├── images ├── figure.png ├── q_results.png ├── results1.png └── results2.png ├── embeddings └── nih_chest_xray_biobert.npy ├── .gitignore ├── plots.py ├── scripts ├── test_densenet121.sh └── train_densenet121.sh ├── test.py ├── train.py ├── loss.py ├── arguments.py ├── environment.yml ├── README.md ├── dataset.py ├── zsl_models.py └── ChexnetTrainer.py /images/figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyuad-cai/CXR-ML-GZSL/HEAD/images/figure.png -------------------------------------------------------------------------------- /images/q_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyuad-cai/CXR-ML-GZSL/HEAD/images/q_results.png -------------------------------------------------------------------------------- /images/results1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyuad-cai/CXR-ML-GZSL/HEAD/images/results1.png -------------------------------------------------------------------------------- /images/results2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyuad-cai/CXR-ML-GZSL/HEAD/images/results2.png -------------------------------------------------------------------------------- /embeddings/nih_chest_xray_biobert.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyuad-cai/CXR-ML-GZSL/HEAD/embeddings/nih_chest_xray_biobert.npy -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.tar 2 | __pycache__/* 3 | .ipynb_checkpoints/* 4 | .DS_Store 5 | ._** 6 | */**/._* 7 | checkpoints/* 8 | plots 9 | *.ipynb 10 | __pycache__ -------------------------------------------------------------------------------- /plots.py: -------------------------------------------------------------------------------- 1 | 2 | import matplotlib.pyplot as plt 3 | def plot_array(array, disc='loss'): 4 | plt.plot(array) 5 | plt.ylabel(disc) 6 | plt.savefig(f'{disc}.pdf') 7 | plt.close() -------------------------------------------------------------------------------- /scripts/test_densenet121.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python test.py \ 2 | --vision-backbone densenet121 \ 3 | --textual-embeddings embeddings/nih_chest_xray_biobert.npy \ 4 | --load-from checkpoints/best_auroc_checkpoint.pth.tar 5 | -------------------------------------------------------------------------------- /scripts/train_densenet121.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train.py \ 2 | --pretrained \ 3 | --vision-backbone densenet121 \ 4 | --save-dir checkpoints \ 5 | --epochs 40 \ 6 | --lr 0.0001 \ 7 | --beta-rank 1 \ 8 | --beta-map 0.01 \ 9 | --beta-con 0.01 \ 10 | --neg-penalty 0.20 \ 11 | --textual-embeddings embeddings/nih_chest_xray_biobert.npy \ 12 | --data-root /data/shamoutlab/nih_chest_xrays 13 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | import sys 5 | 6 | from ChexnetTrainer import ChexnetTrainer 7 | from arguments import parse_args 8 | 9 | 10 | def main (): 11 | 12 | args = parse_args() 13 | 14 | try: 15 | os.mkdir(args.save_dir) 16 | except OSError as error: 17 | print(error) 18 | 19 | trainer = ChexnetTrainer(args) 20 | print ('Testing the trained model') 21 | 22 | 23 | test_ind_auroc = trainer.test() 24 | test_ind_auroc = np.array(test_ind_auroc) 25 | 26 | 27 | 28 | trainer.print_auroc(test_ind_auroc[trainer.test_dl.dataset.seen_class_ids], trainer.test_dl.dataset.seen_class_ids, prefix='\ntest_seen') 29 | trainer.print_auroc(test_ind_auroc[trainer.test_dl.dataset.unseen_class_ids], trainer.test_dl.dataset.unseen_class_ids, prefix='\ntest_unseen') 30 | 31 | 32 | if __name__ == '__main__': 33 | main() 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | import sys 5 | import torch 6 | 7 | from ChexnetTrainer import ChexnetTrainer 8 | from arguments import parse_args 9 | 10 | 11 | def main (): 12 | 13 | args = parse_args() 14 | seed = 1002 15 | torch.manual_seed(seed) 16 | np.random.seed(seed) 17 | 18 | try: 19 | os.mkdir(args.save_dir) 20 | except OSError as error: 21 | print(error) 22 | 23 | trainer = ChexnetTrainer(args) 24 | trainer() 25 | 26 | checkpoint = torch.load(f'{args.save_dir}/min_loss_checkpoint.pth.tar') 27 | trainer.model.load_state_dict(checkpoint['state_dict']) 28 | print ('Testing the min loss model') 29 | test_ind_auroc = trainer.test() 30 | test_ind_auroc = np.array(test_ind_auroc) 31 | 32 | 33 | 34 | trainer.print_auroc(test_ind_auroc[trainer.test_dl.dataset.seen_class_ids], trainer.test_dl.dataset.seen_class_ids, prefix='\ntest_seen') 35 | trainer.print_auroc(test_ind_auroc[trainer.test_dl.dataset.unseen_class_ids], trainer.test_dl.dataset.unseen_class_ids, prefix='\ntest_unseen') 36 | 37 | checkpoint = torch.load(f'{args.save_dir}/best_auroc_checkpoint.pth.tar') 38 | trainer.model.load_state_dict(checkpoint['state_dict']) 39 | print ('Testing the best AUROC model') 40 | test_ind_auroc = trainer.test() 41 | test_ind_auroc = np.array(test_ind_auroc) 42 | 43 | 44 | 45 | trainer.print_auroc(test_ind_auroc[trainer.test_dl.dataset.seen_class_ids], trainer.test_dl.dataset.seen_class_ids, prefix='\ntest_seen') 46 | trainer.print_auroc(test_ind_auroc[trainer.test_dl.dataset.unseen_class_ids], trainer.test_dl.dataset.unseen_class_ids, prefix='\ntest_unseen') 47 | 48 | if __name__ == '__main__': 49 | main() 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Nasir Hayat (nasirhayat6160@gmail.com) 3 | Date: June 10, 2020 4 | """ 5 | 6 | from __future__ import print_function 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import numpy as np 12 | from numpy import linalg as LA 13 | from torch.nn.functional import kl_div, softmax, log_softmax 14 | 15 | 16 | class KLDivLoss(nn.Module): 17 | def __init__(self, temperature=0.2): 18 | super(KLDivLoss, self).__init__() 19 | 20 | self.temperature = temperature 21 | def forward(self, emb1, emb2): 22 | emb1 = softmax(emb1/self.temperature, dim=1).detach() 23 | emb2 = log_softmax(emb2/self.temperature, dim=1) 24 | loss_kldiv = kl_div(emb2, emb1, reduction='none') 25 | loss_kldiv = torch.sum(loss_kldiv, dim=1) 26 | loss_kldiv = torch.mean(loss_kldiv) 27 | return loss_kldiv 28 | 29 | class RankingLoss(nn.Module): 30 | def __init__(self, neg_penalty=0.03): 31 | super(RankingLoss, self).__init__() 32 | 33 | self.neg_penalty = neg_penalty 34 | def forward(self, ranks, labels, class_ids_loaded, device): 35 | ''' 36 | for each correct it should be higher then the absence 37 | ''' 38 | labels = labels[:, class_ids_loaded] 39 | ranks_loaded = ranks[:, class_ids_loaded] 40 | neg_labels = 1+(labels*-1) 41 | loss_rank = torch.zeros(1).to(device) 42 | for i in range(len(labels)): 43 | correct = ranks_loaded[i, labels[i]==1] 44 | wrong = ranks_loaded[i, neg_labels[i]==1] 45 | correct = correct.reshape((-1, 1)).repeat((1, len(wrong))) 46 | wrong = wrong.repeat(len(correct)).reshape(len(correct), -1) 47 | image_level_penalty = ((self.neg_penalty+wrong) - correct) 48 | image_level_penalty[image_level_penalty<0]=0 49 | loss_rank += image_level_penalty.sum() 50 | loss_rank /=len(labels) 51 | 52 | return loss_rank 53 | 54 | class CosineLoss(nn.Module): 55 | 56 | def forward(self, t_emb, v_emb ): 57 | a_norm = v_emb / v_emb.norm(dim=1)[:, None] 58 | b_norm = t_emb / t_emb.norm(dim=1)[:, None] 59 | loss = 1 - torch.mean(torch.diagonal(torch.mm(a_norm, b_norm.t()), 0)) 60 | 61 | return loss 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | argParser = argparse.ArgumentParser(description='arguments') 5 | 6 | argParser.add_argument('--data-root', default='data/nih_chest_xrays', type=str, help='the path to dataset') 7 | argParser.add_argument('--save-dir', default='checkpoints', type=str, help='the path to save the checkpoints') 8 | argParser.add_argument('--train-file', default='dataset_splits/train.txt', type=str, help='the path to train list ') 9 | argParser.add_argument('--val-file', default='dataset_splits/val.txt', type=str, help='the path to val list ') 10 | argParser.add_argument('--test-file', default='dataset_splits/test.txt', type=str, help='the path to test list') 11 | 12 | argParser.add_argument('--pretrained', dest='pretrained', action='store_true', help='load imagenet pretrained model') 13 | argParser.add_argument('--bce-only', dest='bce_only', help='train with only binary cross entropy loss', action='store_true') 14 | 15 | argParser.add_argument('--num-classes', default=14, type=int, help='number of classes') 16 | argParser.add_argument('--batch-size', default=16, type=int, help='training batch size') 17 | argParser.add_argument('--epochs', default=40, type=int, help='number of epochs to train') 18 | argParser.add_argument('--vision-backbone', default='densenet121', type=str, help='[densenet121, densenet169, densenet201]') 19 | argParser.add_argument('--resume-from', default=None, type=str, help='path to checkpoint to resume the training from') 20 | argParser.add_argument('--load-from', default=None, type=str, help='path to checkpoint to load the weights from') 21 | 22 | argParser.add_argument('--resize', default=256, type=int, help='number of epochs to train') 23 | argParser.add_argument('--crop', default=224, type=int, help='number of epochs to train') 24 | argParser.add_argument('--lr', default=0.0001, type=float, help='learning rate') 25 | argParser.add_argument('--steps', default='20, 40, 60, 80', type=str, help='learning rate decay steps comma separated') 26 | 27 | argParser.add_argument('--beta-map', default=0.1, type=float, help='learning rate') 28 | argParser.add_argument('--beta-con', default=0.1, type=float, help='learning rate') 29 | argParser.add_argument('--beta-rank', default=1, type=float, help='learning rate') 30 | argParser.add_argument('--neg-penalty', default=0.03, type=float, help='learning rate') 31 | 32 | argParser.add_argument('--wo-con', dest='wo_con', help='train with out semantic consistency regularizer loss', action='store_true') 33 | argParser.add_argument('--wo-map', dest='wo_map', help='train with out alignement loss', action='store_true') 34 | 35 | 36 | argParser.add_argument('--textual-embeddings', default='../text_embeddings/embeddings/nih_chest_xray_biobert.npy', type=str, help='the path to labels embeddings') 37 | args = argParser.parse_args() 38 | return args 39 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: zsl 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - anaconda 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1 9 | - argon2-cffi=20.1.0 10 | - async_generator=1.10 11 | - attrs=20.2.0 12 | - backcall=0.2.0 13 | - blas=1.0 14 | - bleach=3.2.1 15 | - brotlipy=0.7.0 16 | - bzip2=1.0.8 17 | - ca-certificates=2021.1.19 18 | - cairo=1.14.12 19 | - certifi=2020.12.5 20 | - cffi=1.14.3 21 | - chardet=3.0.4 22 | - cloudpickle=1.6.0 23 | - cryptography=3.1.1 24 | - cudatoolkit=10.0.130 25 | - cycler=0.10.0 26 | - dbus=1.13.6 27 | - decorator=4.4.2 28 | - defusedxml=0.6.0 29 | - entrypoints=0.3 30 | - expat=2.2.9 31 | - ffmpeg=4.0 32 | - fontconfig=2.13.1 33 | - freeglut=3.0.0 34 | - freetype=2.10.4 35 | - glib=2.66.1 36 | - graphite2=1.3.14 37 | - gst-plugins-base=1.14.0 38 | - gstreamer=1.14.0 39 | - harfbuzz=1.8.8 40 | - hdf5=1.10.2 41 | - icu=58.2 42 | - idna=2.10 43 | - importlib-metadata=2.0.0 44 | - importlib_metadata=2.0.0 45 | - intel-openmp=2020.2 46 | - ipykernel=5.3.4 47 | - ipython=7.16.1 48 | - ipython_genutils=0.2.0 49 | - jasper=2.0.14 50 | - jedi=0.17.2 51 | - jinja2=2.11.2 52 | - joblib=0.17.0 53 | - jpeg=9b 54 | - jsonschema=3.2.0 55 | - jupyter_client=6.1.7 56 | - jupyter_core=4.6.3 57 | - jupyterlab_pygments=0.1.2 58 | - kiwisolver=1.3.1 59 | - lcms2=2.11 60 | - ld_impl_linux-64=2.33.1 61 | - libedit=3.1.20191231 62 | - libffi=3.3 63 | - libgcc-ng=9.1.0 64 | - libgfortran-ng=7.3.0 65 | - libglu=9.0.0 66 | - libopencv=3.4.2 67 | - libopus=1.3.1 68 | - libpng=1.6.37 69 | - libsodium=1.0.18 70 | - libstdcxx-ng=9.1.0 71 | - libtiff=4.1.0 72 | - libuuid=2.32.1 73 | - libvpx=1.7.0 74 | - libxcb=1.13 75 | - libxml2=2.9.10 76 | - lz4-c=1.9.2 77 | - markupsafe=1.1.1 78 | - matplotlib=3.3.3 79 | - matplotlib-base=3.3.3 80 | - mistune=0.8.4 81 | - mkl=2020.2 82 | - mkl-service=2.3.0 83 | - mkl_fft=1.2.0 84 | - mkl_random=1.1.1 85 | - nb_conda=2.2.1 86 | - nb_conda_kernels=2.3.0 87 | - nbclient=0.5.1 88 | - nbconvert=6.0.7 89 | - nbformat=5.0.8 90 | - ncurses=6.2 91 | - nest-asyncio=1.4.1 92 | - networkx=2.5 93 | - ninja=1.10.2 94 | - notebook=6.1.4 95 | - numpy=1.19.2 96 | - numpy-base=1.19.2 97 | - olefile=0.46 98 | - opencv=3.4.2 99 | - openssl=1.1.1j 100 | - packaging=20.4 101 | - pandas=1.1.3 102 | - pandoc=2.11 103 | - pandocfilters=1.4.2 104 | - parso=0.7.0 105 | - pcre=8.44 106 | - pexpect=4.8.0 107 | - pickleshare=0.7.5 108 | - pillow=8.0.0 109 | - pip=20.3.3 110 | - pixman=0.40.0 111 | - prometheus_client=0.8.0 112 | - prompt-toolkit=3.0.8 113 | - pthread-stubs=0.4 114 | - ptyprocess=0.6.0 115 | - py-opencv=3.4.2 116 | - pycparser=2.20 117 | - pygments=2.7.1 118 | - pyopenssl=19.1.0 119 | - pyparsing=2.4.7 120 | - pyqt=5.9.2 121 | - pyrsistent=0.17.3 122 | - pysocks=1.7.1 123 | - python=3.6.12 124 | - python-dateutil=2.8.1 125 | - python_abi=3.6 126 | - pytorch=1.4.0 127 | - pytz=2020.1 128 | - pyzmq=19.0.2 129 | - qt=5.9.7 130 | - readline=8.0 131 | - requests=2.24.0 132 | - scikit-learn=0.23.2 133 | - scipy=1.5.2 134 | - send2trash=1.5.0 135 | - setuptools=51.0.0 136 | - sip=4.19.8 137 | - six=1.15.0 138 | - sqlite=3.33.0 139 | - terminado=0.9.1 140 | - testpath=0.4.4 141 | - threadpoolctl=2.1.0 142 | - tk=8.6.10 143 | - torchvision=0.5.0 144 | - tornado=6.0.4 145 | - tqdm=4.55.1 146 | - traitlets=4.3.3 147 | - urllib3=1.25.11 148 | - wcwidth=0.2.5 149 | - webencodings=0.5.1 150 | - wheel=0.36.2 151 | - xorg-libxau=1.0.9 152 | - xorg-libxdmcp=1.1.3 153 | - xz=5.2.5 154 | - zeromq=4.3.3 155 | - zipp=3.3.1 156 | - zlib=1.2.11 157 | - zstd=1.4.5 158 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Code for [MLHC 2021](https://www.mlforhc.org/accepted-papers-1) Multi-Label Generalized Zero Shot Learning for the Classification of Disease in Chest Radiographs 4 | 5 | 6 | Table of contents 7 | ================= 8 | 9 | 10 | * [Background](#Background) 11 | * [Overview](#Overview) 12 | * [Environment setup](#Environment-setup) 13 | * [Dataset](#Dataset) 14 | * [Model evaluation](#Model-evaluation) 15 | * [Model training](#Model-training) 16 | * [Citation](#Citation) 17 | * [Results](#Results) 18 | 19 | 20 | 21 | Background 22 | ============ 23 | Despite the success of deep neural networks in chest X-ray (CXR) diagnosis, supervised learning only allows the prediction of disease classes that were seen during training. At inference, these networks cannot predict an unseen disease class. Incorporating a new class requires the collection of labeled data, which is not a trivial task, especially for less frequently-occurring diseases. As a result, it becomes inconceivable to build a model that can diagnose all possible disease classes. This repo contains the pytorch implementation for our proposed network; multi-label generalized zero shot learning (CXR-ML-GZSL) that can simultaneously predict multiple seen and unseen diseases in CXR images. Given an input image, CXR-ML-GZSL learns the visual representations guided by the input's corresponding semantics extracted from a rich medical text corpus. 24 | 25 | Overview of the CXR-ML-GZSL network 26 | ==================================== 27 | 28 | The components of the network consist of (i) a trainable visual encoder, (ii) a fixed semantic encoder, & (iii) a projection module to map the encoded features to a joint latent space. Our approach is end-to-end trainable and does not require offline training of the visual feature encoder. 29 | ![](images/figure.png) 30 | 31 | 32 | Environment setup 33 | ================== 34 | 35 | git clone https://github.com/nyuad-cai/CXR-ML-GZSL.git 36 | cd CXR-ML-GZSL 37 | conda env create -f environment.yml 38 | conda activate zsl 39 | 40 | Dataset 41 | ------------- 42 | 43 | 44 | We evaluated the proposed method on the NIH Chest X-ray dataset with a random split of 10 seen and 4 unseen classes. To train and evaluate the network, download the [NIH chest x-ray dataset](https://nihcc.app.box.com/v/ChestXray-NIHCC). 45 | 46 | 47 | Model evaluation 48 | ------------------ 49 | 50 | 51 | - To perform evaluation only, you can simply download the pretrained network [weights](https://drive.google.com/file/d/17ioJMW3qNx1Ktmr-hXn-eqp431cm49Rm/view?usp=sharing). 52 | - Update the paths of data-root directory and pretrained weights to run the following script. 53 | 54 | sh ./scripts/test_densenet121.sh 55 | 56 | Model training 57 | ----------------- 58 | 59 | 60 | - To train the network, you run the following script by setting the data-root directory path of the downloaded dataset. 61 | 62 | sh ./scripts/train_densenet121.sh 63 | 64 | Citation 65 | ============ 66 | 67 | If you use this code for your research, please consider citing: 68 | 69 | ``` 70 | @misc{hayat2021multilabel, 71 | title={Multi-Label Generalized Zero Shot Learning for the Classification of Disease in Chest Radiographs}, 72 | author={Nasir Hayat and Hazem Lashen and Farah E. Shamout}, 73 | year={2021}, 74 | eprint={2107.06563}, 75 | archivePrefix={arXiv}, 76 | primaryClass={cs.CV} 77 | } 78 | ``` 79 | Results 80 | ============ 81 | 82 | - We compare the performance of our proposed approach with two baseline methods and report overall precision, recall, f1-score @k {2,3} and AUROC for seen & unseen classes. 83 | 84 | ![](images/results1.png) 85 | - Class wise comparison with baseline methods. The last 4 italicized classes are unseen during training. 86 | 87 | ![](images/results2.png) 88 | 89 | 90 | - Below we show some of the visual results for top 3 predictions. Green, orange, and red represent true positives, false negatives, and false positives, respectively. Note that our method predicts the unseen classes as top-3 even when the number of ground-truth labels is greater than 3. 91 | 92 | ![](images/q_results.png) 93 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import pandas as pd 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | # import 9 | import glob 10 | 11 | 12 | class NIHChestXray(Dataset): 13 | 14 | def __init__ (self, args, pathDatasetFile, transform, classes_to_load='seen', exclude_all=True): 15 | 16 | self.listImagePaths = [] 17 | self.listImageLabels = [] 18 | self.transform = transform 19 | self.num_classes = args.num_classes 20 | 21 | self._data_path = args.data_root 22 | self.args = args 23 | 24 | self.split_path = pathDatasetFile 25 | self.CLASSES = [ 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia', 26 | 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia'] 27 | 28 | 29 | 30 | 31 | 32 | self.unseen_classes = ['Edema', 'Pneumonia', 'Emphysema', 'Fibrosis'] 33 | 34 | self.seen_classes = [ 'Atelectasis', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 35 | 'Pneumothorax', 'Consolidation', 'Cardiomegaly', 'Pleural_Thickening', 'Hernia'] 36 | 37 | self._class_ids = {v: i for i, v in enumerate(self.CLASSES) if v != 'No Finding'} 38 | 39 | self.seen_class_ids = [self._class_ids[label] for label in self.seen_classes] 40 | self.unseen_class_ids = [self._class_ids[label] for label in self.unseen_classes] 41 | 42 | 43 | self.classes_to_load = classes_to_load 44 | self.exclude_all = exclude_all 45 | self._construct_index() 46 | 47 | def _construct_index(self): 48 | # Compile the split data path 49 | max_labels = 0 50 | paths = glob.glob(f'{self._data_path}/**/images/*.png') 51 | self.names_to_path = {path.split('/')[-1]: path for path in paths} 52 | data_entry_file = 'Data_Entry_2017.csv' 53 | # split_path = os.path.join(self._data_path, self._split) 54 | print(f'data partition path: {self.split_path}') 55 | with open(self.split_path, 'r') as f: file_names = f.readlines() 56 | 57 | 58 | split_file_names = np.array([file_name.strip().split(' ')[0].split('/')[-1] for file_name in file_names]) 59 | df = pd.read_csv(f'{self._data_path}/{data_entry_file}') 60 | image_index = df.iloc[:, 0].values 61 | 62 | _, split_index, _ = np.intersect1d(image_index, split_file_names, return_indices=True) 63 | 64 | 65 | 66 | labels = df.iloc[:, 1].values 67 | labels = np.array(labels)[split_index] 68 | 69 | labels = [label.split('|') for label in labels] 70 | 71 | image_index = image_index[split_index] 72 | 73 | 74 | # remove No Finding 75 | 76 | # Construct the image db 77 | self._imdb = [] 78 | self.class_ids_loaded = [] 79 | for index in range(len(split_index)): 80 | if len(labels[index]) == 1 and labels[index][0] == 'No Finding': 81 | continue 82 | if self._should_load_image(labels[index]) is False: 83 | continue 84 | class_ids = [self._class_ids[label] for label in labels[index]] 85 | self.class_ids_loaded +=class_ids 86 | self._imdb.append({ 87 | 'im_path': self.names_to_path[image_index[index]], 88 | 'labels': class_ids, 89 | }) 90 | max_labels = max(max_labels, len(class_ids)) 91 | 92 | # import pdb; pdb.set_trace() 93 | self.class_ids_loaded = np.unique(np.array(self.class_ids_loaded)) 94 | print(f'Number of images: {len(self._imdb)}') 95 | print(f'Number of max labels per image: {max_labels}') 96 | print(f'Number of classes: {len(self.class_ids_loaded)}') 97 | 98 | 99 | def _should_load_image(self, labels): 100 | 101 | 102 | selected_class_labels = self.CLASSES 103 | if self.classes_to_load == 'seen': 104 | selected_class_labels = self.seen_classes 105 | elif self.classes_to_load == 'unseen': 106 | selected_class_labels = self.unseen_classes 107 | elif self.classes_to_load == 'all': 108 | return True 109 | 110 | count = 0 111 | for label in labels: 112 | if label in selected_class_labels: 113 | count+=1 114 | 115 | if count == len(labels): 116 | # all labels from selected sub set 117 | return True 118 | elif count == 0: 119 | # none label in selected sub set 120 | return False 121 | else: 122 | # some labels in selected sub set 123 | if self.exclude_all is True: 124 | return False 125 | else: 126 | return True 127 | 128 | 129 | 130 | def __getitem__(self, index): 131 | 132 | imagePath = self._imdb[index]['im_path'] 133 | 134 | imageData = Image.open(imagePath).convert('RGB') 135 | 136 | labels = torch.tensor(self._imdb[index]['labels']) 137 | 138 | labels = labels.unsqueeze(0) 139 | imageLabel = torch.zeros(labels.size(0), self.num_classes).scatter_(1, labels, 1.).squeeze() 140 | 141 | img = self.transform(imageData) 142 | return img, imageLabel 143 | 144 | 145 | def __len__(self): 146 | 147 | return len(self._imdb) 148 | 149 | -------------------------------------------------------------------------------- /zsl_models.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import torchvision 4 | import torch 5 | import numpy as np 6 | 7 | from torch.nn.functional import kl_div, softmax, log_softmax 8 | from loss import RankingLoss, CosineLoss 9 | import torch.nn.functional as F 10 | 11 | class ZSLNet(nn.Module): 12 | 13 | 14 | 15 | 16 | def __init__(self, args, textual_embeddings=None, device='cpu'): 17 | 18 | super(ZSLNet, self).__init__() 19 | self.args = args 20 | self.device = device 21 | self.vision_backbone = getattr(torchvision.models, self.args.vision_backbone)(pretrained=self.args.pretrained) 22 | # remove classification layer from visual encoder 23 | classifiers = [ 'classifier', 'fc'] 24 | for classifier in classifiers: 25 | cls_layer = getattr(self.vision_backbone, classifier, None) 26 | if cls_layer is None: 27 | continue 28 | d_visual = cls_layer.in_features 29 | setattr(self.vision_backbone, classifier, nn.Identity(d_visual)) 30 | break 31 | 32 | pretrained_encoder = False 33 | if pretrained_encoder: 34 | self.vision_backbone.classifier = nn.Identity(d_visual) 35 | 36 | path = 'checkpoints/bce_only_imagenet/last_epoch_checkpoint.pth.tar' 37 | 38 | self.classifier = nn.Sequential(nn.Linear(d_visual, self.args.num_classes), nn.Sigmoid()) 39 | checkpoint = torch.load(path) 40 | self.load_state_dict(checkpoint['state_dict']) 41 | 42 | for p in self.vision_backbone.parameters(): 43 | p.requires_grad = False 44 | 45 | if self.args.bce_only: 46 | self.bce_loss = torch.nn.BCELoss(size_average=True) 47 | self.classifier = nn.Sequential(nn.Linear(d_visual, self.args.num_classes), nn.Sigmoid()) 48 | else: 49 | self.emb_loss = CosineLoss() 50 | self.ranking_loss = RankingLoss(neg_penalty=self.args.neg_penalty) 51 | self.textual_embeddings = textual_embeddings 52 | d_textual = self.textual_embeddings.shape[-1] 53 | 54 | self.textual_embeddings = torch.from_numpy(self.textual_embeddings).to(self.device) 55 | 56 | self.fc_v = nn.Sequential( 57 | nn.Linear(d_visual, 512), 58 | nn.ReLU(), 59 | nn.Linear(512, 256), 60 | nn.ReLU(), 61 | nn.Linear(256, 128), 62 | ) 63 | 64 | self.fc_t = nn.Sequential( 65 | nn.Linear(d_textual, 512), 66 | nn.ReLU(), 67 | nn.Linear(512, 256), 68 | nn.ReLU(), 69 | nn.Linear(256, 128) 70 | ) 71 | 72 | 73 | def forward(self, x, labels=None, epoch=0, n_crops=0, bs=16): 74 | if self.args.bce_only: 75 | return self.forward_bce_only(x, labels=labels, n_crops=n_crops, bs=bs) 76 | else: 77 | return self.forward_ranking(x, labels=labels, epoch=epoch, n_crops=n_crops, bs=bs) 78 | 79 | def forward_bce_only(self, x, labels=None, n_crops=0, bs=16): 80 | lossvalue_bce = torch.zeros(1).to(self.device) 81 | 82 | visual_feats = self.vision_backbone(x) 83 | preds = self.classifier(visual_feats) 84 | 85 | if labels is not None: 86 | lossvalue_bce = self.bce_loss(preds, labels) 87 | 88 | return preds, lossvalue_bce, f'bce:\t {lossvalue_bce.item():0.4f}' 89 | 90 | 91 | def forward_ranking(self, x, labels=None, epoch=0, n_crops=0, bs=16): 92 | loss_rank = torch.zeros(1).to(self.device) 93 | loss_allignment_cos = torch.zeros(1).to(self.device) 94 | loss_mapping_consistency = torch.zeros(1).to(self.device) 95 | 96 | 97 | visual_feats = self.vision_backbone(x) 98 | visual_feats = self.fc_v(visual_feats) 99 | text_feats = self.fc_t(self.textual_embeddings) 100 | 101 | 102 | 103 | if not self.args.wo_con and epoch >= 0: 104 | text_mapped_sim = self.sim_score(text_feats, text_feats.detach()) 105 | text_orig_sim = self.sim_score(self.textual_embeddings, self.textual_embeddings) 106 | loss_mapping_consistency = torch.abs(text_orig_sim - text_mapped_sim).mean() 107 | 108 | 109 | 110 | 111 | if labels is not None: 112 | mapped_visual, mapped_text = self.map_visual_text(visual_feats, labels, text_feats) 113 | if mapped_visual is not None and not self.args.wo_map and epoch >= 0: 114 | loss_allignment_cos = self.emb_loss(mapped_text, mapped_visual) 115 | 116 | 117 | 118 | ranks = self.sim_score(visual_feats, text_feats) 119 | if n_crops > 0: 120 | ranks = ranks.view(bs, n_crops, -1).mean(1) 121 | 122 | 123 | if labels is not None: 124 | loss_rank = self.ranking_loss(ranks, labels, self.class_ids_loaded, self.device) 125 | loss_allignment_cos = (self.args.beta_map * loss_allignment_cos) 126 | loss_rank = (self.args.beta_rank * loss_rank) 127 | loss_mapping_consistency = (self.args.beta_con * loss_mapping_consistency) 128 | losses = loss_rank + loss_mapping_consistency + 0.0*loss_allignment_cos 129 | return ranks, losses 130 | 131 | def sim_score(self, a, b): 132 | a_norm = a / a.norm(dim=1)[:, None] 133 | b_norm = b / (1e-6+b.norm(dim=1))[:, None] 134 | score = (torch.mm(a_norm, b_norm.t())) 135 | return score 136 | 137 | def map_visual_text(self, visual_feats, labels, labels_embd): 138 | 139 | mapped_labels_embd = [] 140 | labels == 1 141 | for i in range(0, labels.shape[0]): 142 | class_embd = labels_embd[labels[i]==1].mean(dim=0)[None,:] 143 | mapped_labels_embd.append(class_embd) 144 | mapped_labels_embd = torch.cat(mapped_labels_embd) 145 | 146 | 147 | return visual_feats.detach(), mapped_labels_embd.detach() -------------------------------------------------------------------------------- /ChexnetTrainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | from datetime import datetime, timedelta 4 | from tqdm import tqdm 5 | 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torchvision.transforms as transforms 9 | import torch.optim as optim 10 | from torch.utils.data import DataLoader 11 | from torch.optim.lr_scheduler import ReduceLROnPlateau 12 | 13 | from sklearn.metrics.ranking import roc_auc_score 14 | from sklearn.metrics import accuracy_score 15 | from zsl_models import ZSLNet 16 | from dataset import NIHChestXray 17 | from torch.nn.functional import kl_div, softmax, log_softmax 18 | from numpy import dot 19 | from numpy.linalg import norm 20 | from plots import plot_array 21 | # #-------------------------------------------------------------------------------- 22 | 23 | 24 | class ChexnetTrainer(object): 25 | def __init__(self, args): 26 | self.args = args 27 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 28 | 29 | self.textual_embeddings = np.load(args.textual_embeddings) 30 | 31 | self.model = ZSLNet(self.args, self.textual_embeddings, self.device).to(self.device) 32 | self.optimizer = optim.Adam (self.model.parameters(), lr=self.args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5) 33 | # self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.lr, momentum=0.9, weight_decay=0.0001) 34 | # self.scheduler = self.step_lr 35 | self.scheduler = ReduceLROnPlateau(self.optimizer, factor=0.1, patience=5, mode='min') 36 | 37 | 38 | self.loss = torch.nn.BCELoss(size_average=True) 39 | self.auroc_min_loss = 0.0 40 | 41 | self.start_epoch = 1 42 | self.lossMIN = float('inf') 43 | self.max_auroc_mean = float('-inf') 44 | self.best_epoch = 1 45 | 46 | self.val_losses = [] 47 | print(self.model) 48 | print(self.optimizer) 49 | print(self.scheduler) 50 | print(self.loss) 51 | print(f'\n\nloaded imagenet weights {self.args.pretrained}\n\n\n') 52 | self.resume_from() 53 | self.load_from() 54 | self.init_dataset() 55 | self.steps = [int(step) for step in self.args.steps.split(',')] 56 | self.time_start = time.time() 57 | self.time_end = time.time() 58 | self.should_test = False 59 | self.model.class_ids_loaded = self.train_dl.dataset.class_ids_loaded 60 | 61 | 62 | def __call__(self): 63 | self.train() 64 | 65 | def step_lr(self, epoch): 66 | step = self.steps[0] 67 | for index, s in enumerate(self.steps): 68 | if epoch < s: 69 | break 70 | else: 71 | step = s 72 | 73 | lr = self.args.lr * (0.1 ** (epoch // step)) 74 | for param_group in self.optimizer.param_groups: 75 | param_group['lr'] = lr 76 | 77 | def load_from(self): 78 | if self.args.load_from is not None: 79 | checkpoint = torch.load(self.args.load_from) 80 | self.model.load_state_dict(checkpoint['state_dict']) 81 | print(f'loaded checkpoint from {self.args.load_from}') 82 | 83 | def resume_from(self): 84 | if self.args.resume_from is not None: 85 | checkpoint = torch.load(self.args.resume_from) 86 | self.model.load_state_dict(checkpoint['state_dict']) 87 | self.optimizer.load_state_dict(checkpoint['optimizer']) 88 | self.start_epoch = checkpoint['epoch'] + 1 89 | self.lossMIN = checkpoint['lossMIN'] 90 | self.max_auroc_mean = checkpoint['max_auroc_mean'] 91 | print(f'resuming training from epoch {self.start_epoch}') 92 | 93 | def save_checkpoint(self, prefix='best'): 94 | path = f'{self.args.save_dir}/{prefix}_checkpoint.pth.tar' 95 | torch.save( 96 | { 97 | 'epoch': self.epoch, 98 | 'state_dict': self.model.state_dict(), 99 | 'max_auroc_mean': self.max_auroc_mean, 100 | 'optimizer' : self.optimizer.state_dict(), 101 | 'lossMIN' : self.lossMIN 102 | }, path) 103 | print(f"saving {prefix} checkpoint") 104 | def init_dataset(self): 105 | normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 106 | 107 | train_transforms = [] 108 | train_transforms.append(transforms.RandomResizedCrop(self.args.crop)) 109 | train_transforms.append(transforms.RandomHorizontalFlip()) 110 | train_transforms.append(transforms.ToTensor()) 111 | train_transforms.append(normalize) 112 | 113 | datasetTrain = NIHChestXray(self.args, self.args.train_file, transform=transforms.Compose(train_transforms)) 114 | 115 | self.train_dl = DataLoader(dataset=datasetTrain, batch_size=self.args.batch_size, shuffle=True, num_workers=4, pin_memory=True) 116 | 117 | 118 | test_transforms = [] 119 | test_transforms.append(transforms.Resize(self.args.resize)) 120 | test_transforms.append(transforms.TenCrop(self.args.crop)) 121 | test_transforms.append(transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops]))) 122 | test_transforms.append(transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops]))) 123 | 124 | 125 | datasetVal = NIHChestXray(self.args, self.args.val_file, transform=transforms.Compose(test_transforms)) 126 | self.val_dl = DataLoader(dataset=datasetVal, batch_size=self.args.batch_size*10, shuffle=False, num_workers=4, pin_memory=True) 127 | 128 | datasetTest = NIHChestXray(self.args, self.args.test_file, transform=transforms.Compose(test_transforms), classes_to_load='all') 129 | self.test_dl = DataLoader(dataset=datasetTest, batch_size=self.args.batch_size*3, num_workers=8, shuffle=False, pin_memory=True) 130 | print(datasetTest.CLASSES) 131 | 132 | 133 | def train (self): 134 | 135 | for self.epoch in range (self.start_epoch, self.args.epochs): 136 | 137 | self.epochTrain() 138 | lossVal, val_ind_auroc = self.epochVal() 139 | val_ind_auroc = np.array(val_ind_auroc) 140 | 141 | 142 | aurocMean = val_ind_auroc.mean() 143 | self.save_checkpoint(prefix=f'last_epoch') 144 | self.should_test = False 145 | 146 | if aurocMean > self.max_auroc_mean: 147 | self.max_auroc_mean = aurocMean 148 | self.save_checkpoint(prefix='best_auroc') 149 | self.best_epoch = self.epoch 150 | self.should_test = True 151 | if lossVal < self.lossMIN: 152 | self.lossMIN = lossVal 153 | self.auroc_min_loss = aurocMean 154 | self.save_checkpoint(prefix='min_loss') 155 | self.should_test = True 156 | 157 | self.print_auroc(val_ind_auroc, self.val_dl.dataset.class_ids_loaded, prefix='val') 158 | if self.should_test is True: 159 | test_ind_auroc = self.test() 160 | test_ind_auroc = np.array(test_ind_auroc) 161 | 162 | self.write_results(val_ind_auroc, self.val_dl.dataset.class_ids_loaded, prefix=f'\n\nepoch {self.epoch}\nval', mode='a') 163 | 164 | self.write_results(test_ind_auroc[self.test_dl.dataset.seen_class_ids], self.test_dl.dataset.seen_class_ids, prefix='\ntest_seen', mode='a') 165 | self.write_results(test_ind_auroc[self.test_dl.dataset.unseen_class_ids], self.test_dl.dataset.unseen_class_ids, prefix='\ntest_unseen', mode='a') 166 | 167 | self.print_auroc(test_ind_auroc[self.test_dl.dataset.seen_class_ids], self.test_dl.dataset.seen_class_ids, prefix='\ntest_seen') 168 | self.print_auroc(test_ind_auroc[self.test_dl.dataset.unseen_class_ids], self.test_dl.dataset.unseen_class_ids, prefix='\ntest_unseen') 169 | 170 | plot_array(self.val_losses, f'{self.args.save_dir}/val_loss') 171 | print(f'best epoch {self.best_epoch} best auroc {self.max_auroc_mean} loss {lossVal:.6f} auroc at min loss {self.auroc_min_loss:0.4f}') 172 | 173 | self.scheduler.step(lossVal) 174 | 175 | 176 | #-------------------------------------------------------------------------------- 177 | def get_eta(self, epoch, iter): 178 | self.time_end = time.time() 179 | delta = self.time_end - self.time_start 180 | delta = delta * (len(self.train_dl) * (self.args.epochs - epoch) - iter) 181 | sec = timedelta(seconds=int(delta)) 182 | d = (datetime(1,1,1) + sec) 183 | eta = f"{d.day-1} Days {d.hour}:{d.minute}:{d.second}" 184 | self.time_start = time.time() 185 | 186 | return eta 187 | 188 | 189 | 190 | 191 | 192 | def epochTrain(self): 193 | self.model.train() 194 | epoch_loss = 0 195 | for batchID, (inputs, target) in enumerate (self.train_dl): 196 | 197 | target = target.to(self.device) 198 | inputs = inputs.to(self.device) 199 | output, loss = self.model(inputs, target, self.epoch) 200 | 201 | 202 | self.optimizer.zero_grad() 203 | loss.backward() 204 | self.optimizer.step() 205 | eta = self.get_eta(self.epoch, batchID) 206 | epoch_loss +=loss.item() 207 | if batchID % 10 == 9: 208 | print(f" epoch [{self.epoch:04d} / {self.args.epochs:04d}] eta: {eta:<20} [{batchID:04}/{len(self.train_dl)}] lr: \t{self.optimizer.param_groups[0]['lr']:0.4E} loss: \t{epoch_loss/batchID:0.5f}") 209 | 210 | 211 | 212 | #-------------------------------------------------------------------------------- 213 | 214 | def epochVal (self): 215 | 216 | self.model.eval() 217 | 218 | lossVal = 0 219 | 220 | outGT = torch.FloatTensor().to(self.device) 221 | outPRED = torch.FloatTensor().to(self.device) 222 | for i, (inputs, target) in enumerate (tqdm(self.val_dl)): 223 | with torch.no_grad(): 224 | 225 | target = target.to(self.device) 226 | inputs = inputs.to(self.device) 227 | varTarget = torch.autograd.Variable(target) 228 | bs, n_crops, c, h, w = inputs.size() 229 | 230 | varInput = torch.autograd.Variable(inputs.view(-1, c, h, w).to(self.device)) 231 | 232 | varOutput, losstensor = self.model(varInput, varTarget, n_crops=n_crops, bs=bs) 233 | 234 | 235 | 236 | 237 | outPRED = torch.cat((outPRED, varOutput), 0) 238 | outGT = torch.cat((outGT, target), 0) 239 | 240 | lossVal+=losstensor.item() 241 | del varOutput, varTarget, varInput, target, inputs 242 | lossVal = lossVal / len(self.val_dl) 243 | 244 | aurocIndividual = self.computeAUROC(outGT, outPRED, self.val_dl.dataset.class_ids_loaded) 245 | self.val_losses.append(lossVal) 246 | 247 | return lossVal, aurocIndividual 248 | 249 | 250 | 251 | def test(self): 252 | cudnn.benchmark = True 253 | outGT = torch.FloatTensor().cuda() 254 | outPRED = torch.FloatTensor().cuda() 255 | 256 | self.model.eval() 257 | 258 | for i, (inputs, target) in enumerate(tqdm(self.test_dl)): 259 | with torch.no_grad(): 260 | target = target.to(self.device) 261 | outGT = torch.cat((outGT, target), 0) 262 | 263 | bs, n_crops, c, h, w = inputs.size() 264 | 265 | varInput = torch.autograd.Variable(inputs.view(-1, c, h, w).to(self.device)) 266 | 267 | out, _ = self.model(varInput, n_crops=n_crops, bs=bs) 268 | 269 | outPRED = torch.cat((outPRED, out.data), 0) 270 | 271 | 272 | 273 | aurocIndividual = self.computeAUROC(outGT, outPRED, self.test_dl.dataset.class_ids_loaded) 274 | 275 | aurocMean = np.array(aurocIndividual).mean() 276 | 277 | return aurocIndividual 278 | 279 | def computeAUROC (self, dataGT, dataPRED, class_ids): 280 | outAUROC = [] 281 | datanpGT = dataGT.cpu().numpy() 282 | datanpPRED = dataPRED.cpu().numpy() 283 | 284 | for i in class_ids: 285 | outAUROC.append(roc_auc_score(datanpGT[:, i], datanpPRED[:, i])) 286 | return outAUROC 287 | 288 | def write_results(self, aurocIndividual, class_ids, prefix='val', mode='a'): 289 | 290 | with open(f"{self.args.save_dir}/results.txt", mode) as results_file: 291 | 292 | aurocMean = aurocIndividual.mean() 293 | 294 | results_file.write(f'{prefix} AUROC mean {aurocMean:0.4f}\n') 295 | for i, class_id in enumerate(class_ids): 296 | results_file.write(f'{self.val_dl.dataset.CLASSES[class_id]} {aurocIndividual[i]:0.4f}\n') 297 | 298 | def print_auroc(self, aurocIndividual, class_ids, prefix='val'): 299 | aurocMean = aurocIndividual.mean() 300 | 301 | print (f'{prefix} AUROC mean {aurocMean:0.4f}') 302 | 303 | for i, class_id in enumerate(class_ids): 304 | print (f'{self.val_dl.dataset.CLASSES[class_id]} {aurocIndividual[i]:0.4f}') 305 | 306 | 307 | 308 | 309 | --------------------------------------------------------------------------------