├── LICENSE ├── README.md ├── ext_data.py ├── pretrain ├── config.yaml └── main.py ├── requirements.txt └── utils ├── __pycache__ ├── dcl_loss.cpython-39.pyc ├── utils_builder.cpython-38.pyc ├── utils_builder.cpython-39.pyc ├── utils_dataset.cpython-38.pyc ├── utils_dataset.cpython-39.pyc ├── utils_trainer.cpython-38.pyc └── utils_trainer.cpython-39.pyc ├── utils_builder.py ├── utils_dataset.py ├── utils_optimizer.py └── utils_trainer.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Che Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # M-FLAG-MICCAI2023 2 | 3 | [M-FLAG: Medical Vision-Language Pre-training with Frozen Language Models and Latent Space Geometry Optimization]([link](https://arxiv.org/abs/2307.08347)), MICCAI 2023. 4 | 5 | ### Installation 6 | To clone this repository: 7 | ``` 8 | git clone https://github.com/cheliu-computation/M-FLAG-MICCAI2023.git 9 | ``` 10 | To install Python dependencies: 11 | ``` 12 | pip install -r requirements.txt 13 | ``` 14 | All experiments are implemented on A100 GPU. 15 | 16 | ### Pre-train Dataset downloading 17 | Datasets we used are as follows: 18 | - **MIMIC-CXR**: We downloaded the [MIMIC-CXR-JPG](https://physionet.org/content/mimic-cxr-jpg/2.0.0/) dataset as the radiographs. Paired medical reports can be downloaded in [MIMIC-CXR](https://physionet.org/content/mimic-cxr/2.0.0/mimic-cxr-reports.zip). 19 | 20 | ### Preprocessing 21 | - First we follow [MGCA](https://github.com/HKU-MedAI/MGCA) preprocessing to extract a master csv includes all CXR scans associated with report. You can find in [Preprocessing](https://github.com/HKU-MedAI/MGCA/blob/main/mgca/preprocess/mimic_cxr.py). 22 | - Then, run 'ext_data.py' to extract all scans and save as a npy file. It will accelerate the pre-training stage. 23 | 24 | ### Pre-training 25 | We pre-trained MGCA on MIMIC-CXR using this command: 26 | ``` 27 | 28 | cd M-FLAG-MICCAI2023/pretrain 29 | torchrun --nnodes=1 --nproc_per_node=2 main.py 30 | ``` 31 | 32 | ### Finetune on downstream tasks 33 | We evlauate the performance of M-FLAG on three downstream tasks: classification, object detection and semantic segmentation. 34 | 35 | For classification task, we follow [CheXclusion](https://github.com/LalehSeyyed/CheXclusion), please follow their offical code to extract data and implement classification tasks. 36 | 37 | For semantic segmentation and object detection, we follow [MGCA](https://github.com/HKU-MedAI/MGCA) offical configuration and code. The dataset can be found in MGCA repository. 38 | -------------------------------------------------------------------------------- /ext_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from skimage import io, transform 3 | from PIL import Image 4 | import numpy as np 5 | import pandas as pd 6 | import argparse 7 | import skimage 8 | from tqdm import tqdm 9 | 10 | 11 | def get_MIMIC_img(subject_id, study_id, dicom): 12 | path = 'xx' # meta MIMIC path 13 | report_path = 'xx' # report MIMIC path 14 | 15 | sub_dir = 'p' + subject_id[0:2] + '/' + 'p' + subject_id + '/' + 's' + study_id + '/' + dicom + '.jpg' 16 | report_sub_dir = 'p' + subject_id[0:2] + '/' + 'p' + subject_id + '/' + 's' + study_id + '.txt' 17 | jpg_path = path + sub_dir 18 | report_path = report_path + report_sub_dir 19 | 20 | img = Image.open(jpg_path) 21 | img = np.array(img) 22 | return img 23 | 24 | parser = argparse.ArgumentParser(description='extract_data') 25 | parser.add_argument('--resize', type=int) 26 | 27 | if __name__ == "__main__": 28 | args = parser.parse_args() 29 | resize = args.resize 30 | 31 | metacsv = pd.read_csv('xx') # master csv from MGCA preprocessing stage 32 | 33 | temp_npy = np.zeros((metacsv.shape[0], resize, resize), dtype=np.uint8) 34 | print(metacsv.shape, temp_npy.shape) 35 | 36 | for i in tqdm(range(temp_npy.shape[0])): 37 | dicom_idx = metacsv['dicom_id'][i] 38 | subject_idx = str(int(metacsv['subject_id'][i])) 39 | study_idx = str(int(metacsv['study_id'][i])) 40 | 41 | img = get_MIMIC_img(subject_id=subject_idx, study_id=study_idx, dicom=dicom_idx) 42 | x, y = np.nonzero(img) 43 | xl,xr = x.min(),x.max() 44 | yl,yr = y.min(),y.max() 45 | img = img[xl:xr+1, yl:yr+1] 46 | img = ((img - img.min()) * (1/(img.max() - img.min()) * 256)) 47 | 48 | img = skimage.transform.resize(img, (resize, resize), 49 | order=1, preserve_range=True, anti_aliasing=False) 50 | img = img.astype(np.uint8) 51 | 52 | temp_npy[i,:,:] = img 53 | 54 | np.save(f'xx', temp_npy) # save to ext_data folder -------------------------------------------------------------------------------- /pretrain/config.yaml: -------------------------------------------------------------------------------- 1 | network: 2 | img_model: resnet50 3 | ### this part does not control builder/trainer 4 | text_model: bert 5 | free_layers: 12 # set 12 to freeze all layer in bert 6 | text_model_arch: general # specialized/general 7 | feature_dim: 768 8 | 9 | projection_head: 10 | mlp_hidden_size: 2048 11 | projection_size: 768 12 | ### 13 | 14 | # img_path: 'xx' # add your image file path here 15 | # text_path: 'xx' # add your text file path here 16 | 17 | # params for trainer 18 | trainer: 19 | batch_size: 512 20 | test_batch_size: 200 21 | checkpoint_interval: 100000 22 | max_epochs: 50 23 | lr: 2.0e-5 24 | num_workers: 8 25 | test_interval: 2 26 | 27 | optimizer: 28 | params: 29 | lr: 2.0e-5 30 | # momentum: 0.9 31 | weight_decay: 5.0e-2 32 | 33 | # your model name 34 | wandb_name: 'xx' -------------------------------------------------------------------------------- /pretrain/main.py: -------------------------------------------------------------------------------- 1 | import random 2 | from torch.nn.parallel import DistributedDataParallel as DDP 3 | import torch.multiprocessing as mp 4 | import torch.distributed as dist 5 | import tempfile 6 | import os 7 | from torch import optim 8 | import torch.nn as nn 9 | import pandas as pd 10 | import numpy as np 11 | from transformers import AutoModel, AutoTokenizer 12 | import torchvision 13 | import torch 14 | from torch.utils.data.dataloader import DataLoader 15 | import yaml 16 | import sys 17 | sys.path.append("../utils") 18 | from utils_trainer import trainer_wBert 19 | import utils_dataset 20 | import utils_builder 21 | 22 | # import wandb 23 | 24 | 25 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 26 | 27 | 28 | def ddp_main(): 29 | dist.init_process_group("nccl") 30 | torch.cuda.empty_cache() 31 | rank = dist.get_rank() 32 | 33 | print(f"Start running basic DDP example on rank {rank}.") 34 | device_id = rank % torch.cuda.device_count() 35 | 36 | # set up 37 | config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader) 38 | 39 | torch.manual_seed(42) 40 | random.seed(0) 41 | np.random.seed(0) 42 | # loading data path 43 | text_path = config['text_path'] 44 | img_path = config['img_path'] 45 | 46 | # define image-text dataset 47 | train_dataset = utils_dataset.I_T_emb_dataset( 48 | image_path=img_path, csv_path=text_path) 49 | train_dataset = train_dataset.get_dataset(train_test='train') 50 | 51 | # building model part 52 | # -------------------- 53 | if config['network']['img_model'] == 'resnet50': 54 | model = utils_builder.ResNet_CXRBert() 55 | 56 | ''' 57 | you can freeze bert from last layer to first layer. 58 | set num of layer in config.yaml 59 | default is freeze 9 layers 60 | ''' 61 | if config['network']['free_layers'] is not None: 62 | for layer_idx in range(int(config['network']['free_layers'])): 63 | for param in list(model.lm_model.encoder.layer[layer_idx].parameters()): 64 | param.requires_grad = False 65 | 66 | model = model.to(device_id) 67 | model = DDP(model, device_ids=[device_id], find_unused_parameters=True) 68 | 69 | # -------------------- 70 | 71 | # choose optimizer (no LARS, AdamW with small batch) 72 | # -------------------- 73 | optimizer = torch.optim.AdamW( 74 | model.parameters(), 75 | **config['optimizer']['params'], 76 | betas=(0.9, 0.999) 77 | ) 78 | 79 | # ---------xw----------- 80 | trainer = trainer_wBert(model=model, 81 | optimizer=optimizer, 82 | device=rank, 83 | model_name=config['wandb_name'], 84 | **config['trainer']) 85 | # -------------------- 86 | 87 | # -------------------- 88 | # I_T_P_trainer 89 | trainer.train_w_TextEmb(train_dataset) 90 | 91 | 92 | ddp_main() 93 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | accelerate==0.16.0 3 | ansi2html==1.8.0 4 | antlr4-python3-runtime==4.9.3 5 | anyio==3.6.2 6 | argcomplete==2.0.0 7 | argon2-cffi==21.3.0 8 | argon2-cffi-bindings==21.2.0 9 | asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1618968359944/work 10 | attrs==21.4.0 11 | backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work 12 | backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1618230623929/work 13 | beautifulsoup4==4.11.1 14 | biopython==1.79 15 | black==22.10.0 16 | bleach==5.0.1 17 | blis==0.7.7 18 | Bottleneck @ file:///tmp/build/80754af9/bottleneck_1648028898966/work 19 | braceexpand==0.1.7 20 | Brotli==1.0.9 21 | brotlipy==0.7.0 22 | cachetools==5.2.0 23 | catalogue==2.0.7 24 | certifi==2023.7.22 25 | cffi @ file:///opt/conda/conda-bld/cffi_1642701102775/work 26 | chardet==4.0.0 27 | charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work 28 | click==8.1.3 29 | cloudpickle==2.2.0 30 | cmake==3.26.3 31 | colorama @ file:///tmp/build/80754af9/colorama_1607707115595/work 32 | coloredlogs==15.0.1 33 | contextlib2==21.6.0 34 | cryptography @ file:///tmp/build/80754af9/cryptography_1639414572950/work 35 | cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work 36 | cymem==2.0.6 37 | dash==2.2.0 38 | dash-core-components==2.0.0 39 | dash-html-components==2.0.0 40 | dash-table==5.0.0 41 | datamodel-code-generator==0.14.0 42 | debugpy @ file:///tmp/build/80754af9/debugpy_1637091799509/work 43 | decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work 44 | deepdish==0.3.7 45 | DeepXDE==1.1.4 46 | defusedxml==0.7.1 47 | diffusers==0.13.1 48 | dnspython==2.2.1 49 | docker-pycreds==0.4.0 50 | easyDataverse==0.3.7 51 | einops==0.4.1 52 | email-validator==1.3.0 53 | entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work 54 | et-xmlfile==1.1.0 55 | executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1646044401614/work 56 | fastapi==0.87.0 57 | fastcore==1.4.4 58 | fastdownload==0.0.6 59 | fastjsonschema==2.15.3 60 | fastprogress==1.0.2 61 | filelock==3.7.0 62 | fire==0.4.0 63 | Flask==2.2.2 64 | Flask-Compress==1.13 65 | fonttools==4.25.0 66 | ftfy==6.1.1 67 | genson==1.2.2 68 | geojson==3.0.0 69 | gitdb==4.0.9 70 | GitPython==3.1.27 71 | google-auth==2.6.6 72 | google-auth-oauthlib==0.4.6 73 | grad-cam==1.3.9 74 | grpcio==1.46.3 75 | h11==0.14.0 76 | h5py==3.6.0 77 | huggingface-hub==0.12.1 78 | humanfriendly==10.0 79 | hydra-core==1.2.0 80 | hydra-joblib-launcher==1.2.0 81 | hydra-submitit-launcher==1.1.6 82 | idna @ file:///tmp/build/80754af9/idna_1637925883363/work 83 | imageio==2.19.5 84 | imbalanced-learn==0.9.1 85 | importlib-metadata==4.11.4 86 | inflect==5.6.2 87 | ipykernel==6.9.2 88 | ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1648413562175/work 89 | ipython-genutils==0.2.0 90 | ipywidgets==8.0.2 91 | isodate==0.6.1 92 | isort==5.10.1 93 | itsdangerous==2.1.2 94 | jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1649067096717/work 95 | Jinja2==3.1.2 96 | joblib @ file:///tmp/build/80754af9/joblib_1635411271373/work 97 | jsonschema==3.2.0 98 | jupyter==1.0.0 99 | jupyter-client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1633454794268/work 100 | jupyter-console==6.4.4 101 | jupyter-core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1645024267147/work 102 | jupyter-dash==0.4.2 103 | jupyter-server==1.23.3 104 | jupyterlab-pygments==0.2.2 105 | jupyterlab-widgets==3.0.3 106 | kaleido==0.2.1 107 | kiwisolver @ file:///opt/conda/conda-bld/kiwisolver_1638569886207/work 108 | langcodes==3.3.0 109 | lit==16.0.2 110 | llvmlite==0.38.1 111 | Markdown==3.3.7 112 | MarkupSafe==2.1.1 113 | matplotlib==3.5.3 114 | matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1631080358261/work 115 | medmnist==2.1.0 116 | mistune==2.0.4 117 | mkl-fft==1.3.1 118 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186066731/work 119 | mkl-service==2.4.0 120 | ml-collections==0.1.1 121 | mlxtend==0.22.0 122 | mpmath==1.3.0 123 | munkres==1.1.4 124 | murmurhash==1.0.7 125 | mypy-extensions==0.4.3 126 | nbclassic==0.4.8 127 | nbclient==0.7.0 128 | nbconvert==7.2.5 129 | nbformat==5.4.0 130 | nbstripout==0.5.0 131 | nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1648959695634/work 132 | networkx==2.8 133 | nltk==3.5 134 | notebook==6.5.3 135 | notebook_shim==0.2.2 136 | numba==0.55.2 137 | numexpr @ file:///tmp/build/80754af9/numexpr_1640689833592/work 138 | numpy==1.19.5 139 | nvidia-cublas-cu11==11.10.3.66 140 | nvidia-cuda-cupti-cu11==11.7.101 141 | nvidia-cuda-nvrtc-cu11==11.7.99 142 | nvidia-cuda-runtime-cu11==11.7.99 143 | nvidia-cudnn-cu11==8.5.0.96 144 | nvidia-cufft-cu11==10.9.0.58 145 | nvidia-curand-cu11==10.2.10.91 146 | nvidia-cusolver-cu11==11.4.0.1 147 | nvidia-cusparse-cu11==11.7.4.91 148 | nvidia-nccl-cu11==2.14.3 149 | nvidia-nvtx-cu11==11.7.91 150 | oauthlib==3.2.0 151 | omegaconf==2.2.3 152 | openapi-schema-validator==0.1.6 153 | openapi-spec-validator==0.3.3 154 | opencv-python==4.5.5.64 155 | openpyxl==3.1.0 156 | openyxl==0.1 157 | opt-einsum==3.3.0 158 | packaging @ file:///tmp/build/80754af9/packaging_1637314298585/work 159 | pandas==1.4.4 160 | pandocfilters==1.5.0 161 | parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work 162 | pathspec==0.10.2 163 | pathtools==0.1.2 164 | pathy==0.6.1 165 | patsy==0.5.2 166 | pdebench @ file:///home/cl522/PDEBench 167 | pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1602535608087/work 168 | phiflow==2.0.3 169 | pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work 170 | Pillow==9.0.1 171 | platformdirs==2.5.4 172 | plotly==5.11.0 173 | prance==0.22.11.4.0 174 | preshed==3.0.6 175 | prometheus-client==0.15.0 176 | promise==2.3 177 | prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1649130487073/work 178 | protobuf==3.20.1 179 | psutil @ file:///tmp/build/80754af9/psutil_1612297992929/work 180 | ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl 181 | pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work 182 | pyaml==21.10.1 183 | pyasn1==0.4.8 184 | pyasn1-modules==0.2.8 185 | pycocotools==2.0.6 186 | pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work 187 | pydantic==1.9.0 188 | pyDaRUS==1.0.5 189 | pyDataverse==0.3.1 190 | Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1641580240686/work 191 | pynvml==11.4.1 192 | pyOpenSSL @ file:///opt/conda/conda-bld/pyopenssl_1643788558760/work 193 | pyparsing @ file:///tmp/build/80754af9/pyparsing_1635766073266/work 194 | pyre-extensions==0.0.29 195 | pyro-api==0.1.2 196 | pyro-ppl==1.8.3 197 | pyrsistent==0.16.1 198 | pyshp==2.3.1 199 | PySnooper==1.1.1 200 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305812635/work 201 | python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work 202 | python-dotenv==0.20.0 203 | python-forge==18.6.0 204 | pytorch-pretrained-vit==0.0.7 205 | pytorch-ranger==0.1.1 206 | pyts==0.12.0 207 | pytz==2021.3 208 | PyWavelets==1.3.0 209 | PyYAML==6.0 210 | pyzmq==19.0.2 211 | qtconsole==5.4.0 212 | QtPy==2.3.0 213 | regex==2022.4.24 214 | requests @ file:///opt/conda/conda-bld/requests_1641824580448/work 215 | requests-oauthlib==1.3.1 216 | retrying==1.3.3 217 | rsa==4.8 218 | ruamel.yaml==0.17.21 219 | ruamel.yaml.clib==0.2.7 220 | scikit-image==0.19.2 221 | scikit-learn==1.1.1 222 | scikit-optimize==0.9.0 223 | scipy==1.8.1 224 | seaborn @ file:///tmp/build/80754af9/seaborn_1629307859561/work 225 | semver==2.13.0 226 | Send2Trash==1.8.0 227 | sentry-sdk==1.5.12 228 | setproctitle==1.2.3 229 | shapely==2.0.1 230 | shortuuid==1.0.9 231 | sip==4.19.13 232 | six @ file:///tmp/build/80754af9/six_1644875935023/work 233 | smart-open==5.2.1 234 | smmap==5.0.0 235 | sniffio==1.3.0 236 | soupsieve==2.3.2.post1 237 | spacy==3.3.1 238 | spacy-legacy==3.0.9 239 | spacy-loggers==1.0.2 240 | srsly==2.4.3 241 | stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1644872665635/work 242 | starlette==0.21.0 243 | statsmodels==0.13.2 244 | submitit==1.4.5 245 | sympy==1.11.1 246 | tables==3.7.0 247 | tenacity==8.1.0 248 | tensorboard-data-server==0.6.1 249 | tensorboard-plugin-wit==1.8.1 250 | termcolor==2.1.0 251 | terminado==0.17.0 252 | thinc==8.0.17 253 | threadpoolctl @ file:///Users/ktietz/demo/mc3/conda-bld/threadpoolctl_1629802263681/work 254 | tifffile==2022.4.8 255 | # Editable install with no version control (timm==0.8.3.dev0) 256 | -e /home/cl522/reproduce_work/FDN/imagenet 257 | tinycss2==1.2.1 258 | tokenizers==0.12.1 259 | toml==0.10.2 260 | tomli==2.0.1 261 | torch==1.13.1+cu116 262 | torch-optimizer==0.3.0 263 | torchaudio==0.13.1+cu116 264 | torchvision==0.14.1+cu116 265 | tornado @ file:///tmp/build/80754af9/tornado_1606942317143/work 266 | tqdm==4.64.1 267 | traitlets==5.5.0 268 | transformers==4.26.1 269 | triton==2.0.0 270 | ttach==0.0.3 271 | typed-ast==1.5.4 272 | typer==0.4.1 273 | typing-inspect==0.8.0 274 | typing_extensions @ file:///opt/conda/conda-bld/typing_extensions_1647553014482/work 275 | urllib3 @ file:///opt/conda/conda-bld/urllib3_1643638302206/work 276 | uvicorn==0.20.0 277 | vit-pytorch==0.35.2 278 | wandb==0.12.16 279 | wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1600965781394/work 280 | webdataset==0.2.39 281 | webencodings==0.5.1 282 | websocket-client==1.4.2 283 | Werkzeug==2.2.2 284 | wfdb==3.4.1 285 | widgetsnbextension==4.0.3 286 | xmltodict==0.13.0 287 | zipp==3.8.0 288 | -------------------------------------------------------------------------------- /utils/__pycache__/dcl_loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheliu-computation/M-FLAG-MICCAI2023/a7fea3ca69dfe2ff03d0cd7d945d1c6fff835618/utils/__pycache__/dcl_loss.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_builder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheliu-computation/M-FLAG-MICCAI2023/a7fea3ca69dfe2ff03d0cd7d945d1c6fff835618/utils/__pycache__/utils_builder.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_builder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheliu-computation/M-FLAG-MICCAI2023/a7fea3ca69dfe2ff03d0cd7d945d1c6fff835618/utils/__pycache__/utils_builder.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheliu-computation/M-FLAG-MICCAI2023/a7fea3ca69dfe2ff03d0cd7d945d1c6fff835618/utils/__pycache__/utils_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheliu-computation/M-FLAG-MICCAI2023/a7fea3ca69dfe2ff03d0cd7d945d1c6fff835618/utils/__pycache__/utils_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheliu-computation/M-FLAG-MICCAI2023/a7fea3ca69dfe2ff03d0cd7d945d1c6fff835618/utils/__pycache__/utils_trainer.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_trainer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cheliu-computation/M-FLAG-MICCAI2023/a7fea3ca69dfe2ff03d0cd7d945d1c6fff835618/utils/__pycache__/utils_trainer.cpython-39.pyc -------------------------------------------------------------------------------- /utils/utils_builder.py: -------------------------------------------------------------------------------- 1 | from cgi import test 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import torchvision 8 | from torchvision.models import resnet as torch_resnet 9 | import torch.nn.functional as F 10 | from torch.nn.functional import normalize 11 | from transformers import AutoModel, AutoTokenizer 12 | 13 | # raw resnet with cxrbert-genereal 14 | 15 | 16 | class ResNet_CXRBert(torch.nn.Module): 17 | def __init__(self): 18 | super(ResNet_CXRBert, self).__init__() 19 | resnet = torchvision.models.resnet50(pretrained=True) 20 | 21 | self.encoder = resnet 22 | self.encoder.fc = nn.Identity() 23 | 24 | self.proj_v = nn.Sequential( 25 | nn.Linear(2048, 2048), 26 | nn.BatchNorm1d(2048), 27 | nn.ReLU(inplace=True), 28 | nn.Linear(2048, 1024), 29 | nn.BatchNorm1d(1024, affine=False)) 30 | 31 | self.proj_t = nn.Sequential( 32 | nn.Linear(768, 2048), 33 | nn.BatchNorm1d(2048), 34 | nn.ReLU(inplace=True), 35 | nn.Linear(2048, 1024), 36 | nn.BatchNorm1d(1024, affine=False)) 37 | 38 | url = 'microsoft/BiomedVLP-CXR-BERT-general' 39 | self.lm_model = AutoModel.from_pretrained( 40 | url, trust_remote_code=True, revision='main') 41 | self.tokenizer = AutoTokenizer.from_pretrained( 42 | url, trust_remote_code=True, revision='main') 43 | 44 | def _tokenize(self, text): 45 | tokenizer_output = self.tokenizer.batch_encode_plus(batch_text_or_text_pairs=text, 46 | add_special_tokens=True, 47 | truncation=True, 48 | max_length=512, 49 | padding='max_length', 50 | return_tensors='pt') 51 | 52 | return tokenizer_output 53 | 54 | @torch.no_grad() 55 | def get_text_emb(self, input_ids, attention_mask): 56 | text_emb = self.lm_model(input_ids=input_ids, 57 | attention_mask=attention_mask).last_hidden_state 58 | return text_emb 59 | 60 | def forward(self, img, input_ids, attention_mask): 61 | img_emb = self.encoder(img) 62 | # reshape to (b, 2048) 63 | img_emb = img_emb.view(img_emb.shape[0], img_emb.shape[1]) 64 | 65 | # pooler_output: [b, 1, 768] 66 | text_emb = self.get_text_emb(input_ids, attention_mask) 67 | 68 | # project to 512 dim 69 | proj_img_emb = self.proj_v(img_emb) 70 | proj_text_emb = self.proj_t(text_emb[:, 0].contiguous()) 71 | 72 | return {'img_emb': img_emb, 73 | 'proj_img_emb': proj_img_emb, 74 | 'proj_text_emb': proj_text_emb} 75 | 76 | -------------------------------------------------------------------------------- /utils/utils_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | from torch.utils.data import Dataset, ConcatDataset 4 | import numpy as np 5 | from torchvision.transforms import transforms 6 | from PIL import Image 7 | 8 | 9 | class IaT_embed_dataset(Dataset): 10 | def __init__(self, image_data, transform=None, **args): 11 | self.img_data = image_data 12 | 13 | self.text_csv = args['text'] 14 | self.mode = args['train_test'] 15 | self.transform = transform 16 | 17 | def __len__(self): 18 | return (self.img_data.shape[0]) 19 | 20 | def __getitem__(self, idx): 21 | if torch.is_tensor(idx): 22 | idx = idx.tolist() 23 | 24 | # get image 25 | image = self.img_data[idx] 26 | image = Image.fromarray(image).convert("RGB") 27 | 28 | # get raw text 29 | findings = self.text_csv['findings'].iloc[idx] 30 | impression = self.text_csv['impression'].iloc[idx] 31 | if findings == 'dumb' or type(findings) == float: 32 | pass 33 | else: 34 | impression += findings 35 | text = impression 36 | 37 | sample = {'image': image, 'raw_text': text} 38 | 39 | if self.transform: 40 | # for 2 branch contrastive vision model (not useful for CLIP) 41 | if self.mode == 'train': 42 | sample['image'] = self.transform(sample['image']) 43 | 44 | return sample 45 | 46 | 47 | class I_T_emb_dataset: 48 | 49 | def __init__(self, image_path, csv_path): 50 | self.image_path = image_path 51 | self.csv_path = csv_path 52 | 53 | def get_dataset(self, train_test, T=None): 54 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], 55 | std=[0.5, 0.5, 0.5]) 56 | 57 | if train_test == 'train': 58 | print('Apply Train-stage Transform!') 59 | 60 | Transforms = transforms.Compose([ 61 | transforms.RandomCrop(224), 62 | transforms.RandomRotation(degrees=(0, 90)), 63 | transforms.RandomGrayscale(p=0.5), 64 | transforms.RandomPerspective(distortion_scale=0.5, 65 | p=0.5, 66 | interpolation=3), 67 | transforms.RandomAutocontrast(p=0.5), 68 | transforms.ToTensor(), 69 | normalize 70 | ]) 71 | else: 72 | print('No test stage in pretrain!') 73 | 74 | img_path = np.load( 75 | self.image_path, allow_pickle=True, mmap_mode='r') 76 | csv_path = pd.read_csv( 77 | self.csv_path, low_memory=False) 78 | 79 | misc_args = {'train_test': train_test, 80 | 'text': csv_path} 81 | 82 | dataset = IaT_embed_dataset(image_data=img_path, 83 | transform=Transforms, 84 | **misc_args) 85 | 86 | return dataset 87 | -------------------------------------------------------------------------------- /utils/utils_optimizer.py: -------------------------------------------------------------------------------- 1 | from torch import optim as optim 2 | import torch 3 | 4 | 5 | class LARS(optim.Optimizer): 6 | def __init__(self, params, lr, weight_decay=0, momentum=0.9, eta=0.001, 7 | weight_decay_filter=False, lars_adaptation_filter=False): 8 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, 9 | eta=eta, weight_decay_filter=weight_decay_filter, 10 | lars_adaptation_filter=lars_adaptation_filter) 11 | super().__init__(params, defaults) 12 | 13 | def exclude_bias_and_norm(self, p): 14 | return p.ndim == 1 15 | 16 | @torch.no_grad() 17 | def step(self): 18 | for g in self.param_groups: 19 | for p in g['params']: 20 | dp = p.grad 21 | 22 | if dp is None: 23 | continue 24 | 25 | if not g['weight_decay_filter'] or not self.exclude_bias_and_norm(p): 26 | dp = dp.add(p, alpha=g['weight_decay']) 27 | 28 | if not g['lars_adaptation_filter'] or not self.exclude_bias_and_norm(p): 29 | param_norm = torch.norm(p) 30 | update_norm = torch.norm(dp) 31 | one = torch.ones_like(param_norm) 32 | q = torch.where(param_norm > 0., 33 | torch.where(update_norm > 0, 34 | (g['eta'] * param_norm / update_norm), one), one) 35 | dp = dp.mul(q) 36 | 37 | param_state = self.state[p] 38 | if 'mu' not in param_state: 39 | param_state['mu'] = torch.zeros_like(p) 40 | mu = param_state['mu'] 41 | mu.mul_(g['momentum']).add_(dp) 42 | 43 | p.add_(mu, alpha=-g['lr']) 44 | -------------------------------------------------------------------------------- /utils/utils_trainer.py: -------------------------------------------------------------------------------- 1 | # package import 2 | import os 3 | from typing import Type 4 | import torch 5 | import torch.nn.functional as F 6 | import torchvision 7 | import pandas as pd 8 | from torch.utils.data.dataloader import DataLoader 9 | # import wandb 10 | import utils_builder 11 | import math 12 | from torch.cuda.amp import autocast as autocast 13 | from torch.cuda.amp import GradScaler as GradScaler 14 | from tqdm import tqdm 15 | import numpy as np 16 | import torch.nn as nn 17 | from transformers import AutoModel, AutoTokenizer 18 | 19 | import torch.multiprocessing as mp 20 | from torch.utils.data.distributed import DistributedSampler 21 | from torch.nn.parallel import DistributedDataParallel as DDP 22 | from torch.distributed import init_process_group, destroy_process_group 23 | 24 | # image-text embedding diagnosis style trainer Class (with language model) 25 | 26 | 27 | class trainer_wBert: 28 | def __init__(self, model, 29 | optimizer, device, model_name, **args): 30 | self.model = model 31 | self.optimizer = optimizer 32 | self.device = device 33 | self.model_name = model_name 34 | self.train_batch_size = args['batch_size'] 35 | self.test_batch_size = args['test_batch_size'] 36 | self.max_epochs = args['max_epochs'] 37 | self.lr_max = args['lr'] 38 | self.num_workers = args['num_workers'] 39 | self.checkpoint_interval = args['checkpoint_interval'] 40 | 41 | def orthogonal_loss(self, x1, x2): 42 | def off_diagonal(x): 43 | # return a flattened view of the off-diagonal elements of a square matrix 44 | n, m = x.shape 45 | assert n == m 46 | return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() 47 | logits = torch.mm(x1.T, x2).to(self.device) 48 | 49 | logits.div_(self.train_batch_size) 50 | on_diag = torch.diagonal(logits).add_(-1).pow_(2).sum() 51 | off_diag = off_diagonal(logits).pow_(2).sum() 52 | loss = on_diag + 0.0051*off_diag 53 | return loss/2 54 | 55 | def align_loss(self, x, y): 56 | x = F.normalize(x, dim=1) 57 | y = F.normalize(y, dim=1) 58 | loss = 2 - 2 * (x * y).sum(dim=-1) 59 | loss += 2 - 2 * (y * x).sum(dim=-1) 60 | return loss.mean() 61 | 62 | 63 | 64 | # traing process 65 | def train_w_TextEmb(self, train_dataset): 66 | 67 | train_loader = DataLoader(train_dataset, batch_size=self.train_batch_size, 68 | num_workers=self.num_workers, 69 | drop_last=True, shuffle=False, 70 | sampler=DistributedSampler(train_dataset)) 71 | 72 | model_checkpoints_folder = os.path.join('../checkpoints') 73 | if not os.path.exists(model_checkpoints_folder): 74 | print('create directory "{}" for save checkpoint!'.format( 75 | model_checkpoints_folder)) 76 | print('---------------------------') 77 | os.mkdir(model_checkpoints_folder) 78 | else: 79 | print('directory "{}" existing for save checkpoint!'.format( 80 | model_checkpoints_folder)) 81 | 82 | # automatically resume from checkpoint if it exists 83 | print('#########################################') 84 | print('Be patient..., checking checkpoint now...') 85 | if os.path.exists(model_checkpoints_folder + self.model_name+'_checkpoint.pth'): 86 | ckpt = torch.load(model_checkpoints_folder + self.model_name+'_checkpoint.pth', 87 | map_location='cpu') 88 | start_epoch = ckpt['epoch'] 89 | self.model.load_state_dict(ckpt['model_state_dict']) 90 | self.optimizer.load_state_dict(ckpt['optimizer_state_dict']) 91 | print('continue training successful!') 92 | else: 93 | start_epoch = 0 94 | print('Start training from 0 epoch') 95 | 96 | print('#########################################') 97 | print('training start!') 98 | 99 | # scheduler 100 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( 101 | self.optimizer, 102 | T_0=int(self.max_epochs*len(train_dataset) // 103 | 4//self.train_batch_size * 0.4), 104 | T_mult=1, 105 | eta_min=1e-8, 106 | ) 107 | niter = 1 108 | 109 | skip_scheduler = False 110 | scaler = GradScaler() 111 | 112 | for epoch_counter in tqdm(range(start_epoch, self.max_epochs+1)): 113 | 114 | epoch_loss = 0 115 | epoch_loss_orthogonal, epoch_loss_align = 0, 0 116 | 117 | for data in tqdm(train_loader): 118 | # get raw text 119 | imp = data['raw_text'] 120 | 121 | # get image 122 | img = data['image'].to(torch.float32).to( 123 | self.device).contiguous() 124 | 125 | self.optimizer.zero_grad() 126 | 127 | # amp style (might decrease precision) 128 | with autocast(): 129 | imp_tokenize_output = self.model.module._tokenize(imp) 130 | 131 | input_ids = imp_tokenize_output.input_ids.to( 132 | self.device).contiguous() 133 | attention_mask = imp_tokenize_output.attention_mask.to( 134 | self.device).contiguous() 135 | 136 | output_dict = self.model(img, input_ids, attention_mask) 137 | img_emb, proj_img_emb, proj_text_emb = output_dict['img_emb'], output_dict['proj_img_emb'], output_dict['proj_text_emb'] 138 | 139 | loss_orthogonoal = self.orthogonal_loss(img_emb, img_emb) 140 | loss_align = self.align_loss(proj_img_emb, proj_text_emb) 141 | 142 | loss = loss_orthogonoal + loss_align 143 | # accumalate loss for logging 144 | epoch_loss += loss.item() 145 | epoch_loss_orthogonal += loss_orthogonoal.item() 146 | epoch_loss_align += loss_align.item() 147 | 148 | # if self.device == 0: 149 | # print( 150 | # f'epoch {epoch_counter} iter {niter} loss is {loss.item()},\ 151 | # orthogonal loss is {loss_orthogonoal.item()},\ 152 | # align loss is {loss_align.item()}') 153 | 154 | 155 | scaler.scale(loss).backward() 156 | scaler.step(self.optimizer) 157 | scaler.update() 158 | 159 | if not skip_scheduler: 160 | scheduler.step() 161 | niter += 1 162 | 163 | if self.device == 0: 164 | 165 | epoch_iter = (len(train_dataset)//self.train_batch_size//4) 166 | print(f'{epoch_counter} epoch loss is {epoch_loss/epoch_iter}!') 167 | 168 | if epoch_counter % 10 == 0: 169 | torch.save(self.model.module.state_dict(), 170 | model_checkpoints_folder + self.model_name+f'_{epoch_counter}'+'_total.pth') 171 | 172 | # save final vision encoder 173 | torch.save(self.model.module.encoder.state_dict(), 174 | model_checkpoints_folder + self.model_name+'_encoder.pth') 175 | # save final total model 176 | torch.save(self.model.module.state_dict(), 177 | model_checkpoints_folder + self.model_name+'_total.pth') 178 | 179 | def save_checkpoints(self, epoch, PATH): 180 | 181 | torch.save({ 182 | 'epoch': epoch, 183 | 'model_state_dict': self.model.state_dict(), 184 | 'optimizer_state_dict': self.optimizer.state_dict()}, 185 | PATH) 186 | --------------------------------------------------------------------------------