├── images
└── main.jpg
├── .gitmodules
├── LICENSE
├── module
├── util.py
├── mlp.py
├── resnet.py
└── activations.py
├── test.py
├── README.md
├── train.py
├── data
├── civilcomments.py
├── grouper.py
├── __init__.py
└── util.py
├── util.py
└── learner.py
/images/main.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shengliu66/LC/HEAD/images/main.jpg
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "module/lc-loss"]
2 | path = module/lc-loss
3 | url = https://github.com/amazon-science/lc-loss.git
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Sheng Liu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/module/util.py:
--------------------------------------------------------------------------------
1 | ''' Modified from https://github.com/alinlab/LfF/blob/master/module/util.py '''
2 |
3 | import torch.nn as nn
4 | from module.resnet import resnet20
5 | from module.mlp import *
6 | from torchvision.models import resnet18, resnet50
7 | from transformers import BertForSequenceClassification, BertConfig
8 |
9 | def get_model(model_tag, num_classes, bias = True):
10 | if model_tag == "ResNet20":
11 | return resnet20(num_classes)
12 | elif model_tag == "ResNet20_OURS":
13 | model = resnet20(num_classes)
14 | model.fc = nn.Linear(128, num_classes)
15 | return model
16 | elif model_tag == "ResNet18":
17 | print('bringing no pretrained resnet18 ...')
18 | model = resnet18(pretrained=False)
19 | model.fc = nn.Linear(512, num_classes)
20 | return model
21 | elif model_tag == "MLP":
22 | return MLP(num_classes=num_classes)
23 | elif model_tag == "mlp_DISENTANGLE":
24 | return MLP_DISENTANGLE(num_classes=num_classes, bias = bias)
25 | elif model_tag == "mlp_DISENTANGLE_EASY":
26 | return MLP_DISENTANGLE_EASY(num_classes=num_classes)
27 | elif model_tag == 'resnet_DISENTANGLE':
28 | print('bringing no pretrained resnet18 disentangle...')
29 | model = resnet18(pretrained=False)
30 | model.fc = nn.Linear(1024//2, num_classes)
31 | return model
32 | elif model_tag == 'resnet_DISENTANGLE_pretrained':
33 | print('bringing pretrained resnet18 disentangle...')
34 | model = resnet18(pretrained=True)
35 | model.fc = nn.Linear(1024//2, num_classes)
36 | return model
37 |
38 | elif model_tag == 'resnet_50_pretrained':
39 | print('bringing pretrained resnet50 for water bird...')
40 | model = resnet50(pretrained=True)
41 | model.fc = nn.Linear(2048, num_classes)
42 | return model
43 | elif model_tag == 'resnet_50':
44 | print('bringing pretrained resnet50 for water bird...')
45 | model = resnet50(pretrained=False)
46 | model.fc = nn.Linear(2048, num_classes)
47 | return model
48 | elif 'bert' in model_tag:
49 | if model_tag[-3:] == '_pt':
50 | model_name = model_tag[:-3]
51 | else:
52 | model_name = model_tag
53 |
54 | config_class = BertConfig
55 | model_class = BertForSequenceClassification
56 | config = config_class.from_pretrained(model_name,
57 | num_labels=num_classes,
58 | finetuning_task='civilcomments')
59 | model = model_class.from_pretrained(model_name, from_tf=False,
60 | config=config)
61 | model.activation_layer = 'bert.pooler.activation'
62 | return model
63 | else:
64 | raise NotImplementedError
65 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import random
4 | from learner import Learner
5 | import argparse
6 |
7 | if __name__ == '__main__':
8 | parser = argparse.ArgumentParser(description='Avoiding spurious correlations via logit correction')
9 |
10 | # training
11 | parser.add_argument("--batch_size", help="batch_size", default=256, type=int)
12 | parser.add_argument("--lr",help='learning rate',default=1e-3, type=float)
13 | parser.add_argument("--weight_decay",help='weight_decay',default=0.0, type=float)
14 | parser.add_argument("--momentum",help='momentum',default=0.9, type=float)
15 | parser.add_argument("--num_workers", help="workers number", default=16, type=int)
16 | parser.add_argument("--exp", help='experiment name', default='Test', type=str)
17 | parser.add_argument("--device", help="cuda or cpu", default='cuda', type=str)
18 | parser.add_argument("--num_steps", help="# of iterations", default= 500 * 100, type=int)
19 | parser.add_argument("--target_attr_idx", help="target_attr_idx", default= 0, type=int)
20 | parser.add_argument("--bias_attr_idx", help="bias_attr_idx", default= 1, type=int)
21 | parser.add_argument("--dataset", help="data to train, [cmnist, cifar10, bffhq]", default= 'cmnist', type=str)
22 | parser.add_argument("--percent", help="percentage of conflict", default= "1pct", type=str)
23 | parser.add_argument("--use_lr_decay", action='store_true', help="whether to use learning rate decay")
24 | parser.add_argument("--lr_decay_step", help="learning rate decay steps", type=int, default=10000)
25 | parser.add_argument("--q", help="GCE parameter q", type=float, default=0.7)
26 | parser.add_argument("--lr_gamma", help="lr gamma", type=float, default=0.1)
27 | parser.add_argument("--lambda_dis_align", help="lambda_dis in Eq.2", type=float, default=1.0)
28 | parser.add_argument("--lambda_swap_align", help="lambda_swap_b in Eq.3", type=float, default=1.0)
29 | parser.add_argument("--lambda_swap", help="lambda swap (lambda_swap in Eq.4)", type=float, default=1.0)
30 | parser.add_argument("--ema_alpha", help="use weight mul", type=float, default=0.7)
31 | parser.add_argument("--curr_step", help="curriculum steps", type=int, default= 0)
32 | parser.add_argument("--use_type0", action='store_true', help="whether to use type 0 CIFAR10C")
33 | parser.add_argument("--use_type1", action='store_true', help="whether to use type 1 CIFAR10C")
34 | parser.add_argument("--use_resnet20", help="Use Resnet20", action="store_true") # ResNet 20 was used in Learning From Failure CifarC10 (We used ResNet18 in our paper)
35 | parser.add_argument("--model", help="which network, [MLP, ResNet18, ResNet20, ResNet50]", default= 'MLP', type=str)
36 |
37 | # logging
38 | parser.add_argument("--log_dir", help='path for loading data', default='./log', type=str)
39 | parser.add_argument("--data_dir", help='path for saving models & logs', default='dataset', type=str)
40 | parser.add_argument("--valid_freq", help='frequency to evaluate on valid/test set', default=500, type=int)
41 | parser.add_argument("--log_freq", help='frequency to log on tensorboard', default=500, type=int)
42 | parser.add_argument("--save_freq", help='frequency to save model checkpoint', default=1000, type=int)
43 | parser.add_argument("--wandb", action="store_true", help="whether to use wandb")
44 | parser.add_argument("--tensorboard", action="store_true", help="whether to use tensorboard")
45 |
46 | parser.add_argument("--tau", help="loss tau", type=float, default=1)
47 | parser.add_argument("--avg_type", help="pya estimation types", type=str, default='mv')
48 |
49 |
50 | # experiment
51 | parser.add_argument("--pretrained_path", help="path for pretrained model", type=str)
52 |
53 | args = parser.parse_args()
54 |
55 | # init learner
56 | learner = Learner(args)
57 |
58 | # actual training
59 | print('Official Pytorch Code of "Avoiding spurious correlations via logit correction"')
60 | print('Test starts ...')
61 |
62 | learner.test_ours(args)
63 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Avoiding spurious correlations via logit correction
2 | This repository provides the official PyTorch implementation of the following paper:
3 | > Avoiding spurious correlations via logit correction
4 | > [Sheng Liu](https://shengliu66.github.io/) (NYU), Xu Zhang (Amazon), Nitesh Sekhar (Amazon), Yue Wu (Amazon), Prateek Singhal (Amazon), Carlos Fernandez-Granda (NYU)
5 | > ICLR 2023
6 |
7 | > Paper: [Arxiv](https://arxiv.org/abs/2212.01433)
8 |
9 | **Abstract:**
10 | *Empirical studies suggest that machine learning models trained with empirical risk minimization (ERM) often rely on attributes that may be spuriously correlated with the class labels. Such models typically lead to poor performance during inference for data lacking such correlations. In this work, we explicitly consider a situation where potential spurious correlations are present in the majority of training data. In contrast with existing approaches, which use the ERM model outputs to detect the samples without spurious correlations and either heuristically upweight or upsample those samples, we propose the logit correction (LC) loss, a simple yet effective improvement on the softmax cross-entropy loss, to correct the sample logit. We demonstrate that minimizing the LC loss is equivalent to maximizing the group-balanced accuracy, so the proposed LC could mitigate the negative impacts of spurious correlations. Our extensive experimental results further reveal that the proposed LC loss outperforms state-of-the-art solutions on multiple popular benchmarks by a large margin, an average 5.5\% absolute improvement, without access to spurious attribute labels. LC is also competitive with oracle methods that make use of the attribute labels.*
11 |
12 |
13 |
14 |
15 |
16 | ## Pytorch Implementation
17 | ### Installation
18 | Clone this repository.
19 | ```
20 | git clone --recursive https://github.com/shengliu66/LC-private
21 | cd Logit-Correction
22 | pip install -r requirements.txt
23 | ```
24 | ### Datasets
25 | - **Waterbirds:** Download waterbirds from [here](https://nlp.stanford.edu/data/dro/waterbird_complete95_forest2water2.tar.gz).
26 | - In that directory, our code expects `data/waterbird_complete95_forest2water2/` with `metadata.csv` inside.
27 | - Specify `--data_dir`
28 |
29 | - **CivilComments:** Download civilcomments dataset from [here (https://worksheets.codalab.org/rest/bundles/0x8cd3de0634154aeaad2ee6eb96723c6e/contents/blob/)
30 | - In that directory, our code expects a folder `data` with the downloaded dataset.
31 | - Specify `--data_dir`
32 |
33 | ### Running LC
34 | #### Waterbird
35 | ```
36 | python train.py --dataset waterbird --exp=waterbird_ours --lr=1e-3 --weight_decay=1e-4 --curr_epoch=50 --lr_decay_epoch 50 --use_lr_decay --lambda_dis_align=2.0 --ema_alpha 0.5 --tau 0.1 --train_ours --q 0.8 --avg_type batch --data_dir /dir/to/data/
37 | ```
38 | #### CivilComments
39 | ```
40 | python train.py --dataset civilcomments --exp=civilcomments_ours --lr=1e-5 --q 0.7 --log_freq 400 --valid_freq 400 --weight_decay=1e-2 --curr_step 400 --use_lr_decay --num_epochs 3 --lr_decay_step 4000 --lambda_dis_align=0.1 --ema_alpha 0.9 --tau 1.0 --avg_type mv_batch --train_ours --data_dir /dir/to/data/
41 | ```
42 |
43 | ## Monitoring Performance
44 | We use [Weights & Biases](https://wandb.ai/site) to monitor training, you can follow the [doc](https://docs.wandb.ai/quickstart) here to install and log in to W&B, and add the argument `--wandb` .
45 |
46 |
47 | ### Evaluate Models
48 | In order to test our pretrained models, run the following command.
49 | ```
50 | python test.py --pretrained_path= --dataset=
51 | ```
52 |
53 | ### Citations
54 |
55 | ### BibTeX
56 | ```bibtex
57 | @Inproceedings{
58 | Liu2023,
59 | author = {Sheng Liu and Xu Zhang and Nitesh Sekhar and Yue Wu and Prateek Singhal and Carlos Fernandez-Granda},
60 | title = {Avoiding spurious correlations via logit correction},
61 | year = {2023},
62 | url = {https://www.amazon.science/publications/avoiding-spurious-correlations-via-logit-correction},
63 | booktitle = {ICLR 2023},
64 | }
65 | ```
66 |
67 | ### Contact
68 | Sheng Liu (shengliu@nyu.edu)
69 |
70 | ### Acknowledgments
71 | This work was mainly done the authors was doing internship at Amazon Science.
72 | Our pytorch implementation is based on [Disentangled](https://github.com/kakaoenterprise/Learning-Debiased-Disentangled).
73 | We would like to thank for the authors for making their code public.
74 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import random
4 | from learner import Learner
5 | import argparse
6 |
7 | if __name__ == '__main__':
8 | parser = argparse.ArgumentParser(description='Avoiding spurious correlations via logit correction')
9 |
10 | # training
11 | parser.add_argument("--batch_size", help="batch_size", default=256, type=int)
12 | parser.add_argument("--lr",help='learning rate',default=1e-3, type=float)
13 | parser.add_argument("--weight_decay",help='weight_decay',default=0.0, type=float)
14 | parser.add_argument("--momentum",help='momentum',default=0.9, type=float)
15 | parser.add_argument("--num_workers", help="workers number", default=4, type=int)
16 | parser.add_argument("--exp", help='experiment name', default='debugging', type=str)
17 | parser.add_argument("--device", help="cuda or cpu", default='cuda', type=str)
18 | parser.add_argument("--num_steps", help="# of iterations", default= 500 * 100, type=int)
19 | parser.add_argument("--num_epochs", help="# of epochs", default= 300, type=int)
20 | parser.add_argument("--target_attr_idx", help="target_attr_idx", default= 0, type=int)
21 | parser.add_argument("--bias_attr_idx", help="bias_attr_idx", default= 1, type=int)
22 | parser.add_argument("--dataset", help="data to train, [cmnist, cifar10, bffhq]", default= 'cmnist', type=str)
23 | parser.add_argument("--percent", help="percentage of conflict", default= "1pct", type=str)
24 | parser.add_argument("--use_lr_decay", action='store_true', help="whether to use learning rate decay")
25 | parser.add_argument("--lr_decay_step", help="learning rate decay steps", type=int, default=10000)
26 | parser.add_argument("--lr_decay_epoch", help="learning rate decay epochs", type=int, default=70)
27 | parser.add_argument("--q", help="GCE parameter q", type=float, default=0.7)
28 | parser.add_argument("--lr_gamma", help="lr gamma", type=float, default=0.1)
29 | parser.add_argument("--lambda_dis_align", help="lambda_dis in Eq.2", type=float, default=1.0)
30 | parser.add_argument("--lambda_swap_align", help="lambda_swap_b in Eq.3", type=float, default=1.0)
31 | parser.add_argument("--lambda_swap", help="lambda swap (lambda_swap in Eq.4)", type=float, default=1.0)
32 | parser.add_argument("--ema_alpha", help="use weight mul", type=float, default=0.995)
33 | parser.add_argument("--curr_step", help="curriculum steps", type=int, default= 0)
34 | parser.add_argument("--curr_epoch", help="curriculum epochs", type=int, default= 0)
35 | parser.add_argument("--use_type0", action='store_true', help="whether to use type 0 CIFAR10C")
36 | parser.add_argument("--use_type1", action='store_true', help="whether to use type 1 CIFAR10C")
37 | parser.add_argument("--use_resnet20", help="Use Resnet20", action="store_true") # ResNet 20 was used in Learning From Failure CifarC10 (We used ResNet18 in our paper)
38 | parser.add_argument("--model", help="which network, [MLP, ResNet18, ResNet20, ResNet50]", default= 'MLP', type=str)
39 |
40 | # logging
41 | parser.add_argument("--log_dir", help='path for saving model', default='./log', type=str)
42 | parser.add_argument("--data_dir", help='path for loading data', default='dataset', type=str)
43 | parser.add_argument("--valid_freq", help='frequency to evaluate on valid/test set', default=500, type=int)
44 | parser.add_argument("--log_freq", help='frequency to log on tensorboard', default=500, type=int)
45 | parser.add_argument("--save_freq", help='frequency to save model checkpoint', default=1000, type=int)
46 | parser.add_argument("--wandb", action="store_true", help="whether to use wandb")
47 | parser.add_argument("--tensorboard", action="store_true", help="whether to use tensorboard")
48 |
49 | # experiment
50 | parser.add_argument("--train_ours", action="store_true", help="whether to train our method")
51 | parser.add_argument("--train_vanilla", action="store_true", help="whether to train vanilla")
52 |
53 |
54 | parser.add_argument("--alpha", help="mixup alpha", type=float, default=16)
55 | parser.add_argument('--proto_m', default=0.95, type=float,
56 | help='momentum for computing the momving average of prototypes')
57 | parser.add_argument('--temperature', default=0.1, type=float,
58 | help='contrastive temperature')
59 | parser.add_argument("--tau", help="loss tau", type=float, default=1)
60 | parser.add_argument("--avg_type", help="pya estimation types", type=str, default='mv')
61 |
62 | args = parser.parse_args()
63 |
64 | seed = 222
65 | torch.manual_seed(seed)
66 | torch.cuda.manual_seed(seed)
67 | torch.cuda.manual_seed_all(seed)
68 | np.random.seed(seed)
69 | random.seed(seed)
70 | torch.backends.cudnn.benchmark = False
71 | torch.backends.cudnn.deterministic = True
72 |
73 | # init learner
74 | learner = Learner(args)
75 |
76 |
77 | # actual training
78 | print('Official Pytorch Code of "Avoiding spurious correlations via logit correction"')
79 | print('Training starts ...')
80 |
81 | if args.train_ours:
82 | learner.train_ours(args)
83 | elif args.train_vanilla:
84 | learner.train_vanilla(args)
85 | else:
86 | print('choose one of the two options ...')
87 | import sys
88 | sys.exit(0)
89 |
--------------------------------------------------------------------------------
/module/mlp.py:
--------------------------------------------------------------------------------
1 | ''' Modified from https://github.com/alinlab/LfF/blob/master/module/mlp.py'''
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | class MLP_DISENTANGLE(nn.Module):
8 | def __init__(self, num_classes = 10, bias = True):
9 | super(MLP_DISENTANGLE, self).__init__()
10 | self.feature = nn.Sequential(
11 | nn.Linear(3*28*28, 100, bias = bias),
12 | nn.ReLU(),
13 | nn.Linear(100, 100, bias = bias),
14 | nn.ReLU(),
15 | nn.Linear(100, 16, bias = bias),
16 | nn.ReLU(),
17 | )
18 | self.fc = nn.Sequential(
19 | nn.Linear(16, num_classes))
20 |
21 | # self.projection_head = nn.Sequential(
22 | # nn.ReLU(),
23 | # nn.Linear(16,16, bias = bias))
24 |
25 |
26 | def extract(self, x):
27 | x = x.view(x.size(0), -1) / 255
28 | feat = self.feature(x)
29 | # feat = self.projection_head(x)
30 | return feat
31 |
32 | def extract_rawfeature(self, x):
33 | x = x.view(x.size(0), -1) / 255
34 | feat = self.feature(x)
35 | return feat
36 |
37 | def predict(self, x):
38 | prediction = self.fc(x)
39 | return prediction
40 |
41 | def forward(self, x, mode=None, return_feat=False):
42 | x = x.view(x.size(0), -1) / 255
43 | feat = x = self.feature(x)
44 | final_x = self.fc(x)
45 | if mode == 'tsne' or mode == 'mixup':
46 | return x, final_x
47 | else:
48 | if return_feat:
49 | return final_x, feat
50 | else:
51 | return final_x
52 |
53 |
54 |
55 |
56 | class MLP_DISENTANGLE_EASY(nn.Module):
57 | def __init__(self, num_classes = 10):
58 | super(MLP_DISENTANGLE_EASY, self).__init__()
59 | self.feature = nn.Sequential(
60 | nn.Linear(3*28*28, 50),
61 | nn.ReLU(),
62 | nn.Linear(50, 50),
63 | nn.ReLU(),
64 | nn.Linear(50, 16),
65 | nn.ReLU()
66 | )
67 | self.fc = nn.Linear(32, num_classes)
68 | self.fc2 = nn.Linear(16, num_classes)
69 |
70 | def extract(self, x):
71 | x = x.view(x.size(0), -1) / 255
72 | feat = self.feature(x)
73 | return feat
74 |
75 | def predict(self, x):
76 | prediction = self.classifier(x)
77 | return prediction
78 |
79 | def forward(self, x, mode=None, return_feat=False):
80 | x = x.view(x.size(0), -1) / 255
81 | feat = x = self.feature(x)
82 | final_x = self.classifier(x)
83 | if mode == 'tsne' or mode == 'mixup':
84 | return x, final_x
85 | else:
86 | if return_feat:
87 | return final_x, feat
88 | else:
89 | return final_x
90 |
91 |
92 |
93 |
94 | class MLP(nn.Module):
95 | def __init__(self, num_classes = 10):
96 | super(MLP, self).__init__()
97 | self.feature = nn.Sequential(
98 | nn.Linear(3*28*28, 100),
99 | nn.ReLU(),
100 | nn.Linear(100, 100),
101 | nn.ReLU(),
102 | nn.Linear(100, 16),
103 | nn.ReLU()
104 | )
105 | self.classifier = nn.Linear(16, num_classes)
106 |
107 |
108 | def forward(self, x, mode=None, return_feat=False):
109 | x = x.view(x.size(0), -1) / 255
110 | feat = x = self.feature(x)
111 | final_x = self.classifier(x)
112 | if mode == 'tsne' or mode == 'mixup':
113 | return x, final_x
114 | else:
115 | if return_feat:
116 | return final_x, feat
117 | else:
118 | return final_x
119 |
120 | class MLP_DISENTANGLE_SHENG(nn.Module):
121 | def __init__(self, num_classes = 10):
122 | super(MLP_DISENTANGLE_SHENG, self).__init__()
123 | self.feature = nn.Sequential(
124 | nn.Linear(3*28*28, 100),
125 | nn.ReLU(),
126 | nn.Linear(100, 100),
127 | nn.ReLU(),
128 | nn.Linear(100, 16),
129 | nn.ReLU()
130 | )
131 |
132 | self.task_head_target = nn.Sequential(
133 | nn.Linear(16,16),
134 | nn.ReLU(),
135 | # nn.Linear(16,16),
136 | # nn.ReLU(),
137 | )
138 |
139 | self.task_head_bias = nn.Sequential(
140 | nn.Linear(16,16),
141 | nn.ReLU(),
142 | # nn.Linear(16,16),
143 | # nn.ReLU(),
144 | )
145 |
146 | self.fc_target = nn.Linear(16, num_classes)
147 | self.fc_bias = nn.Linear(16, num_classes)
148 | # self.classifier_bias = nn.Linear(16, num_classes)
149 |
150 |
151 | def extract_target(self, x):
152 | x = x.view(x.size(0), -1) / 255
153 | x = self.feature(x)
154 | feat = self.task_head_target(x)
155 | return feat
156 |
157 | def extract_bias(self, x):
158 | x = x.view(x.size(0), -1) / 255
159 | x = self.feature(x)
160 | feat = self.task_head_bias(x)
161 | return feat
162 |
163 | def predict_target(self, x):
164 | prediction = self.fc_target(x)
165 | return prediction
166 |
167 | def predict_bias(self, x):
168 | prediction = self.fc_bias(x)
169 | return prediction
170 |
171 | def forward(self, x, mode=None, return_feat=False):
172 | x = x.view(x.size(0), -1) / 255
173 | feat = x = self.feature(x)
174 | feat_target = self.task_head_target(x)
175 | feat_bias = self.task_head_bias(x)
176 |
177 | final_x_target = self.fc_target(feat_target)
178 | final_x_bias = self.fc_bias(feat_bias)
179 |
180 | if mode == 'tsne' or mode == 'mixup':
181 | return feat_target, feat_bias, final_x_target, final_x_bias
182 | else:
183 | if return_feat:
184 | return final_x_target, feat_target, final_x_bias, feat_bias
185 | else:
186 | return final_x_target, final_x_bias
187 |
188 |
189 | class Noise_MLP(nn.Module):
190 | def __init__(self, n_dim=16, n_layer=3):
191 | super(Noise_MLP, self).__init__()
192 |
193 | layers = []
194 | for i in range(n_layer):
195 | layers.append(nn.Linear(n_dim, n_dim))
196 | layers.append(nn.LeakyReLU(0.2))
197 |
198 | self.style = nn.Sequential(*layers)
199 |
200 | def forward(self, z):
201 | x = self.style(z)
202 | return x
203 |
--------------------------------------------------------------------------------
/data/civilcomments.py:
--------------------------------------------------------------------------------
1 | """
2 | CivilComments Dataset
3 | - Reference code: https://github.com/p-lambda/wilds/blob/main/wilds/datasets/civilcomments_dataset.py
4 | - See WILDS, https://wilds.stanford.edu for more
5 | """
6 | import os
7 | import numpy as np
8 | import pandas as pd
9 |
10 | import torch
11 | from torch.utils.data import Dataset, DataLoader
12 | from transformers import BertTokenizerFast
13 |
14 | from data.grouper import CombinatorialGrouper
15 |
16 |
17 | class CivilComments(Dataset):
18 | """
19 | CivilComments dataset
20 | """
21 | def __init__(self, root_dir,
22 | target_name='toxic', confounder_names=['identities'],
23 | split='train', transform=None):
24 | self.root_dir = root_dir
25 | self.target_name = target_name
26 | self.confounder_names = confounder_names
27 | self.transform = transform
28 | self.split = split
29 |
30 | # Labels
31 | self.class_names = ['non_toxic', 'toxic']
32 |
33 | # Set up data directories
34 | self.data_dir = os.path.join(self.root_dir)
35 | if not os.path.exists(self.data_dir):
36 | raise ValueError(
37 | f'{self.data_dir} does not exist yet. Please generate the dataset first.')
38 |
39 | # Read in metadata
40 | type_of_split = self.target_name.split('_')[-1]
41 | self.metadata_df = pd.read_csv(
42 | os.path.join(self.data_dir, 'all_data_with_identities.csv'),
43 | index_col=0)
44 |
45 | # Get split
46 | self.split_array = self.metadata_df['split'].values
47 | self.metadata_df = self.metadata_df[
48 | self.metadata_df['split'] == split]
49 |
50 | # Get the y values
51 | self.y_array = torch.LongTensor(
52 | self.metadata_df['toxicity'].values >= 0.5)
53 | self.y_size = 1
54 | self.n_classes = 2
55 |
56 | # Get text
57 | self.x_array = np.array(self.metadata_df['comment_text'])
58 |
59 | # Get confounders and map to groups
60 | self._identity_vars = ['male',
61 | 'female',
62 | 'LGBTQ',
63 | 'christian',
64 | 'muslim',
65 | 'other_religions',
66 | 'black',
67 | 'white']
68 | self._auxiliary_vars = ['identity_any',
69 | 'severe_toxicity',
70 | 'obscene',
71 | 'threat',
72 | 'insult',
73 | 'identity_attack',
74 | 'sexual_explicit']
75 | self.metadata_array = torch.cat(
76 | (torch.LongTensor((self.metadata_df.loc[:, self._identity_vars] >= 0.5).values),
77 | torch.LongTensor((self.metadata_df.loc[:, self._auxiliary_vars] >= 0.5).values),
78 | self.y_array.reshape((-1, 1))), dim=1)
79 |
80 | self.metadata_fields = self._identity_vars + self._auxiliary_vars + ['y']
81 | self.confounder_array = self.metadata_array[:, np.arange(len(self._identity_vars))]
82 | self.metadata_map = None
83 |
84 | self._eval_groupers = [
85 | CombinatorialGrouper(
86 | dataset=self,
87 | groupby_fields=[identity_var, 'y'])
88 | for identity_var in self._identity_vars]
89 |
90 | # Get sub_targets / group_idx
91 | groupby_fields = self._identity_vars + ['y']
92 | self.eval_grouper = CombinatorialGrouper(self, groupby_fields)
93 | self.group_array = self.eval_grouper.metadata_to_group(self.metadata_array,
94 | return_counts=False)
95 | self.n_groups = len(np.unique(self.group_array))
96 |
97 | # Get spurious labels
98 | self.spurious_grouper = CombinatorialGrouper(self,
99 | self._identity_vars)
100 | self.spurious_array = self.spurious_grouper.metadata_to_group(
101 | self.metadata_array, return_counts=False).numpy()
102 |
103 | # Get consistent label attributes
104 | self.targets = self.y_array
105 |
106 | unique_group_ix = np.unique(self.spurious_array)
107 | group_ix_to_label = {}
108 | for i, gix in enumerate(unique_group_ix):
109 | group_ix_to_label[gix] = i
110 | spurious_labels = [group_ix_to_label[int(s)]
111 | for s in self.spurious_array]
112 |
113 | self.targets_all = {'target': np.array(self.y_array),
114 | 'group_idx': np.array(self.group_array),
115 | 'spurious': np.array(spurious_labels),
116 | 'sub_target': np.array(self.metadata_array[:, self.eval_grouper.groupby_field_indices]),
117 | 'metadata': np.array(self.metadata_array)}
118 | self.group_labels = [self.group_str(i) for i in range(self.n_groups)]
119 |
120 | def __len__(self):
121 | return len(self.y_array)
122 |
123 | def __getitem__(self, idx):
124 | x = self.x_array[idx]
125 | y = self.y_array[idx]
126 | p = self.targets_all['spurious'][idx]
127 | g = self.group_array[idx]
128 | if self.transform is not None:
129 | x = self.transform(x)
130 |
131 | attr = torch.LongTensor(
132 | [y, p, g])
133 |
134 |
135 | if self.split != 'train':
136 | return x, attr, 0, -1
137 | else:
138 | return x, attr, 0, int((p == 3) and (y == 1))
139 |
140 |
141 |
142 |
143 | # return (x, y, idx) # g
144 |
145 | def group_str(self, group_idx):
146 | return self.eval_grouper.group_str(group_idx)
147 |
148 | def get_text(self, idx):
149 | return self.x_array[idx]
150 |
151 |
152 |
153 | def init_bert_transform(tokenizer, model_name, max_token_length):
154 | """
155 | Inspired from the WILDS dataset:
156 | - https://github.com/p-lambda/wilds/blob/main/examples/transforms.py
157 | """
158 | def transform(text):
159 | tokens = tokenizer(text, padding='max_length',
160 | truncation=True,
161 | max_length=max_token_length,#args.max_token_length, # 300
162 | return_tensors='pt')
163 | if model_name == 'bert-base-uncased':
164 | x = torch.stack((tokens['input_ids'],
165 | tokens['attention_mask'],
166 | tokens['token_type_ids']), dim=2)
167 | # Not supported for now
168 | elif model_name == 'distilbert-base-uncased':
169 | x = torch.stack((tokens['input_ids'],
170 | tokens['attention_mask']), dim=2)
171 | x = torch.squeeze(x, dim=0) # First shape dim is always 1
172 | return x
173 | return transform
--------------------------------------------------------------------------------
/module/resnet.py:
--------------------------------------------------------------------------------
1 | ''' From https://github.com/alinlab/LfF/blob/master/module/resnet.py '''
2 |
3 | """
4 | Properly implemented ResNet-s for CIFAR10 as described in paper [1].
5 | The implementation and structure of this file is hugely influenced by [2]
6 | which is implemented for ImageNet and doesn't have option A for identity.
7 | Moreover, most of the implementations on the web is copy-paste from
8 | torchvision's resnet and has wrong number of params.
9 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following
10 | number of layers and parameters:
11 | name | layers | params
12 | ResNet20 | 20 | 0.27M
13 | ResNet32 | 32 | 0.46M
14 | ResNet44 | 44 | 0.66M
15 | ResNet56 | 56 | 0.85M
16 | ResNet110 | 110 | 1.7M
17 | ResNet1202| 1202 | 19.4m
18 | which this implementation indeed has.
19 | Reference:
20 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
21 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
22 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
23 | If you use this implementation in you work, please don't forget to mention the
24 | author, Yerlan Idelbayev.
25 | """
26 | import torch
27 | import torch.nn as nn
28 | import torch.nn.functional as F
29 | import torch.nn.init as init
30 |
31 | from torch.autograd import Variable
32 |
33 | __all__ = [
34 | "ResNet",
35 | "resnet20",
36 | "resnet32",
37 | "resnet44",
38 | "resnet56",
39 | "resnet110",
40 | "resnet1202",
41 | ]
42 |
43 |
44 | def _weights_init(m):
45 | classname = m.__class__.__name__
46 | # print(classname)
47 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
48 | init.kaiming_normal_(m.weight)
49 |
50 |
51 | class LambdaLayer(nn.Module):
52 | def __init__(self, lambd):
53 | super(LambdaLayer, self).__init__()
54 | self.lambd = lambd
55 |
56 | def forward(self, x):
57 | return self.lambd(x)
58 |
59 |
60 | class BasicBlock(nn.Module):
61 | expansion = 1
62 |
63 | def __init__(self, in_planes, planes, stride=1, option="A"):
64 | super(BasicBlock, self).__init__()
65 | self.conv1 = nn.Conv2d(
66 | in_planes,
67 | planes,
68 | kernel_size=3,
69 | stride=stride,
70 | padding=1,
71 | bias=False,
72 | )
73 | self.bn1 = nn.BatchNorm2d(planes)
74 | self.conv2 = nn.Conv2d(
75 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False
76 | )
77 | self.bn2 = nn.BatchNorm2d(planes)
78 |
79 | self.shortcut = nn.Sequential()
80 | if stride != 1 or in_planes != planes:
81 | if option == "A":
82 | """
83 | For CIFAR10 ResNet paper uses option A.
84 | """
85 | self.shortcut = LambdaLayer(
86 | lambda x: F.pad(
87 | x[:, :, ::2, ::2],
88 | (0, 0, 0, 0, planes // 4, planes // 4),
89 | "constant",
90 | 0,
91 | )
92 | )
93 | elif option == "B":
94 | self.shortcut = nn.Sequential(
95 | nn.Conv2d(
96 | in_planes,
97 | self.expansion * planes,
98 | kernel_size=1,
99 | stride=stride,
100 | bias=False,
101 | ),
102 | nn.BatchNorm2d(self.expansion * planes),
103 | )
104 |
105 | def forward(self, x):
106 | out = F.relu(self.bn1(self.conv1(x)))
107 | out = self.bn2(self.conv2(out))
108 | out += self.shortcut(x)
109 | out = F.relu(out)
110 | return out
111 |
112 |
113 | class ResNet(nn.Module):
114 | def __init__(self, block, num_blocks, num_classes=10):
115 | super(ResNet, self).__init__()
116 | self.in_planes = 16
117 | # print('!!!1'*100)
118 |
119 | self.conv1 = nn.Conv2d(
120 | 3, 16, kernel_size=3, stride=1, padding=1, bias=False
121 | )
122 | self.bn1 = nn.BatchNorm2d(16)
123 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
124 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
125 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
126 | self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
127 | self.fc = nn.Linear(64, num_classes)
128 | self.fc2 = nn.Linear(64, num_classes)
129 |
130 | self.projection_head = nn.Sequential(
131 | # nn.ReLU(),
132 | nn.Linear(64,16),
133 | # nn.ReLU(),
134 | nn.Linear(16,64))
135 |
136 | self.apply(_weights_init)
137 |
138 | def _make_layer(self, block, planes, num_blocks, stride):
139 | strides = [stride] + [1] * (num_blocks - 1)
140 | layers = []
141 | for stride in strides:
142 | layers.append(block(self.in_planes, planes, stride))
143 | self.in_planes = planes * block.expansion
144 |
145 | return nn.Sequential(*layers)
146 |
147 | def extract(self, x):
148 | out = F.relu(self.bn1(self.conv1(x)))
149 | out = self.layer1(out)
150 | out = self.layer2(out)
151 | out = self.layer3(out)
152 | out = F.avg_pool2d(out, out.size()[3])
153 | feat = out.view(out.size(0), -1)
154 |
155 | return feat
156 |
157 | def predict(self, x):
158 | prediction = self.fc(x)
159 | return prediction
160 |
161 | def forward(self, x, mode=None):
162 | out = F.relu(self.bn1(self.conv1(x)))
163 | out = self.layer1(out)
164 | out = self.layer2(out)
165 | out = self.layer3(out)
166 | # out = F.avg_pool2d(out, out.size()[3])
167 | # out = out.view(out.size(0), -1)
168 | out = self.avgpool(out)
169 | out = out.view(out.size(0), -1)
170 | final_out = self.fc(out)
171 | if mode == 'tsne' or mode == 'mixup':
172 | return out, final_out
173 | else:
174 | return final_out
175 |
176 |
177 | def resnet20(num_classes):
178 | print('!!!!!'*100)
179 | return ResNet(BasicBlock, [3, 3, 3], num_classes)
180 |
181 |
182 | def resnet32():
183 | return ResNet(BasicBlock, [5, 5, 5])
184 |
185 |
186 | def resnet44():
187 | return ResNet(BasicBlock, [7, 7, 7])
188 |
189 |
190 | def resnet56():
191 | return ResNet(BasicBlock, [9, 9, 9])
192 |
193 |
194 | def resnet110():
195 | return ResNet(BasicBlock, [18, 18, 18])
196 |
197 |
198 | def resnet1202():
199 | return ResNet(BasicBlock, [200, 200, 200])
200 |
201 |
202 | def test(net):
203 | import numpy as np
204 |
205 | total_params = 0
206 |
207 | for x in filter(lambda p: p.requires_grad, net.parameters()):
208 | total_params += np.prod(x.data.numpy().shape)
209 | print("Total number of params", total_params)
210 | print(
211 | "Total layers",
212 | len(
213 | list(
214 | filter(
215 | lambda p: p.requires_grad and len(p.data.size()) > 1,
216 | net.parameters(),
217 | )
218 | )
219 | ),
220 | )
221 |
222 |
223 |
224 | if __name__ == "__main__":
225 | for net_name in __all__:
226 | if net_name.startswith("resnet"):
227 | print(net_name)
228 | test(globals()[net_name]())
229 | print()
230 |
--------------------------------------------------------------------------------
/data/grouper.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataset grouer for subgroup and group_ix information
3 | - Used by CivilComments
4 |
5 | From WILDS: https://github.com/p-lambda/wilds/blob/main/wilds/common/grouper.py
6 | """
7 |
8 | import numpy as np
9 | import torch
10 | # from wilds.common.utils import get_counts
11 | # from wilds.datasets.wilds_dataset import WILDSSubset
12 | import warnings
13 |
14 |
15 | def get_counts(g, n_groups):
16 | """
17 | This differs from split_into_groups in how it handles missing groups.
18 | get_counts always returns a count Tensor of length n_groups,
19 | whereas split_into_groups returns a unique_counts Tensor
20 | whose length is the number of unique groups present in g.
21 | Args:
22 | - g (Tensor): Vector of groups
23 | Returns:
24 | - counts (Tensor): A list of length n_groups, denoting the count of each group.
25 | """
26 | unique_groups, unique_counts = torch.unique(g, sorted=False, return_counts=True)
27 | counts = torch.zeros(n_groups, device=g.device)
28 | counts[unique_groups] = unique_counts.float()
29 | return counts
30 |
31 |
32 | class Grouper:
33 | """
34 | Groupers group data points together based on their metadata.
35 | They are used for training and evaluation,
36 | e.g., to measure the accuracies of different groups of data.
37 | """
38 | def __init__(self):
39 | raise NotImplementedError
40 |
41 | @property
42 | def n_groups(self):
43 | """
44 | The number of groups defined by this Grouper.
45 | """
46 | return self._n_groups
47 |
48 | def metadata_to_group(self, metadata, return_counts=False):
49 | """
50 | Args:
51 | - metadata (Tensor): An n x d matrix containing d metadata fields
52 | for n different points.
53 | - return_counts (bool): If True, return group counts as well.
54 | Output:
55 | - group (Tensor): An n-length vector of groups.
56 | - group_counts (Tensor): Optional, depending on return_counts.
57 | An n_group-length vector of integers containing the
58 | numbers of data points in each group in the metadata.
59 | """
60 | raise NotImplementedError
61 |
62 | def group_str(self, group):
63 | """
64 | Args:
65 | - group (int): A single integer representing a group.
66 | Output:
67 | - group_str (str): A string containing the pretty name of that group.
68 | """
69 | raise NotImplementedError
70 |
71 | def group_field_str(self, group):
72 | """
73 | Args:
74 | - group (int): A single integer representing a group.
75 | Output:
76 | - group_str (str): A string containing the name of that group.
77 | """
78 | raise NotImplementedError
79 |
80 | class CombinatorialGrouper(Grouper):
81 | def __init__(self, dataset, groupby_fields):
82 | """
83 | CombinatorialGroupers form groups by taking all possible combinations of the metadata
84 | fields specified in groupby_fields, in lexicographical order.
85 | For example, if:
86 | dataset.metadata_fields = ['country', 'time', 'y']
87 | groupby_fields = ['country', 'time']
88 | and if in dataset.metadata, country is in {0, 1} and time is in {0, 1, 2},
89 | then the grouper will assign groups in the following way:
90 | country = 0, time = 0 -> group 0
91 | country = 1, time = 0 -> group 1
92 | country = 0, time = 1 -> group 2
93 | country = 1, time = 1 -> group 3
94 | country = 0, time = 2 -> group 4
95 | country = 1, time = 2 -> group 5
96 | If groupby_fields is None, then all data points are assigned to group 0.
97 | Args:
98 | - dataset (WILDSDataset)
99 | - groupby_fields (list of str)
100 | """
101 | # if isinstance(dataset, WILDSSubset):
102 | # raise ValueError("Grouper should be defined for the full dataset, not a subset")
103 | self.groupby_fields = groupby_fields
104 |
105 | if groupby_fields is None:
106 | self._n_groups = 1
107 | else:
108 | # We assume that the metadata fields are integers,
109 | # so we can measure the cardinality of each field by taking its max + 1.
110 | # Note that this might result in some empty groups.
111 | self.groupby_field_indices = [i for (i, field) in enumerate(dataset.metadata_fields) if field in groupby_fields]
112 | if len(self.groupby_field_indices) != len(self.groupby_fields):
113 | raise ValueError('At least one group field not found in dataset.metadata_fields')
114 | grouped_metadata = dataset.metadata_array[:, self.groupby_field_indices]
115 | if not isinstance(grouped_metadata, torch.LongTensor):
116 | grouped_metadata_long = grouped_metadata.long()
117 | if not torch.all(grouped_metadata == grouped_metadata_long):
118 | warnings.warn(f'CombinatorialGrouper: converting metadata with fields [{", ".join(groupby_fields)}] into long')
119 | grouped_metadata = grouped_metadata_long
120 | for idx, field in enumerate(self.groupby_fields):
121 | min_value = grouped_metadata[:,idx].min()
122 | if min_value < 0:
123 | raise ValueError(f"Metadata for CombinatorialGrouper cannot have values less than 0: {field}, {min_value}")
124 | if min_value > 0:
125 | warnings.warn(f"Minimum metadata value for CombinatorialGrouper is not 0 ({field}, {min_value}). This will result in empty groups")
126 | self.cardinality = 1 + torch.max(
127 | grouped_metadata, dim=0)[0]
128 | cumprod = torch.cumprod(self.cardinality, dim=0)
129 | self._n_groups = cumprod[-1].item()
130 | self.factors_np = np.concatenate(([1], cumprod[:-1]))
131 | self.factors = torch.from_numpy(self.factors_np)
132 | self.metadata_map = dataset.metadata_map
133 |
134 | def metadata_to_group(self, metadata, return_counts=False):
135 | if self.groupby_fields is None:
136 | groups = torch.zeros(metadata.shape[0], dtype=torch.long)
137 | else:
138 | groups = metadata[:, self.groupby_field_indices].long() @ self.factors
139 |
140 | if return_counts:
141 | group_counts = get_counts(groups, self._n_groups)
142 | return groups, group_counts
143 | else:
144 | return groups
145 |
146 | def group_str(self, group):
147 | if self.groupby_fields is None:
148 | return 'all'
149 |
150 | # group is just an integer, not a Tensor
151 | n = len(self.factors_np)
152 | metadata = np.zeros(n)
153 | for i in range(n-1):
154 | metadata[i] = (group % self.factors_np[i+1]) // self.factors_np[i]
155 | metadata[n-1] = group // self.factors_np[n-1]
156 | group_name = ''
157 | for i in reversed(range(n)):
158 | meta_val = int(metadata[i])
159 | if self.metadata_map is not None:
160 | if self.groupby_fields[i] in self.metadata_map:
161 | meta_val = self.metadata_map[self.groupby_fields[i]][meta_val]
162 | group_name += f'{self.groupby_fields[i]} = {meta_val}, '
163 | group_name = group_name[:-2]
164 | return group_name
165 |
166 | # a_n = S / x_n
167 | # a_{n-1} = (S % x_n) / x_{n-1}
168 | # a_{n-2} = (S % x_{n-1}) / x_{n-2}
169 | # ...
170 | #
171 | # g =
172 | # a_1 * x_1 +
173 | # a_2 * x_2 + ...
174 | # a_n * x_n
175 |
176 | def group_field_str(self, group):
177 | return self.group_str(group).replace('=', ':').replace(',','_').replace(' ','')
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Datasets
3 | """
4 | import copy
5 | import numpy as np
6 | import importlib
7 |
8 |
9 | def initialize_data(args):
10 | """
11 | Set dataset-specific arguments
12 | By default, the args.root_dir below should work ifinstalling datasets as
13 | specified in the README to the specified locations
14 | - Otherwise, change `args.root_dir` to the path where the data is stored.
15 | """
16 | dataset_module = importlib.import_module(f'data.{args.dataset}')
17 | load_dataloaders = getattr(dataset_module, 'load_dataloaders')
18 | visualize_dataset = getattr(dataset_module, 'visualize_dataset')
19 |
20 | if 'waterbirds' in args.dataset:
21 | args.root_dir = '../slice-and-dice-smol/datasets/data/Waterbirds/'
22 | # args.root_dir = './datasets/data/Waterbirds/'
23 | args.target_name = 'waterbird_complete95'
24 | args.confounder_names = ['forest2water2']
25 | args.image_mean = np.mean([0.485, 0.456, 0.406])
26 | args.image_std = np.mean([0.229, 0.224, 0.225])
27 | args.augment_data = False
28 | args.train_classes = ['landbirds', 'waterbirds']
29 | if args.dataset == 'waterbirds_r':
30 | args.train_classes = ['land', 'water']
31 |
32 | elif 'colored_mnist' in args.dataset:
33 | args.root_dir = './datasets/data/'
34 | args.data_path = './datasets/data/'
35 | args.target_name = 'digit'
36 | args.confounder_names = ['color']
37 | args.image_mean = 0.5
38 | args.image_std = 0.5
39 | args.augment_data = False
40 | # args.train_classes = args.train_classes
41 |
42 | elif 'celebA' in args.dataset:
43 | # args.root_dir = './datasets/data/CelebA/'
44 | args.root_dir = '/dfs/scratch0/nims/CelebA/celeba/'
45 | # IMPORTANT - dataloader assumes that we have directory structure
46 | # in ./datasets/data/CelebA/ :
47 | # |-- list_attr_celeba.csv
48 | # |-- list_eval_partition.csv
49 | # |-- img_align_celeba/
50 | # |-- image1.png
51 | # |-- ...
52 | # |-- imageN.png
53 | args.target_name = 'Blond_Hair'
54 | args.confounder_names = ['Male']
55 | args.image_mean = np.mean([0.485, 0.456, 0.406])
56 | args.image_std = np.mean([0.229, 0.224, 0.225])
57 | args.augment_data = False
58 | args.image_path = './images/celebA/'
59 | args.train_classes = ['blond', 'nonblond']
60 | args.val_split = 0.2
61 |
62 | elif 'civilcomments' in args.dataset:
63 | args.root_dir = '/gpfs/data/razavianlab/data/nlp/CivilComments/'
64 | args.target_name = 'toxic'
65 | args.confounder_names = ['identities']
66 | args.image_mean = 0
67 | args.image_std = 0
68 | args.augment_data = False
69 | args.image_path = './images/civilcomments/'
70 | args.train_classes = ['non_toxic', 'toxic']
71 | args.max_token_length = 300
72 | args.arch = 'bert-base-uncased_pt'
73 | args.bs_trn = 16
74 | args.bs_val = 128
75 |
76 | elif 'cxr' in args.dataset:
77 | args.root_dir = '/dfs/scratch1/ksaab/data/4tb_hdd/CXR'
78 | args.target_name = 'pmx'
79 | args.confounder_names = ['chest_tube']
80 | args.image_mean = 0.48865
81 | args.image_std = 0.24621
82 | args.augment_data = False
83 | args.image_path = './images/cxr/'
84 | args.train_classes = ['no_pmx', 'pmx']
85 |
86 | args.task = args.dataset # e.g. 'civilcomments', for BERT
87 | args.num_classes = len(args.train_classes)
88 | return load_dataloaders, visualize_dataset
89 |
90 |
91 | def train_val_split(dataset, val_split, seed):
92 | """
93 | Compute indices for train and val splits
94 |
95 | Args:
96 | - dataset (torch.utils.data.Dataset): Pytorch dataset
97 | - val_split (float): Fraction of dataset allocated to validation split
98 | - seed (int): Reproducibility seed
99 | Returns:
100 | - train_indices, val_indices (np.array, np.array): Dataset indices
101 | """
102 | train_ix = int(np.round(val_split * len(dataset)))
103 | all_indices = np.arange(len(dataset))
104 | np.random.seed(seed)
105 | np.random.shuffle(all_indices)
106 | train_indices = all_indices[train_ix:]
107 | val_indices = all_indices[:train_ix]
108 | return train_indices, val_indices
109 |
110 |
111 | def get_resampled_indices(dataloader, args, sampling='subsample', seed=None):
112 | """
113 | Args:
114 | - dataloader (torch.utils.data.DataLoader):
115 | - sampling (str): 'subsample' or 'upsample'
116 | """
117 | try:
118 | indices = dataloader.sampler.indices
119 | except:
120 | indices = np.arange(len(dataloader.dataset))
121 | indices = np.arange(len(dataloader.dataset))
122 | target_vals, target_val_counts = np.unique(
123 | dataloader.dataset.targets_all['target'][indices],
124 | return_counts=True)
125 | sampled_indices = []
126 | if sampling == 'subsample':
127 | sample_size = np.min(target_val_counts)
128 | elif sampling == 'upsample':
129 | sample_size = np.max(target_val_counts)
130 | else:
131 | return indices
132 |
133 | if seed is None:
134 | seed = args.seed
135 | np.random.seed(seed)
136 | for v in target_vals:
137 | group_indices = np.where(
138 | dataloader.dataset.targets_all['target'][indices] == v)[0]
139 | if sampling == 'subsample':
140 | sampling_size = np.min([len(group_indices), sample_size])
141 | replace = False
142 | elif sampling == 'upsample':
143 | sampling_size = np.max([0, sample_size - len(group_indices)])
144 | sampled_indices.append(group_indices)
145 | replace = True
146 | sampled_indices.append(np.random.choice(
147 | group_indices, size=sampling_size, replace=replace))
148 | sampled_indices = np.concatenate(sampled_indices)
149 | np.random.seed(seed)
150 | np.random.shuffle(sampled_indices)
151 | return indices[sampled_indices]
152 |
153 |
154 | def get_resampled_set(dataset, resampled_set_indices, copy_dataset=False):
155 | """
156 | Obtain spurious dataset resampled_set
157 | Args:
158 | - dataset (torch.utils.data.Dataset): Spurious correlations dataset
159 | - resampled_set_indices (int[]): List-like of indices
160 | - deepcopy (bool): If true, copy the dataset
161 | """
162 | resampled_set = copy.deepcopy(dataset) if copy_dataset else dataset
163 | try: # Some dataset classes may not have these attributes
164 | resampled_set.y_array = resampled_set.y_array[resampled_set_indices]
165 | resampled_set.group_array = resampled_set.group_array[resampled_set_indices]
166 | resampled_set.split_array = resampled_set.split_array[resampled_set_indices]
167 | resampled_set.targets = resampled_set.y_array
168 | try: # Depending on the dataset these are responsible for the X features
169 | resampled_set.filename_array = resampled_set.filename_array[resampled_set_indices]
170 | except:
171 | resampled_set.x_array = resampled_set.x_array[resampled_set_indices]
172 | except AttributeError as e:
173 | try:
174 | resampled_set.targets = resampled_set.targets[resampled_set_indices]
175 | except:
176 | resampled_set_indices = np.concatenate(resampled_set_indices)
177 | resampled_set.targets = resampled_set.targets[resampled_set_indices]
178 | try:
179 | resampled_set.df = resampled_set.df.iloc[resampled_set_indices]
180 | except AttributeError:
181 | pass
182 |
183 | try:
184 | resampled_set.data = resampled_set.data[resampled_set_indices]
185 | except AttributeError:
186 | pass
187 |
188 | try: # Depending on the dataset these are responsible for the X features
189 | resampled_set.filename_array = resampled_set.filename_array[resampled_set_indices]
190 | except:
191 | pass
192 |
193 | for target_type, target_val in resampled_set.targets_all.items():
194 | resampled_set.targets_all[target_type] = target_val[resampled_set_indices]
195 |
196 | print('len(resampled_set.targets)', len(resampled_set.targets))
197 | return resampled_set
198 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | '''Modified from https://github.com/alinlab/LfF/blob/master/util.py'''
2 |
3 | import io
4 | import torch
5 | import numpy as np
6 | import torch.nn as nn
7 | from transformers import get_linear_schedule_with_warmup
8 |
9 | class EMA:
10 | def __init__(self, label, num_classes=None, alpha=0.9):
11 | self.label = label.cuda()
12 | self.alpha = alpha
13 | self.parameter = torch.zeros(label.size(0), num_classes)
14 | self.updated = torch.zeros(label.size(0), num_classes)
15 | self.num_classes = num_classes
16 | self.max = torch.zeros(self.num_classes).cuda()
17 |
18 | def update(self, data, index, curve=None, iter_range=None, step=None):
19 | self.parameter = self.parameter.to(data.device)
20 | self.updated = self.updated.to(data.device)
21 | index = index.to(data.device)
22 |
23 | if curve is None:
24 | self.parameter[index] = self.alpha * self.parameter[index] + (1 - self.alpha * self.updated[index]) * data
25 | else:
26 | alpha = curve ** -(step / iter_range)
27 | self.parameter[index] = alpha * self.parameter[index] + (1 - alpha * self.updated[index]) * data
28 | self.updated[index] = 1
29 |
30 | def max_loss(self, label):
31 | label_index = torch.where(self.label == label)[0]
32 | return self.parameter[label_index].max()
33 |
34 | def min_loss(self, label):
35 | label_index = torch.where(self.label == label)[0]
36 | return self.parameter[label_index].min()
37 |
38 |
39 | class EMA_squre:
40 | def __init__(self, num_classes=None, alpha=0.9, avg_type = 'mv'):
41 | self.alpha = alpha
42 | self.parameter = torch.zeros(num_classes, num_classes)
43 | self.global_count_ = torch.zeros(num_classes, num_classes)
44 | self.updated = torch.zeros(num_classes, num_classes)
45 | self.num_classes = num_classes
46 | self.max = torch.zeros(self.num_classes).cuda()
47 | self.avg_type = avg_type
48 |
49 | def update(self, data, y_list, a_list, curve=None, iter_range=None, step=None, bias=None, fix = None):
50 | self.parameter = self.parameter.to(data.device)
51 | self.updated = self.updated.to(data.device)
52 | # self.global_count_ = self.global_count_.to(data.device)
53 | y_list = y_list.to(data.device)
54 | a_list = a_list.to(data.device)
55 |
56 |
57 | count = torch.zeros(self.num_classes, self.num_classes).to(data.device)
58 | # parameter_temp = torch.zeros(self.num_classes, self.num_classes, self.num_classes).to(data.device)
59 |
60 | if self.avg_type == 'mv':
61 | if curve is None:
62 | for i, (y, a) in enumerate(zip(y_list, a_list)):
63 | # parameter_temp[y,a] += data[i]
64 | count[y,a] += 1
65 | self.global_count_[y,a] += 1
66 | self.parameter[y,a] = self.alpha * self.parameter[y,a] + (1 - self.alpha * self.updated[y,a]) * data[i,y]#parameter_temp[y,a]/count[y,a]
67 | self.updated[y,a] = 1
68 | else:
69 | alpha = curve ** -(step / iter_range)
70 | for i, (y, a) in enumerate(zip(y_list, a_list)):
71 | # parameter_temp[y,a] += data[i]
72 | count[y,a] += 1
73 | self.global_count_[y,a] += 1
74 | self.parameter[y,a] = alpha * self.parameter[y,a] + (1 - alpha * self.updated[y,a]) * data[i,y]#parameter_temp[y,a]/count[y,a]
75 | self.updated[y,a] = 1
76 | elif self.avg_type == 'mv_batch':
77 | self.parameter_temp = torch.zeros(self.num_classes, self.num_classes).to(data.device)
78 | for i, (y, a) in enumerate(zip(y_list, a_list)):
79 | count[y,a] += 1
80 | self.global_count_[y,a] += 1
81 | self.parameter_temp[y,a] += data[i,y]
82 | self.parameter = self.alpha * self.parameter + (1 - self.alpha) * self.parameter_temp / (count + 1e-4)
83 | elif self.avg_type == 'batch':
84 | self.parameter_temp = torch.zeros(self.num_classes, self.num_classes).to(data.device)
85 | for i, (y, a) in enumerate(zip(y_list, a_list)):
86 | count[y,a] += 1
87 | self.global_count_[y,a] += 1
88 | self.parameter_temp[y,a] += data[i,y]
89 | self.parameter = self.parameter_temp / (count + 1e-4)
90 | elif self.avg_type == 'epoch':
91 | for i, (y, a) in enumerate(zip(y_list, a_list)):
92 | count[y,a] += 1
93 | self.global_count_[y,a] += 1
94 | self.parameter[y,a] += data[i,y]
95 | else:
96 | raise NotImplementedError("This averaging type is not yet implemented!")
97 |
98 | if fix is not None:
99 | self.parameter = torch.ones(self.num_classes, self.num_classes) * 0.1#* 0.005/(self.num_classes-1)
100 | # for i in range(self.num_classes):
101 | # self.parameter[i,i] = 0.995
102 | self.parameter = self.parameter.to(data.device)
103 |
104 |
105 |
106 |
107 |
108 | # def max_loss(self, label):
109 | # label_index = torch.where(self.label == label)[0]
110 | # return self.parameter[label_index].max()
111 |
112 | # def min_loss(self, label):
113 | # label_index = torch.where(self.label == label)[0]
114 | # return self.parameter[label_index].min()
115 |
116 |
117 |
118 | class EMA_area:
119 | def __init__(self, label, num_classes=None, alpha=0.9):
120 | self.label = label.cuda()
121 | self.alpha = alpha
122 | self.parameter = torch.zeros(label.size(0))
123 | self.updated = torch.zeros(label.size(0))
124 | self.num_classes = num_classes
125 | # self.max = torch.zeros(self.num_classes).cuda()
126 | self.data_old = torch.zeros(label.size(0)).cuda()
127 |
128 | def update(self, data, index, curve=None, iter_range=None, step=None):
129 | self.parameter = self.parameter.to(data.device)
130 | self.updated = self.updated.to(data.device)
131 | index = index.to(data.device)
132 |
133 | self.parameter[index] += 0.5 * self.data_old[index] + (1 - 0.5 * self.updated[index]) * data
134 | self.updated[index] = 1
135 | self.data_old[index] = data
136 |
137 | def max_area(self, label, temp = 1):
138 | label_index = torch.where(self.label == label)[0]
139 | return torch.nn.functional.sigmoid(-self.parameter[label_index]/temp).max()
140 |
141 |
142 |
143 | class EMA_feature:
144 | def __init__(self, label, num_classes=None, alpha=0.9):
145 | self.label = label.cuda()
146 | self.alpha = alpha
147 | self.parameter = torch.zeros((label.size(0),num_classes))
148 | self.updated = torch.zeros((label.size(0),num_classes))
149 | self.num_classes = num_classes
150 | #self.max = torch.zeros(self.num_classes).cuda()
151 |
152 | def update(self, data, index, curve=None, iter_range=None, step=None):
153 | self.parameter = self.parameter.to(data.device)
154 | self.updated = self.updated.to(data.device)
155 | index = index.to(data.device)
156 |
157 | if curve is None:
158 | self.parameter[index] = self.alpha * self.parameter[index] + (1 - self.alpha * self.updated[index]) * data
159 | else:
160 | alpha = curve ** -(step / iter_range)
161 | self.parameter[index] = alpha * self.parameter[index] + (1 - alpha * self.updated[index]) * data
162 | self.updated[index] = 1
163 |
164 | # def max_loss(self, label):
165 | # label_index = torch.where(self.label == label)[0]
166 | # return self.parameter[label_index].max()
167 |
168 |
169 | def sigmoid_rampup(current, rampup_length):
170 | """Exponential rampup from 2"""
171 | if rampup_length == 0:
172 | return 1.0
173 | else:
174 | current = np.clip(current, 0.0, rampup_length)
175 | phase = 1.0 - current / rampup_length
176 | return float(np.exp(-5.0 * phase * phase))
177 |
178 |
179 | def linear_rampup(current, rampup_length):
180 | """Linear rampup"""
181 | assert current >= 0 and rampup_length >= 0
182 | if current >= rampup_length:
183 | return 1.0
184 | else:
185 | return current / rampup_length
186 | EPS = 1e-9
187 |
188 | def grad_norm(module):
189 | parameters = module.parameters()
190 | if isinstance(parameters, torch.Tensor):
191 | parameters = [parameters]
192 |
193 | parameters = list(filter(lambda p: p.grad is not None, parameters))
194 |
195 | total_norm = 0
196 | for p in parameters:
197 | param_norm = p.grad.data.norm(2)
198 | total_norm = param_norm.item() ** 2
199 | total_norm = total_norm ** (1. / 2)
200 | return total_norm
201 |
202 |
203 | def adaptive_gradient_clipping_(generator_module: nn.Module, mi_module: nn.Module):
204 | """
205 | Clips the gradient according to the min norm of the generator and mi estimator
206 | Arguments:
207 | generator_module -- nn.Module
208 | mi_module -- nn.Module
209 | """
210 | norm_generator = grad_norm(generator_module)
211 | #norm_estimator = grad_norm(mi_module)
212 |
213 | min_norm = norm_generator#np.minimum(norm_generator, norm_estimator)
214 |
215 | parameters = list(
216 | filter(lambda p: p.grad is not None, mi_module.parameters()))
217 | if isinstance(parameters, torch.Tensor):
218 | parameters = [parameters]
219 |
220 | for p in parameters:
221 | p.grad.data.mul_(min_norm/(norm_estimator + EPS))
222 |
223 |
224 |
225 | def get_bert_scheduler(optimizer, n_epochs, warmup_steps, dataloader, last_epoch=-1):
226 | """
227 | Learning rate scheduler for BERT model training
228 | """
229 | num_training_steps = int(np.round(len(dataloader) * n_epochs))
230 | print(f'\nt_total is {num_training_steps}\n')
231 | scheduler = get_linear_schedule_with_warmup(optimizer,
232 | warmup_steps,
233 | num_training_steps,
234 | last_epoch)
235 | return scheduler
236 |
237 | # From pytorch-transformers:
238 | # def get_linear_schedule_with_warmup(optimizer, num_warmup_steps,
239 | # num_training_steps, last_epoch=-1):
240 | # """
241 | # Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
242 | # a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
243 | # Args:
244 | # optimizer (:class:`~torch.optim.Optimizer`):
245 | # The optimizer for which to schedule the learning rate.
246 | # num_warmup_steps (:obj:`int`):
247 | # The number of steps for the warmup phase.
248 | # num_training_steps (:obj:`int`):
249 | # The total number of training steps.
250 | # last_epoch (:obj:`int`, `optional`, defaults to -1):
251 | # The index of the last epoch when resuming training.
252 | # Return:
253 | # :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
254 | # """
255 |
256 | # def lr_lambda(current_step: int):
257 | # if current_step < num_warmup_steps:
258 | # return float(current_step) / float(max(1, num_warmup_steps))
259 | # return max(
260 | # 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
261 | # )
262 |
263 | # return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
264 |
--------------------------------------------------------------------------------
/module/activations.py:
--------------------------------------------------------------------------------
1 | """
2 | Functions to help with feature representations
3 | """
4 | import numpy as np
5 | import torch
6 | from tqdm import tqdm
7 |
8 | from utils import print_header
9 | from utils.visualize import plot_umap
10 | from network import get_output
11 |
12 | from sklearn.linear_model import LogisticRegression
13 | from sklearn.model_selection import train_test_split
14 |
15 |
16 | class SaveOutput:
17 | def __init__(self):
18 | self.outputs = []
19 |
20 | def __call__(self, module, module_in, module_out):
21 | try:
22 | module_out = module_out.detach().cpu()
23 | self.outputs.append(module_out) # .detach().cpu().numpy()
24 | except Exception as e:
25 | print(e)
26 | self.outputs.append(module_out)
27 |
28 | def clear(self):
29 | self.outputs = []
30 |
31 |
32 | def save_activations(model, dataloader, args):
33 | """
34 | total_embeddings = save_activations(net, train_loader, args)
35 | """
36 | save_output = SaveOutput()
37 | hook_handles = []
38 |
39 | # if 'resnet' in args.arch:
40 | # for name, layer in model.named_modules():
41 | # if name == model.activation_layer or \
42 | # (isinstance(model, torch.nn.DataParallel) and \
43 | # name.replace('module.', '') == model.activation_layer):
44 | # handle = layer.register_forward_hook(save_output)
45 | # hook_handles.append(handle)
46 | # elif 'densenet' in args.arch:
47 | # for name, layer in model.named_modules():
48 | # if name == model.activation_layer or \
49 | # (isinstance(model, torch.nn.DataParallel) and \
50 | # name.replace('module.', '') == model.activation_layer):
51 | # handle = layer.register_forward_hook(save_output)
52 | # hook_handles.append(handle)
53 | # elif 'bert' in args.arch:
54 | for name, layer in model.named_modules():
55 | if name == model.activation_layer or \
56 | (isinstance(model, torch.nn.DataParallel) and \
57 | name.replace('module.', '') == model.activation_layer):
58 | handle = layer.register_forward_hook(save_output)
59 | hook_handles.append(handle)
60 | print(f'Activation layer: {name}')
61 | # else:
62 | # # Only get last activation layer that fits the criteria?
63 | # activation_layers = []
64 | # for layer in model.modules():
65 | # # for name, layer in model.named_modules()
66 | # try:
67 | # if isinstance(layer, torch.nn.ReLU) or isinstance(layer, torch.nn.Identity):
68 | # activation_layers.append(layer)
69 | # # handle = layer.register_forward_hook(save_output)
70 | # # hook_handles.append(handle)
71 | # except AttributeError:
72 | # if isinstance(layer, torch.nn.ReLU):
73 | # activation_layers.append(layer)
74 | # # handle = layer.register_forward_hook(save_output)
75 | # # hook_handles.append(handle)
76 | # # Only get last activation layer that fits the criteria
77 | # if 'cnn' in args.arch and args.no_projection_head is False:
78 | # # or args.dataset == 'colored_mnist'):
79 | # handle = activation_layers[-2].register_forward_hook(save_output)
80 | # else:
81 | # handle = activation_layers[-1].register_forward_hook(save_output)
82 | # hook_handles.append(handle)
83 | model.to(args.device)
84 | model.eval()
85 |
86 | # Forward pass on test set to save activations
87 | correct_train = 0
88 | total_train = 0
89 | total_embeddings = []
90 | total_inputs = []
91 | total_labels = []
92 |
93 | total_predictions = []
94 |
95 | print('> Saving activations')
96 |
97 | with torch.no_grad():
98 | for i, data in enumerate(tqdm(dataloader, desc='Running inference')):
99 | inputs, labels, data_ix = data
100 | inputs = inputs.to(args.device)
101 | labels = labels.to(args.device)
102 |
103 | try:
104 | if args.mode == 'contrastive_train':
105 | input_ids = inputs[:, :, 0]
106 | input_masks = inputs[:, :, 1]
107 | segment_ids = inputs[:, :, 2]
108 | outputs = model((input_ids, input_masks, segment_ids, None)) # .logits <- changed this in the contrastive network definitino
109 | else:
110 | outputs = get_output(model, inputs, labels, args)
111 | except:
112 | outputs = get_output(model, inputs, labels, args)
113 | # Why was I collecting these? 4/27/21
114 | # total_inputs.extend(inputs.detach().cpu().numpy())
115 | # total_labels.extend(labels.detach().cpu().numpy())
116 |
117 | _, predicted = torch.max(outputs.data, 1)
118 | total_train += labels.size(0)
119 | correct_train += (predicted == labels).sum().item()
120 |
121 | # Clear memory
122 | inputs = inputs.detach().cpu()
123 | labels = labels.detach().cpu()
124 | outputs = outputs.detach().cpu()
125 | predicted = predicted.detach().cpu()
126 | total_predictions.append(predicted)
127 | del inputs; del labels; del outputs; del predicted
128 |
129 | # print(f'Accuracy of the network on the test images: %d %%' % (
130 | # 100 * correct_train / total_train))
131 |
132 | # Testing this
133 | save_output.outputs = [so.detach() for so in save_output.outputs]
134 |
135 | total_predictions = np.concatenate(total_predictions)
136 | # Consolidate embeddings
137 | total_embeddings = [None] * len(save_output.outputs)
138 |
139 | for ix, output in enumerate(save_output.outputs):
140 | total_embeddings[ix] = output.numpy().squeeze()
141 |
142 | # print(total_embeddings)
143 |
144 | if 'resnet' in args.arch or 'densenet' in args.arch or 'bert' in args.arch or 'cnn' in args.arch or 'mlp' in args.arch:
145 | total_embeddings = np.concatenate(total_embeddings)
146 | if len(total_embeddings.shape) > 2: # Should just be (n_datapoints, embedding_dim)
147 | total_embeddings = total_embeddings.reshape(len(total_embeddings), -1)
148 | save_output.clear()
149 | del save_output; del hook_handles
150 | return total_embeddings, total_predictions
151 |
152 | total_embeddings_relu1 = np.concatenate(
153 | [total_embeddings[0::2]], axis=0).reshape(-1, total_embeddings[0].shape[-1])
154 | total_embeddings_relu2 = np.concatenate(
155 | [total_embeddings[1::2]], axis=0).reshape(-1, total_embeddings[1].shape[-1])
156 |
157 | save_output.clear()
158 | del save_output; del hook_handles
159 | return total_embeddings_relu1, total_embeddings_relu2, total_predictions
160 |
161 |
162 | def visualize_activations(net, dataloader, label_types, num_data=None,
163 | figsize=(8, 6), save=True, ftype='png',
164 | title_suffix=None, save_id_suffix=None, args=None,
165 | cmap='tab10', annotate_points=None,
166 | predictions=None, return_embeddings=False):
167 | """
168 | Visualize and save model activations
169 |
170 | Args:
171 | - net (torch.nn.Module): Pytorch neural net model
172 | - dataloader (torch.utils.data.DataLoader): Pytorch dataloader
173 | - label_types (str[]): List of label types, e.g. ['target', 'spurious', 'sub_target']
174 | - num_data (int): Number of datapoints to plot
175 | - figsize (int()): Tuple of image dimensions, by (height, weight)
176 | - save (bool): Whether to save the image
177 | - ftype (str): File format for saving
178 | - args (argparse): Experiment arguments
179 | """
180 | if 'resnet' in args.arch or 'densenet' in args.arch or 'bert' in args.arch or 'cnn' in args.arch or 'mlp' in args.arch:
181 | total_embeddings, predictions = save_activations(net, dataloader, args)
182 | print(f'total_embeddings.shape: {total_embeddings.shape}')
183 | e1 = total_embeddings
184 | e2 = total_embeddings
185 | n_mult = 1
186 | else:
187 | e1, e2, predictions = save_activations(net, dataloader, args)
188 | n_mult = 2
189 | pbar = tqdm(total=n_mult * len(label_types))
190 | for label_type in label_types:
191 | # For now just save both classifier ReLU activation layers (for MLP, BaseCNN)
192 | if save_id_suffix is not None:
193 | save_id = f'{label_type[0]}{label_type[-1]}_{save_id_suffix}_e1'
194 | else:
195 | save_id = f'{label_type[0]}{label_type[-1]}_e1'
196 | # if title_suffix is not None:
197 | # save_id += f'-{title_suffix}'
198 | plot_umap(e1, dataloader.dataset, label_type, num_data, method='umap',
199 | offset=0, figsize=figsize, save_id=save_id, save=save,
200 | ftype=ftype, title_suffix=title_suffix, args=args,
201 | cmap=cmap, annotate_points=annotate_points,
202 | predictions=predictions)
203 | # Add MDS
204 | plot_umap(e1, dataloader.dataset, label_type, 1000, method='mds',
205 | offset=0, figsize=figsize, save_id=save_id, save=save,
206 | ftype=ftype, title_suffix=title_suffix, args=args,
207 | cmap=cmap, annotate_points=annotate_points,
208 | predictions=predictions)
209 | pbar.update(1)
210 | # if 'resnet' not in args.arch and 'densenet' not in args.arch and 'bert' not in args.arch:
211 | # save_id = f'{label_type}_e2'
212 | # if title_suffix is not None:
213 | # save_id += f'-{title_suffix}'
214 | # plot_umap(e2, dataloader.dataset, label_type, num_data,
215 | # offset=0, figsize=figsize, save_id=save_id, save=save,
216 | # ftype=ftype, title_suffix=title_suffix, args=args,
217 | # cmap=cmap, annotate_points=annotate_points,
218 | # predictions=predictions)
219 | # pbar.update(1)
220 | if return_embeddings:
221 | return e1, e2, predictions
222 | del total_embeddings, predictions
223 | del e1; e2
224 | #
225 |
226 |
227 | def estimate_y_probs(classifier, attribute, dataloader,
228 | classifier_test_size=0.5,
229 | model=None, embeddings=None,
230 | seed=42, reshape_prior=True, args=None):
231 | if embeddings is None:
232 | embeddings, _ = save_activations(model, dataloader, args)
233 |
234 | X = embeddings
235 | y = dataloader.dataset.targets_all[attribute]
236 | X_train, X_test, y_train, y_test = train_test_split(
237 | X, y, test_size=classifier_test_size, random_state=seed)
238 |
239 | # Fit linear classifier
240 | classifier.fit(X_train, y_train)
241 | score = classifier.score(X_test, y_test)
242 | print(f'Linear classifier score: {score:<.3f}')
243 |
244 | # Compute p(y)
245 | _, y_prior = np.unique(y_test, return_counts=True)
246 | y_prior = y_prior / y_prior.sum()
247 |
248 | # Compute p(y | X)
249 | y_post = classifier.predict_proba(X_test)
250 |
251 | if reshape_prior:
252 | y_prior = y_prior.reshape(1, -1).repeat(y_post.shape[0], axis=0)
253 | return y_post, y_prior
254 |
255 |
256 | def estimate_mi(classifier, attribute, dataloader,
257 | classifier_test_size=0.5,
258 | model=None, embeddings=None,
259 | seed=42, args=None):
260 | if embeddings is None:
261 | assert model is not None
262 | embeddings, _ = save_activations(model, dataloader, args)
263 | # Compute p(y | x), p(y)
264 | y_post, y_prior = estimate_y_probs(classifier, attribute,
265 | dataloader, classifier_test_size,
266 | model, embeddings, seed,
267 | reshape_prior=True, args=args)
268 | min_size = np.min((y_post.shape[-1], y_prior.shape[-1]))
269 | y_post = y_post[:,:min_size]
270 | y_prior = y_prior[:,:min_size]
271 | return np.sum(y_post * (np.log(y_post) - np.log(y_prior)), axis=1).mean()
272 |
273 |
274 | def compute_activation_mi(attributes, dataloader,
275 | method='logistic_regression',
276 | classifier_test_size=0.5, max_iter=1000,
277 | model=None, embeddings=None,
278 | seed=42, args=None):
279 | if embeddings is None:
280 | assert model is not None
281 | embeddings, _ = save_activations(model, dataloader, args)
282 |
283 | if method == 'logistic_regression':
284 | clf = LogisticRegression(random_state=seed, max_iter=max_iter)
285 | else:
286 | raise NotImplementedError
287 |
288 | mi_by_attributes = []
289 | for attribute in attributes: # ['target', 'spurious']
290 | mi = estimate_mi(clf, attribute, dataloader,
291 | classifier_test_size, model, embeddings,
292 | seed, args)
293 | mi_by_attributes.append(mi)
294 | return mi_by_attributes
295 |
296 |
297 | def compute_align_loss(embeddings, dataloader, measure_by='target', norm=True):
298 | targets_all = dataloader.dataset.targets_all
299 |
300 | if measure_by == 'target':
301 | targets_t = targets_all['target']
302 | targets_s = targets_all['spurious']
303 | elif measure_by == 'spurious': # A bit hacky
304 | targets_t = targets_all['spurious']
305 | targets_s = targets_all['target']
306 |
307 | embeddings_by_class = {}
308 | for t in np.unique(targets_t):
309 | tix = np.where(targets_t == t)[0]
310 | anchor_embeddings = []
311 | positive_embeddings = []
312 | for s in np.unique(targets_s):
313 | six = np.where(targets_s[tix] == s)[0]
314 | if t == s: # For waterbirds, colored MNIST only
315 | anchor_embeddings.append(embeddings[tix][six])
316 | else:
317 | positive_embeddings.append(embeddings[tix][six])
318 |
319 | embeddings_by_class[t] = {'anchor': np.concatenate(anchor_embeddings),
320 | 'positive': np.concatenate(positive_embeddings)}
321 |
322 | all_l2 = []
323 | for c, embeddings_ in embeddings_by_class.items(): # keys
324 | embeddings_a = embeddings_['anchor']
325 | embeddings_p = embeddings_['positive']
326 | if norm:
327 | embeddings_a /= np.linalg.norm(embeddings_a)
328 | embeddings_p /= np.linalg.norm(embeddings_p)
329 |
330 | pairwise_l2 = np.linalg.norm(embeddings_a[:, None, :] - embeddings_p[None, :, :],
331 | axis=-1) ** 2
332 | all_l2.append(pairwise_l2.flatten())
333 |
334 | return np.concatenate(all_l2).mean()
335 |
336 |
337 | def compute_aligned_loss_from_model(model, dataloader, norm, args):
338 | embeddings, predictions = save_activations(model, dataloader, args)
339 | return compute_align_loss(embeddings, dataloader, norm)
340 |
341 | """
342 | Legacy
343 | """
344 | def get_embeddings(net, dataloader, args):
345 | net.to(args.device)
346 | test_embeddings = []
347 | test_embeddings_r = []
348 |
349 | with torch.no_grad():
350 | for i, data in enumerate(dataloader):
351 | inputs, labels, data_ix = data
352 | inputs = inputs.to(args.device)
353 | labels = labels.to(args.device)
354 |
355 | embeddings = net.embed(inputs)
356 | embeddings_r = net.embed(inputs, relu=True)
357 |
358 | test_embeddings.append(embeddings.detach().cpu().numpy())
359 | test_embeddings_r.append(embeddings_r.detach().cpu().numpy())
360 |
361 | test_embeddings = np.concatenate(test_embeddings, axis=0)
362 | test_embeddings_r = np.concatenate(test_embeddings_r, axis=0)
363 | return test_embeddings, test_embeddings_r
--------------------------------------------------------------------------------
/data/util.py:
--------------------------------------------------------------------------------
1 | '''Modified from https://github.com/alinlab/LfF/blob/master/data/util.py'''
2 |
3 | import os
4 | import torch
5 | from torch.utils.data.dataset import Dataset, Subset
6 | from torchvision import transforms as T
7 | from glob import glob
8 | from PIL import Image
9 | from torch.utils.data import DataLoader
10 | import pandas as pd
11 | import numpy as np
12 | from .civilcomments import CivilComments, init_bert_transform
13 | from transformers import BertTokenizerFast
14 | class dataloader():
15 | def __init__(self, dataset, data_dir, percent, data2preprocess, use_type0, use_type1, batch_size, num_workers):
16 | self.dataset = dataset
17 | self.data_dir = data_dir
18 | self.percent = percent
19 | self.data2preprocess = data2preprocess
20 | self.use_type0 = use_type0
21 | self.use_type1 = use_type1
22 | self.num_workers = num_workers
23 | self.batch_size = batch_size
24 |
25 |
26 |
27 | def run(self,mode,preds=[],probs=[]):
28 | if mode=='train':
29 | train_dataset = get_dataset(
30 | self.dataset,
31 | data_dir=self.data_dir,
32 | dataset_split="train",
33 | transform_split="train",
34 | percent=self.percent,
35 | use_preprocess=self.data2preprocess[self.dataset],
36 | use_type0=self.use_type0,
37 | use_type1=self.use_type1
38 | )
39 |
40 | Idx_train_dataset = IdxDataset(train_dataset)
41 |
42 | trainloader = DataLoader(
43 | Idx_train_dataset,
44 | batch_size=self.batch_size,
45 | shuffle=True,
46 | num_workers=self.num_workers,
47 | pin_memory=True,
48 | drop_last=True
49 | )
50 | return trainloader, train_dataset
51 |
52 |
53 | elif mode=='valid':
54 | valid_dataset= get_dataset(
55 | self.dataset,
56 | data_dir=self.data_dir,
57 | dataset_split="valid",
58 | transform_split="valid",
59 | percent=self.percent,
60 | use_preprocess=self.data2preprocess[self.dataset],
61 | use_type0=self.use_type0,
62 | use_type1=self.use_type1
63 | )
64 |
65 | valid_dataset = IdxDataset(valid_dataset)
66 |
67 | valid_trainloader = DataLoader(
68 | valid_dataset,
69 | batch_size=self.batch_size,
70 | shuffle=False,
71 | num_workers=self.num_workers,
72 | pin_memory=True,
73 | )
74 |
75 |
76 | return valid_trainloader
77 |
78 | elif mode=='test':
79 | test_dataset= get_dataset(
80 | self.dataset,
81 | data_dir=self.data_dir,
82 | dataset_split="test",
83 | transform_split="valid",
84 | percent=self.percent,
85 | use_preprocess=self.data2preprocess[self.dataset],
86 | use_type0=self.use_type0,
87 | use_type1=self.use_type1
88 | )
89 |
90 | test_dataset = IdxDataset(test_dataset)
91 |
92 | test_trainloader = DataLoader(
93 | test_dataset,
94 | batch_size=self.batch_size,
95 | shuffle=False,
96 | num_workers=self.num_workers,
97 | pin_memory=True,
98 | )
99 | return test_trainloader
100 |
101 | elif mode=='eval_train':
102 | eval_train_dataset = get_dataset(
103 | self.dataset,
104 | data_dir=self.data_dir,
105 | dataset_split="train",
106 | transform_split="train",
107 | percent=self.percent,
108 | use_preprocess=self.data2preprocess[self.dataset],
109 | use_type0=self.use_type0,
110 | use_type1=self.use_type1
111 | )
112 |
113 | eval_train_dataset = IdxDataset(eval_train_dataset)
114 |
115 | eval_train_trainloader = DataLoader(
116 | eval_train_dataset,
117 | batch_size=self.batch_size,
118 | shuffle=False,
119 | num_workers=self.num_workers,
120 | pin_memory=True,
121 | drop_last=False
122 | )
123 | return eval_train_trainloader
124 |
125 |
126 |
127 | class IdxDataset(Dataset):
128 | def __init__(self, dataset):
129 | self.dataset = dataset
130 |
131 | def __len__(self):
132 | return len(self.dataset)
133 |
134 | def __getitem__(self, idx):
135 | return (idx, *self.dataset[idx])
136 |
137 |
138 | class ZippedDataset(Dataset):
139 | def __init__(self, datasets):
140 | super(ZippedDataset, self).__init__()
141 | self.dataset_sizes = [len(d) for d in datasets]
142 | self.datasets = datasets
143 |
144 | def __len__(self):
145 | return max(self.dataset_sizes)
146 |
147 | def __getitem__(self, idx):
148 | items = []
149 | for dataset_idx, dataset_size in enumerate(self.dataset_sizes):
150 | items.append(self.datasets[dataset_idx][idx % dataset_size])
151 |
152 | item = [torch.stack(tensors, dim=0) for tensors in zip(*items)]
153 |
154 | return item
155 |
156 | class CMNISTDataset(Dataset):
157 | def __init__(self,root,split,transform=None, image_path_list=None, preds = None, bias = True):
158 | super(CMNISTDataset, self).__init__()
159 | self.transform = transform
160 | self.root = root
161 | self.image2pseudo = {}
162 | self.image_path_list = image_path_list
163 |
164 | if split=='train':
165 | self.align = glob(os.path.join(root, 'align',"*","*"))
166 | self.conflict = glob(os.path.join(root, 'conflict',"*","*"))
167 | data = self.align + self.conflict
168 | indicator = [0] * len(self.align) + [1] * len(self.conflict)
169 |
170 | # print(len(self.data),'***************')
171 |
172 |
173 |
174 | if (preds is not None):
175 | pred_idx = (preds).numpy().nonzero()[0]
176 | if bias:
177 | print("Discovered biased example id", pred_idx)
178 | else:
179 | print("Discovered unbiased example id", pred_idx)
180 |
181 | self.data = [data[i] for i in pred_idx] * int(len(data)/len(pred_idx))
182 | self.indicator = [indicator[i] for i in pred_idx] * int(len(data)/len(pred_idx))
183 | else:
184 | self.data = data
185 | self.indicator = indicator
186 |
187 | elif split=='valid':
188 | self.data = glob(os.path.join(root,split,"*"))
189 | elif split=='test':
190 | self.data = glob(os.path.join(root, '../test',"*","*"))
191 | self.split = split
192 |
193 |
194 | def __len__(self):
195 | return len(self.data)
196 |
197 | def __getitem__(self, index):
198 | attr = torch.LongTensor([int(self.data[index].split('_')[-2]),int(self.data[index].split('_')[-1].split('.')[0])])
199 | image = Image.open(self.data[index]).convert('RGB')
200 |
201 | if self.transform is not None:
202 | image = self.transform(image)
203 |
204 | if self.split != 'train':
205 | return image, attr, self.data[index], -1
206 | else:
207 | return image, attr, self.data[index], self.indicator[index]
208 |
209 |
210 | class CIFAR10Dataset(Dataset):
211 | def __init__(self, root, split, transform=None, image_path_list=None, use_type0=None, use_type1=None, preds = None, bias = True):
212 | super(CIFAR10Dataset, self).__init__()
213 | self.transform = transform
214 | self.root = root
215 | self.image2pseudo = {}
216 | self.image_path_list = image_path_list
217 |
218 | if split=='train':
219 | self.align = glob(os.path.join(root, 'align',"*","*"))
220 | self.conflict = glob(os.path.join(root, 'conflict',"*","*"))
221 | data = self.align + self.conflict
222 | indicator = [0] * len(self.align) + [1] * len(self.conflict)
223 |
224 | if (preds is not None):
225 | pred_idx = (preds).numpy().nonzero()[0]
226 | if bias:
227 | print("Discovered biased example id", pred_idx)
228 | else:
229 | print("Discovered unbiased example id", pred_idx)
230 |
231 | self.data = [data[i] for i in pred_idx] * int(len(data)/len(pred_idx))
232 | self.indicator = [indicator[i] for i in pred_idx] * int(len(data)/len(pred_idx))
233 | else:
234 | self.data = data
235 | self.indicator = indicator
236 |
237 | elif split=='valid':
238 | self.data = glob(os.path.join(root,split,"*", "*"))
239 |
240 | elif split=='test':
241 | self.data = glob(os.path.join(root, '../test',"*","*"))
242 |
243 |
244 | self.split = split
245 |
246 | def __len__(self):
247 | return len(self.data)
248 |
249 | def __getitem__(self, index):
250 | attr = torch.LongTensor(
251 | [int(self.data[index].split('_')[-2]), int(self.data[index].split('_')[-1].split('.')[0])])
252 | image = Image.open(self.data[index]).convert('RGB')
253 |
254 |
255 | if self.transform is not None:
256 | image = self.transform(image)
257 |
258 | if self.split != 'train':
259 | return image, attr, self.data[index], -1
260 | else:
261 | return image, attr, self.data[index], self.indicator[index]
262 |
263 |
264 | class bFFHQDataset(Dataset):
265 | def __init__(self, root, split, transform=None, image_path_list=None, preds = None, bias = True):
266 | super(bFFHQDataset, self).__init__()
267 | self.transform = transform
268 | self.root = root
269 |
270 | self.image2pseudo = {}
271 | self.image_path_list = image_path_list
272 |
273 |
274 |
275 |
276 | if split=='train':
277 | self.align = glob(os.path.join(root, 'align',"*","*"))
278 | self.conflict = glob(os.path.join(root, 'conflict',"*","*"))
279 | data = self.align + self.conflict
280 | indicator = [0] * len(self.align) + [1] * len(self.conflict)
281 |
282 | if (preds is not None):
283 | pred_idx = (preds).numpy().nonzero()[0]
284 | if bias:
285 | print("Discovered biased example id", pred_idx)
286 | else:
287 | print("Discovered unbiased example id", pred_idx)
288 |
289 | self.data = [data[i] for i in pred_idx] * int(len(data)/len(pred_idx))
290 | self.indicator = [indicator[i] for i in pred_idx] * int(len(data)/len(pred_idx))
291 | else:
292 | self.data = data
293 | self.indicator = indicator
294 |
295 |
296 | elif split=='valid':
297 | self.data = glob(os.path.join(os.path.dirname(root), split, "*"))
298 |
299 | elif split=='test':
300 | self.data = glob(os.path.join(os.path.dirname(root), split, "*"))
301 | data_conflict = []
302 | for path in self.data:
303 | target_label = path.split('/')[-1].split('.')[0].split('_')[1]
304 | bias_label = path.split('/')[-1].split('.')[0].split('_')[2]
305 | if target_label != bias_label:
306 | data_conflict.append(path)
307 | self.data = data_conflict
308 | self.split = split
309 | def __len__(self):
310 | return len(self.data)
311 |
312 | def __getitem__(self, index):
313 | attr = torch.LongTensor(
314 | [int(self.data[index].split('_')[-2]), int(self.data[index].split('_')[-1].split('.')[0])])
315 | image = Image.open(self.data[index]).convert('RGB')
316 |
317 | if self.transform is not None:
318 | image = self.transform(image)
319 |
320 | if self.split != 'train':
321 | return image, attr, self.data[index], -1
322 | else:
323 | return image, attr, self.data[index], self.indicator[index]
324 |
325 |
326 | class WaterBirdsDataset(Dataset):
327 | def __init__(self, root, split="train", transform=None, image_path_list=None, preds = None, bias = True):
328 | try:
329 | split_i = ["train", "valid", "test"].index(split)
330 | except ValueError:
331 | raise(f"Unknown split {split}")
332 | self.split = split
333 | metadata_df = pd.read_csv(os.path.join(root, "metadata.csv"))
334 | self.metadata_df = metadata_df[metadata_df["split"] == split_i]
335 | self.root = root
336 | self.transform = transform
337 | self.y_array = self.metadata_df['y'].values
338 | self.p_array = self.metadata_df['place'].values
339 | self.n_classes = np.unique(self.y_array).size
340 | self.confounder_array = self.metadata_df['place'].values
341 | self.n_places = np.unique(self.confounder_array).size
342 | self.group_array = (self.y_array * self.n_places + self.confounder_array).astype('int')
343 | self.indicator = np.abs(self.y_array - self.confounder_array).astype('int')
344 | self.n_groups = self.n_classes * self.n_places
345 | self.group_counts = (
346 | torch.arange(self.n_groups).unsqueeze(1) == torch.from_numpy(self.group_array)).sum(1).float()
347 | self.y_counts = (
348 | torch.arange(self.n_classes).unsqueeze(1) == torch.from_numpy(self.y_array)).sum(1).float()
349 | self.p_counts = (
350 | torch.arange(self.n_places).unsqueeze(1) == torch.from_numpy(self.p_array)).sum(1).float()
351 | self.filename_array = self.metadata_df['img_filename'].values
352 |
353 | def __len__(self):
354 | return len(self.metadata_df)
355 |
356 | def __getitem__(self, idx):
357 | y = self.y_array[idx]
358 | g = self.group_array[idx]
359 | p = self.confounder_array[idx]
360 |
361 | attr = torch.LongTensor(
362 | [y, p, g])
363 |
364 | img_path = os.path.join(self.root, self.filename_array[idx])
365 | img = Image.open(img_path).convert('RGB')
366 | # img = read_image(img_path)
367 | # img = img.float() / 255.
368 |
369 |
370 |
371 | if self.transform:
372 | img = self.transform(img)
373 |
374 | if self.split != 'train':
375 | return img, attr, self.filename_array[idx], self.indicator[idx]
376 | else:
377 | return img, attr, self.filename_array[idx], self.indicator[idx]
378 |
379 |
380 |
381 |
382 | transforms = {
383 | "cmnist": {
384 | "train": T.Compose([T.ToTensor()]),
385 | "valid": T.Compose([T.ToTensor()]),
386 | "test": T.Compose([T.ToTensor()])
387 | },
388 | "bffhq": {
389 | "train": T.Compose([T.Resize((224,224)), T.ToTensor()]),
390 | "valid": T.Compose([T.Resize((224,224)), T.ToTensor()]),
391 | "test": T.Compose([T.Resize((224,224)), T.ToTensor()])
392 | },
393 | "cifar10c": {
394 | "train": T.Compose([T.ToTensor(),]),
395 | "valid": T.Compose([T.ToTensor(),]),
396 | "test": T.Compose([T.ToTensor(),]),
397 | },
398 |
399 | "waterbird":{
400 | "train": T.Compose([
401 | T.Resize((256, 256)),
402 | T.CenterCrop((224,224)),
403 | T.ToTensor(),
404 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
405 | "valid": T.Compose([
406 | T.Resize((256, 256)),
407 | T.CenterCrop((224,224)),
408 | T.ToTensor(),
409 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
410 | "test": T.Compose([
411 | T.Resize((256, 256)),
412 | T.CenterCrop((224,224)),
413 | T.ToTensor(),
414 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
415 | },
416 |
417 | }
418 |
419 |
420 | transforms_preprcs = {
421 | "cmnist": {
422 | "train": T.Compose([T.ToTensor()]),
423 | "valid": T.Compose([T.ToTensor()]),
424 | "test": T.Compose([T.ToTensor()])
425 | },
426 | "bffhq": {
427 | "train": T.Compose([
428 | T.Resize((224,224)),
429 | T.RandomCrop(224, padding=4),
430 | T.RandomHorizontalFlip(),
431 | T.ToTensor(),
432 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
433 | ]
434 | ),
435 | "valid": T.Compose([
436 | T.Resize((224,224)),
437 | T.ToTensor(),
438 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
439 | ]
440 | ),
441 | "test": T.Compose([
442 | T.Resize((224,224)),
443 | T.ToTensor(),
444 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
445 | ]
446 | )
447 | },
448 | "cifar10c": {
449 | "train": T.Compose(
450 | [
451 | T.RandomCrop(32, padding=4),
452 | # T.RandomResizedCrop(32),
453 | T.RandomHorizontalFlip(),
454 | T.ToTensor(),
455 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
456 | ]
457 | ),
458 | "valid": T.Compose(
459 | [
460 | T.ToTensor(),
461 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
462 | ]
463 | ),
464 | "test": T.Compose(
465 | [
466 | T.ToTensor(),
467 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
468 | ]
469 | ),
470 | },
471 | "waterbird": {
472 | "train": T.Compose(
473 | [
474 | T.RandomResizedCrop(
475 | (224,224),
476 | scale=(0.7, 1.0),
477 | ratio=(0.75, 1.3333333333333333),
478 | interpolation=2),
479 | T.RandomHorizontalFlip(),
480 | T.ToTensor(),
481 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
482 | ]
483 | ),
484 | "valid": T.Compose(
485 | [
486 | T.Resize((256, 256)),
487 | T.CenterCrop((224,224)),
488 | T.ToTensor(),
489 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
490 | ]
491 | ),
492 | "test": T.Compose(
493 | [
494 | T.Resize((256, 256)),
495 | T.CenterCrop((224,224)),
496 | T.ToTensor(),
497 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
498 | ]
499 | ),
500 | },
501 | }
502 |
503 |
504 |
505 | transforms_preprcs_ae = {
506 | "cmnist": {
507 | "train": T.Compose([T.ToTensor()]),
508 | "valid": T.Compose([T.ToTensor()]),
509 | "test": T.Compose([T.ToTensor()])
510 | },
511 | "bffhq": {
512 | "train": T.Compose([
513 | T.Resize((224,224)),
514 | T.RandomCrop(224, padding=4),
515 | T.RandomHorizontalFlip(),
516 | T.ToTensor(),
517 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
518 | ]
519 | ),
520 | "valid": T.Compose([
521 | T.Resize((224,224)),
522 | T.ToTensor(),
523 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
524 | ]
525 | ),
526 | "test": T.Compose([
527 | T.Resize((224,224)),
528 | T.ToTensor(),
529 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
530 | ]
531 | )
532 | },
533 | "cifar10c": {
534 | "train": T.Compose(
535 | [
536 | # T.RandomCrop(32, padding=4),
537 | T.RandomResizedCrop(32),
538 | T.RandomHorizontalFlip(),
539 | T.ToTensor(),
540 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
541 | ]
542 | ),
543 | "valid": T.Compose(
544 | [
545 | T.ToTensor(),
546 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
547 | ]
548 | ),
549 | "test": T.Compose(
550 | [
551 | T.ToTensor(),
552 | T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
553 | ]
554 | ),
555 | },
556 |
557 | "waterbird":{
558 | "train": T.Compose([
559 | T.Resize((256, 256)),
560 | T.CenterCrop((224,224)),
561 | T.ToTensor(),
562 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
563 | "valid": T.Compose([
564 | T.Resize((256, 256)),
565 | T.CenterCrop((224,224)),
566 | T.ToTensor(),
567 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
568 | "test": T.Compose([
569 | T.Resize((256, 256)),
570 | T.CenterCrop((224,224)),
571 | T.ToTensor(),
572 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
573 | },
574 | }
575 | def get_dataset(dataset, data_dir, dataset_split, transform_split, percent, use_preprocess=None, image_path_list=None, use_type0=None, use_type1=None, preds = None, bias = True):
576 | dataset_category = dataset.split("-")[0]
577 | if dataset != 'civilcomments':
578 | if use_preprocess:
579 | transform = transforms_preprcs[dataset_category][transform_split]
580 | else:
581 | transform = transforms[dataset_category][transform_split]
582 | else:
583 | arch = 'bert-base-uncased_pt'
584 | pretrained_name = arch if arch[-3:] != '_pt' else arch[:-3]
585 | tokenizer = BertTokenizerFast.from_pretrained(pretrained_name) # 'bert-base-uncased'
586 | transform = init_bert_transform(tokenizer, pretrained_name, max_token_length = 300)
587 |
588 |
589 | dataset_split = "valid" if (dataset_split == "eval") else dataset_split
590 | if dataset == 'cmnist':
591 | root = data_dir + f"/cmnist/{percent}"
592 | dataset = CMNISTDataset(root=root,split=dataset_split,transform=transform, image_path_list=image_path_list, preds = preds, bias = bias)
593 |
594 | elif 'cifar10c' in dataset:
595 | # if use_type0:
596 | # root = data_dir + f"/cifar10c_0805_type0/{percent}"
597 | # elif use_type1:
598 | # root = data_dir + f"/cifar10c_0805_type1/{percent}"
599 | # else:
600 | root = data_dir + f"/cifar10c/{percent}"
601 | dataset = CIFAR10Dataset(root=root, split=dataset_split, transform=transform, image_path_list=image_path_list, use_type0=use_type0, use_type1=use_type1)
602 |
603 | elif dataset == "bffhq":
604 | root = data_dir + f"/bffhq/{percent}"
605 | dataset = bFFHQDataset(root=root, split=dataset_split, transform=transform, image_path_list=image_path_list)
606 | elif dataset == 'waterbird':
607 | root = data_dir + f"/waterbird"
608 | dataset = WaterBirdsDataset(root=root, split=dataset_split, transform=transform, image_path_list=image_path_list)
609 | elif dataset == 'civilcomments':
610 | root = data_dir + f"/CivilComments"
611 | if dataset_split == 'train':
612 | dataset = CivilComments(root, target_name='toxic',
613 | confounder_names=['identities'],
614 | split='train', transform=transform)
615 | elif dataset_split == 'valid':
616 | dataset = CivilComments(root, target_name='toxic',
617 | confounder_names=['identities'],
618 | split='val', transform=transform)
619 | elif dataset_split == 'test':
620 | dataset = CivilComments(root, target_name='toxic',
621 | confounder_names=['identities'],
622 | split='test', transform=transform)
623 | else:
624 | print('wrong dataset ...')
625 | import sys
626 | sys.exit(0)
627 |
628 | return dataset
629 |
630 |
--------------------------------------------------------------------------------
/learner.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import wandb
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 |
7 | import os
8 | import torch.optim as optim
9 |
10 | from data.util import dataloader
11 | from module.lc_loss.loss import GeneralizedCELoss, LogitCorrectionLoss
12 | from module.lc_loss.group_mixup import group_mixUp
13 | from module.util import get_model
14 | from util import EMA, EMA_squre, sigmoid_rampup, get_bert_scheduler
15 |
16 | import torch.nn.functional as F
17 | import random
18 |
19 |
20 | import matplotlib.pyplot as plt
21 | import matplotlib.cm as cmap
22 |
23 | def cycle(iterable):
24 | while True:
25 | for x in iterable:
26 | yield x
27 |
28 |
29 |
30 | class Learner(nn.Module):
31 | def __init__(self, args):
32 | super(Learner, self).__init__()
33 | data2model = {'cmnist': "MLP",
34 | 'cifar10c': "ResNet18",
35 | 'bffhq': "ResNet18",
36 | 'waterbird': "resnet_50",
37 | 'civilcomments': "bert-base-uncased_pt"}
38 |
39 | data2batch_size = {'cmnist': 256,
40 | 'cifar10c': 256,
41 | 'bffhq': 64,
42 | 'waterbird': 32,
43 | 'civilcomments': 16}
44 |
45 | data2preprocess = {'cmnist': None,
46 | 'cifar10c': True,
47 | 'bffhq': True,
48 | 'waterbird': True,
49 | 'civilcomments': None}
50 |
51 | self.data2preprocess = data2preprocess
52 | self.data2batch_size = data2batch_size
53 |
54 |
55 | args.exp = '{:s}_ema_{:.2f}_tau_{:.2f}_lambda_{:.2f}_avgtype_{:s}'.format(args.exp, args.ema_alpha, args.tau, args.lambda_dis_align, args.avg_type)
56 |
57 | if args.wandb:
58 | import wandb
59 | wandb.init(project='Learning-with-Logit-Correction')
60 | wandb.run.name = args.exp
61 |
62 | run_name = args.exp
63 | if args.tensorboard:
64 | from tensorboardX import SummaryWriter
65 | self.writer = SummaryWriter(f'result/summary/{run_name}')
66 |
67 | self.model = data2model[args.dataset]
68 | self.batch_size = data2batch_size[args.dataset]
69 |
70 | print(f'model: {self.model} || dataset: {args.dataset}')
71 | print(f'working with experiment: {args.exp}...')
72 | self.log_dir = os.makedirs(os.path.join(args.log_dir, args.dataset, args.exp), exist_ok=True)
73 | self.device = torch.device(args.device)
74 | self.args = args
75 |
76 | # logging directories
77 | self.log_dir = os.path.join(args.log_dir, args.dataset, args.exp)
78 | self.summary_dir = os.path.join(args.log_dir, args.dataset, "summary", args.exp)
79 | self.summary_gradient_dir = os.path.join(self.log_dir, "gradient")
80 | self.result_dir = os.path.join(self.log_dir, "result")
81 | self.plot_dir = os.path.join(self.log_dir, "figure")
82 | os.makedirs(self.summary_dir, exist_ok=True)
83 | os.makedirs(self.result_dir, exist_ok=True)
84 | os.makedirs(self.plot_dir, exist_ok=True)
85 |
86 | self.loader = dataloader(
87 | args.dataset,
88 | args.data_dir,
89 | args.percent,
90 | data2preprocess,
91 | args.use_type0,
92 | args.use_type1,
93 | self.batch_size,
94 | args.num_workers
95 | )
96 |
97 | self.train_loader, self.train_dataset = self.loader.run('train')
98 |
99 | self.valid_loader = self.loader.run('valid')
100 | self.test_loader = self.loader.run('test')
101 |
102 |
103 | if args.dataset == 'waterbird' or args.dataset == 'civilcomments':
104 | train_target_attr = self.train_dataset.y_array
105 | train_target_attr = torch.LongTensor(train_target_attr)
106 | else:
107 | train_target_attr = []
108 | for data in self.train_dataset.data:
109 | train_target_attr.append(int(data.split('_')[-2]))
110 | train_target_attr = torch.LongTensor(train_target_attr)
111 |
112 | attr_dims = []
113 | attr_dims.append(torch.max(train_target_attr).item() + 1)
114 | self.num_classes = attr_dims[0]
115 | num_example = len(train_target_attr)
116 | print('Num example in training is {:d}, Num classes is {:d} \n'.format(num_example, self.num_classes))
117 |
118 |
119 | self.sample_margin_ema_b = EMA(torch.LongTensor(train_target_attr), num_classes=self.num_classes, alpha=0)
120 | self.confusion = EMA_squre(num_classes=self.num_classes, alpha=args.ema_alpha, avg_type = args.avg_type)
121 | print(f'alpha : {self.sample_margin_ema_b.alpha}')
122 |
123 | self.best_valid_acc_b, self.best_test_acc_b = 0., 0.
124 | self.best_valid_acc_d, self.best_test_acc_d = 0., 0.
125 |
126 | self.best_valid_acc_avg, self.best_test_acc_avg = 0., 0.
127 | self.best_valid_acc_worst, self.best_test_acc_worst = 0., 0.
128 |
129 | print('finished model initialization....')
130 |
131 |
132 | # evaluation code for vanilla
133 | def evaluate(self, model, data_loader):
134 | model.eval()
135 | total_correct, total_num = 0, 0
136 | for _, data, attr, _, _ in tqdm(data_loader, leave=False):
137 |
138 | label = attr[:, 0]
139 | if attr.shape[1] > 2:
140 | group = attr[:, 2]
141 | else:
142 | group = None
143 | # label = attr
144 | data = data.to(self.device)
145 | label = label.to(self.device)
146 |
147 |
148 |
149 | # label = attr[:, 0]
150 | # data = data.to(self.device)
151 | # label = label.to(self.device)
152 |
153 | with torch.no_grad():
154 | logit = model(data)
155 | pred = logit.data.max(1, keepdim=True)[1].squeeze(1)
156 | correct = (pred == label).long()
157 | total_correct += correct.sum()
158 | total_num += correct.shape[0]
159 |
160 | accs = total_correct/float(total_num)
161 | model.train()
162 |
163 | return accs
164 |
165 |
166 | def summarize_acc(self, correct_by_groups, total_by_groups, bias = True, split = 'Train',stdout=True):
167 | all_correct = 0
168 | all_total = 0
169 | min_acc = 101.
170 | min_correct_total = [None, None]
171 | # if stdout:
172 | # print(split + ' Accuracies by groups:')
173 | for yix, y_group in enumerate(correct_by_groups):
174 | for aix, a_group in enumerate(y_group):
175 | acc = a_group / total_by_groups[yix][aix] * 100
176 | if acc < min_acc:
177 | min_acc = acc
178 | min_correct_total[0] = a_group
179 | min_correct_total[1] = total_by_groups[yix][aix]
180 | if stdout:
181 | print(
182 | f'{yix}, {aix} acc: {int(a_group):5d} / {int(total_by_groups[yix][aix]):5d} = {a_group / total_by_groups[yix][aix] * 100:>7.3f}')
183 | all_correct += a_group
184 | all_total += total_by_groups[yix][aix]
185 | if stdout:
186 | if bias:
187 | average_str = f'Bised Average acc: {int(all_correct):5d} / {int(all_total):5d} = {100 * all_correct / all_total:>7.3f}'
188 | robust_str = f'Bised Robust acc: {int(min_correct_total[0]):5d} / {int(min_correct_total[1]):5d} = {min_acc:>7.3f}'
189 | else:
190 | average_str = f'Average acc: {int(all_correct):5d} / {int(all_total):5d} = {100 * all_correct / all_total:>7.3f}'
191 | robust_str = f'Robust acc: {int(min_correct_total[0]):5d} / {int(min_correct_total[1]):5d} = {min_acc:>7.3f}'
192 | print('-' * len(average_str))
193 | print(average_str)
194 | print(robust_str)
195 | print('-' * len(average_str))
196 | # return all_correct / all_total * 100, min_acc
197 | return min_acc
198 |
199 | # model_b, model_l, data_loader, n_group = 4, model='label', mode='dummy'
200 | def evaluate_civilcomments(self, net_b, net, dataloader, bias = False, n_group = 4, model='label', mode='dummy', split='Train', step = 0):
201 |
202 |
203 | if bias:
204 | net = net_b
205 |
206 | dataset = dataloader.dataset.dataset
207 | metadata = dataset.metadata_array
208 | correct_by_groups = np.zeros([2, len(dataset._identity_vars)])
209 | total_by_groups = np.zeros(correct_by_groups.shape)
210 |
211 | identity_to_ix = {}
212 | for idx, identity in enumerate(dataset._identity_vars):
213 | identity_to_ix[identity] = idx
214 |
215 | for identity_var, eval_grouper in zip(dataset._identity_vars,
216 | dataset._eval_groupers):
217 | group_idx = eval_grouper.metadata_to_group(metadata).numpy()
218 |
219 | g_list, g_counts = np.unique(group_idx, return_counts=True)
220 | # print(identity_var, identity_to_ix[identity_var])
221 | # print(g_counts)
222 |
223 | for g_ix, g in enumerate(g_list):
224 | g_count = g_counts[g_ix]
225 | # Only pick from positive identities
226 | # e.g. only 1 and 3 from here:
227 | # 0 y:0_male:0
228 | # 1 y:0_male:1
229 | # 2 y:1_male:0
230 | # 3 y:1_male:1
231 | n_total = g_counts[g_ix] # + g_counts[3]
232 | if g in [1, 3]:
233 | class_ix = 0 if g == 1 else 1 # 1 y:0_male:1
234 | # print(g_ix, g, n_total)
235 |
236 | # net.to(args.device)
237 | net.eval()
238 | total_correct = 0
239 | with torch.no_grad():
240 | all_predictions = []
241 | all_correct = []
242 |
243 | for data in tqdm(dataloader, leave=False):
244 | i, inputs, attr, _, _ = data
245 | # inputs, labels, data_ix = data
246 | #for i, data in enumerate(tqdm(dataloader)):
247 | labels = attr[:, 0]
248 | # label = attr
249 | inputs = inputs.to(self.device)
250 | labels = labels.to(self.device)
251 |
252 | # Add this here to generalize NLP, CV models
253 | #outputs = get_output(net, inputs, labels, args)
254 | input_ids = inputs[:, :, 0]
255 | input_masks = inputs[:, :, 1]
256 | segment_ids = inputs[:, :, 2]
257 | outputs = net(input_ids=input_ids,
258 | attention_mask=input_masks,
259 | token_type_ids=segment_ids,
260 | labels=labels)[1]
261 |
262 | _, predicted = torch.max(outputs.data, 1)
263 | correct = (predicted == labels).detach().cpu()
264 | total_correct += correct.sum().item()
265 | all_correct.append(correct)
266 | all_predictions.append(predicted.detach().cpu())
267 |
268 | inputs = inputs.to(torch.device('cpu'))
269 | labels = labels.to(torch.device('cpu'))
270 | outputs = outputs.to(torch.device('cpu'))
271 | del inputs; del labels; del outputs
272 |
273 | all_correct = torch.cat(all_correct).numpy()
274 | all_predictions = torch.cat(all_predictions)
275 |
276 | # Evaluate predictions
277 | # dataset = dataloader.dataset
278 | y_pred = all_predictions # torch.tensors
279 | y_true = dataset.y_array
280 | metadata = dataset.metadata_array
281 |
282 | correct_by_groups = np.zeros([2, len(dataset._identity_vars)])
283 | total_by_groups = np.zeros(correct_by_groups.shape)
284 |
285 | n_group = 2 * len(dataset._identity_vars)
286 | for identity_var, eval_grouper in zip(dataset._identity_vars,
287 | dataset._eval_groupers):
288 | group_idx = eval_grouper.metadata_to_group(metadata).numpy()
289 |
290 | g_list, g_counts = np.unique(group_idx, return_counts=True)
291 | # print(g_counts)
292 |
293 | idx = identity_to_ix[identity_var]
294 |
295 | for g_ix, g in enumerate(g_list):
296 | g_count = g_counts[g_ix]
297 | # Only pick from positive identities
298 | # e.g. only 1 and 3 from here:
299 | # 0 y:0_male:0
300 | # 1 y:0_male:1
301 | # 2 y:1_male:0
302 | # 3 y:1_male:1
303 | n_total = g_count # s[1] + g_counts[3]
304 | if g in [1, 3]:
305 | n_correct = all_correct[group_idx == g].sum()
306 | class_ix = 0 if g == 1 else 1 # 1 y:0_male:1
307 | correct_by_groups[class_ix][idx] += n_correct
308 | total_by_groups[class_ix][idx] += n_total
309 |
310 | group_acc = correct_by_groups / total_by_groups
311 | acc_groups = {}
312 | for group in range(n_group):
313 | acc_groups[group] = group_acc[group // len(dataset._identity_vars), group % len(dataset._identity_vars)]
314 | accs = total_correct/len(dataset)
315 |
316 | robust_acc = self.summarize_acc(correct_by_groups, total_by_groups, bias = bias, split=split, stdout=True)
317 | if not bias:
318 | if split == 'test':
319 | self.save_bert(step, robust_acc, net)
320 |
321 | net.train()
322 |
323 |
324 | return accs, acc_groups
325 | # return 0, total_correct, len(dataset), correct_by_groups, total_by_groups, None, None, None
326 |
327 | def evaluate_ours(self, model_b, model_l, data_loader, n_group = 4, model='label', mode='dummy', split='Train', step = 0):
328 | if self.args.dataset =='civilcomments':
329 | return self.evaluate_civilcomments(model_b, model_l, data_loader, bias = (model == 'bias'), n_group = 4, model='label', mode='dummy', split=split, step = step)
330 |
331 |
332 | model_b.eval()
333 | model_l.eval()
334 |
335 | total_correct, total_num = 0, 0
336 | total_correct_groups = {}
337 | total_num_groups = {}
338 | acc_groups = {}
339 | for group in range(n_group):
340 | total_correct_groups[group] = 0
341 | total_num_groups[group] = 0
342 |
343 |
344 | iter = 0
345 |
346 | # index, data, attr, image_path, indicator
347 |
348 | for _, data, attr, _, _ in tqdm(data_loader, leave=False):
349 | label = attr[:, 0]
350 | if attr.shape[1] > 2:
351 | group = attr[:, 2]
352 | else:
353 | group = None
354 | # label = attr
355 | data = data.to(self.device)
356 | label = label.to(self.device)
357 |
358 | with torch.no_grad():
359 | if self.args.dataset == 'cmnist':
360 | z_l = model_l.extract(data)
361 | z_b = model_b.extract(data)
362 | elif self.args.dataset != 'civilcomments':
363 | z_l, z_b = [], []
364 | if mode == 'dummy':
365 | hook_fn = self.model_l.avgpool.register_forward_hook(self.concat_dummy(z_l))
366 | else:
367 | hook_fn = self.model_l.avgpool.register_forward_hook(self.no_dummy(z_l))
368 | _ = self.model_l(data)
369 | hook_fn.remove()
370 | z_l = z_l[0]
371 | if mode == 'dummy':
372 | hook_fn = self.model_b.avgpool.register_forward_hook(self.concat_dummy(z_b))
373 | else:
374 | hook_fn = self.model_b.avgpool.register_forward_hook(self.no_dummy(z_b))
375 | _ = self.model_b(data)
376 | hook_fn.remove()
377 | z_b = z_b[0]
378 |
379 | if mode == 'dummy':
380 | if iter == 0:
381 | # print('Current mode using is {:s} \n'.format(mode))
382 | iter += 1
383 | z_origin = torch.cat((z_l, z_b), dim=1)
384 | if model == 'bias':
385 | pred_label = model_b.fc(z_origin)
386 | else:
387 | pred_label = model_l.fc(z_origin)
388 | else:
389 | if iter == 0:
390 | # print('Current mode using is {:s} \n'.format(mode))
391 | iter += 1
392 | z_origin = z_l
393 | if model == 'bias':
394 | pred_label = model_b.fc(z_origin)
395 | else:
396 | pred_label = model_l.fc(z_origin)
397 | else:
398 | input_ids = data[:, :, 0]
399 | input_masks = data[:, :, 1]
400 | segment_ids = data[:, :, 2]
401 | if model == 'bias':
402 | pred_label = model_b(input_ids=input_ids,
403 | attention_mask=input_masks,
404 | token_type_ids=segment_ids,
405 | labels=label)[1]
406 | else:
407 | pred_label = model_l(input_ids=input_ids,
408 | attention_mask=input_masks,
409 | token_type_ids=segment_ids,
410 | labels=label)[1]
411 |
412 |
413 | pred = pred_label.data.max(1, keepdim=True)[1].squeeze(1)
414 |
415 | correct = (pred == label).long()
416 | total_correct += correct.sum()
417 | total_num += correct.shape[0]
418 |
419 | if group is not None:
420 | for group_id in range(n_group):
421 | group_select = (group == group_id)
422 | correct_select = (pred[group_select] == label[group_select]).long()
423 | total_correct_groups[group_id] += correct_select.sum()
424 | total_num_groups[group_id] += correct_select.shape[0]
425 |
426 |
427 | accs = total_correct/float(total_num)
428 | if group is not None:
429 | for group_id in range(n_group):
430 | acc_groups[group_id] = (total_correct_groups[group_id]/float(total_num_groups[group_id])).item()
431 | else:
432 | acc_groups = None
433 | model_b.train()
434 | model_l.train()
435 |
436 | return accs, acc_groups
437 |
438 | def save_vanilla(self, step, best=None):
439 | if best:
440 | model_path = os.path.join(self.result_dir, "best_model.th")
441 | else:
442 | model_path = os.path.join(self.result_dir, "model_{}.th".format(step))
443 | state_dict = {
444 | 'steps': step,
445 | 'state_dict': self.model_b.state_dict(),
446 | 'optimizer': self.optimizer_b.state_dict(),
447 | }
448 | with open(model_path, "wb") as f:
449 | torch.save(state_dict, f)
450 | print(f'{step} model saved ...')
451 |
452 |
453 | def save_ours(self, step, best=None):
454 | if best:
455 | model_path = os.path.join(self.result_dir, "best_model_l.th")
456 | else:
457 | model_path = os.path.join(self.result_dir, "model_l_{}.th".format(step))
458 | state_dict = {
459 | 'steps': step,
460 | 'state_dict': self.model_l.state_dict(),
461 | 'optimizer': self.optimizer_l.state_dict(),
462 | }
463 | with open(model_path, "wb") as f:
464 | torch.save(state_dict, f)
465 |
466 | if best:
467 | model_path = os.path.join(self.result_dir, "best_model_b.th")
468 | else:
469 | model_path = os.path.join(self.result_dir, "model_b_{}.th".format(step))
470 | state_dict = {
471 | 'steps': step,
472 | 'state_dict': self.model_b.state_dict(),
473 | 'optimizer': self.optimizer_b.state_dict(),
474 | }
475 | with open(model_path, "wb") as f:
476 | torch.save(state_dict, f)
477 |
478 | print(f'{step} model saved ...')
479 |
480 | def save_bert(self, step, robust_acc, net):
481 | model_path = os.path.join(self.result_dir, f"model_{step}_{robust_acc}.th")
482 | state_dict = {
483 | 'state_dict': net.state_dict(),
484 | }
485 | with open(model_path, "wb") as f:
486 | torch.save(state_dict, f)
487 |
488 | print(f'model saved ...')
489 |
490 |
491 | def board_vanilla_loss(self, step, loss_b):
492 | if self.args.wandb:
493 | wandb.log({
494 | "loss_b_train": loss_b,
495 | }, step=step,)
496 |
497 | if self.args.tensorboard:
498 | self.writer.add_scalar(f"loss/loss_b_train", loss_b, step)
499 |
500 |
501 | def board_ours_loss(self, step, loss_dis_conflict, loss_dis_align, confusion, global_count):
502 |
503 | flatten_confusion = confusion.flatten()
504 | print('Correction: ', flatten_confusion)
505 | if self.args.wandb:
506 | wandb.log({
507 | "loss_dis_conflict": loss_dis_conflict,
508 | "loss_dis_align": loss_dis_align,
509 | "loss": loss_dis_conflict + loss_dis_align,
510 | }, step=step,)
511 |
512 |
513 | flatten_global_count = global_count.flatten()
514 | for i in range(len(flatten_confusion)):
515 | wandb.log({"logit_correction_"+str(i): flatten_confusion[i]}, step=step,)
516 | wandb.log({"global_count_"+str(i): flatten_global_count[i]}, step=step,)
517 |
518 | if self.args.tensorboard:
519 | self.writer.add_scalar(f"loss/loss_dis_conflict", loss_dis_conflict, step)
520 | self.writer.add_scalar(f"loss/loss_dis_align", loss_dis_align, step)
521 | self.writer.add_scalar(f"loss/loss", loss_dis_conflict + loss_dis_align)
522 |
523 | def board_vanilla_acc(self, step, epoch, inference=None):
524 | valid_accs_b = self.evaluate(self.model_b, self.valid_loader)
525 | test_accs_b = self.evaluate(self.model_b, self.test_loader)
526 |
527 | # print(f'epoch: {epoch}')
528 |
529 | if valid_accs_b >= self.best_valid_acc_b:
530 | self.best_valid_acc_b = valid_accs_b
531 | if test_accs_b >= self.best_test_acc_b:
532 | self.best_test_acc_b = test_accs_b
533 | self.save_vanilla(step, best=True)
534 |
535 | if self.args.wandb:
536 | wandb.log({
537 | "acc_b_valid": valid_accs_b,
538 | "acc_b_test": test_accs_b,
539 | },
540 | step=step,)
541 | wandb.log({
542 | "best_acc_b_valid": self.best_valid_acc_b,
543 | "best_acc_b_test": self.best_test_acc_b,
544 | },
545 | step=step, )
546 |
547 | print(f'valid_b: {valid_accs_b} || test_b: {test_accs_b}')
548 |
549 | if self.args.tensorboard:
550 | self.writer.add_scalar(f"acc/acc_b_valid", valid_accs_b, step)
551 | self.writer.add_scalar(f"acc/acc_b_test", test_accs_b, step)
552 |
553 | self.writer.add_scalar(f"acc/best_acc_b_valid", self.best_valid_acc_b, step)
554 | self.writer.add_scalar(f"acc/best_acc_b_test", self.best_test_acc_b, step)
555 |
556 |
557 | def board_ours_acc(self, step, inference=None, model ='debias', mode = 'dummy', n_group = 4, eval = False, save = True):
558 | # check label network
559 |
560 | valid_accs_d, valid_acc_groups = self.evaluate_ours(self.model_b, self.model_l, self.valid_loader, n_group = n_group, model=model, mode=mode, split = 'valid', step = step)
561 |
562 | test_accs_d, test_acc_groups = self.evaluate_ours(self.model_b, self.model_l, self.test_loader, n_group = n_group, model=model, mode=mode, split = 'test', step = step)
563 |
564 | if eval:
565 | return
566 |
567 |
568 | if valid_acc_groups is not None:
569 | valid_group_acc_list = list(valid_acc_groups.values())
570 | valid_accs_avg = np.nanmean(valid_group_acc_list)
571 | valid_accs_worst = np.nanmin(valid_group_acc_list)
572 |
573 | if valid_accs_avg >= self.best_valid_acc_avg:
574 | self.best_valid_acc_avg = valid_accs_avg
575 | if valid_accs_worst >= self.best_valid_acc_worst:
576 | self.best_valid_acc_worst = valid_accs_worst
577 |
578 | if test_acc_groups is not None:
579 |
580 | test_group_acc_list = list(test_acc_groups.values())
581 | test_accs_avg = np.nanmean(test_group_acc_list)
582 | test_accs_worst = np.nanmin(test_group_acc_list)
583 |
584 | if test_accs_avg >=self.best_test_acc_avg:
585 | self.best_test_acc_avg = test_accs_avg
586 | if test_accs_worst >=self.best_test_acc_worst:
587 | self.best_test_acc_worst = test_accs_worst
588 |
589 | # else:
590 | # valid_accs_avg = 0
591 | # valid_accs_worst = 0
592 | # test_accs_avg = 0
593 | # test_accs_worst = 0
594 |
595 |
596 | if inference:
597 | print(f'test acc: {test_accs_d.item()}')
598 | import sys
599 | sys.exit(0)
600 |
601 | if valid_accs_d >= self.best_valid_acc_d:
602 | self.best_valid_acc_d = valid_accs_d
603 | if test_accs_d >= self.best_test_acc_d:
604 | self.best_test_acc_d = test_accs_d
605 | if save:
606 | self.save_ours(step, best=True)
607 |
608 |
609 | if self.args.wandb:
610 | wandb.log({
611 | "acc_d_valid": valid_accs_d,
612 | "acc_d_test": test_accs_d,
613 | "acc_avg_valid": valid_accs_avg,
614 | "acc_worst_valid": valid_accs_worst,
615 | "acc_avg_test": test_accs_avg,
616 | "acc_worst_test": test_accs_worst
617 | },
618 | step=step, )
619 | wandb.log({
620 | "best_acc_d_valid": self.best_valid_acc_d,
621 | "best_acc_d_test": self.best_test_acc_d,
622 | "best_acc_avg_valid": self.best_valid_acc_avg,
623 | "best_acc_avg_test": self.best_test_acc_avg,
624 | "best_acc_worst_valid": self.best_valid_acc_worst,
625 | "best_acc_worst_test": self.best_test_acc_worst,
626 | },
627 | step=step, )
628 |
629 | if (test_acc_groups is not None) and len(test_group_acc_list) < 16:
630 |
631 | for g_id in range(len(test_group_acc_list)):
632 | wandb.log({
633 | "test_acc_group_" + str(g_id): test_group_acc_list[g_id],
634 | },
635 | step=step, )
636 |
637 |
638 | if self.args.tensorboard:
639 | self.writer.add_scalar(f"acc/acc_d_valid", valid_accs_d, step)
640 | self.writer.add_scalar(f"acc/acc_d_test", test_accs_d, step)
641 | self.writer.add_scalar(f"acc/best_acc_d_valid", self.best_valid_acc_d, step)
642 | self.writer.add_scalar(f"acc/best_acc_d_test", self.best_test_acc_d, step)
643 | self.writer.add_scalar(f"acc/best_acc_avg_valid", self.best_valid_acc_avg, step)
644 | self.writer.add_scalar(f"acc/best_acc_worst_valid", self.best_valid_acc_worst, step)
645 | self.writer.add_scalar(f"acc/best_acc_worst_test", self.best_test_acc_worst, step)
646 |
647 | if (test_acc_groups is not None) and len(test_group_acc_list) < 16:
648 | for g_id in range(len(test_group_acc_list)):
649 | print(f"test_acc_group_{g_id}: {valid_group_acc_list[g_id]}")
650 | print(f"valid_acc_group_{g_id}: {test_group_acc_list[g_id]}")
651 | print(f"Best Worst Test:{self.best_test_acc_worst}")
652 | print(f'valid_d: {valid_accs_d} || test_d: {test_accs_d} || best_test_d: {self.best_test_acc_d}')
653 |
654 | def concat_dummy(self, z):
655 | def hook(model, input, output):
656 | z.append(output.squeeze())
657 | return torch.cat((output, torch.zeros_like(output)), dim=1)
658 | return hook
659 |
660 | def no_dummy(self, z):
661 | def hook(model, input, output):
662 | z.append(output.squeeze())
663 | return output
664 | return hook
665 |
666 | def no_dummy_input(self, z):
667 | def hook(model, input):
668 | z.append(input[0])
669 | return input
670 | return hook
671 |
672 | def train_vanilla(self, args):
673 | self.criterion = nn.CrossEntropyLoss()
674 | if args.dataset == 'cmnist':
675 | self.model_l = get_model('mlp_DISENTANGLE', self.num_classes, bias = True).to(self.device)
676 | self.model_b = get_model('mlp_DISENTANGLE', self.num_classes, bias = True).to(self.device)
677 | elif args.dataset == 'waterbird':
678 | self.model_l = get_model('resnet_50_pretrained', self.num_classes, bias = True).to(self.device)
679 | self.model_b = get_model('resnet_50_pretrained', self.num_classes, bias = True).to(self.device)
680 | elif args.dataset == 'civilcomments':
681 | self.model_l = get_model('bert-base-uncased_pt', self.num_classes, bias = True).to(self.device)
682 | self.model_b = get_model('bert-base-uncased_pt', self.num_classes, bias = True).to(self.device)
683 | else:
684 | if self.args.use_resnet20: # Use this option only for comparing with LfF
685 | self.model_l = get_model('ResNet20_OURS', self.num_classes).to(self.device)
686 | self.model_b = get_model('ResNet20_OURS', self.num_classes).to(self.device)
687 | print('our resnet20....')
688 | else:
689 | self.model_l = get_model('resnet_DISENTANGLE_pretrained', self.num_classes).to(self.device)
690 | self.model_b = get_model('resnet_DISENTANGLE_pretrained', self.num_classes).to(self.device)
691 |
692 | # self.model_b.load_state_dict(torch.load(os.path.join('./log/waterbird/waterbird_ours_GEC_0.9_ema_0.50_tau_0.10_lambda_2.00_avgtype_mv_batch/result/', 'best_model_b.th'))['state_dict'])
693 |
694 | if args.dataset == 'waterbird':
695 | print('!' * 10 + ' Using SGD ' + '!' * 10)
696 |
697 | self.optimizer_b = torch.optim.SGD(
698 | self.model_b.parameters(),
699 | lr=args.lr,
700 | weight_decay=0,
701 | )
702 | elif args.dataset == 'civilcomments':
703 | print('------------------- AdamW -------------------------------')
704 | self.optimizer_b = torch.optim.SGD(
705 | self.model_b.parameters(), #1e-3
706 | lr=2e-5,
707 | weight_decay=0
708 | )
709 |
710 | n_group = 16
711 |
712 | else:
713 |
714 | self.optimizer_b = torch.optim.Adam(
715 | self.model_b.fc.parameters(),
716 | lr=args.lr,
717 | weight_decay=args.weight_decay,
718 | )
719 |
720 | step = 0
721 |
722 | for epoch in tqdm(range(args.num_epochs)):
723 |
724 | for index, data, attr, image_path, indicator in tqdm(self.train_loader, leave=False):
725 |
726 | data = data.to(self.device)
727 | attr = attr.to(self.device)
728 | label = attr[:, args.target_attr_idx]
729 | spurious_label = attr[:, 1]
730 |
731 | logit_b = self.model_b(data)
732 |
733 | loss_b_update = self.criterion(logit_b, label)
734 |
735 | loss = loss_b_update.mean()
736 |
737 |
738 | self.optimizer_b.zero_grad()
739 | loss.backward()
740 | self.optimizer_b.step()
741 |
742 | ##################################################
743 | #################### LOGGING #####################
744 | ##################################################
745 |
746 | if (step % args.valid_freq == 0) and (args.dataset != 'waterbird'):
747 | print("------------------------ Validation Starts--------------------------------")
748 | self.board_vanilla_acc(step, epoch, inference=None)
749 | print("------------------------ Validation Done --------------------------------")
750 | step += 1
751 |
752 | if args.dataset == 'waterbird':
753 | self.board_ours_acc(epoch, model = 'bias', mode = 'no_dummy' , n_group = 4, save = False)
754 |
755 | if len(random_indices_all_groups) > 0:
756 | mixed_feature = lam * feature[indices_all_groups] + (1 - lam) * feature[random_indices_all_groups]
757 | mixed_correction = lam * correction[indices_all_groups] + (1 - lam) * correction[random_indices_all_groups]
758 | label_a = label[indices_all_groups]
759 | label_b = label[random_indices_all_groups]
760 | else:
761 | mixed_feature = None
762 | label_a, label_b, lam = None, None, None
763 |
764 | return mixed_feature, mixed_correction, label_a, label_b, lam
765 |
766 |
767 | def train_ours(self, args):
768 | epoch, cnt = 0, 0
769 | print('************** main training starts... ************** ')
770 | train_num = len(self.train_dataset)
771 | print('Length of training set: {:d}'.format(train_num))
772 |
773 | self.bias_criterion = GeneralizedCELoss(q=args.q)
774 | self.criterion = LogitCorrectionLoss(eta = 1.0)
775 |
776 | if args.dataset == 'cmnist':
777 | self.model_l = get_model('mlp_DISENTANGLE', self.num_classes, bias = True).to(self.device)
778 | self.model_b = get_model('mlp_DISENTANGLE', self.num_classes, bias = True).to(self.device)
779 | elif args.dataset == 'waterbird':
780 | self.model_l = get_model('resnet_50_pretrained', self.num_classes, bias = True).to(self.device)
781 | self.model_b = get_model('resnet_50_pretrained', self.num_classes, bias = True).to(self.device)
782 | elif args.dataset == 'civilcomments':
783 | self.model_l = get_model('bert-base-uncased_pt', self.num_classes, bias = True).to(self.device)
784 | self.model_b = get_model('bert-base-uncased_pt', self.num_classes, bias = True).to(self.device)
785 | else:
786 | if self.args.use_resnet20: # Use this option only for comparing with LfF
787 | self.model_l = get_model('ResNet20_OURS', self.num_classes).to(self.device)
788 | self.model_b = get_model('ResNet20_OURS', self.num_classes).to(self.device)
789 | print('our resnet20....')
790 | else:
791 | self.model_l = get_model('resnet_DISENTANGLE', self.num_classes).to(self.device)
792 | self.model_b = get_model('resnet_DISENTANGLE', self.num_classes).to(self.device)
793 |
794 | if args.dataset == 'waterbird':
795 | print('!' * 10 + ' Using SGD ' + '!' * 10)
796 |
797 | self.optimizer_l = torch.optim.SGD(
798 | self.model_l.parameters(), #1e-3
799 | lr=args.lr,
800 | weight_decay=args.weight_decay,#1e-3
801 | )
802 |
803 | self.optimizer_b = torch.optim.SGD(
804 | self.model_b.parameters(),
805 | lr=args.lr*0.1,
806 | weight_decay=0,
807 | )
808 | elif args.dataset == 'civilcomments':
809 | self.optimizer_b = torch.optim.SGD(
810 | self.model_b.parameters(),
811 | lr=2e-5,
812 | weight_decay=0
813 | )
814 |
815 | no_decay = ['bias', 'LayerNorm.weight']
816 | optimizer_grouped_parameters = [
817 | {'params': [p for n, p in self.model_l.named_parameters()
818 | if not any(nd in n for nd in no_decay)],
819 | 'weight_decay': args.weight_decay},
820 | {'params': [p for n, p in self.model_l.named_parameters()
821 | if any(nd in n for nd in no_decay)],
822 | 'weight_decay': 0.0}]
823 | self.optimizer_l = optim.AdamW(optimizer_grouped_parameters,
824 | lr=args.lr, eps=1e-8)
825 |
826 | n_group = 16
827 |
828 | else:
829 |
830 | self.optimizer_l = torch.optim.Adam(
831 | self.model_l.parameters(),
832 | lr=args.lr,
833 | weight_decay=args.weight_decay,
834 | )
835 |
836 | self.optimizer_b = torch.optim.Adam(
837 | self.model_b.parameters(),
838 | lr=args.lr*0.1,
839 | weight_decay=args.weight_decay,
840 | )
841 |
842 |
843 | if args.use_lr_decay and args.dataset == 'waterbird':
844 | self.scheduler_b = optim.lr_scheduler.StepLR(self.optimizer_b, step_size=args.lr_decay_epoch, gamma=args.lr_gamma)
845 | self.scheduler_l = optim.lr_scheduler.StepLR(self.optimizer_l, step_size=args.lr_decay_epoch, gamma=args.lr_gamma)
846 | elif args.use_lr_decay and args.dataset == 'civilcomments':
847 | total_updates = args.num_epochs
848 |
849 | self.scheduler_b = get_bert_scheduler(self.optimizer_b, n_epochs=1,#args.num_epochs,
850 | warmup_steps=0,
851 | dataloader=self.train_loader)
852 | self.scheduler_l = get_bert_scheduler(self.optimizer_l, n_epochs=total_updates,#args.num_epochs,
853 | warmup_steps=0,
854 | dataloader=self.train_loader)
855 | else:
856 | self.scheduler_b = optim.lr_scheduler.StepLR(self.optimizer_b, step_size=args.lr_decay_step, gamma=args.lr_gamma)
857 | self.scheduler_l = optim.lr_scheduler.StepLR(self.optimizer_l, step_size=args.lr_decay_step, gamma=args.lr_gamma)
858 |
859 | print(f'criterion: {self.criterion}')
860 | print(f'bias criterion: {self.bias_criterion}')
861 | train_iter = iter(self.train_loader)
862 |
863 | step = 0
864 |
865 | for epoch in tqdm(range(args.num_epochs)):
866 |
867 | for index, data, attr, image_path, indicator in tqdm(self.train_loader, leave=False):
868 |
869 | data = data.to(self.device)
870 | attr = attr.to(self.device)
871 | label = attr[:, args.target_attr_idx].to(self.device)
872 | bias = attr[:, 1].to(self.device)
873 | if args.dataset == 'waterbird':
874 | alpha = sigmoid_rampup(epoch, args.curr_epoch)*0.5
875 | else:
876 | alpha = sigmoid_rampup(step, args.curr_step)*0.5
877 |
878 | if args.dataset == 'cmnist':
879 | # z_l = self.model_l.extract(data)
880 | z_b = self.model_b.extract(data)
881 | pred_align = self.model_b.fc(z_b)
882 | self.sample_margin_ema_b.update(F.softmax(pred_align.detach()/args.tau), index)
883 | pred_align_mv = self.sample_margin_ema_b.parameter[index].clone().detach()
884 | _, pseudo_label = torch.max(pred_align_mv, dim=1)
885 | self.confusion.update(pred_align_mv, label, pseudo_label, fix = None)
886 | correction_matrix = self.confusion.parameter.clone().detach()
887 | if args.avg_type == 'epoch':
888 | correction_matrix = correction_matrix/self.confusion.global_count_.to(self.device)
889 |
890 | correction_delta = correction_matrix[:,pseudo_label]
891 | correction_delta = torch.t(correction_delta)
892 | return_dict = group_mixUp(data, pseudo_label, correction_delta, label, self.num_classes, alpha)
893 | mixed_target_data = return_dict["mixed_feature"]
894 | mixed_biased_prediction = return_dict["mixed_correction"]
895 | label_a = return_dict["label_majority"]
896 | label_b = return_dict["label_minority"]
897 | lam_target = return_dict["lam"]
898 |
899 | z_l = self.model_l.extract(mixed_target_data)
900 | pred_conflict = self.model_l.fc(z_l)
901 |
902 | elif args.dataset == 'civilcomments':
903 |
904 | input_ids = data[:, :, 0]
905 | input_masks = data[:, :, 1]
906 | segment_ids = data[:, :, 2]
907 |
908 | pred_align = self.model_b(input_ids=input_ids,
909 | attention_mask=input_masks,
910 | token_type_ids=segment_ids,
911 | labels=label)[1]
912 |
913 |
914 | self.sample_margin_ema_b.update(F.softmax(pred_align.detach()/args.tau), index)
915 | pred_align_mv = self.sample_margin_ema_b.parameter[index].clone().detach()
916 | _, pseudo_label = torch.max(pred_align_mv, dim=1)
917 | self.confusion.update(pred_align_mv, label, pseudo_label, fix = None)
918 | correction_matrix = self.confusion.parameter.clone().detach()
919 | if args.avg_type == 'epoch':
920 | correction_matrix = correction_matrix/self.confusion.global_count_.to(self.device)
921 |
922 | correction_matrix = correction_matrix/((correction_matrix).sum(dim=0,keepdims =True) + 1e-4)
923 | correction_delta = correction_matrix[:,pseudo_label]
924 | correction_delta = torch.t(correction_delta)
925 |
926 |
927 | z_l = []
928 | hook_fn = self.model_l.dropout.register_forward_pre_hook(self.no_dummy_input(z_l))#register_forward_hook(self.no_dummy_input(z_l))
929 | _ = self.model_l(input_ids=input_ids,
930 | attention_mask=input_masks,
931 | token_type_ids=segment_ids,
932 | labels=label)
933 | hook_fn.remove()
934 | z_l = z_l[0]
935 |
936 | return_dict = group_mixUp(z_l, pseudo_label, correction_delta, label, self.num_classes, alpha)
937 | mixed_target_z_l = return_dict["mixed_feature"]
938 | mixed_biased_prediction = return_dict["mixed_correction"]
939 | label_a = return_dict["label_majority"]
940 | label_b = return_dict["label_minority"]
941 | lam_target = return_dict["lam"]
942 |
943 |
944 | pred_conflict = self.model_l.classifier(self.model_l.dropout(mixed_target_z_l))
945 |
946 | else:
947 | z_b = []
948 | # Use this only for reproducing CIFARC10 of LfF
949 | if self.args.use_resnet20:
950 | hook_fn = self.model_b.layer3.register_forward_hook(self.no_dummy(z_b))
951 | _ = self.model_b(data)
952 | hook_fn.remove()
953 | z_b = z_b[0]
954 |
955 | pred_align = self.model_b.fc(z_b)
956 | self.sample_margin_ema_b.update(F.softmax(pred_align.detach()/args.tau), index)
957 | pred_align_mv = self.sample_margin_ema_b.parameter[index].clone().detach()
958 | _, pseudo_label = torch.max(pred_align_mv, dim=1)
959 | self.confusion.update(pred_align_mv, label, pseudo_label, fix = None)
960 | correction_matrix = self.confusion.parameter.clone().detach()
961 | if args.avg_type == 'epoch':
962 | correction_matrix = correction_matrix/self.confusion.global_count_.to(self.device)
963 | correction_matrix = correction_matrix/(correction_matrix).sum(dim=0,keepdims =True)
964 | correction_delta = correction_matrix[:,pseudo_label]
965 | correction_delta = torch.t(correction_delta)
966 | return_dict = group_mixUp(data, pseudo_label, correction_delta, label, self.num_classes, alpha)
967 | mixed_target_data = return_dict["mixed_feature"]
968 | mixed_biased_prediction = return_dict["mixed_correction"]
969 | label_a = return_dict["label_majority"]
970 | label_b = return_dict["label_minority"]
971 | lam_target = return_dict["lam"]
972 |
973 |
974 | z_l = []
975 | hook_fn = self.model_l.layer3.register_forward_hook(self.no_dummy(z_l))
976 | _ = self.model_l(mixed_target_data)
977 | hook_fn.remove()
978 |
979 | z_l = z_l[0]
980 |
981 | pred_conflict = self.model_l.fc(z_l)
982 |
983 | else:
984 | hook_fn = self.model_b.avgpool.register_forward_hook(self.no_dummy(z_b))
985 | _ = self.model_b(data)
986 | hook_fn.remove()
987 | z_b = z_b[0]
988 |
989 | pred_align = self.model_b.fc(z_b)
990 | self.sample_margin_ema_b.update(F.softmax(pred_align.detach()/args.tau), index)
991 | pred_align_mv = self.sample_margin_ema_b.parameter[index].clone().detach()
992 |
993 | _, pseudo_label = torch.max(pred_align_mv, dim=1)
994 | self.confusion.update(pred_align_mv, label, pseudo_label, fix = None)
995 | correction_matrix = self.confusion.parameter.clone().detach()
996 | if args.avg_type == 'epoch':
997 | correction_matrix = correction_matrix/self.confusion.global_count_.to(self.device)
998 | correction_matrix = correction_matrix/(correction_matrix).sum(dim=0,keepdims =True)
999 | correction_delta = correction_matrix[:,pseudo_label]
1000 | correction_delta = torch.t(correction_delta)
1001 |
1002 | return_dict = group_mixUp(data, pseudo_label, correction_delta, label, self.num_classes, alpha)
1003 | mixed_target_data = return_dict["mixed_feature"]
1004 | mixed_biased_prediction = return_dict["mixed_correction"]
1005 | label_a = return_dict["label_majority"]
1006 | label_b = return_dict["label_minority"]
1007 | lam_target = return_dict["lam"]
1008 |
1009 | z_l = []
1010 | hook_fn = self.model_l.avgpool.register_forward_hook(self.no_dummy(z_l))
1011 | _ = self.model_l(mixed_target_data)
1012 | hook_fn.remove()
1013 |
1014 | z_l = z_l[0]
1015 |
1016 | pred_conflict = self.model_l.fc(z_l)
1017 |
1018 |
1019 |
1020 | self.sample_margin_ema_b.update(F.softmax(pred_align.detach()), index)
1021 |
1022 | loss_dis_conflict = lam_target * self.criterion(pred_conflict, label_a, mixed_biased_prediction) +\
1023 | (1 - lam_target) * self.criterion(pred_conflict, label_b, mixed_biased_prediction)
1024 |
1025 | loss_dis_align = self.bias_criterion(pred_align, label)
1026 | loss = loss_dis_conflict.mean() + args.lambda_dis_align * loss_dis_align.mean() # Eq.2 L_dis
1027 |
1028 |
1029 | self.optimizer_l.zero_grad()
1030 | self.optimizer_b.zero_grad()
1031 |
1032 | loss.backward()
1033 |
1034 | if args.dataset == 'civilcomments':
1035 | torch.nn.utils.clip_grad_norm_(self.model_l.parameters(), 1.0)
1036 |
1037 | if args.use_lr_decay and args.dataset != 'waterbird':
1038 | self.scheduler_b.step()
1039 | self.scheduler_l.step()
1040 |
1041 |
1042 | self.optimizer_l.step()
1043 | self.optimizer_b.step()
1044 |
1045 | if args.use_lr_decay and step % args.lr_decay_step == 0 and args.dataset != 'waterbird':
1046 | print('******* learning rate decay .... ********')
1047 | print(f"self.optimizer_b lr: { self.optimizer_b.param_groups[-1]['lr']}")
1048 | print(f"self.optimizer_l lr: { self.optimizer_l.param_groups[-1]['lr']}")
1049 |
1050 |
1051 | if step > 0 and step % args.save_freq == 0 and args.dataset != 'waterbird' and args.dataset != 'civilcomments':
1052 | self.save_ours(step)
1053 |
1054 |
1055 | if step > 0 and step % args.log_freq == 0 and args.dataset != 'waterbird': #and args.dataset != 'civilcomments':
1056 | confusion_numpy = correction_matrix.cpu().numpy()
1057 | self.board_ours_loss(
1058 | step=step,
1059 | loss_dis_conflict=loss_dis_conflict.mean(),
1060 | loss_dis_align=args.lambda_dis_align * loss_dis_align.mean(),
1061 | confusion=confusion_numpy,
1062 | global_count=self.confusion.global_count_.cpu().numpy()
1063 | )
1064 | if step > 0 and (step % args.valid_freq == 0) and (args.dataset != 'waterbird'): #and (args.dataset != 'civilcomments'):
1065 | print('################################ #epoch {:d} ############################\n'.format(epoch))
1066 | self.board_ours_acc(step, model = 'debias', mode = 'no_dummy', n_group = n_group)
1067 |
1068 | step += 1
1069 |
1070 | if args.use_lr_decay and args.dataset == 'waterbird':
1071 | self.scheduler_b.step()
1072 | self.scheduler_l.step()
1073 |
1074 | if args.use_lr_decay and epoch % args.lr_decay_epoch == 0 and args.dataset == 'waterbird':
1075 | print('******* learning rate decay .... ********')
1076 | print(f"self.optimizer_b lr: { self.optimizer_b.param_groups[-1]['lr']}")
1077 | print(f"self.optimizer_l lr: { self.optimizer_l.param_groups[-1]['lr']}")
1078 |
1079 |
1080 | confusion_numpy = correction_matrix.cpu().numpy()
1081 | if (args.dataset == 'waterbird'):
1082 | self.board_ours_loss(
1083 | step=epoch,
1084 | loss_dis_conflict=loss_dis_conflict.mean(),
1085 | loss_dis_align=args.lambda_dis_align * loss_dis_align.mean(),
1086 | confusion=confusion_numpy,
1087 | global_count=self.confusion.global_count_.cpu().numpy()
1088 | )
1089 | if (args.dataset == 'waterbird'):
1090 | self.board_ours_acc(epoch, model = 'debias', mode = 'no_dummy' , n_group = 4)
1091 | self.confusion.global_count_ = torch.zeros(self.num_classes, self.num_classes)
1092 | if args.avg_type == 'epoch':
1093 | self.confusion.initiate_parameter()
1094 |
1095 |
1096 |
1097 | def test_ours(self, args):
1098 | if args.dataset == 'cmnist':
1099 | self.model_l = get_model('mlp_DISENTANGLE', self.num_classes, bias = True).to(self.device)
1100 | self.model_b = get_model('mlp_DISENTANGLE', self.num_classes, bias = True).to(self.device)
1101 | elif args.dataset == 'waterbird':
1102 | self.model_l = get_model('resnet_50_pretrained', self.num_classes, bias = True).to(self.device)
1103 | self.model_b = get_model('resnet_50_pretrained', self.num_classes, bias = True).to(self.device)
1104 | elif args.dataset == 'civilcomments':
1105 | print('----------------- Load Bert --------------------')
1106 | self.model_l = get_model('bert-base-uncased_pt', self.num_classes, bias = True).to(self.device)
1107 | self.model_b = get_model('bert-base-uncased_pt', self.num_classes, bias = True).to(self.device)
1108 | else:
1109 | if self.args.use_resnet20: # Use this option only for comparing with LfF
1110 | self.model_l = get_model('ResNet20_OURS', self.num_classes).to(self.device)
1111 | self.model_b = get_model('ResNet20_OURS', self.num_classes).to(self.device)
1112 | print('our resnet20....')
1113 | else:
1114 | self.model_l = get_model('resnet_DISENTANGLE', self.num_classes).to(self.device)
1115 | self.model_b = get_model('resnet_DISENTANGLE', self.num_classes).to(self.device)
1116 |
1117 | self.model_l.load_state_dict(torch.load(os.path.join(args.pretrained_path, 'model_200_37.55377996312232.th'))['state_dict'])
1118 | self.board_ours_acc(-1, model = 'debias', mode = 'no_dummy', n_group = 16, eval = True)
1119 |
--------------------------------------------------------------------------------