├── utilities ├── __init__.py ├── sampler.py ├── presets.py └── transforms.py ├── assets └── fairtune_framework.png ├── config.yaml ├── data ├── fitzpatrick.py ├── papila.py ├── glaucoma.py ├── chexpert.py ├── ol3i.py ├── oasis.py └── HAM10000.py ├── README.md ├── parse_args.py ├── finetune_with_mask.py └── search_mask.py /utilities/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/fairtune_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Raman1121/FairTune/HEAD/assets/fairtune_framework.png -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | HAM10000: 3 | root_path: '' 4 | img_path: '' 5 | train_csv: '' 6 | val_csv: '' 7 | test_csv: '' 8 | 9 | fitzpatrick: '' 10 | root_path: '' 11 | img_path: '' 12 | train_csv: '' 13 | val_csv: '' 14 | test_csv: '' 15 | 16 | papila: 17 | root_path: '' 18 | img_path: '' 19 | train_csv: '' 20 | val_csv: '' 21 | test_csv: '' 22 | 23 | ol3i: 24 | root_path: '' 25 | img_path: '' 26 | train_csv: '' 27 | val_csv: '' 28 | test_csv: '' 29 | 30 | oasis: 31 | root_path: '' 32 | img_path: '' 33 | train_csv: '' 34 | val_csv: '' 35 | test_csv: '' 36 | 37 | chexpert: 38 | root_path: '' 39 | img_path: '' 40 | train_csv: '' 41 | val_csv: '' 42 | test_csv: '' 43 | 44 | glaucoma: 45 | root_path: '' 46 | img_path: '' 47 | train_csv: '' 48 | val_csv: '' 49 | test_csv: '' 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /data/fitzpatrick.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import torchvision 4 | from torchvision import transforms as T 5 | from torchvision.io import read_image, ImageReadMode 6 | 7 | import numpy as np 8 | import random 9 | import yaml 10 | from PIL import Image 11 | 12 | 13 | class FitzpatrickDataset(Dataset): 14 | def __init__(self, df, transform=None, skin_type="multi"): 15 | self.df = df 16 | self.transform = transform 17 | self.skin_type = skin_type 18 | self.classes = self.get_num_classes() 19 | self.class_to_idx = self._get_class_to_idx() 20 | 21 | def __len__(self): 22 | return len(self.df) 23 | 24 | def get_num_classes(self): 25 | # return self.df['label_idx'].unique() 26 | return self.df["binary_label"].unique() 27 | 28 | def _get_class_to_idx(self): 29 | return {"benign": 0, "malignant": 1, "non-neoplastic": 2} 30 | 31 | def __getitem__(self, idx): 32 | image = read_image(self.df.iloc[idx]["Path"], mode=ImageReadMode.RGB) 33 | image = T.ToPILImage()(image) 34 | # label = self.df.iloc[idx]['label_idx'] 35 | label = self.df.iloc[idx]["binary_label"] 36 | 37 | if self.skin_type == "multi": 38 | sens_attribute = self.df.iloc[idx]["skin_type"] 39 | elif self.skin_type == "binary": 40 | sens_attribute = self.df.iloc[idx]["skin_binary"] 41 | 42 | if self.transform: 43 | image = self.transform(image) 44 | 45 | return image, label, sens_attribute 46 | -------------------------------------------------------------------------------- /data/papila.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import torchvision 4 | from torchvision import transforms as T 5 | from torchvision.io import read_image, ImageReadMode 6 | 7 | import numpy as np 8 | import random 9 | import yaml 10 | from PIL import Image 11 | 12 | 13 | class PapilaDataset(Dataset): 14 | def __init__(self, df, sens_attribute=None, transform=None): 15 | assert sens_attribute is not None 16 | 17 | self.df = df 18 | self.transform = transform 19 | self.sens_attribute = sens_attribute 20 | self.classes = self.get_num_classes() 21 | self.class_to_idx = self._get_class_to_idx() 22 | 23 | def __len__(self): 24 | return len(self.df) 25 | 26 | def get_num_classes(self): 27 | return self.df["Diagnosis"].unique() 28 | 29 | def _get_class_to_idx(self): 30 | return { 31 | "healthy": 0, 32 | "glaucoma": 1, 33 | } 34 | 35 | def __getitem__(self, idx): 36 | image = read_image(self.df.iloc[idx]["Path"], mode=ImageReadMode.RGB) 37 | image = T.ToPILImage()(image) 38 | label = torch.tensor(self.df.iloc[idx]["Diagnosis"]).to(torch.int64) 39 | 40 | if self.sens_attribute == "gender": 41 | sens_attribute = self.df.iloc[idx]["Sex"] 42 | elif self.sens_attribute == "age": 43 | # sens_attribute = self.df.iloc[idx]['Age_multi'] 44 | sens_attribute = self.df.iloc[idx]["Age_binary"] 45 | 46 | if self.transform: 47 | image = self.transform(image) 48 | 49 | return image, label, sens_attribute 50 | -------------------------------------------------------------------------------- /data/glaucoma.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import torchvision 4 | from torchvision import transforms as T 5 | from torchvision.io import read_image, ImageReadMode 6 | 7 | import numpy as np 8 | import random 9 | import yaml 10 | from PIL import Image 11 | 12 | 13 | class HarvardGlaucoma(Dataset): 14 | def __init__(self, df, sens_attribute=None, transform=None, age_type="binary"): 15 | assert sens_attribute is not None 16 | 17 | self.df = df 18 | self.transform = transform 19 | self.sens_attribute = sens_attribute 20 | self.age_type = age_type 21 | self.classes = self.get_num_classes() 22 | self.class_to_idx = self._get_class_to_idx() 23 | 24 | def __len__(self): 25 | return len(self.df) 26 | 27 | def get_num_classes(self): 28 | return self.df["Glaucoma"].unique() 29 | 30 | def _get_class_to_idx(self): 31 | return { 32 | "healthy": 0, 33 | "glaucoma": 1, 34 | } 35 | 36 | def __getitem__(self, idx): 37 | image = read_image(self.df.iloc[idx]["Path"], mode=ImageReadMode.RGB) 38 | image = T.ToPILImage()(image) 39 | label = torch.tensor(self.df.iloc[idx]["Glaucoma"]).to(torch.int64) 40 | 41 | if self.sens_attribute == "gender": 42 | sens_attribute = self.df.iloc[idx]["Sex"] 43 | elif self.sens_attribute == "age": 44 | if self.age_type == "multi": 45 | sens_attribute = self.df.iloc[idx]["Age_multi"] 46 | elif self.age_type == "binary": 47 | sens_attribute = self.df.iloc[idx]["Age_binary"] 48 | elif self.sens_attribute == "race": 49 | sens_attribute = self.df.iloc[idx]["Race"] 50 | 51 | if self.transform: 52 | image = self.transform(image) 53 | 54 | return image, label, sens_attribute 55 | -------------------------------------------------------------------------------- /data/chexpert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import torchvision 4 | from torchvision import transforms as T 5 | from torchvision.io import read_image, ImageReadMode 6 | 7 | import numpy as np 8 | import random 9 | import yaml 10 | from PIL import Image 11 | 12 | 13 | class ChexpertDataset(Dataset): 14 | def __init__(self, df, sens_attribute=None, transform=None, age_type="binary"): 15 | assert sens_attribute is not None 16 | 17 | self.df = df 18 | self.transform = transform 19 | self.sens_attribute = sens_attribute 20 | self.age_type = age_type 21 | self.classes = self.get_num_classes() 22 | self.class_to_idx = self._get_class_to_idx() 23 | 24 | def __len__(self): 25 | return len(self.df) 26 | 27 | def get_num_classes(self): 28 | return self.df["No Finding"].unique() 29 | 30 | def _get_class_to_idx(self): 31 | return { 32 | "Healthy": 0, 33 | "Not Healthy": 1, 34 | } 35 | 36 | def __getitem__(self, idx): 37 | image = read_image(self.df.iloc[idx]["Path"], mode=ImageReadMode.RGB) 38 | image = T.ToPILImage()(image) 39 | label = torch.tensor(self.df.iloc[idx]["No Finding"]).to(torch.int64) 40 | 41 | if self.sens_attribute == "gender": 42 | sens_attribute = self.df.iloc[idx]["Sex"] 43 | elif self.sens_attribute == "age": 44 | if self.age_type == "multi": 45 | sens_attribute = self.df.iloc[idx]["Age_multi"] 46 | elif self.age_type == "binary": 47 | sens_attribute = self.df.iloc[idx]["Age_binary"] 48 | elif self.sens_attribute == "age_sex": 49 | sens_attribute = self.df.iloc[idx]["Multi_age_multi_sens"] 50 | 51 | if self.transform: 52 | image = self.transform(image) 53 | 54 | return image, label, sens_attribute 55 | -------------------------------------------------------------------------------- /data/ol3i.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import torchvision 4 | from torchvision import transforms as T 5 | from torchvision.io import read_image, ImageReadMode 6 | 7 | import numpy as np 8 | import random 9 | import yaml 10 | from PIL import Image 11 | 12 | 13 | class OL3IDataset(Dataset): 14 | def __init__(self, df, sens_attribute=None, transform=None, age_type=None): 15 | assert sens_attribute is not None 16 | assert age_type is not None 17 | 18 | self.df = df 19 | self.transform = transform 20 | self.sens_attribute = sens_attribute 21 | self.age_type = age_type 22 | self.classes = self.get_num_classes() 23 | self.class_to_idx = self._get_class_to_idx() 24 | 25 | def __len__(self): 26 | return len(self.df) 27 | 28 | def get_num_classes(self): 29 | return self.df["label_1y"].unique() 30 | 31 | def _get_class_to_idx(self): 32 | return { 33 | 0: 0, 34 | 1: 1, 35 | } 36 | 37 | def __getitem__(self, idx): 38 | img_path = self.df.iloc[idx]["path"] 39 | image = np.load(img_path).astype(np.uint8) 40 | image = Image.fromarray(image).convert("RGB") 41 | label = torch.tensor(self.df.iloc[idx]["label_1y"]).to(torch.int64) 42 | 43 | if self.sens_attribute == "gender": 44 | sens_attribute = self.df.iloc[idx]["sex"] 45 | elif self.sens_attribute == "age": 46 | if self.age_type == "multi": 47 | sens_attribute = self.df.iloc[idx]["Age_multi"] 48 | elif self.age_type == "binary": 49 | sens_attribute = self.df.iloc[idx]["Age_binary"] 50 | else: 51 | raise NotImplementedError("Age type not implemented") 52 | 53 | if self.transform: 54 | image = self.transform(image) 55 | 56 | return image, label, sens_attribute 57 | -------------------------------------------------------------------------------- /data/oasis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import torchvision 4 | from torchvision import transforms as T 5 | from torchvision.io import read_image, ImageReadMode 6 | 7 | import numpy as np 8 | import random 9 | import yaml 10 | from PIL import Image 11 | 12 | 13 | class OASISDataset(Dataset): 14 | def __init__(self, df, sens_attribute=None, transform=None, age_type=None): 15 | assert sens_attribute is not None 16 | assert age_type is not None 17 | 18 | self.df = df 19 | self.transform = transform 20 | self.sens_attribute = sens_attribute 21 | self.age_type = age_type 22 | self.classes = self.get_num_classes() 23 | self.class_to_idx = self._get_class_to_idx() 24 | 25 | def __len__(self): 26 | return len(self.df) 27 | 28 | def get_num_classes(self): 29 | return self.df["CDR"].unique() 30 | 31 | def _get_class_to_idx(self): 32 | # return {0:'Non-Demented', 33 | # 1:'Very Mild-Dementia', 34 | # 2:'Mild-Dementia', 35 | # 3:'Moderate-Dementia' 36 | # } 37 | return { 38 | 0: "Non-Demented", 39 | 1: "Dementia", 40 | } 41 | 42 | def __getitem__(self, idx): 43 | img_path = self.df.iloc[idx]["Path"] 44 | image = np.load(img_path).astype(np.uint8) 45 | image = Image.fromarray(image).convert("RGB") 46 | label = torch.tensor(self.df.iloc[idx]["CDR"]).to(torch.int64) 47 | 48 | if self.sens_attribute == "gender": 49 | sens_attribute = self.df.iloc[idx]["Gender"] 50 | elif self.sens_attribute == "age": 51 | if self.age_type == "multi": 52 | sens_attribute = self.df.iloc[idx]["Age_multi"] 53 | elif self.age_type == "binary": 54 | sens_attribute = self.df.iloc[idx]["Age_binary"] 55 | else: 56 | raise NotImplementedError("Age type not implemented") 57 | 58 | if self.transform: 59 | image = self.transform(image) 60 | 61 | return image, label, sens_attribute 62 | -------------------------------------------------------------------------------- /utilities/sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | 7 | class RASampler(torch.utils.data.Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset for distributed, 9 | with repeated augmentation. 10 | It ensures that different each augmented version of a sample will be visible to a 11 | different process (GPU). 12 | Heavily based on 'torch.utils.data.DistributedSampler'. 13 | 14 | This is borrowed from the DeiT Repo: 15 | https://github.com/facebookresearch/deit/blob/main/samplers.py 16 | """ 17 | 18 | def __init__( 19 | self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3 20 | ): 21 | if num_replicas is None: 22 | if not dist.is_available(): 23 | raise RuntimeError("Requires distributed package to be available!") 24 | num_replicas = dist.get_world_size() 25 | if rank is None: 26 | if not dist.is_available(): 27 | raise RuntimeError("Requires distributed package to be available!") 28 | rank = dist.get_rank() 29 | self.dataset = dataset 30 | self.num_replicas = num_replicas 31 | self.rank = rank 32 | self.epoch = 0 33 | self.num_samples = int( 34 | math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas) 35 | ) 36 | self.total_size = self.num_samples * self.num_replicas 37 | self.num_selected_samples = int( 38 | math.floor(len(self.dataset) // 256 * 256 / self.num_replicas) 39 | ) 40 | self.shuffle = shuffle 41 | self.seed = seed 42 | self.repetitions = repetitions 43 | 44 | def __iter__(self): 45 | if self.shuffle: 46 | # Deterministically shuffle based on epoch 47 | g = torch.Generator() 48 | g.manual_seed(self.seed + self.epoch) 49 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 50 | else: 51 | indices = list(range(len(self.dataset))) 52 | 53 | # Add extra samples to make it evenly divisible 54 | indices = [ele for ele in indices for i in range(self.repetitions)] 55 | indices += indices[: (self.total_size - len(indices))] 56 | assert len(indices) == self.total_size 57 | 58 | # Subsample 59 | indices = indices[self.rank : self.total_size : self.num_replicas] 60 | assert len(indices) == self.num_samples 61 | 62 | return iter(indices[: self.num_selected_samples]) 63 | 64 | def __len__(self): 65 | return self.num_selected_samples 66 | 67 | def set_epoch(self, epoch): 68 | self.epoch = epoch 69 | -------------------------------------------------------------------------------- /data/HAM10000.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import torchvision 4 | from torchvision import transforms as T 5 | from torchvision.io import read_image, ImageReadMode 6 | 7 | import numpy as np 8 | import random 9 | import yaml 10 | from PIL import Image 11 | 12 | 13 | class HAM10000Dataset(Dataset): 14 | def __init__( 15 | self, df, sens_attribute, transform=None, age_type="multi", label_type="binary" 16 | ): 17 | assert sens_attribute is not None 18 | 19 | self.df = df 20 | self.transform = transform 21 | self.sens_attribute = sens_attribute 22 | self.age_type = age_type 23 | self.label_type = label_type 24 | # self.use_binary_label = use_binary_label 25 | self.classes = self.get_num_classes() 26 | self.class_to_idx = self._get_class_to_idx() 27 | 28 | def __len__(self): 29 | return len(self.df) 30 | 31 | def get_num_classes(self): 32 | # return self.df['dx_index'].unique() 33 | if self.label_type == "multi": 34 | return self.df["MultiLabels"].unique() 35 | else: 36 | return self.df["binaryLabel"].unique() 37 | 38 | def _get_original_labels(self): 39 | return { 40 | "akiec": "Bowen's disease", 41 | "bcc": "basal cell carcinoma", 42 | "bkl": "benign keratosis-like lesions", 43 | "df": "dermatofibroma", 44 | "nv": "melanocytic nevi", 45 | "mel": "melanoma", 46 | "vasc": "vascular lesions", 47 | } 48 | 49 | def _get_class_to_idx(self): 50 | return {"akiec": 0, "bcc": 1, "bkl": 2, "df": 3, "nv": 4, "mel": 5, "vasc": 6} 51 | 52 | def __getitem__(self, idx): 53 | image = read_image(self.df.iloc[idx]["Path"], mode=ImageReadMode.RGB) 54 | image = T.ToPILImage()(image) 55 | # label = torch.tensor(self.df.iloc[idx]['dx_index']) 56 | 57 | if self.label_type == "multi": 58 | label = torch.tensor(self.df.iloc[idx]["MultiLabels"]) 59 | else: 60 | label = torch.tensor(self.df.iloc[idx]["binaryLabel"]) 61 | 62 | # print("LABEL: ", label) 63 | 64 | if self.sens_attribute == "gender": 65 | sens_attribute = self.df.iloc[idx]["sex"] 66 | elif self.sens_attribute == "age": 67 | if self.age_type == "multi": 68 | sens_attribute = self.df.iloc[idx]["Age_multi2"] 69 | elif self.age_type == "binary": 70 | sens_attribute = self.df.iloc[idx]["Age_binary"] 71 | else: 72 | raise ValueError("Invalid sensitive attribute for HAM10000 dataset") 73 | 74 | if self.transform: 75 | image = self.transform(image) 76 | 77 | return image, label, sens_attribute 78 | -------------------------------------------------------------------------------- /utilities/presets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import autoaugment, transforms 3 | from torchvision.transforms.functional import InterpolationMode 4 | 5 | 6 | class ClassificationPresetTrain: 7 | def __init__( 8 | self, 9 | *, 10 | crop_size, 11 | mean=(0.485, 0.456, 0.406), 12 | std=(0.229, 0.224, 0.225), 13 | interpolation=InterpolationMode.BILINEAR, 14 | hflip_prob=0.5, 15 | auto_augment_policy=None, 16 | ra_magnitude=9, 17 | augmix_severity=3, 18 | random_erase_prob=0.0, 19 | ): 20 | trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] 21 | if hflip_prob > 0: 22 | trans.append(transforms.RandomHorizontalFlip(hflip_prob)) 23 | if auto_augment_policy is not None: 24 | if auto_augment_policy == "ra": 25 | trans.append( 26 | autoaugment.RandAugment( 27 | interpolation=interpolation, magnitude=ra_magnitude 28 | ) 29 | ) 30 | elif auto_augment_policy == "ta_wide": 31 | trans.append( 32 | autoaugment.TrivialAugmentWide(interpolation=interpolation) 33 | ) 34 | elif auto_augment_policy == "augmix": 35 | trans.append( 36 | autoaugment.AugMix( 37 | interpolation=interpolation, severity=augmix_severity 38 | ) 39 | ) 40 | else: 41 | aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) 42 | trans.append( 43 | autoaugment.AutoAugment( 44 | policy=aa_policy, interpolation=interpolation 45 | ) 46 | ) 47 | trans.extend( 48 | [ 49 | transforms.PILToTensor(), 50 | transforms.ConvertImageDtype(torch.float), 51 | transforms.Normalize(mean=mean, std=std), 52 | ] 53 | ) 54 | if random_erase_prob > 0: 55 | trans.append(transforms.RandomErasing(p=random_erase_prob)) 56 | 57 | self.transforms = transforms.Compose(trans) 58 | 59 | def __call__(self, img): 60 | return self.transforms(img) 61 | 62 | 63 | class ClassificationPresetEval: 64 | def __init__( 65 | self, 66 | *, 67 | crop_size, 68 | resize_size=256, 69 | mean=(0.485, 0.456, 0.406), 70 | std=(0.229, 0.224, 0.225), 71 | interpolation=InterpolationMode.BILINEAR, 72 | ): 73 | self.transforms = transforms.Compose( 74 | [ 75 | transforms.Resize(resize_size, interpolation=interpolation), 76 | transforms.CenterCrop(crop_size), 77 | transforms.PILToTensor(), 78 | transforms.ConvertImageDtype(torch.float), 79 | transforms.Normalize(mean=mean, std=std), 80 | ] 81 | ) 82 | 83 | def __call__(self, img): 84 | return self.transforms(img) 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # *FairTune*: Optimizing Parameter-Efficient Fine-Tuning for Fairness in Medical Image Analysis 2 | 3 | :star2: Accepted to ICLR 2024! | [Paper Link](https://arxiv.org/abs/2310.05055) 4 | 5 | ![Fairtune](assets/fairtune_framework.png) 6 | 7 | Abstract: Training models with robust group fairness properties is crucial in ethically sensitive application areas such as medical diagnosis. Despite the growing body of work aiming to minimise demographic bias in AI, this problem remains challenging. A key reason for this challenge is the fairness generalisation gap: High-capacity deep learning models can fit all training data nearly perfectly, and thus also exhibit perfect fairness during training. In this case, bias emerges only during testing when generalisation performance differs across subgroups. This motivates us to take a bi-level optimisation perspective on fair learning: Optimising the learning strategy based on validation fairness. Specifically, we consider the highly effective workflow of adapting pre-trained models to downstream medical imaging tasks using parameter-efficient fine-tuning (PEFT) techniques. There is a trade-off between updating more parameters, enabling a better fit to the task of interest vs. fewer parameters, potentially reducing the generalisation gap. To manage this tradeoff, we propose *FairTune*, a framework to optimise the choice of PEFT parameters with respect to fairness. We demonstrate empirically that *FairTune* leads to improved fairness on a range of medical imaging datasets. 8 | 9 | ## Installation 10 | Python >= 3.8+ and Pytorch >=1.10 are required for running the code. 11 | Main packages: [PyTorch](https://pytorch.org/get-started/locally/), [Optuna](https://optuna.readthedocs.io/en/stable/installation.html), [FairLearn](https://fairlearn.org/v0.9/quickstart.html) 12 | 13 | 14 | ## Dataset Preparation 15 | We follow the steps in [MEDFAIR](https://github.com/ys-zong/MEDFAIR/tree/main) for preparing the datasets. Please see [this page](https://github.com/ys-zong/MEDFAIR/tree/main#data-preprocessing). 16 | Detailed instructions for preparing the datasets are given in the Appendix. 17 | 18 | After preprocessing, specify the paths of the metadata and pickle files in `config.yaml`. 19 | 20 | 21 | ### Dataset 22 | Due to the data use agreements, we cannot directly share the download link. Please register and download datasets using the links from the table below: 23 | 24 | | **Dataset** | **Access** | 25 | |--------------|-----------------------------------------------------------------------------------------------| 26 | | CheXpert | https://stanfordmlgroup.github.io/competitions/chexpert/ | 27 | | OL3I | https://stanfordaimi.azurewebsites.net/datasets/3263e34a-252e-460f-8f63-d585a9bfecfc | 28 | | PAPILA | https://www.nature.com/articles/s41597-022-01388-1#Sec6 | 29 | | HAM10000 | https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/DBW86T | 30 | | Oasis-1 | https://www.oasis-brains.org/#data | 31 | | Fitzpatrick17k | https://github.com/mattgroh/fitzpatrick17k | 32 | | Harvard-GF3300 | https://ophai.hms.harvard.edu/datasets/harvard-glaucoma-fairness-3300-samples/ | 33 | 34 | 35 | ### Run HPO search for finding the best mask (Stage 1) 36 | ```python 37 | python search_mask.py --model [model] --epochs [epochs] --batch-size [batch-size] \ 38 | --opt [opt] --lr [lr] --lr-scheduler [lr-scheduler] --lr-warmup-method [lr-warmup-method] --lr-warmup-epochs [lr-warmup-epochs] 39 | --tuning_method [tuning_method] --dataset [dataset] --sens_attribute [sens_attribute] \ 40 | --objective_metric [objective_metric] --num_trials [num_trials] --disable_storage --disable_checkpointing 41 | ``` 42 | `The mask would be saved in the directory FairTune////` 43 | 44 | You can use different types of metrics as objectives for the HPO search. Please check `parse_args.py` for more options. 45 | 46 | ### Fine-Tune on the downstream task using the searched mask (Stage 2) 47 | ```python 48 | python finetune_with_mask.py --model [model] --epochs [epochs] --batch-size [batch-size] \ 49 | --opt [opt] --lr [lr] --lr-scheduler [lr-scheduler] --lr-warmup-method [lr-warmup-method] --lr-warmup-epochs [lr-warmup-epochs] 50 | --tuning_method [tuning_method] --dataset [dataset] --sens_attribute [sens_attribute] \ 51 | --cal_equiodds --mask_path [mask_path] --cal_equiodds --use_metric auc 52 | ``` 53 | `The results would be saved in a CSV file located at FairTune///` 54 | 55 | `Note: It is advisable to use a weighted loss when working with Papila and OL3I datasets because of high imbalance, hence, use the --compute_cw argument.` 56 | 57 | ### Citing FairTune 58 | 59 | ``` 60 | @inproceedings{dutt2023fairtune, 61 | title={Fairtune: Optimizing parameter efficient fine tuning for fairness in medical image analysis}, 62 | author={Dutt, Raman and Bohdal, Ondrej and Tsaftaris, Sotirios A and Hospedales, Timothy}, 63 | booktitle={International Conference on Learning Representations}, 64 | year={2024} 65 | } 66 | ``` 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /utilities/transforms.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | 4 | import torch 5 | from torch import Tensor 6 | from torchvision.transforms import functional as F 7 | 8 | 9 | class RandomMixup(torch.nn.Module): 10 | """Randomly apply Mixup to the provided batch and targets. 11 | The class implements the data augmentations as described in the paper 12 | `"mixup: Beyond Empirical Risk Minimization" `_. 13 | 14 | Args: 15 | num_classes (int): number of classes used for one-hot encoding. 16 | p (float): probability of the batch being transformed. Default value is 0.5. 17 | alpha (float): hyperparameter of the Beta distribution used for mixup. 18 | Default value is 1.0. 19 | inplace (bool): boolean to make this transform inplace. Default set to False. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | num_classes: int, 25 | p: float = 0.5, 26 | alpha: float = 1.0, 27 | inplace: bool = False, 28 | ) -> None: 29 | super().__init__() 30 | 31 | if num_classes < 1: 32 | raise ValueError( 33 | f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}" 34 | ) 35 | 36 | if alpha <= 0: 37 | raise ValueError("Alpha param can't be zero.") 38 | 39 | self.num_classes = num_classes 40 | self.p = p 41 | self.alpha = alpha 42 | self.inplace = inplace 43 | 44 | def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: 45 | """ 46 | Args: 47 | batch (Tensor): Float tensor of size (B, C, H, W) 48 | target (Tensor): Integer tensor of size (B, ) 49 | 50 | Returns: 51 | Tensor: Randomly transformed batch. 52 | """ 53 | if batch.ndim != 4: 54 | raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") 55 | if target.ndim != 1: 56 | raise ValueError(f"Target ndim should be 1. Got {target.ndim}") 57 | if not batch.is_floating_point(): 58 | raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") 59 | if target.dtype != torch.int64: 60 | raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") 61 | 62 | if not self.inplace: 63 | batch = batch.clone() 64 | target = target.clone() 65 | 66 | if target.ndim == 1: 67 | target = torch.nn.functional.one_hot( 68 | target, num_classes=self.num_classes 69 | ).to(dtype=batch.dtype) 70 | 71 | if torch.rand(1).item() >= self.p: 72 | return batch, target 73 | 74 | # It's faster to roll the batch by one instead of shuffling it to create image pairs 75 | batch_rolled = batch.roll(1, 0) 76 | target_rolled = target.roll(1, 0) 77 | 78 | # Implemented as on mixup paper, page 3. 79 | lambda_param = float( 80 | torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] 81 | ) 82 | batch_rolled.mul_(1.0 - lambda_param) 83 | batch.mul_(lambda_param).add_(batch_rolled) 84 | 85 | target_rolled.mul_(1.0 - lambda_param) 86 | target.mul_(lambda_param).add_(target_rolled) 87 | 88 | return batch, target 89 | 90 | def __repr__(self) -> str: 91 | s = ( 92 | f"{self.__class__.__name__}(" 93 | f"num_classes={self.num_classes}" 94 | f", p={self.p}" 95 | f", alpha={self.alpha}" 96 | f", inplace={self.inplace}" 97 | f")" 98 | ) 99 | return s 100 | 101 | 102 | class RandomCutmix(torch.nn.Module): 103 | """Randomly apply Cutmix to the provided batch and targets. 104 | The class implements the data augmentations as described in the paper 105 | `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" 106 | `_. 107 | 108 | Args: 109 | num_classes (int): number of classes used for one-hot encoding. 110 | p (float): probability of the batch being transformed. Default value is 0.5. 111 | alpha (float): hyperparameter of the Beta distribution used for cutmix. 112 | Default value is 1.0. 113 | inplace (bool): boolean to make this transform inplace. Default set to False. 114 | """ 115 | 116 | def __init__( 117 | self, 118 | num_classes: int, 119 | p: float = 0.5, 120 | alpha: float = 1.0, 121 | inplace: bool = False, 122 | ) -> None: 123 | super().__init__() 124 | if num_classes < 1: 125 | raise ValueError( 126 | "Please provide a valid positive value for the num_classes." 127 | ) 128 | if alpha <= 0: 129 | raise ValueError("Alpha param can't be zero.") 130 | 131 | self.num_classes = num_classes 132 | self.p = p 133 | self.alpha = alpha 134 | self.inplace = inplace 135 | 136 | def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: 137 | """ 138 | Args: 139 | batch (Tensor): Float tensor of size (B, C, H, W) 140 | target (Tensor): Integer tensor of size (B, ) 141 | 142 | Returns: 143 | Tensor: Randomly transformed batch. 144 | """ 145 | if batch.ndim != 4: 146 | raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") 147 | if target.ndim != 1: 148 | raise ValueError(f"Target ndim should be 1. Got {target.ndim}") 149 | if not batch.is_floating_point(): 150 | raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") 151 | if target.dtype != torch.int64: 152 | raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") 153 | 154 | if not self.inplace: 155 | batch = batch.clone() 156 | target = target.clone() 157 | 158 | if target.ndim == 1: 159 | target = torch.nn.functional.one_hot( 160 | target, num_classes=self.num_classes 161 | ).to(dtype=batch.dtype) 162 | 163 | if torch.rand(1).item() >= self.p: 164 | return batch, target 165 | 166 | # It's faster to roll the batch by one instead of shuffling it to create image pairs 167 | batch_rolled = batch.roll(1, 0) 168 | target_rolled = target.roll(1, 0) 169 | 170 | # Implemented as on cutmix paper, page 12 (with minor corrections on typos). 171 | lambda_param = float( 172 | torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] 173 | ) 174 | _, H, W = F.get_dimensions(batch) 175 | 176 | r_x = torch.randint(W, (1,)) 177 | r_y = torch.randint(H, (1,)) 178 | 179 | r = 0.5 * math.sqrt(1.0 - lambda_param) 180 | r_w_half = int(r * W) 181 | r_h_half = int(r * H) 182 | 183 | x1 = int(torch.clamp(r_x - r_w_half, min=0)) 184 | y1 = int(torch.clamp(r_y - r_h_half, min=0)) 185 | x2 = int(torch.clamp(r_x + r_w_half, max=W)) 186 | y2 = int(torch.clamp(r_y + r_h_half, max=H)) 187 | 188 | batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] 189 | lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) 190 | 191 | target_rolled.mul_(1.0 - lambda_param) 192 | target.mul_(lambda_param).add_(target_rolled) 193 | 194 | return batch, target 195 | 196 | def __repr__(self) -> str: 197 | s = ( 198 | f"{self.__class__.__name__}(" 199 | f"num_classes={self.num_classes}" 200 | f", p={self.p}" 201 | f", alpha={self.alpha}" 202 | f", inplace={self.inplace}" 203 | f")" 204 | ) 205 | return s 206 | -------------------------------------------------------------------------------- /parse_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_args_parser(add_help=True): 5 | parser = argparse.ArgumentParser( 6 | description="PyTorch Classification Training", add_help=add_help 7 | ) 8 | 9 | parser.add_argument( 10 | "--dataset", 11 | default=None, 12 | required=True, 13 | type=str, 14 | help="Dataset for finetuning.", 15 | ) 16 | parser.add_argument( 17 | "--train_subset_ratio", 18 | default=None, 19 | type=float, 20 | help="Subset of the training dataset to be used", 21 | ) 22 | parser.add_argument( 23 | "--val_subset_ratio", 24 | default=None, 25 | type=float, 26 | help="Subset of the validation dataset to be used", 27 | ) 28 | parser.add_argument( 29 | "--dataset_basepath", 30 | required=False, 31 | type=str, 32 | help="Base path for all the datasets.", 33 | ) 34 | parser.add_argument( 35 | "--compute_cw", 36 | help="Compute and use class weights for imbalance dataset", 37 | action="store_true", 38 | ) 39 | parser.add_argument("--model", default="resnet18", type=str, help="model name") 40 | parser.add_argument( 41 | "--device", 42 | default="cuda", 43 | type=str, 44 | help="device (Use cuda or cpu Default: cuda)", 45 | ) 46 | parser.add_argument( 47 | "-b", 48 | "--batch-size", 49 | default=32, 50 | type=int, 51 | help="images per gpu, the total batch size is $NGPU x batch_size", 52 | ) 53 | parser.add_argument( 54 | "--epochs", 55 | default=90, 56 | type=int, 57 | metavar="N", 58 | help="number of total epochs to run", 59 | ) 60 | parser.add_argument( 61 | "-j", 62 | "--workers", 63 | default=16, 64 | type=int, 65 | metavar="N", 66 | help="number of data loading workers (default: 16)", 67 | ) 68 | parser.add_argument("--opt", default="sgd", type=str, help="optimizer") 69 | parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate") 70 | parser.add_argument( 71 | "--lr_scaler", 72 | default=1, 73 | type=float, 74 | help="Multiplier for the LR used in the inner loop weight update", 75 | ) 76 | 77 | parser.add_argument( 78 | "--outer_lr", default=0.1, type=float, help="outer loop learning rate" 79 | ) 80 | parser.add_argument( 81 | "--momentum", default=0.9, type=float, metavar="M", help="momentum" 82 | ) 83 | parser.add_argument( 84 | "--wd", 85 | "--weight-decay", 86 | default=1e-4, 87 | type=float, 88 | metavar="W", 89 | help="weight decay (default: 1e-4)", 90 | dest="weight_decay", 91 | ) 92 | parser.add_argument( 93 | "--norm-weight-decay", 94 | default=None, 95 | type=float, 96 | help="weight decay for Normalization layers (default: None, same value as --wd)", 97 | ) 98 | parser.add_argument( 99 | "--bias-weight-decay", 100 | default=None, 101 | type=float, 102 | help="weight decay for bias parameters of all layers (default: None, same value as --wd)", 103 | ) 104 | parser.add_argument( 105 | "--transformer-embedding-decay", 106 | default=None, 107 | type=float, 108 | help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)", 109 | ) 110 | parser.add_argument( 111 | "--label-smoothing", 112 | default=0.0, 113 | type=float, 114 | help="label smoothing (default: 0.0)", 115 | dest="label_smoothing", 116 | ) 117 | parser.add_argument( 118 | "--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)" 119 | ) 120 | parser.add_argument( 121 | "--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)" 122 | ) 123 | parser.add_argument( 124 | "--lr-scheduler", 125 | default="steplr", 126 | type=str, 127 | help="the lr scheduler (default: steplr)", 128 | ) 129 | parser.add_argument( 130 | "--lr-scheduler-outer", 131 | default="constant", 132 | type=str, 133 | help="the lr scheduler (default: constant)", 134 | ) 135 | parser.add_argument( 136 | "--lr-warmup-epochs", 137 | default=0, 138 | type=int, 139 | help="the number of epochs to warmup (default: 0)", 140 | ) 141 | parser.add_argument( 142 | "--lr-warmup-method", 143 | default="constant", 144 | type=str, 145 | help="the warmup method (default: constant)", 146 | ) 147 | parser.add_argument( 148 | "--lr-warmup-decay", default=0.01, type=float, help="the decay for lr" 149 | ) 150 | parser.add_argument( 151 | "--lr-step-size", 152 | default=30, 153 | type=int, 154 | help="decrease lr every step-size epochs", 155 | ) 156 | parser.add_argument( 157 | "--lr-gamma", 158 | default=0.1, 159 | type=float, 160 | help="decrease lr by a factor of lr-gamma", 161 | ) 162 | parser.add_argument( 163 | "--lr-min", 164 | default=0.0, 165 | type=float, 166 | help="minimum lr of lr schedule (default: 0.0)", 167 | ) 168 | parser.add_argument("--print-freq", default=10, type=int, help="print frequency") 169 | parser.add_argument( 170 | "--output-dir", default=".", type=str, help="path to save outputs" 171 | ) 172 | parser.add_argument("--resume", default="", type=str, help="path of checkpoint") 173 | parser.add_argument( 174 | "--start-epoch", default=0, type=int, metavar="N", help="start epoch" 175 | ) 176 | parser.add_argument( 177 | "--cache-dataset", 178 | dest="cache_dataset", 179 | help="Cache the datasets for quicker initialization. It also serializes the transforms", 180 | action="store_true", 181 | ) 182 | parser.add_argument( 183 | "--sync-bn", 184 | dest="sync_bn", 185 | help="Use sync batch norm", 186 | action="store_true", 187 | ) 188 | parser.add_argument( 189 | "--test-only", 190 | dest="test_only", 191 | help="Only test the model", 192 | action="store_true", 193 | ) 194 | parser.add_argument( 195 | "--auto-augment", 196 | default=None, 197 | type=str, 198 | help="auto augment policy (default: None)", 199 | ) 200 | parser.add_argument( 201 | "--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy" 202 | ) 203 | parser.add_argument( 204 | "--augmix-severity", default=3, type=int, help="severity of augmix policy" 205 | ) 206 | parser.add_argument( 207 | "--random-erase", 208 | default=0.0, 209 | type=float, 210 | help="random erasing probability (default: 0.0)", 211 | ) 212 | 213 | # Mixed precision training parameters 214 | parser.add_argument( 215 | "--amp", 216 | action="store_true", 217 | help="Use torch.cuda.amp for mixed precision training", 218 | ) 219 | 220 | # distributed training parameters 221 | parser.add_argument( 222 | "--world-size", default=1, type=int, help="number of distributed processes" 223 | ) 224 | parser.add_argument( 225 | "--dist-url", 226 | default="env://", 227 | type=str, 228 | help="url used to set up distributed training", 229 | ) 230 | parser.add_argument( 231 | "--model-ema", 232 | action="store_true", 233 | help="enable tracking Exponential Moving Average of model parameters", 234 | ) 235 | parser.add_argument( 236 | "--model-ema-steps", 237 | type=int, 238 | default=32, 239 | help="the number of iterations that controls how often to update the EMA model (default: 32)", 240 | ) 241 | parser.add_argument( 242 | "--model-ema-decay", 243 | type=float, 244 | default=0.99998, 245 | help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)", 246 | ) 247 | parser.add_argument( 248 | "--use-deterministic-algorithms", 249 | action="store_true", 250 | help="Forces the use of deterministic algorithms only.", 251 | ) 252 | parser.add_argument( 253 | "--interpolation", 254 | default="bilinear", 255 | type=str, 256 | help="the interpolation method (default: bilinear)", 257 | ) 258 | parser.add_argument( 259 | "--val-resize-size", 260 | default=256, 261 | type=int, 262 | help="the resize size used for validation (default: 256)", 263 | ) 264 | parser.add_argument( 265 | "--val-crop-size", 266 | default=224, 267 | type=int, 268 | help="the central crop size used for validation (default: 224)", 269 | ) 270 | parser.add_argument( 271 | "--train-crop-size", 272 | default=224, 273 | type=int, 274 | help="the random crop size used for training (default: 224)", 275 | ) 276 | parser.add_argument( 277 | "--clip-grad-norm", 278 | default=None, 279 | type=float, 280 | help="the maximum gradient norm (default None)", 281 | ) 282 | parser.add_argument( 283 | "--ra-sampler", 284 | action="store_true", 285 | help="whether to use Repeated Augmentation in training", 286 | ) 287 | parser.add_argument( 288 | "--ra-reps", 289 | default=3, 290 | type=int, 291 | help="number of repetitions for Repeated Augmentation (default: 3)", 292 | ) 293 | parser.add_argument( 294 | "--weights", 295 | default="IMAGENET1K_V1", 296 | type=str, 297 | help="the weights enum name to load", 298 | ) 299 | 300 | # RANDOM SEARCH AND TUNING METHOD PARAMETERS 301 | parser.add_argument( 302 | "--tuning_method", 303 | default="fullft", 304 | type=str, 305 | help="Type of fine-tuning method to use", 306 | ) 307 | parser.add_argument("--masking_vector_idx", type=int, default=None) 308 | 309 | parser.add_argument( 310 | "--masking_vector", 311 | metavar="N", 312 | type=float, 313 | nargs="+", 314 | help="Elements of the masking vector", 315 | ) 316 | 317 | parser.add_argument("--subnetwork_mask_name", type=str, default=None) 318 | 319 | parser.add_argument("--mask_path", type=str, default=None) 320 | 321 | parser.add_argument( 322 | "--exp_vector_path", 323 | type=str, 324 | ) 325 | 326 | # MASK GENERATION PARAMETERS 327 | parser.add_argument( 328 | "--mask_gen_method", 329 | default="random", 330 | type=str, 331 | ) 332 | parser.add_argument( 333 | "--sigma", 334 | default=0.1, 335 | type=float, 336 | ) 337 | parser.add_argument( 338 | "--use_adaptive_threshold", 339 | action="store_true", 340 | help="Use adaptive thresholding for masking", 341 | ) 342 | parser.add_argument( 343 | "--thr_ema_decay", 344 | default=0.99, 345 | type=float, 346 | ) 347 | parser.add_argument( 348 | "--use_gumbel_sigmoid", 349 | action="store_true", 350 | help="Use sigmoid function to the weight mask", 351 | ) 352 | 353 | ## LOGGING AND MISC PARAMETERS 354 | parser.add_argument( 355 | "--wandb_logging", 356 | action="store_true", 357 | help="To enable/ disable wandb logging.", 358 | ) 359 | parser.add_argument( 360 | "--ckpt_dir", type=str, default=None, help="Path to the checkpoint directory" 361 | ) 362 | parser.add_argument( 363 | "--disable_checkpointing", 364 | action="store_true", 365 | help="Disable saving of model checkopints", 366 | ) 367 | parser.add_argument( 368 | "--disable_training", 369 | action="store_true", 370 | help="To disable/ skip the training process.", 371 | ) 372 | parser.add_argument( 373 | "--disable_plotting", 374 | action="store_true", 375 | help="To disable/ skip the plotting.", 376 | ) 377 | parser.add_argument( 378 | "--dev_mode", 379 | action="store_true", 380 | help="Dev mode disables plotting, checkpointing, etc", 381 | ) 382 | 383 | # FAIRNESS Arguements 384 | parser.add_argument( 385 | "--sens_attribute", 386 | type=str, 387 | default=None, 388 | help="Sensitive attribute to be used for fairness", 389 | ) 390 | parser.add_argument( 391 | "--age_type", type=str, default="binary", choices=["binary", "multi"] 392 | ) 393 | parser.add_argument( 394 | "--skin_type", type=str, default="binary", choices=["binary", "multi"] 395 | ) 396 | parser.add_argument("--use_metric", type=str, default="auc", choices=["acc", "auc"]) 397 | 398 | # HPARAM OPT (HPO) Arguements 399 | parser.add_argument( 400 | "--objective_metric", 401 | type=str, 402 | default="min_acc", 403 | choices=[ 404 | "min_acc", 405 | "min_auc", 406 | "acc_diff", 407 | "auc_diff", 408 | "max_loss", 409 | "overall_acc", 410 | "overall_auc", 411 | ], 412 | ) 413 | parser.add_argument("--num_trials", type=int, default=5) 414 | parser.add_argument( 415 | "--pruner", 416 | type=str, 417 | default="SuccessiveHalving", 418 | choices=["SuccessiveHalving", "MedianPruner", "Hyperband"], 419 | ) 420 | parser.add_argument( 421 | "--disable_storage", 422 | action="store_true", 423 | help="Disable creating a storage DB for the experiment", 424 | ) 425 | parser.add_argument( 426 | "--use_multi_objective", 427 | help="Use multi-objective optimization for HPO", 428 | action="store_true", 429 | ) 430 | 431 | # FAIRPRUNE ARGUEMENTS 432 | parser.add_argument( 433 | "--pruning_ratio", 434 | default=0.35, 435 | type=float, 436 | help="Pruning ratio in FairPrune", 437 | ) 438 | parser.add_argument( 439 | "--b_param", 440 | default=0.33, 441 | type=float, 442 | help="Beta Parameter in FairPrune", 443 | ) 444 | 445 | parser.add_argument( 446 | "--cal_equiodds", 447 | action="store_true", 448 | help="Calculate Equalized odds and DPD", 449 | ) 450 | 451 | # FSCL ARGUEMENTS 452 | parser.add_argument( 453 | "--fscl", 454 | action="store_true", 455 | help="To perform FSCL", 456 | ) 457 | parser.add_argument( 458 | "--train_encoder_lr", 459 | default=0.1, 460 | type=float, 461 | help="Learning rate for training the encoder in FSCL", 462 | ) 463 | parser.add_argument( 464 | "--train_classifier_lr", 465 | default=0.1, 466 | type=float, 467 | help="Learning rate for training the classifier in FSCL", 468 | ) 469 | parser.add_argument( 470 | "--fscl_eval_only", 471 | action="store_true", 472 | help="To perform FSCL eval using the trained model", 473 | ) 474 | parser.add_argument( 475 | "--temperature", 476 | default=0.1, 477 | type=float, 478 | help="Temperature Parameter in FSCL", 479 | ) 480 | parser.add_argument( 481 | "--contrast_mode", 482 | default="all", 483 | type=str, 484 | help="Contrast Mode in FSCL", 485 | ) 486 | parser.add_argument( 487 | "--base_temperature", 488 | default=0.1, 489 | type=float, 490 | help="Temperature Parameter in FSCL", 491 | ) 492 | parser.add_argument("--group_norm", type=int, default=0, help="group normalization") 493 | 494 | parser.add_argument( 495 | "--method", 496 | type=str, 497 | default="FSCL", 498 | choices=["FSCL", "FSCL*", "SupCon", "SimCLR"], 499 | help="choose method", 500 | ) 501 | 502 | # HAM10000 LABELS 503 | parser.add_argument( 504 | "--label_type", 505 | default="binary", 506 | type=str, 507 | help="Binary/ Multi labels to be used", 508 | ) 509 | 510 | return parser 511 | -------------------------------------------------------------------------------- /finetune_with_mask.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["TORCH_HOME"] = os.path.dirname(os.getcwd()) 4 | 5 | import datetime 6 | import random 7 | import re 8 | import time 9 | import json 10 | import warnings 11 | import timm 12 | import pandas as pd 13 | import numpy as np 14 | from utilities import presets 15 | import torch 16 | import torch.utils.data 17 | import torchvision 18 | import utilities.transforms 19 | from utilities.utils import * 20 | from utilities.training_utils import * 21 | from parse_args import * 22 | from utilities.sampler import RASampler 23 | from torch import nn 24 | from torch.utils.data.dataloader import default_collate 25 | from torchvision.transforms.functional import InterpolationMode 26 | import yaml 27 | from pprint import pprint 28 | 29 | 30 | def create_results_df(args): 31 | test_results_df = None 32 | 33 | if args.sens_attribute == "gender": 34 | test_results_df = pd.DataFrame( 35 | columns=[ 36 | "Tuning Method", 37 | "Train Percent", 38 | "LR", 39 | "Test AUC Overall", 40 | "Test AUC Male", 41 | "Test AUC Female", 42 | "Test AUC Difference", 43 | "EquiOdd_diff", 44 | "EquiOdd_ratio", 45 | "DPD", 46 | "DPR", 47 | "Mask Path" 48 | ] 49 | ) 50 | 51 | elif ( 52 | args.sens_attribute == "skin_type" 53 | or args.sens_attribute == "age" 54 | or args.sens_attribute == "race" 55 | or args.sens_attribute == "age_sex" 56 | ): 57 | cols = [ 58 | "Tuning Method", 59 | "Train Percent", 60 | "LR", 61 | "Test AUC Overall", 62 | "Test AUC (Best)", 63 | "Test AUC (Worst)", 64 | "Test AUC Difference", 65 | "EquiOdd_diff", 66 | "EquiOdd_ratio", 67 | "DPD", 68 | "DPR", 69 | "Mask Path", 70 | ] 71 | 72 | test_results_df = pd.DataFrame(columns=cols) 73 | else: 74 | raise NotImplementedError("Sensitive attribute not implemented") 75 | 76 | return test_results_df 77 | 78 | 79 | def main(args): 80 | assert args.sens_attribute is not None, "Sensitive attribute not provided" 81 | 82 | os.makedirs(args.fig_savepath, exist_ok=True) 83 | 84 | # Making directory for saving checkpoints 85 | if args.output_dir: 86 | utils.mkdir(args.output_dir) 87 | utils.mkdir(os.path.join(args.output_dir, "checkpoints")) 88 | 89 | try: 90 | test_results_df = pd.read_csv( 91 | os.path.join(args.output_dir, args.test_results_df) 92 | ) 93 | print("Reading existing results dataframe") 94 | except: 95 | print("Creating new results dataframe") 96 | test_results_df = create_results_df(args) 97 | 98 | utils.init_distributed_mode(args) 99 | print(args) 100 | 101 | device = torch.device(args.device) 102 | 103 | if args.use_deterministic_algorithms: 104 | torch.backends.cudnn.benchmark = False 105 | torch.use_deterministic_algorithms(True) 106 | else: 107 | torch.backends.cudnn.benchmark = True 108 | 109 | ## CREATING DATASETS AND DATALOADERS 110 | 111 | with open("config.yaml") as file: 112 | yaml_data = yaml.safe_load(file) 113 | 114 | ( 115 | dataset, 116 | dataset_val, 117 | dataset_test, 118 | train_sampler, 119 | val_sampler, 120 | test_sampler, 121 | ) = get_fairness_data(args, yaml_data) 122 | 123 | args.num_classes = len(dataset.classes) 124 | print("DATASET: ", args.dataset) 125 | print("Size of training dataset: ", len(dataset)) 126 | print("Size of validation dataset: ", len(dataset_val)) 127 | print("Size of test dataset: ", len(dataset_test)) 128 | print("Number of classes: ", args.num_classes) 129 | 130 | collate_fn = None 131 | mixup_transforms = get_mixup_transforms(args) 132 | 133 | if mixup_transforms: 134 | mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) 135 | 136 | def collate_fn(batch): 137 | return mixupcutmix(*default_collate(batch)) 138 | 139 | if args.dataset == "papila": 140 | drop_last = False 141 | else: 142 | drop_last = True 143 | 144 | data_loader = torch.utils.data.DataLoader( 145 | dataset, 146 | batch_size=args.batch_size, 147 | sampler=train_sampler, 148 | num_workers=args.workers, 149 | pin_memory=True, 150 | collate_fn=collate_fn, 151 | drop_last=drop_last, 152 | ) 153 | data_loader_val = torch.utils.data.DataLoader( 154 | dataset_val, 155 | batch_size=args.batch_size, 156 | sampler=val_sampler, 157 | num_workers=args.workers, 158 | pin_memory=True, 159 | # drop_last=True 160 | ) 161 | data_loader_test = torch.utils.data.DataLoader( 162 | dataset_test, 163 | batch_size=args.batch_size, 164 | sampler=test_sampler, 165 | num_workers=args.workers, 166 | pin_memory=True, 167 | drop_last=False, 168 | ) 169 | 170 | # CREATING THE MODEL 171 | print("TUNING METHOD: ", args.tuning_method) 172 | print("Creating model") 173 | 174 | if args.tuning_method == "train_from_scratch": 175 | model = utils.get_timm_model( 176 | args.model, num_classes=args.num_classes, pretrained=False 177 | ) 178 | else: 179 | model = utils.get_timm_model(args.model, num_classes=args.num_classes) 180 | 181 | # Calculate the sum of model parameters 182 | total_params = sum([p.sum() for p in model.parameters()]) 183 | print("Sum of parameters: ", total_params) 184 | 185 | base_model = model 186 | 187 | if args.tuning_method == "fullft" or args.tuning_method == "train_from_scratch": 188 | pass 189 | elif args.tuning_method == "linear_readout": 190 | utils.disable_module(model) 191 | utils.enable_module(model.head) 192 | else: 193 | assert args.mask_path is not None 194 | mask = np.load(args.mask_path) 195 | mask = get_masked_model(model, args.tuning_method, mask=mask) 196 | print("LOADED MASK: ", mask) 197 | 198 | if np.all(np.array(mask) == 1): 199 | # If the mask contains all ones 200 | args.tuning_method = "Vanilla_" + args.tuning_method 201 | print( 202 | "Mask contains all ones. Changing tuning method to: ", 203 | args.tuning_method, 204 | ) 205 | 206 | # Check Tunable Params 207 | trainable_params, all_param = utils.check_tunable_params(model, True) 208 | trainable_percentage = 100 * trainable_params / all_param 209 | 210 | model.to(device) 211 | 212 | if args.distributed and args.sync_bn: 213 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 214 | 215 | # Computing class weights for a weighted loss in case of highly imbalanced datasets 216 | if args.compute_cw: 217 | if args.dataset == "ol3i": 218 | weight = torch.tensor([0.04361966711306677, 0.9563803328869332]) 219 | elif args.dataset == "papila": 220 | weight = torch.tensor([0.20714285714285716, 0.7928571428571428]) 221 | weight = weight.to(device) 222 | print("Using CW Loss with weights: ", weight) 223 | else: 224 | weight = None 225 | 226 | criterion = nn.CrossEntropyLoss( 227 | label_smoothing=args.label_smoothing, reduction="none", weight=weight 228 | ) 229 | 230 | ece_criterion = utils.ECELoss() 231 | 232 | custom_keys_weight_decay = [] 233 | if args.bias_weight_decay is not None: 234 | custom_keys_weight_decay.append(("bias", args.bias_weight_decay)) 235 | if args.transformer_embedding_decay is not None: 236 | for key in [ 237 | "class_token", 238 | "position_embedding", 239 | "relative_position_bias_table", 240 | ]: 241 | custom_keys_weight_decay.append((key, args.transformer_embedding_decay)) 242 | 243 | parameters = utils.set_weight_decay( 244 | model, 245 | args.weight_decay, 246 | norm_weight_decay=args.norm_weight_decay, 247 | custom_keys_weight_decay=custom_keys_weight_decay 248 | if len(custom_keys_weight_decay) > 0 249 | else None, 250 | ) 251 | 252 | # Optimizer 253 | optimizer = get_optimizer(args, parameters) 254 | scaler = torch.cuda.amp.GradScaler() if args.amp else None 255 | 256 | # LR Scheduler 257 | lr_scheduler = get_lr_scheduler(args, optimizer) 258 | 259 | model_without_ddp = model 260 | if args.distributed: 261 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 262 | model_without_ddp = model.module 263 | 264 | # From training_utils.py 265 | model_ema = get_model_ema(model_without_ddp, args) 266 | print("model ema", model_ema) 267 | 268 | if args.resume: 269 | checkpoint = torch.load(args.resume, map_location="cpu") 270 | model_without_ddp.load_state_dict(checkpoint["model"]) 271 | if not args.test_only: 272 | optimizer.load_state_dict(checkpoint["optimizer"]) 273 | lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) 274 | args.start_epoch = checkpoint["epoch"] + 1 275 | if model_ema: 276 | model_ema.load_state_dict(checkpoint["model_ema"]) 277 | if scaler: 278 | scaler.load_state_dict(checkpoint["scaler"]) 279 | 280 | if args.disable_training: 281 | print("Training Process Skipped") 282 | else: 283 | print("Start training") 284 | start_time = time.time() 285 | for epoch in range(args.start_epoch, args.epochs): 286 | if args.distributed: 287 | train_sampler.set_epoch(epoch) 288 | ( 289 | train_acc, 290 | train_best_acc, 291 | train_worst_acc, 292 | train_auc, 293 | train_best_auc, 294 | train_worst_auc, 295 | ) = train_one_epoch_fairness( 296 | model, 297 | criterion, 298 | ece_criterion, 299 | optimizer, 300 | data_loader, 301 | device, 302 | epoch, 303 | args, 304 | model_ema, 305 | scaler, 306 | ) 307 | lr_scheduler.step() 308 | 309 | if args.sens_attribute == "gender": 310 | if args.cal_equiodds: 311 | ( 312 | val_acc, 313 | val_male_acc, 314 | val_female_acc, 315 | val_auc, 316 | val_male_auc, 317 | val_female_auc, 318 | val_loss, 319 | val_max_loss, 320 | equiodds_diff, 321 | equiodds_ratio, 322 | dpd, 323 | dpr, 324 | ) = evaluate_fairness_gender( 325 | model, 326 | criterion, 327 | ece_criterion, 328 | data_loader_val, 329 | args=args, 330 | device=device, 331 | ) 332 | else: 333 | ( 334 | val_acc, 335 | val_male_acc, 336 | val_female_acc, 337 | val_auc, 338 | val_male_auc, 339 | val_female_auc, 340 | val_loss, 341 | val_max_loss, 342 | ) = evaluate_fairness_gender( 343 | model, 344 | criterion, 345 | ece_criterion, 346 | data_loader_val, 347 | args=args, 348 | device=device, 349 | ) 350 | 351 | best_val_acc = max(val_male_acc, val_female_acc) 352 | worst_val_acc = min(val_male_acc, val_female_acc) 353 | 354 | best_val_auc = max(val_male_auc, val_female_auc) 355 | worst_val_auc = min(val_male_auc, val_female_auc) 356 | 357 | print( 358 | "Val Acc: {:.2f}, Val Male Acc {:.2f}, Val Female Acc {:.2f}, Val Loss: {:.2f}, Val MAX LOSS: {:.2f}".format( 359 | val_acc, 360 | val_male_acc, 361 | val_female_acc, 362 | torch.mean(val_loss), 363 | val_max_loss, 364 | ) 365 | ) 366 | print( 367 | "Val AUC: {:.2f}, Val Male AUC {:.2f}, Val Female AUC {:.2f}".format( 368 | val_auc, 369 | val_male_auc, 370 | val_female_auc, 371 | ) 372 | ) 373 | 374 | elif args.sens_attribute == "skin_type": 375 | assert args.skin_type == "binary" 376 | assert args.cal_equiodds is not None 377 | 378 | ( 379 | val_acc, 380 | val_acc_type0, 381 | val_acc_type1, 382 | val_auc, 383 | val_auc_type0, 384 | val_auc_type1, 385 | val_loss, 386 | val_max_loss, 387 | equiodds_diff, 388 | equiodds_ratio, 389 | dpd, 390 | dpr, 391 | ) = evaluate_fairness_skin_type_binary( 392 | model, 393 | criterion, 394 | ece_criterion, 395 | data_loader_val, 396 | args=args, 397 | device=device, 398 | ) 399 | 400 | best_val_acc = max(val_acc_type0, val_acc_type1) 401 | worst_val_acc = min(val_acc_type0, val_acc_type1) 402 | 403 | best_val_auc = max(val_auc_type0, val_auc_type1) 404 | worst_val_auc = min(val_auc_type0, val_auc_type1) 405 | 406 | print( 407 | "Val Acc: {:.2f}, Val Type 0 Acc: {:.2f}, Val Type 1 Acc: {:.2f}, Val Loss: {:.2f}, Val MAX LOSS: {:.2f}".format( 408 | val_acc, 409 | val_acc_type0, 410 | val_acc_type1, 411 | torch.mean(val_loss), 412 | val_max_loss, 413 | ) 414 | ) 415 | print( 416 | "Val AUC: {:.2f}, Val Type 0 AUC: {:.2f}, Val Type 1 AUC: {:.2f}".format( 417 | val_auc, 418 | val_auc_type0, 419 | val_auc_type1, 420 | ) 421 | ) 422 | print("\n") 423 | 424 | elif args.sens_attribute == "age": 425 | assert args.age_type == "binary" 426 | assert args.cal_equiodds is not None 427 | 428 | ( 429 | val_acc, 430 | acc_age0_avg, 431 | acc_age1_avg, 432 | val_auc, 433 | auc_age0_avg, 434 | auc_age1_avg, 435 | val_loss, 436 | val_max_loss, 437 | equiodds_diff, 438 | equiodds_ratio, 439 | dpd, 440 | dpr, 441 | ) = evaluate_fairness_age_binary( 442 | model, 443 | criterion, 444 | ece_criterion, 445 | data_loader_val, 446 | args=args, 447 | device=device, 448 | ) 449 | 450 | best_val_acc = max(acc_age0_avg, acc_age1_avg) 451 | worst_val_acc = min(acc_age0_avg, acc_age1_avg) 452 | 453 | best_val_auc = max(auc_age0_avg, auc_age1_avg) 454 | worst_val_auc = min(auc_age0_avg, auc_age1_avg) 455 | 456 | print( 457 | "Val Acc: {:.2f}, Val Age Group0 Acc: {:.2f}, Val Age Group1 Acc: {:.2f}, Val Loss: {:.2f}, Val MAX LOSS: {:.2f}".format( 458 | val_acc, 459 | acc_age0_avg, 460 | acc_age1_avg, 461 | torch.mean(val_loss), 462 | val_max_loss, 463 | ) 464 | ) 465 | print( 466 | "Val AUC: {:.2f}, Val Age Group0 AUC: {:.2f}, Val Age Group1 AUC: {:.2f}".format( 467 | val_auc, 468 | auc_age0_avg, 469 | auc_age1_avg, 470 | ) 471 | ) 472 | print("\n") 473 | 474 | elif args.sens_attribute == "race": 475 | assert args.cal_equiodds is not None 476 | ( 477 | val_acc, 478 | acc_race0_avg, 479 | acc_race1_avg, 480 | val_auc, 481 | auc_race0_avg, 482 | auc_race1_avg, 483 | val_loss, 484 | val_max_loss, 485 | equiodds_diff, 486 | equiodds_ratio, 487 | dpd, 488 | dpr, 489 | ) = evaluate_fairness_race_binary( 490 | model, 491 | criterion, 492 | ece_criterion, 493 | data_loader_val, 494 | args=args, 495 | device=device, 496 | ) 497 | 498 | best_val_acc = max(acc_race0_avg, acc_race1_avg) 499 | worst_val_acc = min(acc_race0_avg, acc_race1_avg) 500 | 501 | best_val_auc = max(auc_race0_avg, auc_race1_avg) 502 | worst_val_auc = min(auc_race0_avg, auc_race1_avg) 503 | 504 | print( 505 | "Val Acc: {:.2f}, Val Race Group0 Acc: {:.2f}, Val Race Group1 Acc: {:.2f}, Val Loss: {:.2f}, Val MAX LOSS: {:.2f}".format( 506 | val_acc, 507 | acc_race0_avg, 508 | acc_race1_avg, 509 | torch.mean(val_loss), 510 | val_max_loss, 511 | ) 512 | ) 513 | print( 514 | "Val AUC: {:.2f}, Val Race Group0 AUC: {:.2f}, Val Race Group1 AUC: {:.2f}".format( 515 | val_auc, 516 | auc_race0_avg, 517 | auc_race1_avg, 518 | ) 519 | ) 520 | print("\n") 521 | 522 | elif args.sens_attribute == "age_sex": 523 | assert args.dataset == "chexpert" 524 | assert args.cal_equiodds is not None 525 | ( 526 | val_acc, 527 | val_acc_type0, 528 | val_acc_type1, 529 | val_acc_type2, 530 | val_acc_type3, 531 | val_auc, 532 | val_auc_type0, 533 | val_auc_type1, 534 | val_auc_type2, 535 | val_auc_type3, 536 | val_loss, 537 | val_max_loss, 538 | equiodds_diff, 539 | equiodds_ratio, 540 | dpd, 541 | dpr, 542 | ) = evaluate_fairness_age_sex( 543 | model, 544 | criterion, 545 | ece_criterion, 546 | data_loader_val, 547 | args=args, 548 | device=device, 549 | ) 550 | 551 | best_val_acc = max( 552 | val_acc_type0, val_acc_type1, val_acc_type2, val_acc_type3 553 | ) 554 | worst_val_acc = min( 555 | val_acc_type0, val_acc_type1, val_acc_type2, val_acc_type3 556 | ) 557 | 558 | best_val_auc = max( 559 | val_auc_type0, val_auc_type1, val_auc_type2, val_auc_type3 560 | ) 561 | worst_val_auc = min( 562 | val_auc_type0, val_auc_type1, val_auc_type2, val_auc_type3 563 | ) 564 | 565 | print( 566 | "Val Acc: {:.2f}, Val Young Male Acc: {:.2f}, Val Old Male Acc: {:.2f}, Val Young Female Acc: {:.2f}, Val Old Female Acc: {:.2f}, Val Loss: {:.2f}, Val MAX LOSS: {:.2f}".format( 567 | val_acc, 568 | val_acc_type0, 569 | val_acc_type1, 570 | val_acc_type2, 571 | val_acc_type3, 572 | torch.mean(val_loss), 573 | val_max_loss, 574 | ) 575 | ) 576 | print( 577 | "Val AUC: {:.2f}, Val Young Male AUC: {:.2f}, Val Old Male AUC: {:.2f}, Val Young Female AUC: {:.2f}, Val Old Female AUC: {:.2f}".format( 578 | val_auc, 579 | val_auc_type0, 580 | val_auc_type1, 581 | val_auc_type2, 582 | val_auc_type3, 583 | ) 584 | ) 585 | print("\n") 586 | 587 | else: 588 | raise NotImplementedError("Sensitive attribute not implemented") 589 | 590 | if args.output_dir: 591 | checkpoint = { 592 | "model": model_without_ddp.state_dict(), 593 | "optimizer": optimizer.state_dict(), 594 | "lr_scheduler": lr_scheduler.state_dict(), 595 | "epoch": epoch, 596 | "args": args, 597 | } 598 | if model_ema: 599 | checkpoint["model_ema"] = model_ema.state_dict() 600 | if scaler: 601 | checkpoint["scaler"] = scaler.state_dict() 602 | 603 | ckpt_path = os.path.join( 604 | args.output_dir, 605 | "checkpoints", 606 | "checkpoint_" + args.tuning_method + ".pth", 607 | ) 608 | if not args.disable_checkpointing: 609 | utils.save_on_master(checkpoint, ckpt_path) 610 | 611 | total_time = time.time() - start_time 612 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 613 | print("Training time {}".format(total_time_str)) 614 | 615 | # Obtaining the performance on test set 616 | print("Obtaining the performance on test set") 617 | if args.sens_attribute == "gender": 618 | assert args.cal_equiodds is not None 619 | ( 620 | test_acc, 621 | test_male_acc, 622 | test_female_acc, 623 | test_auc, 624 | test_male_auc, 625 | test_female_auc, 626 | test_loss, 627 | test_max_loss, 628 | equiodds_diff, 629 | equiodds_ratio, 630 | dpd, 631 | dpr, 632 | ) = evaluate_fairness_gender( 633 | model, 634 | criterion, 635 | ece_criterion, 636 | data_loader_test, 637 | args=args, 638 | device=device, 639 | ) 640 | 641 | print("\n") 642 | print("Overall Test Accuracy: ", test_acc) 643 | print("Test Male Accuracy: ", test_male_acc) 644 | print("Test Female Accuracy: ", test_female_acc) 645 | print("\n") 646 | print("Overall Test AUC: ", test_auc) 647 | print("Test Male AUC: ", test_male_auc) 648 | print("Test Female AUC: ", test_female_auc) 649 | if args.cal_equiodds: 650 | print("\n") 651 | print("EquiOdds Difference: ", equiodds_diff) 652 | print("EquiOdds Ratio: ", equiodds_ratio) 653 | print("DPD: ", dpd) 654 | print("DPR: ", dpr) 655 | 656 | elif args.sens_attribute == "skin_type": 657 | assert args.skin_type == "binary" 658 | assert args.cal_equiodds is not None 659 | ( 660 | test_acc, 661 | test_acc_type0, 662 | test_acc_type1, 663 | test_auc, 664 | test_auc_type0, 665 | test_auc_type1, 666 | test_loss, 667 | test_max_loss, 668 | equiodds_diff, 669 | equiodds_ratio, 670 | dpd, 671 | dpr, 672 | ) = evaluate_fairness_skin_type_binary( 673 | model, 674 | criterion, 675 | ece_criterion, 676 | data_loader_test, 677 | args=args, 678 | device=device, 679 | ) 680 | 681 | print("\n") 682 | print("Overall Test accuracy: ", test_acc) 683 | print("Test Type 0 Accuracy: ", test_acc_type0) 684 | print("Test Type 1 Accuracy: ", test_acc_type1) 685 | print("\n") 686 | print("Overall Test AUC: ", test_auc) 687 | print("Test Type 0 AUC: ", test_auc_type0) 688 | print("Test Type 1 AUC: ", test_auc_type1) 689 | if args.cal_equiodds: 690 | print("\n") 691 | print("EquiOdds Difference: ", equiodds_diff) 692 | print("EquiOdds Ratio: ", equiodds_ratio) 693 | print("DPD: ", dpd) 694 | print("DPR: ", dpr) 695 | 696 | elif args.sens_attribute == "age": 697 | assert args.age_type == "binary" 698 | assert args.cal_equiodds is not None 699 | 700 | ( 701 | test_acc, 702 | test_acc_type0, 703 | test_acc_type1, 704 | test_auc, 705 | test_auc_type0, 706 | test_auc_type1, 707 | test_loss, 708 | test_max_loss, 709 | equiodds_diff, 710 | equiodds_ratio, 711 | dpd, 712 | dpr, 713 | ) = evaluate_fairness_age_binary( 714 | model, 715 | criterion, 716 | ece_criterion, 717 | data_loader_test, 718 | args=args, 719 | device=device, 720 | ) 721 | 722 | print("\n") 723 | print("Overall Test accuracy: ", test_acc) 724 | print("Test Age Group 0 Accuracy: ", test_acc_type0) 725 | print("Test Age Group 1 Accuracy: ", test_acc_type1) 726 | print("\n") 727 | print("Overall Test AUC: ", test_auc) 728 | print("Test Age Group 0 AUC: ", test_auc_type0) 729 | print("Test Age Group 1 AUC: ", test_auc_type1) 730 | if args.cal_equiodds: 731 | print("\n") 732 | print("EquiOdds Difference: ", equiodds_diff) 733 | print("EquiOdds Ratio: ", equiodds_ratio) 734 | print("DPD: ", dpd) 735 | print("DPR: ", dpr) 736 | 737 | elif args.sens_attribute == "race": 738 | ( 739 | test_acc, 740 | test_acc_type0, 741 | test_acc_type1, 742 | test_auc, 743 | test_auc_type0, 744 | test_auc_type1, 745 | test_loss, 746 | test_max_loss, 747 | equiodds_diff, 748 | equiodds_ratio, 749 | dpd, 750 | dpr, 751 | ) = evaluate_fairness_race_binary( 752 | model, 753 | criterion, 754 | ece_criterion, 755 | data_loader_test, 756 | args=args, 757 | device=device, 758 | ) 759 | 760 | print("\n") 761 | print("Overall Test accuracy: ", test_acc) 762 | print("Test Race Group 0 Accuracy: ", test_acc_type0) 763 | print("Test Race Group 1 Accuracy: ", test_acc_type1) 764 | print("\n") 765 | print("Overall Test AUC: ", test_auc) 766 | print("Test Race Group 0 AUC: ", test_auc_type0) 767 | print("Test Race Group 1 AUC: ", test_auc_type1) 768 | if args.cal_equiodds: 769 | print("\n") 770 | print("EquiOdds Difference: ", equiodds_diff) 771 | print("EquiOdds Ratio: ", equiodds_ratio) 772 | print("DPD: ", dpd) 773 | print("DPR: ", dpr) 774 | 775 | elif args.sens_attribute == "age_sex": 776 | ( 777 | test_acc, 778 | test_acc_type0, 779 | test_acc_type1, 780 | test_acc_type2, 781 | test_acc_type3, 782 | test_auc, 783 | test_auc_type0, 784 | test_auc_type1, 785 | test_auc_type2, 786 | test_auc_type3, 787 | test_loss, 788 | test_max_loss, 789 | equiodds_diff, 790 | equiodds_ratio, 791 | dpd, 792 | dpr, 793 | ) = evaluate_fairness_age_sex( 794 | model, 795 | criterion, 796 | ece_criterion, 797 | data_loader_test, 798 | args=args, 799 | device=device, 800 | ) 801 | 802 | print("\n") 803 | print("Overall Test accuracy: ", test_acc) 804 | print("Test Young Male Accuracy: ", test_acc_type0) 805 | print("Test Old Male Accuracy: ", test_acc_type1) 806 | print("Test Young Female Accuracy: ", test_acc_type2) 807 | print("Test Old Female Accuracy: ", test_acc_type3) 808 | print("\n") 809 | print("Overall Test AUC: ", test_auc) 810 | print("Test Young Male AUC: ", test_auc_type0) 811 | print("Test Old Male AUC: ", test_auc_type1) 812 | print("Test Young Female AUC: ", test_auc_type2) 813 | print("Test Old Female AUC: ", test_auc_type3) 814 | 815 | if args.cal_equiodds: 816 | print("\n") 817 | print("EquiOdds Difference: ", equiodds_diff) 818 | print("EquiOdds Ratio: ", equiodds_ratio) 819 | print("DPD: ", dpd) 820 | print("DPR: ", dpr) 821 | else: 822 | raise NotImplementedError("Sensitive Attribute not supported") 823 | 824 | print("Test loss: ", round(torch.mean(test_loss).item(), 3)) 825 | print("Test max loss: ", round(test_max_loss.item(), 3)) 826 | 827 | # Add these results to CSV 828 | # Here we are adding results on the test set 829 | 830 | if args.mask_path is not None: 831 | mask_path = args.mask_path.split("/")[-1] 832 | else: 833 | mask_path = "None" 834 | 835 | if args.sens_attribute == "gender": 836 | assert args.use_metric == "auc" 837 | 838 | new_row2 = [ 839 | args.tuning_method, 840 | round(trainable_percentage, 3), 841 | args.lr, 842 | test_auc, 843 | test_male_auc, 844 | test_female_auc, 845 | round(abs(test_male_auc - test_female_auc), 3), 846 | round(equiodds_diff, 3), 847 | round(equiodds_ratio, 3), 848 | round(dpd, 3), 849 | round(dpr, 3), 850 | mask_path, 851 | ] 852 | 853 | elif args.sens_attribute == "skin_type" or args.sens_attribute == "age" or args.sens_attribute == "race": 854 | # assert args.skin_type == "binary" 855 | assert args.use_metric == "auc" 856 | assert args.cal_equiodds is not None 857 | 858 | best_auc = max(test_auc_type0, test_auc_type1) 859 | worst_auc = min(test_auc_type0, test_auc_type1) 860 | 861 | new_row2 = [ 862 | args.tuning_method, 863 | round(trainable_percentage, 3), 864 | args.lr, 865 | test_auc, 866 | best_auc, 867 | worst_auc, 868 | round(abs(best_auc - worst_auc), 3), 869 | round(equiodds_diff, 3), 870 | round(equiodds_ratio, 3), 871 | round(dpd, 3), 872 | round(dpr, 3), 873 | mask_path, 874 | ] 875 | 876 | # elif args.sens_attribute == "age": 877 | # assert args.age_type == "binary" 878 | # assert args.use_metric == "auc" 879 | # assert args.cal_equiodds is not None 880 | 881 | # best_auc = max(test_auc_type0, test_auc_type1) 882 | # worst_auc = min(test_auc_type0, test_auc_type1) 883 | 884 | # new_row2 = [ 885 | # args.tuning_method, 886 | # round(trainable_percentage, 3), 887 | # args.lr, 888 | # test_auc, 889 | # best_auc, 890 | # worst_auc, 891 | # round(abs(best_auc - worst_auc), 3), 892 | # round(equiodds_diff, 3), 893 | # round(equiodds_ratio, 3), 894 | # round(dpd, 3), 895 | # round(dpr, 3), 896 | # mask_path, 897 | # ] 898 | 899 | # elif args.sens_attribute == "race": 900 | # assert args.use_metric == "auc" 901 | # assert args.cal_equiodds is not None 902 | 903 | # best_auc = max(test_auc_type0, test_auc_type1) 904 | # worst_auc = min(test_auc_type0, test_auc_type1) 905 | 906 | # new_row2 = [ 907 | # args.tuning_method, 908 | # round(trainable_percentage, 3), 909 | # args.lr, 910 | # test_auc, 911 | # best_auc, 912 | # worst_auc, 913 | # round(abs(best_auc - worst_auc), 3), 914 | # round(equiodds_diff, 3), 915 | # round(equiodds_ratio, 3), 916 | # round(dpd, 3), 917 | # round(dpr, 3), 918 | # mask_path, 919 | # ] 920 | 921 | elif args.sens_attribute == "age_sex": 922 | assert args.use_metric == "auc" 923 | assert args.cal_equiodds is not None 924 | 925 | best_auc = max( 926 | test_auc_type0, test_auc_type1, test_auc_type2, test_auc_type3 927 | ) 928 | worst_auc = min( 929 | test_auc_type0, test_auc_type1, test_auc_type2, test_auc_type3 930 | ) 931 | 932 | new_row2 = [ 933 | args.tuning_method, 934 | round(trainable_percentage, 3), 935 | args.lr, 936 | test_auc, 937 | best_auc, 938 | worst_auc, 939 | round(abs(best_auc - worst_auc), 3), 940 | round(equiodds_diff, 3), 941 | round(equiodds_ratio, 3), 942 | round(dpd, 3), 943 | round(dpr, 3), 944 | mask_path, 945 | ] 946 | 947 | else: 948 | raise NotImplementedError("Sensitive attribute not implemented") 949 | 950 | test_results_df.loc[len(test_results_df)] = new_row2 951 | 952 | print( 953 | "Saving test results df at: {}".format( 954 | os.path.join(args.output_dir, args.test_results_df) 955 | ) 956 | ) 957 | 958 | test_results_df.to_csv( 959 | os.path.join(args.output_dir, args.test_results_df), index=False 960 | ) 961 | 962 | 963 | if __name__ == "__main__": 964 | args = get_args_parser().parse_args() 965 | args.output_dir = os.path.join(os.getcwd(), args.model, args.dataset) 966 | 967 | if "auc" in args.objective_metric: 968 | args.use_metric = "auc" 969 | 970 | assert args.use_metric == "auc" 971 | assert args.cal_equiodds is not None 972 | 973 | args.test_results_df = ( 974 | "RESULTS_" + args.sens_attribute + "_" + args.objective_metric + ".csv" 975 | ) 976 | 977 | current_wd = os.getcwd() 978 | args.fig_savepath = os.path.join(args.output_dir, "plots/") 979 | 980 | args.train_fscl_classifier = False 981 | args.train_fscl_encoder = False 982 | 983 | main(args) 984 | -------------------------------------------------------------------------------- /search_mask.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["TORCH_HOME"] = os.path.dirname(os.getcwd()) 4 | 5 | import datetime 6 | import random 7 | import sys 8 | import re 9 | import time 10 | import warnings 11 | import timm 12 | import logging 13 | import pandas as pd 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | from utilities import presets 17 | import torch 18 | import torch.utils.data 19 | import torchvision 20 | from utilities import transforms 21 | from utilities.utils import * 22 | from utilities.training_utils import * 23 | from parse_args import * 24 | from utilities.sampler import RASampler 25 | from torch import nn 26 | from torch.utils.data.dataloader import default_collate 27 | from torchvision.transforms.functional import InterpolationMode 28 | import yaml 29 | from pprint import pprint 30 | 31 | import optuna 32 | from optuna.trial import TrialState 33 | 34 | ############################################################################ 35 | 36 | 37 | def create_opt_mask(trial, args, num_blocks): 38 | mask_length = None 39 | # if(args.tuning_method == 'tune_full_block' or args.tuning_method == 'tune_attention_blocks_random'): 40 | if args.tuning_method in [ 41 | "tune_full_block", 42 | "tune_attention_blocks_random", 43 | "tune_layernorm_blocks_random", 44 | "tune_attention_layernorm", 45 | "tune_attention_mlp", 46 | "tune_layernorm_mlp", 47 | "tune_attention_layernorm_mlp", 48 | ]: 49 | mask_length = num_blocks 50 | elif args.tuning_method == "auto_peft1": 51 | mask_length = num_blocks * 3 52 | elif ( 53 | args.tuning_method == "tune_attention_params_random" 54 | or args.tuning_method == "auto_peft2" 55 | ): 56 | mask_length = num_blocks * 4 57 | elif args.tuning_method == "fullft": 58 | return None 59 | else: 60 | raise NotImplementedError 61 | 62 | mask = np.zeros(mask_length, dtype=np.int8) 63 | 64 | for i in range(mask_length): 65 | mask[i] = trial.suggest_int("Mask Idx {}".format(i), 0, 1) 66 | 67 | return mask 68 | 69 | 70 | def create_model(args): 71 | print("Creating model") 72 | print("TUNING METHOD: ", args.tuning_method) 73 | model = utils.get_timm_model(args.model, num_classes=args.num_classes) 74 | 75 | custom_keys_weight_decay = [] 76 | if args.bias_weight_decay is not None: 77 | custom_keys_weight_decay.append(("bias", args.bias_weight_decay)) 78 | if args.transformer_embedding_decay is not None: 79 | for key in [ 80 | "class_token", 81 | "position_embedding", 82 | "relative_position_bias_table", 83 | ]: 84 | custom_keys_weight_decay.append((key, args.transformer_embedding_decay)) 85 | 86 | parameters = utils.set_weight_decay( 87 | model, 88 | args.weight_decay, 89 | norm_weight_decay=args.norm_weight_decay, 90 | custom_keys_weight_decay=custom_keys_weight_decay 91 | if len(custom_keys_weight_decay) > 0 92 | else None, 93 | ) 94 | 95 | return model, parameters 96 | 97 | 98 | def create_results_df(args): 99 | test_results_df = None 100 | 101 | if args.sens_attribute == "gender": 102 | if args.use_metric == "acc": 103 | test_results_df = pd.DataFrame( 104 | columns=[ 105 | "Tuning Method", 106 | "Train Percent", 107 | "LR", 108 | "Test Acc Overall", 109 | "Test Acc Male", 110 | "Test Acc Female", 111 | "Test Acc Difference", 112 | ] 113 | ) 114 | elif args.use_metric == "auc": 115 | test_results_df = pd.DataFrame( 116 | columns=[ 117 | "Tuning Method", 118 | "Train Percent", 119 | "LR", 120 | "Test AUC Overall", 121 | "Test AUC Male", 122 | "Test AUC Female", 123 | "Test AUC Difference", 124 | ] 125 | ) 126 | elif ( 127 | args.sens_attribute == "skin_type" 128 | or args.sens_attribute == "age" 129 | or args.sens_attribute == "race" 130 | or args.sens_attribute == "age_sex" 131 | ): 132 | if args.use_metric == "acc": 133 | test_results_df = pd.DataFrame( 134 | columns=[ 135 | "Tuning Method", 136 | "Train Percent", 137 | "LR", 138 | "Test Acc Overall", 139 | "Test Acc (Best)", 140 | "Test Acc (Worst)", 141 | "Test Acc Difference", 142 | ] 143 | ) 144 | elif args.use_metric == "auc": 145 | test_results_df = pd.DataFrame( 146 | columns=[ 147 | "Tuning Method", 148 | "Train Percent", 149 | "LR", 150 | "Test AUC Overall", 151 | "Test AUC (Best)", 152 | "Test AUC (Worst)", 153 | "Test AUC Difference", 154 | ] 155 | ) 156 | else: 157 | raise NotImplementedError 158 | 159 | return test_results_df 160 | 161 | 162 | def define_dataloaders(args): 163 | with open("config.yaml") as file: 164 | yaml_data = yaml.safe_load(file) 165 | 166 | print("Creating dataset") 167 | ( 168 | dataset, 169 | dataset_val, 170 | dataset_test, 171 | train_sampler, 172 | val_sampler, 173 | test_sampler, 174 | ) = get_fairness_data(args, yaml_data) 175 | 176 | args.num_classes = len(dataset.classes) 177 | print("DATASET: ", args.dataset) 178 | print("Size of training dataset: ", len(dataset)) 179 | print("Size of validation dataset: ", len(dataset_val)) 180 | print("Size of test dataset: ", len(dataset_test)) 181 | print("Number of classes: ", args.num_classes) 182 | pprint(dataset.class_to_idx) 183 | 184 | collate_fn = None 185 | mixup_transforms = get_mixup_transforms(args) 186 | 187 | if mixup_transforms: 188 | mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) 189 | 190 | def collate_fn(batch): 191 | return mixupcutmix(*default_collate(batch)) 192 | 193 | print("Creating data loaders") 194 | 195 | if args.dataset == "papila": 196 | drop_last = False 197 | else: 198 | drop_last = True 199 | 200 | data_loader = torch.utils.data.DataLoader( 201 | dataset, 202 | batch_size=args.batch_size, 203 | sampler=train_sampler, 204 | num_workers=args.workers, 205 | pin_memory=True, 206 | collate_fn=collate_fn, 207 | drop_last=drop_last, 208 | ) 209 | data_loader_val = torch.utils.data.DataLoader( 210 | dataset_val, 211 | batch_size=args.batch_size, 212 | sampler=val_sampler, 213 | num_workers=args.workers, 214 | pin_memory=True, 215 | # drop_last=True 216 | ) 217 | data_loader_test = torch.utils.data.DataLoader( 218 | dataset_test, 219 | batch_size=args.batch_size, 220 | sampler=test_sampler, 221 | num_workers=args.workers, 222 | pin_memory=True, 223 | drop_last=False, 224 | ) 225 | 226 | return data_loader, data_loader_val, data_loader_test 227 | 228 | 229 | def objective(trial): 230 | args = get_args_parser().parse_args() 231 | device = torch.device(args.device) 232 | 233 | args.train_fscl_classifier = False 234 | args.train_fscl_encoder = False 235 | 236 | if args.dev_mode: 237 | args.disable_plotting = True 238 | args.disable_checkpointing = True 239 | 240 | # try: 241 | # _temp_trainable_params_df = pd.read_csv('_temp_trainable_params_df.csv') 242 | # except: 243 | # _temp_trainable_params_df = pd.DataFrame(columns=['Trainable Params']) 244 | 245 | # Saving results to a dataframe 246 | results_df_savedir = os.path.join(args.model, args.dataset, "Optuna Results") 247 | if not os.path.exists(results_df_savedir): 248 | os.makedirs(results_df_savedir, exist_ok=True) 249 | results_df_name = ( 250 | "Fairness_Optuna_" 251 | + args.sens_attribute 252 | + "_" 253 | + args.tuning_method 254 | + "_" 255 | + args.model 256 | + "_" 257 | + args.objective_metric 258 | + ".csv" 259 | ) 260 | 261 | if "auc" in args.objective_metric: 262 | args.use_metric = "auc" 263 | 264 | try: 265 | results_df = pd.read_csv(os.path.join(results_df_savedir, results_df_name)) 266 | except: 267 | results_df = create_results_df(args) 268 | 269 | args.output_dir = os.path.join(os.getcwd(), args.model, args.dataset) 270 | 271 | args.distributed = False 272 | if args.use_deterministic_algorithms: 273 | torch.backends.cudnn.benchmark = False 274 | torch.use_deterministic_algorithms(True) 275 | else: 276 | torch.backends.cudnn.benchmark = True 277 | 278 | # Create datasets and dataloaders here 279 | data_loader, data_loader_val, data_loader_test = define_dataloaders(args) 280 | 281 | # Create the model here 282 | model, parameters = create_model(args) 283 | 284 | mask = create_opt_mask(trial, args, len(model.blocks)) 285 | print("Mask: ", mask) 286 | print("\n") 287 | 288 | masking_vector = utils.get_masked_model(model, args.tuning_method, mask=list(mask)) 289 | 290 | trainable_params, all_param = utils.check_tunable_params(model, True) 291 | trainable_percentage = 100 * trainable_params / all_param 292 | 293 | model.to(device) 294 | 295 | # OL3I and Papila are highly imbalanced datasets. 296 | # We use class weights to balance the loss function 297 | 298 | if args.compute_cw: 299 | if args.dataset == "ol3i": 300 | weight = torch.tensor([0.04361966711306677, 0.9563803328869332]) 301 | elif args.dataset == "papila": 302 | weight = torch.tensor([0.20714285714285716, 0.7928571428571428]) 303 | else: 304 | raise NotImplementedError("Class weights not calculated for this dataset") 305 | weight = weight.to(device) 306 | print("Using CW Loss with weights: ", weight) 307 | else: 308 | weight = None 309 | 310 | # Create the optimizer, criterion, lr_scheduler here 311 | criterion = nn.CrossEntropyLoss( 312 | label_smoothing=args.label_smoothing, reduction="none" 313 | ) 314 | ece_criterion = utils.ECELoss() 315 | args.lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True) 316 | optimizer = get_optimizer(args, parameters) 317 | scaler = torch.cuda.amp.GradScaler() if args.amp else None 318 | 319 | lr_scheduler = get_lr_scheduler(args, optimizer) 320 | 321 | model_without_ddp = model 322 | if args.distributed: 323 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 324 | model_without_ddp = model.module 325 | 326 | print("Start training") 327 | start_time = time.time() 328 | 329 | model_ema = None 330 | 331 | for epoch in range(args.start_epoch, args.epochs): 332 | train_one_epoch_fairness( 333 | model, 334 | criterion, 335 | ece_criterion, 336 | optimizer, 337 | data_loader, 338 | device, 339 | epoch, 340 | args, 341 | model_ema, 342 | scaler, 343 | ) 344 | lr_scheduler.step() 345 | 346 | if args.sens_attribute == "gender": 347 | ( 348 | val_acc, 349 | val_male_acc, 350 | val_female_acc, 351 | val_auc, 352 | val_male_auc, 353 | val_female_auc, 354 | val_loss, 355 | val_max_loss, 356 | ) = evaluate_fairness_gender( 357 | model, 358 | criterion, 359 | ece_criterion, 360 | data_loader_val, 361 | args=args, 362 | device=device, 363 | ) 364 | print( 365 | "Val Acc: {:.2f}, Val Male Acc {:.2f}, Val Female Acc {:.2f}, Val Loss: {:.2f}, Val MAX LOSS: {:.2f}".format( 366 | val_acc, 367 | val_male_acc, 368 | val_female_acc, 369 | torch.mean(val_loss), 370 | val_max_loss, 371 | ) 372 | ) 373 | print( 374 | "Val AUC: {:.2f}, Val Male AUC {:.2f}, Val Female AUC {:.2f}".format( 375 | val_auc, 376 | val_male_auc, 377 | val_female_auc, 378 | ) 379 | ) 380 | 381 | elif args.sens_attribute == "skin_type": 382 | if args.skin_type == "multi": 383 | ( 384 | val_acc, 385 | val_acc_type0, 386 | val_acc_type1, 387 | val_acc_type2, 388 | val_acc_type3, 389 | val_acc_type4, 390 | val_acc_type5, 391 | val_auc, 392 | val_auc_type0, 393 | val_auc_type1, 394 | val_auc_type2, 395 | val_auc_type3, 396 | val_auc_type4, 397 | val_auc_type5, 398 | val_loss, 399 | val_max_loss, 400 | ) = evaluate_fairness_skin_type( 401 | model, 402 | criterion, 403 | ece_criterion, 404 | data_loader_val, 405 | args=args, 406 | device=device, 407 | ) 408 | print( 409 | "Val Acc: {:.2f}, Val Type 0 Acc: {:.2f}, Val Type 1 Acc: {:.2f}, Val Type 2 Acc: {:.2f}, Val Type 3 Acc: {:.2f}, Val Type 4 Acc: {:.2f}, Val Type 5 Acc: {:.2f}, Val Loss: {:.2f}, Val MAX LOSS: {:.2f}".format( 410 | val_acc, 411 | val_acc_type0, 412 | val_acc_type1, 413 | val_acc_type2, 414 | val_acc_type3, 415 | val_acc_type4, 416 | val_acc_type5, 417 | torch.mean(val_loss), 418 | val_max_loss, 419 | ) 420 | ) 421 | print( 422 | "Val AUC: {:.2f}, Val Type 0 AUC: {:.2f}, Val Type 1 AUC: {:.2f}, Val Type 2 AUC: {:.2f}, Val Type 3 AUC: {:.2f}, Val Type 4 AUC: {:.2f}, Val Type 5 AUC: {:.2f}".format( 423 | val_auc, 424 | val_auc_type0, 425 | val_auc_type1, 426 | val_auc_type2, 427 | val_auc_type3, 428 | val_auc_type4, 429 | val_auc_type5, 430 | ) 431 | ) 432 | print("\n") 433 | elif args.skin_type == "binary": 434 | ( 435 | val_acc, 436 | val_acc_type0, 437 | val_acc_type1, 438 | val_auc, 439 | val_auc_type0, 440 | val_auc_type1, 441 | val_loss, 442 | val_max_loss, 443 | ) = evaluate_fairness_skin_type_binary( 444 | model, 445 | criterion, 446 | ece_criterion, 447 | data_loader_val, 448 | args=args, 449 | device=device, 450 | ) 451 | print( 452 | "Val Acc: {:.2f}, Val Type 0 Acc: {:.2f}, Val Type 1 Acc: {:.2f}, Val Loss: {:.2f}, Val MAX LOSS: {:.2f}".format( 453 | val_acc, 454 | val_acc_type0, 455 | val_acc_type1, 456 | torch.mean(val_loss), 457 | val_max_loss, 458 | ) 459 | ) 460 | print( 461 | "Val AUC: {:.2f}, Val Type 0 AUC: {:.2f}, Val Type 1 AUC: {:.2f}".format( 462 | val_auc, 463 | val_auc_type0, 464 | val_auc_type1, 465 | ) 466 | ) 467 | print("\n") 468 | 469 | elif args.sens_attribute == "age": 470 | if args.age_type == "multi": 471 | ( 472 | val_acc, 473 | acc_age0_avg, 474 | acc_age1_avg, 475 | acc_age2_avg, 476 | acc_age3_avg, 477 | acc_age4_avg, 478 | val_auc, 479 | auc_age0_avg, 480 | auc_age1_avg, 481 | auc_age2_avg, 482 | auc_age3_avg, 483 | auc_age4_avg, 484 | val_loss, 485 | val_max_loss, 486 | ) = evaluate_fairness_age( 487 | model, 488 | criterion, 489 | ece_criterion, 490 | data_loader_val, 491 | args=args, 492 | device=device, 493 | ) 494 | print( 495 | "Val Acc: {:.2f}, Val Age Group0 Acc: {:.2f}, Val Age Group1 Acc: {:.2f}, Val Age Group2 Acc: {:.2f}, Val Age Group3 Acc: {:.2f}, Val Age Group4 Acc: {:.2f}, Val Loss: {:.2f}, Val MAX LOSS: {:.2f}".format( 496 | val_acc, 497 | acc_age0_avg, 498 | acc_age1_avg, 499 | acc_age2_avg, 500 | acc_age3_avg, 501 | acc_age4_avg, 502 | torch.mean(val_loss), 503 | val_max_loss, 504 | ) 505 | ) 506 | print( 507 | "Val AUC: {:.2f}, Val Age Group0 AUC: {:.2f}, Val Age Group1 AUC: {:.2f}, Val Age Group2 AUC: {:.2f}, Val Age Group3 AUC: {:.2f}, Val Age Group4 AUC: {:.2f}".format( 508 | val_auc, 509 | auc_age0_avg, 510 | auc_age1_avg, 511 | auc_age2_avg, 512 | auc_age3_avg, 513 | auc_age4_avg, 514 | ) 515 | ) 516 | elif args.age_type == "binary": 517 | ( 518 | val_acc, 519 | acc_age0_avg, 520 | acc_age1_avg, 521 | val_auc, 522 | auc_age0_avg, 523 | auc_age1_avg, 524 | val_loss, 525 | val_max_loss, 526 | ) = evaluate_fairness_age_binary( 527 | model, 528 | criterion, 529 | ece_criterion, 530 | data_loader_val, 531 | args=args, 532 | device=device, 533 | ) 534 | print( 535 | "Val Acc: {:.2f}, Val Age Group0 Acc: {:.2f}, Val Age Group1 Acc: {:.2f}, Val Loss: {:.2f}, Val MAX LOSS: {:.2f}".format( 536 | val_acc, 537 | acc_age0_avg, 538 | acc_age1_avg, 539 | torch.mean(val_loss), 540 | val_max_loss, 541 | ) 542 | ) 543 | print( 544 | "Val AUC: {:.2f}, Val Age Group0 AUC: {:.2f}, Val Age Group1 AUC: {:.2f}".format( 545 | val_auc, 546 | auc_age0_avg, 547 | auc_age1_avg, 548 | ) 549 | ) 550 | print("\n") 551 | else: 552 | raise NotImplementedError( 553 | "Age type not supported. Choose from 'multi' or 'binary'" 554 | ) 555 | 556 | elif args.sens_attribute == "race": 557 | if args.cal_equiodds: 558 | ( 559 | val_acc, 560 | acc_race0_avg, 561 | acc_race1_avg, 562 | val_auc, 563 | auc_race0_avg, 564 | auc_race1_avg, 565 | val_loss, 566 | val_max_loss, 567 | equiodds_diff, 568 | equiodds_ratio, 569 | dpd, 570 | dpr, 571 | ) = evaluate_fairness_race_binary( 572 | model, 573 | criterion, 574 | ece_criterion, 575 | data_loader_val, 576 | args=args, 577 | device=device, 578 | ) 579 | else: 580 | ( 581 | val_acc, 582 | acc_race0_avg, 583 | acc_race1_avg, 584 | val_auc, 585 | auc_race0_avg, 586 | auc_race1_avg, 587 | val_loss, 588 | val_max_loss, 589 | ) = evaluate_fairness_race_binary( 590 | model, 591 | criterion, 592 | ece_criterion, 593 | data_loader_val, 594 | args=args, 595 | device=device, 596 | ) 597 | 598 | elif args.sens_attribute == "age_sex": 599 | assert args.dataset == "chexpert" 600 | if args.cal_equiodds: 601 | ( 602 | val_acc, 603 | val_acc_type0, 604 | val_acc_type1, 605 | val_acc_type2, 606 | val_acc_type3, 607 | val_auc, 608 | val_auc_type0, 609 | val_auc_type1, 610 | val_auc_type2, 611 | val_auc_type3, 612 | val_loss, 613 | val_max_loss, 614 | equiodds_diff, 615 | equiodds_ratio, 616 | dpd, 617 | dpr, 618 | ) = evaluate_fairness_age_sex( 619 | model, 620 | criterion, 621 | ece_criterion, 622 | data_loader_val, 623 | args=args, 624 | device=device, 625 | ) 626 | else: 627 | ( 628 | val_acc, 629 | val_acc_type0, 630 | val_acc_type1, 631 | val_acc_type2, 632 | val_acc_type3, 633 | val_auc, 634 | val_auc_type0, 635 | val_auc_type1, 636 | val_auc_type2, 637 | val_auc_type3, 638 | val_loss, 639 | val_max_loss, 640 | ) = evaluate_fairness_age_sex( 641 | model, 642 | criterion, 643 | ece_criterion, 644 | data_loader_val, 645 | args=args, 646 | device=device, 647 | ) 648 | 649 | else: 650 | raise NotImplementedError("Sensitive attribute not implemented") 651 | 652 | if args.output_dir: 653 | checkpoint = { 654 | "model": model_without_ddp.state_dict(), 655 | "optimizer": optimizer.state_dict(), 656 | "lr_scheduler": lr_scheduler.state_dict(), 657 | "epoch": epoch, 658 | "args": args, 659 | } 660 | 661 | if model_ema: 662 | checkpoint["model_ema"] = model_ema.state_dict() 663 | if scaler: 664 | checkpoint["scaler"] = scaler.state_dict() 665 | # utils.save_on_master(checkpoint, os.path.join(args.output_dir, 'checkpoints', f"model_{epoch}.pth"))7 666 | ckpt_path = os.path.join( 667 | args.output_dir, 668 | "checkpoints", 669 | "checkpoint_" + args.tuning_method + ".pth", 670 | ) 671 | if not args.disable_checkpointing: 672 | utils.save_on_master(checkpoint, ckpt_path) 673 | 674 | total_time = time.time() - start_time 675 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 676 | print("Training time {}".format(total_time_str)) 677 | 678 | # Obtaining the performance on val set 679 | print("Training finished | Evaluating on the val set") 680 | 681 | if args.sens_attribute == "gender": 682 | ( 683 | val_acc, 684 | val_male_acc, 685 | val_female_acc, 686 | val_auc, 687 | val_male_auc, 688 | val_female_auc, 689 | val_loss, 690 | val_max_loss, 691 | ) = evaluate_fairness_gender( 692 | model, 693 | criterion, 694 | ece_criterion, 695 | data_loader_val, 696 | args=args, 697 | device=device, 698 | ) 699 | 700 | max_acc = max(val_male_acc, val_female_acc) 701 | min_acc = min(val_male_acc, val_female_acc) 702 | acc_diff = abs(max_acc - min_acc) 703 | 704 | max_auc = max(val_male_auc, val_female_auc) 705 | min_auc = min(val_male_auc, val_female_auc) 706 | auc_diff = abs(max_auc - min_auc) 707 | 708 | print("\n") 709 | print("Val Male Accuracy: ", val_male_acc) 710 | print("Val Female Accuracy: ", val_female_acc) 711 | print("Difference in sub-group performance: ", acc_diff) 712 | print("\n") 713 | print("Val Male AUC: ", val_male_auc) 714 | print("Val Female AUC: ", val_female_auc) 715 | print("Difference in sub-group performance (AUC): ", auc_diff) 716 | 717 | elif args.sens_attribute == "skin_type": 718 | if args.skin_type == "multi": 719 | ( 720 | val_acc, 721 | val_acc_type0, 722 | val_acc_type1, 723 | val_acc_type2, 724 | val_acc_type3, 725 | val_acc_type4, 726 | val_acc_type5, 727 | val_auc, 728 | val_auc_type0, 729 | val_auc_type1, 730 | val_auc_type2, 731 | val_auc_type3, 732 | val_auc_type4, 733 | val_auc_type5, 734 | val_loss, 735 | val_max_loss, 736 | ) = evaluate_fairness_skin_type( 737 | model, 738 | criterion, 739 | ece_criterion, 740 | data_loader_val, 741 | args=args, 742 | device=device, 743 | ) 744 | 745 | max_acc = max( 746 | val_acc_type0, 747 | val_acc_type1, 748 | val_acc_type2, 749 | val_acc_type3, 750 | val_acc_type4, 751 | val_acc_type5, 752 | ) 753 | min_acc = min( 754 | val_acc_type0, 755 | val_acc_type1, 756 | val_acc_type2, 757 | val_acc_type3, 758 | val_acc_type4, 759 | val_acc_type5, 760 | ) 761 | acc_diff = abs(max_acc - min_acc) 762 | 763 | max_auc = max( 764 | val_auc_type0, 765 | val_auc_type1, 766 | val_auc_type2, 767 | val_auc_type3, 768 | val_auc_type4, 769 | val_auc_type5, 770 | ) 771 | min_auc = min( 772 | val_auc_type0, 773 | val_auc_type1, 774 | val_auc_type2, 775 | val_auc_type3, 776 | val_auc_type4, 777 | val_auc_type5, 778 | ) 779 | auc_diff = abs(max_auc - min_auc) 780 | 781 | print("\n") 782 | print("Val Type 0 Accuracy: ", val_acc_type0) 783 | print("Val Type 1 Accuracy: ", val_acc_type1) 784 | print("Val Type 2 Accuracy: ", val_acc_type2) 785 | print("Val Type 3 Accuracy: ", val_acc_type3) 786 | print("Val Type 4 Accuracy: ", val_acc_type4) 787 | print("Val Type 5 Accuracy: ", val_acc_type5) 788 | print("Difference in sub-group performance (Accuracy): ", acc_diff) 789 | 790 | print("\n") 791 | print("Val Type 0 AUC: ", val_auc_type0) 792 | print("Val Type 1 AUC: ", val_auc_type1) 793 | print("Val Type 2 AUC: ", val_auc_type2) 794 | print("Val Type 3 AUC: ", val_auc_type3) 795 | print("Val Type 4 AUC: ", val_auc_type4) 796 | print("Val Type 5 AUC: ", val_auc_type5) 797 | print("Difference in sub-group performance (AUC): ", auc_diff) 798 | 799 | elif args.skin_type == "binary": 800 | ( 801 | val_acc, 802 | val_acc_type0, 803 | val_acc_type1, 804 | val_auc, 805 | val_auc_type0, 806 | val_auc_type1, 807 | val_loss, 808 | val_max_loss, 809 | ) = evaluate_fairness_skin_type_binary( 810 | model, 811 | criterion, 812 | ece_criterion, 813 | data_loader_val, 814 | args=args, 815 | device=device, 816 | ) 817 | 818 | max_acc = max(val_acc_type0, val_acc_type1) 819 | min_acc = min(val_acc_type0, val_acc_type1) 820 | acc_diff = abs(max_acc - min_acc) 821 | 822 | max_auc = max(val_auc_type0, val_auc_type1) 823 | min_auc = min(val_auc_type0, val_auc_type1) 824 | auc_diff = abs(max_auc - min_auc) 825 | 826 | print("\n") 827 | print("Val Type 0 Accuracy: ", val_acc_type0) 828 | print("Val Type 1 Accuracy: ", val_acc_type1) 829 | print("Difference in sub-group performance (Accuracy): ", acc_diff) 830 | 831 | print("\n") 832 | print("Overall Val AUC: ", val_auc) 833 | print("Val Type 0 AUC: ", val_auc_type0) 834 | print("Val Type 1 AUC: ", val_auc_type1) 835 | print("Difference in sub-group performance (AUC): ", auc_diff) 836 | 837 | elif args.sens_attribute == "age": 838 | if args.age_type == "multi": 839 | ( 840 | val_acc, 841 | acc_age0_avg, 842 | acc_age1_avg, 843 | acc_age2_avg, 844 | acc_age3_avg, 845 | acc_age4_avg, 846 | val_auc, 847 | auc_age0_avg, 848 | auc_age1_avg, 849 | auc_age2_avg, 850 | auc_age3_avg, 851 | auc_age4_avg, 852 | val_loss, 853 | val_max_loss, 854 | ) = evaluate_fairness_age( 855 | model, 856 | criterion, 857 | ece_criterion, 858 | data_loader_val, 859 | args=args, 860 | device=device, 861 | ) 862 | 863 | max_acc = max( 864 | acc_age0_avg, acc_age1_avg, acc_age2_avg, acc_age3_avg, acc_age4_avg 865 | ) 866 | min_acc = min( 867 | acc_age0_avg, acc_age1_avg, acc_age2_avg, acc_age3_avg, acc_age4_avg 868 | ) 869 | acc_diff = abs(max_acc - min_acc) 870 | 871 | max_auc = max( 872 | auc_age0_avg, auc_age1_avg, auc_age2_avg, auc_age3_avg, auc_age4_avg 873 | ) 874 | min_auc = min( 875 | auc_age0_avg, auc_age1_avg, auc_age2_avg, auc_age3_avg, auc_age4_avg 876 | ) 877 | auc_diff = abs(max_auc - min_auc) 878 | 879 | print("\n") 880 | print("val Age Group 0 Accuracy: ", acc_age0_avg) 881 | print("val Age Group 1 Accuracy: ", acc_age1_avg) 882 | print("val Age Group 2 Accuracy: ", acc_age2_avg) 883 | print("val Age Group 3 Accuracy: ", acc_age3_avg) 884 | print("val Age Group 4 Accuracy: ", acc_age4_avg) 885 | print("Difference in sub-group performance (Accuracy): ", acc_diff) 886 | 887 | print("\n") 888 | print("val Age Group 0 AUC: ", auc_age0_avg) 889 | print("val Age Group 1 AUC: ", auc_age1_avg) 890 | print("val Age Group 2 AUC: ", auc_age2_avg) 891 | print("val Age Group 3 AUC: ", auc_age3_avg) 892 | print("val Age Group 4 AUC: ", auc_age4_avg) 893 | print("Difference in sub-group performance (AUC): ", auc_diff) 894 | 895 | elif args.age_type == "binary": 896 | ( 897 | val_acc, 898 | acc_age0_avg, 899 | acc_age1_avg, 900 | val_auc, 901 | auc_age0_avg, 902 | auc_age1_avg, 903 | val_loss, 904 | val_max_loss, 905 | ) = evaluate_fairness_age_binary( 906 | model, 907 | criterion, 908 | ece_criterion, 909 | data_loader_val, 910 | args=args, 911 | device=device, 912 | ) 913 | 914 | max_acc = max(acc_age0_avg, acc_age1_avg) 915 | min_acc = min(acc_age0_avg, acc_age1_avg) 916 | acc_diff = abs(max_acc - min_acc) 917 | 918 | max_auc = max(auc_age0_avg, auc_age1_avg) 919 | min_auc = min(auc_age0_avg, auc_age1_avg) 920 | auc_diff = abs(max_auc - min_auc) 921 | 922 | print("\n") 923 | print("val Age Group 0 Accuracy: ", acc_age0_avg) 924 | print("val Age Group 1 Accuracy: ", acc_age1_avg) 925 | print("Difference in sub-group performance (Accuracy): ", acc_diff) 926 | 927 | print("\n") 928 | print("val Age Group 0 AUC: ", auc_age0_avg) 929 | print("val Age Group 1 AUC: ", auc_age1_avg) 930 | print("Difference in sub-group performance (AUC): ", auc_diff) 931 | 932 | else: 933 | raise NotImplementedError( 934 | "Age type not supported. Choose from 'multi' or 'binary'" 935 | ) 936 | 937 | elif args.sens_attribute == "race": 938 | if args.cal_equiodds: 939 | ( 940 | val_acc, 941 | acc_race0_avg, 942 | acc_race1_avg, 943 | val_auc, 944 | auc_race0_avg, 945 | auc_race1_avg, 946 | val_loss, 947 | val_max_loss, 948 | equiodds_diff, 949 | equiodds_ratio, 950 | dpd, 951 | dpr, 952 | ) = evaluate_fairness_race_binary( 953 | model, 954 | criterion, 955 | ece_criterion, 956 | data_loader_val, 957 | args=args, 958 | device=device, 959 | ) 960 | else: 961 | ( 962 | val_acc, 963 | acc_race0_avg, 964 | acc_race1_avg, 965 | val_auc, 966 | auc_race0_avg, 967 | auc_race1_avg, 968 | val_loss, 969 | val_max_loss, 970 | ) = evaluate_fairness_race_binary( 971 | model, 972 | criterion, 973 | ece_criterion, 974 | data_loader_val, 975 | args=args, 976 | device=device, 977 | ) 978 | 979 | max_acc = max(acc_race0_avg, acc_race1_avg) 980 | min_acc = min(acc_race0_avg, acc_race1_avg) 981 | acc_diff = abs(max_acc - min_acc) 982 | 983 | max_auc = max(auc_race0_avg, auc_race1_avg) 984 | min_auc = min(auc_race0_avg, auc_race1_avg) 985 | auc_diff = abs(max_auc - min_auc) 986 | 987 | print("\n") 988 | print("val Race Group 0 Accuracy: ", acc_race0_avg) 989 | print("val Race Group 1 Accuracy: ", acc_race1_avg) 990 | print("Difference in sub-group performance (Accuracy): ", acc_diff) 991 | 992 | print("\n") 993 | print("val Race Group 0 AUC: ", auc_race0_avg) 994 | print("val Race Group 1 AUC: ", auc_race0_avg) 995 | print("Difference in sub-group performance (AUC): ", auc_diff) 996 | 997 | elif args.sens_attribute == "age_sex": 998 | assert args.dataset == "chexpert" 999 | if args.cal_equiodds: 1000 | ( 1001 | val_acc, 1002 | val_acc_type0, 1003 | val_acc_type1, 1004 | val_acc_type2, 1005 | val_acc_type3, 1006 | val_auc, 1007 | val_auc_type0, 1008 | val_auc_type1, 1009 | val_auc_type2, 1010 | val_auc_type3, 1011 | val_loss, 1012 | val_max_loss, 1013 | equiodds_diff, 1014 | equiodds_ratio, 1015 | dpd, 1016 | dpr, 1017 | ) = evaluate_fairness_age_sex( 1018 | model, 1019 | criterion, 1020 | ece_criterion, 1021 | data_loader_val, 1022 | args=args, 1023 | device=device, 1024 | ) 1025 | else: 1026 | ( 1027 | val_acc, 1028 | val_acc_type0, 1029 | val_acc_type1, 1030 | val_acc_type2, 1031 | val_acc_type3, 1032 | val_auc, 1033 | val_auc_type0, 1034 | val_auc_type1, 1035 | val_auc_type2, 1036 | val_auc_type3, 1037 | val_loss, 1038 | val_max_loss, 1039 | ) = evaluate_fairness_age_sex( 1040 | model, 1041 | criterion, 1042 | ece_criterion, 1043 | data_loader_val, 1044 | args=args, 1045 | device=device, 1046 | ) 1047 | 1048 | max_acc = max(val_acc_type0, val_acc_type1, val_acc_type2, val_acc_type3) 1049 | min_acc = min(val_acc_type0, val_acc_type1, val_acc_type2, val_acc_type3) 1050 | acc_diff = abs(max_acc - min_acc) 1051 | 1052 | max_auc = max(val_auc_type0, val_auc_type1, val_auc_type2, val_auc_type3) 1053 | min_auc = min(val_auc_type0, val_auc_type1, val_auc_type2, val_auc_type3) 1054 | auc_diff = abs(max_auc - min_auc) 1055 | 1056 | print("\n") 1057 | print("val AgeSex Group 0 Accuracy: ", val_acc_type0) 1058 | print("val AgeSex Group 1 Accuracy: ", val_acc_type1) 1059 | print("val AgeSex Group 2 Accuracy: ", val_acc_type2) 1060 | print("val AgeSex Group 3 Accuracy: ", val_acc_type3) 1061 | print("Difference in sub-group performance (Accuracy): ", acc_diff) 1062 | 1063 | print("\n") 1064 | print("val AgeSex Group 0 AUC: ", val_auc_type0) 1065 | print("val AgeSex Group 1 AUC: ", val_auc_type1) 1066 | print("val AgeSex Group 2 AUC: ", val_auc_type2) 1067 | print("val AgeSex Group 3 AUC: ", val_auc_type3) 1068 | 1069 | else: 1070 | raise NotImplementedError("Sensitive attribute not implemented") 1071 | 1072 | print("Val overall accuracy: ", val_acc) 1073 | print("Val Max Accuracy: ", round(max_acc, 3)) 1074 | print("Val Min Accuracy: ", round(min_acc, 3)) 1075 | print("Val Accuracy Difference: ", round(acc_diff, 3)) 1076 | print("Val loss: ", round(torch.mean(val_loss).item(), 3)) 1077 | print("Val max loss: ", round(val_max_loss.item(), 3)) 1078 | 1079 | print("Val overall AUC: ", val_auc) 1080 | print("Val Max AUC: ", round(max_auc, 3)) 1081 | print("Val Min AUC: ", round(min_auc, 3)) 1082 | print("Val AUC Difference: ", round(auc_diff, 3)) 1083 | 1084 | # Adding results to the dataframe 1085 | if args.sens_attribute == "gender": 1086 | if args.use_metric == "acc": 1087 | _row = [ 1088 | args.tuning_method, 1089 | trainable_percentage, 1090 | args.lr, 1091 | round(val_acc, 3), 1092 | round(val_male_acc, 3), 1093 | round(val_female_acc, 3), 1094 | round(acc_diff, 3), 1095 | ] 1096 | if args.use_metric == "auc": 1097 | _row = [ 1098 | args.tuning_method, 1099 | trainable_percentage, 1100 | args.lr, 1101 | round(val_auc, 3), 1102 | round(val_male_auc, 3), 1103 | round(val_female_auc, 3), 1104 | round(auc_diff, 3), 1105 | ] 1106 | 1107 | elif ( 1108 | args.sens_attribute == "age" 1109 | or args.sens_attribute == "skin_type" 1110 | or args.sens_attribute == "race" 1111 | or args.sens_attribute == "age_sex" 1112 | ): 1113 | if args.use_metric == "acc": 1114 | _row = [ 1115 | args.tuning_method, 1116 | round(trainable_percentage, 3), 1117 | args.lr, 1118 | round(val_acc, 3), 1119 | round(max_acc, 3), 1120 | round(min_acc, 3), 1121 | round(acc_diff, 3), 1122 | ] 1123 | if args.use_metric == "auc": 1124 | _row = [ 1125 | args.tuning_method, 1126 | round(trainable_percentage, 3), 1127 | args.lr, 1128 | round(val_auc, 3), 1129 | round(max_auc, 3), 1130 | round(min_auc, 3), 1131 | round(auc_diff, 3), 1132 | ] 1133 | 1134 | results_df.loc[len(results_df)] = _row 1135 | print( 1136 | "!!! Saving the results dataframe at {}".format( 1137 | os.path.join(results_df_savedir, results_df_name) 1138 | ) 1139 | ) 1140 | results_df.to_csv(os.path.join(results_df_savedir, results_df_name), index=False) 1141 | 1142 | # Pruning 1143 | if args.objective_metric == "min_acc": 1144 | trial.report(min_acc, epoch) 1145 | elif args.objective_metric == "min_auc": 1146 | trial.report(min_auc, epoch) 1147 | 1148 | elif args.objective_metric == "acc_diff": 1149 | trial.report(acc_diff, epoch) 1150 | elif args.objective_metric == "auc_diff": 1151 | trial.report(auc_diff, epoch) 1152 | 1153 | elif args.objective_metric == "max_loss": 1154 | trial.report(val_max_loss, epoch) 1155 | 1156 | elif args.objective_metric == "overall_acc": 1157 | trial.report(val_acc, epoch) 1158 | elif args.objective_metric == "overall_auc": 1159 | trial.report(val_auc, epoch) 1160 | else: 1161 | raise NotImplementedError("Objective metric not implemented") 1162 | 1163 | if trial.should_prune(): 1164 | raise optuna.exceptions.TrialPruned() 1165 | 1166 | if args.objective_metric == "acc_diff": 1167 | try: 1168 | return acc_diff.item() 1169 | except: 1170 | return acc_diff 1171 | 1172 | elif args.objective_metric == "auc_diff": 1173 | try: 1174 | return auc_diff.item() 1175 | except: 1176 | return auc_diff 1177 | 1178 | elif args.objective_metric == "min_acc": 1179 | try: 1180 | return min_acc.item() 1181 | except: 1182 | return min_acc 1183 | elif args.objective_metric == "min_auc": 1184 | try: 1185 | return min_auc.item() 1186 | except: 1187 | return min_auc 1188 | 1189 | elif args.objective_metric == "max_loss": 1190 | try: 1191 | return val_max_loss.item() 1192 | except: 1193 | return val_max_loss 1194 | elif args.objective_metric == "overall_acc": 1195 | try: 1196 | return val_acc.item() 1197 | except: 1198 | return val_acc 1199 | elif args.objective_metric == "overall_auc": 1200 | try: 1201 | return val_auc.item() 1202 | except: 1203 | return val_auc 1204 | 1205 | else: 1206 | raise NotImplementedError("Objective metric not implemented") 1207 | 1208 | 1209 | if __name__ == "__main__": 1210 | args = get_args_parser().parse_args() 1211 | args.plots_save_dir = os.path.join( 1212 | os.getcwd(), 1213 | "plots", 1214 | "optuna_plots", 1215 | args.model, 1216 | args.dataset, 1217 | args.tuning_method, 1218 | args.sens_attribute, 1219 | ) 1220 | 1221 | if args.dev_mode: 1222 | args.disable_plotting = True 1223 | args.disable_checkpointing = True 1224 | 1225 | if not os.path.exists(args.plots_save_dir): 1226 | os.makedirs(args.plots_save_dir, exist_ok=True) 1227 | 1228 | if ( 1229 | args.objective_metric == "acc_diff" 1230 | or args.objective_metric == "auc_diff" 1231 | or args.objective_metric == "max_loss" 1232 | ): 1233 | direction = "minimize" 1234 | elif ( 1235 | args.objective_metric == "min_acc" 1236 | or args.objective_metric == "overall_acc" 1237 | or args.objective_metric == "overall_auc" 1238 | or args.objective_metric == "min_auc" 1239 | ): 1240 | direction = "maximize" 1241 | else: 1242 | raise NotImplementedError 1243 | 1244 | # Pruners 1245 | if args.pruner == "SuccessiveHalving": 1246 | pruner = optuna.pruners.SuccessiveHalvingPruner() 1247 | elif args.pruner == "MedianPruner": 1248 | pruner = optuna.pruners.MedianPruner() 1249 | elif args.pruner == "Hyperband": 1250 | pruner = optuna.pruners.HyperbandPruner() 1251 | else: 1252 | raise NotImplementedError 1253 | 1254 | if not args.disable_storage: 1255 | study_name = ( 1256 | args.dataset 1257 | + "_" 1258 | + args.tuning_method 1259 | + "_" 1260 | + args.sens_attribute 1261 | + "_" 1262 | + args.objective_metric 1263 | ) 1264 | optuna.logging.get_logger("optuna").addHandler( 1265 | logging.StreamHandler(sys.stdout) 1266 | ) 1267 | storage_dir = os.path.join(os.getcwd(), "Optuna_StorageDB") 1268 | if not os.path.exists(storage_dir): 1269 | os.makedirs(storage_dir) 1270 | # storage_name = os.path.join(storage_dir, "sqlite:///{}.db".format(study_name)) 1271 | storage_name = "sqlite:///{}.db".format(study_name) 1272 | print("!!! Creating the study DB at {}".format(storage_name)) 1273 | study = optuna.create_study( 1274 | direction=direction, pruner=pruner, storage=storage_name 1275 | ) 1276 | else: 1277 | study = optuna.create_study(direction=direction, pruner=pruner) 1278 | 1279 | study.optimize(objective, n_trials=args.num_trials, show_progress_bar=True) 1280 | 1281 | pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED]) 1282 | complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE]) 1283 | 1284 | # print("Pruned trials: ", pruned_trials) 1285 | 1286 | print("Study statistics: ") 1287 | print(" Number of finished trials: ", len(study.trials)) 1288 | print(" Number of pruned trials: ", len(pruned_trials)) 1289 | print(" Number of complete trials: ", len(complete_trials)) 1290 | 1291 | print("Best trial:") 1292 | trial = study.best_trial 1293 | 1294 | print(" Value: ", trial.value) 1295 | 1296 | print(" Params: ") 1297 | best_mask = [] 1298 | for key, value in trial.params.items(): 1299 | print(" {}: {}".format(key, value)) 1300 | if key == "lr": 1301 | continue 1302 | best_mask.append(value) 1303 | 1304 | # Save the best mask 1305 | best_mask = np.array(best_mask).astype(np.int8) 1306 | mask_savedir = os.path.join( 1307 | args.model, args.dataset, "Optuna_Masks", args.sens_attribute 1308 | ) 1309 | if not os.path.exists(mask_savedir): 1310 | os.makedirs(mask_savedir) 1311 | 1312 | print( 1313 | "!!! Saving the best mask at {}".format( 1314 | os.path.join( 1315 | mask_savedir, 1316 | args.tuning_method 1317 | + "_best_mask_" 1318 | + args.objective_metric 1319 | + "_" 1320 | + str(trial.value) 1321 | + ".npy", 1322 | ) 1323 | ) 1324 | ) 1325 | if not args.dev_mode: 1326 | np.save( 1327 | os.path.join( 1328 | mask_savedir, 1329 | args.tuning_method 1330 | + "_best_mask_" 1331 | + args.objective_metric 1332 | + "_" 1333 | + str(trial.value) 1334 | + ".npy", 1335 | ), 1336 | best_mask, 1337 | ) 1338 | 1339 | # Save these results to a dataframe 1340 | stats_df_savedir = os.path.join(args.model, args.dataset, "Optuna Run Stats") 1341 | if not os.path.exists(stats_df_savedir): 1342 | os.makedirs(stats_df_savedir) 1343 | stats_df_name = ( 1344 | "Run_Stats_" 1345 | + args.sens_attribute 1346 | + "_" 1347 | + args.tuning_method 1348 | + "_" 1349 | + args.model 1350 | + "_" 1351 | + args.objective_metric 1352 | + ".csv" 1353 | ) 1354 | best_params_df_name = ( 1355 | "Best_Params" 1356 | + args.sens_attribute 1357 | + "_" 1358 | + args.tuning_method 1359 | + "_" 1360 | + args.model 1361 | + "_" 1362 | + args.objective_metric 1363 | + ".csv" 1364 | ) 1365 | 1366 | df = study.trials_dataframe() 1367 | try: 1368 | df = df.drop( 1369 | [ 1370 | "datetime_start", 1371 | "datetime_complete", 1372 | "duration", 1373 | "system_attrs_completed_rung_0", 1374 | ], 1375 | axis=1, 1376 | ) # Drop unnecessary columns 1377 | except: 1378 | pass 1379 | df = df.rename(columns={"value": args.objective_metric}) 1380 | df.to_csv(os.path.join(stats_df_savedir, stats_df_name), index=False) 1381 | 1382 | # Save the best params 1383 | cols = ["Col {}".format(i) for i in range(len(trial.params))] 1384 | best_params_df = pd.DataFrame(columns=cols) 1385 | best_params_df.loc[len(best_params_df)] = list(trial.params.values()) 1386 | best_params_df.to_csv( 1387 | os.path.join(stats_df_savedir, best_params_df_name), index=False 1388 | ) 1389 | 1390 | #################### Plotting #################### 1391 | 1392 | # 1. Parameter importance plots 1393 | 1394 | # a) Bar Plot 1395 | 1396 | if not args.disable_plotting: 1397 | try: 1398 | param_imp_plot = optuna.visualization.matplotlib.plot_param_importances( 1399 | study 1400 | ) 1401 | param_imp_plot.figure.tight_layout() 1402 | param_imp_plot.figure.savefig( 1403 | os.path.join( 1404 | args.plots_save_dir, 1405 | "param_importance_{}.jpg".format(args.objective_metric), 1406 | ), 1407 | format="jpg", 1408 | ) 1409 | except: 1410 | print("Error in plotting parameter importance plot") 1411 | 1412 | # b) Contour Plot 1413 | try: 1414 | contour_fig = plt.figure() 1415 | contour_plot = optuna.visualization.matplotlib.plot_contour(study) 1416 | except: 1417 | print("Error in plotting contour plot") 1418 | # contour_fig.savefig(os.path.join(args.plots_save_dir, "contour_plot.jpg"), format="jpg") 1419 | 1420 | # print(contour_plot) 1421 | # contour_plot.figure.savefig(os.path.join(args.plots_save_dir, "contour_plot_{}.jpg".format(args.objective_metric)), format="jpg") 1422 | 1423 | # 2. Slice plot 1424 | # fig2 = plt.figure() 1425 | # slice_plot = optuna.visualization.matplotlib.plot_slice(study) 1426 | # print(slice_plot) 1427 | # slice_plot.figure.savefig(os.path.join(args.plots_save_dir, "slice_plot_{}.jpg".format(args.objective_metric)), format="jpg") 1428 | # fig2.add_axes(axes) 1429 | # plt.savefig(os.path.join(args.plots_save_dir, "slice_plot.jpg"), format="jpg") 1430 | # plt.close(fig2) 1431 | 1432 | # 3. Optimization history plot 1433 | try: 1434 | history_plot = optuna.visualization.matplotlib.plot_optimization_history( 1435 | study 1436 | ) 1437 | history_plot.figure.tight_layout() 1438 | history_plot.figure.savefig( 1439 | os.path.join( 1440 | args.plots_save_dir, 1441 | "optimization_history_{}.jpg".format(args.objective_metric), 1442 | ), 1443 | format="jpg", 1444 | ) 1445 | except: 1446 | print("Error in plotting optimization history plot") 1447 | 1448 | # 4. High-dimensional parameter relationships plot 1449 | try: 1450 | parallel_plot = optuna.visualization.matplotlib.plot_parallel_coordinate( 1451 | study 1452 | ) 1453 | parallel_plot.figure.tight_layout() 1454 | parallel_plot.figure.savefig( 1455 | os.path.join( 1456 | args.plots_save_dir, 1457 | "parallel_coordinate_{}.jpg".format(args.objective_metric), 1458 | ), 1459 | format="jpg", 1460 | ) 1461 | except: 1462 | print("Error in plotting parallel coordinate plot") 1463 | 1464 | # 5. Pareto front plot 1465 | try: 1466 | pareto_plot = optuna.visualization.matplotlib.plot_pareto_front(study) 1467 | pareto_plot.figure.tight_layout() 1468 | pareto_plot.figure.savefig( 1469 | os.path.join( 1470 | args.plots_save_dir, 1471 | "pareto_plot_{}.jpg".format(args.objective_metric), 1472 | ), 1473 | format="jpg", 1474 | ) 1475 | except: 1476 | print("Error in plotting Pareto front plot") 1477 | 1478 | # 6. Parameter Rank plot 1479 | try: 1480 | # param_rank_plot = optuna.visualization.matplotlib.plot_param_importances(study, target=lambda t: t.params[args.objective_metric]) 1481 | param_rank_plot = optuna.visualization.matplotlib.plot_rank(study) 1482 | param_rank_plot.figure.tight_layout() 1483 | param_rank_plot.figure.savefig( 1484 | os.path.join( 1485 | args.plots_save_dir, 1486 | "rank_plot_{}.jpg".format(args.objective_metric), 1487 | ), 1488 | format="jpg", 1489 | ) 1490 | except: 1491 | print("Error in plotting parameter rank plot") 1492 | --------------------------------------------------------------------------------