├── figs
├── urfound_p0.jpg
├── urfound_p1.jpg
├── urfound_p2.jpg
└── urfound_p3.jpg
├── util
├── lr_sched.py
├── lr_decay.py
├── engine_pretrain.py
├── flair_dataloader
│ ├── dataset.py
│ ├── dataloader.py
│ ├── transforms.py
│ └── dictionary.py
├── pos_embed.py
├── dataset.py
├── model_urfound.py
└── misc.py
├── LICENSE
├── requirements.txt
├── finetune
├── datasets_finetune.py
├── models_vit.py
└── engine_finetune.py
├── .gitignore
├── README.md
├── bert
├── bert_encoder.py
└── bert.py
├── main_pretrain_urfound.py
└── main_finetune.py
/figs/urfound_p0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukkai/UrFound/HEAD/figs/urfound_p0.jpg
--------------------------------------------------------------------------------
/figs/urfound_p1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukkai/UrFound/HEAD/figs/urfound_p1.jpg
--------------------------------------------------------------------------------
/figs/urfound_p2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukkai/UrFound/HEAD/figs/urfound_p2.jpg
--------------------------------------------------------------------------------
/figs/urfound_p3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yukkai/UrFound/HEAD/figs/urfound_p3.jpg
--------------------------------------------------------------------------------
/util/lr_sched.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 |
9 | def adjust_learning_rate(optimizer, epoch, args):
10 | """Decay the learning rate with half-cycle cosine after warmup"""
11 | if epoch < args.warmup_epochs:
12 | lr = args.lr * epoch / args.warmup_epochs
13 | else:
14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
16 | for param_group in optimizer.param_groups:
17 | if "lr_scale" in param_group:
18 | param_group["lr"] = lr * param_group["lr_scale"]
19 | else:
20 | param_group["lr"] = lr
21 | return lr
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 yukkai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | art==6.2
3 | cachetools==4.2.4
4 | charset-normalizer==3.3.2
5 | cycler==0.11.0
6 | filelock==3.12.2
7 | fonttools==4.38.0
8 | fsspec==2023.1.0
9 | google-auth==1.35.0
10 | google-auth-oauthlib==0.4.6
11 | grpcio==1.62.2
12 | huggingface-hub==0.16.4
13 | idna==3.7
14 | imageio==2.9.0
15 | importlib-metadata==6.7.0
16 | joblib==1.3.2
17 | kiwisolver==1.4.5
18 | Markdown==3.4.4
19 | MarkupSafe==2.1.5
20 | matplotlib==3.5.3
21 | networkx==2.6.3
22 | numpy==1.21.6
23 | oauthlib==3.2.2
24 | opencv-python==4.5.3.56
25 | packaging==24.0
26 | pandas==0.25.3
27 | parameterized==0.9.0
28 | Pillow==8.3.1
29 | protobuf==3.17.3
30 | pyasn1==0.5.1
31 | pyasn1-modules==0.3.0
32 | pycm==3.2
33 | pydicom==2.3.0
34 | pyparsing==3.1.2
35 | python-dateutil==2.9.0.post0
36 | pytz==2024.1
37 | PyWavelets==1.3.0
38 | PyYAML==6.0.1
39 | regex==2024.4.16
40 | requests==2.31.0
41 | requests-oauthlib==2.0.0
42 | rsa==4.9
43 | safetensors==0.4.3
44 | scikit-image==0.17.2
45 | scikit-learn==0.24.2
46 | scipy==1.5.4
47 | six==1.16.0
48 | tensorboard==2.6.0
49 | tensorboard-data-server==0.6.1
50 | tensorboard-plugin-wit==1.8.0
51 | threadpoolctl==3.1.0
52 | tifffile==2021.11.2
53 | timm==0.3.2
54 | tokenizers==0.13.3
55 | tqdm==4.62.1
56 | transformers==4.30.2
57 | typing_extensions==4.7.1
58 | uncertainty-calibration==0.1.4
59 | urllib3==2.0.7
60 | Werkzeug==2.2.3
61 | zipp==3.15.0
--------------------------------------------------------------------------------
/finetune/datasets_finetune.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | # Partly revised by YZ @UCL&Moorfields
4 | # --------------------------------------------------------
5 |
6 | import os
7 | from torchvision import datasets, transforms
8 | from timm.data import create_transform
9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
10 | from torchvision.transforms.functional import InterpolationMode
11 |
12 |
13 | def build_dataset(is_train, args):
14 |
15 | transform = build_transform(is_train, args)
16 | root = os.path.join(args.data_path, is_train)
17 | dataset = datasets.ImageFolder(root, transform=transform)
18 |
19 | return dataset
20 |
21 |
22 | def build_transform(is_train, args):
23 | mean = IMAGENET_DEFAULT_MEAN
24 | std = IMAGENET_DEFAULT_STD
25 | # train transform
26 | # if is_train=='train':
27 | if 'train' in is_train:
28 | # this should always dispatch to transforms_imagenet_train
29 | transform = create_transform(
30 | input_size=args.input_size,
31 | is_training=True,
32 | color_jitter=args.color_jitter,
33 | auto_augment=args.aa,
34 | # interpolation='bicubic',
35 | interpolation = InterpolationMode.BICUBIC,
36 | re_prob=args.reprob,
37 | re_mode=args.remode,
38 | re_count=args.recount,
39 | mean=mean,
40 | std=std,
41 | )
42 | return transform
43 |
44 | # eval transform
45 | t = []
46 | if args.input_size <= 224:
47 | crop_pct = 224 / 256
48 | else:
49 | crop_pct = 1.0
50 | size = int(args.input_size / crop_pct)
51 | t.append(
52 | transforms.Resize(size, interpolation=InterpolationMode.BICUBIC),
53 | )
54 | t.append(transforms.CenterCrop(args.input_size))
55 | t.append(transforms.ToTensor())
56 | t.append(transforms.Normalize(mean, std))
57 | return transforms.Compose(t)
58 |
--------------------------------------------------------------------------------
/finetune/models_vit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | # Partly revised by YZ @UCL&Moorfields
4 | # --------------------------------------------------------
5 |
6 | from functools import partial
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | import timm.models.vision_transformer
12 |
13 |
14 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
15 | """ Vision Transformer with support for global average pooling
16 | """
17 | def __init__(self, global_pool=False, **kwargs):
18 | super(VisionTransformer, self).__init__(**kwargs)
19 |
20 | self.global_pool = global_pool
21 | if self.global_pool:
22 | norm_layer = kwargs['norm_layer']
23 | embed_dim = kwargs['embed_dim']
24 | self.fc_norm = norm_layer(embed_dim)
25 |
26 | del self.norm # remove the original norm
27 |
28 | def forward_features(self, x):
29 | B = x.shape[0]
30 | x = self.patch_embed(x)
31 |
32 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
33 | x = torch.cat((cls_tokens, x), dim=1)
34 | x = x + self.pos_embed
35 | x = self.pos_drop(x)
36 |
37 | for blk in self.blocks:
38 | x = blk(x)
39 |
40 | if self.global_pool:
41 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token
42 | outcome = self.fc_norm(x)
43 | else:
44 | x = self.norm(x)
45 | outcome = x[:, 0]
46 |
47 | return outcome
48 |
49 |
50 | def vit_large_patch16(**kwargs):
51 | model = VisionTransformer(
52 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
53 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
54 | return model
55 |
56 |
57 | def vit_base_patch16(**kwargs):
58 | model = VisionTransformer(
59 | patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
60 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
61 | return model
--------------------------------------------------------------------------------
/util/lr_decay.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # ELECTRA https://github.com/google-research/electra
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 |
12 | import json
13 |
14 |
15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
16 | """
17 | Parameter groups for layer-wise lr decay
18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
19 | """
20 | param_group_names = {}
21 | param_groups = {}
22 |
23 | num_layers = len(model.blocks) + 1
24 |
25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
26 |
27 | for n, p in model.named_parameters():
28 | if not p.requires_grad:
29 | continue
30 |
31 | # no decay: all 1D parameters and model specific ones
32 | if p.ndim == 1 or n in no_weight_decay_list:
33 | g_decay = "no_decay"
34 | this_decay = 0.
35 | else:
36 | g_decay = "decay"
37 | this_decay = weight_decay
38 |
39 | layer_id = get_layer_id_for_vit(n, num_layers)
40 | group_name = "layer_%d_%s" % (layer_id, g_decay)
41 |
42 | if group_name not in param_group_names:
43 | this_scale = layer_scales[layer_id]
44 |
45 | param_group_names[group_name] = {
46 | "lr_scale": this_scale,
47 | "weight_decay": this_decay,
48 | "params": [],
49 | }
50 | param_groups[group_name] = {
51 | "lr_scale": this_scale,
52 | "weight_decay": this_decay,
53 | "params": [],
54 | }
55 |
56 | param_group_names[group_name]["params"].append(n)
57 | param_groups[group_name]["params"].append(p)
58 |
59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
60 |
61 | return list(param_groups.values())
62 |
63 |
64 | def get_layer_id_for_vit(name, num_layers):
65 | """
66 | Assign a parameter with its layer id
67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
68 | """
69 | if name in ['cls_token', 'pos_embed']:
70 | return 0
71 | elif name.startswith('patch_embed'):
72 | return 0
73 | elif name.startswith('blocks'):
74 | return int(name.split('.')[1]) + 1
75 | else:
76 | return num_layers
--------------------------------------------------------------------------------
/util/engine_pretrain.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # DeiT: https://github.com/facebookresearch/deit
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 |
12 | from typing import Iterable
13 |
14 | import torch
15 |
16 | import util.misc as misc
17 | import util.lr_sched as lr_sched
18 |
19 |
20 | def train_one_epoch(model: torch.nn.Module,
21 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
22 | device: torch.device, epoch: int, loss_scaler,
23 | log_writer=None,
24 | args=None):
25 | model.train(True)
26 | metric_logger = misc.MetricLogger(delimiter=" ")
27 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
28 | header = 'Epoch: [{}]'.format(epoch)
29 | print_freq = 20
30 |
31 | accum_iter = args.accum_iter
32 |
33 | optimizer.zero_grad()
34 |
35 | if log_writer is not None:
36 | print('log_dir: {}'.format(log_writer.log_dir))
37 |
38 | mask_ratio = args.mask_ratio
39 | for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
40 | # we use a per iteration (instead of per epoch) lr scheduler
41 | if data_iter_step % accum_iter == 0:
42 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
43 | with torch.cuda.amp.autocast():
44 | loss, _, _ = model(batch, mask_ratio=mask_ratio)
45 |
46 | loss_value1 = loss[0].item()
47 | loss_value2 = loss[1].item()
48 | loss = loss[0] + loss[1]
49 | loss = loss / accum_iter
50 | loss_scaler(loss, optimizer, parameters=model.parameters(),
51 | update_grad=(data_iter_step + 1) % accum_iter == 0)
52 |
53 | if (data_iter_step + 1) % accum_iter == 0:
54 | optimizer.zero_grad()
55 |
56 | torch.cuda.synchronize()
57 |
58 | metric_logger.update(loss1=loss_value1)
59 | metric_logger.update(loss2=loss_value2)
60 |
61 | lr = optimizer.param_groups[0]["lr"]
62 | metric_logger.update(lr=lr)
63 |
64 | loss_value_reduce1 = misc.all_reduce_mean(loss_value1)
65 | loss_value_reduce2 = misc.all_reduce_mean(loss_value2)
66 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
67 | """ We use epoch_1000x as the x-axis in tensorboard.
68 | This calibrates different curves when batch size changes.
69 | """
70 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
71 | log_writer.add_scalar('train_loss1', loss_value_reduce1, epoch_1000x)
72 | log_writer.add_scalar('train_loss2', loss_value_reduce2, epoch_1000x)
73 | log_writer.add_scalar('lr', lr, epoch_1000x)
74 |
75 | # gather the stats from all processes
76 | metric_logger.synchronize_between_processes()
77 | print("Averaged stats:", metric_logger)
78 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
164 | .DS_Store
--------------------------------------------------------------------------------
/util/flair_dataloader/dataset.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | Specialized Datasets implementation from pytorch to create balanced datasets
4 | with regard the used datasets for pretraining.
5 | """
6 |
7 | import collections.abc
8 | import numpy as np
9 |
10 | from torch.utils.data import Dataset as _TorchDataset
11 | from typing import Any, Callable, Optional, Sequence, Union
12 | from torch.utils.data import Subset
13 |
14 |
15 | class Dataset(_TorchDataset):
16 | """
17 | A generic data with a length property and an optional callable data transform
18 | when fetching a data sample.
19 | If passing slicing indices, will return a PyTorch Subset, for example: `data: Subset = data[1:4]`,
20 | for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset
21 |
22 | For example, typical input data can be a list of dictionaries::
23 |
24 | [{ { {
25 | 'img': 'image1.nii.gz', 'img': 'image2.nii.gz', 'img': 'image3.nii.gz',
26 | 'seg': 'label1.nii.gz', 'seg': 'label2.nii.gz', 'seg': 'label3.nii.gz',
27 | 'extra': 123 'extra': 456 'extra': 789
28 | }, }, }]
29 | """
30 |
31 | def __init__(self, data: Sequence, transform: Optional[Callable] = None) -> None:
32 | """
33 | Args:
34 | data: input data to load and transform to generate data for model.
35 | transform: a callable data transform on input data.
36 |
37 | """
38 | self.data = data
39 | self.transform: Any = transform
40 |
41 | def __len__(self) -> int:
42 | return len(self.data)
43 |
44 | def _transform(self, index: int):
45 | """
46 | Fetch single data item from `self.data`.
47 | """
48 | data_i = self.data[index]
49 | return self.transform(data_i) if self.transform is not None else data_i
50 |
51 | def __getitem__(self, index: Union[int, slice, Sequence[int]]):
52 | """
53 | Returns a `Subset` if `index` is a slice or Sequence, a data item otherwise.
54 | """
55 | if isinstance(index, slice):
56 | # data[:42]
57 | start, stop, step = index.indices(len(self))
58 | indices = range(start, stop, step)
59 | return Subset(dataset=self, indices=indices)
60 | if isinstance(index, collections.abc.Sequence):
61 | # data[[1, 3, 4]]
62 | return Subset(dataset=self, indices=index)
63 | return self._transform(index)
64 |
65 |
66 | class UniformDataset(Dataset):
67 | def __init__(self, data, transform):
68 | super().__init__(data=data, transform=transform)
69 | self.datasetkey = []
70 | self.data_dic = []
71 | self.datasetnum = []
72 | self.datasetlen = 0
73 | self.dataset_split(data)
74 |
75 | def dataset_split(self, data):
76 | keys = []
77 | for img in data:
78 | keys.append(img["image_name"].split("/")[0])
79 |
80 | self.datasetkey = list(np.unique(keys))
81 |
82 | data_dic = {}
83 | for iKey in self.datasetkey:
84 | data_dic[iKey] = [data[iSample] for iSample in range(len(keys)) if keys[iSample]==iKey]
85 | self.data_dic = data_dic
86 |
87 | self.datasetnum = []
88 | for key, item in self.data_dic.items():
89 | assert len(item) != 0, f'the data {key} has no data'
90 | self.datasetnum.append(len(item))
91 | self.datasetlen = len(self.datasetkey)
92 |
93 | def _transform(self, set_key, data_index):
94 | data_i = self.data_dic[set_key][data_index]
95 | return self.transform(data_i) if self.transform is not None else data_i
96 |
97 | def __getitem__(self, index):
98 | ## the index generated outside is only used to select the data
99 | ## the corresponding data in each data is selelcted by the np.random.randint function
100 | set_index = index % self.datasetlen
101 | set_key = self.datasetkey[set_index]
102 |
103 | data_index = np.random.randint(self.datasetnum[set_index], size=1)[0]
104 | return self._transform(set_key, data_index)
--------------------------------------------------------------------------------
/util/pos_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # Position embedding utils
8 | # --------------------------------------------------------
9 |
10 | import numpy as np
11 |
12 | import torch
13 |
14 | # --------------------------------------------------------
15 | # 2D sine-cosine position embedding
16 | # References:
17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18 | # MoCo v3: https://github.com/facebookresearch/moco-v3
19 | # --------------------------------------------------------
20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21 | """
22 | grid_size: int of the grid height and width
23 | return:
24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25 | """
26 | grid_h = np.arange(grid_size, dtype=np.float32)
27 | grid_w = np.arange(grid_size, dtype=np.float32)
28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
29 | grid = np.stack(grid, axis=0)
30 |
31 | grid = grid.reshape([2, 1, grid_size, grid_size])
32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33 | if cls_token:
34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35 | return pos_embed
36 |
37 |
38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
39 | assert embed_dim % 2 == 0
40 |
41 | # use half of dimensions to encode grid_h
42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
44 |
45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
46 | return emb
47 |
48 |
49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
50 | """
51 | embed_dim: output dimension for each position
52 | pos: a list of positions to be encoded: size (M,)
53 | out: (M, D)
54 | """
55 | assert embed_dim % 2 == 0
56 | omega = np.arange(embed_dim // 2, dtype=np.float)
57 | omega = omega / embed_dim / 2.
58 | omega = 1. / 10000**omega # (D/2,)
59 |
60 | pos = pos.reshape(-1) # (M,)
61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
62 |
63 | emb_sin = np.sin(out) # (M, D/2)
64 | emb_cos = np.cos(out) # (M, D/2)
65 |
66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
67 | return emb
68 |
69 |
70 | # --------------------------------------------------------
71 | # Interpolate position embeddings for high-resolution
72 | # References:
73 | # DeiT: https://github.com/facebookresearch/deit
74 | # --------------------------------------------------------
75 | def interpolate_pos_embed(model, checkpoint_model):
76 | if 'pos_embed' in checkpoint_model:
77 | pos_embed_checkpoint = checkpoint_model['pos_embed']
78 | embedding_size = pos_embed_checkpoint.shape[-1]
79 | num_patches = model.patch_embed.num_patches
80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
81 | # height (== width) for the checkpoint position embedding
82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
83 | # height (== width) for the new position embedding
84 | new_size = int(num_patches ** 0.5)
85 | # class_token and dist_token are kept unchanged
86 | if orig_size != new_size:
87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
89 | # only the position tokens are interpolated
90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
92 | pos_tokens = torch.nn.functional.interpolate(
93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
96 | checkpoint_model['pos_embed'] = new_pos_embed
97 |
--------------------------------------------------------------------------------
/util/flair_dataloader/dataloader.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataset and Dataloader preparation for vision-language pre-training
3 | """
4 |
5 | import pandas as pd
6 | import ast
7 | from torchvision.transforms import Compose
8 | from torch.utils.data import DataLoader
9 |
10 | from util.flair_dataloader.dataset import Dataset, UniformDataset
11 | from util.flair_dataloader.transforms import LoadImage, ImageScaling, SelectRelevantKeys, CopyDict,\
12 | ProduceDescription, AugmentDescription
13 |
14 |
15 | def get_loader(dataframes_path, data_root_path, datasets, balance=False, batch_size=8, num_workers=0,
16 | banned_categories=None, caption="A fundus photograph of [CLS]", augment_description=True):
17 |
18 | """
19 | Dataloaders generation for vision-language pretraining. Read all dataframes from assembly model and combines
20 | them into a unified dataframe. Also, a dataloader is conditioned for training.
21 | """
22 |
23 | # Prepare data sample pre-processing transforms
24 | transforms = Compose([
25 | CopyDict(),
26 | LoadImage(),
27 | ImageScaling(),
28 | ProduceDescription(caption=caption),
29 | AugmentDescription(augment=augment_description),
30 | SelectRelevantKeys()
31 | ])
32 |
33 | # Assembly dataframes into a combined data structure
34 | print("Setting assebly data...")
35 | data = []
36 | for iDataset in datasets:
37 | print("Processing data: " + iDataset)
38 |
39 | dataframe = pd.read_csv(dataframes_path + iDataset + ".csv")
40 |
41 | for i in range(len(dataframe)):
42 | data_i = dataframe.loc[i, :].to_dict()
43 | data_i["categories"] = eval(data_i["categories"])
44 | data_i["atributes"] = eval(data_i["atributes"])
45 |
46 | # Remove banned words - for evaluating on incremental categories
47 | banned = False
48 | if banned_categories is not None:
49 | for iCat in data_i["categories"]:
50 | for iiCat in banned_categories:
51 | if iiCat in iCat:
52 | banned = True
53 | if banned:
54 | continue
55 |
56 | # Add sample to general data
57 | data_i["image_name"] = data_i["image"]
58 | data_i["image_path"] = data_root_path + data_i["image"]
59 | data.append(data_i)
60 |
61 | print('Total assembly data samples: {}'.format(len(data)))
62 |
63 | # Set data
64 | if balance:
65 | train_dataset = UniformDataset(data=data, transform=transforms)
66 | else:
67 | train_dataset = Dataset(data=data, transform=transforms)
68 |
69 | # Set dataloader
70 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
71 |
72 | # Set dataloaders in dict
73 | datalaoders = {"train": train_loader}
74 |
75 | return datalaoders
76 |
77 |
78 |
79 |
80 | def get_data_list(dataframes_path, data_root_path, datasets, balance=False, batch_size=8, num_workers=0,
81 | banned_categories=None, caption="A fundus photograph of [CLS]", augment_description=True):
82 |
83 | """
84 | Dataloaders generation for vision-language pretraining. Read all dataframes from assembly model and combines
85 | them into a unified dataframe. Also, a dataloader is conditioned for training.
86 | """
87 |
88 | # Assembly dataframes into a combined data structure
89 | print("Setting assebly data...")
90 | data = []
91 | for iDataset in datasets:
92 |
93 | dataframe = pd.read_csv(dataframes_path + iDataset + ".csv")
94 | print("Processing data: " + iDataset, len(dataframe))
95 |
96 | for i in range(len(dataframe)):
97 | data_i = dataframe.loc[i, :].to_dict()
98 | data_i["categories"] = ast.literal_eval(data_i["categories"])
99 | data_i["atributes"] = ast.literal_eval(data_i["atributes"])
100 |
101 | # Remove banned words - for evaluating on incremental categories
102 | banned = False
103 | if banned_categories is not None:
104 | for iCat in data_i["categories"]:
105 | for iiCat in banned_categories:
106 | if iiCat in iCat:
107 | banned = True
108 | if banned:
109 | continue
110 |
111 | # Add sample to general data
112 | data_i["image_name"] = data_i["image"]
113 | data_i["image_path"] = data_root_path + data_i["image"]
114 | data.append(data_i)
115 |
116 | print('Total assembly data samples: {}'.format(len(data)))
117 |
118 | return data
119 |
--------------------------------------------------------------------------------
/util/flair_dataloader/transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Methods for image and text loading, pre-processing and generation
3 | for vision-language pretraining. Also, it includes data augmentation
4 | utilities.
5 | """
6 |
7 | import numpy as np
8 | import random
9 | import torch
10 | import copy
11 |
12 | from PIL import Image
13 | from torchvision.transforms import Resize
14 | from util.flair_dataloader.dictionary import definitions
15 | import torchvision.transforms as transforms
16 | from torchvision.transforms.functional import InterpolationMode
17 | from PIL import Image
18 |
19 |
20 | class LoadImage():
21 | def __init__(self, target="image_path"):
22 | self.target = target
23 | """
24 | Load, organize channels, and standardize intensity of images.
25 | """
26 |
27 | def __call__(self, data):
28 | img = Image.open(data[self.target]).convert('RGB')
29 | data[self.target.replace("_path", "")] = img
30 | return data
31 |
32 |
33 | class ImageScaling():
34 |
35 | """
36 | Method for image scaling. It includes two options: scaling from canvas, to avoid image distortions,
37 | and regular scaling trough resizing.
38 | """
39 |
40 | def __init__(self, size=(512, 512), canvas=True, target="image"):
41 | self.size = size
42 | self.canvas = canvas
43 | self.target = target
44 |
45 | # self.transforms = torch.nn.Sequential(
46 | # Resize(self.size),
47 | # )
48 | self.transforms = transforms.Compose([
49 | transforms.RandomResizedCrop(448, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC), # 3 is bicubic
50 | transforms.RandomHorizontalFlip(),
51 | # transforms.Grayscale(num_output_channels=3),
52 | transforms.ToTensor(),
53 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
54 | # transforms.Normalize(mean=[0.4978], std=[0.2449])])
55 |
56 | def __call__(self, data):
57 | img = data[self.target]
58 | img = self.transforms(img)
59 | # if not self.canvas or (img.shape[-1] == img.shape[-2]):
60 | # img = self.transforms(img)
61 | # else:
62 | # sizes = img.shape[-2:]
63 | # max_size = max(sizes)
64 | # scale = max_size/self.size[0]
65 | # img = Resize((int(img.shape[-2]/scale), int((img.shape[-1]/scale))))(img)
66 | # img = torch.nn.functional.pad(img, (0, self.size[0] - img.shape[-1], 0, self.size[1] - img.shape[-2], 0, 0))
67 |
68 | data[self.target] = img
69 | return data
70 |
71 |
72 | class ProduceDescription():
73 |
74 | """
75 | Method that creates naive text prompts combining a prompt template, atributes (e.g. noisy), and categories
76 | (e.g. cataract). Also, this method is used to integrate text data with the modality prompt template.
77 | """
78 |
79 | def __init__(self, caption):
80 | self.caption = caption
81 |
82 | def __call__(self, data):
83 |
84 | # Create text
85 | atr_sample = random.sample(data['atributes'], 1)[0] if len(data['atributes']) > 0 else ""
86 | cat_sample = random.sample(data['categories'], 1)[0] if len(data['categories']) > 0 else ""
87 |
88 | data["sel_category"] = cat_sample
89 | if 'OCT' in atr_sample:
90 | data["report"] = ['An Optical Coherence Tomography Image shows '+cat_sample.lower()]
91 | else:
92 | data["report"] = [self.caption.replace("[ATR]", atr_sample).replace("[CLS]", cat_sample).replace(" ", " ")]
93 |
94 | return data
95 |
96 |
97 | class AugmentDescription():
98 |
99 | """
100 | Method that augments naive text prompts into expert knowledge prompts by changing the category name
101 | by expert descriptions of the target category.
102 | """
103 |
104 | def __init__(self, augment=False):
105 | self.augment = augment
106 |
107 | def __call__(self, data):
108 |
109 | if self.augment:
110 | if data["image_name"].split("/")[0] not in ["00_OCTCELL", "06_EYENET", "11_STARE", "08_ODIR-5K", "31_JICHI"]:
111 | if data["sel_category"] in list(definitions.keys()):
112 | prompts = [data["sel_category"]] + definitions[data["sel_category"]]
113 | new_cat = random.sample(prompts, 1)[0]
114 | data["report"][0] = data["report"][0].replace(data["sel_category"], new_cat)
115 | data["augmented_category"] = new_cat
116 |
117 | return data
118 |
119 |
120 | class CopyDict():
121 | def __call__(self, data):
122 | d = copy.deepcopy(data)
123 | return d
124 |
125 |
126 | class SelectRelevantKeys():
127 |
128 | def __init__(self, target_keys=None):
129 | if target_keys is None:
130 | target_keys = ['image', 'report', 'sel_category', 'atributes']
131 | self.target_keys = target_keys
132 |
133 | def __call__(self, data):
134 | d = {key: data[key] for key in self.target_keys}
135 | return d
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 【MICCAI 2024】 UrFound: Towards Universal Retinal Foundation Models via Knowledge-Guided Masked Modeling
2 |
3 | This repo is the official implementation of [UrFound](https://arxiv.org/pdf/2408.05618).
4 |
5 |
6 |

7 |
8 |
9 | ## Abstract
10 |
11 | Retinal foundation models aim to learn generalizable representations from diverse retinal images, facilitating label-efficient model adaptation across various ophthalmic tasks. Despite their success, current retinal foundation models are generally restricted to a single imaging modality, such as Color Fundus Photography (CFP) or Optical Coherence Tomography (OCT), limiting their versatility. Moreover, these models may struggle to fully leverage expert annotations and overlook the valuable domain knowledge essential for domain-specific representation learning. To overcome these limitations, we introduce UrFound, a retinal foundation model designed to learn universal representations from both multimodal retinal images and domain knowledge. UrFound is equipped with a modality-agnostic image encoder and accepts either CFP or OCT images as inputs. To integrate domain knowledge into representation learning, we encode expert annotation in text supervision and propose a knowledge-guided masked modeling strategy for model pre-training. It involves reconstructing randomly masked patches of retinal images while predicting masked text tokens conditioned on the corresponding retinal image. This approach aligns multimodal images and textual expert annotations within a unified latent space, facilitating generalizable and domain-specific representation learning. Experimental results demonstrate that UrFound exhibits strong generalization ability and data efficiency when adapting to various tasks in retinal image analysis. By training on ~180k retinal images, UrFound significantly outperforms the state-of-the-art retinal foundation model trained on up to 1.6 million unlabelled images across 8 public retinal datasets.
12 |
13 | ## Framework
14 |
15 |
16 |

17 |
18 |
19 | ## Get started
20 |
21 | ### Installation
22 |
23 | ```bash
24 | # Clone this repo
25 | git clone https://github.com/yukkai/UrFound.git
26 | cd UrFound
27 |
28 | # Create a conda enviroment
29 | conda create -n urfound python=3.7.5
30 |
31 | # Activate the environment
32 | conda activate urfound
33 |
34 | # Install dependencies
35 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
36 | pip install -r requirements.txt
37 | ```
38 |
39 | ### Datasets
40 |
41 | * Pretrain dataset ([FLAIR](https://github.com/jusiro/FLAIR) for more details)
42 |
43 |
44 |

45 |
46 |
47 | * Finetune dataset ([RETFound](https://github.com/rmaphoh/RETFound_MAE) for more details)
48 |
49 |
50 |

51 |
52 |
53 | ### How to Run
54 |
55 | * Pretrain
56 | ```bash
57 | pretrained_model='model pre-trained on ImageNet'
58 | pretrain_data='downloaded pretrain dataset'
59 | output_path='output path'
60 |
61 | CUDA_VISIBLE_DEVICES=0 ./main_pretrain_urfound.py \
62 | --num_workers 32 \
63 | --accum_iter 2 \
64 | --batch_size 128 \
65 | --model urmodel \
66 | --norm_pix_loss \
67 | --mask_ratio 0.75 \
68 | --epochs 200 \
69 | --warmup_epochs 40 \
70 | --blr 1.5e-4 --weight_decay 0.05 \
71 | --resume ${pretrained_model} \
72 | --data_path ${pretrain_data} \
73 | --output_dir ${output_path} \
74 | --data_mode fundus_oct \
75 | ```
76 |
77 | * Finetune
78 | ```bash
79 | data_i='downstream task dataset'
80 | nb_classes='class num'
81 | Pretraining_model='pretrained model'
82 | Out_folder='output path'
83 |
84 | CUDA_VISIBLE_DEVICES=0 python ./main_finetune.py \
85 | --batch_size 16 \
86 | --world_size 1 \
87 | --model vit_base_patch16 \
88 | --epochs 50 \
89 | --blr 5e-3 --layer_decay 0.65 \
90 | --weight_decay 0.05 --drop_path 0.2 \
91 | --nb_classes ${nb_classes} \
92 | --data_path ./${data_i}/ \
93 | --task ${data_i}/ \
94 | --finetune ${Pretraining_model} \
95 | --input_size 224 \
96 | --log_dir ${Out_folder}/ \
97 | --output_dir ${Out_folder}/
98 | ```
99 | ## Release
100 | * Pretraiend model [[Checkpoints](https://huggingface.co/yyyyk/UrFound)]
101 |
102 | ## Citation
103 |
104 | ```
105 | @article{yu2024urfound,
106 | title={UrFound: Towards Universal Retinal Foundation Models via Knowledge-Guided Masked Modeling},
107 | author={Yu, Kai and Zhou, Yang and Bai, Yang and Da Soh, Zhi and Xu, Xinxing and Goh, Rick Siow Mong and Cheng, Ching-Yu and Liu, Yong},
108 | journal={arXiv preprint arXiv:2408.05618},
109 | year={2024}
110 | }
111 | ```
112 |
113 | ## Acknowledgements
114 |
115 | We extend our appreciation to the developers of the [RETFound](https://github.com/rmaphoh/RETFound_MAE), [FLAIR](https://github.com/jusiro/FLAIR) and [MRM](https://github.com/RL4M/MRM-pytorch) project for sharing their open-source implementation and providing guidance on preparing the data.
--------------------------------------------------------------------------------
/bert/bert_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .bert import MyBertMaskedLM
4 | from transformers.configuration_utils import PretrainedConfig
5 |
6 | class BertConfig(PretrainedConfig):
7 | r"""
8 | This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to
9 | instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a
10 | configuration with the defaults will yield a similar configuration to that of the BERT
11 | [bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture.
12 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
13 | documentation from [`PretrainedConfig`] for more information.
14 | Args:
15 | vocab_size (`int`, *optional*, defaults to 30522):
16 | Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
17 | `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`].
18 | hidden_size (`int`, *optional*, defaults to 768):
19 | Dimensionality of the encoder layers and the pooler layer.
20 | num_hidden_layers (`int`, *optional*, defaults to 12):
21 | Number of hidden layers in the Transformer encoder.
22 | num_attention_heads (`int`, *optional*, defaults to 12):
23 | Number of attention heads for each attention layer in the Transformer encoder.
24 | intermediate_size (`int`, *optional*, defaults to 3072):
25 | Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
26 | hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
27 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
28 | `"relu"`, `"silu"` and `"gelu_new"` are supported.
29 | hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
30 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
31 | attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
32 | The dropout ratio for the attention probabilities.
33 | max_position_embeddings (`int`, *optional*, defaults to 512):
34 | The maximum sequence length that this model might ever be used with. Typically set this to something large
35 | just in case (e.g., 512 or 1024 or 2048).
36 | type_vocab_size (`int`, *optional*, defaults to 2):
37 | The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`].
38 | initializer_range (`float`, *optional*, defaults to 0.02):
39 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
40 | layer_norm_eps (`float`, *optional*, defaults to 1e-12):
41 | The epsilon used by the layer normalization layers.
42 | position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
43 | Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
44 | positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
45 | [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
46 | For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
47 | with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
48 | use_cache (`bool`, *optional*, defaults to `True`):
49 | Whether or not the model should return the last key/values attentions (not used by all models). Only
50 | relevant if `config.is_decoder=True`.
51 | classifier_dropout (`float`, *optional*):
52 | The dropout ratio for the classification head.
53 | Examples:
54 | ```python
55 | >>> from transformers import BertModel, BertConfig
56 | >>> # Initializing a BERT bert-base-uncased style configuration
57 | >>> configuration = BertConfig()
58 | >>> # Initializing a model from the bert-base-uncased style configuration
59 | >>> model = BertModel(configuration)
60 | >>> # Accessing the model configuration
61 | >>> configuration = model.config
62 | ```"""
63 | model_type = "bert"
64 |
65 | def __init__(
66 | self,
67 | vocab_size=30000,
68 | hidden_size=384,
69 | num_hidden_layers=6,
70 | num_attention_heads=6,
71 | intermediate_size=1536,
72 | hidden_act="gelu",
73 | hidden_dropout_prob=0.1,
74 | attention_probs_dropout_prob=0.1,
75 | max_position_embeddings=100,
76 | type_vocab_size=2,
77 | initializer_range=0.02,
78 | layer_norm_eps=1e-12,
79 | pad_token_id=0,
80 | position_embedding_type="absolute",
81 | use_cache=True,
82 | classifier_dropout=None,
83 | **kwargs
84 | ):
85 | super().__init__(pad_token_id=pad_token_id, **kwargs)
86 |
87 | self.vocab_size = vocab_size
88 | self.hidden_size = hidden_size
89 | self.num_hidden_layers = num_hidden_layers
90 | self.num_attention_heads = num_attention_heads
91 | self.hidden_act = hidden_act
92 | self.intermediate_size = intermediate_size
93 | self.hidden_dropout_prob = hidden_dropout_prob
94 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
95 | self.max_position_embeddings = max_position_embeddings
96 | self.type_vocab_size = type_vocab_size
97 | self.initializer_range = initializer_range
98 | self.layer_norm_eps = layer_norm_eps
99 | self.position_embedding_type = position_embedding_type
100 | self.use_cache = use_cache
101 | self.classifier_dropout = classifier_dropout
102 |
103 | class BertEncoder(nn.Module):
104 | def __init__(self):
105 | super(BertEncoder, self).__init__()
106 |
107 | self.model = MyBertMaskedLM(BertConfig())
108 |
109 | def forward(self, latent, ids, labels, attn_mask, token_type):
110 |
111 | outputs = self.model(latent, ids, attn_mask, token_type, labels = labels)
112 |
113 | return outputs
--------------------------------------------------------------------------------
/util/dataset.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import os
3 | from typing import List, Tuple
4 | from PIL import Image
5 | import pandas as pd
6 | import torch
7 | from torch.utils.data import Dataset
8 | import tokenizers
9 | import random
10 |
11 | from transformers import AutoModel, AutoTokenizer
12 |
13 | from util.flair_dataloader import dataloader
14 |
15 |
16 | # ===========
17 | dataframes_path = 'TODO CHANGE TO YOUR DATAPATH./pretrain_data/fundus_labels/'
18 | data_root_path = 'TODO CHANGE TO YOUR DATAPATH./pretrain_data/fundus_oct_images/'
19 | datasets_debug = ["04_RFMid", "00_OCTCELL"]
20 |
21 | datasets_oct = ["00_OCTCELL"]
22 | datasets_fundus = ["01_EYEPACS", "04_RFMid",
23 | "06_DEN", "07_LAG", "08_ODIR", "10_PARAGUAY",
24 | "11_STARE", "12_ARIA", "14_AGAR300", "16_FUND-OCT",
25 | "18_DRIONS-DB", "19_Drishti-GS1",
26 | "20_E-ophta", "21_G1020", "23_HRF", "24_ORIGA", "26_ROC",
27 | "28_OIA-DDR", "30_SUSTech-SYSU", "31_JICHI",
28 | "32_CHAKSU", "33_DR1-2", "35_ScarDat", "36_ACRIMA", "37_DeepDRiD_test", "37_DeepDRiD_train_eval"]
29 | datasets_fundus_oct = ["00_OCTCELL", "01_EYEPACS", "04_RFMid",
30 | "06_DEN", "07_LAG", "08_ODIR", "10_PARAGUAY",
31 | "11_STARE", "12_ARIA", "14_AGAR300", "16_FUND-OCT",
32 | "18_DRIONS-DB", "19_Drishti-GS1",
33 | "20_E-ophta", "21_G1020", "23_HRF", "24_ORIGA", "26_ROC",
34 | "28_OIA-DDR", "30_SUSTech-SYSU", "31_JICHI",
35 | "32_CHAKSU", "33_DR1-2", "35_ScarDat", "36_ACRIMA", "37_DeepDRiD_test", "37_DeepDRiD_train_eval"]
36 |
37 | balance = True
38 | batch_size = 16
39 | num_workers = 10
40 | banned_categories = []
41 | caption = "A [ATR] fundus photograph of [CLS]"
42 | augment_description = True
43 | from torchvision.transforms import Compose
44 | from util.flair_dataloader.transforms import LoadImage, ImageScaling, SelectRelevantKeys, CopyDict,\
45 | ProduceDescription, AugmentDescription
46 | # ===========
47 |
48 |
49 | def pil_loader(path: str) -> Image.Image:
50 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
51 | with open(path, 'rb') as f:
52 | img = Image.open(f)
53 | return img.convert('RGB')
54 |
55 |
56 | class MultimodalBertDataset_flair(Dataset):
57 | def __init__(
58 | self,
59 | data_mode,
60 | max_caption_length: int = 100
61 | ):
62 | if data_mode == 'debug':
63 | datasets_train = datasets_debug
64 | elif data_mode == 'fundus':
65 | datasets_train = datasets_fundus
66 | elif data_mode == 'oct':
67 | datasets_train = datasets_oct
68 | elif data_mode == 'fundus_oct':
69 | datasets_train = datasets_fundus_oct
70 |
71 | self.data_list = dataloader.get_data_list(dataframes_path,
72 | data_root_path, datasets_train, balance,
73 | batch_size, num_workers, banned_categories,
74 | caption, augment_description)
75 |
76 | self.transforms = Compose([
77 | CopyDict(),
78 | LoadImage(),
79 | ImageScaling(),
80 | ProduceDescription(caption=caption),
81 | AugmentDescription(augment=augment_description),
82 | SelectRelevantKeys()
83 | ])
84 |
85 | self.max_caption_length = max_caption_length
86 | # self.data_root = data_root
87 | # self.transform = transform
88 | # self.images_list, self.report_list = self.read_csv()
89 | # # random
90 | # random_seed = 42
91 | # random.seed(random_seed)
92 | # random.shuffle(self.images_list)
93 | # random.seed(random_seed)
94 | # random.shuffle(self.report_list)
95 | self.tokenizer = tokenizers.Tokenizer.from_pretrained('bert-base-uncased')
96 | # self.tokenizer = AutoTokenizer.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')
97 | # self.tokenizer.model_max_length = 77
98 |
99 | # self.tokenizer = tokenizers.Tokenizer.from_file("mimic_wordpiece.json")
100 | self.idxtoword = {v: k for k, v in self.tokenizer.get_vocab().items()}
101 | self.tokenizer.enable_truncation(max_length=self.max_caption_length)
102 | self.tokenizer.enable_padding(length=self.max_caption_length)
103 |
104 | def __len__(self):
105 | return len(self.data_list)
106 |
107 | def _random_mask(self,tokens):
108 | masked_tokens = deepcopy(tokens)
109 | for i in range(1, masked_tokens.shape[1]-1):
110 | if masked_tokens[0][i] == 0:
111 | break
112 |
113 | if masked_tokens[0][i-1] == 3 and self.idxtoword[masked_tokens[0][i].item()][0:2] == '##':
114 | masked_tokens[0][i] = 3
115 | continue
116 |
117 | if masked_tokens[0][i-1] != 3 and self.idxtoword[masked_tokens[0][i].item()][0:2] == '##':
118 | continue
119 |
120 | prob = random.random()
121 | if prob < 0.5:
122 | masked_tokens[0][i] = 3
123 |
124 | return masked_tokens
125 |
126 | def __getitem__(self, index):
127 | batch = self.transforms(self.data_list[index])
128 | image = batch['image']
129 | sent = batch['report'][0]
130 | data_moda = torch.tensor(0)
131 | if 'OCT' in batch['atributes']:
132 | data_moda = torch.tensor(1)
133 |
134 | # image = pil_loader(self.images_list[index])
135 | # image = self.transform(image)
136 | # sent = self.report_list[index]
137 | # sent = '[CLS] '+ sent
138 |
139 | encoded = self.tokenizer.encode(sent)
140 | ids = torch.tensor(encoded.ids).unsqueeze(0)
141 | attention_mask = torch.tensor(encoded.attention_mask).unsqueeze(0)
142 | type_ids = torch.tensor(encoded.type_ids).unsqueeze(0)
143 | masked_ids = self._random_mask(ids)
144 | return image, ids, attention_mask, type_ids, masked_ids, data_moda
145 |
146 | # def read_csv(self):
147 | # csv_path = os.path.join(self.data_root,'training.csv')
148 | # df = pd.read_csv(csv_path,sep=',')
149 | # return df["image_path"], df["report_content"]
150 |
151 | def collate_fn(self, instances: List[Tuple]):
152 | image_list, ids_list, attention_mask_list, type_ids_list, masked_ids_list, datamoda_list = [], [], [], [], [], []
153 | # flattern
154 | for b in instances:
155 | image, ids, attention_mask, type_ids, masked_ids, moda_ids = b
156 | image_list.append(image)
157 | ids_list.append(ids)
158 | attention_mask_list.append(attention_mask)
159 | type_ids_list.append(type_ids)
160 | masked_ids_list.append(masked_ids)
161 | datamoda_list.append(moda_ids)
162 |
163 | # stack
164 | image_stack = torch.stack(image_list)
165 | ids_stack = torch.stack(ids_list).squeeze()
166 | attention_mask_stack = torch.stack(attention_mask_list).squeeze()
167 | type_ids_stack = torch.stack(type_ids_list).squeeze()
168 | masked_ids_stack = torch.stack(masked_ids_list).squeeze()
169 | moda_ids_stack = torch.stack(datamoda_list).squeeze()
170 |
171 | # sort and add to dictionary
172 | return_dict = {
173 | "image": image_stack,
174 | "labels": ids_stack,
175 | "attention_mask": attention_mask_stack,
176 | "type_ids": type_ids_stack,
177 | "ids": masked_ids_stack,
178 | 'tag': moda_ids_stack
179 | }
180 |
181 | return return_dict
--------------------------------------------------------------------------------
/finetune/engine_finetune.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | # Partly revised by YZ @UCL&Moorfields
4 | # --------------------------------------------------------
5 |
6 | import math
7 | import sys
8 | import csv
9 | import os
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from timm.data import Mixup
14 | from timm.utils import accuracy
15 | from typing import Iterable, Optional
16 | import util.misc as misc
17 | import util.lr_sched as lr_sched
18 | from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, average_precision_score,multilabel_confusion_matrix
19 | from pycm import *
20 | import matplotlib.pyplot as plt
21 | import numpy as np
22 |
23 |
24 |
25 |
26 | def misc_measures(confusion_matrix):
27 |
28 | acc = []
29 | sensitivity = []
30 | specificity = []
31 | precision = []
32 | G = []
33 | F1_score_2 = []
34 | mcc_ = []
35 |
36 | for i in range(1, confusion_matrix.shape[0]):
37 | cm1=confusion_matrix[i]
38 | acc.append(1.*(cm1[0,0]+cm1[1,1])/np.sum(cm1))
39 | sensitivity_ = 1.*cm1[1,1]/(cm1[1,0]+cm1[1,1])
40 | sensitivity.append(sensitivity_)
41 | specificity_ = 1.*cm1[0,0]/(cm1[0,1]+cm1[0,0])
42 | specificity.append(specificity_)
43 | precision_ = 1.*cm1[1,1]/(cm1[1,1]+cm1[0,1])
44 | precision.append(precision_)
45 | G.append(np.sqrt(sensitivity_*specificity_))
46 | F1_score_2.append(2*precision_*sensitivity_/(precision_+sensitivity_))
47 | mcc = (cm1[0,0]*cm1[1,1]-cm1[0,1]*cm1[1,0])/np.sqrt((cm1[0,0]+cm1[0,1])*(cm1[0,0]+cm1[1,0])*(cm1[1,1]+cm1[1,0])*(cm1[1,1]+cm1[0,1]))
48 | mcc_.append(mcc)
49 |
50 | acc = np.array(acc).mean()
51 | sensitivity = np.array(sensitivity).mean()
52 | specificity = np.array(specificity).mean()
53 | precision = np.array(precision).mean()
54 | G = np.array(G).mean()
55 | F1_score_2 = np.array(F1_score_2).mean()
56 | mcc_ = np.array(mcc_).mean()
57 |
58 | return acc, sensitivity, specificity, precision, G, F1_score_2, mcc_
59 |
60 |
61 |
62 |
63 |
64 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
65 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
66 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
67 | mixup_fn: Optional[Mixup] = None, log_writer=None,
68 | args=None):
69 | model.train(True)
70 | metric_logger = misc.MetricLogger(delimiter=" ")
71 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
72 | header = 'Epoch: [{}]'.format(epoch)
73 | print_freq = 20
74 |
75 | accum_iter = args.accum_iter
76 |
77 | optimizer.zero_grad()
78 |
79 | if log_writer is not None:
80 | print('log_dir: {}'.format(log_writer.log_dir))
81 |
82 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
83 |
84 | # we use a per iteration (instead of per epoch) lr scheduler
85 | if data_iter_step % accum_iter == 0:
86 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
87 |
88 | samples = samples.to(device, non_blocking=True)
89 | targets = targets.to(device, non_blocking=True)
90 |
91 | if mixup_fn is not None:
92 | samples, targets = mixup_fn(samples, targets)
93 |
94 | with torch.cuda.amp.autocast():
95 | outputs = model(samples)
96 | loss = criterion(outputs, targets)
97 |
98 | loss_value = loss.item()
99 |
100 | if not math.isfinite(loss_value):
101 | print("Loss is {}, stopping training".format(loss_value))
102 | sys.exit(1)
103 |
104 | loss /= accum_iter
105 | loss_scaler(loss, optimizer, clip_grad=max_norm,
106 | parameters=model.parameters(), create_graph=False,
107 | update_grad=(data_iter_step + 1) % accum_iter == 0)
108 | if (data_iter_step + 1) % accum_iter == 0:
109 | optimizer.zero_grad()
110 |
111 | torch.cuda.synchronize()
112 |
113 | metric_logger.update(loss=loss_value)
114 | min_lr = 10.
115 | max_lr = 0.
116 | for group in optimizer.param_groups:
117 | min_lr = min(min_lr, group["lr"])
118 | max_lr = max(max_lr, group["lr"])
119 |
120 | metric_logger.update(lr=max_lr)
121 |
122 | loss_value_reduce = misc.all_reduce_mean(loss_value)
123 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
124 | """ We use epoch_1000x as the x-axis in tensorboard.
125 | This calibrates different curves when batch size changes.
126 | """
127 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
128 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
129 | log_writer.add_scalar('lr', max_lr, epoch_1000x)
130 |
131 | # gather the stats from all processes
132 | metric_logger.synchronize_between_processes()
133 | print("Averaged stats:", metric_logger)
134 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
135 |
136 |
137 |
138 |
139 | @torch.no_grad()
140 | def evaluate(data_loader, model, device, task, epoch, mode, num_class):
141 | criterion = torch.nn.CrossEntropyLoss()
142 |
143 | metric_logger = misc.MetricLogger(delimiter=" ")
144 | header = 'Test:'
145 |
146 | if not os.path.exists(task):
147 | os.makedirs(task)
148 |
149 | prediction_decode_list = []
150 | prediction_list = []
151 | true_label_decode_list = []
152 | true_label_onehot_list = []
153 |
154 | # switch to evaluation mode
155 | model.eval()
156 |
157 | for batch in metric_logger.log_every(data_loader, 10, header):
158 | images = batch[0]
159 | target = batch[-1]
160 | images = images.to(device, non_blocking=True)
161 | target = target.to(device, non_blocking=True)
162 | true_label=F.one_hot(target.to(torch.int64), num_classes=num_class)
163 |
164 | # compute output
165 | with torch.cuda.amp.autocast():
166 | output = model(images)
167 | loss = criterion(output, target)
168 | prediction_softmax = nn.Softmax(dim=1)(output)
169 | _,prediction_decode = torch.max(prediction_softmax, 1)
170 | _,true_label_decode = torch.max(true_label, 1)
171 |
172 | prediction_decode_list.extend(prediction_decode.cpu().detach().numpy())
173 | true_label_decode_list.extend(true_label_decode.cpu().detach().numpy())
174 | true_label_onehot_list.extend(true_label.cpu().detach().numpy())
175 | prediction_list.extend(prediction_softmax.cpu().detach().numpy())
176 |
177 | acc1,_ = accuracy(output, target, topk=(1,2))
178 |
179 | batch_size = images.shape[0]
180 | metric_logger.update(loss=loss.item())
181 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
182 | # gather the stats from all processes
183 | true_label_decode_list = np.array(true_label_decode_list)
184 | prediction_decode_list = np.array(prediction_decode_list)
185 | confusion_matrix = multilabel_confusion_matrix(true_label_decode_list, prediction_decode_list,labels=[i for i in range(num_class)])
186 | acc, sensitivity, specificity, precision, G, F1, mcc = misc_measures(confusion_matrix)
187 |
188 | auc_roc = roc_auc_score(true_label_onehot_list, prediction_list,multi_class='ovr',average='macro')
189 | auc_pr = average_precision_score(true_label_onehot_list, prediction_list,average='macro')
190 |
191 | metric_logger.synchronize_between_processes()
192 |
193 | print('Sklearn Metrics - Acc: {:.4f} AUC-roc: {:.4f} AUC-pr: {:.4f} F1-score: {:.4f} MCC: {:.4f}'.format(acc, auc_roc, auc_pr, F1, mcc))
194 | results_path = task+'/_metrics_{}.csv'.format(mode)
195 | with open(results_path,mode='a',newline='',encoding='utf8') as cfa:
196 | wf = csv.writer(cfa)
197 | data2=[[acc,sensitivity,specificity,precision,auc_roc,auc_pr,F1,mcc,metric_logger.loss]]
198 | for i in data2:
199 | wf.writerow(i)
200 |
201 |
202 | if 'test' in mode:
203 | cm = ConfusionMatrix(actual_vector=true_label_decode_list, predict_vector=prediction_decode_list)
204 | cm.plot(cmap=plt.cm.Blues,number_label=True,normalized=True,plot_lib="matplotlib")
205 | plt.savefig(task+'/confusion_matrix_{}.jpg'.format(mode),dpi=600,bbox_inches ='tight')
206 |
207 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()},auc_roc
208 |
209 |
--------------------------------------------------------------------------------
/main_pretrain_urfound.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # DeiT: https://github.com/facebookresearch/deit
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 | import argparse
12 | import datetime
13 | import json
14 | import numpy as np
15 | import os
16 | import time
17 | from pathlib import Path
18 |
19 | import torch
20 | import torch.backends.cudnn as cudnn
21 | from torch.utils.tensorboard import SummaryWriter
22 |
23 | import timm
24 |
25 | assert timm.__version__ == "0.3.2" # version check
26 | import timm.optim.optim_factory as optim_factory
27 |
28 | import util.misc as misc
29 | from util.misc import NativeScalerWithGradNormCount as NativeScaler
30 |
31 | import util.model_urfound as model_urfound
32 |
33 | from util.engine_pretrain import train_one_epoch
34 | from util.dataset import MultimodalBertDataset_flair
35 |
36 | def get_args_parser():
37 | parser = argparse.ArgumentParser('UrFound pre-training', add_help=False)
38 | parser.add_argument('--batch_size', default=64, type=int,
39 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
40 | parser.add_argument('--epochs', default=400, type=int)
41 | parser.add_argument('--accum_iter', default=1, type=int,
42 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
43 |
44 | # Model parameters
45 | parser.add_argument('--model', default='urmodel', type=str, metavar='MODEL',
46 | help='Name of model to train')
47 |
48 | parser.add_argument('--input_size', default=224, type=int,
49 | help='images input size')
50 |
51 | parser.add_argument('--mask_ratio', default=0.75, type=float,
52 | help='Masking ratio (percentage of removed patches).')
53 |
54 | parser.add_argument('--norm_pix_loss', action='store_true',
55 | help='Use (per-patch) normalized pixels as targets for computing loss')
56 | parser.set_defaults(norm_pix_loss=False)
57 |
58 | # Optimizer parameters
59 | parser.add_argument('--weight_decay', type=float, default=0.05,
60 | help='weight decay (default: 0.05)')
61 |
62 | parser.add_argument('--lr', type=float, default=None, metavar='LR',
63 | help='learning rate (absolute lr)')
64 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
65 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
66 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
67 | help='lower lr bound for cyclic schedulers that hit 0')
68 |
69 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
70 | help='epochs to warmup LR')
71 |
72 | # Dataset parameters
73 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
74 | help='dataset path')
75 |
76 | parser.add_argument('--output_dir', default='./output_dir',
77 | help='path where to save, empty for no saving')
78 | parser.add_argument('--log_dir', default='./output_dir',
79 | help='path where to tensorboard log')
80 | parser.add_argument('--device', default='cuda',
81 | help='device to use for training / testing')
82 | parser.add_argument('--seed', default=0, type=int)
83 | parser.add_argument('--resume', default='',
84 | help='resume from checkpoint')
85 |
86 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
87 | help='start epoch')
88 | parser.add_argument('--num_workers', default=10, type=int)
89 | parser.add_argument('--pin_mem', action='store_true',
90 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
91 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
92 | parser.set_defaults(pin_mem=True)
93 |
94 | # distributed training parameters
95 | parser.add_argument('--world_size', default=1, type=int,
96 | help='number of distributed processes')
97 | parser.add_argument('--local_rank', default=-1, type=int)
98 | parser.add_argument('--dist_on_itp', action='store_true')
99 | parser.add_argument('--dist_url', default='env://',
100 | help='url used to set up distributed training')
101 |
102 | parser.add_argument('--data_mode', default='fundus', type=str,
103 | help='dataset mode: debug / fundus / oct / fundus_oct')
104 | return parser
105 |
106 |
107 | def main(args):
108 | misc.init_distributed_mode(args)
109 |
110 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
111 | print("{}".format(args).replace(', ', ',\n'))
112 |
113 | device = torch.device(args.device)
114 |
115 | # fix the seed for reproducibility
116 | seed = args.seed + misc.get_rank()
117 | torch.manual_seed(seed)
118 | np.random.seed(seed)
119 |
120 | cudnn.benchmark = True
121 |
122 | # dataset
123 | dataset_train = MultimodalBertDataset_flair(args.data_mode)
124 |
125 | print(dataset_train)
126 |
127 | if True: # args.distributed:
128 | num_tasks = misc.get_world_size()
129 | global_rank = misc.get_rank()
130 | sampler_train = torch.utils.data.DistributedSampler(
131 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
132 | )
133 | print("Sampler_train = %s" % str(sampler_train))
134 | else:
135 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
136 |
137 | args.log_dir = os.path.join(args.output_dir, "logs")
138 | if global_rank == 0 and args.log_dir is not None:
139 | os.makedirs(args.log_dir, exist_ok=True)
140 | log_writer = SummaryWriter(log_dir=args.log_dir)
141 | else:
142 | log_writer = None
143 |
144 | data_loader_train = torch.utils.data.DataLoader(
145 | dataset_train, sampler=sampler_train,
146 | batch_size=args.batch_size,
147 | num_workers=args.num_workers,
148 | pin_memory=args.pin_mem,
149 | drop_last=True,
150 | collate_fn=dataset_train.collate_fn
151 | )
152 |
153 | # define the model
154 | model = model_urfound.__dict__[args.model](norm_pix_loss=args.norm_pix_loss)
155 |
156 | model.to(device)
157 |
158 | model_without_ddp = model
159 | print("Model = %s" % str(model_without_ddp))
160 |
161 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
162 |
163 | if args.lr is None: # only base_lr is specified
164 | args.lr = args.blr * eff_batch_size / 256
165 |
166 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
167 | print("actual lr: %.2e" % args.lr)
168 |
169 | print("accumulate grad iterations: %d" % args.accum_iter)
170 | print("effective batch size: %d" % eff_batch_size)
171 |
172 | if args.distributed:
173 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
174 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
175 | model_without_ddp = model.module
176 |
177 | # following timm: set wd as 0 for bias and norm layers
178 | param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
179 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
180 | print(optimizer)
181 | loss_scaler = NativeScaler()
182 |
183 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
184 |
185 | print(f"Start training for {args.epochs} epochs")
186 | start_time = time.time()
187 | for epoch in range(args.start_epoch, args.epochs):
188 | if args.distributed:
189 | data_loader_train.sampler.set_epoch(epoch)
190 | train_stats = train_one_epoch(
191 | model, data_loader_train,
192 | optimizer, device, epoch, loss_scaler,
193 | log_writer=log_writer,
194 | args=args
195 | )
196 | if args.output_dir and (epoch % 20 == 0 or epoch + 1 == args.epochs):
197 | misc.save_model_pretrain(
198 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
199 | loss_scaler=loss_scaler, epoch=epoch)
200 |
201 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
202 | 'epoch': epoch,}
203 |
204 | if args.output_dir and misc.is_main_process():
205 | if log_writer is not None:
206 | log_writer.flush()
207 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
208 | f.write(json.dumps(log_stats) + "\n")
209 |
210 | total_time = time.time() - start_time
211 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
212 | print('Training time {}'.format(total_time_str))
213 |
214 |
215 | if __name__ == '__main__':
216 | args = get_args_parser()
217 | args = args.parse_args()
218 | if args.output_dir:
219 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
220 | main(args)
--------------------------------------------------------------------------------
/bert/bert.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import BertModel, BertForMaskedLM
3 | from transformers.modeling_outputs import MaskedLMOutput, BaseModelOutputWithPoolingAndCrossAttentions
4 | from torch.nn import CrossEntropyLoss
5 |
6 |
7 | class MyBertModel(BertModel):
8 | def __init__(self, config, add_pooling_layer=True):
9 | super().__init__(config)
10 |
11 | def forward(
12 | self,
13 | latent = None,
14 | input_ids=None,
15 | attention_mask=None,
16 | token_type_ids=None,
17 | position_ids=None,
18 | head_mask=None,
19 | inputs_embeds=None,
20 | encoder_hidden_states=None,
21 | encoder_attention_mask=None,
22 | past_key_values=None,
23 | use_cache=None,
24 | output_attentions=None,
25 | output_hidden_states=None,
26 | return_dict=None,
27 | ):
28 | r"""
29 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
30 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
31 | the model is configured as a decoder.
32 | encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
33 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
34 | the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
35 | - 1 for tokens that are **not masked**,
36 | - 0 for tokens that are **masked**.
37 | past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
38 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
39 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
40 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
41 | `decoder_input_ids` of shape `(batch_size, sequence_length)`.
42 | use_cache (`bool`, *optional*):
43 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
44 | `past_key_values`).
45 | """
46 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
47 | output_hidden_states = (
48 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
49 | )
50 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
51 |
52 | if self.config.is_decoder:
53 | use_cache = use_cache if use_cache is not None else self.config.use_cache
54 | else:
55 | use_cache = False
56 |
57 | if input_ids is not None and inputs_embeds is not None:
58 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
59 | elif input_ids is not None:
60 | input_shape = input_ids.size()
61 | elif inputs_embeds is not None:
62 | input_shape = inputs_embeds.size()[:-1]
63 | else:
64 | raise ValueError("You have to specify either input_ids or inputs_embeds")
65 |
66 | batch_size, seq_length = input_shape
67 | device = input_ids.device if input_ids is not None else inputs_embeds.device
68 |
69 | # past_key_values_length
70 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
71 |
72 | if attention_mask is None:
73 | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
74 |
75 | if token_type_ids is None:
76 | if hasattr(self.embeddings, "token_type_ids"):
77 | buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
78 | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
79 | token_type_ids = buffered_token_type_ids_expanded
80 | else:
81 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
82 |
83 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
84 | # ourselves in which case we just need to make it broadcastable to all heads.
85 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
86 |
87 | # If a 2D or 3D attention mask is provided for the cross-attention
88 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
89 | if self.config.is_decoder and encoder_hidden_states is not None:
90 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
91 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
92 | if encoder_attention_mask is None:
93 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
94 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
95 | else:
96 | encoder_extended_attention_mask = None
97 |
98 | # Prepare head mask if needed
99 | # 1.0 in head_mask indicate we keep the head
100 | # attention_probs has shape bsz x n_heads x N x N
101 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
102 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
103 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
104 |
105 | embedding_output = self.embeddings(
106 | input_ids=input_ids,
107 | position_ids=position_ids,
108 | token_type_ids=token_type_ids,
109 | inputs_embeds=inputs_embeds,
110 | past_key_values_length=past_key_values_length,
111 | )
112 | embedding_output = embedding_output + latent.unsqueeze(1)
113 | encoder_outputs = self.encoder(
114 | embedding_output,
115 | attention_mask=extended_attention_mask,
116 | head_mask=head_mask,
117 | encoder_hidden_states=encoder_hidden_states,
118 | encoder_attention_mask=encoder_extended_attention_mask,
119 | past_key_values=past_key_values,
120 | use_cache=use_cache,
121 | output_attentions=output_attentions,
122 | output_hidden_states=output_hidden_states,
123 | return_dict=return_dict,
124 | )
125 | sequence_output = encoder_outputs[0]
126 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
127 |
128 | if not return_dict:
129 | return (sequence_output, pooled_output) + encoder_outputs[1:]
130 |
131 | return BaseModelOutputWithPoolingAndCrossAttentions(
132 | last_hidden_state=sequence_output,
133 | pooler_output=pooled_output,
134 | past_key_values=encoder_outputs.past_key_values,
135 | hidden_states=encoder_outputs.hidden_states,
136 | attentions=encoder_outputs.attentions,
137 | cross_attentions=encoder_outputs.cross_attentions,
138 | )
139 |
140 | class MyBertMaskedLM(BertForMaskedLM):
141 | def __init__(self, config):
142 | super().__init__(config)
143 | self.bert = MyBertModel(config, add_pooling_layer=False)
144 |
145 | def forward(
146 | self,
147 | latent=None,
148 | input_ids=None,
149 | attention_mask=None,
150 | token_type_ids=None,
151 | position_ids=None,
152 | head_mask=None,
153 | inputs_embeds=None,
154 | encoder_hidden_states=None,
155 | encoder_attention_mask=None,
156 | labels=None,
157 | output_attentions=None,
158 | output_hidden_states=None,
159 | return_dict=None,
160 | ):
161 | r"""
162 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
163 | Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
164 | config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
165 | loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
166 | """
167 |
168 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
169 |
170 | outputs = self.bert(
171 | latent,
172 | input_ids,
173 | attention_mask=attention_mask,
174 | token_type_ids=token_type_ids,
175 | position_ids=position_ids,
176 | head_mask=head_mask,
177 | inputs_embeds=inputs_embeds,
178 | encoder_hidden_states=encoder_hidden_states,
179 | encoder_attention_mask=encoder_attention_mask,
180 | output_attentions=output_attentions,
181 | output_hidden_states=output_hidden_states,
182 | return_dict=return_dict,
183 | )
184 |
185 | sequence_output = outputs[0]
186 | prediction_scores = self.cls(sequence_output)
187 |
188 | masked_lm_loss = None
189 | if labels is not None:
190 | loss_fct = CrossEntropyLoss() # -100 index = padding token
191 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
192 |
193 | if not return_dict:
194 | output = (prediction_scores,) + outputs[2:]
195 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
196 |
197 | return MaskedLMOutput(
198 | loss=masked_lm_loss,
199 | logits=prediction_scores,
200 | hidden_states=outputs.hidden_states,
201 | attentions=outputs.attentions,
202 | )
--------------------------------------------------------------------------------
/util/model_urfound.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9 | # DeiT: https://github.com/facebookresearch/deit
10 | # --------------------------------------------------------
11 |
12 | from functools import partial
13 |
14 |
15 | import torch
16 | import torchvision
17 | import torch.nn as nn
18 | from torchvision.transforms.functional import InterpolationMode
19 | from timm.models.vision_transformer import PatchEmbed, Block
20 |
21 | from util.pos_embed import get_2d_sincos_pos_embed
22 | from bert.bert_encoder import BertEncoder
23 |
24 | class UrModel(nn.Module):
25 | """ Masked Autoencoder with VisionTransformer backbone
26 | """
27 | def __init__(self, img_size=224, patch_size=16, in_chans=3,
28 | embed_dim=1024, depth=24, num_heads=16,
29 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
30 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
31 | super().__init__()
32 |
33 | # --------------------------------------------------------------------------
34 | # image encoder specifics
35 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
36 | num_patches = self.patch_embed.num_patches
37 |
38 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
39 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
40 |
41 | self.blocks = nn.ModuleList([
42 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
43 | for i in range(depth)])
44 | self.norm = norm_layer(embed_dim)
45 | # --------------------------------------------------------------------------
46 |
47 | # --------------------------------------------------------------------------
48 | # image decoder specifics
49 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
50 |
51 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
52 |
53 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
54 |
55 | self.decoder_blocks = nn.ModuleList([
56 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
57 | for i in range(decoder_depth)])
58 |
59 | self.decoder_norm = norm_layer(decoder_embed_dim)
60 | self.decoder_pred = nn.Linear(decoder_embed_dim, (patch_size*2)**2 * in_chans, bias=True)
61 | # --------------------------------------------------------------------------
62 | # Bert encoder
63 | self.bert_encoder = BertEncoder()
64 | self.bert_mlp = nn.Linear(embed_dim, 384, bias=True)
65 | self.norm_pix_loss = norm_pix_loss
66 |
67 | self.initialize_weights()
68 |
69 | def initialize_weights(self):
70 | # initialization
71 | # initialize (and freeze) pos_embed by sin-cos embedding
72 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
73 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
74 |
75 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
76 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
77 |
78 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
79 | w = self.patch_embed.proj.weight.data
80 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
81 |
82 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
83 | torch.nn.init.normal_(self.cls_token, std=.02)
84 | torch.nn.init.normal_(self.mask_token, std=.02)
85 |
86 | # initialize nn.Linear and nn.LayerNorm
87 | self.apply(self._init_weights)
88 |
89 | def _init_weights(self, m):
90 | if isinstance(m, nn.Linear):
91 | # we use xavier_uniform following official JAX ViT:
92 | torch.nn.init.xavier_uniform_(m.weight)
93 | if isinstance(m, nn.Linear) and m.bias is not None:
94 | nn.init.constant_(m.bias, 0)
95 | elif isinstance(m, nn.LayerNorm):
96 | nn.init.constant_(m.bias, 0)
97 | nn.init.constant_(m.weight, 1.0)
98 |
99 | def patchify(self, imgs):
100 | """
101 | imgs: (N, 3, H, W)
102 | x: (N, L, patch_size**2 *3)
103 | """
104 |
105 | p = self.patch_embed.patch_size[0]*2
106 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
107 |
108 | h = w = imgs.shape[2] // p
109 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
110 | x = torch.einsum('nchpwq->nhwpqc', x)
111 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
112 | return x
113 |
114 | def unpatchify(self, x):
115 | """
116 | x: (N, L, patch_size**2 *3)
117 | imgs: (N, 3, H, W)
118 | """
119 | p = self.patch_embed.patch_size[0] * 2
120 | h = w = int(x.shape[1]**.5)
121 | assert h * w == x.shape[1]
122 |
123 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
124 | x = torch.einsum('nhwpqc->nchpwq', x)
125 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
126 | return imgs
127 |
128 | def random_masking(self, x, mask_ratio):
129 | """
130 | Perform per-sample random masking by per-sample shuffling.
131 | Per-sample shuffling is done by argsort random noise.
132 | x: [N, L, D], sequence
133 | """
134 | N, L, D = x.shape # batch, length, dim
135 | len_keep = int(L * (1 - mask_ratio))
136 |
137 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
138 |
139 | # sort noise for each sample
140 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
141 | ids_restore = torch.argsort(ids_shuffle, dim=1)
142 |
143 | # keep the first subset
144 | ids_keep = ids_shuffle[:, :len_keep]
145 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
146 |
147 | # generate the binary mask: 0 is keep, 1 is remove
148 | mask = torch.ones([N, L], device=x.device)
149 | mask[:, :len_keep] = 0
150 | # unshuffle to get the binary mask
151 | mask = torch.gather(mask, dim=1, index=ids_restore)
152 |
153 | return x_masked, mask, ids_restore
154 |
155 | def forward_encoder(self, x, mask_ratio):
156 | # embed patches
157 | x = self.patch_embed(x)
158 | # add pos embed w/o cls token
159 | x = x + self.pos_embed[:, 1:, :]
160 |
161 | # masking: length -> length * mask_ratio
162 | x, mask, ids_restore = self.random_masking(x, mask_ratio)
163 |
164 | # append cls token
165 | cls_token = self.cls_token + self.pos_embed[:, :1, :]
166 | cls_tokens = cls_token.expand(x.shape[0], -1, -1)
167 | x = torch.cat((cls_tokens, x), dim=1)
168 |
169 | # apply Transformer blocks
170 | for blk in self.blocks:
171 | x = blk(x)
172 | x = self.norm(x)
173 |
174 | return x, mask, ids_restore
175 |
176 | def forward_decoder(self, x, ids_restore):
177 | # embed tokens
178 | x = self.decoder_embed(x)
179 |
180 | # append mask tokens to sequence
181 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
182 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
183 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
184 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
185 |
186 | # add pos embed
187 | x = x + self.decoder_pos_embed
188 |
189 | # apply Transformer blocks
190 | for blk in self.decoder_blocks:
191 | x = blk(x)
192 | x = self.decoder_norm(x)
193 |
194 | # predictor projection
195 | x = self.decoder_pred(x)
196 |
197 | # remove cls token
198 | x = x[:, 1:, :]
199 |
200 | return x
201 |
202 | def forward_report_decoder(self, latent, caption_ids, labels, attention_mask, token_type_ids):
203 | latent = self.bert_mlp(latent)
204 | latent = latent[:, 1:, :].mean(dim=1)
205 | outputs = self.bert_encoder(latent, caption_ids, labels, attention_mask, token_type_ids)
206 | return outputs.loss
207 |
208 | def forward_loss(self, imgs, pred, mask):
209 | """
210 | imgs: [N, 3, H, W]
211 | pred: [N, L, p*p*3]
212 | mask: [N, L], 0 is keep, 1 is remove,
213 | """
214 | target = self.patchify(imgs)
215 | if self.norm_pix_loss:
216 | mean = target.mean(dim=-1, keepdim=True)
217 | var = target.var(dim=-1, keepdim=True)
218 | target = (target - mean) / (var + 1.e-6)**.5
219 |
220 | loss = (pred - target) ** 2
221 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch
222 |
223 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
224 | return loss
225 |
226 | def forward(self, batch, mask_ratio=0.75):
227 | big_imgs = batch["image"]
228 |
229 | ids, labels, attention_mask, type_ids = batch["ids"], batch["labels"], batch["attention_mask"], batch["type_ids"]
230 |
231 | big_imgs = big_imgs.cuda()
232 | ids = ids.cuda()
233 | labels = labels.cuda()
234 | attention_mask = attention_mask.cuda()
235 | type_ids = type_ids.cuda()
236 | imgs = torchvision.transforms.Resize([224,224], interpolation=InterpolationMode.BICUBIC)(big_imgs)
237 |
238 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
239 | report_loss = self.forward_report_decoder(latent, ids, labels, attention_mask, type_ids)
240 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
241 | loss = self.forward_loss(big_imgs, pred, mask)
242 | return (loss, report_loss), pred, mask
243 |
244 | def urmodel(**kwargs):
245 | model = UrModel(
246 | patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12,
247 | decoder_embed_dim=768, decoder_depth=4, decoder_num_heads=6,
248 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
249 | return model
250 |
251 | def mae_vit_base_patch16_dec512d8b(**kwargs):
252 | model = UrModel(
253 | patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12,
254 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
255 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
256 | return model
257 |
258 | mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
259 |
--------------------------------------------------------------------------------
/util/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # DeiT: https://github.com/facebookresearch/deit
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 |
12 | import builtins
13 | import datetime
14 | import os
15 | import time
16 | from collections import defaultdict, deque
17 | from pathlib import Path
18 |
19 | import torch
20 | import torch.distributed as dist
21 | from torch._six import inf
22 |
23 |
24 | class SmoothedValue(object):
25 | """Track a series of values and provide access to smoothed values over a
26 | window or the global series average.
27 | """
28 |
29 | def __init__(self, window_size=20, fmt=None):
30 | if fmt is None:
31 | fmt = "{median:.4f} ({global_avg:.4f})"
32 | self.deque = deque(maxlen=window_size)
33 | self.total = 0.0
34 | self.count = 0
35 | self.fmt = fmt
36 |
37 | def update(self, value, n=1):
38 | self.deque.append(value)
39 | self.count += n
40 | self.total += value * n
41 |
42 | def synchronize_between_processes(self):
43 | """
44 | Warning: does not synchronize the deque!
45 | """
46 | if not is_dist_avail_and_initialized():
47 | return
48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
49 | dist.barrier()
50 | dist.all_reduce(t)
51 | t = t.tolist()
52 | self.count = int(t[0])
53 | self.total = t[1]
54 |
55 | @property
56 | def median(self):
57 | d = torch.tensor(list(self.deque))
58 | return d.median().item()
59 |
60 | @property
61 | def avg(self):
62 | d = torch.tensor(list(self.deque), dtype=torch.float32)
63 | return d.mean().item()
64 |
65 | @property
66 | def global_avg(self):
67 | return self.total / self.count
68 |
69 | @property
70 | def max(self):
71 | return max(self.deque)
72 |
73 | @property
74 | def value(self):
75 | return self.deque[-1]
76 |
77 | def __str__(self):
78 | return self.fmt.format(
79 | median=self.median,
80 | avg=self.avg,
81 | global_avg=self.global_avg,
82 | max=self.max,
83 | value=self.value)
84 |
85 |
86 | class MetricLogger(object):
87 | def __init__(self, delimiter="\t"):
88 | self.meters = defaultdict(SmoothedValue)
89 | self.delimiter = delimiter
90 |
91 | def update(self, **kwargs):
92 | for k, v in kwargs.items():
93 | if v is None:
94 | continue
95 | if isinstance(v, torch.Tensor):
96 | v = v.item()
97 | assert isinstance(v, (float, int))
98 | self.meters[k].update(v)
99 |
100 | def __getattr__(self, attr):
101 | if attr in self.meters:
102 | return self.meters[attr]
103 | if attr in self.__dict__:
104 | return self.__dict__[attr]
105 | raise AttributeError("'{}' object has no attribute '{}'".format(
106 | type(self).__name__, attr))
107 |
108 | def __str__(self):
109 | loss_str = []
110 | for name, meter in self.meters.items():
111 | loss_str.append(
112 | "{}: {}".format(name, str(meter))
113 | )
114 | return self.delimiter.join(loss_str)
115 |
116 | def synchronize_between_processes(self):
117 | for meter in self.meters.values():
118 | meter.synchronize_between_processes()
119 |
120 | def add_meter(self, name, meter):
121 | self.meters[name] = meter
122 |
123 | def log_every(self, iterable, print_freq, header=None):
124 | i = 0
125 | if not header:
126 | header = ''
127 | start_time = time.time()
128 | end = time.time()
129 | iter_time = SmoothedValue(fmt='{avg:.4f}')
130 | data_time = SmoothedValue(fmt='{avg:.4f}')
131 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
132 | log_msg = [
133 | header,
134 | '[{0' + space_fmt + '}/{1}]',
135 | 'eta: {eta}',
136 | '{meters}',
137 | 'time: {time}',
138 | 'data: {data}'
139 | ]
140 | if torch.cuda.is_available():
141 | log_msg.append('max mem: {memory:.0f}')
142 | log_msg = self.delimiter.join(log_msg)
143 | MB = 1024.0 * 1024.0
144 | for obj in iterable:
145 | data_time.update(time.time() - end)
146 | yield obj
147 | iter_time.update(time.time() - end)
148 | if i % print_freq == 0 or i == len(iterable) - 1:
149 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
151 | if torch.cuda.is_available():
152 | print(log_msg.format(
153 | i, len(iterable), eta=eta_string,
154 | meters=str(self),
155 | time=str(iter_time), data=str(data_time),
156 | memory=torch.cuda.max_memory_allocated() / MB))
157 | else:
158 | print(log_msg.format(
159 | i, len(iterable), eta=eta_string,
160 | meters=str(self),
161 | time=str(iter_time), data=str(data_time)))
162 | i += 1
163 | end = time.time()
164 | total_time = time.time() - start_time
165 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
166 | print('{} Total time: {} ({:.4f} s / it)'.format(
167 | header, total_time_str, total_time / len(iterable)))
168 |
169 |
170 | def setup_for_distributed(is_master):
171 | """
172 | This function disables printing when not in master process
173 | """
174 | builtin_print = builtins.print
175 |
176 | def print(*args, **kwargs):
177 | force = kwargs.pop('force', False)
178 | force = force or (get_world_size() > 8)
179 | if is_master or force:
180 | now = datetime.datetime.now().time()
181 | builtin_print('[{}] '.format(now), end='') # print with time stamp
182 | builtin_print(*args, **kwargs)
183 |
184 | builtins.print = print
185 |
186 |
187 | def is_dist_avail_and_initialized():
188 | if not dist.is_available():
189 | return False
190 | if not dist.is_initialized():
191 | return False
192 | return True
193 |
194 |
195 | def get_world_size():
196 | if not is_dist_avail_and_initialized():
197 | return 1
198 | return dist.get_world_size()
199 |
200 |
201 | def get_rank():
202 | if not is_dist_avail_and_initialized():
203 | return 0
204 | return dist.get_rank()
205 |
206 |
207 | def is_main_process():
208 | return get_rank() == 0
209 |
210 |
211 | def save_on_master(*args, **kwargs):
212 | if is_main_process():
213 | torch.save(*args, **kwargs)
214 |
215 |
216 | def init_distributed_mode(args):
217 | if args.dist_on_itp:
218 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
219 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
220 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
221 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
222 | os.environ['LOCAL_RANK'] = str(args.gpu)
223 | os.environ['RANK'] = str(args.rank)
224 | os.environ['WORLD_SIZE'] = str(args.world_size)
225 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
226 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
227 | args.rank = int(os.environ["RANK"])
228 | args.world_size = int(os.environ['WORLD_SIZE'])
229 | args.gpu = int(os.environ['LOCAL_RANK'])
230 | elif 'SLURM_PROCID' in os.environ:
231 | args.rank = int(os.environ['SLURM_PROCID'])
232 | args.gpu = args.rank % torch.cuda.device_count()
233 | else:
234 | print('Not using distributed mode')
235 | setup_for_distributed(is_master=True) # hack
236 | args.distributed = False
237 | return
238 |
239 | args.distributed = True
240 |
241 | torch.cuda.set_device(args.gpu)
242 | args.dist_backend = 'nccl'
243 | print('| distributed init (rank {}): {}, gpu {}'.format(
244 | args.rank, args.dist_url, args.gpu), flush=True)
245 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
246 | world_size=args.world_size, rank=args.rank)
247 | torch.distributed.barrier()
248 | setup_for_distributed(args.rank == 0)
249 |
250 |
251 | class NativeScalerWithGradNormCount:
252 | state_dict_key = "amp_scaler"
253 |
254 | def __init__(self):
255 | self._scaler = torch.cuda.amp.GradScaler()
256 |
257 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
258 | self._scaler.scale(loss).backward(create_graph=create_graph)
259 | if update_grad:
260 | if clip_grad is not None:
261 | assert parameters is not None
262 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
263 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
264 | else:
265 | self._scaler.unscale_(optimizer)
266 | norm = get_grad_norm_(parameters)
267 | self._scaler.step(optimizer)
268 | self._scaler.update()
269 | else:
270 | norm = None
271 | return norm
272 |
273 | def state_dict(self):
274 | return self._scaler.state_dict()
275 |
276 | def load_state_dict(self, state_dict):
277 | self._scaler.load_state_dict(state_dict)
278 |
279 |
280 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
281 | if isinstance(parameters, torch.Tensor):
282 | parameters = [parameters]
283 | parameters = [p for p in parameters if p.grad is not None]
284 | norm_type = float(norm_type)
285 | if len(parameters) == 0:
286 | return torch.tensor(0.)
287 | device = parameters[0].grad.device
288 | if norm_type == inf:
289 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
290 | else:
291 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
292 | return total_norm
293 |
294 |
295 | def save_model_pretrain(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
296 | output_dir = Path(args.output_dir)
297 | epoch_name = str(epoch)
298 | if loss_scaler is not None:
299 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
300 | for checkpoint_path in checkpoint_paths:
301 | to_save = {
302 | 'model': model_without_ddp.state_dict(),
303 | 'optimizer': optimizer.state_dict(),
304 | 'epoch': epoch,
305 | 'scaler': loss_scaler.state_dict(),
306 | 'args': args,
307 | }
308 |
309 | save_on_master(to_save, checkpoint_path)
310 | else:
311 | client_state = {'epoch': epoch}
312 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
313 |
314 |
315 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
316 | output_dir = Path(args.output_dir)
317 | epoch_name = str(epoch)
318 | if loss_scaler is not None:
319 | checkpoint_paths = [args.output_dir+args.task+'checkpoint-best.pth']
320 | for checkpoint_path in checkpoint_paths:
321 | to_save = {
322 | 'model': model_without_ddp.state_dict(),
323 | 'optimizer': optimizer.state_dict(),
324 | 'epoch': epoch,
325 | 'scaler': loss_scaler.state_dict(),
326 | 'args': args,
327 | }
328 |
329 | save_on_master(to_save, checkpoint_path)
330 | else:
331 | client_state = {'epoch': epoch}
332 | model.save_checkpoint(save_dir=args.output_dir+args.task, tag="checkpoint-best", client_state=client_state)
333 |
334 |
335 | def load_model(args, model_without_ddp, optimizer, loss_scaler):
336 | if args.resume:
337 | if args.resume.startswith('https'):
338 | checkpoint = torch.hub.load_state_dict_from_url(
339 | args.resume, map_location='cpu', check_hash=True)
340 | else:
341 | checkpoint = torch.load(args.resume, map_location='cpu')
342 | # if 'ft' in args.output_dir and args.eval:
343 | for i in model_without_ddp.state_dict():
344 | if i not in checkpoint['model']:
345 | checkpoint['model'][i] = model_without_ddp.state_dict()[i]
346 |
347 | model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
348 | print("Resume checkpoint %s" % args.resume)
349 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
350 | optimizer.load_state_dict(checkpoint['optimizer'])
351 | args.start_epoch = checkpoint['epoch'] + 1
352 | if 'scaler' in checkpoint:
353 | loss_scaler.load_state_dict(checkpoint['scaler'])
354 | print("With optim & sched!")
355 |
356 |
357 | def all_reduce_mean(x):
358 | world_size = get_world_size()
359 | if world_size > 1:
360 | x_reduce = torch.tensor(x).cuda()
361 | dist.all_reduce(x_reduce)
362 | x_reduce = x_reduce / world_size
363 | return x_reduce.item()
364 | else:
365 | return x
--------------------------------------------------------------------------------
/util/flair_dataloader/dictionary.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | definition:
4 | This script contains dictionaries for expert knowledge prompts
5 | of several fundus image conditions for pre-training.
6 |
7 | ensemble_prompts:
8 | Also, it presents a dictionary for creating prompt ensembles for
9 | zero-shot classification and transferability.
10 |
11 | datasets/abbreviations:
12 | Finally, the script contains abbreviations of several relevant
13 | conditions used for FLAIR pre-training, and datasets names.
14 | """
15 |
16 | # Expert knowledge definitions dictionary
17 | definitions = {"no diabetic retinopathy": ["no diabetic retinopathy", "no microaneurysms"],
18 | "mild diabetic retinopathy": ["only few microaneurysms"],
19 | "moderate diabetic retinopathy": ["many exudates near the macula",
20 | "many haemorrhages near the macula",
21 | "retinal thickening near the macula",
22 | "hard exudates",
23 | "cotton wool spots",
24 | "few severe haemorrhages"],
25 | "severe diabetic retinopathy": ["venous beading",
26 | "many severe haemorrhages",
27 | "intraretinal microvascular abnormality"],
28 | "proliferative diabetic retinopathy": ["preretinal or vitreous haemorrhage",
29 | "neovascularization"],
30 | "no referable diabetic macular edema": ["no apparent exudates"],
31 | "hard exudates": ["small white or yellowish deposits with sharp margins", "bright lesion"],
32 | "soft exudates": ["pale yellow or white areas with ill-defined edges", "cotton-wool spot",
33 | "small, whitish or grey, cloud-like, linear or serpentine, slightly elevated lesions"
34 | " with fimbriated edges"],
35 | "microaneurysms": ["small red dots"],
36 | "haemorrhages": ["dense, dark red, sharply outlined lesion"],
37 | "non clinically significant diabetic macular edema": ["presence of exudates outside the radius of one"
38 | " disc diameter from the macula center",
39 | "presence of exudates"],
40 | "age related macular degeneration": ["many small drusen", "few medium-sized drusen", "large drusen",
41 | "macular degeneration"],
42 | "media haze": ["vitreous haze", "pathological opacity", "the obscuration of fundus details by vitreous"
43 | " cells and protein exudation"],
44 | "drusens": ["yellow deposits under the retina", "numerous uniform round yellow-white lesions"],
45 | "pathologic myopia": ["anomalous disc, macular atrophy and possible tessellation"],
46 | "branch retinal vein occlusion": ["occlusion of one of the four major branch retinal veins"],
47 | "tessellation": ["large choroidal vessels at the posterior fundus"],
48 | "epiretinal membrane": ["greyish semi-translucent avascular membrane"],
49 | "laser scar": ["round or oval, yellowish-white with variable black pigment centrally",
50 | "50 to 200 micron diameter lesions"],
51 | "no laser scar": ["no laser scar"],
52 | "macular scar": ["macular scar"],
53 | "central serous retinopathy": ["subretinal fluid involving the fovea", "leakage"],
54 | "optic disc cupping": ["optic disc cupping"],
55 | "central retinal vein occlusion": ["central retinal vein occlusion"],
56 | "tortuous vessels": ["tortuous vessels"],
57 | "asteroid hyalosis": ["multiple sparking, yellow-white, and refractile opacities in the vitreous cavity",
58 | "vitreous opacities"],
59 | "optic disc pallor": ["pale yellow discoloration that can be segmental or generalized on optic disc"],
60 | "optic disc edema": ["optic disc edema"],
61 | "shunt": ["collateral vessels connecting the choroidal and the retinal vasculature",
62 | "collateral vessels of large caliber and lack of leakage"],
63 | "anterior ischemic optic neuropathy": ["anterior ischemic optic neuropathy"],
64 | "parafoveal telangiectasia": ["parafoveal telangiectasia"],
65 | "retinal traction": ["retinal traction"],
66 | "retinitis": ["retinitis"],
67 | "chorioretinitis": ["chorioretinitis"],
68 | "exudates": ["small white or yellowish white deposits with sharp margins", "bright lesion"],
69 | "retinal pigment epithelium changes": ["retinal pigment epithelium changes"],
70 | "macular hole": ["lesion in the macula", "grayish fovea"],
71 | "retinitis pigmentosa": ["pigment deposits are present in the periphery"],
72 | "cotton wool spots": ["cotton wool spots", "soft exudates"],
73 | "colobomas": ["colobomas"],
74 | "optic disc pit maculopathy": ["optic disc pit maculopathy"],
75 | "preretinal haemorrhage": ["preretinal haemorrhage"],
76 | "myelinated nerve fibers": ["myelinated nerve fibers"],
77 | "haemorrhagic retinopathy": ["haemorrhagic retinopathy"],
78 | "central retinal artery occlusion": ["central retinal artery occlusion"],
79 | "tilted disc": ["tilted disc"],
80 | "cystoid macular edema": ["cysts in the macula region"],
81 | "post traumatic choroidal rupture": ["post traumatic choroidal rupture"],
82 | "choroidal folds": ["choroidal folds"],
83 | "vitreous haemorrhage": ["vitreous haemorrhage"],
84 | "macroaneurysm": ["macroaneurysm"],
85 | "vasculitis": ["vasculitis"],
86 | "branch retinal artery occlusion": ["branch retinal artery occlusion"],
87 | "plaque": ["plaque"],
88 | "haemorrhagic pigment epithelial detachment": ["haemorrhagic pigment epithelial detachment"],
89 | "collaterals": ["collaterals"],
90 | "normal": ["healthy", "no findings", "no lesion signs", "no glaucoma", "no retinopathy"],
91 | "large optic cup": ["abnormality in optic cup"],
92 | "retina detachment": ["retina detachment"],
93 | "Vogt-Koyanagi syndrome": ["Vogt-Koyanagi syndrome"],
94 | "maculopathy": ["maculopathy"],
95 | "glaucoma": ["optic nerve abnormalities", "abnormal size of the optic cup",
96 | "anomalous size in the optic disc"],
97 | "optic atrophy": ["optic atrophy"],
98 | "severe hypertensive retinopathy": ["flame shaped hemorrhages at the disc margin, blurred disc margins,"
99 | " congested retinal veins, papilledema, and secondary macular "
100 | "exudates", "arterio-venous crossing changes, macular star and "
101 | "cotton wool spots"],
102 | "disc swelling and elevation": ["disc swelling and elevation"],
103 | "dragged disk": ["dragged disk"],
104 | "congenital disk abnormality": ["disk abnormality", "optic disk lesion"],
105 | "Bietti crystalline dystrophy": ["Bietti crystalline dystrophy"],
106 | "peripheral retinal degeneration and break": ["peripheral retinal degeneration and break"],
107 | "neoplasm": ["neoplasm"],
108 | "yellow-white spots flecks": ["yellow-white spots flecks"],
109 | "fibrosis": ["fibrosis"],
110 | "silicon oil": ["silicon oil"],
111 | "no proliferative diabetic retinopathy": ["diabetic retinopathy with no neovascularization",
112 | "no neovascularization"],
113 | "no glaucoma": ["no glaucoma"],
114 | "cataract": ["opacity in the macular area"],
115 | "hypertensive retinopathy": ["possible signs of haemorraghe with blot, dot, or flame-shaped",
116 | "possible presence of microaneurysm, cotton-wool spot, or hard exudate",
117 | "arteriolar narrowing", "vascular wall changes", "optic disk edema"],
118 | "neovascular age related macular degeneration": ["neovascular age-related macular degeneration"],
119 | "geographical age related macular degeneration": ["geographical age-related macular degeneration"],
120 | "acute central serous retinopathy": ["acute central serous retinopathy"],
121 | "chronic central serous retinopathy": ["chronic central serous retinopathy"],
122 | "no cataract": ["no cataract signs", "no obscure opacities"],
123 | "abnormal optic disc": ["abnormal optic disc"],
124 | "abnormal vessels": ["abnormal vessels"],
125 | "abnormal macula": ["abnormal macula"],
126 | "macular edema": ["macular edema"],
127 | "scar": ["scar"],
128 | "nevus": ["darkly pigmented lesion found in the back of the eye"],
129 | "increased cup disc": ["increased cup disc"],
130 | "intraretinal microvascular abnormalities": ["shunt vessels and appear as abnormal branching or"
131 | " dilation of existing blood vessels (capillaries) "
132 | "within the retina", "deeper in the retina than"
133 | " neovascularization, has blurrier edges, is more"
134 | " of a burgundy than a red, does not appear on the "
135 | "optic disc", "vascular loops confined within the"
136 | " retina"],
137 | "red small dots": ["microaneurysms"],
138 | "neovascularisation": ["neovascularisation"],
139 | "a disease": ["no healthy", "lesions"],
140 | "superficial haemorrhages": ["superficial haemorrhages"],
141 | "deep haemorrhages": ["deep haemorrhages"],
142 | "ungradable": ["no fundus", "very noisy", "noisy"],
143 | "noisy": ["noisy"],
144 | "normal macula": ["normal macula"],
145 | "macular degeneration": ["macular degeneration"],
146 | "diabetic retinopathy": ["diabetic retinopathy"],
147 | "no hypertensive retinopathy": ["no presence of hypertensive retinopathy"],
148 | "mild hypertensive retinopathy": ["mild arteriovenous ratio", "mild tortuosity",
149 | "focal arteriolar narrowing",
150 | "arteriovenous nicking"],
151 | "moderate hypertensive retinopathy": ["moderate arteriovenous ratio", "moderate tortuosity",
152 | "cotton wool spots",
153 | "flame-shaped haemorrhages"],
154 | "malignant hypertensive retinopathy": ["severe arteriovenous ratio", "severe tortuosity",
155 | "swelling optical disk",
156 | "flame-shaped haemorrhages"]
157 | }
158 |
159 | # Datasets names
160 | datasets = ["01_EYEPACS", "03_IDRID", "04_RFMid", "05_1000x39", "07_LAG", "09_PAPILA", "10_PARAGUAY", "12_ARIA",
161 | "14_AGAR300", "15_APTOS", "16_FUND-OCT", "17_DiaRetDB1", "18_DRIONS-DB", "19_Drishti-GS1", "20_E-ophta",
162 | "20_E-ophta", "21_G1020", "23_HRF", "24_ORIGA", "25_REFUGE", "26_ROC", "27_BRSET", "28_OIA-DDR",
163 | "02_MESIDOR", "05_20x3", "08_ODIR200x3", "13_FIVES"]
164 |
165 | # Categories abbreviations
166 | abbreviations = {"no diabetic retinopathy": "noDR", "mild diabetic retinopathy": "mildDR",
167 | "moderate diabetic retinopathy": "modDR", "severe diabetic retinopathy": "sevDR",
168 | "proliferative diabetic retinopathy": "prolDR", "diabetic macular edema": "DME",
169 | "no referable diabetic macular edema": "noDME", "hard exudates": "hEX",
170 | "soft exudates": "sEX", "microaneurysms": "MA", "haemorrhages": "HE",
171 | "non clinically significant diabetic macular edema": "nonCSDME",
172 | "age-related macular degeneration": "ARMD", "media haze": "MH", "drusens": "DN",
173 | "pathologic myopia": "MYA", "branch retinal vein occlusion": "BRVO", "tessellation": "TSLN",
174 | "epiretinal membrane": "ERM", "laser scar": "LS", "macular scar": "MS",
175 | "central serous retinopathy": "CSR", "optic disc cupping": "ODC",
176 | "central retinal vein occlusion": "CRVO", "tortuous vessels": "TV", "asteroid hyalosis": "AH",
177 | "optic disc pallor": "ODP", "optic disc edema": "ODE",
178 | "shunt": "ST", "anterior ischemic optic neuropathy": "AION", "parafoveal telangiectasia": "PT",
179 | "retinal traction": "RT", "retinitis": "RS", "chorioretinitis": "CRS", "exudates": "EX",
180 | "retinal pigment epithelium changes": "RPEC", "macular hole": "MHL", "retinitis pigmentosa": "RP",
181 | "cotton wool spots": "CWS", "colobomas": "CB", "optic disc pit maculopathy": "ODM",
182 | "preretinal haemorrhage": "PRH", "myelinated nerve fibers": "MNF", "haemorrhagic retinopathy": "HR",
183 | "central retinal artery occlusion": "CRAO", "tilted disc": "TD", "cystoid macular edema": "CME",
184 | "post traumatic choroidal rupture": "PTCR", "choroidal folds": "CF", "vitreous haemorrhage": "VH",
185 | "macroaneurysm": "MCA", "vasculitis": "VS", "branch retinal artery occlusion": "BRAO", "plaque": "PLQ",
186 | "haemorrhagic pigment epithelial detachment": "HPED", "collaterals": "CL", "normal": "N",
187 | "large optic cup": "LOC", "retina detachment": "RD", "Vogt-Koyanagi syndrome": "VKH",
188 | "maculopathy": "M", "glaucoma": "G", "optic atrophy": "OA", "severe hypertensive retinopathy": "sevHR",
189 | "disc swelling and elevation": "DSE", "dragged disk": "DD", "congenital disk abnormality": "CDA",
190 | "Bietti crystalline dystrophy": "BCD", "peripheral retinal degeneration and break": "PRDB",
191 | "neoplasm": "NP", "yellow-white spots flecks": "YWSF", "fibrosis": "fibrosis", "silicon oil": "SO",
192 | "no proliferative diabetic retinopathy": "noProlDR", "no glaucoma": "noG", "cataract": "CAT",
193 | "hypertensive retinopathy": "HR", "neovascular age-related macular degeneration": "neovARMD",
194 | "geographical age-related macular degeneration": "geoARMD",
195 | "acute central serous retinopathy": "acCSR", "chronic central serous retinopathy": "chCSR",
196 | "no cataract": "noCAT", "abnormal optic disc": "AOD", "abnormal vessels": "AV",
197 | "abnormal macula": "AM", "macular edema": "ME", "scar": "S", "nevus": "NE",
198 | "increased cup disc": "ICD", "intraretinal microvascular abnormalities": "IrMA",
199 | "red small dots": "ReSD", "neovascularisation": "neoV", "a disease": "Dis"}
--------------------------------------------------------------------------------
/main_finetune.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | # Partly revised by YZ @UCL&Moorfields
4 | # --------------------------------------------------------
5 |
6 | import argparse
7 | import datetime
8 | import json
9 | import numpy as np
10 | import os
11 | import time
12 | from pathlib import Path
13 |
14 | import torch
15 | import torch.backends.cudnn as cudnn
16 | from torch.utils.tensorboard import SummaryWriter
17 |
18 | import timm
19 |
20 | assert timm.__version__ == "0.3.2" # version check
21 | from timm.models.layers import trunc_normal_
22 | from timm.data.mixup import Mixup
23 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
24 |
25 | import util.lr_decay as lrd
26 | import util.misc as misc
27 | from finetune.datasets_finetune import build_dataset, build_transform
28 | from util.pos_embed import interpolate_pos_embed
29 | from util.misc import NativeScalerWithGradNormCount as NativeScaler
30 |
31 | import finetune.models_vit as models_vit
32 |
33 | from finetune.engine_finetune import train_one_epoch, evaluate
34 | from torchvision import datasets
35 | import random
36 | from torch.utils.data import Subset
37 |
38 | def get_args_parser():
39 | parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
40 | parser.add_argument('--batch_size', default=64, type=int,
41 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
42 | parser.add_argument('--epochs', default=50, type=int)
43 | parser.add_argument('--accum_iter', default=1, type=int,
44 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
45 |
46 | # Model parameters
47 | parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
48 | help='Name of model to train')
49 |
50 | parser.add_argument('--input_size', default=224, type=int,
51 | help='images input size')
52 |
53 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
54 | help='Drop path rate (default: 0.1)')
55 |
56 | # Optimizer parameters
57 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
58 | help='Clip gradient norm (default: None, no clipping)')
59 | parser.add_argument('--weight_decay', type=float, default=0.05,
60 | help='weight decay (default: 0.05)')
61 |
62 | parser.add_argument('--lr', type=float, default=None, metavar='LR',
63 | help='learning rate (absolute lr)')
64 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
65 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
66 | parser.add_argument('--layer_decay', type=float, default=0.75,
67 | help='layer-wise lr decay from ELECTRA/BEiT')
68 |
69 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
70 | help='lower lr bound for cyclic schedulers that hit 0')
71 |
72 | parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
73 | help='epochs to warmup LR')
74 |
75 | # Augmentation parameters
76 | parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
77 | help='Color jitter factor (enabled only when not using Auto/RandAug)')
78 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
79 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
80 | parser.add_argument('--smoothing', type=float, default=0.1,
81 | help='Label smoothing (default: 0.1)')
82 |
83 | # * Random Erase params
84 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
85 | help='Random erase prob (default: 0.25)')
86 | parser.add_argument('--remode', type=str, default='pixel',
87 | help='Random erase mode (default: "pixel")')
88 | parser.add_argument('--recount', type=int, default=1,
89 | help='Random erase count (default: 1)')
90 | parser.add_argument('--resplit', action='store_true', default=False,
91 | help='Do not random erase first (clean) augmentation split')
92 |
93 | # * Mixup params
94 | parser.add_argument('--mixup', type=float, default=0,
95 | help='mixup alpha, mixup enabled if > 0.')
96 | parser.add_argument('--cutmix', type=float, default=0,
97 | help='cutmix alpha, cutmix enabled if > 0.')
98 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
99 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
100 | parser.add_argument('--mixup_prob', type=float, default=1.0,
101 | help='Probability of performing mixup or cutmix when either/both is enabled')
102 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
103 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
104 | parser.add_argument('--mixup_mode', type=str, default='batch',
105 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
106 |
107 | # * Finetuning params
108 | parser.add_argument('--finetune', default='',type=str,
109 | help='finetune from checkpoint')
110 | parser.add_argument('--task', default='',type=str,
111 | help='finetune from checkpoint')
112 | parser.add_argument('--global_pool', action='store_true')
113 | parser.set_defaults(global_pool=True)
114 | parser.add_argument('--cls_token', action='store_false', dest='global_pool',
115 | help='Use class token instead of global pool for classification')
116 |
117 | # Dataset parameters
118 | parser.add_argument('--data_path', default='/home/jupyter/Mor_DR_data/data/data/IDRID/Disease_Grading/', type=str,
119 | help='dataset path')
120 | parser.add_argument('--nb_classes', default=1000, type=int,
121 | help='number of the classification types')
122 |
123 | parser.add_argument('--output_dir', default='./output_dir',
124 | help='path where to save, empty for no saving')
125 | parser.add_argument('--log_dir', default='./output_dir',
126 | help='path where to tensorboard log')
127 | parser.add_argument('--device', default='cuda',
128 | help='device to use for training / testing')
129 | parser.add_argument('--seed', default=0, type=int)
130 | parser.add_argument('--resume', default='',
131 | help='resume from checkpoint')
132 |
133 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
134 | help='start epoch')
135 | parser.add_argument('--eval', action='store_true',
136 | help='Perform evaluation only')
137 | parser.add_argument('--dist_eval', action='store_true', default=False,
138 | help='Enabling distributed evaluation (recommended during training for faster monitor')
139 | parser.add_argument('--num_workers', default=10, type=int)
140 | parser.add_argument('--pin_mem', action='store_true',
141 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
142 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
143 | parser.set_defaults(pin_mem=True)
144 |
145 | # distributed training parameters
146 | parser.add_argument('--world_size', default=1, type=int,
147 | help='number of distributed processes')
148 | parser.add_argument('--local_rank', default=-1, type=int)
149 | parser.add_argument('--dist_on_itp', action='store_true')
150 | parser.add_argument('--dist_url', default='env://',
151 | help='url used to set up distributed training')
152 |
153 | parser.add_argument('--partial_p', type=str, default='')
154 | return parser
155 |
156 | def main(args):
157 | misc.init_distributed_mode(args)
158 |
159 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
160 | print("{}".format(args).replace(', ', ',\n'))
161 |
162 | device = torch.device(args.device)
163 |
164 | # fix the seed for reproducibility
165 | seed = args.seed + misc.get_rank()
166 | torch.manual_seed(seed)
167 | np.random.seed(seed)
168 |
169 | cudnn.benchmark = True
170 |
171 | if args.partial_p:
172 | dataset_train = build_dataset(is_train=f'train_{(args.partial_p)}', args=args)
173 | else:
174 | dataset_train = build_dataset(is_train='train', args=args)
175 | dataset_val = build_dataset(is_train='val', args=args)
176 | dataset_test = build_dataset(is_train='test', args=args)
177 |
178 | if True: # args.distributed:
179 | num_tasks = misc.get_world_size()
180 | global_rank = misc.get_rank()
181 | sampler_train = torch.utils.data.DistributedSampler(
182 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
183 | )
184 | print("Sampler_train = %s" % str(sampler_train))
185 | if args.dist_eval:
186 | if len(dataset_val) % num_tasks != 0:
187 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
188 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
189 | 'equal num of samples per-process.')
190 | sampler_val = torch.utils.data.DistributedSampler(
191 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias
192 | else:
193 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
194 |
195 | if args.dist_eval:
196 | if len(dataset_test) % num_tasks != 0:
197 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
198 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
199 | 'equal num of samples per-process.')
200 | sampler_test = torch.utils.data.DistributedSampler(
201 | dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias
202 | else:
203 | sampler_test = torch.utils.data.SequentialSampler(dataset_test)
204 |
205 |
206 | args.log_dir = os.path.join(args.output_dir)
207 | if global_rank == 0 and args.log_dir is not None and not args.eval:
208 | # os.makedirs(args.log_dir, exist_ok=True)
209 | log_writer = SummaryWriter(log_dir=args.log_dir+args.task)
210 | else:
211 | log_writer = None
212 |
213 | data_loader_train = torch.utils.data.DataLoader(
214 | dataset_train, sampler=sampler_train,
215 | batch_size=args.batch_size,
216 | num_workers=args.num_workers,
217 | pin_memory=args.pin_mem,
218 | drop_last=False,
219 | )
220 |
221 | data_loader_val = torch.utils.data.DataLoader(
222 | dataset_val, sampler=sampler_val,
223 | batch_size=args.batch_size,
224 | num_workers=args.num_workers,
225 | pin_memory=args.pin_mem,
226 | drop_last=False
227 | )
228 |
229 | data_loader_test = torch.utils.data.DataLoader(
230 | dataset_test, sampler=sampler_test,
231 | batch_size=args.batch_size,
232 | num_workers=args.num_workers,
233 | pin_memory=args.pin_mem,
234 | drop_last=False
235 | )
236 |
237 | mixup_fn = None
238 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
239 | if mixup_active:
240 | print("Mixup is activated!")
241 | mixup_fn = Mixup(
242 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
243 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
244 | label_smoothing=args.smoothing, num_classes=args.nb_classes)
245 |
246 | model = models_vit.__dict__[args.model](
247 | img_size=args.input_size,
248 | num_classes=args.nb_classes,
249 | drop_path_rate=args.drop_path,
250 | global_pool=args.global_pool,
251 | )
252 |
253 | if args.finetune and not args.eval:
254 | checkpoint = torch.load(args.finetune, map_location='cpu')
255 |
256 | print("Load pre-trained checkpoint from: %s" % args.finetune)
257 | checkpoint_model = checkpoint['model']
258 | state_dict = model.state_dict()
259 | for k in ['head.weight', 'head.bias']:
260 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
261 | print(f"Removing key {k} from pretrained checkpoint")
262 | del checkpoint_model[k]
263 |
264 | # interpolate position embedding
265 | interpolate_pos_embed(model, checkpoint_model)
266 |
267 | # load pre-trained model
268 | msg = model.load_state_dict(checkpoint_model, strict=False)
269 | print(msg)
270 |
271 | # if args.global_pool:
272 | # assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
273 | # else:
274 | # assert set(msg.missing_keys) == {'head.weight', 'head.bias'}
275 |
276 | # manually initialize fc layer
277 | trunc_normal_(model.head.weight, std=2e-5)
278 |
279 | model.to(device)
280 |
281 | model_without_ddp = model
282 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
283 |
284 | print("Model = %s" % str(model_without_ddp))
285 | print('number of params (M): %.2f' % (n_parameters / 1.e6))
286 |
287 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
288 |
289 | if args.lr is None: # only base_lr is specified
290 | args.lr = args.blr * eff_batch_size / 256
291 |
292 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
293 | print("actual lr: %.2e" % args.lr)
294 |
295 | print("accumulate grad iterations: %d" % args.accum_iter)
296 | print("effective batch size: %d" % eff_batch_size)
297 |
298 | if args.distributed:
299 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
300 | model_without_ddp = model.module
301 |
302 | # build optimizer with layer-wise lr decay (lrd)
303 | param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
304 | no_weight_decay_list=model_without_ddp.no_weight_decay(),
305 | layer_decay=args.layer_decay
306 | )
307 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
308 | loss_scaler = NativeScaler()
309 |
310 | if mixup_fn is not None:
311 | # smoothing is handled with mixup label transform
312 | criterion = SoftTargetCrossEntropy()
313 | elif args.smoothing > 0.:
314 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
315 | else:
316 | criterion = torch.nn.CrossEntropyLoss()
317 |
318 | print("criterion = %s" % str(criterion))
319 |
320 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
321 |
322 | if args.eval:
323 | test_stats,auc_roc = evaluate(data_loader_test, model, device, args.log_dir+args.task, epoch=0, mode='test',num_class=args.nb_classes)
324 | exit(0)
325 |
326 | print(f"Start training for {args.epochs} epochs")
327 | start_time = time.time()
328 | max_accuracy = 0.0
329 | max_auc = 0.0
330 | for epoch in range(args.start_epoch, args.epochs):
331 | if args.distributed:
332 | data_loader_train.sampler.set_epoch(epoch)
333 | train_stats = train_one_epoch(
334 | model, criterion, data_loader_train,
335 | optimizer, device, epoch, loss_scaler,
336 | args.clip_grad, mixup_fn,
337 | log_writer=log_writer,
338 | args=args
339 | )
340 |
341 | val_stats,val_auc_roc = evaluate(data_loader_val, model, device, args.log_dir+args.task, epoch, mode='val',num_class=args.nb_classes)
342 | if max_auc