├── README.md ├── bpa_trainer.py ├── dataset └── celebA.py ├── factory.py ├── main.py ├── models └── classification.py ├── modules ├── centroids.py ├── loss.py └── transform.py ├── requirements.txt ├── trainer.py └── utils ├── io_utils.py └── train_utils.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Unsupervised Learning of Debiased Representations with Pseudo-Attributes 3 | Pytorch implementation for BPA (CVPR 2022) 4 | 5 | [Seonguk Seo](https://seoseong.uk/), [Joon-Young Lee](https://joonyoung-cv.github.io/), [Bohyung Han](https://cv.snu.ac.kr/~bhhan/) 6 | 7 | Seoul National University, Adobe Research 8 | ### [[Paper](https://arxiv.org/abs/2108.02943)] 9 | 10 | >

11 | > Dataset bias is a critical challenge in machine learning since it often leads to a negative impact on a model due to the unintended decision rules captured by spurious correlations. Although existing works often handle this issue based on human supervision, the availability of the proper annotations is impractical and even unrealistic. To better tackle the limitation, we propose a simple but effective unsupervised debiasing technique. Specifically, we first identify pseudo-attributes based on the results from clustering performed in the feature embedding space even without an explicit bias attribute supervision. Then, we employ a novel cluster-wise reweighting scheme to learn debiased representation; the proposed method prevents minority groups from being discounted for minimizing the overall loss, which is desirable for worst-case generalization. The extensive experiments demonstrate the outstanding performance of our approach on multiple standard benchmarks, even achieving the competitive accuracy to the supervised counterpart. 12 | 13 | 14 | --- 15 | 16 | ## Installation 17 | ``` 18 | git clone https://github.com/skynbe/pseudo-attributes.git 19 | cd pseudo-attributes 20 | pip install -r requirements.txt 21 | ``` 22 | Download CelebA dataset at $ROOT_PATH/data/celebA. 23 | 24 | 25 | ### Quick Start 26 | 27 | Train baseline model: 28 | ``` 29 | python main.py --arch ResNet18 --trainer classify --desc base --dataset celebA --test_epoch 1 --lr 1e-4 --target_attr Blond_Hair --bias_attrs Male --no_save 30 | ``` 31 | 32 | Train BPA model: 33 | ``` 34 | python main.py --arch ResNet18 --trainer bpa --desc bpa_k8 --dataset celebA --test_epoch 1 --lr 1e-4 --target_attr Blond_Hair --bias_attrs Male --k 8 --ks 8 --no_save --use_base {$BASE_PATH} 35 | ``` 36 | 37 | 38 | ## Citation 39 | 40 | If you find our work useful in your research, please cite: 41 | 42 | ``` 43 | @inproceedings{seo2022unsupervised, 44 | title={Unsupervised Learning of Debiased Representations with Pseudo-Attributes}, 45 | author={Seo, Seonguk and Lee, Joon-Young and Han, Bohyung}, 46 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 47 | pages={16742--16751}, 48 | year={2022} 49 | } 50 | ``` 51 | -------------------------------------------------------------------------------- /bpa_trainer.py: -------------------------------------------------------------------------------- 1 | 2 | import copy 3 | from trainer import * 4 | from kmeans_pytorch import kmeans 5 | from modules.centroids import AvgFixedCentroids 6 | from torch.autograd import Variable 7 | 8 | 9 | class OnlineTrainer(BiasedClassifyTrainer): 10 | 11 | def __init__(self, args, model, loaders, optimizer, num_classes, per_clusters=0): 12 | super().__init__(args, model, loaders, optimizer, num_classes) 13 | if per_clusters == 0: 14 | per_clusters = args.k 15 | 16 | self.centroids = AvgFixedCentroids(args, num_classes, per_clusters=per_clusters) 17 | self.update_cluster_iter = args.update_cluster_iter 18 | self.checkpoint_dir = args.checkpoint_dir 19 | 20 | def save_model(self, epoch): 21 | if not self.args.no_save or epoch % self.args.save_epoch == 0: 22 | torch.save({ 23 | 'epoch': epoch, 24 | 'state_dict': self.model.state_dict(), 25 | 'optimizer' : self.optimizer.state_dict(), 26 | }, self.checkpoint_dir / 'e{:04d}.pth'.format(epoch)) 27 | return 28 | 29 | def load_model(self, epoch=0): 30 | self.logger.info('Resume training') 31 | if epoch==0: 32 | checkpoint_path = max((f.stat().st_mtime, f) for f in self.checkpoint_dir.glob('*.pth'))[1] 33 | self.logger.info('Resume Latest from {}'.format(checkpoint_path)) 34 | else: 35 | self.logger.info('Resume from {}'.format(epoch)) 36 | checkpoint_path = self.checkpoint_dir / 'e{:04d}.pth'.format(epoch) 37 | 38 | checkpoint = torch.load(checkpoint_path) 39 | self.model.load_state_dict(checkpoint['state_dict']) # Set CUDA before if error occurs. 40 | self.optimizer.load_state_dict(checkpoint['optimizer']) 41 | self.epoch = checkpoint['epoch'] 42 | 43 | 44 | 45 | class BPATrainer(OnlineTrainer): 46 | 47 | def __init__(self, args, model, loaders, optimizer, num_classes): 48 | super().__init__(args, model, loaders, optimizer, num_classes) 49 | self.class_weights = None 50 | self.base_model = copy.deepcopy(self.model) 51 | if not args.use_base: 52 | assert ValueError 53 | 54 | def save_model(self, epoch): 55 | if not self.args.no_save or epoch % self.args.save_epoch == 0: 56 | torch.save({ 57 | 'epoch': epoch, 58 | 'state_dict': self.model.state_dict(), 59 | 'optimizer' : self.optimizer.state_dict(), 60 | }, self.checkpoint_dir / 'e{:04d}.pth'.format(epoch)) 61 | return 62 | 63 | 64 | def use_base_model(self, file_name): 65 | self.logger.info('Loading ({}) base model'.format(file_name)) 66 | checkpoint_path = self.checkpoint_dir / '..' / '{}.pth'.format(file_name) 67 | checkpoint = torch.load(checkpoint_path) 68 | self.base_model.load_state_dict(checkpoint['state_dict']) # Set CUDA before if error occurs. 69 | 70 | 71 | def load_model(self, epoch=0): 72 | self.logger.info('Resume training') 73 | if epoch==0: 74 | checkpoint_path = max((f.stat().st_mtime, f) for f in self.checkpoint_dir.glob('*.pth'))[1] 75 | self.logger.info('Resume Latest from {}'.format(checkpoint_path)) 76 | else: 77 | self.logger.info('Resume from {}'.format(epoch)) 78 | checkpoint_path = self.checkpoint_dir / 'e{:04d}.pth'.format(epoch) 79 | 80 | checkpoint = torch.load(checkpoint_path) 81 | self.model.load_state_dict(checkpoint['state_dict']) # Set CUDA before if error occurs. 82 | self.optimizer.load_state_dict(checkpoint['optimizer']) 83 | self.epoch = checkpoint['epoch'] 84 | 85 | 86 | def _extract_features(self, model, data_loader): 87 | features, targets = [], [] 88 | ids = [] 89 | 90 | for data, target, index in tqdm(data_loader, desc='Feature extraction for clustering..', ncols=5): 91 | data, target, index = data.cuda(), target.cuda(), index.cuda() 92 | results = model(data) 93 | features.append(results["feature"]) 94 | targets.append(target) 95 | ids.append(index) 96 | 97 | features = torch.cat(features) 98 | targets = torch.cat(targets) 99 | ids = torch.cat(ids) 100 | return features, targets, ids 101 | 102 | 103 | def _cluster_features(self, data_loader, features, targets, ids, num_clusters): 104 | 105 | N = len(data_loader.dataset) 106 | num_classes = data_loader.dataset.num_classes 107 | sorted_target_clusters = torch.zeros(N).long().cuda() + num_clusters*num_classes 108 | 109 | target_clusters = torch.zeros_like(targets)-1 110 | cluster_centers = [] 111 | 112 | for t in range(num_classes): 113 | target_assigns = (targets==t).nonzero().squeeze() 114 | feautre_assigns = features[target_assigns] 115 | 116 | cluster_ids, cluster_center = kmeans(X=feautre_assigns, num_clusters=num_clusters, distance='cosine', tqdm_flag=False, device=0) 117 | cluster_ids_ = cluster_ids + t*num_clusters 118 | 119 | target_clusters[target_assigns] = cluster_ids_.cuda() 120 | cluster_centers.append(cluster_center) 121 | 122 | sorted_target_clusters[ids] = target_clusters 123 | cluster_centers = torch.cat(cluster_centers, 0) 124 | return sorted_target_clusters, cluster_centers 125 | 126 | 127 | def inital_clustering(self): 128 | data_loader = self.loaders['train_eval'] 129 | data_loader.dataset.clustering_on() 130 | self.base_model.eval() 131 | 132 | with torch.no_grad(): 133 | 134 | features, targets, ids = self._extract_features(self.base_model, data_loader) 135 | num_clusters = self.args.num_clusters 136 | cluster_assigns, cluster_centers = self._cluster_features(data_loader, features, targets, ids, num_clusters) 137 | 138 | cluster_counts = cluster_assigns.bincount().float() 139 | print("Cluster counts : {}, len({})".format(cluster_counts, len(cluster_counts))) 140 | 141 | 142 | data_loader.dataset.clustering_off() 143 | return cluster_assigns, cluster_centers 144 | 145 | 146 | def train(self, epoch): 147 | 148 | cluster_weights = None 149 | # if epoch > 1 and not self.centroids.initialized: 150 | if not self.centroids.initialized: 151 | cluster_assigns, cluster_centers = self.inital_clustering() 152 | self.centroids.initialize_(cluster_assigns, cluster_centers) 153 | 154 | data_loader = self.loaders['train'] 155 | total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader) 156 | criterion = torch.nn.CrossEntropyLoss(reduction='none') 157 | total_metric_loss = 0.0 158 | 159 | i = 0 160 | for data, target, _, _, cluster, weight_pre, ids in train_bar: 161 | i += 1 162 | B = target.size(0) 163 | 164 | data, target = data.cuda(), target.cuda() 165 | 166 | results = self.model(data) 167 | weight = self.centroids.get_cluster_weights(ids) 168 | loss = torch.mean(criterion(results["out"], target.long()) * (weight)) 169 | 170 | self.optimizer.zero_grad() 171 | (loss).backward() 172 | self.optimizer.step() 173 | 174 | total_num += B 175 | total_loss += loss.item() * B 176 | 177 | train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, self.max_epoch, total_loss / total_num)) 178 | 179 | if self.centroids.initialized: 180 | self.centroids.update(results, target, ids) 181 | if self.args.update_cluster_iter > 0 and i % self.args.update_cluster_iter == 0: 182 | self.centroids.compute_centroids() 183 | 184 | 185 | if self.centroids.initialized: 186 | self.centroids.compute_centroids(verbose=True) 187 | 188 | 189 | return total_loss / total_num 190 | 191 | 192 | def test_unbiased(self, epoch, train_eval=True): 193 | self.model.eval() 194 | 195 | test_envs = ['valid', 'test'] 196 | 197 | for desc in test_envs: 198 | loader = self.loaders[desc] 199 | 200 | total_top1, total_top5, total_num, test_bar = 0.0, 0.0, 0, tqdm(loader) 201 | 202 | num_classes = len(loader.dataset.classes) 203 | num_groups = loader.dataset.num_groups 204 | 205 | bias_counts = torch.zeros(num_groups).cuda() 206 | bias_top1s = torch.zeros(num_groups).cuda() 207 | 208 | with torch.no_grad(): 209 | 210 | features, labels = [], [] 211 | corrects = [] 212 | 213 | for data, target, biases, group, _, _, ids in test_bar: 214 | data, target, biases, group = data.cuda(), target.cuda(), biases.cuda(), group.cuda() 215 | 216 | B = target.size(0) 217 | num_groups = np.power(num_classes, biases.size(1)+1) 218 | 219 | results = self.model(data) 220 | pred_labels = results["out"].argsort(dim=-1, descending=True) 221 | features.append(results["feature"]) 222 | labels.append(group) 223 | 224 | 225 | top1s = (pred_labels[:, :1] == target.unsqueeze(dim=-1)).squeeze().unsqueeze(0) 226 | group_indices = (group==torch.arange(num_groups).unsqueeze(1).long().cuda()) 227 | 228 | bias_counts += group_indices.sum(1) 229 | bias_top1s += (top1s * group_indices).sum(1) 230 | 231 | corrects.append(top1s) 232 | 233 | total_num += B 234 | total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 235 | total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 236 | acc1, acc5 = total_top1 / total_num * 100, total_top5 / total_num * 100 237 | 238 | bias_accs = bias_top1s / bias_counts * 100 239 | 240 | avg_acc = np.nanmean(bias_accs.cpu().numpy()) 241 | worst_acc = np.nanmin(bias_accs.cpu().numpy()) 242 | 243 | acc_desc = '/'.join(['{:.1f}%'.format(acc) for acc in bias_accs]) 244 | 245 | test_bar.set_description('Eval Epoch [{}/{}] [{}] Bias: {:.2f}%'.format(epoch, self.max_epoch, desc, avg_acc)) 246 | 247 | 248 | log = self.logger.info if desc in ['train', 'train_eval'] else self.logger.warning 249 | log('Eval Epoch [{}/{}] [{}] Unbiased: {:.2f}% [{}]'.format(epoch, self.max_epoch, desc, avg_acc, acc_desc)) 250 | self.logger.info('Total [{}]: Acc@1:{:.2f}% Acc@5:{:.2f}%'.format(desc, acc1, acc5)) 251 | print(" {} / {} / {}".format(self.args.desc, self.args.target_attr, self.args.bias_attrs)) 252 | 253 | 254 | self.model.train() 255 | return 256 | -------------------------------------------------------------------------------- /dataset/celebA.py: -------------------------------------------------------------------------------- 1 | import os, pdb 2 | import torch 3 | import pandas as pd 4 | from PIL import Image 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | # from models import model_attributes 8 | from torch.utils.data import Dataset, Subset 9 | from pathlib import Path 10 | import torchvision 11 | from torch.utils import data 12 | from torch.utils.data.sampler import WeightedRandomSampler 13 | 14 | 15 | 16 | class CelebA(torchvision.datasets.CelebA): 17 | 18 | # Attributes : '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young' 19 | 20 | def __init__(self, root, split="train", target_type="attr", transform=None, 21 | target_transform=None, download=False, 22 | target_attr='', aux_attrs=[], bias_attrs=[], domain_attr=None, domain_type=None, 23 | pair=False, args=None, scale=1.0): 24 | 25 | super().__init__(root, split=split, target_type=target_type, transform=transform, 26 | target_transform=target_transform, download=download) 27 | 28 | self.target_attr = target_attr 29 | self.aux_attrs = aux_attrs 30 | self.bias_attrs = bias_attrs 31 | self.domain_attr = domain_attr 32 | 33 | self.target_idx = self.attr_names.index(target_attr) 34 | self.aux_indices = [self.attr_names.index(aux_att) for aux_att in aux_attrs] if aux_attrs else [] 35 | self.domain_idx = self.attr_names.index(domain_attr) if domain_attr else None 36 | 37 | self.domain_type = domain_type 38 | 39 | self.bias_indices = [self.attr_names.index(bias_att) for bias_att in bias_attrs] 40 | 41 | self.cluster_ids = None 42 | self.clustering = False 43 | self.sample_weights = None 44 | 45 | self.pair = pair 46 | 47 | self.visualize_image = False 48 | 49 | self.args = args 50 | self.visualize = False 51 | self.scale = scale 52 | 53 | 54 | 55 | @property 56 | def class_elements(self): 57 | return self.attr[:, self.target_idx] 58 | 59 | @property 60 | def group_elements(self): 61 | group_attrs = self.attr[:, [self.target_idx]+self.bias_indices] 62 | weight = np.power(self.num_classes, np.arange(group_attrs.size(1))) 63 | group_elems = (group_attrs*weight).sum(1) 64 | return group_elems 65 | 66 | @property 67 | def group_counts(self): 68 | group_attrs = self.attr[:, [self.target_idx]+self.bias_indices] 69 | weight = np.power(self.num_classes, np.arange(group_attrs.size(1))) 70 | group_elems = (group_attrs*weight).sum(1) 71 | return group_elems.bincount() 72 | 73 | 74 | def group_counts_with_attr(self, attr): 75 | target_idx = self.attr_names.index(attr) 76 | group_attrs = self.attr[:, [target_idx]+self.bias_indices] 77 | weight = np.power(self.num_classes, np.arange(group_attrs.size(1))) 78 | group_elems = (group_attrs*weight).sum(1) 79 | return group_elems.bincount() 80 | 81 | def visualize(self): 82 | self.visualize_image = True 83 | 84 | def clustering_on(self): 85 | self.clustering = True 86 | 87 | def clustering_off(self): 88 | self.clustering = False 89 | 90 | def update_clusters(self, cluster_ids): 91 | self.cluster_ids = cluster_ids 92 | 93 | def update_weights(self, weights): 94 | self.sample_weights = weights 95 | 96 | 97 | def __len__(self): 98 | len = super().__len__() 99 | if self.scale < 1.0: 100 | len = int(len*self.scale) 101 | return len 102 | 103 | 104 | def get_sample_index(self, index): 105 | return index 106 | 107 | def __getitem__(self, index_): 108 | 109 | index = self.get_sample_index(index_) 110 | 111 | img_path = os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]) 112 | img_ = Image.open(img_path) 113 | 114 | target = [] 115 | for t in self.target_type: 116 | if t == "attr": 117 | target.append(self.attr[index, :]) 118 | elif t == "identity": 119 | target.append(self.identity[index, 0]) 120 | elif t == "bbox": 121 | target.append(self.bbox[index, :]) 122 | elif t == "landmarks": 123 | target.append(self.landmarks_align[index, :]) 124 | else: 125 | # TODO: refactor with utils.verify_str_arg 126 | raise ValueError("Target type \"{}\" is not recognized.".format(t)) 127 | 128 | if target: 129 | target = tuple(target) if len(target) > 1 else target[0] 130 | 131 | if self.target_transform is not None: 132 | target = self.target_transform(target) 133 | else: 134 | target = None 135 | 136 | target_attr = target[self.target_idx] 137 | bias_attrs = np.array([target[bias_idx] for bias_idx in self.bias_indices]) 138 | group_attrs = np.insert(bias_attrs, 0, target_attr) # target first 139 | 140 | bit = np.power(self.num_classes, np.arange(len(group_attrs))) 141 | group = np.sum(bit * group_attrs) 142 | 143 | 144 | if self.cluster_ids is not None: 145 | cluster = self.cluster_ids[index] 146 | else: 147 | cluster = -1 148 | 149 | if self.sample_weights is not None: 150 | weight = self.sample_weights[index] 151 | else: 152 | weight = 1 153 | 154 | 155 | if self.transform is not None: 156 | transform = self.transform 157 | img = transform(img_) 158 | 159 | # for clustering 160 | if self.clustering is True: 161 | return img, target_attr, index 162 | 163 | 164 | return img, target_attr, bias_attrs, group, cluster, weight, index 165 | 166 | 167 | @property 168 | def classes(self): 169 | return ['0', '1'] 170 | 171 | @property 172 | def num_classes(self): 173 | return len(self.classes) 174 | 175 | @property 176 | def num_groups(self): 177 | return np.power(len(self.classes), len(self.bias_attrs)+1) 178 | 179 | @property 180 | def bias_attributes(self): 181 | return 182 | 183 | @property 184 | def attribute_names(self): 185 | return ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young'] 186 | 187 | 188 | 189 | 190 | 191 | def get_celebA_dataloader(root, batch_size, split, target_attr, bias_attrs, aux_attrs=None, num_workers=4, pair=False, cluster_ids=None, args=None): 192 | 193 | from factory import TransformFactory 194 | 195 | ### Transform and scale 196 | if split in ['train', 'train_target']: 197 | celebA_transform = TransformFactory.create("celebA_train") 198 | 199 | elif split in ['valid', 'test', 'train_eval']: 200 | celebA_transform = TransformFactory.create("celebA_test") 201 | 202 | ### Dataset split 203 | celebDataset = CelebA 204 | if split in ['train', 'train_eval']: 205 | dataset_split = 'train' 206 | elif split in ['valid']: 207 | dataset_split = 'valid' 208 | elif split in ['test']: 209 | dataset_split = 'test' 210 | 211 | 212 | dataset = celebDataset(root, split=dataset_split, transform=celebA_transform, download=True, 213 | target_attr=target_attr, bias_attrs=bias_attrs, aux_attrs=aux_attrs, args=args) 214 | 215 | 216 | dataloader = data.DataLoader(dataset=dataset, 217 | batch_size=batch_size, 218 | shuffle=True, 219 | num_workers=num_workers, 220 | pin_memory=True) 221 | 222 | return dataloader, dataset 223 | 224 | -------------------------------------------------------------------------------- /factory.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Iterable, List, Optional 2 | from pathlib import Path 3 | from dataset.celebA import CelebA, get_celebA_dataloader 4 | 5 | import torchvision 6 | from modules.transform import * 7 | from models.classification import ResNet18 8 | 9 | from trainer import BiasedClassifyTrainer 10 | from bpa_trainer import BPATrainer 11 | from torch.utils.data import DataLoader 12 | import pdb 13 | 14 | 15 | 16 | class Factory(object): 17 | 18 | PRODUCTS: Dict[str, Callable] = {} 19 | 20 | def __init__(self): 21 | raise ValueError( 22 | f"""Cannot instantiate {self.__class__.__name__} object, use 23 | `create` classmethod to create a product from this factory. 24 | """ 25 | ) 26 | 27 | @classmethod 28 | def create(cls, name: str, *args, **kwargs) -> Any: 29 | r"""Create an object by its name, args and kwargs.""" 30 | if name not in cls.PRODUCTS: 31 | raise KeyError(f"{cls.__class__.__name__} cannot create {name}.") 32 | 33 | return cls.PRODUCTS[name](*args, **kwargs) 34 | 35 | 36 | 37 | class ModelFactory(Factory): 38 | 39 | 40 | MODELS: Dict[str, Callable] = { 41 | "ResNet18": ResNet18, 42 | } 43 | 44 | @classmethod 45 | def create(cls, name: str, *args, **kwargs) -> Any: 46 | 47 | return cls.MODELS[name](*args, **kwargs) 48 | 49 | 50 | 51 | 52 | class TransformFactory(Factory): 53 | PRODUCTS: Dict[str, Callable] = { 54 | 55 | "train": train_transform, 56 | "test": test_transform, 57 | 58 | "celebA_train": celebA_train_transform, 59 | "celebA_test": celebA_test_transform, 60 | 61 | } 62 | 63 | @classmethod 64 | def create(cls, name: str, *args, **kwargs) -> Any: 65 | r"""Create an object by its name, args and kwargs.""" 66 | if name not in cls.PRODUCTS: 67 | raise KeyError(f"{cls.__class__.__name__} cannot create {name}.") 68 | 69 | return cls.PRODUCTS[name] 70 | 71 | 72 | 73 | 74 | 75 | class DataLoaderFactory(Factory): 76 | 77 | @classmethod 78 | def create(cls, name: str, trainer: str, batch_size: int, num_workers: int, configs: Any, cluster_ids: Any = None) -> Any: 79 | 80 | if name == 'celebA': 81 | 82 | train_loader, train_set = get_celebA_dataloader( 83 | root=Path('./data/celebA'), batch_size=batch_size, split='train', 84 | target_attr=configs.target_attr, bias_attrs=configs.bias_attrs, 85 | cluster_ids=cluster_ids, args=configs) 86 | valid_loader, valid_set = get_celebA_dataloader( 87 | root=Path('./data/celebA'), batch_size=batch_size, split='valid', 88 | target_attr=configs.target_attr, bias_attrs=configs.bias_attrs, 89 | args=configs) 90 | test_loader, test_set = get_celebA_dataloader( 91 | root=Path('./data/celebA'), batch_size=batch_size, split='test', 92 | target_attr=configs.target_attr, bias_attrs=configs.bias_attrs, 93 | args=configs) 94 | train_eval_loader, train_eval_set = get_celebA_dataloader( 95 | root=Path('./data/celebA'), batch_size=batch_size, split='train_eval', 96 | target_attr=configs.target_attr, bias_attrs=configs.bias_attrs, 97 | args=configs) 98 | 99 | datasets = { 100 | 'train': train_set, 101 | 'valid': valid_set, 102 | 'test': test_set, 103 | 'train_eval': train_eval_set, 104 | } 105 | 106 | data_loaders = { 107 | 'train': train_loader, 108 | 'valid': valid_loader, 109 | 'test': test_loader, 110 | 'train_eval': train_eval_loader, 111 | } 112 | 113 | 114 | else: 115 | raise ValueError 116 | 117 | 118 | return data_loaders, datasets 119 | 120 | 121 | 122 | class TrainerFactory(Factory): 123 | 124 | TRAINERS: Dict[str, Callable] = { 125 | "classify": BiasedClassifyTrainer, 126 | "bpa": BPATrainer, 127 | } 128 | 129 | @classmethod 130 | def create(cls, name: str, *args, **kwargs) -> Any: 131 | 132 | return cls.TRAINERS[name](*args, **kwargs) 133 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, copy 3 | 4 | import numpy as np 5 | import torch 6 | import torch.optim as optim 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | # import utils 13 | from modules.transform import * 14 | from model import Model 15 | import os, sys, logging, time, random, json, pdb 16 | from pathlib import Path 17 | 18 | 19 | from factory import ModelFactory, TrainerFactory, DataLoaderFactory, TransformFactory 20 | 21 | 22 | import wandb 23 | 24 | DATA_ROOT = Path('./data') 25 | CHECKPOINT_ROOT = Path('./checkpoint') 26 | 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser(description='Train SimCLR') 30 | parser.add_argument('--feature_dim', default=512, type=int, help='Feature dim for latent vector') 31 | parser.add_argument('--lr', default=1e-3, type=float, help='Learning Rate') 32 | parser.add_argument('--batch_size', default=256, type=int, help='Number of images in each mini-batch') 33 | parser.add_argument('--max_epoch', default=50, type=int, help='Number of sweeps over the dataset to train') 34 | parser.add_argument('--test_epoch', default=25, type=int, help='Test epoch') 35 | parser.add_argument('--save_epoch', default=25, type=int, help='Save epoch') 36 | parser.add_argument('--train_eval_epoch', default=50, type=int, help='Save epoch') 37 | 38 | parser.add_argument('--dataset', default='celebA', type=str, help='Dataset') 39 | parser.add_argument('--arch', default='ResNet18', type=str, help='Model architecture') 40 | parser.add_argument('--trainer', default='classify', type=str, help='Training scheme') 41 | parser.add_argument('--cluster_weight_type', default='scale_loss', type=str, help='Training scheme') 42 | parser.add_argument('--centroid', default='avgfixed', type=str, help='') 43 | 44 | parser.add_argument('--target_attr', default='', type=str, help='Target attributes') 45 | parser.add_argument('--bias_attrs', nargs='+', help='Bias attributes') 46 | 47 | parser.add_argument('--num_partitions', default=1, type=int, help='Test epoch') 48 | parser.add_argument('--k', default=1, type=int, help='# of clusters') 49 | parser.add_argument('--ks', default=[], nargs='+', help='# of clusters list (multi)') 50 | parser.add_argument('--update_cluster_iter', default=10, type=int, help='0 for every epoch') 51 | parser.add_argument('--feature_bank_init', action='store_true') 52 | parser.add_argument('--num_multi_centroids', default=1, type=int, help='# of centroids') 53 | 54 | parser.add_argument('--desc', default='test', type=str, help='Checkpoint folder name') 55 | parser.add_argument('--load_epoch', default=-1, type=int, help='Load model epoch') 56 | parser.add_argument('--weight_decay', default=1e-2, type=float, help='Weight decay') 57 | parser.add_argument('--momentum', default=0.3, type=float, help='Positive class priorx') 58 | parser.add_argument('--adj', default=2.0, type=float, help='Label noise ratio') 59 | parser.add_argument('--adj_type', default='', type=str, help='multiply or default') 60 | parser.add_argument('--exp_step', default=0.05, type=float, help='Exponential step size for weight averaging in AvgFixedCentroids') 61 | parser.add_argument('--avg_weight_type', default='expavg', type=str, help='avg type for weight averaging in AvgFixedCentroids') 62 | parser.add_argument('--overlap_type', default='exclusive', type=str, help='Channel overlap type for hetero clustering, [exclusive, half_exclusive]') 63 | parser.add_argument('--gamma_reverse', action='store_true') 64 | parser.add_argument('--scale', default=1.0, type=float, help='Dataset scale') 65 | parser.add_argument('--sampling', default='', type=str, help='class_subsampling/class_resampling') 66 | 67 | parser.add_argument('--use_base', default='', type=str, help='use base model') 68 | parser.add_argument('--load_base', default='', type=str, help='load checkpoint file as base model') 69 | parser.add_argument('--load_path', default='', type=str, help='load checkpoint file') 70 | 71 | parser.add_argument('--scheduler', default='cosine', type=str, help='cosine') 72 | parser.add_argument('--scheduler_param', default=100, type=int) 73 | 74 | parser.add_argument('--resume', default='', type=str, help='Run ID') 75 | parser.add_argument('--optim', default='adam', type=str, help='adam or sgd') 76 | parser.add_argument('--no_save', action='store_true') 77 | parser.add_argument('--verbose', action='store_true') 78 | parser.add_argument('--feature_fix', action='store_true') 79 | 80 | parser.add_argument('--eval', action='store_true') 81 | 82 | args = parser.parse_args() 83 | args.num_clusters = args.k 84 | args.num_multi_centroids = len(args.ks) 85 | 86 | # savings 87 | checkpoint_dir = CHECKPOINT_ROOT / args.dataset / args.target_attr / args.desc 88 | if not checkpoint_dir.exists(): 89 | checkpoint_dir.mkdir(parents=True, exist_ok=True) 90 | 91 | loaders, datasets = DataLoaderFactory.create(args.dataset, trainer=args.trainer, batch_size=args.batch_size, 92 | num_workers=4, configs=args) 93 | 94 | num_classes = len(datasets['test'].classes) 95 | print('# Classes: {}'.format(num_classes)) 96 | 97 | model_args = { 98 | "name": args.arch, 99 | "feature_dim": args.feature_dim, 100 | "num_classes": num_classes, 101 | "feature_fix": args.feature_fix, 102 | } 103 | 104 | # model setup and optimizer config 105 | model = ModelFactory.create(**model_args).cuda() 106 | model = nn.DataParallel(model) 107 | if args.optim == 'sgd': 108 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) 109 | else: 110 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 111 | 112 | 113 | scheduler = None 114 | if args.scheduler == 'cosine': 115 | assert args.scheduler_param != 0 116 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.scheduler_param) 117 | 118 | 119 | args.checkpoint_dir = checkpoint_dir 120 | trainer = TrainerFactory.create(args.trainer, args, model, loaders, optimizer, num_classes) 121 | trainer.set_checkpoint_dir(checkpoint_dir) 122 | start_epoch = 1 123 | 124 | if args.load_base: 125 | trainer.load_base_model(args.load_base) 126 | 127 | if args.use_base: 128 | trainer.use_base_model(args.use_base) 129 | 130 | if args.load_epoch >= 0: 131 | trainer.load_model(args.load_epoch) 132 | start_epoch = trainer.epoch + 1 133 | 134 | if args.load_path: 135 | if args.origin_attr: 136 | path = '{}/{}'.format(args.origin_attr, args.load_path) 137 | else: 138 | trainer.load_path(args.load_path) 139 | 140 | if args.eval: 141 | trainer.test_unbiased(epoch=start_epoch-1) 142 | return 143 | 144 | 145 | for epoch in range(start_epoch, args.max_epoch+1): 146 | trainer.train(epoch=epoch) 147 | 148 | if epoch % args.test_epoch == 0: 149 | trainer.save_model(epoch=epoch) 150 | trainer.test_unbiased(epoch=epoch) 151 | 152 | if scheduler is not None: 153 | scheduler.step() 154 | 155 | 156 | if __name__ == "__main__": 157 | main() 158 | 159 | 160 | -------------------------------------------------------------------------------- /models/classification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models 5 | import pdb 6 | from utils.train_utils import * 7 | 8 | 9 | 10 | class ResNet(nn.Module): 11 | def __init__(self, feature_dim, num_classes, arch='', feature_fix=False): 12 | super(ResNet, self).__init__() 13 | 14 | self.feature_dim = feature_dim 15 | self.num_classes = num_classes 16 | self.arch = arch 17 | resnet = self.get_backbone() 18 | self.conv1 = resnet.conv1 19 | self.bn1 = resnet.bn1 20 | self.relu = resnet.relu # 1/2, 64 21 | self.maxpool = resnet.maxpool 22 | 23 | self.res2 = resnet.layer1 # 1/4, 64 24 | self.res3 = resnet.layer2 # 1/8, 128 25 | self.res4 = resnet.layer3 # 1/16, 256 26 | self.res5 = resnet.layer4 # 1/32, 512 27 | 28 | self.f = nn.Sequential(self.conv1, self.bn1, self.relu, self.maxpool, 29 | self.res2, self.res3, self.res4, self.res5) 30 | 31 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 32 | # classifier 33 | self.fc = self.get_fc(num_classes) 34 | 35 | 36 | self.feature_fix = feature_fix 37 | if feature_fix: 38 | print("Fix parameters except fc layer") 39 | for param in self.parameters(): 40 | param.requires_grad = False 41 | 42 | self.fc.weight.requires_grad = True 43 | self.fc.bias.requires_grad = True 44 | 45 | 46 | def get_backbone(self): 47 | raise NotImplementedError 48 | 49 | def get_fc(self, num_classes): 50 | raise NotImplementedError 51 | 52 | def forward(self, x): 53 | 54 | feature = self.f(x) 55 | feature = torch.flatten(self.avgpool(feature), start_dim=1) 56 | 57 | logits = self.fc(feature) 58 | 59 | results = { 60 | "out": logits, 61 | "feature": feature, 62 | } 63 | return results 64 | 65 | 66 | class ResNet18(ResNet): 67 | 68 | def get_backbone(self): 69 | return torchvision.models.resnet18(pretrained=True) 70 | 71 | def get_fc(self, num_classes): 72 | return nn.Linear(512, num_classes, bias=True) 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /modules/centroids.py: -------------------------------------------------------------------------------- 1 | import torch, pdb 2 | from torch import nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from kmeans_pytorch import kmeans 6 | 7 | from utils.train_utils import grad_mul_const 8 | 9 | 10 | class Centroids(nn.Module): 11 | 12 | def __init__(self, args, num_classes, per_clusters, feature_dim=None): 13 | super(Centroids, self).__init__() 14 | self.momentum = args.momentum 15 | self.per_clusters = per_clusters 16 | self.num_classes = num_classes 17 | self.feature_dim = args.feature_dim if feature_dim is None else feature_dim 18 | 19 | # Cluster 20 | self.cluster_means = None 21 | self.cluster_vars = torch.zeros((self.num_classes, self.per_clusters)) 22 | self.cluster_losses = torch.zeros((self.num_classes, self.per_clusters)) 23 | self.cluster_accs = torch.zeros((self.num_classes, self.per_clusters)) 24 | self.cluster_weights = torch.zeros((self.num_classes, self.per_clusters)) 25 | 26 | # Sample 27 | self.feature_bank = None 28 | self.assigns = None 29 | self.corrects = None 30 | self.losses = None 31 | self.weights = None 32 | 33 | self.initialized = False 34 | self.weight_type = args.cluster_weight_type 35 | 36 | self.max_cluster_weights = 0. # 0 means no-limit 37 | 38 | def __repr__(self): 39 | return "{}(Y{}/K{}/dim{})".format(self.__class__.__name__, self.num_classes, self.per_clusters, self.feature_dim) 40 | 41 | @property 42 | def num_clusters(self): 43 | return self.num_classes * self.per_clusters 44 | 45 | @property 46 | def cluster_counts(self): 47 | if self.assigns is None: 48 | return 0 49 | return self.assigns.bincount(minlength=self.num_clusters) 50 | 51 | 52 | def _clamp_weights(self, weights): 53 | if self.max_cluster_weights > 0: 54 | if weights.max() > self.max_cluster_weights: 55 | scale = np.log(self.max_cluster_weights)/torch.log(weights.cpu().max()) 56 | scale = scale.cuda() 57 | print("> Weight : {:.4f}, scale : {:.4f}".format(weights.max(), scale)) 58 | return weights ** scale 59 | return weights 60 | 61 | 62 | def get_cluster_weights(self, ids): 63 | if self.assigns is None: 64 | return 1 65 | 66 | cluster_counts = self.cluster_counts + (self.cluster_counts==0).float() # avoid nans 67 | 68 | cluster_weights = cluster_counts.sum()/(cluster_counts.float()) 69 | assigns_id = self.assigns[ids] 70 | 71 | if (self.losses == -1).nonzero().size(0) == 0: 72 | cluster_losses_ = self.cluster_losses.view(-1) 73 | losses_weight = cluster_losses_.float()/cluster_losses_.sum() 74 | weights_ = cluster_weights[assigns_id] * losses_weight[assigns_id].cuda() 75 | weights_ /= weights_.mean() 76 | else: 77 | weights_ = cluster_weights[assigns_id] 78 | weights_ /= weights_.mean() 79 | 80 | return self._clamp_weights(weights_) 81 | 82 | 83 | def initialize_(self, cluster_assigns, cluster_centers, sorted_features=None): 84 | cluster_means = cluster_centers.detach().cuda() 85 | cluster_means = F.normalize(cluster_means, 1) 86 | self.cluster_means = cluster_means.view(self.num_classes, self.per_clusters, -1) 87 | 88 | N = cluster_assigns.size(0) 89 | self.feature_bank = torch.zeros((N, self.feature_dim)).cuda() if sorted_features is None else sorted_features 90 | self.assigns = cluster_assigns 91 | self.corrects = torch.zeros((N)).long().cuda() - 1 92 | self.losses = torch.zeros((N)).cuda() - 1 93 | self.weights = torch.ones((N)).cuda() 94 | self.initialized = True 95 | 96 | 97 | def get_variances(self, x, y): 98 | return 1 - (y @ x).mean(0) 99 | 100 | def compute_centroids(self, verbose=False, split=False): 101 | for y in range(self.num_classes): 102 | for k in range(self.per_clusters): 103 | l = y*self.per_clusters + k 104 | ids = (self.assigns==l).nonzero() 105 | if ids.size(0) == 0: 106 | continue 107 | self.cluster_means[y, k] = self.feature_bank[ids].mean(0) 108 | self.cluster_vars[y, k] = self.get_variances(self.cluster_means[y, k], self.feature_bank[ids]) 109 | 110 | corrs = self.corrects[ids] 111 | corrs_nz = (corrs[:, 0]>=0).nonzero() 112 | if corrs_nz.size(0) > 0: 113 | self.cluster_accs[y, k] = corrs[corrs_nz].float().mean(0) 114 | 115 | losses = self.losses[ids] 116 | loss_nz = (losses[:, 0]>=0).nonzero() 117 | if loss_nz.size(0) > 0: 118 | self.cluster_losses[y, k] = losses[loss_nz].float().mean(0) 119 | 120 | return 121 | 122 | 123 | def update(self, results, target, ids, features=None): 124 | assert self.initialized 125 | 126 | ### update feature and assigns 127 | feature = results["feature"] if features is None else features 128 | feature_ = F.normalize(feature, 1).detach() 129 | 130 | feature_new = (1-self.momentum) * self.feature_bank[ids] + self.momentum * feature_ 131 | feature_new = F.normalize(feature_new, 1) 132 | 133 | self.feature_bank[ids] = feature_new 134 | 135 | sim_score = self.cluster_means @ feature_new.permute(1, 0) # YKC/CB => YKB 136 | 137 | for y in range(self.num_classes): 138 | sim_score[y, :, (target!=y).nonzero()] -= 1e4 139 | 140 | sim_score_ = sim_score.view(self.num_clusters, -1) 141 | new_assigns = sim_score_.argmax(0) 142 | self.assigns[ids] = new_assigns 143 | 144 | corrects = (results["out"].argmax(1) == target).long() 145 | self.corrects[ids] = corrects 146 | 147 | criterion = torch.nn.CrossEntropyLoss(reduction='none') 148 | losses = criterion(results["out"], target.long()).detach() 149 | self.losses[ids] = losses 150 | 151 | return 152 | 153 | 154 | 155 | class FixedCentroids(Centroids): 156 | 157 | def compute_centroids(self, verbose='', split=False): 158 | 159 | for y in range(self.num_classes): 160 | for k in range(self.per_clusters): 161 | l = y*self.per_clusters + k 162 | 163 | ids = (self.assigns==l).nonzero() 164 | if ids.size(0) == 0: 165 | continue 166 | 167 | corrs = self.corrects[ids] 168 | corrs_nz = (corrs[:, 0]>=0).nonzero() 169 | if corrs_nz.size(0) > 0: 170 | self.cluster_accs[y, k] = corrs[corrs_nz].float().mean(0) 171 | 172 | losses = self.losses[ids] 173 | loss_nz = (losses[:, 0]>=0).nonzero() 174 | if loss_nz.size(0) > 0: 175 | self.cluster_losses[y, k] = losses[loss_nz].float().mean(0) 176 | 177 | self.cluster_weights[y, k] = self.weights[ids].float().mean(0) 178 | 179 | return 180 | 181 | def get_cluster_weights(self, ids): 182 | weights_ids = super().get_cluster_weights(ids) 183 | self.weights[ids] = weights_ids 184 | return weights_ids 185 | 186 | 187 | def update(self, results, target, ids, features=None, preds=None): 188 | assert self.initialized 189 | 190 | out = preds if preds is not None else results["out"] 191 | 192 | corrects = (out.argmax(1) == target).long() 193 | self.corrects[ids] = corrects 194 | 195 | criterion = torch.nn.CrossEntropyLoss(reduction='none') 196 | losses = criterion(out, target.long()).detach() 197 | self.losses[ids] = losses 198 | 199 | return 200 | 201 | 202 | 203 | class AvgFixedCentroids(FixedCentroids): 204 | 205 | def __init__(self, args, num_classes, per_clusters, feature_dim=None): 206 | super(AvgFixedCentroids, self).__init__(args, num_classes, per_clusters, feature_dim) 207 | self.exp_step = args.exp_step 208 | self.avg_weight_type = args.avg_weight_type 209 | 210 | def compute_centroids(self, verbose='', split=False): 211 | 212 | for y in range(self.num_classes): 213 | for k in range(self.per_clusters): 214 | l = y*self.per_clusters + k 215 | 216 | ids = (self.assigns==l).nonzero() 217 | if ids.size(0) == 0: 218 | continue 219 | 220 | corrs = self.corrects[ids] 221 | corrs_nz = (corrs[:, 0]>=0).nonzero() 222 | if corrs_nz.size(0) > 0: 223 | self.cluster_accs[y, k] = corrs[corrs_nz].float().mean(0) 224 | 225 | losses = self.losses[ids] 226 | loss_nz = (losses[:, 0]>=0).nonzero() 227 | if loss_nz.size(0) > 0: 228 | self.cluster_losses[y, k] = losses[loss_nz].float().mean(0) 229 | 230 | self.cluster_weights[y, k] = self.weights[ids].float().mean(0) 231 | 232 | return 233 | 234 | 235 | def get_cluster_weights(self, ids): 236 | 237 | weights_ids = super().get_cluster_weights(ids) 238 | 239 | if self.avg_weight_type == 'expavg': 240 | weights_ids_ = self.weights[ids] * torch.exp(self.exp_step*weights_ids.data) 241 | elif self.avg_weight_type == 'avg': 242 | weights_ids_ = (1-self.momentum) * self.weights[ids] + self.momentum * weights_ids 243 | elif self.avg_weight_type == 'expgrad': 244 | weights_ids_l1 = weights_ids / weights_ids.sum() 245 | prev_ids_l1 = self.weights[ids] / self.weights[ids].sum() 246 | weights_ids_ = prev_ids_l1 * torch.exp(self.exp_step*weights_ids_l1.data) 247 | else: 248 | raise ValueError 249 | 250 | self.weights[ids] = weights_ids_ / weights_ids_.mean() 251 | return self.weights[ids] 252 | 253 | 254 | 255 | class HeteroCentroids(nn.Module): 256 | 257 | def __init__(self, args, num_classes, num_hetero_clusters, centroids_type): 258 | super(HeteroCentroids, self).__init__() 259 | self.momentum = args.momentum 260 | self.num_classes = num_classes 261 | self.feature_dim = args.feature_dim 262 | self.initialized = False 263 | 264 | self.num_hetero_clusters = num_hetero_clusters 265 | self.num_multi_centroids = len(num_hetero_clusters) 266 | self.centroids_list = [centroids_type(args, num_classes, per_clusters=num_hetero_clusters[m], feature_dim=self.feature_dim) for m in range(self.num_multi_centroids)] 267 | 268 | def __repr__(self): 269 | return self.__class__.__name__ + "(" + ", ".join([centroids.__repr__() for centroids in self.centroids_list])+ ")" 270 | 271 | 272 | def initialize_multi(self, multi_cluster_assigns, multi_cluster_centers): 273 | for cluster_assigns, cluster_centers, centroids in zip(multi_cluster_assigns, multi_cluster_centers, self.centroids_list): 274 | centroids.initialize_(cluster_assigns, cluster_centers) 275 | self.initialized = True 276 | 277 | def compute_centroids(self, verbose=False, split=False): 278 | for m, centroids in enumerate(self.centroids_list): 279 | verbose_ = str(m) if verbose else '' 280 | centroids.compute_centroids(verbose=verbose_) 281 | 282 | 283 | def update(self, results, target, ids): 284 | for m, centroids in enumerate(self.centroids_list): 285 | features = results["feature"] 286 | centroids.update(results, target, ids) 287 | 288 | 289 | def get_cluster_weights(self, ids): 290 | 291 | weights_list = [centroids.get_cluster_weights(ids) for centroids in self.centroids_list] 292 | weights_ = torch.stack(weights_list).mean(0) 293 | return weights_ 294 | 295 | 296 | 297 | 298 | 299 | 300 | -------------------------------------------------------------------------------- /modules/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | 8 | class GeneralizedCELoss(nn.Module): 9 | 10 | def __init__(self, q=0.7): 11 | super(GeneralizedCELoss, self).__init__() 12 | self.q = q 13 | 14 | def forward(self, logits, targets): 15 | p = F.softmax(logits, dim=1) 16 | if np.isnan(p.mean().item()): 17 | raise NameError('GCE_p') 18 | Yg = torch.gather(p, 1, torch.unsqueeze(targets, 1)) 19 | # modify gradient of cross entropy 20 | loss_weight = (Yg.squeeze().detach()**self.q)*self.q 21 | if np.isnan(Yg.mean().item()): 22 | raise NameError('GCE_Yg') 23 | 24 | loss = F.cross_entropy(logits, targets, reduction='none') * loss_weight 25 | 26 | return loss 27 | -------------------------------------------------------------------------------- /modules/transform.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torchvision import transforms 3 | import cv2 4 | import numpy as np 5 | 6 | 7 | train_transform = transforms.Compose([ 8 | transforms.RandomResizedCrop(32), 9 | transforms.RandomHorizontalFlip(p=0.5), 10 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 11 | transforms.RandomGrayscale(p=0.2), 12 | transforms.ToTensor(), 13 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])]) 14 | 15 | test_transform = transforms.Compose([ 16 | transforms.ToTensor(), 17 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])]) 18 | 19 | 20 | 21 | 22 | celebA_test_transform = transforms.Compose([ 23 | transforms.Resize(224), 24 | transforms.ToTensor(), 25 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 26 | ]) 27 | 28 | celebA_train_transform = transforms.Compose([ 29 | transforms.RandomResizedCrop(224), 30 | transforms.RandomHorizontalFlip(), 31 | transforms.ToTensor(), 32 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 33 | ]) 34 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.4.0 2 | torchvision>=0.2.1 3 | matplotlib 4 | pathlib 5 | pandas 6 | kmeans_pytorch 7 | configargparse 8 | tqdm 9 | seaborn 10 | opencv-python 11 | coloredlogs 12 | wandb 13 | numpy -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, pdb 3 | import logging 4 | from PIL import Image 5 | 6 | import numpy as np 7 | import torch 8 | import torch.optim as optim 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | from tqdm import tqdm 13 | from torch.distributions import Categorical 14 | 15 | from modules.transform import * 16 | from model import Model 17 | from utils.train_utils import * 18 | from utils.io_utils import * 19 | import importlib 20 | 21 | 22 | import pandas as pd 23 | from sklearn.manifold import TSNE 24 | import seaborn as sns 25 | from matplotlib import pyplot as plt 26 | from kmeans_pytorch import kmeans 27 | 28 | import wandb 29 | import pickle 30 | 31 | 32 | class Trainer(): 33 | 34 | def __init__(self, args, model, loaders, optimizer, num_classes): 35 | print(self) 36 | self.args = args 37 | self.model = model 38 | self.loaders = loaders 39 | self.optimizer = optimizer 40 | 41 | self.max_epoch = args.max_epoch 42 | self.batch_size = args.batch_size 43 | 44 | self.k = args.k 45 | self.num_classes = num_classes 46 | self.num_groups = np.power(self.num_classes, len(self.args.bias_attrs)+1) 47 | self.num_clusters = self.k 48 | 49 | self.logger = get_logger('') 50 | 51 | self.accs = WindowAvgMeter(name='accs', max_count=20) 52 | 53 | 54 | def set_checkpoint_dir(self, checkpoint_dir): 55 | self.checkpoint_dir = checkpoint_dir 56 | 57 | 58 | def save_model(self, epoch): 59 | if not self.args.no_save or epoch % self.args.save_epoch == 0: 60 | torch.save({ 61 | 'epoch': epoch, 62 | 'state_dict': self.model.state_dict(), 63 | 'optimizer' : self.optimizer.state_dict(), 64 | }, self.checkpoint_dir / 'e{:04d}.pth'.format(epoch)) 65 | return 66 | 67 | def load_model(self, epoch=0): 68 | self.logger.info('Resume training') 69 | if epoch==0: 70 | checkpoint_path = max((f.stat().st_mtime, f) for f in self.checkpoint_dir.glob('*.pth'))[1] 71 | self.logger.info('Resume Latest from {}'.format(checkpoint_path)) 72 | else: 73 | checkpoint_path = self.checkpoint_dir / 'e{:04d}.pth'.format(epoch) 74 | self.logger.info('Resume from {}'.format(checkpoint_path)) 75 | 76 | checkpoint = torch.load(checkpoint_path) 77 | self.model.load_state_dict((checkpoint['state_dict'])) # Set CUDA before if error occurs. 78 | self.optimizer.load_state_dict(checkpoint['optimizer']) 79 | self.epoch = checkpoint['epoch'] 80 | 81 | 82 | def load_path(self, file_name): 83 | self.logger.info('Loading model at ({})'.format(file_name)) 84 | checkpoint_path = self.checkpoint_dir / '..' / '..' / '{}.pth'.format(file_name) 85 | checkpoint = torch.load(checkpoint_path) 86 | self.model.load_state_dict(checkpoint['state_dict']) # Set CUDA before if error occurs. 87 | 88 | 89 | def finetune(self, epoch, iter): 90 | return 91 | 92 | 93 | def extract_sample(self, data, desc): 94 | data_ = data.permute(1,2,0).cpu().numpy().astype(np.uint8) 95 | img = Image.fromarray(data_) 96 | img.save(self.checkpoint_dir / '{}.png'.format(desc)) 97 | 98 | 99 | 100 | def _extract_features_with_path(self, model, data_loader): 101 | features, targets = [], [] 102 | ids = [] 103 | paths = [] 104 | 105 | for data, target, index, path in tqdm(data_loader, desc='Feature extraction for clustering..', ncols=5): 106 | data, target, index = data.cuda(), target.cuda(), index.cuda() 107 | results = model(data) 108 | features.append(results["feature"]) 109 | targets.append(target) 110 | ids.append(index) 111 | paths.append(path) 112 | 113 | features = torch.cat(features) 114 | targets = torch.cat(targets) 115 | ids = torch.cat(ids) 116 | paths = np.concatenate(paths) 117 | return features, targets, ids, paths 118 | 119 | 120 | 121 | 122 | class ClassifyTrainer(Trainer): 123 | 124 | def train(self, epoch): 125 | 126 | data_loader = self.loaders['train'] 127 | total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader, ncols=100) 128 | criterion = torch.nn.CrossEntropyLoss(reduction='none') 129 | 130 | 131 | for data, target, _, _, _, _, _ in train_bar: 132 | B = target.size(0) 133 | 134 | data, target = data.cuda(), target.cuda() 135 | 136 | results = self.model(data) 137 | loss = torch.mean(criterion(results["out"], target.long())) 138 | 139 | self.optimizer.zero_grad() 140 | loss.backward() 141 | self.optimizer.step() 142 | 143 | total_num += B 144 | total_loss += loss.item() * B 145 | 146 | train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, self.max_epoch, total_loss / total_num)) 147 | 148 | 149 | return total_loss / total_num 150 | 151 | 152 | def test(self, epoch): 153 | self.model.eval() 154 | test_loader = self.loaders['test'] 155 | 156 | total_top1, total_top5, total_num, test_bar = 0.0, 0.0, 0, tqdm(test_loader, ncols=100) 157 | with torch.no_grad(): 158 | 159 | for data, target, _, _, _, _, _ in test_bar: 160 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 161 | B = target.size(0) 162 | 163 | results = self.model(data) 164 | pred_labels = results["out"].argsort(dim=-1, descending=True) 165 | 166 | total_num += B 167 | total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 168 | total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 169 | test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}% Acc@5:{:.2f}%' 170 | .format(epoch, self.max_epoch, total_top1 / total_num * 100, total_top5 / total_num * 100)) 171 | 172 | 173 | self.model.train() 174 | 175 | return 176 | 177 | 178 | class BiasedClassifyTrainer(ClassifyTrainer): 179 | 180 | 181 | def test(self, epoch): 182 | self.model.eval() 183 | 184 | for desc in ['test']: 185 | loader = self.loaders[desc] 186 | 187 | total_top1, total_top5, total_num, test_bar = 0.0, 0.0, 0, tqdm(loader, ncols=100) 188 | 189 | with torch.no_grad(): 190 | 191 | for data, target, bias, _, _, _, _ in test_bar: 192 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 193 | 194 | B = target.size(0) 195 | 196 | results = self.model(data) 197 | pred_labels = results["out"].argsort(dim=-1, descending=True) 198 | 199 | total_num += B 200 | total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 201 | total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 202 | test_bar.set_description('[{}] Test Epoch: [{}/{}] Acc@1:{:.2f}% Acc@5:{:.2f}%' 203 | .format(desc, epoch, self.max_epoch, total_top1 / total_num * 100, total_top5 / total_num * 100)) 204 | 205 | log = self.logger.info if desc in ['train', 'train_eval'] else self.logger.warning 206 | log('Eval Epoch [{}/{}] ({}) Acc@1:{:.2f}% Acc@5:{:.2f}%'.format(epoch, self.max_epoch, desc, total_top1 / total_num * 100, total_top5 / total_num * 100)) 207 | 208 | self.model.train() 209 | return 210 | 211 | 212 | 213 | 214 | def test_unbiased(self, epoch, train_eval=True): 215 | self.model.eval() 216 | 217 | test_envs = ['valid', 'test'] 218 | for desc in test_envs: 219 | loader = self.loaders[desc] 220 | 221 | total_top1, total_top5, total_num, test_bar = 0.0, 0.0, 0, tqdm(loader, ncols=100) 222 | 223 | num_classes = len(loader.dataset.classes) 224 | num_groups = loader.dataset.num_groups 225 | 226 | bias_counts = torch.zeros(num_groups).cuda() 227 | bias_top1s = torch.zeros(num_groups).cuda() 228 | 229 | 230 | with torch.no_grad(): 231 | 232 | features, labels = [], [] 233 | logits = [] 234 | corrects = [] 235 | 236 | for data, target, biases, group, _, _, ids in test_bar: 237 | data, target, biases, group = data.cuda(), target.cuda(), biases.cuda(), group.cuda() 238 | 239 | B = target.size(0) 240 | 241 | results = self.model(data) 242 | pred_labels = results["out"].argsort(dim=-1, descending=True) 243 | features.append(results["feature"]) 244 | logits.append(results["out"]) 245 | labels.append(group) 246 | 247 | 248 | top1s = (pred_labels[:, :1] == target.unsqueeze(dim=-1)).squeeze().unsqueeze(0) 249 | group_indices = (group==torch.arange(num_groups).unsqueeze(1).long().cuda()) 250 | 251 | bias_counts += group_indices.sum(1) 252 | bias_top1s += (top1s * group_indices).sum(1) 253 | 254 | corrects.append(top1s) 255 | 256 | total_num += B 257 | total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 258 | total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item() 259 | acc1, acc5 = total_top1 / total_num * 100, total_top5 / total_num * 100 260 | 261 | bias_accs = bias_top1s / bias_counts * 100 262 | 263 | avg_acc = np.nanmean(bias_accs.cpu().numpy()) 264 | worst_acc = np.nanmin(bias_accs.cpu().numpy()) 265 | std_acc = np.nanstd(bias_accs.cpu().numpy()) 266 | 267 | acc_desc = '/'.join(['{:.1f}%'.format(acc) for acc in bias_accs]) 268 | 269 | test_bar.set_description('Eval Epoch [{}/{}] [{}] Bias: {:.2f}%'.format(epoch, self.max_epoch, desc, avg_acc)) 270 | 271 | features = torch.cat(features) 272 | logits = torch.cat(logits) 273 | labels = torch.cat(labels) 274 | corrects = torch.cat(corrects, 1) 275 | 276 | 277 | log = self.logger.info if desc in ['train', 'train_eval'] else self.logger.warning 278 | log('Eval Epoch [{}/{}] [{}] Unbiased: {:.2f}% (std: {:.2f}), Worst: {:.2f}% [{}] (Average: {:.2f}%)'.format(epoch, self.max_epoch, desc, avg_acc, std_acc, worst_acc, acc_desc)) 279 | self.logger.info('Total [{}]: Acc@1:{:.2f}% Acc@5:{:.2f}%'.format(desc, acc1, acc5)) 280 | print(" {} / {} / {}".format(self.args.desc, self.args.target_attr, self.args.bias_attrs)) 281 | 282 | 283 | self.model.train() 284 | 285 | return 286 | 287 | 288 | 289 | 290 | 291 | -------------------------------------------------------------------------------- /utils/io_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import hashlib 4 | import errno 5 | import logging 6 | from collections import defaultdict 7 | from string import Formatter 8 | import torch 9 | import re, pdb 10 | from datetime import datetime 11 | 12 | import cv2 13 | import numpy as np 14 | import pdb 15 | 16 | 17 | def get_logger(name, fmt='%(asctime)s:%(name)s:%(message)s', 18 | print_level=logging.INFO, 19 | write_level=logging.DEBUG, log_file='', mode='w'): 20 | """ 21 | Get Logger with given name 22 | :param name: logger name. 23 | :param fmt: log format. (default: %(asctime)s:%(levelname)s:%(name)s:%(message)s) 24 | :param level: logging level. (default: logging.DEBUG) 25 | :param log_file: path of log file. (default: None) 26 | :return: 27 | """ 28 | logger = logging.getLogger(name) 29 | # logger.setLevel(write_level) 30 | logging.basicConfig(level=print_level) 31 | formatter = logging.Formatter(fmt, datefmt='%Y/%m/%d %H:%M:%S') 32 | 33 | # Add file handler 34 | if log_file: 35 | file_handler = logging.FileHandler(log_file, mode=mode) 36 | file_handler.setLevel(write_level) 37 | file_handler.setFormatter(formatter) 38 | logger.addHandler(file_handler) 39 | 40 | if print_level is not None: 41 | try: 42 | import coloredlogs 43 | coloredlogs.install(level=print_level, logger=logger) 44 | coloredlogs.DEFAULT_LEVEL_STYLES = {'critical': {'color': 'red', 'bold': True}, 'debug': {'color': 'green'}, 'error': {'color': 'red'}, 'info': {}, 'notice': {'color': 'magenta'}, 'spam': {'color': 'green', 'faint': True}, 'success': {'color': 'green', 'bold': True}, 'verbose': {'color': 'blue'}, 'warning': {'color': 'yellow'}} 45 | except ImportError: 46 | print("Please install Coloredlogs for better view") 47 | # Add stream handler 48 | stream_handler = logging.StreamHandler() 49 | stream_handler.setLevel(print_level) 50 | stream_handler.setFormatter(formatter) 51 | logger.addHandler(stream_handler) 52 | return logger -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import torch, pdb 2 | import numpy as np 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | 6 | def get_negative_mask(batch_size): 7 | negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool) 8 | for i in range(batch_size): 9 | negative_mask[i, i] = 0 10 | negative_mask[i, i + batch_size] = 0 11 | 12 | negative_mask = torch.cat((negative_mask, negative_mask), 0) 13 | return negative_mask 14 | 15 | 16 | def get_negative_class_mask(targets): 17 | batch_size = targets.size(0) 18 | negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool) 19 | for i in range(batch_size): 20 | current_c = targets[i] 21 | same_indices = (targets == current_c).nonzero().squeeze() 22 | for s in same_indices: 23 | negative_mask[i, s] = 0 24 | negative_mask[i, s + batch_size] = 0 25 | 26 | negative_mask = torch.cat((negative_mask, negative_mask), 0) 27 | return negative_mask 28 | 29 | 30 | 31 | class GradMulConst(torch.autograd.Function): 32 | """ 33 | This layer is used to create an adversarial loss. 34 | """ 35 | @staticmethod 36 | def forward(ctx, x, const): 37 | ctx.const = const 38 | return x.view_as(x) 39 | 40 | @staticmethod 41 | def backward(ctx, grad_output): 42 | return grad_output * ctx.const, None 43 | 44 | def grad_mul_const(x, const): 45 | return GradMulConst.apply(x, const) 46 | 47 | 48 | 49 | class AvgMeter(object): 50 | def __init__(self, name=''): 51 | self.reset() 52 | 53 | def reset(self): 54 | self.val = 0. 55 | self.avg = 0. 56 | self.sum = 0. 57 | self.count = 0. 58 | 59 | def update(self, val, n=1): 60 | if type(val) is torch.Tensor: 61 | val = val.detach() 62 | val = val.cpu() 63 | val = val.numpy() 64 | 65 | if n==len(val): 66 | self.val = val[-1] 67 | self.sum += np.sum(val) 68 | self.count += len(val) 69 | elif n==0: # array 70 | self.val = val[-1] 71 | self.sum += np.sum(val) 72 | self.count += len(val) 73 | else: 74 | self.val = val 75 | self.sum += val 76 | self.count += n 77 | 78 | self.avg = self.sum / self.count 79 | 80 | def __repr__(self): 81 | return self.name+":"+str(round(self.avg, 3)) 82 | 83 | 84 | class WindowAvgMeter(object): 85 | def __init__(self, name='', max_count=20): 86 | self.values = [] 87 | self.name = name 88 | self.max_count = max_count 89 | 90 | @property 91 | def avg(self): 92 | if len(self.values) > 0: 93 | return np.sum(self.values)/len(self.values) 94 | else: 95 | return 0 96 | 97 | def update(self, val): 98 | if type(val) is torch.Tensor: 99 | val = val.detach() 100 | val = val.cpu() 101 | val = val.numpy() 102 | 103 | self.values.append(val) 104 | if len(self.values) > self.max_count: 105 | self.values.pop(0) 106 | 107 | def __repr__(self): 108 | return self.name+":"+str(round(self.avg, 3)) 109 | 110 | 111 | class EMA(object): 112 | # Exponential Moving Average 113 | 114 | def __init__(self, label, alpha=0.9): 115 | self.label = label 116 | self.alpha = alpha 117 | self.parameter = torch.zeros(label.size(0)) 118 | self.updated = torch.zeros(label.size(0)) 119 | 120 | def update(self, data, index): 121 | self.parameter[index] = self.alpha * self.parameter[index] + (1-self.alpha*self.updated[index]) * data 122 | self.updated[index] = 1 123 | 124 | def max_loss(self, label): 125 | label_index = np.where(self.label == label)[0] 126 | return self.parameter[label_index].max() 127 | 128 | 129 | 130 | 131 | class FocalLoss(torch.nn.Module): 132 | def __init__(self, gamma=0, alpha=None, size_average=True): 133 | super(FocalLoss, self).__init__() 134 | self.gamma = gamma 135 | self.alpha = alpha 136 | if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha]) 137 | if isinstance(alpha,list): self.alpha = torch.Tensor(alpha) 138 | self.size_average = size_average 139 | 140 | def forward(self, input, target): 141 | if input.dim()>2: 142 | input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W 143 | input = input.transpose(1,2) # N,C,H*W => N,H*W,C 144 | input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C 145 | target = target.view(-1,1) 146 | 147 | logpt = F.log_softmax(input, 1) 148 | logpt = logpt.gather(1,target) 149 | logpt = logpt.view(-1) 150 | pt = Variable(logpt.data.exp()) 151 | 152 | if self.alpha is not None: 153 | if self.alpha.type()!=input.data.type(): 154 | self.alpha = self.alpha.type_as(input.data) 155 | at = self.alpha.gather(0,target.data.view(-1)) 156 | logpt = logpt * Variable(at) 157 | 158 | loss = -1 * (1-pt)**self.gamma * logpt 159 | if self.size_average: return loss.mean() 160 | else: return loss.sum() 161 | 162 | 163 | 164 | def KL_u_p_loss(outputs): 165 | # KL(u||p) 166 | num_classes = outputs.size(1) 167 | uniform_tensors = torch.ones(outputs.size()) 168 | uniform_dists = Variable(uniform_tensors / num_classes).cuda() 169 | instance_losses = F.kl_div(F.log_softmax(outputs, dim=1), uniform_dists, reduction='none').sum(dim=1) 170 | return instance_losses 171 | --------------------------------------------------------------------------------