├── models └── README.md ├── training_code ├── frequency_training │ ├── environment.yml │ ├── requirements.txt │ ├── scripts │ │ └── run_exp.sh │ ├── lib │ │ ├── perturb_helper.py │ │ ├── models.py │ │ ├── infinite_loader.py │ │ ├── utils.py │ │ ├── train_helper.py │ │ ├── cxr_preprocess.py │ │ └── data.py │ ├── README.md │ ├── launchers.py │ ├── Constants.py │ ├── sweep.py │ ├── experiments.py │ ├── train_representation.py │ ├── train_model.py │ └── perturb.py ├── CXR_training │ ├── Emory_CXR │ │ └── Emory_CXR_resnet34_race_detection_2021_06_29.ipynb │ └── CheXpert │ │ └── CheXpert_resnet34_race_detection_2021_09_21_premade_splits.ipynb ├── EM-CS_training │ └── Emory_C-spine_race_detection_2021_06_29.ipynb └── digital_hand_atlas │ └── dha_2_classes.ipynb ├── requirements.txt ├── .github └── workflows │ └── blank.yml ├── data └── README.md └── README.md /models/README.md: -------------------------------------------------------------------------------- 1 | ## models 2 | 3 | The trained models were not released to prevent leakage issues with medical data. -------------------------------------------------------------------------------- /training_code/frequency_training/environment.yml: -------------------------------------------------------------------------------- 1 | name: cxr_bias 2 | dependencies: 3 | - python=3.6 4 | - pip 5 | - pip: 6 | - -r file:requirements.txt -------------------------------------------------------------------------------- /training_code/frequency_training/requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.4.5 2 | torchvision==0.9.1 3 | torch==1.8.1 4 | numpy==1.19.5 5 | pandas==1.1.5 6 | scipy==1.5.4 7 | scikit-learn==0.24.1 8 | scikit-image==0.17.2 9 | Pillow==8.1.0 10 | matplotlib==2.2.2 11 | pytorch_metric_learning==0.9.98 12 | -------------------------------------------------------------------------------- /training_code/frequency_training/scripts/run_exp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../ 4 | 5 | echo "Running experiment:" "${1}" 6 | 7 | slurm_pre="--partition t4v2,t4v1,rtx6000,p100 --gres gpu:1 --mem 40gb -c 4 --exclude gpu080 --job-name ${1} --output /scratch/ssd001/home/haoran/projects/CXR_Bias/logs/${1}_%A.log" 8 | 9 | python sweep.py launch \ 10 | --experiment ${1} \ 11 | --output_dir "/scratch/hdd001/home/haoran/cxr_bias/${1}/" \ 12 | --slurm_pre "${slurm_pre}" \ 13 | --command_launcher "slurm" 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pip install numpy==1.19.5 2 | pip install pandas==1.0.3 3 | pip install matplotlib==3.2.1 4 | pip install Pillow==7.1.2 5 | pip install scikit-learn==0.22.2.post1 6 | pip install tensorflow==2.5.0 7 | pip install tensorflow-gpu==2.5.0 8 | pip install classification_models==0.1 9 | pip install keras==2.3.1 10 | pip install torch==1.9.0 11 | pip install torchvision==0.10.0 12 | pip install h5py==3.1.0 13 | pip install scikit-image==0.18.1 14 | pip install opencv-python==4.2.0.34 15 | pip install seaborn==0.11.1 16 | pip install scipy==1.4.1 17 | pip install efficientnet==1.1.1 18 | pip install tqdm==4.61.1 19 | pip install lime==0.2.0.1 -------------------------------------------------------------------------------- /training_code/frequency_training/lib/perturb_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def fgsm_attack(batch, epsilon, data_grads): 5 | sign_data_grads = data_grads.sign() 6 | 7 | perturbed_batch = batch + epsilon*sign_data_grads 8 | perturbed_batch = torch.clamp(perturbed_batch, 0., 1.) 9 | 10 | return perturbed_batch 11 | 12 | def random_attack(batch, epsilon): 13 | random_noise = torch.normal(mean = 0., std = 1., size = batch.shape).to(batch.device) 14 | 15 | perturbed_batch = batch + epsilon*random_noise 16 | perturbed_batch = torch.clamp(perturbed_batch, 0., 1.) 17 | 18 | return perturbed_batch 19 | -------------------------------------------------------------------------------- /.github/workflows/blank.yml: -------------------------------------------------------------------------------- 1 | 2 | name: hugo CI 3 | 4 | on: 5 | push: 6 | branches: [ master ] 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v2 14 | with: 15 | submodules: true 16 | fetch-depth: 1 17 | 18 | - name: Setup Hugo 19 | uses: peaceiris/actions-hugo@v2 20 | with: 21 | hugo-version: 'latest' 22 | 23 | - name: Build 24 | run: hugo 25 | 26 | - name: Deploy 27 | uses: peaceiris/actions-gh-pages@v3 28 | with: 29 | personal_token: ${{ secrets.PERSONAL_TOKEN }} 30 | external_repository: myuser/dotnetramblings 31 | publish_branch: master 32 | publish_dir: ./public 33 | 34 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ## data 2 | 3 | The experiments run in this github repository used large medical datasets which cannot be pushed onto github. Some of the datasets used as open source and some others are proprietary datasets. Details about open source datasets were attached below. 4 | 5 | | Dataset | Open Source/Proprietary | URL | 6 | | ----------- | ----------- | --- | 7 | | CheXpert | Open Source | https://stanfordmlgroup.github.io/competitions/chexpert/ | 8 | | Emory CXR | Proprietary | | 9 | | MIMIC | Open Source | https://physionet.org/content/mimic-cxr-jpg/2.0.0/ | 10 | | Emory Cervical Spine | Proprietary | | 11 | | Digital Hand Atlas | Open Source | https://ipilab.usc.edu/research/baaweb/ | 12 | | Mammogram | Proprietary | | 13 | -------------------------------------------------------------------------------- /training_code/frequency_training/README.md: -------------------------------------------------------------------------------- 1 | ## Environment Setup 2 | Run the following commands to create the Conda environment: 3 | ``` 4 | conda env create -f environment.yml 5 | conda activate cxr_bias 6 | ``` 7 | 8 | ## Training a Single Model 9 | To train a single model, use `train_model.py` with the appropriate arguments, for example: 10 | ``` 11 | python train_model.py \ 12 | --domain MIMIC \ 13 | --use_pretrained \ 14 | --target race \ 15 | --data_type ifft \ 16 | --model densenet \ 17 | --filter_type high \ 18 | --filter_thres 100 19 | ``` 20 | 21 | 22 | ## Training a Grid of Models 23 | To reproduce the experiments from the paper, use `sweep.py` with the experiment grids defined in `experiments.py`, for example: 24 | ``` 25 | python sweep.py \ 26 | --output_dir=/my/sweep/output/path\ 27 | --command_launcher slurm\ 28 | --slurm_pre slurm_arguments 29 | --experiment IFFTPatched 30 | ``` 31 | 32 | 33 | We provide the bash script used for our main experiments in the scripts directory. You will need to customize them, along with the launcher, to your compute environment. 34 | 35 | ## Aggregating Results 36 | 37 | We provide sample code for creating aggregate results after running all experiments in `AggResults.ipynb`. 38 | 39 | -------------------------------------------------------------------------------- /training_code/frequency_training/launchers.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | import unicodedata 4 | import getpass 5 | import time 6 | 7 | def local_launcher(commands): 8 | for cmd in commands: 9 | subprocess.call(cmd, shell=True) 10 | 11 | def slurm_launcher(commands): 12 | MAX_SLURM_JOBS = 500 13 | for cmd in commands: 14 | block_until_running(MAX_SLURM_JOBS, getpass.getuser()) 15 | subprocess.call(cmd, shell=True) 16 | 17 | def get_num_jobs(user): 18 | # returns a list of (# queued and waiting, # running) 19 | out = subprocess.run(['squeue -u ' + user], shell = True, stdout = subprocess.PIPE).stdout.decode(sys.stdout.encoding) 20 | a = list(filter(lambda x: len(x) > 0, map(lambda x: x.split(), out.split('\n')))) 21 | queued, running = 0,0 22 | for i in a: 23 | if i[0].isnumeric(): 24 | if i[4].strip() == 'PD': 25 | queued += 1 26 | else: 27 | running += 1 28 | return (queued, running) 29 | 30 | def block_until_running(n, user): 31 | while True: 32 | if sum(get_num_jobs(user)) < n: 33 | time.sleep(0.2) 34 | return True 35 | else: 36 | time.sleep(10) 37 | 38 | REGISTRY = { 39 | 'local': local_launcher, 40 | 'slurm': slurm_launcher 41 | } -------------------------------------------------------------------------------- /training_code/frequency_training/lib/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as mod 5 | import timm 6 | 7 | class DenseNet(nn.Module): 8 | def __init__(self, use_pretrained, num_extra=0, n_outputs = 2): 9 | super().__init__() 10 | self.num_extra = num_extra 11 | self.model_conv = mod.densenet121(pretrained= use_pretrained) 12 | self.num_ftrs = self.model_conv.classifier.in_features + num_extra 13 | self.model_conv.classifier = nn.Identity() 14 | self.class_conf = nn.Linear(self.num_ftrs,n_outputs) 15 | 16 | def forward(self,x,*args): 17 | assert(len(args) <= 1) 18 | img_conv_out = self.model_conv(x) 19 | if self.num_extra: 20 | assert(args[0].shape[1] == self.num_extra) 21 | img_conv_out = torch.cat((img_conv_out, args[0]), -1) 22 | out = self.class_conf(img_conv_out) 23 | return out 24 | 25 | class VisionTransformer(nn.Module): 26 | def __init__(self, use_pretrained, num_extra=0, n_outputs = 2): 27 | super().__init__() 28 | self.num_extra = num_extra 29 | self.model_conv = timm.create_model('vit_deit_small_patch16_224', pretrained= use_pretrained) 30 | self.num_ftrs = self.model_conv.head.in_features 31 | self.model_conv.head = nn.Identity() 32 | self.class_conf = nn.Linear(self.num_ftrs, n_outputs) 33 | 34 | def forward(self,x,*args): 35 | assert(len(args) <= 1) 36 | img_conv_out = self.model_conv(x) 37 | if self.num_extra: 38 | assert(args[0].shape[1] == self.num_extra) 39 | img_conv_out = torch.cat((img_conv_out, args[0]), -1) 40 | out = self.class_conf(img_conv_out) 41 | return out -------------------------------------------------------------------------------- /training_code/frequency_training/Constants.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import pandas as pd 3 | import numpy as np 4 | import os 5 | 6 | ## CXR 7 | 8 | df_paths = { 9 | 'MIMIC': { 10 | 'train': "/scratch/hdd001/projects/ml4h/projects/mimic_access_required/MIMIC-CXR/laleh/new_split/8-1-1/new_train.csv", 11 | 'val': "/scratch/hdd001/projects/ml4h/projects/mimic_access_required/MIMIC-CXR/laleh/new_split/8-1-1/new_valid.csv", 12 | 'test': "/scratch/hdd001/projects/ml4h/projects/mimic_access_required/MIMIC-CXR/laleh/new_split/8-1-1/new_test.csv" 13 | }, 14 | 'CXP':{ 15 | 'train': "/scratch/hdd001/projects/ml4h/projects/CheXpert/split/July19/new_train.csv", 16 | 'val': "/scratch/hdd001/projects/ml4h/projects/CheXpert/split/July19/new_valid.csv", 17 | 'test': "/scratch/hdd001/projects/ml4h/projects/CheXpert/split/July19/new_test.csv" 18 | }, 19 | 'NIH':{ 20 | 'train': "/scratch/hdd001/projects/ml4h/projects/NIH/split/July16/train.csv", 21 | 'val': "/scratch/hdd001/projects/ml4h/projects/NIH/split/July16/valid.csv", 22 | 'test': "/scratch/hdd001/projects/ml4h/projects/NIH/split/July16/test.csv" 23 | }, 24 | 'PAD':{ 25 | 'train': "/scratch/hdd001/projects/ml4h/projects/padchest/PADCHEST/haoran_split/train.csv", 26 | 'val': "/scratch/hdd001/projects/ml4h/projects/padchest/PADCHEST/haoran_split/valid.csv", 27 | 'test': "/scratch/hdd001/projects/ml4h/projects/padchest/PADCHEST/haoran_split/test.csv" 28 | } 29 | } 30 | 31 | image_paths = { 32 | 'MIMIC': "/scratch/hdd001/projects/ml4h/projects/mimic_access_required/MIMIC-CXR/", 33 | 'CXP': "/scratch/hdd001/projects/ml4h/projects/CheXpert/", 34 | 'NIH': "/scratch/hdd001/projects/ml4h/projects/NIH/images/", 35 | 'PAD': '/scratch/hdd001/projects/ml4h/projects/padchest/PADCHEST/images-224' 36 | } 37 | 38 | MIMIC_details = "/scratch/hdd001/projects/ml4h/projects/mimic_access_required/MIMIC-CXR/vin_new_split/8-1-1/mimic-cxr-metadata-detail.csv" 39 | CXP_details = "/scratch/hdd001/projects/ml4h/projects/CheXpert/chexpert_demographics.csv" 40 | PAD_details = "/scratch/hdd001/projects/ml4h/projects/padchest/PADCHEST/PADCHEST_chest_x_ray_images_labels_160K_01.02.19.csv" 41 | cache_dir = "/scratch/hdd001/home/{}/datasets/cache".format(os.environ.get('USER')) 42 | 43 | IMAGENET_MEAN = [0.485, 0.456, 0.406] # Mean of ImageNet dataset (used for normalization) 44 | IMAGENET_STD = [0.229, 0.224, 0.225] # Std of ImageNet dataset (used for normalization) 45 | 46 | take_labels = ['No Finding', 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Pneumonia', 'Pneumothorax', 'Consolidation','Edema'] 47 | 48 | race_mapping = { 49 | 0: 'White', 50 | 1: "Black", 51 | 2: "Hispanic", 52 | 3: "Asian", 53 | 4: "Other" 54 | } 55 | 56 | gender_mapping = { 57 | 0: 'F', 58 | 1: 'M' 59 | } 60 | 61 | N_PATCHES = 9 62 | -------------------------------------------------------------------------------- /training_code/frequency_training/lib/infinite_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class StatefulSampler(torch.utils.data.sampler.Sampler): 4 | def __init__(self, data_source, shuffle=False): 5 | self.data = data_source 6 | self.shuffle = shuffle 7 | 8 | # initial dataloader index 9 | self.init_index() 10 | 11 | def init_index(self): 12 | if self.shuffle: 13 | self.indices = torch.randperm(len(self.data)) 14 | else: 15 | self.indices = torch.arange(len(self.data)) 16 | 17 | self.data_counter = 0 18 | 19 | def __iter__(self): 20 | return self 21 | 22 | def __len__(self): 23 | return len(self.data) 24 | 25 | def __next__(self): 26 | if self.data_counter == len(self.data): 27 | self.init_index() 28 | raise StopIteration() 29 | else: 30 | ele = self.indices[self.data_counter] 31 | self.data_counter += 1 32 | return int(ele) 33 | 34 | def state_dict(self, dataloader_iter=None): 35 | prefetched_num = 0 36 | # in the case of multiworker dataloader, the helper worker could be 37 | # pre-fetching the data that is not consumed by the main dataloader. 38 | # we need to subtract the unconsumed part . 39 | if dataloader_iter is not None: 40 | if dataloader_iter._num_workers > 0: 41 | batch_size = dataloader_iter._index_sampler.batch_size 42 | prefetched_num = \ 43 | (dataloader_iter._send_idx - dataloader_iter._rcvd_idx) * batch_size 44 | return { 45 | 'indices': self.indices, 46 | 'data_counter': self.data_counter - prefetched_num, 47 | } 48 | 49 | def load_state_dict(self, state_dict): 50 | self.indices = state_dict['indices'] 51 | self.data_counter = state_dict['data_counter'] 52 | 53 | class _InfiniteSampler(torch.utils.data.Sampler): 54 | """Wraps another Sampler to yield an infinite stream.""" 55 | def __init__(self, sampler): 56 | self.sampler = sampler 57 | self.batch_size = sampler.batch_size 58 | 59 | def __iter__(self): 60 | while True: 61 | for batch in self.sampler: 62 | yield batch 63 | 64 | class InfiniteDataLoader: 65 | def __init__(self, dataset, batch_size, num_workers): 66 | super().__init__() 67 | self.sampler = StatefulSampler(dataset, shuffle = True) 68 | 69 | batch_sampler = torch.utils.data.BatchSampler( 70 | self.sampler, 71 | batch_size=batch_size, 72 | drop_last=True) 73 | 74 | self._infinite_iterator = iter(torch.utils.data.DataLoader( 75 | dataset, 76 | num_workers=num_workers, 77 | batch_sampler=_InfiniteSampler(batch_sampler) 78 | )) 79 | 80 | def __iter__(self): 81 | while True: 82 | yield next(self._infinite_iterator) 83 | 84 | def __len__(self): 85 | raise ValueError -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Contributors][contributors-shield]][contributors-url] 2 | [![Stargazers][stars-shield]][stars-url] 3 | [![Issues][issues-shield]][issues-url] 4 | 5 | # AI-Vengers 6 | 7 | This GitHub repository holds training and validation code for Deep Learning and Machine Learning models to detect racial demographics of patients through the use of medical images. 8 | 9 | The **data** and **models** folders will be empty as - 10 | 1. The data and model file occupy a large amount of data and cannot be pushed onto GitHub repositories. 11 | 2. Some of the data used to conduct experiments is proprietary. URLs were attached for the open-source datasets. 12 | 3. The trained ML/DL models can be used to re-create the patient information unto certain levels which could leak the proprietary data. 13 | 14 | The training data folder has training code for all the experiments. The experiments, corresponding data and model were as following - 15 | 16 | | Training Folder Name | Training File Name | Data | Model | 17 | | -------------------- | ------------------ | ---- | ----- | 18 | | CXR_training | CheXpert_resnet34_race_detection_2021_06_29.ipynb | CheXpert | ResNet34 | 19 | | CXR_training | Emory_CXR_resnet34_race_detection_2021_06_29.ipynb | Emory CXR | ResNet34 | 20 | | CXR_training | MIMIC_resnet34_race_detection_2021_06_29.ipynb | MIMIC | Resnet34 | 21 | | EM-CS_training | Emory_C-spine_race_detection_2021_06_29.ipynb | Emory Cervical Spine | Resnet34 | 22 | | EM_Mammo_Training | training code.ipynb | Mammogram | EfficientNetB2 | 23 | | Densenet121_CXR_Training | Lung_segmentation_MIMIC.ipynb | MIMIC | U-Net | 24 | | Densenet121_CXR_Training | Race classification with No Finding label only_MIMIC_Densenet121.ipynb | MIMIC | DenseNet121 | 25 | | Densenet121_CXR_Training | Race classification_MIMIC_Densenet121.ipynb | MIMIC | DenseNet121 | 26 | | Densenet121_CXR_Training | Race_classification_Emory_Densenet121.ipynb | Emory CXR | DenseNet121 | 27 | | digital_hand_atlas | dha_2_classes.ipynb | Digital Hand Atlas | ResNet50 | 28 | 29 | The final ipython-notebook — bias_pred.ipynb has validation code for all the above training models (except frequency training). 30 | 31 | To run the validation code - 32 | 1. Fork/Download the GitHub repository. 33 | 2. Fetch the data from the data URLs for open-source datasets and drop them in the data folder. 34 | 3. Run the corresponding training code and save the trained model in the models folder. 35 | 4. Change the model path in the validation code and the corresponding function. 36 | 37 | https://emory-hiti.github.io/AI-Vengers/ 38 | 39 | [contributors-shield]: https://img.shields.io/github/contributors/Emory-HITI/AI-Vengers.svg?style=flat-square 40 | [contributors-url]: https://github.com/Emory-HITI/AI-Vengers/graphs/contributors 41 | [stars-shield]: https://img.shields.io/github/stars/Emory-HITI/AI-Vengers.svg?style=flat-square 42 | [stars-url]: https://github.com/Emory-HITI/AI-Vengers/stargazers 43 | [issues-shield]: https://img.shields.io/github/issues/Emory-HITI/AI-Vengers.svg?style=flat-square 44 | [issues-url]: https://github.com/Emory-HITI/AI-Vengers/issues 45 | -------------------------------------------------------------------------------- /training_code/frequency_training/lib/utils.py: -------------------------------------------------------------------------------- 1 | import getpass 2 | import os 3 | import torch 4 | from pathlib import Path 5 | import numpy as np 6 | import math 7 | 8 | class EarlyStopping: 9 | # adapted from https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py 10 | def __init__(self, patience=5): 11 | self.patience = patience 12 | self.counter = 0 13 | self.best_score = None 14 | self.early_stop = False 15 | 16 | def __call__(self, val_loss, step, state_dict, path): # lower loss is better 17 | score = -val_loss # higher score is better 18 | 19 | if self.best_score is None: 20 | self.best_score = score 21 | self.step = step 22 | save_model(state_dict, path) 23 | elif score < self.best_score: 24 | self.counter += 1 25 | # print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 26 | if self.counter >= self.patience: 27 | self.early_stop = True 28 | else: 29 | save_model(state_dict, path) 30 | self.best_score = score 31 | self.step = step 32 | self.counter = 0 33 | 34 | def save_model(state_dict, path): 35 | torch.save(state_dict, path) 36 | 37 | def save_checkpoint(model, optimizer, scheduler, sampler_dict, start_step, es, rng): 38 | slurm_job_id = os.environ.get('SLURM_JOB_ID') 39 | 40 | if slurm_job_id is not None and Path('/checkpoint/').exists(): 41 | torch.save({'model_dict': model.state_dict(), 42 | 'optimizer_dict': optimizer.state_dict(), 43 | 'scheduler_dict': scheduler.state_dict(), 44 | 'sampler_dict': sampler_dict, 45 | 'start_step': start_step, 46 | 'es': es, 47 | 'rng': rng 48 | } 49 | , 50 | Path(f'/checkpoint/{getpass.getuser()}/{slurm_job_id}/chkpt').open('wb') 51 | ) 52 | 53 | 54 | def has_checkpoint(): 55 | slurm_job_id = os.environ.get('SLURM_JOB_ID') 56 | if slurm_job_id is not None and Path(f'/checkpoint/{getpass.getuser()}/{slurm_job_id}/chkpt').exists(): 57 | return True 58 | return False 59 | 60 | 61 | def load_checkpoint(): 62 | slurm_job_id = os.environ.get('SLURM_JOB_ID') 63 | if slurm_job_id is not None and Path('/checkpoint/').exists(): 64 | return torch.load(f'/checkpoint/{getpass.getuser()}/{slurm_job_id}/chkpt') 65 | 66 | def delete_checkpoint(): 67 | slurm_job_id = os.environ.get('SLURM_JOB_ID') 68 | chkpt_file = Path(f'/checkpoint/{getpass.getuser()}/{slurm_job_id}/chkpt') 69 | if slurm_job_id is not None and chkpt_file.exists(): 70 | return chkpt_file.unlink() 71 | 72 | def fft(img): 73 | assert(img.ndim == 2) 74 | img_c2 = np.fft.fft2(img) 75 | img_c3 = np.fft.fftshift(img_c2) 76 | spectra = np.log(1+np.abs(img_c3)) 77 | return spectra 78 | 79 | 80 | def filter_circle(TFcircleIN,fft_img_channel): 81 | temp = np.zeros(fft_img_channel.shape[:2],dtype=complex) 82 | temp[TFcircleIN] = fft_img_channel[TFcircleIN] 83 | return(temp) 84 | 85 | def draw_circle(shape,diameter): 86 | assert len(shape) == 2 87 | TF = np.zeros(shape,dtype=np.bool) 88 | center = np.array(TF.shape)/2.0 89 | 90 | for iy in range(shape[0]): 91 | for ix in range(shape[1]): 92 | TF[iy,ix] = (iy- center[0])**2 + (ix - center[1])**2 < diameter **2 93 | return(TF) 94 | 95 | def filter_and_ifft(x, filter): 96 | return np.real_if_close(np.fft.ifft2(np.fft.ifftshift(filter_circle(filter, x)))) 97 | 98 | def split_tensor(tensor, tile_size=256, offset=256): 99 | tiles = [] 100 | h, w = tensor.size(1), tensor.size(2) 101 | for y in range(int(math.floor(h/offset))): 102 | for x in range(int(math.floor(w/offset))): 103 | tiles.append(tensor[:, offset*y:min(offset*y+tile_size, h), offset*x:min(offset*x+tile_size, w)]) 104 | if tensor.is_cuda: 105 | base_tensor = torch.zeros(tensor.size(), device=tensor.get_device()) 106 | else: 107 | base_tensor = torch.zeros(tensor.size()) 108 | return tiles, base_tensor 109 | 110 | def blacken_tensor(tensor, patch_ind, tile_size=256, offset=256): 111 | h, w = tensor.size(1), tensor.size(2) 112 | c = 0 113 | for y in range(int(math.floor(h/offset))): 114 | for x in range(int(math.floor(w/offset))): 115 | if c == patch_ind: 116 | tensor[:, offset*y:min(offset*y+tile_size, h), offset*x:min(offset*x+tile_size, w)] = 0 117 | c += 1 118 | return tensor 119 | -------------------------------------------------------------------------------- /training_code/frequency_training/lib/train_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from collections import defaultdict 5 | 6 | from pytorch_metric_learning import miners, losses 7 | from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator 8 | 9 | def train_step(data, target, model, device, optimizer, extras = None): 10 | model.train() 11 | 12 | data = data.float().to(device) 13 | target = target.long().to(device) 14 | target = target.view((-1)) 15 | 16 | if torch.is_tensor(extras): 17 | extras = extras.float().to(device) 18 | 19 | output = model(data, extras) ## no softmax applied 20 | 21 | loss = F.cross_entropy(output, target) 22 | 23 | optimizer.zero_grad() 24 | loss.backward() 25 | optimizer.step() 26 | 27 | pred = F.softmax(output, dim=1).max(1, keepdim=True)[1] 28 | 29 | acc = 1.*pred.eq(target.view_as(pred)).sum().item()/len(data) 30 | 31 | return loss.item(), acc 32 | 33 | 34 | def prediction(model, device, data_loader): 35 | model.eval() 36 | prob_list = []; pred_list = []; target_list = [] 37 | total_loss = 0.0 38 | with torch.no_grad(): 39 | for [data, target, extras] in data_loader: 40 | data = data.float().to(device) # add channel dimension 41 | target = target.long().to(device) 42 | target = target.view((-1,)) 43 | 44 | if torch.is_tensor(extras): 45 | extras = extras.float().to(device) 46 | 47 | target_list = target_list + list(target.cpu().detach().tolist()) 48 | 49 | output = model(data, extras) 50 | 51 | total_loss += F.cross_entropy(output, target) 52 | 53 | prob = F.softmax(output, dim=1) 54 | prob_list = prob_list + list(prob.cpu().detach().tolist()) 55 | 56 | pred = prob.max(1, keepdim=True)[1] 57 | pred_list = pred_list + list(pred.cpu().detach().tolist()) 58 | 59 | total_loss /= len(data_loader.dataset) 60 | 61 | return total_loss, np.array(prob_list), np.array(pred_list), np.array(target_list) 62 | 63 | def train_triplet_step(data, target, model, device, optimizer, miner, extras = None): 64 | model.train() 65 | 66 | loss_func = losses.MarginLoss() 67 | acc_calc = AccuracyCalculator() 68 | 69 | data = data.float().to(device) 70 | target = target.long().to(device) 71 | target = target.view((-1)) 72 | 73 | if torch.is_tensor(extras): 74 | extras = extras.float().to(device) 75 | 76 | embedding = model(data, extras) 77 | triplets = miner.mine(embedding, target, embedding, target) 78 | 79 | loss = loss_func(embedding, target, triplets) 80 | 81 | optimizer.zero_grad() 82 | loss.backward() 83 | optimizer.step() 84 | 85 | with torch.no_grad(): 86 | acc_dict = acc_calc.get_accuracy(embedding, embedding, target, target, embeddings_come_from_same_source=True) 87 | 88 | return loss.item(), acc_dict["precision_at_1"] 89 | 90 | def representation(model, device, data_loader): 91 | model.eval() 92 | target_list = [] 93 | embedding_list = [] 94 | total_loss = 0.0 95 | 96 | loss_func = losses.MarginLoss() 97 | acc_calc = AccuracyCalculator() 98 | miner = miners.BatchEasyHardMiner(pos_strategy='all', neg_strategy='all') 99 | 100 | acc_dicts = defaultdict(list) 101 | with torch.no_grad(): 102 | for [data, target, extras] in data_loader: 103 | data = data.float().to(device) # add channel dimension 104 | target = target.long().to(device) 105 | target = target.view((-1,)) 106 | 107 | if torch.is_tensor(extras): 108 | extras = extras.float().to(device) 109 | 110 | target_list = target_list + list(target.cpu().detach().tolist()) 111 | 112 | embedding = model(data, extras) 113 | triplets = miner.mine(embedding, target, embedding, target) 114 | 115 | embedding_list = embedding_list + list(embedding.cpu().detach().tolist()) 116 | 117 | total_loss += loss_func(embedding, target, triplets) 118 | 119 | acc_dict = acc_calc.get_accuracy(embedding, embedding, target, target, embeddings_come_from_same_source=True) 120 | for key in acc_dict: 121 | acc_dicts[key].append(acc_dict[key]) 122 | 123 | total_loss /= len(data_loader.dataset) 124 | 125 | avg_acc_dict = {key: np.mean(acc_dicts[key]) for key in acc_dicts} 126 | return total_loss, avg_acc_dict, np.array(embedding_list), np.array(target_list) 127 | -------------------------------------------------------------------------------- /training_code/frequency_training/lib/cxr_preprocess.py: -------------------------------------------------------------------------------- 1 | import Constants 2 | from torch.utils.data import Dataset 3 | import os 4 | import numpy as np 5 | from PIL import Image 6 | import pandas as pd 7 | import torch 8 | from pathlib import Path 9 | from collections import defaultdict 10 | import re 11 | 12 | def preprocess_MIMIC(split, only_frontal, return_all_labels = False): 13 | details = pd.read_csv(Constants.MIMIC_details) 14 | details = details.drop(columns=['dicom_id', 'study_id', 'religion', 'marital_status', 'gender']) 15 | details.drop_duplicates(subset="subject_id", keep="first", inplace=True) 16 | df = pd.merge(split, details) 17 | 18 | copy_subjectid = df['subject_id'] 19 | df = df.drop(columns = ['subject_id']).replace( 20 | [[None], -1, "[False]", "[True]", "[ True]", 'UNABLE TO OBTAIN', 'UNKNOWN', 'MARRIED', 'LIFE PARTNER', 21 | 'DIVORCED', 'SEPARATED', '0-10', '10-20', '20-30', '30-40', '40-50', '50-60', '60-70', '70-80', '80-90', 22 | '>=90'], 23 | [0, 0, 0, 1, 1, 0, 0, 'MARRIED/LIFE PARTNER', 'MARRIED/LIFE PARTNER', 'DIVORCED/SEPARATED', 24 | 'DIVORCED/SEPARATED', '0-20', '0-20', '20-40', '20-40', '40-60', '40-60', '60-80', '60-80', '80-', '80-']) 25 | 26 | df['subject_id'] = copy_subjectid.astype(str) 27 | df['Age'] = df["age_decile"] 28 | df['Sex'] = df["gender"] 29 | df = df.drop(columns=["age_decile", 'gender']) 30 | df = df.rename( 31 | columns = { 32 | 'Pleural Effusion':'Effusion', 33 | }) 34 | df['study_id'] = df['path'].apply(lambda x: x[x.index('p'):x.rindex('/')]) 35 | df['path'] = Constants.image_paths['MIMIC'] + df['path'].astype(str) 36 | df['frontal'] = (df.view == 'frontal') 37 | if only_frontal: 38 | df = df[df.frontal] 39 | 40 | df['env'] = 'MIMIC' 41 | df.loc[df.Age == 0, 'Age'] = '0-20' 42 | 43 | df = df[(~df.race.isin(['UNKNOWN', 'UNABLE TO OBTAIN', 0, '0'])) & (~pd.isnull(df.race))] 44 | 45 | race_mapping = defaultdict(lambda:4) 46 | race_mapping['WHITE'] = 0 47 | race_mapping['BLACK/AFRICAN AMERICAN'] = 1 48 | race_mapping['HISPANIC/LATINO'] = 2 49 | race_mapping['ASIAN'] = 3 50 | 51 | df['race'] = df['race'].map(race_mapping) 52 | df['Sex'] = (df['Sex'] == 'M').astype(int) 53 | 54 | return df[['subject_id','path','Sex', "Age", 'env', 'frontal', 'study_id', 'race'] + Constants.take_labels + 55 | (['Enlarged Cardiomediastinum', 'Airspace Opacity', 'Lung Lesion', 'Pleural Other', 'Fracture', 'Support Devices'] if return_all_labels else [])] 56 | 57 | def preprocess_CXP(split, only_frontal, return_all_labels = False): 58 | details = pd.read_csv(Constants.CXP_details)[['PATIENT', 'PRIMARY_RACE']] 59 | details['subject_id'] = details['PATIENT'].apply(lambda x: x[7:]).astype(int).astype(str) 60 | 61 | split['Age'] = np.where(split['Age'].between(0,19), 19, split['Age']) 62 | split['Age'] = np.where(split['Age'].between(20,39), 39, split['Age']) 63 | split['Age'] = np.where(split['Age'].between(40,59), 59, split['Age']) 64 | split['Age'] = np.where(split['Age'].between(60,79), 79, split['Age']) 65 | split['Age'] = np.where(split['Age']>=80, 81, split['Age']) 66 | 67 | copy_subjectid = split['subject_id'] 68 | split = split.drop(columns = ['subject_id']).replace([[None], -1, "[False]", "[True]", "[ True]", 19, 39, 59, 79, 81], 69 | [0, 0, 0, 1, 1, "0-20", "20-40", "40-60", "60-80", "80-"]) 70 | 71 | split['subject_id'] = copy_subjectid.astype(str) 72 | split['Sex'] = np.where(split['Sex']=='Female', 'F', split['Sex']) 73 | split['Sex'] = np.where(split['Sex']=='Male', 'M', split['Sex']) 74 | split = split.rename( 75 | columns = { 76 | 'Pleural Effusion':'Effusion', 77 | 'Lung Opacity': 'Airspace Opacity' 78 | }) 79 | split['path'] = Constants.image_paths['CXP'] + split['Path'].astype(str) 80 | split['frontal'] = (split['Frontal/Lateral'] == 'Frontal') 81 | if only_frontal: 82 | split = split[split['frontal']] 83 | split['env'] = 'CXP' 84 | split['study_id'] = split['path'].apply(lambda x: x[x.index('patient'):x.rindex('/')]) 85 | 86 | split = pd.merge(split, details, on = 'subject_id', how = 'inner') 87 | split = split[(~split.PRIMARY_RACE.isin(['Unknown', 'Patient Refused'])) & (~pd.isnull(split.PRIMARY_RACE))] 88 | 89 | def cat_race(r): 90 | if r.startswith('White'): 91 | return 0 92 | elif r.startswith('Black'): 93 | return 1 94 | elif 'Hispanic' in r and 'non-Hispanic' not in r: 95 | return 2 96 | elif 'Asian' in r: 97 | return 3 98 | else: 99 | return 4 100 | 101 | split['race'] = split['PRIMARY_RACE'].apply(cat_race) 102 | split['Sex'] = (split['Sex'] == 'M').astype(int) 103 | 104 | return split[['subject_id','path','Sex',"Age", 'env', 'frontal','study_id', 'race'] + Constants.take_labels + 105 | (['Enlarged Cardiomediastinum', 'Airspace Opacity', 'Lung Lesion', 'Pleural Other', 'Fracture', 'Support Devices'] if return_all_labels else [])] 106 | 107 | def get_process_func(env): 108 | if env == 'MIMIC': 109 | return preprocess_MIMIC 110 | elif env == 'CXP': 111 | return preprocess_CXP 112 | else: 113 | raise NotImplementedError 114 | -------------------------------------------------------------------------------- /training_code/frequency_training/sweep.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import getpass 4 | import hashlib 5 | import json 6 | import os 7 | import random 8 | import shutil 9 | import time 10 | import uuid 11 | 12 | import numpy as np 13 | import torch 14 | 15 | import tqdm 16 | import shlex 17 | import experiments 18 | import launchers 19 | 20 | class Job: 21 | NOT_LAUNCHED = 'Not launched' 22 | INCOMPLETE = 'Incomplete' 23 | DONE = 'Done' 24 | 25 | def __init__(self, train_args, sweep_output_dir, slurm_pre, script_name, exp_name): 26 | args_str = json.dumps(train_args, sort_keys=True) 27 | args_hash = hashlib.md5(args_str.encode('utf-8')).hexdigest() 28 | self.output_dir = os.path.join(sweep_output_dir, args_hash) 29 | 30 | self.train_args = copy.deepcopy(train_args) 31 | self.train_args['output_dir'] = self.output_dir 32 | self.train_args['exp_name'] = exp_name 33 | command = ['python', script_name] 34 | for k, v in sorted(self.train_args.items()): 35 | if isinstance(v, (list, tuple)): 36 | v = ' '.join([str(v_) for v_ in v]) 37 | elif isinstance(v, str): 38 | v = shlex.quote(v) 39 | 40 | if k: 41 | if not isinstance(v, bool): 42 | command.append(f'--{k} {v}') 43 | else: 44 | if v: 45 | command.append(f'--{k}') 46 | else: 47 | pass 48 | 49 | self.command_str = ' '.join(command) 50 | self.command_str = f'sbatch {slurm_pre} --wrap "{self.command_str}"' 51 | 52 | print(self.command_str) 53 | 54 | if os.path.exists(os.path.join(self.output_dir, 'done')): 55 | self.state = Job.DONE 56 | elif os.path.exists(self.output_dir): 57 | self.state = Job.INCOMPLETE 58 | else: 59 | self.state = Job.NOT_LAUNCHED 60 | 61 | def __str__(self): 62 | job_info = {i:self.train_args[i] for i in self.train_args if i not in ['experiment','output_dir']} 63 | return '{}: {} {}'.format( 64 | self.state, 65 | self.output_dir, 66 | job_info) 67 | 68 | @staticmethod 69 | def launch(jobs, launcher_fn): 70 | print('Launching...') 71 | jobs = jobs.copy() 72 | print('Making job directories:') 73 | for job in tqdm.tqdm(jobs, leave=False): 74 | os.makedirs(job.output_dir, exist_ok=True) 75 | commands = [job.command_str for job in jobs] 76 | launcher_fn(commands) 77 | print(f'Launched {len(jobs)} jobs!') 78 | 79 | @staticmethod 80 | def delete(jobs): 81 | print('Deleting...') 82 | for job in jobs: 83 | shutil.rmtree(job.output_dir) 84 | print(f'Deleted {len(jobs)} jobs!') 85 | 86 | def ask_for_confirmation(): 87 | response = input('Are you sure? (y/n) ') 88 | if not response.lower().strip()[:1] == "y": 89 | exit(0) 90 | 91 | def make_args_list(experiment): 92 | return experiments.get_hparams(experiment) 93 | 94 | if __name__ == "__main__": 95 | parser = argparse.ArgumentParser(description='Run a sweep') 96 | parser.add_argument('command', choices=['launch', 'delete_incomplete', 'delete_all']) 97 | parser.add_argument('--experiment', type=str, required = True) 98 | parser.add_argument('--output_dir', type=str, required=True) 99 | parser.add_argument('--skip_confirmation', action='store_true') 100 | parser.add_argument('--slurm_pre', type=str, required = True) 101 | parser.add_argument('--command_launcher', type=str, required=True) 102 | args = parser.parse_args() 103 | 104 | args_list = make_args_list(args.experiment) 105 | jobs = [Job(train_args, args.output_dir, args.slurm_pre, experiments.get_script_name(args.experiment), args.experiment) for train_args in args_list] 106 | 107 | for job in jobs: 108 | print(job) 109 | print("{} jobs: {} done, {} incomplete, {} not launched.".format( 110 | len(jobs), 111 | len([j for j in jobs if j.state == Job.DONE]), 112 | len([j for j in jobs if j.state == Job.INCOMPLETE]), 113 | len([j for j in jobs if j.state == Job.NOT_LAUNCHED])) 114 | ) 115 | 116 | if args.command == 'launch': 117 | to_launch = [j for j in jobs if j.state in [Job.NOT_LAUNCHED, job.INCOMPLETE]] 118 | print(f'About to launch {len(to_launch)} jobs.') 119 | if not args.skip_confirmation: 120 | ask_for_confirmation() 121 | launcher_fn = launchers.REGISTRY[args.command_launcher] 122 | Job.launch(to_launch, launcher_fn) 123 | 124 | elif args.command == 'delete_incomplete': 125 | to_delete = [j for j in jobs if j.state == Job.INCOMPLETE] 126 | print(f'About to delete {len(to_delete)} jobs.') 127 | if not args.skip_confirmation: 128 | ask_for_confirmation() 129 | Job.delete(to_delete) 130 | 131 | elif args.command == 'delete_all': 132 | to_delete = [j for j in jobs if j.state == Job.INCOMPLETE or j.state == job.DONE] 133 | print(f'About to delete {len(to_delete)} jobs.') 134 | if not args.skip_confirmation: 135 | ask_for_confirmation() 136 | Job.delete(to_delete) 137 | -------------------------------------------------------------------------------- /training_code/frequency_training/experiments.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import Constants 3 | import os 4 | 5 | def combinations(grid): 6 | keys = list(grid.keys()) 7 | limits = [len(grid[i]) for i in keys] 8 | all_args = [] 9 | 10 | index = [0]*len(keys) 11 | 12 | while True: 13 | args = {} 14 | for c, i in enumerate(index): 15 | key = keys[c] 16 | args[key] = grid[key][i] 17 | all_args.append(args) 18 | 19 | # increment index by 1 20 | carry = False 21 | index[-1] += 1 22 | ind = len(index) -1 23 | while ind >= 0: 24 | if carry: 25 | index[ind] += 1 26 | 27 | if index[ind] == limits[ind]: 28 | index[ind] = 0 29 | carry = True 30 | else: 31 | carry = False 32 | ind -= 1 33 | 34 | if carry: 35 | break 36 | 37 | return all_args 38 | 39 | def get_hparams(experiment): 40 | if experiment not in globals(): 41 | raise NotImplementedError 42 | return globals()[experiment].hparams() 43 | 44 | 45 | def get_script_name(experiment): 46 | if experiment not in globals(): 47 | raise NotImplementedError 48 | return globals()[experiment].fname 49 | 50 | 51 | #### write experiments here 52 | class IFFT(): 53 | fname = 'train_model.py' 54 | @staticmethod 55 | def hparams(): 56 | grid = { 57 | 'domain': ['both', 'MIMIC', 'CXP'], 58 | 'use_pretrained': [True], 59 | 'seed': [0], 60 | 'target': ['race', 'gender', 'Pneumonia'], 61 | 'data_type': ['ifft'], 62 | 'model': ['densenet'], 63 | 'filter_type': ['low', 'high'], 64 | 'filter_thres': [1, 5, 10, 25, 50, 75, 100, 125, 150, 200], 65 | 'add_noise': [False, True], 66 | 'crop_patch_at_end': [False], 67 | 'augment': [False] 68 | } 69 | return combinations(grid) 70 | 71 | class Normal(): 72 | fname = 'train_model.py' 73 | @staticmethod 74 | def hparams(): 75 | grid = { 76 | 'domain': ['both', 'MIMIC', 'CXP'], 77 | 'use_pretrained': [True], 78 | 'seed': [0], 79 | 'target': ['race', 'gender', 'Pneumonia'], 80 | 'data_type': ['normal'], 81 | 'model': ['densenet', 'vision_transformer'], 82 | 'pixel_thres': [1.0, 0.6], 83 | 'augment': [True, False] 84 | } 85 | return combinations(grid) 86 | 87 | class NormalMIMIC(): 88 | fname = 'train_model.py' 89 | @staticmethod 90 | def hparams(): 91 | grid = { 92 | 'domain': ['MIMIC'], 93 | 'use_pretrained': [True], 94 | 'seed': [0], 95 | 'target': ['race', 'gender', 'Pneumonia'], 96 | 'data_type': ['normal'], 97 | 'model': ['densenet'], 98 | 'pixel_thres': [1.0], 99 | 'augment': [True, False] 100 | } 101 | return combinations(grid) 102 | 103 | class RepresentationNormalMIMIC(): 104 | fname = 'train_representation.py' 105 | @staticmethod 106 | def hparams(): 107 | grid = { 108 | 'domain': ['MIMIC'], 109 | 'use_pretrained': [True], 110 | 'seed': [0], 111 | 'target': ['race', 'gender', 'Pneumonia'], 112 | 'data_type': ['normal'], 113 | 'embed_dim': [128], 114 | 'model': ['densenet'], 115 | 'pixel_thres': [1.0], 116 | 'augment': [True, False], 117 | 'epochs': [15] 118 | } 119 | return combinations(grid) 120 | 121 | class PerturbMIMIC(): 122 | fname = 'perturb.py' 123 | @staticmethod 124 | def hparams(): 125 | grid = { 126 | 'domain': ['MIMIC'], 127 | 'plabel': ['race'], 128 | 'use_pretrained': [True], 129 | 'attack': ['fgsm', 'random'], 130 | 'input_dir': ['/scratch/hdd001/home/{}/projects/CXR_Bias/output/NormalMIMIC/'.format(os.environ.get('USER'))], 131 | 'seed': [0], 132 | 'data_type': ['normal'], 133 | 'model': ['densenet'], 134 | 'pixel_thres': [1.0], 135 | 'augment': [False] 136 | } 137 | return combinations(grid) 138 | 139 | class IFFTPatched(): 140 | fname = 'train_model.py' 141 | @staticmethod 142 | def hparams(): 143 | grid = { 144 | 'domain': ['MIMIC'], 145 | 'use_pretrained': [True], 146 | 'seed': [0], 147 | 'target': ['race', 'gender', 'Pneumonia'], 148 | 'data_type': ['ifft', 'normal'], 149 | 'model': ['densenet'], 150 | 'filter_type': ['high'], 151 | 'filter_thres': [100], 152 | 'patch_ind': list(range(9)), 153 | 'add_noise': [False], 154 | 'crop_patch_at_end': [False], 155 | 'augment': [False], 156 | 'patched': ['patch', 'invpatch'] 157 | } 158 | return combinations(grid) 159 | 160 | 161 | class IFFTNotchBandpass(): 162 | fname = 'train_model.py' 163 | @staticmethod 164 | def hparams(): 165 | grid = { 166 | 'domain': ['MIMIC'], 167 | 'use_pretrained': [True], 168 | 'seed': [0], 169 | 'target': ['race', 'gender', 'Pneumonia'], 170 | 'data_type': ['ifft'], 171 | 'model': ['densenet'], 172 | 'filter_type': ['notch', 'bandpass'], 173 | 'filter_thres': [10, 25, 50, 75, 100, 125, 150], 174 | 'filter_thres2': [10, 25, 50, 75, 100, 125, 150], 175 | 'add_noise': [False], 176 | 'crop_patch_at_end': [False], 177 | 'augment': [False], 178 | } 179 | options = combinations(grid) 180 | 181 | return [i for i in options if i['filter_thres2'] > i['filter_thres']] 182 | -------------------------------------------------------------------------------- /training_code/frequency_training/train_representation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import sys 4 | import pickle 5 | import os 6 | import random 7 | from pathlib import Path 8 | import math 9 | import json 10 | 11 | from lib import models 12 | from lib.data import get_dataset 13 | from lib.train_helper import representation, train_triplet_step 14 | from lib.utils import EarlyStopping, save_checkpoint, has_checkpoint, load_checkpoint 15 | from lib.infinite_loader import InfiniteDataLoader 16 | import Constants 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.optim as optim 22 | from torch.optim import lr_scheduler 23 | from torch.utils.data import DataLoader, Subset 24 | 25 | from pytorch_metric_learning import miners 26 | from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator 27 | 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--exp_name', type=str, required = True) 31 | parser.add_argument('--domain', type=str, choices = ['MIMIC', 'CXP', 'both']) 32 | parser.add_argument('--target', type=str, choices = ['race', 'gender', 'Pneumonia']) 33 | parser.add_argument('--es_patience', type=int, default=7) # *val_freq steps 34 | parser.add_argument('--val_freq', type=int, default=100) 35 | parser.add_argument('--model', type = str, choices = ['densenet', 'vision_transformer'], default = 'densenet') 36 | parser.add_argument('--embed_dim', type = int, default=128) 37 | parser.add_argument('--data_type', type = str, choices = ['normal','fft','ifft']) 38 | parser.add_argument('--augment', action = 'store_true') 39 | parser.add_argument('--use_pretrained', action = 'store_true') 40 | parser.add_argument('--pixel_thres', type = float, default = 1.0) 41 | parser.add_argument('--debug', action = 'store_true') 42 | parser.add_argument('--batch_size', type=int, default=64) 43 | parser.add_argument('--lr', type=float, default=1e-4) 44 | parser.add_argument('--seed', type=int, default=42) 45 | parser.add_argument('--epochs', type=int, default=10) 46 | parser.add_argument('--output_dir', type=str) 47 | args = parser.parse_args() 48 | 49 | torch.manual_seed(args.seed) 50 | np.random.seed(args.seed) 51 | random.seed(args.seed) 52 | torch.backends.cudnn.deterministic = True 53 | 54 | output_dir = Path(args.output_dir) 55 | output_dir.mkdir(parents = True, exist_ok = True) 56 | 57 | with open(Path(output_dir)/'args.json', 'w') as outfile: 58 | json.dump(vars(args), outfile) 59 | 60 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 61 | 62 | if args.target == 'race': 63 | classes = [Constants.race_mapping[i] for i in range(len(Constants.race_mapping))] 64 | elif args.target == 'gender': 65 | classes = [Constants.gender_mapping[i] for i in range(len(Constants.gender_mapping))] 66 | else: 67 | classes = ['neg', 'pos'] 68 | 69 | n_outputs = args.embed_dim 70 | 71 | if args.model == 'densenet': 72 | model = models.DenseNet(args.use_pretrained, n_outputs = n_outputs).to(device) 73 | elif args.model == 'vision_transformer': 74 | raise NotImplementedError("Vision transformer not currently supported on this branch.") 75 | 76 | print("Total parameters: " + str(sum(p.numel() for p in model.parameters()))) 77 | 78 | optimizer = optim.Adam(model.parameters(), lr = args.lr) 79 | lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode = 'min', patience=5, factor=0.1) 80 | 81 | if args.domain in ['MIMIC', 'CXP']: 82 | envs = [args.domain] 83 | elif args.domain == 'both': 84 | envs = ['MIMIC', 'CXP'] 85 | else: 86 | raise NotImplementedError 87 | 88 | common_args = { 89 | 'envs': envs, 90 | 'only_frontal': True, 91 | 'subset_label': args.target, 92 | 'output_type': args.data_type, 93 | 'imagenet_norm': not args.data_type == 'fft', 94 | 'augment': int(args.augment), 95 | } 96 | 97 | train_data = get_dataset(split = 'train', **common_args) 98 | val_data = get_dataset(split = 'val', **common_args) 99 | test_data = get_dataset(split = 'test', **common_args) 100 | test_data_aug = get_dataset(augment = 1, split = 'test', **{i:common_args[i] for i in common_args if i != 'augment'}) 101 | 102 | if args.debug: 103 | val_data = Subset(val_data, list(range(1024))) 104 | test_data = Subset(test_data, list(range(1024))) 105 | test_data_aug = Subset(test_data_aug, list(range(1024))) 106 | else: 107 | val_data = Subset(val_data, np.random.choice(np.arange(len(val_data)), min(1024*8, len(val_data)), replace = False)) 108 | 109 | es = EarlyStopping(patience = args.es_patience) 110 | batch_size = args.batch_size 111 | if args.debug: 112 | n_steps = 50 113 | else: 114 | n_steps = args.epochs * (len(train_data) // batch_size) 115 | 116 | train_loader = InfiniteDataLoader(train_data, batch_size=batch_size, num_workers = 1) 117 | validation_loader = DataLoader(val_data, batch_size=batch_size*2, shuffle=False) 118 | test_loader = DataLoader(test_data, batch_size=batch_size*2, shuffle=False) 119 | test_loader_aug = DataLoader(test_data_aug, batch_size=batch_size*2, shuffle=False) 120 | 121 | miner = miners.DistanceWeightedMiner() 122 | 123 | if has_checkpoint() and not args.debug: 124 | state = load_checkpoint() 125 | model.load_state_dict(state['model_dict']) 126 | optimizer.load_state_dict(state['optimizer_dict']) 127 | lr_scheduler.load_state_dict(state['scheduler_dict']) 128 | train_loader.sampler.load_state_dict(state['sampler_dict']) 129 | start_step = state['start_step'] 130 | es = state['es'] 131 | torch.random.set_rng_state(state['rng']) 132 | print("Loaded checkpoint at step %s" % start_step) 133 | else: 134 | start_step = 0 135 | 136 | for step in range(start_step, n_steps): 137 | if es.early_stop: 138 | break 139 | data, target, meta = next(iter(train_loader)) 140 | step_loss, step_precision = train_triplet_step(data, target, model, 141 | device, optimizer, miner) 142 | 143 | print('Train Step: {} Precision@1: {:.4f}\tLoss: {:.6f}'.format( 144 | step, step_precision, step_loss), flush = True) 145 | 146 | if step % args.val_freq == 0: 147 | total_loss, acc_dict, embedding_list, target_list = representation(model, device, validation_loader) 148 | lr_scheduler.step(total_loss) 149 | es(total_loss, step , model.state_dict(), output_dir/'model.pt') 150 | 151 | save_checkpoint(model, optimizer, lr_scheduler, 152 | train_loader.sampler.state_dict(train_loader._infinite_iterator), 153 | step+1, es, torch.random.get_rng_state()) 154 | 155 | 156 | _, acc_dict, embedding_list, target_list = representation(model, device, test_loader) 157 | _, acc_dict_aug, embedding_list_aug, target_list_aug = representation(model, device, test_loader_aug) 158 | 159 | results = {} 160 | acc_calc = AccuracyCalculator() 161 | for m, embedding, target in zip(['unaug', 'aug'], [embedding_list, embedding_list_aug], [target_list, target_list_aug]): 162 | results[m] = {} 163 | for grp in np.unique(target): 164 | target_bin = target == grp 165 | embedding_bin = embedding[target_bin,:] 166 | results[m][f'metrics_{classes[grp]}'] = acc_calc.get_accuracy(embedding_bin, embedding_bin, target_bin, target_bin, embeddings_come_from_same_source=True) 167 | 168 | results[m]['targets'] = target 169 | 170 | if args.debug: 171 | print(results) 172 | 173 | pickle.dump(results, (output_dir/'results.pkl').open('wb')) 174 | 175 | with open(output_dir/'done', 'w') as f: 176 | f.write('done') 177 | 178 | -------------------------------------------------------------------------------- /training_code/frequency_training/train_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import sys 4 | import pickle 5 | import os 6 | import random 7 | from pathlib import Path 8 | import math 9 | import json 10 | 11 | from lib import models 12 | from lib.data import get_dataset 13 | from lib.train_helper import prediction, train_step 14 | from lib.utils import EarlyStopping, save_checkpoint, has_checkpoint, load_checkpoint 15 | from lib.infinite_loader import InfiniteDataLoader 16 | from sklearn.metrics import roc_auc_score 17 | import Constants 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | import torch.optim as optim 23 | from torch.optim import lr_scheduler 24 | from torch.utils.data import DataLoader, Subset 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--exp_name', type=str, required = True) 28 | parser.add_argument('--domain', type=str, choices = ['MIMIC', 'CXP', 'both']) 29 | parser.add_argument('--target', type=str, choices = ['race', 'gender', 'Pneumonia']) 30 | parser.add_argument('--es_patience', type=int, default=7) # *val_freq steps 31 | parser.add_argument('--val_freq', type=int, default=100) 32 | parser.add_argument('--model', type = str, choices = ['densenet', 'vision_transformer'], default = 'densenet') 33 | parser.add_argument('--data_type', type = str, choices = ['normal','fft','ifft'], 34 | help = '`normal`: train on orginal images, `fft`: train on frequency spectra (not used), `ifft` train on inverse transformed images') 35 | parser.add_argument('--patched', type = str, choices = ['patch', 'invpatch', 'none'], default = 'none', 36 | help = '`none`: train on the whole image, `patch`: train using only patch patch_ind, `invpatch`: set patch_ind to black, train on whole image') 37 | parser.add_argument('--patch_ind', type = int, choices = list(range(Constants.N_PATCHES)), default = None) 38 | parser.add_argument('--filter_type', type = str, choices = ['low', 'high', 'notch', 'bandpass']) 39 | parser.add_argument('--filter_thres', type = float) 40 | parser.add_argument('--filter_thres2', type = float) 41 | parser.add_argument('--augment', action = 'store_true', help = 'whether to use data augmentation') 42 | parser.add_argument('--use_pretrained', action = 'store_true') 43 | parser.add_argument('--pixel_thres', type = float, default = 1.0, help = 'intensity threshold for clipping pixels, 1.0 for no clipping') 44 | parser.add_argument('--crop_patch_at_end', action = 'store_true') 45 | parser.add_argument('--debug', action = 'store_true') 46 | parser.add_argument('--batch_size', type=int, default=64) 47 | parser.add_argument('--lr', type=float, default=1e-4) 48 | parser.add_argument('--seed', type=int, default=42) 49 | parser.add_argument('--epochs', type=int, default=10) 50 | parser.add_argument('--output_dir', type=str) 51 | args = parser.parse_args() 52 | 53 | torch.manual_seed(args.seed) 54 | np.random.seed(args.seed) 55 | random.seed(args.seed) 56 | torch.backends.cudnn.deterministic = True 57 | 58 | output_dir = Path(args.output_dir) 59 | output_dir.mkdir(parents = True, exist_ok = True) 60 | 61 | with open(Path(output_dir)/'args.json', 'w') as outfile: 62 | json.dump(vars(args), outfile) 63 | 64 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 65 | 66 | if args.target == 'race': 67 | classes = [Constants.race_mapping[i] for i in range(len(Constants.race_mapping))] 68 | elif args.target == 'gender': 69 | classes = [Constants.gender_mapping[i] for i in range(len(Constants.gender_mapping))] 70 | else: 71 | classes = ['neg', 'pos'] 72 | 73 | n_outputs = len(classes) 74 | 75 | if args.model == 'densenet': 76 | model = models.DenseNet(args.use_pretrained, n_outputs = n_outputs).to(device) 77 | elif args.model == 'vision_transformer': 78 | model = models.VisionTransformer(args.use_pretrained, n_outputs = n_outputs).to(device) 79 | 80 | print("Total parameters: " + str(sum(p.numel() for p in model.parameters()))) 81 | 82 | optimizer = optim.Adam(model.parameters(), lr = args.lr) 83 | lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode = 'min', patience=5, factor=0.1) 84 | 85 | if args.domain in ['MIMIC', 'CXP']: 86 | envs = [args.domain] 87 | elif args.domain == 'both': 88 | envs = ['MIMIC', 'CXP'] 89 | else: 90 | raise NotImplementedError 91 | 92 | common_args = { 93 | 'envs': envs, 94 | 'only_frontal': True, 95 | 'subset_label': args.target, 96 | 'output_type': args.data_type, 97 | 'imagenet_norm': not args.data_type == 'fft', 98 | 'augment': int(args.augment), 99 | 'ifft_filter': (args.filter_type, args.filter_thres, args.filter_thres2), 100 | 'pixel_thres': args.pixel_thres, 101 | 'crop_patch_at_end': args.crop_patch_at_end, 102 | 'patched': args.patched, 103 | 'patch_ind': args.patch_ind 104 | } 105 | 106 | train_data = get_dataset(split = 'train', **common_args) 107 | val_data = get_dataset(split = 'val', **common_args) 108 | test_data = get_dataset(split = 'test', **common_args) 109 | test_data_aug = get_dataset(augment = 1, split = 'test', **{i:common_args[i] for i in common_args if i != 'augment'}) 110 | 111 | if args.debug: 112 | val_data = Subset(val_data, list(range(1024))) 113 | test_data = Subset(test_data, list(range(1024))) 114 | test_data_aug = Subset(test_data_aug, list(range(1024))) 115 | else: 116 | val_data = Subset(val_data, np.random.choice(np.arange(len(val_data)), min(1024*8, len(val_data)), replace = False)) 117 | 118 | es = EarlyStopping(patience = args.es_patience) 119 | batch_size = args.batch_size 120 | if args.debug: 121 | n_steps = 50 122 | else: 123 | n_steps = args.epochs * (len(train_data) // batch_size) 124 | 125 | train_loader = InfiniteDataLoader(train_data, batch_size=batch_size, num_workers = 1) 126 | validation_loader = DataLoader(val_data, batch_size=batch_size*2, shuffle=False) 127 | test_loader = DataLoader(test_data, batch_size=batch_size*2, shuffle=False) 128 | test_loader_aug = DataLoader(test_data_aug, batch_size=batch_size*2, shuffle=False) 129 | 130 | if has_checkpoint() and not args.debug: 131 | state = load_checkpoint() 132 | model.load_state_dict(state['model_dict']) 133 | optimizer.load_state_dict(state['optimizer_dict']) 134 | lr_scheduler.load_state_dict(state['scheduler_dict']) 135 | train_loader.sampler.load_state_dict(state['sampler_dict']) 136 | start_step = state['start_step'] 137 | es = state['es'] 138 | torch.random.set_rng_state(state['rng']) 139 | print("Loaded checkpoint at step %s" % start_step) 140 | else: 141 | start_step = 0 142 | 143 | for step in range(start_step, n_steps): 144 | if es.early_stop: 145 | break 146 | data, target, meta = next(iter(train_loader)) 147 | step_loss, step_acc = train_step(data, target, model, 148 | device, optimizer) 149 | 150 | print('Train Step: {} Accuracy: {:.4f}\tLoss: {:.6f}'.format( 151 | step, step_acc, step_loss), flush = True) 152 | 153 | if step % args.val_freq == 0: 154 | total_loss, prob_list, pred_list, target_list = prediction(model, device, validation_loader) 155 | lr_scheduler.step(total_loss) 156 | es(total_loss, step , model.state_dict(), output_dir/'model.pt') 157 | 158 | save_checkpoint(model, optimizer, lr_scheduler, 159 | train_loader.sampler.state_dict(train_loader._infinite_iterator), 160 | step+1, es, torch.random.get_rng_state()) 161 | 162 | 163 | _, prob_list, _, target_list = prediction(model, device, test_loader) 164 | _, prob_list_aug, _, target_list_aug = prediction(model, device, test_loader_aug) 165 | 166 | results = {} 167 | for m, prob, target in zip(['unaug', 'aug'], [prob_list, prob_list_aug], [target_list, target_list_aug]): 168 | results[m] = {} 169 | for grp in np.unique(target): 170 | target_bin = target == grp 171 | pred = prob[:, grp] 172 | results[m][f'roc_{classes[grp]}'] = roc_auc_score(target_bin, pred) 173 | 174 | results[m]['preds'] = prob 175 | results[m]['targets'] = target 176 | 177 | if args.debug: 178 | print(results) 179 | 180 | pickle.dump(results, (output_dir/'results.pkl').open('wb')) 181 | 182 | with open(output_dir/'done', 'w') as f: 183 | f.write('done') 184 | -------------------------------------------------------------------------------- /training_code/frequency_training/perturb.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import sys 4 | import pickle 5 | import os 6 | import sys 7 | import random 8 | from pathlib import Path 9 | import math 10 | import json 11 | 12 | from lib import models 13 | from lib.data import get_dataset 14 | from lib.perturb_helper import fgsm_attack, random_attack 15 | from sklearn.metrics import roc_auc_score, accuracy_score 16 | import Constants 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from torch.utils.data import DataLoader, Subset 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--exp_name', type=str, required = True) 26 | parser.add_argument('--domain', type=str, choices = ['MIMIC', 'CXP', 'both']) 27 | parser.add_argument('--plabel', type=str, choices = ['race', 'gender', 'Pneumonia']) 28 | parser.add_argument('--model', type = str, choices = ['densenet', 'vision_transformer'], default = 'densenet') 29 | parser.add_argument('--use_pretrained', action = 'store_true') 30 | parser.add_argument('--input_dir', type = str, required = True) 31 | parser.add_argument('--attack', type = str, choices = ['fgsm', 'random'], default = 'fgsm') 32 | parser.add_argument('--epsilons', nargs='+', type = float, default = np.linspace(0., 0.5, 10).tolist()) 33 | parser.add_argument('--data_type', type = str, choices = ['normal']) 34 | parser.add_argument('--patched', type = str, choices = ['patch', 'invpatch', 'none'], default = 'none') 35 | parser.add_argument('--patch_ind', type = int, choices = list(range(Constants.N_PATCHES)), default = None) 36 | parser.add_argument('--augment', action = 'store_true') 37 | parser.add_argument('--pixel_thres', type = float, default = 1.0) 38 | parser.add_argument('--debug', action = 'store_true') 39 | parser.add_argument('--batch_size', type=int, default=64) 40 | parser.add_argument('--seed', type=int, default=42) 41 | parser.add_argument('--output_dir', type=str, required=True) 42 | args = parser.parse_args() 43 | 44 | torch.manual_seed(args.seed) 45 | np.random.seed(args.seed) 46 | random.seed(args.seed) 47 | torch.backends.cudnn.deterministic = True 48 | 49 | output_dir = Path(args.output_dir) 50 | output_dir.mkdir(parents = True, exist_ok = True) 51 | 52 | with open(Path(output_dir)/'args.json', 'w') as outfile: 53 | json.dump(vars(args), outfile) 54 | 55 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 56 | 57 | config = {} 58 | for indir in os.listdir(Path(args.input_dir)): 59 | model_path = os.path.join(*[args.input_dir, indir, 'model.pt']) 60 | args_path = os.path.join(*[args.input_dir, indir, 'args.json']) 61 | if os.path.exists(model_path) and os.path.exists(args_path): 62 | with open(args_path, 'r') as fp: 63 | try: 64 | train_args = json.load(fp) 65 | except: 66 | print('Unexpected error encountered trying to load JSON {}: {}'.format(args_path, sys.exc_info()[0])) 67 | train_args = {} 68 | use = True 69 | for key in train_args: 70 | if key in vars(args) and key not in ['exp_name', 'output_dir', 'seed', 'debug', 'batch_size'] and vars(args)[key] != train_args[key]: 71 | use = False 72 | break 73 | if use and 'target' in train_args: 74 | config[train_args['target']] = model_path 75 | 76 | if args.debug: 77 | assert args.plabel in config, "Label with which to perturb model must exist as key in the config." 78 | 79 | models_dict = {} 80 | classes_dict = {} 81 | for key in config: 82 | 83 | if key == 'race': 84 | classes_dict[key] = [Constants.race_mapping[i] for i in range(len(Constants.race_mapping))] 85 | elif key == 'gender': 86 | classes_dict[key] = [Constants.gender_mapping[i] for i in range(len(Constants.gender_mapping))] 87 | else: 88 | classes_dict[key] = ['neg', 'pos'] 89 | 90 | n_outputs = len(classes_dict[key]) 91 | 92 | if args.model == 'densenet': 93 | model = models.DenseNet(args.use_pretrained, n_outputs = n_outputs).to(device) 94 | elif args.model == 'vision_transformer': 95 | raise NotImplementedError("Vision transformer not currently supported on this branch.") 96 | 97 | checkpoint_state_dict = torch.load(config[key]) 98 | model.load_state_dict(checkpoint_state_dict) 99 | 100 | model.eval() 101 | models_dict[key] = model 102 | 103 | print(f"Key {key}: Total parameters: " + str(sum(p.numel() for p in model.parameters()))) 104 | 105 | if args.domain in ['MIMIC', 'CXP']: 106 | envs = [args.domain] 107 | elif args.domain == 'both': 108 | envs = ['MIMIC', 'CXP'] 109 | else: 110 | raise NotImplementedError 111 | 112 | common_args = { 113 | 'envs': envs, 114 | 'only_frontal': True, 115 | 'subset_label': args.plabel, 116 | 'output_type': args.data_type, 117 | 'imagenet_norm': not args.data_type == 'fft', 118 | 'augment': int(args.augment), 119 | 'patched': args.patched, 120 | 'patch_ind': args.patch_ind 121 | } 122 | 123 | test_data = get_dataset(split = 'test', **common_args) 124 | test_data_aug = get_dataset(augment = 1, split = 'test', **{i:common_args[i] for i in common_args if i != 'augment'}) 125 | 126 | if args.debug: 127 | test_data = Subset(test_data, list(range(1024))) 128 | test_data_aug = Subset(test_data_aug, list(range(1024))) 129 | 130 | batch_size = args.batch_size 131 | 132 | test_loader = DataLoader(test_data, batch_size=1, shuffle=False) 133 | test_loader_aug = DataLoader(test_data_aug, batch_size=1, shuffle=False) 134 | 135 | if args.attack == 'fgsm': 136 | attack_func = fgsm_attack 137 | elif args.attack == 'random': 138 | attack_func = random_attack 139 | else: 140 | raise ValueError("Attack must be either 'fgsm' or 'random.'") 141 | 142 | auc_dict = {} 143 | acc_dict = {} 144 | 145 | for epsilon in args.epsilons: 146 | 147 | targets_dict = {key: [] for key in config} 148 | preds_dict = {key: [] for key in config} 149 | probs_dict = {key: [] for key in config} 150 | 151 | # Loop over all examples in test set 152 | for data, target, meta in test_loader: 153 | 154 | data, target = data.float().to(device), target.long().to(device) 155 | target = target.view((-1)) 156 | 157 | data.requires_grad = True 158 | 159 | output = models_dict[args.plabel](data) 160 | 161 | init_probs = F.softmax(output, dim=1) 162 | init_pred = init_probs.max(1, keepdim=True)[1] 163 | 164 | targets_dict[args.plabel].append(target.cpu().item()) 165 | for key in config: 166 | if key != args.plabel: 167 | if key == 'gender': 168 | targets_dict[key].append(meta['Sex'].item()) 169 | continue 170 | targets_dict[key].append(meta[key].item()) 171 | # If the initial prediction is wrong, dont bother attacking, just move on 172 | if init_pred.item() != target.item() or epsilon == 0.: 173 | probs_dict[args.plabel].append(init_probs.squeeze().detach().cpu().numpy()) 174 | preds_dict[args.plabel].append(init_pred.cpu().item()) 175 | for key in config: 176 | if key != args.plabel: 177 | with torch.no_grad(): 178 | output = models_dict[key](data) 179 | 180 | probs = F.softmax(output, dim=1) 181 | preds = probs.max(1, keepdim=True)[1] 182 | 183 | probs_dict[key].append(probs.squeeze().cpu().numpy()) 184 | preds_dict[key].append(preds.cpu().item()) 185 | continue 186 | 187 | # Get the loss 188 | loss = F.cross_entropy(output, target) 189 | 190 | # Zero all existing gradients 191 | models_dict[args.plabel].zero_grad() 192 | 193 | # Calculate gradients of model in backward pass 194 | loss.backward() 195 | 196 | # Collect datagrad 197 | data_grad = data.grad.data 198 | 199 | attack_dict = {'batch': data, 'epsilon': epsilon} 200 | if args.attack == 'fgsm': 201 | attack_dict['data_grads'] = data_grad 202 | 203 | # Call Attack 204 | perturbed_data = attack_func(**attack_dict) 205 | 206 | # Re-classify the perturbed image 207 | for key in config: 208 | with torch.no_grad(): 209 | perturbed_output = models_dict[key](perturbed_data) 210 | 211 | perturbed_probs = F.softmax(perturbed_output, dim=1) 212 | perturbed_pred = perturbed_probs.max(1, keepdim=True)[1] 213 | 214 | probs_dict[key].append(perturbed_probs.squeeze().cpu().numpy()) 215 | preds_dict[key].append(perturbed_pred.item()) 216 | 217 | auc_dict[epsilon] = {} 218 | for key in config: 219 | prob = np.stack(probs_dict[key]) 220 | target = np.hstack(targets_dict[key]) 221 | auc_dict[epsilon][key] = {} 222 | for grp in np.unique(target): 223 | target_bin = target == grp 224 | pred = prob[:, int(grp)] 225 | auc_dict[epsilon][key][f'roc_{classes_dict[key][int(grp)]}'] = roc_auc_score(target_bin, pred) 226 | acc_dict[epsilon] = {key: accuracy_score(targets_dict[key], preds_dict[key]) for key in config} 227 | 228 | #### RESULTS 229 | 230 | results = {'auc': auc_dict, 'acc': acc_dict} 231 | print(results) 232 | 233 | with open(output_dir/'results.json', 'w') as fp: 234 | json.dump(results, fp) 235 | 236 | with open(output_dir/'done', 'w') as f: 237 | f.write('done') 238 | -------------------------------------------------------------------------------- /training_code/frequency_training/lib/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | import Constants 6 | from lib import cxr_preprocess as preprocess 7 | import pandas as pd 8 | from torchvision import transforms 9 | import pickle 10 | from pathlib import Path 11 | from torch.utils.data import Dataset, ConcatDataset 12 | from lib.utils import fft, filter_and_ifft, draw_circle, split_tensor, blacken_tensor 13 | import math 14 | 15 | 16 | def get_dataset(envs = [], split = None, only_frontal = False, imagenet_norm = True, augment = 0, cache = True, subset_label = None, 17 | augmented_dfs = None, output_type = 'normal', ifft_filter = None, 18 | pixel_thres = None, crop_patch_at_end = False, patched = 'none', patch_ind = None): 19 | 20 | if augment == 1: # normal image augmentation 21 | image_transforms = [transforms.RandomHorizontalFlip(), 22 | transforms.RandomRotation(15), 23 | transforms.RandomResizedCrop(size = 224, scale = (0.75, 1.0)), 24 | transforms.ToTensor()] 25 | elif augment == 0: 26 | image_transforms = [transforms.ToTensor()] 27 | elif augment == -1: # only resize, just return a dataset with PIL images; don't ToTensor() 28 | image_transforms = [] 29 | 30 | if imagenet_norm and augment != -1: 31 | image_transforms.append(transforms.Normalize(Constants.IMAGENET_MEAN, Constants.IMAGENET_STD)) 32 | 33 | datasets = [] 34 | for e in envs: 35 | func = preprocess.get_process_func(e) 36 | paths = Constants.df_paths[e] 37 | 38 | if split is not None: 39 | splits = [split] 40 | else: 41 | splits = ['train', 'val', 'test'] 42 | 43 | if augmented_dfs is not None: # use provided dataframes for subsample augmentation 44 | dfs = [augmented_dfs[e][i] for i in splits] 45 | else: 46 | dfs = [func(pd.read_csv(paths[i]), only_frontal) for i in splits] 47 | 48 | for c, s in enumerate(splits): 49 | cache_dir = Path(Constants.cache_dir)/ f'{e}_{s}/' 50 | cache_dir.mkdir(parents=True, exist_ok=True) 51 | datasets.append(AllDatasetsShared(dfs[c], transform = transforms.Compose(image_transforms) 52 | , split = split, cache = cache, cache_dir = cache_dir, subset_label = subset_label, output_type = output_type, 53 | ifft_filter = ifft_filter, pixel_thres = pixel_thres, crop_patch_at_end = crop_patch_at_end, 54 | patched = patched, patch_ind = patch_ind)) 55 | 56 | if len(datasets) == 0: 57 | return None 58 | elif len(datasets) == 1: 59 | ds = datasets[0] 60 | else: 61 | ds = ConcatDataset(datasets) 62 | ds.dataframe = pd.concat([i.dataframe for i in datasets]) 63 | 64 | return ds 65 | 66 | class AllDatasetsShared(Dataset): 67 | def __init__(self, dataframe, transform=None, split = None, cache = True, cache_dir = '', subset_label = None, output_type = 'normal', ifft_filter = None, 68 | pixel_thres = None, crop_patch_at_end = False, patched = 'none', patch_ind = None): 69 | super().__init__() 70 | self.dataframe = dataframe 71 | self.dataset_size = self.dataframe.shape[0] 72 | self.transform = transform 73 | self.split = split 74 | self.cache = cache 75 | self.cache_dir = Path(cache_dir) 76 | self.subset_label = subset_label # (str) select one label instead of returning all Constants.take_labels 77 | self.output_type = output_type 78 | if ifft_filter is not None and ifft_filter[0] in ['low', 'high']: 79 | self.ifft_filter = ifft_filter[:2] 80 | else: 81 | self.ifft_filter = ifft_filter 82 | self.pixel_thres = pixel_thres 83 | self.crop_patch_at_end = crop_patch_at_end 84 | self.patched = patched 85 | self.patch_ind = patch_ind 86 | 87 | if self.output_type == 'ifft': 88 | if self.ifft_filter[0] in ['low','high']: 89 | self.filter = draw_circle(shape = (224, 224), diameter = self.ifft_filter[1]) 90 | if self.ifft_filter[0] == 'high': 91 | self.filter = ~self.filter 92 | elif self.ifft_filter[0] in ['bandpass', 'notch']: 93 | filter_outer = draw_circle(shape = (224, 224), diameter = self.ifft_filter[2]) 94 | filter_inner = draw_circle(shape = (224, 224), diameter = self.ifft_filter[1]) 95 | if self.ifft_filter[0] == 'notch': 96 | self.filter = ~(~filter_inner & filter_outer) 97 | else: 98 | self.filter = (~filter_inner & filter_outer) 99 | else: 100 | raise NotImplementedError 101 | 102 | def get_cache_path(self, cache_dir, meta): 103 | path = Path(meta['path']) 104 | if meta['env'] in ['PAD', 'NIH']: 105 | return cache_dir / (path.stem + '.pkl') 106 | elif meta['env'] in ['MIMIC', 'CXP']: 107 | return (cache_dir / '_'.join(path.parts[-3:])).with_suffix('.pkl') 108 | 109 | def __getitem__(self, idx): 110 | item = self.dataframe.iloc[idx] 111 | cache_path = self.get_cache_path(self.cache_dir, item) 112 | 113 | if self.cache and cache_path.is_file(): 114 | img, label, meta = pickle.load(cache_path.open('rb')) 115 | meta = item.to_dict() # override 116 | else: 117 | img = np.array(Image.open(item["path"])) 118 | 119 | if img.dtype == 'int32': 120 | img = np.uint8(img/(2**16)*255) 121 | elif img.dtype == 'bool': 122 | img = np.uint8(img) 123 | else: #uint8 124 | pass 125 | 126 | if len(img.shape) == 2: 127 | img = img[:, :, np.newaxis] 128 | img = np.concatenate([img, img, img], axis=2) 129 | elif len(img.shape)>2: 130 | # print('different shape', img.shape, item) 131 | img = img[:,:,0] 132 | img = img[:, :, np.newaxis] 133 | img = np.concatenate([img, img, img], axis=2) 134 | 135 | img = Image.fromarray(img) 136 | resize_transform = transforms.Resize(size = [224, 224]) 137 | img = transforms.Compose([resize_transform])(img) 138 | 139 | label = torch.FloatTensor(np.zeros(len(Constants.take_labels), dtype=float)) 140 | for i in range(0, len(Constants.take_labels)): 141 | if (self.dataframe[Constants.take_labels[i].strip()].iloc[idx].astype('float') > 0): 142 | label[i] = self.dataframe[Constants.take_labels[i].strip()].iloc[idx].astype('float') 143 | 144 | meta = item.to_dict() 145 | 146 | if self.cache: 147 | pickle.dump((img, label, meta), cache_path.open('wb')) 148 | 149 | if self.transform is not None: # apply image augmentations after caching 150 | img = self.transform(img) 151 | 152 | if self.subset_label: 153 | if self.subset_label in Constants.take_labels: 154 | label = int(label[Constants.take_labels.index(self.subset_label)]) 155 | elif self.subset_label == 'gender': 156 | label = meta['Sex'] 157 | elif self.subset_label == 'race': 158 | label = meta['race'] 159 | elif self.subset_label == 'insurance': 160 | label = meta['insurance'] 161 | else: 162 | raise NotImplementedError 163 | 164 | if self.output_type == 'fft': 165 | img = img[0, :, :].float().numpy() 166 | spectra = fft(img) 167 | img = torch.from_numpy(np.stack([spectra, spectra, spectra])) 168 | 169 | elif self.output_type == 'ifft': 170 | img = img[0, :, :].float().numpy() 171 | spectra = np.fft.fftshift(np.fft.fft2(img)) 172 | img = filter_and_ifft(spectra, self.filter) 173 | img = torch.from_numpy(np.stack([img, img, img])) 174 | else: 175 | assert self.output_type == 'normal' 176 | 177 | if self.patched != 'none' and self.patched is not None: 178 | assert(not self.crop_patch_at_end) 179 | # assert(0 <= self.patch_ind < Constants.N_PATCHES) 180 | if self.patched == 'patch': 181 | img = split_tensor(img, tile_size=int(224/math.sqrt(Constants.N_PATCHES)), offset=int(224/math.sqrt(Constants.N_PATCHES)))[0][self.patch_ind] 182 | img = transforms.Resize((224, 224))(img) 183 | elif self.patched == 'invpatch': 184 | img = blacken_tensor(img, self.patch_ind, tile_size=int(224/math.sqrt(Constants.N_PATCHES)), offset=int(224/math.sqrt(Constants.N_PATCHES))) 185 | else: 186 | raise NotImplementedError 187 | 188 | if self.crop_patch_at_end: 189 | img = transforms.RandomResizedCrop(size = 224, scale = (0.2, 0.2), ratio = (1.0, 1.0))(img) 190 | 191 | if self.pixel_thres is not None: 192 | img[img > self.pixel_thres] = self.pixel_thres 193 | 194 | return img, label, meta 195 | 196 | def __len__(self): 197 | return self.dataset_size 198 | -------------------------------------------------------------------------------- /training_code/CXR_training/Emory_CXR/Emory_CXR_resnet34_race_detection_2021_06_29.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "\n", 10 | "import os\n", 11 | "import copy\n", 12 | "import sys\n", 13 | "import math\n", 14 | "from datetime import datetime\n", 15 | "import random\n", 16 | "import numpy as np\n", 17 | "import pandas as pd\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve, precision_recall_curve, matthews_corrcoef, auc, accuracy_score, recall_score, precision_score, f1_score\n", 20 | "from sklearn.utils import shuffle\n", 21 | "import tensorflow as tf\n", 22 | "from tensorflow.keras import backend as K\n", 23 | "from tensorflow.keras.optimizers import Adam,RMSprop,SGD\n", 24 | "from tensorflow.keras import layers\n", 25 | "from tensorflow.keras.layers import concatenate, add, GlobalAveragePooling2D, BatchNormalization, Input, Dense\n", 26 | "from tensorflow.keras.models import Model\n", 27 | "from tensorflow.keras import initializers\n", 28 | "from tensorflow.keras.callbacks import EarlyStopping, LearningRateScheduler, ReduceLROnPlateau, ModelCheckpoint\n", 29 | "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n", 30 | "from tensorflow.keras.preprocessing.image import load_img\n", 31 | "from tensorflow.keras import optimizers\n", 32 | "from tensorflow.keras.models import Sequential\n", 33 | "from tensorflow.keras.applications.densenet import DenseNet121\n", 34 | "from classification_models.tfkeras import Classifiers\n", 35 | "from tensorflow.keras.models import load_model\n", 36 | "from PIL import ImageFile\n", 37 | "import random as python_random\n", 38 | "\n" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "\n", 48 | "np.random.seed(2021)\n", 49 | "python_random.seed(2021)\n", 50 | "tf.random.set_seed(2021)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 3, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "import os\n", 60 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"4\"\n" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 4, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "name": "stdout", 70 | "output_type": "stream", 71 | "text": [ 72 | "2.0.0\n" 73 | ] 74 | } 75 | ], 76 | "source": [ 77 | "print(tf.__version__)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 5, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "ImageFile.LOAD_TRUNCATED_IMAGES = True\n" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 6, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "train_df = pd.read_csv('train_censored.csv')\n", 96 | "validate_df = pd.read_csv('val_censored.csv')\n", 97 | "test_df = pd.read_csv('test_censored.csv')" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 7, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "train_df = train_df[train_df.Race.isin(['ASIAN','BLACK/AFRICAN AMERICAN','WHITE'])]\n", 107 | "validate_df = validate_df[validate_df.Race.isin(['ASIAN','BLACK/AFRICAN AMERICAN','WHITE'])]\n", 108 | "test_df = test_df[test_df.Race.isin(['ASIAN','BLACK/AFRICAN AMERICAN','WHITE'])]" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 8, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "train_df.hiti_path = train_df.hiti_path.astype(str)\n", 118 | "validate_df.hiti_path = validate_df.hiti_path.astype(str)\n", 119 | "test_df.hiti_path = test_df.hiti_path.astype(str)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 9, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "data": { 129 | "text/plain": [ 130 | "184974" 131 | ] 132 | }, 133 | "execution_count": 9, 134 | "metadata": {}, 135 | "output_type": "execute_result" 136 | } 137 | ], 138 | "source": [ 139 | "len(train_df)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 10, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "#remove 0 byte images\n", 149 | "validate_df = validate_df[~validate_df.hiti_path.str.contains('406e0996e5f1cf082487d7d096574d10b46c0c52710222a4884db1cc|dd97e997cc2a4166dc6e192cb62e29553aa28f4671d98c9577e32cfd|6224290209c45bb2b3e07b3b3a27778d1d10f7953567b3c59158e099')]\n", 150 | "test_df = test_df[~test_df.hiti_path.str.contains('406e0996e5f1cf082487d7d096574d10b46c0c52710222a4884db1cc|dd97e997cc2a4166dc6e192cb62e29553aa28f4671d98c9577e32cfd|6224290209c45bb2b3e07b3b3a27778d1d10f7953567b3c59158e099')]\n", 151 | "train_df = train_df[~train_df.hiti_path.str.contains('406e0996e5f1cf082487d7d096574d10b46c0c52710222a4884db1cc|dd97e997cc2a4166dc6e192cb62e29553aa28f4671d98c9577e32cfd|6224290209c45bb2b3e07b3b3a27778d1d10f7953567b3c59158e099')]" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 11, 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "data": { 161 | "text/plain": [ 162 | "WHITE 91369\n", 163 | "BLACK/AFRICAN AMERICAN 87139\n", 164 | "ASIAN 6457\n", 165 | "Name: Race, dtype: int64" 166 | ] 167 | }, 168 | "execution_count": 11, 169 | "metadata": {}, 170 | "output_type": "execute_result" 171 | } 172 | ], 173 | "source": [ 174 | "train_df.Race.value_counts()" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 12, 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "data": { 184 | "text/plain": [ 185 | "BLACK/AFRICAN AMERICAN 7540\n", 186 | "WHITE 6656\n", 187 | "ASIAN 530\n", 188 | "Name: Race, dtype: int64" 189 | ] 190 | }, 191 | "execution_count": 12, 192 | "metadata": {}, 193 | "output_type": "execute_result" 194 | } 195 | ], 196 | "source": [ 197 | "validate_df.Race.value_counts()" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 13, 203 | "metadata": {}, 204 | "outputs": [ 205 | { 206 | "data": { 207 | "text/plain": [ 208 | "BLACK/AFRICAN AMERICAN 6067\n", 209 | "WHITE 5281\n", 210 | "ASIAN 484\n", 211 | "Name: Race, dtype: int64" 212 | ] 213 | }, 214 | "execution_count": 13, 215 | "metadata": {}, 216 | "output_type": "execute_result" 217 | } 218 | ], 219 | "source": [ 220 | "test_df.Race.value_counts()" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 14, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "HEIGHT, WIDTH = 320, 320" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 15, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "arc_name = \"Emory_CXR-\" + str(HEIGHT) + \"x\" + str(WIDTH) + \"resnet34-Float32_3-race_\"" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 16, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "from tensorflow.keras.mixed_precision import experimental as mixed_precision\n", 248 | "policy = mixed_precision.Policy('mixed_float16')\n", 249 | "\n", 250 | "mixed_precision.set_policy(policy)" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 17, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "resnet34, preprocess_input = Classifiers.get('resnet34')\n" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 18, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "input_a = Input(shape=(HEIGHT, WIDTH, 3))\n", 269 | "base_model = resnet34(input_tensor=input_a, include_top=False, input_shape=(HEIGHT,WIDTH,3), weights='imagenet')\n", 270 | "x = GlobalAveragePooling2D()(base_model.output)\n", 271 | "x = layers.Dense(3, name='dense_logits')(x)\n", 272 | "output = layers.Activation('softmax', dtype='float32', name='predictions')(x)\n", 273 | "model = Model(inputs=[input_a], outputs=[output])" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 19, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "learning_rate = 1e-3\n", 283 | "decay_val= 0.0\n", 284 | "batch_s = 256\n", 285 | "desired_epoch = 3\n", 286 | "train_batch_size = batch_s\n", 287 | "test_batch_size = 64" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": 20, 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [ 296 | "reduce_lr = ReduceLROnPlateau(monitor='val_loss', mode='min', factor=0.1,\n", 297 | " patience=2, min_lr=1e-5, verbose=1)" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 21, 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "adam_opt = optimizers.Adam(lr=learning_rate, decay=decay_val)\n" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 22, 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [ 315 | "model.compile(optimizer=adam_opt,\n", 316 | " loss=tf.losses.CategoricalCrossentropy(),\n", 317 | " metrics=[\n", 318 | " tf.keras.metrics.AUC(curve='ROC', name='ROC-AUC'),\n", 319 | " tf.keras.metrics.AUC(curve='PR', name='PR-AUC')\n", 320 | " ],\n", 321 | ")" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 23, 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "train_gen = ImageDataGenerator(\n", 331 | " rotation_range=15, \n", 332 | " fill_mode='constant',\n", 333 | " horizontal_flip=True,\n", 334 | " zoom_range=0.1,\n", 335 | " preprocessing_function=preprocess_input\n", 336 | ")\n", 337 | "\n", 338 | "validate_gen = ImageDataGenerator(preprocessing_function=preprocess_input)" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 24, 344 | "metadata": {}, 345 | "outputs": [ 346 | { 347 | "name": "stdout", 348 | "output_type": "stream", 349 | "text": [ 350 | "Found 184965 validated image filenames belonging to 3 classes.\n", 351 | "Found 14726 validated image filenames belonging to 3 classes.\n" 352 | ] 353 | } 354 | ], 355 | "source": [ 356 | "train_batches = train_gen.flow_from_dataframe(train_df, x_col=\"hiti_path\", y_col=\"Race\", class_mode=\"categorical\",target_size=(HEIGHT, WIDTH),shuffle=True,seed=2021,batch_size=train_batch_size, dtype='float32')\n", 357 | "\n", 358 | "validate_batches = validate_gen.flow_from_dataframe(validate_df,x_col=\"hiti_path\", y_col=\"Race\", class_mode=\"categorical\",target_size=(HEIGHT, WIDTH),shuffle=False,batch_size=test_batch_size, dtype='float32') \n", 359 | "\n" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": 25, 365 | "metadata": {}, 366 | "outputs": [], 367 | "source": [ 368 | "train_epoch = math.ceil(len(train_df) / train_batch_size)\n", 369 | "val_epoch = math.ceil(len(validate_df) / test_batch_size)" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": 26, 375 | "metadata": {}, 376 | "outputs": [], 377 | "source": [ 378 | "var_date = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", 379 | "ES = EarlyStopping(monitor='val_loss', mode='min', patience=4, restore_best_weights=True)\n", 380 | "checkloss = ModelCheckpoint(\"../saved_models/\" + str(arc_name) + \"_LR-\" + str(learning_rate) + \"_\" + var_date+\"_epoch:{epoch:03d}_val_loss:{val_loss:.5f}.hdf5\", monitor='val_loss', mode='min', verbose=1, save_best_only=True, save_weights_only=False)\n" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 27, 386 | "metadata": {}, 387 | "outputs": [ 388 | { 389 | "name": "stdout", 390 | "output_type": "stream", 391 | "text": [ 392 | "Epoch 1/100\n", 393 | "722/723 [============================>.] - ETA: 12s - loss: 0.4047 - ROC-AUC: 0.9553 - PR-AUC: 0.9212\n", 394 | "Epoch 00001: val_loss improved from inf to 0.30497, saving model to ../saved_models/Emory_CXR-320x320resnet34-Float32_3-race__LR-0.001_20210627-214820_epoch:001_val_loss:0.30497.hdf5\n", 395 | "723/723 [==============================] - 10084s 14s/step - loss: 0.4045 - ROC-AUC: 0.9553 - PR-AUC: 0.9213 - val_loss: 0.3050 - val_ROC-AUC: 0.9739 - val_PR-AUC: 0.9548\n", 396 | "Epoch 2/100\n", 397 | "722/723 [============================>.] - ETA: 11s - loss: 0.2760 - ROC-AUC: 0.9784 - PR-AUC: 0.9613\n", 398 | "Epoch 00002: val_loss improved from 0.30497 to 0.22691, saving model to ../saved_models/Emory_CXR-320x320resnet34-Float32_3-race__LR-0.001_20210627-214820_epoch:002_val_loss:0.22691.hdf5\n", 399 | "723/723 [==============================] - 9132s 13s/step - loss: 0.2760 - ROC-AUC: 0.9784 - PR-AUC: 0.9613 - val_loss: 0.2269 - val_ROC-AUC: 0.9853 - val_PR-AUC: 0.9739\n", 400 | "Epoch 3/100\n", 401 | "722/723 [============================>.] - ETA: 11s - loss: 0.2413 - ROC-AUC: 0.9832 - PR-AUC: 0.9696\n", 402 | "Epoch 00003: val_loss did not improve from 0.22691\n", 403 | "723/723 [==============================] - 9356s 13s/step - loss: 0.2413 - ROC-AUC: 0.9832 - PR-AUC: 0.9696 - val_loss: 0.2996 - val_ROC-AUC: 0.9760 - val_PR-AUC: 0.9574\n", 404 | "Epoch 4/100\n", 405 | "722/723 [============================>.] - ETA: 10s - loss: 0.2202 - ROC-AUC: 0.9857 - PR-AUC: 0.9741\n", 406 | "Epoch 00004: ReduceLROnPlateau reducing learning rate to 0.00010000000474974513.\n", 407 | "\n", 408 | "Epoch 00004: val_loss did not improve from 0.22691\n", 409 | "723/723 [==============================] - 8650s 12s/step - loss: 0.2202 - ROC-AUC: 0.9858 - PR-AUC: 0.9741 - val_loss: 0.2553 - val_ROC-AUC: 0.9821 - val_PR-AUC: 0.9687\n", 410 | "Epoch 5/100\n", 411 | "722/723 [============================>.] - ETA: 12s - loss: 0.1675 - ROC-AUC: 0.9911 - PR-AUC: 0.9837\n", 412 | "Epoch 00005: val_loss improved from 0.22691 to 0.17858, saving model to ../saved_models/Emory_CXR-320x320resnet34-Float32_3-race__LR-0.001_20210627-214820_epoch:005_val_loss:0.17858.hdf5\n", 413 | "723/723 [==============================] - 9727s 13s/step - loss: 0.1675 - ROC-AUC: 0.9911 - PR-AUC: 0.9837 - val_loss: 0.1786 - val_ROC-AUC: 0.9902 - val_PR-AUC: 0.9826\n", 414 | "Epoch 6/100\n", 415 | "722/723 [============================>.] - ETA: 13s - loss: 0.1521 - ROC-AUC: 0.9925 - PR-AUC: 0.9861\n", 416 | "Epoch 00006: val_loss improved from 0.17858 to 0.16077, saving model to ../saved_models/Emory_CXR-320x320resnet34-Float32_3-race__LR-0.001_20210627-214820_epoch:006_val_loss:0.16077.hdf5\n", 417 | "723/723 [==============================] - 10374s 14s/step - loss: 0.1521 - ROC-AUC: 0.9925 - PR-AUC: 0.9861 - val_loss: 0.1608 - val_ROC-AUC: 0.9916 - val_PR-AUC: 0.9851\n", 418 | "Epoch 7/100\n", 419 | "722/723 [============================>.] - ETA: 13s - loss: 0.1443 - ROC-AUC: 0.9930 - PR-AUC: 0.9870\n", 420 | "Epoch 00007: val_loss did not improve from 0.16077\n", 421 | "723/723 [==============================] - 10532s 15s/step - loss: 0.1444 - ROC-AUC: 0.9930 - PR-AUC: 0.9870 - val_loss: 0.1687 - val_ROC-AUC: 0.9909 - val_PR-AUC: 0.9840\n", 422 | "Epoch 8/100\n", 423 | "722/723 [============================>.] - ETA: 9s - loss: 0.1390 - ROC-AUC: 0.9935 - PR-AUC: 0.9879 \n", 424 | "Epoch 00008: ReduceLROnPlateau reducing learning rate to 1.0000000474974514e-05.\n", 425 | "\n", 426 | "Epoch 00008: val_loss did not improve from 0.16077\n", 427 | "723/723 [==============================] - 7289s 10s/step - loss: 0.1389 - ROC-AUC: 0.9935 - PR-AUC: 0.9879 - val_loss: 0.1789 - val_ROC-AUC: 0.9901 - val_PR-AUC: 0.9825\n", 428 | "Epoch 9/100\n", 429 | "722/723 [============================>.] - ETA: 8s - loss: 0.1294 - ROC-AUC: 0.9942 - PR-AUC: 0.9892 \n", 430 | "Epoch 00009: val_loss did not improve from 0.16077\n", 431 | "723/723 [==============================] - 6371s 9s/step - loss: 0.1294 - ROC-AUC: 0.9942 - PR-AUC: 0.9892 - val_loss: 0.1657 - val_ROC-AUC: 0.9913 - val_PR-AUC: 0.9846\n", 432 | "Epoch 10/100\n", 433 | "722/723 [============================>.] - ETA: 7s - loss: 0.1261 - ROC-AUC: 0.9944 - PR-AUC: 0.9896 \n", 434 | "Epoch 00010: ReduceLROnPlateau reducing learning rate to 1e-05.\n", 435 | "\n", 436 | "Epoch 00010: val_loss did not improve from 0.16077\n", 437 | "723/723 [==============================] - 6292s 9s/step - loss: 0.1262 - ROC-AUC: 0.9944 - PR-AUC: 0.9896 - val_loss: 0.1643 - val_ROC-AUC: 0.9914 - val_PR-AUC: 0.9848\n" 438 | ] 439 | }, 440 | { 441 | "data": { 442 | "text/plain": [ 443 | "" 444 | ] 445 | }, 446 | "execution_count": 27, 447 | "metadata": {}, 448 | "output_type": "execute_result" 449 | } 450 | ], 451 | "source": [ 452 | "model.fit_generator(\n", 453 | " train_batches, \n", 454 | " steps_per_epoch=train_epoch,\n", 455 | " initial_epoch=0,\n", 456 | " epochs=100, \n", 457 | " verbose=1, \n", 458 | " callbacks=[reduce_lr, checkloss, ES],\n", 459 | " validation_data=validate_batches, \n", 460 | " validation_steps=val_epoch, \n", 461 | " validation_freq=1,\n", 462 | " class_weight=None,\n", 463 | " max_queue_size=10,\n", 464 | " workers=32,\n", 465 | " use_multiprocessing=False,\n", 466 | " shuffle=True\n", 467 | ")\n" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": 28, 473 | "metadata": {}, 474 | "outputs": [], 475 | "source": [ 476 | "test_batch_size = 32" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": 29, 482 | "metadata": {}, 483 | "outputs": [ 484 | { 485 | "name": "stdout", 486 | "output_type": "stream", 487 | "text": [ 488 | "Found 11832 validated image filenames belonging to 3 classes.\n" 489 | ] 490 | } 491 | ], 492 | "source": [ 493 | "test_batches = validate_gen.flow_from_dataframe(test_df,x_col=\"hiti_path\", y_col=\"Race\", class_mode=\"categorical\",target_size=(HEIGHT, WIDTH),shuffle=False,batch_size=test_batch_size, dtype='float32') \n" 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": 30, 499 | "metadata": {}, 500 | "outputs": [ 501 | { 502 | "name": "stdout", 503 | "output_type": "stream", 504 | "text": [ 505 | "370/370 [==============================] - 1057s 3s/step\n" 506 | ] 507 | } 508 | ], 509 | "source": [ 510 | "multilabel_predict_test = model.predict(test_batches, max_queue_size=10, verbose=1, steps=math.ceil(len(test_df)/test_batch_size), workers=16)\n" 511 | ] 512 | }, 513 | { 514 | "cell_type": "code", 515 | "execution_count": 31, 516 | "metadata": {}, 517 | "outputs": [], 518 | "source": [ 519 | "input_prediction = multilabel_predict_test\n", 520 | "input_df = test_df\n", 521 | "input_prediction_df = pd.DataFrame(input_prediction)\n", 522 | "true_logits = pd.DataFrame()\n", 523 | "loss_log = pd.DataFrame()\n", 524 | "#input_prediction_df = np.transpose(input_prediction_df)" 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": 32, 530 | "metadata": {}, 531 | "outputs": [], 532 | "source": [ 533 | "def stat_calc(input_prediction_df, input_df):\n", 534 | " ground_truth = input_df.Race\n", 535 | " #ground_truth = ground_truth.apply(', '.join)\n", 536 | " pathology_array=[\n", 537 | " 'ASIAN',\n", 538 | " 'BLACK/AFRICAN AMERICAN',\n", 539 | " 'WHITE'\n", 540 | " ]\n", 541 | "\n", 542 | " i=0\n", 543 | " auc_array = []\n", 544 | " for pathology in pathology_array:\n", 545 | " \n", 546 | " new_truth = (ground_truth.str.contains(pathology)).apply(int)\n", 547 | " input_prediction_val = input_prediction_df[i]\n", 548 | " val = input_prediction_val\n", 549 | " AUC = roc_auc_score(new_truth, val)\n", 550 | " true_logits.insert(i, i, new_truth, True)\n", 551 | " auc_array.append(AUC)\n", 552 | " i += 1\n", 553 | " \n", 554 | " progress_df = pd.DataFrame({'Study':pathology_array, 'AUC':auc_array})\n", 555 | " print(progress_df)\n" 556 | ] 557 | }, 558 | { 559 | "cell_type": "code", 560 | "execution_count": 33, 561 | "metadata": {}, 562 | "outputs": [ 563 | { 564 | "name": "stdout", 565 | "output_type": "stream", 566 | "text": [ 567 | " Study AUC\n", 568 | "0 ASIAN 0.969191\n", 569 | "1 BLACK/AFRICAN AMERICAN 0.992430\n", 570 | "2 WHITE 0.987709\n" 571 | ] 572 | } 573 | ], 574 | "source": [ 575 | "stat_calc(input_prediction_df, input_df)" 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": null, 581 | "metadata": {}, 582 | "outputs": [], 583 | "source": [] 584 | } 585 | ], 586 | "metadata": { 587 | "kernelspec": { 588 | "display_name": "Python 3", 589 | "language": "python", 590 | "name": "python3" 591 | }, 592 | "language_info": { 593 | "codemirror_mode": { 594 | "name": "ipython", 595 | "version": 3 596 | }, 597 | "file_extension": ".py", 598 | "mimetype": "text/x-python", 599 | "name": "python", 600 | "nbconvert_exporter": "python", 601 | "pygments_lexer": "ipython3", 602 | "version": "3.7.3" 603 | } 604 | }, 605 | "nbformat": 4, 606 | "nbformat_minor": 4 607 | } 608 | -------------------------------------------------------------------------------- /training_code/EM-CS_training/Emory_C-spine_race_detection_2021_06_29.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "\n", 10 | "import os\n", 11 | "import copy\n", 12 | "import sys\n", 13 | "import math\n", 14 | "from datetime import datetime\n", 15 | "import random\n", 16 | "import numpy as np\n", 17 | "import pandas as pd\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve, precision_recall_curve, matthews_corrcoef, auc, accuracy_score, recall_score, precision_score, f1_score\n", 20 | "from sklearn.utils import shuffle\n", 21 | "import tensorflow as tf\n", 22 | "from tensorflow.keras import backend as K\n", 23 | "from tensorflow.keras.optimizers import Adam,RMSprop,SGD\n", 24 | "from tensorflow.keras import layers\n", 25 | "from tensorflow.keras.layers import concatenate, add, GlobalAveragePooling2D, BatchNormalization, Input, Dense\n", 26 | "from tensorflow.keras.models import Model\n", 27 | "from tensorflow.keras import initializers\n", 28 | "from tensorflow.keras.callbacks import EarlyStopping, LearningRateScheduler, ReduceLROnPlateau, ModelCheckpoint\n", 29 | "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n", 30 | "from tensorflow.keras.preprocessing.image import load_img\n", 31 | "from tensorflow.keras import optimizers\n", 32 | "from tensorflow.keras.models import Sequential\n", 33 | "#from tensorflow.keras.applications.densenet import DenseNet121\n", 34 | "from classification_models.tfkeras import Classifiers\n", 35 | "from tensorflow.keras.models import load_model\n", 36 | "import random as python_random\n" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "\n", 46 | "np.random.seed(2021)\n", 47 | "python_random.seed(2021)\n", 48 | "tf.random.set_seed(2021)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 3, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "2.0.0\n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "print(tf.__version__)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 4, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "import os\n", 75 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"3\"" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 5, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "compare_df = pd.read_csv('cspin_split_80-10-10.csv')" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 6, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "data_df = pd.read_csv('cspin_split_80-10-10_ver_C.csv')" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 7, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "compare_df = compare_df[compare_df.ViewPosition.isin(['LATERAL','LATERAL FLEX','LATERAL EXT'])]" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 8, 108 | "metadata": {}, 109 | "outputs": [ 110 | { 111 | "data": { 112 | "text/plain": [ 113 | "13287" 114 | ] 115 | }, 116 | "execution_count": 8, 117 | "metadata": {}, 118 | "output_type": "execute_result" 119 | } 120 | ], 121 | "source": [ 122 | "len(compare_df)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 9, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "split = data_df.Image.str.split(\"/\", n=6, expand=True)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 10, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "data_df.Image = \"../data/cspine_hardware/new_extract/\" + split[6]" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 11, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "compare_df = compare_df.rename(columns={\"png_path\": \"Image\"})" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 12, 155 | "metadata": {}, 156 | "outputs": [ 157 | { 158 | "data": { 159 | "text/plain": [ 160 | "13287" 161 | ] 162 | }, 163 | "execution_count": 12, 164 | "metadata": {}, 165 | "output_type": "execute_result" 166 | } 167 | ], 168 | "source": [ 169 | "len(compare_df)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 13, 175 | "metadata": {}, 176 | "outputs": [ 177 | { 178 | "data": { 179 | "text/plain": [ 180 | "19118" 181 | ] 182 | }, 183 | "execution_count": 13, 184 | "metadata": {}, 185 | "output_type": "execute_result" 186 | } 187 | ], 188 | "source": [ 189 | "len(data_df)" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 14, 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "total_df = data_df.merge(compare_df, on='Image')" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 15, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "total_df = total_df.drop_duplicates(subset=\"Image\")" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 16, 213 | "metadata": {}, 214 | "outputs": [ 215 | { 216 | "data": { 217 | "text/plain": [ 218 | "10729" 219 | ] 220 | }, 221 | "execution_count": 16, 222 | "metadata": {}, 223 | "output_type": "execute_result" 224 | } 225 | ], 226 | "source": [ 227 | "len(total_df)" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 17, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "data_df = total_df" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 18, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "data_df = data_df[data_df.Race.isin(['African American or Black', 'Caucasian or White'])]" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 19, 251 | "metadata": {}, 252 | "outputs": [ 253 | { 254 | "data": { 255 | "text/plain": [ 256 | "10358" 257 | ] 258 | }, 259 | "execution_count": 19, 260 | "metadata": {}, 261 | "output_type": "execute_result" 262 | } 263 | ], 264 | "source": [ 265 | "len(data_df)" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 20, 271 | "metadata": {}, 272 | "outputs": [ 273 | { 274 | "data": { 275 | "text/plain": [ 276 | "980" 277 | ] 278 | }, 279 | "execution_count": 20, 280 | "metadata": {}, 281 | "output_type": "execute_result" 282 | } 283 | ], 284 | "source": [ 285 | "len(data_df.EMPI.unique())" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 21, 291 | "metadata": {}, 292 | "outputs": [ 293 | { 294 | "data": { 295 | "text/plain": [ 296 | "Caucasian or White 7589\n", 297 | "African American or Black 2769\n", 298 | "Name: Race, dtype: int64" 299 | ] 300 | }, 301 | "execution_count": 21, 302 | "metadata": {}, 303 | "output_type": "execute_result" 304 | } 305 | ], 306 | "source": [ 307 | "data_df.Race.value_counts()" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": 22, 313 | "metadata": {}, 314 | "outputs": [ 315 | { 316 | "data": { 317 | "text/plain": [ 318 | "Caucasian or White 0.73267\n", 319 | "African American or Black 0.26733\n", 320 | "Name: Race, dtype: float64" 321 | ] 322 | }, 323 | "execution_count": 22, 324 | "metadata": {}, 325 | "output_type": "execute_result" 326 | } 327 | ], 328 | "source": [ 329 | "data_df.Race.value_counts(normalize=True)" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 23, 335 | "metadata": {}, 336 | "outputs": [ 337 | { 338 | "data": { 339 | "text/plain": [ 340 | "F 5488\n", 341 | "M 4870\n", 342 | "Name: Sex, dtype: int64" 343 | ] 344 | }, 345 | "execution_count": 23, 346 | "metadata": {}, 347 | "output_type": "execute_result" 348 | } 349 | ], 350 | "source": [ 351 | "data_df.Sex.value_counts()" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 24, 357 | "metadata": {}, 358 | "outputs": [ 359 | { 360 | "data": { 361 | "text/plain": [ 362 | "train 0.805947\n", 363 | "validate 0.102819\n", 364 | "test 0.091234\n", 365 | "Name: split, dtype: float64" 366 | ] 367 | }, 368 | "execution_count": 24, 369 | "metadata": {}, 370 | "output_type": "execute_result" 371 | } 372 | ], 373 | "source": [ 374 | "data_df = data_df.rename(columns={\"split_x\": \"split\"})\n", 375 | "data_df.split.value_counts(normalize=True)" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 25, 381 | "metadata": {}, 382 | "outputs": [], 383 | "source": [ 384 | "data_df.Image = data_df.Image.astype(str)" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": 26, 390 | "metadata": {}, 391 | "outputs": [], 392 | "source": [ 393 | "from tensorflow.keras.mixed_precision import experimental as mixed_precision\n", 394 | "policy = mixed_precision.Policy('mixed_float16')\n", 395 | "\n", 396 | "mixed_precision.set_policy(policy)" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": 27, 402 | "metadata": {}, 403 | "outputs": [], 404 | "source": [ 405 | "HEIGHT, WIDTH = 320, 320" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": 28, 411 | "metadata": {}, 412 | "outputs": [], 413 | "source": [ 414 | "resnet34, preprocess_input = Classifiers.get('resnet34')" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": 29, 420 | "metadata": {}, 421 | "outputs": [], 422 | "source": [ 423 | "input_a = Input(shape=(HEIGHT, WIDTH, 3))\n", 424 | "base_model = resnet34(input_tensor=input_a, include_top=False, input_shape=(HEIGHT,WIDTH,3), weights='imagenet')\n", 425 | "x = GlobalAveragePooling2D()(base_model.output)\n", 426 | "x = layers.Dense(2, name='dense_logits')(x)\n", 427 | "output = layers.Activation('softmax', dtype='float32', name='predictions')(x)\n", 428 | "model = Model(inputs=[input_a], outputs=[output])" 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "execution_count": 30, 434 | "metadata": {}, 435 | "outputs": [], 436 | "source": [ 437 | "learning_rate = 1e-3\n", 438 | "decay_val= 0.0 \n", 439 | "batch_s = 256\n", 440 | "\n", 441 | "reduce_lr = ReduceLROnPlateau(monitor='val_loss', mode='min', factor=0.1,\n", 442 | " patience=2, min_lr=1e-5, verbose=1)\n", 443 | "\n", 444 | "adam_opt = optimizers.Adam(learning_rate=learning_rate, decay=decay_val)\n", 445 | "\n", 446 | "model.compile(optimizer=adam_opt,\n", 447 | " loss=tf.losses.CategoricalCrossentropy(name='loss'),\n", 448 | " metrics=[\n", 449 | " tf.keras.metrics.AUC(curve='ROC', name='ROC-AUC')\n", 450 | " ],\n", 451 | ")" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": 31, 457 | "metadata": {}, 458 | "outputs": [], 459 | "source": [ 460 | "train_gen = ImageDataGenerator(\n", 461 | " rotation_range=15,\n", 462 | " fill_mode='constant',\n", 463 | " zoom_range=0.1,\n", 464 | " horizontal_flip=True,\n", 465 | " preprocessing_function=preprocess_input\n", 466 | ")\n", 467 | "\n", 468 | "validate_gen = ImageDataGenerator(preprocessing_function=preprocess_input)" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": 32, 474 | "metadata": {}, 475 | "outputs": [], 476 | "source": [ 477 | "train_batch_size = batch_s\n", 478 | "test_batch_size = 64" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 33, 484 | "metadata": {}, 485 | "outputs": [], 486 | "source": [ 487 | "train_df = data_df[data_df.split==\"train\"]\n", 488 | "validate_df = data_df[data_df.split==\"validate\"]\n", 489 | "test_df = data_df[data_df.split==\"test\"]" 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": 34, 495 | "metadata": {}, 496 | "outputs": [ 497 | { 498 | "name": "stdout", 499 | "output_type": "stream", 500 | "text": [ 501 | "Found 8346 validated image filenames belonging to 2 classes.\n", 502 | "Found 1065 validated image filenames belonging to 2 classes.\n" 503 | ] 504 | }, 505 | { 506 | "name": "stderr", 507 | "output_type": "stream", 508 | "text": [ 509 | "/home/jupyter-brandon/.local/lib/python3.7/site-packages/keras_preprocessing/image/dataframe_iterator.py:282: UserWarning: Found 2 invalid image filename(s) in x_col=\"Image\". These filename(s) will be ignored.\n", 510 | " .format(n_invalid, x_col)\n" 511 | ] 512 | } 513 | ], 514 | "source": [ 515 | "train_batches = train_gen.flow_from_dataframe(train_df, x_col=\"Image\", y_col=\"Race\", class_mode=\"categorical\",target_size=(HEIGHT, WIDTH),shuffle=True,seed=2021,batch_size=train_batch_size, dtype='float32')\n", 516 | "\n", 517 | "validate_batches = validate_gen.flow_from_dataframe(validate_df, x_col=\"Image\", y_col=\"Race\", class_mode=\"categorical\",target_size=(HEIGHT, WIDTH),shuffle=False,batch_size=test_batch_size, dtype='float32') " 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": 35, 523 | "metadata": {}, 524 | "outputs": [], 525 | "source": [ 526 | "arc_name = \"resnet34_cspine_race_detection_with_random_seed_\"" 527 | ] 528 | }, 529 | { 530 | "cell_type": "code", 531 | "execution_count": 36, 532 | "metadata": {}, 533 | "outputs": [], 534 | "source": [ 535 | "train_epoch = math.ceil(len(train_df) / batch_s)\n", 536 | "val_epoch = math.ceil(len(validate_df) / test_batch_size)" 537 | ] 538 | }, 539 | { 540 | "cell_type": "code", 541 | "execution_count": 37, 542 | "metadata": {}, 543 | "outputs": [], 544 | "source": [ 545 | "var_date = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", 546 | "ES = EarlyStopping(monitor='val_loss', mode='min', patience=4, restore_best_weights=True)\n", 547 | "checkloss = ModelCheckpoint(\"../saved_models/\" + str(arc_name) + \"_CXR_\" +var_date+\"_epoch:{epoch:03d}_val_loss:{val_loss:.5f}.hdf5\", monitor='val_loss', mode='min', verbose=1, save_best_only=True, save_weights_only=False)\n" 548 | ] 549 | }, 550 | { 551 | "cell_type": "code", 552 | "execution_count": 38, 553 | "metadata": { 554 | "scrolled": false 555 | }, 556 | "outputs": [ 557 | { 558 | "name": "stdout", 559 | "output_type": "stream", 560 | "text": [ 561 | "Epoch 1/100\n", 562 | "32/33 [============================>.] - ETA: 10s - loss: 0.6883 - ROC-AUC: 0.7275\n", 563 | "Epoch 00001: val_loss improved from inf to 8.61255, saving model to ../saved_models/resnet34_cspine_race_detection_with_random_seed__CXR_20210627-212707_epoch:001_val_loss:8.61255.hdf5\n", 564 | "33/33 [==============================] - 454s 14s/step - loss: 0.6849 - ROC-AUC: 0.7281 - val_loss: 8.6126 - val_ROC-AUC: 0.3036\n", 565 | "Epoch 2/100\n", 566 | "32/33 [============================>.] - ETA: 9s - loss: 0.5373 - ROC-AUC: 0.8061 \n", 567 | "Epoch 00002: val_loss improved from 8.61255 to 2.10922, saving model to ../saved_models/resnet34_cspine_race_detection_with_random_seed__CXR_20210627-212707_epoch:002_val_loss:2.10922.hdf5\n", 568 | "33/33 [==============================] - 340s 10s/step - loss: 0.5358 - ROC-AUC: 0.8075 - val_loss: 2.1092 - val_ROC-AUC: 0.3446\n", 569 | "Epoch 3/100\n", 570 | "32/33 [============================>.] - ETA: 3s - loss: 0.4154 - ROC-AUC: 0.8911\n", 571 | "Epoch 00003: val_loss improved from 2.10922 to 0.82805, saving model to ../saved_models/resnet34_cspine_race_detection_with_random_seed__CXR_20210627-212707_epoch:003_val_loss:0.82805.hdf5\n", 572 | "33/33 [==============================] - 147s 4s/step - loss: 0.4144 - ROC-AUC: 0.8917 - val_loss: 0.8280 - val_ROC-AUC: 0.8110\n", 573 | "Epoch 4/100\n", 574 | "32/33 [============================>.] - ETA: 3s - loss: 0.3202 - ROC-AUC: 0.9370\n", 575 | "Epoch 00004: val_loss improved from 0.82805 to 0.68681, saving model to ../saved_models/resnet34_cspine_race_detection_with_random_seed__CXR_20210627-212707_epoch:004_val_loss:0.68681.hdf5\n", 576 | "33/33 [==============================] - 131s 4s/step - loss: 0.3192 - ROC-AUC: 0.9374 - val_loss: 0.6868 - val_ROC-AUC: 0.8475\n", 577 | "Epoch 5/100\n", 578 | "32/33 [============================>.] - ETA: 3s - loss: 0.2590 - ROC-AUC: 0.9589\n", 579 | "Epoch 00005: val_loss did not improve from 0.68681\n", 580 | "33/33 [==============================] - 131s 4s/step - loss: 0.2574 - ROC-AUC: 0.9594 - val_loss: 1.3166 - val_ROC-AUC: 0.7579\n", 581 | "Epoch 6/100\n", 582 | "32/33 [============================>.] - ETA: 3s - loss: 0.2043 - ROC-AUC: 0.9746\n", 583 | "Epoch 00006: ReduceLROnPlateau reducing learning rate to 0.00010000000474974513.\n", 584 | "\n", 585 | "Epoch 00006: val_loss did not improve from 0.68681\n", 586 | "33/33 [==============================] - 127s 4s/step - loss: 0.2036 - ROC-AUC: 0.9748 - val_loss: 1.7914 - val_ROC-AUC: 0.5367\n", 587 | "Epoch 7/100\n", 588 | "32/33 [============================>.] - ETA: 3s - loss: 0.1354 - ROC-AUC: 0.9895\n", 589 | "Epoch 00007: val_loss improved from 0.68681 to 0.42511, saving model to ../saved_models/resnet34_cspine_race_detection_with_random_seed__CXR_20210627-212707_epoch:007_val_loss:0.42511.hdf5\n", 590 | "33/33 [==============================] - 132s 4s/step - loss: 0.1356 - ROC-AUC: 0.9894 - val_loss: 0.4251 - val_ROC-AUC: 0.9272\n", 591 | "Epoch 8/100\n", 592 | "32/33 [============================>.] - ETA: 3s - loss: 0.0934 - ROC-AUC: 0.9950\n", 593 | "Epoch 00008: val_loss improved from 0.42511 to 0.39056, saving model to ../saved_models/resnet34_cspine_race_detection_with_random_seed__CXR_20210627-212707_epoch:008_val_loss:0.39056.hdf5\n", 594 | "33/33 [==============================] - 135s 4s/step - loss: 0.0929 - ROC-AUC: 0.9951 - val_loss: 0.3906 - val_ROC-AUC: 0.9422\n", 595 | "Epoch 9/100\n", 596 | "32/33 [============================>.] - ETA: 3s - loss: 0.0814 - ROC-AUC: 0.9958\n", 597 | "Epoch 00009: val_loss did not improve from 0.39056\n", 598 | "33/33 [==============================] - 123s 4s/step - loss: 0.0816 - ROC-AUC: 0.9958 - val_loss: 0.5122 - val_ROC-AUC: 0.9357\n", 599 | "Epoch 10/100\n", 600 | "32/33 [============================>.] - ETA: 3s - loss: 0.0679 - ROC-AUC: 0.9970\n", 601 | "Epoch 00010: ReduceLROnPlateau reducing learning rate to 1.0000000474974514e-05.\n", 602 | "\n", 603 | "Epoch 00010: val_loss did not improve from 0.39056\n", 604 | "33/33 [==============================] - 122s 4s/step - loss: 0.0684 - ROC-AUC: 0.9969 - val_loss: 0.5492 - val_ROC-AUC: 0.9314\n", 605 | "Epoch 11/100\n", 606 | "32/33 [============================>.] - ETA: 3s - loss: 0.0569 - ROC-AUC: 0.9979\n", 607 | "Epoch 00011: val_loss did not improve from 0.39056\n", 608 | "33/33 [==============================] - 123s 4s/step - loss: 0.0573 - ROC-AUC: 0.9979 - val_loss: 0.4744 - val_ROC-AUC: 0.9420\n", 609 | "Epoch 12/100\n", 610 | "32/33 [============================>.] - ETA: 3s - loss: 0.0549 - ROC-AUC: 0.9982\n", 611 | "Epoch 00012: ReduceLROnPlateau reducing learning rate to 1e-05.\n", 612 | "\n", 613 | "Epoch 00012: val_loss did not improve from 0.39056\n", 614 | "33/33 [==============================] - 119s 4s/step - loss: 0.0543 - ROC-AUC: 0.9983 - val_loss: 0.4206 - val_ROC-AUC: 0.9482\n" 615 | ] 616 | }, 617 | { 618 | "data": { 619 | "text/plain": [ 620 | "" 621 | ] 622 | }, 623 | "execution_count": 38, 624 | "metadata": {}, 625 | "output_type": "execute_result" 626 | } 627 | ], 628 | "source": [ 629 | "model.fit_generator(\n", 630 | " train_batches, \n", 631 | " steps_per_epoch=train_epoch, \n", 632 | " epochs=100, \n", 633 | " callbacks=[reduce_lr, checkloss, ES],\n", 634 | " validation_data=validate_batches, \n", 635 | " validation_steps=val_epoch, \n", 636 | " max_queue_size=10,\n", 637 | " workers=32,\n", 638 | " shuffle=True\n", 639 | ")\n" 640 | ] 641 | }, 642 | { 643 | "cell_type": "code", 644 | "execution_count": 39, 645 | "metadata": {}, 646 | "outputs": [ 647 | { 648 | "name": "stdout", 649 | "output_type": "stream", 650 | "text": [ 651 | "Found 945 validated image filenames belonging to 2 classes.\n" 652 | ] 653 | } 654 | ], 655 | "source": [ 656 | "test_batches = validate_gen.flow_from_dataframe(test_df, x_col=\"Image\", y_col=\"Race\", class_mode=\"categorical\",target_size=(HEIGHT, WIDTH),shuffle=False,batch_size=test_batch_size, dtype='float32') " 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": 40, 662 | "metadata": {}, 663 | "outputs": [ 664 | { 665 | "data": { 666 | "text/plain": [ 667 | "Caucasian or White 0.697354\n", 668 | "African American or Black 0.302646\n", 669 | "Name: Race, dtype: float64" 670 | ] 671 | }, 672 | "execution_count": 40, 673 | "metadata": {}, 674 | "output_type": "execute_result" 675 | } 676 | ], 677 | "source": [ 678 | "test_df.Race.value_counts(normalize=True)" 679 | ] 680 | }, 681 | { 682 | "cell_type": "code", 683 | "execution_count": 42, 684 | "metadata": {}, 685 | "outputs": [ 686 | { 687 | "name": "stdout", 688 | "output_type": "stream", 689 | "text": [ 690 | "15/15 [==============================] - 300s 20s/step\n" 691 | ] 692 | } 693 | ], 694 | "source": [ 695 | "race_multilabel_predict_test = model.predict(test_batches, max_queue_size=10, verbose=1, steps=math.ceil(len(test_df)/test_batch_size))" 696 | ] 697 | }, 698 | { 699 | "cell_type": "code", 700 | "execution_count": 43, 701 | "metadata": {}, 702 | "outputs": [], 703 | "source": [ 704 | "race_input_prediction = race_multilabel_predict_test\n", 705 | "input_df = test_df\n", 706 | "race_input_prediction_df = pd.DataFrame(race_input_prediction)\n", 707 | "race_true_logits = pd.DataFrame()\n", 708 | "race_loss_log = pd.DataFrame()" 709 | ] 710 | }, 711 | { 712 | "cell_type": "code", 713 | "execution_count": 44, 714 | "metadata": {}, 715 | "outputs": [], 716 | "source": [ 717 | "def stat_calc(input_prediction_df, input_df):\n", 718 | " ground_truth = input_df.Race\n", 719 | " #ground_truth = ground_truth.apply(', '.join)\n", 720 | " pathology_array=[\n", 721 | " 'African American or Black',\n", 722 | " 'Caucasian or White'\n", 723 | " ]\n", 724 | "\n", 725 | " i=0\n", 726 | " auc_array = []\n", 727 | " for pathology in pathology_array:\n", 728 | " \n", 729 | " new_truth = (ground_truth.str.contains(pathology)).apply(int)\n", 730 | " input_prediction_val = input_prediction_df[i]\n", 731 | " val = input_prediction_val\n", 732 | " AUC = roc_auc_score(new_truth, val)\n", 733 | " race_true_logits.insert(i, i, new_truth, True)\n", 734 | " auc_array.append(AUC)\n", 735 | " i += 1\n", 736 | " \n", 737 | " progress_df = pd.DataFrame({'Study':pathology_array, 'AUC':auc_array})\n", 738 | " print(progress_df)\n" 739 | ] 740 | }, 741 | { 742 | "cell_type": "code", 743 | "execution_count": 48, 744 | "metadata": {}, 745 | "outputs": [ 746 | { 747 | "name": "stdout", 748 | "output_type": "stream", 749 | "text": [ 750 | " Study AUC\n", 751 | "0 African American or Black 0.919251\n", 752 | "1 Caucasian or White 0.919251\n" 753 | ] 754 | } 755 | ], 756 | "source": [ 757 | "stat_calc(race_input_prediction_df, input_df)" 758 | ] 759 | } 760 | ], 761 | "metadata": { 762 | "kernelspec": { 763 | "display_name": "Python 3", 764 | "language": "python", 765 | "name": "python3" 766 | }, 767 | "language_info": { 768 | "codemirror_mode": { 769 | "name": "ipython", 770 | "version": 3 771 | }, 772 | "file_extension": ".py", 773 | "mimetype": "text/x-python", 774 | "name": "python", 775 | "nbconvert_exporter": "python", 776 | "pygments_lexer": "ipython3", 777 | "version": "3.7.3" 778 | } 779 | }, 780 | "nbformat": 4, 781 | "nbformat_minor": 4 782 | } 783 | -------------------------------------------------------------------------------- /training_code/CXR_training/CheXpert/CheXpert_resnet34_race_detection_2021_09_21_premade_splits.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# This model was made using a docker image\n", 10 | "# Docker image can be found at https://hub.docker.com/r/blackboxradiology/tf-2.6_with_pytorch\n", 11 | "# docker pull blackboxradiology/tf-2.6_with_pytorch\n", 12 | "\n", 13 | "# python version 3.6.9\n", 14 | "# mayplotlib version 3.3.4\n", 15 | "# numpy version 1.19.5\n", 16 | "# pandas version 1.1.5\n", 17 | "# PIL version 8.2.0\n", 18 | "# sklearn version 0.24.2\n", 19 | "# tensorflow version 2.6.0\n", 20 | "\n", 21 | "from datetime import datetime\n", 22 | "import glob\n", 23 | "import math\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "import matplotlib\n", 26 | "import numpy as np\n", 27 | "import os\n", 28 | "import pandas as pd\n", 29 | "from PIL import Image\n", 30 | "import random as python_random\n", 31 | "import seaborn as sns\n", 32 | "from sklearn.metrics import classification_report, roc_auc_score, roc_curve, precision_recall_curve\n", 33 | "from sklearn.metrics import auc, accuracy_score, recall_score, precision_score, f1_score, confusion_matrix\n", 34 | "from sklearn.utils import shuffle\n", 35 | "import sys\n", 36 | "import tensorflow as tf\n", 37 | "from tensorflow.keras.optimizers import Adam\n", 38 | "from tensorflow.keras.layers import GlobalAveragePooling2D, Input, Dense, Activation\n", 39 | "from tensorflow.keras.models import Model\n", 40 | "from tensorflow.keras import initializers\n", 41 | "from tensorflow.keras.callbacks import EarlyStopping, LearningRateScheduler, ReduceLROnPlateau, ModelCheckpoint\n", 42 | "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n", 43 | "from tensorflow.keras.models import load_model\n", 44 | "from tensorflow.keras.mixed_precision import experimental as mixed_precision\n", 45 | "\n", 46 | "# pip install image-classifiers==1.0.0b1\n", 47 | "from classification_models.tfkeras import Classifiers\n", 48 | "# More information about this package can be found at https://github.com/qubvel/classification_models" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK\n", 61 | "Your GPUs will likely run quickly with dtype policy mixed_float16 as they all have compute capability of at least 7.0\n", 62 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/mixed_precision/loss_scale.py:52: DynamicLossScale.__init__ (from tensorflow.python.training.experimental.loss_scale) is deprecated and will be removed in a future version.\n", 63 | "Instructions for updating:\n", 64 | "Use tf.keras.mixed_precision.LossScaleOptimizer instead. LossScaleOptimizer now has all the functionality of DynamicLossScale\n" 65 | ] 66 | } 67 | ], 68 | "source": [ 69 | "np.random.seed(2021)\n", 70 | "python_random.seed(2021)\n", 71 | "tf.random.set_seed(2021)\n", 72 | "\n", 73 | "policy = mixed_precision.Policy('mixed_float16')\n", 74 | "mixed_precision.set_policy(policy)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 3, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "# CheXpert images can be found: https://stanfordaimi.azurewebsites.net/datasets/8cbd9ed4-2eb9-4565-affc-111cf4f7ebe2\n", 84 | "data_df = pd.read_csv('chexpert_train.csv')\n", 85 | "\n", 86 | "# Demographic labels can be found: https://stanfordaimi.azurewebsites.net/datasets/192ada7c-4d43-466e-b8bb-b81992bb80cf\n", 87 | "demo_df = pd.DataFrame(pd.read_excel(\"CHEXPERT DEMO.xlsx\", engine='openpyxl')) #pip install openpyxl\n", 88 | "\n", 89 | "# 60-10-30, train-val-test split that we used\n", 90 | "# These splits can be found in this repository\n", 91 | "split_df = pd.read_csv('chexpert_split_2021_08_20.csv').set_index('index')\n", 92 | "\n", 93 | "\n", 94 | "# All preprocessing steps of CheXpert .jpg images are included in this repository\n", 95 | "# Image data preprocessing include resizing to 320x320\n", 96 | "# and normalizing images with ImageNet mean and standard deviation values\n", 97 | "# using resnet34, preprocess_input = Classifiers.get('resnet34') from the classification_models.tfkeras package" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 4, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "data_df = pd.concat([data_df,split_df], axis=1)\n", 107 | "data_df = data_df[~data_df.split.isna()]\n", 108 | "\n", 109 | "path_split = data_df.Path.str.split(\"/\", expand = True)\n", 110 | "data_df[\"patient_id\"] = path_split[2]\n", 111 | "demo_df = demo_df.rename(columns={'PATIENT': 'patient_id'})\n", 112 | "data_df = data_df.merge(demo_df, on=\"patient_id\")\n", 113 | "\n", 114 | "mask = (data_df.PRIMARY_RACE.str.contains(\"Black\", na=False))\n", 115 | "data_df.loc[mask, \"race\"] = \"BLACK/AFRICAN AMERICAN\"\n", 116 | "\n", 117 | "mask = (data_df.PRIMARY_RACE.str.contains(\"White\", na=False))\n", 118 | "data_df.loc[mask, \"race\"] = \"WHITE\"\n", 119 | "\n", 120 | "mask = (data_df.PRIMARY_RACE.str.contains(\"Asian\", na=False))\n", 121 | "data_df.loc[mask, \"race\"] = \"ASIAN\"" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 5, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "data": { 131 | "text/html": [ 132 | "
\n", 133 | "\n", 146 | "\n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | "
PathSexAgeFrontalLateralPositionNo FindingEnlarged CardiomediastinumCardiomegalyLung OpacityLung Lesion...FractureSupport DevicesUnnamed: 0splitpatient_idGENDERAGE_AT_CXRPRIMARY_RACEETHNICITYrace
\n", 176 | "

