├── .gitignore ├── LICENSE ├── README.md ├── checkpoints ├── .gitignore ├── README └── download.sh ├── cider-df.py ├── classification_metric_analysis.py ├── clinicgen ├── __pycache__ │ ├── eval.cpython-36.pyc │ ├── eval.cpython-37.pyc │ ├── eval.cpython-38.pyc │ ├── log.cpython-36.pyc │ ├── log.cpython-37.pyc │ ├── log.cpython-38.pyc │ ├── nli.cpython-36.pyc │ ├── nli.cpython-37.pyc │ ├── nli.cpython-38.pyc │ ├── optmizer.cpython-36.pyc │ ├── optmizer.cpython-37.pyc │ ├── optmizer.cpython-38.pyc │ ├── utils.cpython-36.pyc │ ├── utils.cpython-37.pyc │ └── utils.cpython-38.pyc ├── data │ ├── __pycache__ │ │ ├── areport.cpython-36.pyc │ │ ├── areport.cpython-37.pyc │ │ ├── areport.cpython-38.pyc │ │ ├── chexpert.cpython-36.pyc │ │ ├── chexpert.cpython-37.pyc │ │ ├── chexpert.cpython-38.pyc │ │ ├── flickr30k.cpython-36.pyc │ │ ├── flickr30k.cpython-37.pyc │ │ ├── flickr30k.cpython-38.pyc │ │ ├── image2text.cpython-36.pyc │ │ ├── image2text.cpython-37.pyc │ │ ├── image2text.cpython-38.pyc │ │ ├── mimiccxr.cpython-36.pyc │ │ ├── mimiccxr.cpython-37.pyc │ │ ├── mimiccxr.cpython-38.pyc │ │ ├── openi.cpython-36.pyc │ │ ├── openi.cpython-37.pyc │ │ ├── openi.cpython-38.pyc │ │ ├── utils.cpython-36.pyc │ │ ├── utils.cpython-37.pyc │ │ └── utils.cpython-38.pyc │ ├── areport.py │ ├── chexpert.py │ ├── flickr30k.py │ ├── image2text.py │ ├── mednli.py │ ├── mimiccxr.py │ ├── mimiccxr_custom.py │ ├── openi.py │ └── utils.py ├── eval.py ├── external │ ├── LICENSE_bleu-cider-rouge-spice │ ├── bleu │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── bleu.cpython-36.pyc │ │ │ ├── bleu.cpython-37.pyc │ │ │ ├── bleu.cpython-38.pyc │ │ │ ├── bleu_scorer.cpython-36.pyc │ │ │ ├── bleu_scorer.cpython-37.pyc │ │ │ └── bleu_scorer.cpython-38.pyc │ │ ├── bleu.py │ │ └── bleu_scorer.py │ ├── cider │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── cider.cpython-36.pyc │ │ │ ├── cider.cpython-37.pyc │ │ │ ├── cider.cpython-38.pyc │ │ │ ├── cider_scorer.cpython-36.pyc │ │ │ ├── cider_scorer.cpython-37.pyc │ │ │ └── cider_scorer.cpython-38.pyc │ │ ├── cider.py │ │ └── cider_scorer.py │ ├── rouge │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── rouge.cpython-36.pyc │ │ │ ├── rouge.cpython-37.pyc │ │ │ └── rouge.cpython-38.pyc │ │ └── rouge.py │ └── spice │ │ ├── .gitignore │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── spice.cpython-36.pyc │ │ ├── spice.cpython-37.pyc │ │ └── spice.cpython-38.pyc │ │ ├── lib │ │ ├── .gitignore │ │ ├── Meteor-1.5.jar │ │ ├── SceneGraphParser-1.0.jar │ │ ├── ejml-0.23.jar │ │ ├── fst-2.47.jar │ │ ├── guava-19.0.jar │ │ ├── hamcrest-core-1.3.jar │ │ ├── jackson-core-2.5.3.jar │ │ ├── javassist-3.19.0-GA.jar │ │ ├── json-simple-1.1.1.jar │ │ ├── junit-4.12.jar │ │ ├── lmdbjni-0.4.6.jar │ │ ├── lmdbjni-linux64-0.4.6.jar │ │ ├── lmdbjni-osx64-0.4.6.jar │ │ ├── lmdbjni-win64-0.4.6.jar │ │ ├── objenesis-2.4.jar │ │ ├── slf4j-api-1.7.12.jar │ │ └── slf4j-simple-1.7.21.jar │ │ ├── spice-1.0.jar │ │ └── spice.py ├── log.py ├── models │ ├── __pycache__ │ │ ├── bertnli.cpython-36.pyc │ │ ├── bertnli.cpython-37.pyc │ │ ├── bertnli.cpython-38.pyc │ │ ├── cnnrnnrnn.cpython-36.pyc │ │ ├── cnnrnnrnn.cpython-37.pyc │ │ ├── cnnrnnrnn.cpython-38.pyc │ │ ├── image.cpython-36.pyc │ │ ├── image.cpython-37.pyc │ │ ├── image.cpython-38.pyc │ │ ├── image2text.cpython-36.pyc │ │ ├── image2text.cpython-37.pyc │ │ ├── image2text.cpython-38.pyc │ │ ├── kwl.cpython-36.pyc │ │ ├── kwl.cpython-37.pyc │ │ ├── kwl.cpython-38.pyc │ │ ├── m2transformer.cpython-36.pyc │ │ ├── m2transformer.cpython-37.pyc │ │ ├── m2transformer.cpython-38.pyc │ │ ├── sat.cpython-36.pyc │ │ ├── sat.cpython-37.pyc │ │ ├── sat.cpython-38.pyc │ │ ├── tienet.cpython-36.pyc │ │ ├── tienet.cpython-37.pyc │ │ ├── tienet.cpython-38.pyc │ │ ├── transformer.cpython-36.pyc │ │ ├── transformer.cpython-37.pyc │ │ ├── transformer.cpython-38.pyc │ │ ├── utils.cpython-36.pyc │ │ ├── utils.cpython-37.pyc │ │ └── utils.cpython-38.pyc │ ├── bertnli.py │ ├── cnnrnnrnn.py │ ├── image.py │ ├── image2text.py │ ├── image_orig.py │ ├── kwl.py │ ├── m2transformer.py │ ├── sat.py │ ├── tienet.py │ ├── transformer.py │ └── utils.py ├── nli.py ├── optmizer.py ├── text │ ├── __pycache__ │ │ ├── sentsplit.cpython-36.pyc │ │ ├── sentsplit.cpython-37.pyc │ │ ├── sentsplit.cpython-38.pyc │ │ ├── textfilter.cpython-36.pyc │ │ ├── textfilter.cpython-37.pyc │ │ ├── textfilter.cpython-38.pyc │ │ ├── tokenfilter.cpython-36.pyc │ │ ├── tokenfilter.cpython-37.pyc │ │ ├── tokenfilter.cpython-38.pyc │ │ ├── tokenizer.cpython-36.pyc │ │ ├── tokenizer.cpython-37.pyc │ │ └── tokenizer.cpython-38.pyc │ ├── parser.py │ ├── sentsplit.py │ ├── textfilter.py │ ├── tokenfilter.py │ └── tokenizer.py └── utils.py ├── convert_generated.py ├── create_sections.py ├── custom_models.py ├── environment.yml ├── eval_prf.py ├── extract_reports.py ├── infer.py ├── libs.yml ├── make_radnli-pseudo-train.py ├── metric_analysis.py ├── ner_reports.py ├── resize_mimic-cxr-jpg.py ├── resources ├── .gitignore ├── download.sh └── radnli_pseudo-train_indexes.jsonl ├── section_parser.py ├── setup.py ├── temp.ipynb ├── tests ├── test_eval.py ├── test_nli.py └── text │ ├── test_sentsplit.py │ ├── test_textfilter.py │ └── test_tokenizer.py ├── train.py └── train_image.py /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/.gitignore -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # On the Importance of Image Encoding in Automated Chest X-Ray Report Generation 2 | ### BMVC, 2022 3 | 4 | The reference code of [On the Importance of Image Encoding in Automated Chest X-Ray Report Generation](https://arxiv.org/abs/2211.13465). 5 | 6 | The original code taken from [Improving Factual Completeness and Consistency of Image-to-text Radiology Report Generation](https://github.com/ysmiura/ifcc). 7 | 8 | Changes to original code: 9 | - [custom_models.py](custom_models.py) - contains various image encoders 10 | - [image.py](clinicgen/models/image.py) - modified code of original file to accomodate various encoders 11 | 12 | ### Running the code 13 | The original [repository](https://github.com/ysmiura/ifcc) has very good instructions for running the code. We recommend following those. 14 | To choose the encoders please select corresponding models within the [custom_models.py](custom_models.py) file. 15 | New encoders can also be easily added. 16 | 17 | ## Licence 18 | See LICENSE and clinicgen/external/LICENSE_bleu-cider-rouge-spice for details. 19 | -------------------------------------------------------------------------------- /checkpoints/.gitignore: -------------------------------------------------------------------------------- 1 | *.dict.gz 2 | -------------------------------------------------------------------------------- /checkpoints/README: -------------------------------------------------------------------------------- 1 | [checkpoint_nll-bs.dict.gz] An M2 Transformer checkpoint trained with NLL+BERTScore 2 | [checkpoint_nll-bs-emexact.dict.gz] An M2 Transformer checkpoint trained with NLL+BERTScore+EntityMatchExact 3 | [checkpoint_nll-bs-emnli.dict.gz] An M2 Transformer checkpoint trained with NLL+BERTScore+EntityMatchNLI 4 | -------------------------------------------------------------------------------- /checkpoints/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | wget https://nlp.stanford.edu/ysmiura/ifcc/checkpoints.tar 3 | tar xvf checkpoints.tar 4 | rm checkpoints.tar 5 | -------------------------------------------------------------------------------- /cider-df.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import csv 6 | import gzip 7 | import os 8 | import pickle 9 | from clinicgen.data.areport import AReportData 10 | from clinicgen.data.mimiccxr import MIMICCXRData 11 | from clinicgen.eval import GenEval 12 | from clinicgen.text.textfilter import get_textfilter 13 | from clinicgen.text.tokenfilter import get_tokenfilter 14 | from clinicgen.text.tokenizer import get_tokenizer 15 | 16 | 17 | def main(args): 18 | tokenizer = get_tokenizer(args.tokenizer) 19 | textfilter = get_textfilter(args.textfilter) 20 | tokenfilter = get_tokenfilter(args.tokenfilter) 21 | 22 | texts = [] 23 | if args.corpus == 'mimic-cxr': 24 | path = os.path.join(args.data, 'mimic-cxr-resized', '2.0.0', MIMICCXRData.SPLITS_PATH) 25 | train_ids = {} 26 | with gzip.open(path, 'rt', encoding='utf-8') as f: 27 | header = f.readline() 28 | reader = csv.reader(f) 29 | for row in reader: 30 | if row[3] == 'train': 31 | train_ids['s' + row[1]] = True 32 | path = os.path.join(args.data, 'mimic-cxr-resized', '2.0.0', MIMICCXRData.SECTIONED_PATH) 33 | with gzip.open(path, 'rt', encoding='utf-8') as f: 34 | header = f.readline() 35 | reader = csv.reader(f) 36 | for row in reader: 37 | if row[0] in train_ids: 38 | if args.section == 'impression': 39 | text = row[1] 40 | else: 41 | text = row[2] 42 | if len(text) > 0: 43 | texts.append(text) 44 | elif args.corpus == 'a': 45 | dump_dir = os.path.join(args.cache, 'a', 'train') 46 | dataset = AReportData(args.data, section=args.section, anatomy=args.anatomy, exclude_ids=args.exclude_ids, 47 | meta=args.meta, split='train', cache_image=True, cache_text=True, dump_dir=dump_dir, 48 | multi_image=2) 49 | for target in dataset.targets: 50 | target = gzip.decompress(target).decode('utf-8') 51 | target = dataset.extract_section(target) 52 | if len(target) > 0: 53 | texts.append(target) 54 | else: 55 | raise ValueError('Unknown corpus {0}'.format(args.corpus)) 56 | print('{0} texts'.format(len(texts))) 57 | 58 | ftexts = [] 59 | for text in texts: 60 | toks = tokenizer.tokenize(textfilter.filter(text)) 61 | toks = tokenfilter.filter(toks) 62 | ftext = ' '.join(toks) 63 | ftexts.append(ftext) 64 | 65 | df = GenEval.compute_cider_df(texts) 66 | with gzip.open(args.output, 'w') as f: 67 | pickle.dump(df, f) 68 | 69 | 70 | def parse_args(): 71 | parser = argparse.ArgumentParser(description='Compute CIDEr DFs') 72 | parser.add_argument('data', type=str, help='A path to clinical data') 73 | parser.add_argument('output', type=str, help='An output path') 74 | parser.add_argument('--anatomy', type=str, default=None, help='An anatomy') 75 | parser.add_argument('--cache', type=str, default=None, help='A cache path') 76 | parser.add_argument('--corpus', type=str, default='mimic-cxr', help='Corpus name') 77 | parser.add_argument('--exclude-ids', type=str, default=None, help='Exclude IDs') 78 | parser.add_argument('--meta', type=str, default=None, help='A meta data path') 79 | parser.add_argument('--section', type=str, default='findings', help='Target section') 80 | parser.add_argument('--textfilter', type=str, default='lower', help='Text filter') 81 | parser.add_argument('--tokenfilter', type=str, default='none', help='Token filter') 82 | parser.add_argument('--tokenizer', type=str, default='nltk', choices=['nltk', 'none', 'stanford', 'whitespace'], help='Tokenizer name') 83 | return parser.parse_args() 84 | 85 | 86 | if __name__ == '__main__': 87 | args = parse_args() 88 | main(args) 89 | -------------------------------------------------------------------------------- /classification_metric_analysis.py: -------------------------------------------------------------------------------- 1 | from transformers import pipeline 2 | from clinicgen.nli import BERTScorer 3 | import pandas as pd 4 | import numpy as np 5 | import re 6 | from tqdm import tqdm 7 | import torch 8 | 9 | # Questions for QA model 10 | QUESTIONS = [ 11 | 'Is there monia?', 12 | 'Is there edema?', 13 | 'Is there thorax?', 14 | 'Are there devices?', 15 | 'Is there opacity?', 16 | 'Is there atelectasis?', 17 | 'Is there heart?', 18 | 'Is there lung lesion?', 19 | 'Is there consolidation?', 20 | 'Is there fracture?', 21 | ] 22 | 23 | # Model for BERTScore 24 | bert_model = 'distilbert-base-uncased' 25 | 26 | # Read original reports 27 | reports_df = pd.read_csv('/home/otabek.nazarov/Downloads/thesis/ifcc/labeled_reports_test.csv') 28 | genered_df = pd.read_csv('/home/otabek.nazarov/Downloads/thesis/ifcc/trans_baseline.csv') 29 | 30 | # Batch size configuration for model 31 | samples_cnt = 3800 32 | batch_size = 50#47 # or 41 33 | batch_count = int(samples_cnt / batch_size) 34 | 35 | # Load QA model 36 | device_id = 2 # -1 for cpu 37 | qa_model = pipeline("question-answering", 38 | model='franklu/pubmed_bert_squadv2', 39 | framework='pt', 40 | device=device_id) 41 | qa_model.model.to(torch.device('cuda:2')) 42 | 43 | QA_THRESHOLD = 0.30 44 | 45 | # Load BERTScore model 46 | bert_score_qa_model = BERTScorer(model_type=bert_model, batch_size=batch_size, 47 | nthreads=2, lang='en', rescale_with_baseline=True, 48 | penalty=False) 49 | 50 | # Dictionary for final dataframe 51 | data_dict = { 52 | 'mask_prob' : [], 53 | 'f1_full' : [], 54 | 'f1_qa' : [], 55 | 'prec_full' : [], 56 | 'prec_qa' : [], 57 | 'recall_full' : [], 58 | 'recall_qa' : [], 59 | } 60 | 61 | mask_reports_dict = {} 62 | 63 | 64 | # Turn into batches for fast processing 65 | orig_reports = np.reshape(reports_df['Report Impression'].values[:samples_cnt], (batch_count, batch_size)) 66 | mask_reports = np.reshape(genered_df['Report Impression'].values[:samples_cnt], (batch_count, batch_size)) 67 | 68 | 69 | f1_score_means = [] 70 | f1_score_means_orig = [] 71 | qa_bert_scores = [] 72 | full_bert_scores = [] 73 | for idx in tqdm(range(batch_count)): 74 | 75 | refs_l = orig_reports[idx,:].tolist() 76 | hypos_l = mask_reports[idx,:].tolist() 77 | 78 | f1_scores = np.empty((len(refs_l), len(QUESTIONS))) 79 | f1_scores.fill(np.nan) 80 | 81 | full_f1_scores = np.empty((len(refs_l), len(QUESTIONS))) 82 | full_f1_scores.fill(np.nan) 83 | 84 | for q_idx, cur_question in enumerate(QUESTIONS): 85 | # Copy questions for batch forwarding to the model 86 | question_batch = [cur_question] * len(hypos_l) 87 | 88 | # Get results from QA model 89 | refs_cur_results = qa_model(question=question_batch, context=refs_l) 90 | hypo_cur_results = qa_model(question=question_batch, context=hypos_l) 91 | 92 | # Get bert scores for given answers 93 | bert_score_refs = [] 94 | bert_score_hypo = [] 95 | for sample_idx, (cur_ref_res, cur_hypo_res) in enumerate(zip(refs_cur_results, hypo_cur_results)): 96 | bert_score_refs.append(cur_ref_res['answer']) 97 | bert_score_hypo.append(cur_hypo_res['answer']) 98 | 99 | _, _, b_f1 = bert_score_qa_model.score(bert_score_hypo, bert_score_refs) 100 | b_f1 = b_f1.numpy() 101 | 102 | _, _, full_f1 = bert_score_qa_model.score(hypos_l, refs_l) 103 | full_f1 = full_f1.numpy() 104 | 105 | # Select scores for loss based on threshold 106 | for sample_idx, (cur_ref_res, cur_hypo_res) in enumerate(zip(refs_cur_results, hypo_cur_results)): 107 | if cur_ref_res['score'] > QA_THRESHOLD or cur_hypo_res['score'] > QA_THRESHOLD: 108 | f1_scores[sample_idx, q_idx] = b_f1[sample_idx] 109 | full_f1_scores[sample_idx, q_idx] = full_f1[sample_idx] 110 | 111 | qa_bert_scores.append(f1_scores) 112 | full_bert_scores.append(full_f1_scores) 113 | 114 | # Save bert scores 115 | bert_scores_np = np.reshape(np.array(qa_bert_scores), (samples_cnt, len(QUESTIONS))) 116 | save_df = pd.DataFrame(bert_scores_np, columns=QUESTIONS) 117 | save_df.to_csv(f'qa_bert_scores_heart_{QA_THRESHOLD}.csv', index=False) 118 | 119 | bert_scores_np = np.reshape(np.array(full_bert_scores), (samples_cnt, len(QUESTIONS))) 120 | save_df = pd.DataFrame(bert_scores_np, columns=QUESTIONS) 121 | save_df.to_csv(f'full_bert_scores_heart_{QA_THRESHOLD}.csv', index=False) -------------------------------------------------------------------------------- /clinicgen/__pycache__/eval.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/__pycache__/eval.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/__pycache__/eval.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/__pycache__/eval.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/__pycache__/eval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/__pycache__/eval.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/__pycache__/log.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/__pycache__/log.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/__pycache__/log.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/__pycache__/log.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/__pycache__/log.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/__pycache__/log.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/__pycache__/nli.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/__pycache__/nli.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/__pycache__/nli.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/__pycache__/nli.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/__pycache__/nli.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/__pycache__/nli.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/__pycache__/optmizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/__pycache__/optmizer.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/__pycache__/optmizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/__pycache__/optmizer.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/__pycache__/optmizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/__pycache__/optmizer.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/areport.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/areport.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/areport.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/areport.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/areport.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/areport.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/chexpert.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/chexpert.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/chexpert.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/chexpert.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/chexpert.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/chexpert.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/flickr30k.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/flickr30k.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/flickr30k.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/flickr30k.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/flickr30k.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/flickr30k.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/image2text.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/image2text.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/image2text.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/image2text.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/image2text.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/image2text.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/mimiccxr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/mimiccxr.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/mimiccxr.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/mimiccxr.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/mimiccxr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/mimiccxr.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/openi.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/openi.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/openi.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/openi.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/openi.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/openi.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/data/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/data/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/data/chexpert.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import csv 5 | import os 6 | import time 7 | import torch 8 | from tqdm import tqdm 9 | from clinicgen.data.image2text import _RadiologyReportData 10 | 11 | 12 | class CheXpertData(_RadiologyReportData): 13 | IMAGE_NUM = {'train': 223414, 'valid': 234} 14 | LABEL_IGNORE = -100 15 | LABEL_NEGATIVE = 1 16 | NUM_CLASSES = 3 17 | NUM_LABELS = 14 18 | 19 | def __init__(self, root, split=None, cache_image=False, multi_image=1, img_mode='center', img_augment=False, 20 | ignore_blank=False, dump_dir=None): 21 | super().__init__(root, section='findings', split=split, cache_image=cache_image, cache_text=True, 22 | multi_image=multi_image, dump_dir=dump_dir) 23 | self.ignore_blank = ignore_blank 24 | pre_transform, self.transform = CheXpertData.get_transform(cache_image, img_mode, img_augment) 25 | self.target_transform = None 26 | 27 | if dump_dir is not None: 28 | t = time.time() 29 | if self.load(): 30 | print('Loaded data dump from %s (%.2fs)' % (dump_dir, time.time() - t)) 31 | self.pre_processes() 32 | return 33 | 34 | images = os.path.join(root, '{0}.csv'.format(split)) 35 | with open(images, encoding='utf-8') as f: 36 | f.readline() 37 | reader = csv.reader(f) 38 | with tqdm(total=self.IMAGE_NUM[split]) as pbar: 39 | pbar.set_description('Data ({0})'.format(split)) 40 | count = 0 41 | interval = 1000 42 | inc = 1 if split == 'train' else 0 43 | 44 | for entry in reader: 45 | sub_paths = entry[0].split('/') 46 | image_id = '/'.join(sub_paths[2:5]) 47 | doc_id = '/'.join(sub_paths[2:4]) 48 | self.ids.append(image_id) 49 | self.doc_ids.append(doc_id) 50 | # image 51 | image = os.path.join(root, '/'.join(sub_paths[1:])) 52 | if cache_image: 53 | image = self.bytes_image(image, pre_transform) 54 | # labels 55 | labels = [] 56 | for label in entry[5:]: 57 | if len(label) == 0: 58 | labels.append(self.LABEL_IGNORE) 59 | else: 60 | labels.append(int(float(label)) + inc) 61 | self.samples.append((image, labels)) 62 | self.targets.append(labels) 63 | count += 1 64 | if count >= interval: 65 | pbar.update(count) 66 | count = 0 67 | if count > 0: 68 | pbar.update(count) 69 | 70 | if dump_dir is not None: 71 | self.dump() 72 | self.pre_processes() 73 | 74 | def __getitem__(self, index): 75 | rid, sample, target, vp = super().__getitem__(index) 76 | target = torch.tensor(target) 77 | return rid, sample, target, vp 78 | 79 | @classmethod 80 | def get_transform(cls, cache_image=False, mode='center', augment=False): 81 | return cls._transform(cache_image, 224, mode, augment) 82 | 83 | def convert_blank_labels(self, print_num=True): 84 | t = time.time() 85 | if print_num: 86 | print('Converting blank labels ... ', end='', flush=True) 87 | new_samples, new_targets = [], [] 88 | for i in range(len(self.samples)): 89 | new_labels = [] 90 | image, labels = self.samples[i] 91 | for label in labels: 92 | if label == self.LABEL_IGNORE: 93 | new_labels.append(self.LABEL_NEGATIVE) 94 | else: 95 | new_labels.append(label) 96 | new_samples.append((image, new_labels)) 97 | new_targets.append(new_labels) 98 | self.samples = new_samples 99 | self.targets = new_targets 100 | if print_num: 101 | print('done (%.2fs)' % (time.time() - t), flush=True) 102 | 103 | def pre_processes(self): 104 | if not self.ignore_blank: 105 | self.convert_blank_labels() 106 | if self.multi_image > 1: 107 | self.convert_to_multi_images() 108 | -------------------------------------------------------------------------------- /clinicgen/data/flickr30k.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import gzip 5 | import json 6 | import os 7 | import time 8 | from tqdm import tqdm 9 | from clinicgen.data.image2text import _CaptioningData 10 | 11 | 12 | class Flickr30kData(_CaptioningData): 13 | IMAGE_NUM = 158915 14 | DIR_IMAGES = 'flickr30k-images' 15 | FILE_CAPTIONS = os.path.join('flickr30k', 'results_20130124.token') 16 | 17 | def __init__(self, root, meta=None, split=None, target_transform=None, cache_image=False, img_mode='center', 18 | img_augment=False, cache_text=False, dump_dir=None): 19 | if not cache_text: 20 | raise ValueError('Flickr30k data only supports cached texts') 21 | super().__init__(root, split=split, cache_image=cache_image, cache_text=cache_text, dump_dir=dump_dir) 22 | pre_transform, self.transform = Flickr30kData.get_transform(cache_image, img_mode, img_augment) 23 | self.target_transform = target_transform 24 | self.multi_instance = True 25 | 26 | if dump_dir is not None: 27 | t = time.time() 28 | if self.load(): 29 | print('Loaded data dump from %s (%.2fs)' % (dump_dir, time.time() - t)) 30 | self.pre_processes() 31 | return 32 | 33 | splits = None 34 | if meta is not None: 35 | splits = {} 36 | with open(meta, encoding='utf-8') as f: 37 | meta_data = json.load(f) 38 | for entry in meta_data['images']: 39 | splits[entry['filename']] = entry['split'] 40 | 41 | captions = os.path.join(root, self.FILE_CAPTIONS) 42 | with open(captions, encoding='utf-8') as f: 43 | with tqdm(total=self.IMAGE_NUM) as pbar: 44 | pbar.set_description('Data ({0})'.format(split)) 45 | count = 0 46 | interval = 1000 47 | prev_image, buffer = None, [] 48 | 49 | for line in f: 50 | entry = line.rstrip().split('\t') 51 | image = entry[0].split('#')[0] 52 | 53 | if split is None or (image in splits and splits[image] == split): 54 | report = gzip.compress(entry[1].encode('utf-8')) 55 | if prev_image is not None and image != prev_image: 56 | count = self._append_image(prev_image, buffer, count, pre_transform) 57 | buffer = [] 58 | prev_image = image 59 | buffer.append((entry[0], report)) 60 | else: 61 | count += 1 62 | if count >= interval: 63 | pbar.update(count) 64 | count = 0 65 | 66 | if len(buffer) > 0: 67 | count = self._append_image(prev_image, buffer, count, pre_transform) 68 | if count > 0: 69 | pbar.update(count) 70 | 71 | if dump_dir is not None: 72 | self.dump() 73 | self.pre_processes() 74 | 75 | def _append_image(self, image_id, buffer, count, pre_transform): 76 | if len(buffer) > 0: 77 | image = os.path.join(self.root, self.DIR_IMAGES, image_id) 78 | if self.cache_image: 79 | image = self.bytes_image(image, pre_transform) 80 | if self.split == 'test' or self.split == 'val': 81 | buffer = [e[1] for e in buffer] 82 | self.ids.append(image_id) 83 | self.samples.append((image, buffer)) 84 | self.targets.append(buffer) 85 | count += len(buffer) 86 | else: 87 | for entry_id, report in buffer: 88 | self.ids.append(entry_id) 89 | self.samples.append((image, report)) 90 | self.targets.append(report) 91 | count += 1 92 | return count 93 | 94 | def pre_processes(self): 95 | self.pre_transform_texts(self.split) 96 | -------------------------------------------------------------------------------- /clinicgen/data/mednli.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import csv 5 | import json 6 | import os 7 | import torch.utils.data as data 8 | 9 | 10 | class MedNLIData(data.Dataset): 11 | def __init__(self): 12 | self.ids = [] 13 | self.samples = [] 14 | 15 | def __getitem__(self, index): 16 | sentence1, sentence2, gold_label = self.samples[index] 17 | iid = self.ids[index] 18 | return iid, sentence1, sentence2, gold_label 19 | 20 | def __len__(self): 21 | return len(self.samples) 22 | 23 | def load(self, root, split=None): 24 | if split == 'validation': 25 | split = 'dev' 26 | path = os.path.join(root, 'mednli_bionlp19_shared_task_ground_truth.csv') 27 | if os.path.exists(path): 28 | form = 'jsonl-csv' 29 | else: 30 | path = os.path.join(root, 'mli_{0}_v1.jsonl'.format(split)) 31 | if os.path.exists(path): 32 | form = 'jsonl' 33 | else: 34 | form = 'tsv' 35 | 36 | if form == 'tsv': 37 | if split == 'dev' or split == 'test': 38 | path = os.path.join(root, '{0}.tsv'.format(split)) 39 | if not os.path.exists(path): 40 | path = os.path.join(root, '{0}_fact240.tsv'.format(split)) 41 | elif split == 'train': 42 | path = os.path.join(root, 'train.tsv') 43 | if not os.path.exists(path): 44 | items = [] 45 | for item in os.listdir(root): 46 | if not item.startswith('.') and item.endswith('.tsv'): 47 | items.append(item) 48 | assert len(items) == 1 49 | path = os.path.join(root, items[0]) 50 | else: 51 | raise ValueError('Unknown split {0}'.format(split)) 52 | with open(path, encoding='utf-8') as f: 53 | for line in f: 54 | entry = line.rstrip().split('\t') 55 | self.ids.append(entry[0]) 56 | self.samples.append((entry[1], entry[2], entry[-1])) 57 | elif form == 'jsonl': 58 | path = os.path.join(root, 'mli_{0}_v1.jsonl'.format(split)) 59 | with open(path, encoding='utf-8') as f: 60 | for line in f: 61 | entry = json.loads(line) 62 | self.ids.append(entry['pairID']) 63 | self.samples.append((entry['sentence1'], entry['sentence2'], entry['gold_label'])) 64 | elif form == 'jsonl-csv': 65 | path = os.path.join(root, 'mednli_bionlp19_shared_task_ground_truth.csv') 66 | labels = {} 67 | with open(path, encoding='utf-8') as f: 68 | f.readline() 69 | reader = csv.reader(f) 70 | for row in reader: 71 | labels[row[0]] = row[1] 72 | path = os.path.join(root, 'mednli_bionlp19_shared_task.jsonl') 73 | with open(path, encoding='utf-8') as f: 74 | for line in f: 75 | entry = json.loads(line) 76 | pid = entry['pairID'] 77 | self.ids.append(pid) 78 | self.samples.append((entry['sentence1'], entry['sentence2'], labels[pid])) 79 | else: 80 | raise ValueError('Unknown format {0}'.format(form)) 81 | -------------------------------------------------------------------------------- /clinicgen/data/mimiccxr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import sys 5 | sys.path.insert(1, '/home/otabek.nazarov/Downloads/thesis/ifcc') 6 | 7 | import csv 8 | import gzip 9 | import os 10 | import pickle 11 | import time 12 | import torch 13 | from tqdm import tqdm 14 | from clinicgen.data.image2text import _RadiologyReportData 15 | 16 | 17 | class MIMICCXRData(_RadiologyReportData): 18 | IMAGE_NUM = 377110 19 | LABEL_CHEXPERT = 'chexpert' 20 | CHEXPERT_MAP = [13, 4, 1, 7, 6, 3, 2, 9, 0, 10, 8, 11, 5, 12] 21 | 22 | CHEXPERT_PATH = 'mimic-cxr-2.0.0-chexpert.csv.gz' 23 | META_PATH = 'mimic-cxr-2.0.0-metadata.csv.gz' 24 | SECTIONED_PATH = 'mimic_cxr_sectioned.csv.gz' 25 | SPLITS_PATH = 'mimic-cxr-2.0.0-split.csv.gz' 26 | 27 | def __init__(self, root, section='findings', split=None, target_transform=None, cache_image=False, cache_text=True, 28 | multi_image=1, img_mode='center', img_augment=False, single_image_doc=False, dump_dir=None, 29 | filter_reports=True): 30 | if not cache_text: 31 | raise ValueError('MIMIC-CXR data only supports cached texts') 32 | super().__init__(root, section, split, cache_image, cache_text, multi_image=multi_image, 33 | single_image_doc=single_image_doc, dump_dir=dump_dir) 34 | pre_transform, self.transform = MIMICCXRData.get_transform(cache_image, img_mode, img_augment) 35 | self.target_transform = target_transform 36 | self.chexpert_labels_path = os.path.join(root, 'mimic-cxr-jpg', '2.0.0', self.CHEXPERT_PATH) 37 | 38 | self.view_positions = {} 39 | doc_image_map = {} 40 | with gzip.open(os.path.join(root, 'mimic-cxr-resized', '2.0.0', self.META_PATH), 'rt', encoding='utf-8') as f: 41 | f.readline() 42 | reader = csv.reader(f) 43 | for row in reader: 44 | self.view_positions[row[0]] = row[4] 45 | if row[2] in doc_image_map: 46 | doc_image_map[row[2]].append(row[0]) 47 | else: 48 | doc_image_map[row[2]] = [row[0]] 49 | 50 | if dump_dir is not None: 51 | t = time.time() 52 | if self.load(): 53 | print('Loaded data dump from %s (%.2fs)' % (dump_dir, time.time() - t)) 54 | self.pre_processes(filter_reports) 55 | return 56 | 57 | with gzip.open(os.path.join(root, 'mimic-cxr-resized', '2.0.0', self.SECTIONED_PATH), 'rt', 58 | encoding='utf-8') as f: 59 | header = f.readline().strip().split(',') 60 | sections = {} 61 | reader = csv.reader(f) 62 | for row in reader: 63 | report = {} 64 | for i, sec in enumerate(header): 65 | report[sec] = row[i] 66 | sections[row[0]] = gzip.compress(pickle.dumps(report)) 67 | 68 | with gzip.open(os.path.join(root, 'mimic-cxr-resized', '2.0.0', self.SPLITS_PATH), 'rt', 69 | encoding='utf-8') as f: 70 | f.readline() 71 | reader = csv.reader(f) 72 | interval = 1000 73 | with tqdm(total=self.IMAGE_NUM) as pbar: 74 | pbar.set_description('Data ({0})'.format(split)) 75 | count = 0 76 | for row in reader: 77 | try: 78 | if split is None or split == row[3]: 79 | did = row[0] 80 | sid = row[1] 81 | pid = row[2] 82 | self.ids.append(did) 83 | self.doc_ids.append(sid) 84 | # image 85 | image = os.path.join(root, 'mimic-cxr-resized', '2.0.0', 'files', 'p{0}'.format(pid[:2]), 86 | 'p' + pid, 's' + sid, did + '.png') 87 | 88 | if cache_image: 89 | image = self.bytes_image(image, pre_transform) 90 | # report 91 | report = os.path.join(root, 'mimic-cxr', '2.0.0', 'files', 'p{0}'.format(pid[:2]), 'p' + pid, 92 | 's{0}.txt'.format(sid)) 93 | if cache_text: 94 | sid = 's' + sid 95 | report = sections[sid] if sid in sections else gzip.compress(pickle.dumps({})) 96 | self.samples.append((image, report)) 97 | self.targets.append(report) 98 | count += 1 99 | if count >= interval: 100 | pbar.update(count) 101 | count = 0 102 | except: 103 | print(f'failed {image}') 104 | continue 105 | if count > 0: 106 | pbar.update(count) 107 | 108 | if dump_dir is not None: 109 | self.dump() 110 | self.pre_processes(filter_reports) 111 | 112 | def __getitem__(self, index): 113 | rid, sample, target, _ = super().__getitem__(index) 114 | did = self.doc_ids[index] 115 | # View position features 116 | if self.multi_image > 1: 117 | vp = [self.view_position_embedding(self.view_positions[iid]) for iid in self.image_ids[index]] 118 | vp = [p.unsqueeze(dim=0) for p in vp] 119 | if len(vp) > self.multi_image: 120 | vp = vp[:self.multi_image] 121 | elif len(vp) < self.multi_image: 122 | first_vp = vp[0] 123 | for _ in range(self.multi_image - len(vp)): 124 | vp.append(first_vp.new_zeros(first_vp.size())) 125 | vp = torch.cat(vp, dim=0) 126 | else: 127 | vp = self.view_position_embedding(self.view_positions[rid]) 128 | return did + '__' + rid, sample, target, vp 129 | 130 | @classmethod 131 | def get_transform(cls, cache_image=False, mode='center', augment=False): 132 | return cls._transform(cache_image, 294, mode, augment) 133 | 134 | def compare_texts(self, text1, text2): 135 | if 'study' in text1 and 'study' in text2: 136 | return text1['study'] == text2['study'] 137 | else: 138 | return True 139 | 140 | def decompress_text(self, text): 141 | return pickle.loads(gzip.decompress(text)) 142 | 143 | def extract_section(self, text): 144 | if self.section in text: 145 | return text[self.section].replace('\n', ' ') 146 | else: 147 | return '' 148 | 149 | def pre_processes(self, filter_reports): 150 | if filter_reports: 151 | self.filter_empty_reports() 152 | if self.multi_image > 1: 153 | self.convert_to_multi_images() 154 | elif self.single_image_doc: 155 | self.convert_to_single_image() 156 | self.pre_transform_texts(self.split) 157 | 158 | if __name__ == "__main__": 159 | path = '/home/otabek.nazarov/Downloads/' 160 | datasets = MIMICCXRData(path, section='finding', split='test', 161 | cache_image=False, cache_text=True, multi_image=1) 162 | print(datasets[0]) -------------------------------------------------------------------------------- /clinicgen/data/mimiccxr_custom.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import sys 5 | sys.path.insert(1, '/home/otabek.nazarov/Downloads/thesis/ifcc') 6 | 7 | import csv 8 | import gzip 9 | import os 10 | import pickle 11 | import time 12 | import torch 13 | from tqdm import tqdm 14 | from clinicgen.data.image2text import _RadiologyReportData 15 | 16 | 17 | class MIMICCXRData(_RadiologyReportData): 18 | IMAGE_NUM = 377110 19 | LABEL_CHEXPERT = 'chexpert' 20 | CHEXPERT_MAP = [13, 4, 1, 7, 6, 3, 2, 9, 0, 10, 8, 11, 5, 12] 21 | 22 | CHEXPERT_PATH = 'mimic-cxr-2.0.0-chexpert.csv.gz' 23 | META_PATH = 'mimic-cxr-2.0.0-metadata.csv.gz' 24 | SECTIONED_PATH = 'mimic_cxr_sectioned.csv.gz' 25 | SPLITS_PATH = 'mimic-cxr-2.0.0-split.csv.gz' 26 | 27 | def __init__(self, root, section='findings', split=None, target_transform=None, cache_image=False, cache_text=True, 28 | multi_image=1, img_mode='center', img_augment=False, single_image_doc=False, dump_dir=None, 29 | filter_reports=True): 30 | if not cache_text: 31 | raise ValueError('MIMIC-CXR data only supports cached texts') 32 | super().__init__(root, section, split, cache_image, cache_text, multi_image=multi_image, 33 | single_image_doc=single_image_doc, dump_dir=dump_dir) 34 | pre_transform, self.transform = MIMICCXRData.get_transform(cache_image, img_mode, img_augment) 35 | self.target_transform = target_transform 36 | self.chexpert_labels_path = os.path.join(root, 'mimic-cxr-jpg', '2.0.0', self.CHEXPERT_PATH) 37 | 38 | self.view_positions = {} 39 | doc_image_map = {} 40 | with gzip.open(os.path.join(root, 'mimic-cxr-resized', '2.0.0', self.META_PATH), 'rt', encoding='utf-8') as f: 41 | f.readline() 42 | reader = csv.reader(f) 43 | for row in reader: 44 | self.view_positions[row[0]] = row[4] 45 | if row[2] in doc_image_map: 46 | doc_image_map[row[2]].append(row[0]) 47 | else: 48 | doc_image_map[row[2]] = [row[0]] 49 | 50 | if dump_dir is not None: 51 | t = time.time() 52 | if self.load(): 53 | print('Loaded data dump from %s (%.2fs)' % (dump_dir, time.time() - t)) 54 | self.pre_processes(filter_reports) 55 | return 56 | 57 | with gzip.open(os.path.join(root, 'mimic-cxr-resized', '2.0.0', self.SECTIONED_PATH), 'rt', 58 | encoding='utf-8') as f: 59 | header = f.readline().strip().split(',') 60 | sections = {} 61 | reader = csv.reader(f) 62 | for row in reader: 63 | report = {} 64 | for i, sec in enumerate(header): 65 | report[sec] = row[i] 66 | sections[row[0]] = gzip.compress(pickle.dumps(report)) 67 | 68 | with gzip.open(os.path.join(root, 'mimic-cxr-resized', '2.0.0', self.SPLITS_PATH), 'rt', 69 | encoding='utf-8') as f: 70 | f.readline() 71 | reader = csv.reader(f) 72 | interval = 1000 73 | with tqdm(total=self.IMAGE_NUM) as pbar: 74 | pbar.set_description('Data ({0})'.format(split)) 75 | count = 0 76 | for row in reader: 77 | try: 78 | if split is None or split == row[3]: 79 | did = row[0] 80 | sid = row[1] 81 | pid = row[2] 82 | self.ids.append(did) 83 | self.doc_ids.append(sid) 84 | # image 85 | image = os.path.join(root, 'mimic-cxr-resized', '2.0.0', 'files', 'p{0}'.format(pid[:2]), 86 | 'p' + pid, 's' + sid, did + '.png') 87 | 88 | if cache_image: 89 | image = self.bytes_image(image, pre_transform) 90 | # report 91 | report = os.path.join(root, 'mimic-cxr', '2.0.0', 'files', 'p{0}'.format(pid[:2]), 'p' + pid, 92 | 's{0}.txt'.format(sid)) 93 | if cache_text: 94 | sid = 's' + sid 95 | report = sections[sid] if sid in sections else gzip.compress(pickle.dumps({})) 96 | self.samples.append((image, report)) 97 | self.targets.append(report) 98 | count += 1 99 | if count >= interval: 100 | pbar.update(count) 101 | count = 0 102 | except: 103 | print(f'failed {image}') 104 | continue 105 | if count > 0: 106 | pbar.update(count) 107 | 108 | if dump_dir is not None: 109 | self.dump() 110 | self.pre_processes(filter_reports) 111 | 112 | def __getitem__(self, index): 113 | rid, sample, target, _ = super().__getitem__(index) 114 | did = self.doc_ids[index] 115 | # View position features 116 | if self.multi_image > 1: 117 | vp = [self.view_position_embedding(self.view_positions[iid]) for iid in self.image_ids[index]] 118 | vp = [p.unsqueeze(dim=0) for p in vp] 119 | if len(vp) > self.multi_image: 120 | vp = vp[:self.multi_image] 121 | elif len(vp) < self.multi_image: 122 | first_vp = vp[0] 123 | for _ in range(self.multi_image - len(vp)): 124 | vp.append(first_vp.new_zeros(first_vp.size())) 125 | vp = torch.cat(vp, dim=0) 126 | else: 127 | vp = self.view_position_embedding(self.view_positions[rid]) 128 | return did + '__' + rid, sample, target, vp 129 | 130 | @classmethod 131 | def get_transform(cls, cache_image=False, mode='center', augment=False): 132 | return cls._transform(cache_image, 294, mode, augment) 133 | 134 | def compare_texts(self, text1, text2): 135 | if 'study' in text1 and 'study' in text2: 136 | return text1['study'] == text2['study'] 137 | else: 138 | return True 139 | 140 | def decompress_text(self, text): 141 | return pickle.loads(gzip.decompress(text)) 142 | 143 | def extract_section(self, text): 144 | if self.section in text: 145 | return text[self.section].replace('\n', ' ') 146 | else: 147 | return '' 148 | 149 | def pre_processes(self, filter_reports): 150 | if filter_reports: 151 | self.filter_empty_reports() 152 | if self.multi_image > 1: 153 | self.convert_to_multi_images() 154 | elif self.single_image_doc: 155 | self.convert_to_single_image() 156 | self.pre_transform_texts(self.split) 157 | 158 | if __name__ == "__main__": 159 | path = '/home/otabek.nazarov/Downloads/' 160 | datasets = MIMICCXRData(path, section='finding', split='test', 161 | cache_image=False, cache_text=True, multi_image=1) 162 | print(datasets[0]) -------------------------------------------------------------------------------- /clinicgen/data/openi.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import csv 5 | import os 6 | import sys 7 | import time 8 | import xml.etree.ElementTree as etree 9 | from tqdm import tqdm 10 | from clinicgen.data.image2text import _RadiologyReportData 11 | 12 | 13 | class OpenIData(_RadiologyReportData): 14 | CHEXPERT_PATH = 'open-i-chexpert.csv.gz' 15 | IMAGES_DIR = 'NLMCXR_png' 16 | REPORTS_DIR = 'ecgen-radiology' 17 | 18 | def __init__(self, root, section='findings', meta=None, split=None, target_transform=None, cache_image=False, 19 | cache_text=False, multi_image=1, img_mode='center', img_augment=False, single_image_doc=False, 20 | dump_dir=None): 21 | super().__init__(root, section, split, cache_image, cache_text, multi_image=multi_image, 22 | single_image_doc=single_image_doc, dump_dir=dump_dir) 23 | pre_transform, self.transform = OpenIData.get_transform(cache_image, img_mode, img_augment) 24 | self.target_transform = target_transform 25 | if meta is not None: 26 | self.chexpert_labels_path = os.path.join(os.path.dirname(meta), self.CHEXPERT_PATH) 27 | if dump_dir is not None: 28 | t = time.time() 29 | if self.load(): 30 | print('Loaded data dump from %s (%.2fs)' % (dump_dir, time.time() - t)) 31 | self.pre_processes() 32 | return 33 | 34 | splits = {} 35 | if meta is not None: 36 | with open(meta, encoding='utf-8') as f: 37 | reader = csv.reader(f) 38 | for row in reader: 39 | splits[row[0]] = row[1] 40 | 41 | total = 3999 42 | interval = 100 43 | with tqdm(total=total) as pbar: 44 | pbar.set_description('Data ({0})'.format(split)) 45 | count = 0 46 | for i in range(1, total + 1): 47 | path = os.path.join(root, OpenIData.REPORTS_DIR, '{0}.xml'.format(i)) 48 | if os.path.exists(path): 49 | tree = etree.parse(path) 50 | rt = tree.getroot() 51 | uid = None 52 | for ele in rt.findall(".//uId"): 53 | uid = ele.attrib['id'] 54 | if split is None or (uid in splits and splits[uid] == split): 55 | image_ids = [] 56 | for ele in rt.findall(".//parentImage"): 57 | image_ids.append(ele.attrib['id']) 58 | findings = '' 59 | for ele in rt.findall(".//AbstractText[@Label='{0}']".format('FINDINGS')): 60 | findings = ele.text 61 | findings = findings.strip() if findings is not None else '' 62 | if len(findings) > 0 and len(image_ids) > 0: 63 | for image_id in image_ids: 64 | image = os.path.join(root, OpenIData.IMAGES_DIR, '{0}.png'.format(image_id)) 65 | report = self.extract_text(path, compress=True) if cache_text else path 66 | if cache_image: 67 | image = self.bytes_image(image, pre_transform) 68 | self.ids.append(image_id) 69 | self.doc_ids.append(uid) 70 | self.samples.append((image, report)) 71 | self.targets.append(report) 72 | count += 1 73 | if count >= interval: 74 | pbar.update(count) 75 | count = 0 76 | if count > 0: 77 | pbar.update(count) 78 | if len(self.samples) == 0 and len(self.targets) == 0: 79 | sys.stderr.write("WARNING: Found 0 files in subfolders of: " + root + "\n") 80 | 81 | if dump_dir is not None: 82 | self.dump() 83 | self.pre_processes() 84 | 85 | def extract_section(self, text): 86 | rt = etree.fromstring(text) 87 | report = '' 88 | if self.section is None: 89 | for ele in rt.findall(".//AbstractText"): 90 | s = ele.text 91 | if s is not None: 92 | report += s.strip() + '/n' 93 | else: 94 | for ele in rt.findall(".//AbstractText[@Label='{0}']".format(self.section.upper())): 95 | report = ele.text 96 | if report is None: 97 | report = '' 98 | else: 99 | report = report.strip() 100 | return report 101 | 102 | def pre_processes(self): 103 | if self.multi_image > 1: 104 | self.convert_to_multi_images() 105 | elif self.single_image_doc: 106 | self.convert_to_single_image() 107 | self.pre_transform_texts(self.split) 108 | -------------------------------------------------------------------------------- /clinicgen/external/LICENSE_bleu-cider-rouge-spice: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015, Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 14 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 15 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 16 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 17 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 18 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 19 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 20 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 21 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 22 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | 24 | The views and conclusions contained in the software and documentation are those 25 | of the authors and should not be interpreted as representing official policies, 26 | either expressed or implied, of the FreeBSD Project. 27 | -------------------------------------------------------------------------------- /clinicgen/external/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /clinicgen/external/bleu/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/bleu/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/external/bleu/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/bleu/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/external/bleu/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/bleu/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/external/bleu/__pycache__/bleu.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/bleu/__pycache__/bleu.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/external/bleu/__pycache__/bleu.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/bleu/__pycache__/bleu.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/external/bleu/__pycache__/bleu.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/bleu/__pycache__/bleu.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/external/bleu/__pycache__/bleu_scorer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/bleu/__pycache__/bleu_scorer.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/external/bleu/__pycache__/bleu_scorer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/bleu/__pycache__/bleu_scorer.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/external/bleu/__pycache__/bleu_scorer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/bleu/__pycache__/bleu_scorer.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/external/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from clinicgen.external.bleu.bleu_scorer import BleuScorer 12 | 13 | 14 | class Bleu: 15 | def __init__(self, n=4): 16 | # default compute Blue score up to 4 17 | self._n = n 18 | self._hypo_for_image = {} 19 | self.ref_for_image = {} 20 | 21 | def compute_score(self, gts, res, verbose=1): 22 | 23 | assert(gts.keys() == res.keys()) 24 | imgIds = gts.keys() 25 | 26 | bleu_scorer = BleuScorer(n=self._n) 27 | for id in imgIds: 28 | hypo = res[id] 29 | ref = gts[id] 30 | 31 | # Sanity check. 32 | assert(type(hypo) is list) 33 | assert(len(hypo) == 1) 34 | assert(type(ref) is list) 35 | assert(len(ref) >= 1) 36 | 37 | bleu_scorer += (hypo[0], ref) 38 | 39 | #score, scores = bleu_scorer.compute_score(option='shortest') 40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=verbose) 41 | #score, scores = bleu_scorer.compute_score(option='average', verbose=verbose) 42 | 43 | # return (bleu, bleu_info) 44 | return score, scores 45 | 46 | def method(self): 47 | return "Bleu" 48 | -------------------------------------------------------------------------------- /clinicgen/external/bleu/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | '''Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | 19 | import copy 20 | import sys, math, re 21 | from collections import defaultdict 22 | 23 | def precook(s, n=4, out=False): 24 | """Takes a string as input and returns an object that can be given to 25 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 26 | can take string arguments as well.""" 27 | words = s.split() 28 | counts = defaultdict(int) 29 | for k in range(1,n+1): 30 | for i in range(len(words)-k+1): 31 | ngram = tuple(words[i:i+k]) 32 | counts[ngram] += 1 33 | return (len(words), counts) 34 | 35 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 36 | '''Takes a list of reference sentences for a single segment 37 | and returns an object that encapsulates everything that BLEU 38 | needs to know about them.''' 39 | 40 | reflen = [] 41 | maxcounts = {} 42 | for ref in refs: 43 | rl, counts = precook(ref, n) 44 | reflen.append(rl) 45 | for (ngram,count) in counts.items(): 46 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 47 | 48 | # Calculate effective reference sentence length. 49 | if eff == "shortest": 50 | reflen = min(reflen) 51 | elif eff == "average": 52 | reflen = float(sum(reflen))/len(reflen) 53 | 54 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 55 | 56 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 57 | 58 | return (reflen, maxcounts) 59 | 60 | def cook_test(test, ref, eff=None, n=4): 61 | '''Takes a test sentence and returns an object that 62 | encapsulates everything that BLEU needs to know about it.''' 63 | 64 | reflen, refmaxcounts = ref 65 | testlen, counts = precook(test, n, True) 66 | 67 | result = {} 68 | 69 | # Calculate effective reference sentence length. 70 | 71 | if eff == "closest": 72 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1] 73 | else: ## i.e., "average" or "shortest" or None 74 | result["reflen"] = reflen 75 | 76 | result["testlen"] = testlen 77 | 78 | result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)] 79 | 80 | result['correct'] = [0]*n 81 | for (ngram, count) in counts.items(): 82 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) 83 | 84 | return result 85 | 86 | class BleuScorer(object): 87 | """Bleu scorer. 88 | """ 89 | 90 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 91 | # special_reflen is used in oracle (proportional effective ref len for a node). 92 | 93 | def copy(self): 94 | ''' copy the refs.''' 95 | new = BleuScorer(n=self.n) 96 | new.ctest = copy.copy(self.ctest) 97 | new.crefs = copy.copy(self.crefs) 98 | new._score = None 99 | return new 100 | 101 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 102 | ''' singular instance ''' 103 | 104 | self.n = n 105 | self.crefs = [] 106 | self.ctest = [] 107 | self.cook_append(test, refs) 108 | self.special_reflen = special_reflen 109 | 110 | def cook_append(self, test, refs): 111 | '''called by constructor and __iadd__ to avoid creating new instances.''' 112 | 113 | if refs is not None: 114 | self.crefs.append(cook_refs(refs)) 115 | if test is not None: 116 | cooked_test = cook_test(test, self.crefs[-1]) 117 | self.ctest.append(cooked_test) ## N.B.: -1 118 | else: 119 | self.ctest.append(None) # lens of crefs and ctest have to match 120 | 121 | self._score = None ## need to recompute 122 | 123 | def ratio(self, option=None): 124 | self.compute_score(option=option) 125 | return self._ratio 126 | 127 | def score_ratio(self, option=None): 128 | '''return (bleu, len_ratio) pair''' 129 | return (self.fscore(option=option), self.ratio(option=option)) 130 | 131 | def score_ratio_str(self, option=None): 132 | return "%.4f (%.2f)" % self.score_ratio(option) 133 | 134 | def reflen(self, option=None): 135 | self.compute_score(option=option) 136 | return self._reflen 137 | 138 | def testlen(self, option=None): 139 | self.compute_score(option=option) 140 | return self._testlen 141 | 142 | def retest(self, new_test): 143 | if type(new_test) is str: 144 | new_test = [new_test] 145 | assert len(new_test) == len(self.crefs), new_test 146 | self.ctest = [] 147 | for t, rs in zip(new_test, self.crefs): 148 | self.ctest.append(cook_test(t, rs)) 149 | self._score = None 150 | 151 | return self 152 | 153 | def rescore(self, new_test): 154 | ''' replace test(s) with new test(s), and returns the new score.''' 155 | 156 | return self.retest(new_test).compute_score() 157 | 158 | def size(self): 159 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 160 | return len(self.crefs) 161 | 162 | def __iadd__(self, other): 163 | '''add an instance (e.g., from another sentence).''' 164 | 165 | if type(other) is tuple: 166 | ## avoid creating new BleuScorer instances 167 | self.cook_append(other[0], other[1]) 168 | else: 169 | assert self.compatible(other), "incompatible BLEUs." 170 | self.ctest.extend(other.ctest) 171 | self.crefs.extend(other.crefs) 172 | self._score = None ## need to recompute 173 | 174 | return self 175 | 176 | def compatible(self, other): 177 | return isinstance(other, BleuScorer) and self.n == other.n 178 | 179 | def single_reflen(self, option="average"): 180 | return self._single_reflen(self.crefs[0][0], option) 181 | 182 | def _single_reflen(self, reflens, option=None, testlen=None): 183 | 184 | if option == "shortest": 185 | reflen = min(reflens) 186 | elif option == "average": 187 | reflen = float(sum(reflens))/len(reflens) 188 | elif option == "closest": 189 | reflen = min((abs(l-testlen), l) for l in reflens)[1] 190 | else: 191 | assert False, "unsupported reflen option %s" % option 192 | 193 | return reflen 194 | 195 | def recompute_score(self, option=None, verbose=0): 196 | self._score = None 197 | return self.compute_score(option, verbose) 198 | 199 | def compute_score(self, option=None, verbose=0): 200 | n = self.n 201 | small = 1e-9 202 | tiny = 1e-15 ## so that if guess is 0 still return 0 203 | bleu_list = [[] for _ in range(n)] 204 | 205 | if self._score is not None: 206 | return self._score 207 | 208 | if option is None: 209 | option = "average" if len(self.crefs) == 1 else "closest" 210 | 211 | self._testlen = 0 212 | self._reflen = 0 213 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} 214 | 215 | # for each sentence 216 | for comps in self.ctest: 217 | testlen = comps['testlen'] 218 | self._testlen += testlen 219 | 220 | if self.special_reflen is None: ## need computation 221 | reflen = self._single_reflen(comps['reflen'], option, testlen) 222 | else: 223 | reflen = self.special_reflen 224 | 225 | self._reflen += reflen 226 | 227 | for key in ['guess','correct']: 228 | for k in range(n): 229 | totalcomps[key][k] += comps[key][k] 230 | 231 | # append per image bleu score 232 | bleu = 1. 233 | for k in range(n): 234 | bleu *= (float(comps['correct'][k]) + tiny) \ 235 | /(float(comps['guess'][k]) + small) 236 | bleu_list[k].append(bleu ** (1./(k+1))) 237 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 238 | if ratio < 1: 239 | for k in range(n): 240 | bleu_list[k][-1] *= math.exp(1 - 1/ratio) 241 | 242 | if verbose > 1: 243 | print(comps, reflen) 244 | 245 | totalcomps['reflen'] = self._reflen 246 | totalcomps['testlen'] = self._testlen 247 | 248 | bleus = [] 249 | bleu = 1. 250 | for k in range(n): 251 | bleu *= float(totalcomps['correct'][k] + tiny) \ 252 | / (totalcomps['guess'][k] + small) 253 | bleus.append(bleu ** (1./(k+1))) 254 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 255 | if ratio < 1: 256 | for k in range(n): 257 | bleus[k] *= math.exp(1 - 1/ratio) 258 | 259 | if verbose > 0: 260 | print(totalcomps) 261 | print("ratio:", ratio) 262 | 263 | self._score = bleus 264 | return self._score, bleu_list 265 | -------------------------------------------------------------------------------- /clinicgen/external/cider/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /clinicgen/external/cider/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/cider/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/external/cider/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/cider/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/external/cider/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/cider/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/external/cider/__pycache__/cider.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/cider/__pycache__/cider.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/external/cider/__pycache__/cider.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/cider/__pycache__/cider.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/external/cider/__pycache__/cider.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/cider/__pycache__/cider.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/external/cider/__pycache__/cider_scorer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/cider/__pycache__/cider_scorer.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/external/cider/__pycache__/cider_scorer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/cider/__pycache__/cider_scorer.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/external/cider/__pycache__/cider_scorer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/cider/__pycache__/cider_scorer.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/external/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | 10 | from clinicgen.external.cider.cider_scorer import CiderScorer 11 | import pdb 12 | 13 | class Cider: 14 | """ 15 | Main Class to compute the CIDEr metric 16 | 17 | """ 18 | def __init__(self, test=None, refs=None, n=4, sigma=6.0, df=None): 19 | # set cider to sum over 1 to 4-grams 20 | self._n = n 21 | # set the standard deviation parameter for gaussian penalty 22 | self._sigma = sigma 23 | # Added for the use of external document frequencies 24 | self._df = df 25 | 26 | def compute_score(self, gts, res): 27 | """ 28 | Main function to compute CIDEr score 29 | :param hypo_for_image (dict) : dictionary with key and value 30 | ref_for_image (dict) : dictionary with key and value 31 | :return: cider (float) : computed CIDEr score for the corpus 32 | """ 33 | 34 | assert(gts.keys() == res.keys()) 35 | imgIds = gts.keys() 36 | 37 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma, df=self._df) 38 | 39 | for id in imgIds: 40 | hypo = res[id] 41 | ref = gts[id] 42 | 43 | # Sanity check. 44 | assert(type(hypo) is list) 45 | assert(len(hypo) == 1) 46 | assert(type(ref) is list) 47 | assert(len(ref) > 0) 48 | 49 | cider_scorer += (hypo[0], ref) 50 | 51 | (score, scores) = cider_scorer.compute_score() 52 | 53 | return score, scores 54 | 55 | def method(self): 56 | return "CIDEr" -------------------------------------------------------------------------------- /clinicgen/external/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | import copy 6 | from collections import defaultdict 7 | import numpy as np 8 | import pdb 9 | import math 10 | 11 | def precook(s, n=4, out=False): 12 | """ 13 | Takes a string as input and returns an object that can be given to 14 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 15 | can take string arguments as well. 16 | :param s: string : sentence to be converted into ngrams 17 | :param n: int : number of ngrams for which representation is calculated 18 | :return: term frequency vector for occuring ngrams 19 | """ 20 | words = s.split() 21 | counts = defaultdict(int) 22 | for k in range(1,n+1): 23 | for i in range(len(words)-k+1): 24 | ngram = tuple(words[i:i+k]) 25 | counts[ngram] += 1 26 | return counts 27 | 28 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 29 | '''Takes a list of reference sentences for a single segment 30 | and returns an object that encapsulates everything that BLEU 31 | needs to know about them. 32 | :param refs: list of string : reference sentences for some image 33 | :param n: int : number of ngrams for which (ngram) representation is calculated 34 | :return: result (list of dict) 35 | ''' 36 | return [precook(ref, n) for ref in refs] 37 | 38 | def cook_test(test, n=4): 39 | '''Takes a test sentence and returns an object that 40 | encapsulates everything that BLEU needs to know about it. 41 | :param test: list of string : hypothesis sentence for some image 42 | :param n: int : number of ngrams for which (ngram) representation is calculated 43 | :return: result (dict) 44 | ''' 45 | return precook(test, n, True) 46 | 47 | class CiderScorer(object): 48 | """CIDEr scorer. 49 | """ 50 | 51 | def copy(self): 52 | ''' copy the refs.''' 53 | new = CiderScorer(n=self.n) 54 | new.ctest = copy.copy(self.ctest) 55 | new.crefs = copy.copy(self.crefs) 56 | return new 57 | 58 | def __init__(self, test=None, refs=None, n=4, sigma=6.0, df=None): 59 | ''' singular instance ''' 60 | self.n = n 61 | self.sigma = sigma 62 | self.crefs = [] 63 | self.ctest = [] 64 | self.document_frequency = defaultdict(float) if df is None else df 65 | self.cook_append(test, refs) 66 | self.ref_len = None 67 | 68 | def cook_append(self, test, refs): 69 | '''called by constructor and __iadd__ to avoid creating new instances.''' 70 | 71 | if refs is not None: 72 | self.crefs.append(cook_refs(refs)) 73 | if test is not None: 74 | self.ctest.append(cook_test(test)) ## N.B.: -1 75 | else: 76 | self.ctest.append(None) # lens of crefs and ctest have to match 77 | 78 | def size(self): 79 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 80 | return len(self.crefs) 81 | 82 | def __iadd__(self, other): 83 | '''add an instance (e.g., from another sentence).''' 84 | 85 | if type(other) is tuple: 86 | ## avoid creating new CiderScorer instances 87 | self.cook_append(other[0], other[1]) 88 | else: 89 | self.ctest.extend(other.ctest) 90 | self.crefs.extend(other.crefs) 91 | 92 | return self 93 | def compute_doc_freq(self): 94 | ''' 95 | Compute term frequency for reference data. 96 | This will be used to compute idf (inverse document frequency later) 97 | The term frequency is stored in the object 98 | :return: None 99 | ''' 100 | for refs in self.crefs: 101 | # refs, k ref captions of one image 102 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 103 | self.document_frequency[ngram] += 1 104 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 105 | 106 | def compute_cider(self): 107 | def counts2vec(cnts): 108 | """ 109 | Function maps counts of ngram to vector of tfidf weights. 110 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 111 | The n-th entry of array denotes length of n-grams. 112 | :param cnts: 113 | :return: vec (array of dict), norm (array of float), length (int) 114 | """ 115 | vec = [defaultdict(float) for _ in range(self.n)] 116 | length = 0 117 | norm = [0.0 for _ in range(self.n)] 118 | for (ngram,term_freq) in cnts.items(): 119 | # give word count 1 if it doesn't appear in reference corpus 120 | df = np.log(max(1.0, self.document_frequency[ngram])) 121 | # ngram index 122 | n = len(ngram)-1 123 | # tf (term_freq) * idf (precomputed idf) for n-grams 124 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 125 | # compute norm for the vector. the norm will be used for computing similarity 126 | norm[n] += pow(vec[n][ngram], 2) 127 | 128 | if n == 1: 129 | length += term_freq 130 | norm = [np.sqrt(n) for n in norm] 131 | return vec, norm, length 132 | 133 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 134 | ''' 135 | Compute the cosine similarity of two vectors. 136 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 137 | :param vec_ref: array of dictionary for vector corresponding to reference 138 | :param norm_hyp: array of float for vector corresponding to hypothesis 139 | :param norm_ref: array of float for vector corresponding to reference 140 | :param length_hyp: int containing length of hypothesis 141 | :param length_ref: int containing length of reference 142 | :return: array of score for each n-grams cosine similarity 143 | ''' 144 | delta = float(length_hyp - length_ref) 145 | # measure consine similarity 146 | val = np.array([0.0 for _ in range(self.n)]) 147 | for n in range(self.n): 148 | # ngram 149 | for (ngram,count) in vec_hyp[n].items(): 150 | # vrama91 : added clipping 151 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 152 | 153 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 154 | val[n] /= (norm_hyp[n]*norm_ref[n]) 155 | 156 | assert(not math.isnan(val[n])) 157 | # vrama91: added a length based gaussian penalty 158 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 159 | return val 160 | 161 | # compute log reference length 162 | self.ref_len = np.log(float(len(self.crefs))) 163 | 164 | scores = [] 165 | for test, refs in zip(self.ctest, self.crefs): 166 | # compute vector for test captions 167 | vec, norm, length = counts2vec(test) 168 | # compute vector for ref captions 169 | score = np.array([0.0 for _ in range(self.n)]) 170 | for ref in refs: 171 | vec_ref, norm_ref, length_ref = counts2vec(ref) 172 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 173 | # change by vrama91 - mean of ngram scores, instead of sum 174 | score_avg = np.mean(score) 175 | # divide by number of references 176 | score_avg /= len(refs) 177 | # multiply score by 10 178 | score_avg *= 10.0 179 | # append score of an image to the score list 180 | scores.append(score_avg) 181 | return scores 182 | 183 | def compute_score(self, option=None, verbose=0): 184 | # compute idf 185 | if len(self.document_frequency) == 0: 186 | self.compute_doc_freq() 187 | # assert to check document frequency 188 | assert(len(self.ctest) >= max(self.document_frequency.values())) 189 | # compute cider score 190 | score = self.compute_cider() 191 | # debug 192 | # print score 193 | return np.mean(np.array(score)), np.array(score) -------------------------------------------------------------------------------- /clinicgen/external/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'vrama91' 2 | -------------------------------------------------------------------------------- /clinicgen/external/rouge/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/rouge/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/external/rouge/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/rouge/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/external/rouge/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/rouge/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/external/rouge/__pycache__/rouge.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/rouge/__pycache__/rouge.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/external/rouge/__pycache__/rouge.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/rouge/__pycache__/rouge.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/external/rouge/__pycache__/rouge.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/rouge/__pycache__/rouge.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/external/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | def my_lcs(string, sub): 14 | """ 15 | Calculates longest common subsequence for a pair of tokenized strings 16 | :param string : list of str : tokens from a string split using whitespace 17 | :param sub : list of str : shorter string, also split using whitespace 18 | :returns: length (list of int): length of the longest common subsequence between the two strings 19 | 20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 21 | """ 22 | if(len(string)< len(sub)): 23 | sub, string = string, sub 24 | 25 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] 26 | 27 | for j in range(1,len(sub)+1): 28 | for i in range(1,len(string)+1): 29 | if(string[i-1] == sub[j-1]): 30 | lengths[i][j] = lengths[i-1][j-1] + 1 31 | else: 32 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) 33 | 34 | return lengths[len(string)][len(sub)] 35 | 36 | class Rouge(): 37 | ''' 38 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 39 | 40 | ''' 41 | def __init__(self): 42 | # vrama91: updated the value below based on discussion with Hovey 43 | self.beta = 1.2 44 | 45 | def calc_score(self, candidate, refs): 46 | """ 47 | Compute ROUGE-L score given one candidate and references for an image 48 | :param candidate: str : candidate sentence to be evaluated 49 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 50 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 51 | """ 52 | assert(len(candidate)==1) 53 | assert(len(refs)>0) 54 | prec = [] 55 | rec = [] 56 | 57 | # split into tokens 58 | token_c = candidate[0].split(" ") 59 | 60 | for reference in refs: 61 | # split into tokens 62 | token_r = reference.split(" ") 63 | # compute the longest common subsequence 64 | lcs = my_lcs(token_r, token_c) 65 | prec.append(lcs/float(len(token_c))) 66 | rec.append(lcs/float(len(token_r))) 67 | 68 | prec_max = max(prec) 69 | rec_max = max(rec) 70 | 71 | if(prec_max!=0 and rec_max !=0): 72 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) 73 | else: 74 | score = 0.0 75 | return score 76 | 77 | def compute_score(self, gts, res): 78 | """ 79 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 80 | Invoked by evaluate_captions.py 81 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 82 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 83 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 84 | """ 85 | assert(gts.keys() == res.keys()) 86 | imgIds = gts.keys() 87 | 88 | score = [] 89 | for id in imgIds: 90 | hypo = res[id] 91 | ref = gts[id] 92 | 93 | score.append(self.calc_score(hypo, ref)) 94 | 95 | # Sanity check. 96 | assert(type(hypo) is list) 97 | assert(len(hypo) == 1) 98 | assert(type(ref) is list) 99 | assert(len(ref) > 0) 100 | 101 | average_score = np.mean(np.array(score)) 102 | return average_score, np.array(score) 103 | 104 | def method(self): 105 | return "Rouge" 106 | -------------------------------------------------------------------------------- /clinicgen/external/spice/.gitignore: -------------------------------------------------------------------------------- 1 | cache 2 | tmp 3 | -------------------------------------------------------------------------------- /clinicgen/external/spice/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/__init__.py -------------------------------------------------------------------------------- /clinicgen/external/spice/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/external/spice/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/external/spice/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/external/spice/__pycache__/spice.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/__pycache__/spice.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/external/spice/__pycache__/spice.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/__pycache__/spice.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/external/spice/__pycache__/spice.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/__pycache__/spice.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/external/spice/lib/.gitignore: -------------------------------------------------------------------------------- 1 | stanford-corenlp-*.jar 2 | -------------------------------------------------------------------------------- /clinicgen/external/spice/lib/Meteor-1.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/lib/Meteor-1.5.jar -------------------------------------------------------------------------------- /clinicgen/external/spice/lib/SceneGraphParser-1.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/lib/SceneGraphParser-1.0.jar -------------------------------------------------------------------------------- /clinicgen/external/spice/lib/ejml-0.23.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/lib/ejml-0.23.jar -------------------------------------------------------------------------------- /clinicgen/external/spice/lib/fst-2.47.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/lib/fst-2.47.jar -------------------------------------------------------------------------------- /clinicgen/external/spice/lib/guava-19.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/lib/guava-19.0.jar -------------------------------------------------------------------------------- /clinicgen/external/spice/lib/hamcrest-core-1.3.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/lib/hamcrest-core-1.3.jar -------------------------------------------------------------------------------- /clinicgen/external/spice/lib/jackson-core-2.5.3.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/lib/jackson-core-2.5.3.jar -------------------------------------------------------------------------------- /clinicgen/external/spice/lib/javassist-3.19.0-GA.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/lib/javassist-3.19.0-GA.jar -------------------------------------------------------------------------------- /clinicgen/external/spice/lib/json-simple-1.1.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/lib/json-simple-1.1.1.jar -------------------------------------------------------------------------------- /clinicgen/external/spice/lib/junit-4.12.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/lib/junit-4.12.jar -------------------------------------------------------------------------------- /clinicgen/external/spice/lib/lmdbjni-0.4.6.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/lib/lmdbjni-0.4.6.jar -------------------------------------------------------------------------------- /clinicgen/external/spice/lib/lmdbjni-linux64-0.4.6.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/lib/lmdbjni-linux64-0.4.6.jar -------------------------------------------------------------------------------- /clinicgen/external/spice/lib/lmdbjni-osx64-0.4.6.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/lib/lmdbjni-osx64-0.4.6.jar -------------------------------------------------------------------------------- /clinicgen/external/spice/lib/lmdbjni-win64-0.4.6.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/lib/lmdbjni-win64-0.4.6.jar -------------------------------------------------------------------------------- /clinicgen/external/spice/lib/objenesis-2.4.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/lib/objenesis-2.4.jar -------------------------------------------------------------------------------- /clinicgen/external/spice/lib/slf4j-api-1.7.12.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/lib/slf4j-api-1.7.12.jar -------------------------------------------------------------------------------- /clinicgen/external/spice/lib/slf4j-simple-1.7.21.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/lib/slf4j-simple-1.7.21.jar -------------------------------------------------------------------------------- /clinicgen/external/spice/spice-1.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/external/spice/spice-1.0.jar -------------------------------------------------------------------------------- /clinicgen/external/spice/spice.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import sys 4 | import subprocess 5 | import threading 6 | import json 7 | import numpy as np 8 | import ast 9 | import tempfile 10 | 11 | # Assumes spice.jar is in the same directory as spice.py. Change as needed. 12 | SPICE_JAR = 'spice-1.0.jar' 13 | TEMP_DIR = 'tmp' 14 | CACHE_DIR = 'cache' 15 | 16 | class Spice: 17 | """ 18 | Main Class to compute the SPICE metric 19 | """ 20 | 21 | def float_convert(self, obj): 22 | try: 23 | return float(obj) 24 | except: 25 | return np.nan 26 | 27 | def compute_score(self, gts, res): 28 | assert(sorted(gts.keys()) == sorted(res.keys())) 29 | imgIds = sorted(gts.keys()) 30 | 31 | # Prepare temp input file for the SPICE scorer 32 | input_data = [] 33 | for id in imgIds: 34 | hypo = res[id] 35 | ref = gts[id] 36 | 37 | # Sanity check. 38 | assert(type(hypo) is list) 39 | assert(len(hypo) == 1) 40 | assert(type(ref) is list) 41 | assert(len(ref) >= 1) 42 | 43 | input_data.append({ 44 | "image_id" : id, 45 | "test" : hypo[0], 46 | "refs" : ref 47 | }) 48 | 49 | cwd = os.path.dirname(os.path.abspath(__file__)) 50 | temp_dir=os.path.join(cwd, TEMP_DIR) 51 | if not os.path.exists(temp_dir): 52 | os.makedirs(temp_dir) 53 | in_file = tempfile.NamedTemporaryFile(mode='w', delete=False, dir=temp_dir) 54 | json.dump(input_data, in_file, indent=2) 55 | in_file.close() 56 | 57 | # Start job 58 | out_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) 59 | out_file.close() 60 | cache_dir=os.path.join(cwd, CACHE_DIR) 61 | if not os.path.exists(cache_dir): 62 | os.makedirs(cache_dir) 63 | spice_cmd = ['java', '-jar', '-Xmx8G', SPICE_JAR, in_file.name, 64 | '-cache', cache_dir, 65 | '-out', out_file.name, 66 | '-subset', 67 | '-silent' 68 | ] 69 | subprocess.check_call(spice_cmd, 70 | cwd=os.path.dirname(os.path.abspath(__file__))) 71 | 72 | # Read and process results 73 | with open(out_file.name) as data_file: 74 | results = json.load(data_file) 75 | os.remove(in_file.name) 76 | os.remove(out_file.name) 77 | 78 | imgId_to_scores = {} 79 | spice_scores = [] 80 | for item in results: 81 | imgId_to_scores[item['image_id']] = item['scores'] 82 | spice_scores.append(self.float_convert(item['scores']['All']['f'])) 83 | average_score = np.mean(np.array(spice_scores)) 84 | scores = [] 85 | for image_id in imgIds: 86 | # Convert none to NaN before saving scores over subcategories 87 | score_set = {} 88 | for category,score_tuple in imgId_to_scores[image_id].items(): 89 | score_set[category] = {k: self.float_convert(v) for k, v in score_tuple.items()} 90 | scores.append(score_set) 91 | return average_score, scores 92 | 93 | def method(self): 94 | return "SPICE" 95 | 96 | 97 | -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/bertnli.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/bertnli.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/bertnli.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/bertnli.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/bertnli.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/bertnli.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/cnnrnnrnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/cnnrnnrnn.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/cnnrnnrnn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/cnnrnnrnn.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/cnnrnnrnn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/cnnrnnrnn.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/image.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/image.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/image.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/image.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/image.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/image.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/image2text.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/image2text.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/image2text.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/image2text.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/image2text.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/image2text.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/kwl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/kwl.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/kwl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/kwl.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/kwl.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/kwl.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/m2transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/m2transformer.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/m2transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/m2transformer.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/m2transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/m2transformer.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/sat.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/sat.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/sat.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/sat.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/sat.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/sat.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/tienet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/tienet.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/tienet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/tienet.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/tienet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/tienet.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/transformer.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/models/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/models/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/models/bertnli.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | from torch.nn import Dropout, Linear 6 | from torch.nn.functional import cross_entropy 7 | from torch.nn.utils import clip_grad_norm_ 8 | from transformers import AutoModel, AutoTokenizer 9 | from clinicgen.utils import data_cuda 10 | 11 | 12 | MY_DEVICE = torch.device('cuda') 13 | 14 | class BERTNLI(torch.nn.Module): 15 | LABEL_CONTRADICTION = 2 16 | LABEL_ENTAILMENT = 0 17 | LABEL_NEUTRAL = 1 18 | 19 | def __init__(self, name_or_path, bert_type='bert', cls='linear', length=128, force_lowercase=False, device='cuda', 20 | verbose=False): 21 | super(BERTNLI, self).__init__() 22 | self.bert_type = bert_type 23 | self.cls = cls 24 | self.force_lowercase = force_lowercase 25 | self.bert = AutoModel.from_pretrained(name_or_path) 26 | self.tokenizer = AutoTokenizer.from_pretrained(name_or_path) 27 | self.dropout = Dropout(0.1) 28 | self.linear = Linear(self.bert.config.hidden_size, 3) 29 | self.length = length 30 | self.device = device 31 | self.verbose = verbose 32 | 33 | @classmethod 34 | def train_step(cls, logits, gold_labels, optimizer, model=None, grad_clip=None): 35 | optimizer.zero_grad() 36 | target = logits.new_zeros((logits.shape[0],), dtype=torch.long) 37 | for i, gold_label in enumerate(gold_labels): 38 | if gold_label == 'entailment': 39 | target[i] = cls.LABEL_ENTAILMENT 40 | elif gold_label == 'neutral': 41 | target[i] = cls.LABEL_NEUTRAL 42 | elif gold_label == 'contradiction': 43 | target[i] = cls.LABEL_CONTRADICTION 44 | else: 45 | raise ValueError('Unknown label {0}'.format(gold_label)) 46 | loss = cross_entropy(logits, target) 47 | loss.backward() 48 | if grad_clip is not None: 49 | clip_grad_norm_(model.parameters(), grad_clip) 50 | optimizer.step() 51 | return float(loss.detach().cpu()) 52 | 53 | def cuda(self, device=None): 54 | super(BERTNLI, self).cuda(device) 55 | self.device = 'gpu' 56 | 57 | def forward(self, sent1s, sent2s): 58 | buffer, boundaries, max_len = [], [], 0 59 | for sent1, sent2 in zip(sent1s, sent2s): 60 | if self.force_lowercase: 61 | sent1 = sent1.lower() 62 | sent2 = sent2.lower() 63 | toks1 = self.tokenizer.tokenize(sent1) 64 | toks2 = self.tokenizer.tokenize(sent2) 65 | tokens = ['[CLS]'] + toks1 + ['[SEP]'] + toks2 + ['[SEP]'] 66 | buffer.append(tokens) 67 | boundaries.append(len(toks1) + 2) 68 | if len(tokens) > max_len: 69 | max_len = len(tokens) 70 | if max_len > self.length: 71 | max_len = self.length 72 | token_ids, attn_mask = [], [] 73 | seg_ids = [] if self.bert_type != 'distilbert' else None 74 | for idx, tokens in enumerate(buffer): 75 | if len(tokens) < max_len: 76 | for _ in range(max_len - len(tokens)): 77 | tokens.append('[PAD]') 78 | elif len(tokens) > max_len: 79 | if self.verbose: 80 | print('Truncating pair from {0}->{1}'.format(len(tokens), max_len)) 81 | tokens = tokens[:max_len] 82 | attn_mask.append(torch.tensor([1 if token != '[PAD]' else 0 for token in tokens])) 83 | token_ids.append(torch.tensor(self.tokenizer.convert_tokens_to_ids(tokens))) 84 | if seg_ids is not None: 85 | seg_ids.append(torch.tensor([0 if i < boundaries[idx] else 1 for i in range(len(tokens))])) 86 | token_ids = torch.stack(token_ids, dim=0) 87 | attn_mask = torch.stack(attn_mask, dim=0) 88 | token_ids, attn_mask = token_ids.to(MY_DEVICE), attn_mask.to(MY_DEVICE) 89 | if seg_ids is not None: 90 | seg_ids = torch.stack(seg_ids, dim=0) 91 | seg_ids = seg_ids.to(MY_DEVICE) 92 | # token_ids, attn_mask, seg_ids = data_cuda(token_ids, attn_mask, seg_ids, device=self.device) 93 | # else: 94 | # token_ids, attn_mask = data_cuda(token_ids, attn_mask, device=self.device) 95 | if self.bert_type == 'distilbert': 96 | reps = self.bert(token_ids, attention_mask=attn_mask) 97 | reps = reps[0][:, 0] 98 | reps = self.dropout(reps) 99 | return self.linear(reps) 100 | else: 101 | reps, cls = self.bert(token_ids, attention_mask=attn_mask, token_type_ids=seg_ids) 102 | if self.cls == 'token': 103 | reps = reps[:, 0] 104 | reps = self.dropout(reps) 105 | return self.linear(reps) 106 | else: 107 | cls = self.dropout(cls) 108 | return self.linear(cls) 109 | -------------------------------------------------------------------------------- /clinicgen/models/image.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import gzip 5 | import torch 6 | from torch import relu 7 | from torch.nn import Dropout, Linear, Sequential 8 | from torch.nn.functional import adaptive_avg_pool2d, cross_entropy 9 | from torch.nn.utils import clip_grad_norm_ 10 | from torchvision import models 11 | from clinicgen.utils import data_cuda 12 | 13 | from custom_models import * 14 | 15 | 16 | class ImageClassification(torch.nn.Module): 17 | def __init__(self, model, num_labels, num_classes, multi_image=1, dropout=0.0, pretrained=True): 18 | super(ImageClassification, self).__init__() 19 | self.image_feats, self.image_dim = self.image_features(model, False, pretrained) 20 | for i in range(num_labels): 21 | setattr(self, 'linear{0}'.format(i), Linear(self.image_dim, num_classes)) 22 | self.num_labels = num_labels 23 | self.multi_image = multi_image 24 | self.dropout = Dropout(p=dropout) 25 | 26 | @classmethod 27 | def fix_layers(cls, model): 28 | for param in model.parameters(): 29 | param.requires_grad = False 30 | 31 | @classmethod 32 | def image_features(cls, name, fixed_weight=False, pretrained=True, pretrained_model=None, device='gpu'): 33 | m = CustomEncoder() 34 | feature_dim = m.feature_dim 35 | return m, feature_dim 36 | 37 | def deflatten_image(self, x): 38 | if self.multi_image > 1: 39 | x = x.view(int(x.shape[0] / self.multi_image), self.multi_image, x.shape[1]) 40 | x, _ = torch.max(x, dim=1) 41 | return x 42 | 43 | def flatten_image(self, x): 44 | if self.multi_image > 1: 45 | return x.flatten(start_dim=0, end_dim=1) 46 | else: 47 | return x 48 | 49 | def forward(self, x): 50 | x = self.flatten_image(x) 51 | x = self.image_feats(x) 52 | x = relu(x) 53 | x = adaptive_avg_pool2d(x, (1, 1)) 54 | x = torch.flatten(x, 1) 55 | x = self.deflatten_image(x) 56 | xs = [] 57 | for i in range(self.num_labels): 58 | xi = self.dropout(x) 59 | xi = getattr(self, 'linear{0}'.format(i))(xi).unsqueeze(dim=2) 60 | xs.append(xi) 61 | x = torch.cat(xs, dim=2) 62 | return x 63 | 64 | def train_step(self, inp, targ, optimizer, clip_grad=None, device='gpu'): 65 | optimizer.zero_grad() 66 | inp, targ = data_cuda(inp, targ, device=device, non_blocking=False) 67 | targ = targ.squeeze(dim=-1) 68 | out = self.forward(inp) 69 | out = out.squeeze(dim=-1) 70 | loss = cross_entropy(out, targ, ignore_index=-100, reduction='mean') 71 | loss.backward() 72 | loss_val = loss.detach().cpu() 73 | if clip_grad is not None: 74 | clip_grad_norm_(self.parameters(), clip_grad) 75 | optimizer.step() 76 | return loss_val 77 | -------------------------------------------------------------------------------- /clinicgen/models/image_orig.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import gzip 5 | import torch 6 | from torch import relu 7 | from torch.nn import Dropout, Linear, Sequential 8 | from torch.nn.functional import adaptive_avg_pool2d, cross_entropy 9 | from torch.nn.utils import clip_grad_norm_ 10 | from torchvision import models 11 | from clinicgen.utils import data_cuda 12 | 13 | 14 | class ImageClassification(torch.nn.Module): 15 | def __init__(self, model, num_labels, num_classes, multi_image=1, dropout=0.0, pretrained=True): 16 | super(ImageClassification, self).__init__() 17 | self.image_feats, self.image_dim = self.image_features(model, False, pretrained) 18 | for i in range(num_labels): 19 | setattr(self, 'linear{0}'.format(i), Linear(self.image_dim, num_classes)) 20 | self.num_labels = num_labels 21 | self.multi_image = multi_image 22 | self.dropout = Dropout(p=dropout) 23 | 24 | @classmethod 25 | def fix_layers(cls, model): 26 | for param in model.parameters(): 27 | param.requires_grad = False 28 | 29 | @classmethod 30 | def image_features(cls, name, fixed_weight=False, pretrained=True, pretrained_model=None, device='gpu'): 31 | if pretrained_model is None: 32 | if name == 'densenet121' or name == 'densenet': 33 | m = models.densenet121(pretrained=pretrained) 34 | if fixed_weight: 35 | cls.fix_layers(m) 36 | return Sequential(*list(m.features.children())), 1024 37 | elif name == 'resnet50': 38 | m = models.resnet50(pretrained=pretrained) 39 | if fixed_weight: 40 | cls.fix_layers(m) 41 | return Sequential(*list(m.children())[:-2]), 2048 42 | elif name == 'resnet152' or name == 'resnet': 43 | m = models.resnet152(pretrained=pretrained) 44 | if fixed_weight: 45 | cls.fix_layers(m) 46 | return Sequential(*list(m.children())[:-2]), 2048 47 | elif name == 'vgg19' or name == 'vgg': 48 | m = models.vgg19(pretrained=pretrained) 49 | if fixed_weight: 50 | cls.fix_layers(m) 51 | return Sequential(*list(m.features.children())[:-1]), 512 52 | else: 53 | raise ValueError('Unknown model {0}'.format(name)) 54 | else: 55 | d = torch.device('cpu')if device == 'cpu' else torch.device('cuda:0') 56 | with gzip.open(pretrained_model) as f: 57 | state = torch.load(f, map_location=d) 58 | m = ImageClassification(name, 14, 3, pretrained=False) 59 | m.load_state_dict(state['model']) 60 | if fixed_weight: 61 | cls.fix_layers(m) 62 | return m.image_feats, m.image_dim 63 | 64 | def deflatten_image(self, x): 65 | if self.multi_image > 1: 66 | x = x.view(int(x.shape[0] / self.multi_image), self.multi_image, x.shape[1]) 67 | x, _ = torch.max(x, dim=1) 68 | return x 69 | 70 | def flatten_image(self, x): 71 | if self.multi_image > 1: 72 | return x.flatten(start_dim=0, end_dim=1) 73 | else: 74 | return x 75 | 76 | def forward(self, x): 77 | x = self.flatten_image(x) 78 | x = self.image_feats(x) 79 | x = relu(x) 80 | x = adaptive_avg_pool2d(x, (1, 1)) 81 | x = torch.flatten(x, 1) 82 | x = self.deflatten_image(x) 83 | xs = [] 84 | for i in range(self.num_labels): 85 | xi = self.dropout(x) 86 | xi = getattr(self, 'linear{0}'.format(i))(xi).unsqueeze(dim=2) 87 | xs.append(xi) 88 | x = torch.cat(xs, dim=2) 89 | return x 90 | 91 | def train_step(self, inp, targ, optimizer, clip_grad=None, device='gpu'): 92 | optimizer.zero_grad() 93 | inp, targ = data_cuda(inp, targ, device=device, non_blocking=False) 94 | targ = targ.squeeze(dim=-1) 95 | out = self.forward(inp) 96 | out = out.squeeze(dim=-1) 97 | loss = cross_entropy(out, targ, ignore_index=-100, reduction='mean') 98 | loss.backward() 99 | loss_val = loss.detach().cpu() 100 | if clip_grad is not None: 101 | clip_grad_norm_(self.parameters(), clip_grad) 102 | optimizer.step() 103 | return loss_val 104 | -------------------------------------------------------------------------------- /clinicgen/models/kwl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import random 5 | import torch 6 | from torch import sigmoid, tanh 7 | from torch.distributions.categorical import Categorical 8 | from torch.nn import Dropout, Embedding, Linear, LSTMCell 9 | from torch.nn.functional import cross_entropy, relu, softmax 10 | from clinicgen.models.image import ImageClassification 11 | from clinicgen.models.image2text import _Image2Text, PretrainedEmbeddings 12 | from clinicgen.utils import data_cuda 13 | 14 | 15 | class KnowingWhenToLook(_Image2Text): 16 | VISUAL_NUM = 49 17 | 18 | def __init__(self, embeddings, feat_dim=512, max_word=32, multi_image=1, multi_merge='att', teacher_forcing=False, 19 | image_model=None, image_pretrained=None, finetune_image=False, image_finetune_epoch=None, rl_opts=None, 20 | word_idxs=None, device='gpu', verbose=False): 21 | super(KnowingWhenToLook, self).__init__(max_word, multi_image, multi_merge, teacher_forcing, 22 | image_finetune_epoch, rl_opts, word_idxs, verbose) 23 | self.feat_dim = feat_dim 24 | 25 | self.dropout = Dropout(0.5) 26 | # Image processes 27 | if image_model is None: 28 | image_model = 'resnet' 29 | self.image_feats, image_dim = ImageClassification.image_features(image_model, not finetune_image, True, 30 | image_pretrained, device) 31 | self._init_multi_image(image_dim, self.VISUAL_NUM, feat_dim) 32 | self.image_proj_l = Linear(image_dim, feat_dim) 33 | self.image_proj_g = Linear(image_dim, feat_dim) 34 | # Visual sentinel 35 | input_dim = feat_dim + embeddings.shape[1] 36 | self.vs_att_h = Linear(self.VISUAL_NUM, 1, bias=False) 37 | self.vs_att_v = Linear(feat_dim, self.VISUAL_NUM, bias=False) 38 | self.vs_att_g = Linear(feat_dim, self.VISUAL_NUM, bias=False) 39 | self.vs_att_s = Linear(feat_dim, self.VISUAL_NUM, bias=False) 40 | self.vs_dense1 = Linear(input_dim, feat_dim, bias=False) 41 | self.vs_dense2 = Linear(feat_dim, feat_dim, bias=False) 42 | # Word processes 43 | self.lstm_word = LSTMCell(input_dim, feat_dim) 44 | self.embeddings = Embedding.from_pretrained(embeddings, freeze=False, 45 | padding_idx=PretrainedEmbeddings.INDEX_PAD) 46 | self.embed_num = self.embeddings.num_embeddings 47 | self.word_dense = Linear(feat_dim, embeddings.shape[0], bias=False) 48 | 49 | def _init_multi_image(self, image_dim, visual_num, rnn_dim): 50 | super(KnowingWhenToLook, self)._init_multi_image(image_dim, visual_num, rnn_dim) 51 | if self.multi_image > 1 and self.multi_merge == self.MULTI_MERGE_ATT: 52 | sentinel_num = visual_num + 1 53 | self.att_z_z = Linear(sentinel_num, sentinel_num) 54 | self.att_z_h = Linear(rnn_dim, sentinel_num) 55 | self.att_z_a = Linear(sentinel_num, 1) 56 | 57 | def _nll_step(self, encoded_data, targ, device, non_blocking, ids=None): 58 | vl, vg = encoded_data 59 | y = data_cuda(targ, device=device, non_blocking=non_blocking) 60 | words = self.decode_teacher_forcing(y, vl, vg) 61 | return self.loss_nll(y, words), [] 62 | 63 | def _rl_step(self, encoded_data, targ, device, non_blocking, ids=None): 64 | self.eval() 65 | with torch.no_grad(): 66 | _, words_greedy, _ = self.decode_beam(encoded_data, beam_size=1) 67 | gens_greedy, _ = self.evaluator.recover_words(words_greedy.squeeze(dim=1).squeeze(dim=1)) 68 | self.train() 69 | words, log_probs = self.sample(encoded_data) 70 | gens_sample, masks_sample = self.evaluator.recover_words(words) 71 | gens_ref, _ = self.evaluator.recover_words(targ.squeeze(dim=1)) 72 | rewards, loss_acc = self.self_critical_reward(gens_sample, gens_greedy, gens_ref, masks_sample, log_probs, 73 | ids=ids) 74 | return loss_acc, rewards 75 | 76 | def decode_beam(self, encoded_data, beam_size, allow_stop=True, recover_words=None, diversity_rate=0.0): 77 | vl, vg = encoded_data 78 | # Initialize word process states 79 | w = vg.new_full((vg.shape[0], 1), PretrainedEmbeddings.INDEX_START, dtype=torch.long) 80 | h, c = vg.new_zeros((vg.shape[0], self.feat_dim)), vg.new_zeros((vg.shape[0], self.feat_dim)) 81 | states = (vg, vl, h, c) 82 | # Decode words 83 | beam_words, logs = self._decode_words_beam(w, states, beam_size, allow_stop, recover_words, diversity_rate) 84 | dummy_stops = self.dummy_stops(beam_words) 85 | return dummy_stops, beam_words.unsqueeze(dim=1), logs 86 | 87 | def decode_teacher_forcing(self, y, vl, vg): 88 | # Masks 89 | not_masked = y.new_ones(1, dtype=torch.bool)[0] 90 | mask = ((y > 0).sum(dim=(0, 1)) > 0) 91 | 92 | h, m = vg.new_zeros((vg.shape[0], self.feat_dim)), vg.new_zeros((vg.shape[0], self.feat_dim)) 93 | w = vg.new_full((vg.shape[0], 1), PretrainedEmbeddings.INDEX_START, dtype=torch.long) 94 | states = (vg, vl, h, m) 95 | 96 | words = [] 97 | for j in range(self.max_word): 98 | if torch.equal(mask[j], not_masked): 99 | p, states = self.proc_word(w, states) 100 | words.append(p) 101 | if self.teacher_forcing is None or self.teacher_forcing.get_tfr() >= random.random(): 102 | w = y[:, 0, j] 103 | else: 104 | p = softmax(p, dim=-1) 105 | cat = Categorical(probs=p) 106 | w = cat.sample() 107 | else: 108 | p = vg.new_ones(vg.shape[0], self.embed_num) / self.embed_num 109 | words.append(p) 110 | if self.teacher_forcing is None or self.teacher_forcing.get_tfr() >= random.random(): 111 | w = y[:, 0, j] 112 | else: 113 | w = vg.new_zeros(vg.shape[0]) 114 | return torch.stack(words, dim=1) 115 | 116 | def encode(self, x, meta): 117 | return self.encode_image(x) 118 | 119 | def encode_image(self, x): 120 | # Resnet-152 features 121 | x = self.flatten_image(x) 122 | x, mask = self.image_features_with_mask(x, self.image_feats) 123 | # Nx2048x7x7 -> Nx49x512 124 | if len(x.shape) > 3: 125 | x = x.flatten(start_dim=-2, end_dim=-1) 126 | x = x.permute(0, 2, 1) 127 | vl = self.dropout(x) 128 | vl = relu(self.image_proj_l(vl)) 129 | vl *= mask 130 | vl = self.deflatten_image(vl) 131 | # Mean visual feature (Nx512) 132 | vg = (x * mask).mean(dim=1) 133 | vg = self.deflatten_image(vg) 134 | vg = self.multi_vg(vg) 135 | vg = self.dropout(vg) 136 | vg = relu(self.image_proj_g(vg)) 137 | return vl, vg 138 | 139 | def forward(self, x, y, meta): 140 | # Encode image 141 | vl, vg = self.encode(x, meta) 142 | # Decode text 143 | return self.decode_teacher_forcing(y, vl, vg) 144 | 145 | def loss_nll(self, y, words): 146 | # Word loss 147 | words = words.unsqueeze(dim=1).permute(0, 3, 1, 2) 148 | loss_word = cross_entropy(words, y, ignore_index=PretrainedEmbeddings.INDEX_PAD, reduction='mean') 149 | return loss_word 150 | 151 | def multi_cb(self, zs, cv, b, h): 152 | if self.multi_image > 1: 153 | if self.multi_merge == self.MULTI_MERGE_ATT: 154 | h_a = self.att_z_h(h).unsqueeze(dim=1).repeat(1, zs.shape[1], 1) 155 | alpha = self.att_z_a(tanh(self.att_z_z(zs) + h_a)).squeeze(dim=2) 156 | alpha = softmax(alpha, dim=-1) 157 | alpha = alpha.unsqueeze(dim=2) 158 | cv = (cv * alpha).sum(dim=1) 159 | b = (b * alpha).sum(dim=1) 160 | elif self.multi_merge == self.MULTI_MERGE_MAX: 161 | cv, _ = torch.max(cv, dim=1) 162 | b, _ = torch.min(b, dim=1) 163 | else: 164 | raise ValueError('Unknown multi merge {0}'.format(self.multi_merge)) 165 | return cv, b 166 | 167 | def proc_word(self, w, states): 168 | vg, vl, hw, mw = states 169 | z, hw, mw, s = self.visual_sentinel(w, vg, vl, hw, mw) 170 | a, b, c = self.visual_sentinel_attention(vl, hw, z, s) 171 | # Word generation probability (NxW), note that softmax is not applied here. 172 | p = self.word_dense(self.dropout(c + hw)) 173 | return p, (vg, vl, hw, mw) 174 | 175 | def sample(self, encoded_data, nucleus_p=None): 176 | vl, vg = encoded_data 177 | w = vg.new_full((vg.shape[0], 1), PretrainedEmbeddings.INDEX_START, dtype=torch.long) 178 | h, c = vg.new_zeros((vg.shape[0], self.feat_dim)), vg.new_zeros((vg.shape[0], self.feat_dim)) 179 | states = (vg, vl, h, c) 180 | return self._sample_words(w, states, nucleus_p) 181 | 182 | def visual_sentinel(self, w, vg, vl, h, m): 183 | # Visual features (Nx49) 184 | z1 = self.vs_att_v(vl) 185 | if self.multi_image > 1: 186 | z2 = self.vs_att_g(h).unsqueeze(dim=1).unsqueeze(dim=1) 187 | z2 = z2.repeat(1, vl.shape[1], vl.shape[2], 1) 188 | else: 189 | z2 = self.vs_att_g(h).unsqueeze(dim=1) 190 | z2 = z2.repeat(1, vl.shape[1], 1) 191 | z = self.vs_att_h(tanh(z1 + z2)).squeeze(dim=-1) 192 | # Visual sentinel (Nx512) 193 | e = self.embeddings(w).squeeze(1) 194 | xw = torch.cat((vg, e), dim=1) 195 | g = sigmoid(self.vs_dense1(xw) + self.vs_dense2(h)) 196 | h, m = self.lstm_word(xw, (h, m)) 197 | s = g * tanh(m) 198 | return z, h, m, s 199 | 200 | def visual_sentinel_attention(self, vl, h, z, s): 201 | # Attention to visual sentinel & visual features (Nx50) 202 | a1 = softmax(z, dim=-1) 203 | zs = tanh(self.vs_att_s(s) + self.vs_att_g(h)) 204 | if self.multi_image > 1: 205 | zs = zs.unsqueeze(dim=1).repeat(1, vl.shape[1], 1) 206 | zs = torch.cat((z, self.vs_att_h(zs)), dim=-1) 207 | a2 = softmax(zs, dim=-1) 208 | # Mixture of visual sentinel and visual features (Nx512) 209 | cv = torch.sum(a1.unsqueeze(-1) * vl, dim=-2) 210 | b = a2[:, :, -1].unsqueeze(dim=-1) if self.multi_image > 1 else a2[:, -1].unsqueeze(dim=-1) 211 | cv, b = self.multi_cb(zs, cv, b, h) 212 | c = b * s + (cv.new_ones((cv.shape[0], 1)) - b) * cv 213 | return a2, b, c 214 | -------------------------------------------------------------------------------- /clinicgen/models/utils.py: -------------------------------------------------------------------------------- 1 | #! /bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | from clinicgen.models.cnnrnnrnn import CNNRNNRNN 5 | from clinicgen.models.kwl import KnowingWhenToLook 6 | from clinicgen.models.m2transformer import M2Transformer 7 | from clinicgen.models.sat import ShowAttendAndTell 8 | from clinicgen.models.tienet import TieNet 9 | from clinicgen.models.transformer import TransformerCaptioner, TransformerSimpleCaptioner 10 | from clinicgen.nli import SimpleNLI 11 | 12 | 13 | class Models: 14 | @classmethod 15 | def get_model(cls, name, embeddings, hidden_size, max_word, max_sent=None, multi_image=1, multi_merge='att', 16 | teacher_forcing=None, image_model=None, image_pretrained=None, finetune_image=True, 17 | view_position=False, image_finetune_epoch=None, rl_opts=None, word_idxs=None, device='gpu', 18 | parallel_sent=True, cnnrnnrnn_topic_state=False, cnnrnnrnn_simple_proj=False, sat_lstm_dim=1000, 19 | trans_image_pe=True, trans_layers=6, trans_enc_layers=None, trans_layer_norm=False, m2_memory=40, 20 | tienet_labels=None, verbose=False): 21 | if trans_enc_layers is None: 22 | trans_enc_layers = trans_layers 23 | if name == 'cnnrnnrnn': 24 | model = CNNRNNRNN(embeddings, feat_dim=hidden_size, max_sent=max_sent, max_word=max_word, 25 | multi_image=multi_image, multi_merge=multi_merge, topic_as_state=cnnrnnrnn_topic_state, 26 | simple_proj=cnnrnnrnn_simple_proj, teacher_forcing=teacher_forcing, 27 | parallel_sent=parallel_sent, view_position=view_position, image_model=image_model, 28 | image_pretrained=image_pretrained, finetune_image=finetune_image, 29 | image_finetune_epoch=image_finetune_epoch, rl_opts=rl_opts, word_idxs=word_idxs, 30 | device=device, verbose=verbose) 31 | elif name == 'kwl': 32 | model = KnowingWhenToLook(embeddings, feat_dim=hidden_size, max_word=max_word, multi_image=multi_image, 33 | multi_merge=multi_merge, teacher_forcing=teacher_forcing, image_model=image_model, 34 | image_pretrained=image_pretrained, finetune_image=finetune_image, 35 | image_finetune_epoch=image_finetune_epoch, rl_opts=rl_opts, word_idxs=word_idxs, 36 | device=device, verbose=verbose) 37 | elif name == 'm2trans': 38 | model = M2Transformer(embeddings, feat_dim=hidden_size, max_word=max_word, multi_image=multi_image, 39 | layer_norm=trans_layer_norm, num_enc_layers=trans_enc_layers, 40 | num_dec_layers=trans_layers, num_memory=m2_memory, teacher_forcing=teacher_forcing, 41 | image_model=image_model, image_pretrained=image_pretrained, 42 | finetune_image=finetune_image, image_finetune_epoch=image_finetune_epoch, 43 | rl_opts=rl_opts, word_idxs=word_idxs, device=device, verbose=verbose) 44 | elif name == 'sat': 45 | model = ShowAttendAndTell(embeddings, context_dim=hidden_size, lstm_dim=sat_lstm_dim, max_word=max_word, 46 | multi_image=multi_image, multi_merge=multi_merge, teacher_forcing=teacher_forcing, 47 | image_model=image_model, image_pretrained=image_pretrained, 48 | finetune_image=finetune_image, image_finetune_epoch=image_finetune_epoch, 49 | rl_opts=rl_opts, word_idxs=word_idxs, device=device, verbose=verbose) 50 | elif name == 'tienet': 51 | model = TieNet(embeddings, lstm_dim=hidden_size, max_word=max_word, multi_image=multi_image, 52 | multi_merge=multi_merge, labels=tienet_labels, teacher_forcing=teacher_forcing, 53 | image_model=image_model, image_pretrained=image_pretrained, 54 | finetune_image=finetune_image, image_finetune_epoch=image_finetune_epoch, 55 | rl_opts=rl_opts, word_idxs=word_idxs, device=device, verbose=verbose) 56 | elif name == 'trans': 57 | model = TransformerCaptioner(embeddings, feat_dim=hidden_size, max_word=max_word, multi_image=multi_image, 58 | image_pe=trans_image_pe, layer_norm=trans_layer_norm, 59 | num_enc_layers=trans_enc_layers, num_dec_layers=trans_layers, 60 | teacher_forcing=teacher_forcing, image_model=image_model, 61 | image_pretrained=image_pretrained, finetune_image=finetune_image, 62 | image_finetune_epoch=image_finetune_epoch, rl_opts=rl_opts, 63 | word_idxs=word_idxs, device=device, verbose=verbose) 64 | elif name == 'trans-s': 65 | model = TransformerSimpleCaptioner(embeddings, feat_dim=hidden_size, max_word=max_word, 66 | multi_image=multi_image, image_pe=trans_image_pe, 67 | layer_norm=trans_layer_norm, num_layers=trans_layers, 68 | teacher_forcing=teacher_forcing, image_model=image_model, 69 | image_pretrained=image_pretrained, finetune_image=finetune_image, 70 | image_finetune_epoch=image_finetune_epoch, rl_opts=rl_opts, 71 | word_idxs=word_idxs, device=device, verbose=verbose) 72 | else: 73 | raise ValueError('Unknown model {0}'.format(name)) 74 | return model 75 | 76 | @classmethod 77 | def hierarchical(cls, name): 78 | if name == 'cnnrnnrnn': 79 | return True 80 | else: 81 | return False 82 | 83 | 84 | class RLOptions: 85 | def __init__(self, epoch=None, metrics=None, weights=None, cider_df=None, tfidf=None, bert_score=None, 86 | bert_score_penalty=False, op='add', nli='mednli', nli_label='entailment', nli_neutral_score=(1.0 / 3), 87 | nli_prf='f', nli_batch=16, entity_match=None, entity_mode='exact', nthreads=2, pin_memory=False): 88 | self.epoch = epoch 89 | self.metrics = metrics.split(',') if metrics is not None else None 90 | self.weights = list(map(lambda v: float(v), weights.split(','))) if weights is not None else None 91 | self.cider_df = cider_df 92 | self.tfidf = tfidf 93 | self.bert_score_penalty = bert_score_penalty 94 | self.nli = nli 95 | self.nli_label = nli_label 96 | self.nli_neutral_score = nli_neutral_score 97 | self.nli_prf = nli_prf 98 | self.nli_batch = nli_batch 99 | self.op = op 100 | self.nthreads = nthreads 101 | self.pin_memory = pin_memory 102 | 103 | # Self-critical RL metrics 104 | self.rl_train = False 105 | self.bleu, self.rouge, self.cider, self.spice, self.bert_score = False, False, False, False, None 106 | self.qa_score = False 107 | self.chexpert = False 108 | self.nli_compare = [] 109 | self.entity_match, self.entity_mode = None, None 110 | 111 | if metrics is not None: 112 | for metric in self.metrics: 113 | if metric.startswith('BLEU'): 114 | self.bleu = True 115 | elif metric == 'ROUGE': 116 | self.rouge = True 117 | elif metric == 'CIDEr': 118 | self.cider = True 119 | elif metric == 'SPICE': 120 | self.spice = True 121 | elif metric.startswith('chexpert'): 122 | self.chexpert = True 123 | elif metric.startswith('BERTScore'): 124 | self.bert_score = bert_score 125 | elif metric.startswith('QAScore'): 126 | self.qa_score = True 127 | elif metric.startswith('NLI'): 128 | if metric == 'NLISentBERTScore': 129 | self.nli_compare.append(SimpleNLI.COMPARE_BERT_SCORE) 130 | elif metric == 'NLISentBERTScoreT': 131 | self.nli_compare.append(SimpleNLI.COMPARE_BERT_SCORE_FIX_THRESH) 132 | elif metric == 'NLISentTFIDF': 133 | self.nli_compare.append(SimpleNLI.COMPARE_TFIDF) 134 | elif metric == 'NLISentAll': 135 | self.nli_compare.append(SimpleNLI.COMPARE_ALL) 136 | else: 137 | self.nli_compare.append(SimpleNLI.COMPARE_DOC) 138 | elif metric == 'EntityMatchExact': 139 | self.entity_match = entity_match 140 | if self.entity_mode is None: 141 | buf = entity_mode.split('-') 142 | buf[0] = 'exact' 143 | self.entity_mode = '-'.join(buf) 144 | elif metric == 'EntityMatchNLI': 145 | self.entity_match = entity_match 146 | self.entity_mode = entity_mode 147 | else: 148 | raise ValueError('Unknown RL metric {0}'.format(metric)) 149 | -------------------------------------------------------------------------------- /clinicgen/optmizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from torch.optim import Adam 5 | from torch.optim.lr_scheduler import _LRScheduler, StepLR 6 | 7 | 8 | class Optimizers: 9 | IMAGE = 'image' 10 | TEXT = 'text' 11 | TRANSFORMER = 'trans' 12 | 13 | @classmethod 14 | def get_optmizers(cls, model, lr, lr_img=None, lr_step=None, lr_scheduler='linear', lr_decay_rate=0.5, beta1=0.9, 15 | beta2=0.999, train_steps=None, d_train=None, steps_per_epoch=None, warmup=None): 16 | if lr_scheduler == cls.TRANSFORMER: 17 | optimizer = Adam(model.parameters(), lr=0.0, betas=(0.9, 0.98), eps=1e-9) 18 | optimizers = {cls.TEXT: optimizer} 19 | schedulers = {} 20 | batch_schedulers = {cls.TEXT: TransformerScheduler(optimizer, d_train, steps_per_epoch, warmup)} 21 | else: 22 | if lr_img is None or lr == lr_img: 23 | optimizers = {cls.TEXT: Adam(model.parameters(), lr=lr, betas=(beta1, beta2))} 24 | schedulers = {cls.TEXT: StepLR(optimizers[cls.TEXT], lr_step, lr_decay_rate)} 25 | else: 26 | if lr_img is None: 27 | lr_img = lr 28 | text_params, img_params = [], [] 29 | for name, param in model.named_parameters(): 30 | if name.startswith('image_feats'): 31 | img_params.append(param) 32 | else: 33 | text_params.append(param) 34 | optimizers = {cls.TEXT: Adam(text_params, lr=lr, betas=(beta1, beta2)), 35 | cls.IMAGE: Adam(img_params, lr=lr_img, betas=(beta1, beta2))} 36 | schedulers = {cls.TEXT: StepLR(optimizers[cls.TEXT], lr_step, lr_decay_rate), 37 | cls.IMAGE: StepLR(optimizers[cls.IMAGE], lr_step, lr_decay_rate)} 38 | batch_schedulers = None 39 | return optimizers, schedulers, batch_schedulers 40 | 41 | 42 | class TransformerScheduler(_LRScheduler): 43 | def __init__(self, optimizer, d_train, steps_per_epoch, warmup=4000, last_epoch=-1): 44 | self.epoch_step = 0 45 | self.d_train = d_train 46 | self.steps_per_epoch = steps_per_epoch 47 | self.warmup = warmup 48 | super(TransformerScheduler, self).__init__(optimizer, last_epoch) 49 | 50 | def batch_step(self): 51 | self.epoch_step += 1 52 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 53 | param_group['lr'] = lr 54 | 55 | def get_lr(self): 56 | step = min(0, self.last_epoch * self.steps_per_epoch) 57 | step += self.epoch_step 58 | step = max(1, step) 59 | return [self.d_train ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)) for _ in self.base_lrs] 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /clinicgen/text/__pycache__/sentsplit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/text/__pycache__/sentsplit.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/text/__pycache__/sentsplit.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/text/__pycache__/sentsplit.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/text/__pycache__/sentsplit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/text/__pycache__/sentsplit.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/text/__pycache__/textfilter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/text/__pycache__/textfilter.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/text/__pycache__/textfilter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/text/__pycache__/textfilter.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/text/__pycache__/textfilter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/text/__pycache__/textfilter.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/text/__pycache__/tokenfilter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/text/__pycache__/tokenfilter.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/text/__pycache__/tokenfilter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/text/__pycache__/tokenfilter.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/text/__pycache__/tokenfilter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/text/__pycache__/tokenfilter.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/text/__pycache__/tokenizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/text/__pycache__/tokenizer.cpython-36.pyc -------------------------------------------------------------------------------- /clinicgen/text/__pycache__/tokenizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/text/__pycache__/tokenizer.cpython-37.pyc -------------------------------------------------------------------------------- /clinicgen/text/__pycache__/tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mudabek/encoding-cxr-report-gen/1ee49b484fb57bc6575a83b9a4c827b01932e1d3/clinicgen/text/__pycache__/tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /clinicgen/text/parser.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import random 5 | from nltk import Tree 6 | from nltk.compat import unicode_repr 7 | from six import string_types 8 | from stanza.server import CoreNLPClient 9 | 10 | 11 | class CoreNLPBinaryParser: 12 | DEFAULT_PORT = 9003 13 | 14 | def __init__(self, threads=1, port=None): 15 | sid = random.randint(0, 65535) 16 | if port is None: 17 | port = self.DEFAULT_PORT 18 | self.corenlp = CoreNLPClient(endpoint='http://localhost:{0}'.format(port), annotators=['parse'], 19 | output_format='json', properties={'ssplit.eolonly': 'true'}, timeout=300000, 20 | memory='8G', threads=threads, server_id='clinicgen{0}'.format(sid)) 21 | self.corenlp.start() 22 | self.run = True 23 | 24 | def __del__(self): 25 | self.stop() 26 | 27 | @classmethod 28 | def _format(cls, tree): 29 | childstrs = [] 30 | for child in tree: 31 | if isinstance(child, Tree): 32 | childstrs.append(cls._format(child)) 33 | elif isinstance(child, tuple): 34 | childstrs.append("/".join(child)) 35 | elif isinstance(child, string_types): 36 | childstrs.append('%s' % child) 37 | else: 38 | childstrs.append(unicode_repr(child)) 39 | if len(childstrs) > 1: 40 | return '( %s )' % ' '.join(childstrs) 41 | else: 42 | return childstrs[0] 43 | 44 | @classmethod 45 | def binarize(cls, tree): 46 | # collapse 47 | t = Tree.fromstring(tree) 48 | # chomsky normal form transformation 49 | Tree.collapse_unary(t, collapsePOS=True, collapseRoot=True) 50 | Tree.chomsky_normal_form(t) 51 | s = cls._format(t) 52 | return s 53 | 54 | def parse(self, text): 55 | ann = self.corenlp.annotate(text) 56 | return self.binarize(ann['sentences'][0]['parse']) 57 | 58 | def stop(self): 59 | if self.run: 60 | self.corenlp.stop() 61 | self.run = False 62 | -------------------------------------------------------------------------------- /clinicgen/text/sentsplit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import spacy 5 | from nltk.tokenize import sent_tokenize 6 | from stanza import Pipeline 7 | 8 | 9 | def get_sentsplitter(name, linebreak=True): 10 | if name == 'linebreak': 11 | return LineBreakSplitter() 12 | elif name == 'nltk': 13 | return NLTKSentenceSplitter(linebreak) 14 | elif name == 'none': 15 | return NullSentenceSplitter() 16 | elif name == 'scispacy': 17 | return SpaCySentenceSplitter('en_core_sci_md', linebreak) 18 | elif name == 'spacy': 19 | return SpaCySentenceSplitter('en_core_web_sm', linebreak) 20 | elif name == 'stanford': 21 | return StanzaSentenceSplitter(linebreak) 22 | else: 23 | return None 24 | 25 | 26 | class LineBreakSplitter: 27 | @staticmethod 28 | def split(text): 29 | return text.split('\n') 30 | 31 | 32 | class NLTKSentenceSplitter: 33 | def __init__(self, linebreak=True): 34 | self.linebreak = linebreak 35 | 36 | def split(self, text): 37 | sents = [] 38 | for sent in sent_tokenize(text): 39 | if self.linebreak: 40 | for sent2 in sent.split('\n'): 41 | sents.append(sent2) 42 | else: 43 | sents.append(sent) 44 | return sents 45 | 46 | 47 | class NullSentenceSplitter: 48 | @staticmethod 49 | def split(text): 50 | return [text] 51 | 52 | 53 | class SpaCySentenceSplitter: 54 | def __init__(self, model, linebreak=True): 55 | self.nlp = spacy.load(model) 56 | self.linebreak = linebreak 57 | 58 | def split(self, text): 59 | sents = [] 60 | doc = self.nlp(text) 61 | for sent in doc.sents: 62 | if self.linebreak: 63 | for sent2 in sent.text.split('\n'): 64 | sents.append(sent2) 65 | else: 66 | sents.append(sent.text) 67 | return sents 68 | 69 | 70 | class StanzaSentenceSplitter: 71 | def __init__(self, linebreak=True): 72 | self.nlp = Pipeline(lang='en', processors='tokenize') 73 | self.linebreak = linebreak 74 | 75 | def split(self, text): 76 | sents = [] 77 | doc = self.nlp(text) 78 | for sent in doc.sentences: 79 | if self.linebreak: 80 | for sent2 in sent.text.split('\n'): 81 | sents.append(sent2) 82 | else: 83 | sents.append(sent.text) 84 | return sents 85 | -------------------------------------------------------------------------------- /clinicgen/text/textfilter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | def get_textfilter(name): 6 | if name == 'lower': 7 | return LowerTextFilter() 8 | elif name == 'none': 9 | return NullTextFilter() 10 | else: 11 | return None 12 | 13 | 14 | class LowerTextFilter: 15 | @staticmethod 16 | def filter(text): 17 | return text.lower() 18 | 19 | 20 | class NullTextFilter: 21 | @staticmethod 22 | def filter(text): 23 | return text 24 | -------------------------------------------------------------------------------- /clinicgen/text/tokenfilter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/evn python 2 | # -*- coding: utf-8 -*- 3 | 4 | import re 5 | 6 | 7 | def get_tokenfilter(name): 8 | if name == 'alphanum': 9 | return AlphaNumFilter() 10 | elif name == 'none': 11 | return NoneTokenFilter() 12 | else: 13 | return None 14 | 15 | 16 | class AlphaNumFilter: 17 | NUMBER_TOKEN = '__NUM__' 18 | ALPHANUM_PATTERN = re.compile('^[a-zA-Z0-9\\\.]+$') 19 | NUM_PATTERN = re.compile('^[0-9\\\.]+$') 20 | 21 | @classmethod 22 | def filter(cls, tokens): 23 | new_tokens = [] 24 | for token in tokens: 25 | if token != '.': 26 | if cls.ALPHANUM_PATTERN.search(token) is not None: 27 | new_tokens.append(token) 28 | elif cls.NUM_PATTERN.search(token) is not None: 29 | new_tokens.append(cls.NUMBER_TOKEN) 30 | return new_tokens 31 | 32 | 33 | class NoneTokenFilter: 34 | @staticmethod 35 | def filter(tokens): 36 | return tokens 37 | -------------------------------------------------------------------------------- /clinicgen/text/tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import spacy 5 | from nltk.tokenize import wordpunct_tokenize 6 | from stanza import Pipeline 7 | 8 | 9 | def get_tokenizer(name, port=None): 10 | if name == 'nltk': 11 | return NLTKTokenizer() 12 | elif name == 'scispacy': 13 | return SpaCyTokenizer('en_core_sci_md') 14 | elif name == 'spacy': 15 | return SpaCyTokenizer('en_core_web_sm') 16 | elif name == 'stanford': 17 | return StanzaTokenizer() 18 | elif name == 'whitespace': 19 | return WhiteSpaceTokenizer() 20 | else: 21 | return None 22 | 23 | 24 | class NLTKTokenizer: 25 | @staticmethod 26 | def tokenize(text): 27 | return wordpunct_tokenize(text) 28 | 29 | 30 | class SpaCyTokenizer: 31 | def __init__(self, model): 32 | self.nlp = spacy.load(model) 33 | 34 | def tokenize(self, text): 35 | toks = [] 36 | for tok in self.nlp(text): 37 | toks.append(tok.text) 38 | return toks 39 | 40 | 41 | class StanzaTokenizer: 42 | def __init__(self): 43 | self.nlp = Pipeline(lang='en', processors='tokenize') 44 | 45 | def tokenize(self, text): 46 | toks = [] 47 | doc = self.nlp(text) 48 | for sentence in doc.sentences: 49 | for token in sentence.tokens: 50 | toks.append(token.text) 51 | return toks 52 | 53 | 54 | class WhiteSpaceTokenizer: 55 | @staticmethod 56 | def tokenize(text): 57 | return text.split() 58 | -------------------------------------------------------------------------------- /clinicgen/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | from torch import sigmoid 6 | from torch.nn import DataParallel 7 | from clinicgen.data.image2text import PretrainedEmbeddings 8 | 9 | 10 | def data_cuda(*tensors, device='gpu', non_blocking=False): 11 | if device == 'gpu': 12 | cuda_tensors = [] 13 | for tensor in tensors: 14 | cuda_tensors.append(tensor.cuda(device=0, non_blocking=non_blocking)) 15 | else: 16 | cuda_tensors = tensors 17 | 18 | if len(cuda_tensors) > 1: 19 | return cuda_tensors 20 | else: 21 | return cuda_tensors[0] 22 | 23 | 24 | class RecoverWords: 25 | def __init__(self, word_indexes): 26 | self.index_words = {} 27 | for word, index in word_indexes.items(): 28 | self.index_words[index] = word 29 | 30 | def __call__(self, *inputs, normalized=False): 31 | reports = [] 32 | 33 | if len(inputs) == 2: 34 | stops, samples = inputs 35 | if not normalized: 36 | stops = sigmoid(stops) 37 | stops = stops.detach().cpu().numpy() 38 | samples = samples.detach().cpu().numpy() 39 | masks = np.zeros((samples.shape[0], samples.shape[1], samples.shape[2]), dtype='float') 40 | 41 | for i in range(stops.shape[0]): 42 | stop_sent = False 43 | sentences = [] 44 | for j in range(stops.shape[1]): 45 | if not stop_sent: 46 | if j > 0 and stops[i][j] >= 0.5: 47 | stop_sent = True 48 | else: 49 | stop_word = False 50 | words = [] 51 | for k in range(samples.shape[2]): 52 | if not stop_word: 53 | index = samples[i][j][k] 54 | masks[i][j][k] = 1.0 55 | if index == PretrainedEmbeddings.INDEX_EOS: 56 | stop_word = True 57 | elif index != PretrainedEmbeddings.INDEX_PAD: 58 | words.append(self.index_words[index]) 59 | sentences.append(' '.join(words)) 60 | reports.append('\n'.join(sentences)) 61 | else: 62 | samples = inputs[0] 63 | samples = samples.detach().cpu().numpy() 64 | masks = np.zeros((samples.shape[0], samples.shape[1]), dtype='float') 65 | 66 | for i in range(samples.shape[0]): 67 | stop_word = False 68 | words = [] 69 | for k in range(samples.shape[1]): 70 | if not stop_word: 71 | index = samples[i][k] 72 | masks[i][k] = 1.0 73 | if index == PretrainedEmbeddings.INDEX_EOS: 74 | stop_word = True 75 | elif index != PretrainedEmbeddings.INDEX_PAD: 76 | words.append(self.index_words[index]) 77 | reports.append(' '.join(words)) 78 | return reports, masks 79 | 80 | def array(self, idxs): 81 | return list(map(lambda idx: self.index_words[idx], idxs)) 82 | 83 | 84 | class DataParallelSwitch(DataParallel): 85 | def __init__(self, module, device_ids=None, output_device=None, dim=0): 86 | super(DataParallelSwitch, self).__init__(module, device_ids, output_device, dim) 87 | self.parallel = False 88 | 89 | def forward(self, *inputs, **kwargs): 90 | if not self.parallel: 91 | return self.module(*inputs, **kwargs) 92 | return super(DataParallelSwitch, self).forward(inputs, kwargs) 93 | -------------------------------------------------------------------------------- /convert_generated.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import csv 6 | import gzip 7 | import os 8 | 9 | 10 | def main(args): 11 | with open(args.output, 'w', encoding='utf-8') as out: 12 | writer = csv.writer(out) 13 | writer.writerow(['DOC_ID', 'Report Generated']) 14 | with gzip.open(args.gen, 'rt', encoding='utf-8') as f: 15 | for line in f: 16 | entry = line.rstrip().split(' ') 17 | did = entry[0].split('__')[0] 18 | text = rewrite(' '.join(entry[2:])) 19 | writer.writerow([did, text]) 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('gen', type=str) 25 | parser.add_argument('output', type=str) 26 | return parser.parse_args() 27 | 28 | 29 | def rewrite(text): 30 | text = text.replace(" ' ", "'") 31 | text = text.replace(" n't", "n't") 32 | text = text.replace(' - ', '-') 33 | text = text.replace(' .', '.') 34 | text = text.replace(' ,', ',') 35 | return text 36 | 37 | 38 | if __name__ == '__main__': 39 | args = parse_args() 40 | main(args) 41 | -------------------------------------------------------------------------------- /create_sections.py: -------------------------------------------------------------------------------- 1 | # This script extracts the conclusion section from MIMIC-CXR reports 2 | # It outputs them into individual files with at most 10,000 reports. 3 | import sys 4 | import os 5 | import argparse 6 | import csv 7 | from pathlib import Path 8 | 9 | from tqdm import tqdm 10 | 11 | # local folder import 12 | import section_parser as sp 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--reports_path', 16 | required=True, 17 | help=('Path to file with radiology reports,' 18 | ' e.g. /data/mimic-cxr/files')) 19 | parser.add_argument('--output_path', 20 | required=True, 21 | help='Path to output CSV files.') 22 | parser.add_argument('--no_split', action='store_true', 23 | help='Do not output batched CSV files.') 24 | 25 | 26 | def list_rindex(l, s): 27 | """Helper function: *last* matching element in a list""" 28 | return len(l) - l[-1::-1].index(s) - 1 29 | 30 | 31 | def main(args): 32 | args = parser.parse_args(args) 33 | 34 | reports_path = Path(args.reports_path) 35 | output_path = Path(args.output_path) 36 | 37 | if not output_path.exists(): 38 | output_path.mkdir() 39 | 40 | # not all reports can be automatically sectioned 41 | # we load in some dictionaries which have manually determined sections 42 | custom_section_names, custom_indices = sp.custom_mimic_cxr_rules() 43 | 44 | # get all higher up folders (p00, p01, etc) 45 | p_grp_folders = os.listdir(reports_path) 46 | p_grp_folders = [p for p in p_grp_folders 47 | if p.startswith('p') and len(p) == 3] 48 | p_grp_folders.sort() 49 | 50 | # patient_studies will hold the text for use in NLP labeling 51 | patient_studies = [] 52 | 53 | # study_sections will have an element for each study 54 | # this element will be a list, each element having text for a specific section 55 | study_sections = [] 56 | for p_grp in p_grp_folders: 57 | # get patient folders, usually around ~6k per group folder 58 | cxr_path = reports_path / p_grp 59 | p_folders = os.listdir(cxr_path) 60 | p_folders = [p for p in p_folders if p.startswith('p')] 61 | p_folders.sort() 62 | 63 | # For each patient in this grouping folder 64 | print(p_grp) 65 | for p in tqdm(p_folders): 66 | patient_path = cxr_path / p 67 | 68 | # get the filename for all their free-text reports 69 | studies = os.listdir(patient_path) 70 | studies = [s for s in studies 71 | if s.endswith('.txt') and s.startswith('s')] 72 | 73 | for s in studies: 74 | # load in the free-text report 75 | with open(patient_path / s, 'r') as fp: 76 | text = ''.join(fp.readlines()) 77 | 78 | # get study string name without the txt extension 79 | s_stem = s[0:-4] 80 | 81 | # custom rules for some poorly formatted reports 82 | if s_stem in custom_indices: 83 | idx = custom_indices[s_stem] 84 | patient_studies.append([s_stem, text[idx[0]:idx[1]]]) 85 | continue 86 | 87 | # split text into sections 88 | sections, section_names, section_idx = sp.section_text( 89 | text 90 | ) 91 | 92 | # check to see if this has mis-named sections 93 | # e.g. sometimes the impression is in the comparison section 94 | if s_stem in custom_section_names: 95 | sn = custom_section_names[s_stem] 96 | idx = list_rindex(section_names, sn) 97 | patient_studies.append([s_stem, sections[idx].strip()]) 98 | continue 99 | 100 | # grab the *last* section with the given title 101 | # prioritizes impression > findings, etc. 102 | 103 | # "last_paragraph" is text up to the end of the report 104 | # many reports are simple, and have a single section 105 | # header followed by a few paragraphs 106 | # these paragraphs are grouped into section "last_paragraph" 107 | 108 | # note also comparison seems unusual but if no other sections 109 | # exist the radiologist has usually written the report 110 | # in the comparison section 111 | idx = -1 112 | for sn in ('impression', 'findings', 113 | 'last_paragraph', 'comparison'): 114 | if sn in section_names: 115 | idx = list_rindex(section_names, sn) 116 | break 117 | 118 | if idx == -1: 119 | # we didn't find any sections we can use :( 120 | patient_studies.append([s_stem, '']) 121 | print(f'no impression/findings: {patient_path / s}') 122 | else: 123 | # store the text of the conclusion section 124 | patient_studies.append([s_stem, sections[idx].strip()]) 125 | 126 | study_sectioned = [s_stem] 127 | for sn in ('impression', 'findings', 128 | 'last_paragraph', 'comparison'): 129 | if sn in section_names: 130 | idx = list_rindex(section_names, sn) 131 | study_sectioned.append(sections[idx].strip()) 132 | else: 133 | study_sectioned.append(None) 134 | study_sections.append(study_sectioned) 135 | # write distinct files to facilitate modular processing 136 | if len(patient_studies) > 0: 137 | # write out a single CSV with the sections 138 | with open(output_path / 'mimic_cxr_sectioned.csv', 'w') as fp: 139 | csvwriter = csv.writer(fp) 140 | # write header 141 | csvwriter.writerow(['study', 'impression', 'findings', 142 | 'last_paragraph', 'comparison']) 143 | for row in study_sections: 144 | csvwriter.writerow(row) 145 | 146 | if args.no_split: 147 | # write all the reports out to a single file 148 | with open(output_path / f'mimic_cxr_sections.csv', 'w') as fp: 149 | csvwriter = csv.writer(fp) 150 | for row in patient_studies: 151 | csvwriter.writerow(row) 152 | else: 153 | # write ~22 files with ~10k reports each 154 | n = 0 155 | jmp = 10000 156 | 157 | while n < len(patient_studies): 158 | n_fn = n // jmp 159 | with open(output_path / f'mimic_cxr_{n_fn:02d}.csv', 'w') as fp: 160 | csvwriter = csv.writer(fp) 161 | for row in patient_studies[n:n+jmp]: 162 | csvwriter.writerow(row) 163 | n += jmp 164 | 165 | 166 | if __name__ == '__main__': 167 | main(sys.argv[1:]) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: py37-ifcc 2 | channels: 3 | - stanfordnlp 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - absl-py=0.9.0=py37_0 9 | - blas=1.0=mkl 10 | - blinker=1.4=py37_0 11 | - brotlipy=0.7.0=py37h7b6447c_1000 12 | - c-ares=1.15.0=h7b6447c_1001 13 | - ca-certificates=2020.7.22=0 14 | - cachetools=4.1.0=py_1 15 | - certifi=2020.6.20=py37_0 16 | - cffi=1.14.0=py37he30daa8_1 17 | - chardet=3.0.4=py37_1003 18 | - click=7.1.2=py_0 19 | - cryptography=2.9.2=py37h1ba5d50_0 20 | - cudatoolkit=10.1.243=h6bb024c_0 21 | - freetype=2.10.2=h5ab3b9f_0 22 | - google-auth=1.14.1=py_0 23 | - google-auth-oauthlib=0.4.1=py_2 24 | - grpcio=1.27.2=py37hf8bcb03_0 25 | - idna=2.9=py_1 26 | - intel-openmp=2020.1=217 27 | - joblib=0.15.1=py_0 28 | - jpeg=9b=h024ee3a_2 29 | - ld_impl_linux-64=2.33.1=h53a641e_7 30 | - libedit=3.1.20191231=h7b6447c_0 31 | - libffi=3.3=he6710b0_1 32 | - libgcc-ng=9.1.0=hdf63c60_0 33 | - libgfortran-ng=7.3.0=hdf63c60_0 34 | - libpng=1.6.37=hbc83047_0 35 | - libprotobuf=3.12.3=hd408876_0 36 | - libstdcxx-ng=9.1.0=hdf63c60_0 37 | - libtiff=4.1.0=h2733197_1 38 | - lz4-c=1.9.2=he6710b0_0 39 | - markdown=3.1.1=py37_0 40 | - mkl=2020.1=217 41 | - mkl-service=2.3.0=py37he904b0f_0 42 | - mkl_fft=1.1.0=py37h23d657b_0 43 | - mkl_random=1.1.1=py37h0573a6f_0 44 | - ncurses=6.2=he6710b0_1 45 | - ninja=1.9.0=py37hfd86e86_0 46 | - nltk=3.4.5=py37_0 47 | - numpy=1.18.5=py37ha1c710e_0 48 | - numpy-base=1.18.5=py37hde5b4d6_0 49 | - oauthlib=3.1.0=py_0 50 | - olefile=0.46=py37_0 51 | - openssl=1.1.1h=h7b6447c_0 52 | - pandas=1.0.1=py37h0573a6f_0 53 | - pillow=7.1.2=py37hb39fc2d_0 54 | - pip=20.1.1=py37_1 55 | - pyasn1=0.4.8=py_0 56 | - pyasn1-modules=0.2.7=py_0 57 | - pycparser=2.20=py_0 58 | - pyjwt=1.7.1=py37_0 59 | - pyopenssl=19.1.0=py37_0 60 | - pysocks=1.7.1=py37_0 61 | - python=3.7.7=hcff3b4d_5 62 | - python-dateutil=2.8.1=py_0 63 | - pytorch=1.5.0=py3.7_cuda10.1.243_cudnn7.6.3_0 64 | - pytz=2020.1=py_0 65 | - readline=8.0=h7b6447c_0 66 | - requests=2.24.0=py_0 67 | - requests-oauthlib=1.3.0=py_0 68 | - scikit-learn=0.22.1=py37hd81dba3_0 69 | - scipy=1.5.0=py37h0b6359f_0 70 | - setuptools=47.3.1=py37_0 71 | - six=1.15.0=py_0 72 | - sqlite=3.32.3=h62c20be_0 73 | - stanza=1.1.1=py37_0 74 | - tensorboard=2.2.1=pyh532a8cf_0 75 | - tensorboard-plugin-wit=1.6.0=py_0 76 | - tk=8.6.10=hbc83047_0 77 | - torchvision=0.6.0=py37_cu101 78 | - urllib3=1.25.9=py_0 79 | - werkzeug=1.0.1=py_0 80 | - wheel=0.34.2=py37_0 81 | - xz=5.2.5=h7b6447c_0 82 | - zlib=1.2.11=h7b6447c_3 83 | - zstd=1.4.4=h0b5b093_3 84 | - pip: 85 | - awscli==1.18.88 86 | - bert-score==0.3.0 87 | - blis==0.2.4 88 | - botocore==1.17.11 89 | - colorama==0.4.3 90 | - conllu==3.0 91 | - cycler==0.10.0 92 | - cymem==2.0.3 93 | - docutils==0.15.2 94 | - filelock==3.0.12 95 | - future==0.18.2 96 | - jmespath==0.10.0 97 | - jsonschema==2.6.0 98 | - kiwisolver==1.2.0 99 | - matplotlib==3.2.2 100 | - murmurhash==1.0.2 101 | - plac==0.9.6 102 | - preshed==2.0.1 103 | - protobuf==3.12.2 104 | - pyparsing==2.4.7 105 | - pyyaml==5.3.1 106 | - regex==2020.6.8 107 | - rouge==0.3.2 108 | - rsa==3.4.2 109 | - s3transfer==0.3.3 110 | - sacremoses==0.0.43 111 | - scispacy==0.2.0 112 | - sentencepiece==0.1.91 113 | - spacy==2.1.3 114 | - srsly==1.0.2 115 | - stanfordnlp==0.2.0 116 | - thinc==7.0.8 117 | - tokenizers==0.7.0 118 | - tqdm==4.46.1 119 | - transformers==2.9.0 120 | - wasabi==0.7.0 121 | prefix: //anaconda3/envs/py37-ifcc 122 | -------------------------------------------------------------------------------- /eval_prf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import csv 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | import utils 10 | from collections import OrderedDict 11 | from sklearn.metrics import accuracy_score, precision_recall_fscore_support 12 | from transformers import BertTokenizer 13 | from constants import * 14 | from models.bert_labeler import bert_labeler 15 | 16 | 17 | def label(checkpoint_path, texts): 18 | model = bert_labeler() 19 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 20 | if torch.cuda.device_count() > 0: #works even if only 1 GPU available 21 | print("Using", torch.cuda.device_count(), "GPUs!") 22 | model = nn.DataParallel(model) #to utilize multiple GPU's 23 | model = model.to(device) 24 | checkpoint = torch.load(checkpoint_path) 25 | model.load_state_dict(checkpoint['model_state_dict']) 26 | else: 27 | checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) 28 | new_state_dict = OrderedDict() 29 | for k, v in checkpoint['model_state_dict'].items(): 30 | name = k[7:] # remove `module.` 31 | new_state_dict[name] = v 32 | model.load_state_dict(new_state_dict) 33 | 34 | was_training = model.training 35 | model.eval() 36 | y_pred = [[] for _ in range(len(CONDITIONS))] 37 | 38 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 39 | encoded_imp = tokenize(texts, tokenizer) 40 | 41 | with torch.no_grad(): 42 | batches, ls, buffer = [], [], [] 43 | for data in encoded_imp: 44 | buffer.append(data) 45 | if len(buffer) >= BATCH_SIZE: 46 | batch, bl = make_batch(buffer) 47 | batches.append(batch) 48 | ls.append(bl) 49 | buffer = [] 50 | if len(buffer) > 0: 51 | batch, bl = make_batch(buffer) 52 | batches.append(batch) 53 | ls.append(bl) 54 | 55 | for batch, bl in zip(batches, ls): 56 | batch = batch.to(device) 57 | src_len = bl 58 | batch_size = batch.shape[0] 59 | attn_mask = utils.generate_attention_masks(batch, src_len, device) 60 | 61 | out = model(batch, attn_mask) 62 | 63 | for j in range(len(out)): 64 | curr_y_pred = out[j].argmax(dim=1) #shape is (batch_size) 65 | y_pred[j].append(curr_y_pred) 66 | 67 | for j in range(len(y_pred)): 68 | y_pred[j] = torch.cat(y_pred[j], dim=0) 69 | 70 | if was_training: 71 | model.train() 72 | 73 | y_pred = [t.tolist() for t in y_pred] 74 | return y_pred 75 | 76 | 77 | def main(args): 78 | if args.uncertain: 79 | avg, pos = 'micro', [1, 3] 80 | else: 81 | avg, pos = 'binary', 1 82 | 83 | ids = {} 84 | path = os.path.join(args.ref, 'reports.csv') 85 | with open(path, encoding='utf-8') as f: 86 | f.readline() 87 | reader = csv.reader(f) 88 | idx = 1 89 | for row in reader: 90 | ids[idx] = row[0] 91 | idx += 1 92 | refs = {} 93 | path = os.path.join(args.ref, 'labeled_reports.csv') 94 | with open(path, encoding='utf-8') as f: 95 | f.readline() 96 | reader = csv.reader(f) 97 | idx = 1 98 | for row in reader: 99 | labels = {} 100 | for lidx, l in enumerate(row[1:]): 101 | labels[CONDITIONS[lidx]] = l 102 | refs[ids[idx]] = labels 103 | idx += 1 104 | print('{0} references'.format(len(refs))) 105 | 106 | gids, texts = [], [] 107 | with open(args.gen, 'rt', encoding='utf-8') as f: 108 | f.readline() 109 | reader = csv.reader(f) 110 | for row in reader: 111 | gids.append(row[0]) 112 | texts.append(row[1]) 113 | print('{0} generated'.format(len(texts))) 114 | 115 | with open('chexbert.pth', 'rb') as f: 116 | rs = label(f, texts) 117 | 118 | result = {} 119 | for lidx, labels in enumerate(rs): 120 | for idx, l in enumerate(labels): 121 | did = gids[idx] 122 | if did not in result: 123 | result[did] = {} 124 | if CONDITIONS[lidx] not in result[did]: 125 | result[did][CONDITIONS[lidx]] = l 126 | 127 | with open(args.out, 'w', encoding='utf-8') as out: 128 | out.write('DOC_ID,{0}\n'.format(','.join(CONDITIONS))) 129 | for did in gids: 130 | l = [] 131 | for c in CONDITIONS: 132 | v = result[did][c] 133 | l.append(str(v)) 134 | out.write('{0},{1}\n'.format(did, ','.join(l) )) 135 | 136 | trues, preds = {}, {} 137 | for did, gen_labels in result.items(): 138 | for gen_label, v in gen_labels.items(): 139 | if gen_label not in preds: 140 | preds[gen_label] = [] 141 | if v == 3 and not args.uncertain: 142 | v = 1 143 | elif v == 2: 144 | v = 0 145 | preds[gen_label].append(v) 146 | for true_label, v in refs[did].items(): 147 | if true_label not in trues: 148 | trues[true_label] = [] 149 | if v == '-1.0': 150 | v = 1 if not args.uncertain else 3 151 | elif v == '1.0': 152 | v = 1 153 | elif v == '0.0' or v == '': 154 | v = 0 155 | trues[true_label].append(v) 156 | 157 | prs, rcs, fb1s = [], [], [] 158 | for c in CONDITIONS: 159 | if c in trues and c in preds: 160 | acc = accuracy_score(trues[c], preds[c]) 161 | pr, rc, fb1, _ = precision_recall_fscore_support(trues[c], preds[c], labels=pos, average=avg) 162 | prs.append(pr) 163 | rcs.append(rc) 164 | fb1s.append(fb1) 165 | print('{0} {1} {2} {3} {4}'.format(c, acc, pr, rc, fb1)) 166 | 167 | trues5, preds5 = [], [] 168 | for c in ['Cardiomegaly', 'Edema', 'Consolidation', 'Atelectasis', 'Pleural Effusion']: 169 | trues5 += trues[c] 170 | preds5 += preds[c] 171 | pr, rc, fb1, _ = precision_recall_fscore_support(trues5, preds5, labels=pos, average=avg) 172 | print('5-micro {0} {1} {2}'.format(pr, rc, fb1)) 173 | print('5-acc {0}'.format(accuracy_score(trues5, preds5))) 174 | 175 | 176 | def make_batch(buffer): 177 | max_len = 0 178 | for data in buffer: 179 | if data.shape[0] > max_len: 180 | max_len = data.shape[0] 181 | batch = torch.zeros((len(buffer), max_len), dtype=torch.long) 182 | bl = [] 183 | for i, data in enumerate(buffer): 184 | batch[i][:data.shape[0]] = data 185 | bl.append(data.shape[0]) 186 | return batch, bl 187 | 188 | 189 | def parse_args(): 190 | parser = argparse.ArgumentParser() 191 | parser.add_argument('ref', type=str, help='A path to reference reports') 192 | parser.add_argument('gen', type=str, help='A path to generated reports') 193 | parser.add_argument('out', type=str, help='A path to output CheXbert outputs') 194 | parser.add_argument('--uncertain', default=False, action='store_true', help='Treat uncertain as an independent class') 195 | return parser.parse_args() 196 | 197 | 198 | def tokenize(impressions, tokenizer): 199 | new_impressions = [] 200 | for impression in impressions: 201 | tokenized_imp = tokenizer.tokenize(impression) 202 | if tokenized_imp: #not an empty report 203 | res = tokenizer.encode_plus(tokenized_imp)['input_ids'] 204 | if len(res) > 512: #length exceeds maximum size 205 | print("report length bigger than 512") 206 | res = res[:511] + [tokenizer.sep_token_id] 207 | new_impressions.append(torch.LongTensor(res)) 208 | else: #an empty report 209 | new_impressions.append(torch.LongTensor([tokenizer.cls_token_id, tokenizer.sep_token_id])) 210 | return new_impressions 211 | 212 | 213 | if __name__ == '__main__': 214 | args = parse_args() 215 | main(args) 216 | -------------------------------------------------------------------------------- /extract_reports.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import csv 6 | import gzip 7 | import os 8 | import re 9 | 10 | 11 | def main(args): 12 | space_pattern = re.compile('\\s+') 13 | 14 | doc_ids = {} 15 | with gzip.open(args.split, 'rt', encoding='utf-8') as f: 16 | f.readline() 17 | reader = csv.reader(f) 18 | for row in reader: 19 | if row[3] == 'test': 20 | sid = row[1] 21 | doc_ids[sid] = True 22 | print('{0} test reports'.format(len(doc_ids))) 23 | 24 | with gzip.open(args.sections, 'rt', encoding='utf-8') as f: 25 | header = f.readline().strip().split(',') 26 | sections = {} 27 | reader = csv.reader(f) 28 | for row in reader: 29 | report = {} 30 | for i, sec in enumerate(header): 31 | report[sec] = row[i] 32 | sections[row[0]] = report 33 | 34 | if not os.path.exists(args.output): 35 | os.mkdir(args.output) 36 | with open(os.path.join(args.output, 'reports.csv'), 'w', encoding='utf-8') as out: 37 | writer = csv.writer(out) 38 | writer.writerow(['DOC_ID', 'Report Impression']) 39 | for sid in doc_ids: 40 | text = sections['s' + sid]['impression'] 41 | if len(text) == 0: 42 | text = sections['s' + sid]['findings'] 43 | text = text.replace('\n', ' ') 44 | text = space_pattern.sub(' ', text) 45 | if len(text) > 0: 46 | writer.writerow([sid, text]) 47 | 48 | 49 | def parse_args(): 50 | parser = argparse.ArgumentParser(description='Extract radiology named entities') 51 | parser.add_argument('sections', type=str, help='A path to the MIMIC-CXR sectioned file') 52 | parser.add_argument('split', type=str, help='A path to the MIMIC-CXR split file') 53 | parser.add_argument('output', type=str, help='An output path') 54 | return parser.parse_args() 55 | 56 | 57 | if __name__ == '__main__': 58 | args = parse_args() 59 | main(args) 60 | -------------------------------------------------------------------------------- /libs.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - anaconda 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - blas=1.0=mkl 7 | - ca-certificates=2020.1.1=0 8 | - certifi=2019.11.28=py36_0 9 | - decorator=4.3.0=py_0 10 | - docopt=0.6.2=py_1 11 | - gcc_linux-64=7.3.0=h553295d_9 12 | - gxx_linux-64=7.3.0=h553295d_9 13 | - intel-openmp=2019.1=144 14 | - jpype1=0.6.3=py36h9de70de_1001 15 | - libblas=3.8.0=14_mkl 16 | - libcblas=3.8.0=14_mkl 17 | - libedit=3.1.20170329=hf8c457e_1001 18 | - libffi=3.2.1=he1b5a44_1006 19 | - libgcc-ng=8.2.0=hdf63c60_1 20 | - libgfortran-ng=7.3.0=hdf63c60_5 21 | - liblapack=3.8.0=14_mkl 22 | - libstdcxx-ng=8.2.0=hdf63c60_1 23 | - mkl=2019.1=144 24 | - mkl_fft=1.0.6=py36_0 25 | - mkl_random=1.0.1=py36_0 26 | - ncurses=6.1=hf484d3e_1002 27 | - networkx=1.11=py36_0 28 | - nltk=3.3.0=py36_0 29 | - numpy=1.15.4=py36h8b7e671_1002 30 | - numpy-base=1.15.4=py36hde5b4d6_0 31 | - openssl=1.1.1f=h7b6447c_0 32 | - pandas=0.23.4=py36h637b7d7_1000 33 | - pathlib2=2.3.5=py36_0 34 | - pip=18.1=py36_1000 35 | - ply=3.11=py_1 36 | - python=3.6.7=h0371630_0 37 | - python-dateutil=2.7.5=py_0 38 | - pytz=2018.7=py_0 39 | - readline=7.0=hf8c457e_1001 40 | - setuptools=40.6.2=py36_0 41 | - six=1.11.0=py36_1001 42 | - sqlite=3.25.3=h67949de_1000 43 | - tk=8.6.8=h84994c4_1000 44 | - tqdm=4.28.1=py_0 45 | - wheel=0.32.3=py36_0 46 | - xz=5.2.4=h14c3975_1001 47 | - zlib=1.2.11=h516909a_1006 48 | - pip: 49 | - bioc==1.1.dev3 50 | - bllipparser==2016.9.11 51 | - deprecation==2.0.6 52 | - docutils==0.13.1 53 | - lxml==3.7.3 54 | - packaging==18.0 55 | - pyparsing==2.3.0 56 | - pystanforddependencies==0.3.1 -------------------------------------------------------------------------------- /make_radnli-pseudo-train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import csv 6 | import gzip 7 | import json 8 | import os 9 | import re 10 | from collections import OrderedDict 11 | from nltk.tokenize import sent_tokenize 12 | 13 | 14 | def main(args): 15 | reports = {} 16 | with gzip.open(args.sections, 'rt', encoding='utf-8') as f: 17 | f.readline() 18 | reader = csv.reader(f) 19 | for row in reader: 20 | findings = row[2] 21 | findings = findings.replace('\n', ' ') 22 | findings = re.sub('\\s+', ' ', findings) 23 | reports[row[0]] = findings 24 | with open('radnli_pseudo-train.jsonl', 'w', encoding='utf-8') as out: 25 | with open(os.path.join('resources', 'radnli_pseudo-train_indexes.jsonl'), encoding='utf-8') as f: 26 | for line in f: 27 | entry = json.loads(line) 28 | sent1 = entry['sentence1'].split(',') 29 | idxs = sent1[1].split(':') 30 | entry['sentence1'] = reports[sent1[0]][int(idxs[0]):int(idxs[1])] 31 | sent2 = entry['sentence2'].split(',') 32 | idxs = sent2[1].split(':') 33 | entry['sentence2'] = reports[sent2[0]][int(idxs[0]):int(idxs[1])] 34 | out.write(json.dumps(entry) + '\n') 35 | print('Wrote: radnli_pseudo-train.jsonl') 36 | 37 | def parse_args(): 38 | parser = argparse.ArgumentParser(description='Make RadNLI pseudo training data from MIMIC-CXR') 39 | parser.add_argument('sections', type=str, help='A path to the MIMIC-CXR sectioned file') 40 | return parser.parse_args() 41 | 42 | if __name__ == '__main__': 43 | args = parse_args() 44 | main(args) 45 | -------------------------------------------------------------------------------- /metric_analysis.py: -------------------------------------------------------------------------------- 1 | from transformers import pipeline 2 | from clinicgen.nli import BERTScorer 3 | import pandas as pd 4 | import numpy as np 5 | import re 6 | from tqdm import tqdm 7 | import torch 8 | 9 | # Set visible gpu 10 | # import os 11 | # os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 12 | # os.environ["CUDA_VISIBLE_DEVICES"]=2 13 | 14 | # Questions for QA model 15 | QUESTIONS = [ 16 | 'Is there pneumonia?', 17 | 'Is there edema?', 18 | 'Is there thorax?', 19 | 'Are there devices?', 20 | 'Is there opacity?', 21 | 'Is there atelectasis?', 22 | 'Is there cardiomegaly?', 23 | 'Is there lung lesion?', 24 | 'Is there consolidation?', 25 | 'Is there fracture?', 26 | ] 27 | 28 | # Model for BERTScore 29 | bert_model = 'distilbert-base-uncased' 30 | 31 | # Read original reports 32 | reports_df = pd.read_csv('/home/otabek.nazarov/Downloads/thesis/ifcc/labeled_reports_test.csv') 33 | 34 | # Batch size configuration for model 35 | samples_cnt = 1500#3854 36 | batch_size = 50#47 # or 41 37 | batch_count = int(samples_cnt / batch_size) 38 | 39 | # Load QA model 40 | device_id = 2 # -1 for cpu 41 | qa_model = pipeline("question-answering", 42 | model='franklu/pubmed_bert_squadv2', 43 | framework='pt', 44 | device=device_id) 45 | qa_model.model.to(torch.device('cuda:2')) 46 | 47 | QA_THRESHOLD = 0.25 48 | 49 | # Load BERTScore model 50 | bert_score_qa_model = BERTScorer(model_type=bert_model, batch_size=batch_size, 51 | nthreads=2, lang='en', rescale_with_baseline=True, 52 | penalty=False) 53 | # bert_score_qa_model.cuda()#.model.to(torch.device('cuda:2')) 54 | 55 | # Dictionary for final dataframe 56 | data_dict = { 57 | 'mask_prob' : [], 58 | 'f1_full' : [], 59 | 'f1_qa' : [], 60 | 'prec_full' : [], 61 | 'prec_qa' : [], 62 | 'recall_full' : [], 63 | 'recall_qa' : [], 64 | } 65 | 66 | bert_scores_detailed = [] 67 | 68 | mask_reports_dict = {} 69 | 70 | for percent in tqdm(range(0, 100, 4)): 71 | # Mask out all the reports 72 | orig_reports = reports_df['Report Impression'].values[:samples_cnt] 73 | mask_reports = [] 74 | masking_prob = percent / 100 75 | data_dict['mask_prob'].append(masking_prob) 76 | 77 | for cur_report in orig_reports: 78 | # Split report into list of words 79 | words = cur_report.split() 80 | words_array = np.array(words) 81 | length = len(words_array) 82 | 83 | # Mask out the words with masking_prob 84 | mask = np.random.choice([0, 1], size=length, replace=True, p=[1-masking_prob, masking_prob]).astype(bool) 85 | mask_vals = np.array([''] * length) 86 | words_array[mask] = mask_vals[mask] 87 | 88 | # Append masked report to the list 89 | masked_report = ' '.join(words_array.tolist()) 90 | masked_report = re.sub(' +', ' ', masked_report) 91 | mask_reports.append(masked_report) 92 | 93 | # Save masked reports for dataframe 94 | mask_reports_dict[f'masked_{percent}'] = mask_reports 95 | 96 | # Turn into batches for fast processing 97 | orig_reports = np.reshape(orig_reports, (batch_count, batch_size)) 98 | mask_reports = np.reshape(np.array(mask_reports), (batch_count, batch_size)) 99 | 100 | 101 | f1_score_means = [] 102 | f1_score_means_orig = [] 103 | prec_means = [] 104 | prec_means_orig = [] 105 | recall_means = [] 106 | recall_means_orig = [] 107 | bert_scores = [] 108 | for idx in range(batch_count): 109 | 110 | refs_l = orig_reports[idx,:].tolist() 111 | hypos_l = mask_reports[idx,:].tolist() 112 | 113 | f1_scores = np.empty((len(refs_l), len(QUESTIONS))) 114 | f1_scores.fill(np.nan) 115 | 116 | f1_scores_orig = np.empty((len(refs_l), len(QUESTIONS))) 117 | f1_scores_orig.fill(np.nan) 118 | 119 | prec_scores = np.empty((len(refs_l), len(QUESTIONS))) 120 | prec_scores.fill(np.nan) 121 | 122 | prec_scores_orig = np.empty((len(refs_l), len(QUESTIONS))) 123 | prec_scores_orig.fill(np.nan) 124 | 125 | recall_scores = np.empty((len(refs_l), len(QUESTIONS))) 126 | recall_scores.fill(np.nan) 127 | 128 | recall_scores_orig = np.empty((len(refs_l), len(QUESTIONS))) 129 | recall_scores_orig.fill(np.nan) 130 | 131 | for q_idx, cur_question in enumerate(QUESTIONS): 132 | # Copy questions for batch forwarding to the model 133 | question_batch = [cur_question] * len(hypos_l) 134 | 135 | # Get results from QA model 136 | refs_cur_results = qa_model(question=question_batch, context=refs_l) 137 | hypo_cur_results = qa_model(question=question_batch, context=hypos_l) 138 | 139 | # Get bert scores for given answers 140 | bert_score_refs = [] 141 | bert_score_hypo = [] 142 | for sample_idx, (cur_ref_res, cur_hypo_res) in enumerate(zip(refs_cur_results, hypo_cur_results)): 143 | bert_score_refs.append(cur_ref_res['answer']) 144 | bert_score_hypo.append(cur_hypo_res['answer']) 145 | 146 | b_prec, b_recall, b_f1 = bert_score_qa_model.score(bert_score_hypo, bert_score_refs) 147 | b_prec, b_recall, b_f1 = b_prec.numpy(), b_recall.numpy(), b_f1.numpy() 148 | 149 | full_prec, full_recall, full_f1 = bert_score_qa_model.score(hypos_l, refs_l) 150 | full_prec, full_recall, full_f1 = full_prec.numpy(), full_recall.numpy(), full_f1.numpy() 151 | 152 | # Select scores for loss based on threshold 153 | for sample_idx, (cur_ref_res, cur_hypo_res) in enumerate(zip(refs_cur_results, hypo_cur_results)): 154 | if cur_ref_res['score'] > QA_THRESHOLD or cur_hypo_res['score'] > QA_THRESHOLD: 155 | f1_scores[sample_idx, q_idx] = b_f1[sample_idx] 156 | f1_scores_orig[sample_idx, q_idx] = full_f1[sample_idx] 157 | 158 | prec_scores[sample_idx, q_idx] = b_prec[sample_idx] 159 | prec_scores_orig[sample_idx, q_idx] = full_prec[sample_idx] 160 | 161 | recall_scores[sample_idx, q_idx] = b_recall[sample_idx] 162 | recall_scores_orig[sample_idx, q_idx] = full_recall[sample_idx] 163 | 164 | bert_scores.append(np.nanmean(f1_scores, axis=0)) 165 | f1_score_means.append(np.nanmean(f1_scores)) 166 | f1_score_means_orig.append(np.nanmean(f1_scores_orig)) 167 | prec_means.append(np.nanmean(prec_scores)) 168 | prec_means_orig.append(np.nanmean(prec_scores_orig)) 169 | recall_means.append(np.nanmean(recall_scores)) 170 | recall_means_orig.append(np.nanmean(recall_scores_orig)) 171 | 172 | # Save data for final dataframe 173 | bert_scores_detailed.append(np.array(bert_scores).mean(axis=0)) 174 | data_dict['f1_full'].append(np.array(f1_score_means_orig).mean()) 175 | data_dict['f1_qa'].append(np.array(f1_score_means).mean()) 176 | data_dict['prec_full'].append(np.array(prec_means_orig).mean()) 177 | data_dict['prec_qa'].append(np.array(prec_means).mean()) 178 | data_dict['recall_full'].append(np.array(recall_means_orig).mean()) 179 | data_dict['recall_qa'].append(np.array(recall_means).mean()) 180 | 181 | # Save metrics dataframe 182 | save_df = pd.DataFrame(data_dict) 183 | save_df.to_csv(f'metric_experiments_{QA_THRESHOLD}.csv', index=False) 184 | 185 | # Save masked reports dataframe 186 | save_df = pd.DataFrame(mask_reports_dict) 187 | save_df.to_csv(f'masked_reports_{QA_THRESHOLD}.csv', index=False) 188 | 189 | # Save detailed bert scores 190 | bert_scores_np = np.array(bert_scores_detailed) 191 | save_df = pd.DataFrame(bert_scores_np, columns=QUESTIONS) 192 | save_df.to_csv(f'bert_scores_{QA_THRESHOLD}.csv', index=False) -------------------------------------------------------------------------------- /ner_reports.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import csv 6 | import gzip 7 | import json 8 | import os 9 | import stanza 10 | from stanza import Pipeline 11 | from clinicgen.data.mimiccxr import MIMICCXRData 12 | from clinicgen.data.openi import OpenIData 13 | from clinicgen.text.textfilter import get_textfilter 14 | 15 | 16 | def main(args): 17 | # Download 18 | if args.stanza_download: 19 | stanza.download('en', processors='tokenize,lemma,pos,ner') 20 | stanza.download('en', package='radiology') 21 | # Text processors 22 | textfilter = get_textfilter(args.textfilter) 23 | nlp = Pipeline(lang='en', package='radiology', processors={'lemma': 'default', 'pos': 'default', 24 | 'tokenize': 'default', 'ner': 'radiology'}) 25 | # Extract texts 26 | texts = [] 27 | if args.corpus == 'mimic-cxr': 28 | path = os.path.join(args.data, 'mimic-cxr-resized', '2.0.0', MIMICCXRData.SECTIONED_PATH) 29 | with gzip.open(path, 'rt', encoding='utf-8') as f: 30 | header = f.readline() 31 | reader = csv.reader(f) 32 | for row in reader: 33 | if args.section == 'impression': 34 | text = row[1] 35 | else: 36 | text = row[2] 37 | if len(text) > 0: 38 | texts.append((row[0][1:], text)) 39 | elif args.corpus == 'open-i': 40 | dataset = OpenIData(args.data, section=args.section, meta=args.splits, split='train', multi_image=2) 41 | for tid, target in zip(dataset.doc_ids, dataset.targets): 42 | target = gzip.decompress(target).decode('utf-8') 43 | target = dataset.extract_section(target) 44 | if len(target) > 0: 45 | texts.append((tid, target)) 46 | else: 47 | raise ValueError('Unknown corpus {0}'.format(args.corpus)) 48 | print('{0} texts'.format(len(texts))) 49 | # Extract NEs 50 | count = 0 51 | with gzip.open(args.output, 'wt', encoding='utf-8') as out: 52 | for tid, text in texts: 53 | ftext = textfilter.filter(text) 54 | doc = nlp(ftext) 55 | i = 0 56 | for sentence in doc.sentences: 57 | token_starts, token_ends = {}, {} 58 | j = 0 59 | text_tokens = [] 60 | for token in sentence.tokens: 61 | token_starts[token.start_char] = j 62 | token_ends[token.end_char] = j 63 | text_tokens.append(token.text) 64 | j += 1 65 | lemmas, poses = [], [] 66 | for word in sentence.words: 67 | lemmas.append(word.lemma) 68 | poses.append(word.pos) 69 | ne_tuples = [] 70 | for entity in sentence.ents: 71 | ne_tuples.append({'text': entity.text, 'type': entity.type, 72 | 'start': token_starts[entity.start_char], 73 | 'end': token_ends[entity.end_char] + 1}) 74 | ins = {'id': '{0}__{1}'.format(tid, i), 'nes': ne_tuples, 'text': sentence.text, 75 | 'tokens': text_tokens, 'lemmas': lemmas, 'poses': poses} 76 | out.write('{0}\n'.format(json.dumps(ins))) 77 | i += 1 78 | count += 1 79 | if count % 10000 == 0: 80 | print('Processed {0}'.format(count)) 81 | 82 | 83 | def parse_args(): 84 | parser = argparse.ArgumentParser(description='Extract radiology named entities') 85 | parser.add_argument('data', type=str, help='A path to a clinical dataset') 86 | parser.add_argument('output', type=str, help='An output path') 87 | parser.add_argument('--cache', type=str, default=None, help='A cache path') 88 | parser.add_argument('--corpus', type=str, default='mimic-cxr', help='Corpus name') 89 | parser.add_argument('--section', type=str, default='findings', help='Target section') 90 | parser.add_argument('--splits', type=str, default=None, help='A path to a file defining splits') 91 | parser.add_argument('--stanza-download', default=False, action='store_true', help='Download Stanza clinical model') 92 | parser.add_argument('--textfilter', type=str, default='lower', help='Text filter') 93 | return parser.parse_args() 94 | 95 | 96 | if __name__ == '__main__': 97 | args = parse_args() 98 | main(args) 99 | -------------------------------------------------------------------------------- /resize_mimic-cxr-jpg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import os 6 | import shutil 7 | from torchvision.datasets.folder import default_loader 8 | from torchvision.transforms import Resize 9 | 10 | 11 | def main(args): 12 | # in_files = os.path.join(args.jpgs, '2.0.0', 'files') 13 | in_files = os.path.join(args.jpgs, 'images_clean/files') 14 | resized_dir = os.path.join(os.path.dirname(args.jpgs), '..', 'mimic-cxr-resized') 15 | # resized_dir = os.path.dirname('/l/users/20020038/mimic-cxr-resized') 16 | out_files = os.path.join(resized_dir, '2.0.0', 'files') 17 | resize = Resize(256) 18 | # Make mimic-cxr-resized directory 19 | if not os.path.exists(resized_dir): 20 | os.mkdir(resized_dir) 21 | os.mkdir(os.path.join(resized_dir, '2.0.0')) 22 | os.mkdir(out_files) 23 | for item in os.listdir(os.path.join(args.jpgs, '')): 24 | # if item.endswith('.csv.gz'): 25 | if item.endswith('.csv'): 26 | shutil.copy(os.path.join(args.jpgs, '2.0.0', item), os.path.join(resized_dir, '2.0.0', item)) 27 | # Copy various MIMIC-CXR-JPG data 28 | # shutil.copy(os.path.join(args.jpgs, '2.0.0', 'mimic-cxr-2.0.0-chexpert.csv.gz'), os.path.join(resized_dir, '2.0.0')) 29 | # shutil.copy(os.path.join(args.jpgs, '2.0.0', 'mimic-cxr-2.0.0-metadata.csv.gz'), os.path.join(resized_dir, '2.0.0')) 30 | # shutil.copy(os.path.join(args.jpgs, '2.0.0', 'mimic-cxr-2.0.0-split.csv.gz'), os.path.join(resized_dir, '2.0.0')) 31 | shutil.copy(os.path.join(args.jpgs, 'mimic-cxr-2.0.0-chexpert.csv'), os.path.join(resized_dir, '2.0.0')) 32 | shutil.copy(os.path.join(args.jpgs, 'mimic-cxr-2.0.0-metadata.csv'), os.path.join(resized_dir, '2.0.0')) 33 | shutil.copy(os.path.join(args.jpgs, 'mimic-cxr-2.0.0-split.csv'), os.path.join(resized_dir, '2.0.0')) 34 | # Resize images 35 | count_resized, count_skipped = 0, 0 36 | for item1 in os.listdir(in_files): 37 | if item1.startswith('p') and os.path.isdir(os.path.join(in_files, item1)): 38 | print('Processing {0} ...'.format(item1)) 39 | if not os.path.exists(os.path.join(out_files, item1)): 40 | os.mkdir(os.path.join(out_files, item1)) 41 | for item2 in os.listdir(os.path.join(in_files, item1)): 42 | if item2.startswith('p'): 43 | if not os.path.exists(os.path.join(out_files, item1, item2)): 44 | os.mkdir(os.path.join(out_files, item1, item2)) 45 | for item3 in os.listdir(os.path.join(in_files, item1, item2)): 46 | if item3.startswith('s'): 47 | if not os.path.exists(os.path.join(out_files, item1, item2, item3)): 48 | os.mkdir(os.path.join(out_files, item1, item2, item3)) 49 | for item4 in os.listdir(os.path.join(in_files, item1, item2, item3)): 50 | if item4.endswith('.jpg'): 51 | out_image_path = os.path.join(out_files, item1, item2, item3, 52 | item4.replace('.jpg', '.png')) 53 | if not os.path.exists(out_image_path): 54 | image_path = os.path.join(in_files, item1, item2, item3, item4) 55 | image = default_loader(image_path) 56 | image = resize(image) 57 | with open(out_image_path, 'wb') as out: 58 | image.save(out, 'png') 59 | count_resized += 1 60 | if count_resized % 1000 == 0: 61 | print('Resized {0} images'.format(count_resized)) 62 | else: 63 | count_skipped += 1 64 | print('Total {0} images'.format(count_resized + count_skipped)) 65 | 66 | 67 | def parse_args(): 68 | parser = argparse.ArgumentParser(description='Resize MIMIC-CXR-JPG jpgs to 256 pixels pngs') 69 | parser.add_argument('jpgs', type=str, help='A path to MIMIC-CXR-JPG') 70 | return parser.parse_args() 71 | 72 | 73 | if __name__ == '__main__': 74 | args = parse_args() 75 | main(args) 76 | -------------------------------------------------------------------------------- /resources/.gitignore: -------------------------------------------------------------------------------- 1 | model_medrad_19k 2 | chexpert_auc14.dict.gz 3 | glove_mimic-cxr_train.512.txt.gz 4 | -------------------------------------------------------------------------------- /resources/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | wget https://nlp.stanford.edu/ysmiura/ifcc/glove_mimic-cxr_train.512.txt.gz 3 | wget https://nlp.stanford.edu/ysmiura/ifcc/model_medrad_19k.tar.gz 4 | tar xvzf model_medrad_19k.tar.gz 5 | rm model_medrad_19k.tar.gz 6 | wget https://nlp.stanford.edu/ysmiura/ifcc/chexpert_auc14.dict.gz 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | 4 | with open('README.md', 'r') as fh: 5 | long_description = fh.read() 6 | 7 | setuptools.setup( 8 | name='ifcc', 9 | version='0.2.0', 10 | author='Yasuhide Miura', 11 | author_email='ysmiura@stanford.edu', 12 | description='The code of: Improving Factual Completeness and Consistency of Image-to-text Radiology Report Generation', 13 | long_description=long_description, 14 | long_description_content_type='text/markdown', 15 | url='https://github.com/ysmiura/ifcc', 16 | packages='clinicgen', 17 | python_requires='>=3.7', 18 | install_requires=[ 19 | 'bert-score==0.3.0', 20 | 'bioc==1.3.4', 21 | 'bllipparser==2016.9.11', 22 | 'cachetools==4.1.0', 23 | 'flask==1.1.1', 24 | 'jpype1==0.6.3', 25 | 'networkx==1.11', 26 | 'nltk==3.4.5', 27 | 'numpy==1.18.5', 28 | 'pandas==1.0.1', 29 | 'pathlib2==2.3.5', 30 | 'ply==3.11', 31 | 'pystanforddependencies==0.3.1', 32 | 'rouge==0.3.2', 33 | 'scispacy==0.2.0', 34 | 'spacy==2.1.3', 35 | 'stanza==1.1.1', 36 | 'tensorboard==2.0.0', 37 | 'torch==1.5.0', 38 | 'torchvision==0.6.0', 39 | 'tqdm==4.45.0', 40 | 'transformers==2.9.0' 41 | ] 42 | ) 43 | -------------------------------------------------------------------------------- /tests/test_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import unittest 5 | import numpy as np 6 | import torch 7 | from clinicgen.data.image2text import PretrainedEmbeddings as PreEmb 8 | from clinicgen.eval import EntityMatcher, RecoverWords 9 | 10 | 11 | class TestEntityMatcher(unittest.TestCase): 12 | def setUp(self): 13 | sentences = {'1': {0: 'No pleural effusions.'}, '2': {0: 'Enlarged heart.'}} 14 | entities = {'1': {'pleural': [0], 'effusions': [0]}, '2': {'heart': [0]}} 15 | target_types = {'ANATOMY': True, 'OBSERVATION': True} 16 | self.matcher = EntityMatcher(sentences, entities, target_types) 17 | 18 | def test_score(self): 19 | rs = self.matcher.score(['1', '2'], ['No pleural effusion.', 'Normal heart size.']) 20 | self.assertEqual(rs[1][0], 0.5) 21 | self.assertEqual(rs[1][1], 1.0) 22 | 23 | 24 | class TestRecoverWords(unittest.TestCase): 25 | def setUp(self): 26 | word_idxs = {'__PAD__': PreEmb.INDEX_PAD, '__START__': PreEmb.INDEX_START, '__UNK__': PreEmb.INDEX_UNKNOWN, 27 | PreEmb.TOKEN_EOS: PreEmb.INDEX_EOS, 'Hello': 4, 'world': 5, '!': 6, 'NLP': 7} 28 | self.recover_words = RecoverWords(word_idxs) 29 | 30 | def test___call__(self): 31 | stops = torch.tensor([[-1.0, -0.2, 1.0], [-2.0, 0.1, 0.5]]) 32 | samples = np.zeros((2, 3, 4)) 33 | samples[0][0] = np.array([4, 5, 6, 0]) 34 | samples[0][1] = np.array([4, 7, 6, 0]) 35 | samples[1][0] = np.array([4, 5, 7, 6]) 36 | samples = torch.tensor(samples).type(torch.long) 37 | rec, _ = self.recover_words(stops, samples) 38 | self.assertEqual(rec[0], 'Hello world !\nHello NLP !') 39 | self.assertEqual(rec[1], 'Hello world NLP !') 40 | -------------------------------------------------------------------------------- /tests/test_nli.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import unittest 6 | from clinicgen.nli import SimpleNLI 7 | 8 | 9 | class TestSimpleNLI(unittest.TestCase): 10 | def setUp(self): 11 | resource_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'resources') 12 | model = os.path.join(resource_dir, SimpleNLI.RADNLI_STATES) 13 | model = SimpleNLI.load_model(model) 14 | self.nli = SimpleNLI(model, bert_score='distilbert-base-uncased') 15 | 16 | def test_predict(self): 17 | sent1s = ['No pleural effusions.', 'Enlarged heart.', 'Pulmonary edema.'] 18 | sent2s = ['No pleural effusion.', 'Normal heart.', 'Clear lungs.'] 19 | rs = self.nli.predict(sent1s, sent2s) 20 | self.assertEqual(rs[1][0], 'entailment') 21 | self.assertEqual(rs[1][1], 'contradiction') 22 | self.assertEqual(rs[1][2], 'neutral') 23 | -------------------------------------------------------------------------------- /tests/text/test_sentsplit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import unittest 5 | from clinicgen.text.sentsplit import LineBreakSplitter, NLTKSentenceSplitter, SpaCySentenceSplitter, StanzaSentenceSplitter 6 | 7 | 8 | class TestLineBreakSplitter(unittest.TestCase): 9 | def test_split(self): 10 | splitter = LineBreakSplitter() 11 | text = 'Hello NLP! Running a test\nof sentence splitting. Line breaks are considered as sentence boundaries.' 12 | sents = splitter.split(text) 13 | self.assertEqual(len(sents), 2) 14 | self.assertTrue(sents[0].startswith('Hello')) 15 | self.assertTrue(sents[1].startswith('of')) 16 | 17 | 18 | class TestNLTKSentenceSplitter(unittest.TestCase): 19 | def test_split(self): 20 | splitter = NLTKSentenceSplitter() 21 | text = 'Hello NLP! Running a test\nof sentence splitting. Line breaks are considered as sentence boundaries.' 22 | sents = splitter.split(text) 23 | self.assertEqual(len(sents), 4) 24 | self.assertTrue(sents[0].startswith('Hello')) 25 | self.assertTrue(sents[1].startswith('Running')) 26 | self.assertTrue(sents[2].startswith('of')) 27 | self.assertTrue(sents[3].startswith('Line')) 28 | 29 | 30 | class TestSpaCySentenceSplitter(unittest.TestCase): 31 | def test_split(self): 32 | splitter = SpaCySentenceSplitter('en_core_web_sm') 33 | text = 'Hello NLP! Running a test\nof sentence splitting. Line breaks are considered as sentence boundaries.' 34 | sents = splitter.split(text) 35 | self.assertEqual(len(sents), 4) 36 | self.assertTrue(sents[0].startswith('Hello')) 37 | self.assertTrue(sents[1].startswith('Running')) 38 | self.assertTrue(sents[2].startswith('of')) 39 | self.assertTrue(sents[3].startswith('Line')) 40 | 41 | 42 | class TestStanzaSentenceSplitter(unittest.TestCase): 43 | def test_split(self): 44 | splitter = StanzaSentenceSplitter() 45 | text = 'Hello NLP! Running a test\nof sentence splitting. Line breaks are considered as sentence boundaries.' 46 | sents = splitter.split(text) 47 | self.assertEqual(len(sents), 4) 48 | self.assertTrue(sents[0].startswith('Hello')) 49 | self.assertTrue(sents[1].startswith('Running')) 50 | self.assertTrue(sents[2].startswith('of')) 51 | self.assertTrue(sents[3].startswith('Line')) -------------------------------------------------------------------------------- /tests/text/test_textfilter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import unittest 5 | from clinicgen.text.textfilter import LowerTextFilter 6 | 7 | 8 | class TestLowerTextFilter(unittest.TestCase): 9 | def test_filter(self): 10 | text = 'Hello NLP!' 11 | tfilter = LowerTextFilter() 12 | self.assertEqual(tfilter.filter(text), 'hello nlp!') 13 | -------------------------------------------------------------------------------- /tests/text/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import unittest 5 | from clinicgen.text.tokenizer import NLTKTokenizer, SpaCyTokenizer, StanzaTokenizer, WhiteSpaceTokenizer 6 | 7 | 8 | class TestNLTKTokenizer(unittest.TestCase): 9 | def test_tokenize(self): 10 | tokenizer = NLTKTokenizer() 11 | text = 'Hello NLP! Running a (tokenization) test.' 12 | tokens = tokenizer.tokenize(text) 13 | self.assertEqual(len(tokens), 10) 14 | self.assertEqual(tokens[0], 'Hello') 15 | self.assertEqual(tokens[2], '!') 16 | self.assertEqual(tokens[3], 'Running') 17 | self.assertEqual(tokens[5], '(') 18 | self.assertEqual(tokens[6], 'tokenization') 19 | self.assertEqual(tokens[9], '.') 20 | 21 | 22 | class TestSpaCyTokenizer(unittest.TestCase): 23 | def test_tokenize(self): 24 | tokenizer = SpaCyTokenizer('en_core_web_sm') 25 | text = 'Hello NLP! Running a (tokenization) test.' 26 | tokens = tokenizer.tokenize(text) 27 | self.assertEqual(len(tokens), 10) 28 | self.assertEqual(tokens[0], 'Hello') 29 | self.assertEqual(tokens[2], '!') 30 | self.assertEqual(tokens[3], 'Running') 31 | self.assertEqual(tokens[5], '(') 32 | self.assertEqual(tokens[6], 'tokenization') 33 | self.assertEqual(tokens[9], '.') 34 | 35 | 36 | class TestStanzaTokenizer(unittest.TestCase): 37 | def test_tokenize(self): 38 | tokenizer = StanzaTokenizer() 39 | text = 'Hello NLP! Running a (tokenization) test.' 40 | tokens = tokenizer.tokenize(text) 41 | self.assertEqual(len(tokens), 10) 42 | self.assertEqual(tokens[0], 'Hello') 43 | self.assertEqual(tokens[2], '!') 44 | self.assertEqual(tokens[3], 'Running') 45 | self.assertEqual(tokens[5], '(') 46 | self.assertEqual(tokens[6], 'tokenization') 47 | self.assertEqual(tokens[9], '.') 48 | 49 | 50 | class TestWhiteSpaceTokenizer(unittest.TestCase): 51 | def test_tokenize(self): 52 | tokenizer = WhiteSpaceTokenizer() 53 | text = 'Hello NLP! Running a (tokenization) test.' 54 | tokens = tokenizer.tokenize(text) 55 | self.assertEqual(len(tokens), 6) 56 | self.assertEqual(tokens[0], 'Hello') 57 | self.assertEqual(tokens[1], 'NLP!') 58 | self.assertEqual(tokens[2], 'Running') 59 | self.assertEqual(tokens[4], '(tokenization)') 60 | self.assertEqual(tokens[5], 'test.') 61 | 62 | -------------------------------------------------------------------------------- /train_image.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import gzip 6 | import os 7 | import random 8 | import time 9 | import numpy as np 10 | import torch 11 | from sklearn.metrics import roc_auc_score 12 | from tqdm import tqdm 13 | from torch import softmax 14 | from torch.optim import Adam 15 | from torch.optim.lr_scheduler import StepLR 16 | from torch.utils.data import DataLoader 17 | from clinicgen.data.chexpert import CheXpertData 18 | from clinicgen.data.utils import Data 19 | from clinicgen.models.image import ImageClassification 20 | from clinicgen.utils import data_cuda 21 | 22 | 23 | def eval_model(pbar_vals, outs, epoch, data_n, model, optimizer, scheduler, val_loader, test_loader, bests, 24 | device=None): 25 | for split, data_loader in [('val', val_loader), ('test', test_loader)]: 26 | if data_loader is not None: 27 | scores = eval_split(model, data_loader, device=device) 28 | pbar_vals['{0}_score'.format(split)] = scores[0] 29 | outs[split].write('{0}-{1} {2} {3}\n'.format(epoch, data_n, scores[0], scores[1])) 30 | outs[split].flush() 31 | if split == 'val': 32 | updates = update_bests(bests, scores) 33 | for update in updates: 34 | save_model(os.path.join(args.out, 'model_{0}.dict.gz'.format(update)), epoch, model, optimizer, 35 | scheduler, bests) 36 | 37 | 38 | def eval_split(model, data_loader, device=None): 39 | with torch.no_grad(): 40 | model.eval() 41 | y_true5, y_true14, y_score5, y_score14 = [], [], [], [] 42 | for _, inp, targ, _, _, _ in data_loader: 43 | inp, _ = data_cuda(inp, targ, device=device, non_blocking=False) 44 | out = model(inp) 45 | probs = softmax(out.permute(0, 2, 1)[:, :, 1:3], dim=-1) 46 | probs = probs.detach().cpu().numpy() 47 | targ = targ.numpy() 48 | for i in range(probs.shape[0]): 49 | for j in range(probs.shape[1]): 50 | true_val = 1 if targ[i][j] == 1 else 0 51 | score_val = probs[i][j][1] 52 | y_true14.append(true_val) 53 | y_score14.append(score_val) 54 | if j == 2 or j == 5 or j == 6 or j == 8 or j == 10: 55 | y_true5.append(true_val) 56 | y_score5.append(score_val) 57 | y_true5 = np.array(y_true5) 58 | y_score5 = np.array(y_score5) 59 | y_true14 = np.array(y_true14) 60 | y_score14 = np.array(y_score14) 61 | rocauc5 = roc_auc_score(y_true5, y_score5, average='macro') 62 | rocauc14 = roc_auc_score(y_true14, y_score14, average='macro') 63 | model.train() 64 | return rocauc5, rocauc14 65 | 66 | 67 | def main(args): 68 | if not os.path.exists(args.out): 69 | os.makedirs(args.out) 70 | else: 71 | print('ERROR: {0} already exists'.format(args.out)) 72 | exit(1) 73 | 74 | # Set random seeds 75 | random.seed(args.seed) 76 | np.random.seed(args.seed) 77 | torch.manual_seed(args.seed) 78 | 79 | # Model configurations 80 | model = ImageClassification(args.model, CheXpertData.NUM_LABELS, CheXpertData.NUM_CLASSES, args.multi_image, 81 | dropout=args.dropout, pretrained=args.pretrained) 82 | if args.cuda: 83 | device = 'gpu' 84 | model = model.cuda(0) 85 | else: 86 | device = 'cpu' 87 | 88 | # Data 89 | t = time.time() 90 | datasets = Data.get_datasets(args.data, args.corpus, None, None, None, None, None, None, None, 91 | multi_image=args.multi_image, img_mode=args.img_trans, img_augment=args.img_augment, 92 | cache_data=args.cache_data, anatomy=args.anatomy, meta=args.splits, 93 | ignore_blank=args.ignore_blank, exclude_ids=args.exclude_ids, filter_reports=False) 94 | nw = 0 if args.cache_data else args.num_workers 95 | train_loader = DataLoader(datasets['train'], batch_size=args.batch_size, shuffle=True, num_workers=nw, 96 | pin_memory=False) 97 | batch_size_test = args.batch_size if args.batch_size_test is None else args.batch_size_test 98 | val_loader = DataLoader(datasets['validation'], batch_size=batch_size_test, shuffle=False, num_workers=nw, 99 | pin_memory=False) 100 | if 'test' in datasets: 101 | test_loader = DataLoader(datasets['test'], batch_size=batch_size_test, shuffle=False, num_workers=nw, 102 | pin_memory=False) 103 | test_size = len(test_loader.dataset.samples) 104 | else: 105 | test_loader, test_size = None, 0 106 | print('Data: train={0}, validation={1}, test={2} (load time {3:.2f}s)'.format(len(train_loader.dataset.samples), 107 | len(val_loader.dataset.samples), 108 | test_size, time.time() - t)) 109 | 110 | # Train and test processes 111 | optimizer = Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999)) 112 | scheduler = StepLR(optimizer, args.lr_step, args.lr_gamma) 113 | pbar_vals = {'loss': None, 'val_score': None, 'test_score': None} 114 | outs, bests = {}, {'auc5': 0.0, 'auc14': 0.0} 115 | try: 116 | outs['val'] = open(os.path.join(args.out, 'val.txt'), 'w', encoding='utf-8') 117 | outs['test'] = open(os.path.join(args.out, 'test.txt'), 'w', encoding='utf-8') 118 | for epoch in range(args.epochs): 119 | loss_log = [] 120 | 121 | with tqdm(total=len(train_loader.dataset.samples)) as pbar: 122 | pbar.set_description('Epoch {0}/{1}'.format(epoch + 1, args.epochs)) 123 | data_n, eval_interval, tqdm_interval = 0, 0, 0 124 | 125 | for _, inp, targ, _ in train_loader: 126 | # Train 127 | loss_val = model.train_step(inp, targ, optimizer, clip_grad=args.clip_grad, device=device) 128 | loss_log.append(loss_val) 129 | # Validation / Test 130 | data_n += inp.shape[0] 131 | eval_interval += inp.shape[0] 132 | if args.eval_interval is not None and eval_interval >= args.eval_interval: 133 | eval_model(pbar_vals, outs, epoch, data_n, model, optimizer, scheduler, val_loader, test_loader, 134 | bests, device) 135 | eval_interval -= args.eval_interval 136 | # Progress updates 137 | tqdm_interval += inp.shape[0] 138 | if args.tqdm_interval is None or tqdm_interval >= args.tqdm_interval: 139 | pbar_vals['loss'] = np.mean(loss_log) 140 | pbar.set_postfix(**pbar_vals) 141 | pbar.update(tqdm_interval) 142 | tqdm_interval -= args.tqdm_interval if args.tqdm_interval is not None else 0 143 | # Epoch end processes 144 | scheduler.step() 145 | eval_model(pbar_vals, outs, epoch, None, model, optimizer, scheduler, val_loader, test_loader, bests, 146 | device) 147 | finally: 148 | for _, out in outs.items(): 149 | out.close() 150 | 151 | 152 | def parse_args(): 153 | parser = argparse.ArgumentParser(description='Train a model for image classification') 154 | parser.add_argument('data', type=str, help='A path to clinical data') 155 | parser.add_argument('model', type=str, help='A model name') 156 | parser.add_argument('out', type=str, help='An output path') 157 | parser.add_argument('--anatomy', type=str, default=None, help='A specific anatomy to target') 158 | parser.add_argument('--batch-size', type=int, default=16, help='Batch size') 159 | parser.add_argument('--batch-size-test', type=int, default=None, help='Batch size (test)') 160 | parser.add_argument('--cache-data', type=str, default=None, help='Cache images and texts to memory and disk') 161 | parser.add_argument('--clip-grad', type=float, default=None, help='Clip gradients') 162 | parser.add_argument('--corpus', type=str, default='chexpert', choices=['a', 'chexpert', 'mimic-cxr', 'open-i'], help='Corpus name') 163 | parser.add_argument('--cuda', default=False, action='store_true', help='Use GPU') 164 | parser.add_argument('--dropout', type=float, default=0.0, help='Dropout probability') 165 | parser.add_argument('--epochs', type=int, default=12, help='Epoch num') 166 | parser.add_argument('--eval-interval', type=int, default=None, help='Evaluation interval') 167 | parser.add_argument('--exclude-ids', type=str, default=None, help='IDs to exclude from the data') 168 | parser.add_argument('--ignore-blank', default=False, action='store_true', help='Ignore blank labels') 169 | parser.add_argument('--img-no-augment', dest='img_augment', default=True, action='store_false', help='Do not augment images') 170 | parser.add_argument('--img-trans', type=str, default='pad', help='Image transformation mode') 171 | parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate') 172 | parser.add_argument('--lr-gamma', type=float, default=0.1, help='A learning rate scheduler gamma') 173 | parser.add_argument('--lr-step', type=int, default=16, help='A learning rate scheduler step') 174 | parser.add_argument('--multi-image', type=int, default=2, help='Multi image number') 175 | parser.add_argument('--scratch', dest='pretrained', default=True, action='store_false', help='Train a model from scratch') 176 | parser.add_argument('--seed', type=int, default=1, help='Random seed') 177 | parser.add_argument('--splits', type=str, default=None, help='A path to a file defining splits') 178 | parser.add_argument('--tqdm-interval', type=int, default=None, help='tqdm interval') 179 | return parser.parse_args() 180 | 181 | 182 | def save_model(path, epoch, model, optimizer, scheduler, bests): 183 | with gzip.open(path, 'wb') as out: 184 | state = {'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 185 | 'scheduler': scheduler.state_dict(), 'bests': bests} 186 | torch.save(state, out) 187 | 188 | 189 | def update_bests(bests, scores): 190 | updates = [] 191 | if scores[0] > bests['auc5']: 192 | bests['auc5'] = scores[0] 193 | updates.append('auc5') 194 | if scores[1] > bests['auc14']: 195 | bests['auc14'] = scores[1] 196 | updates.append('auc14') 197 | return updates 198 | 199 | 200 | if __name__ == '__main__': 201 | args = parse_args() 202 | main(args) 203 | --------------------------------------------------------------------------------