├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── README.md ├── RLSbench ├── __init__.py ├── algorithms │ ├── BN_adapt.py │ ├── BN_adapt_adv.py │ ├── CDAN.py │ ├── COAL.py │ ├── CORAL.py │ ├── DANN.py │ ├── ERM.py │ ├── ERM_Adv.py │ ├── SENTRY.py │ ├── TENT.py │ ├── algorithm.py │ ├── fixmatch.py │ ├── initializer.py │ ├── noisy_student.py │ ├── pseudolabel.py │ └── single_model_algorithm.py ├── collate_functions.py ├── configs │ ├── algorithm.py │ ├── datasets.py │ ├── supported.py │ └── utils.py ├── data_augmentation │ ├── __init__.py │ └── randaugment.py ├── datasets │ ├── __init__.py │ ├── data_utils.py │ └── get_dataset.py ├── helper.py ├── label_shift_utils.py ├── losses.py ├── models │ ├── __init__.py │ ├── cifar_efficientnet.py │ ├── cifar_resnet.py │ ├── clip.py │ ├── domain_adversarial_network.py │ ├── initializer.py │ ├── mdd_net.py │ ├── mimic_model.py │ ├── mimic_tokenizer.py │ ├── model_utils.py │ └── transformers.py ├── optimizer.py ├── scheduler.py ├── transforms.py ├── utils.py └── version.py ├── code_helper.md ├── dataset_scripts ├── Imagenet │ ├── ImageNet_reorg.py │ ├── ImageNet_resize.py │ ├── ImageNet_v2_reorg.py │ ├── convert.sh │ ├── convert_to_jpg.py │ └── resize_ImageNet-C.sh ├── convert.sh ├── setup_BREEDs.sh ├── setup_Imagenet.sh ├── setup_Imagenet200.sh ├── setup_camelyon.sh ├── setup_cifar100c.sh ├── setup_cifar10c.sh ├── setup_domainnet.sh ├── setup_fmow.sh ├── setup_iwildcams.sh ├── setup_office31.sh ├── setup_officehome.sh ├── setup_rxrx1.sh ├── setup_visda.sh └── visda_structure.py ├── images ├── RLSbench_fig.png └── datasets.png ├── pretrained_models └── resnet18_imagenet32.pt ├── run_main.py ├── scripts ├── eval_ERM.py ├── run_ERM.sh ├── run_adapt.py └── run_tta.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/* 2 | __pycache__ 3 | logs 4 | logs* 5 | data 6 | .DS_Store 7 | wandb 8 | LSBench.egg-info 9 | # s3_config.sh 10 | ecr_read_only.csv 11 | label_shift_study_results 12 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.formatting.provider": "black" 3 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | [![Website](https://img.shields.io/badge/www-Website-green)](https://sites.google.com/view/rlsbench/) 3 | 4 | 5 | `RLSbench` is the official implementation of [RLSbench: Domain Adaptation Under Relaxed Label Shift](https://arxiv.org/abs/2302.03020). We release the dataset setup, code and our logs/ models from our paper. 6 | 7 | 8 | ### Using the Dataset in an Academic Setting 9 | 10 | To try experiments in our RLSbench setup with limited resources or for fast prototyping, we recommend restricting experiments to a few datasets with fewer label dist shift simulations. In particular, we observe that experiments on different domains in CIFAR10, Entity13, Living17, Visda, and Retiring Adults datasets capture our main findings and the failure models. One can further restrict the simulated shifts to Dirichlet alpha in [None, 1.0 and 0.3] set. 11 | 12 | 13 | ## Paper 14 | For more details, refer to the accompanying paper: 15 | [RLSbench: Domain Adaptation Under Relaxed Label Shift](https://arxiv.org/abs/2302.03020). If you have questions, please feel free to reach us at sgarg2@andrew.cmu.edu or open an issue. 16 | 17 | ![Setup](images/RLSbench_fig.png) 18 | 19 | 20 | If you find this repository useful or use this code in your research, please cite the following paper: 21 | 22 | > Garg, S., Erickson, N., Sharpnack, J., Smola, A., Balakrishnan, S., Lipton, Z. (2023). RLSbench: Domain Adaptation Under Relaxed Label Shift. In International Conference on Machine Learning (ICML). 23 | ``` 24 | @inproceedings{garg2023RLSBench, 25 | title={RLSbench: Domain Adaptation Under Relaxed Label Shift}, 26 | author={Garg, Saurabh and Erickson, Nick and Sharpnack, James and Smola, Alex and Balakrishnan, Sivaraman and Lipton, Zachary}, 27 | year={2023}, 28 | booktitle={International Conference on Machine Learning (ICML)} 29 | } 30 | ``` 31 | 32 | ### Installation 33 | 34 | To install `RLSbench`, run: 35 | ```python 36 | pip install -e . 37 | ``` 38 | 39 | To install latest version of `RLSbench` as a package, run: 40 | 41 | ```python 42 | pip install git+https://github.com/acmi-lab/RLSbench 43 | ``` 44 | 45 | ## Datasets 46 | ![Datasets](images/datasets.png) 47 | 48 | ## Dataset Setup 49 | To setup different datasets, run the scrips in `dataset_scripts` folder. Except for Imagenet dataset which can be downloaded from the [official website](https://www.image-net.org/download.php), the scripts set up all the datasets (including all the source and target pairs) used in our study. 50 | 51 | ## Overview 52 | 53 | Overall pipeline: 54 | 55 | The following are the crucial parts of the code: 56 | 57 | 1. `label_shift_utils.py`: This files contains utils functions to simulate label shift in the target data. 58 | 2. `./datasets/get_dataset.py`: This file contains the code to get the source and target datasets. 59 | 3. `./algorithms/`: This folder contains the code for different algorithms. We implement the follwing domain algorithms: 60 | - ERM variants: ERM, ERM-aug, with different pretraining techniques like ['rand', 'imagenet', 'clip', 'bert'] 61 | - Domain alignment methods: DANN, CDAN, IW-DANN, IW-CDAN 62 | - Self-training methods: Noisy student, Pseudolabeling, FixMatch, SENTRY 63 | - Test time adaptation methods: BN_adapt, TENT, CORAL 64 | 65 | 66 | The entry point of the code is `run_main.py`. `config` folder contains default parameters and hyperparameters needed for base experiments for the project. We need to pass the dataset name and the algorithm name with flag `--dataset` and `--algorithm` to the `run_main.py` file. To simulate label shift, we need to pass the flag `--simulate_label_shift` and the dirchilet sampling parameter with `--dirchilet_alpha`. And the flag `--root_dir` is used to specified the data directory for source and target datasets. 67 | 68 | Caveat: For Test Time Adaptation (TTA) methods, we need to provide the folder with ERM-aug trained models with the parameter `--source_model_path`. 69 | 70 | ### Results Logging 71 | The code evaluates the models trained and logs the results in the `./logs/` folder in form of a csv file. 72 | 73 | ## License 74 | This repository is licensed under the terms of the [Apache License](LICENSE). 75 | 76 | -------------------------------------------------------------------------------- /RLSbench/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acmi-lab/RLSbench/eb67d5c78aa3646b7369830e481b3f15a59a087d/RLSbench/__init__.py -------------------------------------------------------------------------------- /RLSbench/algorithms/BN_adapt.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from RLSbench.algorithms.algorithm import Algorithm 5 | from RLSbench.models.initializer import initialize_model 6 | from RLSbench.utils import load, move_to 7 | 8 | logger = logging.getLogger("label_shift") 9 | 10 | 11 | class BN_adapt(Algorithm): 12 | def __init__(self, config): 13 | logger.info("Initializing model...") 14 | 15 | model = initialize_model( 16 | model_name=config.model, 17 | dataset_name=config.dataset, 18 | num_classes=config.num_classes, 19 | featurize=False, 20 | pretrained=False, 21 | ) 22 | 23 | model.to(config.device) 24 | 25 | # initialize module 26 | super().__init__( 27 | device=config.device, 28 | ) 29 | 30 | self.model = model 31 | 32 | self.source_balanced = config.source_balanced 33 | self.num_classes = config.num_classes 34 | 35 | def get_model_output(self, x): 36 | outputs = self.model(x) 37 | return outputs 38 | 39 | def process_batch(self, batch): 40 | """ 41 | A helper function for update() and evaluate() that processes the batch 42 | Args: 43 | - batch (tuple of Tensors): a batch of data yielded by data loaders 44 | Output: 45 | - results (dictionary): information about the batch 46 | - y_true (Tensor): ground truth labels for batch 47 | - y_pred (Tensor): model output for batch 48 | """ 49 | x, y_true = batch[:2] 50 | x = move_to(x, self.device) 51 | y_true = move_to(y_true, self.device) 52 | 53 | outputs = self.get_model_output(x) 54 | 55 | results = { 56 | "y_true": y_true, 57 | "y_pred": outputs, 58 | } 59 | return results 60 | 61 | def evaluate(self, batch): 62 | """ 63 | Process the batch and update the log, without updating the model 64 | Args: 65 | - batch (tuple of Tensors): a batch of data yielded by data loaders 66 | Output: 67 | - results (dictionary): information about the batch, such as: 68 | - y_true (Tensor) 69 | - outputs (Tensor) 70 | - y_pred (Tensor) 71 | """ 72 | assert not self.is_training 73 | results = self.process_batch(batch) 74 | return results 75 | 76 | def adapt( 77 | self, 78 | source_loader, 79 | target_loader, 80 | target_marginal=None, 81 | source_marginal=None, 82 | target_average=None, 83 | pretrained_path=None, 84 | ): 85 | """ 86 | Load the model and adapt it to the new data 87 | Args: 88 | - unlabeled_batch (tuple of Tensors): a batch of data yielded by unlabeled data loader 89 | - target_marginal (Tensor): the marginal distribution of the target 90 | - source_marginal (Tensor): the marginal distribution of the source 91 | - target_average (Tensor): the average of the target 92 | 93 | Output: 94 | """ 95 | 96 | if pretrained_path is not None: 97 | logger.info(f"Loading pretrained model from {pretrained_path}") 98 | load(self.model, pretrained_path, device=self.device) 99 | 100 | # self.train(True) 101 | 102 | logger.info("Adapting model to BN params ...") 103 | 104 | with torch.no_grad(): 105 | for batch in target_loader: 106 | inp = batch[0].to(self.device) 107 | self.model(inp) 108 | 109 | def reset(self): 110 | pass 111 | -------------------------------------------------------------------------------- /RLSbench/algorithms/BN_adapt_adv.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from RLSbench.algorithms.algorithm import Algorithm 5 | from RLSbench.models.initializer import initialize_model 6 | from RLSbench.utils import load, move_to 7 | from robustness.attacker import AttackerModel 8 | 9 | logger = logging.getLogger("label_shift") 10 | 11 | 12 | _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406] 13 | _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225] 14 | 15 | 16 | class BN_adapt_adv(Algorithm): 17 | def __init__(self, config): 18 | logger.info("Initializing model...") 19 | 20 | model = initialize_model( 21 | model_name=config.model, 22 | dataset_name=config.dataset, 23 | num_classes=config.num_classes, 24 | featurize=False, 25 | pretrained=False, 26 | ) 27 | 28 | model = AttackerModel( 29 | model, 30 | torch.tensor(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN), 31 | torch.tensor(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD), 32 | ) 33 | 34 | model.to(config.device) 35 | 36 | # initialize module 37 | super().__init__( 38 | device=config.device, 39 | ) 40 | 41 | self.model = model 42 | 43 | self.source_balanced = config.source_balanced 44 | self.num_classes = config.num_classes 45 | 46 | def get_model_output(self, x): 47 | outputs = self.model(x) 48 | return outputs 49 | 50 | def process_batch(self, batch): 51 | """ 52 | A helper function for update() and evaluate() that processes the batch 53 | Args: 54 | - batch (tuple of Tensors): a batch of data yielded by data loaders 55 | Output: 56 | - results (dictionary): information about the batch 57 | - y_true (Tensor): ground truth labels for batch 58 | - y_pred (Tensor): model output for batch 59 | """ 60 | x, y_true = batch[:2] 61 | x = move_to(x, self.device) 62 | y_true = move_to(y_true, self.device) 63 | 64 | outputs = self.model(x) 65 | 66 | results = { 67 | "y_true": y_true, 68 | "y_pred": outputs, 69 | } 70 | return results 71 | 72 | def evaluate(self, batch): 73 | """ 74 | Process the batch and update the log, without updating the model 75 | Args: 76 | - batch (tuple of Tensors): a batch of data yielded by data loaders 77 | Output: 78 | - results (dictionary): information about the batch, such as: 79 | - y_true (Tensor) 80 | - outputs (Tensor) 81 | - y_pred (Tensor) 82 | """ 83 | assert not self.is_training 84 | results = self.process_batch(batch) 85 | return results 86 | 87 | def adapt( 88 | self, 89 | source_loader, 90 | target_loader, 91 | target_marginal=None, 92 | source_marginal=None, 93 | target_average=None, 94 | pretrained_path=None, 95 | ): 96 | """ 97 | Load the model and adapt it to the new data 98 | Args: 99 | - unlabeled_batch (tuple of Tensors): a batch of data yielded by unlabeled data loader 100 | - target_marginal (Tensor): the marginal distribution of the target 101 | - source_marginal (Tensor): the marginal distribution of the source 102 | - target_average (Tensor): the average of the target 103 | 104 | Output: 105 | """ 106 | 107 | if pretrained_path is not None: 108 | logger.info(f"Loading pretrained model from {pretrained_path}") 109 | load(self.model, pretrained_path, device=self.device) 110 | 111 | # self.train(True) 112 | 113 | logger.info("Adapting model to BN params ...") 114 | 115 | with torch.no_grad(): 116 | for batch in target_loader: 117 | inp = batch[0].to(self.device) 118 | self.model(inp) 119 | 120 | def reset(self): 121 | pass 122 | -------------------------------------------------------------------------------- /RLSbench/algorithms/COAL.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from RLSbench.algorithms.single_model_algorithm import SingleModelAlgorithm 7 | from RLSbench.losses import initialize_loss 8 | from RLSbench.models.domain_adversarial_network import COALNetwork 9 | from RLSbench.models.initializer import initialize_model 10 | from RLSbench.models.model_utils import linear_probe 11 | from RLSbench.optimizer import initialize_optimizer_with_model_params 12 | from RLSbench.scheduler import LinearScheduleWithWarmupAndThreshold 13 | from RLSbench.utils import ( 14 | concat_input, 15 | detach_and_clone, 16 | move_to, 17 | pseudolabel_multiclass_logits, 18 | ) 19 | 20 | logger = logging.getLogger("label_shift") 21 | 22 | 23 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor: 24 | """Entropy of softmax distribution from logits.""" 25 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 26 | 27 | 28 | class COAL(SingleModelAlgorithm): 29 | """ 30 | COAL. 31 | 32 | Original paper: 33 | @inproceedings{tan2020class, 34 | title={Class-imbalanced domain adaptation: An empirical odyssey}, 35 | author={Tan, Shuhan and Peng, Xingchao and Saenko, Kate}, 36 | booktitle={European Conference on Computer Vision}, 37 | pages={585--602}, 38 | year={2020}, 39 | organization={Springer} 40 | } 41 | """ 42 | 43 | def __init__(self, config, dataloader, loss_function, n_train_steps, **kwargs): 44 | logger.info("Initializing PseudoLabel models") 45 | 46 | model = initialize_model( 47 | model_name=config.model, 48 | dataset_name=config.dataset, 49 | num_classes=config.num_classes, 50 | featurize=True, 51 | pretrained=config.pretrained, 52 | pretrained_path=config.pretrained_path, 53 | ) 54 | 55 | if config.algorithm.startswith("IW"): 56 | self.use_target_marginal = True 57 | else: 58 | self.use_target_marginal = False 59 | 60 | if config.source_balanced or self.use_target_marginal: 61 | loss = initialize_loss(loss_function, reduction="none") 62 | else: 63 | loss = initialize_loss(loss_function) 64 | 65 | # if config.pretrained: 66 | # featurizer, classifier = linear_probe(model, dataloader, device= config.device, progress_bar=config.progress_bar) 67 | 68 | model = COALNetwork(model[0], num_classes=config.num_classes) 69 | 70 | featurizer = model.featurizer 71 | classifier = model.classifier 72 | 73 | if config.pretrained: 74 | linear_probe( 75 | (featurizer, classifier), 76 | dataloader, 77 | device=config.device, 78 | progress_bar=config.progress_bar, 79 | ) 80 | 81 | parameters_to_optimize: List[Dict] = model.get_parameters_with_lr( 82 | featurizer_lr=kwargs["featurizer_lr"], 83 | classifier_lr=kwargs["classifier_lr"], 84 | # discriminator_lr=kwargs["discriminator_lr"], 85 | ) 86 | 87 | self.optimizer = initialize_optimizer_with_model_params( 88 | config, parameters_to_optimize 89 | ) 90 | 91 | # initialize module 92 | super().__init__( 93 | config=config, 94 | model=model, 95 | loss=loss, 96 | n_train_steps=n_train_steps, 97 | ) 98 | 99 | # algorithm hyperparameters 100 | self.confidence_threshold = kwargs["self_training_threshold"] 101 | self.alpha = kwargs["alpha"] 102 | self.process_pseudolabels_function = pseudolabel_multiclass_logits 103 | 104 | self.target_align = False 105 | 106 | self.source_balanced = config.source_balanced 107 | self.num_classes = config.num_classes 108 | 109 | def process_batch( 110 | self, 111 | batch, 112 | unlabeled_batch=None, 113 | target_marginal=None, 114 | source_marginal=None, 115 | target_average=None, 116 | ): 117 | """ 118 | Overrides single_model_algorithm.process_batch(). 119 | Args: 120 | - batch (tuple of Tensors): a batch of data yielded by data loaders 121 | - unlabeled_batch (tuple of Tensors or None): a batch of data yielded by unlabeled data loader 122 | Output: 123 | - results (dictionary): information about the batch 124 | - y_true (Tensor): ground truth labels for batch 125 | - y_pred (Tensor): model output for batch 126 | - unlabeled_y_pseudo (Tensor): pseudolabels on the unlabeled batch, already thresholded 127 | - unlabeled_y_pred (Tensor): model output on the unlabeled batch, already thresholded 128 | """ 129 | # Labeled examples 130 | x, y_true = batch[:2] 131 | x = move_to(x, self.device) 132 | y_true = move_to(y_true, self.device) 133 | 134 | n_lab = y_true.shape[0] 135 | 136 | # package the results 137 | results = { 138 | "y_true": y_true, 139 | } 140 | 141 | # TODO: Add target alignment if it is useful 142 | # alignment_dist = torch.divide(torch.tensor(target_marginal).to(self.device), torch.tensor(target_average).to(self.device)) 143 | 144 | if unlabeled_batch is not None: 145 | x_unlab = unlabeled_batch[0] 146 | x_unlab = move_to(x_unlab, self.device) 147 | 148 | x_cat = concat_input(x, x_unlab) 149 | outputs = self.model(x_cat) 150 | unlabeled_output = outputs[n_lab:] 151 | 152 | if self.target_align: 153 | ( 154 | unlabeled_y_pred, 155 | unlabeled_y_pseudo, 156 | pseudolabels_kept_frac, 157 | _, 158 | ) = self.process_pseudolabels_function( 159 | unlabeled_output, self.confidence_threshold, alignment_dist 160 | ) 161 | else: 162 | ( 163 | unlabeled_y_pred, 164 | unlabeled_y_pseudo, 165 | pseudolabels_kept_frac, 166 | _, 167 | ) = self.process_pseudolabels_function( 168 | unlabeled_output, self.confidence_threshold 169 | ) 170 | 171 | results["y_pred"] = outputs[:n_lab] 172 | results["unlabeled_y_pred"] = unlabeled_y_pred 173 | results["unlabeled_y_pseudo"] = detach_and_clone(unlabeled_y_pseudo) 174 | 175 | if self.source_balanced and source_marginal is not None: 176 | results["source_marginal"] = torch.tensor(source_marginal).to( 177 | self.device 178 | ) 179 | 180 | if self.use_target_marginal and target_marginal is not None: 181 | results["im_weights"] = torch.divide( 182 | torch.tensor(target_marginal).to(self.device), 183 | torch.tensor(source_marginal).to(self.device), 184 | ) 185 | results["target_marginal"] = torch.tensor(target_marginal).to( 186 | self.device 187 | ) 188 | 189 | x_unlab_copy = torch.clone(x_unlab) 190 | unlabeled_output_ent = self.model(x_unlab_copy, reverse=True) 191 | results["unlabeled_y_pred_ent"] = unlabeled_output_ent 192 | 193 | ## New edits below 194 | 195 | # outputs = self.model(x) 196 | # results['y_pred'] = outputs 197 | 198 | # outputs_unlab = self.model(x_unlab, reverse=True) 199 | # results['unlabeled_y_pred'] = outputs_unlab 200 | 201 | else: 202 | results["y_pred"] = self.get_model_output(x) 203 | pseudolabels_kept_frac = 0 204 | 205 | results["pseudolabels_kept_frac"] = pseudolabels_kept_frac 206 | 207 | return results 208 | 209 | def objective(self, results): 210 | # Labeled loss 211 | classification_loss = self.loss(results["y_pred"], results["y_true"]) 212 | 213 | if self.use_target_marginal: 214 | classification_loss = torch.mean( 215 | classification_loss * results["im_weights"][results["y_true"]] 216 | ) 217 | 218 | elif self.source_balanced: 219 | classification_loss = torch.mean( 220 | classification_loss 221 | / results["source_marginal"][results["y_true"]] 222 | / self.num_classes 223 | ) 224 | 225 | # Pseudolabeled loss 226 | if "unlabeled_y_pred" in results: 227 | loss_output = self.loss( 228 | results["unlabeled_y_pred"], 229 | results["unlabeled_y_pseudo"], 230 | ) 231 | 232 | if self.source_balanced: 233 | target_marginal = results["target_marginal"] 234 | target_marginal[target_marginal == 0] = 1.0 235 | 236 | loss_output = torch.mean( 237 | loss_output 238 | / target_marginal[results["unlabeled_y_pseudo"]] 239 | / self.num_classes 240 | ) 241 | 242 | elif self.use_target_marginal: 243 | loss_output = torch.mean(loss_output) 244 | 245 | consistency_loss = loss_output * results["pseudolabels_kept_frac"] 246 | 247 | y_pred_ent = results["unlabeled_y_pred_ent"] 248 | ent_loss = -self.alpha * torch.mean(softmax_entropy(y_pred_ent), dim=0) 249 | 250 | # import pdb; pdb.set_trace() 251 | else: 252 | consistency_loss = 0 253 | ent_loss = 0 254 | 255 | return classification_loss + ent_loss + consistency_loss 256 | -------------------------------------------------------------------------------- /RLSbench/algorithms/CORAL.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | from RLSbench.algorithms.algorithm import Algorithm 6 | from RLSbench.models.initializer import initialize_model 7 | from RLSbench.models.model_utils import test_CORAL_params, train_CORAL 8 | from RLSbench.utils import load, move_to 9 | 10 | logger = logging.getLogger("label_shift") 11 | 12 | 13 | class CORALModel(nn.Module): 14 | def __init__(self, featurizer, classifier): 15 | super().__init__() 16 | self.featurizer = featurizer 17 | self.classifier = classifier 18 | # self.mean = None 19 | self.cov_inv = None 20 | 21 | def forward(self, x): 22 | features = self.featurizer(x) 23 | 24 | if self.cov_inv is not None: 25 | # centered = features - self.mean 26 | features = torch.mm(features, self.cov_inv) 27 | 28 | outputs = self.classifier(features) 29 | 30 | return outputs 31 | 32 | 33 | class CORAL(Algorithm): 34 | def __init__(self, config): 35 | logger.info("Initializing model...") 36 | 37 | if config.algorithm.startswith("IW"): 38 | self.use_target_marginal = True 39 | else: 40 | self.use_target_marginal = False 41 | 42 | model = initialize_model( 43 | model_name=config.model, 44 | dataset_name=config.dataset, 45 | num_classes=config.num_classes, 46 | featurize=True, 47 | pretrained=False, 48 | ) 49 | 50 | linear_layer = nn.Linear(model[0].d_out, config.num_classes) 51 | model = CORALModel(model[0], linear_layer) 52 | 53 | # initialize module 54 | super().__init__( 55 | device=config.device, 56 | ) 57 | 58 | model = model.to(config.device) 59 | 60 | self.model = model 61 | self.source_balanced = config.source_balanced 62 | self.num_classes = config.num_classes 63 | 64 | def get_model_output(self, x): 65 | return self.model(x) 66 | 67 | def process_batch(self, batch): 68 | """ 69 | A helper function for update() and evaluate() that processes the batch 70 | Args: 71 | - batch (tuple of Tensors): a batch of data yielded by data loaders 72 | Output: 73 | - results (dictionary): information about the batch 74 | - y_true (Tensor): ground truth labels for batch 75 | - y_pred (Tensor): model output for batch 76 | """ 77 | x, y_true = batch[:2] 78 | x = move_to(x, self.device) 79 | y_true = move_to(y_true, self.device) 80 | 81 | outputs = self.get_model_output(x) 82 | 83 | results = { 84 | "y_true": y_true, 85 | "y_pred": outputs, 86 | } 87 | return results 88 | 89 | def evaluate(self, batch): 90 | """ 91 | Process the batch and update the log, without updating the model 92 | Args: 93 | - batch (tuple of Tensors): a batch of data yielded by data loaders 94 | Output: 95 | - results (dictionary): information about the batch, such as: 96 | - y_true (Tensor) 97 | - outputs (Tensor) 98 | - y_pred (Tensor) 99 | """ 100 | assert not self.is_training 101 | results = self.process_batch(batch) 102 | return results 103 | 104 | def adapt( 105 | self, 106 | source_loader, 107 | target_loader, 108 | target_marginal=None, 109 | source_marginal=None, 110 | target_average=None, 111 | pretrained_path=None, 112 | ): 113 | """ 114 | Load the model and adapt it to the new data 115 | Args: 116 | - unlabeled_batch (tuple of Tensors): a batch of data yielded by unlabeled data loader 117 | - target_marginal (Tensor): the marginal distribution of the target 118 | - source_marginal (Tensor): the marginal distribution of the source 119 | - target_average (Tensor): the average of the target 120 | 121 | Output: 122 | """ 123 | 124 | if pretrained_path is not None: 125 | logger.info(f"Loading pretrained model from {pretrained_path}") 126 | 127 | load(self.model[0], pretrained_path, device=self.device) 128 | 129 | # self.train(True) 130 | 131 | logger.info("Adapting model to CORAL ...") 132 | 133 | im_weights = None 134 | 135 | if self.use_target_marginal: 136 | im_weights = torch.divide( 137 | torch.tensor(target_marginal).to(self.device), 138 | torch.tensor(source_marginal).to(self.device), 139 | ) 140 | 141 | self.model = train_CORAL( 142 | self.model, source_loader, im_weights=im_weights, device=self.device 143 | ) 144 | 145 | # self.model.mean, 146 | self.model.cov_inv = test_CORAL_params( 147 | self.model, target_loader, device=self.device 148 | ) 149 | 150 | # self.model.mean = self.model.mean.to(self.device) 151 | self.model.cov_inv = self.model.cov_inv.to(self.device) 152 | 153 | def reset(self): 154 | # self.model.mean = None 155 | self.model.cov_inv = None 156 | -------------------------------------------------------------------------------- /RLSbench/algorithms/DANN.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, List 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from RLSbench.algorithms.single_model_algorithm import SingleModelAlgorithm 8 | from RLSbench.losses import initialize_loss 9 | from RLSbench.models.domain_adversarial_network import DomainAdversarialNetwork 10 | from RLSbench.models.initializer import initialize_model 11 | from RLSbench.models.model_utils import linear_probe 12 | from RLSbench.optimizer import initialize_optimizer_with_model_params 13 | from RLSbench.scheduler import CoeffSchedule 14 | from RLSbench.utils import concat_input, move_to 15 | from RLSbench.label_shift_utils import im_weights_update 16 | import torch.nn.functional as F 17 | from RLSbench.collate_functions import collate_fn_mimic 18 | 19 | logger = logging.getLogger("label_shift") 20 | 21 | 22 | class DANN(SingleModelAlgorithm): 23 | """ 24 | Domain-adversarial training of neural networks. 25 | 26 | Original paper: 27 | @inproceedings{dann, 28 | title={Domain-Adversarial Training of Neural Networks}, 29 | author={Ganin, Ustinova, Ajakan, Germain, Larochelle, Laviolette, Marchand and Lempitsky}, 30 | booktitle={Journal of Machine Learning Research 17}, 31 | year={2016} 32 | } 33 | """ 34 | 35 | def __init__( 36 | self, config, dataloader, loss_function, n_train_steps, n_domains=2, **kwargs 37 | ): 38 | logger.info("Initializing DANN models") 39 | 40 | if config.algorithm.startswith("IW"): 41 | self.use_target_marginal = True 42 | else: 43 | self.use_target_marginal = False 44 | 45 | if config.source_balanced or self.use_target_marginal: 46 | loss = initialize_loss(loss_function, reduction="none") 47 | else: 48 | loss = initialize_loss(loss_function) 49 | 50 | # Initialize model 51 | featurizer, classifier = initialize_model( 52 | model_name=config.model, 53 | dataset_name=config.dataset, 54 | num_classes=config.num_classes, 55 | featurize=True, 56 | pretrained=config.pretrained, 57 | pretrained_path=config.pretrained_path, 58 | data_dir=config.root_dir, 59 | ) 60 | 61 | self.im_weights = np.ones((config.num_classes, 1)) 62 | self.cov = np.zeros((config.num_classes, config.num_classes)) 63 | self.source_marginal = np.zeros((config.num_classes)) 64 | self.psuedo_marginal = np.zeros((config.num_classes)) 65 | self.source_num_samples = 0 66 | self.target_num_samples = 0 67 | 68 | # if config.pretrained : 69 | # featurizer, classifier = linear_probe( (featurizer, classifier), dataloader, device= config.device, progress_bar=config.progress_bar) 70 | 71 | model = DomainAdversarialNetwork( 72 | featurizer, classifier, n_domains, config.num_classes 73 | ) 74 | 75 | # featurizer = model.featurizer 76 | # classifier = nn.Sequential(model.bottleneck, model.classifier) 77 | 78 | # if config.pretrained : 79 | # linear_probe( (featurizer, classifier), dataloader, device= config.device, progress_bar=config.progress_bar) 80 | 81 | parameters_to_optimize: List[Dict] = model.get_parameters_with_lr( 82 | featurizer_lr=kwargs["featurizer_lr"], 83 | classifier_lr=kwargs["classifier_lr"], 84 | discriminator_lr=kwargs["discriminator_lr"], 85 | ) 86 | 87 | self.optimizer = initialize_optimizer_with_model_params( 88 | config, parameters_to_optimize 89 | ) 90 | 91 | # Initialize module 92 | super().__init__( 93 | config=config, 94 | model=model, 95 | loss=loss, 96 | n_train_steps=n_train_steps, 97 | ) 98 | 99 | self.coeff_schedule = CoeffSchedule(max_iter=n_train_steps) 100 | self.schedulers.append(self.coeff_schedule) 101 | 102 | self.domain_loss = initialize_loss("cross_entropy", reduction="none") 103 | 104 | # Algorithm hyperparameters 105 | self.penalty_weight = kwargs["penalty_weight"] 106 | self.source_balanced = config.source_balanced 107 | self.num_classes = config.num_classes 108 | 109 | self.dataset = config.dataset 110 | 111 | def process_batch( 112 | self, 113 | batch, 114 | unlabeled_batch=None, 115 | target_marginal=None, 116 | source_marginal=None, 117 | target_average=None, 118 | ): 119 | """ 120 | Overrides single_model_algorithm.process_batch(). 121 | Args: 122 | - batch (tuple of Tensors): a batch of data yielded by data loaders 123 | - unlabeled_batch (tuple of Tensors or None): a batch of data yielded by unlabeled data loader 124 | Output: 125 | - results (dictionary): information about the batch 126 | - y_true (Tensor): ground truth labels for batch 127 | - g (Tensor): groups for batch 128 | - metadata (Tensor): metadata for batch 129 | - y_pred (Tensor): model output for batch 130 | - domains_true (Tensor): true domains for batch and unlabeled batch 131 | - domains_pred (Tensor): predicted domains for batch and unlabeled batch 132 | - unlabeled_features (Tensor): featurizer outputs for unlabeled_batch 133 | """ 134 | # Forward pass 135 | ( 136 | x, 137 | y_true, 138 | ) = batch[:2] 139 | 140 | if unlabeled_batch is not None: 141 | unlabeled_x = unlabeled_batch[0] 142 | 143 | # Concatenate examples and true domains 144 | if "mimic" in self.dataset: 145 | # x_cat = collate_fn_mimic([x, unlabeled_x]) 146 | x_cat = [x[0] + unlabeled_x[0], x[1] + unlabeled_x[1]] 147 | domains_true = torch.cat( 148 | [ 149 | torch.zeros(len(x[0]), dtype=torch.long), 150 | torch.ones(len(unlabeled_x[0]), dtype=torch.long), 151 | ] 152 | ) 153 | 154 | else: 155 | x_cat = concat_input(x, unlabeled_x) 156 | domains_true = torch.cat( 157 | [ 158 | torch.zeros(len(x), dtype=torch.long), 159 | torch.ones(len(unlabeled_x), dtype=torch.long), 160 | ] 161 | ) 162 | 163 | x_cat = move_to(x_cat, self.device) 164 | # x_cat = x_cat.to(self.device) 165 | y_true = y_true.to(self.device) 166 | domains_true = domains_true.to(self.device) 167 | y_pred, domains_pred = self.model( 168 | x_cat, self.coeff_schedule.value, domain_classifier=True 169 | ) 170 | 171 | y_source_pred = y_pred[: len(y_true)] 172 | y_target_pred = y_pred[len(y_true) :] 173 | 174 | results = { 175 | "y_true": y_true, 176 | "y_pred": y_source_pred, 177 | "target_y_pred": y_target_pred, 178 | "domains_true": domains_true, 179 | "domains_pred": domains_pred, 180 | } 181 | 182 | # if self.use_target_marginal and target_marginal is not None: 183 | 184 | # results["im_weights"] = torch.divide(torch.tensor(target_marginal).to(self.device),\ 185 | # torch.tensor(source_marginal).to(self.device)) 186 | 187 | if source_marginal is not None: 188 | self.source_marginal = source_marginal 189 | 190 | # results["source_marginal"] = torch.tensor(source_marginal).to(self.device) 191 | 192 | return results 193 | 194 | else: 195 | x = move_to(x, self.device) 196 | y_true = y_true.to(self.device) 197 | 198 | y_pred = self.model(x) 199 | 200 | return { 201 | "y_true": y_true, 202 | "y_pred": y_pred, 203 | } 204 | 205 | def objective(self, results): 206 | if self.use_target_marginal: 207 | self.source_num_samples += len(results["y_pred"]) 208 | self.target_num_samples += len(results["target_y_pred"]) 209 | 210 | target_preds = F.softmax(results["target_y_pred"], dim=1) 211 | self.psuedo_marginal += ( 212 | torch.sum(target_preds, dim=0).detach().cpu().numpy() 213 | ) 214 | 215 | source_preds = F.softmax(results["y_pred"], dim=1) 216 | source_onehot = F.one_hot( 217 | results["y_true"], num_classes=self.num_classes 218 | ).float() 219 | 220 | self.cov += ( 221 | torch.mm(source_preds.transpose(1, 0), source_onehot) 222 | .detach() 223 | .cpu() 224 | .numpy() 225 | ) 226 | 227 | if self.batch_idx == 0: 228 | self.cov /= self.source_num_samples 229 | self.psuedo_marginal /= self.target_num_samples 230 | 231 | self.im_weights = im_weights_update( 232 | self.source_marginal, 233 | self.psuedo_marginal, 234 | self.cov, 235 | self.im_weights, 236 | ) 237 | 238 | # import pdb; pdb.set_trace() 239 | 240 | self.cov = np.zeros((self.num_classes, self.num_classes)) 241 | self.source_marginal = np.zeros((self.num_classes)) 242 | self.psuedo_marginal = np.zeros((self.num_classes)) 243 | 244 | self.source_num_samples = 0 245 | self.target_num_samples = 0 246 | 247 | classification_loss = self.loss(results["y_pred"], results["y_true"]) 248 | 249 | im_weights = torch.tensor(self.im_weights).to(self.device) 250 | 251 | # if self.source_balanced: 252 | # classification_loss = torch.mean(classification_loss/results["source_marginal"][results["y_true"]]/ self.num_classes) 253 | 254 | if self.use_target_marginal: 255 | classification_loss = torch.mean( 256 | classification_loss * im_weights[results["y_true"]] 257 | ) 258 | 259 | if self.is_training: 260 | domain_classification_loss = self.domain_loss( 261 | results["domains_pred"], 262 | results["domains_true"], 263 | ) 264 | if self.use_target_marginal: 265 | source_size = len(results["y_true"]) 266 | domain_classification_loss_source = torch.mean( 267 | domain_classification_loss[:source_size] 268 | * im_weights[results["y_true"]] 269 | ) 270 | domain_classification_loss_target = torch.mean( 271 | domain_classification_loss[source_size:] 272 | ) 273 | domain_classification_loss = ( 274 | domain_classification_loss_source 275 | + domain_classification_loss_target 276 | ) / 2.0 277 | else: 278 | domain_classification_loss = torch.mean(domain_classification_loss) 279 | else: 280 | domain_classification_loss = 0.0 281 | 282 | return classification_loss + domain_classification_loss * self.penalty_weight 283 | -------------------------------------------------------------------------------- /RLSbench/algorithms/ERM.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | from RLSbench.algorithms.single_model_algorithm import SingleModelAlgorithm 6 | from RLSbench.losses import initialize_loss 7 | from RLSbench.models.initializer import initialize_model 8 | from RLSbench.models.model_utils import linear_probe 9 | from RLSbench.utils import move_to 10 | 11 | logger = logging.getLogger("label_shift") 12 | 13 | 14 | class ERM(SingleModelAlgorithm): 15 | def __init__(self, config, dataloader, loss_function, n_train_steps): 16 | logger.info("Initializing model...") 17 | 18 | if config.algorithm.startswith("IW"): 19 | self.use_target_marginal = True 20 | else: 21 | self.use_target_marginal = False 22 | 23 | if config.source_balanced or self.use_target_marginal: 24 | loss = initialize_loss(loss_function, reduction="none") 25 | else: 26 | loss = initialize_loss(loss_function) 27 | 28 | model = initialize_model( 29 | model_name=config.model, 30 | dataset_name=config.dataset, 31 | num_classes=config.num_classes, 32 | featurize=True, 33 | pretrained=config.pretrained, 34 | pretrained_path=config.pretrained_path, 35 | data_dir=config.root_dir, 36 | ) 37 | 38 | if config.pretrained and "clip" in config.model: 39 | model = linear_probe( 40 | model, 41 | dataloader, 42 | device=config.device, 43 | progress_bar=config.progress_bar, 44 | ) 45 | 46 | model = nn.Sequential(*model) 47 | 48 | # initialize module 49 | super().__init__( 50 | config=config, 51 | model=model, 52 | loss=loss, 53 | n_train_steps=n_train_steps, 54 | ) 55 | 56 | self.use_unlabeled_y = ( 57 | config.use_unlabeled_y 58 | ) # Expect x,y,m from unlabeled loaders and train on the unlabeled y 59 | 60 | self.source_balanced = config.source_balanced 61 | self.num_classes = config.num_classes 62 | 63 | def process_batch( 64 | self, 65 | batch, 66 | unlabeled_batch=None, 67 | target_marginal=None, 68 | source_marginal=None, 69 | target_average=None, 70 | ): 71 | """ 72 | Overrides single_model_algorithm.process_batch(). 73 | ERM defines its own process_batch to handle if self.use_unlabeled_y is true. 74 | Args: 75 | - batch (tuple of Tensors): a batch of data yielded by data loaders 76 | - unlabeled_batch (tuple of Tensors or None): a batch of data yielded by unlabeled data loader 77 | Output: 78 | - results (dictionary): information about the batch 79 | - y_true (Tensor): ground truth labels for batch 80 | - y_pred (Tensor): model output for batch 81 | - unlabeled_y_pred (Tensor): predictions for unlabeled batch for fully-supervised ERM experiments 82 | - unlabeled_y_true (Tensor): true labels for unlabeled batch for fully-supervised ERM experiments 83 | """ 84 | x, y_true = batch[:2] 85 | # import pdb; pdb.set_trace() 86 | # print(x) 87 | # print(y_true) 88 | x = move_to(x, self.device) 89 | y_true = move_to(y_true, self.device) 90 | 91 | outputs = self.get_model_output(x) 92 | 93 | results = {"y_true": y_true, "y_pred": outputs} 94 | if unlabeled_batch is not None: 95 | if self.use_unlabeled_y: # expect loaders to return x,y,m 96 | x, y = unlabeled_batch[:2] 97 | y = move_to(y, self.device) 98 | x = move_to(x, self.device) 99 | results["unlabeled_y_pred"] = self.get_model_output(x) 100 | results["unlabeled_y_true"] = y 101 | 102 | if self.use_target_marginal and target_marginal is not None: 103 | results["im_weights"] = torch.divide( 104 | torch.tensor(target_marginal).to(self.device), 105 | torch.tensor(source_marginal).to(self.device), 106 | ) 107 | 108 | if self.source_balanced and source_marginal is not None: 109 | results["source_marginal"] = torch.tensor(source_marginal).to(self.device) 110 | 111 | return results 112 | 113 | def objective(self, results): 114 | labeled_loss = self.loss(results["y_pred"], results["y_true"]) 115 | 116 | if self.use_target_marginal: 117 | labeled_loss = torch.mean( 118 | labeled_loss * results["im_weights"][results["y_true"]] 119 | ) 120 | 121 | elif self.source_balanced: 122 | labeled_loss = torch.mean( 123 | labeled_loss 124 | / results["source_marginal"][results["y_true"]] 125 | / self.num_classes 126 | ) 127 | 128 | if self.use_unlabeled_y and "unlabeled_y_true" in results: 129 | unlabeled_loss = self.loss( 130 | results["unlabeled_y_pred"], 131 | results["unlabeled_y_true"], 132 | ) 133 | 134 | if self.use_target_marginal or self.source_balanced: 135 | unlabeled_loss = torch.mean(unlabeled_loss) 136 | 137 | return labeled_loss + unlabeled_loss 138 | 139 | else: 140 | return labeled_loss 141 | -------------------------------------------------------------------------------- /RLSbench/algorithms/ERM_Adv.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | from RLSbench.algorithms.single_model_algorithm import SingleModelAlgorithm 6 | from RLSbench.losses import initialize_loss 7 | from RLSbench.models.initializer import initialize_model 8 | from RLSbench.models.model_utils import linear_probe 9 | from RLSbench.utils import move_to 10 | from robustness.attacker import AttackerModel 11 | 12 | _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406] 13 | _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225] 14 | 15 | logger = logging.getLogger("label_shift") 16 | 17 | 18 | class ERM_Adv(SingleModelAlgorithm): 19 | def __init__(self, config, dataloader, loss_function, n_train_steps): 20 | logger.info("Initializing model...") 21 | 22 | if config.algorithm.startswith("IW"): 23 | self.use_target_marginal = True 24 | else: 25 | self.use_target_marginal = False 26 | 27 | if config.source_balanced or self.use_target_marginal: 28 | loss = initialize_loss(loss_function, reduction="none") 29 | else: 30 | loss = initialize_loss(loss_function) 31 | 32 | model = initialize_model( 33 | model_name=config.model, 34 | dataset_name=config.dataset, 35 | num_classes=config.num_classes, 36 | featurize=True, 37 | pretrained=config.pretrained, 38 | pretrained_path=config.pretrained_path, 39 | ) 40 | model = nn.Sequential(*model) 41 | 42 | model = AttackerModel( 43 | model, 44 | torch.tensor(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN), 45 | torch.tensor(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD), 46 | ) 47 | 48 | # initialize module 49 | super().__init__( 50 | config=config, 51 | model=model, 52 | loss=loss, 53 | n_train_steps=n_train_steps, 54 | ) 55 | 56 | self.use_unlabeled_y = ( 57 | config.use_unlabeled_y 58 | ) # Expect x,y,m from unlabeled loaders and train on the unlabeled y 59 | 60 | self.source_balanced = config.source_balanced 61 | self.num_classes = config.num_classes 62 | self.n_epochs = config.n_epochs 63 | self.normal_train_epoch = int(self.n_epochs * 0.9) 64 | 65 | self.curr_epoch = 0 66 | 67 | self.attack_kwargs = { 68 | "constraint": "2", 69 | "eps": 0.5, 70 | "iterations": 3, 71 | "step_size": 1.0 / 3, 72 | "return_image": False, 73 | } 74 | 75 | def process_batch( 76 | self, 77 | batch, 78 | unlabeled_batch=None, 79 | target_marginal=None, 80 | source_marginal=None, 81 | target_average=None, 82 | ): 83 | """ 84 | Overrides single_model_algorithm.process_batch(). 85 | ERM defines its own process_batch to handle if self.use_unlabeled_y is true. 86 | Args: 87 | - batch (tuple of Tensors): a batch of data yielded by data loaders 88 | - unlabeled_batch (tuple of Tensors or None): a batch of data yielded by unlabeled data loader 89 | Output: 90 | - results (dictionary): information about the batch 91 | - y_true (Tensor): ground truth labels for batch 92 | - y_pred (Tensor): model output for batch 93 | - unlabeled_y_pred (Tensor): predictions for unlabeled batch for fully-supervised ERM experiments 94 | - unlabeled_y_true (Tensor): true labels for unlabeled batch for fully-supervised ERM experiments 95 | """ 96 | if self.is_training and self.batch_idx == 0: 97 | self.curr_epoch += 1 98 | 99 | x, y_true = batch[:2] 100 | x = move_to(x, self.device) 101 | y_true = move_to(y_true, self.device) 102 | 103 | if self.is_training and self.curr_epoch < self.normal_train_epoch: 104 | outputs = self.model(x, y_true, make_adv=True, **self.attack_kwargs) 105 | else: 106 | outputs = self.model(x) 107 | 108 | # import pdb; pdb.set_trace() 109 | results = {"y_true": y_true, "y_pred": outputs} 110 | 111 | if self.use_target_marginal and target_marginal is not None: 112 | results["im_weights"] = torch.divide( 113 | torch.tensor(target_marginal).to(self.device), 114 | torch.tensor(source_marginal).to(self.device), 115 | ) 116 | 117 | if self.source_balanced and source_marginal is not None: 118 | results["source_marginal"] = torch.tensor(source_marginal).to(self.device) 119 | 120 | return results 121 | 122 | def objective(self, results): 123 | labeled_loss = self.loss(results["y_pred"], results["y_true"]) 124 | 125 | if self.use_target_marginal: 126 | labeled_loss = torch.mean( 127 | labeled_loss * results["im_weights"][results["y_true"]] 128 | ) 129 | 130 | elif self.source_balanced: 131 | labeled_loss = torch.mean( 132 | labeled_loss 133 | / results["source_marginal"][results["y_true"]] 134 | / self.num_classes 135 | ) 136 | 137 | if self.use_unlabeled_y and "unlabeled_y_true" in results: 138 | unlabeled_loss = self.loss( 139 | results["unlabeled_y_pred"], 140 | results["unlabeled_y_true"], 141 | ) 142 | 143 | if self.use_target_marginal or self.source_balanced: 144 | unlabeled_loss = torch.mean(unlabeled_loss) 145 | 146 | return labeled_loss + unlabeled_loss 147 | 148 | else: 149 | return labeled_loss 150 | -------------------------------------------------------------------------------- /RLSbench/algorithms/SENTRY.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from RLSbench.algorithms.single_model_algorithm import SingleModelAlgorithm 7 | from RLSbench.losses import initialize_loss 8 | from RLSbench.models.initializer import initialize_model 9 | from RLSbench.models.model_utils import linear_probe 10 | from RLSbench.utils import detach_and_clone, pseudolabel_multiclass_logits 11 | from RLSbench.optimizer import initialize_optimizer_with_model_params 12 | 13 | logger = logging.getLogger("label_shift") 14 | 15 | 16 | class SENTRY(SingleModelAlgorithm): 17 | """ 18 | Sentry: Selective Entropy Optimization via Committee Consistency for Unsupervised Domain Adaptation 19 | 20 | Original paper: 21 | @inproceedings{prabhu2021sentry, 22 | title={Sentry: Selective entropy optimization via committee consistency for unsupervised domain adaptation}, 23 | author={Prabhu, Viraj and Khare, Shivam and Kartik, Deeksha and Hoffman, Judy}, 24 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 25 | pages={8558--8567}, 26 | year={2021} 27 | } 28 | """ 29 | 30 | def __init__(self, config, dataloader, loss_function, n_train_steps, **kwargs): 31 | logger.info("Intializing SENTRY algorithm model") 32 | 33 | model = initialize_model( 34 | model_name=config.model, 35 | dataset_name=config.dataset, 36 | num_classes=config.num_classes, 37 | featurize=True, 38 | pretrained=config.pretrained, 39 | pretrained_path=config.pretrained_path, 40 | ) 41 | 42 | # if config.algorithm.startswith("IW"): 43 | # self.use_target_marginal = True 44 | # else: 45 | self.use_target_marginal = False 46 | 47 | # if config.source_balanced or self.use_target_marginal: 48 | # loss = initialize_loss(loss_function, reduction='none') 49 | # else: 50 | loss = initialize_loss(loss_function) 51 | 52 | # if config.pretrained: 53 | # model = linear_probe(model, dataloader, device= config.device, progress_bar=config.progress_bar) 54 | 55 | params = [ 56 | {"params": model[0].parameters(), "lr": config.lr * 0.1}, 57 | # {"params": self.bottleneck.parameters(), "lr": classifier_lr}, 58 | {"params": model[1].parameters(), "lr": config.lr}, 59 | ] 60 | 61 | model = nn.Sequential(*model) 62 | 63 | # self.optimizer = 64 | self.optimizer = initialize_optimizer_with_model_params(config, params) 65 | 66 | # initialize module 67 | super().__init__( 68 | config=config, 69 | model=model, 70 | loss=loss, 71 | n_train_steps=n_train_steps, 72 | ) 73 | # algorithm hyperparameters 74 | self.lambda_src = kwargs["lambda_src"] 75 | self.lambda_unsup = kwargs["lambda_unsup"] 76 | self.lambda_ent = kwargs["lambda_ent"] 77 | 78 | self.source_balanced = config.source_balanced 79 | self.num_classes = config.num_classes 80 | 81 | def process_batch( 82 | self, 83 | batch, 84 | unlabeled_batch=None, 85 | target_marginal=None, 86 | source_marginal=None, 87 | target_average=None, 88 | ): 89 | """ 90 | Overrides single_model_algorithm.process_batch(). 91 | Args: 92 | - batch (x, y, m): a batch of data yielded by data loaders 93 | - unlabeled_batch: examples ((x_weak, x_strong), m) where x_weak is weakly augmented but x_strong is strongly augmented 94 | Output: 95 | - results (dictionary): information about the batch 96 | - y_true (Tensor): ground truth labels for batch 97 | - y_pred (Tensor): model output for batch 98 | - unlabeled_weak_y_pseudo (Tensor): pseudolabels on x_weak of the unlabeled batch, already thresholded 99 | - unlabeled_strong_y_pred (Tensor): model output on x_strong of the unlabeled batch, already thresholded 100 | """ 101 | # Labeled examples 102 | x, y_true = batch[:2] 103 | x = x.to(self.device) 104 | y_true = y_true.to(self.device) 105 | 106 | y_pred = self.model(x) 107 | # package the results 108 | results = { 109 | "y_true": y_true, 110 | "y_pred": y_pred, 111 | } 112 | 113 | # Unlabeled examples 114 | if unlabeled_batch is not None: 115 | x = unlabeled_batch[0] 116 | x_weak, x_strong_augs = x[0], x[1:] 117 | x_weak = x_weak.to(self.device) 118 | 119 | y_weak_pred = self.model(x_weak) 120 | 121 | y_weak_pseudo_label = y_weak_pred.max(dim=1)[1].detach().reshape(-1) 122 | 123 | correct_mask, incorrect_mask = torch.zeros_like(y_weak_pseudo_label).to( 124 | self.device 125 | ), torch.zeros_like(y_weak_pseudo_label).to(self.device) 126 | score_t_aug_pos, score_t_aug_neg = torch.zeros_like( 127 | y_weak_pred 128 | ), torch.zeros_like(y_weak_pred) 129 | 130 | for x_strong in x_strong_augs: 131 | y_strong_pred = self.model(x_strong.to(self.device)) 132 | y_strong_pseudo_label = y_strong_pred.max(dim=1)[1].reshape(-1) 133 | 134 | consistent_idxs = ( 135 | y_weak_pseudo_label == y_strong_pseudo_label 136 | ).detach() 137 | inconsistent_idxs = ( 138 | y_weak_pseudo_label != y_strong_pseudo_label 139 | ).detach() 140 | 141 | correct_mask = correct_mask + consistent_idxs.type(torch.uint8) 142 | incorrect_mask = incorrect_mask + inconsistent_idxs.type(torch.uint8) 143 | 144 | score_t_aug_pos[consistent_idxs, :] = y_strong_pred[consistent_idxs, :] 145 | score_t_aug_neg[inconsistent_idxs, :] = y_strong_pred[ 146 | inconsistent_idxs, : 147 | ] 148 | 149 | correct_mask, incorrect_mask = correct_mask >= 2, incorrect_mask >= 2 150 | 151 | correct_ratio = (correct_mask).sum().item() / x_weak.shape[0] 152 | incorrect_ratio = (incorrect_mask).sum().item() / x_weak.shape[0] 153 | 154 | results["correct_ratio"] = correct_ratio 155 | results["incorrect_ratio"] = incorrect_ratio 156 | 157 | results["y_weak_pred"] = y_weak_pred 158 | results["y_weak_pseudo_label"] = y_weak_pseudo_label 159 | 160 | results["score_t_aug_pos"] = score_t_aug_pos[correct_mask] 161 | results["score_t_aug_neg"] = score_t_aug_neg[incorrect_mask] 162 | 163 | if target_average is not None: 164 | # import pdb; pdb.set_trace() 165 | results["target_average"] = torch.tensor(target_average).to(self.device) 166 | 167 | # if self.use_target_marginal and target_marginal is not None: 168 | # results['im_weights'] = torch.divide(torch.tensor(target_marginal).to(self.device),\ 169 | # torch.tensor(source_marginal).to(self.device)) 170 | # results['target_marginal'] = torch.tensor(target_marginal).to(self.device) 171 | 172 | # import pdb; pdb.set_trace() 173 | # if self.source_balanced and source_marginal is not None: 174 | # results['source_marginal'] = torch.tensor(source_marginal).to(self.device) 175 | 176 | return results 177 | 178 | def objective(self, results): 179 | # Labeled loss 180 | classification_loss = self.loss(results["y_pred"], results["y_true"]) 181 | 182 | # if self.use_target_marginal: 183 | # classification_loss = self.lambda_src * torch.mean(classification_loss*results["im_weights"][results["y_true"]]) 184 | 185 | # loss_cent = 0.0 186 | 187 | # if results["correct_ratio"] > 0.0: 188 | # probs_t_pos = F.softmax(results["score_t_aug_pos"], dim=1) 189 | # loss_cent_correct = self.lambda_ent * -torch.mean(torch.sum(probs_t_pos * (torch.log(probs_t_pos + 1e-12)), 1)) 190 | # loss_cent += loss_cent_correct* results["correct_ratio"] 191 | 192 | # if results["incorrect_ratio"] > 0.0: 193 | # probs_t_neg = F.softmax(results["score_t_aug_neg"], dim=1) 194 | # loss_cent_incorrect = self.lambda_ent * torch.mean(torch.sum(probs_t_neg * (torch.log(probs_t_neg + 1e-12)), 1)) 195 | # loss_cent += loss_cent_incorrect* results["incorrect_ratio"] 196 | 197 | # elif self.source_balanced: 198 | # classification_loss = torch.mean(classification_loss/results["source_marginal"][results["y_true"]]/ self.num_classes) 199 | # loss_cent = 0.0 200 | 201 | # target_marginal = results["target_marginal"] 202 | # target_marginal[target_marginal == 0] = 1.0 203 | 204 | # if results["correct_ratio"] > 0.0: 205 | # probs_t_pos = F.softmax(results["score_t_aug_pos"], dim=1) 206 | # loss_cent_correct = torch.sum(probs_t_pos * (torch.log(probs_t_pos + 1e-12)), 1) 207 | 208 | # loss_cent_correct = -torch.mean(loss_cent_correct* 1.0/target_marginal[results["y_weak_pseudo_label"]]/ self.num_classes) 209 | 210 | # loss_cent += loss_cent_correct* results["correct_ratio"] 211 | 212 | # if results["incorrect_ratio"] > 0.0: 213 | # probs_t_neg = F.softmax(results["score_t_aug_neg"], dim=1) 214 | # loss_cent_incorrect = torch.sum(probs_t_neg * (torch.log(probs_t_neg + 1e-12)), 1) 215 | 216 | # loss_cent_incorrect = torch.mean(loss_cent_correct* 1.0/target_marginal[results["y_weak_pseudo_label"]]/ self.num_classes) 217 | 218 | # loss_cent += loss_cent_incorrect* results["incorrect_ratio"] 219 | 220 | # else: 221 | classification_loss = self.lambda_src * classification_loss 222 | 223 | loss_cent = 0.0 224 | 225 | if results["correct_ratio"] > 0.0: 226 | probs_t_pos = F.softmax(results["score_t_aug_pos"], dim=1) 227 | loss_cent_correct = self.lambda_ent * -torch.mean( 228 | torch.sum(probs_t_pos * (torch.log(probs_t_pos + 1e-12)), 1) 229 | ) 230 | loss_cent += loss_cent_correct * results["correct_ratio"] 231 | 232 | if results["incorrect_ratio"] > 0.0: 233 | probs_t_neg = F.softmax(results["score_t_aug_neg"], dim=1) 234 | loss_cent_incorrect = self.lambda_ent * torch.mean( 235 | torch.sum(probs_t_neg * (torch.log(probs_t_neg + 1e-12)), 1) 236 | ) 237 | loss_cent += loss_cent_incorrect * results["incorrect_ratio"] 238 | 239 | if "target_average" in results: 240 | loss_infoent = self.lambda_unsup * torch.mean( 241 | torch.sum( 242 | F.softmax(results["y_weak_pred"], dim=1) 243 | * torch.log(results["target_average"] + 1e-12).reshape( 244 | 1, self.num_classes 245 | ), 246 | dim=1, 247 | ) 248 | ) 249 | 250 | else: 251 | loss_infoent = 0.0 252 | 253 | return classification_loss + loss_cent + loss_infoent 254 | -------------------------------------------------------------------------------- /RLSbench/algorithms/TENT.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from RLSbench.algorithms.algorithm import Algorithm 5 | from RLSbench.models.initializer import initialize_model 6 | from RLSbench.models.model_utils import collect_params, configure_model 7 | from RLSbench.optimizer import initialize_optimizer_with_model_params 8 | from RLSbench.utils import load, move_to 9 | from torch.optim import SGD, Adam 10 | 11 | logger = logging.getLogger("label_shift") 12 | 13 | 14 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor: 15 | """Entropy of softmax distribution from logits.""" 16 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 17 | 18 | 19 | class TENT(Algorithm): 20 | def __init__(self, config): 21 | logger.info("Initializing model...") 22 | 23 | model = initialize_model( 24 | model_name=config.model, 25 | dataset_name=config.dataset, 26 | num_classes=config.num_classes, 27 | featurize=False, 28 | pretrained=False, 29 | ) 30 | 31 | model.to(config.device) 32 | 33 | # initialize module 34 | super().__init__( 35 | device=config.device, 36 | ) 37 | self.model = model 38 | 39 | self.source_balanced = config.source_balanced 40 | self.num_classes = config.num_classes 41 | self.config = config 42 | 43 | def get_model_output(self, x): 44 | outputs = self.model(x) 45 | return outputs 46 | 47 | def process_batch(self, batch): 48 | """ 49 | A helper function for update() and evaluate() that processes the batch 50 | Args: 51 | - batch (tuple of Tensors): a batch of data yielded by data loaders 52 | Output: 53 | - results (dictionary): information about the batch 54 | - y_true (Tensor): ground truth labels for batch 55 | - y_pred (Tensor): model output for batch 56 | """ 57 | x, y_true = batch[:2] 58 | x = move_to(x, self.device) 59 | y_true = move_to(y_true, self.device) 60 | 61 | outputs = self.get_model_output(x) 62 | 63 | results = { 64 | "y_true": y_true, 65 | "y_pred": outputs, 66 | } 67 | return results 68 | 69 | def evaluate(self, batch): 70 | """ 71 | Process the batch and update the log, without updating the model 72 | Args: 73 | - batch (tuple of Tensors): a batch of data yielded by data loaders 74 | Output: 75 | - results (dictionary): information about the batch, such as: 76 | - y_true (Tensor) 77 | - outputs (Tensor) 78 | - y_pred (Tensor) 79 | """ 80 | assert not self.is_training 81 | results = self.process_batch(batch) 82 | return results 83 | 84 | def adapt( 85 | self, 86 | source_loader, 87 | target_loader, 88 | target_marginal=None, 89 | source_marginal=None, 90 | target_average=None, 91 | pretrained_path=None, 92 | ): 93 | """ 94 | Load the model and adapt it to the new data 95 | Args: 96 | - unlabeled_batch (tuple of Tensors): a batch of data yielded by unlabeled data loader 97 | - target_marginal (Tensor): the marginal distribution of the target 98 | - source_marginal (Tensor): the marginal distribution of the source 99 | - target_average (Tensor): the average of the target 100 | 101 | Output: 102 | """ 103 | 104 | if pretrained_path is not None: 105 | logger.info(f"Loading pretrained model from {pretrained_path}") 106 | load(self.model, pretrained_path, device=self.device) 107 | 108 | # TODO: Check what if we adapt to the BN params here 109 | # logger.info("Adapting BN params...") 110 | 111 | # self.train(True) 112 | 113 | # with torch.no_grad(): 114 | # for batch in target_loader: 115 | # self.model(batch[0].to(self.device)) 116 | 117 | self.model = configure_model(self.model) 118 | params, param_names = collect_params(self.model) 119 | 120 | self.optimizer = SGD(params, lr=1e-4, momentum=0.9) 121 | 122 | logger.info("Adapting model to TENT ...") 123 | 124 | # for epoch in range(5): 125 | for batch in target_loader: 126 | self.optimizer.zero_grad() 127 | outputs = self.model(batch[0].to(self.device)) 128 | 129 | loss = softmax_entropy(outputs).mean(0) 130 | 131 | loss.backward() 132 | self.optimizer.step() 133 | 134 | self.optimizer.zero_grad() 135 | 136 | def reset(self): 137 | pass 138 | -------------------------------------------------------------------------------- /RLSbench/algorithms/algorithm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from RLSbench.utils import detach_and_clone, move_to 3 | 4 | 5 | class Algorithm(nn.Module): 6 | def __init__( 7 | self, 8 | device, 9 | ): 10 | super().__init__() 11 | self.device = device 12 | self.out_device = "cpu" 13 | 14 | def update(self, batch): 15 | """ 16 | Process the batch, update the log, and update the model 17 | Args: 18 | - batch (tuple of Tensors): a batch of data yielded by data loaders 19 | Output: 20 | - results (dictionary): information about the batch, such as: 21 | - y_true (Tensor) 22 | - loss (Tensor) 23 | - metrics (Tensor) 24 | """ 25 | raise NotImplementedError 26 | 27 | def evaluate(self, batch): 28 | """ 29 | Process the batch and update the log, without updating the model 30 | Args: 31 | - batch (tuple of Tensors): a batch of data yielded by data loaders 32 | Output: 33 | - results (dictionary): information about the batch, such as: 34 | - y_true (Tensor) 35 | - loss (Tensor) 36 | - metrics (Tensor) 37 | """ 38 | raise NotImplementedError 39 | 40 | def train(self, mode=True): 41 | """ 42 | Switch to train mode 43 | """ 44 | self.is_training = mode 45 | super().train(mode) 46 | 47 | def step_schedulers(self, is_epoch): 48 | """ 49 | Update all relevant schedulers 50 | Args: 51 | - is_epoch (bool): epoch-wise update if set to True, batch-wise update otherwise 52 | - metrics (dict): a dictionary of metrics that can be used for scheduler updates 53 | - log_access (bool): whether metrics from self.get_log() can be used to update schedulers 54 | """ 55 | raise NotImplementedError 56 | -------------------------------------------------------------------------------- /RLSbench/algorithms/fixmatch.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from RLSbench.algorithms.single_model_algorithm import SingleModelAlgorithm 7 | from RLSbench.losses import initialize_loss 8 | from RLSbench.models.initializer import initialize_model 9 | from RLSbench.models.model_utils import linear_probe 10 | from RLSbench.utils import detach_and_clone, pseudolabel_multiclass_logits 11 | 12 | logger = logging.getLogger("label_shift") 13 | 14 | 15 | class FixMatch(SingleModelAlgorithm): 16 | """ 17 | FixMatch. 18 | This algorithm was originally proposed as a semi-supervised learning algorithm. 19 | 20 | Loss is of the form 21 | \ell_s + \lambda * \ell_u 22 | where 23 | \ell_s = cross-entropy with true labels using weakly augmented labeled examples 24 | \ell_u = cross-entropy with pseudolabel generated using weak augmentation and prediction 25 | using strong augmentation 26 | 27 | Original paper: 28 | @article{sohn2020fixmatch, 29 | title={Fixmatch: Simplifying semi-supervised learning with consistency and confidence}, 30 | author={Sohn, Kihyuk and Berthelot, David and Li, Chun-Liang and Zhang, Zizhao and Carlini, Nicholas and Cubuk, Ekin D and Kurakin, Alex and Zhang, Han and Raffel, Colin}, 31 | journal={arXiv preprint arXiv:2001.07685}, 32 | year={2020} 33 | } 34 | """ 35 | 36 | def __init__(self, config, dataloader, loss_function, n_train_steps, **kwargs): 37 | logger.info("Intializing FixMatch algorithm model") 38 | 39 | model = initialize_model( 40 | model_name=config.model, 41 | dataset_name=config.dataset, 42 | num_classes=config.num_classes, 43 | featurize=True, 44 | pretrained=config.pretrained, 45 | pretrained_path=config.pretrained_path, 46 | ) 47 | 48 | # if config.algorithm.startswith("IW"): 49 | # # self.use_target_marginal = True 50 | # self.use_target_marginal = False 51 | # else: 52 | self.use_target_marginal = False 53 | 54 | # if config.source_balanced or self.use_target_marginal: 55 | # loss = initialize_loss(loss_function, reduction='none') 56 | # else: 57 | loss = initialize_loss(loss_function) 58 | 59 | # if config.pretrained: 60 | # model = linear_probe(model, dataloader, device= config.device, progress_bar=config.progress_bar) 61 | 62 | model = nn.Sequential(*model) 63 | 64 | # initialize module 65 | super().__init__( 66 | config=config, 67 | model=model, 68 | loss=loss, 69 | n_train_steps=n_train_steps, 70 | ) 71 | # algorithm hyperparameters 72 | self.fixmatch_lambda = kwargs["self_training_lambda"] 73 | self.target_align = kwargs["target_align"] 74 | self.confidence_threshold = kwargs["self_training_threshold"] 75 | self.process_pseudolabels_function = pseudolabel_multiclass_logits 76 | 77 | self.source_balanced = config.source_balanced 78 | self.num_classes = config.num_classes 79 | 80 | def process_batch( 81 | self, 82 | batch, 83 | unlabeled_batch=None, 84 | target_marginal=None, 85 | source_marginal=None, 86 | target_average=None, 87 | ): 88 | """ 89 | Overrides single_model_algorithm.process_batch(). 90 | Args: 91 | - batch (x, y, m): a batch of data yielded by data loaders 92 | - unlabeled_batch: examples ((x_weak, x_strong), m) where x_weak is weakly augmented but x_strong is strongly augmented 93 | Output: 94 | - results (dictionary): information about the batch 95 | - y_true (Tensor): ground truth labels for batch 96 | - y_pred (Tensor): model output for batch 97 | - unlabeled_weak_y_pseudo (Tensor): pseudolabels on x_weak of the unlabeled batch, already thresholded 98 | - unlabeled_strong_y_pred (Tensor): model output on x_strong of the unlabeled batch, already thresholded 99 | """ 100 | # Labeled examples 101 | x, y_true = batch[:2] 102 | x = x.to(self.device) 103 | y_true = y_true.to(self.device) 104 | 105 | # package the results 106 | results = { 107 | "y_true": y_true, 108 | } 109 | 110 | # if self.source_balanced and source_marginal is not None: 111 | # results['source_marginal'] = torch.tensor(source_marginal).to(self.device) 112 | 113 | # if self.use_target_marginal and target_marginal is not None: 114 | # results['im_weights'] = torch.divide(torch.tensor(target_marginal).to(self.device),\ 115 | # torch.tensor(source_marginal).to(self.device)) 116 | 117 | pseudolabels_kept_frac = 0 118 | 119 | # Unlabeled examples 120 | if unlabeled_batch is not None: 121 | if self.target_align and target_average is not None: 122 | alignment_dist = torch.divide( 123 | torch.tensor(target_marginal).to(self.device), 124 | torch.tensor(target_average).to(self.device), 125 | ) 126 | 127 | (x_weak, x_strong) = unlabeled_batch[0] 128 | x_weak = x_weak.to(self.device) 129 | x_strong = x_strong.to(self.device) 130 | 131 | with torch.no_grad(): 132 | outputs = self.model(x_weak) 133 | 134 | if self.target_align and target_average is not None: 135 | ( 136 | _, 137 | pseudolabels, 138 | pseudolabels_kept_frac, 139 | mask, 140 | ) = self.process_pseudolabels_function( 141 | outputs, self.confidence_threshold, alignment_dist 142 | ) 143 | else: 144 | ( 145 | _, 146 | pseudolabels, 147 | pseudolabels_kept_frac, 148 | mask, 149 | ) = self.process_pseudolabels_function( 150 | outputs, self.confidence_threshold 151 | ) 152 | 153 | results["unlabeled_weak_y_pseudo"] = detach_and_clone(pseudolabels) 154 | 155 | results["pseudolabels_kept_frac"] = pseudolabels_kept_frac 156 | 157 | # Concat and call forward 158 | n_lab = x.shape[0] 159 | if unlabeled_batch is not None: 160 | x_concat = torch.cat((x, x_strong), dim=0) 161 | else: 162 | x_concat = x 163 | 164 | outputs = self.model(x_concat) 165 | results["y_pred"] = outputs[:n_lab] 166 | if unlabeled_batch is not None: 167 | results["unlabeled_strong_y_pred"] = ( 168 | outputs[n_lab:] if mask is None else outputs[n_lab:][mask] 169 | ) 170 | 171 | return results 172 | 173 | def objective(self, results): 174 | # Labeled loss 175 | classification_loss = self.loss(results["y_pred"], results["y_true"]) 176 | 177 | # if self.use_target_marginal: 178 | # classification_loss = torch.mean(classification_loss*results["im_weights"][results["y_true"]]) 179 | 180 | # elif self.source_balanced: 181 | # classification_loss = torch.mean(classification_loss/results["source_marginal"][results["y_true"]]/ self.num_classes) 182 | 183 | # Pseudolabeled loss 184 | if "unlabeled_weak_y_pseudo" in results: 185 | loss_output = self.loss( 186 | results["unlabeled_strong_y_pred"], 187 | results["unlabeled_weak_y_pseudo"], 188 | ) 189 | consistency_loss = loss_output * results["pseudolabels_kept_frac"] 190 | else: 191 | consistency_loss = 0 192 | 193 | return classification_loss + self.fixmatch_lambda * consistency_loss 194 | -------------------------------------------------------------------------------- /RLSbench/algorithms/initializer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from RLSbench.algorithms.BN_adapt import BN_adapt 4 | from RLSbench.algorithms.BN_adapt_adv import BN_adapt_adv 5 | from RLSbench.algorithms.CDAN import CDAN 6 | from RLSbench.algorithms.COAL import COAL 7 | from RLSbench.algorithms.CORAL import CORAL 8 | from RLSbench.algorithms.DANN import DANN 9 | from RLSbench.algorithms.ERM import ERM 10 | from RLSbench.algorithms.ERM_Adv import ERM_Adv 11 | from RLSbench.algorithms.fixmatch import FixMatch 12 | from RLSbench.algorithms.noisy_student import NoisyStudent 13 | from RLSbench.algorithms.pseudolabel import PseudoLabel 14 | from RLSbench.algorithms.SENTRY import SENTRY 15 | from RLSbench.algorithms.TENT import TENT 16 | 17 | logger = logging.getLogger("label_shift") 18 | 19 | 20 | def initialize_algorithm(config, datasets, dataloader): 21 | logger.info(f"Initializing algorithm {config.algorithm} ...") 22 | 23 | source_dataset = datasets["source_train"] 24 | trainloader_source = dataloader["source_train"] 25 | 26 | # Other config 27 | n_train_steps = ( 28 | len(trainloader_source) * config.n_epochs // config.gradient_accumulation_steps 29 | ) 30 | 31 | if config.algorithm in ( 32 | "ERM-rand", 33 | "ERM-imagenet", 34 | "ERM-clip", 35 | "ERM-bert", 36 | "ERM-aug-rand", 37 | "ERM-aug-imagenet", 38 | "ERM-swav", 39 | "ERM-oracle-rand", 40 | "ERM-oracle-imagenet", 41 | "IS-ERM-rand", 42 | "IS-ERM-imagenet", 43 | "IS-ERM-clip", 44 | "IS-ERM-aug-rand", 45 | "IS-ERM-aug-imagenet", 46 | "IS-ERM-swav", 47 | "IS-ERM-oracle-rand", 48 | "IS-ERM-oracle-imagenet", 49 | ): 50 | algorithm = ERM( 51 | config=config, 52 | dataloader=trainloader_source, 53 | loss_function=config.loss_function, 54 | n_train_steps=n_train_steps, 55 | ) 56 | 57 | elif config.algorithm in ("ERM-adv"): 58 | algorithm = ERM_Adv( 59 | config=config, 60 | dataloader=trainloader_source, 61 | loss_function=config.loss_function, 62 | n_train_steps=n_train_steps, 63 | ) 64 | 65 | elif config.algorithm in ("DANN", "IW-DANN", "IS-DANN"): 66 | algorithm = DANN( 67 | config=config, 68 | dataloader=trainloader_source, 69 | loss_function=config.loss_function, 70 | n_train_steps=n_train_steps, 71 | n_domains=2, 72 | **config.dann_kwargs, 73 | ) 74 | 75 | elif config.algorithm in ("CDANN", "IW-CDANN", "IS-CDANN"): 76 | algorithm = CDAN( 77 | config=config, 78 | dataloader=trainloader_source, 79 | loss_function=config.loss_function, 80 | n_train_steps=n_train_steps, 81 | n_domains=2, 82 | **config.cdan_kwargs, 83 | ) 84 | 85 | elif config.algorithm in ("FixMatch", "IS-FixMatch"): 86 | algorithm = FixMatch( 87 | config=config, 88 | dataloader=trainloader_source, 89 | loss_function=config.loss_function, 90 | n_train_steps=n_train_steps, 91 | **config.fixmatch_kwargs, 92 | ) 93 | 94 | elif config.algorithm in ("PseudoLabel", "IS-PseudoLabel"): 95 | algorithm = PseudoLabel( 96 | config=config, 97 | dataloader=trainloader_source, 98 | loss_function=config.loss_function, 99 | n_train_steps=n_train_steps, 100 | **config.pseudolabel_kwargs, 101 | ) 102 | 103 | elif config.algorithm in ("NoisyStudent", "IS-NoisyStudent"): 104 | algorithm = NoisyStudent( 105 | config=config, 106 | dataloader=trainloader_source, 107 | loss_function=config.loss_function, 108 | n_train_steps=n_train_steps, 109 | **config.noisystudent_kwargs, 110 | ) 111 | 112 | elif config.algorithm in ("COAL", "IW-COAL"): 113 | algorithm = COAL( 114 | config=config, 115 | dataloader=trainloader_source, 116 | loss_function=config.loss_function, 117 | n_train_steps=n_train_steps, 118 | **config.coal_kwargs, 119 | ) 120 | 121 | elif config.algorithm in ("SENTRY", "IW-SENTRY"): 122 | algorithm = SENTRY( 123 | config=config, 124 | dataloader=trainloader_source, 125 | loss_function=config.loss_function, 126 | n_train_steps=n_train_steps, 127 | **config.sentry_kwargs, 128 | ) 129 | 130 | elif config.algorithm in ("CORAL", "IS-CORAL"): 131 | algorithm = CORAL(config=config) 132 | 133 | elif config.algorithm in ("BN_adapt", "IS-BN_adapt"): 134 | algorithm = BN_adapt(config=config) 135 | 136 | elif config.algorithm in ("BN_adapt-adv", "IS-BN_adapt-adv"): 137 | algorithm = BN_adapt_adv(config=config) 138 | 139 | elif config.algorithm in ("TENT", "IS-TENT"): 140 | algorithm = TENT(config=config) 141 | 142 | else: 143 | raise ValueError(f"Algorithm {config.algorithm} not recognized") 144 | 145 | return algorithm 146 | -------------------------------------------------------------------------------- /RLSbench/algorithms/noisy_student.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | from RLSbench.algorithms.single_model_algorithm import SingleModelAlgorithm 6 | from RLSbench.losses import initialize_loss 7 | from RLSbench.models.initializer import initialize_model 8 | from RLSbench.models.model_utils import linear_probe 9 | from RLSbench.utils import concat_input, move_to 10 | 11 | logger = logging.getLogger("label_shift") 12 | 13 | 14 | class DropoutModel(nn.Module): 15 | def __init__(self, featurizer, classifier, dropout_rate): 16 | super().__init__() 17 | self.featurizer = featurizer 18 | self.dropout = nn.Dropout(p=dropout_rate) 19 | self.classifier = classifier 20 | 21 | def forward(self, x): 22 | features = self.featurizer(x) 23 | features_sparse = self.dropout(features) 24 | return self.classifier(features_sparse) 25 | 26 | 27 | class NoisyStudent(SingleModelAlgorithm): 28 | """ 29 | Noisy Student. 30 | This algorithm was originally proposed as a semi-supervised learning algorithm. 31 | 32 | One run of this codebase gives us one iteration (load a teacher, train student). To run another iteration, 33 | re-run the previous command, pointing config.teacher_model_path to the trained student weights. 34 | 35 | To warm start the student model, point config.pretrained_model_path to config.teacher_model_path 36 | 37 | Based on the original paper, loss is of the form 38 | \ell_s + \ell_u 39 | where 40 | \ell_s = cross-entropy with true labels; student predicts with noise 41 | \ell_u = cross-entropy with pseudolabel generated without noise; student predicts with noise 42 | The student is noised using: 43 | - Input images are augmented using RandAugment 44 | - Single dropout layer before final classifier (fc) layer 45 | We do not use stochastic depth. 46 | 47 | Pseudolabels are generated in run_expt.py on unlabeled images that have only been randomly cropped and flipped ("weak" transform). 48 | By default, we use hard pseudolabels; use the --soft_pseudolabels flag to add soft pseudolabels. 49 | 50 | This code only supports a teacher that is the same class as the student (e.g. both densenet121s) 51 | 52 | Original paper: 53 | @inproceedings{xie2020self, 54 | title={Self-training with noisy student improves imagenet classification}, 55 | author={Xie, Qizhe and Luong, Minh-Thang and Hovy, Eduard and Le, Quoc V}, 56 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 57 | pages={10687--10698}, 58 | year={2020} 59 | } 60 | """ 61 | 62 | def __init__(self, config, dataloader, loss_function, n_train_steps, **kwargs): 63 | logger.info("Intializing Noisy Student algorithm model") 64 | 65 | # initialize student model with dropout before last layer 66 | if kwargs["noisystudent_add_dropout"]: 67 | featurizer, classifier = initialize_model( 68 | model_name=config.model, 69 | dataset_name=config.dataset, 70 | num_classes=config.num_classes, 71 | featurize=True, 72 | pretrained=config.pretrained, 73 | pretrained_path=config.pretrained_path, 74 | ) 75 | 76 | model = (featurizer, classifier) 77 | 78 | # if config.pretrained : 79 | # featurizer, classifier = linear_probe(model, dataloader, device= config.device, progress_bar=config.progress_bar) 80 | 81 | student_model = DropoutModel( 82 | featurizer, classifier, kwargs["noisystudent_dropout_rate"] 83 | ) 84 | 85 | else: 86 | featurizer, classifier = initialize_model( 87 | model_name=config.model, 88 | dataset_name=config.dataset, 89 | num_classes=config.num_classes, 90 | featurize=True, 91 | pretrained=config.pretrained, 92 | pretrained_path=config.pretrained_path, 93 | ) 94 | model = (featurizer, classifier) 95 | 96 | # if config.pretrained and config.featurize: 97 | # model = linear_probe(model, dataloader, device= config.device) 98 | 99 | student_model = nn.Sequential(*model) 100 | 101 | # if config.algorithm.startswith("IW"): 102 | # self.use_target_marginal = True 103 | # else: 104 | self.use_target_marginal = False 105 | 106 | # if config.source_balanced or self.use_target_marginal: 107 | # loss = initialize_loss(loss_function, reduction='none') 108 | # else: 109 | loss = initialize_loss(loss_function) 110 | 111 | # initialize module 112 | super().__init__( 113 | config=config, 114 | model=student_model, 115 | loss=loss, 116 | n_train_steps=n_train_steps, 117 | ) 118 | 119 | self.source_balanced = config.source_balanced 120 | self.num_classes = config.num_classes 121 | 122 | def process_batch( 123 | self, 124 | batch, 125 | unlabeled_batch=None, 126 | target_marginal=None, 127 | source_marginal=None, 128 | target_average=None, 129 | ): 130 | """ 131 | Overrides single_model_algorithm.process_batch(). 132 | Args: 133 | - batch (x, y, m): a batch of data yielded by data loaders 134 | - unlabeled_batch: examples (x, y_pseudo, m) where y_pseudo is an already-computed teacher pseudolabel 135 | Output: 136 | - results (dictionary): information about the batch 137 | - y_true (Tensor): ground truth labels for batch 138 | - y_pred (Tensor): model output for batch 139 | - unlabeled_y_pseudo (Tensor): pseudolabels for unlabeled batch (from loader) 140 | - unlabeled_y_pred (Tensor): model output on unlabeled batch 141 | """ 142 | # Labeled examples 143 | x, y_true = batch[:2] 144 | n_lab = len(y_true) 145 | x = move_to(x, self.device) 146 | y_true = move_to(y_true, self.device) 147 | 148 | # package the results 149 | results = {"y_true": y_true} 150 | 151 | # Unlabeled examples with pseudolabels 152 | if unlabeled_batch is not None: 153 | x_unlab, y_pseudo = unlabeled_batch[:2] 154 | x_unlab = move_to(x_unlab, self.device) 155 | y_pseudo = move_to(y_pseudo, self.device) 156 | 157 | results["unlabeled_y_pseudo"] = y_pseudo 158 | 159 | x_cat = concat_input(x, x_unlab) 160 | 161 | outputs = self.get_model_output(x_cat) 162 | 163 | results["y_pred"] = outputs[:n_lab] 164 | results["unlabeled_y_pred"] = outputs[n_lab:] 165 | else: 166 | results["y_pred"] = self.get_model_output(x) 167 | 168 | # if self.use_target_marginal and target_marginal is not None: 169 | # results['im_weights'] = torch.divide(torch.tensor(target_marginal).to(self.device),\ 170 | # torch.tensor(source_marginal).to(self.device)) 171 | 172 | # if self.source_balanced and source_marginal is not None: 173 | # results['source_marginal'] = torch.tensor(source_marginal).to(self.device) 174 | 175 | return results 176 | 177 | def objective(self, results): 178 | # Labeled loss 179 | classification_loss = self.loss(results["y_pred"], results["y_true"]) 180 | 181 | # if self.use_target_marginal: 182 | # classification_loss = torch.mean(classification_loss*results["im_weights"][results["y_true"]]) 183 | 184 | # elif self.source_balanced: 185 | # classification_loss = torch.mean(classification_loss/results["source_marginal"][results["y_true"]]/ self.num_classes) 186 | 187 | # Pseudolabel loss 188 | if "unlabeled_y_pseudo" in results: 189 | consistency_loss = self.loss( 190 | results["unlabeled_y_pred"], 191 | results["unlabeled_y_pseudo"], 192 | ) 193 | # if self.use_target_marginal or self.source_balanced: 194 | # consistency_loss = torch.mean(consistency_loss) 195 | 196 | else: 197 | consistency_loss = 0 198 | 199 | return classification_loss + consistency_loss 200 | -------------------------------------------------------------------------------- /RLSbench/algorithms/pseudolabel.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from RLSbench.algorithms.single_model_algorithm import SingleModelAlgorithm 7 | from RLSbench.losses import initialize_loss 8 | from RLSbench.models.initializer import initialize_model 9 | from RLSbench.models.model_utils import linear_probe 10 | from RLSbench.scheduler import LinearScheduleWithWarmupAndThreshold 11 | from RLSbench.utils import ( 12 | concat_input, 13 | detach_and_clone, 14 | move_to, 15 | pseudolabel_multiclass_logits, 16 | ) 17 | 18 | logger = logging.getLogger("label_shift") 19 | 20 | 21 | class PseudoLabel(SingleModelAlgorithm): 22 | """ 23 | PseudoLabel. 24 | This is a vanilla pseudolabeling algorithm which updates the model per batch and incorporates a confidence threshold. 25 | 26 | Original paper: 27 | @inproceedings{lee2013pseudo, 28 | title={Pseudo-label: The simple and efficient semi-supervised learning method for deep neural networks}, 29 | author={Lee, Dong-Hyun and others}, 30 | booktitle={Workshop on challenges in representation learning, ICML}, 31 | volume={3}, 32 | number={2}, 33 | pages={896}, 34 | year={2013} 35 | } 36 | """ 37 | 38 | def __init__(self, config, dataloader, loss_function, n_train_steps, **kwargs): 39 | logger.info("Initializing PseudoLabel models") 40 | 41 | model = initialize_model( 42 | model_name=config.model, 43 | dataset_name=config.dataset, 44 | num_classes=config.num_classes, 45 | featurize=True, 46 | pretrained=config.pretrained, 47 | pretrained_path=config.pretrained_path, 48 | data_dir=config.root_dir, 49 | ) 50 | 51 | self.use_target_marginal = False 52 | 53 | loss = initialize_loss(loss_function) 54 | 55 | # if config.pretrained : 56 | # model = linear_probe(model, dataloader, device= config.device, progress_bar=config.progress_bar) 57 | 58 | model = nn.Sequential(*model) 59 | 60 | # initialize module 61 | super().__init__( 62 | config=config, 63 | model=model, 64 | loss=loss, 65 | n_train_steps=n_train_steps, 66 | ) 67 | 68 | # algorithm hyperparameters 69 | self.lambda_scheduler = LinearScheduleWithWarmupAndThreshold( 70 | max_value=kwargs["self_training_lambda"], 71 | step_every_batch=True, # step per batch 72 | last_warmup_step=0, 73 | threshold_step=kwargs["pseudolabel_T2"] * n_train_steps, 74 | ) 75 | 76 | self.schedulers.append(self.lambda_scheduler) 77 | self.confidence_threshold = kwargs["self_training_threshold"] 78 | self.target_align = kwargs["target_align"] 79 | self.process_pseudolabels_function = pseudolabel_multiclass_logits 80 | 81 | self.source_balanced = config.source_balanced 82 | self.num_classes = config.num_classes 83 | self.dataset = config.dataset 84 | 85 | def process_batch( 86 | self, 87 | batch, 88 | unlabeled_batch=None, 89 | target_marginal=None, 90 | source_marginal=None, 91 | target_average=None, 92 | ): 93 | """ 94 | Overrides single_model_algorithm.process_batch(). 95 | Args: 96 | - batch (tuple of Tensors): a batch of data yielded by data loaders 97 | - unlabeled_batch (tuple of Tensors or None): a batch of data yielded by unlabeled data loader 98 | Output: 99 | - results (dictionary): information about the batch 100 | - y_true (Tensor): ground truth labels for batch 101 | - y_pred (Tensor): model output for batch 102 | - unlabeled_y_pseudo (Tensor): pseudolabels on the unlabeled batch, already thresholded 103 | - unlabeled_y_pred (Tensor): model output on the unlabeled batch, already thresholded 104 | """ 105 | # Labeled examples 106 | x, y_true = batch[:2] 107 | 108 | n_lab = y_true.shape[0] 109 | 110 | # package the results 111 | 112 | if self.use_target_marginal and target_marginal is not None: 113 | results["im_weights"] = torch.divide( 114 | torch.tensor(target_marginal).to(self.device), 115 | torch.tensor(source_marginal).to(self.device), 116 | ) 117 | 118 | if self.source_balanced and source_marginal is not None: 119 | results["source_marginal"] = torch.tensor(source_marginal).to(self.device) 120 | 121 | if unlabeled_batch is not None: 122 | if self.target_align and target_average is not None: 123 | alignment_dist = torch.divide( 124 | torch.tensor(target_marginal).to(self.device), 125 | torch.tensor(target_average).to(self.device), 126 | ) 127 | 128 | x_unlab = unlabeled_batch[0] 129 | 130 | if "mimic" in self.dataset: 131 | # x_cat = collate_fn_mimic([x, unlabeled_x]) 132 | x_cat = [x[0] + x_unlab[0], x[1] + x_unlab[1]] 133 | 134 | else: 135 | x_cat = concat_input(x, x_unlab) 136 | 137 | x_cat = move_to(x_cat, self.device) 138 | outputs = self.get_model_output(x_cat) 139 | unlabeled_output = outputs[n_lab:] 140 | 141 | if self.target_align and target_average is not None: 142 | ( 143 | unlabeled_y_pred, 144 | unlabeled_y_pseudo, 145 | pseudolabels_kept_frac, 146 | _, 147 | ) = self.process_pseudolabels_function( 148 | unlabeled_output, self.confidence_threshold, alignment_dist 149 | ) 150 | else: 151 | ( 152 | unlabeled_y_pred, 153 | unlabeled_y_pseudo, 154 | pseudolabels_kept_frac, 155 | _, 156 | ) = self.process_pseudolabels_function( 157 | unlabeled_output, self.confidence_threshold 158 | ) 159 | 160 | y_true = move_to(y_true, self.device) 161 | 162 | results = { 163 | "y_true": y_true, 164 | } 165 | results["y_pred"] = outputs[:n_lab] 166 | results["unlabeled_y_pred"] = unlabeled_y_pred 167 | results["unlabeled_y_pseudo"] = detach_and_clone(unlabeled_y_pseudo) 168 | else: 169 | x = move_to(x, self.device) 170 | y_true = move_to(y_true, self.device) 171 | results = { 172 | "y_true": y_true, 173 | } 174 | results["y_pred"] = self.get_model_output(x) 175 | pseudolabels_kept_frac = 0 176 | 177 | results["pseudolabels_kept_frac"] = pseudolabels_kept_frac 178 | 179 | return results 180 | 181 | def objective(self, results): 182 | # Labeled loss 183 | classification_loss = self.loss(results["y_pred"], results["y_true"]) 184 | 185 | if self.use_target_marginal: 186 | classification_loss = torch.mean( 187 | classification_loss * results["im_weights"][results["y_true"]] 188 | ) 189 | 190 | elif self.source_balanced: 191 | classification_loss = torch.mean( 192 | classification_loss 193 | / results["source_marginal"][results["y_true"]] 194 | / self.num_classes 195 | ) 196 | 197 | # Pseudolabeled loss 198 | if "unlabeled_y_pseudo" in results: 199 | loss_output = self.loss( 200 | results["unlabeled_y_pred"], 201 | results["unlabeled_y_pseudo"], 202 | ) 203 | 204 | if self.use_target_marginal or self.source_balanced: 205 | loss_output = torch.mean(loss_output) 206 | 207 | consistency_loss = loss_output * results["pseudolabels_kept_frac"] 208 | else: 209 | consistency_loss = 0 210 | 211 | return classification_loss + self.lambda_scheduler.value * consistency_loss 212 | -------------------------------------------------------------------------------- /RLSbench/algorithms/single_model_algorithm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from RLSbench.algorithms.algorithm import Algorithm 3 | from RLSbench.optimizer import initialize_optimizer 4 | from RLSbench.scheduler import initialize_scheduler, step_scheduler 5 | from RLSbench.utils import move_to 6 | from torch.nn import DataParallel 7 | 8 | 9 | class SingleModelAlgorithm(Algorithm): 10 | """ 11 | An abstract class for algorithm that has one underlying model. 12 | """ 13 | 14 | def __init__(self, config, model, loss, n_train_steps): 15 | super().__init__(config.device) 16 | self.loss = loss 17 | 18 | # initialize models, optimizers, and schedulers 19 | if not hasattr(self, "optimizer") or self.optimizer is None: 20 | self.optimizer = initialize_optimizer(config, model) 21 | 22 | self.schedulers = [initialize_scheduler(config, self.optimizer, n_train_steps)] 23 | 24 | if config.use_data_parallel: 25 | model = DataParallel(model) 26 | 27 | model.to(config.device) 28 | 29 | self.batch_idx = 0 30 | self.gradient_accumulation_steps = config.gradient_accumulation_steps 31 | self.model = model 32 | 33 | def get_model_output(self, x): 34 | outputs = self.model(x) 35 | return outputs 36 | 37 | def process_batch(self, batch, unlabeled_batch=None): 38 | """ 39 | A helper function for update() and evaluate() that processes the batch 40 | Args: 41 | - batch (tuple of Tensors): a batch of data yielded by data loaders 42 | Output: 43 | - results (dictionary): information about the batch 44 | - y_true (Tensor): ground truth labels for batch 45 | - y_pred (Tensor): model output for batch 46 | """ 47 | x, y_true = batch[:2] 48 | x = move_to(x, self.device) 49 | y_true = move_to(y_true, self.device) 50 | 51 | outputs = self.get_model_output(x) 52 | 53 | results = { 54 | "y_true": y_true, 55 | "y_pred": outputs, 56 | } 57 | return results 58 | 59 | def objective(self, results): 60 | raise NotImplementedError 61 | 62 | def evaluate(self, batch): 63 | """ 64 | Process the batch and update the log, without updating the model 65 | Args: 66 | - batch (tuple of Tensors): a batch of data yielded by data loaders 67 | Output: 68 | - results (dictionary): information about the batch, such as: 69 | - y_true (Tensor) 70 | - outputs (Tensor) 71 | - y_pred (Tensor) 72 | """ 73 | assert not self.is_training 74 | results = self.process_batch(batch) 75 | return results 76 | 77 | def update( 78 | self, 79 | batch, 80 | unlabeled_batch=None, 81 | target_marginal=None, 82 | source_marginal=None, 83 | target_average=None, 84 | is_epoch_end=False, 85 | ): 86 | """ 87 | Process the batch, update the log, and update the model 88 | Args: 89 | - batch (tuple of Tensors): a batch of data yielded by data loaders 90 | - unlabeled_batch (tuple of Tensors or None): a batch of data yielded by unlabeled data loader 91 | - is_epoch_end: whether this batch is the last batch of the epoch. if so, force optimizer to step, 92 | regardless of whether this batch idx divides self.gradient_accumulation_steps evenly 93 | Output: 94 | - results (dictionary): information about the batch, such as: 95 | - g (Tensor) 96 | - y_true (Tensor) 97 | - metadata (Tensor) 98 | - outputs (Tensor) 99 | - y_pred (Tensor) 100 | - objective (float) 101 | """ 102 | assert self.is_training 103 | 104 | # self.optimizer.zero_grad() 105 | 106 | # process this batch 107 | results = self.process_batch( 108 | batch, unlabeled_batch, target_marginal, source_marginal, target_average 109 | ) 110 | 111 | # update running statistics and update model if we've reached end of effective batch 112 | # iterate batch index 113 | if is_epoch_end: 114 | self.batch_idx = 0 115 | 116 | else: 117 | self.batch_idx += 1 118 | 119 | self._update(results) 120 | 121 | return results 122 | 123 | def _update(self, results): 124 | """ 125 | Computes the objective and updates the model. 126 | Also updates the results dictionary yielded by process_batch(). 127 | Should be overridden to change algorithm update beyond modifying the objective. 128 | """ 129 | # compute objective 130 | objective = ( 131 | self.objective(results) / self.gradient_accumulation_steps 132 | ) # normalize by gradient accumulation steps 133 | results["objective"] = objective.item() 134 | objective.backward() 135 | 136 | # import pdb; pdb.set_trace() 137 | if ( 138 | self.batch_idx 139 | ) % self.gradient_accumulation_steps == 0 or self.batch_idx == 0: 140 | self.optimizer.step() 141 | self.model.zero_grad() 142 | # self.optimizer.step() 143 | self.step_schedulers(is_epoch=False) 144 | 145 | if self.batch_idx == 0: 146 | self.step_schedulers(is_epoch=True) 147 | 148 | def step_schedulers(self, is_epoch): 149 | """ 150 | Updates the scheduler after an epoch. 151 | """ 152 | for scheduler in self.schedulers: 153 | if scheduler is None: 154 | continue 155 | if scheduler.step_every_batch: 156 | step_scheduler(scheduler) 157 | elif is_epoch: 158 | step_scheduler(scheduler) 159 | -------------------------------------------------------------------------------- /RLSbench/collate_functions.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import List 3 | import torch 4 | import torchvision.transforms as transforms 5 | import torchvision.transforms.functional as TF 6 | import pickle 7 | from typing import List 8 | import torch 9 | import os 10 | from torch.nn.utils.rnn import pad_sequence 11 | 12 | 13 | def initialize_collate_function(collate_function_name): 14 | """ 15 | Initializes collate_function that takes in a batch of samples 16 | returned by data loader and combines them into a tensor. 17 | This is standard for images and datasets where each element 18 | has the same size. But, for datasets like mimic and arxiv, 19 | where each element in the batch varies in size, the collate 20 | function must handle the padding etc.. 21 | """ 22 | 23 | if collate_function_name is None or collate_function_name.lower() == "none": 24 | return None 25 | elif ( 26 | collate_function_name.lower() == "mimic_readmission" 27 | or collate_function_name.lower() == "mimic_mortality" 28 | ): 29 | return collate_fn_mimic 30 | else: 31 | raise ValueError(f"{collate_function_name} not recognized") 32 | 33 | 34 | def collate_fn_mimic(batch): 35 | """ 36 | batch is a list, where each element is also a list of size 37 | at least two. The first element of the inner list is 38 | [code, type] and the second element is the label. The rest 39 | of the dimensions may contain auxiliary information. 40 | """ 41 | codes = [item[0][0] for item in batch] 42 | types = [item[0][1] for item in batch] 43 | target_and_aux = [item[1:] for item in batch] 44 | target_and_aux = list(zip(*target_and_aux)) 45 | target_and_aux = [torch.tensor(item) for item in target_and_aux] 46 | return [(codes, types), *target_and_aux] 47 | -------------------------------------------------------------------------------- /RLSbench/configs/supported.py: -------------------------------------------------------------------------------- 1 | # See algorithms/initializer.py 2 | algorithms = [ 3 | "ERM", 4 | "IS-ERM", 5 | "ERM-aug", 6 | "IS-ERM-aug", 7 | "ERM-oracle", 8 | "IS-ERM-oracle", 9 | "ERM-adv", 10 | "DANN", 11 | "CDANN", 12 | "IW-DANN", 13 | "IW-CDANN", 14 | "IS-DANN", 15 | "IS-CDANN", 16 | "COAL", 17 | "IW-COAL", 18 | "SENTRY", 19 | "IW-SENTRY", 20 | "FixMatch", 21 | "IW-FixMatch", 22 | "IS-FixMatch", 23 | "PseudoLabel", 24 | "IS-PseudoLabel", 25 | "NoisyStudent", 26 | "IS-NoisyStudent", 27 | "CORAL", 28 | "IS-CORAL", 29 | "BN_adapt", 30 | "BN_adapt-adv", 31 | "IS-BN_adapt", 32 | "IS-BN_adapt-adv", 33 | "TENT", 34 | "IS-TENT", 35 | ] 36 | 37 | label_shift_adapt = ["MLLS", "true", "RLLS", "None", "baseline"] 38 | 39 | 40 | # See transforms.py 41 | transforms = [ 42 | "image_base", 43 | "image_resize_and_center_crop", 44 | "image_none", 45 | "rxrx1", 46 | "clip", 47 | "bert", 48 | "None", 49 | ] 50 | 51 | additional_transforms = [ 52 | "randaugment", 53 | "weak", 54 | ] 55 | collate_functions = ["mimic_readmission", "mimic_mortality", "None"] 56 | # See models/initializer.py 57 | models = [ 58 | "resnet18", 59 | "resnet34", 60 | "resnet50", 61 | "resnet101", 62 | "densenet121", 63 | "clipvitb32", 64 | "clipvitb16", 65 | "clipvitl14", 66 | "efficientnet_b0", 67 | "mimic_model", 68 | "distilbert-base-uncased", 69 | "MLP", 70 | ] 71 | 72 | # Pre-training type 73 | pretrainining_options = ["clip", "imagenet", "swav", "rand", "bert"] 74 | 75 | # See optimizer.py 76 | optimizers = ["SGD", "Adam", "AdamW"] 77 | 78 | # See scheduler.py 79 | schedulers = [ 80 | "linear_schedule_with_warmup", 81 | "cosine_schedule_with_warmup", 82 | "ReduceLROnPlateau", 83 | "StepLR", 84 | "FixMatchLR", 85 | "MultiStepLR", 86 | ] 87 | 88 | # See losses.py 89 | losses = ["cross_entropy", "cross_entropy_logits"] 90 | -------------------------------------------------------------------------------- /RLSbench/configs/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from RLSbench.configs.algorithm import algorithm_defaults 4 | from RLSbench.configs.datasets import dataset_defaults 5 | 6 | 7 | def populate_defaults(config): 8 | """Populates hyperparameters with defaults implied by choices 9 | of other hyperparameters.""" 10 | 11 | orig_config = copy.deepcopy(config) 12 | assert config.dataset is not None, "dataset must be specified" 13 | assert config.algorithm is not None, "algorithm must be specified" 14 | 15 | config = populate_config(config, dataset_defaults[config.dataset]) 16 | 17 | # # implied defaults from choice of split 18 | # if config.dataset in split_defaults and config.split_scheme in split_defaults[config.dataset]: 19 | # config = populate_config( 20 | # config, 21 | # split_defaults[config.dataset][config.split_scheme] 22 | # ) 23 | 24 | # implied defaults from choice of algorithm 25 | config = populate_config(config, algorithm_defaults[config.algorithm]) 26 | 27 | # implied defaults from choice of model 28 | # if config.model: config = populate_config( 29 | # config, 30 | # model_defaults[config.model], 31 | # ) 32 | 33 | # # implied defaults from choice of scheduler 34 | # if config.scheduler: config = populate_config( 35 | # config, 36 | # scheduler_defaults[config.scheduler] 37 | # ) 38 | 39 | # # implied defaults from choice of loader 40 | # config = populate_config( 41 | # config, 42 | # loader_defaults 43 | # ) 44 | 45 | # import pdb; pdb.set_trace() 46 | if config.use_target: 47 | assert ( 48 | config.target_split is not None 49 | ), "target_split must be specified if use_target is True" 50 | 51 | if config.eval_only: 52 | assert ( 53 | config.use_source_model is not None 54 | ), "use_source_model must be True if eval_only is True" 55 | 56 | if config.use_source_model: 57 | assert ( 58 | config.source_model_path is not None 59 | ), "source_model_path must be specified if use_source_model is True" 60 | 61 | if config.simulate_label_shift: 62 | assert ( 63 | config.use_target is not None 64 | ), "when simulating label shift target split is needed" 65 | assert ( 66 | config.dirichlet_alpha is not None 67 | ), "when simulating label shift dirchilet alpha is needed" 68 | 69 | # basic checks 70 | required_fields = [ 71 | "batch_size", 72 | "model", 73 | "loss_function", 74 | "n_epochs", 75 | "optimizer", 76 | "lr", 77 | "weight_decay", 78 | "default_normalization", 79 | ] 80 | for field in required_fields: 81 | assert ( 82 | getattr(config, field) is not None 83 | ), f"Must manually specify {field} for this setup." 84 | 85 | if config.pretrain_type == "imagenet": 86 | config.pretrained = True 87 | 88 | if "clip" in config.model: 89 | config.pretrain_type = "clip" 90 | config.transform = "clip" 91 | config.pretrained = True 92 | config.optimizer = "AdamW" 93 | config.optimizer_kwargs = {} 94 | config.scheduler = "cosine_schedule_with_warmup" 95 | config.scheduler_kwargs = {"warmup_frac": 0.1} 96 | 97 | if not config.pretrained: 98 | assert ( 99 | config.pretrain_type == "rand" 100 | ), "When pre-trained loading is False, pre-train type must be rand" 101 | 102 | if config.algorithm in ( 103 | "ERM", 104 | "ERM-aug", 105 | "ERM-oracle", 106 | "IS-ERM", 107 | "IS-ERM-aug", 108 | "IS-ERM-oracle", 109 | ): 110 | config.algorithm = f"{config.algorithm}-{config.pretrain_type}" 111 | 112 | # if "NoisyStudent" in config.algorithm: 113 | # assert "teacher_model_path" in config.noisystudent_kwargs, "Teacher model path needed for noisy student training." 114 | 115 | if "SENTRY" in config.algorithm: 116 | import math 117 | 118 | config.batch_size = int((config.batch_size / 6)) 119 | config.gradient_accumulation_steps = 6 120 | 121 | if ( 122 | "DANN" in config.algorithm 123 | or "FixMatch" in config.algorithm 124 | or "PseudoLabel" in config.algorithm 125 | or "NoisyStudent" in config.algorithm 126 | or "COAL" in config.algorithm 127 | ): 128 | if "civilcomments" in config.dataset: 129 | config.batch_size = config.batch_size // 3 130 | config.gradient_accumulation_steps = 3 131 | else: 132 | config.batch_size = config.batch_size // 2 133 | config.gradient_accumulation_steps = 2 134 | 135 | return config 136 | 137 | 138 | def populate_config(config, template: dict, force_compatibility=False): 139 | """Populates missing (key, val) pairs in config with (key, val) in template. 140 | Example usage: populate config with defaults 141 | Args: 142 | - config: namespace 143 | - template: dict 144 | - force_compatibility: option to raise errors if config.key != template[key] 145 | """ 146 | if template is None: 147 | return config 148 | 149 | d_config = vars(config) 150 | for key, val in template.items(): 151 | if not isinstance(val, dict): # config[key] expected to be a non-index-able 152 | if key not in d_config or d_config[key] is None: 153 | d_config[key] = val 154 | elif d_config[key] != val and force_compatibility: 155 | raise ValueError(f"Argument {key} must be set to {val}") 156 | 157 | else: # config[key] expected to be a kwarg dict 158 | for kwargs_key, kwargs_val in val.items(): 159 | if kwargs_key not in d_config[key] or d_config[key][kwargs_key] is None: 160 | d_config[key][kwargs_key] = kwargs_val 161 | elif d_config[key][kwargs_key] != kwargs_val and force_compatibility: 162 | raise ValueError( 163 | f"Argument {key}[{kwargs_key}] must be set to {val}" 164 | ) 165 | return config 166 | -------------------------------------------------------------------------------- /RLSbench/data_augmentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acmi-lab/RLSbench/eb67d5c78aa3646b7369830e481b3f15a59a087d/RLSbench/data_augmentation/__init__.py -------------------------------------------------------------------------------- /RLSbench/data_augmentation/randaugment.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/YBZh/Bridging_UDA_SSL 2 | 3 | import torch 4 | from PIL import Image, ImageOps, ImageEnhance, ImageDraw 5 | 6 | 7 | def AutoContrast(img, _): 8 | return ImageOps.autocontrast(img) 9 | 10 | 11 | def Brightness(img, v): 12 | assert v >= 0.0 13 | return ImageEnhance.Brightness(img).enhance(v) 14 | 15 | 16 | def Color(img, v): 17 | assert v >= 0.0 18 | return ImageEnhance.Color(img).enhance(v) 19 | 20 | 21 | def Contrast(img, v): 22 | assert v >= 0.0 23 | return ImageEnhance.Contrast(img).enhance(v) 24 | 25 | 26 | def Equalize(img, _): 27 | return ImageOps.equalize(img) 28 | 29 | 30 | def Invert(img, _): 31 | return ImageOps.invert(img) 32 | 33 | 34 | def Identity(img, v): 35 | return img 36 | 37 | 38 | def Posterize(img, v): # [4, 8] 39 | v = int(v) 40 | v = max(1, v) 41 | return ImageOps.posterize(img, v) 42 | 43 | 44 | def Rotate(img, v): # [-30, 30] 45 | return img.rotate(v) 46 | 47 | 48 | def Sharpness(img, v): # [0.1,1.9] 49 | assert v >= 0.0 50 | return ImageEnhance.Sharpness(img).enhance(v) 51 | 52 | 53 | def ShearX(img, v): # [-0.3, 0.3] 54 | return img.transform(img.size, Image.AFFINE, (1, v, 0, 0, 1, 0)) 55 | 56 | 57 | def ShearY(img, v): # [-0.3, 0.3] 58 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, v, 1, 0)) 59 | 60 | 61 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 62 | v = v * img.size[0] 63 | return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0)) 64 | 65 | 66 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 67 | return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0)) 68 | 69 | 70 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 71 | v = v * img.size[1] 72 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v)) 73 | 74 | 75 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 76 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v)) 77 | 78 | 79 | def Solarize(img, v): # [0, 256] 80 | assert 0 <= v <= 256 81 | return ImageOps.solarize(img, v) 82 | 83 | 84 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] => change to [0, 0.5] 85 | assert 0.0 <= v <= 0.5 86 | 87 | v = v * img.size[0] 88 | return CutoutAbs(img, v) 89 | 90 | 91 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 92 | if v < 0: 93 | return img 94 | w, h = img.size 95 | x_center = _sample_uniform(0, w) 96 | y_center = _sample_uniform(0, h) 97 | 98 | x0 = int(max(0, x_center - v / 2.0)) 99 | y0 = int(max(0, y_center - v / 2.0)) 100 | x1 = min(w, x0 + v) 101 | y1 = min(h, y0 + v) 102 | 103 | xy = (x0, y0, x1, y1) 104 | color = (125, 123, 114) 105 | img = img.copy() 106 | ImageDraw.Draw(img).rectangle(xy, color) 107 | return img 108 | 109 | 110 | FIX_MATCH_AUGMENTATION_POOL = [ 111 | (AutoContrast, 0, 1), 112 | (Brightness, 0.05, 0.95), 113 | (Color, 0.05, 0.95), 114 | (Contrast, 0.05, 0.95), 115 | (Equalize, 0, 1), 116 | (Identity, 0, 1), 117 | (Posterize, 4, 8), 118 | (Rotate, -30, 30), 119 | (Sharpness, 0.05, 0.95), 120 | (ShearX, -0.3, 0.3), 121 | (ShearY, -0.3, 0.3), 122 | (Solarize, 0, 256), 123 | (TranslateX, -0.3, 0.3), 124 | (TranslateY, -0.3, 0.3), 125 | ] 126 | 127 | 128 | def _sample_uniform(a, b): 129 | return torch.empty(1).uniform_(a, b).item() 130 | 131 | 132 | class RandAugment: 133 | def __init__(self, n, augmentation_pool): 134 | assert n >= 1, "RandAugment N has to be a value greater than or equal to 1." 135 | self.n = n 136 | self.augmentation_pool = augmentation_pool 137 | 138 | def __call__(self, img): 139 | ops = [ 140 | self.augmentation_pool[torch.randint(len(self.augmentation_pool), (1,))] 141 | for _ in range(self.n) 142 | ] 143 | for op, min_val, max_val in ops: 144 | val = min_val + float(max_val - min_val) * _sample_uniform(0, 1) 145 | img = op(img, val) 146 | cutout_val = _sample_uniform(0, 1) * 0.5 147 | img = Cutout(img, cutout_val) 148 | return img 149 | -------------------------------------------------------------------------------- /RLSbench/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from RLSbench.datasets.data_utils import * 2 | from RLSbench.datasets.get_dataset import * 3 | 4 | benchmark_datasets = [ 5 | "camelyon", 6 | "iwildcam", 7 | "fmow", 8 | "cifar10", 9 | "cifar100", 10 | "domainnet", 11 | "entity13", 12 | "entity30", 13 | "living17", 14 | "nonliving26", 15 | "office31", 16 | "officehome", 17 | "visda", 18 | "civilcomments", 19 | "amazon", 20 | "retiring_adult", 21 | "mimic_readmission", 22 | ] 23 | 24 | supported_datasets = benchmark_datasets 25 | 26 | dataset_map = { 27 | "cifar10": get_cifar10, 28 | "cifar100": get_cifar100, 29 | "office31": get_office31, 30 | "officehome": get_officehome, 31 | "visda": get_visda, 32 | "domainnet": get_domainnet, 33 | "entity13": get_entity13, 34 | "entity30": get_entity30, 35 | "living17": get_living17, 36 | "nonliving26": get_nonliving26, 37 | "fmow": get_fmow, 38 | "iwildcam": get_iwildcams, 39 | "rxrx1": get_rxrx1, 40 | "camelyon": get_camelyon, 41 | "civilcomments": get_civilcomments, 42 | "amazon": get_amazon, 43 | "retiring_adult": get_retiring_adult, 44 | "mimic_readmission": get_mimic_readmission, 45 | } 46 | 47 | 48 | def get_dataset( 49 | dataset, 50 | source=True, 51 | target=False, 52 | root_dir=None, 53 | target_split=None, 54 | transforms=None, 55 | num_classes=None, 56 | split_fraction=0.8, 57 | seed=42, 58 | ): 59 | """Get dataset. 60 | 61 | Args: 62 | dataset (str): Name of the dataset. 63 | source (bool): Whether to return the source dataset. 64 | target (bool): Whether to return the target dataset. 65 | root_dir (str): Path to the root directory of the dataset. 66 | target_split (int): Num of the target split. 67 | transforms (dict): Dictionary of transformations. 68 | num_classes (int): Number of classes. 69 | split_fraction (float): Fraction of the dataset to use for training. 70 | seed (int): Random seed. 71 | 72 | Returns: 73 | dataset (torch.utils.data.Dataset): Dataset. 74 | """ 75 | 76 | return dataset_map[dataset]( 77 | source, 78 | target, 79 | root_dir, 80 | target_split, 81 | transforms, 82 | num_classes, 83 | split_fraction, 84 | seed, 85 | ) 86 | -------------------------------------------------------------------------------- /RLSbench/losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from wilds.common.metrics.loss import ElementwiseLoss 3 | 4 | # from wilds.common.metrics.all_metrics import MSE 5 | from RLSbench.utils import cross_entropy_with_logits_loss 6 | 7 | 8 | def initialize_loss(loss, reduction="mean"): 9 | if loss == "cross_entropy": 10 | return nn.CrossEntropyLoss(reduction=reduction) 11 | 12 | elif loss == "cross_entropy_logits": 13 | return ElementwiseLoss(loss_fn=cross_entropy_with_logits_loss) 14 | 15 | else: 16 | raise ValueError(f"loss {loss} not recognized") 17 | -------------------------------------------------------------------------------- /RLSbench/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acmi-lab/RLSbench/eb67d5c78aa3646b7369830e481b3f15a59a087d/RLSbench/models/__init__.py -------------------------------------------------------------------------------- /RLSbench/models/cifar_efficientnet.py: -------------------------------------------------------------------------------- 1 | """EfficientNet in PyTorch. 2 | Paper: "EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks". 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Block(nn.Module): 10 | """expand + depthwise + pointwise + squeeze-excitation""" 11 | 12 | def __init__(self, in_planes, out_planes, expansion, stride): 13 | super(Block, self).__init__() 14 | self.stride = stride 15 | 16 | planes = expansion * in_planes 17 | self.conv1 = nn.Conv2d( 18 | in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False 19 | ) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d( 22 | planes, 23 | planes, 24 | kernel_size=3, 25 | stride=stride, 26 | padding=1, 27 | groups=planes, 28 | bias=False, 29 | ) 30 | self.bn2 = nn.BatchNorm2d(planes) 31 | self.conv3 = nn.Conv2d( 32 | planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False 33 | ) 34 | self.bn3 = nn.BatchNorm2d(out_planes) 35 | 36 | self.shortcut = nn.Sequential() 37 | if stride == 1 and in_planes != out_planes: 38 | self.shortcut = nn.Sequential( 39 | nn.Conv2d( 40 | in_planes, 41 | out_planes, 42 | kernel_size=1, 43 | stride=1, 44 | padding=0, 45 | bias=False, 46 | ), 47 | nn.BatchNorm2d(out_planes), 48 | ) 49 | 50 | # SE layers 51 | self.fc1 = nn.Conv2d(out_planes, out_planes // 16, kernel_size=1) 52 | self.fc2 = nn.Conv2d(out_planes // 16, out_planes, kernel_size=1) 53 | 54 | def forward(self, x): 55 | out = F.relu(self.bn1(self.conv1(x))) 56 | out = F.relu(self.bn2(self.conv2(out))) 57 | out = self.bn3(self.conv3(out)) 58 | shortcut = self.shortcut(x) if self.stride == 1 else out 59 | # Squeeze-Excitation 60 | w = F.avg_pool2d(out, out.size(2)) 61 | w = F.relu(self.fc1(w)) 62 | w = self.fc2(w).sigmoid() 63 | out = out * w + shortcut 64 | return out 65 | 66 | 67 | class EfficientNet(nn.Module): 68 | def __init__(self, cfg, num_classes=10, features=False): 69 | super(EfficientNet, self).__init__() 70 | self.cfg = cfg 71 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 72 | self.bn1 = nn.BatchNorm2d(32) 73 | self.layers = self._make_layers(in_planes=32) 74 | self.linear = nn.Linear(cfg[-1][1], num_classes) 75 | self.features = features 76 | 77 | def _make_layers(self, in_planes): 78 | layers = [] 79 | for expansion, out_planes, num_blocks, stride in self.cfg: 80 | strides = [stride] + [1] * (num_blocks - 1) 81 | for stride in strides: 82 | layers.append(Block(in_planes, out_planes, expansion, stride)) 83 | in_planes = out_planes 84 | return nn.Sequential(*layers) 85 | 86 | def forward(self, x): 87 | out = F.relu(self.bn1(self.conv1(x))) 88 | out = self.layers(out) 89 | out = out.view(out.size(0), -1) 90 | 91 | if self.features: 92 | return out 93 | 94 | else: 95 | out = self.linear(out) 96 | return out 97 | 98 | 99 | def EfficientNetB0(num_classes=10, features=False): 100 | # (expansion, out_planes, num_blocks, stride) 101 | cfg = [ 102 | (1, 16, 1, 2), 103 | (6, 24, 2, 1), 104 | (6, 40, 2, 2), 105 | (6, 80, 3, 2), 106 | (6, 112, 3, 1), 107 | (6, 192, 4, 2), 108 | (6, 320, 1, 2), 109 | ] 110 | return EfficientNet(cfg, num_classes, features) 111 | -------------------------------------------------------------------------------- /RLSbench/models/cifar_resnet.py: -------------------------------------------------------------------------------- 1 | """ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | """ 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.expansion = 1 20 | self.conv1 = nn.Conv2d( 21 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 22 | ) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.conv2 = nn.Conv2d( 25 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 26 | ) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | 29 | self.shortcut = nn.Sequential() 30 | if stride != 1 or in_planes != self.expansion * planes: 31 | self.shortcut = nn.Sequential( 32 | nn.Conv2d( 33 | in_planes, 34 | self.expansion * planes, 35 | kernel_size=1, 36 | stride=stride, 37 | bias=False, 38 | ), 39 | nn.BatchNorm2d(self.expansion * planes), 40 | ) 41 | 42 | def forward(self, x): 43 | out = F.relu(self.bn1(self.conv1(x))) 44 | out = self.bn2(self.conv2(out)) 45 | out += self.shortcut(x) 46 | out = F.relu(out) 47 | return out 48 | 49 | 50 | class Bottleneck(nn.Module): 51 | expansion = 4 52 | 53 | def __init__(self, in_planes, planes, stride=1): 54 | super(Bottleneck, self).__init__() 55 | self.expansion = 4 56 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(planes) 58 | self.conv2 = nn.Conv2d( 59 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 60 | ) 61 | self.bn2 = nn.BatchNorm2d(planes) 62 | self.conv3 = nn.Conv2d( 63 | planes, self.expansion * planes, kernel_size=1, bias=False 64 | ) 65 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 66 | 67 | self.shortcut = nn.Sequential() 68 | if stride != 1 or in_planes != self.expansion * planes: 69 | self.shortcut = nn.Sequential( 70 | nn.Conv2d( 71 | in_planes, 72 | self.expansion * planes, 73 | kernel_size=1, 74 | stride=stride, 75 | bias=False, 76 | ), 77 | nn.BatchNorm2d(self.expansion * planes), 78 | ) 79 | 80 | def forward(self, x): 81 | out = F.relu(self.bn1(self.conv1(x))) 82 | out = F.relu(self.bn2(self.conv2(out))) 83 | out = self.bn3(self.conv3(out)) 84 | out += self.shortcut(x) 85 | out = F.relu(out) 86 | return out 87 | 88 | 89 | class ResNet(nn.Module): 90 | def __init__(self, block, num_blocks, num_classes=10, features=False): 91 | super(ResNet, self).__init__() 92 | self.num_classes = num_classes 93 | self.in_planes = 64 94 | 95 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 96 | self.bn1 = nn.BatchNorm2d(64) 97 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 98 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 99 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 100 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 101 | self.linear = nn.Linear(512 * block.expansion, num_classes) 102 | self.features = features 103 | 104 | def _make_layer(self, block, planes, num_blocks, stride): 105 | strides = [stride] + [1] * (num_blocks - 1) 106 | layers = [] 107 | for stride in strides: 108 | layers.append(block(self.in_planes, planes, stride)) 109 | self.in_planes = planes * block.expansion 110 | return nn.Sequential(*layers) 111 | 112 | def forward(self, x): 113 | out = F.relu(self.bn1(self.conv1(x))) 114 | out = self.layer1(out) 115 | out = self.layer2(out) 116 | out = self.layer3(out) 117 | out = self.layer4(out) 118 | out = F.avg_pool2d(out, 4) 119 | out = out.view(out.size(0), -1) 120 | if self.features: 121 | return out 122 | 123 | else: 124 | final_out = self.linear(out) 125 | return final_out 126 | 127 | 128 | def ResNet18(num_classes=10, features=False): 129 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, features=features) 130 | 131 | 132 | def ResNet34(num_classes=10, features=False): 133 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, features=features) 134 | 135 | 136 | def ResNet50(num_classes=10, features=False): 137 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, features=features) 138 | 139 | 140 | def ResNet101(num_classes=10, features=False): 141 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, features=features) 142 | 143 | 144 | def ResNet152(num_classes=10, features=False): 145 | return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, features=features) 146 | -------------------------------------------------------------------------------- /RLSbench/models/clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from clip import clip 4 | import math 5 | 6 | representations_dims = { 7 | "RN50": 1024, 8 | "RN101": 512, 9 | "RN50x4": 640, 10 | "RN50x16": 768, 11 | "ViT-B/32": 512, 12 | "ViT-B/16": 512, 13 | "ViT-L/14": 768, 14 | } 15 | 16 | 17 | class LinearWrapper(torch.nn.Module): 18 | def __init__(self, in_features, num_classes, initial_weights=None): 19 | super(LinearWrapper, self).__init__() 20 | 21 | self.classification_head = torch.nn.Linear(in_features, num_classes) 22 | 23 | if initial_weights is not None and type(initial_weights) == tuple: 24 | print("tuple.") 25 | w, b = initial_weights 26 | self.classification_head.weight = torch.nn.Parameter(w.clone()) 27 | self.classification_head.bias = torch.nn.Parameter(b.clone()) 28 | else: 29 | if initial_weights is None: 30 | initial_weights = torch.zeros_like(self.classification_head.weight) 31 | torch.nn.init.kaiming_uniform_(initial_weights, a=math.sqrt(5)) 32 | self.classification_head.weight = torch.nn.Parameter( 33 | initial_weights.clone() 34 | ) 35 | # Note: modified. Initial bug in forgetting to zero bias. 36 | self.classification_head.bias = torch.nn.Parameter( 37 | torch.zeros_like(self.classification_head.bias) 38 | ) 39 | 40 | def forward(self, features): 41 | return self.classification_head(features) 42 | 43 | 44 | class ModelWrapper(torch.nn.Module): 45 | def __init__(self, backbone, normalize=True): 46 | super(ModelWrapper, self).__init__() 47 | 48 | self.model, _ = clip.load(backbone) 49 | in_features = self.model.visual.output_dim 50 | self.d_out = in_features 51 | self.normalize = normalize 52 | self.model.visual.float() 53 | 54 | # Note: modified. Get rid of the language part. 55 | delattr(self.model, "transformer") 56 | 57 | def forward(self, images): 58 | # with torch.no_grad(): 59 | features = self.model.encode_image(images) 60 | if self.normalize: 61 | features = features / features.norm(dim=-1, keepdim=True) 62 | return features 63 | 64 | 65 | def ClipRN50(num_classes=10): 66 | model = ModelWrapper("RN50") 67 | classifier = LinearWrapper(model.d_out, num_classes) 68 | 69 | return (model, classifier) 70 | 71 | 72 | def ClipRN101(num_classes=10): 73 | model = ModelWrapper("RN101") 74 | classifier = LinearWrapper(model.d_out, num_classes) 75 | 76 | return (model, classifier) 77 | 78 | 79 | def ClipRN50x4(num_classes=10): 80 | model = ModelWrapper("RN50x4") 81 | classifier = LinearWrapper(model.d_out, num_classes) 82 | 83 | return (model, classifier) 84 | 85 | 86 | def ClipRN50x16(num_classes=10): 87 | model = ModelWrapper("RN50x16") 88 | classifier = LinearWrapper(model.d_out, num_classes) 89 | 90 | return (model, classifier) 91 | 92 | 93 | def ClipViTB16(num_classes=10): 94 | model = ModelWrapper("ViT-B/16") 95 | classifier = LinearWrapper(model.d_out, num_classes) 96 | 97 | return (model, classifier) 98 | 99 | 100 | def ClipViTB32(num_classes=10): 101 | model = ModelWrapper("ViT-B/32") 102 | classifier = LinearWrapper(model.d_out, num_classes) 103 | 104 | return (model, classifier) 105 | 106 | 107 | def ClipViTL14(num_classes=10): 108 | model = ModelWrapper("ViT-L/14") 109 | classifier = LinearWrapper(model.d_out, num_classes) 110 | 111 | return (model, classifier) 112 | -------------------------------------------------------------------------------- /RLSbench/models/domain_adversarial_network.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | import torch.nn.functional as F 7 | 8 | 9 | class DomainDiscriminator(nn.Sequential): 10 | """ 11 | Adapted from https://github.com/thuml/Transfer-Learning-Library 12 | 13 | Domain discriminator model from 14 | `"Domain-Adversarial Training of Neural Networks" `_ 15 | In the original paper and implementation, we distinguish whether the input features come 16 | from the source domain or the target domain. 17 | 18 | We extended this to work with multiple domains, which is controlled by the n_domains 19 | argument. 20 | 21 | Args: 22 | in_feature (int): dimension of the input feature 23 | n_domains (int): number of domains to discriminate 24 | hidden_size (int): dimension of the hidden features 25 | batch_norm (bool): whether use :class:`~torch.nn.BatchNorm1d`. 26 | Use :class:`~torch.nn.Dropout` if ``batch_norm`` is False. Default: True. 27 | Shape: 28 | - Inputs: (minibatch, `in_feature`) 29 | - Outputs: :math:`(minibatch, n_domains)` 30 | """ 31 | 32 | def __init__( 33 | self, in_feature: int, n_domains, hidden_size: int = 1024, batch_norm=True 34 | ): 35 | if batch_norm: 36 | super(DomainDiscriminator, self).__init__( 37 | nn.Linear(in_feature, hidden_size), 38 | nn.BatchNorm1d(hidden_size), 39 | nn.ReLU(), 40 | nn.Linear(hidden_size, hidden_size), 41 | nn.BatchNorm1d(hidden_size), 42 | nn.ReLU(), 43 | nn.Linear(hidden_size, n_domains), 44 | ) 45 | else: 46 | super(DomainDiscriminator, self).__init__( 47 | nn.Linear(in_feature, hidden_size), 48 | nn.ReLU(inplace=True), 49 | nn.Dropout(0.5), 50 | nn.Linear(hidden_size, hidden_size), 51 | nn.ReLU(inplace=True), 52 | nn.Dropout(0.5), 53 | nn.Linear(hidden_size, n_domains), 54 | ) 55 | 56 | def get_parameters_with_lr(self, lr) -> List[Dict]: 57 | return [{"params": self.parameters(), "lr": lr}] 58 | 59 | 60 | class GradientReverseFunction(Function): 61 | """ 62 | Credit: https://github.com/thuml/Transfer-Learning-Library 63 | """ 64 | 65 | @staticmethod 66 | def forward( 67 | ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.0 68 | ) -> torch.Tensor: 69 | ctx.coeff = coeff 70 | output = input * 1.0 71 | return output 72 | 73 | @staticmethod 74 | def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]: 75 | return grad_output.neg() * ctx.coeff, None 76 | 77 | 78 | class GradientReverseLayer(nn.Module): 79 | """ 80 | Credit: https://github.com/thuml/Transfer-Learning-Library 81 | """ 82 | 83 | def __init__(self): 84 | super(GradientReverseLayer, self).__init__() 85 | 86 | def forward(self, input, coeff): 87 | return GradientReverseFunction.apply(input, coeff) 88 | 89 | 90 | class DomainAdversarialNetwork(nn.Module): 91 | def __init__( 92 | self, featurizer, classifier, n_domains, num_classes, bottleneck_dim=256 93 | ): 94 | super().__init__() 95 | self.featurizer = featurizer 96 | self.classifier = classifier 97 | # self.classifier = nn.Linear(bottleneck_dim, num_classes) 98 | # self.bottleneck = nn.Linear(featurizer.d_out, bottleneck_dim) 99 | # self.domain_classifier = DomainDiscriminator(featurizer.d_out, n_domains) 100 | self.domain_classifier = DomainDiscriminator( 101 | featurizer.d_out, n_domains, batch_norm=False 102 | ) 103 | 104 | self.gradient_reverse_layer = GradientReverseLayer() 105 | 106 | def forward(self, input, coeff=1.0, domain_classifier=False): 107 | features = self.featurizer(input) 108 | # features = self.bottleneck(features) 109 | y_pred = self.classifier(features) 110 | if domain_classifier: 111 | features = self.gradient_reverse_layer(features, coeff) 112 | domains_pred = self.domain_classifier(features) 113 | return y_pred, domains_pred 114 | else: 115 | return y_pred 116 | 117 | def get_parameters_with_lr( 118 | self, featurizer_lr, classifier_lr, discriminator_lr 119 | ) -> List[Dict]: 120 | """ 121 | Adapted from https://github.com/thuml/Transfer-Learning-Library 122 | 123 | A parameter list which decides optimization hyper-parameters, 124 | such as the relative learning rate of each layer 125 | """ 126 | # In TLL's implementation, the learning rate of this classifier is set 10 times to that of the 127 | # feature extractor for better accuracy by default. For our implementation, we allow the learning 128 | # rates to be passed in separately for featurizer and classifier. 129 | params = [ 130 | {"params": self.featurizer.parameters(), "lr": featurizer_lr}, 131 | # {"params": self.bottleneck.parameters(), "lr": classifier_lr}, 132 | {"params": self.classifier.parameters(), "lr": classifier_lr}, 133 | ] 134 | return params + self.domain_classifier.get_parameters_with_lr(discriminator_lr) 135 | 136 | 137 | class classifier_deep(nn.Module): 138 | def __init__(self, num_classes, inc=4096, temp=0.05): 139 | super(classifier_deep, self).__init__() 140 | self.fc1 = nn.Linear(inc, 512) 141 | self.fc2 = nn.Linear(512, num_classes, bias=False) 142 | self.gradient_reverse_layer = GradientReverseLayer() 143 | self.temp = temp 144 | 145 | def forward(self, x, reverse=False, eta=0.1): 146 | x = self.fc1(x) 147 | if reverse: 148 | x = self.gradient_reverse_layer(x, eta) 149 | 150 | # x = F.normalize(x) 151 | x_out = self.fc2(x) # / self.temp 152 | return x_out 153 | 154 | 155 | class COALNetwork(nn.Module): 156 | def __init__(self, featurizer, num_classes): 157 | super().__init__() 158 | self.featurizer = featurizer 159 | self.classifier = classifier_deep(num_classes=num_classes, inc=featurizer.d_out) 160 | 161 | def forward(self, input, coeff=1.0, reverse=False): 162 | features = self.featurizer(input) 163 | 164 | y_pred = self.classifier(features, reverse=reverse, eta=coeff) 165 | return y_pred 166 | 167 | def get_parameters_with_lr(self, featurizer_lr, classifier_lr) -> List[Dict]: 168 | """ 169 | Adapted from https://github.com/thuml/Transfer-Learning-Library 170 | 171 | A parameter list which decides optimization hyper-parameters, 172 | such as the relative learning rate of each layer 173 | """ 174 | # In TLL's implementation, the learning rate of this classifier is set 10 times to that of the 175 | # feature extractor for better accuracy by default. For our implementation, we allow the learning 176 | # rates to be passed in separately for featurizer and classifier. 177 | params = [ 178 | {"params": self.featurizer.parameters(), "lr": featurizer_lr}, 179 | {"params": self.classifier.parameters(), "lr": classifier_lr}, 180 | ] 181 | return params 182 | 183 | 184 | class ConditionalDomainAdversarialNetwork(nn.Module): 185 | def __init__( 186 | self, featurizer, classifier, n_domains, num_classes, bottleneck_dim=256 187 | ): 188 | super().__init__() 189 | self.featurizer = featurizer 190 | self.classifier = nn.Linear(bottleneck_dim, num_classes) 191 | self.bottleneck = nn.Linear(featurizer.d_out, bottleneck_dim) 192 | self.domain_classifier = DomainDiscriminator( 193 | bottleneck_dim * num_classes, n_domains, batch_norm=False 194 | ) 195 | self.gradient_reverse_layer = GradientReverseLayer() 196 | 197 | def forward(self, input, coeff=1.0, domain_classifier=False): 198 | features = self.featurizer(input) 199 | features = self.bottleneck(features) 200 | 201 | y_pred = self.classifier(features) 202 | 203 | if domain_classifier: 204 | softmax_out = F.softmax(y_pred, dim=1).detach() 205 | op_out = torch.bmm(softmax_out.unsqueeze(2), features.unsqueeze(1)) 206 | op_out = op_out.view(-1, softmax_out.size(1) * features.size(1)) 207 | 208 | op_out = self.gradient_reverse_layer(op_out, coeff) 209 | # op_out.register_hook(grl_hook(coeff)) 210 | # import pdb; pdb.set_trace() 211 | domains_pred = self.domain_classifier(op_out) 212 | return y_pred, domains_pred 213 | else: 214 | return y_pred 215 | 216 | def get_parameters_with_lr( 217 | self, featurizer_lr, classifier_lr, discriminator_lr 218 | ) -> List[Dict]: 219 | """ 220 | Adapted from https://github.com/thuml/Transfer-Learning-Library 221 | 222 | A parameter list which decides optimization hyper-parameters, 223 | such as the relative learning rate of each layer 224 | """ 225 | # In TLL's implementation, the learning rate of this classifier is set 10 times to that of the 226 | # feature extractor for better accuracy by default. For our implementation, we allow the learning 227 | # rates to be passed in separately for featurizer and classifier. 228 | params = [ 229 | {"params": self.featurizer.parameters(), "lr": featurizer_lr}, 230 | {"params": self.bottleneck.parameters(), "lr": classifier_lr}, 231 | {"params": self.classifier.parameters(), "lr": classifier_lr}, 232 | ] 233 | return params + self.domain_classifier.get_parameters_with_lr(discriminator_lr) 234 | -------------------------------------------------------------------------------- /RLSbench/models/initializer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import antialiased_cnns 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision 9 | from RLSbench.utils import load 10 | 11 | logger = logging.getLogger("label_shift") 12 | 13 | 14 | class Identity(nn.Module): 15 | """An identity layer""" 16 | 17 | def __init__(self, d): 18 | super().__init__() 19 | self.in_features = d 20 | self.out_features = d 21 | 22 | def forward(self, x): 23 | return x 24 | 25 | 26 | def initialize_model( 27 | model_name, 28 | dataset_name, 29 | num_classes, 30 | featurize=False, 31 | in_features=None, 32 | pretrained=False, 33 | pretrained_path=None, 34 | data_dir=None, 35 | ): 36 | """ 37 | Initializes models according to the config 38 | Args: 39 | - model_name: name of the model 40 | - dataset_name: name of the dataset 41 | - num_classes (int): number of classes in the dataset 42 | - is_featurizer (bool): whether to return a model or a (featurizer, classifier) pair that constitutes a model. 43 | Output: 44 | If is_featurizer=True: 45 | - featurizer: a model that outputs feature Tensors of shape (batch_size, ..., feature dimensionality) 46 | - classifier: a model that takes in feature Tensors and outputs predictions. In most cases, this is a linear layer. 47 | 48 | If is_featurizer=False: 49 | - model: a model that is equivalent to nn.Sequential(featurizer, classifier) 50 | """ 51 | 52 | if ( 53 | model_name 54 | in ("resnet18", "resnet34", "resnet50", "resnet101", "efficientnetb0") 55 | and "cifar" in dataset_name 56 | ): 57 | from RLSbench.models.cifar_efficientnet import EfficientNetB0 58 | from RLSbench.models.cifar_resnet import ResNet18, ResNet34, ResNet50, ResNet101 59 | 60 | arch = { 61 | "resnet18": ResNet18, 62 | "resnet34": ResNet34, 63 | "resnet50": ResNet50, 64 | "resnet101": ResNet101, 65 | "efficientnet_b0": EfficientNetB0, 66 | } 67 | 68 | featurizer = arch[model_name](num_classes=1000, features=True) 69 | 70 | if pretrained: 71 | assert ( 72 | pretrained_path is not None 73 | ), "Must provide pretrained_path if pretrained=True" 74 | load(featurizer, pretrained_path) 75 | 76 | d_out = getattr(featurizer, "linear").in_features 77 | featurizer.d_out = d_out 78 | classifier = nn.Linear(d_out, num_classes) 79 | model = (featurizer, classifier) 80 | 81 | if not featurize: 82 | model = nn.Sequential(*model) 83 | 84 | elif model_name in ( 85 | "resnet18", 86 | "resnet34", 87 | "resnet50", 88 | "resnet101", 89 | "densenet121", 90 | "efficientnet_b0", 91 | ): 92 | featurizer = initialize_torchvision_model( 93 | name=model_name, d_out=None, pretrained=pretrained 94 | ) 95 | 96 | classifier = nn.Linear(featurizer.d_out, num_classes) 97 | model = (featurizer, classifier) 98 | 99 | # if pretrained_type in ('swav'): 100 | # load(model[0], pretrained_path) 101 | 102 | if not featurize: 103 | model = nn.Sequential(*model) 104 | 105 | elif model_name in ( 106 | "cliprn50", 107 | "cliprn101", 108 | "clipvitb16", 109 | "clipvitb32", 110 | "clipvitl14", 111 | ): 112 | from RLSbench.models.clip import ( 113 | ClipRN50, 114 | ClipRN101, 115 | ClipViTB16, 116 | ClipViTB32, 117 | ClipViTL14, 118 | ) 119 | 120 | arch = { 121 | "cliprn50": ClipRN50, 122 | "cliprn101": ClipRN101, 123 | "clipvitb16": ClipViTB16, 124 | "clipvitb32": ClipViTB32, 125 | "clipvitl14": ClipViTL14, 126 | } 127 | 128 | model = arch[model_name](num_classes=num_classes) 129 | 130 | if not featurize: 131 | model = nn.Sequential(*model) 132 | 133 | elif model_name in ("mimic_network"): 134 | from RLSbench.models.mimic_model import Transformer 135 | 136 | featurizer = Transformer( 137 | data_dir, embedding_size=128, dropout=0.5, layers=2, heads=2 138 | ) 139 | classifier = nn.Linear(featurizer.d_out, 2) 140 | model = (featurizer, classifier) 141 | 142 | if not featurize: 143 | model = nn.Sequential(*model) 144 | 145 | elif model_name in ("distilbert-base-uncased"): 146 | from RLSbench.models.transformers import initialize_bert_based_model 147 | 148 | featurizer = initialize_bert_based_model(model_name, num_classes) 149 | 150 | classifier = nn.Linear(featurizer.d_out, num_classes) 151 | 152 | model = (featurizer, classifier) 153 | 154 | if not featurize: 155 | model = nn.Sequential(*model) 156 | 157 | elif model_name in ("MLP"): 158 | featurizer = nn.Sequential( 159 | nn.Flatten(), 160 | nn.Linear(10, 100, bias=True), 161 | nn.ReLU(), 162 | nn.Linear(100, 100, bias=True), 163 | nn.ReLU(), 164 | ) 165 | 166 | featurizer.d_out = 100 167 | classifier = nn.Linear(100, num_classes) 168 | 169 | model = (featurizer, classifier) 170 | 171 | if not featurize: 172 | model = nn.Sequential(*model) 173 | 174 | elif model_name == "logistic_regression": 175 | assert not featurize, "Featurizer not supported for logistic regression" 176 | model = nn.Linear(in_features=in_features, out_features=num_classes) 177 | 178 | else: 179 | raise ValueError(f"Model: {model_name} not recognized.") 180 | 181 | return model 182 | 183 | 184 | def initialize_torchvision_model(name, d_out, pretrained=True): 185 | # get constructor and last layer names 186 | if name == "wideresnet50": 187 | constructor_name = "wide_resnet50_2" 188 | last_layer_name = "fc" 189 | elif name == "densenet121": 190 | constructor_name = name 191 | last_layer_name = "classifier" 192 | elif name in ("resnet18", "resnet34", "resnet50", "resnet101"): 193 | constructor_name = name 194 | last_layer_name = "fc" 195 | elif name in ("efficientnet_b0"): 196 | constructor_name = name 197 | last_layer_name = "classifier" 198 | else: 199 | raise ValueError(f"Torchvision model {name} not recognized") 200 | # construct the default model, which has the default last layer 201 | constructor = getattr(antialiased_cnns, constructor_name) 202 | model = constructor(pretrained=pretrained) 203 | # adjust the last layer 204 | d_features = getattr(model, last_layer_name).in_features 205 | if d_out is None: # want to initialize a featurizer model 206 | last_layer = Identity(d_features) 207 | model.d_out = d_features 208 | else: # want to initialize a classifier for a particular num_classes 209 | last_layer = nn.Linear(d_features, d_out) 210 | model.d_out = d_out 211 | setattr(model, last_layer_name, last_layer) 212 | 213 | return model 214 | -------------------------------------------------------------------------------- /RLSbench/models/mdd_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from RLSbench.models.domain_adversarial_network import ( 5 | GradientReverseFunction, 6 | GradientReverseLayer, 7 | ) 8 | 9 | 10 | class MDDNet(nn.Module): 11 | def __init__( 12 | self, 13 | featurizer, 14 | class_num, 15 | bottleneck_dim=1024, 16 | classifier_width=1024, 17 | classifier_depth=2, 18 | ): 19 | super().__init__() 20 | 21 | self.class_num = class_num 22 | self.bottleneck_dim = bottleneck_dim 23 | self.freeze_backbone = False 24 | self.normalize_features = False 25 | 26 | self.base_network = featurizer 27 | 28 | self.use_bottleneck = True 29 | self.create_bottleneck_layer(use_dropout=True) 30 | 31 | self.create_f_and_fhat_classifiers( 32 | bottleneck_dim, 33 | classifier_width, 34 | class_num, 35 | classifier_depth, 36 | use_dropout=True, 37 | ) 38 | 39 | self.softmax = nn.Softmax(dim=1) 40 | 41 | # collect parameters 42 | self.parameter_list = [ 43 | {"params": self.base_network.parameters(), "lr_scale": 0.1}, 44 | {"params": self.bottleneck_layer.parameters(), "lr_scale": 1}, 45 | {"params": self.classifier_layer.parameters(), "lr_scale": 1}, 46 | {"params": self.classifier_layer_2.parameters(), "lr_scale": 1}, 47 | ] 48 | 49 | def create_bottleneck_layer(self, use_dropout): 50 | bottleneck_layer_list = [ 51 | nn.Linear(self.base_network.output_num(), self.bottleneck_dim), 52 | nn.BatchNorm1d(self.bottleneck_dim), 53 | nn.ReLU(), 54 | ] 55 | if use_dropout is True: 56 | bottleneck_layer_list.append(nn.Dropout(0.5)) 57 | 58 | self.bottleneck_layer = nn.Sequential(*bottleneck_layer_list) 59 | 60 | # init 61 | self.bottleneck_layer[0].weight.data.normal_(0, 0.005) 62 | self.bottleneck_layer[0].bias.data.fill_(0.1) 63 | 64 | def create_f_and_fhat_classifiers( 65 | self, 66 | bottleneck_dim, 67 | classifier_width, 68 | class_num, 69 | classifier_depth, 70 | use_dropout=True, 71 | ): 72 | self.classifier_layer = self.create_classifier( 73 | bottleneck_dim, 74 | classifier_width, 75 | class_num, 76 | classifier_depth, 77 | use_dropout=use_dropout, 78 | ) 79 | self.classifier_layer_2 = self.create_classifier( 80 | bottleneck_dim, 81 | classifier_width, 82 | class_num, 83 | classifier_depth, 84 | use_dropout=use_dropout, 85 | ) 86 | self.initialize_classifiers() 87 | 88 | def create_classifier( 89 | self, bottleneck_dim, width, class_num, depth=2, use_dropout=True 90 | ): 91 | layer_list = [] 92 | input_size = bottleneck_dim 93 | for ith_layer in range(depth - 1): 94 | layer_list.append(nn.Linear(input_size, width)) 95 | 96 | layer_list.append(nn.ReLU()) 97 | 98 | if use_dropout is True: 99 | layer_list.append(nn.Dropout(0.5)) 100 | 101 | input_size = width 102 | 103 | layer_list.append(nn.Linear(width, class_num)) 104 | classifier = nn.Sequential(*layer_list) 105 | return classifier 106 | 107 | def forward(self, inputs): 108 | features = self.feature_forward(inputs) 109 | outputs = self.classifier_layer(features) 110 | softmax_outputs = self.softmax(outputs) 111 | 112 | # gradient reversal layer helps fuse the minimax problem into one loss function 113 | features_adv = GradientReverseLayer.apply(features) 114 | outputs_adv = self.classifier_layer_2(features_adv) 115 | 116 | return features, outputs, softmax_outputs, outputs_adv 117 | 118 | def feature_forward(self, inputs): 119 | if self.freeze_backbone is True: 120 | with torch.no_grad(): 121 | features = self.base_network(inputs) 122 | else: 123 | features = self.base_network(inputs) 124 | 125 | if self.use_bottleneck: 126 | features = self.bottleneck_layer(features) 127 | 128 | if self.normalize_features is True: 129 | features_norm = torch.norm(features, p=2, dim=1).detach() 130 | features = features / features_norm.unsqueeze(1) 131 | return features 132 | 133 | def logits_forward(self, inputs): 134 | features = self.feature_forward(inputs) 135 | logits = self.classifier_layer(features) 136 | return logits 137 | 138 | def initialize_classifiers(self): 139 | self.xavier_initialization(self.classifier_layer) 140 | self.xavier_initialization(self.classifier_layer_2) 141 | 142 | def xavier_initialization(self, layers): 143 | for layer in layers: 144 | if type(layer) == nn.Linear: 145 | torch.nn.init.xavier_normal_(layer.weight) 146 | layer.bias.data.fill_(0.0) 147 | 148 | def initialize_bottleneck(self): 149 | for b_layer in self.bottleneck_layer: 150 | if type(b_layer) == nn.Linear: 151 | torch.nn.init.xavier_normal_(b_layer.weight) 152 | b_layer.bias.data.fill_(0.0) 153 | 154 | def get_parameter_list(self): 155 | c_net_params = self.parameter_list 156 | return c_net_params 157 | 158 | 159 | def get_mdd_loss(outputs, outputs_adv, labels_source, class_criterion, srcweight): 160 | # f(x) 161 | outputs_src = outputs.narrow(0, 0, labels_source.size(0)) 162 | label_preds_src = outputs_src.max(1)[1] 163 | outputs_tgt = outputs.narrow( 164 | 0, labels_source.size(0), outputs.size(0) - labels_source.size(0) 165 | ) 166 | probs_tgt = F.softmax(outputs_tgt, dim=1) 167 | # f'(x) 168 | outputs_adv_src = outputs_adv.narrow(0, 0, labels_source.size(0)) 169 | outputs_adv_tgt = outputs_adv.narrow( 170 | 0, labels_source.size(0), outputs.size(0) - labels_source.size(0) 171 | ) 172 | 173 | # classification loss on source domain 174 | # if self.args.mask_classifier is True: 175 | outputs_src_masked, _, _ = mask_clf_outputs( 176 | outputs_src, outputs_adv_src, outputs_adv_tgt, labels_source 177 | ) 178 | classifier_loss = class_criterion(outputs_src_masked, labels_source) 179 | 180 | outputs_src, outputs_adv_src, outputs_adv_tgt = mask_clf_outputs( 181 | outputs_src, outputs_adv_src, outputs_adv_tgt, labels_source 182 | ) 183 | 184 | # use $f$ as the target for $f'$ 185 | target_adv = outputs.max(1)[1] # categorical labels from $f$ 186 | target_adv_src = target_adv.narrow(0, 0, labels_source.size(0)) 187 | target_adv_tgt = target_adv.narrow( 188 | 0, labels_source.size(0), outputs.size(0) - labels_source.size(0) 189 | ) 190 | 191 | # source classification acc 192 | classifier_acc = ( 193 | label_preds_src == labels_source 194 | ).sum().float() / labels_source.size(0) 195 | 196 | # adversarial loss for source domain 197 | classifier_loss_adv_src = class_criterion(outputs_adv_src, target_adv_src) 198 | 199 | # adversarial loss for target domain, opposite sign with source domain 200 | prob_adv = 1 - F.softmax(outputs_adv_tgt, dim=1) 201 | prob_adv = prob_adv.clamp(min=1e-7) 202 | logloss_tgt = torch.log(prob_adv) 203 | classifier_loss_adv_tgt = F.nll_loss(logloss_tgt, target_adv_tgt) 204 | 205 | # total adversarial loss 206 | adv_loss = srcweight * classifier_loss_adv_src + classifier_loss_adv_tgt 207 | 208 | # loss for explicit alignment 209 | 210 | total_loss = classifier_loss + adv_loss 211 | 212 | return total_loss 213 | 214 | 215 | def mask_clf_outputs(outputs_src, outputs_adv_src, outputs_adv_tgt, labels_source): 216 | mask = torch.zeros(outputs_src.shape[1]) 217 | mask[labels_source.unique()] = 1 218 | mask = mask.repeat((outputs_src.shape[0], 1)).cuda() 219 | outputs_src = outputs_src * mask 220 | outputs_adv_src = outputs_adv_src * mask 221 | outputs_adv_tgt = outputs_adv_tgt * mask 222 | return outputs_src, outputs_adv_src, outputs_adv_tgt 223 | -------------------------------------------------------------------------------- /RLSbench/models/mimic_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | 4 | import torch 5 | import torch.nn as nn 6 | import logging 7 | 8 | logger = logging.getLogger("label_shift") 9 | 10 | from RLSbench.models.mimic_tokenizer import MIMICTokenizer 11 | 12 | 13 | class Attention(nn.Module): 14 | def forward(self, query, key, value, mask, dropout=None): 15 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1)) 16 | if mask is not None: 17 | scores = scores.masked_fill(mask == 0, -1e9) 18 | p_attn = torch.softmax(scores, dim=-1) 19 | p_attn = p_attn.masked_fill(mask == 0, 0) 20 | if dropout is not None: 21 | p_attn = dropout(p_attn) 22 | return torch.matmul(p_attn, value), p_attn 23 | 24 | 25 | class MultiHeadedAttention(nn.Module): 26 | def __init__(self, h, d_model, dropout=0.1): 27 | super(MultiHeadedAttention, self).__init__() 28 | assert d_model % h == 0 29 | 30 | # We assume d_v always equals d_k 31 | self.d_k = d_model // h 32 | self.h = h 33 | 34 | self.linear_layers = nn.ModuleList( 35 | [nn.Linear(d_model, d_model, bias=False) for _ in range(3)] 36 | ) 37 | self.output_linear = nn.Linear(d_model, d_model, bias=False) 38 | self.attention = Attention() 39 | 40 | self.dropout = nn.Dropout(p=dropout) 41 | 42 | def forward(self, query, key, value, mask): 43 | """ 44 | :param query, key, value: [batch_size, seq_len, d_model] 45 | :param mask: [batch_size, seq_len, seq_len] 46 | :return: [batch_size, seq_len, d_model] 47 | """ 48 | 49 | batch_size = query.size(0) 50 | 51 | # 1) Do all the linear projections in batch from d_model => h x d_k 52 | query, key, value = [ 53 | l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) 54 | for l, x in zip(self.linear_layers, (query, key, value)) 55 | ] 56 | 57 | # 2) Apply attention on all the projected vectors in batch. 58 | x, attn = self.attention( 59 | query, key, value, mask=mask.unsqueeze(1), dropout=self.dropout 60 | ) 61 | 62 | # 3) "Concat" using a view and apply a final linear. 63 | x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k) 64 | 65 | return self.output_linear(x) 66 | 67 | 68 | class PositionwiseFeedForward(nn.Module): 69 | def __init__(self, d_model, d_ff, dropout=0.1): 70 | super(PositionwiseFeedForward, self).__init__() 71 | self.w_1 = nn.Linear(d_model, d_ff) 72 | self.w_2 = nn.Linear(d_ff, d_model) 73 | self.dropout = nn.Dropout(dropout) 74 | self.activation = nn.GELU() 75 | 76 | def forward(self, x, mask): 77 | x = self.w_2(self.dropout(self.activation(self.w_1(x)))) 78 | mask = mask.sum(dim=-1) > 0 79 | x[~mask] = 0 80 | return x 81 | 82 | 83 | class SublayerConnection(nn.Module): 84 | """ 85 | A residual connection followed by a layer norm. 86 | Note for code simplicity the norm is first as opposed to last. 87 | """ 88 | 89 | def __init__(self, size, dropout): 90 | super(SublayerConnection, self).__init__() 91 | self.norm = nn.LayerNorm(size) 92 | self.dropout = nn.Dropout(dropout) 93 | 94 | def forward(self, x, sublayer): 95 | """Apply residual connection to any sublayer with the same size.""" 96 | return x + self.dropout(sublayer(self.norm(x))) 97 | 98 | 99 | class TransformerBlock(nn.Module): 100 | """ 101 | Transformer Block = MultiHead Attention + Feed Forward with sublayer connection 102 | """ 103 | 104 | def __init__(self, hidden, attn_heads, dropout): 105 | """ 106 | :param hidden: hidden size of transformer 107 | :param attn_heads: head sizes of multi-head attention 108 | :param dropout: dropout rate 109 | """ 110 | 111 | super(TransformerBlock, self).__init__() 112 | self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden) 113 | self.feed_forward = PositionwiseFeedForward( 114 | d_model=hidden, d_ff=4 * hidden, dropout=dropout 115 | ) 116 | self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout) 117 | self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout) 118 | self.dropout = nn.Dropout(p=dropout) 119 | logger.info( 120 | f"TransformerBlock added with hid-{hidden}, head-{attn_heads}, in_hid-{2 * hidden}, drop-{dropout}" 121 | ) 122 | 123 | def forward(self, x, mask): 124 | """ 125 | :param x: [batch_size, seq_len, hidden] 126 | :param mask: [batch_size, seq_len, seq_len] 127 | :return: batch_size, seq_len, hidden] 128 | """ 129 | 130 | x = self.input_sublayer(x, lambda _x: self.attention(_x, _x, _x, mask=mask)) 131 | x = self.output_sublayer(x, lambda _x: self.feed_forward(_x, mask=mask)) 132 | return self.dropout(x) 133 | 134 | 135 | class Transformer(nn.Module): 136 | def __init__( 137 | self, 138 | data_dir, 139 | embedding_size: int, 140 | dropout: float, 141 | layers: int, 142 | heads: int, 143 | device="cpu", 144 | ): 145 | super(Transformer, self).__init__() 146 | self.tokenizer = MIMICTokenizer(data_dir) 147 | self.embedding_size = embedding_size 148 | self.dropout = dropout 149 | self.layers = layers 150 | self.heads = heads 151 | self.device = device 152 | 153 | # embedding 154 | self.code_embedding = nn.Embedding( 155 | self.tokenizer.get_code_vocabs_size(), embedding_size, padding_idx=0 156 | ) 157 | self.type_embedding = nn.Embedding( 158 | self.tokenizer.get_type_vocabs_size(), embedding_size, padding_idx=0 159 | ) 160 | 161 | # encoder 162 | self.transformer = nn.ModuleList( 163 | [TransformerBlock(embedding_size, heads, dropout) for _ in range(layers)] 164 | ) 165 | 166 | # binary classifier 167 | self.activation = nn.Sigmoid() 168 | 169 | self.d_out = embedding_size 170 | 171 | def forward(self, x): 172 | codes, types = x[0], x[1] 173 | codes, types = self.tokenizer(codes, types, padding=True, prefix="") 174 | codes = codes.cuda() 175 | types = types.cuda() 176 | 177 | """ embedding """ 178 | # [# admissions, # batch_codes, embedding_size] 179 | codes_emb = self.code_embedding(codes) 180 | types_emb = self.type_embedding(types) 181 | emb = codes_emb + types_emb 182 | 183 | """ transformer """ 184 | mask = codes != 0 185 | mask = torch.einsum("ab,ac->abc", mask, mask) 186 | for transformer in self.transformer: 187 | x = transformer(emb, mask) # [# admissions, # batch_codes, embedding_size] 188 | 189 | cls_emb = x[:, 0, :] 190 | # logits = self.fc(cls_emb) 191 | # logits = logits.squeeze(-1) 192 | return cls_emb 193 | 194 | def get_cls_embed(self, x): 195 | codes, types = x[0], x[1] 196 | codes, types = self.tokenizer(codes, types, padding=True, prefix="") 197 | codes = codes.cuda() 198 | types = types.cuda() 199 | 200 | """ embedding """ 201 | # [# admissions, # batch_codes, embedding_size] 202 | codes_emb = self.code_embedding(codes) 203 | types_emb = self.type_embedding(types) 204 | emb = codes_emb + types_emb 205 | 206 | """ transformer """ 207 | mask = codes != 0 208 | mask = torch.einsum("ab,ac->abc", mask, mask) 209 | for transformer in self.transformer: 210 | x = transformer(emb, mask) # [# admissions, # batch_codes, embedding_size] 211 | 212 | cls_embed = x[:, 0, :] # get CLS embedding 213 | return cls_embed 214 | -------------------------------------------------------------------------------- /RLSbench/models/model_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | 10 | logger = logging.getLogger("label_shift") 11 | 12 | 13 | def PCA_whitener(centered): 14 | n, p = centered.shape 15 | cov = torch.matmul(centered.T, centered) / (n - 1) 16 | # torch.symeig much less stable since the small eigenvalues are very close together, often returns negatives. 17 | U, S, _ = torch.svd(cov) 18 | 19 | rho = 0.01 20 | return U, (1.0 - rho) * S + rho * S.mean() 21 | 22 | 23 | def linear_probe(model, dataloader, device, lambda_param=1e-6, progress_bar=True): 24 | logger.info("Linear probing ... ") 25 | 26 | model[0].to(device) 27 | model[0].eval() 28 | data_features = [] 29 | data_labels = [] 30 | iterator = dataloader 31 | if progress_bar: 32 | iterator = tqdm(iterator) 33 | # import pdb; pdb.set_trace() 34 | with torch.no_grad(): 35 | for batch in iterator: 36 | # import pdb; pdb.set_trace() 37 | 38 | x, y = batch[:2] 39 | x = x.to(device) 40 | y = y.to(device) 41 | features = model[0](x) 42 | data_features.append(features.cpu().numpy()) 43 | data_labels.append(y.cpu().numpy()) 44 | 45 | data_features = torch.tensor(np.concatenate(data_features, axis=0)) 46 | data_labels = torch.tensor(np.concatenate(data_labels, axis=0)) 47 | 48 | optimizer = torch.optim.LBFGS( 49 | model[1].parameters(), history_size=100, max_iter=100, lr=0.1 50 | ) 51 | 52 | model[0].to(torch.device("cpu")) 53 | 54 | model[1].to(device) 55 | model[1].train() 56 | 57 | new_loss = -200.0 58 | loss = -100.0 59 | 60 | iteration = 0 61 | logger.info("Got features, now training the linear layer ...") 62 | while np.abs(new_loss - loss) > 1e-6: 63 | logger.info(f"Linear probing iteration {iteration+1} ... ") 64 | loss = new_loss 65 | data_features, data_labels = data_features.to(device), data_labels.to(device) 66 | 67 | def closure_fn(): 68 | out = model[1](data_features) 69 | 70 | l2_norm = sum(p.pow(2.0).sum() for p in model[1].parameters()) 71 | loss = F.cross_entropy(out, data_labels) + l2_norm * lambda_param 72 | 73 | optimizer.zero_grad() 74 | loss.backward() 75 | 76 | return loss 77 | 78 | optimizer.step(closure_fn) 79 | iteration = iteration + 1 80 | with torch.no_grad(): 81 | new_loss = ( 82 | F.cross_entropy(model[1](data_features), data_labels).cpu().numpy() 83 | ) 84 | 85 | model[1].to(torch.device("cpu")) 86 | 87 | return model 88 | 89 | 90 | def train_CORAL(model, dataloader, im_weights, device, lambda_param=1e-6): 91 | logger.info("Getting features ... ") 92 | 93 | model.featurizer.eval() 94 | data_features = [] 95 | data_labels = [] 96 | with torch.no_grad(): 97 | for batch in dataloader: 98 | x, y = batch[:2] 99 | x = x.to(device) 100 | y = y.to(device) 101 | features = model.featurizer(x) 102 | data_features.append(features.cpu().numpy()) 103 | data_labels.append(y.cpu().numpy()) 104 | 105 | data_features = torch.tensor(np.concatenate(data_features, axis=0)) 106 | data_labels = torch.tensor(np.concatenate(data_labels, axis=0)) 107 | 108 | # centered = data_features - torch.mean(data_features, 0, True) 109 | 110 | U, S = PCA_whitener(data_features) 111 | W = U @ torch.diag_embed(torch.reciprocal(torch.sqrt(S))) @ U.T 112 | # features = torch.mm(data_features, W) 113 | W = W.to(device) 114 | 115 | logger.info("Linear probing ... ") 116 | 117 | # optimizer = torch.optim.LBFGS(model.classifier.parameters(), history_size=100, max_iter=100, lr=0.1) 118 | optimizer = torch.optim.SGD( 119 | model.classifier.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001 120 | ) 121 | 122 | model.classifier.train() 123 | 124 | for batch in dataloader: 125 | x, y = batch[:2] 126 | data_features, data_labels = x.to(device), y.to(device) 127 | 128 | with torch.no_grad(): 129 | data_features = model.featurizer(data_features) 130 | data_features = torch.mm(data_features, W) 131 | 132 | out = model.classifier(data_features) 133 | 134 | loss = F.cross_entropy(out, data_labels) 135 | 136 | optimizer.zero_grad() 137 | loss.backward() 138 | 139 | optimizer.step() 140 | 141 | return model 142 | 143 | 144 | def test_CORAL_params(model, dataloader, device): 145 | logger.info("Getting mean and variance of target ... ") 146 | 147 | model.featurizer.eval() 148 | data_features = [] 149 | data_labels = [] 150 | with torch.no_grad(): 151 | for batch in dataloader: 152 | x, y = batch[:2] 153 | x = x.to(device) 154 | # y = y.to(device) 155 | features = model.featurizer(x) 156 | data_features.append(features.cpu().numpy()) 157 | # data_labels.append(y.cpu().numpy()) 158 | 159 | data_features = torch.tensor(np.concatenate(data_features, axis=0)) 160 | # data_labels = torch.tensor(np.concatenate(data_labels, axis=0)) 161 | 162 | # mean = torch.mean(data_features, 0, True) 163 | # centered = data_features - mean 164 | 165 | U, S = PCA_whitener(data_features) 166 | cov_inv = U @ torch.diag_embed(torch.reciprocal(torch.sqrt(S))) @ U.T 167 | 168 | logger.info("Done.") 169 | 170 | return cov_inv 171 | 172 | 173 | def configure_model(model): 174 | """Configure model for use with tent.""" 175 | # train mode, because tent optimizes the model to minimize entropy 176 | model.train() 177 | # disable grad, to (re-)enable only what tent updates 178 | model.requires_grad_(False) 179 | # configure norm for tent updates: enable grad + force batch statisics 180 | 181 | # TODO: Check what if we combine this with BN adapt 182 | for m in model.modules(): 183 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.LayerNorm): 184 | m.requires_grad_(True) 185 | # force use of batch stats in train and eval modes 186 | # m.track_running_stats = False 187 | # m.running_mean = None 188 | # m.running_var = None 189 | return model 190 | 191 | 192 | def collect_params(model): 193 | """Collect the affine scale + shift parameters from batch norms. 194 | Walk the model's modules and collect all batch normalization parameters. 195 | Return the parameters and their names. 196 | Note: other choices of parameterization are possible! 197 | """ 198 | params = [] 199 | names = [] 200 | for nm, m in model.named_modules(): 201 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.LayerNorm): 202 | for np, p in m.named_parameters(): 203 | if np in ["weight", "bias"]: # weight is scale, bias is shift 204 | params.append(p) 205 | names.append(f"{nm}.{np}") 206 | return params, names 207 | -------------------------------------------------------------------------------- /RLSbench/models/transformers.py: -------------------------------------------------------------------------------- 1 | from transformers import DistilBertForSequenceClassification, DistilBertModel 2 | import torch.nn as nn 3 | 4 | 5 | class DistilBertClassifier(DistilBertForSequenceClassification): 6 | def __init__(self, config): 7 | super().__init__(config) 8 | 9 | def __call__(self, x): 10 | input_ids = x[:, :, 0] 11 | attention_mask = x[:, :, 1] 12 | outputs = super().__call__( 13 | input_ids=input_ids, 14 | attention_mask=attention_mask, 15 | )[0] 16 | return outputs 17 | 18 | 19 | class Identity(nn.Module): 20 | """An identity layer""" 21 | 22 | def __init__(self, d): 23 | super().__init__() 24 | self.in_features = d 25 | self.out_features = d 26 | 27 | def forward(self, x): 28 | return x 29 | 30 | 31 | class DistilBertFeaturizer(DistilBertModel): 32 | def __init__(self, config): 33 | super().__init__(config) 34 | self.d_out = config.hidden_size 35 | 36 | def __call__(self, x): 37 | input_ids = x[:, :, 0] 38 | attention_mask = x[:, :, 1] 39 | hidden_state = super().__call__( 40 | input_ids=input_ids, 41 | attention_mask=attention_mask, 42 | )[0] 43 | pooled_output = hidden_state[:, 0] 44 | return pooled_output 45 | 46 | 47 | def initialize_bert_based_model(net, num_classes): 48 | if net == "distilbert-base-uncased": 49 | model = DistilBertClassifier.from_pretrained(net, num_labels=num_classes) 50 | d_features = getattr(model, "classifier").in_features 51 | 52 | model.classifier = Identity(d_features) 53 | model.d_out = d_features 54 | else: 55 | raise ValueError(f"Model: {net} not recognized.") 56 | return model 57 | -------------------------------------------------------------------------------- /RLSbench/optimizer.py: -------------------------------------------------------------------------------- 1 | from torch.optim import SGD, Adam, AdamW 2 | 3 | 4 | def initialize_optimizer(config, model): 5 | # initialize optimizers 6 | if config.optimizer == "SGD": 7 | params = filter(lambda p: p.requires_grad, model.parameters()) 8 | optimizer = SGD( 9 | params, 10 | lr=config.lr, 11 | weight_decay=config.weight_decay, 12 | **config.optimizer_kwargs, 13 | ) 14 | elif config.optimizer == "Adam": 15 | params = filter(lambda p: p.requires_grad, model.parameters()) 16 | optimizer = Adam( 17 | params, 18 | lr=config.lr, 19 | weight_decay=config.weight_decay, 20 | **config.optimizer_kwargs, 21 | ) 22 | elif config.optimizer == "AdamW": 23 | params = filter(lambda p: p.requires_grad, model.parameters()) 24 | 25 | # import pdb; pdb.set_trace() 26 | optimizer = AdamW( 27 | params, 28 | lr=config.lr, 29 | weight_decay=config.weight_decay, 30 | **config.optimizer_kwargs, 31 | ) 32 | else: 33 | raise ValueError(f"Optimizer {config.optimizer} not recognized.") 34 | 35 | return optimizer 36 | 37 | 38 | def initialize_optimizer_with_model_params(config, params): 39 | if config.optimizer == "SGD": 40 | optimizer = SGD( 41 | params, 42 | lr=config.lr, 43 | weight_decay=config.weight_decay, 44 | **config.optimizer_kwargs, 45 | ) 46 | elif config.optimizer == "Adam": 47 | optimizer = Adam( 48 | params, 49 | lr=config.lr, 50 | weight_decay=config.weight_decay, 51 | **config.optimizer_kwargs, 52 | ) 53 | elif config.optimizer == "AdamW": 54 | optimizer = AdamW( 55 | params, 56 | lr=config.lr, 57 | weight_decay=config.weight_decay, 58 | **config.optimizer_kwargs, 59 | ) 60 | else: 61 | raise ValueError(f"Optimizer {config.optimizer} not supported.") 62 | 63 | return optimizer 64 | -------------------------------------------------------------------------------- /RLSbench/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.optim.lr_scheduler import LambdaLR, MultiStepLR, ReduceLROnPlateau, StepLR 3 | 4 | 5 | def initialize_scheduler(config, optimizer, n_train_steps): 6 | # construct schedulers 7 | if config.scheduler is None: 8 | return None 9 | elif config.scheduler == "linear_schedule_with_warmup": 10 | from transformers import get_linear_schedule_with_warmup 11 | 12 | scheduler = get_linear_schedule_with_warmup( 13 | optimizer, 14 | num_training_steps=n_train_steps, 15 | num_warmup_steps=int( 16 | config.scheduler_kwargs["warmup_frac"] * n_train_steps 17 | ), 18 | ) 19 | step_every_batch = True 20 | use_metric = False 21 | elif config.scheduler == "cosine_schedule_with_warmup": 22 | from transformers import get_cosine_schedule_with_warmup 23 | 24 | if "warmup_frac" not in config.scheduler_kwargs: 25 | config.scheduler_kwargs["num_warmup_steps"] = 0 26 | 27 | else: 28 | config.scheduler_kwargs["num_warmup_steps"] = int( 29 | config.scheduler_kwargs["warmup_frac"] * n_train_steps 30 | ) 31 | 32 | scheduler = get_cosine_schedule_with_warmup( 33 | optimizer, 34 | num_training_steps=n_train_steps, 35 | num_warmup_steps=config.scheduler_kwargs["num_warmup_steps"], 36 | ) 37 | step_every_batch = True 38 | use_metric = False 39 | 40 | # elif config.scheduler=='ReduceLROnPlateau': 41 | # assert config.scheduler_metric_name, f'scheduler metric must be specified for {config.scheduler}' 42 | # scheduler = ReduceLROnPlateau( 43 | # optimizer, 44 | # **config.scheduler_kwargs) 45 | # step_every_batch = False 46 | # use_metric = True 47 | 48 | elif config.scheduler == "StepLR": 49 | scheduler = StepLR(optimizer, **config.scheduler_kwargs) 50 | step_every_batch = False 51 | use_metric = False 52 | elif config.scheduler == "FixMatchLR": 53 | scheduler = LambdaLR( 54 | optimizer, lambda x: (1.0 + 10 * float(x) / n_train_steps) ** -0.75 55 | ) 56 | step_every_batch = True 57 | use_metric = False 58 | elif config.scheduler == "MultiStepLR": 59 | scheduler = MultiStepLR(optimizer, **config.scheduler_kwargs) 60 | step_every_batch = False 61 | use_metric = False 62 | else: 63 | raise ValueError(f"Scheduler: {config.scheduler} not supported.") 64 | 65 | # add an step_every_batch field 66 | scheduler.step_every_batch = step_every_batch 67 | return scheduler 68 | 69 | 70 | def step_scheduler(scheduler): 71 | scheduler.step() 72 | 73 | 74 | class LinearScheduleWithWarmupAndThreshold: 75 | """ 76 | Linear scheduler with warmup and threshold for non lr parameters. 77 | Parameters is held at 0 until some T1, linearly increased until T2, and then held 78 | at some max value after T2. 79 | Designed to be called by step_scheduler() above and used within Algorithm class. 80 | Args: 81 | - last_warmup_step: aka T1. for steps [0, T1) keep param = 0 82 | - threshold_step: aka T2. step over period [T1, T2) to reach param = max value 83 | - max value: end value of the param 84 | """ 85 | 86 | def __init__( 87 | self, max_value, last_warmup_step=0, threshold_step=1, step_every_batch=False 88 | ): 89 | self.max_value = max_value 90 | self.T1 = last_warmup_step 91 | self.T2 = threshold_step 92 | assert (0 <= self.T1) and (self.T1 < self.T2) 93 | 94 | # internal tracker of which step we're on 95 | self.current_step = 0 96 | self.value = 0 97 | 98 | # required fields called in Algorithm when stepping schedulers 99 | self.step_every_batch = step_every_batch 100 | 101 | def step(self): 102 | """This function is first called AFTER step 0, so increment first to set value for next step""" 103 | self.current_step += 1 104 | if self.current_step < self.T1: 105 | self.value = 0 106 | elif self.current_step < self.T2: 107 | self.value = ( 108 | (self.current_step - self.T1) / (self.T2 - self.T1) * self.max_value 109 | ) 110 | else: 111 | self.value = self.max_value 112 | 113 | 114 | class CoeffSchedule: 115 | def __init__(self, max_iter, high=1.0, low=0.0, alpha=10.0): 116 | self.max_iter = max_iter 117 | self.high = high 118 | self.low = low 119 | self.alpha = alpha 120 | self.iter_num = 0.0 121 | self.step_every_batch = True 122 | self.value = 0.0 123 | 124 | def step(self): 125 | self.value = np.float( 126 | 2.0 127 | * (self.high - self.low) 128 | / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iter)) 129 | - (self.high - self.low) 130 | + self.low 131 | ) 132 | self.iter_num = self.iter_num + 1 133 | -------------------------------------------------------------------------------- /RLSbench/version.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/snap-stanford/ogb/blob/master/ogb/version.py 2 | 3 | import os 4 | import logging 5 | from threading import Thread 6 | 7 | __version__ = "0.1" 8 | 9 | try: 10 | os.environ["OUTDATED_IGNORE"] = "1" 11 | from outdated import check_outdated # noqa 12 | except ImportError: 13 | check_outdated = None 14 | 15 | 16 | def check(): 17 | try: 18 | is_outdated, latest = check_outdated("RLSbench", __version__) 19 | if is_outdated: 20 | logging.warning( 21 | f"The Relaxed Label Shift study package is out of date. Your version is " 22 | f"{__version__}, while the latest version is {latest}." 23 | ) 24 | except Exception: 25 | pass 26 | 27 | 28 | if check_outdated is not None: 29 | thread = Thread(target=check) 30 | thread.start() 31 | -------------------------------------------------------------------------------- /code_helper.md: -------------------------------------------------------------------------------- 1 | # Label Shift Study Code 2 | 3 | ## Overview 4 | 5 | Overall pipeline: 6 | 7 | 8 | The following are the crucial parts of the code: 9 | 10 | 1. `label_shift_utils.py`: This files contains utils functions to simulate label shift in the target data. 11 | 2. `./datasets/get_dataset.py`: This file contains the code to get the source and target datasets. 12 | 3. `./algorithms/`: This folder contains the code for different algorithms. We implement the follwing domain algorithms: 13 | - ERM variants: ERM, ERM-aug, with different pretraining techniques like ['rand', 'imagenet', 'clip'] 14 | - Domain alignment methods: DANN, CDAN, 15 | - Self-training methods: Noisy student, Pseudolabeling, FixMatch, SENTRY, COAL 16 | - Self-supervised learning methods: SwAV 17 | - Test time adaptation methods: BN_adapt, TENT, CORAL 18 | 19 | 20 | The entry point of the code is `run_expt.py`. `config` folder contains default parameters and hyperparameters needed for base experiments for the project. We need to pass the dataset name and the algorithm name with flag `--dataset` and `--algorithm` to the `run_expt.py` file. To simulate label shift, we need to pass the flag `--simulate_label_shift` and the dirchilet sampling parameter with `--dirchilet_alpha`. And the flag `--root_dir` is used to specified the data directory for source and target datasets. 21 | 22 | Caveat: For Test Time Adaptation (TTA) methods, we need to provide the folder with ERM-aug trained models with the parameter `--source_model_path`. 23 | 24 | ### Results Logging 25 | 26 | The code evaluates the models trained and logs the results in the `./logs/` folder in form of a csv file. 27 | 28 | 29 | ## Simple example for running the code 30 | The following command can be used to run the code on `cifar10` dataset with `ERM-aug` algorithm: 31 | 32 | ```python 33 | python run_expt.py --dataset=cifar10 --algorithm=ERM-aug --simulate_label_shift --dirchilet_alpha=0.1 34 | ``` 35 | 36 | ## Requirements 37 | 38 | The code is written in Python and uses [PyTorch](https://pytorch.org/). To install requirements, setup a conda enviornment using the following command: 39 | 40 | ```setup 41 | conda env create --file requirements.yml 42 | ``` 43 | 44 | ## Dataset Setup 45 | To setup different datasets, run the scrips in `dataset_scripts` folder. Except for Imagenet dataset which can be downloaded from the [official website](https://www.image-net.org/download.php), the scripts set up all the datasets (including all the source and target pairs) used in our study. 46 | 47 | 48 | ## Code structure 49 | The code structure is the following: 50 | ``` 51 | label_shift_study 52 | ├── algorithms 53 | │   ├── BN_adapt.py 54 | │   ├── CDAN.py 55 | │   ├── COAL.py 56 | │   ├── CORAL.py 57 | │   ├── DANN.py 58 | │   ├── ERM.py 59 | │   ├── MDD.py 60 | │   ├── SENTRY.py 61 | │   ├── TENT.py 62 | │   ├── algorithm.py 63 | │   ├── deepCORAL.py 64 | │   ├── fixmatch.py 65 | │   ├── initializer.py 66 | │   ├── noisy_student.py 67 | │   ├── pseudolabel.py 68 | │   └── single_model_algorithm.py 69 | ├── code_helper.md 70 | ├── configs 71 | │   ├── algorithm.py 72 | │   ├── datasets.py 73 | │   ├── supported.py 74 | │   └── utils.py 75 | ├── data_augmentation 76 | │   ├── __init__.py 77 | │   └── randaugment.py 78 | ├── dataset_scripts 79 | │   ├── Imagenet 80 | │   │   ├── ImageNet_reorg.py 81 | │   │   ├── ImageNet_resize.py 82 | │   │   ├── ImageNet_v2_reorg.py 83 | │   │   └── resize_ImageNet-C.sh 84 | │   ├── convert.sh 85 | │   ├── setup_BREEDs.sh 86 | │   ├── setup_Imagenet.sh 87 | │   ├── setup_Imagenet200.sh 88 | │   ├── setup_camelyon.sh 89 | │   ├── setup_cifar100c.sh 90 | │   ├── setup_cifar10c.sh 91 | │   ├── setup_domainnet.sh 92 | │   ├── setup_fmow.sh 93 | │   ├── setup_iwildcams.sh 94 | │   ├── setup_office31.sh 95 | │   ├── setup_officehome.sh 96 | │   ├── setup_rxrx1.sh 97 | │   ├── setup_visda.sh 98 | │   └── visda_structure.py 99 | ├── datasets 100 | │   ├── __init__.py 101 | │   ├── data_utils.py 102 | │   └── get_dataset.py 103 | ├── experiment_scripts 104 | ├── label_shift_utils.py 105 | ├── losses.py 106 | ├── models 107 | │   ├── __init__.py 108 | │   ├── cifar_efficientnet.py 109 | │   ├── cifar_resnet.py 110 | │   ├── clip.py 111 | │   ├── domain_adversarial_network.py 112 | │   ├── initializer.py 113 | │   ├── mdd_net.py 114 | │   └── model_utils.py 115 | ├── notebooks 116 | │   ├── image_show.ipynb 117 | │   └── wilds_loading.ipynb 118 | ├── optimizer.py 119 | ├── pretraining 120 | │   └── swav 121 | │   ├── LICENSE 122 | │   ├── README.md 123 | │   ├── main_swav.py 124 | │   └── src 125 | │   ├── config.py 126 | │   ├── logger.py 127 | │   ├── model.py 128 | │   ├── multicropdataset.py 129 | │   └── utils.py 130 | ├── run_expt.py 131 | ├── scheduler.py 132 | ├── train.py 133 | ├── transforms.py 134 | └── utils.py 135 | 136 | ``` 137 | -------------------------------------------------------------------------------- /dataset_scripts/Imagenet/ImageNet_reorg.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from absl import app, flags 3 | from concurrent import futures 4 | import os 5 | 6 | FLAGS = flags.FLAGS 7 | 8 | flags.DEFINE_string('dir', "/tmp/", "Dir to convert images") 9 | flags.DEFINE_string('newDir', "/tmp/", "Dir to convert images") 10 | flags.DEFINE_string('csv', "/tmp/", "CSV file") 11 | 12 | def main(_): 13 | 14 | 15 | file_map = {} 16 | with open(FLAGS.csv, "r" ) as file: 17 | file.readline() 18 | for line in file: 19 | file_name, file_class = line.split(" ")[0].split(",") 20 | file_map[file_name + ".JPEG"] = file_class 21 | 22 | print(file_name + ".JPEG", file_class) 23 | 24 | 25 | for r, d, f in os.walk(FLAGS.dir): 26 | for file in f: 27 | if file.endswith("jpg") or file.endswith("JPEG") or file.endswith("jpeg") or file.endswith("JPG"): 28 | file_name = (r + "/" + file).split("/")[-1] 29 | if not os.path.isdir(FLAGS.newDir + "/" + file_map[file_name]): 30 | os.makedirs(FLAGS.newDir + "/" + file_map[file_name]) 31 | 32 | os.rename(r + "/" + file, FLAGS.newDir + "/" + file_map[file_name] + "/" + file_name ) 33 | 34 | 35 | if __name__ == '__main__': 36 | app.run(main) -------------------------------------------------------------------------------- /dataset_scripts/Imagenet/ImageNet_resize.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from absl import app, flags 3 | from concurrent import futures 4 | import os 5 | 6 | FLAGS = flags.FLAGS 7 | 8 | flags.DEFINE_string('dir', "/tmp/", "Dir to convert images") 9 | 10 | 11 | def resize_crop(image_name): 12 | 13 | im = Image.open(image_name) 14 | width, height = im.size # Get dimensions 15 | 16 | if width >= height: 17 | width = int(width*1.0/height*256) 18 | height = 256 19 | else: 20 | height = int(height*1.0/width*256) 21 | width = 256 22 | 23 | im = im.resize((width, height)) 24 | 25 | # width, height = im.size 26 | # new_width, new_height = 224, 224 27 | 28 | # left = (width - new_width)/2 29 | # top = (height - new_height)/2 30 | # right = (width + new_width)/2 31 | # bottom = (height + new_height)/2 32 | 33 | # # Crop the center of the image 34 | # im = im.crop((left, top, right, bottom)) 35 | im.save(image_name) 36 | 37 | def main(_): 38 | 39 | pool = futures.ThreadPoolExecutor(20) 40 | 41 | processes = [] 42 | for r, d, f in os.walk(FLAGS.dir): 43 | for file in f: 44 | if file.endswith("png") or file.endswith("PNG") or file.endswith("jpg") or file.endswith("JPEG") or file.endswith("jpeg") or file.endswith("JPG"): 45 | process = pool.submit(resize_crop, r + "/" + file) 46 | processes.append(process) 47 | 48 | futures.wait(processes) 49 | 50 | 51 | 52 | if __name__ == '__main__': 53 | app.run(main) -------------------------------------------------------------------------------- /dataset_scripts/Imagenet/ImageNet_v2_reorg.py: -------------------------------------------------------------------------------- 1 | # from PIL import Image 2 | from absl import app, flags 3 | from concurrent import futures 4 | import os 5 | import json 6 | 7 | FLAGS = flags.FLAGS 8 | 9 | flags.DEFINE_string('dir', "/tmp/", "Dir to convert images") 10 | flags.DEFINE_string('info', "/tmp/", "json file") 11 | 12 | def main(_): 13 | 14 | 15 | file_map = {} 16 | with open(FLAGS.info, "r" ) as file: 17 | 18 | json_array = json.load(file) 19 | for i, line in enumerate(json_array): 20 | # print(line) 21 | file_map[str(line[0])] = line[1] 22 | 23 | for r, d, f in os.walk(FLAGS.dir): 24 | dir_name = r.split("/")[-1] 25 | if dir_name in file_map: 26 | os.rename(r, "/".join(r.split("/")[:-1]) + "/" + file_map[dir_name]) 27 | 28 | 29 | if __name__ == '__main__': 30 | app.run(main) 31 | -------------------------------------------------------------------------------- /dataset_scripts/Imagenet/convert.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | counter=0 3 | NUM_RUNS=20 4 | 5 | for i in `find . -name "*.png" -type f`; do 6 | convert $i "${i%.png}".jpg && rm $i & 7 | 8 | counter=$((counter+1)) 9 | 10 | if ! ((counter % NUM_RUNS)); then 11 | wait 12 | fi 13 | 14 | done 15 | 16 | # folder=/tmp/ 17 | # find $folder -name "*.png" -exec bash -c 'convert "$1" "${1%.png}".jpg && rm $1' - '{}' + 18 | -------------------------------------------------------------------------------- /dataset_scripts/Imagenet/convert_to_jpg.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from absl import app, flags 3 | from concurrent import futures 4 | import os 5 | 6 | FLAGS = flags.FLAGS 7 | 8 | flags.DEFINE_string('dir', "/tmp/", "Dir to convert images") 9 | 10 | 11 | def jpg_to_png(image_name): 12 | 13 | im = Image.open(image_name) 14 | im.save(f"{image_name[:-4]}.jpg") 15 | 16 | def main(_): 17 | 18 | pool = futures.ThreadPoolExecutor(20) 19 | 20 | processes = [] 21 | for r, d, f in os.walk(FLAGS.dir): 22 | for file in f: 23 | if file.endswith("png") or file.endswith("PNG"): 24 | process = pool.submit(jpg_to_png, r + "/" + file) 25 | processes.append(process) 26 | 27 | futures.wait(processes) 28 | 29 | 30 | 31 | if __name__ == '__main__': 32 | app.run(main) -------------------------------------------------------------------------------- /dataset_scripts/Imagenet/resize_ImageNet-C.sh: -------------------------------------------------------------------------------- 1 | TYPES=("shot_noise" "impulse_noise" "contrast" "elastic_transform" "pixelate" "jpeg_compression" "speckle_noise" "spatter" "gaussian_blur" "saturate") 2 | 3 | for type in "${TYPES[@]}"; do 4 | echo "${type}" 5 | command="python ImageNet_resize.py --dir=data/ImageNet/ImageNet-C/${type}/" 6 | eval $command 7 | done -------------------------------------------------------------------------------- /dataset_scripts/convert.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | counter=0 3 | NUM_RUNS=20 4 | 5 | for i in `find . -name "*.png" -type f`; do 6 | convert $i "${i%.png}".jpg && rm $i & 7 | 8 | counter=$((counter+1)) 9 | 10 | if ! ((counter % NUM_RUNS)); then 11 | wait 12 | fi 13 | 14 | done 15 | 16 | # folder=/tmp/ 17 | # find $folder -name "*.png" -exec bash -c 'convert "$1" "${1%.png}".jpg && rm $1' - '{}' + 18 | -------------------------------------------------------------------------------- /dataset_scripts/setup_BREEDs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$#" -ne 1 ]; then 4 | echo "Illegal number of parameters" 5 | echo "Usage: ./setup_BREEDs.sh " 6 | exit 1 7 | fi 8 | 9 | ## Download Imagenet from here 10 | echo "Download Imagenet by registering and following instrutions from http://image-net.org/download-images." 11 | 12 | ## Download Imagenet hierarchy 13 | git clone https://github.com/MadryLab/BREEDS-Benchmarks.git 14 | mkdir -p $1/imagenet_class_hierarchy 15 | mv BREEDS-Benchmarks/imagenet_class_hierarchy/modified/* $1/imagenet_hierarchy/ 16 | rm -rf BREEDS-Benchmarks 17 | 18 | -------------------------------------------------------------------------------- /dataset_scripts/setup_Imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$#" -ne 1 ]; then 4 | echo "Illegal number of parameters" 5 | echo "Usage: ./setup_Imagenet.sh " 6 | exit 1 7 | fi 8 | 9 | if [ -f "$1/imagenet_class_hierarchy/dataset_class_info.json" ] 10 | then 11 | echo "OK" 12 | else 13 | echo "Please download the BREEDs heirarcy first with the following command:" 14 | echo "./setup_BREEDs.sh ${1}" 15 | fi 16 | 17 | ## Download Imagenet from here 18 | echo "Download Imagenet by registering and following instrutions from http://image-net.org/download-images." 19 | 20 | ## Download Imagenetv2 21 | echo "Downloading Imagenetv2..." 22 | mkdir -p $1/imagenetv2 23 | 24 | wget https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-matched-frequency.tar.gz 25 | tar -xvf imagenetv2-matched-frequency.tar.gz -C $1/imagenetv2/ 26 | rm -rf imagenetv2-matched-frequency.tar.gz 27 | python dataset_setup/Imagenet/ImageNet_v2_reorg.py --dir $1/imagenetv2/imagenetv2-matched-frequency-format-val --info $1/imagenet_class_hierarchy/dataset_class_info.json 28 | 29 | wget https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-threshold0.7.tar.gz 30 | tar -xvf imagenetv2-threshold0.7.tar.gz -C $1/imagenetv2/ 31 | rm -rf imagenetv2-threshold0.7.tar.gz 32 | python dataset_setup/Imagenet/ImageNet_v2_reorg.py --dir $1/imagenetv2/imagenetv2-threshold0.7-format-val --info $1/imagenet_class_hierarchy/dataset_class_info.json 33 | 34 | 35 | wget https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-top-images.tar.gz 36 | tar -xvf imagenetv2-top-images.tar.gz -C $1/imagenetv2/ 37 | rm -rf imagenetv2-top-images.tar.gz 38 | python dataset_setup/Imagenet/ImageNet_v2_reorg.py --dir $1/imagenetv2/imagenetv2-top-images-format-val --info $1/imagenet_class_hierarchy/dataset_class_info.json 39 | 40 | echo "Imagenetv2 downloaded" 41 | 42 | ## Download Imagenet C 43 | echo "Downloading Imagenet C..." 44 | mkdir -p $1/imagenet-c 45 | 46 | wget https://zenodo.org/record/2235448/files/blur.tar?download=1 47 | tar -xvf "blur.tar?download=1" -C $1/imagenet-c/ 48 | rm -rf "blur.tar?download=1" 49 | 50 | wget https://zenodo.org/record/2235448/files/digital.tar?download=1 51 | tar -xvf "digital.tar?download=1" -C $1/imagenet-c/ 52 | rm -rf "digital.tar?download=1" 53 | 54 | wget https://zenodo.org/record/2235448/files/extra.tar?download=1 55 | tar -xvf "extra.tar?download=1" -C $1/imagenet-c/ 56 | rm -rf "extra.tar?download=1" 57 | 58 | wget https://zenodo.org/record/2235448/files/noise.tar?download=1 59 | tar -xvf "noise.tar?download=1" -C $1/imagenet-c/ 60 | rm -rf "noise.tar?download=1" 61 | 62 | wget https://zenodo.org/record/2235448/files/weather.tar?download=1 63 | tar -xvf "weather.tar?download=1" -C $1/imagenet-c/ 64 | rm -rf "weather.tar?download=1" 65 | 66 | echo "Imagenet C downloaded" 67 | 68 | ## Download Imagenet R 69 | 70 | echo "Downloading Imagenet R..." 71 | wget https://people.eecs.berkeley.edu/~hendrycks/imagenet-r.tar 72 | tar -xvf imagenet-r.tar -C $1/ 73 | rm -rf imagenet-r.tar 74 | 75 | echo "Imagenet R downloaded" 76 | 77 | ## Download Imagenet Sketch 78 | 79 | echo "Downloading Imagenet Sketch..." 80 | gdown https://drive.google.com/uc?id=1Mj0i5HBthqH1p_yeXzsg22gZduvgoNeA 81 | unzip ImageNet-Sketch.zip -d $1/ 82 | mv $1/sketch/ $1/imagenet-sketch/ 83 | rm -rf ImageNet-Sketch.zip 84 | 85 | echo "Imagenet Sketch downloaded" 86 | -------------------------------------------------------------------------------- /dataset_scripts/setup_Imagenet200.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$#" -ne 2 ]; then 4 | echo "Illegal number of parameters" 5 | echo "Usage: ./setup_Imagenet200.sh " 6 | exit 1 7 | fi 8 | 9 | mkdir -p $2 10 | 11 | IDS=('n01443537' 'n01484850' 'n01494475' 'n01498041' 'n01514859' 'n01518878' 'n01531178' 'n01534433' 'n01614925' 'n01616318' 'n01630670' 'n01632777' 'n01644373' 'n01677366' 'n01694178' 'n01748264' 'n01770393' 'n01774750' 'n01784675' 'n01806143' 'n01820546' 'n01833805' 'n01843383' 'n01847000' 'n01855672' 'n01860187' 'n01882714' 'n01910747' 'n01944390' 'n01983481' 'n01986214' 'n02007558' 'n02009912' 'n02051845' 'n02056570' 'n02066245' 'n02071294' 'n02077923' 'n02085620' 'n02086240' 'n02088094' 'n02088238' 'n02088364' 'n02088466' 'n02091032' 'n02091134' 'n02092339' 'n02094433' 'n02096585' 'n02097298' 'n02098286' 'n02099601' 'n02099712' 'n02102318' 'n02106030' 'n02106166' 'n02106550' 'n02106662' 'n02108089' 'n02108915' 'n02109525' 'n02110185' 'n02110341' 'n02110958' 'n02112018' 'n02112137' 'n02113023' 'n02113624' 'n02113799' 'n02114367' 'n02117135' 'n02119022' 'n02123045' 'n02128385' 'n02128757' 'n02129165' 'n02129604' 'n02130308' 'n02134084' 'n02138441' 'n02165456' 'n02190166' 'n02206856' 'n02219486' 'n02226429' 'n02233338' 'n02236044' 'n02268443' 'n02279972' 'n02317335' 'n02325366' 'n02346627' 'n02356798' 'n02363005' 'n02364673' 'n02391049' 'n02395406' 'n02398521' 'n02410509' 'n02423022' 'n02437616' 'n02445715' 'n02447366' 'n02480495' 'n02480855' 'n02481823' 'n02483362' 'n02486410' 'n02510455' 'n02526121' 'n02607072' 'n02655020' 'n02672831' 'n02701002' 'n02749479' 'n02769748' 'n02793495' 'n02797295' 'n02802426' 'n02808440' 'n02814860' 'n02823750' 'n02841315' 'n02843684' 'n02883205' 'n02906734' 'n02909870' 'n02939185' 'n02948072' 'n02950826' 'n02951358' 'n02966193' 'n02980441' 'n02992529' 'n03124170' 'n03272010' 'n03345487' 'n03372029' 'n03424325' 'n03452741' 'n03467068' 'n03481172' 'n03494278' 'n03495258' 'n03498962' 'n03594945' 'n03602883' 'n03630383' 'n03649909' 'n03676483' 'n03710193' 'n03773504' 'n03775071' 'n03888257' 'n03930630' 'n03947888' 'n04086273' 'n04118538' 'n04133789' 'n04141076' 'n04146614' 'n04147183' 'n04192698' 'n04254680' 'n04266014' 'n04275548' 'n04310018' 'n04325704' 'n04347754' 'n04389033' 'n04409515' 'n04465501' 'n04487394' 'n04522168' 'n04536866' 'n04552348' 'n04591713' 'n07614500' 'n07693725' 'n07695742' 'n07697313' 'n07697537' 'n07714571' 'n07714990' 'n07718472' 'n07720875' 'n07734744' 'n07742313' 'n07745940' 'n07749582' 'n07753275' 'n07753592' 'n07768694' 'n07873807' 'n07880968' 'n07920052' 'n09472597' 'n09835506' 'n10565667' 'n12267677') 12 | 13 | 14 | for file in $1/imagenetv1/*; do 15 | mkdir -p $2/imagenetv1/$(basename $file) 16 | for id in ${IDS[@]}; do 17 | ln -s $file/$id $2/imagenetv1/$(basename $file)/$id 18 | done 19 | done 20 | 21 | for file in $1/imagenetv2/*; do 22 | mkdir -p $2/imagenetv2/$(basename $file) 23 | for id in ${IDS[@]}; do 24 | ln -s $file/$id $2/imagenetv2/$(basename $file)/$id 25 | done 26 | done 27 | 28 | for file in $1/imagenet-c/*; do 29 | for severity in {0..4}; do 30 | mkdir -p $2/imagenet-c/$(basename $file)/$severity 31 | for id in ${IDS[@]}; do 32 | ln -s $file/$severity/$id $2/imagenet-c/$(basename $file)/$severity/$id 33 | done 34 | done 35 | done 36 | 37 | ln -s $1/imagenet-r $2/imagenet-r 38 | 39 | mkdir -p $2/imagenet-sketch 40 | 41 | for id in ${IDS[@]}; do 42 | ln -s $1/imagenet-sketch/$id $2/imagenet-sketch/$id 43 | done 44 | 45 | 46 | -------------------------------------------------------------------------------- /dataset_scripts/setup_camelyon.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$#" -ne 1 ]; then 4 | echo "Illegal number of parameters" 5 | echo "Usage: ./setup_camelyon.sh " 6 | exit 1 7 | fi 8 | 9 | 10 | python <(echo "from wilds import get_dataset; get_dataset(dataset='camelyon17', download=True, root_dir='$1')") -------------------------------------------------------------------------------- /dataset_scripts/setup_cifar100c.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$#" -ne 1 ]; then 4 | echo "Illegal number of parameters" 5 | echo "Usage: ./setup_cifar100c.sh " 6 | exit 1 7 | fi 8 | 9 | ## Download CIFAR10-C 10 | echo "Downloading CIFAR-100 C..." 11 | 12 | wget https://zenodo.org/record/3555552/files/CIFAR-100-C.tar?download=1 13 | tar -xvf "CIFAR-100-C.tar?download=1" -C $1/ 14 | rm -rf "CIFAR-100-C.tar?download=1" 15 | 16 | echo "CIFAR100-C downloaded" 17 | -------------------------------------------------------------------------------- /dataset_scripts/setup_cifar10c.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$#" -ne 1 ]; then 4 | echo "Illegal number of parameters" 5 | echo "Usage: ./setup_cifar10c.sh " 6 | exit 1 7 | fi 8 | 9 | 10 | ## Download CIFAR10v2 11 | echo "Downloading CIFAR10v2..." 12 | mkdir -p $1/cifar10v2 13 | 14 | wget https://github.com/modestyachts/cifar-10.2/raw/master/cifar102_train.npz 15 | mv cifar102_train.npz $1/cifar10v2/ 16 | 17 | wget https://github.com/modestyachts/cifar-10.2/raw/master/cifar102_test.npz 18 | mv cifar102_test.npz $1/cifar10v2/ 19 | 20 | echo "CIFAR10v2 downloaded" 21 | 22 | ## Download CIFAR10-C 23 | echo "Downloading CIFAR-10 C..." 24 | 25 | wget https://zenodo.org/record/2535967/files/CIFAR-10-C.tar?download=1 26 | tar -xvf "CIFAR-10-C.tar?download=1" -C $1/ 27 | rm -rf "CIFAR-10-C.tar?download=1" 28 | 29 | echo "CIFAR10-C downloaded" 30 | -------------------------------------------------------------------------------- /dataset_scripts/setup_domainnet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$#" -ne 1 ]; then 4 | echo "Illegal number of parameters" 5 | echo "Usage: ./setup_domainnet.sh " 6 | exit 1 7 | fi 8 | 9 | 10 | python <(echo "from wilds import get_dataset; get_dataset(dataset='domainnet', download=True, root_dir='$1')") -------------------------------------------------------------------------------- /dataset_scripts/setup_fmow.sh: -------------------------------------------------------------------------------- 1 | se#!/bin/bash 2 | 3 | if [ "$#" -ne 1 ]; then 4 | echo "Illegal number of parameters" 5 | echo "Usage: ./setup_fmow.sh " 6 | exit 1 7 | fi 8 | 9 | 10 | python <(echo "from wilds import get_dataset; get_dataset(dataset='fmow', download=True, root_dir='$1')") -------------------------------------------------------------------------------- /dataset_scripts/setup_iwildcams.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$#" -ne 1 ]; then 4 | echo "Illegal number of parameters" 5 | echo "Usage: ./setup_iwildscam.sh " 6 | exit 1 7 | fi 8 | 9 | 10 | python <(echo "from wilds import get_dataset; get_dataset(dataset='iwildcam', download=True, root_dir='$1')") -------------------------------------------------------------------------------- /dataset_scripts/setup_office31.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$#" -ne 1 ]; then 4 | echo "Illegal number of parameters" 5 | echo "Usage: ./setup_office31.sh " 6 | exit 1 7 | fi 8 | 9 | curr_dir=$(pwd) 10 | cd $1 11 | 12 | wget https://wjdcloud.blob.core.windows.net/dataset/OFFICE31.zip 13 | unzip OFFICE31.zip 14 | rm OFFICE31.zip 15 | 16 | mv OFFICE31 office31 17 | 18 | cd $curr_dir 19 | -------------------------------------------------------------------------------- /dataset_scripts/setup_officehome.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$#" -ne 1 ]; then 4 | echo "Illegal number of parameters" 5 | echo "Usage: ./setup_officehome.sh " 6 | exit 1 7 | fi 8 | 9 | curr_dir=$(pwd) 10 | 11 | 12 | mkdir $1/ 13 | cd $1/ 14 | 15 | wget https://wjdcloud.blob.core.windows.net/dataset/OfficeHome.zip 16 | unzip OfficeHome.zip 17 | rm OfficeHome.zip 18 | 19 | mv OfficeHome officehome 20 | 21 | cd $curr_dir -------------------------------------------------------------------------------- /dataset_scripts/setup_rxrx1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$#" -ne 1 ]; then 4 | echo "Illegal number of parameters" 5 | echo "Usage: ./setup_rxrx1.sh " 6 | exit 1 7 | fi 8 | 9 | 10 | python <(echo "from wilds import get_dataset; get_dataset(dataset='rxrx1', download=True, root_dir='$1')") -------------------------------------------------------------------------------- /dataset_scripts/setup_visda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$#" -ne 2 ]; then 4 | echo "Illegal number of parameters" 5 | echo "Usage: ./setup_visda.sh " 6 | exit 1 7 | fi 8 | 9 | curr_dir=$(pwd) 10 | mkdir $1/visda 11 | cd $1/visda 12 | 13 | wget http://csr.bu.edu/ftp/visda17/clf/train.tar 14 | tar xvf train.tar 15 | 16 | wget http://csr.bu.edu/ftp/visda17/clf/validation.tar 17 | tar xvf validation.tar 18 | 19 | wget http://csr.bu.edu/ftp/visda17/clf/test.tar 20 | tar xvf test.tar 21 | 22 | wget https://raw.githubusercontent.com/VisionLearningGroup/taskcv-2017-public/master/classification/data/image_list.txt 23 | 24 | python $2 --dir=./test/ --map=image_list.txt 25 | 26 | rm -rf test/trunk* 27 | 28 | rm -rf train.tar validation.tar test.tar 29 | 30 | cd $curr_dir 31 | -------------------------------------------------------------------------------- /dataset_scripts/visda_structure.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import argparse 3 | import os 4 | 5 | #VISDA classes 6 | classes = ["aeroplane","bicycle","bus","car","horse","knife","motorcycle","person","plant","skateboard","train","truck"] 7 | 8 | # Argparser with dir input 9 | parser = argparse.ArgumentParser(description='VISDA structure') 10 | parser.add_argument('--dir', type=str, default='', help='path to dataset') 11 | parser.add_argument('--map', type=str, default='', help='path to map file') 12 | 13 | 14 | args = parser.parse_args() 15 | 16 | map_files = {} 17 | 18 | # Read map file 19 | with open(args.map, 'r') as f: 20 | for line in f: 21 | file_name, id = line.rstrip().split() 22 | file_name = file_name.split('/')[-1] 23 | map_files[file_name] = int(id) 24 | 25 | 26 | # Read directory args.dir recursively 27 | for root, dirs, files in os.walk(args.dir): 28 | 29 | # Load file 30 | for file in files: 31 | 32 | img_file = f"{file}" 33 | # img_file = img_file[2:] 34 | print(root, dirs, img_file) 35 | if img_file.endswith('.jpg') and 'trunk' in root: 36 | img_id = map_files[img_file] 37 | 38 | # Create folder 39 | if not os.path.exists(f"{args.dir}/{classes[img_id]}"): 40 | os.makedirs(f"{args.dir}/{classes[img_id]}") 41 | 42 | 43 | # Move file 44 | os.rename(f"{root}/{file}", f"{args.dir}/{classes[img_id]}/{file}") -------------------------------------------------------------------------------- /images/RLSbench_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acmi-lab/RLSbench/eb67d5c78aa3646b7369830e481b3f15a59a087d/images/RLSbench_fig.png -------------------------------------------------------------------------------- /images/datasets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acmi-lab/RLSbench/eb67d5c78aa3646b7369830e481b3f15a59a087d/images/datasets.png -------------------------------------------------------------------------------- /pretrained_models/resnet18_imagenet32.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acmi-lab/RLSbench/eb67d5c78aa3646b7369830e481b3f15a59a087d/pretrained_models/resnet18_imagenet32.pt -------------------------------------------------------------------------------- /scripts/eval_ERM.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import time 5 | from collections import Counter 6 | from datetime import date 7 | from subprocess import Popen 8 | 9 | 10 | NUM_RUNS = 12 11 | GPU_IDS = list(range(4)) 12 | NUM_GPUS = len(GPU_IDS) 13 | counter = 0 14 | 15 | 16 | DATASETS = [ 17 | "domainnet", 18 | "camelyon", 19 | # # 'iwildcam', 20 | "fmow", 21 | "cifar10", 22 | "cifar100", 23 | "entity13", 24 | "entity30", 25 | "living17", 26 | "nonliving26", 27 | # 'office31', 28 | "officehome", 29 | "visda", 30 | ] 31 | TARGET_SETS = { 32 | "cifar10": ["0", "1", "10", "71", "95"], 33 | "cifar100": ["0", "4", "12", "59", "82"], 34 | "fmow": ["0", "1", "2"], 35 | "iwildcams": ["0", "1", "2"], 36 | "camelyon": ["0", "1", "2"], 37 | "domainnet": ["0", "1", "2", "3"], 38 | "entity13": ["0", "1", "2", "3"], 39 | "entity30": ["0", "1", "2", "3"], 40 | "living17": ["0", "1", "2", "3"], 41 | "nonliving26": ["0", "1", "2", "3"], 42 | "officehome": ["0", "1", "2", "3"], 43 | "office31": ["0", "1", "2"], 44 | "visda": ["0", "1", "2"], 45 | } 46 | 47 | SEEDS = ["42"] 48 | ALPHA = ["0.5", "1.0", "3.0", "10.0", "100.0"] 49 | ALGORITHMS = ["ERM-aug"] 50 | # ALGORITHMS= ["ERM", "ERM-aug"] 51 | 52 | SOURCE_FILE = { 53 | "cifar10": "logs_consistent/cifar10_seed\:%s/%s-imagenet_pretrained\:imagenet/", 54 | "cifar100": "logs_consistent/cifar100_seed\:%s/%s-imagenet_pretrained\:imagenet/", 55 | "camelyon": "logs_consistent/camelyon_seed\:%s/%s-rand_pretrained\:rand/", 56 | "entity13": "logs_consistent/entity13_seed\:%s/%s-rand_pretrained\:rand/", 57 | "entity30": "logs_consistent/entity30_seed\:%s/%s-rand_pretrained\:rand/", 58 | "living17": "logs_consistent/living17_seed\:%s/%s-rand_pretrained\:rand/", 59 | "nonliving26": "logs_consistent/nonliving26_seed\:%s/%s-rand_pretrained\:rand/", 60 | "fmow": "logs_consistent/fmow_seed\:%s/%s-imagenet_pretrained\:imagenet/", 61 | "domainnet": "logs_consistent/domainnet_seed\:%s/%s-imagenet_pretrained\:imagenet/", 62 | "officehome": "logs_consistent/officehome_seed\:%s/%s-imagenet_pretrained\:imagenet/", 63 | "visda": "logs_consistent/visda_seed\:%s/%s-imagenet_pretrained\:imagenet/", 64 | } 65 | 66 | procs = [] 67 | 68 | for dataset in DATASETS: 69 | for seed in SEEDS: 70 | for alpha in ALPHA: 71 | for target_set in TARGET_SETS[dataset]: 72 | for algorithm in ALGORITHMS: 73 | gpu_id = GPU_IDS[counter % NUM_GPUS] 74 | 75 | source_models = SOURCE_FILE[dataset] % (seed, algorithm) 76 | 77 | cmd = f"CUDA_VISIBLE_DEVICES={gpu_id} python run_expt.py --remote False \ 78 | --dataset {dataset} --root_dir /home/ubuntu/data --seed {seed} \ 79 | --transform image_none --algorithm {algorithm} --eval_only --use_source_model \ 80 | --source_model_path={source_models} --dirichlet_alpha {alpha} \ 81 | --target_split {target_set} --use_target True --simulate_label_shift True" 82 | 83 | print(cmd) 84 | procs.append(Popen(cmd, shell=True)) 85 | 86 | time.sleep(3) 87 | 88 | counter += 1 89 | 90 | if counter % NUM_RUNS == 0: 91 | for p in procs: 92 | p.wait() 93 | procs = [] 94 | time.sleep(3) 95 | 96 | print("\n \n \n \n --------------------------- \n \n \n \n") 97 | print(f"{date.today()} - {counter} runs completed") 98 | sys.stdout.flush() 99 | print("\n \n \n \n --------------------------- \n \n \n \n") 100 | 101 | 102 | for p in procs: 103 | p.wait() 104 | procs = [] 105 | -------------------------------------------------------------------------------- /scripts/run_ERM.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_RUNS=2 3 | GPU_IDS=( 0 1 2 3 4 5 6 7 ) 4 | NUM_GPUS=${#GPU_IDS[@]} 5 | counter=0 6 | 7 | DATASETS=( 'fmow' ) 8 | SEEDS=( 42 1234 ) 9 | ALPHA=('0.0' '0.5' '1.0' '5.0' '10.0' '100.0') 10 | ALGORITHMS=( "ERM-aug" ) 11 | 12 | for dataset in "${DATASETS[@]}"; do 13 | for algorithm in "${ALGORITHMS[@]}"; do 14 | for seed in "${SEEDS[@]}"; do 15 | 16 | # Get GPU id. 17 | gpu_idx=$((counter % $NUM_GPUS)) 18 | gpu_id=${GPU_IDS[$gpu_idx]} 19 | 20 | cmd="CUDA_VISIBLE_DEVICES=${gpu_id} python run_expt.py '--remote' 'False' '--dataset' ${dataset} '--n_epochs' 60 --resume\ 21 | '--root_dir' '/home/sgarg2/data' '--seed' ${seed} '--transform' 'image_none' '--algorithm' ${algorithm} --progress_bar" 22 | 23 | echo $cmd 24 | eval ${cmd} & 25 | 26 | counter=$((counter+1)) 27 | if ! ((counter % NUM_RUNS)); then 28 | wait 29 | fi 30 | done 31 | done 32 | done 33 | 34 | 35 | -------------------------------------------------------------------------------- /scripts/run_adapt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import time 5 | from collections import Counter 6 | from datetime import date 7 | from subprocess import Popen 8 | 9 | 10 | def check_for_done(l): 11 | for i, p in enumerate(l): 12 | if p.poll() is not None: 13 | return True, i 14 | return False, False 15 | 16 | 17 | NUM_RUNS = 4 18 | GPU_IDS = [0, 1, 2, 3] 19 | NUM_GPUS = len(GPU_IDS) 20 | counter = 0 21 | 22 | DATASETS = [ 23 | "domainnet", 24 | "camelyon", 25 | # # 'iwildcam', 26 | "fmow", 27 | "cifar10", 28 | "cifar100", 29 | "entity13", 30 | "entity30", 31 | "living17", 32 | "nonliving26", 33 | "office31", 34 | "officehome", 35 | "visda", 36 | ] 37 | TARGET_SETS = { 38 | "cifar10": ["0", "1", "10", "23", "57", "71", "95"], 39 | "cifar100": ["0", "4", "12", "43", "59", "82"], 40 | "fmow": ["0", "1", "2"], 41 | "iwildcams": ["0", "1", "2"], 42 | "camelyon": ["0", "1", "2"], 43 | "domainnet": ["0", "1", "2", "3"], 44 | "entity13": ["0", "1", "2", "3"], 45 | "entity30": ["0", "1", "2", "3"], 46 | "living17": ["0", "1", "2", "3"], 47 | "nonliving26": ["0", "1", "2", "3"], 48 | "officehome": ["0", "1", "2", "3"], 49 | "office31": ["0", "1", "2"], 50 | "visda": ["0", "1", "2"], 51 | } 52 | 53 | SEEDS = ["1234", "42"] 54 | ALPHA = ["0.5", "1.0", "3.0", "10.0", "100.0"] 55 | ALGORITHMS = [ 56 | "DANN", 57 | "IW-DANN", 58 | "IW-CDANN", 59 | "FixMatch", 60 | "CDANN", 61 | "SENTRY", 62 | "IS-DANN", 63 | "IS-CDANN", 64 | "IS-FixMatch", 65 | ] 66 | 67 | procs = list() 68 | 69 | for dataset in DATASETS: 70 | for seed in SEEDS: 71 | for alpha in ALPHA: 72 | for algorithm in ALGORITHMS: 73 | for target_set in TARGET_SETS[dataset]: 74 | gpu_id = GPU_IDS[counter % NUM_GPUS] 75 | 76 | cmd = f"CUDA_VISIBLE_DEVICES={gpu_id} python run_expt.py --remote False \ 77 | --dataset {dataset} --root_dir /home/sgarg2/data --seed {seed} \ 78 | --transform image_none --algorithm {algorithm} --dirichlet_alpha {alpha} \ 79 | --target_split {target_set} --use_target True --simulate_label_shift True" 80 | 81 | print(cmd) 82 | procs.append(Popen(cmd, shell=True)) 83 | 84 | time.sleep(3) 85 | 86 | counter += 1 87 | 88 | if len(procs) == NUM_RUNS: 89 | wait = True 90 | 91 | while wait: 92 | done, num = check_for_done(procs) 93 | 94 | if done: 95 | procs.pop(num) 96 | wait = False 97 | else: 98 | time.sleep(3) 99 | 100 | print("\n \n \n \n --------------------------- \n \n \n \n") 101 | print(f"{date.today()} - {counter} runs completed") 102 | sys.stdout.flush() 103 | print("\n \n \n \n --------------------------- \n \n \n \n") 104 | 105 | 106 | for p in procs: 107 | p.wait() 108 | procs = [] 109 | -------------------------------------------------------------------------------- /scripts/run_tta.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import time 5 | from collections import Counter 6 | from datetime import date 7 | from subprocess import Popen 8 | 9 | 10 | def check_for_done(l): 11 | for i, p in enumerate(l): 12 | if p.poll() is not None: 13 | return True, i 14 | return False, False 15 | 16 | 17 | NUM_RUNS = 32 18 | GPU_IDS = [0, 1, 2, 3, 4, 5, 6, 7] 19 | NUM_GPUS = len(GPU_IDS) 20 | counter = 0 21 | 22 | DATASETS = [ 23 | "camelyon", 24 | "fmow", 25 | "domainnet", 26 | "cifar10", 27 | "cifar100", 28 | "entity13", 29 | "entity30", 30 | "living17", 31 | "nonliving26", 32 | "officehome", 33 | "visda", 34 | ] 35 | TARGET_SETS = { 36 | "cifar10": ["0", "1", "10", "71", "95"], 37 | "cifar100": ["0", "4", "12", "59", "82"], 38 | "fmow": ["0", "1", "2"], 39 | "iwildcams": ["0", "1", "2"], 40 | "camelyon": ["0", "1", "2"], 41 | "domainnet": ["0", "1", "2", "3"], 42 | "entity13": ["0", "1", "2", "3"], 43 | "entity30": ["0", "1", "2", "3"], 44 | "living17": ["0", "1", "2", "3"], 45 | "nonliving26": ["0", "1", "2", "3"], 46 | "officehome": ["0", "1", "2", "3"], 47 | "office31": ["0", "1", "2"], 48 | "visda": ["0", "1", "2"], 49 | } 50 | 51 | SEEDS = ["42", "1234"] 52 | ALPHA = ["0.5", "1.0", "3.0", "10.0", "100.0"] 53 | ALGORITHMS = ["TENT"] 54 | 55 | SOURCE_FILE = { 56 | "cifar10": "logs_consistent_erm/cifar10_seed\:%s/ERM-aug-imagenet_pretrained\:imagenet/", 57 | "cifar100": "logs_consistent_erm/cifar100_seed\:%s/ERM-aug-imagenet_pretrained\:imagenet/", 58 | "camelyon": "logs_consistent_erm/camelyon_seed\:%s/ERM-aug-rand_pretrained\:rand/", 59 | "entity13": "logs_consistent_erm/entity13_seed\:%s/ERM-aug-rand_pretrained\:rand/", 60 | "entity30": "logs_consistent_erm/entity30_seed\:%s/ERM-aug-rand_pretrained\:rand/", 61 | "living17": "logs_consistent_erm/living17_seed\:%s/ERM-aug-rand_pretrained\:rand/", 62 | "nonliving26": "logs_consistent_erm/nonliving26_seed\:%s/ERM-aug-rand_pretrained\:rand/", 63 | "fmow": "logs_consistent_erm/fmow_seed\:%s/ERM-aug-imagenet_pretrained\:imagenet/", 64 | "domainnet": "logs_consistent_erm/domainnet_seed\:%s/ERM-aug-imagenet_pretrained\:imagenet/", 65 | "officehome": "logs_consistent_erm/officehome_seed\:%s/ERM-aug-imagenet_pretrained\:imagenet/", 66 | "visda": "logs_consistent_erm/visda_seed\:%s/ERM-aug-imagenet_pretrained\:imagenet/", 67 | } 68 | 69 | gpu_queue = list() 70 | procs = list() 71 | gpu_id = 0 72 | gpu_use = list() 73 | 74 | for i in range(NUM_RUNS): 75 | gpu_queue.append(i % NUM_GPUS) 76 | 77 | for algorithm in ALGORITHMS: 78 | for dataset in DATASETS: 79 | for seed in SEEDS: 80 | for alpha in ALPHA: 81 | for target_set in TARGET_SETS[dataset]: 82 | # gpu_id = GPU_IDS[counter % NUM_GPUS] 83 | gpu_id = gpu_queue.pop(0) 84 | 85 | source_model_path = SOURCE_FILE[dataset] % (seed) 86 | 87 | cmd = f"CUDA_VISIBLE_DEVICES={gpu_id} python run_expt.py --remote False \ 88 | --dataset {dataset} --root_dir /home/ubuntu/data --seed {seed} \ 89 | --transform image_none --algorithm {algorithm} --test_time_adapt --use_source_model \ 90 | --source_model_path={source_model_path} --dirichlet_alpha {alpha} \ 91 | --target_split {target_set} --use_target True --simulate_label_shift True" 92 | 93 | print(cmd) 94 | 95 | procs.append(Popen(cmd, shell=True)) 96 | gpu_use.append(gpu_id) 97 | 98 | time.sleep(3) 99 | 100 | counter += 1 101 | 102 | if len(procs) == NUM_RUNS: 103 | wait = True 104 | 105 | while wait: 106 | done, num = check_for_done(procs) 107 | 108 | if done: 109 | procs.pop(num) 110 | wait = False 111 | gpu_queue.append(gpu_use.pop(num)) 112 | else: 113 | time.sleep(3) 114 | 115 | print("\n \n \n \n --------------------------- \n \n \n \n") 116 | print(f"{date.today()} - {counter} runs completed") 117 | sys.stdout.flush() 118 | print("\n \n \n \n --------------------------- \n \n \n \n") 119 | 120 | for p in procs: 121 | p.wait() 122 | procs = [] 123 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/p-lambda/wilds/blob/main/setup.py 2 | 3 | import os 4 | import sys 5 | 6 | import setuptools 7 | 8 | here = os.path.abspath(os.path.dirname(__file__)) 9 | sys.path.insert(0, os.path.join(here, 'RLSbench')) 10 | 11 | from version import __version__ 12 | 13 | print(f'Version {__version__}') 14 | 15 | with open("README.md", "r", encoding="utf-8") as fh: 16 | long_description = fh.read() 17 | 18 | setuptools.setup( 19 | name="RLSBench", 20 | version=__version__, 21 | author="Saurabh Garg", 22 | author_email="sgarg2@andrew.cmu.edu", 23 | description="Relaxed Label Shift benchmark", 24 | long_description=long_description, 25 | long_description_content_type="text/markdown", 26 | install_requires = [ 27 | 'numpy>=1.21.1', 28 | 'pandas>=1.1.0', 29 | 'pillow>=7.2.0', 30 | 'pytz>=2020.4', 31 | 'torch>=1.10.0', 32 | 'torchvision>=0.11.3', 33 | 'tqdm>=4.53.0', 34 | 'scikit-learn>=0.20.0', 35 | 'scipy>=1.5.4', 36 | 'cvxpy>=1.1.7', 37 | 'cvxopt>=1.3.0', 38 | 'transformers>=4.21', 39 | 'matplotlib>=3.5.1', 40 | 'networkx>=2.0', 41 | 'antialiased-cnns', 42 | 'folktables', 43 | 'clip @ git+https://github.com/openai/CLIP.git##egg=clip', 44 | 'calibration @ git+https://github.com/saurabhgarg1996/calibration.git#egg=calibration', 45 | 'wilds @ git+https://github.com/saurabhgarg1996/wilds.git#egg=wilds', 46 | 'robustness @ git+https://github.com/saurabhgarg1996/robustness.git#egg=robustness' 47 | ], 48 | dependency_links=[ 49 | 'https://download.pytorch.org/whl/cu113', 50 | ], 51 | license='MIT', 52 | packages=setuptools.find_packages(), 53 | classifiers=[ 54 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 55 | 'Intended Audience :: Science/Research', 56 | "Programming Language :: Python :: 3", 57 | "License :: OSI Approved :: MIT License", 58 | ], 59 | python_requires='>=3.6', 60 | ) 61 | --------------------------------------------------------------------------------