├── data └── .gitignore ├── checkpoints └── .gitignore ├── predictions └── .gitignore ├── bpe_simple_vocab_16e6.txt.gz ├── .gitignore ├── requirements.txt ├── LICENSE ├── NOTICE.md ├── run_preprocess.py ├── data_process.py ├── run_train.py ├── simple_tokenizer.py ├── eval.py ├── clip.py ├── preprocess_padchest.py ├── train.py ├── README.md ├── metrics.py ├── model.py ├── zero_shot.py └── notebooks └── zero_shot.ipynb /data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore 4 | -------------------------------------------------------------------------------- /checkpoints/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore 4 | -------------------------------------------------------------------------------- /predictions/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore 4 | -------------------------------------------------------------------------------- /bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajpurkarlab/CheXzero/HEAD/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | model.pt 2 | notebooks/.ipynb_checkpoints 3 | .ipynb_checkpoints 4 | __pycache__ 5 | notebooks/clip_v1_0.1_state_dict.pt 6 | notebooks/models 7 | notebooks/train/wandb 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==1.1.0 2 | argparse==1.4.0 3 | ftfy==6.1.1 4 | grpcio==1.46.1 5 | h5py==3.1.0 6 | huggingface-hub==0.6.0 7 | imageio==2.19.1 8 | joblib==1.0.1 9 | matplotlib==3.3.4 10 | numpy==1.19.5 11 | opencv-python==4.5.3.56 12 | opencv-python-headless==4.1.2.30 13 | pandas==1.2.1 14 | pathlib==1.0.1 15 | plotly==5.9.0 16 | psutil==5.8.0 17 | python-dateutil==2.8.1 18 | regex==2020.11.13 19 | scikit-image==0.19.2 20 | scikit-learn==0.24.1 21 | scipy==1.6.1 22 | sklearn==0.0 23 | tifffile==2022.5.4 24 | tokenizers==0.12.1 25 | torch==1.10.2 26 | torchaudio==0.10.2 27 | torchvision==0.11.3 28 | transformers==4.19.0 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Rajpurkar Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NOTICE.md: -------------------------------------------------------------------------------- 1 | # Notices for CheXzero 2 | This software incorporates material from third parties. 3 | 4 | ## Project Licenses 5 | The source code of this repository was derived from CLIP developed by OpenAI (https://github.com/openai/CLIP). This work uses and modifies code that defines the CLIP model architecture, preprocesses unstructured text, and runs inference. 6 | 7 | ### Open Source License / Copyright Notice 8 | ``` 9 | MIT License 10 | 11 | Copyright (c) 2021 OpenAI 12 | 13 | Permission is hereby granted, free of charge, to any person obtaining a copy 14 | of this software and associated documentation files (the "Software"), to deal 15 | in the Software without restriction, including without limitation the rights 16 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 17 | copies of the Software, and to permit persons to whom the Software is 18 | furnished to do so, subject to the following conditions: 19 | 20 | The above copyright notice and this permission notice shall be included in all 21 | copies or substantial portions of the Software. 22 | 23 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 24 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 25 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 26 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 27 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 28 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 29 | SOFTWARE. 30 | ``` 31 | -------------------------------------------------------------------------------- /run_preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from data_process import get_cxr_paths_list, img_to_hdf5, get_cxr_path_csv, write_report_csv 4 | 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--csv_out_path', type=str, default='data/cxr_paths.csv', help="Directory to save paths to all chest x-ray images in dataset.") 9 | parser.add_argument('--cxr_out_path', type=str, default='data/cxr.h5', help="Directory to save processed chest x-ray image data.") 10 | parser.add_argument('--dataset_type', type=str, default='mimic', choices=['mimic', 'chexpert-test'], help="Type of dataset to pre-process") 11 | parser.add_argument('--mimic_impressions_path', default='data/mimic_impressions.csv', help="Directory to save extracted impressions from radiology reports.") 12 | parser.add_argument('--chest_x_ray_path', default='/deep/group/data/mimic-cxr/mimic-cxr-jpg/2.0.0/files', help="Directory where chest x-ray image data is stored. This should point to the files folder from the MIMIC chest x-ray dataset.") 13 | parser.add_argument('--radiology_reports_path', default='/deep/group/data/med-data/files/', help="Directory radiology reports are stored. This should point to the files folder from the MIMIC radiology reports dataset.") 14 | args = parser.parse_args() 15 | return args 16 | 17 | if __name__ == "__main__": 18 | args = parse_args() 19 | if args.dataset_type == "mimic": 20 | # Write Chest X-ray Image HDF5 File 21 | get_cxr_path_csv(args.csv_out_path, args.chest_x_ray_path) 22 | cxr_paths = get_cxr_paths_list(args.csv_out_path) 23 | img_to_hdf5(cxr_paths, args.cxr_out_path) 24 | 25 | #Write CSV File Containing Impressions for each Chest X-ray 26 | write_report_csv(cxr_paths, args.radiology_reports_path, args.mimic_impressions_path) 27 | elif args.dataset_type == "chexpert-test": 28 | # Get all test paths based on cxr dir 29 | cxr_dir = Path(args.chest_x_ray_path) 30 | cxr_paths = list(cxr_dir.rglob("*.jpg")) 31 | cxr_paths = list(filter(lambda x: "view1" in str(x), cxr_paths)) # filter only first frontal views 32 | cxr_paths = sorted(cxr_paths) # sort to align with groundtruth 33 | assert(len(cxr_paths) == 500) 34 | 35 | img_to_hdf5(cxr_paths, args.cxr_out_path) 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /data_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import glob 4 | import numpy as np 5 | import pandas as pd 6 | import csv 7 | import matplotlib.pyplot as plt 8 | from tqdm import tqdm 9 | 10 | from PIL import Image 11 | import h5py 12 | import cv2 13 | from typing import * 14 | from pathlib import Path 15 | 16 | import torch 17 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 18 | 19 | def load_data(filepath): 20 | dataframe = pd.read_csv(filepath) 21 | return dataframe 22 | 23 | def get_cxr_paths_list(filepath): 24 | dataframe = load_data(filepath) 25 | cxr_paths = dataframe['Path'] 26 | return cxr_paths 27 | 28 | ''' 29 | This function resizes and zero pads image 30 | ''' 31 | def preprocess(img, desired_size=320): 32 | old_size = img.size 33 | ratio = float(desired_size)/max(old_size) 34 | new_size = tuple([int(x*ratio) for x in old_size]) 35 | img = img.resize(new_size, Image.ANTIALIAS) 36 | # create a new image and paste the resized on it 37 | 38 | new_img = Image.new('L', (desired_size, desired_size)) 39 | new_img.paste(img, ((desired_size-new_size[0])//2, 40 | (desired_size-new_size[1])//2)) 41 | return new_img 42 | 43 | def img_to_hdf5(cxr_paths: List[Union[str, Path]], out_filepath: str, resolution=320): 44 | """ 45 | Convert directory of images into a .h5 file given paths to all 46 | images. 47 | """ 48 | dset_size = len(cxr_paths) 49 | failed_images = [] 50 | with h5py.File(out_filepath,'w') as h5f: 51 | img_dset = h5f.create_dataset('cxr', shape=(dset_size, resolution, resolution)) 52 | for idx, path in enumerate(tqdm(cxr_paths)): 53 | try: 54 | # read image using cv2 55 | img = cv2.imread(str(path)) 56 | # convert to PIL Image object 57 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 58 | img_pil = Image.fromarray(img) 59 | # preprocess 60 | img = preprocess(img_pil, desired_size=resolution) 61 | img_dset[idx] = img 62 | except Exception as e: 63 | failed_images.append((path, e)) 64 | print(f"{len(failed_images)} / {len(cxr_paths)} images failed to be added to h5.", failed_images) 65 | 66 | def get_files(directory): 67 | files = [] 68 | for (dirpath, dirnames, filenames) in os.walk(directory): 69 | for file in filenames: 70 | if file.endswith(".jpg"): 71 | files.append(os.path.join(dirpath, file)) 72 | return files 73 | 74 | def get_cxr_path_csv(out_filepath, directory): 75 | files = get_files(directory) 76 | file_dict = {"Path": files} 77 | df = pd.DataFrame(file_dict) 78 | df.to_csv(out_filepath, index=False) 79 | 80 | def section_start(lines, section=' IMPRESSION'): 81 | for idx, line in enumerate(lines): 82 | if line.startswith(section): 83 | return idx 84 | return -1 85 | 86 | def section_end(lines, section_start): 87 | num_lines = len(lines) 88 | 89 | def getIndexOfLast(l, element): 90 | """ Get index of last occurence of element 91 | @param l (list): list of elements 92 | @param element (string): element to search for 93 | @returns (int): index of last occurrence of element 94 | """ 95 | i = max(loc for loc, val in enumerate(l) if val == element) 96 | return i 97 | 98 | def write_report_csv(cxr_paths, txt_folder, out_path): 99 | imps = {"filename": [], "impression": []} 100 | txt_reports = [] 101 | for cxr_path in cxr_paths: 102 | tokens = cxr_path.split('/') 103 | study_num = tokens[-2] 104 | patient_num = tokens[-3] 105 | patient_group = tokens[-4] 106 | txt_report = txt_folder + patient_group + '/' + patient_num + '/' + study_num + '.txt' 107 | filename = study_num + '.txt' 108 | f = open(txt_report, 'r') 109 | s = f.read() 110 | s_split = s.split() 111 | if "IMPRESSION:" in s_split: 112 | begin = getIndexOfLast(s_split, "IMPRESSION:") + 1 113 | end = None 114 | end_cand1 = None 115 | end_cand2 = None 116 | # remove recommendation(s) and notification 117 | if "RECOMMENDATION(S):" in s_split: 118 | end_cand1 = s_split.index("RECOMMENDATION(S):") 119 | elif "RECOMMENDATION:" in s_split: 120 | end_cand1 = s_split.index("RECOMMENDATION:") 121 | elif "RECOMMENDATIONS:" in s_split: 122 | end_cand1 = s_split.index("RECOMMENDATIONS:") 123 | 124 | if "NOTIFICATION:" in s_split: 125 | end_cand2 = s_split.index("NOTIFICATION:") 126 | elif "NOTIFICATIONS:" in s_split: 127 | end_cand2 = s_split.index("NOTIFICATIONS:") 128 | 129 | if end_cand1 and end_cand2: 130 | end = min(end_cand1, end_cand2) 131 | elif end_cand1: 132 | end = end_cand1 133 | elif end_cand2: 134 | end = end_cand2 135 | 136 | if end == None: 137 | imp = " ".join(s_split[begin:]) 138 | else: 139 | imp = " ".join(s_split[begin:end]) 140 | else: 141 | imp = 'NO IMPRESSION' 142 | 143 | imps["impression"].append(imp) 144 | imps["filename"].append(filename) 145 | 146 | df = pd.DataFrame(data=imps) 147 | df.to_csv(out_path, index=False) 148 | 149 | -------------------------------------------------------------------------------- /run_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pprint 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | import torch 7 | from torch.utils import data 8 | from torch import nn 9 | import torch.optim as optim 10 | from torchvision.transforms import Compose, Normalize, Resize 11 | 12 | import clip 13 | from model import CLIP 14 | from simple_tokenizer import SimpleTokenizer 15 | 16 | from train import train_main, load_data, load_clip, preprocess_text 17 | from zero_shot import run_cxr_zero_shot, run_zero_shot 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--cxr_filepath', type=str, default='data/cxr.h5', help="Directory to load chest x-ray image data from.") 22 | parser.add_argument('--txt_filepath', type=str, default='data/mimic_impressions.csv', help="Directory to load radiology report impressions text from.") 23 | parser.add_argument('--batch_size', type=int, default=16) 24 | parser.add_argument('--epochs', type=int, default=4) 25 | parser.add_argument('--lr', type=float, default=1e-4) 26 | parser.add_argument('--save_interval', type=int, default=100) 27 | parser.add_argument('--log_interval', type=int, default=10) 28 | parser.add_argument('--save_dir', type=str, default="checkpoints/", help="Directory to save the trained model.") 29 | parser.add_argument('--seed', type=int, default=1234) 30 | parser.add_argument('--optimizer', type=str, default="sgd") 31 | parser.add_argument('--momentum', type=float, default=0.9) 32 | parser.add_argument('--context_length', type=int, default=77) 33 | parser.add_argument('--random_init', action='store_true') 34 | parser.add_argument('--model_name', type=str, default="pt-imp") 35 | args = parser.parse_args() 36 | return args 37 | 38 | def model_pipeline(config, verbose=0): 39 | # make the model, data, and optimization problem 40 | model, data_loader, device, criterion, optimizer = make(config) 41 | 42 | # and use them to train the model 43 | train(model, data_loader, device, criterion, optimizer, config) 44 | 45 | # save model 46 | model_path = os.path.join(config.save_dir, str(config.model_name), 'checkpoint.pt') 47 | save(model, model_path) 48 | 49 | if verbose: 50 | print(model) 51 | return model 52 | 53 | def make(config): 54 | pretrained = not config.random_init 55 | data_loader, device = load_data(config.cxr_filepath, config.txt_filepath, batch_size=config.batch_size, pretrained=pretrained, column="impression") 56 | model = load_clip(model_path=None, pretrained=pretrained, context_length=config.context_length) 57 | model.to(device) 58 | print('Model on Device.') 59 | 60 | # make the optimizer 61 | criterion = nn.CrossEntropyLoss().cuda() 62 | if config.optimizer == "adam": 63 | optimizer = optim.AdamW(model.parameters(), lr=config.lr) 64 | elif config.optimizer == "sgd": 65 | optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum) 66 | return model, data_loader, device, criterion, optimizer 67 | 68 | def train(model, loader, device, criterion, optimizer, config): 69 | model_save_dir = os.path.join(config.save_dir, config.model_name) 70 | if not os.path.exists(model_save_dir): 71 | # Create a new folder if not exists 72 | os.makedirs(model_save_dir) 73 | 74 | # Run training 75 | total_batches = len(loader) * config.epochs 76 | example_ct = 0 # number of examples seen 77 | batch_ct = 0 78 | report_freq = config.log_interval 79 | highest_val_auc = 0 # save highest mean auc 80 | 81 | for epoch in range(config.epochs): 82 | running_loss = 0.0 # running loss over batch 83 | for data in tqdm(loader): 84 | # get the images 85 | images = data['img'] 86 | 87 | texts = data['txt'] 88 | texts = preprocess_text(texts, model) 89 | 90 | # perform step for a single batch 91 | loss = train_batch(images, texts, model, device, criterion, optimizer) 92 | example_ct += len(images) 93 | batch_ct += 1 94 | running_loss += loss.item() 95 | 96 | # Report metrics every `report_freq` batch 97 | if (batch_ct % report_freq) == 0: 98 | train_log(running_loss / report_freq, example_ct, epoch) 99 | running_loss = 0.0 100 | 101 | if (batch_ct % config.save_interval) == 0: 102 | model_path = os.path.join(model_save_dir, "checkpoint_{batch_ct}.pt".format( 103 | batch_ct=str(batch_ct), 104 | )) 105 | print("Saved checkpoint to: ", model_path) 106 | save(model, model_path) 107 | 108 | def train_batch(images, texts, model, device, criterion, optimizer): 109 | images, texts = images.to(device), texts.to(device) 110 | 111 | # Forward pass ➡ 112 | logits_per_image, logits_per_text = model(images, texts) 113 | 114 | # Create labels 115 | batch_size = images.shape[0] 116 | labels = torch.arange(batch_size).to(device) 117 | 118 | # Compute loss 119 | loss_img = criterion(logits_per_image, labels) 120 | loss_txt = criterion(logits_per_text, labels) 121 | loss = (loss_img + loss_txt)/2 # avg. img and txt loss 122 | 123 | # Backward pass ⬅ 124 | optimizer.zero_grad() 125 | loss.backward() 126 | 127 | # Step with optimizer 128 | optimizer.step() 129 | 130 | return loss 131 | 132 | def train_log(loss, example_ct, epoch): 133 | loss = float(loss) 134 | print(f"Loss after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}") 135 | 136 | def save(model, path): 137 | torch.save(model.state_dict(), path) 138 | 139 | if __name__ == "__main__": 140 | args = parse_args() 141 | model = model_pipeline(args) 142 | 143 | 144 | -------------------------------------------------------------------------------- /simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2021 OpenAI 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | 24 | """ 25 | import gzip 26 | import html 27 | import os 28 | from functools import lru_cache 29 | 30 | import ftfy 31 | import regex as re 32 | 33 | 34 | @lru_cache() 35 | def default_bpe(): 36 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 37 | 38 | 39 | @lru_cache() 40 | def bytes_to_unicode(): 41 | """ 42 | Returns list of utf-8 byte and a corresponding list of unicode strings. 43 | The reversible bpe codes work on unicode strings. 44 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 45 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 46 | This is a signficant percentage of your normal, say, 32K bpe vocab. 47 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 48 | And avoids mapping to whitespace/control characters the bpe code barfs on. 49 | """ 50 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 51 | cs = bs[:] 52 | n = 0 53 | for b in range(2**8): 54 | if b not in bs: 55 | bs.append(b) 56 | cs.append(2**8+n) 57 | n += 1 58 | cs = [chr(n) for n in cs] 59 | return dict(zip(bs, cs)) 60 | 61 | 62 | def get_pairs(word): 63 | """Return set of symbol pairs in a word. 64 | Word is represented as tuple of symbols (symbols being variable-length strings). 65 | """ 66 | pairs = set() 67 | prev_char = word[0] 68 | for char in word[1:]: 69 | pairs.add((prev_char, char)) 70 | prev_char = char 71 | return pairs 72 | 73 | 74 | def basic_clean(text): 75 | text = ftfy.fix_text(text) 76 | text = html.unescape(html.unescape(text)) 77 | return text.strip() 78 | 79 | 80 | def whitespace_clean(text): 81 | text = re.sub(r'\s+', ' ', text) 82 | text = text.strip() 83 | return text 84 | 85 | 86 | class SimpleTokenizer(object): 87 | def __init__(self, bpe_path: str = default_bpe()): 88 | self.byte_encoder = bytes_to_unicode() 89 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 90 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 91 | merges = merges[1:49152-256-2+1] 92 | merges = [tuple(merge.split()) for merge in merges] 93 | vocab = list(bytes_to_unicode().values()) 94 | vocab = vocab + [v+'' for v in vocab] 95 | for merge in merges: 96 | vocab.append(''.join(merge)) 97 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 98 | self.encoder = dict(zip(vocab, range(len(vocab)))) 99 | self.decoder = {v: k for k, v in self.encoder.items()} 100 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 101 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 102 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 103 | 104 | def bpe(self, token): 105 | if token in self.cache: 106 | return self.cache[token] 107 | word = tuple(token[:-1]) + ( token[-1] + '',) 108 | pairs = get_pairs(word) 109 | 110 | if not pairs: 111 | return token+'' 112 | 113 | while True: 114 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 115 | if bigram not in self.bpe_ranks: 116 | break 117 | first, second = bigram 118 | new_word = [] 119 | i = 0 120 | while i < len(word): 121 | try: 122 | j = word.index(first, i) 123 | new_word.extend(word[i:j]) 124 | i = j 125 | except: 126 | new_word.extend(word[i:]) 127 | break 128 | 129 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 130 | new_word.append(first+second) 131 | i += 2 132 | else: 133 | new_word.append(word[i]) 134 | i += 1 135 | new_word = tuple(new_word) 136 | word = new_word 137 | if len(word) == 1: 138 | break 139 | else: 140 | pairs = get_pairs(word) 141 | word = ' '.join(word) 142 | self.cache[token] = word 143 | return word 144 | 145 | def encode(self, text): 146 | bpe_tokens = [] 147 | text = whitespace_clean(basic_clean(text)).lower() 148 | for token in re.findall(self.pat, text): 149 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 150 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 151 | return bpe_tokens 152 | 153 | def decode(self, tokens): 154 | text = ''.join([self.decoder[token] for token in tokens]) 155 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 156 | return text 157 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import numpy as np 3 | import os 4 | import pandas as pd 5 | from PIL import Image 6 | import h5py 7 | import matplotlib.pyplot as plt 8 | from typing import List, Callable 9 | 10 | import torch 11 | from torch.utils import data 12 | from tqdm.notebook import tqdm 13 | import torch.nn as nn 14 | from torchvision.transforms import Compose, Normalize, Resize 15 | 16 | import sklearn 17 | from sklearn.metrics import matthews_corrcoef, confusion_matrix, accuracy_score, auc, roc_auc_score, roc_curve, classification_report 18 | from sklearn.metrics import precision_recall_curve, f1_score 19 | from sklearn.metrics import average_precision_score 20 | from sklearn.utils import resample 21 | 22 | import scipy 23 | import scipy.stats 24 | 25 | import sys 26 | sys.path.append('../..') 27 | 28 | import clip 29 | from model import CLIP 30 | 31 | def compute_mean(stats, is_df=True): 32 | spec_labels = ["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Pleural Effusion"] 33 | if is_df: 34 | spec_df = stats[spec_labels] 35 | res = np.mean(spec_df.iloc[0]) 36 | else: 37 | # cis is df, within bootstrap 38 | vals = [stats[spec_label][0] for spec_label in spec_labels] 39 | res = np.mean(vals) 40 | return res 41 | 42 | def accuracy(output, target, topk=(1,)): 43 | pred = output.topk(max(topk), 1, True, True)[1].t() 44 | print('pred: ', pred) 45 | 46 | expand = target.expand(-1, max(topk)) 47 | print('expand: ', expand) 48 | 49 | correct = pred.eq(expand) 50 | print('correct: ', correct) 51 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 52 | 53 | def sigmoid(x): 54 | z = 1/(1 + np.exp(-x)) 55 | return z 56 | 57 | ''' ROC CURVE ''' 58 | def plot_roc(y_pred, y_true, roc_name, plot=False): 59 | # given the test_ground_truth, and test_predictions 60 | fpr, tpr, thresholds = roc_curve(y_true, y_pred) 61 | 62 | roc_auc = auc(fpr, tpr) 63 | 64 | if plot: 65 | plt.figure(dpi=100) 66 | plt.title(roc_name) 67 | plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc) 68 | plt.legend(loc = 'lower right') 69 | plt.plot([0, 1], [0, 1],'r--') 70 | plt.xlim([0, 1]) 71 | plt.ylim([0, 1]) 72 | plt.ylabel('True Positive Rate') 73 | plt.xlabel('False Positive Rate') 74 | plt.show() 75 | return fpr, tpr, thresholds, roc_auc 76 | 77 | # J = TP/(TP+FN) + TN/(TN+FP) - 1 = tpr - fpr 78 | def choose_operating_point(fpr, tpr, thresholds): 79 | sens = 0 80 | spec = 0 81 | J = 0 82 | for _fpr, _tpr in zip(fpr, tpr): 83 | if _tpr - _fpr > J: 84 | sens = _tpr 85 | spec = 1-_fpr 86 | J = _tpr - _fpr 87 | return sens, spec 88 | 89 | ''' PRECISION-RECALL CURVE ''' 90 | def plot_pr(y_pred, y_true, pr_name, plot=False): 91 | precision, recall, thresholds = precision_recall_curve(y_true, y_pred) 92 | pr_auc = auc(recall, precision) 93 | # plot the precision-recall curves 94 | baseline = len(y_true[y_true==1]) / len(y_true) 95 | 96 | if plot: 97 | plt.figure(dpi=20) 98 | plt.title(pr_name) 99 | plt.plot(recall, precision, 'b', label='AUC = %0.2f' % pr_auc) 100 | # axis labels 101 | plt.legend(loc = 'lower right') 102 | plt.plot([0, 1], [baseline, baseline],'r--') 103 | plt.xlim([0, 1]) 104 | plt.ylim([0, 1]) 105 | plt.xlabel('Recall') 106 | plt.ylabel('Precision') 107 | # show the plot 108 | plt.show() 109 | return precision, recall, thresholds 110 | 111 | def evaluate(y_pred, y_true, cxr_labels, 112 | roc_name='Receiver Operating Characteristic', pr_name='Precision-Recall Curve', label_idx_map=None): 113 | 114 | ''' 115 | We expect `y_pred` and `y_true` to be numpy arrays, both of shape (num_samples, num_classes) 116 | 117 | `y_pred` is a numpy array consisting of probability scores with all values in range 0-1. 118 | 119 | `y_true` is a numpy array consisting of binary values representing if a class is present in 120 | the cxr. 121 | 122 | This function provides all relevant evaluation information, ROC, AUROC, Sensitivity, Specificity, 123 | PR-Curve, Precision, Recall for each class. 124 | ''' 125 | import warnings 126 | warnings.filterwarnings('ignore') 127 | 128 | num_classes = y_pred.shape[-1] # number of total labels 129 | 130 | dataframes = [] 131 | for i in range(num_classes): 132 | # print('{}.'.format(cxr_labels[i])) 133 | 134 | if label_idx_map is None: 135 | y_pred_i = y_pred[:, i] # (num_samples,) 136 | y_true_i = y_true[:, i] # (num_samples,) 137 | 138 | else: 139 | y_pred_i = y_pred[:, i] # (num_samples,) 140 | 141 | true_index = label_idx_map[cxr_labels[i]] 142 | y_true_i = y_true[:, true_index] # (num_samples,) 143 | 144 | cxr_label = cxr_labels[i] 145 | 146 | ''' ROC CURVE ''' 147 | roc_name = cxr_label + ' ROC Curve' 148 | fpr, tpr, thresholds, roc_auc = plot_roc(y_pred_i, y_true_i, roc_name) 149 | 150 | sens, spec = choose_operating_point(fpr, tpr, thresholds) 151 | 152 | results = [[roc_auc]] 153 | df = pd.DataFrame(results, columns=[cxr_label+'_auc']) 154 | dataframes.append(df) 155 | 156 | ''' PRECISION-RECALL CURVE ''' 157 | pr_name = cxr_label + ' Precision-Recall Curve' 158 | precision, recall, thresholds = plot_pr(y_pred_i, y_true_i, pr_name) 159 | 160 | dfs = pd.concat(dataframes, axis=1) 161 | return dfs 162 | 163 | ''' Bootstrap and Confidence Intervals ''' 164 | def compute_cis(data, confidence_level=0.05): 165 | """ 166 | FUNCTION: compute_cis 167 | ------------------------------------------------------ 168 | Given a Pandas dataframe of (n, labels), return another 169 | Pandas dataframe that is (3, labels). 170 | 171 | Each row is lower bound, mean, upper bound of a confidence 172 | interval with `confidence`. 173 | 174 | Args: 175 | * data - Pandas Dataframe, of shape (num_bootstrap_samples, num_labels) 176 | * confidence_level (optional) - confidence level of interval 177 | 178 | Returns: 179 | * Pandas Dataframe, of shape (3, labels), representing mean, lower, upper 180 | """ 181 | data_columns = list(data) 182 | intervals = [] 183 | for i in data_columns: 184 | series = data[i] 185 | sorted_perfs = series.sort_values() 186 | lower_index = int(confidence_level/2 * len(sorted_perfs)) - 1 187 | upper_index = int((1 - confidence_level/2) * len(sorted_perfs)) - 1 188 | lower = sorted_perfs.iloc[lower_index].round(4) 189 | upper = sorted_perfs.iloc[upper_index].round(4) 190 | mean = round(sorted_perfs.mean(), 4) 191 | interval = pd.DataFrame({i : [mean, lower, upper]}) 192 | intervals.append(interval) 193 | intervals_df = pd.concat(intervals, axis=1) 194 | intervals_df.index = ['mean', 'lower', 'upper'] 195 | return intervals_df 196 | 197 | def bootstrap(y_pred, y_true, cxr_labels, n_samples=1000, label_idx_map=None): 198 | ''' 199 | This function will randomly sample with replacement 200 | from y_pred and y_true then evaluate `n` times 201 | and obtain AUROC scores for each. 202 | 203 | You can specify the number of samples that should be 204 | used with the `n_samples` parameter. 205 | 206 | Confidence intervals will be generated from each 207 | of the samples. 208 | 209 | Note: 210 | * n_total_labels >= n_cxr_labels 211 | `n_total_labels` is greater iff alternative labels are being tested 212 | ''' 213 | np.random.seed(97) 214 | y_pred # (500, n_total_labels) 215 | y_true # (500, n_cxr_labels) 216 | 217 | idx = np.arange(len(y_true)) 218 | 219 | boot_stats = [] 220 | for i in tqdm(range(n_samples)): 221 | sample = resample(idx, replace=True, random_state=i) 222 | y_pred_sample = y_pred[sample] 223 | y_true_sample = y_true[sample] 224 | 225 | sample_stats = evaluate(y_pred_sample, y_true_sample, cxr_labels, label_idx_map=label_idx_map) 226 | boot_stats.append(sample_stats) 227 | 228 | boot_stats = pd.concat(boot_stats) # pandas array of evaluations for each sample 229 | return boot_stats, compute_cis(boot_stats) 230 | -------------------------------------------------------------------------------- /clip.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2021 OpenAI 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | 24 | """ 25 | 26 | import hashlib 27 | import os 28 | import urllib 29 | import warnings 30 | from typing import Union, List 31 | 32 | import torch 33 | from PIL import Image 34 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 35 | from tqdm import tqdm 36 | 37 | from model import build_model 38 | from simple_tokenizer import SimpleTokenizer as _Tokenizer 39 | 40 | __all__ = ["available_models", "load", "tokenize"] 41 | _tokenizer = _Tokenizer() 42 | 43 | _MODELS = { 44 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 45 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 46 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 47 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 48 | } 49 | 50 | 51 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 52 | os.makedirs(root, exist_ok=True) 53 | filename = os.path.basename(url) 54 | 55 | expected_sha256 = url.split("/")[-2] 56 | download_target = os.path.join(root, filename) 57 | 58 | if os.path.exists(download_target) and not os.path.isfile(download_target): 59 | raise RuntimeError(f"{download_target} exists and is not a regular file") 60 | 61 | if os.path.isfile(download_target): 62 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 63 | return download_target 64 | else: 65 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 66 | 67 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 68 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 69 | while True: 70 | buffer = source.read(8192) 71 | if not buffer: 72 | break 73 | 74 | output.write(buffer) 75 | loop.update(len(buffer)) 76 | 77 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 78 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 79 | 80 | return download_target 81 | 82 | 83 | def _transform(n_px): 84 | return Compose([ 85 | Resize(n_px, interpolation=Image.BICUBIC), 86 | CenterCrop(n_px), 87 | lambda image: image.convert("RGB"), 88 | ToTensor(), 89 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 90 | ]) 91 | 92 | 93 | def available_models() -> List[str]: 94 | """Returns the names of available CLIP models""" 95 | return list(_MODELS.keys()) 96 | 97 | 98 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): 99 | """Load a CLIP model 100 | 101 | Parameters 102 | ---------- 103 | name : str 104 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 105 | 106 | device : Union[str, torch.device] 107 | The device to put the loaded model 108 | 109 | jit : bool 110 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 111 | 112 | Returns 113 | ------- 114 | model : torch.nn.Module 115 | The CLIP model 116 | 117 | preprocess : Callable[[PIL.Image], torch.Tensor] 118 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 119 | """ 120 | if name in _MODELS: 121 | model_path = _download(_MODELS[name]) 122 | elif os.path.isfile(name): 123 | model_path = name 124 | else: 125 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 126 | 127 | try: 128 | # loading JIT archive 129 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 130 | state_dict = None 131 | except RuntimeError: 132 | # loading saved state dict 133 | if jit: 134 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 135 | jit = False 136 | state_dict = torch.load(model_path, map_location="cpu") 137 | 138 | if not jit: 139 | model = build_model(state_dict or model.state_dict()).to(device) 140 | if str(device) == "cpu": 141 | model.float() 142 | return model, _transform(model.visual.input_resolution) 143 | 144 | # patch the device names 145 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 146 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 147 | 148 | def patch_device(module): 149 | graphs = [module.graph] if hasattr(module, "graph") else [] 150 | if hasattr(module, "forward1"): 151 | graphs.append(module.forward1.graph) 152 | 153 | for graph in graphs: 154 | for node in graph.findAllNodes("prim::Constant"): 155 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 156 | node.copyAttributes(device_node) 157 | 158 | model.apply(patch_device) 159 | patch_device(model.encode_image) 160 | patch_device(model.encode_text) 161 | 162 | # patch dtype to float32 on CPU 163 | if str(device) == "cpu": 164 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 165 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 166 | float_node = float_input.node() 167 | 168 | def patch_float(module): 169 | graphs = [module.graph] if hasattr(module, "graph") else [] 170 | if hasattr(module, "forward1"): 171 | graphs.append(module.forward1.graph) 172 | 173 | for graph in graphs: 174 | for node in graph.findAllNodes("aten::to"): 175 | inputs = list(node.inputs()) 176 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 177 | if inputs[i].node()["value"] == 5: 178 | inputs[i].node().copyAttributes(float_node) 179 | 180 | model.apply(patch_float) 181 | patch_float(model.encode_image) 182 | patch_float(model.encode_text) 183 | 184 | model.float() 185 | 186 | return model, _transform(model.input_resolution.item()) 187 | 188 | 189 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 190 | """ 191 | Returns the tokenized representation of given input string(s) 192 | 193 | Parameters 194 | ---------- 195 | texts : Union[str, List[str]] 196 | An input string or a list of input strings to tokenize 197 | 198 | context_length : int 199 | The context length to use; all CLIP models use 77 as the context length 200 | 201 | Returns 202 | ------- 203 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 204 | """ 205 | if isinstance(texts, str): 206 | texts = [texts] 207 | 208 | sot_token = _tokenizer.encoder["<|startoftext|>"] 209 | eot_token = _tokenizer.encoder["<|endoftext|>"] 210 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 211 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 212 | 213 | for i, tokens in enumerate(all_tokens): 214 | if len(tokens) > context_length: 215 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 216 | result[i, :len(tokens)] = torch.tensor(tokens) 217 | 218 | return result 219 | -------------------------------------------------------------------------------- /preprocess_padchest.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import numpy as np 3 | import os 4 | import pandas as pd 5 | from PIL import Image 6 | import h5py 7 | import matplotlib.pyplot as plt 8 | from typing import List 9 | 10 | import torch 11 | from torch.utils import data 12 | from tqdm.notebook import tqdm 13 | import torch.nn as nn 14 | from torchvision.transforms import Compose, Normalize 15 | 16 | import sklearn 17 | from sklearn.metrics import confusion_matrix, accuracy_score, auc, roc_auc_score, roc_curve, classification_report 18 | from sklearn.metrics import precision_recall_curve, f1_score 19 | from sklearn.metrics import average_precision_score 20 | 21 | import sys 22 | sys.path.append('../..') 23 | sys.path.append('../data-process') 24 | sys.path.append('data/padchest') 25 | 26 | from data_process import * 27 | 28 | 29 | 30 | def preprocess_data(data_root): 31 | labels_path = os.path.join(data_root, 32 | 'PADCHEST_chest_x_ray_images_labels_160K_01.02.19.csv') 33 | labels = pd.read_csv(labels_path) 34 | # get filepaths of 2.zip images 35 | text_file_path = os.path.join(data_root, '2.zip.unzip-l.txt') 36 | image_paths = extract_filenames(text_file_path) 37 | labels_2_df = labels[labels['ImageID'].isin(image_paths)] 38 | unique_labels = get_unique_labels(labels_2_df) 39 | # multi hot encoding for labels 40 | df_lab = create_multi_hot_labels(labels_2_df, unique_labels) 41 | 42 | loc_2_df = labels[labels['ImageID'].isin(image_paths)] 43 | loc_col_2 = loc_2_df.loc[:, "Labels"] 44 | # multihot encoding for localizations 45 | unique_loc = get_unique_labels(loc_2_df, column="Labels") 46 | df_loc = create_multi_hot_labels(loc_2_df, unique_loc, column="Labels") 47 | directory = 'data/padchest/images/' 48 | cxr_paths = get_paths(directory) 49 | write_h5(cxr_paths) 50 | unique_labels = np.load('unique_labels.npy') 51 | return unique_labels[0:1] 52 | 53 | def extract_filenames(txt_path): 54 | """ 55 | Given a filepath to a txt file with image file names, 56 | extract a list of filenames for this zip. 57 | 58 | Assume that the txt file has two unnecessary lines at 59 | both the top and the bottom of the file. 60 | """ 61 | df = pd.read_csv(txt_path) 62 | df_list = df.values.tolist() 63 | df_list = df_list[2:-2] 64 | 65 | images_list = [] 66 | for file in df_list: 67 | parsed_filename = file[0].split()[-1] 68 | images_list.append(parsed_filename) 69 | return images_list 70 | 71 | # get paths of all possible labels 72 | def get_unique_labels(labels_df, column='Labels'): 73 | """ 74 | Given labels_df, return a list containing all unique labels 75 | present in this dataset. 76 | """ 77 | 78 | unique_labels = set() 79 | # iterate through all rows in the dataframe 80 | for index, row in labels_df.iterrows(): 81 | labels = row[column] 82 | try: 83 | # convert labels str to array 84 | labels_arr = labels.strip('][').split(', ') 85 | for label in labels_arr: 86 | # process string 87 | processed_label = label.split("'")[1].strip() 88 | processed_label = processed_label.lower() 89 | unique_labels.add(processed_label) 90 | except: 91 | continue 92 | 93 | return list(unique_labels) 94 | 95 | def create_multi_hot_labels(labels_df, unique_labels_list, column='Labels'): 96 | """ 97 | Args: 98 | * labels_df: original df where labels are an arr 99 | * labels_list: list of all possible labels in respective order 100 | 101 | Given all entries and it's corresponding labels, create a one(multi)-hot vector 102 | where a 1 represents the presence of that disease. 103 | 104 | Returns a Pandas dataframe mapping filename to it's multi-hot representation. Each of the diseases 105 | are columns. 106 | """ 107 | 108 | # todo: check how the labels are represented for CheXpert 109 | # create a pandas datafraame with columns as unique labels, start with list of dicts 110 | dict_list = [] 111 | 112 | # iterate through all rows in the dataframe 113 | for index, row in labels_df.iterrows(): 114 | labels = row[column] 115 | try: 116 | # convert labels str to array 117 | labels_arr = labels.strip('][').split(', ') 118 | # print(labels_arr, len(labels_arr)) 119 | 120 | count_dict = dict() # map label name to count 121 | count_dict['ImageID'] = row['ImageID'] 122 | # init count dict with 0s 123 | for unq_label in unique_labels_list: 124 | count_dict[unq_label] = 0 125 | 126 | if len(labels_arr) > 0 and labels_arr[0] != '': 127 | for label in labels_arr: 128 | # process string 129 | processed_label = label.split("'")[1].strip() 130 | processed_label = processed_label.lower() 131 | count_dict[processed_label] = 1 132 | 133 | dict_list.append(count_dict) 134 | except: 135 | print("error when creating labels for this img.") 136 | continue 137 | 138 | multi_hot_labels_df = pd.DataFrame(dict_list, columns=(['ImageID'] + unique_labels_list)) 139 | return multi_hot_labels_df 140 | 141 | # convert folder of images to h5 file 142 | def get_paths(directory): 143 | """ 144 | Given a directory, this function outputs 145 | all the image paths in that directory as a 146 | list. 147 | """ 148 | paths_list = [] 149 | for filename in os.listdir(directory): 150 | if filename.endswith(".png"): 151 | paths_list.append(os.path.join(directory, filename)) 152 | else: 153 | continue 154 | return paths_list 155 | 156 | def img_to_h5( 157 | cxr_paths: List[str], 158 | out_filepath: str, 159 | resolution: int = 320, 160 | ) -> List[str]: 161 | """ 162 | Converts a set of images into a single `.h5` file. 163 | 164 | Args: 165 | cxr_paths: List of paths to images as `.png` 166 | out_filepath: Path to store h5 file 167 | resolution: image resolution 168 | 169 | Returns a list of cxr_paths that were successfully stored in the 170 | `.h5` file. 171 | """ 172 | dset_size = len(cxr_paths) 173 | proper_cxr_paths = [] 174 | with h5py.File(out_filepath,'w') as h5f: 175 | img_dset = h5f.create_dataset('cxr', shape=(dset_size, resolution, resolution)) 176 | 177 | ctr = 0 178 | for idx, path in enumerate(tqdm(cxr_paths)): 179 | try: 180 | # read image using cv2 181 | img = cv2.imread(path) 182 | # convert to PIL Image object 183 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 184 | img_pil = Image.fromarray(img) 185 | # preprocess 186 | img = preprocess(img_pil, desired_size=resolution) 187 | img_dset[ctr] = img 188 | ctr += 1 189 | proper_cxr_paths.append(path) 190 | except: 191 | print(f"Image {ctr} failed loading...") 192 | continue 193 | print(h5f) 194 | 195 | return proper_cxr_paths 196 | 197 | def write_h5(cxr_paths, resolution: int = 320): 198 | out_filepath = 'data/padchest/images/2_cxr_dset_sample.h5' 199 | dset_size = len(cxr_paths) 200 | 201 | proper_cxr_paths = [] 202 | with h5py.File(out_filepath,'w') as h5f: 203 | img_dset = h5f.create_dataset('cxr', shape=(2978, resolution, resolution)) # todo: replace magic number with actual number 204 | # print('Dataset initialized.') 205 | 206 | ctr = 0 207 | for idx, path in enumerate(tqdm(cxr_paths)): 208 | try: 209 | # read image using cv2 210 | img = cv2.imread(path) 211 | # convert to PIL Image object 212 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 213 | img_pil = Image.fromarray(img) 214 | # preprocess 215 | img = preprocess(img_pil, desired_size=resolution) 216 | plt.imshow(img) 217 | img_dset[ctr] = img 218 | ctr += 1 219 | proper_cxr_paths.append(path) 220 | except: 221 | print("failed!") 222 | continue 223 | print(h5f) 224 | np.save("proper_cxr_paths.npy", np.array(proper_cxr_paths)) 225 | out_filepath = 'data/padchest/images/2_cxr.h5' 226 | img_to_hdf5(cxr_paths, out_filepath, resolution=320) 227 | df_labels_new = order_labels(df_lab, proper_cxr_paths) 228 | labels_path = 'data/padchest/2_cxr_labels.csv' 229 | df_labels_new.to_csv(labels_path) 230 | 231 | def order_labels(df, cxr_paths): 232 | """ 233 | Fixes multi-hot labels to be in order of cxr_paths 234 | """ 235 | df_new = pd.DataFrame(columns=df.columns) 236 | for path in cxr_paths: 237 | imageId = path.split('/')[-1] 238 | row = df.loc[df['ImageID'] == imageId] 239 | df_new = df_new.append(row) 240 | return df_new 241 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | from tqdm.notebook import tqdm 7 | 8 | from PIL import Image 9 | import h5py 10 | 11 | import torch 12 | from torch.utils import data 13 | from torch import nn 14 | import torch.optim as optim 15 | from torchvision.transforms import Compose, Normalize, Resize, InterpolationMode 16 | 17 | import sys 18 | sys.path.append('../..') 19 | 20 | import clip 21 | from model import CLIP 22 | from simple_tokenizer import SimpleTokenizer 23 | 24 | class CXRDataset(data.Dataset): 25 | """Represents an abstract HDF5 dataset. 26 | 27 | Input params: 28 | file_path: Path to the folder containing the dataset (one or multiple HDF5 files). 29 | recursive: If True, searches for h5 files in subdirectories. 30 | load_data: If True, loads all the data immediately into RAM. Use this if 31 | the dataset is fits into memory. Otherwise, leave this at false and 32 | the data will load lazily. 33 | data_cache_size: Number of HDF5 files that can be cached in the cache (default=3). 34 | transform: PyTorch transform to apply to every data instance (default=None). 35 | """ 36 | def __init__(self, img_path, txt_path, column='report', size=None, transform=None): 37 | super().__init__() 38 | if size != None: 39 | self.img_dset = h5py.File(img_path, 'r')['cxr'][:size] 40 | self.txt_dset = pd.read_csv(txt_path)[column][:size] 41 | else: 42 | self.img_dset = h5py.File(img_path, 'r')['cxr'] 43 | self.txt_dset = pd.read_csv(txt_path)[column] 44 | self.transform = transform 45 | 46 | def __len__(self): 47 | return len(self.txt_dset) 48 | 49 | def __getitem__(self, idx): 50 | if torch.is_tensor(idx): 51 | idx = idx.tolist() 52 | 53 | img = self.img_dset[idx] # np array, (320, 320) 54 | img = np.expand_dims(img, axis=0) 55 | img = np.repeat(img, 3, axis=0) 56 | txt = self.txt_dset[idx] # python str 57 | if type(txt) == type(float("nan")): # capture the case of empty "Impression" sections 58 | txt = " " 59 | 60 | img = torch.from_numpy(img) # torch, (3, 320, 320) 61 | if self.transform: 62 | img = self.transform(img) 63 | sample = {'img': img, 'txt': txt } 64 | 65 | return sample 66 | 67 | def load_data(cxr_filepath, txt_filepath, batch_size=4, column='report', pretrained=False, verbose=False): 68 | if torch.cuda.is_available(): 69 | dev = "cuda:0" 70 | cuda_available = True 71 | print('Using CUDA.') 72 | else: 73 | dev = "cpu" 74 | cuda_available = False 75 | print('Using cpu.') 76 | 77 | device = torch.device(dev) 78 | 79 | if cuda_available: 80 | torch.cuda.set_device(device) 81 | 82 | if pretrained: 83 | input_resolution = 224 84 | transform = Compose([ 85 | Normalize((101.48761, 101.48761, 101.48761), (83.43944, 83.43944, 83.43944)), 86 | Resize(input_resolution, interpolation=InterpolationMode.BICUBIC), 87 | ]) 88 | print('Interpolation Mode: ', InterpolationMode.BICUBIC) 89 | print("Finished image transforms for pretrained model.") 90 | else: 91 | input_resolution = 320 92 | transform = Compose([ 93 | Normalize((101.48761, 101.48761, 101.48761), (83.43944, 83.43944, 83.43944)), 94 | ]) 95 | print("Finished image transforms for clip model.") 96 | 97 | torch_dset = CXRDataset(img_path=cxr_filepath, 98 | txt_path=txt_filepath, column=column, transform=transform) 99 | 100 | if verbose: 101 | for i in range(len(torch_dset)): 102 | sample = torch_dset[i] 103 | plt.imshow(sample['img'][0]) 104 | plt.show() 105 | print(i, sample['img'].size(), sample['txt']) 106 | if i == 3: 107 | break 108 | 109 | loader_params = {'batch_size':batch_size, 'shuffle': True, 'num_workers': 0} 110 | data_loader = data.DataLoader(torch_dset, **loader_params) 111 | return data_loader, device 112 | 113 | def load_clip(model_path=None, pretrained=False, context_length=77): 114 | ''' 115 | FUNCTION: load_clip 116 | ------------------------------- 117 | This function loads in a model with the CLIP model 118 | architecture. 119 | 120 | args: 121 | * model_path (optional) - path to model weights that the model 122 | will be initialized with 123 | * pretrained (optional) - if True, will load the pretrained 124 | CLIP model 125 | * context_length (optional) - length of the maximum number of 126 | tokens that can be inputted into the CLIP model 127 | ''' 128 | 129 | params = { 130 | 'embed_dim':768, 131 | 'image_resolution': 320, 132 | 'vision_layers': 12, 133 | 'vision_width': 768, 134 | 'vision_patch_size': 16, 135 | 'context_length': context_length, 136 | 'vocab_size': 49408, 137 | 'transformer_width': 512, 138 | 'transformer_heads': 8, 139 | 'transformer_layers': 12 140 | } 141 | 142 | # set device 143 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 144 | 145 | if pretrained: 146 | # load clip pre-trained model 147 | model, preprocess = clip.load("ViT-B/32", device=device, jit=False) 148 | print("Loaded in pretrained model.") 149 | else: 150 | model = CLIP(**params) 151 | print("Loaded in clip model.") 152 | 153 | # if a model_path is provided, load in weights to backbone 154 | if model_path != None: 155 | model.load_state_dict(torch.load(model_path, map_location=device)) 156 | return model 157 | 158 | 159 | def preprocess_text(texts, model): 160 | # if model.context_length is None: 161 | # model = model.module 162 | 163 | _tokenizer = SimpleTokenizer() 164 | sot_token = _tokenizer.encoder["<|startoftext|>"] 165 | eot_token = _tokenizer.encoder["<|endoftext|>"] 166 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 167 | result = torch.zeros(len(all_tokens), model.context_length, dtype=torch.long) 168 | 169 | for i, tokens in enumerate(all_tokens): 170 | if len(tokens) > model.context_length: 171 | tokens = tokens[:model.context_length] 172 | tokens[model.context_length - 1] = eot_token 173 | result[i, :len(tokens)] = torch.tensor(tokens) 174 | return result 175 | 176 | def make(config, cxr_filepath, txt_filepath, model_path=None): 177 | ''' 178 | FUNCTION: make 179 | --------------------------------- 180 | This function makes the model, the data loader, loss and optimizer. 181 | 182 | args: 183 | * config - dict, configuration of experiment 184 | * cxr_filepath - string, filepath to chest x-ray images 185 | * txt_filepath - string, filepath to corresponding text reports 186 | * model_path - string, filepath to previously trained model 187 | ''' 188 | data_loader, device = load_data(cxr_filepath, txt_filepath, batch_size=config.batch_size, pretrained=config.pretrained, column=config.column) 189 | model = load_clip(model_path=model_path, pretrained=config.pretrained, context_length=config.context_length) 190 | model.to(device) 191 | print('Model on Device.') 192 | 193 | # make the optimizer 194 | criterion = nn.CrossEntropyLoss().cuda() 195 | # todo: incorporate - torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False) 196 | optimizer = optim.AdamW(model.parameters(), lr=config.lr) 197 | return model, data_loader, device, criterion, optimizer 198 | 199 | 200 | def train_main(cxr_filepath, txt_filepath, hyperparams, output_path, model_path=None, pretrained=False): 201 | ''' 202 | args: 203 | * cxr_filpath- str filepath to cxr images 204 | * txt_filepath- str filepath to text reports 205 | * hyperparams- dictionary with the following hyperparams: 206 | `batch_size`, `criterion`, `learning_rate`, `momentum`, `epochs` 207 | * output_path- str filepath to where the trained model will be saved 208 | * model_path- str filepath to model that will be used as baseline model for training. 209 | If not provided, a model will be trained from scratch 210 | * pretrained- whether or not the clip model was pretrained with generic images 211 | This function is the main train function for CXR-CLIP. 212 | ''' 213 | 214 | # unpack `hyperparams` 215 | batch_size = hyperparams['batch_size'] 216 | criterion = hyperparams['criterion'] 217 | learning_rate = hyperparams['learning_rate'] 218 | momentum = hyperparams['momentum'] 219 | epochs = hyperparams['epochs'] 220 | 221 | # load input cxr + report data 222 | data_loader, device = load_data(cxr_filepath, txt_filepath, batch_size=batch_size, pretrained=pretrained) 223 | model = load_clip(model_path=model_path, pretrained=pretrained) 224 | 225 | optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum) 226 | train_clip(model, data_loader, device, criterion, optimizer, epochs, output_path) 227 | return model 228 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Expert-level detection of pathologies from unannotated chest X-ray images via self-supervised learning 2 | 3 |
4 | 5 | Expert-level detection of pathologies from unannotated chest X-ray images via self-supervised learning, Nat. Biomed. Eng (2022). 6 | [Paper] 7 |
Ekin Tiu, Ellie Talius, Pujan Patel, Curtis P. Langlotz, Andrew Y. Ng, Pranav Rajpurkar
8 |
9 | 10 | ```bash 11 | Tiu, E., Talius, E., Patel, P. et al. Expert-level detection of pathologies from unannotated chest X-ray images via self-supervised learning. Nat. Biomed. Eng (2022). https://doi.org/10.1038/s41551-022-00936-9 12 | ``` 13 |
14 | 15 | Screen Shot 2022-09-15 at 10 57 16 AM 16 | 17 | This repository contains code to train a self-supervised learning model on chest X-ray images that lack explicit annotations and evalute this model's performance on pathology-classification tasks. 18 | 19 |
20 | 21 | Main Findings 22 | 23 | 24 | 1. **Automatically detecting pathologies in chest x-rays without explicit annotations:** Our method learns directly from the combination of images and unstructured radiology reports, thereby avoiding time-consuming labeling efforts. Our deep learning method is capable of predicting multiple pathologies and differential diagnoses that it had not explicitly seen during training. 25 | 2. **Matching radiologist performance on different tasks on an external test set:** Our method performed on par with human performance when evaluated on an external validation set (CheXpert) of chest x-ray images labeled for the presence of 14 different conditions by multiple radiologists. 26 | 3. **Outperforming approaches that train on explicitly labeled data on an external test set:** Using no labels, we outperformed a fully supervised approach (100% of labels) on 3 out of the 8 selected pathologies on a dataset (PadChest) collected in a different country. We further demonstrated high performance (AUC > 0.9) on 14 findings and at least 0.700 on 53 findings out of 107 radiographic findings that the method had not seen during training. 27 |
28 | 29 | 30 | ## Dependencies 31 | To clone all files: 32 | 33 | ```git clone https://github.com/rajpurkarlab/CheXzero.git``` 34 | 35 | To install Python dependencies: 36 | 37 | ```pip install -r requirements.txt``` 38 | 39 | ## Data 40 | ### Training Dataset 41 | 1. Download images come from [MIMIC-CXR JPG] https://physionet.org/content/mimic-cxr-jpg/2.0.0/ and reports from [MIMIC-CXR Database](https://physionet.org/content/mimic-cxr/2.0.0/) Note: in order to gain access to the data, you must be a credentialed user as defined on [PhysioNet](https://physionet.org/settings/credentialing/). 42 | 2. Copy the dataset into the `data/` directory. 43 | 3. Run `python run_preprocess.py` 44 | 4. This should preprocess the chest x-ray images into a Hierarchical Data Format (HDF) format used for training stored at `data/cxr.h5` and extract the impressions section as text from the corresponding chest x-ray radiology report stored at `data/mimic_impressions.csv` . 45 | 46 | ### Evaluation Dataset 47 | 48 | #### CheXpert Dataset 49 | The CheXpert dataset consists of chest radiographic examinations from Stanford Hospital, performed between October 2002 50 | and July 2017 in both inpatient and outpatient centers. Population-level characteristics are unavailable for the CheXpert test 51 | dataset, as they are used for official evaluation on the CheXpert leaderboard. 52 | 53 | The main data (CheXpert data) supporting the results of this study are available at https://aimi.stanford.edu/chexpert-chest-x-rays. 54 | 55 | The CheXpert **test** dataset has recently been made public, and can be found by following the steps in the [cheXpert-test-set-labels](https://github.com/rajpurkarlab/cheXpert-test-set-labels) repository. 56 | 57 | #### PadChest Dataset 58 | The PadChest dataset contains chest X-rays that were interpreted by 18 radiologists at the Hospital Universitario de San Juan, 59 | Alicante, Spain, from January 2009 to December 2017. The dataset contains 109,931 image studies and 168,861 images. 60 | PadChest also contains 206,222 study reports. 61 | 62 | The [PadChest](https://arxiv.org/abs/1901.07441) is publicly available at https://bimcv.cipf.es/bimcv-projects/padchest. Those who would like to use PadChest for experimentation should request access to PadChest at the [link](https://bimcv.cipf.es/bimcv-projects/padchest). 63 | 64 | ### Model Checkpoints 65 | Model checkpoints of CheXzero pre-trained on MIMIC-CXR are publicly available at the following [link](https://drive.google.com/drive/folders/1makFLiEMbSleYltaRxw81aBhEDMpVwno?usp=sharing). Download files and save them in the `./checkpoints/chexzero_weights` directory. 66 | 67 | ## Running Training 68 | Run the following command to perform CheXzero pretraining. 69 | ```bash 70 | python run_train.py --cxr_filepath "./data/cxr.h5" --txt_filepath "data/mimic_impressions.csv" 71 | ``` 72 | 73 | ### Arguments 74 | * `--cxr_filepath` Directory to load chest x-ray image data from. 75 | * `--txt_filepath` Directory to load radiology report impressions text from. 76 | 77 | Use `-h` flag to see all optional arguments. 78 | 79 | ## Zero-Shot Inference 80 | See the following [notebook](https://github.com/rajpurkarlab/CheXzero/blob/main/notebooks/zero_shot.ipynb) for an example of how to use CheXzero to perform zero-shot inference on a chest x-ray dataset. The example shows how to output predictions from the model ensemble and evaluate performance of the model if ground truth labels are available. 81 | 82 | ```python 83 | import zero_shot 84 | 85 | # computes predictions for a set of images stored as a np array of probabilities for each pathology 86 | predictions, y_pred_avg = zero_shot.ensemble_models( 87 | model_paths=model_paths, 88 | cxr_filepath=cxr_filepath, 89 | cxr_labels=cxr_labels, 90 | cxr_pair_template=cxr_pair_template, 91 | cache_dir=cache_dir, 92 | ) 93 | ``` 94 | ### Arguments 95 | * `model_paths: List[str]`: List of paths to all checkpoints to be used in the ensemble. To run on a single model, input a list containing a single path. 96 | * `cxr_filepath: str`: Path to images `.h5` file 97 | * `cxr_labels: List[str]`: List of pathologies to query in each image 98 | * `cxr_pair_templates: Tuple[str, str]`: constrasting templates used to query model (see Figure 1 in article for visual explanation). 99 | * `cache_dir: str`: Directory to cache predictions of each checkpoint, use to avoid recomputing predictions. 100 | 101 | In order to use CheXzero for zero-shot inference, ensure the following requirements are met: 102 | * All input *`images`* must be stored in a single `.h5` (Hierarchical Data Format). See the [`img_to_h5`](https://github.com/rajpurkarlab/CheXzero/blob/main/preprocess_padchest.py#L156) function in [preprocess_padchest.py](https://github.com/rajpurkarlab/internal-chexzero/blob/cleanversion/preprocess_padchest.py) for an example of how to convert a list of paths to `.png` files into a valid `.h5` file. 103 | * The *ground truth `labels`* must be in a `.csv` dataframe where rows represent each image sample, and each column represents the binary labels for a particular pathology on each sample. 104 | * Ensure all [model checkpoints](https://drive.google.com/drive/folders/1makFLiEMbSleYltaRxw81aBhEDMpVwno?usp=sharing) are stored in `checkpoints/chexzero_weights/`, or the `model_dir` that is specified in the notebook. 105 | 106 | ## Evaluation 107 | Given a numpy array of predictions (obtained from zero-shot inference), and a numpy array of ground truth labels, one can evaluate the performance of the model using the following code: 108 | ```python 109 | import zero_shot 110 | import eval 111 | 112 | # loads in ground truth labels into memory 113 | test_pred = y_pred_avg 114 | test_true = zero_shot.make_true_labels(cxr_true_labels_path=cxr_true_labels_path, cxr_labels=cxr_labels) 115 | 116 | # evaluate model, no bootstrap 117 | cxr_results: pd.DataFrame = eval.evaluate(test_pred, test_true, cxr_labels) # eval on full test datset 118 | 119 | # boostrap evaluations for 95% confidence intervals 120 | bootstrap_results: Tuple[pd.DataFrame, pd.DataFrame] = eval.bootstrap(test_pred, test_true, cxr_labels) # (df of results for each bootstrap, df of CI) 121 | 122 | # print results with confidence intervals 123 | print(bootstrap_results[1]) 124 | ``` 125 | The results are represented as a `pd.DataFrame` which can be saved as a `.csv`. 126 | 127 | ### CheXpert Test Dataset 128 | In order to replicate the results in the paper, zero-shot inference and evaluation can be performed on the now publicly available CheXpert test dataset. 129 | 1) Download labels at [cheXpert-test-set-labels](https://github.com/rajpurkarlab/cheXpert-test-set-labels/blob/main/groundtruth.csv) and image files from [Stanford AIMI](https://stanfordaimi.azurewebsites.net/datasets/23c56a0d-15de-405b-87c8-99c30138950c) and save in the `./data` directory in `CheXzero/`. The test dataset images should have the following directory structure: 130 | ``` 131 | data/ 132 | ├─ CheXpert/ 133 | │ ├─ test/ 134 | │ │ ├─ patient64741/ 135 | │ │ │ ├─ study1/ 136 | │ │ │ │ ├─ view1_frontal.jpg 137 | │ │ ├─ .../ 138 | ``` 139 | 140 | 2) Run `run_preprocess.py` script with the following arguments: 141 | ```bash 142 | python run_preprocess.py --dataset_type "chexpert-test" --cxr_out_path "./data/chexpert_test.h5" --chest_x_ray_path "./data/CheXpert/test/" 143 | ``` 144 | This should save a `.h5` version of the test dataset images which can be used for evaluation. 145 | 146 | 3) Open sample zero-shot [notebook](https://github.com/rajpurkarlab/CheXzero/blob/main/notebooks/zero_shot.ipynb) and run all cells. If the directory structure is set up correctly, then all cells should run without errors. 147 | 148 | ## Issues 149 | Please open new issue threads specifying the issue with the codebase or report issues directly to ekintiu@stanford.edu. 150 | 151 | ## Citation 152 | ```bash 153 | Tiu, E., Talius, E., Patel, P. et al. Expert-level detection of pathologies from unannotated chest X-ray images via self-supervised learning. Nat. Biomed. Eng (2022). https://doi.org/10.1038/s41551-022-00936-9 154 | ``` 155 | 156 | ## License 157 | The source code for the site is licensed under the MIT license, which you can find in the `LICENSE` file. Also see `NOTICE.md` for attributions to third-party sources. 158 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import numpy as np 3 | import os 4 | import pandas as pd 5 | from PIL import Image 6 | import h5py 7 | import matplotlib.pyplot as plt 8 | from typing import List, Callable 9 | from collections import defaultdict 10 | 11 | import torch 12 | from torch.utils import data 13 | from tqdm.notebook import tqdm 14 | import torch.nn as nn 15 | from torchvision.transforms import Compose, Normalize, Resize 16 | 17 | import sklearn 18 | from sklearn.metrics import matthews_corrcoef, confusion_matrix, accuracy_score, auc, roc_auc_score, roc_curve, classification_report 19 | from sklearn.metrics import precision_recall_curve, f1_score 20 | from sklearn.metrics import average_precision_score 21 | from sklearn.utils import resample 22 | 23 | import scipy 24 | import scipy.stats 25 | 26 | import sys 27 | sys.path.append('../..') 28 | 29 | import clip 30 | from model import CLIP 31 | from eval import * 32 | from zero_shot import * 33 | 34 | def evaluate_model(X_dir, y_dir, model_path, cxr_labels, alt_labels_dict=None): 35 | cxr_filepath = X_dir 36 | final_label_path = y_dir 37 | 38 | results_out_folder = './results' 39 | context_length = 77 40 | 41 | # templates list of positive and negative template pairs 42 | cxr_pair_templates = [("{}", "no {}")] 43 | 44 | cxr_results, y_pred = run_zero_shot(cxr_labels, cxr_pair_templates, model_path, cxr_filepath=cxr_filepath, final_label_path=final_label_path, alt_labels_dict=alt_labels_dict, softmax_eval=True, context_length=context_length, pretrained=True, use_bootstrap=True, cutlabels=True) 45 | return cxr_results, y_pred 46 | 47 | def f1_mcc_bootstrap(y_pred, y_true, cxr_labels, best_p_vals, eval_func, n_samples=5000, label_idx_map=None): 48 | ''' 49 | This function will randomly sample with replacement 50 | from y_pred and y_true then evaluate `n` times 51 | and obtain AUROC scores for each. 52 | 53 | You can specify the number of samples that should be 54 | used with the `n_samples` parameter. 55 | 56 | Confidence intervals will be generated from each 57 | of the samples. 58 | ''' 59 | y_pred # (500, 14) 60 | y_true # (500, 14) 61 | 62 | idx = np.arange(len(y_true)) 63 | 64 | boot_stats = [] 65 | for i in tqdm(range(n_samples)): 66 | sample = resample(idx, replace=True) 67 | y_pred_sample = y_pred[sample] 68 | y_true_sample = y_true[sample] 69 | 70 | sample_stats = eval_func(y_pred_sample, y_true_sample, best_p_vals, cxr_labels=cxr_labels, label_idx_map=label_idx_map) 71 | boot_stats.append(sample_stats) 72 | 73 | boot_stats = pd.concat(boot_stats) # pandas array of evaluations for each sample 74 | return boot_stats, compute_cis(boot_stats) 75 | 76 | def get_best_alt_labels(res_df, cxr_labels): 77 | best_alt_labels_dict = dict() 78 | best_alt_labels_vals = dict() 79 | res_cols = list(res_df) 80 | 81 | curr_path_name = None 82 | for col in res_cols: # for each col 83 | path_name = col.split("_")[0] # pathology name 84 | mean_auc = res_df[col][0] # mean auc 85 | 86 | if path_name in cxr_labels: 87 | # reset the vars 88 | curr_path_name = path_name 89 | best_alt_labels_dict[path_name] = [path_name] 90 | best_alt_labels_vals[path_name] = mean_auc 91 | 92 | if best_alt_labels_vals[curr_path_name] < mean_auc: 93 | best_alt_labels_vals[curr_path_name] = mean_auc 94 | best_alt_labels_dict[curr_path_name] = [path_name] 95 | 96 | return best_alt_labels_dict 97 | 98 | def y_true_csv_to_np(df_path, cxr_labels): 99 | groundtruth = pd.read_csv(df_path) 100 | groundtruth = groundtruth[cxr_labels] 101 | groundtruth = groundtruth.to_numpy()[:,:].astype(int) 102 | return groundtruth 103 | 104 | def get_best_p_vals(pred, groundtruth, cxr_labels, metric_func=matthews_corrcoef, spline_k: int = None, verbose: bool = False): 105 | """ 106 | WARNING: CXR_LABELS must 107 | Params: 108 | * pred : np arr 109 | probabilities output by model 110 | 111 | * plot_graphs : bool 112 | if True, will save plots for metric vs. threshold for 113 | each pathology 114 | 115 | Note: 116 | * `probabilities` value is a linspace of possible probabilities 117 | """ 118 | probabilities = [val for val in np.arange(0.4, 0.64, 0.0001)] 119 | best_p_vals = dict() 120 | for idx, cxr_label in enumerate(cxr_labels): 121 | y_true = groundtruth[:, idx] 122 | _, _, probabilities = roc_curve(y_true, pred[:, idx]) 123 | probabilities = probabilities[1:] 124 | probabilities.sort() 125 | 126 | metrics_list = [] 127 | for p in probabilities: 128 | y_pred = np.where(pred[:, idx] < p, 0, 1) 129 | metric = metric_func(y_true, y_pred) 130 | metrics_list.append(metric) 131 | 132 | if spline_k is not None: 133 | try: 134 | spl = UnivariateSpline(probabilities, metrics_list, k=spline_k) 135 | spl_y = spl(probabilities) 136 | # get optimal thresholds on the spline and on the val_metric_list 137 | best_index = np.argmax(spl_y) 138 | except: 139 | best_index = np.argmax(metrics_list) 140 | else: 141 | best_index = np.argmax(metrics_list) 142 | 143 | best_p = probabilities[best_index] 144 | best_metric = metrics_list[best_index] 145 | if verbose: 146 | print("Best metric for {} is {}. threshold = {}.".format(cxr_label, best_metric, best_p)) 147 | 148 | best_p_vals[cxr_label] = best_p 149 | return best_p_vals 150 | 151 | def compute_f1_mcc(X_test_dir, y_test_dir, X_val_dir, y_val_dir, model_path, alt_labels_dict : dict = None, find_best_alt: bool = True, thresh_func: Callable = matthews_corrcoef): 152 | """ 153 | Computes f1 and mcc scores given test dataset, validation dataset (to find 154 | thresholds) and path to the model. 155 | 156 | Params: 157 | * find_best_alt : bool 158 | If True, will filter alt_labels_dict to only the best alternative labels 159 | based on the validation dataset. Otherwise, will run on all alternative labels 160 | provided. 161 | """ 162 | 163 | 164 | 165 | # specify basic cxr labels 166 | cxr_labels = ['Atelectasis','Cardiomegaly', 167 | 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 168 | 'Lung Opacity', 'No Finding','Pleural Effusion', 'Pleural Other', 'Pneumonia', 169 | 'Pneumothorax', 'Support Devices'] 170 | 171 | # load in ground truth 172 | VAL_GROUNDTRUTH_PATH = "val_groundtruth.csv" 173 | GROUNDTRUTH_PATH = "groundtruth.csv" 174 | 175 | val_groundtruth = y_true_csv_to_np(VAL_GROUNDTRUTH_PATH, cxr_labels) 176 | groundtruth = y_true_csv_to_np(GROUNDTRUTH_PATH, cxr_labels) 177 | 178 | NUM_LABELS = 14 179 | 180 | # run evaluation on validation and test datasets 181 | # dir for validation datasets 182 | val_X = "/deep/group/data/med-data/valid.h5" 183 | val_y = "/deep/group/data/CheXpert-320x320/valid.csv" 184 | 185 | # dir for test datasets 186 | test_X = "/deep/group/data/med-data/test_cxr.h5" 187 | test_y = "/deep/group/data/med-data/final_paths.csv" 188 | 189 | if alt_labels_dict is not None and find_best_alt: 190 | # find best alternate labels 191 | val_res, val_pred = evaluate_model(val_X, val_y, best_model_path, alt_labels_dict=alt_labels_dict) 192 | # save alternative label results on validation dataset 193 | alt_val_res = val_res[0][1] 194 | best_alt_labels_dict = get_best_alt_labels(alt_val_res, cxr_labels) 195 | elif alt_labels_dict is not None: # find_best_alt == False 196 | best_alt_labels_dict = alt_labels_dict 197 | else: # no alternative labels 198 | best_alt_labels_dict = None 199 | 200 | # create alt_labels 201 | if best_alt_labels_dict is not None: 202 | alt_labels_list, alt_label_idx_map = process_alt_labels(best_alt_labels_dict, cxr_labels) 203 | else: 204 | alt_labels_list, alt_label_idx_map = cxr_labels, None 205 | 206 | # TODO: convert preds into binarized and make this one of the things that are returned 207 | val_res, val_pred = evaluate_model(val_X, val_y, model_path, cxr_labels, alt_labels_dict=best_alt_labels_dict) 208 | test_res, test_pred = evaluate_model(test_X, test_y, model_path, cxr_labels, alt_labels_dict=best_alt_labels_dict) 209 | 210 | # get best thresholds 211 | best_p_vals = get_best_p_vals(val_pred, val_groundtruth, alt_labels_list, alt_label_idx_map, metric_func=thresh_func) 212 | 213 | # f1 computation 214 | f1_cis = compute_f1(test_pred, groundtruth, alt_labels_list, best_p_vals, alt_label_idx_map) 215 | # mcc computation 216 | mcc_cis = compute_mcc(test_pred, groundtruth, alt_labels_list, best_p_vals, alt_label_idx_map) 217 | 218 | return f1_cis, mcc_cis 219 | 220 | def compute_f1(y_pred, y_true, cxr_labels, thresholds, label_idx_map=None): 221 | def get_f1_clip_bootstrap(y_pred, y_true, best_p_vals, cxr_labels=cxr_labels, label_idx_map=None): 222 | stats = {} 223 | probs = np.copy(y_pred) 224 | for idx, cxr_label in enumerate(cxr_labels): 225 | p = best_p_vals[cxr_label] 226 | probs[:,idx] = np.where(probs[:,idx] < p, 0, 1) 227 | clip_preds = np.copy(probs) 228 | for idx, cxr_label in enumerate(cxr_labels): 229 | 230 | if label_idx_map is None: 231 | curr_y_true = y_true[:, idx] 232 | else: 233 | curr_y_true = y_true[:, label_idx_map[cxr_label]] 234 | curr_y_pred = clip_preds[:, idx] 235 | 236 | m = confusion_matrix(curr_y_true, curr_y_pred) 237 | if len(m.ravel()) == 1: 238 | tn = 500 239 | fp = 0 240 | fn = 0 241 | tp = 0 242 | else: 243 | tn, fp, fn, tp = m.ravel() 244 | 245 | if ((2*tp + fp +fn) == 0): 246 | stats[cxr_label] = 1 247 | continue 248 | 249 | stats[cxr_label] = [(2 * tp) / (2*tp + fp +fn)] 250 | # compute mean over five major pathologies 251 | stats["Mean"] = compute_mean(stats, is_df=False) 252 | return pd.DataFrame.from_dict(stats) 253 | 254 | boot_stats, f1_cis = f1_mcc_bootstrap(y_pred, y_true, cxr_labels, thresholds, get_f1_clip_bootstrap, n_samples=1000, label_idx_map=label_idx_map) 255 | return f1_cis 256 | 257 | def compute_mcc(y_pred: np.array, y_true: np.array, cxr_labels: List, thresholds: dict, label_idx_map: dict = None): 258 | def get_mcc_bootstrap(y_pred, y_true, best_p_vals, cxr_labels=cxr_labels, label_idx_map=None): 259 | stats = {} 260 | probs = np.copy(y_pred) 261 | 262 | for idx, cxr_label in enumerate(cxr_labels): 263 | p = best_p_vals[cxr_label] 264 | probs[:,idx] = np.where(probs[:,idx] < p, 0, 1) 265 | 266 | clip_preds = np.copy(probs) 267 | 268 | for idx, cxr_label in enumerate(cxr_labels): 269 | if label_idx_map is None: 270 | curr_y_true = y_true[:, idx] 271 | else: 272 | curr_y_true = y_true[:, label_idx_map[cxr_label]] 273 | 274 | curr_y_pred = clip_preds[:, idx] 275 | stats[cxr_label] = [matthews_corrcoef(curr_y_true, curr_y_pred)] 276 | # compute mean over five major pathologies 277 | stats["Mean"] = compute_mean(stats, is_df=False) 278 | return pd.DataFrame.from_dict(stats) 279 | 280 | boot_stats, mcc_cis = f1_mcc_bootstrap(y_pred, y_true, cxr_labels, thresholds, get_mcc_bootstrap, n_samples=1000, label_idx_map=label_idx_map) 281 | return mcc_cis 282 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2021 OpenAI 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | 24 | """ 25 | from collections import OrderedDict 26 | from typing import Tuple, Union 27 | 28 | import numpy as np 29 | import torch 30 | import torch.nn.functional as F 31 | from torch import nn 32 | 33 | 34 | class Bottleneck(nn.Module): 35 | expansion = 4 36 | 37 | def __init__(self, inplanes, planes, stride=1): 38 | super().__init__() 39 | 40 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 41 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 42 | self.bn1 = nn.BatchNorm2d(planes) 43 | 44 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | 47 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 48 | 49 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 50 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 51 | 52 | self.relu = nn.ReLU(inplace=True) 53 | self.downsample = None 54 | self.stride = stride 55 | 56 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 57 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 58 | self.downsample = nn.Sequential(OrderedDict([ 59 | ("-1", nn.AvgPool2d(stride)), 60 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 61 | ("1", nn.BatchNorm2d(planes * self.expansion)) 62 | ])) 63 | 64 | def forward(self, x: torch.Tensor): 65 | identity = x 66 | 67 | out = self.relu(self.bn1(self.conv1(x))) 68 | out = self.relu(self.bn2(self.conv2(out))) 69 | out = self.avgpool(out) 70 | out = self.bn3(self.conv3(out)) 71 | 72 | if self.downsample is not None: 73 | identity = self.downsample(x) 74 | 75 | out += identity 76 | out = self.relu(out) 77 | return out 78 | 79 | 80 | class AttentionPool2d(nn.Module): 81 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 82 | super().__init__() 83 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 84 | self.k_proj = nn.Linear(embed_dim, embed_dim) 85 | self.q_proj = nn.Linear(embed_dim, embed_dim) 86 | self.v_proj = nn.Linear(embed_dim, embed_dim) 87 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 88 | self.num_heads = num_heads 89 | 90 | def forward(self, x): 91 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 92 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 93 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 94 | x, _ = F.multi_head_attention_forward( 95 | query=x, key=x, value=x, 96 | embed_dim_to_check=x.shape[-1], 97 | num_heads=self.num_heads, 98 | q_proj_weight=self.q_proj.weight, 99 | k_proj_weight=self.k_proj.weight, 100 | v_proj_weight=self.v_proj.weight, 101 | in_proj_weight=None, 102 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 103 | bias_k=None, 104 | bias_v=None, 105 | add_zero_attn=False, 106 | dropout_p=0, 107 | out_proj_weight=self.c_proj.weight, 108 | out_proj_bias=self.c_proj.bias, 109 | use_separate_proj_weight=True, 110 | training=self.training, 111 | need_weights=False 112 | ) 113 | 114 | return x[0] 115 | 116 | 117 | class ModifiedResNet(nn.Module): 118 | """ 119 | A ResNet class that is similar to torchvision's but contains the following changes: 120 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 121 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 122 | - The final pooling layer is a QKV attention instead of an average pool 123 | """ 124 | 125 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 126 | super().__init__() 127 | self.output_dim = output_dim 128 | self.input_resolution = input_resolution 129 | 130 | # the 3-layer stem 131 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 132 | self.bn1 = nn.BatchNorm2d(width // 2) 133 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 134 | self.bn2 = nn.BatchNorm2d(width // 2) 135 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 136 | self.bn3 = nn.BatchNorm2d(width) 137 | self.avgpool = nn.AvgPool2d(2) 138 | self.relu = nn.ReLU(inplace=True) 139 | 140 | # residual layers 141 | self._inplanes = width # this is a *mutable* variable used during construction 142 | self.layer1 = self._make_layer(width, layers[0]) 143 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 144 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 145 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 146 | 147 | embed_dim = width * 32 # the ResNet feature dimension 148 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 149 | 150 | def _make_layer(self, planes, blocks, stride=1): 151 | layers = [Bottleneck(self._inplanes, planes, stride)] 152 | 153 | self._inplanes = planes * Bottleneck.expansion 154 | for _ in range(1, blocks): 155 | layers.append(Bottleneck(self._inplanes, planes)) 156 | 157 | return nn.Sequential(*layers) 158 | 159 | def forward(self, x): 160 | def stem(x): 161 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 162 | x = self.relu(bn(conv(x))) 163 | x = self.avgpool(x) 164 | return x 165 | 166 | x = x.type(self.conv1.weight.dtype) 167 | x = stem(x) 168 | x = self.layer1(x) 169 | x = self.layer2(x) 170 | x = self.layer3(x) 171 | x = self.layer4(x) 172 | x = self.attnpool(x) 173 | 174 | return x 175 | 176 | 177 | class LayerNorm(nn.LayerNorm): 178 | """Subclass torch's LayerNorm to handle fp16.""" 179 | 180 | def forward(self, x: torch.Tensor): 181 | orig_type = x.dtype 182 | ret = super().forward(x.type(torch.float32)) 183 | return ret.type(orig_type) 184 | 185 | 186 | class QuickGELU(nn.Module): 187 | def forward(self, x: torch.Tensor): 188 | return x * torch.sigmoid(1.702 * x) 189 | 190 | 191 | class ResidualAttentionBlock(nn.Module): 192 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 193 | super().__init__() 194 | 195 | self.attn = nn.MultiheadAttention(d_model, n_head) 196 | self.ln_1 = LayerNorm(d_model) 197 | self.mlp = nn.Sequential(OrderedDict([ 198 | ("c_fc", nn.Linear(d_model, d_model * 4)), 199 | ("gelu", QuickGELU()), 200 | ("c_proj", nn.Linear(d_model * 4, d_model)) 201 | ])) 202 | self.ln_2 = LayerNorm(d_model) 203 | self.attn_mask = attn_mask 204 | 205 | def attention(self, x: torch.Tensor): 206 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 207 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 208 | 209 | def forward(self, x: torch.Tensor): 210 | x = x + self.attention(self.ln_1(x)) 211 | x = x + self.mlp(self.ln_2(x)) 212 | return x 213 | 214 | 215 | class Transformer(nn.Module): 216 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 217 | super().__init__() 218 | self.width = width 219 | self.layers = layers 220 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 221 | 222 | def forward(self, x: torch.Tensor): 223 | return self.resblocks(x) 224 | 225 | 226 | class VisualTransformer(nn.Module): 227 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 228 | super().__init__() 229 | self.input_resolution = input_resolution 230 | self.output_dim = output_dim 231 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 232 | 233 | scale = width ** -0.5 234 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 235 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 236 | self.ln_pre = LayerNorm(width) 237 | 238 | self.transformer = Transformer(width, layers, heads) 239 | 240 | self.ln_post = LayerNorm(width) 241 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 242 | 243 | def forward(self, x: torch.Tensor): 244 | x = self.conv1(x) # shape = [*, width, grid, grid] 245 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 246 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 247 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 248 | x = x + self.positional_embedding.to(x.dtype) 249 | x = self.ln_pre(x) 250 | 251 | x = x.permute(1, 0, 2) # NLD -> LND 252 | x = self.transformer(x) 253 | x = x.permute(1, 0, 2) # LND -> NLD 254 | 255 | x = self.ln_post(x[:, 0, :]) 256 | 257 | if self.proj is not None: 258 | x = x @ self.proj 259 | 260 | return x 261 | 262 | 263 | class CLIP(nn.Module): 264 | def __init__(self, 265 | embed_dim: int, 266 | # vision 267 | image_resolution: int, 268 | vision_layers: Union[Tuple[int, int, int, int], int], 269 | vision_width: int, 270 | vision_patch_size: int, 271 | # text 272 | context_length: int, 273 | vocab_size: int, 274 | transformer_width: int, 275 | transformer_heads: int, 276 | transformer_layers: int 277 | ): 278 | super().__init__() 279 | 280 | self.context_length = context_length 281 | 282 | if isinstance(vision_layers, (tuple, list)): 283 | vision_heads = vision_width * 32 // 64 284 | self.visual = ModifiedResNet( 285 | layers=vision_layers, 286 | output_dim=embed_dim, 287 | heads=vision_heads, 288 | input_resolution=image_resolution, 289 | width=vision_width 290 | ) 291 | else: 292 | vision_heads = vision_width // 64 293 | self.visual = VisualTransformer( 294 | input_resolution=image_resolution, 295 | patch_size=vision_patch_size, 296 | width=vision_width, 297 | layers=vision_layers, 298 | heads=vision_heads, 299 | output_dim=embed_dim 300 | ) 301 | 302 | self.transformer = Transformer( 303 | width=transformer_width, 304 | layers=transformer_layers, 305 | heads=transformer_heads, 306 | attn_mask=self.build_attention_mask() 307 | ) 308 | 309 | self.vocab_size = vocab_size 310 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 311 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 312 | self.ln_final = LayerNorm(transformer_width) 313 | 314 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 315 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 316 | 317 | self.initialize_parameters() 318 | 319 | def initialize_parameters(self): 320 | nn.init.normal_(self.token_embedding.weight, std=0.02) 321 | nn.init.normal_(self.positional_embedding, std=0.01) 322 | 323 | if isinstance(self.visual, ModifiedResNet): 324 | if self.visual.attnpool is not None: 325 | std = self.visual.attnpool.c_proj.in_features ** -0.5 326 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 327 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 328 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 329 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 330 | 331 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 332 | for name, param in resnet_block.named_parameters(): 333 | if name.endswith("bn3.weight"): 334 | nn.init.zeros_(param) 335 | 336 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 337 | attn_std = self.transformer.width ** -0.5 338 | fc_std = (2 * self.transformer.width) ** -0.5 339 | for block in self.transformer.resblocks: 340 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 341 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 342 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 343 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 344 | 345 | if self.text_projection is not None: 346 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 347 | 348 | def build_attention_mask(self): 349 | # lazily create causal attention mask, with full attention between the vision tokens 350 | # pytorch uses additive attention mask; fill with -inf 351 | mask = torch.empty(self.context_length, self.context_length) 352 | mask.fill_(float("-inf")) 353 | mask.triu_(1) # zero out the lower diagonal 354 | return mask 355 | 356 | @property 357 | def dtype(self): 358 | return self.visual.conv1.weight.dtype 359 | 360 | def encode_image(self, image): 361 | return self.visual(image.type(self.dtype)) 362 | 363 | def encode_text(self, text): 364 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 365 | 366 | x = x + self.positional_embedding.type(self.dtype) 367 | x = x.permute(1, 0, 2) # NLD -> LND 368 | x = self.transformer(x) 369 | x = x.permute(1, 0, 2) # LND -> NLD 370 | x = self.ln_final(x).type(self.dtype) 371 | 372 | # x.shape = [batch_size, n_ctx, transformer.width] 373 | # take features from the eot embedding (eot_token is the highest number in each sequence) 374 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 375 | 376 | return x 377 | 378 | def forward(self, image, text): 379 | image_features = self.encode_image(image) 380 | text_features = self.encode_text(text) 381 | 382 | # normalized features 383 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 384 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 385 | 386 | # cosine similarity as logits 387 | logit_scale = self.logit_scale.exp() 388 | logits_per_image = logit_scale * image_features @ text_features.t() 389 | logits_per_text = logit_scale * text_features @ image_features.t() 390 | 391 | # shape = [global_batch_size, global_batch_size] 392 | return logits_per_image, logits_per_text 393 | 394 | 395 | def convert_weights(model: nn.Module): 396 | """Convert applicable model parameters to fp16""" 397 | 398 | def _convert_weights_to_fp16(l): 399 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 400 | l.weight.data = l.weight.data.half() 401 | if l.bias is not None: 402 | l.bias.data = l.bias.data.half() 403 | 404 | if isinstance(l, nn.MultiheadAttention): 405 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 406 | tensor = getattr(l, attr) 407 | if tensor is not None: 408 | tensor.data = tensor.data.half() 409 | 410 | for name in ["text_projection", "proj"]: 411 | if hasattr(l, name): 412 | attr = getattr(l, name) 413 | if attr is not None: 414 | attr.data = attr.data.half() 415 | 416 | model.apply(_convert_weights_to_fp16) 417 | 418 | 419 | def build_model(state_dict: dict): 420 | vit = "visual.proj" in state_dict 421 | 422 | if vit: 423 | vision_width = state_dict["visual.conv1.weight"].shape[0] 424 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 425 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 426 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 427 | image_resolution = vision_patch_size * grid_size 428 | else: 429 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 430 | vision_layers = tuple(counts) 431 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 432 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 433 | vision_patch_size = None 434 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 435 | image_resolution = output_width * 32 436 | 437 | embed_dim = state_dict["text_projection"].shape[1] 438 | context_length = state_dict["positional_embedding"].shape[0] 439 | vocab_size = state_dict["token_embedding.weight"].shape[0] 440 | transformer_width = state_dict["ln_final.weight"].shape[0] 441 | transformer_heads = transformer_width // 64 442 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 443 | 444 | model = CLIP( 445 | embed_dim, 446 | image_resolution, vision_layers, vision_width, vision_patch_size, 447 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 448 | ) 449 | 450 | for key in ["input_resolution", "context_length", "vocab_size"]: 451 | if key in state_dict: 452 | del state_dict[key] 453 | 454 | convert_weights(model) 455 | model.load_state_dict(state_dict) 456 | return model.eval() 457 | -------------------------------------------------------------------------------- /zero_shot.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import numpy as np 3 | import os 4 | import sys 5 | import pandas as pd 6 | from PIL import Image 7 | import h5py 8 | import matplotlib.pyplot as plt 9 | from typing import List, Tuple 10 | 11 | import torch 12 | from torch.utils import data 13 | from tqdm.notebook import tqdm 14 | import torch.nn as nn 15 | from torchvision.transforms import Compose, Normalize, Resize, InterpolationMode 16 | 17 | import sklearn 18 | from sklearn.metrics import confusion_matrix, accuracy_score, auc, roc_auc_score, roc_curve, classification_report 19 | from sklearn.metrics import precision_recall_curve, f1_score 20 | from sklearn.metrics import average_precision_score 21 | 22 | import clip 23 | from model import CLIP 24 | from eval import evaluate, plot_roc, accuracy, sigmoid, bootstrap, compute_cis 25 | 26 | CXR_FILEPATH = '../../project-files/data/test_cxr.h5' 27 | FINAL_LABEL_PATH = '../../project-files/data/final_paths.csv' 28 | 29 | class CXRTestDataset(data.Dataset): 30 | """Represents an abstract HDF5 dataset. 31 | 32 | Input params: 33 | img_path: Path to hdf5 file containing images. 34 | label_path: Path to file containing labels 35 | transform: PyTorch transform to apply to every data instance (default=None). 36 | """ 37 | def __init__( 38 | self, 39 | img_path: str, 40 | transform = None, 41 | ): 42 | super().__init__() 43 | self.img_dset = h5py.File(img_path, 'r')['cxr'] 44 | self.transform = transform 45 | 46 | def __len__(self): 47 | return len(self.img_dset) 48 | 49 | def __getitem__(self, idx): 50 | if torch.is_tensor(idx): 51 | idx = idx.tolist() 52 | 53 | img = self.img_dset[idx] # np array, (320, 320) 54 | img = np.expand_dims(img, axis=0) 55 | img = np.repeat(img, 3, axis=0) 56 | img = torch.from_numpy(img) # torch, (320, 320) 57 | 58 | if self.transform: 59 | img = self.transform(img) 60 | 61 | sample = {'img': img} 62 | 63 | return sample 64 | 65 | def load_clip(model_path, pretrained=False, context_length=77): 66 | """ 67 | FUNCTION: load_clip 68 | --------------------------------- 69 | """ 70 | device = torch.device("cpu") 71 | if pretrained is False: 72 | # use new model params 73 | params = { 74 | 'embed_dim':768, 75 | 'image_resolution': 320, 76 | 'vision_layers': 12, 77 | 'vision_width': 768, 78 | 'vision_patch_size': 16, 79 | 'context_length': context_length, 80 | 'vocab_size': 49408, 81 | 'transformer_width': 512, 82 | 'transformer_heads': 8, 83 | 'transformer_layers': 12 84 | } 85 | 86 | model = CLIP(**params) 87 | else: 88 | model, preprocess = clip.load("ViT-B/32", device=device, jit=False) 89 | try: 90 | model.load_state_dict(torch.load(model_path, map_location=device)) 91 | except: 92 | print("Argument error. Set pretrained = True.", sys.exc_info()[0]) 93 | raise 94 | return model 95 | 96 | def zeroshot_classifier(classnames, templates, model, context_length=77): 97 | """ 98 | FUNCTION: zeroshot_classifier 99 | ------------------------------------- 100 | This function outputs the weights for each of the classes based on the 101 | output of the trained clip model text transformer. 102 | 103 | args: 104 | * classnames - Python list of classes for a specific zero-shot task. (i.e. ['Atelectasis',...]). 105 | * templates - Python list of phrases that will be indpendently tested as input to the clip model. 106 | * model - Pytorch model, full trained clip model. 107 | * context_length (optional) - int, max number of tokens of text inputted into the model. 108 | 109 | Returns PyTorch Tensor, output of the text encoder given templates. 110 | """ 111 | with torch.no_grad(): 112 | zeroshot_weights = [] 113 | # compute embedding through model for each class 114 | for classname in tqdm(classnames): 115 | texts = [template.format(classname) for template in templates] # format with class 116 | texts = clip.tokenize(texts, context_length=context_length) # tokenize 117 | class_embeddings = model.encode_text(texts) # embed with text encoder 118 | 119 | # normalize class_embeddings 120 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 121 | # average over templates 122 | class_embedding = class_embeddings.mean(dim=0) 123 | # norm over new averaged templates 124 | class_embedding /= class_embedding.norm() 125 | zeroshot_weights.append(class_embedding) 126 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1) 127 | return zeroshot_weights 128 | 129 | def predict(loader, model, zeroshot_weights, softmax_eval=True, verbose=0): 130 | """ 131 | FUNCTION: predict 132 | --------------------------------- 133 | This function runs the cxr images through the model 134 | and computes the cosine similarities between the images 135 | and the text embeddings. 136 | 137 | args: 138 | * loader - PyTorch data loader, loads in cxr images 139 | * model - PyTorch model, trained clip model 140 | * zeroshot_weights - PyTorch Tensor, outputs of text encoder for labels 141 | * softmax_eval (optional) - Use +/- softmax method for evaluation 142 | * verbose (optional) - bool, If True, will print out intermediate tensor values for debugging. 143 | 144 | Returns numpy array, predictions on all test data samples. 145 | """ 146 | y_pred = [] 147 | with torch.no_grad(): 148 | for i, data in enumerate(tqdm(loader)): 149 | images = data['img'] 150 | 151 | # predict 152 | image_features = model.encode_image(images) 153 | image_features /= image_features.norm(dim=-1, keepdim=True) # (1, 768) 154 | 155 | # obtain logits 156 | logits = image_features @ zeroshot_weights # (1, num_classes) 157 | logits = np.squeeze(logits.numpy(), axis=0) # (num_classes,) 158 | 159 | if softmax_eval is False: 160 | norm_logits = (logits - logits.mean()) / (logits.std()) 161 | logits = sigmoid(norm_logits) 162 | 163 | y_pred.append(logits) 164 | 165 | if verbose: 166 | plt.imshow(images[0][0]) 167 | plt.show() 168 | print('images: ', images) 169 | print('images size: ', images.size()) 170 | 171 | print('image_features size: ', image_features.size()) 172 | print('logits: ', logits) 173 | print('logits size: ', logits.size()) 174 | 175 | y_pred = np.array(y_pred) 176 | return np.array(y_pred) 177 | 178 | def run_single_prediction(cxr_labels, template, model, loader, softmax_eval=True, context_length=77): 179 | """ 180 | FUNCTION: run_single_prediction 181 | -------------------------------------- 182 | This function will make probability predictions for a single template 183 | (i.e. "has {}"). 184 | 185 | args: 186 | * cxr_labels - list, labels for a specific zero-shot task. (i.e. ['Atelectasis',...]) 187 | * template - string, template to input into model. 188 | * model - PyTorch model, trained clip model 189 | * loader - PyTorch data loader, loads in cxr images 190 | * softmax_eval (optional) - Use +/- softmax method for evaluation 191 | * context_length (optional) - int, max number of tokens of text inputted into the model. 192 | 193 | Returns list, predictions from the given template. 194 | """ 195 | cxr_phrase = [template] 196 | zeroshot_weights = zeroshot_classifier(cxr_labels, cxr_phrase, model, context_length=context_length) 197 | y_pred = predict(loader, model, zeroshot_weights, softmax_eval=softmax_eval) 198 | return y_pred 199 | 200 | def process_alt_labels(alt_labels_dict, cxr_labels): 201 | """ 202 | Process alt labels and return relevant info. If `alt_labels_dict` is 203 | None, return None. 204 | 205 | Returns: 206 | * alt_label_list : list 207 | List of all alternative labels 208 | * alt_label_idx_map : dict 209 | Maps alt label to idx of original label in cxr_labels 210 | Needed to access correct column during evaluation 211 | 212 | """ 213 | 214 | if alt_labels_dict is None: 215 | return None, None 216 | 217 | def get_inverse_labels(labels_alt_map: dict): 218 | """ 219 | Returns dict mapping alternative label back to actual label. 220 | Used for reference during evaluation. 221 | """ 222 | inverse_labels_dict = {} 223 | for main in labels_alt_map: 224 | inverse_labels_dict[main] = main # adds self to list of alt labels 225 | for alt in labels_alt_map[main]: 226 | inverse_labels_dict[alt] = main 227 | return inverse_labels_dict 228 | 229 | inv_labels_dict = get_inverse_labels(alt_labels_dict) 230 | alt_label_list = [w for w in inv_labels_dict.keys()] 231 | 232 | # create index map 233 | index_map = dict() 234 | for i, label in enumerate(cxr_labels): 235 | index_map[label] = i 236 | 237 | # make map to go from alt label directly to index 238 | alt_label_idx_map = dict() 239 | for alt_label in alt_label_list: 240 | alt_label_idx_map[alt_label] = index_map[inv_labels_dict[alt_label]] 241 | 242 | return alt_label_list, alt_label_idx_map 243 | 244 | def run_softmax_eval(model, loader, eval_labels: list, pair_template: tuple, context_length: int = 77): 245 | """ 246 | Run softmax evaluation to obtain a single prediction from the model. 247 | """ 248 | # get pos and neg phrases 249 | pos = pair_template[0] 250 | neg = pair_template[1] 251 | 252 | # get pos and neg predictions, (num_samples, num_classes) 253 | pos_pred = run_single_prediction(eval_labels, pos, model, loader, 254 | softmax_eval=True, context_length=context_length) 255 | neg_pred = run_single_prediction(eval_labels, neg, model, loader, 256 | softmax_eval=True, context_length=context_length) 257 | 258 | # compute probabilities with softmax 259 | sum_pred = np.exp(pos_pred) + np.exp(neg_pred) 260 | y_pred = np.exp(pos_pred) / sum_pred 261 | return y_pred 262 | 263 | def run_experiment(model, cxr_labels, cxr_templates, loader, y_true, alt_labels_dict=None, softmax_eval=True, context_length=77, use_bootstrap=True): 264 | ''' 265 | FUNCTION: run_experiment 266 | ---------------------------------------- 267 | This function runs the zeroshot experiment on each of the templates 268 | individually, and stores the results in a list. 269 | 270 | args: 271 | * model - PyTorch model, trained clip model 272 | * cxr_labels - list, labels for a specific zero-shot task. (i.e. ['Atelectasis',...]) 273 | * cxr_templates - list, templates to input into model. If softmax_eval is True, 274 | this should be a list of tuples, where each tuple is a +/- pair 275 | * loader - PyTorch data loader, loads in cxr images 276 | * y_true - list, ground truth labels for test dataset 277 | * softmax_eval (optional) - bool, if True, will evaluate results through softmax of pos vs. neg samples. 278 | * context_length - int, max number of tokens of text inputted into the model. 279 | * use_bootstrap (optional) - bool, whether or not to use bootstrap sampling 280 | 281 | Returns a list of results from the experiment. 282 | ''' 283 | 284 | alt_label_list, alt_label_idx_map = process_alt_labels(alt_labels_dict, cxr_labels) 285 | if alt_label_list is not None: 286 | eval_labels = alt_label_list 287 | else: 288 | eval_labels = cxr_labels 289 | 290 | results = [] 291 | for template in cxr_templates: 292 | print('Phrase being used: ', template) 293 | 294 | try: 295 | if softmax_eval: 296 | y_pred = run_softmax_eval(model, loader, eval_labels, template, context_length=context_length) 297 | 298 | else: 299 | # get single prediction 300 | y_pred = run_single_prediction(eval_labels, template, model, loader, 301 | softmax_eval=softmax_eval, context_length=context_length) 302 | # print("y_pred: ", y_pred) 303 | except: 304 | print("Argument error. Make sure cxr_templates is proper format.", sys.exc_info()[0]) 305 | raise 306 | 307 | # evaluate 308 | if use_bootstrap: 309 | # compute bootstrap stats 310 | boot_stats = bootstrap(y_pred, y_true, eval_labels, label_idx_map=alt_label_idx_map) 311 | results.append(boot_stats) # each template has a pandas array of samples 312 | else: 313 | stats = evaluate(y_pred, y_true, eval_labels) 314 | results.append(stats) 315 | 316 | return results, y_pred 317 | 318 | def make_true_labels( 319 | cxr_true_labels_path: str, 320 | cxr_labels: List[str], 321 | cutlabels: bool = True 322 | ): 323 | """ 324 | Loads in data containing the true binary labels 325 | for each pathology in `cxr_labels` for all samples. This 326 | is used for evaluation of model performance. 327 | 328 | args: 329 | * cxr_true_labels_path - str, path to csv containing ground truth labels 330 | * cxr_labels - List[str], subset of label columns to select from ground truth df 331 | * cutlabels - bool, if True, will keep columns of ground truth labels that correspond 332 | with the labels inputted through `cxr_labels`. Otherwise, drop the first column and keep remaining. 333 | 334 | Returns a numpy array of shape (# samples, # labels/pathologies) 335 | representing the binary ground truth labels for each pathology on each sample. 336 | """ 337 | # create ground truth labels 338 | full_labels = pd.read_csv(cxr_true_labels_path) 339 | if cutlabels: 340 | full_labels = full_labels.loc[:, cxr_labels] 341 | else: 342 | full_labels.drop(full_labels.columns[0], axis=1, inplace=True) 343 | 344 | y_true = full_labels.to_numpy() 345 | return y_true 346 | 347 | def make( 348 | model_path: str, 349 | cxr_filepath: str, 350 | pretrained: bool = True, 351 | context_length: bool = 77, 352 | ): 353 | """ 354 | FUNCTION: make 355 | ------------------------------------------- 356 | This function makes the model, the data loader, and the ground truth labels. 357 | 358 | args: 359 | * model_path - String for directory to the weights of the trained clip model. 360 | * context_length - int, max number of tokens of text inputted into the model. 361 | * cxr_filepath - String for path to the chest x-ray images. 362 | * cxr_labels - Python list of labels for a specific zero-shot task. (i.e. ['Atelectasis',...]) 363 | * pretrained - bool, whether or not model uses pretrained clip weights 364 | * cutlabels - bool, if True, will keep columns of ground truth labels that correspond 365 | with the labels inputted through `cxr_labels`. Otherwise, drop the first column and keep remaining. 366 | 367 | Returns model, data loader. 368 | """ 369 | # load model 370 | model = load_clip( 371 | model_path=model_path, 372 | pretrained=pretrained, 373 | context_length=context_length 374 | ) 375 | 376 | # load data 377 | transformations = [ 378 | # means computed from sample in `cxr_stats` notebook 379 | Normalize((101.48761, 101.48761, 101.48761), (83.43944, 83.43944, 83.43944)), 380 | ] 381 | # if using CLIP pretrained model 382 | if pretrained: 383 | # resize to input resolution of pretrained clip model 384 | input_resolution = 224 385 | transformations.append(Resize(input_resolution, interpolation=InterpolationMode.BICUBIC)) 386 | transform = Compose(transformations) 387 | 388 | # create dataset 389 | torch_dset = CXRTestDataset( 390 | img_path=cxr_filepath, 391 | transform=transform, 392 | ) 393 | loader = torch.utils.data.DataLoader(torch_dset, shuffle=False) 394 | 395 | return model, loader 396 | 397 | ## Run the model on the data set using ensembled models 398 | def ensemble_models( 399 | model_paths: List[str], 400 | cxr_filepath: str, 401 | cxr_labels: List[str], 402 | cxr_pair_template: Tuple[str], 403 | cache_dir: str = None, 404 | save_name: str = None, 405 | ) -> Tuple[List[np.ndarray], np.ndarray]: 406 | """ 407 | Given a list of `model_paths`, ensemble model and return 408 | predictions. Caches predictions at `cache_dir` if location provided. 409 | 410 | Returns a list of each model's predictions and the averaged 411 | set of predictions. 412 | """ 413 | 414 | predictions = [] 415 | model_paths = sorted(model_paths) # ensure consistency of 416 | for path in model_paths: # for each model 417 | model_name = Path(path).stem 418 | 419 | # load in model and `torch.DataLoader` 420 | model, loader = make( 421 | model_path=path, 422 | cxr_filepath=cxr_filepath, 423 | ) 424 | 425 | # path to the cached prediction 426 | if cache_dir is not None: 427 | if save_name is not None: 428 | cache_path = Path(cache_dir) / f"{save_name}_{model_name}.npy" 429 | else: 430 | cache_path = Path(cache_dir) / f"{model_name}.npy" 431 | 432 | # if prediction already cached, don't recompute prediction 433 | if cache_dir is not None and os.path.exists(cache_path): 434 | print("Loading cached prediction for {}".format(model_name)) 435 | y_pred = np.load(cache_path) 436 | else: # cached prediction not found, compute preds 437 | print("Inferring model {}".format(path)) 438 | y_pred = run_softmax_eval(model, loader, cxr_labels, cxr_pair_template) 439 | if cache_dir is not None: 440 | Path(cache_dir).mkdir(exist_ok=True, parents=True) 441 | np.save(file=cache_path, arr=y_pred) 442 | predictions.append(y_pred) 443 | 444 | # compute average predictions 445 | y_pred_avg = np.mean(predictions, axis=0) 446 | 447 | return predictions, y_pred_avg 448 | 449 | def run_zero_shot(cxr_labels, cxr_templates, model_path, cxr_filepath, final_label_path, alt_labels_dict: dict = None, softmax_eval = True, context_length=77, pretrained: bool = False, use_bootstrap=True, cutlabels=True): 450 | """ 451 | FUNCTION: run_zero_shot 452 | -------------------------------------- 453 | This function is the main function to run the zero-shot pipeline given a dataset, 454 | labels, templates for those labels, ground truth labels, and config parameters. 455 | 456 | args: 457 | * cxr_labels - list 458 | labels for a specific zero-shot task. (i.e. ['Atelectasis',...]) 459 | task can either be a string or a tuple (name of alternative label, name of label in csv) 460 | * cxr_templates - list, phrases that will be indpendently tested as input to the clip model. If `softmax_eval` is True, this parameter should be a 461 | list of positive and negative template pairs stored as tuples. 462 | * model_path - String for directory to the weights of the trained clip model. 463 | * cxr_filepath - String for path to the chest x-ray images. 464 | * final_label_path - String for path to ground truth labels. 465 | 466 | * alt_labels_dict (optional) - dict, map cxr_labels to list of alternative labels (i.e. 'Atelectasis': ['lung collapse', 'atelectatic lung', ...]) 467 | * softmax_eval (optional) - bool, if True, will evaluate results through softmax of pos vs. neg samples. 468 | * context_length (optional) - int, max number of tokens of text inputted into the model. 469 | * pretrained (optional) - bool, whether or not model uses pretrained clip weights 470 | * use_bootstrap (optional) - bool, whether or not to use bootstrap sampling 471 | * cutlabels (optional) - bool, if True, will keep columns of ground truth labels that correspond 472 | with the labels inputted through `cxr_labels`. Otherwise, drop the first column and keep remaining. 473 | 474 | Returns an array of results per template, each consists of a tuple containing a pandas dataframes 475 | for n bootstrap samples, and another pandas dataframe with the confidence intervals for each class. 476 | """ 477 | 478 | np.random.seed(97) 479 | # make the model, data loader, and ground truth labels 480 | model, loader = make( 481 | model_path=model_path, 482 | cxr_filepath=cxr_filepath, 483 | pretrained=pretrained, 484 | context_length=context_length 485 | ) 486 | 487 | y_true = make_true_labels( 488 | cxr_true_labels_path=final_label_path, 489 | cxr_labels=cxr_labels, 490 | cutlabels=cutlabels, 491 | ) 492 | 493 | # run multiphrase experiment 494 | results, y_pred = run_experiment(model, cxr_labels, cxr_templates, loader, y_true, 495 | alt_labels_dict=alt_labels_dict, softmax_eval=softmax_eval, context_length=context_length, use_bootstrap=use_bootstrap) 496 | return results, y_pred 497 | 498 | def run_cxr_zero_shot(model_path, context_length=77, pretrained=False): 499 | """ 500 | FUNCTION: run_cxr_zero_shot 501 | -------------------------------------- 502 | This function runs zero-shot specifically for the cxr dataset. 503 | The only difference between this function and `run_zero_shot` is that 504 | this function is already pre-parameterized for the 14 cxr labels evaluated 505 | using softmax method of positive and negative templates. 506 | 507 | args: 508 | * model_path - string, filepath of model being evaluated 509 | * context_length (optional) - int, max number of tokens of text inputted into the model. 510 | * pretrained (optional) - bool, whether or not model uses pretrained clip weights 511 | * use_bootstrap (optional) - bool, whether or not to use bootstrap sampling 512 | 513 | Returns an array of labels, and an array of results per template, 514 | each consists of a tuple containing a pandas dataframes 515 | for n bootstrap samples, and another pandas dataframe with the confidence intervals for each class. 516 | """ 517 | cxr_filepath = '/deep/group/data/med-data/test_cxr.h5' 518 | final_label_path = '/deep/group/data/med-data/final_paths.csv' 519 | 520 | cxr_labels = ['Atelectasis','Cardiomegaly', 521 | 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 522 | 'Lung Opacity', 'No Finding','Pleural Effusion', 'Pleural Other', 'Pneumonia', 523 | 'Pneumothorax', 'Support Devices'] 524 | 525 | # templates list of positive and negative template pairs 526 | cxr_templates = [("{}", "no {}")] 527 | 528 | cxr_results = run_zero_shot(cxr_labels, cxr_templates, model_path, cxr_filepath=cxr_filepath, final_label_path=final_label_path, softmax_eval=True, context_length=context_length, pretrained=pretrained, use_bootstrap=False, cutlabels=True) 529 | 530 | return cxr_labels, cxr_results[0] 531 | 532 | def validation_zero_shot(model_path, context_length=77, pretrained=False): 533 | """ 534 | FUNCTION: validation_zero_shot 535 | -------------------------------------- 536 | This function uses the CheXpert validation dataset to make predictions 537 | on an alternative task (ap/pa, sex) in order to tune hyperparameters. 538 | 539 | args: 540 | * model_path - string, filepath of model being evaluated 541 | * context_length (optional) - int, max number of tokens of text inputted into the model. 542 | * pretrained (optional) - bool, whether or not model uses pretrained clip weights 543 | * use_bootstrap (optional) - bool, whether or not to use bootstrap sampling 544 | 545 | Returns an array of labels, and an array of results per template, 546 | each consists of a tuple containing a pandas dataframes 547 | for n bootstrap samples, and another pandas dataframe with the confidence intervals for each class. 548 | """ 549 | cxr_sex_labels = ['Female', 'Male'] 550 | 551 | cxr_sex_templates = [ 552 | #'{}', 553 | # 'the patient is a {}', 554 | "the patient's sex is {}", 555 | ] 556 | 557 | # run zero shot experiment 558 | sex_labels_path = '../../data/val_sex_labels.csv' 559 | results = run_zero_shot(cxr_sex_labels, cxr_sex_templates, model_path, cxr_filepath=cxr_filepath, final_label_path=sex_labels_path, softmax_eval=False, context_length=context_length, pretrained=True, use_bootstrap=True, cutlabels=False) 560 | 561 | results = run_zero_shot(cxr_sex_labels, cxr_sex_templates, model_path, cxr_filepath=cxr_filepath, final_label_path=sex_labels_path, softmax_eval=False, context_length=context_length, pretrained=True, use_bootstrap=True, cutlabels=False) 562 | pass 563 | 564 | 565 | 566 | 567 | 568 | 569 | 570 | 571 | -------------------------------------------------------------------------------- /notebooks/zero_shot.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Sample Notebook for Zero-Shot Inference with CheXzero\n", 8 | "This notebook walks through how to use CheXzero to perform zero-shot inference on a chest x-ray image dataset." 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "## Import Libraries" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 16, 21 | "metadata": {}, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "The autoreload extension is already loaded. To reload it, use:\n", 28 | " %reload_ext autoreload\n" 29 | ] 30 | } 31 | ], 32 | "source": [ 33 | "import os\n", 34 | "import numpy as np\n", 35 | "import pandas as pd\n", 36 | "from pathlib import Path\n", 37 | "from typing import List, Tuple, Optional\n", 38 | "\n", 39 | "import sys\n", 40 | "sys.path.append('../')\n", 41 | "\n", 42 | "from eval import evaluate, bootstrap\n", 43 | "from zero_shot import make, make_true_labels, run_softmax_eval\n", 44 | "\n", 45 | "%load_ext autoreload\n", 46 | "%autoreload 2" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "## Directories and Constants" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 17, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "['../checkpoints/chexzero_weights/best_64_0.0001_original_17000_0.863.pt', '../checkpoints/chexzero_weights/best_128_5e-05_original_22000_0.855.pt', '../checkpoints/chexzero_weights/best_64_0.0001_original_35000_0.864.pt', '../checkpoints/chexzero_weights/best_64_5e-05_original_18000_0.862.pt', '../checkpoints/chexzero_weights/best_128_0.0002_original_8000_0.857.pt', '../checkpoints/chexzero_weights/best_64_5e-05_original_22000_0.864.pt', '../checkpoints/chexzero_weights/best_64_5e-05_original_16000_0.858.pt', '../checkpoints/chexzero_weights/best_128_0.0002_original_15000_0.859.pt', '../checkpoints/chexzero_weights/best_64_0.0002_original_23000_0.854.pt', '../checkpoints/chexzero_weights/best_64_0.0001_original_16000_0.861.pt']\n" 66 | ] 67 | } 68 | ], 69 | "source": [ 70 | "## Define Zero Shot Labels and Templates\n", 71 | "\n", 72 | "# ----- DIRECTORIES ------ #\n", 73 | "cxr_filepath: str = '../data/chexpert_test.h5' # filepath of chest x-ray images (.h5)\n", 74 | "cxr_true_labels_path: Optional[str] = '../data/groundtruth.csv' # (optional for evaluation) if labels are provided, provide path\n", 75 | "model_dir: str = '../checkpoints/chexzero_weights' # where pretrained models are saved (.pt) \n", 76 | "predictions_dir: Path = Path('../predictions') # where to save predictions\n", 77 | "cache_dir: str = predictions_dir / \"cached\" # where to cache ensembled predictions\n", 78 | "\n", 79 | "context_length: int = 77\n", 80 | "\n", 81 | "# ------- LABELS ------ #\n", 82 | "# Define labels to query each image | will return a prediction for each label\n", 83 | "cxr_labels: List[str] = ['Atelectasis','Cardiomegaly', \n", 84 | " 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',\n", 85 | " 'Lung Opacity', 'No Finding','Pleural Effusion', 'Pleural Other', 'Pneumonia', \n", 86 | " 'Pneumothorax', 'Support Devices']\n", 87 | "\n", 88 | "# ---- TEMPLATES ----- # \n", 89 | "# Define set of templates | see Figure 1 for more details \n", 90 | "cxr_pair_template: Tuple[str] = (\"{}\", \"no {}\")\n", 91 | "\n", 92 | "# ----- MODEL PATHS ------ #\n", 93 | "# If using ensemble, collect all model paths\n", 94 | "model_paths = []\n", 95 | "for subdir, dirs, files in os.walk(model_dir):\n", 96 | " for file in files:\n", 97 | " full_dir = os.path.join(subdir, file)\n", 98 | " model_paths.append(full_dir)\n", 99 | " \n", 100 | "print(model_paths)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "## Run Inference" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 19, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "## Run the model on the data set using ensembled models\n", 117 | "def ensemble_models(\n", 118 | " model_paths: List[str], \n", 119 | " cxr_filepath: str, \n", 120 | " cxr_labels: List[str], \n", 121 | " cxr_pair_template: Tuple[str], \n", 122 | " cache_dir: str = None, \n", 123 | " save_name: str = None,\n", 124 | ") -> Tuple[List[np.ndarray], np.ndarray]: \n", 125 | " \"\"\"\n", 126 | " Given a list of `model_paths`, ensemble model and return\n", 127 | " predictions. Caches predictions at `cache_dir` if location provided.\n", 128 | "\n", 129 | " Returns a list of each model's predictions and the averaged\n", 130 | " set of predictions.\n", 131 | " \"\"\"\n", 132 | "\n", 133 | " predictions = []\n", 134 | " model_paths = sorted(model_paths) # ensure consistency of \n", 135 | " for path in model_paths: # for each model\n", 136 | " model_name = Path(path).stem\n", 137 | "\n", 138 | " # load in model and `torch.DataLoader`\n", 139 | " model, loader = make(\n", 140 | " model_path=path, \n", 141 | " cxr_filepath=cxr_filepath, \n", 142 | " ) \n", 143 | " \n", 144 | " # path to the cached prediction\n", 145 | " if cache_dir is not None:\n", 146 | " if save_name is not None: \n", 147 | " cache_path = Path(cache_dir) / f\"{save_name}_{model_name}.npy\"\n", 148 | " else: \n", 149 | " cache_path = Path(cache_dir) / f\"{model_name}.npy\"\n", 150 | "\n", 151 | " # if prediction already cached, don't recompute prediction\n", 152 | " if cache_dir is not None and os.path.exists(cache_path): \n", 153 | " print(\"Loading cached prediction for {}\".format(model_name))\n", 154 | " y_pred = np.load(cache_path)\n", 155 | " else: # cached prediction not found, compute preds\n", 156 | " print(\"Inferring model {}\".format(path))\n", 157 | " y_pred = run_softmax_eval(model, loader, cxr_labels, cxr_pair_template)\n", 158 | " if cache_dir is not None: \n", 159 | " Path(cache_dir).mkdir(exist_ok=True, parents=True)\n", 160 | " np.save(file=cache_path, arr=y_pred)\n", 161 | " predictions.append(y_pred)\n", 162 | " \n", 163 | " # compute average predictions\n", 164 | " y_pred_avg = np.mean(predictions, axis=0)\n", 165 | " \n", 166 | " return predictions, y_pred_avg" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 21, 172 | "metadata": {}, 173 | "outputs": [ 174 | { 175 | "name": "stdout", 176 | "output_type": "stream", 177 | "text": [ 178 | "Inferring model ../checkpoints/chexzero_weights/best_128_0.0002_original_15000_0.859.pt\n" 179 | ] 180 | }, 181 | { 182 | "data": { 183 | "application/vnd.jupyter.widget-view+json": { 184 | "model_id": "9e09e7e227cb4d4d9871f0b06f02ff61", 185 | "version_major": 2, 186 | "version_minor": 0 187 | }, 188 | "text/plain": [ 189 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))" 190 | ] 191 | }, 192 | "metadata": {}, 193 | "output_type": "display_data" 194 | }, 195 | { 196 | "name": "stdout", 197 | "output_type": "stream", 198 | "text": [ 199 | "\n" 200 | ] 201 | }, 202 | { 203 | "data": { 204 | "application/vnd.jupyter.widget-view+json": { 205 | "model_id": "c9603105fe154361921ade24843bb63a", 206 | "version_major": 2, 207 | "version_minor": 0 208 | }, 209 | "text/plain": [ 210 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))" 211 | ] 212 | }, 213 | "metadata": {}, 214 | "output_type": "display_data" 215 | }, 216 | { 217 | "name": "stdout", 218 | "output_type": "stream", 219 | "text": [ 220 | "\n" 221 | ] 222 | }, 223 | { 224 | "data": { 225 | "application/vnd.jupyter.widget-view+json": { 226 | "model_id": "1ca871c5168b412eaed223fd4407f14c", 227 | "version_major": 2, 228 | "version_minor": 0 229 | }, 230 | "text/plain": [ 231 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))" 232 | ] 233 | }, 234 | "metadata": {}, 235 | "output_type": "display_data" 236 | }, 237 | { 238 | "name": "stdout", 239 | "output_type": "stream", 240 | "text": [ 241 | "\n" 242 | ] 243 | }, 244 | { 245 | "data": { 246 | "application/vnd.jupyter.widget-view+json": { 247 | "model_id": "8feea81cea1a4e44886f78cbd3dbd95e", 248 | "version_major": 2, 249 | "version_minor": 0 250 | }, 251 | "text/plain": [ 252 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))" 253 | ] 254 | }, 255 | "metadata": {}, 256 | "output_type": "display_data" 257 | }, 258 | { 259 | "name": "stdout", 260 | "output_type": "stream", 261 | "text": [ 262 | "\n", 263 | "Inferring model ../checkpoints/chexzero_weights/best_128_0.0002_original_8000_0.857.pt\n" 264 | ] 265 | }, 266 | { 267 | "data": { 268 | "application/vnd.jupyter.widget-view+json": { 269 | "model_id": "753a46885545435480f8a559b7a29955", 270 | "version_major": 2, 271 | "version_minor": 0 272 | }, 273 | "text/plain": [ 274 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))" 275 | ] 276 | }, 277 | "metadata": {}, 278 | "output_type": "display_data" 279 | }, 280 | { 281 | "name": "stdout", 282 | "output_type": "stream", 283 | "text": [ 284 | "\n" 285 | ] 286 | }, 287 | { 288 | "data": { 289 | "application/vnd.jupyter.widget-view+json": { 290 | "model_id": "251f63415a314d2ca2064d1d00a028ae", 291 | "version_major": 2, 292 | "version_minor": 0 293 | }, 294 | "text/plain": [ 295 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))" 296 | ] 297 | }, 298 | "metadata": {}, 299 | "output_type": "display_data" 300 | }, 301 | { 302 | "name": "stdout", 303 | "output_type": "stream", 304 | "text": [ 305 | "\n" 306 | ] 307 | }, 308 | { 309 | "data": { 310 | "application/vnd.jupyter.widget-view+json": { 311 | "model_id": "d1e56347bb704d8c9033d9ccec2f2015", 312 | "version_major": 2, 313 | "version_minor": 0 314 | }, 315 | "text/plain": [ 316 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))" 317 | ] 318 | }, 319 | "metadata": {}, 320 | "output_type": "display_data" 321 | }, 322 | { 323 | "name": "stdout", 324 | "output_type": "stream", 325 | "text": [ 326 | "\n" 327 | ] 328 | }, 329 | { 330 | "data": { 331 | "application/vnd.jupyter.widget-view+json": { 332 | "model_id": "8c53cb0b856e4a7c8dc4851aff857322", 333 | "version_major": 2, 334 | "version_minor": 0 335 | }, 336 | "text/plain": [ 337 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))" 338 | ] 339 | }, 340 | "metadata": {}, 341 | "output_type": "display_data" 342 | }, 343 | { 344 | "name": "stdout", 345 | "output_type": "stream", 346 | "text": [ 347 | "\n", 348 | "Inferring model ../checkpoints/chexzero_weights/best_128_5e-05_original_22000_0.855.pt\n" 349 | ] 350 | }, 351 | { 352 | "data": { 353 | "application/vnd.jupyter.widget-view+json": { 354 | "model_id": "a59cc7a934b64d03af79a60949c2e01f", 355 | "version_major": 2, 356 | "version_minor": 0 357 | }, 358 | "text/plain": [ 359 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))" 360 | ] 361 | }, 362 | "metadata": {}, 363 | "output_type": "display_data" 364 | }, 365 | { 366 | "name": "stdout", 367 | "output_type": "stream", 368 | "text": [ 369 | "\n" 370 | ] 371 | }, 372 | { 373 | "data": { 374 | "application/vnd.jupyter.widget-view+json": { 375 | "model_id": "f9f4cf985d5140759ee44a39625ff30d", 376 | "version_major": 2, 377 | "version_minor": 0 378 | }, 379 | "text/plain": [ 380 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))" 381 | ] 382 | }, 383 | "metadata": {}, 384 | "output_type": "display_data" 385 | }, 386 | { 387 | "name": "stdout", 388 | "output_type": "stream", 389 | "text": [ 390 | "\n" 391 | ] 392 | }, 393 | { 394 | "data": { 395 | "application/vnd.jupyter.widget-view+json": { 396 | "model_id": "ed6484585691494490fad5a40480dedb", 397 | "version_major": 2, 398 | "version_minor": 0 399 | }, 400 | "text/plain": [ 401 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))" 402 | ] 403 | }, 404 | "metadata": {}, 405 | "output_type": "display_data" 406 | }, 407 | { 408 | "name": "stdout", 409 | "output_type": "stream", 410 | "text": [ 411 | "\n" 412 | ] 413 | }, 414 | { 415 | "data": { 416 | "application/vnd.jupyter.widget-view+json": { 417 | "model_id": "795518422bd44731a78254d9cf1759fb", 418 | "version_major": 2, 419 | "version_minor": 0 420 | }, 421 | "text/plain": [ 422 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))" 423 | ] 424 | }, 425 | "metadata": {}, 426 | "output_type": "display_data" 427 | }, 428 | { 429 | "name": "stdout", 430 | "output_type": "stream", 431 | "text": [ 432 | "\n", 433 | "Inferring model ../checkpoints/chexzero_weights/best_64_0.0001_original_16000_0.861.pt\n" 434 | ] 435 | }, 436 | { 437 | "data": { 438 | "application/vnd.jupyter.widget-view+json": { 439 | "model_id": "4a565effbb694f639f0cdf633da11884", 440 | "version_major": 2, 441 | "version_minor": 0 442 | }, 443 | "text/plain": [ 444 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))" 445 | ] 446 | }, 447 | "metadata": {}, 448 | "output_type": "display_data" 449 | }, 450 | { 451 | "name": "stdout", 452 | "output_type": "stream", 453 | "text": [ 454 | "\n" 455 | ] 456 | }, 457 | { 458 | "data": { 459 | "application/vnd.jupyter.widget-view+json": { 460 | "model_id": "ea44e753e85e442a9ee9080d52108149", 461 | "version_major": 2, 462 | "version_minor": 0 463 | }, 464 | "text/plain": [ 465 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))" 466 | ] 467 | }, 468 | "metadata": {}, 469 | "output_type": "display_data" 470 | }, 471 | { 472 | "name": "stdout", 473 | "output_type": "stream", 474 | "text": [ 475 | "\n" 476 | ] 477 | }, 478 | { 479 | "data": { 480 | "application/vnd.jupyter.widget-view+json": { 481 | "model_id": "7f085cd056304e498a9dc251db09cb9a", 482 | "version_major": 2, 483 | "version_minor": 0 484 | }, 485 | "text/plain": [ 486 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))" 487 | ] 488 | }, 489 | "metadata": {}, 490 | "output_type": "display_data" 491 | }, 492 | { 493 | "name": "stdout", 494 | "output_type": "stream", 495 | "text": [ 496 | "\n" 497 | ] 498 | }, 499 | { 500 | "data": { 501 | "application/vnd.jupyter.widget-view+json": { 502 | "model_id": "dbaf547d56ab47fcafd7cccde650a9ef", 503 | "version_major": 2, 504 | "version_minor": 0 505 | }, 506 | "text/plain": [ 507 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))" 508 | ] 509 | }, 510 | "metadata": {}, 511 | "output_type": "display_data" 512 | }, 513 | { 514 | "name": "stdout", 515 | "output_type": "stream", 516 | "text": [ 517 | "\n", 518 | "Inferring model ../checkpoints/chexzero_weights/best_64_0.0001_original_17000_0.863.pt\n" 519 | ] 520 | }, 521 | { 522 | "data": { 523 | "application/vnd.jupyter.widget-view+json": { 524 | "model_id": "cd029103e56842e18feae09bb7488fd2", 525 | "version_major": 2, 526 | "version_minor": 0 527 | }, 528 | "text/plain": [ 529 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))" 530 | ] 531 | }, 532 | "metadata": {}, 533 | "output_type": "display_data" 534 | }, 535 | { 536 | "name": "stdout", 537 | "output_type": "stream", 538 | "text": [ 539 | "\n" 540 | ] 541 | }, 542 | { 543 | "data": { 544 | "application/vnd.jupyter.widget-view+json": { 545 | "model_id": "7534cec7de8e4f0d9cfc22776a94d52d", 546 | "version_major": 2, 547 | "version_minor": 0 548 | }, 549 | "text/plain": [ 550 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))" 551 | ] 552 | }, 553 | "metadata": {}, 554 | "output_type": "display_data" 555 | }, 556 | { 557 | "name": "stdout", 558 | "output_type": "stream", 559 | "text": [ 560 | "\n" 561 | ] 562 | }, 563 | { 564 | "data": { 565 | "application/vnd.jupyter.widget-view+json": { 566 | "model_id": "71ace33603504ca692a81330ad2c5c2a", 567 | "version_major": 2, 568 | "version_minor": 0 569 | }, 570 | "text/plain": [ 571 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))" 572 | ] 573 | }, 574 | "metadata": {}, 575 | "output_type": "display_data" 576 | }, 577 | { 578 | "name": "stdout", 579 | "output_type": "stream", 580 | "text": [ 581 | "\n" 582 | ] 583 | }, 584 | { 585 | "data": { 586 | "application/vnd.jupyter.widget-view+json": { 587 | "model_id": "e5b30aed5f47451b80bbc61da9d45ad6", 588 | "version_major": 2, 589 | "version_minor": 0 590 | }, 591 | "text/plain": [ 592 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))" 593 | ] 594 | }, 595 | "metadata": {}, 596 | "output_type": "display_data" 597 | }, 598 | { 599 | "name": "stdout", 600 | "output_type": "stream", 601 | "text": [ 602 | "\n", 603 | "Inferring model ../checkpoints/chexzero_weights/best_64_0.0001_original_35000_0.864.pt\n" 604 | ] 605 | }, 606 | { 607 | "data": { 608 | "application/vnd.jupyter.widget-view+json": { 609 | "model_id": "e9908db37e064d06af19924b9b02e1de", 610 | "version_major": 2, 611 | "version_minor": 0 612 | }, 613 | "text/plain": [ 614 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))" 615 | ] 616 | }, 617 | "metadata": {}, 618 | "output_type": "display_data" 619 | }, 620 | { 621 | "name": "stdout", 622 | "output_type": "stream", 623 | "text": [ 624 | "\n" 625 | ] 626 | }, 627 | { 628 | "data": { 629 | "application/vnd.jupyter.widget-view+json": { 630 | "model_id": "3bf4a3799ba142babd924ed5fe13eaf2", 631 | "version_major": 2, 632 | "version_minor": 0 633 | }, 634 | "text/plain": [ 635 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))" 636 | ] 637 | }, 638 | "metadata": {}, 639 | "output_type": "display_data" 640 | }, 641 | { 642 | "name": "stdout", 643 | "output_type": "stream", 644 | "text": [ 645 | "\n" 646 | ] 647 | }, 648 | { 649 | "data": { 650 | "application/vnd.jupyter.widget-view+json": { 651 | "model_id": "87c65ae5d98b4679907d25cb98157485", 652 | "version_major": 2, 653 | "version_minor": 0 654 | }, 655 | "text/plain": [ 656 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))" 657 | ] 658 | }, 659 | "metadata": {}, 660 | "output_type": "display_data" 661 | }, 662 | { 663 | "name": "stdout", 664 | "output_type": "stream", 665 | "text": [ 666 | "\n" 667 | ] 668 | }, 669 | { 670 | "data": { 671 | "application/vnd.jupyter.widget-view+json": { 672 | "model_id": "52db175c5f3b41148c56b8ecab7f789b", 673 | "version_major": 2, 674 | "version_minor": 0 675 | }, 676 | "text/plain": [ 677 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))" 678 | ] 679 | }, 680 | "metadata": {}, 681 | "output_type": "display_data" 682 | }, 683 | { 684 | "name": "stdout", 685 | "output_type": "stream", 686 | "text": [ 687 | "\n", 688 | "Inferring model ../checkpoints/chexzero_weights/best_64_0.0002_original_23000_0.854.pt\n" 689 | ] 690 | }, 691 | { 692 | "data": { 693 | "application/vnd.jupyter.widget-view+json": { 694 | "model_id": "e000f01f6d394e3584c8aebc74d965f0", 695 | "version_major": 2, 696 | "version_minor": 0 697 | }, 698 | "text/plain": [ 699 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))" 700 | ] 701 | }, 702 | "metadata": {}, 703 | "output_type": "display_data" 704 | }, 705 | { 706 | "name": "stdout", 707 | "output_type": "stream", 708 | "text": [ 709 | "\n" 710 | ] 711 | }, 712 | { 713 | "data": { 714 | "application/vnd.jupyter.widget-view+json": { 715 | "model_id": "92429fac05ce4e87ac8093ed1e6b6e0b", 716 | "version_major": 2, 717 | "version_minor": 0 718 | }, 719 | "text/plain": [ 720 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))" 721 | ] 722 | }, 723 | "metadata": {}, 724 | "output_type": "display_data" 725 | }, 726 | { 727 | "name": "stdout", 728 | "output_type": "stream", 729 | "text": [ 730 | "\n" 731 | ] 732 | }, 733 | { 734 | "data": { 735 | "application/vnd.jupyter.widget-view+json": { 736 | "model_id": "6daaede09a1745da99b23784ab1e65da", 737 | "version_major": 2, 738 | "version_minor": 0 739 | }, 740 | "text/plain": [ 741 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))" 742 | ] 743 | }, 744 | "metadata": {}, 745 | "output_type": "display_data" 746 | }, 747 | { 748 | "name": "stdout", 749 | "output_type": "stream", 750 | "text": [ 751 | "\n" 752 | ] 753 | }, 754 | { 755 | "data": { 756 | "application/vnd.jupyter.widget-view+json": { 757 | "model_id": "50eb0ec5ac9f4561a6eee9e0b64e325e", 758 | "version_major": 2, 759 | "version_minor": 0 760 | }, 761 | "text/plain": [ 762 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))" 763 | ] 764 | }, 765 | "metadata": {}, 766 | "output_type": "display_data" 767 | }, 768 | { 769 | "name": "stdout", 770 | "output_type": "stream", 771 | "text": [ 772 | "\n", 773 | "Inferring model ../checkpoints/chexzero_weights/best_64_5e-05_original_16000_0.858.pt\n" 774 | ] 775 | }, 776 | { 777 | "data": { 778 | "application/vnd.jupyter.widget-view+json": { 779 | "model_id": "75ed921a1c384da082f47aaae0d60db9", 780 | "version_major": 2, 781 | "version_minor": 0 782 | }, 783 | "text/plain": [ 784 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))" 785 | ] 786 | }, 787 | "metadata": {}, 788 | "output_type": "display_data" 789 | }, 790 | { 791 | "name": "stdout", 792 | "output_type": "stream", 793 | "text": [ 794 | "\n" 795 | ] 796 | }, 797 | { 798 | "data": { 799 | "application/vnd.jupyter.widget-view+json": { 800 | "model_id": "02ea0f79065342f7bd0c5ed353c1e6e4", 801 | "version_major": 2, 802 | "version_minor": 0 803 | }, 804 | "text/plain": [ 805 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))" 806 | ] 807 | }, 808 | "metadata": {}, 809 | "output_type": "display_data" 810 | }, 811 | { 812 | "name": "stdout", 813 | "output_type": "stream", 814 | "text": [ 815 | "\n" 816 | ] 817 | }, 818 | { 819 | "data": { 820 | "application/vnd.jupyter.widget-view+json": { 821 | "model_id": "f861a35947ab4068b05514f51c75ea22", 822 | "version_major": 2, 823 | "version_minor": 0 824 | }, 825 | "text/plain": [ 826 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))" 827 | ] 828 | }, 829 | "metadata": {}, 830 | "output_type": "display_data" 831 | }, 832 | { 833 | "name": "stdout", 834 | "output_type": "stream", 835 | "text": [ 836 | "\n" 837 | ] 838 | }, 839 | { 840 | "data": { 841 | "application/vnd.jupyter.widget-view+json": { 842 | "model_id": "7249bae84ec04a2392614d30d832c645", 843 | "version_major": 2, 844 | "version_minor": 0 845 | }, 846 | "text/plain": [ 847 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))" 848 | ] 849 | }, 850 | "metadata": {}, 851 | "output_type": "display_data" 852 | }, 853 | { 854 | "name": "stdout", 855 | "output_type": "stream", 856 | "text": [ 857 | "\n", 858 | "Inferring model ../checkpoints/chexzero_weights/best_64_5e-05_original_18000_0.862.pt\n" 859 | ] 860 | }, 861 | { 862 | "data": { 863 | "application/vnd.jupyter.widget-view+json": { 864 | "model_id": "75285bd7d3884a5fad36d80483421cd1", 865 | "version_major": 2, 866 | "version_minor": 0 867 | }, 868 | "text/plain": [ 869 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))" 870 | ] 871 | }, 872 | "metadata": {}, 873 | "output_type": "display_data" 874 | }, 875 | { 876 | "name": "stdout", 877 | "output_type": "stream", 878 | "text": [ 879 | "\n" 880 | ] 881 | }, 882 | { 883 | "data": { 884 | "application/vnd.jupyter.widget-view+json": { 885 | "model_id": "6ce671a4cc6c476f82bb657a22402f57", 886 | "version_major": 2, 887 | "version_minor": 0 888 | }, 889 | "text/plain": [ 890 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))" 891 | ] 892 | }, 893 | "metadata": {}, 894 | "output_type": "display_data" 895 | }, 896 | { 897 | "name": "stdout", 898 | "output_type": "stream", 899 | "text": [ 900 | "\n" 901 | ] 902 | }, 903 | { 904 | "data": { 905 | "application/vnd.jupyter.widget-view+json": { 906 | "model_id": "d11b53b9ecee49858b365ebf9a1104de", 907 | "version_major": 2, 908 | "version_minor": 0 909 | }, 910 | "text/plain": [ 911 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))" 912 | ] 913 | }, 914 | "metadata": {}, 915 | "output_type": "display_data" 916 | }, 917 | { 918 | "name": "stdout", 919 | "output_type": "stream", 920 | "text": [ 921 | "\n" 922 | ] 923 | }, 924 | { 925 | "data": { 926 | "application/vnd.jupyter.widget-view+json": { 927 | "model_id": "f83be71ddfa8430b868f5c00e8c708be", 928 | "version_major": 2, 929 | "version_minor": 0 930 | }, 931 | "text/plain": [ 932 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))" 933 | ] 934 | }, 935 | "metadata": {}, 936 | "output_type": "display_data" 937 | }, 938 | { 939 | "name": "stdout", 940 | "output_type": "stream", 941 | "text": [ 942 | "\n", 943 | "Inferring model ../checkpoints/chexzero_weights/best_64_5e-05_original_22000_0.864.pt\n" 944 | ] 945 | }, 946 | { 947 | "data": { 948 | "application/vnd.jupyter.widget-view+json": { 949 | "model_id": "48b10dc6aa524f5081861eef3a1fdf3f", 950 | "version_major": 2, 951 | "version_minor": 0 952 | }, 953 | "text/plain": [ 954 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))" 955 | ] 956 | }, 957 | "metadata": {}, 958 | "output_type": "display_data" 959 | }, 960 | { 961 | "name": "stdout", 962 | "output_type": "stream", 963 | "text": [ 964 | "\n" 965 | ] 966 | }, 967 | { 968 | "data": { 969 | "application/vnd.jupyter.widget-view+json": { 970 | "model_id": "d1f822091f93432eaa5f29c1d62d0e29", 971 | "version_major": 2, 972 | "version_minor": 0 973 | }, 974 | "text/plain": [ 975 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))" 976 | ] 977 | }, 978 | "metadata": {}, 979 | "output_type": "display_data" 980 | }, 981 | { 982 | "name": "stdout", 983 | "output_type": "stream", 984 | "text": [ 985 | "\n" 986 | ] 987 | }, 988 | { 989 | "data": { 990 | "application/vnd.jupyter.widget-view+json": { 991 | "model_id": "2ab6afefd29a4e008df820e3442def05", 992 | "version_major": 2, 993 | "version_minor": 0 994 | }, 995 | "text/plain": [ 996 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))" 997 | ] 998 | }, 999 | "metadata": {}, 1000 | "output_type": "display_data" 1001 | }, 1002 | { 1003 | "name": "stdout", 1004 | "output_type": "stream", 1005 | "text": [ 1006 | "\n" 1007 | ] 1008 | }, 1009 | { 1010 | "data": { 1011 | "application/vnd.jupyter.widget-view+json": { 1012 | "model_id": "fff116a1443d4e9b96f41fbe53ebd1ac", 1013 | "version_major": 2, 1014 | "version_minor": 0 1015 | }, 1016 | "text/plain": [ 1017 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))" 1018 | ] 1019 | }, 1020 | "metadata": {}, 1021 | "output_type": "display_data" 1022 | }, 1023 | { 1024 | "name": "stdout", 1025 | "output_type": "stream", 1026 | "text": [ 1027 | "\n" 1028 | ] 1029 | } 1030 | ], 1031 | "source": [ 1032 | "predictions, y_pred_avg = ensemble_models(\n", 1033 | " model_paths=model_paths, \n", 1034 | " cxr_filepath=cxr_filepath, \n", 1035 | " cxr_labels=cxr_labels, \n", 1036 | " cxr_pair_template=cxr_pair_template, \n", 1037 | " cache_dir=cache_dir,\n", 1038 | ")" 1039 | ] 1040 | }, 1041 | { 1042 | "cell_type": "code", 1043 | "execution_count": 22, 1044 | "metadata": {}, 1045 | "outputs": [], 1046 | "source": [ 1047 | "# save averaged preds\n", 1048 | "pred_name = \"chexpert_preds.npy\" # add name of preds\n", 1049 | "predictions_dir = predictions_dir / pred_name\n", 1050 | "np.save(file=predictions_dir, arr=y_pred_avg)" 1051 | ] 1052 | }, 1053 | { 1054 | "cell_type": "markdown", 1055 | "metadata": {}, 1056 | "source": [ 1057 | "## (Optional) Evaluate Results\n", 1058 | "If ground truth labels are available, compute AUC on each pathology to evaluate the performance of the zero-shot model. " 1059 | ] 1060 | }, 1061 | { 1062 | "cell_type": "code", 1063 | "execution_count": 23, 1064 | "metadata": {}, 1065 | "outputs": [ 1066 | { 1067 | "data": { 1068 | "application/vnd.jupyter.widget-view+json": { 1069 | "model_id": "a338f58b25a94d68a3be00565cbaca39", 1070 | "version_major": 2, 1071 | "version_minor": 0 1072 | }, 1073 | "text/plain": [ 1074 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 1075 | ] 1076 | }, 1077 | "metadata": {}, 1078 | "output_type": "display_data" 1079 | }, 1080 | { 1081 | "name": "stdout", 1082 | "output_type": "stream", 1083 | "text": [ 1084 | "\n" 1085 | ] 1086 | } 1087 | ], 1088 | "source": [ 1089 | "# make test_true\n", 1090 | "test_pred = y_pred_avg\n", 1091 | "test_true = make_true_labels(cxr_true_labels_path=cxr_true_labels_path, cxr_labels=cxr_labels)\n", 1092 | "\n", 1093 | "# evaluate model\n", 1094 | "cxr_results = evaluate(test_pred, test_true, cxr_labels)\n", 1095 | "\n", 1096 | "# boostrap evaluations for 95% confidence intervals\n", 1097 | "bootstrap_results = bootstrap(test_pred, test_true, cxr_labels)" 1098 | ] 1099 | }, 1100 | { 1101 | "cell_type": "code", 1102 | "execution_count": 25, 1103 | "metadata": {}, 1104 | "outputs": [ 1105 | { 1106 | "data": { 1107 | "text/html": [ 1108 | "
\n", 1109 | "\n", 1122 | "\n", 1123 | " \n", 1124 | " \n", 1125 | " \n", 1126 | " \n", 1127 | " \n", 1128 | " \n", 1129 | " \n", 1130 | " \n", 1131 | " \n", 1132 | " \n", 1133 | " \n", 1134 | " \n", 1135 | " \n", 1136 | " \n", 1137 | " \n", 1138 | " \n", 1139 | " \n", 1140 | " \n", 1141 | " \n", 1142 | " \n", 1143 | " \n", 1144 | " \n", 1145 | " \n", 1146 | " \n", 1147 | " \n", 1148 | " \n", 1149 | " \n", 1150 | " \n", 1151 | " \n", 1152 | " \n", 1153 | " \n", 1154 | " \n", 1155 | " \n", 1156 | " \n", 1157 | " \n", 1158 | " \n", 1159 | " \n", 1160 | " \n", 1161 | " \n", 1162 | " \n", 1163 | " \n", 1164 | " \n", 1165 | " \n", 1166 | " \n", 1167 | " \n", 1168 | " \n", 1169 | " \n", 1170 | " \n", 1171 | " \n", 1172 | " \n", 1173 | " \n", 1174 | " \n", 1175 | " \n", 1176 | " \n", 1177 | " \n", 1178 | " \n", 1179 | " \n", 1180 | " \n", 1181 | " \n", 1182 | " \n", 1183 | " \n", 1184 | " \n", 1185 | " \n", 1186 | " \n", 1187 | " \n", 1188 | " \n", 1189 | " \n", 1190 | " \n", 1191 | " \n", 1192 | " \n", 1193 | " \n", 1194 | " \n", 1195 | "
Atelectasis_aucCardiomegaly_aucConsolidation_aucEdema_aucEnlarged Cardiomediastinum_aucFracture_aucLung Lesion_aucLung Opacity_aucNo Finding_aucPleural Effusion_aucPleural Other_aucPneumonia_aucPneumothorax_aucSupport Devices_auc
mean0.81180.91320.89010.89940.91600.56030.73600.92130.07000.93170.60250.77980.65200.7735
lower0.77200.88490.82010.86620.89120.26460.56580.89610.04510.90530.46080.56950.48540.7310
upper0.84790.93670.94700.92950.93750.87250.87790.94260.09520.95360.88550.94830.82430.8130
\n", 1196 | "
" 1197 | ], 1198 | "text/plain": [ 1199 | " Atelectasis_auc Cardiomegaly_auc Consolidation_auc Edema_auc \\\n", 1200 | "mean 0.8118 0.9132 0.8901 0.8994 \n", 1201 | "lower 0.7720 0.8849 0.8201 0.8662 \n", 1202 | "upper 0.8479 0.9367 0.9470 0.9295 \n", 1203 | "\n", 1204 | " Enlarged Cardiomediastinum_auc Fracture_auc Lung Lesion_auc \\\n", 1205 | "mean 0.9160 0.5603 0.7360 \n", 1206 | "lower 0.8912 0.2646 0.5658 \n", 1207 | "upper 0.9375 0.8725 0.8779 \n", 1208 | "\n", 1209 | " Lung Opacity_auc No Finding_auc Pleural Effusion_auc \\\n", 1210 | "mean 0.9213 0.0700 0.9317 \n", 1211 | "lower 0.8961 0.0451 0.9053 \n", 1212 | "upper 0.9426 0.0952 0.9536 \n", 1213 | "\n", 1214 | " Pleural Other_auc Pneumonia_auc Pneumothorax_auc Support Devices_auc \n", 1215 | "mean 0.6025 0.7798 0.6520 0.7735 \n", 1216 | "lower 0.4608 0.5695 0.4854 0.7310 \n", 1217 | "upper 0.8855 0.9483 0.8243 0.8130 " 1218 | ] 1219 | }, 1220 | "execution_count": 25, 1221 | "metadata": {}, 1222 | "output_type": "execute_result" 1223 | } 1224 | ], 1225 | "source": [ 1226 | "# display AUC with confidence intervals\n", 1227 | "bootstrap_results[1]" 1228 | ] 1229 | }, 1230 | { 1231 | "cell_type": "code", 1232 | "execution_count": null, 1233 | "metadata": {}, 1234 | "outputs": [], 1235 | "source": [] 1236 | } 1237 | ], 1238 | "metadata": { 1239 | "interpreter": { 1240 | "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e" 1241 | }, 1242 | "kernelspec": { 1243 | "display_name": "Python 3", 1244 | "language": "python", 1245 | "name": "python3" 1246 | }, 1247 | "language_info": { 1248 | "codemirror_mode": { 1249 | "name": "ipython", 1250 | "version": 3 1251 | }, 1252 | "file_extension": ".py", 1253 | "mimetype": "text/x-python", 1254 | "name": "python", 1255 | "nbconvert_exporter": "python", 1256 | "pygments_lexer": "ipython3", 1257 | "version": "3.8.5" 1258 | } 1259 | }, 1260 | "nbformat": 4, 1261 | "nbformat_minor": 2 1262 | } 1263 | --------------------------------------------------------------------------------