├── figs ├── urfound_p0.jpg ├── urfound_p1.jpg ├── urfound_p2.jpg └── urfound_p3.jpg ├── util ├── lr_sched.py ├── lr_decay.py ├── engine_pretrain.py ├── flair_dataloader │ ├── dataset.py │ ├── dataloader.py │ ├── transforms.py │ └── dictionary.py ├── pos_embed.py ├── dataset.py ├── model_urfound.py └── misc.py ├── LICENSE ├── requirements.txt ├── finetune ├── datasets_finetune.py ├── models_vit.py └── engine_finetune.py ├── .gitignore ├── README.md ├── bert ├── bert_encoder.py └── bert.py ├── main_pretrain_urfound.py └── main_finetune.py /figs/urfound_p0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukkai/UrFound/HEAD/figs/urfound_p0.jpg -------------------------------------------------------------------------------- /figs/urfound_p1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukkai/UrFound/HEAD/figs/urfound_p1.jpg -------------------------------------------------------------------------------- /figs/urfound_p2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukkai/UrFound/HEAD/figs/urfound_p2.jpg -------------------------------------------------------------------------------- /figs/urfound_p3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukkai/UrFound/HEAD/figs/urfound_p3.jpg -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 yukkai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | art==6.2 3 | cachetools==4.2.4 4 | charset-normalizer==3.3.2 5 | cycler==0.11.0 6 | filelock==3.12.2 7 | fonttools==4.38.0 8 | fsspec==2023.1.0 9 | google-auth==1.35.0 10 | google-auth-oauthlib==0.4.6 11 | grpcio==1.62.2 12 | huggingface-hub==0.16.4 13 | idna==3.7 14 | imageio==2.9.0 15 | importlib-metadata==6.7.0 16 | joblib==1.3.2 17 | kiwisolver==1.4.5 18 | Markdown==3.4.4 19 | MarkupSafe==2.1.5 20 | matplotlib==3.5.3 21 | networkx==2.6.3 22 | numpy==1.21.6 23 | oauthlib==3.2.2 24 | opencv-python==4.5.3.56 25 | packaging==24.0 26 | pandas==0.25.3 27 | parameterized==0.9.0 28 | Pillow==8.3.1 29 | protobuf==3.17.3 30 | pyasn1==0.5.1 31 | pyasn1-modules==0.3.0 32 | pycm==3.2 33 | pydicom==2.3.0 34 | pyparsing==3.1.2 35 | python-dateutil==2.9.0.post0 36 | pytz==2024.1 37 | PyWavelets==1.3.0 38 | PyYAML==6.0.1 39 | regex==2024.4.16 40 | requests==2.31.0 41 | requests-oauthlib==2.0.0 42 | rsa==4.9 43 | safetensors==0.4.3 44 | scikit-image==0.17.2 45 | scikit-learn==0.24.2 46 | scipy==1.5.4 47 | six==1.16.0 48 | tensorboard==2.6.0 49 | tensorboard-data-server==0.6.1 50 | tensorboard-plugin-wit==1.8.0 51 | threadpoolctl==3.1.0 52 | tifffile==2021.11.2 53 | timm==0.3.2 54 | tokenizers==0.13.3 55 | tqdm==4.62.1 56 | transformers==4.30.2 57 | typing_extensions==4.7.1 58 | uncertainty-calibration==0.1.4 59 | urllib3==2.0.7 60 | Werkzeug==2.2.3 61 | zipp==3.15.0 -------------------------------------------------------------------------------- /finetune/datasets_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # Partly revised by YZ @UCL&Moorfields 4 | # -------------------------------------------------------- 5 | 6 | import os 7 | from torchvision import datasets, transforms 8 | from timm.data import create_transform 9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 10 | from torchvision.transforms.functional import InterpolationMode 11 | 12 | 13 | def build_dataset(is_train, args): 14 | 15 | transform = build_transform(is_train, args) 16 | root = os.path.join(args.data_path, is_train) 17 | dataset = datasets.ImageFolder(root, transform=transform) 18 | 19 | return dataset 20 | 21 | 22 | def build_transform(is_train, args): 23 | mean = IMAGENET_DEFAULT_MEAN 24 | std = IMAGENET_DEFAULT_STD 25 | # train transform 26 | # if is_train=='train': 27 | if 'train' in is_train: 28 | # this should always dispatch to transforms_imagenet_train 29 | transform = create_transform( 30 | input_size=args.input_size, 31 | is_training=True, 32 | color_jitter=args.color_jitter, 33 | auto_augment=args.aa, 34 | # interpolation='bicubic', 35 | interpolation = InterpolationMode.BICUBIC, 36 | re_prob=args.reprob, 37 | re_mode=args.remode, 38 | re_count=args.recount, 39 | mean=mean, 40 | std=std, 41 | ) 42 | return transform 43 | 44 | # eval transform 45 | t = [] 46 | if args.input_size <= 224: 47 | crop_pct = 224 / 256 48 | else: 49 | crop_pct = 1.0 50 | size = int(args.input_size / crop_pct) 51 | t.append( 52 | transforms.Resize(size, interpolation=InterpolationMode.BICUBIC), 53 | ) 54 | t.append(transforms.CenterCrop(args.input_size)) 55 | t.append(transforms.ToTensor()) 56 | t.append(transforms.Normalize(mean, std)) 57 | return transforms.Compose(t) 58 | -------------------------------------------------------------------------------- /finetune/models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # Partly revised by YZ @UCL&Moorfields 4 | # -------------------------------------------------------- 5 | 6 | from functools import partial 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | import timm.models.vision_transformer 12 | 13 | 14 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 15 | """ Vision Transformer with support for global average pooling 16 | """ 17 | def __init__(self, global_pool=False, **kwargs): 18 | super(VisionTransformer, self).__init__(**kwargs) 19 | 20 | self.global_pool = global_pool 21 | if self.global_pool: 22 | norm_layer = kwargs['norm_layer'] 23 | embed_dim = kwargs['embed_dim'] 24 | self.fc_norm = norm_layer(embed_dim) 25 | 26 | del self.norm # remove the original norm 27 | 28 | def forward_features(self, x): 29 | B = x.shape[0] 30 | x = self.patch_embed(x) 31 | 32 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 33 | x = torch.cat((cls_tokens, x), dim=1) 34 | x = x + self.pos_embed 35 | x = self.pos_drop(x) 36 | 37 | for blk in self.blocks: 38 | x = blk(x) 39 | 40 | if self.global_pool: 41 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 42 | outcome = self.fc_norm(x) 43 | else: 44 | x = self.norm(x) 45 | outcome = x[:, 0] 46 | 47 | return outcome 48 | 49 | 50 | def vit_large_patch16(**kwargs): 51 | model = VisionTransformer( 52 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 53 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 54 | return model 55 | 56 | 57 | def vit_base_patch16(**kwargs): 58 | model = VisionTransformer( 59 | patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 60 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 61 | return model -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /util/engine_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | from typing import Iterable 13 | 14 | import torch 15 | 16 | import util.misc as misc 17 | import util.lr_sched as lr_sched 18 | 19 | 20 | def train_one_epoch(model: torch.nn.Module, 21 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 22 | device: torch.device, epoch: int, loss_scaler, 23 | log_writer=None, 24 | args=None): 25 | model.train(True) 26 | metric_logger = misc.MetricLogger(delimiter=" ") 27 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 28 | header = 'Epoch: [{}]'.format(epoch) 29 | print_freq = 20 30 | 31 | accum_iter = args.accum_iter 32 | 33 | optimizer.zero_grad() 34 | 35 | if log_writer is not None: 36 | print('log_dir: {}'.format(log_writer.log_dir)) 37 | 38 | mask_ratio = args.mask_ratio 39 | for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 40 | # we use a per iteration (instead of per epoch) lr scheduler 41 | if data_iter_step % accum_iter == 0: 42 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 43 | with torch.cuda.amp.autocast(): 44 | loss, _, _ = model(batch, mask_ratio=mask_ratio) 45 | 46 | loss_value1 = loss[0].item() 47 | loss_value2 = loss[1].item() 48 | loss = loss[0] + loss[1] 49 | loss = loss / accum_iter 50 | loss_scaler(loss, optimizer, parameters=model.parameters(), 51 | update_grad=(data_iter_step + 1) % accum_iter == 0) 52 | 53 | if (data_iter_step + 1) % accum_iter == 0: 54 | optimizer.zero_grad() 55 | 56 | torch.cuda.synchronize() 57 | 58 | metric_logger.update(loss1=loss_value1) 59 | metric_logger.update(loss2=loss_value2) 60 | 61 | lr = optimizer.param_groups[0]["lr"] 62 | metric_logger.update(lr=lr) 63 | 64 | loss_value_reduce1 = misc.all_reduce_mean(loss_value1) 65 | loss_value_reduce2 = misc.all_reduce_mean(loss_value2) 66 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 67 | """ We use epoch_1000x as the x-axis in tensorboard. 68 | This calibrates different curves when batch size changes. 69 | """ 70 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 71 | log_writer.add_scalar('train_loss1', loss_value_reduce1, epoch_1000x) 72 | log_writer.add_scalar('train_loss2', loss_value_reduce2, epoch_1000x) 73 | log_writer.add_scalar('lr', lr, epoch_1000x) 74 | 75 | # gather the stats from all processes 76 | metric_logger.synchronize_between_processes() 77 | print("Averaged stats:", metric_logger) 78 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | .DS_Store -------------------------------------------------------------------------------- /util/flair_dataloader/dataset.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Specialized Datasets implementation from pytorch to create balanced datasets 4 | with regard the used datasets for pretraining. 5 | """ 6 | 7 | import collections.abc 8 | import numpy as np 9 | 10 | from torch.utils.data import Dataset as _TorchDataset 11 | from typing import Any, Callable, Optional, Sequence, Union 12 | from torch.utils.data import Subset 13 | 14 | 15 | class Dataset(_TorchDataset): 16 | """ 17 | A generic data with a length property and an optional callable data transform 18 | when fetching a data sample. 19 | If passing slicing indices, will return a PyTorch Subset, for example: `data: Subset = data[1:4]`, 20 | for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset 21 | 22 | For example, typical input data can be a list of dictionaries:: 23 | 24 | [{ { { 25 | 'img': 'image1.nii.gz', 'img': 'image2.nii.gz', 'img': 'image3.nii.gz', 26 | 'seg': 'label1.nii.gz', 'seg': 'label2.nii.gz', 'seg': 'label3.nii.gz', 27 | 'extra': 123 'extra': 456 'extra': 789 28 | }, }, }] 29 | """ 30 | 31 | def __init__(self, data: Sequence, transform: Optional[Callable] = None) -> None: 32 | """ 33 | Args: 34 | data: input data to load and transform to generate data for model. 35 | transform: a callable data transform on input data. 36 | 37 | """ 38 | self.data = data 39 | self.transform: Any = transform 40 | 41 | def __len__(self) -> int: 42 | return len(self.data) 43 | 44 | def _transform(self, index: int): 45 | """ 46 | Fetch single data item from `self.data`. 47 | """ 48 | data_i = self.data[index] 49 | return self.transform(data_i) if self.transform is not None else data_i 50 | 51 | def __getitem__(self, index: Union[int, slice, Sequence[int]]): 52 | """ 53 | Returns a `Subset` if `index` is a slice or Sequence, a data item otherwise. 54 | """ 55 | if isinstance(index, slice): 56 | # data[:42] 57 | start, stop, step = index.indices(len(self)) 58 | indices = range(start, stop, step) 59 | return Subset(dataset=self, indices=indices) 60 | if isinstance(index, collections.abc.Sequence): 61 | # data[[1, 3, 4]] 62 | return Subset(dataset=self, indices=index) 63 | return self._transform(index) 64 | 65 | 66 | class UniformDataset(Dataset): 67 | def __init__(self, data, transform): 68 | super().__init__(data=data, transform=transform) 69 | self.datasetkey = [] 70 | self.data_dic = [] 71 | self.datasetnum = [] 72 | self.datasetlen = 0 73 | self.dataset_split(data) 74 | 75 | def dataset_split(self, data): 76 | keys = [] 77 | for img in data: 78 | keys.append(img["image_name"].split("/")[0]) 79 | 80 | self.datasetkey = list(np.unique(keys)) 81 | 82 | data_dic = {} 83 | for iKey in self.datasetkey: 84 | data_dic[iKey] = [data[iSample] for iSample in range(len(keys)) if keys[iSample]==iKey] 85 | self.data_dic = data_dic 86 | 87 | self.datasetnum = [] 88 | for key, item in self.data_dic.items(): 89 | assert len(item) != 0, f'the data {key} has no data' 90 | self.datasetnum.append(len(item)) 91 | self.datasetlen = len(self.datasetkey) 92 | 93 | def _transform(self, set_key, data_index): 94 | data_i = self.data_dic[set_key][data_index] 95 | return self.transform(data_i) if self.transform is not None else data_i 96 | 97 | def __getitem__(self, index): 98 | ## the index generated outside is only used to select the data 99 | ## the corresponding data in each data is selelcted by the np.random.randint function 100 | set_index = index % self.datasetlen 101 | set_key = self.datasetkey[set_index] 102 | 103 | data_index = np.random.randint(self.datasetnum[set_index], size=1)[0] 104 | return self._transform(set_key, data_index) -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega = omega / embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /util/flair_dataloader/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset and Dataloader preparation for vision-language pre-training 3 | """ 4 | 5 | import pandas as pd 6 | import ast 7 | from torchvision.transforms import Compose 8 | from torch.utils.data import DataLoader 9 | 10 | from util.flair_dataloader.dataset import Dataset, UniformDataset 11 | from util.flair_dataloader.transforms import LoadImage, ImageScaling, SelectRelevantKeys, CopyDict,\ 12 | ProduceDescription, AugmentDescription 13 | 14 | 15 | def get_loader(dataframes_path, data_root_path, datasets, balance=False, batch_size=8, num_workers=0, 16 | banned_categories=None, caption="A fundus photograph of [CLS]", augment_description=True): 17 | 18 | """ 19 | Dataloaders generation for vision-language pretraining. Read all dataframes from assembly model and combines 20 | them into a unified dataframe. Also, a dataloader is conditioned for training. 21 | """ 22 | 23 | # Prepare data sample pre-processing transforms 24 | transforms = Compose([ 25 | CopyDict(), 26 | LoadImage(), 27 | ImageScaling(), 28 | ProduceDescription(caption=caption), 29 | AugmentDescription(augment=augment_description), 30 | SelectRelevantKeys() 31 | ]) 32 | 33 | # Assembly dataframes into a combined data structure 34 | print("Setting assebly data...") 35 | data = [] 36 | for iDataset in datasets: 37 | print("Processing data: " + iDataset) 38 | 39 | dataframe = pd.read_csv(dataframes_path + iDataset + ".csv") 40 | 41 | for i in range(len(dataframe)): 42 | data_i = dataframe.loc[i, :].to_dict() 43 | data_i["categories"] = eval(data_i["categories"]) 44 | data_i["atributes"] = eval(data_i["atributes"]) 45 | 46 | # Remove banned words - for evaluating on incremental categories 47 | banned = False 48 | if banned_categories is not None: 49 | for iCat in data_i["categories"]: 50 | for iiCat in banned_categories: 51 | if iiCat in iCat: 52 | banned = True 53 | if banned: 54 | continue 55 | 56 | # Add sample to general data 57 | data_i["image_name"] = data_i["image"] 58 | data_i["image_path"] = data_root_path + data_i["image"] 59 | data.append(data_i) 60 | 61 | print('Total assembly data samples: {}'.format(len(data))) 62 | 63 | # Set data 64 | if balance: 65 | train_dataset = UniformDataset(data=data, transform=transforms) 66 | else: 67 | train_dataset = Dataset(data=data, transform=transforms) 68 | 69 | # Set dataloader 70 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) 71 | 72 | # Set dataloaders in dict 73 | datalaoders = {"train": train_loader} 74 | 75 | return datalaoders 76 | 77 | 78 | 79 | 80 | def get_data_list(dataframes_path, data_root_path, datasets, balance=False, batch_size=8, num_workers=0, 81 | banned_categories=None, caption="A fundus photograph of [CLS]", augment_description=True): 82 | 83 | """ 84 | Dataloaders generation for vision-language pretraining. Read all dataframes from assembly model and combines 85 | them into a unified dataframe. Also, a dataloader is conditioned for training. 86 | """ 87 | 88 | # Assembly dataframes into a combined data structure 89 | print("Setting assebly data...") 90 | data = [] 91 | for iDataset in datasets: 92 | 93 | dataframe = pd.read_csv(dataframes_path + iDataset + ".csv") 94 | print("Processing data: " + iDataset, len(dataframe)) 95 | 96 | for i in range(len(dataframe)): 97 | data_i = dataframe.loc[i, :].to_dict() 98 | data_i["categories"] = ast.literal_eval(data_i["categories"]) 99 | data_i["atributes"] = ast.literal_eval(data_i["atributes"]) 100 | 101 | # Remove banned words - for evaluating on incremental categories 102 | banned = False 103 | if banned_categories is not None: 104 | for iCat in data_i["categories"]: 105 | for iiCat in banned_categories: 106 | if iiCat in iCat: 107 | banned = True 108 | if banned: 109 | continue 110 | 111 | # Add sample to general data 112 | data_i["image_name"] = data_i["image"] 113 | data_i["image_path"] = data_root_path + data_i["image"] 114 | data.append(data_i) 115 | 116 | print('Total assembly data samples: {}'.format(len(data))) 117 | 118 | return data 119 | -------------------------------------------------------------------------------- /util/flair_dataloader/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Methods for image and text loading, pre-processing and generation 3 | for vision-language pretraining. Also, it includes data augmentation 4 | utilities. 5 | """ 6 | 7 | import numpy as np 8 | import random 9 | import torch 10 | import copy 11 | 12 | from PIL import Image 13 | from torchvision.transforms import Resize 14 | from util.flair_dataloader.dictionary import definitions 15 | import torchvision.transforms as transforms 16 | from torchvision.transforms.functional import InterpolationMode 17 | from PIL import Image 18 | 19 | 20 | class LoadImage(): 21 | def __init__(self, target="image_path"): 22 | self.target = target 23 | """ 24 | Load, organize channels, and standardize intensity of images. 25 | """ 26 | 27 | def __call__(self, data): 28 | img = Image.open(data[self.target]).convert('RGB') 29 | data[self.target.replace("_path", "")] = img 30 | return data 31 | 32 | 33 | class ImageScaling(): 34 | 35 | """ 36 | Method for image scaling. It includes two options: scaling from canvas, to avoid image distortions, 37 | and regular scaling trough resizing. 38 | """ 39 | 40 | def __init__(self, size=(512, 512), canvas=True, target="image"): 41 | self.size = size 42 | self.canvas = canvas 43 | self.target = target 44 | 45 | # self.transforms = torch.nn.Sequential( 46 | # Resize(self.size), 47 | # ) 48 | self.transforms = transforms.Compose([ 49 | transforms.RandomResizedCrop(448, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC), # 3 is bicubic 50 | transforms.RandomHorizontalFlip(), 51 | # transforms.Grayscale(num_output_channels=3), 52 | transforms.ToTensor(), 53 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 54 | # transforms.Normalize(mean=[0.4978], std=[0.2449])]) 55 | 56 | def __call__(self, data): 57 | img = data[self.target] 58 | img = self.transforms(img) 59 | # if not self.canvas or (img.shape[-1] == img.shape[-2]): 60 | # img = self.transforms(img) 61 | # else: 62 | # sizes = img.shape[-2:] 63 | # max_size = max(sizes) 64 | # scale = max_size/self.size[0] 65 | # img = Resize((int(img.shape[-2]/scale), int((img.shape[-1]/scale))))(img) 66 | # img = torch.nn.functional.pad(img, (0, self.size[0] - img.shape[-1], 0, self.size[1] - img.shape[-2], 0, 0)) 67 | 68 | data[self.target] = img 69 | return data 70 | 71 | 72 | class ProduceDescription(): 73 | 74 | """ 75 | Method that creates naive text prompts combining a prompt template, atributes (e.g. noisy), and categories 76 | (e.g. cataract). Also, this method is used to integrate text data with the modality prompt template. 77 | """ 78 | 79 | def __init__(self, caption): 80 | self.caption = caption 81 | 82 | def __call__(self, data): 83 | 84 | # Create text 85 | atr_sample = random.sample(data['atributes'], 1)[0] if len(data['atributes']) > 0 else "" 86 | cat_sample = random.sample(data['categories'], 1)[0] if len(data['categories']) > 0 else "" 87 | 88 | data["sel_category"] = cat_sample 89 | if 'OCT' in atr_sample: 90 | data["report"] = ['An Optical Coherence Tomography Image shows '+cat_sample.lower()] 91 | else: 92 | data["report"] = [self.caption.replace("[ATR]", atr_sample).replace("[CLS]", cat_sample).replace(" ", " ")] 93 | 94 | return data 95 | 96 | 97 | class AugmentDescription(): 98 | 99 | """ 100 | Method that augments naive text prompts into expert knowledge prompts by changing the category name 101 | by expert descriptions of the target category. 102 | """ 103 | 104 | def __init__(self, augment=False): 105 | self.augment = augment 106 | 107 | def __call__(self, data): 108 | 109 | if self.augment: 110 | if data["image_name"].split("/")[0] not in ["00_OCTCELL", "06_EYENET", "11_STARE", "08_ODIR-5K", "31_JICHI"]: 111 | if data["sel_category"] in list(definitions.keys()): 112 | prompts = [data["sel_category"]] + definitions[data["sel_category"]] 113 | new_cat = random.sample(prompts, 1)[0] 114 | data["report"][0] = data["report"][0].replace(data["sel_category"], new_cat) 115 | data["augmented_category"] = new_cat 116 | 117 | return data 118 | 119 | 120 | class CopyDict(): 121 | def __call__(self, data): 122 | d = copy.deepcopy(data) 123 | return d 124 | 125 | 126 | class SelectRelevantKeys(): 127 | 128 | def __init__(self, target_keys=None): 129 | if target_keys is None: 130 | target_keys = ['image', 'report', 'sel_category', 'atributes'] 131 | self.target_keys = target_keys 132 | 133 | def __call__(self, data): 134 | d = {key: data[key] for key in self.target_keys} 135 | return d -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 【MICCAI 2024】 UrFound: Towards Universal Retinal Foundation Models via Knowledge-Guided Masked Modeling 2 | 3 | This repo is the official implementation of [UrFound](https://arxiv.org/pdf/2408.05618). 4 | 5 |
6 | 7 |
8 | 9 | ## Abstract 10 | 11 | Retinal foundation models aim to learn generalizable representations from diverse retinal images, facilitating label-efficient model adaptation across various ophthalmic tasks. Despite their success, current retinal foundation models are generally restricted to a single imaging modality, such as Color Fundus Photography (CFP) or Optical Coherence Tomography (OCT), limiting their versatility. Moreover, these models may struggle to fully leverage expert annotations and overlook the valuable domain knowledge essential for domain-specific representation learning. To overcome these limitations, we introduce UrFound, a retinal foundation model designed to learn universal representations from both multimodal retinal images and domain knowledge. UrFound is equipped with a modality-agnostic image encoder and accepts either CFP or OCT images as inputs. To integrate domain knowledge into representation learning, we encode expert annotation in text supervision and propose a knowledge-guided masked modeling strategy for model pre-training. It involves reconstructing randomly masked patches of retinal images while predicting masked text tokens conditioned on the corresponding retinal image. This approach aligns multimodal images and textual expert annotations within a unified latent space, facilitating generalizable and domain-specific representation learning. Experimental results demonstrate that UrFound exhibits strong generalization ability and data efficiency when adapting to various tasks in retinal image analysis. By training on ~180k retinal images, UrFound significantly outperforms the state-of-the-art retinal foundation model trained on up to 1.6 million unlabelled images across 8 public retinal datasets. 12 | 13 | ## Framework 14 | 15 |
16 | 17 |
18 | 19 | ## Get started 20 | 21 | ### Installation 22 | 23 | ```bash 24 | # Clone this repo 25 | git clone https://github.com/yukkai/UrFound.git 26 | cd UrFound 27 | 28 | # Create a conda enviroment 29 | conda create -n urfound python=3.7.5 30 | 31 | # Activate the environment 32 | conda activate urfound 33 | 34 | # Install dependencies 35 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html 36 | pip install -r requirements.txt 37 | ``` 38 | 39 | ### Datasets 40 | 41 | * Pretrain dataset ([FLAIR](https://github.com/jusiro/FLAIR) for more details) 42 | 43 |
44 | 45 |
46 | 47 | * Finetune dataset ([RETFound](https://github.com/rmaphoh/RETFound_MAE) for more details) 48 | 49 |
50 | 51 |
52 | 53 | ### How to Run 54 | 55 | * Pretrain 56 | ```bash 57 | pretrained_model='model pre-trained on ImageNet' 58 | pretrain_data='downloaded pretrain dataset' 59 | output_path='output path' 60 | 61 | CUDA_VISIBLE_DEVICES=0 ./main_pretrain_urfound.py \ 62 | --num_workers 32 \ 63 | --accum_iter 2 \ 64 | --batch_size 128 \ 65 | --model urmodel \ 66 | --norm_pix_loss \ 67 | --mask_ratio 0.75 \ 68 | --epochs 200 \ 69 | --warmup_epochs 40 \ 70 | --blr 1.5e-4 --weight_decay 0.05 \ 71 | --resume ${pretrained_model} \ 72 | --data_path ${pretrain_data} \ 73 | --output_dir ${output_path} \ 74 | --data_mode fundus_oct \ 75 | ``` 76 | 77 | * Finetune 78 | ```bash 79 | data_i='downstream task dataset' 80 | nb_classes='class num' 81 | Pretraining_model='pretrained model' 82 | Out_folder='output path' 83 | 84 | CUDA_VISIBLE_DEVICES=0 python ./main_finetune.py \ 85 | --batch_size 16 \ 86 | --world_size 1 \ 87 | --model vit_base_patch16 \ 88 | --epochs 50 \ 89 | --blr 5e-3 --layer_decay 0.65 \ 90 | --weight_decay 0.05 --drop_path 0.2 \ 91 | --nb_classes ${nb_classes} \ 92 | --data_path ./${data_i}/ \ 93 | --task ${data_i}/ \ 94 | --finetune ${Pretraining_model} \ 95 | --input_size 224 \ 96 | --log_dir ${Out_folder}/ \ 97 | --output_dir ${Out_folder}/ 98 | ``` 99 | ## Release 100 | * Pretraiend model [[Checkpoints](https://huggingface.co/yyyyk/UrFound)] 101 | 102 | ## Citation 103 | 104 | ``` 105 | @article{yu2024urfound, 106 | title={UrFound: Towards Universal Retinal Foundation Models via Knowledge-Guided Masked Modeling}, 107 | author={Yu, Kai and Zhou, Yang and Bai, Yang and Da Soh, Zhi and Xu, Xinxing and Goh, Rick Siow Mong and Cheng, Ching-Yu and Liu, Yong}, 108 | journal={arXiv preprint arXiv:2408.05618}, 109 | year={2024} 110 | } 111 | ``` 112 | 113 | ## Acknowledgements 114 | 115 | We extend our appreciation to the developers of the [RETFound](https://github.com/rmaphoh/RETFound_MAE), [FLAIR](https://github.com/jusiro/FLAIR) and [MRM](https://github.com/RL4M/MRM-pytorch) project for sharing their open-source implementation and providing guidance on preparing the data. -------------------------------------------------------------------------------- /bert/bert_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .bert import MyBertMaskedLM 4 | from transformers.configuration_utils import PretrainedConfig 5 | 6 | class BertConfig(PretrainedConfig): 7 | r""" 8 | This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to 9 | instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a 10 | configuration with the defaults will yield a similar configuration to that of the BERT 11 | [bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture. 12 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 13 | documentation from [`PretrainedConfig`] for more information. 14 | Args: 15 | vocab_size (`int`, *optional*, defaults to 30522): 16 | Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the 17 | `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`]. 18 | hidden_size (`int`, *optional*, defaults to 768): 19 | Dimensionality of the encoder layers and the pooler layer. 20 | num_hidden_layers (`int`, *optional*, defaults to 12): 21 | Number of hidden layers in the Transformer encoder. 22 | num_attention_heads (`int`, *optional*, defaults to 12): 23 | Number of attention heads for each attention layer in the Transformer encoder. 24 | intermediate_size (`int`, *optional*, defaults to 3072): 25 | Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. 26 | hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): 27 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 28 | `"relu"`, `"silu"` and `"gelu_new"` are supported. 29 | hidden_dropout_prob (`float`, *optional*, defaults to 0.1): 30 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 31 | attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): 32 | The dropout ratio for the attention probabilities. 33 | max_position_embeddings (`int`, *optional*, defaults to 512): 34 | The maximum sequence length that this model might ever be used with. Typically set this to something large 35 | just in case (e.g., 512 or 1024 or 2048). 36 | type_vocab_size (`int`, *optional*, defaults to 2): 37 | The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`]. 38 | initializer_range (`float`, *optional*, defaults to 0.02): 39 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 40 | layer_norm_eps (`float`, *optional*, defaults to 1e-12): 41 | The epsilon used by the layer normalization layers. 42 | position_embedding_type (`str`, *optional*, defaults to `"absolute"`): 43 | Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For 44 | positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to 45 | [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). 46 | For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models 47 | with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). 48 | use_cache (`bool`, *optional*, defaults to `True`): 49 | Whether or not the model should return the last key/values attentions (not used by all models). Only 50 | relevant if `config.is_decoder=True`. 51 | classifier_dropout (`float`, *optional*): 52 | The dropout ratio for the classification head. 53 | Examples: 54 | ```python 55 | >>> from transformers import BertModel, BertConfig 56 | >>> # Initializing a BERT bert-base-uncased style configuration 57 | >>> configuration = BertConfig() 58 | >>> # Initializing a model from the bert-base-uncased style configuration 59 | >>> model = BertModel(configuration) 60 | >>> # Accessing the model configuration 61 | >>> configuration = model.config 62 | ```""" 63 | model_type = "bert" 64 | 65 | def __init__( 66 | self, 67 | vocab_size=30000, 68 | hidden_size=384, 69 | num_hidden_layers=6, 70 | num_attention_heads=6, 71 | intermediate_size=1536, 72 | hidden_act="gelu", 73 | hidden_dropout_prob=0.1, 74 | attention_probs_dropout_prob=0.1, 75 | max_position_embeddings=100, 76 | type_vocab_size=2, 77 | initializer_range=0.02, 78 | layer_norm_eps=1e-12, 79 | pad_token_id=0, 80 | position_embedding_type="absolute", 81 | use_cache=True, 82 | classifier_dropout=None, 83 | **kwargs 84 | ): 85 | super().__init__(pad_token_id=pad_token_id, **kwargs) 86 | 87 | self.vocab_size = vocab_size 88 | self.hidden_size = hidden_size 89 | self.num_hidden_layers = num_hidden_layers 90 | self.num_attention_heads = num_attention_heads 91 | self.hidden_act = hidden_act 92 | self.intermediate_size = intermediate_size 93 | self.hidden_dropout_prob = hidden_dropout_prob 94 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 95 | self.max_position_embeddings = max_position_embeddings 96 | self.type_vocab_size = type_vocab_size 97 | self.initializer_range = initializer_range 98 | self.layer_norm_eps = layer_norm_eps 99 | self.position_embedding_type = position_embedding_type 100 | self.use_cache = use_cache 101 | self.classifier_dropout = classifier_dropout 102 | 103 | class BertEncoder(nn.Module): 104 | def __init__(self): 105 | super(BertEncoder, self).__init__() 106 | 107 | self.model = MyBertMaskedLM(BertConfig()) 108 | 109 | def forward(self, latent, ids, labels, attn_mask, token_type): 110 | 111 | outputs = self.model(latent, ids, attn_mask, token_type, labels = labels) 112 | 113 | return outputs -------------------------------------------------------------------------------- /util/dataset.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import os 3 | from typing import List, Tuple 4 | from PIL import Image 5 | import pandas as pd 6 | import torch 7 | from torch.utils.data import Dataset 8 | import tokenizers 9 | import random 10 | 11 | from transformers import AutoModel, AutoTokenizer 12 | 13 | from util.flair_dataloader import dataloader 14 | 15 | 16 | # =========== 17 | dataframes_path = 'TODO CHANGE TO YOUR DATAPATH./pretrain_data/fundus_labels/' 18 | data_root_path = 'TODO CHANGE TO YOUR DATAPATH./pretrain_data/fundus_oct_images/' 19 | datasets_debug = ["04_RFMid", "00_OCTCELL"] 20 | 21 | datasets_oct = ["00_OCTCELL"] 22 | datasets_fundus = ["01_EYEPACS", "04_RFMid", 23 | "06_DEN", "07_LAG", "08_ODIR", "10_PARAGUAY", 24 | "11_STARE", "12_ARIA", "14_AGAR300", "16_FUND-OCT", 25 | "18_DRIONS-DB", "19_Drishti-GS1", 26 | "20_E-ophta", "21_G1020", "23_HRF", "24_ORIGA", "26_ROC", 27 | "28_OIA-DDR", "30_SUSTech-SYSU", "31_JICHI", 28 | "32_CHAKSU", "33_DR1-2", "35_ScarDat", "36_ACRIMA", "37_DeepDRiD_test", "37_DeepDRiD_train_eval"] 29 | datasets_fundus_oct = ["00_OCTCELL", "01_EYEPACS", "04_RFMid", 30 | "06_DEN", "07_LAG", "08_ODIR", "10_PARAGUAY", 31 | "11_STARE", "12_ARIA", "14_AGAR300", "16_FUND-OCT", 32 | "18_DRIONS-DB", "19_Drishti-GS1", 33 | "20_E-ophta", "21_G1020", "23_HRF", "24_ORIGA", "26_ROC", 34 | "28_OIA-DDR", "30_SUSTech-SYSU", "31_JICHI", 35 | "32_CHAKSU", "33_DR1-2", "35_ScarDat", "36_ACRIMA", "37_DeepDRiD_test", "37_DeepDRiD_train_eval"] 36 | 37 | balance = True 38 | batch_size = 16 39 | num_workers = 10 40 | banned_categories = [] 41 | caption = "A [ATR] fundus photograph of [CLS]" 42 | augment_description = True 43 | from torchvision.transforms import Compose 44 | from util.flair_dataloader.transforms import LoadImage, ImageScaling, SelectRelevantKeys, CopyDict,\ 45 | ProduceDescription, AugmentDescription 46 | # =========== 47 | 48 | 49 | def pil_loader(path: str) -> Image.Image: 50 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 51 | with open(path, 'rb') as f: 52 | img = Image.open(f) 53 | return img.convert('RGB') 54 | 55 | 56 | class MultimodalBertDataset_flair(Dataset): 57 | def __init__( 58 | self, 59 | data_mode, 60 | max_caption_length: int = 100 61 | ): 62 | if data_mode == 'debug': 63 | datasets_train = datasets_debug 64 | elif data_mode == 'fundus': 65 | datasets_train = datasets_fundus 66 | elif data_mode == 'oct': 67 | datasets_train = datasets_oct 68 | elif data_mode == 'fundus_oct': 69 | datasets_train = datasets_fundus_oct 70 | 71 | self.data_list = dataloader.get_data_list(dataframes_path, 72 | data_root_path, datasets_train, balance, 73 | batch_size, num_workers, banned_categories, 74 | caption, augment_description) 75 | 76 | self.transforms = Compose([ 77 | CopyDict(), 78 | LoadImage(), 79 | ImageScaling(), 80 | ProduceDescription(caption=caption), 81 | AugmentDescription(augment=augment_description), 82 | SelectRelevantKeys() 83 | ]) 84 | 85 | self.max_caption_length = max_caption_length 86 | # self.data_root = data_root 87 | # self.transform = transform 88 | # self.images_list, self.report_list = self.read_csv() 89 | # # random 90 | # random_seed = 42 91 | # random.seed(random_seed) 92 | # random.shuffle(self.images_list) 93 | # random.seed(random_seed) 94 | # random.shuffle(self.report_list) 95 | self.tokenizer = tokenizers.Tokenizer.from_pretrained('bert-base-uncased') 96 | # self.tokenizer = AutoTokenizer.from_pretrained('emilyalsentzer/Bio_ClinicalBERT') 97 | # self.tokenizer.model_max_length = 77 98 | 99 | # self.tokenizer = tokenizers.Tokenizer.from_file("mimic_wordpiece.json") 100 | self.idxtoword = {v: k for k, v in self.tokenizer.get_vocab().items()} 101 | self.tokenizer.enable_truncation(max_length=self.max_caption_length) 102 | self.tokenizer.enable_padding(length=self.max_caption_length) 103 | 104 | def __len__(self): 105 | return len(self.data_list) 106 | 107 | def _random_mask(self,tokens): 108 | masked_tokens = deepcopy(tokens) 109 | for i in range(1, masked_tokens.shape[1]-1): 110 | if masked_tokens[0][i] == 0: 111 | break 112 | 113 | if masked_tokens[0][i-1] == 3 and self.idxtoword[masked_tokens[0][i].item()][0:2] == '##': 114 | masked_tokens[0][i] = 3 115 | continue 116 | 117 | if masked_tokens[0][i-1] != 3 and self.idxtoword[masked_tokens[0][i].item()][0:2] == '##': 118 | continue 119 | 120 | prob = random.random() 121 | if prob < 0.5: 122 | masked_tokens[0][i] = 3 123 | 124 | return masked_tokens 125 | 126 | def __getitem__(self, index): 127 | batch = self.transforms(self.data_list[index]) 128 | image = batch['image'] 129 | sent = batch['report'][0] 130 | data_moda = torch.tensor(0) 131 | if 'OCT' in batch['atributes']: 132 | data_moda = torch.tensor(1) 133 | 134 | # image = pil_loader(self.images_list[index]) 135 | # image = self.transform(image) 136 | # sent = self.report_list[index] 137 | # sent = '[CLS] '+ sent 138 | 139 | encoded = self.tokenizer.encode(sent) 140 | ids = torch.tensor(encoded.ids).unsqueeze(0) 141 | attention_mask = torch.tensor(encoded.attention_mask).unsqueeze(0) 142 | type_ids = torch.tensor(encoded.type_ids).unsqueeze(0) 143 | masked_ids = self._random_mask(ids) 144 | return image, ids, attention_mask, type_ids, masked_ids, data_moda 145 | 146 | # def read_csv(self): 147 | # csv_path = os.path.join(self.data_root,'training.csv') 148 | # df = pd.read_csv(csv_path,sep=',') 149 | # return df["image_path"], df["report_content"] 150 | 151 | def collate_fn(self, instances: List[Tuple]): 152 | image_list, ids_list, attention_mask_list, type_ids_list, masked_ids_list, datamoda_list = [], [], [], [], [], [] 153 | # flattern 154 | for b in instances: 155 | image, ids, attention_mask, type_ids, masked_ids, moda_ids = b 156 | image_list.append(image) 157 | ids_list.append(ids) 158 | attention_mask_list.append(attention_mask) 159 | type_ids_list.append(type_ids) 160 | masked_ids_list.append(masked_ids) 161 | datamoda_list.append(moda_ids) 162 | 163 | # stack 164 | image_stack = torch.stack(image_list) 165 | ids_stack = torch.stack(ids_list).squeeze() 166 | attention_mask_stack = torch.stack(attention_mask_list).squeeze() 167 | type_ids_stack = torch.stack(type_ids_list).squeeze() 168 | masked_ids_stack = torch.stack(masked_ids_list).squeeze() 169 | moda_ids_stack = torch.stack(datamoda_list).squeeze() 170 | 171 | # sort and add to dictionary 172 | return_dict = { 173 | "image": image_stack, 174 | "labels": ids_stack, 175 | "attention_mask": attention_mask_stack, 176 | "type_ids": type_ids_stack, 177 | "ids": masked_ids_stack, 178 | 'tag': moda_ids_stack 179 | } 180 | 181 | return return_dict -------------------------------------------------------------------------------- /finetune/engine_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # Partly revised by YZ @UCL&Moorfields 4 | # -------------------------------------------------------- 5 | 6 | import math 7 | import sys 8 | import csv 9 | import os 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from timm.data import Mixup 14 | from timm.utils import accuracy 15 | from typing import Iterable, Optional 16 | import util.misc as misc 17 | import util.lr_sched as lr_sched 18 | from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, average_precision_score,multilabel_confusion_matrix 19 | from pycm import * 20 | import matplotlib.pyplot as plt 21 | import numpy as np 22 | 23 | 24 | 25 | 26 | def misc_measures(confusion_matrix): 27 | 28 | acc = [] 29 | sensitivity = [] 30 | specificity = [] 31 | precision = [] 32 | G = [] 33 | F1_score_2 = [] 34 | mcc_ = [] 35 | 36 | for i in range(1, confusion_matrix.shape[0]): 37 | cm1=confusion_matrix[i] 38 | acc.append(1.*(cm1[0,0]+cm1[1,1])/np.sum(cm1)) 39 | sensitivity_ = 1.*cm1[1,1]/(cm1[1,0]+cm1[1,1]) 40 | sensitivity.append(sensitivity_) 41 | specificity_ = 1.*cm1[0,0]/(cm1[0,1]+cm1[0,0]) 42 | specificity.append(specificity_) 43 | precision_ = 1.*cm1[1,1]/(cm1[1,1]+cm1[0,1]) 44 | precision.append(precision_) 45 | G.append(np.sqrt(sensitivity_*specificity_)) 46 | F1_score_2.append(2*precision_*sensitivity_/(precision_+sensitivity_)) 47 | mcc = (cm1[0,0]*cm1[1,1]-cm1[0,1]*cm1[1,0])/np.sqrt((cm1[0,0]+cm1[0,1])*(cm1[0,0]+cm1[1,0])*(cm1[1,1]+cm1[1,0])*(cm1[1,1]+cm1[0,1])) 48 | mcc_.append(mcc) 49 | 50 | acc = np.array(acc).mean() 51 | sensitivity = np.array(sensitivity).mean() 52 | specificity = np.array(specificity).mean() 53 | precision = np.array(precision).mean() 54 | G = np.array(G).mean() 55 | F1_score_2 = np.array(F1_score_2).mean() 56 | mcc_ = np.array(mcc_).mean() 57 | 58 | return acc, sensitivity, specificity, precision, G, F1_score_2, mcc_ 59 | 60 | 61 | 62 | 63 | 64 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 65 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 66 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 67 | mixup_fn: Optional[Mixup] = None, log_writer=None, 68 | args=None): 69 | model.train(True) 70 | metric_logger = misc.MetricLogger(delimiter=" ") 71 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 72 | header = 'Epoch: [{}]'.format(epoch) 73 | print_freq = 20 74 | 75 | accum_iter = args.accum_iter 76 | 77 | optimizer.zero_grad() 78 | 79 | if log_writer is not None: 80 | print('log_dir: {}'.format(log_writer.log_dir)) 81 | 82 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 83 | 84 | # we use a per iteration (instead of per epoch) lr scheduler 85 | if data_iter_step % accum_iter == 0: 86 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 87 | 88 | samples = samples.to(device, non_blocking=True) 89 | targets = targets.to(device, non_blocking=True) 90 | 91 | if mixup_fn is not None: 92 | samples, targets = mixup_fn(samples, targets) 93 | 94 | with torch.cuda.amp.autocast(): 95 | outputs = model(samples) 96 | loss = criterion(outputs, targets) 97 | 98 | loss_value = loss.item() 99 | 100 | if not math.isfinite(loss_value): 101 | print("Loss is {}, stopping training".format(loss_value)) 102 | sys.exit(1) 103 | 104 | loss /= accum_iter 105 | loss_scaler(loss, optimizer, clip_grad=max_norm, 106 | parameters=model.parameters(), create_graph=False, 107 | update_grad=(data_iter_step + 1) % accum_iter == 0) 108 | if (data_iter_step + 1) % accum_iter == 0: 109 | optimizer.zero_grad() 110 | 111 | torch.cuda.synchronize() 112 | 113 | metric_logger.update(loss=loss_value) 114 | min_lr = 10. 115 | max_lr = 0. 116 | for group in optimizer.param_groups: 117 | min_lr = min(min_lr, group["lr"]) 118 | max_lr = max(max_lr, group["lr"]) 119 | 120 | metric_logger.update(lr=max_lr) 121 | 122 | loss_value_reduce = misc.all_reduce_mean(loss_value) 123 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 124 | """ We use epoch_1000x as the x-axis in tensorboard. 125 | This calibrates different curves when batch size changes. 126 | """ 127 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 128 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) 129 | log_writer.add_scalar('lr', max_lr, epoch_1000x) 130 | 131 | # gather the stats from all processes 132 | metric_logger.synchronize_between_processes() 133 | print("Averaged stats:", metric_logger) 134 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 135 | 136 | 137 | 138 | 139 | @torch.no_grad() 140 | def evaluate(data_loader, model, device, task, epoch, mode, num_class): 141 | criterion = torch.nn.CrossEntropyLoss() 142 | 143 | metric_logger = misc.MetricLogger(delimiter=" ") 144 | header = 'Test:' 145 | 146 | if not os.path.exists(task): 147 | os.makedirs(task) 148 | 149 | prediction_decode_list = [] 150 | prediction_list = [] 151 | true_label_decode_list = [] 152 | true_label_onehot_list = [] 153 | 154 | # switch to evaluation mode 155 | model.eval() 156 | 157 | for batch in metric_logger.log_every(data_loader, 10, header): 158 | images = batch[0] 159 | target = batch[-1] 160 | images = images.to(device, non_blocking=True) 161 | target = target.to(device, non_blocking=True) 162 | true_label=F.one_hot(target.to(torch.int64), num_classes=num_class) 163 | 164 | # compute output 165 | with torch.cuda.amp.autocast(): 166 | output = model(images) 167 | loss = criterion(output, target) 168 | prediction_softmax = nn.Softmax(dim=1)(output) 169 | _,prediction_decode = torch.max(prediction_softmax, 1) 170 | _,true_label_decode = torch.max(true_label, 1) 171 | 172 | prediction_decode_list.extend(prediction_decode.cpu().detach().numpy()) 173 | true_label_decode_list.extend(true_label_decode.cpu().detach().numpy()) 174 | true_label_onehot_list.extend(true_label.cpu().detach().numpy()) 175 | prediction_list.extend(prediction_softmax.cpu().detach().numpy()) 176 | 177 | acc1,_ = accuracy(output, target, topk=(1,2)) 178 | 179 | batch_size = images.shape[0] 180 | metric_logger.update(loss=loss.item()) 181 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 182 | # gather the stats from all processes 183 | true_label_decode_list = np.array(true_label_decode_list) 184 | prediction_decode_list = np.array(prediction_decode_list) 185 | confusion_matrix = multilabel_confusion_matrix(true_label_decode_list, prediction_decode_list,labels=[i for i in range(num_class)]) 186 | acc, sensitivity, specificity, precision, G, F1, mcc = misc_measures(confusion_matrix) 187 | 188 | auc_roc = roc_auc_score(true_label_onehot_list, prediction_list,multi_class='ovr',average='macro') 189 | auc_pr = average_precision_score(true_label_onehot_list, prediction_list,average='macro') 190 | 191 | metric_logger.synchronize_between_processes() 192 | 193 | print('Sklearn Metrics - Acc: {:.4f} AUC-roc: {:.4f} AUC-pr: {:.4f} F1-score: {:.4f} MCC: {:.4f}'.format(acc, auc_roc, auc_pr, F1, mcc)) 194 | results_path = task+'/_metrics_{}.csv'.format(mode) 195 | with open(results_path,mode='a',newline='',encoding='utf8') as cfa: 196 | wf = csv.writer(cfa) 197 | data2=[[acc,sensitivity,specificity,precision,auc_roc,auc_pr,F1,mcc,metric_logger.loss]] 198 | for i in data2: 199 | wf.writerow(i) 200 | 201 | 202 | if 'test' in mode: 203 | cm = ConfusionMatrix(actual_vector=true_label_decode_list, predict_vector=prediction_decode_list) 204 | cm.plot(cmap=plt.cm.Blues,number_label=True,normalized=True,plot_lib="matplotlib") 205 | plt.savefig(task+'/confusion_matrix_{}.jpg'.format(mode),dpi=600,bbox_inches ='tight') 206 | 207 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()},auc_roc 208 | 209 | -------------------------------------------------------------------------------- /main_pretrain_urfound.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import argparse 12 | import datetime 13 | import json 14 | import numpy as np 15 | import os 16 | import time 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.backends.cudnn as cudnn 21 | from torch.utils.tensorboard import SummaryWriter 22 | 23 | import timm 24 | 25 | assert timm.__version__ == "0.3.2" # version check 26 | import timm.optim.optim_factory as optim_factory 27 | 28 | import util.misc as misc 29 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 30 | 31 | import util.model_urfound as model_urfound 32 | 33 | from util.engine_pretrain import train_one_epoch 34 | from util.dataset import MultimodalBertDataset_flair 35 | 36 | def get_args_parser(): 37 | parser = argparse.ArgumentParser('UrFound pre-training', add_help=False) 38 | parser.add_argument('--batch_size', default=64, type=int, 39 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 40 | parser.add_argument('--epochs', default=400, type=int) 41 | parser.add_argument('--accum_iter', default=1, type=int, 42 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 43 | 44 | # Model parameters 45 | parser.add_argument('--model', default='urmodel', type=str, metavar='MODEL', 46 | help='Name of model to train') 47 | 48 | parser.add_argument('--input_size', default=224, type=int, 49 | help='images input size') 50 | 51 | parser.add_argument('--mask_ratio', default=0.75, type=float, 52 | help='Masking ratio (percentage of removed patches).') 53 | 54 | parser.add_argument('--norm_pix_loss', action='store_true', 55 | help='Use (per-patch) normalized pixels as targets for computing loss') 56 | parser.set_defaults(norm_pix_loss=False) 57 | 58 | # Optimizer parameters 59 | parser.add_argument('--weight_decay', type=float, default=0.05, 60 | help='weight decay (default: 0.05)') 61 | 62 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 63 | help='learning rate (absolute lr)') 64 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 65 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 66 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 67 | help='lower lr bound for cyclic schedulers that hit 0') 68 | 69 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 70 | help='epochs to warmup LR') 71 | 72 | # Dataset parameters 73 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 74 | help='dataset path') 75 | 76 | parser.add_argument('--output_dir', default='./output_dir', 77 | help='path where to save, empty for no saving') 78 | parser.add_argument('--log_dir', default='./output_dir', 79 | help='path where to tensorboard log') 80 | parser.add_argument('--device', default='cuda', 81 | help='device to use for training / testing') 82 | parser.add_argument('--seed', default=0, type=int) 83 | parser.add_argument('--resume', default='', 84 | help='resume from checkpoint') 85 | 86 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 87 | help='start epoch') 88 | parser.add_argument('--num_workers', default=10, type=int) 89 | parser.add_argument('--pin_mem', action='store_true', 90 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 91 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 92 | parser.set_defaults(pin_mem=True) 93 | 94 | # distributed training parameters 95 | parser.add_argument('--world_size', default=1, type=int, 96 | help='number of distributed processes') 97 | parser.add_argument('--local_rank', default=-1, type=int) 98 | parser.add_argument('--dist_on_itp', action='store_true') 99 | parser.add_argument('--dist_url', default='env://', 100 | help='url used to set up distributed training') 101 | 102 | parser.add_argument('--data_mode', default='fundus', type=str, 103 | help='dataset mode: debug / fundus / oct / fundus_oct') 104 | return parser 105 | 106 | 107 | def main(args): 108 | misc.init_distributed_mode(args) 109 | 110 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 111 | print("{}".format(args).replace(', ', ',\n')) 112 | 113 | device = torch.device(args.device) 114 | 115 | # fix the seed for reproducibility 116 | seed = args.seed + misc.get_rank() 117 | torch.manual_seed(seed) 118 | np.random.seed(seed) 119 | 120 | cudnn.benchmark = True 121 | 122 | # dataset 123 | dataset_train = MultimodalBertDataset_flair(args.data_mode) 124 | 125 | print(dataset_train) 126 | 127 | if True: # args.distributed: 128 | num_tasks = misc.get_world_size() 129 | global_rank = misc.get_rank() 130 | sampler_train = torch.utils.data.DistributedSampler( 131 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 132 | ) 133 | print("Sampler_train = %s" % str(sampler_train)) 134 | else: 135 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 136 | 137 | args.log_dir = os.path.join(args.output_dir, "logs") 138 | if global_rank == 0 and args.log_dir is not None: 139 | os.makedirs(args.log_dir, exist_ok=True) 140 | log_writer = SummaryWriter(log_dir=args.log_dir) 141 | else: 142 | log_writer = None 143 | 144 | data_loader_train = torch.utils.data.DataLoader( 145 | dataset_train, sampler=sampler_train, 146 | batch_size=args.batch_size, 147 | num_workers=args.num_workers, 148 | pin_memory=args.pin_mem, 149 | drop_last=True, 150 | collate_fn=dataset_train.collate_fn 151 | ) 152 | 153 | # define the model 154 | model = model_urfound.__dict__[args.model](norm_pix_loss=args.norm_pix_loss) 155 | 156 | model.to(device) 157 | 158 | model_without_ddp = model 159 | print("Model = %s" % str(model_without_ddp)) 160 | 161 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 162 | 163 | if args.lr is None: # only base_lr is specified 164 | args.lr = args.blr * eff_batch_size / 256 165 | 166 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 167 | print("actual lr: %.2e" % args.lr) 168 | 169 | print("accumulate grad iterations: %d" % args.accum_iter) 170 | print("effective batch size: %d" % eff_batch_size) 171 | 172 | if args.distributed: 173 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 174 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 175 | model_without_ddp = model.module 176 | 177 | # following timm: set wd as 0 for bias and norm layers 178 | param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) 179 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 180 | print(optimizer) 181 | loss_scaler = NativeScaler() 182 | 183 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 184 | 185 | print(f"Start training for {args.epochs} epochs") 186 | start_time = time.time() 187 | for epoch in range(args.start_epoch, args.epochs): 188 | if args.distributed: 189 | data_loader_train.sampler.set_epoch(epoch) 190 | train_stats = train_one_epoch( 191 | model, data_loader_train, 192 | optimizer, device, epoch, loss_scaler, 193 | log_writer=log_writer, 194 | args=args 195 | ) 196 | if args.output_dir and (epoch % 20 == 0 or epoch + 1 == args.epochs): 197 | misc.save_model_pretrain( 198 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 199 | loss_scaler=loss_scaler, epoch=epoch) 200 | 201 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 202 | 'epoch': epoch,} 203 | 204 | if args.output_dir and misc.is_main_process(): 205 | if log_writer is not None: 206 | log_writer.flush() 207 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 208 | f.write(json.dumps(log_stats) + "\n") 209 | 210 | total_time = time.time() - start_time 211 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 212 | print('Training time {}'.format(total_time_str)) 213 | 214 | 215 | if __name__ == '__main__': 216 | args = get_args_parser() 217 | args = args.parse_args() 218 | if args.output_dir: 219 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 220 | main(args) -------------------------------------------------------------------------------- /bert/bert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import BertModel, BertForMaskedLM 3 | from transformers.modeling_outputs import MaskedLMOutput, BaseModelOutputWithPoolingAndCrossAttentions 4 | from torch.nn import CrossEntropyLoss 5 | 6 | 7 | class MyBertModel(BertModel): 8 | def __init__(self, config, add_pooling_layer=True): 9 | super().__init__(config) 10 | 11 | def forward( 12 | self, 13 | latent = None, 14 | input_ids=None, 15 | attention_mask=None, 16 | token_type_ids=None, 17 | position_ids=None, 18 | head_mask=None, 19 | inputs_embeds=None, 20 | encoder_hidden_states=None, 21 | encoder_attention_mask=None, 22 | past_key_values=None, 23 | use_cache=None, 24 | output_attentions=None, 25 | output_hidden_states=None, 26 | return_dict=None, 27 | ): 28 | r""" 29 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 30 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 31 | the model is configured as a decoder. 32 | encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): 33 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 34 | the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: 35 | - 1 for tokens that are **not masked**, 36 | - 0 for tokens that are **masked**. 37 | past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 38 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 39 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that 40 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all 41 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. 42 | use_cache (`bool`, *optional*): 43 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 44 | `past_key_values`). 45 | """ 46 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 47 | output_hidden_states = ( 48 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 49 | ) 50 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 51 | 52 | if self.config.is_decoder: 53 | use_cache = use_cache if use_cache is not None else self.config.use_cache 54 | else: 55 | use_cache = False 56 | 57 | if input_ids is not None and inputs_embeds is not None: 58 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 59 | elif input_ids is not None: 60 | input_shape = input_ids.size() 61 | elif inputs_embeds is not None: 62 | input_shape = inputs_embeds.size()[:-1] 63 | else: 64 | raise ValueError("You have to specify either input_ids or inputs_embeds") 65 | 66 | batch_size, seq_length = input_shape 67 | device = input_ids.device if input_ids is not None else inputs_embeds.device 68 | 69 | # past_key_values_length 70 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 71 | 72 | if attention_mask is None: 73 | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) 74 | 75 | if token_type_ids is None: 76 | if hasattr(self.embeddings, "token_type_ids"): 77 | buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] 78 | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) 79 | token_type_ids = buffered_token_type_ids_expanded 80 | else: 81 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 82 | 83 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 84 | # ourselves in which case we just need to make it broadcastable to all heads. 85 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) 86 | 87 | # If a 2D or 3D attention mask is provided for the cross-attention 88 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 89 | if self.config.is_decoder and encoder_hidden_states is not None: 90 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 91 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 92 | if encoder_attention_mask is None: 93 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 94 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 95 | else: 96 | encoder_extended_attention_mask = None 97 | 98 | # Prepare head mask if needed 99 | # 1.0 in head_mask indicate we keep the head 100 | # attention_probs has shape bsz x n_heads x N x N 101 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 102 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 103 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 104 | 105 | embedding_output = self.embeddings( 106 | input_ids=input_ids, 107 | position_ids=position_ids, 108 | token_type_ids=token_type_ids, 109 | inputs_embeds=inputs_embeds, 110 | past_key_values_length=past_key_values_length, 111 | ) 112 | embedding_output = embedding_output + latent.unsqueeze(1) 113 | encoder_outputs = self.encoder( 114 | embedding_output, 115 | attention_mask=extended_attention_mask, 116 | head_mask=head_mask, 117 | encoder_hidden_states=encoder_hidden_states, 118 | encoder_attention_mask=encoder_extended_attention_mask, 119 | past_key_values=past_key_values, 120 | use_cache=use_cache, 121 | output_attentions=output_attentions, 122 | output_hidden_states=output_hidden_states, 123 | return_dict=return_dict, 124 | ) 125 | sequence_output = encoder_outputs[0] 126 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None 127 | 128 | if not return_dict: 129 | return (sequence_output, pooled_output) + encoder_outputs[1:] 130 | 131 | return BaseModelOutputWithPoolingAndCrossAttentions( 132 | last_hidden_state=sequence_output, 133 | pooler_output=pooled_output, 134 | past_key_values=encoder_outputs.past_key_values, 135 | hidden_states=encoder_outputs.hidden_states, 136 | attentions=encoder_outputs.attentions, 137 | cross_attentions=encoder_outputs.cross_attentions, 138 | ) 139 | 140 | class MyBertMaskedLM(BertForMaskedLM): 141 | def __init__(self, config): 142 | super().__init__(config) 143 | self.bert = MyBertModel(config, add_pooling_layer=False) 144 | 145 | def forward( 146 | self, 147 | latent=None, 148 | input_ids=None, 149 | attention_mask=None, 150 | token_type_ids=None, 151 | position_ids=None, 152 | head_mask=None, 153 | inputs_embeds=None, 154 | encoder_hidden_states=None, 155 | encoder_attention_mask=None, 156 | labels=None, 157 | output_attentions=None, 158 | output_hidden_states=None, 159 | return_dict=None, 160 | ): 161 | r""" 162 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 163 | Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., 164 | config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the 165 | loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` 166 | """ 167 | 168 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 169 | 170 | outputs = self.bert( 171 | latent, 172 | input_ids, 173 | attention_mask=attention_mask, 174 | token_type_ids=token_type_ids, 175 | position_ids=position_ids, 176 | head_mask=head_mask, 177 | inputs_embeds=inputs_embeds, 178 | encoder_hidden_states=encoder_hidden_states, 179 | encoder_attention_mask=encoder_attention_mask, 180 | output_attentions=output_attentions, 181 | output_hidden_states=output_hidden_states, 182 | return_dict=return_dict, 183 | ) 184 | 185 | sequence_output = outputs[0] 186 | prediction_scores = self.cls(sequence_output) 187 | 188 | masked_lm_loss = None 189 | if labels is not None: 190 | loss_fct = CrossEntropyLoss() # -100 index = padding token 191 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 192 | 193 | if not return_dict: 194 | output = (prediction_scores,) + outputs[2:] 195 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 196 | 197 | return MaskedLMOutput( 198 | loss=masked_lm_loss, 199 | logits=prediction_scores, 200 | hidden_states=outputs.hidden_states, 201 | attentions=outputs.attentions, 202 | ) -------------------------------------------------------------------------------- /util/model_urfound.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | 15 | import torch 16 | import torchvision 17 | import torch.nn as nn 18 | from torchvision.transforms.functional import InterpolationMode 19 | from timm.models.vision_transformer import PatchEmbed, Block 20 | 21 | from util.pos_embed import get_2d_sincos_pos_embed 22 | from bert.bert_encoder import BertEncoder 23 | 24 | class UrModel(nn.Module): 25 | """ Masked Autoencoder with VisionTransformer backbone 26 | """ 27 | def __init__(self, img_size=224, patch_size=16, in_chans=3, 28 | embed_dim=1024, depth=24, num_heads=16, 29 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 30 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): 31 | super().__init__() 32 | 33 | # -------------------------------------------------------------------------- 34 | # image encoder specifics 35 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 36 | num_patches = self.patch_embed.num_patches 37 | 38 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 39 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 40 | 41 | self.blocks = nn.ModuleList([ 42 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 43 | for i in range(depth)]) 44 | self.norm = norm_layer(embed_dim) 45 | # -------------------------------------------------------------------------- 46 | 47 | # -------------------------------------------------------------------------- 48 | # image decoder specifics 49 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 50 | 51 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 52 | 53 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 54 | 55 | self.decoder_blocks = nn.ModuleList([ 56 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 57 | for i in range(decoder_depth)]) 58 | 59 | self.decoder_norm = norm_layer(decoder_embed_dim) 60 | self.decoder_pred = nn.Linear(decoder_embed_dim, (patch_size*2)**2 * in_chans, bias=True) 61 | # -------------------------------------------------------------------------- 62 | # Bert encoder 63 | self.bert_encoder = BertEncoder() 64 | self.bert_mlp = nn.Linear(embed_dim, 384, bias=True) 65 | self.norm_pix_loss = norm_pix_loss 66 | 67 | self.initialize_weights() 68 | 69 | def initialize_weights(self): 70 | # initialization 71 | # initialize (and freeze) pos_embed by sin-cos embedding 72 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 73 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 74 | 75 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 76 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 77 | 78 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 79 | w = self.patch_embed.proj.weight.data 80 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 81 | 82 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 83 | torch.nn.init.normal_(self.cls_token, std=.02) 84 | torch.nn.init.normal_(self.mask_token, std=.02) 85 | 86 | # initialize nn.Linear and nn.LayerNorm 87 | self.apply(self._init_weights) 88 | 89 | def _init_weights(self, m): 90 | if isinstance(m, nn.Linear): 91 | # we use xavier_uniform following official JAX ViT: 92 | torch.nn.init.xavier_uniform_(m.weight) 93 | if isinstance(m, nn.Linear) and m.bias is not None: 94 | nn.init.constant_(m.bias, 0) 95 | elif isinstance(m, nn.LayerNorm): 96 | nn.init.constant_(m.bias, 0) 97 | nn.init.constant_(m.weight, 1.0) 98 | 99 | def patchify(self, imgs): 100 | """ 101 | imgs: (N, 3, H, W) 102 | x: (N, L, patch_size**2 *3) 103 | """ 104 | 105 | p = self.patch_embed.patch_size[0]*2 106 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 107 | 108 | h = w = imgs.shape[2] // p 109 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 110 | x = torch.einsum('nchpwq->nhwpqc', x) 111 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 112 | return x 113 | 114 | def unpatchify(self, x): 115 | """ 116 | x: (N, L, patch_size**2 *3) 117 | imgs: (N, 3, H, W) 118 | """ 119 | p = self.patch_embed.patch_size[0] * 2 120 | h = w = int(x.shape[1]**.5) 121 | assert h * w == x.shape[1] 122 | 123 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 124 | x = torch.einsum('nhwpqc->nchpwq', x) 125 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 126 | return imgs 127 | 128 | def random_masking(self, x, mask_ratio): 129 | """ 130 | Perform per-sample random masking by per-sample shuffling. 131 | Per-sample shuffling is done by argsort random noise. 132 | x: [N, L, D], sequence 133 | """ 134 | N, L, D = x.shape # batch, length, dim 135 | len_keep = int(L * (1 - mask_ratio)) 136 | 137 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 138 | 139 | # sort noise for each sample 140 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 141 | ids_restore = torch.argsort(ids_shuffle, dim=1) 142 | 143 | # keep the first subset 144 | ids_keep = ids_shuffle[:, :len_keep] 145 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 146 | 147 | # generate the binary mask: 0 is keep, 1 is remove 148 | mask = torch.ones([N, L], device=x.device) 149 | mask[:, :len_keep] = 0 150 | # unshuffle to get the binary mask 151 | mask = torch.gather(mask, dim=1, index=ids_restore) 152 | 153 | return x_masked, mask, ids_restore 154 | 155 | def forward_encoder(self, x, mask_ratio): 156 | # embed patches 157 | x = self.patch_embed(x) 158 | # add pos embed w/o cls token 159 | x = x + self.pos_embed[:, 1:, :] 160 | 161 | # masking: length -> length * mask_ratio 162 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 163 | 164 | # append cls token 165 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 166 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 167 | x = torch.cat((cls_tokens, x), dim=1) 168 | 169 | # apply Transformer blocks 170 | for blk in self.blocks: 171 | x = blk(x) 172 | x = self.norm(x) 173 | 174 | return x, mask, ids_restore 175 | 176 | def forward_decoder(self, x, ids_restore): 177 | # embed tokens 178 | x = self.decoder_embed(x) 179 | 180 | # append mask tokens to sequence 181 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 182 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 183 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 184 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 185 | 186 | # add pos embed 187 | x = x + self.decoder_pos_embed 188 | 189 | # apply Transformer blocks 190 | for blk in self.decoder_blocks: 191 | x = blk(x) 192 | x = self.decoder_norm(x) 193 | 194 | # predictor projection 195 | x = self.decoder_pred(x) 196 | 197 | # remove cls token 198 | x = x[:, 1:, :] 199 | 200 | return x 201 | 202 | def forward_report_decoder(self, latent, caption_ids, labels, attention_mask, token_type_ids): 203 | latent = self.bert_mlp(latent) 204 | latent = latent[:, 1:, :].mean(dim=1) 205 | outputs = self.bert_encoder(latent, caption_ids, labels, attention_mask, token_type_ids) 206 | return outputs.loss 207 | 208 | def forward_loss(self, imgs, pred, mask): 209 | """ 210 | imgs: [N, 3, H, W] 211 | pred: [N, L, p*p*3] 212 | mask: [N, L], 0 is keep, 1 is remove, 213 | """ 214 | target = self.patchify(imgs) 215 | if self.norm_pix_loss: 216 | mean = target.mean(dim=-1, keepdim=True) 217 | var = target.var(dim=-1, keepdim=True) 218 | target = (target - mean) / (var + 1.e-6)**.5 219 | 220 | loss = (pred - target) ** 2 221 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 222 | 223 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 224 | return loss 225 | 226 | def forward(self, batch, mask_ratio=0.75): 227 | big_imgs = batch["image"] 228 | 229 | ids, labels, attention_mask, type_ids = batch["ids"], batch["labels"], batch["attention_mask"], batch["type_ids"] 230 | 231 | big_imgs = big_imgs.cuda() 232 | ids = ids.cuda() 233 | labels = labels.cuda() 234 | attention_mask = attention_mask.cuda() 235 | type_ids = type_ids.cuda() 236 | imgs = torchvision.transforms.Resize([224,224], interpolation=InterpolationMode.BICUBIC)(big_imgs) 237 | 238 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 239 | report_loss = self.forward_report_decoder(latent, ids, labels, attention_mask, type_ids) 240 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 241 | loss = self.forward_loss(big_imgs, pred, mask) 242 | return (loss, report_loss), pred, mask 243 | 244 | def urmodel(**kwargs): 245 | model = UrModel( 246 | patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, 247 | decoder_embed_dim=768, decoder_depth=4, decoder_num_heads=6, 248 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 249 | return model 250 | 251 | def mae_vit_base_patch16_dec512d8b(**kwargs): 252 | model = UrModel( 253 | patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, 254 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 255 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 256 | return model 257 | 258 | mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks 259 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value) 84 | 85 | 86 | class MetricLogger(object): 87 | def __init__(self, delimiter="\t"): 88 | self.meters = defaultdict(SmoothedValue) 89 | self.delimiter = delimiter 90 | 91 | def update(self, **kwargs): 92 | for k, v in kwargs.items(): 93 | if v is None: 94 | continue 95 | if isinstance(v, torch.Tensor): 96 | v = v.item() 97 | assert isinstance(v, (float, int)) 98 | self.meters[k].update(v) 99 | 100 | def __getattr__(self, attr): 101 | if attr in self.meters: 102 | return self.meters[attr] 103 | if attr in self.__dict__: 104 | return self.__dict__[attr] 105 | raise AttributeError("'{}' object has no attribute '{}'".format( 106 | type(self).__name__, attr)) 107 | 108 | def __str__(self): 109 | loss_str = [] 110 | for name, meter in self.meters.items(): 111 | loss_str.append( 112 | "{}: {}".format(name, str(meter)) 113 | ) 114 | return self.delimiter.join(loss_str) 115 | 116 | def synchronize_between_processes(self): 117 | for meter in self.meters.values(): 118 | meter.synchronize_between_processes() 119 | 120 | def add_meter(self, name, meter): 121 | self.meters[name] = meter 122 | 123 | def log_every(self, iterable, print_freq, header=None): 124 | i = 0 125 | if not header: 126 | header = '' 127 | start_time = time.time() 128 | end = time.time() 129 | iter_time = SmoothedValue(fmt='{avg:.4f}') 130 | data_time = SmoothedValue(fmt='{avg:.4f}') 131 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 132 | log_msg = [ 133 | header, 134 | '[{0' + space_fmt + '}/{1}]', 135 | 'eta: {eta}', 136 | '{meters}', 137 | 'time: {time}', 138 | 'data: {data}' 139 | ] 140 | if torch.cuda.is_available(): 141 | log_msg.append('max mem: {memory:.0f}') 142 | log_msg = self.delimiter.join(log_msg) 143 | MB = 1024.0 * 1024.0 144 | for obj in iterable: 145 | data_time.update(time.time() - end) 146 | yield obj 147 | iter_time.update(time.time() - end) 148 | if i % print_freq == 0 or i == len(iterable) - 1: 149 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 151 | if torch.cuda.is_available(): 152 | print(log_msg.format( 153 | i, len(iterable), eta=eta_string, 154 | meters=str(self), 155 | time=str(iter_time), data=str(data_time), 156 | memory=torch.cuda.max_memory_allocated() / MB)) 157 | else: 158 | print(log_msg.format( 159 | i, len(iterable), eta=eta_string, 160 | meters=str(self), 161 | time=str(iter_time), data=str(data_time))) 162 | i += 1 163 | end = time.time() 164 | total_time = time.time() - start_time 165 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 166 | print('{} Total time: {} ({:.4f} s / it)'.format( 167 | header, total_time_str, total_time / len(iterable))) 168 | 169 | 170 | def setup_for_distributed(is_master): 171 | """ 172 | This function disables printing when not in master process 173 | """ 174 | builtin_print = builtins.print 175 | 176 | def print(*args, **kwargs): 177 | force = kwargs.pop('force', False) 178 | force = force or (get_world_size() > 8) 179 | if is_master or force: 180 | now = datetime.datetime.now().time() 181 | builtin_print('[{}] '.format(now), end='') # print with time stamp 182 | builtin_print(*args, **kwargs) 183 | 184 | builtins.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if args.dist_on_itp: 218 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 219 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 220 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 221 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 222 | os.environ['LOCAL_RANK'] = str(args.gpu) 223 | os.environ['RANK'] = str(args.rank) 224 | os.environ['WORLD_SIZE'] = str(args.world_size) 225 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 226 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 227 | args.rank = int(os.environ["RANK"]) 228 | args.world_size = int(os.environ['WORLD_SIZE']) 229 | args.gpu = int(os.environ['LOCAL_RANK']) 230 | elif 'SLURM_PROCID' in os.environ: 231 | args.rank = int(os.environ['SLURM_PROCID']) 232 | args.gpu = args.rank % torch.cuda.device_count() 233 | else: 234 | print('Not using distributed mode') 235 | setup_for_distributed(is_master=True) # hack 236 | args.distributed = False 237 | return 238 | 239 | args.distributed = True 240 | 241 | torch.cuda.set_device(args.gpu) 242 | args.dist_backend = 'nccl' 243 | print('| distributed init (rank {}): {}, gpu {}'.format( 244 | args.rank, args.dist_url, args.gpu), flush=True) 245 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 246 | world_size=args.world_size, rank=args.rank) 247 | torch.distributed.barrier() 248 | setup_for_distributed(args.rank == 0) 249 | 250 | 251 | class NativeScalerWithGradNormCount: 252 | state_dict_key = "amp_scaler" 253 | 254 | def __init__(self): 255 | self._scaler = torch.cuda.amp.GradScaler() 256 | 257 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 258 | self._scaler.scale(loss).backward(create_graph=create_graph) 259 | if update_grad: 260 | if clip_grad is not None: 261 | assert parameters is not None 262 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 263 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 264 | else: 265 | self._scaler.unscale_(optimizer) 266 | norm = get_grad_norm_(parameters) 267 | self._scaler.step(optimizer) 268 | self._scaler.update() 269 | else: 270 | norm = None 271 | return norm 272 | 273 | def state_dict(self): 274 | return self._scaler.state_dict() 275 | 276 | def load_state_dict(self, state_dict): 277 | self._scaler.load_state_dict(state_dict) 278 | 279 | 280 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 281 | if isinstance(parameters, torch.Tensor): 282 | parameters = [parameters] 283 | parameters = [p for p in parameters if p.grad is not None] 284 | norm_type = float(norm_type) 285 | if len(parameters) == 0: 286 | return torch.tensor(0.) 287 | device = parameters[0].grad.device 288 | if norm_type == inf: 289 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 290 | else: 291 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 292 | return total_norm 293 | 294 | 295 | def save_model_pretrain(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 296 | output_dir = Path(args.output_dir) 297 | epoch_name = str(epoch) 298 | if loss_scaler is not None: 299 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 300 | for checkpoint_path in checkpoint_paths: 301 | to_save = { 302 | 'model': model_without_ddp.state_dict(), 303 | 'optimizer': optimizer.state_dict(), 304 | 'epoch': epoch, 305 | 'scaler': loss_scaler.state_dict(), 306 | 'args': args, 307 | } 308 | 309 | save_on_master(to_save, checkpoint_path) 310 | else: 311 | client_state = {'epoch': epoch} 312 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 313 | 314 | 315 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 316 | output_dir = Path(args.output_dir) 317 | epoch_name = str(epoch) 318 | if loss_scaler is not None: 319 | checkpoint_paths = [args.output_dir+args.task+'checkpoint-best.pth'] 320 | for checkpoint_path in checkpoint_paths: 321 | to_save = { 322 | 'model': model_without_ddp.state_dict(), 323 | 'optimizer': optimizer.state_dict(), 324 | 'epoch': epoch, 325 | 'scaler': loss_scaler.state_dict(), 326 | 'args': args, 327 | } 328 | 329 | save_on_master(to_save, checkpoint_path) 330 | else: 331 | client_state = {'epoch': epoch} 332 | model.save_checkpoint(save_dir=args.output_dir+args.task, tag="checkpoint-best", client_state=client_state) 333 | 334 | 335 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 336 | if args.resume: 337 | if args.resume.startswith('https'): 338 | checkpoint = torch.hub.load_state_dict_from_url( 339 | args.resume, map_location='cpu', check_hash=True) 340 | else: 341 | checkpoint = torch.load(args.resume, map_location='cpu') 342 | # if 'ft' in args.output_dir and args.eval: 343 | for i in model_without_ddp.state_dict(): 344 | if i not in checkpoint['model']: 345 | checkpoint['model'][i] = model_without_ddp.state_dict()[i] 346 | 347 | model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 348 | print("Resume checkpoint %s" % args.resume) 349 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 350 | optimizer.load_state_dict(checkpoint['optimizer']) 351 | args.start_epoch = checkpoint['epoch'] + 1 352 | if 'scaler' in checkpoint: 353 | loss_scaler.load_state_dict(checkpoint['scaler']) 354 | print("With optim & sched!") 355 | 356 | 357 | def all_reduce_mean(x): 358 | world_size = get_world_size() 359 | if world_size > 1: 360 | x_reduce = torch.tensor(x).cuda() 361 | dist.all_reduce(x_reduce) 362 | x_reduce = x_reduce / world_size 363 | return x_reduce.item() 364 | else: 365 | return x -------------------------------------------------------------------------------- /util/flair_dataloader/dictionary.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | definition: 4 | This script contains dictionaries for expert knowledge prompts 5 | of several fundus image conditions for pre-training. 6 | 7 | ensemble_prompts: 8 | Also, it presents a dictionary for creating prompt ensembles for 9 | zero-shot classification and transferability. 10 | 11 | datasets/abbreviations: 12 | Finally, the script contains abbreviations of several relevant 13 | conditions used for FLAIR pre-training, and datasets names. 14 | """ 15 | 16 | # Expert knowledge definitions dictionary 17 | definitions = {"no diabetic retinopathy": ["no diabetic retinopathy", "no microaneurysms"], 18 | "mild diabetic retinopathy": ["only few microaneurysms"], 19 | "moderate diabetic retinopathy": ["many exudates near the macula", 20 | "many haemorrhages near the macula", 21 | "retinal thickening near the macula", 22 | "hard exudates", 23 | "cotton wool spots", 24 | "few severe haemorrhages"], 25 | "severe diabetic retinopathy": ["venous beading", 26 | "many severe haemorrhages", 27 | "intraretinal microvascular abnormality"], 28 | "proliferative diabetic retinopathy": ["preretinal or vitreous haemorrhage", 29 | "neovascularization"], 30 | "no referable diabetic macular edema": ["no apparent exudates"], 31 | "hard exudates": ["small white or yellowish deposits with sharp margins", "bright lesion"], 32 | "soft exudates": ["pale yellow or white areas with ill-defined edges", "cotton-wool spot", 33 | "small, whitish or grey, cloud-like, linear or serpentine, slightly elevated lesions" 34 | " with fimbriated edges"], 35 | "microaneurysms": ["small red dots"], 36 | "haemorrhages": ["dense, dark red, sharply outlined lesion"], 37 | "non clinically significant diabetic macular edema": ["presence of exudates outside the radius of one" 38 | " disc diameter from the macula center", 39 | "presence of exudates"], 40 | "age related macular degeneration": ["many small drusen", "few medium-sized drusen", "large drusen", 41 | "macular degeneration"], 42 | "media haze": ["vitreous haze", "pathological opacity", "the obscuration of fundus details by vitreous" 43 | " cells and protein exudation"], 44 | "drusens": ["yellow deposits under the retina", "numerous uniform round yellow-white lesions"], 45 | "pathologic myopia": ["anomalous disc, macular atrophy and possible tessellation"], 46 | "branch retinal vein occlusion": ["occlusion of one of the four major branch retinal veins"], 47 | "tessellation": ["large choroidal vessels at the posterior fundus"], 48 | "epiretinal membrane": ["greyish semi-translucent avascular membrane"], 49 | "laser scar": ["round or oval, yellowish-white with variable black pigment centrally", 50 | "50 to 200 micron diameter lesions"], 51 | "no laser scar": ["no laser scar"], 52 | "macular scar": ["macular scar"], 53 | "central serous retinopathy": ["subretinal fluid involving the fovea", "leakage"], 54 | "optic disc cupping": ["optic disc cupping"], 55 | "central retinal vein occlusion": ["central retinal vein occlusion"], 56 | "tortuous vessels": ["tortuous vessels"], 57 | "asteroid hyalosis": ["multiple sparking, yellow-white, and refractile opacities in the vitreous cavity", 58 | "vitreous opacities"], 59 | "optic disc pallor": ["pale yellow discoloration that can be segmental or generalized on optic disc"], 60 | "optic disc edema": ["optic disc edema"], 61 | "shunt": ["collateral vessels connecting the choroidal and the retinal vasculature", 62 | "collateral vessels of large caliber and lack of leakage"], 63 | "anterior ischemic optic neuropathy": ["anterior ischemic optic neuropathy"], 64 | "parafoveal telangiectasia": ["parafoveal telangiectasia"], 65 | "retinal traction": ["retinal traction"], 66 | "retinitis": ["retinitis"], 67 | "chorioretinitis": ["chorioretinitis"], 68 | "exudates": ["small white or yellowish white deposits with sharp margins", "bright lesion"], 69 | "retinal pigment epithelium changes": ["retinal pigment epithelium changes"], 70 | "macular hole": ["lesion in the macula", "grayish fovea"], 71 | "retinitis pigmentosa": ["pigment deposits are present in the periphery"], 72 | "cotton wool spots": ["cotton wool spots", "soft exudates"], 73 | "colobomas": ["colobomas"], 74 | "optic disc pit maculopathy": ["optic disc pit maculopathy"], 75 | "preretinal haemorrhage": ["preretinal haemorrhage"], 76 | "myelinated nerve fibers": ["myelinated nerve fibers"], 77 | "haemorrhagic retinopathy": ["haemorrhagic retinopathy"], 78 | "central retinal artery occlusion": ["central retinal artery occlusion"], 79 | "tilted disc": ["tilted disc"], 80 | "cystoid macular edema": ["cysts in the macula region"], 81 | "post traumatic choroidal rupture": ["post traumatic choroidal rupture"], 82 | "choroidal folds": ["choroidal folds"], 83 | "vitreous haemorrhage": ["vitreous haemorrhage"], 84 | "macroaneurysm": ["macroaneurysm"], 85 | "vasculitis": ["vasculitis"], 86 | "branch retinal artery occlusion": ["branch retinal artery occlusion"], 87 | "plaque": ["plaque"], 88 | "haemorrhagic pigment epithelial detachment": ["haemorrhagic pigment epithelial detachment"], 89 | "collaterals": ["collaterals"], 90 | "normal": ["healthy", "no findings", "no lesion signs", "no glaucoma", "no retinopathy"], 91 | "large optic cup": ["abnormality in optic cup"], 92 | "retina detachment": ["retina detachment"], 93 | "Vogt-Koyanagi syndrome": ["Vogt-Koyanagi syndrome"], 94 | "maculopathy": ["maculopathy"], 95 | "glaucoma": ["optic nerve abnormalities", "abnormal size of the optic cup", 96 | "anomalous size in the optic disc"], 97 | "optic atrophy": ["optic atrophy"], 98 | "severe hypertensive retinopathy": ["flame shaped hemorrhages at the disc margin, blurred disc margins," 99 | " congested retinal veins, papilledema, and secondary macular " 100 | "exudates", "arterio-venous crossing changes, macular star and " 101 | "cotton wool spots"], 102 | "disc swelling and elevation": ["disc swelling and elevation"], 103 | "dragged disk": ["dragged disk"], 104 | "congenital disk abnormality": ["disk abnormality", "optic disk lesion"], 105 | "Bietti crystalline dystrophy": ["Bietti crystalline dystrophy"], 106 | "peripheral retinal degeneration and break": ["peripheral retinal degeneration and break"], 107 | "neoplasm": ["neoplasm"], 108 | "yellow-white spots flecks": ["yellow-white spots flecks"], 109 | "fibrosis": ["fibrosis"], 110 | "silicon oil": ["silicon oil"], 111 | "no proliferative diabetic retinopathy": ["diabetic retinopathy with no neovascularization", 112 | "no neovascularization"], 113 | "no glaucoma": ["no glaucoma"], 114 | "cataract": ["opacity in the macular area"], 115 | "hypertensive retinopathy": ["possible signs of haemorraghe with blot, dot, or flame-shaped", 116 | "possible presence of microaneurysm, cotton-wool spot, or hard exudate", 117 | "arteriolar narrowing", "vascular wall changes", "optic disk edema"], 118 | "neovascular age related macular degeneration": ["neovascular age-related macular degeneration"], 119 | "geographical age related macular degeneration": ["geographical age-related macular degeneration"], 120 | "acute central serous retinopathy": ["acute central serous retinopathy"], 121 | "chronic central serous retinopathy": ["chronic central serous retinopathy"], 122 | "no cataract": ["no cataract signs", "no obscure opacities"], 123 | "abnormal optic disc": ["abnormal optic disc"], 124 | "abnormal vessels": ["abnormal vessels"], 125 | "abnormal macula": ["abnormal macula"], 126 | "macular edema": ["macular edema"], 127 | "scar": ["scar"], 128 | "nevus": ["darkly pigmented lesion found in the back of the eye"], 129 | "increased cup disc": ["increased cup disc"], 130 | "intraretinal microvascular abnormalities": ["shunt vessels and appear as abnormal branching or" 131 | " dilation of existing blood vessels (capillaries) " 132 | "within the retina", "deeper in the retina than" 133 | " neovascularization, has blurrier edges, is more" 134 | " of a burgundy than a red, does not appear on the " 135 | "optic disc", "vascular loops confined within the" 136 | " retina"], 137 | "red small dots": ["microaneurysms"], 138 | "neovascularisation": ["neovascularisation"], 139 | "a disease": ["no healthy", "lesions"], 140 | "superficial haemorrhages": ["superficial haemorrhages"], 141 | "deep haemorrhages": ["deep haemorrhages"], 142 | "ungradable": ["no fundus", "very noisy", "noisy"], 143 | "noisy": ["noisy"], 144 | "normal macula": ["normal macula"], 145 | "macular degeneration": ["macular degeneration"], 146 | "diabetic retinopathy": ["diabetic retinopathy"], 147 | "no hypertensive retinopathy": ["no presence of hypertensive retinopathy"], 148 | "mild hypertensive retinopathy": ["mild arteriovenous ratio", "mild tortuosity", 149 | "focal arteriolar narrowing", 150 | "arteriovenous nicking"], 151 | "moderate hypertensive retinopathy": ["moderate arteriovenous ratio", "moderate tortuosity", 152 | "cotton wool spots", 153 | "flame-shaped haemorrhages"], 154 | "malignant hypertensive retinopathy": ["severe arteriovenous ratio", "severe tortuosity", 155 | "swelling optical disk", 156 | "flame-shaped haemorrhages"] 157 | } 158 | 159 | # Datasets names 160 | datasets = ["01_EYEPACS", "03_IDRID", "04_RFMid", "05_1000x39", "07_LAG", "09_PAPILA", "10_PARAGUAY", "12_ARIA", 161 | "14_AGAR300", "15_APTOS", "16_FUND-OCT", "17_DiaRetDB1", "18_DRIONS-DB", "19_Drishti-GS1", "20_E-ophta", 162 | "20_E-ophta", "21_G1020", "23_HRF", "24_ORIGA", "25_REFUGE", "26_ROC", "27_BRSET", "28_OIA-DDR", 163 | "02_MESIDOR", "05_20x3", "08_ODIR200x3", "13_FIVES"] 164 | 165 | # Categories abbreviations 166 | abbreviations = {"no diabetic retinopathy": "noDR", "mild diabetic retinopathy": "mildDR", 167 | "moderate diabetic retinopathy": "modDR", "severe diabetic retinopathy": "sevDR", 168 | "proliferative diabetic retinopathy": "prolDR", "diabetic macular edema": "DME", 169 | "no referable diabetic macular edema": "noDME", "hard exudates": "hEX", 170 | "soft exudates": "sEX", "microaneurysms": "MA", "haemorrhages": "HE", 171 | "non clinically significant diabetic macular edema": "nonCSDME", 172 | "age-related macular degeneration": "ARMD", "media haze": "MH", "drusens": "DN", 173 | "pathologic myopia": "MYA", "branch retinal vein occlusion": "BRVO", "tessellation": "TSLN", 174 | "epiretinal membrane": "ERM", "laser scar": "LS", "macular scar": "MS", 175 | "central serous retinopathy": "CSR", "optic disc cupping": "ODC", 176 | "central retinal vein occlusion": "CRVO", "tortuous vessels": "TV", "asteroid hyalosis": "AH", 177 | "optic disc pallor": "ODP", "optic disc edema": "ODE", 178 | "shunt": "ST", "anterior ischemic optic neuropathy": "AION", "parafoveal telangiectasia": "PT", 179 | "retinal traction": "RT", "retinitis": "RS", "chorioretinitis": "CRS", "exudates": "EX", 180 | "retinal pigment epithelium changes": "RPEC", "macular hole": "MHL", "retinitis pigmentosa": "RP", 181 | "cotton wool spots": "CWS", "colobomas": "CB", "optic disc pit maculopathy": "ODM", 182 | "preretinal haemorrhage": "PRH", "myelinated nerve fibers": "MNF", "haemorrhagic retinopathy": "HR", 183 | "central retinal artery occlusion": "CRAO", "tilted disc": "TD", "cystoid macular edema": "CME", 184 | "post traumatic choroidal rupture": "PTCR", "choroidal folds": "CF", "vitreous haemorrhage": "VH", 185 | "macroaneurysm": "MCA", "vasculitis": "VS", "branch retinal artery occlusion": "BRAO", "plaque": "PLQ", 186 | "haemorrhagic pigment epithelial detachment": "HPED", "collaterals": "CL", "normal": "N", 187 | "large optic cup": "LOC", "retina detachment": "RD", "Vogt-Koyanagi syndrome": "VKH", 188 | "maculopathy": "M", "glaucoma": "G", "optic atrophy": "OA", "severe hypertensive retinopathy": "sevHR", 189 | "disc swelling and elevation": "DSE", "dragged disk": "DD", "congenital disk abnormality": "CDA", 190 | "Bietti crystalline dystrophy": "BCD", "peripheral retinal degeneration and break": "PRDB", 191 | "neoplasm": "NP", "yellow-white spots flecks": "YWSF", "fibrosis": "fibrosis", "silicon oil": "SO", 192 | "no proliferative diabetic retinopathy": "noProlDR", "no glaucoma": "noG", "cataract": "CAT", 193 | "hypertensive retinopathy": "HR", "neovascular age-related macular degeneration": "neovARMD", 194 | "geographical age-related macular degeneration": "geoARMD", 195 | "acute central serous retinopathy": "acCSR", "chronic central serous retinopathy": "chCSR", 196 | "no cataract": "noCAT", "abnormal optic disc": "AOD", "abnormal vessels": "AV", 197 | "abnormal macula": "AM", "macular edema": "ME", "scar": "S", "nevus": "NE", 198 | "increased cup disc": "ICD", "intraretinal microvascular abnormalities": "IrMA", 199 | "red small dots": "ReSD", "neovascularisation": "neoV", "a disease": "Dis"} -------------------------------------------------------------------------------- /main_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # Partly revised by YZ @UCL&Moorfields 4 | # -------------------------------------------------------- 5 | 6 | import argparse 7 | import datetime 8 | import json 9 | import numpy as np 10 | import os 11 | import time 12 | from pathlib import Path 13 | 14 | import torch 15 | import torch.backends.cudnn as cudnn 16 | from torch.utils.tensorboard import SummaryWriter 17 | 18 | import timm 19 | 20 | assert timm.__version__ == "0.3.2" # version check 21 | from timm.models.layers import trunc_normal_ 22 | from timm.data.mixup import Mixup 23 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 24 | 25 | import util.lr_decay as lrd 26 | import util.misc as misc 27 | from finetune.datasets_finetune import build_dataset, build_transform 28 | from util.pos_embed import interpolate_pos_embed 29 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 30 | 31 | import finetune.models_vit as models_vit 32 | 33 | from finetune.engine_finetune import train_one_epoch, evaluate 34 | from torchvision import datasets 35 | import random 36 | from torch.utils.data import Subset 37 | 38 | def get_args_parser(): 39 | parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False) 40 | parser.add_argument('--batch_size', default=64, type=int, 41 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 42 | parser.add_argument('--epochs', default=50, type=int) 43 | parser.add_argument('--accum_iter', default=1, type=int, 44 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 45 | 46 | # Model parameters 47 | parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', 48 | help='Name of model to train') 49 | 50 | parser.add_argument('--input_size', default=224, type=int, 51 | help='images input size') 52 | 53 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', 54 | help='Drop path rate (default: 0.1)') 55 | 56 | # Optimizer parameters 57 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 58 | help='Clip gradient norm (default: None, no clipping)') 59 | parser.add_argument('--weight_decay', type=float, default=0.05, 60 | help='weight decay (default: 0.05)') 61 | 62 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 63 | help='learning rate (absolute lr)') 64 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 65 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 66 | parser.add_argument('--layer_decay', type=float, default=0.75, 67 | help='layer-wise lr decay from ELECTRA/BEiT') 68 | 69 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 70 | help='lower lr bound for cyclic schedulers that hit 0') 71 | 72 | parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N', 73 | help='epochs to warmup LR') 74 | 75 | # Augmentation parameters 76 | parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', 77 | help='Color jitter factor (enabled only when not using Auto/RandAug)') 78 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 79 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 80 | parser.add_argument('--smoothing', type=float, default=0.1, 81 | help='Label smoothing (default: 0.1)') 82 | 83 | # * Random Erase params 84 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 85 | help='Random erase prob (default: 0.25)') 86 | parser.add_argument('--remode', type=str, default='pixel', 87 | help='Random erase mode (default: "pixel")') 88 | parser.add_argument('--recount', type=int, default=1, 89 | help='Random erase count (default: 1)') 90 | parser.add_argument('--resplit', action='store_true', default=False, 91 | help='Do not random erase first (clean) augmentation split') 92 | 93 | # * Mixup params 94 | parser.add_argument('--mixup', type=float, default=0, 95 | help='mixup alpha, mixup enabled if > 0.') 96 | parser.add_argument('--cutmix', type=float, default=0, 97 | help='cutmix alpha, cutmix enabled if > 0.') 98 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 99 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 100 | parser.add_argument('--mixup_prob', type=float, default=1.0, 101 | help='Probability of performing mixup or cutmix when either/both is enabled') 102 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 103 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 104 | parser.add_argument('--mixup_mode', type=str, default='batch', 105 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 106 | 107 | # * Finetuning params 108 | parser.add_argument('--finetune', default='',type=str, 109 | help='finetune from checkpoint') 110 | parser.add_argument('--task', default='',type=str, 111 | help='finetune from checkpoint') 112 | parser.add_argument('--global_pool', action='store_true') 113 | parser.set_defaults(global_pool=True) 114 | parser.add_argument('--cls_token', action='store_false', dest='global_pool', 115 | help='Use class token instead of global pool for classification') 116 | 117 | # Dataset parameters 118 | parser.add_argument('--data_path', default='/home/jupyter/Mor_DR_data/data/data/IDRID/Disease_Grading/', type=str, 119 | help='dataset path') 120 | parser.add_argument('--nb_classes', default=1000, type=int, 121 | help='number of the classification types') 122 | 123 | parser.add_argument('--output_dir', default='./output_dir', 124 | help='path where to save, empty for no saving') 125 | parser.add_argument('--log_dir', default='./output_dir', 126 | help='path where to tensorboard log') 127 | parser.add_argument('--device', default='cuda', 128 | help='device to use for training / testing') 129 | parser.add_argument('--seed', default=0, type=int) 130 | parser.add_argument('--resume', default='', 131 | help='resume from checkpoint') 132 | 133 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 134 | help='start epoch') 135 | parser.add_argument('--eval', action='store_true', 136 | help='Perform evaluation only') 137 | parser.add_argument('--dist_eval', action='store_true', default=False, 138 | help='Enabling distributed evaluation (recommended during training for faster monitor') 139 | parser.add_argument('--num_workers', default=10, type=int) 140 | parser.add_argument('--pin_mem', action='store_true', 141 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 142 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 143 | parser.set_defaults(pin_mem=True) 144 | 145 | # distributed training parameters 146 | parser.add_argument('--world_size', default=1, type=int, 147 | help='number of distributed processes') 148 | parser.add_argument('--local_rank', default=-1, type=int) 149 | parser.add_argument('--dist_on_itp', action='store_true') 150 | parser.add_argument('--dist_url', default='env://', 151 | help='url used to set up distributed training') 152 | 153 | parser.add_argument('--partial_p', type=str, default='') 154 | return parser 155 | 156 | def main(args): 157 | misc.init_distributed_mode(args) 158 | 159 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 160 | print("{}".format(args).replace(', ', ',\n')) 161 | 162 | device = torch.device(args.device) 163 | 164 | # fix the seed for reproducibility 165 | seed = args.seed + misc.get_rank() 166 | torch.manual_seed(seed) 167 | np.random.seed(seed) 168 | 169 | cudnn.benchmark = True 170 | 171 | if args.partial_p: 172 | dataset_train = build_dataset(is_train=f'train_{(args.partial_p)}', args=args) 173 | else: 174 | dataset_train = build_dataset(is_train='train', args=args) 175 | dataset_val = build_dataset(is_train='val', args=args) 176 | dataset_test = build_dataset(is_train='test', args=args) 177 | 178 | if True: # args.distributed: 179 | num_tasks = misc.get_world_size() 180 | global_rank = misc.get_rank() 181 | sampler_train = torch.utils.data.DistributedSampler( 182 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 183 | ) 184 | print("Sampler_train = %s" % str(sampler_train)) 185 | if args.dist_eval: 186 | if len(dataset_val) % num_tasks != 0: 187 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 188 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 189 | 'equal num of samples per-process.') 190 | sampler_val = torch.utils.data.DistributedSampler( 191 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias 192 | else: 193 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 194 | 195 | if args.dist_eval: 196 | if len(dataset_test) % num_tasks != 0: 197 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 198 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 199 | 'equal num of samples per-process.') 200 | sampler_test = torch.utils.data.DistributedSampler( 201 | dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias 202 | else: 203 | sampler_test = torch.utils.data.SequentialSampler(dataset_test) 204 | 205 | 206 | args.log_dir = os.path.join(args.output_dir) 207 | if global_rank == 0 and args.log_dir is not None and not args.eval: 208 | # os.makedirs(args.log_dir, exist_ok=True) 209 | log_writer = SummaryWriter(log_dir=args.log_dir+args.task) 210 | else: 211 | log_writer = None 212 | 213 | data_loader_train = torch.utils.data.DataLoader( 214 | dataset_train, sampler=sampler_train, 215 | batch_size=args.batch_size, 216 | num_workers=args.num_workers, 217 | pin_memory=args.pin_mem, 218 | drop_last=False, 219 | ) 220 | 221 | data_loader_val = torch.utils.data.DataLoader( 222 | dataset_val, sampler=sampler_val, 223 | batch_size=args.batch_size, 224 | num_workers=args.num_workers, 225 | pin_memory=args.pin_mem, 226 | drop_last=False 227 | ) 228 | 229 | data_loader_test = torch.utils.data.DataLoader( 230 | dataset_test, sampler=sampler_test, 231 | batch_size=args.batch_size, 232 | num_workers=args.num_workers, 233 | pin_memory=args.pin_mem, 234 | drop_last=False 235 | ) 236 | 237 | mixup_fn = None 238 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 239 | if mixup_active: 240 | print("Mixup is activated!") 241 | mixup_fn = Mixup( 242 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 243 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 244 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 245 | 246 | model = models_vit.__dict__[args.model]( 247 | img_size=args.input_size, 248 | num_classes=args.nb_classes, 249 | drop_path_rate=args.drop_path, 250 | global_pool=args.global_pool, 251 | ) 252 | 253 | if args.finetune and not args.eval: 254 | checkpoint = torch.load(args.finetune, map_location='cpu') 255 | 256 | print("Load pre-trained checkpoint from: %s" % args.finetune) 257 | checkpoint_model = checkpoint['model'] 258 | state_dict = model.state_dict() 259 | for k in ['head.weight', 'head.bias']: 260 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 261 | print(f"Removing key {k} from pretrained checkpoint") 262 | del checkpoint_model[k] 263 | 264 | # interpolate position embedding 265 | interpolate_pos_embed(model, checkpoint_model) 266 | 267 | # load pre-trained model 268 | msg = model.load_state_dict(checkpoint_model, strict=False) 269 | print(msg) 270 | 271 | # if args.global_pool: 272 | # assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} 273 | # else: 274 | # assert set(msg.missing_keys) == {'head.weight', 'head.bias'} 275 | 276 | # manually initialize fc layer 277 | trunc_normal_(model.head.weight, std=2e-5) 278 | 279 | model.to(device) 280 | 281 | model_without_ddp = model 282 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 283 | 284 | print("Model = %s" % str(model_without_ddp)) 285 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 286 | 287 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 288 | 289 | if args.lr is None: # only base_lr is specified 290 | args.lr = args.blr * eff_batch_size / 256 291 | 292 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 293 | print("actual lr: %.2e" % args.lr) 294 | 295 | print("accumulate grad iterations: %d" % args.accum_iter) 296 | print("effective batch size: %d" % eff_batch_size) 297 | 298 | if args.distributed: 299 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 300 | model_without_ddp = model.module 301 | 302 | # build optimizer with layer-wise lr decay (lrd) 303 | param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay, 304 | no_weight_decay_list=model_without_ddp.no_weight_decay(), 305 | layer_decay=args.layer_decay 306 | ) 307 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr) 308 | loss_scaler = NativeScaler() 309 | 310 | if mixup_fn is not None: 311 | # smoothing is handled with mixup label transform 312 | criterion = SoftTargetCrossEntropy() 313 | elif args.smoothing > 0.: 314 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 315 | else: 316 | criterion = torch.nn.CrossEntropyLoss() 317 | 318 | print("criterion = %s" % str(criterion)) 319 | 320 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 321 | 322 | if args.eval: 323 | test_stats,auc_roc = evaluate(data_loader_test, model, device, args.log_dir+args.task, epoch=0, mode='test',num_class=args.nb_classes) 324 | exit(0) 325 | 326 | print(f"Start training for {args.epochs} epochs") 327 | start_time = time.time() 328 | max_accuracy = 0.0 329 | max_auc = 0.0 330 | for epoch in range(args.start_epoch, args.epochs): 331 | if args.distributed: 332 | data_loader_train.sampler.set_epoch(epoch) 333 | train_stats = train_one_epoch( 334 | model, criterion, data_loader_train, 335 | optimizer, device, epoch, loss_scaler, 336 | args.clip_grad, mixup_fn, 337 | log_writer=log_writer, 338 | args=args 339 | ) 340 | 341 | val_stats,val_auc_roc = evaluate(data_loader_val, model, device, args.log_dir+args.task, epoch, mode='val',num_class=args.nb_classes) 342 | if max_auc