├── networks
├── __init__.py
├── model_factory.py
├── mlp.py
├── autoencoder.py
├── resnet_dropout.py
└── resnet.py
├── trainer
├── __init__.py
├── loss_utils.py
├── vanilla_train.py
├── hsic.py
├── fairhsic.py
├── mfd.py
├── trainer_factory.py
├── reweighting.py
└── adv_debiasing.py
├── data_handler
├── __init__.py
├── adult.py
├── compas.py
├── custom_loader.py
├── custom_loader_hsic.py
├── AIF360
│ ├── binary_label_dataset.py
│ ├── credit_dataset.py
│ ├── dataset.py
│ ├── bank_dataset.py
│ ├── compas_dataset.py
│ ├── adult_dataset.py
│ └── standard_dataset.py
├── fairface.py
├── dataloader_factory.py
├── tabular_dataset.py
├── utils.py
├── dataset_factory.py
├── utkface.py
├── utkface_fairface.py
├── celeba.py
└── ssl_dataset.py
├── LICENSE
├── utils.py
├── arguments.py
├── main.py
├── README.md
├── main_groupclf.py
└── NOTICE
/networks/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | from networks.model_factory import *
7 |
--------------------------------------------------------------------------------
/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | from trainer.trainer_factory import *
7 |
--------------------------------------------------------------------------------
/data_handler/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | from data_handler.dataloader_factory import *
7 | from data_handler.dataset_factory import *
8 | from data_handler.ssl_dataset import *
9 |
--------------------------------------------------------------------------------
/data_handler/adult.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | import pandas as pd
7 | from data_handler.AIF360.adult_dataset import AdultDataset
8 | from data_handler.tabular_dataset import TabularDataset
9 |
10 |
11 | class AdultDataset_torch(TabularDataset):
12 | """Adult dataset."""
13 | name = 'adult'
14 | def __init__(self, root, target_attr='sex', **kwargs):
15 |
16 | dataset = AdultDataset(root_dir=root)
17 | if target_attr == 'sex':
18 | sen_attr_idx = 3
19 | elif target_attr == 'race':
20 | sen_attr_idx = 2
21 | else:
22 | raise Exception('Not allowed group')
23 |
24 | self.num_groups = 2
25 | self.num_classes = 2
26 |
27 | super(AdultDataset_torch, self).__init__(root=root, dataset=dataset, sen_attr_idx=sen_attr_idx,
28 | **kwargs)
29 |
30 |
--------------------------------------------------------------------------------
/data_handler/compas.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | import pandas as pd
7 | from data_handler.AIF360.compas_dataset import CompasDataset
8 | from data_handler.tabular_dataset import TabularDataset
9 |
10 |
11 | class CompasDataset_torch(TabularDataset):
12 | """Compas dataset."""
13 | name = 'compas'
14 | def __init__(self, root, target_attr='race', **kwargs):
15 |
16 | dataset = CompasDataset(root_dir=root)
17 | if target_attr == 'sex':
18 | sen_attr_idx = 0
19 | elif target_attr == 'race':
20 | sen_attr_idx = 2
21 | else:
22 | raise Exception('Not allowed group')
23 |
24 | self.num_groups = 2
25 | self.num_classes = 2
26 |
27 | super(CompasDataset_torch, self).__init__(root=root, dataset=dataset, sen_attr_idx=sen_attr_idx,
28 | **kwargs)
29 |
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | cgl_fairness
2 | Copyright (c) 2022-present NAVER Corp.
3 |
4 | Permission is hereby granted, free of charge, to any person obtaining a copy
5 | of this software and associated documentation files (the "Software"), to deal
6 | in the Software without restriction, including without limitation the rights
7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 | copies of the Software, and to permit persons to whom the Software is
9 | furnished to do so, subject to the following conditions:
10 |
11 | The above copyright notice and this permission notice shall be included in
12 | all copies or substantial portions of the Software.
13 |
14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 | THE SOFTWARE.
21 |
--------------------------------------------------------------------------------
/networks/model_factory.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | import torch.nn as nn
7 |
8 | from networks.resnet import resnet10, resnet12,resnet18, resnet34, resnet50, resnet101
9 | from networks.mlp import MLP
10 | from networks.resnet_dropout import resnet18_dropout
11 |
12 | class ModelFactory():
13 | def __init__(self):
14 | pass
15 |
16 | @staticmethod
17 | def get_model(target_model, num_classes=2, img_size=224, pretrained=False, num_groups=2):
18 |
19 | if target_model == 'mlp':
20 | return MLP(feature_size=img_size, hidden_dim=64, num_classes=num_classes)
21 |
22 | elif 'resnet' in target_model:
23 | model_class = eval(target_model)
24 | if pretrained:
25 | model = model_class(pretrained=True, img_size=img_size)
26 | model.fc = nn.Linear(in_features=model.fc.weight.shape[1], out_features=num_classes, bias=True)
27 | else:
28 | model = model_class(pretrained=False, num_classes=num_classes, num_groups=num_groups, img_size=img_size)
29 | return model
30 |
31 | else:
32 | raise NotImplementedError
33 |
34 |
--------------------------------------------------------------------------------
/data_handler/custom_loader.py:
--------------------------------------------------------------------------------
1 | """
2 | Original code:
3 | https://github.com/sangwon79/Fair-Feature-Distillation-for-Visual-Recognition
4 | """
5 | from copy import copy
6 | import numpy as np
7 | from torch.utils.data.sampler import RandomSampler
8 | import random
9 |
10 |
11 | class Customsampler(RandomSampler):
12 |
13 | def __init__(self, data_source, replacement=False, num_samples=None, batch_size=None, generator=None):
14 | super(Customsampler, self).__init__(data_source=data_source, replacement=replacement,
15 | num_samples=num_samples, generator=generator)
16 | self.l = data_source.num_classes
17 | self.g = data_source.num_groups
18 | self.nbatch_size = batch_size // (self.l*self.g)
19 | self.num_data = data_source.num_data
20 | self.idxs_per_group = data_source.idxs_per_group
21 |
22 | # which one is a group that has the largest number of data poitns
23 | self.max_pos = np.unravel_index(np.argmax(self.num_data), self.num_data.shape)
24 | self.numdata_per_group = (self.num_data[self.max_pos] // (self.nbatch_size+1) + 1) * (self.nbatch_size+1)
25 |
26 | def __iter__(self):
27 | index_list = []
28 |
29 | for g in range(self.g):
30 | for l in range(self.l):
31 | total = 0
32 | group_index_list = []
33 | while total < self.numdata_per_group:
34 | tmp = copy(self.idxs_per_group[(g, l)])
35 | random.shuffle(tmp)
36 | remained_data = self.numdata_per_group - total
37 | if remained_data > len(tmp):
38 | group_index_list.extend(tmp)
39 | else:
40 | group_index_list.extend(tmp[:remained_data])
41 | break
42 | total += len(tmp)
43 | index_list.append(group_index_list)
44 |
45 | final_list = np.array(index_list)
46 | final_list = final_list.flatten('F')
47 | final_list = list(final_list)
48 |
49 | return iter(final_list)
50 |
--------------------------------------------------------------------------------
/networks/mlp.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.autograd import Function
10 |
11 | class MLP(nn.Module):
12 | def __init__(self, feature_size, hidden_dim, num_classes=None, num_layer=3, adv=False, adv_lambda=1.):
13 | super(MLP, self).__init__()
14 | try: #list
15 | in_features = self.compute_input_size(feature_size)
16 | except : #int
17 | in_features = feature_size
18 |
19 | num_outputs = num_classes
20 | self.adv = adv
21 | if self.adv:
22 | self.adv_lambda = adv_lambda
23 | self._make_layer(in_features, hidden_dim, num_classes, num_layer)
24 |
25 | def forward(self, feature, get_inter=False):
26 | feature = torch.flatten(feature, 1)
27 | if self.adv:
28 | feature = ReverseLayerF.apply(feature, self.adv_lambda)
29 |
30 | h = self.features(feature)
31 | out = self.head(h)
32 | out = out.squeeze()
33 |
34 | if get_inter:
35 | return h, out
36 | else:
37 | return out
38 |
39 | def compute_input_size(self, feature_size):
40 | in_features = 1
41 | for size in feature_size:
42 | in_features = in_features * size
43 |
44 | return in_features
45 |
46 | def _make_layer(self, in_dim, h_dim, num_classes, num_layer):
47 |
48 | if num_layer == 1:
49 | self.features = nn.Identity()
50 | h_dim = in_dim
51 | else:
52 | features = []
53 | for i in range(num_layer-1):
54 | features.append(nn.Linear(in_dim, h_dim) if i == 0 else nn.Linear(h_dim, h_dim))
55 | features.append(nn.ReLU())
56 | self.features = nn.Sequential(*features)
57 |
58 | self.head = nn.Linear(h_dim, num_classes)
59 |
60 |
61 | class ReverseLayerF(Function):
62 |
63 | @staticmethod
64 | def forward(ctx, x, alpha):
65 | ctx.alpha = alpha
66 |
67 | return x.view_as(x)
68 |
69 | @staticmethod
70 | def backward(ctx, grad_output):
71 | return grad_output.neg() * ctx.alpha, None
72 |
--------------------------------------------------------------------------------
/data_handler/custom_loader_hsic.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | from copy import copy
7 | import numpy as np
8 | from torch.utils.data.sampler import RandomSampler
9 | import random
10 |
11 |
12 | class Customsampler(RandomSampler):
13 | def __init__(self, data_source, replacement=False, num_samples=None, batch_size=None, generator=None):
14 | super(Customsampler, self).__init__(data_source=data_source, replacement=replacement,
15 | num_samples=num_samples, generator=generator)
16 | self.l = data_source.num_classes
17 | self.g = data_source.num_groups
18 | self.nbatch_size = batch_size // (self.l*self.g)
19 | self.num_data = np.sum(data_source.num_data, axis=0)
20 |
21 | self.idxs_per_group = data_source.idxs_per_group
22 | idxs_per_group = {}
23 | for l in range(self.l):
24 | idxs_per_group[l] = []
25 | for g in range(self.g):
26 | idxs_per_group[l].extend(self.idxs_per_group[(g, l)])
27 |
28 | self.idxs_per_group = idxs_per_group
29 | # which one is a group that has the largest number of data poitns
30 | self.max_pos = np.unravel_index(np.argmax(self.num_data), self.num_data.shape)
31 |
32 | self.numdata_per_group = (self.num_data[self.max_pos] // (self.nbatch_size+1) + 1) * (self.nbatch_size+1)
33 |
34 | def __iter__(self):
35 | index_list = []
36 |
37 | for l in range(self.l):
38 | total = 0
39 | group_index_list = []
40 | while total < self.numdata_per_group:
41 | tmp = copy(self.idxs_per_group[l])
42 | random.shuffle(tmp)
43 | remained_data = self.numdata_per_group - total
44 | if remained_data > len(tmp):
45 | group_index_list.extend(tmp)
46 | else:
47 | group_index_list.extend(tmp[:remained_data])
48 | break
49 | total += len(tmp)
50 | index_list.append(group_index_list)
51 |
52 | final_list = np.array(index_list)
53 | final_list = final_list.flatten('F')
54 | final_list = list(final_list)
55 |
56 | return iter(final_list)
57 |
--------------------------------------------------------------------------------
/data_handler/AIF360/binary_label_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Original code:
3 | https://github.com/Trusted-AI/AIF360
4 | """
5 | import numpy as np
6 |
7 | from data_handler.AIF360.structured_dataset import StructuredDataset
8 |
9 |
10 | class BinaryLabelDataset(StructuredDataset):
11 | """Base class for all structured datasets with binary labels."""
12 |
13 | def __init__(self, favorable_label=1., unfavorable_label=0., **kwargs):
14 | """
15 | Args:
16 | favorable_label (float): Label value which is considered favorable
17 | (i.e. "positive").
18 | unfavorable_label (float): Label value which is considered
19 | unfavorable (i.e. "negative").
20 | **kwargs: StructuredDataset arguments.
21 | """
22 | self.favorable_label = float(favorable_label)
23 | self.unfavorable_label = float(unfavorable_label)
24 |
25 | super(BinaryLabelDataset, self).__init__(**kwargs)
26 |
27 | def validate_dataset(self):
28 | """Error checking and type validation.
29 | Raises:
30 | ValueError: `labels` must be shape [n, 1].
31 | ValueError: `favorable_label` and `unfavorable_label` must be the
32 | only values present in `labels`.
33 | """
34 | # fix scores before validating
35 | if np.all(self.scores == self.labels):
36 | self.scores = (self.scores == self.favorable_label).astype(np.float64)
37 |
38 | super(BinaryLabelDataset, self).validate_dataset()
39 |
40 | # =========================== SHAPE CHECKING ===========================
41 | # Verify if the labels are only 1 column
42 | if self.labels.shape[1] != 1:
43 | raise ValueError("BinaryLabelDataset only supports single-column "
44 | "labels:\n\tlabels.shape = {}".format(self.labels.shape))
45 |
46 | # =========================== VALUE CHECKING ===========================
47 | # Check if the favorable and unfavorable labels match those in the dataset
48 | if (not set(self.labels.ravel()) <=
49 | set([self.favorable_label, self.unfavorable_label])):
50 | raise ValueError("The favorable and unfavorable labels provided do "
51 | "not match the labels in the dataset.")
52 |
--------------------------------------------------------------------------------
/trainer/loss_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Original code:
3 | https://github.com/sangwon79/Fair-Feature-Distillation-for-Visual-Recognition
4 | """
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | def mse(inputs, targets):
10 | return (inputs - targets).pow(2).mean()
11 |
12 |
13 | def compute_feature_loss(inputs, t_inputs, student, teacher, device=0, regressor=None):
14 | stu_outputs = student(inputs, get_inter=True)
15 | f_s = stu_outputs[-2]
16 | if regressor is not None:
17 | f_s = regressor.forward(f_s)
18 |
19 | f_s = f_s.view(f_s.shape[0], -1)
20 | stu_logits = stu_outputs[-1]
21 |
22 | tea_outputs = teacher(t_inputs, get_inter=True)
23 | f_t = tea_outputs[-2].to(device)
24 | f_t = f_t.view(f_t.shape[0], -1).detach()
25 |
26 | tea_logits = tea_outputs[-1]
27 |
28 | fitnet_loss = (1 / 2) * (mse(f_s, f_t))
29 |
30 | return fitnet_loss, stu_logits, tea_logits, f_s, f_t
31 |
32 |
33 | def compute_hinton_loss(outputs, t_outputs=None, teacher=None, t_inputs=None, kd_temp=3, device=0):
34 | if t_outputs is None:
35 | if (t_inputs is not None and teacher is not None):
36 | t_outputs = teacher(t_inputs)
37 | else:
38 | Exception('Nothing is given to compute hinton loss')
39 |
40 | soft_label = F.softmax(t_outputs / kd_temp, dim=1).to(device).detach()
41 | kd_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(outputs / kd_temp, dim=1),
42 | soft_label) * (kd_temp * kd_temp)
43 |
44 | return kd_loss
45 |
46 |
47 | def compute_at_loss(inputs, t_inputs, student, teacher, device=0, for_cifar=False):
48 | stu_outputs = student(inputs, get_inter=True) if not for_cifar else student(inputs, get_inter=True, before_fc=True)
49 | stu_logits = stu_outputs[-1]
50 | f_s = stu_outputs[-2]
51 |
52 | tea_outputs = teacher(t_inputs, get_inter=True) if not for_cifar else teacher(inputs, get_inter=True, before_fc=True)
53 | tea_logits = tea_outputs[-1].to(device)
54 | f_t = tea_outputs[-2].to(device)
55 | attention_loss = (1 / 2) * (at_loss(f_s, f_t))
56 | return attention_loss, stu_logits, tea_logits, f_s, f_t
57 |
58 |
59 | def at(x):
60 | return F.normalize(x.pow(2).mean(1).view(x.size(0), -1))
61 |
62 |
63 | def at_loss(x, y):
64 | return (at(x) - at(y)).pow(2).mean()
65 |
--------------------------------------------------------------------------------
/data_handler/AIF360/credit_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Original code:
3 | https://github.com/Trusted-AI/AIF360
4 | """
5 | import os
6 |
7 | import pandas as pd
8 |
9 | from data_handler.AIF360.standard_dataset import StandardDataset
10 |
11 |
12 | class CreditDataset(StandardDataset):
13 | """Bank marketing Dataset.
14 | See :file:`aif360/data/raw/bank/README.md`.
15 | """
16 |
17 | def __init__(self, root_dir='./data/credit', label_name='default payment next month', favorable_classes=[1],
18 | protected_attribute_names=['SEX'],
19 | privileged_classes=[lambda x: x == 1],
20 | instance_weights_name=None,
21 | categorical_features=['EDUCATION', 'MARRIAGE', 'AGE', 'PAY_0', 'PAY_2',
22 | 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6'],
23 | features_to_keep=[], features_to_drop=[],
24 | na_values=[], custom_preprocessing=None,
25 | metadata=None):
26 | """See :obj:`StandardDataset` for a description of the arguments.
27 | By default, this code converts the 'marital' attribute to a binary value
28 | where privileged is `not married` and unprivileged is `married` as in
29 | :obj:`GermanDataset`.
30 | """
31 |
32 | filepath = os.path.join(root_dir, 'credit.csv')
33 |
34 | try:
35 | df = pd.read_csv(filepath)
36 | except IOError as err:
37 | print("IOError: {}".format(err))
38 | import sys
39 | sys.exit(1)
40 |
41 | super(CreditDataset, self).__init__(df=df, label_name=label_name,
42 | favorable_classes=favorable_classes,
43 | protected_attribute_names=protected_attribute_names,
44 | privileged_classes=privileged_classes,
45 | instance_weights_name=instance_weights_name,
46 | categorical_features=categorical_features,
47 | features_to_keep=features_to_keep,
48 | features_to_drop=features_to_drop, na_values=na_values,
49 | custom_preprocessing=custom_preprocessing, metadata=metadata)
50 |
--------------------------------------------------------------------------------
/data_handler/fairface.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | import os
7 | from os.path import join
8 | from data_handler import GenericDataset
9 | import pandas
10 | from torchvision import transforms
11 | from functools import partial
12 | from data_handler.utils import get_mean_std
13 |
14 |
15 | class Fairface(GenericDataset):
16 | mean, std = get_mean_std('uktface')
17 | train_transform = transforms.Compose(
18 | [transforms.Resize((256, 256)),
19 | transforms.RandomCrop(224),
20 | transforms.RandomHorizontalFlip(),
21 | transforms.ToTensor(),
22 | transforms.Normalize(mean=mean, std=std)]
23 | )
24 | test_transform = transforms.Compose(
25 | [transforms.Resize((224, 224)),
26 | transforms.ToTensor(),
27 | transforms.Normalize(mean=mean, std=std)]
28 | )
29 | name = 'fairface'
30 |
31 | def __init__(self, target_attr='age', root='fairface', split='train', seed=0):
32 | transform = self.train_transform if split == 'train' else self.test_transform
33 | GenericDataset.__init__(self, root=root, split=split, seed=seed, transform=transform)
34 |
35 | self.sensitive_attr = 'race'
36 | self.target_attr = target_attr
37 |
38 | fn = partial(join, self.root)
39 | label_file = "fairface_label_{}.csv".format(split)
40 | label_mat = pandas.read_csv(fn(label_file))
41 | self.feature_mat = self._preprocessing(label_mat)
42 |
43 | def _preprocessing(self, label_mat):
44 | race_dict = {
45 | 'White': 0,
46 | 'Middle Eastern': 0,
47 | 'Black': 1,
48 | 'East Asian': 2,
49 | 'Southeast Asian': 2,
50 | 'Indian': 3
51 | }
52 |
53 | age_dict = {
54 | '0-2': 0,
55 | '3-9': 0,
56 | '10-19': 0,
57 | '20-29': 1,
58 | '30-39': 1,
59 | '40-49': 2,
60 | '50-59': 2,
61 | '60-69': 2,
62 | 'more than 70': 2,
63 | }
64 |
65 | race_idx = 3
66 | age_idx = 1
67 |
68 | feature_mat = []
69 | for row in label_mat.values:
70 | _age = row[age_idx]
71 | _race = row[race_idx]
72 | if _race not in race_dict.keys():
73 | continue
74 | race = race_dict[_race]
75 | age = age_dict[_age]
76 | feature_mat.append([race, age, os.path.join(self.root, row[0])])
77 |
78 | return feature_mat
79 |
--------------------------------------------------------------------------------
/data_handler/dataloader_factory.py:
--------------------------------------------------------------------------------
1 | """
2 | Original code:
3 | https://github.com/sangwon79/Fair-Feature-Distillation-for-Visual-Recognition
4 | """
5 | from data_handler.dataset_factory import DatasetFactory
6 |
7 | import numpy as np
8 | from torch.utils.data import DataLoader
9 |
10 |
11 | class DataloaderFactory:
12 | def __init__(self):
13 | pass
14 |
15 | @staticmethod
16 | def get_dataloader(name, batch_size=256, seed=0, num_workers=4,
17 | target_attr='Attractive', add_attr=None, labelwise=False, sv_ratio=1, version='', args=None):
18 | if name == 'adult':
19 | target_attr = 'sex'
20 | elif name == 'compas':
21 | target_attr = 'race'
22 |
23 | test_dataset = DatasetFactory.get_dataset(name, split='test', sv_ratio=sv_ratio, version=version,
24 | target_attr=target_attr, seed=seed, add_attr=add_attr)
25 | train_dataset = DatasetFactory.get_dataset(name, split='train', sv_ratio=sv_ratio, version=version,
26 | target_attr=target_attr, seed=seed, add_attr=add_attr)
27 |
28 | def _init_fn(worker_id):
29 | np.random.seed(int(seed))
30 |
31 | num_classes = test_dataset.num_classes
32 | num_groups = test_dataset.num_groups
33 |
34 | shuffle = True
35 | sampler = None
36 | if labelwise:
37 | if args.method == 'fairhsic':
38 | from data_handler.custom_loader_hsic import Customsampler
39 | sampler = Customsampler(train_dataset, replacement=False, batch_size=batch_size)
40 | else:
41 | from data_handler.custom_loader import Customsampler
42 | sampler = Customsampler(train_dataset, replacement=False, batch_size=batch_size)
43 |
44 | shuffle = False
45 |
46 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
47 | num_workers=num_workers, worker_init_fn=_init_fn, pin_memory=True, drop_last=False)
48 |
49 | test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False,
50 | num_workers=num_workers, worker_init_fn=_init_fn, pin_memory=True)
51 |
52 | print('# of test data : {}'.format(len(test_dataset)))
53 | print('# of train data : {}'.format(len(train_dataset)))
54 | print('Dataset loaded.')
55 | print('# of classes, # of groups : {}, {}'.format(num_classes, num_groups))
56 |
57 | return num_classes, num_groups, train_dataloader, test_dataloader
58 |
--------------------------------------------------------------------------------
/data_handler/AIF360/dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Original code:
3 | https://github.com/Trusted-AI/AIF360
4 | """
5 | from abc import ABC, abstractmethod
6 | import copy
7 |
8 |
9 | class Dataset(ABC):
10 | """Abstract base class for datasets."""
11 |
12 | @abstractmethod
13 | def __init__(self, **kwargs):
14 | self.metadata = kwargs.pop('metadata', dict()) or dict()
15 | self.metadata.update({
16 | 'transformer': '{}.__init__'.format(type(self).__name__),
17 | 'params': kwargs,
18 | 'previous': []
19 | })
20 | self.validate_dataset()
21 |
22 | def validate_dataset(self):
23 | """Error checking and type validation."""
24 | pass
25 |
26 | def copy(self, deepcopy=False):
27 | """Convenience method to return a copy of this dataset.
28 | Args:
29 | deepcopy (bool, optional): :func:`~copy.deepcopy` this dataset if
30 | `True`, shallow copy otherwise.
31 | Returns:
32 | Dataset: A new dataset with fields copied from this object and
33 | metadata set accordingly.
34 | """
35 | cpy = copy.deepcopy(self) if deepcopy else copy.copy(self)
36 | # preserve any user-created fields
37 | cpy.metadata = cpy.metadata.copy()
38 | cpy.metadata.update({
39 | 'transformer': '{}.copy'.format(type(self).__name__),
40 | 'params': {'deepcopy': deepcopy},
41 | 'previous': [self]
42 | })
43 | return cpy
44 |
45 | @abstractmethod
46 | def export_dataset(self):
47 | """Save this Dataset to disk."""
48 | raise NotImplementedError
49 |
50 | @abstractmethod
51 | def split(self, num_or_size_splits, shuffle=False):
52 | """Split this dataset into multiple partitions.
53 | Args:
54 | num_or_size_splits (array or int): If `num_or_size_splits` is an
55 | int, *k*, the value is the number of equal-sized folds to make
56 | (if *k* does not evenly divide the dataset these folds are
57 | approximately equal-sized). If `num_or_size_splits` is an array
58 | of type int, the values are taken as the indices at which to
59 | split the dataset. If the values are floats (< 1.), they are
60 | considered to be fractional proportions of the dataset at which
61 | to split.
62 | shuffle (bool, optional): Randomly shuffle the dataset before
63 | splitting.
64 | Returns:
65 | list(Dataset): Splits. Contains *k* or `len(num_or_size_splits) + 1`
66 | datasets depending on `num_or_size_splits`.
67 | """
68 | raise NotImplementedError
69 |
--------------------------------------------------------------------------------
/data_handler/tabular_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | import numpy as np
7 | import random
8 | from data_handler import SSLDataset
9 |
10 |
11 | class TabularDataset(SSLDataset):
12 | """Adult dataset."""
13 | # 1 idx -> sensi
14 | # 2 idx -> label
15 | # 3 idx -> filename or feature (image / tabular)
16 | def __init__(self, dataset, sen_attr_idx, **kwargs):
17 | super(TabularDataset, self).__init__(**kwargs)
18 | self.sen_attr_idx = sen_attr_idx
19 |
20 | dataset_train, dataset_test = dataset.split([0.8], shuffle=True, seed=0)
21 | # features, labels = self._balance_test_set(dataset)
22 | self.dataset = dataset_train if (self.split == 'train') or ('group' in self.version) else dataset_test
23 |
24 | features = np.delete(self.dataset.features, self.sen_attr_idx, axis=1)
25 | mean, std = self._get_mean_n_std(dataset_train.features)
26 | features = (features - mean) / std
27 |
28 | self.groups = np.expand_dims(self.dataset.features[:, self.sen_attr_idx], axis=1)
29 | self.labels = np.squeeze(self.dataset.labels)
30 |
31 | # self.features = self.dataset.features
32 | self.features = np.concatenate((self.groups, self.dataset.labels, features), axis=1)
33 |
34 | # For prepare mean and std from the train dataset
35 | self.num_data, self.idxs_per_group = self._data_count(self.features, self.num_groups, self.num_classes)
36 |
37 | # if semi-supervised learning,
38 | if self.sv_ratio < 1:
39 | # we want the different supervision according to the seed
40 | random.seed(self.seed)
41 | self.features, self.num_data, self.idxs_per_group = self.ssl_processing(self.features, self.num_data, self.idxs_per_group, )
42 | if 'group' in self.version:
43 | a, b = self.num_groups, self.num_classes
44 | self.num_groups, self.num_classes = b, a
45 |
46 | def get_dim(self):
47 | return self.dataset.features.shape[-1]
48 |
49 | def __getitem__(self, idx):
50 | features = self.features[idx]
51 | group = features[0]
52 | label = features[1]
53 | feature = features[2:]
54 |
55 | if 'group' in self.version:
56 | return np.float32(feature), 0, label, np.int64(group), (idx, 0)
57 | else:
58 | return np.float32(feature), 0, group, np.int64(label), (idx, 0)
59 |
60 | def _get_mean_n_std(self, train_features):
61 | features = np.delete(train_features, self.sen_attr_idx, axis=1)
62 | mean = features.mean(axis=0)
63 | std = features.std(axis=0)
64 | std[std == 0] += 1e-7
65 | return mean, std
66 |
--------------------------------------------------------------------------------
/data_handler/AIF360/bank_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Original code:
3 | https://github.com/Trusted-AI/AIF360
4 | """
5 | import os
6 |
7 | import pandas as pd
8 |
9 | from data_handler.AIF360.standard_dataset import StandardDataset
10 |
11 |
12 | class BankDataset(StandardDataset):
13 | """Bank marketing Dataset.
14 | See :file:`aif360/data/raw/bank/README.md`.
15 | """
16 |
17 | def __init__(self, root_dir='./data/bank', label_name='y', favorable_classes=['yes'],
18 | protected_attribute_names=['age'],
19 | privileged_classes=[lambda x: (x >= 25 and x <= 60)],
20 | instance_weights_name=None,
21 | categorical_features=['job', 'marital', 'education', 'default',
22 | 'housing', 'loan', 'contact', 'month', 'day_of_week',
23 | 'poutcome'],
24 | features_to_keep=[], features_to_drop=[],
25 | na_values=["unknown"], custom_preprocessing=None,
26 | metadata=None):
27 | """See :obj:`StandardDataset` for a description of the arguments.
28 | By default, this code converts the 'marital' attribute to a binary value
29 | where privileged is `not married` and unprivileged is `married` as in
30 | :obj:`GermanDataset`.
31 | """
32 |
33 | filepath = os.path.join(root_dir, 'bank-additional-full.csv')
34 |
35 | try:
36 | df = pd.read_csv(filepath, sep=';', na_values=na_values)
37 | except IOError as err:
38 | print("IOError: {}".format(err))
39 | print("To use this class, please download the following file:")
40 | print("\n\thttps://archive.ics.uci.edu/ml/machine-learning-databases/00222/bank-additional.zip")
41 | print("\nunzip it and place the files, as-is, in the folder:")
42 | print("\n\t{}\n".format(root_dir))
43 | import sys
44 | sys.exit(1)
45 |
46 | super(BankDataset, self).__init__(df=df, label_name=label_name,
47 | favorable_classes=favorable_classes,
48 | protected_attribute_names=protected_attribute_names,
49 | privileged_classes=privileged_classes,
50 | instance_weights_name=instance_weights_name,
51 | categorical_features=categorical_features,
52 | features_to_keep=features_to_keep,
53 | features_to_drop=features_to_drop, na_values=na_values,
54 | custom_preprocessing=custom_preprocessing, metadata=metadata)
55 |
--------------------------------------------------------------------------------
/data_handler/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | import torch
7 | from torch.utils.data import DataLoader
8 | import numpy as np
9 |
10 |
11 | def get_mean_std(dataset, skew_ratio=0.8):
12 | mean, std = None, None
13 |
14 | if dataset == 'utkface':
15 | mean = [0.5960, 0.4573, 0.3921]
16 | std = [0.2586, 0.2314, 0.2275]
17 |
18 | elif dataset == 'cifar10s':
19 | # for skew 0.8
20 | mean = [0.4871, 0.4811, 0.4632]
21 | std = [0.2431, 0.2414, 0.2506]
22 |
23 | elif dataset == 'celeba':
24 | # default target is 'Attractive'
25 | mean = [0.5063, 0.4258, 0.3832]
26 | std = [0.3107, 0.2904, 0.2897]
27 |
28 | elif dataset == 'imagenet':
29 | mean = [0.485, 0.456, 0.406]
30 | std = [0.229, 0.224, 0.225]
31 |
32 | return mean, std
33 |
34 |
35 | def predict_group(model, loader, args):
36 | model.cuda('cuda:{}'.format(args.device))
37 | if args.slversion == 3:
38 | filename = 'trained_models/group_clf/utkface/scratch/resnet18_seed{}_epochs70_bs128_lr0.001_sv{}_version0.0.pt'
39 | elif args.slversion == 5:
40 | filename = 'trained_models/group_clf_pretrain/utkface/scratch/resnet18_seed{}_epochs70_bs128_lr0.001_sv{}_version0.0.pt'
41 | path = filename.format(str(args.seed), str(args.sv))
42 | model.load_state_dict(torch.load(path, map_location=torch.device('cuda:{}'.format(args.device))))
43 |
44 | features = loader.dataset.features
45 |
46 | dataloader = DataLoader(loader.dataset, batch_size=args.batch_size, shuffle=False,
47 | num_workers=args.num_workers, pin_memory=True, drop_last=False)
48 | model.eval()
49 | with torch.no_grad():
50 | for i, data in enumerate(dataloader):
51 | inputs, _, groups, labels, (idxs, _) = data
52 | if (groups == -1).sum() == 0:
53 | continue
54 |
55 | if args.cuda:
56 | inputs = inputs.cuda()
57 | groups = groups.cuda()
58 | idxs = idxs.cuda()
59 | inputs = inputs[groups == -1]
60 | idxs = idxs[groups == -1]
61 |
62 | outputs = model(inputs)
63 | preds = torch.argmax(outputs, 1)
64 | for j, idx in enumerate(idxs.cpu().numpy()):
65 | features[idx][1] = preds.cpu()[j]
66 | if i % args.term == 0:
67 | print('[{}] in group prediction'.format(i))
68 |
69 | if args.labelwise:
70 | loader.dataset.num_data, loader.dataset.idxs_per_group = loader.dataset._data_count()
71 | from data_handler.custom_loader import Customsampler
72 |
73 | def _init_fn(worker_id):
74 | np.random.seed(int(args.seed))
75 | sampler = Customsampler(loader.dataset, replacement=False, batch_size=args.batch_size)
76 | train_dataloader = DataLoader(loader.dataset, batch_size=args.batch_size, sampler=sampler,
77 | num_workers=args.num_workers, worker_init_fn=_init_fn, pin_memory=True, drop_last=True)
78 |
79 | del dataloader
80 | del model
81 | del loader
82 | if args.labelwise:
83 | return train_dataloader
84 |
--------------------------------------------------------------------------------
/trainer/vanilla_train.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | from __future__ import print_function
7 |
8 | import time
9 | from utils import get_accuracy
10 | import trainer
11 |
12 |
13 | class Trainer(trainer.GenericTrainer):
14 | def __init__(self, args, **kwargs):
15 | super().__init__(args=args, **kwargs)
16 |
17 | def train(self, train_loader, test_loader, epochs, criterion=None, writer=None):
18 | global loss_set
19 | if criterion == None:
20 | criterion = self.criterion
21 | model = self.model
22 | model.train()
23 |
24 | for epoch in range(epochs):
25 | self._train_epoch(epoch, train_loader, model, criterion)
26 |
27 | eval_start_time = time.time()
28 | eval_loss, eval_acc, eval_deom, eval_deoa, eval_subgroup_acc = self.evaluate(self.model, test_loader, criterion)
29 | eval_end_time = time.time()
30 | print('[{}/{}] Method: {} '
31 | 'Test Loss: {:.3f} Test Acc: {:.2f} Test DEOM {:.2f} [{:.2f} s]'.format
32 | (epoch + 1, epochs, self.method,
33 | eval_loss, eval_acc, eval_deom, (eval_end_time - eval_start_time)))
34 | if self.record:
35 | train_loss, train_acc, train_deom, train_deoa, train_subgroup_acc = self.evaluate(self.model, train_loader, self.criterion)
36 | writer.add_scalar('train_loss', train_loss, epoch)
37 | writer.add_scalar('train_acc', train_acc, epoch)
38 | writer.add_scalar('train_deom', train_deom, epoch)
39 | writer.add_scalar('train_deoa', train_deoa, epoch)
40 | writer.add_scalar('eval_loss', eval_loss, epoch)
41 | writer.add_scalar('eval_acc', eval_acc, epoch)
42 | writer.add_scalar('eval_deom', eval_deom, epoch)
43 | writer.add_scalar('eval_deoa', eval_deoa, epoch)
44 |
45 | eval_contents = {}
46 | train_contents = {}
47 | for g in range(num_groups):
48 | for l in range(num_classes):
49 | eval_contents[f'g{g},l{l}'] = eval_subgroup_acc[g, l]
50 | train_contents[f'g{g},l{l}'] = train_subgroup_acc[g, l]
51 | writer.add_scalars('eval_subgroup_acc', eval_contents, epoch)
52 | writer.add_scalars('train_subgroup_acc', train_contents, epoch)
53 |
54 | if self.scheduler is not None and 'Reduce' in type(self.scheduler).__name__:
55 | self.scheduler.step(eval_loss)
56 | else:
57 | self.scheduler.step()
58 | print('Training Finished!')
59 |
60 | def _train_epoch(self, epoch, train_loader, model, criterion=None):
61 | model.train()
62 |
63 | running_acc = 0.0
64 | running_loss = 0.0
65 | batch_start_time = time.time()
66 | for i, data in enumerate(train_loader):
67 | # Get the inputs
68 | inputs, _, groups, targets, _ = data
69 | labels = targets
70 |
71 | if self.cuda:
72 | inputs = inputs.cuda(device=self.device)
73 | labels = labels.cuda(device=self.device)
74 | outputs = model(inputs)
75 | if criterion is not None:
76 | loss = criterion(outputs, labels)
77 | else:
78 | loss = self.criterion(outputs, labels)
79 | running_loss += loss.item()
80 | running_acc += get_accuracy(outputs, labels)
81 |
82 | self.optimizer.zero_grad()
83 | loss.backward()
84 | self.optimizer.step()
85 |
86 | if i % self.term == self.term-1: # print every self.term mini-batches
87 | avg_batch_time = time.time()-batch_start_time
88 | print('[{}/{}, {:5d}] Method: {} Train Loss: {:.3f} Train Acc: {:.2f} '
89 | '[{:.2f} s/batch]'.format
90 | (epoch + 1, self.epochs, i+1, self.method, running_loss / self.term, running_acc / self.term,
91 | avg_batch_time/self.term))
92 |
93 | running_loss = 0.0
94 | running_acc = 0.0
95 | batch_start_time = time.time()
96 |
--------------------------------------------------------------------------------
/data_handler/dataset_factory.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | import importlib
7 | import torch.utils.data as data
8 | import numpy as np
9 | from collections import defaultdict
10 |
11 | dataset_dict = {
12 | 'utkface': ['data_handler.utkface', 'UTKFaceDataset'],
13 | 'utkface_fairface': ['data_handler.utkface_fairface', 'UTKFaceFairface_Dataset'],
14 | 'celeba': ['data_handler.celeba', 'CelebA'],
15 | 'adult': ['data_handler.adult', 'AdultDataset_torch'],
16 | 'compas': ['data_handler.compas', 'CompasDataset_torch'],
17 | 'cifar100s': ['data_handler.cifar100s', 'CIFAR_100S'],
18 | }
19 |
20 |
21 | class DatasetFactory:
22 | def __init__(self):
23 | pass
24 |
25 | @staticmethod
26 | def get_dataset(name, split='Train', seed=0, sv_ratio=1, version=1, target_attr='Attractive', add_attr=None, ups_iter=0):
27 | root = f'./data/{name}' if name != 'utkface_fairface' else './data/utkface'
28 | kwargs = {
29 | 'root': root,
30 | 'split': split,
31 | 'seed': seed,
32 | 'sv_ratio': sv_ratio,
33 | 'version': version,
34 | 'ups_iter': ups_iter
35 | }
36 |
37 | if name not in dataset_dict.keys():
38 | raise Exception('Not allowed method')
39 |
40 | if name == 'celeba':
41 | kwargs['add_attr'] = add_attr
42 | kwargs['target_attr'] = target_attr
43 | elif name == 'adult':
44 | kwargs['target_attr'] = target_attr
45 | elif name == 'compas':
46 | kwargs['target_attr'] = target_attr
47 |
48 | module = importlib.import_module(dataset_dict[name][0])
49 | class_ = getattr(module, dataset_dict[name][1])
50 | return class_(**kwargs)
51 |
52 |
53 | class GenericDataset(data.Dataset):
54 | def __init__(self, root, split='train', transform=None, seed=0):
55 | self.root = root
56 | self.split = split
57 | self.transform = transform
58 | self.seed = seed
59 | self.num_data = None
60 |
61 | def __len__(self):
62 | return np.sum(self.num_data)
63 |
64 | def _data_count(self, features, num_groups, num_classes):
65 | idxs_per_group = defaultdict(lambda: [])
66 | data_count = np.zeros((num_groups, num_classes), dtype=int)
67 |
68 | for idx, i in enumerate(features):
69 | s, l = int(i[0]), int(i[1])
70 | data_count[s, l] += 1
71 | idxs_per_group[(s, l)].append(idx)
72 |
73 | print(f'mode : {self.split}')
74 | for i in range(num_groups):
75 | print('# of %d group data : ' % i, data_count[i, :])
76 | return data_count, idxs_per_group
77 |
78 | def _make_data(self, features, num_groups, num_classes):
79 | # if the original dataset not is divided into train / test set, this function is used
80 | min_cnt = 100
81 | data_count = np.zeros((num_groups, num_classes), dtype=int)
82 | tmp = []
83 | for i in reversed(self.features):
84 | s, l = int(i[0]), int(i[1])
85 | data_count[s, l] += 1
86 | if data_count[s, l] <= min_cnt:
87 | features.remove(i)
88 | tmp.append(i)
89 |
90 | train_data = features
91 | test_data = tmp
92 | return train_data, test_data
93 |
94 | def _make_weights(self):
95 | group_weights = len(self) / self.num_data
96 | weights = [group_weights[s, l] for s, l, _ in self.features]
97 | return weights
98 |
99 | def _balance_test_data(self, num_data, num_groups, num_classes):
100 | print('balance test data...')
101 | # if the original dataset is divided into train / test set, this function is used
102 | num_data_min = np.min(num_data)
103 | print('min : ', num_data_min)
104 | data_count = np.zeros((num_groups, num_classes), dtype=int)
105 | new_features = []
106 | for idx, i in enumerate(self.features):
107 | s, l = int(i[0]), int(i[1])
108 | if data_count[s, l] < num_data_min:
109 | new_features.append(i)
110 | data_count[s, l] += 1
111 |
112 | return new_features
113 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Original code:
3 | https://github.com/sangwon79/Fair-Feature-Distillation-for-Visual-Recognition
4 | """
5 | import torch
6 | import numpy as np
7 | import random
8 | import os
9 | import torch.nn.functional as F
10 |
11 |
12 | def list_files(root, suffix, prefix=False):
13 | root = os.path.expanduser(root)
14 | files = list(
15 | filter(
16 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
17 | os.listdir(root)
18 | )
19 | )
20 | if prefix is True:
21 | files = [os.path.join(root, d) for d in files]
22 | return files
23 |
24 |
25 | def set_seed(seed):
26 | torch.manual_seed(seed)
27 | # torch.cuda.manual_seed(seed)
28 | np.random.seed(seed)
29 | random.seed(seed)
30 | torch.backends.cudnn.benchmark = False
31 | torch.backends.cudnn.deterministic = True
32 |
33 |
34 | def get_accuracy(outputs, labels, binary=False, reduction='mean'):
35 | # if multi-label classification
36 | if len(labels.size()) > 1:
37 | outputs = (outputs > 0.0).float()
38 | correct = ((outputs == labels)).float().sum()
39 | total = torch.tensor(labels.shape[0] * labels.shape[1], dtype=torch.float)
40 | avg = correct / total
41 | return avg.item()
42 |
43 | if binary:
44 | predictions = (torch.sigmoid(outputs) >= 0.5).float()
45 | else:
46 | predictions = torch.argmax(outputs, 1)
47 |
48 | c = (predictions == labels).float().squeeze()
49 | if reduction == 'none':
50 | return c
51 | else:
52 | accuracy = torch.mean(c)
53 | return accuracy.item()
54 |
55 |
56 | def check_log_dir(log_dir):
57 | try:
58 | if not os.path.isdir(log_dir):
59 | os.makedirs(log_dir)
60 | except OSError:
61 | print("Failed to create directory!!")
62 |
63 |
64 | class FitnetRegressor(torch.nn.Module):
65 | def __init__(self, in_feature, out_feature):
66 | super(FitnetRegressor, self).__init__()
67 | self.in_feature = in_feature
68 | self.out_feature = out_feature
69 |
70 | self.regressor = torch.nn.Conv2d(in_feature, out_feature, 1, bias=False)
71 | torch.nn.init.kaiming_normal_(self.regressor.weight, mode='fan_out', nonlinearity='relu')
72 | self.regressor.weight.data.uniform_(-0.005, 0.005)
73 |
74 | def forward(self, feature):
75 | if feature.dim() == 2:
76 | feature = feature.unsqueeze(2).unsqueeze(3)
77 |
78 | return F.relu(self.regressor(feature))
79 |
80 |
81 | def make_log_name(args):
82 | log_name = args.model
83 |
84 | if args.mode == 'eva':
85 | log_name = args.modelpath.split('/')[-1]
86 | # remove .pt from name
87 | log_name = log_name[:-3]
88 |
89 | else:
90 | if args.pretrained:
91 | log_name += '_pretrained'
92 | log_name += '_seed{}_epochs{}_bs{}_lr{}'.format(args.seed, args.epochs, args.batch_size, args.lr)
93 | if args.method == 'adv':
94 | log_name += '_lamb{}_eta{}'.format(args.lamb, args.eta)
95 |
96 | elif args.method == 'scratch_mmd' or args.method == 'kd_mfd':
97 | log_name += '_{}'.format(args.kernel)
98 | log_name += '_sigma{}'.format(args.sigma) if args.kernel == 'rbf' else ''
99 | log_name += '_{}'.format(args.lambhf)
100 |
101 | elif args.method == 'reweighting':
102 | log_name += '_eta{}_iter{}'.format(args.eta, args.iteration)
103 |
104 | elif 'groupdro' in args.method:
105 | log_name += '_gamma{}'.format(args.gamma)
106 |
107 | if args.labelwise:
108 | log_name += '_labelwise'
109 |
110 | if args.teacher_path is not None or args.method == 'fairhsic':
111 | log_name += '_lamb{}'.format(args.lamb)
112 | log_name += '_from_{}'.format(args.teacher_type)
113 |
114 | if args.dataset == 'celeba':
115 | if args.target != 'Attractive':
116 | log_name += '_{}'.format(args.target)
117 | if args.add_attr is not None:
118 | log_name += '_{}'.format(args.add_attr)
119 | if args.sv < 1:
120 | log_name += '_sv{}'.format(args.sv)
121 | log_name += '_{}'.format(args.version)
122 |
123 | return log_name
124 |
--------------------------------------------------------------------------------
/data_handler/AIF360/compas_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Original code:
3 | https://github.com/Trusted-AI/AIF360
4 | """
5 | import os
6 |
7 | import pandas as pd
8 |
9 | from data_handler.AIF360.standard_dataset import StandardDataset
10 |
11 |
12 | default_mappings = {
13 | 'label_maps': [{1.0: 'Did recid.', 0.0: 'No recid.'}],
14 | 'protected_attribute_maps': [{0.0: 'Male', 1.0: 'Female'},
15 | {1.0: 'Caucasian', 0.0: 'Not Caucasian'}]
16 | }
17 |
18 |
19 | def default_preprocessing(df):
20 | """Perform the same preprocessing as the original analysis:
21 | https://github.com/propublica/compas-analysis/blob/master/Compas%20Analysis.ipynb
22 | """
23 | return df[(df.days_b_screening_arrest <= 30)
24 | & (df.days_b_screening_arrest >= -30)
25 | & (df.is_recid != -1)
26 | & (df.c_charge_degree != 'O')
27 | & (df.score_text != 'N/A')]
28 |
29 |
30 | class CompasDataset(StandardDataset):
31 | """ProPublica COMPAS Dataset.
32 | See :file:`aif360/data/raw/compas/README.md`.
33 | """
34 |
35 | def __init__(self, root_dir='./data/compas',
36 | label_name='two_year_recid', favorable_classes=[0],
37 | protected_attribute_names=['sex', 'race'],
38 | privileged_classes=[['Female'], ['Caucasian']],
39 | instance_weights_name=None,
40 | categorical_features=['age_cat', 'c_charge_degree', 'c_charge_desc'],
41 | features_to_keep=['sex', 'age', 'age_cat', 'race',
42 | 'juv_fel_count', 'juv_misd_count', 'juv_other_count',
43 | 'priors_count', 'c_charge_degree', 'c_charge_desc',
44 | 'two_year_recid'],
45 | features_to_drop=[], na_values=[],
46 | custom_preprocessing=default_preprocessing,
47 | metadata=default_mappings):
48 | """See :obj:`StandardDataset` for a description of the arguments.
49 | Note: The label value 0 in this case is considered favorable (no
50 | recidivism).
51 | Examples:
52 | In some cases, it may be useful to keep track of a mapping from
53 | `float -> str` for protected attributes and/or labels. If our use
54 | case differs from the default, we can modify the mapping stored in
55 | `metadata`:
56 | >>> label_map = {1.0: 'Did recid.', 0.0: 'No recid.'}
57 | >>> protected_attribute_maps = [{1.0: 'Male', 0.0: 'Female'}]
58 | >>> cd = CompasDataset(protected_attribute_names=['sex'],
59 | ... privileged_classes=[['Male']], metadata={'label_map': label_map,
60 | ... 'protected_attribute_maps': protected_attribute_maps})
61 | Now this information will stay attached to the dataset and can be
62 | used for more descriptive visualizations.
63 | """
64 |
65 | filepath = os.path.join(root_dir, 'compas-scores-two-years.csv')
66 |
67 | try:
68 | df = pd.read_csv(filepath, index_col='id', na_values=na_values)
69 | except IOError as err:
70 | print("IOError: {}".format(err))
71 | print("To use this class, please download the following file:")
72 | print("\n\thttps://raw.githubusercontent.com/propublica/compas-analysis/master/compas-scores-two-years.csv")
73 | print("\nand place it, as-is, in the folder:")
74 | print("\n\t{}\n".format(root_dir))
75 | import sys
76 | sys.exit(1)
77 |
78 | super(CompasDataset, self).__init__(df=df, label_name=label_name,
79 | favorable_classes=favorable_classes,
80 | protected_attribute_names=protected_attribute_names,
81 | privileged_classes=privileged_classes,
82 | instance_weights_name=instance_weights_name,
83 | categorical_features=categorical_features,
84 | features_to_keep=features_to_keep,
85 | features_to_drop=features_to_drop, na_values=na_values,
86 | custom_preprocessing=custom_preprocessing, metadata=metadata)
87 |
--------------------------------------------------------------------------------
/networks/autoencoder.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | import torch
7 | import torch.nn as nn
8 |
9 | # Conv Layer
10 | class ConvLayer(nn.Module):
11 | def __init__(self, in_channels, out_channels, kernel_size, stride):
12 | super(ConvLayer, self).__init__()
13 | padding = kernel_size // 2
14 | self.reflection_pad = nn.ReflectionPad2d(padding)
15 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride) #, padding)
16 |
17 | def forward(self, x):
18 | out = self.reflection_pad(x)
19 | out = self.conv2d(out)
20 | return out
21 |
22 | # Upsample Conv Layer
23 | class UpsampleConvLayer(nn.Module):
24 | def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
25 | super(UpsampleConvLayer, self).__init__()
26 | self.upsample = upsample
27 | if upsample:
28 | self.upsample = nn.Upsample(scale_factor=upsample, mode='nearest')
29 | reflection_padding = kernel_size // 2
30 | self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
31 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
32 |
33 | def forward(self, x):
34 | if self.upsample:
35 | x = self.upsample(x)
36 | out = self.reflection_pad(x)
37 | out = self.conv2d(out)
38 | return out
39 |
40 | # Residual Block
41 | # adapted from pytorch tutorial
42 | # https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-
43 | # intermediate/deep_residual_network/main.py
44 | class ResidualBlock(nn.Module):
45 | def __init__(self, channels):
46 | super(ResidualBlock, self).__init__()
47 | self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
48 | self.in1 = nn.InstanceNorm2d(channels, affine=True)
49 | self.relu = nn.ReLU()
50 | self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
51 | self.in2 = nn.InstanceNorm2d(channels, affine=True)
52 |
53 | def forward(self, x):
54 | residual = x
55 | out = self.relu(self.in1(self.conv1(x)))
56 | out = self.in2(self.conv2(out))
57 | out = out + residual
58 | out = self.relu(out)
59 | return out
60 |
61 | # Image Transform Network
62 | class ImageTransformNet(nn.Module):
63 | def __init__(self):
64 | super(ImageTransformNet, self).__init__()
65 |
66 | # nonlineraity
67 | self.relu = nn.ReLU()
68 | self.tanh = nn.Tanh()
69 |
70 | # encoding layers
71 | self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
72 | self.in1_e = nn.InstanceNorm2d(32, affine=True)
73 |
74 | self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
75 | self.in2_e = nn.InstanceNorm2d(64, affine=True)
76 |
77 | self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
78 | self.in3_e = nn.InstanceNorm2d(128, affine=True)
79 |
80 | # residual layers
81 | self.res1 = ResidualBlock(128)
82 | self.res2 = ResidualBlock(128)
83 | self.res3 = ResidualBlock(128)
84 | self.res4 = ResidualBlock(128)
85 | self.res5 = ResidualBlock(128)
86 |
87 | # decoding layers
88 | self.deconv3 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2 )
89 | self.in3_d = nn.InstanceNorm2d(64, affine=True)
90 |
91 | self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2 )
92 | self.in2_d = nn.InstanceNorm2d(32, affine=True)
93 |
94 | self.deconv1 = UpsampleConvLayer(32, 3, kernel_size=9, stride=1)
95 | self.in1_d = nn.InstanceNorm2d(3, affine=True)
96 |
97 | def forward(self, x):
98 | # encode
99 | y = self.relu(self.in1_e(self.conv1(x)))
100 | y = self.relu(self.in2_e(self.conv2(y)))
101 | y = self.relu(self.in3_e(self.conv3(y)))
102 |
103 | # residual layers
104 | y = self.res1(y)
105 | y = self.res2(y)
106 | y = self.res3(y)
107 | y = self.res4(y)
108 | y = self.res5(y)
109 |
110 | # decode
111 | y = self.relu(self.in3_d(self.deconv3(y)))
112 | y = self.relu(self.in2_d(self.deconv2(y)))
113 | #y = self.tanh(self.in1_d(self.deconv1(y)))
114 | y = self.deconv1(y)
115 |
116 | return y
117 |
--------------------------------------------------------------------------------
/trainer/hsic.py:
--------------------------------------------------------------------------------
1 | """
2 | Original code:
3 | https://github.com/clovaai/rebias
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | def to_numpy(x):
10 | """convert Pytorch tensor to numpy array
11 | """
12 | return x.clone().detach().cpu().numpy()
13 |
14 |
15 | class HSIC(nn.Module):
16 | """Base class for the finite sample estimator of Hilbert-Schmidt Independence Criterion (HSIC)
17 | ..math:: HSIC (X, Y) := || C_{x, y} ||^2_{HS}, where HSIC (X, Y) = 0 iif X and Y are independent.
18 | Empirically, we use the finite sample estimator of HSIC (with m observations) by,
19 | (1) biased estimator (HSIC_0)
20 | Gretton, Arthur, et al. "Measuring statistical dependence with Hilbert-Schmidt norms." 2005.
21 | :math: (m - 1)^2 tr KHLH.
22 | where K_{ij} = kernel_x (x_i, x_j), L_{ij} = kernel_y (y_i, y_j), H = 1 - m^{-1} 1 1 (Hence, K, L, H are m by m matrices).
23 | (2) unbiased estimator (HSIC_1)
24 | Song, Le, et al. "Feature selection via dependence maximization." 2012.
25 | :math: \frac{1}{m (m - 3)} \bigg[ tr (\tilde K \tilde L) + \frac{1^\top \tilde K 1 1^\top \tilde L 1}{(m-1)(m-2)} - \frac{2}{m-2} 1^\top \tilde K \tilde L 1 \bigg].
26 | where \tilde K and \tilde L are related to K and L by the diagonal entries of \tilde K_{ij} and \tilde L_{ij} are set to zero.
27 | Parameters
28 | ----------
29 | sigma_x : float
30 | the kernel size of the kernel function for X.
31 | sigma_y : float
32 | the kernel size of the kernel function for Y.
33 | algorithm: str ('unbiased' / 'biased')
34 | the algorithm for the finite sample estimator. 'unbiased' is used for our paper.
35 | reduction: not used (for compatibility with other losses).
36 | """
37 | def __init__(self, sigma_x, sigma_y=None, algorithm='unbiased',
38 | reduction=None):
39 | super(HSIC, self).__init__()
40 |
41 | if sigma_y is None:
42 | sigma_y = sigma_x
43 |
44 | self.sigma_x = sigma_x
45 | self.sigma_y = sigma_y
46 |
47 | if algorithm == 'biased':
48 | self.estimator = self.biased_estimator
49 | elif algorithm == 'unbiased':
50 | self.estimator = self.unbiased_estimator
51 | else:
52 | raise ValueError('invalid estimator: {}'.format(algorithm))
53 |
54 | def _kernel_x(self, X):
55 | raise NotImplementedError
56 |
57 | def _kernel_y(self, Y):
58 | raise NotImplementedError
59 |
60 | def biased_estimator(self, input1, input2):
61 | """Biased estimator of Hilbert-Schmidt Independence Criterion
62 | Gretton, Arthur, et al. "Measuring statistical dependence with Hilbert-Schmidt norms." 2005.
63 | """
64 | K = self._kernel_x(input1)
65 | L = self._kernel_y(input2)
66 |
67 | KH = K - K.mean(0, keepdim=True)
68 | LH = L - L.mean(0, keepdim=True)
69 |
70 | N = len(input1)
71 |
72 | return torch.trace(KH @ LH / (N - 1) ** 2)
73 |
74 | def unbiased_estimator(self, input1, input2):
75 | """Unbiased estimator of Hilbert-Schmidt Independence Criterion
76 | Song, Le, et al. "Feature selection via dependence maximization." 2012.
77 | """
78 | kernel_XX = self._kernel_x(input1)
79 | kernel_YY = self._kernel_y(input2)
80 |
81 | tK = kernel_XX - torch.diag(kernel_XX)
82 | tL = kernel_YY - torch.diag(kernel_YY)
83 | N = len(input1)
84 |
85 | hsic = (
86 | torch.trace(tK @ tL)
87 | + (torch.sum(tK) * torch.sum(tL) / (N - 1) / (N - 2))
88 | - (2 * torch.sum(tK, 0).dot(torch.sum(tL, 0)) / (N - 2))
89 | )
90 |
91 | return hsic / (N * (N - 3))
92 |
93 | def forward(self, input1, input2, **kwargs):
94 | return self.estimator(input1, input2)
95 |
96 |
97 | class RbfHSIC(HSIC):
98 | """Radial Basis Function (RBF) kernel HSIC implementation.
99 | """
100 | def _kernel(self, X, sigma):
101 | X = X.view(len(X), -1)
102 | Xn = X.norm(2, dim=1, keepdim=True)
103 | X = X.div(Xn)
104 | XX = X @ X.t()
105 | X_sqnorms = torch.diag(XX)
106 | X_L2 = -2 * XX + X_sqnorms.unsqueeze(1) + X_sqnorms.unsqueeze(0)
107 | X_L2 = X_L2.clamp(1e-12)
108 | sigma_avg = X_L2.mean().detach()
109 | gamma = 1/(2*sigma_avg)
110 | # gamma = 1 / (2 * sigma ** 2)
111 |
112 | kernel_XX = torch.exp(-gamma * X_L2)
113 | return kernel_XX
114 |
115 | def _kernel_x(self, X):
116 | return self._kernel(X, self.sigma_x)
117 |
118 | def _kernel_y(self, Y):
119 | return self._kernel(Y, self.sigma_y)
120 |
121 |
122 | class MinusRbfHSIC(RbfHSIC):
123 | """``Minus'' RbfHSIC for the ``max'' optimization.
124 | """
125 | def forward(self, input1, input2, **kwargs):
126 | return -self.estimator(input1, input2)
127 |
--------------------------------------------------------------------------------
/data_handler/utkface.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | from os.path import join
7 | from PIL import Image
8 | from utils import list_files
9 | from natsort import natsorted
10 | import random
11 | import numpy as np
12 | from torchvision import transforms
13 | from data_handler import SSLDataset
14 | from data_handler.utils import get_mean_std
15 |
16 |
17 | class UTKFaceDataset(SSLDataset):
18 | label = 'age'
19 | sensi = 'race'
20 | fea_map = {
21 | 'age': 0,
22 | 'gender': 1,
23 | 'race': 2
24 | }
25 | num_map = {
26 | 'age': 100, # will be changed if the function '_transorm_age' is called
27 | 'gender': 2,
28 | 'race': 4
29 | }
30 | mean, std = get_mean_std('utkface')
31 | train_transform = transforms.Compose(
32 | [transforms.Resize((256, 256)),
33 | transforms.RandomCrop(224),
34 | transforms.RandomHorizontalFlip(),
35 | transforms.ToTensor(),
36 | transforms.Normalize(mean=mean, std=std)]
37 | )
38 | test_transform = transforms.Compose(
39 | [transforms.Resize((224, 224)),
40 | transforms.ToTensor(),
41 | transforms.Normalize(mean=mean, std=std)]
42 | )
43 | name = 'utkface'
44 |
45 | def __init__(self, **kwargs):
46 | transform = self.train_transform if kwargs['split'] == 'train' else self.test_transform
47 |
48 | SSLDataset.__init__(self, transform=transform, **kwargs)
49 |
50 | filenames = list_files(self.root, '.jpg')
51 | filenames = natsorted(filenames)
52 | self._data_preprocessing(filenames)
53 | self.num_groups = self.num_map[self.sensi]
54 | self.num_classes = self.num_map[self.label]
55 |
56 | # we want the same train / test set, so fix the seed to 1
57 | random.seed(1)
58 | random.shuffle(self.features)
59 |
60 | train, test = self._make_data(self.features, self.num_groups, self.num_classes)
61 | self.features = train if self.split == 'train' or 'group' in self.version else test
62 |
63 | self.num_data, self.idxs_per_group = self._data_count(self.features, self.num_groups, self.num_classes)
64 |
65 | # if semi-supervised learning,
66 | if self.sv_ratio < 1:
67 | # we want the different supervision according to the seed
68 | random.seed(self.seed)
69 | self.features, self.num_data, self.idxs_per_group = self.ssl_processing(self.features, self.num_data, self.idxs_per_group, )
70 | if 'group' in self.version:
71 | a, b = self.num_groups, self.num_classes
72 | self.num_groups, self.num_classes = b, a
73 |
74 | self.weights = self._make_weights()
75 |
76 | def __getitem__(self, index):
77 | s, l, img_name = self.features[index]
78 |
79 | image_path = join(self.root, img_name)
80 | image = Image.open(image_path, mode='r').convert('RGB')
81 |
82 | if self.transform:
83 | image = self.transform(image)
84 |
85 | if 'group' in self.version:
86 | return image, 1, np.float32(l), np.int64(s), (index, img_name)
87 | else:
88 | return image, 1, np.float32(s), np.int64(l), (index, img_name)
89 |
90 | # five functions below preprocess UTKFace dataset
91 | def _data_preprocessing(self, filenames):
92 | filenames = self._delete_incomplete_images(filenames)
93 | filenames = self._delete_others_n_age_filter(filenames)
94 | self.features = []
95 | for filename in filenames:
96 | s, y = self._filename2SY(filename)
97 | self.features.append([s, y, filename])
98 |
99 | def _filename2SY(self, filename):
100 | tmp = filename.split('_')
101 | sensi = int(tmp[self.fea_map[self.sensi]])
102 | label = int(tmp[self.fea_map[self.label]])
103 | if self.sensi == 'age':
104 | sensi = self._transform_age(sensi)
105 | if self.label == 'age':
106 | label = self._transform_age(label)
107 | return int(sensi), int(label)
108 |
109 | def _transform_age(self, age):
110 | if age < 20:
111 | label = 0
112 | elif age < 40:
113 | label = 1
114 | else:
115 | label = 2
116 | return label
117 |
118 | def _delete_incomplete_images(self, filenames):
119 | filenames = [image for image in filenames if len(image.split('_')) == 4]
120 | return filenames
121 |
122 | def _delete_others_n_age_filter(self, filenames):
123 | filenames = [image for image in filenames
124 | if ((image.split('_')[self.fea_map['race']] != '4'))]
125 | ages = [self._transform_age(int(image.split('_')[self.fea_map['age']])) for image in filenames]
126 | self.num_map['age'] = len(set(ages))
127 | return filenames
128 |
--------------------------------------------------------------------------------
/trainer/fairhsic.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | from __future__ import print_function
7 |
8 | import torch.nn.functional as F
9 | import time
10 | from utils import get_accuracy
11 | import trainer
12 | from .hsic import RbfHSIC
13 |
14 |
15 | class Trainer(trainer.GenericTrainer):
16 | def __init__(self, args, **kwargs):
17 | super().__init__(args=args, **kwargs)
18 | self.lamb = args.lamb
19 | self.sigma = args.sigma
20 | self.kernel = args.kernel
21 | self.slmode = True if args.sv < 1 else False
22 | self.version = args.version
23 |
24 | def train(self, train_loader, test_loader, epochs, writer=None):
25 | num_classes = train_loader.dataset.num_classes
26 | num_groups = train_loader.dataset.num_groups
27 |
28 | hsic = RbfHSIC(1, 1)
29 |
30 | for epoch in range(self.epochs):
31 | self._train_epoch(epoch, train_loader, self.model, hsic=hsic, num_classes=num_classes)
32 |
33 | eval_start_time = time.time()
34 | eval_loss, eval_acc, eval_deom, eval_deoa, eval_subgroup_acc = self.evaluate(self.model, test_loader, self.criterion)
35 | eval_end_time = time.time()
36 | print('[{}/{}] Method: {} '
37 | 'Test Loss: {:.3f} Test Acc: {:.2f} Test DEOM {:.2f} [{:.2f} s]'.format
38 | (epoch + 1, epochs, self.method,
39 | eval_loss, eval_acc, eval_deom, (eval_end_time - eval_start_time)))
40 |
41 | if self.record:
42 | train_loss, train_acc, train_deom, train_deoa, train_subgroup_acc = self.evaluate(self.model, train_loader, self.criterion)
43 | writer.add_scalar('train_loss', train_loss, epoch)
44 | writer.add_scalar('train_acc', train_acc, epoch)
45 | writer.add_scalar('train_deom', train_deom, epoch)
46 | writer.add_scalar('train_deoa', train_deoa, epoch)
47 | writer.add_scalar('eval_loss', eval_loss, epoch)
48 | writer.add_scalar('eval_acc', eval_acc, epoch)
49 | writer.add_scalar('eval_deom', eval_deom, epoch)
50 | writer.add_scalar('eval_deoa', eval_deoa, epoch)
51 |
52 | eval_contents = {}
53 | train_contents = {}
54 | for g in range(num_groups):
55 | for l in range(num_classes):
56 | eval_contents[f'g{g},l{l}'] = eval_subgroup_acc[g, l]
57 | train_contents[f'g{g},l{l}'] = train_subgroup_acc[g, l]
58 | writer.add_scalars('eval_subgroup_acc', eval_contents, epoch)
59 | writer.add_scalars('train_subgroup_acc', train_contents, epoch)
60 |
61 | if self.scheduler is not None and 'Reduce' in type(self.scheduler).__name__:
62 | self.scheduler.step(eval_loss)
63 | else:
64 | self.scheduler.step()
65 |
66 | print('Training Finished!')
67 |
68 | def _train_epoch(self, epoch, train_loader, model, hsic=None, num_classes=3):
69 | model.train()
70 |
71 | running_acc = 0.0
72 | running_loss = 0.0
73 | batch_start_time = time.time()
74 |
75 | for i, data in enumerate(train_loader):
76 | # Get the inputs
77 | inputs, _, groups, targets, (idx, _) = data
78 | labels = targets
79 | if self.cuda:
80 | inputs = inputs.cuda(self.device)
81 | labels = labels.cuda(self.device)
82 | groups = groups.long().cuda(self.device)
83 |
84 | outputs = model(inputs, get_inter=True)
85 |
86 | stu_logits = outputs[-1]
87 |
88 | loss = self.criterion(stu_logits, labels)
89 |
90 | running_acc += get_accuracy(stu_logits, labels)
91 |
92 | f_s = outputs[-2]
93 |
94 | group_onehot = F.one_hot(groups).float()
95 | hsic_loss = 0
96 | for l in range(num_classes):
97 | mask = targets == l
98 | hsic_loss += hsic.unbiased_estimator(f_s[mask], group_onehot[mask])
99 |
100 | loss = loss + self.lamb * hsic_loss
101 | running_loss += loss.item()
102 | self.optimizer.zero_grad()
103 | loss.backward()
104 | self.optimizer.step()
105 | if i % self.term == self.term - 1: # print every self.term mini-batches
106 | avg_batch_time = time.time() - batch_start_time
107 | print('[{}/{}, {:5d}] Method: {} Train Loss: {:.3f} Train Acc: {:.2f} '
108 | '[{:.2f} s/batch]'.format
109 | (epoch + 1, self.epochs, i + 1, self.method, running_loss / self.term, running_acc / self.term,
110 | avg_batch_time / self.term))
111 |
112 | running_loss = 0.0
113 | running_acc = 0.0
114 | batch_start_time = time.time()
115 |
--------------------------------------------------------------------------------
/arguments.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | import argparse
7 |
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser(description='Fairness')
11 | parser.add_argument('--date', default='20200101', type=str, help='experiment date')
12 | parser.add_argument('--term', default=20, type=int, help='the period for recording train acc')
13 | parser.add_argument('--record', default=False, action='store_true', help='record using tensorboardX')
14 | parser.add_argument('--result-dir', default='./results/',
15 | help='directory to save results (default: ./results/)')
16 | parser.add_argument('--log-dir', default='./logs/',
17 | help='directory to save logs (default: ./logs/)')
18 | parser.add_argument('--data-dir', default='./data/',
19 | help='data directory (default: ./data/)')
20 | parser.add_argument('--save-dir', default='./trained_models/',
21 | help='directory to save trained models (default: ./trained_models/)')
22 | parser.add_argument('--mode', default='train', choices=['train', 'eval'])
23 | parser.add_argument('--evalset', default='test', choices=['all', 'train', 'test'])
24 | parser.add_argument('--get-inter', default=False, action='store_true',
25 | help='get penultimate features for TSNE visualization')
26 |
27 |
28 | #### base configuration for learning ####
29 | parser.add_argument('--seed', default=0, type=int, help='seed for randomness')
30 | # dataset
31 | parser.add_argument('--dataset', required=True, default='', choices=['adult', 'compas','utkface', 'celeba', 'utkface_fairface'])
32 | parser.add_argument('--batch-size', default=128, type=int, help='mini batch size')
33 | parser.add_argument('--img-size', default=224, type=int, help='img size for preprocessing')
34 | parser.add_argument('--num-workers', default=2, type=int, help='the number of thread used in dataloader')
35 | parser.add_argument('--labelwise', default=False, action='store_true', help='balanced sampling over groups')
36 | # only for celebA
37 | parser.add_argument('--target', default='Attractive', type=str, help='target attribute for celeba')
38 | parser.add_argument('--add-attr', default=None, help='additional group attribute for celeba')
39 |
40 |
41 | # model
42 | parser.add_argument('--model', default='', required=True, choices=['mlp', 'resnet18','resnet18_dropout'])
43 | parser.add_argument('--modelpath', default=None)
44 | parser.add_argument('--pretrained', default=False, action='store_true', help='load imagenet pretrained model')
45 | parser.add_argument('--device', default=0, type=int, help='cuda device number')
46 | parser.add_argument('--t-device', default=0, type=int, help='teacher cuda device number')
47 | # optimization
48 | parser.add_argument('--method', default='scratch', type=str, required=True,
49 | choices=['scratch', 'reweighting','mfd', 'adv', 'fairhsic'])
50 | parser.add_argument('--optimizer', default='Adam', type=str, required=False,
51 | choices=['AdamP', 'SGD', 'SGD_momentum_decay', 'Adam'],
52 | help='(default=%(default)s)')
53 | parser.add_argument('--epochs', default=50, type=int, help='number of training epochs')
54 | parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
55 | parser.add_argument('--weight-decay', default=0, type=float, help='weight decay')
56 |
57 |
58 | # for base fairness methods
59 | parser.add_argument('--lamb', default=1, type=float, help='fairness strength')
60 | # for MFD
61 | parser.add_argument('--sigma', default=1.0, type=float, help='sigma for rbf kernel')
62 | parser.add_argument('--kernel', default='rbf', type=str, choices=['rbf', 'poly'], help='kernel for mmd')
63 | parser.add_argument('--teacher-type', default=None, choices=['mlp', 'resnet18', 'resnet18_dropout'])
64 | parser.add_argument('--teacher-path', default=None, help='teacher model path')
65 |
66 | # For reweighting & adv
67 | parser.add_argument('--reweighting-target-criterion', default='eo', type=str, help='fairness criterion')
68 | parser.add_argument('--iteration', default=10, type=int, help='iteration for reweighting')
69 | parser.add_argument('--ups-iter', default=10, type=int, help='iteration for reweighting')
70 | parser.add_argument('--eta', default=0.001, type=float, help='adversary training learning rate or lr for reweighting')
71 |
72 | # For fair FG,
73 | parser.add_argument('--sv', default=1, type=float, help='the ratio of group annotation for a training set')
74 | parser.add_argument('--version', default='', type=str, help='version about how the unsupervised data is used')
75 |
76 | args = parser.parse_args()
77 | args.cuda=True
78 | if args.mode == 'train' and args.method == 'mfd':
79 | if args.teacher_type is None:
80 | raise Exception('A teacher model needs to be specified for distillation')
81 | elif args.teacher_path is None:
82 | raise Exception('A teacher model path is not specified.')
83 |
84 | return args
85 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Original code:
3 | https://github.com/sangwon79/Fair-Feature-Distillation-for-Visual-Recognition
4 | """
5 | import torch
6 | import torch.optim as optim
7 | import numpy as np
8 | import networks
9 | import data_handler
10 | import trainer
11 | from utils import check_log_dir, make_log_name, set_seed
12 | from adamp import AdamP
13 | from tensorboardX import SummaryWriter
14 |
15 | from arguments import get_args
16 | import time
17 | import os
18 | args = get_args()
19 |
20 |
21 | def main():
22 | torch.backends.cudnn.enabled = True
23 |
24 | seed = args.seed
25 | set_seed(seed)
26 |
27 | np.set_printoptions(precision=4)
28 | torch.set_printoptions(precision=4)
29 |
30 | log_name = make_log_name(args)
31 | dataset = args.dataset
32 | save_dir = os.path.join(args.save_dir, args.date, dataset, args.method)
33 | result_dir = os.path.join(args.result_dir, args.date, dataset, args.method)
34 | check_log_dir(save_dir)
35 | check_log_dir(result_dir)
36 | writer = None
37 | if args.record:
38 | log_dir = os.path.join(args.log_dir, args.date, dataset, args.method)
39 | check_log_dir(log_dir)
40 | writer = SummaryWriter(log_dir + '/' + log_name)
41 |
42 | ########################## get dataloader ################################
43 | if args.dataset == 'adult':
44 | args.img_size = 97
45 | elif args.dataset == 'compas':
46 | args.img_size = 400
47 | else:
48 | args.img_size = 224
49 | tmp = data_handler.DataloaderFactory.get_dataloader(args.dataset,
50 | batch_size=args.batch_size, seed=args.seed,
51 | num_workers=args.num_workers,
52 | target_attr=args.target,
53 | add_attr=args.add_attr,
54 | labelwise=args.labelwise,
55 | sv_ratio=args.sv,
56 | version=args.version,
57 | args=args
58 | )
59 | num_classes, num_groups, train_loader, test_loader = tmp
60 |
61 | ########################## get model ##################################
62 | if args.dataset == 'adult':
63 | args.img_size = 97
64 | elif args.dataset == 'compas':
65 | args.img_size = 400
66 | elif 'cifar' in args.dataset:
67 | args.img_size = 32
68 |
69 | model = networks.ModelFactory.get_model(args.model, num_classes, args.img_size,
70 | pretrained=args.pretrained, num_groups=num_groups)
71 |
72 | model.cuda('cuda:{}'.format(args.device))
73 |
74 | if args.modelpath is not None:
75 | model.load_state_dict(torch.load(args.modelpath))
76 |
77 | teacher = None
78 | if (args.method == 'mfd' or args.teacher_path is not None) and args.mode != 'eval':
79 | teacher = networks.ModelFactory.get_model(args.teacher_type, train_loader.dataset.num_classes, args.img_size)
80 | teacher.load_state_dict(torch.load(args.teacher_path, map_location=torch.device('cuda:{}'.format(args.t_device))))
81 | teacher.cuda('cuda:{}'.format(args.t_device))
82 |
83 | ########################## get trainer ##################################
84 | if 'Adam' in args.optimizer:
85 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
86 | elif 'AdamP' in args.optimizer:
87 | optimizer = AdamP(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
88 | elif 'SGD' in args.optimizer:
89 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
90 |
91 | trainer_ = trainer.TrainerFactory.get_trainer(args.method, model=model, args=args,
92 | optimizer=optimizer, teacher=teacher)
93 |
94 | ####################### start training or evaluating ####################
95 |
96 | if args.mode == 'train':
97 | start_t = time.time()
98 | trainer_.train(train_loader, test_loader, args.epochs, writer=writer)
99 | end_t = time.time()
100 | train_t = int((end_t - start_t)/60) # to minutes
101 | print('Training Time : {} hours {} minutes'.format(int(train_t/60), (train_t % 60)))
102 | trainer_.save_model(save_dir, log_name)
103 |
104 | else:
105 | print('Evaluation ----------------')
106 | model_to_load = args.modelpath
107 | trainer_.model.load_state_dict(torch.load(model_to_load))
108 | print('Trained model loaded successfully')
109 |
110 | if args.evalset == 'all':
111 | trainer_.compute_confusion_matix('train', train_loader.dataset.num_classes, train_loader, result_dir, log_name)
112 | trainer_.compute_confusion_matix('test', test_loader.dataset.num_classes, test_loader, result_dir, log_name)
113 |
114 | elif args.evalset == 'train':
115 | trainer_.compute_confusion_matix('train', train_loader.dataset.num_classes, train_loader, result_dir, log_name)
116 | else:
117 | trainer_.compute_confusion_matix('test', test_loader.dataset.num_classes, test_loader, result_dir, log_name)
118 | if writer is not None:
119 | writer.close()
120 | print('Done!')
121 |
122 |
123 | if __name__ == '__main__':
124 | main()
125 |
--------------------------------------------------------------------------------
/data_handler/AIF360/adult_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Original code:
3 | https://github.com/Trusted-AI/AIF360
4 | """
5 | import os
6 |
7 | import pandas as pd
8 |
9 | from data_handler.AIF360.standard_dataset import StandardDataset
10 |
11 |
12 | default_mappings = {
13 | 'label_maps': [{1.0: '>50K', 0.0: '<=50K'}],
14 | 'protected_attribute_maps': [{1.0: 'White', 0.0: 'Non-white'},
15 | {1.0: 'Male', 0.0: 'Female'}]
16 | }
17 |
18 |
19 | class AdultDataset(StandardDataset):
20 | """Adult Census Income Dataset.
21 | See :file:`aif360/data/raw/adult/README.md`.
22 | """
23 |
24 | def __init__(self, root_dir='./data/adult',
25 | label_name='income-per-year',
26 | favorable_classes=['>50K', '>50K.'],
27 | protected_attribute_names=['race', 'sex'],
28 | privileged_classes=[['White'], ['Male']],
29 | instance_weights_name=None,
30 | categorical_features=['workclass', 'education',
31 | 'marital-status', 'occupation', 'relationship',
32 | 'native-country'],
33 | features_to_keep=[], features_to_drop=['fnlwgt'],
34 | na_values=['?'], custom_preprocessing=None,
35 | metadata=default_mappings):
36 | """See :obj:`StandardDataset` for a description of the arguments.
37 | Examples:
38 | The following will instantiate a dataset which uses the `fnlwgt`
39 | feature:
40 | >>> from aif360.datasets import AdultDataset
41 | >>> ad = AdultDataset(instance_weights_name='fnlwgt',
42 | ... features_to_drop=[])
43 | WARNING:root:Missing Data: 3620 rows removed from dataset.
44 | >>> not np.all(ad.instance_weights == 1.)
45 | True
46 | To instantiate a dataset which utilizes only numerical features and
47 | a single protected attribute, run:
48 | >>> single_protected = ['sex']
49 | >>> single_privileged = [['Male']]
50 | >>> ad = AdultDataset(protected_attribute_names=single_protected,
51 | ... privileged_classes=single_privileged,
52 | ... categorical_features=[],
53 | ... features_to_keep=['age', 'education-num'])
54 | >>> print(ad.feature_names)
55 | ['education-num', 'age', 'sex']
56 | >>> print(ad.label_names)
57 | ['income-per-year']
58 | Note: the `protected_attribute_names` and `label_name` are kept even
59 | if they are not explicitly given in `features_to_keep`.
60 | In some cases, it may be useful to keep track of a mapping from
61 | `float -> str` for protected attributes and/or labels. If our use
62 | case differs from the default, we can modify the mapping stored in
63 | `metadata`:
64 | >>> label_map = {1.0: '>50K', 0.0: '<=50K'}
65 | >>> protected_attribute_maps = [{1.0: 'Male', 0.0: 'Female'}]
66 | >>> ad = AdultDataset(protected_attribute_names=['sex'],
67 | ... categorical_features=['workclass', 'education', 'marital-status',
68 | ... 'occupation', 'relationship', 'native-country', 'race'],
69 | ... privileged_classes=[['Male']], metadata={'label_map': label_map,
70 | ... 'protected_attribute_maps': protected_attribute_maps})
71 | Note that we are now adding `race` as a `categorical_features`.
72 | Now this information will stay attached to the dataset and can be
73 | used for more descriptive visualizations.
74 | """
75 |
76 | train_path = os.path.join(root_dir, 'adult.data')
77 | test_path = os.path.join(root_dir, 'adult.test')
78 |
79 | # as given by adult.names
80 | column_names = ['age', 'workclass', 'fnlwgt', 'education',
81 | 'education-num', 'marital-status', 'occupation', 'relationship',
82 | 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week',
83 | 'native-country', 'income-per-year']
84 | try:
85 | train = pd.read_csv(train_path, header=None, names=column_names,
86 | skipinitialspace=True, na_values=na_values)
87 | test = pd.read_csv(test_path, header=0, names=column_names,
88 | skipinitialspace=True, na_values=na_values)
89 | except IOError as err:
90 | print("IOError: {}".format(err))
91 | print("To use this class, please download the following files:")
92 | print("\n\thttps://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data")
93 | print("\thttps://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test")
94 | print("\thttps://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.names")
95 | print("\nand place them, as-is, in the folder:")
96 | print("\n\t{}\n".format('./data/adult'))
97 | import sys
98 | sys.exit(1)
99 |
100 | df = pd.concat([test, train], ignore_index=True)
101 |
102 | super(AdultDataset, self).__init__(df=df, label_name=label_name,
103 | favorable_classes=favorable_classes,
104 | protected_attribute_names=protected_attribute_names,
105 | privileged_classes=privileged_classes,
106 | instance_weights_name=instance_weights_name,
107 | categorical_features=categorical_features,
108 | features_to_keep=features_to_keep,
109 | features_to_drop=features_to_drop, na_values=na_values,
110 | custom_preprocessing=custom_preprocessing, metadata=metadata)
111 |
--------------------------------------------------------------------------------
/data_handler/utkface_fairface.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | from os.path import join
7 | from PIL import Image
8 | from utils import list_files
9 | from natsort import natsorted
10 | import random
11 | import numpy as np
12 | from torchvision import transforms
13 | from data_handler import SSLDataset
14 | from data_handler.utils import get_mean_std
15 | from data_handler.fairface import Fairface
16 |
17 |
18 | class UTKFaceFairface_Dataset(SSLDataset):
19 | label = 'age'
20 | sensi = 'race'
21 | fea_map = {
22 | 'age': 0,
23 | 'gender': 1,
24 | 'race': 2
25 | }
26 | num_map = {
27 | 'age': 100, # will be changed if the function '_transorm_age' is called
28 | 'gender': 2,
29 | 'race': 4
30 | }
31 | mean, std = get_mean_std('utkface')
32 | train_transform = transforms.Compose(
33 | [transforms.Resize((256, 256)),
34 | transforms.RandomCrop(224),
35 | transforms.RandomHorizontalFlip(),
36 | transforms.ToTensor(),
37 | transforms.Normalize(mean=mean, std=std)]
38 | )
39 | test_transform = transforms.Compose(
40 | [transforms.Resize((224, 224)),
41 | transforms.ToTensor(),
42 | transforms.Normalize(mean=mean, std=std)]
43 | )
44 | name = 'utkface_fairface'
45 |
46 | def __init__(self, **kwargs):
47 | transform = self.train_transform if kwargs['split'] == 'train' else self.test_transform
48 |
49 | SSLDataset.__init__(self, transform=transform, **kwargs)
50 |
51 | filenames = list_files(self.root, '.jpg')
52 | filenames = natsorted(filenames)
53 | self._data_preprocessing(filenames)
54 | self.num_groups = self.num_map[self.sensi]
55 | self.num_classes = self.num_map[self.label]
56 |
57 | # we want the same train / test set, so fix the seed to 1
58 | random.seed(1)
59 | random.shuffle(self.features)
60 |
61 | train, test = self._make_data(self.features, self.num_groups, self.num_classes)
62 | # if we train a group classifier, we don't use the original test dataset
63 | # and we split the original training dataset into train and validation dataset
64 | self.features = train if self.split == 'train' or 'group' in self.version else test
65 | self.num_data, self.idxs_per_group = self._data_count(self.features, self.num_groups, self.num_classes)
66 |
67 | # add fairface
68 | idxs_dict = None
69 | if self.split == 'train' or 'group' in self.version:
70 | fairface = Fairface(target_attr='age', root='./data/fairface', split='train', seed=self.seed)
71 | fairface_feature = fairface.feature_mat
72 | # shuffle
73 | random.shuffle(fairface_feature)
74 |
75 | # count only fairface
76 | print('count the number of data in fairface')
77 | num_data_fairface, idxs_per_group_fairface = self._data_count(fairface_feature, self.num_groups, self.num_classes)
78 | print(np.sum(num_data_fairface))
79 |
80 | # make idx_dict
81 | idxs_dict = {}
82 | idxs_dict['annotated'] = {}
83 | idxs_dict['non-annotated'] = {}
84 | total_num_utkface = len(self.features)
85 | for g in range(self.num_groups):
86 | for l in range(self.num_classes):
87 | tmp = np.array(idxs_per_group_fairface[g, l]) + total_num_utkface
88 | idxs_dict['annotated'][(g, l)] = self.idxs_per_group[(g, l)]
89 | idxs_dict['non-annotated'][(g, l)] = list(tmp)
90 |
91 | self.features.extend(fairface_feature)
92 |
93 | # count again
94 | self.num_data, self.idxs_per_group = self._data_count(self.features, self.num_groups, self.num_classes)
95 |
96 | # if semi-supervised learning,
97 | if self.sv_ratio < 1:
98 | # we want the different supervision according to the seed
99 | random.seed(self.seed)
100 | self.features, self.num_data, self.idxs_per_group = self.ssl_processing(self.features, self.num_data, self.idxs_per_group, idxs_dict=idxs_dict)
101 | if 'group' in self.version:
102 | a, b = self.num_groups, self.num_classes
103 | self.num_groups, self.num_classes = b, a
104 |
105 | def __getitem__(self, index):
106 | s, l, img_name = self.features[index]
107 |
108 | # image_path = join(self.root, img_name)
109 | image = Image.open(img_name, mode='r').convert('RGB')
110 |
111 | if self.transform:
112 | image = self.transform(image)
113 |
114 | if 'group' in self.version:
115 | return image, 1, np.float32(l), np.int64(s), (index, img_name)
116 | else:
117 | return image, 1, np.float32(s), np.int64(l), (index, img_name)
118 |
119 | # five functions below preprocess UTKFace dataset
120 | def _data_preprocessing(self, filenames):
121 | filenames = self._delete_incomplete_images(filenames)
122 | filenames = self._delete_others_n_age_filter(filenames)
123 | self.features = []
124 | for filename in filenames:
125 | s, y = self._filename2SY(filename)
126 | filename = join(self.root, filename)
127 | self.features.append([s, y, filename])
128 |
129 | def _filename2SY(self, filename):
130 | tmp = filename.split('_')
131 | sensi = int(tmp[self.fea_map[self.sensi]])
132 | label = int(tmp[self.fea_map[self.label]])
133 | if self.sensi == 'age':
134 | sensi = self._transform_age(sensi)
135 | if self.label == 'age':
136 | label = self._transform_age(label)
137 | return int(sensi), int(label)
138 |
139 | def _transform_age(self, age):
140 | if age < 20:
141 | label = 0
142 | elif age < 40:
143 | label = 1
144 | else:
145 | label = 2
146 | return label
147 |
148 | def _delete_incomplete_images(self, filenames):
149 | filenames = [image for image in filenames if len(image.split('_')) == 4]
150 | return filenames
151 |
152 | def _delete_others_n_age_filter(self, filenames):
153 | filenames = [image for image in filenames
154 | if ((image.split('_')[self.fea_map['race']] != '4'))]
155 | ages = [self._transform_age(int(image.split('_')[self.fea_map['age']])) for image in filenames]
156 | self.num_map['age'] = len(set(ages))
157 | return filenames
158 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Learning Fair Classifiers with Partially Annotated Group Labels (CVPR 2022)
2 |
3 | Official Pytorch implementation of Learning Fair Classifiers with Partially Annotated Group Labels | [Paper](https://arxiv.org/abs/2111.14581)
4 |
5 | [Sangwon Jung](https://scholar.google.com/citations?user=WdC_a5IAAAAJ&hl=ko)1 [Sanghyuk Chun](https://sanghyukchun.github.io/home/)2 [Taesup Moon](https://scholar.google.com/citations?user=lQlioBoAAAAJ&hl=ko)1, 3
6 |
7 | 1Department of ECE/ASRI, Seoul National University
8 | 2[NAVER AI LAB](https://naver-career.gitbook.io/en/teams/clova-cic)
9 | 3Interdisciplinary Program in Artificial Intelligence, Seoul National University
10 |
11 | Recently, fairness-aware learning have become increasingly crucial, but most of those methods operate by assuming the availability of fully annotated demographic group labels. We emphasize that such assumption is unrealistic for real-world applications since group label annotations are expensive and can conflict with privacy issues. In this paper, we consider a more practical scenario, dubbed as Algorithmic Group Fairness with the Partially annotated Group labels (Fair-PG). We observe that the existing methods to achieve group fairness perform even worse than the vanilla training, which simply uses full data only with target labels, under Fair-PG. To address this problem, we propose a simple Confidence-based Group Label assignment (CGL) strategy that is readily applicable to any fairness-aware learning method. CGL utilizes an auxiliary group classifier to assign pseudo group labels, where random labels are assigned to low confident samples. We first theoretically show that our method design is better than the vanilla pseudo-labeling strategy in terms of fairness criteria. Then, we empirically show on several benchmark datasets that by combining CGL and the state-of-the-art fairness-aware in-processing methods, the target accuracies and the fairness metrics can be jointly improved compared to the baselines. Furthermore, we convincingly show that CGL enables to naturally augment the given group-labeled dataset with external target label-only datasets so that both accuracy and fairness can be improved.
12 |
13 | ## Updates
14 |
15 | - 11 Apr, 2022: Initial upload.
16 |
17 | ## Dataset preparation
18 | 1. Download dataset
19 | - UTKFace :
20 | [link](https://susanqq.github.io/UTKFace/) (We used Aligned&Cropped Faces from the site)
21 | - CelebA :
22 | [link](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) (We used Aligned&Cropped Faces from the site)
23 | - Tabular datasets (Propublica Compas & Adult) :
24 | [link](https://github.com/Trusted-AI/AIF360)
25 | 2. Locate downloaded datasets to './data' directory
26 |
27 | ## How to train
28 | We note that CGL first trains a group classifier and a proper threshold using a train set and a validation set splitted from a group-labeled training dataset. And then, the group-unlabeled training samples are annotated with pseudo group labels based on CGL's assignment rules. Finally, we can train a fair model using base fair training methods such as LBC, FairHSIC or MFD.
29 |
30 | ### 1. Train a group classifier
31 | ```
32 | $ python main_groupclf.py --model --method scratch \
33 | --dataset \
34 | --version groupclf_val \
35 | --sv
36 | ```
37 |
38 | In above command, the 'version' can be chosen beween 'groupclf' or 'groupclfval' that indicate whether the group-labeled data is splitted into a train/validation set or not. For CGL, you should choice 'groupclfval' and for the pseudo-label baseline, you should choice 'groupclf'.
39 |
40 | ### 2. Find a threshold and save the predictions of group classifier
41 | ```
42 | $ python main_groupclf.py --model --method scratch \
43 | --dataset \
44 | --mode eval \
45 | --version groupclf_val \
46 | --sv
47 | ```
48 |
49 | ### 3. Train a fair model using any fair-training methods
50 |
51 | - MFD. For the feature distilation, MFD needs a teacher model that is trained from scratch.
52 | ```
53 | # train a scratch model
54 | $ python main.py --model --method scratch --dataset dataset_name
55 | $ python main.py --model --method mfd \
56 | --dataset \
57 | --labelwise \
58 | --version cgl \
59 | --sv {group_label_ratio} \
60 | --lamb 100 \
61 | --teacher-type mlp \
62 | --teacher-path
63 | ```
64 | - FairHSIC
65 | ```
66 | $ python main.py --model --method fairhsic \
67 | --dataset \
68 | --labelwise \
69 | --version cgl \
70 | --sv {group_label_ratio} \
71 | --lamb 100
72 | ```
73 | - LBC
74 | ```
75 | $ python main.py --model --method reweighting \
76 | --dataset \
77 | --iteration \
78 | --version cgl \
79 | --sv {group_label_ratio} \
80 | ```
81 | ## How to cite
82 |
83 | ```
84 | @inproceedings{jung2021cgl,
85 | title={Learning Fair Classifiers with Partially Annotated Group Labels},
86 | author={Sangwon Jung and Sanghyuk Chun and Taesup Moon},
87 | year={2022},
88 | booktitle={Conference on Computer Vision and Pattern Recognition (CVPR)},
89 | }
90 | ```
91 | ## License
92 |
93 | ```
94 | Copyright (c) 2022-present NAVER Corp.
95 |
96 | Permission is hereby granted, free of charge, to any person obtaining a copy
97 | of this software and associated documentation files (the "Software"), to deal
98 | in the Software without restriction, including without limitation the rights
99 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
100 | copies of the Software, and to permit persons to whom the Software is
101 | furnished to do so, subject to the following conditions:
102 |
103 | The above copyright notice and this permission notice shall be included in
104 | all copies or substantial portions of the Software.
105 |
106 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
107 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
108 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
109 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
110 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
111 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
112 | THE SOFTWARE.
113 | ```
114 |
--------------------------------------------------------------------------------
/data_handler/celeba.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | import torch
7 | import os
8 | from os.path import join
9 | import PIL
10 | import pandas
11 | import random
12 | import zipfile
13 | from functools import partial
14 | from torchvision.datasets.utils import download_file_from_google_drive, check_integrity, verify_str_arg
15 | from torchvision import transforms
16 | import data_handler
17 | from data_handler.utils import get_mean_std
18 |
19 |
20 | class CelebA(data_handler.SSLDataset):
21 | """
22 | There currently does not appear to be a easy way to extract 7z in python (without introducing additional
23 | dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
24 | right now.
25 | """
26 | file_list = [
27 | # File ID MD5 Hash Filename
28 | ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
29 | # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
30 | # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
31 | ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
32 | ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
33 | ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
34 | ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
35 | # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
36 | ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
37 | ]
38 | mean, std = get_mean_std('celeba')
39 | train_transform = transforms.Compose(
40 | [transforms.Resize((256, 256)),
41 | transforms.RandomCrop(224),
42 | transforms.RandomHorizontalFlip(),
43 | transforms.ToTensor(),
44 | transforms.Normalize(mean=mean, std=std)]
45 | )
46 | test_transform = transforms.Compose(
47 | [transforms.Resize((224, 224)),
48 | transforms.ToTensor(),
49 | transforms.Normalize(mean=mean, std=std)]
50 | )
51 | name = 'celeba'
52 |
53 | def __init__(self, target_attr='Attractive', add_attr=None, download=False, **kwargs):
54 | transform = self.train_transform if kwargs['split'] == 'train' else self.test_transform
55 | super(CelebA, self).__init__(transform=transform, **kwargs)
56 |
57 | if download:
58 | self._download()
59 |
60 | if not self._check_integrity():
61 | raise RuntimeError('Dataset not found or corrupted.' +
62 | ' You can use download=True to download it')
63 | # SELECT the features
64 | self.sensitive_attr = 'Male'
65 | self.add_attr = add_attr
66 | self.target_attr = target_attr
67 | split_map = {
68 | "train": 0,
69 | "valid": 1,
70 | "test": 2,
71 | "all": None,
72 | }
73 | split = split_map[verify_str_arg(self.split.lower(), "split",
74 | ("train", "valid", "test", "all"))]
75 | if 'group' in self.version:
76 | split = 0
77 | fn = partial(join, self.root)
78 | splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0)
79 | attr = pandas.read_csv(fn("list_attr_celeba.txt"), delim_whitespace=True, header=1)
80 | print('Im add attr : ', target_attr, add_attr)
81 |
82 | mask = slice(None) if split is None else (splits[1] == split)
83 |
84 | self.filename = splits[mask].index.values
85 | self.attr = torch.as_tensor(attr[mask].values)
86 | self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
87 | self.attr_names = list(attr.columns)
88 |
89 | self.target_idx = self.attr_names.index(self.target_attr)
90 | self.sensi_idx = self.attr_names.index(self.sensitive_attr)
91 | self.add_idx = self.attr_names.index(self.add_attr) if self.add_attr is not None else -1
92 | self.feature_idx = [i for i in range(len(self.attr_names)) if i != self.target_idx and i != self.sensi_idx]
93 | if self.add_attr is not None:
94 | self.feature_idx.remove(self.add_idx)
95 | self.num_classes = 2
96 | self.num_groups = 2 if self.add_attr is None else 4
97 |
98 | if self.add_attr is None:
99 | self.features = [[int(s), int(l), filename] for s, l, filename in
100 | zip(self.attr[:, self.sensi_idx], self.attr[:, self.target_idx], self.filename)]
101 | else:
102 | self.features = [[int(s1)*2 + int(s2), int(l), filename] for s1, s2, l, filename in
103 | zip(self.attr[:, self.sensi_idx], self.attr[:, self.add_idx], self.attr[:, self.target_idx], self.filename)]
104 | self.num_data, self.idxs_per_group = self._data_count(self.features, self.num_groups, self.num_classes)
105 |
106 | if self.split == "test" and 'group' not in self.version:
107 | self.features = self._balance_test_data(self.num_data, self.num_groups, self.num_classes)
108 | self.num_data, self.idxs_per_group = self._data_count(self.features, self.num_groups, self.num_classes)
109 |
110 | # if semi-supervised learning,
111 | if self.sv_ratio < 1:
112 | # we want the different supervision according to the seed
113 | random.seed(self.seed)
114 | self.features, self.num_data, self.idxs_per_group = self.ssl_processing(self.features, self.num_data, self.idxs_per_group)
115 | if 'group' in self.version:
116 | a, b = self.num_groups, self.num_classes
117 | self.num_groups, self.num_classes = b, a
118 |
119 | def __getitem__(self, index):
120 | sensitive, target, img_name = self.features[index]
121 | image = PIL.Image.open(os.path.join(self.root, "img_align_celeba", img_name))
122 |
123 | if self.transform is not None:
124 | image = self.transform(image)
125 |
126 | if 'group' in self.version:
127 | return image, 0, target, sensitive, (index, img_name)
128 | return image, 0, sensitive, target, (index, img_name)
129 |
130 | def _check_integrity(self):
131 | for (_, md5, filename) in self.file_list:
132 | fpath = os.path.join(self.root, filename)
133 | _, ext = os.path.splitext(filename)
134 | # Allow original archive to be deleted (zip and 7z)
135 | # Only need the extracted images
136 | if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
137 | return False
138 |
139 | # Should check a hash of the images
140 | return os.path.isdir(os.path.join(self.root, "img_align_celeba"))
141 |
142 | def _download(self):
143 | if self._check_integrity():
144 | print('Files already downloaded and verified')
145 | return
146 |
147 | for (file_id, md5, filename) in self.file_list:
148 | download_file_from_google_drive(file_id, self.root, filename, md5)
149 |
150 | with zipfile.ZipFile(os.path.join(self.root, "img_align_celeba.zip"), "r") as f:
151 | f.extractall(self.root)
152 |
--------------------------------------------------------------------------------
/trainer/mfd.py:
--------------------------------------------------------------------------------
1 | """
2 | Original code:
3 | https://github.com/sangwon79/Fair-Feature-Distillation-for-Visual-Recognition
4 | """
5 | from __future__ import print_function
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | import torch.nn as nn
10 | import time
11 | from utils import get_accuracy
12 | import trainer
13 |
14 |
15 | class Trainer(trainer.GenericTrainer):
16 | def __init__(self, args, **kwargs):
17 | super().__init__(args=args, **kwargs)
18 | self.lamb = args.lamb
19 | self.sigma = args.sigma
20 | self.kernel = args.kernel
21 | self.slmode = True if args.sv < 1 else False
22 | self.version = args.version
23 |
24 | def train(self, train_loader, test_loader, epochs, writer=None):
25 | num_classes = train_loader.dataset.num_classes
26 | num_groups = train_loader.dataset.num_groups
27 |
28 | distiller = MMDLoss(w_m=self.lamb, sigma=self.sigma,
29 | num_classes=num_classes, num_groups=num_groups, kernel=self.kernel)
30 | for epoch in range(self.epochs):
31 | self._train_epoch(epoch, train_loader, self.model, self.teacher, distiller=distiller)
32 | eval_start_time = time.time()
33 | eval_loss, eval_acc, eval_deom, eval_deoa, eval_subgroup_acc = self.evaluate(self.model, test_loader, self.criterion)
34 | eval_end_time = time.time()
35 | print('[{}/{}] Method: {} '
36 | 'Test Loss: {:.3f} Test Acc: {:.2f} Test DEOM {:.2f} [{:.2f} s]'.format
37 | (epoch + 1, epochs, self.method,
38 | eval_loss, eval_acc, eval_deom, (eval_end_time - eval_start_time)))
39 | if self.record:
40 | train_loss, train_acc, train_deom, train_deoa, train_subgroup_acc = self.evaluate(self.model, train_loader, self.criterion)
41 | writer.add_scalar('train_loss', train_loss, epoch)
42 | writer.add_scalar('train_acc', train_acc, epoch)
43 | writer.add_scalar('train_deom', train_deom, epoch)
44 | writer.add_scalar('train_deoa', train_deoa, epoch)
45 | writer.add_scalar('eval_loss', eval_loss, epoch)
46 | writer.add_scalar('eval_acc', eval_acc, epoch)
47 | writer.add_scalar('eval_deom', eval_deom, epoch)
48 | writer.add_scalar('eval_deoa', eval_deoa, epoch)
49 |
50 | eval_contents = {}
51 | train_contents = {}
52 | for g in range(num_groups):
53 | for l in range(num_classes):
54 | eval_contents[f'g{g},l{l}'] = eval_subgroup_acc[g, l]
55 | train_contents[f'g{g},l{l}'] = train_subgroup_acc[g, l]
56 | writer.add_scalars('eval_subgroup_acc', eval_contents, epoch)
57 | writer.add_scalars('train_subgroup_acc', train_contents, epoch)
58 |
59 | if self.scheduler is not None and 'Reduce' in type(self.scheduler).__name__:
60 | self.scheduler.step(eval_loss)
61 | else:
62 | self.scheduler.step()
63 |
64 | print('Training Finished!')
65 |
66 | def _train_epoch(self, epoch, train_loader, model, teacher, distiller=None):
67 | model.train()
68 | teacher.eval()
69 |
70 | running_acc = 0.0
71 | running_loss = 0.0
72 | batch_start_time = time.time()
73 | # tmp = np.zeros((4,3))
74 | for i, data in enumerate(train_loader):
75 | # Get the inputs
76 | inputs, _, groups, targets, (idx, _) = data
77 | labels = targets
78 | if self.cuda:
79 | inputs = inputs.cuda(self.device)
80 | labels = labels.cuda(self.device)
81 | groups = groups.long().cuda(self.device)
82 |
83 | t_inputs = inputs.to(self.t_device)
84 |
85 | outputs = model(inputs, get_inter=True)
86 | stu_logits = outputs[-1]
87 |
88 | t_outputs = teacher(t_inputs, get_inter=True)
89 |
90 | loss = self.criterion(stu_logits, labels)
91 |
92 | running_acc += get_accuracy(stu_logits, labels)
93 | f_s = outputs[-2]
94 |
95 | f_t = t_outputs[-2].detach()
96 |
97 | mmd_loss = distiller.forward(f_s, f_t, groups=groups, labels=labels)
98 |
99 | loss = loss + mmd_loss
100 | running_loss += loss.item()
101 | self.optimizer.zero_grad()
102 | loss.backward()
103 | self.optimizer.step()
104 | if i % self.term == self.term - 1: # print every self.term mini-batches
105 | avg_batch_time = time.time() - batch_start_time
106 | print('[{}/{}, {:5d}] Method: {} Train Loss: {:.3f} Train Acc: {:.2f} '
107 | '[{:.2f} s/batch]'.format
108 | (epoch + 1, self.epochs, i + 1, self.method, running_loss / self.term, running_acc / self.term,
109 | avg_batch_time / self.term))
110 |
111 | running_loss = 0.0
112 | running_acc = 0.0
113 | batch_start_time = time.time()
114 |
115 |
116 | class MMDLoss(nn.Module):
117 | def __init__(self, w_m, sigma, num_groups, num_classes, kernel):
118 | super(MMDLoss, self).__init__()
119 | self.w_m = w_m
120 | self.sigma = sigma
121 | self.num_groups = num_groups
122 | self.num_classes = num_classes
123 | self.kernel = kernel
124 |
125 | def forward(self, f_s, f_t, groups, labels):
126 | if self.kernel == 'poly':
127 | student = F.normalize(f_s.view(f_s.shape[0], -1), dim=1)
128 | teacher = F.normalize(f_t.view(f_t.shape[0], -1), dim=1).detach()
129 | else:
130 | student = f_s.view(f_s.shape[0], -1)
131 | teacher = f_t.view(f_t.shape[0], -1)
132 |
133 | mmd_loss = 0
134 | with torch.no_grad():
135 | _, sigma_avg = self.pdist(teacher, student, sigma_base=self.sigma, kernel=self.kernel)
136 |
137 | for c in range(self.num_classes):
138 | if len(teacher[labels == c]) == 0:
139 | continue
140 | for g in range(self.num_groups):
141 | if len(student[(labels == c) * (groups == g)]) == 0:
142 | continue
143 | K_TS, _ = self.pdist(teacher[labels == c], student[(labels == c) * (groups == g)],
144 | sigma_base=self.sigma, sigma_avg=sigma_avg, kernel=self.kernel)
145 | K_SS, _ = self.pdist(student[(labels == c) * (groups == g)], student[(labels == c) * (groups == g)],
146 | sigma_base=self.sigma, sigma_avg=sigma_avg, kernel=self.kernel)
147 |
148 | K_TT, _ = self.pdist(teacher[labels == c], teacher[labels == c], sigma_base=self.sigma,
149 | sigma_avg=sigma_avg, kernel=self.kernel)
150 |
151 | mmd_loss += K_TT.mean() + K_SS.mean() - 2 * K_TS.mean()
152 | loss = (1/2) * self.w_m * mmd_loss
153 | return loss
154 |
155 | @staticmethod
156 | def pdist(e1, e2, eps=1e-12, kernel='rbf', sigma_base=1.0, sigma_avg=None):
157 | if len(e1) == 0 or len(e2) == 0:
158 | res = torch.zeros(1)
159 | else:
160 | if kernel == 'rbf':
161 | e1_square = e1.pow(2).sum(dim=1)
162 | e2_square = e2.pow(2).sum(dim=1)
163 | prod = e1 @ e2.t()
164 | res = (e1_square.unsqueeze(1) + e2_square.unsqueeze(0) - 2 * prod).clamp(min=eps)
165 | res = res.clone()
166 | sigma_avg = res.mean().detach() if sigma_avg is None else sigma_avg
167 | res = torch.exp(-res / (2*(sigma_base)*sigma_avg))
168 | elif kernel == 'poly':
169 | res = torch.matmul(e1, e2.t()).pow(2)
170 |
171 | return res, sigma_avg
172 |
--------------------------------------------------------------------------------
/data_handler/AIF360/standard_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Original code:
3 | https://github.com/Trusted-AI/AIF360
4 | """
5 | from logging import warning
6 |
7 | import numpy as np
8 | import pandas as pd
9 |
10 | from data_handler.AIF360.binary_label_dataset import BinaryLabelDataset
11 |
12 |
13 | class StandardDataset(BinaryLabelDataset):
14 | """Base class for every :obj:`BinaryLabelDataset` provided out of the box by
15 | aif360.
16 | It is not strictly necessary to inherit this class when adding custom
17 | datasets but it may be useful.
18 | This class is very loosely based on code from
19 | https://github.com/algofairness/fairness-comparison.
20 | """
21 |
22 | def __init__(self, df, label_name, favorable_classes,
23 | protected_attribute_names, privileged_classes,
24 | instance_weights_name='', scores_name='',
25 | categorical_features=[], features_to_keep=[],
26 | features_to_drop=[], na_values=[], custom_preprocessing=None,
27 | metadata=None):
28 | """
29 | Subclasses of StandardDataset should perform the following before
30 | calling `super().__init__`:
31 | 1. Load the dataframe from a raw file.
32 | Then, this class will go through a standard preprocessing routine which:
33 | 2. (optional) Performs some dataset-specific preprocessing (e.g.
34 | renaming columns/values, handling missing data).
35 | 3. Drops unrequested columns (see `features_to_keep` and
36 | `features_to_drop` for details).
37 | 4. Drops rows with NA values.
38 | 5. Creates a one-hot encoding of the categorical variables.
39 | 6. Maps protected attributes to binary privileged/unprivileged
40 | values (1/0).
41 | 7. Maps labels to binary favorable/unfavorable labels (1/0).
42 | Args:
43 | df (pandas.DataFrame): DataFrame on which to perform standard
44 | processing.
45 | label_name: Name of the label column in `df`.
46 | favorable_classes (list or function): Label values which are
47 | considered favorable or a boolean function which returns `True`
48 | if favorable. All others are unfavorable. Label values are
49 | mapped to 1 (favorable) and 0 (unfavorable) if they are not
50 | already binary and numerical.
51 | protected_attribute_names (list): List of names corresponding to
52 | protected attribute columns in `df`.
53 | privileged_classes (list(list or function)): Each element is
54 | a list of values which are considered privileged or a boolean
55 | function which return `True` if privileged for the corresponding
56 | column in `protected_attribute_names`. All others are
57 | unprivileged. Values are mapped to 1 (privileged) and 0
58 | (unprivileged) if they are not already numerical.
59 | instance_weights_name (optional): Name of the instance weights
60 | column in `df`.
61 | categorical_features (optional, list): List of column names in the
62 | DataFrame which are to be expanded into one-hot vectors.
63 | features_to_keep (optional, list): Column names to keep. All others
64 | are dropped except those present in `protected_attribute_names`,
65 | `categorical_features`, `label_name` or `instance_weights_name`.
66 | Defaults to all columns if not provided.
67 | features_to_drop (optional, list): Column names to drop. *Note: this
68 | overrides* `features_to_keep`.
69 | na_values (optional): Additional strings to recognize as NA. See
70 | :func:`pandas.read_csv` for details.
71 | custom_preprocessing (function): A function object which
72 | acts on and returns a DataFrame (f: DataFrame -> DataFrame). If
73 | `None`, no extra preprocessing is applied.
74 | metadata (optional): Additional metadata to append.
75 | """
76 | # 2. Perform dataset-specific preprocessing
77 | if custom_preprocessing:
78 | df = custom_preprocessing(df)
79 |
80 | # 3. Drop unrequested columns
81 | features_to_keep = features_to_keep or df.columns.tolist()
82 | keep = (set(features_to_keep) | set(protected_attribute_names)
83 | | set(categorical_features) | set([label_name]))
84 | if instance_weights_name:
85 | keep |= set([instance_weights_name])
86 | df = df[sorted(keep - set(features_to_drop), key=df.columns.get_loc)]
87 | categorical_features = sorted(set(categorical_features) - set(features_to_drop), key=df.columns.get_loc)
88 |
89 | # 4. Remove any rows that have missing data.
90 | dropped = df.dropna()
91 | count = df.shape[0] - dropped.shape[0]
92 | if count > 0:
93 | warning("Missing Data: {} rows removed from {}.".format(count,
94 | type(self).__name__))
95 | df = dropped
96 |
97 | # 5. Create a one-hot encoding of the categorical variables.
98 | df = pd.get_dummies(df, columns=categorical_features, prefix_sep='=')
99 |
100 | # 6. Map protected attributes to privileged/unprivileged
101 | privileged_protected_attributes = []
102 | unprivileged_protected_attributes = []
103 | for attr, vals in zip(protected_attribute_names, privileged_classes):
104 | privileged_values = [1.]
105 | unprivileged_values = [0.]
106 | if callable(vals):
107 | df[attr] = df[attr].apply(vals)
108 | elif np.issubdtype(df[attr].dtype, np.number):
109 | # this attribute is numeric; no remapping needed
110 | privileged_values = vals
111 | unprivileged_values = list(set(df[attr]).difference(vals))
112 | else:
113 | # find all instances which match any of the attribute values
114 | priv = np.logical_or.reduce(np.equal.outer(vals, df[attr].to_numpy()))
115 | df.loc[priv, attr] = privileged_values[0]
116 | df.loc[~priv, attr] = unprivileged_values[0]
117 |
118 | privileged_protected_attributes.append(
119 | np.array(privileged_values, dtype=np.float64))
120 | unprivileged_protected_attributes.append(
121 | np.array(unprivileged_values, dtype=np.float64))
122 |
123 | # 7. Make labels binary
124 | favorable_label = 1.
125 | unfavorable_label = 0.
126 | if callable(favorable_classes):
127 | df[label_name] = df[label_name].apply(favorable_classes)
128 | elif np.issubdtype(df[label_name], np.number) and len(set(df[label_name])) == 2:
129 | # labels are already binary; don't change them
130 | favorable_label = favorable_classes[0]
131 | unfavorable_label = set(df[label_name]).difference(favorable_classes).pop()
132 | else:
133 | # find all instances which match any of the favorable classes
134 | pos = np.logical_or.reduce(np.equal.outer(favorable_classes,
135 | df[label_name].to_numpy()))
136 | df.loc[pos, label_name] = favorable_label
137 | df.loc[~pos, label_name] = unfavorable_label
138 |
139 | super(StandardDataset, self).__init__(df=df, label_names=[label_name],
140 | protected_attribute_names=protected_attribute_names,
141 | privileged_protected_attributes=privileged_protected_attributes,
142 | unprivileged_protected_attributes=unprivileged_protected_attributes,
143 | instance_weights_name=instance_weights_name,
144 | scores_names=[scores_name] if scores_name else [],
145 | favorable_label=favorable_label,
146 | unfavorable_label=unfavorable_label, metadata=metadata)
147 |
--------------------------------------------------------------------------------
/trainer/trainer_factory.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | import torch
7 | import numpy as np
8 | import os
9 | import torch.nn as nn
10 | from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR, CosineAnnealingLR
11 | from sklearn.metrics import confusion_matrix
12 | from utils import make_log_name
13 |
14 |
15 | class TrainerFactory:
16 | def __init__(self):
17 | pass
18 |
19 | @staticmethod
20 | def get_trainer(method, **kwargs):
21 | if method == 'scratch':
22 | import trainer.vanilla_train as trainer
23 | elif method == 'mfd':
24 | import trainer.mfd as trainer
25 | elif method == 'fairhsic':
26 | import trainer.fairhsic as trainer
27 | elif method == 'adv':
28 | import trainer.adv_debiasing as trainer
29 | elif method == 'reweighting':
30 | import trainer.reweighting as trainer
31 | elif method == 'groupdro':
32 | import trainer.groupdro as trainer
33 | elif method == 'groupdro_ori':
34 | import trainer.groupdro as trainer
35 | else:
36 | raise Exception('Not allowed method')
37 | return trainer.Trainer(**kwargs)
38 |
39 |
40 | class GenericTrainer:
41 | '''
42 | Base class for trainer; to implement a new training routine, inherit from this.
43 | '''
44 | def __init__(self, model, args, optimizer, teacher=None):
45 | self.get_inter = args.get_inter
46 |
47 | self.record = args.record
48 | self.cuda = args.cuda
49 | self.device = args.device
50 | self.t_device = args.t_device
51 | self.term = args.term
52 | self.lr = args.lr
53 | self.epochs = args.epochs
54 | self.method = args.method
55 | self.model_name =args.model
56 | self.model = model
57 | self.teacher = teacher
58 | self.optimizer = optimizer
59 | self.optim_type = args.optimizer
60 | self.log_dir = args.log_dir
61 | self.criterion=torch.nn.CrossEntropyLoss()
62 | self.scheduler = None
63 |
64 | self.log_name = make_log_name(args)
65 | self.log_dir = os.path.join(args.log_dir, args.date, args.dataset, args.method)
66 | self.save_dir = os.path.join(args.save_dir, args.date, args.dataset, args.method)
67 |
68 | if self.optim_type == 'Adam' and self.optimizer is not None:
69 | self.scheduler = ReduceLROnPlateau(self.optimizer)
70 | elif self.optim_type == 'AdamP' and self.optimizer is not None:
71 | if self.epochs < 100:
72 | t_max = self.epochs
73 | elif self.epochs == 200:
74 | t_max = 66
75 | self.scheduler = CosineAnnealingLR(self.optimizer, t_max)
76 | else:
77 | self.scheduler = MultiStepLR(self.optimizer, [60, 120, 180], gamma=0.1)
78 | #self.scheduler = MultiStepLR(self.optimizer, [30, 60, 90], gamma=0.1)
79 |
80 |
81 | def evaluate(self, model, loader, criterion, device=None):
82 | model.eval()
83 | num_groups = loader.dataset.num_groups
84 | num_classes = loader.dataset.num_classes
85 | device = self.device if device is None else device
86 |
87 | eval_acc = 0
88 | eval_loss = 0
89 | eval_eopp_list = torch.zeros(num_groups, num_classes).cuda(device)
90 | eval_data_count = torch.zeros(num_groups, num_classes).cuda(device)
91 |
92 | if 'Custom' in type(loader).__name__:
93 | loader = loader.generate()
94 | with torch.no_grad():
95 | for j, eval_data in enumerate(loader):
96 | if j == 100:
97 | break
98 | # Get the inputs
99 | inputs, _, groups, classes, _ = eval_data
100 | #
101 | labels = classes
102 | if self.cuda:
103 | inputs = inputs.cuda(device)
104 | labels = labels.cuda(device)
105 | groups = groups.cuda(device)
106 |
107 | outputs = model(inputs)
108 |
109 | loss = criterion(outputs, labels)
110 | eval_loss += loss.item() * len(labels)
111 | preds = torch.argmax(outputs, 1)
112 | acc = (preds == labels).float().squeeze()
113 | eval_acc += acc.sum()
114 |
115 | for g in range(num_groups):
116 | for l in range(num_classes):
117 | eval_eopp_list[g, l] += acc[(groups == g) * (labels == l)].sum()
118 | eval_data_count[g, l] += torch.sum((groups == g) * (labels == l))
119 |
120 | eval_loss = eval_loss / eval_data_count.sum()
121 | eval_acc = eval_acc / eval_data_count.sum()
122 | eval_eopp_list = eval_eopp_list / eval_data_count
123 | eval_max_eopp = torch.max(eval_eopp_list, dim=0)[0] - torch.min(eval_eopp_list, dim=0)[0]
124 | eval_avg_eopp = torch.mean(eval_max_eopp).item()
125 | eval_max_eopp = torch.max(eval_max_eopp).item()
126 | model.train()
127 | return eval_loss, eval_acc, eval_max_eopp, eval_avg_eopp, eval_eopp_list
128 |
129 | def save_model(self, save_dir, log_name="", model=None):
130 | model_to_save = self.model if model is None else model
131 | model_savepath = os.path.join(save_dir, log_name + '.pt')
132 | torch.save(model_to_save.state_dict(), model_savepath)
133 |
134 | print('Model saved to %s' % model_savepath)
135 |
136 | def compute_confusion_matix(self, dataset='test', num_classes=2,
137 | dataloader=None, log_dir="", log_name=""):
138 | from scipy.io import savemat
139 | from collections import defaultdict
140 | self.model.eval()
141 | confu_mat = defaultdict(lambda: np.zeros((num_classes, num_classes)))
142 | print('# of {} data : {}'.format(dataset, len(dataloader.dataset)))
143 |
144 | predict_mat = {}
145 | output_set = torch.tensor([])
146 | group_set = torch.tensor([], dtype=torch.long)
147 | target_set = torch.tensor([], dtype=torch.long)
148 | intermediate_feature_set = torch.tensor([])
149 |
150 | with torch.no_grad():
151 | for i, data in enumerate(dataloader):
152 | # Get the inputs
153 | inputs, _, groups, targets, _ = data
154 | labels = targets
155 | groups = groups.long()
156 |
157 | if self.cuda:
158 | inputs = inputs.cuda(self.device)
159 | labels = labels.cuda(self.device)
160 |
161 | # forward
162 |
163 | outputs = self.model(inputs)
164 | if self.get_inter:
165 | intermediate_feature = self.model.forward(inputs, get_inter=True)[-2]
166 |
167 | group_set = torch.cat((group_set, groups))
168 | target_set = torch.cat((target_set, targets))
169 | output_set = torch.cat((output_set, outputs.cpu()))
170 | if self.get_inter:
171 | intermediate_feature_set = torch.cat((intermediate_feature_set, intermediate_feature.cpu()))
172 |
173 | pred = torch.argmax(outputs, 1)
174 | group_element = list(torch.unique(groups).numpy())
175 | for i in group_element:
176 | mask = groups == i
177 | if len(labels[mask]) != 0:
178 | confu_mat[str(i)] += confusion_matrix(
179 | labels[mask].cpu().numpy(), pred[mask].cpu().numpy(),
180 | labels=[i for i in range(num_classes)])
181 |
182 | predict_mat['group_set'] = group_set.numpy()
183 | predict_mat['target_set'] = target_set.numpy()
184 | predict_mat['output_set'] = output_set.numpy()
185 | if self.get_inter:
186 | predict_mat['intermediate_feature_set'] = intermediate_feature_set.numpy()
187 |
188 | savepath = os.path.join(log_dir, log_name + '_{}_confu'.format(dataset))
189 | print('savepath', savepath)
190 | savemat(savepath, confu_mat, appendmat=True)
191 |
192 | savepath_pred = os.path.join(log_dir, log_name + '_{}_pred'.format(dataset))
193 | savemat(savepath_pred, predict_mat, appendmat=True)
194 |
195 | print('Computed confusion matrix for {} dataset successfully!'.format(dataset))
196 | return confu_mat
197 |
--------------------------------------------------------------------------------
/trainer/reweighting.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | from __future__ import print_function
7 | import torch
8 | import torch.nn as nn
9 | import time
10 | from utils import get_accuracy
11 | import trainer
12 | from torch.utils.data import DataLoader
13 |
14 |
15 | class Trainer(trainer.GenericTrainer):
16 | def __init__(self, args, **kwargs):
17 | super().__init__(args=args, **kwargs)
18 |
19 | self.eta = args.eta
20 | self.iteration = args.iteration
21 | self.batch_size = args.batch_size
22 | self.num_workers = args.num_workers
23 | self.reweighting_target_criterion = args.reweighting_target_criterion
24 | self.slmode = True if args.sv < 1 else False
25 | self.version = args.version
26 |
27 | def train(self, train_loader, test_loader, epochs, dummy_loader=None, writer=None):
28 | model = self.model
29 | model.train()
30 | num_groups = train_loader.dataset.num_groups
31 | num_classes = train_loader.dataset.num_classes
32 |
33 | extended_multipliers = torch.zeros((num_groups, num_classes))
34 | if self.cuda:
35 | extended_multipliers = extended_multipliers.cuda()
36 | _, Y_train, S_train = self.get_statistics(train_loader.dataset, batch_size=self.batch_size,
37 | num_workers=self.num_workers)
38 |
39 | eta_learning_rate = self.eta
40 | print('eta_learning_rate : ', eta_learning_rate)
41 | n_iters = self.iteration
42 | print('n_iters : ', n_iters)
43 |
44 | violations = 0
45 |
46 | for iter_ in range(n_iters):
47 | start_t = time.time()
48 | weight_set = self.debias_weights(Y_train, S_train, extended_multipliers, num_groups, num_classes)
49 |
50 | for epoch in range(epochs):
51 | self._train_epoch(epoch, train_loader, model, weight_set)
52 | eval_start_time = time.time()
53 | eval_loss, eval_acc, eval_deom, eval_deoa, eval_subgroup_acc = self.evaluate(self.model, test_loader, self.criterion)
54 | eval_end_time = time.time()
55 | print('[{}/{}] Method: {} '
56 | 'Test Loss: {:.3f} Test Acc: {:.2f} Test DEOM {:.2f} [{:.2f} s]'.format
57 | (epoch + 1, epochs, self.method,
58 | eval_loss, eval_acc, eval_deom, (eval_end_time - eval_start_time)))
59 | if self.record:
60 | train_loss, train_acc, train_deom, train_deoa, train_subgroup_acc = self.evaluate(self.model, train_loader, self.criterion)
61 | writer.add_scalar('train_loss', train_loss, epoch)
62 | writer.add_scalar('train_acc', train_acc, epoch)
63 | writer.add_scalar('train_deom', train_deom, epoch)
64 | writer.add_scalar('train_deoa', train_deoa, epoch)
65 | writer.add_scalar('eval_loss', eval_loss, epoch)
66 | writer.add_scalar('eval_acc', eval_acc, epoch)
67 | writer.add_scalar('eval_deopm', eval_deom, epoch)
68 | writer.add_scalar('eval_deopa', eval_deoa, epoch)
69 |
70 | eval_contents = {}
71 | train_contents = {}
72 | for g in range(num_groups):
73 | for l in range(num_classes):
74 | eval_contents[f'g{g},l{l}'] = eval_subgroup_acc[g, l]
75 | train_contents[f'g{g},l{l}'] = train_subgroup_acc[g, l]
76 | writer.add_scalars('eval_subgroup_acc', eval_contents, epoch)
77 | writer.add_scalars('train_subgroup_acc', train_contents, epoch)
78 |
79 | if self.scheduler is not None and 'Reduce' in type(self.scheduler).__name__:
80 | self.scheduler.step(eval_loss)
81 | else:
82 | self.scheduler.step()
83 |
84 | end_t = time.time()
85 | train_t = int((end_t - start_t) / 60)
86 | print('Training Time : {} hours {} minutes / iter : {}/{}'.format(int(train_t / 60), (train_t % 60),
87 | (iter_ + 1), n_iters))
88 |
89 | Y_pred_train, Y_train, S_train = self.get_statistics(train_loader.dataset, batch_size=self.batch_size,
90 | num_workers=self.num_workers, model=model)
91 |
92 | if self.reweighting_target_criterion == 'dp':
93 | acc, violations = self.get_error_and_violations_DP(Y_pred_train, Y_train, S_train, num_groups, num_classes)
94 | elif self.reweighting_target_criterion == 'eo':
95 | acc, violations = self.get_error_and_violations_EO(Y_pred_train, Y_train, S_train, num_groups, num_classes)
96 | extended_multipliers -= eta_learning_rate * violations
97 |
98 | def _train_epoch(self, epoch, train_loader, model, weight_set):
99 | model.train()
100 |
101 | running_acc = 0.0
102 | running_loss = 0.0
103 | avg_batch_time = 0.0
104 |
105 | for i, data in enumerate(train_loader):
106 | batch_start_time = time.time()
107 | # Get the inputs
108 | inputs, _, groups, targets, indexes = data
109 | labels = targets
110 | labels = labels.long()
111 |
112 | weights = weight_set[indexes[0]]
113 |
114 | if self.cuda:
115 | inputs = inputs.cuda()
116 | labels = labels.cuda()
117 | weights = weights.cuda()
118 | groups = groups.cuda()
119 |
120 | outputs = model(inputs)
121 |
122 | loss = torch.mean(weights * nn.CrossEntropyLoss(reduction='none')(outputs, labels))
123 | running_loss += loss.item()
124 | running_acc += get_accuracy(outputs, labels)
125 |
126 | # zero the parameter gradients + backward + optimize
127 | self.optimizer.zero_grad()
128 | loss.backward()
129 | self.optimizer.step()
130 |
131 | batch_end_time = time.time()
132 | avg_batch_time += batch_end_time - batch_start_time
133 |
134 | if i % self.term == self.term - 1: # print every self.term mini-batches
135 | print('[{}/{}, {:5d}] Method: {} Train Loss: {:.3f} Train Acc: {:.2f} '
136 | '[{:.2f} s/batch]'.format
137 | (epoch + 1, self.epochs, i + 1, self.method, running_loss / self.term, running_acc / self.term,
138 | avg_batch_time / self.term))
139 |
140 | running_loss = 0.0
141 | running_acc = 0.0
142 | avg_batch_time = 0.0
143 |
144 | last_batch_idx = i
145 | return last_batch_idx
146 |
147 | def get_statistics(self, dataset, batch_size=128, num_workers=2, model=None):
148 |
149 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
150 | num_workers=num_workers, pin_memory=True, drop_last=False)
151 |
152 | if model is not None:
153 | model.eval()
154 |
155 | Y_pred_set = []
156 | Y_set = []
157 | S_set = []
158 | total = 0
159 | for i, data in enumerate(dataloader):
160 | inputs, _, sen_attrs, targets, indexes = data
161 | Y_set.append(targets) # sen_attrs = -1 means no supervision for sensitive group
162 | S_set.append(sen_attrs)
163 |
164 | if self.cuda:
165 | inputs = inputs.cuda()
166 | groups = sen_attrs.cuda()
167 | if model is not None:
168 | outputs = model(inputs)
169 | Y_pred_set.append(torch.argmax(outputs, dim=1))
170 | total += inputs.shape[0]
171 |
172 | Y_set = torch.cat(Y_set)
173 | S_set = torch.cat(S_set)
174 | Y_pred_set = torch.cat(Y_pred_set) if len(Y_pred_set) != 0 else torch.zeros(0)
175 | return Y_pred_set.long(), Y_set.long().cuda(), S_set.long().cuda()
176 |
177 | # Vectorized version for DP & multi-class
178 | def get_error_and_violations_DP(self, y_pred, label, sen_attrs, num_groups, num_classes):
179 | acc = torch.mean(y_pred == label)
180 | total_num = len(y_pred)
181 | violations = torch.zeros((num_groups, num_classes))
182 | if self.cuda:
183 | violations = violations.cuda()
184 | for g in range(num_groups):
185 | for c in range(num_classes):
186 | pivot = len(torch.where(y_pred == c)[0]) / total_num
187 | group_idxs = torch.where(sen_attrs == g)[0]
188 | group_pred_idxs = torch.where(torch.logical_and(sen_attrs == g, y_pred == c))[0]
189 | violations[g, c] = len(group_pred_idxs)/len(group_idxs) - pivot
190 | return acc, violations
191 |
192 | # Vectorized version for EO & multi-class
193 | def get_error_and_violations_EO(self, y_pred, label, sen_attrs, num_groups, num_classes):
194 | acc = torch.mean((y_pred == label).float())
195 | violations = torch.zeros((num_groups, num_classes))
196 | if self.cuda:
197 | violations = violations.cuda()
198 | for g in range(num_groups):
199 | for c in range(num_classes):
200 | class_idxs = torch.where(label == c)[0]
201 | pred_class_idxs = torch.where(torch.logical_and(y_pred == c, label == c))[0]
202 | pivot = len(pred_class_idxs)/len(class_idxs)
203 | group_class_idxs = torch.where(torch.logical_and(sen_attrs == g, label == c))[0]
204 | group_pred_class_idxs = torch.where(torch.logical_and(torch.logical_and(sen_attrs == g, y_pred == c), label == c))[0]
205 | violations[g, c] = len(group_pred_class_idxs)/len(group_class_idxs) - pivot
206 | print('violations', violations)
207 | return acc, violations
208 |
209 | # update weight
210 | def debias_weights(self, label, sen_attrs, extended_multipliers, num_groups, num_classes):
211 | weights = torch.zeros(len(label))
212 | w_matrix = torch.sigmoid(extended_multipliers) # g by c
213 | weights = w_matrix[sen_attrs, label]
214 | if self.slmode and self.version == 2:
215 | weights[sen_attrs == -1] = 0.5
216 | return weights
217 |
218 | def criterion(self, model, outputs, labels):
219 | return nn.CrossEntropyLoss()(outputs, labels)
220 |
--------------------------------------------------------------------------------
/trainer/adv_debiasing.py:
--------------------------------------------------------------------------------
1 | """
2 | Original code:
3 | https://github.com/sangwon79/Fair-Feature-Distillation-for-Visual-Recognition
4 | """
5 | from __future__ import print_function
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 |
12 | import time
13 |
14 | from utils import get_accuracy
15 | from networks.mlp import MLP
16 | import trainer
17 | from torch.optim.lr_scheduler import ReduceLROnPlateau
18 |
19 |
20 | class Trainer(trainer.GenericTrainer):
21 | def __init__(self, args, **kwargs):
22 | super().__init__(args=args, **kwargs)
23 | self.adv_lambda = args.lamb
24 | self.adv_lr = args.eta
25 | self.target_criterion = 'eo'
26 |
27 | # self.film = args.film
28 | # self.no_film_residual = args.no_film_residual
29 |
30 | # self.no_groupmask = args.no_groupmask
31 | # self.mask_step = args.mask_step
32 | # param_m = [param for name, param in self.model.named_parameters() if 'mask' in name] \
33 | # if not args.no_groupmask and self.decouple else None
34 | # self.mask_optimizer = optim.Adam(param_m, lr=args.mask_lr, weight_decay=args.weight_decay) \
35 | # if not args.no_groupmask and self.decouple else None
36 | # self.scheduler_mask = ReduceLROnPlateau(self.mask_optimizer, patience=5) \
37 | # if not args.no_groupmask and self.decouple else None
38 |
39 | def train(self, train_loader, test_loader, epochs):
40 | model = self.model
41 | num_groups = train_loader.dataset.num_groups
42 | num_classes = train_loader.dataset.num_classes
43 | self._init_adversary(num_groups, num_classes, train_loader)
44 | self.scheduler = ReduceLROnPlateau(self.optimizer, patience=5)
45 |
46 | for epoch in range(epochs):
47 | self._train_epoch(epoch, train_loader, model)
48 |
49 | eval_start_time = time.time()
50 | eval_loss, eval_acc, eval_adv_loss, eval_adv_acc, eval_deopp = \
51 | self.evaluate(model, self.sa_clf, test_loader, self.criterion, self.adv_criterion)
52 | eval_end_time = time.time()
53 | print('[{}/{}] Method: {} '
54 | 'Test Loss: {:.3f} Test Acc: {:.2f} Test Adv loss: {:.3f} Test Adv Acc: {:.2f} Test DEopp {:.2f} [{:.2f} s]'.format
55 | (epoch + 1, epochs, self.method,
56 | eval_loss, eval_acc, eval_adv_loss, eval_adv_acc, eval_deopp, (eval_end_time - eval_start_time)))
57 |
58 | if self.scheduler != None:
59 | self.scheduler.step(eval_loss)
60 | self.adv_scheduler.step(eval_adv_loss)
61 |
62 | print('Training Finished!')
63 | return model
64 |
65 | def _train_epoch(self, epoch, train_loader, model):
66 | num_classes = train_loader.dataset.num_classes
67 | num_groups = train_loader.dataset.num_groups
68 |
69 | model.train()
70 |
71 | running_acc = 0.0
72 | running_loss = 0.0
73 | running_adv_loss = 0.0
74 | batch_start_time = time.time()
75 |
76 | for i, data in enumerate(train_loader):
77 | # Get the inputs
78 | inputs, _, groups, targets, _ = data
79 | labels = targets
80 | # groups = groups.long()
81 |
82 | if self.cuda:
83 | inputs = inputs.cuda()
84 | labels = labels.cuda()
85 | groups = groups.cuda()
86 |
87 | labels = labels.long()
88 | groups = groups.long()
89 |
90 | outputs = model(inputs)
91 |
92 |
93 | inputs_for_adv = outputs
94 | logits = outputs
95 |
96 | adv_inputs = None
97 | if self.target_criterion =='eo':
98 | repeat_times = num_classes
99 | input_loc = F.one_hot(labels.long(), num_classes).repeat_interleave(repeat_times, dim=1)
100 | adv_inputs = inputs_for_adv.repeat(1, repeat_times) * input_loc
101 | adv_inputs = torch.cat((inputs_for_adv, adv_inputs), dim=1)
102 |
103 | elif self.target_criterion == 'dp':
104 | adv_inputs = inputs_for_adv
105 |
106 | adv_preds = self.sa_clf(adv_inputs)
107 | # adv_loss = self.adv_criterion(self.sa_clf, adv_preds, groups)
108 | adv_loss = self.adv_criterion(adv_preds, groups)
109 |
110 | self.optimizer.zero_grad()
111 | self.adv_optimizer.zero_grad()
112 |
113 | #adv_loss.backward()#retain_graph=True)
114 | #adv_loss.backward(retain_graph=True)
115 | #for n, p in model.named_parameters():
116 | # unit_adv_grad = p.grad / (p.grad.norm() + torch.finfo(torch.float32).tiny)
117 | # p.grad += torch.sum(p.grad * unit_adv_grad) * unit_adv_grad # gradients are already reversed
118 |
119 | loss = self.criterion(logits, labels)
120 |
121 | (loss+adv_loss).backward()
122 |
123 | self.optimizer.step()
124 | self.adv_optimizer.step()
125 |
126 | running_loss += loss.item()
127 | running_adv_loss += adv_loss.item()
128 | # binary = True if num_classes ==2 else False
129 | running_acc += get_accuracy(outputs, labels)
130 |
131 | # self.optimizer.step()
132 | # self.adv_optimizer.step()
133 |
134 | if i % self.term == self.term - 1: # print every self.term mini-batches
135 | avg_batch_time = time.time() - batch_start_time
136 | print_statement = '[{}/{}, {:5d}] Method: {} Train Loss: {:.3f} Adv Loss: {:.3f} Train Acc: {:.2f} [{:.2f} s/batch]'\
137 | .format(epoch + 1, self.epochs, i + 1, self.method, running_loss / self.term,
138 | running_adv_loss / self.term,running_acc / self.term, avg_batch_time / self.term)
139 | print(print_statement)
140 |
141 | running_loss = 0.0
142 | running_acc = 0.0
143 | running_adv_loss = 0.0
144 | batch_start_time = time.time()
145 |
146 | def evaluate(self, model, adversary, loader, criterion, adv_criterion):
147 | model.eval()
148 | num_groups = loader.dataset.num_groups
149 | num_classes = loader.dataset.num_classes
150 | eval_acc = 0
151 | eval_adv_acc = 0
152 | eval_loss = 0
153 | eval_adv_loss = 0
154 | eval_eopp_list = torch.zeros(num_groups, num_classes).cuda()
155 | eval_data_count = torch.zeros(num_groups, num_classes).cuda()
156 |
157 | if 'Custom' in type(loader).__name__:
158 | loader = loader.generate()
159 | with torch.no_grad():
160 | for j, eval_data in enumerate(loader):
161 | # Get the inputs
162 | inputs, _, groups, classes, _ = eval_data
163 | #
164 | labels = classes
165 | groups = groups.long()
166 | if self.cuda:
167 | inputs = inputs.cuda()
168 | labels = labels.cuda()
169 | groups = groups.cuda()
170 |
171 | labels = labels.long()
172 |
173 | get_inter = False
174 | outputs = model(inputs, get_inter=get_inter)
175 |
176 | inputs_for_adv = outputs[-2] if get_inter else outputs
177 | logits = outputs[-1] if get_inter else outputs
178 |
179 | adv_inputs = None
180 | if self.target_criterion == 'eo':
181 | repeat_times = num_classes
182 | input_loc = F.one_hot(labels.long(), num_classes).repeat_interleave(repeat_times, dim=1)
183 | adv_inputs = inputs_for_adv.repeat(1, repeat_times) * input_loc
184 | adv_inputs = torch.cat((inputs_for_adv, adv_inputs), dim=1)
185 |
186 | elif self.target_criterion == 'dp':
187 | adv_inputs = inputs_for_adv
188 |
189 | loss = criterion(logits, labels)
190 | eval_loss += loss.item() * len(labels)
191 | binary = True if num_classes == 2 else False
192 | acc = get_accuracy(outputs, labels, reduction='none')
193 | eval_acc += acc.sum()
194 |
195 | for g in range(num_groups):
196 | for l in range(num_classes):
197 | eval_eopp_list[g, l] += acc[(groups == g) * (labels == l)].sum()
198 | eval_data_count[g, l] += torch.sum((groups == g) * (labels == l))
199 |
200 | adv_preds = adversary(adv_inputs)
201 | # groups = groups.float() if num_groups == 2 else groups.long()
202 | groups = groups.long()
203 | adv_loss = adv_criterion(adv_preds, groups)
204 | eval_adv_loss += adv_loss.item() * len(labels)
205 | # binary = True if num_groups == 2 else False
206 | eval_adv_acc += get_accuracy(adv_preds, groups)
207 |
208 | eval_loss = eval_loss / eval_data_count.sum()
209 | eval_acc = eval_acc / eval_data_count.sum()
210 | eval_adv_loss = eval_adv_loss / eval_data_count.sum()
211 | eval_adv_acc = eval_adv_acc / eval_data_count.sum()
212 | eval_eopp_list = eval_eopp_list / eval_data_count
213 | eval_max_eopp = torch.max(eval_eopp_list, dim=0)[0] - torch.min(eval_eopp_list, dim=0)[0]
214 | eval_max_eopp = torch.max(eval_max_eopp).item()
215 | model.train()
216 | return eval_loss, eval_acc, eval_adv_loss, eval_adv_acc, eval_max_eopp
217 |
218 | def _init_adversary(self, num_groups, num_classes, dataloader):
219 | self.model.eval()
220 | if self.target_criterion == 'eo':
221 | feature_size = num_classes * (num_classes + 1)
222 | elif self.target_criterion == 'dp':
223 | feature_size = num_classes
224 |
225 |
226 | sa_clf = MLP(feature_size=feature_size, hidden_dim=32, num_classes=num_groups,
227 | num_layer=2, adv=True, adv_lambda=self.adv_lambda)
228 | if self.cuda:
229 | sa_clf.cuda()
230 | sa_clf.train()
231 | self.sa_clf = sa_clf
232 | self.adv_optimizer = optim.Adam(sa_clf.parameters(), lr=self.adv_lr)
233 | self.adv_scheduler = ReduceLROnPlateau(self.adv_optimizer, patience=5)
234 | self.adv_criterion = self.criterion
235 |
236 | def criterion(self, model, outputs, labels):
237 | return nn.CrossEntropyLoss()(outputs, labels)
238 |
--------------------------------------------------------------------------------
/data_handler/ssl_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | import random
7 | import numpy as np
8 | import torch
9 | from data_handler import GenericDataset
10 | import os
11 | import pickle
12 |
13 |
14 | class SSLDataset(GenericDataset):
15 | def __init__(self, sv_ratio=1.0, version='', ups_iter=0, **kwargs):
16 | super(SSLDataset, self).__init__(**kwargs)
17 | self.sv_ratio = sv_ratio
18 | self.version = version
19 | self.add_attr = None
20 | self.ups_iter = ups_iter
21 |
22 | def ssl_processing(self, features, num_data, idxs_per_group, idxs_dict=None):
23 | if self.sv_ratio >= 1:
24 | raise ValueError
25 |
26 | if self.split == 'test' and 'group' not in self.version:
27 | return features, num_data, idxs_per_group
28 |
29 | print('preprocessing for ssl...')
30 | num_groups, num_classes = num_data.shape
31 |
32 | folder_name = 'annotated_idxs'
33 | idx_filename = '{}_{}'.format(self.seed, self.sv_ratio)
34 | if self.name == 'celeba':
35 | if self.target_attr != 'Attractive':
36 | idx_filename += f'_{self.target_attr}'
37 | if self.add_attr is not None:
38 | idx_filename += f'_{self.add_attr}'
39 | idx_filename += '.pkl'
40 | filepath = os.path.join(self.root, folder_name, idx_filename)
41 |
42 | if idxs_dict is None:
43 | if not os.path.isfile(filepath):
44 | idxs_dict = self.pick_idxs(num_data, idxs_per_group, filepath)
45 | else:
46 | with open(filepath, 'rb') as f:
47 | idxs_dict = pickle.load(f)
48 |
49 | if self.version == 'bs1':
50 | new_idxs = []
51 | for g in range(num_groups):
52 | for l in range(num_classes):
53 | new_idxs.extend(idxs_dict['annotated'][(g, l)])
54 | new_idxs.sort()
55 | new_features = [features[idx] for idx in new_idxs]
56 | features = new_features
57 |
58 | elif self.version == 'bs2':
59 | idx_pool = list(range(num_groups))
60 | total_per_group = np.zeros((num_classes, num_groups))
61 | for l in range(num_classes):
62 | for g in range(num_groups):
63 | total_per_group[l, g] = len(idxs_dict['annotated'][(g, l)])
64 | total_per_group = total_per_group.astype(int)
65 | for g in range(num_groups):
66 | for l in range(num_classes):
67 | for idx in idxs_dict['non-annotated'][(g, l)]:
68 | features[idx][0] = random.choices(idx_pool, k=1, weights=list(total_per_group[l]))[0]
69 |
70 | elif self.version == 'bs1uncertain':
71 | # bs2new is to use unlabeled data only for accuracy
72 | folder = 'group_clf'
73 | model = 'resnet18_dropout'
74 | epochs = '70'
75 | filename_pre = f'{model}_seed{self.seed}_epochs{epochs}_bs128_lr0.001'
76 | filename_post = f'_sv{self.sv_ratio}_groupclf.pt'
77 | filename = filename_pre + filename_post
78 | path = os.path.join(self.root, folder, filename)
79 | preds = torch.load(path)['pred']
80 |
81 | new_idxs = []
82 | for g in range(num_groups):
83 | for l in range(num_classes):
84 | # annotated
85 | new_idxs.extend(idxs_dict['annotated'][(g, l)])
86 | # non-annotated
87 | for idx in idxs_dict['non-annotated'][(g, l)]:
88 | if preds[idx] != -1:
89 | new_idxs.append(idx)
90 | features[idx][0] = preds[idx].item()
91 |
92 | print(len(new_idxs), len(features))
93 | new_idxs.sort()
94 | new_features = [features[idx] for idx in new_idxs]
95 | features = new_features
96 |
97 | elif self.version == 'bs3':
98 | # bs 3 is to make psuedo labels using a model being trained with labeld group label from scratch
99 | folder = 'group_clf' if self.version == 'bs3' else 'group_clf_pretrain'
100 | model = 'resnet18' if self.name in ['utkface', 'celeba'] else 'mlp'
101 | epochs = '70' if self.name in ['utkface', 'celeba'] else '50'
102 | filename_pre = f'{model}_seed{self.seed}_epochs{epochs}_bs128_lr0.001'
103 | filename_post = f'_sv{self.sv_ratio}_groupclf.pt'
104 | if self.name == 'celeba':
105 | if self.target_attr != 'Attractive':
106 | filename_pre += f'_{self.target_attr}'
107 | if self.add_attr is not None:
108 | filename_pre += f'_{self.add_attr}'
109 | filename = filename_pre + filename_post
110 | path = os.path.join(self.root, folder, filename)
111 | preds = torch.load(path)['pred']
112 |
113 | for g in range(num_groups):
114 | for l in range(num_classes):
115 | for idx in idxs_dict['non-annotated'][(g, l)]:
116 | features[idx][0] = preds[idx].item()
117 |
118 | elif self.version == 'groupclf':
119 | if self.ups_iter > 0:
120 | folder = 'group_clf'
121 | filename_pre = f'resnet18_dropout_seed{self.seed}_epochs70_bs128_lr0.001'
122 | if self.ups_iter > 1:
123 | filename_post = f'_sv{self.sv_ratio}_iter{_iter}_groupclf.pt'
124 | else:
125 | filename_post = f'_sv{self.sv_ratio}_groupclf.pt'
126 | filename = filename_pre + filename_post
127 | path = os.path.join(self.root, folder, filename)
128 | preds = torch.load(path)['pred']
129 |
130 | new_idxs = []
131 | for g in range(num_groups):
132 | for l in range(num_classes):
133 | if self.split == 'train':
134 | idxs = idxs_dict['annotated'][(g, l)]
135 | elif self.split == 'test':
136 | idxs = idxs_dict['non-annotated'][(g, l)]
137 | new_idxs.extend(idxs)
138 | new_idxs.sort()
139 | new_features = [features[idx] for idx in new_idxs]
140 | features = new_features
141 |
142 | elif self.version == 'groupclf_val':
143 | new_idxs = []
144 | val_idxs = []
145 | for g in range(num_groups):
146 | for l in range(num_classes):
147 | if self.split == 'train':
148 | train_num = int(len(idxs_dict['annotated'][(g, l)]) * 0.8)
149 | idxs = idxs_dict['annotated'][(g, l)][:train_num]
150 | val_idxs.extend(idxs_dict['annotated'][(g, l)][train_num:])
151 | elif self.split == 'test':
152 | idxs = idxs_dict['non-annotated'][(g, l)]
153 | new_idxs.extend(idxs)
154 | new_idxs.sort()
155 | new_features = [features[idx] for idx in new_idxs]
156 | features = new_features
157 | self.val_idxs = val_idxs
158 |
159 | elif self.version == 'oracle':
160 | # this version is to make a oracle model about predicting noisy lables.
161 | folder = 'group_clf'
162 | filename = 'resnet18_seed{}_epochs70_bs128_lr0.001_sv{}_groupclf_val.pt'.format(self.seed, self.sv_ratio)
163 | path = os.path.join(self.root, folder, filename)
164 | preds = torch.load(path)['pred']
165 | idx_pool = list(range(num_groups))
166 | for g in range(num_groups):
167 | for l in range(num_classes):
168 | for idx in idxs_dict['non-annotated'][(g, l)]:
169 | if features[idx][0] != preds[idx].item():
170 | features[idx][0] = random.choices(idx_pool, k=1)[0]
171 |
172 | elif self.version == 'cgl':
173 | folder = 'group_clf'
174 | model = 'resnet18' if self.name in ['utkface', 'celeba', 'utkface_fairface'] else 'mlp'
175 | epochs = '70' if self.name in ['utkface', 'celeba', 'utkface_fairface'] else '50'
176 | bs = '128' if self.name != 'adult' else '1024'
177 | filename_pre = f'{model}_seed{self.seed}_epochs{epochs}_bs{bs}_lr0.001'
178 | filename_post = f'_sv{self.sv_ratio}_groupclf_val.pt'
179 | if self.name == 'celeba':
180 | if self.target_attr != 'Attractive':
181 | filename_pre += f'_{self.target_attr}'
182 | if self.add_attr is not None:
183 | filename_pre += f'_{self.add_attr}'
184 | filename = filename_pre + filename_post
185 |
186 | if self.name == 'utkface_fairface':
187 | path = os.path.join('./data/utkface_fairface', folder, filename)
188 | else:
189 | path = os.path.join(self.root, folder, filename)
190 |
191 | preds = torch.load(path)['pred']
192 | probs = torch.load(path)['probs']
193 | thres = torch.load(path)['opt_thres']
194 | print('thres : ', thres)
195 | idx_pool = list(range(num_groups))
196 |
197 | total_per_group = np.zeros((num_classes, num_groups))
198 | for l in range(num_classes):
199 | for g in range(num_groups):
200 | total_per_group[l, g] = len(idxs_dict['annotated'][(g, l)])
201 | total_per_group = total_per_group.astype(int)
202 | for g in range(num_groups):
203 | for l in range(num_classes):
204 | for idx in idxs_dict['non-annotated'][(g, l)]:
205 | if self.version == 'cgl':
206 | if probs[idx].max() >= thres:
207 | features[idx][0] = preds[idx].item()
208 | else:
209 | features[idx][0] = random.choices(idx_pool, k=1, weights=list(total_per_group[l]))[0]
210 |
211 | else:
212 | raise ValueError
213 | print('count the number of data newly!')
214 | num_data, idxs_per_group = self._data_count(features, num_groups, num_classes)
215 |
216 | return features, num_data, idxs_per_group
217 |
218 | def pick_idxs(self, num_data, idxs_per_group, filepath):
219 | print(''.format(filepath))
220 | if not os.path.isdir(os.path.join(self.root, 'annotated_idxs')):
221 | os.mkdir(os.path.join(self.root, 'annotated_idxs'))
222 | num_groups, num_classes = num_data.shape
223 | idxs_dict = {}
224 | idxs_dict['annotated'] = {}
225 | idxs_dict['non-annotated'] = {}
226 | for g in range(num_groups):
227 | for l in range(num_classes):
228 | num_nonannotated = int(num_data[g, l] * (1-self.sv_ratio))
229 | print(g, l, num_nonannotated)
230 | idxs_nonannotated = random.sample(idxs_per_group[(g, l)], num_nonannotated)
231 | idxs_annotated = [idx for idx in idxs_per_group[(g, l)] if idx not in idxs_nonannotated]
232 | idxs_dict['non-annotated'][(g, l)] = idxs_nonannotated
233 | idxs_dict['annotated'][(g, l)] = idxs_annotated
234 |
235 | with open(filepath, 'wb') as f:
236 | pickle.dump(idxs_dict, f, pickle.HIGHEST_PROTOCOL)
237 |
238 | return idxs_dict
239 |
--------------------------------------------------------------------------------
/networks/resnet_dropout.py:
--------------------------------------------------------------------------------
1 | """
2 | Original code:
3 | https://github.com/nayeemrizve/ups
4 | """
5 | import sys
6 | import math
7 | import itertools
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import functional as F
12 | from torch.autograd import Variable, Function
13 | # import torch.utils.model_zoo as model_zoo
14 |
15 | class ResNet224x224(nn.Module):
16 | def __init__(self, block, layers, channels, groups=1, num_classes=1000, downsample='basic'):
17 | super().__init__()
18 | assert len(layers) == 4
19 | self.downsample_mode = downsample
20 | self.inplanes = 64
21 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
22 | bias=False)
23 | self.bn1 = nn.BatchNorm2d(self.inplanes)
24 | self.relu = nn.ReLU(inplace=True)
25 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
26 | self.layer1 = self._make_layer(block, channels, groups, layers[0])
27 | self.layer2 = self._make_layer(
28 | block, channels * 2, groups, layers[1], stride=2)
29 | self.layer3 = self._make_layer(
30 | block, channels * 4, groups, layers[2], stride=2)
31 | self.layer4 = self._make_layer(
32 | block, channels * 8, groups, layers[3], stride=2)
33 | self.avgpool = nn.AvgPool2d(7)
34 | self.fc = nn.Linear(block.out_channels(
35 | channels * 8, groups), num_classes)
36 |
37 | for m in self.modules():
38 | if isinstance(m, nn.Conv2d):
39 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
40 | m.weight.data.normal_(0, math.sqrt(2. / n))
41 | elif isinstance(m, nn.BatchNorm2d):
42 | m.weight.data.fill_(1)
43 | m.bias.data.zero_()
44 |
45 | def _make_layer(self, block, planes, groups, blocks, stride=1):
46 | downsample = None
47 | if stride != 1 or self.inplanes != block.out_channels(planes, groups):
48 | if self.downsample_mode == 'basic' or stride == 1:
49 | downsample = nn.Sequential(
50 | nn.Conv2d(self.inplanes, block.out_channels(planes, groups),
51 | kernel_size=1, stride=stride, bias=False),
52 | nn.BatchNorm2d(block.out_channels(planes, groups)),
53 | )
54 | elif self.downsample_mode == 'shift_conv':
55 | downsample = ShiftConvDownsample(in_channels=self.inplanes,
56 | out_channels=block.out_channels(planes, groups))
57 | else:
58 | assert False
59 |
60 | layers = []
61 | layers.append(block(self.inplanes, planes, groups, stride, downsample))
62 | self.inplanes = block.out_channels(planes, groups)
63 | for i in range(1, blocks):
64 | layers.append(block(self.inplanes, planes, groups))
65 |
66 | return nn.Sequential(*layers)
67 |
68 | def forward(self, x):
69 | x = self.conv1(x)
70 | x = self.bn1(x)
71 | x = self.relu(x)
72 | x = self.maxpool(x)
73 | x = self.layer1(x)
74 | x = self.layer2(x)
75 | x = self.layer3(x)
76 | x = self.layer4(x)
77 | x = self.avgpool(x)
78 | x = x.view(x.size(0), -1)
79 | return self.fc(x)
80 |
81 |
82 | class ResNet32x32(nn.Module):
83 | def __init__(self, block, layers, channels, groups=1, downsample='basic', num_classes=10, dropout=0.3):
84 | super().__init__()
85 | assert len(layers) == 3
86 | self.downsample_mode = downsample
87 | self.inplanes = 16
88 | self.dropout = dropout
89 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1,
90 | padding=1, bias=False)
91 | self.layer1 = self._make_layer(block, channels, groups, layers[0])
92 | self.layer2 = self._make_layer(
93 | block, channels * 2, groups, layers[1], stride=2)
94 | self.layer3 = self._make_layer(
95 | block, channels * 4, groups, layers[2], stride=2)
96 | self.avgpool = nn.AvgPool2d(8)
97 | self.fc = nn.Linear(block.out_channels(
98 | channels * 4, groups), num_classes)
99 |
100 | for m in self.modules():
101 | if isinstance(m, nn.Conv2d):
102 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
103 | m.weight.data.normal_(0, math.sqrt(2. / n))
104 | elif isinstance(m, nn.BatchNorm2d):
105 | m.weight.data.fill_(1)
106 | m.bias.data.zero_()
107 |
108 | def _make_layer(self, block, planes, groups, blocks, stride=1):
109 | downsample = None
110 | if stride != 1 or self.inplanes != block.out_channels(planes, groups):
111 | if self.downsample_mode == 'basic' or stride == 1:
112 | downsample = nn.Sequential(
113 | nn.Conv2d(self.inplanes, block.out_channels(planes, groups),
114 | kernel_size=1, stride=stride, bias=False),
115 | nn.BatchNorm2d(block.out_channels(planes, groups)),
116 | )
117 | elif self.downsample_mode == 'shift_conv':
118 | downsample = ShiftConvDownsample(in_channels=self.inplanes,
119 | out_channels=block.out_channels(planes, groups))
120 | else:
121 | assert False
122 |
123 | layers = []
124 | layers.append(block(self.inplanes, planes, groups, stride, downsample, dropout=self.dropout))
125 | self.inplanes = block.out_channels(planes, groups)
126 | for i in range(1, blocks):
127 | layers.append(block(self.inplanes, planes, groups, dropout=self.dropout))
128 |
129 | return nn.Sequential(*layers)
130 |
131 | def forward(self, x, randomness=0, mini_batch=1):
132 | x = self.conv1(x)
133 | x = self.layer1(x)
134 | x = self.layer2(x)
135 | x = self.layer3(x)
136 | x = self.avgpool(x)
137 | x = x.view(x.size(0), -1)
138 | return self.fc(x)
139 |
140 |
141 | def conv3x3(in_planes, out_planes, stride=1):
142 | "3x3 convolution with padding"
143 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
144 | padding=1, bias=False)
145 |
146 |
147 | class BottleneckBlock(nn.Module):
148 | @classmethod
149 | def out_channels(cls, planes, groups):
150 | if groups > 1:
151 | return 2 * planes
152 | else:
153 | return 4 * planes
154 |
155 | def __init__(self, inplanes, planes, groups, stride=1, downsample=None):
156 | super().__init__()
157 | self.relu = nn.ReLU(inplace=True)
158 |
159 | self.conv_a1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
160 | self.bn_a1 = nn.BatchNorm2d(planes)
161 | self.conv_a2 = nn.Conv2d(
162 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False, groups=groups)
163 | self.bn_a2 = nn.BatchNorm2d(planes)
164 | self.conv_a3 = nn.Conv2d(planes, self.out_channels(
165 | planes, groups), kernel_size=1, bias=False)
166 | self.bn_a3 = nn.BatchNorm2d(self.out_channels(planes, groups))
167 |
168 | self.downsample = downsample
169 | self.stride = stride
170 |
171 | def forward(self, x):
172 | a, residual = x, x
173 |
174 | a = self.conv_a1(a)
175 | a = self.bn_a1(a)
176 | a = self.relu(a)
177 | a = self.conv_a2(a)
178 | a = self.bn_a2(a)
179 | a = self.relu(a)
180 | a = self.conv_a3(a)
181 | a = self.bn_a3(a)
182 |
183 | if self.downsample is not None:
184 | residual = self.downsample(residual)
185 |
186 | return self.relu(residual + a)
187 |
188 |
189 | class ShakeShakeBlock(nn.Module):
190 | @classmethod
191 | def out_channels(cls, planes, groups):
192 | assert groups == 1
193 | return planes
194 |
195 | def __init__(self, inplanes, planes, groups, stride=1, downsample=None, dropout=0.3):
196 | super().__init__()
197 | assert groups == 1
198 | self.conv_a1 = conv3x3(inplanes, planes, stride)
199 | self.bn_a1 = nn.BatchNorm2d(planes)
200 | self.conv_a2 = conv3x3(planes, planes)
201 | self.bn_a2 = nn.BatchNorm2d(planes)
202 |
203 | self.conv_b1 = conv3x3(inplanes, planes, stride)
204 | self.bn_b1 = nn.BatchNorm2d(planes)
205 | self.conv_b2 = conv3x3(planes, planes)
206 | self.bn_b2 = nn.BatchNorm2d(planes)
207 |
208 | self.downsample = downsample
209 | self.stride = stride
210 | self.dropout = dropout
211 | if self.dropout > 0:
212 | self.drop_a = nn.Dropout(p=self.dropout)
213 | self.drop_b = nn.Dropout(p=self.dropout)
214 | # self.drop = nn.Dropout2d(self.dropout)
215 |
216 | def forward(self, x):
217 | a, b, residual = x, x, x
218 |
219 | a = F.relu(a, inplace=False)
220 | a = self.conv_a1(a)
221 | a = self.bn_a1(a)
222 | a = F.relu(a, inplace=True)
223 | if self.dropout > 0:
224 | a = self.drop_a(a)
225 | a = self.conv_a2(a)
226 | a = self.bn_a2(a)
227 |
228 | b = F.relu(b, inplace=False)
229 | b = self.conv_b1(b)
230 | b = self.bn_b1(b)
231 | b = F.relu(b, inplace=True)
232 | if self.dropout > 0:
233 | b = self.drop_b(b)
234 | b = self.conv_b2(b)
235 | b = self.bn_b2(b)
236 |
237 | ab = shake(a, b, training=self.training)
238 |
239 | if self.downsample is not None:
240 | residual = self.downsample(x)
241 |
242 | return residual + ab
243 |
244 | class Shake(Function):
245 | @classmethod
246 | def forward(cls, ctx, inp1, inp2, training):
247 | assert inp1.size() == inp2.size()
248 | gate_size = [inp1.size()[0], *itertools.repeat(1, inp1.dim() - 1)]
249 | gate = inp1.new(*gate_size)
250 | if training:
251 | gate.uniform_(0, 1)
252 | else:
253 | gate.fill_(0.5)
254 | return inp1 * gate + inp2 * (1. - gate)
255 |
256 | @classmethod
257 | def backward(cls, ctx, grad_output):
258 | grad_inp1 = grad_inp2 = grad_training = None
259 | gate_size = [grad_output.size()[0], *itertools.repeat(1,
260 | grad_output.dim() - 1)]
261 | gate = Variable(grad_output.data.new(*gate_size).uniform_(0, 1))
262 | if ctx.needs_input_grad[0]:
263 | grad_inp1 = grad_output * gate
264 | if ctx.needs_input_grad[1]:
265 | grad_inp2 = grad_output * (1 - gate)
266 | assert not ctx.needs_input_grad[2]
267 | return grad_inp1, grad_inp2, grad_training
268 |
269 |
270 | def shake(inp1, inp2, training=False):
271 | return Shake.apply(inp1, inp2, training)
272 |
273 |
274 | class ShiftConvDownsample(nn.Module):
275 | def __init__(self, in_channels, out_channels):
276 | super().__init__()
277 | self.relu = nn.ReLU(inplace=True)
278 | self.conv = nn.Conv2d(in_channels=2 * in_channels,
279 | out_channels=out_channels,
280 | kernel_size=1,
281 | groups=2)
282 | self.bn = nn.BatchNorm2d(out_channels)
283 |
284 | def forward(self, x):
285 | x = torch.cat((x[:, :, 0::2, 0::2],
286 | x[:, :, 1::2, 1::2]), dim=1)
287 | x = self.relu(x)
288 | x = self.conv(x)
289 | x = self.bn(x)
290 | return x
291 |
292 | def resnet18_dropout(pretrained=False, num_groups=0, img_size=224, **kwargs):
293 | return ResNet224x224(ShakeShakeBlock, channels=64, layers=[2,2,2,2], **kwargs)
294 |
295 |
296 | if __name__ == "__main__":
297 | model = ResNet32x32(ShakeShakeBlock,
298 | layers=[4, 4, 4],
299 | channels=96,
300 | downsample='shift_conv')
301 | x = torch.randn(1,3,32,32)
302 | y, feature = model(x)
303 | # y.shape = [1, 10]
304 | # feature.shape = [1, 512]
305 | # import IPython
306 | # IPython.embed()
307 | # print(y.size())
308 |
--------------------------------------------------------------------------------
/networks/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Original code:
3 | https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 | from torchvision.models.utils import load_state_dict_from_url
8 |
9 |
10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50']
11 |
12 |
13 | model_urls = {
14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
17 | }
18 |
19 |
20 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
21 | """3x3 convolution with padding"""
22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
23 | padding=dilation, groups=groups, bias=False, dilation=dilation)
24 |
25 |
26 | def conv1x1(in_planes, out_planes, stride=1):
27 | """1x1 convolution"""
28 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
29 |
30 |
31 | class BasicBlock(nn.Module):
32 | expansion = 1
33 | __constants__ = ['downsample']
34 |
35 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
36 | base_width=64, dilation=1, norm_layer=None):
37 | super(BasicBlock, self).__init__()
38 | if norm_layer is None:
39 | norm_layer = nn.BatchNorm2d
40 | if groups != 1 or base_width != 64:
41 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
42 | if dilation > 1:
43 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
44 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
45 | self.conv1 = conv3x3(inplanes, planes, stride)
46 | self.bn1 = norm_layer(planes)
47 | self.relu = nn.ReLU(inplace=True)
48 | self.conv2 = conv3x3(planes, planes)
49 | self.bn2 = norm_layer(planes)
50 | self.downsample = downsample
51 | self.stride = stride
52 |
53 | def forward(self, x):
54 | identity = x
55 |
56 | out = self.conv1(x)
57 | out = self.bn1(out)
58 | out = self.relu(out)
59 |
60 | out = self.conv2(out)
61 | out = self.bn2(out)
62 |
63 | if self.downsample is not None:
64 | identity = self.downsample(x)
65 |
66 | out += identity
67 | out = self.relu(out)
68 |
69 | return out
70 |
71 |
72 | class Bottleneck(nn.Module):
73 | expansion = 4
74 | __constants__ = ['downsample']
75 |
76 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
77 | base_width=64, dilation=1, norm_layer=None):
78 | super(Bottleneck, self).__init__()
79 | if norm_layer is None:
80 | norm_layer = nn.BatchNorm2d
81 | width = int(planes * (base_width / 64.)) * groups
82 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
83 | self.conv1 = conv1x1(inplanes, width)
84 | self.bn1 = norm_layer(width)
85 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
86 | self.bn2 = norm_layer(width)
87 | self.conv3 = conv1x1(width, planes * self.expansion)
88 | self.bn3 = norm_layer(planes * self.expansion)
89 | self.relu = nn.ReLU(inplace=True)
90 | self.downsample = downsample
91 | self.stride = stride
92 |
93 | def forward(self, x):
94 | identity = x
95 |
96 | out = self.conv1(x)
97 | out = self.bn1(out)
98 | out = self.relu(out)
99 |
100 | out = self.conv2(out)
101 | out = self.bn2(out)
102 | out = self.relu(out)
103 |
104 | out = self.conv3(out)
105 | out = self.bn3(out)
106 |
107 | if self.downsample is not None:
108 | identity = self.downsample(x)
109 |
110 | out += identity
111 | out = self.relu(out)
112 |
113 | return out
114 |
115 |
116 | class ResNet(nn.Module):
117 |
118 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
119 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
120 | norm_layer=None, hydra=False, num_groups=None, img_size=32):
121 | super(ResNet, self).__init__()
122 | if norm_layer is None:
123 | norm_layer = nn.BatchNorm2d
124 | self._norm_layer = norm_layer
125 |
126 | self.inplanes = 64
127 | self.dilation = 1
128 | if replace_stride_with_dilation is None:
129 | # each element in the tuple indicates if we should replace
130 | # the 2x2 stride with a dilated convolution instead
131 | replace_stride_with_dilation = [False, False, False]
132 | if len(replace_stride_with_dilation) != 3:
133 | raise ValueError("replace_stride_with_dilation should be None "
134 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
135 | self.groups = groups
136 | self.base_width = width_per_group
137 | self.img_size = img_size
138 | if img_size==32:
139 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1,
140 | bias=False)
141 | else:
142 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
143 | bias=False)
144 | self.bn1 = norm_layer(self.inplanes)
145 | self.relu = nn.ReLU(inplace=True)
146 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
147 | self.layer1 = self._make_layer(block, 64, layers[0])
148 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
149 | dilate=replace_stride_with_dilation[0])
150 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
151 | dilate=replace_stride_with_dilation[1])
152 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
153 | dilate=replace_stride_with_dilation[2])
154 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
155 | if not hydra:
156 | self.fc = nn.Linear(512 * block.expansion, num_classes)
157 | else:
158 | self.fc = nn.ModuleList()
159 | for _ in range(num_groups):
160 | self.fc.append(torch.nn.Linear(512 * block.expansion, num_classes))
161 |
162 | for m in self.modules():
163 | if isinstance(m, nn.Conv2d):
164 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
165 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
166 | nn.init.constant_(m.weight, 1)
167 | nn.init.constant_(m.bias, 0)
168 |
169 | # Zero-initialize the last BN in each residual branch,
170 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
171 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
172 | if zero_init_residual:
173 | for m in self.modules():
174 | if isinstance(m, Bottleneck):
175 | nn.init.constant_(m.bn3.weight, 0)
176 | elif isinstance(m, BasicBlock):
177 | nn.init.constant_(m.bn2.weight, 0)
178 |
179 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
180 | norm_layer = self._norm_layer
181 | downsample = None
182 | previous_dilation = self.dilation
183 | if dilate:
184 | self.dilation *= stride
185 | stride = 1
186 | if stride != 1 or self.inplanes != planes * block.expansion:
187 | downsample = nn.Sequential(
188 | conv1x1(self.inplanes, planes * block.expansion, stride),
189 | norm_layer(planes * block.expansion),
190 | )
191 |
192 | layers = []
193 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
194 | self.base_width, previous_dilation, norm_layer))
195 | self.inplanes = planes * block.expansion
196 | for _ in range(1, blocks):
197 | layers.append(block(self.inplanes, planes, groups=self.groups,
198 | base_width=self.base_width, dilation=self.dilation,
199 | norm_layer=norm_layer))
200 |
201 | return nn.Sequential(*layers)
202 |
203 | def _forward_impl(self, x, get_inter=False, reid = False):
204 | # See note [TorchScript super()]
205 | h = self.conv1(x)
206 | h = self.bn1(h)
207 | h = self.relu(h)
208 | if self.img_size != 32:
209 | h = self.maxpool(h)
210 |
211 | b1 = self.layer1(h)
212 | b2 = self.layer2(b1)
213 | b3 = self.layer3(b2)
214 | b4 = self.layer4(b3)
215 |
216 | h = self.avgpool(b4)
217 | h1 = torch.flatten(h, 1)
218 | if isinstance(self.fc, nn.ModuleList):
219 | h = []
220 | for i in range(len(self.fc)):
221 | h.append(self.fc[i](h1))
222 | else:
223 | h = self.fc(h1)
224 |
225 | if get_inter:
226 | return b1, b2, b3, b4, h
227 | elif reid:
228 | return h1
229 | else:
230 | return h
231 |
232 | def forward(self, x, get_inter=False, reid=False):
233 | return self._forward_impl(x, get_inter, reid)
234 |
235 | def adapt_classifier(self, x):
236 | h = self.avgpool(x)
237 | h1 = torch.flatten(h, 1)
238 | return self.fc(h1)
239 |
240 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
241 | model = ResNet(block, layers, **kwargs)
242 | if pretrained:
243 | state_dict = load_state_dict_from_url(model_urls[arch],
244 | progress=progress)
245 | model.load_state_dict(state_dict)
246 | return model
247 |
248 | def resnet12(pretrained=False, progress=True,**kwargs):
249 | r"""ResNet-18 model from
250 | `"Deep Residual Learning for Image Recognition" `_
251 | Args:
252 | pretrained (bool): If True, returns a model pre-trained on ImageNet
253 | progress (bool): If True, displays a progress bar of the download to stderr
254 | """
255 | return _resnet('resnet12', BasicBlock, [1, 1, 2, 1], pretrained, progress,
256 | **kwargs)
257 |
258 | def resnet10(pretrained=False, progress=True,**kwargs):
259 | r"""ResNet-18 model from
260 | `"Deep Residual Learning for Image Recognition" `_
261 | Args:
262 | pretrained (bool): If True, returns a model pre-trained on ImageNet
263 | progress (bool): If True, displays a progress bar of the download to stderr
264 | """
265 | return _resnet('resnet10', BasicBlock, [1, 1, 1, 1], pretrained, progress,
266 | **kwargs)
267 |
268 | def resnet18(pretrained=False, progress=True, **kwargs):
269 | r"""ResNet-18 model from
270 | `"Deep Residual Learning for Image Recognition" `_
271 | Args:
272 | pretrained (bool): If True, returns a model pre-trained on ImageNet
273 | progress (bool): If True, displays a progress bar of the download to stderr
274 | """
275 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
276 | **kwargs)
277 |
278 |
279 | def resnet34(pretrained=False, progress=True, **kwargs):
280 | r"""ResNet-34 model from
281 | `"Deep Residual Learning for Image Recognition" `_
282 | Args:
283 | pretrained (bool): If True, returns a model pre-trained on ImageNet
284 | progress (bool): If True, displays a progress bar of the download to stderr
285 | """
286 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
287 | **kwargs)
288 |
289 |
290 | def resnet50(pretrained=False, progress=True, **kwargs):
291 | r"""ResNet-50 model from
292 | `"Deep Residual Learning for Image Recognition" `_
293 | Args:
294 | pretrained (bool): If True, returns a model pre-trained on ImageNet
295 | progress (bool): If True, displays a progress bar of the download to stderr
296 | """
297 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
298 | **kwargs)
299 |
300 | def resnet101(pretrained=False, progress=True, **kwargs) -> ResNet:
301 | r"""ResNet-101 model from
302 | `"Deep Residual Learning for Image Recognition" `_.
303 | Args:
304 | pretrained (bool): If True, returns a model pre-trained on ImageNet
305 | progress (bool): If True, displays a progress bar of the download to stderr
306 | """
307 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
308 | **kwargs)
309 |
--------------------------------------------------------------------------------
/main_groupclf.py:
--------------------------------------------------------------------------------
1 | """
2 | cgl_fairness
3 | Copyright (c) 2022-present NAVER Corp.
4 | MIT license
5 | """
6 | import torch
7 | from sklearn import metrics
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import numpy as np
11 | import networks
12 | import torch.nn.functional as F
13 | import data_handler
14 | import trainer
15 | from utils import check_log_dir, make_log_name, set_seed
16 | from data_handler.dataset_factory import DatasetFactory
17 | from arguments import get_args
18 | import time
19 | import pickle
20 | import os
21 | from torch.utils.data import DataLoader
22 | args = get_args()
23 | ###################################################################################################################################################################
24 | """
25 | Used only for training group classifer
26 | """
27 | ###################################################################################################################################################################
28 |
29 |
30 | def get_weights(loader, cuda=True):
31 | num_groups = loader.dataset.num_groups
32 | data_counts = torch.zeros(num_groups)
33 | data_counts = data_counts.cuda() if cuda else data_counts
34 |
35 | for data in loader:
36 | inputs, _, groups, _, _ = data
37 | for g in range(num_groups):
38 | data_counts[g] += torch.sum((groups == g))
39 |
40 | weights = data_counts / data_counts.min()
41 | return weights, data_counts
42 |
43 |
44 | def focal_loss(input_values, gamma=10):
45 | """Computes the focal loss"""
46 | p = torch.exp(-input_values)
47 | loss = (1 - p) ** gamma * input_values
48 | return loss.mean()
49 |
50 |
51 | class FocalLoss(nn.Module):
52 | def __init__(self, weight=None, gamma=0.5):
53 | super(FocalLoss, self).__init__()
54 | assert gamma >= 0
55 | self.gamma = gamma
56 | self.weight = weight
57 |
58 | def forward(self, input, target):
59 | return focal_loss(F.cross_entropy(input, target, reduction='none', weight=self.weight), self.gamma)
60 |
61 |
62 | def measure_ood(ind_probs, ood_probs, tpr_thres=0.95):
63 | n_pos = len(ind_probs)
64 | n_neg = len(ood_probs)
65 |
66 | labels = np.append(
67 | np.ones_like(ind_probs) * 1,
68 | np.ones_like(ood_probs) * 2
69 | )
70 | preds = np.append(ind_probs, ood_probs)
71 | fpr, tpr, thresholds = metrics.roc_curve(labels, preds, pos_label=1)
72 | auroc = metrics.auc(fpr, tpr)
73 |
74 | thres_indx = np.where(tpr >= tpr_thres)[0][0]
75 | tnr_at_tpr95 = 1 - fpr[thres_indx]
76 |
77 | temp_idx = np.argmax(((1 - fpr) * n_neg + tpr * n_pos) / (n_pos + n_neg))
78 | detection_acc = np.max(((1 - fpr) * n_neg + tpr * n_pos) / (n_pos + n_neg))
79 | if n_neg != n_pos:
80 | print(f'warning: n_neg ({n_neg}) != n_pos ({n_pos}). It may shows weird detection acc')
81 | print(f'current threshold ({thresholds[temp_idx]}): FPR ({fpr[temp_idx]}), TPR ({tpr[temp_idx]})')
82 |
83 | return {
84 | 'auroc': auroc,
85 | 'tnr_at_tpr95': tnr_at_tpr95,
86 | 'detection_acc': detection_acc,
87 | 'opt_thres': thresholds[temp_idx]
88 | }
89 |
90 |
91 | def predict_thres(probs, true_idxs, false_idxs, val_idxs):
92 | print(len(true_idxs), len(false_idxs))
93 | val_false_idxs = list(set(false_idxs).intersection(val_idxs))
94 | val_true_idxs = list(set(true_idxs).intersection(val_idxs))
95 |
96 | val_true_maxprob = probs[val_true_idxs].max(dim=1)[0]
97 | val_false_maxprob = probs[val_false_idxs].max(dim=1)[0]
98 | r = measure_ood(val_true_maxprob.numpy(), val_false_maxprob.numpy())
99 | return r['opt_thres']
100 |
101 |
102 | def predict_group(model, args, trainloader, testloader):
103 | model.cuda('cuda:{}'.format(args.device))
104 | target_attr = None
105 | target_attr = args.target
106 | if args.dataset == 'adult':
107 | target_attr = 'sex'
108 | elif args.dataset == 'compas':
109 | target_attr = 'race'
110 | dataset = DatasetFactory.get_dataset(args.dataset, split='train', sv_ratio=1, version='',
111 | target_attr=target_attr, add_attr=args.add_attr)
112 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
113 | num_workers=args.num_workers, pin_memory=True, drop_last=False)
114 |
115 | val_idxs = []
116 | if args.version == 'groupclf_val':
117 | val_idxs = trainloader.dataset.val_idxs
118 |
119 | test_idxs = []
120 | filename = f'{args.seed}_{args.sv}'
121 | if args.dataset == 'celeba':
122 | if args.target != 'Attractive':
123 | filename += f'_{args.target}'
124 | if args.add_attr is not None:
125 | filename += f'_{args.add_attr}'
126 | filename += '.pkl'
127 | filename = os.path.join(dataset.root, 'annotated_idxs', filename)
128 | with open(filename, 'rb') as f:
129 | idxs_dict = pickle.load(f)
130 | for key in idxs_dict['non-annotated'].keys():
131 | test_idxs.extend(idxs_dict['non-annotated'][key])
132 |
133 | model.eval()
134 | preds_list = []
135 | groups_list = []
136 | labels_list = []
137 | probs_list = []
138 | true_idxs = []
139 | false_idxs = []
140 |
141 | uncertainty = False
142 | if 'dropout' in args.model:
143 | uncertainty = True
144 | enable_dropout(model)
145 |
146 | with torch.no_grad():
147 | for i, data in enumerate(dataloader):
148 | inputs, _, groups, labels, (idxs, _) = data
149 | if args.cuda:
150 | inputs = inputs.cuda()
151 | groups = groups.cuda()
152 | labels = labels.cuda()
153 | idxs = idxs.cuda()
154 |
155 | if uncertainty:
156 | preds, probs = mc_dropout(model, inputs)
157 |
158 | else:
159 | outputs = model(inputs)
160 | probs = torch.nn.functional.softmax(outputs, dim=1)
161 | preds = torch.argmax(outputs, 1)
162 |
163 | true_mask = groups == preds
164 | false_mask = groups != preds
165 |
166 | probs_list.append(probs.cpu())
167 |
168 | preds_list.append(preds.cpu())
169 | groups_list.append(groups.cpu())
170 | labels_list.append(labels.cpu())
171 |
172 | true_idxs.extend(list(idxs[true_mask].cpu().numpy()))
173 | false_idxs.extend(list(idxs[false_mask].cpu().numpy()))
174 |
175 | preds = torch.cat(preds_list)
176 | groups = torch.cat(groups_list)
177 | probs = torch.cat(probs_list)
178 | labels = torch.cat(labels_list)
179 |
180 | # calculate the test acc
181 | test_acc = (preds == groups)[test_idxs].sum().item() / len(test_idxs)
182 | print(f'test acc : {test_acc}')
183 | if args.version == 'groupclf_val':
184 | val_acc = (preds == groups)[val_idxs].sum().item() / len(val_idxs)
185 | print(f'val acc : {val_acc}')
186 |
187 | results = {}
188 | results['pred'] = preds
189 | results['group'] = groups
190 | results['probs'] = probs
191 | results['label'] = labels
192 | results['true_idxs'] = true_idxs
193 | results['false_idxs'] = false_idxs
194 | return results, val_idxs
195 |
196 |
197 | def enable_dropout(model):
198 | for m in model.modules():
199 | if m.__class__.__name__.startswith('Dropout'):
200 | m.train()
201 |
202 |
203 | def mc_dropout(model, inputs):
204 | temp_nl = 2
205 | f_pass = 10
206 | out_prob = []
207 | out_prob_nl = []
208 | for _ in range(f_pass):
209 | outputs = model(inputs)
210 | out_prob.append(F.softmax(outputs, dim=1)) # for selecting positive pseudo-labels
211 | out_prob_nl.append(F.softmax(outputs/temp_nl, dim=1)) # for selecting negative pseudo-labels
212 | out_prob = torch.stack(out_prob)
213 | out_prob_nl = torch.stack(out_prob_nl)
214 | out_std = torch.std(out_prob, dim=0)
215 | probs = torch.mean(out_prob, dim=0)
216 | out_prob_nl = torch.mean(out_prob_nl, dim=0)
217 | max_value, preds = torch.max(probs, dim=1)
218 | max_std = out_std.gather(1, preds.view(-1, 1)).squeeze()
219 | n_uncertain = (max_std > 0.05).sum()
220 | print('# of uncertain samples : ', n_uncertain)
221 | preds[max_std > 0.05] = -1
222 | return preds, probs
223 |
224 |
225 | def main():
226 | torch.backends.cudnn.enabled = True
227 |
228 | seed = args.seed
229 | set_seed(seed)
230 |
231 | np.set_printoptions(precision=4)
232 | torch.set_printoptions(precision=4)
233 |
234 | if args.dataset == 'adult':
235 | args.img_size = 97
236 | elif args.dataset == 'compas':
237 | args.img_size = 400
238 | log_name = make_log_name(args)
239 | dataset = args.dataset
240 | save_dir = os.path.join(args.save_dir, args.date, dataset, args.method)
241 | log_dir = os.path.join(args.result_dir, args.date, dataset, args.method)
242 | check_log_dir(save_dir)
243 | check_log_dir(log_dir)
244 | if 'groupclf' not in args.version:
245 | raise ValueError
246 | ########################## get dataloader ################################
247 | tmp = data_handler.DataloaderFactory.get_dataloader(args.dataset,
248 | batch_size=args.batch_size, seed=args.seed,
249 | num_workers=args.num_workers,
250 | target_attr=args.target,
251 | add_attr=args.add_attr,
252 | labelwise=args.labelwise,
253 | sv_ratio=args.sv,
254 | version=args.version,
255 | )
256 | num_groups, num_classes, train_loader, test_loader = tmp
257 | ########################## get model ##################################
258 |
259 | num_model_output = num_classes if args.modelpath is not None else num_groups
260 | model = networks.ModelFactory.get_model(args.model, num_model_output, args.img_size)
261 |
262 | if args.modelpath is not None:
263 | model.load_state_dict(torch.load(args.modelpath))
264 | model.fc = nn.Linear(in_features=model.fc.weight.shape[1], out_features=num_groups, bias=True)
265 |
266 | model.cuda('cuda:{}'.format(args.device))
267 | teacher = None
268 | criterion = None
269 | if args.dataset in ['compas', 'adult']:
270 | weights = None
271 | weights, data_counts = get_weights(train_loader, cuda=args.cuda)
272 | cls_num_list = []
273 | for i in range(num_groups):
274 | cls_num_list.append(data_counts[i].item())
275 | beta = 0.999
276 | effective_num = 1.0 - np.power(beta, cls_num_list)
277 | per_cls_weights = (1.0 - beta) / np.array(effective_num)
278 | per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
279 | per_cls_weights = torch.FloatTensor(per_cls_weights).cuda()
280 | criterion = FocalLoss(weight=per_cls_weights).cuda()
281 | ########################## get trainer ##################################
282 | if args.optimizer == 'Adam':
283 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)
284 | elif 'SGD' in args.optimizer:
285 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
286 |
287 | trainer_ = trainer.TrainerFactory.get_trainer(args.method, model=model, args=args,
288 | optimizer=optimizer, teacher=teacher)
289 |
290 | ####################### start training or evaluating ####################
291 |
292 | if args.mode == 'train':
293 | start_t = time.time()
294 | trainer_.train(train_loader, test_loader, args.epochs, criterion=criterion)
295 |
296 | end_t = time.time()
297 | train_t = int((end_t - start_t)/60) # to minutes
298 | print('Training Time : {} hours {} minutes'.format(int(train_t/60), (train_t % 60)))
299 | trainer_.save_model(save_dir, log_name)
300 | else:
301 | print('Evaluation ----------------')
302 | model_name = os.path.join(save_dir, log_name+'.pt')
303 | model.load_state_dict(torch.load(model_name))
304 | results, val_idxs = predict_group(model, args, train_loader, test_loader)
305 | if args.version == 'groupclf_val':
306 | opt_thres = predict_thres(results['probs'], results['true_idxs'], results['false_idxs'], val_idxs)
307 | results['opt_thres'] = opt_thres
308 | save_path = os.path.join('../data', args.dataset, args.date)
309 | if not os.path.isdir(save_path):
310 | os.makedirs(save_path)
311 | torch.save(results, os.path.join(save_path, log_name+'.pt'))
312 | print(os.path.join(save_path, log_name+'.pt'))
313 | return
314 |
315 | if args.evalset == 'all':
316 | trainer_.compute_confusion_matix('train', train_loader.dataset.num_classes, train_loader, log_dir, log_name)
317 | trainer_.compute_confusion_matix('test', test_loader.dataset.num_classes, test_loader, log_dir, log_name)
318 |
319 | elif args.evalset == 'train':
320 | trainer_.compute_confusion_matix('train', train_loader.dataset.num_classes, train_loader, log_dir, log_name)
321 | else:
322 | trainer_.compute_confusion_matix('test', test_loader.dataset.num_classes, test_loader, log_dir, log_name)
323 |
324 | print('Done!')
325 |
326 |
327 | if __name__ == '__main__':
328 | main()
329 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | cgl_fairness
2 | Copyright (c) 2022-present NAVER Corp.
3 |
4 | Permission is hereby granted, free of charge, to any person obtaining a copy
5 | of this software and associated documentation files (the "Software"), to deal
6 | in the Software without restriction, including without limitation the rights
7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 | copies of the Software, and to permit persons to whom the Software is
9 | furnished to do so, subject to the following conditions:
10 |
11 | The above copyright notice and this permission notice shall be included in
12 | all copies or substantial portions of the Software.
13 |
14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 | THE SOFTWARE.
21 |
22 | --------------------------------------------------------------------------------------
23 |
24 | This project contains subcomponents with separate copyright notices and license terms.
25 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
26 |
27 | =====
28 |
29 | Trusted-AI/AIF360
30 | https://github.com/Trusted-AI/AIF360
31 |
32 |
33 | Copyright 2018-2021 The AI Fairness 360 (AIF360) Authors
34 |
35 | Licensed under the Apache License, Version 2.0 (the "License");
36 | you may not use this file except in compliance with the License.
37 | You may obtain a copy of the License at
38 |
39 | http://www.apache.org/licenses/LICENSE-2.0
40 |
41 | Unless required by applicable law or agreed to in writing, software
42 | distributed under the License is distributed on an "AS IS" BASIS,
43 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
44 | See the License for the specific language governing permissions and
45 | limitations under the License.
46 |
47 | =====
48 |
49 | pytorch/vision
50 | https://github.com/pytorch/vision
51 |
52 |
53 | BSD 3-Clause License
54 |
55 | Copyright (c) Soumith Chintala 2016,
56 | All rights reserved.
57 |
58 | Redistribution and use in source and binary forms, with or without
59 | modification, are permitted provided that the following conditions are met:
60 |
61 | * Redistributions of source code must retain the above copyright notice, this
62 | list of conditions and the following disclaimer.
63 |
64 | * Redistributions in binary form must reproduce the above copyright notice,
65 | this list of conditions and the following disclaimer in the documentation
66 | and/or other materials provided with the distribution.
67 |
68 | * Neither the name of the copyright holder nor the names of its
69 | contributors may be used to endorse or promote products derived from
70 | this software without specific prior written permission.
71 |
72 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
73 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
74 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
75 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
76 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
77 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
78 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
79 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
80 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
81 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
82 |
83 | =====
84 |
85 | sangwon79/Fair-Feature-Distillation-for-Visual-Recognition
86 | https://github.com/sangwon79/Fair-Feature-Distillation-for-Visual-Recognition
87 |
88 |
89 |
90 | MIT License
91 |
92 | Copyright (c) 2021 Donggyu Lee
93 |
94 | Permission is hereby granted, free of charge, to any person obtaining a copy
95 | of this software and associated documentation files (the "Software"), to deal
96 | in the Software without restriction, including without limitation the rights
97 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
98 | copies of the Software, and to permit persons to whom the Software is
99 | furnished to do so, subject to the following conditions:
100 |
101 | The above copyright notice and this permission notice shall be included in all
102 | copies or substantial portions of the Software.
103 |
104 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
105 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
106 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
107 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
108 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
109 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
110 | SOFTWARE.
111 |
112 | =====
113 |
114 | yunjey/pytorch-tutorial
115 | https://github.com/yunjey/pytorch-tutorial
116 |
117 |
118 | MIT License
119 |
120 | Copyright (c) 2017
121 |
122 | Permission is hereby granted, free of charge, to any person obtaining a copy
123 | of this software and associated documentation files (the "Software"), to deal
124 | in the Software without restriction, including without limitation the rights
125 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
126 | copies of the Software, and to permit persons to whom the Software is
127 | furnished to do so, subject to the following conditions:
128 |
129 | The above copyright notice and this permission notice shall be included in all
130 | copies or substantial portions of the Software.
131 |
132 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
133 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
134 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
135 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
136 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
137 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
138 | SOFTWARE.
139 |
140 | =====
141 |
142 | nayeemrizve/ups
143 | https://github.com/nayeemrizve/ups
144 |
145 |
146 | MIT License
147 |
148 | Copyright (c) 2021 Mamshad Nayeem Rizve
149 |
150 | Permission is hereby granted, free of charge, to any person obtaining a copy
151 | of this software and associated documentation files (the "Software"), to deal
152 | in the Software without restriction, including without limitation the rights
153 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
154 | copies of the Software, and to permit persons to whom the Software is
155 | furnished to do so, subject to the following conditions:
156 |
157 | The above copyright notice and this permission notice shall be included in all
158 | copies or substantial portions of the Software.
159 |
160 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
161 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
162 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
163 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
164 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
165 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
166 | SOFTWARE.
167 |
168 | =====
169 |
170 | clovaai/rebias
171 | https://github.com/clovaai/rebias
172 |
173 |
174 | ReBias
175 | Copyright (c) 2020-present NAVER Corp.
176 |
177 | Permission is hereby granted, free of charge, to any person obtaining a copy
178 | of this software and associated documentation files (the "Software"), to deal
179 | in the Software without restriction, including without limitation the rights
180 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
181 | copies of the Software, and to permit persons to whom the Software is
182 | furnished to do so, subject to the following conditions:
183 |
184 | The above copyright notice and this permission notice shall be included in
185 | all copies or substantial portions of the Software.
186 |
187 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
188 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
189 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
190 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
191 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
192 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
193 | THE SOFTWARE.
194 |
195 | --------------------------------------------------------------------------------------
196 |
197 | This project contains subcomponents with separate copyright notices and license terms.
198 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
199 |
200 | =====
201 |
202 | facebookresearch/SlowFast
203 | https://github.com/facebookresearch/SlowFast
204 |
205 |
206 | Copyright 2019, Facebook, Inc
207 |
208 | Licensed under the Apache License, Version 2.0 (the "License");
209 | you may not use this file except in compliance with the License.
210 | You may obtain a copy of the License at
211 |
212 | http://www.apache.org/licenses/LICENSE-2.0
213 |
214 | Unless required by applicable law or agreed to in writing, software
215 | distributed under the License is distributed on an "AS IS" BASIS,
216 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
217 | See the License for the specific language governing permissions and
218 | limitations under the License.
219 |
220 | =====
221 |
222 | pytorch/vision
223 | https://github.com/pytorch/vision
224 |
225 |
226 | BSD 3-Clause License
227 |
228 | Copyright (c) Soumith Chintala 2016,
229 | All rights reserved.
230 |
231 | Redistribution and use in source and binary forms, with or without
232 | modification, are permitted provided that the following conditions are met:
233 |
234 | * Redistributions of source code must retain the above copyright notice, this
235 | list of conditions and the following disclaimer.
236 |
237 | * Redistributions in binary form must reproduce the above copyright notice,
238 | this list of conditions and the following disclaimer in the documentation
239 | and/or other materials provided with the distribution.
240 |
241 | * Neither the name of the copyright holder nor the names of its
242 | contributors may be used to endorse or promote products derived from
243 | this software without specific prior written permission.
244 |
245 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
246 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
247 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
248 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
249 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
250 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
251 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
252 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
253 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
254 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
255 |
256 | =====
257 |
258 | wielandbrendel/bag-of-local-features-models
259 | https://github.com/wielandbrendel/bag-of-local-features-models
260 |
261 |
262 | Copyright (c) 2019 Wieland Brendel
263 |
264 | Permission is hereby granted, free of charge, to any person obtaining a copy
265 | of this software and associated documentation files (the "Software"), to deal
266 | in the Software without restriction, including without limitation the rights
267 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
268 | copies of the Software, and to permit persons to whom the Software is
269 | furnished to do so, subject to the following conditions:
270 |
271 | The above copyright notice and this permission notice shall be included in all
272 | copies or substantial portions of the Software.
273 |
274 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
275 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
276 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
277 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
278 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
279 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
280 | SOFTWARE.
281 |
282 | =====
283 |
284 | cdancette/rubi.bootstrap.pytorch
285 | https://github.com/cdancette/rubi.bootstrap.pytorch
286 |
287 |
288 | BSD 3-Clause License
289 |
290 | Copyright (c) 2019+, Remi Cadene, Corentin Dancette
291 | All rights reserved.
292 |
293 | Redistribution and use in source and binary forms, with or without
294 | modification, are permitted provided that the following conditions are met:
295 |
296 | * Redistributions of source code must retain the above copyright notice, this
297 | list of conditions and the following disclaimer.
298 |
299 | * Redistributions in binary form must reproduce the above copyright notice,
300 | this list of conditions and the following disclaimer in the documentation
301 | and/or other materials provided with the distribution.
302 |
303 | * Neither the name of the copyright holder nor the names of its
304 | contributors may be used to endorse or promote products derived from
305 | this software without specific prior written permission.
306 |
307 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
308 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
309 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
310 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
311 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
312 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
313 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
314 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
315 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
316 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
317 |
318 | =====
319 |
320 | chrisc36/debias
321 | https://github.com/chrisc36/debias
322 |
323 |
324 | "Don’t Take the Easy Way Out: Ensemble Based Methods for Avoiding Known Dataset Biases". Christopher Clark, Mark Yatskar, Luke Zettlemoyer. In EMNLP 2019.
325 |
326 | ---
327 |
328 | Licensed under the Apache License, Version 2.0 (the "License");
329 | you may not use this file except in compliance with the License.
330 | You may obtain a copy of the License at
331 |
332 | http://www.apache.org/licenses/LICENSE-2.0
333 |
334 | Unless required by applicable law or agreed to in writing, software
335 | distributed under the License is distributed on an "AS IS" BASIS,
336 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
337 | See the License for the specific language governing permissions and
338 | limitations under the License.
339 |
340 | =====
341 |
342 | =====
343 |
--------------------------------------------------------------------------------