├── data
└── .gitignore
├── checkpoints
└── .gitignore
├── predictions
└── .gitignore
├── bpe_simple_vocab_16e6.txt.gz
├── .gitignore
├── requirements.txt
├── LICENSE
├── NOTICE.md
├── run_preprocess.py
├── data_process.py
├── run_train.py
├── simple_tokenizer.py
├── eval.py
├── clip.py
├── preprocess_padchest.py
├── train.py
├── README.md
├── metrics.py
├── model.py
├── zero_shot.py
└── notebooks
└── zero_shot.ipynb
/data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | */
3 | !.gitignore
4 |
--------------------------------------------------------------------------------
/checkpoints/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | */
3 | !.gitignore
4 |
--------------------------------------------------------------------------------
/predictions/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | */
3 | !.gitignore
4 |
--------------------------------------------------------------------------------
/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rajpurkarlab/CheXzero/HEAD/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | model.pt
2 | notebooks/.ipynb_checkpoints
3 | .ipynb_checkpoints
4 | __pycache__
5 | notebooks/clip_v1_0.1_state_dict.pt
6 | notebooks/models
7 | notebooks/train/wandb
8 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | albumentations==1.1.0
2 | argparse==1.4.0
3 | ftfy==6.1.1
4 | grpcio==1.46.1
5 | h5py==3.1.0
6 | huggingface-hub==0.6.0
7 | imageio==2.19.1
8 | joblib==1.0.1
9 | matplotlib==3.3.4
10 | numpy==1.19.5
11 | opencv-python==4.5.3.56
12 | opencv-python-headless==4.1.2.30
13 | pandas==1.2.1
14 | pathlib==1.0.1
15 | plotly==5.9.0
16 | psutil==5.8.0
17 | python-dateutil==2.8.1
18 | regex==2020.11.13
19 | scikit-image==0.19.2
20 | scikit-learn==0.24.1
21 | scipy==1.6.1
22 | sklearn==0.0
23 | tifffile==2022.5.4
24 | tokenizers==0.12.1
25 | torch==1.10.2
26 | torchaudio==0.10.2
27 | torchvision==0.11.3
28 | transformers==4.19.0
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Rajpurkar Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/NOTICE.md:
--------------------------------------------------------------------------------
1 | # Notices for CheXzero
2 | This software incorporates material from third parties.
3 |
4 | ## Project Licenses
5 | The source code of this repository was derived from CLIP developed by OpenAI (https://github.com/openai/CLIP). This work uses and modifies code that defines the CLIP model architecture, preprocesses unstructured text, and runs inference.
6 |
7 | ### Open Source License / Copyright Notice
8 | ```
9 | MIT License
10 |
11 | Copyright (c) 2021 OpenAI
12 |
13 | Permission is hereby granted, free of charge, to any person obtaining a copy
14 | of this software and associated documentation files (the "Software"), to deal
15 | in the Software without restriction, including without limitation the rights
16 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
17 | copies of the Software, and to permit persons to whom the Software is
18 | furnished to do so, subject to the following conditions:
19 |
20 | The above copyright notice and this permission notice shall be included in all
21 | copies or substantial portions of the Software.
22 |
23 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
24 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
26 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
28 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29 | SOFTWARE.
30 | ```
31 |
--------------------------------------------------------------------------------
/run_preprocess.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | from data_process import get_cxr_paths_list, img_to_hdf5, get_cxr_path_csv, write_report_csv
4 |
5 |
6 | def parse_args():
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--csv_out_path', type=str, default='data/cxr_paths.csv', help="Directory to save paths to all chest x-ray images in dataset.")
9 | parser.add_argument('--cxr_out_path', type=str, default='data/cxr.h5', help="Directory to save processed chest x-ray image data.")
10 | parser.add_argument('--dataset_type', type=str, default='mimic', choices=['mimic', 'chexpert-test'], help="Type of dataset to pre-process")
11 | parser.add_argument('--mimic_impressions_path', default='data/mimic_impressions.csv', help="Directory to save extracted impressions from radiology reports.")
12 | parser.add_argument('--chest_x_ray_path', default='/deep/group/data/mimic-cxr/mimic-cxr-jpg/2.0.0/files', help="Directory where chest x-ray image data is stored. This should point to the files folder from the MIMIC chest x-ray dataset.")
13 | parser.add_argument('--radiology_reports_path', default='/deep/group/data/med-data/files/', help="Directory radiology reports are stored. This should point to the files folder from the MIMIC radiology reports dataset.")
14 | args = parser.parse_args()
15 | return args
16 |
17 | if __name__ == "__main__":
18 | args = parse_args()
19 | if args.dataset_type == "mimic":
20 | # Write Chest X-ray Image HDF5 File
21 | get_cxr_path_csv(args.csv_out_path, args.chest_x_ray_path)
22 | cxr_paths = get_cxr_paths_list(args.csv_out_path)
23 | img_to_hdf5(cxr_paths, args.cxr_out_path)
24 |
25 | #Write CSV File Containing Impressions for each Chest X-ray
26 | write_report_csv(cxr_paths, args.radiology_reports_path, args.mimic_impressions_path)
27 | elif args.dataset_type == "chexpert-test":
28 | # Get all test paths based on cxr dir
29 | cxr_dir = Path(args.chest_x_ray_path)
30 | cxr_paths = list(cxr_dir.rglob("*.jpg"))
31 | cxr_paths = list(filter(lambda x: "view1" in str(x), cxr_paths)) # filter only first frontal views
32 | cxr_paths = sorted(cxr_paths) # sort to align with groundtruth
33 | assert(len(cxr_paths) == 500)
34 |
35 | img_to_hdf5(cxr_paths, args.cxr_out_path)
36 |
37 |
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/data_process.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import glob
4 | import numpy as np
5 | import pandas as pd
6 | import csv
7 | import matplotlib.pyplot as plt
8 | from tqdm import tqdm
9 |
10 | from PIL import Image
11 | import h5py
12 | import cv2
13 | from typing import *
14 | from pathlib import Path
15 |
16 | import torch
17 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
18 |
19 | def load_data(filepath):
20 | dataframe = pd.read_csv(filepath)
21 | return dataframe
22 |
23 | def get_cxr_paths_list(filepath):
24 | dataframe = load_data(filepath)
25 | cxr_paths = dataframe['Path']
26 | return cxr_paths
27 |
28 | '''
29 | This function resizes and zero pads image
30 | '''
31 | def preprocess(img, desired_size=320):
32 | old_size = img.size
33 | ratio = float(desired_size)/max(old_size)
34 | new_size = tuple([int(x*ratio) for x in old_size])
35 | img = img.resize(new_size, Image.ANTIALIAS)
36 | # create a new image and paste the resized on it
37 |
38 | new_img = Image.new('L', (desired_size, desired_size))
39 | new_img.paste(img, ((desired_size-new_size[0])//2,
40 | (desired_size-new_size[1])//2))
41 | return new_img
42 |
43 | def img_to_hdf5(cxr_paths: List[Union[str, Path]], out_filepath: str, resolution=320):
44 | """
45 | Convert directory of images into a .h5 file given paths to all
46 | images.
47 | """
48 | dset_size = len(cxr_paths)
49 | failed_images = []
50 | with h5py.File(out_filepath,'w') as h5f:
51 | img_dset = h5f.create_dataset('cxr', shape=(dset_size, resolution, resolution))
52 | for idx, path in enumerate(tqdm(cxr_paths)):
53 | try:
54 | # read image using cv2
55 | img = cv2.imread(str(path))
56 | # convert to PIL Image object
57 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
58 | img_pil = Image.fromarray(img)
59 | # preprocess
60 | img = preprocess(img_pil, desired_size=resolution)
61 | img_dset[idx] = img
62 | except Exception as e:
63 | failed_images.append((path, e))
64 | print(f"{len(failed_images)} / {len(cxr_paths)} images failed to be added to h5.", failed_images)
65 |
66 | def get_files(directory):
67 | files = []
68 | for (dirpath, dirnames, filenames) in os.walk(directory):
69 | for file in filenames:
70 | if file.endswith(".jpg"):
71 | files.append(os.path.join(dirpath, file))
72 | return files
73 |
74 | def get_cxr_path_csv(out_filepath, directory):
75 | files = get_files(directory)
76 | file_dict = {"Path": files}
77 | df = pd.DataFrame(file_dict)
78 | df.to_csv(out_filepath, index=False)
79 |
80 | def section_start(lines, section=' IMPRESSION'):
81 | for idx, line in enumerate(lines):
82 | if line.startswith(section):
83 | return idx
84 | return -1
85 |
86 | def section_end(lines, section_start):
87 | num_lines = len(lines)
88 |
89 | def getIndexOfLast(l, element):
90 | """ Get index of last occurence of element
91 | @param l (list): list of elements
92 | @param element (string): element to search for
93 | @returns (int): index of last occurrence of element
94 | """
95 | i = max(loc for loc, val in enumerate(l) if val == element)
96 | return i
97 |
98 | def write_report_csv(cxr_paths, txt_folder, out_path):
99 | imps = {"filename": [], "impression": []}
100 | txt_reports = []
101 | for cxr_path in cxr_paths:
102 | tokens = cxr_path.split('/')
103 | study_num = tokens[-2]
104 | patient_num = tokens[-3]
105 | patient_group = tokens[-4]
106 | txt_report = txt_folder + patient_group + '/' + patient_num + '/' + study_num + '.txt'
107 | filename = study_num + '.txt'
108 | f = open(txt_report, 'r')
109 | s = f.read()
110 | s_split = s.split()
111 | if "IMPRESSION:" in s_split:
112 | begin = getIndexOfLast(s_split, "IMPRESSION:") + 1
113 | end = None
114 | end_cand1 = None
115 | end_cand2 = None
116 | # remove recommendation(s) and notification
117 | if "RECOMMENDATION(S):" in s_split:
118 | end_cand1 = s_split.index("RECOMMENDATION(S):")
119 | elif "RECOMMENDATION:" in s_split:
120 | end_cand1 = s_split.index("RECOMMENDATION:")
121 | elif "RECOMMENDATIONS:" in s_split:
122 | end_cand1 = s_split.index("RECOMMENDATIONS:")
123 |
124 | if "NOTIFICATION:" in s_split:
125 | end_cand2 = s_split.index("NOTIFICATION:")
126 | elif "NOTIFICATIONS:" in s_split:
127 | end_cand2 = s_split.index("NOTIFICATIONS:")
128 |
129 | if end_cand1 and end_cand2:
130 | end = min(end_cand1, end_cand2)
131 | elif end_cand1:
132 | end = end_cand1
133 | elif end_cand2:
134 | end = end_cand2
135 |
136 | if end == None:
137 | imp = " ".join(s_split[begin:])
138 | else:
139 | imp = " ".join(s_split[begin:end])
140 | else:
141 | imp = 'NO IMPRESSION'
142 |
143 | imps["impression"].append(imp)
144 | imps["filename"].append(filename)
145 |
146 | df = pd.DataFrame(data=imps)
147 | df.to_csv(out_path, index=False)
148 |
149 |
--------------------------------------------------------------------------------
/run_train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pprint
3 | import argparse
4 | from tqdm import tqdm
5 |
6 | import torch
7 | from torch.utils import data
8 | from torch import nn
9 | import torch.optim as optim
10 | from torchvision.transforms import Compose, Normalize, Resize
11 |
12 | import clip
13 | from model import CLIP
14 | from simple_tokenizer import SimpleTokenizer
15 |
16 | from train import train_main, load_data, load_clip, preprocess_text
17 | from zero_shot import run_cxr_zero_shot, run_zero_shot
18 |
19 | def parse_args():
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--cxr_filepath', type=str, default='data/cxr.h5', help="Directory to load chest x-ray image data from.")
22 | parser.add_argument('--txt_filepath', type=str, default='data/mimic_impressions.csv', help="Directory to load radiology report impressions text from.")
23 | parser.add_argument('--batch_size', type=int, default=16)
24 | parser.add_argument('--epochs', type=int, default=4)
25 | parser.add_argument('--lr', type=float, default=1e-4)
26 | parser.add_argument('--save_interval', type=int, default=100)
27 | parser.add_argument('--log_interval', type=int, default=10)
28 | parser.add_argument('--save_dir', type=str, default="checkpoints/", help="Directory to save the trained model.")
29 | parser.add_argument('--seed', type=int, default=1234)
30 | parser.add_argument('--optimizer', type=str, default="sgd")
31 | parser.add_argument('--momentum', type=float, default=0.9)
32 | parser.add_argument('--context_length', type=int, default=77)
33 | parser.add_argument('--random_init', action='store_true')
34 | parser.add_argument('--model_name', type=str, default="pt-imp")
35 | args = parser.parse_args()
36 | return args
37 |
38 | def model_pipeline(config, verbose=0):
39 | # make the model, data, and optimization problem
40 | model, data_loader, device, criterion, optimizer = make(config)
41 |
42 | # and use them to train the model
43 | train(model, data_loader, device, criterion, optimizer, config)
44 |
45 | # save model
46 | model_path = os.path.join(config.save_dir, str(config.model_name), 'checkpoint.pt')
47 | save(model, model_path)
48 |
49 | if verbose:
50 | print(model)
51 | return model
52 |
53 | def make(config):
54 | pretrained = not config.random_init
55 | data_loader, device = load_data(config.cxr_filepath, config.txt_filepath, batch_size=config.batch_size, pretrained=pretrained, column="impression")
56 | model = load_clip(model_path=None, pretrained=pretrained, context_length=config.context_length)
57 | model.to(device)
58 | print('Model on Device.')
59 |
60 | # make the optimizer
61 | criterion = nn.CrossEntropyLoss().cuda()
62 | if config.optimizer == "adam":
63 | optimizer = optim.AdamW(model.parameters(), lr=config.lr)
64 | elif config.optimizer == "sgd":
65 | optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum)
66 | return model, data_loader, device, criterion, optimizer
67 |
68 | def train(model, loader, device, criterion, optimizer, config):
69 | model_save_dir = os.path.join(config.save_dir, config.model_name)
70 | if not os.path.exists(model_save_dir):
71 | # Create a new folder if not exists
72 | os.makedirs(model_save_dir)
73 |
74 | # Run training
75 | total_batches = len(loader) * config.epochs
76 | example_ct = 0 # number of examples seen
77 | batch_ct = 0
78 | report_freq = config.log_interval
79 | highest_val_auc = 0 # save highest mean auc
80 |
81 | for epoch in range(config.epochs):
82 | running_loss = 0.0 # running loss over batch
83 | for data in tqdm(loader):
84 | # get the images
85 | images = data['img']
86 |
87 | texts = data['txt']
88 | texts = preprocess_text(texts, model)
89 |
90 | # perform step for a single batch
91 | loss = train_batch(images, texts, model, device, criterion, optimizer)
92 | example_ct += len(images)
93 | batch_ct += 1
94 | running_loss += loss.item()
95 |
96 | # Report metrics every `report_freq` batch
97 | if (batch_ct % report_freq) == 0:
98 | train_log(running_loss / report_freq, example_ct, epoch)
99 | running_loss = 0.0
100 |
101 | if (batch_ct % config.save_interval) == 0:
102 | model_path = os.path.join(model_save_dir, "checkpoint_{batch_ct}.pt".format(
103 | batch_ct=str(batch_ct),
104 | ))
105 | print("Saved checkpoint to: ", model_path)
106 | save(model, model_path)
107 |
108 | def train_batch(images, texts, model, device, criterion, optimizer):
109 | images, texts = images.to(device), texts.to(device)
110 |
111 | # Forward pass ➡
112 | logits_per_image, logits_per_text = model(images, texts)
113 |
114 | # Create labels
115 | batch_size = images.shape[0]
116 | labels = torch.arange(batch_size).to(device)
117 |
118 | # Compute loss
119 | loss_img = criterion(logits_per_image, labels)
120 | loss_txt = criterion(logits_per_text, labels)
121 | loss = (loss_img + loss_txt)/2 # avg. img and txt loss
122 |
123 | # Backward pass ⬅
124 | optimizer.zero_grad()
125 | loss.backward()
126 |
127 | # Step with optimizer
128 | optimizer.step()
129 |
130 | return loss
131 |
132 | def train_log(loss, example_ct, epoch):
133 | loss = float(loss)
134 | print(f"Loss after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}")
135 |
136 | def save(model, path):
137 | torch.save(model.state_dict(), path)
138 |
139 | if __name__ == "__main__":
140 | args = parse_args()
141 | model = model_pipeline(args)
142 |
143 |
144 |
--------------------------------------------------------------------------------
/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2021 OpenAI
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
24 | """
25 | import gzip
26 | import html
27 | import os
28 | from functools import lru_cache
29 |
30 | import ftfy
31 | import regex as re
32 |
33 |
34 | @lru_cache()
35 | def default_bpe():
36 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
37 |
38 |
39 | @lru_cache()
40 | def bytes_to_unicode():
41 | """
42 | Returns list of utf-8 byte and a corresponding list of unicode strings.
43 | The reversible bpe codes work on unicode strings.
44 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
45 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
46 | This is a signficant percentage of your normal, say, 32K bpe vocab.
47 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
48 | And avoids mapping to whitespace/control characters the bpe code barfs on.
49 | """
50 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
51 | cs = bs[:]
52 | n = 0
53 | for b in range(2**8):
54 | if b not in bs:
55 | bs.append(b)
56 | cs.append(2**8+n)
57 | n += 1
58 | cs = [chr(n) for n in cs]
59 | return dict(zip(bs, cs))
60 |
61 |
62 | def get_pairs(word):
63 | """Return set of symbol pairs in a word.
64 | Word is represented as tuple of symbols (symbols being variable-length strings).
65 | """
66 | pairs = set()
67 | prev_char = word[0]
68 | for char in word[1:]:
69 | pairs.add((prev_char, char))
70 | prev_char = char
71 | return pairs
72 |
73 |
74 | def basic_clean(text):
75 | text = ftfy.fix_text(text)
76 | text = html.unescape(html.unescape(text))
77 | return text.strip()
78 |
79 |
80 | def whitespace_clean(text):
81 | text = re.sub(r'\s+', ' ', text)
82 | text = text.strip()
83 | return text
84 |
85 |
86 | class SimpleTokenizer(object):
87 | def __init__(self, bpe_path: str = default_bpe()):
88 | self.byte_encoder = bytes_to_unicode()
89 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
90 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
91 | merges = merges[1:49152-256-2+1]
92 | merges = [tuple(merge.split()) for merge in merges]
93 | vocab = list(bytes_to_unicode().values())
94 | vocab = vocab + [v+'' for v in vocab]
95 | for merge in merges:
96 | vocab.append(''.join(merge))
97 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
98 | self.encoder = dict(zip(vocab, range(len(vocab))))
99 | self.decoder = {v: k for k, v in self.encoder.items()}
100 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
101 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
102 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
103 |
104 | def bpe(self, token):
105 | if token in self.cache:
106 | return self.cache[token]
107 | word = tuple(token[:-1]) + ( token[-1] + '',)
108 | pairs = get_pairs(word)
109 |
110 | if not pairs:
111 | return token+''
112 |
113 | while True:
114 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
115 | if bigram not in self.bpe_ranks:
116 | break
117 | first, second = bigram
118 | new_word = []
119 | i = 0
120 | while i < len(word):
121 | try:
122 | j = word.index(first, i)
123 | new_word.extend(word[i:j])
124 | i = j
125 | except:
126 | new_word.extend(word[i:])
127 | break
128 |
129 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
130 | new_word.append(first+second)
131 | i += 2
132 | else:
133 | new_word.append(word[i])
134 | i += 1
135 | new_word = tuple(new_word)
136 | word = new_word
137 | if len(word) == 1:
138 | break
139 | else:
140 | pairs = get_pairs(word)
141 | word = ' '.join(word)
142 | self.cache[token] = word
143 | return word
144 |
145 | def encode(self, text):
146 | bpe_tokens = []
147 | text = whitespace_clean(basic_clean(text)).lower()
148 | for token in re.findall(self.pat, text):
149 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
150 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
151 | return bpe_tokens
152 |
153 | def decode(self, tokens):
154 | text = ''.join([self.decoder[token] for token in tokens])
155 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
156 | return text
157 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import numpy as np
3 | import os
4 | import pandas as pd
5 | from PIL import Image
6 | import h5py
7 | import matplotlib.pyplot as plt
8 | from typing import List, Callable
9 |
10 | import torch
11 | from torch.utils import data
12 | from tqdm.notebook import tqdm
13 | import torch.nn as nn
14 | from torchvision.transforms import Compose, Normalize, Resize
15 |
16 | import sklearn
17 | from sklearn.metrics import matthews_corrcoef, confusion_matrix, accuracy_score, auc, roc_auc_score, roc_curve, classification_report
18 | from sklearn.metrics import precision_recall_curve, f1_score
19 | from sklearn.metrics import average_precision_score
20 | from sklearn.utils import resample
21 |
22 | import scipy
23 | import scipy.stats
24 |
25 | import sys
26 | sys.path.append('../..')
27 |
28 | import clip
29 | from model import CLIP
30 |
31 | def compute_mean(stats, is_df=True):
32 | spec_labels = ["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Pleural Effusion"]
33 | if is_df:
34 | spec_df = stats[spec_labels]
35 | res = np.mean(spec_df.iloc[0])
36 | else:
37 | # cis is df, within bootstrap
38 | vals = [stats[spec_label][0] for spec_label in spec_labels]
39 | res = np.mean(vals)
40 | return res
41 |
42 | def accuracy(output, target, topk=(1,)):
43 | pred = output.topk(max(topk), 1, True, True)[1].t()
44 | print('pred: ', pred)
45 |
46 | expand = target.expand(-1, max(topk))
47 | print('expand: ', expand)
48 |
49 | correct = pred.eq(expand)
50 | print('correct: ', correct)
51 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
52 |
53 | def sigmoid(x):
54 | z = 1/(1 + np.exp(-x))
55 | return z
56 |
57 | ''' ROC CURVE '''
58 | def plot_roc(y_pred, y_true, roc_name, plot=False):
59 | # given the test_ground_truth, and test_predictions
60 | fpr, tpr, thresholds = roc_curve(y_true, y_pred)
61 |
62 | roc_auc = auc(fpr, tpr)
63 |
64 | if plot:
65 | plt.figure(dpi=100)
66 | plt.title(roc_name)
67 | plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
68 | plt.legend(loc = 'lower right')
69 | plt.plot([0, 1], [0, 1],'r--')
70 | plt.xlim([0, 1])
71 | plt.ylim([0, 1])
72 | plt.ylabel('True Positive Rate')
73 | plt.xlabel('False Positive Rate')
74 | plt.show()
75 | return fpr, tpr, thresholds, roc_auc
76 |
77 | # J = TP/(TP+FN) + TN/(TN+FP) - 1 = tpr - fpr
78 | def choose_operating_point(fpr, tpr, thresholds):
79 | sens = 0
80 | spec = 0
81 | J = 0
82 | for _fpr, _tpr in zip(fpr, tpr):
83 | if _tpr - _fpr > J:
84 | sens = _tpr
85 | spec = 1-_fpr
86 | J = _tpr - _fpr
87 | return sens, spec
88 |
89 | ''' PRECISION-RECALL CURVE '''
90 | def plot_pr(y_pred, y_true, pr_name, plot=False):
91 | precision, recall, thresholds = precision_recall_curve(y_true, y_pred)
92 | pr_auc = auc(recall, precision)
93 | # plot the precision-recall curves
94 | baseline = len(y_true[y_true==1]) / len(y_true)
95 |
96 | if plot:
97 | plt.figure(dpi=20)
98 | plt.title(pr_name)
99 | plt.plot(recall, precision, 'b', label='AUC = %0.2f' % pr_auc)
100 | # axis labels
101 | plt.legend(loc = 'lower right')
102 | plt.plot([0, 1], [baseline, baseline],'r--')
103 | plt.xlim([0, 1])
104 | plt.ylim([0, 1])
105 | plt.xlabel('Recall')
106 | plt.ylabel('Precision')
107 | # show the plot
108 | plt.show()
109 | return precision, recall, thresholds
110 |
111 | def evaluate(y_pred, y_true, cxr_labels,
112 | roc_name='Receiver Operating Characteristic', pr_name='Precision-Recall Curve', label_idx_map=None):
113 |
114 | '''
115 | We expect `y_pred` and `y_true` to be numpy arrays, both of shape (num_samples, num_classes)
116 |
117 | `y_pred` is a numpy array consisting of probability scores with all values in range 0-1.
118 |
119 | `y_true` is a numpy array consisting of binary values representing if a class is present in
120 | the cxr.
121 |
122 | This function provides all relevant evaluation information, ROC, AUROC, Sensitivity, Specificity,
123 | PR-Curve, Precision, Recall for each class.
124 | '''
125 | import warnings
126 | warnings.filterwarnings('ignore')
127 |
128 | num_classes = y_pred.shape[-1] # number of total labels
129 |
130 | dataframes = []
131 | for i in range(num_classes):
132 | # print('{}.'.format(cxr_labels[i]))
133 |
134 | if label_idx_map is None:
135 | y_pred_i = y_pred[:, i] # (num_samples,)
136 | y_true_i = y_true[:, i] # (num_samples,)
137 |
138 | else:
139 | y_pred_i = y_pred[:, i] # (num_samples,)
140 |
141 | true_index = label_idx_map[cxr_labels[i]]
142 | y_true_i = y_true[:, true_index] # (num_samples,)
143 |
144 | cxr_label = cxr_labels[i]
145 |
146 | ''' ROC CURVE '''
147 | roc_name = cxr_label + ' ROC Curve'
148 | fpr, tpr, thresholds, roc_auc = plot_roc(y_pred_i, y_true_i, roc_name)
149 |
150 | sens, spec = choose_operating_point(fpr, tpr, thresholds)
151 |
152 | results = [[roc_auc]]
153 | df = pd.DataFrame(results, columns=[cxr_label+'_auc'])
154 | dataframes.append(df)
155 |
156 | ''' PRECISION-RECALL CURVE '''
157 | pr_name = cxr_label + ' Precision-Recall Curve'
158 | precision, recall, thresholds = plot_pr(y_pred_i, y_true_i, pr_name)
159 |
160 | dfs = pd.concat(dataframes, axis=1)
161 | return dfs
162 |
163 | ''' Bootstrap and Confidence Intervals '''
164 | def compute_cis(data, confidence_level=0.05):
165 | """
166 | FUNCTION: compute_cis
167 | ------------------------------------------------------
168 | Given a Pandas dataframe of (n, labels), return another
169 | Pandas dataframe that is (3, labels).
170 |
171 | Each row is lower bound, mean, upper bound of a confidence
172 | interval with `confidence`.
173 |
174 | Args:
175 | * data - Pandas Dataframe, of shape (num_bootstrap_samples, num_labels)
176 | * confidence_level (optional) - confidence level of interval
177 |
178 | Returns:
179 | * Pandas Dataframe, of shape (3, labels), representing mean, lower, upper
180 | """
181 | data_columns = list(data)
182 | intervals = []
183 | for i in data_columns:
184 | series = data[i]
185 | sorted_perfs = series.sort_values()
186 | lower_index = int(confidence_level/2 * len(sorted_perfs)) - 1
187 | upper_index = int((1 - confidence_level/2) * len(sorted_perfs)) - 1
188 | lower = sorted_perfs.iloc[lower_index].round(4)
189 | upper = sorted_perfs.iloc[upper_index].round(4)
190 | mean = round(sorted_perfs.mean(), 4)
191 | interval = pd.DataFrame({i : [mean, lower, upper]})
192 | intervals.append(interval)
193 | intervals_df = pd.concat(intervals, axis=1)
194 | intervals_df.index = ['mean', 'lower', 'upper']
195 | return intervals_df
196 |
197 | def bootstrap(y_pred, y_true, cxr_labels, n_samples=1000, label_idx_map=None):
198 | '''
199 | This function will randomly sample with replacement
200 | from y_pred and y_true then evaluate `n` times
201 | and obtain AUROC scores for each.
202 |
203 | You can specify the number of samples that should be
204 | used with the `n_samples` parameter.
205 |
206 | Confidence intervals will be generated from each
207 | of the samples.
208 |
209 | Note:
210 | * n_total_labels >= n_cxr_labels
211 | `n_total_labels` is greater iff alternative labels are being tested
212 | '''
213 | np.random.seed(97)
214 | y_pred # (500, n_total_labels)
215 | y_true # (500, n_cxr_labels)
216 |
217 | idx = np.arange(len(y_true))
218 |
219 | boot_stats = []
220 | for i in tqdm(range(n_samples)):
221 | sample = resample(idx, replace=True, random_state=i)
222 | y_pred_sample = y_pred[sample]
223 | y_true_sample = y_true[sample]
224 |
225 | sample_stats = evaluate(y_pred_sample, y_true_sample, cxr_labels, label_idx_map=label_idx_map)
226 | boot_stats.append(sample_stats)
227 |
228 | boot_stats = pd.concat(boot_stats) # pandas array of evaluations for each sample
229 | return boot_stats, compute_cis(boot_stats)
230 |
--------------------------------------------------------------------------------
/clip.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2021 OpenAI
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
24 | """
25 |
26 | import hashlib
27 | import os
28 | import urllib
29 | import warnings
30 | from typing import Union, List
31 |
32 | import torch
33 | from PIL import Image
34 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
35 | from tqdm import tqdm
36 |
37 | from model import build_model
38 | from simple_tokenizer import SimpleTokenizer as _Tokenizer
39 |
40 | __all__ = ["available_models", "load", "tokenize"]
41 | _tokenizer = _Tokenizer()
42 |
43 | _MODELS = {
44 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
45 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
46 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
47 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
48 | }
49 |
50 |
51 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
52 | os.makedirs(root, exist_ok=True)
53 | filename = os.path.basename(url)
54 |
55 | expected_sha256 = url.split("/")[-2]
56 | download_target = os.path.join(root, filename)
57 |
58 | if os.path.exists(download_target) and not os.path.isfile(download_target):
59 | raise RuntimeError(f"{download_target} exists and is not a regular file")
60 |
61 | if os.path.isfile(download_target):
62 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
63 | return download_target
64 | else:
65 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
66 |
67 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
68 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
69 | while True:
70 | buffer = source.read(8192)
71 | if not buffer:
72 | break
73 |
74 | output.write(buffer)
75 | loop.update(len(buffer))
76 |
77 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
78 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
79 |
80 | return download_target
81 |
82 |
83 | def _transform(n_px):
84 | return Compose([
85 | Resize(n_px, interpolation=Image.BICUBIC),
86 | CenterCrop(n_px),
87 | lambda image: image.convert("RGB"),
88 | ToTensor(),
89 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
90 | ])
91 |
92 |
93 | def available_models() -> List[str]:
94 | """Returns the names of available CLIP models"""
95 | return list(_MODELS.keys())
96 |
97 |
98 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
99 | """Load a CLIP model
100 |
101 | Parameters
102 | ----------
103 | name : str
104 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
105 |
106 | device : Union[str, torch.device]
107 | The device to put the loaded model
108 |
109 | jit : bool
110 | Whether to load the optimized JIT model (default) or more hackable non-JIT model.
111 |
112 | Returns
113 | -------
114 | model : torch.nn.Module
115 | The CLIP model
116 |
117 | preprocess : Callable[[PIL.Image], torch.Tensor]
118 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
119 | """
120 | if name in _MODELS:
121 | model_path = _download(_MODELS[name])
122 | elif os.path.isfile(name):
123 | model_path = name
124 | else:
125 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
126 |
127 | try:
128 | # loading JIT archive
129 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
130 | state_dict = None
131 | except RuntimeError:
132 | # loading saved state dict
133 | if jit:
134 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
135 | jit = False
136 | state_dict = torch.load(model_path, map_location="cpu")
137 |
138 | if not jit:
139 | model = build_model(state_dict or model.state_dict()).to(device)
140 | if str(device) == "cpu":
141 | model.float()
142 | return model, _transform(model.visual.input_resolution)
143 |
144 | # patch the device names
145 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
146 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
147 |
148 | def patch_device(module):
149 | graphs = [module.graph] if hasattr(module, "graph") else []
150 | if hasattr(module, "forward1"):
151 | graphs.append(module.forward1.graph)
152 |
153 | for graph in graphs:
154 | for node in graph.findAllNodes("prim::Constant"):
155 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
156 | node.copyAttributes(device_node)
157 |
158 | model.apply(patch_device)
159 | patch_device(model.encode_image)
160 | patch_device(model.encode_text)
161 |
162 | # patch dtype to float32 on CPU
163 | if str(device) == "cpu":
164 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
165 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
166 | float_node = float_input.node()
167 |
168 | def patch_float(module):
169 | graphs = [module.graph] if hasattr(module, "graph") else []
170 | if hasattr(module, "forward1"):
171 | graphs.append(module.forward1.graph)
172 |
173 | for graph in graphs:
174 | for node in graph.findAllNodes("aten::to"):
175 | inputs = list(node.inputs())
176 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
177 | if inputs[i].node()["value"] == 5:
178 | inputs[i].node().copyAttributes(float_node)
179 |
180 | model.apply(patch_float)
181 | patch_float(model.encode_image)
182 | patch_float(model.encode_text)
183 |
184 | model.float()
185 |
186 | return model, _transform(model.input_resolution.item())
187 |
188 |
189 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
190 | """
191 | Returns the tokenized representation of given input string(s)
192 |
193 | Parameters
194 | ----------
195 | texts : Union[str, List[str]]
196 | An input string or a list of input strings to tokenize
197 |
198 | context_length : int
199 | The context length to use; all CLIP models use 77 as the context length
200 |
201 | Returns
202 | -------
203 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
204 | """
205 | if isinstance(texts, str):
206 | texts = [texts]
207 |
208 | sot_token = _tokenizer.encoder["<|startoftext|>"]
209 | eot_token = _tokenizer.encoder["<|endoftext|>"]
210 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
211 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
212 |
213 | for i, tokens in enumerate(all_tokens):
214 | if len(tokens) > context_length:
215 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
216 | result[i, :len(tokens)] = torch.tensor(tokens)
217 |
218 | return result
219 |
--------------------------------------------------------------------------------
/preprocess_padchest.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import numpy as np
3 | import os
4 | import pandas as pd
5 | from PIL import Image
6 | import h5py
7 | import matplotlib.pyplot as plt
8 | from typing import List
9 |
10 | import torch
11 | from torch.utils import data
12 | from tqdm.notebook import tqdm
13 | import torch.nn as nn
14 | from torchvision.transforms import Compose, Normalize
15 |
16 | import sklearn
17 | from sklearn.metrics import confusion_matrix, accuracy_score, auc, roc_auc_score, roc_curve, classification_report
18 | from sklearn.metrics import precision_recall_curve, f1_score
19 | from sklearn.metrics import average_precision_score
20 |
21 | import sys
22 | sys.path.append('../..')
23 | sys.path.append('../data-process')
24 | sys.path.append('data/padchest')
25 |
26 | from data_process import *
27 |
28 |
29 |
30 | def preprocess_data(data_root):
31 | labels_path = os.path.join(data_root,
32 | 'PADCHEST_chest_x_ray_images_labels_160K_01.02.19.csv')
33 | labels = pd.read_csv(labels_path)
34 | # get filepaths of 2.zip images
35 | text_file_path = os.path.join(data_root, '2.zip.unzip-l.txt')
36 | image_paths = extract_filenames(text_file_path)
37 | labels_2_df = labels[labels['ImageID'].isin(image_paths)]
38 | unique_labels = get_unique_labels(labels_2_df)
39 | # multi hot encoding for labels
40 | df_lab = create_multi_hot_labels(labels_2_df, unique_labels)
41 |
42 | loc_2_df = labels[labels['ImageID'].isin(image_paths)]
43 | loc_col_2 = loc_2_df.loc[:, "Labels"]
44 | # multihot encoding for localizations
45 | unique_loc = get_unique_labels(loc_2_df, column="Labels")
46 | df_loc = create_multi_hot_labels(loc_2_df, unique_loc, column="Labels")
47 | directory = 'data/padchest/images/'
48 | cxr_paths = get_paths(directory)
49 | write_h5(cxr_paths)
50 | unique_labels = np.load('unique_labels.npy')
51 | return unique_labels[0:1]
52 |
53 | def extract_filenames(txt_path):
54 | """
55 | Given a filepath to a txt file with image file names,
56 | extract a list of filenames for this zip.
57 |
58 | Assume that the txt file has two unnecessary lines at
59 | both the top and the bottom of the file.
60 | """
61 | df = pd.read_csv(txt_path)
62 | df_list = df.values.tolist()
63 | df_list = df_list[2:-2]
64 |
65 | images_list = []
66 | for file in df_list:
67 | parsed_filename = file[0].split()[-1]
68 | images_list.append(parsed_filename)
69 | return images_list
70 |
71 | # get paths of all possible labels
72 | def get_unique_labels(labels_df, column='Labels'):
73 | """
74 | Given labels_df, return a list containing all unique labels
75 | present in this dataset.
76 | """
77 |
78 | unique_labels = set()
79 | # iterate through all rows in the dataframe
80 | for index, row in labels_df.iterrows():
81 | labels = row[column]
82 | try:
83 | # convert labels str to array
84 | labels_arr = labels.strip('][').split(', ')
85 | for label in labels_arr:
86 | # process string
87 | processed_label = label.split("'")[1].strip()
88 | processed_label = processed_label.lower()
89 | unique_labels.add(processed_label)
90 | except:
91 | continue
92 |
93 | return list(unique_labels)
94 |
95 | def create_multi_hot_labels(labels_df, unique_labels_list, column='Labels'):
96 | """
97 | Args:
98 | * labels_df: original df where labels are an arr
99 | * labels_list: list of all possible labels in respective order
100 |
101 | Given all entries and it's corresponding labels, create a one(multi)-hot vector
102 | where a 1 represents the presence of that disease.
103 |
104 | Returns a Pandas dataframe mapping filename to it's multi-hot representation. Each of the diseases
105 | are columns.
106 | """
107 |
108 | # todo: check how the labels are represented for CheXpert
109 | # create a pandas datafraame with columns as unique labels, start with list of dicts
110 | dict_list = []
111 |
112 | # iterate through all rows in the dataframe
113 | for index, row in labels_df.iterrows():
114 | labels = row[column]
115 | try:
116 | # convert labels str to array
117 | labels_arr = labels.strip('][').split(', ')
118 | # print(labels_arr, len(labels_arr))
119 |
120 | count_dict = dict() # map label name to count
121 | count_dict['ImageID'] = row['ImageID']
122 | # init count dict with 0s
123 | for unq_label in unique_labels_list:
124 | count_dict[unq_label] = 0
125 |
126 | if len(labels_arr) > 0 and labels_arr[0] != '':
127 | for label in labels_arr:
128 | # process string
129 | processed_label = label.split("'")[1].strip()
130 | processed_label = processed_label.lower()
131 | count_dict[processed_label] = 1
132 |
133 | dict_list.append(count_dict)
134 | except:
135 | print("error when creating labels for this img.")
136 | continue
137 |
138 | multi_hot_labels_df = pd.DataFrame(dict_list, columns=(['ImageID'] + unique_labels_list))
139 | return multi_hot_labels_df
140 |
141 | # convert folder of images to h5 file
142 | def get_paths(directory):
143 | """
144 | Given a directory, this function outputs
145 | all the image paths in that directory as a
146 | list.
147 | """
148 | paths_list = []
149 | for filename in os.listdir(directory):
150 | if filename.endswith(".png"):
151 | paths_list.append(os.path.join(directory, filename))
152 | else:
153 | continue
154 | return paths_list
155 |
156 | def img_to_h5(
157 | cxr_paths: List[str],
158 | out_filepath: str,
159 | resolution: int = 320,
160 | ) -> List[str]:
161 | """
162 | Converts a set of images into a single `.h5` file.
163 |
164 | Args:
165 | cxr_paths: List of paths to images as `.png`
166 | out_filepath: Path to store h5 file
167 | resolution: image resolution
168 |
169 | Returns a list of cxr_paths that were successfully stored in the
170 | `.h5` file.
171 | """
172 | dset_size = len(cxr_paths)
173 | proper_cxr_paths = []
174 | with h5py.File(out_filepath,'w') as h5f:
175 | img_dset = h5f.create_dataset('cxr', shape=(dset_size, resolution, resolution))
176 |
177 | ctr = 0
178 | for idx, path in enumerate(tqdm(cxr_paths)):
179 | try:
180 | # read image using cv2
181 | img = cv2.imread(path)
182 | # convert to PIL Image object
183 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
184 | img_pil = Image.fromarray(img)
185 | # preprocess
186 | img = preprocess(img_pil, desired_size=resolution)
187 | img_dset[ctr] = img
188 | ctr += 1
189 | proper_cxr_paths.append(path)
190 | except:
191 | print(f"Image {ctr} failed loading...")
192 | continue
193 | print(h5f)
194 |
195 | return proper_cxr_paths
196 |
197 | def write_h5(cxr_paths, resolution: int = 320):
198 | out_filepath = 'data/padchest/images/2_cxr_dset_sample.h5'
199 | dset_size = len(cxr_paths)
200 |
201 | proper_cxr_paths = []
202 | with h5py.File(out_filepath,'w') as h5f:
203 | img_dset = h5f.create_dataset('cxr', shape=(2978, resolution, resolution)) # todo: replace magic number with actual number
204 | # print('Dataset initialized.')
205 |
206 | ctr = 0
207 | for idx, path in enumerate(tqdm(cxr_paths)):
208 | try:
209 | # read image using cv2
210 | img = cv2.imread(path)
211 | # convert to PIL Image object
212 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
213 | img_pil = Image.fromarray(img)
214 | # preprocess
215 | img = preprocess(img_pil, desired_size=resolution)
216 | plt.imshow(img)
217 | img_dset[ctr] = img
218 | ctr += 1
219 | proper_cxr_paths.append(path)
220 | except:
221 | print("failed!")
222 | continue
223 | print(h5f)
224 | np.save("proper_cxr_paths.npy", np.array(proper_cxr_paths))
225 | out_filepath = 'data/padchest/images/2_cxr.h5'
226 | img_to_hdf5(cxr_paths, out_filepath, resolution=320)
227 | df_labels_new = order_labels(df_lab, proper_cxr_paths)
228 | labels_path = 'data/padchest/2_cxr_labels.csv'
229 | df_labels_new.to_csv(labels_path)
230 |
231 | def order_labels(df, cxr_paths):
232 | """
233 | Fixes multi-hot labels to be in order of cxr_paths
234 | """
235 | df_new = pd.DataFrame(columns=df.columns)
236 | for path in cxr_paths:
237 | imageId = path.split('/')[-1]
238 | row = df.loc[df['ImageID'] == imageId]
239 | df_new = df_new.append(row)
240 | return df_new
241 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import matplotlib.pyplot as plt
6 | from tqdm.notebook import tqdm
7 |
8 | from PIL import Image
9 | import h5py
10 |
11 | import torch
12 | from torch.utils import data
13 | from torch import nn
14 | import torch.optim as optim
15 | from torchvision.transforms import Compose, Normalize, Resize, InterpolationMode
16 |
17 | import sys
18 | sys.path.append('../..')
19 |
20 | import clip
21 | from model import CLIP
22 | from simple_tokenizer import SimpleTokenizer
23 |
24 | class CXRDataset(data.Dataset):
25 | """Represents an abstract HDF5 dataset.
26 |
27 | Input params:
28 | file_path: Path to the folder containing the dataset (one or multiple HDF5 files).
29 | recursive: If True, searches for h5 files in subdirectories.
30 | load_data: If True, loads all the data immediately into RAM. Use this if
31 | the dataset is fits into memory. Otherwise, leave this at false and
32 | the data will load lazily.
33 | data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
34 | transform: PyTorch transform to apply to every data instance (default=None).
35 | """
36 | def __init__(self, img_path, txt_path, column='report', size=None, transform=None):
37 | super().__init__()
38 | if size != None:
39 | self.img_dset = h5py.File(img_path, 'r')['cxr'][:size]
40 | self.txt_dset = pd.read_csv(txt_path)[column][:size]
41 | else:
42 | self.img_dset = h5py.File(img_path, 'r')['cxr']
43 | self.txt_dset = pd.read_csv(txt_path)[column]
44 | self.transform = transform
45 |
46 | def __len__(self):
47 | return len(self.txt_dset)
48 |
49 | def __getitem__(self, idx):
50 | if torch.is_tensor(idx):
51 | idx = idx.tolist()
52 |
53 | img = self.img_dset[idx] # np array, (320, 320)
54 | img = np.expand_dims(img, axis=0)
55 | img = np.repeat(img, 3, axis=0)
56 | txt = self.txt_dset[idx] # python str
57 | if type(txt) == type(float("nan")): # capture the case of empty "Impression" sections
58 | txt = " "
59 |
60 | img = torch.from_numpy(img) # torch, (3, 320, 320)
61 | if self.transform:
62 | img = self.transform(img)
63 | sample = {'img': img, 'txt': txt }
64 |
65 | return sample
66 |
67 | def load_data(cxr_filepath, txt_filepath, batch_size=4, column='report', pretrained=False, verbose=False):
68 | if torch.cuda.is_available():
69 | dev = "cuda:0"
70 | cuda_available = True
71 | print('Using CUDA.')
72 | else:
73 | dev = "cpu"
74 | cuda_available = False
75 | print('Using cpu.')
76 |
77 | device = torch.device(dev)
78 |
79 | if cuda_available:
80 | torch.cuda.set_device(device)
81 |
82 | if pretrained:
83 | input_resolution = 224
84 | transform = Compose([
85 | Normalize((101.48761, 101.48761, 101.48761), (83.43944, 83.43944, 83.43944)),
86 | Resize(input_resolution, interpolation=InterpolationMode.BICUBIC),
87 | ])
88 | print('Interpolation Mode: ', InterpolationMode.BICUBIC)
89 | print("Finished image transforms for pretrained model.")
90 | else:
91 | input_resolution = 320
92 | transform = Compose([
93 | Normalize((101.48761, 101.48761, 101.48761), (83.43944, 83.43944, 83.43944)),
94 | ])
95 | print("Finished image transforms for clip model.")
96 |
97 | torch_dset = CXRDataset(img_path=cxr_filepath,
98 | txt_path=txt_filepath, column=column, transform=transform)
99 |
100 | if verbose:
101 | for i in range(len(torch_dset)):
102 | sample = torch_dset[i]
103 | plt.imshow(sample['img'][0])
104 | plt.show()
105 | print(i, sample['img'].size(), sample['txt'])
106 | if i == 3:
107 | break
108 |
109 | loader_params = {'batch_size':batch_size, 'shuffle': True, 'num_workers': 0}
110 | data_loader = data.DataLoader(torch_dset, **loader_params)
111 | return data_loader, device
112 |
113 | def load_clip(model_path=None, pretrained=False, context_length=77):
114 | '''
115 | FUNCTION: load_clip
116 | -------------------------------
117 | This function loads in a model with the CLIP model
118 | architecture.
119 |
120 | args:
121 | * model_path (optional) - path to model weights that the model
122 | will be initialized with
123 | * pretrained (optional) - if True, will load the pretrained
124 | CLIP model
125 | * context_length (optional) - length of the maximum number of
126 | tokens that can be inputted into the CLIP model
127 | '''
128 |
129 | params = {
130 | 'embed_dim':768,
131 | 'image_resolution': 320,
132 | 'vision_layers': 12,
133 | 'vision_width': 768,
134 | 'vision_patch_size': 16,
135 | 'context_length': context_length,
136 | 'vocab_size': 49408,
137 | 'transformer_width': 512,
138 | 'transformer_heads': 8,
139 | 'transformer_layers': 12
140 | }
141 |
142 | # set device
143 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
144 |
145 | if pretrained:
146 | # load clip pre-trained model
147 | model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
148 | print("Loaded in pretrained model.")
149 | else:
150 | model = CLIP(**params)
151 | print("Loaded in clip model.")
152 |
153 | # if a model_path is provided, load in weights to backbone
154 | if model_path != None:
155 | model.load_state_dict(torch.load(model_path, map_location=device))
156 | return model
157 |
158 |
159 | def preprocess_text(texts, model):
160 | # if model.context_length is None:
161 | # model = model.module
162 |
163 | _tokenizer = SimpleTokenizer()
164 | sot_token = _tokenizer.encoder["<|startoftext|>"]
165 | eot_token = _tokenizer.encoder["<|endoftext|>"]
166 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
167 | result = torch.zeros(len(all_tokens), model.context_length, dtype=torch.long)
168 |
169 | for i, tokens in enumerate(all_tokens):
170 | if len(tokens) > model.context_length:
171 | tokens = tokens[:model.context_length]
172 | tokens[model.context_length - 1] = eot_token
173 | result[i, :len(tokens)] = torch.tensor(tokens)
174 | return result
175 |
176 | def make(config, cxr_filepath, txt_filepath, model_path=None):
177 | '''
178 | FUNCTION: make
179 | ---------------------------------
180 | This function makes the model, the data loader, loss and optimizer.
181 |
182 | args:
183 | * config - dict, configuration of experiment
184 | * cxr_filepath - string, filepath to chest x-ray images
185 | * txt_filepath - string, filepath to corresponding text reports
186 | * model_path - string, filepath to previously trained model
187 | '''
188 | data_loader, device = load_data(cxr_filepath, txt_filepath, batch_size=config.batch_size, pretrained=config.pretrained, column=config.column)
189 | model = load_clip(model_path=model_path, pretrained=config.pretrained, context_length=config.context_length)
190 | model.to(device)
191 | print('Model on Device.')
192 |
193 | # make the optimizer
194 | criterion = nn.CrossEntropyLoss().cuda()
195 | # todo: incorporate - torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False)
196 | optimizer = optim.AdamW(model.parameters(), lr=config.lr)
197 | return model, data_loader, device, criterion, optimizer
198 |
199 |
200 | def train_main(cxr_filepath, txt_filepath, hyperparams, output_path, model_path=None, pretrained=False):
201 | '''
202 | args:
203 | * cxr_filpath- str filepath to cxr images
204 | * txt_filepath- str filepath to text reports
205 | * hyperparams- dictionary with the following hyperparams:
206 | `batch_size`, `criterion`, `learning_rate`, `momentum`, `epochs`
207 | * output_path- str filepath to where the trained model will be saved
208 | * model_path- str filepath to model that will be used as baseline model for training.
209 | If not provided, a model will be trained from scratch
210 | * pretrained- whether or not the clip model was pretrained with generic images
211 | This function is the main train function for CXR-CLIP.
212 | '''
213 |
214 | # unpack `hyperparams`
215 | batch_size = hyperparams['batch_size']
216 | criterion = hyperparams['criterion']
217 | learning_rate = hyperparams['learning_rate']
218 | momentum = hyperparams['momentum']
219 | epochs = hyperparams['epochs']
220 |
221 | # load input cxr + report data
222 | data_loader, device = load_data(cxr_filepath, txt_filepath, batch_size=batch_size, pretrained=pretrained)
223 | model = load_clip(model_path=model_path, pretrained=pretrained)
224 |
225 | optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
226 | train_clip(model, data_loader, device, criterion, optimizer, epochs, output_path)
227 | return model
228 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Expert-level detection of pathologies from unannotated chest X-ray images via self-supervised learning
2 |
3 |
4 |
5 | Expert-level detection of pathologies from unannotated chest X-ray images via self-supervised learning, Nat. Biomed. Eng (2022).
6 | [Paper]
7 |
Ekin Tiu, Ellie Talius, Pujan Patel, Curtis P. Langlotz, Andrew Y. Ng, Pranav Rajpurkar
8 |
9 |
10 | ```bash
11 | Tiu, E., Talius, E., Patel, P. et al. Expert-level detection of pathologies from unannotated chest X-ray images via self-supervised learning. Nat. Biomed. Eng (2022). https://doi.org/10.1038/s41551-022-00936-9
12 | ```
13 |
14 |
15 |
16 |
17 | This repository contains code to train a self-supervised learning model on chest X-ray images that lack explicit annotations and evalute this model's performance on pathology-classification tasks.
18 |
19 |
20 |
21 | Main Findings
22 |
23 |
24 | 1. **Automatically detecting pathologies in chest x-rays without explicit annotations:** Our method learns directly from the combination of images and unstructured radiology reports, thereby avoiding time-consuming labeling efforts. Our deep learning method is capable of predicting multiple pathologies and differential diagnoses that it had not explicitly seen during training.
25 | 2. **Matching radiologist performance on different tasks on an external test set:** Our method performed on par with human performance when evaluated on an external validation set (CheXpert) of chest x-ray images labeled for the presence of 14 different conditions by multiple radiologists.
26 | 3. **Outperforming approaches that train on explicitly labeled data on an external test set:** Using no labels, we outperformed a fully supervised approach (100% of labels) on 3 out of the 8 selected pathologies on a dataset (PadChest) collected in a different country. We further demonstrated high performance (AUC > 0.9) on 14 findings and at least 0.700 on 53 findings out of 107 radiographic findings that the method had not seen during training.
27 |
28 |
29 |
30 | ## Dependencies
31 | To clone all files:
32 |
33 | ```git clone https://github.com/rajpurkarlab/CheXzero.git```
34 |
35 | To install Python dependencies:
36 |
37 | ```pip install -r requirements.txt```
38 |
39 | ## Data
40 | ### Training Dataset
41 | 1. Download images come from [MIMIC-CXR JPG] https://physionet.org/content/mimic-cxr-jpg/2.0.0/ and reports from [MIMIC-CXR Database](https://physionet.org/content/mimic-cxr/2.0.0/) Note: in order to gain access to the data, you must be a credentialed user as defined on [PhysioNet](https://physionet.org/settings/credentialing/).
42 | 2. Copy the dataset into the `data/` directory.
43 | 3. Run `python run_preprocess.py`
44 | 4. This should preprocess the chest x-ray images into a Hierarchical Data Format (HDF) format used for training stored at `data/cxr.h5` and extract the impressions section as text from the corresponding chest x-ray radiology report stored at `data/mimic_impressions.csv` .
45 |
46 | ### Evaluation Dataset
47 |
48 | #### CheXpert Dataset
49 | The CheXpert dataset consists of chest radiographic examinations from Stanford Hospital, performed between October 2002
50 | and July 2017 in both inpatient and outpatient centers. Population-level characteristics are unavailable for the CheXpert test
51 | dataset, as they are used for official evaluation on the CheXpert leaderboard.
52 |
53 | The main data (CheXpert data) supporting the results of this study are available at https://aimi.stanford.edu/chexpert-chest-x-rays.
54 |
55 | The CheXpert **test** dataset has recently been made public, and can be found by following the steps in the [cheXpert-test-set-labels](https://github.com/rajpurkarlab/cheXpert-test-set-labels) repository.
56 |
57 | #### PadChest Dataset
58 | The PadChest dataset contains chest X-rays that were interpreted by 18 radiologists at the Hospital Universitario de San Juan,
59 | Alicante, Spain, from January 2009 to December 2017. The dataset contains 109,931 image studies and 168,861 images.
60 | PadChest also contains 206,222 study reports.
61 |
62 | The [PadChest](https://arxiv.org/abs/1901.07441) is publicly available at https://bimcv.cipf.es/bimcv-projects/padchest. Those who would like to use PadChest for experimentation should request access to PadChest at the [link](https://bimcv.cipf.es/bimcv-projects/padchest).
63 |
64 | ### Model Checkpoints
65 | Model checkpoints of CheXzero pre-trained on MIMIC-CXR are publicly available at the following [link](https://drive.google.com/drive/folders/1makFLiEMbSleYltaRxw81aBhEDMpVwno?usp=sharing). Download files and save them in the `./checkpoints/chexzero_weights` directory.
66 |
67 | ## Running Training
68 | Run the following command to perform CheXzero pretraining.
69 | ```bash
70 | python run_train.py --cxr_filepath "./data/cxr.h5" --txt_filepath "data/mimic_impressions.csv"
71 | ```
72 |
73 | ### Arguments
74 | * `--cxr_filepath` Directory to load chest x-ray image data from.
75 | * `--txt_filepath` Directory to load radiology report impressions text from.
76 |
77 | Use `-h` flag to see all optional arguments.
78 |
79 | ## Zero-Shot Inference
80 | See the following [notebook](https://github.com/rajpurkarlab/CheXzero/blob/main/notebooks/zero_shot.ipynb) for an example of how to use CheXzero to perform zero-shot inference on a chest x-ray dataset. The example shows how to output predictions from the model ensemble and evaluate performance of the model if ground truth labels are available.
81 |
82 | ```python
83 | import zero_shot
84 |
85 | # computes predictions for a set of images stored as a np array of probabilities for each pathology
86 | predictions, y_pred_avg = zero_shot.ensemble_models(
87 | model_paths=model_paths,
88 | cxr_filepath=cxr_filepath,
89 | cxr_labels=cxr_labels,
90 | cxr_pair_template=cxr_pair_template,
91 | cache_dir=cache_dir,
92 | )
93 | ```
94 | ### Arguments
95 | * `model_paths: List[str]`: List of paths to all checkpoints to be used in the ensemble. To run on a single model, input a list containing a single path.
96 | * `cxr_filepath: str`: Path to images `.h5` file
97 | * `cxr_labels: List[str]`: List of pathologies to query in each image
98 | * `cxr_pair_templates: Tuple[str, str]`: constrasting templates used to query model (see Figure 1 in article for visual explanation).
99 | * `cache_dir: str`: Directory to cache predictions of each checkpoint, use to avoid recomputing predictions.
100 |
101 | In order to use CheXzero for zero-shot inference, ensure the following requirements are met:
102 | * All input *`images`* must be stored in a single `.h5` (Hierarchical Data Format). See the [`img_to_h5`](https://github.com/rajpurkarlab/CheXzero/blob/main/preprocess_padchest.py#L156) function in [preprocess_padchest.py](https://github.com/rajpurkarlab/internal-chexzero/blob/cleanversion/preprocess_padchest.py) for an example of how to convert a list of paths to `.png` files into a valid `.h5` file.
103 | * The *ground truth `labels`* must be in a `.csv` dataframe where rows represent each image sample, and each column represents the binary labels for a particular pathology on each sample.
104 | * Ensure all [model checkpoints](https://drive.google.com/drive/folders/1makFLiEMbSleYltaRxw81aBhEDMpVwno?usp=sharing) are stored in `checkpoints/chexzero_weights/`, or the `model_dir` that is specified in the notebook.
105 |
106 | ## Evaluation
107 | Given a numpy array of predictions (obtained from zero-shot inference), and a numpy array of ground truth labels, one can evaluate the performance of the model using the following code:
108 | ```python
109 | import zero_shot
110 | import eval
111 |
112 | # loads in ground truth labels into memory
113 | test_pred = y_pred_avg
114 | test_true = zero_shot.make_true_labels(cxr_true_labels_path=cxr_true_labels_path, cxr_labels=cxr_labels)
115 |
116 | # evaluate model, no bootstrap
117 | cxr_results: pd.DataFrame = eval.evaluate(test_pred, test_true, cxr_labels) # eval on full test datset
118 |
119 | # boostrap evaluations for 95% confidence intervals
120 | bootstrap_results: Tuple[pd.DataFrame, pd.DataFrame] = eval.bootstrap(test_pred, test_true, cxr_labels) # (df of results for each bootstrap, df of CI)
121 |
122 | # print results with confidence intervals
123 | print(bootstrap_results[1])
124 | ```
125 | The results are represented as a `pd.DataFrame` which can be saved as a `.csv`.
126 |
127 | ### CheXpert Test Dataset
128 | In order to replicate the results in the paper, zero-shot inference and evaluation can be performed on the now publicly available CheXpert test dataset.
129 | 1) Download labels at [cheXpert-test-set-labels](https://github.com/rajpurkarlab/cheXpert-test-set-labels/blob/main/groundtruth.csv) and image files from [Stanford AIMI](https://stanfordaimi.azurewebsites.net/datasets/23c56a0d-15de-405b-87c8-99c30138950c) and save in the `./data` directory in `CheXzero/`. The test dataset images should have the following directory structure:
130 | ```
131 | data/
132 | ├─ CheXpert/
133 | │ ├─ test/
134 | │ │ ├─ patient64741/
135 | │ │ │ ├─ study1/
136 | │ │ │ │ ├─ view1_frontal.jpg
137 | │ │ ├─ .../
138 | ```
139 |
140 | 2) Run `run_preprocess.py` script with the following arguments:
141 | ```bash
142 | python run_preprocess.py --dataset_type "chexpert-test" --cxr_out_path "./data/chexpert_test.h5" --chest_x_ray_path "./data/CheXpert/test/"
143 | ```
144 | This should save a `.h5` version of the test dataset images which can be used for evaluation.
145 |
146 | 3) Open sample zero-shot [notebook](https://github.com/rajpurkarlab/CheXzero/blob/main/notebooks/zero_shot.ipynb) and run all cells. If the directory structure is set up correctly, then all cells should run without errors.
147 |
148 | ## Issues
149 | Please open new issue threads specifying the issue with the codebase or report issues directly to ekintiu@stanford.edu.
150 |
151 | ## Citation
152 | ```bash
153 | Tiu, E., Talius, E., Patel, P. et al. Expert-level detection of pathologies from unannotated chest X-ray images via self-supervised learning. Nat. Biomed. Eng (2022). https://doi.org/10.1038/s41551-022-00936-9
154 | ```
155 |
156 | ## License
157 | The source code for the site is licensed under the MIT license, which you can find in the `LICENSE` file. Also see `NOTICE.md` for attributions to third-party sources.
158 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import numpy as np
3 | import os
4 | import pandas as pd
5 | from PIL import Image
6 | import h5py
7 | import matplotlib.pyplot as plt
8 | from typing import List, Callable
9 | from collections import defaultdict
10 |
11 | import torch
12 | from torch.utils import data
13 | from tqdm.notebook import tqdm
14 | import torch.nn as nn
15 | from torchvision.transforms import Compose, Normalize, Resize
16 |
17 | import sklearn
18 | from sklearn.metrics import matthews_corrcoef, confusion_matrix, accuracy_score, auc, roc_auc_score, roc_curve, classification_report
19 | from sklearn.metrics import precision_recall_curve, f1_score
20 | from sklearn.metrics import average_precision_score
21 | from sklearn.utils import resample
22 |
23 | import scipy
24 | import scipy.stats
25 |
26 | import sys
27 | sys.path.append('../..')
28 |
29 | import clip
30 | from model import CLIP
31 | from eval import *
32 | from zero_shot import *
33 |
34 | def evaluate_model(X_dir, y_dir, model_path, cxr_labels, alt_labels_dict=None):
35 | cxr_filepath = X_dir
36 | final_label_path = y_dir
37 |
38 | results_out_folder = './results'
39 | context_length = 77
40 |
41 | # templates list of positive and negative template pairs
42 | cxr_pair_templates = [("{}", "no {}")]
43 |
44 | cxr_results, y_pred = run_zero_shot(cxr_labels, cxr_pair_templates, model_path, cxr_filepath=cxr_filepath, final_label_path=final_label_path, alt_labels_dict=alt_labels_dict, softmax_eval=True, context_length=context_length, pretrained=True, use_bootstrap=True, cutlabels=True)
45 | return cxr_results, y_pred
46 |
47 | def f1_mcc_bootstrap(y_pred, y_true, cxr_labels, best_p_vals, eval_func, n_samples=5000, label_idx_map=None):
48 | '''
49 | This function will randomly sample with replacement
50 | from y_pred and y_true then evaluate `n` times
51 | and obtain AUROC scores for each.
52 |
53 | You can specify the number of samples that should be
54 | used with the `n_samples` parameter.
55 |
56 | Confidence intervals will be generated from each
57 | of the samples.
58 | '''
59 | y_pred # (500, 14)
60 | y_true # (500, 14)
61 |
62 | idx = np.arange(len(y_true))
63 |
64 | boot_stats = []
65 | for i in tqdm(range(n_samples)):
66 | sample = resample(idx, replace=True)
67 | y_pred_sample = y_pred[sample]
68 | y_true_sample = y_true[sample]
69 |
70 | sample_stats = eval_func(y_pred_sample, y_true_sample, best_p_vals, cxr_labels=cxr_labels, label_idx_map=label_idx_map)
71 | boot_stats.append(sample_stats)
72 |
73 | boot_stats = pd.concat(boot_stats) # pandas array of evaluations for each sample
74 | return boot_stats, compute_cis(boot_stats)
75 |
76 | def get_best_alt_labels(res_df, cxr_labels):
77 | best_alt_labels_dict = dict()
78 | best_alt_labels_vals = dict()
79 | res_cols = list(res_df)
80 |
81 | curr_path_name = None
82 | for col in res_cols: # for each col
83 | path_name = col.split("_")[0] # pathology name
84 | mean_auc = res_df[col][0] # mean auc
85 |
86 | if path_name in cxr_labels:
87 | # reset the vars
88 | curr_path_name = path_name
89 | best_alt_labels_dict[path_name] = [path_name]
90 | best_alt_labels_vals[path_name] = mean_auc
91 |
92 | if best_alt_labels_vals[curr_path_name] < mean_auc:
93 | best_alt_labels_vals[curr_path_name] = mean_auc
94 | best_alt_labels_dict[curr_path_name] = [path_name]
95 |
96 | return best_alt_labels_dict
97 |
98 | def y_true_csv_to_np(df_path, cxr_labels):
99 | groundtruth = pd.read_csv(df_path)
100 | groundtruth = groundtruth[cxr_labels]
101 | groundtruth = groundtruth.to_numpy()[:,:].astype(int)
102 | return groundtruth
103 |
104 | def get_best_p_vals(pred, groundtruth, cxr_labels, metric_func=matthews_corrcoef, spline_k: int = None, verbose: bool = False):
105 | """
106 | WARNING: CXR_LABELS must
107 | Params:
108 | * pred : np arr
109 | probabilities output by model
110 |
111 | * plot_graphs : bool
112 | if True, will save plots for metric vs. threshold for
113 | each pathology
114 |
115 | Note:
116 | * `probabilities` value is a linspace of possible probabilities
117 | """
118 | probabilities = [val for val in np.arange(0.4, 0.64, 0.0001)]
119 | best_p_vals = dict()
120 | for idx, cxr_label in enumerate(cxr_labels):
121 | y_true = groundtruth[:, idx]
122 | _, _, probabilities = roc_curve(y_true, pred[:, idx])
123 | probabilities = probabilities[1:]
124 | probabilities.sort()
125 |
126 | metrics_list = []
127 | for p in probabilities:
128 | y_pred = np.where(pred[:, idx] < p, 0, 1)
129 | metric = metric_func(y_true, y_pred)
130 | metrics_list.append(metric)
131 |
132 | if spline_k is not None:
133 | try:
134 | spl = UnivariateSpline(probabilities, metrics_list, k=spline_k)
135 | spl_y = spl(probabilities)
136 | # get optimal thresholds on the spline and on the val_metric_list
137 | best_index = np.argmax(spl_y)
138 | except:
139 | best_index = np.argmax(metrics_list)
140 | else:
141 | best_index = np.argmax(metrics_list)
142 |
143 | best_p = probabilities[best_index]
144 | best_metric = metrics_list[best_index]
145 | if verbose:
146 | print("Best metric for {} is {}. threshold = {}.".format(cxr_label, best_metric, best_p))
147 |
148 | best_p_vals[cxr_label] = best_p
149 | return best_p_vals
150 |
151 | def compute_f1_mcc(X_test_dir, y_test_dir, X_val_dir, y_val_dir, model_path, alt_labels_dict : dict = None, find_best_alt: bool = True, thresh_func: Callable = matthews_corrcoef):
152 | """
153 | Computes f1 and mcc scores given test dataset, validation dataset (to find
154 | thresholds) and path to the model.
155 |
156 | Params:
157 | * find_best_alt : bool
158 | If True, will filter alt_labels_dict to only the best alternative labels
159 | based on the validation dataset. Otherwise, will run on all alternative labels
160 | provided.
161 | """
162 |
163 |
164 |
165 | # specify basic cxr labels
166 | cxr_labels = ['Atelectasis','Cardiomegaly',
167 | 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
168 | 'Lung Opacity', 'No Finding','Pleural Effusion', 'Pleural Other', 'Pneumonia',
169 | 'Pneumothorax', 'Support Devices']
170 |
171 | # load in ground truth
172 | VAL_GROUNDTRUTH_PATH = "val_groundtruth.csv"
173 | GROUNDTRUTH_PATH = "groundtruth.csv"
174 |
175 | val_groundtruth = y_true_csv_to_np(VAL_GROUNDTRUTH_PATH, cxr_labels)
176 | groundtruth = y_true_csv_to_np(GROUNDTRUTH_PATH, cxr_labels)
177 |
178 | NUM_LABELS = 14
179 |
180 | # run evaluation on validation and test datasets
181 | # dir for validation datasets
182 | val_X = "/deep/group/data/med-data/valid.h5"
183 | val_y = "/deep/group/data/CheXpert-320x320/valid.csv"
184 |
185 | # dir for test datasets
186 | test_X = "/deep/group/data/med-data/test_cxr.h5"
187 | test_y = "/deep/group/data/med-data/final_paths.csv"
188 |
189 | if alt_labels_dict is not None and find_best_alt:
190 | # find best alternate labels
191 | val_res, val_pred = evaluate_model(val_X, val_y, best_model_path, alt_labels_dict=alt_labels_dict)
192 | # save alternative label results on validation dataset
193 | alt_val_res = val_res[0][1]
194 | best_alt_labels_dict = get_best_alt_labels(alt_val_res, cxr_labels)
195 | elif alt_labels_dict is not None: # find_best_alt == False
196 | best_alt_labels_dict = alt_labels_dict
197 | else: # no alternative labels
198 | best_alt_labels_dict = None
199 |
200 | # create alt_labels
201 | if best_alt_labels_dict is not None:
202 | alt_labels_list, alt_label_idx_map = process_alt_labels(best_alt_labels_dict, cxr_labels)
203 | else:
204 | alt_labels_list, alt_label_idx_map = cxr_labels, None
205 |
206 | # TODO: convert preds into binarized and make this one of the things that are returned
207 | val_res, val_pred = evaluate_model(val_X, val_y, model_path, cxr_labels, alt_labels_dict=best_alt_labels_dict)
208 | test_res, test_pred = evaluate_model(test_X, test_y, model_path, cxr_labels, alt_labels_dict=best_alt_labels_dict)
209 |
210 | # get best thresholds
211 | best_p_vals = get_best_p_vals(val_pred, val_groundtruth, alt_labels_list, alt_label_idx_map, metric_func=thresh_func)
212 |
213 | # f1 computation
214 | f1_cis = compute_f1(test_pred, groundtruth, alt_labels_list, best_p_vals, alt_label_idx_map)
215 | # mcc computation
216 | mcc_cis = compute_mcc(test_pred, groundtruth, alt_labels_list, best_p_vals, alt_label_idx_map)
217 |
218 | return f1_cis, mcc_cis
219 |
220 | def compute_f1(y_pred, y_true, cxr_labels, thresholds, label_idx_map=None):
221 | def get_f1_clip_bootstrap(y_pred, y_true, best_p_vals, cxr_labels=cxr_labels, label_idx_map=None):
222 | stats = {}
223 | probs = np.copy(y_pred)
224 | for idx, cxr_label in enumerate(cxr_labels):
225 | p = best_p_vals[cxr_label]
226 | probs[:,idx] = np.where(probs[:,idx] < p, 0, 1)
227 | clip_preds = np.copy(probs)
228 | for idx, cxr_label in enumerate(cxr_labels):
229 |
230 | if label_idx_map is None:
231 | curr_y_true = y_true[:, idx]
232 | else:
233 | curr_y_true = y_true[:, label_idx_map[cxr_label]]
234 | curr_y_pred = clip_preds[:, idx]
235 |
236 | m = confusion_matrix(curr_y_true, curr_y_pred)
237 | if len(m.ravel()) == 1:
238 | tn = 500
239 | fp = 0
240 | fn = 0
241 | tp = 0
242 | else:
243 | tn, fp, fn, tp = m.ravel()
244 |
245 | if ((2*tp + fp +fn) == 0):
246 | stats[cxr_label] = 1
247 | continue
248 |
249 | stats[cxr_label] = [(2 * tp) / (2*tp + fp +fn)]
250 | # compute mean over five major pathologies
251 | stats["Mean"] = compute_mean(stats, is_df=False)
252 | return pd.DataFrame.from_dict(stats)
253 |
254 | boot_stats, f1_cis = f1_mcc_bootstrap(y_pred, y_true, cxr_labels, thresholds, get_f1_clip_bootstrap, n_samples=1000, label_idx_map=label_idx_map)
255 | return f1_cis
256 |
257 | def compute_mcc(y_pred: np.array, y_true: np.array, cxr_labels: List, thresholds: dict, label_idx_map: dict = None):
258 | def get_mcc_bootstrap(y_pred, y_true, best_p_vals, cxr_labels=cxr_labels, label_idx_map=None):
259 | stats = {}
260 | probs = np.copy(y_pred)
261 |
262 | for idx, cxr_label in enumerate(cxr_labels):
263 | p = best_p_vals[cxr_label]
264 | probs[:,idx] = np.where(probs[:,idx] < p, 0, 1)
265 |
266 | clip_preds = np.copy(probs)
267 |
268 | for idx, cxr_label in enumerate(cxr_labels):
269 | if label_idx_map is None:
270 | curr_y_true = y_true[:, idx]
271 | else:
272 | curr_y_true = y_true[:, label_idx_map[cxr_label]]
273 |
274 | curr_y_pred = clip_preds[:, idx]
275 | stats[cxr_label] = [matthews_corrcoef(curr_y_true, curr_y_pred)]
276 | # compute mean over five major pathologies
277 | stats["Mean"] = compute_mean(stats, is_df=False)
278 | return pd.DataFrame.from_dict(stats)
279 |
280 | boot_stats, mcc_cis = f1_mcc_bootstrap(y_pred, y_true, cxr_labels, thresholds, get_mcc_bootstrap, n_samples=1000, label_idx_map=label_idx_map)
281 | return mcc_cis
282 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2021 OpenAI
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
24 | """
25 | from collections import OrderedDict
26 | from typing import Tuple, Union
27 |
28 | import numpy as np
29 | import torch
30 | import torch.nn.functional as F
31 | from torch import nn
32 |
33 |
34 | class Bottleneck(nn.Module):
35 | expansion = 4
36 |
37 | def __init__(self, inplanes, planes, stride=1):
38 | super().__init__()
39 |
40 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
41 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
42 | self.bn1 = nn.BatchNorm2d(planes)
43 |
44 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
45 | self.bn2 = nn.BatchNorm2d(planes)
46 |
47 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
48 |
49 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
50 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
51 |
52 | self.relu = nn.ReLU(inplace=True)
53 | self.downsample = None
54 | self.stride = stride
55 |
56 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
57 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
58 | self.downsample = nn.Sequential(OrderedDict([
59 | ("-1", nn.AvgPool2d(stride)),
60 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
61 | ("1", nn.BatchNorm2d(planes * self.expansion))
62 | ]))
63 |
64 | def forward(self, x: torch.Tensor):
65 | identity = x
66 |
67 | out = self.relu(self.bn1(self.conv1(x)))
68 | out = self.relu(self.bn2(self.conv2(out)))
69 | out = self.avgpool(out)
70 | out = self.bn3(self.conv3(out))
71 |
72 | if self.downsample is not None:
73 | identity = self.downsample(x)
74 |
75 | out += identity
76 | out = self.relu(out)
77 | return out
78 |
79 |
80 | class AttentionPool2d(nn.Module):
81 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
82 | super().__init__()
83 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
84 | self.k_proj = nn.Linear(embed_dim, embed_dim)
85 | self.q_proj = nn.Linear(embed_dim, embed_dim)
86 | self.v_proj = nn.Linear(embed_dim, embed_dim)
87 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
88 | self.num_heads = num_heads
89 |
90 | def forward(self, x):
91 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
92 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
93 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
94 | x, _ = F.multi_head_attention_forward(
95 | query=x, key=x, value=x,
96 | embed_dim_to_check=x.shape[-1],
97 | num_heads=self.num_heads,
98 | q_proj_weight=self.q_proj.weight,
99 | k_proj_weight=self.k_proj.weight,
100 | v_proj_weight=self.v_proj.weight,
101 | in_proj_weight=None,
102 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
103 | bias_k=None,
104 | bias_v=None,
105 | add_zero_attn=False,
106 | dropout_p=0,
107 | out_proj_weight=self.c_proj.weight,
108 | out_proj_bias=self.c_proj.bias,
109 | use_separate_proj_weight=True,
110 | training=self.training,
111 | need_weights=False
112 | )
113 |
114 | return x[0]
115 |
116 |
117 | class ModifiedResNet(nn.Module):
118 | """
119 | A ResNet class that is similar to torchvision's but contains the following changes:
120 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
121 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
122 | - The final pooling layer is a QKV attention instead of an average pool
123 | """
124 |
125 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
126 | super().__init__()
127 | self.output_dim = output_dim
128 | self.input_resolution = input_resolution
129 |
130 | # the 3-layer stem
131 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
132 | self.bn1 = nn.BatchNorm2d(width // 2)
133 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
134 | self.bn2 = nn.BatchNorm2d(width // 2)
135 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
136 | self.bn3 = nn.BatchNorm2d(width)
137 | self.avgpool = nn.AvgPool2d(2)
138 | self.relu = nn.ReLU(inplace=True)
139 |
140 | # residual layers
141 | self._inplanes = width # this is a *mutable* variable used during construction
142 | self.layer1 = self._make_layer(width, layers[0])
143 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
144 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
145 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
146 |
147 | embed_dim = width * 32 # the ResNet feature dimension
148 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
149 |
150 | def _make_layer(self, planes, blocks, stride=1):
151 | layers = [Bottleneck(self._inplanes, planes, stride)]
152 |
153 | self._inplanes = planes * Bottleneck.expansion
154 | for _ in range(1, blocks):
155 | layers.append(Bottleneck(self._inplanes, planes))
156 |
157 | return nn.Sequential(*layers)
158 |
159 | def forward(self, x):
160 | def stem(x):
161 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
162 | x = self.relu(bn(conv(x)))
163 | x = self.avgpool(x)
164 | return x
165 |
166 | x = x.type(self.conv1.weight.dtype)
167 | x = stem(x)
168 | x = self.layer1(x)
169 | x = self.layer2(x)
170 | x = self.layer3(x)
171 | x = self.layer4(x)
172 | x = self.attnpool(x)
173 |
174 | return x
175 |
176 |
177 | class LayerNorm(nn.LayerNorm):
178 | """Subclass torch's LayerNorm to handle fp16."""
179 |
180 | def forward(self, x: torch.Tensor):
181 | orig_type = x.dtype
182 | ret = super().forward(x.type(torch.float32))
183 | return ret.type(orig_type)
184 |
185 |
186 | class QuickGELU(nn.Module):
187 | def forward(self, x: torch.Tensor):
188 | return x * torch.sigmoid(1.702 * x)
189 |
190 |
191 | class ResidualAttentionBlock(nn.Module):
192 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
193 | super().__init__()
194 |
195 | self.attn = nn.MultiheadAttention(d_model, n_head)
196 | self.ln_1 = LayerNorm(d_model)
197 | self.mlp = nn.Sequential(OrderedDict([
198 | ("c_fc", nn.Linear(d_model, d_model * 4)),
199 | ("gelu", QuickGELU()),
200 | ("c_proj", nn.Linear(d_model * 4, d_model))
201 | ]))
202 | self.ln_2 = LayerNorm(d_model)
203 | self.attn_mask = attn_mask
204 |
205 | def attention(self, x: torch.Tensor):
206 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
207 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
208 |
209 | def forward(self, x: torch.Tensor):
210 | x = x + self.attention(self.ln_1(x))
211 | x = x + self.mlp(self.ln_2(x))
212 | return x
213 |
214 |
215 | class Transformer(nn.Module):
216 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
217 | super().__init__()
218 | self.width = width
219 | self.layers = layers
220 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
221 |
222 | def forward(self, x: torch.Tensor):
223 | return self.resblocks(x)
224 |
225 |
226 | class VisualTransformer(nn.Module):
227 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
228 | super().__init__()
229 | self.input_resolution = input_resolution
230 | self.output_dim = output_dim
231 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
232 |
233 | scale = width ** -0.5
234 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
235 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
236 | self.ln_pre = LayerNorm(width)
237 |
238 | self.transformer = Transformer(width, layers, heads)
239 |
240 | self.ln_post = LayerNorm(width)
241 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
242 |
243 | def forward(self, x: torch.Tensor):
244 | x = self.conv1(x) # shape = [*, width, grid, grid]
245 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
246 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
247 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
248 | x = x + self.positional_embedding.to(x.dtype)
249 | x = self.ln_pre(x)
250 |
251 | x = x.permute(1, 0, 2) # NLD -> LND
252 | x = self.transformer(x)
253 | x = x.permute(1, 0, 2) # LND -> NLD
254 |
255 | x = self.ln_post(x[:, 0, :])
256 |
257 | if self.proj is not None:
258 | x = x @ self.proj
259 |
260 | return x
261 |
262 |
263 | class CLIP(nn.Module):
264 | def __init__(self,
265 | embed_dim: int,
266 | # vision
267 | image_resolution: int,
268 | vision_layers: Union[Tuple[int, int, int, int], int],
269 | vision_width: int,
270 | vision_patch_size: int,
271 | # text
272 | context_length: int,
273 | vocab_size: int,
274 | transformer_width: int,
275 | transformer_heads: int,
276 | transformer_layers: int
277 | ):
278 | super().__init__()
279 |
280 | self.context_length = context_length
281 |
282 | if isinstance(vision_layers, (tuple, list)):
283 | vision_heads = vision_width * 32 // 64
284 | self.visual = ModifiedResNet(
285 | layers=vision_layers,
286 | output_dim=embed_dim,
287 | heads=vision_heads,
288 | input_resolution=image_resolution,
289 | width=vision_width
290 | )
291 | else:
292 | vision_heads = vision_width // 64
293 | self.visual = VisualTransformer(
294 | input_resolution=image_resolution,
295 | patch_size=vision_patch_size,
296 | width=vision_width,
297 | layers=vision_layers,
298 | heads=vision_heads,
299 | output_dim=embed_dim
300 | )
301 |
302 | self.transformer = Transformer(
303 | width=transformer_width,
304 | layers=transformer_layers,
305 | heads=transformer_heads,
306 | attn_mask=self.build_attention_mask()
307 | )
308 |
309 | self.vocab_size = vocab_size
310 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
311 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
312 | self.ln_final = LayerNorm(transformer_width)
313 |
314 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
315 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
316 |
317 | self.initialize_parameters()
318 |
319 | def initialize_parameters(self):
320 | nn.init.normal_(self.token_embedding.weight, std=0.02)
321 | nn.init.normal_(self.positional_embedding, std=0.01)
322 |
323 | if isinstance(self.visual, ModifiedResNet):
324 | if self.visual.attnpool is not None:
325 | std = self.visual.attnpool.c_proj.in_features ** -0.5
326 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
327 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
328 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
329 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
330 |
331 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
332 | for name, param in resnet_block.named_parameters():
333 | if name.endswith("bn3.weight"):
334 | nn.init.zeros_(param)
335 |
336 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
337 | attn_std = self.transformer.width ** -0.5
338 | fc_std = (2 * self.transformer.width) ** -0.5
339 | for block in self.transformer.resblocks:
340 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
341 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
342 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
343 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
344 |
345 | if self.text_projection is not None:
346 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
347 |
348 | def build_attention_mask(self):
349 | # lazily create causal attention mask, with full attention between the vision tokens
350 | # pytorch uses additive attention mask; fill with -inf
351 | mask = torch.empty(self.context_length, self.context_length)
352 | mask.fill_(float("-inf"))
353 | mask.triu_(1) # zero out the lower diagonal
354 | return mask
355 |
356 | @property
357 | def dtype(self):
358 | return self.visual.conv1.weight.dtype
359 |
360 | def encode_image(self, image):
361 | return self.visual(image.type(self.dtype))
362 |
363 | def encode_text(self, text):
364 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
365 |
366 | x = x + self.positional_embedding.type(self.dtype)
367 | x = x.permute(1, 0, 2) # NLD -> LND
368 | x = self.transformer(x)
369 | x = x.permute(1, 0, 2) # LND -> NLD
370 | x = self.ln_final(x).type(self.dtype)
371 |
372 | # x.shape = [batch_size, n_ctx, transformer.width]
373 | # take features from the eot embedding (eot_token is the highest number in each sequence)
374 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
375 |
376 | return x
377 |
378 | def forward(self, image, text):
379 | image_features = self.encode_image(image)
380 | text_features = self.encode_text(text)
381 |
382 | # normalized features
383 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
384 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
385 |
386 | # cosine similarity as logits
387 | logit_scale = self.logit_scale.exp()
388 | logits_per_image = logit_scale * image_features @ text_features.t()
389 | logits_per_text = logit_scale * text_features @ image_features.t()
390 |
391 | # shape = [global_batch_size, global_batch_size]
392 | return logits_per_image, logits_per_text
393 |
394 |
395 | def convert_weights(model: nn.Module):
396 | """Convert applicable model parameters to fp16"""
397 |
398 | def _convert_weights_to_fp16(l):
399 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
400 | l.weight.data = l.weight.data.half()
401 | if l.bias is not None:
402 | l.bias.data = l.bias.data.half()
403 |
404 | if isinstance(l, nn.MultiheadAttention):
405 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
406 | tensor = getattr(l, attr)
407 | if tensor is not None:
408 | tensor.data = tensor.data.half()
409 |
410 | for name in ["text_projection", "proj"]:
411 | if hasattr(l, name):
412 | attr = getattr(l, name)
413 | if attr is not None:
414 | attr.data = attr.data.half()
415 |
416 | model.apply(_convert_weights_to_fp16)
417 |
418 |
419 | def build_model(state_dict: dict):
420 | vit = "visual.proj" in state_dict
421 |
422 | if vit:
423 | vision_width = state_dict["visual.conv1.weight"].shape[0]
424 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
425 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
426 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
427 | image_resolution = vision_patch_size * grid_size
428 | else:
429 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
430 | vision_layers = tuple(counts)
431 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
432 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
433 | vision_patch_size = None
434 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
435 | image_resolution = output_width * 32
436 |
437 | embed_dim = state_dict["text_projection"].shape[1]
438 | context_length = state_dict["positional_embedding"].shape[0]
439 | vocab_size = state_dict["token_embedding.weight"].shape[0]
440 | transformer_width = state_dict["ln_final.weight"].shape[0]
441 | transformer_heads = transformer_width // 64
442 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
443 |
444 | model = CLIP(
445 | embed_dim,
446 | image_resolution, vision_layers, vision_width, vision_patch_size,
447 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
448 | )
449 |
450 | for key in ["input_resolution", "context_length", "vocab_size"]:
451 | if key in state_dict:
452 | del state_dict[key]
453 |
454 | convert_weights(model)
455 | model.load_state_dict(state_dict)
456 | return model.eval()
457 |
--------------------------------------------------------------------------------
/zero_shot.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import numpy as np
3 | import os
4 | import sys
5 | import pandas as pd
6 | from PIL import Image
7 | import h5py
8 | import matplotlib.pyplot as plt
9 | from typing import List, Tuple
10 |
11 | import torch
12 | from torch.utils import data
13 | from tqdm.notebook import tqdm
14 | import torch.nn as nn
15 | from torchvision.transforms import Compose, Normalize, Resize, InterpolationMode
16 |
17 | import sklearn
18 | from sklearn.metrics import confusion_matrix, accuracy_score, auc, roc_auc_score, roc_curve, classification_report
19 | from sklearn.metrics import precision_recall_curve, f1_score
20 | from sklearn.metrics import average_precision_score
21 |
22 | import clip
23 | from model import CLIP
24 | from eval import evaluate, plot_roc, accuracy, sigmoid, bootstrap, compute_cis
25 |
26 | CXR_FILEPATH = '../../project-files/data/test_cxr.h5'
27 | FINAL_LABEL_PATH = '../../project-files/data/final_paths.csv'
28 |
29 | class CXRTestDataset(data.Dataset):
30 | """Represents an abstract HDF5 dataset.
31 |
32 | Input params:
33 | img_path: Path to hdf5 file containing images.
34 | label_path: Path to file containing labels
35 | transform: PyTorch transform to apply to every data instance (default=None).
36 | """
37 | def __init__(
38 | self,
39 | img_path: str,
40 | transform = None,
41 | ):
42 | super().__init__()
43 | self.img_dset = h5py.File(img_path, 'r')['cxr']
44 | self.transform = transform
45 |
46 | def __len__(self):
47 | return len(self.img_dset)
48 |
49 | def __getitem__(self, idx):
50 | if torch.is_tensor(idx):
51 | idx = idx.tolist()
52 |
53 | img = self.img_dset[idx] # np array, (320, 320)
54 | img = np.expand_dims(img, axis=0)
55 | img = np.repeat(img, 3, axis=0)
56 | img = torch.from_numpy(img) # torch, (320, 320)
57 |
58 | if self.transform:
59 | img = self.transform(img)
60 |
61 | sample = {'img': img}
62 |
63 | return sample
64 |
65 | def load_clip(model_path, pretrained=False, context_length=77):
66 | """
67 | FUNCTION: load_clip
68 | ---------------------------------
69 | """
70 | device = torch.device("cpu")
71 | if pretrained is False:
72 | # use new model params
73 | params = {
74 | 'embed_dim':768,
75 | 'image_resolution': 320,
76 | 'vision_layers': 12,
77 | 'vision_width': 768,
78 | 'vision_patch_size': 16,
79 | 'context_length': context_length,
80 | 'vocab_size': 49408,
81 | 'transformer_width': 512,
82 | 'transformer_heads': 8,
83 | 'transformer_layers': 12
84 | }
85 |
86 | model = CLIP(**params)
87 | else:
88 | model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
89 | try:
90 | model.load_state_dict(torch.load(model_path, map_location=device))
91 | except:
92 | print("Argument error. Set pretrained = True.", sys.exc_info()[0])
93 | raise
94 | return model
95 |
96 | def zeroshot_classifier(classnames, templates, model, context_length=77):
97 | """
98 | FUNCTION: zeroshot_classifier
99 | -------------------------------------
100 | This function outputs the weights for each of the classes based on the
101 | output of the trained clip model text transformer.
102 |
103 | args:
104 | * classnames - Python list of classes for a specific zero-shot task. (i.e. ['Atelectasis',...]).
105 | * templates - Python list of phrases that will be indpendently tested as input to the clip model.
106 | * model - Pytorch model, full trained clip model.
107 | * context_length (optional) - int, max number of tokens of text inputted into the model.
108 |
109 | Returns PyTorch Tensor, output of the text encoder given templates.
110 | """
111 | with torch.no_grad():
112 | zeroshot_weights = []
113 | # compute embedding through model for each class
114 | for classname in tqdm(classnames):
115 | texts = [template.format(classname) for template in templates] # format with class
116 | texts = clip.tokenize(texts, context_length=context_length) # tokenize
117 | class_embeddings = model.encode_text(texts) # embed with text encoder
118 |
119 | # normalize class_embeddings
120 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
121 | # average over templates
122 | class_embedding = class_embeddings.mean(dim=0)
123 | # norm over new averaged templates
124 | class_embedding /= class_embedding.norm()
125 | zeroshot_weights.append(class_embedding)
126 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1)
127 | return zeroshot_weights
128 |
129 | def predict(loader, model, zeroshot_weights, softmax_eval=True, verbose=0):
130 | """
131 | FUNCTION: predict
132 | ---------------------------------
133 | This function runs the cxr images through the model
134 | and computes the cosine similarities between the images
135 | and the text embeddings.
136 |
137 | args:
138 | * loader - PyTorch data loader, loads in cxr images
139 | * model - PyTorch model, trained clip model
140 | * zeroshot_weights - PyTorch Tensor, outputs of text encoder for labels
141 | * softmax_eval (optional) - Use +/- softmax method for evaluation
142 | * verbose (optional) - bool, If True, will print out intermediate tensor values for debugging.
143 |
144 | Returns numpy array, predictions on all test data samples.
145 | """
146 | y_pred = []
147 | with torch.no_grad():
148 | for i, data in enumerate(tqdm(loader)):
149 | images = data['img']
150 |
151 | # predict
152 | image_features = model.encode_image(images)
153 | image_features /= image_features.norm(dim=-1, keepdim=True) # (1, 768)
154 |
155 | # obtain logits
156 | logits = image_features @ zeroshot_weights # (1, num_classes)
157 | logits = np.squeeze(logits.numpy(), axis=0) # (num_classes,)
158 |
159 | if softmax_eval is False:
160 | norm_logits = (logits - logits.mean()) / (logits.std())
161 | logits = sigmoid(norm_logits)
162 |
163 | y_pred.append(logits)
164 |
165 | if verbose:
166 | plt.imshow(images[0][0])
167 | plt.show()
168 | print('images: ', images)
169 | print('images size: ', images.size())
170 |
171 | print('image_features size: ', image_features.size())
172 | print('logits: ', logits)
173 | print('logits size: ', logits.size())
174 |
175 | y_pred = np.array(y_pred)
176 | return np.array(y_pred)
177 |
178 | def run_single_prediction(cxr_labels, template, model, loader, softmax_eval=True, context_length=77):
179 | """
180 | FUNCTION: run_single_prediction
181 | --------------------------------------
182 | This function will make probability predictions for a single template
183 | (i.e. "has {}").
184 |
185 | args:
186 | * cxr_labels - list, labels for a specific zero-shot task. (i.e. ['Atelectasis',...])
187 | * template - string, template to input into model.
188 | * model - PyTorch model, trained clip model
189 | * loader - PyTorch data loader, loads in cxr images
190 | * softmax_eval (optional) - Use +/- softmax method for evaluation
191 | * context_length (optional) - int, max number of tokens of text inputted into the model.
192 |
193 | Returns list, predictions from the given template.
194 | """
195 | cxr_phrase = [template]
196 | zeroshot_weights = zeroshot_classifier(cxr_labels, cxr_phrase, model, context_length=context_length)
197 | y_pred = predict(loader, model, zeroshot_weights, softmax_eval=softmax_eval)
198 | return y_pred
199 |
200 | def process_alt_labels(alt_labels_dict, cxr_labels):
201 | """
202 | Process alt labels and return relevant info. If `alt_labels_dict` is
203 | None, return None.
204 |
205 | Returns:
206 | * alt_label_list : list
207 | List of all alternative labels
208 | * alt_label_idx_map : dict
209 | Maps alt label to idx of original label in cxr_labels
210 | Needed to access correct column during evaluation
211 |
212 | """
213 |
214 | if alt_labels_dict is None:
215 | return None, None
216 |
217 | def get_inverse_labels(labels_alt_map: dict):
218 | """
219 | Returns dict mapping alternative label back to actual label.
220 | Used for reference during evaluation.
221 | """
222 | inverse_labels_dict = {}
223 | for main in labels_alt_map:
224 | inverse_labels_dict[main] = main # adds self to list of alt labels
225 | for alt in labels_alt_map[main]:
226 | inverse_labels_dict[alt] = main
227 | return inverse_labels_dict
228 |
229 | inv_labels_dict = get_inverse_labels(alt_labels_dict)
230 | alt_label_list = [w for w in inv_labels_dict.keys()]
231 |
232 | # create index map
233 | index_map = dict()
234 | for i, label in enumerate(cxr_labels):
235 | index_map[label] = i
236 |
237 | # make map to go from alt label directly to index
238 | alt_label_idx_map = dict()
239 | for alt_label in alt_label_list:
240 | alt_label_idx_map[alt_label] = index_map[inv_labels_dict[alt_label]]
241 |
242 | return alt_label_list, alt_label_idx_map
243 |
244 | def run_softmax_eval(model, loader, eval_labels: list, pair_template: tuple, context_length: int = 77):
245 | """
246 | Run softmax evaluation to obtain a single prediction from the model.
247 | """
248 | # get pos and neg phrases
249 | pos = pair_template[0]
250 | neg = pair_template[1]
251 |
252 | # get pos and neg predictions, (num_samples, num_classes)
253 | pos_pred = run_single_prediction(eval_labels, pos, model, loader,
254 | softmax_eval=True, context_length=context_length)
255 | neg_pred = run_single_prediction(eval_labels, neg, model, loader,
256 | softmax_eval=True, context_length=context_length)
257 |
258 | # compute probabilities with softmax
259 | sum_pred = np.exp(pos_pred) + np.exp(neg_pred)
260 | y_pred = np.exp(pos_pred) / sum_pred
261 | return y_pred
262 |
263 | def run_experiment(model, cxr_labels, cxr_templates, loader, y_true, alt_labels_dict=None, softmax_eval=True, context_length=77, use_bootstrap=True):
264 | '''
265 | FUNCTION: run_experiment
266 | ----------------------------------------
267 | This function runs the zeroshot experiment on each of the templates
268 | individually, and stores the results in a list.
269 |
270 | args:
271 | * model - PyTorch model, trained clip model
272 | * cxr_labels - list, labels for a specific zero-shot task. (i.e. ['Atelectasis',...])
273 | * cxr_templates - list, templates to input into model. If softmax_eval is True,
274 | this should be a list of tuples, where each tuple is a +/- pair
275 | * loader - PyTorch data loader, loads in cxr images
276 | * y_true - list, ground truth labels for test dataset
277 | * softmax_eval (optional) - bool, if True, will evaluate results through softmax of pos vs. neg samples.
278 | * context_length - int, max number of tokens of text inputted into the model.
279 | * use_bootstrap (optional) - bool, whether or not to use bootstrap sampling
280 |
281 | Returns a list of results from the experiment.
282 | '''
283 |
284 | alt_label_list, alt_label_idx_map = process_alt_labels(alt_labels_dict, cxr_labels)
285 | if alt_label_list is not None:
286 | eval_labels = alt_label_list
287 | else:
288 | eval_labels = cxr_labels
289 |
290 | results = []
291 | for template in cxr_templates:
292 | print('Phrase being used: ', template)
293 |
294 | try:
295 | if softmax_eval:
296 | y_pred = run_softmax_eval(model, loader, eval_labels, template, context_length=context_length)
297 |
298 | else:
299 | # get single prediction
300 | y_pred = run_single_prediction(eval_labels, template, model, loader,
301 | softmax_eval=softmax_eval, context_length=context_length)
302 | # print("y_pred: ", y_pred)
303 | except:
304 | print("Argument error. Make sure cxr_templates is proper format.", sys.exc_info()[0])
305 | raise
306 |
307 | # evaluate
308 | if use_bootstrap:
309 | # compute bootstrap stats
310 | boot_stats = bootstrap(y_pred, y_true, eval_labels, label_idx_map=alt_label_idx_map)
311 | results.append(boot_stats) # each template has a pandas array of samples
312 | else:
313 | stats = evaluate(y_pred, y_true, eval_labels)
314 | results.append(stats)
315 |
316 | return results, y_pred
317 |
318 | def make_true_labels(
319 | cxr_true_labels_path: str,
320 | cxr_labels: List[str],
321 | cutlabels: bool = True
322 | ):
323 | """
324 | Loads in data containing the true binary labels
325 | for each pathology in `cxr_labels` for all samples. This
326 | is used for evaluation of model performance.
327 |
328 | args:
329 | * cxr_true_labels_path - str, path to csv containing ground truth labels
330 | * cxr_labels - List[str], subset of label columns to select from ground truth df
331 | * cutlabels - bool, if True, will keep columns of ground truth labels that correspond
332 | with the labels inputted through `cxr_labels`. Otherwise, drop the first column and keep remaining.
333 |
334 | Returns a numpy array of shape (# samples, # labels/pathologies)
335 | representing the binary ground truth labels for each pathology on each sample.
336 | """
337 | # create ground truth labels
338 | full_labels = pd.read_csv(cxr_true_labels_path)
339 | if cutlabels:
340 | full_labels = full_labels.loc[:, cxr_labels]
341 | else:
342 | full_labels.drop(full_labels.columns[0], axis=1, inplace=True)
343 |
344 | y_true = full_labels.to_numpy()
345 | return y_true
346 |
347 | def make(
348 | model_path: str,
349 | cxr_filepath: str,
350 | pretrained: bool = True,
351 | context_length: bool = 77,
352 | ):
353 | """
354 | FUNCTION: make
355 | -------------------------------------------
356 | This function makes the model, the data loader, and the ground truth labels.
357 |
358 | args:
359 | * model_path - String for directory to the weights of the trained clip model.
360 | * context_length - int, max number of tokens of text inputted into the model.
361 | * cxr_filepath - String for path to the chest x-ray images.
362 | * cxr_labels - Python list of labels for a specific zero-shot task. (i.e. ['Atelectasis',...])
363 | * pretrained - bool, whether or not model uses pretrained clip weights
364 | * cutlabels - bool, if True, will keep columns of ground truth labels that correspond
365 | with the labels inputted through `cxr_labels`. Otherwise, drop the first column and keep remaining.
366 |
367 | Returns model, data loader.
368 | """
369 | # load model
370 | model = load_clip(
371 | model_path=model_path,
372 | pretrained=pretrained,
373 | context_length=context_length
374 | )
375 |
376 | # load data
377 | transformations = [
378 | # means computed from sample in `cxr_stats` notebook
379 | Normalize((101.48761, 101.48761, 101.48761), (83.43944, 83.43944, 83.43944)),
380 | ]
381 | # if using CLIP pretrained model
382 | if pretrained:
383 | # resize to input resolution of pretrained clip model
384 | input_resolution = 224
385 | transformations.append(Resize(input_resolution, interpolation=InterpolationMode.BICUBIC))
386 | transform = Compose(transformations)
387 |
388 | # create dataset
389 | torch_dset = CXRTestDataset(
390 | img_path=cxr_filepath,
391 | transform=transform,
392 | )
393 | loader = torch.utils.data.DataLoader(torch_dset, shuffle=False)
394 |
395 | return model, loader
396 |
397 | ## Run the model on the data set using ensembled models
398 | def ensemble_models(
399 | model_paths: List[str],
400 | cxr_filepath: str,
401 | cxr_labels: List[str],
402 | cxr_pair_template: Tuple[str],
403 | cache_dir: str = None,
404 | save_name: str = None,
405 | ) -> Tuple[List[np.ndarray], np.ndarray]:
406 | """
407 | Given a list of `model_paths`, ensemble model and return
408 | predictions. Caches predictions at `cache_dir` if location provided.
409 |
410 | Returns a list of each model's predictions and the averaged
411 | set of predictions.
412 | """
413 |
414 | predictions = []
415 | model_paths = sorted(model_paths) # ensure consistency of
416 | for path in model_paths: # for each model
417 | model_name = Path(path).stem
418 |
419 | # load in model and `torch.DataLoader`
420 | model, loader = make(
421 | model_path=path,
422 | cxr_filepath=cxr_filepath,
423 | )
424 |
425 | # path to the cached prediction
426 | if cache_dir is not None:
427 | if save_name is not None:
428 | cache_path = Path(cache_dir) / f"{save_name}_{model_name}.npy"
429 | else:
430 | cache_path = Path(cache_dir) / f"{model_name}.npy"
431 |
432 | # if prediction already cached, don't recompute prediction
433 | if cache_dir is not None and os.path.exists(cache_path):
434 | print("Loading cached prediction for {}".format(model_name))
435 | y_pred = np.load(cache_path)
436 | else: # cached prediction not found, compute preds
437 | print("Inferring model {}".format(path))
438 | y_pred = run_softmax_eval(model, loader, cxr_labels, cxr_pair_template)
439 | if cache_dir is not None:
440 | Path(cache_dir).mkdir(exist_ok=True, parents=True)
441 | np.save(file=cache_path, arr=y_pred)
442 | predictions.append(y_pred)
443 |
444 | # compute average predictions
445 | y_pred_avg = np.mean(predictions, axis=0)
446 |
447 | return predictions, y_pred_avg
448 |
449 | def run_zero_shot(cxr_labels, cxr_templates, model_path, cxr_filepath, final_label_path, alt_labels_dict: dict = None, softmax_eval = True, context_length=77, pretrained: bool = False, use_bootstrap=True, cutlabels=True):
450 | """
451 | FUNCTION: run_zero_shot
452 | --------------------------------------
453 | This function is the main function to run the zero-shot pipeline given a dataset,
454 | labels, templates for those labels, ground truth labels, and config parameters.
455 |
456 | args:
457 | * cxr_labels - list
458 | labels for a specific zero-shot task. (i.e. ['Atelectasis',...])
459 | task can either be a string or a tuple (name of alternative label, name of label in csv)
460 | * cxr_templates - list, phrases that will be indpendently tested as input to the clip model. If `softmax_eval` is True, this parameter should be a
461 | list of positive and negative template pairs stored as tuples.
462 | * model_path - String for directory to the weights of the trained clip model.
463 | * cxr_filepath - String for path to the chest x-ray images.
464 | * final_label_path - String for path to ground truth labels.
465 |
466 | * alt_labels_dict (optional) - dict, map cxr_labels to list of alternative labels (i.e. 'Atelectasis': ['lung collapse', 'atelectatic lung', ...])
467 | * softmax_eval (optional) - bool, if True, will evaluate results through softmax of pos vs. neg samples.
468 | * context_length (optional) - int, max number of tokens of text inputted into the model.
469 | * pretrained (optional) - bool, whether or not model uses pretrained clip weights
470 | * use_bootstrap (optional) - bool, whether or not to use bootstrap sampling
471 | * cutlabels (optional) - bool, if True, will keep columns of ground truth labels that correspond
472 | with the labels inputted through `cxr_labels`. Otherwise, drop the first column and keep remaining.
473 |
474 | Returns an array of results per template, each consists of a tuple containing a pandas dataframes
475 | for n bootstrap samples, and another pandas dataframe with the confidence intervals for each class.
476 | """
477 |
478 | np.random.seed(97)
479 | # make the model, data loader, and ground truth labels
480 | model, loader = make(
481 | model_path=model_path,
482 | cxr_filepath=cxr_filepath,
483 | pretrained=pretrained,
484 | context_length=context_length
485 | )
486 |
487 | y_true = make_true_labels(
488 | cxr_true_labels_path=final_label_path,
489 | cxr_labels=cxr_labels,
490 | cutlabels=cutlabels,
491 | )
492 |
493 | # run multiphrase experiment
494 | results, y_pred = run_experiment(model, cxr_labels, cxr_templates, loader, y_true,
495 | alt_labels_dict=alt_labels_dict, softmax_eval=softmax_eval, context_length=context_length, use_bootstrap=use_bootstrap)
496 | return results, y_pred
497 |
498 | def run_cxr_zero_shot(model_path, context_length=77, pretrained=False):
499 | """
500 | FUNCTION: run_cxr_zero_shot
501 | --------------------------------------
502 | This function runs zero-shot specifically for the cxr dataset.
503 | The only difference between this function and `run_zero_shot` is that
504 | this function is already pre-parameterized for the 14 cxr labels evaluated
505 | using softmax method of positive and negative templates.
506 |
507 | args:
508 | * model_path - string, filepath of model being evaluated
509 | * context_length (optional) - int, max number of tokens of text inputted into the model.
510 | * pretrained (optional) - bool, whether or not model uses pretrained clip weights
511 | * use_bootstrap (optional) - bool, whether or not to use bootstrap sampling
512 |
513 | Returns an array of labels, and an array of results per template,
514 | each consists of a tuple containing a pandas dataframes
515 | for n bootstrap samples, and another pandas dataframe with the confidence intervals for each class.
516 | """
517 | cxr_filepath = '/deep/group/data/med-data/test_cxr.h5'
518 | final_label_path = '/deep/group/data/med-data/final_paths.csv'
519 |
520 | cxr_labels = ['Atelectasis','Cardiomegaly',
521 | 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
522 | 'Lung Opacity', 'No Finding','Pleural Effusion', 'Pleural Other', 'Pneumonia',
523 | 'Pneumothorax', 'Support Devices']
524 |
525 | # templates list of positive and negative template pairs
526 | cxr_templates = [("{}", "no {}")]
527 |
528 | cxr_results = run_zero_shot(cxr_labels, cxr_templates, model_path, cxr_filepath=cxr_filepath, final_label_path=final_label_path, softmax_eval=True, context_length=context_length, pretrained=pretrained, use_bootstrap=False, cutlabels=True)
529 |
530 | return cxr_labels, cxr_results[0]
531 |
532 | def validation_zero_shot(model_path, context_length=77, pretrained=False):
533 | """
534 | FUNCTION: validation_zero_shot
535 | --------------------------------------
536 | This function uses the CheXpert validation dataset to make predictions
537 | on an alternative task (ap/pa, sex) in order to tune hyperparameters.
538 |
539 | args:
540 | * model_path - string, filepath of model being evaluated
541 | * context_length (optional) - int, max number of tokens of text inputted into the model.
542 | * pretrained (optional) - bool, whether or not model uses pretrained clip weights
543 | * use_bootstrap (optional) - bool, whether or not to use bootstrap sampling
544 |
545 | Returns an array of labels, and an array of results per template,
546 | each consists of a tuple containing a pandas dataframes
547 | for n bootstrap samples, and another pandas dataframe with the confidence intervals for each class.
548 | """
549 | cxr_sex_labels = ['Female', 'Male']
550 |
551 | cxr_sex_templates = [
552 | #'{}',
553 | # 'the patient is a {}',
554 | "the patient's sex is {}",
555 | ]
556 |
557 | # run zero shot experiment
558 | sex_labels_path = '../../data/val_sex_labels.csv'
559 | results = run_zero_shot(cxr_sex_labels, cxr_sex_templates, model_path, cxr_filepath=cxr_filepath, final_label_path=sex_labels_path, softmax_eval=False, context_length=context_length, pretrained=True, use_bootstrap=True, cutlabels=False)
560 |
561 | results = run_zero_shot(cxr_sex_labels, cxr_sex_templates, model_path, cxr_filepath=cxr_filepath, final_label_path=sex_labels_path, softmax_eval=False, context_length=context_length, pretrained=True, use_bootstrap=True, cutlabels=False)
562 | pass
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
--------------------------------------------------------------------------------
/notebooks/zero_shot.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Sample Notebook for Zero-Shot Inference with CheXzero\n",
8 | "This notebook walks through how to use CheXzero to perform zero-shot inference on a chest x-ray image dataset."
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "metadata": {},
14 | "source": [
15 | "## Import Libraries"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": 16,
21 | "metadata": {},
22 | "outputs": [
23 | {
24 | "name": "stdout",
25 | "output_type": "stream",
26 | "text": [
27 | "The autoreload extension is already loaded. To reload it, use:\n",
28 | " %reload_ext autoreload\n"
29 | ]
30 | }
31 | ],
32 | "source": [
33 | "import os\n",
34 | "import numpy as np\n",
35 | "import pandas as pd\n",
36 | "from pathlib import Path\n",
37 | "from typing import List, Tuple, Optional\n",
38 | "\n",
39 | "import sys\n",
40 | "sys.path.append('../')\n",
41 | "\n",
42 | "from eval import evaluate, bootstrap\n",
43 | "from zero_shot import make, make_true_labels, run_softmax_eval\n",
44 | "\n",
45 | "%load_ext autoreload\n",
46 | "%autoreload 2"
47 | ]
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "metadata": {},
52 | "source": [
53 | "## Directories and Constants"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 17,
59 | "metadata": {},
60 | "outputs": [
61 | {
62 | "name": "stdout",
63 | "output_type": "stream",
64 | "text": [
65 | "['../checkpoints/chexzero_weights/best_64_0.0001_original_17000_0.863.pt', '../checkpoints/chexzero_weights/best_128_5e-05_original_22000_0.855.pt', '../checkpoints/chexzero_weights/best_64_0.0001_original_35000_0.864.pt', '../checkpoints/chexzero_weights/best_64_5e-05_original_18000_0.862.pt', '../checkpoints/chexzero_weights/best_128_0.0002_original_8000_0.857.pt', '../checkpoints/chexzero_weights/best_64_5e-05_original_22000_0.864.pt', '../checkpoints/chexzero_weights/best_64_5e-05_original_16000_0.858.pt', '../checkpoints/chexzero_weights/best_128_0.0002_original_15000_0.859.pt', '../checkpoints/chexzero_weights/best_64_0.0002_original_23000_0.854.pt', '../checkpoints/chexzero_weights/best_64_0.0001_original_16000_0.861.pt']\n"
66 | ]
67 | }
68 | ],
69 | "source": [
70 | "## Define Zero Shot Labels and Templates\n",
71 | "\n",
72 | "# ----- DIRECTORIES ------ #\n",
73 | "cxr_filepath: str = '../data/chexpert_test.h5' # filepath of chest x-ray images (.h5)\n",
74 | "cxr_true_labels_path: Optional[str] = '../data/groundtruth.csv' # (optional for evaluation) if labels are provided, provide path\n",
75 | "model_dir: str = '../checkpoints/chexzero_weights' # where pretrained models are saved (.pt) \n",
76 | "predictions_dir: Path = Path('../predictions') # where to save predictions\n",
77 | "cache_dir: str = predictions_dir / \"cached\" # where to cache ensembled predictions\n",
78 | "\n",
79 | "context_length: int = 77\n",
80 | "\n",
81 | "# ------- LABELS ------ #\n",
82 | "# Define labels to query each image | will return a prediction for each label\n",
83 | "cxr_labels: List[str] = ['Atelectasis','Cardiomegaly', \n",
84 | " 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',\n",
85 | " 'Lung Opacity', 'No Finding','Pleural Effusion', 'Pleural Other', 'Pneumonia', \n",
86 | " 'Pneumothorax', 'Support Devices']\n",
87 | "\n",
88 | "# ---- TEMPLATES ----- # \n",
89 | "# Define set of templates | see Figure 1 for more details \n",
90 | "cxr_pair_template: Tuple[str] = (\"{}\", \"no {}\")\n",
91 | "\n",
92 | "# ----- MODEL PATHS ------ #\n",
93 | "# If using ensemble, collect all model paths\n",
94 | "model_paths = []\n",
95 | "for subdir, dirs, files in os.walk(model_dir):\n",
96 | " for file in files:\n",
97 | " full_dir = os.path.join(subdir, file)\n",
98 | " model_paths.append(full_dir)\n",
99 | " \n",
100 | "print(model_paths)"
101 | ]
102 | },
103 | {
104 | "cell_type": "markdown",
105 | "metadata": {},
106 | "source": [
107 | "## Run Inference"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": 19,
113 | "metadata": {},
114 | "outputs": [],
115 | "source": [
116 | "## Run the model on the data set using ensembled models\n",
117 | "def ensemble_models(\n",
118 | " model_paths: List[str], \n",
119 | " cxr_filepath: str, \n",
120 | " cxr_labels: List[str], \n",
121 | " cxr_pair_template: Tuple[str], \n",
122 | " cache_dir: str = None, \n",
123 | " save_name: str = None,\n",
124 | ") -> Tuple[List[np.ndarray], np.ndarray]: \n",
125 | " \"\"\"\n",
126 | " Given a list of `model_paths`, ensemble model and return\n",
127 | " predictions. Caches predictions at `cache_dir` if location provided.\n",
128 | "\n",
129 | " Returns a list of each model's predictions and the averaged\n",
130 | " set of predictions.\n",
131 | " \"\"\"\n",
132 | "\n",
133 | " predictions = []\n",
134 | " model_paths = sorted(model_paths) # ensure consistency of \n",
135 | " for path in model_paths: # for each model\n",
136 | " model_name = Path(path).stem\n",
137 | "\n",
138 | " # load in model and `torch.DataLoader`\n",
139 | " model, loader = make(\n",
140 | " model_path=path, \n",
141 | " cxr_filepath=cxr_filepath, \n",
142 | " ) \n",
143 | " \n",
144 | " # path to the cached prediction\n",
145 | " if cache_dir is not None:\n",
146 | " if save_name is not None: \n",
147 | " cache_path = Path(cache_dir) / f\"{save_name}_{model_name}.npy\"\n",
148 | " else: \n",
149 | " cache_path = Path(cache_dir) / f\"{model_name}.npy\"\n",
150 | "\n",
151 | " # if prediction already cached, don't recompute prediction\n",
152 | " if cache_dir is not None and os.path.exists(cache_path): \n",
153 | " print(\"Loading cached prediction for {}\".format(model_name))\n",
154 | " y_pred = np.load(cache_path)\n",
155 | " else: # cached prediction not found, compute preds\n",
156 | " print(\"Inferring model {}\".format(path))\n",
157 | " y_pred = run_softmax_eval(model, loader, cxr_labels, cxr_pair_template)\n",
158 | " if cache_dir is not None: \n",
159 | " Path(cache_dir).mkdir(exist_ok=True, parents=True)\n",
160 | " np.save(file=cache_path, arr=y_pred)\n",
161 | " predictions.append(y_pred)\n",
162 | " \n",
163 | " # compute average predictions\n",
164 | " y_pred_avg = np.mean(predictions, axis=0)\n",
165 | " \n",
166 | " return predictions, y_pred_avg"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": 21,
172 | "metadata": {},
173 | "outputs": [
174 | {
175 | "name": "stdout",
176 | "output_type": "stream",
177 | "text": [
178 | "Inferring model ../checkpoints/chexzero_weights/best_128_0.0002_original_15000_0.859.pt\n"
179 | ]
180 | },
181 | {
182 | "data": {
183 | "application/vnd.jupyter.widget-view+json": {
184 | "model_id": "9e09e7e227cb4d4d9871f0b06f02ff61",
185 | "version_major": 2,
186 | "version_minor": 0
187 | },
188 | "text/plain": [
189 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))"
190 | ]
191 | },
192 | "metadata": {},
193 | "output_type": "display_data"
194 | },
195 | {
196 | "name": "stdout",
197 | "output_type": "stream",
198 | "text": [
199 | "\n"
200 | ]
201 | },
202 | {
203 | "data": {
204 | "application/vnd.jupyter.widget-view+json": {
205 | "model_id": "c9603105fe154361921ade24843bb63a",
206 | "version_major": 2,
207 | "version_minor": 0
208 | },
209 | "text/plain": [
210 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))"
211 | ]
212 | },
213 | "metadata": {},
214 | "output_type": "display_data"
215 | },
216 | {
217 | "name": "stdout",
218 | "output_type": "stream",
219 | "text": [
220 | "\n"
221 | ]
222 | },
223 | {
224 | "data": {
225 | "application/vnd.jupyter.widget-view+json": {
226 | "model_id": "1ca871c5168b412eaed223fd4407f14c",
227 | "version_major": 2,
228 | "version_minor": 0
229 | },
230 | "text/plain": [
231 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))"
232 | ]
233 | },
234 | "metadata": {},
235 | "output_type": "display_data"
236 | },
237 | {
238 | "name": "stdout",
239 | "output_type": "stream",
240 | "text": [
241 | "\n"
242 | ]
243 | },
244 | {
245 | "data": {
246 | "application/vnd.jupyter.widget-view+json": {
247 | "model_id": "8feea81cea1a4e44886f78cbd3dbd95e",
248 | "version_major": 2,
249 | "version_minor": 0
250 | },
251 | "text/plain": [
252 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))"
253 | ]
254 | },
255 | "metadata": {},
256 | "output_type": "display_data"
257 | },
258 | {
259 | "name": "stdout",
260 | "output_type": "stream",
261 | "text": [
262 | "\n",
263 | "Inferring model ../checkpoints/chexzero_weights/best_128_0.0002_original_8000_0.857.pt\n"
264 | ]
265 | },
266 | {
267 | "data": {
268 | "application/vnd.jupyter.widget-view+json": {
269 | "model_id": "753a46885545435480f8a559b7a29955",
270 | "version_major": 2,
271 | "version_minor": 0
272 | },
273 | "text/plain": [
274 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))"
275 | ]
276 | },
277 | "metadata": {},
278 | "output_type": "display_data"
279 | },
280 | {
281 | "name": "stdout",
282 | "output_type": "stream",
283 | "text": [
284 | "\n"
285 | ]
286 | },
287 | {
288 | "data": {
289 | "application/vnd.jupyter.widget-view+json": {
290 | "model_id": "251f63415a314d2ca2064d1d00a028ae",
291 | "version_major": 2,
292 | "version_minor": 0
293 | },
294 | "text/plain": [
295 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))"
296 | ]
297 | },
298 | "metadata": {},
299 | "output_type": "display_data"
300 | },
301 | {
302 | "name": "stdout",
303 | "output_type": "stream",
304 | "text": [
305 | "\n"
306 | ]
307 | },
308 | {
309 | "data": {
310 | "application/vnd.jupyter.widget-view+json": {
311 | "model_id": "d1e56347bb704d8c9033d9ccec2f2015",
312 | "version_major": 2,
313 | "version_minor": 0
314 | },
315 | "text/plain": [
316 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))"
317 | ]
318 | },
319 | "metadata": {},
320 | "output_type": "display_data"
321 | },
322 | {
323 | "name": "stdout",
324 | "output_type": "stream",
325 | "text": [
326 | "\n"
327 | ]
328 | },
329 | {
330 | "data": {
331 | "application/vnd.jupyter.widget-view+json": {
332 | "model_id": "8c53cb0b856e4a7c8dc4851aff857322",
333 | "version_major": 2,
334 | "version_minor": 0
335 | },
336 | "text/plain": [
337 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))"
338 | ]
339 | },
340 | "metadata": {},
341 | "output_type": "display_data"
342 | },
343 | {
344 | "name": "stdout",
345 | "output_type": "stream",
346 | "text": [
347 | "\n",
348 | "Inferring model ../checkpoints/chexzero_weights/best_128_5e-05_original_22000_0.855.pt\n"
349 | ]
350 | },
351 | {
352 | "data": {
353 | "application/vnd.jupyter.widget-view+json": {
354 | "model_id": "a59cc7a934b64d03af79a60949c2e01f",
355 | "version_major": 2,
356 | "version_minor": 0
357 | },
358 | "text/plain": [
359 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))"
360 | ]
361 | },
362 | "metadata": {},
363 | "output_type": "display_data"
364 | },
365 | {
366 | "name": "stdout",
367 | "output_type": "stream",
368 | "text": [
369 | "\n"
370 | ]
371 | },
372 | {
373 | "data": {
374 | "application/vnd.jupyter.widget-view+json": {
375 | "model_id": "f9f4cf985d5140759ee44a39625ff30d",
376 | "version_major": 2,
377 | "version_minor": 0
378 | },
379 | "text/plain": [
380 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))"
381 | ]
382 | },
383 | "metadata": {},
384 | "output_type": "display_data"
385 | },
386 | {
387 | "name": "stdout",
388 | "output_type": "stream",
389 | "text": [
390 | "\n"
391 | ]
392 | },
393 | {
394 | "data": {
395 | "application/vnd.jupyter.widget-view+json": {
396 | "model_id": "ed6484585691494490fad5a40480dedb",
397 | "version_major": 2,
398 | "version_minor": 0
399 | },
400 | "text/plain": [
401 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))"
402 | ]
403 | },
404 | "metadata": {},
405 | "output_type": "display_data"
406 | },
407 | {
408 | "name": "stdout",
409 | "output_type": "stream",
410 | "text": [
411 | "\n"
412 | ]
413 | },
414 | {
415 | "data": {
416 | "application/vnd.jupyter.widget-view+json": {
417 | "model_id": "795518422bd44731a78254d9cf1759fb",
418 | "version_major": 2,
419 | "version_minor": 0
420 | },
421 | "text/plain": [
422 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))"
423 | ]
424 | },
425 | "metadata": {},
426 | "output_type": "display_data"
427 | },
428 | {
429 | "name": "stdout",
430 | "output_type": "stream",
431 | "text": [
432 | "\n",
433 | "Inferring model ../checkpoints/chexzero_weights/best_64_0.0001_original_16000_0.861.pt\n"
434 | ]
435 | },
436 | {
437 | "data": {
438 | "application/vnd.jupyter.widget-view+json": {
439 | "model_id": "4a565effbb694f639f0cdf633da11884",
440 | "version_major": 2,
441 | "version_minor": 0
442 | },
443 | "text/plain": [
444 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))"
445 | ]
446 | },
447 | "metadata": {},
448 | "output_type": "display_data"
449 | },
450 | {
451 | "name": "stdout",
452 | "output_type": "stream",
453 | "text": [
454 | "\n"
455 | ]
456 | },
457 | {
458 | "data": {
459 | "application/vnd.jupyter.widget-view+json": {
460 | "model_id": "ea44e753e85e442a9ee9080d52108149",
461 | "version_major": 2,
462 | "version_minor": 0
463 | },
464 | "text/plain": [
465 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))"
466 | ]
467 | },
468 | "metadata": {},
469 | "output_type": "display_data"
470 | },
471 | {
472 | "name": "stdout",
473 | "output_type": "stream",
474 | "text": [
475 | "\n"
476 | ]
477 | },
478 | {
479 | "data": {
480 | "application/vnd.jupyter.widget-view+json": {
481 | "model_id": "7f085cd056304e498a9dc251db09cb9a",
482 | "version_major": 2,
483 | "version_minor": 0
484 | },
485 | "text/plain": [
486 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))"
487 | ]
488 | },
489 | "metadata": {},
490 | "output_type": "display_data"
491 | },
492 | {
493 | "name": "stdout",
494 | "output_type": "stream",
495 | "text": [
496 | "\n"
497 | ]
498 | },
499 | {
500 | "data": {
501 | "application/vnd.jupyter.widget-view+json": {
502 | "model_id": "dbaf547d56ab47fcafd7cccde650a9ef",
503 | "version_major": 2,
504 | "version_minor": 0
505 | },
506 | "text/plain": [
507 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))"
508 | ]
509 | },
510 | "metadata": {},
511 | "output_type": "display_data"
512 | },
513 | {
514 | "name": "stdout",
515 | "output_type": "stream",
516 | "text": [
517 | "\n",
518 | "Inferring model ../checkpoints/chexzero_weights/best_64_0.0001_original_17000_0.863.pt\n"
519 | ]
520 | },
521 | {
522 | "data": {
523 | "application/vnd.jupyter.widget-view+json": {
524 | "model_id": "cd029103e56842e18feae09bb7488fd2",
525 | "version_major": 2,
526 | "version_minor": 0
527 | },
528 | "text/plain": [
529 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))"
530 | ]
531 | },
532 | "metadata": {},
533 | "output_type": "display_data"
534 | },
535 | {
536 | "name": "stdout",
537 | "output_type": "stream",
538 | "text": [
539 | "\n"
540 | ]
541 | },
542 | {
543 | "data": {
544 | "application/vnd.jupyter.widget-view+json": {
545 | "model_id": "7534cec7de8e4f0d9cfc22776a94d52d",
546 | "version_major": 2,
547 | "version_minor": 0
548 | },
549 | "text/plain": [
550 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))"
551 | ]
552 | },
553 | "metadata": {},
554 | "output_type": "display_data"
555 | },
556 | {
557 | "name": "stdout",
558 | "output_type": "stream",
559 | "text": [
560 | "\n"
561 | ]
562 | },
563 | {
564 | "data": {
565 | "application/vnd.jupyter.widget-view+json": {
566 | "model_id": "71ace33603504ca692a81330ad2c5c2a",
567 | "version_major": 2,
568 | "version_minor": 0
569 | },
570 | "text/plain": [
571 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))"
572 | ]
573 | },
574 | "metadata": {},
575 | "output_type": "display_data"
576 | },
577 | {
578 | "name": "stdout",
579 | "output_type": "stream",
580 | "text": [
581 | "\n"
582 | ]
583 | },
584 | {
585 | "data": {
586 | "application/vnd.jupyter.widget-view+json": {
587 | "model_id": "e5b30aed5f47451b80bbc61da9d45ad6",
588 | "version_major": 2,
589 | "version_minor": 0
590 | },
591 | "text/plain": [
592 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))"
593 | ]
594 | },
595 | "metadata": {},
596 | "output_type": "display_data"
597 | },
598 | {
599 | "name": "stdout",
600 | "output_type": "stream",
601 | "text": [
602 | "\n",
603 | "Inferring model ../checkpoints/chexzero_weights/best_64_0.0001_original_35000_0.864.pt\n"
604 | ]
605 | },
606 | {
607 | "data": {
608 | "application/vnd.jupyter.widget-view+json": {
609 | "model_id": "e9908db37e064d06af19924b9b02e1de",
610 | "version_major": 2,
611 | "version_minor": 0
612 | },
613 | "text/plain": [
614 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))"
615 | ]
616 | },
617 | "metadata": {},
618 | "output_type": "display_data"
619 | },
620 | {
621 | "name": "stdout",
622 | "output_type": "stream",
623 | "text": [
624 | "\n"
625 | ]
626 | },
627 | {
628 | "data": {
629 | "application/vnd.jupyter.widget-view+json": {
630 | "model_id": "3bf4a3799ba142babd924ed5fe13eaf2",
631 | "version_major": 2,
632 | "version_minor": 0
633 | },
634 | "text/plain": [
635 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))"
636 | ]
637 | },
638 | "metadata": {},
639 | "output_type": "display_data"
640 | },
641 | {
642 | "name": "stdout",
643 | "output_type": "stream",
644 | "text": [
645 | "\n"
646 | ]
647 | },
648 | {
649 | "data": {
650 | "application/vnd.jupyter.widget-view+json": {
651 | "model_id": "87c65ae5d98b4679907d25cb98157485",
652 | "version_major": 2,
653 | "version_minor": 0
654 | },
655 | "text/plain": [
656 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))"
657 | ]
658 | },
659 | "metadata": {},
660 | "output_type": "display_data"
661 | },
662 | {
663 | "name": "stdout",
664 | "output_type": "stream",
665 | "text": [
666 | "\n"
667 | ]
668 | },
669 | {
670 | "data": {
671 | "application/vnd.jupyter.widget-view+json": {
672 | "model_id": "52db175c5f3b41148c56b8ecab7f789b",
673 | "version_major": 2,
674 | "version_minor": 0
675 | },
676 | "text/plain": [
677 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))"
678 | ]
679 | },
680 | "metadata": {},
681 | "output_type": "display_data"
682 | },
683 | {
684 | "name": "stdout",
685 | "output_type": "stream",
686 | "text": [
687 | "\n",
688 | "Inferring model ../checkpoints/chexzero_weights/best_64_0.0002_original_23000_0.854.pt\n"
689 | ]
690 | },
691 | {
692 | "data": {
693 | "application/vnd.jupyter.widget-view+json": {
694 | "model_id": "e000f01f6d394e3584c8aebc74d965f0",
695 | "version_major": 2,
696 | "version_minor": 0
697 | },
698 | "text/plain": [
699 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))"
700 | ]
701 | },
702 | "metadata": {},
703 | "output_type": "display_data"
704 | },
705 | {
706 | "name": "stdout",
707 | "output_type": "stream",
708 | "text": [
709 | "\n"
710 | ]
711 | },
712 | {
713 | "data": {
714 | "application/vnd.jupyter.widget-view+json": {
715 | "model_id": "92429fac05ce4e87ac8093ed1e6b6e0b",
716 | "version_major": 2,
717 | "version_minor": 0
718 | },
719 | "text/plain": [
720 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))"
721 | ]
722 | },
723 | "metadata": {},
724 | "output_type": "display_data"
725 | },
726 | {
727 | "name": "stdout",
728 | "output_type": "stream",
729 | "text": [
730 | "\n"
731 | ]
732 | },
733 | {
734 | "data": {
735 | "application/vnd.jupyter.widget-view+json": {
736 | "model_id": "6daaede09a1745da99b23784ab1e65da",
737 | "version_major": 2,
738 | "version_minor": 0
739 | },
740 | "text/plain": [
741 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))"
742 | ]
743 | },
744 | "metadata": {},
745 | "output_type": "display_data"
746 | },
747 | {
748 | "name": "stdout",
749 | "output_type": "stream",
750 | "text": [
751 | "\n"
752 | ]
753 | },
754 | {
755 | "data": {
756 | "application/vnd.jupyter.widget-view+json": {
757 | "model_id": "50eb0ec5ac9f4561a6eee9e0b64e325e",
758 | "version_major": 2,
759 | "version_minor": 0
760 | },
761 | "text/plain": [
762 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))"
763 | ]
764 | },
765 | "metadata": {},
766 | "output_type": "display_data"
767 | },
768 | {
769 | "name": "stdout",
770 | "output_type": "stream",
771 | "text": [
772 | "\n",
773 | "Inferring model ../checkpoints/chexzero_weights/best_64_5e-05_original_16000_0.858.pt\n"
774 | ]
775 | },
776 | {
777 | "data": {
778 | "application/vnd.jupyter.widget-view+json": {
779 | "model_id": "75ed921a1c384da082f47aaae0d60db9",
780 | "version_major": 2,
781 | "version_minor": 0
782 | },
783 | "text/plain": [
784 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))"
785 | ]
786 | },
787 | "metadata": {},
788 | "output_type": "display_data"
789 | },
790 | {
791 | "name": "stdout",
792 | "output_type": "stream",
793 | "text": [
794 | "\n"
795 | ]
796 | },
797 | {
798 | "data": {
799 | "application/vnd.jupyter.widget-view+json": {
800 | "model_id": "02ea0f79065342f7bd0c5ed353c1e6e4",
801 | "version_major": 2,
802 | "version_minor": 0
803 | },
804 | "text/plain": [
805 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))"
806 | ]
807 | },
808 | "metadata": {},
809 | "output_type": "display_data"
810 | },
811 | {
812 | "name": "stdout",
813 | "output_type": "stream",
814 | "text": [
815 | "\n"
816 | ]
817 | },
818 | {
819 | "data": {
820 | "application/vnd.jupyter.widget-view+json": {
821 | "model_id": "f861a35947ab4068b05514f51c75ea22",
822 | "version_major": 2,
823 | "version_minor": 0
824 | },
825 | "text/plain": [
826 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))"
827 | ]
828 | },
829 | "metadata": {},
830 | "output_type": "display_data"
831 | },
832 | {
833 | "name": "stdout",
834 | "output_type": "stream",
835 | "text": [
836 | "\n"
837 | ]
838 | },
839 | {
840 | "data": {
841 | "application/vnd.jupyter.widget-view+json": {
842 | "model_id": "7249bae84ec04a2392614d30d832c645",
843 | "version_major": 2,
844 | "version_minor": 0
845 | },
846 | "text/plain": [
847 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))"
848 | ]
849 | },
850 | "metadata": {},
851 | "output_type": "display_data"
852 | },
853 | {
854 | "name": "stdout",
855 | "output_type": "stream",
856 | "text": [
857 | "\n",
858 | "Inferring model ../checkpoints/chexzero_weights/best_64_5e-05_original_18000_0.862.pt\n"
859 | ]
860 | },
861 | {
862 | "data": {
863 | "application/vnd.jupyter.widget-view+json": {
864 | "model_id": "75285bd7d3884a5fad36d80483421cd1",
865 | "version_major": 2,
866 | "version_minor": 0
867 | },
868 | "text/plain": [
869 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))"
870 | ]
871 | },
872 | "metadata": {},
873 | "output_type": "display_data"
874 | },
875 | {
876 | "name": "stdout",
877 | "output_type": "stream",
878 | "text": [
879 | "\n"
880 | ]
881 | },
882 | {
883 | "data": {
884 | "application/vnd.jupyter.widget-view+json": {
885 | "model_id": "6ce671a4cc6c476f82bb657a22402f57",
886 | "version_major": 2,
887 | "version_minor": 0
888 | },
889 | "text/plain": [
890 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))"
891 | ]
892 | },
893 | "metadata": {},
894 | "output_type": "display_data"
895 | },
896 | {
897 | "name": "stdout",
898 | "output_type": "stream",
899 | "text": [
900 | "\n"
901 | ]
902 | },
903 | {
904 | "data": {
905 | "application/vnd.jupyter.widget-view+json": {
906 | "model_id": "d11b53b9ecee49858b365ebf9a1104de",
907 | "version_major": 2,
908 | "version_minor": 0
909 | },
910 | "text/plain": [
911 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))"
912 | ]
913 | },
914 | "metadata": {},
915 | "output_type": "display_data"
916 | },
917 | {
918 | "name": "stdout",
919 | "output_type": "stream",
920 | "text": [
921 | "\n"
922 | ]
923 | },
924 | {
925 | "data": {
926 | "application/vnd.jupyter.widget-view+json": {
927 | "model_id": "f83be71ddfa8430b868f5c00e8c708be",
928 | "version_major": 2,
929 | "version_minor": 0
930 | },
931 | "text/plain": [
932 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))"
933 | ]
934 | },
935 | "metadata": {},
936 | "output_type": "display_data"
937 | },
938 | {
939 | "name": "stdout",
940 | "output_type": "stream",
941 | "text": [
942 | "\n",
943 | "Inferring model ../checkpoints/chexzero_weights/best_64_5e-05_original_22000_0.864.pt\n"
944 | ]
945 | },
946 | {
947 | "data": {
948 | "application/vnd.jupyter.widget-view+json": {
949 | "model_id": "48b10dc6aa524f5081861eef3a1fdf3f",
950 | "version_major": 2,
951 | "version_minor": 0
952 | },
953 | "text/plain": [
954 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))"
955 | ]
956 | },
957 | "metadata": {},
958 | "output_type": "display_data"
959 | },
960 | {
961 | "name": "stdout",
962 | "output_type": "stream",
963 | "text": [
964 | "\n"
965 | ]
966 | },
967 | {
968 | "data": {
969 | "application/vnd.jupyter.widget-view+json": {
970 | "model_id": "d1f822091f93432eaa5f29c1d62d0e29",
971 | "version_major": 2,
972 | "version_minor": 0
973 | },
974 | "text/plain": [
975 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))"
976 | ]
977 | },
978 | "metadata": {},
979 | "output_type": "display_data"
980 | },
981 | {
982 | "name": "stdout",
983 | "output_type": "stream",
984 | "text": [
985 | "\n"
986 | ]
987 | },
988 | {
989 | "data": {
990 | "application/vnd.jupyter.widget-view+json": {
991 | "model_id": "2ab6afefd29a4e008df820e3442def05",
992 | "version_major": 2,
993 | "version_minor": 0
994 | },
995 | "text/plain": [
996 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=14.0), HTML(value='')))"
997 | ]
998 | },
999 | "metadata": {},
1000 | "output_type": "display_data"
1001 | },
1002 | {
1003 | "name": "stdout",
1004 | "output_type": "stream",
1005 | "text": [
1006 | "\n"
1007 | ]
1008 | },
1009 | {
1010 | "data": {
1011 | "application/vnd.jupyter.widget-view+json": {
1012 | "model_id": "fff116a1443d4e9b96f41fbe53ebd1ac",
1013 | "version_major": 2,
1014 | "version_minor": 0
1015 | },
1016 | "text/plain": [
1017 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))"
1018 | ]
1019 | },
1020 | "metadata": {},
1021 | "output_type": "display_data"
1022 | },
1023 | {
1024 | "name": "stdout",
1025 | "output_type": "stream",
1026 | "text": [
1027 | "\n"
1028 | ]
1029 | }
1030 | ],
1031 | "source": [
1032 | "predictions, y_pred_avg = ensemble_models(\n",
1033 | " model_paths=model_paths, \n",
1034 | " cxr_filepath=cxr_filepath, \n",
1035 | " cxr_labels=cxr_labels, \n",
1036 | " cxr_pair_template=cxr_pair_template, \n",
1037 | " cache_dir=cache_dir,\n",
1038 | ")"
1039 | ]
1040 | },
1041 | {
1042 | "cell_type": "code",
1043 | "execution_count": 22,
1044 | "metadata": {},
1045 | "outputs": [],
1046 | "source": [
1047 | "# save averaged preds\n",
1048 | "pred_name = \"chexpert_preds.npy\" # add name of preds\n",
1049 | "predictions_dir = predictions_dir / pred_name\n",
1050 | "np.save(file=predictions_dir, arr=y_pred_avg)"
1051 | ]
1052 | },
1053 | {
1054 | "cell_type": "markdown",
1055 | "metadata": {},
1056 | "source": [
1057 | "## (Optional) Evaluate Results\n",
1058 | "If ground truth labels are available, compute AUC on each pathology to evaluate the performance of the zero-shot model. "
1059 | ]
1060 | },
1061 | {
1062 | "cell_type": "code",
1063 | "execution_count": 23,
1064 | "metadata": {},
1065 | "outputs": [
1066 | {
1067 | "data": {
1068 | "application/vnd.jupyter.widget-view+json": {
1069 | "model_id": "a338f58b25a94d68a3be00565cbaca39",
1070 | "version_major": 2,
1071 | "version_minor": 0
1072 | },
1073 | "text/plain": [
1074 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))"
1075 | ]
1076 | },
1077 | "metadata": {},
1078 | "output_type": "display_data"
1079 | },
1080 | {
1081 | "name": "stdout",
1082 | "output_type": "stream",
1083 | "text": [
1084 | "\n"
1085 | ]
1086 | }
1087 | ],
1088 | "source": [
1089 | "# make test_true\n",
1090 | "test_pred = y_pred_avg\n",
1091 | "test_true = make_true_labels(cxr_true_labels_path=cxr_true_labels_path, cxr_labels=cxr_labels)\n",
1092 | "\n",
1093 | "# evaluate model\n",
1094 | "cxr_results = evaluate(test_pred, test_true, cxr_labels)\n",
1095 | "\n",
1096 | "# boostrap evaluations for 95% confidence intervals\n",
1097 | "bootstrap_results = bootstrap(test_pred, test_true, cxr_labels)"
1098 | ]
1099 | },
1100 | {
1101 | "cell_type": "code",
1102 | "execution_count": 25,
1103 | "metadata": {},
1104 | "outputs": [
1105 | {
1106 | "data": {
1107 | "text/html": [
1108 | "
\n",
1109 | "\n",
1122 | "
\n",
1123 | " \n",
1124 | " \n",
1125 | " | \n",
1126 | " Atelectasis_auc | \n",
1127 | " Cardiomegaly_auc | \n",
1128 | " Consolidation_auc | \n",
1129 | " Edema_auc | \n",
1130 | " Enlarged Cardiomediastinum_auc | \n",
1131 | " Fracture_auc | \n",
1132 | " Lung Lesion_auc | \n",
1133 | " Lung Opacity_auc | \n",
1134 | " No Finding_auc | \n",
1135 | " Pleural Effusion_auc | \n",
1136 | " Pleural Other_auc | \n",
1137 | " Pneumonia_auc | \n",
1138 | " Pneumothorax_auc | \n",
1139 | " Support Devices_auc | \n",
1140 | "
\n",
1141 | " \n",
1142 | " \n",
1143 | " \n",
1144 | " | mean | \n",
1145 | " 0.8118 | \n",
1146 | " 0.9132 | \n",
1147 | " 0.8901 | \n",
1148 | " 0.8994 | \n",
1149 | " 0.9160 | \n",
1150 | " 0.5603 | \n",
1151 | " 0.7360 | \n",
1152 | " 0.9213 | \n",
1153 | " 0.0700 | \n",
1154 | " 0.9317 | \n",
1155 | " 0.6025 | \n",
1156 | " 0.7798 | \n",
1157 | " 0.6520 | \n",
1158 | " 0.7735 | \n",
1159 | "
\n",
1160 | " \n",
1161 | " | lower | \n",
1162 | " 0.7720 | \n",
1163 | " 0.8849 | \n",
1164 | " 0.8201 | \n",
1165 | " 0.8662 | \n",
1166 | " 0.8912 | \n",
1167 | " 0.2646 | \n",
1168 | " 0.5658 | \n",
1169 | " 0.8961 | \n",
1170 | " 0.0451 | \n",
1171 | " 0.9053 | \n",
1172 | " 0.4608 | \n",
1173 | " 0.5695 | \n",
1174 | " 0.4854 | \n",
1175 | " 0.7310 | \n",
1176 | "
\n",
1177 | " \n",
1178 | " | upper | \n",
1179 | " 0.8479 | \n",
1180 | " 0.9367 | \n",
1181 | " 0.9470 | \n",
1182 | " 0.9295 | \n",
1183 | " 0.9375 | \n",
1184 | " 0.8725 | \n",
1185 | " 0.8779 | \n",
1186 | " 0.9426 | \n",
1187 | " 0.0952 | \n",
1188 | " 0.9536 | \n",
1189 | " 0.8855 | \n",
1190 | " 0.9483 | \n",
1191 | " 0.8243 | \n",
1192 | " 0.8130 | \n",
1193 | "
\n",
1194 | " \n",
1195 | "
\n",
1196 | "
"
1197 | ],
1198 | "text/plain": [
1199 | " Atelectasis_auc Cardiomegaly_auc Consolidation_auc Edema_auc \\\n",
1200 | "mean 0.8118 0.9132 0.8901 0.8994 \n",
1201 | "lower 0.7720 0.8849 0.8201 0.8662 \n",
1202 | "upper 0.8479 0.9367 0.9470 0.9295 \n",
1203 | "\n",
1204 | " Enlarged Cardiomediastinum_auc Fracture_auc Lung Lesion_auc \\\n",
1205 | "mean 0.9160 0.5603 0.7360 \n",
1206 | "lower 0.8912 0.2646 0.5658 \n",
1207 | "upper 0.9375 0.8725 0.8779 \n",
1208 | "\n",
1209 | " Lung Opacity_auc No Finding_auc Pleural Effusion_auc \\\n",
1210 | "mean 0.9213 0.0700 0.9317 \n",
1211 | "lower 0.8961 0.0451 0.9053 \n",
1212 | "upper 0.9426 0.0952 0.9536 \n",
1213 | "\n",
1214 | " Pleural Other_auc Pneumonia_auc Pneumothorax_auc Support Devices_auc \n",
1215 | "mean 0.6025 0.7798 0.6520 0.7735 \n",
1216 | "lower 0.4608 0.5695 0.4854 0.7310 \n",
1217 | "upper 0.8855 0.9483 0.8243 0.8130 "
1218 | ]
1219 | },
1220 | "execution_count": 25,
1221 | "metadata": {},
1222 | "output_type": "execute_result"
1223 | }
1224 | ],
1225 | "source": [
1226 | "# display AUC with confidence intervals\n",
1227 | "bootstrap_results[1]"
1228 | ]
1229 | },
1230 | {
1231 | "cell_type": "code",
1232 | "execution_count": null,
1233 | "metadata": {},
1234 | "outputs": [],
1235 | "source": []
1236 | }
1237 | ],
1238 | "metadata": {
1239 | "interpreter": {
1240 | "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
1241 | },
1242 | "kernelspec": {
1243 | "display_name": "Python 3",
1244 | "language": "python",
1245 | "name": "python3"
1246 | },
1247 | "language_info": {
1248 | "codemirror_mode": {
1249 | "name": "ipython",
1250 | "version": 3
1251 | },
1252 | "file_extension": ".py",
1253 | "mimetype": "text/x-python",
1254 | "name": "python",
1255 | "nbconvert_exporter": "python",
1256 | "pygments_lexer": "ipython3",
1257 | "version": "3.8.5"
1258 | }
1259 | },
1260 | "nbformat": 4,
1261 | "nbformat_minor": 2
1262 | }
1263 |
--------------------------------------------------------------------------------