0 rows × 27 columns

\n", 177 | "
" 178 | ], 179 | "text/plain": [ 180 | "Empty DataFrame\n", 181 | "Columns: [Path, Sex, Age, FrontalLateral, Position, No Finding, Enlarged Cardiomediastinum, Cardiomegaly, Lung Opacity, Lung Lesion, Edema, Consolidation, Pneumonia, Atelectasis, Pneumothorax, Pleural Effusion, Pleural Other, Fracture, Support Devices, Unnamed: 0, split, patient_id, GENDER, AGE_AT_CXR, PRIMARY_RACE, ETHNICITY, race]\n", 182 | "Index: []\n", 183 | "\n", 184 | "[0 rows x 27 columns]" 185 | ] 186 | }, 187 | "execution_count": 5, 188 | "metadata": {}, 189 | "output_type": "execute_result" 190 | } 191 | ], 192 | "source": [ 193 | "data_df[:0]" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 6, 199 | "metadata": {}, 200 | "outputs": [ 201 | { 202 | "data": { 203 | "text/plain": [ 204 | "127118" 205 | ] 206 | }, 207 | "execution_count": 6, 208 | "metadata": {}, 209 | "output_type": "execute_result" 210 | } 211 | ], 212 | "source": [ 213 | "len(data_df)" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 7, 219 | "metadata": {}, 220 | "outputs": [ 221 | { 222 | "data": { 223 | "text/plain": [ 224 | "train 0.599482\n", 225 | "test 0.300823\n", 226 | "validate 0.099695\n", 227 | "Name: split, dtype: float64" 228 | ] 229 | }, 230 | "execution_count": 7, 231 | "metadata": {}, 232 | "output_type": "execute_result" 233 | } 234 | ], 235 | "source": [ 236 | "data_df.split.value_counts(normalize=True)" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 8, 242 | "metadata": {}, 243 | "outputs": [ 244 | { 245 | "data": { 246 | "text/plain": [ 247 | "WHITE 99027\n", 248 | "ASIAN 18830\n", 249 | "BLACK/AFRICAN AMERICAN 9261\n", 250 | "Name: race, dtype: int64" 251 | ] 252 | }, 253 | "execution_count": 8, 254 | "metadata": {}, 255 | "output_type": "execute_result" 256 | } 257 | ], 258 | "source": [ 259 | "data_df.race.value_counts()" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 9, 265 | "metadata": {}, 266 | "outputs": [ 267 | { 268 | "data": { 269 | "text/plain": [ 270 | "WHITE 0.779016\n", 271 | "ASIAN 0.148130\n", 272 | "BLACK/AFRICAN AMERICAN 0.072854\n", 273 | "Name: race, dtype: float64" 274 | ] 275 | }, 276 | "execution_count": 9, 277 | "metadata": {}, 278 | "output_type": "execute_result" 279 | } 280 | ], 281 | "source": [ 282 | "data_df.race.value_counts(normalize=True)" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 10, 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "train_df = data_df[data_df.split==\"train\"]\n", 292 | "validation_df = data_df[data_df.split==\"validate\"]\n", 293 | "test_df = data_df[data_df.split==\"test\"]" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 11, 299 | "metadata": {}, 300 | "outputs": [ 301 | { 302 | "data": { 303 | "text/plain": [ 304 | "False" 305 | ] 306 | }, 307 | "execution_count": 11, 308 | "metadata": {}, 309 | "output_type": "execute_result" 310 | } 311 | ], 312 | "source": [ 313 | "#False indicates no patient_id shared between groups\n", 314 | "\n", 315 | "unique_train_id = train_df.patient_id.unique()\n", 316 | "unique_validation_id = validation_df.patient_id.unique()\n", 317 | "unique_test_id = test_df.patient_id.unique()\n", 318 | "all_id = np.concatenate((unique_train_id, unique_validation_id, unique_test_id), axis=None)\n", 319 | "\n", 320 | "def contains_duplicates(X):\n", 321 | " return len(np.unique(X)) != len(X)\n", 322 | "\n", 323 | "contains_duplicates(all_id)" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 12, 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "HEIGHT, WIDTH = 320, 320\n", 333 | "\n", 334 | "arc_name = \"CHEXPERT-\" + str(HEIGHT) + \"x\" + str(WIDTH) + \"_60-10-30-split-resnet34-Float16_3-race_detection\"" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 13, 340 | "metadata": {}, 341 | "outputs": [], 342 | "source": [ 343 | "resnet34, preprocess_input = Classifiers.get('resnet34')" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 14, 349 | "metadata": {}, 350 | "outputs": [ 351 | { 352 | "name": "stdout", 353 | "output_type": "stream", 354 | "text": [ 355 | "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n", 356 | "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n", 357 | "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n", 358 | "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n", 359 | "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n", 360 | "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n", 361 | "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n", 362 | "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n", 363 | "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n", 364 | "INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).\n" 365 | ] 366 | } 367 | ], 368 | "source": [ 369 | "input_a = Input(shape=(HEIGHT, WIDTH, 3))\n", 370 | "base_model = resnet34(input_tensor=input_a, include_top=False, input_shape=(HEIGHT,WIDTH,3), weights='imagenet')\n", 371 | "x = GlobalAveragePooling2D()(base_model.output)\n", 372 | "x = Dense(3, name='dense_logits')(x)\n", 373 | "output = Activation('softmax', dtype='float32', name='predictions')(x)\n", 374 | "model = Model(inputs=[input_a], outputs=[output])" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": 15, 380 | "metadata": {}, 381 | "outputs": [], 382 | "source": [ 383 | "learning_rate = 1e-3\n", 384 | "momentum_val=0.9\n", 385 | "decay_val= 0.0\n", 386 | "train_batch_size = 256 # may need to reduce batch size if OOM error occurs\n", 387 | "test_batch_size = 256\n", 388 | "\n", 389 | "reduce_lr = ReduceLROnPlateau(monitor='val_loss', mode='min', factor=0.1, patience=2, min_lr=1e-5, verbose=1)\n", 390 | "\n", 391 | "adam_opt = tf.keras.optimizers.Adam(learning_rate=learning_rate, decay=decay_val)\n", 392 | "adam_opt = tf.keras.mixed_precision.LossScaleOptimizer(adam_opt)\n", 393 | "\n", 394 | "model.compile(optimizer=adam_opt,\n", 395 | " loss=tf.losses.CategoricalCrossentropy(),\n", 396 | " metrics=[\n", 397 | " tf.keras.metrics.AUC(curve='ROC', name='ROC-AUC'),\n", 398 | " tf.keras.metrics.AUC(curve='PR', name='PR-AUC')\n", 399 | " ],\n", 400 | ")" 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": 16, 406 | "metadata": {}, 407 | "outputs": [], 408 | "source": [ 409 | "train_gen = ImageDataGenerator(\n", 410 | " rotation_range=15,\n", 411 | " fill_mode='constant',\n", 412 | " horizontal_flip=True,\n", 413 | " zoom_range=0.1,\n", 414 | " preprocessing_function=preprocess_input\n", 415 | " )\n", 416 | "\n", 417 | "validate_gen = ImageDataGenerator(preprocessing_function=preprocess_input)" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": 17, 423 | "metadata": {}, 424 | "outputs": [ 425 | { 426 | "name": "stdout", 427 | "output_type": "stream", 428 | "text": [ 429 | "Found 76205 validated image filenames belonging to 3 classes.\n", 430 | "Found 12673 validated image filenames belonging to 3 classes.\n" 431 | ] 432 | } 433 | ], 434 | "source": [ 435 | "train_batches = train_gen.flow_from_dataframe(train_df, directory=\"/path/to/directory/\", x_col=\"Path\", y_col=\"race\", class_mode=\"categorical\",target_size=(HEIGHT, WIDTH),shuffle=True,seed=2021,batch_size=train_batch_size, dtype='float32')\n", 436 | "validate_batches = validate_gen.flow_from_dataframe(validation_df, directory=\"/path/to/directory/\", x_col=\"Path\", y_col=\"race\", class_mode=\"categorical\",target_size=(HEIGHT, WIDTH),shuffle=False,batch_size=test_batch_size, dtype='float32') " 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": 18, 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [ 445 | "train_epoch = math.ceil(len(train_df) / train_batch_size)\n", 446 | "val_epoch = math.ceil(len(validation_df) / test_batch_size)" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 19, 452 | "metadata": {}, 453 | "outputs": [], 454 | "source": [ 455 | "var_date = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", 456 | "ES = EarlyStopping(monitor='val_loss', mode='min', patience=4, restore_best_weights=True)\n", 457 | "checkloss = ModelCheckpoint(\"../saved_models/racial_bias/trials/\" + str(arc_name) + \"_CXR_LR-\" + str(learning_rate) + \"_\" + var_date+\"_epoch:{epoch:03d}_val_loss:{val_loss:.5f}.hdf5\", monitor='val_loss', mode='min', verbose=1, save_best_only=True, save_weights_only=False)" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": 40, 463 | "metadata": {}, 464 | "outputs": [ 465 | { 466 | "name": "stdout", 467 | "output_type": "stream", 468 | "text": [ 469 | "Epoch 1/100\n", 470 | "INFO:tensorflow:batch_all_reduce: 108 all-reduces with algorithm = nccl, num_packs = 1\n", 471 | "INFO:tensorflow:batch_all_reduce: 108 all-reduces with algorithm = nccl, num_packs = 1\n", 472 | "298/298 [==============================] - 551s 2s/step - loss: 0.4800 - ROC-AUC: 0.9374 - PR-AUC: 0.8913 - val_loss: 0.8220 - val_ROC-AUC: 0.9188 - val_PR-AUC: 0.8732\n", 473 | "\n", 474 | "Epoch 00001: val_loss improved from inf to 0.82205, saving model to ../saved_models/racial_bias/trials/CHEXPERT-320x320_60-10-30-split-resnet34-Float16_3-race_detection_CXR_LR-0.001_20210818-211230_epoch:001_val_loss:0.82205.hdf5\n" 475 | ] 476 | }, 477 | { 478 | "name": "stderr", 479 | "output_type": "stream", 480 | "text": [ 481 | "/usr/local/lib/python3.6/dist-packages/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n", 482 | " category=CustomMaskWarning)\n" 483 | ] 484 | }, 485 | { 486 | "name": "stdout", 487 | "output_type": "stream", 488 | "text": [ 489 | "Epoch 2/100\n", 490 | "298/298 [==============================] - 492s 2s/step - loss: 0.3307 - ROC-AUC: 0.9695 - PR-AUC: 0.9454 - val_loss: 0.4644 - val_ROC-AUC: 0.9483 - val_PR-AUC: 0.9092\n", 491 | "\n", 492 | "Epoch 00002: val_loss improved from 0.82205 to 0.46438, saving model to ../saved_models/racial_bias/trials/CHEXPERT-320x320_60-10-30-split-resnet34-Float16_3-race_detection_CXR_LR-0.001_20210818-211230_epoch:002_val_loss:0.46438.hdf5\n", 493 | "Epoch 3/100\n", 494 | "298/298 [==============================] - 494s 2s/step - loss: 0.2830 - ROC-AUC: 0.9772 - PR-AUC: 0.9589 - val_loss: 0.4079 - val_ROC-AUC: 0.9624 - val_PR-AUC: 0.9354\n", 495 | "\n", 496 | "Epoch 00003: val_loss improved from 0.46438 to 0.40790, saving model to ../saved_models/racial_bias/trials/CHEXPERT-320x320_60-10-30-split-resnet34-Float16_3-race_detection_CXR_LR-0.001_20210818-211230_epoch:003_val_loss:0.40790.hdf5\n", 497 | "Epoch 4/100\n", 498 | "298/298 [==============================] - 495s 2s/step - loss: 0.2545 - ROC-AUC: 0.9812 - PR-AUC: 0.9656 - val_loss: 0.3118 - val_ROC-AUC: 0.9755 - val_PR-AUC: 0.9562\n", 499 | "\n", 500 | "Epoch 00004: val_loss improved from 0.40790 to 0.31179, saving model to ../saved_models/racial_bias/trials/CHEXPERT-320x320_60-10-30-split-resnet34-Float16_3-race_detection_CXR_LR-0.001_20210818-211230_epoch:004_val_loss:0.31179.hdf5\n", 501 | "Epoch 5/100\n", 502 | "298/298 [==============================] - 497s 2s/step - loss: 0.2407 - ROC-AUC: 0.9829 - PR-AUC: 0.9685 - val_loss: 0.3397 - val_ROC-AUC: 0.9752 - val_PR-AUC: 0.9542\n", 503 | "\n", 504 | "Epoch 00005: val_loss did not improve from 0.31179\n", 505 | "Epoch 6/100\n", 506 | "298/298 [==============================] - 496s 2s/step - loss: 0.2257 - ROC-AUC: 0.9849 - PR-AUC: 0.9723 - val_loss: 0.4189 - val_ROC-AUC: 0.9600 - val_PR-AUC: 0.9311\n", 507 | "\n", 508 | "Epoch 00006: val_loss did not improve from 0.31179\n", 509 | "\n", 510 | "Epoch 00006: ReduceLROnPlateau reducing learning rate to 0.00010000000474974513.\n", 511 | "Epoch 7/100\n", 512 | "298/298 [==============================] - 497s 2s/step - loss: 0.1763 - ROC-AUC: 0.9901 - PR-AUC: 0.9813 - val_loss: 0.2191 - val_ROC-AUC: 0.9852 - val_PR-AUC: 0.9735\n", 513 | "\n", 514 | "Epoch 00007: val_loss improved from 0.31179 to 0.21913, saving model to ../saved_models/racial_bias/trials/CHEXPERT-320x320_60-10-30-split-resnet34-Float16_3-race_detection_CXR_LR-0.001_20210818-211230_epoch:007_val_loss:0.21913.hdf5\n", 515 | "Epoch 8/100\n", 516 | "298/298 [==============================] - 496s 2s/step - loss: 0.1589 - ROC-AUC: 0.9918 - PR-AUC: 0.9846 - val_loss: 0.2221 - val_ROC-AUC: 0.9852 - val_PR-AUC: 0.9732\n", 517 | "\n", 518 | "Epoch 00008: val_loss did not improve from 0.21913\n", 519 | "Epoch 9/100\n", 520 | "298/298 [==============================] - 496s 2s/step - loss: 0.1507 - ROC-AUC: 0.9924 - PR-AUC: 0.9855 - val_loss: 0.2251 - val_ROC-AUC: 0.9849 - val_PR-AUC: 0.9717\n", 521 | "\n", 522 | "Epoch 00009: val_loss did not improve from 0.21913\n", 523 | "\n", 524 | "Epoch 00009: ReduceLROnPlateau reducing learning rate to 1.0000000474974514e-05.\n", 525 | "Epoch 10/100\n", 526 | "298/298 [==============================] - 496s 2s/step - loss: 0.1411 - ROC-AUC: 0.9933 - PR-AUC: 0.9872 - val_loss: 0.2234 - val_ROC-AUC: 0.9853 - val_PR-AUC: 0.9730\n", 527 | "\n", 528 | "Epoch 00010: val_loss did not improve from 0.21913\n", 529 | "Epoch 11/100\n", 530 | "298/298 [==============================] - 497s 2s/step - loss: 0.1385 - ROC-AUC: 0.9935 - PR-AUC: 0.9879 - val_loss: 0.2250 - val_ROC-AUC: 0.9852 - val_PR-AUC: 0.9728\n", 531 | "\n", 532 | "Epoch 00011: val_loss did not improve from 0.21913\n", 533 | "\n", 534 | "Epoch 00011: ReduceLROnPlateau reducing learning rate to 1e-05.\n" 535 | ] 536 | }, 537 | { 538 | "data": { 539 | "text/plain": [ 540 | "" 541 | ] 542 | }, 543 | "execution_count": 40, 544 | "metadata": {}, 545 | "output_type": "execute_result" 546 | } 547 | ], 548 | "source": [ 549 | "model.fit(train_batches,\n", 550 | " validation_data=validate_batches,\n", 551 | " epochs=100,\n", 552 | " steps_per_epoch=int(train_epoch),\n", 553 | " validation_steps=int(val_epoch),\n", 554 | " workers=32,\n", 555 | " max_queue_size=50,\n", 556 | " shuffle=False,\n", 557 | " callbacks=[checkloss, reduce_lr, ES]\n", 558 | " )" 559 | ] 560 | }, 561 | { 562 | "cell_type": "code", 563 | "execution_count": 41, 564 | "metadata": {}, 565 | "outputs": [ 566 | { 567 | "name": "stdout", 568 | "output_type": "stream", 569 | "text": [ 570 | "Found 38240 validated image filenames belonging to 3 classes.\n" 571 | ] 572 | } 573 | ], 574 | "source": [ 575 | "test_batches = validate_gen.flow_from_dataframe(test_df, directory=\"/path/to/directory/\", x_col=\"Path\", y_col=\"race\", class_mode=\"categorical\",target_size=(HEIGHT, WIDTH),shuffle=False,batch_size=test_batch_size, dtype='float32') " 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": 42, 581 | "metadata": {}, 582 | "outputs": [ 583 | { 584 | "name": "stdout", 585 | "output_type": "stream", 586 | "text": [ 587 | "150/150 [==============================] - 98s 629ms/step\n" 588 | ] 589 | } 590 | ], 591 | "source": [ 592 | "multilabel_predict_test = model.predict(test_batches, max_queue_size=10, verbose=1, steps=math.ceil(len(test_df)/test_batch_size), workers=16)" 593 | ] 594 | }, 595 | { 596 | "cell_type": "code", 597 | "execution_count": 48, 598 | "metadata": {}, 599 | "outputs": [ 600 | { 601 | "name": "stdout", 602 | "output_type": "stream", 603 | "text": [ 604 | "Classwise ROC AUC \n", 605 | "\n", 606 | "Class - Asian ROC-AUC- 0.98\n", 607 | "Class - Black ROC-AUC- 0.98\n", 608 | "Class - White ROC-AUC- 0.97\n", 609 | " precision recall f1-score support\n", 610 | "\n", 611 | " Asian 0.87 0.82 0.85 5650\n", 612 | " Black 0.89 0.75 0.81 2746\n", 613 | " White 0.95 0.97 0.96 29844\n", 614 | "\n", 615 | " accuracy 0.94 38240\n", 616 | " macro avg 0.90 0.85 0.87 38240\n", 617 | "weighted avg 0.93 0.94 0.93 38240\n", 618 | "\n" 619 | ] 620 | }, 621 | { 622 | "data": { 623 | "text/plain": [ 624 | "" 625 | ] 626 | }, 627 | "execution_count": 48, 628 | "metadata": {}, 629 | "output_type": "execute_result" 630 | }, 631 | { 632 | "data": { 633 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW0AAAD4CAYAAAAn3bdmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAhe0lEQVR4nO3de5yN5f7/8ddnTsyQ02DCkOOun0qKpK0jlUPa6GBLBzqgIimVU1GidNJOotBEexOKooMtOSTayiHfkMrssJvJIafIaJvD9ftj3U2LPWaGZiz37f3cj/sxa13rvq/1udfWe6513deaZc45RETEH6IiXYCIiBSeQltExEcU2iIiPqLQFhHxEYW2iIiPxBT3E3y7NUPLU4rZaRUTIl1C4GVm5US6hJPCKSWj7I/2EX9ur0JnzoEvR//h5zveNNIWEfGRYh9pi4gcVxbssahCW0SCJSo60hUUK4W2iASL+W6a+qgotEUkWDQ9IiLiIxppi4j4iEbaIiI+opG2iIiPaPWIiIiPaHpERMRHND0iIuIjGmmLiPiIQltExEeidSFSRMQ/NKctIuIjmh4REfERjbRFRHxEI20RER/RSFtExEf0MXYRER/R9IiIiI9oekRExEc00hYR8RGFtoiIj+hCpIiIj2hOW0TERwI+PRLssxORk49Z4bd8u7HqZrbQzL42s3Vmdp/X/piZpZvZam9rE3bMADNLNbNvzaxlWHsrry3VzPqHtdcys8+99mlmFlfQ6Sm0RSRQzKzQWwGygL7OufpAU6CnmdX3HnvBOdfQ2z70nrc+0Ak4E2gFjDGzaDOLBl4GWgP1gRvD+nna66susBu4o6CiFNoiEihFFdrOuS3OuVXe7X3AeqBaPoe0A6Y65/7rnNsIpAJNvC3VOfe9c+4gMBVoZ6ECmgNve8dPAtoXdH4KbREJFIuywm9m3c1sRdjWPc8+zWoC5wKfe029zOwrM0sxs/JeWzXgh7DD0ry2I7UnAnucc1mHtefrpLoQmZ2dzQPdbyKxUmUGjxiFc45/THiZpYvmERUVTet213PN9Z0BWPPlCiaMfpasrCzKlC3HU6NeA2Dl50uZ8NKzZOfkcNXV7bn+ptsjeUq+sHfvXh4f/Aipqd9hZjz+xJOULBnPsKFDyMjIoGrVajz1zHOULl060qX6ypuT3+CdGW+Bc7S/7gY639yFF0c+y+JPFhIbG0tycnWGDH2SU8qUYc+e3fTr24ev162l7V/a02/go5Euv9gUYtojl3NuHDCugP5KAzOAPs65vWY2FngCcN7P54HjFgQnVWi/9/YUqp9Wi4yM/QDMnzObHdu3Mubv7xAVFcWe3bsA+GXfPl554Ukee/ZlKiVVyW3Pzs7m1b+NYOjzY0mslETfHjfRpNml1KhZJ2Ln5AfPPDWcZhddzPN/G0XmwYMc+PVX7rrzNh54qB+Nz2/COzPfZmLKBHr17hPpUn0jdcN3vDPjLd6YPJ2Y2Fh639ONiy+5jAua/pmeve8nJiaGUS88x+uvjaP3/Q9SIq4Ed/fsTWrqBv6duiHS5RerowntQvQVSyiwJzvnZgI457aFPT4eeN+7mw5UDzs82WvjCO07gXJmFuONtsP3P6KTZnpkx/ZtrFi2hCvbdshtmzPrLf7apTtRUaGXoVz5CgAs/ngOF17SgkpJVQ5p37B+LVWqVefUqsnExsZycfOWfL5k0fE9EZ/Zt28fK1cup8N11wMQGxdHmTJl2Lx5E40anw/AhRc2Y/68jyJZpu9s2vg9Z53dgJLx8cTExHBeo/NZMH8eTf/cjJiY0Fjs7AbnsH17KF/iExJoeF4jSpQoEcmyj4uimtP25pxfA9Y750aGtVcJ260DsNa7PRvoZGYlzKwWUA/4AlgO1PNWisQRulg52znngIXA9d7xXYBZBZ1fgSNtMzuD0AT7b3Mt6d4Tri/o2BPJhNHP0vWu+ziQkZHbtvXHNJYs/Ihlny6gTNnydL/vYaomn0Z62mays7IYeN+dHMjI4JrrbqR5q2vYuWM7FSsn5R5fsVIS365fm9fTiSc9LY3y5SsweNAAvv32G+qfeSYP9x9Enbr1WLhgPs1bXMFHc//J1q1bIl2qr9SpW48xL/2NPXt2U7JESZYuWcz/q3/WIfvMfncmV7ZsHaEKI6joBtrNgFuANWa22msbSGj1R0NC0yObgB4Azrl1ZjYd+JrQypOezrlsADPrBcwFooEU59w6r79+wFQzGwZ8SeiXRL7yDW0z6wfcSOhq5xdeczLwpplNdc6NKPC0TwDLP1tM2XIVqHt6fdZ8uSK3PTPzILFxcYwcN4XPFs9n1IjHGTE6hezsbFK/W8+wka9y8L+/8tA9XTj9zAYRPAP/ys7O4pv1X9N/0KM0aHAOTz81jJQJ43j8ieGMeGo4414Zw2WXNyc2tsDlqRKmVu063HrbnfS6607i4+P50+lnEB39+xvn18a/QnR0NK2vviaCVUZGUU2POOeWkPevgA/zOWY4MDyP9g/zOs459z2h1SWFVtBI+w7gTOdcZnijmY0E1gF5hrZ3BbY7wOPPvMRfb4nsxbqv167mi88+YeXnSzh48CAZ+/fz/LBBJFZK4sJLWgBw4cXNGTXiMQAqVqpMmTJlKRkfT8n4eM485zw2pn5HxUqV2bE9dzqLHT9tI7FipUickm8kJZ1KUtKpNGhwDgBXXtWKlAnj6NW7D6+OTwFg06aNLP5kUQSr9Kf2115P+2tD76xfHvUClZNC7wLfm/UOSxYvYuy414t0ftcvfpvuDKqCzi4HqJpHexXvsTw558Y55xo75xpHOrABunTvzetvz2XCtA95aPAIGpx3Pn0fGU7Tiy5jzarlAKxdvZKqyTUAuKDZZXy9ZjXZWVn899cDfLd+LdVPq0W9M87kx7T/sHVLOpmZmXy6YC4XNLssgmd24qtYqRJJp57Kpo3fA/D5sn9Ru04ddu7cCUBOTg7jXx3LDX/tFMkyfWmX9xpu3fIjC+bPo1Xrtny29FPemPgaI18cQ8n4+AhXGBlF+OGaE1JBI+0+wHwz28Dv6wxrAHWBXsVY13FxXefbGTlsILPfmkzJ+HjufXgwANVr1ua8Jn+m9+0dsagorry6A6fVrgtAjz79eOzBe8jJyeGKNu2oUUsrRwrSf+CjDOj3IJmZmSQnV2fosKd4b/a7TH1zCgAtrriS9h2ui3CV/vNw3/v4+ec9xMTE0G/go5xSpgzPPDWMzIMH6XlX6IN1Z519DgMffQyAa1q3YP8v+8nMzOSThfMZ/coEatepG8EzKCb+zOJCs9AFzHx2MIsiNOcSfiFy+W8T7AX5dmtG/k8gf9hpFRMiXULgZWYd8Y2lFKFTSkb94cit2HVqoTNnx8ROvov4AlePOOdygGXHoRYRkT/Mr9MehXVSfbhGRILP/vhg/YSm0BaRQNFIW0TERxTaIiI+otAWEfERhbaIiJ8EO7MV2iISLEH/GLtCW0QCRdMjIiJ+EuzMVmiLSLBopC0i4iMKbRERH1Foi4j4iP72iIiIj2ikLSLiIwptEREfCXhmK7RFJFg00hYR8ZEoXYgUEfGPgA+0FdoiEiwaaYuI+IhG2iIiPqILkSIiPhLwzCbYfy1cRE46UVFRhd7yY2bVzWyhmX1tZuvM7D6vvYKZzTOzDd7P8l67mdkoM0s1s6/M7Lywvrp4+28wsy5h7Y3MbI13zCgrxNsEhbaIBIpZ4bcCZAF9nXP1gaZATzOrD/QH5jvn6gHzvfsArYF63tYdGBuqxyoAQ4ALgCbAkN+C3tunW9hxrQoqSqEtIoFiZoXe8uOc2+KcW+Xd3gesB6oB7YBJ3m6TgPbe7XbAGy5kGVDOzKoALYF5zrldzrndwDyglfdYGefcMuecA94I6+uIFNoiEihHM9I2s+5mtiJs6553n1YTOBf4HEhyzm3xHtoKJHm3qwE/hB2W5rXl156WR3u+dCFSRALlaFaPOOfGAeMK6K80MAPo45zbG96/c86ZmTvGUo+JRtoiEihFOKeNmcUSCuzJzrmZXvM2b2oD7+d2rz0dqB52eLLXll97ch7t+VJoi0igREVZobf8eCs5XgPWO+dGhj00G/htBUgXYFZY+63eKpKmwM/eNMpc4CozK+9dgLwKmOs9ttfMmnrPdWtYX0dU7NMjp1VMKO6nOOntO5AV6RICL6FEdKRLkEIqwg/XNANuAdaY2WqvbSAwAphuZncAm4GO3mMfAm2AVCADuA3AObfLzJ4Alnv7DXXO7fJu3wNMBOKBOd6WL81pi0igFFVmO+eWAEfqrUUe+zug5xH6SgFS8mhfAZx1NHUptEUkUPQxdhERHwl4Ziu0RSRY9KdZRUR8RNMjIiI+otAWEfGRgGe2QltEgkUjbRERHwl4Ziu0RSRYtHpERMRHogI+1FZoi0igBDyzFdoiEiy6ECki4iMBn9JWaItIsOhCpIiIj9gR/5pqMCi0RSRQAj7QVmiLSLDoQqSIiI8EPLMV2iISLPpwjYiIj2j1iIiIjwR8oK3QFpFg0fSIiIiPBDuyFdoiEjBa8ici4iMBvw6p0BaRYNHqERERH9H0iIiIjwR8oK3QFpFgCfpIOyrSBYiIFCU7iq3AvsxSzGy7ma0Na3vMzNLNbLW3tQl7bICZpZrZt2bWMqy9ldeWamb9w9prmdnnXvs0M4srqCaFtogESnSUFXorhIlAqzzaX3DONfS2DwHMrD7QCTjTO2aMmUWbWTTwMtAaqA/c6O0L8LTXV11gN3BHQQWddNMjgx8ZwOJPFlGhQiIzZ70PwMjnnuaTRQuJjY0luXoNhg57ijJlypCenkaHa9pQs2YtAM4+5xweHTI0kuWfsLZt3cKwIQPYvWsnmPGXDjfQ8cZb2PvzHgYPeJCtW9I5tUo1ho54njJlyrJqxRcM6HsvVapVA+DSy6/gtm738J9NGxk8sG9uvz+mp3Fnj1507HxrpE7thHZ1y+aUSihFVHQ00dHRTJ42A4Cpk//O9KlTiIqO5qJLLqXPAw/lHrNly49c364tPe7pya1dC8wI3ynK6RHn3GIzq1nI3dsBU51z/wU2mlkq0MR7LNU5971X31SgnZmtB5oDnb19JgGPAWPze5KTLrTbtb+WGzvfzKAB/XLbml7YjN59+hITE8MLzz/La+Nf5f6+oX/kydVrMH3mrEiV6xvRMTH0uv9hTj+jPhn793P7LTdw/gUXMue9d2nU5AJu6dqNv08czz8mTuCe3qFQPufcRjzztzGH9FOjZi0mTpkJQHZ2Nh3aXM4ll19x3M/HT15NeYPy5cvn3l/+xTIWLVzA1BmziIuLY9fOnYfsP/LZETS76OLjXeZxczSZbWbdge5hTeOcc+MKcWgvM7sVWAH0dc7tBqoBy8L2SfPaAH44rP0CIBHY45zLymP/IzrppkcaNT6fMmXLHtL252YXERMT+v3V4JyGbN+2NRKl+VrFipU4/YzQO76EUqWoWbM2O7Zv59NPFtK6bXsAWrdtz6eLFhS6z5XLl1GtWnVOrVK1OEoOrLenTeW2O7oRFxeaHq2QmJj72ML5H1O1WjK169aNVHnFLsqs0JtzbpxzrnHYVpjAHgvUARoCW4Dni/N8DnfMoW1mtxVlISeKd2fOoNnFl+TeT09Po+N17bm9y82sWrkigpX5x5Yf0/nu2/XUP6sBu3ftpGLFSgAkJlYMTZ941q5ZTZcbO9C3dw++/3fq//Tz8dw5XNGyzf+0y+/MjJ497qBzx2uZ8dY0ADZv3sSqVSu4tXNH7ux6M+vWrgEgI2M/E1PG0+PunpEsudiZFX47Fs65bc65bOdcDjCe36dA0oHqYbsme21Hat8JlDOzmMPa8/VHpkceB17P64Hwtxyjx7zKHd2657XbCWf8q2OJjonm6rZ/AaBSpcrM/Xgh5cqV5+t1a+nTuyczZ31A6dKlI1zpiSsjYz+DHu7DfX37U+qw18nC/ks5/Yz6vP3ePBISSvGvJYsZ+OC9TH1nTu6+mZkHWbp4IXf16nM8y/edlElTqJyUxK6dO7m7++3UrFWb7Oxs9v78M5MmT2Pd2jX0e7AP7835mFfHjOamW7qSkFAq0mUXq+Je8mdmVZxzW7y7HYDfVpbMBqaY2UigKlAP+ILQQpV6ZlaLUCh3Ajo755yZLQSuB6YCXYAC52LzDW0z++pIDwFJRzrOe4sxDuDXLFxBRZwIZr0zk8WfLGLcaxNz/0+Pi4vLfYtZ/8yzqF69Bps3beTMs86OZKknrKysTB55uA9XtbqaS5tfCUD5Cons2PETFStWYseOnyhfvgLAIYF+4UWX8PzTT7Bnz27KlQvNzS5buoQ/nVGfCokVj/+J+EjlpNB/hhUSE7m8xRWsW/sVlZOSaH7FlZgZZ53dgCiLYs/u3axZ8xUfz5vLiy88y759+4iyKOLiStCp880RPouiFV2EoW1mbwKXARXNLA0YAlxmZg0BB2wCegA459aZ2XTgayAL6Omcy/b66QXMBaKBFOfcOu8p+gFTzWwY8CXwWkE1FTTSTgJaElqKcsi5AJ8V1LlfLP10MRNTJvDapH8QHx+f275r1y7Kli1LdHQ0aT/8wObNm0hOrp5PTycv5xxPDR3MabVq0+nmrrntF116OXPef5dbunZjzvvvcvGllwOwc8dPVEisiJnx9dqvyMnJoWzZcrnHfTz3Q02NFOBARgY5LodSpUpzICODZZ8tpdtdPUlIKMWKL77g/CZN2bxpI5mZmZQrX56USZNzj31lzEskJCQELrChaD8R6Zy7MY/mIwarc244MDyP9g+BD/No/57fp1cKpaDQfh8o7ZxbffgDZrboaJ7oRNHvwQdYsfwL9uzZzZXNL+HunveSMn4cBzMPctedoWn635b2rVqxnJdHjyI2JgaLiuKRwY9Ttly5yJ7ACeqr/1vF3A9nU6fun+ja+VoAetzTh5u73MngAQ/wwayZJFWpyhNPha7ZLJr/Ee/MmEZ0dDQlSpTk8Sefy32Hc+BABsu/+IyHBg2J2Pn4wc6dO+nbpxcQWmnTqk1bml10MZmZB3ns0UHc0OEaYmNjeXz4iMB/SjBc0D/Gbs4V7+yFX6ZH/GzfgayCd5I/JKFEdKRLOCmUivvjv136vvdtoTPn+WtO913En3TrtEUk2II+0lZoi0igBH0mSKEtIoESE/DUVmiLSKAEPLMV2iISLFEBT22FtogESsAzW6EtIsGi1SMiIj5SyC838C2FtogESsAzW6EtIsFihfr2R/9SaItIoGikLSLiIwptEREfCfpfNFRoi0igRAf8m28V2iISKPpEpIiIj2hOW0TERwI+0FZoi0iwRGmdtoiIf2ikLSLiIzEBn9RWaItIoGikLSLiI1ryJyLiIwHPbIW2iARLwD8QqdAWkWDR9IiIiI8EPbSD/k5CRE4ydhRbgX2ZpZjZdjNbG9ZWwczmmdkG72d5r93MbJSZpZrZV2Z2XtgxXbz9N5hZl7D2Rma2xjtmlBXiTxQqtEUkUMwKvxXCRKDVYW39gfnOuXrAfO8+QGugnrd1B8aG6rEKwBDgAqAJMOS3oPf26RZ23OHP9T8U2iISKGZW6K0gzrnFwK7DmtsBk7zbk4D2Ye1vuJBlQDkzqwK0BOY553Y553YD84BW3mNlnHPLnHMOeCOsryNSaItIoEQdxWZm3c1sRdjWvRBPkeSc2+Ld3gokeberAT+E7ZfmteXXnpZHe750IVJEAuVoLkQ658YB4471uZxzzszcsR5/LIo9tLNzjuv5nJRKl9Tv3uJWoUmvSJdwUjjw5eg/3Mdx+LqxbWZWxTm3xZvi2O61pwPVw/ZL9trSgcsOa1/ktSfnsX++ND0iIoFyNNMjx2g28NsKkC7ArLD2W71VJE2Bn71plLnAVWZW3rsAeRUw13tsr5k19VaN3BrW1xFpiCYigVKUI20ze5PQKLmimaURWgUyAphuZncAm4GO3u4fAm2AVCADuA3AObfLzJ4Alnv7DXXO/XZx8x5CK1TigTneli+FtogESlFOjjjnbjzCQy3y2NcBPY/QTwqQkkf7CuCso6lJoS0igRId8E9EKrRFJFACntkKbREJFtN3RIqI+IdG2iIiPqJvYxcR8RGNtEVEfCTof09boS0igRIV7MxWaItIsGj1iIiIjwR8dkShLSLBopG2iIiPaE5bRMRHtHpERMRHgh3ZCm0RCRiNtEVEfCTYka3QFpGgCXhqK7RFJFA0PSIi4iPBjmyFtogETcBTW6EtIoGiT0SKiPhIwKe0FdoiEiwBz2yFtogEiwV8qK3QFpFACXhmK7RFJFgCntkKbREJmICntkJbRAJFS/4C6OqWzSmVUIqo6Giio6OZPG0G/R68n82bNgKwb99eTjmlDFPffjf3mC1bfuT6dm3pcU9Pbu16R4Qq948hjwxg8eJFVKiQyIx33wfg22++YfgTQ8jIyKBq1Wo8+fRzlC5dmj17dvPg/b1Zt3Ytf2nfgQGDBke4+hNHclI5JjxxK5UTT8E5SJmxlJffXMTZf6rGS4M6USq+BJt/3Mltgyaxb/+vADx4+1V0bXch2Tk59H3mbT7+13rqnVaZvz99e26/taol8sTYDxg9Jf++/Ehz2gH1asoblC9fPvf+08+9kHt75LMjKF36lEP2H/nsCJpddPFxq8/v/tL+Wjp1vplHBvbLbXt8yCAeeLAfjc9vwrsz32bS6xPoeW8fSsSVoOe995G6YQOpqRsiWPWJJys7h/4jZ7L6mzRKJ5Tgsyn9mP/5N4wd3Jn+L7zDkpWp3NquKfd3acHQMR9wRu1TuaHleZx3/XCqVCrLh6/04uz2Q9mweTtNO40AICrK+Pfc4cxe+H8AR+zLr4oytM1sE7APyAaynHONzawCMA2oCWwCOjrndlto2cqLQBsgA+jqnFvl9dMFeMTrdphzbtKx1hR1rAcGlXOOeXP/Sas2V+e2LZz/MVWrJVO7bt0IVuYvjRqfT5myZQ9p+8/mTTRqfD4ATS9sxvx5HwEQn5DAuec1Jq5EieNe54lu6469rP4mDYBfMv7LNxu3UrVSOerWqMySlakALFj2De1bNASg7WUNeGvuKg5mZrH5x538+4cdnH9WzUP6vLzJ6WxM+4n/bNkNcMS+/MqO4n+FdLlzrqFzrrF3vz8w3zlXD5jv3QdoDdTztu7AWAAv5IcAFwBNgCFmVp5jVGBom9kZZtbCzEof1t7qWJ800syMnj3uoHPHa5nx1rRDHlu1cgUVEhOpcVpNADIy9jMxZTw97u4ZgUqDpXadeixcMB+AeR/9k61bt0S4In+pUaUCDU9PZvnaTaz/fgvXXNYAgGuvPI/kpFAGVKtUlrStu3OPSd++m6qVD/3leUPLRkz/58rc+0fqy6/MCr8do3bAbyPlSUD7sPY3XMgyoJyZVQFaAvOcc7ucc7uBecAx52e+oW1mvYFZwL3AWjNrF/bwk8f6pJGWMmkKU6bPZPTY8UyfOoWVK5bnPjZ3zgeHjLJfHTOam27pSkJCqUiUGiiPPzGc6VOncGPHa9m/fz+xsXGRLsk3SsXH8eZzd/LQczPYt/9Xejw2me4dL2bp5IcpnVCCg5nZheonNiaaqy89m5nzvsxtO9a+TlR2NJtZdzNbEbZ1P6w7B3xkZivDHktyzv024tgKJHm3qwE/hB2b5rUdqf2YFDSn3Q1o5Jz7xcxqAm+bWU3n3Ivks7DGO7nuAKNefoXb7zz8dYisykmh17hCYiKXt7iCdWu/olHj88nKymLBx/OYPG1G7r5r1nzFx/Pm8uILz7Jv3z6iLIq4uBJ06nxzpMr3rVq16/DK+BQANm/ayKeLF0W2IJ+IiYnizee6MW3OCmYtCM1Df7dpG9fc8zIQmt5offGZAKT/9DPJp/4+Uq5WuTw/bv85937Li+qz+psf2L5rX27bkfryraMYQTvnxgHj8tnlIudcuplVBuaZ2TeHHe/MzB1TnceooNCOcs79AuCc22RmlxEK7tPI56UJfyH2H3TH9YQKciAjgxyXQ6lSpTmQkcGyz5bS7a7Q1Mfny/5FzVq1SDr11Nz9UyZNzr39ypiXSEhIUGAfo107d1IhMZGcnBzGvzqWGzp2inRJvvDKkJv4duNWRv1jQW5bpfKl+Wn3L5gZ/bu1ZPzbSwD4YNFXTHyqK6P+voAqlcpSt0Yllq/dlHtcx1aND5kaya8vvyrKL0FwzqV7P7eb2TuE5qS3mVkV59wWb/pju7d7OlA97PBkry0duOyw9kXHWlNBob3NzBo651Z7hf9iZm2BFODsY33SSNq5cyd9+/QCIDs7m1Zt2uauCvlozge0atM2kuUFRv+HHmDF8i/Ys2c3V7W4hLvvuZeMjAymTZ0CQIsrrqRdh+ty9299VXP2//ILmZmZLFzwMWPHpVCnji78/rlhbW5qewFrvktn2dTQ9a4ho2dTt3plevz1EgBmLVjNG7OWAbD++63M+OhLvpwxiKzsHPqMmE5OTmjclFAyjuYXnEGvYW8e8hwdWzXOsy+/KqrINrNShAau+7zbVwFDgdlAF2CE93OWd8hsoJeZTSV00fFnL9jnAk+GXXy8ChhwzHW5fAbCZpZMaJnL1jwea+acW1rQE5xoI+0gCvrXK50IKjTpFekSTgoHvhz9h/8xf7cto9CZ86ekhPymeWsD73h3Y4ApzrnhZpYITAdqAJsJLfnb5S35G03oImMGcJtzboXX1+3AQK+v4c6514/ytH6vK7/QLgoK7eKn0C5+Cu3joyhCe8O2A4XOnHpJ8b77j+ek/XCNiART0McwCm0RCZSAZ7ZCW0SCRV+CICLiIwHPbIW2iARLwDNboS0iARPw1FZoi0ig6EsQRER8RHPaIiI+EqXQFhHxk2CntkJbRAJF0yMiIj4S8MxWaItIsGikLSLiI/oYu4iIjwQ7shXaIhIwAR9oK7RFJFj0iUgRET8JdmYrtEUkWAKe2QptEQmWoH9nqkJbRAIl4JlNVKQLEBGRwtNIW0QCJegjbYW2iASKlvyJiPiIRtoiIj6i0BYR8RFNj4iI+IhG2iIiPhLwzFZoi0jABDy1FdoiEihB/xi7OeciXcMJx8y6O+fGRbqOINNrXPz0GgeTPsaet+6RLuAkoNe4+Ok1DiCFtoiIjyi0RUR8RKGdN80DFj+9xsVPr3EA6UKkiIiPaKQtIuIjCm0RER9RaIcxs1Zm9q2ZpZpZ/0jXE0RmlmJm281sbaRrCSozq25mC83sazNbZ2b3RbomKTqa0/aYWTTwHXAlkAYsB250zn0d0cICxswuAX4B3nDOnRXpeoLIzKoAVZxzq8zsFGAl0F7/loNBI+3fNQFSnXPfO+cOAlOBdhGuKXCcc4uBXZGuI8icc1ucc6u82/uA9UC1yFYlRUWh/btqwA9h99PQP3TxOTOrCZwLfB7hUqSIKLRFAsrMSgMzgD7Oub2RrkeKhkL7d+lA9bD7yV6biO+YWSyhwJ7snJsZ6Xqk6Ci0f7ccqGdmtcwsDugEzI5wTSJHzcwMeA1Y75wbGel6pGgptD3OuSygFzCX0IWb6c65dZGtKnjM7E3gX8DpZpZmZndEuqYAagbcAjQ3s9Xe1ibSRUnR0JI/EREf0UhbRMRHFNoiIj6i0BYR8RGFtoiIjyi0RUR8RKEtIuIjCm0RER/5/w2U5XyF/ncHAAAAAElFTkSuQmCC\n", 634 | "text/plain": [ 635 | "
" 636 | ] 637 | }, 638 | "metadata": { 639 | "needs_background": "light" 640 | }, 641 | "output_type": "display_data" 642 | } 643 | ], 644 | "source": [ 645 | "result = multilabel_predict_test\n", 646 | "#result = model.predict(validate_batches, val_epoch)\n", 647 | "labels = np.argmax(result, axis=1)\n", 648 | "target_names = ['Asian', 'Black', 'White']\n", 649 | "\n", 650 | "print ('Classwise ROC AUC \\n')\n", 651 | "for p in list(set(labels)):\n", 652 | " fpr, tpr, thresholds = roc_curve(test_batches.classes, result[:,p], pos_label = p)\n", 653 | " auroc = round(auc(fpr, tpr), 2)\n", 654 | " print ('Class - {} ROC-AUC- {}'.format(target_names[p], auroc))\n", 655 | "\n", 656 | "print (classification_report(test_batches.classes, labels, target_names=target_names))\n", 657 | "class_matrix = confusion_matrix(test_batches.classes, labels)\n", 658 | "\n", 659 | "sns.heatmap(class_matrix, annot=True, fmt='d', cmap='Blues')" 660 | ] 661 | } 662 | ], 663 | "metadata": { 664 | "kernelspec": { 665 | "display_name": "Python 3", 666 | "language": "python", 667 | "name": "python3" 668 | }, 669 | "language_info": { 670 | "codemirror_mode": { 671 | "name": "ipython", 672 | "version": 3 673 | }, 674 | "file_extension": ".py", 675 | "mimetype": "text/x-python", 676 | "name": "python", 677 | "nbconvert_exporter": "python", 678 | "pygments_lexer": "ipython3", 679 | "version": "3.6.9" 680 | } 681 | }, 682 | "nbformat": 4, 683 | "nbformat_minor": 4 684 | } 685 | -------------------------------------------------------------------------------- /training_code/digital_hand_atlas/dha_2_classes.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "1aab02dd", 6 | "metadata": {}, 7 | "source": [ 8 | "## Processing" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "6e8966f8", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import glob\n", 19 | "import random\n", 20 | "import tensorflow as tf\n", 21 | "import pandas as pd\n", 22 | "import numpy as np\n", 23 | "from sklearn.model_selection import train_test_split\n", 24 | "\n", 25 | "random.seed(2021)" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "id": "95d9f9ef", 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "data": { 36 | "text/html": [ 37 | "
\n", 38 | "\n", 51 | "\n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | "
filenameracegender
05599.jpg0F
15004.jpg0F
25024.jpg0F
36114.jpg0F
46127.jpg0F
\n", 93 | "
" 94 | ], 95 | "text/plain": [ 96 | " filename race gender\n", 97 | "0 5599.jpg 0 F\n", 98 | "1 5004.jpg 0 F\n", 99 | "2 5024.jpg 0 F\n", 100 | "3 6114.jpg 0 F\n", 101 | "4 6127.jpg 0 F" 102 | ] 103 | }, 104 | "execution_count": 2, 105 | "metadata": {}, 106 | "output_type": "execute_result" 107 | } 108 | ], 109 | "source": [ 110 | "img_height = 320\n", 111 | "img_width = 320\n", 112 | "batch_size = 8\n", 113 | "\n", 114 | "train_data_dir = \"../IMAGES/\"\n", 115 | "\n", 116 | "data = []\n", 117 | "labels = []\n", 118 | "\n", 119 | "metadata_df = pd.read_csv('../metadata.csv')\n", 120 | "metadata_df['race'] = pd.Categorical(pd.factorize(metadata_df.race)[0])\n", 121 | "metadata_df.head()" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 3, 127 | "id": "8ecce600", 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "metadata_df['race'] = metadata_df['race'].replace(0, 'Asian')\n", 132 | "metadata_df['race'] = metadata_df['race'].replace(1, 'Black')\n", 133 | "metadata_df['race'] = metadata_df['race'].replace(2, 'Caucasian')\n", 134 | "metadata_df['race'] = metadata_df['race'].replace(3, 'Hispanic')" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 4, 140 | "id": "0d32a491", 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "metadata_df = metadata_df[(metadata_df['race'] == 'Black') |\n", 145 | " ((metadata_df['race'] == 'Caucasian'))]" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 5, 151 | "id": "1407b14a", 152 | "metadata": {}, 153 | "outputs": [ 154 | { 155 | "data": { 156 | "text/plain": [ 157 | "Black 358\n", 158 | "Caucasian 333\n", 159 | "Name: race, dtype: int64" 160 | ] 161 | }, 162 | "execution_count": 5, 163 | "metadata": {}, 164 | "output_type": "execute_result" 165 | } 166 | ], 167 | "source": [ 168 | "num_classes = len(metadata_df['race'].value_counts().index)\n", 169 | "metadata_df['race'].value_counts()" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 6, 175 | "id": "b0ca5e0c", 176 | "metadata": {}, 177 | "outputs": [ 178 | { 179 | "data": { 180 | "text/plain": [ 181 | "array(['Black', 'Caucasian'], dtype=object)" 182 | ] 183 | }, 184 | "execution_count": 6, 185 | "metadata": {}, 186 | "output_type": "execute_result" 187 | } 188 | ], 189 | "source": [ 190 | "from sklearn.preprocessing import LabelEncoder\n", 191 | "\n", 192 | "le = LabelEncoder()\n", 193 | "metadata_df['race'] = le.fit_transform(metadata_df['race'])\n", 194 | "le.classes_" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 7, 200 | "id": "c6c0677f", 201 | "metadata": {}, 202 | "outputs": [ 203 | { 204 | "data": { 205 | "text/plain": [ 206 | "M 351\n", 207 | "F 340\n", 208 | "Name: gender, dtype: int64" 209 | ] 210 | }, 211 | "execution_count": 7, 212 | "metadata": {}, 213 | "output_type": "execute_result" 214 | } 215 | ], 216 | "source": [ 217 | "metadata_df['gender'].value_counts()" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 8, 223 | "id": "b25d3d9a", 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "for i in metadata_df['filename']:\n", 228 | " key = i\n", 229 | " label = metadata_df.loc[metadata_df['filename'] == key].iloc[0]['race']\n", 230 | " labels.append(label)\n", 231 | " filepath = train_data_dir+i\n", 232 | " image=tf.keras.preprocessing.image.load_img(filepath, color_mode='rgb', target_size= (img_height,img_width))\n", 233 | " image=np.array(image)\n", 234 | " data.append(image)\n", 235 | " \n", 236 | "X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.20, random_state=42) \n", 237 | "X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.125, random_state=42)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 9, 243 | "id": "070528ab", 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "import pickle\n", 248 | "\n", 249 | "test_data = (X_test, y_test)\n", 250 | "pickle.dump(test_data, open('../test_data_dha_2_classes.pkl', 'wb'))" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 10, 256 | "id": "5bffef0a", 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "X_train = np.asarray(X_train)\n", 261 | "X_test = np.asarray(X_test)\n", 262 | "y_train = np.asarray(y_train)\n", 263 | "y_test = np.asarray(y_test)\n", 264 | "X_val = np.asarray(X_val)\n", 265 | "y_val = np.asarray(y_val)" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 11, 271 | "id": "969e0eae", 272 | "metadata": {}, 273 | "outputs": [ 274 | { 275 | "data": { 276 | "text/plain": [ 277 | "
" 278 | ] 279 | }, 280 | "metadata": {}, 281 | "output_type": "display_data" 282 | }, 283 | { 284 | "data": { 285 | "text/plain": [ 286 | "1" 287 | ] 288 | }, 289 | "metadata": {}, 290 | "output_type": "display_data" 291 | } 292 | ], 293 | "source": [ 294 | "from matplotlib import pyplot as plt\n", 295 | "plt.imshow(X_train[10], interpolation='nearest')\n", 296 | "plt.show()\n", 297 | "display(y_train[10])" 298 | ] 299 | }, 300 | { 301 | "cell_type": "markdown", 302 | "id": "c282099e", 303 | "metadata": {}, 304 | "source": [ 305 | "## Model creation" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 12, 311 | "id": "a5f07180", 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [ 315 | "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n", 316 | "from tensorflow.keras.models import Sequential\n", 317 | "from tensorflow.keras.layers import Conv2D, MaxPool2D\n", 318 | "from tensorflow.keras.layers import Activation, Dropout, Flatten, Dense, concatenate, add, GlobalAveragePooling2D, BatchNormalization, Input\n", 319 | "from tensorflow.keras import backend as K\n", 320 | "from tensorflow.keras.optimizers import Adam\n", 321 | "from tensorflow.keras import layers\n", 322 | "from tensorflow.keras.models import Model\n", 323 | "from tensorflow.keras.callbacks import EarlyStopping, LearningRateScheduler, ReduceLROnPlateau, ModelCheckpoint\n", 324 | "from tensorflow.keras import optimizers\n", 325 | "from tensorflow.keras.applications.densenet import DenseNet121\n", 326 | "from classification_models.tfkeras import Classifiers\n", 327 | "from tensorflow.keras.models import load_model\n", 328 | "\n", 329 | "resnet34, preprocess_input = Classifiers.get('resnet50')\n", 330 | "\n", 331 | "if K.image_data_format() == 'channels_first':\n", 332 | " input_shape = (3, img_width, img_height)\n", 333 | "else:\n", 334 | " input_shape = (img_width, img_height, 3)\n", 335 | " \n", 336 | "input_a = Input(shape=(img_height, img_width, 3))\n", 337 | "base_model = resnet34(input_tensor=input_a, include_top=False, input_shape=(img_height, img_width,3), weights='imagenet')\n", 338 | "\n", 339 | "x = GlobalAveragePooling2D()(base_model.output)\n", 340 | "x = layers.Dense(num_classes, name='dense_logits')(x)\n", 341 | "output = layers.Activation('softmax', dtype='float32', name='predictions')(x)\n", 342 | "model = Model(inputs=[input_a], outputs=[output])\n", 343 | "\n", 344 | "adam_opt = Adam(lr=0.000001)\n", 345 | "reduce_lr = ReduceLROnPlateau(monitor='val_loss', mode='min', factor=0.1, patience=2, min_lr=1e-6, verbose=1)\n", 346 | "model.compile(optimizer=adam_opt, loss=tf.losses.CategoricalCrossentropy(),\n", 347 | " metrics=[\n", 348 | " tf.keras.metrics.AUC(curve='ROC', name='ROC-AUC'),\n", 349 | " tf.keras.metrics.AUC(curve='PR', name='PR-AUC')\n", 350 | " ],\n", 351 | " )" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 13, 357 | "id": "e2441236", 358 | "metadata": {}, 359 | "outputs": [ 360 | { 361 | "name": "stderr", 362 | "output_type": "stream", 363 | "text": [ 364 | "Using TensorFlow backend.\n" 365 | ] 366 | } 367 | ], 368 | "source": [ 369 | "from keras.utils import to_categorical\n", 370 | "y_train_cat = to_categorical(y_train)\n", 371 | "y_test_cat = to_categorical(y_test)\n", 372 | "y_val_cat = to_categorical(y_val)" 373 | ] 374 | }, 375 | { 376 | "cell_type": "markdown", 377 | "id": "cb1fd580", 378 | "metadata": {}, 379 | "source": [ 380 | "## Model Training" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 14, 386 | "id": "c2064f60", 387 | "metadata": { 388 | "scrolled": true 389 | }, 390 | "outputs": [ 391 | { 392 | "name": "stdout", 393 | "output_type": "stream", 394 | "text": [ 395 | "Train on 483 samples, validate on 69 samples\n", 396 | "Epoch 1/100\n", 397 | "483/483 [==============================] - 14s 29ms/sample - loss: 0.7112 - ROC-AUC: 0.5332 - PR-AUC: 0.5126 - val_loss: 0.7675 - val_ROC-AUC: 0.4033 - val_PR-AUC: 0.4225\n", 398 | "Epoch 2/100\n", 399 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.6822 - ROC-AUC: 0.5925 - PR-AUC: 0.5768 - val_loss: 0.7336 - val_ROC-AUC: 0.4500 - val_PR-AUC: 0.4538\n", 400 | "Epoch 3/100\n", 401 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.6647 - ROC-AUC: 0.6375 - PR-AUC: 0.6208 - val_loss: 0.7190 - val_ROC-AUC: 0.4789 - val_PR-AUC: 0.4806\n", 402 | "Epoch 4/100\n", 403 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.6482 - ROC-AUC: 0.6818 - PR-AUC: 0.6656 - val_loss: 0.7075 - val_ROC-AUC: 0.5090 - val_PR-AUC: 0.5002\n", 404 | "Epoch 5/100\n", 405 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.6409 - ROC-AUC: 0.7010 - PR-AUC: 0.6888 - val_loss: 0.7064 - val_ROC-AUC: 0.5177 - val_PR-AUC: 0.4990\n", 406 | "Epoch 6/100\n", 407 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.6116 - ROC-AUC: 0.7758 - PR-AUC: 0.7684 - val_loss: 0.7044 - val_ROC-AUC: 0.5316 - val_PR-AUC: 0.5053\n", 408 | "Epoch 7/100\n", 409 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.5962 - ROC-AUC: 0.8073 - PR-AUC: 0.8045 - val_loss: 0.6984 - val_ROC-AUC: 0.5579 - val_PR-AUC: 0.5230\n", 410 | "Epoch 8/100\n", 411 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.5772 - ROC-AUC: 0.8482 - PR-AUC: 0.8478 - val_loss: 0.6902 - val_ROC-AUC: 0.5880 - val_PR-AUC: 0.5569\n", 412 | "Epoch 9/100\n", 413 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.5672 - ROC-AUC: 0.8657 - PR-AUC: 0.8674 - val_loss: 0.6830 - val_ROC-AUC: 0.6127 - val_PR-AUC: 0.5727\n", 414 | "Epoch 10/100\n", 415 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.5459 - ROC-AUC: 0.9062 - PR-AUC: 0.9045 - val_loss: 0.6742 - val_ROC-AUC: 0.6303 - val_PR-AUC: 0.5874\n", 416 | "Epoch 11/100\n", 417 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.5525 - ROC-AUC: 0.8860 - PR-AUC: 0.8889 - val_loss: 0.6656 - val_ROC-AUC: 0.6455 - val_PR-AUC: 0.6007\n", 418 | "Epoch 12/100\n", 419 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.5306 - ROC-AUC: 0.9221 - PR-AUC: 0.9240 - val_loss: 0.6554 - val_ROC-AUC: 0.6648 - val_PR-AUC: 0.6240\n", 420 | "Epoch 13/100\n", 421 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.5180 - ROC-AUC: 0.9340 - PR-AUC: 0.9316 - val_loss: 0.6454 - val_ROC-AUC: 0.6856 - val_PR-AUC: 0.6522\n", 422 | "Epoch 14/100\n", 423 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.5063 - ROC-AUC: 0.9458 - PR-AUC: 0.9466 - val_loss: 0.6377 - val_ROC-AUC: 0.6993 - val_PR-AUC: 0.6757\n", 424 | "Epoch 15/100\n", 425 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.4904 - ROC-AUC: 0.9576 - PR-AUC: 0.9576 - val_loss: 0.6324 - val_ROC-AUC: 0.7065 - val_PR-AUC: 0.6878\n", 426 | "Epoch 16/100\n", 427 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.4763 - ROC-AUC: 0.9626 - PR-AUC: 0.9637 - val_loss: 0.6264 - val_ROC-AUC: 0.7116 - val_PR-AUC: 0.6964\n", 428 | "Epoch 17/100\n", 429 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.4688 - ROC-AUC: 0.9661 - PR-AUC: 0.9644 - val_loss: 0.6214 - val_ROC-AUC: 0.7203 - val_PR-AUC: 0.7103\n", 430 | "Epoch 18/100\n", 431 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.4521 - ROC-AUC: 0.9760 - PR-AUC: 0.9767 - val_loss: 0.6167 - val_ROC-AUC: 0.7279 - val_PR-AUC: 0.7192\n", 432 | "Epoch 19/100\n", 433 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.4486 - ROC-AUC: 0.9744 - PR-AUC: 0.9746 - val_loss: 0.6121 - val_ROC-AUC: 0.7307 - val_PR-AUC: 0.7197\n", 434 | "Epoch 20/100\n", 435 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.4297 - ROC-AUC: 0.9812 - PR-AUC: 0.9811 - val_loss: 0.6080 - val_ROC-AUC: 0.7322 - val_PR-AUC: 0.7225\n", 436 | "Epoch 21/100\n", 437 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.4244 - ROC-AUC: 0.9816 - PR-AUC: 0.9814 - val_loss: 0.6050 - val_ROC-AUC: 0.7336 - val_PR-AUC: 0.7226\n", 438 | "Epoch 22/100\n", 439 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.4138 - ROC-AUC: 0.9868 - PR-AUC: 0.9870 - val_loss: 0.6029 - val_ROC-AUC: 0.7368 - val_PR-AUC: 0.7271\n", 440 | "Epoch 23/100\n", 441 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.4011 - ROC-AUC: 0.9867 - PR-AUC: 0.9865 - val_loss: 0.5993 - val_ROC-AUC: 0.7446 - val_PR-AUC: 0.7345\n", 442 | "Epoch 24/100\n", 443 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.3909 - ROC-AUC: 0.9879 - PR-AUC: 0.9865 - val_loss: 0.5993 - val_ROC-AUC: 0.7379 - val_PR-AUC: 0.7276\n", 444 | "Epoch 25/100\n", 445 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.3791 - ROC-AUC: 0.9921 - PR-AUC: 0.9920 - val_loss: 0.5993 - val_ROC-AUC: 0.7398 - val_PR-AUC: 0.7285\n", 446 | "Epoch 26/100\n", 447 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.3733 - ROC-AUC: 0.9904 - PR-AUC: 0.9907 - val_loss: 0.5968 - val_ROC-AUC: 0.7447 - val_PR-AUC: 0.7339\n", 448 | "Epoch 27/100\n", 449 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.3665 - ROC-AUC: 0.9875 - PR-AUC: 0.9852 - val_loss: 0.5962 - val_ROC-AUC: 0.7442 - val_PR-AUC: 0.7321\n", 450 | "Epoch 28/100\n", 451 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.3466 - ROC-AUC: 0.9930 - PR-AUC: 0.9932 - val_loss: 0.5966 - val_ROC-AUC: 0.7438 - val_PR-AUC: 0.7343\n", 452 | "Epoch 29/100\n", 453 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.3452 - ROC-AUC: 0.9954 - PR-AUC: 0.9955 - val_loss: 0.5934 - val_ROC-AUC: 0.7481 - val_PR-AUC: 0.7406\n", 454 | "Epoch 30/100\n", 455 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.3337 - ROC-AUC: 0.9935 - PR-AUC: 0.9937 - val_loss: 0.5937 - val_ROC-AUC: 0.7504 - val_PR-AUC: 0.7432\n", 456 | "Epoch 31/100\n", 457 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.3218 - ROC-AUC: 0.9951 - PR-AUC: 0.9951 - val_loss: 0.5941 - val_ROC-AUC: 0.7472 - val_PR-AUC: 0.7382\n", 458 | "Epoch 32/100\n", 459 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.3167 - ROC-AUC: 0.9962 - PR-AUC: 0.9963 - val_loss: 0.5930 - val_ROC-AUC: 0.7477 - val_PR-AUC: 0.7336\n", 460 | "Epoch 33/100\n", 461 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.3156 - ROC-AUC: 0.9938 - PR-AUC: 0.9940 - val_loss: 0.5898 - val_ROC-AUC: 0.7511 - val_PR-AUC: 0.7363\n", 462 | "Epoch 34/100\n", 463 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.3098 - ROC-AUC: 0.9935 - PR-AUC: 0.9937 - val_loss: 0.5903 - val_ROC-AUC: 0.7505 - val_PR-AUC: 0.7349\n", 464 | "Epoch 35/100\n", 465 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.2812 - ROC-AUC: 0.9984 - PR-AUC: 0.9984 - val_loss: 0.5877 - val_ROC-AUC: 0.7555 - val_PR-AUC: 0.7394\n", 466 | "Epoch 36/100\n", 467 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.2811 - ROC-AUC: 0.9958 - PR-AUC: 0.9958 - val_loss: 0.5871 - val_ROC-AUC: 0.7581 - val_PR-AUC: 0.7473\n", 468 | "Epoch 37/100\n", 469 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.2800 - ROC-AUC: 0.9965 - PR-AUC: 0.9966 - val_loss: 0.5878 - val_ROC-AUC: 0.7543 - val_PR-AUC: 0.7404\n", 470 | "Epoch 38/100\n", 471 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.2608 - ROC-AUC: 0.9964 - PR-AUC: 0.9962 - val_loss: 0.5891 - val_ROC-AUC: 0.7556 - val_PR-AUC: 0.7416\n", 472 | "Epoch 39/100\n", 473 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.2555 - ROC-AUC: 0.9983 - PR-AUC: 0.9983 - val_loss: 0.5875 - val_ROC-AUC: 0.7559 - val_PR-AUC: 0.7431\n", 474 | "Epoch 40/100\n", 475 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.2521 - ROC-AUC: 0.9966 - PR-AUC: 0.9962 - val_loss: 0.5892 - val_ROC-AUC: 0.7576 - val_PR-AUC: 0.7457\n", 476 | "Epoch 41/100\n", 477 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.2340 - ROC-AUC: 0.9996 - PR-AUC: 0.9996 - val_loss: 0.5860 - val_ROC-AUC: 0.7615 - val_PR-AUC: 0.7502\n", 478 | "Epoch 42/100\n", 479 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.2302 - ROC-AUC: 0.9979 - PR-AUC: 0.9979 - val_loss: 0.5808 - val_ROC-AUC: 0.7660 - val_PR-AUC: 0.7513\n", 480 | "Epoch 43/100\n", 481 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.2228 - ROC-AUC: 0.9992 - PR-AUC: 0.9993 - val_loss: 0.5809 - val_ROC-AUC: 0.7669 - val_PR-AUC: 0.7527\n", 482 | "Epoch 44/100\n", 483 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.2269 - ROC-AUC: 0.9982 - PR-AUC: 0.9982 - val_loss: 0.5763 - val_ROC-AUC: 0.7706 - val_PR-AUC: 0.7566\n", 484 | "Epoch 45/100\n", 485 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.2056 - ROC-AUC: 0.9988 - PR-AUC: 0.9988 - val_loss: 0.5723 - val_ROC-AUC: 0.7743 - val_PR-AUC: 0.7613\n", 486 | "Epoch 46/100\n", 487 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.2119 - ROC-AUC: 0.9967 - PR-AUC: 0.9967 - val_loss: 0.5703 - val_ROC-AUC: 0.7775 - val_PR-AUC: 0.7645\n", 488 | "Epoch 47/100\n", 489 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1863 - ROC-AUC: 0.9996 - PR-AUC: 0.9996 - val_loss: 0.5698 - val_ROC-AUC: 0.7783 - val_PR-AUC: 0.7638\n", 490 | "Epoch 48/100\n", 491 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1925 - ROC-AUC: 0.9994 - PR-AUC: 0.9994 - val_loss: 0.5660 - val_ROC-AUC: 0.7863 - val_PR-AUC: 0.7730\n", 492 | "Epoch 49/100\n", 493 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1862 - ROC-AUC: 0.9990 - PR-AUC: 0.9990 - val_loss: 0.5584 - val_ROC-AUC: 0.7923 - val_PR-AUC: 0.7780\n", 494 | "Epoch 50/100\n", 495 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1806 - ROC-AUC: 0.9988 - PR-AUC: 0.9988 - val_loss: 0.5578 - val_ROC-AUC: 0.7927 - val_PR-AUC: 0.7763\n", 496 | "Epoch 51/100\n", 497 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1671 - ROC-AUC: 0.9998 - PR-AUC: 0.9998 - val_loss: 0.5571 - val_ROC-AUC: 0.7912 - val_PR-AUC: 0.7748\n", 498 | "Epoch 52/100\n", 499 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1731 - ROC-AUC: 0.9996 - PR-AUC: 0.9996 - val_loss: 0.5550 - val_ROC-AUC: 0.7942 - val_PR-AUC: 0.7805\n", 500 | "Epoch 53/100\n", 501 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1584 - ROC-AUC: 0.9998 - PR-AUC: 0.9998 - val_loss: 0.5508 - val_ROC-AUC: 0.7972 - val_PR-AUC: 0.7850\n", 502 | "Epoch 54/100\n", 503 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1558 - ROC-AUC: 0.9994 - PR-AUC: 0.9994 - val_loss: 0.5518 - val_ROC-AUC: 0.7963 - val_PR-AUC: 0.7815\n", 504 | "Epoch 55/100\n", 505 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1524 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5538 - val_ROC-AUC: 0.7963 - val_PR-AUC: 0.7793\n", 506 | "Epoch 56/100\n", 507 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1543 - ROC-AUC: 0.9997 - PR-AUC: 0.9997 - val_loss: 0.5566 - val_ROC-AUC: 0.7966 - val_PR-AUC: 0.7797\n", 508 | "Epoch 57/100\n", 509 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1446 - ROC-AUC: 0.9999 - PR-AUC: 0.9999 - val_loss: 0.5546 - val_ROC-AUC: 0.7954 - val_PR-AUC: 0.7788\n", 510 | "Epoch 58/100\n", 511 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1481 - ROC-AUC: 0.9997 - PR-AUC: 0.9997 - val_loss: 0.5590 - val_ROC-AUC: 0.7948 - val_PR-AUC: 0.7796\n", 512 | "Epoch 59/100\n", 513 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1426 - ROC-AUC: 0.9980 - PR-AUC: 0.9980 - val_loss: 0.5604 - val_ROC-AUC: 0.7937 - val_PR-AUC: 0.7770\n", 514 | "Epoch 60/100\n", 515 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1391 - ROC-AUC: 0.9993 - PR-AUC: 0.9993 - val_loss: 0.5616 - val_ROC-AUC: 0.7910 - val_PR-AUC: 0.7728\n", 516 | "Epoch 61/100\n", 517 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1357 - ROC-AUC: 0.9996 - PR-AUC: 0.9996 - val_loss: 0.5544 - val_ROC-AUC: 0.7992 - val_PR-AUC: 0.7813\n", 518 | "Epoch 62/100\n", 519 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1165 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5499 - val_ROC-AUC: 0.8051 - val_PR-AUC: 0.7893\n", 520 | "Epoch 63/100\n", 521 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1190 - ROC-AUC: 0.9997 - PR-AUC: 0.9997 - val_loss: 0.5491 - val_ROC-AUC: 0.8026 - val_PR-AUC: 0.7856\n", 522 | "Epoch 64/100\n", 523 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1111 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5542 - val_ROC-AUC: 0.7973 - val_PR-AUC: 0.7803\n", 524 | "Epoch 65/100\n", 525 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1101 - ROC-AUC: 0.9999 - PR-AUC: 0.9999 - val_loss: 0.5555 - val_ROC-AUC: 0.7982 - val_PR-AUC: 0.7826\n", 526 | "Epoch 66/100\n", 527 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1063 - ROC-AUC: 0.9999 - PR-AUC: 0.9999 - val_loss: 0.5483 - val_ROC-AUC: 0.8026 - val_PR-AUC: 0.7870\n", 528 | "Epoch 67/100\n", 529 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1054 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5462 - val_ROC-AUC: 0.8057 - val_PR-AUC: 0.7901\n", 530 | "Epoch 68/100\n", 531 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1012 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5489 - val_ROC-AUC: 0.8047 - val_PR-AUC: 0.7892\n", 532 | "Epoch 69/100\n", 533 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1016 - ROC-AUC: 0.9998 - PR-AUC: 0.9998 - val_loss: 0.5487 - val_ROC-AUC: 0.8072 - val_PR-AUC: 0.7936\n", 534 | "Epoch 70/100\n", 535 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.1041 - ROC-AUC: 0.9999 - PR-AUC: 0.9999 - val_loss: 0.5462 - val_ROC-AUC: 0.8097 - val_PR-AUC: 0.7971\n", 536 | "Epoch 71/100\n", 537 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0970 - ROC-AUC: 0.9998 - PR-AUC: 0.9998 - val_loss: 0.5514 - val_ROC-AUC: 0.8082 - val_PR-AUC: 0.7939\n", 538 | "Epoch 72/100\n", 539 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0821 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5477 - val_ROC-AUC: 0.8093 - val_PR-AUC: 0.7941\n", 540 | "Epoch 73/100\n", 541 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0807 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5491 - val_ROC-AUC: 0.8091 - val_PR-AUC: 0.7969\n", 542 | "Epoch 74/100\n", 543 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0935 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5440 - val_ROC-AUC: 0.8124 - val_PR-AUC: 0.8010\n", 544 | "Epoch 75/100\n", 545 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0841 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5424 - val_ROC-AUC: 0.8141 - val_PR-AUC: 0.8023\n", 546 | "Epoch 76/100\n", 547 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0834 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5389 - val_ROC-AUC: 0.8147 - val_PR-AUC: 0.8028\n", 548 | "Epoch 77/100\n", 549 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0716 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5417 - val_ROC-AUC: 0.8150 - val_PR-AUC: 0.8019\n", 550 | "Epoch 78/100\n", 551 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0654 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5441 - val_ROC-AUC: 0.8131 - val_PR-AUC: 0.7998\n", 552 | "Epoch 79/100\n", 553 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0625 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5393 - val_ROC-AUC: 0.8188 - val_PR-AUC: 0.8080\n", 554 | "Epoch 80/100\n", 555 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0682 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5413 - val_ROC-AUC: 0.8175 - val_PR-AUC: 0.8059\n", 556 | "Epoch 81/100\n", 557 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0692 - ROC-AUC: 0.9999 - PR-AUC: 0.9999 - val_loss: 0.5407 - val_ROC-AUC: 0.8181 - val_PR-AUC: 0.8061\n", 558 | "Epoch 82/100\n", 559 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0658 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5372 - val_ROC-AUC: 0.8220 - val_PR-AUC: 0.8109\n", 560 | "Epoch 83/100\n", 561 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0663 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5370 - val_ROC-AUC: 0.8210 - val_PR-AUC: 0.8096\n", 562 | "Epoch 84/100\n", 563 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0701 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5424 - val_ROC-AUC: 0.8183 - val_PR-AUC: 0.8071\n", 564 | "Epoch 85/100\n", 565 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0523 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5436 - val_ROC-AUC: 0.8183 - val_PR-AUC: 0.8073\n", 566 | "Epoch 86/100\n", 567 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0547 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5454 - val_ROC-AUC: 0.8176 - val_PR-AUC: 0.8054\n", 568 | "Epoch 87/100\n", 569 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0589 - ROC-AUC: 0.9998 - PR-AUC: 0.9998 - val_loss: 0.5500 - val_ROC-AUC: 0.8164 - val_PR-AUC: 0.8035\n", 570 | "Epoch 88/100\n", 571 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0686 - ROC-AUC: 0.9996 - PR-AUC: 0.9996 - val_loss: 0.5475 - val_ROC-AUC: 0.8192 - val_PR-AUC: 0.8079\n", 572 | "Epoch 89/100\n", 573 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0457 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5473 - val_ROC-AUC: 0.8171 - val_PR-AUC: 0.8056\n", 574 | "Epoch 90/100\n", 575 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0501 - ROC-AUC: 0.9999 - PR-AUC: 0.9999 - val_loss: 0.5539 - val_ROC-AUC: 0.8151 - val_PR-AUC: 0.8038\n", 576 | "Epoch 91/100\n", 577 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0574 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5638 - val_ROC-AUC: 0.8139 - val_PR-AUC: 0.8021\n", 578 | "Epoch 92/100\n", 579 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0586 - ROC-AUC: 0.9999 - PR-AUC: 0.9999 - val_loss: 0.5643 - val_ROC-AUC: 0.8141 - val_PR-AUC: 0.8013\n", 580 | "Epoch 93/100\n", 581 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0496 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5552 - val_ROC-AUC: 0.8174 - val_PR-AUC: 0.8054\n", 582 | "Epoch 94/100\n", 583 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0494 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5378 - val_ROC-AUC: 0.8229 - val_PR-AUC: 0.8120\n", 584 | "Epoch 95/100\n", 585 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0458 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5392 - val_ROC-AUC: 0.8229 - val_PR-AUC: 0.8130\n", 586 | "Epoch 96/100\n", 587 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0511 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5420 - val_ROC-AUC: 0.8227 - val_PR-AUC: 0.8131\n", 588 | "Epoch 97/100\n", 589 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0473 - ROC-AUC: 0.9999 - PR-AUC: 0.9999 - val_loss: 0.5453 - val_ROC-AUC: 0.8225 - val_PR-AUC: 0.8122\n", 590 | "Epoch 98/100\n", 591 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0528 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5456 - val_ROC-AUC: 0.8234 - val_PR-AUC: 0.8126\n", 592 | "Epoch 99/100\n", 593 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0404 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5478 - val_ROC-AUC: 0.8215 - val_PR-AUC: 0.8081\n", 594 | "Epoch 100/100\n", 595 | "483/483 [==============================] - 6s 12ms/sample - loss: 0.0407 - ROC-AUC: 1.0000 - PR-AUC: 1.0000 - val_loss: 0.5488 - val_ROC-AUC: 0.8238 - val_PR-AUC: 0.8106\n", 596 | "WARNING:tensorflow:From /home/jupyter-anbhimi/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1786: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n", 597 | "Instructions for updating:\n", 598 | "If using Keras pass *_constraint arguments to layers.\n", 599 | "INFO:tensorflow:Assets written to: ../classification_model_2_classes/assets\n" 600 | ] 601 | } 602 | ], 603 | "source": [ 604 | "model_path = '../classification_model_2_classes'\n", 605 | "history = model.fit(X_train,y_train_cat, validation_data=(X_val, y_val_cat), epochs=100, callbacks=[reduce_lr])\n", 606 | "tf.keras.models.save_model(model = model, filepath = model_path)" 607 | ] 608 | }, 609 | { 610 | "cell_type": "code", 611 | "execution_count": 15, 612 | "id": "4786cefb", 613 | "metadata": {}, 614 | "outputs": [ 615 | { 616 | "data": { 617 | "text/plain": [ 618 | "(139, 320, 320, 3)" 619 | ] 620 | }, 621 | "execution_count": 15, 622 | "metadata": {}, 623 | "output_type": "execute_result" 624 | } 625 | ], 626 | "source": [ 627 | "X_test.shape" 628 | ] 629 | }, 630 | { 631 | "cell_type": "code", 632 | "execution_count": 16, 633 | "id": "a83e76b5", 634 | "metadata": {}, 635 | "outputs": [], 636 | "source": [ 637 | "from sklearn.metrics import classification_report,confusion_matrix, roc_auc_score\n", 638 | "\n", 639 | "model = tf.keras.models.load_model(filepath = model_path)\n", 640 | "predictions = model.predict(X_test)\n", 641 | "predictions_rounded = np.argmax(predictions, axis=1)" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": 17, 647 | "id": "779d9fac", 648 | "metadata": {}, 649 | "outputs": [ 650 | { 651 | "data": { 652 | "text/plain": [ 653 | "[0, 1]" 654 | ] 655 | }, 656 | "execution_count": 17, 657 | "metadata": {}, 658 | "output_type": "execute_result" 659 | } 660 | ], 661 | "source": [ 662 | "list(set(predictions_rounded))" 663 | ] 664 | }, 665 | { 666 | "cell_type": "code", 667 | "execution_count": 18, 668 | "id": "018be441", 669 | "metadata": {}, 670 | "outputs": [ 671 | { 672 | "name": "stdout", 673 | "output_type": "stream", 674 | "text": [ 675 | "0 - 0.86\n", 676 | "1 - 0.86\n" 677 | ] 678 | } 679 | ], 680 | "source": [ 681 | "from sklearn.metrics import roc_curve, auc\n", 682 | "\n", 683 | "for p in list(set(predictions_rounded)):\n", 684 | " fpr, tpr, thresholds = roc_curve(y_test, model.predict(X_test)[:,p], pos_label = p)\n", 685 | " auroc = round(auc(fpr, tpr), 2)\n", 686 | " print ('{} - {}'.format(p, auroc))" 687 | ] 688 | }, 689 | { 690 | "cell_type": "code", 691 | "execution_count": 19, 692 | "id": "0016732e", 693 | "metadata": {}, 694 | "outputs": [ 695 | { 696 | "data": { 697 | "text/plain": [ 698 | "array([[74, 8],\n", 699 | " [19, 38]])" 700 | ] 701 | }, 702 | "execution_count": 19, 703 | "metadata": {}, 704 | "output_type": "execute_result" 705 | } 706 | ], 707 | "source": [ 708 | "confusion_matrix(y_test, predictions_rounded)" 709 | ] 710 | }, 711 | { 712 | "cell_type": "code", 713 | "execution_count": 20, 714 | "id": "ab0099f4", 715 | "metadata": {}, 716 | "outputs": [ 717 | { 718 | "name": "stdout", 719 | "output_type": "stream", 720 | "text": [ 721 | " precision recall f1-score support\n", 722 | "\n", 723 | " 0 0.80 0.90 0.85 82\n", 724 | " 1 0.83 0.67 0.74 57\n", 725 | "\n", 726 | " accuracy 0.81 139\n", 727 | " macro avg 0.81 0.78 0.79 139\n", 728 | "weighted avg 0.81 0.81 0.80 139\n", 729 | "\n" 730 | ] 731 | } 732 | ], 733 | "source": [ 734 | "print (classification_report(y_test, predictions_rounded))" 735 | ] 736 | }, 737 | { 738 | "cell_type": "markdown", 739 | "id": "e4733ecb-ced6-499f-9800-9bd5f587a384", 740 | "metadata": {}, 741 | "source": [ 742 | "## LIME Model Interpretation" 743 | ] 744 | }, 745 | { 746 | "cell_type": "code", 747 | "execution_count": null, 748 | "id": "f3bdf4b2-c3ad-41d0-93d7-ee9b95c8642e", 749 | "metadata": {}, 750 | "outputs": [], 751 | "source": [ 752 | "import lime\n", 753 | "from lime import lime_range" 754 | ] 755 | }, 756 | { 757 | "cell_type": "code", 758 | "execution_count": null, 759 | "id": "1d196c2c-c02b-4a0f-9ed1-e4fbb7007bce", 760 | "metadata": {}, 761 | "outputs": [], 762 | "source": [ 763 | "i = 0 # please change the value with the index of image in the dataframe\n", 764 | "filepath = train_data_dir+metadata_df['filename'][i]\n", 765 | "image = tf.keras.preprocessing.image.load_img(filepath, color_mode='rgb', target_size=(img_height, img_width))\n", 766 | "image = np.array(image)" 767 | ] 768 | }, 769 | { 770 | "cell_type": "code", 771 | "execution_count": null, 772 | "id": "97146366-85b2-42c3-94c7-aeb27a287203", 773 | "metadata": {}, 774 | "outputs": [], 775 | "source": [ 776 | "explainer = lime_image.LimeImageExplainer()\n", 777 | "explanation = explainer.explain_instance(image, model.predict, top_labels=5)" 778 | ] 779 | }, 780 | { 781 | "cell_type": "code", 782 | "execution_count": null, 783 | "id": "6abe6822-1755-45f4-a3d3-a5a6c1a7eb83", 784 | "metadata": {}, 785 | "outputs": [], 786 | "source": [ 787 | "from skimage.segmentation import mark_boundaries\n", 788 | "\n", 789 | "temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=True)\n", 790 | "plt.imshow(mark_boundaries(image, mask))" 791 | ] 792 | }, 793 | { 794 | "cell_type": "code", 795 | "execution_count": null, 796 | "id": "7ee3eddb-0f45-4bae-8418-de74f659a47d", 797 | "metadata": {}, 798 | "outputs": [], 799 | "source": [ 800 | "temp, mask = explanation.get_image_and_mask(explanation.top_labels[1], positive_only=True, num_features=5, hide_rest=True)\n", 801 | "plt.imshow(mark_boundaries(image, mask))" 802 | ] 803 | }, 804 | { 805 | "cell_type": "markdown", 806 | "id": "d45adf91-f0b9-4ece-91ef-76a60cff2a0a", 807 | "metadata": {}, 808 | "source": [ 809 | "## Grad Cams" 810 | ] 811 | }, 812 | { 813 | "cell_type": "code", 814 | "execution_count": null, 815 | "id": "cc4a5898-50ac-4959-a085-5e9c0b7f1b3b", 816 | "metadata": {}, 817 | "outputs": [], 818 | "source": [ 819 | "# https://keras.io/examples/vision/grad_cam/\n", 820 | "def get_img_array(image_path, image_size):\n", 821 | " image = tf.keras.preprocessing.image.load_img(image_path, target_size=image_size)\n", 822 | " array = tf.keras.preprocessing.image.img_to_array(image)\n", 823 | " array = np.expand_dims(array, axis=0)\n", 824 | " return (array)\n", 825 | "\n", 826 | "def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):\n", 827 | " grad_model = tf.keras.models.Model(\n", 828 | " [model.inputs], [model.get_layer(last_conv_layer_name).output, model.output]\n", 829 | " )\n", 830 | " \n", 831 | " with tf.GradientTape() as tape:\n", 832 | " last_conv_layer_output, preds = grad_model(img_array)\n", 833 | " if pred_index is None:\n", 834 | " pred_index = tf.argmax(preds[0])\n", 835 | " class_channel = preds[:, pred_index]\n", 836 | " \n", 837 | " grads = tape.gradient(class_channel, last_conv_layer_output)\n", 838 | " pooled_grads = tf.reduce_mean(grads, axis=(0,1,2))\n", 839 | " last_conv_layer_output = last_conv_layer_output[0]\n", 840 | " heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]\n", 841 | " heatmap = tf.squeeze(heatmap)\n", 842 | " \n", 843 | " heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)\n", 844 | " return (heatmap.numpy())\n", 845 | "\n", 846 | "def save_and_display_gradcam(img_path, heatmap, cam_path='./cam.jpg', alpha=0.4):\n", 847 | " img = tf.keras.preprocessing.image.load_img(img_path)\n", 848 | " img = tf.keras.preprocessing.image.img_to_array(img)\n", 849 | " \n", 850 | " heatmap = np.uint8(255*heatmap)\n", 851 | " jet = cm.get_cmap('jet')\n", 852 | " jet_colors = jet(np.arange(256))[:, :3]\n", 853 | " jet_heatmap = jet_colors[heatmap]\n", 854 | " \n", 855 | " jet_heatmap = tf.keras.preprocessing.image.array_to_img(jet_heatmap)\n", 856 | " jet_heatmap = jet_heatmap.resize((img.shape[1], img.shape[0]))\n", 857 | " jet_heatmap = tf.keras.preprocessing.image.img_to_array(jet_heatmap)\n", 858 | " \n", 859 | " superimposed_img = jet_heatmap * alpha + img\n", 860 | " superimposed_img = tf.keras.preprocessing.image.array_to_img(superimposed_img)\n", 861 | " superimposed_img.save(cam_path)\n", 862 | " \n", 863 | " display(Image(cam_path))" 864 | ] 865 | }, 866 | { 867 | "cell_type": "code", 868 | "execution_count": null, 869 | "id": "5c742258-caa0-480d-9580-d82e818a8287", 870 | "metadata": {}, 871 | "outputs": [], 872 | "source": [ 873 | "i = 100 # please change the value with the index of image in the dataframe\n", 874 | "image_path = train_data_dir+metadata_df[filename][i]\n", 875 | "img_array = get_img_array(image_path=image_path, image_size=(img_height, img_width))\n", 876 | "model = tf.keras.models.load_model(filepath = model_path)\n", 877 | "model.layers[-1] = None\n", 878 | "preds = model.predict(img_array)\n", 879 | "heatmap = make_gradcam_heatmap(img_array, model, \n", 880 | " last_conv_layer_name='stage4_unit3_conv3')\n", 881 | "\n", 882 | "save_and_display_gradcam(image_path, heatmap)" 883 | ] 884 | }, 885 | { 886 | "cell_type": "code", 887 | "execution_count": null, 888 | "id": "2007d919-bee2-4b59-bf35-094ae9f8e761", 889 | "metadata": {}, 890 | "outputs": [], 891 | "source": [] 892 | } 893 | ], 894 | "metadata": { 895 | "kernelspec": { 896 | "display_name": "Python 3", 897 | "language": "python", 898 | "name": "python3" 899 | }, 900 | "language_info": { 901 | "codemirror_mode": { 902 | "name": "ipython", 903 | "version": 3 904 | }, 905 | "file_extension": ".py", 906 | "mimetype": "text/x-python", 907 | "name": "python", 908 | "nbconvert_exporter": "python", 909 | "pygments_lexer": "ipython3", 910 | "version": "3.8.8" 911 | } 912 | }, 913 | "nbformat": 4, 914 | "nbformat_minor": 5 915 | } 916 | --------------------------------------------------------------------------------