├── models ├── SWA │ ├── __init__.py │ └── SWA.py ├── SAMSWAD │ └── __init__.py ├── baseline │ ├── __init__.py │ └── baseline.py ├── DomainInd │ ├── __init__.py │ └── DomainInd.py ├── resampling │ ├── __init__.py │ └── resampling.py ├── resamplingSWAD │ ├── __init__.py │ └── resamplingSWAD.py ├── EnD │ ├── __init__.py │ ├── model.py │ └── EnD.py ├── LAFTR │ ├── __init__.py │ ├── model.py │ └── LAFTR.py ├── CFair │ ├── __init__.py │ ├── CFair.py │ └── model.py ├── GSAM │ ├── __init__.py │ └── GSAM.py ├── GroupDRO │ ├── __init__.py │ ├── utils.py │ └── GroupDRO.py ├── SWAD │ └── __init__.py ├── SAM │ ├── __init__.py │ ├── utils.py │ └── SAM.py ├── LNL │ ├── __init__.py │ ├── LNL.py │ └── model.py ├── ODR │ ├── __init__.py │ └── utils.py ├── __init__.py ├── basemodels_mlp.py ├── basemodels_3d.py ├── basemodels.py └── utils.py ├── utils ├── __init__.py └── basics.py ├── configs ├── wandb_init.json └── datasets.json ├── sweep ├── train-sweep │ ├── sweep_count.sh │ ├── slurm_sweep_count.sh │ ├── sweep_resampling.yaml │ ├── sweep_DomainInd.yaml │ ├── sweep_CFair.yaml │ ├── sweep_baseline.yaml │ ├── sweep_LAFTR.yaml │ ├── sweep_LNL.yaml │ ├── sweep_EnD.yaml │ ├── sweep_GroupDRO.yaml │ ├── sweep_SAM.yaml │ ├── sweep_SWA.yaml │ ├── sweep_SWAD.yaml │ ├── sweep_resamplingSWAD.yaml │ ├── sweep_ODR.yaml │ ├── sweep_GSAM.yaml │ └── sweep_batch.py └── test │ └── cross_domain │ ├── batch_submit.sh │ ├── slurm_batch_submit.sh │ └── cross_test.py ├── datasets ├── __init__.py ├── OCT.py ├── RadFusion_images.py ├── COVID_CT_MD.py ├── RadFusion_EHR.py ├── MIMIC_III.py ├── ADNI.py ├── PAPILA.py ├── Fitz17k.py ├── CXP.py ├── MIMIC_CXR.py ├── eICU.py ├── HAM10000.py ├── utils.py └── BaseDataset.py ├── docs ├── index.md ├── reference.md ├── quickstart.md └── customization.md ├── main.py ├── notebooks ├── CovidCT.ipynb ├── fit17k.ipynb ├── OCT.ipynb ├── PAPILA.ipynb ├── ADNI.ipynb └── HAM10000-example.ipynb └── README.md /models/SWA/__init__.py: -------------------------------------------------------------------------------- 1 | from models.SWA.SWA import SWA -------------------------------------------------------------------------------- /models/SAMSWAD/__init__.py: -------------------------------------------------------------------------------- 1 | from models.SAMSWAD.SAMSWAD import SAMSWAD -------------------------------------------------------------------------------- /models/baseline/__init__.py: -------------------------------------------------------------------------------- 1 | from models.baseline.baseline import baseline -------------------------------------------------------------------------------- /models/DomainInd/__init__.py: -------------------------------------------------------------------------------- 1 | from models.DomainInd.DomainInd import DomainInd -------------------------------------------------------------------------------- /models/resampling/__init__.py: -------------------------------------------------------------------------------- 1 | from models.resampling.resampling import resampling -------------------------------------------------------------------------------- /models/resamplingSWAD/__init__.py: -------------------------------------------------------------------------------- 1 | from models.resamplingSWAD.resamplingSWAD import resamplingSWAD -------------------------------------------------------------------------------- /models/EnD/__init__.py: -------------------------------------------------------------------------------- 1 | from models.EnD.model import EnDNet, EnDNet3D 2 | from models.EnD.EnD import EnD -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import utils.spatial_transforms 2 | import utils.basics 3 | import utils.evaluation -------------------------------------------------------------------------------- /models/LAFTR/__init__.py: -------------------------------------------------------------------------------- 1 | from models.LAFTR.model import LaftrNet 2 | from models.LAFTR.LAFTR import LAFTR -------------------------------------------------------------------------------- /models/CFair/__init__.py: -------------------------------------------------------------------------------- 1 | from models.CFair.model import CFairNet, CFairNet3D 2 | from models.CFair.CFair import CFair -------------------------------------------------------------------------------- /models/GSAM/__init__.py: -------------------------------------------------------------------------------- 1 | from models.GSAM.GSAM import GSAM 2 | from models.GSAM.utils import GSAM_optimizer, LinearScheduler -------------------------------------------------------------------------------- /models/GroupDRO/__init__.py: -------------------------------------------------------------------------------- 1 | from models.GroupDRO.utils import LossComputer 2 | from models.GroupDRO.GroupDRO import GroupDRO -------------------------------------------------------------------------------- /models/SWAD/__init__.py: -------------------------------------------------------------------------------- 1 | from models.SWAD.SWAD import SWAD 2 | from models.SWAD.utils import AveragedModel, update_bn, LossValley -------------------------------------------------------------------------------- /models/SAM/__init__.py: -------------------------------------------------------------------------------- 1 | from models.SAM.SAM import SAM 2 | from models.SAM.utils import SAM_optimizer, disable_running_stats, enable_running_stats -------------------------------------------------------------------------------- /models/LNL/__init__.py: -------------------------------------------------------------------------------- 1 | from models.LNL.model import LNLNet, LNLNet3D, LNLPredictor, LNLPredictor3D, grad_reverseLNL 2 | from models.LNL.LNL import LNL -------------------------------------------------------------------------------- /models/ODR/__init__.py: -------------------------------------------------------------------------------- 1 | from models.ODR.model import ODRModel, ODR_Encoder3D 2 | from models.ODR.ODR import ODR 3 | from models.ODR.utils import OrthoLoss -------------------------------------------------------------------------------- /configs/wandb_init.json: -------------------------------------------------------------------------------- 1 | { 2 | "project" : "MEDFAIR", 3 | "name": "baseline", 4 | "dir": "./output/", 5 | "entity": "yourname", 6 | "allow_val_change": true, 7 | "sync_tensorboard": true, 8 | "tags": ["all"], 9 | "mode": "online" 10 | } -------------------------------------------------------------------------------- /models/resampling/resampling.py: -------------------------------------------------------------------------------- 1 | from models.baseline import baseline 2 | 3 | 4 | class resampling(baseline): 5 | def __init__(self, opt, wandb): 6 | super(resampling, self).__init__(opt, wandb) 7 | self.set_network(opt) 8 | self.set_optimizer(opt) 9 | -------------------------------------------------------------------------------- /sweep/train-sweep/sweep_count.sh: -------------------------------------------------------------------------------- 1 | OPTIONS=d: 2 | LONGOPTS=sweep_id: 3 | ! PARSED=$(getopt --options=$OPTIONS --longoptions=$LONGOPTS --name "$0" -- "$@") 4 | eval set -- "$PARSED" 5 | 6 | sweep_id="$2" 7 | echo "$sweep_id" 8 | 9 | wandb agent --count 1 $sweep_id 10 | 11 | echo "done" -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import models.baseline 2 | import models.CFair 3 | import models.LAFTR 4 | import models.basemodels 5 | import models.basemodels_3d 6 | import models.basenet 7 | import models.LNL 8 | import models.EnD 9 | import models.DomainInd 10 | import models.ODR 11 | import models.SWA 12 | import models.SWAD 13 | import models.SAM 14 | import models.GSAM 15 | 16 | import models.SAMSWAD 17 | 18 | import models.utils -------------------------------------------------------------------------------- /sweep/train-sweep/slurm_sweep_count.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #SBATCH -N 1 3 | #SBATCH --ntasks-per-node=1 4 | #SBATCH --partition ampere 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --account your_account 7 | #SBATCH --time=4:30:30 8 | 9 | 10 | OPTIONS=d: 11 | LONGOPTS=sweep_id: 12 | ! PARSED=$(getopt --options=$OPTIONS --longoptions=$LONGOPTS --name "$0" -- "$@") 13 | eval set -- "$PARSED" 14 | 15 | sweep_id="$2" 16 | echo "$sweep_id" 17 | 18 | wandb agent --count 1 $sweep_id 19 | 20 | echo "done" 21 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from datasets.BaseDataset import BaseDataset 2 | from datasets.CXP import CXP 3 | from datasets.MIMIC_CXR import MIMIC_CXR 4 | from datasets.HAM10000 import HAM10000 5 | from datasets.PAPILA import PAPILA 6 | from datasets.RadFusion_images import RadFusion_images 7 | from datasets.RadFusion_EHR import RadFusion_EHR 8 | from datasets.OCT import OCT 9 | from datasets.ADNI import ADNI 10 | from datasets.Fitz17k import Fitz17k 11 | from datasets.COVID_CT_MD import COVID_CT_MD 12 | from datasets.MIMIC_III import MIMIC_III 13 | from datasets.eICU import eICU 14 | -------------------------------------------------------------------------------- /sweep/train-sweep/sweep_resampling.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python 5 | - ${program} 6 | - ${args} 7 | - "--experiment" 8 | - "resampling" 9 | - "--hyper_search" 10 | - "True" 11 | 12 | method: bayes 13 | metric: 14 | name: Validation loss 15 | goal: minimize 16 | parameters: 17 | lr: 18 | distribution: uniform 19 | min: 1e-5 20 | max: 1e-3 21 | weight_decay: 22 | value: 1e-4 23 | value: 1e-5 24 | 25 | total_epochs: 26 | value: 20 27 | -------------------------------------------------------------------------------- /sweep/train-sweep/sweep_DomainInd.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python 5 | - ${program} 6 | - ${args} 7 | - "--experiment" 8 | - "DomainInd" 9 | - "--hyper_search" 10 | - "True" 11 | 12 | method: bayes 13 | metric: 14 | name: Validation loss 15 | goal: minimize 16 | parameters: 17 | lr: 18 | values: 19 | #- 0.005 20 | #- 0.001 21 | - 0.0005 22 | - 0.0001 23 | weight_decay: 24 | #value: 1e-4 25 | value: 1e-5 26 | batch_size: 27 | values: 28 | - 1024 29 | total_epochs: 30 | value: 20 31 | -------------------------------------------------------------------------------- /sweep/train-sweep/sweep_CFair.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python 5 | - ${program} 6 | - ${args} 7 | - "--experiment" 8 | - "CFair" 9 | - "--hyper_search" 10 | - "True" 11 | 12 | method: bayes 13 | metric: 14 | name: Validation loss 15 | goal: minimize 16 | parameters: 17 | mu: 18 | distribution: uniform 19 | min: 0.01 20 | max: 5 21 | lr: 22 | distribution: uniform 23 | min: 1e-5 24 | max: 1e-3 25 | weight_decay: 26 | value: 1e-4 27 | value: 1e-5 28 | 29 | total_epochs: 30 | value: 20 31 | -------------------------------------------------------------------------------- /sweep/train-sweep/sweep_baseline.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python 5 | - ${program} 6 | - ${args} 7 | - "--experiment" 8 | - "baseline" 9 | - "--hyper_search" 10 | - "True" 11 | 12 | method: bayes 13 | metric: 14 | name: Validation loss 15 | goal: minimize 16 | parameters: 17 | lr: 18 | distribution: uniform 19 | min: 1e-5 20 | max: 1e-3 21 | weight_decay: 22 | value: 1e-4 23 | value: 1e-5 24 | batch_size: 25 | values: 26 | - 1024 27 | - 512 28 | - 256 29 | total_epochs: 30 | value: 30 31 | -------------------------------------------------------------------------------- /sweep/train-sweep/sweep_LAFTR.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python 5 | - ${program} 6 | - ${args} 7 | - "--experiment" 8 | - "LAFTR" 9 | - "--hyper_search" 10 | - "True" 11 | 12 | method: bayes 13 | metric: 14 | name: Validation loss 15 | goal: minimize 16 | parameters: 17 | class_coeff: 18 | distribution: uniform 19 | min: 0.01 20 | max: 5 21 | lr: 22 | distribution: uniform 23 | min: 1e-5 24 | max: 1e-3 25 | weight_decay: 26 | #value: 1e-4 27 | value: 1e-5 28 | 29 | total_epochs: 30 | value: 20 31 | -------------------------------------------------------------------------------- /models/resamplingSWAD/resamplingSWAD.py: -------------------------------------------------------------------------------- 1 | from models.SWAD import SWAD 2 | from torch.optim.lr_scheduler import CosineAnnealingLR 3 | from models.SWAD.utils import AveragedModel, update_bn, LossValley 4 | 5 | class resamplingSWAD(SWAD): 6 | def __init__(self, opt, wandb): 7 | super(resamplingSWAD, self).__init__(opt, wandb) 8 | self.annealing_epochs = opt['swa_annealing_epochs'] 9 | 10 | self.set_optimizer(opt) 11 | self.swad = LossValley(n_converge = opt['swad_n_converge'], n_tolerance = opt['swad_n_converge'] + opt['swad_n_tolerance'], 12 | tolerance_ratio = opt['swad_tolerance_ratio']) 13 | 14 | self.step = 0 15 | -------------------------------------------------------------------------------- /sweep/train-sweep/sweep_LNL.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python 5 | - ${program} 6 | - ${args} 7 | - "--experiment" 8 | - "LNL" 9 | - "--hyper_search" 10 | - "True" 11 | 12 | method: bayes 13 | metric: 14 | name: Validation loss 15 | goal: minimize 16 | parameters: 17 | _lambda: 18 | distribution: uniform 19 | min: 0.001 20 | max: 3 21 | lr: 22 | distribution: uniform 23 | min: 1e-5 24 | max: 1e-3 25 | weight_decay: 26 | value: 1e-4 27 | value: 1e-5 28 | 29 | lr_decay_rate: 30 | values: 31 | - 10 32 | - 5 33 | total_epochs: 34 | value: 20 35 | -------------------------------------------------------------------------------- /sweep/train-sweep/sweep_EnD.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python 5 | - ${program} 6 | - ${args} 7 | - "--experiment" 8 | - "EnD" 9 | - "--hyper_search" 10 | - "True" 11 | - "--sens_classes" 12 | - "5" 13 | 14 | method: bayes 15 | metric: 16 | name: Validation loss 17 | goal: minimize 18 | parameters: 19 | alpha: 20 | distribution: uniform 21 | min: 0.01 22 | max: 5.0 23 | beta: 24 | distribution: uniform 25 | min: 0.01 26 | max: 5.0 27 | lr: 28 | distribution: uniform 29 | min: 1e-5 30 | max: 1e-3 31 | weight_decay: 32 | value: 1e-4 33 | value: 1e-5 34 | 35 | total_epochs: 36 | value: 20 37 | -------------------------------------------------------------------------------- /sweep/train-sweep/sweep_GroupDRO.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python 5 | - ${program} 6 | - ${args} 7 | - "--experiment" 8 | - "GroupDRO" 9 | - "--hyper_search" 10 | - "True" 11 | 12 | method: bayes 13 | metric: 14 | name: Validation loss 15 | goal: minimize 16 | parameters: 17 | groupdro_alpha: 18 | distribution: uniform 19 | min: 0.01 20 | max: 5 21 | groupdro_gamma: 22 | distribution: uniform 23 | min: 0.01 24 | max: 5 25 | lr: 26 | distribution: uniform 27 | min: 1e-5 28 | max: 1e-3 29 | weight_decay: 30 | values: 31 | - 1e-5 32 | - 1e-3 33 | - 1e-2 34 | - 1e-1 35 | - 1e-4 36 | 37 | total_epochs: 38 | value: 20 39 | -------------------------------------------------------------------------------- /sweep/train-sweep/sweep_SAM.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python 5 | - ${program} 6 | - ${args} 7 | - "--experiment" 8 | - "SAM" 9 | - "--hyper_search" 10 | - "True" 11 | - "--early_stopping" 12 | - "10" 13 | 14 | method: bayes 15 | metric: 16 | name: Validation loss 17 | goal: minimize 18 | parameters: 19 | rho: 20 | distribution: uniform 21 | min: 0.01 22 | max: 5 23 | adaptive: 24 | values: 25 | - False 26 | lr: 27 | distribution: uniform 28 | min: 1e-4 29 | max: 1e-1 30 | weight_decay: 31 | values: 32 | - 1e-5 33 | - 1e-4 34 | - 1e-3 35 | T_max: 36 | values: 37 | - 100 38 | - 200 39 | 40 | total_epochs: 41 | value: 40 42 | -------------------------------------------------------------------------------- /sweep/train-sweep/sweep_SWA.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python 5 | - ${program} 6 | - ${args} 7 | - "--experiment" 8 | - "SWA" 9 | - "--hyper_search" 10 | - "True" 11 | 12 | method: bayes 13 | metric: 14 | name: Validation loss 15 | goal: minimize 16 | parameters: 17 | swa_start: 18 | values: 19 | - 5 20 | - 7 21 | - 10 22 | swa_lr: 23 | values: 24 | - 0.1 25 | - 0.05 26 | - 0.01 27 | - 0.005 28 | - 0.0001 29 | swa_annealing_epochs: 30 | values: 31 | - 0 32 | - 3 33 | - 5 34 | - 7 35 | lr: 36 | distribution: uniform 37 | min: 1e-5 38 | max: 1e-3 39 | weight_decay: 40 | value: 1e-5 41 | total_epochs: 42 | value: 20 43 | -------------------------------------------------------------------------------- /sweep/train-sweep/sweep_SWAD.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python 5 | - ${program} 6 | - ${args} 7 | - "--experiment" 8 | - "SWAD" 9 | - "--hyper_search" 10 | - "True" 11 | 12 | method: bayes 13 | metric: 14 | name: Validation loss 15 | goal: minimize 16 | parameters: 17 | swad_n_converge: 18 | values: 19 | - 3 20 | - 5 21 | - 7 22 | - 9 23 | swad_n_tolerance: 24 | values: 25 | - 3 26 | - 5 27 | - 7 28 | - 9 29 | swad_tolerance_ratio: 30 | distribution: uniform 31 | min: 0.01 32 | max: 0.3 33 | lr: 34 | distribution: uniform 35 | min: 1e-5 36 | max: 1e-3 37 | weight_decay: 38 | value: 1e-5 39 | value: 1e-4 40 | 41 | total_epochs: 42 | value: 20 43 | -------------------------------------------------------------------------------- /datasets/OCT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from datasets.BaseDataset import BaseDataset 5 | 6 | class OCT(BaseDataset): 7 | def __init__(self, dataframe, path_to_images, sens_name, sens_classes, transform): 8 | super(OCT, self).__init__(dataframe, path_to_images, sens_name, sens_classes, transform) 9 | 10 | self.A = self.set_A(sens_name) 11 | self.Y = (np.asarray(self.dataframe['label'].values) > 0).astype('float') 12 | self.AY_proportion = None 13 | 14 | def __getitem__(self, idx): 15 | item = self.dataframe.iloc[idx] 16 | img = np.load(os.path.join(self.path_to_images, item["Path"])) 17 | img = self.transform(img) 18 | 19 | label = torch.FloatTensor([item['label']]) 20 | 21 | sensitive = self.get_sensitive(self.sens_name, self.sens_classes, item) 22 | 23 | return img, label, sensitive, idx -------------------------------------------------------------------------------- /sweep/train-sweep/sweep_resamplingSWAD.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python 5 | - ${program} 6 | - ${args} 7 | - "--experiment" 8 | - "resamplingSWAD" 9 | - "--hyper_search" 10 | - "True" 11 | - "--resample_which" 12 | - "balanced" 13 | 14 | method: bayes 15 | metric: 16 | name: Validation loss 17 | goal: minimize 18 | parameters: 19 | swad_n_converge: 20 | values: 21 | - 3 22 | - 5 23 | - 7 24 | - 9 25 | swad_n_tolerance: 26 | values: 27 | - 3 28 | - 5 29 | - 7 30 | - 9 31 | swad_tolerance_ratio: 32 | values: 33 | - 0.03 34 | - 0.05 35 | - 0.1 36 | lr: 37 | distribution: uniform 38 | min: 1e-5 39 | max: 1e-3 40 | weight_decay: 41 | value: 1e-5 42 | 43 | total_epochs: 44 | value: 20 45 | -------------------------------------------------------------------------------- /sweep/train-sweep/sweep_ODR.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python 5 | - ${program} 6 | - ${args} 7 | - "--experiment" 8 | - "ODR" 9 | - "--hyper_search" 10 | - "True" 11 | 12 | method: bayes 13 | metric: 14 | name: Validation loss 15 | goal: minimize 16 | parameters: 17 | lambda_e: 18 | distribution: uniform 19 | min: 0.01 20 | max: 5.0 21 | lambda_od: 22 | distribution: uniform 23 | min: 0.01 24 | max: 5.0 25 | gamma_e: 26 | distribution: uniform 27 | min: 0.01 28 | max: 5.0 29 | gamma_od: 30 | distribution: uniform 31 | min: 0.01 32 | max: 5.0 33 | step_size: 34 | values: 35 | - 10 36 | - 20 37 | - 50 38 | lr: 39 | distribution: uniform 40 | min: 1e-5 41 | max: 1e-3 42 | 43 | total_epochs: 44 | value: 25 45 | -------------------------------------------------------------------------------- /datasets/RadFusion_images.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from datasets.BaseDataset import BaseDataset 5 | 6 | class RadFusion_images(BaseDataset): 7 | def __init__(self, dataframe, path_to_images, sens_name, sens_classes, transform): 8 | super(RadFusion_images, self).__init__(dataframe, path_to_images, sens_name, sens_classes, transform) 9 | 10 | self.A = self.set_A(sens_name) 11 | self.Y = (np.asarray(self.dataframe['label'].values) > 0).astype('float') 12 | self.AY_proportion = None 13 | 14 | def __getitem__(self, idx): 15 | item = self.dataframe.iloc[idx] 16 | 17 | img = np.load(os.path.join(self.path_to_images, item["Path"])) 18 | 19 | img = self.transform(img) 20 | 21 | label = torch.FloatTensor([item['label']]) 22 | 23 | sensitive = self.get_sensitive(self.sens_name, self.sens_classes, item) 24 | 25 | return img, label, sensitive, idx -------------------------------------------------------------------------------- /datasets/COVID_CT_MD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pickle 4 | import os 5 | from datasets.BaseDataset import BaseDataset 6 | 7 | class COVID_CT_MD(BaseDataset): 8 | def __init__(self, dataframe, path_to_images, sens_name, sens_classes, transform): 9 | super(COVID_CT_MD, self).__init__(dataframe, path_to_images, sens_name, sens_classes, transform) 10 | 11 | self.A = self.set_A(sens_name) 12 | self.Y = (np.asarray(self.dataframe['binary_label'].values) > 0).astype('float') 13 | self.AY_proportion = None 14 | 15 | def __getitem__(self, idx): 16 | item = self.dataframe.iloc[idx] 17 | img = np.load(os.path.join(self.path_to_images, item["Path"])) 18 | img = self.transform(img) 19 | 20 | label = torch.FloatTensor([item['binary_label']]) 21 | 22 | sensitive = self.get_sensitive(self.sens_name, self.sens_classes, item) 23 | 24 | return img, label, sensitive, idx -------------------------------------------------------------------------------- /sweep/train-sweep/sweep_GSAM.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python 5 | - ${program} 6 | - ${args} 7 | - "--experiment" 8 | - "GSAM" 9 | - "--hyper_search" 10 | - "True" 11 | - "--early_stopping" 12 | - "5" 13 | 14 | method: bayes 15 | metric: 16 | name: Validation loss 17 | goal: minimize 18 | parameters: 19 | rho: 20 | values: 21 | - 0.05 22 | #- 0.1 23 | #- 0.5 24 | #- 1 25 | gsam_alpha: 26 | values: 27 | - 0.01 28 | - 0.05 29 | - 0.1 30 | lr: 31 | values: 32 | - 0.1 33 | - 0.05 34 | - 0.01 35 | weight_decay: 36 | values: 37 | - 1e-5 38 | - 1e-4 39 | - 1e-3 40 | T_max: 41 | values: 42 | - 20 43 | - 50 44 | - 100 45 | - 200 46 | #batch_size: 47 | # values: 48 | # - 1024 49 | total_epochs: 50 | value: 40 51 | -------------------------------------------------------------------------------- /datasets/RadFusion_EHR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from datasets.BaseDataset import BaseDataset 4 | 5 | 6 | class RadFusion_EHR(BaseDataset): 7 | def __init__(self, dataframe, data_df, sens_name, sens_classes, transform): 8 | super(RadFusion_EHR, self).__init__(dataframe, data_df, sens_name, sens_classes, transform) 9 | 10 | self.data_df = data_df 11 | self.A = self.set_A(sens_name) 12 | self.Y = (np.asarray(self.dataframe['label'].values) > 0).astype('float') 13 | self.AY_proportion = None 14 | 15 | def __getitem__(self, idx): 16 | item = self.dataframe.iloc[idx] 17 | ehr = self.data_df[self.data_df['idx']==item['idx']].drop(columns = ['idx']).values.squeeze() 18 | ehr = torch.FloatTensor(ehr) 19 | 20 | label = torch.FloatTensor([int(item['label'])]) 21 | 22 | sensitive = self.get_sensitive(self.sens_name, self.sens_classes, item) 23 | 24 | return ehr, label, sensitive, idx -------------------------------------------------------------------------------- /datasets/MIMIC_III.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import numpy as np 4 | from PIL import Image 5 | import pickle 6 | from datasets.BaseDataset import BaseDataset 7 | 8 | 9 | class MIMIC_III(BaseDataset): 10 | def __init__(self, dataframe, text_features, sens_name, sens_classes, transform): 11 | super(MIMIC_III, self).__init__(dataframe, text_features, sens_name, sens_classes, transform) 12 | 13 | self.text_features = text_features 14 | self.A = self.set_A(sens_name) 15 | self.Y = (np.asarray(self.dataframe['label'].values) > 0).astype('float') 16 | self.AY_proportion = None 17 | 18 | def __getitem__(self, idx): 19 | item = self.dataframe.iloc[idx] 20 | t_feature = self.text_features[idx] 21 | t_feature = torch.FloatTensor(t_feature) 22 | 23 | label = torch.FloatTensor([int(item['label'])]) 24 | 25 | sensitive = self.get_sensitive(self.sens_name, self.sens_classes, item) 26 | 27 | return t_feature, label, sensitive, idx -------------------------------------------------------------------------------- /datasets/ADNI.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import os 4 | import numpy as np 5 | from PIL import Image 6 | import pickle 7 | from datasets.BaseDataset import BaseDataset 8 | 9 | 10 | class ADNI(BaseDataset): 11 | def __init__(self, dataframe, path_to_images, sens_name, sens_classes, transform): 12 | super(ADNI, self).__init__(dataframe, path_to_images, sens_name, sens_classes, transform) 13 | 14 | self.A = self.set_A(sens_name) 15 | self.Y = (np.asarray(self.dataframe['label'].values) > 0).astype('float') 16 | self.AY_proportion = None 17 | 18 | def __getitem__(self, idx): 19 | item = self.dataframe.iloc[idx] 20 | 21 | img = np.load(os.path.join(self.path_to_images, item["Path"]).split('.nii')[0] + '.npy') 22 | 23 | img = self.transform(img) 24 | 25 | label = torch.FloatTensor([item['label']]) 26 | 27 | sensitive = self.get_sensitive(self.sens_name, self.sens_classes, item) 28 | 29 | return img, label, sensitive, idx -------------------------------------------------------------------------------- /datasets/PAPILA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import numpy as np 4 | from PIL import Image 5 | import pickle 6 | from datasets.BaseDataset import BaseDataset 7 | 8 | 9 | class PAPILA(BaseDataset): 10 | def __init__(self, dataframe, path_to_pickles, sens_name, sens_classes, transform): 11 | super(PAPILA, self).__init__(dataframe, path_to_pickles, sens_name, sens_classes, transform) 12 | 13 | with open(path_to_pickles, 'rb') as f: 14 | self.tol_images = pickle.load(f) 15 | 16 | self.A = self.set_A(sens_name) 17 | self.Y = (np.asarray(self.dataframe['Diagnosis'].values) > 0).astype('float') 18 | self.AY_proportion = None 19 | 20 | def __getitem__(self, idx): 21 | item = self.dataframe.iloc[idx] 22 | img = Image.fromarray(self.tol_images[idx]) 23 | img = self.transform(img) 24 | 25 | label = torch.FloatTensor([item['Diagnosis']]) 26 | 27 | sensitive = self.get_sensitive(self.sens_name, self.sens_classes, item) 28 | 29 | return img, label, sensitive, idx -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome to the documentation of MEDFAIR. 2 | 3 | ## Quick Start 4 | - [Installation](quickstart.md#installation) 5 | - [Dataset Download](quickstart.md#dataset-download) 6 | - [Running Experiments](quickstart.md#usage) 7 | - [Analyze Results](quickstart.md#analysis) 8 | 9 | ## Experiment Customization 10 | - [Customize Dataset](customization.md#customize-dataset) 11 | - [Customize Network Architectures](customization.md#customize-network-architectures) 12 | - [Customize Debiasing Algorithms](customization.md#customize-debiasing-algorithms) 13 | - [Customize Evaluation Metrics](customization.md#customize-evaluation-metrics) 14 | 15 | ## Code Structure 16 | - `configs`: Configuration for datasets and [Weight and Bias](https://wandb.ai/). 17 | - `datasets`: Dataset class for loading data. 18 | - `docs`: Documentations. 19 | - `models`: Implementation of debiasing algorithms. 20 | - `notebooks`: Data preprocessing/analysis scripts. 21 | - `utils`: Some useful functions. 22 | - `main.py`: Main entry file. 23 | 24 | ## References 25 | - [Datasets](reference.md#datasets) 26 | - [Debiasing Methods](reference.md#debiasing-methods) 27 | -------------------------------------------------------------------------------- /datasets/Fitz17k.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import numpy as np 4 | from PIL import Image 5 | import pickle 6 | from datasets.BaseDataset import BaseDataset 7 | 8 | 9 | class Fitz17k(BaseDataset): 10 | def __init__(self, dataframe, path_to_pickles, sens_name, sens_classes, transform): 11 | super(Fitz17k, self).__init__(dataframe, path_to_pickles, sens_name, sens_classes, transform) 12 | 13 | with open(path_to_pickles, 'rb') as f: 14 | self.tol_images = pickle.load(f) 15 | self.A = self.set_A(sens_name) 16 | 17 | self.Y = (np.asarray(self.dataframe['binary_label'].values) > 0).astype('float') 18 | self.AY_proportion = None 19 | 20 | def __getitem__(self, idx): 21 | item = self.dataframe.iloc[idx] 22 | img = Image.fromarray(self.tol_images[idx]) 23 | img = self.transform(img) 24 | 25 | label = torch.FloatTensor([int(item['binary_label'])]) 26 | 27 | sensitive = self.get_sensitive(self.sens_name, self.sens_classes, item) 28 | 29 | return img, label, sensitive, idx -------------------------------------------------------------------------------- /datasets/CXP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import numpy as np 4 | from PIL import Image 5 | import pickle 6 | from datasets.BaseDataset import BaseDataset 7 | 8 | 9 | class CXP(BaseDataset): 10 | def __init__(self, dataframe, path_to_pickles, sens_name, sens_classes, transform): 11 | super(CXP, self).__init__(dataframe, path_to_pickles, sens_name, sens_classes, transform) 12 | 13 | with open(path_to_pickles, 'rb') as f: 14 | self.tol_images = pickle.load(f) 15 | 16 | self.A = self.set_A(sens_name) 17 | self.Y = (np.asarray(self.dataframe['No Finding'].values) > 0).astype('float') 18 | self.AY_proportion = None 19 | 20 | def __getitem__(self, idx): 21 | item = self.dataframe.iloc[idx] 22 | 23 | img = Image.fromarray(self.tol_images[idx]).convert('RGB') 24 | img = self.transform(img) 25 | 26 | label = torch.FloatTensor([int(item['No Finding'].astype('float') > 0)]) 27 | 28 | sensitive = self.get_sensitive(self.sens_name, self.sens_classes, item) 29 | 30 | return img, label, sensitive, idx -------------------------------------------------------------------------------- /datasets/MIMIC_CXR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import numpy as np 4 | from PIL import Image 5 | import pickle 6 | from datasets.BaseDataset import BaseDataset 7 | 8 | 9 | class MIMIC_CXR(BaseDataset): 10 | def __init__(self, dataframe, PATH_TO_IMAGES, sens_name, sens_classes, transform): 11 | super(MIMIC_CXR, self).__init__(dataframe, PATH_TO_IMAGES, sens_name, sens_classes, transform) 12 | 13 | with open(PATH_TO_IMAGES, 'rb') as f: 14 | self.tol_images = pickle.load(f) 15 | 16 | self.A = self.set_A(sens_name) 17 | self.Y = (np.asarray(self.dataframe['No Finding'].values) > 0).astype('float') 18 | self.AY_proportion = None 19 | 20 | def __getitem__(self, idx): 21 | item = self.dataframe.iloc[idx] 22 | img = Image.fromarray(self.tol_images[idx]).convert('RGB') 23 | img = self.transform(img) 24 | 25 | label = torch.FloatTensor([int(item['No Finding'].astype('float') > 0)]) 26 | 27 | sensitive = self.get_sensitive(self.sens_name, self.sens_classes, item) 28 | 29 | return img, label, sensitive, idx -------------------------------------------------------------------------------- /datasets/eICU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from datasets.BaseDataset import BaseDataset 4 | 5 | 6 | class eICU(BaseDataset): 7 | def __init__(self, dataframe, s_features, sens_name, sens_classes, transform): 8 | super(eICU, self).__init__(dataframe, s_features, sens_name, sens_classes, transform) 9 | 10 | self.s_features = s_features 11 | self.A = self.set_A(sens_name) 12 | self.Y = (np.asarray(self.dataframe['mortality_LABEL'].values) > 0).astype('float') 13 | self.AY_proportion = None 14 | 15 | def __getitem__(self, idx): 16 | item = self.dataframe.iloc[idx] 17 | patient_idx = int(item['patientunitstayid']) 18 | feature_idx = int(np.where(self.s_features[:, -1]==patient_idx)[0]) 19 | s_feature = self.s_features[feature_idx, :-1] 20 | s_feature = torch.FloatTensor(s_feature) 21 | 22 | label = torch.FloatTensor([int(item['mortality_LABEL'])]) 23 | 24 | sensitive = self.get_sensitive(self.sens_name, self.sens_classes, item) 25 | 26 | return s_feature, label, sensitive, idx -------------------------------------------------------------------------------- /datasets/HAM10000.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import numpy as np 4 | from PIL import Image 5 | import pickle 6 | from datasets.BaseDataset import BaseDataset 7 | 8 | 9 | class HAM10000(BaseDataset): 10 | def __init__(self, dataframe, path_to_pickles, sens_name, sens_classes, transform): 11 | super(HAM10000, self).__init__(dataframe, path_to_pickles, sens_name, sens_classes, transform) 12 | 13 | with open(path_to_pickles, 'rb') as f: 14 | self.tol_images = pickle.load(f) 15 | 16 | self.A = self.set_A(sens_name) 17 | self.Y = (np.asarray(self.dataframe['binaryLabel'].values) > 0).astype('float') 18 | self.AY_proportion = None 19 | 20 | 21 | def __getitem__(self, idx): 22 | item = self.dataframe.iloc[idx] 23 | 24 | img = Image.fromarray(self.tol_images[idx]) 25 | img = self.transform(img) 26 | 27 | label = torch.FloatTensor([int(item['binaryLabel'])]) 28 | 29 | sensitive = self.get_sensitive(self.sens_name, self.sens_classes, item) 30 | 31 | return img, label, sensitive, idx -------------------------------------------------------------------------------- /models/baseline/baseline.py: -------------------------------------------------------------------------------- 1 | from models.utils import standard_train 2 | from models.basenet import BaseNet 3 | from importlib import import_module 4 | 5 | 6 | class baseline(BaseNet): 7 | def __init__(self, opt, wandb): 8 | super(baseline, self).__init__(opt, wandb) 9 | self.set_network(opt) 10 | self.set_optimizer(opt) 11 | 12 | def set_network(self, opt): 13 | """Define the network""" 14 | 15 | if self.is_3d: 16 | mod = import_module("models.basemodels_3d") 17 | cusModel = getattr(mod, self.backbone) 18 | self.network = cusModel(n_classes=self.output_dim, pretrained = self.pretrained).to(self.device) 19 | elif self.is_tabular: 20 | mod = import_module("models.basemodels_mlp") 21 | cusModel = getattr(mod, self.backbone) 22 | self.network = cusModel(n_classes=self.output_dim, in_features= self.in_features, hidden_features = 1024).to(self.device) 23 | else: 24 | mod = import_module("models.basemodels") 25 | cusModel = getattr(mod, self.backbone) 26 | self.network = cusModel(n_classes=self.output_dim, pretrained=self.pretrained).to(self.device) 27 | 28 | def _train(self, loader): 29 | """Train the model for one epoch""" 30 | 31 | self.network.train() 32 | auc, train_loss = standard_train(self.opt, self.network, self.optimizer, loader, self._criterion, self.wandb) 33 | 34 | print('Training epoch {}: AUC:{}'.format(self.epoch, auc)) 35 | print('Training epoch {}: loss:{}'.format(self.epoch, train_loss)) 36 | 37 | self.epoch += 1 38 | -------------------------------------------------------------------------------- /models/basemodels_mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from importlib import import_module 5 | 6 | 7 | class MLP(nn.Module): 8 | def __init__(self, in_features, hidden_features=1024, out_features=1): 9 | super().__init__() 10 | self.fc1 = nn.Linear(in_features, hidden_features) 11 | self.relu = nn.ReLU() 12 | self.fc2 = nn.Linear(hidden_features, out_features) 13 | 14 | 15 | def forward(self, x): 16 | x1 = self.fc1(x) 17 | x_hidden = self.relu(x1) 18 | x_out = self.fc2(x_hidden) 19 | return x_out, x_hidden.squeeze() 20 | 21 | 22 | class cusMLP(nn.Module): 23 | def __init__(self, n_classes, in_features, hidden_features, disentangle = False): 24 | super(cusMLP, self).__init__() 25 | self.backbone = MLP(in_features, hidden_features, n_classes) 26 | 27 | if disentangle is True: 28 | self.backbone.fc2 = nn.Linear(hidden_features * 2, n_classes) 29 | 30 | def forward(self, x): 31 | outputs, hidden = self.backbone(x) 32 | return outputs, hidden 33 | 34 | def inference(self, x): 35 | outputs, hidden = self.backbone(x) 36 | return outputs, hidden 37 | 38 | 39 | class MLPclassifer(nn.Module): 40 | def __init__(self, input_dim, output_dim): 41 | super(MLPclassifer, self).__init__() 42 | self.relu = nn.ReLU() 43 | self.fc1 = nn.Linear(input_dim, output_dim) 44 | #self.fc2 = nn.Linear(hidden_dim, output_dim) 45 | 46 | def forward(self,x): 47 | x = self.relu(x) 48 | x = self.fc1(x) 49 | #x = self.fc2(x) 50 | return x -------------------------------------------------------------------------------- /models/EnD/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from importlib import import_module 4 | import numpy as np 5 | 6 | 7 | class pattern_norm(nn.Module): 8 | def __init__(self, scale = 1.0): 9 | super(pattern_norm, self).__init__() 10 | self.scale = scale 11 | 12 | def forward(self, input): 13 | sizes = input.size() 14 | if len(sizes) > 2: 15 | input = input.view(-1, np.prod(sizes[1:])) 16 | input = torch.nn.functional.normalize(input, p=2, dim=1, eps=1e-12) 17 | input = input.view(sizes) 18 | return input 19 | 20 | 21 | def EnDNet(backbone, n_classes, pretrained = True): 22 | mod = import_module("models.basemodels") 23 | cusModel = getattr(mod, backbone) 24 | model = cusModel(n_classes=n_classes, pretrained=pretrained) 25 | model.body.avgpool = nn.Sequential( 26 | model.avgpool, 27 | pattern_norm() 28 | ) 29 | return model 30 | 31 | 32 | def EnDNet3D(backbone, n_classes, pretrained = True): 33 | 34 | mod = import_module("models.basemodels_3d") 35 | cusModel = getattr(mod, backbone) 36 | model = cusModel(n_classes=n_classes, pretrained=pretrained) 37 | model.body.avgpool = nn.Sequential( 38 | model.avgpool, 39 | pattern_norm() 40 | ) 41 | return model 42 | 43 | def EnDNetMLP(backbone, n_classes, in_features, hidden_features=1024): 44 | mod = import_module("models.basemodels_mlp") 45 | cusModel = getattr(mod, backbone) 46 | model = cusModel(n_classes=n_classes, in_features= in_features, hidden_features=hidden_features) 47 | model.backbone.fc1 = nn.Sequential( 48 | model.backbone.fc1, 49 | pattern_norm() 50 | ) 51 | return model -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import parse_args 2 | import json 3 | import numpy as np 4 | import pandas as pd 5 | from utils import basics 6 | import glob 7 | 8 | 9 | def train(model, opt): 10 | for epoch in range(opt['total_epochs']): 11 | ifbreak = model.train(epoch) 12 | if ifbreak: 13 | break 14 | 15 | # record val metrics for hyperparameter selection 16 | pred_df = model.record_val() 17 | return pred_df 18 | 19 | 20 | if __name__ == '__main__': 21 | 22 | opt, wandb = parse_args.collect_args() 23 | if not opt['test_mode']: 24 | 25 | random_seeds = np.random.choice(range(100), size = 3, replace=False).tolist() 26 | val_df = pd.DataFrame() 27 | test_df = pd.DataFrame() 28 | print('Random seed: ', random_seeds) 29 | for random_seed in random_seeds: 30 | opt['random_seed'] = random_seed 31 | model = basics.get_model(opt, wandb) 32 | pred_df = train(model, opt) 33 | val_df = pd.concat([val_df, pred_df]) 34 | 35 | pred_df = model.test() 36 | test_df = pd.concat([test_df, pred_df]) 37 | 38 | stat_val = basics.avg_eval(val_df, opt, 'val') 39 | stat_test = basics.avg_eval(test_df, opt, 'test') 40 | model.log_wandb(stat_val.to_dict()) 41 | model.log_wandb(stat_test.to_dict()) 42 | else: 43 | 44 | if opt['cross_testing']: 45 | 46 | test_df = pd.DataFrame() 47 | method_model_path = opt['cross_testing_model_path'] 48 | model_paths = glob.glob(method_model_path + '/cross_domain_*.pth') 49 | for model_path in model_paths: 50 | opt['cross_testing_model_path_single'] = model_path 51 | model = basics.get_model(opt, wandb) 52 | pred_df = model.test() 53 | 54 | test_df = pd.concat([test_df, pred_df]) 55 | stat_test = basics.avg_eval(test_df, opt, 'cross_testing') 56 | 57 | model.log_wandb(stat_test.to_dict()) -------------------------------------------------------------------------------- /models/basemodels_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | from torchvision.models.feature_extraction import create_feature_extractor 5 | 6 | 7 | class cusResNet18_3d(nn.Module): 8 | def __init__(self, n_classes, pretrained = True): 9 | super(cusResNet18_3d, self).__init__() 10 | resnet = torchvision.models.video.r3d_18(pretrained=pretrained) 11 | 12 | resnet.fc = nn.Linear(resnet.fc.in_features, n_classes) 13 | self.avgpool = resnet.avgpool 14 | 15 | self.returnkey_avg = 'avgpool' 16 | self.returnkey_fc = 'fc' 17 | self.body = create_feature_extractor( 18 | resnet, return_nodes={'avgpool': self.returnkey_avg, 'fc': self.returnkey_fc}) 19 | 20 | def forward(self, x): 21 | outputs = self.body(x) 22 | return outputs[self.returnkey_fc], outputs[self.returnkey_avg].squeeze() 23 | 24 | def inference(self, x): 25 | outputs = self.body(x) 26 | return outputs[self.returnkey_fc], outputs[self.returnkey_avg].squeeze() 27 | 28 | 29 | class cusResNet50_3d(cusResNet18_3d): 30 | def __init__(self, n_classes, pretrained = True): 31 | super(cusResNet50_3d, self).__init__(n_classes, pretrained) 32 | resnet = torch.hub.load('facebookresearch/pytorchvideo', 'slow_r50', pretrained=pretrained) 33 | 34 | resnet.blocks[-1].proj = nn.Linear(2048, n_classes) 35 | self.avgpool = resnet.blocks[-1].pool 36 | 37 | self.returnkey_avg = 'avgpool' 38 | self.returnkey_fc = 'fc' 39 | self.body = create_feature_extractor( 40 | resnet, return_nodes={'blocks.5.pool': self.returnkey_avg, 'blocks.5.proj': self.returnkey_fc}) 41 | 42 | 43 | class MLPclassifer(nn.Module): 44 | def __init__(self, input_dim, hidden_dim, output_dim): 45 | super(MLPclassifer, self).__init__() 46 | self.relu = nn.ReLU() 47 | self.fc1 = nn.Linear(input_dim, output_dim) 48 | #self.fc2 = nn.Linear(hidden_dim, output_dim) 49 | 50 | def forward(self,x): 51 | x = self.relu(x) 52 | x = self.fc1(x) 53 | #x = self.fc2(x) 54 | return x -------------------------------------------------------------------------------- /configs/datasets.json: -------------------------------------------------------------------------------- 1 | { 2 | "CXP": 3 | { 4 | "image_feature_path": "yourpath", 5 | "pickle_train_path": "yourpath", 6 | "pickle_val_path": "yourpath", 7 | "pickle_test_path": "yourpath", 8 | 9 | "train_meta_path": "yourpath", 10 | "val_meta_path": "yourpath", 11 | "test_meta_path": "yourpath" 12 | }, 13 | "MIMIC_CXR": 14 | { 15 | "image_feature_path": "yourpath", 16 | "pickle_train_path": "yourpath", 17 | "pickle_val_path": "yourpath", 18 | "pickle_test_path": "yourpath", 19 | "train_meta_path": "yourpath", 20 | "val_meta_path": "yourpath", 21 | "test_meta_path": "yourpath" 22 | }, 23 | "PAPILA": 24 | { 25 | "image_feature_path": "yourpath", 26 | "pickle_train_path": "yourpath", 27 | "pickle_val_path": "yourpath", 28 | "pickle_test_path": "yourpath", 29 | "train_meta_path": "yourpath", 30 | "val_meta_path": "yourpath", 31 | "test_meta_path": "yourpath" 32 | }, 33 | "OCT": 34 | { 35 | "image_feature_path": "yourpath", 36 | "train_meta_path": "yourpath", 37 | "val_meta_path": "yourpath", 38 | "test_meta_path": "yourpath" 39 | }, 40 | "HAM10000": 41 | { 42 | "image_feature_path": "yourpath", 43 | "pickle_train_path": "yourpath", 44 | "pickle_val_path": "yourpath", 45 | "pickle_test_path": "yourpath", 46 | "train_meta_path": "yourpath", 47 | "val_meta_path": "yourpath", 48 | "test_meta_path": "yourpath" 49 | }, 50 | "Fitz17k": 51 | { 52 | "image_feature_path": "yourpath", 53 | "pickle_train_path": "yourpath", 54 | "pickle_val_path": "yourpath", 55 | "pickle_test_path": "yourpath", 56 | "train_meta_path": "yourpath", 57 | "val_meta_path": "yourpath", 58 | "test_meta_path": "yourpath" 59 | }, 60 | "ADNI": 61 | { 62 | "image_feature_path": "yourpath", 63 | "train_meta_path": "yourpath", 64 | "val_meta_path": "yourpath", 65 | "test_meta_path": "yourpath" 66 | }, 67 | "ADNI3T": 68 | { 69 | "image_feature_path": "yourpath", 70 | "train_meta_path": "yourpath", 71 | "val_meta_path": "yourpath", 72 | "test_meta_path": "yourpath" 73 | }, 74 | "COVID_CT_MD": 75 | { 76 | "image_feature_path": "yourpath", 77 | "train_meta_path": "yourpath", 78 | "val_meta_path": "yourpath", 79 | "test_meta_path": "yourpath" 80 | } 81 | } -------------------------------------------------------------------------------- /sweep/test/cross_domain/batch_submit.sh: -------------------------------------------------------------------------------- 1 | OPTIONS=d: 2 | LONGOPTS=experiment:,dataset_name:,sensitive_name:,output_dim:,num_classes:,batch_size:,cross_testing_model_path:,sens_classes:,backbone:,source_domain:,target_domain: 3 | 4 | # -regarding ! and PIPESTATUS see above 5 | # -temporarily store output to be able to check for errors 6 | # -activate quoting/enhanced mode (e.g. by writing out “--options”) 7 | # -pass arguments only via -- "$@" to separate them correctly 8 | 9 | ! PARSED=$(getopt --options=$OPTIONS --longoptions=$LONGOPTS --name "$0" -- "$@") 10 | 11 | eval set -- "$PARSED" 12 | 13 | experiment="baseline" 14 | dataset_name="CXP" 15 | sensitive_name="Age" 16 | wandb_name="default" 17 | output_dim=1 18 | num_classes=1 19 | batch_size=1024 20 | backbone="cusResNet18" 21 | sens_classes=2 22 | cross_testing_model_path="" 23 | source_domain="" 24 | target_domain="" 25 | 26 | while true; do 27 | case "$1" in 28 | --experiment) 29 | experiment="$2" 30 | shift 2 31 | ;; 32 | --dataset_name) 33 | dataset_name="$2" 34 | shift 2 35 | ;; 36 | --sensitive_name) 37 | sensitive_name="$2" 38 | shift 2 39 | ;; 40 | --output_dim) 41 | output_dim="$2" 42 | shift 2 43 | ;; 44 | --num_classes) 45 | num_classes="$2" 46 | shift 2 47 | ;; 48 | --batch_size) 49 | batch_size="$2" 50 | shift 2 51 | ;; 52 | --cross_testing_model_path) 53 | cross_testing_model_path="$2" 54 | shift 2 55 | ;; 56 | --sens_classes) 57 | sens_classes="$2" 58 | shift 2 59 | ;; 60 | --backbone) 61 | backbone="$2" 62 | shift 2 63 | ;; 64 | --source_domain) 65 | source_domain="$2" 66 | shift 2 67 | ;; 68 | --target_domain) 69 | target_domain="$2" 70 | shift 2 71 | ;; 72 | --) 73 | shift 74 | break 75 | ;; 76 | *) 77 | echo "Programming error" 78 | exit 3 79 | ;; 80 | esac 81 | done 82 | 83 | 84 | python main.py --experiment $experiment --dataset_name $dataset_name --experiment_name $wandb_name --sensitive_name $sensitive_name --output_dim $output_dim --num_classes $num_classes --batch_size $batch_size --sens_classes $sens_classes --cross_testing --cross_testing_model_path $cross_testing_model_path --test_mode True --backbone $backbone --source_domain $source_domain --target_domain $target_domain -------------------------------------------------------------------------------- /utils/basics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from importlib import import_module 8 | import random 9 | import csv 10 | 11 | 12 | def save_results(t_predictions, tol_target, s_prediction, tol_sensitive, path): 13 | np.save(os.path.join(path, 'tpredictions.npy'), np.asarray(t_predictions)) 14 | np.save(os.path.join(path, 'ttargets.npy'), np.asarray(tol_target)) 15 | np.save(os.path.join(path, 'spredictions.npy'), np.asarray(s_prediction)) 16 | np.save(os.path.join(path, 'stargets.npy'), np.asarray(tol_sensitive)) 17 | 18 | 19 | def save_result_csv(log_dict, path): 20 | with open(path + '/results.csv', 'w') as f: 21 | w = csv.DictWriter(f, log_dict.keys()) 22 | w.writeheader() 23 | w.writerow(log_dict) 24 | 25 | 26 | def add_dict_prefix(dicts, prefix): 27 | new_dict = {} 28 | for k, v in dicts.items(): 29 | new_dict[prefix + k] = dicts[k] 30 | return new_dict 31 | 32 | 33 | def get_model(opt, wandb): 34 | mod = import_module("models" + '.' + opt['experiment']) 35 | model_name = getattr(mod, opt['experiment']) 36 | model = model_name(opt, wandb) 37 | return model 38 | 39 | 40 | def avg_eval(val_df, opt, mode = 'val'): 41 | val_df = val_df.reset_index(drop=True) 42 | 43 | mean_df = val_df.mean() 44 | std_df = val_df.std() 45 | sem_df = val_df.sem() 46 | ci95_hi = pd.DataFrame(mean_df + 1.96 * sem_df).transpose() 47 | ci95_lo = pd.DataFrame(mean_df - 1.96 * sem_df).transpose() 48 | mean_df = pd.DataFrame(mean_df).transpose() 49 | std_df = pd.DataFrame(std_df).transpose() 50 | 51 | stat = pd.concat([mean_df, std_df, ci95_hi, ci95_lo]).reset_index(drop=True) 52 | stat = stat.rename(index={0: 'mean', 1: 'std', 2: 'ci95_hi', 3: 'ci95_lo'}) 53 | save_path = os.path.join(opt['save_folder'], opt['experiment'] + '_'+ opt['hash'] + '_' + mode + '_pred_stat.csv') 54 | stat.to_csv(save_path) 55 | return stat 56 | 57 | 58 | def save_pkl(pkl_data, save_path): 59 | with open(save_path, 'wb') as f: 60 | pickle.dump(pkl_data, f) 61 | 62 | 63 | def load_pkl(load_path): 64 | with open(load_path, 'rb') as f: 65 | pkl_data = pickle.load(f) 66 | return pkl_data 67 | 68 | 69 | def save_json(json_data, save_path): 70 | with open(save_path, 'w') as f: 71 | json.dump(json_data, f) 72 | 73 | 74 | def load_json(load_path): 75 | with open(load_path, 'r') as f: 76 | json_data = json.load(f) 77 | return json_data 78 | 79 | 80 | def save_state_dict(state_dict, save_path): 81 | torch.save(state_dict, save_path) 82 | 83 | 84 | def creat_folder(path): 85 | if not os.path.exists(path): 86 | os.makedirs(path) 87 | 88 | -------------------------------------------------------------------------------- /sweep/test/cross_domain/slurm_batch_submit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #SBATCH -N 1 3 | #SBATCH --ntasks-per-node=1 4 | #SBATCH --partition ampere 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --account your_account 7 | #SBATCH --time=00:30:30 8 | 9 | OPTIONS=d: 10 | LONGOPTS=experiment:,dataset_name:,sensitive_name:,output_dim:,num_classes:,batch_size:,cross_testing_model_path:,sens_classes:,backbone:,source_domain:,target_domain: 11 | 12 | # -regarding ! and PIPESTATUS see above 13 | # -temporarily store output to be able to check for errors 14 | # -activate quoting/enhanced mode (e.g. by writing out “--options”) 15 | # -pass arguments only via -- "$@" to separate them correctly 16 | 17 | ! PARSED=$(getopt --options=$OPTIONS --longoptions=$LONGOPTS --name "$0" -- "$@") 18 | 19 | eval set -- "$PARSED" 20 | 21 | experiment="baseline" 22 | dataset_name="CXP" 23 | sensitive_name="Age" 24 | wandb_name="default" 25 | output_dim=1 26 | num_classes=1 27 | batch_size=1024 28 | backbone="cusResNet18" 29 | sens_classes=2 30 | cross_testing_model_path="" 31 | source_domain="" 32 | target_domain="" 33 | 34 | while true; do 35 | case "$1" in 36 | --experiment) 37 | experiment="$2" 38 | shift 2 39 | ;; 40 | --dataset_name) 41 | dataset_name="$2" 42 | shift 2 43 | ;; 44 | --sensitive_name) 45 | sensitive_name="$2" 46 | shift 2 47 | ;; 48 | --output_dim) 49 | output_dim="$2" 50 | shift 2 51 | ;; 52 | --num_classes) 53 | num_classes="$2" 54 | shift 2 55 | ;; 56 | --batch_size) 57 | batch_size="$2" 58 | shift 2 59 | ;; 60 | --cross_testing_model_path) 61 | cross_testing_model_path="$2" 62 | shift 2 63 | ;; 64 | --sens_classes) 65 | sens_classes="$2" 66 | shift 2 67 | ;; 68 | --backbone) 69 | backbone="$2" 70 | shift 2 71 | ;; 72 | --source_domain) 73 | source_domain="$2" 74 | shift 2 75 | ;; 76 | --target_domain) 77 | target_domain="$2" 78 | shift 2 79 | ;; 80 | --) 81 | shift 82 | break 83 | ;; 84 | *) 85 | echo "Programming error" 86 | exit 3 87 | ;; 88 | esac 89 | done 90 | 91 | 92 | python main.py --experiment $experiment --dataset_name $dataset_name --experiment_name $wandb_name --sensitive_name $sensitive_name --output_dim $output_dim --num_classes $num_classes --batch_size $batch_size --sens_classes $sens_classes --cross_testing --cross_testing_model_path $cross_testing_model_path --test_mode True --backbone $backbone --source_domain $source_domain --target_domain $target_domain -------------------------------------------------------------------------------- /models/basemodels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision.models.feature_extraction import create_feature_extractor 6 | 7 | 8 | class cusResNet18(nn.Module): 9 | def __init__(self, n_classes, pretrained = True): 10 | super(cusResNet18, self).__init__() 11 | resnet = torchvision.models.resnet18(pretrained=pretrained) 12 | 13 | resnet.fc = nn.Linear(resnet.fc.in_features, n_classes) 14 | self.avgpool = resnet.avgpool 15 | 16 | self.returnkey_avg = 'avgpool' 17 | self.returnkey_fc = 'fc' 18 | self.body = create_feature_extractor( 19 | resnet, return_nodes={'avgpool': self.returnkey_avg, 'fc': self.returnkey_fc}) 20 | 21 | def forward(self, x): 22 | outputs = self.body(x) 23 | return outputs[self.returnkey_fc], outputs[self.returnkey_avg].squeeze() 24 | 25 | def inference(self, x): 26 | outputs = self.body(x) 27 | return outputs[self.returnkey_fc], outputs[self.returnkey_avg].squeeze() 28 | 29 | 30 | class cusResNet50(cusResNet18): 31 | def __init__(self, n_classes, pretrained = True): 32 | super(cusResNet50, self).__init__(n_classes, pretrained) 33 | resnet = torchvision.models.resnet50(pretrained=pretrained) 34 | resnet.fc = nn.Linear(resnet.fc.in_features, n_classes) 35 | 36 | self.avgpool = resnet.avgpool 37 | self.returnkey_avg = 'avgpool' 38 | self.returnkey_fc = 'fc' 39 | self.body = create_feature_extractor( 40 | resnet, return_nodes={'avgpool': self.returnkey_avg, 'fc': self.returnkey_fc}) 41 | 42 | 43 | class cusDenseNet121(cusResNet18): 44 | def __init__(self, n_classes, pretrained = True, disentangle = False): 45 | super(cusDenseNet121, self).__init__(n_classes, pretrained) 46 | resnet = torchvision.models.densenet121(pretrained=pretrained) 47 | 48 | resnet.classifier = nn.Linear(resnet.classifier.in_features, n_classes) 49 | 50 | self.returnkey_fc = 'classifier' 51 | self.body = create_feature_extractor( 52 | resnet, return_nodes={'classifier': self.returnkey_fc}) 53 | 54 | def forward(self, x): 55 | outputs = self.body(x) 56 | return outputs[self.returnkey_fc], outputs[self.returnkey_fc] 57 | 58 | def inference(self, x): 59 | outputs = self.body(x) 60 | return outputs[self.returnkey_fc], outputs[self.returnkey_fc] 61 | 62 | 63 | class MLPclassifer(nn.Module): 64 | def __init__(self, input_dim, hidden_dim, output_dim): 65 | super(MLPclassifer, self).__init__() 66 | self.relu = nn.ReLU() 67 | self.fc1 = nn.Linear(input_dim, output_dim) 68 | 69 | def forward(self,x): 70 | x = self.relu(x) 71 | x = self.fc1(x) 72 | #x = self.fc2(x) 73 | return x -------------------------------------------------------------------------------- /docs/reference.md: -------------------------------------------------------------------------------- 1 | # Reference 2 | 3 | ## Datasets 4 | | **Dataset** | **Link** | 5 | |--------------|-----------------------------------------------------------------------------------------------| 6 | | CheXpert | [paper](https://arxiv.org/abs/1901.07031) | | 7 | | MIMIC-CXR | [paper](https://arxiv.org/abs/1901.07042) | 8 | | PAPILA | [paper](https://www.nature.com/articles/s41597-022-01388-1) | 9 | | HAM10000 | [paper](https://www.nature.com/articles/sdata2018161) | 10 | | OCT | [paper](https://www.sciencedirect.com/science/article/pii/S016164201300612X) | 11 | | COVID-CT-MD | [paper](https://www.nature.com/articles/s41597-021-00900-3) | 12 | | Fitzpatrick17k | [paper](https://arxiv.org/abs/2104.09957) | 13 | | ADNI | [paper](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2809036/) | 14 | 15 | 16 | ## Debiasing Methods 17 | | **Method** | **Link** | 18 | |--------------|-----------------------------------------------------------------------------------------------| 19 | | DomainInd | [paper](https://openaccess.thecvf.com/content_CVPR_2020/papers/Wang_Towards_Fairness_in_Visual_Recognition_Effective_Strategies_for_Bias_Mitigation_CVPR_2020_paper.pdf), [original repo](https://github.com/princetonvisualai/DomainBiasMitigation) | 20 | | LAFTR | [paper](https://arxiv.org/abs/1802.06309), [original repo](https://github.com/VectorInstitute/laftr) | 21 | | CFair | [paper](https://arxiv.org/abs/1910.07162), [original repo](https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=&cad=rja&uact=8&ved=2ahUKEwiFk6qjwfH5AhW6S0EAHTmKAfwQFnoECAkQAQ&url=https%3A%2F%2Fgithub.com%2FKeiraZhao%2FICLR2020-CFair&usg=AOvVaw3jqOPLtJfWgWQF86fPyB1q) | 22 | | LNL | [paper](https://arxiv.org/abs/1812.10352), [original repo](https://github.com/feidfoe/learning-not-to-learn) | 23 | | EnD | [paper](https://arxiv.org/abs/2103.02023), [original repo](https://github.com/EIDOSLAB/entangling-disentangling-bias) | 24 | | ODR | [paper](https://arxiv.org/abs/2003.05707), [original repo](https://github.com/spyrosavl/Fairness-by-Learning-Orthogonal-Disentangled-Representations) | 25 | | GroupDRO | [paper](https://arxiv.org/abs/1911.08731), [original repo](https://github.com/kohpangwei/group_DRO) | 26 | | SWAD | [paper](https://arxiv.org/abs/2102.08604), [original repo](https://github.com/khanrc/swad) | 27 | | SAM | [paper](https://openreview.net/forum?id=6Tm1mposlrM), [original repo](https://github.com/google-research/sam) | -------------------------------------------------------------------------------- /models/ODR/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torch.distributions.multivariate_normal import MultivariateNormal 6 | from torch.distributions import Normal 7 | 8 | 9 | class OrthoLoss(nn.Module): 10 | def __init__(self, lambda_e, lambda_od, gamma_e, gamma_od, step_size, device): 11 | super(OrthoLoss, self).__init__() 12 | self.lambda_e = lambda_e 13 | self.lambda_od = lambda_od 14 | self.gamma_e = gamma_e 15 | self.gamma_od = gamma_od 16 | self.step_size = step_size 17 | self.device = device 18 | 19 | self.bce = nn.BCEWithLogitsLoss() 20 | self.cross = nn.CrossEntropyLoss() 21 | self.kld = nn.KLDivLoss(reduction='batchmean') 22 | 23 | def mean_tensors(self, mean_1, mean_2, i): 24 | mean_1[i] = 1 25 | mean_2[i] = 0 26 | mean_t = torch.from_numpy(mean_1).float() 27 | mean_s = torch.from_numpy(mean_2).float() 28 | return mean_t, mean_s 29 | 30 | def L_e(self, sen_dis_out): 31 | L_e = -torch.sum(torch.softmax(sen_dis_out, dim=1) * torch.log_softmax(sen_dis_out, dim=1)) / sen_dis_out.shape[0] 32 | return L_e 33 | 34 | def forward(self, inputs, target, sensitive, current_step): 35 | mean_t, mean_s, log_std_t, log_std_s = inputs[0] 36 | y_zt, s_zt, s_zs = inputs[1] 37 | z1, z2 = inputs[2] 38 | y_zt, s_zt, s_zs = y_zt.to(self.device), s_zt.to(self.device), s_zs.to(self.device) 39 | target = target.to(self.device) 40 | 41 | L_t = self.bce(y_zt, target) 42 | mean_1, mean_2 = self.mean_tensors(np.zeros(128), np.ones(128), 13) 43 | m_t = MultivariateNormal(mean_1, torch.eye(128)) 44 | m_s = MultivariateNormal(mean_2, torch.eye(128)) 45 | 46 | Loss_e = self.L_e(s_zt) 47 | prior_t=[]; prior_s=[] 48 | enc_dis_t=[]; enc_dis_s=[] 49 | 50 | for i in range(z1.shape[0]): 51 | prior_t.append(m_t.sample()) 52 | prior_s.append(m_s.sample()) 53 | n_t = MultivariateNormal(mean_t[i], torch.diag(torch.exp(log_std_t[i]))) 54 | n_s = MultivariateNormal(mean_s[i], torch.diag(torch.exp(log_std_s[i]))) 55 | enc_dis_t.append(n_t.sample()) 56 | enc_dis_s.append(n_s.sample()) 57 | 58 | prior_t = torch.stack(prior_t) 59 | prior_s = torch.stack(prior_s) 60 | enc_dis_t = torch.stack(enc_dis_t) 61 | enc_dis_s = torch.stack(enc_dis_s) 62 | 63 | L_zt = self.kld(torch.log_softmax(prior_t, dim=1).to(self.device), torch.softmax(enc_dis_t, dim=1).to(self.device),) 64 | L_zs = self.kld(torch.log_softmax(prior_s, dim=1).to(self.device), torch.softmax(enc_dis_s, dim=1).to(self.device),) 65 | 66 | lambda_e = self.lambda_e * self.gamma_e ** (current_step/self.step_size) 67 | lambda_od = self.lambda_od * self.gamma_od ** (current_step/self.step_size) 68 | 69 | Loss = L_t + lambda_od * (L_zt + L_zs) + lambda_e * Loss_e 70 | return Loss -------------------------------------------------------------------------------- /models/SAM/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.batchnorm import _BatchNorm 3 | 4 | 5 | def disable_running_stats(model): 6 | def _disable(module): 7 | if isinstance(module, _BatchNorm): 8 | module.backup_momentum = module.momentum 9 | module.momentum = 0 10 | 11 | model.apply(_disable) 12 | 13 | def enable_running_stats(model): 14 | def _enable(module): 15 | if isinstance(module, _BatchNorm) and hasattr(module, "backup_momentum"): 16 | module.momentum = module.backup_momentum 17 | 18 | model.apply(_enable) 19 | 20 | class SAM_optimizer(torch.optim.Optimizer): 21 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs): 22 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 23 | 24 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs) 25 | super(SAM_optimizer, self).__init__(params, defaults) 26 | 27 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) 28 | self.param_groups = self.base_optimizer.param_groups 29 | self.defaults.update(self.base_optimizer.defaults) 30 | 31 | @torch.no_grad() 32 | def first_step(self, zero_grad=False): 33 | grad_norm = self._grad_norm() 34 | for group in self.param_groups: 35 | scale = group["rho"] / (grad_norm + 1e-12) 36 | 37 | for p in group["params"]: 38 | if p.grad is None: continue 39 | self.state[p]["old_p"] = p.data.clone() 40 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) 41 | p.add_(e_w) # climb to the local maximum "w + e(w)" 42 | 43 | if zero_grad: self.zero_grad() 44 | 45 | @torch.no_grad() 46 | def second_step(self, zero_grad=False): 47 | for group in self.param_groups: 48 | for p in group["params"]: 49 | if p.grad is None: continue 50 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)" 51 | 52 | self.base_optimizer.step() # do the actual "sharpness-aware" update 53 | 54 | if zero_grad: self.zero_grad() 55 | 56 | @torch.no_grad() 57 | def step(self, closure=None): 58 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" 59 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass 60 | 61 | self.first_step(zero_grad=True) 62 | closure() 63 | self.second_step() 64 | 65 | def _grad_norm(self): 66 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism 67 | norm = torch.norm( 68 | torch.stack([ 69 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) 70 | for group in self.param_groups for p in group["params"] 71 | if p.grad is not None 72 | ]), 73 | p=2 74 | ) 75 | return norm 76 | 77 | def load_state_dict(self, state_dict): 78 | super().load_state_dict(state_dict) 79 | self.base_optimizer.param_groups = self.param_groups -------------------------------------------------------------------------------- /sweep/test/cross_domain/cross_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import argparse 4 | 5 | 6 | parser = argparse.ArgumentParser(description='Hyperparameter sweep.') 7 | 8 | parser.add_argument('--source_dataset', type=str, default='CXP', help='source dataset name') 9 | parser.add_argument('--target_dataset', type=str, default='MIMIC_CXR', help='target dataset name') 10 | parser.add_argument('--sensitive_name', type=str, default='Age', help='dataset name') 11 | parser.add_argument('--sens_classes', type=int, default=2, help='number of sensitive classes') 12 | parser.add_argument('--num_classes', type=int, default=1, help='number of classes') 13 | parser.add_argument('--is_3d', type=bool, default=False, help='whether 3D dataset') 14 | parser.add_argument('--is_slurm', type=bool, default=True, help='whether using Slurm') 15 | 16 | args = parser.parse_args() 17 | 18 | 19 | if args.is_3d: 20 | backbone = 'cusResNet18_3d' 21 | batch_size = 8 22 | else: 23 | backbone = 'cusResNet18' 24 | batch_size = 1024 25 | 26 | sensitive_name = args.sensitive_name 27 | source_dataset = args.source_dataset 28 | target_dataset = args.target_dataset 29 | output_dim = num_classes = args.num_classes 30 | sens_classes = args.sens_classes 31 | 32 | methods = ['baseline', 'resampling', 'LAFTR', 'CFair', 'LNL', 'EnD', 'DomainInd', 'GroupDRO', 'ODR', 'SWAD', 'resamplingSWAD', 'SAM'] 33 | 34 | 35 | model_path = 'your_path/fariness_data/model_records/{datas}/{attr}/{bkbone}/'.format( 36 | datas = source_dataset, attr = sensitive_name, bkbone = backbone) 37 | 38 | for method in methods: 39 | 40 | method_model_path = os.path.join(model_path, method) 41 | if args.is_slurm: 42 | MAIN_CMD = f"sbatch sweep/test/cross_domain/slurm_batch_submit.sh" \ 43 | f" --experiment {method}"\ 44 | f" --dataset_name {target_dataset}"\ 45 | f" --sensitive_name {sensitive_name}"\ 46 | f" --output_dim {output_dim}"\ 47 | f" --num_classes {num_classes}"\ 48 | f" --batch_size {batch_size}"\ 49 | f" --cross_testing_model_path {method_model_path}"\ 50 | f" --sens_classes {sens_classes}"\ 51 | f" --backbone {backbone}"\ 52 | f" --source_domain {source_dataset}"\ 53 | f" --target_domain {target_dataset}" 54 | else: 55 | MAIN_CMD = f"bash sweep/test/cross_domain/batch_submit.sh" \ 56 | f" --experiment {method}"\ 57 | f" --dataset_name {target_dataset}"\ 58 | f" --sensitive_name {sensitive_name}"\ 59 | f" --output_dim {output_dim}"\ 60 | f" --num_classes {num_classes}"\ 61 | f" --batch_size {batch_size}"\ 62 | f" --cross_testing_model_path {method_model_path}"\ 63 | f" --sens_classes {sens_classes}"\ 64 | f" --backbone {backbone}"\ 65 | f" --source_domain {source_dataset}"\ 66 | f" --target_domain {target_dataset}"\ 67 | 68 | print('command is ', MAIN_CMD) 69 | CMD = MAIN_CMD.split(' ') 70 | process = subprocess.Popen(CMD, stdout=subprocess.PIPE, universal_newlines=True) 71 | out, err = process.communicate() 72 | print(out) 73 | 74 | -------------------------------------------------------------------------------- /docs/quickstart.md: -------------------------------------------------------------------------------- 1 | # Quick Start 2 | 3 | ## Installation 4 | Python >= 3.8+ and Pytorch >=1.10 are required for running the code. Other necessary packages are listed in [`environment.yml`](../environment.yml). 5 | 6 | ### Installation via conda: 7 | ```python 8 | cd MEDFAIR/ 9 | conda env create -n medfair_env -f environment.yml 10 | conda activate medfair_env 11 | ``` 12 | 13 | ## Dataset Download 14 | Due to the data use agreements, we cannot directly share the download link. Please follow the instructions and download datasets via links from the table below: 15 | 16 | 17 | | **Dataset** | **Access** | 18 | |--------------|-----------------------------------------------------------------------------------------------| 19 | | CheXpert | Original data: https://stanfordmlgroup.github.io/competitions/chexpert/ | 20 | | | Demographic data: https://stanfordaimi.azurewebsites.net/datasets/192ada7c-4d43-466e-b8bb-b81992bb80cf | 21 | | MIMIC-CXR | https://physionet.org/content/mimic-cxr-jpg/2.0.0/ | 22 | | PAPILA | https://www.nature.com/articles/s41597-022-01388-1#Sec6 | 23 | | HAM10000 | https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/DBW86T | 24 | | OCT | https://people.duke.edu/~sf59/RPEDC_Ophth_2013_dataset.htm | 25 | | Fitzpatrick17k | https://github.com/mattgroh/fitzpatrick17k | 26 | | COVID-CT-MD | https://doi.org/10.6084/m9.figshare.12991592 | 27 | | ADNI 1.5T/3T | https://ida.loni.usc.edu/login.jsp?project=ADNI | 28 | 29 | 30 | ## Usage 31 | 32 | ### Dataset Preprocessing 33 | See `notebooks/HAM10000.ipynb` for an simple example of how to preprocess the data into desired format. 34 | Basically, it contains 3 steps: 35 | 1. Preprocess metadata. 36 | 2. Split to train/val/test set 37 | 3. Save images into pickle files (optional -- we usually do this for 2D images instead of 3D images, as data IO is not the bottleneck for training 3D images). 38 | 39 | After preprocessing, specify the paths of the metadata and pickle files in `configs/datasets.json`. 40 | 41 | 42 | ### Run a single experiment 43 | ```python 44 | python main.py --experiment [experiment] --experiment_name [experiment_name] --dataset_name [dataset_name] \ 45 | --backbone [backbone] --total_epochs [total_epochs] --sensitive_name [sensitive_name] \ 46 | --batch_size [batch_size] --lr [lr] --sens_classes [sens_classes] --val_strategy [val_strategy] \ 47 | --output_dim [output_dim] --num_classes [num_classes] 48 | ``` 49 | See `parse_args.py` for more options. 50 | 51 | ### Run a grid search on a Slurm cluster/Regular Machine 52 | ```python 53 | python sweep/train-sweep/sweep_batch.py --is_slurm True/False 54 | ``` 55 | Set the other arguments as needed. 56 | 57 | 58 | ## Model selection and Results analysis 59 | See `notebooks/results_analysis.ipynb` for a step by step example. 60 | 61 | ## Tabular data 62 | We also implement these algorithms with a three-layer Multi-Layer Perceptron (MLP) as the backbone to explore the tabular data (This is not introduced in the paper). You can use the tabular mode with the parse argument `cusMLP` and `is_tabular`. -------------------------------------------------------------------------------- /models/GSAM/GSAM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.evaluation import calculate_auc 5 | from models.basenet import BaseNet 6 | from importlib import import_module 7 | import torchvision 8 | 9 | from models.GSAM.utils import GSAM_optimizer, LinearScheduler 10 | 11 | 12 | class GSAM(BaseNet): 13 | def __init__(self, opt, wandb): 14 | super(GSAM, self).__init__(opt, wandb) 15 | self.set_network(opt) 16 | self.set_optimizer(opt) 17 | 18 | 19 | def set_network(self, opt): 20 | """Define the network""" 21 | 22 | if not self.is_3d: 23 | mod = import_module("models.basemodels") 24 | cusModel = getattr(mod, self.backbone) 25 | self.network = cusModel(n_classes=self.output_dim, pretrained=self.pretrained).to(self.device) 26 | 27 | else: 28 | mod = import_module("models.basemodels_3d") 29 | cusModel = getattr(mod, self.backbone) 30 | self.network = cusModel(n_classes=self.output_dim, pretrained = self.pretrained).to(self.device) 31 | 32 | #self.network = cusResNet18(n_classes=self.output_dim, pretrained=self.pretrained).to(self.device) 33 | 34 | 35 | def forward(self, x): 36 | out, feature = self.network(x) 37 | return out, feature 38 | 39 | def set_optimizer(self, opt): 40 | optimizer_setting = opt['optimizer_setting'] 41 | self.base_optimizer = torch.optim.Adam( 42 | params=self.network.parameters(), 43 | lr=optimizer_setting['lr'], 44 | weight_decay=optimizer_setting['weight_decay'] 45 | ) 46 | self.lr_scheduler = LinearScheduler(T_max=opt['T_max'], \ 47 | max_value=optimizer_setting['lr'], min_value=optimizer_setting['lr']*0.01, optimizer=self.base_optimizer) 48 | self.rho_scheduler = LinearScheduler(T_max=opt['T_max'], max_value=0.04, min_value=0.02) 49 | self.gsam_optimizer = GSAM_optimizer(params=self.network.parameters(), base_optimizer=self.base_optimizer, model=self.network,\ 50 | gsam_alpha=0.01, rho_scheduler=self.rho_scheduler, adaptive=False) 51 | 52 | def state_dict(self): 53 | state_dict = { 54 | 'model': self.network.state_dict(), 55 | 'optimizer': self.gsam_optimizer.state_dict(), 56 | 'epoch': self.epoch 57 | } 58 | return state_dict 59 | 60 | def _train(self, loader): 61 | """Train the model for one epoch""" 62 | 63 | self.network.train() 64 | 65 | train_loss = 0 66 | auc = 0. 67 | no_iter = 0 68 | for i, (images, targets, sensitive_attr, index) in enumerate(loader): 69 | images, targets, sensitive_attr = images.to(self.device), targets.to(self.device), sensitive_attr.to(self.device) 70 | 71 | self.gsam_optimizer.set_closure(self._criterion, images, targets) 72 | outputs, loss = self.gsam_optimizer.step() 73 | self.lr_scheduler.step() 74 | self.gsam_optimizer.update_rho_t() 75 | 76 | 77 | auc += calculate_auc(F.sigmoid(outputs).cpu().data.numpy(), targets.cpu().data.numpy()) 78 | 79 | train_loss += loss.item() 80 | no_iter += 1 81 | 82 | if self.log_freq and (i % self.log_freq == 0): 83 | self.wandb.log({'Training loss': train_loss / (i+1), 'Training AUC': auc / (i+1)}) 84 | 85 | auc = 100 * auc / no_iter 86 | train_loss /= no_iter 87 | 88 | 89 | print('Training epoch {}: AUC:{}'.format(self.epoch, auc)) 90 | print('Training epoch {}: loss:{}'.format(self.epoch, train_loss)) 91 | 92 | self.epoch += 1 93 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from utils.evaluation import calculate_auc, calculate_metrics 6 | 7 | from importlib import import_module 8 | 9 | def standard_train(opt, network, optimizer, loader, _criterion, wandb): 10 | """Train the model for one epoch""" 11 | train_loss, auc, no_iter = 0., 0., 0 12 | for i, (images, targets, sensitive_attr, index) in enumerate(loader): 13 | images, targets, sensitive_attr = images.to(opt['device']), targets.to(opt['device']), sensitive_attr.to(opt['device']) 14 | optimizer.zero_grad() 15 | outputs, _ = network(images) 16 | 17 | loss = _criterion(outputs, targets) 18 | loss.backward() 19 | optimizer.step() 20 | 21 | auc += calculate_auc(F.sigmoid(outputs).cpu().data.numpy(), targets.cpu().data.numpy()) 22 | 23 | train_loss += loss.item() 24 | no_iter += 1 25 | 26 | if opt['log_freq'] and (i % opt['log_freq'] == 0): 27 | wandb.log({'Training loss': train_loss / no_iter, 'Training AUC': auc / no_iter}) 28 | 29 | auc = 100 * auc / no_iter 30 | train_loss /= no_iter 31 | return auc, train_loss 32 | 33 | 34 | def standard_val(opt, network, loader, _criterion, sens_classes, wandb): 35 | """Compute model output on validation set""" 36 | tol_output, tol_target, tol_sensitive, tol_index = [], [], [], [] 37 | 38 | val_loss, auc = 0., 0. 39 | no_iter = 0 40 | with torch.no_grad(): 41 | for i, (images, targets, sensitive_attr, index) in enumerate(loader): 42 | images, targets, sensitive_attr = images.to(opt['device']), targets.to(opt['device']), sensitive_attr.to( 43 | opt['device']) 44 | outputs, features = network.forward(images) 45 | loss = _criterion(outputs, targets) 46 | try: 47 | val_loss += loss.item() 48 | except: 49 | val_loss += loss.mean().item() 50 | tol_output += F.sigmoid(outputs).flatten().cpu().data.numpy().tolist() 51 | tol_target += targets.cpu().data.numpy().tolist() 52 | tol_sensitive += sensitive_attr.cpu().data.numpy().tolist() 53 | tol_index += index.numpy().tolist() 54 | 55 | auc += calculate_auc(F.sigmoid(outputs).cpu().data.numpy(), 56 | targets.cpu().data.numpy()) 57 | 58 | no_iter += 1 59 | 60 | if opt['log_freq'] and (i % opt['log_freq'] == 0): 61 | wandb.log({'Validation loss': val_loss / no_iter, 'Validation AUC': auc / no_iter}) 62 | 63 | auc = 100 * auc / no_iter 64 | val_loss /= no_iter 65 | log_dict, t_predictions, pred_df = calculate_metrics(tol_output, tol_target, tol_sensitive, tol_index, sens_classes) 66 | 67 | return auc, val_loss, log_dict, pred_df 68 | 69 | 70 | def standard_test(opt, network, loader, _criterion, wandb): 71 | """Compute model output on testing set""" 72 | tol_output, tol_target, tol_sensitive, tol_index = [], [], [], [] 73 | 74 | with torch.no_grad(): 75 | for i, (images, targets, sensitive_attr, index) in enumerate(loader): 76 | images, targets, sensitive_attr = images.to(opt['device']), targets.to(opt['device']), sensitive_attr.to( 77 | opt['device']) 78 | outputs, features = network.forward(images) 79 | 80 | tol_output += F.sigmoid(outputs).flatten().cpu().data.numpy().tolist() 81 | tol_target += targets.cpu().data.numpy().tolist() 82 | tol_sensitive += sensitive_attr.cpu().data.numpy().tolist() 83 | tol_index += index.numpy().tolist() 84 | 85 | return tol_output, tol_target, tol_sensitive, tol_index -------------------------------------------------------------------------------- /sweep/train-sweep/sweep_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import itertools 4 | import yaml 5 | import wandb 6 | import argparse 7 | 8 | 9 | parser = argparse.ArgumentParser(description='Hyperparameter sweep.') 10 | 11 | parser.add_argument('--dataset_name', type=str, default='HAM10000', help='dataset name') 12 | parser.add_argument('--sensitive_name', type=str, default='Age', help='dataset name') 13 | parser.add_argument('--total_epochs', type=int, default=20, help='total epochs') 14 | parser.add_argument('--sens_classes', type=int, default=2, help='number of sensitive classes') 15 | parser.add_argument('--num_classes', type=int, default=1, help='number of classes') 16 | parser.add_argument('--val_strategy', type=str, default='worst_auc', help='validation strategy') 17 | parser.add_argument('--is_3d', type=bool, default=False, help='whether 3D dataset') 18 | parser.add_argument('--is_slurm', type=bool, default=True, help='whether using Slurm') 19 | parser.add_argument('--resample_which', type=str, default='class', help='what to resample') 20 | 21 | args = parser.parse_args() 22 | 23 | is_3d = args.is_3d 24 | if is_3d: 25 | backbone = 'cusResNet18_3d' 26 | batch_size = 8 27 | #elif is_tabular: 28 | # backbone = 'cusMLP' 29 | # batch_size = 512 30 | else: 31 | backbone = 'cusResNet18' 32 | batch_size = 1024 33 | 34 | 35 | sensitive_name = args.sensitive_name 36 | dataset_name = args.dataset_name 37 | total_epochs = args.total_epochs 38 | output_dim = num_classes = args.num_classes 39 | val_strategy = args.val_strategy 40 | sens_classes = args.sens_classes 41 | resample_which = args.resample_which 42 | 43 | methods = ['baseline', 'resampling', 'LAFTR', 'CFair', 'LNL', 'EnD', 'DomainInd', 'GroupDRO', 'ODR', 'SWAD', 'SAM'] 44 | 45 | 46 | for method in methods: 47 | print(method) 48 | project_name = '{dataset} {meth}'.format(dataset = dataset_name, meth = method) 49 | wandb.init(project=project_name) 50 | 51 | with open('sweep/train-sweep/sweep_{}.yaml'.format(method)) as file: 52 | config_dict = yaml.load(file, Loader=yaml.FullLoader) 53 | 54 | config_dict['name'] = '{dataset} {meth} {sens} multiAttr'.format(dataset = dataset_name, meth = method, sens=sensitive_name) 55 | 56 | command_list = config_dict['command'] 57 | command_list += ['--dataset_name', dataset_name] 58 | command_list += ['--experiment_name', '{meth}_{dataset}_{sens}'.format(meth = method, dataset = dataset_name, sens=sensitive_name)] 59 | command_list += ['--sensitive_name', sensitive_name] 60 | command_list += ['--total_epochs', total_epochs] 61 | command_list += ['--output_dim', num_classes] 62 | command_list += ['--num_classes', num_classes] 63 | command_list += ['--batch_size', batch_size] 64 | command_list += ['--val_strategy', val_strategy] 65 | command_list += ['--sens_classes', sens_classes] 66 | 67 | command_list += ['--resample_which', resample_which] 68 | if is_3d: 69 | command_list += ['--is_3d', is_3d] 70 | command_list += ['--backbone', backbone] 71 | #elif is_tabular: 72 | # command_list += ['--is_tabular', is_3d] 73 | # command_list += ['--backbone', backbone] 74 | 75 | config_dict['command'] = command_list 76 | #print(config_dict) 77 | 78 | sweep_id = wandb.sweep(config_dict, project=project_name) 79 | 80 | counts = 30 81 | 82 | for i in range(counts): 83 | if args.is_slurm: 84 | MAIN_CMD = f"sbatch sweep/train-sweep/slurm_sweep_count.sh" \ 85 | f" --sweep_id {sweep_id}" 86 | else: 87 | MAIN_CMD = f"bash sweep/train-sweep/sweep_count.sh" \ 88 | f" --sweep_id {sweep_id}" \ 89 | 90 | print('command is ', MAIN_CMD) 91 | CMD = MAIN_CMD.split(' ') 92 | process = subprocess.Popen(CMD, stdout=subprocess.PIPE, universal_newlines=True) 93 | out, err = process.communicate() 94 | print(out) 95 | -------------------------------------------------------------------------------- /models/SAM/SAM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.evaluation import calculate_auc 5 | from models.basenet import BaseNet 6 | from importlib import import_module 7 | 8 | from torch.optim.lr_scheduler import CosineAnnealingLR 9 | from models.SAM.utils import SAM_optimizer, disable_running_stats, enable_running_stats 10 | 11 | 12 | class SAM(BaseNet): 13 | def __init__(self, opt, wandb): 14 | super(SAM, self).__init__(opt, wandb) 15 | self.set_network(opt) 16 | self.set_optimizer(opt) 17 | 18 | def set_network(self, opt): 19 | """Define the network""" 20 | 21 | if self.is_3d: 22 | mod = import_module("models.basemodels_3d") 23 | cusModel = getattr(mod, self.backbone) 24 | self.network = cusModel(n_classes=self.output_dim, pretrained = self.pretrained).to(self.device) 25 | elif self.is_tabular: 26 | mod = import_module("models.basemodels_mlp") 27 | cusModel = getattr(mod, self.backbone) 28 | self.network = cusModel(n_classes=self.output_dim, in_features= self.in_features, hidden_features = 1024).to(self.device) 29 | else: 30 | mod = import_module("models.basemodels") 31 | cusModel = getattr(mod, self.backbone) 32 | self.network = cusModel(n_classes=self.output_dim, pretrained=self.pretrained).to(self.device) 33 | 34 | def set_optimizer(self, opt): 35 | optimizer_setting = opt['optimizer_setting'] 36 | self.base_optimizer = torch.optim.Adam 37 | self.optimizer = SAM_optimizer(params = self.network.parameters(), base_optimizer = self.base_optimizer, rho=opt['rho'], adaptive=opt['adaptive'], lr=optimizer_setting['lr'], weight_decay=optimizer_setting['weight_decay']) 38 | 39 | self.scheduler = CosineAnnealingLR(self.optimizer.base_optimizer, T_max=opt['T_max']) 40 | 41 | def _criterion(self, output, target): 42 | self.criterion = nn.BCEWithLogitsLoss(reduction='none') 43 | return self.criterion(output, target) 44 | 45 | 46 | def state_dict(self): 47 | state_dict = { 48 | 'model': self.network.state_dict(), 49 | 'optimizer': self.optimizer.state_dict(), 50 | 'epoch': self.epoch 51 | } 52 | return state_dict 53 | 54 | def _train(self, loader): 55 | """Train the model for one epoch""" 56 | 57 | self.network.train() 58 | 59 | train_loss = 0 60 | auc = 0. 61 | no_iter = 0 62 | for i, (images, targets, sensitive_attr, index) in enumerate(loader): 63 | images, targets, sensitive_attr = images.to(self.device), targets.to(self.device), sensitive_attr.to(self.device) 64 | 65 | enable_running_stats(self.network) 66 | outputs, _ = self.network(images) 67 | 68 | loss = self._criterion(outputs, targets) 69 | loss.mean().backward() 70 | self.optimizer.first_step(zero_grad=True) 71 | self.scheduler.step() 72 | 73 | disable_running_stats(self.network) 74 | outputs, _ = self.network(images) 75 | self._criterion(outputs, targets).mean().backward() 76 | self.optimizer.second_step(zero_grad=True) 77 | self.scheduler.step() 78 | 79 | auc += calculate_auc(F.sigmoid(outputs).cpu().data.numpy(), targets.cpu().data.numpy()) 80 | 81 | train_loss += loss.mean().item() 82 | no_iter += 1 83 | 84 | if self.log_freq and (i % self.log_freq == 0): 85 | self.wandb.log({'Training loss': train_loss / (i+1), 'Training AUC': auc / (i+1)}) 86 | 87 | auc = 100 * auc / no_iter 88 | train_loss /= no_iter 89 | 90 | 91 | print('Training epoch {}: AUC:{}'.format(self.epoch, auc)) 92 | print('Training epoch {}: loss:{}'.format(self.epoch, train_loss)) 93 | 94 | self.epoch += 1 95 | -------------------------------------------------------------------------------- /notebooks/CovidCT.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "bbd4178f", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import h5py\n", 11 | "import pandas as pd\n", 12 | "import numpy as np\n", 13 | "import cv2\n", 14 | "import os\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "from collections import Counter\n", 17 | "from sklearn.model_selection import train_test_split\n", 18 | "import pydicom as dicom" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "id": "e1173431", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "# read metadata\n", 29 | "path = '/yourpath/data/COVID_CT_MD/'\n", 30 | "\n", 31 | "demo_data = pd.read_csv(path + 'Clinical-data.csv')\n", 32 | "demo_data" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "88dd276b", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "Diagnosis_list = demo_data['Diagnosis'].values.tolist()\n", 43 | "Folder_list = demo_data['Folder'].values.tolist()\n", 44 | "\n", 45 | "Path_list = [x +'/'+y+'.npy' for x, y in zip(Diagnosis_list, Folder_list)]\n", 46 | "\n", 47 | "binary_label_list = [1 if x=='COVID-19' else 0 for x in Diagnosis_list]\n", 48 | "\n", 49 | "demo_data['Path'] = Path_list\n", 50 | "demo_data['binary_label'] = binary_label_list" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "id": "7d69f0c9", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "\n", 61 | "demo_data['Age_multi'] = demo_data['Patient Age'].str[:-1].values.astype('int')\n", 62 | "demo_data['Age_multi'] = np.where(demo_data['Age_multi'].between(-1,20), 0, demo_data['Age_multi'])\n", 63 | "demo_data['Age_multi'] = np.where(demo_data['Age_multi'].between(20,39), 1, demo_data['Age_multi'])\n", 64 | "demo_data['Age_multi'] = np.where(demo_data['Age_multi'].between(40,59), 2, demo_data['Age_multi'])\n", 65 | "demo_data['Age_multi'] = np.where(demo_data['Age_multi'].between(60,79), 3, demo_data['Age_multi'])\n", 66 | "demo_data['Age_multi'] = np.where(demo_data['Age_multi']>=80, 4, demo_data['Age_multi'])\n", 67 | "\n", 68 | "demo_data['Age_binary'] = demo_data['Patient Age'].str[:-1].values.astype('int')\n", 69 | "demo_data['Age_binary'] = np.where(demo_data['Age_binary'].between(-1, 60), 0, demo_data['Age_binary'])\n", 70 | "demo_data['Age_binary'] = np.where(demo_data['Age_binary']>= 60, 1, demo_data['Age_binary'])\n", 71 | "\n", 72 | "demo_data = demo_data.rename(columns={'Patient Gender': 'Sex'})" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "id": "e4f757b9", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "def split_712(all_meta, patient_ids):\n", 83 | " sub_train, sub_val_test = train_test_split(patient_ids, test_size=0.3, random_state=10)\n", 84 | " sub_val, sub_test = train_test_split(sub_val_test, test_size=0.66, random_state=0)\n", 85 | " train_meta = all_meta[all_meta.Folder.isin(sub_train.astype('str'))]\n", 86 | " val_meta = all_meta[all_meta.Folder.isin(sub_val.astype('str'))]\n", 87 | " test_meta = all_meta[all_meta.Folder.isin(sub_test.astype('str'))]\n", 88 | " return train_meta, val_meta, test_meta\n", 89 | "\n", 90 | "sub_train, sub_val, sub_test = split_712(demo_data, np.unique(demo_data['Folder']))" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "id": "b91657ee", 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "sub_train.to_csv('/yourpath/data/COVID_CT_MD/split/new_train.csv')\n", 101 | "sub_val.to_csv('/yourpath/data/COVID_CT_MD/split/new_val.csv')\n", 102 | "sub_test.to_csv('/yourpath/data/COVID_CT_MD/split/new_test.csv')" 103 | ] 104 | } 105 | ], 106 | "metadata": { 107 | "kernelspec": { 108 | "display_name": "torch11", 109 | "language": "python", 110 | "name": "torch11" 111 | }, 112 | "language_info": { 113 | "codemirror_mode": { 114 | "name": "ipython", 115 | "version": 3 116 | }, 117 | "file_extension": ".py", 118 | "mimetype": "text/x-python", 119 | "name": "python", 120 | "nbconvert_exporter": "python", 121 | "pygments_lexer": "ipython3", 122 | "version": "3.8.12" 123 | } 124 | }, 125 | "nbformat": 4, 126 | "nbformat_minor": 5 127 | } 128 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MEDFAIR: Benchmarking Fairness for Medical Imaging 2 | 3 | MEDFAIR is a fairness benchmarking suite for medical imaging ([paper](https://arxiv.org/abs/2210.01725)). We are actively updating this repo and will incorporate more datasets and algorithms in the future. Contributions are warmly welcomed! 4 | Check our [website](https://ys-zong.github.io/MEDFAIR/) for a brief summary of the paper. 5 | 6 | :grinning: MEDFAIR is accepted to ICLR'23 as *Spotlight*! 7 | 8 | ## Documentation 9 | A detailed documentation can be found [here](https://github.com/ys-zong/MEDFAIR/blob/main/docs/index.md). 10 | 11 | ## Quick Start 12 | 13 | ### Installation 14 | Python >= 3.8+ and Pytorch >=1.10 are required for running the code. Other necessary packages are listed in [`environment.yml`](../environment.yml). 15 | 16 | ### Installation via conda: 17 | ```python 18 | cd MEDFAIR/ 19 | conda env create -n fair_benchmark -f environment.yml 20 | conda activate fair_benchmark 21 | ``` 22 | 23 | ### Dataset 24 | Due to the data use agreements, we cannot directly share the download link. Please register and download datasets using the links from the table below: 25 | 26 | | **Dataset** | **Access** | 27 | |--------------|-----------------------------------------------------------------------------------------------| 28 | | CheXpert | Original data: https://stanfordmlgroup.github.io/competitions/chexpert/ | 29 | | | Demographic data: https://stanfordaimi.azurewebsites.net/datasets/192ada7c-4d43-466e-b8bb-b81992bb80cf | 30 | | MIMIC-CXR | https://physionet.org/content/mimic-cxr-jpg/2.0.0/ | 31 | | PAPILA | https://www.nature.com/articles/s41597-022-01388-1#Sec6 | 32 | | HAM10000 | https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/DBW86T | 33 | | OCT | https://people.duke.edu/~sf59/RPEDC_Ophth_2013_dataset.htm | 34 | | Fitzpatrick17k | https://github.com/mattgroh/fitzpatrick17k | 35 | | COVID-CT-MD | https://doi.org/10.6084/m9.figshare.12991592 | 36 | | ADNI 1.5T/3T | https://ida.loni.usc.edu/login.jsp?project=ADNI | 37 | 38 | 39 | ### Data Preprocessing 40 | See `notebooks/HAM10000.ipynb` for an simple example of how to preprocess the data into desired format. You can also find other preprocessing scripts for corresponding datasets. 41 | Basically, it contains 3 steps: 42 | 1. Preprocess metadata. 43 | 2. Split to train/val/test set 44 | 3. Save images into pickle files (optional -- we usually do this for 2D images instead of 3D images, as data IO is not the bottleneck for training 3D images). 45 | 46 | After preprocessing, specify the paths of the metadata and pickle files in `configs/datasets.json`. 47 | 48 | 49 | ### Run a single experiment 50 | ```python 51 | python main.py --experiment [experiment] --experiment_name [experiment_name] --dataset_name [dataset_name] \ 52 | --backbone [backbone] --total_epochs [total_epochs] --sensitive_name [sensitive_name] \ 53 | --batch_size [batch_size] --lr [lr] --sens_classes [sens_classes] --val_strategy [val_strategy] \ 54 | --output_dim [output_dim] --num_classes [num_classes] 55 | ``` 56 | 57 | For example, for running `ERM` in `HAM10000` dataset with `Sex` as the sensitive attribute: 58 | ```python 59 | python main.py --experiment baseline --dataset_name HAM10000 \ 60 | --total_epochs 20 --sensitive_name Sex --batch_size 1024 \ 61 | --sens_classes 2 --output_dim 1 --num_classes 1 62 | ``` 63 | 64 | See `parse_args.py` for more options. 65 | 66 | ### Run a grid search on a Slurm cluster/Regular Machine 67 | ```python 68 | python sweep/train-sweep/sweep_batch.py --is_slurm True/False 69 | ``` 70 | Set the other arguments as needed. 71 | 72 | ### Model selection and Results analysis 73 | See `notebooks/results_analysis.ipynb` for a step by step example. 74 | 75 | ## Citation 76 | Please consider citing our paper if you find this repo useful. 77 | ``` 78 | @inproceedings{zong2023medfair, 79 | title={MEDFAIR: Benchmarking Fairness for Medical Imaging}, 80 | author={Yongshuo Zong and Yongxin Yang and Timothy Hospedales}, 81 | booktitle={International Conference on Learning Representations (ICLR)}, 82 | year={2023}, 83 | } 84 | ``` 85 | 86 | ## Acknowledgement 87 | MEDFAIR adapts implementations from many repos (check [here](docs/reference.md#debiasing-methods) for the original implementation of the algorithms), as well as many other codes. Many thanks! 88 | -------------------------------------------------------------------------------- /notebooks/fit17k.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "bbd4178f", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import h5py\n", 11 | "import pandas as pd\n", 12 | "import numpy as np\n", 13 | "import cv2\n", 14 | "import os\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "from collections import Counter\n", 17 | "from sklearn.model_selection import train_test_split" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "id": "e1173431", 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "# read metadata\n", 28 | "path = 'yourpath/data/finalfitz17k/'\n", 29 | "\n", 30 | "annot_data = pd.read_csv(path + 'fitzpatrick17k.csv')\n", 31 | "annot_data" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "95739fb4", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "# 'fitzpatrick_scale' is the skin type\n", 42 | "Counter(annot_data['fitzpatrick_scale'])" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "id": "0b1098d0", 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "pathlist = annot_data['md5hash'].values.tolist()\n", 53 | "paths = ['images/' + i + '.jpg' for i in pathlist]" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "id": "a7847f09", 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "annot_data['Path'] = paths" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "id": "c80a746b", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "# remove skin type == null \n", 74 | "annot_data = annot_data[annot_data['fitzpatrick_scale'] != -1]\n", 75 | "annot_data" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "id": "7f8b82c1", 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "Counter(annot_data['three_partition_label'])" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "id": "221a1aa0", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "# binarize the label\n", 96 | "labellist = annot_data['three_partition_label'].values.tolist()\n", 97 | "labels = [1 if x == 'malignant' else 0 for x in labellist]\n", 98 | "print(Counter(labels))\n", 99 | "annot_data['binary_label'] = labels" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "id": "39807da2", 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "annot_data['skin_type'] = annot_data['fitzpatrick_scale'] - 1\n", 110 | "skin_lists = annot_data['skin_type'].values.tolist()\n", 111 | "annot_data['skin_binary'] = [0 if x <=2 else 1 for x in skin_lists] " 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "id": "e4f757b9", 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "def split_811(all_meta, patient_ids):\n", 122 | " sub_train, sub_val_test = train_test_split(patient_ids, test_size=0.2, random_state=5)\n", 123 | " sub_val, sub_test = train_test_split(sub_val_test, test_size=0.5, random_state=6)\n", 124 | " train_meta = all_meta[all_meta.md5hash.isin(sub_train)]\n", 125 | " val_meta = all_meta[all_meta.md5hash.isin(sub_val)]\n", 126 | " test_meta = all_meta[all_meta.md5hash.isin(sub_test)]\n", 127 | " return train_meta, val_meta, test_meta\n", 128 | "\n", 129 | "sub_train, sub_val, sub_test = split_811(annot_data, np.unique(annot_data['md5hash']))" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "id": "b91657ee", 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "sub_train.to_csv('/yourpath/data/finalfitz17k/split/new_train.csv')\n", 140 | "sub_val.to_csv('/yourpath/data/finalfitz17k/split/new_val.csv')\n", 141 | "sub_test.to_csv('/yourpath/data/finalfitz17k/split/new_test.csv')" 142 | ] 143 | } 144 | ], 145 | "metadata": { 146 | "kernelspec": { 147 | "display_name": "torch11", 148 | "language": "python", 149 | "name": "torch11" 150 | }, 151 | "language_info": { 152 | "codemirror_mode": { 153 | "name": "ipython", 154 | "version": 3 155 | }, 156 | "file_extension": ".py", 157 | "mimetype": "text/x-python", 158 | "name": "python", 159 | "nbconvert_exporter": "python", 160 | "pygments_lexer": "ipython3", 161 | "version": "3.8.12" 162 | } 163 | }, 164 | "nbformat": 4, 165 | "nbformat_minor": 5 166 | } 167 | -------------------------------------------------------------------------------- /notebooks/OCT.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "bbd4178f", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import h5py\n", 11 | "import pandas as pd\n", 12 | "import numpy as np\n", 13 | "import cv2\n", 14 | "import os\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "from collections import Counter\n", 17 | "from sklearn.model_selection import train_test_split\n", 18 | "import scipy.io" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "id": "e1173431", 25 | "metadata": { 26 | "scrolled": true 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "# read metadata\n", 31 | "path = '/yourpath/data/OCT/'\n", 32 | "\n", 33 | "clses = ['AMD', 'Control']\n", 34 | "ages = []\n", 35 | "paths = []\n", 36 | "labels = []\n", 37 | "for cls in clses:\n", 38 | " files = os.listdir(path + cls)\n", 39 | " for file in files:\n", 40 | " if '.mat' in file:\n", 41 | " mat = scipy.io.loadmat(os.path.join(path, cls, file))\n", 42 | " img = mat['images']\n", 43 | " age = mat['Age'].item()\n", 44 | " ages.append(age)\n", 45 | " paths.append(os.path.join(path, cls, file))\n", 46 | " label = 0 if cls == 'AMD' else 1\n", 47 | " labels.append(label)\n", 48 | " \n", 49 | " # write to npy\n", 50 | " filename = file.split('.')[0] + '.npy'\n", 51 | " np.save(os.path.join(path, 'images', filename), img)\n", 52 | " \n" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "75d82bf3", 59 | "metadata": { 60 | "scrolled": true 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "to_create = list(zip(paths, ages, labels))\n", 65 | "meta_all = pd.DataFrame(to_create, columns=['Path', 'Age', 'label'])" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "id": "d12badfe", 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "paths = meta_all['Path'].values.tolist()\n", 76 | "paths = [x.split('/')[-1].replace('mat', 'npy') for x in paths]\n", 77 | "meta_all['Path'] = paths" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "c2445a0d", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "meta_all.to_csv('/yourpath/data/OCT/split/meta_all.csv', index = False)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "id": "1908675b", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "meta_all['Age_multi'] = meta_all['Age'].values.astype('int')\n", 98 | "meta_all['Age_multi'] = np.where(meta_all['Age_multi'].between(0,19), 0, meta_all['Age_multi'])\n", 99 | "meta_all['Age_multi'] = np.where(meta_all['Age_multi'].between(20,39), 1, meta_all['Age_multi'])\n", 100 | "meta_all['Age_multi'] = np.where(meta_all['Age_multi'].between(40,59), 2, meta_all['Age_multi'])\n", 101 | "meta_all['Age_multi'] = np.where(meta_all['Age_multi'].between(60,79), 3, meta_all['Age_multi'])\n", 102 | "meta_all['Age_multi'] = np.where(meta_all['Age_multi']>=80, 4, meta_all['Age_multi'])\n", 103 | "\n", 104 | "meta_all['Age_binary'] = meta_all['Age'].values.astype('int')\n", 105 | "meta_all['Age_binary'] = np.where(meta_all['Age_binary'].between(0, 70), 0, meta_all['Age_binary'])\n", 106 | "meta_all['Age_binary'] = np.where(meta_all['Age_binary']>= 70, 1, meta_all['Age_binary'])\n", 107 | "meta_all" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "e4f757b9", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "def split_712(all_meta, patient_ids):\n", 118 | " sub_train, sub_val_test = train_test_split(patient_ids, test_size=0.3, random_state=5)\n", 119 | " sub_val, sub_test = train_test_split(sub_val_test, test_size=0.6, random_state=6)\n", 120 | " train_meta = all_meta[all_meta.Path.isin(sub_train)]\n", 121 | " val_meta = all_meta[all_meta.Path.isin(sub_val)]\n", 122 | " test_meta = all_meta[all_meta.Path.isin(sub_test)]\n", 123 | " return train_meta, val_meta, test_meta\n", 124 | "\n", 125 | "sub_train, sub_val, sub_test = split_712(meta_all, np.unique(meta_all['Path']))" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "id": "b91657ee", 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "sub_train.to_csv('/yourpath/data/OCT/split/new_train.csv')\n", 136 | "sub_val.to_csv('/yourpath/data/OCT/split/new_val.csv')\n", 137 | "sub_test.to_csv('/yourpath/data/OCT/split/new_test.csv')" 138 | ] 139 | } 140 | ], 141 | "metadata": { 142 | "kernelspec": { 143 | "display_name": "torch11", 144 | "language": "python", 145 | "name": "torch11" 146 | }, 147 | "language_info": { 148 | "codemirror_mode": { 149 | "name": "ipython", 150 | "version": 3 151 | }, 152 | "file_extension": ".py", 153 | "mimetype": "text/x-python", 154 | "name": "python", 155 | "nbconvert_exporter": "python", 156 | "pygments_lexer": "ipython3", 157 | "version": "3.8.12" 158 | } 159 | }, 160 | "nbformat": 4, 161 | "nbformat_minor": 5 162 | } 163 | -------------------------------------------------------------------------------- /models/LNL/LNL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from models.LNL.model import LNLNet, LNLNet3D, LNLNet_MLP, LNLPredictor_MLP, LNLPredictor, LNLPredictor3D, grad_reverseLNL 6 | from utils.evaluation import calculate_auc 7 | from models.basenet import BaseNet 8 | 9 | 10 | class LNL(BaseNet): 11 | def __init__(self, opt, wandb): 12 | super(LNL, self).__init__(opt, wandb) 13 | self.set_network(opt) 14 | self.set_optimizer(opt) 15 | self._lambda = opt['_lambda'] 16 | 17 | self.pred_loss = nn.CrossEntropyLoss() 18 | 19 | def set_network(self, opt): 20 | """Define the network""" 21 | 22 | if self.is_3d: 23 | self.network = LNLNet3D(backbone = self.backbone, num_classes=self.num_classes, pretrained=self.pretrained).to(self.device) 24 | #pred_ch = self.network.body.layer2[-1].conv1[0].in_channels 25 | pred_ch = pred_ch = self.network.pred_ch 26 | self.pred_net = LNLPredictor3D(input_ch=pred_ch, num_classes=self.sens_classes).to(self.device) 27 | elif self.is_tabular: 28 | self.network = LNLNet_MLP(backbone = self.backbone, num_classes=self.num_classes, in_features=self.in_features, hidden_features=1024).to(self.device) 29 | pred_ch = self.network.pred_ch 30 | self.pred_net = LNLPredictor_MLP(input_ch=pred_ch, num_classes=self.sens_classes).to(self.device) 31 | 32 | else: 33 | self.network = LNLNet(backbone = self.backbone, num_classes=self.num_classes, pretrained=self.pretrained).to(self.device) 34 | #pred_ch = self.network.body.layer2[-1].conv1.in_channels 35 | pred_ch = self.network.pred_ch 36 | self.pred_net = LNLPredictor(input_ch=pred_ch, num_classes=self.sens_classes).to(self.device) 37 | 38 | 39 | #print(self.network) 40 | #print(self.pred_net) 41 | 42 | def forward(self, x): 43 | pred_label, feat_label = self.network(x) 44 | return pred_label, feat_label 45 | 46 | def set_optimizer(self, opt): 47 | optimizer_setting = opt['optimizer_setting'] 48 | self.optimizer = optimizer_setting['optimizer']( 49 | params=filter(lambda p: p.requires_grad, self.network.parameters()), 50 | lr=optimizer_setting['lr'], 51 | weight_decay=optimizer_setting['weight_decay'] 52 | ) 53 | self.optimizer_pred = optimizer_setting['optimizer']( 54 | params=filter(lambda p: p.requires_grad, self.network.parameters()), 55 | lr=optimizer_setting['lr'], 56 | weight_decay=optimizer_setting['weight_decay']) 57 | 58 | lr_lambda = lambda step: opt['lr_decay_rate'] ** (step // opt['lr_decay_period']) 59 | self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_lambda, last_epoch=-1) 60 | self.scheduler_pred = optim.lr_scheduler.LambdaLR(self.optimizer_pred, lr_lambda=lr_lambda, last_epoch=-1) 61 | 62 | def _train(self, loader): 63 | """Train the model for one epoch""" 64 | self.network.train() 65 | 66 | running_loss = 0. 67 | running_adv_loss = 0. 68 | running_MI = 0. 69 | auc = 0. 70 | no_iter = 0 71 | for i, (images, targets, sensitive_attr, index) in enumerate(loader): 72 | images, targets, sensitive_attr = images.to(self.device), targets.to(self.device), sensitive_attr.to(self.device) 73 | self.optimizer.zero_grad() 74 | 75 | self.optimizer.zero_grad() 76 | self.optimizer_pred.zero_grad() 77 | 78 | pred_label, feat_label = self.forward(images) 79 | pseudo_pred, _ = self.pred_net(feat_label) 80 | 81 | loss_pred_cls = self._criterion(pred_label, targets) 82 | pseudo_pred = F.sigmoid(pseudo_pred) 83 | loss_pseudo_pred = torch.mean(torch.sum(pseudo_pred * torch.log(pseudo_pred), 1)) 84 | 85 | loss = loss_pred_cls + loss_pseudo_pred * self._lambda 86 | loss.backward() 87 | 88 | self.optimizer.step() 89 | self.optimizer_pred.step() 90 | 91 | self.optimizer.zero_grad() 92 | self.optimizer_pred.zero_grad() 93 | 94 | pred_label, feat_label = self.forward(images) 95 | feat_sens = grad_reverseLNL(feat_label) 96 | _, pred_ = self.pred_net(feat_sens) 97 | loss_pred_sensi = self.pred_loss(pred_, sensitive_attr) 98 | loss_pred_sensi.backward() 99 | 100 | self.optimizer.step() 101 | self.optimizer_pred.step() 102 | 103 | running_loss += loss_pred_cls.item() 104 | running_adv_loss += loss_pseudo_pred.item() 105 | running_MI += loss_pred_sensi.item() 106 | 107 | auc += calculate_auc(F.sigmoid(pred_label).cpu().data.numpy(), targets.cpu().data.numpy()) 108 | no_iter += 1 109 | 110 | if self.log_freq and (i % self.log_freq == 0): 111 | self.wandb.log({'Training loss': running_loss / (i+1), 'Training AUC': auc / (i+1)}) 112 | 113 | running_loss /= no_iter 114 | running_adv_loss /= no_iter 115 | running_MI /= no_iter 116 | auc = auc / no_iter 117 | print('Training epoch {}: AUC:{}'.format(self.epoch, auc)) 118 | print('Training epoch {}: cls loss:{}, adv loss:{}, MI:{}'.format(self.epoch, running_loss, running_adv_loss, running_MI)) 119 | #self.log_result('Train epoch', {'cls loss': running_loss, 'adv loss': running_adv_loss, 'AUC': auc}, self.epoch) 120 | self.epoch += 1 -------------------------------------------------------------------------------- /models/CFair/CFair.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from models.CFair.model import CFairNet, CFairNet3D, CFairNet_MLP 7 | from utils import basics 8 | from utils.evaluation import calculate_auc, calculate_metrics, calculate_FPR_FNR 9 | from models.basenet import BaseNet 10 | from models.utils import standard_val, standard_test 11 | 12 | 13 | 14 | class CFair(BaseNet): 15 | def __init__(self, opt, wandb): 16 | super(CFair, self).__init__(opt, wandb) 17 | 18 | self.test_classes = opt['sens_classes'] 19 | self.sens_classes = 2 20 | 21 | self.set_network(opt) 22 | self.set_optimizer(opt) 23 | 24 | self.mu = opt['mu'] # coefficient for adversarial loss 25 | 26 | def set_network(self, opt): 27 | """Define the network""" 28 | if self.is_3d: 29 | self.network = CFairNet3D(backbone = self.backbone, num_classes=self.num_classes, adversary_size = 128, pretrained = self.pretrained).to(self.device) 30 | elif self.is_tabular: 31 | self.network = CFairNet_MLP(backbone = self.backbone, num_classes=self.num_classes, adversary_size=128, device=self.device, in_features=self.in_features, hidden_features=1024).to(self.device) 32 | else: 33 | self.network = CFairNet(backbone = self.backbone, num_classes=self.num_classes, adversary_size = 128, pretrained = self.pretrained).to(self.device) 34 | 35 | def get_reweight_tensor(self, model_name): 36 | train_target_attrs = self.train_data.A 37 | train_target_labels = self.train_data.Y 38 | train_y_1 = np.mean(train_target_labels) 39 | 40 | if model_name == "cfair": 41 | reweight_target_tensor = torch.FloatTensor([1.0 / (1.0 - train_y_1), 1.0 / train_y_1]).to(self.device) 42 | elif model_name == "cfair-eo": 43 | reweight_target_tensor = torch.FloatTensor([1.0, 1.0]).to(self.device) 44 | 45 | train_idx = train_target_attrs == 0 46 | train_base_0, train_base_1 = np.mean(train_target_labels[train_idx]), np.mean(train_target_labels[~train_idx]) 47 | reweight_attr_0_tensor = torch.FloatTensor([1.0 / (1.0 - train_base_0), 1.0 / train_base_0]).to(self.device) 48 | reweight_attr_1_tensor = torch.FloatTensor([1.0 / (1.0 - train_base_1), 1.0 / train_base_1]).to(self.device) 49 | reweight_attr_tensors = [reweight_attr_0_tensor, reweight_attr_1_tensor] 50 | return reweight_target_tensor, reweight_attr_tensors 51 | 52 | def _train(self, loader): 53 | """Train the model for one epoch""" 54 | reweight_target_tensor, reweight_attr_tensors = self.get_reweight_tensor(model_name='cfair') 55 | self._criterion = nn.BCEWithLogitsLoss(pos_weight=reweight_target_tensor) 56 | self.network.train() 57 | 58 | running_loss = 0. 59 | running_adv_loss = 0. 60 | auc = 0. 61 | no_iter = 0 62 | for i, (images, targets, sensitive_attr, index) in enumerate(loader): 63 | images, targets, sensitive_attr = images.to(self.device), targets.to(self.device), sensitive_attr.to(self.device) 64 | self.optimizer.zero_grad() 65 | ypreds, apreds = self.network.forward(images, targets) 66 | 67 | loss = self._criterion(ypreds, targets) 68 | 69 | adv_loss = torch.mean(torch.stack([F.nll_loss(apreds[j], sensitive_attr[targets[:, 0] == j], weight= reweight_attr_tensors[j]) for j in range(self.sens_classes)])) 70 | running_loss += loss.item() 71 | running_adv_loss += adv_loss.item() 72 | 73 | loss += self.mu * adv_loss 74 | 75 | loss.backward() 76 | self.optimizer.step() 77 | 78 | auc += calculate_auc(F.sigmoid(ypreds).cpu().data.numpy(), targets.cpu().data.numpy()) 79 | no_iter += 1 80 | 81 | if self.log_freq and (i % self.log_freq == 0): 82 | self.wandb.log({'Training loss': running_loss / (i+1), 'Training AUC': auc / (i+1)}) 83 | 84 | running_loss /= no_iter 85 | running_adv_loss /= no_iter 86 | auc = auc / no_iter 87 | print('Training epoch {}: AUC:{}'.format(self.epoch, auc)) 88 | print('Training epoch {}: cls loss:{}, adv loss:{}'.format(self.epoch, running_loss, running_adv_loss)) 89 | self.epoch += 1 90 | 91 | def _val(self, loader): 92 | """Compute model output on validation set""" 93 | 94 | self.network.eval() 95 | auc, val_loss, log_dict, pred_df = standard_val(self.opt, self.network, loader, self._criterion, self.test_classes, self.wandb) 96 | 97 | print('Validation epoch {}: validation loss:{}, AUC:{}'.format( 98 | self.epoch, val_loss, auc)) 99 | return val_loss, auc, log_dict, pred_df 100 | 101 | def _test(self, loader): 102 | """Compute model output on testing set""" 103 | 104 | self.network.eval() 105 | tol_output, tol_target, tol_sensitive, tol_index = standard_test(self.opt, self.network, loader, self._criterion, self.wandb) 106 | 107 | log_dict, t_predictions, pred_df = calculate_metrics(tol_output, tol_target, tol_sensitive, tol_index, self.test_classes) 108 | overall_FPR, overall_FNR, FPRs, FNRs = calculate_FPR_FNR(pred_df, self.test_meta, self.opt) 109 | log_dict['Overall FPR'] = overall_FPR 110 | log_dict['Overall FNR'] = overall_FNR 111 | 112 | for i, FPR in enumerate(FPRs): 113 | log_dict['FPR-group_' + str(i)] = FPR 114 | for i, FNR in enumerate(FNRs): 115 | log_dict['FNR-group_' + str(i)] = FNR 116 | 117 | log_dict = basics.add_dict_prefix(log_dict, 'Test ') 118 | self.opt['sens_classes'] = 2 119 | return log_dict -------------------------------------------------------------------------------- /models/EnD/EnD.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from models.EnD.model import EnDNet, EnDNet3D,EnDNetMLP 6 | from utils.evaluation import calculate_auc 7 | from models.basenet import BaseNet 8 | 9 | 10 | class EnD(BaseNet): 11 | def __init__(self, opt, wandb): 12 | super(EnD, self).__init__(opt, wandb) 13 | 14 | self.set_network(opt) 15 | self.set_optimizer(opt) 16 | 17 | self.alpha = opt['alpha'] 18 | self.beta = opt['beta'] 19 | 20 | def set_network(self, opt): 21 | """Define the network""" 22 | if self.is_3d: 23 | self.network = EnDNet3D(backbone = self.backbone, n_classes = self.num_classes, pretrained = self.pretrained).to(self.device) 24 | elif self.is_tabular: 25 | self.network = EnDNetMLP(backbone = self.backbone, n_classes = self.num_classes, in_features=self.in_features, hidden_features=1024).to(self.device) 26 | else: 27 | self.network = EnDNet(backbone = self.backbone, n_classes = self.num_classes, pretrained = self.pretrained).to(self.device) 28 | 29 | def _train(self, loader): 30 | """Train the model for one epoch""" 31 | 32 | self.network.train() 33 | 34 | running_loss = 0. 35 | running_adv_loss = 0. 36 | auc = 0. 37 | no_iter = 0 38 | for i, (images, targets, sensitive_attr, index) in enumerate(loader): 39 | images, targets, sensitive_attr = images.to(self.device), targets.to(self.device), sensitive_attr.to( 40 | self.device) 41 | self.optimizer.zero_grad() 42 | outputs, features = self.network.forward(images) 43 | 44 | bce_loss = self._criterion(outputs, targets) 45 | abs_loss = self.abs_regu(features, targets, sensitive_attr, self.alpha, self.beta) 46 | loss = bce_loss + abs_loss 47 | loss.backward() 48 | self.optimizer.step() 49 | 50 | running_loss += loss.item() 51 | running_adv_loss += abs_loss.item() 52 | 53 | auc += calculate_auc(F.sigmoid(outputs).cpu().data.numpy(), 54 | targets.cpu().data.numpy()) 55 | 56 | no_iter += 1 57 | if self.log_freq and (i % self.log_freq == 0): 58 | self.wandb.log({'Training loss': running_loss / (i+1), 'Training AUC': auc / (i+1)}) 59 | 60 | 61 | running_loss /= no_iter 62 | running_adv_loss /= no_iter 63 | auc = auc / no_iter 64 | print('Training epoch {}: AUC:{}'.format(self.epoch, auc)) 65 | print('Training epoch {}: cls loss:{}, adv loss:{}'.format( 66 | self.epoch, running_loss, running_adv_loss)) 67 | 68 | self.epoch += 1 69 | 70 | def abs_orthogonal_blind(self, output, gram, target_labels, bias_labels): 71 | # For each discriminatory class, orthogonalize samples 72 | 73 | bias_classes = torch.unique(bias_labels) 74 | orthogonal_loss = torch.tensor(0.).to(output.device) 75 | M_tot = 0. 76 | 77 | for bias_class in bias_classes: 78 | bias_mask = (bias_labels == bias_class).type(torch.float).unsqueeze(dim=1) 79 | bias_mask = torch.tril(torch.mm(bias_mask, torch.transpose(bias_mask, 0, 1)), diagonal=-1) 80 | M = bias_mask.sum() 81 | M_tot += M 82 | 83 | if M > 0: 84 | orthogonal_loss += torch.abs(torch.sum(gram * bias_mask)) 85 | 86 | if M_tot > 0: 87 | orthogonal_loss /= M_tot 88 | return orthogonal_loss 89 | 90 | def abs_parallel(self, gram, target_labels, bias_labels): 91 | # For each target class, parallelize samples belonging to 92 | # different discriminatory classes 93 | 94 | target_classes = torch.unique(target_labels) 95 | bias_classes = torch.unique(bias_labels) 96 | 97 | parallel_loss = torch.tensor(0.).to(gram.device) 98 | M_tot = 0. 99 | 100 | for target_class in target_classes: 101 | class_mask = (target_labels == target_class).type(torch.float).unsqueeze(dim=1) 102 | 103 | for idx, bias_class in enumerate(bias_classes): 104 | bias_mask = (bias_labels == bias_class).type(torch.float).unsqueeze(dim=1) 105 | 106 | for other_bias_class in bias_classes[idx:]: 107 | if other_bias_class == bias_class: 108 | continue 109 | 110 | other_bias_mask = (bias_labels == other_bias_class).type(torch.float).unsqueeze(dim=1) 111 | mask = torch.tril( 112 | torch.mm(class_mask * bias_mask, torch.transpose(class_mask * other_bias_mask, 0, 1)), 113 | diagonal=-1) 114 | M = mask.sum() 115 | M_tot += M 116 | 117 | if M > 0: 118 | parallel_loss -= torch.sum((1.0 + gram) * mask * 0.5) 119 | if M_tot > 0: 120 | parallel_loss = 1.0 + (parallel_loss / M_tot) 121 | return parallel_loss 122 | 123 | def abs_regu(self, feat, target_labels, bias_labels, alpha=1.0, beta=1.0, sum=True): 124 | D = feat 125 | if len(D.size()) > 2: 126 | D = D.view(-1, np.prod((D.size()[1:]))) 127 | 128 | gram_matrix = torch.tril(torch.mm(D, torch.transpose(D, 0, 1)), diagonal=-1) 129 | # not really needed, just for safety for approximate repr 130 | gram_matrix = torch.clamp(gram_matrix, -1, 1.) 131 | 132 | zero = torch.tensor(0.).to(target_labels.device) 133 | R_ortho = self.abs_orthogonal_blind(D, gram_matrix, target_labels, bias_labels) if alpha != 0 else zero 134 | R_parallel = self.abs_parallel(gram_matrix, target_labels, bias_labels) if beta != 0 else zero 135 | 136 | if sum: 137 | return alpha * R_ortho + beta * R_parallel 138 | return alpha * R_ortho, beta * R_parallel 139 | -------------------------------------------------------------------------------- /models/GroupDRO/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class LossComputer: 4 | def __init__(self, criterion, is_robust, dataset, alpha=None, gamma=0.1, adj=None, min_var_weight=0, step_size=0.01, normalize_loss=False, btl=False): 5 | self.criterion = criterion 6 | self.is_robust = is_robust 7 | self.gamma = gamma 8 | self.alpha = alpha 9 | self.min_var_weight = min_var_weight 10 | self.step_size = step_size 11 | self.normalize_loss = normalize_loss 12 | self.btl = btl 13 | 14 | self.n_groups = dataset.sens_classes 15 | _, self.group_counts = dataset.group_counts() 16 | self.group_counts = self.group_counts.cuda() 17 | self.group_frac = self.group_counts/self.group_counts.sum() 18 | #self.group_str = dataset.group_str 19 | 20 | if adj is not None: 21 | self.adj = torch.from_numpy(adj).float().cuda() 22 | else: 23 | self.adj = torch.zeros(self.n_groups).float().cuda() 24 | 25 | if is_robust: 26 | assert alpha, 'alpha must be specified' 27 | 28 | # quantities maintained throughout training 29 | self.adv_probs = torch.ones(self.n_groups).cuda()/self.n_groups 30 | self.exp_avg_loss = torch.zeros(self.n_groups).cuda() 31 | self.exp_avg_initialized = torch.zeros(self.n_groups).byte().cuda() 32 | 33 | self.reset_stats() 34 | 35 | def loss(self, yhat, y, group_idx=None, is_training=False): 36 | # compute per-sample and per-group losses 37 | per_sample_losses = self.criterion(yhat, y) 38 | group_loss, group_count = self.compute_group_avg(per_sample_losses, group_idx) 39 | #group_acc, group_count = self.compute_group_avg((torch.argmax(yhat,1)==y).float(), group_idx) 40 | group_acc, group_count = self.compute_group_avg((yhat > 0.5).float(), group_idx) 41 | 42 | # update historical losses 43 | self.update_exp_avg_loss(group_loss, group_count) 44 | 45 | # compute overall loss 46 | if self.is_robust and not self.btl: 47 | actual_loss, weights = self.compute_robust_loss(group_loss, group_count) 48 | elif self.is_robust and self.btl: 49 | actual_loss, weights = self.compute_robust_loss_btl(group_loss, group_count) 50 | else: 51 | actual_loss = per_sample_losses.mean() 52 | weights = None 53 | 54 | # update stats 55 | self.update_stats(actual_loss, group_loss, group_acc, group_count, weights) 56 | 57 | return actual_loss 58 | 59 | def compute_robust_loss(self, group_loss, group_count): 60 | adjusted_loss = group_loss 61 | if torch.all(self.adj>0): 62 | adjusted_loss += self.adj/torch.sqrt(self.group_counts) 63 | if self.normalize_loss: 64 | adjusted_loss = adjusted_loss/(adjusted_loss.sum()) 65 | self.adv_probs = self.adv_probs * torch.exp(self.step_size*adjusted_loss.data) 66 | self.adv_probs = self.adv_probs/(self.adv_probs.sum()) 67 | 68 | robust_loss = group_loss @ self.adv_probs 69 | return robust_loss, self.adv_probs 70 | 71 | def compute_group_avg(self, losses, group_idx): 72 | # compute observed counts and mean loss for each group 73 | group_map = (group_idx == torch.arange(self.n_groups).unsqueeze(1).long().cuda()).float() #size: 2 x batch_size 74 | group_count = group_map.sum(1) 75 | group_denom = group_count + (group_count==0).float() # avoid nans 76 | #import pdb; pdb.set_trace() 77 | 78 | group_loss = (group_map @ losses.view(-1))/group_denom 79 | return group_loss, group_count 80 | 81 | def update_exp_avg_loss(self, group_loss, group_count): 82 | prev_weights = (1 - self.gamma*(group_count>0).float()) * (self.exp_avg_initialized>0).float() 83 | curr_weights = 1 - prev_weights 84 | self.exp_avg_loss = self.exp_avg_loss * prev_weights + group_loss*curr_weights 85 | self.exp_avg_initialized = (self.exp_avg_initialized>0) + (group_count>0) 86 | 87 | def reset_stats(self): 88 | self.processed_data_counts = torch.zeros(self.n_groups).cuda() 89 | self.update_data_counts = torch.zeros(self.n_groups).cuda() 90 | self.update_batch_counts = torch.zeros(self.n_groups).cuda() 91 | self.avg_group_loss = torch.zeros(self.n_groups).cuda() 92 | self.avg_group_acc = torch.zeros(self.n_groups).cuda() 93 | self.avg_per_sample_loss = 0. 94 | self.avg_actual_loss = 0. 95 | self.avg_acc = 0. 96 | self.batch_count = 0. 97 | 98 | def update_stats(self, actual_loss, group_loss, group_acc, group_count, weights=None): 99 | # avg group loss 100 | denom = self.processed_data_counts + group_count 101 | denom += (denom==0).float() 102 | prev_weight = self.processed_data_counts/denom 103 | curr_weight = group_count/denom 104 | self.avg_group_loss = prev_weight*self.avg_group_loss + curr_weight*group_loss 105 | 106 | # avg group acc 107 | self.avg_group_acc = prev_weight*self.avg_group_acc + curr_weight*group_acc 108 | 109 | # batch-wise average actual loss 110 | denom = self.batch_count + 1 111 | self.avg_actual_loss = (self.batch_count/denom)*self.avg_actual_loss + (1/denom)*actual_loss 112 | 113 | # counts 114 | self.processed_data_counts += group_count 115 | if self.is_robust: 116 | self.update_data_counts += group_count*((weights>0).float()) 117 | self.update_batch_counts += ((group_count*weights)>0).float() 118 | else: 119 | self.update_data_counts += group_count 120 | self.update_batch_counts += (group_count>0).float() 121 | self.batch_count+=1 122 | 123 | # avg per-sample quantities 124 | group_frac = self.processed_data_counts/(self.processed_data_counts.sum()) 125 | self.avg_per_sample_loss = group_frac @ self.avg_group_loss 126 | self.avg_acc = group_frac @ self.avg_group_acc 127 | -------------------------------------------------------------------------------- /notebooks/PAPILA.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "bbd4178f", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import h5py\n", 11 | "import pandas as pd\n", 12 | "import numpy as np\n", 13 | "import cv2\n", 14 | "import os\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "from collections import Counter\n", 17 | "from sklearn.model_selection import train_test_split" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "id": "e1173431", 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "# read metadata\n", 28 | "path = 'your_path/data/PAPILA/'\n", 29 | "\n", 30 | "# OD for right, OS for left\n", 31 | "od_meta = pd.read_csv(path + 'ClinicalData/patient_data_od.csv')\n", 32 | "os_meta = pd.read_csv(path + 'ClinicalData/patient_data_os.csv')\n", 33 | "od_meta.head()" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "id": "75d82bf3", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "os_meta.head()" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "id": "95739fb4", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "ids = os_meta['ID'].values\n", 54 | "os_path = ['RET' + x[1:] + 'OS.jpg' for x in ids]\n", 55 | "os_meta['Path'] = os_path" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "a376a0aa", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "ids = od_meta['ID'].values\n", 66 | "od_path = ['RET' + x[1:] + 'OD.jpg' for x in ids]\n", 67 | "od_meta['Path'] = od_path" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "id": "a7847f09", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "meta_all = pd.concat([od_meta, os_meta])\n", 78 | "subcolumns = ['ID', 'Age', 'Gender', 'Diagnosis', 'Path']\n", 79 | "meta_all = meta_all[subcolumns]\n", 80 | "meta_all" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "id": "0e7dff3c", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "meta_all.to_csv(path + 'ClinicalData/patient_meta_concat.csv')" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "id": "7263de2a", 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "# the patient (0 for male and 1 for female), \n", 101 | "# the diagnosis (0 stands for healthy, 1 for glaucoma, and 2 for suspicious)\n", 102 | "\n", 103 | "sex = meta_all['Gender'].values.astype('str')\n", 104 | "sex[sex == '0.0'] = 'M'\n", 105 | "sex[sex == '1.0'] = 'F'\n", 106 | "meta_all['Sex'] = sex\n", 107 | "meta_all" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "1908675b", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "meta_all['Age_multi'] = meta_all['Age'].values.astype('int')\n", 118 | "meta_all['Age_multi'] = np.where(meta_all['Age_multi'].between(0,19), 0, meta_all['Age_multi'])\n", 119 | "meta_all['Age_multi'] = np.where(meta_all['Age_multi'].between(20,39), 1, meta_all['Age_multi'])\n", 120 | "meta_all['Age_multi'] = np.where(meta_all['Age_multi'].between(40,59), 2, meta_all['Age_multi'])\n", 121 | "meta_all['Age_multi'] = np.where(meta_all['Age_multi'].between(60,79), 3, meta_all['Age_multi'])\n", 122 | "meta_all['Age_multi'] = np.where(meta_all['Age_multi']>=80, 4, meta_all['Age_multi'])\n", 123 | "\n", 124 | "meta_all['Age_binary'] = meta_all['Age'].values.astype('int')\n", 125 | "meta_all['Age_binary'] = np.where(meta_all['Age_binary'].between(0, 60), 0, meta_all['Age_binary'])\n", 126 | "meta_all['Age_binary'] = np.where(meta_all['Age_binary']>= 60, 1, meta_all['Age_binary'])\n", 127 | "meta_all" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "id": "55550dfe", 134 | "metadata": { 135 | "scrolled": true 136 | }, 137 | "outputs": [], 138 | "source": [ 139 | "# binary , only use healthy and glaucoma, i.e. 0 and 1.\n", 140 | "\n", 141 | "meta_binary = meta_all[(meta_all['Diagnosis'].values == 1.0) | (meta_all['Diagnosis'].values == 0.0)]\n", 142 | "len(meta_binary)" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "id": "e4f757b9", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "def split_712(all_meta, patient_ids):\n", 153 | " sub_train, sub_val_test = train_test_split(patient_ids, test_size=0.3, random_state=5)\n", 154 | " sub_val, sub_test = train_test_split(sub_val_test, test_size=0.66, random_state=15)\n", 155 | " train_meta = all_meta[all_meta.ID.isin(sub_train)]\n", 156 | " val_meta = all_meta[all_meta.ID.isin(sub_val)]\n", 157 | " test_meta = all_meta[all_meta.ID.isin(sub_test)]\n", 158 | " return train_meta, val_meta, test_meta\n", 159 | "\n", 160 | "sub_train, sub_val, sub_test = split_712(meta_binary, np.unique(meta_binary['ID']))" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "id": "b91657ee", 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "sub_train.to_csv('/yourpath/data/PAPILA/split/new_train.csv')\n", 171 | "sub_val.to_csv('/yourpath/data/PAPILA/split/new_val.csv')\n", 172 | "sub_test.to_csv('/yourpath/data/PAPILA/split/new_test.csv')" 173 | ] 174 | } 175 | ], 176 | "metadata": { 177 | "kernelspec": { 178 | "display_name": "torch11", 179 | "language": "python", 180 | "name": "torch11" 181 | }, 182 | "language_info": { 183 | "codemirror_mode": { 184 | "name": "ipython", 185 | "version": 3 186 | }, 187 | "file_extension": ".py", 188 | "mimetype": "text/x-python", 189 | "name": "python", 190 | "nbconvert_exporter": "python", 191 | "pygments_lexer": "ipython3", 192 | "version": "3.8.12" 193 | } 194 | }, 195 | "nbformat": 4, 196 | "nbformat_minor": 5 197 | } 198 | -------------------------------------------------------------------------------- /models/LNL/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | from importlib import import_module 5 | from torch.autograd import Function 6 | from torchvision.models.feature_extraction import create_feature_extractor 7 | 8 | 9 | class LNLGradReverse(Function): 10 | """ 11 | Implement the gradient reversal layer for the convenience of domain adaptation neural network. 12 | The forward part is the identity function while the backward part is the negative function. 13 | """ 14 | 15 | @staticmethod 16 | def forward(ctx, x): 17 | return x.view_as(x) 18 | 19 | @staticmethod 20 | def backward(ctx, grad_output): 21 | return grad_output.neg() 22 | 23 | 24 | def grad_reverseLNL(x): 25 | return LNLGradReverse.apply(x) 26 | 27 | 28 | class LNLNet(nn.Module): 29 | def __init__(self, backbone, num_classes, pretrained=True): 30 | super(LNLNet, self).__init__() 31 | 32 | self.backbone = backbone[3:].lower() 33 | mod = import_module("torchvision.models") 34 | cusModel = getattr(mod, self.backbone) 35 | resnet = cusModel(pretrained=pretrained) 36 | 37 | resnet.fc = nn.Linear(resnet.fc.in_features, num_classes) 38 | self.pred_ch = resnet.layer2[-1].conv1.in_channels 39 | 40 | self.returnkey = 'layer2' 41 | self.returnkey_avg = 'avgpool' 42 | self.returnkey_fc = 'fc' 43 | self.body = create_feature_extractor( 44 | resnet, return_nodes={'layer2': self.returnkey, 'avgpool': self.returnkey_avg, 'fc': self.returnkey_fc}) 45 | 46 | def forward(self, x): 47 | output = self.body(x) 48 | return output[self.returnkey_fc], output[self.returnkey] 49 | 50 | def inference(self, x): 51 | 52 | output = self.body(x) 53 | return output[self.returnkey_fc], output[self.returnkey_avg].squeeze() 54 | 55 | 56 | class LNLPredictor(nn.Module): 57 | def __init__(self, input_ch, num_classes=2): 58 | super(LNLPredictor, self).__init__() 59 | self.pred_conv1 = nn.Conv2d(input_ch, input_ch, kernel_size=3, 60 | stride=1, padding=1) 61 | self.pred_bn1 = nn.BatchNorm2d(input_ch) 62 | self.relu = nn.ReLU(inplace=True) 63 | self.pred_conv2 = nn.Conv2d(input_ch, input_ch, kernel_size=3, 64 | stride=1, padding=1) 65 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 66 | self.linear = nn.Linear(input_ch, num_classes) # binary classification, here use sigmoid instead of softmax 67 | 68 | def forward(self, x): 69 | x = self.pred_conv1(x) 70 | x = self.pred_bn1(x) 71 | x = self.relu(x) 72 | x = self.pred_conv2(x) 73 | x2 = self.avgpool(x) 74 | x2 = x2.view(x2.size(0), -1) 75 | px = self.linear(x2) 76 | return x, px 77 | 78 | 79 | class LNLNet3D(nn.Module): 80 | def __init__(self, backbone, num_classes, pretrained=True): 81 | super(LNLNet3D, self).__init__() 82 | 83 | self.backbone = backbone[3:].lower() 84 | #mod = import_module("torchvision.models.video.r3d_18") 85 | #cusModel = getattr(mod, self.backbone) 86 | resnet = torchvision.models.video.r3d_18(pretrained=pretrained) 87 | 88 | resnet.fc = nn.Linear(resnet.fc.in_features, num_classes) 89 | self.pred_ch = resnet.layer2[-1].conv1[0].in_channels 90 | 91 | self.returnkey = 'layer2' 92 | self.returnkey_avg = 'avgpool' 93 | self.returnkey_fc = 'fc' 94 | self.body = create_feature_extractor( 95 | resnet, return_nodes={'layer2': self.returnkey, 'avgpool': self.returnkey_avg, 'fc': self.returnkey_fc}) 96 | 97 | def forward(self, x): 98 | output = self.body(x) 99 | return output[self.returnkey_fc], output[self.returnkey] 100 | 101 | def inference(self, x): 102 | output = self.body(x) 103 | return output[self.returnkey_fc], output[self.returnkey_avg].squeeze() 104 | 105 | 106 | class LNLPredictor3D(nn.Module): 107 | def __init__(self, input_ch, num_classes=2): 108 | super(LNLPredictor3D, self).__init__() 109 | self.pred_conv1 = nn.Conv3d(input_ch, input_ch, kernel_size=3, 110 | stride=1, padding=1) 111 | self.pred_bn1 = nn.BatchNorm3d(input_ch) 112 | self.relu = nn.ReLU(inplace=True) 113 | self.pred_conv2 = nn.Conv3d(input_ch, input_ch, kernel_size=3, 114 | stride=1, padding=1) 115 | self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) 116 | self.linear = nn.Linear(input_ch, num_classes) # binary classification, here use sigmoid instead of softmax 117 | 118 | def forward(self, x): 119 | x = self.pred_conv1(x) 120 | x = self.pred_bn1(x) 121 | x = self.relu(x) 122 | x = self.pred_conv2(x) 123 | x2 = self.avgpool(x) 124 | x2 = x2.view(x2.size(0), -1) 125 | px = self.linear(x2) 126 | return x, px 127 | 128 | 129 | class LNLNet_MLP(nn.Module): 130 | def __init__(self, backbone, num_classes, in_features=1024, hidden_features=1024): 131 | super(LNLNet_MLP, self).__init__() 132 | 133 | mod = import_module("models.basemodels_mlp") 134 | cusModel = getattr(mod, backbone) 135 | self.net = cusModel(n_classes=num_classes, in_features= in_features, hidden_features=hidden_features) 136 | hidden_size = hidden_features 137 | 138 | self.pred_ch = hidden_size 139 | 140 | def forward(self, x): 141 | output, hidden = self.net.backbone(x) 142 | return output, hidden 143 | 144 | def inference(self, x): 145 | output, hidden = self.net.backbone(x) 146 | return output, hidden 147 | 148 | 149 | class LNLPredictor_MLP(nn.Module): 150 | def __init__(self, input_ch, num_classes=2, hidden_features = 512): 151 | super(LNLPredictor_MLP, self).__init__() 152 | self.pred_f1 = nn.Linear(input_ch, hidden_features) 153 | self.relu = nn.ReLU() 154 | self.pred_fc2 = nn.Linear(hidden_features, num_classes) 155 | 156 | def forward(self, x): 157 | x = self.pred_f1(x) 158 | x = self.relu(x) 159 | px = self.pred_fc2(x) 160 | return x, px -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.transforms as transforms 4 | import datasets 5 | import pandas as pd 6 | import random 7 | import torchio as tio 8 | from utils.spatial_transforms import ToTensor 9 | 10 | from torchvision.transforms._transforms_video import ( 11 | NormalizeVideo, 12 | ) 13 | 14 | from torch.utils.data import WeightedRandomSampler 15 | 16 | 17 | def get_dataset(opt): 18 | data_setting = opt['data_setting'] 19 | mean=[0.485, 0.456, 0.406] 20 | std=[0.229, 0.224, 0.225] 21 | normalize = transforms.Normalize(mean=mean, std=std) 22 | if opt['is_3d']: 23 | mean_3d = [0.45, 0.45, 0.45] 24 | std_3d = [0.225, 0.225, 0.225] 25 | sizes = {'ADNI': (192, 192, 128), 'ADNI3T': (192, 192, 128), 'OCT': (192, 192, 96), 'COVID_CT_MD': (224, 224, 80)} 26 | if data_setting['augment']: 27 | transform_train = transforms.Compose([ 28 | tio.transforms.RandomFlip(), 29 | tio.transforms.RandomAffine(scales=(0.9, 1.2), degrees=15,), 30 | tio.transforms.CropOrPad(sizes[opt['dataset_name']]), 31 | 32 | ToTensor(), 33 | NormalizeVideo(mean_3d, std_3d), 34 | ]) 35 | else: 36 | transform_train = transforms.Compose([ 37 | tio.transforms.CropOrPad(sizes[opt['dataset_name']]), 38 | ToTensor(), 39 | NormalizeVideo(mean_3d, std_3d), 40 | ]) 41 | 42 | transform_test = transforms.Compose([ 43 | tio.transforms.CropOrPad(sizes[opt['dataset_name']]), 44 | ToTensor(), 45 | NormalizeVideo(mean_3d, std_3d), 46 | ]) 47 | elif opt['is_tabular']: 48 | pass 49 | else: 50 | if data_setting['augment']: 51 | transform_train = transforms.Compose([ 52 | transforms.Resize(256), 53 | transforms.RandomHorizontalFlip(), 54 | transforms.RandomRotation((-15, 15)), 55 | transforms.RandomCrop((224, 224)), 56 | transforms.ToTensor(), 57 | normalize, 58 | ]) 59 | else: 60 | transform_train = transforms.Compose([ 61 | transforms.Resize(256), 62 | transforms.CenterCrop(224), 63 | transforms.ToTensor(), 64 | normalize, 65 | ]) 66 | 67 | transform_test = transforms.Compose([ 68 | transforms.Resize(256), 69 | transforms.CenterCrop(224), 70 | transforms.ToTensor(), 71 | normalize, 72 | ]) 73 | 74 | g = torch.Generator() 75 | g.manual_seed(opt['random_seed']) 76 | def seed_worker(worker_id): 77 | np.random.seed(opt['random_seed'] ) 78 | random.seed(opt['random_seed']) 79 | 80 | image_path = data_setting['image_feature_path'] 81 | train_meta = pd.read_csv(data_setting['train_meta_path']) 82 | val_meta = pd.read_csv(data_setting['val_meta_path']) 83 | test_meta = pd.read_csv(data_setting['test_meta_path']) 84 | 85 | if opt['is_3d']: 86 | dataset_name = getattr(datasets, opt['dataset_name']) 87 | train_data = dataset_name(train_meta, image_path, opt['sensitive_name'], opt['train_sens_classes'], transform_train) 88 | val_data = dataset_name(val_meta, image_path, opt['sensitive_name'], opt['sens_classes'], transform_test) 89 | test_data = dataset_name(test_meta, image_path, opt['sensitive_name'], opt['sens_classes'], transform_test) 90 | elif opt['is_tabular']: 91 | # different format 92 | dataset_name = getattr(datasets, opt['dataset_name']) 93 | data_train_path = data_setting['data_train_path'] 94 | data_val_path = data_setting['data_val_path'] 95 | data_test_path = data_setting['data_test_path'] 96 | 97 | data_train_df = pd.read_csv(data_train_path) 98 | data_val_df = pd.read_csv(data_val_path) 99 | data_test_df = pd.read_csv(data_test_path) 100 | 101 | train_data = dataset_name(train_meta, data_train_df, opt['sensitive_name'], opt['train_sens_classes'], None) 102 | val_data = dataset_name(val_meta, data_val_df, opt['sensitive_name'], opt['sens_classes'], None) 103 | test_data = dataset_name(test_meta, data_test_df, opt['sensitive_name'], opt['sens_classes'], None) 104 | 105 | else: 106 | dataset_name = getattr(datasets, opt['dataset_name']) 107 | pickle_train_path = data_setting['pickle_train_path'] 108 | pickle_val_path = data_setting['pickle_val_path'] 109 | pickle_test_path = data_setting['pickle_test_path'] 110 | train_data = dataset_name(train_meta, pickle_train_path, opt['sensitive_name'], opt['train_sens_classes'], transform_train) 111 | val_data = dataset_name(val_meta, pickle_val_path, opt['sensitive_name'], opt['sens_classes'], transform_test) 112 | test_data = dataset_name(test_meta, pickle_test_path, opt['sensitive_name'], opt['sens_classes'], transform_test) 113 | 114 | print('loaded dataset ', opt['dataset_name']) 115 | 116 | if opt['experiment']=='resampling' or opt['experiment']=='GroupDRO' or opt['experiment']=='resamplingSWAD': 117 | weights = train_data.get_weights(resample_which = opt['resample_which']) 118 | sampler = WeightedRandomSampler(weights, len(weights), replacement=True, generator = g) 119 | else: 120 | sampler = None 121 | 122 | train_loader = torch.utils.data.DataLoader( 123 | train_data, batch_size=opt['batch_size'], 124 | sampler=sampler, 125 | shuffle=(opt['experiment']!='resampling' and opt['experiment']!='GroupDRO' and opt['experiment']!='resamplingSWAD'), num_workers=8, 126 | worker_init_fn=seed_worker, generator=g, pin_memory=True) 127 | val_loader = torch.utils.data.DataLoader( 128 | val_data, batch_size=opt['batch_size'], 129 | shuffle=True, num_workers=8, worker_init_fn=seed_worker, generator=g, pin_memory=True) 130 | test_loader = torch.utils.data.DataLoader( 131 | test_data, batch_size=opt['batch_size'], 132 | shuffle=True, num_workers=8, worker_init_fn=seed_worker, generator=g, pin_memory=True) 133 | 134 | return train_data, val_data, test_data, train_loader, val_loader, test_loader, val_meta, test_meta 135 | -------------------------------------------------------------------------------- /models/LAFTR/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from importlib import import_module 4 | 5 | 6 | class LaftrNet(nn.Module): 7 | def __init__(self, backbone, num_classes, adversary_size = 128, pretrained = True, device = 'cuda', model_var = 'laftr-dp'): 8 | super(LaftrNet, self).__init__() 9 | 10 | self.backbone = backbone 11 | self.model_var = model_var 12 | self.num_classes = num_classes 13 | self.used_classes = 2 14 | 15 | mod = import_module("models.basemodels") 16 | cusModel = getattr(mod, self.backbone) 17 | self.net = cusModel(n_classes=self.num_classes, pretrained=pretrained) 18 | hidden_size = self.net.body.fc.in_features 19 | 20 | self.device = device 21 | 22 | if self.model_var != "laftr-dp": 23 | self.adv_neurons = [hidden_size + self.used_classes - 1] \ 24 | + [adversary_size] \ 25 | + [self.used_classes - 1] 26 | else: 27 | self.adv_neurons = [hidden_size] + [adversary_size] + [self.used_classes - 1] 28 | 29 | self.num_adversaries_layers = len(self.adv_neurons) 30 | # Conditional adversaries for sensitive attribute classification, one separate adversarial classifier for one class label. 31 | self.discriminator = nn.ModuleList([nn.Linear(self.adv_neurons[i], self.adv_neurons[i + 1]) 32 | for i in range(self.num_adversaries_layers -1)]) 33 | 34 | def forward(self, X, Y=None): 35 | Y_logits, Z = self.net(X) 36 | if Y is None: 37 | # for inference 38 | return Y_logits, Z 39 | 40 | if self.model_var != "laftr-dp": 41 | Z = torch.cat( 42 | [Z, torch.unsqueeze(Y[:, 0].type(torch.FloatTensor), 1).to(self.device)], 43 | axis=1,) 44 | for hidden in self.discriminator: 45 | Z = hidden(Z) 46 | 47 | # For discriminator loss 48 | A_logits = torch.squeeze(Z) 49 | return Y_logits, A_logits 50 | 51 | def inference(self, X): 52 | Y_logits, Z = self.net(X) 53 | return Y_logits, Z 54 | 55 | 56 | class LaftrNet3D(nn.Module): 57 | def __init__(self, backbone, num_classes, adversary_size = 128, pretrained = True, device = 'cuda', model_var = 'laftr-dp'): 58 | super(LaftrNet3D, self).__init__() 59 | 60 | self.backbone = backbone 61 | self.model_var = model_var 62 | self.num_classes = num_classes 63 | self.used_classes = 2 64 | 65 | mod = import_module("models.basemodels_3d") 66 | cusModel = getattr(mod, self.backbone) 67 | self.net = cusModel(n_classes=self.num_classes, pretrained=pretrained) 68 | hidden_size = self.net.body.fc.in_features 69 | 70 | self.device = device 71 | 72 | if self.model_var != "laftr-dp": 73 | self.adv_neurons = [hidden_size + self.used_classes - 1] \ 74 | + [adversary_size] \ 75 | + [self.used_classes - 1] 76 | else: 77 | self.adv_neurons = [hidden_size] + [adversary_size] + [self.used_classes - 1] 78 | 79 | 80 | self.num_adversaries_layers = len(self.adv_neurons) 81 | # Conditional adversaries for sensitive attribute classification, one separate adversarial classifier for one class label. 82 | self.discriminator = nn.ModuleList([nn.Linear(self.adv_neurons[i], self.adv_neurons[i + 1]) 83 | for i in range(self.num_adversaries_layers -1)]) 84 | 85 | def forward(self, X, Y=None): 86 | Y_logits, Z = self.net(X) 87 | if Y is None: 88 | # for inference 89 | return Y_logits, Z 90 | if self.model_var != "laftr-dp": 91 | Z = torch.cat( 92 | [Z, torch.unsqueeze(Y[:, 0].type(torch.FloatTensor), 1).to(self.device)], 93 | axis=1,) 94 | for hidden in self.discriminator: 95 | Z = hidden(Z) 96 | 97 | # For discriminator loss 98 | A_logits = torch.squeeze(Z) 99 | return Y_logits, A_logits 100 | 101 | def inference(self, X): 102 | Y_logits, Z = self.net(X) 103 | return Y_logits, Z 104 | 105 | 106 | 107 | class LaftrNet_MLP(nn.Module): 108 | def __init__(self, backbone, num_classes, adversary_size = 128, device = 'cuda', model_var = 'laftr-dp', in_features=1024, hidden_features=1024): 109 | super(LaftrNet_MLP, self).__init__() 110 | 111 | self.backbone = backbone 112 | self.model_var = model_var 113 | self.num_classes = num_classes 114 | self.used_classes = 2 115 | 116 | mod = import_module("models.basemodels_mlp") 117 | cusModel = getattr(mod, self.backbone) 118 | self.net = cusModel(n_classes=self.num_classes, in_features= in_features, hidden_features=hidden_features) 119 | hidden_size = hidden_features 120 | 121 | self.device = device 122 | 123 | if self.model_var != "laftr-dp": 124 | self.adv_neurons = [hidden_size + self.used_classes - 1] \ 125 | + [adversary_size] \ 126 | + [self.used_classes - 1] 127 | else: 128 | self.adv_neurons = [hidden_size] + [adversary_size] + [self.used_classes - 1] 129 | 130 | 131 | self.num_adversaries_layers = len(self.adv_neurons) 132 | # Conditional adversaries for sensitive attribute classification, one separate adversarial classifier for 133 | # one class label. 134 | self.discriminator = nn.ModuleList([nn.Linear(self.adv_neurons[i], self.adv_neurons[i + 1]) 135 | for i in range(self.num_adversaries_layers -1)]) 136 | 137 | def forward(self, X, Y=None): 138 | Y_logits, Z = self.net(X) 139 | if Y is None: 140 | # for inference 141 | return Y_logits, Z 142 | if self.model_var != "laftr-dp": 143 | Z = torch.cat( 144 | [Z, torch.unsqueeze(Y[:, 0].type(torch.FloatTensor), 1).to(self.device)], 145 | axis=1,) 146 | for hidden in self.discriminator: 147 | Z = hidden(Z) 148 | 149 | # For discriminator loss 150 | A_logits = torch.squeeze(Z) 151 | return Y_logits, A_logits 152 | 153 | def inference(self, X): 154 | Y_logits, Z = self.net(X) 155 | return Y_logits, Z -------------------------------------------------------------------------------- /models/DomainInd/DomainInd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | from utils import basics 5 | from utils.evaluation import calculate_auc, calculate_metrics, calculate_FPR_FNR 6 | from models.baseline import baseline 7 | 8 | 9 | class DomainInd(baseline): 10 | def __init__(self, opt, wandb): 11 | super(DomainInd, self).__init__(opt, wandb) 12 | 13 | self.set_network(opt) 14 | self.set_data(opt) 15 | self.set_optimizer(opt) 16 | 17 | def _criterion_domain(self, output, target, sensitive_attr): 18 | domain_label = sensitive_attr.long() #.reshape(-1, 1) 19 | class_num = output.shape[1] // self.sens_classes 20 | preds = [] 21 | for i in range(domain_label.shape[0]): 22 | preds.append(output[i, domain_label[i] * class_num: (domain_label[i]+1) *class_num]) 23 | preds = torch.stack(preds) 24 | loss = F.binary_cross_entropy_with_logits(preds, target) 25 | return loss 26 | 27 | def inference_sum_prob(self, output): 28 | """Inference method: sum the probability from multiple domains""" 29 | #predict_prob = torch.sigmoid(output) 30 | predict_prob = output 31 | class_num = predict_prob.shape[1] // self.sens_classes 32 | predict_prob_sum = [] 33 | for i in range(self.sens_classes): 34 | predict_prob_sum.append(predict_prob[:, i * class_num: (i+1) * class_num]) 35 | predict_prob_sum = torch.stack(predict_prob_sum).sum(0) 36 | predict_prob_sum = torch.sigmoid(predict_prob_sum) 37 | return predict_prob_sum 38 | 39 | def _train(self, loader): 40 | """Train the model for one epoch""" 41 | 42 | self.network.train() 43 | train_loss, auc, no_iter = 0, 0., 0 44 | 45 | for i, (images, targets, sensitive_attr, index) in enumerate(loader): 46 | images, targets, sensitive_attr = images.to(self.device), targets.to(self.device), sensitive_attr.to( 47 | self.device) 48 | self.optimizer.zero_grad() 49 | outputs, _ = self.network.forward(images) 50 | 51 | loss = self._criterion_domain(outputs, targets, sensitive_attr) 52 | loss.backward() 53 | self.optimizer.step() 54 | 55 | outputs = self.inference_sum_prob(outputs) 56 | auc += calculate_auc(outputs.cpu().data.numpy(), 57 | targets.cpu().data.numpy()) 58 | train_loss += loss.item() 59 | no_iter += 1 60 | 61 | if self.log_freq and (i % self.log_freq == 0): 62 | self.wandb.log({'Training loss': train_loss / (i+1), 'Training AUC': auc / (i+1)}) 63 | 64 | auc = 100 * auc / no_iter 65 | train_loss /= no_iter 66 | 67 | print('Training epoch {}: AUC:{}'.format(self.epoch, auc)) 68 | print('Training epoch {}: loss:{}'.format(self.epoch, train_loss)) 69 | self.epoch += 1 70 | 71 | def _val(self, loader): 72 | """Compute model output on validation set""" 73 | 74 | self.network.eval() 75 | tol_output, tol_target, tol_sensitive, tol_index = [], [], [], [] 76 | val_loss, auc = 0., 0. 77 | no_iter = 0 78 | with torch.no_grad(): 79 | for i, (images, targets, sensitive_attr, index) in enumerate(loader): 80 | images, targets, sensitive_attr = images.to(self.device), targets.to(self.device), sensitive_attr.to( 81 | self.device) 82 | outputs, features = self.network.inference(images) 83 | loss = self._criterion_domain(outputs, targets, sensitive_attr) 84 | val_loss += loss.item() 85 | outputs = self.inference_sum_prob(outputs) 86 | 87 | tol_output += outputs.flatten().cpu().data.numpy().tolist() 88 | tol_target += targets.cpu().data.numpy().tolist() 89 | tol_sensitive += sensitive_attr.cpu().data.numpy().tolist() 90 | tol_index += index.numpy().tolist() 91 | 92 | auc += calculate_auc(outputs.cpu().data.numpy(), 93 | targets.cpu().data.numpy()) 94 | no_iter += 1 95 | if self.log_freq and (i % self.log_freq == 0): 96 | self.wandb.log({'Validation loss': val_loss / (i+1), 'Validation AUC': auc / (i+1)}) 97 | 98 | auc = 100 * auc / no_iter 99 | val_loss /= no_iter 100 | 101 | log_dict, t_predictions, pred_df = calculate_metrics(tol_output, tol_target, tol_sensitive, tol_index, self.sens_classes) 102 | print('Validation epoch {}: validation loss:{}, AUC:{}'.format( 103 | self.epoch, val_loss, auc)) 104 | 105 | return val_loss, auc, log_dict, pred_df 106 | 107 | def _test(self, loader): 108 | """Compute model output on testing set""" 109 | 110 | self.network.eval() 111 | tol_output, tol_target, tol_sensitive, tol_index = [], [], [], [] 112 | with torch.no_grad(): 113 | for i, (images, targets, sensitive_attr, index) in enumerate(loader): 114 | images, targets, sensitive_attr = images.to(self.device), targets.to(self.device), sensitive_attr.to( 115 | self.device) 116 | outputs, features = self.network.inference(images) 117 | outputs = self.inference_sum_prob(outputs) 118 | 119 | tol_output += outputs.flatten().cpu().data.numpy().tolist() 120 | tol_target += targets.cpu().data.numpy().tolist() 121 | tol_sensitive += sensitive_attr.cpu().data.numpy().tolist() 122 | tol_index += index.numpy().tolist() 123 | 124 | log_dict, t_predictions, pred_df = calculate_metrics(tol_output, tol_target, tol_sensitive, tol_index, self.sens_classes) 125 | overall_FPR, overall_FNR, FPRs, FNRs = calculate_FPR_FNR(pred_df, self.test_meta, self.opt) 126 | log_dict['Overall FPR'] = overall_FPR 127 | log_dict['Overall FNR'] = overall_FNR 128 | pred_df.to_csv(os.path.join(self.save_path, self.experiment + '_pred.csv'), index = False) 129 | 130 | for i, FPR in enumerate(FPRs): 131 | log_dict['FPR-group_' + str(i)] = FPR 132 | for i, FNR in enumerate(FNRs): 133 | log_dict['FNR-group_' + str(i)] = FNR 134 | 135 | log_dict = basics.add_dict_prefix(log_dict, 'Test ') 136 | 137 | return log_dict 138 | -------------------------------------------------------------------------------- /notebooks/ADNI.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "bbd4178f", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import h5py\n", 11 | "import pandas as pd\n", 12 | "import numpy as np\n", 13 | "import cv2\n", 14 | "import os\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "from collections import Counter\n", 17 | "from sklearn.model_selection import train_test_split" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "id": "e1173431", 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "# read metadata\n", 28 | "path = '/yourpath/data/ADNI/'\n", 29 | "\n", 30 | "# use `ADNI1_Baseline_3T_7_07_2022.csv` for ADNI 3T \n", 31 | "demo_data = pd.read_csv(path + 'ADNI1_Screening_1.5T_7_02_2022.csv')\n", 32 | "demo_data" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "52196b73", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "demo_data = demo_data[demo_data['Group'] != 'MCI']" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "id": "843f8689", 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "labels = demo_data['Group'].values.tolist()\n", 53 | "labels = [1 if x == 'AD' else 0 for x in labels]\n", 54 | "demo_data['label'] = labels" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "id": "5cd32963", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "# flatten directories\n", 65 | "\n", 66 | "import os\n", 67 | "import itertools\n", 68 | "import shutil\n", 69 | "\n", 70 | "\n", 71 | "def move(destination):\n", 72 | " all_files = []\n", 73 | " for root, _dirs, files in itertools.islice(os.walk(destination), 1, None):\n", 74 | " for filename in files:\n", 75 | " #print(filename)\n", 76 | " all_files.append(os.path.join(root, filename))\n", 77 | " for filename in all_files:\n", 78 | " shutil.move(filename, destination)\n", 79 | "\n", 80 | "move(path + 'images-bk')" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "id": "ffcf05e9", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "# rename file: subject_id __ image_id\n", 91 | "images = os.listdir(path + 'images-all')\n", 92 | "for image in images:\n", 93 | " subject_id = image[5: 15]\n", 94 | " image_id = image[-10:]\n", 95 | " \n", 96 | " new_name = subject_id + '__' + image_id\n", 97 | " old_path = os.path.join(path, 'images-all', image)\n", 98 | " new_path = os.path.join(path, 'images-all', new_name)\n", 99 | " os.rename(old_path, new_path)\n", 100 | "\n" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "id": "46c1b226", 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "# rename file: subject_id __ image_id\n", 111 | "images = os.listdir(path + 'images')\n", 112 | "for image in images:\n", 113 | " image_id = image[-10:]\n", 114 | " if not image_id[0] == 'I':\n", 115 | " image_id = 'I'+ image_id\n", 116 | " new_name = image.split('__')[0] + '__' + image_id\n", 117 | " old_path = os.path.join(path, 'images', image)\n", 118 | " new_path = os.path.join(path, 'images', new_name)\n", 119 | " os.rename(old_path, new_path)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "2bf42610", 126 | "metadata": { 127 | "scrolled": true 128 | }, 129 | "outputs": [], 130 | "source": [ 131 | "def addpath(row):\n", 132 | " return str(row['Subject']) + '__'+ str(row['Image Data ID']) + '.nii'\n", 133 | "\n", 134 | "demo_data[\"Path\"] = demo_data.apply(addpath, axis=1)\n", 135 | "#all_meta['Path'] = 'images/' + str(all_meta['idx']) + '.npy'\n", 136 | "demo_data" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "id": "9a1ed584", 143 | "metadata": { 144 | "scrolled": true 145 | }, 146 | "outputs": [], 147 | "source": [ 148 | "# the patient (0 for male and 1 for female), \n", 149 | "# the diagnosis (0 stands for healthy, 1 for glaucoma, and 2 for suspicious)\n", 150 | "\n", 151 | "\n", 152 | "demo_data['Age_multi'] = demo_data['Age'].values.astype('int')\n", 153 | "demo_data['Age_multi'] = np.where(demo_data['Age_multi'].between(0,54), 0, demo_data['Age_multi'])\n", 154 | "demo_data['Age_multi'] = np.where(demo_data['Age_multi'].between(55,65), 1, demo_data['Age_multi'])\n", 155 | "demo_data['Age_multi'] = np.where(demo_data['Age_multi'].between(65,75), 2, demo_data['Age_multi'])\n", 156 | "demo_data['Age_multi'] = np.where(demo_data['Age_multi'].between(75,85), 3, demo_data['Age_multi'])\n", 157 | "demo_data['Age_multi'] = np.where(demo_data['Age_multi']>=85, 4, demo_data['Age_multi'])\n", 158 | "\n", 159 | "demo_data['Age_binary'] = demo_data['Age'].values.astype('int')\n", 160 | "demo_data['Age_binary'] = np.where(demo_data['Age_binary'].between(0, 75), 0, demo_data['Age_binary'])\n", 161 | "demo_data['Age_binary'] = np.where(demo_data['Age_binary']>= 75, 1, demo_data['Age_binary'])\n", 162 | "demo_data" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "id": "e4f757b9", 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "def split_712(all_meta, patient_ids):\n", 173 | " sub_train, sub_val_test = train_test_split(patient_ids, test_size=0.3, random_state=5)\n", 174 | " sub_val, sub_test = train_test_split(sub_val_test, test_size=0.66, random_state=6)\n", 175 | " train_meta = all_meta[all_meta.Subject.isin(sub_train.astype('str'))]\n", 176 | " val_meta = all_meta[all_meta.Subject.isin(sub_val.astype('str'))]\n", 177 | " test_meta = all_meta[all_meta.Subject.isin(sub_test.astype('str'))]\n", 178 | " return train_meta, val_meta, test_meta\n", 179 | "\n", 180 | "sub_train, sub_val, sub_test = split_712(demo_data, np.unique(demo_data['Subject']))" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "id": "b91657ee", 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "sub_train.to_csv('/yourpath/ADNI/split/new_train.csv')\n", 191 | "sub_val.to_csv('/yourpath/ADNI/split/new_val.csv')\n", 192 | "sub_test.to_csv('/yourpath/ADNI/split/new_test.csv')" 193 | ] 194 | } 195 | ], 196 | "metadata": { 197 | "kernelspec": { 198 | "display_name": "torch11", 199 | "language": "python", 200 | "name": "torch11" 201 | }, 202 | "language_info": { 203 | "codemirror_mode": { 204 | "name": "ipython", 205 | "version": 3 206 | }, 207 | "file_extension": ".py", 208 | "mimetype": "text/x-python", 209 | "name": "python", 210 | "nbconvert_exporter": "python", 211 | "pygments_lexer": "ipython3", 212 | "version": "3.8.12" 213 | } 214 | }, 215 | "nbformat": 4, 216 | "nbformat_minor": 5 217 | } 218 | -------------------------------------------------------------------------------- /models/GroupDRO/GroupDRO.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from models import basemodels 7 | from utils import basics 8 | from utils.evaluation import calculate_auc, calculate_metrics, calculate_FPR_FNR 9 | from models.basenet import BaseNet 10 | from importlib import import_module 11 | from models.GroupDRO.utils import LossComputer 12 | 13 | 14 | class GroupDRO(BaseNet): 15 | def __init__(self, opt, wandb): 16 | super(GroupDRO, self).__init__(opt, wandb) 17 | 18 | self.set_network(opt) 19 | self.set_optimizer(opt) 20 | 21 | self.groupdro_alpha = opt['groupdro_alpha'] 22 | self.groupdro_gamma = opt['groupdro_gamma'] 23 | self.register_buffer("q", torch.ones(self.sens_classes)) 24 | 25 | self.criterion = nn.BCEWithLogitsLoss(reduction = 'none') 26 | 27 | generalization_adjustment = "0" 28 | adjustments = [float(c) for c in generalization_adjustment.split(',')] 29 | assert len(adjustments) in (1, self.train_data.sens_classes) 30 | if len(adjustments)==1: 31 | adjustments = np.array(adjustments* self.train_data.sens_classes) 32 | else: 33 | adjustments = np.array(adjustments) 34 | self.train_loss_computer = LossComputer( 35 | criterion = self._criterion, 36 | is_robust=True, 37 | dataset=self.train_data, 38 | alpha=self.groupdro_alpha, 39 | gamma=self.groupdro_gamma, 40 | adj=adjustments, 41 | step_size=0.01, 42 | normalize_loss=False, 43 | btl=False, 44 | min_var_weight=0) 45 | 46 | def set_network(self, opt): 47 | """Define the network""" 48 | 49 | if self.is_3d: 50 | mod = import_module("models.basemodels_3d") 51 | cusModel = getattr(mod, self.backbone) 52 | self.network = cusModel(n_classes=self.output_dim, pretrained = self.pretrained).to(self.device) 53 | elif self.is_tabular: 54 | mod = import_module("models.basemodels_mlp") 55 | cusModel = getattr(mod, self.backbone) 56 | self.network = cusModel(n_classes=self.output_dim, in_features= self.in_features, hidden_features = 1024).to(self.device) 57 | else: 58 | mod = import_module("models.basemodels") 59 | cusModel = getattr(mod, self.backbone) 60 | self.network = cusModel(n_classes=self.output_dim, pretrained=self.pretrained).to(self.device) 61 | 62 | def _train(self, loader): 63 | """Train the model for one epoch""" 64 | self.network.train() 65 | 66 | running_loss, auc = 0., 0. 67 | no_iter = 0 68 | for i, (images, targets, sensitive_attr, index) in enumerate(loader): 69 | images, targets, sensitive_attr = images.to(self.device), targets.to(self.device), sensitive_attr.to(self.device) 70 | self.optimizer.zero_grad() 71 | outputs, features = self.network.forward(images) 72 | 73 | loss = self.train_loss_computer.loss(outputs, targets, sensitive_attr, is_training = True) 74 | 75 | running_loss += loss.item() 76 | 77 | loss.backward() 78 | self.optimizer.step() 79 | 80 | auc += calculate_auc(F.sigmoid(outputs).cpu().data.numpy(), targets.cpu().data.numpy()) 81 | no_iter += 1 82 | 83 | if self.log_freq and (i % self.log_freq == 0): 84 | self.wandb.log({'Training loss': running_loss / (i+1), 'Training AUC': auc / (i+1)}) 85 | 86 | running_loss /= no_iter 87 | auc = auc / no_iter 88 | print('Training epoch {}: AUC:{}'.format(self.epoch, auc)) 89 | print('Training epoch {}: loss:{}'.format(self.epoch, running_loss)) 90 | self.epoch += 1 91 | 92 | def _val(self, loader): 93 | """Compute model output on validation set""" 94 | 95 | self.network.eval() 96 | tol_output, tol_target, tol_sensitive, tol_index = [], [], [], [] 97 | val_loss, auc = 0., 0. 98 | no_iter = 0 99 | with torch.no_grad(): 100 | for i, (images, targets, sensitive_attr, index) in enumerate(loader): 101 | images, targets, sensitive_attr = images.to(self.device), targets.to(self.device), sensitive_attr.to( 102 | self.device) 103 | outputs, features = self.network.inference(images) 104 | loss = self.train_loss_computer.loss(outputs, targets, sensitive_attr, is_training = False) 105 | val_loss += loss.item() 106 | 107 | tol_output += F.sigmoid(outputs).flatten().cpu().data.numpy().tolist() 108 | tol_target += targets.cpu().data.numpy().tolist() 109 | tol_sensitive += sensitive_attr.cpu().data.numpy().tolist() 110 | tol_index += index.numpy().tolist() 111 | 112 | auc += calculate_auc(outputs.cpu().data.numpy(), 113 | targets.cpu().data.numpy()) 114 | no_iter += 1 115 | if self.log_freq and (i % self.log_freq == 0): 116 | self.wandb.log({'Validation loss': val_loss / (i+1), 'Validation AUC': auc / (i+1)}) 117 | 118 | auc = 100 * auc / no_iter 119 | val_loss /= no_iter 120 | 121 | log_dict, t_predictions, pred_df = calculate_metrics(tol_output, tol_target, tol_sensitive, tol_index, self.sens_classes) 122 | print('Validation epoch {}: validation loss:{}, AUC:{}'.format( 123 | self.epoch, val_loss, auc)) 124 | 125 | return val_loss, auc, log_dict, pred_df 126 | 127 | def _test(self, loader): 128 | """Compute model output on testing set""" 129 | 130 | self.network.eval() 131 | tol_output, tol_target, tol_sensitive, tol_index = [], [], [], [] 132 | with torch.no_grad(): 133 | for i, (images, targets, sensitive_attr, index) in enumerate(loader): 134 | images, targets, sensitive_attr = images.to(self.device), targets.to(self.device), sensitive_attr.to( 135 | self.device) 136 | outputs, features = self.network.inference(images) 137 | 138 | tol_output += F.sigmoid(outputs).flatten().cpu().data.numpy().tolist() 139 | tol_target += targets.cpu().data.numpy().tolist() 140 | tol_sensitive += sensitive_attr.cpu().data.numpy().tolist() 141 | tol_index += index.numpy().tolist() 142 | 143 | 144 | log_dict, t_predictions, pred_df = calculate_metrics(tol_output, tol_target, tol_sensitive, tol_index, self.sens_classes) 145 | overall_FPR, overall_FNR, FPRs, FNRs = calculate_FPR_FNR(pred_df, self.test_meta, self.opt) 146 | log_dict['Overall FPR'] = overall_FPR 147 | log_dict['Overall FNR'] = overall_FNR 148 | pred_df.to_csv(os.path.join(self.save_path, self.experiment + '_pred.csv'), index = False) 149 | #basics.save_results(t_predictions, tol_target, s_prediction, tol_sensitive, self.save_path) 150 | for i, FPR in enumerate(FPRs): 151 | log_dict['FPR-group_' + str(i)] = FPR 152 | for i, FNR in enumerate(FNRs): 153 | log_dict['FNR-group_' + str(i)] = FNR 154 | 155 | log_dict = basics.add_dict_prefix(log_dict, 'Test ') 156 | 157 | return log_dict -------------------------------------------------------------------------------- /models/CFair/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from importlib import import_module 6 | from torch.autograd import Function 7 | 8 | 9 | class GradReverse(Function): 10 | """ 11 | Implement the gradient reversal layer adapting from domain adaptation neural network. 12 | The forward part is the identity function while the backward part is the negative function. 13 | """ 14 | @staticmethod 15 | def forward(ctx, x): 16 | return x.view_as(x) 17 | 18 | @staticmethod 19 | def backward(ctx, grad_output): 20 | return grad_output.neg() 21 | 22 | def grad_reverse(x): 23 | return GradReverse.apply(x) 24 | 25 | 26 | class CFairNet(nn.Module): 27 | def __init__(self, backbone, num_classes, adversary_size = 128, pretrained = True): 28 | super(CFairNet, self).__init__() 29 | 30 | self.num_classes = num_classes 31 | self.used_classes = 2 # can only handle binary attributes 32 | mod = import_module("models.basemodels") 33 | cusModel = getattr(mod, backbone) 34 | self.net = cusModel(n_classes=self.num_classes, pretrained=pretrained) 35 | hidden_size = self.net.body.fc.in_features 36 | # Parameter of the conditional adversary classification layer. 37 | self.num_adversaries = [hidden_size] + [adversary_size] 38 | self.num_adversaries_layers = len([adversary_size]) 39 | # Conditional adversaries for sensitive attribute classification, one separate adversarial classifier for one class label. 40 | self.adversaries = nn.ModuleList([nn.ModuleList([nn.Linear(self.num_adversaries[i], self.num_adversaries[i + 1]) 41 | for i in range(self.num_adversaries_layers)]) 42 | for _ in range(self.used_classes)]) 43 | self.sensitive_cls = nn.ModuleList([nn.Linear(self.num_adversaries[-1], 2) for _ in range(self.used_classes)]) 44 | 45 | def forward(self, inputs, labels=None): 46 | h_relu = inputs 47 | outputs, features = self.net(h_relu) 48 | if labels is None: 49 | # for inference 50 | return outputs, features 51 | h_relu = F.relu(features) 52 | 53 | # Adversary classification component. 54 | c_losses = [] 55 | h_relu = grad_reverse(h_relu) 56 | 57 | for j in range(self.used_classes): 58 | idx = labels[:, 0] == j 59 | c_h_relu = h_relu[idx] 60 | for hidden in self.adversaries[j]: 61 | c_h_relu = F.relu(hidden(c_h_relu)) 62 | c_cls = F.log_softmax(self.sensitive_cls[j](c_h_relu), dim=1) 63 | c_losses.append(c_cls) 64 | return outputs, c_losses 65 | 66 | def inference(self, inputs): 67 | outputs, features = self.net(inputs) 68 | return outputs, features 69 | 70 | class CFairNet3D(nn.Module): 71 | def __init__(self, backbone, num_classes, adversary_size = 128, pretrained = True): 72 | super(CFairNet3D, self).__init__() 73 | 74 | self.backbone = backbone 75 | self.num_classes = num_classes 76 | self.used_classes = 2 # can only handle binary attributes 77 | mod = import_module("models.basemodels_3d") 78 | cusModel = getattr(mod, self.backbone) 79 | self.net = cusModel(n_classes=self.num_classes, pretrained=pretrained) 80 | hidden_size = self.net.body.fc.in_features 81 | # Parameter of the conditional adversary classification layer. 82 | self.num_adversaries = [hidden_size] + [adversary_size] 83 | self.num_adversaries_layers = len([adversary_size]) 84 | # Conditional adversaries for sensitive attribute classification, one separate adversarial classifier for one class label. 85 | self.adversaries = nn.ModuleList([nn.ModuleList([nn.Linear(self.num_adversaries[i], self.num_adversaries[i + 1]) 86 | for i in range(self.num_adversaries_layers)]) 87 | for _ in range(self.used_classes)]) 88 | self.sensitive_cls = nn.ModuleList([nn.Linear(self.num_adversaries[-1], 2) for _ in range(self.used_classes)]) 89 | 90 | def forward(self, inputs, labels=None): 91 | h_relu = inputs 92 | outputs, features = self.net(h_relu) 93 | if labels is None: 94 | # for inference 95 | return outputs, features 96 | h_relu = F.relu(features) 97 | 98 | # Adversary classification component. 99 | c_losses = [] 100 | h_relu = grad_reverse(h_relu) 101 | 102 | for j in range(self.used_classes): 103 | idx = labels[:, 0] == j 104 | c_h_relu = h_relu[idx] 105 | for hidden in self.adversaries[j]: 106 | c_h_relu = F.relu(hidden(c_h_relu)) 107 | c_cls = F.log_softmax(self.sensitive_cls[j](c_h_relu), dim=1) 108 | c_losses.append(c_cls) 109 | return outputs, c_losses 110 | 111 | def inference(self, inputs): 112 | outputs, features = self.net(inputs) 113 | return outputs, features 114 | 115 | 116 | class CFairNet_MLP(nn.Module): 117 | def __init__(self, backbone, num_classes, adversary_size = 128, device = 'cuda', in_features=1024, hidden_features=1024): 118 | super(CFairNet_MLP, self).__init__() 119 | 120 | self.backbone = backbone 121 | self.num_classes = num_classes 122 | self.used_classes = 2 # can only handle binary attributes 123 | 124 | mod = import_module("models.basemodels_mlp") 125 | cusModel = getattr(mod, self.backbone) 126 | self.net = cusModel(n_classes=self.num_classes, in_features= in_features, hidden_features=hidden_features) 127 | hidden_size = hidden_features 128 | # Parameter of the conditional adversary classification layer. 129 | self.num_adversaries = [hidden_size] + [adversary_size] 130 | self.num_adversaries_layers = len([adversary_size]) 131 | # Conditional adversaries for sensitive attribute classification, one separate adversarial classifier for 132 | # one class label. 133 | self.adversaries = nn.ModuleList([nn.ModuleList([nn.Linear(self.num_adversaries[i], self.num_adversaries[i + 1]) 134 | for i in range(self.num_adversaries_layers)]) 135 | for _ in range(self.used_classes)]) 136 | self.sensitive_cls = nn.ModuleList([nn.Linear(self.num_adversaries[-1], 2) for _ in range(self.used_classes)]) 137 | 138 | def forward(self, inputs, labels=None): 139 | h_relu = inputs 140 | outputs, features = self.net(h_relu) 141 | if labels is None: 142 | # for inference 143 | return outputs, features 144 | h_relu = F.relu(features) 145 | 146 | # Adversary classification component. 147 | c_losses = [] 148 | h_relu = grad_reverse(h_relu) 149 | 150 | for j in range(self.used_classes): 151 | idx = labels[:, 0] == j 152 | c_h_relu = h_relu[idx] 153 | for hidden in self.adversaries[j]: 154 | c_h_relu = F.relu(hidden(c_h_relu)) 155 | c_cls = F.log_softmax(self.sensitive_cls[j](c_h_relu), dim=1) 156 | c_losses.append(c_cls) 157 | return outputs, c_losses 158 | 159 | def inference(self, inputs): 160 | outputs, features = self.net(inputs) 161 | return outputs, features -------------------------------------------------------------------------------- /notebooks/HAM10000-example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "bbd4178f", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import h5py\n", 11 | "import pandas as pd\n", 12 | "import numpy as np\n", 13 | "import cv2\n", 14 | "import os\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "from collections import Counter\n", 17 | "from sklearn.model_selection import train_test_split\n", 18 | "import pickle\n", 19 | "import time" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "id": "5cf5d2b1", 25 | "metadata": {}, 26 | "source": [ 27 | "## Preprocess metadata" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "id": "e1173431", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# read metadata\n", 38 | "path = 'your_path/fariness_data/HAM10000/'\n", 39 | "\n", 40 | "demo_data = pd.read_csv(path + 'HAM10000_metadata.csv')\n", 41 | "demo_data" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "95739fb4", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "Counter(demo_data['dataset'])" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 22, 57 | "id": "a376a0aa", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "# add image path to the metadata\n", 62 | "pathlist = demo_data['image_id'].values.tolist()\n", 63 | "paths = ['HAM10000_images/' + i + '.jpg' for i in pathlist]\n", 64 | "demo_data['Path'] = paths" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "id": "c80a746b", 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "# remove age/sex == null \n", 75 | "demo_data = demo_data[~demo_data['age'].isnull()]\n", 76 | "demo_data = demo_data[~demo_data['sex'].isnull()]\n", 77 | "demo_data" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "e08809a8", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "# unify the value of sensitive attributes\n", 88 | "sex = demo_data['sex'].values\n", 89 | "sex[sex == 'male'] = 'M'\n", 90 | "sex[sex == 'female'] = 'F'\n", 91 | "demo_data['Sex'] = sex\n", 92 | "demo_data" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "id": "39807da2", 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "# split subjects to different age groups\n", 103 | "demo_data['Age_multi'] = demo_data['age'].values.astype('int')\n", 104 | "demo_data['Age_multi'] = np.where(demo_data['Age_multi'].between(-1,19), 0, demo_data['Age_multi'])\n", 105 | "demo_data['Age_multi'] = np.where(demo_data['Age_multi'].between(20,39), 1, demo_data['Age_multi'])\n", 106 | "demo_data['Age_multi'] = np.where(demo_data['Age_multi'].between(40,59), 2, demo_data['Age_multi'])\n", 107 | "demo_data['Age_multi'] = np.where(demo_data['Age_multi'].between(60,79), 3, demo_data['Age_multi'])\n", 108 | "demo_data['Age_multi'] = np.where(demo_data['Age_multi']>=80, 4, demo_data['Age_multi'])\n", 109 | "\n", 110 | "demo_data['Age_binary'] = demo_data['age'].values.astype('int')\n", 111 | "demo_data['Age_binary'] = np.where(demo_data['Age_binary'].between(-1, 60), 0, demo_data['Age_binary'])\n", 112 | "demo_data['Age_binary'] = np.where(demo_data['Age_binary']>= 60, 1, demo_data['Age_binary'])\n", 113 | "demo_data" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "id": "c6135d73", 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "# convert to binary labels\n", 124 | "# benign: bcc, bkl, dermatofibroma, nv, vasc\n", 125 | "# maglinant: akiec, mel\n", 126 | "\n", 127 | "labels = demo_data['dx'].values.copy()\n", 128 | "labels[labels == 'akiec'] = '1'\n", 129 | "labels[labels == 'mel'] = '1'\n", 130 | "labels[labels != '1'] = '0'\n", 131 | "\n", 132 | "labels = labels.astype('int')\n", 133 | "\n", 134 | "demo_data['binaryLabel'] = labels\n", 135 | "demo_data" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "id": "b528c300", 141 | "metadata": {}, 142 | "source": [ 143 | "## Split train/val/test" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 27, 149 | "id": "e4f757b9", 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "def split_811(all_meta, patient_ids):\n", 154 | " sub_train, sub_val_test = train_test_split(patient_ids, test_size=0.2, random_state=0)\n", 155 | " sub_val, sub_test = train_test_split(sub_val_test, test_size=0.5, random_state=0)\n", 156 | " train_meta = all_meta[all_meta.lesion_id.isin(sub_train)]\n", 157 | " val_meta = all_meta[all_meta.lesion_id.isin(sub_val)]\n", 158 | " test_meta = all_meta[all_meta.lesion_id.isin(sub_test)]\n", 159 | " return train_meta, val_meta, test_meta\n", 160 | "\n", 161 | "sub_train, sub_val, sub_test = split_811(demo_data, np.unique(demo_data['lesion_id']))" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 33, 167 | "id": "b91657ee", 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "sub_train.to_csv('your_path/fariness_data/HAM10000/split/new_train.csv')\n", 172 | "sub_val.to_csv('your_path/fariness_data/HAM10000/split/new_val.csv')\n", 173 | "sub_test.to_csv('your_path/fariness_data/HAM10000/split/new_test.csv')" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "id": "c1d7d453", 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "# you can have a look of some examples here\n", 184 | "img = cv2.imread('your_path/fariness_data/HAM10000/HAM10000_images/ISIC_0027419.jpg')\n", 185 | "print(img.shape)\n", 186 | "plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "id": "ab951da8", 192 | "metadata": {}, 193 | "source": [ 194 | "## Save images into pickle files\n", 195 | "This is optional, but if you are training many models, this step can save a lot of time by reducing the data IO." 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "id": "5d667132", 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "test_meta = pd.read_csv('your_path/fariness_data/HAM10000/split/new_train.csv')\n", 206 | "\n", 207 | "path = 'your_path/fariness_data/HAM10000/pkls/'\n", 208 | "images = []\n", 209 | "start = time.time()\n", 210 | "for i in range(len(test_meta)):\n", 211 | "\n", 212 | " img = cv2.imread(path + test_meta.iloc[i]['Path'])\n", 213 | " # resize to the input size in advance to save time during training\n", 214 | " img = cv2.resize(img, (256, 256))\n", 215 | " images.append(img)\n", 216 | " \n", 217 | "end = time.time()\n", 218 | "end-start\n", 219 | "with open(path + 'train_images.pkl', 'wb') as f:\n", 220 | " pickle.dump(images, f)" 221 | ] 222 | } 223 | ], 224 | "metadata": { 225 | "kernelspec": { 226 | "display_name": "torch11", 227 | "language": "python", 228 | "name": "torch11" 229 | }, 230 | "language_info": { 231 | "codemirror_mode": { 232 | "name": "ipython", 233 | "version": 3 234 | }, 235 | "file_extension": ".py", 236 | "mimetype": "text/x-python", 237 | "name": "python", 238 | "nbconvert_exporter": "python", 239 | "pygments_lexer": "ipython3", 240 | "version": "3.8.12" 241 | } 242 | }, 243 | "nbformat": 4, 244 | "nbformat_minor": 5 245 | } 246 | -------------------------------------------------------------------------------- /models/SWA/SWA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from models import basemodels 6 | from utils import basics 7 | import pandas as pd 8 | from utils.evaluation import calculate_auc, calculate_metrics, calculate_FPR_FNR 9 | from models.basenet import BaseNet 10 | import torchvision 11 | 12 | from importlib import import_module 13 | from torch.optim.swa_utils import AveragedModel, SWALR 14 | from torch.optim.lr_scheduler import CosineAnnealingLR 15 | 16 | 17 | class SWA(BaseNet): 18 | def __init__(self, opt, wandb): 19 | super(SWA, self).__init__(opt, wandb) 20 | self.set_network(opt) 21 | self.swa_start = opt['swa_start'] 22 | self.swa_lr = opt['swa_lr'] 23 | self.annealing_epochs = opt['swa_annealing_epochs'] 24 | 25 | self.set_optimizer(opt) 26 | 27 | 28 | def set_network(self, opt): 29 | """Define the network""" 30 | 31 | if self.is_3d: 32 | mod = import_module("models.basemodels_3d") 33 | cusModel = getattr(mod, self.backbone) 34 | self.network = cusModel(n_classes=self.output_dim, pretrained = self.pretrained).to(self.device) 35 | elif self.is_tabular: 36 | mod = import_module("models.basemodels_mlp") 37 | cusModel = getattr(mod, self.backbone) 38 | self.network = cusModel(n_classes=self.output_dim, in_features= self.in_features, hidden_features = 1024).to(self.device) 39 | else: 40 | mod = import_module("models.basemodels") 41 | cusModel = getattr(mod, self.backbone) 42 | self.network = cusModel(n_classes=self.output_dim, pretrained=self.pretrained).to(self.device) 43 | 44 | self.swa_model = AveragedModel(self.network).to(self.device) 45 | 46 | def set_optimizer(self, opt): 47 | optimizer_setting = opt['optimizer_setting'] 48 | self.optimizer = optimizer_setting['optimizer']( 49 | params=filter(lambda p: p.requires_grad, self.network.parameters()), 50 | lr=optimizer_setting['lr'], 51 | weight_decay=optimizer_setting['weight_decay'] 52 | ) 53 | 54 | self.scheduler = CosineAnnealingLR(self.optimizer, T_max=100) 55 | self.swa_scheduler = SWALR(self.optimizer, anneal_epochs = self.annealing_epochs, swa_lr=self.swa_lr) 56 | 57 | def state_dict(self): 58 | state_dict = { 59 | 'model': self.swa_model.state_dict(), 60 | 'optimizer': self.optimizer.state_dict(), 61 | 'epoch': self.epoch 62 | } 63 | return state_dict 64 | 65 | def _train(self, loader): 66 | """Train the model for one epoch""" 67 | 68 | self.network.train() 69 | 70 | train_loss = 0 71 | auc = 0. 72 | no_iter = 0 73 | for i, (images, targets, sensitive_attr, index) in enumerate(loader): 74 | images, targets, sensitive_attr = images.to(self.device), targets.to(self.device), sensitive_attr.to(self.device) 75 | 76 | self.optimizer.zero_grad() 77 | outputs, _ = self.network(images) 78 | 79 | loss = self._criterion(outputs, targets) 80 | loss.backward() 81 | self.optimizer.step() 82 | 83 | auc += calculate_auc(F.sigmoid(outputs).cpu().data.numpy(), targets.cpu().data.numpy()) 84 | 85 | train_loss += loss.item() 86 | no_iter += 1 87 | 88 | if self.log_freq and (i % self.log_freq == 0): 89 | self.wandb.log({'Training loss': train_loss / (i+1), 'Training AUC': auc / (i+1)}) 90 | 91 | auc = 100 * auc / no_iter 92 | train_loss /= no_iter 93 | 94 | 95 | print('Training epoch {}: AUC:{}'.format(self.epoch, auc)) 96 | print('Training epoch {}: loss:{}'.format(self.epoch, train_loss)) 97 | 98 | self.epoch += 1 99 | 100 | if self.epoch >= self.swa_start: 101 | #if self.epoch == self.swa_start: 102 | # for g in self.optimizer.param_groups: 103 | # g['lr'] = 0.05 104 | 105 | self.swa_model.update_parameters(self.network) 106 | self.swa_scheduler.step() 107 | else: 108 | self.scheduler.step() 109 | 110 | 111 | 112 | def test(self): 113 | if self.test_mode: 114 | if not self.cross_testing: 115 | if self.hyper_search is True: 116 | state_dict = torch.load(os.path.join(self.resume_path, self.hash + '_' + str(self.seed) + '_best.pth')) 117 | print('Testing, loaded model from ', os.path.join(self.resume_path, self.hash + '_' + str(self.seed) + '_best.pth')) 118 | else: 119 | state_dict = torch.load(os.path.join(self.resume_path, str(self.seed) +'_best.pth')) 120 | print('Testing, loaded model from ', os.path.join(self.resume_path, str(self.seed) +'_best.pth')) 121 | else: 122 | state_dict = torch.load(self.load_path) 123 | print('Testing, loaded model from ', self.load_path) 124 | self.network.load_state_dict(state_dict['model']) 125 | else: 126 | torch.optim.swa_utils.update_bn(self.train_loader, self.swa_model, device = self.device) 127 | if self.hyper_search is True: 128 | basics.save_state_dict(self.state_dict(), os.path.join(self.save_path, self.hash + '_' + str(self.seed) + '_best.pth')) 129 | print('saving best model in ', os.path.join(self.save_path, self.hash + '_' + str(self.seed) + '_best.pth')) 130 | else: 131 | basics.save_state_dict(self.state_dict(), os.path.join(self.save_path, str(self.seed) + '_best.pth')) 132 | print('saving best model in ', os.path.join(self.save_path, str(self.seed) + '_best.pth')) 133 | self.network = self.swa_model.to(self.device) 134 | 135 | log_dict = self._test(self.test_loader) 136 | 137 | print('Finish testing') 138 | print(log_dict) 139 | return pd.DataFrame(log_dict, index=[0]) 140 | 141 | 142 | def _test(self, loader): 143 | self.network.eval() 144 | tol_output, tol_target, tol_sensitive, tol_index = [], [], [], [] 145 | 146 | with torch.no_grad(): 147 | for i, (images, targets, sensitive_attr, index) in enumerate(loader): 148 | images, targets, sensitive_attr = images.to(self.device), targets.to(self.device), sensitive_attr.to( 149 | self.device) 150 | outputs, _ = self.swa_model(images) 151 | 152 | tol_output += F.sigmoid(outputs).flatten().cpu().data.numpy().tolist() 153 | tol_target += targets.cpu().data.numpy().tolist() 154 | tol_sensitive += sensitive_attr.cpu().data.numpy().tolist() 155 | tol_index += index.numpy().tolist() 156 | 157 | 158 | log_dict, t_predictions, pred_df = calculate_metrics(tol_output, tol_target, tol_sensitive, tol_index, self.sens_classes) 159 | overall_FPR, overall_FNR, FPRs, FNRs = calculate_FPR_FNR(pred_df, self.test_meta, self.opt) 160 | log_dict['Overall FPR'] = overall_FPR 161 | log_dict['Overall FNR'] = overall_FNR 162 | pred_df.to_csv(os.path.join(self.save_path, 'pred.csv'), index = False) 163 | #basics.save_results(t_predictions, tol_target, s_prediction, tol_sensitive, self.save_path) 164 | for i, FPR in enumerate(FPRs): 165 | log_dict['FPR-group_' + str(i)] = FPR 166 | for i, FNR in enumerate(FNRs): 167 | log_dict['FNR-group_' + str(i)] = FNR 168 | 169 | log_dict = basics.add_dict_prefix(log_dict, 'Test ') 170 | #log_dict.update({'s_acc': round(sens_acc, 4),}) 171 | 172 | return log_dict -------------------------------------------------------------------------------- /datasets/BaseDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class BaseDataset(torch.utils.data.Dataset): 6 | def __init__(self, dataframe, path_to_images, sens_name, sens_classes, transform): 7 | super(BaseDataset, self).__init__() 8 | 9 | self.dataframe = dataframe 10 | self.dataset_size = self.dataframe.shape[0] 11 | self.transform = transform 12 | self.path_to_images = path_to_images 13 | self.sens_name = sens_name 14 | self.sens_classes = sens_classes 15 | 16 | self.A = None 17 | self.Y = None 18 | self.AY_proportion = None 19 | 20 | def get_AY_proportions(self): 21 | if self.AY_proportion: 22 | return self.AY_proportion 23 | 24 | A_num_class = 2 25 | Y_num_class = 2 26 | A_label = self.A 27 | Y_label = self.Y 28 | 29 | A = self.A.tolist() 30 | Y = self.Y.tolist() 31 | ttl = len(A) 32 | 33 | len_A0Y0 = len([ay for ay in zip(A, Y) if ay == (0, 0)]) 34 | len_A0Y1 = len([ay for ay in zip(A, Y) if ay == (0, 1)]) 35 | len_A1Y0 = len([ay for ay in zip(A, Y) if ay == (1, 0)]) 36 | len_A1Y1 = len([ay for ay in zip(A, Y) if ay == (1, 1)]) 37 | 38 | assert ( 39 | len_A0Y0 + len_A0Y1 + len_A1Y0 + len_A1Y1 40 | ) == ttl, "Problem computing train set AY proportion." 41 | A0Y0 = len_A0Y0 / ttl 42 | A0Y1 = len_A0Y1 / ttl 43 | A1Y0 = len_A1Y0 / ttl 44 | A1Y1 = len_A1Y1 / ttl 45 | 46 | self.AY_proportion = [[A0Y0, A0Y1], [A1Y0, A1Y1]] 47 | 48 | return self.AY_proportion 49 | 50 | def get_A_proportions(self): 51 | AY = self.get_AY_proportions() 52 | ret = [AY[0][0] + AY[0][1], AY[1][0] + AY[1][1]] 53 | np.testing.assert_almost_equal(np.sum(ret), 1.0) 54 | return ret 55 | 56 | def get_Y_proportions(self): 57 | AY = self.get_AY_proportions() 58 | ret = [AY[0][0] + AY[1][0], AY[0][1] + AY[1][1]] 59 | np.testing.assert_almost_equal(np.sum(ret), 1.0) 60 | return ret 61 | 62 | def set_A(self, sens_name): 63 | if sens_name == 'Sex': 64 | A = np.asarray(self.dataframe['Sex'].values != 'M').astype('float') 65 | elif sens_name == 'Age': 66 | A = np.asarray(self.dataframe['Age_binary'].values.astype('int') == 1).astype('float') 67 | elif sens_name == 'Race': 68 | A = np.asarray(self.dataframe['Race'].values == 'White').astype('float') 69 | elif self.sens_name == 'skin_type': 70 | A = np.asarray(self.dataframe['skin_binary'].values != 0).astype('float') 71 | elif self.sens_name == 'Insurance': 72 | self.A = np.asarray(self.dataframe['Insurance_binary'].values != 0).astype('float') 73 | else: 74 | raise ValueError("Does not contain {}".format(self.sens_name)) 75 | return A 76 | 77 | def get_weights(self, resample_which): 78 | sens_attr, group_num = self.group_counts(resample_which) 79 | group_weights = [1/x.item() for x in group_num] 80 | sample_weights = [group_weights[int(i)] for i in sens_attr] 81 | return sample_weights 82 | 83 | def group_counts(self, resample_which = 'group'): 84 | if resample_which == 'group' or resample_which == 'balanced': 85 | if self.sens_name == 'Sex': 86 | mapping = {'M': 0, 'F': 1} 87 | groups = self.dataframe['Sex'].values 88 | group_array = [*map(mapping.get, groups)] 89 | 90 | elif self.sens_name == 'Age': 91 | if self.sens_classes == 2: 92 | groups = self.dataframe['Age_binary'].values 93 | elif self.sens_classes == 5: 94 | groups = self.dataframe['Age_multi'].values 95 | elif self.sens_classes == 4: 96 | groups = self.dataframe['Age_multi4'].values.astype('int') 97 | group_array = groups.tolist() 98 | 99 | elif self.sens_name == 'Race': 100 | mapping = {'White': 0, 'non-White': 1} 101 | groups = self.dataframe['Race'].values 102 | group_array = [*map(mapping.get, groups)] 103 | elif self.sens_name == 'skin_type': 104 | if self.sens_classes == 2: 105 | groups = self.dataframe['skin_binary'].values 106 | elif self.sens_classes == 6: 107 | groups = self.dataframe['skin_type'].values 108 | group_array = groups.tolist() 109 | elif self.sens_name == 'Insurance': 110 | if self.sens_classes == 2: 111 | groups = self.dataframe['Insurance_binary'].values 112 | elif self.sens_classes == 5: 113 | groups = self.dataframe['Insurance'].values 114 | group_array = groups.tolist() 115 | else: 116 | raise ValueError("sensitive attribute does not defined in BaseDataset") 117 | 118 | if resample_which == 'balanced': 119 | #get class 120 | labels = self.Y.tolist() 121 | num_labels = len(set(labels)) 122 | num_groups = len(set(group_array)) 123 | 124 | group_array = (np.asarray(group_array) * num_labels + np.asarray(labels)).tolist() 125 | 126 | elif resample_which == 'class': 127 | group_array = self.Y.tolist() 128 | num_labels = len(set(group_array)) 129 | 130 | self._group_array = torch.LongTensor(group_array) 131 | if resample_which == 'group': 132 | self._group_counts = (torch.arange(self.sens_classes).unsqueeze(1)==self._group_array).sum(1).float() 133 | elif resample_which == 'balanced': 134 | self._group_counts = (torch.arange(num_labels * num_groups).unsqueeze(1)==self._group_array).sum(1).float() 135 | elif resample_which == 'class': 136 | self._group_counts = (torch.arange(num_labels).unsqueeze(1)==self._group_array).sum(1).float() 137 | return group_array, self._group_counts 138 | 139 | def __len__(self): 140 | return self.dataset_size 141 | 142 | def get_labels(self): 143 | # for sensitive attribute imbalance 144 | if self.sens_classes == 2: 145 | return self.A 146 | elif self.sens_classes == 5: 147 | return self.dataframe['Age_multi'].values.tolist() 148 | elif self.sens_classes == 4: 149 | return self.dataframe['Age_multi4'].values.tolist() 150 | 151 | def get_sensitive(self, sens_name, sens_classes, item): 152 | if sens_name == 'Sex': 153 | if item['Sex'] == 'M': 154 | sensitive = 0 155 | else: 156 | sensitive = 1 157 | elif sens_name == 'Age': 158 | if sens_classes == 2: 159 | sensitive = int(item['Age_binary']) 160 | elif sens_classes == 5: 161 | sensitive = int(item['Age_multi']) 162 | elif sens_classes == 4: 163 | sensitive = int(item['Age_multi4']) 164 | elif sens_name == 'Race': 165 | if item['Race'] == 'White': 166 | sensitive = 0 167 | else: 168 | sensitive = 1 169 | elif sens_name == 'skin_type': 170 | if sens_classes == 2: 171 | sensitive = int(item['skin_binary']) 172 | else: 173 | sensitive = int(item['skin_type']) 174 | elif self.sens_name == 'Insurance': 175 | if self.sens_classes == 2: 176 | sensitive = int(item['Insurance_binary']) 177 | elif self.sens_classes == 5: 178 | sensitive = int(item['Insurance']) 179 | else: 180 | raise ValueError('Please check the sensitive attributes.') 181 | return sensitive -------------------------------------------------------------------------------- /models/LAFTR/LAFTR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from models.LAFTR.model import LaftrNet, LaftrNet3D, LaftrNet_MLP 4 | from utils import basics 5 | from utils.evaluation import calculate_auc, calculate_metrics, calculate_FPR_FNR 6 | from models.utils import standard_val, standard_test 7 | from models.basenet import BaseNet 8 | 9 | 10 | class LAFTR(BaseNet): 11 | def __init__(self, opt, wandb): 12 | super(LAFTR, self).__init__(opt, wandb) 13 | 14 | self.model_var = opt['model_var'] 15 | self.test_classes = opt['sens_classes'] 16 | self.sens_classes = 2 17 | 18 | self.set_network(opt) 19 | self.set_optimizer(opt) 20 | 21 | self.aud_steps = opt['aud_steps'] 22 | self.class_coeff = opt['class_coeff'] 23 | self.fair_coeff = opt['fair_coeff'] 24 | 25 | def set_network(self, opt): 26 | """Define the network""" 27 | if self.is_3d: 28 | self.network = LaftrNet3D(backbone = self.backbone, num_classes=self.num_classes, adversary_size=128, pretrained = self.pretrained, device=self.device, model_var=self.model_var).to(self.device) 29 | elif self.is_tabular: 30 | self.network = LaftrNet_MLP(backbone = self.backbone, num_classes=self.num_classes, adversary_size=128, device=self.device, model_var=self.model_var, in_features=self.in_features, hidden_features=1024).to(self.device) 31 | else: 32 | self.network = LaftrNet(backbone = self.backbone, num_classes=self.num_classes, adversary_size=128, pretrained = self.pretrained, device=self.device, model_var=self.model_var).to(self.device) 33 | 34 | def set_optimizer(self, opt): 35 | optimizer_setting = opt['optimizer_setting'] 36 | self.optimizer = optimizer_setting['optimizer']( 37 | params=self.network.net.parameters(), 38 | lr=optimizer_setting['lr'], 39 | weight_decay=optimizer_setting['weight_decay'] 40 | ) 41 | self.optimizer_disc = optimizer_setting['optimizer']( 42 | params=filter(lambda p: p.requires_grad, self.network.discriminator.parameters()), 43 | lr=optimizer_setting['lr'], 44 | weight_decay=optimizer_setting['weight_decay'] 45 | ) 46 | 47 | def get_AYweights(self, data): 48 | A_weights, Y_weights, AY_weights = ( 49 | data.get_A_proportions(), 50 | data.get_Y_proportions(), 51 | data.get_AY_proportions(), 52 | ) 53 | return A_weights, Y_weights, AY_weights 54 | 55 | def l1_loss(self, y, y_logits): 56 | """Returns l1 loss""" 57 | y_hat = torch.sigmoid(y_logits) 58 | return torch.squeeze(torch.abs(y - y_hat)) 59 | 60 | def get_weighted_aud_loss(self, L, X, Y, A, A_wts, Y_wts, AY_wts): 61 | """Returns weighted discriminator loss""" 62 | Y = Y[:, 0] 63 | if self.model_var == "laftr-dp": 64 | A0_wt = A_wts[0] 65 | A1_wt = A_wts[1] 66 | wts = A0_wt * (1 - A) + A1_wt * A 67 | wtd_L = L * torch.squeeze(wts) 68 | elif ( 69 | self.model_var == "laftr-eqodd" 70 | or self.model_var == "laftr-eqopp0" 71 | or self.model_var == "laftr-eqopp1" 72 | ): 73 | A0_Y0_wt = AY_wts[0][0] 74 | A0_Y1_wt = AY_wts[0][1] 75 | A1_Y0_wt = AY_wts[1][0] 76 | A1_Y1_wt = AY_wts[1][1] 77 | 78 | if self.model_var == "laftr-eqodd": 79 | wts = ( 80 | A0_Y0_wt * (1 - A) * (1 - Y) 81 | + A0_Y1_wt * (1 - A) * (Y) 82 | + A1_Y0_wt * (A) * (1 - Y) 83 | + A1_Y1_wt * (A) * (Y) 84 | ) 85 | elif self.model_var == "laftr-eqopp0": 86 | wts = A0_Y0_wt * (1 - A) * (1 - Y) + A1_Y0_wt * (A) * (1 - Y) 87 | elif self.model_var == "laftr-eqopp1": 88 | wts = A0_Y1_wt * (1 - A) * (Y) + A1_Y1_wt * (A) * (Y) 89 | 90 | wtd_L = L * torch.squeeze(wts) 91 | else: 92 | raise Exception("Wrong model name") 93 | exit(0) 94 | 95 | return wtd_L 96 | 97 | def _train(self, loader): 98 | """Train the model for one epoch""" 99 | A_weights, Y_weights, AY_weights = self.get_AYweights(self.train_data) 100 | 101 | self.network.train() 102 | 103 | running_loss = 0. 104 | running_adv_loss = 0. 105 | auc = 0. 106 | no_iter = 0 107 | for i, (images, targets, sensitive_attr, index) in enumerate(loader): 108 | images, targets, sensitive_attr = images.to(self.device), targets.to(self.device), sensitive_attr.to( 109 | self.device) 110 | self.optimizer.zero_grad() 111 | Y_logits, A_logits = self.network.forward(images, targets) 112 | 113 | class_loss = self.class_coeff * self._criterion(Y_logits, targets) 114 | aud_loss = -self.fair_coeff * self.l1_loss(sensitive_attr, A_logits) 115 | weighted_aud_loss = self.get_weighted_aud_loss(aud_loss, images, targets, sensitive_attr, A_weights, 116 | Y_weights, AY_weights) 117 | weighted_aud_loss = torch.mean(weighted_aud_loss) 118 | loss = class_loss + weighted_aud_loss 119 | 120 | torch.autograd.set_detect_anomaly(True) 121 | 122 | self.optimizer.zero_grad() 123 | self.optimizer_disc.zero_grad() 124 | 125 | loss.backward(retain_graph=True) 126 | torch.nn.utils.clip_grad_norm_(self.network.net.parameters(), 5.0) 127 | 128 | for i in range(self.aud_steps): 129 | if i != self.aud_steps - 1: 130 | loss.backward(retain_graph=True) 131 | else: 132 | loss.backward() 133 | torch.nn.utils.clip_grad_norm_(self.network.discriminator.parameters(), 5.0) 134 | self.optimizer_disc.step() 135 | self.optimizer.step() 136 | 137 | running_loss += loss.item() 138 | running_adv_loss += weighted_aud_loss.item() 139 | 140 | auc += calculate_auc(F.sigmoid(Y_logits).cpu().data.numpy(), 141 | targets.cpu().data.numpy()) 142 | 143 | no_iter += 1 144 | 145 | if self.log_freq and (i % self.log_freq == 0): 146 | self.wandb.log({'Training loss': running_loss / (i+1), 'Training AUC': auc / (i+1)}) 147 | 148 | 149 | running_loss /= no_iter 150 | running_adv_loss /= no_iter 151 | 152 | auc = auc / no_iter 153 | print('Training epoch {}: AUC:{}'.format(self.epoch, auc)) 154 | print('Training epoch {}: cls loss:{}, adv loss:{}'.format( 155 | self.epoch, running_loss, running_adv_loss)) 156 | 157 | self.epoch += 1 158 | 159 | def _val(self, loader): 160 | 161 | self.network.eval() 162 | auc, val_loss, log_dict, pred_df = standard_val(self.opt, self.network, loader, self._criterion, self.test_classes, self.wandb) 163 | 164 | print('Validation epoch {}: validation loss:{}, AUC:{}'.format( 165 | self.epoch, val_loss, auc)) 166 | return val_loss, auc, log_dict, pred_df 167 | 168 | def _test(self, loader): 169 | 170 | self.network.eval() 171 | tol_output, tol_target, tol_sensitive, tol_index = standard_test(self.opt, self.network, loader, self._criterion, self.wandb) 172 | 173 | log_dict, t_predictions, pred_df = calculate_metrics(tol_output, tol_target, tol_sensitive, tol_index, self.test_classes) 174 | overall_FPR, overall_FNR, FPRs, FNRs = calculate_FPR_FNR(pred_df, self.test_meta, self.opt) 175 | log_dict['Overall FPR'] = overall_FPR 176 | log_dict['Overall FNR'] = overall_FNR 177 | 178 | for i, FPR in enumerate(FPRs): 179 | log_dict['FPR-group_' + str(i)] = FPR 180 | for i, FNR in enumerate(FNRs): 181 | log_dict['FNR-group_' + str(i)] = FNR 182 | 183 | log_dict = basics.add_dict_prefix(log_dict, 'Test ') 184 | return log_dict 185 | -------------------------------------------------------------------------------- /docs/customization.md: -------------------------------------------------------------------------------- 1 | # Not Enough? Customize your own experiments. 2 | In our benchmarking framework, you can easily add different datasets, network architectures, debiasing algorithms, and evaluation metrics for your own experiments. 3 | 4 | ## Customize Dataset 5 | You can easily add any dataset you need following the three steps below. 6 | 7 | ### STEP 1. Configure dataset 8 | Preprocess the dataset and image files in a way similar to `notebooks/HAM10000-example.ipynb`. 9 | 10 | ### STEP 2. Implement the Dataset Class 11 | We write the dataset class inheriting the regular Pytorch Dataset ([official tutorial](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html)). We provide a base dataset class in `datasets/BaseDataset.py`. In `datasets` folder, create a new script named after your dataset (e.g. `DatasetX.py`), and name the new dataset class with the same name as the script (i.e. `class DatasetX`). An example script is given below. The input paths need to be specified in `configs/datasets.json`. The comments in the code block below may be helpful. 12 | 13 | ```python 14 | import torch 15 | import pickle 16 | import numpy as np 17 | from PIL import Image 18 | from datasets.BaseDataset import BaseDataset 19 | 20 | class DatasetX(BaseDataset): 21 | def __init__(self, dataframe, path_to_pickles, sens_name, sens_classes, transform): 22 | super(DatasetX, self).__init__(dataframe, path_to_pickles, sens_name, sens_classes, transform) 23 | 24 | """ 25 | Dataset class for customized dataset 26 | 27 | Arguments: 28 | dataframe: the metadata in pandas dataframe format. 29 | path_to_pickles: path to the pickle file containing images. 30 | sens_name: which sensitive attribute to use, e.g., Sex. 31 | sens_classes: number of sensitive classes. 32 | transform: whether conduct data transform to the images or not. 33 | 34 | Returns: 35 | index, image, label, and sensitive attribute. 36 | """ 37 | 38 | # load the pickle file containing all images 39 | with open(path_to_pickles, 'rb') as f: 40 | self.tol_images = pickle.load(f) 41 | 42 | self.A = set_A(sens_name) 43 | self.Y = (np.asarray(self.dataframe['binaryLabel'].values) > 0).astype('float') 44 | self.AY_proportion = None 45 | 46 | def __getitem__(self, idx): 47 | # get the item based on the index 48 | item = self.dataframe.iloc[idx] 49 | 50 | # get the image from the pickle file 51 | img = Image.fromarray(self.tol_images[idx]) 52 | # uncomment the line to load the image directly below if you don't want to use pickle file. 53 | # Note, the `path_to_images` variable needs to be modified accordingly. 54 | # img = Image.open(path_to_images[idx]) 55 | 56 | # apply image transform/augmentation 57 | img = self.transform(img) 58 | 59 | label = torch.FloatTensor([int(item['binaryLabel'])]) 60 | 61 | # get sensitive attributes in numerical values 62 | sensitive = self.get_sensitive(self.sens_name, self.sens_classes, item) 63 | 64 | return img, label, sensitive, idx 65 | ``` 66 | 67 | You can also refer to other dataset classes we wrote in the `datasets` folder. 68 | 69 | ### STEP 3. Register the dataset 70 | - Add the dataset name in the choices of `dataset_name` argument in `parse_args.py`. 71 | - Import the dataset class in `datasets/__init__.py`. 72 | - Make sure the paths to the dataset is written to the `configs/datasets.json`. 73 | 74 | Now, you can use your own dataset for training! 75 | 76 | ## Customize Network Architectures 77 | You can add more network architectures (CNN-based) to the framework easily. Transformer models can also be incorporated yet requires some other modifications. 78 | 79 | ### STEP 1. Implement the Network Class 80 | You can incorporate any 2D model in `models/basemodels.py` and 3D model in `models/basemodels_3d.py`. We use the backbone provided by torchvision model zoo, but you can also implement your customized network structures. We use the `create_feature_extractor` function to extract the intermediate feature map. An example is given below: 81 | 82 | ```python 83 | class cusResNet18(nn.Module): 84 | def __init__(self, n_classes, pretrained = True): 85 | super(cusResNet18, self).__init__() 86 | # load the model backbone 87 | resnet = torchvision.models.resnet18(pretrained=pretrained) 88 | # change the output neuron of the fc layer 89 | resnet.fc = nn.Linear(resnet.fc.in_features, n_classes) 90 | self.avgpool = resnet.avgpool 91 | 92 | # specific the feature layer you want to extract 93 | self.returnkey_avg = 'avgpool' 94 | self.returnkey_fc = 'fc' 95 | self.body = create_feature_extractor( 96 | resnet, return_nodes={'avgpool': self.returnkey_avg, 'fc': self.returnkey_fc}) 97 | 98 | def forward(self, x): 99 | outputs = self.body(x) 100 | return outputs[self.returnkey_fc], outputs[self.returnkey_avg].squeeze() 101 | 102 | def inference(self, x): 103 | outputs = self.body(x) 104 | return outputs[self.returnkey_fc], outputs[self.returnkey_avg].squeeze() 105 | ``` 106 | 107 | ### STEP 2. Register the Network 108 | - Add the network backbone name in the choices of `backbone` argument in `parse_args.py`. 109 | 110 | ## Customize Debiasing Algorithms 111 | 112 | ### STEP 1. Implement the Algorithm Class 113 | For debiasing algorithms, we provide a base algorithm class in `models/basenet.py`, which includes necessary initializations, training/testing loop, and evaluations. To add a new algorithm, you should first create a new folder named after your algorithm under `models` folder (e.g. `AlgorithmX`), and create a `__init__.py` file and a `AlgorithmX.py` file. In `AlgorithmX.py`, name the new algorithm class with the same name as the script (i.e. `class AlgorithmX`). Then, follow the three steps below to implement your algorithm: 114 | 115 | 1. Configure Hyper-parameters. 116 | 117 | You can add options for the algorithm-specific hyper-parameter in `parse_args.py`. For example for EnD method, we set two hyper-parameters `alpha` and `beta`: 118 | ```python 119 | # EnD 120 | parser.add_argument('--alpha', type=float, default=0.1, help='weighting parameters alpha for EnD method') 121 | parser.add_argument('--beta', type=float, default=0.1, help='weighting parameters beta for EnD method') 122 | ``` 123 | 124 | 2. Network and other utils. 125 | 126 | - If your method does not require customized network architecture, you can use the regular networks as in `models/basemodels.py` for 2D networks and `models/basemodels_3d.py` for 3D networks. 127 | - If you need to modify the network architecture, you can create a file `models/AlgorithmX/model.py` and implement the network class there with a `forward` function for training, and a `inference` function for testing. Example implementation can be referred to `models/LAFTR/model.py`, etc. 128 | - If you need other functions for training, you can create a file `models/AlgorithmX/utils.py` and implement it there. 129 | - Import modules you need in `models/AlgorithmX/__init__.py`. 130 | 131 | 3. Training loop 132 | - If the train/val/test procedure is the regular loop like that in `models/baseline.py`, and does not require other loss functions, backward propagation, other outputs, etc. (check `standard_train`, `standard_val`, `standard_val` functions in `models/utils.py`), you do not need further modifications. 133 | - If you want to write your own training loop, you can override a new `_train` function within the `AlgorithmX` class. You can have a look at the implementation of other algorithms in `models` folder for reference. Also, override the `_val` and `_test` function if needed. 134 | 135 | 136 | ### STEP 2. Register the Algorithm 137 | - Add the algorithm name (`AlgorithmX`) in the choices of `experiment` argument in `parse_args.py`. 138 | - Import the algorithm class in `models/__init__.py`. 139 | 140 | Now, you can use your own algorithm for training! 141 | 142 | ## Customize Evaluation Metrics 143 | Currently, we implement the evaluation metrics in `utils/evaluation.py` and record all of them in `calculate_metrics` function. You can implement the evaluation metrics in this file and then add it in the `calculate_metrics` function. --------------------------------------------------------------------------------