├── networks ├── __init__.py ├── model_factory.py ├── mlp.py ├── autoencoder.py ├── resnet_dropout.py └── resnet.py ├── trainer ├── __init__.py ├── loss_utils.py ├── vanilla_train.py ├── hsic.py ├── fairhsic.py ├── mfd.py ├── trainer_factory.py ├── reweighting.py └── adv_debiasing.py ├── data_handler ├── __init__.py ├── adult.py ├── compas.py ├── custom_loader.py ├── custom_loader_hsic.py ├── AIF360 │ ├── binary_label_dataset.py │ ├── credit_dataset.py │ ├── dataset.py │ ├── bank_dataset.py │ ├── compas_dataset.py │ ├── adult_dataset.py │ └── standard_dataset.py ├── fairface.py ├── dataloader_factory.py ├── tabular_dataset.py ├── utils.py ├── dataset_factory.py ├── utkface.py ├── utkface_fairface.py ├── celeba.py └── ssl_dataset.py ├── LICENSE ├── utils.py ├── arguments.py ├── main.py ├── README.md ├── main_groupclf.py └── NOTICE /networks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | from networks.model_factory import * 7 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | from trainer.trainer_factory import * 7 | -------------------------------------------------------------------------------- /data_handler/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | from data_handler.dataloader_factory import * 7 | from data_handler.dataset_factory import * 8 | from data_handler.ssl_dataset import * 9 | -------------------------------------------------------------------------------- /data_handler/adult.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | import pandas as pd 7 | from data_handler.AIF360.adult_dataset import AdultDataset 8 | from data_handler.tabular_dataset import TabularDataset 9 | 10 | 11 | class AdultDataset_torch(TabularDataset): 12 | """Adult dataset.""" 13 | name = 'adult' 14 | def __init__(self, root, target_attr='sex', **kwargs): 15 | 16 | dataset = AdultDataset(root_dir=root) 17 | if target_attr == 'sex': 18 | sen_attr_idx = 3 19 | elif target_attr == 'race': 20 | sen_attr_idx = 2 21 | else: 22 | raise Exception('Not allowed group') 23 | 24 | self.num_groups = 2 25 | self.num_classes = 2 26 | 27 | super(AdultDataset_torch, self).__init__(root=root, dataset=dataset, sen_attr_idx=sen_attr_idx, 28 | **kwargs) 29 | 30 | -------------------------------------------------------------------------------- /data_handler/compas.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | import pandas as pd 7 | from data_handler.AIF360.compas_dataset import CompasDataset 8 | from data_handler.tabular_dataset import TabularDataset 9 | 10 | 11 | class CompasDataset_torch(TabularDataset): 12 | """Compas dataset.""" 13 | name = 'compas' 14 | def __init__(self, root, target_attr='race', **kwargs): 15 | 16 | dataset = CompasDataset(root_dir=root) 17 | if target_attr == 'sex': 18 | sen_attr_idx = 0 19 | elif target_attr == 'race': 20 | sen_attr_idx = 2 21 | else: 22 | raise Exception('Not allowed group') 23 | 24 | self.num_groups = 2 25 | self.num_classes = 2 26 | 27 | super(CompasDataset_torch, self).__init__(root=root, dataset=dataset, sen_attr_idx=sen_attr_idx, 28 | **kwargs) 29 | 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | cgl_fairness 2 | Copyright (c) 2022-present NAVER Corp. 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy 5 | of this software and associated documentation files (the "Software"), to deal 6 | in the Software without restriction, including without limitation the rights 7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the Software is 9 | furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in 12 | all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 | THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /networks/model_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | import torch.nn as nn 7 | 8 | from networks.resnet import resnet10, resnet12,resnet18, resnet34, resnet50, resnet101 9 | from networks.mlp import MLP 10 | from networks.resnet_dropout import resnet18_dropout 11 | 12 | class ModelFactory(): 13 | def __init__(self): 14 | pass 15 | 16 | @staticmethod 17 | def get_model(target_model, num_classes=2, img_size=224, pretrained=False, num_groups=2): 18 | 19 | if target_model == 'mlp': 20 | return MLP(feature_size=img_size, hidden_dim=64, num_classes=num_classes) 21 | 22 | elif 'resnet' in target_model: 23 | model_class = eval(target_model) 24 | if pretrained: 25 | model = model_class(pretrained=True, img_size=img_size) 26 | model.fc = nn.Linear(in_features=model.fc.weight.shape[1], out_features=num_classes, bias=True) 27 | else: 28 | model = model_class(pretrained=False, num_classes=num_classes, num_groups=num_groups, img_size=img_size) 29 | return model 30 | 31 | else: 32 | raise NotImplementedError 33 | 34 | -------------------------------------------------------------------------------- /data_handler/custom_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original code: 3 | https://github.com/sangwon79/Fair-Feature-Distillation-for-Visual-Recognition 4 | """ 5 | from copy import copy 6 | import numpy as np 7 | from torch.utils.data.sampler import RandomSampler 8 | import random 9 | 10 | 11 | class Customsampler(RandomSampler): 12 | 13 | def __init__(self, data_source, replacement=False, num_samples=None, batch_size=None, generator=None): 14 | super(Customsampler, self).__init__(data_source=data_source, replacement=replacement, 15 | num_samples=num_samples, generator=generator) 16 | self.l = data_source.num_classes 17 | self.g = data_source.num_groups 18 | self.nbatch_size = batch_size // (self.l*self.g) 19 | self.num_data = data_source.num_data 20 | self.idxs_per_group = data_source.idxs_per_group 21 | 22 | # which one is a group that has the largest number of data poitns 23 | self.max_pos = np.unravel_index(np.argmax(self.num_data), self.num_data.shape) 24 | self.numdata_per_group = (self.num_data[self.max_pos] // (self.nbatch_size+1) + 1) * (self.nbatch_size+1) 25 | 26 | def __iter__(self): 27 | index_list = [] 28 | 29 | for g in range(self.g): 30 | for l in range(self.l): 31 | total = 0 32 | group_index_list = [] 33 | while total < self.numdata_per_group: 34 | tmp = copy(self.idxs_per_group[(g, l)]) 35 | random.shuffle(tmp) 36 | remained_data = self.numdata_per_group - total 37 | if remained_data > len(tmp): 38 | group_index_list.extend(tmp) 39 | else: 40 | group_index_list.extend(tmp[:remained_data]) 41 | break 42 | total += len(tmp) 43 | index_list.append(group_index_list) 44 | 45 | final_list = np.array(index_list) 46 | final_list = final_list.flatten('F') 47 | final_list = list(final_list) 48 | 49 | return iter(final_list) 50 | -------------------------------------------------------------------------------- /networks/mlp.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Function 10 | 11 | class MLP(nn.Module): 12 | def __init__(self, feature_size, hidden_dim, num_classes=None, num_layer=3, adv=False, adv_lambda=1.): 13 | super(MLP, self).__init__() 14 | try: #list 15 | in_features = self.compute_input_size(feature_size) 16 | except : #int 17 | in_features = feature_size 18 | 19 | num_outputs = num_classes 20 | self.adv = adv 21 | if self.adv: 22 | self.adv_lambda = adv_lambda 23 | self._make_layer(in_features, hidden_dim, num_classes, num_layer) 24 | 25 | def forward(self, feature, get_inter=False): 26 | feature = torch.flatten(feature, 1) 27 | if self.adv: 28 | feature = ReverseLayerF.apply(feature, self.adv_lambda) 29 | 30 | h = self.features(feature) 31 | out = self.head(h) 32 | out = out.squeeze() 33 | 34 | if get_inter: 35 | return h, out 36 | else: 37 | return out 38 | 39 | def compute_input_size(self, feature_size): 40 | in_features = 1 41 | for size in feature_size: 42 | in_features = in_features * size 43 | 44 | return in_features 45 | 46 | def _make_layer(self, in_dim, h_dim, num_classes, num_layer): 47 | 48 | if num_layer == 1: 49 | self.features = nn.Identity() 50 | h_dim = in_dim 51 | else: 52 | features = [] 53 | for i in range(num_layer-1): 54 | features.append(nn.Linear(in_dim, h_dim) if i == 0 else nn.Linear(h_dim, h_dim)) 55 | features.append(nn.ReLU()) 56 | self.features = nn.Sequential(*features) 57 | 58 | self.head = nn.Linear(h_dim, num_classes) 59 | 60 | 61 | class ReverseLayerF(Function): 62 | 63 | @staticmethod 64 | def forward(ctx, x, alpha): 65 | ctx.alpha = alpha 66 | 67 | return x.view_as(x) 68 | 69 | @staticmethod 70 | def backward(ctx, grad_output): 71 | return grad_output.neg() * ctx.alpha, None 72 | -------------------------------------------------------------------------------- /data_handler/custom_loader_hsic.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | from copy import copy 7 | import numpy as np 8 | from torch.utils.data.sampler import RandomSampler 9 | import random 10 | 11 | 12 | class Customsampler(RandomSampler): 13 | def __init__(self, data_source, replacement=False, num_samples=None, batch_size=None, generator=None): 14 | super(Customsampler, self).__init__(data_source=data_source, replacement=replacement, 15 | num_samples=num_samples, generator=generator) 16 | self.l = data_source.num_classes 17 | self.g = data_source.num_groups 18 | self.nbatch_size = batch_size // (self.l*self.g) 19 | self.num_data = np.sum(data_source.num_data, axis=0) 20 | 21 | self.idxs_per_group = data_source.idxs_per_group 22 | idxs_per_group = {} 23 | for l in range(self.l): 24 | idxs_per_group[l] = [] 25 | for g in range(self.g): 26 | idxs_per_group[l].extend(self.idxs_per_group[(g, l)]) 27 | 28 | self.idxs_per_group = idxs_per_group 29 | # which one is a group that has the largest number of data poitns 30 | self.max_pos = np.unravel_index(np.argmax(self.num_data), self.num_data.shape) 31 | 32 | self.numdata_per_group = (self.num_data[self.max_pos] // (self.nbatch_size+1) + 1) * (self.nbatch_size+1) 33 | 34 | def __iter__(self): 35 | index_list = [] 36 | 37 | for l in range(self.l): 38 | total = 0 39 | group_index_list = [] 40 | while total < self.numdata_per_group: 41 | tmp = copy(self.idxs_per_group[l]) 42 | random.shuffle(tmp) 43 | remained_data = self.numdata_per_group - total 44 | if remained_data > len(tmp): 45 | group_index_list.extend(tmp) 46 | else: 47 | group_index_list.extend(tmp[:remained_data]) 48 | break 49 | total += len(tmp) 50 | index_list.append(group_index_list) 51 | 52 | final_list = np.array(index_list) 53 | final_list = final_list.flatten('F') 54 | final_list = list(final_list) 55 | 56 | return iter(final_list) 57 | -------------------------------------------------------------------------------- /data_handler/AIF360/binary_label_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original code: 3 | https://github.com/Trusted-AI/AIF360 4 | """ 5 | import numpy as np 6 | 7 | from data_handler.AIF360.structured_dataset import StructuredDataset 8 | 9 | 10 | class BinaryLabelDataset(StructuredDataset): 11 | """Base class for all structured datasets with binary labels.""" 12 | 13 | def __init__(self, favorable_label=1., unfavorable_label=0., **kwargs): 14 | """ 15 | Args: 16 | favorable_label (float): Label value which is considered favorable 17 | (i.e. "positive"). 18 | unfavorable_label (float): Label value which is considered 19 | unfavorable (i.e. "negative"). 20 | **kwargs: StructuredDataset arguments. 21 | """ 22 | self.favorable_label = float(favorable_label) 23 | self.unfavorable_label = float(unfavorable_label) 24 | 25 | super(BinaryLabelDataset, self).__init__(**kwargs) 26 | 27 | def validate_dataset(self): 28 | """Error checking and type validation. 29 | Raises: 30 | ValueError: `labels` must be shape [n, 1]. 31 | ValueError: `favorable_label` and `unfavorable_label` must be the 32 | only values present in `labels`. 33 | """ 34 | # fix scores before validating 35 | if np.all(self.scores == self.labels): 36 | self.scores = (self.scores == self.favorable_label).astype(np.float64) 37 | 38 | super(BinaryLabelDataset, self).validate_dataset() 39 | 40 | # =========================== SHAPE CHECKING =========================== 41 | # Verify if the labels are only 1 column 42 | if self.labels.shape[1] != 1: 43 | raise ValueError("BinaryLabelDataset only supports single-column " 44 | "labels:\n\tlabels.shape = {}".format(self.labels.shape)) 45 | 46 | # =========================== VALUE CHECKING =========================== 47 | # Check if the favorable and unfavorable labels match those in the dataset 48 | if (not set(self.labels.ravel()) <= 49 | set([self.favorable_label, self.unfavorable_label])): 50 | raise ValueError("The favorable and unfavorable labels provided do " 51 | "not match the labels in the dataset.") 52 | -------------------------------------------------------------------------------- /trainer/loss_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original code: 3 | https://github.com/sangwon79/Fair-Feature-Distillation-for-Visual-Recognition 4 | """ 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def mse(inputs, targets): 10 | return (inputs - targets).pow(2).mean() 11 | 12 | 13 | def compute_feature_loss(inputs, t_inputs, student, teacher, device=0, regressor=None): 14 | stu_outputs = student(inputs, get_inter=True) 15 | f_s = stu_outputs[-2] 16 | if regressor is not None: 17 | f_s = regressor.forward(f_s) 18 | 19 | f_s = f_s.view(f_s.shape[0], -1) 20 | stu_logits = stu_outputs[-1] 21 | 22 | tea_outputs = teacher(t_inputs, get_inter=True) 23 | f_t = tea_outputs[-2].to(device) 24 | f_t = f_t.view(f_t.shape[0], -1).detach() 25 | 26 | tea_logits = tea_outputs[-1] 27 | 28 | fitnet_loss = (1 / 2) * (mse(f_s, f_t)) 29 | 30 | return fitnet_loss, stu_logits, tea_logits, f_s, f_t 31 | 32 | 33 | def compute_hinton_loss(outputs, t_outputs=None, teacher=None, t_inputs=None, kd_temp=3, device=0): 34 | if t_outputs is None: 35 | if (t_inputs is not None and teacher is not None): 36 | t_outputs = teacher(t_inputs) 37 | else: 38 | Exception('Nothing is given to compute hinton loss') 39 | 40 | soft_label = F.softmax(t_outputs / kd_temp, dim=1).to(device).detach() 41 | kd_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(outputs / kd_temp, dim=1), 42 | soft_label) * (kd_temp * kd_temp) 43 | 44 | return kd_loss 45 | 46 | 47 | def compute_at_loss(inputs, t_inputs, student, teacher, device=0, for_cifar=False): 48 | stu_outputs = student(inputs, get_inter=True) if not for_cifar else student(inputs, get_inter=True, before_fc=True) 49 | stu_logits = stu_outputs[-1] 50 | f_s = stu_outputs[-2] 51 | 52 | tea_outputs = teacher(t_inputs, get_inter=True) if not for_cifar else teacher(inputs, get_inter=True, before_fc=True) 53 | tea_logits = tea_outputs[-1].to(device) 54 | f_t = tea_outputs[-2].to(device) 55 | attention_loss = (1 / 2) * (at_loss(f_s, f_t)) 56 | return attention_loss, stu_logits, tea_logits, f_s, f_t 57 | 58 | 59 | def at(x): 60 | return F.normalize(x.pow(2).mean(1).view(x.size(0), -1)) 61 | 62 | 63 | def at_loss(x, y): 64 | return (at(x) - at(y)).pow(2).mean() 65 | -------------------------------------------------------------------------------- /data_handler/AIF360/credit_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original code: 3 | https://github.com/Trusted-AI/AIF360 4 | """ 5 | import os 6 | 7 | import pandas as pd 8 | 9 | from data_handler.AIF360.standard_dataset import StandardDataset 10 | 11 | 12 | class CreditDataset(StandardDataset): 13 | """Bank marketing Dataset. 14 | See :file:`aif360/data/raw/bank/README.md`. 15 | """ 16 | 17 | def __init__(self, root_dir='./data/credit', label_name='default payment next month', favorable_classes=[1], 18 | protected_attribute_names=['SEX'], 19 | privileged_classes=[lambda x: x == 1], 20 | instance_weights_name=None, 21 | categorical_features=['EDUCATION', 'MARRIAGE', 'AGE', 'PAY_0', 'PAY_2', 22 | 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6'], 23 | features_to_keep=[], features_to_drop=[], 24 | na_values=[], custom_preprocessing=None, 25 | metadata=None): 26 | """See :obj:`StandardDataset` for a description of the arguments. 27 | By default, this code converts the 'marital' attribute to a binary value 28 | where privileged is `not married` and unprivileged is `married` as in 29 | :obj:`GermanDataset`. 30 | """ 31 | 32 | filepath = os.path.join(root_dir, 'credit.csv') 33 | 34 | try: 35 | df = pd.read_csv(filepath) 36 | except IOError as err: 37 | print("IOError: {}".format(err)) 38 | import sys 39 | sys.exit(1) 40 | 41 | super(CreditDataset, self).__init__(df=df, label_name=label_name, 42 | favorable_classes=favorable_classes, 43 | protected_attribute_names=protected_attribute_names, 44 | privileged_classes=privileged_classes, 45 | instance_weights_name=instance_weights_name, 46 | categorical_features=categorical_features, 47 | features_to_keep=features_to_keep, 48 | features_to_drop=features_to_drop, na_values=na_values, 49 | custom_preprocessing=custom_preprocessing, metadata=metadata) 50 | -------------------------------------------------------------------------------- /data_handler/fairface.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | import os 7 | from os.path import join 8 | from data_handler import GenericDataset 9 | import pandas 10 | from torchvision import transforms 11 | from functools import partial 12 | from data_handler.utils import get_mean_std 13 | 14 | 15 | class Fairface(GenericDataset): 16 | mean, std = get_mean_std('uktface') 17 | train_transform = transforms.Compose( 18 | [transforms.Resize((256, 256)), 19 | transforms.RandomCrop(224), 20 | transforms.RandomHorizontalFlip(), 21 | transforms.ToTensor(), 22 | transforms.Normalize(mean=mean, std=std)] 23 | ) 24 | test_transform = transforms.Compose( 25 | [transforms.Resize((224, 224)), 26 | transforms.ToTensor(), 27 | transforms.Normalize(mean=mean, std=std)] 28 | ) 29 | name = 'fairface' 30 | 31 | def __init__(self, target_attr='age', root='fairface', split='train', seed=0): 32 | transform = self.train_transform if split == 'train' else self.test_transform 33 | GenericDataset.__init__(self, root=root, split=split, seed=seed, transform=transform) 34 | 35 | self.sensitive_attr = 'race' 36 | self.target_attr = target_attr 37 | 38 | fn = partial(join, self.root) 39 | label_file = "fairface_label_{}.csv".format(split) 40 | label_mat = pandas.read_csv(fn(label_file)) 41 | self.feature_mat = self._preprocessing(label_mat) 42 | 43 | def _preprocessing(self, label_mat): 44 | race_dict = { 45 | 'White': 0, 46 | 'Middle Eastern': 0, 47 | 'Black': 1, 48 | 'East Asian': 2, 49 | 'Southeast Asian': 2, 50 | 'Indian': 3 51 | } 52 | 53 | age_dict = { 54 | '0-2': 0, 55 | '3-9': 0, 56 | '10-19': 0, 57 | '20-29': 1, 58 | '30-39': 1, 59 | '40-49': 2, 60 | '50-59': 2, 61 | '60-69': 2, 62 | 'more than 70': 2, 63 | } 64 | 65 | race_idx = 3 66 | age_idx = 1 67 | 68 | feature_mat = [] 69 | for row in label_mat.values: 70 | _age = row[age_idx] 71 | _race = row[race_idx] 72 | if _race not in race_dict.keys(): 73 | continue 74 | race = race_dict[_race] 75 | age = age_dict[_age] 76 | feature_mat.append([race, age, os.path.join(self.root, row[0])]) 77 | 78 | return feature_mat 79 | -------------------------------------------------------------------------------- /data_handler/dataloader_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original code: 3 | https://github.com/sangwon79/Fair-Feature-Distillation-for-Visual-Recognition 4 | """ 5 | from data_handler.dataset_factory import DatasetFactory 6 | 7 | import numpy as np 8 | from torch.utils.data import DataLoader 9 | 10 | 11 | class DataloaderFactory: 12 | def __init__(self): 13 | pass 14 | 15 | @staticmethod 16 | def get_dataloader(name, batch_size=256, seed=0, num_workers=4, 17 | target_attr='Attractive', add_attr=None, labelwise=False, sv_ratio=1, version='', args=None): 18 | if name == 'adult': 19 | target_attr = 'sex' 20 | elif name == 'compas': 21 | target_attr = 'race' 22 | 23 | test_dataset = DatasetFactory.get_dataset(name, split='test', sv_ratio=sv_ratio, version=version, 24 | target_attr=target_attr, seed=seed, add_attr=add_attr) 25 | train_dataset = DatasetFactory.get_dataset(name, split='train', sv_ratio=sv_ratio, version=version, 26 | target_attr=target_attr, seed=seed, add_attr=add_attr) 27 | 28 | def _init_fn(worker_id): 29 | np.random.seed(int(seed)) 30 | 31 | num_classes = test_dataset.num_classes 32 | num_groups = test_dataset.num_groups 33 | 34 | shuffle = True 35 | sampler = None 36 | if labelwise: 37 | if args.method == 'fairhsic': 38 | from data_handler.custom_loader_hsic import Customsampler 39 | sampler = Customsampler(train_dataset, replacement=False, batch_size=batch_size) 40 | else: 41 | from data_handler.custom_loader import Customsampler 42 | sampler = Customsampler(train_dataset, replacement=False, batch_size=batch_size) 43 | 44 | shuffle = False 45 | 46 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, 47 | num_workers=num_workers, worker_init_fn=_init_fn, pin_memory=True, drop_last=False) 48 | 49 | test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False, 50 | num_workers=num_workers, worker_init_fn=_init_fn, pin_memory=True) 51 | 52 | print('# of test data : {}'.format(len(test_dataset))) 53 | print('# of train data : {}'.format(len(train_dataset))) 54 | print('Dataset loaded.') 55 | print('# of classes, # of groups : {}, {}'.format(num_classes, num_groups)) 56 | 57 | return num_classes, num_groups, train_dataloader, test_dataloader 58 | -------------------------------------------------------------------------------- /data_handler/AIF360/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original code: 3 | https://github.com/Trusted-AI/AIF360 4 | """ 5 | from abc import ABC, abstractmethod 6 | import copy 7 | 8 | 9 | class Dataset(ABC): 10 | """Abstract base class for datasets.""" 11 | 12 | @abstractmethod 13 | def __init__(self, **kwargs): 14 | self.metadata = kwargs.pop('metadata', dict()) or dict() 15 | self.metadata.update({ 16 | 'transformer': '{}.__init__'.format(type(self).__name__), 17 | 'params': kwargs, 18 | 'previous': [] 19 | }) 20 | self.validate_dataset() 21 | 22 | def validate_dataset(self): 23 | """Error checking and type validation.""" 24 | pass 25 | 26 | def copy(self, deepcopy=False): 27 | """Convenience method to return a copy of this dataset. 28 | Args: 29 | deepcopy (bool, optional): :func:`~copy.deepcopy` this dataset if 30 | `True`, shallow copy otherwise. 31 | Returns: 32 | Dataset: A new dataset with fields copied from this object and 33 | metadata set accordingly. 34 | """ 35 | cpy = copy.deepcopy(self) if deepcopy else copy.copy(self) 36 | # preserve any user-created fields 37 | cpy.metadata = cpy.metadata.copy() 38 | cpy.metadata.update({ 39 | 'transformer': '{}.copy'.format(type(self).__name__), 40 | 'params': {'deepcopy': deepcopy}, 41 | 'previous': [self] 42 | }) 43 | return cpy 44 | 45 | @abstractmethod 46 | def export_dataset(self): 47 | """Save this Dataset to disk.""" 48 | raise NotImplementedError 49 | 50 | @abstractmethod 51 | def split(self, num_or_size_splits, shuffle=False): 52 | """Split this dataset into multiple partitions. 53 | Args: 54 | num_or_size_splits (array or int): If `num_or_size_splits` is an 55 | int, *k*, the value is the number of equal-sized folds to make 56 | (if *k* does not evenly divide the dataset these folds are 57 | approximately equal-sized). If `num_or_size_splits` is an array 58 | of type int, the values are taken as the indices at which to 59 | split the dataset. If the values are floats (< 1.), they are 60 | considered to be fractional proportions of the dataset at which 61 | to split. 62 | shuffle (bool, optional): Randomly shuffle the dataset before 63 | splitting. 64 | Returns: 65 | list(Dataset): Splits. Contains *k* or `len(num_or_size_splits) + 1` 66 | datasets depending on `num_or_size_splits`. 67 | """ 68 | raise NotImplementedError 69 | -------------------------------------------------------------------------------- /data_handler/tabular_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | import numpy as np 7 | import random 8 | from data_handler import SSLDataset 9 | 10 | 11 | class TabularDataset(SSLDataset): 12 | """Adult dataset.""" 13 | # 1 idx -> sensi 14 | # 2 idx -> label 15 | # 3 idx -> filename or feature (image / tabular) 16 | def __init__(self, dataset, sen_attr_idx, **kwargs): 17 | super(TabularDataset, self).__init__(**kwargs) 18 | self.sen_attr_idx = sen_attr_idx 19 | 20 | dataset_train, dataset_test = dataset.split([0.8], shuffle=True, seed=0) 21 | # features, labels = self._balance_test_set(dataset) 22 | self.dataset = dataset_train if (self.split == 'train') or ('group' in self.version) else dataset_test 23 | 24 | features = np.delete(self.dataset.features, self.sen_attr_idx, axis=1) 25 | mean, std = self._get_mean_n_std(dataset_train.features) 26 | features = (features - mean) / std 27 | 28 | self.groups = np.expand_dims(self.dataset.features[:, self.sen_attr_idx], axis=1) 29 | self.labels = np.squeeze(self.dataset.labels) 30 | 31 | # self.features = self.dataset.features 32 | self.features = np.concatenate((self.groups, self.dataset.labels, features), axis=1) 33 | 34 | # For prepare mean and std from the train dataset 35 | self.num_data, self.idxs_per_group = self._data_count(self.features, self.num_groups, self.num_classes) 36 | 37 | # if semi-supervised learning, 38 | if self.sv_ratio < 1: 39 | # we want the different supervision according to the seed 40 | random.seed(self.seed) 41 | self.features, self.num_data, self.idxs_per_group = self.ssl_processing(self.features, self.num_data, self.idxs_per_group, ) 42 | if 'group' in self.version: 43 | a, b = self.num_groups, self.num_classes 44 | self.num_groups, self.num_classes = b, a 45 | 46 | def get_dim(self): 47 | return self.dataset.features.shape[-1] 48 | 49 | def __getitem__(self, idx): 50 | features = self.features[idx] 51 | group = features[0] 52 | label = features[1] 53 | feature = features[2:] 54 | 55 | if 'group' in self.version: 56 | return np.float32(feature), 0, label, np.int64(group), (idx, 0) 57 | else: 58 | return np.float32(feature), 0, group, np.int64(label), (idx, 0) 59 | 60 | def _get_mean_n_std(self, train_features): 61 | features = np.delete(train_features, self.sen_attr_idx, axis=1) 62 | mean = features.mean(axis=0) 63 | std = features.std(axis=0) 64 | std[std == 0] += 1e-7 65 | return mean, std 66 | -------------------------------------------------------------------------------- /data_handler/AIF360/bank_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original code: 3 | https://github.com/Trusted-AI/AIF360 4 | """ 5 | import os 6 | 7 | import pandas as pd 8 | 9 | from data_handler.AIF360.standard_dataset import StandardDataset 10 | 11 | 12 | class BankDataset(StandardDataset): 13 | """Bank marketing Dataset. 14 | See :file:`aif360/data/raw/bank/README.md`. 15 | """ 16 | 17 | def __init__(self, root_dir='./data/bank', label_name='y', favorable_classes=['yes'], 18 | protected_attribute_names=['age'], 19 | privileged_classes=[lambda x: (x >= 25 and x <= 60)], 20 | instance_weights_name=None, 21 | categorical_features=['job', 'marital', 'education', 'default', 22 | 'housing', 'loan', 'contact', 'month', 'day_of_week', 23 | 'poutcome'], 24 | features_to_keep=[], features_to_drop=[], 25 | na_values=["unknown"], custom_preprocessing=None, 26 | metadata=None): 27 | """See :obj:`StandardDataset` for a description of the arguments. 28 | By default, this code converts the 'marital' attribute to a binary value 29 | where privileged is `not married` and unprivileged is `married` as in 30 | :obj:`GermanDataset`. 31 | """ 32 | 33 | filepath = os.path.join(root_dir, 'bank-additional-full.csv') 34 | 35 | try: 36 | df = pd.read_csv(filepath, sep=';', na_values=na_values) 37 | except IOError as err: 38 | print("IOError: {}".format(err)) 39 | print("To use this class, please download the following file:") 40 | print("\n\thttps://archive.ics.uci.edu/ml/machine-learning-databases/00222/bank-additional.zip") 41 | print("\nunzip it and place the files, as-is, in the folder:") 42 | print("\n\t{}\n".format(root_dir)) 43 | import sys 44 | sys.exit(1) 45 | 46 | super(BankDataset, self).__init__(df=df, label_name=label_name, 47 | favorable_classes=favorable_classes, 48 | protected_attribute_names=protected_attribute_names, 49 | privileged_classes=privileged_classes, 50 | instance_weights_name=instance_weights_name, 51 | categorical_features=categorical_features, 52 | features_to_keep=features_to_keep, 53 | features_to_drop=features_to_drop, na_values=na_values, 54 | custom_preprocessing=custom_preprocessing, metadata=metadata) 55 | -------------------------------------------------------------------------------- /data_handler/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | import torch 7 | from torch.utils.data import DataLoader 8 | import numpy as np 9 | 10 | 11 | def get_mean_std(dataset, skew_ratio=0.8): 12 | mean, std = None, None 13 | 14 | if dataset == 'utkface': 15 | mean = [0.5960, 0.4573, 0.3921] 16 | std = [0.2586, 0.2314, 0.2275] 17 | 18 | elif dataset == 'cifar10s': 19 | # for skew 0.8 20 | mean = [0.4871, 0.4811, 0.4632] 21 | std = [0.2431, 0.2414, 0.2506] 22 | 23 | elif dataset == 'celeba': 24 | # default target is 'Attractive' 25 | mean = [0.5063, 0.4258, 0.3832] 26 | std = [0.3107, 0.2904, 0.2897] 27 | 28 | elif dataset == 'imagenet': 29 | mean = [0.485, 0.456, 0.406] 30 | std = [0.229, 0.224, 0.225] 31 | 32 | return mean, std 33 | 34 | 35 | def predict_group(model, loader, args): 36 | model.cuda('cuda:{}'.format(args.device)) 37 | if args.slversion == 3: 38 | filename = 'trained_models/group_clf/utkface/scratch/resnet18_seed{}_epochs70_bs128_lr0.001_sv{}_version0.0.pt' 39 | elif args.slversion == 5: 40 | filename = 'trained_models/group_clf_pretrain/utkface/scratch/resnet18_seed{}_epochs70_bs128_lr0.001_sv{}_version0.0.pt' 41 | path = filename.format(str(args.seed), str(args.sv)) 42 | model.load_state_dict(torch.load(path, map_location=torch.device('cuda:{}'.format(args.device)))) 43 | 44 | features = loader.dataset.features 45 | 46 | dataloader = DataLoader(loader.dataset, batch_size=args.batch_size, shuffle=False, 47 | num_workers=args.num_workers, pin_memory=True, drop_last=False) 48 | model.eval() 49 | with torch.no_grad(): 50 | for i, data in enumerate(dataloader): 51 | inputs, _, groups, labels, (idxs, _) = data 52 | if (groups == -1).sum() == 0: 53 | continue 54 | 55 | if args.cuda: 56 | inputs = inputs.cuda() 57 | groups = groups.cuda() 58 | idxs = idxs.cuda() 59 | inputs = inputs[groups == -1] 60 | idxs = idxs[groups == -1] 61 | 62 | outputs = model(inputs) 63 | preds = torch.argmax(outputs, 1) 64 | for j, idx in enumerate(idxs.cpu().numpy()): 65 | features[idx][1] = preds.cpu()[j] 66 | if i % args.term == 0: 67 | print('[{}] in group prediction'.format(i)) 68 | 69 | if args.labelwise: 70 | loader.dataset.num_data, loader.dataset.idxs_per_group = loader.dataset._data_count() 71 | from data_handler.custom_loader import Customsampler 72 | 73 | def _init_fn(worker_id): 74 | np.random.seed(int(args.seed)) 75 | sampler = Customsampler(loader.dataset, replacement=False, batch_size=args.batch_size) 76 | train_dataloader = DataLoader(loader.dataset, batch_size=args.batch_size, sampler=sampler, 77 | num_workers=args.num_workers, worker_init_fn=_init_fn, pin_memory=True, drop_last=True) 78 | 79 | del dataloader 80 | del model 81 | del loader 82 | if args.labelwise: 83 | return train_dataloader 84 | -------------------------------------------------------------------------------- /trainer/vanilla_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | from __future__ import print_function 7 | 8 | import time 9 | from utils import get_accuracy 10 | import trainer 11 | 12 | 13 | class Trainer(trainer.GenericTrainer): 14 | def __init__(self, args, **kwargs): 15 | super().__init__(args=args, **kwargs) 16 | 17 | def train(self, train_loader, test_loader, epochs, criterion=None, writer=None): 18 | global loss_set 19 | if criterion == None: 20 | criterion = self.criterion 21 | model = self.model 22 | model.train() 23 | 24 | for epoch in range(epochs): 25 | self._train_epoch(epoch, train_loader, model, criterion) 26 | 27 | eval_start_time = time.time() 28 | eval_loss, eval_acc, eval_deom, eval_deoa, eval_subgroup_acc = self.evaluate(self.model, test_loader, criterion) 29 | eval_end_time = time.time() 30 | print('[{}/{}] Method: {} ' 31 | 'Test Loss: {:.3f} Test Acc: {:.2f} Test DEOM {:.2f} [{:.2f} s]'.format 32 | (epoch + 1, epochs, self.method, 33 | eval_loss, eval_acc, eval_deom, (eval_end_time - eval_start_time))) 34 | if self.record: 35 | train_loss, train_acc, train_deom, train_deoa, train_subgroup_acc = self.evaluate(self.model, train_loader, self.criterion) 36 | writer.add_scalar('train_loss', train_loss, epoch) 37 | writer.add_scalar('train_acc', train_acc, epoch) 38 | writer.add_scalar('train_deom', train_deom, epoch) 39 | writer.add_scalar('train_deoa', train_deoa, epoch) 40 | writer.add_scalar('eval_loss', eval_loss, epoch) 41 | writer.add_scalar('eval_acc', eval_acc, epoch) 42 | writer.add_scalar('eval_deom', eval_deom, epoch) 43 | writer.add_scalar('eval_deoa', eval_deoa, epoch) 44 | 45 | eval_contents = {} 46 | train_contents = {} 47 | for g in range(num_groups): 48 | for l in range(num_classes): 49 | eval_contents[f'g{g},l{l}'] = eval_subgroup_acc[g, l] 50 | train_contents[f'g{g},l{l}'] = train_subgroup_acc[g, l] 51 | writer.add_scalars('eval_subgroup_acc', eval_contents, epoch) 52 | writer.add_scalars('train_subgroup_acc', train_contents, epoch) 53 | 54 | if self.scheduler is not None and 'Reduce' in type(self.scheduler).__name__: 55 | self.scheduler.step(eval_loss) 56 | else: 57 | self.scheduler.step() 58 | print('Training Finished!') 59 | 60 | def _train_epoch(self, epoch, train_loader, model, criterion=None): 61 | model.train() 62 | 63 | running_acc = 0.0 64 | running_loss = 0.0 65 | batch_start_time = time.time() 66 | for i, data in enumerate(train_loader): 67 | # Get the inputs 68 | inputs, _, groups, targets, _ = data 69 | labels = targets 70 | 71 | if self.cuda: 72 | inputs = inputs.cuda(device=self.device) 73 | labels = labels.cuda(device=self.device) 74 | outputs = model(inputs) 75 | if criterion is not None: 76 | loss = criterion(outputs, labels) 77 | else: 78 | loss = self.criterion(outputs, labels) 79 | running_loss += loss.item() 80 | running_acc += get_accuracy(outputs, labels) 81 | 82 | self.optimizer.zero_grad() 83 | loss.backward() 84 | self.optimizer.step() 85 | 86 | if i % self.term == self.term-1: # print every self.term mini-batches 87 | avg_batch_time = time.time()-batch_start_time 88 | print('[{}/{}, {:5d}] Method: {} Train Loss: {:.3f} Train Acc: {:.2f} ' 89 | '[{:.2f} s/batch]'.format 90 | (epoch + 1, self.epochs, i+1, self.method, running_loss / self.term, running_acc / self.term, 91 | avg_batch_time/self.term)) 92 | 93 | running_loss = 0.0 94 | running_acc = 0.0 95 | batch_start_time = time.time() 96 | -------------------------------------------------------------------------------- /data_handler/dataset_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | import importlib 7 | import torch.utils.data as data 8 | import numpy as np 9 | from collections import defaultdict 10 | 11 | dataset_dict = { 12 | 'utkface': ['data_handler.utkface', 'UTKFaceDataset'], 13 | 'utkface_fairface': ['data_handler.utkface_fairface', 'UTKFaceFairface_Dataset'], 14 | 'celeba': ['data_handler.celeba', 'CelebA'], 15 | 'adult': ['data_handler.adult', 'AdultDataset_torch'], 16 | 'compas': ['data_handler.compas', 'CompasDataset_torch'], 17 | 'cifar100s': ['data_handler.cifar100s', 'CIFAR_100S'], 18 | } 19 | 20 | 21 | class DatasetFactory: 22 | def __init__(self): 23 | pass 24 | 25 | @staticmethod 26 | def get_dataset(name, split='Train', seed=0, sv_ratio=1, version=1, target_attr='Attractive', add_attr=None, ups_iter=0): 27 | root = f'./data/{name}' if name != 'utkface_fairface' else './data/utkface' 28 | kwargs = { 29 | 'root': root, 30 | 'split': split, 31 | 'seed': seed, 32 | 'sv_ratio': sv_ratio, 33 | 'version': version, 34 | 'ups_iter': ups_iter 35 | } 36 | 37 | if name not in dataset_dict.keys(): 38 | raise Exception('Not allowed method') 39 | 40 | if name == 'celeba': 41 | kwargs['add_attr'] = add_attr 42 | kwargs['target_attr'] = target_attr 43 | elif name == 'adult': 44 | kwargs['target_attr'] = target_attr 45 | elif name == 'compas': 46 | kwargs['target_attr'] = target_attr 47 | 48 | module = importlib.import_module(dataset_dict[name][0]) 49 | class_ = getattr(module, dataset_dict[name][1]) 50 | return class_(**kwargs) 51 | 52 | 53 | class GenericDataset(data.Dataset): 54 | def __init__(self, root, split='train', transform=None, seed=0): 55 | self.root = root 56 | self.split = split 57 | self.transform = transform 58 | self.seed = seed 59 | self.num_data = None 60 | 61 | def __len__(self): 62 | return np.sum(self.num_data) 63 | 64 | def _data_count(self, features, num_groups, num_classes): 65 | idxs_per_group = defaultdict(lambda: []) 66 | data_count = np.zeros((num_groups, num_classes), dtype=int) 67 | 68 | for idx, i in enumerate(features): 69 | s, l = int(i[0]), int(i[1]) 70 | data_count[s, l] += 1 71 | idxs_per_group[(s, l)].append(idx) 72 | 73 | print(f'mode : {self.split}') 74 | for i in range(num_groups): 75 | print('# of %d group data : ' % i, data_count[i, :]) 76 | return data_count, idxs_per_group 77 | 78 | def _make_data(self, features, num_groups, num_classes): 79 | # if the original dataset not is divided into train / test set, this function is used 80 | min_cnt = 100 81 | data_count = np.zeros((num_groups, num_classes), dtype=int) 82 | tmp = [] 83 | for i in reversed(self.features): 84 | s, l = int(i[0]), int(i[1]) 85 | data_count[s, l] += 1 86 | if data_count[s, l] <= min_cnt: 87 | features.remove(i) 88 | tmp.append(i) 89 | 90 | train_data = features 91 | test_data = tmp 92 | return train_data, test_data 93 | 94 | def _make_weights(self): 95 | group_weights = len(self) / self.num_data 96 | weights = [group_weights[s, l] for s, l, _ in self.features] 97 | return weights 98 | 99 | def _balance_test_data(self, num_data, num_groups, num_classes): 100 | print('balance test data...') 101 | # if the original dataset is divided into train / test set, this function is used 102 | num_data_min = np.min(num_data) 103 | print('min : ', num_data_min) 104 | data_count = np.zeros((num_groups, num_classes), dtype=int) 105 | new_features = [] 106 | for idx, i in enumerate(self.features): 107 | s, l = int(i[0]), int(i[1]) 108 | if data_count[s, l] < num_data_min: 109 | new_features.append(i) 110 | data_count[s, l] += 1 111 | 112 | return new_features 113 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original code: 3 | https://github.com/sangwon79/Fair-Feature-Distillation-for-Visual-Recognition 4 | """ 5 | import torch 6 | import numpy as np 7 | import random 8 | import os 9 | import torch.nn.functional as F 10 | 11 | 12 | def list_files(root, suffix, prefix=False): 13 | root = os.path.expanduser(root) 14 | files = list( 15 | filter( 16 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 17 | os.listdir(root) 18 | ) 19 | ) 20 | if prefix is True: 21 | files = [os.path.join(root, d) for d in files] 22 | return files 23 | 24 | 25 | def set_seed(seed): 26 | torch.manual_seed(seed) 27 | # torch.cuda.manual_seed(seed) 28 | np.random.seed(seed) 29 | random.seed(seed) 30 | torch.backends.cudnn.benchmark = False 31 | torch.backends.cudnn.deterministic = True 32 | 33 | 34 | def get_accuracy(outputs, labels, binary=False, reduction='mean'): 35 | # if multi-label classification 36 | if len(labels.size()) > 1: 37 | outputs = (outputs > 0.0).float() 38 | correct = ((outputs == labels)).float().sum() 39 | total = torch.tensor(labels.shape[0] * labels.shape[1], dtype=torch.float) 40 | avg = correct / total 41 | return avg.item() 42 | 43 | if binary: 44 | predictions = (torch.sigmoid(outputs) >= 0.5).float() 45 | else: 46 | predictions = torch.argmax(outputs, 1) 47 | 48 | c = (predictions == labels).float().squeeze() 49 | if reduction == 'none': 50 | return c 51 | else: 52 | accuracy = torch.mean(c) 53 | return accuracy.item() 54 | 55 | 56 | def check_log_dir(log_dir): 57 | try: 58 | if not os.path.isdir(log_dir): 59 | os.makedirs(log_dir) 60 | except OSError: 61 | print("Failed to create directory!!") 62 | 63 | 64 | class FitnetRegressor(torch.nn.Module): 65 | def __init__(self, in_feature, out_feature): 66 | super(FitnetRegressor, self).__init__() 67 | self.in_feature = in_feature 68 | self.out_feature = out_feature 69 | 70 | self.regressor = torch.nn.Conv2d(in_feature, out_feature, 1, bias=False) 71 | torch.nn.init.kaiming_normal_(self.regressor.weight, mode='fan_out', nonlinearity='relu') 72 | self.regressor.weight.data.uniform_(-0.005, 0.005) 73 | 74 | def forward(self, feature): 75 | if feature.dim() == 2: 76 | feature = feature.unsqueeze(2).unsqueeze(3) 77 | 78 | return F.relu(self.regressor(feature)) 79 | 80 | 81 | def make_log_name(args): 82 | log_name = args.model 83 | 84 | if args.mode == 'eva': 85 | log_name = args.modelpath.split('/')[-1] 86 | # remove .pt from name 87 | log_name = log_name[:-3] 88 | 89 | else: 90 | if args.pretrained: 91 | log_name += '_pretrained' 92 | log_name += '_seed{}_epochs{}_bs{}_lr{}'.format(args.seed, args.epochs, args.batch_size, args.lr) 93 | if args.method == 'adv': 94 | log_name += '_lamb{}_eta{}'.format(args.lamb, args.eta) 95 | 96 | elif args.method == 'scratch_mmd' or args.method == 'kd_mfd': 97 | log_name += '_{}'.format(args.kernel) 98 | log_name += '_sigma{}'.format(args.sigma) if args.kernel == 'rbf' else '' 99 | log_name += '_{}'.format(args.lambhf) 100 | 101 | elif args.method == 'reweighting': 102 | log_name += '_eta{}_iter{}'.format(args.eta, args.iteration) 103 | 104 | elif 'groupdro' in args.method: 105 | log_name += '_gamma{}'.format(args.gamma) 106 | 107 | if args.labelwise: 108 | log_name += '_labelwise' 109 | 110 | if args.teacher_path is not None or args.method == 'fairhsic': 111 | log_name += '_lamb{}'.format(args.lamb) 112 | log_name += '_from_{}'.format(args.teacher_type) 113 | 114 | if args.dataset == 'celeba': 115 | if args.target != 'Attractive': 116 | log_name += '_{}'.format(args.target) 117 | if args.add_attr is not None: 118 | log_name += '_{}'.format(args.add_attr) 119 | if args.sv < 1: 120 | log_name += '_sv{}'.format(args.sv) 121 | log_name += '_{}'.format(args.version) 122 | 123 | return log_name 124 | -------------------------------------------------------------------------------- /data_handler/AIF360/compas_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original code: 3 | https://github.com/Trusted-AI/AIF360 4 | """ 5 | import os 6 | 7 | import pandas as pd 8 | 9 | from data_handler.AIF360.standard_dataset import StandardDataset 10 | 11 | 12 | default_mappings = { 13 | 'label_maps': [{1.0: 'Did recid.', 0.0: 'No recid.'}], 14 | 'protected_attribute_maps': [{0.0: 'Male', 1.0: 'Female'}, 15 | {1.0: 'Caucasian', 0.0: 'Not Caucasian'}] 16 | } 17 | 18 | 19 | def default_preprocessing(df): 20 | """Perform the same preprocessing as the original analysis: 21 | https://github.com/propublica/compas-analysis/blob/master/Compas%20Analysis.ipynb 22 | """ 23 | return df[(df.days_b_screening_arrest <= 30) 24 | & (df.days_b_screening_arrest >= -30) 25 | & (df.is_recid != -1) 26 | & (df.c_charge_degree != 'O') 27 | & (df.score_text != 'N/A')] 28 | 29 | 30 | class CompasDataset(StandardDataset): 31 | """ProPublica COMPAS Dataset. 32 | See :file:`aif360/data/raw/compas/README.md`. 33 | """ 34 | 35 | def __init__(self, root_dir='./data/compas', 36 | label_name='two_year_recid', favorable_classes=[0], 37 | protected_attribute_names=['sex', 'race'], 38 | privileged_classes=[['Female'], ['Caucasian']], 39 | instance_weights_name=None, 40 | categorical_features=['age_cat', 'c_charge_degree', 'c_charge_desc'], 41 | features_to_keep=['sex', 'age', 'age_cat', 'race', 42 | 'juv_fel_count', 'juv_misd_count', 'juv_other_count', 43 | 'priors_count', 'c_charge_degree', 'c_charge_desc', 44 | 'two_year_recid'], 45 | features_to_drop=[], na_values=[], 46 | custom_preprocessing=default_preprocessing, 47 | metadata=default_mappings): 48 | """See :obj:`StandardDataset` for a description of the arguments. 49 | Note: The label value 0 in this case is considered favorable (no 50 | recidivism). 51 | Examples: 52 | In some cases, it may be useful to keep track of a mapping from 53 | `float -> str` for protected attributes and/or labels. If our use 54 | case differs from the default, we can modify the mapping stored in 55 | `metadata`: 56 | >>> label_map = {1.0: 'Did recid.', 0.0: 'No recid.'} 57 | >>> protected_attribute_maps = [{1.0: 'Male', 0.0: 'Female'}] 58 | >>> cd = CompasDataset(protected_attribute_names=['sex'], 59 | ... privileged_classes=[['Male']], metadata={'label_map': label_map, 60 | ... 'protected_attribute_maps': protected_attribute_maps}) 61 | Now this information will stay attached to the dataset and can be 62 | used for more descriptive visualizations. 63 | """ 64 | 65 | filepath = os.path.join(root_dir, 'compas-scores-two-years.csv') 66 | 67 | try: 68 | df = pd.read_csv(filepath, index_col='id', na_values=na_values) 69 | except IOError as err: 70 | print("IOError: {}".format(err)) 71 | print("To use this class, please download the following file:") 72 | print("\n\thttps://raw.githubusercontent.com/propublica/compas-analysis/master/compas-scores-two-years.csv") 73 | print("\nand place it, as-is, in the folder:") 74 | print("\n\t{}\n".format(root_dir)) 75 | import sys 76 | sys.exit(1) 77 | 78 | super(CompasDataset, self).__init__(df=df, label_name=label_name, 79 | favorable_classes=favorable_classes, 80 | protected_attribute_names=protected_attribute_names, 81 | privileged_classes=privileged_classes, 82 | instance_weights_name=instance_weights_name, 83 | categorical_features=categorical_features, 84 | features_to_keep=features_to_keep, 85 | features_to_drop=features_to_drop, na_values=na_values, 86 | custom_preprocessing=custom_preprocessing, metadata=metadata) 87 | -------------------------------------------------------------------------------- /networks/autoencoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | 9 | # Conv Layer 10 | class ConvLayer(nn.Module): 11 | def __init__(self, in_channels, out_channels, kernel_size, stride): 12 | super(ConvLayer, self).__init__() 13 | padding = kernel_size // 2 14 | self.reflection_pad = nn.ReflectionPad2d(padding) 15 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride) #, padding) 16 | 17 | def forward(self, x): 18 | out = self.reflection_pad(x) 19 | out = self.conv2d(out) 20 | return out 21 | 22 | # Upsample Conv Layer 23 | class UpsampleConvLayer(nn.Module): 24 | def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None): 25 | super(UpsampleConvLayer, self).__init__() 26 | self.upsample = upsample 27 | if upsample: 28 | self.upsample = nn.Upsample(scale_factor=upsample, mode='nearest') 29 | reflection_padding = kernel_size // 2 30 | self.reflection_pad = nn.ReflectionPad2d(reflection_padding) 31 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride) 32 | 33 | def forward(self, x): 34 | if self.upsample: 35 | x = self.upsample(x) 36 | out = self.reflection_pad(x) 37 | out = self.conv2d(out) 38 | return out 39 | 40 | # Residual Block 41 | # adapted from pytorch tutorial 42 | # https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02- 43 | # intermediate/deep_residual_network/main.py 44 | class ResidualBlock(nn.Module): 45 | def __init__(self, channels): 46 | super(ResidualBlock, self).__init__() 47 | self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1) 48 | self.in1 = nn.InstanceNorm2d(channels, affine=True) 49 | self.relu = nn.ReLU() 50 | self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1) 51 | self.in2 = nn.InstanceNorm2d(channels, affine=True) 52 | 53 | def forward(self, x): 54 | residual = x 55 | out = self.relu(self.in1(self.conv1(x))) 56 | out = self.in2(self.conv2(out)) 57 | out = out + residual 58 | out = self.relu(out) 59 | return out 60 | 61 | # Image Transform Network 62 | class ImageTransformNet(nn.Module): 63 | def __init__(self): 64 | super(ImageTransformNet, self).__init__() 65 | 66 | # nonlineraity 67 | self.relu = nn.ReLU() 68 | self.tanh = nn.Tanh() 69 | 70 | # encoding layers 71 | self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1) 72 | self.in1_e = nn.InstanceNorm2d(32, affine=True) 73 | 74 | self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2) 75 | self.in2_e = nn.InstanceNorm2d(64, affine=True) 76 | 77 | self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2) 78 | self.in3_e = nn.InstanceNorm2d(128, affine=True) 79 | 80 | # residual layers 81 | self.res1 = ResidualBlock(128) 82 | self.res2 = ResidualBlock(128) 83 | self.res3 = ResidualBlock(128) 84 | self.res4 = ResidualBlock(128) 85 | self.res5 = ResidualBlock(128) 86 | 87 | # decoding layers 88 | self.deconv3 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2 ) 89 | self.in3_d = nn.InstanceNorm2d(64, affine=True) 90 | 91 | self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2 ) 92 | self.in2_d = nn.InstanceNorm2d(32, affine=True) 93 | 94 | self.deconv1 = UpsampleConvLayer(32, 3, kernel_size=9, stride=1) 95 | self.in1_d = nn.InstanceNorm2d(3, affine=True) 96 | 97 | def forward(self, x): 98 | # encode 99 | y = self.relu(self.in1_e(self.conv1(x))) 100 | y = self.relu(self.in2_e(self.conv2(y))) 101 | y = self.relu(self.in3_e(self.conv3(y))) 102 | 103 | # residual layers 104 | y = self.res1(y) 105 | y = self.res2(y) 106 | y = self.res3(y) 107 | y = self.res4(y) 108 | y = self.res5(y) 109 | 110 | # decode 111 | y = self.relu(self.in3_d(self.deconv3(y))) 112 | y = self.relu(self.in2_d(self.deconv2(y))) 113 | #y = self.tanh(self.in1_d(self.deconv1(y))) 114 | y = self.deconv1(y) 115 | 116 | return y 117 | -------------------------------------------------------------------------------- /trainer/hsic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original code: 3 | https://github.com/clovaai/rebias 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | def to_numpy(x): 10 | """convert Pytorch tensor to numpy array 11 | """ 12 | return x.clone().detach().cpu().numpy() 13 | 14 | 15 | class HSIC(nn.Module): 16 | """Base class for the finite sample estimator of Hilbert-Schmidt Independence Criterion (HSIC) 17 | ..math:: HSIC (X, Y) := || C_{x, y} ||^2_{HS}, where HSIC (X, Y) = 0 iif X and Y are independent. 18 | Empirically, we use the finite sample estimator of HSIC (with m observations) by, 19 | (1) biased estimator (HSIC_0) 20 | Gretton, Arthur, et al. "Measuring statistical dependence with Hilbert-Schmidt norms." 2005. 21 | :math: (m - 1)^2 tr KHLH. 22 | where K_{ij} = kernel_x (x_i, x_j), L_{ij} = kernel_y (y_i, y_j), H = 1 - m^{-1} 1 1 (Hence, K, L, H are m by m matrices). 23 | (2) unbiased estimator (HSIC_1) 24 | Song, Le, et al. "Feature selection via dependence maximization." 2012. 25 | :math: \frac{1}{m (m - 3)} \bigg[ tr (\tilde K \tilde L) + \frac{1^\top \tilde K 1 1^\top \tilde L 1}{(m-1)(m-2)} - \frac{2}{m-2} 1^\top \tilde K \tilde L 1 \bigg]. 26 | where \tilde K and \tilde L are related to K and L by the diagonal entries of \tilde K_{ij} and \tilde L_{ij} are set to zero. 27 | Parameters 28 | ---------- 29 | sigma_x : float 30 | the kernel size of the kernel function for X. 31 | sigma_y : float 32 | the kernel size of the kernel function for Y. 33 | algorithm: str ('unbiased' / 'biased') 34 | the algorithm for the finite sample estimator. 'unbiased' is used for our paper. 35 | reduction: not used (for compatibility with other losses). 36 | """ 37 | def __init__(self, sigma_x, sigma_y=None, algorithm='unbiased', 38 | reduction=None): 39 | super(HSIC, self).__init__() 40 | 41 | if sigma_y is None: 42 | sigma_y = sigma_x 43 | 44 | self.sigma_x = sigma_x 45 | self.sigma_y = sigma_y 46 | 47 | if algorithm == 'biased': 48 | self.estimator = self.biased_estimator 49 | elif algorithm == 'unbiased': 50 | self.estimator = self.unbiased_estimator 51 | else: 52 | raise ValueError('invalid estimator: {}'.format(algorithm)) 53 | 54 | def _kernel_x(self, X): 55 | raise NotImplementedError 56 | 57 | def _kernel_y(self, Y): 58 | raise NotImplementedError 59 | 60 | def biased_estimator(self, input1, input2): 61 | """Biased estimator of Hilbert-Schmidt Independence Criterion 62 | Gretton, Arthur, et al. "Measuring statistical dependence with Hilbert-Schmidt norms." 2005. 63 | """ 64 | K = self._kernel_x(input1) 65 | L = self._kernel_y(input2) 66 | 67 | KH = K - K.mean(0, keepdim=True) 68 | LH = L - L.mean(0, keepdim=True) 69 | 70 | N = len(input1) 71 | 72 | return torch.trace(KH @ LH / (N - 1) ** 2) 73 | 74 | def unbiased_estimator(self, input1, input2): 75 | """Unbiased estimator of Hilbert-Schmidt Independence Criterion 76 | Song, Le, et al. "Feature selection via dependence maximization." 2012. 77 | """ 78 | kernel_XX = self._kernel_x(input1) 79 | kernel_YY = self._kernel_y(input2) 80 | 81 | tK = kernel_XX - torch.diag(kernel_XX) 82 | tL = kernel_YY - torch.diag(kernel_YY) 83 | N = len(input1) 84 | 85 | hsic = ( 86 | torch.trace(tK @ tL) 87 | + (torch.sum(tK) * torch.sum(tL) / (N - 1) / (N - 2)) 88 | - (2 * torch.sum(tK, 0).dot(torch.sum(tL, 0)) / (N - 2)) 89 | ) 90 | 91 | return hsic / (N * (N - 3)) 92 | 93 | def forward(self, input1, input2, **kwargs): 94 | return self.estimator(input1, input2) 95 | 96 | 97 | class RbfHSIC(HSIC): 98 | """Radial Basis Function (RBF) kernel HSIC implementation. 99 | """ 100 | def _kernel(self, X, sigma): 101 | X = X.view(len(X), -1) 102 | Xn = X.norm(2, dim=1, keepdim=True) 103 | X = X.div(Xn) 104 | XX = X @ X.t() 105 | X_sqnorms = torch.diag(XX) 106 | X_L2 = -2 * XX + X_sqnorms.unsqueeze(1) + X_sqnorms.unsqueeze(0) 107 | X_L2 = X_L2.clamp(1e-12) 108 | sigma_avg = X_L2.mean().detach() 109 | gamma = 1/(2*sigma_avg) 110 | # gamma = 1 / (2 * sigma ** 2) 111 | 112 | kernel_XX = torch.exp(-gamma * X_L2) 113 | return kernel_XX 114 | 115 | def _kernel_x(self, X): 116 | return self._kernel(X, self.sigma_x) 117 | 118 | def _kernel_y(self, Y): 119 | return self._kernel(Y, self.sigma_y) 120 | 121 | 122 | class MinusRbfHSIC(RbfHSIC): 123 | """``Minus'' RbfHSIC for the ``max'' optimization. 124 | """ 125 | def forward(self, input1, input2, **kwargs): 126 | return -self.estimator(input1, input2) 127 | -------------------------------------------------------------------------------- /data_handler/utkface.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | from os.path import join 7 | from PIL import Image 8 | from utils import list_files 9 | from natsort import natsorted 10 | import random 11 | import numpy as np 12 | from torchvision import transforms 13 | from data_handler import SSLDataset 14 | from data_handler.utils import get_mean_std 15 | 16 | 17 | class UTKFaceDataset(SSLDataset): 18 | label = 'age' 19 | sensi = 'race' 20 | fea_map = { 21 | 'age': 0, 22 | 'gender': 1, 23 | 'race': 2 24 | } 25 | num_map = { 26 | 'age': 100, # will be changed if the function '_transorm_age' is called 27 | 'gender': 2, 28 | 'race': 4 29 | } 30 | mean, std = get_mean_std('utkface') 31 | train_transform = transforms.Compose( 32 | [transforms.Resize((256, 256)), 33 | transforms.RandomCrop(224), 34 | transforms.RandomHorizontalFlip(), 35 | transforms.ToTensor(), 36 | transforms.Normalize(mean=mean, std=std)] 37 | ) 38 | test_transform = transforms.Compose( 39 | [transforms.Resize((224, 224)), 40 | transforms.ToTensor(), 41 | transforms.Normalize(mean=mean, std=std)] 42 | ) 43 | name = 'utkface' 44 | 45 | def __init__(self, **kwargs): 46 | transform = self.train_transform if kwargs['split'] == 'train' else self.test_transform 47 | 48 | SSLDataset.__init__(self, transform=transform, **kwargs) 49 | 50 | filenames = list_files(self.root, '.jpg') 51 | filenames = natsorted(filenames) 52 | self._data_preprocessing(filenames) 53 | self.num_groups = self.num_map[self.sensi] 54 | self.num_classes = self.num_map[self.label] 55 | 56 | # we want the same train / test set, so fix the seed to 1 57 | random.seed(1) 58 | random.shuffle(self.features) 59 | 60 | train, test = self._make_data(self.features, self.num_groups, self.num_classes) 61 | self.features = train if self.split == 'train' or 'group' in self.version else test 62 | 63 | self.num_data, self.idxs_per_group = self._data_count(self.features, self.num_groups, self.num_classes) 64 | 65 | # if semi-supervised learning, 66 | if self.sv_ratio < 1: 67 | # we want the different supervision according to the seed 68 | random.seed(self.seed) 69 | self.features, self.num_data, self.idxs_per_group = self.ssl_processing(self.features, self.num_data, self.idxs_per_group, ) 70 | if 'group' in self.version: 71 | a, b = self.num_groups, self.num_classes 72 | self.num_groups, self.num_classes = b, a 73 | 74 | self.weights = self._make_weights() 75 | 76 | def __getitem__(self, index): 77 | s, l, img_name = self.features[index] 78 | 79 | image_path = join(self.root, img_name) 80 | image = Image.open(image_path, mode='r').convert('RGB') 81 | 82 | if self.transform: 83 | image = self.transform(image) 84 | 85 | if 'group' in self.version: 86 | return image, 1, np.float32(l), np.int64(s), (index, img_name) 87 | else: 88 | return image, 1, np.float32(s), np.int64(l), (index, img_name) 89 | 90 | # five functions below preprocess UTKFace dataset 91 | def _data_preprocessing(self, filenames): 92 | filenames = self._delete_incomplete_images(filenames) 93 | filenames = self._delete_others_n_age_filter(filenames) 94 | self.features = [] 95 | for filename in filenames: 96 | s, y = self._filename2SY(filename) 97 | self.features.append([s, y, filename]) 98 | 99 | def _filename2SY(self, filename): 100 | tmp = filename.split('_') 101 | sensi = int(tmp[self.fea_map[self.sensi]]) 102 | label = int(tmp[self.fea_map[self.label]]) 103 | if self.sensi == 'age': 104 | sensi = self._transform_age(sensi) 105 | if self.label == 'age': 106 | label = self._transform_age(label) 107 | return int(sensi), int(label) 108 | 109 | def _transform_age(self, age): 110 | if age < 20: 111 | label = 0 112 | elif age < 40: 113 | label = 1 114 | else: 115 | label = 2 116 | return label 117 | 118 | def _delete_incomplete_images(self, filenames): 119 | filenames = [image for image in filenames if len(image.split('_')) == 4] 120 | return filenames 121 | 122 | def _delete_others_n_age_filter(self, filenames): 123 | filenames = [image for image in filenames 124 | if ((image.split('_')[self.fea_map['race']] != '4'))] 125 | ages = [self._transform_age(int(image.split('_')[self.fea_map['age']])) for image in filenames] 126 | self.num_map['age'] = len(set(ages)) 127 | return filenames 128 | -------------------------------------------------------------------------------- /trainer/fairhsic.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | from __future__ import print_function 7 | 8 | import torch.nn.functional as F 9 | import time 10 | from utils import get_accuracy 11 | import trainer 12 | from .hsic import RbfHSIC 13 | 14 | 15 | class Trainer(trainer.GenericTrainer): 16 | def __init__(self, args, **kwargs): 17 | super().__init__(args=args, **kwargs) 18 | self.lamb = args.lamb 19 | self.sigma = args.sigma 20 | self.kernel = args.kernel 21 | self.slmode = True if args.sv < 1 else False 22 | self.version = args.version 23 | 24 | def train(self, train_loader, test_loader, epochs, writer=None): 25 | num_classes = train_loader.dataset.num_classes 26 | num_groups = train_loader.dataset.num_groups 27 | 28 | hsic = RbfHSIC(1, 1) 29 | 30 | for epoch in range(self.epochs): 31 | self._train_epoch(epoch, train_loader, self.model, hsic=hsic, num_classes=num_classes) 32 | 33 | eval_start_time = time.time() 34 | eval_loss, eval_acc, eval_deom, eval_deoa, eval_subgroup_acc = self.evaluate(self.model, test_loader, self.criterion) 35 | eval_end_time = time.time() 36 | print('[{}/{}] Method: {} ' 37 | 'Test Loss: {:.3f} Test Acc: {:.2f} Test DEOM {:.2f} [{:.2f} s]'.format 38 | (epoch + 1, epochs, self.method, 39 | eval_loss, eval_acc, eval_deom, (eval_end_time - eval_start_time))) 40 | 41 | if self.record: 42 | train_loss, train_acc, train_deom, train_deoa, train_subgroup_acc = self.evaluate(self.model, train_loader, self.criterion) 43 | writer.add_scalar('train_loss', train_loss, epoch) 44 | writer.add_scalar('train_acc', train_acc, epoch) 45 | writer.add_scalar('train_deom', train_deom, epoch) 46 | writer.add_scalar('train_deoa', train_deoa, epoch) 47 | writer.add_scalar('eval_loss', eval_loss, epoch) 48 | writer.add_scalar('eval_acc', eval_acc, epoch) 49 | writer.add_scalar('eval_deom', eval_deom, epoch) 50 | writer.add_scalar('eval_deoa', eval_deoa, epoch) 51 | 52 | eval_contents = {} 53 | train_contents = {} 54 | for g in range(num_groups): 55 | for l in range(num_classes): 56 | eval_contents[f'g{g},l{l}'] = eval_subgroup_acc[g, l] 57 | train_contents[f'g{g},l{l}'] = train_subgroup_acc[g, l] 58 | writer.add_scalars('eval_subgroup_acc', eval_contents, epoch) 59 | writer.add_scalars('train_subgroup_acc', train_contents, epoch) 60 | 61 | if self.scheduler is not None and 'Reduce' in type(self.scheduler).__name__: 62 | self.scheduler.step(eval_loss) 63 | else: 64 | self.scheduler.step() 65 | 66 | print('Training Finished!') 67 | 68 | def _train_epoch(self, epoch, train_loader, model, hsic=None, num_classes=3): 69 | model.train() 70 | 71 | running_acc = 0.0 72 | running_loss = 0.0 73 | batch_start_time = time.time() 74 | 75 | for i, data in enumerate(train_loader): 76 | # Get the inputs 77 | inputs, _, groups, targets, (idx, _) = data 78 | labels = targets 79 | if self.cuda: 80 | inputs = inputs.cuda(self.device) 81 | labels = labels.cuda(self.device) 82 | groups = groups.long().cuda(self.device) 83 | 84 | outputs = model(inputs, get_inter=True) 85 | 86 | stu_logits = outputs[-1] 87 | 88 | loss = self.criterion(stu_logits, labels) 89 | 90 | running_acc += get_accuracy(stu_logits, labels) 91 | 92 | f_s = outputs[-2] 93 | 94 | group_onehot = F.one_hot(groups).float() 95 | hsic_loss = 0 96 | for l in range(num_classes): 97 | mask = targets == l 98 | hsic_loss += hsic.unbiased_estimator(f_s[mask], group_onehot[mask]) 99 | 100 | loss = loss + self.lamb * hsic_loss 101 | running_loss += loss.item() 102 | self.optimizer.zero_grad() 103 | loss.backward() 104 | self.optimizer.step() 105 | if i % self.term == self.term - 1: # print every self.term mini-batches 106 | avg_batch_time = time.time() - batch_start_time 107 | print('[{}/{}, {:5d}] Method: {} Train Loss: {:.3f} Train Acc: {:.2f} ' 108 | '[{:.2f} s/batch]'.format 109 | (epoch + 1, self.epochs, i + 1, self.method, running_loss / self.term, running_acc / self.term, 110 | avg_batch_time / self.term)) 111 | 112 | running_loss = 0.0 113 | running_acc = 0.0 114 | batch_start_time = time.time() 115 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | import argparse 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser(description='Fairness') 11 | parser.add_argument('--date', default='20200101', type=str, help='experiment date') 12 | parser.add_argument('--term', default=20, type=int, help='the period for recording train acc') 13 | parser.add_argument('--record', default=False, action='store_true', help='record using tensorboardX') 14 | parser.add_argument('--result-dir', default='./results/', 15 | help='directory to save results (default: ./results/)') 16 | parser.add_argument('--log-dir', default='./logs/', 17 | help='directory to save logs (default: ./logs/)') 18 | parser.add_argument('--data-dir', default='./data/', 19 | help='data directory (default: ./data/)') 20 | parser.add_argument('--save-dir', default='./trained_models/', 21 | help='directory to save trained models (default: ./trained_models/)') 22 | parser.add_argument('--mode', default='train', choices=['train', 'eval']) 23 | parser.add_argument('--evalset', default='test', choices=['all', 'train', 'test']) 24 | parser.add_argument('--get-inter', default=False, action='store_true', 25 | help='get penultimate features for TSNE visualization') 26 | 27 | 28 | #### base configuration for learning #### 29 | parser.add_argument('--seed', default=0, type=int, help='seed for randomness') 30 | # dataset 31 | parser.add_argument('--dataset', required=True, default='', choices=['adult', 'compas','utkface', 'celeba', 'utkface_fairface']) 32 | parser.add_argument('--batch-size', default=128, type=int, help='mini batch size') 33 | parser.add_argument('--img-size', default=224, type=int, help='img size for preprocessing') 34 | parser.add_argument('--num-workers', default=2, type=int, help='the number of thread used in dataloader') 35 | parser.add_argument('--labelwise', default=False, action='store_true', help='balanced sampling over groups') 36 | # only for celebA 37 | parser.add_argument('--target', default='Attractive', type=str, help='target attribute for celeba') 38 | parser.add_argument('--add-attr', default=None, help='additional group attribute for celeba') 39 | 40 | 41 | # model 42 | parser.add_argument('--model', default='', required=True, choices=['mlp', 'resnet18','resnet18_dropout']) 43 | parser.add_argument('--modelpath', default=None) 44 | parser.add_argument('--pretrained', default=False, action='store_true', help='load imagenet pretrained model') 45 | parser.add_argument('--device', default=0, type=int, help='cuda device number') 46 | parser.add_argument('--t-device', default=0, type=int, help='teacher cuda device number') 47 | # optimization 48 | parser.add_argument('--method', default='scratch', type=str, required=True, 49 | choices=['scratch', 'reweighting','mfd', 'adv', 'fairhsic']) 50 | parser.add_argument('--optimizer', default='Adam', type=str, required=False, 51 | choices=['AdamP', 'SGD', 'SGD_momentum_decay', 'Adam'], 52 | help='(default=%(default)s)') 53 | parser.add_argument('--epochs', default=50, type=int, help='number of training epochs') 54 | parser.add_argument('--lr', default=0.001, type=float, help='learning rate') 55 | parser.add_argument('--weight-decay', default=0, type=float, help='weight decay') 56 | 57 | 58 | # for base fairness methods 59 | parser.add_argument('--lamb', default=1, type=float, help='fairness strength') 60 | # for MFD 61 | parser.add_argument('--sigma', default=1.0, type=float, help='sigma for rbf kernel') 62 | parser.add_argument('--kernel', default='rbf', type=str, choices=['rbf', 'poly'], help='kernel for mmd') 63 | parser.add_argument('--teacher-type', default=None, choices=['mlp', 'resnet18', 'resnet18_dropout']) 64 | parser.add_argument('--teacher-path', default=None, help='teacher model path') 65 | 66 | # For reweighting & adv 67 | parser.add_argument('--reweighting-target-criterion', default='eo', type=str, help='fairness criterion') 68 | parser.add_argument('--iteration', default=10, type=int, help='iteration for reweighting') 69 | parser.add_argument('--ups-iter', default=10, type=int, help='iteration for reweighting') 70 | parser.add_argument('--eta', default=0.001, type=float, help='adversary training learning rate or lr for reweighting') 71 | 72 | # For fair FG, 73 | parser.add_argument('--sv', default=1, type=float, help='the ratio of group annotation for a training set') 74 | parser.add_argument('--version', default='', type=str, help='version about how the unsupervised data is used') 75 | 76 | args = parser.parse_args() 77 | args.cuda=True 78 | if args.mode == 'train' and args.method == 'mfd': 79 | if args.teacher_type is None: 80 | raise Exception('A teacher model needs to be specified for distillation') 81 | elif args.teacher_path is None: 82 | raise Exception('A teacher model path is not specified.') 83 | 84 | return args 85 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original code: 3 | https://github.com/sangwon79/Fair-Feature-Distillation-for-Visual-Recognition 4 | """ 5 | import torch 6 | import torch.optim as optim 7 | import numpy as np 8 | import networks 9 | import data_handler 10 | import trainer 11 | from utils import check_log_dir, make_log_name, set_seed 12 | from adamp import AdamP 13 | from tensorboardX import SummaryWriter 14 | 15 | from arguments import get_args 16 | import time 17 | import os 18 | args = get_args() 19 | 20 | 21 | def main(): 22 | torch.backends.cudnn.enabled = True 23 | 24 | seed = args.seed 25 | set_seed(seed) 26 | 27 | np.set_printoptions(precision=4) 28 | torch.set_printoptions(precision=4) 29 | 30 | log_name = make_log_name(args) 31 | dataset = args.dataset 32 | save_dir = os.path.join(args.save_dir, args.date, dataset, args.method) 33 | result_dir = os.path.join(args.result_dir, args.date, dataset, args.method) 34 | check_log_dir(save_dir) 35 | check_log_dir(result_dir) 36 | writer = None 37 | if args.record: 38 | log_dir = os.path.join(args.log_dir, args.date, dataset, args.method) 39 | check_log_dir(log_dir) 40 | writer = SummaryWriter(log_dir + '/' + log_name) 41 | 42 | ########################## get dataloader ################################ 43 | if args.dataset == 'adult': 44 | args.img_size = 97 45 | elif args.dataset == 'compas': 46 | args.img_size = 400 47 | else: 48 | args.img_size = 224 49 | tmp = data_handler.DataloaderFactory.get_dataloader(args.dataset, 50 | batch_size=args.batch_size, seed=args.seed, 51 | num_workers=args.num_workers, 52 | target_attr=args.target, 53 | add_attr=args.add_attr, 54 | labelwise=args.labelwise, 55 | sv_ratio=args.sv, 56 | version=args.version, 57 | args=args 58 | ) 59 | num_classes, num_groups, train_loader, test_loader = tmp 60 | 61 | ########################## get model ################################## 62 | if args.dataset == 'adult': 63 | args.img_size = 97 64 | elif args.dataset == 'compas': 65 | args.img_size = 400 66 | elif 'cifar' in args.dataset: 67 | args.img_size = 32 68 | 69 | model = networks.ModelFactory.get_model(args.model, num_classes, args.img_size, 70 | pretrained=args.pretrained, num_groups=num_groups) 71 | 72 | model.cuda('cuda:{}'.format(args.device)) 73 | 74 | if args.modelpath is not None: 75 | model.load_state_dict(torch.load(args.modelpath)) 76 | 77 | teacher = None 78 | if (args.method == 'mfd' or args.teacher_path is not None) and args.mode != 'eval': 79 | teacher = networks.ModelFactory.get_model(args.teacher_type, train_loader.dataset.num_classes, args.img_size) 80 | teacher.load_state_dict(torch.load(args.teacher_path, map_location=torch.device('cuda:{}'.format(args.t_device)))) 81 | teacher.cuda('cuda:{}'.format(args.t_device)) 82 | 83 | ########################## get trainer ################################## 84 | if 'Adam' in args.optimizer: 85 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 86 | elif 'AdamP' in args.optimizer: 87 | optimizer = AdamP(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 88 | elif 'SGD' in args.optimizer: 89 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9) 90 | 91 | trainer_ = trainer.TrainerFactory.get_trainer(args.method, model=model, args=args, 92 | optimizer=optimizer, teacher=teacher) 93 | 94 | ####################### start training or evaluating #################### 95 | 96 | if args.mode == 'train': 97 | start_t = time.time() 98 | trainer_.train(train_loader, test_loader, args.epochs, writer=writer) 99 | end_t = time.time() 100 | train_t = int((end_t - start_t)/60) # to minutes 101 | print('Training Time : {} hours {} minutes'.format(int(train_t/60), (train_t % 60))) 102 | trainer_.save_model(save_dir, log_name) 103 | 104 | else: 105 | print('Evaluation ----------------') 106 | model_to_load = args.modelpath 107 | trainer_.model.load_state_dict(torch.load(model_to_load)) 108 | print('Trained model loaded successfully') 109 | 110 | if args.evalset == 'all': 111 | trainer_.compute_confusion_matix('train', train_loader.dataset.num_classes, train_loader, result_dir, log_name) 112 | trainer_.compute_confusion_matix('test', test_loader.dataset.num_classes, test_loader, result_dir, log_name) 113 | 114 | elif args.evalset == 'train': 115 | trainer_.compute_confusion_matix('train', train_loader.dataset.num_classes, train_loader, result_dir, log_name) 116 | else: 117 | trainer_.compute_confusion_matix('test', test_loader.dataset.num_classes, test_loader, result_dir, log_name) 118 | if writer is not None: 119 | writer.close() 120 | print('Done!') 121 | 122 | 123 | if __name__ == '__main__': 124 | main() 125 | -------------------------------------------------------------------------------- /data_handler/AIF360/adult_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original code: 3 | https://github.com/Trusted-AI/AIF360 4 | """ 5 | import os 6 | 7 | import pandas as pd 8 | 9 | from data_handler.AIF360.standard_dataset import StandardDataset 10 | 11 | 12 | default_mappings = { 13 | 'label_maps': [{1.0: '>50K', 0.0: '<=50K'}], 14 | 'protected_attribute_maps': [{1.0: 'White', 0.0: 'Non-white'}, 15 | {1.0: 'Male', 0.0: 'Female'}] 16 | } 17 | 18 | 19 | class AdultDataset(StandardDataset): 20 | """Adult Census Income Dataset. 21 | See :file:`aif360/data/raw/adult/README.md`. 22 | """ 23 | 24 | def __init__(self, root_dir='./data/adult', 25 | label_name='income-per-year', 26 | favorable_classes=['>50K', '>50K.'], 27 | protected_attribute_names=['race', 'sex'], 28 | privileged_classes=[['White'], ['Male']], 29 | instance_weights_name=None, 30 | categorical_features=['workclass', 'education', 31 | 'marital-status', 'occupation', 'relationship', 32 | 'native-country'], 33 | features_to_keep=[], features_to_drop=['fnlwgt'], 34 | na_values=['?'], custom_preprocessing=None, 35 | metadata=default_mappings): 36 | """See :obj:`StandardDataset` for a description of the arguments. 37 | Examples: 38 | The following will instantiate a dataset which uses the `fnlwgt` 39 | feature: 40 | >>> from aif360.datasets import AdultDataset 41 | >>> ad = AdultDataset(instance_weights_name='fnlwgt', 42 | ... features_to_drop=[]) 43 | WARNING:root:Missing Data: 3620 rows removed from dataset. 44 | >>> not np.all(ad.instance_weights == 1.) 45 | True 46 | To instantiate a dataset which utilizes only numerical features and 47 | a single protected attribute, run: 48 | >>> single_protected = ['sex'] 49 | >>> single_privileged = [['Male']] 50 | >>> ad = AdultDataset(protected_attribute_names=single_protected, 51 | ... privileged_classes=single_privileged, 52 | ... categorical_features=[], 53 | ... features_to_keep=['age', 'education-num']) 54 | >>> print(ad.feature_names) 55 | ['education-num', 'age', 'sex'] 56 | >>> print(ad.label_names) 57 | ['income-per-year'] 58 | Note: the `protected_attribute_names` and `label_name` are kept even 59 | if they are not explicitly given in `features_to_keep`. 60 | In some cases, it may be useful to keep track of a mapping from 61 | `float -> str` for protected attributes and/or labels. If our use 62 | case differs from the default, we can modify the mapping stored in 63 | `metadata`: 64 | >>> label_map = {1.0: '>50K', 0.0: '<=50K'} 65 | >>> protected_attribute_maps = [{1.0: 'Male', 0.0: 'Female'}] 66 | >>> ad = AdultDataset(protected_attribute_names=['sex'], 67 | ... categorical_features=['workclass', 'education', 'marital-status', 68 | ... 'occupation', 'relationship', 'native-country', 'race'], 69 | ... privileged_classes=[['Male']], metadata={'label_map': label_map, 70 | ... 'protected_attribute_maps': protected_attribute_maps}) 71 | Note that we are now adding `race` as a `categorical_features`. 72 | Now this information will stay attached to the dataset and can be 73 | used for more descriptive visualizations. 74 | """ 75 | 76 | train_path = os.path.join(root_dir, 'adult.data') 77 | test_path = os.path.join(root_dir, 'adult.test') 78 | 79 | # as given by adult.names 80 | column_names = ['age', 'workclass', 'fnlwgt', 'education', 81 | 'education-num', 'marital-status', 'occupation', 'relationship', 82 | 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 83 | 'native-country', 'income-per-year'] 84 | try: 85 | train = pd.read_csv(train_path, header=None, names=column_names, 86 | skipinitialspace=True, na_values=na_values) 87 | test = pd.read_csv(test_path, header=0, names=column_names, 88 | skipinitialspace=True, na_values=na_values) 89 | except IOError as err: 90 | print("IOError: {}".format(err)) 91 | print("To use this class, please download the following files:") 92 | print("\n\thttps://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data") 93 | print("\thttps://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test") 94 | print("\thttps://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.names") 95 | print("\nand place them, as-is, in the folder:") 96 | print("\n\t{}\n".format('./data/adult')) 97 | import sys 98 | sys.exit(1) 99 | 100 | df = pd.concat([test, train], ignore_index=True) 101 | 102 | super(AdultDataset, self).__init__(df=df, label_name=label_name, 103 | favorable_classes=favorable_classes, 104 | protected_attribute_names=protected_attribute_names, 105 | privileged_classes=privileged_classes, 106 | instance_weights_name=instance_weights_name, 107 | categorical_features=categorical_features, 108 | features_to_keep=features_to_keep, 109 | features_to_drop=features_to_drop, na_values=na_values, 110 | custom_preprocessing=custom_preprocessing, metadata=metadata) 111 | -------------------------------------------------------------------------------- /data_handler/utkface_fairface.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | from os.path import join 7 | from PIL import Image 8 | from utils import list_files 9 | from natsort import natsorted 10 | import random 11 | import numpy as np 12 | from torchvision import transforms 13 | from data_handler import SSLDataset 14 | from data_handler.utils import get_mean_std 15 | from data_handler.fairface import Fairface 16 | 17 | 18 | class UTKFaceFairface_Dataset(SSLDataset): 19 | label = 'age' 20 | sensi = 'race' 21 | fea_map = { 22 | 'age': 0, 23 | 'gender': 1, 24 | 'race': 2 25 | } 26 | num_map = { 27 | 'age': 100, # will be changed if the function '_transorm_age' is called 28 | 'gender': 2, 29 | 'race': 4 30 | } 31 | mean, std = get_mean_std('utkface') 32 | train_transform = transforms.Compose( 33 | [transforms.Resize((256, 256)), 34 | transforms.RandomCrop(224), 35 | transforms.RandomHorizontalFlip(), 36 | transforms.ToTensor(), 37 | transforms.Normalize(mean=mean, std=std)] 38 | ) 39 | test_transform = transforms.Compose( 40 | [transforms.Resize((224, 224)), 41 | transforms.ToTensor(), 42 | transforms.Normalize(mean=mean, std=std)] 43 | ) 44 | name = 'utkface_fairface' 45 | 46 | def __init__(self, **kwargs): 47 | transform = self.train_transform if kwargs['split'] == 'train' else self.test_transform 48 | 49 | SSLDataset.__init__(self, transform=transform, **kwargs) 50 | 51 | filenames = list_files(self.root, '.jpg') 52 | filenames = natsorted(filenames) 53 | self._data_preprocessing(filenames) 54 | self.num_groups = self.num_map[self.sensi] 55 | self.num_classes = self.num_map[self.label] 56 | 57 | # we want the same train / test set, so fix the seed to 1 58 | random.seed(1) 59 | random.shuffle(self.features) 60 | 61 | train, test = self._make_data(self.features, self.num_groups, self.num_classes) 62 | # if we train a group classifier, we don't use the original test dataset 63 | # and we split the original training dataset into train and validation dataset 64 | self.features = train if self.split == 'train' or 'group' in self.version else test 65 | self.num_data, self.idxs_per_group = self._data_count(self.features, self.num_groups, self.num_classes) 66 | 67 | # add fairface 68 | idxs_dict = None 69 | if self.split == 'train' or 'group' in self.version: 70 | fairface = Fairface(target_attr='age', root='./data/fairface', split='train', seed=self.seed) 71 | fairface_feature = fairface.feature_mat 72 | # shuffle 73 | random.shuffle(fairface_feature) 74 | 75 | # count only fairface 76 | print('count the number of data in fairface') 77 | num_data_fairface, idxs_per_group_fairface = self._data_count(fairface_feature, self.num_groups, self.num_classes) 78 | print(np.sum(num_data_fairface)) 79 | 80 | # make idx_dict 81 | idxs_dict = {} 82 | idxs_dict['annotated'] = {} 83 | idxs_dict['non-annotated'] = {} 84 | total_num_utkface = len(self.features) 85 | for g in range(self.num_groups): 86 | for l in range(self.num_classes): 87 | tmp = np.array(idxs_per_group_fairface[g, l]) + total_num_utkface 88 | idxs_dict['annotated'][(g, l)] = self.idxs_per_group[(g, l)] 89 | idxs_dict['non-annotated'][(g, l)] = list(tmp) 90 | 91 | self.features.extend(fairface_feature) 92 | 93 | # count again 94 | self.num_data, self.idxs_per_group = self._data_count(self.features, self.num_groups, self.num_classes) 95 | 96 | # if semi-supervised learning, 97 | if self.sv_ratio < 1: 98 | # we want the different supervision according to the seed 99 | random.seed(self.seed) 100 | self.features, self.num_data, self.idxs_per_group = self.ssl_processing(self.features, self.num_data, self.idxs_per_group, idxs_dict=idxs_dict) 101 | if 'group' in self.version: 102 | a, b = self.num_groups, self.num_classes 103 | self.num_groups, self.num_classes = b, a 104 | 105 | def __getitem__(self, index): 106 | s, l, img_name = self.features[index] 107 | 108 | # image_path = join(self.root, img_name) 109 | image = Image.open(img_name, mode='r').convert('RGB') 110 | 111 | if self.transform: 112 | image = self.transform(image) 113 | 114 | if 'group' in self.version: 115 | return image, 1, np.float32(l), np.int64(s), (index, img_name) 116 | else: 117 | return image, 1, np.float32(s), np.int64(l), (index, img_name) 118 | 119 | # five functions below preprocess UTKFace dataset 120 | def _data_preprocessing(self, filenames): 121 | filenames = self._delete_incomplete_images(filenames) 122 | filenames = self._delete_others_n_age_filter(filenames) 123 | self.features = [] 124 | for filename in filenames: 125 | s, y = self._filename2SY(filename) 126 | filename = join(self.root, filename) 127 | self.features.append([s, y, filename]) 128 | 129 | def _filename2SY(self, filename): 130 | tmp = filename.split('_') 131 | sensi = int(tmp[self.fea_map[self.sensi]]) 132 | label = int(tmp[self.fea_map[self.label]]) 133 | if self.sensi == 'age': 134 | sensi = self._transform_age(sensi) 135 | if self.label == 'age': 136 | label = self._transform_age(label) 137 | return int(sensi), int(label) 138 | 139 | def _transform_age(self, age): 140 | if age < 20: 141 | label = 0 142 | elif age < 40: 143 | label = 1 144 | else: 145 | label = 2 146 | return label 147 | 148 | def _delete_incomplete_images(self, filenames): 149 | filenames = [image for image in filenames if len(image.split('_')) == 4] 150 | return filenames 151 | 152 | def _delete_others_n_age_filter(self, filenames): 153 | filenames = [image for image in filenames 154 | if ((image.split('_')[self.fea_map['race']] != '4'))] 155 | ages = [self._transform_age(int(image.split('_')[self.fea_map['age']])) for image in filenames] 156 | self.num_map['age'] = len(set(ages)) 157 | return filenames 158 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Fair Classifiers with Partially Annotated Group Labels (CVPR 2022) 2 | 3 | Official Pytorch implementation of Learning Fair Classifiers with Partially Annotated Group Labels | [Paper](https://arxiv.org/abs/2111.14581) 4 | 5 | [Sangwon Jung](https://scholar.google.com/citations?user=WdC_a5IAAAAJ&hl=ko)1 [Sanghyuk Chun](https://sanghyukchun.github.io/home/)2 [Taesup Moon](https://scholar.google.com/citations?user=lQlioBoAAAAJ&hl=ko)1, 3 6 | 7 | 1Department of ECE/ASRI, Seoul National University
8 | 2[NAVER AI LAB](https://naver-career.gitbook.io/en/teams/clova-cic)
9 | 3Interdisciplinary Program in Artificial Intelligence, Seoul National University 10 | 11 | Recently, fairness-aware learning have become increasingly crucial, but most of those methods operate by assuming the availability of fully annotated demographic group labels. We emphasize that such assumption is unrealistic for real-world applications since group label annotations are expensive and can conflict with privacy issues. In this paper, we consider a more practical scenario, dubbed as Algorithmic Group Fairness with the Partially annotated Group labels (Fair-PG). We observe that the existing methods to achieve group fairness perform even worse than the vanilla training, which simply uses full data only with target labels, under Fair-PG. To address this problem, we propose a simple Confidence-based Group Label assignment (CGL) strategy that is readily applicable to any fairness-aware learning method. CGL utilizes an auxiliary group classifier to assign pseudo group labels, where random labels are assigned to low confident samples. We first theoretically show that our method design is better than the vanilla pseudo-labeling strategy in terms of fairness criteria. Then, we empirically show on several benchmark datasets that by combining CGL and the state-of-the-art fairness-aware in-processing methods, the target accuracies and the fairness metrics can be jointly improved compared to the baselines. Furthermore, we convincingly show that CGL enables to naturally augment the given group-labeled dataset with external target label-only datasets so that both accuracy and fairness can be improved. 12 | 13 | ## Updates 14 | 15 | - 11 Apr, 2022: Initial upload. 16 | 17 | ## Dataset preparation 18 | 1. Download dataset 19 | - UTKFace : 20 | [link](https://susanqq.github.io/UTKFace/) (We used Aligned&Cropped Faces from the site) 21 | - CelebA : 22 | [link](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) (We used Aligned&Cropped Faces from the site) 23 | - Tabular datasets (Propublica Compas & Adult) : 24 | [link](https://github.com/Trusted-AI/AIF360) 25 | 2. Locate downloaded datasets to './data' directory 26 | 27 | ## How to train 28 | We note that CGL first trains a group classifier and a proper threshold using a train set and a validation set splitted from a group-labeled training dataset. And then, the group-unlabeled training samples are annotated with pseudo group labels based on CGL's assignment rules. Finally, we can train a fair model using base fair training methods such as LBC, FairHSIC or MFD. 29 | 30 | ### 1. Train a group classifier 31 | ``` 32 | $ python main_groupclf.py --model --method scratch \ 33 | --dataset \ 34 | --version groupclf_val \ 35 | --sv 36 | ``` 37 | 38 | In above command, the 'version' can be chosen beween 'groupclf' or 'groupclfval' that indicate whether the group-labeled data is splitted into a train/validation set or not. For CGL, you should choice 'groupclfval' and for the pseudo-label baseline, you should choice 'groupclf'. 39 | 40 | ### 2. Find a threshold and save the predictions of group classifier 41 | ``` 42 | $ python main_groupclf.py --model --method scratch \ 43 | --dataset \ 44 | --mode eval \ 45 | --version groupclf_val \ 46 | --sv 47 | ``` 48 | 49 | ### 3. Train a fair model using any fair-training methods 50 | 51 | - MFD. For the feature distilation, MFD needs a teacher model that is trained from scratch. 52 | ``` 53 | # train a scratch model 54 | $ python main.py --model --method scratch --dataset dataset_name 55 | $ python main.py --model --method mfd \ 56 | --dataset \ 57 | --labelwise \ 58 | --version cgl \ 59 | --sv {group_label_ratio} \ 60 | --lamb 100 \ 61 | --teacher-type mlp \ 62 | --teacher-path 63 | ``` 64 | - FairHSIC 65 | ``` 66 | $ python main.py --model --method fairhsic \ 67 | --dataset \ 68 | --labelwise \ 69 | --version cgl \ 70 | --sv {group_label_ratio} \ 71 | --lamb 100 72 | ``` 73 | - LBC 74 | ``` 75 | $ python main.py --model --method reweighting \ 76 | --dataset \ 77 | --iteration \ 78 | --version cgl \ 79 | --sv {group_label_ratio} \ 80 | ``` 81 | ## How to cite 82 | 83 | ``` 84 | @inproceedings{jung2021cgl, 85 | title={Learning Fair Classifiers with Partially Annotated Group Labels}, 86 | author={Sangwon Jung and Sanghyuk Chun and Taesup Moon}, 87 | year={2022}, 88 | booktitle={Conference on Computer Vision and Pattern Recognition (CVPR)}, 89 | } 90 | ``` 91 | ## License 92 | 93 | ``` 94 | Copyright (c) 2022-present NAVER Corp. 95 | 96 | Permission is hereby granted, free of charge, to any person obtaining a copy 97 | of this software and associated documentation files (the "Software"), to deal 98 | in the Software without restriction, including without limitation the rights 99 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 100 | copies of the Software, and to permit persons to whom the Software is 101 | furnished to do so, subject to the following conditions: 102 | 103 | The above copyright notice and this permission notice shall be included in 104 | all copies or substantial portions of the Software. 105 | 106 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 107 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 108 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 109 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 110 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 111 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 112 | THE SOFTWARE. 113 | ``` 114 | -------------------------------------------------------------------------------- /data_handler/celeba.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | import torch 7 | import os 8 | from os.path import join 9 | import PIL 10 | import pandas 11 | import random 12 | import zipfile 13 | from functools import partial 14 | from torchvision.datasets.utils import download_file_from_google_drive, check_integrity, verify_str_arg 15 | from torchvision import transforms 16 | import data_handler 17 | from data_handler.utils import get_mean_std 18 | 19 | 20 | class CelebA(data_handler.SSLDataset): 21 | """ 22 | There currently does not appear to be a easy way to extract 7z in python (without introducing additional 23 | dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available 24 | right now. 25 | """ 26 | file_list = [ 27 | # File ID MD5 Hash Filename 28 | ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"), 29 | # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"), 30 | # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"), 31 | ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"), 32 | ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"), 33 | ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"), 34 | ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"), 35 | # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"), 36 | ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), 37 | ] 38 | mean, std = get_mean_std('celeba') 39 | train_transform = transforms.Compose( 40 | [transforms.Resize((256, 256)), 41 | transforms.RandomCrop(224), 42 | transforms.RandomHorizontalFlip(), 43 | transforms.ToTensor(), 44 | transforms.Normalize(mean=mean, std=std)] 45 | ) 46 | test_transform = transforms.Compose( 47 | [transforms.Resize((224, 224)), 48 | transforms.ToTensor(), 49 | transforms.Normalize(mean=mean, std=std)] 50 | ) 51 | name = 'celeba' 52 | 53 | def __init__(self, target_attr='Attractive', add_attr=None, download=False, **kwargs): 54 | transform = self.train_transform if kwargs['split'] == 'train' else self.test_transform 55 | super(CelebA, self).__init__(transform=transform, **kwargs) 56 | 57 | if download: 58 | self._download() 59 | 60 | if not self._check_integrity(): 61 | raise RuntimeError('Dataset not found or corrupted.' + 62 | ' You can use download=True to download it') 63 | # SELECT the features 64 | self.sensitive_attr = 'Male' 65 | self.add_attr = add_attr 66 | self.target_attr = target_attr 67 | split_map = { 68 | "train": 0, 69 | "valid": 1, 70 | "test": 2, 71 | "all": None, 72 | } 73 | split = split_map[verify_str_arg(self.split.lower(), "split", 74 | ("train", "valid", "test", "all"))] 75 | if 'group' in self.version: 76 | split = 0 77 | fn = partial(join, self.root) 78 | splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0) 79 | attr = pandas.read_csv(fn("list_attr_celeba.txt"), delim_whitespace=True, header=1) 80 | print('Im add attr : ', target_attr, add_attr) 81 | 82 | mask = slice(None) if split is None else (splits[1] == split) 83 | 84 | self.filename = splits[mask].index.values 85 | self.attr = torch.as_tensor(attr[mask].values) 86 | self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} 87 | self.attr_names = list(attr.columns) 88 | 89 | self.target_idx = self.attr_names.index(self.target_attr) 90 | self.sensi_idx = self.attr_names.index(self.sensitive_attr) 91 | self.add_idx = self.attr_names.index(self.add_attr) if self.add_attr is not None else -1 92 | self.feature_idx = [i for i in range(len(self.attr_names)) if i != self.target_idx and i != self.sensi_idx] 93 | if self.add_attr is not None: 94 | self.feature_idx.remove(self.add_idx) 95 | self.num_classes = 2 96 | self.num_groups = 2 if self.add_attr is None else 4 97 | 98 | if self.add_attr is None: 99 | self.features = [[int(s), int(l), filename] for s, l, filename in 100 | zip(self.attr[:, self.sensi_idx], self.attr[:, self.target_idx], self.filename)] 101 | else: 102 | self.features = [[int(s1)*2 + int(s2), int(l), filename] for s1, s2, l, filename in 103 | zip(self.attr[:, self.sensi_idx], self.attr[:, self.add_idx], self.attr[:, self.target_idx], self.filename)] 104 | self.num_data, self.idxs_per_group = self._data_count(self.features, self.num_groups, self.num_classes) 105 | 106 | if self.split == "test" and 'group' not in self.version: 107 | self.features = self._balance_test_data(self.num_data, self.num_groups, self.num_classes) 108 | self.num_data, self.idxs_per_group = self._data_count(self.features, self.num_groups, self.num_classes) 109 | 110 | # if semi-supervised learning, 111 | if self.sv_ratio < 1: 112 | # we want the different supervision according to the seed 113 | random.seed(self.seed) 114 | self.features, self.num_data, self.idxs_per_group = self.ssl_processing(self.features, self.num_data, self.idxs_per_group) 115 | if 'group' in self.version: 116 | a, b = self.num_groups, self.num_classes 117 | self.num_groups, self.num_classes = b, a 118 | 119 | def __getitem__(self, index): 120 | sensitive, target, img_name = self.features[index] 121 | image = PIL.Image.open(os.path.join(self.root, "img_align_celeba", img_name)) 122 | 123 | if self.transform is not None: 124 | image = self.transform(image) 125 | 126 | if 'group' in self.version: 127 | return image, 0, target, sensitive, (index, img_name) 128 | return image, 0, sensitive, target, (index, img_name) 129 | 130 | def _check_integrity(self): 131 | for (_, md5, filename) in self.file_list: 132 | fpath = os.path.join(self.root, filename) 133 | _, ext = os.path.splitext(filename) 134 | # Allow original archive to be deleted (zip and 7z) 135 | # Only need the extracted images 136 | if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5): 137 | return False 138 | 139 | # Should check a hash of the images 140 | return os.path.isdir(os.path.join(self.root, "img_align_celeba")) 141 | 142 | def _download(self): 143 | if self._check_integrity(): 144 | print('Files already downloaded and verified') 145 | return 146 | 147 | for (file_id, md5, filename) in self.file_list: 148 | download_file_from_google_drive(file_id, self.root, filename, md5) 149 | 150 | with zipfile.ZipFile(os.path.join(self.root, "img_align_celeba.zip"), "r") as f: 151 | f.extractall(self.root) 152 | -------------------------------------------------------------------------------- /trainer/mfd.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original code: 3 | https://github.com/sangwon79/Fair-Feature-Distillation-for-Visual-Recognition 4 | """ 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | import time 11 | from utils import get_accuracy 12 | import trainer 13 | 14 | 15 | class Trainer(trainer.GenericTrainer): 16 | def __init__(self, args, **kwargs): 17 | super().__init__(args=args, **kwargs) 18 | self.lamb = args.lamb 19 | self.sigma = args.sigma 20 | self.kernel = args.kernel 21 | self.slmode = True if args.sv < 1 else False 22 | self.version = args.version 23 | 24 | def train(self, train_loader, test_loader, epochs, writer=None): 25 | num_classes = train_loader.dataset.num_classes 26 | num_groups = train_loader.dataset.num_groups 27 | 28 | distiller = MMDLoss(w_m=self.lamb, sigma=self.sigma, 29 | num_classes=num_classes, num_groups=num_groups, kernel=self.kernel) 30 | for epoch in range(self.epochs): 31 | self._train_epoch(epoch, train_loader, self.model, self.teacher, distiller=distiller) 32 | eval_start_time = time.time() 33 | eval_loss, eval_acc, eval_deom, eval_deoa, eval_subgroup_acc = self.evaluate(self.model, test_loader, self.criterion) 34 | eval_end_time = time.time() 35 | print('[{}/{}] Method: {} ' 36 | 'Test Loss: {:.3f} Test Acc: {:.2f} Test DEOM {:.2f} [{:.2f} s]'.format 37 | (epoch + 1, epochs, self.method, 38 | eval_loss, eval_acc, eval_deom, (eval_end_time - eval_start_time))) 39 | if self.record: 40 | train_loss, train_acc, train_deom, train_deoa, train_subgroup_acc = self.evaluate(self.model, train_loader, self.criterion) 41 | writer.add_scalar('train_loss', train_loss, epoch) 42 | writer.add_scalar('train_acc', train_acc, epoch) 43 | writer.add_scalar('train_deom', train_deom, epoch) 44 | writer.add_scalar('train_deoa', train_deoa, epoch) 45 | writer.add_scalar('eval_loss', eval_loss, epoch) 46 | writer.add_scalar('eval_acc', eval_acc, epoch) 47 | writer.add_scalar('eval_deom', eval_deom, epoch) 48 | writer.add_scalar('eval_deoa', eval_deoa, epoch) 49 | 50 | eval_contents = {} 51 | train_contents = {} 52 | for g in range(num_groups): 53 | for l in range(num_classes): 54 | eval_contents[f'g{g},l{l}'] = eval_subgroup_acc[g, l] 55 | train_contents[f'g{g},l{l}'] = train_subgroup_acc[g, l] 56 | writer.add_scalars('eval_subgroup_acc', eval_contents, epoch) 57 | writer.add_scalars('train_subgroup_acc', train_contents, epoch) 58 | 59 | if self.scheduler is not None and 'Reduce' in type(self.scheduler).__name__: 60 | self.scheduler.step(eval_loss) 61 | else: 62 | self.scheduler.step() 63 | 64 | print('Training Finished!') 65 | 66 | def _train_epoch(self, epoch, train_loader, model, teacher, distiller=None): 67 | model.train() 68 | teacher.eval() 69 | 70 | running_acc = 0.0 71 | running_loss = 0.0 72 | batch_start_time = time.time() 73 | # tmp = np.zeros((4,3)) 74 | for i, data in enumerate(train_loader): 75 | # Get the inputs 76 | inputs, _, groups, targets, (idx, _) = data 77 | labels = targets 78 | if self.cuda: 79 | inputs = inputs.cuda(self.device) 80 | labels = labels.cuda(self.device) 81 | groups = groups.long().cuda(self.device) 82 | 83 | t_inputs = inputs.to(self.t_device) 84 | 85 | outputs = model(inputs, get_inter=True) 86 | stu_logits = outputs[-1] 87 | 88 | t_outputs = teacher(t_inputs, get_inter=True) 89 | 90 | loss = self.criterion(stu_logits, labels) 91 | 92 | running_acc += get_accuracy(stu_logits, labels) 93 | f_s = outputs[-2] 94 | 95 | f_t = t_outputs[-2].detach() 96 | 97 | mmd_loss = distiller.forward(f_s, f_t, groups=groups, labels=labels) 98 | 99 | loss = loss + mmd_loss 100 | running_loss += loss.item() 101 | self.optimizer.zero_grad() 102 | loss.backward() 103 | self.optimizer.step() 104 | if i % self.term == self.term - 1: # print every self.term mini-batches 105 | avg_batch_time = time.time() - batch_start_time 106 | print('[{}/{}, {:5d}] Method: {} Train Loss: {:.3f} Train Acc: {:.2f} ' 107 | '[{:.2f} s/batch]'.format 108 | (epoch + 1, self.epochs, i + 1, self.method, running_loss / self.term, running_acc / self.term, 109 | avg_batch_time / self.term)) 110 | 111 | running_loss = 0.0 112 | running_acc = 0.0 113 | batch_start_time = time.time() 114 | 115 | 116 | class MMDLoss(nn.Module): 117 | def __init__(self, w_m, sigma, num_groups, num_classes, kernel): 118 | super(MMDLoss, self).__init__() 119 | self.w_m = w_m 120 | self.sigma = sigma 121 | self.num_groups = num_groups 122 | self.num_classes = num_classes 123 | self.kernel = kernel 124 | 125 | def forward(self, f_s, f_t, groups, labels): 126 | if self.kernel == 'poly': 127 | student = F.normalize(f_s.view(f_s.shape[0], -1), dim=1) 128 | teacher = F.normalize(f_t.view(f_t.shape[0], -1), dim=1).detach() 129 | else: 130 | student = f_s.view(f_s.shape[0], -1) 131 | teacher = f_t.view(f_t.shape[0], -1) 132 | 133 | mmd_loss = 0 134 | with torch.no_grad(): 135 | _, sigma_avg = self.pdist(teacher, student, sigma_base=self.sigma, kernel=self.kernel) 136 | 137 | for c in range(self.num_classes): 138 | if len(teacher[labels == c]) == 0: 139 | continue 140 | for g in range(self.num_groups): 141 | if len(student[(labels == c) * (groups == g)]) == 0: 142 | continue 143 | K_TS, _ = self.pdist(teacher[labels == c], student[(labels == c) * (groups == g)], 144 | sigma_base=self.sigma, sigma_avg=sigma_avg, kernel=self.kernel) 145 | K_SS, _ = self.pdist(student[(labels == c) * (groups == g)], student[(labels == c) * (groups == g)], 146 | sigma_base=self.sigma, sigma_avg=sigma_avg, kernel=self.kernel) 147 | 148 | K_TT, _ = self.pdist(teacher[labels == c], teacher[labels == c], sigma_base=self.sigma, 149 | sigma_avg=sigma_avg, kernel=self.kernel) 150 | 151 | mmd_loss += K_TT.mean() + K_SS.mean() - 2 * K_TS.mean() 152 | loss = (1/2) * self.w_m * mmd_loss 153 | return loss 154 | 155 | @staticmethod 156 | def pdist(e1, e2, eps=1e-12, kernel='rbf', sigma_base=1.0, sigma_avg=None): 157 | if len(e1) == 0 or len(e2) == 0: 158 | res = torch.zeros(1) 159 | else: 160 | if kernel == 'rbf': 161 | e1_square = e1.pow(2).sum(dim=1) 162 | e2_square = e2.pow(2).sum(dim=1) 163 | prod = e1 @ e2.t() 164 | res = (e1_square.unsqueeze(1) + e2_square.unsqueeze(0) - 2 * prod).clamp(min=eps) 165 | res = res.clone() 166 | sigma_avg = res.mean().detach() if sigma_avg is None else sigma_avg 167 | res = torch.exp(-res / (2*(sigma_base)*sigma_avg)) 168 | elif kernel == 'poly': 169 | res = torch.matmul(e1, e2.t()).pow(2) 170 | 171 | return res, sigma_avg 172 | -------------------------------------------------------------------------------- /data_handler/AIF360/standard_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original code: 3 | https://github.com/Trusted-AI/AIF360 4 | """ 5 | from logging import warning 6 | 7 | import numpy as np 8 | import pandas as pd 9 | 10 | from data_handler.AIF360.binary_label_dataset import BinaryLabelDataset 11 | 12 | 13 | class StandardDataset(BinaryLabelDataset): 14 | """Base class for every :obj:`BinaryLabelDataset` provided out of the box by 15 | aif360. 16 | It is not strictly necessary to inherit this class when adding custom 17 | datasets but it may be useful. 18 | This class is very loosely based on code from 19 | https://github.com/algofairness/fairness-comparison. 20 | """ 21 | 22 | def __init__(self, df, label_name, favorable_classes, 23 | protected_attribute_names, privileged_classes, 24 | instance_weights_name='', scores_name='', 25 | categorical_features=[], features_to_keep=[], 26 | features_to_drop=[], na_values=[], custom_preprocessing=None, 27 | metadata=None): 28 | """ 29 | Subclasses of StandardDataset should perform the following before 30 | calling `super().__init__`: 31 | 1. Load the dataframe from a raw file. 32 | Then, this class will go through a standard preprocessing routine which: 33 | 2. (optional) Performs some dataset-specific preprocessing (e.g. 34 | renaming columns/values, handling missing data). 35 | 3. Drops unrequested columns (see `features_to_keep` and 36 | `features_to_drop` for details). 37 | 4. Drops rows with NA values. 38 | 5. Creates a one-hot encoding of the categorical variables. 39 | 6. Maps protected attributes to binary privileged/unprivileged 40 | values (1/0). 41 | 7. Maps labels to binary favorable/unfavorable labels (1/0). 42 | Args: 43 | df (pandas.DataFrame): DataFrame on which to perform standard 44 | processing. 45 | label_name: Name of the label column in `df`. 46 | favorable_classes (list or function): Label values which are 47 | considered favorable or a boolean function which returns `True` 48 | if favorable. All others are unfavorable. Label values are 49 | mapped to 1 (favorable) and 0 (unfavorable) if they are not 50 | already binary and numerical. 51 | protected_attribute_names (list): List of names corresponding to 52 | protected attribute columns in `df`. 53 | privileged_classes (list(list or function)): Each element is 54 | a list of values which are considered privileged or a boolean 55 | function which return `True` if privileged for the corresponding 56 | column in `protected_attribute_names`. All others are 57 | unprivileged. Values are mapped to 1 (privileged) and 0 58 | (unprivileged) if they are not already numerical. 59 | instance_weights_name (optional): Name of the instance weights 60 | column in `df`. 61 | categorical_features (optional, list): List of column names in the 62 | DataFrame which are to be expanded into one-hot vectors. 63 | features_to_keep (optional, list): Column names to keep. All others 64 | are dropped except those present in `protected_attribute_names`, 65 | `categorical_features`, `label_name` or `instance_weights_name`. 66 | Defaults to all columns if not provided. 67 | features_to_drop (optional, list): Column names to drop. *Note: this 68 | overrides* `features_to_keep`. 69 | na_values (optional): Additional strings to recognize as NA. See 70 | :func:`pandas.read_csv` for details. 71 | custom_preprocessing (function): A function object which 72 | acts on and returns a DataFrame (f: DataFrame -> DataFrame). If 73 | `None`, no extra preprocessing is applied. 74 | metadata (optional): Additional metadata to append. 75 | """ 76 | # 2. Perform dataset-specific preprocessing 77 | if custom_preprocessing: 78 | df = custom_preprocessing(df) 79 | 80 | # 3. Drop unrequested columns 81 | features_to_keep = features_to_keep or df.columns.tolist() 82 | keep = (set(features_to_keep) | set(protected_attribute_names) 83 | | set(categorical_features) | set([label_name])) 84 | if instance_weights_name: 85 | keep |= set([instance_weights_name]) 86 | df = df[sorted(keep - set(features_to_drop), key=df.columns.get_loc)] 87 | categorical_features = sorted(set(categorical_features) - set(features_to_drop), key=df.columns.get_loc) 88 | 89 | # 4. Remove any rows that have missing data. 90 | dropped = df.dropna() 91 | count = df.shape[0] - dropped.shape[0] 92 | if count > 0: 93 | warning("Missing Data: {} rows removed from {}.".format(count, 94 | type(self).__name__)) 95 | df = dropped 96 | 97 | # 5. Create a one-hot encoding of the categorical variables. 98 | df = pd.get_dummies(df, columns=categorical_features, prefix_sep='=') 99 | 100 | # 6. Map protected attributes to privileged/unprivileged 101 | privileged_protected_attributes = [] 102 | unprivileged_protected_attributes = [] 103 | for attr, vals in zip(protected_attribute_names, privileged_classes): 104 | privileged_values = [1.] 105 | unprivileged_values = [0.] 106 | if callable(vals): 107 | df[attr] = df[attr].apply(vals) 108 | elif np.issubdtype(df[attr].dtype, np.number): 109 | # this attribute is numeric; no remapping needed 110 | privileged_values = vals 111 | unprivileged_values = list(set(df[attr]).difference(vals)) 112 | else: 113 | # find all instances which match any of the attribute values 114 | priv = np.logical_or.reduce(np.equal.outer(vals, df[attr].to_numpy())) 115 | df.loc[priv, attr] = privileged_values[0] 116 | df.loc[~priv, attr] = unprivileged_values[0] 117 | 118 | privileged_protected_attributes.append( 119 | np.array(privileged_values, dtype=np.float64)) 120 | unprivileged_protected_attributes.append( 121 | np.array(unprivileged_values, dtype=np.float64)) 122 | 123 | # 7. Make labels binary 124 | favorable_label = 1. 125 | unfavorable_label = 0. 126 | if callable(favorable_classes): 127 | df[label_name] = df[label_name].apply(favorable_classes) 128 | elif np.issubdtype(df[label_name], np.number) and len(set(df[label_name])) == 2: 129 | # labels are already binary; don't change them 130 | favorable_label = favorable_classes[0] 131 | unfavorable_label = set(df[label_name]).difference(favorable_classes).pop() 132 | else: 133 | # find all instances which match any of the favorable classes 134 | pos = np.logical_or.reduce(np.equal.outer(favorable_classes, 135 | df[label_name].to_numpy())) 136 | df.loc[pos, label_name] = favorable_label 137 | df.loc[~pos, label_name] = unfavorable_label 138 | 139 | super(StandardDataset, self).__init__(df=df, label_names=[label_name], 140 | protected_attribute_names=protected_attribute_names, 141 | privileged_protected_attributes=privileged_protected_attributes, 142 | unprivileged_protected_attributes=unprivileged_protected_attributes, 143 | instance_weights_name=instance_weights_name, 144 | scores_names=[scores_name] if scores_name else [], 145 | favorable_label=favorable_label, 146 | unfavorable_label=unfavorable_label, metadata=metadata) 147 | -------------------------------------------------------------------------------- /trainer/trainer_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | import torch 7 | import numpy as np 8 | import os 9 | import torch.nn as nn 10 | from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR, CosineAnnealingLR 11 | from sklearn.metrics import confusion_matrix 12 | from utils import make_log_name 13 | 14 | 15 | class TrainerFactory: 16 | def __init__(self): 17 | pass 18 | 19 | @staticmethod 20 | def get_trainer(method, **kwargs): 21 | if method == 'scratch': 22 | import trainer.vanilla_train as trainer 23 | elif method == 'mfd': 24 | import trainer.mfd as trainer 25 | elif method == 'fairhsic': 26 | import trainer.fairhsic as trainer 27 | elif method == 'adv': 28 | import trainer.adv_debiasing as trainer 29 | elif method == 'reweighting': 30 | import trainer.reweighting as trainer 31 | elif method == 'groupdro': 32 | import trainer.groupdro as trainer 33 | elif method == 'groupdro_ori': 34 | import trainer.groupdro as trainer 35 | else: 36 | raise Exception('Not allowed method') 37 | return trainer.Trainer(**kwargs) 38 | 39 | 40 | class GenericTrainer: 41 | ''' 42 | Base class for trainer; to implement a new training routine, inherit from this. 43 | ''' 44 | def __init__(self, model, args, optimizer, teacher=None): 45 | self.get_inter = args.get_inter 46 | 47 | self.record = args.record 48 | self.cuda = args.cuda 49 | self.device = args.device 50 | self.t_device = args.t_device 51 | self.term = args.term 52 | self.lr = args.lr 53 | self.epochs = args.epochs 54 | self.method = args.method 55 | self.model_name =args.model 56 | self.model = model 57 | self.teacher = teacher 58 | self.optimizer = optimizer 59 | self.optim_type = args.optimizer 60 | self.log_dir = args.log_dir 61 | self.criterion=torch.nn.CrossEntropyLoss() 62 | self.scheduler = None 63 | 64 | self.log_name = make_log_name(args) 65 | self.log_dir = os.path.join(args.log_dir, args.date, args.dataset, args.method) 66 | self.save_dir = os.path.join(args.save_dir, args.date, args.dataset, args.method) 67 | 68 | if self.optim_type == 'Adam' and self.optimizer is not None: 69 | self.scheduler = ReduceLROnPlateau(self.optimizer) 70 | elif self.optim_type == 'AdamP' and self.optimizer is not None: 71 | if self.epochs < 100: 72 | t_max = self.epochs 73 | elif self.epochs == 200: 74 | t_max = 66 75 | self.scheduler = CosineAnnealingLR(self.optimizer, t_max) 76 | else: 77 | self.scheduler = MultiStepLR(self.optimizer, [60, 120, 180], gamma=0.1) 78 | #self.scheduler = MultiStepLR(self.optimizer, [30, 60, 90], gamma=0.1) 79 | 80 | 81 | def evaluate(self, model, loader, criterion, device=None): 82 | model.eval() 83 | num_groups = loader.dataset.num_groups 84 | num_classes = loader.dataset.num_classes 85 | device = self.device if device is None else device 86 | 87 | eval_acc = 0 88 | eval_loss = 0 89 | eval_eopp_list = torch.zeros(num_groups, num_classes).cuda(device) 90 | eval_data_count = torch.zeros(num_groups, num_classes).cuda(device) 91 | 92 | if 'Custom' in type(loader).__name__: 93 | loader = loader.generate() 94 | with torch.no_grad(): 95 | for j, eval_data in enumerate(loader): 96 | if j == 100: 97 | break 98 | # Get the inputs 99 | inputs, _, groups, classes, _ = eval_data 100 | # 101 | labels = classes 102 | if self.cuda: 103 | inputs = inputs.cuda(device) 104 | labels = labels.cuda(device) 105 | groups = groups.cuda(device) 106 | 107 | outputs = model(inputs) 108 | 109 | loss = criterion(outputs, labels) 110 | eval_loss += loss.item() * len(labels) 111 | preds = torch.argmax(outputs, 1) 112 | acc = (preds == labels).float().squeeze() 113 | eval_acc += acc.sum() 114 | 115 | for g in range(num_groups): 116 | for l in range(num_classes): 117 | eval_eopp_list[g, l] += acc[(groups == g) * (labels == l)].sum() 118 | eval_data_count[g, l] += torch.sum((groups == g) * (labels == l)) 119 | 120 | eval_loss = eval_loss / eval_data_count.sum() 121 | eval_acc = eval_acc / eval_data_count.sum() 122 | eval_eopp_list = eval_eopp_list / eval_data_count 123 | eval_max_eopp = torch.max(eval_eopp_list, dim=0)[0] - torch.min(eval_eopp_list, dim=0)[0] 124 | eval_avg_eopp = torch.mean(eval_max_eopp).item() 125 | eval_max_eopp = torch.max(eval_max_eopp).item() 126 | model.train() 127 | return eval_loss, eval_acc, eval_max_eopp, eval_avg_eopp, eval_eopp_list 128 | 129 | def save_model(self, save_dir, log_name="", model=None): 130 | model_to_save = self.model if model is None else model 131 | model_savepath = os.path.join(save_dir, log_name + '.pt') 132 | torch.save(model_to_save.state_dict(), model_savepath) 133 | 134 | print('Model saved to %s' % model_savepath) 135 | 136 | def compute_confusion_matix(self, dataset='test', num_classes=2, 137 | dataloader=None, log_dir="", log_name=""): 138 | from scipy.io import savemat 139 | from collections import defaultdict 140 | self.model.eval() 141 | confu_mat = defaultdict(lambda: np.zeros((num_classes, num_classes))) 142 | print('# of {} data : {}'.format(dataset, len(dataloader.dataset))) 143 | 144 | predict_mat = {} 145 | output_set = torch.tensor([]) 146 | group_set = torch.tensor([], dtype=torch.long) 147 | target_set = torch.tensor([], dtype=torch.long) 148 | intermediate_feature_set = torch.tensor([]) 149 | 150 | with torch.no_grad(): 151 | for i, data in enumerate(dataloader): 152 | # Get the inputs 153 | inputs, _, groups, targets, _ = data 154 | labels = targets 155 | groups = groups.long() 156 | 157 | if self.cuda: 158 | inputs = inputs.cuda(self.device) 159 | labels = labels.cuda(self.device) 160 | 161 | # forward 162 | 163 | outputs = self.model(inputs) 164 | if self.get_inter: 165 | intermediate_feature = self.model.forward(inputs, get_inter=True)[-2] 166 | 167 | group_set = torch.cat((group_set, groups)) 168 | target_set = torch.cat((target_set, targets)) 169 | output_set = torch.cat((output_set, outputs.cpu())) 170 | if self.get_inter: 171 | intermediate_feature_set = torch.cat((intermediate_feature_set, intermediate_feature.cpu())) 172 | 173 | pred = torch.argmax(outputs, 1) 174 | group_element = list(torch.unique(groups).numpy()) 175 | for i in group_element: 176 | mask = groups == i 177 | if len(labels[mask]) != 0: 178 | confu_mat[str(i)] += confusion_matrix( 179 | labels[mask].cpu().numpy(), pred[mask].cpu().numpy(), 180 | labels=[i for i in range(num_classes)]) 181 | 182 | predict_mat['group_set'] = group_set.numpy() 183 | predict_mat['target_set'] = target_set.numpy() 184 | predict_mat['output_set'] = output_set.numpy() 185 | if self.get_inter: 186 | predict_mat['intermediate_feature_set'] = intermediate_feature_set.numpy() 187 | 188 | savepath = os.path.join(log_dir, log_name + '_{}_confu'.format(dataset)) 189 | print('savepath', savepath) 190 | savemat(savepath, confu_mat, appendmat=True) 191 | 192 | savepath_pred = os.path.join(log_dir, log_name + '_{}_pred'.format(dataset)) 193 | savemat(savepath_pred, predict_mat, appendmat=True) 194 | 195 | print('Computed confusion matrix for {} dataset successfully!'.format(dataset)) 196 | return confu_mat 197 | -------------------------------------------------------------------------------- /trainer/reweighting.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | from __future__ import print_function 7 | import torch 8 | import torch.nn as nn 9 | import time 10 | from utils import get_accuracy 11 | import trainer 12 | from torch.utils.data import DataLoader 13 | 14 | 15 | class Trainer(trainer.GenericTrainer): 16 | def __init__(self, args, **kwargs): 17 | super().__init__(args=args, **kwargs) 18 | 19 | self.eta = args.eta 20 | self.iteration = args.iteration 21 | self.batch_size = args.batch_size 22 | self.num_workers = args.num_workers 23 | self.reweighting_target_criterion = args.reweighting_target_criterion 24 | self.slmode = True if args.sv < 1 else False 25 | self.version = args.version 26 | 27 | def train(self, train_loader, test_loader, epochs, dummy_loader=None, writer=None): 28 | model = self.model 29 | model.train() 30 | num_groups = train_loader.dataset.num_groups 31 | num_classes = train_loader.dataset.num_classes 32 | 33 | extended_multipliers = torch.zeros((num_groups, num_classes)) 34 | if self.cuda: 35 | extended_multipliers = extended_multipliers.cuda() 36 | _, Y_train, S_train = self.get_statistics(train_loader.dataset, batch_size=self.batch_size, 37 | num_workers=self.num_workers) 38 | 39 | eta_learning_rate = self.eta 40 | print('eta_learning_rate : ', eta_learning_rate) 41 | n_iters = self.iteration 42 | print('n_iters : ', n_iters) 43 | 44 | violations = 0 45 | 46 | for iter_ in range(n_iters): 47 | start_t = time.time() 48 | weight_set = self.debias_weights(Y_train, S_train, extended_multipliers, num_groups, num_classes) 49 | 50 | for epoch in range(epochs): 51 | self._train_epoch(epoch, train_loader, model, weight_set) 52 | eval_start_time = time.time() 53 | eval_loss, eval_acc, eval_deom, eval_deoa, eval_subgroup_acc = self.evaluate(self.model, test_loader, self.criterion) 54 | eval_end_time = time.time() 55 | print('[{}/{}] Method: {} ' 56 | 'Test Loss: {:.3f} Test Acc: {:.2f} Test DEOM {:.2f} [{:.2f} s]'.format 57 | (epoch + 1, epochs, self.method, 58 | eval_loss, eval_acc, eval_deom, (eval_end_time - eval_start_time))) 59 | if self.record: 60 | train_loss, train_acc, train_deom, train_deoa, train_subgroup_acc = self.evaluate(self.model, train_loader, self.criterion) 61 | writer.add_scalar('train_loss', train_loss, epoch) 62 | writer.add_scalar('train_acc', train_acc, epoch) 63 | writer.add_scalar('train_deom', train_deom, epoch) 64 | writer.add_scalar('train_deoa', train_deoa, epoch) 65 | writer.add_scalar('eval_loss', eval_loss, epoch) 66 | writer.add_scalar('eval_acc', eval_acc, epoch) 67 | writer.add_scalar('eval_deopm', eval_deom, epoch) 68 | writer.add_scalar('eval_deopa', eval_deoa, epoch) 69 | 70 | eval_contents = {} 71 | train_contents = {} 72 | for g in range(num_groups): 73 | for l in range(num_classes): 74 | eval_contents[f'g{g},l{l}'] = eval_subgroup_acc[g, l] 75 | train_contents[f'g{g},l{l}'] = train_subgroup_acc[g, l] 76 | writer.add_scalars('eval_subgroup_acc', eval_contents, epoch) 77 | writer.add_scalars('train_subgroup_acc', train_contents, epoch) 78 | 79 | if self.scheduler is not None and 'Reduce' in type(self.scheduler).__name__: 80 | self.scheduler.step(eval_loss) 81 | else: 82 | self.scheduler.step() 83 | 84 | end_t = time.time() 85 | train_t = int((end_t - start_t) / 60) 86 | print('Training Time : {} hours {} minutes / iter : {}/{}'.format(int(train_t / 60), (train_t % 60), 87 | (iter_ + 1), n_iters)) 88 | 89 | Y_pred_train, Y_train, S_train = self.get_statistics(train_loader.dataset, batch_size=self.batch_size, 90 | num_workers=self.num_workers, model=model) 91 | 92 | if self.reweighting_target_criterion == 'dp': 93 | acc, violations = self.get_error_and_violations_DP(Y_pred_train, Y_train, S_train, num_groups, num_classes) 94 | elif self.reweighting_target_criterion == 'eo': 95 | acc, violations = self.get_error_and_violations_EO(Y_pred_train, Y_train, S_train, num_groups, num_classes) 96 | extended_multipliers -= eta_learning_rate * violations 97 | 98 | def _train_epoch(self, epoch, train_loader, model, weight_set): 99 | model.train() 100 | 101 | running_acc = 0.0 102 | running_loss = 0.0 103 | avg_batch_time = 0.0 104 | 105 | for i, data in enumerate(train_loader): 106 | batch_start_time = time.time() 107 | # Get the inputs 108 | inputs, _, groups, targets, indexes = data 109 | labels = targets 110 | labels = labels.long() 111 | 112 | weights = weight_set[indexes[0]] 113 | 114 | if self.cuda: 115 | inputs = inputs.cuda() 116 | labels = labels.cuda() 117 | weights = weights.cuda() 118 | groups = groups.cuda() 119 | 120 | outputs = model(inputs) 121 | 122 | loss = torch.mean(weights * nn.CrossEntropyLoss(reduction='none')(outputs, labels)) 123 | running_loss += loss.item() 124 | running_acc += get_accuracy(outputs, labels) 125 | 126 | # zero the parameter gradients + backward + optimize 127 | self.optimizer.zero_grad() 128 | loss.backward() 129 | self.optimizer.step() 130 | 131 | batch_end_time = time.time() 132 | avg_batch_time += batch_end_time - batch_start_time 133 | 134 | if i % self.term == self.term - 1: # print every self.term mini-batches 135 | print('[{}/{}, {:5d}] Method: {} Train Loss: {:.3f} Train Acc: {:.2f} ' 136 | '[{:.2f} s/batch]'.format 137 | (epoch + 1, self.epochs, i + 1, self.method, running_loss / self.term, running_acc / self.term, 138 | avg_batch_time / self.term)) 139 | 140 | running_loss = 0.0 141 | running_acc = 0.0 142 | avg_batch_time = 0.0 143 | 144 | last_batch_idx = i 145 | return last_batch_idx 146 | 147 | def get_statistics(self, dataset, batch_size=128, num_workers=2, model=None): 148 | 149 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, 150 | num_workers=num_workers, pin_memory=True, drop_last=False) 151 | 152 | if model is not None: 153 | model.eval() 154 | 155 | Y_pred_set = [] 156 | Y_set = [] 157 | S_set = [] 158 | total = 0 159 | for i, data in enumerate(dataloader): 160 | inputs, _, sen_attrs, targets, indexes = data 161 | Y_set.append(targets) # sen_attrs = -1 means no supervision for sensitive group 162 | S_set.append(sen_attrs) 163 | 164 | if self.cuda: 165 | inputs = inputs.cuda() 166 | groups = sen_attrs.cuda() 167 | if model is not None: 168 | outputs = model(inputs) 169 | Y_pred_set.append(torch.argmax(outputs, dim=1)) 170 | total += inputs.shape[0] 171 | 172 | Y_set = torch.cat(Y_set) 173 | S_set = torch.cat(S_set) 174 | Y_pred_set = torch.cat(Y_pred_set) if len(Y_pred_set) != 0 else torch.zeros(0) 175 | return Y_pred_set.long(), Y_set.long().cuda(), S_set.long().cuda() 176 | 177 | # Vectorized version for DP & multi-class 178 | def get_error_and_violations_DP(self, y_pred, label, sen_attrs, num_groups, num_classes): 179 | acc = torch.mean(y_pred == label) 180 | total_num = len(y_pred) 181 | violations = torch.zeros((num_groups, num_classes)) 182 | if self.cuda: 183 | violations = violations.cuda() 184 | for g in range(num_groups): 185 | for c in range(num_classes): 186 | pivot = len(torch.where(y_pred == c)[0]) / total_num 187 | group_idxs = torch.where(sen_attrs == g)[0] 188 | group_pred_idxs = torch.where(torch.logical_and(sen_attrs == g, y_pred == c))[0] 189 | violations[g, c] = len(group_pred_idxs)/len(group_idxs) - pivot 190 | return acc, violations 191 | 192 | # Vectorized version for EO & multi-class 193 | def get_error_and_violations_EO(self, y_pred, label, sen_attrs, num_groups, num_classes): 194 | acc = torch.mean((y_pred == label).float()) 195 | violations = torch.zeros((num_groups, num_classes)) 196 | if self.cuda: 197 | violations = violations.cuda() 198 | for g in range(num_groups): 199 | for c in range(num_classes): 200 | class_idxs = torch.where(label == c)[0] 201 | pred_class_idxs = torch.where(torch.logical_and(y_pred == c, label == c))[0] 202 | pivot = len(pred_class_idxs)/len(class_idxs) 203 | group_class_idxs = torch.where(torch.logical_and(sen_attrs == g, label == c))[0] 204 | group_pred_class_idxs = torch.where(torch.logical_and(torch.logical_and(sen_attrs == g, y_pred == c), label == c))[0] 205 | violations[g, c] = len(group_pred_class_idxs)/len(group_class_idxs) - pivot 206 | print('violations', violations) 207 | return acc, violations 208 | 209 | # update weight 210 | def debias_weights(self, label, sen_attrs, extended_multipliers, num_groups, num_classes): 211 | weights = torch.zeros(len(label)) 212 | w_matrix = torch.sigmoid(extended_multipliers) # g by c 213 | weights = w_matrix[sen_attrs, label] 214 | if self.slmode and self.version == 2: 215 | weights[sen_attrs == -1] = 0.5 216 | return weights 217 | 218 | def criterion(self, model, outputs, labels): 219 | return nn.CrossEntropyLoss()(outputs, labels) 220 | -------------------------------------------------------------------------------- /trainer/adv_debiasing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original code: 3 | https://github.com/sangwon79/Fair-Feature-Distillation-for-Visual-Recognition 4 | """ 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | 12 | import time 13 | 14 | from utils import get_accuracy 15 | from networks.mlp import MLP 16 | import trainer 17 | from torch.optim.lr_scheduler import ReduceLROnPlateau 18 | 19 | 20 | class Trainer(trainer.GenericTrainer): 21 | def __init__(self, args, **kwargs): 22 | super().__init__(args=args, **kwargs) 23 | self.adv_lambda = args.lamb 24 | self.adv_lr = args.eta 25 | self.target_criterion = 'eo' 26 | 27 | # self.film = args.film 28 | # self.no_film_residual = args.no_film_residual 29 | 30 | # self.no_groupmask = args.no_groupmask 31 | # self.mask_step = args.mask_step 32 | # param_m = [param for name, param in self.model.named_parameters() if 'mask' in name] \ 33 | # if not args.no_groupmask and self.decouple else None 34 | # self.mask_optimizer = optim.Adam(param_m, lr=args.mask_lr, weight_decay=args.weight_decay) \ 35 | # if not args.no_groupmask and self.decouple else None 36 | # self.scheduler_mask = ReduceLROnPlateau(self.mask_optimizer, patience=5) \ 37 | # if not args.no_groupmask and self.decouple else None 38 | 39 | def train(self, train_loader, test_loader, epochs): 40 | model = self.model 41 | num_groups = train_loader.dataset.num_groups 42 | num_classes = train_loader.dataset.num_classes 43 | self._init_adversary(num_groups, num_classes, train_loader) 44 | self.scheduler = ReduceLROnPlateau(self.optimizer, patience=5) 45 | 46 | for epoch in range(epochs): 47 | self._train_epoch(epoch, train_loader, model) 48 | 49 | eval_start_time = time.time() 50 | eval_loss, eval_acc, eval_adv_loss, eval_adv_acc, eval_deopp = \ 51 | self.evaluate(model, self.sa_clf, test_loader, self.criterion, self.adv_criterion) 52 | eval_end_time = time.time() 53 | print('[{}/{}] Method: {} ' 54 | 'Test Loss: {:.3f} Test Acc: {:.2f} Test Adv loss: {:.3f} Test Adv Acc: {:.2f} Test DEopp {:.2f} [{:.2f} s]'.format 55 | (epoch + 1, epochs, self.method, 56 | eval_loss, eval_acc, eval_adv_loss, eval_adv_acc, eval_deopp, (eval_end_time - eval_start_time))) 57 | 58 | if self.scheduler != None: 59 | self.scheduler.step(eval_loss) 60 | self.adv_scheduler.step(eval_adv_loss) 61 | 62 | print('Training Finished!') 63 | return model 64 | 65 | def _train_epoch(self, epoch, train_loader, model): 66 | num_classes = train_loader.dataset.num_classes 67 | num_groups = train_loader.dataset.num_groups 68 | 69 | model.train() 70 | 71 | running_acc = 0.0 72 | running_loss = 0.0 73 | running_adv_loss = 0.0 74 | batch_start_time = time.time() 75 | 76 | for i, data in enumerate(train_loader): 77 | # Get the inputs 78 | inputs, _, groups, targets, _ = data 79 | labels = targets 80 | # groups = groups.long() 81 | 82 | if self.cuda: 83 | inputs = inputs.cuda() 84 | labels = labels.cuda() 85 | groups = groups.cuda() 86 | 87 | labels = labels.long() 88 | groups = groups.long() 89 | 90 | outputs = model(inputs) 91 | 92 | 93 | inputs_for_adv = outputs 94 | logits = outputs 95 | 96 | adv_inputs = None 97 | if self.target_criterion =='eo': 98 | repeat_times = num_classes 99 | input_loc = F.one_hot(labels.long(), num_classes).repeat_interleave(repeat_times, dim=1) 100 | adv_inputs = inputs_for_adv.repeat(1, repeat_times) * input_loc 101 | adv_inputs = torch.cat((inputs_for_adv, adv_inputs), dim=1) 102 | 103 | elif self.target_criterion == 'dp': 104 | adv_inputs = inputs_for_adv 105 | 106 | adv_preds = self.sa_clf(adv_inputs) 107 | # adv_loss = self.adv_criterion(self.sa_clf, adv_preds, groups) 108 | adv_loss = self.adv_criterion(adv_preds, groups) 109 | 110 | self.optimizer.zero_grad() 111 | self.adv_optimizer.zero_grad() 112 | 113 | #adv_loss.backward()#retain_graph=True) 114 | #adv_loss.backward(retain_graph=True) 115 | #for n, p in model.named_parameters(): 116 | # unit_adv_grad = p.grad / (p.grad.norm() + torch.finfo(torch.float32).tiny) 117 | # p.grad += torch.sum(p.grad * unit_adv_grad) * unit_adv_grad # gradients are already reversed 118 | 119 | loss = self.criterion(logits, labels) 120 | 121 | (loss+adv_loss).backward() 122 | 123 | self.optimizer.step() 124 | self.adv_optimizer.step() 125 | 126 | running_loss += loss.item() 127 | running_adv_loss += adv_loss.item() 128 | # binary = True if num_classes ==2 else False 129 | running_acc += get_accuracy(outputs, labels) 130 | 131 | # self.optimizer.step() 132 | # self.adv_optimizer.step() 133 | 134 | if i % self.term == self.term - 1: # print every self.term mini-batches 135 | avg_batch_time = time.time() - batch_start_time 136 | print_statement = '[{}/{}, {:5d}] Method: {} Train Loss: {:.3f} Adv Loss: {:.3f} Train Acc: {:.2f} [{:.2f} s/batch]'\ 137 | .format(epoch + 1, self.epochs, i + 1, self.method, running_loss / self.term, 138 | running_adv_loss / self.term,running_acc / self.term, avg_batch_time / self.term) 139 | print(print_statement) 140 | 141 | running_loss = 0.0 142 | running_acc = 0.0 143 | running_adv_loss = 0.0 144 | batch_start_time = time.time() 145 | 146 | def evaluate(self, model, adversary, loader, criterion, adv_criterion): 147 | model.eval() 148 | num_groups = loader.dataset.num_groups 149 | num_classes = loader.dataset.num_classes 150 | eval_acc = 0 151 | eval_adv_acc = 0 152 | eval_loss = 0 153 | eval_adv_loss = 0 154 | eval_eopp_list = torch.zeros(num_groups, num_classes).cuda() 155 | eval_data_count = torch.zeros(num_groups, num_classes).cuda() 156 | 157 | if 'Custom' in type(loader).__name__: 158 | loader = loader.generate() 159 | with torch.no_grad(): 160 | for j, eval_data in enumerate(loader): 161 | # Get the inputs 162 | inputs, _, groups, classes, _ = eval_data 163 | # 164 | labels = classes 165 | groups = groups.long() 166 | if self.cuda: 167 | inputs = inputs.cuda() 168 | labels = labels.cuda() 169 | groups = groups.cuda() 170 | 171 | labels = labels.long() 172 | 173 | get_inter = False 174 | outputs = model(inputs, get_inter=get_inter) 175 | 176 | inputs_for_adv = outputs[-2] if get_inter else outputs 177 | logits = outputs[-1] if get_inter else outputs 178 | 179 | adv_inputs = None 180 | if self.target_criterion == 'eo': 181 | repeat_times = num_classes 182 | input_loc = F.one_hot(labels.long(), num_classes).repeat_interleave(repeat_times, dim=1) 183 | adv_inputs = inputs_for_adv.repeat(1, repeat_times) * input_loc 184 | adv_inputs = torch.cat((inputs_for_adv, adv_inputs), dim=1) 185 | 186 | elif self.target_criterion == 'dp': 187 | adv_inputs = inputs_for_adv 188 | 189 | loss = criterion(logits, labels) 190 | eval_loss += loss.item() * len(labels) 191 | binary = True if num_classes == 2 else False 192 | acc = get_accuracy(outputs, labels, reduction='none') 193 | eval_acc += acc.sum() 194 | 195 | for g in range(num_groups): 196 | for l in range(num_classes): 197 | eval_eopp_list[g, l] += acc[(groups == g) * (labels == l)].sum() 198 | eval_data_count[g, l] += torch.sum((groups == g) * (labels == l)) 199 | 200 | adv_preds = adversary(adv_inputs) 201 | # groups = groups.float() if num_groups == 2 else groups.long() 202 | groups = groups.long() 203 | adv_loss = adv_criterion(adv_preds, groups) 204 | eval_adv_loss += adv_loss.item() * len(labels) 205 | # binary = True if num_groups == 2 else False 206 | eval_adv_acc += get_accuracy(adv_preds, groups) 207 | 208 | eval_loss = eval_loss / eval_data_count.sum() 209 | eval_acc = eval_acc / eval_data_count.sum() 210 | eval_adv_loss = eval_adv_loss / eval_data_count.sum() 211 | eval_adv_acc = eval_adv_acc / eval_data_count.sum() 212 | eval_eopp_list = eval_eopp_list / eval_data_count 213 | eval_max_eopp = torch.max(eval_eopp_list, dim=0)[0] - torch.min(eval_eopp_list, dim=0)[0] 214 | eval_max_eopp = torch.max(eval_max_eopp).item() 215 | model.train() 216 | return eval_loss, eval_acc, eval_adv_loss, eval_adv_acc, eval_max_eopp 217 | 218 | def _init_adversary(self, num_groups, num_classes, dataloader): 219 | self.model.eval() 220 | if self.target_criterion == 'eo': 221 | feature_size = num_classes * (num_classes + 1) 222 | elif self.target_criterion == 'dp': 223 | feature_size = num_classes 224 | 225 | 226 | sa_clf = MLP(feature_size=feature_size, hidden_dim=32, num_classes=num_groups, 227 | num_layer=2, adv=True, adv_lambda=self.adv_lambda) 228 | if self.cuda: 229 | sa_clf.cuda() 230 | sa_clf.train() 231 | self.sa_clf = sa_clf 232 | self.adv_optimizer = optim.Adam(sa_clf.parameters(), lr=self.adv_lr) 233 | self.adv_scheduler = ReduceLROnPlateau(self.adv_optimizer, patience=5) 234 | self.adv_criterion = self.criterion 235 | 236 | def criterion(self, model, outputs, labels): 237 | return nn.CrossEntropyLoss()(outputs, labels) 238 | -------------------------------------------------------------------------------- /data_handler/ssl_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | import random 7 | import numpy as np 8 | import torch 9 | from data_handler import GenericDataset 10 | import os 11 | import pickle 12 | 13 | 14 | class SSLDataset(GenericDataset): 15 | def __init__(self, sv_ratio=1.0, version='', ups_iter=0, **kwargs): 16 | super(SSLDataset, self).__init__(**kwargs) 17 | self.sv_ratio = sv_ratio 18 | self.version = version 19 | self.add_attr = None 20 | self.ups_iter = ups_iter 21 | 22 | def ssl_processing(self, features, num_data, idxs_per_group, idxs_dict=None): 23 | if self.sv_ratio >= 1: 24 | raise ValueError 25 | 26 | if self.split == 'test' and 'group' not in self.version: 27 | return features, num_data, idxs_per_group 28 | 29 | print('preprocessing for ssl...') 30 | num_groups, num_classes = num_data.shape 31 | 32 | folder_name = 'annotated_idxs' 33 | idx_filename = '{}_{}'.format(self.seed, self.sv_ratio) 34 | if self.name == 'celeba': 35 | if self.target_attr != 'Attractive': 36 | idx_filename += f'_{self.target_attr}' 37 | if self.add_attr is not None: 38 | idx_filename += f'_{self.add_attr}' 39 | idx_filename += '.pkl' 40 | filepath = os.path.join(self.root, folder_name, idx_filename) 41 | 42 | if idxs_dict is None: 43 | if not os.path.isfile(filepath): 44 | idxs_dict = self.pick_idxs(num_data, idxs_per_group, filepath) 45 | else: 46 | with open(filepath, 'rb') as f: 47 | idxs_dict = pickle.load(f) 48 | 49 | if self.version == 'bs1': 50 | new_idxs = [] 51 | for g in range(num_groups): 52 | for l in range(num_classes): 53 | new_idxs.extend(idxs_dict['annotated'][(g, l)]) 54 | new_idxs.sort() 55 | new_features = [features[idx] for idx in new_idxs] 56 | features = new_features 57 | 58 | elif self.version == 'bs2': 59 | idx_pool = list(range(num_groups)) 60 | total_per_group = np.zeros((num_classes, num_groups)) 61 | for l in range(num_classes): 62 | for g in range(num_groups): 63 | total_per_group[l, g] = len(idxs_dict['annotated'][(g, l)]) 64 | total_per_group = total_per_group.astype(int) 65 | for g in range(num_groups): 66 | for l in range(num_classes): 67 | for idx in idxs_dict['non-annotated'][(g, l)]: 68 | features[idx][0] = random.choices(idx_pool, k=1, weights=list(total_per_group[l]))[0] 69 | 70 | elif self.version == 'bs1uncertain': 71 | # bs2new is to use unlabeled data only for accuracy 72 | folder = 'group_clf' 73 | model = 'resnet18_dropout' 74 | epochs = '70' 75 | filename_pre = f'{model}_seed{self.seed}_epochs{epochs}_bs128_lr0.001' 76 | filename_post = f'_sv{self.sv_ratio}_groupclf.pt' 77 | filename = filename_pre + filename_post 78 | path = os.path.join(self.root, folder, filename) 79 | preds = torch.load(path)['pred'] 80 | 81 | new_idxs = [] 82 | for g in range(num_groups): 83 | for l in range(num_classes): 84 | # annotated 85 | new_idxs.extend(idxs_dict['annotated'][(g, l)]) 86 | # non-annotated 87 | for idx in idxs_dict['non-annotated'][(g, l)]: 88 | if preds[idx] != -1: 89 | new_idxs.append(idx) 90 | features[idx][0] = preds[idx].item() 91 | 92 | print(len(new_idxs), len(features)) 93 | new_idxs.sort() 94 | new_features = [features[idx] for idx in new_idxs] 95 | features = new_features 96 | 97 | elif self.version == 'bs3': 98 | # bs 3 is to make psuedo labels using a model being trained with labeld group label from scratch 99 | folder = 'group_clf' if self.version == 'bs3' else 'group_clf_pretrain' 100 | model = 'resnet18' if self.name in ['utkface', 'celeba'] else 'mlp' 101 | epochs = '70' if self.name in ['utkface', 'celeba'] else '50' 102 | filename_pre = f'{model}_seed{self.seed}_epochs{epochs}_bs128_lr0.001' 103 | filename_post = f'_sv{self.sv_ratio}_groupclf.pt' 104 | if self.name == 'celeba': 105 | if self.target_attr != 'Attractive': 106 | filename_pre += f'_{self.target_attr}' 107 | if self.add_attr is not None: 108 | filename_pre += f'_{self.add_attr}' 109 | filename = filename_pre + filename_post 110 | path = os.path.join(self.root, folder, filename) 111 | preds = torch.load(path)['pred'] 112 | 113 | for g in range(num_groups): 114 | for l in range(num_classes): 115 | for idx in idxs_dict['non-annotated'][(g, l)]: 116 | features[idx][0] = preds[idx].item() 117 | 118 | elif self.version == 'groupclf': 119 | if self.ups_iter > 0: 120 | folder = 'group_clf' 121 | filename_pre = f'resnet18_dropout_seed{self.seed}_epochs70_bs128_lr0.001' 122 | if self.ups_iter > 1: 123 | filename_post = f'_sv{self.sv_ratio}_iter{_iter}_groupclf.pt' 124 | else: 125 | filename_post = f'_sv{self.sv_ratio}_groupclf.pt' 126 | filename = filename_pre + filename_post 127 | path = os.path.join(self.root, folder, filename) 128 | preds = torch.load(path)['pred'] 129 | 130 | new_idxs = [] 131 | for g in range(num_groups): 132 | for l in range(num_classes): 133 | if self.split == 'train': 134 | idxs = idxs_dict['annotated'][(g, l)] 135 | elif self.split == 'test': 136 | idxs = idxs_dict['non-annotated'][(g, l)] 137 | new_idxs.extend(idxs) 138 | new_idxs.sort() 139 | new_features = [features[idx] for idx in new_idxs] 140 | features = new_features 141 | 142 | elif self.version == 'groupclf_val': 143 | new_idxs = [] 144 | val_idxs = [] 145 | for g in range(num_groups): 146 | for l in range(num_classes): 147 | if self.split == 'train': 148 | train_num = int(len(idxs_dict['annotated'][(g, l)]) * 0.8) 149 | idxs = idxs_dict['annotated'][(g, l)][:train_num] 150 | val_idxs.extend(idxs_dict['annotated'][(g, l)][train_num:]) 151 | elif self.split == 'test': 152 | idxs = idxs_dict['non-annotated'][(g, l)] 153 | new_idxs.extend(idxs) 154 | new_idxs.sort() 155 | new_features = [features[idx] for idx in new_idxs] 156 | features = new_features 157 | self.val_idxs = val_idxs 158 | 159 | elif self.version == 'oracle': 160 | # this version is to make a oracle model about predicting noisy lables. 161 | folder = 'group_clf' 162 | filename = 'resnet18_seed{}_epochs70_bs128_lr0.001_sv{}_groupclf_val.pt'.format(self.seed, self.sv_ratio) 163 | path = os.path.join(self.root, folder, filename) 164 | preds = torch.load(path)['pred'] 165 | idx_pool = list(range(num_groups)) 166 | for g in range(num_groups): 167 | for l in range(num_classes): 168 | for idx in idxs_dict['non-annotated'][(g, l)]: 169 | if features[idx][0] != preds[idx].item(): 170 | features[idx][0] = random.choices(idx_pool, k=1)[0] 171 | 172 | elif self.version == 'cgl': 173 | folder = 'group_clf' 174 | model = 'resnet18' if self.name in ['utkface', 'celeba', 'utkface_fairface'] else 'mlp' 175 | epochs = '70' if self.name in ['utkface', 'celeba', 'utkface_fairface'] else '50' 176 | bs = '128' if self.name != 'adult' else '1024' 177 | filename_pre = f'{model}_seed{self.seed}_epochs{epochs}_bs{bs}_lr0.001' 178 | filename_post = f'_sv{self.sv_ratio}_groupclf_val.pt' 179 | if self.name == 'celeba': 180 | if self.target_attr != 'Attractive': 181 | filename_pre += f'_{self.target_attr}' 182 | if self.add_attr is not None: 183 | filename_pre += f'_{self.add_attr}' 184 | filename = filename_pre + filename_post 185 | 186 | if self.name == 'utkface_fairface': 187 | path = os.path.join('./data/utkface_fairface', folder, filename) 188 | else: 189 | path = os.path.join(self.root, folder, filename) 190 | 191 | preds = torch.load(path)['pred'] 192 | probs = torch.load(path)['probs'] 193 | thres = torch.load(path)['opt_thres'] 194 | print('thres : ', thres) 195 | idx_pool = list(range(num_groups)) 196 | 197 | total_per_group = np.zeros((num_classes, num_groups)) 198 | for l in range(num_classes): 199 | for g in range(num_groups): 200 | total_per_group[l, g] = len(idxs_dict['annotated'][(g, l)]) 201 | total_per_group = total_per_group.astype(int) 202 | for g in range(num_groups): 203 | for l in range(num_classes): 204 | for idx in idxs_dict['non-annotated'][(g, l)]: 205 | if self.version == 'cgl': 206 | if probs[idx].max() >= thres: 207 | features[idx][0] = preds[idx].item() 208 | else: 209 | features[idx][0] = random.choices(idx_pool, k=1, weights=list(total_per_group[l]))[0] 210 | 211 | else: 212 | raise ValueError 213 | print('count the number of data newly!') 214 | num_data, idxs_per_group = self._data_count(features, num_groups, num_classes) 215 | 216 | return features, num_data, idxs_per_group 217 | 218 | def pick_idxs(self, num_data, idxs_per_group, filepath): 219 | print(''.format(filepath)) 220 | if not os.path.isdir(os.path.join(self.root, 'annotated_idxs')): 221 | os.mkdir(os.path.join(self.root, 'annotated_idxs')) 222 | num_groups, num_classes = num_data.shape 223 | idxs_dict = {} 224 | idxs_dict['annotated'] = {} 225 | idxs_dict['non-annotated'] = {} 226 | for g in range(num_groups): 227 | for l in range(num_classes): 228 | num_nonannotated = int(num_data[g, l] * (1-self.sv_ratio)) 229 | print(g, l, num_nonannotated) 230 | idxs_nonannotated = random.sample(idxs_per_group[(g, l)], num_nonannotated) 231 | idxs_annotated = [idx for idx in idxs_per_group[(g, l)] if idx not in idxs_nonannotated] 232 | idxs_dict['non-annotated'][(g, l)] = idxs_nonannotated 233 | idxs_dict['annotated'][(g, l)] = idxs_annotated 234 | 235 | with open(filepath, 'wb') as f: 236 | pickle.dump(idxs_dict, f, pickle.HIGHEST_PROTOCOL) 237 | 238 | return idxs_dict 239 | -------------------------------------------------------------------------------- /networks/resnet_dropout.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original code: 3 | https://github.com/nayeemrizve/ups 4 | """ 5 | import sys 6 | import math 7 | import itertools 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from torch.autograd import Variable, Function 13 | # import torch.utils.model_zoo as model_zoo 14 | 15 | class ResNet224x224(nn.Module): 16 | def __init__(self, block, layers, channels, groups=1, num_classes=1000, downsample='basic'): 17 | super().__init__() 18 | assert len(layers) == 4 19 | self.downsample_mode = downsample 20 | self.inplanes = 64 21 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 22 | bias=False) 23 | self.bn1 = nn.BatchNorm2d(self.inplanes) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 26 | self.layer1 = self._make_layer(block, channels, groups, layers[0]) 27 | self.layer2 = self._make_layer( 28 | block, channels * 2, groups, layers[1], stride=2) 29 | self.layer3 = self._make_layer( 30 | block, channels * 4, groups, layers[2], stride=2) 31 | self.layer4 = self._make_layer( 32 | block, channels * 8, groups, layers[3], stride=2) 33 | self.avgpool = nn.AvgPool2d(7) 34 | self.fc = nn.Linear(block.out_channels( 35 | channels * 8, groups), num_classes) 36 | 37 | for m in self.modules(): 38 | if isinstance(m, nn.Conv2d): 39 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 40 | m.weight.data.normal_(0, math.sqrt(2. / n)) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | m.weight.data.fill_(1) 43 | m.bias.data.zero_() 44 | 45 | def _make_layer(self, block, planes, groups, blocks, stride=1): 46 | downsample = None 47 | if stride != 1 or self.inplanes != block.out_channels(planes, groups): 48 | if self.downsample_mode == 'basic' or stride == 1: 49 | downsample = nn.Sequential( 50 | nn.Conv2d(self.inplanes, block.out_channels(planes, groups), 51 | kernel_size=1, stride=stride, bias=False), 52 | nn.BatchNorm2d(block.out_channels(planes, groups)), 53 | ) 54 | elif self.downsample_mode == 'shift_conv': 55 | downsample = ShiftConvDownsample(in_channels=self.inplanes, 56 | out_channels=block.out_channels(planes, groups)) 57 | else: 58 | assert False 59 | 60 | layers = [] 61 | layers.append(block(self.inplanes, planes, groups, stride, downsample)) 62 | self.inplanes = block.out_channels(planes, groups) 63 | for i in range(1, blocks): 64 | layers.append(block(self.inplanes, planes, groups)) 65 | 66 | return nn.Sequential(*layers) 67 | 68 | def forward(self, x): 69 | x = self.conv1(x) 70 | x = self.bn1(x) 71 | x = self.relu(x) 72 | x = self.maxpool(x) 73 | x = self.layer1(x) 74 | x = self.layer2(x) 75 | x = self.layer3(x) 76 | x = self.layer4(x) 77 | x = self.avgpool(x) 78 | x = x.view(x.size(0), -1) 79 | return self.fc(x) 80 | 81 | 82 | class ResNet32x32(nn.Module): 83 | def __init__(self, block, layers, channels, groups=1, downsample='basic', num_classes=10, dropout=0.3): 84 | super().__init__() 85 | assert len(layers) == 3 86 | self.downsample_mode = downsample 87 | self.inplanes = 16 88 | self.dropout = dropout 89 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, 90 | padding=1, bias=False) 91 | self.layer1 = self._make_layer(block, channels, groups, layers[0]) 92 | self.layer2 = self._make_layer( 93 | block, channels * 2, groups, layers[1], stride=2) 94 | self.layer3 = self._make_layer( 95 | block, channels * 4, groups, layers[2], stride=2) 96 | self.avgpool = nn.AvgPool2d(8) 97 | self.fc = nn.Linear(block.out_channels( 98 | channels * 4, groups), num_classes) 99 | 100 | for m in self.modules(): 101 | if isinstance(m, nn.Conv2d): 102 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 103 | m.weight.data.normal_(0, math.sqrt(2. / n)) 104 | elif isinstance(m, nn.BatchNorm2d): 105 | m.weight.data.fill_(1) 106 | m.bias.data.zero_() 107 | 108 | def _make_layer(self, block, planes, groups, blocks, stride=1): 109 | downsample = None 110 | if stride != 1 or self.inplanes != block.out_channels(planes, groups): 111 | if self.downsample_mode == 'basic' or stride == 1: 112 | downsample = nn.Sequential( 113 | nn.Conv2d(self.inplanes, block.out_channels(planes, groups), 114 | kernel_size=1, stride=stride, bias=False), 115 | nn.BatchNorm2d(block.out_channels(planes, groups)), 116 | ) 117 | elif self.downsample_mode == 'shift_conv': 118 | downsample = ShiftConvDownsample(in_channels=self.inplanes, 119 | out_channels=block.out_channels(planes, groups)) 120 | else: 121 | assert False 122 | 123 | layers = [] 124 | layers.append(block(self.inplanes, planes, groups, stride, downsample, dropout=self.dropout)) 125 | self.inplanes = block.out_channels(planes, groups) 126 | for i in range(1, blocks): 127 | layers.append(block(self.inplanes, planes, groups, dropout=self.dropout)) 128 | 129 | return nn.Sequential(*layers) 130 | 131 | def forward(self, x, randomness=0, mini_batch=1): 132 | x = self.conv1(x) 133 | x = self.layer1(x) 134 | x = self.layer2(x) 135 | x = self.layer3(x) 136 | x = self.avgpool(x) 137 | x = x.view(x.size(0), -1) 138 | return self.fc(x) 139 | 140 | 141 | def conv3x3(in_planes, out_planes, stride=1): 142 | "3x3 convolution with padding" 143 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 144 | padding=1, bias=False) 145 | 146 | 147 | class BottleneckBlock(nn.Module): 148 | @classmethod 149 | def out_channels(cls, planes, groups): 150 | if groups > 1: 151 | return 2 * planes 152 | else: 153 | return 4 * planes 154 | 155 | def __init__(self, inplanes, planes, groups, stride=1, downsample=None): 156 | super().__init__() 157 | self.relu = nn.ReLU(inplace=True) 158 | 159 | self.conv_a1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 160 | self.bn_a1 = nn.BatchNorm2d(planes) 161 | self.conv_a2 = nn.Conv2d( 162 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False, groups=groups) 163 | self.bn_a2 = nn.BatchNorm2d(planes) 164 | self.conv_a3 = nn.Conv2d(planes, self.out_channels( 165 | planes, groups), kernel_size=1, bias=False) 166 | self.bn_a3 = nn.BatchNorm2d(self.out_channels(planes, groups)) 167 | 168 | self.downsample = downsample 169 | self.stride = stride 170 | 171 | def forward(self, x): 172 | a, residual = x, x 173 | 174 | a = self.conv_a1(a) 175 | a = self.bn_a1(a) 176 | a = self.relu(a) 177 | a = self.conv_a2(a) 178 | a = self.bn_a2(a) 179 | a = self.relu(a) 180 | a = self.conv_a3(a) 181 | a = self.bn_a3(a) 182 | 183 | if self.downsample is not None: 184 | residual = self.downsample(residual) 185 | 186 | return self.relu(residual + a) 187 | 188 | 189 | class ShakeShakeBlock(nn.Module): 190 | @classmethod 191 | def out_channels(cls, planes, groups): 192 | assert groups == 1 193 | return planes 194 | 195 | def __init__(self, inplanes, planes, groups, stride=1, downsample=None, dropout=0.3): 196 | super().__init__() 197 | assert groups == 1 198 | self.conv_a1 = conv3x3(inplanes, planes, stride) 199 | self.bn_a1 = nn.BatchNorm2d(planes) 200 | self.conv_a2 = conv3x3(planes, planes) 201 | self.bn_a2 = nn.BatchNorm2d(planes) 202 | 203 | self.conv_b1 = conv3x3(inplanes, planes, stride) 204 | self.bn_b1 = nn.BatchNorm2d(planes) 205 | self.conv_b2 = conv3x3(planes, planes) 206 | self.bn_b2 = nn.BatchNorm2d(planes) 207 | 208 | self.downsample = downsample 209 | self.stride = stride 210 | self.dropout = dropout 211 | if self.dropout > 0: 212 | self.drop_a = nn.Dropout(p=self.dropout) 213 | self.drop_b = nn.Dropout(p=self.dropout) 214 | # self.drop = nn.Dropout2d(self.dropout) 215 | 216 | def forward(self, x): 217 | a, b, residual = x, x, x 218 | 219 | a = F.relu(a, inplace=False) 220 | a = self.conv_a1(a) 221 | a = self.bn_a1(a) 222 | a = F.relu(a, inplace=True) 223 | if self.dropout > 0: 224 | a = self.drop_a(a) 225 | a = self.conv_a2(a) 226 | a = self.bn_a2(a) 227 | 228 | b = F.relu(b, inplace=False) 229 | b = self.conv_b1(b) 230 | b = self.bn_b1(b) 231 | b = F.relu(b, inplace=True) 232 | if self.dropout > 0: 233 | b = self.drop_b(b) 234 | b = self.conv_b2(b) 235 | b = self.bn_b2(b) 236 | 237 | ab = shake(a, b, training=self.training) 238 | 239 | if self.downsample is not None: 240 | residual = self.downsample(x) 241 | 242 | return residual + ab 243 | 244 | class Shake(Function): 245 | @classmethod 246 | def forward(cls, ctx, inp1, inp2, training): 247 | assert inp1.size() == inp2.size() 248 | gate_size = [inp1.size()[0], *itertools.repeat(1, inp1.dim() - 1)] 249 | gate = inp1.new(*gate_size) 250 | if training: 251 | gate.uniform_(0, 1) 252 | else: 253 | gate.fill_(0.5) 254 | return inp1 * gate + inp2 * (1. - gate) 255 | 256 | @classmethod 257 | def backward(cls, ctx, grad_output): 258 | grad_inp1 = grad_inp2 = grad_training = None 259 | gate_size = [grad_output.size()[0], *itertools.repeat(1, 260 | grad_output.dim() - 1)] 261 | gate = Variable(grad_output.data.new(*gate_size).uniform_(0, 1)) 262 | if ctx.needs_input_grad[0]: 263 | grad_inp1 = grad_output * gate 264 | if ctx.needs_input_grad[1]: 265 | grad_inp2 = grad_output * (1 - gate) 266 | assert not ctx.needs_input_grad[2] 267 | return grad_inp1, grad_inp2, grad_training 268 | 269 | 270 | def shake(inp1, inp2, training=False): 271 | return Shake.apply(inp1, inp2, training) 272 | 273 | 274 | class ShiftConvDownsample(nn.Module): 275 | def __init__(self, in_channels, out_channels): 276 | super().__init__() 277 | self.relu = nn.ReLU(inplace=True) 278 | self.conv = nn.Conv2d(in_channels=2 * in_channels, 279 | out_channels=out_channels, 280 | kernel_size=1, 281 | groups=2) 282 | self.bn = nn.BatchNorm2d(out_channels) 283 | 284 | def forward(self, x): 285 | x = torch.cat((x[:, :, 0::2, 0::2], 286 | x[:, :, 1::2, 1::2]), dim=1) 287 | x = self.relu(x) 288 | x = self.conv(x) 289 | x = self.bn(x) 290 | return x 291 | 292 | def resnet18_dropout(pretrained=False, num_groups=0, img_size=224, **kwargs): 293 | return ResNet224x224(ShakeShakeBlock, channels=64, layers=[2,2,2,2], **kwargs) 294 | 295 | 296 | if __name__ == "__main__": 297 | model = ResNet32x32(ShakeShakeBlock, 298 | layers=[4, 4, 4], 299 | channels=96, 300 | downsample='shift_conv') 301 | x = torch.randn(1,3,32,32) 302 | y, feature = model(x) 303 | # y.shape = [1, 10] 304 | # feature.shape = [1, 512] 305 | # import IPython 306 | # IPython.embed() 307 | # print(y.size()) 308 | -------------------------------------------------------------------------------- /networks/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original code: 3 | https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | from torchvision.models.utils import load_state_dict_from_url 8 | 9 | 10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50'] 11 | 12 | 13 | model_urls = { 14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 21 | """3x3 convolution with padding""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=dilation, groups=groups, bias=False, dilation=dilation) 24 | 25 | 26 | def conv1x1(in_planes, out_planes, stride=1): 27 | """1x1 convolution""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 29 | 30 | 31 | class BasicBlock(nn.Module): 32 | expansion = 1 33 | __constants__ = ['downsample'] 34 | 35 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 36 | base_width=64, dilation=1, norm_layer=None): 37 | super(BasicBlock, self).__init__() 38 | if norm_layer is None: 39 | norm_layer = nn.BatchNorm2d 40 | if groups != 1 or base_width != 64: 41 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 42 | if dilation > 1: 43 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 44 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 45 | self.conv1 = conv3x3(inplanes, planes, stride) 46 | self.bn1 = norm_layer(planes) 47 | self.relu = nn.ReLU(inplace=True) 48 | self.conv2 = conv3x3(planes, planes) 49 | self.bn2 = norm_layer(planes) 50 | self.downsample = downsample 51 | self.stride = stride 52 | 53 | def forward(self, x): 54 | identity = x 55 | 56 | out = self.conv1(x) 57 | out = self.bn1(out) 58 | out = self.relu(out) 59 | 60 | out = self.conv2(out) 61 | out = self.bn2(out) 62 | 63 | if self.downsample is not None: 64 | identity = self.downsample(x) 65 | 66 | out += identity 67 | out = self.relu(out) 68 | 69 | return out 70 | 71 | 72 | class Bottleneck(nn.Module): 73 | expansion = 4 74 | __constants__ = ['downsample'] 75 | 76 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 77 | base_width=64, dilation=1, norm_layer=None): 78 | super(Bottleneck, self).__init__() 79 | if norm_layer is None: 80 | norm_layer = nn.BatchNorm2d 81 | width = int(planes * (base_width / 64.)) * groups 82 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 83 | self.conv1 = conv1x1(inplanes, width) 84 | self.bn1 = norm_layer(width) 85 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 86 | self.bn2 = norm_layer(width) 87 | self.conv3 = conv1x1(width, planes * self.expansion) 88 | self.bn3 = norm_layer(planes * self.expansion) 89 | self.relu = nn.ReLU(inplace=True) 90 | self.downsample = downsample 91 | self.stride = stride 92 | 93 | def forward(self, x): 94 | identity = x 95 | 96 | out = self.conv1(x) 97 | out = self.bn1(out) 98 | out = self.relu(out) 99 | 100 | out = self.conv2(out) 101 | out = self.bn2(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv3(out) 105 | out = self.bn3(out) 106 | 107 | if self.downsample is not None: 108 | identity = self.downsample(x) 109 | 110 | out += identity 111 | out = self.relu(out) 112 | 113 | return out 114 | 115 | 116 | class ResNet(nn.Module): 117 | 118 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 119 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 120 | norm_layer=None, hydra=False, num_groups=None, img_size=32): 121 | super(ResNet, self).__init__() 122 | if norm_layer is None: 123 | norm_layer = nn.BatchNorm2d 124 | self._norm_layer = norm_layer 125 | 126 | self.inplanes = 64 127 | self.dilation = 1 128 | if replace_stride_with_dilation is None: 129 | # each element in the tuple indicates if we should replace 130 | # the 2x2 stride with a dilated convolution instead 131 | replace_stride_with_dilation = [False, False, False] 132 | if len(replace_stride_with_dilation) != 3: 133 | raise ValueError("replace_stride_with_dilation should be None " 134 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 135 | self.groups = groups 136 | self.base_width = width_per_group 137 | self.img_size = img_size 138 | if img_size==32: 139 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, 140 | bias=False) 141 | else: 142 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 143 | bias=False) 144 | self.bn1 = norm_layer(self.inplanes) 145 | self.relu = nn.ReLU(inplace=True) 146 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 147 | self.layer1 = self._make_layer(block, 64, layers[0]) 148 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 149 | dilate=replace_stride_with_dilation[0]) 150 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 151 | dilate=replace_stride_with_dilation[1]) 152 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 153 | dilate=replace_stride_with_dilation[2]) 154 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 155 | if not hydra: 156 | self.fc = nn.Linear(512 * block.expansion, num_classes) 157 | else: 158 | self.fc = nn.ModuleList() 159 | for _ in range(num_groups): 160 | self.fc.append(torch.nn.Linear(512 * block.expansion, num_classes)) 161 | 162 | for m in self.modules(): 163 | if isinstance(m, nn.Conv2d): 164 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 165 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 166 | nn.init.constant_(m.weight, 1) 167 | nn.init.constant_(m.bias, 0) 168 | 169 | # Zero-initialize the last BN in each residual branch, 170 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 171 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 172 | if zero_init_residual: 173 | for m in self.modules(): 174 | if isinstance(m, Bottleneck): 175 | nn.init.constant_(m.bn3.weight, 0) 176 | elif isinstance(m, BasicBlock): 177 | nn.init.constant_(m.bn2.weight, 0) 178 | 179 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 180 | norm_layer = self._norm_layer 181 | downsample = None 182 | previous_dilation = self.dilation 183 | if dilate: 184 | self.dilation *= stride 185 | stride = 1 186 | if stride != 1 or self.inplanes != planes * block.expansion: 187 | downsample = nn.Sequential( 188 | conv1x1(self.inplanes, planes * block.expansion, stride), 189 | norm_layer(planes * block.expansion), 190 | ) 191 | 192 | layers = [] 193 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 194 | self.base_width, previous_dilation, norm_layer)) 195 | self.inplanes = planes * block.expansion 196 | for _ in range(1, blocks): 197 | layers.append(block(self.inplanes, planes, groups=self.groups, 198 | base_width=self.base_width, dilation=self.dilation, 199 | norm_layer=norm_layer)) 200 | 201 | return nn.Sequential(*layers) 202 | 203 | def _forward_impl(self, x, get_inter=False, reid = False): 204 | # See note [TorchScript super()] 205 | h = self.conv1(x) 206 | h = self.bn1(h) 207 | h = self.relu(h) 208 | if self.img_size != 32: 209 | h = self.maxpool(h) 210 | 211 | b1 = self.layer1(h) 212 | b2 = self.layer2(b1) 213 | b3 = self.layer3(b2) 214 | b4 = self.layer4(b3) 215 | 216 | h = self.avgpool(b4) 217 | h1 = torch.flatten(h, 1) 218 | if isinstance(self.fc, nn.ModuleList): 219 | h = [] 220 | for i in range(len(self.fc)): 221 | h.append(self.fc[i](h1)) 222 | else: 223 | h = self.fc(h1) 224 | 225 | if get_inter: 226 | return b1, b2, b3, b4, h 227 | elif reid: 228 | return h1 229 | else: 230 | return h 231 | 232 | def forward(self, x, get_inter=False, reid=False): 233 | return self._forward_impl(x, get_inter, reid) 234 | 235 | def adapt_classifier(self, x): 236 | h = self.avgpool(x) 237 | h1 = torch.flatten(h, 1) 238 | return self.fc(h1) 239 | 240 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 241 | model = ResNet(block, layers, **kwargs) 242 | if pretrained: 243 | state_dict = load_state_dict_from_url(model_urls[arch], 244 | progress=progress) 245 | model.load_state_dict(state_dict) 246 | return model 247 | 248 | def resnet12(pretrained=False, progress=True,**kwargs): 249 | r"""ResNet-18 model from 250 | `"Deep Residual Learning for Image Recognition" `_ 251 | Args: 252 | pretrained (bool): If True, returns a model pre-trained on ImageNet 253 | progress (bool): If True, displays a progress bar of the download to stderr 254 | """ 255 | return _resnet('resnet12', BasicBlock, [1, 1, 2, 1], pretrained, progress, 256 | **kwargs) 257 | 258 | def resnet10(pretrained=False, progress=True,**kwargs): 259 | r"""ResNet-18 model from 260 | `"Deep Residual Learning for Image Recognition" `_ 261 | Args: 262 | pretrained (bool): If True, returns a model pre-trained on ImageNet 263 | progress (bool): If True, displays a progress bar of the download to stderr 264 | """ 265 | return _resnet('resnet10', BasicBlock, [1, 1, 1, 1], pretrained, progress, 266 | **kwargs) 267 | 268 | def resnet18(pretrained=False, progress=True, **kwargs): 269 | r"""ResNet-18 model from 270 | `"Deep Residual Learning for Image Recognition" `_ 271 | Args: 272 | pretrained (bool): If True, returns a model pre-trained on ImageNet 273 | progress (bool): If True, displays a progress bar of the download to stderr 274 | """ 275 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 276 | **kwargs) 277 | 278 | 279 | def resnet34(pretrained=False, progress=True, **kwargs): 280 | r"""ResNet-34 model from 281 | `"Deep Residual Learning for Image Recognition" `_ 282 | Args: 283 | pretrained (bool): If True, returns a model pre-trained on ImageNet 284 | progress (bool): If True, displays a progress bar of the download to stderr 285 | """ 286 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 287 | **kwargs) 288 | 289 | 290 | def resnet50(pretrained=False, progress=True, **kwargs): 291 | r"""ResNet-50 model from 292 | `"Deep Residual Learning for Image Recognition" `_ 293 | Args: 294 | pretrained (bool): If True, returns a model pre-trained on ImageNet 295 | progress (bool): If True, displays a progress bar of the download to stderr 296 | """ 297 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 298 | **kwargs) 299 | 300 | def resnet101(pretrained=False, progress=True, **kwargs) -> ResNet: 301 | r"""ResNet-101 model from 302 | `"Deep Residual Learning for Image Recognition" `_. 303 | Args: 304 | pretrained (bool): If True, returns a model pre-trained on ImageNet 305 | progress (bool): If True, displays a progress bar of the download to stderr 306 | """ 307 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 308 | **kwargs) 309 | -------------------------------------------------------------------------------- /main_groupclf.py: -------------------------------------------------------------------------------- 1 | """ 2 | cgl_fairness 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT license 5 | """ 6 | import torch 7 | from sklearn import metrics 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import numpy as np 11 | import networks 12 | import torch.nn.functional as F 13 | import data_handler 14 | import trainer 15 | from utils import check_log_dir, make_log_name, set_seed 16 | from data_handler.dataset_factory import DatasetFactory 17 | from arguments import get_args 18 | import time 19 | import pickle 20 | import os 21 | from torch.utils.data import DataLoader 22 | args = get_args() 23 | ################################################################################################################################################################### 24 | """ 25 | Used only for training group classifer 26 | """ 27 | ################################################################################################################################################################### 28 | 29 | 30 | def get_weights(loader, cuda=True): 31 | num_groups = loader.dataset.num_groups 32 | data_counts = torch.zeros(num_groups) 33 | data_counts = data_counts.cuda() if cuda else data_counts 34 | 35 | for data in loader: 36 | inputs, _, groups, _, _ = data 37 | for g in range(num_groups): 38 | data_counts[g] += torch.sum((groups == g)) 39 | 40 | weights = data_counts / data_counts.min() 41 | return weights, data_counts 42 | 43 | 44 | def focal_loss(input_values, gamma=10): 45 | """Computes the focal loss""" 46 | p = torch.exp(-input_values) 47 | loss = (1 - p) ** gamma * input_values 48 | return loss.mean() 49 | 50 | 51 | class FocalLoss(nn.Module): 52 | def __init__(self, weight=None, gamma=0.5): 53 | super(FocalLoss, self).__init__() 54 | assert gamma >= 0 55 | self.gamma = gamma 56 | self.weight = weight 57 | 58 | def forward(self, input, target): 59 | return focal_loss(F.cross_entropy(input, target, reduction='none', weight=self.weight), self.gamma) 60 | 61 | 62 | def measure_ood(ind_probs, ood_probs, tpr_thres=0.95): 63 | n_pos = len(ind_probs) 64 | n_neg = len(ood_probs) 65 | 66 | labels = np.append( 67 | np.ones_like(ind_probs) * 1, 68 | np.ones_like(ood_probs) * 2 69 | ) 70 | preds = np.append(ind_probs, ood_probs) 71 | fpr, tpr, thresholds = metrics.roc_curve(labels, preds, pos_label=1) 72 | auroc = metrics.auc(fpr, tpr) 73 | 74 | thres_indx = np.where(tpr >= tpr_thres)[0][0] 75 | tnr_at_tpr95 = 1 - fpr[thres_indx] 76 | 77 | temp_idx = np.argmax(((1 - fpr) * n_neg + tpr * n_pos) / (n_pos + n_neg)) 78 | detection_acc = np.max(((1 - fpr) * n_neg + tpr * n_pos) / (n_pos + n_neg)) 79 | if n_neg != n_pos: 80 | print(f'warning: n_neg ({n_neg}) != n_pos ({n_pos}). It may shows weird detection acc') 81 | print(f'current threshold ({thresholds[temp_idx]}): FPR ({fpr[temp_idx]}), TPR ({tpr[temp_idx]})') 82 | 83 | return { 84 | 'auroc': auroc, 85 | 'tnr_at_tpr95': tnr_at_tpr95, 86 | 'detection_acc': detection_acc, 87 | 'opt_thres': thresholds[temp_idx] 88 | } 89 | 90 | 91 | def predict_thres(probs, true_idxs, false_idxs, val_idxs): 92 | print(len(true_idxs), len(false_idxs)) 93 | val_false_idxs = list(set(false_idxs).intersection(val_idxs)) 94 | val_true_idxs = list(set(true_idxs).intersection(val_idxs)) 95 | 96 | val_true_maxprob = probs[val_true_idxs].max(dim=1)[0] 97 | val_false_maxprob = probs[val_false_idxs].max(dim=1)[0] 98 | r = measure_ood(val_true_maxprob.numpy(), val_false_maxprob.numpy()) 99 | return r['opt_thres'] 100 | 101 | 102 | def predict_group(model, args, trainloader, testloader): 103 | model.cuda('cuda:{}'.format(args.device)) 104 | target_attr = None 105 | target_attr = args.target 106 | if args.dataset == 'adult': 107 | target_attr = 'sex' 108 | elif args.dataset == 'compas': 109 | target_attr = 'race' 110 | dataset = DatasetFactory.get_dataset(args.dataset, split='train', sv_ratio=1, version='', 111 | target_attr=target_attr, add_attr=args.add_attr) 112 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, 113 | num_workers=args.num_workers, pin_memory=True, drop_last=False) 114 | 115 | val_idxs = [] 116 | if args.version == 'groupclf_val': 117 | val_idxs = trainloader.dataset.val_idxs 118 | 119 | test_idxs = [] 120 | filename = f'{args.seed}_{args.sv}' 121 | if args.dataset == 'celeba': 122 | if args.target != 'Attractive': 123 | filename += f'_{args.target}' 124 | if args.add_attr is not None: 125 | filename += f'_{args.add_attr}' 126 | filename += '.pkl' 127 | filename = os.path.join(dataset.root, 'annotated_idxs', filename) 128 | with open(filename, 'rb') as f: 129 | idxs_dict = pickle.load(f) 130 | for key in idxs_dict['non-annotated'].keys(): 131 | test_idxs.extend(idxs_dict['non-annotated'][key]) 132 | 133 | model.eval() 134 | preds_list = [] 135 | groups_list = [] 136 | labels_list = [] 137 | probs_list = [] 138 | true_idxs = [] 139 | false_idxs = [] 140 | 141 | uncertainty = False 142 | if 'dropout' in args.model: 143 | uncertainty = True 144 | enable_dropout(model) 145 | 146 | with torch.no_grad(): 147 | for i, data in enumerate(dataloader): 148 | inputs, _, groups, labels, (idxs, _) = data 149 | if args.cuda: 150 | inputs = inputs.cuda() 151 | groups = groups.cuda() 152 | labels = labels.cuda() 153 | idxs = idxs.cuda() 154 | 155 | if uncertainty: 156 | preds, probs = mc_dropout(model, inputs) 157 | 158 | else: 159 | outputs = model(inputs) 160 | probs = torch.nn.functional.softmax(outputs, dim=1) 161 | preds = torch.argmax(outputs, 1) 162 | 163 | true_mask = groups == preds 164 | false_mask = groups != preds 165 | 166 | probs_list.append(probs.cpu()) 167 | 168 | preds_list.append(preds.cpu()) 169 | groups_list.append(groups.cpu()) 170 | labels_list.append(labels.cpu()) 171 | 172 | true_idxs.extend(list(idxs[true_mask].cpu().numpy())) 173 | false_idxs.extend(list(idxs[false_mask].cpu().numpy())) 174 | 175 | preds = torch.cat(preds_list) 176 | groups = torch.cat(groups_list) 177 | probs = torch.cat(probs_list) 178 | labels = torch.cat(labels_list) 179 | 180 | # calculate the test acc 181 | test_acc = (preds == groups)[test_idxs].sum().item() / len(test_idxs) 182 | print(f'test acc : {test_acc}') 183 | if args.version == 'groupclf_val': 184 | val_acc = (preds == groups)[val_idxs].sum().item() / len(val_idxs) 185 | print(f'val acc : {val_acc}') 186 | 187 | results = {} 188 | results['pred'] = preds 189 | results['group'] = groups 190 | results['probs'] = probs 191 | results['label'] = labels 192 | results['true_idxs'] = true_idxs 193 | results['false_idxs'] = false_idxs 194 | return results, val_idxs 195 | 196 | 197 | def enable_dropout(model): 198 | for m in model.modules(): 199 | if m.__class__.__name__.startswith('Dropout'): 200 | m.train() 201 | 202 | 203 | def mc_dropout(model, inputs): 204 | temp_nl = 2 205 | f_pass = 10 206 | out_prob = [] 207 | out_prob_nl = [] 208 | for _ in range(f_pass): 209 | outputs = model(inputs) 210 | out_prob.append(F.softmax(outputs, dim=1)) # for selecting positive pseudo-labels 211 | out_prob_nl.append(F.softmax(outputs/temp_nl, dim=1)) # for selecting negative pseudo-labels 212 | out_prob = torch.stack(out_prob) 213 | out_prob_nl = torch.stack(out_prob_nl) 214 | out_std = torch.std(out_prob, dim=0) 215 | probs = torch.mean(out_prob, dim=0) 216 | out_prob_nl = torch.mean(out_prob_nl, dim=0) 217 | max_value, preds = torch.max(probs, dim=1) 218 | max_std = out_std.gather(1, preds.view(-1, 1)).squeeze() 219 | n_uncertain = (max_std > 0.05).sum() 220 | print('# of uncertain samples : ', n_uncertain) 221 | preds[max_std > 0.05] = -1 222 | return preds, probs 223 | 224 | 225 | def main(): 226 | torch.backends.cudnn.enabled = True 227 | 228 | seed = args.seed 229 | set_seed(seed) 230 | 231 | np.set_printoptions(precision=4) 232 | torch.set_printoptions(precision=4) 233 | 234 | if args.dataset == 'adult': 235 | args.img_size = 97 236 | elif args.dataset == 'compas': 237 | args.img_size = 400 238 | log_name = make_log_name(args) 239 | dataset = args.dataset 240 | save_dir = os.path.join(args.save_dir, args.date, dataset, args.method) 241 | log_dir = os.path.join(args.result_dir, args.date, dataset, args.method) 242 | check_log_dir(save_dir) 243 | check_log_dir(log_dir) 244 | if 'groupclf' not in args.version: 245 | raise ValueError 246 | ########################## get dataloader ################################ 247 | tmp = data_handler.DataloaderFactory.get_dataloader(args.dataset, 248 | batch_size=args.batch_size, seed=args.seed, 249 | num_workers=args.num_workers, 250 | target_attr=args.target, 251 | add_attr=args.add_attr, 252 | labelwise=args.labelwise, 253 | sv_ratio=args.sv, 254 | version=args.version, 255 | ) 256 | num_groups, num_classes, train_loader, test_loader = tmp 257 | ########################## get model ################################## 258 | 259 | num_model_output = num_classes if args.modelpath is not None else num_groups 260 | model = networks.ModelFactory.get_model(args.model, num_model_output, args.img_size) 261 | 262 | if args.modelpath is not None: 263 | model.load_state_dict(torch.load(args.modelpath)) 264 | model.fc = nn.Linear(in_features=model.fc.weight.shape[1], out_features=num_groups, bias=True) 265 | 266 | model.cuda('cuda:{}'.format(args.device)) 267 | teacher = None 268 | criterion = None 269 | if args.dataset in ['compas', 'adult']: 270 | weights = None 271 | weights, data_counts = get_weights(train_loader, cuda=args.cuda) 272 | cls_num_list = [] 273 | for i in range(num_groups): 274 | cls_num_list.append(data_counts[i].item()) 275 | beta = 0.999 276 | effective_num = 1.0 - np.power(beta, cls_num_list) 277 | per_cls_weights = (1.0 - beta) / np.array(effective_num) 278 | per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list) 279 | per_cls_weights = torch.FloatTensor(per_cls_weights).cuda() 280 | criterion = FocalLoss(weight=per_cls_weights).cuda() 281 | ########################## get trainer ################################## 282 | if args.optimizer == 'Adam': 283 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) 284 | elif 'SGD' in args.optimizer: 285 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9) 286 | 287 | trainer_ = trainer.TrainerFactory.get_trainer(args.method, model=model, args=args, 288 | optimizer=optimizer, teacher=teacher) 289 | 290 | ####################### start training or evaluating #################### 291 | 292 | if args.mode == 'train': 293 | start_t = time.time() 294 | trainer_.train(train_loader, test_loader, args.epochs, criterion=criterion) 295 | 296 | end_t = time.time() 297 | train_t = int((end_t - start_t)/60) # to minutes 298 | print('Training Time : {} hours {} minutes'.format(int(train_t/60), (train_t % 60))) 299 | trainer_.save_model(save_dir, log_name) 300 | else: 301 | print('Evaluation ----------------') 302 | model_name = os.path.join(save_dir, log_name+'.pt') 303 | model.load_state_dict(torch.load(model_name)) 304 | results, val_idxs = predict_group(model, args, train_loader, test_loader) 305 | if args.version == 'groupclf_val': 306 | opt_thres = predict_thres(results['probs'], results['true_idxs'], results['false_idxs'], val_idxs) 307 | results['opt_thres'] = opt_thres 308 | save_path = os.path.join('../data', args.dataset, args.date) 309 | if not os.path.isdir(save_path): 310 | os.makedirs(save_path) 311 | torch.save(results, os.path.join(save_path, log_name+'.pt')) 312 | print(os.path.join(save_path, log_name+'.pt')) 313 | return 314 | 315 | if args.evalset == 'all': 316 | trainer_.compute_confusion_matix('train', train_loader.dataset.num_classes, train_loader, log_dir, log_name) 317 | trainer_.compute_confusion_matix('test', test_loader.dataset.num_classes, test_loader, log_dir, log_name) 318 | 319 | elif args.evalset == 'train': 320 | trainer_.compute_confusion_matix('train', train_loader.dataset.num_classes, train_loader, log_dir, log_name) 321 | else: 322 | trainer_.compute_confusion_matix('test', test_loader.dataset.num_classes, test_loader, log_dir, log_name) 323 | 324 | print('Done!') 325 | 326 | 327 | if __name__ == '__main__': 328 | main() 329 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | cgl_fairness 2 | Copyright (c) 2022-present NAVER Corp. 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy 5 | of this software and associated documentation files (the "Software"), to deal 6 | in the Software without restriction, including without limitation the rights 7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the Software is 9 | furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in 12 | all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 | THE SOFTWARE. 21 | 22 | -------------------------------------------------------------------------------------- 23 | 24 | This project contains subcomponents with separate copyright notices and license terms. 25 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. 26 | 27 | ===== 28 | 29 | Trusted-AI/AIF360 30 | https://github.com/Trusted-AI/AIF360 31 | 32 | 33 | Copyright 2018-2021 The AI Fairness 360 (AIF360) Authors 34 | 35 | Licensed under the Apache License, Version 2.0 (the "License"); 36 | you may not use this file except in compliance with the License. 37 | You may obtain a copy of the License at 38 | 39 | http://www.apache.org/licenses/LICENSE-2.0 40 | 41 | Unless required by applicable law or agreed to in writing, software 42 | distributed under the License is distributed on an "AS IS" BASIS, 43 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 44 | See the License for the specific language governing permissions and 45 | limitations under the License. 46 | 47 | ===== 48 | 49 | pytorch/vision 50 | https://github.com/pytorch/vision 51 | 52 | 53 | BSD 3-Clause License 54 | 55 | Copyright (c) Soumith Chintala 2016, 56 | All rights reserved. 57 | 58 | Redistribution and use in source and binary forms, with or without 59 | modification, are permitted provided that the following conditions are met: 60 | 61 | * Redistributions of source code must retain the above copyright notice, this 62 | list of conditions and the following disclaimer. 63 | 64 | * Redistributions in binary form must reproduce the above copyright notice, 65 | this list of conditions and the following disclaimer in the documentation 66 | and/or other materials provided with the distribution. 67 | 68 | * Neither the name of the copyright holder nor the names of its 69 | contributors may be used to endorse or promote products derived from 70 | this software without specific prior written permission. 71 | 72 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 73 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 74 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 75 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 76 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 77 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 78 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 79 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 80 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 81 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 82 | 83 | ===== 84 | 85 | sangwon79/Fair-Feature-Distillation-for-Visual-Recognition 86 | https://github.com/sangwon79/Fair-Feature-Distillation-for-Visual-Recognition 87 | 88 | 89 | 90 | MIT License 91 | 92 | Copyright (c) 2021 Donggyu Lee 93 | 94 | Permission is hereby granted, free of charge, to any person obtaining a copy 95 | of this software and associated documentation files (the "Software"), to deal 96 | in the Software without restriction, including without limitation the rights 97 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 98 | copies of the Software, and to permit persons to whom the Software is 99 | furnished to do so, subject to the following conditions: 100 | 101 | The above copyright notice and this permission notice shall be included in all 102 | copies or substantial portions of the Software. 103 | 104 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 105 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 106 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 107 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 108 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 109 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 110 | SOFTWARE. 111 | 112 | ===== 113 | 114 | yunjey/pytorch-tutorial 115 | https://github.com/yunjey/pytorch-tutorial 116 | 117 | 118 | MIT License 119 | 120 | Copyright (c) 2017 121 | 122 | Permission is hereby granted, free of charge, to any person obtaining a copy 123 | of this software and associated documentation files (the "Software"), to deal 124 | in the Software without restriction, including without limitation the rights 125 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 126 | copies of the Software, and to permit persons to whom the Software is 127 | furnished to do so, subject to the following conditions: 128 | 129 | The above copyright notice and this permission notice shall be included in all 130 | copies or substantial portions of the Software. 131 | 132 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 133 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 134 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 135 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 136 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 137 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 138 | SOFTWARE. 139 | 140 | ===== 141 | 142 | nayeemrizve/ups 143 | https://github.com/nayeemrizve/ups 144 | 145 | 146 | MIT License 147 | 148 | Copyright (c) 2021 Mamshad Nayeem Rizve 149 | 150 | Permission is hereby granted, free of charge, to any person obtaining a copy 151 | of this software and associated documentation files (the "Software"), to deal 152 | in the Software without restriction, including without limitation the rights 153 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 154 | copies of the Software, and to permit persons to whom the Software is 155 | furnished to do so, subject to the following conditions: 156 | 157 | The above copyright notice and this permission notice shall be included in all 158 | copies or substantial portions of the Software. 159 | 160 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 161 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 162 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 163 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 164 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 165 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 166 | SOFTWARE. 167 | 168 | ===== 169 | 170 | clovaai/rebias 171 | https://github.com/clovaai/rebias 172 | 173 | 174 | ReBias 175 | Copyright (c) 2020-present NAVER Corp. 176 | 177 | Permission is hereby granted, free of charge, to any person obtaining a copy 178 | of this software and associated documentation files (the "Software"), to deal 179 | in the Software without restriction, including without limitation the rights 180 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 181 | copies of the Software, and to permit persons to whom the Software is 182 | furnished to do so, subject to the following conditions: 183 | 184 | The above copyright notice and this permission notice shall be included in 185 | all copies or substantial portions of the Software. 186 | 187 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 188 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 189 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 190 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 191 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 192 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 193 | THE SOFTWARE. 194 | 195 | -------------------------------------------------------------------------------------- 196 | 197 | This project contains subcomponents with separate copyright notices and license terms. 198 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. 199 | 200 | ===== 201 | 202 | facebookresearch/SlowFast 203 | https://github.com/facebookresearch/SlowFast 204 | 205 | 206 | Copyright 2019, Facebook, Inc 207 | 208 | Licensed under the Apache License, Version 2.0 (the "License"); 209 | you may not use this file except in compliance with the License. 210 | You may obtain a copy of the License at 211 | 212 | http://www.apache.org/licenses/LICENSE-2.0 213 | 214 | Unless required by applicable law or agreed to in writing, software 215 | distributed under the License is distributed on an "AS IS" BASIS, 216 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 217 | See the License for the specific language governing permissions and 218 | limitations under the License. 219 | 220 | ===== 221 | 222 | pytorch/vision 223 | https://github.com/pytorch/vision 224 | 225 | 226 | BSD 3-Clause License 227 | 228 | Copyright (c) Soumith Chintala 2016, 229 | All rights reserved. 230 | 231 | Redistribution and use in source and binary forms, with or without 232 | modification, are permitted provided that the following conditions are met: 233 | 234 | * Redistributions of source code must retain the above copyright notice, this 235 | list of conditions and the following disclaimer. 236 | 237 | * Redistributions in binary form must reproduce the above copyright notice, 238 | this list of conditions and the following disclaimer in the documentation 239 | and/or other materials provided with the distribution. 240 | 241 | * Neither the name of the copyright holder nor the names of its 242 | contributors may be used to endorse or promote products derived from 243 | this software without specific prior written permission. 244 | 245 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 246 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 247 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 248 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 249 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 250 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 251 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 252 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 253 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 254 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 255 | 256 | ===== 257 | 258 | wielandbrendel/bag-of-local-features-models 259 | https://github.com/wielandbrendel/bag-of-local-features-models 260 | 261 | 262 | Copyright (c) 2019 Wieland Brendel 263 | 264 | Permission is hereby granted, free of charge, to any person obtaining a copy 265 | of this software and associated documentation files (the "Software"), to deal 266 | in the Software without restriction, including without limitation the rights 267 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 268 | copies of the Software, and to permit persons to whom the Software is 269 | furnished to do so, subject to the following conditions: 270 | 271 | The above copyright notice and this permission notice shall be included in all 272 | copies or substantial portions of the Software. 273 | 274 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 275 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 276 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 277 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 278 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 279 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 280 | SOFTWARE. 281 | 282 | ===== 283 | 284 | cdancette/rubi.bootstrap.pytorch 285 | https://github.com/cdancette/rubi.bootstrap.pytorch 286 | 287 | 288 | BSD 3-Clause License 289 | 290 | Copyright (c) 2019+, Remi Cadene, Corentin Dancette 291 | All rights reserved. 292 | 293 | Redistribution and use in source and binary forms, with or without 294 | modification, are permitted provided that the following conditions are met: 295 | 296 | * Redistributions of source code must retain the above copyright notice, this 297 | list of conditions and the following disclaimer. 298 | 299 | * Redistributions in binary form must reproduce the above copyright notice, 300 | this list of conditions and the following disclaimer in the documentation 301 | and/or other materials provided with the distribution. 302 | 303 | * Neither the name of the copyright holder nor the names of its 304 | contributors may be used to endorse or promote products derived from 305 | this software without specific prior written permission. 306 | 307 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 308 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 309 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 310 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 311 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 312 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 313 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 314 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 315 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 316 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 317 | 318 | ===== 319 | 320 | chrisc36/debias 321 | https://github.com/chrisc36/debias 322 | 323 | 324 | "Don’t Take the Easy Way Out: Ensemble Based Methods for Avoiding Known Dataset Biases". Christopher Clark, Mark Yatskar, Luke Zettlemoyer. In EMNLP 2019. 325 | 326 | --- 327 | 328 | Licensed under the Apache License, Version 2.0 (the "License"); 329 | you may not use this file except in compliance with the License. 330 | You may obtain a copy of the License at 331 | 332 | http://www.apache.org/licenses/LICENSE-2.0 333 | 334 | Unless required by applicable law or agreed to in writing, software 335 | distributed under the License is distributed on an "AS IS" BASIS, 336 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 337 | See the License for the specific language governing permissions and 338 | limitations under the License. 339 | 340 | ===== 341 | 342 | ===== 343 | --------------------------------------------------------------------------------