├── GPT-IE ├── run.py ├── pre_process.py ├── GPT-IE.py └── post_process.py ├── LLaMA-IE ├── llama3_sft.sh ├── instruction_generator.py ├── inference.py └── post_process.py ├── train ├── models.py ├── constants.py └── train.py ├── downstream ├── retrieve.py ├── classify.py ├── models.py └── segment.py ├── README.md ├── .gitignore └── LICENSE /GPT-IE/run.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import subprocess 3 | from typing import List 4 | import time 5 | import os 6 | 7 | def task_func(args): 8 | while True: 9 | try: 10 | subprocess.check_call(['python3', "prompt.py",args]) 11 | except Exception as e: 12 | time.sleep(10) 13 | print(f"Error occurred: {e}. Restarting...") 14 | 15 | 16 | if __name__ == '__main__': 17 | processes = [] 18 | for i in range(10,20): 19 | path = "./results/"+"p"+str(i) 20 | if not os.path.exists(path): 21 | os.makedirs(path, exist_ok=True) 22 | 23 | for i in range(10,20): 24 | p = multiprocessing.Process(target=task_func, args=("p"+str(i),)) 25 | p.start() 26 | processes.append(p) 27 | 28 | for p in processes: 29 | p.join() -------------------------------------------------------------------------------- /LLaMA-IE/llama3_sft.sh: -------------------------------------------------------------------------------- 1 | # 12GB GPU memory 2 | 3 | PYTHONPATH=../../.. \ 4 | CUDA_VISIBLE_DEVICES=1,2 \ 5 | swift sft \ 6 | --model_id_or_path LLM-Research/Meta-Llama-3-8B-Instruct \ 7 | --model_revision master \ 8 | --sft_type lora \ 9 | --tuner_backend peft \ 10 | --template_type AUTO \ 11 | --dtype AUTO \ 12 | --output_dir output \ 13 | --dataset ./data/instruction.json \ 14 | --train_dataset_sample 1000 \ 15 | --num_train_epochs 5 \ 16 | --max_length 512 \ 17 | --check_dataset_strategy warning \ 18 | --quantization_bit 4 \ 19 | --bnb_4bit_comp_dtype AUTO \ 20 | --lora_rank 8 \ 21 | --lora_alpha 32 \ 22 | --lora_dropout_p 0.05 \ 23 | --lora_target_modules q_proj v_proj \ 24 | --gradient_checkpointing true \ 25 | --batch_size 1 \ 26 | --weight_decay 0.1 \ 27 | --learning_rate 1e-4 \ 28 | --gradient_accumulation_steps 16 \ 29 | --max_grad_norm 0.5 \ 30 | --warmup_ratio 0.03 \ 31 | --eval_steps 100 \ 32 | --save_steps 100 \ 33 | --save_total_limit 2 \ 34 | --logging_steps 10 35 | -------------------------------------------------------------------------------- /train/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | import numpy as np 5 | import timm 6 | from transformers import CLIPProcessor, CLIPModel 7 | 8 | 9 | class VLM(nn.Module): 10 | def __init__(self, embed_dim=768, vision_model="resnet"): 11 | super().__init__() 12 | self.vision_model = vision_model 13 | if self.vision_model == "vit": 14 | self.image_model = timm.create_model( 15 | "eva02_base_patch14_448.mim_in22k_ft_in22k_in1k", 16 | pretrained=True, 17 | num_classes=0, 18 | ) 19 | elif self.vision_model == "resnet": 20 | self.image_model = models.resnet50(weights=True) 21 | num_ftrs = self.image_model.fc.in_features 22 | self.image_model.fc = nn.Linear(num_ftrs, embed_dim) 23 | 24 | self.text_model = CLIPModel.from_pretrained( 25 | "openai/clip-vit-base-patch32" 26 | ).text_model 27 | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 28 | self.text_projection = nn.Linear(512, embed_dim, bias=False) 29 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 30 | 31 | def forward(self, images, texts, device): 32 | image_feat = self.image_model(images) 33 | tokens = self.processor( 34 | text=texts, 35 | padding=True, 36 | )["input_ids"] 37 | tokens = torch.tensor(tokens) 38 | tokens = tokens.to(device) 39 | text_feat = self.text_model(tokens).pooler_output 40 | text_feat = self.text_projection(text_feat) 41 | image_feat = image_feat / image_feat.norm(dim=1, keepdim=True) 42 | text_feat = text_feat / text_feat.norm(dim=1, keepdim=True) 43 | logit_scale = self.logit_scale.exp() 44 | logits_per = logit_scale * image_feat @ text_feat.t() 45 | return logits_per 46 | -------------------------------------------------------------------------------- /downstream/retrieve.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from models import VLM 4 | 5 | 6 | device = torch.device("cuda:{}".format(4) if torch.cuda.is_available() else "cpu") 7 | net = VLM() 8 | net.to(device) 9 | ckpt_path = "~/MedFILIP/train/pretrained/model.pt" 10 | checkpoints = torch.load(ckpt_path, map_location=device) 11 | net.load_state_dict(checkpoints["network"]) 12 | net.eval() 13 | dataset = np.load("~/MedFILIP/FT/mimic_retrieve.npy", allow_pickle=True) 14 | text_datas = [] 15 | image_datas = [] 16 | for data in dataset: 17 | text_datas.append(data["labels"]) 18 | image_datas.append(np.array(data["image_npy"])) 19 | labels = [ 20 | [text_datas[i] == text_datas[j] for j in range(len(text_datas))] 21 | for i in range(len(text_datas)) 22 | ] 23 | labels = np.array(labels) 24 | text_embeddings = [] 25 | image_embeddings = [] 26 | b = 20 27 | for i in range(int(len(text_datas) / b)): 28 | text_data = text_datas[i * b : (i + 1) * b] 29 | text_data = [sublist[0] for sublist in text_data] 30 | image_data = image_datas[i * b : (i + 1) * b] 31 | image_data = torch.tensor(image_data) 32 | image_data = image_data.to(device) 33 | image_embedding, text_embedding = net(image_data, text_data, device, True) 34 | text_embeddings.append(text_embedding.detach().cpu().numpy()) 35 | image_embeddings.append(image_embedding.detach().cpu().numpy()) 36 | text_embeddings = np.concatenate(text_embeddings, axis=0) 37 | image_embeddings = np.concatenate(image_embeddings, axis=0) 38 | similarities = np.dot(image_embeddings, text_embeddings.T) 39 | similarities = np.array(similarities) 40 | dic_ItoT = {} 41 | 42 | for k in range(1, 11): 43 | top_k_pred = np.argsort(similarities, axis=1)[:, -k:] 44 | top_k_pred = list(top_k_pred) 45 | label = [np.where(row == 1)[0] for row in labels] 46 | overlap_counts = [ 47 | len(set(row1) & set(row2)) for row1, row2 in zip(top_k_pred, label) 48 | ] 49 | dic_ItoT[k] = sum(overlap_counts) 50 | print(dic_ItoT) 51 | 52 | dic_TtoI = {} 53 | similarities, labels = similarities.T, labels.T 54 | for k in range(1, 11): 55 | top_k_pred = np.argsort(similarities, axis=1)[:, -k:] 56 | top_k_pred = list(top_k_pred) 57 | label = [np.where(row == 1)[0] for row in labels] 58 | overlap_counts = [ 59 | len(set(row1) & set(row2)) for row1, row2 in zip(top_k_pred, label) 60 | ] 61 | dic_TtoI[k] = sum(overlap_counts) 62 | print(dic_TtoI) 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [IEEE-JBHI 2025] Pytorch Implementation of the Paper "MedFILIP: Medical Fine-Grained Language-Image Pre-Training" 2 | 3 | ## Requirements 4 | 5 | - python=3.10.12 6 | - pytorch-cuda=11.7 7 | - tensorflow=2.14.0 8 | - transformers=4.24.0 9 | 10 | ## Code Architecture 11 | 12 | ### downstream 13 | Contains modules for fine-tuning and inference: 14 | - `classifi.py`: Fine-tuning for classification tasks 15 | - `models.py`: Contrastive learning models and segmentation models 16 | - `retrieve.py`: Zero-shot retrieval tasks 17 | - `segment.py`: Fine-tuning for segmentation tasks 18 | 19 | ### GPT-IE 20 | Information extraction using GPT-3.5 and related preprocessing and post-processing: 21 | - `GPT-IE.py`: Entity extraction using GPT-3.5 22 | - `post_process.py`: Post-processing of extracted entities 23 | - `pre_process.py`: Preprocessing of diagnostic reports 24 | - `run.py`: Multithreaded execution of GPT-IE 25 | 26 | ### LLaMA-IE 27 | Information extraction using LLaMA-3-8B 28 | - **data folder**: Houses instruction fine-tuning dataset for LLaMA-3-8B 29 | - **inference.py**: Code for inference using the fine-tuned LLaMA-3-8B 30 | - **instruction_generator.py**: Code for constructing instruction fine-tuning dataset 31 | - **llama3_sft.sh**: Command-line code for LLaMA-3-8B fine-tuning 32 | - Configuration file: `.\LLM\ckpt\sft_args.json` 33 | - **post_process.py**: Post-processes LLaMA-3-8B's output, converting structured disease information to JSON format 34 | 35 | ### train 36 | Training of contrastive learning models and related configurations: 37 | - `constants.py`: Sets of disease categories, disease severity levels, disease locations, and disease-description mapping dictionaries 38 | - `models.py`: Contrastive learning models 39 | - `data_GPT.json`: Entity extracted by GPT-3.5 40 | - `data_llama3_8B.json`: Entity extracted by LLAMA-3-8B 41 | - `train.py`: Training script for contrastive learning models 42 | 43 | ## Citation 44 | 45 | If you use this project in your research, please consider citing it. Below is the BibTeX entry for referencing this work: 46 | 47 | ```bibtex 48 | @article{liang2025medfilip, 49 | title={MedFILIP: Medical Fine-Grained Language-Image Pre-Training}, 50 | author={Liang, Xinjie and Li, Xiangyu and Li, Fanding and Jiang, Jie and Dong, Qing and Wang, Wei and Wang, Kuanquan and Dong, Suyu and Luo, Gongning and Li, Shuo}, 51 | journal={IEEE Journal of Biomedical and Health Informatics}, 52 | year={2025}, 53 | publisher={IEEE} 54 | } -------------------------------------------------------------------------------- /GPT-IE/pre_process.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | import re 4 | import os 5 | 6 | 7 | def read_report(csv_path, root_dir, save_path): 8 | """ 9 | Read and process medical reports. 10 | 11 | This function reads report information from the given CSV file, filters reports containing specific diseases, 12 | and saves these reports to the specified path. 13 | 14 | Parameters: 15 | csv_path: Path to the CSV file containing report information. 16 | root_dir: Root directory of the original report files. 17 | save_path: Path to save the processed reports. 18 | """ 19 | # Initialize a dictionary to store report paths for each patient 20 | paths_dict = {} 21 | for i in range(10, 20): 22 | paths_dict[str(i)] = [] 23 | 24 | # Open the CSV file and read report information 25 | with open(csv_path, "r") as f: 26 | reader = csv.reader(f) 27 | for row in reader: 28 | # Read reports containing diseases 29 | if row[10] == "": 30 | paths_dict[row[0][:2]].append( 31 | root_dir + "p" + row[0][:2] + "/p" + row[0] + "/s" + row[1] + ".txt" 32 | ) 33 | 34 | # Traverse each patient's report paths, read and process report contents 35 | for files in paths_dict: 36 | reports = [] 37 | for file_path in paths_dict[files]: 38 | with open(file_path, "r") as f: 39 | contents = f.read() 40 | report = contents.replace("\n", "") 41 | if report.find("FINAL REPORT") != -1: 42 | report = report.split("FINAL REPORT")[1] 43 | findings_index = report.find("FINDINGS:") 44 | if findings_index == -1: 45 | findings_index = len(report) 46 | impression_index = report.find("IMPRESSION:") 47 | if impression_index == -1: 48 | impression_index = len(report) 49 | index = min(findings_index, impression_index) 50 | report = report[index:] 51 | reports.append({"path": file_path, "text": report}) 52 | 53 | # Save the processed reports to the specified path 54 | reports_path = save_path + "/p" + files 55 | if not os.path.exists(reports_path): 56 | os.makedirs(reports_path) 57 | reports_path = reports_path + "/reports.npy" 58 | print(reports_path) 59 | np.save(reports_path, reports) 60 | 61 | 62 | if __name__ == "__main__": 63 | root_dir = "/root/reports/files/" 64 | csv_path = "/root/mimic-cxr-2.0.0-chexpert.csv" 65 | save_path = "reports.npy" 66 | save_path = "./reports" 67 | read_report(csv_path, root_dir, save_path) 68 | 69 | reports = np.load("./reports/p10/reports.npy", allow_pickle=True) 70 | print(len(reports)) 71 | print(reports[4]) -------------------------------------------------------------------------------- /LLaMA-IE/instruction_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | reports_path = "~/DILLIM/reports_process/reports/p18/reports.npy" 4 | reports = np.load(reports_path, allow_pickle=True) 5 | prompt = "Now you need to use your medical knowledge to help me. I will give you some medical diagnosis reports, and you need to extract disease-related information from the reports. The output should only contain the extracted disease information and should not include any other text. When extracting information, you need to follow six rules as follows: 1. Each piece of disease-related information extracted must meet the format: {severity or stage of disease}{location or organ of disease}{category of disease}, for example, the sentence is ‘New nodular opacities are clustered within the left upper lobe.’, then the disease-related information extracted should be ‘{New}{left upper lobe}{nodular opacities}’. 2. Some sentences may lack some information, in this case, you need to use {mask} instead, for example, if the sentence is ‘There is left lung pneumonia.’, then the disease-related information extracted should be {mask}{left lung}{pneumonia}. 3. If the disease was negatively mentioned in the report, for example, if the sentence is ‘There is no pneumothorax or left-sided pleural effusion.’, then the disease-related information extracted should be ‘{No}{mask}{pneumothorax}’ and ‘{No}{left-sided pleural}{effusion}’. 4. There may be multiple disease descriptions in one sentence, you need to find them all and extract disease-related information. 5. Ignore words irrelevant to disease description, for example, the sentence is ‘The heart size is normal.’, and there is no disease-related information, so you don’t need to extract information from this sentence. 6. Separate information with commas. \nNext, you can extract information from the report.\n report:\n" 6 | 7 | k = 25 8 | results = np.array([]) 9 | for p in range(int(1000 / k)): 10 | content = [] 11 | m, n = 1, 76 12 | with open( 13 | "~/DILLIM/reports_process/examples/reports" + str(p) + ".txt", "r" 14 | ) as f: 15 | for line in f: 16 | content.append(line) 17 | for i in range(25): 18 | m = m + 3 19 | n = n + 2 20 | query = content[m] 21 | response = content[n] 22 | if query[-1] == "\n": 23 | query = query[:-1] 24 | if response[-1] == "\n": 25 | response = response[:-1] 26 | print(m, n, "query", query, "response", response) 27 | results = np.append(results, {"query": prompt + query, "response": response}) 28 | # print('reports'+str(n)+':'+str(len(content))) 29 | 30 | print(len(results)) 31 | import json 32 | 33 | results_json = results.tolist() 34 | results_json = json.dumps(results_json) 35 | with open("~/DILLIM/reports_process/data/instruction.json", "w") as file: 36 | file.write(results_json) 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | -------------------------------------------------------------------------------- /LLaMA-IE/inference.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import json 3 | import sys 4 | import numpy as np 5 | import time 6 | import argparse 7 | import os 8 | 9 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 10 | from swift.llm import ( 11 | get_model_tokenizer, 12 | get_template, 13 | inference, 14 | ModelType, 15 | get_default_template_type, 16 | ) 17 | from swift.utils import seed_everything 18 | from swift.tuners import Swift 19 | 20 | prompt = "Now you need to use your medical knowledge to help me. I will give you some medical diagnosis reports, and you need to extract disease-related information from the reports. The output should only contain the extracted disease information and should not include any other text. When extracting information, you need to follow six rules as follows: 1. Each piece of disease-related information extracted must meet the format: {severity or stage of disease}{location or organ of disease}{category of disease}, for example, the sentence is \u2018New nodular opacities are clustered within the left upper lobe.\u2019, then the disease-related information extracted should be \u2018{New}{left upper lobe}{nodular opacities}\u2019. 2. Some sentences may lack some information, in this case, you need to use {mask} instead, for example, if the sentence is \u2018There is left lung pneumonia.\u2019, then the disease-related information extracted should be {mask}{left lung}{pneumonia}. 3. If the disease was negatively mentioned in the report, for example, if the sentence is \u2018There is no pneumothorax or left-sided pleural effusion.\u2019, then the disease-related information extracted should be \u2018{No}{mask}{pneumothorax}\u2019 and \u2018{No}{left-sided pleural}{effusion}\u2019. 4. There may be multiple disease descriptions in one sentence, you need to find them all and extract disease-related information. 5. Ignore words irrelevant to disease description, for example, the sentence is \u2018The heart size is normal.\u2019, and there is no disease-related information, so you don\u2019t need to extract information from this sentence. 6. Separate information with commas. \nNext, you can extract information from the report.\n report:\n" 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description="Inference") 25 | parser.add_argument( 26 | "--checkpoint", type=str, help="checkpoint_path", default="./ckpt/" 27 | ) 28 | parser.add_argument( 29 | "--reports", type=str, help="reports_path", default="./reports/" 30 | ) 31 | parser.add_argument( 32 | "--extracted_entity", 33 | type=str, 34 | help="results_path", 35 | default="./extracted_entity/llama3_fine_tuned/", 36 | ) 37 | args = parser.parse_args() 38 | return args 39 | 40 | 41 | if __name__ == "__main__": 42 | args = parse_args() 43 | chekpoint_path = args.checkpoint 44 | reports_path = args.reports 45 | results_path = args.results 46 | 47 | model_type = ModelType.llama3_8b_instruct 48 | template_type = get_default_template_type(model_type) 49 | print(f"template_type: {template_type}") 50 | 51 | kwargs = {} 52 | # kwargs['use_flash_attn'] = True # 使用flash_attn 53 | 54 | model_type = ModelType.llama3_8b_instruct 55 | template_type = get_default_template_type(model_type) 56 | 57 | model, tokenizer = get_model_tokenizer( 58 | model_type, model_kwargs={"device_map": "auto"} 59 | ) 60 | 61 | model = Swift.from_pretrained(model, chekpoint_path, inference_mode=True) 62 | template = get_template(template_type, tokenizer) 63 | seed_everything(42) 64 | for group in range(10, 20): 65 | group = "p" + str(group) 66 | reports = reports_path + group + "/reports.npy" 67 | reports = np.load(reports, allow_pickle=True) 68 | results_path = results_path + group + "/results.npy" 69 | if not os.path.exists(results_path + group): 70 | os.makedirs(results_path + group, exist_ok=True) 71 | 72 | try: 73 | results = np.load(results_path, allow_pickle=True) 74 | except: 75 | results = np.array([]) 76 | 77 | for report in reports[len(results) :]: 78 | query = report["text"] 79 | response, history = inference(model, template, prompt + query) 80 | history.pop() 81 | results = np.append( 82 | results, 83 | { 84 | "path": report["path"], 85 | "report": report["text"], 86 | "response": response, 87 | }, 88 | ) 89 | print(report["path"]) 90 | print(query) 91 | print(response) 92 | print(len(results)) 93 | -------------------------------------------------------------------------------- /downstream/classify.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torchvision import models 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data import Dataset, DataLoader 7 | import numpy as np 8 | import argparse 9 | import timm 10 | 11 | 12 | def collate_fn(data): 13 | images, labels = [], [] 14 | for batch in range(0, len(data)): 15 | images.append(data[batch][0]) 16 | labels.append(data[batch][1]) 17 | data_copy = (np.array(images), np.array(labels)) 18 | return data_copy 19 | 20 | 21 | class CXR(Dataset): 22 | def __init__(self, data, portion=1, mode="train"): 23 | self.data = data 24 | if mode == "train": 25 | self.data = self.data[: int(12800 * portion / 100)] 26 | else: 27 | self.data = self.data[12800:14848] 28 | 29 | def __len__(self): 30 | return len(self.data) 31 | 32 | def __getitem__(self, idx): 33 | return np.array(self.data[idx]["image_npy"]), np.array( 34 | self.data[idx]["label"], dtype=float 35 | ) 36 | 37 | 38 | def compute_metrics(predicted, labels): 39 | correct_predictions = np.all(predicted == labels) 40 | num_correct = np.sum(correct_predictions) 41 | num_total = predicted.shape[0] 42 | accuracy_strict = num_correct / num_total 43 | TP = np.sum((predicted == 1) & (labels == 1)) 44 | FP = np.sum((predicted == 1) & (labels == 0)) 45 | TN = np.sum((predicted == 0) & (labels == 0)) 46 | FN = np.sum((predicted == 0) & (labels == 1)) 47 | 48 | accuracy = (TP + TN) / (TP + FP + TN + FN) 49 | precision = TP / (TP + FP) 50 | recall = TP / (TP + FN) 51 | f1_score = 2 * precision * recall / (precision + recall) 52 | TPR = TP / (TP + FN) 53 | FPR = FP / (FP + TN) 54 | return ( 55 | accuracy_strict, 56 | accuracy, 57 | precision, 58 | recall, 59 | f1_score, 60 | TPR, 61 | FPR, 62 | TP, 63 | FP, 64 | TN, 65 | FN, 66 | ) 67 | 68 | 69 | def parse_args(): 70 | parser = argparse.ArgumentParser(description="Train a classifier") 71 | parser.add_argument("--dataset", type=str, help="数据集", default="pneumonia") 72 | parser.add_argument("--backbone", type=str, help="视觉编码器", default="resnet") 73 | parser.add_argument("--gpu", type=int, help="gpu序号", default=0) 74 | parser.add_argument("--portion", type=int, help="微调数据比例", default=1) 75 | parser.add_argument("--pretrain", type=bool, help="微调数据比例", default=False) 76 | args = parser.parse_args() 77 | return args 78 | 79 | 80 | args = parse_args() 81 | gpu = args.gpu 82 | device = torch.device("cuda:{}".format(gpu) if torch.cuda.is_available() else "cpu") 83 | dataset = args.dataset 84 | portion = args.portion 85 | pretrain = args.pretrain 86 | 87 | 88 | if dataset == "chest14": 89 | datas = np.load("~/data/chest14_train.npy", allow_pickle=True) 90 | elif dataset == "pneumonia": 91 | datas = np.load("~/data/pneumonia_train.npy", allow_pickle=True) 92 | elif dataset == "vinbigdata": 93 | datas = np.load("~/data/vinbigdata_train.npy", allow_pickle=True) 94 | train_dataset = CXR(datas, portion) 95 | print(train_dataset.__len__()) 96 | train_dataloader = DataLoader( 97 | train_dataset, 98 | batch_size=32, 99 | shuffle=True, 100 | num_workers=0, 101 | pin_memory=False, 102 | collate_fn=collate_fn, 103 | drop_last=True, 104 | ) 105 | eval_dataset = CXR(datas, mode="eval") 106 | print(eval_dataset.__len__()) 107 | eval_dataloader = DataLoader( 108 | eval_dataset, 109 | batch_size=32, 110 | shuffle=False, 111 | num_workers=0, 112 | pin_memory=False, 113 | collate_fn=collate_fn, 114 | drop_last=True, 115 | ) 116 | 117 | if args.backbone == "convnext": 118 | model = timm.create_model("convnext_base", pretrained=pretrain) 119 | # 修改第一层卷积以接受单通道输入 120 | # 原始的第一层为: Conv2d(3, 96, kernel_size=4, stride=4) 121 | # 修改为: Conv2d(1, 96, kernel_size=4, stride=4) 122 | model.stem[0] = nn.Conv2d( 123 | 1, 124 | model.stem[0].out_channels, 125 | kernel_size=model.stem[0].kernel_size, 126 | stride=model.stem[0].stride, 127 | padding=model.stem[0].padding, 128 | bias=False, 129 | ) 130 | if dataset == "chest14": 131 | model.head.fc = nn.Linear(model.head.in_features, 10) 132 | elif dataset == "pneumonia": 133 | model.head.fc = nn.Linear(model.head.in_features, 1) 134 | elif dataset == "vinbigdata": 135 | model.head.fc = nn.Linear(model.head.in_features, 8) 136 | model = model.to(device) 137 | elif args.backbone == "resnet": 138 | model = models.resnet50(weights=pretrain) 139 | new_conv1 = nn.Conv2d( 140 | 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False 141 | ) 142 | new_conv1.weight.data = model.conv1.weight.mean(dim=1, keepdim=True) 143 | model.conv1 = new_conv1 144 | num_ftrs = model.fc.in_features 145 | if dataset == "chest14": 146 | model.fc = nn.Linear(num_ftrs, 10) 147 | elif dataset == "pneumonia": 148 | model.fc = nn.Linear(num_ftrs, 1) 149 | elif dataset == "vinbigdata": 150 | model.fc = nn.Linear(num_ftrs, 8) 151 | model = model.to(device) 152 | 153 | 154 | if dataset == "chest14": 155 | criterion = nn.CrossEntropyLoss() 156 | elif dataset == "pneumonia": 157 | criterion = nn.BCEWithLogitsLoss() 158 | elif dataset == "vinbigdata": 159 | criterion = nn.CrossEntropyLoss() 160 | optimizer = optim.Adam(model.parameters(), lr=0.001) 161 | 162 | num_epochs = 20 163 | with torch.cuda.amp.autocast(enabled=True): 164 | for epoch in range(num_epochs): 165 | model.train() 166 | running_loss = 0.0 167 | for inputs, labels in train_dataloader: 168 | inputs, labels = torch.tensor(inputs), torch.tensor(labels) 169 | inputs = inputs.to(device) 170 | labels = labels.to(device) 171 | optimizer.zero_grad() 172 | outputs = model(inputs) 173 | loss = criterion(outputs, labels) 174 | loss.backward() 175 | optimizer.step() 176 | running_loss += loss.item() * inputs.size(0) 177 | 178 | epoch_loss = running_loss / len(train_dataset) 179 | print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}") 180 | 181 | model.eval() 182 | correct = 0 183 | total = 0 184 | threshold = 0.5 185 | predicted_concat = [] 186 | label_concat = [] 187 | with torch.no_grad(): 188 | for inputs, labels in eval_dataloader: 189 | label_concat.append(np.array(labels)) 190 | inputs, labels = torch.tensor(inputs), torch.tensor(labels) 191 | inputs = inputs.to(device) 192 | labels = labels.to(device) 193 | outputs = model(inputs) 194 | predicted = (outputs.data > threshold).int() 195 | predicted_concat.append(predicted.detach().cpu().numpy()) 196 | 197 | predicted_concat = np.concatenate(predicted_concat, axis=0) 198 | label_concat = np.concatenate(label_concat, axis=0) 199 | metric = compute_metrics(predicted_concat, label_concat) 200 | print(metric) 201 | -------------------------------------------------------------------------------- /downstream/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | from torchvision import models 5 | import numpy as np 6 | import timm 7 | from transformers import CLIPProcessor, CLIPModel 8 | 9 | 10 | class VLM(nn.Module): 11 | def __init__(self, embed_dim=768, vision_model="resnet"): 12 | super().__init__() 13 | self.vision_model = vision_model 14 | if self.vision_model == "vit": 15 | self.image_model = timm.create_model( 16 | "eva02_base_patch14_448.mim_in22k_ft_in22k_in1k", 17 | pretrained=True, 18 | num_classes=0, 19 | ) 20 | elif self.vision_model == "resnet": 21 | self.image_model = models.resnet50(weights=True) 22 | num_ftrs = self.image_model.fc.in_features 23 | self.image_model.fc = nn.Linear(num_ftrs, embed_dim) 24 | 25 | self.text_model = CLIPModel.from_pretrained( 26 | "openai/clip-vit-base-patch32" 27 | ).text_model 28 | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 29 | # self.text_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") 30 | self.text_projection = nn.Linear(512, embed_dim, bias=False) 31 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 32 | 33 | def forward(self, images, texts, device, embedding=False): 34 | image_feat = self.image_model(images) 35 | tokens = self.processor( 36 | text=texts, 37 | padding=True, 38 | )["input_ids"] 39 | tokens = torch.tensor(tokens) 40 | tokens = tokens.to(device) 41 | text_feat = self.text_model(tokens).pooler_output 42 | # text_feat = self.text_model.get_text_features(tokens) 43 | text_feat = self.text_projection(text_feat) 44 | image_feat = image_feat / image_feat.norm(dim=1, keepdim=True) 45 | text_feat = text_feat / text_feat.norm(dim=1, keepdim=True) 46 | logit_scale = self.logit_scale.exp() 47 | logits_per = logit_scale * image_feat @ text_feat.t() 48 | if embedding: 49 | return image_feat, text_feat 50 | return logits_per 51 | 52 | 53 | class ConvBlock(nn.Module): 54 | """ 55 | Helper module that consists of a Conv -> BN -> ReLU 56 | """ 57 | 58 | def __init__( 59 | self, 60 | in_channels, 61 | out_channels, 62 | padding=1, 63 | kernel_size=3, 64 | stride=1, 65 | with_nonlinearity=True, 66 | ): 67 | super().__init__() 68 | self.conv = nn.Conv2d( 69 | in_channels, 70 | out_channels, 71 | padding=padding, 72 | kernel_size=kernel_size, 73 | stride=stride, 74 | ) 75 | self.bn = nn.BatchNorm2d(out_channels) 76 | self.relu = nn.ReLU() 77 | self.with_nonlinearity = with_nonlinearity 78 | 79 | def forward(self, x): 80 | x = self.conv(x) 81 | x = self.bn(x) 82 | if self.with_nonlinearity: 83 | x = self.relu(x) 84 | return x 85 | 86 | 87 | class Bridge(nn.Module): 88 | """ 89 | This is the middle layer of the UNet which just consists of some 90 | """ 91 | 92 | def __init__(self, in_channels, out_channels): 93 | super().__init__() 94 | self.bridge = nn.Sequential( 95 | ConvBlock(in_channels, out_channels), ConvBlock(out_channels, out_channels) 96 | ) 97 | 98 | def forward(self, x): 99 | return self.bridge(x) 100 | 101 | 102 | class UpBlockForUNetWithResNet50(nn.Module): 103 | """ 104 | Up block that encapsulates one up-sampling step which consists of Upsample -> ConvBlock -> ConvBlock 105 | """ 106 | 107 | def __init__( 108 | self, 109 | in_channels, 110 | out_channels, 111 | up_conv_in_channels=None, 112 | up_conv_out_channels=None, 113 | upsampling_method="conv_transpose", 114 | ): 115 | super().__init__() 116 | 117 | if up_conv_in_channels == None: 118 | up_conv_in_channels = in_channels 119 | if up_conv_out_channels == None: 120 | up_conv_out_channels = out_channels 121 | 122 | if upsampling_method == "conv_transpose": 123 | self.upsample = nn.ConvTranspose2d( 124 | up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2 125 | ) 126 | elif upsampling_method == "bilinear": 127 | self.upsample = nn.Sequential( 128 | nn.Upsample(mode="bilinear", scale_factor=2), 129 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1), 130 | ) 131 | self.conv_block_1 = ConvBlock(in_channels, out_channels) 132 | self.conv_block_2 = ConvBlock(out_channels, out_channels) 133 | 134 | def forward(self, up_x, down_x): 135 | """ 136 | 137 | :param up_x: this is the output from the previous up block 138 | :param down_x: this is the output from the down block 139 | :return: upsampled feature map 140 | """ 141 | x = self.upsample(up_x) 142 | x = torch.cat([x, down_x], 1) 143 | x = self.conv_block_1(x) 144 | x = self.conv_block_2(x) 145 | return x 146 | 147 | 148 | # This code uses a function from the pytorch-unet-resnet-50-encoder library available at https://github.com/rawmarshmellows/pytorch-unet-resnet-50-encoder/tree/master 149 | 150 | 151 | class UNetWithResnet50Encoder(nn.Module): 152 | DEPTH = 6 153 | 154 | def __init__(self, n_classes=1, weights=None): 155 | super().__init__() 156 | if weights: 157 | resnet = VLM(vision_model="resnet").image_model 158 | resnet.load_state_dict(weights) 159 | else: 160 | resnet = torchvision.models.resnet.resnet50(weights=None) 161 | down_blocks = [] 162 | up_blocks = [] 163 | self.input_block = nn.Sequential(*list(resnet.children()))[:3] 164 | self.input_pool = list(resnet.children())[3] 165 | for bottleneck in list(resnet.children()): 166 | if isinstance(bottleneck, nn.Sequential): 167 | down_blocks.append(bottleneck) 168 | self.down_blocks = nn.ModuleList(down_blocks) 169 | self.bridge = Bridge(2048, 2048) 170 | up_blocks.append(UpBlockForUNetWithResNet50(2048, 1024)) 171 | up_blocks.append(UpBlockForUNetWithResNet50(1024, 512)) 172 | up_blocks.append(UpBlockForUNetWithResNet50(512, 256)) 173 | up_blocks.append( 174 | UpBlockForUNetWithResNet50( 175 | in_channels=128 + 64, 176 | out_channels=128, 177 | up_conv_in_channels=256, 178 | up_conv_out_channels=128, 179 | ) 180 | ) 181 | up_blocks.append( 182 | UpBlockForUNetWithResNet50( 183 | in_channels=64 + 3, 184 | out_channels=64, 185 | up_conv_in_channels=128, 186 | up_conv_out_channels=64, 187 | ) 188 | ) 189 | 190 | self.up_blocks = nn.ModuleList(up_blocks) 191 | 192 | self.out = nn.Conv2d(64, n_classes, kernel_size=1, stride=1) 193 | 194 | def forward(self, x, with_output_feature_map=False): 195 | pre_pools = dict() 196 | pre_pools[f"layer_0"] = x 197 | x = self.input_block(x) 198 | pre_pools[f"layer_1"] = x 199 | x = self.input_pool(x) 200 | 201 | for i, block in enumerate(self.down_blocks, 2): 202 | x = block(x) 203 | if i == (UNetWithResnet50Encoder.DEPTH - 1): 204 | continue 205 | pre_pools[f"layer_{i}"] = x 206 | 207 | x = self.bridge(x) 208 | 209 | for i, block in enumerate(self.up_blocks, 1): 210 | key = f"layer_{UNetWithResnet50Encoder.DEPTH - 1 - i}" 211 | x = block(x, pre_pools[key]) 212 | output_feature_map = x 213 | x = self.out(x) 214 | del pre_pools 215 | if with_output_feature_map: 216 | return x, output_feature_map 217 | else: 218 | return x 219 | -------------------------------------------------------------------------------- /GPT-IE/GPT-IE.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import json 3 | import sys 4 | import numpy as np 5 | import time 6 | import argparse 7 | import openai 8 | 9 | prompts = [ 10 | "Now you need to use your medical knowledge to help me. I will give you some medical image diagnosis reports, and you need to extract disease-related information from the reports sentence by sentence. When extracting information, you need to follow six rules as follows: \n 1. Each piece of disease-related information extracted must meet the format: {descriptor of disease}{location of disease}{organ of disease}{category of disease}, for example, the sentence is 'New nodular opacities are clustered within the left upper lobe.', then the disease-related information extracted should be '{New}{left upper}{lobe}{nodular opacities}'.\n2. Some sentences may lack some information, in this case, you need to use {NA} instead, for example, if the sentence is 'There is left lung pneumonia.', then the disease-related information extracted should be {NA}{left}{lung}{pneumonia}.\n3. If the disease was negatively mentioned in the report, for example, if the sentence is 'There is no pneumothorax or left-sided pleural effusion.', then the disease-related information extracted should be '{No}{NA}{NA}{pneumothorax}' and '{No}{left-sided}{pleural}{effusion}'.\n4. There may be multiple disease descriptions in one sentence, you need to find them all and extract disease-related information.\n5. Ignore words irrelevant to disease description, for example, the sentence is 'The heart size is normal.', and there is no disease-related information, so you don't need to extract information from this sentence .\n6. Separate information with commas.", 11 | "Now you are a medical professional.Each report describes up to 13 diseases: Atelectasis, Cardiomegaly, Consolidation, Edema, Enlarged Cardiomediastinum, Fracture, Lung Lesion, Lung Opacity, Pleural Effusion, Pneumonia ,Pneumothorax, Pleural Other, Support Devices and No Finding. I will give you some reports, you need to label these 13 diseases, if the label is not 3, you need to provide the relevant text in the report to support the reason for this label. The rules for labeling are as follows: \nThe label has four values: 0,1,-1,3. These values have the following interpretation: \n1: The disease was positively mentioned in the report, for example, ‘A large pleural effusion’. 0: The disease was negatively mentioned in the report, for example, ‘No pneumothorax.’. -1: The disease was mentioned with uncertainty in the report, for example, ‘The cardiac size cannot be evaluated.’ , or mentioned with ambiguous language in the report and it is unclear if the pathology exists or not, for example, ‘The cardiac contours are stable.’. 3: No mention of the disease was made in the report.", 12 | "Now you are a medical professional.I will give you some reports, you need to label the diseases mentioned in the reports. The rules for labeling are as follows: \nThe label has three values: 0,1,-1. These values have the following interpretation: \n1: The disease was positively mentioned in the report, for example, ‘A large pleural effusion’.\n 0: The disease was negatively mentioned in the report, for example, ‘No pneumothorax.’.\n -1: The disease was mentioned with uncertainty in the report, for example, ‘The cardiac size cannot be evaluated.’ , or mentioned with ambiguous language in the report and it is unclear if the pathology exists or not, for example, ‘The cardiac contours are stable.’.\n You need to give the disease type and its label and the relevant text in the report to support the reason for the label, for example, 'pleural effusion:1 (Small pleural effusion in the right middle fissure is new)'", 13 | "Now you are a medical professional. I will give you some reports, you need to extract information from the report according to the following prompt template:{disease adjective}{disease location}{disease type}.For example,{acute}{heart}{heart consolidation}.Note that multiple diseases may be described in one report, you need to find them all and extract them with the prompt template.In addition, if 'disease adjective' or 'disease location' information is missing, use {NA} to indicate, but the 'disease type' shoule not be {NA}.If disease was negatively mentioned in the report, for example,‘No pneumothorax.’,the {disease adjective} should be {No}. Common types of diseases include: Engorgement,Consolidation, Opacity, Aerate, Deformity, Fractures, Thicken, Calcification, Aspiration, Pneumonia, Effusion, Pneumothorax. You should strictly follow the prompt template. One message per line, don't put multiple diseases in one message, ignore all text in the report except the disease description.I will use the information you provide for image classification, try to generate information that fits my task.", 14 | ] 15 | 16 | 17 | class SimChatGPT: 18 | 19 | def __init__(self, api_key: str, messages: List = None): 20 | openai.api_key = api_key 21 | if messages: 22 | self.messages = messages 23 | else: 24 | self.messages = [ 25 | {"role": "system", "content": "You are a helpful assistant."}, 26 | {"role": "user", "content": prompts[0]}, 27 | ] 28 | 29 | def ask(self) -> str: 30 | response = openai.ChatCompletion.create( 31 | model="gpt-3.5-turbo", 32 | # model="gpt-4", 33 | messages=self.messages, 34 | temperature=0, 35 | ) 36 | response_content = response["choices"][0]["message"]["content"] 37 | return response_content 38 | 39 | def predict(self, report: str) -> str: 40 | self.messages.append({"role": "user", "content": report}) 41 | response_content = self.ask() 42 | self.messages.pop() 43 | return response_content 44 | 45 | 46 | if __name__ == "__main__": 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument("group", type=str, help="p10-p19") 49 | args = parser.parse_args() 50 | # Reports for test 51 | reports = [ 52 | "The lung volumes are low. The cardiac, mediastinal and hilar contours appear unchanged, allowing for differences in technique. There are a number of round nodular densities projecting over each upper lung, but more numerous and discretely visualized in the left upper lobe, similar to prior study. However, in addition, there is a more hazy widespread opacity projecting over the left mid upper lung which could be compatible with a coinciding pneumonia. Pulmonary nodules in the left upper lobe are also not completely characterized on this study. There is no pleural effusion or pneumothorax. Post-operative changes are similar along the right chest wall.", 53 | "Lung volumes remain low. There are innumerable bilateral scattered small pulmonary nodules which are better demonstrated on recent CT. Mild pulmonary vascular congestion is stable. The cardiomediastinal silhouette and hilar contours are unchanged. Small pleural effusion in the right middle fissure is new. There is no new focal opacity to suggest pneumonia. There is no pneumothorax.", 54 | ] 55 | 56 | api_key = "" 57 | sim_chatgpt = SimChatGPT(api_key=api_key) 58 | print(sim_chatgpt.ask()) 59 | reports_path = "./reports/" + args.group + "/reports.npy" 60 | reports = np.load(reports_path, allow_pickle=True) 61 | results_path = "./results/" + args.group + "/results.npy" 62 | resume_path = "./results/" + args.group + "/resumes.npy" 63 | 64 | try: 65 | results = np.load(results_path, allow_pickle=True) 66 | except: 67 | try: 68 | results = np.load(resume_path, allow_pickle=True) 69 | except: 70 | results = np.array([]) 71 | 72 | for report in reports[len(results) :]: 73 | try: 74 | if len(results) % 100 == 0 and len(results) > 0: 75 | np.save(results_path, results) 76 | print(sim_chatgpt.ask()) 77 | if len(results) % 1000 == 0 and len(results) > 0: 78 | np.save(resume_path, results) 79 | 80 | contents = sim_chatgpt.predict(report["text"]) 81 | results = np.append( 82 | results, 83 | {"path": report["path"], "report": report["text"], "prompt": contents}, 84 | ) 85 | print(report["path"]) 86 | print(report["text"]) 87 | print(contents) 88 | print(len(results)) 89 | except Exception as e: 90 | print(f"Error occurred: {e}") 91 | print(report["path"]) 92 | if len(results) > 0: 93 | np.save(resume_path, results) 94 | break 95 | 96 | if len(results) > 0: 97 | np.save(results_path, results) 98 | if len(results) == len(reports): 99 | time.sleep(1000) 100 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /downstream/segment.py: -------------------------------------------------------------------------------- 1 | # The code is from https://www.kaggle.com/code/nayem163/unet-with-se-resnet50-stage-2 2 | # The dataset is available at https://www.kaggle.com/datasets/vbookshelf/pneumothorax-chest-xray-images-and-masks/data 3 | import os 4 | import cv2 5 | import pdb 6 | import time 7 | import warnings 8 | import random 9 | import numpy as np 10 | import pandas as pd 11 | from tqdm import tqdm_notebook as tqdm 12 | from torch.optim.lr_scheduler import ReduceLROnPlateau 13 | from sklearn.model_selection import StratifiedKFold 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn import functional as F 17 | import torch.optim as optim 18 | import torch.backends.cudnn as cudnn 19 | from torch.utils.data import DataLoader, Dataset, sampler 20 | from matplotlib import pyplot as plt 21 | import segmentation_models_pytorch as smp 22 | from albumentations import ( 23 | MultiplicativeNoise, 24 | HorizontalFlip, 25 | OpticalDistortion, 26 | VerticalFlip, 27 | GridDistortion, 28 | RandomBrightnessContrast, 29 | OneOf, 30 | ElasticTransform, 31 | RandomGamma, 32 | IAAEmboss, 33 | Blur, 34 | RandomRotate90, 35 | Transpose, 36 | ShiftScaleRotate, 37 | Normalize, 38 | Resize, 39 | Compose, 40 | GaussNoise, 41 | ) 42 | from albumentations.pytorch import ToTensorV2 43 | 44 | warnings.filterwarnings("ignore") 45 | import albumentations as A 46 | 47 | A.MultiplicativeNoise() 48 | 49 | import math 50 | import torch 51 | from torch.optim.optimizer import Optimizer, required 52 | 53 | 54 | class RAdam(Optimizer): 55 | 56 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 57 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 58 | self.buffer = [[None, None, None] for ind in range(10)] 59 | super(RAdam, self).__init__(params, defaults) 60 | 61 | def __setstate__(self, state): 62 | super(RAdam, self).__setstate__(state) 63 | 64 | def step(self, closure=None): 65 | 66 | loss = None 67 | if closure is not None: 68 | loss = closure() 69 | 70 | for group in self.param_groups: 71 | 72 | for p in group["params"]: 73 | if p.grad is None: 74 | continue 75 | grad = p.grad.data.float() 76 | if grad.is_sparse: 77 | raise RuntimeError("RAdam does not support sparse gradients") 78 | 79 | p_data_fp32 = p.data.float() 80 | 81 | state = self.state[p] 82 | 83 | if len(state) == 0: 84 | state["step"] = 0 85 | state["exp_avg"] = torch.zeros_like(p_data_fp32) 86 | state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) 87 | else: 88 | state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) 89 | state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) 90 | 91 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 92 | beta1, beta2 = group["betas"] 93 | 94 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 95 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 96 | 97 | state["step"] += 1 98 | buffered = self.buffer[int(state["step"] % 10)] 99 | if state["step"] == buffered[0]: 100 | N_sma, step_size = buffered[1], buffered[2] 101 | else: 102 | buffered[0] = state["step"] 103 | beta2_t = beta2 ** state["step"] 104 | N_sma_max = 2 / (1 - beta2) - 1 105 | N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) 106 | buffered[1] = N_sma 107 | 108 | # more conservative since it's an approximated value 109 | if N_sma >= 5: 110 | step_size = ( 111 | group["lr"] 112 | * math.sqrt( 113 | (1 - beta2_t) 114 | * (N_sma - 4) 115 | / (N_sma_max - 4) 116 | * (N_sma - 2) 117 | / N_sma 118 | * N_sma_max 119 | / (N_sma_max - 2) 120 | ) 121 | / (1 - beta1 ** state["step"]) 122 | ) 123 | else: 124 | step_size = group["lr"] / (1 - beta1 ** state["step"]) 125 | buffered[2] = step_size 126 | 127 | if group["weight_decay"] != 0: 128 | p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32) 129 | 130 | # more conservative since it's an approximated value 131 | if N_sma >= 5: 132 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 133 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 134 | else: 135 | p_data_fp32.add_(-step_size, exp_avg) 136 | 137 | p.data.copy_(p_data_fp32) 138 | 139 | return loss 140 | 141 | 142 | class SIIMDataset(Dataset): 143 | def __init__(self, fnames, data_folder, size, mean, std, phase): 144 | self.root = data_folder 145 | self.size = size 146 | self.mean = mean 147 | self.std = std 148 | self.phase = phase 149 | self.transforms = get_transforms(phase, size, mean, std) 150 | self.fnames = fnames 151 | 152 | def __getitem__(self, idx): 153 | image_id = self.fnames[idx] 154 | image_path = self.root + "/png_images/" + image_id 155 | image = cv2.imread(image_path) 156 | mask_path = self.root + "/png_masks/" + image_id 157 | mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) 158 | mask = np.array((mask > 0), dtype=float) 159 | augmented = self.transforms(image=image, mask=mask) 160 | image = augmented["image"] 161 | mask = augmented["mask"] 162 | return image, mask 163 | 164 | def __len__(self): 165 | return len(self.fnames) 166 | 167 | 168 | def get_transforms(phase, size, mean, std): 169 | list_transforms = [] 170 | if phase == "train": 171 | list_transforms.extend( 172 | [ 173 | HorizontalFlip(p=0.5), 174 | ShiftScaleRotate( 175 | shift_limit=0, # no resizing 176 | scale_limit=0.1, 177 | rotate_limit=10, # rotate 178 | p=0.5, 179 | border_mode=cv2.BORDER_CONSTANT, 180 | ), 181 | GaussNoise(), 182 | A.MultiplicativeNoise(multiplier=1.5, p=1), 183 | ] 184 | ) 185 | list_transforms.extend( 186 | [ 187 | Resize(size, size), 188 | Normalize(mean=mean, std=std, p=1), 189 | ToTensorV2(), 190 | ] 191 | ) 192 | 193 | list_trfms = Compose(list_transforms) 194 | return list_trfms 195 | 196 | 197 | def provider( 198 | fold, 199 | total_folds, 200 | data_folder, 201 | df_path, 202 | phase, 203 | size, 204 | mean=None, 205 | std=None, 206 | batch_size=8, 207 | num_workers=2, 208 | ): 209 | df_all = pd.read_csv(df_path) 210 | df = df_all.drop_duplicates("new_filename") 211 | df_with_mask = df[df["has_pneumo"] == 1] 212 | df_with_mask["has_mask"] = 1 213 | df_without_mask = df[df["has_pneumo"] == 0] 214 | df_without_mask["has_mask"] = 0 215 | df_without_mask_sampled = df_without_mask.sample( 216 | len(df_with_mask) + 1500, random_state=2019 217 | ) # random state is imp 218 | df = pd.concat([df_with_mask, df_without_mask_sampled]) 219 | 220 | # NOTE: equal number of positive and negative cases are chosen. 221 | 222 | kfold = StratifiedKFold(total_folds, shuffle=True, random_state=43) 223 | train_idx, val_idx = list(kfold.split(df["ImageId"], df["has_mask"]))[fold] 224 | train_df, val_df = df.iloc[train_idx], df.iloc[val_idx] 225 | df = train_df if phase == "train" else val_df 226 | # NOTE: total_folds=5 -> train/val : 80%/20% 227 | 228 | fnames = df["new_filename"].values 229 | 230 | image_dataset = SIIMDataset(fnames, data_folder, size, mean, std, phase) 231 | 232 | dataloader = DataLoader( 233 | image_dataset, 234 | batch_size=batch_size, 235 | num_workers=num_workers, 236 | pin_memory=True, 237 | shuffle=True, 238 | ) 239 | return dataloader 240 | 241 | 242 | def dice_loss(input, target): 243 | input = torch.sigmoid(input) 244 | smooth = 1.0 245 | iflat = input.view(-1) 246 | tflat = target.view(-1) 247 | # tflat = np.reshape(target,(4, 1, 512, 512)) 248 | intersection = (iflat * tflat).sum() 249 | return (2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth) 250 | 251 | 252 | class FocalLoss(nn.Module): 253 | def __init__(self, gamma): 254 | super().__init__() 255 | self.gamma = gamma 256 | 257 | def forward(self, input, target): 258 | if not (target.size() == input.size()): 259 | raise ValueError( 260 | "Target size ({}) must be the same as input size ({})".format( 261 | target.size(), input.size() 262 | ) 263 | ) 264 | max_val = (-input).clamp(min=0) 265 | loss = ( 266 | input 267 | - input * target 268 | + max_val 269 | + ((-max_val).exp() + (-input - max_val).exp()).log() 270 | ) 271 | invprobs = F.logsigmoid(-input * (target * 2.0 - 1.0)) 272 | loss = (invprobs * self.gamma).exp() * loss 273 | return loss.mean() 274 | 275 | 276 | class MixedLoss(nn.Module): 277 | def __init__(self, alpha, gamma): 278 | super().__init__() 279 | self.alpha = alpha 280 | self.focal = FocalLoss(gamma) 281 | 282 | def forward(self, input, target): 283 | loss = self.alpha * self.focal(input, target) - torch.log( 284 | dice_loss(input, target) 285 | ) 286 | return loss.mean() 287 | 288 | 289 | def predict(X, threshold): 290 | X_p = np.copy(X) 291 | preds = (X_p > threshold).astype("uint8") 292 | return preds 293 | 294 | 295 | def metric(probability, truth, threshold=0.9, reduction="none"): 296 | """Calculates dice of positive and negative images seperately""" 297 | """probability and truth must be torch tensors""" 298 | batch_size = len(truth) 299 | with torch.no_grad(): 300 | probability = probability.view(batch_size, -1) 301 | truth = truth.view(batch_size, -1) 302 | assert probability.shape == truth.shape 303 | 304 | p = (probability > threshold).float() 305 | t = (truth > 0.5).float() 306 | 307 | t_sum = t.sum(-1) 308 | p_sum = p.sum(-1) 309 | neg_index = torch.nonzero(t_sum == 0) 310 | pos_index = torch.nonzero(t_sum >= 1) 311 | 312 | dice_neg = (p_sum == 0).float() 313 | dice_pos = 2 * (p * t).sum(-1) / ((p + t).sum(-1)) 314 | 315 | dice_neg = dice_neg[neg_index] 316 | dice_pos = dice_pos[pos_index] 317 | dice = torch.cat([dice_pos, dice_neg]) 318 | 319 | dice_neg = np.nan_to_num(dice_neg.mean().item(), 0) 320 | dice_pos = np.nan_to_num(dice_pos.mean().item(), 0) 321 | dice = dice.mean().item() 322 | 323 | num_neg = len(neg_index) 324 | num_pos = len(pos_index) 325 | 326 | return dice, dice_neg, dice_pos, num_neg, num_pos 327 | 328 | 329 | def compute_ious(pred, label, classes, ignore_index=255, only_present=True): 330 | """computes iou for one ground truth mask and predicted mask""" 331 | pred[label == ignore_index] = 0 332 | ious = [] 333 | for c in classes: 334 | label_c = label == c 335 | if only_present and np.sum(label_c) == 0: 336 | ious.append(np.nan) 337 | continue 338 | pred_c = pred == c 339 | intersection = np.logical_and(pred_c, label_c).sum() 340 | union = np.logical_or(pred_c, label_c).sum() 341 | if union != 0: 342 | ious.append(intersection / union) 343 | return ious if ious else [1] 344 | 345 | 346 | def compute_iou_batch(outputs, labels, classes=None): 347 | """computes mean iou for a batch of ground truth masks and predicted masks""" 348 | ious = [] 349 | preds = np.copy(outputs) # copy is imp 350 | labels = np.array(labels) # tensor to np 351 | for pred, label in zip(preds, labels): 352 | ious.append(np.nanmean(compute_ious(pred, label, classes))) 353 | iou = np.nanmean(ious) 354 | return iou 355 | 356 | 357 | class Meter: 358 | """A meter to keep track of iou and dice scores throughout an epoch""" 359 | 360 | def __init__(self, phase, epoch): 361 | self.base_threshold = 0.55 # <<<<<<<<<<< here's the threshold 362 | self.base_dice_scores = [] 363 | self.dice_neg_scores = [] 364 | self.dice_pos_scores = [] 365 | self.iou_scores = [] 366 | 367 | def update(self, targets, outputs): 368 | probs = torch.sigmoid(outputs) 369 | dice, dice_neg, dice_pos, _, _ = metric(probs, targets, self.base_threshold) 370 | self.base_dice_scores.append(dice) 371 | self.dice_pos_scores.append(dice_pos) 372 | self.dice_neg_scores.append(dice_neg) 373 | preds = predict(probs, self.base_threshold) 374 | iou = compute_iou_batch(preds, targets, classes=[1]) 375 | self.iou_scores.append(iou) 376 | 377 | def get_metrics(self): 378 | dice = np.mean(self.base_dice_scores) 379 | dice_neg = np.mean(self.dice_neg_scores) 380 | dice_pos = np.mean(self.dice_pos_scores) 381 | dices = [dice, dice_neg, dice_pos] 382 | iou = np.nanmean(self.iou_scores) 383 | return dices, iou 384 | 385 | 386 | def epoch_log(phase, epoch, epoch_loss, meter, start): 387 | """logging the metrics at the end of an epoch""" 388 | dices, iou = meter.get_metrics() 389 | dice, dice_neg, dice_pos = dices 390 | print( 391 | "Loss: %0.4f | dice: %0.4f | dice_neg: %0.4f | dice_pos: %0.4f | IoU: %0.4f" 392 | % (epoch_loss, dice, dice_neg, dice_pos, iou) 393 | ) 394 | return dice, iou 395 | 396 | 397 | class Trainer(object): 398 | """This class takes care of training and validation of our model""" 399 | 400 | def __init__(self, model, df_path, data_folder, device): 401 | self.fold = 1 402 | self.total_folds = 5 403 | self.num_workers = 4 404 | self.batch_size = {"train": 4, "val": 4} 405 | self.accumulation_steps = 32 // self.batch_size["train"] 406 | self.lr = 5e-4 407 | self.num_epochs = 32 408 | self.best_loss = float("inf") 409 | self.phases = ["train", "val"] 410 | self.device = device 411 | # torch.set_default_tensor_type("torch.cuda.FloatTensor") 412 | self.net = model 413 | self.criterion = MixedLoss(2.0, 2.0) 414 | # self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr) 415 | self.optimizer = RAdam(model.parameters(), lr=self.lr) 416 | 417 | self.scheduler = ReduceLROnPlateau( 418 | self.optimizer, mode="min", patience=3, verbose=True 419 | ) 420 | self.net = self.net.to(self.device) 421 | cudnn.benchmark = True 422 | self.dataloaders = { 423 | phase: provider( 424 | fold=0, 425 | total_folds=5, 426 | data_folder=data_folder, 427 | df_path=df_path, 428 | phase="train", 429 | size=512, 430 | mean=(0.493, 0.493, 0.493), 431 | std=(0.250, 0.250, 0.250), 432 | batch_size=4, 433 | num_workers=2, 434 | ) 435 | for phase in self.phases 436 | } 437 | self.losses = {phase: [] for phase in self.phases} 438 | self.iou_scores = {phase: [] for phase in self.phases} 439 | self.dice_scores = {phase: [] for phase in self.phases} 440 | 441 | def forward(self, images, targets): 442 | images = images.to(self.device) 443 | masks = targets.to(self.device) 444 | outputs = self.net(images) 445 | loss = self.criterion(outputs, masks) 446 | return loss, outputs 447 | 448 | def iterate(self, epoch, phase): 449 | meter = Meter(phase, epoch) 450 | start = time.strftime("%H:%M:%S") 451 | print(f"Starting epoch: {epoch} | phase: {phase} | ⏰: {start}") 452 | batch_size = self.batch_size[phase] 453 | self.net.train(phase == "train") 454 | dataloader = self.dataloaders[phase] 455 | running_loss = 0.0 456 | total_batches = len(dataloader) 457 | # tk0 = tqdm(dataloader, total=total_batches) 458 | self.optimizer.zero_grad() 459 | for itr, batch in enumerate(dataloader): 460 | images, targets = batch 461 | targets = targets.unsqueeze(1) 462 | loss, outputs = self.forward(images, targets) 463 | loss = loss / self.accumulation_steps 464 | if phase == "train": 465 | loss.backward() 466 | if (itr + 1) % self.accumulation_steps == 0: 467 | self.optimizer.step() 468 | self.optimizer.zero_grad() 469 | running_loss += loss.item() 470 | outputs = outputs.detach().cpu() 471 | meter.update(targets, outputs) 472 | # tk0.set_postfix(loss=(running_loss / ((itr + 1)))) 473 | epoch_loss = (running_loss * self.accumulation_steps) / total_batches 474 | dice, iou = epoch_log(phase, epoch, epoch_loss, meter, start) 475 | self.losses[phase].append(epoch_loss) 476 | self.dice_scores[phase].append(dice) 477 | self.iou_scores[phase].append(iou) 478 | torch.cuda.empty_cache() 479 | return epoch_loss 480 | 481 | def start(self): 482 | for epoch in range(self.num_epochs): 483 | self.iterate(epoch, "train") 484 | state = { 485 | "epoch": epoch, 486 | "best_loss": self.best_loss, 487 | "state_dict": self.net.state_dict(), 488 | "optimizer": self.optimizer.state_dict(), 489 | } 490 | val_loss = self.iterate(epoch, "val") 491 | self.scheduler.step(val_loss) 492 | if val_loss < self.best_loss: 493 | print("******** New optimal found, saving state ********") 494 | state["best_loss"] = self.best_loss = val_loss 495 | torch.save(state, "./model.pth") 496 | print() 497 | 498 | 499 | from models import UNetWithResnet50Encoder 500 | 501 | 502 | df_path = "~/data/siim-acr-pneumothorax/stage_1_train_images.csv" 503 | data_folder = "~/data/siim-acr-pneumothorax" 504 | 505 | 506 | device = torch.device("cuda:{}".format(6) if torch.cuda.is_available() else "cpu") 507 | ckpt_path = "~/MedFILIP/train/pretrained/image_model.pt" 508 | checkpoints = torch.load(ckpt_path, map_location="cpu") 509 | model = UNetWithResnet50Encoder(weights=checkpoints["image_model"]).to(device) 510 | model = UNetWithResnet50Encoder().to(device) 511 | 512 | 513 | print(model) 514 | model_trainer = Trainer(model, df_path, data_folder, device) 515 | model_trainer.start() 516 | -------------------------------------------------------------------------------- /train/constants.py: -------------------------------------------------------------------------------- 1 | view_position = ["AP", "PA", "LATERAL", "LL", ""] 2 | disease_location_organ = [ 3 | "costophrenic sinus", 4 | "left base", 5 | "ventricle", 6 | "mediastinal contour", 7 | "left basal", 8 | "right lung apex", 9 | "hemidiaphragms", 10 | "bronchovascular", 11 | "lobe", 12 | "parenchymal", 13 | "hila", 14 | "stomach", 15 | "bony structures", 16 | "right upper lobe", 17 | "vein", 18 | "pulmonary vascularity", 19 | "vessels", 20 | "aortic", 21 | "aortic valve", 22 | "thoracic", 23 | "pectoral", 24 | "left hemidiaphragm", 25 | "bony", 26 | "hemidiaphragmatic contour", 27 | "right atrium", 28 | "both lower", 29 | "apical", 30 | "aortic knob", 31 | "displaced", 32 | "abnormalities", 33 | "mediastinal and hilar contours", 34 | "left chest wall", 35 | "right hemidiaphragm", 36 | "distal", 37 | "left apical", 38 | "tube", 39 | "basilar", 40 | "low", 41 | "right apex", 42 | "port-a-cath", 43 | "et tube", 44 | "silhouette", 45 | "nasogastric tube", 46 | "right middle lobe", 47 | "vasculature", 48 | "pleura", 49 | "thoracic spine", 50 | "upper zone", 51 | "surfaces", 52 | "picc line", 53 | "diaphragm", 54 | "silhouettes", 55 | "subcutaneous", 56 | "pulmonary vasculature", 57 | "pleural", 58 | "mediastinal", 59 | "osseous", 60 | "pulmonary artery", 61 | "right hilar", 62 | "alveolar", 63 | "lower thoracic", 64 | "lateral", 65 | "right lower lobe", 66 | "lingula", 67 | "aorta", 68 | "venous", 69 | "right apical", 70 | "right chest wall", 71 | "bilateral lower", 72 | "upper", 73 | "right infrahilar", 74 | "cardiopulmonary", 75 | "hiatal", 76 | "perihilar", 77 | "contour", 78 | "lower", 79 | "carinal", 80 | "focal", 81 | "mid thoracic", 82 | "vertebral body", 83 | "central", 84 | "both", 85 | "costophrenic angles", 86 | "region", 87 | "chest", 88 | "left mid and lower", 89 | "interstitial", 90 | "internal jugular", 91 | "lymphadenopathy", 92 | "left upper", 93 | "airspace", 94 | "multifocal", 95 | "lung base", 96 | "pulmonary venous", 97 | "hilar", 98 | "ribs", 99 | "left lower lobe", 100 | "base", 101 | "lung", 102 | "right basal", 103 | "cardiomediastinal silhouette", 104 | "right middle and lower", 105 | "right lower", 106 | "catheter", 107 | "bilateral", 108 | "trachea", 109 | "hemi thorax", 110 | "right lateral", 111 | "chest wall", 112 | "bone", 113 | "contours", 114 | "descending", 115 | "overt", 116 | "right lung", 117 | "bibasal", 118 | "veins", 119 | "mid and lower", 120 | "endotracheal tube", 121 | "retrocardiac", 122 | "humerus", 123 | "left retrocardiac", 124 | "sternal wires", 125 | "right mid lung", 126 | "tracheostomy tube", 127 | "left", 128 | "subclavian", 129 | "left lung base", 130 | "left subclavian", 131 | "abdomen", 132 | "descending aorta", 133 | "right upper quadrant", 134 | "left mid", 135 | "carina", 136 | "lower lobes", 137 | "intrathoracic", 138 | "apex", 139 | "lungs", 140 | "right-sided", 141 | "cavoatrial junction", 142 | "middle", 143 | "lower lobe", 144 | "spine", 145 | "bones", 146 | "right upper", 147 | "mediastinal and hilar", 148 | "na", 149 | "adenopathy", 150 | "hilus", 151 | "clavicle", 152 | "upper lobe", 153 | "thoracic aorta", 154 | "mid", 155 | "right", 156 | "cardiac silhouette", 157 | "right perihilar", 158 | "cardiac", 159 | "right ventricle", 160 | "line", 161 | "lingular", 162 | "left perihilar", 163 | "adjacent", 164 | "underlying", 165 | "right base", 166 | "hilum", 167 | "aortic arch", 168 | "pericardial", 169 | "right basilar", 170 | "left-sided", 171 | "pulmonary", 172 | "pulmonary arteries", 173 | "osseous structures", 174 | "lobes", 175 | "shoulder", 176 | "right mid and lower", 177 | "zone", 178 | "left lung", 179 | "mid svc", 180 | "subdiaphragmatic", 181 | "mediastinum", 182 | "posterior", 183 | "vascular", 184 | "left mid lung", 185 | "thoracolumbar junction", 186 | "rib", 187 | "lung bases", 188 | "cardiomediastinal", 189 | "pulmonary venous pressure", 190 | "basal", 191 | "heart", 192 | "ng tube", 193 | "areas", 194 | "biapical", 195 | "right lung bases", 196 | "hemithorax", 197 | "left lateral", 198 | "atrium", 199 | "hilar and mediastinal", 200 | "left lower", 201 | "upper abdomen", 202 | "left upper quadrant", 203 | "pulmonary vascular", 204 | "bases", 205 | "lymph nodes", 206 | "both bases", 207 | "pericardium", 208 | "costophrenic angle", 209 | "mediastinal contours", 210 | "right lung base", 211 | "infrahilar", 212 | "bibasilar", 213 | "left pectoral", 214 | "soft tissues", 215 | "right middle", 216 | "esophagus", 217 | "lung apices", 218 | "pulmonary vessels", 219 | "hemidiaphragm", 220 | "right mid", 221 | "right internal jugular", 222 | "cardiac and mediastinal", 223 | "left basilar", 224 | "lymph node", 225 | "cardiomediastinal and hilar", 226 | "lung parenchyma", 227 | "both lung bases", 228 | ] 229 | disease_adjective = [ 230 | "concurrent", 231 | "vague", 232 | "probably", 233 | "dilated", 234 | "biapical", 235 | "small if any", 236 | "less prominent", 237 | "widening", 238 | "bibasal", 239 | "resolving", 240 | "consistent with", 241 | "subsequent", 242 | "smaller", 243 | "decreasing", 244 | "minimally increased", 245 | "postoperative", 246 | "healed", 247 | "worse", 248 | "worrisome", 249 | "significant", 250 | "potentially", 251 | "old healed", 252 | "asymmetrical", 253 | "scarring", 254 | "similar", 255 | "complete", 256 | "little", 257 | "acute", 258 | "prior", 259 | "moderate-to-severe", 260 | "perihilar", 261 | "loculated", 262 | "progressive", 263 | "possibility of", 264 | "basilar", 265 | "slightly improved", 266 | "infectious", 267 | "ill-defined", 268 | "lungs", 269 | "central", 270 | "supervening", 271 | "subsegmental", 272 | "massive", 273 | "hyperinflation", 274 | "partial", 275 | "crowding", 276 | "some", 277 | "moderate to large", 278 | "minor", 279 | "tortuosity", 280 | "vascular", 281 | "degenerative", 282 | "subcutaneous", 283 | "consolidation", 284 | "heterogeneous", 285 | "right lower", 286 | "presumed", 287 | "slightly increased", 288 | "moderately severe", 289 | "mediastinal", 290 | "small-to-moderate", 291 | "standard", 292 | "atelectatic", 293 | "upper", 294 | "concerning", 295 | "diffuse bilateral", 296 | "hazy", 297 | "interstitial", 298 | "top-normal", 299 | "moderately", 300 | "infection", 301 | "pre-existing", 302 | "associated", 303 | "blunting", 304 | "layering", 305 | "streaky", 306 | "left-sided", 307 | "moderately enlarged", 308 | "early", 309 | "dense", 310 | "widespread", 311 | "right-sided", 312 | "intact", 313 | "old", 314 | "interval improvement", 315 | "tiny", 316 | "improvement", 317 | "tortuous", 318 | "decrease", 319 | "left lower", 320 | "elevated", 321 | "elevation", 322 | "unremarkable", 323 | "interval increase", 324 | "moderate to severe", 325 | "previous", 326 | "marked", 327 | "pleural", 328 | "asymmetric", 329 | "small to moderate", 330 | "hyperinflated", 331 | "increase", 332 | "prominence", 333 | "resolved", 334 | "focal", 335 | "developing", 336 | "prominent", 337 | "improving", 338 | "volume loss", 339 | "worsened", 340 | "mild-to-moderate", 341 | "interval", 342 | "residual", 343 | "adjacent", 344 | "aspiration", 345 | "compressive", 346 | "retrocardiac", 347 | "lower", 348 | "constant", 349 | "linear", 350 | "known", 351 | "underlying", 352 | "subtle", 353 | "slight", 354 | "multiple", 355 | "likely", 356 | "trace", 357 | "continued", 358 | "mild to moderate", 359 | "calcified", 360 | "mildly enlarged", 361 | "mildly", 362 | "probable", 363 | "multifocal", 364 | "superimposed", 365 | "enlargement", 366 | "extensive", 367 | "diffuse", 368 | "increasing", 369 | "clear", 370 | "patchy", 371 | "substantial", 372 | "chronic", 373 | "enlarged", 374 | "borderline", 375 | "worsening", 376 | "large", 377 | "possible", 378 | "decreased", 379 | "bibasilar", 380 | "severe", 381 | "improved", 382 | "persistent", 383 | "bilateral", 384 | "left", 385 | "right", 386 | "increased", 387 | "minimal", 388 | "stable", 389 | "low", 390 | "normal", 391 | "new", 392 | "unchanged", 393 | "moderate", 394 | "small", 395 | "mild", 396 | ] 397 | disease_type = [ 398 | "ascites", 399 | "pulmonary embolism", 400 | "pleural fluid", 401 | "pseudoaneurysm", 402 | "pulmonary nodules", 403 | "pneumoperitoneum", 404 | "dislocation", 405 | "inflammation", 406 | "hyperinflation", 407 | "tracheostomy tube", 408 | "pulmonary fibrosis", 409 | "infectious process", 410 | "cardiac enlargement", 411 | "bronchiectasis", 412 | "lesion", 413 | "cyst", 414 | "abscess", 415 | "embolism", 416 | "pneumomediastinum", 417 | "density", 418 | "granuloma", 419 | "adenopathy", 420 | "pacemaker", 421 | "subcutaneous emphysema", 422 | "hemorrhage", 423 | "dilatation", 424 | "nodules", 425 | "infiltrate", 426 | "aneurysm", 427 | "tension", 428 | "interstitial markings", 429 | "scoliosis", 430 | "fibrosis", 431 | "copd", 432 | "hernia", 433 | "hematoma", 434 | "hiatal hernia", 435 | "lymphadenopathy", 436 | "vascular engorgement", 437 | "occlusion", 438 | "thrombosis", 439 | "engorgement", 440 | "aeration", 441 | "nodule", 442 | "calcification", 443 | "scarring", 444 | "infection", 445 | "fracture", 446 | "emphysema", 447 | "pneumothoraces", 448 | "abnormalities", 449 | "vascular congestion", 450 | "opacification", 451 | "consolidation", 452 | "opacity", 453 | "cardiomegaly", 454 | "edema", 455 | "pneumonia", 456 | "atelectasis", 457 | "effusion", 458 | "pneumothorax", 459 | ] 460 | description_book = { 461 | "ascites": "May show elevation of the diaphragm or a fluid wave on ultrasound due to accumulation of fluid in the abdominal cavity.", 462 | "pulmonary embolism": "Can cause areas of the lung to appear darker due to infarction or a wedge-shaped (Hampton hump) opacity; may also cause enlargement of pulmonary arteries.", 463 | "pleural fluid": "Appears as a homogenous density, usually in the lower lung fields or as a meniscus sign along the chest wall, indicating fluid accumulation in the pleural space.", 464 | "pseudoaneurysm": "An outpouching of an arterial wall that can appear as a rounded or saccular density adjacent to blood vessels, often requiring other imaging modalities to confirm.", 465 | "pulmonary nodules": "Small, rounded opacities within the lung parenchyma, which can vary in size and number.", 466 | "pneumoperitoneum": "Free air in the abdominal cavity that can accumulate under the diaphragm, appearing as a sharp line of lucency on an upright X-ray.", 467 | "dislocation": "Misalignment of bone joints, usually visible in the shoulder or neck area, where bones are not in their normal positions.", 468 | "inflammation": "May manifest as increased opacity due to fluid accumulation or swelling in the lung tissue or surrounding structures.", 469 | "hyperinflation": "Over-expanded lung fields, flattened diaphragms, and increased retrosternal air space, commonly seen in obstructive lung diseases.", 470 | "tracheostomy tube": "A radiopaque line or shadow within the trachea indicating the presence of a tracheostomy tube used for ventilation.", 471 | "pulmonary fibrosis": "Reticular markings, honeycombing, and possibly traction bronchiectasis, indicating stiff, scarred lung tissue.", 472 | "infectious process": "Variable appearance, from focal to diffuse opacities, consolidation, or cavitation, depending on the type and extent of the infection.", 473 | "cardiac enlargement": "Enlarged cardiac silhouette exceeding the normal size limits for the heart on the chest X-ray.", 474 | "bronchiectasis": "Abnormally dilated and thick-walled bronchi that can appear as ring-like or tubular structures on the X-ray.", 475 | "lesion": "A general term for an abnormal area; on an X-ray, it appears as an area of increased density or opacity.", 476 | "cyst": "A fluid-filled space that can appear as a round, well-defined lucency or opacity, depending on its content and wall thickness.", 477 | "abscess": "Appears as a round or ovoid density, often with a fluid level indicating the presence of pus in a new or pre-existing cavity.", 478 | "embolism": "A blockage in a blood vessel that may not be directly visible on X-ray but could lead to areas of increased or decreased opacity in the lung.", 479 | "pneumomediastinum": "Free air in the mediastinal space, appearing as linear or streaky lucencies outlining the mediastinal structures.", 480 | "density": "A region that appears more white on the X-ray film, indicating something denser than air, such as fluid, bone, or mass.", 481 | "granuloma": "A small, localized, rounded density often resulting from past inflammation or infection.", 482 | "adenopathy": "Enlarged lymph nodes that appear as rounded opacities, typically in the mediastinal or hilar regions.", 483 | "pacemaker": "Visible as wires and a generator case that are radiopaque, usually seen in the left upper chest area.", 484 | "subcutaneous emphysema": "Air in the subcutaneous tissues that appears as streaky or bubbly lucencies under the skin, often around the neck or chest wall.", 485 | "hemorrhage": "Depending on its location, may cause increased opacity due to blood in the lung tissue or an air-fluid level if in a body cavity.", 486 | "dilatation": "Enlargement of airways or blood vessels, appearing wider than normal on the X-ray.", 487 | "nodules": "Similar to pulmonary nodules, these are small, rounded densities within the lung fields.", 488 | "infiltrate": "A diffuse area of increased opacity in the lung, suggesting inflammation, infection, or other processes affecting the lung parenchyma.", 489 | "aneurysm": "Localized dilatation of a blood vessel that may appear as an abnormal rounded or oval shadow, but is more clearly seen on CT or MRI.", 490 | "tension": "This term is not typically descriptive of an X-ray finding, but in the context of tension pneumothorax, it would show a collapsed lung and shift of mediastinal structures away from the affected side due to high intrapleural pressure.", 491 | "interstitial markings": "Normally invisible lines and dots representing the lung support structure become more prominent, looking like a network or mesh of lines crisscrossing the lung fields.", 492 | "scoliosis": "A sideways curvature of the spine that can make the ribs appear uneven on an X-ray and may cause the lung fields to look asymmetrical.", 493 | "fibrosis": "Scarred and stiff lung tissue that appears denser (whiter) than healthy tissue, often with a reticular pattern or honeycombing in advanced cases.", 494 | "copd": "Over-expanded lungs with flattened diaphragms and large, dark (lucent) areas that indicate air trapping.", 495 | "hernia": "An abnormal protrusion of an organ or tissue that may show up as an unexpected mass or bulge in the areas adjacent to the diaphragm.", 496 | "hematoma": "A collection of blood outside blood vessels that appears as a localized, denser (whiter) area on X-ray.", 497 | "hiatal hernia": "Part of the stomach protrudes upward through the diaphragm and may appear as an abnormal shadow above the diaphragm.", 498 | "lymphadenopathy": "Enlarged lymph nodes that may show up as rounded, denser (whiter) spots or clusters, usually near the center of the chest in the mediastinal or hilar regions.", 499 | "vascular engorgement": "Blood vessels, especially those leading to the heart, look more prominent and wider than usual, indicating increased blood flow or volume.", 500 | "occlusion": "A blockage in a blood vessel or hollow organ that can result in a lack of blood flow and a corresponding area of increased opacity if infarction occurs.", 501 | "thrombosis": "A blood clot within a vessel that may not be directly visible but could cause a region of the lung to appear abnormal due to decreased blood flow.", 502 | "engorgement": "Similar to vascular engorgement, where the vessels, particularly around the heart and lungs, appear fuller and more pronounced.", 503 | "aeration": "A term for how well the lungs are filled with air; poorly aerated lungs may look more solid (whiter), while well-aerated lungs are dark (lucent).", 504 | "nodule": "A small, rounded shadow that stands out from the surrounding lung tissue; can be solitary or multiple.", 505 | "calcification": "A white, dense spot or area within a nodule or other tissue indicating deposition of calcium.", 506 | "scarring": "Appears as irregular lines or bands of denser tissue, often as a result of past inflammation or injury.", 507 | "infection": "Can vary but often shows up as areas of consolidation or opacity where the lung tissue is filled with fluid or pus.", 508 | "fracture": "Breaks in bones that appear as dark lines across the white bone structure on X-ray.", 509 | "emphysema": "Areas of the lung that look unusually clear (dark on X-ray) due to the destruction of lung tissue and air sacs, with a decrease in vascular markings.", 510 | "pneumothoraces": "Presence of air in the pleural space that shows up as a clear space (dark area) between the lung and the chest wall.", 511 | "abnormalities": "General term for any unusual findings on an X-ray which can include unexpected shapes, sizes, or densities in the lung fields.", 512 | "vascular congestion": "The blood vessels in the lung appear more prominent and numerous, often associated with heart failure.", 513 | "opacification": "Areas that appear white or grayish and cannot be seen through, suggesting the presence of fluid, cells, or other material.", 514 | "consolidation": "A region of lung tissue that has filled with liquid instead of air, appearing as a uniform area of increased whiteness.", 515 | "opacity": "Any area that appears whiter or more solid than it should, indicating something is blocking the passage of X-rays.", 516 | "cardiomegaly": "An enlarged heart silhouette that takes up more space than normal on the X-ray film.", 517 | "edema": "Excess fluid in the lung tissue, which may present as a diffuse haziness across the lung fields.", 518 | "pneumonia": "An infection causing consolidation, often seen as patchy or confluent white areas within the lung fields.", 519 | "atelectasis": "Collapse of lung tissue, showing as streaky opacities or anarea of increased density (whiteness) on the X-ray, sometimes with a shift of the surrounding structures.", 520 | "effusion": "An accumulation of fluid between the lung and chest wall that appears as a homogenous, dense area, often with a meniscus at the lung base.", 521 | "pneumothorax": "Air in the pleural space causing part of the lung to collapse, visible as a sharp line with no lung markings beyond it and a clear space where the lung has collapsed.", 522 | } 523 | -------------------------------------------------------------------------------- /GPT-IE/post_process.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | import re 4 | import os 5 | import copy 6 | from collections import Counter 7 | 8 | disease_type = [ 9 | "ascites", 10 | "pulmonary embolism", 11 | "patchy opacity", 12 | "pleural fluid", 13 | "nodular opacity", 14 | "pseudoaneurysm", 15 | "pulmonary nodules", 16 | "pneumoperitoneum", 17 | "dislocation", 18 | "inflammation", 19 | "hyperinflation", 20 | "tracheostomy tube", 21 | "pulmonary fibrosis", 22 | "infectious process", 23 | "patchy opacities", 24 | "cardiac enlargement", 25 | "bronchiectasis", 26 | "lesion", 27 | "cyst", 28 | "abscess", 29 | "embolism", 30 | "pneumomediastinum", 31 | "nodular opacities", 32 | "interstitial opacities", 33 | "density", 34 | "granuloma", 35 | "large effusion", 36 | "adenopathy", 37 | "pacemaker", 38 | "interstitial pulmonary edema", 39 | "subcutaneous emphysema", 40 | "parenchymal opacities", 41 | "hemorrhage", 42 | "dilatation", 43 | "nodules", 44 | "opacifications", 45 | "infiltrate", 46 | "aneurysm", 47 | "tension", 48 | "interstitial markings", 49 | "scoliosis", 50 | "fibrosis", 51 | "copd", 52 | "hernia", 53 | "hematoma", 54 | "hiatal hernia", 55 | "lymphadenopathy", 56 | "vascular engorgement", 57 | "occlusion", 58 | "airspace opacities", 59 | "thrombosis", 60 | "consolidations", 61 | "engorgement", 62 | "aeration", 63 | "nodule", 64 | "calcification", 65 | "calcifications", 66 | "scarring", 67 | "infection", 68 | "pulmonary vascular congestion", 69 | "fractures", 70 | "fracture", 71 | "pleural effusions", 72 | "emphysema", 73 | "pneumothoraces", 74 | "abnormalities", 75 | "congestion", 76 | "focal consolidation", 77 | "vascular congestion", 78 | "opacification", 79 | "pleural effusion", 80 | "consolidation", 81 | "opacities", 82 | "opacity", 83 | "edema", 84 | "cardiomegaly", 85 | "pulmonary edema", 86 | "pneumonia", 87 | "effusions", 88 | "atelectasis", 89 | "effusion", 90 | "pneumothorax", 91 | ] 92 | disease_type_ = [ 93 | [ 94 | "effusion", 95 | "effusions", 96 | "pleural effusion", 97 | "pleural effusions", 98 | "large effusion", 99 | ], 100 | ["edema", "pulmonary edema", "interstitial pulmonary edema"], 101 | [ 102 | "opacity", 103 | "nodular opacity", 104 | "patchy opacity", 105 | "patchy opacities", 106 | "opacities", 107 | "airspace opacities", 108 | "nodular opacities", 109 | "interstitial opacities", 110 | "parenchymal opacities", 111 | ], 112 | ["opacification", "opacifications"], 113 | ["consolidation", "consolidations", "focal consolidation"], 114 | ["vascular congestion", "congestion", "pulmonary vascular congestion"], 115 | ["fracture", "fractures"], 116 | ["calcification", "calcifications"], 117 | ] 118 | disease_location_organ = [ 119 | "costophrenic sinus", 120 | "left base", 121 | "ventricle", 122 | "mediastinal contour", 123 | "left basal", 124 | "right lung apex", 125 | "hemidiaphragms", 126 | "bronchovascular", 127 | "lobe", 128 | "parenchymal", 129 | "hila", 130 | "stomach", 131 | "bony structures", 132 | "right upper lobe", 133 | "vein", 134 | "pulmonary vascularity", 135 | "vessels", 136 | "aortic", 137 | "aortic valve", 138 | "thoracic", 139 | "pectoral", 140 | "left hemidiaphragm", 141 | "bony", 142 | "hemidiaphragmatic contour", 143 | "right atrium", 144 | "both lower", 145 | "apical", 146 | "aortic knob", 147 | "displaced", 148 | "abnormalities", 149 | "mediastinal and hilar contours", 150 | "left chest wall", 151 | "right hemidiaphragm", 152 | "distal", 153 | "left apical", 154 | "tube", 155 | "basilar", 156 | "low", 157 | "right apex", 158 | "port-a-cath", 159 | "et tube", 160 | "silhouette", 161 | "nasogastric tube", 162 | "right middle lobe", 163 | "vasculature", 164 | "pleura", 165 | "thoracic spine", 166 | "upper zone", 167 | "surfaces", 168 | "picc line", 169 | "diaphragm", 170 | "silhouettes", 171 | "subcutaneous", 172 | "pulmonary vasculature", 173 | "pleural", 174 | "mediastinal", 175 | "osseous", 176 | "pulmonary artery", 177 | "right hilar", 178 | "alveolar", 179 | "lower thoracic", 180 | "lateral", 181 | "right lower lobe", 182 | "lingula", 183 | "aorta", 184 | "venous", 185 | "right apical", 186 | "right chest wall", 187 | "bilateral lower", 188 | "upper", 189 | "right infrahilar", 190 | "cardiopulmonary", 191 | "hiatal", 192 | "perihilar", 193 | "contour", 194 | "lower", 195 | "carinal", 196 | "focal", 197 | "mid thoracic", 198 | "vertebral body", 199 | "central", 200 | "both", 201 | "costophrenic angles", 202 | "region", 203 | "chest", 204 | "left mid and lower", 205 | "interstitial", 206 | "internal jugular", 207 | "lymphadenopathy", 208 | "left upper", 209 | "airspace", 210 | "multifocal", 211 | "lung base", 212 | "pulmonary venous", 213 | "hilar", 214 | "ribs", 215 | "left lower lobe", 216 | "base", 217 | "lung", 218 | "right basal", 219 | "cardiomediastinal silhouette", 220 | "right middle and lower", 221 | "right lower", 222 | "catheter", 223 | "bilateral", 224 | "trachea", 225 | "hemi thorax", 226 | "right lateral", 227 | "chest wall", 228 | "bone", 229 | "contours", 230 | "descending", 231 | "overt", 232 | "right lung", 233 | "bibasal", 234 | "veins", 235 | "mid and lower", 236 | "endotracheal tube", 237 | "retrocardiac", 238 | "humerus", 239 | "left retrocardiac", 240 | "sternal wires", 241 | "right mid lung", 242 | "tracheostomy tube", 243 | "left", 244 | "subclavian", 245 | "left lung base", 246 | "left subclavian", 247 | "abdomen", 248 | "descending aorta", 249 | "right upper quadrant", 250 | "left mid", 251 | "carina", 252 | "lower lobes", 253 | "intrathoracic", 254 | "apex", 255 | "lungs", 256 | "right-sided", 257 | "cavoatrial junction", 258 | "middle", 259 | "lower lobe", 260 | "spine", 261 | "bones", 262 | "right upper", 263 | "mediastinal and hilar", 264 | "na", 265 | "adenopathy", 266 | "hilus", 267 | "clavicle", 268 | "upper lobe", 269 | "thoracic aorta", 270 | "mid", 271 | "right", 272 | "cardiac silhouette", 273 | "right perihilar", 274 | "cardiac", 275 | "right ventricle", 276 | "line", 277 | "lingular", 278 | "left perihilar", 279 | "adjacent", 280 | "underlying", 281 | "right base", 282 | "hilum", 283 | "aortic arch", 284 | "pericardial", 285 | "right basilar", 286 | "left-sided", 287 | "pulmonary", 288 | "pulmonary arteries", 289 | "osseous structures", 290 | "lobes", 291 | "shoulder", 292 | "right mid and lower", 293 | "zone", 294 | "left lung", 295 | "mid svc", 296 | "subdiaphragmatic", 297 | "mediastinum", 298 | "posterior", 299 | "vascular", 300 | "left mid lung", 301 | "thoracolumbar junction", 302 | "rib", 303 | "lung bases", 304 | "cardiomediastinal", 305 | "pulmonary venous pressure", 306 | "basal", 307 | "heart", 308 | "ng tube", 309 | "areas", 310 | "biapical", 311 | "right lung bases", 312 | "hemithorax", 313 | "left lateral", 314 | "atrium", 315 | "hilar and mediastinal", 316 | "left lower", 317 | "upper abdomen", 318 | "left upper quadrant", 319 | "pulmonary vascular", 320 | "bases", 321 | "lymph nodes", 322 | "both bases", 323 | "pericardium", 324 | "costophrenic angle", 325 | "mediastinal contours", 326 | "right lung base", 327 | "infrahilar", 328 | "bibasilar", 329 | "left pectoral", 330 | "soft tissues", 331 | "right middle", 332 | "esophagus", 333 | "lung apices", 334 | "pulmonary vessels", 335 | "hemidiaphragm", 336 | "right mid", 337 | "right internal jugular", 338 | "cardiac and mediastinal", 339 | "left basilar", 340 | "lymph node", 341 | "cardiomediastinal and hilar", 342 | "lung parenchyma", 343 | "both lung bases", 344 | ] 345 | disease_adjective = [ 346 | "concurrent", 347 | "vague", 348 | "probably", 349 | "dilated", 350 | "biapical", 351 | "small if any", 352 | "less prominent", 353 | "widening", 354 | "bibasal", 355 | "resolving", 356 | "consistent with", 357 | "subsequent", 358 | "smaller", 359 | "decreasing", 360 | "minimally increased", 361 | "postoperative", 362 | "healed", 363 | "worse", 364 | "worrisome", 365 | "significant", 366 | "potentially", 367 | "old healed", 368 | "asymmetrical", 369 | "scarring", 370 | "similar", 371 | "complete", 372 | "little", 373 | "acute", 374 | "prior", 375 | "moderate-to-severe", 376 | "perihilar", 377 | "loculated", 378 | "progressive", 379 | "possibility of", 380 | "basilar", 381 | "slightly improved", 382 | "infectious", 383 | "ill-defined", 384 | "lungs", 385 | "central", 386 | "supervening", 387 | "subsegmental", 388 | "massive", 389 | "hyperinflation", 390 | "partial", 391 | "crowding", 392 | "some", 393 | "moderate to large", 394 | "minor", 395 | "tortuosity", 396 | "vascular", 397 | "degenerative", 398 | "subcutaneous", 399 | "consolidation", 400 | "heterogeneous", 401 | "right lower", 402 | "presumed", 403 | "slightly increased", 404 | "moderately severe", 405 | "mediastinal", 406 | "small-to-moderate", 407 | "standard", 408 | "atelectatic", 409 | "upper", 410 | "concerning", 411 | "diffuse bilateral", 412 | "hazy", 413 | "interstitial", 414 | "top-normal", 415 | "moderately", 416 | "infection", 417 | "pre-existing", 418 | "associated", 419 | "blunting", 420 | "layering", 421 | "streaky", 422 | "left-sided", 423 | "moderately enlarged", 424 | "early", 425 | "dense", 426 | "widespread", 427 | "right-sided", 428 | "intact", 429 | "old", 430 | "interval improvement", 431 | "tiny", 432 | "improvement", 433 | "tortuous", 434 | "decrease", 435 | "left lower", 436 | "elevated", 437 | "elevation", 438 | "unremarkable", 439 | "interval increase", 440 | "moderate to severe", 441 | "previous", 442 | "marked", 443 | "pleural", 444 | "asymmetric", 445 | "small to moderate", 446 | "hyperinflated", 447 | "increase", 448 | "prominence", 449 | "resolved", 450 | "focal", 451 | "developing", 452 | "prominent", 453 | "improving", 454 | "volume loss", 455 | "worsened", 456 | "mild-to-moderate", 457 | "interval", 458 | "residual", 459 | "adjacent", 460 | "aspiration", 461 | "compressive", 462 | "retrocardiac", 463 | "lower", 464 | "constant", 465 | "linear", 466 | "known", 467 | "underlying", 468 | "subtle", 469 | "slight", 470 | "multiple", 471 | "likely", 472 | "trace", 473 | "continued", 474 | "mild to moderate", 475 | "calcified", 476 | "mildly enlarged", 477 | "mildly", 478 | "probable", 479 | "multifocal", 480 | "superimposed", 481 | "enlargement", 482 | "extensive", 483 | "diffuse", 484 | "increasing", 485 | "clear", 486 | "patchy", 487 | "substantial", 488 | "chronic", 489 | "enlarged", 490 | "borderline", 491 | "worsening", 492 | "large", 493 | "possible", 494 | "decreased", 495 | "bibasilar", 496 | "severe", 497 | "improved", 498 | "persistent", 499 | "bilateral", 500 | "left", 501 | "right", 502 | "increased", 503 | "minimal", 504 | "stable", 505 | "low", 506 | "normal", 507 | "new", 508 | "unchanged", 509 | "moderate", 510 | "small", 511 | "mild", 512 | "no", 513 | ] 514 | disease_adjective = disease_adjective[:-1] 515 | disease_type_2 = [ 516 | "ascites", 517 | "pulmonary embolism", 518 | "pleural fluid", 519 | "pseudoaneurysm", 520 | "pulmonary nodules", 521 | "pneumoperitoneum", 522 | "dislocation", 523 | "inflammation", 524 | "hyperinflation", 525 | "tracheostomy tube", 526 | "pulmonary fibrosis", 527 | "infectious process", 528 | "cardiac enlargement", 529 | "bronchiectasis", 530 | "lesion", 531 | "cyst", 532 | "abscess", 533 | "embolism", 534 | "pneumomediastinum", 535 | "density", 536 | "granuloma", 537 | "adenopathy", 538 | "pacemaker", 539 | "subcutaneous emphysema", 540 | "hemorrhage", 541 | "dilatation", 542 | "nodules", 543 | "infiltrate", 544 | "aneurysm", 545 | "tension", 546 | "interstitial markings", 547 | "scoliosis", 548 | "fibrosis", 549 | "copd", 550 | "hernia", 551 | "hematoma", 552 | "hiatal hernia", 553 | "lymphadenopathy", 554 | "vascular engorgement", 555 | "occlusion", 556 | "thrombosis", 557 | "engorgement", 558 | "aeration", 559 | "nodule", 560 | "calcification", 561 | "scarring", 562 | "infection", 563 | "fracture", 564 | "emphysema", 565 | "pneumothoraces", 566 | "abnormalities", 567 | "vascular congestion", 568 | "opacification", 569 | "consolidation", 570 | "opacity", 571 | "cardiomegaly", 572 | "edema", 573 | "pneumonia", 574 | "atelectasis", 575 | "effusion", 576 | "pneumothorax", 577 | ] 578 | view_position = ["AP", "PA", "LATERAL", "LL", ""] 579 | 580 | 581 | def post_process(root_dir): 582 | """ 583 | Post-process the processing results, including reading metadata, creating directories, loading and processing results, and saving the processed results. 584 | 585 | Parameters: 586 | root_dir: The root directory path used to construct image paths. 587 | """ 588 | # Initialize a dictionary to store metadata 589 | paths_dict = {} 590 | # Open and read the metadata CSV file 591 | with open("./mimic-cxr-2.0.0-metadata.csv", "r") as f: 592 | reader = csv.reader(f) 593 | for row in reader: 594 | # Store metadata combined by subject_id and study_id 595 | if str(row[1]) + str(row[2]) not in paths_dict: 596 | paths_dict[str(row[1]) + str(row[2])] = [ 597 | [str(row[0]), str(row[1]), str(row[2]), str(row[4])] 598 | ] 599 | else: 600 | paths_dict[str(row[1]) + str(row[2])].append( 601 | [str(row[0]), str(row[1]), str(row[2]), str(row[4])] 602 | ) 603 | 604 | # Create directories to store post-processed results 605 | for i in range(10, 20): 606 | path = "./post_processed_results/" + "p" + str(i) 607 | if not os.path.exists(path): 608 | os.makedirs(path, exist_ok=True) 609 | 610 | # Load and process each result file 611 | for p in range(10, 20): 612 | results = np.load("./results/p" + str(p) + "/results.npy", allow_pickle=True) 613 | results_not_empty = np.array([]) 614 | save_path = "./post_processed_results/p" + str(p) + "/results.npy" 615 | for i in range(len(results)): 616 | # Split the text by commas or periods 617 | parts = re.split("[,.]", results[i]["prompt"]) 618 | checked_result = [] 619 | # Traverse the split parts and check if they meet the requirements 620 | for part in parts: 621 | if len(part) < 1: 622 | continue 623 | # Organize the results returned by GPT 624 | if len(re.findall("\{[a-zA-Z\s\-/]+\}", part.strip())) == 4: 625 | result = list( 626 | filter( 627 | lambda s: s != "", re.split("[{}]", part.strip().lower()) 628 | ) 629 | ) 630 | if result[3] == "na" and result[2] in disease_type: 631 | result[3] = result[2] 632 | result[2] = "na" 633 | 634 | if result[0] == "na" and result[1] in disease_adjective: 635 | result[0] = result[1] 636 | result[1] = "na" 637 | 638 | # Structured labels need to include disease description, disease location, disease organ, and disease type information, and these labels must be within specific ranges 639 | if ( 640 | result[3] not in disease_type 641 | or result[2] not in disease_location_organ 642 | or result[1] not in disease_location_organ 643 | or result[0] not in disease_adjective 644 | ): 645 | continue 646 | # Group different descriptions of the same disease into one category 647 | for original_disease_type in disease_type_: 648 | if result[3] in original_disease_type: 649 | result[3] = original_disease_type[0] 650 | checked_result.append(result) 651 | results[i]["result"] = checked_result 652 | 653 | # Retain only records with valid results and add additional information 654 | if len(results[i]["result"]) >= 1: 655 | if len(results[i]["result"]) > 10: 656 | print(results[i]["path"]) 657 | print(len(results[i]["result"])) 658 | if len(results[i]["result"]) > 100: 659 | print(results[i]["result"]) 660 | # Remove results that do not contain structured labels after processing, and add image paths and view information 661 | subject_id_and_study_id = str(results[i]["path"][25:33]) + str( 662 | results[i]["path"][35:43] 663 | ) 664 | for j in range(len(paths_dict[subject_id_and_study_id])): 665 | result = copy.deepcopy(results[i]) 666 | result["image_path"] = ( 667 | root_dir 668 | + "p" 669 | + paths_dict[subject_id_and_study_id][j][1][:2] 670 | + "/p" 671 | + paths_dict[subject_id_and_study_id][j][1] 672 | + "/s" 673 | + paths_dict[subject_id_and_study_id][j][2] 674 | + "/" 675 | + paths_dict[subject_id_and_study_id][j][0] 676 | + ".jpg" 677 | ) 678 | result["view_position"] = paths_dict[subject_id_and_study_id][j][3] 679 | if ( 680 | result["path"] 681 | == "/root/reports/files/p10/p10002559/s52212843.txt" 682 | ): 683 | print(result) 684 | results_not_empty = np.append(results_not_empty, result) 685 | # Save the processed results 686 | np.save(save_path, results_not_empty) 687 | 688 | 689 | def concat_npy(): 690 | # Combine into one .npy file 691 | data_dir = "./post_processed_results" 692 | folder_names = os.listdir(data_dir) 693 | result = np.empty((0,)) 694 | 695 | for folder_name in folder_names: 696 | file_path = os.path.join(data_dir, folder_name, "results.npy") 697 | data = np.load(file_path, allow_pickle=True) 698 | result = np.concatenate((result, data)) 699 | np.save("./mimic.npy", result) 700 | 701 | 702 | if __name__ == "__main__": 703 | # Set the root directory to the MIMIC-CXR dataset file path 704 | root_dir = "~/data/physionet.org/files/mimic-cxr-jpg/2.0.0/files/" 705 | 706 | # Call the post_process function to post-process the data 707 | post_process(root_dir) 708 | 709 | # Load the post-processed result file 710 | path = "./post_processed_results/p10/results.npy" 711 | results = np.load(path, allow_pickle=True) 712 | 713 | # Initialize an empty list to store disease information 714 | disease = [] 715 | 716 | # Display the processed label information (only display results for a specific path) 717 | for i in range(len(results)): 718 | if results[i]["path"] == "/root/reports/files/p10/p10002559/s52212843.txt": 719 | print(results[i]) 720 | 721 | # Traverse all results, extract and record disease information from each result 722 | for i in range(len(results)): 723 | for result in results[i]["result"]: 724 | disease.append(result[3]) # Assume result[3] is the disease name 725 | 726 | # Count the occurrences of each disease 727 | counts = Counter(disease) 728 | 729 | # Sort the counts in descending order 730 | sorted_counts = sorted(counts.items(), key=lambda x: x[1], reverse=True) 731 | 732 | # Initialize three lists to store the top 100 diseases, their occurrences, and disease types 733 | disease_100 = [] 734 | disease_num = [] 735 | disease_type = [] 736 | 737 | # Print and record the information of the last 50 diseases (i.e., the 50 diseases with the fewest occurrences) 738 | for i in range(len(sorted_counts) - 50, len(sorted_counts)): 739 | print( 740 | sorted_counts[len(sorted_counts) - 1 - i][0], 741 | sorted_counts[len(sorted_counts) - 1 - i][1], 742 | ) 743 | disease_100.append( 744 | [ 745 | sorted_counts[len(sorted_counts) - 1 - i][0], 746 | sorted_counts[len(sorted_counts) - 1 - i][1], 747 | ] 748 | ) 749 | disease_num.append(sorted_counts[len(sorted_counts) - 1 - i][1]) 750 | disease_type.append(sorted_counts[len(sorted_counts) - 1 - i][0]) 751 | 752 | # Define the headers for the CSV file 753 | headers = ["disease", "occurrences"] 754 | 755 | # Write the information of the top 100 diseases to a CSV file 756 | with open("disease_100.csv", "w") as f: 757 | writer = csv.writer(f) 758 | writer.writerow(headers) 759 | writer.writerows(disease_100) 760 | 761 | # Call the concat_npy function to merge Numpy files 762 | concat_npy() 763 | 764 | # Load the merged Numpy file 765 | results = np.load("./mimic.npy", allow_pickle=True) 766 | 767 | # Print the first 10 loaded results 768 | print(results[:10]) 769 | -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | from tqdm import tqdm 4 | import timeit 5 | import os 6 | import copy 7 | import random 8 | from datetime import datetime 9 | import argparse 10 | import json 11 | import seaborn as sns 12 | import matplotlib.pyplot as plt 13 | from torch.utils.data import Dataset, DataLoader 14 | from torch import nn 15 | from torchvision import transforms 16 | from transformers import AutoTokenizer, AutoModel 17 | from transformers import BertTokenizer, BertModel 18 | from torch.utils.tensorboard import SummaryWriter 19 | import torch 20 | import torch.nn.functional as F 21 | import clip 22 | from models import VLM 23 | from constants import disease_type, description_book 24 | 25 | IKI_dic, disease_type = description_book, disease_type[-16:] 26 | 27 | 28 | def collate_fn(data): 29 | """ 30 | Custom data collation function to process batches of data from the data loader. 31 | This function separates images, labels, and texts into individual lists for further processing. 32 | 33 | Parameters: 34 | - data: A list of tuples, each containing an image (NumPy array), a label (integer), and a text (string). 35 | 36 | Returns: 37 | - data_copy: A tuple containing three elements: 38 | 1. images: A NumPy array of all images. 39 | 2. labels: A list of all labels. 40 | 3. texts: A list of all texts. 41 | """ 42 | # Initialize lists for images, labels, and texts 43 | images, labels, texts = [], [], [] 44 | 45 | # Iterate over the batch and append images, labels, and texts to their respective lists 46 | for batch in range(0, len(data)): 47 | images.append(data[batch][0]) 48 | labels.append(data[batch][1]) 49 | texts.append(data[batch][2]) 50 | 51 | # Convert the list of images to a NumPy array and return it with labels and texts 52 | data_copy = (np.array(images), labels, texts) 53 | return data_copy 54 | 55 | 56 | def mean_pooling(model_output, attention_mask): 57 | token_embeddings = model_output[0] 58 | input_mask_expanded = ( 59 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 60 | ) 61 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( 62 | input_mask_expanded.sum(1), min=1e-9 63 | ) 64 | 65 | 66 | def draw_similarity_matrix(similarity_matrix, name, size=(80, 60)): 67 | plt.figure(figsize=size) 68 | sns.heatmap(similarity_matrix) 69 | plt.savefig(name) 70 | plt.close() 71 | 72 | 73 | def SSM(text, entity, lengths, tokenizer, model, device): 74 | """ 75 | Calculate the similarity between entities in the text. 76 | 77 | Parameters: 78 | text (List[str]): List of segmented texts. 79 | entity (List[str]): List of entities. 80 | lengths (List[int]): List of lengths for each entity. 81 | tokenizer: Tokenizer object used for tokenization. 82 | model: Model object used to generate embeddings. 83 | device: Device information (GPU or CPU). 84 | 85 | Returns: 86 | torch.Tensor: Similarity matrix between entities. 87 | """ 88 | # Create a boolean matrix indicating whether each character in the entity matches others 89 | bool_matrix = [[item2 == item for item2 in entity] for item in entity] 90 | 91 | # Tokenize the text 92 | encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors="pt") 93 | encoded_input = {key: tensor.to(device) for key, tensor in encoded_input.items()} 94 | 95 | # Compute embeddings using the model without gradient calculation 96 | with torch.no_grad(): 97 | model_output = model(**encoded_input) 98 | 99 | # Perform mean pooling to obtain sentence embeddings 100 | sentence_embeddings = mean_pooling(model_output, encoded_input["attention_mask"]) 101 | sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) 102 | 103 | # Compute and retrieve the similarity matrix of sentence embeddings 104 | similarity_matrix = torch.matmul(sentence_embeddings, sentence_embeddings.T) 105 | similarity_matrix = similarity_matrix.cpu().numpy() 106 | 107 | # Apply the boolean matrix to filter non-similar items 108 | similarity_matrix = bool_matrix * (similarity_matrix) 109 | 110 | # Insert 0 at the beginning of the lengths list and convert it to a cumulative sum array 111 | lengths.insert(0, 0) 112 | lengths = np.array(lengths) 113 | lengths = np.cumsum(lengths) 114 | 115 | # Split and compute the maximum value of the similarity matrix based on entity lengths 116 | similarity_matrix = [ 117 | np.max(similarity_matrix[lengths[i] : lengths[i + 1]], axis=0) 118 | for i in range(len(lengths) - 1) 119 | ] 120 | 121 | # Return the processed similarity matrix 122 | return torch.tensor(np.array(similarity_matrix)) 123 | 124 | 125 | def load_png(img_path): 126 | img = Image.open(img_path).convert("L") 127 | transform = transforms.Compose( 128 | [ 129 | transforms.ToTensor(), 130 | transforms.Resize((448, 448), antialias=True), 131 | ] 132 | ) 133 | return transform(img) 134 | 135 | 136 | def data_prepare(save_path): 137 | """ 138 | Prepare the dataset and save it in numpy format. 139 | 140 | This function reads data from a JSON file, processes the images, calculates the mean and standard deviation, 141 | normalizes the images, and saves the processed data to the specified path. 142 | 143 | Parameters: 144 | - save_path: str, path to save the processed data. 145 | """ 146 | # Load data from JSON file 147 | with open("./data_llama3_8B.json", "r", encoding="utf-8") as file: 148 | datas = json.load(file) 149 | 150 | # Initialize image list 151 | image_list = [] 152 | # Process only the first 1000 data entries 153 | for data in datas[:1000]: 154 | # Load and save the image 155 | image_npy = load_png(data["images"][0]) 156 | image_list.append(image_npy) 157 | 158 | # Convert image list to numpy array 159 | image_list = np.array(image_list) 160 | 161 | # Calculate mean and standard deviation of images 162 | mean_value = np.mean(image_list, axis=(0, 2, 3)) 163 | std_value = np.std(image_list, axis=(0, 2, 3)) 164 | 165 | # Define image transformation for normalization 166 | transform = transforms.Compose( 167 | [transforms.Normalize(mean=mean_value, std=std_value)] 168 | ) 169 | 170 | # Print mean and standard deviation 171 | print(f"Mean: {mean_value}") 172 | print(f"Standard Deviation: {std_value}") 173 | 174 | # Initialize array to save processed data 175 | data_tosave = np.array([]) 176 | 177 | # Record start time 178 | time0 = timeit.default_timer() 179 | 180 | # Iterate over all data for processing 181 | for i in range(len(datas[:])): 182 | # Print progress every 1000 data entries 183 | if i % 1000 == 0: 184 | print("i", i) 185 | data = datas[i] 186 | 187 | # Process labels 188 | labels = [] 189 | for s in data["response"].split(", "): 190 | s = s.split(" ") 191 | labels.append([s[0], " ".join(s[1:-1]), s[-1]]) 192 | 193 | # Update data and remove unnecessary fields 194 | data["labels"] = labels 195 | del data["response"] 196 | del data["query"] 197 | 198 | # Load and transform the image 199 | img = load_png(data["images"][0]) 200 | image_npy = transform(img) 201 | data["image_npy"] = image_npy 202 | 203 | # Append processed data to the save array 204 | data_tosave = np.append(data_tosave, data) 205 | 206 | # Save processed data to the specified path 207 | np.save(save_path, data_tosave) 208 | 209 | # Record end time and print elapsed time 210 | time1 = timeit.default_timer() 211 | print("Data loading: ", time1 - time0) 212 | 213 | 214 | def random_mask(labels, p=0.5): 215 | for k in range(len(labels)): 216 | for i in range(0, len(labels[k]) - 1): 217 | if random.random() < p: 218 | labels[k][i] = "mask" 219 | if random.random() < p: 220 | labels[k][-1] = "mask" 221 | return labels 222 | 223 | 224 | class Mimic_CXR(Dataset): 225 | def __init__(self, data, mode="train", mask="random_mask"): 226 | self.data = data 227 | if mode == "train": 228 | self.data = self.data[: int(len(self.data) * 0.9)] 229 | else: 230 | self.data = self.data[int(len(self.data) * 0.9) :] 231 | self.mask = mask 232 | 233 | def __len__(self): 234 | return len(self.data) 235 | 236 | def __getitem__(self, idx): 237 | label = copy.deepcopy(self.data[idx]["labels"]) 238 | # Inject Image Knowledge Injector module to inject image-specific knowledge into the model 239 | text = [ 240 | ( 241 | " ".join(sublist) + ". " + description_book[sublist[2]] 242 | if sublist[2] in description_book 243 | else " ".join(sublist) 244 | ) 245 | for sublist in label 246 | ] 247 | return np.array(self.data[idx]["image_npy"].repeat(3, 1, 1)), label, text 248 | 249 | 250 | def calculate_auc(fpr, tpr): 251 | fpr = np.array(fpr) 252 | tpr = np.array(tpr) 253 | auc = np.trapz(tpr, fpr) 254 | return auc 255 | 256 | 257 | def compute_metrics(predicted, labels): 258 | """ 259 | Compute evaluation metrics for the model. 260 | 261 | Parameters: 262 | predicted: Model predictions, same shape as labels. 263 | labels: Ground truth labels, same shape as predicted. 264 | 265 | Returns: 266 | Strict accuracy, overall accuracy, precision, recall, F1 score, true positive rate, false positive rate, 267 | true positives, false positives, true negatives, false negatives. 268 | """ 269 | # Compute strict accuracy, i.e., proportion where all predictions are correct 270 | correct_predictions = np.all(predicted == labels) 271 | num_correct = np.sum(correct_predictions) 272 | num_total = predicted.shape[0] 273 | accuracy_strict = num_correct / num_total 274 | 275 | # Compute components of the confusion matrix 276 | TP = np.sum((predicted == 1) & (labels == 1)) 277 | FP = np.sum((predicted == 1) & (labels == 0)) 278 | TN = np.sum((predicted == 0) & (labels == 0)) 279 | FN = np.sum((predicted == 0) & (labels == 1)) 280 | 281 | # Compute various evaluation metrics based on the confusion matrix 282 | accuracy = (TP + TN) / (TP + FP + TN + FN) 283 | precision = TP / (TP + FP) 284 | recall = TP / (TP + FN) 285 | f1_score = 2 * precision * recall / (precision + recall) 286 | TPR = TP / (TP + FN) 287 | FPR = FP / (FP + TN) 288 | 289 | # Return all computed evaluation metrics and confusion matrix elements 290 | return ( 291 | accuracy_strict, 292 | accuracy, 293 | precision, 294 | recall, 295 | f1_score, 296 | TPR, 297 | FPR, 298 | TP, 299 | FP, 300 | TN, 301 | FN, 302 | ) 303 | 304 | 305 | def filter_array(arr): 306 | return arr[(arr >= 0) & (arr <= 1)] 307 | 308 | 309 | def eval(net, dataloader, device): 310 | """ 311 | Evaluate model performance. 312 | 313 | Parameters: 314 | - net: Model network. 315 | - dataloader: Data loader for iterating over data. 316 | - device: Device information (CPU or GPU). 317 | 318 | Returns: 319 | - f1_list: List of F1 scores. 320 | - auc_list: List of AUC scores. 321 | - acc_list: List of accuracies. 322 | """ 323 | with torch.set_grad_enabled(False): 324 | # Initialize concatenated lists for predictions and labels 325 | predicted_concat = [] 326 | label_concat = [] 327 | 328 | # Iterate over data in the data loader 329 | for images, labels, texts in tqdm(dataloader): 330 | # Convert image data to tensor and move to specified device 331 | images = torch.tensor(images) 332 | images = images.to(device) 333 | 334 | # Extract entity information from labels 335 | entity = [[arr[2] for arr in label] for label in labels] 336 | lengths = [len(sublist) for sublist in entity] 337 | entity = [item for sublist in entity for item in sublist] 338 | 339 | # Generate ground truth labels 340 | ground_truth = [ 341 | [item2 == item for item2 in disease_type] for item in entity 342 | ] 343 | ground_truth = np.array(ground_truth) 344 | ground_truth = ground_truth * np.ones(ground_truth.shape) 345 | 346 | # Accumulate lengths for subsequent processing 347 | lengths.insert(0, 0) 348 | lengths = np.array(lengths) 349 | lengths = np.cumsum(lengths) 350 | 351 | # Perform max aggregation on ground truth 352 | ground_truth = [ 353 | np.max(ground_truth[lengths[i] : lengths[i + 1]], axis=0) 354 | for i in range(len(lengths) - 1) 355 | ] 356 | labels_np = np.array(ground_truth) 357 | 358 | # Generate prompt text 359 | prompt_text = [ 360 | disease + ". " + description_book[disease] for disease in disease_type 361 | ] 362 | 363 | # Use the model to make predictions 364 | predicted_np = net(images, prompt_text, device).detach().cpu().numpy() 365 | 366 | # Append predictions and labels to concatenated lists 367 | predicted_concat.append(predicted_np) 368 | label_concat.append(labels_np) 369 | 370 | # Concatenate predictions and labels along the specified axis 371 | predicted_concat = np.concatenate(predicted_concat, axis=0) 372 | label_concat = np.concatenate(label_concat, axis=0) 373 | 374 | # Initialize results list 375 | results = [] 376 | # Define threshold range 377 | thresholds = np.arange(-1, 1, 0.001) 378 | 379 | # Iterate over each threshold 380 | for threshold in thresholds: 381 | # Copy predictions and apply threshold 382 | predicted_binary = copy.deepcopy(predicted_concat) 383 | predicted_binary = np.where(predicted_binary >= threshold, 1, 0) 384 | 385 | # Initialize performance metric result list 386 | result = [] 387 | for c in range(predicted_binary.shape[1]): 388 | # Compute performance metrics for each class 389 | metric = compute_metrics(predicted_binary[:, c], label_concat[:, c]) 390 | result.append(metric) 391 | results.append(result) 392 | 393 | # Initialize performance metric lists 394 | auc_list = [] 395 | f1_list = [] 396 | acc_list = [] 397 | 398 | # Convert results to numpy array 399 | results = np.array(results) 400 | 401 | # Iterate over each class 402 | for c in range(predicted_binary.shape[1]): 403 | result = results[:, c, :] 404 | 405 | # Filter F1 scores 406 | f1_ = filter_array(result[:, 4]) 407 | if len(f1_) == 0: 408 | continue 409 | 410 | # Add maximum F1 score to the list 411 | f1_list.append(np.max(f1_)) 412 | 413 | # Add corresponding accuracy for maximum F1 score to the list 414 | acc_ = result[:, 1][np.where(result[:, 4] == np.max(f1_))] 415 | acc_list.append(np.max(acc_)) 416 | 417 | # Compute and add AUC score to the list 418 | auc_list.append(calculate_auc(result[:, 6][::-1], result[:, 5][::-1])) 419 | 420 | # Return performance metric lists 421 | return f1_list, auc_list, acc_list 422 | 423 | 424 | def parse_args(): 425 | parser = argparse.ArgumentParser(description="Train a classifier") 426 | parser.add_argument("--tag", type=str, help="Experiment tag", default="experiment") 427 | parser.add_argument("--gpu", type=int, help="GPU index", default=0) 428 | parser.add_argument("--vision_model", type=str, help="Vision encoder", default="resnet") 429 | parser.add_argument("--log_path", type=str, help="Training log path", default="./log") 430 | parser.add_argument( 431 | "--checkpoints_path", type=str, help="Model checkpoint path", default="./checkpoints" 432 | ) 433 | parser.add_argument( 434 | "--pretrained_path", 435 | type=str, 436 | help="Pretrained model path", 437 | default="", 438 | ) 439 | args = parser.parse_args() 440 | return args 441 | 442 | 443 | if __name__ == "__main__": 444 | # Set PyTorch CUDA memory allocation configuration to limit max split size to 128MB 445 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" 446 | 447 | # If torch.cuda has an empty_cache method, call it to clear the cache 448 | if hasattr(torch.cuda, "empty_cache"): 449 | torch.cuda.empty_cache() 450 | 451 | # Record start time for data loading 452 | time0 = timeit.default_timer() 453 | 454 | # Initialize a dictionary to record time consumption for different stages 455 | time = {"Data Loading": 0, "Label Generation": 0, "Model Training": 0} 456 | 457 | # Set CUDA launch blocking mode for debugging 458 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 459 | 460 | # Disable parallelism in tokenizers to avoid potential conflicts 461 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 462 | 463 | # Parse command-line arguments 464 | args = parse_args() 465 | 466 | # Get the tag information from command-line arguments 467 | tag = str(args.tag) 468 | 469 | # Set the maximum number of epochs for training 470 | epoches = 20 471 | 472 | # Select device (CPU or GPU) based on availability 473 | device = torch.device( 474 | "cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu" 475 | ) 476 | 477 | # Get log path and checkpoint path from command-line arguments 478 | log_path = args.log_path 479 | checkpoints_path = args.checkpoints_path 480 | 481 | # Get current timestamp to generate unique folder names 482 | current_time = datetime.now().strftime("%b%d_%H-%M-%S") 483 | 484 | # Construct log and checkpoint paths with timestamps 485 | log_path = os.path.join(log_path, "{}_".format(tag) + current_time) 486 | checkpoints_path = os.path.join(checkpoints_path, "{}_".format(tag) + current_time) 487 | 488 | # Create directories for logs and checkpoints 489 | os.makedirs(log_path) 490 | os.makedirs(checkpoints_path) 491 | 492 | # Initialize TensorBoard SummaryWriter for logging training progress 493 | writer = SummaryWriter(log_dir=log_path) 494 | 495 | # Define training data path and call data_prepare function for data preprocessing 496 | data_path = "./mimic_train.npy" 497 | data_prepare(data_path) 498 | 499 | # Load preprocessed data 500 | datas = np.load(data_path, allow_pickle=True) 501 | 502 | # Create training dataset and data loader 503 | train_dataset = Mimic_CXR(datas) 504 | train_dataloader = DataLoader( 505 | train_dataset, 506 | batch_size=32, 507 | shuffle=True, 508 | num_workers=0, 509 | pin_memory=False, 510 | collate_fn=collate_fn, 511 | drop_last=True, 512 | ) 513 | 514 | # Create evaluation dataset and data loader 515 | eval_dataset = Mimic_CXR( 516 | datas, 517 | mode="eval", 518 | mask="all_mask", 519 | ) 520 | eval_dataloader = DataLoader( 521 | eval_dataset, 522 | batch_size=64, 523 | shuffle=False, 524 | num_workers=0, 525 | pin_memory=False, 526 | collate_fn=collate_fn, 527 | drop_last=True, 528 | ) 529 | 530 | # Load pretrained tokenizer and model 531 | tokenizer = AutoTokenizer.from_pretrained("~/.cache/all-mpnet-base-v2/") 532 | similarity_model = AutoModel.from_pretrained("~/.cache/all-mpnet-base-v2/") 533 | 534 | # Move similarity model to the specified device (CPU or GPU) 535 | similarity_model.to(device) 536 | 537 | # Initialize the custom VLM model 538 | net = VLM(vision_model=args.vision_model) 539 | 540 | # Move VLM model to the specified device (CPU or GPU) 541 | net.to(device) 542 | 543 | # If a pretrained model path is provided, load the pretrained model weights 544 | if args.pretrained_path: 545 | checkpoints = torch.load(args.pretrained_path, map_location=device) 546 | net.load_state_dict(checkpoints["network"]) 547 | 548 | # Define loss functions: Mean Squared Error Loss, Cross Entropy Loss, and Binary Cross Entropy Loss 549 | MSE_Loss = nn.MSELoss() 550 | CE_Loss = nn.CrossEntropyLoss() 551 | BCE_Loss = nn.BCEWithLogitsLoss() 552 | 553 | # Define optimizer: Stochastic Gradient Descent (SGD) with learning rate 0.0001 and weight decay 1e-4 554 | optimizer = torch.optim.SGD(net.parameters(), lr=0.0001, weight_decay=1e-4) 555 | 556 | # Define learning rate scheduler: Cosine Annealing Warm Restarts 557 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( 558 | optimizer, 559 | T_0=2 * len(train_dataloader), 560 | T_mult=1, 561 | eta_min=1e-6, 562 | last_epoch=-1, 563 | ) 564 | 565 | # Record the end time of data loading and print the time taken for data loading 566 | time1 = timeit.default_timer() 567 | print("Data loading: ", time1 - time0) 568 | time["Data Loading"] += time1 - time0 569 | 570 | # Reinitialize the timer to record time for other stages 571 | time1 = timeit.default_timer() 572 | 573 | # Initialize the best score to 0 574 | best_score = 0 575 | 576 | # Start the training loop, iterating for the specified number of epochs 577 | for epoch in range(epoches): 578 | # Initialize loss list for each epoch 579 | loss_epoch = [0, 0, 0, 0] 580 | 581 | # Iterate over each batch of data in the training data loader 582 | for images, labels, texts in tqdm(train_dataloader): 583 | # Record the start time of label generation 584 | time1 = timeit.default_timer() 585 | 586 | # Convert image data to tensor and move to the specified device (CPU or GPU) 587 | images = torch.tensor(images) 588 | images = images.to(device) 589 | 590 | # Flatten texts into a single list of strings 591 | texts = [sequence for sublist in texts for sequence in sublist] 592 | 593 | # Extract entity information from labels 594 | entity = [[arr[2] for arr in label] for label in labels] 595 | 596 | # Calculate the length of each entity in the samples 597 | lengths = [len(sublist) for sublist in entity] 598 | 599 | # Flatten the entity list 600 | entity = [item for sublist in entity for item in sublist] 601 | 602 | # Generate semantic similarity label matrix 603 | mse_label = SSM( 604 | copy.deepcopy(texts), 605 | copy.deepcopy(entity), 606 | copy.deepcopy(lengths), 607 | tokenizer, 608 | similarity_model, 609 | device, 610 | ) 611 | 612 | # Move the semantic similarity label matrix to the specified device (CPU or GPU) 613 | mse_label = mse_label.to(device) 614 | 615 | # Insert 0 at the beginning of the lengths list and compute the cumulative sum 616 | lengths.insert(0, 0) 617 | lengths = np.array(lengths) 618 | lengths = np.cumsum(lengths) 619 | 620 | # Record the end time of label generation and update label generation time 621 | time2 = timeit.default_timer() 622 | time["Label Generation"] += time2 - time1 623 | 624 | # Clear gradients 625 | optimizer.zero_grad() 626 | 627 | # Forward pass to get model output 628 | logits_per = net(images, texts, device) 629 | 630 | # Compute MSE loss 631 | mse_loss = MSE_Loss(logits_per, mse_label) 632 | 633 | # Compute CE loss 634 | ce_loss = CE_Loss(logits_per, mse_label) / logits_per.size()[0] 635 | 636 | # Compute BCE loss 637 | bce_loss = BCE_Loss(logits_per, mse_label) 638 | 639 | # Compute total loss 640 | total_loss = mse_loss + bce_loss 641 | 642 | # Backpropagation and update model parameters 643 | total_loss.backward() 644 | optimizer.step() 645 | scheduler.step() 646 | 647 | # Record the end time of model training and update model training time 648 | time3 = timeit.default_timer() 649 | time["Model Training"] += time3 - time2 650 | 651 | # Accumulate losses for each batch 652 | loss_epoch[0] += total_loss.item() 653 | loss_epoch[1] += mse_loss.item() 654 | loss_epoch[2] += ce_loss.item() 655 | loss_epoch[3] += bce_loss.item() 656 | 657 | # Evaluate model performance at the end of each epoch 658 | scores = eval(net, eval_dataloader, device) 659 | f1_score = np.mean(scores[0]) 660 | auc_score = np.mean(scores[1]) 661 | acc_score = np.mean(scores[2]) 662 | 663 | # Print information for the current epoch 664 | len_ = len(train_dataloader) 665 | print("epoch: ", epoch) 666 | print(time) 667 | print("epoch_loss: ", loss_epoch[0] / len_) 668 | print(scores[0]) 669 | print("f1_score: ", f1_score) 670 | print(scores[1]) 671 | print("auc_score: ", auc_score) 672 | print(scores[2]) 673 | print("acc_score: ", acc_score) 674 | 675 | # Write evaluation results to the log file 676 | with open(os.path.join(log_path, "logs.txt"), "a") as logs: 677 | if epoch % 20 == 0: 678 | logs.write( 679 | "{:<10}|{:<8}|{:<9}|{:<9}|{:<9}|{:<9}|{:<9}|{:<9}|{:<9}\n".format( 680 | " lr", 681 | " epoch", 682 | " total", 683 | " mse", 684 | " ce", 685 | " bce", 686 | " f1", 687 | " auc", 688 | " acc", 689 | ) 690 | ) 691 | logs.write( 692 | str( 693 | " %0.6f | %6d | %0.5f | %0.5f | %0.5f | %0.5f | %0.5f | %0.5f | %0.5f " 694 | % ( 695 | optimizer.state_dict()["param_groups"][0]["lr"], 696 | epoch, 697 | loss_epoch[0] / len_, 698 | loss_epoch[1] / len_, 699 | loss_epoch[2] / len_, 700 | loss_epoch[3] / len_, 701 | f1_score, 702 | auc_score, 703 | acc_score, 704 | ) 705 | ) 706 | + "\n" 707 | ) 708 | 709 | # Use TensorBoard to log evaluation results 710 | writer.add_scalars( 711 | "time", 712 | { 713 | "epoch_loss": loss_epoch[0] / len_, 714 | "mse_loss": loss_epoch[1] / len_, 715 | "ce_loss": loss_epoch[2] / len_, 716 | "bce_loss": loss_epoch[3] / len_, 717 | "f1_score": f1_score, 718 | "auc_score": auc_score, 719 | "acc_score": acc_score, 720 | }, 721 | epoch, 722 | ) 723 | 724 | # Save the model if the current F1 score is better than the best score 725 | if f1_score > best_score: 726 | best_score = f1_score 727 | torch.save( 728 | { 729 | "it": epoch, 730 | "network": net.state_dict(), 731 | "image_model": net.image_model.state_dict(), 732 | "optimizer": optimizer.state_dict(), 733 | "scheduler": scheduler.state_dict(), 734 | }, 735 | os.path.join(checkpoints_path, f"score_{best_score:.3f}.pt"), 736 | ) 737 | 738 | # Save the model every 5 epochs 739 | if epoch % 5 == 0: 740 | torch.save( 741 | { 742 | "it": epoch, 743 | "network": net.state_dict(), 744 | "image_model": net.image_model.state_dict(), 745 | "optimizer": optimizer.state_dict(), 746 | "scheduler": scheduler.state_dict(), 747 | }, 748 | os.path.join( 749 | checkpoints_path, f"epoch_{epoch}_score_{f1_score:.3f}.pt" 750 | ), 751 | ) -------------------------------------------------------------------------------- /LLaMA-IE/post_process.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | import re 4 | import os 5 | import copy 6 | import argparse 7 | from collections import Counter 8 | 9 | view_position = ["AP", "PA", "LATERAL", "LL", ""] 10 | disease_adjective = [ 11 | "new large", 12 | "minimal residual", 13 | "slightly larger", 14 | "similar", 15 | "significant", 16 | "slightly more prominent", 17 | "interval development", 18 | "mild elevation", 19 | "small amount", 20 | "slightly worse", 21 | "more pronounced", 22 | "more prominent", 23 | "persistent mild", 24 | "difficult to exclude", 25 | "borderline size", 26 | "very small", 27 | "interval", 28 | "potential", 29 | "minimal decrease", 30 | "suspected", 31 | "slightly lower", 32 | "volume loss", 33 | "almost completely resolved", 34 | "slightly worsened", 35 | "less prominent", 36 | "unchanged small", 37 | "new moderate", 38 | "suggestive of", 39 | "persistent small", 40 | "substantial increase", 41 | "resolution of", 42 | "little", 43 | "resolution", 44 | "borderline enlarged", 45 | "slightly enlarged", 46 | "moderate size", 47 | "mild increase", 48 | "minimal improvement", 49 | "larger", 50 | "ill-defined", 51 | "moderate-sized", 52 | "near complete", 53 | "markedly enlarged", 54 | "widespread", 55 | "slight worsening", 56 | "possible mild", 57 | "newly appeared", 58 | "moderate-to-large", 59 | "bilateral small", 60 | "interval decrease", 61 | "substantially decreased", 62 | "resolving", 63 | "no change", 64 | "slight decrease", 65 | "faint", 66 | "asymmetric", 67 | "mildly", 68 | "left greater than right", 69 | "scattered", 70 | "interval increase", 71 | "mildly increased", 72 | "substantial decrease", 73 | "loculated", 74 | "worrisome for", 75 | "worse", 76 | "likely small", 77 | "small residual", 78 | "minimally improved", 79 | "progressed", 80 | "hyperexpansion", 81 | "massive", 82 | "lungs", 83 | "interval improvement", 84 | "presumed", 85 | "minimal increase", 86 | "complete", 87 | "subsequent", 88 | "improvement", 89 | "minimally increased", 90 | "minimally decreased", 91 | "stable small", 92 | "probable small", 93 | "top-normal", 94 | "decreasing", 95 | "dense", 96 | "developing", 97 | "heterogeneous", 98 | "stable mild", 99 | "partial", 100 | "some", 101 | "slight improvement", 102 | "superimposed", 103 | "normal", 104 | "smaller", 105 | "adjacent", 106 | "moderate-to-severe", 107 | "tortuosity", 108 | "mild to moderately enlarged", 109 | "possible trace", 110 | "small-to-moderate", 111 | "mild improvement", 112 | "bibasal", 113 | "slight", 114 | "known", 115 | "unchanged mild", 116 | "blunting", 117 | "minor", 118 | "tortuous", 119 | "moderately enlarged", 120 | "slightly decreased", 121 | "slightly improved", 122 | "moderately severe", 123 | "new mild", 124 | "multiple", 125 | "multifocal", 126 | "little change", 127 | "possible small", 128 | "constant", 129 | "upper limits of normal", 130 | "focal", 131 | "residual", 132 | "new small", 133 | "layering", 134 | "slight increase", 135 | "stable moderate", 136 | "diffuse", 137 | "hazy", 138 | "calcified", 139 | "increase", 140 | "elevated", 141 | "atelectatic", 142 | "marked", 143 | "decrease", 144 | "moderate to large", 145 | "associated", 146 | "unchanged moderate", 147 | "slightly increased", 148 | "lower", 149 | "basilar", 150 | "moderate to severe", 151 | "small bilateral", 152 | "right basilar", 153 | "hyperinflated", 154 | "improving", 155 | "streaky", 156 | "left retrocardiac", 157 | "elevation", 158 | "small to moderate", 159 | "mild-to-moderate", 160 | "continued", 161 | "probable", 162 | "right", 163 | "left basilar", 164 | "linear", 165 | "subtle", 166 | "extensive", 167 | "borderline", 168 | "left", 169 | "tiny", 170 | "enlargement", 171 | "mild to moderate", 172 | "pleural", 173 | "compressive", 174 | "possible", 175 | "chronic", 176 | "trace", 177 | "likely", 178 | "mildly enlarged", 179 | "worsening", 180 | "resolved", 181 | "increasing", 182 | "retrocardiac", 183 | "worsened", 184 | "substantial", 185 | "patchy", 186 | "persistent", 187 | "decreased", 188 | "enlarged", 189 | "severe", 190 | "large", 191 | "bilateral", 192 | "stable", 193 | "bibasilar", 194 | "low", 195 | "minimal", 196 | "new", 197 | "increased", 198 | "atelectasis", 199 | "improved", 200 | "unchanged", 201 | "moderate", 202 | "small", 203 | "mask", 204 | "mild", 205 | ] 206 | disease_location_organ = [ 207 | "left lung base laterally", 208 | "parenchymal", 209 | "upper lobe", 210 | "central pulmonary vasculature", 211 | "bronchial wall", 212 | "posteriorly", 213 | "bibasilar lung", 214 | "bilateral alveolar", 215 | "retrocardiac lung areas", 216 | "right costophrenic sulcus", 217 | "bilateral basilar", 218 | "right subpulmonic", 219 | "pulmonary interstitial", 220 | "biapical", 221 | "costophrenic angle", 222 | "right mid lung field", 223 | "left costophrenic sinus", 224 | "pericardial", 225 | "bilateral lung", 226 | "bilateral upper lobe", 227 | "bilateral bases", 228 | "bilateral hilar", 229 | "bilateral basal", 230 | "bibasilar subsegmental", 231 | "left lung basis", 232 | "right lung basis", 233 | "main pulmonary artery", 234 | "right perihilar region", 235 | "left retrocardiac region", 236 | "mediastinal and hilar", 237 | "diffuse bilateral pulmonary", 238 | "left lateral chest wall", 239 | "right major fissure", 240 | "right pleural space", 241 | "widespread parenchymal", 242 | "left mid to lower lung", 243 | "left infrahilar", 244 | "left hemidiaphragmatic", 245 | "left ventricular", 246 | "bilateral costophrenic angles", 247 | "right upper zone", 248 | "cardiomediastinal", 249 | "lung", 250 | "right medial lung base", 251 | "right mid zone", 252 | "lung parenchyma", 253 | "left chest wall", 254 | "right mid-to-lower lung", 255 | "middle lobe", 256 | "right base medially", 257 | "costophrenic angles", 258 | "left and right lung base", 259 | "upper zone", 260 | "bilateral basal parenchymal", 261 | "heart size", 262 | "lingular", 263 | "ascending aorta", 264 | "left midlung", 265 | "hila", 266 | "aortic arch", 267 | "left pneumothorax", 268 | "small bilateral", 269 | "small right", 270 | "retrocardiac lung regions", 271 | "small left", 272 | "both lower lungs", 273 | "opacities", 274 | "right and left lung bases", 275 | "heart size and mediastinum", 276 | "mediastinal", 277 | "right mid to lower lung", 278 | "bilateral lower lung", 279 | "bilateral interstitial", 280 | "bilateral lung bases", 281 | "hemidiaphragms", 282 | "right cardiophrenic angle", 283 | "base", 284 | "both lower lobes", 285 | "lung apices", 286 | "left and right lung bases", 287 | "central pulmonary vascular", 288 | "bilateral airspace", 289 | "bilateral pleural effusions", 290 | "right chest wall", 291 | "left mid lung field", 292 | "right middle and lower lobes", 293 | "base of the left lung", 294 | "right lower hemithorax", 295 | "stomach", 296 | "left hilus", 297 | "interstitial markings", 298 | "bibasilar opacities", 299 | "bilateral lungs", 300 | "lower lung", 301 | "right hemidiaphragmatic", 302 | "bilateral effusions", 303 | "left hilar", 304 | "bilateral perihilar", 305 | "interstitial pulmonary", 306 | "right middle and lower lobe", 307 | "bibasilar airspace", 308 | "right hilus", 309 | "right midlung", 310 | "base of the right lung", 311 | "pulmonary vasculature", 312 | "left lung apex", 313 | "upper lobes", 314 | "biapical pleural", 315 | "left base retrocardiac", 316 | "hilar", 317 | "both bases", 318 | "thoracic spine", 319 | "left heart border", 320 | "right perihilar", 321 | "right hilum", 322 | "aortic knob", 323 | "vascular", 324 | "left apex", 325 | "right heart border", 326 | "right infrahilar region", 327 | "lower lobe", 328 | "left mid and lower lung", 329 | "small bilateral pleural", 330 | "pulmonary vessels", 331 | "lower lobes", 332 | "descending aorta", 333 | "cardiomediastinal silhouette", 334 | "upper lungs", 335 | "right mid and lower lung", 336 | "right hilar", 337 | "left-sided", 338 | "right apex", 339 | "perihilar", 340 | "right infrahilar", 341 | "left upper lung", 342 | "bibasal", 343 | "pleural effusions", 344 | "right greater than left", 345 | "small right pleural", 346 | "pulmonary venous", 347 | "right lung apex", 348 | "right pleural effusion", 349 | "aorta", 350 | "lower lungs", 351 | "both lungs", 352 | "apical", 353 | "lingula", 354 | "small left pleural", 355 | "mediastinum", 356 | "bilateral lower lobe", 357 | "right-sided", 358 | "retrocardiac region", 359 | "thoracic aorta", 360 | "bilaterally", 361 | "left perihilar", 362 | "right lung bases", 363 | "cardiomegaly", 364 | "bilateral pulmonary", 365 | "left lung bases", 366 | "right costophrenic angle", 367 | "right upper lung", 368 | "left costophrenic angle", 369 | "right hemithorax", 370 | "bilateral parenchymal", 371 | "left pleural effusion", 372 | "basilar", 373 | "lungs", 374 | "right mid lung", 375 | "left hemithorax", 376 | "both lung bases", 377 | "left basal", 378 | "left mid lung", 379 | "right-sided pleural", 380 | "interstitial", 381 | "left-sided pleural", 382 | "left hemidiaphragm", 383 | "pleural", 384 | "right basal", 385 | "left retrocardiac", 386 | "left lower lung", 387 | "right basilar", 388 | "pulmonary", 389 | "right middle lobe", 390 | "bases", 391 | "left upper lobe", 392 | "left apical", 393 | "left basilar", 394 | "right hemidiaphragm", 395 | "pulmonary vascular", 396 | "right lower lung", 397 | "left lung", 398 | "heart", 399 | "lung bases", 400 | "bilateral", 401 | "right apical", 402 | "retrocardiac", 403 | "right lung", 404 | "right upper lobe", 405 | "left lung base", 406 | "right base", 407 | "right", 408 | "right lung base", 409 | "left", 410 | "left base", 411 | "cardiac", 412 | "bibasilar", 413 | "right lower lobe", 414 | "cardiac silhouette", 415 | "bilateral pleural", 416 | "left lower lobe", 417 | "right pleural", 418 | "left pleural", 419 | "mask", 420 | ] 421 | disease_type_original = [ 422 | "opacity with air bronchograms", 423 | "effusion with atelectasis", 424 | "central lymph node enlargement", 425 | "osseous metastatic disease", 426 | "dextroscoliosis", 427 | "perihilar edema", 428 | "tube", 429 | "space", 430 | "hemorrhage", 431 | "opacification concerning for pneumonia", 432 | "consistent with pulmonary disease", 433 | "pulmonary vascular re-distribution", 434 | "cardiomyopathy or pericardial effusion", 435 | "hiatus hernia", 436 | "central pulmonary vascular congestion", 437 | "lesion", 438 | "interstitial opacity", 439 | "ground-glass opacification", 440 | "opacities compatible with pneumonia", 441 | "centralized pulmonary edema", 442 | "septal thickening", 443 | "underlying atelectasis", 444 | "distention", 445 | "scar", 446 | "nodular densities", 447 | "consistent with copd", 448 | "component", 449 | "poorly defined opacity", 450 | "mediastinal venous engorgement", 451 | "thoracic aorta", 452 | "hila", 453 | "cavitation", 454 | "volume overload", 455 | "metastatic disease", 456 | "opacities represent atelectasis", 457 | "streaks", 458 | "engorged", 459 | "tension", 460 | "pulmonary vascularity", 461 | "mediastinal and pulmonary vascular engorgement", 462 | "rib fracture", 463 | "early pneumonia", 464 | "overt pulmonary edema", 465 | "lung disease", 466 | "consolidations concerning for pneumonia", 467 | "pneumonic infiltrate", 468 | "opacity reflecting atelectasis", 469 | "hematoma", 470 | "enlarging effusions", 471 | "compression fracture", 472 | "underlying consolidation", 473 | "opacities due to atelectasis", 474 | "versus scarring", 475 | "bronchial inflammation", 476 | "mediastinal vascular engorgement", 477 | "pulmonary arteries", 478 | "hyperinflated", 479 | "opacification consistent with pneumonia", 480 | "thickening or effusion", 481 | "opacification consistent with effusion and atelectasis", 482 | "of the thoracic aorta", 483 | "heart failure", 484 | "opacities consistent with pneumonia", 485 | "opacity suggesting pneumonia", 486 | "postoperative appearance", 487 | "ards", 488 | "relative elevation", 489 | "cardiac decompensation", 490 | "pulmonary hemorrhage", 491 | "impression", 492 | "gas", 493 | "fluid collection", 494 | "prominence of interstitial markings", 495 | "abnormalities", 496 | "picc line", 497 | "rounded density", 498 | "biapical thickening", 499 | "diameter", 500 | "airspace process", 501 | "densities", 502 | "appearance", 503 | "intrathoracic malignancy", 504 | "flattening", 505 | "venous pressure", 506 | "pulmonary vascular redistribution", 507 | "plaque", 508 | "eventration", 509 | "dilatation", 510 | "of the costophrenic angle", 511 | "interstitial changes", 512 | "collapsed", 513 | "air", 514 | "osteopenia", 515 | "hilar congestion", 516 | "compression deformity", 517 | "alveolar opacities", 518 | "volume", 519 | "chf findings", 520 | "haziness", 521 | "hyperexpanded", 522 | "acute intrathoracic process", 523 | "catheter", 524 | "vascular redistribution", 525 | "clear", 526 | "peribronchial cuffing", 527 | "apical thickening", 528 | "lung nodules", 529 | "lucency", 530 | "pulmonary venous hypertension", 531 | "or pneumonia", 532 | "perihilar opacities", 533 | "acute cardiopulmonary abnormality", 534 | "mediastinum", 535 | "lung opacities", 536 | "obscuration", 537 | "pulmonary hypertension", 538 | "aspiration", 539 | "shift", 540 | "emphysematous changes", 541 | "silhouette", 542 | "process", 543 | "pulmonary arterial hypertension", 544 | "pulmonary abnormality", 545 | "of the lungs", 546 | "markings", 547 | "masses", 548 | "interstitial prominence", 549 | "infiltrative pulmonary abnormality", 550 | "atherosclerotic calcifications", 551 | "lobe pneumonia", 552 | "pulmonary vasculature", 553 | "heart enlargement", 554 | "granulomas", 555 | "alveolar infiltrate", 556 | "radiodensity", 557 | "infiltrates", 558 | "elongation", 559 | "relevant change", 560 | "air-fluid level", 561 | "contour", 562 | "bronchiectasis", 563 | "of pulmonary venous pressure", 564 | "and opacities", 565 | "collection", 566 | "airspace disease", 567 | "confluent opacity", 568 | "contours", 569 | "interstitial pulmonary abnormality", 570 | "pulmonary opacities", 571 | "redistribution", 572 | "obstructive pulmonary disease", 573 | "degenerative changes", 574 | "mediastinal widening", 575 | "calcification", 576 | "aorta", 577 | "pulmonary congestion", 578 | "vascular engorgement", 579 | "indistinctness", 580 | "pulmonary opacifications", 581 | "fullness", 582 | "adenopathy", 583 | "fracture", 584 | "bronchial wall thickening", 585 | "fractures", 586 | "interstitial lung disease", 587 | "infection", 588 | "hydropneumothorax", 589 | "fibrosis", 590 | "plaques", 591 | "ventilation", 592 | "widening", 593 | "scoliosis", 594 | "rib fractures", 595 | "pulmonary disease", 596 | "lymphadenopathy", 597 | "pulmonary nodules", 598 | "pressure", 599 | "pneumoperitoneum", 600 | "reticular opacities", 601 | "acute cardiopulmonary disease", 602 | "airspace opacity", 603 | "change", 604 | "overinflation", 605 | "peribronchial opacification", 606 | "pneumomediastinum", 607 | "plate-like atelectasis", 608 | "calcifications", 609 | "granuloma", 610 | "pulmonary venous pressure", 611 | "interstitial markings", 612 | "interstitial abnormality", 613 | "congestive heart failure", 614 | "pulmonary fibrosis", 615 | "airspace consolidation", 616 | "pulmonary vascular engorgement", 617 | "infiltrate", 618 | "abnormality", 619 | "parenchymal opacity", 620 | "interstitial opacities", 621 | "acute cardiopulmonary process", 622 | "tortuosity", 623 | "hyperexpansion", 624 | "engorgement", 625 | "subcutaneous emphysema", 626 | "heart", 627 | "lungs", 628 | "density", 629 | "opacifications", 630 | "chf", 631 | "hyperinflation", 632 | "fluid", 633 | "atelectaxic changes", 634 | "blunting", 635 | "heart size", 636 | "prominence", 637 | "airspace opacities", 638 | "parenchymal opacities", 639 | "mass", 640 | "enlarged", 641 | "hiatal hernia", 642 | "copd", 643 | "interstitial pulmonary edema", 644 | "volume loss", 645 | "consolidations", 646 | "emphysema", 647 | "changes", 648 | "aeration", 649 | "scarring", 650 | "collapse", 651 | "thickening", 652 | "size", 653 | "fluid overload", 654 | "elevation", 655 | "congestion", 656 | "vascular congestion", 657 | "nodule", 658 | "lung volumes", 659 | "pulmonary vascular congestion", 660 | "opacification", 661 | "enlargement", 662 | "consolidation", 663 | "pneumonia", 664 | "cardiac", 665 | "edema", 666 | "opacity", 667 | "pneumothorax", 668 | "atelectasis", 669 | "effusion", 670 | ] 671 | 672 | disease_type = [ 673 | "degenerative changes", 674 | "adenopathy", 675 | "collection", 676 | "calcification", 677 | "airspace disease", 678 | "mediastinum", 679 | "rib fractures", 680 | "vascular engorgement", 681 | "interstitial pulmonary abnormality", 682 | "mediastinal widening", 683 | "plaques", 684 | "hydropneumothorax", 685 | "bronchial wall thickening", 686 | "calcifications", 687 | "pulmonary opacifications", 688 | "scoliosis", 689 | "infection", 690 | "ventilation", 691 | "lymphadenopathy", 692 | "pneumoperitoneum", 693 | "reticular opacities", 694 | "airspace opacity", 695 | "granuloma", 696 | "overinflation", 697 | "peribronchial opacification", 698 | "pneumomediastinum", 699 | "parenchymal opacity", 700 | "airspace consolidation", 701 | "congestive heart failure", 702 | "pulmonary fibrosis", 703 | "infiltrate", 704 | "engorgement", 705 | "interstitial abnormality", 706 | "pulmonary vascular engorgement", 707 | "subcutaneous emphysema", 708 | "interstitial opacities", 709 | "pulmonary venous pressure", 710 | "interstitial markings", 711 | "tortuosity", 712 | "hyperexpansion", 713 | "acute cardiopulmonary process", 714 | "fluid", 715 | "prominence", 716 | "opacifications", 717 | "chf", 718 | "hyperinflation", 719 | "airspace opacities", 720 | "blunting", 721 | "parenchymal opacities", 722 | "atelectaxic changes", 723 | "mass", 724 | "consolidations", 725 | "copd", 726 | "emphysema", 727 | "thickening", 728 | "hiatal hernia", 729 | "volume loss", 730 | "interstitial pulmonary edema", 731 | "scarring", 732 | "aeration", 733 | "collapse", 734 | "fluid overload", 735 | "nodule", 736 | "vascular congestion", 737 | "opacification", 738 | "consolidation", 739 | "pneumonia", 740 | "opacity", 741 | "edema", 742 | "pneumothorax", 743 | "atelectasis", 744 | "effusion", 745 | ] 746 | 747 | if "mask" in disease_type: 748 | disease_type = disease_type.remove("mask") 749 | 750 | disease_divide = { 751 | "effusion and atelectasis": ["effusion", "atelectasis"], 752 | "collapse and/or consolidation": ["collapse", "consolidation"], 753 | "consolidation compatible with pneumonia": ["consolidation", "pneumonia"], 754 | "consolidation concerning for pneumonia": ["consolidation", "pneumonia"], 755 | "consolidative opacity": ["consolidation", "opacity"], 756 | "consolidative opacities": ["consolidation", "opacity"], 757 | "opacity compatible with pneumonia": ["opacity", "pneumonia"], 758 | "opacity consistent with pneumonia": ["opacity", "pneumonia"], 759 | "opacities concerning for pneumonia": ["opacity", "pneumonia"], 760 | "opacity concerning for pneumonia": ["opacity", "pneumonia"], 761 | "opacity compatible with atelectasis": ["opacity", "atelectasis"], 762 | "opacities suggestive of atelectasis": ["opacity", "atelectasis"], 763 | "opacities reflect atelectasis": ["opacity", "atelectasis"], 764 | "opacities atelectasis": ["opacity", "atelectasis"], 765 | "opacities reflecting atelectasis": ["opacity", "atelectasis"], 766 | "nodular opacification": ["nodular", "opacification"], 767 | } 768 | 769 | disease_type_repeat = [ 770 | [ 771 | "atelectasis", 772 | "volume loss/infiltrate", 773 | "basal atelectasis", 774 | "plate atelectasis", 775 | "scarring or atelectasis", 776 | "areas of atelectasis", 777 | "lobe atelectasis", 778 | "atelectasis/scarring", 779 | "platelike atelectasis", 780 | "plate-like atelectasis", 781 | "subsegmental atelectasis", 782 | ], 783 | [ 784 | "effusion", 785 | "to effusions", 786 | "effusion or thickening", 787 | "to effusion", 788 | "and effusion", 789 | "pericardial effusion", 790 | "effusions", 791 | "effusion", 792 | ], 793 | [ 794 | "consolidation", 795 | "region of consolidation", 796 | "and/or consolidation", 797 | "areas of consolidation", 798 | "or consolidation", 799 | "pulmonary consolidation", 800 | "airspace consolidation", 801 | "consolidations", 802 | "consolidation", 803 | ], 804 | [ 805 | "pneumonia", 806 | "aspiration pneumonia", 807 | "pneumonia", 808 | "acute pneumonia", 809 | ], 810 | [ 811 | "pneumothorax", 812 | "apical pneumothorax", 813 | "hydro pneumothorax", 814 | "pneumothoraces", 815 | "pneumothorax", 816 | ], 817 | [ 818 | "opacity", 819 | "rounded opacity", 820 | "airspace opacification", 821 | "pulmonary opacification", 822 | "nodular opacification", 823 | "reticulonodular opacities", 824 | "pulmonary opacities", 825 | "confluent opacity", 826 | "opacities", 827 | "opacification", 828 | "opacifications", 829 | "opacity", 830 | ], 831 | [ 832 | "scarring", 833 | "fibrotic changes", 834 | "scarring or atelectasis", 835 | "atelectasis/scarring", 836 | "or scarring", 837 | "scarring", 838 | ], 839 | [ 840 | "cardiac", 841 | "cardiomediastinal silhouette", 842 | "to cardiomegaly", 843 | "cardiac silhouette enlargement", 844 | "of the cardiac silhouette", 845 | "cardiac silhouette", 846 | "cardiac enlargement", 847 | "cardiomegaly", 848 | ], 849 | [ 850 | "edema", 851 | "pulmonary interstitial edema", 852 | "to pulmonary edema", 853 | "interstitial edema", 854 | "pulmonary edema", 855 | "edema", 856 | ], 857 | [ 858 | "nodule", 859 | "lung nodule", 860 | "nodules", 861 | "pulmonary nodule", 862 | "nodular density", 863 | "nodular opacities", 864 | "pulmonary nodules", 865 | "nodular opacity", 866 | "nodular", 867 | "nodule", 868 | ], 869 | [ 870 | "vascular congestion", 871 | "pulmonary vascular congestion", 872 | "congestion", 873 | "pulmonary congestion", 874 | ], 875 | ["pulmonary fibrosis", "fibrosis"], 876 | ] 877 | 878 | # disease_type_repeat = [ 879 | # [ 880 | # "atelectasis", 881 | # "volume loss/infiltrate", 882 | # "region of consolidation", 883 | # "opacity compatible with atelectasis", 884 | # "basal atelectasis", 885 | # "opacities suggestive of atelectasis", 886 | # "plate atelectasis", 887 | # "opacities reflect atelectasis", 888 | # "scarring or atelectasis", 889 | # "opacities atelectasis", 890 | # "areas of atelectasis", 891 | # "effusion and atelectasis", 892 | # "lobe atelectasis", 893 | # "opacities reflecting atelectasis", 894 | # "atelectasis/scarring", 895 | # "platelike atelectasis", 896 | # "plate-like atelectasis", 897 | # "subsegmental atelectasis", 898 | # ], 899 | # [ 900 | # "effusion", 901 | # "to effusions", 902 | # "effusion or thickening", 903 | # "to effusion", 904 | # "and effusion", 905 | # "effusion and atelectasis", 906 | # "pericardial effusion", 907 | # "effusions", 908 | # "effusion", 909 | # ], 910 | # [ 911 | # "consolidation", 912 | # "region of consolidation", 913 | # "and/or consolidation", 914 | # "areas of consolidation", 915 | # "or consolidation", 916 | # "collapse and/or consolidation", 917 | # "consolidation compatible with pneumonia", 918 | # "consolidation concerning for pneumonia", 919 | # "pulmonary consolidation", 920 | # "consolidative opacity", 921 | # "consolidation concerning for pneumonia", 922 | # "consolidative opacities", 923 | # "consolidation", 924 | # ], 925 | # [ 926 | # "pneumonia", 927 | # "opacity compatible with pneumonia", 928 | # "opacity consistent with pneumonia", 929 | # "opacities concerning for pneumonia", 930 | # "aspiration pneumonia", 931 | # "consolidation compatible with pneumonia", 932 | # "consolidation concerning for pneumonia", 933 | # "pneumonia", 934 | # "acute pneumonia", 935 | # "opacity concerning for pneumonia", 936 | # ], 937 | # [ 938 | # "pneumothorax", 939 | # "apical pneumothorax", 940 | # "hydro pneumothorax", 941 | # "pneumothoraces", 942 | # "pneumothorax", 943 | # ], 944 | # [ 945 | # "opacity", 946 | # "opacity compatible with pneumonia", 947 | # "opacity compatible with atelectasis", 948 | # "rounded opacity", 949 | # "opacity consistent with pneumonia", 950 | # "opacities suggestive of atelectasis", 951 | # "airspace opacification", 952 | # "pulmonary opacification", 953 | # "nodular opacification", 954 | # "opacities concerning for pneumonia", 955 | # "opacities reflect atelectasis", 956 | # "opacities atelectasis", 957 | # "reticulonodular opacities", 958 | # "opacities reflecting atelectasis", 959 | # "opacities concerning for pneumonia", 960 | # "opacity concerning for pneumonia", 961 | # "pulmonary opacities", 962 | # "confluent opacity", 963 | # "opacities", 964 | # "opacity", 965 | # ], 966 | # [ 967 | # "scarring", 968 | # "fibrotic changes", 969 | # "scarring or atelectasis", 970 | # "atelectasis/scarring", 971 | # "or scarring", 972 | # "scarring", 973 | # ], 974 | # [ 975 | # "cardiac", 976 | # "cardiomediastinal silhouette", 977 | # "to cardiomegaly", 978 | # "cardiac silhouette enlargement", 979 | # "of the cardiac silhouette", 980 | # "cardiac silhouette", 981 | # "cardiac enlargement", 982 | # "cardiomegaly", 983 | # ], 984 | # [ 985 | # "edema", 986 | # "pulmonary interstitial edema", 987 | # "to pulmonary edema", 988 | # "interstitial edema", 989 | # "pulmonary edema", 990 | # "edema", 991 | # ], 992 | # [ 993 | # "nodule", 994 | # "lung nodule", 995 | # "nodular opacification", 996 | # "nodules", 997 | # "pulmonary nodule", 998 | # "nodular density", 999 | # "nodular opacities", 1000 | # "pulmonary nodules", 1001 | # "nodular opacity", 1002 | # "nodule", 1003 | # ], 1004 | # [ 1005 | # "vascular congestion", 1006 | # "pulmonary vascular congestion", 1007 | # "congestion", 1008 | # "pulmonary congestion", 1009 | # ], 1010 | # ["pulmonary fibrosis", "fibrosis"], 1011 | # ] 1012 | 1013 | 1014 | def post_process(response_dir, image_dir, csv_dir): 1015 | paths_dict = {} 1016 | # 用subject_id(row[1])加study_id(row[2])的方式区分每一个样本,并保存对应图片名称(row[0])和视图信息(row[4]) 1017 | with open(csv_dir, "r") as f: 1018 | reader = csv.reader(f) 1019 | for row in reader: 1020 | if str(row[1]) + str(row[2]) not in paths_dict: 1021 | paths_dict[str(row[1]) + str(row[2])] = [ 1022 | [str(row[0]), str(row[1]), str(row[2]), str(row[4])] 1023 | ] 1024 | else: 1025 | paths_dict[str(row[1]) + str(row[2])].append( 1026 | [str(row[0]), str(row[1]), str(row[2]), str(row[4])] 1027 | ) 1028 | 1029 | for i in range(10, 20): 1030 | path = "./post_processed_results/" + "p" + str(i) 1031 | if not os.path.exists(path): 1032 | os.makedirs(path, exist_ok=True) 1033 | 1034 | for p in range(10, 20): 1035 | results = np.load( 1036 | response_dir + "/p" + str(p) + "/results.npy", allow_pickle=True 1037 | ) 1038 | results_not_empty = np.array([]) 1039 | save_path = "./post_processed_results/p" + str(p) + "/results.npy" 1040 | for i in range(len(results)): 1041 | # 将文本用逗号或句号分隔 1042 | parts = re.split("[,.]", results[i]["response"]) 1043 | checked_result = [] 1044 | # 遍历分隔后的部分,判断是否符合格式要求 1045 | for part in parts: 1046 | if len(part) < 1: 1047 | continue 1048 | if len(re.findall("\{[a-zA-Z\s\-/]+\}", part.strip())) == 3: 1049 | # 用{}将不同疾病信息分开 1050 | result = list( 1051 | filter( 1052 | lambda s: s != "", re.split("[{}]", part.strip().lower()) 1053 | ) 1054 | ) 1055 | # 处理疾病类别为空,且疾病类别与疾病部位错位的情况 1056 | if result[2] == "mask" and result[1] in disease_type: 1057 | result[2] = result[1] 1058 | result[1] = "mask" 1059 | # 处理疾病严重程度为空,且疾病严重程度与疾病部位错位的情况 1060 | if result[0] == "mask" and result[1] in disease_adjective: 1061 | result[0] = result[1] 1062 | result[1] = "mask" 1063 | if result[1] == "mask" and result[0] in disease_location_organ: 1064 | result[1] = result[0] 1065 | result[0] = "mask" 1066 | # 处理疾病严重程度被包括在疾病类别里的情况 1067 | for words in result[2].split(" "): 1068 | adjective = "" 1069 | if words in disease_adjective: 1070 | result[2] = result[2].replace(words + " ", "") 1071 | adjective = adjective + words 1072 | if result[0] == "mask" and adjective != "": 1073 | result[0] = adjective 1074 | if "no" in adjective: 1075 | result[0] = "no" 1076 | # 把同类疾病的不同形式归为一类 1077 | for disease_type_repeat_sub in disease_type_repeat: 1078 | if result[2] in disease_type_repeat_sub: 1079 | result[2] = disease_type_repeat_sub[0] 1080 | # 处理疾病类别里包含多种疾病的情况 1081 | if result[2] in disease_divide: 1082 | if ( 1083 | result[1] not in disease_location_organ 1084 | or result[0] not in disease_adjective 1085 | ): 1086 | continue 1087 | checked_result.append( 1088 | [result[0], result[1], disease_divide[result[2]][0]] 1089 | ) 1090 | checked_result.append( 1091 | [result[0], result[1], disease_divide[result[2]][1]] 1092 | ) 1093 | else: 1094 | # 结构化标签中需要包含疾病描述、疾病位置疾病类型信息,且这些标签都分别要在特定的范围内 1095 | if ( 1096 | result[2] not in disease_type 1097 | or result[1] not in disease_location_organ 1098 | or result[0] not in disease_adjective 1099 | ): 1100 | continue 1101 | checked_result.append(result) 1102 | results[i]["result"] = checked_result 1103 | 1104 | if len(results[i]["result"]) >= 1: 1105 | # 剔除处理后不包含结构化标签的结果,加入图片路径以及视图信息 1106 | subject_id_and_study_id = str(results[i]["path"][14:22]) + str( 1107 | results[i]["path"][24:32] 1108 | ) 1109 | for j in range(len(paths_dict[subject_id_and_study_id])): 1110 | result = copy.deepcopy(results[i]) 1111 | result["image_path"] = ( 1112 | image_dir 1113 | + "p" 1114 | + paths_dict[subject_id_and_study_id][j][1][:2] 1115 | + "/p" 1116 | + paths_dict[subject_id_and_study_id][j][1] 1117 | + "/s" 1118 | + paths_dict[subject_id_and_study_id][j][2] 1119 | + "/" 1120 | + paths_dict[subject_id_and_study_id][j][0] 1121 | + ".jpg" 1122 | ) 1123 | result["view_position"] = paths_dict[subject_id_and_study_id][j][3] 1124 | # tokens = tokenizer( 1125 | # result["result"], result["view_position"]) 1126 | # padding = 10 - len(tokens) 1127 | # result["tokens"] = tokens 1128 | results_not_empty = np.append(results_not_empty, result) 1129 | np.save(save_path, results_not_empty) 1130 | 1131 | 1132 | def concat_npy(save_path): 1133 | # 合成一个.npy 1134 | data_dir = "./post_processed_results" 1135 | folder_names = os.listdir(data_dir) 1136 | result = np.empty((0,)) 1137 | 1138 | for folder_name in folder_names: 1139 | file_path = os.path.join(data_dir, folder_name, "results.npy") 1140 | data = np.load(file_path, allow_pickle=True) 1141 | result = np.concatenate((result, data)) 1142 | np.save(save_path, result) 1143 | 1144 | 1145 | def parse_args(): 1146 | parser = argparse.ArgumentParser(description="Inference") 1147 | parser.add_argument( 1148 | "--image_dir", 1149 | type=str, 1150 | help="physionet.org/files/mimic-cxr-jpg/2.0.0/files/", 1151 | default="~/data/physionet.org/files/mimic-cxr-jpg/2.0.0/files/", 1152 | ) 1153 | parser.add_argument( 1154 | "--csv_dir", 1155 | type=str, 1156 | help="mimic-cxr-2.0.0-metadata.csv", 1157 | default="./mimic-cxr-2.0.0-metadata.csv", 1158 | ) 1159 | parser.add_argument( 1160 | "--extracted_entity", 1161 | type=str, 1162 | help="extracted_entity", 1163 | default="./extracted_entity/llama3_fine_tuned/", 1164 | ) 1165 | parser.add_argument( 1166 | "--save_path", type=str, help="save_path", default="./mimic_cxr.npy" 1167 | ) 1168 | args = parser.parse_args() 1169 | return args 1170 | 1171 | 1172 | if __name__ == "__main__": 1173 | args = parse_args() 1174 | image_dir = args.image_dir 1175 | csv_dir = args.csv_dir 1176 | extracted_entity = args.extracted_entity 1177 | save_path = args.save_path 1178 | post_process(extracted_entity, image_dir, csv_dir) 1179 | concat_npy(save_path) 1180 | results = np.load(save_path, allow_pickle=True) 1181 | print(results[:10]) 1182 | --------------------------------------------------------------------------------