├── images └── main.jpg ├── .gitmodules ├── LICENSE ├── module ├── util.py ├── mlp.py ├── resnet.py └── activations.py ├── test.py ├── README.md ├── train.py ├── data ├── civilcomments.py ├── grouper.py ├── __init__.py └── util.py ├── util.py └── learner.py /images/main.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shengliu66/LC/HEAD/images/main.jpg -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "module/lc-loss"] 2 | path = module/lc-loss 3 | url = https://github.com/amazon-science/lc-loss.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Sheng Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /module/util.py: -------------------------------------------------------------------------------- 1 | ''' Modified from https://github.com/alinlab/LfF/blob/master/module/util.py ''' 2 | 3 | import torch.nn as nn 4 | from module.resnet import resnet20 5 | from module.mlp import * 6 | from torchvision.models import resnet18, resnet50 7 | from transformers import BertForSequenceClassification, BertConfig 8 | 9 | def get_model(model_tag, num_classes, bias = True): 10 | if model_tag == "ResNet20": 11 | return resnet20(num_classes) 12 | elif model_tag == "ResNet20_OURS": 13 | model = resnet20(num_classes) 14 | model.fc = nn.Linear(128, num_classes) 15 | return model 16 | elif model_tag == "ResNet18": 17 | print('bringing no pretrained resnet18 ...') 18 | model = resnet18(pretrained=False) 19 | model.fc = nn.Linear(512, num_classes) 20 | return model 21 | elif model_tag == "MLP": 22 | return MLP(num_classes=num_classes) 23 | elif model_tag == "mlp_DISENTANGLE": 24 | return MLP_DISENTANGLE(num_classes=num_classes, bias = bias) 25 | elif model_tag == "mlp_DISENTANGLE_EASY": 26 | return MLP_DISENTANGLE_EASY(num_classes=num_classes) 27 | elif model_tag == 'resnet_DISENTANGLE': 28 | print('bringing no pretrained resnet18 disentangle...') 29 | model = resnet18(pretrained=False) 30 | model.fc = nn.Linear(1024//2, num_classes) 31 | return model 32 | elif model_tag == 'resnet_DISENTANGLE_pretrained': 33 | print('bringing pretrained resnet18 disentangle...') 34 | model = resnet18(pretrained=True) 35 | model.fc = nn.Linear(1024//2, num_classes) 36 | return model 37 | 38 | elif model_tag == 'resnet_50_pretrained': 39 | print('bringing pretrained resnet50 for water bird...') 40 | model = resnet50(pretrained=True) 41 | model.fc = nn.Linear(2048, num_classes) 42 | return model 43 | elif model_tag == 'resnet_50': 44 | print('bringing pretrained resnet50 for water bird...') 45 | model = resnet50(pretrained=False) 46 | model.fc = nn.Linear(2048, num_classes) 47 | return model 48 | elif 'bert' in model_tag: 49 | if model_tag[-3:] == '_pt': 50 | model_name = model_tag[:-3] 51 | else: 52 | model_name = model_tag 53 | 54 | config_class = BertConfig 55 | model_class = BertForSequenceClassification 56 | config = config_class.from_pretrained(model_name, 57 | num_labels=num_classes, 58 | finetuning_task='civilcomments') 59 | model = model_class.from_pretrained(model_name, from_tf=False, 60 | config=config) 61 | model.activation_layer = 'bert.pooler.activation' 62 | return model 63 | else: 64 | raise NotImplementedError 65 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | from learner import Learner 5 | import argparse 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser(description='Avoiding spurious correlations via logit correction') 9 | 10 | # training 11 | parser.add_argument("--batch_size", help="batch_size", default=256, type=int) 12 | parser.add_argument("--lr",help='learning rate',default=1e-3, type=float) 13 | parser.add_argument("--weight_decay",help='weight_decay',default=0.0, type=float) 14 | parser.add_argument("--momentum",help='momentum',default=0.9, type=float) 15 | parser.add_argument("--num_workers", help="workers number", default=16, type=int) 16 | parser.add_argument("--exp", help='experiment name', default='Test', type=str) 17 | parser.add_argument("--device", help="cuda or cpu", default='cuda', type=str) 18 | parser.add_argument("--num_steps", help="# of iterations", default= 500 * 100, type=int) 19 | parser.add_argument("--target_attr_idx", help="target_attr_idx", default= 0, type=int) 20 | parser.add_argument("--bias_attr_idx", help="bias_attr_idx", default= 1, type=int) 21 | parser.add_argument("--dataset", help="data to train, [cmnist, cifar10, bffhq]", default= 'cmnist', type=str) 22 | parser.add_argument("--percent", help="percentage of conflict", default= "1pct", type=str) 23 | parser.add_argument("--use_lr_decay", action='store_true', help="whether to use learning rate decay") 24 | parser.add_argument("--lr_decay_step", help="learning rate decay steps", type=int, default=10000) 25 | parser.add_argument("--q", help="GCE parameter q", type=float, default=0.7) 26 | parser.add_argument("--lr_gamma", help="lr gamma", type=float, default=0.1) 27 | parser.add_argument("--lambda_dis_align", help="lambda_dis in Eq.2", type=float, default=1.0) 28 | parser.add_argument("--lambda_swap_align", help="lambda_swap_b in Eq.3", type=float, default=1.0) 29 | parser.add_argument("--lambda_swap", help="lambda swap (lambda_swap in Eq.4)", type=float, default=1.0) 30 | parser.add_argument("--ema_alpha", help="use weight mul", type=float, default=0.7) 31 | parser.add_argument("--curr_step", help="curriculum steps", type=int, default= 0) 32 | parser.add_argument("--use_type0", action='store_true', help="whether to use type 0 CIFAR10C") 33 | parser.add_argument("--use_type1", action='store_true', help="whether to use type 1 CIFAR10C") 34 | parser.add_argument("--use_resnet20", help="Use Resnet20", action="store_true") # ResNet 20 was used in Learning From Failure CifarC10 (We used ResNet18 in our paper) 35 | parser.add_argument("--model", help="which network, [MLP, ResNet18, ResNet20, ResNet50]", default= 'MLP', type=str) 36 | 37 | # logging 38 | parser.add_argument("--log_dir", help='path for loading data', default='./log', type=str) 39 | parser.add_argument("--data_dir", help='path for saving models & logs', default='dataset', type=str) 40 | parser.add_argument("--valid_freq", help='frequency to evaluate on valid/test set', default=500, type=int) 41 | parser.add_argument("--log_freq", help='frequency to log on tensorboard', default=500, type=int) 42 | parser.add_argument("--save_freq", help='frequency to save model checkpoint', default=1000, type=int) 43 | parser.add_argument("--wandb", action="store_true", help="whether to use wandb") 44 | parser.add_argument("--tensorboard", action="store_true", help="whether to use tensorboard") 45 | 46 | parser.add_argument("--tau", help="loss tau", type=float, default=1) 47 | parser.add_argument("--avg_type", help="pya estimation types", type=str, default='mv') 48 | 49 | 50 | # experiment 51 | parser.add_argument("--pretrained_path", help="path for pretrained model", type=str) 52 | 53 | args = parser.parse_args() 54 | 55 | # init learner 56 | learner = Learner(args) 57 | 58 | # actual training 59 | print('Official Pytorch Code of "Avoiding spurious correlations via logit correction"') 60 | print('Test starts ...') 61 | 62 | learner.test_ours(args) 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Avoiding spurious correlations via logit correction 2 | This repository provides the official PyTorch implementation of the following paper: 3 | > Avoiding spurious correlations via logit correction
4 | > [Sheng Liu](https://shengliu66.github.io/) (NYU), Xu Zhang (Amazon), Nitesh Sekhar (Amazon), Yue Wu (Amazon), Prateek Singhal (Amazon), Carlos Fernandez-Granda (NYU) 5 | > ICLR 2023
6 | 7 | > Paper: [Arxiv](https://arxiv.org/abs/2212.01433)
8 | 9 | **Abstract:** 10 | *Empirical studies suggest that machine learning models trained with empirical risk minimization (ERM) often rely on attributes that may be spuriously correlated with the class labels. Such models typically lead to poor performance during inference for data lacking such correlations. In this work, we explicitly consider a situation where potential spurious correlations are present in the majority of training data. In contrast with existing approaches, which use the ERM model outputs to detect the samples without spurious correlations and either heuristically upweight or upsample those samples, we propose the logit correction (LC) loss, a simple yet effective improvement on the softmax cross-entropy loss, to correct the sample logit. We demonstrate that minimizing the LC loss is equivalent to maximizing the group-balanced accuracy, so the proposed LC could mitigate the negative impacts of spurious correlations. Our extensive experimental results further reveal that the proposed LC loss outperforms state-of-the-art solutions on multiple popular benchmarks by a large margin, an average 5.5\% absolute improvement, without access to spurious attribute labels. LC is also competitive with oracle methods that make use of the attribute labels.*
11 | 12 |

13 | 14 |

15 | 16 | ## Pytorch Implementation 17 | ### Installation 18 | Clone this repository. 19 | ``` 20 | git clone --recursive https://github.com/shengliu66/LC-private 21 | cd Logit-Correction 22 | pip install -r requirements.txt 23 | ``` 24 | ### Datasets 25 | - **Waterbirds:** Download waterbirds from [here](https://nlp.stanford.edu/data/dro/waterbird_complete95_forest2water2.tar.gz). 26 | - In that directory, our code expects `data/waterbird_complete95_forest2water2/` with `metadata.csv` inside. 27 | - Specify `--data_dir` 28 | 29 | - **CivilComments:** Download civilcomments dataset from [here (https://worksheets.codalab.org/rest/bundles/0x8cd3de0634154aeaad2ee6eb96723c6e/contents/blob/) 30 | - In that directory, our code expects a folder `data` with the downloaded dataset. 31 | - Specify `--data_dir` 32 | 33 | ### Running LC 34 | #### Waterbird 35 | ``` 36 | python train.py --dataset waterbird --exp=waterbird_ours --lr=1e-3 --weight_decay=1e-4 --curr_epoch=50 --lr_decay_epoch 50 --use_lr_decay --lambda_dis_align=2.0 --ema_alpha 0.5 --tau 0.1 --train_ours --q 0.8 --avg_type batch --data_dir /dir/to/data/ 37 | ``` 38 | #### CivilComments 39 | ``` 40 | python train.py --dataset civilcomments --exp=civilcomments_ours --lr=1e-5 --q 0.7 --log_freq 400 --valid_freq 400 --weight_decay=1e-2 --curr_step 400 --use_lr_decay --num_epochs 3 --lr_decay_step 4000 --lambda_dis_align=0.1 --ema_alpha 0.9 --tau 1.0 --avg_type mv_batch --train_ours --data_dir /dir/to/data/ 41 | ``` 42 | 43 | ## Monitoring Performance 44 | We use [Weights & Biases](https://wandb.ai/site) to monitor training, you can follow the [doc](https://docs.wandb.ai/quickstart) here to install and log in to W&B, and add the argument `--wandb` . 45 | 46 | 47 | ### Evaluate Models 48 | In order to test our pretrained models, run the following command. 49 | ``` 50 | python test.py --pretrained_path= --dataset= 51 | ``` 52 | 53 | ### Citations 54 | 55 | ### BibTeX 56 | ```bibtex 57 | @Inproceedings{ 58 | Liu2023, 59 | author = {Sheng Liu and Xu Zhang and Nitesh Sekhar and Yue Wu and Prateek Singhal and Carlos Fernandez-Granda}, 60 | title = {Avoiding spurious correlations via logit correction}, 61 | year = {2023}, 62 | url = {https://www.amazon.science/publications/avoiding-spurious-correlations-via-logit-correction}, 63 | booktitle = {ICLR 2023}, 64 | } 65 | ``` 66 | 67 | ### Contact 68 | Sheng Liu (shengliu@nyu.edu) 69 | 70 | ### Acknowledgments 71 | This work was mainly done the authors was doing internship at Amazon Science. 72 | Our pytorch implementation is based on [Disentangled](https://github.com/kakaoenterprise/Learning-Debiased-Disentangled). 73 | We would like to thank for the authors for making their code public. 74 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | from learner import Learner 5 | import argparse 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser(description='Avoiding spurious correlations via logit correction') 9 | 10 | # training 11 | parser.add_argument("--batch_size", help="batch_size", default=256, type=int) 12 | parser.add_argument("--lr",help='learning rate',default=1e-3, type=float) 13 | parser.add_argument("--weight_decay",help='weight_decay',default=0.0, type=float) 14 | parser.add_argument("--momentum",help='momentum',default=0.9, type=float) 15 | parser.add_argument("--num_workers", help="workers number", default=4, type=int) 16 | parser.add_argument("--exp", help='experiment name', default='debugging', type=str) 17 | parser.add_argument("--device", help="cuda or cpu", default='cuda', type=str) 18 | parser.add_argument("--num_steps", help="# of iterations", default= 500 * 100, type=int) 19 | parser.add_argument("--num_epochs", help="# of epochs", default= 300, type=int) 20 | parser.add_argument("--target_attr_idx", help="target_attr_idx", default= 0, type=int) 21 | parser.add_argument("--bias_attr_idx", help="bias_attr_idx", default= 1, type=int) 22 | parser.add_argument("--dataset", help="data to train, [cmnist, cifar10, bffhq]", default= 'cmnist', type=str) 23 | parser.add_argument("--percent", help="percentage of conflict", default= "1pct", type=str) 24 | parser.add_argument("--use_lr_decay", action='store_true', help="whether to use learning rate decay") 25 | parser.add_argument("--lr_decay_step", help="learning rate decay steps", type=int, default=10000) 26 | parser.add_argument("--lr_decay_epoch", help="learning rate decay epochs", type=int, default=70) 27 | parser.add_argument("--q", help="GCE parameter q", type=float, default=0.7) 28 | parser.add_argument("--lr_gamma", help="lr gamma", type=float, default=0.1) 29 | parser.add_argument("--lambda_dis_align", help="lambda_dis in Eq.2", type=float, default=1.0) 30 | parser.add_argument("--lambda_swap_align", help="lambda_swap_b in Eq.3", type=float, default=1.0) 31 | parser.add_argument("--lambda_swap", help="lambda swap (lambda_swap in Eq.4)", type=float, default=1.0) 32 | parser.add_argument("--ema_alpha", help="use weight mul", type=float, default=0.995) 33 | parser.add_argument("--curr_step", help="curriculum steps", type=int, default= 0) 34 | parser.add_argument("--curr_epoch", help="curriculum epochs", type=int, default= 0) 35 | parser.add_argument("--use_type0", action='store_true', help="whether to use type 0 CIFAR10C") 36 | parser.add_argument("--use_type1", action='store_true', help="whether to use type 1 CIFAR10C") 37 | parser.add_argument("--use_resnet20", help="Use Resnet20", action="store_true") # ResNet 20 was used in Learning From Failure CifarC10 (We used ResNet18 in our paper) 38 | parser.add_argument("--model", help="which network, [MLP, ResNet18, ResNet20, ResNet50]", default= 'MLP', type=str) 39 | 40 | # logging 41 | parser.add_argument("--log_dir", help='path for saving model', default='./log', type=str) 42 | parser.add_argument("--data_dir", help='path for loading data', default='dataset', type=str) 43 | parser.add_argument("--valid_freq", help='frequency to evaluate on valid/test set', default=500, type=int) 44 | parser.add_argument("--log_freq", help='frequency to log on tensorboard', default=500, type=int) 45 | parser.add_argument("--save_freq", help='frequency to save model checkpoint', default=1000, type=int) 46 | parser.add_argument("--wandb", action="store_true", help="whether to use wandb") 47 | parser.add_argument("--tensorboard", action="store_true", help="whether to use tensorboard") 48 | 49 | # experiment 50 | parser.add_argument("--train_ours", action="store_true", help="whether to train our method") 51 | parser.add_argument("--train_vanilla", action="store_true", help="whether to train vanilla") 52 | 53 | 54 | parser.add_argument("--alpha", help="mixup alpha", type=float, default=16) 55 | parser.add_argument('--proto_m', default=0.95, type=float, 56 | help='momentum for computing the momving average of prototypes') 57 | parser.add_argument('--temperature', default=0.1, type=float, 58 | help='contrastive temperature') 59 | parser.add_argument("--tau", help="loss tau", type=float, default=1) 60 | parser.add_argument("--avg_type", help="pya estimation types", type=str, default='mv') 61 | 62 | args = parser.parse_args() 63 | 64 | seed = 222 65 | torch.manual_seed(seed) 66 | torch.cuda.manual_seed(seed) 67 | torch.cuda.manual_seed_all(seed) 68 | np.random.seed(seed) 69 | random.seed(seed) 70 | torch.backends.cudnn.benchmark = False 71 | torch.backends.cudnn.deterministic = True 72 | 73 | # init learner 74 | learner = Learner(args) 75 | 76 | 77 | # actual training 78 | print('Official Pytorch Code of "Avoiding spurious correlations via logit correction"') 79 | print('Training starts ...') 80 | 81 | if args.train_ours: 82 | learner.train_ours(args) 83 | elif args.train_vanilla: 84 | learner.train_vanilla(args) 85 | else: 86 | print('choose one of the two options ...') 87 | import sys 88 | sys.exit(0) 89 | -------------------------------------------------------------------------------- /module/mlp.py: -------------------------------------------------------------------------------- 1 | ''' Modified from https://github.com/alinlab/LfF/blob/master/module/mlp.py''' 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class MLP_DISENTANGLE(nn.Module): 8 | def __init__(self, num_classes = 10, bias = True): 9 | super(MLP_DISENTANGLE, self).__init__() 10 | self.feature = nn.Sequential( 11 | nn.Linear(3*28*28, 100, bias = bias), 12 | nn.ReLU(), 13 | nn.Linear(100, 100, bias = bias), 14 | nn.ReLU(), 15 | nn.Linear(100, 16, bias = bias), 16 | nn.ReLU(), 17 | ) 18 | self.fc = nn.Sequential( 19 | nn.Linear(16, num_classes)) 20 | 21 | # self.projection_head = nn.Sequential( 22 | # nn.ReLU(), 23 | # nn.Linear(16,16, bias = bias)) 24 | 25 | 26 | def extract(self, x): 27 | x = x.view(x.size(0), -1) / 255 28 | feat = self.feature(x) 29 | # feat = self.projection_head(x) 30 | return feat 31 | 32 | def extract_rawfeature(self, x): 33 | x = x.view(x.size(0), -1) / 255 34 | feat = self.feature(x) 35 | return feat 36 | 37 | def predict(self, x): 38 | prediction = self.fc(x) 39 | return prediction 40 | 41 | def forward(self, x, mode=None, return_feat=False): 42 | x = x.view(x.size(0), -1) / 255 43 | feat = x = self.feature(x) 44 | final_x = self.fc(x) 45 | if mode == 'tsne' or mode == 'mixup': 46 | return x, final_x 47 | else: 48 | if return_feat: 49 | return final_x, feat 50 | else: 51 | return final_x 52 | 53 | 54 | 55 | 56 | class MLP_DISENTANGLE_EASY(nn.Module): 57 | def __init__(self, num_classes = 10): 58 | super(MLP_DISENTANGLE_EASY, self).__init__() 59 | self.feature = nn.Sequential( 60 | nn.Linear(3*28*28, 50), 61 | nn.ReLU(), 62 | nn.Linear(50, 50), 63 | nn.ReLU(), 64 | nn.Linear(50, 16), 65 | nn.ReLU() 66 | ) 67 | self.fc = nn.Linear(32, num_classes) 68 | self.fc2 = nn.Linear(16, num_classes) 69 | 70 | def extract(self, x): 71 | x = x.view(x.size(0), -1) / 255 72 | feat = self.feature(x) 73 | return feat 74 | 75 | def predict(self, x): 76 | prediction = self.classifier(x) 77 | return prediction 78 | 79 | def forward(self, x, mode=None, return_feat=False): 80 | x = x.view(x.size(0), -1) / 255 81 | feat = x = self.feature(x) 82 | final_x = self.classifier(x) 83 | if mode == 'tsne' or mode == 'mixup': 84 | return x, final_x 85 | else: 86 | if return_feat: 87 | return final_x, feat 88 | else: 89 | return final_x 90 | 91 | 92 | 93 | 94 | class MLP(nn.Module): 95 | def __init__(self, num_classes = 10): 96 | super(MLP, self).__init__() 97 | self.feature = nn.Sequential( 98 | nn.Linear(3*28*28, 100), 99 | nn.ReLU(), 100 | nn.Linear(100, 100), 101 | nn.ReLU(), 102 | nn.Linear(100, 16), 103 | nn.ReLU() 104 | ) 105 | self.classifier = nn.Linear(16, num_classes) 106 | 107 | 108 | def forward(self, x, mode=None, return_feat=False): 109 | x = x.view(x.size(0), -1) / 255 110 | feat = x = self.feature(x) 111 | final_x = self.classifier(x) 112 | if mode == 'tsne' or mode == 'mixup': 113 | return x, final_x 114 | else: 115 | if return_feat: 116 | return final_x, feat 117 | else: 118 | return final_x 119 | 120 | class MLP_DISENTANGLE_SHENG(nn.Module): 121 | def __init__(self, num_classes = 10): 122 | super(MLP_DISENTANGLE_SHENG, self).__init__() 123 | self.feature = nn.Sequential( 124 | nn.Linear(3*28*28, 100), 125 | nn.ReLU(), 126 | nn.Linear(100, 100), 127 | nn.ReLU(), 128 | nn.Linear(100, 16), 129 | nn.ReLU() 130 | ) 131 | 132 | self.task_head_target = nn.Sequential( 133 | nn.Linear(16,16), 134 | nn.ReLU(), 135 | # nn.Linear(16,16), 136 | # nn.ReLU(), 137 | ) 138 | 139 | self.task_head_bias = nn.Sequential( 140 | nn.Linear(16,16), 141 | nn.ReLU(), 142 | # nn.Linear(16,16), 143 | # nn.ReLU(), 144 | ) 145 | 146 | self.fc_target = nn.Linear(16, num_classes) 147 | self.fc_bias = nn.Linear(16, num_classes) 148 | # self.classifier_bias = nn.Linear(16, num_classes) 149 | 150 | 151 | def extract_target(self, x): 152 | x = x.view(x.size(0), -1) / 255 153 | x = self.feature(x) 154 | feat = self.task_head_target(x) 155 | return feat 156 | 157 | def extract_bias(self, x): 158 | x = x.view(x.size(0), -1) / 255 159 | x = self.feature(x) 160 | feat = self.task_head_bias(x) 161 | return feat 162 | 163 | def predict_target(self, x): 164 | prediction = self.fc_target(x) 165 | return prediction 166 | 167 | def predict_bias(self, x): 168 | prediction = self.fc_bias(x) 169 | return prediction 170 | 171 | def forward(self, x, mode=None, return_feat=False): 172 | x = x.view(x.size(0), -1) / 255 173 | feat = x = self.feature(x) 174 | feat_target = self.task_head_target(x) 175 | feat_bias = self.task_head_bias(x) 176 | 177 | final_x_target = self.fc_target(feat_target) 178 | final_x_bias = self.fc_bias(feat_bias) 179 | 180 | if mode == 'tsne' or mode == 'mixup': 181 | return feat_target, feat_bias, final_x_target, final_x_bias 182 | else: 183 | if return_feat: 184 | return final_x_target, feat_target, final_x_bias, feat_bias 185 | else: 186 | return final_x_target, final_x_bias 187 | 188 | 189 | class Noise_MLP(nn.Module): 190 | def __init__(self, n_dim=16, n_layer=3): 191 | super(Noise_MLP, self).__init__() 192 | 193 | layers = [] 194 | for i in range(n_layer): 195 | layers.append(nn.Linear(n_dim, n_dim)) 196 | layers.append(nn.LeakyReLU(0.2)) 197 | 198 | self.style = nn.Sequential(*layers) 199 | 200 | def forward(self, z): 201 | x = self.style(z) 202 | return x 203 | -------------------------------------------------------------------------------- /data/civilcomments.py: -------------------------------------------------------------------------------- 1 | """ 2 | CivilComments Dataset 3 | - Reference code: https://github.com/p-lambda/wilds/blob/main/wilds/datasets/civilcomments_dataset.py 4 | - See WILDS, https://wilds.stanford.edu for more 5 | """ 6 | import os 7 | import numpy as np 8 | import pandas as pd 9 | 10 | import torch 11 | from torch.utils.data import Dataset, DataLoader 12 | from transformers import BertTokenizerFast 13 | 14 | from data.grouper import CombinatorialGrouper 15 | 16 | 17 | class CivilComments(Dataset): 18 | """ 19 | CivilComments dataset 20 | """ 21 | def __init__(self, root_dir, 22 | target_name='toxic', confounder_names=['identities'], 23 | split='train', transform=None): 24 | self.root_dir = root_dir 25 | self.target_name = target_name 26 | self.confounder_names = confounder_names 27 | self.transform = transform 28 | self.split = split 29 | 30 | # Labels 31 | self.class_names = ['non_toxic', 'toxic'] 32 | 33 | # Set up data directories 34 | self.data_dir = os.path.join(self.root_dir) 35 | if not os.path.exists(self.data_dir): 36 | raise ValueError( 37 | f'{self.data_dir} does not exist yet. Please generate the dataset first.') 38 | 39 | # Read in metadata 40 | type_of_split = self.target_name.split('_')[-1] 41 | self.metadata_df = pd.read_csv( 42 | os.path.join(self.data_dir, 'all_data_with_identities.csv'), 43 | index_col=0) 44 | 45 | # Get split 46 | self.split_array = self.metadata_df['split'].values 47 | self.metadata_df = self.metadata_df[ 48 | self.metadata_df['split'] == split] 49 | 50 | # Get the y values 51 | self.y_array = torch.LongTensor( 52 | self.metadata_df['toxicity'].values >= 0.5) 53 | self.y_size = 1 54 | self.n_classes = 2 55 | 56 | # Get text 57 | self.x_array = np.array(self.metadata_df['comment_text']) 58 | 59 | # Get confounders and map to groups 60 | self._identity_vars = ['male', 61 | 'female', 62 | 'LGBTQ', 63 | 'christian', 64 | 'muslim', 65 | 'other_religions', 66 | 'black', 67 | 'white'] 68 | self._auxiliary_vars = ['identity_any', 69 | 'severe_toxicity', 70 | 'obscene', 71 | 'threat', 72 | 'insult', 73 | 'identity_attack', 74 | 'sexual_explicit'] 75 | self.metadata_array = torch.cat( 76 | (torch.LongTensor((self.metadata_df.loc[:, self._identity_vars] >= 0.5).values), 77 | torch.LongTensor((self.metadata_df.loc[:, self._auxiliary_vars] >= 0.5).values), 78 | self.y_array.reshape((-1, 1))), dim=1) 79 | 80 | self.metadata_fields = self._identity_vars + self._auxiliary_vars + ['y'] 81 | self.confounder_array = self.metadata_array[:, np.arange(len(self._identity_vars))] 82 | self.metadata_map = None 83 | 84 | self._eval_groupers = [ 85 | CombinatorialGrouper( 86 | dataset=self, 87 | groupby_fields=[identity_var, 'y']) 88 | for identity_var in self._identity_vars] 89 | 90 | # Get sub_targets / group_idx 91 | groupby_fields = self._identity_vars + ['y'] 92 | self.eval_grouper = CombinatorialGrouper(self, groupby_fields) 93 | self.group_array = self.eval_grouper.metadata_to_group(self.metadata_array, 94 | return_counts=False) 95 | self.n_groups = len(np.unique(self.group_array)) 96 | 97 | # Get spurious labels 98 | self.spurious_grouper = CombinatorialGrouper(self, 99 | self._identity_vars) 100 | self.spurious_array = self.spurious_grouper.metadata_to_group( 101 | self.metadata_array, return_counts=False).numpy() 102 | 103 | # Get consistent label attributes 104 | self.targets = self.y_array 105 | 106 | unique_group_ix = np.unique(self.spurious_array) 107 | group_ix_to_label = {} 108 | for i, gix in enumerate(unique_group_ix): 109 | group_ix_to_label[gix] = i 110 | spurious_labels = [group_ix_to_label[int(s)] 111 | for s in self.spurious_array] 112 | 113 | self.targets_all = {'target': np.array(self.y_array), 114 | 'group_idx': np.array(self.group_array), 115 | 'spurious': np.array(spurious_labels), 116 | 'sub_target': np.array(self.metadata_array[:, self.eval_grouper.groupby_field_indices]), 117 | 'metadata': np.array(self.metadata_array)} 118 | self.group_labels = [self.group_str(i) for i in range(self.n_groups)] 119 | 120 | def __len__(self): 121 | return len(self.y_array) 122 | 123 | def __getitem__(self, idx): 124 | x = self.x_array[idx] 125 | y = self.y_array[idx] 126 | p = self.targets_all['spurious'][idx] 127 | g = self.group_array[idx] 128 | if self.transform is not None: 129 | x = self.transform(x) 130 | 131 | attr = torch.LongTensor( 132 | [y, p, g]) 133 | 134 | 135 | if self.split != 'train': 136 | return x, attr, 0, -1 137 | else: 138 | return x, attr, 0, int((p == 3) and (y == 1)) 139 | 140 | 141 | 142 | 143 | # return (x, y, idx) # g 144 | 145 | def group_str(self, group_idx): 146 | return self.eval_grouper.group_str(group_idx) 147 | 148 | def get_text(self, idx): 149 | return self.x_array[idx] 150 | 151 | 152 | 153 | def init_bert_transform(tokenizer, model_name, max_token_length): 154 | """ 155 | Inspired from the WILDS dataset: 156 | - https://github.com/p-lambda/wilds/blob/main/examples/transforms.py 157 | """ 158 | def transform(text): 159 | tokens = tokenizer(text, padding='max_length', 160 | truncation=True, 161 | max_length=max_token_length,#args.max_token_length, # 300 162 | return_tensors='pt') 163 | if model_name == 'bert-base-uncased': 164 | x = torch.stack((tokens['input_ids'], 165 | tokens['attention_mask'], 166 | tokens['token_type_ids']), dim=2) 167 | # Not supported for now 168 | elif model_name == 'distilbert-base-uncased': 169 | x = torch.stack((tokens['input_ids'], 170 | tokens['attention_mask']), dim=2) 171 | x = torch.squeeze(x, dim=0) # First shape dim is always 1 172 | return x 173 | return transform -------------------------------------------------------------------------------- /module/resnet.py: -------------------------------------------------------------------------------- 1 | ''' From https://github.com/alinlab/LfF/blob/master/module/resnet.py ''' 2 | 3 | """ 4 | Properly implemented ResNet-s for CIFAR10 as described in paper [1]. 5 | The implementation and structure of this file is hugely influenced by [2] 6 | which is implemented for ImageNet and doesn't have option A for identity. 7 | Moreover, most of the implementations on the web is copy-paste from 8 | torchvision's resnet and has wrong number of params. 9 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 10 | number of layers and parameters: 11 | name | layers | params 12 | ResNet20 | 20 | 0.27M 13 | ResNet32 | 32 | 0.46M 14 | ResNet44 | 44 | 0.66M 15 | ResNet56 | 56 | 0.85M 16 | ResNet110 | 110 | 1.7M 17 | ResNet1202| 1202 | 19.4m 18 | which this implementation indeed has. 19 | Reference: 20 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 21 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 22 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 23 | If you use this implementation in you work, please don't forget to mention the 24 | author, Yerlan Idelbayev. 25 | """ 26 | import torch 27 | import torch.nn as nn 28 | import torch.nn.functional as F 29 | import torch.nn.init as init 30 | 31 | from torch.autograd import Variable 32 | 33 | __all__ = [ 34 | "ResNet", 35 | "resnet20", 36 | "resnet32", 37 | "resnet44", 38 | "resnet56", 39 | "resnet110", 40 | "resnet1202", 41 | ] 42 | 43 | 44 | def _weights_init(m): 45 | classname = m.__class__.__name__ 46 | # print(classname) 47 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 48 | init.kaiming_normal_(m.weight) 49 | 50 | 51 | class LambdaLayer(nn.Module): 52 | def __init__(self, lambd): 53 | super(LambdaLayer, self).__init__() 54 | self.lambd = lambd 55 | 56 | def forward(self, x): 57 | return self.lambd(x) 58 | 59 | 60 | class BasicBlock(nn.Module): 61 | expansion = 1 62 | 63 | def __init__(self, in_planes, planes, stride=1, option="A"): 64 | super(BasicBlock, self).__init__() 65 | self.conv1 = nn.Conv2d( 66 | in_planes, 67 | planes, 68 | kernel_size=3, 69 | stride=stride, 70 | padding=1, 71 | bias=False, 72 | ) 73 | self.bn1 = nn.BatchNorm2d(planes) 74 | self.conv2 = nn.Conv2d( 75 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 76 | ) 77 | self.bn2 = nn.BatchNorm2d(planes) 78 | 79 | self.shortcut = nn.Sequential() 80 | if stride != 1 or in_planes != planes: 81 | if option == "A": 82 | """ 83 | For CIFAR10 ResNet paper uses option A. 84 | """ 85 | self.shortcut = LambdaLayer( 86 | lambda x: F.pad( 87 | x[:, :, ::2, ::2], 88 | (0, 0, 0, 0, planes // 4, planes // 4), 89 | "constant", 90 | 0, 91 | ) 92 | ) 93 | elif option == "B": 94 | self.shortcut = nn.Sequential( 95 | nn.Conv2d( 96 | in_planes, 97 | self.expansion * planes, 98 | kernel_size=1, 99 | stride=stride, 100 | bias=False, 101 | ), 102 | nn.BatchNorm2d(self.expansion * planes), 103 | ) 104 | 105 | def forward(self, x): 106 | out = F.relu(self.bn1(self.conv1(x))) 107 | out = self.bn2(self.conv2(out)) 108 | out += self.shortcut(x) 109 | out = F.relu(out) 110 | return out 111 | 112 | 113 | class ResNet(nn.Module): 114 | def __init__(self, block, num_blocks, num_classes=10): 115 | super(ResNet, self).__init__() 116 | self.in_planes = 16 117 | # print('!!!1'*100) 118 | 119 | self.conv1 = nn.Conv2d( 120 | 3, 16, kernel_size=3, stride=1, padding=1, bias=False 121 | ) 122 | self.bn1 = nn.BatchNorm2d(16) 123 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 124 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 125 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 126 | self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) 127 | self.fc = nn.Linear(64, num_classes) 128 | self.fc2 = nn.Linear(64, num_classes) 129 | 130 | self.projection_head = nn.Sequential( 131 | # nn.ReLU(), 132 | nn.Linear(64,16), 133 | # nn.ReLU(), 134 | nn.Linear(16,64)) 135 | 136 | self.apply(_weights_init) 137 | 138 | def _make_layer(self, block, planes, num_blocks, stride): 139 | strides = [stride] + [1] * (num_blocks - 1) 140 | layers = [] 141 | for stride in strides: 142 | layers.append(block(self.in_planes, planes, stride)) 143 | self.in_planes = planes * block.expansion 144 | 145 | return nn.Sequential(*layers) 146 | 147 | def extract(self, x): 148 | out = F.relu(self.bn1(self.conv1(x))) 149 | out = self.layer1(out) 150 | out = self.layer2(out) 151 | out = self.layer3(out) 152 | out = F.avg_pool2d(out, out.size()[3]) 153 | feat = out.view(out.size(0), -1) 154 | 155 | return feat 156 | 157 | def predict(self, x): 158 | prediction = self.fc(x) 159 | return prediction 160 | 161 | def forward(self, x, mode=None): 162 | out = F.relu(self.bn1(self.conv1(x))) 163 | out = self.layer1(out) 164 | out = self.layer2(out) 165 | out = self.layer3(out) 166 | # out = F.avg_pool2d(out, out.size()[3]) 167 | # out = out.view(out.size(0), -1) 168 | out = self.avgpool(out) 169 | out = out.view(out.size(0), -1) 170 | final_out = self.fc(out) 171 | if mode == 'tsne' or mode == 'mixup': 172 | return out, final_out 173 | else: 174 | return final_out 175 | 176 | 177 | def resnet20(num_classes): 178 | print('!!!!!'*100) 179 | return ResNet(BasicBlock, [3, 3, 3], num_classes) 180 | 181 | 182 | def resnet32(): 183 | return ResNet(BasicBlock, [5, 5, 5]) 184 | 185 | 186 | def resnet44(): 187 | return ResNet(BasicBlock, [7, 7, 7]) 188 | 189 | 190 | def resnet56(): 191 | return ResNet(BasicBlock, [9, 9, 9]) 192 | 193 | 194 | def resnet110(): 195 | return ResNet(BasicBlock, [18, 18, 18]) 196 | 197 | 198 | def resnet1202(): 199 | return ResNet(BasicBlock, [200, 200, 200]) 200 | 201 | 202 | def test(net): 203 | import numpy as np 204 | 205 | total_params = 0 206 | 207 | for x in filter(lambda p: p.requires_grad, net.parameters()): 208 | total_params += np.prod(x.data.numpy().shape) 209 | print("Total number of params", total_params) 210 | print( 211 | "Total layers", 212 | len( 213 | list( 214 | filter( 215 | lambda p: p.requires_grad and len(p.data.size()) > 1, 216 | net.parameters(), 217 | ) 218 | ) 219 | ), 220 | ) 221 | 222 | 223 | 224 | if __name__ == "__main__": 225 | for net_name in __all__: 226 | if net_name.startswith("resnet"): 227 | print(net_name) 228 | test(globals()[net_name]()) 229 | print() 230 | -------------------------------------------------------------------------------- /data/grouper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset grouer for subgroup and group_ix information 3 | - Used by CivilComments 4 | 5 | From WILDS: https://github.com/p-lambda/wilds/blob/main/wilds/common/grouper.py 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | # from wilds.common.utils import get_counts 11 | # from wilds.datasets.wilds_dataset import WILDSSubset 12 | import warnings 13 | 14 | 15 | def get_counts(g, n_groups): 16 | """ 17 | This differs from split_into_groups in how it handles missing groups. 18 | get_counts always returns a count Tensor of length n_groups, 19 | whereas split_into_groups returns a unique_counts Tensor 20 | whose length is the number of unique groups present in g. 21 | Args: 22 | - g (Tensor): Vector of groups 23 | Returns: 24 | - counts (Tensor): A list of length n_groups, denoting the count of each group. 25 | """ 26 | unique_groups, unique_counts = torch.unique(g, sorted=False, return_counts=True) 27 | counts = torch.zeros(n_groups, device=g.device) 28 | counts[unique_groups] = unique_counts.float() 29 | return counts 30 | 31 | 32 | class Grouper: 33 | """ 34 | Groupers group data points together based on their metadata. 35 | They are used for training and evaluation, 36 | e.g., to measure the accuracies of different groups of data. 37 | """ 38 | def __init__(self): 39 | raise NotImplementedError 40 | 41 | @property 42 | def n_groups(self): 43 | """ 44 | The number of groups defined by this Grouper. 45 | """ 46 | return self._n_groups 47 | 48 | def metadata_to_group(self, metadata, return_counts=False): 49 | """ 50 | Args: 51 | - metadata (Tensor): An n x d matrix containing d metadata fields 52 | for n different points. 53 | - return_counts (bool): If True, return group counts as well. 54 | Output: 55 | - group (Tensor): An n-length vector of groups. 56 | - group_counts (Tensor): Optional, depending on return_counts. 57 | An n_group-length vector of integers containing the 58 | numbers of data points in each group in the metadata. 59 | """ 60 | raise NotImplementedError 61 | 62 | def group_str(self, group): 63 | """ 64 | Args: 65 | - group (int): A single integer representing a group. 66 | Output: 67 | - group_str (str): A string containing the pretty name of that group. 68 | """ 69 | raise NotImplementedError 70 | 71 | def group_field_str(self, group): 72 | """ 73 | Args: 74 | - group (int): A single integer representing a group. 75 | Output: 76 | - group_str (str): A string containing the name of that group. 77 | """ 78 | raise NotImplementedError 79 | 80 | class CombinatorialGrouper(Grouper): 81 | def __init__(self, dataset, groupby_fields): 82 | """ 83 | CombinatorialGroupers form groups by taking all possible combinations of the metadata 84 | fields specified in groupby_fields, in lexicographical order. 85 | For example, if: 86 | dataset.metadata_fields = ['country', 'time', 'y'] 87 | groupby_fields = ['country', 'time'] 88 | and if in dataset.metadata, country is in {0, 1} and time is in {0, 1, 2}, 89 | then the grouper will assign groups in the following way: 90 | country = 0, time = 0 -> group 0 91 | country = 1, time = 0 -> group 1 92 | country = 0, time = 1 -> group 2 93 | country = 1, time = 1 -> group 3 94 | country = 0, time = 2 -> group 4 95 | country = 1, time = 2 -> group 5 96 | If groupby_fields is None, then all data points are assigned to group 0. 97 | Args: 98 | - dataset (WILDSDataset) 99 | - groupby_fields (list of str) 100 | """ 101 | # if isinstance(dataset, WILDSSubset): 102 | # raise ValueError("Grouper should be defined for the full dataset, not a subset") 103 | self.groupby_fields = groupby_fields 104 | 105 | if groupby_fields is None: 106 | self._n_groups = 1 107 | else: 108 | # We assume that the metadata fields are integers, 109 | # so we can measure the cardinality of each field by taking its max + 1. 110 | # Note that this might result in some empty groups. 111 | self.groupby_field_indices = [i for (i, field) in enumerate(dataset.metadata_fields) if field in groupby_fields] 112 | if len(self.groupby_field_indices) != len(self.groupby_fields): 113 | raise ValueError('At least one group field not found in dataset.metadata_fields') 114 | grouped_metadata = dataset.metadata_array[:, self.groupby_field_indices] 115 | if not isinstance(grouped_metadata, torch.LongTensor): 116 | grouped_metadata_long = grouped_metadata.long() 117 | if not torch.all(grouped_metadata == grouped_metadata_long): 118 | warnings.warn(f'CombinatorialGrouper: converting metadata with fields [{", ".join(groupby_fields)}] into long') 119 | grouped_metadata = grouped_metadata_long 120 | for idx, field in enumerate(self.groupby_fields): 121 | min_value = grouped_metadata[:,idx].min() 122 | if min_value < 0: 123 | raise ValueError(f"Metadata for CombinatorialGrouper cannot have values less than 0: {field}, {min_value}") 124 | if min_value > 0: 125 | warnings.warn(f"Minimum metadata value for CombinatorialGrouper is not 0 ({field}, {min_value}). This will result in empty groups") 126 | self.cardinality = 1 + torch.max( 127 | grouped_metadata, dim=0)[0] 128 | cumprod = torch.cumprod(self.cardinality, dim=0) 129 | self._n_groups = cumprod[-1].item() 130 | self.factors_np = np.concatenate(([1], cumprod[:-1])) 131 | self.factors = torch.from_numpy(self.factors_np) 132 | self.metadata_map = dataset.metadata_map 133 | 134 | def metadata_to_group(self, metadata, return_counts=False): 135 | if self.groupby_fields is None: 136 | groups = torch.zeros(metadata.shape[0], dtype=torch.long) 137 | else: 138 | groups = metadata[:, self.groupby_field_indices].long() @ self.factors 139 | 140 | if return_counts: 141 | group_counts = get_counts(groups, self._n_groups) 142 | return groups, group_counts 143 | else: 144 | return groups 145 | 146 | def group_str(self, group): 147 | if self.groupby_fields is None: 148 | return 'all' 149 | 150 | # group is just an integer, not a Tensor 151 | n = len(self.factors_np) 152 | metadata = np.zeros(n) 153 | for i in range(n-1): 154 | metadata[i] = (group % self.factors_np[i+1]) // self.factors_np[i] 155 | metadata[n-1] = group // self.factors_np[n-1] 156 | group_name = '' 157 | for i in reversed(range(n)): 158 | meta_val = int(metadata[i]) 159 | if self.metadata_map is not None: 160 | if self.groupby_fields[i] in self.metadata_map: 161 | meta_val = self.metadata_map[self.groupby_fields[i]][meta_val] 162 | group_name += f'{self.groupby_fields[i]} = {meta_val}, ' 163 | group_name = group_name[:-2] 164 | return group_name 165 | 166 | # a_n = S / x_n 167 | # a_{n-1} = (S % x_n) / x_{n-1} 168 | # a_{n-2} = (S % x_{n-1}) / x_{n-2} 169 | # ... 170 | # 171 | # g = 172 | # a_1 * x_1 + 173 | # a_2 * x_2 + ... 174 | # a_n * x_n 175 | 176 | def group_field_str(self, group): 177 | return self.group_str(group).replace('=', ':').replace(',','_').replace(' ','') -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Datasets 3 | """ 4 | import copy 5 | import numpy as np 6 | import importlib 7 | 8 | 9 | def initialize_data(args): 10 | """ 11 | Set dataset-specific arguments 12 | By default, the args.root_dir below should work ifinstalling datasets as 13 | specified in the README to the specified locations 14 | - Otherwise, change `args.root_dir` to the path where the data is stored. 15 | """ 16 | dataset_module = importlib.import_module(f'data.{args.dataset}') 17 | load_dataloaders = getattr(dataset_module, 'load_dataloaders') 18 | visualize_dataset = getattr(dataset_module, 'visualize_dataset') 19 | 20 | if 'waterbirds' in args.dataset: 21 | args.root_dir = '../slice-and-dice-smol/datasets/data/Waterbirds/' 22 | # args.root_dir = './datasets/data/Waterbirds/' 23 | args.target_name = 'waterbird_complete95' 24 | args.confounder_names = ['forest2water2'] 25 | args.image_mean = np.mean([0.485, 0.456, 0.406]) 26 | args.image_std = np.mean([0.229, 0.224, 0.225]) 27 | args.augment_data = False 28 | args.train_classes = ['landbirds', 'waterbirds'] 29 | if args.dataset == 'waterbirds_r': 30 | args.train_classes = ['land', 'water'] 31 | 32 | elif 'colored_mnist' in args.dataset: 33 | args.root_dir = './datasets/data/' 34 | args.data_path = './datasets/data/' 35 | args.target_name = 'digit' 36 | args.confounder_names = ['color'] 37 | args.image_mean = 0.5 38 | args.image_std = 0.5 39 | args.augment_data = False 40 | # args.train_classes = args.train_classes 41 | 42 | elif 'celebA' in args.dataset: 43 | # args.root_dir = './datasets/data/CelebA/' 44 | args.root_dir = '/dfs/scratch0/nims/CelebA/celeba/' 45 | # IMPORTANT - dataloader assumes that we have directory structure 46 | # in ./datasets/data/CelebA/ : 47 | # |-- list_attr_celeba.csv 48 | # |-- list_eval_partition.csv 49 | # |-- img_align_celeba/ 50 | # |-- image1.png 51 | # |-- ... 52 | # |-- imageN.png 53 | args.target_name = 'Blond_Hair' 54 | args.confounder_names = ['Male'] 55 | args.image_mean = np.mean([0.485, 0.456, 0.406]) 56 | args.image_std = np.mean([0.229, 0.224, 0.225]) 57 | args.augment_data = False 58 | args.image_path = './images/celebA/' 59 | args.train_classes = ['blond', 'nonblond'] 60 | args.val_split = 0.2 61 | 62 | elif 'civilcomments' in args.dataset: 63 | args.root_dir = '/gpfs/data/razavianlab/data/nlp/CivilComments/' 64 | args.target_name = 'toxic' 65 | args.confounder_names = ['identities'] 66 | args.image_mean = 0 67 | args.image_std = 0 68 | args.augment_data = False 69 | args.image_path = './images/civilcomments/' 70 | args.train_classes = ['non_toxic', 'toxic'] 71 | args.max_token_length = 300 72 | args.arch = 'bert-base-uncased_pt' 73 | args.bs_trn = 16 74 | args.bs_val = 128 75 | 76 | elif 'cxr' in args.dataset: 77 | args.root_dir = '/dfs/scratch1/ksaab/data/4tb_hdd/CXR' 78 | args.target_name = 'pmx' 79 | args.confounder_names = ['chest_tube'] 80 | args.image_mean = 0.48865 81 | args.image_std = 0.24621 82 | args.augment_data = False 83 | args.image_path = './images/cxr/' 84 | args.train_classes = ['no_pmx', 'pmx'] 85 | 86 | args.task = args.dataset # e.g. 'civilcomments', for BERT 87 | args.num_classes = len(args.train_classes) 88 | return load_dataloaders, visualize_dataset 89 | 90 | 91 | def train_val_split(dataset, val_split, seed): 92 | """ 93 | Compute indices for train and val splits 94 | 95 | Args: 96 | - dataset (torch.utils.data.Dataset): Pytorch dataset 97 | - val_split (float): Fraction of dataset allocated to validation split 98 | - seed (int): Reproducibility seed 99 | Returns: 100 | - train_indices, val_indices (np.array, np.array): Dataset indices 101 | """ 102 | train_ix = int(np.round(val_split * len(dataset))) 103 | all_indices = np.arange(len(dataset)) 104 | np.random.seed(seed) 105 | np.random.shuffle(all_indices) 106 | train_indices = all_indices[train_ix:] 107 | val_indices = all_indices[:train_ix] 108 | return train_indices, val_indices 109 | 110 | 111 | def get_resampled_indices(dataloader, args, sampling='subsample', seed=None): 112 | """ 113 | Args: 114 | - dataloader (torch.utils.data.DataLoader): 115 | - sampling (str): 'subsample' or 'upsample' 116 | """ 117 | try: 118 | indices = dataloader.sampler.indices 119 | except: 120 | indices = np.arange(len(dataloader.dataset)) 121 | indices = np.arange(len(dataloader.dataset)) 122 | target_vals, target_val_counts = np.unique( 123 | dataloader.dataset.targets_all['target'][indices], 124 | return_counts=True) 125 | sampled_indices = [] 126 | if sampling == 'subsample': 127 | sample_size = np.min(target_val_counts) 128 | elif sampling == 'upsample': 129 | sample_size = np.max(target_val_counts) 130 | else: 131 | return indices 132 | 133 | if seed is None: 134 | seed = args.seed 135 | np.random.seed(seed) 136 | for v in target_vals: 137 | group_indices = np.where( 138 | dataloader.dataset.targets_all['target'][indices] == v)[0] 139 | if sampling == 'subsample': 140 | sampling_size = np.min([len(group_indices), sample_size]) 141 | replace = False 142 | elif sampling == 'upsample': 143 | sampling_size = np.max([0, sample_size - len(group_indices)]) 144 | sampled_indices.append(group_indices) 145 | replace = True 146 | sampled_indices.append(np.random.choice( 147 | group_indices, size=sampling_size, replace=replace)) 148 | sampled_indices = np.concatenate(sampled_indices) 149 | np.random.seed(seed) 150 | np.random.shuffle(sampled_indices) 151 | return indices[sampled_indices] 152 | 153 | 154 | def get_resampled_set(dataset, resampled_set_indices, copy_dataset=False): 155 | """ 156 | Obtain spurious dataset resampled_set 157 | Args: 158 | - dataset (torch.utils.data.Dataset): Spurious correlations dataset 159 | - resampled_set_indices (int[]): List-like of indices 160 | - deepcopy (bool): If true, copy the dataset 161 | """ 162 | resampled_set = copy.deepcopy(dataset) if copy_dataset else dataset 163 | try: # Some dataset classes may not have these attributes 164 | resampled_set.y_array = resampled_set.y_array[resampled_set_indices] 165 | resampled_set.group_array = resampled_set.group_array[resampled_set_indices] 166 | resampled_set.split_array = resampled_set.split_array[resampled_set_indices] 167 | resampled_set.targets = resampled_set.y_array 168 | try: # Depending on the dataset these are responsible for the X features 169 | resampled_set.filename_array = resampled_set.filename_array[resampled_set_indices] 170 | except: 171 | resampled_set.x_array = resampled_set.x_array[resampled_set_indices] 172 | except AttributeError as e: 173 | try: 174 | resampled_set.targets = resampled_set.targets[resampled_set_indices] 175 | except: 176 | resampled_set_indices = np.concatenate(resampled_set_indices) 177 | resampled_set.targets = resampled_set.targets[resampled_set_indices] 178 | try: 179 | resampled_set.df = resampled_set.df.iloc[resampled_set_indices] 180 | except AttributeError: 181 | pass 182 | 183 | try: 184 | resampled_set.data = resampled_set.data[resampled_set_indices] 185 | except AttributeError: 186 | pass 187 | 188 | try: # Depending on the dataset these are responsible for the X features 189 | resampled_set.filename_array = resampled_set.filename_array[resampled_set_indices] 190 | except: 191 | pass 192 | 193 | for target_type, target_val in resampled_set.targets_all.items(): 194 | resampled_set.targets_all[target_type] = target_val[resampled_set_indices] 195 | 196 | print('len(resampled_set.targets)', len(resampled_set.targets)) 197 | return resampled_set 198 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | '''Modified from https://github.com/alinlab/LfF/blob/master/util.py''' 2 | 3 | import io 4 | import torch 5 | import numpy as np 6 | import torch.nn as nn 7 | from transformers import get_linear_schedule_with_warmup 8 | 9 | class EMA: 10 | def __init__(self, label, num_classes=None, alpha=0.9): 11 | self.label = label.cuda() 12 | self.alpha = alpha 13 | self.parameter = torch.zeros(label.size(0), num_classes) 14 | self.updated = torch.zeros(label.size(0), num_classes) 15 | self.num_classes = num_classes 16 | self.max = torch.zeros(self.num_classes).cuda() 17 | 18 | def update(self, data, index, curve=None, iter_range=None, step=None): 19 | self.parameter = self.parameter.to(data.device) 20 | self.updated = self.updated.to(data.device) 21 | index = index.to(data.device) 22 | 23 | if curve is None: 24 | self.parameter[index] = self.alpha * self.parameter[index] + (1 - self.alpha * self.updated[index]) * data 25 | else: 26 | alpha = curve ** -(step / iter_range) 27 | self.parameter[index] = alpha * self.parameter[index] + (1 - alpha * self.updated[index]) * data 28 | self.updated[index] = 1 29 | 30 | def max_loss(self, label): 31 | label_index = torch.where(self.label == label)[0] 32 | return self.parameter[label_index].max() 33 | 34 | def min_loss(self, label): 35 | label_index = torch.where(self.label == label)[0] 36 | return self.parameter[label_index].min() 37 | 38 | 39 | class EMA_squre: 40 | def __init__(self, num_classes=None, alpha=0.9, avg_type = 'mv'): 41 | self.alpha = alpha 42 | self.parameter = torch.zeros(num_classes, num_classes) 43 | self.global_count_ = torch.zeros(num_classes, num_classes) 44 | self.updated = torch.zeros(num_classes, num_classes) 45 | self.num_classes = num_classes 46 | self.max = torch.zeros(self.num_classes).cuda() 47 | self.avg_type = avg_type 48 | 49 | def update(self, data, y_list, a_list, curve=None, iter_range=None, step=None, bias=None, fix = None): 50 | self.parameter = self.parameter.to(data.device) 51 | self.updated = self.updated.to(data.device) 52 | # self.global_count_ = self.global_count_.to(data.device) 53 | y_list = y_list.to(data.device) 54 | a_list = a_list.to(data.device) 55 | 56 | 57 | count = torch.zeros(self.num_classes, self.num_classes).to(data.device) 58 | # parameter_temp = torch.zeros(self.num_classes, self.num_classes, self.num_classes).to(data.device) 59 | 60 | if self.avg_type == 'mv': 61 | if curve is None: 62 | for i, (y, a) in enumerate(zip(y_list, a_list)): 63 | # parameter_temp[y,a] += data[i] 64 | count[y,a] += 1 65 | self.global_count_[y,a] += 1 66 | self.parameter[y,a] = self.alpha * self.parameter[y,a] + (1 - self.alpha * self.updated[y,a]) * data[i,y]#parameter_temp[y,a]/count[y,a] 67 | self.updated[y,a] = 1 68 | else: 69 | alpha = curve ** -(step / iter_range) 70 | for i, (y, a) in enumerate(zip(y_list, a_list)): 71 | # parameter_temp[y,a] += data[i] 72 | count[y,a] += 1 73 | self.global_count_[y,a] += 1 74 | self.parameter[y,a] = alpha * self.parameter[y,a] + (1 - alpha * self.updated[y,a]) * data[i,y]#parameter_temp[y,a]/count[y,a] 75 | self.updated[y,a] = 1 76 | elif self.avg_type == 'mv_batch': 77 | self.parameter_temp = torch.zeros(self.num_classes, self.num_classes).to(data.device) 78 | for i, (y, a) in enumerate(zip(y_list, a_list)): 79 | count[y,a] += 1 80 | self.global_count_[y,a] += 1 81 | self.parameter_temp[y,a] += data[i,y] 82 | self.parameter = self.alpha * self.parameter + (1 - self.alpha) * self.parameter_temp / (count + 1e-4) 83 | elif self.avg_type == 'batch': 84 | self.parameter_temp = torch.zeros(self.num_classes, self.num_classes).to(data.device) 85 | for i, (y, a) in enumerate(zip(y_list, a_list)): 86 | count[y,a] += 1 87 | self.global_count_[y,a] += 1 88 | self.parameter_temp[y,a] += data[i,y] 89 | self.parameter = self.parameter_temp / (count + 1e-4) 90 | elif self.avg_type == 'epoch': 91 | for i, (y, a) in enumerate(zip(y_list, a_list)): 92 | count[y,a] += 1 93 | self.global_count_[y,a] += 1 94 | self.parameter[y,a] += data[i,y] 95 | else: 96 | raise NotImplementedError("This averaging type is not yet implemented!") 97 | 98 | if fix is not None: 99 | self.parameter = torch.ones(self.num_classes, self.num_classes) * 0.1#* 0.005/(self.num_classes-1) 100 | # for i in range(self.num_classes): 101 | # self.parameter[i,i] = 0.995 102 | self.parameter = self.parameter.to(data.device) 103 | 104 | 105 | 106 | 107 | 108 | # def max_loss(self, label): 109 | # label_index = torch.where(self.label == label)[0] 110 | # return self.parameter[label_index].max() 111 | 112 | # def min_loss(self, label): 113 | # label_index = torch.where(self.label == label)[0] 114 | # return self.parameter[label_index].min() 115 | 116 | 117 | 118 | class EMA_area: 119 | def __init__(self, label, num_classes=None, alpha=0.9): 120 | self.label = label.cuda() 121 | self.alpha = alpha 122 | self.parameter = torch.zeros(label.size(0)) 123 | self.updated = torch.zeros(label.size(0)) 124 | self.num_classes = num_classes 125 | # self.max = torch.zeros(self.num_classes).cuda() 126 | self.data_old = torch.zeros(label.size(0)).cuda() 127 | 128 | def update(self, data, index, curve=None, iter_range=None, step=None): 129 | self.parameter = self.parameter.to(data.device) 130 | self.updated = self.updated.to(data.device) 131 | index = index.to(data.device) 132 | 133 | self.parameter[index] += 0.5 * self.data_old[index] + (1 - 0.5 * self.updated[index]) * data 134 | self.updated[index] = 1 135 | self.data_old[index] = data 136 | 137 | def max_area(self, label, temp = 1): 138 | label_index = torch.where(self.label == label)[0] 139 | return torch.nn.functional.sigmoid(-self.parameter[label_index]/temp).max() 140 | 141 | 142 | 143 | class EMA_feature: 144 | def __init__(self, label, num_classes=None, alpha=0.9): 145 | self.label = label.cuda() 146 | self.alpha = alpha 147 | self.parameter = torch.zeros((label.size(0),num_classes)) 148 | self.updated = torch.zeros((label.size(0),num_classes)) 149 | self.num_classes = num_classes 150 | #self.max = torch.zeros(self.num_classes).cuda() 151 | 152 | def update(self, data, index, curve=None, iter_range=None, step=None): 153 | self.parameter = self.parameter.to(data.device) 154 | self.updated = self.updated.to(data.device) 155 | index = index.to(data.device) 156 | 157 | if curve is None: 158 | self.parameter[index] = self.alpha * self.parameter[index] + (1 - self.alpha * self.updated[index]) * data 159 | else: 160 | alpha = curve ** -(step / iter_range) 161 | self.parameter[index] = alpha * self.parameter[index] + (1 - alpha * self.updated[index]) * data 162 | self.updated[index] = 1 163 | 164 | # def max_loss(self, label): 165 | # label_index = torch.where(self.label == label)[0] 166 | # return self.parameter[label_index].max() 167 | 168 | 169 | def sigmoid_rampup(current, rampup_length): 170 | """Exponential rampup from 2""" 171 | if rampup_length == 0: 172 | return 1.0 173 | else: 174 | current = np.clip(current, 0.0, rampup_length) 175 | phase = 1.0 - current / rampup_length 176 | return float(np.exp(-5.0 * phase * phase)) 177 | 178 | 179 | def linear_rampup(current, rampup_length): 180 | """Linear rampup""" 181 | assert current >= 0 and rampup_length >= 0 182 | if current >= rampup_length: 183 | return 1.0 184 | else: 185 | return current / rampup_length 186 | EPS = 1e-9 187 | 188 | def grad_norm(module): 189 | parameters = module.parameters() 190 | if isinstance(parameters, torch.Tensor): 191 | parameters = [parameters] 192 | 193 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 194 | 195 | total_norm = 0 196 | for p in parameters: 197 | param_norm = p.grad.data.norm(2) 198 | total_norm = param_norm.item() ** 2 199 | total_norm = total_norm ** (1. / 2) 200 | return total_norm 201 | 202 | 203 | def adaptive_gradient_clipping_(generator_module: nn.Module, mi_module: nn.Module): 204 | """ 205 | Clips the gradient according to the min norm of the generator and mi estimator 206 | Arguments: 207 | generator_module -- nn.Module 208 | mi_module -- nn.Module 209 | """ 210 | norm_generator = grad_norm(generator_module) 211 | #norm_estimator = grad_norm(mi_module) 212 | 213 | min_norm = norm_generator#np.minimum(norm_generator, norm_estimator) 214 | 215 | parameters = list( 216 | filter(lambda p: p.grad is not None, mi_module.parameters())) 217 | if isinstance(parameters, torch.Tensor): 218 | parameters = [parameters] 219 | 220 | for p in parameters: 221 | p.grad.data.mul_(min_norm/(norm_estimator + EPS)) 222 | 223 | 224 | 225 | def get_bert_scheduler(optimizer, n_epochs, warmup_steps, dataloader, last_epoch=-1): 226 | """ 227 | Learning rate scheduler for BERT model training 228 | """ 229 | num_training_steps = int(np.round(len(dataloader) * n_epochs)) 230 | print(f'\nt_total is {num_training_steps}\n') 231 | scheduler = get_linear_schedule_with_warmup(optimizer, 232 | warmup_steps, 233 | num_training_steps, 234 | last_epoch) 235 | return scheduler 236 | 237 | # From pytorch-transformers: 238 | # def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, 239 | # num_training_steps, last_epoch=-1): 240 | # """ 241 | # Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after 242 | # a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 243 | # Args: 244 | # optimizer (:class:`~torch.optim.Optimizer`): 245 | # The optimizer for which to schedule the learning rate. 246 | # num_warmup_steps (:obj:`int`): 247 | # The number of steps for the warmup phase. 248 | # num_training_steps (:obj:`int`): 249 | # The total number of training steps. 250 | # last_epoch (:obj:`int`, `optional`, defaults to -1): 251 | # The index of the last epoch when resuming training. 252 | # Return: 253 | # :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 254 | # """ 255 | 256 | # def lr_lambda(current_step: int): 257 | # if current_step < num_warmup_steps: 258 | # return float(current_step) / float(max(1, num_warmup_steps)) 259 | # return max( 260 | # 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 261 | # ) 262 | 263 | # return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch) 264 | -------------------------------------------------------------------------------- /module/activations.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions to help with feature representations 3 | """ 4 | import numpy as np 5 | import torch 6 | from tqdm import tqdm 7 | 8 | from utils import print_header 9 | from utils.visualize import plot_umap 10 | from network import get_output 11 | 12 | from sklearn.linear_model import LogisticRegression 13 | from sklearn.model_selection import train_test_split 14 | 15 | 16 | class SaveOutput: 17 | def __init__(self): 18 | self.outputs = [] 19 | 20 | def __call__(self, module, module_in, module_out): 21 | try: 22 | module_out = module_out.detach().cpu() 23 | self.outputs.append(module_out) # .detach().cpu().numpy() 24 | except Exception as e: 25 | print(e) 26 | self.outputs.append(module_out) 27 | 28 | def clear(self): 29 | self.outputs = [] 30 | 31 | 32 | def save_activations(model, dataloader, args): 33 | """ 34 | total_embeddings = save_activations(net, train_loader, args) 35 | """ 36 | save_output = SaveOutput() 37 | hook_handles = [] 38 | 39 | # if 'resnet' in args.arch: 40 | # for name, layer in model.named_modules(): 41 | # if name == model.activation_layer or \ 42 | # (isinstance(model, torch.nn.DataParallel) and \ 43 | # name.replace('module.', '') == model.activation_layer): 44 | # handle = layer.register_forward_hook(save_output) 45 | # hook_handles.append(handle) 46 | # elif 'densenet' in args.arch: 47 | # for name, layer in model.named_modules(): 48 | # if name == model.activation_layer or \ 49 | # (isinstance(model, torch.nn.DataParallel) and \ 50 | # name.replace('module.', '') == model.activation_layer): 51 | # handle = layer.register_forward_hook(save_output) 52 | # hook_handles.append(handle) 53 | # elif 'bert' in args.arch: 54 | for name, layer in model.named_modules(): 55 | if name == model.activation_layer or \ 56 | (isinstance(model, torch.nn.DataParallel) and \ 57 | name.replace('module.', '') == model.activation_layer): 58 | handle = layer.register_forward_hook(save_output) 59 | hook_handles.append(handle) 60 | print(f'Activation layer: {name}') 61 | # else: 62 | # # Only get last activation layer that fits the criteria? 63 | # activation_layers = [] 64 | # for layer in model.modules(): 65 | # # for name, layer in model.named_modules() 66 | # try: 67 | # if isinstance(layer, torch.nn.ReLU) or isinstance(layer, torch.nn.Identity): 68 | # activation_layers.append(layer) 69 | # # handle = layer.register_forward_hook(save_output) 70 | # # hook_handles.append(handle) 71 | # except AttributeError: 72 | # if isinstance(layer, torch.nn.ReLU): 73 | # activation_layers.append(layer) 74 | # # handle = layer.register_forward_hook(save_output) 75 | # # hook_handles.append(handle) 76 | # # Only get last activation layer that fits the criteria 77 | # if 'cnn' in args.arch and args.no_projection_head is False: 78 | # # or args.dataset == 'colored_mnist'): 79 | # handle = activation_layers[-2].register_forward_hook(save_output) 80 | # else: 81 | # handle = activation_layers[-1].register_forward_hook(save_output) 82 | # hook_handles.append(handle) 83 | model.to(args.device) 84 | model.eval() 85 | 86 | # Forward pass on test set to save activations 87 | correct_train = 0 88 | total_train = 0 89 | total_embeddings = [] 90 | total_inputs = [] 91 | total_labels = [] 92 | 93 | total_predictions = [] 94 | 95 | print('> Saving activations') 96 | 97 | with torch.no_grad(): 98 | for i, data in enumerate(tqdm(dataloader, desc='Running inference')): 99 | inputs, labels, data_ix = data 100 | inputs = inputs.to(args.device) 101 | labels = labels.to(args.device) 102 | 103 | try: 104 | if args.mode == 'contrastive_train': 105 | input_ids = inputs[:, :, 0] 106 | input_masks = inputs[:, :, 1] 107 | segment_ids = inputs[:, :, 2] 108 | outputs = model((input_ids, input_masks, segment_ids, None)) # .logits <- changed this in the contrastive network definitino 109 | else: 110 | outputs = get_output(model, inputs, labels, args) 111 | except: 112 | outputs = get_output(model, inputs, labels, args) 113 | # Why was I collecting these? 4/27/21 114 | # total_inputs.extend(inputs.detach().cpu().numpy()) 115 | # total_labels.extend(labels.detach().cpu().numpy()) 116 | 117 | _, predicted = torch.max(outputs.data, 1) 118 | total_train += labels.size(0) 119 | correct_train += (predicted == labels).sum().item() 120 | 121 | # Clear memory 122 | inputs = inputs.detach().cpu() 123 | labels = labels.detach().cpu() 124 | outputs = outputs.detach().cpu() 125 | predicted = predicted.detach().cpu() 126 | total_predictions.append(predicted) 127 | del inputs; del labels; del outputs; del predicted 128 | 129 | # print(f'Accuracy of the network on the test images: %d %%' % ( 130 | # 100 * correct_train / total_train)) 131 | 132 | # Testing this 133 | save_output.outputs = [so.detach() for so in save_output.outputs] 134 | 135 | total_predictions = np.concatenate(total_predictions) 136 | # Consolidate embeddings 137 | total_embeddings = [None] * len(save_output.outputs) 138 | 139 | for ix, output in enumerate(save_output.outputs): 140 | total_embeddings[ix] = output.numpy().squeeze() 141 | 142 | # print(total_embeddings) 143 | 144 | if 'resnet' in args.arch or 'densenet' in args.arch or 'bert' in args.arch or 'cnn' in args.arch or 'mlp' in args.arch: 145 | total_embeddings = np.concatenate(total_embeddings) 146 | if len(total_embeddings.shape) > 2: # Should just be (n_datapoints, embedding_dim) 147 | total_embeddings = total_embeddings.reshape(len(total_embeddings), -1) 148 | save_output.clear() 149 | del save_output; del hook_handles 150 | return total_embeddings, total_predictions 151 | 152 | total_embeddings_relu1 = np.concatenate( 153 | [total_embeddings[0::2]], axis=0).reshape(-1, total_embeddings[0].shape[-1]) 154 | total_embeddings_relu2 = np.concatenate( 155 | [total_embeddings[1::2]], axis=0).reshape(-1, total_embeddings[1].shape[-1]) 156 | 157 | save_output.clear() 158 | del save_output; del hook_handles 159 | return total_embeddings_relu1, total_embeddings_relu2, total_predictions 160 | 161 | 162 | def visualize_activations(net, dataloader, label_types, num_data=None, 163 | figsize=(8, 6), save=True, ftype='png', 164 | title_suffix=None, save_id_suffix=None, args=None, 165 | cmap='tab10', annotate_points=None, 166 | predictions=None, return_embeddings=False): 167 | """ 168 | Visualize and save model activations 169 | 170 | Args: 171 | - net (torch.nn.Module): Pytorch neural net model 172 | - dataloader (torch.utils.data.DataLoader): Pytorch dataloader 173 | - label_types (str[]): List of label types, e.g. ['target', 'spurious', 'sub_target'] 174 | - num_data (int): Number of datapoints to plot 175 | - figsize (int()): Tuple of image dimensions, by (height, weight) 176 | - save (bool): Whether to save the image 177 | - ftype (str): File format for saving 178 | - args (argparse): Experiment arguments 179 | """ 180 | if 'resnet' in args.arch or 'densenet' in args.arch or 'bert' in args.arch or 'cnn' in args.arch or 'mlp' in args.arch: 181 | total_embeddings, predictions = save_activations(net, dataloader, args) 182 | print(f'total_embeddings.shape: {total_embeddings.shape}') 183 | e1 = total_embeddings 184 | e2 = total_embeddings 185 | n_mult = 1 186 | else: 187 | e1, e2, predictions = save_activations(net, dataloader, args) 188 | n_mult = 2 189 | pbar = tqdm(total=n_mult * len(label_types)) 190 | for label_type in label_types: 191 | # For now just save both classifier ReLU activation layers (for MLP, BaseCNN) 192 | if save_id_suffix is not None: 193 | save_id = f'{label_type[0]}{label_type[-1]}_{save_id_suffix}_e1' 194 | else: 195 | save_id = f'{label_type[0]}{label_type[-1]}_e1' 196 | # if title_suffix is not None: 197 | # save_id += f'-{title_suffix}' 198 | plot_umap(e1, dataloader.dataset, label_type, num_data, method='umap', 199 | offset=0, figsize=figsize, save_id=save_id, save=save, 200 | ftype=ftype, title_suffix=title_suffix, args=args, 201 | cmap=cmap, annotate_points=annotate_points, 202 | predictions=predictions) 203 | # Add MDS 204 | plot_umap(e1, dataloader.dataset, label_type, 1000, method='mds', 205 | offset=0, figsize=figsize, save_id=save_id, save=save, 206 | ftype=ftype, title_suffix=title_suffix, args=args, 207 | cmap=cmap, annotate_points=annotate_points, 208 | predictions=predictions) 209 | pbar.update(1) 210 | # if 'resnet' not in args.arch and 'densenet' not in args.arch and 'bert' not in args.arch: 211 | # save_id = f'{label_type}_e2' 212 | # if title_suffix is not None: 213 | # save_id += f'-{title_suffix}' 214 | # plot_umap(e2, dataloader.dataset, label_type, num_data, 215 | # offset=0, figsize=figsize, save_id=save_id, save=save, 216 | # ftype=ftype, title_suffix=title_suffix, args=args, 217 | # cmap=cmap, annotate_points=annotate_points, 218 | # predictions=predictions) 219 | # pbar.update(1) 220 | if return_embeddings: 221 | return e1, e2, predictions 222 | del total_embeddings, predictions 223 | del e1; e2 224 | # 225 | 226 | 227 | def estimate_y_probs(classifier, attribute, dataloader, 228 | classifier_test_size=0.5, 229 | model=None, embeddings=None, 230 | seed=42, reshape_prior=True, args=None): 231 | if embeddings is None: 232 | embeddings, _ = save_activations(model, dataloader, args) 233 | 234 | X = embeddings 235 | y = dataloader.dataset.targets_all[attribute] 236 | X_train, X_test, y_train, y_test = train_test_split( 237 | X, y, test_size=classifier_test_size, random_state=seed) 238 | 239 | # Fit linear classifier 240 | classifier.fit(X_train, y_train) 241 | score = classifier.score(X_test, y_test) 242 | print(f'Linear classifier score: {score:<.3f}') 243 | 244 | # Compute p(y) 245 | _, y_prior = np.unique(y_test, return_counts=True) 246 | y_prior = y_prior / y_prior.sum() 247 | 248 | # Compute p(y | X) 249 | y_post = classifier.predict_proba(X_test) 250 | 251 | if reshape_prior: 252 | y_prior = y_prior.reshape(1, -1).repeat(y_post.shape[0], axis=0) 253 | return y_post, y_prior 254 | 255 | 256 | def estimate_mi(classifier, attribute, dataloader, 257 | classifier_test_size=0.5, 258 | model=None, embeddings=None, 259 | seed=42, args=None): 260 | if embeddings is None: 261 | assert model is not None 262 | embeddings, _ = save_activations(model, dataloader, args) 263 | # Compute p(y | x), p(y) 264 | y_post, y_prior = estimate_y_probs(classifier, attribute, 265 | dataloader, classifier_test_size, 266 | model, embeddings, seed, 267 | reshape_prior=True, args=args) 268 | min_size = np.min((y_post.shape[-1], y_prior.shape[-1])) 269 | y_post = y_post[:,:min_size] 270 | y_prior = y_prior[:,:min_size] 271 | return np.sum(y_post * (np.log(y_post) - np.log(y_prior)), axis=1).mean() 272 | 273 | 274 | def compute_activation_mi(attributes, dataloader, 275 | method='logistic_regression', 276 | classifier_test_size=0.5, max_iter=1000, 277 | model=None, embeddings=None, 278 | seed=42, args=None): 279 | if embeddings is None: 280 | assert model is not None 281 | embeddings, _ = save_activations(model, dataloader, args) 282 | 283 | if method == 'logistic_regression': 284 | clf = LogisticRegression(random_state=seed, max_iter=max_iter) 285 | else: 286 | raise NotImplementedError 287 | 288 | mi_by_attributes = [] 289 | for attribute in attributes: # ['target', 'spurious'] 290 | mi = estimate_mi(clf, attribute, dataloader, 291 | classifier_test_size, model, embeddings, 292 | seed, args) 293 | mi_by_attributes.append(mi) 294 | return mi_by_attributes 295 | 296 | 297 | def compute_align_loss(embeddings, dataloader, measure_by='target', norm=True): 298 | targets_all = dataloader.dataset.targets_all 299 | 300 | if measure_by == 'target': 301 | targets_t = targets_all['target'] 302 | targets_s = targets_all['spurious'] 303 | elif measure_by == 'spurious': # A bit hacky 304 | targets_t = targets_all['spurious'] 305 | targets_s = targets_all['target'] 306 | 307 | embeddings_by_class = {} 308 | for t in np.unique(targets_t): 309 | tix = np.where(targets_t == t)[0] 310 | anchor_embeddings = [] 311 | positive_embeddings = [] 312 | for s in np.unique(targets_s): 313 | six = np.where(targets_s[tix] == s)[0] 314 | if t == s: # For waterbirds, colored MNIST only 315 | anchor_embeddings.append(embeddings[tix][six]) 316 | else: 317 | positive_embeddings.append(embeddings[tix][six]) 318 | 319 | embeddings_by_class[t] = {'anchor': np.concatenate(anchor_embeddings), 320 | 'positive': np.concatenate(positive_embeddings)} 321 | 322 | all_l2 = [] 323 | for c, embeddings_ in embeddings_by_class.items(): # keys 324 | embeddings_a = embeddings_['anchor'] 325 | embeddings_p = embeddings_['positive'] 326 | if norm: 327 | embeddings_a /= np.linalg.norm(embeddings_a) 328 | embeddings_p /= np.linalg.norm(embeddings_p) 329 | 330 | pairwise_l2 = np.linalg.norm(embeddings_a[:, None, :] - embeddings_p[None, :, :], 331 | axis=-1) ** 2 332 | all_l2.append(pairwise_l2.flatten()) 333 | 334 | return np.concatenate(all_l2).mean() 335 | 336 | 337 | def compute_aligned_loss_from_model(model, dataloader, norm, args): 338 | embeddings, predictions = save_activations(model, dataloader, args) 339 | return compute_align_loss(embeddings, dataloader, norm) 340 | 341 | """ 342 | Legacy 343 | """ 344 | def get_embeddings(net, dataloader, args): 345 | net.to(args.device) 346 | test_embeddings = [] 347 | test_embeddings_r = [] 348 | 349 | with torch.no_grad(): 350 | for i, data in enumerate(dataloader): 351 | inputs, labels, data_ix = data 352 | inputs = inputs.to(args.device) 353 | labels = labels.to(args.device) 354 | 355 | embeddings = net.embed(inputs) 356 | embeddings_r = net.embed(inputs, relu=True) 357 | 358 | test_embeddings.append(embeddings.detach().cpu().numpy()) 359 | test_embeddings_r.append(embeddings_r.detach().cpu().numpy()) 360 | 361 | test_embeddings = np.concatenate(test_embeddings, axis=0) 362 | test_embeddings_r = np.concatenate(test_embeddings_r, axis=0) 363 | return test_embeddings, test_embeddings_r -------------------------------------------------------------------------------- /data/util.py: -------------------------------------------------------------------------------- 1 | '''Modified from https://github.com/alinlab/LfF/blob/master/data/util.py''' 2 | 3 | import os 4 | import torch 5 | from torch.utils.data.dataset import Dataset, Subset 6 | from torchvision import transforms as T 7 | from glob import glob 8 | from PIL import Image 9 | from torch.utils.data import DataLoader 10 | import pandas as pd 11 | import numpy as np 12 | from .civilcomments import CivilComments, init_bert_transform 13 | from transformers import BertTokenizerFast 14 | class dataloader(): 15 | def __init__(self, dataset, data_dir, percent, data2preprocess, use_type0, use_type1, batch_size, num_workers): 16 | self.dataset = dataset 17 | self.data_dir = data_dir 18 | self.percent = percent 19 | self.data2preprocess = data2preprocess 20 | self.use_type0 = use_type0 21 | self.use_type1 = use_type1 22 | self.num_workers = num_workers 23 | self.batch_size = batch_size 24 | 25 | 26 | 27 | def run(self,mode,preds=[],probs=[]): 28 | if mode=='train': 29 | train_dataset = get_dataset( 30 | self.dataset, 31 | data_dir=self.data_dir, 32 | dataset_split="train", 33 | transform_split="train", 34 | percent=self.percent, 35 | use_preprocess=self.data2preprocess[self.dataset], 36 | use_type0=self.use_type0, 37 | use_type1=self.use_type1 38 | ) 39 | 40 | Idx_train_dataset = IdxDataset(train_dataset) 41 | 42 | trainloader = DataLoader( 43 | Idx_train_dataset, 44 | batch_size=self.batch_size, 45 | shuffle=True, 46 | num_workers=self.num_workers, 47 | pin_memory=True, 48 | drop_last=True 49 | ) 50 | return trainloader, train_dataset 51 | 52 | 53 | elif mode=='valid': 54 | valid_dataset= get_dataset( 55 | self.dataset, 56 | data_dir=self.data_dir, 57 | dataset_split="valid", 58 | transform_split="valid", 59 | percent=self.percent, 60 | use_preprocess=self.data2preprocess[self.dataset], 61 | use_type0=self.use_type0, 62 | use_type1=self.use_type1 63 | ) 64 | 65 | valid_dataset = IdxDataset(valid_dataset) 66 | 67 | valid_trainloader = DataLoader( 68 | valid_dataset, 69 | batch_size=self.batch_size, 70 | shuffle=False, 71 | num_workers=self.num_workers, 72 | pin_memory=True, 73 | ) 74 | 75 | 76 | return valid_trainloader 77 | 78 | elif mode=='test': 79 | test_dataset= get_dataset( 80 | self.dataset, 81 | data_dir=self.data_dir, 82 | dataset_split="test", 83 | transform_split="valid", 84 | percent=self.percent, 85 | use_preprocess=self.data2preprocess[self.dataset], 86 | use_type0=self.use_type0, 87 | use_type1=self.use_type1 88 | ) 89 | 90 | test_dataset = IdxDataset(test_dataset) 91 | 92 | test_trainloader = DataLoader( 93 | test_dataset, 94 | batch_size=self.batch_size, 95 | shuffle=False, 96 | num_workers=self.num_workers, 97 | pin_memory=True, 98 | ) 99 | return test_trainloader 100 | 101 | elif mode=='eval_train': 102 | eval_train_dataset = get_dataset( 103 | self.dataset, 104 | data_dir=self.data_dir, 105 | dataset_split="train", 106 | transform_split="train", 107 | percent=self.percent, 108 | use_preprocess=self.data2preprocess[self.dataset], 109 | use_type0=self.use_type0, 110 | use_type1=self.use_type1 111 | ) 112 | 113 | eval_train_dataset = IdxDataset(eval_train_dataset) 114 | 115 | eval_train_trainloader = DataLoader( 116 | eval_train_dataset, 117 | batch_size=self.batch_size, 118 | shuffle=False, 119 | num_workers=self.num_workers, 120 | pin_memory=True, 121 | drop_last=False 122 | ) 123 | return eval_train_trainloader 124 | 125 | 126 | 127 | class IdxDataset(Dataset): 128 | def __init__(self, dataset): 129 | self.dataset = dataset 130 | 131 | def __len__(self): 132 | return len(self.dataset) 133 | 134 | def __getitem__(self, idx): 135 | return (idx, *self.dataset[idx]) 136 | 137 | 138 | class ZippedDataset(Dataset): 139 | def __init__(self, datasets): 140 | super(ZippedDataset, self).__init__() 141 | self.dataset_sizes = [len(d) for d in datasets] 142 | self.datasets = datasets 143 | 144 | def __len__(self): 145 | return max(self.dataset_sizes) 146 | 147 | def __getitem__(self, idx): 148 | items = [] 149 | for dataset_idx, dataset_size in enumerate(self.dataset_sizes): 150 | items.append(self.datasets[dataset_idx][idx % dataset_size]) 151 | 152 | item = [torch.stack(tensors, dim=0) for tensors in zip(*items)] 153 | 154 | return item 155 | 156 | class CMNISTDataset(Dataset): 157 | def __init__(self,root,split,transform=None, image_path_list=None, preds = None, bias = True): 158 | super(CMNISTDataset, self).__init__() 159 | self.transform = transform 160 | self.root = root 161 | self.image2pseudo = {} 162 | self.image_path_list = image_path_list 163 | 164 | if split=='train': 165 | self.align = glob(os.path.join(root, 'align',"*","*")) 166 | self.conflict = glob(os.path.join(root, 'conflict',"*","*")) 167 | data = self.align + self.conflict 168 | indicator = [0] * len(self.align) + [1] * len(self.conflict) 169 | 170 | # print(len(self.data),'***************') 171 | 172 | 173 | 174 | if (preds is not None): 175 | pred_idx = (preds).numpy().nonzero()[0] 176 | if bias: 177 | print("Discovered biased example id", pred_idx) 178 | else: 179 | print("Discovered unbiased example id", pred_idx) 180 | 181 | self.data = [data[i] for i in pred_idx] * int(len(data)/len(pred_idx)) 182 | self.indicator = [indicator[i] for i in pred_idx] * int(len(data)/len(pred_idx)) 183 | else: 184 | self.data = data 185 | self.indicator = indicator 186 | 187 | elif split=='valid': 188 | self.data = glob(os.path.join(root,split,"*")) 189 | elif split=='test': 190 | self.data = glob(os.path.join(root, '../test',"*","*")) 191 | self.split = split 192 | 193 | 194 | def __len__(self): 195 | return len(self.data) 196 | 197 | def __getitem__(self, index): 198 | attr = torch.LongTensor([int(self.data[index].split('_')[-2]),int(self.data[index].split('_')[-1].split('.')[0])]) 199 | image = Image.open(self.data[index]).convert('RGB') 200 | 201 | if self.transform is not None: 202 | image = self.transform(image) 203 | 204 | if self.split != 'train': 205 | return image, attr, self.data[index], -1 206 | else: 207 | return image, attr, self.data[index], self.indicator[index] 208 | 209 | 210 | class CIFAR10Dataset(Dataset): 211 | def __init__(self, root, split, transform=None, image_path_list=None, use_type0=None, use_type1=None, preds = None, bias = True): 212 | super(CIFAR10Dataset, self).__init__() 213 | self.transform = transform 214 | self.root = root 215 | self.image2pseudo = {} 216 | self.image_path_list = image_path_list 217 | 218 | if split=='train': 219 | self.align = glob(os.path.join(root, 'align',"*","*")) 220 | self.conflict = glob(os.path.join(root, 'conflict',"*","*")) 221 | data = self.align + self.conflict 222 | indicator = [0] * len(self.align) + [1] * len(self.conflict) 223 | 224 | if (preds is not None): 225 | pred_idx = (preds).numpy().nonzero()[0] 226 | if bias: 227 | print("Discovered biased example id", pred_idx) 228 | else: 229 | print("Discovered unbiased example id", pred_idx) 230 | 231 | self.data = [data[i] for i in pred_idx] * int(len(data)/len(pred_idx)) 232 | self.indicator = [indicator[i] for i in pred_idx] * int(len(data)/len(pred_idx)) 233 | else: 234 | self.data = data 235 | self.indicator = indicator 236 | 237 | elif split=='valid': 238 | self.data = glob(os.path.join(root,split,"*", "*")) 239 | 240 | elif split=='test': 241 | self.data = glob(os.path.join(root, '../test',"*","*")) 242 | 243 | 244 | self.split = split 245 | 246 | def __len__(self): 247 | return len(self.data) 248 | 249 | def __getitem__(self, index): 250 | attr = torch.LongTensor( 251 | [int(self.data[index].split('_')[-2]), int(self.data[index].split('_')[-1].split('.')[0])]) 252 | image = Image.open(self.data[index]).convert('RGB') 253 | 254 | 255 | if self.transform is not None: 256 | image = self.transform(image) 257 | 258 | if self.split != 'train': 259 | return image, attr, self.data[index], -1 260 | else: 261 | return image, attr, self.data[index], self.indicator[index] 262 | 263 | 264 | class bFFHQDataset(Dataset): 265 | def __init__(self, root, split, transform=None, image_path_list=None, preds = None, bias = True): 266 | super(bFFHQDataset, self).__init__() 267 | self.transform = transform 268 | self.root = root 269 | 270 | self.image2pseudo = {} 271 | self.image_path_list = image_path_list 272 | 273 | 274 | 275 | 276 | if split=='train': 277 | self.align = glob(os.path.join(root, 'align',"*","*")) 278 | self.conflict = glob(os.path.join(root, 'conflict',"*","*")) 279 | data = self.align + self.conflict 280 | indicator = [0] * len(self.align) + [1] * len(self.conflict) 281 | 282 | if (preds is not None): 283 | pred_idx = (preds).numpy().nonzero()[0] 284 | if bias: 285 | print("Discovered biased example id", pred_idx) 286 | else: 287 | print("Discovered unbiased example id", pred_idx) 288 | 289 | self.data = [data[i] for i in pred_idx] * int(len(data)/len(pred_idx)) 290 | self.indicator = [indicator[i] for i in pred_idx] * int(len(data)/len(pred_idx)) 291 | else: 292 | self.data = data 293 | self.indicator = indicator 294 | 295 | 296 | elif split=='valid': 297 | self.data = glob(os.path.join(os.path.dirname(root), split, "*")) 298 | 299 | elif split=='test': 300 | self.data = glob(os.path.join(os.path.dirname(root), split, "*")) 301 | data_conflict = [] 302 | for path in self.data: 303 | target_label = path.split('/')[-1].split('.')[0].split('_')[1] 304 | bias_label = path.split('/')[-1].split('.')[0].split('_')[2] 305 | if target_label != bias_label: 306 | data_conflict.append(path) 307 | self.data = data_conflict 308 | self.split = split 309 | def __len__(self): 310 | return len(self.data) 311 | 312 | def __getitem__(self, index): 313 | attr = torch.LongTensor( 314 | [int(self.data[index].split('_')[-2]), int(self.data[index].split('_')[-1].split('.')[0])]) 315 | image = Image.open(self.data[index]).convert('RGB') 316 | 317 | if self.transform is not None: 318 | image = self.transform(image) 319 | 320 | if self.split != 'train': 321 | return image, attr, self.data[index], -1 322 | else: 323 | return image, attr, self.data[index], self.indicator[index] 324 | 325 | 326 | class WaterBirdsDataset(Dataset): 327 | def __init__(self, root, split="train", transform=None, image_path_list=None, preds = None, bias = True): 328 | try: 329 | split_i = ["train", "valid", "test"].index(split) 330 | except ValueError: 331 | raise(f"Unknown split {split}") 332 | self.split = split 333 | metadata_df = pd.read_csv(os.path.join(root, "metadata.csv")) 334 | self.metadata_df = metadata_df[metadata_df["split"] == split_i] 335 | self.root = root 336 | self.transform = transform 337 | self.y_array = self.metadata_df['y'].values 338 | self.p_array = self.metadata_df['place'].values 339 | self.n_classes = np.unique(self.y_array).size 340 | self.confounder_array = self.metadata_df['place'].values 341 | self.n_places = np.unique(self.confounder_array).size 342 | self.group_array = (self.y_array * self.n_places + self.confounder_array).astype('int') 343 | self.indicator = np.abs(self.y_array - self.confounder_array).astype('int') 344 | self.n_groups = self.n_classes * self.n_places 345 | self.group_counts = ( 346 | torch.arange(self.n_groups).unsqueeze(1) == torch.from_numpy(self.group_array)).sum(1).float() 347 | self.y_counts = ( 348 | torch.arange(self.n_classes).unsqueeze(1) == torch.from_numpy(self.y_array)).sum(1).float() 349 | self.p_counts = ( 350 | torch.arange(self.n_places).unsqueeze(1) == torch.from_numpy(self.p_array)).sum(1).float() 351 | self.filename_array = self.metadata_df['img_filename'].values 352 | 353 | def __len__(self): 354 | return len(self.metadata_df) 355 | 356 | def __getitem__(self, idx): 357 | y = self.y_array[idx] 358 | g = self.group_array[idx] 359 | p = self.confounder_array[idx] 360 | 361 | attr = torch.LongTensor( 362 | [y, p, g]) 363 | 364 | img_path = os.path.join(self.root, self.filename_array[idx]) 365 | img = Image.open(img_path).convert('RGB') 366 | # img = read_image(img_path) 367 | # img = img.float() / 255. 368 | 369 | 370 | 371 | if self.transform: 372 | img = self.transform(img) 373 | 374 | if self.split != 'train': 375 | return img, attr, self.filename_array[idx], self.indicator[idx] 376 | else: 377 | return img, attr, self.filename_array[idx], self.indicator[idx] 378 | 379 | 380 | 381 | 382 | transforms = { 383 | "cmnist": { 384 | "train": T.Compose([T.ToTensor()]), 385 | "valid": T.Compose([T.ToTensor()]), 386 | "test": T.Compose([T.ToTensor()]) 387 | }, 388 | "bffhq": { 389 | "train": T.Compose([T.Resize((224,224)), T.ToTensor()]), 390 | "valid": T.Compose([T.Resize((224,224)), T.ToTensor()]), 391 | "test": T.Compose([T.Resize((224,224)), T.ToTensor()]) 392 | }, 393 | "cifar10c": { 394 | "train": T.Compose([T.ToTensor(),]), 395 | "valid": T.Compose([T.ToTensor(),]), 396 | "test": T.Compose([T.ToTensor(),]), 397 | }, 398 | 399 | "waterbird":{ 400 | "train": T.Compose([ 401 | T.Resize((256, 256)), 402 | T.CenterCrop((224,224)), 403 | T.ToTensor(), 404 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), 405 | "valid": T.Compose([ 406 | T.Resize((256, 256)), 407 | T.CenterCrop((224,224)), 408 | T.ToTensor(), 409 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), 410 | "test": T.Compose([ 411 | T.Resize((256, 256)), 412 | T.CenterCrop((224,224)), 413 | T.ToTensor(), 414 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 415 | }, 416 | 417 | } 418 | 419 | 420 | transforms_preprcs = { 421 | "cmnist": { 422 | "train": T.Compose([T.ToTensor()]), 423 | "valid": T.Compose([T.ToTensor()]), 424 | "test": T.Compose([T.ToTensor()]) 425 | }, 426 | "bffhq": { 427 | "train": T.Compose([ 428 | T.Resize((224,224)), 429 | T.RandomCrop(224, padding=4), 430 | T.RandomHorizontalFlip(), 431 | T.ToTensor(), 432 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 433 | ] 434 | ), 435 | "valid": T.Compose([ 436 | T.Resize((224,224)), 437 | T.ToTensor(), 438 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 439 | ] 440 | ), 441 | "test": T.Compose([ 442 | T.Resize((224,224)), 443 | T.ToTensor(), 444 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 445 | ] 446 | ) 447 | }, 448 | "cifar10c": { 449 | "train": T.Compose( 450 | [ 451 | T.RandomCrop(32, padding=4), 452 | # T.RandomResizedCrop(32), 453 | T.RandomHorizontalFlip(), 454 | T.ToTensor(), 455 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 456 | ] 457 | ), 458 | "valid": T.Compose( 459 | [ 460 | T.ToTensor(), 461 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 462 | ] 463 | ), 464 | "test": T.Compose( 465 | [ 466 | T.ToTensor(), 467 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 468 | ] 469 | ), 470 | }, 471 | "waterbird": { 472 | "train": T.Compose( 473 | [ 474 | T.RandomResizedCrop( 475 | (224,224), 476 | scale=(0.7, 1.0), 477 | ratio=(0.75, 1.3333333333333333), 478 | interpolation=2), 479 | T.RandomHorizontalFlip(), 480 | T.ToTensor(), 481 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 482 | ] 483 | ), 484 | "valid": T.Compose( 485 | [ 486 | T.Resize((256, 256)), 487 | T.CenterCrop((224,224)), 488 | T.ToTensor(), 489 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 490 | ] 491 | ), 492 | "test": T.Compose( 493 | [ 494 | T.Resize((256, 256)), 495 | T.CenterCrop((224,224)), 496 | T.ToTensor(), 497 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 498 | ] 499 | ), 500 | }, 501 | } 502 | 503 | 504 | 505 | transforms_preprcs_ae = { 506 | "cmnist": { 507 | "train": T.Compose([T.ToTensor()]), 508 | "valid": T.Compose([T.ToTensor()]), 509 | "test": T.Compose([T.ToTensor()]) 510 | }, 511 | "bffhq": { 512 | "train": T.Compose([ 513 | T.Resize((224,224)), 514 | T.RandomCrop(224, padding=4), 515 | T.RandomHorizontalFlip(), 516 | T.ToTensor(), 517 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 518 | ] 519 | ), 520 | "valid": T.Compose([ 521 | T.Resize((224,224)), 522 | T.ToTensor(), 523 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 524 | ] 525 | ), 526 | "test": T.Compose([ 527 | T.Resize((224,224)), 528 | T.ToTensor(), 529 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 530 | ] 531 | ) 532 | }, 533 | "cifar10c": { 534 | "train": T.Compose( 535 | [ 536 | # T.RandomCrop(32, padding=4), 537 | T.RandomResizedCrop(32), 538 | T.RandomHorizontalFlip(), 539 | T.ToTensor(), 540 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 541 | ] 542 | ), 543 | "valid": T.Compose( 544 | [ 545 | T.ToTensor(), 546 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 547 | ] 548 | ), 549 | "test": T.Compose( 550 | [ 551 | T.ToTensor(), 552 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 553 | ] 554 | ), 555 | }, 556 | 557 | "waterbird":{ 558 | "train": T.Compose([ 559 | T.Resize((256, 256)), 560 | T.CenterCrop((224,224)), 561 | T.ToTensor(), 562 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), 563 | "valid": T.Compose([ 564 | T.Resize((256, 256)), 565 | T.CenterCrop((224,224)), 566 | T.ToTensor(), 567 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), 568 | "test": T.Compose([ 569 | T.Resize((256, 256)), 570 | T.CenterCrop((224,224)), 571 | T.ToTensor(), 572 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 573 | }, 574 | } 575 | def get_dataset(dataset, data_dir, dataset_split, transform_split, percent, use_preprocess=None, image_path_list=None, use_type0=None, use_type1=None, preds = None, bias = True): 576 | dataset_category = dataset.split("-")[0] 577 | if dataset != 'civilcomments': 578 | if use_preprocess: 579 | transform = transforms_preprcs[dataset_category][transform_split] 580 | else: 581 | transform = transforms[dataset_category][transform_split] 582 | else: 583 | arch = 'bert-base-uncased_pt' 584 | pretrained_name = arch if arch[-3:] != '_pt' else arch[:-3] 585 | tokenizer = BertTokenizerFast.from_pretrained(pretrained_name) # 'bert-base-uncased' 586 | transform = init_bert_transform(tokenizer, pretrained_name, max_token_length = 300) 587 | 588 | 589 | dataset_split = "valid" if (dataset_split == "eval") else dataset_split 590 | if dataset == 'cmnist': 591 | root = data_dir + f"/cmnist/{percent}" 592 | dataset = CMNISTDataset(root=root,split=dataset_split,transform=transform, image_path_list=image_path_list, preds = preds, bias = bias) 593 | 594 | elif 'cifar10c' in dataset: 595 | # if use_type0: 596 | # root = data_dir + f"/cifar10c_0805_type0/{percent}" 597 | # elif use_type1: 598 | # root = data_dir + f"/cifar10c_0805_type1/{percent}" 599 | # else: 600 | root = data_dir + f"/cifar10c/{percent}" 601 | dataset = CIFAR10Dataset(root=root, split=dataset_split, transform=transform, image_path_list=image_path_list, use_type0=use_type0, use_type1=use_type1) 602 | 603 | elif dataset == "bffhq": 604 | root = data_dir + f"/bffhq/{percent}" 605 | dataset = bFFHQDataset(root=root, split=dataset_split, transform=transform, image_path_list=image_path_list) 606 | elif dataset == 'waterbird': 607 | root = data_dir + f"/waterbird" 608 | dataset = WaterBirdsDataset(root=root, split=dataset_split, transform=transform, image_path_list=image_path_list) 609 | elif dataset == 'civilcomments': 610 | root = data_dir + f"/CivilComments" 611 | if dataset_split == 'train': 612 | dataset = CivilComments(root, target_name='toxic', 613 | confounder_names=['identities'], 614 | split='train', transform=transform) 615 | elif dataset_split == 'valid': 616 | dataset = CivilComments(root, target_name='toxic', 617 | confounder_names=['identities'], 618 | split='val', transform=transform) 619 | elif dataset_split == 'test': 620 | dataset = CivilComments(root, target_name='toxic', 621 | confounder_names=['identities'], 622 | split='test', transform=transform) 623 | else: 624 | print('wrong dataset ...') 625 | import sys 626 | sys.exit(0) 627 | 628 | return dataset 629 | 630 | -------------------------------------------------------------------------------- /learner.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import wandb 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | import os 8 | import torch.optim as optim 9 | 10 | from data.util import dataloader 11 | from module.lc_loss.loss import GeneralizedCELoss, LogitCorrectionLoss 12 | from module.lc_loss.group_mixup import group_mixUp 13 | from module.util import get_model 14 | from util import EMA, EMA_squre, sigmoid_rampup, get_bert_scheduler 15 | 16 | import torch.nn.functional as F 17 | import random 18 | 19 | 20 | import matplotlib.pyplot as plt 21 | import matplotlib.cm as cmap 22 | 23 | def cycle(iterable): 24 | while True: 25 | for x in iterable: 26 | yield x 27 | 28 | 29 | 30 | class Learner(nn.Module): 31 | def __init__(self, args): 32 | super(Learner, self).__init__() 33 | data2model = {'cmnist': "MLP", 34 | 'cifar10c': "ResNet18", 35 | 'bffhq': "ResNet18", 36 | 'waterbird': "resnet_50", 37 | 'civilcomments': "bert-base-uncased_pt"} 38 | 39 | data2batch_size = {'cmnist': 256, 40 | 'cifar10c': 256, 41 | 'bffhq': 64, 42 | 'waterbird': 32, 43 | 'civilcomments': 16} 44 | 45 | data2preprocess = {'cmnist': None, 46 | 'cifar10c': True, 47 | 'bffhq': True, 48 | 'waterbird': True, 49 | 'civilcomments': None} 50 | 51 | self.data2preprocess = data2preprocess 52 | self.data2batch_size = data2batch_size 53 | 54 | 55 | args.exp = '{:s}_ema_{:.2f}_tau_{:.2f}_lambda_{:.2f}_avgtype_{:s}'.format(args.exp, args.ema_alpha, args.tau, args.lambda_dis_align, args.avg_type) 56 | 57 | if args.wandb: 58 | import wandb 59 | wandb.init(project='Learning-with-Logit-Correction') 60 | wandb.run.name = args.exp 61 | 62 | run_name = args.exp 63 | if args.tensorboard: 64 | from tensorboardX import SummaryWriter 65 | self.writer = SummaryWriter(f'result/summary/{run_name}') 66 | 67 | self.model = data2model[args.dataset] 68 | self.batch_size = data2batch_size[args.dataset] 69 | 70 | print(f'model: {self.model} || dataset: {args.dataset}') 71 | print(f'working with experiment: {args.exp}...') 72 | self.log_dir = os.makedirs(os.path.join(args.log_dir, args.dataset, args.exp), exist_ok=True) 73 | self.device = torch.device(args.device) 74 | self.args = args 75 | 76 | # logging directories 77 | self.log_dir = os.path.join(args.log_dir, args.dataset, args.exp) 78 | self.summary_dir = os.path.join(args.log_dir, args.dataset, "summary", args.exp) 79 | self.summary_gradient_dir = os.path.join(self.log_dir, "gradient") 80 | self.result_dir = os.path.join(self.log_dir, "result") 81 | self.plot_dir = os.path.join(self.log_dir, "figure") 82 | os.makedirs(self.summary_dir, exist_ok=True) 83 | os.makedirs(self.result_dir, exist_ok=True) 84 | os.makedirs(self.plot_dir, exist_ok=True) 85 | 86 | self.loader = dataloader( 87 | args.dataset, 88 | args.data_dir, 89 | args.percent, 90 | data2preprocess, 91 | args.use_type0, 92 | args.use_type1, 93 | self.batch_size, 94 | args.num_workers 95 | ) 96 | 97 | self.train_loader, self.train_dataset = self.loader.run('train') 98 | 99 | self.valid_loader = self.loader.run('valid') 100 | self.test_loader = self.loader.run('test') 101 | 102 | 103 | if args.dataset == 'waterbird' or args.dataset == 'civilcomments': 104 | train_target_attr = self.train_dataset.y_array 105 | train_target_attr = torch.LongTensor(train_target_attr) 106 | else: 107 | train_target_attr = [] 108 | for data in self.train_dataset.data: 109 | train_target_attr.append(int(data.split('_')[-2])) 110 | train_target_attr = torch.LongTensor(train_target_attr) 111 | 112 | attr_dims = [] 113 | attr_dims.append(torch.max(train_target_attr).item() + 1) 114 | self.num_classes = attr_dims[0] 115 | num_example = len(train_target_attr) 116 | print('Num example in training is {:d}, Num classes is {:d} \n'.format(num_example, self.num_classes)) 117 | 118 | 119 | self.sample_margin_ema_b = EMA(torch.LongTensor(train_target_attr), num_classes=self.num_classes, alpha=0) 120 | self.confusion = EMA_squre(num_classes=self.num_classes, alpha=args.ema_alpha, avg_type = args.avg_type) 121 | print(f'alpha : {self.sample_margin_ema_b.alpha}') 122 | 123 | self.best_valid_acc_b, self.best_test_acc_b = 0., 0. 124 | self.best_valid_acc_d, self.best_test_acc_d = 0., 0. 125 | 126 | self.best_valid_acc_avg, self.best_test_acc_avg = 0., 0. 127 | self.best_valid_acc_worst, self.best_test_acc_worst = 0., 0. 128 | 129 | print('finished model initialization....') 130 | 131 | 132 | # evaluation code for vanilla 133 | def evaluate(self, model, data_loader): 134 | model.eval() 135 | total_correct, total_num = 0, 0 136 | for _, data, attr, _, _ in tqdm(data_loader, leave=False): 137 | 138 | label = attr[:, 0] 139 | if attr.shape[1] > 2: 140 | group = attr[:, 2] 141 | else: 142 | group = None 143 | # label = attr 144 | data = data.to(self.device) 145 | label = label.to(self.device) 146 | 147 | 148 | 149 | # label = attr[:, 0] 150 | # data = data.to(self.device) 151 | # label = label.to(self.device) 152 | 153 | with torch.no_grad(): 154 | logit = model(data) 155 | pred = logit.data.max(1, keepdim=True)[1].squeeze(1) 156 | correct = (pred == label).long() 157 | total_correct += correct.sum() 158 | total_num += correct.shape[0] 159 | 160 | accs = total_correct/float(total_num) 161 | model.train() 162 | 163 | return accs 164 | 165 | 166 | def summarize_acc(self, correct_by_groups, total_by_groups, bias = True, split = 'Train',stdout=True): 167 | all_correct = 0 168 | all_total = 0 169 | min_acc = 101. 170 | min_correct_total = [None, None] 171 | # if stdout: 172 | # print(split + ' Accuracies by groups:') 173 | for yix, y_group in enumerate(correct_by_groups): 174 | for aix, a_group in enumerate(y_group): 175 | acc = a_group / total_by_groups[yix][aix] * 100 176 | if acc < min_acc: 177 | min_acc = acc 178 | min_correct_total[0] = a_group 179 | min_correct_total[1] = total_by_groups[yix][aix] 180 | if stdout: 181 | print( 182 | f'{yix}, {aix} acc: {int(a_group):5d} / {int(total_by_groups[yix][aix]):5d} = {a_group / total_by_groups[yix][aix] * 100:>7.3f}') 183 | all_correct += a_group 184 | all_total += total_by_groups[yix][aix] 185 | if stdout: 186 | if bias: 187 | average_str = f'Bised Average acc: {int(all_correct):5d} / {int(all_total):5d} = {100 * all_correct / all_total:>7.3f}' 188 | robust_str = f'Bised Robust acc: {int(min_correct_total[0]):5d} / {int(min_correct_total[1]):5d} = {min_acc:>7.3f}' 189 | else: 190 | average_str = f'Average acc: {int(all_correct):5d} / {int(all_total):5d} = {100 * all_correct / all_total:>7.3f}' 191 | robust_str = f'Robust acc: {int(min_correct_total[0]):5d} / {int(min_correct_total[1]):5d} = {min_acc:>7.3f}' 192 | print('-' * len(average_str)) 193 | print(average_str) 194 | print(robust_str) 195 | print('-' * len(average_str)) 196 | # return all_correct / all_total * 100, min_acc 197 | return min_acc 198 | 199 | # model_b, model_l, data_loader, n_group = 4, model='label', mode='dummy' 200 | def evaluate_civilcomments(self, net_b, net, dataloader, bias = False, n_group = 4, model='label', mode='dummy', split='Train', step = 0): 201 | 202 | 203 | if bias: 204 | net = net_b 205 | 206 | dataset = dataloader.dataset.dataset 207 | metadata = dataset.metadata_array 208 | correct_by_groups = np.zeros([2, len(dataset._identity_vars)]) 209 | total_by_groups = np.zeros(correct_by_groups.shape) 210 | 211 | identity_to_ix = {} 212 | for idx, identity in enumerate(dataset._identity_vars): 213 | identity_to_ix[identity] = idx 214 | 215 | for identity_var, eval_grouper in zip(dataset._identity_vars, 216 | dataset._eval_groupers): 217 | group_idx = eval_grouper.metadata_to_group(metadata).numpy() 218 | 219 | g_list, g_counts = np.unique(group_idx, return_counts=True) 220 | # print(identity_var, identity_to_ix[identity_var]) 221 | # print(g_counts) 222 | 223 | for g_ix, g in enumerate(g_list): 224 | g_count = g_counts[g_ix] 225 | # Only pick from positive identities 226 | # e.g. only 1 and 3 from here: 227 | # 0 y:0_male:0 228 | # 1 y:0_male:1 229 | # 2 y:1_male:0 230 | # 3 y:1_male:1 231 | n_total = g_counts[g_ix] # + g_counts[3] 232 | if g in [1, 3]: 233 | class_ix = 0 if g == 1 else 1 # 1 y:0_male:1 234 | # print(g_ix, g, n_total) 235 | 236 | # net.to(args.device) 237 | net.eval() 238 | total_correct = 0 239 | with torch.no_grad(): 240 | all_predictions = [] 241 | all_correct = [] 242 | 243 | for data in tqdm(dataloader, leave=False): 244 | i, inputs, attr, _, _ = data 245 | # inputs, labels, data_ix = data 246 | #for i, data in enumerate(tqdm(dataloader)): 247 | labels = attr[:, 0] 248 | # label = attr 249 | inputs = inputs.to(self.device) 250 | labels = labels.to(self.device) 251 | 252 | # Add this here to generalize NLP, CV models 253 | #outputs = get_output(net, inputs, labels, args) 254 | input_ids = inputs[:, :, 0] 255 | input_masks = inputs[:, :, 1] 256 | segment_ids = inputs[:, :, 2] 257 | outputs = net(input_ids=input_ids, 258 | attention_mask=input_masks, 259 | token_type_ids=segment_ids, 260 | labels=labels)[1] 261 | 262 | _, predicted = torch.max(outputs.data, 1) 263 | correct = (predicted == labels).detach().cpu() 264 | total_correct += correct.sum().item() 265 | all_correct.append(correct) 266 | all_predictions.append(predicted.detach().cpu()) 267 | 268 | inputs = inputs.to(torch.device('cpu')) 269 | labels = labels.to(torch.device('cpu')) 270 | outputs = outputs.to(torch.device('cpu')) 271 | del inputs; del labels; del outputs 272 | 273 | all_correct = torch.cat(all_correct).numpy() 274 | all_predictions = torch.cat(all_predictions) 275 | 276 | # Evaluate predictions 277 | # dataset = dataloader.dataset 278 | y_pred = all_predictions # torch.tensors 279 | y_true = dataset.y_array 280 | metadata = dataset.metadata_array 281 | 282 | correct_by_groups = np.zeros([2, len(dataset._identity_vars)]) 283 | total_by_groups = np.zeros(correct_by_groups.shape) 284 | 285 | n_group = 2 * len(dataset._identity_vars) 286 | for identity_var, eval_grouper in zip(dataset._identity_vars, 287 | dataset._eval_groupers): 288 | group_idx = eval_grouper.metadata_to_group(metadata).numpy() 289 | 290 | g_list, g_counts = np.unique(group_idx, return_counts=True) 291 | # print(g_counts) 292 | 293 | idx = identity_to_ix[identity_var] 294 | 295 | for g_ix, g in enumerate(g_list): 296 | g_count = g_counts[g_ix] 297 | # Only pick from positive identities 298 | # e.g. only 1 and 3 from here: 299 | # 0 y:0_male:0 300 | # 1 y:0_male:1 301 | # 2 y:1_male:0 302 | # 3 y:1_male:1 303 | n_total = g_count # s[1] + g_counts[3] 304 | if g in [1, 3]: 305 | n_correct = all_correct[group_idx == g].sum() 306 | class_ix = 0 if g == 1 else 1 # 1 y:0_male:1 307 | correct_by_groups[class_ix][idx] += n_correct 308 | total_by_groups[class_ix][idx] += n_total 309 | 310 | group_acc = correct_by_groups / total_by_groups 311 | acc_groups = {} 312 | for group in range(n_group): 313 | acc_groups[group] = group_acc[group // len(dataset._identity_vars), group % len(dataset._identity_vars)] 314 | accs = total_correct/len(dataset) 315 | 316 | robust_acc = self.summarize_acc(correct_by_groups, total_by_groups, bias = bias, split=split, stdout=True) 317 | if not bias: 318 | if split == 'test': 319 | self.save_bert(step, robust_acc, net) 320 | 321 | net.train() 322 | 323 | 324 | return accs, acc_groups 325 | # return 0, total_correct, len(dataset), correct_by_groups, total_by_groups, None, None, None 326 | 327 | def evaluate_ours(self, model_b, model_l, data_loader, n_group = 4, model='label', mode='dummy', split='Train', step = 0): 328 | if self.args.dataset =='civilcomments': 329 | return self.evaluate_civilcomments(model_b, model_l, data_loader, bias = (model == 'bias'), n_group = 4, model='label', mode='dummy', split=split, step = step) 330 | 331 | 332 | model_b.eval() 333 | model_l.eval() 334 | 335 | total_correct, total_num = 0, 0 336 | total_correct_groups = {} 337 | total_num_groups = {} 338 | acc_groups = {} 339 | for group in range(n_group): 340 | total_correct_groups[group] = 0 341 | total_num_groups[group] = 0 342 | 343 | 344 | iter = 0 345 | 346 | # index, data, attr, image_path, indicator 347 | 348 | for _, data, attr, _, _ in tqdm(data_loader, leave=False): 349 | label = attr[:, 0] 350 | if attr.shape[1] > 2: 351 | group = attr[:, 2] 352 | else: 353 | group = None 354 | # label = attr 355 | data = data.to(self.device) 356 | label = label.to(self.device) 357 | 358 | with torch.no_grad(): 359 | if self.args.dataset == 'cmnist': 360 | z_l = model_l.extract(data) 361 | z_b = model_b.extract(data) 362 | elif self.args.dataset != 'civilcomments': 363 | z_l, z_b = [], [] 364 | if mode == 'dummy': 365 | hook_fn = self.model_l.avgpool.register_forward_hook(self.concat_dummy(z_l)) 366 | else: 367 | hook_fn = self.model_l.avgpool.register_forward_hook(self.no_dummy(z_l)) 368 | _ = self.model_l(data) 369 | hook_fn.remove() 370 | z_l = z_l[0] 371 | if mode == 'dummy': 372 | hook_fn = self.model_b.avgpool.register_forward_hook(self.concat_dummy(z_b)) 373 | else: 374 | hook_fn = self.model_b.avgpool.register_forward_hook(self.no_dummy(z_b)) 375 | _ = self.model_b(data) 376 | hook_fn.remove() 377 | z_b = z_b[0] 378 | 379 | if mode == 'dummy': 380 | if iter == 0: 381 | # print('Current mode using is {:s} \n'.format(mode)) 382 | iter += 1 383 | z_origin = torch.cat((z_l, z_b), dim=1) 384 | if model == 'bias': 385 | pred_label = model_b.fc(z_origin) 386 | else: 387 | pred_label = model_l.fc(z_origin) 388 | else: 389 | if iter == 0: 390 | # print('Current mode using is {:s} \n'.format(mode)) 391 | iter += 1 392 | z_origin = z_l 393 | if model == 'bias': 394 | pred_label = model_b.fc(z_origin) 395 | else: 396 | pred_label = model_l.fc(z_origin) 397 | else: 398 | input_ids = data[:, :, 0] 399 | input_masks = data[:, :, 1] 400 | segment_ids = data[:, :, 2] 401 | if model == 'bias': 402 | pred_label = model_b(input_ids=input_ids, 403 | attention_mask=input_masks, 404 | token_type_ids=segment_ids, 405 | labels=label)[1] 406 | else: 407 | pred_label = model_l(input_ids=input_ids, 408 | attention_mask=input_masks, 409 | token_type_ids=segment_ids, 410 | labels=label)[1] 411 | 412 | 413 | pred = pred_label.data.max(1, keepdim=True)[1].squeeze(1) 414 | 415 | correct = (pred == label).long() 416 | total_correct += correct.sum() 417 | total_num += correct.shape[0] 418 | 419 | if group is not None: 420 | for group_id in range(n_group): 421 | group_select = (group == group_id) 422 | correct_select = (pred[group_select] == label[group_select]).long() 423 | total_correct_groups[group_id] += correct_select.sum() 424 | total_num_groups[group_id] += correct_select.shape[0] 425 | 426 | 427 | accs = total_correct/float(total_num) 428 | if group is not None: 429 | for group_id in range(n_group): 430 | acc_groups[group_id] = (total_correct_groups[group_id]/float(total_num_groups[group_id])).item() 431 | else: 432 | acc_groups = None 433 | model_b.train() 434 | model_l.train() 435 | 436 | return accs, acc_groups 437 | 438 | def save_vanilla(self, step, best=None): 439 | if best: 440 | model_path = os.path.join(self.result_dir, "best_model.th") 441 | else: 442 | model_path = os.path.join(self.result_dir, "model_{}.th".format(step)) 443 | state_dict = { 444 | 'steps': step, 445 | 'state_dict': self.model_b.state_dict(), 446 | 'optimizer': self.optimizer_b.state_dict(), 447 | } 448 | with open(model_path, "wb") as f: 449 | torch.save(state_dict, f) 450 | print(f'{step} model saved ...') 451 | 452 | 453 | def save_ours(self, step, best=None): 454 | if best: 455 | model_path = os.path.join(self.result_dir, "best_model_l.th") 456 | else: 457 | model_path = os.path.join(self.result_dir, "model_l_{}.th".format(step)) 458 | state_dict = { 459 | 'steps': step, 460 | 'state_dict': self.model_l.state_dict(), 461 | 'optimizer': self.optimizer_l.state_dict(), 462 | } 463 | with open(model_path, "wb") as f: 464 | torch.save(state_dict, f) 465 | 466 | if best: 467 | model_path = os.path.join(self.result_dir, "best_model_b.th") 468 | else: 469 | model_path = os.path.join(self.result_dir, "model_b_{}.th".format(step)) 470 | state_dict = { 471 | 'steps': step, 472 | 'state_dict': self.model_b.state_dict(), 473 | 'optimizer': self.optimizer_b.state_dict(), 474 | } 475 | with open(model_path, "wb") as f: 476 | torch.save(state_dict, f) 477 | 478 | print(f'{step} model saved ...') 479 | 480 | def save_bert(self, step, robust_acc, net): 481 | model_path = os.path.join(self.result_dir, f"model_{step}_{robust_acc}.th") 482 | state_dict = { 483 | 'state_dict': net.state_dict(), 484 | } 485 | with open(model_path, "wb") as f: 486 | torch.save(state_dict, f) 487 | 488 | print(f'model saved ...') 489 | 490 | 491 | def board_vanilla_loss(self, step, loss_b): 492 | if self.args.wandb: 493 | wandb.log({ 494 | "loss_b_train": loss_b, 495 | }, step=step,) 496 | 497 | if self.args.tensorboard: 498 | self.writer.add_scalar(f"loss/loss_b_train", loss_b, step) 499 | 500 | 501 | def board_ours_loss(self, step, loss_dis_conflict, loss_dis_align, confusion, global_count): 502 | 503 | flatten_confusion = confusion.flatten() 504 | print('Correction: ', flatten_confusion) 505 | if self.args.wandb: 506 | wandb.log({ 507 | "loss_dis_conflict": loss_dis_conflict, 508 | "loss_dis_align": loss_dis_align, 509 | "loss": loss_dis_conflict + loss_dis_align, 510 | }, step=step,) 511 | 512 | 513 | flatten_global_count = global_count.flatten() 514 | for i in range(len(flatten_confusion)): 515 | wandb.log({"logit_correction_"+str(i): flatten_confusion[i]}, step=step,) 516 | wandb.log({"global_count_"+str(i): flatten_global_count[i]}, step=step,) 517 | 518 | if self.args.tensorboard: 519 | self.writer.add_scalar(f"loss/loss_dis_conflict", loss_dis_conflict, step) 520 | self.writer.add_scalar(f"loss/loss_dis_align", loss_dis_align, step) 521 | self.writer.add_scalar(f"loss/loss", loss_dis_conflict + loss_dis_align) 522 | 523 | def board_vanilla_acc(self, step, epoch, inference=None): 524 | valid_accs_b = self.evaluate(self.model_b, self.valid_loader) 525 | test_accs_b = self.evaluate(self.model_b, self.test_loader) 526 | 527 | # print(f'epoch: {epoch}') 528 | 529 | if valid_accs_b >= self.best_valid_acc_b: 530 | self.best_valid_acc_b = valid_accs_b 531 | if test_accs_b >= self.best_test_acc_b: 532 | self.best_test_acc_b = test_accs_b 533 | self.save_vanilla(step, best=True) 534 | 535 | if self.args.wandb: 536 | wandb.log({ 537 | "acc_b_valid": valid_accs_b, 538 | "acc_b_test": test_accs_b, 539 | }, 540 | step=step,) 541 | wandb.log({ 542 | "best_acc_b_valid": self.best_valid_acc_b, 543 | "best_acc_b_test": self.best_test_acc_b, 544 | }, 545 | step=step, ) 546 | 547 | print(f'valid_b: {valid_accs_b} || test_b: {test_accs_b}') 548 | 549 | if self.args.tensorboard: 550 | self.writer.add_scalar(f"acc/acc_b_valid", valid_accs_b, step) 551 | self.writer.add_scalar(f"acc/acc_b_test", test_accs_b, step) 552 | 553 | self.writer.add_scalar(f"acc/best_acc_b_valid", self.best_valid_acc_b, step) 554 | self.writer.add_scalar(f"acc/best_acc_b_test", self.best_test_acc_b, step) 555 | 556 | 557 | def board_ours_acc(self, step, inference=None, model ='debias', mode = 'dummy', n_group = 4, eval = False, save = True): 558 | # check label network 559 | 560 | valid_accs_d, valid_acc_groups = self.evaluate_ours(self.model_b, self.model_l, self.valid_loader, n_group = n_group, model=model, mode=mode, split = 'valid', step = step) 561 | 562 | test_accs_d, test_acc_groups = self.evaluate_ours(self.model_b, self.model_l, self.test_loader, n_group = n_group, model=model, mode=mode, split = 'test', step = step) 563 | 564 | if eval: 565 | return 566 | 567 | 568 | if valid_acc_groups is not None: 569 | valid_group_acc_list = list(valid_acc_groups.values()) 570 | valid_accs_avg = np.nanmean(valid_group_acc_list) 571 | valid_accs_worst = np.nanmin(valid_group_acc_list) 572 | 573 | if valid_accs_avg >= self.best_valid_acc_avg: 574 | self.best_valid_acc_avg = valid_accs_avg 575 | if valid_accs_worst >= self.best_valid_acc_worst: 576 | self.best_valid_acc_worst = valid_accs_worst 577 | 578 | if test_acc_groups is not None: 579 | 580 | test_group_acc_list = list(test_acc_groups.values()) 581 | test_accs_avg = np.nanmean(test_group_acc_list) 582 | test_accs_worst = np.nanmin(test_group_acc_list) 583 | 584 | if test_accs_avg >=self.best_test_acc_avg: 585 | self.best_test_acc_avg = test_accs_avg 586 | if test_accs_worst >=self.best_test_acc_worst: 587 | self.best_test_acc_worst = test_accs_worst 588 | 589 | # else: 590 | # valid_accs_avg = 0 591 | # valid_accs_worst = 0 592 | # test_accs_avg = 0 593 | # test_accs_worst = 0 594 | 595 | 596 | if inference: 597 | print(f'test acc: {test_accs_d.item()}') 598 | import sys 599 | sys.exit(0) 600 | 601 | if valid_accs_d >= self.best_valid_acc_d: 602 | self.best_valid_acc_d = valid_accs_d 603 | if test_accs_d >= self.best_test_acc_d: 604 | self.best_test_acc_d = test_accs_d 605 | if save: 606 | self.save_ours(step, best=True) 607 | 608 | 609 | if self.args.wandb: 610 | wandb.log({ 611 | "acc_d_valid": valid_accs_d, 612 | "acc_d_test": test_accs_d, 613 | "acc_avg_valid": valid_accs_avg, 614 | "acc_worst_valid": valid_accs_worst, 615 | "acc_avg_test": test_accs_avg, 616 | "acc_worst_test": test_accs_worst 617 | }, 618 | step=step, ) 619 | wandb.log({ 620 | "best_acc_d_valid": self.best_valid_acc_d, 621 | "best_acc_d_test": self.best_test_acc_d, 622 | "best_acc_avg_valid": self.best_valid_acc_avg, 623 | "best_acc_avg_test": self.best_test_acc_avg, 624 | "best_acc_worst_valid": self.best_valid_acc_worst, 625 | "best_acc_worst_test": self.best_test_acc_worst, 626 | }, 627 | step=step, ) 628 | 629 | if (test_acc_groups is not None) and len(test_group_acc_list) < 16: 630 | 631 | for g_id in range(len(test_group_acc_list)): 632 | wandb.log({ 633 | "test_acc_group_" + str(g_id): test_group_acc_list[g_id], 634 | }, 635 | step=step, ) 636 | 637 | 638 | if self.args.tensorboard: 639 | self.writer.add_scalar(f"acc/acc_d_valid", valid_accs_d, step) 640 | self.writer.add_scalar(f"acc/acc_d_test", test_accs_d, step) 641 | self.writer.add_scalar(f"acc/best_acc_d_valid", self.best_valid_acc_d, step) 642 | self.writer.add_scalar(f"acc/best_acc_d_test", self.best_test_acc_d, step) 643 | self.writer.add_scalar(f"acc/best_acc_avg_valid", self.best_valid_acc_avg, step) 644 | self.writer.add_scalar(f"acc/best_acc_worst_valid", self.best_valid_acc_worst, step) 645 | self.writer.add_scalar(f"acc/best_acc_worst_test", self.best_test_acc_worst, step) 646 | 647 | if (test_acc_groups is not None) and len(test_group_acc_list) < 16: 648 | for g_id in range(len(test_group_acc_list)): 649 | print(f"test_acc_group_{g_id}: {valid_group_acc_list[g_id]}") 650 | print(f"valid_acc_group_{g_id}: {test_group_acc_list[g_id]}") 651 | print(f"Best Worst Test:{self.best_test_acc_worst}") 652 | print(f'valid_d: {valid_accs_d} || test_d: {test_accs_d} || best_test_d: {self.best_test_acc_d}') 653 | 654 | def concat_dummy(self, z): 655 | def hook(model, input, output): 656 | z.append(output.squeeze()) 657 | return torch.cat((output, torch.zeros_like(output)), dim=1) 658 | return hook 659 | 660 | def no_dummy(self, z): 661 | def hook(model, input, output): 662 | z.append(output.squeeze()) 663 | return output 664 | return hook 665 | 666 | def no_dummy_input(self, z): 667 | def hook(model, input): 668 | z.append(input[0]) 669 | return input 670 | return hook 671 | 672 | def train_vanilla(self, args): 673 | self.criterion = nn.CrossEntropyLoss() 674 | if args.dataset == 'cmnist': 675 | self.model_l = get_model('mlp_DISENTANGLE', self.num_classes, bias = True).to(self.device) 676 | self.model_b = get_model('mlp_DISENTANGLE', self.num_classes, bias = True).to(self.device) 677 | elif args.dataset == 'waterbird': 678 | self.model_l = get_model('resnet_50_pretrained', self.num_classes, bias = True).to(self.device) 679 | self.model_b = get_model('resnet_50_pretrained', self.num_classes, bias = True).to(self.device) 680 | elif args.dataset == 'civilcomments': 681 | self.model_l = get_model('bert-base-uncased_pt', self.num_classes, bias = True).to(self.device) 682 | self.model_b = get_model('bert-base-uncased_pt', self.num_classes, bias = True).to(self.device) 683 | else: 684 | if self.args.use_resnet20: # Use this option only for comparing with LfF 685 | self.model_l = get_model('ResNet20_OURS', self.num_classes).to(self.device) 686 | self.model_b = get_model('ResNet20_OURS', self.num_classes).to(self.device) 687 | print('our resnet20....') 688 | else: 689 | self.model_l = get_model('resnet_DISENTANGLE_pretrained', self.num_classes).to(self.device) 690 | self.model_b = get_model('resnet_DISENTANGLE_pretrained', self.num_classes).to(self.device) 691 | 692 | # self.model_b.load_state_dict(torch.load(os.path.join('./log/waterbird/waterbird_ours_GEC_0.9_ema_0.50_tau_0.10_lambda_2.00_avgtype_mv_batch/result/', 'best_model_b.th'))['state_dict']) 693 | 694 | if args.dataset == 'waterbird': 695 | print('!' * 10 + ' Using SGD ' + '!' * 10) 696 | 697 | self.optimizer_b = torch.optim.SGD( 698 | self.model_b.parameters(), 699 | lr=args.lr, 700 | weight_decay=0, 701 | ) 702 | elif args.dataset == 'civilcomments': 703 | print('------------------- AdamW -------------------------------') 704 | self.optimizer_b = torch.optim.SGD( 705 | self.model_b.parameters(), #1e-3 706 | lr=2e-5, 707 | weight_decay=0 708 | ) 709 | 710 | n_group = 16 711 | 712 | else: 713 | 714 | self.optimizer_b = torch.optim.Adam( 715 | self.model_b.fc.parameters(), 716 | lr=args.lr, 717 | weight_decay=args.weight_decay, 718 | ) 719 | 720 | step = 0 721 | 722 | for epoch in tqdm(range(args.num_epochs)): 723 | 724 | for index, data, attr, image_path, indicator in tqdm(self.train_loader, leave=False): 725 | 726 | data = data.to(self.device) 727 | attr = attr.to(self.device) 728 | label = attr[:, args.target_attr_idx] 729 | spurious_label = attr[:, 1] 730 | 731 | logit_b = self.model_b(data) 732 | 733 | loss_b_update = self.criterion(logit_b, label) 734 | 735 | loss = loss_b_update.mean() 736 | 737 | 738 | self.optimizer_b.zero_grad() 739 | loss.backward() 740 | self.optimizer_b.step() 741 | 742 | ################################################## 743 | #################### LOGGING ##################### 744 | ################################################## 745 | 746 | if (step % args.valid_freq == 0) and (args.dataset != 'waterbird'): 747 | print("------------------------ Validation Starts--------------------------------") 748 | self.board_vanilla_acc(step, epoch, inference=None) 749 | print("------------------------ Validation Done --------------------------------") 750 | step += 1 751 | 752 | if args.dataset == 'waterbird': 753 | self.board_ours_acc(epoch, model = 'bias', mode = 'no_dummy' , n_group = 4, save = False) 754 | 755 | if len(random_indices_all_groups) > 0: 756 | mixed_feature = lam * feature[indices_all_groups] + (1 - lam) * feature[random_indices_all_groups] 757 | mixed_correction = lam * correction[indices_all_groups] + (1 - lam) * correction[random_indices_all_groups] 758 | label_a = label[indices_all_groups] 759 | label_b = label[random_indices_all_groups] 760 | else: 761 | mixed_feature = None 762 | label_a, label_b, lam = None, None, None 763 | 764 | return mixed_feature, mixed_correction, label_a, label_b, lam 765 | 766 | 767 | def train_ours(self, args): 768 | epoch, cnt = 0, 0 769 | print('************** main training starts... ************** ') 770 | train_num = len(self.train_dataset) 771 | print('Length of training set: {:d}'.format(train_num)) 772 | 773 | self.bias_criterion = GeneralizedCELoss(q=args.q) 774 | self.criterion = LogitCorrectionLoss(eta = 1.0) 775 | 776 | if args.dataset == 'cmnist': 777 | self.model_l = get_model('mlp_DISENTANGLE', self.num_classes, bias = True).to(self.device) 778 | self.model_b = get_model('mlp_DISENTANGLE', self.num_classes, bias = True).to(self.device) 779 | elif args.dataset == 'waterbird': 780 | self.model_l = get_model('resnet_50_pretrained', self.num_classes, bias = True).to(self.device) 781 | self.model_b = get_model('resnet_50_pretrained', self.num_classes, bias = True).to(self.device) 782 | elif args.dataset == 'civilcomments': 783 | self.model_l = get_model('bert-base-uncased_pt', self.num_classes, bias = True).to(self.device) 784 | self.model_b = get_model('bert-base-uncased_pt', self.num_classes, bias = True).to(self.device) 785 | else: 786 | if self.args.use_resnet20: # Use this option only for comparing with LfF 787 | self.model_l = get_model('ResNet20_OURS', self.num_classes).to(self.device) 788 | self.model_b = get_model('ResNet20_OURS', self.num_classes).to(self.device) 789 | print('our resnet20....') 790 | else: 791 | self.model_l = get_model('resnet_DISENTANGLE', self.num_classes).to(self.device) 792 | self.model_b = get_model('resnet_DISENTANGLE', self.num_classes).to(self.device) 793 | 794 | if args.dataset == 'waterbird': 795 | print('!' * 10 + ' Using SGD ' + '!' * 10) 796 | 797 | self.optimizer_l = torch.optim.SGD( 798 | self.model_l.parameters(), #1e-3 799 | lr=args.lr, 800 | weight_decay=args.weight_decay,#1e-3 801 | ) 802 | 803 | self.optimizer_b = torch.optim.SGD( 804 | self.model_b.parameters(), 805 | lr=args.lr*0.1, 806 | weight_decay=0, 807 | ) 808 | elif args.dataset == 'civilcomments': 809 | self.optimizer_b = torch.optim.SGD( 810 | self.model_b.parameters(), 811 | lr=2e-5, 812 | weight_decay=0 813 | ) 814 | 815 | no_decay = ['bias', 'LayerNorm.weight'] 816 | optimizer_grouped_parameters = [ 817 | {'params': [p for n, p in self.model_l.named_parameters() 818 | if not any(nd in n for nd in no_decay)], 819 | 'weight_decay': args.weight_decay}, 820 | {'params': [p for n, p in self.model_l.named_parameters() 821 | if any(nd in n for nd in no_decay)], 822 | 'weight_decay': 0.0}] 823 | self.optimizer_l = optim.AdamW(optimizer_grouped_parameters, 824 | lr=args.lr, eps=1e-8) 825 | 826 | n_group = 16 827 | 828 | else: 829 | 830 | self.optimizer_l = torch.optim.Adam( 831 | self.model_l.parameters(), 832 | lr=args.lr, 833 | weight_decay=args.weight_decay, 834 | ) 835 | 836 | self.optimizer_b = torch.optim.Adam( 837 | self.model_b.parameters(), 838 | lr=args.lr*0.1, 839 | weight_decay=args.weight_decay, 840 | ) 841 | 842 | 843 | if args.use_lr_decay and args.dataset == 'waterbird': 844 | self.scheduler_b = optim.lr_scheduler.StepLR(self.optimizer_b, step_size=args.lr_decay_epoch, gamma=args.lr_gamma) 845 | self.scheduler_l = optim.lr_scheduler.StepLR(self.optimizer_l, step_size=args.lr_decay_epoch, gamma=args.lr_gamma) 846 | elif args.use_lr_decay and args.dataset == 'civilcomments': 847 | total_updates = args.num_epochs 848 | 849 | self.scheduler_b = get_bert_scheduler(self.optimizer_b, n_epochs=1,#args.num_epochs, 850 | warmup_steps=0, 851 | dataloader=self.train_loader) 852 | self.scheduler_l = get_bert_scheduler(self.optimizer_l, n_epochs=total_updates,#args.num_epochs, 853 | warmup_steps=0, 854 | dataloader=self.train_loader) 855 | else: 856 | self.scheduler_b = optim.lr_scheduler.StepLR(self.optimizer_b, step_size=args.lr_decay_step, gamma=args.lr_gamma) 857 | self.scheduler_l = optim.lr_scheduler.StepLR(self.optimizer_l, step_size=args.lr_decay_step, gamma=args.lr_gamma) 858 | 859 | print(f'criterion: {self.criterion}') 860 | print(f'bias criterion: {self.bias_criterion}') 861 | train_iter = iter(self.train_loader) 862 | 863 | step = 0 864 | 865 | for epoch in tqdm(range(args.num_epochs)): 866 | 867 | for index, data, attr, image_path, indicator in tqdm(self.train_loader, leave=False): 868 | 869 | data = data.to(self.device) 870 | attr = attr.to(self.device) 871 | label = attr[:, args.target_attr_idx].to(self.device) 872 | bias = attr[:, 1].to(self.device) 873 | if args.dataset == 'waterbird': 874 | alpha = sigmoid_rampup(epoch, args.curr_epoch)*0.5 875 | else: 876 | alpha = sigmoid_rampup(step, args.curr_step)*0.5 877 | 878 | if args.dataset == 'cmnist': 879 | # z_l = self.model_l.extract(data) 880 | z_b = self.model_b.extract(data) 881 | pred_align = self.model_b.fc(z_b) 882 | self.sample_margin_ema_b.update(F.softmax(pred_align.detach()/args.tau), index) 883 | pred_align_mv = self.sample_margin_ema_b.parameter[index].clone().detach() 884 | _, pseudo_label = torch.max(pred_align_mv, dim=1) 885 | self.confusion.update(pred_align_mv, label, pseudo_label, fix = None) 886 | correction_matrix = self.confusion.parameter.clone().detach() 887 | if args.avg_type == 'epoch': 888 | correction_matrix = correction_matrix/self.confusion.global_count_.to(self.device) 889 | 890 | correction_delta = correction_matrix[:,pseudo_label] 891 | correction_delta = torch.t(correction_delta) 892 | return_dict = group_mixUp(data, pseudo_label, correction_delta, label, self.num_classes, alpha) 893 | mixed_target_data = return_dict["mixed_feature"] 894 | mixed_biased_prediction = return_dict["mixed_correction"] 895 | label_a = return_dict["label_majority"] 896 | label_b = return_dict["label_minority"] 897 | lam_target = return_dict["lam"] 898 | 899 | z_l = self.model_l.extract(mixed_target_data) 900 | pred_conflict = self.model_l.fc(z_l) 901 | 902 | elif args.dataset == 'civilcomments': 903 | 904 | input_ids = data[:, :, 0] 905 | input_masks = data[:, :, 1] 906 | segment_ids = data[:, :, 2] 907 | 908 | pred_align = self.model_b(input_ids=input_ids, 909 | attention_mask=input_masks, 910 | token_type_ids=segment_ids, 911 | labels=label)[1] 912 | 913 | 914 | self.sample_margin_ema_b.update(F.softmax(pred_align.detach()/args.tau), index) 915 | pred_align_mv = self.sample_margin_ema_b.parameter[index].clone().detach() 916 | _, pseudo_label = torch.max(pred_align_mv, dim=1) 917 | self.confusion.update(pred_align_mv, label, pseudo_label, fix = None) 918 | correction_matrix = self.confusion.parameter.clone().detach() 919 | if args.avg_type == 'epoch': 920 | correction_matrix = correction_matrix/self.confusion.global_count_.to(self.device) 921 | 922 | correction_matrix = correction_matrix/((correction_matrix).sum(dim=0,keepdims =True) + 1e-4) 923 | correction_delta = correction_matrix[:,pseudo_label] 924 | correction_delta = torch.t(correction_delta) 925 | 926 | 927 | z_l = [] 928 | hook_fn = self.model_l.dropout.register_forward_pre_hook(self.no_dummy_input(z_l))#register_forward_hook(self.no_dummy_input(z_l)) 929 | _ = self.model_l(input_ids=input_ids, 930 | attention_mask=input_masks, 931 | token_type_ids=segment_ids, 932 | labels=label) 933 | hook_fn.remove() 934 | z_l = z_l[0] 935 | 936 | return_dict = group_mixUp(z_l, pseudo_label, correction_delta, label, self.num_classes, alpha) 937 | mixed_target_z_l = return_dict["mixed_feature"] 938 | mixed_biased_prediction = return_dict["mixed_correction"] 939 | label_a = return_dict["label_majority"] 940 | label_b = return_dict["label_minority"] 941 | lam_target = return_dict["lam"] 942 | 943 | 944 | pred_conflict = self.model_l.classifier(self.model_l.dropout(mixed_target_z_l)) 945 | 946 | else: 947 | z_b = [] 948 | # Use this only for reproducing CIFARC10 of LfF 949 | if self.args.use_resnet20: 950 | hook_fn = self.model_b.layer3.register_forward_hook(self.no_dummy(z_b)) 951 | _ = self.model_b(data) 952 | hook_fn.remove() 953 | z_b = z_b[0] 954 | 955 | pred_align = self.model_b.fc(z_b) 956 | self.sample_margin_ema_b.update(F.softmax(pred_align.detach()/args.tau), index) 957 | pred_align_mv = self.sample_margin_ema_b.parameter[index].clone().detach() 958 | _, pseudo_label = torch.max(pred_align_mv, dim=1) 959 | self.confusion.update(pred_align_mv, label, pseudo_label, fix = None) 960 | correction_matrix = self.confusion.parameter.clone().detach() 961 | if args.avg_type == 'epoch': 962 | correction_matrix = correction_matrix/self.confusion.global_count_.to(self.device) 963 | correction_matrix = correction_matrix/(correction_matrix).sum(dim=0,keepdims =True) 964 | correction_delta = correction_matrix[:,pseudo_label] 965 | correction_delta = torch.t(correction_delta) 966 | return_dict = group_mixUp(data, pseudo_label, correction_delta, label, self.num_classes, alpha) 967 | mixed_target_data = return_dict["mixed_feature"] 968 | mixed_biased_prediction = return_dict["mixed_correction"] 969 | label_a = return_dict["label_majority"] 970 | label_b = return_dict["label_minority"] 971 | lam_target = return_dict["lam"] 972 | 973 | 974 | z_l = [] 975 | hook_fn = self.model_l.layer3.register_forward_hook(self.no_dummy(z_l)) 976 | _ = self.model_l(mixed_target_data) 977 | hook_fn.remove() 978 | 979 | z_l = z_l[0] 980 | 981 | pred_conflict = self.model_l.fc(z_l) 982 | 983 | else: 984 | hook_fn = self.model_b.avgpool.register_forward_hook(self.no_dummy(z_b)) 985 | _ = self.model_b(data) 986 | hook_fn.remove() 987 | z_b = z_b[0] 988 | 989 | pred_align = self.model_b.fc(z_b) 990 | self.sample_margin_ema_b.update(F.softmax(pred_align.detach()/args.tau), index) 991 | pred_align_mv = self.sample_margin_ema_b.parameter[index].clone().detach() 992 | 993 | _, pseudo_label = torch.max(pred_align_mv, dim=1) 994 | self.confusion.update(pred_align_mv, label, pseudo_label, fix = None) 995 | correction_matrix = self.confusion.parameter.clone().detach() 996 | if args.avg_type == 'epoch': 997 | correction_matrix = correction_matrix/self.confusion.global_count_.to(self.device) 998 | correction_matrix = correction_matrix/(correction_matrix).sum(dim=0,keepdims =True) 999 | correction_delta = correction_matrix[:,pseudo_label] 1000 | correction_delta = torch.t(correction_delta) 1001 | 1002 | return_dict = group_mixUp(data, pseudo_label, correction_delta, label, self.num_classes, alpha) 1003 | mixed_target_data = return_dict["mixed_feature"] 1004 | mixed_biased_prediction = return_dict["mixed_correction"] 1005 | label_a = return_dict["label_majority"] 1006 | label_b = return_dict["label_minority"] 1007 | lam_target = return_dict["lam"] 1008 | 1009 | z_l = [] 1010 | hook_fn = self.model_l.avgpool.register_forward_hook(self.no_dummy(z_l)) 1011 | _ = self.model_l(mixed_target_data) 1012 | hook_fn.remove() 1013 | 1014 | z_l = z_l[0] 1015 | 1016 | pred_conflict = self.model_l.fc(z_l) 1017 | 1018 | 1019 | 1020 | self.sample_margin_ema_b.update(F.softmax(pred_align.detach()), index) 1021 | 1022 | loss_dis_conflict = lam_target * self.criterion(pred_conflict, label_a, mixed_biased_prediction) +\ 1023 | (1 - lam_target) * self.criterion(pred_conflict, label_b, mixed_biased_prediction) 1024 | 1025 | loss_dis_align = self.bias_criterion(pred_align, label) 1026 | loss = loss_dis_conflict.mean() + args.lambda_dis_align * loss_dis_align.mean() # Eq.2 L_dis 1027 | 1028 | 1029 | self.optimizer_l.zero_grad() 1030 | self.optimizer_b.zero_grad() 1031 | 1032 | loss.backward() 1033 | 1034 | if args.dataset == 'civilcomments': 1035 | torch.nn.utils.clip_grad_norm_(self.model_l.parameters(), 1.0) 1036 | 1037 | if args.use_lr_decay and args.dataset != 'waterbird': 1038 | self.scheduler_b.step() 1039 | self.scheduler_l.step() 1040 | 1041 | 1042 | self.optimizer_l.step() 1043 | self.optimizer_b.step() 1044 | 1045 | if args.use_lr_decay and step % args.lr_decay_step == 0 and args.dataset != 'waterbird': 1046 | print('******* learning rate decay .... ********') 1047 | print(f"self.optimizer_b lr: { self.optimizer_b.param_groups[-1]['lr']}") 1048 | print(f"self.optimizer_l lr: { self.optimizer_l.param_groups[-1]['lr']}") 1049 | 1050 | 1051 | if step > 0 and step % args.save_freq == 0 and args.dataset != 'waterbird' and args.dataset != 'civilcomments': 1052 | self.save_ours(step) 1053 | 1054 | 1055 | if step > 0 and step % args.log_freq == 0 and args.dataset != 'waterbird': #and args.dataset != 'civilcomments': 1056 | confusion_numpy = correction_matrix.cpu().numpy() 1057 | self.board_ours_loss( 1058 | step=step, 1059 | loss_dis_conflict=loss_dis_conflict.mean(), 1060 | loss_dis_align=args.lambda_dis_align * loss_dis_align.mean(), 1061 | confusion=confusion_numpy, 1062 | global_count=self.confusion.global_count_.cpu().numpy() 1063 | ) 1064 | if step > 0 and (step % args.valid_freq == 0) and (args.dataset != 'waterbird'): #and (args.dataset != 'civilcomments'): 1065 | print('################################ #epoch {:d} ############################\n'.format(epoch)) 1066 | self.board_ours_acc(step, model = 'debias', mode = 'no_dummy', n_group = n_group) 1067 | 1068 | step += 1 1069 | 1070 | if args.use_lr_decay and args.dataset == 'waterbird': 1071 | self.scheduler_b.step() 1072 | self.scheduler_l.step() 1073 | 1074 | if args.use_lr_decay and epoch % args.lr_decay_epoch == 0 and args.dataset == 'waterbird': 1075 | print('******* learning rate decay .... ********') 1076 | print(f"self.optimizer_b lr: { self.optimizer_b.param_groups[-1]['lr']}") 1077 | print(f"self.optimizer_l lr: { self.optimizer_l.param_groups[-1]['lr']}") 1078 | 1079 | 1080 | confusion_numpy = correction_matrix.cpu().numpy() 1081 | if (args.dataset == 'waterbird'): 1082 | self.board_ours_loss( 1083 | step=epoch, 1084 | loss_dis_conflict=loss_dis_conflict.mean(), 1085 | loss_dis_align=args.lambda_dis_align * loss_dis_align.mean(), 1086 | confusion=confusion_numpy, 1087 | global_count=self.confusion.global_count_.cpu().numpy() 1088 | ) 1089 | if (args.dataset == 'waterbird'): 1090 | self.board_ours_acc(epoch, model = 'debias', mode = 'no_dummy' , n_group = 4) 1091 | self.confusion.global_count_ = torch.zeros(self.num_classes, self.num_classes) 1092 | if args.avg_type == 'epoch': 1093 | self.confusion.initiate_parameter() 1094 | 1095 | 1096 | 1097 | def test_ours(self, args): 1098 | if args.dataset == 'cmnist': 1099 | self.model_l = get_model('mlp_DISENTANGLE', self.num_classes, bias = True).to(self.device) 1100 | self.model_b = get_model('mlp_DISENTANGLE', self.num_classes, bias = True).to(self.device) 1101 | elif args.dataset == 'waterbird': 1102 | self.model_l = get_model('resnet_50_pretrained', self.num_classes, bias = True).to(self.device) 1103 | self.model_b = get_model('resnet_50_pretrained', self.num_classes, bias = True).to(self.device) 1104 | elif args.dataset == 'civilcomments': 1105 | print('----------------- Load Bert --------------------') 1106 | self.model_l = get_model('bert-base-uncased_pt', self.num_classes, bias = True).to(self.device) 1107 | self.model_b = get_model('bert-base-uncased_pt', self.num_classes, bias = True).to(self.device) 1108 | else: 1109 | if self.args.use_resnet20: # Use this option only for comparing with LfF 1110 | self.model_l = get_model('ResNet20_OURS', self.num_classes).to(self.device) 1111 | self.model_b = get_model('ResNet20_OURS', self.num_classes).to(self.device) 1112 | print('our resnet20....') 1113 | else: 1114 | self.model_l = get_model('resnet_DISENTANGLE', self.num_classes).to(self.device) 1115 | self.model_b = get_model('resnet_DISENTANGLE', self.num_classes).to(self.device) 1116 | 1117 | self.model_l.load_state_dict(torch.load(os.path.join(args.pretrained_path, 'model_200_37.55377996312232.th'))['state_dict']) 1118 | self.board_ours_acc(-1, model = 'debias', mode = 'no_dummy', n_group = 16, eval = True) 1119 | --------------------------------------------------------------------------------