├── README.md
├── bpa_trainer.py
├── dataset
└── celebA.py
├── factory.py
├── main.py
├── models
└── classification.py
├── modules
├── centroids.py
├── loss.py
└── transform.py
├── requirements.txt
├── trainer.py
└── utils
├── io_utils.py
└── train_utils.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Unsupervised Learning of Debiased Representations with Pseudo-Attributes
3 | Pytorch implementation for BPA (CVPR 2022)
4 |
5 | [Seonguk Seo](https://seoseong.uk/), [Joon-Young Lee](https://joonyoung-cv.github.io/), [Bohyung Han](https://cv.snu.ac.kr/~bhhan/)
6 |
7 | Seoul National University, Adobe Research
8 | ### [[Paper](https://arxiv.org/abs/2108.02943)]
9 |
10 | >
11 | > Dataset bias is a critical challenge in machine learning since it often leads to a negative impact on a model due to the unintended decision rules captured by spurious correlations. Although existing works often handle this issue based on human supervision, the availability of the proper annotations is impractical and even unrealistic. To better tackle the limitation, we propose a simple but effective unsupervised debiasing technique. Specifically, we first identify pseudo-attributes based on the results from clustering performed in the feature embedding space even without an explicit bias attribute supervision. Then, we employ a novel cluster-wise reweighting scheme to learn debiased representation; the proposed method prevents minority groups from being discounted for minimizing the overall loss, which is desirable for worst-case generalization. The extensive experiments demonstrate the outstanding performance of our approach on multiple standard benchmarks, even achieving the competitive accuracy to the supervised counterpart.
12 |
13 |
14 | ---
15 |
16 | ## Installation
17 | ```
18 | git clone https://github.com/skynbe/pseudo-attributes.git
19 | cd pseudo-attributes
20 | pip install -r requirements.txt
21 | ```
22 | Download CelebA dataset at $ROOT_PATH/data/celebA.
23 |
24 |
25 | ### Quick Start
26 |
27 | Train baseline model:
28 | ```
29 | python main.py --arch ResNet18 --trainer classify --desc base --dataset celebA --test_epoch 1 --lr 1e-4 --target_attr Blond_Hair --bias_attrs Male --no_save
30 | ```
31 |
32 | Train BPA model:
33 | ```
34 | python main.py --arch ResNet18 --trainer bpa --desc bpa_k8 --dataset celebA --test_epoch 1 --lr 1e-4 --target_attr Blond_Hair --bias_attrs Male --k 8 --ks 8 --no_save --use_base {$BASE_PATH}
35 | ```
36 |
37 |
38 | ## Citation
39 |
40 | If you find our work useful in your research, please cite:
41 |
42 | ```
43 | @inproceedings{seo2022unsupervised,
44 | title={Unsupervised Learning of Debiased Representations with Pseudo-Attributes},
45 | author={Seo, Seonguk and Lee, Joon-Young and Han, Bohyung},
46 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
47 | pages={16742--16751},
48 | year={2022}
49 | }
50 | ```
51 |
--------------------------------------------------------------------------------
/bpa_trainer.py:
--------------------------------------------------------------------------------
1 |
2 | import copy
3 | from trainer import *
4 | from kmeans_pytorch import kmeans
5 | from modules.centroids import AvgFixedCentroids
6 | from torch.autograd import Variable
7 |
8 |
9 | class OnlineTrainer(BiasedClassifyTrainer):
10 |
11 | def __init__(self, args, model, loaders, optimizer, num_classes, per_clusters=0):
12 | super().__init__(args, model, loaders, optimizer, num_classes)
13 | if per_clusters == 0:
14 | per_clusters = args.k
15 |
16 | self.centroids = AvgFixedCentroids(args, num_classes, per_clusters=per_clusters)
17 | self.update_cluster_iter = args.update_cluster_iter
18 | self.checkpoint_dir = args.checkpoint_dir
19 |
20 | def save_model(self, epoch):
21 | if not self.args.no_save or epoch % self.args.save_epoch == 0:
22 | torch.save({
23 | 'epoch': epoch,
24 | 'state_dict': self.model.state_dict(),
25 | 'optimizer' : self.optimizer.state_dict(),
26 | }, self.checkpoint_dir / 'e{:04d}.pth'.format(epoch))
27 | return
28 |
29 | def load_model(self, epoch=0):
30 | self.logger.info('Resume training')
31 | if epoch==0:
32 | checkpoint_path = max((f.stat().st_mtime, f) for f in self.checkpoint_dir.glob('*.pth'))[1]
33 | self.logger.info('Resume Latest from {}'.format(checkpoint_path))
34 | else:
35 | self.logger.info('Resume from {}'.format(epoch))
36 | checkpoint_path = self.checkpoint_dir / 'e{:04d}.pth'.format(epoch)
37 |
38 | checkpoint = torch.load(checkpoint_path)
39 | self.model.load_state_dict(checkpoint['state_dict']) # Set CUDA before if error occurs.
40 | self.optimizer.load_state_dict(checkpoint['optimizer'])
41 | self.epoch = checkpoint['epoch']
42 |
43 |
44 |
45 | class BPATrainer(OnlineTrainer):
46 |
47 | def __init__(self, args, model, loaders, optimizer, num_classes):
48 | super().__init__(args, model, loaders, optimizer, num_classes)
49 | self.class_weights = None
50 | self.base_model = copy.deepcopy(self.model)
51 | if not args.use_base:
52 | assert ValueError
53 |
54 | def save_model(self, epoch):
55 | if not self.args.no_save or epoch % self.args.save_epoch == 0:
56 | torch.save({
57 | 'epoch': epoch,
58 | 'state_dict': self.model.state_dict(),
59 | 'optimizer' : self.optimizer.state_dict(),
60 | }, self.checkpoint_dir / 'e{:04d}.pth'.format(epoch))
61 | return
62 |
63 |
64 | def use_base_model(self, file_name):
65 | self.logger.info('Loading ({}) base model'.format(file_name))
66 | checkpoint_path = self.checkpoint_dir / '..' / '{}.pth'.format(file_name)
67 | checkpoint = torch.load(checkpoint_path)
68 | self.base_model.load_state_dict(checkpoint['state_dict']) # Set CUDA before if error occurs.
69 |
70 |
71 | def load_model(self, epoch=0):
72 | self.logger.info('Resume training')
73 | if epoch==0:
74 | checkpoint_path = max((f.stat().st_mtime, f) for f in self.checkpoint_dir.glob('*.pth'))[1]
75 | self.logger.info('Resume Latest from {}'.format(checkpoint_path))
76 | else:
77 | self.logger.info('Resume from {}'.format(epoch))
78 | checkpoint_path = self.checkpoint_dir / 'e{:04d}.pth'.format(epoch)
79 |
80 | checkpoint = torch.load(checkpoint_path)
81 | self.model.load_state_dict(checkpoint['state_dict']) # Set CUDA before if error occurs.
82 | self.optimizer.load_state_dict(checkpoint['optimizer'])
83 | self.epoch = checkpoint['epoch']
84 |
85 |
86 | def _extract_features(self, model, data_loader):
87 | features, targets = [], []
88 | ids = []
89 |
90 | for data, target, index in tqdm(data_loader, desc='Feature extraction for clustering..', ncols=5):
91 | data, target, index = data.cuda(), target.cuda(), index.cuda()
92 | results = model(data)
93 | features.append(results["feature"])
94 | targets.append(target)
95 | ids.append(index)
96 |
97 | features = torch.cat(features)
98 | targets = torch.cat(targets)
99 | ids = torch.cat(ids)
100 | return features, targets, ids
101 |
102 |
103 | def _cluster_features(self, data_loader, features, targets, ids, num_clusters):
104 |
105 | N = len(data_loader.dataset)
106 | num_classes = data_loader.dataset.num_classes
107 | sorted_target_clusters = torch.zeros(N).long().cuda() + num_clusters*num_classes
108 |
109 | target_clusters = torch.zeros_like(targets)-1
110 | cluster_centers = []
111 |
112 | for t in range(num_classes):
113 | target_assigns = (targets==t).nonzero().squeeze()
114 | feautre_assigns = features[target_assigns]
115 |
116 | cluster_ids, cluster_center = kmeans(X=feautre_assigns, num_clusters=num_clusters, distance='cosine', tqdm_flag=False, device=0)
117 | cluster_ids_ = cluster_ids + t*num_clusters
118 |
119 | target_clusters[target_assigns] = cluster_ids_.cuda()
120 | cluster_centers.append(cluster_center)
121 |
122 | sorted_target_clusters[ids] = target_clusters
123 | cluster_centers = torch.cat(cluster_centers, 0)
124 | return sorted_target_clusters, cluster_centers
125 |
126 |
127 | def inital_clustering(self):
128 | data_loader = self.loaders['train_eval']
129 | data_loader.dataset.clustering_on()
130 | self.base_model.eval()
131 |
132 | with torch.no_grad():
133 |
134 | features, targets, ids = self._extract_features(self.base_model, data_loader)
135 | num_clusters = self.args.num_clusters
136 | cluster_assigns, cluster_centers = self._cluster_features(data_loader, features, targets, ids, num_clusters)
137 |
138 | cluster_counts = cluster_assigns.bincount().float()
139 | print("Cluster counts : {}, len({})".format(cluster_counts, len(cluster_counts)))
140 |
141 |
142 | data_loader.dataset.clustering_off()
143 | return cluster_assigns, cluster_centers
144 |
145 |
146 | def train(self, epoch):
147 |
148 | cluster_weights = None
149 | # if epoch > 1 and not self.centroids.initialized:
150 | if not self.centroids.initialized:
151 | cluster_assigns, cluster_centers = self.inital_clustering()
152 | self.centroids.initialize_(cluster_assigns, cluster_centers)
153 |
154 | data_loader = self.loaders['train']
155 | total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
156 | criterion = torch.nn.CrossEntropyLoss(reduction='none')
157 | total_metric_loss = 0.0
158 |
159 | i = 0
160 | for data, target, _, _, cluster, weight_pre, ids in train_bar:
161 | i += 1
162 | B = target.size(0)
163 |
164 | data, target = data.cuda(), target.cuda()
165 |
166 | results = self.model(data)
167 | weight = self.centroids.get_cluster_weights(ids)
168 | loss = torch.mean(criterion(results["out"], target.long()) * (weight))
169 |
170 | self.optimizer.zero_grad()
171 | (loss).backward()
172 | self.optimizer.step()
173 |
174 | total_num += B
175 | total_loss += loss.item() * B
176 |
177 | train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, self.max_epoch, total_loss / total_num))
178 |
179 | if self.centroids.initialized:
180 | self.centroids.update(results, target, ids)
181 | if self.args.update_cluster_iter > 0 and i % self.args.update_cluster_iter == 0:
182 | self.centroids.compute_centroids()
183 |
184 |
185 | if self.centroids.initialized:
186 | self.centroids.compute_centroids(verbose=True)
187 |
188 |
189 | return total_loss / total_num
190 |
191 |
192 | def test_unbiased(self, epoch, train_eval=True):
193 | self.model.eval()
194 |
195 | test_envs = ['valid', 'test']
196 |
197 | for desc in test_envs:
198 | loader = self.loaders[desc]
199 |
200 | total_top1, total_top5, total_num, test_bar = 0.0, 0.0, 0, tqdm(loader)
201 |
202 | num_classes = len(loader.dataset.classes)
203 | num_groups = loader.dataset.num_groups
204 |
205 | bias_counts = torch.zeros(num_groups).cuda()
206 | bias_top1s = torch.zeros(num_groups).cuda()
207 |
208 | with torch.no_grad():
209 |
210 | features, labels = [], []
211 | corrects = []
212 |
213 | for data, target, biases, group, _, _, ids in test_bar:
214 | data, target, biases, group = data.cuda(), target.cuda(), biases.cuda(), group.cuda()
215 |
216 | B = target.size(0)
217 | num_groups = np.power(num_classes, biases.size(1)+1)
218 |
219 | results = self.model(data)
220 | pred_labels = results["out"].argsort(dim=-1, descending=True)
221 | features.append(results["feature"])
222 | labels.append(group)
223 |
224 |
225 | top1s = (pred_labels[:, :1] == target.unsqueeze(dim=-1)).squeeze().unsqueeze(0)
226 | group_indices = (group==torch.arange(num_groups).unsqueeze(1).long().cuda())
227 |
228 | bias_counts += group_indices.sum(1)
229 | bias_top1s += (top1s * group_indices).sum(1)
230 |
231 | corrects.append(top1s)
232 |
233 | total_num += B
234 | total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
235 | total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
236 | acc1, acc5 = total_top1 / total_num * 100, total_top5 / total_num * 100
237 |
238 | bias_accs = bias_top1s / bias_counts * 100
239 |
240 | avg_acc = np.nanmean(bias_accs.cpu().numpy())
241 | worst_acc = np.nanmin(bias_accs.cpu().numpy())
242 |
243 | acc_desc = '/'.join(['{:.1f}%'.format(acc) for acc in bias_accs])
244 |
245 | test_bar.set_description('Eval Epoch [{}/{}] [{}] Bias: {:.2f}%'.format(epoch, self.max_epoch, desc, avg_acc))
246 |
247 |
248 | log = self.logger.info if desc in ['train', 'train_eval'] else self.logger.warning
249 | log('Eval Epoch [{}/{}] [{}] Unbiased: {:.2f}% [{}]'.format(epoch, self.max_epoch, desc, avg_acc, acc_desc))
250 | self.logger.info('Total [{}]: Acc@1:{:.2f}% Acc@5:{:.2f}%'.format(desc, acc1, acc5))
251 | print(" {} / {} / {}".format(self.args.desc, self.args.target_attr, self.args.bias_attrs))
252 |
253 |
254 | self.model.train()
255 | return
256 |
--------------------------------------------------------------------------------
/dataset/celebA.py:
--------------------------------------------------------------------------------
1 | import os, pdb
2 | import torch
3 | import pandas as pd
4 | from PIL import Image
5 | import numpy as np
6 | import torchvision.transforms as transforms
7 | # from models import model_attributes
8 | from torch.utils.data import Dataset, Subset
9 | from pathlib import Path
10 | import torchvision
11 | from torch.utils import data
12 | from torch.utils.data.sampler import WeightedRandomSampler
13 |
14 |
15 |
16 | class CelebA(torchvision.datasets.CelebA):
17 |
18 | # Attributes : '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young'
19 |
20 | def __init__(self, root, split="train", target_type="attr", transform=None,
21 | target_transform=None, download=False,
22 | target_attr='', aux_attrs=[], bias_attrs=[], domain_attr=None, domain_type=None,
23 | pair=False, args=None, scale=1.0):
24 |
25 | super().__init__(root, split=split, target_type=target_type, transform=transform,
26 | target_transform=target_transform, download=download)
27 |
28 | self.target_attr = target_attr
29 | self.aux_attrs = aux_attrs
30 | self.bias_attrs = bias_attrs
31 | self.domain_attr = domain_attr
32 |
33 | self.target_idx = self.attr_names.index(target_attr)
34 | self.aux_indices = [self.attr_names.index(aux_att) for aux_att in aux_attrs] if aux_attrs else []
35 | self.domain_idx = self.attr_names.index(domain_attr) if domain_attr else None
36 |
37 | self.domain_type = domain_type
38 |
39 | self.bias_indices = [self.attr_names.index(bias_att) for bias_att in bias_attrs]
40 |
41 | self.cluster_ids = None
42 | self.clustering = False
43 | self.sample_weights = None
44 |
45 | self.pair = pair
46 |
47 | self.visualize_image = False
48 |
49 | self.args = args
50 | self.visualize = False
51 | self.scale = scale
52 |
53 |
54 |
55 | @property
56 | def class_elements(self):
57 | return self.attr[:, self.target_idx]
58 |
59 | @property
60 | def group_elements(self):
61 | group_attrs = self.attr[:, [self.target_idx]+self.bias_indices]
62 | weight = np.power(self.num_classes, np.arange(group_attrs.size(1)))
63 | group_elems = (group_attrs*weight).sum(1)
64 | return group_elems
65 |
66 | @property
67 | def group_counts(self):
68 | group_attrs = self.attr[:, [self.target_idx]+self.bias_indices]
69 | weight = np.power(self.num_classes, np.arange(group_attrs.size(1)))
70 | group_elems = (group_attrs*weight).sum(1)
71 | return group_elems.bincount()
72 |
73 |
74 | def group_counts_with_attr(self, attr):
75 | target_idx = self.attr_names.index(attr)
76 | group_attrs = self.attr[:, [target_idx]+self.bias_indices]
77 | weight = np.power(self.num_classes, np.arange(group_attrs.size(1)))
78 | group_elems = (group_attrs*weight).sum(1)
79 | return group_elems.bincount()
80 |
81 | def visualize(self):
82 | self.visualize_image = True
83 |
84 | def clustering_on(self):
85 | self.clustering = True
86 |
87 | def clustering_off(self):
88 | self.clustering = False
89 |
90 | def update_clusters(self, cluster_ids):
91 | self.cluster_ids = cluster_ids
92 |
93 | def update_weights(self, weights):
94 | self.sample_weights = weights
95 |
96 |
97 | def __len__(self):
98 | len = super().__len__()
99 | if self.scale < 1.0:
100 | len = int(len*self.scale)
101 | return len
102 |
103 |
104 | def get_sample_index(self, index):
105 | return index
106 |
107 | def __getitem__(self, index_):
108 |
109 | index = self.get_sample_index(index_)
110 |
111 | img_path = os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])
112 | img_ = Image.open(img_path)
113 |
114 | target = []
115 | for t in self.target_type:
116 | if t == "attr":
117 | target.append(self.attr[index, :])
118 | elif t == "identity":
119 | target.append(self.identity[index, 0])
120 | elif t == "bbox":
121 | target.append(self.bbox[index, :])
122 | elif t == "landmarks":
123 | target.append(self.landmarks_align[index, :])
124 | else:
125 | # TODO: refactor with utils.verify_str_arg
126 | raise ValueError("Target type \"{}\" is not recognized.".format(t))
127 |
128 | if target:
129 | target = tuple(target) if len(target) > 1 else target[0]
130 |
131 | if self.target_transform is not None:
132 | target = self.target_transform(target)
133 | else:
134 | target = None
135 |
136 | target_attr = target[self.target_idx]
137 | bias_attrs = np.array([target[bias_idx] for bias_idx in self.bias_indices])
138 | group_attrs = np.insert(bias_attrs, 0, target_attr) # target first
139 |
140 | bit = np.power(self.num_classes, np.arange(len(group_attrs)))
141 | group = np.sum(bit * group_attrs)
142 |
143 |
144 | if self.cluster_ids is not None:
145 | cluster = self.cluster_ids[index]
146 | else:
147 | cluster = -1
148 |
149 | if self.sample_weights is not None:
150 | weight = self.sample_weights[index]
151 | else:
152 | weight = 1
153 |
154 |
155 | if self.transform is not None:
156 | transform = self.transform
157 | img = transform(img_)
158 |
159 | # for clustering
160 | if self.clustering is True:
161 | return img, target_attr, index
162 |
163 |
164 | return img, target_attr, bias_attrs, group, cluster, weight, index
165 |
166 |
167 | @property
168 | def classes(self):
169 | return ['0', '1']
170 |
171 | @property
172 | def num_classes(self):
173 | return len(self.classes)
174 |
175 | @property
176 | def num_groups(self):
177 | return np.power(len(self.classes), len(self.bias_attrs)+1)
178 |
179 | @property
180 | def bias_attributes(self):
181 | return
182 |
183 | @property
184 | def attribute_names(self):
185 | return ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young']
186 |
187 |
188 |
189 |
190 |
191 | def get_celebA_dataloader(root, batch_size, split, target_attr, bias_attrs, aux_attrs=None, num_workers=4, pair=False, cluster_ids=None, args=None):
192 |
193 | from factory import TransformFactory
194 |
195 | ### Transform and scale
196 | if split in ['train', 'train_target']:
197 | celebA_transform = TransformFactory.create("celebA_train")
198 |
199 | elif split in ['valid', 'test', 'train_eval']:
200 | celebA_transform = TransformFactory.create("celebA_test")
201 |
202 | ### Dataset split
203 | celebDataset = CelebA
204 | if split in ['train', 'train_eval']:
205 | dataset_split = 'train'
206 | elif split in ['valid']:
207 | dataset_split = 'valid'
208 | elif split in ['test']:
209 | dataset_split = 'test'
210 |
211 |
212 | dataset = celebDataset(root, split=dataset_split, transform=celebA_transform, download=True,
213 | target_attr=target_attr, bias_attrs=bias_attrs, aux_attrs=aux_attrs, args=args)
214 |
215 |
216 | dataloader = data.DataLoader(dataset=dataset,
217 | batch_size=batch_size,
218 | shuffle=True,
219 | num_workers=num_workers,
220 | pin_memory=True)
221 |
222 | return dataloader, dataset
223 |
224 |
--------------------------------------------------------------------------------
/factory.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Dict, Iterable, List, Optional
2 | from pathlib import Path
3 | from dataset.celebA import CelebA, get_celebA_dataloader
4 |
5 | import torchvision
6 | from modules.transform import *
7 | from models.classification import ResNet18
8 |
9 | from trainer import BiasedClassifyTrainer
10 | from bpa_trainer import BPATrainer
11 | from torch.utils.data import DataLoader
12 | import pdb
13 |
14 |
15 |
16 | class Factory(object):
17 |
18 | PRODUCTS: Dict[str, Callable] = {}
19 |
20 | def __init__(self):
21 | raise ValueError(
22 | f"""Cannot instantiate {self.__class__.__name__} object, use
23 | `create` classmethod to create a product from this factory.
24 | """
25 | )
26 |
27 | @classmethod
28 | def create(cls, name: str, *args, **kwargs) -> Any:
29 | r"""Create an object by its name, args and kwargs."""
30 | if name not in cls.PRODUCTS:
31 | raise KeyError(f"{cls.__class__.__name__} cannot create {name}.")
32 |
33 | return cls.PRODUCTS[name](*args, **kwargs)
34 |
35 |
36 |
37 | class ModelFactory(Factory):
38 |
39 |
40 | MODELS: Dict[str, Callable] = {
41 | "ResNet18": ResNet18,
42 | }
43 |
44 | @classmethod
45 | def create(cls, name: str, *args, **kwargs) -> Any:
46 |
47 | return cls.MODELS[name](*args, **kwargs)
48 |
49 |
50 |
51 |
52 | class TransformFactory(Factory):
53 | PRODUCTS: Dict[str, Callable] = {
54 |
55 | "train": train_transform,
56 | "test": test_transform,
57 |
58 | "celebA_train": celebA_train_transform,
59 | "celebA_test": celebA_test_transform,
60 |
61 | }
62 |
63 | @classmethod
64 | def create(cls, name: str, *args, **kwargs) -> Any:
65 | r"""Create an object by its name, args and kwargs."""
66 | if name not in cls.PRODUCTS:
67 | raise KeyError(f"{cls.__class__.__name__} cannot create {name}.")
68 |
69 | return cls.PRODUCTS[name]
70 |
71 |
72 |
73 |
74 |
75 | class DataLoaderFactory(Factory):
76 |
77 | @classmethod
78 | def create(cls, name: str, trainer: str, batch_size: int, num_workers: int, configs: Any, cluster_ids: Any = None) -> Any:
79 |
80 | if name == 'celebA':
81 |
82 | train_loader, train_set = get_celebA_dataloader(
83 | root=Path('./data/celebA'), batch_size=batch_size, split='train',
84 | target_attr=configs.target_attr, bias_attrs=configs.bias_attrs,
85 | cluster_ids=cluster_ids, args=configs)
86 | valid_loader, valid_set = get_celebA_dataloader(
87 | root=Path('./data/celebA'), batch_size=batch_size, split='valid',
88 | target_attr=configs.target_attr, bias_attrs=configs.bias_attrs,
89 | args=configs)
90 | test_loader, test_set = get_celebA_dataloader(
91 | root=Path('./data/celebA'), batch_size=batch_size, split='test',
92 | target_attr=configs.target_attr, bias_attrs=configs.bias_attrs,
93 | args=configs)
94 | train_eval_loader, train_eval_set = get_celebA_dataloader(
95 | root=Path('./data/celebA'), batch_size=batch_size, split='train_eval',
96 | target_attr=configs.target_attr, bias_attrs=configs.bias_attrs,
97 | args=configs)
98 |
99 | datasets = {
100 | 'train': train_set,
101 | 'valid': valid_set,
102 | 'test': test_set,
103 | 'train_eval': train_eval_set,
104 | }
105 |
106 | data_loaders = {
107 | 'train': train_loader,
108 | 'valid': valid_loader,
109 | 'test': test_loader,
110 | 'train_eval': train_eval_loader,
111 | }
112 |
113 |
114 | else:
115 | raise ValueError
116 |
117 |
118 | return data_loaders, datasets
119 |
120 |
121 |
122 | class TrainerFactory(Factory):
123 |
124 | TRAINERS: Dict[str, Callable] = {
125 | "classify": BiasedClassifyTrainer,
126 | "bpa": BPATrainer,
127 | }
128 |
129 | @classmethod
130 | def create(cls, name: str, *args, **kwargs) -> Any:
131 |
132 | return cls.TRAINERS[name](*args, **kwargs)
133 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, copy
3 |
4 | import numpy as np
5 | import torch
6 | import torch.optim as optim
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.utils.data import DataLoader
10 | from tqdm import tqdm
11 |
12 | # import utils
13 | from modules.transform import *
14 | from model import Model
15 | import os, sys, logging, time, random, json, pdb
16 | from pathlib import Path
17 |
18 |
19 | from factory import ModelFactory, TrainerFactory, DataLoaderFactory, TransformFactory
20 |
21 |
22 | import wandb
23 |
24 | DATA_ROOT = Path('./data')
25 | CHECKPOINT_ROOT = Path('./checkpoint')
26 |
27 |
28 | def main():
29 | parser = argparse.ArgumentParser(description='Train SimCLR')
30 | parser.add_argument('--feature_dim', default=512, type=int, help='Feature dim for latent vector')
31 | parser.add_argument('--lr', default=1e-3, type=float, help='Learning Rate')
32 | parser.add_argument('--batch_size', default=256, type=int, help='Number of images in each mini-batch')
33 | parser.add_argument('--max_epoch', default=50, type=int, help='Number of sweeps over the dataset to train')
34 | parser.add_argument('--test_epoch', default=25, type=int, help='Test epoch')
35 | parser.add_argument('--save_epoch', default=25, type=int, help='Save epoch')
36 | parser.add_argument('--train_eval_epoch', default=50, type=int, help='Save epoch')
37 |
38 | parser.add_argument('--dataset', default='celebA', type=str, help='Dataset')
39 | parser.add_argument('--arch', default='ResNet18', type=str, help='Model architecture')
40 | parser.add_argument('--trainer', default='classify', type=str, help='Training scheme')
41 | parser.add_argument('--cluster_weight_type', default='scale_loss', type=str, help='Training scheme')
42 | parser.add_argument('--centroid', default='avgfixed', type=str, help='')
43 |
44 | parser.add_argument('--target_attr', default='', type=str, help='Target attributes')
45 | parser.add_argument('--bias_attrs', nargs='+', help='Bias attributes')
46 |
47 | parser.add_argument('--num_partitions', default=1, type=int, help='Test epoch')
48 | parser.add_argument('--k', default=1, type=int, help='# of clusters')
49 | parser.add_argument('--ks', default=[], nargs='+', help='# of clusters list (multi)')
50 | parser.add_argument('--update_cluster_iter', default=10, type=int, help='0 for every epoch')
51 | parser.add_argument('--feature_bank_init', action='store_true')
52 | parser.add_argument('--num_multi_centroids', default=1, type=int, help='# of centroids')
53 |
54 | parser.add_argument('--desc', default='test', type=str, help='Checkpoint folder name')
55 | parser.add_argument('--load_epoch', default=-1, type=int, help='Load model epoch')
56 | parser.add_argument('--weight_decay', default=1e-2, type=float, help='Weight decay')
57 | parser.add_argument('--momentum', default=0.3, type=float, help='Positive class priorx')
58 | parser.add_argument('--adj', default=2.0, type=float, help='Label noise ratio')
59 | parser.add_argument('--adj_type', default='', type=str, help='multiply or default')
60 | parser.add_argument('--exp_step', default=0.05, type=float, help='Exponential step size for weight averaging in AvgFixedCentroids')
61 | parser.add_argument('--avg_weight_type', default='expavg', type=str, help='avg type for weight averaging in AvgFixedCentroids')
62 | parser.add_argument('--overlap_type', default='exclusive', type=str, help='Channel overlap type for hetero clustering, [exclusive, half_exclusive]')
63 | parser.add_argument('--gamma_reverse', action='store_true')
64 | parser.add_argument('--scale', default=1.0, type=float, help='Dataset scale')
65 | parser.add_argument('--sampling', default='', type=str, help='class_subsampling/class_resampling')
66 |
67 | parser.add_argument('--use_base', default='', type=str, help='use base model')
68 | parser.add_argument('--load_base', default='', type=str, help='load checkpoint file as base model')
69 | parser.add_argument('--load_path', default='', type=str, help='load checkpoint file')
70 |
71 | parser.add_argument('--scheduler', default='cosine', type=str, help='cosine')
72 | parser.add_argument('--scheduler_param', default=100, type=int)
73 |
74 | parser.add_argument('--resume', default='', type=str, help='Run ID')
75 | parser.add_argument('--optim', default='adam', type=str, help='adam or sgd')
76 | parser.add_argument('--no_save', action='store_true')
77 | parser.add_argument('--verbose', action='store_true')
78 | parser.add_argument('--feature_fix', action='store_true')
79 |
80 | parser.add_argument('--eval', action='store_true')
81 |
82 | args = parser.parse_args()
83 | args.num_clusters = args.k
84 | args.num_multi_centroids = len(args.ks)
85 |
86 | # savings
87 | checkpoint_dir = CHECKPOINT_ROOT / args.dataset / args.target_attr / args.desc
88 | if not checkpoint_dir.exists():
89 | checkpoint_dir.mkdir(parents=True, exist_ok=True)
90 |
91 | loaders, datasets = DataLoaderFactory.create(args.dataset, trainer=args.trainer, batch_size=args.batch_size,
92 | num_workers=4, configs=args)
93 |
94 | num_classes = len(datasets['test'].classes)
95 | print('# Classes: {}'.format(num_classes))
96 |
97 | model_args = {
98 | "name": args.arch,
99 | "feature_dim": args.feature_dim,
100 | "num_classes": num_classes,
101 | "feature_fix": args.feature_fix,
102 | }
103 |
104 | # model setup and optimizer config
105 | model = ModelFactory.create(**model_args).cuda()
106 | model = nn.DataParallel(model)
107 | if args.optim == 'sgd':
108 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
109 | else:
110 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
111 |
112 |
113 | scheduler = None
114 | if args.scheduler == 'cosine':
115 | assert args.scheduler_param != 0
116 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.scheduler_param)
117 |
118 |
119 | args.checkpoint_dir = checkpoint_dir
120 | trainer = TrainerFactory.create(args.trainer, args, model, loaders, optimizer, num_classes)
121 | trainer.set_checkpoint_dir(checkpoint_dir)
122 | start_epoch = 1
123 |
124 | if args.load_base:
125 | trainer.load_base_model(args.load_base)
126 |
127 | if args.use_base:
128 | trainer.use_base_model(args.use_base)
129 |
130 | if args.load_epoch >= 0:
131 | trainer.load_model(args.load_epoch)
132 | start_epoch = trainer.epoch + 1
133 |
134 | if args.load_path:
135 | if args.origin_attr:
136 | path = '{}/{}'.format(args.origin_attr, args.load_path)
137 | else:
138 | trainer.load_path(args.load_path)
139 |
140 | if args.eval:
141 | trainer.test_unbiased(epoch=start_epoch-1)
142 | return
143 |
144 |
145 | for epoch in range(start_epoch, args.max_epoch+1):
146 | trainer.train(epoch=epoch)
147 |
148 | if epoch % args.test_epoch == 0:
149 | trainer.save_model(epoch=epoch)
150 | trainer.test_unbiased(epoch=epoch)
151 |
152 | if scheduler is not None:
153 | scheduler.step()
154 |
155 |
156 | if __name__ == "__main__":
157 | main()
158 |
159 |
160 |
--------------------------------------------------------------------------------
/models/classification.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision.models
5 | import pdb
6 | from utils.train_utils import *
7 |
8 |
9 |
10 | class ResNet(nn.Module):
11 | def __init__(self, feature_dim, num_classes, arch='', feature_fix=False):
12 | super(ResNet, self).__init__()
13 |
14 | self.feature_dim = feature_dim
15 | self.num_classes = num_classes
16 | self.arch = arch
17 | resnet = self.get_backbone()
18 | self.conv1 = resnet.conv1
19 | self.bn1 = resnet.bn1
20 | self.relu = resnet.relu # 1/2, 64
21 | self.maxpool = resnet.maxpool
22 |
23 | self.res2 = resnet.layer1 # 1/4, 64
24 | self.res3 = resnet.layer2 # 1/8, 128
25 | self.res4 = resnet.layer3 # 1/16, 256
26 | self.res5 = resnet.layer4 # 1/32, 512
27 |
28 | self.f = nn.Sequential(self.conv1, self.bn1, self.relu, self.maxpool,
29 | self.res2, self.res3, self.res4, self.res5)
30 |
31 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
32 | # classifier
33 | self.fc = self.get_fc(num_classes)
34 |
35 |
36 | self.feature_fix = feature_fix
37 | if feature_fix:
38 | print("Fix parameters except fc layer")
39 | for param in self.parameters():
40 | param.requires_grad = False
41 |
42 | self.fc.weight.requires_grad = True
43 | self.fc.bias.requires_grad = True
44 |
45 |
46 | def get_backbone(self):
47 | raise NotImplementedError
48 |
49 | def get_fc(self, num_classes):
50 | raise NotImplementedError
51 |
52 | def forward(self, x):
53 |
54 | feature = self.f(x)
55 | feature = torch.flatten(self.avgpool(feature), start_dim=1)
56 |
57 | logits = self.fc(feature)
58 |
59 | results = {
60 | "out": logits,
61 | "feature": feature,
62 | }
63 | return results
64 |
65 |
66 | class ResNet18(ResNet):
67 |
68 | def get_backbone(self):
69 | return torchvision.models.resnet18(pretrained=True)
70 |
71 | def get_fc(self, num_classes):
72 | return nn.Linear(512, num_classes, bias=True)
73 |
74 |
75 |
76 |
--------------------------------------------------------------------------------
/modules/centroids.py:
--------------------------------------------------------------------------------
1 | import torch, pdb
2 | from torch import nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from kmeans_pytorch import kmeans
6 |
7 | from utils.train_utils import grad_mul_const
8 |
9 |
10 | class Centroids(nn.Module):
11 |
12 | def __init__(self, args, num_classes, per_clusters, feature_dim=None):
13 | super(Centroids, self).__init__()
14 | self.momentum = args.momentum
15 | self.per_clusters = per_clusters
16 | self.num_classes = num_classes
17 | self.feature_dim = args.feature_dim if feature_dim is None else feature_dim
18 |
19 | # Cluster
20 | self.cluster_means = None
21 | self.cluster_vars = torch.zeros((self.num_classes, self.per_clusters))
22 | self.cluster_losses = torch.zeros((self.num_classes, self.per_clusters))
23 | self.cluster_accs = torch.zeros((self.num_classes, self.per_clusters))
24 | self.cluster_weights = torch.zeros((self.num_classes, self.per_clusters))
25 |
26 | # Sample
27 | self.feature_bank = None
28 | self.assigns = None
29 | self.corrects = None
30 | self.losses = None
31 | self.weights = None
32 |
33 | self.initialized = False
34 | self.weight_type = args.cluster_weight_type
35 |
36 | self.max_cluster_weights = 0. # 0 means no-limit
37 |
38 | def __repr__(self):
39 | return "{}(Y{}/K{}/dim{})".format(self.__class__.__name__, self.num_classes, self.per_clusters, self.feature_dim)
40 |
41 | @property
42 | def num_clusters(self):
43 | return self.num_classes * self.per_clusters
44 |
45 | @property
46 | def cluster_counts(self):
47 | if self.assigns is None:
48 | return 0
49 | return self.assigns.bincount(minlength=self.num_clusters)
50 |
51 |
52 | def _clamp_weights(self, weights):
53 | if self.max_cluster_weights > 0:
54 | if weights.max() > self.max_cluster_weights:
55 | scale = np.log(self.max_cluster_weights)/torch.log(weights.cpu().max())
56 | scale = scale.cuda()
57 | print("> Weight : {:.4f}, scale : {:.4f}".format(weights.max(), scale))
58 | return weights ** scale
59 | return weights
60 |
61 |
62 | def get_cluster_weights(self, ids):
63 | if self.assigns is None:
64 | return 1
65 |
66 | cluster_counts = self.cluster_counts + (self.cluster_counts==0).float() # avoid nans
67 |
68 | cluster_weights = cluster_counts.sum()/(cluster_counts.float())
69 | assigns_id = self.assigns[ids]
70 |
71 | if (self.losses == -1).nonzero().size(0) == 0:
72 | cluster_losses_ = self.cluster_losses.view(-1)
73 | losses_weight = cluster_losses_.float()/cluster_losses_.sum()
74 | weights_ = cluster_weights[assigns_id] * losses_weight[assigns_id].cuda()
75 | weights_ /= weights_.mean()
76 | else:
77 | weights_ = cluster_weights[assigns_id]
78 | weights_ /= weights_.mean()
79 |
80 | return self._clamp_weights(weights_)
81 |
82 |
83 | def initialize_(self, cluster_assigns, cluster_centers, sorted_features=None):
84 | cluster_means = cluster_centers.detach().cuda()
85 | cluster_means = F.normalize(cluster_means, 1)
86 | self.cluster_means = cluster_means.view(self.num_classes, self.per_clusters, -1)
87 |
88 | N = cluster_assigns.size(0)
89 | self.feature_bank = torch.zeros((N, self.feature_dim)).cuda() if sorted_features is None else sorted_features
90 | self.assigns = cluster_assigns
91 | self.corrects = torch.zeros((N)).long().cuda() - 1
92 | self.losses = torch.zeros((N)).cuda() - 1
93 | self.weights = torch.ones((N)).cuda()
94 | self.initialized = True
95 |
96 |
97 | def get_variances(self, x, y):
98 | return 1 - (y @ x).mean(0)
99 |
100 | def compute_centroids(self, verbose=False, split=False):
101 | for y in range(self.num_classes):
102 | for k in range(self.per_clusters):
103 | l = y*self.per_clusters + k
104 | ids = (self.assigns==l).nonzero()
105 | if ids.size(0) == 0:
106 | continue
107 | self.cluster_means[y, k] = self.feature_bank[ids].mean(0)
108 | self.cluster_vars[y, k] = self.get_variances(self.cluster_means[y, k], self.feature_bank[ids])
109 |
110 | corrs = self.corrects[ids]
111 | corrs_nz = (corrs[:, 0]>=0).nonzero()
112 | if corrs_nz.size(0) > 0:
113 | self.cluster_accs[y, k] = corrs[corrs_nz].float().mean(0)
114 |
115 | losses = self.losses[ids]
116 | loss_nz = (losses[:, 0]>=0).nonzero()
117 | if loss_nz.size(0) > 0:
118 | self.cluster_losses[y, k] = losses[loss_nz].float().mean(0)
119 |
120 | return
121 |
122 |
123 | def update(self, results, target, ids, features=None):
124 | assert self.initialized
125 |
126 | ### update feature and assigns
127 | feature = results["feature"] if features is None else features
128 | feature_ = F.normalize(feature, 1).detach()
129 |
130 | feature_new = (1-self.momentum) * self.feature_bank[ids] + self.momentum * feature_
131 | feature_new = F.normalize(feature_new, 1)
132 |
133 | self.feature_bank[ids] = feature_new
134 |
135 | sim_score = self.cluster_means @ feature_new.permute(1, 0) # YKC/CB => YKB
136 |
137 | for y in range(self.num_classes):
138 | sim_score[y, :, (target!=y).nonzero()] -= 1e4
139 |
140 | sim_score_ = sim_score.view(self.num_clusters, -1)
141 | new_assigns = sim_score_.argmax(0)
142 | self.assigns[ids] = new_assigns
143 |
144 | corrects = (results["out"].argmax(1) == target).long()
145 | self.corrects[ids] = corrects
146 |
147 | criterion = torch.nn.CrossEntropyLoss(reduction='none')
148 | losses = criterion(results["out"], target.long()).detach()
149 | self.losses[ids] = losses
150 |
151 | return
152 |
153 |
154 |
155 | class FixedCentroids(Centroids):
156 |
157 | def compute_centroids(self, verbose='', split=False):
158 |
159 | for y in range(self.num_classes):
160 | for k in range(self.per_clusters):
161 | l = y*self.per_clusters + k
162 |
163 | ids = (self.assigns==l).nonzero()
164 | if ids.size(0) == 0:
165 | continue
166 |
167 | corrs = self.corrects[ids]
168 | corrs_nz = (corrs[:, 0]>=0).nonzero()
169 | if corrs_nz.size(0) > 0:
170 | self.cluster_accs[y, k] = corrs[corrs_nz].float().mean(0)
171 |
172 | losses = self.losses[ids]
173 | loss_nz = (losses[:, 0]>=0).nonzero()
174 | if loss_nz.size(0) > 0:
175 | self.cluster_losses[y, k] = losses[loss_nz].float().mean(0)
176 |
177 | self.cluster_weights[y, k] = self.weights[ids].float().mean(0)
178 |
179 | return
180 |
181 | def get_cluster_weights(self, ids):
182 | weights_ids = super().get_cluster_weights(ids)
183 | self.weights[ids] = weights_ids
184 | return weights_ids
185 |
186 |
187 | def update(self, results, target, ids, features=None, preds=None):
188 | assert self.initialized
189 |
190 | out = preds if preds is not None else results["out"]
191 |
192 | corrects = (out.argmax(1) == target).long()
193 | self.corrects[ids] = corrects
194 |
195 | criterion = torch.nn.CrossEntropyLoss(reduction='none')
196 | losses = criterion(out, target.long()).detach()
197 | self.losses[ids] = losses
198 |
199 | return
200 |
201 |
202 |
203 | class AvgFixedCentroids(FixedCentroids):
204 |
205 | def __init__(self, args, num_classes, per_clusters, feature_dim=None):
206 | super(AvgFixedCentroids, self).__init__(args, num_classes, per_clusters, feature_dim)
207 | self.exp_step = args.exp_step
208 | self.avg_weight_type = args.avg_weight_type
209 |
210 | def compute_centroids(self, verbose='', split=False):
211 |
212 | for y in range(self.num_classes):
213 | for k in range(self.per_clusters):
214 | l = y*self.per_clusters + k
215 |
216 | ids = (self.assigns==l).nonzero()
217 | if ids.size(0) == 0:
218 | continue
219 |
220 | corrs = self.corrects[ids]
221 | corrs_nz = (corrs[:, 0]>=0).nonzero()
222 | if corrs_nz.size(0) > 0:
223 | self.cluster_accs[y, k] = corrs[corrs_nz].float().mean(0)
224 |
225 | losses = self.losses[ids]
226 | loss_nz = (losses[:, 0]>=0).nonzero()
227 | if loss_nz.size(0) > 0:
228 | self.cluster_losses[y, k] = losses[loss_nz].float().mean(0)
229 |
230 | self.cluster_weights[y, k] = self.weights[ids].float().mean(0)
231 |
232 | return
233 |
234 |
235 | def get_cluster_weights(self, ids):
236 |
237 | weights_ids = super().get_cluster_weights(ids)
238 |
239 | if self.avg_weight_type == 'expavg':
240 | weights_ids_ = self.weights[ids] * torch.exp(self.exp_step*weights_ids.data)
241 | elif self.avg_weight_type == 'avg':
242 | weights_ids_ = (1-self.momentum) * self.weights[ids] + self.momentum * weights_ids
243 | elif self.avg_weight_type == 'expgrad':
244 | weights_ids_l1 = weights_ids / weights_ids.sum()
245 | prev_ids_l1 = self.weights[ids] / self.weights[ids].sum()
246 | weights_ids_ = prev_ids_l1 * torch.exp(self.exp_step*weights_ids_l1.data)
247 | else:
248 | raise ValueError
249 |
250 | self.weights[ids] = weights_ids_ / weights_ids_.mean()
251 | return self.weights[ids]
252 |
253 |
254 |
255 | class HeteroCentroids(nn.Module):
256 |
257 | def __init__(self, args, num_classes, num_hetero_clusters, centroids_type):
258 | super(HeteroCentroids, self).__init__()
259 | self.momentum = args.momentum
260 | self.num_classes = num_classes
261 | self.feature_dim = args.feature_dim
262 | self.initialized = False
263 |
264 | self.num_hetero_clusters = num_hetero_clusters
265 | self.num_multi_centroids = len(num_hetero_clusters)
266 | self.centroids_list = [centroids_type(args, num_classes, per_clusters=num_hetero_clusters[m], feature_dim=self.feature_dim) for m in range(self.num_multi_centroids)]
267 |
268 | def __repr__(self):
269 | return self.__class__.__name__ + "(" + ", ".join([centroids.__repr__() for centroids in self.centroids_list])+ ")"
270 |
271 |
272 | def initialize_multi(self, multi_cluster_assigns, multi_cluster_centers):
273 | for cluster_assigns, cluster_centers, centroids in zip(multi_cluster_assigns, multi_cluster_centers, self.centroids_list):
274 | centroids.initialize_(cluster_assigns, cluster_centers)
275 | self.initialized = True
276 |
277 | def compute_centroids(self, verbose=False, split=False):
278 | for m, centroids in enumerate(self.centroids_list):
279 | verbose_ = str(m) if verbose else ''
280 | centroids.compute_centroids(verbose=verbose_)
281 |
282 |
283 | def update(self, results, target, ids):
284 | for m, centroids in enumerate(self.centroids_list):
285 | features = results["feature"]
286 | centroids.update(results, target, ids)
287 |
288 |
289 | def get_cluster_weights(self, ids):
290 |
291 | weights_list = [centroids.get_cluster_weights(ids) for centroids in self.centroids_list]
292 | weights_ = torch.stack(weights_list).mean(0)
293 | return weights_
294 |
295 |
296 |
297 |
298 |
299 |
300 |
--------------------------------------------------------------------------------
/modules/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import numpy as np
6 |
7 |
8 | class GeneralizedCELoss(nn.Module):
9 |
10 | def __init__(self, q=0.7):
11 | super(GeneralizedCELoss, self).__init__()
12 | self.q = q
13 |
14 | def forward(self, logits, targets):
15 | p = F.softmax(logits, dim=1)
16 | if np.isnan(p.mean().item()):
17 | raise NameError('GCE_p')
18 | Yg = torch.gather(p, 1, torch.unsqueeze(targets, 1))
19 | # modify gradient of cross entropy
20 | loss_weight = (Yg.squeeze().detach()**self.q)*self.q
21 | if np.isnan(Yg.mean().item()):
22 | raise NameError('GCE_Yg')
23 |
24 | loss = F.cross_entropy(logits, targets, reduction='none') * loss_weight
25 |
26 | return loss
27 |
--------------------------------------------------------------------------------
/modules/transform.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from torchvision import transforms
3 | import cv2
4 | import numpy as np
5 |
6 |
7 | train_transform = transforms.Compose([
8 | transforms.RandomResizedCrop(32),
9 | transforms.RandomHorizontalFlip(p=0.5),
10 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
11 | transforms.RandomGrayscale(p=0.2),
12 | transforms.ToTensor(),
13 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
14 |
15 | test_transform = transforms.Compose([
16 | transforms.ToTensor(),
17 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
18 |
19 |
20 |
21 |
22 | celebA_test_transform = transforms.Compose([
23 | transforms.Resize(224),
24 | transforms.ToTensor(),
25 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
26 | ])
27 |
28 | celebA_train_transform = transforms.Compose([
29 | transforms.RandomResizedCrop(224),
30 | transforms.RandomHorizontalFlip(),
31 | transforms.ToTensor(),
32 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
33 | ])
34 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.4.0
2 | torchvision>=0.2.1
3 | matplotlib
4 | pathlib
5 | pandas
6 | kmeans_pytorch
7 | configargparse
8 | tqdm
9 | seaborn
10 | opencv-python
11 | coloredlogs
12 | wandb
13 | numpy
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, pdb
3 | import logging
4 | from PIL import Image
5 |
6 | import numpy as np
7 | import torch
8 | import torch.optim as optim
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torch.utils.data import DataLoader
12 | from tqdm import tqdm
13 | from torch.distributions import Categorical
14 |
15 | from modules.transform import *
16 | from model import Model
17 | from utils.train_utils import *
18 | from utils.io_utils import *
19 | import importlib
20 |
21 |
22 | import pandas as pd
23 | from sklearn.manifold import TSNE
24 | import seaborn as sns
25 | from matplotlib import pyplot as plt
26 | from kmeans_pytorch import kmeans
27 |
28 | import wandb
29 | import pickle
30 |
31 |
32 | class Trainer():
33 |
34 | def __init__(self, args, model, loaders, optimizer, num_classes):
35 | print(self)
36 | self.args = args
37 | self.model = model
38 | self.loaders = loaders
39 | self.optimizer = optimizer
40 |
41 | self.max_epoch = args.max_epoch
42 | self.batch_size = args.batch_size
43 |
44 | self.k = args.k
45 | self.num_classes = num_classes
46 | self.num_groups = np.power(self.num_classes, len(self.args.bias_attrs)+1)
47 | self.num_clusters = self.k
48 |
49 | self.logger = get_logger('')
50 |
51 | self.accs = WindowAvgMeter(name='accs', max_count=20)
52 |
53 |
54 | def set_checkpoint_dir(self, checkpoint_dir):
55 | self.checkpoint_dir = checkpoint_dir
56 |
57 |
58 | def save_model(self, epoch):
59 | if not self.args.no_save or epoch % self.args.save_epoch == 0:
60 | torch.save({
61 | 'epoch': epoch,
62 | 'state_dict': self.model.state_dict(),
63 | 'optimizer' : self.optimizer.state_dict(),
64 | }, self.checkpoint_dir / 'e{:04d}.pth'.format(epoch))
65 | return
66 |
67 | def load_model(self, epoch=0):
68 | self.logger.info('Resume training')
69 | if epoch==0:
70 | checkpoint_path = max((f.stat().st_mtime, f) for f in self.checkpoint_dir.glob('*.pth'))[1]
71 | self.logger.info('Resume Latest from {}'.format(checkpoint_path))
72 | else:
73 | checkpoint_path = self.checkpoint_dir / 'e{:04d}.pth'.format(epoch)
74 | self.logger.info('Resume from {}'.format(checkpoint_path))
75 |
76 | checkpoint = torch.load(checkpoint_path)
77 | self.model.load_state_dict((checkpoint['state_dict'])) # Set CUDA before if error occurs.
78 | self.optimizer.load_state_dict(checkpoint['optimizer'])
79 | self.epoch = checkpoint['epoch']
80 |
81 |
82 | def load_path(self, file_name):
83 | self.logger.info('Loading model at ({})'.format(file_name))
84 | checkpoint_path = self.checkpoint_dir / '..' / '..' / '{}.pth'.format(file_name)
85 | checkpoint = torch.load(checkpoint_path)
86 | self.model.load_state_dict(checkpoint['state_dict']) # Set CUDA before if error occurs.
87 |
88 |
89 | def finetune(self, epoch, iter):
90 | return
91 |
92 |
93 | def extract_sample(self, data, desc):
94 | data_ = data.permute(1,2,0).cpu().numpy().astype(np.uint8)
95 | img = Image.fromarray(data_)
96 | img.save(self.checkpoint_dir / '{}.png'.format(desc))
97 |
98 |
99 |
100 | def _extract_features_with_path(self, model, data_loader):
101 | features, targets = [], []
102 | ids = []
103 | paths = []
104 |
105 | for data, target, index, path in tqdm(data_loader, desc='Feature extraction for clustering..', ncols=5):
106 | data, target, index = data.cuda(), target.cuda(), index.cuda()
107 | results = model(data)
108 | features.append(results["feature"])
109 | targets.append(target)
110 | ids.append(index)
111 | paths.append(path)
112 |
113 | features = torch.cat(features)
114 | targets = torch.cat(targets)
115 | ids = torch.cat(ids)
116 | paths = np.concatenate(paths)
117 | return features, targets, ids, paths
118 |
119 |
120 |
121 |
122 | class ClassifyTrainer(Trainer):
123 |
124 | def train(self, epoch):
125 |
126 | data_loader = self.loaders['train']
127 | total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader, ncols=100)
128 | criterion = torch.nn.CrossEntropyLoss(reduction='none')
129 |
130 |
131 | for data, target, _, _, _, _, _ in train_bar:
132 | B = target.size(0)
133 |
134 | data, target = data.cuda(), target.cuda()
135 |
136 | results = self.model(data)
137 | loss = torch.mean(criterion(results["out"], target.long()))
138 |
139 | self.optimizer.zero_grad()
140 | loss.backward()
141 | self.optimizer.step()
142 |
143 | total_num += B
144 | total_loss += loss.item() * B
145 |
146 | train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, self.max_epoch, total_loss / total_num))
147 |
148 |
149 | return total_loss / total_num
150 |
151 |
152 | def test(self, epoch):
153 | self.model.eval()
154 | test_loader = self.loaders['test']
155 |
156 | total_top1, total_top5, total_num, test_bar = 0.0, 0.0, 0, tqdm(test_loader, ncols=100)
157 | with torch.no_grad():
158 |
159 | for data, target, _, _, _, _, _ in test_bar:
160 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
161 | B = target.size(0)
162 |
163 | results = self.model(data)
164 | pred_labels = results["out"].argsort(dim=-1, descending=True)
165 |
166 | total_num += B
167 | total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
168 | total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
169 | test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}% Acc@5:{:.2f}%'
170 | .format(epoch, self.max_epoch, total_top1 / total_num * 100, total_top5 / total_num * 100))
171 |
172 |
173 | self.model.train()
174 |
175 | return
176 |
177 |
178 | class BiasedClassifyTrainer(ClassifyTrainer):
179 |
180 |
181 | def test(self, epoch):
182 | self.model.eval()
183 |
184 | for desc in ['test']:
185 | loader = self.loaders[desc]
186 |
187 | total_top1, total_top5, total_num, test_bar = 0.0, 0.0, 0, tqdm(loader, ncols=100)
188 |
189 | with torch.no_grad():
190 |
191 | for data, target, bias, _, _, _, _ in test_bar:
192 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
193 |
194 | B = target.size(0)
195 |
196 | results = self.model(data)
197 | pred_labels = results["out"].argsort(dim=-1, descending=True)
198 |
199 | total_num += B
200 | total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
201 | total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
202 | test_bar.set_description('[{}] Test Epoch: [{}/{}] Acc@1:{:.2f}% Acc@5:{:.2f}%'
203 | .format(desc, epoch, self.max_epoch, total_top1 / total_num * 100, total_top5 / total_num * 100))
204 |
205 | log = self.logger.info if desc in ['train', 'train_eval'] else self.logger.warning
206 | log('Eval Epoch [{}/{}] ({}) Acc@1:{:.2f}% Acc@5:{:.2f}%'.format(epoch, self.max_epoch, desc, total_top1 / total_num * 100, total_top5 / total_num * 100))
207 |
208 | self.model.train()
209 | return
210 |
211 |
212 |
213 |
214 | def test_unbiased(self, epoch, train_eval=True):
215 | self.model.eval()
216 |
217 | test_envs = ['valid', 'test']
218 | for desc in test_envs:
219 | loader = self.loaders[desc]
220 |
221 | total_top1, total_top5, total_num, test_bar = 0.0, 0.0, 0, tqdm(loader, ncols=100)
222 |
223 | num_classes = len(loader.dataset.classes)
224 | num_groups = loader.dataset.num_groups
225 |
226 | bias_counts = torch.zeros(num_groups).cuda()
227 | bias_top1s = torch.zeros(num_groups).cuda()
228 |
229 |
230 | with torch.no_grad():
231 |
232 | features, labels = [], []
233 | logits = []
234 | corrects = []
235 |
236 | for data, target, biases, group, _, _, ids in test_bar:
237 | data, target, biases, group = data.cuda(), target.cuda(), biases.cuda(), group.cuda()
238 |
239 | B = target.size(0)
240 |
241 | results = self.model(data)
242 | pred_labels = results["out"].argsort(dim=-1, descending=True)
243 | features.append(results["feature"])
244 | logits.append(results["out"])
245 | labels.append(group)
246 |
247 |
248 | top1s = (pred_labels[:, :1] == target.unsqueeze(dim=-1)).squeeze().unsqueeze(0)
249 | group_indices = (group==torch.arange(num_groups).unsqueeze(1).long().cuda())
250 |
251 | bias_counts += group_indices.sum(1)
252 | bias_top1s += (top1s * group_indices).sum(1)
253 |
254 | corrects.append(top1s)
255 |
256 | total_num += B
257 | total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
258 | total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
259 | acc1, acc5 = total_top1 / total_num * 100, total_top5 / total_num * 100
260 |
261 | bias_accs = bias_top1s / bias_counts * 100
262 |
263 | avg_acc = np.nanmean(bias_accs.cpu().numpy())
264 | worst_acc = np.nanmin(bias_accs.cpu().numpy())
265 | std_acc = np.nanstd(bias_accs.cpu().numpy())
266 |
267 | acc_desc = '/'.join(['{:.1f}%'.format(acc) for acc in bias_accs])
268 |
269 | test_bar.set_description('Eval Epoch [{}/{}] [{}] Bias: {:.2f}%'.format(epoch, self.max_epoch, desc, avg_acc))
270 |
271 | features = torch.cat(features)
272 | logits = torch.cat(logits)
273 | labels = torch.cat(labels)
274 | corrects = torch.cat(corrects, 1)
275 |
276 |
277 | log = self.logger.info if desc in ['train', 'train_eval'] else self.logger.warning
278 | log('Eval Epoch [{}/{}] [{}] Unbiased: {:.2f}% (std: {:.2f}), Worst: {:.2f}% [{}] (Average: {:.2f}%)'.format(epoch, self.max_epoch, desc, avg_acc, std_acc, worst_acc, acc_desc))
279 | self.logger.info('Total [{}]: Acc@1:{:.2f}% Acc@5:{:.2f}%'.format(desc, acc1, acc5))
280 | print(" {} / {} / {}".format(self.args.desc, self.args.target_attr, self.args.bias_attrs))
281 |
282 |
283 | self.model.train()
284 |
285 | return
286 |
287 |
288 |
289 |
290 |
291 |
--------------------------------------------------------------------------------
/utils/io_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import hashlib
4 | import errno
5 | import logging
6 | from collections import defaultdict
7 | from string import Formatter
8 | import torch
9 | import re, pdb
10 | from datetime import datetime
11 |
12 | import cv2
13 | import numpy as np
14 | import pdb
15 |
16 |
17 | def get_logger(name, fmt='%(asctime)s:%(name)s:%(message)s',
18 | print_level=logging.INFO,
19 | write_level=logging.DEBUG, log_file='', mode='w'):
20 | """
21 | Get Logger with given name
22 | :param name: logger name.
23 | :param fmt: log format. (default: %(asctime)s:%(levelname)s:%(name)s:%(message)s)
24 | :param level: logging level. (default: logging.DEBUG)
25 | :param log_file: path of log file. (default: None)
26 | :return:
27 | """
28 | logger = logging.getLogger(name)
29 | # logger.setLevel(write_level)
30 | logging.basicConfig(level=print_level)
31 | formatter = logging.Formatter(fmt, datefmt='%Y/%m/%d %H:%M:%S')
32 |
33 | # Add file handler
34 | if log_file:
35 | file_handler = logging.FileHandler(log_file, mode=mode)
36 | file_handler.setLevel(write_level)
37 | file_handler.setFormatter(formatter)
38 | logger.addHandler(file_handler)
39 |
40 | if print_level is not None:
41 | try:
42 | import coloredlogs
43 | coloredlogs.install(level=print_level, logger=logger)
44 | coloredlogs.DEFAULT_LEVEL_STYLES = {'critical': {'color': 'red', 'bold': True}, 'debug': {'color': 'green'}, 'error': {'color': 'red'}, 'info': {}, 'notice': {'color': 'magenta'}, 'spam': {'color': 'green', 'faint': True}, 'success': {'color': 'green', 'bold': True}, 'verbose': {'color': 'blue'}, 'warning': {'color': 'yellow'}}
45 | except ImportError:
46 | print("Please install Coloredlogs for better view")
47 | # Add stream handler
48 | stream_handler = logging.StreamHandler()
49 | stream_handler.setLevel(print_level)
50 | stream_handler.setFormatter(formatter)
51 | logger.addHandler(stream_handler)
52 | return logger
--------------------------------------------------------------------------------
/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | import torch, pdb
2 | import numpy as np
3 | from torch.autograd import Variable
4 | import torch.nn.functional as F
5 |
6 | def get_negative_mask(batch_size):
7 | negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool)
8 | for i in range(batch_size):
9 | negative_mask[i, i] = 0
10 | negative_mask[i, i + batch_size] = 0
11 |
12 | negative_mask = torch.cat((negative_mask, negative_mask), 0)
13 | return negative_mask
14 |
15 |
16 | def get_negative_class_mask(targets):
17 | batch_size = targets.size(0)
18 | negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool)
19 | for i in range(batch_size):
20 | current_c = targets[i]
21 | same_indices = (targets == current_c).nonzero().squeeze()
22 | for s in same_indices:
23 | negative_mask[i, s] = 0
24 | negative_mask[i, s + batch_size] = 0
25 |
26 | negative_mask = torch.cat((negative_mask, negative_mask), 0)
27 | return negative_mask
28 |
29 |
30 |
31 | class GradMulConst(torch.autograd.Function):
32 | """
33 | This layer is used to create an adversarial loss.
34 | """
35 | @staticmethod
36 | def forward(ctx, x, const):
37 | ctx.const = const
38 | return x.view_as(x)
39 |
40 | @staticmethod
41 | def backward(ctx, grad_output):
42 | return grad_output * ctx.const, None
43 |
44 | def grad_mul_const(x, const):
45 | return GradMulConst.apply(x, const)
46 |
47 |
48 |
49 | class AvgMeter(object):
50 | def __init__(self, name=''):
51 | self.reset()
52 |
53 | def reset(self):
54 | self.val = 0.
55 | self.avg = 0.
56 | self.sum = 0.
57 | self.count = 0.
58 |
59 | def update(self, val, n=1):
60 | if type(val) is torch.Tensor:
61 | val = val.detach()
62 | val = val.cpu()
63 | val = val.numpy()
64 |
65 | if n==len(val):
66 | self.val = val[-1]
67 | self.sum += np.sum(val)
68 | self.count += len(val)
69 | elif n==0: # array
70 | self.val = val[-1]
71 | self.sum += np.sum(val)
72 | self.count += len(val)
73 | else:
74 | self.val = val
75 | self.sum += val
76 | self.count += n
77 |
78 | self.avg = self.sum / self.count
79 |
80 | def __repr__(self):
81 | return self.name+":"+str(round(self.avg, 3))
82 |
83 |
84 | class WindowAvgMeter(object):
85 | def __init__(self, name='', max_count=20):
86 | self.values = []
87 | self.name = name
88 | self.max_count = max_count
89 |
90 | @property
91 | def avg(self):
92 | if len(self.values) > 0:
93 | return np.sum(self.values)/len(self.values)
94 | else:
95 | return 0
96 |
97 | def update(self, val):
98 | if type(val) is torch.Tensor:
99 | val = val.detach()
100 | val = val.cpu()
101 | val = val.numpy()
102 |
103 | self.values.append(val)
104 | if len(self.values) > self.max_count:
105 | self.values.pop(0)
106 |
107 | def __repr__(self):
108 | return self.name+":"+str(round(self.avg, 3))
109 |
110 |
111 | class EMA(object):
112 | # Exponential Moving Average
113 |
114 | def __init__(self, label, alpha=0.9):
115 | self.label = label
116 | self.alpha = alpha
117 | self.parameter = torch.zeros(label.size(0))
118 | self.updated = torch.zeros(label.size(0))
119 |
120 | def update(self, data, index):
121 | self.parameter[index] = self.alpha * self.parameter[index] + (1-self.alpha*self.updated[index]) * data
122 | self.updated[index] = 1
123 |
124 | def max_loss(self, label):
125 | label_index = np.where(self.label == label)[0]
126 | return self.parameter[label_index].max()
127 |
128 |
129 |
130 |
131 | class FocalLoss(torch.nn.Module):
132 | def __init__(self, gamma=0, alpha=None, size_average=True):
133 | super(FocalLoss, self).__init__()
134 | self.gamma = gamma
135 | self.alpha = alpha
136 | if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha])
137 | if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
138 | self.size_average = size_average
139 |
140 | def forward(self, input, target):
141 | if input.dim()>2:
142 | input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W
143 | input = input.transpose(1,2) # N,C,H*W => N,H*W,C
144 | input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C
145 | target = target.view(-1,1)
146 |
147 | logpt = F.log_softmax(input, 1)
148 | logpt = logpt.gather(1,target)
149 | logpt = logpt.view(-1)
150 | pt = Variable(logpt.data.exp())
151 |
152 | if self.alpha is not None:
153 | if self.alpha.type()!=input.data.type():
154 | self.alpha = self.alpha.type_as(input.data)
155 | at = self.alpha.gather(0,target.data.view(-1))
156 | logpt = logpt * Variable(at)
157 |
158 | loss = -1 * (1-pt)**self.gamma * logpt
159 | if self.size_average: return loss.mean()
160 | else: return loss.sum()
161 |
162 |
163 |
164 | def KL_u_p_loss(outputs):
165 | # KL(u||p)
166 | num_classes = outputs.size(1)
167 | uniform_tensors = torch.ones(outputs.size())
168 | uniform_dists = Variable(uniform_tensors / num_classes).cuda()
169 | instance_losses = F.kl_div(F.log_softmax(outputs, dim=1), uniform_dists, reduction='none').sum(dim=1)
170 | return instance_losses
171 |
--------------------------------------------------------------------------------