├── LICENSE.pdf ├── README.md ├── environment.yml ├── figures └── VisualCheXbert_Figure.png ├── requirements.txt └── src ├── bert_tokenizer.py ├── constants.py ├── datasets ├── impressions_dataset.py └── unlabeled_dataset.py ├── label.py ├── models └── bert_labeler.py ├── sample_reports ├── larger_sample_reports.csv └── sample_reports.csv └── utils.py /LICENSE.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordmlgroup/VisualCheXbert/c7cb58d8fc0197a1370e587ca78b8e183ec2dc8e/LICENSE.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VisualCheXbert: Addressing the Discrepancy Between Radiology Report Labels and Image Labels 2 | VisualCheXbert is an automated deep-learning based chest radiology report labeler that can label for the following 14 medical observations: Fracture, Consolidation, Enlarged Cardiomediastinum, No Finding, Pleural Other, Cardiomegaly, Pneumothorax, Atelectasis, Support Devices, Edema, Pleural Effusion, Lung Lesion, and Lung Opacity. VisualCheXbert aims to produce better labels for computer vision models using only the textual reports and improves performance compared to previous report labeling approaches when evaluated against labels annotated by radiologists observing associated chest X-ray images. 3 | 4 | Paper (Accepted to ACM-CHIL 2021): https://arxiv.org/pdf/2102.11467.pdf
5 | Note that due to class imbalance in our test set, our paper reports scores on the following subset of the 14 conditions: Atelectasis, Cardiomegaly, Edema, Pleural Effusion, Enlarged Cardiomediastinum, Lung Opacity, Support Devices, and No Finding. 6 | 7 | 8 | License from us (For Commercial Purposes): Please use the same licensing contact for VisualCheXbert as for CheXbert, which can be found here: http://techfinder2.stanford.edu/technology_detail.php?ID=43869. 9 | 10 | ## Abstract 11 | 12 | Automatic extraction of medical conditions from free-text radiology reports is critical for supervising computer vision models to interpret medical images. In this work, we show that radiologists labeling reports significantly disagree with radiologists labeling corresponding chest X-ray images, which reduces the quality of report labels as proxies for image labels. We develop and evaluate methods to produce labels from radiology reports that have better agreement with radiologists labeling images. Our best performing method, called VisualCheXbert, uses a biomedically-pretrained BERT model to directly map from a radiology report to the image labels, with a supervisory signal determined by a computer vision model trained to detect medical conditions from chest X-ray images. We find that VisualCheXbert outperforms an approach using an existing radiology report labeler by an average F1 score of 0.14 (95% CI 0.12, 0.17). We also find that VisualCheXbert better agrees with radiologists labeling chest X-ray images than do radiologists labeling the corresponding radiology reports by an average F1 score across several medical conditions of between 0.12 (95% CI 0.09, 0.15) and 0.21 (95% CI 0.18, 0.24). 13 | 14 | ![The VisualCheXbert training approach](figures/VisualCheXbert_Figure.png) 15 | 16 | ## Prerequisites 17 | (Recommended) Install requirements, with Python 3.7 or higher, using pip. 18 | 19 | ``` 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | OR 24 | 25 | Create conda environment. 26 | 27 | ``` 28 | conda env create -f environment.yml 29 | ``` 30 | 31 | Activate environment. 32 | 33 | ``` 34 | conda activate visualCheXbert 35 | ``` 36 | 37 | By default, all available GPU's will be used for labeling in parallel. If there is no GPU, the CPU is used. You can control which GPU's are used by appropriately setting CUDA_VISIBLE_DEVICES. The batch size by default is 18 but can be changed inside src/constants.py. 38 | 39 | ## Checkpoint download 40 | 41 | Download our trained model checkpoints here: https://drive.google.com/file/d/1QmwAurEmiXc-_uF0JrgAbWCxpF-pl52s/view?usp=drive_link. 42 | 43 | ## Usage 44 | 45 | ### Label reports with VisualCheXbert 46 | 47 | 1. Put all reports in a csv file under the column name "Report Impression" (see src/sample_reports/sample_reports.csv for an example). The path to this csv is {path to reports}. 48 | 2. Download and unzip the checkpoint folder in the src directory (see above section). The path to this folder is {path to checkpoint folder}. 49 | 3. Navigate to the src directory and run the following command, where the path to your desired output folder is {path to output dir}: 50 | 51 | ``` 52 | python label.py -d={path to reports} -o={path to output dir} -c={path to checkpoint folder} 53 | ``` 54 | 55 | The output file with labeled reports is {path to output dir}/labeled_reports.csv. Note that the output of VisualCheXbert is binary, where a label of 1 corresponds to presence of a condition in the associated X-ray and label of 0 corresponds to the absence of a condition in the associated X-ray. 56 | 57 | Run the following for descriptions of all command line arguments: 58 | 59 | ``` 60 | python label.py -h 61 | ``` 62 | 63 | **Ignore any error messages about the size of the report exceeding 512 tokens. All reports are automatically cut off at 512 tokens.** 64 | 65 | # Label Convention 66 | 67 | The labeler outputs the following numbers corresponding to classes: 68 | - Positive: 1 69 | - Negative: 0 70 | 71 | # Citation 72 | 73 | If you use the VisualCheXbert labeler in your work, please cite our paper: 74 | 75 | ``` 76 | @article{jain2021visualchexbert, 77 | title={VisualCheXbert: Addressing the Discrepancy Between Radiology Report Labels and Image Labels}, 78 | author={Jain, Saahil and Smit, Akshay and Truong, Steven QH and Nguyen, Chanh DT and Huynh, Minh-Thanh and Jain, Mudit and Young, Victoria A and Ng, Andrew Y and Lungren, Matthew P and Rajpurkar, Pranav}, 79 | journal={arXiv preprint arXiv:2102.11467}, 80 | year={2021} 81 | } 82 | ``` 83 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: visualCheXbert 2 | channels: 3 | - defaults 4 | dependencies: 5 | - blas=1.0 6 | - ca-certificates=2021.1.19 7 | - certifi=2020.12.5 8 | - intel-openmp=2019.4 9 | - libcxx=10.0.0 10 | - libedit=3.1.20191231 11 | - libffi=3.3 12 | - libgfortran=3.0.1 13 | - llvm-openmp=10.0.0 14 | - mkl=2019.4 15 | - mkl-service=2.3.0 16 | - mkl_fft=1.3.0 17 | - mkl_random=1.1.1 18 | - ncurses=6.2 19 | - numpy-base=1.19.2 20 | - openssl=1.1.1j 21 | - pip=21.0.1 22 | - python=3.7.7 23 | - readline=8.1 24 | - scipy=1.6.1 25 | - setuptools=52.0.0 26 | - six=1.15.0 27 | - sqlite=3.33.0 28 | - tk=8.6.10 29 | - wheel=0.36.2 30 | - xz=5.2.5 31 | - zlib=1.2.11 32 | - pip: 33 | - chardet==4.0.0 34 | - click==7.1.2 35 | - filelock==3.0.12 36 | - idna==2.10 37 | - joblib==0.14.1 38 | - numpy==1.20.1 39 | - packaging==20.9 40 | - pandas==1.2.2 41 | - pyparsing==2.4.7 42 | - python-dateutil==2.8.1 43 | - pytz==2021.1 44 | - regex==2020.11.13 45 | - requests==2.25.1 46 | - sacremoses==0.0.43 47 | - scikit-learn==0.22.1 48 | - sentencepiece==0.1.95 49 | - sklearn==0.0 50 | - threadpoolctl==2.1.0 51 | - tokenizers==0.8.1rc1 52 | - torch==1.4.0 53 | - tqdm==4.57.0 54 | - transformers==3.0.2 55 | - typing-extensions==3.7.4.3 56 | - urllib3==1.26.3 57 | 58 | -------------------------------------------------------------------------------- /figures/VisualCheXbert_Figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordmlgroup/VisualCheXbert/c7cb58d8fc0197a1370e587ca78b8e183ec2dc8e/figures/VisualCheXbert_Figure.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.18.2 2 | pandas==0.25.3 3 | scikit-learn==0.22.1 4 | tokenizers==0.5.2 5 | torch==1.4.0 6 | tqdm==4.38.0 7 | transformers==2.5.1 8 | -------------------------------------------------------------------------------- /src/bert_tokenizer.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import transformers 3 | from transformers import BertTokenizer, AutoTokenizer 4 | import json 5 | from tqdm import tqdm 6 | 7 | def get_impressions_from_csv(path): 8 | df = pd.read_csv(path) 9 | imp = df['Report Impression'] 10 | imp = imp.str.strip() 11 | imp = imp.replace('\n',' ', regex=True) 12 | imp = imp.replace('\s+', ' ', regex=True) 13 | imp = imp.str.strip() 14 | return imp 15 | 16 | def tokenize(impressions, tokenizer): 17 | new_impressions = [] 18 | print("\nTokenizing report impressions. All reports are cut off at 512 tokens.") 19 | for i in tqdm(range(impressions.shape[0])): 20 | tokenized_imp = tokenizer.tokenize(impressions.iloc[i]) 21 | if tokenized_imp: #not an empty report 22 | res = tokenizer.encode_plus(tokenized_imp)['input_ids'] 23 | if len(res) > 512: #length exceeds maximum size 24 | #print("report length bigger than 512") 25 | res = res[:511] + [tokenizer.sep_token_id] 26 | new_impressions.append(res) 27 | else: #an empty report 28 | new_impressions.append([tokenizer.cls_token_id, tokenizer.sep_token_id]) 29 | return new_impressions 30 | 31 | def load_list(path): 32 | with open(path, 'r') as filehandle: 33 | impressions = json.load(filehandle) 34 | return impressions 35 | 36 | if __name__ == "__main__": 37 | tokenizer = BertTokenizer.from_pretrained('/data3/aihc-winter20-chexbert/bluebert/pretrain_repo') 38 | #tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 39 | #tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") 40 | #tokenizer = AutoTokenizer.from_pretrained('xlnet-base-cased') 41 | 42 | impressions = get_impressions_from_csv('/data3/aihc-winter20-chexbert/chexpert_data/vision_test_gt.csv') 43 | new_impressions = tokenize(impressions, tokenizer) 44 | with open('/data3/aihc-winter20-chexbert/bluebert/vision_labels/impressions_lists/vision_test', 'w') as filehandle: 45 | json.dump(new_impressions, filehandle) 46 | -------------------------------------------------------------------------------- /src/constants.py: -------------------------------------------------------------------------------- 1 | NUM_EPOCHS = 8 2 | BATCH_SIZE = 18 #Refer to load_data docstring before changing this! 3 | NUM_WORKERS = 4 #A value of 0 means the main process loads the data 4 | LEARNING_RATE = 2e-5 5 | PAD_IDX = 0 #padding index as required by the tokenizer 6 | LOG_EVERY = 200 #iterations after which to log status 7 | VALID_NITER = 2000 #iterations after which to evaluate model and possibly save 8 | TEMP = 1.0 #softmax temperature for self training 9 | 10 | #CONDITIONS is a list of all 14 medical observations 11 | CONDITIONS = ['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 12 | 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 13 | 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 14 | 'Support Devices', 'No Finding'] 15 | #List of conditions to evaluate model's performance during validation 16 | EVAL_CONDITIONS = set(CONDITIONS) 17 | CLASS_MAPPING = {0: "Blank", 1: "Positive", 2: "Negative", 3: "Uncertain"} 18 | -------------------------------------------------------------------------------- /src/datasets/impressions_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import numpy as np 4 | from bert_tokenizer import load_list 5 | from torch.utils.data import Dataset, DataLoader 6 | 7 | class ImpressionsDataset(Dataset): 8 | """The dataset to contain report impressions and their labels.""" 9 | 10 | def __init__(self, csv_path, list_path): 11 | """ Initialize the dataset object 12 | @param csv_path (string): path to the csv file containing labels 13 | @param list_path (string): path to the list of encoded impressions 14 | """ 15 | self.df = pd.read_csv(csv_path) 16 | self.df = self.df[['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 17 | 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 18 | 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 19 | 'Support Devices', 'No Finding']] 20 | self.df.fillna(0.0, inplace=True) #only have two labels, positive and negative 21 | self.encoded_imp = load_list(path=list_path) 22 | 23 | def __len__(self): 24 | """Compute the length of the dataset 25 | 26 | @return (int): size of the dataframe 27 | """ 28 | return self.df.shape[0] 29 | 30 | def __getitem__(self, idx): 31 | """ Functionality to index into the dataset 32 | @param idx (int): Integer index into the dataset 33 | 34 | @return (dictionary): Has keys 'imp', 'label' and 'len'. The value of 'imp' is 35 | a LongTensor of an encoded impression. The value of 'label' 36 | is a LongTensor containing the labels and 'the value of 37 | 'len' is an integer representing the length of imp's value 38 | """ 39 | if torch.is_tensor(idx): 40 | idx = idx.tolist() 41 | label = self.df.iloc[idx].to_numpy() 42 | label = torch.Tensor(label) 43 | imp = self.encoded_imp[idx] 44 | imp = torch.LongTensor(imp) 45 | return {"imp": imp, "label": label, "len": imp.shape[0]} 46 | -------------------------------------------------------------------------------- /src/datasets/unlabeled_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import numpy as np 4 | from transformers import BertTokenizer 5 | import bert_tokenizer 6 | from torch.utils.data import Dataset, DataLoader 7 | 8 | class UnlabeledDataset(Dataset): 9 | """The dataset to contain report impressions without any labels.""" 10 | 11 | def __init__(self, csv_path): 12 | """ Initialize the dataset object 13 | @param csv_path (string): path to the csv file containing rhe reports. It 14 | should have a column named "Report Impression" 15 | """ 16 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 17 | impressions = bert_tokenizer.get_impressions_from_csv(csv_path) 18 | self.encoded_imp = bert_tokenizer.tokenize(impressions, tokenizer) 19 | 20 | def __len__(self): 21 | """Compute the length of the dataset 22 | 23 | @return (int): size of the dataframe 24 | """ 25 | return len(self.encoded_imp) 26 | 27 | def __getitem__(self, idx): 28 | """ Functionality to index into the dataset 29 | @param idx (int): Integer index into the dataset 30 | 31 | @return (dictionary): Has keys 'imp', 'label' and 'len'. The value of 'imp' is 32 | a LongTensor of an encoded impression. The value of 'label' 33 | is a LongTensor containing the labels and 'the value of 34 | 'len' is an integer representing the length of imp's value 35 | """ 36 | if torch.is_tensor(idx): 37 | idx = idx.tolist() 38 | imp = self.encoded_imp[idx] 39 | imp = torch.LongTensor(imp) 40 | return {"imp": imp, "len": imp.shape[0]} 41 | -------------------------------------------------------------------------------- /src/label.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import pandas as pd 6 | import numpy as np 7 | import utils 8 | from models.bert_labeler import bert_labeler 9 | from bert_tokenizer import tokenize 10 | from transformers import BertTokenizer 11 | from collections import OrderedDict 12 | from datasets.unlabeled_dataset import UnlabeledDataset 13 | from constants import * 14 | from tqdm import tqdm 15 | import pickle 16 | 17 | def collate_fn_no_labels(sample_list): 18 | """Custom collate function to pad reports in each batch to the max len, 19 | where the reports have no associated labels 20 | @param sample_list (List): A list of samples. Each sample is a dictionary with 21 | keys 'imp', 'len' as returned by the __getitem__ 22 | function of ImpressionsDataset 23 | 24 | @returns batch (dictionary): A dictionary with keys 'imp' and 'len' but now 25 | 'imp' is a tensor with padding and batch size as the 26 | first dimension. 'len' is a list of the length of 27 | each sequence in batch 28 | """ 29 | tensor_list = [s['imp'] for s in sample_list] 30 | batched_imp = torch.nn.utils.rnn.pad_sequence(tensor_list, 31 | batch_first=True, 32 | padding_value=PAD_IDX) 33 | len_list = [s['len'] for s in sample_list] 34 | batch = {'imp': batched_imp, 'len': len_list} 35 | return batch 36 | 37 | def load_unlabeled_data(csv_path, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, 38 | shuffle=False): 39 | """ Create UnlabeledDataset object for the input reports 40 | @param csv_path (string): path to csv file containing reports 41 | @param batch_size (int): the batch size. As per the BERT repository, the max batch size 42 | that can fit on a TITAN XP is 6 if the max sequence length 43 | is 512, which is our case. We have 3 TITAN XP's 44 | @param num_workers (int): how many worker processes to use to load data 45 | @param shuffle (bool): whether to shuffle the data or not 46 | 47 | @returns loader (dataloader): dataloader object for the reports 48 | """ 49 | collate_fn = collate_fn_no_labels 50 | dset = UnlabeledDataset(csv_path) 51 | loader = torch.utils.data.DataLoader(dset, batch_size=batch_size, shuffle=shuffle, 52 | num_workers=num_workers, collate_fn=collate_fn) 53 | return loader 54 | 55 | def apply_logreg_mapping(df_probs, logreg_models_path): 56 | logreg_models = {} 57 | visualchexbert_dict = {} 58 | try: 59 | with open(logreg_models_path, "rb") as handle: 60 | logreg_models = pickle.load(handle) 61 | except Exception as e: 62 | print("Error loading path to logistic regression models. Please ensure that the pickle file is in the checkpoint folder.") 63 | print(f"Exception: {e}") 64 | for condition in CONDITIONS: 65 | clf = logreg_models[condition] 66 | y_pred = clf.predict(df_probs) 67 | visualchexbert_dict[condition] = y_pred 68 | df_visualchexbert = pd.DataFrame.from_dict(visualchexbert_dict) 69 | return df_visualchexbert 70 | 71 | def label_and_save_preds(checkpoint_folder, csv_path, out_path): 72 | """Labels a dataset of reports 73 | @param checkpoint_path (string): location of saved model checkpoint 74 | @param csv_path (string): location of csv with reports 75 | @param out_path (string): path to output directory 76 | 77 | @returns y_pred (List[List[int]]): Labels for each of the 14 conditions, per report 78 | """ 79 | ld = load_unlabeled_data(csv_path) 80 | 81 | model = bert_labeler() 82 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 83 | checkpoint_path = f"{checkpoint_folder}/visualCheXbert.pth" 84 | if torch.cuda.device_count() > 0: #works even if only 1 GPU available 85 | print("Using", torch.cuda.device_count(), "GPUs!") 86 | model = nn.DataParallel(model) #to utilize multiple GPU's 87 | model = model.to(device) 88 | checkpoint = torch.load(checkpoint_path) 89 | model.load_state_dict(checkpoint['model_state_dict']) 90 | else: 91 | checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) 92 | new_state_dict = OrderedDict() 93 | for k, v in checkpoint['model_state_dict'].items(): 94 | name = k[7:] # remove `module.` 95 | new_state_dict[name] = v 96 | model.load_state_dict(new_state_dict) 97 | 98 | was_training = model.training 99 | model.eval() 100 | y_pred = [[] for _ in range(len(CONDITIONS))] 101 | 102 | print("\nBegin report impression labeling. The progress bar counts the # of batches completed:") 103 | print("The batch size is %d" % BATCH_SIZE) 104 | with torch.no_grad(): 105 | for i, data in enumerate(tqdm(ld)): 106 | batch = data['imp'] #(batch_size, max_len) 107 | batch = batch.to(device) 108 | src_len = data['len'] 109 | batch_size = batch.shape[0] 110 | attn_mask = utils.generate_attention_masks(batch, src_len, device) 111 | 112 | out = model(batch, attn_mask) 113 | 114 | for j in range(len(out)): 115 | curr_y_pred = torch.sigmoid(out[j]) #shape is (batch_size) 116 | y_pred[j].append(curr_y_pred) 117 | 118 | for j in range(len(y_pred)): 119 | y_pred[j] = torch.cat(y_pred[j], dim=0) 120 | 121 | if was_training: 122 | model.train() 123 | 124 | y_pred = [t.tolist() for t in y_pred] 125 | y_pred = np.array(y_pred) 126 | y_pred = y_pred.T 127 | 128 | df = pd.DataFrame(y_pred, columns=CONDITIONS) 129 | 130 | # Apply mapping from probs to image labels 131 | logreg_models_path = f"{checkpoint_folder}/logreg_models.pickle" 132 | df_visualchexbert = apply_logreg_mapping(df, logreg_models_path) 133 | 134 | reports = pd.read_csv(csv_path)['Report Impression'].tolist() 135 | 136 | df_visualchexbert.insert(loc=0, column='Report Impression', value=reports) 137 | df_visualchexbert.to_csv(os.path.join(out_path, 'labeled_reports.csv'), index=False) 138 | 139 | return df_visualchexbert 140 | 141 | if __name__ == '__main__': 142 | parser = argparse.ArgumentParser(description='Label a csv file containing radiology reports') 143 | parser.add_argument('-d', '--data', type=str, nargs='?', required=True, 144 | help='path to csv containing reports. The reports should be \ 145 | under the \"Report Impression\" column') 146 | parser.add_argument('-o', '--output_dir', type=str, nargs='?', required=True, 147 | help='path to intended output folder') 148 | parser.add_argument('-c', '--checkpoint_folder', type=str, nargs='?', required=False, default="checkpoint", 149 | help='path to folder with trained model checkpoints') 150 | args = parser.parse_args() 151 | csv_path = args.data 152 | out_path = args.output_dir 153 | checkpoint_path = args.checkpoint_folder 154 | 155 | label_and_save_preds(checkpoint_path, csv_path, out_path) 156 | -------------------------------------------------------------------------------- /src/models/bert_labeler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import BertModel, AutoModel 4 | 5 | class bert_labeler(nn.Module): 6 | def __init__(self, p=0.1, clinical=False, freeze_embeddings=False, pretrain_path=None): 7 | """ Init the labeler module 8 | @param p (float): p to use for dropout in the linear heads, 0.1 by default is consistant with 9 | transformers.BertForSequenceClassification 10 | @param clinical (boolean): True if Bio_Clinical BERT desired, False otherwise. Ignored if 11 | pretrain_path is not None 12 | @param freeze_embeddings (boolean): true to freeze bert embeddings during training 13 | @param pretrain_path (string): path to load checkpoint from 14 | """ 15 | super(bert_labeler, self).__init__() 16 | 17 | if pretrain_path is not None: 18 | self.bert = BertModel.from_pretrained(pretrain_path) 19 | elif clinical: 20 | self.bert = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") 21 | else: 22 | self.bert = BertModel.from_pretrained('bert-base-uncased') 23 | 24 | if freeze_embeddings: 25 | for param in self.bert.embeddings.parameters(): 26 | param.requires_grad = False 27 | 28 | self.dropout = nn.Dropout(p) 29 | #size of the output of transformer's last layer 30 | hidden_size = self.bert.pooler.dense.in_features 31 | #classes: present, absent 32 | self.linear_heads = nn.ModuleList([nn.Linear(hidden_size, 1, bias=True) for _ in range(14)]) 33 | 34 | def forward(self, source_padded, attention_mask): 35 | """ Forward pass of the labeler 36 | @param source_padded (torch.LongTensor): Tensor of word indices with padding, shape (batch_size, max_len) 37 | @param attention_mask (torch.Tensor): Mask to avoid attention on padding tokens, shape (batch_size, max_len) 38 | @returns out (List[torch.Tensor])): A list of size 14 containing tensors of shape (batch_size,) 39 | """ 40 | #shape (batch_size, max_len, hidden_size) 41 | final_hidden = self.bert(source_padded, attention_mask=attention_mask)[0] 42 | #shape (batch_size, hidden_size) 43 | cls_hidden = final_hidden[:, 0, :].squeeze(dim=1) 44 | cls_hidden = self.dropout(cls_hidden) 45 | out = [] 46 | for i in range(14): 47 | lin_out = self.linear_heads[i](cls_hidden) 48 | out.append(lin_out.squeeze(dim=1)) 49 | return out 50 | 51 | -------------------------------------------------------------------------------- /src/sample_reports/larger_sample_reports.csv: -------------------------------------------------------------------------------- 1 | Report Impression 2 | No evidence of acute disease. 3 | Slight pulmonary vascular congestion without pulmonary edema or focal consolidation. Stable mild cardiomegaly. 4 | No acute cardiopulmonary process. 5 | No definite mass identified. Bibasilar opacities are likely atelectasis 6 | "In comparison with the earlier study of this date, the nasogastric tube has been pushed forward so that the side port is below the level of the esophagogastric junction. Otherwise little change." 7 | Increased bibasilar opacities could be due to aspiration or atelectasis. 8 | "Probable, new right lower lobe pneumonia." 9 | "As compared to ___ radiograph, right lower lobe opacity has nearly resolved, but a left lower lobe opacity has progressed. In a patient with a neurological condition, recurrent aspiration and or developing aspiration pneumonia should be considered. Exam is otherwise remarkable for removal of nasogastric tube and development of distended loops of bowel in the imaged upper abdomen, incompletely evaluated on this portable chest radiograph." 10 | New heterogeneous opacification at the right lung base could be pneumonia or aspirated blood. Cardiomediastinal and hilar silhouettes and pleural surfaces are normal. Right PIC line ends in the low SVC. Nasogastric tube passes into the stomach and out of view. 11 | "No relevant change as compared to the previous exam. The patient is rotated to the right, which slightly increases the radiodensity on the right side. No new focal parenchymal opacities. No pleural effusions. No pulmonary edema. Unchanged normal size of the cardiac silhouette." 12 | 1. Satisfactory position of endotracheal and enteric tubes. 2. The left internal jugular central venous line appears to be high in position and terminates in the internal jugular vein. 3. Small left pleural effusion is partially imaged. 13 | NG tube tip isin the stomach. Cardiac size is normal. Right PICC tip is in the lower SVC. Visualized lungs are clear. Of note the upper portion of the chest was not included on the field of view. 14 | "In comparison with the study of ___, the tip of the replaced nasogastric tube is in the mid portion of the stomach. However, the side port is at the level of the esophagogastric junction so that the tube should be pushed forward at least 5-7 cm for better positioning. The cardiac silhouette is within normal limits. Prominence of interstitial markings at the bases with Kerley lines are consistent with some elevation of pulmonary venous pressure. However, areas of possible coalescence at the bases could represent superimposed pneumonia in the appropriate clinical setting." 15 | No significant interval change. 16 | No acute intrathoracic process 17 | No acute cardiopulmonary process. 18 | No evidence of acute disease. 19 | "1. No acute fracture is seen, but chest radiographs are insensitive for rib fractures. If there is high clinical concern for rib fracture, dedicated rib films would be recommended with the sites of tenderness to palpitation marked. 2. On the lateral view, the decending aorta is more prominent than on prior exams, which could represent a small aneursym at that level. Recommend correlation with clinical history and physical findings to assess the likelihood that the patient's fall was caused by this possible anuerysm or an associated acute dissection. These findings were communicated to Dr. ___ at 8:01 a.m. on ___ by phone." 20 | No evidence of acute cardiopulmonary process. 21 | Bibasilar atelectasis and upper zone redistribution. No CHF or frank consolidation. No gross effusion. 22 | New left lower lobe opacity concerning for pneumonia or aspiration pneumonia. Improving right basal opacity. Bilateral pleural effusions. These findings were communicated via the radiology critical results dashboard at 12:57 p.m. 23 | "No definite acute cardiopulmonary process. Bibasilar atelectasis. Vague opacities projecting over the left upper lung medially, potentially costochondral calcifications; however, dedicated chest x-ray with two views is suggested when patient is amenable." 24 | Streaky bibasilar atelectasis. No focal consolidation to suggest pneumonia. 25 | No acute cardiopulmonary process. 26 | "No radiographic evidence of pneumonia or other acute cardiopulmonary abnormalities. Small nodular opacity described above is likely a spinous process tip in this projection, but cannot rule out a pulmonary nodule. Recommend oblique radiographs for further evaluation." 27 | Heart size and mediastinum are stable. Lungs are slightly hyperinflated but essentially clear. There is no appreciable pleural effusion or pneumothorax. Nodule projecting over the left apex appears to be within the posterior aspect of the third ribs as demonstrated on the ___ view obtained and ___ 28 | "1. No acute cardiopulmonary processes. 2. Thickening of the cortex and trabecula of the left humerus, suggestive of Paget's disease. Dedicated humeral radiographs may be obtained for further evaluation." 29 | No acute cardiopulmonary process. 30 | "Bilateral small pleural effusions with lower lung atelectasis, less likely pneumonia." 31 | "In comparison with the study of ___, the right IJ catheter remains in place. There is suggestion of lucency and opacity in the right supraclavicular region, which may well relate to material outside the patient. There are slightly better lung volumes. Central catheter again extends to the cavoatrial junction or upper right atrium. Bibasilar opacification, more prominent on the left, is consistent with atelectatic changes and small pleural effusion. No definite vascular congestion." 32 | "Patient has been extubated, but lung volumes are unchanged and cardiomediastinal caliber is normal for the postoperative. . Previous pulmonary edema has resolved. No pneumothorax. Left pleural effusion is small." 33 | "Normal chest radiograph; specifically, no evidence of pneumonia." 34 | "PA and lateral chest compared to ___, read in conjunction with imaging of the chest on an abdomen CT performed earlier today and reported separately. Small bilateral pleural effusions are essentially unchanged. Lungs are well expanded and clear. Normal cardiomediastinal and hilar silhouettes. No free subdiaphragmatic gas. Healed right mid rib fracture." 35 | No acute cardiopulmonary abnormality. 36 | "As on the previous image, several millimetric metallic particles are seen. The biggest of these particles projects over the region of the right atrium on the frontal image. Another particles appears to be located in the lingular. 2 additional barely visible particles with diameters be low 1 mm project over the left lung bases and are not clearly visualized on the previous cross-table image. No other change, no pneumothorax." 37 | No acute intrathoracic process. No visualized food products. 38 | "As compared to the previous radiograph, the previously placed right pigtail catheter has been removed. There is a minimal decrease in extent of the known right pneumothorax. Right basilar atelectasis, however, persists. The pre-existing depression of the right hemidiaphragm is less severe than on the previous image. Unchanged appearance of the cardiac silhouette. Normal left hemithorax." 39 | "As compared to the previous image, the position of the right chest tube is constant. There is no noticeable decrease in severity of the known right pneumothorax, account for not by depression of the right hemidiaphragm. The cardiac silhouette and the left lung is unchanged." 40 | "As compared to the previous image, there is again an increase in extent and severity of the known right pneumothorax. The pneumothorax is now severe, with an average with of 3 cm throughout the right hemi thorax. There also is ongoing depression of the right hemidiaphragm. New is a larger right-sided pleural fluid level. Placement of a new chest tube should again be considered. Normal size of the heart. Unremarkable appearance of the left lung." 41 | "Significant worsening opacification of the entire right hemithorax, suggesting re-expansion pulmonary edema in the setting of recent thoracentesis." 42 | "As compared to the previous radiograph, the known right pneumothorax with depression of the right hemidiaphragm has not substantially changed. The relatively substantial pneumothorax has a diameter of close to 4 cm at the lung apex. The basal right lung is atelectatic. Unchanged appearance of the left lung. The right pigtail catheter in the pleural space is in unchanged position." 43 | Previous large right pleural effusion has been replaced by a large pneumothorax and small pleural effusion. Interval shift of the mediastinum to the left suggests accumulation of air in the right pleural space under pressure. Heart size is normal. Left lung is clear. 44 | "As compared to the previous image, there is virtually no change in appearance of the large fluidopneumothorax on the right, with mild depression of the right hemidiaphragm." 45 | The the VP shunt is projecting over the right hemi thorax. Elevated right hemidiaphragm is present most likely increased in part due to subpulmonic effusion and atelectasis. Substantial portion of the right lower lung is still collapsed and substantial amount of pleural effusion is still present. No definitive pneumothorax is seen. 46 | "In comparison with the study of ___, there is been a thoracentesis on the right with removal of a substantial amount of pleural fluid. No evidence of pneumothorax. A moderate right effusion persists with areas of underlying atelectasis and re-expansion pulmonary edema. The left lung is essentially clear." 47 | "No focal consolidation, pneumothorax, or pleural effusion." 48 | "Possible tiny, left apical pneumothorax. Bibasilar atelectasis, moderate left pleural effusion, and small right pleural effusion are essentially unchanged." 49 | The previously placed left chest tube was removed. There is no radiographic evidence for the presence of a pneumothorax. Minimal atelectasis at the left lung basis. No larger pleural effusions. No pneumonia. The patient has also been extubated and the nasogastric tube was removed. 50 | No acute cardiopulmonary process. 51 | No evidence of acute disease. 52 | -------------------------------------------------------------------------------- /src/sample_reports/sample_reports.csv: -------------------------------------------------------------------------------- 1 | Report Impression 2 | Heart size normal and lungs are clear. No edema or pneumonia. No effusion 3 | "1. Left pleural effusion with adjacent atelectasis. Right effusion is also present. 4 | 5 | 2. Cardiomegaly without overt edema." 6 | "Minimal patchy airspace disease within the lingula, may reflect atelectasis or consolidation." 7 | 1. Stable mild cardiomegaly. 2. Hyperexpanded but clear lungs. -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | import pandas as pd 5 | import numpy as np 6 | import json 7 | from models.bert_labeler import bert_labeler 8 | from bert_tokenizer import tokenize 9 | from transformers import BertTokenizer 10 | from sklearn.metrics import roc_auc_score 11 | from constants import * 12 | 13 | def generate_attention_masks(batch, source_lengths, device): 14 | """Generate masks for padded batches to avoid self-attention over pad tokens 15 | @param batch (Tensor): tensor of token indices of shape (batch_size, max_len) 16 | where max_len is length of longest sequence in the batch 17 | @param source_lengths (List[Int]): List of actual lengths for each of the 18 | sequences in the batch 19 | @param device (torch.device): device on which data should be 20 | 21 | @returns masks (Tensor): Tensor of masks of shape (batch_size, max_len) 22 | """ 23 | masks = torch.ones(batch.size(0), batch.size(1), dtype=torch.float) 24 | for idx, src_len in enumerate(source_lengths): 25 | masks[idx, src_len:] = 0 26 | return masks.to(device) 27 | 28 | def compute_auc(y_true, y_pred): 29 | """Compute the positive auc score 30 | @param y_true (list): List of 14 tensors each of shape (dev_set_size) 31 | @param y_pred (list): Same as y_true but for model predictions 32 | 33 | @returns res (list): List of 14 scalars 34 | """ 35 | res = [] 36 | for j in range(len(y_true)): 37 | if len(set(y_true[j].tolist())) == 1: #only one class present 38 | res.append(1.0) 39 | else: 40 | res.append(roc_auc_score(y_true[j], y_pred[j])) 41 | return res 42 | 43 | def evaluate(model, dev_loader, device, return_pred=False): 44 | """ Function to evaluate the current model weights 45 | @param model (nn.Module): the labeler module 46 | @param dev_loader (torch.utils.data.DataLoader): dataloader for dev set 47 | @param device (torch.device): device on which data should be 48 | @param return_pred (bool): whether to return predictions or not 49 | 50 | @returns res_dict (dictionary): dictionary with key 'auc' and with value 51 | being a list of length 14 with each element in the 52 | list as a scalar. If return_pred is true then a 53 | tuple is returned with the aforementioned dictionary 54 | as the first item, a list of predictions as the 55 | second item, and a list of ground truth as the 56 | third item 57 | """ 58 | 59 | was_training = model.training 60 | model.eval() 61 | y_pred = [[] for _ in range(len(CONDITIONS))] 62 | y_true = [[] for _ in range(len(CONDITIONS))] 63 | 64 | with torch.no_grad(): 65 | for i, data in enumerate(dev_loader, 0): 66 | batch = data['imp'] #(batch_size, max_len) 67 | batch = batch.to(device) 68 | label = data['label'] #(batch_size, 14) 69 | label = label.permute(1, 0).to(device) 70 | src_len = data['len'] 71 | batch_size = batch.shape[0] 72 | attn_mask = generate_attention_masks(batch, src_len, device) 73 | 74 | out = model(batch, attn_mask) 75 | 76 | for j in range(len(out)): 77 | out[j] = out[j].to('cpu') #move to cpu for sklearn 78 | curr_y_pred = torch.sigmoid(out[j]) #shape is (batch_size) 79 | y_pred[j].append(curr_y_pred) 80 | y_true[j].append(label[j].to('cpu')) 81 | 82 | if (i+1) % 200 == 0: 83 | print('Evaluation batch no: ', i+1) 84 | 85 | for j in range(len(y_true)): 86 | y_true[j] = torch.cat(y_true[j], dim=0) 87 | y_pred[j] = torch.cat(y_pred[j], dim=0) 88 | 89 | if was_training: 90 | model.train() 91 | 92 | auc = compute_auc(copy.deepcopy(y_true), copy.deepcopy(y_pred)) 93 | res_dict = {'auc': auc} 94 | 95 | if return_pred: 96 | return res_dict, y_pred, y_true 97 | else: 98 | return res_dict 99 | 100 | def test(model, checkpoint_path, test_ld): 101 | """Evaluate model on test set. 102 | @param model (nn.Module): labeler module 103 | @param checkpoint_path (string): location of saved model checkpoint 104 | @param test_ld (dataloader): dataloader for test set 105 | """ 106 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 107 | if torch.cuda.device_count() > 1: 108 | print("Using", torch.cuda.device_count(), "GPUs!") 109 | model = nn.DataParallel(model) #to utilize multiple GPU's 110 | model = model.to(device) 111 | 112 | checkpoint = torch.load(checkpoint_path) 113 | model.load_state_dict(checkpoint['model_state_dict']) 114 | 115 | print("Doing evaluation on test set\n") 116 | metrics = evaluate(model, test_ld, device) 117 | auc = metrics['auc'] 118 | 119 | print() 120 | metric_avg = [] 121 | for j in range(len(CONDITIONS)): 122 | print('%s auc: %.3f' % (CONDITIONS[j], auc[j])) 123 | if CONDITIONS[j] in EVAL_CONDITIONS: 124 | metric_avg.append(auc[j]) 125 | print('average of auc: %.3f' % (np.mean(metric_avg))) 126 | 127 | 128 | def label_report_list(checkpoint_path, report_list): 129 | """ Evaluate model on list of reports. 130 | @param checkpoint_path (string): location of saved model checkpoint 131 | @param report_list (list): list of report impressions (string) 132 | """ 133 | imp = pd.Series(report_list) 134 | imp = imp.str.strip() 135 | imp = imp.replace('\n',' ', regex=True) 136 | imp = imp.replace('[0-9]\.', '', regex=True) 137 | imp = imp.replace('\s+', ' ', regex=True) 138 | imp = imp.str.strip() 139 | 140 | model = bert_labeler() 141 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 142 | if torch.cuda.device_count() > 1: 143 | print("Using", torch.cuda.device_count(), "GPUs!") 144 | model = nn.DataParallel(model) #to utilize multiple GPU's 145 | model = model.to(device) 146 | checkpoint = torch.load(checkpoint_path) 147 | model.load_state_dict(checkpoint['model_state_dict']) 148 | model.eval() 149 | 150 | y_pred = [] 151 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 152 | new_imps = tokenize(imp, tokenizer) 153 | with torch.no_grad(): 154 | for imp in new_imps: 155 | # run forward prop 156 | imp = torch.LongTensor(imp) 157 | source = imp.view(1, len(imp)) 158 | 159 | attention = torch.ones(len(imp)) 160 | attention = attention.view(1, len(imp)) 161 | out = model(source.to(device), attention.to(device)) 162 | 163 | # get predictions 164 | result = {} 165 | for j in range(len(out)): 166 | curr_y_pred = out[j] #shape is (1) 167 | result[CONDITIONS[j]] = CLASS_MAPPING[curr_y_pred.item()] 168 | y_pred.append(result) 169 | return y_pred 170 | 171 | --------------------------------------------------------------------------------