├── common
├── __init__.py
├── w0_utils.py
├── args.py
└── utils.py
├── data
├── __init__.py
├── preproc_lidc-idri.py
└── dataset.py
├── eval
├── __init__.py
├── maml_full_fit.py
├── maml_full_eval.py
└── maml_scale.py
├── models
├── __init__.py
├── layers.py
├── inrs.py
└── model_wrapper.py
├── train
├── __init__.py
├── trainer.py
└── maml_boot.py
├── assets
├── MedNF.png
└── overview.png
├── environment.yaml
├── configs
├── experiments
│ ├── 1d_timeseries
│ │ └── default.yaml
│ ├── 3d_imgs
│ │ └── default.yaml
│ └── 2d_imgs
│ │ ├── default_224.yaml
│ │ ├── default_128.yaml
│ │ └── default_64.yaml
├── eval
│ ├── 1d_timeseries
│ │ └── default.yaml
│ ├── 3d_imgs
│ │ └── default.yaml
│ └── 2d_imgs
│ │ ├── default_128.yaml
│ │ ├── default_224.yaml
│ │ └── default_64.yaml
└── fit
│ └── default_64.yaml
├── LICENSE
├── eval.py
├── .gitignore
├── fit_NFset.py
├── train.py
├── README.md
└── downstream_tasks
└── classification.py
/common/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/eval/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/MedNF.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfriedri/medfuncta/HEAD/assets/MedNF.png
--------------------------------------------------------------------------------
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfriedri/medfuncta/HEAD/assets/overview.png
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: medfuncta
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | dependencies:
6 | - python=3.10
7 | - pytorch=2.5.1
8 | - torchvision=0.20.1
9 | - numpy=2.1.3
10 | - matplotlib=3.10.1
11 | - pip
12 | - pip:
13 | - tensorboard
14 | - pandas
15 | - nibabel
16 | - einops
17 | - medmnist
18 | - dicom2nifti
19 | - pytorch_msssim
20 | - lpips
21 | - torchmetrics
22 | - opencv-python
--------------------------------------------------------------------------------
/configs/experiments/1d_timeseries/default.yaml:
--------------------------------------------------------------------------------
1 | {
2 | # Training configuration
3 | 'dataset': ecg,
4 | 'img_size': 187,
5 | 'batch_size': 64,
6 | 'inner_steps': 10,
7 | 'max_iter': 250000,
8 | 'lr_scheduler': True,
9 |
10 | # Context reduction
11 | 'data_ratio': 1.0,
12 | 'sample_type': random,
13 |
14 | # Omega schedule
15 | 'w0_sched_type': linear,
16 | 'w0': 20,
17 | 'wK': 200,
18 |
19 | # Test configuration
20 | 'test_batch_size': 64,
21 | 'num_test_signals': 256,
22 | 'inner_steps_test': 20,
23 |
24 | # Model config
25 | 'min_hidden_dim': 64,
26 | 'max_hidden_dim': 64,
27 | 'progression_type': linear,
28 | 'num_layers': 8,
29 | 'latent_modulation_dim': 64,
30 | }
--------------------------------------------------------------------------------
/configs/experiments/3d_imgs/default.yaml:
--------------------------------------------------------------------------------
1 | {
2 | # Training cofiguration
3 | 'task': rec,
4 | 'dataset': brats,
5 | 'img_size': 32,
6 | 'batch_size': 4,
7 | 'inner_steps': 10,
8 | 'max_iter': 250000,
9 | 'lr_scheduler': True,
10 |
11 | # Context reduction
12 | 'data_ratio': 0.25,
13 | 'sample_type': random,
14 |
15 | # Omega schedule
16 | 'w0_sched_type': linear,
17 | 'w0': 20,
18 | 'wK': 300,
19 |
20 | # Test configuration
21 | 'test_batch_size': 4,
22 | 'num_test_signals': 64,
23 | 'inner_steps_test': 20,
24 |
25 | # Model config
26 | 'min_hidden_dim': 256,
27 | 'max_hidden_dim': 256,
28 | 'progression_type': linear,
29 | 'num_layers': 15,
30 | 'latent_modulation_dim': 8192,
31 | }
--------------------------------------------------------------------------------
/configs/experiments/2d_imgs/default_224.yaml:
--------------------------------------------------------------------------------
1 | {
2 | # Training cofiguration
3 | 'task': rec,
4 | 'dataset': chestmnist,
5 | 'img_size': 224,
6 | 'batch_size': 4,
7 | 'inner_steps': 10,
8 | 'max_iter': 250000,
9 | 'lr_scheduler': True,
10 |
11 | # Context reduction
12 | 'data_ratio': 0.1,
13 | 'sample_type': random,
14 |
15 | # Omega schedule
16 | 'w0_sched_type': linear,
17 | 'w0': 30,
18 | 'wK': 300,
19 |
20 | # Test configuration
21 | 'test_batch_size': 4,
22 | 'num_test_signals': 256,
23 | 'inner_steps_test': 20,
24 |
25 | # Model config
26 | 'min_hidden_dim': 256,
27 | 'max_hidden_dim': 256,
28 | 'progression_type': linear,
29 | 'num_layers': 15,
30 | 'latent_modulation_dim': 16384,
31 | }
--------------------------------------------------------------------------------
/configs/experiments/2d_imgs/default_128.yaml:
--------------------------------------------------------------------------------
1 | {
2 | # Training configuration
3 | 'task': rec,
4 | 'dataset': chestmnist,
5 | 'img_size': 128,
6 | 'batch_size': 8,
7 | 'inner_steps': 10,
8 | 'max_iter': 250000,
9 | 'lr_scheduler': True,
10 |
11 | # Context reduction
12 | 'data_ratio': 0.25,
13 | 'sample_type': random,
14 |
15 | # Omega schedule
16 | 'w0_sched_type': linear,
17 | 'w0': 30,
18 | 'wK': 300,
19 |
20 | # Test configuration
21 | 'test_batch_size': 8,
22 | 'num_test_signals': 256,
23 | 'inner_steps_test': 20,
24 |
25 | # Model config
26 | 'min_hidden_dim': 256,
27 | 'max_hidden_dim': 256,
28 | 'progression_type': linear,
29 | 'num_layers': 15,
30 | 'latent_modulation_dim': 8192,
31 | }
--------------------------------------------------------------------------------
/configs/experiments/2d_imgs/default_64.yaml:
--------------------------------------------------------------------------------
1 | {
2 | # Training configuration
3 | 'task': rec,
4 | 'dataset': chestmnist,
5 | 'img_size': 64,
6 | 'batch_size': 24,
7 | 'inner_steps': 10,
8 | 'max_iter': 250000,
9 | 'lr_scheduler': True,
10 |
11 | # Context reduction
12 | 'data_ratio': 0.25,
13 | 'sample_type': random,
14 |
15 | # Omega schedule
16 | 'w0_sched_type': linear,
17 | 'w0': 20,
18 | 'wK': 400,
19 |
20 | # Test configuration
21 | 'test_batch_size': 24,
22 | 'num_test_signals': 256,
23 | 'inner_steps_test': 20,
24 |
25 | # Model config
26 | 'min_hidden_dim': 256,
27 | 'max_hidden_dim': 256,
28 | 'progression_type': linear,
29 | 'num_layers': 15,
30 | 'latent_modulation_dim': 2048,
31 | }
--------------------------------------------------------------------------------
/configs/eval/1d_timeseries/default.yaml:
--------------------------------------------------------------------------------
1 | {
2 | # Training configuration
3 | 'dataset': ecg,
4 | 'img_size': 187,
5 | 'batch_size': 64,
6 | 'inner_steps': 10,
7 | 'max_iter': 250000,
8 | 'lr_scheduler': True,
9 |
10 | # Context reduction
11 | 'data_ratio': 1.0,
12 | 'sample_type': random,
13 |
14 | # Omega schedule
15 | 'w0_sched_type': linear,
16 | 'w0': 20,
17 | 'wK': 200,
18 |
19 | # Test configuration
20 | 'test_batch_size': 64,
21 | 'num_test_signals': 256,
22 | 'inner_steps_test': 20,
23 |
24 | # Model config
25 | 'min_hidden_dim': 64,
26 | 'max_hidden_dim': 64,
27 | 'progression_type': linear,
28 | 'num_layers': 8,
29 | 'latent_modulation_dim': 64,
30 |
31 | # Load model from checkpoint
32 | 'load_path': /path/to/best.model or /path/to/stepXXXX.model,
33 | }
--------------------------------------------------------------------------------
/configs/eval/3d_imgs/default.yaml:
--------------------------------------------------------------------------------
1 | {
2 | # Training cofiguration
3 | 'dataset': brats,
4 | 'img_size': 32,
5 | 'batch_size': 4,
6 | 'inner_steps': 10,
7 | 'max_iter': 250000,
8 | 'lr_scheduler': True,
9 |
10 | # Context reduction
11 | 'data_ratio': 0.25,
12 | 'sample_type': random,
13 |
14 | # Omega schedule
15 | 'w0_sched_type': linear,
16 | 'w0': 20,
17 | 'wK': 300,
18 |
19 | # Test configuration
20 | 'test_batch_size': 4,
21 | 'num_test_signals': 64,
22 | 'inner_steps_test': 20,
23 |
24 | # Model config
25 | 'min_hidden_dim': 256,
26 | 'max_hidden_dim': 256,
27 | 'progression_type': linear,
28 | 'num_layers': 15,
29 | 'latent_modulation_dim': 8192,
30 |
31 | # Load model from checkpoint
32 | 'load_path': /path/to/best.model or /path/to/stepXXXX.model,
33 | }
--------------------------------------------------------------------------------
/configs/eval/2d_imgs/default_128.yaml:
--------------------------------------------------------------------------------
1 | {
2 | # Training cofiguration
3 | 'dataset': chestmnist,
4 | 'img_size': 128,
5 | 'batch_size': 8,
6 | 'inner_steps': 10,
7 | 'max_iter': 250000,
8 | 'lr_scheduler': True,
9 |
10 | # Context reduction
11 | 'data_ratio': 0.25,
12 | 'sample_type': random,
13 |
14 | # Omega schedule
15 | 'w0_sched_type': linear,
16 | 'w0': 30,
17 | 'wK': 300,
18 |
19 | # Test configuration
20 | 'test_batch_size': 8,
21 | 'num_test_signals': 256,
22 | 'inner_steps_test': 20,
23 |
24 | # Model config
25 | 'min_hidden_dim': 256,
26 | 'max_hidden_dim': 256,
27 | 'progression_type': linear,
28 | 'num_layers': 15,
29 | 'latent_modulation_dim': 8192,
30 |
31 | # Load model from checkpoint
32 | 'load_path': /path/to/best.model or /path/to/stepXXXX.model,
33 | }
--------------------------------------------------------------------------------
/configs/eval/2d_imgs/default_224.yaml:
--------------------------------------------------------------------------------
1 | {
2 | # Training configuration
3 | 'dataset': chestmnist,
4 | 'img_size': 224,
5 | 'batch_size': 4,
6 | 'inner_steps': 10,
7 | 'max_iter': 250000,
8 | 'lr_scheduler': True,
9 |
10 | # Context reduction
11 | 'data_ratio': 0.1,
12 | 'sample_type': random,
13 |
14 | # Omega schedule
15 | 'w0_sched_type': linear,
16 | 'w0': 30,
17 | 'wK': 300,
18 |
19 | # Test configuration
20 | 'test_batch_size': 4,
21 | 'num_test_signals': 256,
22 | 'inner_steps_test': 20,
23 |
24 | # Model config
25 | 'min_hidden_dim': 256,
26 | 'max_hidden_dim': 256,
27 | 'progression_type': linear,
28 | 'num_layers': 15,
29 | 'latent_modulation_dim': 16384,
30 |
31 | # Load model from checkpoint
32 | 'load_path': /path/to/best.model or /path/to/stepXXXX.model,
33 | }
--------------------------------------------------------------------------------
/configs/eval/2d_imgs/default_64.yaml:
--------------------------------------------------------------------------------
1 | {
2 | # Training configuration
3 | 'dataset': chestmnist,
4 | 'img_size': 64,
5 | 'batch_size': 24,
6 | 'inner_steps': 10,
7 | 'max_iter': 250000,
8 | 'lr_scheduler': True,
9 |
10 | # Context reduction
11 | 'data_ratio': 0.25,
12 | 'sample_type': random,
13 |
14 | # Omega schedule
15 | 'w0_sched_type': linear,
16 | 'w0': 20,
17 | 'wK': 400,
18 |
19 | # Test configuration
20 | 'test_batch_size': 24,
21 | 'num_test_signals': 256,
22 | 'inner_steps_test': 20,
23 |
24 | # Model config
25 | 'min_hidden_dim': 256,
26 | 'max_hidden_dim': 256,
27 | 'progression_type': linear,
28 | 'num_layers': 15,
29 | 'latent_modulation_dim': 2048,
30 |
31 | # Load model from checkpoint
32 | 'load_path': /path/to/best.model or /path/to/stepXXXX.model,
33 | }
--------------------------------------------------------------------------------
/configs/fit/default_64.yaml:
--------------------------------------------------------------------------------
1 | {
2 | # Training configuration
3 | 'task': rec,
4 | 'dataset': chestmnist,
5 | 'img_size': 64,
6 | 'batch_size': 1,
7 | 'inner_steps': 10,
8 | 'max_iter': 250000,
9 | 'lr_scheduler': True,
10 |
11 | # Context reduction
12 | 'data_ratio': 0.25,
13 | 'sample_type': random,
14 |
15 | # Omega schedule
16 | 'w0_sched_type': linear,
17 | 'w0': 20,
18 | 'wK': 400,
19 |
20 | # Test configuration
21 | 'test_batch_size': 1,
22 | 'num_test_signals': 256,
23 | 'inner_steps_test': 20,
24 |
25 | # Model config
26 | 'min_hidden_dim': 256,
27 | 'max_hidden_dim': 256,
28 | 'progression_type': linear,
29 | 'num_layers': 15,
30 | 'latent_modulation_dim': 2048,
31 |
32 | # Load model from checkpoint
33 | 'load_path': /path/to/best.model or /path/to/stepXXXX.model,
34 |
35 | #Save NFs to
36 | 'save_dir': /path/to/save/medfuncta/set
37 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Paul Friedrich
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/common/w0_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import math
4 | from typing import Callable
5 |
6 | def get_w0s(args):
7 | sched_type = args.w0_sched_type
8 | num_layers = args.num_layers - 1
9 | device = args.device
10 |
11 | def linear_schedule(i):
12 | w0_0 = args.w0
13 | w0_K = args.wK
14 | a = (w0_K - w0_0) / (num_layers - 1)
15 | b = w0_0
16 | return a * i + b
17 |
18 | def exponential_schedule(i):
19 | w0_0 = args.w0
20 | w0_K = args.wK
21 | a = w0_0
22 | b = math.log(w0_K / w0_0) / (num_layers - 1)
23 | return a * torch.exp(b * i)
24 |
25 | def const_manual_schedule(i):
26 | return torch.tensor(args.w0, device=device)
27 |
28 | sched_map: dict[str, Callable[[torch.Tensor], torch.Tensor]] = {
29 | 'linear': linear_schedule,
30 | 'exponential': exponential_schedule,
31 | }
32 |
33 | w0_fn = sched_map.get(sched_type, const_manual_schedule)
34 | w0s = [w0_fn(torch.tensor(i, device=device)) for i in range(num_layers)]
35 |
36 | return w0s
37 |
38 | def save_w0s(w0s, logger):
39 | logdir = logger.logdir
40 | w0_path = os.path.join(logdir, 'w0s.sched')
41 | w0s = [t.cpu() for t in w0s]
42 | torch.save(w0s, w0_path)
43 |
44 | def load_w0s(args):
45 | load_dir = os.path.join(os.path.dirname(args.load_path), 'w0s.sched')
46 | w0s = torch.load(load_dir, weights_only=True)
47 | return w0s
48 |
--------------------------------------------------------------------------------
/eval/maml_full_fit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms as transforms
3 |
4 | from tqdm import tqdm
5 | from train.maml_boot import inner_adapt_test_scale
6 |
7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8 |
9 |
10 | def fit_nfs(args, model_wrapper, dataloader, set=None):
11 |
12 | model_wrapper.model.eval()
13 | model_wrapper.coord_init()
14 |
15 | for n, data in enumerate(tqdm(dataloader)):
16 | data, label = data
17 | data = data.to(device)
18 | batch_size = data.size(0)
19 | model_wrapper.model.reset_modulations()
20 |
21 | _ = inner_adapt_test_scale(model_wrapper=model_wrapper, data=data, step_size=args.inner_lr,
22 | num_steps=args.inner_steps_test, first_order=True, sample_type=args.sample_type,
23 | scale_type='grad')
24 |
25 | if set == 'test':
26 | with torch.no_grad():
27 | pred = model_wrapper().clamp(0, 1)
28 | if n < 100:
29 | # Convert to PIL image
30 | to_pil = transforms.ToPILImage()
31 | image = to_pil(data.squeeze())
32 | input_path = args.save_dir + f'test/imgs/{n}_input.png'
33 | image.save(input_path)
34 | image = to_pil(pred.squeeze())
35 | recon_path = args.save_dir + f'test/imgs/{n}_recon.png'
36 | image.save(recon_path)
37 |
38 | else:
39 | input('done')
40 |
41 | for i in range(batch_size):
42 | datapoint = {
43 | 'modulations': model_wrapper.model.modulations[i].detach().cpu(),
44 | 'label': label[i].detach().cpu()
45 | }
46 | sdir = args.save_dir + f'/{set}/' + f'datapoint_{(n * batch_size) + i}.pt'
47 | torch.save(datapoint, sdir)
48 | return
49 |
--------------------------------------------------------------------------------
/models/layers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 |
5 |
6 | class LatentModulatedSIRENLayer(nn.Module):
7 | def __init__(self, in_size, out_size, latent_modulation_dim: 512, w0=30.,
8 | modulate_shift=True, modulate_scale=False, is_first=False, is_last=False):
9 | super().__init__()
10 | self.in_size = in_size
11 | self.out_size = out_size
12 | self.latent_modulation_dim = latent_modulation_dim
13 | self.w0 = w0
14 | self.modulate_shift = modulate_shift
15 | self.modulate_scale = modulate_scale
16 | self.is_first = is_first
17 | self.is_last = is_last
18 |
19 | self.linear = nn.Linear(in_size, out_size, bias=True)
20 |
21 | if modulate_shift and not is_first and not is_last:
22 | self.modulate_shift_layer = nn.Linear(latent_modulation_dim, out_size)
23 | if modulate_scale and not is_first and not is_last:
24 | self.modulate_scale_layer = nn.Linear(latent_modulation_dim, out_size)
25 |
26 | self._init(w0, is_first)
27 |
28 | def _init(self, w0, is_first):
29 | dim_in = self.in_size
30 | w_std = 1 / dim_in if is_first else math.sqrt(6.0 / dim_in) / w0.item()
31 | nn.init.uniform_(self.linear.weight, -w_std, w_std)
32 | nn.init.uniform_(self.linear.bias, -w_std, w_std)
33 |
34 | def forward(self, x, latent):
35 | x = self.linear(x)
36 |
37 | if not self.is_first and not self.is_last:
38 | shift = 0.0 if not self.modulate_shift else self.modulate_shift_layer(latent)
39 | scale = 1.0 if not self.modulate_scale else self.modulate_scale_layer(latent)
40 |
41 | if self.modulate_shift:
42 | if len(shift.shape) == 2:
43 | shift = shift.unsqueeze(dim=1)
44 | if self.modulate_scale:
45 | if len(scale.shape) == 2:
46 | scale = scale.unsqueeze(dim=1)
47 |
48 | x = scale * x + shift
49 |
50 | if not self.is_last:
51 | x = torch.sin(self.w0 * x)
52 | return x
53 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 |
4 | from common.args import parse_args
5 | from common.utils import set_random_seed, load_model
6 | from common.w0_utils import load_w0s
7 | from data.dataset import get_dataset
8 | from eval.maml_full_eval import test_model
9 | from models.inrs import LatentModulatedSIREN
10 | from models.model_wrapper import ModelWrapper
11 |
12 |
13 | def main(args):
14 | """
15 | Main function to call for running an evaluation procedure (evaluate performance on test set).
16 | :param args: parameters parsed from the command line/ a config.yaml..
17 | :return: Nothing.
18 | """
19 |
20 | """ Set a device to use """
21 | if torch.cuda.is_available():
22 | torch.cuda.set_device(args.gpu_id)
23 | device = torch.device(f'cuda' if torch.cuda.is_available() else 'cpu')
24 | args.device = device
25 |
26 | """ Enable determinism """
27 | set_random_seed(args.seed)
28 | torch.backends.cudnn.deterministic = True
29 | torch.backends.cudnn.benchmark = False
30 |
31 | """ Define test dataset """
32 | test_set = get_dataset(args, only_test=True)
33 | test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, num_workers=4, pin_memory=True,
34 | drop_last=True)
35 |
36 | """ Get w0s to initialize the model """
37 | w0s = load_w0s(args)
38 | args.w0s = w0s
39 |
40 | """ Initialize model and optimizer """
41 | model = LatentModulatedSIREN(
42 | in_size=args.in_size,
43 | out_size=args.out_size,
44 | min_hidden_size=args.min_hidden_dim,
45 | max_hidden_size=args.max_hidden_dim,
46 | progression_type=args.progression_type,
47 | num_layers=args.num_layers,
48 | latent_modulation_dim=args.latent_modulation_dim,
49 | w0s=args.w0s,
50 | modulate_shift=args.modulate_shift,
51 | modulate_scale=args.modulate_scale,
52 | enable_skip_connections=args.enable_skip_connections,
53 | ).to(device)
54 |
55 | """ Initialize modulation vectors (signal-specific parameter vector) """
56 | model.modulations = torch.zeros(size=[args.test_batch_size, args.latent_modulation_dim], requires_grad=True).to(device)
57 |
58 | """ Wrap the model """
59 | model = ModelWrapper(args, model)
60 | load_model(args, model)
61 |
62 | """ Define test function """
63 | test_model(args, model, test_loader, logger=None)
64 |
65 |
66 | if __name__ == "__main__":
67 | args = parse_args()
68 | main(args)
69 |
--------------------------------------------------------------------------------
/train/trainer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import torch
4 | from common.utils import resume_training, MetricLogger, save_checkpoint, save_checkpoint_step
5 |
6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 |
8 |
9 | def trainer(args, train_function, test_function, model_wrapper, meta_optimizer, train_loader, test_loader, logger, scheduler=None):
10 | """
11 | The main function that performs the training. Iteratively calls training steps (train_function) and evaluations (test_function).
12 | :param args: parameters parsed from the command line.
13 | :param train_function: function that performs a single meta-update step.
14 | :param test_function: function that performs the evaluation.
15 | :param model_wrapper: the wrapped model.
16 | :param meta_optimizer: optimizer used for the meta-learning (global optimizer).
17 | :param train_loader: data loader for training data.
18 | :param test_loader: data loader for testing data (usually this will be your validation set).
19 | :param logger: a logger.
20 | :param scheduler: boolean - True: lr-schedule is used.
21 | :return: Nothing.
22 | """
23 |
24 | metric_logger = MetricLogger(delimiter=" ")
25 |
26 | """ Resume training (optional with '--resume_path' flag) """
27 | is_best, start_step, best_psnr, psnr = resume_training(args, model_wrapper, meta_optimizer)
28 |
29 | """ Start Training """
30 | logger.log_dirname(f"Start training")
31 |
32 | """ Load training data """
33 | for it, train_batch in enumerate(train_loader):
34 | step = start_step + it + 1
35 | if step > args.outer_steps:
36 | break
37 |
38 | train_batch, _ = train_batch
39 | train_batch = train_batch.float().to(device, non_blocking=True)
40 |
41 | """ Perform a single meta-update training step """
42 | train_function(args, step, model_wrapper, meta_optimizer, train_batch, metric_logger, logger, scheduler)
43 |
44 | """ Evaluate and save model every eval_step steps """
45 | if step == 1 or step % args.eval_step == 0:
46 | psnr, lpips, ssim = test_function(args, step, model_wrapper, test_loader, logger)
47 |
48 | if best_psnr < psnr:
49 | best_psnr = psnr
50 | save_checkpoint(args, step, best_psnr, model_wrapper, meta_optimizer.state_dict(), logger.logdir,
51 | is_best=True)
52 |
53 | logger.scalar_summary('eval/best_psnr', best_psnr, step)
54 | logger.log('[EVAL] [Step %3d] [PSNR %5.2f] [BestPSNR %5.2f]' % (step, psnr, best_psnr))
55 |
56 | """ Save model every save_step steps"""
57 | if step == 1 or step % args.save_step == 0:
58 | save_checkpoint_step(args, step, best_psnr, model_wrapper,meta_optimizer.state_dict(), logger.logdir)
59 |
60 | """ Finish training after max_iter steps """
61 | if step >= args.max_iter:
62 | break
63 |
64 | """ Save the last model"""
65 | save_checkpoint(args, args.outer_steps, best_psnr, model_wrapper, meta_optimizer.state_dict(), logger.logdir)
66 |
--------------------------------------------------------------------------------
/eval/maml_full_eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import lpips
3 | import torch.nn.functional as F
4 | from tqdm import tqdm
5 | from pytorch_msssim import ssim
6 |
7 | from common.utils import MetricLogger, psnr, ssim_1d
8 | from train.maml_boot import inner_adapt_test_scale
9 |
10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11 |
12 |
13 | def test_model(args, model_wrapper, test_loader, logger=None):
14 | metric_logger = MetricLogger(delimiter=" ")
15 |
16 | if logger is None:
17 | log_ = print
18 | else:
19 | log_ = logger.log
20 |
21 | model_wrapper.model.eval()
22 | model_wrapper.coord_init()
23 |
24 | lpips_score = lpips.LPIPS(net='alex').to(device)
25 |
26 | for n, data in enumerate(tqdm(test_loader)):
27 | data, _ = data
28 | data = data.to(device)
29 | batch_size = data.size(0)
30 | model_wrapper.model.reset_modulations()
31 |
32 | _ = inner_adapt_test_scale(model_wrapper=model_wrapper, data=data, step_size=args.inner_lr,
33 | num_steps=args.inner_steps_test, first_order=True, sample_type=args.sample_type,
34 | scale_type='grad')
35 |
36 | with torch.no_grad():
37 | pred = model_wrapper().clamp(0, 1)
38 |
39 | if args.data_type == 'img':
40 | lpips_results = lpips_score((pred * 2 - 1), (data * 2 - 1)).mean()
41 | mse_results = F.mse_loss(data.view(batch_size, -1), pred.reshape(batch_size, -1), reduce=False).mean()
42 | psnr_results = psnr(
43 | F.mse_loss(data.view(batch_size, -1), pred.reshape(batch_size, -1), reduce=False).mean(dim=1)
44 | ).mean()
45 | ssim_results = ssim(pred, data, data_range=1.).mean()
46 |
47 | elif args.data_type == 'img3d':
48 | mse_results = F.mse_loss(data.view(batch_size, -1), pred.reshape(batch_size, -1), reduce=False).mean()
49 | psnr_results = psnr(
50 | F.mse_loss(data.view(batch_size, -1), pred.reshape(batch_size, -1), reduce=False).mean(dim=1)
51 | ).mean()
52 | ssim_results = ssim(pred, data, data_range=1.).mean()
53 | lpips_results = torch.zeros_like(psnr_results)
54 |
55 | elif args.data_type == 'timeseries':
56 | mse_results = F.mse_loss(data.view(batch_size, -1), pred.reshape(batch_size, -1), reduce=False).mean()
57 | psnr_results = psnr(
58 | F.mse_loss(data.view(batch_size, -1), pred.reshape(batch_size, -1), reduce=False).mean(dim=1)
59 | ).mean()
60 | ssim_results = ssim_1d(pred.squeeze(), data.squeeze(), data_range=1.).mean()
61 | lpips_results = torch.zeros_like(psnr_results)
62 |
63 | else:
64 | raise NotImplementedError()
65 |
66 | metric_logger.meters['lpips'].update(lpips_results.item(), n=batch_size)
67 | metric_logger.meters['psnr'].update(psnr_results.item(), n=batch_size)
68 | metric_logger.meters['mse'].update(mse_results.item(), n=batch_size)
69 | metric_logger.meters['ssim'].update(ssim_results.item(), n=batch_size)
70 |
71 | if n % 10 == 0:
72 | # gather the stats from all processes
73 | metric_logger.synchronize_between_processes()
74 |
75 | log_(f'*[EVAL {n}][PSNR %.6f][LPIPS %.6f][SSIM %.6f][MSE %.6f]' %
76 | (metric_logger.psnr.global_avg, metric_logger.lpips.global_avg,
77 | metric_logger.ssim.global_avg, metric_logger.mse.global_avg))
78 |
79 | # gather the stats from all processes
80 | metric_logger.synchronize_between_processes()
81 | log_(f'*[EVAL Final][PSNR %.8f][LPIPS %.8f][SSIM %.8f][MSE %.8f]' %
82 | (metric_logger.psnr.global_avg, metric_logger.lpips.global_avg,
83 | metric_logger.ssim.global_avg, metric_logger.mse.global_avg))
84 |
85 | return
86 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Log files
10 | logs/
11 | nfsets/
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 | cover/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | .pybuilder/
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | # For a library or package, you might want to ignore these files since the code is
91 | # intended to run in multiple environments; otherwise, check them in:
92 | # .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # poetry
102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103 | # This is especially recommended for binary packages to ensure reproducibility, and is more
104 | # commonly ignored for libraries.
105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106 | #poetry.lock
107 |
108 | # pdm
109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110 | #pdm.lock
111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112 | # in version control.
113 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
114 | .pdm.toml
115 | .pdm-python
116 | .pdm-build/
117 |
118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
119 | __pypackages__/
120 |
121 | # Celery stuff
122 | celerybeat-schedule
123 | celerybeat.pid
124 |
125 | # SageMath parsed files
126 | *.sage.py
127 |
128 | # Environments
129 | .env
130 | .venv
131 | env/
132 | venv/
133 | ENV/
134 | env.bak/
135 | venv.bak/
136 |
137 | # Spyder project settings
138 | .spyderproject
139 | .spyproject
140 |
141 | # Rope project settings
142 | .ropeproject
143 |
144 | # mkdocs documentation
145 | /site
146 |
147 | # mypy
148 | .mypy_cache/
149 | .dmypy.json
150 | dmypy.json
151 |
152 | # Pyre type checker
153 | .pyre/
154 |
155 | # pytype static type analyzer
156 | .pytype/
157 |
158 | # Cython debug symbols
159 | cython_debug/
160 |
161 | # PyCharm
162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
164 | # and can be added to the global gitignore or merged into this file. For a more nuclear
165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
166 | .idea/
--------------------------------------------------------------------------------
/data/preproc_lidc-idri.py:
--------------------------------------------------------------------------------
1 | """
2 | Script for preprocessing the LIDC-IDRI dataset.
3 | """
4 | import argparse
5 | import os
6 | import shutil
7 | import dicom2nifti
8 | import nibabel as nib
9 | import numpy as np
10 | from scipy.ndimage import zoom
11 |
12 |
13 | def preprocess_nifti(input_path, output_path):
14 | # Load the Nifti image
15 | print('Process image: {}'.format(input_path))
16 | img = nib.load(input_path)
17 |
18 | # Get the current voxel sizes
19 | voxel_sizes = img.header.get_zooms()
20 |
21 | # Calculate the target voxel size (1mm x 1mm x 1mm)
22 | target_voxel_size = (1.0, 1.0, 1.0)
23 |
24 | # Calculate the resampling factor
25 | zoom_factors = [current / target for target, current in zip(target_voxel_size, voxel_sizes)]
26 |
27 | # Resample the image
28 | print("[1] Resample the image ...")
29 | resampled_data = zoom(img.get_fdata(), zoom_factors, order=3, mode='nearest')
30 |
31 | print("[2] Center crop the image ...")
32 | crop_size = (256, 256, 256)
33 | depth, height, width = resampled_data.shape
34 |
35 | d_start = (depth - crop_size[0]) // 2
36 | h_start = (height - crop_size[1]) // 2
37 | w_start = (width - crop_size[2]) // 2
38 | cropped_arr = resampled_data[d_start:d_start + crop_size[0], h_start:h_start + crop_size[1], w_start:w_start + crop_size[2]]
39 |
40 | print("[3] Clip all values below -1000 ...")
41 | cropped_arr[cropped_arr < -1000] = -1000
42 |
43 | print("[4] Clip the upper quantile (0.999) to remove outliers ...")
44 | out_clipped = np.clip(cropped_arr, -1000, np.quantile(cropped_arr, 0.999))
45 |
46 | print("[5] Bring image to positive and cast to ...")
47 | out_pos = out_clipped + 1000
48 | out_pos = np.int16(out_pos)
49 |
50 | assert out_pos.shape == (256, 256, 256), "The output shape should be (256,256,256)"
51 |
52 | print("[6] FINAL REPORT: Min value: {}, Max value: {}, Shape: {}".format(out_pos.min(),
53 | out_pos.max(),
54 | out_pos.shape))
55 | print("-------------------------------------------------------------------------------")
56 |
57 | # Save the resampled image
58 | resampled_img = nib.Nifti1Image(out_pos, np.eye(4))
59 | nib.save(resampled_img, output_path)
60 |
61 |
62 | if __name__ == "__main__":
63 | parser = argparse.ArgumentParser()
64 | parser.add_argument('--dicom_dir', type=str, required=True,
65 | help='Directory containing the original dicom data')
66 | parser.add_argument('--nifti_dir', type=str, required=True,
67 | help='Directory to store the processed nifti files')
68 | parser.add_argument('--delete_unprocessed', type=eval, default=False,
69 | help='Set true to delete the unprocessed nifti files')
70 | args = parser.parse_args()
71 |
72 | # Convert DICOM to nifti
73 | for patient in os.listdir(args.dicom_dir):
74 | print('Convert {} to nifti'.format(patient))
75 | if not os.path.exists(os.path.join(args.nifti_dir, patient)):
76 | os.makedirs(os.path.join(args.nifti_dir, patient))
77 | dicom2nifti.convert_directory(os.path.join(args.dicom_dir, patient),
78 | os.path.join(args.nifti_dir, patient))
79 | shutil.rmtree(os.path.join(args.dicom_dir, patient))
80 |
81 | # Preprocess nifti files
82 | for root, dirs, files in os.walk(args.nifti_dir):
83 | for file in files:
84 | try:
85 | preprocess_nifti(os.path.join(root, file), os.path.join(root, 'processed.nii.gz'))
86 | except:
87 | print("Error occurred for file: {}".format(file))
88 |
89 | # Delete unprocessed nifti files
90 | if args.delete_unprocessed:
91 | for root, dirs, files in os.walk(args.nifti_dir):
92 | for file in files:
93 | if not file == 'processed.nii.gz':
94 | os.remove(os.path.join(root, file))
--------------------------------------------------------------------------------
/fit_NFset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | import torch
4 | from torch.utils.data import DataLoader
5 |
6 | from common.args import parse_args
7 | from common.utils import set_random_seed, load_model
8 | from common.w0_utils import get_w0s, save_w0s
9 | from data.dataset import get_dataset
10 | from eval.maml_full_fit import fit_nfs
11 | from models.inrs import LatentModulatedSIREN
12 | from models.model_wrapper import ModelWrapper
13 |
14 |
15 | def main(args):
16 | """
17 | Main function to call for fitting neural fields to a whole dataset (having pretrained shared weights).
18 | :param args: parameters parsed from the command line/a config.yaml.
19 | :return: Nothing.
20 | """
21 |
22 | """ Set a device to use """
23 | if torch.cuda.is_available():
24 | torch.cuda.set_device(args.gpu_id)
25 | device = torch.device(f'cuda' if torch.cuda.is_available() else 'cpu')
26 | args.device = device
27 |
28 | """ Enable determinism """
29 | set_random_seed(args.seed)
30 | torch.backends.cudnn.deterministic = True
31 | torch.backends.cudnn.benchmark = False
32 |
33 | """ Define dataset that you want to convert to NFs """
34 | train, val, test = get_dataset(args, all=True)
35 | train_loader = DataLoader(train, batch_size=args.test_batch_size, shuffle=False, num_workers=4, pin_memory=True,
36 | drop_last=True)
37 | val_loader = DataLoader(val, batch_size=args.test_batch_size, shuffle=False, num_workers=4, pin_memory=True,
38 | drop_last=True)
39 | test_loader = DataLoader(test, batch_size=args.test_batch_size, shuffle=False, num_workers=4, pin_memory=True,
40 | drop_last=True)
41 |
42 | """ Get w0s to initialize the model """
43 | w0s = get_w0s(args)
44 | args.w0s = w0s
45 |
46 | """ Initialize model and optimizer """
47 | model = LatentModulatedSIREN(
48 | in_size=args.in_size,
49 | out_size=args.out_size,
50 | min_hidden_size=args.min_hidden_dim,
51 | max_hidden_size=args.max_hidden_dim,
52 | progression_type=args.progression_type,
53 | num_layers=args.num_layers,
54 | latent_modulation_dim=args.latent_modulation_dim,
55 | w0s=args.w0s,
56 | modulate_shift=args.modulate_shift,
57 | modulate_scale=args.modulate_scale,
58 | enable_skip_connections=args.enable_skip_connections,
59 | ).to(device)
60 |
61 | """ Initialize modulation vectors """
62 | model.modulations = torch.zeros(size=[args.test_batch_size, args.latent_modulation_dim], requires_grad=True).to(device)
63 | model = ModelWrapper(args, model)
64 | load_model(args, model)
65 |
66 | if not os.path.exists(args.save_dir):
67 | print(f'Create: {args.save_dir}')
68 | os.mkdir(args.save_dir)
69 |
70 | """ Create training set """
71 | if not os.path.exists(args.save_dir + 'train'):
72 | print(f'Create: {args.save_dir}'+'train/')
73 | os.mkdir(args.save_dir + 'train/')
74 | fit_nfs(args, model, train_loader, set='train')
75 | print("Created MedFuncta Set: Training")
76 |
77 | """ Create validation set """
78 | if not os.path.exists(args.save_dir + 'val'):
79 | print(f'Create: {args.save_dir}' + 'val/')
80 | os.mkdir(args.save_dir + 'val/')
81 | fit_nfs(args, model, val_loader, set='val')
82 | print("Created MedFuncta Set: Validation")
83 |
84 | """ Create test set """
85 | if not os.path.exists(args.save_dir + 'test'):
86 | print(f'Create: {args.save_dir}' + 'test/')
87 | os.mkdir(args.save_dir + 'test')
88 | os.mkdir(args.save_dir + 'test/imgs')
89 | fit_nfs(args, model, test_loader, set='test')
90 | print("Created MedFuncta Set: Test")
91 |
92 | """ Save the model to save_dir folder """
93 | model_path = args.save_dir + 'model.pt'
94 | torch.save(model.model, model_path)
95 | print("DONE")
96 |
97 |
98 | if __name__ == "__main__":
99 | args = parse_args()
100 | main(args)
101 |
--------------------------------------------------------------------------------
/models/inrs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch import nn
4 | from models.layers import LatentModulatedSIRENLayer
5 |
6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 |
8 |
9 | class LatentModulatedSIREN(nn.Module):
10 | def __init__(self, in_size, out_size, w0s, min_hidden_size=256, max_hidden_size=256, num_layers=15,
11 | latent_modulation_dim=2048, modulate_shift=True, modulate_scale=False, enable_skip_connections=False,
12 | progression_type='linear'):
13 | super().__init__()
14 | self.num_layers = num_layers
15 | self.hidden_sizes = self._calculate_progressive_sizes(
16 | min_hidden_size, max_hidden_size, num_layers, progression_type
17 | )
18 | print(f"Progressive layer widths: {self.hidden_sizes}")
19 | layers = []
20 | for i in range(num_layers - 1):
21 | is_first = i == 0
22 | layer_in_size = in_size if is_first else self.hidden_sizes[i - 1]
23 | layer_out_size = self.hidden_sizes[i]
24 | layers.append(LatentModulatedSIRENLayer(in_size=layer_in_size, out_size=layer_out_size,
25 | latent_modulation_dim=latent_modulation_dim, w0=w0s[i],
26 | modulate_shift=modulate_shift, modulate_scale=modulate_scale,
27 | is_first=is_first))
28 | self.layers = nn.ModuleList(layers)
29 | self.last_layer = LatentModulatedSIRENLayer(in_size=self.hidden_sizes[-1], out_size=out_size,
30 | latent_modulation_dim=latent_modulation_dim, w0=w0s[-1],
31 | modulate_shift=modulate_shift, modulate_scale=modulate_scale,
32 | is_last=True)
33 | self.enable_skip_connections = enable_skip_connections
34 | self.modulations = torch.zeros(size=[latent_modulation_dim], requires_grad=True).to(device)
35 |
36 | def reset_modulations(self):
37 | self.modulations = self.modulations.detach() * 0
38 | self.modulations.requires_grad = True
39 |
40 | def forward(self, x, get_features=False):
41 | x = self.layers[0](x, self.modulations)
42 | for layer in self.layers[1:]:
43 | y = layer(x, self.modulations)
44 | if self.enable_skip_connections:
45 | x = x + y
46 | else:
47 | x = y
48 | features = x
49 | out = self.last_layer(features, self.modulations) + 0.5
50 |
51 | if get_features:
52 | return out, features
53 | else:
54 | return out
55 |
56 |
57 | def _calculate_progressive_sizes(self, min_size, max_size, num_layers, progression_type='linear'):
58 | """Calculate progressive hidden layer sizes."""
59 | if num_layers <= 1:
60 | return [min_size]
61 |
62 | # We have num_layers-1 hidden layers (excluding output layer)
63 | n_hidden = num_layers - 1
64 |
65 | if progression_type == 'linear':
66 | # Linear interpolation
67 | sizes = np.linspace(min_size, max_size, n_hidden)
68 | elif progression_type == 'exponential':
69 | # Exponential growth
70 | log_min = np.log(min_size)
71 | log_max = np.log(max_size)
72 | log_sizes = np.linspace(log_min, log_max, n_hidden)
73 | sizes = np.exp(log_sizes)
74 | elif progression_type == 'cosine':
75 | # Cosine schedule (slower at beginning and end)
76 | t = np.linspace(0, 1, n_hidden)
77 | cosine_factor = (1 - np.cos(t * np.pi)) / 2
78 | sizes = min_size + (max_size - min_size) * cosine_factor
79 | else:
80 | raise ValueError(f"Unknown progression_type: {progression_type}")
81 |
82 | # Round to nearest multiple of 8 for efficiency (optional)
83 | sizes = [int(8 * round(size / 8)) for size in sizes]
84 |
85 | return sizes
86 |
--------------------------------------------------------------------------------
/eval/maml_scale.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import lpips
3 | from common.utils import MetricLogger, psnr
4 | from train.maml_boot import inner_adapt_test_scale
5 | from pytorch_msssim import ssim
6 |
7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8 |
9 |
10 | def test_model(args, step, model_wrapper, test_loader, logger=None):
11 | """
12 | Function that performs the model evaluation
13 | """
14 | metric_logger = MetricLogger(delimiter=" ")
15 | lpips_score = lpips.LPIPS(net='alex').to(device)
16 |
17 | if logger is None:
18 | log_ = print
19 | else:
20 | log_ = logger.log
21 |
22 | model_wrapper.model.eval() # Enter evaluation mode
23 | model_wrapper.coord_init() # Reset coordinates
24 |
25 | """ Iterate over test loader """
26 | for n, data in enumerate(test_loader):
27 | if n * args.test_batch_size > args.num_test_signals:
28 | break
29 |
30 | data, _ = data # Discard label
31 | data = data.float().to(device)
32 | batch_size = data.size(0)
33 | model_wrapper.model.reset_modulations()
34 | input = data
35 |
36 | _ = inner_adapt_test_scale(model_wrapper=model_wrapper, data=data, step_size=args.inner_lr,
37 | num_steps=args.inner_steps_test, first_order=True, sample_type=args.sample_type,
38 | scale_type='grad')
39 |
40 | """ Outer loss aggregation """
41 | with torch.no_grad():
42 | loss_out_tt_gradscale = model_wrapper(data)
43 | loss_out = loss_out_tt_gradscale
44 | psnr_out = psnr(loss_out).mean()
45 |
46 | if args.data_type == 'img':
47 | out = model_wrapper().clamp(0,1)
48 | lpips_result = lpips_score((out * 2 - 1), (input * 2 - 1)).mean()
49 | ssim_result = ssim(out, input, data_range=1.).mean()
50 | metric_logger.meters['lpips'].update(lpips_result.item(), n=batch_size)
51 | metric_logger.meters['ssim'].update(ssim_result.item(), n=batch_size)
52 | metric_logger.meters['psnr'].update(psnr_out.item(), n=batch_size)
53 | metric_logger.meters['loss'].update(loss_out.mean().item(), n=batch_size)
54 |
55 | if args.data_type == 'img3d':
56 | out = model_wrapper().clamp(0, 1)
57 | lpips_result = torch.zeros_like(loss_out_tt_gradscale).mean()
58 | ssim_result = ssim(out, input, data_range=1.).mean()
59 | metric_logger.meters['lpips'].update(lpips_result.item(), n=batch_size)
60 | metric_logger.meters['ssim'].update(ssim_result.item(), n=batch_size)
61 | metric_logger.meters['psnr'].update(psnr_out.item(), n=batch_size)
62 | metric_logger.meters['loss'].update(loss_out.mean().item(), n=batch_size)
63 |
64 | if args.data_type == 'timeseries':
65 | lpips_result = torch.zeros_like(loss_out_tt_gradscale).mean()
66 | ssim_result = torch.zeros_like(loss_out_tt_gradscale).mean()
67 | metric_logger.meters['lpips'].update(lpips_result.item(), n=batch_size)
68 | metric_logger.meters['ssim'].update(ssim_result.item(), n=batch_size)
69 | metric_logger.meters['psnr'].update(psnr_out.item(), n=batch_size)
70 | metric_logger.meters['loss'].update(loss_out.mean().item(), n=batch_size)
71 |
72 | metric_logger.synchronize_between_processes()
73 |
74 | """ Log to tensorboard & console """
75 | if args.data_type == 'img' or args.data_type == 'img3d':
76 | log_('*[EVAL-GTT-REC] [Loss %f] [PSNR %.3f] [LPIPS %.3f] [SSIM %.3f]' %
77 | (metric_logger.loss.global_avg, metric_logger.psnr.global_avg, metric_logger.lpips.global_avg,
78 | metric_logger.ssim.global_avg))
79 | if logger is not None:
80 | logger.scalar_summary('eval/loss', metric_logger.loss.global_avg, step)
81 | logger.scalar_summary('eval/psnr', metric_logger.psnr.global_avg, step)
82 | logger.scalar_summary('eval/ssim', metric_logger.ssim.global_avg, step)
83 | logger.scalar_summary('eval/lpips', metric_logger.lpips.global_avg, step)
84 | logger.log_image('eval/img_in', input, step)
85 | logger.log_image('eval/img_out', out, step)
86 |
87 |
88 | return metric_logger.psnr.global_avg, metric_logger.lpips.global_avg, metric_logger.ssim.global_avg
89 |
--------------------------------------------------------------------------------
/common/args.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import os.path
3 | from argparse import ArgumentParser
4 |
5 |
6 | def load_cfg(args):
7 | with open(args.config, "rb") as f:
8 | cfg = yaml.safe_load(f)
9 |
10 | for key, value in cfg.items():
11 | args.__dict__[key] = value
12 |
13 | return args
14 |
15 |
16 | def parse_args():
17 | parser = ArgumentParser()
18 |
19 | """ Config """
20 | parser.add_argument('--config', help='Path to config .yaml file', type=str, default=None)
21 |
22 | """ System configuration """
23 | parser.add_argument('--gpu_id', help='GPU ID', type=int, default=0)
24 | parser.add_argument('--seed', help='Random seed', type=int, default=42)
25 |
26 | """ Resume training """
27 | parser.add_argument('--resume_path', help='Path to the logdir of training to resume', type=str)
28 |
29 | """ Training configuration """
30 | parser.add_argument('--dataset', help='Dataset', type=str)
31 | parser.add_argument('--img_size', help='Image size', type=int, default=64)
32 | parser.add_argument('--batch_size', help='Batch size (number of images per batch) used for training', type=int, default=24)
33 | parser.add_argument('--outer_steps', help='Numer of meta-learning steps to perform', type=int, default=250000)
34 | parser.add_argument('--inner_steps', help='Number of inner loop optimization steps (G)', type=int, default=10)
35 | parser.add_argument('--meta_lr', help='Learning rate for meta-learning updates (beta)', type=float, default=3e-6)
36 | parser.add_argument('--inner_lr', help='Learning rate for inner loop (alpha)', type=float, default=1e-2)
37 | parser.add_argument('--lr_scheduler', help='If True, a global lr-schedule is applied', type=eval, default=True)
38 | parser.add_argument('--data_ratio', help='Ratio of data used for training (gamma)', type=float, default=0.25)
39 |
40 | """ Testing configuration """
41 | parser.add_argument('--test_batch_size', help='Batch size used for testing', type=int, default=24)
42 | parser.add_argument('--num_test_signals', help='Number of signals used for testing', default=256, type=int)
43 | parser.add_argument('--inner_steps_test', help='Number of inner loop update steps at test-time (H)', type=int, default=20)
44 |
45 | """ Model configuration """
46 | parser.add_argument('--min_hidden_dim', help='MLP hidden size start', type=int, default=256)
47 | parser.add_argument('--max_hidden_dim', help='MLP hidden size start', type=int, default=256)
48 | parser.add_argument('--progression_type', help='Progression type hidden_dim [linear, exponential, cosine]', type=str, default='linear')
49 | parser.add_argument('--num_layers', help='Number of MLP layers (K)', type=int, default=15)
50 | parser.add_argument('--latent_modulation_dim', help='Representation size (P)', type=int, default=2048)
51 | parser.add_argument('--w0', help='SIREN parameter w0', type=float, default=30.)
52 | parser.add_argument('--wK', help='SIREN parameter wK', type=float, default=300.)
53 | parser.add_argument('--w0_sched_type', help='Type of w0 schedule', type=str, default='linear')
54 | parser.add_argument('--modulate_shift', help='Set True to use shift modulations', type=eval, default=True)
55 | parser.add_argument('--modulate_scale', help='Set True to use scale modulations (not recommended)', type=eval, default=False)
56 | parser.add_argument('--enable_skip_connections', help='Set True to enable skip-connections', type=eval, default=False)
57 |
58 | """ Logging configuration """
59 | parser.add_argument('--print_step', help='Print every x steps', type=int, default=100)
60 | parser.add_argument('--print_img_step', help='Print images every x steps', type=int, default=100)
61 | parser.add_argument('--eval_step', help='Evaluate every x steps', type=int, default=1000)
62 | parser.add_argument('--save_step', help='Save model every x steps', type=int, default=50000)
63 | parser.add_argument('--advanced_step', type=int, default=1000)
64 | parser.add_argument('--log_advanced', help='Activate to log advanced statistics', type=eval, default=False)
65 |
66 | """ Eval configuration """
67 | parser.add_argument('--load_path', help='Load model from this path', type=str, default=None)
68 |
69 | """ Fitting configuration """
70 | parser.add_argument('--save_dir', help='Directory to store shared model, modulations and labels', type=str, default=None)
71 |
72 | """ Parse Arguments """
73 | args = parser.parse_args()
74 |
75 | """ Load config files """
76 | if args.config is not None and os.path.exists(args.config):
77 | load_cfg(args)
78 |
79 | return args
80 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | from torch.utils.data import DataLoader
4 |
5 | from common.args import parse_args
6 | from common.utils import set_random_seed, Logger, InfiniteSampler
7 | from common.w0_utils import get_w0s, save_w0s
8 | from data.dataset import get_dataset
9 | from models.inrs import LatentModulatedSIREN
10 | from models.model_wrapper import ModelWrapper
11 | from train.trainer import trainer
12 | from train.maml_boot import train_step
13 | from eval.maml_scale import test_model
14 |
15 |
16 | def main(args):
17 | """
18 | Main function to call for running a training procedure (meta-learning a shred network).
19 | :param args: parameters parsed from the command line/a config.yaml.
20 | :return: Nothing.
21 | """
22 |
23 | """ Set a device to use """
24 | if torch.cuda.is_available():
25 | torch.cuda.set_device(args.gpu_id)
26 | device = torch.device(f'cuda' if torch.cuda.is_available() else 'cpu')
27 | args.device = device
28 |
29 | """ Enable determinism """
30 | set_random_seed(args.seed)
31 | torch.backends.cudnn.deterministic = True
32 | torch.backends.cudnn.benchmark = True
33 |
34 | """ Define dataset """
35 | train_set, val_set = get_dataset(args)
36 |
37 | """ Define dataloader """
38 | infinite_sampler = InfiniteSampler(train_set, rank=0, num_replicas=1, shuffle=True, seed=args.seed)
39 | train_loader = DataLoader(train_set, sampler=infinite_sampler, batch_size=args.batch_size, num_workers=4,
40 | prefetch_factor=2)
41 | val_loader = DataLoader(val_set, batch_size=args.test_batch_size, shuffle=False, num_workers=4)
42 |
43 | """ Get w0s to initialize the model """
44 | w0s = get_w0s(args)
45 | args.w0s = w0s
46 |
47 | """ Initialize model """
48 | model = LatentModulatedSIREN(
49 | in_size=args.in_size, # Input dimension (coordinate dim) C
50 | out_size=args.out_size, # Output dimension (signal dim) D
51 | min_hidden_size=args.min_hidden_dim, # First layers hidden dimension
52 | max_hidden_size=args.max_hidden_dim, # Last layers hidden dimension (usually min_hidden_dim)
53 | progression_type=args.progression_type, # Defines how hidden dimension progresses in model
54 | num_layers=args.num_layers, # Number of layers K
55 | latent_modulation_dim=args.latent_modulation_dim, # Representation size P
56 | w0s=args.w0s, # Per-layer omega parameters
57 | modulate_shift=args.modulate_shift, # If shift modulation is used (default: True)
58 | modulate_scale=args.modulate_scale, # If scale modulation is used (default: False)
59 | enable_skip_connections=args.enable_skip_connections, # Set True to enable skip-connections (default: False)
60 | ).to(device)
61 |
62 | """ Initialize modulation vectors (signal-specific parameter vector) """
63 | model.modulations = torch.zeros(size=[args.batch_size, args.latent_modulation_dim], requires_grad=True).to(device)
64 | model.modulation_init = model.modulations.clone().detach()
65 |
66 | """ Wrap the model """
67 | model = ModelWrapper(args, model)
68 |
69 | """ Define training and test functions """
70 | train_function = train_step
71 | test_function = test_model
72 |
73 | """ Define logger """
74 | fname = (f'{args.dataset}_size{args.img_size}_bs{args.batch_size}_inner{args.inner_steps}_gamma{args.data_ratio}_'
75 | f'{args.config.split("/")[-1].split(".yaml")[0]}')
76 | logger = Logger(fname, ask=args.resume_path is None, rank=args.gpu_id)
77 | logger.log(args)
78 | logger.log(w0s)
79 | logger.log(model)
80 | logger.log_hyperparameters(args)
81 |
82 | """ Save w0s """
83 | save_w0s(w0s, logger)
84 |
85 | """ Initialize meta-optimizer """
86 | meta_optimizer = optim.AdamW(params=model.model.parameters(), lr=args.meta_lr)
87 |
88 | """ Initialize a global lr-scheduler (recommended) """
89 | scheduler = None
90 | if args.lr_scheduler:
91 | scheduler = optim.lr_scheduler.CosineAnnealingLR(meta_optimizer, eta_min=1e-7, T_max=args.max_iter)
92 |
93 | """ Start training """
94 | trainer(args, train_function, test_function, model, meta_optimizer, train_loader, val_loader, logger, scheduler)
95 |
96 | """ Close logger """
97 | logger.close_writer()
98 |
99 |
100 | if __name__ == "__main__":
101 | args = parse_args()
102 | main(args)
103 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MedFuncta: A Unified Framework for Learning Efficient Medical Neural Fields
2 | [](https://opensource.org/licenses/MIT)
3 | [](https://pfriedri.github.io/medfuncta-io/)
4 | [](https://arxiv.org/abs/2502.14401)
5 |
6 | This is the official PyTorch implementation of the paper **MedFuncta: A Unified Framework for Learning Efficient Medical Neural Fields** by [Paul Friedrich](https://pfriedri.github.io/), [Florentin Bieder](https://dbe.unibas.ch/en/persons/florentin-bieder/), [Julian McGinnis](https://www.kiinformatik.mri.tum.de/de/team/julian_mcginnis), [Julia Wolleb](https://medicine.yale.edu/profile/julia-wolleb/), [Daniel Rueckert](https://www.professoren.tum.de/rueckert-daniel) and [Philippe C. Cattin](https://dbe.unibas.ch/en/persons/philippe-claude-cattin/).
7 |
8 | If you find our work useful, please consider to :star: **star this repository** and :memo: **cite our paper**:
9 | ```bibtex
10 | @article{friedrich2025medfuncta,
11 | title={MedFuncta: A Unified Framework for Learning Efficient Medical Neural Fields},
12 | author={Friedrich, Paul and Bieder, Florentin and McGinnis, Julian and Wolleb, Julia and Rueckert, Daniel and Cattin, Philippe C},
13 | journal={arXiv preprint arXiv:2502.14401},
14 | year={2025}
15 | }
16 | ```
17 | ## Paper Abstract
18 | Research in medical imaging primarily focuses on discrete data representations that poorly scale with grid resolution and fail to capture the often continuous nature of the underlying signal.
19 | Neural Fields (NFs) offer a powerful alternative by modeling data as continuous functions.
20 | While single-instance NFs have successfully been applied in medical contexts, extending them to large-scale medical datasets remains an open challenge.
21 | We therefore introduce [**MedFuncta**](https://arxiv.org/abs/2502.14401), a unified framework for large-scale NF training on diverse medical signals.
22 | Building on Functa, our approach encodes data into a unified representation, namely a 1D latent vector, that modulates a shared, meta-learned NF, enabling generalization across a dataset.
23 | We revisit common design choices, introducing a non-constant frequency parameter $\omega$ in widely used SIREN activations, and establish a connection between this $\omega$-schedule and layer-wise learning rates, relating our findings to recent work in theoretical learning dynamics.
24 | We additionally introduce a scalable meta-learning strategy for shared network learning that employs sparse supervision during training, thereby reducing memory consumption and computational overhead while maintaining competitive performance.
25 | Finally, we evaluate MedFuncta across a diverse range of medical datasets and show how to solve relevant downstream tasks on our neural data representation.
26 | To promote further research in this direction, we release our code, model weights and the first large-scale dataset - [**MedNF**](https://doi.org/10.5281/zenodo.14898708) - containing > 500 k latent vectors for multi-instance medical NFs.
27 |
28 |
29 |
30 |
31 |
32 | ## Dependencies
33 | We recommend using a [conda](https://github.com/conda-forge/miniforge#mambaforge) environment to install the required dependencies.
34 | You can create and activate such an environment called `medfuncta` by running the following commands:
35 | ```sh
36 | mamba env create -f environment.yaml
37 | mamba activate medfuncta
38 | ```
39 |
40 | ## Training (Meta-Learning)
41 | To obtain meta-learned shared model parameters, simply run the following command with the correct `config.yaml`:
42 | ```sh
43 | python train.py --config ./configs/experiments/DATASET_RESOLUTION.yaml
44 | ```
45 |
46 | ## Evaluation (Reconstruction Experiments)
47 | To perform reconstruction experiments (evaluate the reconstruction quality), simply run the following command with the correct `config.yaml`:
48 | ```sh
49 | python eval.py --config ./configs/eval/experiments/DATASET_RESOLUTION.yaml
50 | ```
51 | ## Create a MedNF Dataset
52 | To convert a dataset into our MedFuncta representation, simply run the following command with the correct `config.yaml`:
53 | ```sh
54 | python fit_NFset.py --config ./configs/fit/DATASET_RESOLUTION.yaml
55 | ```
56 |
57 | ## Classification Experiments
58 | The source code for reproducing our classification experiments can be found in `/downstream_tasks/classification.py`.
59 | All arguments can be set in the `Args` class in this script.
60 |
61 |
62 | ## MedNF Dataset
63 | We release **MedNF** a large-scale dataset containing more than 500 k medical NFs.
64 | More information on the dataset can be found in our paper (Appendix D).
65 | The dataset can be accessed here: [https://doi.org/10.5281/zenodo.14898708](https://doi.org/10.5281/zenodo.14898708).
66 |
67 | The dataset consists of the following 7 sub-datasets:
68 |
69 |
70 |
71 |
72 | ## Data
73 | To ensure good reproducibility, we trained and evaluated our network on publicly available datasets:
74 | * **MedMNIST**, a large-scale MNIST-like collection of standardized biomedical images. More information is avilable [here](https://medmnist.com/).
75 |
76 | * **MIT-BIH Arryhythmia**, a heartbeat classification dataset. We use a preprocessed version that is available [here](https://www.kaggle.com/datasets/shayanfazeli/heartbeat).
77 |
78 | * **BRATS 2023: Adult Glioma**, a dataset containing routine clinically-acquired, multi-site multiparametric magnetic resonance imaging (MRI) scans of brain tumor patients. We just used the T1-weighted images for training. The data is available [here](https://www.synapse.org/#!Synapse:syn51514105).
79 |
80 | * **LIDC-IDRI**, a dataset containing multi-site, thoracic computed tomography (CT) scans of lung cancer patients. The data is available [here](https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=1966254).
81 |
82 | The provided code works for the following data structure (you might need to adapt the directories in `data/dataset.py`):
83 | ```
84 | data
85 | └───BRATS
86 | └───BraTS-GLI-00000-000
87 | └───BraTS-GLI-00000-000-seg.nii.gz
88 | └───BraTS-GLI-00000-000-t1c.nii.gz
89 | └───BraTS-GLI-00000-000-t1n.nii.gz
90 | └───BraTS-GLI-00000-000-t2f.nii.gz
91 | └───BraTS-GLI-00000-000-t2w.nii.gz
92 | └───BraTS-GLI-00001-000
93 | └───BraTS-GLI-00002-000
94 | ...
95 |
96 | └───LIDC-IDRI
97 | └───LIDC-IDRI-0001
98 | └───preprocessed.nii.gz
99 | └───LIDC-IDRI-0002
100 | └───LIDC-IDRI-0003
101 | ...
102 |
103 | └───MIT-BIH
104 | └───mitbih_test.csv
105 | └───mitbih_train.csv
106 |
107 | ...
108 | ```
109 | We provide a script for preprocessing LIDC-IDRI. Simply run the following command with the correct path to the downloaded DICOM files `DICOM_PATH` and the directory you want to store the processed nifti files `NIFTI_PATH`:
110 | ```sh
111 | python data/preproc_lidc-idri.py --dicom_dir DICOM_PATH --nifti_dir NIFTI_PATH
112 | ```
113 |
114 | ## Acknowledgements
115 | Our code is based on / inspired by the following repositories:
116 | * https://github.com/jihoontack/GradNCP
117 | * https://github.com/google-deepmind/functa
118 | * https://github.com/pfriedri/wdm-3d
119 |
--------------------------------------------------------------------------------
/train/maml_boot.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from common.utils import psnr
3 |
4 |
5 | def get_grad_norm(grads, detach=True):
6 | grad_norm_list = []
7 | for grad in grads:
8 | if grad is None:
9 | grad_norm = 0
10 | else:
11 | if detach:
12 | grad_norm = torch.norm(grad.data, p=2, keepdim=True).unsqueeze(dim=0)
13 | else:
14 | grad_norm = torch.norm(grad, p=2, keepdim=True).unsqueeze(dim=0)
15 |
16 | grad_norm_list.append(grad_norm)
17 | return torch.norm(torch.cat(grad_norm_list, dim=0), p=2, dim=1)
18 |
19 |
20 | def train_step(args, step, model_wrapper, optimizer, data, metric_logger, logger, scheduler=None):
21 | """
22 | Function that performs a single meta update step.
23 | """
24 | model_wrapper.model.train() # Enable model training
25 | model_wrapper.coord_init() # Reset coordinates
26 | model_wrapper.model.reset_modulations() # Reset modulations (zero-initialization)
27 |
28 | batch_size = data.size(0)
29 |
30 | if step % args.print_img_step == 0:
31 | input = data # Save input data for logging
32 |
33 | """ Inner-loop optimization for G steps """
34 | _ = inner_adapt(model_wrapper=model_wrapper, data=data, step_size=args.inner_lr,
35 | num_steps=args.inner_steps, first_order=False, sample_type=args.sample_type)
36 |
37 | """ Compute loss using full context set"""
38 | model_wrapper.coord_init() # Reset coordinates
39 | loss_out = model_wrapper(data) # Compute reconstruction loss
40 | psnr_out = psnr(loss_out)
41 |
42 | if step % args.print_img_step == 0:
43 | images = model_wrapper() # Sample images
44 | loss = loss_out.mean()
45 |
46 | """ Meta update (optimize shared weights) """
47 | optimizer.zero_grad()
48 | loss.backward()
49 | grad_norm = torch.nn.utils.clip_grad_norm_(model_wrapper.model.parameters(), 0.5) # Clip gradient
50 | optimizer.step()
51 | torch.cuda.synchronize()
52 |
53 | """ Update scheduler """
54 | if scheduler:
55 | scheduler.step()
56 |
57 | """ Track stats """
58 | metric_logger.meters['loss'].update(loss_out.mean().item(), n=batch_size)
59 | metric_logger.meters['psnr'].update(psnr_out.mean().item(), n=batch_size)
60 | metric_logger.meters['grad_norm'].update(grad_norm.item(), n=batch_size)
61 | metric_logger.synchronize_between_processes()
62 |
63 | """ Log scalars to tensorboard & console """
64 | if step % args.print_step == 0:
65 | logger.scalar_summary('train/loss', metric_logger.loss.global_avg, step)
66 | logger.scalar_summary('train/psnr', metric_logger.psnr.global_avg, step)
67 |
68 | logger.log('[TRAIN-REC] [Step %3d] [Loss %f] [PSNR %.3f]' %
69 | (step, metric_logger.loss.global_avg, metric_logger.psnr.global_avg))
70 |
71 | logger.scalar_summary('supp/grad_norm', metric_logger.grad_norm.global_avg, step)
72 | if scheduler:
73 | logger.scalar_summary('supp/lr', scheduler.get_last_lr()[0], step)
74 |
75 | """ Log images to tensorboard"""
76 | if step % args.print_img_step == 0:
77 | logger.log_image('train/img_in', input, step)
78 | logger.log_image('train/img_pred', images, step)
79 |
80 | """ Log activation distributions and weight dynamics """
81 | if step % args.advanced_step == 0:
82 | if args.log_advanced:
83 | # Weight dynamics
84 | for name, param in model_wrapper.model.named_parameters():
85 | if 'weight' in name:
86 | logger.log_hist(f'weights/{name}', param.data, step)
87 | if param.grad is not None:
88 | logger.log_hist(f'grads/{name}', param.grad, step)
89 |
90 | metric_logger.reset()
91 |
92 |
93 | def inner_adapt(model_wrapper, data, step_size=1e-2, num_steps=3, first_order=False, sample_type='none'):
94 | """
95 | Performs the inner loop optimization.
96 | :param model_wrapper: the wrapped model.
97 | :param data: the data used for training.
98 | :param step_size: the inner_loop learning rate (alpha).
99 | :param num_steps: numer of inner-loop update steps G.
100 | :param first_order: if True, first order MAML is used (not recommended).
101 | :param sample_type: coordinate sample type.
102 | :return: loss
103 | """
104 | loss = 0. # Initialize outer_loop loss
105 |
106 | """ Perform num_step (G) inner-loop updates """
107 | for _ in range(num_steps):
108 | if sample_type != 'none':
109 | model_wrapper.sample_coordinates(sample_type='random', data=data) # Sample coordinates for the training step
110 | loss = inner_loop_step(model_wrapper, data, step_size, first_order)
111 | return loss
112 |
113 |
114 | def inner_loop_step(model_wrapper, data, inner_lr=1e-2, first_order=False):
115 | """ Performs a single inner-loop update. """
116 | batch_size = data.size(0)
117 |
118 | with torch.enable_grad():
119 | loss = model_wrapper(data)
120 | grads = torch.autograd.grad(
121 | loss.mean() * batch_size,
122 | model_wrapper.model.modulations,
123 | create_graph=not first_order,
124 | )[0]
125 | model_wrapper.model.modulations = model_wrapper.model.modulations - inner_lr * grads
126 | return loss
127 |
128 |
129 | def inner_adapt_test_scale(model_wrapper, data, step_size=1e-2, num_steps=3, first_order=False,
130 | sample_type='none', scale_type='grad'):
131 | """ Similar to inner_adapt, but with rescaled gradients at test-time """
132 | loss = 0. # Initialize outer_loop loss
133 |
134 | """ Perform num_step (H) inner-loop updates """
135 | for _ in range(num_steps):
136 | if sample_type != 'none':
137 | model_wrapper.sample_coordinates(sample_type='random', data=data)
138 | loss = inner_loop_step_tt_gradscale(model_wrapper, data, step_size, first_order, scale_type)
139 | return loss
140 |
141 |
142 | def inner_loop_step_tt_gradscale(model_wrapper, data, inner_lr=1e-2, first_order=False, scale_type='grad'):
143 | """ Similar to inner_loop_step, but with rescaled gradients at test-time. """
144 | batch_size = data.size(0)
145 | model_wrapper.model.zero_grad()
146 |
147 | """ Get gradients with sparse supervision (as in training) """
148 | with torch.enable_grad():
149 | subsample_loss = model_wrapper(data)
150 | subsample_grad = torch.autograd.grad(
151 | subsample_loss.mean() * batch_size,
152 | model_wrapper.model.modulations,
153 | create_graph=False,
154 | allow_unused=True
155 | )[0]
156 |
157 | model_wrapper.model.zero_grad()
158 | model_wrapper.coord_init() # Reset coordinates
159 |
160 | """ Get gradients wit full supervision (during inference)"""
161 | with torch.enable_grad():
162 | loss = model_wrapper(data)
163 | grads = torch.autograd.grad(
164 | loss.mean() * batch_size,
165 | model_wrapper.model.modulations,
166 | create_graph=not first_order,
167 | allow_unused=True
168 | )[0]
169 |
170 | """ Rescale the gradient """
171 | if scale_type == 'grad':
172 | subsample_grad_norm = get_grad_norm(subsample_grad, detach=True)
173 | grad_norm = get_grad_norm(grads, detach=True)
174 | grad_scale = subsample_grad_norm / (grad_norm + 1e-16)
175 | grad_scale_ = grad_scale.view((batch_size,) + (1,) * (len(grads.shape) - 1)).detach()
176 | else:
177 | raise NotImplementedError()
178 |
179 | model_wrapper.model.modulations = model_wrapper.model.modulations - inner_lr * grad_scale_ * grads
180 |
181 | return loss
182 |
--------------------------------------------------------------------------------
/models/model_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import random
5 | from einops import rearrange
6 |
7 |
8 | def exists(val):
9 | return val is not None
10 |
11 |
12 | class ModelWrapper(nn.Module):
13 | def __init__(self, args, model):
14 | super().__init__()
15 | self.args = args
16 | self.model = model
17 | self.data_type = args.data_type
18 |
19 | self.sampled_coord = None
20 | self.sampled_index = None
21 | self.gradncp_coord = None
22 | self.gradncp_index = None
23 |
24 | if self.data_type == 'img':
25 | self.width = args.data_size[1]
26 | self.height = args.data_size[2]
27 |
28 | mgrid = self.shape_to_coords((self.width, self.height))
29 | mgrid = rearrange(mgrid, 'h w c -> (h w) c')
30 |
31 | elif self.data_type == 'img3d':
32 | self.width = args.data_size[1]
33 | self.height = args.data_size[2]
34 | self.depth = args.data_size[3]
35 |
36 | mgrid = self.shape_to_coords((self.width, self.height, self.depth))
37 | mgrid = rearrange(mgrid, 'h w d c -> (h w d) c')
38 |
39 | elif self.data_type == 'timeseries':
40 | self.length = args.data_size[-1]
41 | mgrid = self.shape_to_coords([self.length])
42 |
43 | else:
44 | raise NotImplementedError()
45 |
46 | self.register_buffer('grid', mgrid)
47 |
48 | def coord_init(self):
49 | self.sampled_coord = None
50 | self.sampled_index = None
51 | self.gradncp_coord = None
52 | self.gradncp_index = None
53 |
54 | def get_batch_coords(self, x=None):
55 | if x is None:
56 | meta_batch_size = 1
57 | else:
58 | meta_batch_size = x.size(0)
59 |
60 | # batch of coordinates
61 | if self.sampled_coord is None and self.gradncp_coord is None:
62 | coords = self.grid
63 | elif self.gradncp_coord is not None:
64 | return self.gradncp_coord, meta_batch_size
65 | else:
66 | coords = self.sampled_coord
67 | coords = coords.clone().detach()[None, ...].repeat((meta_batch_size,) + (1,) * len(coords.shape))
68 | return coords, meta_batch_size
69 |
70 | def shape_to_coords(self, spatial_dims):
71 | coords = []
72 | for i in range(len(spatial_dims)):
73 | coords.append(torch.linspace(-1.0, 1.0, spatial_dims[i]))
74 | return torch.stack(torch.meshgrid(*coords, indexing='ij'), dim=-1)
75 |
76 | def sample_coordinates(self, sample_type, data):
77 | if sample_type == 'random':
78 | self.random_sample()
79 | elif sample_type == 'gradncp':
80 | self.gradncp(data)
81 | else:
82 | raise NotImplementedError()
83 |
84 | def gradncp(self, x):
85 | ratio = self.args.data_ratio
86 | meta_batch_size = x.size(0)
87 | coords = self.grid
88 | coords = coords.clone().detach()[None, ...].repeat((meta_batch_size,) + (1,) * len(coords.shape))
89 | coords = coords.to(self.args.device)
90 | with torch.no_grad():
91 | out, feature = self.model(coords, get_features=True)
92 |
93 | if self.data_type == 'img':
94 | out = rearrange(out, 'b hw c -> b c hw')
95 | feature = rearrange(feature, 'b hw f -> b f hw')
96 | x = rearrange(x, 'b c h w -> b c (h w)')
97 | elif self.data_type == 'img3d':
98 | out = rearrange(out, 'b hwd c -> b c hwd')
99 | feature = rearrange(feature, 'b hwd f -> b f hwd')
100 | x = rearrange(x, 'b c h w d -> b c (h w d)')
101 | elif self.data_type == 'timeseries':
102 | out = rearrange(out, 'b l c -> b c l')
103 | feature = rearrange(feature, 'b l f -> b f l')
104 | else:
105 | raise NotImplementedError()
106 |
107 | error = x - out
108 |
109 | gradient = -1 * feature.unsqueeze(dim=1) * error.unsqueeze(dim=2)
110 | gradient_bias = -1 * error.unsqueeze(dim=2)
111 | gradient = torch.cat([gradient, gradient_bias], dim=2)
112 | gradient = rearrange(gradient, 'b c f hw -> b (c f) hw')
113 | gradient_norm = torch.norm(gradient, dim=1)
114 |
115 | coords_len = gradient_norm.size(1)
116 |
117 | self.gradncp_index = torch.sort(gradient_norm, dim=1, descending=True)[1][:, :int(coords_len * ratio)]
118 | self.gradncp_coord = torch.gather(coords, 1, self.gradncp_index.unsqueeze(dim=2).repeat(1, 1, self.args.in_size))
119 | self.gradncp_index = self.gradncp_index.unsqueeze(dim=1).repeat(1, self.args.out_size, 1)
120 |
121 | def random_sample(self):
122 | coord_size = self.grid.size(0)
123 | perm = torch.randperm(coord_size)
124 | self.sampled_index = perm[:int(self.args.data_ratio * coord_size)]
125 | self.sampled_coord = self.grid[self.sampled_index]
126 | return self.sampled_coord
127 |
128 | def forward(self, x=None):
129 | if self.data_type == 'img':
130 | return self.forward_img(x)
131 | if self.data_type == 'img3d':
132 | return self.forward_img3d(x)
133 | if self.data_type == 'timeseries':
134 | return self.forward_timeseries(x)
135 | else:
136 | raise NotImplementedError()
137 |
138 | def forward_img(self, x):
139 | coords, meta_batch_size = self.get_batch_coords(x)
140 | coords = coords.to(self.args.device)
141 |
142 | out = self.model(coords)
143 | out = rearrange(out, 'b hw c -> b c hw')
144 |
145 | if exists(x):
146 | if self.sampled_coord is None and self.gradncp_coord is None:
147 | return F.mse_loss(x.view(meta_batch_size, -1), out.reshape(meta_batch_size, -1), reduce=False).mean(dim=1)
148 | elif self.gradncp_coord is not None:
149 | x = rearrange(x, 'b c h w -> b c (h w)')
150 | x = torch.gather(x, 2, self.gradncp_index)
151 | return F.mse_loss(x.view(meta_batch_size, -1), out.reshape(meta_batch_size, -1), reduce=False).mean(dim=1)
152 | else:
153 | x = rearrange(x, 'b c h w -> b c (h w)')[:, :, self.sampled_index]
154 | return F.mse_loss(x.view(meta_batch_size, -1), out.reshape(meta_batch_size, -1), reduce=False).mean(dim=1)
155 |
156 | out = rearrange(out, 'b c (h w) -> b c h w', h=self.height, w=self.width)
157 | return out
158 |
159 | def forward_img3d(self, x):
160 | coords, meta_batch_size = self.get_batch_coords(x)
161 | coords = coords.to(self.args.device)
162 |
163 | out = self.model(coords)
164 | out = rearrange(out, 'b hwd c -> b c hwd')
165 |
166 | if exists(x):
167 | if self.sampled_coord is None and self.gradncp_coord is None:
168 | return F.mse_loss(x.view(meta_batch_size, -1), out.reshape(meta_batch_size, -1), reduce=False).mean(dim=1)
169 | elif self.gradncp_coord is not None:
170 | x = rearrange(x, 'b c h w d -> b c (h w d)')
171 | x = torch.gather(x, 2, self.gradncp_index)
172 | return F.mse_loss(x.view(meta_batch_size, -1), out.reshape(meta_batch_size, -1), reduce=False).mean(dim=1)
173 | else:
174 | x = rearrange(x, 'b c h w d -> b c (h w d)')[:, :, self.sampled_index]
175 | return F.mse_loss(x.view(meta_batch_size, -1), out.reshape(meta_batch_size, -1), reduce=False).mean(dim=1)
176 |
177 | out = rearrange(out, 'b c (h w d) -> b c h w d', h=self.height, w=self.width, d=self.depth)
178 | return out
179 |
180 | def forward_timeseries(self, x):
181 | coords, meta_batch_size = self.get_batch_coords(x)
182 | coords = coords.to(self.args.device)
183 |
184 | out = self.model(coords)
185 | out = rearrange(out, 'b l c -> b c l')
186 |
187 | if exists(x):
188 | if self.sampled_coord is None and self.gradncp_coord is None:
189 | return F.mse_loss(x.view(meta_batch_size, -1), out.reshape(meta_batch_size, -1), reduce=False).mean(dim=1)
190 | elif self.gradncp_coord is not None:
191 | x = torch.gather(x, 2, self.gradncp_index)
192 | return F.mse_loss(x.view(meta_batch_size, -1), out.reshape(meta_batch_size, -1), reduce=False).mean(dim=1)
193 | else:
194 | x = x[:, :, self.sampled_index]
195 | return F.mse_loss(x.view(meta_batch_size, -1), out.reshape(meta_batch_size, -1), reduce=False).mean(dim=1)
196 | return out
197 |
--------------------------------------------------------------------------------
/downstream_tasks/classification.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import sys
4 |
5 | sys.path.append(".")
6 |
7 | import time
8 | import random
9 | import numpy as np
10 | import torch.nn as nn
11 | import torch.optim as optim
12 | import torchvision.transforms as T
13 | import torchvision.models as models
14 | from sklearn.neighbors import KNeighborsClassifier
15 | from sklearn.metrics import accuracy_score, f1_score
16 | from torch.utils.data import DataLoader, Dataset
17 | from torchmetrics.classification import MulticlassF1Score
18 |
19 |
20 | def set_random_seed(seed):
21 | random.seed(seed)
22 | np.random.seed(seed)
23 | torch.manual_seed(seed)
24 |
25 |
26 | class NFDataset(Dataset):
27 | def __init__(self, root_dir):
28 | self.root_dir = root_dir
29 | self.files = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.pt')]
30 |
31 | def __len__(self):
32 | return len(self.files)
33 |
34 | def __getitem__(self, idx):
35 | data = torch.load(self.files[idx], weights_only=False)
36 | return data['modulations'].float(), data['label']
37 |
38 |
39 | class SimpleClassifier(nn.Module):
40 | def __init__(self, input_dim, num_classes):
41 | super(SimpleClassifier, self).__init__()
42 | self.network = nn.Sequential(
43 | nn.Linear(input_dim, 512),
44 | nn.ReLU(),
45 | nn.Dropout(0.3),
46 | nn.Linear(512, 256),
47 | nn.ReLU(),
48 | nn.Dropout(0.3),
49 | nn.Linear(256, num_classes),
50 | )
51 |
52 | def forward(self, x):
53 | return self.network(x)
54 |
55 |
56 | class ResNet50Classifier(nn.Module):
57 | def __init__(self, num_classes, mode='rgb'):
58 | super(ResNet50Classifier, self).__init__()
59 | self.resnet50 = models.resnet50()
60 | self.resnet50.fc = nn.Linear(self.resnet50.fc.in_features, num_classes)
61 | if mode == 'grayscale':
62 | self.resnet50.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=7,
63 | stride=2, padding=3, bias=False)
64 |
65 | def forward(self, x):
66 | return self.resnet50(x)
67 |
68 |
69 | class EfficientNetB0Classifier(nn.Module):
70 | def __init__(self, num_classes, mode='rgb'):
71 | super(EfficientNetB0Classifier, self).__init__()
72 | self.efficientnet_b0 = models.efficientnet_b0()
73 | self.efficientnet_b0.classifier[1] = nn.Linear(self.efficientnet_b0.classifier[1].in_features, num_classes)
74 | if mode == 'grayscale':
75 | self.efficientnet_b0.features[0][0] = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3,
76 | stride=2, padding=1, bias=False)
77 |
78 | def forward(self, x):
79 | return self.efficientnet_b0(x)
80 |
81 |
82 | def load_nf_data_for_knn(dataset):
83 | """Load all data from NFDataset into numpy arrays for KNN"""
84 | features = []
85 | labels = []
86 |
87 | for i in range(len(dataset)):
88 | data, label = dataset[i]
89 | features.append(data.numpy())
90 | labels.append(label.item() if torch.is_tensor(label) else label)
91 |
92 | return np.array(features), np.array(labels)
93 |
94 |
95 | def train_knn(train_dataset, args):
96 | """Train KNN classifier"""
97 | print("Loading training data for KNN...")
98 | X_train, y_train = load_nf_data_for_knn(train_dataset)
99 |
100 | print(f"Training KNN with k={args.k_neighbors}...")
101 | knn = KNeighborsClassifier(n_neighbors=args.k_neighbors, n_jobs=-1)
102 | knn.fit(X_train, y_train)
103 |
104 | return knn
105 |
106 |
107 | def evaluate_knn(knn_model, dataset, num_classes):
108 | """Evaluate KNN classifier"""
109 | print("Loading evaluation data for KNN...")
110 | X, y = load_nf_data_for_knn(dataset)
111 |
112 | predictions = knn_model.predict(X)
113 | accuracy = accuracy_score(y, predictions) * 100
114 | f1 = f1_score(y, predictions, average='macro' if num_classes > 2 else 'binary')
115 |
116 | return accuracy, f1
117 |
118 |
119 | def train(model, train_loader, criterion, optimizer, device, num_classes):
120 | model.train()
121 | total_loss = 0.0
122 |
123 | for data, labels in train_loader:
124 | data, labels = data.to(device), labels.to(device)
125 | labels = labels.squeeze()
126 | optimizer.zero_grad()
127 | out = model(data)
128 | loss = criterion(out, labels)
129 | loss.backward()
130 | optimizer.step()
131 |
132 | total_loss += loss.item()
133 |
134 | return total_loss / len(train_loader)
135 |
136 |
137 | def evaluate(model, val_loader, criterion, device, num_classes):
138 | model.eval()
139 | total_loss = 0.0
140 | correct = 0
141 | total = 0
142 | f1_metric = MulticlassF1Score(num_classes=num_classes).to(device)
143 |
144 | with torch.no_grad():
145 | for data, labels in val_loader:
146 | data, labels = data.to(device), labels.to(device)
147 | labels = labels.squeeze()
148 | out = model(data)
149 | loss = criterion(out, labels)
150 |
151 | total_loss += loss.item()
152 | _, predicted = torch.max(out, 1)
153 | correct += (predicted == labels).sum().item()
154 | total += labels.size(0)
155 |
156 | f1_metric.update(predicted, labels)
157 |
158 | accuracy = 100 * correct / total
159 | f1_score = f1_metric.compute().item()
160 |
161 | return total_loss / len(val_loader), accuracy, f1_score
162 |
163 |
164 | class Args:
165 | def __init__(self):
166 | self.data_dir = "your/data/dir" # Data directory
167 | self.classifier = 'simple' # knn, simple, resnet, efficientnet
168 | self.mode = 'grayscale' # grayscale or rgb
169 | self.batch_size = 32
170 | self.input_dim = 2048
171 | self.num_classes = 7 # Number of classes
172 | self.learning_rate = 1e-3
173 | self.num_epochs = 50
174 | self.seed = 42
175 | self.k_neighbors = 3 # Number of neighbors for KNN
176 |
177 |
178 | def main(args):
179 | device = torch.device(f'cuda' if torch.cuda.is_available() else 'cpu')
180 |
181 | """ Enable determinism """
182 | set_random_seed(args.seed)
183 | torch.backends.cudnn.deterministic = True
184 | torch.backends.cudnn.benchmark = False
185 |
186 | """ Define Dataset and Dataloader """
187 | if args.classifier in ['simple', 'knn']:
188 | train_set = NFDataset(os.path.join(args.data_dir, "train"))
189 | val_set = NFDataset(os.path.join(args.data_dir, "val"))
190 | test_set = NFDataset(os.path.join(args.data_dir, "test"))
191 |
192 | elif args.classifier == 'resnet' or args.classifier == 'efficientnet':
193 | from medmnist import PneumoniaMNIST
194 | # from medmnist import DermaMNIST
195 | transforms = T.Compose([
196 | T.ToTensor(),
197 | ])
198 | train_set = PneumoniaMNIST(split='train', transform=transforms, download='True', size=64)
199 | val_set = PneumoniaMNIST(split='val', transform=transforms, download='True', size=64)
200 | test_set = PneumoniaMNIST(split='test', transform=transforms, download='True', size=64)
201 | # train_set = DermaMNIST(split='train', transform=transforms, download='True', size=64)
202 | # val_set = DermaMNIST(split='val', transform=transforms, download='True', size=64)
203 | # test_set = DermaMNIST(split='test', transform=transforms, download='True', size=64)
204 | else:
205 | raise NotImplementedError()
206 |
207 | # Handle KNN separately since it doesn't use PyTorch training loop
208 | if args.classifier == 'knn':
209 | print("Training KNN Classifier...")
210 | start_time = time.time()
211 |
212 | # Train KNN
213 | knn_model = train_knn(train_set, args)
214 |
215 | # Evaluate on validation set
216 | val_acc, val_f1 = evaluate_knn(knn_model, val_set, args.num_classes)
217 | print(f"Validation Accuracy: {val_acc:.2f}%, Validation F1: {val_f1:.4f}")
218 |
219 | # Evaluate on test set
220 | test_acc, test_f1 = evaluate_knn(knn_model, test_set, args.num_classes)
221 |
222 | end_time = time.time()
223 | elapsed_time = end_time - start_time
224 | print(f"Elapsed time: {elapsed_time:.2f} seconds")
225 | print(f"Test Accuracy: {test_acc:.2f}%, Test F1 Score: {test_f1:.4f}")
226 |
227 | return
228 |
229 | # For neural network classifiers
230 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
231 | val_loader = DataLoader(val_set, batch_size=args.batch_size)
232 | test_loader = DataLoader(test_set, batch_size=args.batch_size)
233 |
234 | """ Select classification model """
235 | if args.classifier == 'simple':
236 | model = SimpleClassifier(args.input_dim, args.num_classes).to(device)
237 | elif args.classifier == 'resnet':
238 | model = ResNet50Classifier(args.num_classes, mode=args.mode).to(device)
239 | elif args.classifier == 'efficientnet':
240 | model = EfficientNetB0Classifier(args.num_classes, mode=args.mode).to(device)
241 |
242 | pytorch_total_params = sum(p.numel() for p in model.parameters())
243 | print(f"Parameters: {pytorch_total_params}")
244 |
245 | """ Define optimization criterion and optimizer """
246 | criterion = nn.CrossEntropyLoss()
247 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
248 |
249 | """ Run training and validation loop """
250 | best_val_acc = 0.0
251 | start_time = time.time()
252 | for epoch in range(args.num_epochs):
253 | train_loss = train(model, train_loader, criterion, optimizer, device, args.num_classes)
254 | val_loss, val_acc, val_f1 = evaluate(model, val_loader, criterion, device, args.num_classes)
255 | if val_acc > best_val_acc:
256 | best_val_acc = val_acc
257 | best_model = model.state_dict()
258 |
259 | print(f"Epoch {epoch + 1}/{args.num_epochs}: ",
260 | f"Train Loss: {train_loss:.4f} ",
261 | f"| Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%, Val F1: {val_f1:.4f}")
262 |
263 | end_time = time.time()
264 | elapsed_time = end_time - start_time
265 | print(f"Elapsed time: {elapsed_time} seconds")
266 |
267 | """ Final evaluation on test set """
268 | model.load_state_dict(best_model)
269 | test_loss, test_acc, test_f1 = evaluate(model, test_loader, criterion, device, args.num_classes)
270 | print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%, Test F1 Score: {test_f1:.4f}")
271 |
272 |
273 | if __name__ == "__main__":
274 | args = Args()
275 | main(args)
276 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import torchvision.transforms as T
5 | import pandas as pd
6 | import os
7 | import nibabel
8 | import cv2 as cv
9 |
10 |
11 | class ECG1D(torch.utils.data.Dataset):
12 | def __init__(self, directory, test=False):
13 | super().__init__()
14 |
15 | if not test:
16 | self.df = pd.read_csv(directory + '/mitbih_train.csv')
17 | else:
18 | self.df = pd.read_csv(directory + '/mitbih_test.csv')
19 |
20 | def __len__(self):
21 | return len(self.df)
22 |
23 | def __getitem__(self, idx):
24 | sample = self.df.iloc[idx, :-1].values.astype(float)
25 | label = self.df.iloc[idx, -1]
26 | sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(dim=0)
27 | label = torch.tensor(label, dtype=torch.long)
28 |
29 | return sample, label
30 |
31 |
32 | class BRATSVolumes(torch.utils.data.Dataset):
33 | def __init__(self, directory, normalize=None, img_size=32):
34 | super().__init__()
35 | self.directory = os.path.expanduser(directory)
36 | self.normalize = normalize or (lambda x: x)
37 | self.img_size = img_size
38 | self.seqtypes = ['t1n', 't1c', 't2w', 't2f', 'seg']
39 | self.seqtypes_set = set(self.seqtypes)
40 | self.database = []
41 |
42 | for root, dirs, files in os.walk(self.directory):
43 | # Ensure determinism
44 | dirs.sort()
45 | files.sort()
46 | # if there are no subdirs, we have a datadir
47 | if not dirs:
48 | datapoint = dict()
49 | # extract all files as channels
50 | for f in files:
51 | seqtype = f.split('-')[4].split('.')[0]
52 | datapoint[seqtype] = os.path.join(root, f)
53 | self.database.append(datapoint)
54 |
55 | def __getitem__(self, x):
56 | filedict = self.database[x]
57 | name = filedict['t1n']
58 | nib_img = nibabel.load(name) # We only use t1 weighted images
59 | out = nib_img.get_fdata()
60 |
61 | # Clip and normalize the images
62 | out_clipped = np.clip(out, np.quantile(out, 0.001), np.quantile(out, 0.999))
63 | out_normalized = (out_clipped - np.min(out_clipped)) / (np.max(out_clipped) - np.min(out_clipped))
64 | out = torch.tensor(out_normalized)
65 |
66 | # Zero pad images
67 | image = torch.zeros(1, 256, 256, 256)
68 | image[:, 8:-8, 8:-8, 50:-51] = out
69 |
70 | # Downsampling
71 | if self.img_size == 32:
72 | downsample = nn.AvgPool3d(kernel_size=8, stride=8)
73 | image = downsample(image)
74 |
75 | if self.img_size == 64:
76 | downsample = nn.AvgPool3d(kernel_size=4, stride=4)
77 | image = downsample(image)
78 |
79 | # Normalization
80 | image = self.normalize(image)
81 |
82 | # Insert dummy label
83 | label = 1
84 |
85 | return image, label
86 |
87 | def __len__(self):
88 | return len(self.database)
89 |
90 | class LIDCVolumes(torch.utils.data.Dataset):
91 | def __init__(self, directory, normalize=None, img_size=32):
92 | super().__init__()
93 | self.directory = os.path.expanduser(directory)
94 | self.normalize = normalize or (lambda x: x)
95 | self.img_size = img_size
96 | self.database = []
97 |
98 | for root, dirs, files in os.walk(self.directory):
99 | # Ensure determinism
100 | dirs.sort()
101 | files.sort()
102 | # if there are no subdirs, we have a datadir
103 | if not dirs:
104 | datapoint = dict()
105 | for f in files:
106 | datapoint['image'] = os.path.join(root, f)
107 | if len(datapoint) != 0:
108 | self.database.append(datapoint)
109 |
110 | def __getitem__(self, x):
111 | filedict = self.database[x]
112 | name = filedict['image']
113 | nib_img = nibabel.load(name)
114 | out = nib_img.get_fdata()
115 |
116 | # Clip and normalize the images
117 | out_clipped = np.clip(out, np.quantile(out, 0.001), np.quantile(out, 0.999))
118 | out_normalized = (out_clipped - np.min(out_clipped)) / (np.max(out_clipped) - np.min(out_clipped))
119 | out = torch.tensor(out_normalized)
120 |
121 | image = torch.zeros(1, 256, 256, 256)
122 | image[:, :, :, :] = out
123 |
124 | if self.img_size == 32:
125 | downsample = nn.AvgPool3d(kernel_size=8, stride=8)
126 | image = downsample(image)
127 |
128 | if self.img_size == 64:
129 | downsample = nn.AvgPool3d(kernel_size=4, stride=4)
130 | image = downsample(image)
131 |
132 | # normalization
133 | image = self.normalize(image)
134 |
135 | # Insert dummy label
136 | label = 1
137 |
138 | return image, label
139 |
140 | def __len__(self):
141 | return len(self.database)
142 |
143 |
144 | class EchoNet(torch.utils.data.Dataset):
145 | def __init__(self, directory, split='TRAIN'):
146 | self.data = pd.read_csv(directory + '/FileList.csv')
147 | self.data = self.data[self.data['Split'] == split]
148 | self.video_dir = directory + '/Videos/'
149 | self.max_frames = 10
150 |
151 | def __len__(self):
152 | return len(self.data)
153 |
154 | def __getitem__(self, idx):
155 | row = self.data.iloc[idx]
156 | filename = row['FileName']
157 | ef = torch.tensor(row['EF'], dtype=torch.float32)
158 |
159 | video_path = os.path.join(self.video_dir, f"{filename}.avi")
160 | video = self.load_video(video_path)
161 |
162 | if self.transform:
163 | video = self.transform(video)
164 |
165 | return video, ef
166 |
167 | def load_video(self, path):
168 | cap = cv.VideoCapture(path)
169 | frames = []
170 |
171 | while True:
172 | ret, frame = cap.read()
173 | if not ret:
174 | break
175 | frame = cv.cvtColor(frame, cv.COLOR_BGR2GRAY) # Convert to grayscale
176 | frame = cv.resize(frame, (112, 112)) # Resize if needed
177 | frames.append(frame)
178 |
179 | cap.release()
180 | frames = np.stack(frames, axis=0)
181 | frames = torch.tensor(frames, dtype=torch.float32) / 255.0
182 | frames = frames.unsqueeze(1)
183 |
184 | # Crop/pad to max_frames
185 | if self.max_frames:
186 | T = frames.shape[0]
187 | if T > self.max_frames:
188 | frames = frames[:self.max_frames]
189 | elif T < self.max_frames:
190 | pad = torch.zeros((self.max_frames - T, 1, 112, 112))
191 | frames = torch.cat([frames, pad], dim=0)
192 |
193 | return frames
194 |
195 |
196 | def get_dataset(args, only_test=False, all=False):
197 | train_set = None
198 | val_set = None
199 | test_set = None
200 |
201 | #############################################
202 | ############# 2D Image Datasets #############
203 | #############################################
204 | if args.dataset == 'chestmnist':
205 | from medmnist import ChestMNIST
206 | transforms = T.Compose([
207 | T.ToTensor(),
208 | T.Grayscale()
209 | ])
210 |
211 | train_set = ChestMNIST(split='train', transform=transforms, download='True', size=args.img_size)
212 | val_set = ChestMNIST(split='val', transform=transforms, download='True', size=args.img_size)
213 | test_set = ChestMNIST(split='test', transform=transforms, download='True', size=args.img_size)
214 |
215 | print(f'Training set containing {len(train_set)} images.')
216 | print(f'Validation set containing {len(val_set)} images.')
217 | print(f'Test set containing {len(test_set)} images.')
218 |
219 | args.data_type = 'img'
220 | args.in_size, args.out_size = 2, 1
221 | args.data_size = (1, args.img_size, args.img_size)
222 |
223 | elif args.dataset == 'pneumoniamnist':
224 | from medmnist import PneumoniaMNIST
225 | transforms = T.Compose([
226 | T.ToTensor(),
227 | T.Grayscale()
228 | ])
229 |
230 | train_set = PneumoniaMNIST(split='train', transform=transforms, download='True', size=args.img_size)
231 | val_set = PneumoniaMNIST(split='val', transform=transforms, download='True', size=args.img_size)
232 | test_set = PneumoniaMNIST(split='test', transform=transforms, download='True', size=args.img_size)
233 |
234 | print(f'Training set containing {len(train_set)} images.')
235 | print(f'Validation set containing {len(val_set)} images.')
236 | print(f'Test set containing {len(test_set)} images.')
237 |
238 | args.data_type = 'img'
239 | args.in_size, args.out_size = 2, 1
240 | args.data_size = (1, args.img_size, args.img_size)
241 |
242 | elif args.dataset == 'retinamnist':
243 | from medmnist import RetinaMNIST
244 | transforms = T.Compose([
245 | T.ToTensor(),
246 | ])
247 | train_set = RetinaMNIST(split='train', transform=transforms, download='True', size=args.img_size)
248 | val_set = RetinaMNIST(split='val', transform=transforms, download='True', size=args.img_size)
249 | test_set = RetinaMNIST(split='test', transform=transforms, download='True', size=args.img_size)
250 |
251 | print(f'Training set containing {len(train_set)} images.')
252 | print(f'Validation set containing {len(val_set)} images.')
253 | print(f'Test set containing {len(test_set)} images.')
254 |
255 | args.data_type = 'img'
256 | args.in_size, args.out_size = 2, 3
257 | args.data_size = (3, args.img_size, args.img_size)
258 |
259 | elif args.dataset == 'dermamnist':
260 | from medmnist import DermaMNIST
261 | transforms = T.Compose([
262 | T.ToTensor(),
263 | ])
264 | train_set = DermaMNIST(split='train', transform=transforms, download='True', size=args.img_size)
265 | val_set = DermaMNIST(split='val', transform=transforms, download='True', size=args.img_size)
266 | test_set = DermaMNIST(split='test', transform=transforms, download='True', size=args.img_size)
267 |
268 | print(f'Training set containing {len(train_set)} images.')
269 | print(f'Validation set containing {len(val_set)} images.')
270 | print(f'Test set containing {len(test_set)} images.')
271 |
272 | args.data_type = 'img'
273 | args.in_size, args.out_size = 2, 3
274 | args.data_size = (3, args.img_size, args.img_size)
275 |
276 | elif args.dataset == 'octmnist':
277 | from medmnist import OCTMNIST
278 | transforms = T.Compose([
279 | T.ToTensor(),
280 | T.Grayscale()
281 | ])
282 | train_set = OCTMNIST(split='train', transform=transforms, download='True', size=args.img_size)
283 | val_set = OCTMNIST(split='val', transform=transforms, download='True', size=args.img_size)
284 | test_set = OCTMNIST(split='test', transform=transforms, download='True', size=args.img_size)
285 |
286 | print(f'Training set containing {len(train_set)} images.')
287 | print(f'Validation set containing {len(val_set)} images.')
288 | print(f'Test set containing {len(test_set)} images.')
289 |
290 | args.data_type = 'img'
291 | args.in_size, args.out_size = 2, 1
292 | args.data_size = (1, args.img_size, args.img_size)
293 |
294 | elif args.dataset == 'pathmnist':
295 | from medmnist import PathMNIST
296 | transforms = T.Compose([
297 | T.ToTensor(),
298 | ])
299 | train_set = PathMNIST(split='train', transform=transforms, download='True', size=args.img_size)
300 | val_set = PathMNIST(split='val', transform=transforms, download='True', size=args.img_size)
301 | test_set = PathMNIST(split='test', transform=transforms, download='True', size=args.img_size)
302 |
303 | print(f'Training set containing {len(train_set)} images.')
304 | print(f'Validation set containing {len(val_set)} images.')
305 | print(f'Test set containing {len(test_set)} images.')
306 |
307 | args.data_type = 'img'
308 | args.in_size, args.out_size = 2, 3
309 | args.data_size = (3, args.img_size, args.img_size)
310 |
311 | elif args.dataset == 'tissuemnist':
312 | from medmnist import TissueMNIST
313 | transforms = T.Compose([
314 | T.ToTensor(),
315 | T.Grayscale()
316 | ])
317 | train_set = TissueMNIST(split='train', transform=transforms, download='True', size=args.img_size)
318 | val_set = TissueMNIST(split='val', transform=transforms, download='True', size=args.img_size)
319 | test_set = TissueMNIST(split='test', transform=transforms, download='True', size=args.img_size)
320 |
321 | print(f'Training set containing {len(train_set)} images.')
322 | print(f'Validation set containing {len(val_set)} images.')
323 | print(f'Test set containing {len(test_set)} images.')
324 |
325 | args.data_type = 'img'
326 | args.in_size, args.out_size = 2, 1
327 | args.data_size = (1, args.img_size, args.img_size)
328 |
329 | #############################################
330 | ############# 3D Image Datasets #############
331 | #############################################
332 | elif args.dataset == 'nodulemnist':
333 | from medmnist import NoduleMNIST3D
334 | train_set = NoduleMNIST3D(split='train', download='True', size=args.img_size)
335 | val_set = NoduleMNIST3D(split='val', download='True', size=args.img_size)
336 | test_set = NoduleMNIST3D(split='test', download='True', size=args.img_size)
337 |
338 | print(f'Training set containing {len(train_set)} images.')
339 | print(f'Validation set containing {len(val_set)} images.')
340 | print(f'Test set containing {len(test_set)} images.')
341 |
342 | args.data_type = 'img3d'
343 | args.in_size, args.out_size = 3, 1
344 | args.data_size = (1, args.img_size, args.img_size, args.img_size)
345 |
346 | elif args.dataset == 'organmnist':
347 | from medmnist import OrganMNIST3D
348 | train_set = OrganMNIST3D(split='train', download='True', size=args.img_size)
349 | val_set = OrganMNIST3D(split='val', download='True', size=args.img_size)
350 | test_set = OrganMNIST3D(split='test', download='True', size=args.img_size)
351 |
352 | print(f'Training set containing {len(train_set)} images.')
353 | print(f'Validation set containing {len(val_set)} images.')
354 | print(f'Test set containing {len(test_set)} images.')
355 |
356 | args.data_type = 'img3d'
357 | args.in_size, args.out_size = 3, 1
358 | args.data_size = (1, args.img_size, args.img_size, args.img_size)
359 |
360 | elif args.dataset == 'vesselmnist':
361 | from medmnist import VesselMNIST3D
362 | train_set = VesselMNIST3D(split='train', download='True', size=args.img_size)
363 | val_set = VesselMNIST3D(split='val', download='True', size=args.img_size)
364 | test_set = VesselMNIST3D(split='test', download='True', size=args.img_size)
365 |
366 | print(f'Training set containing {len(train_set)} images.')
367 | print(f'Validation set containing {len(val_set)} images.')
368 | print(f'Test set containing {len(test_set)} images.')
369 |
370 | args.data_type = 'shape3d'
371 | args.in_size, args.out_size = 3, 1
372 | args.data_size = (1, args.img_size, args.img_size, args.img_size)
373 |
374 | elif args.dataset == 'brats':
375 | dataset = BRATSVolumes('/raid/cian/user/paul.friedrich/datasets/BRATS2023-GLI/', img_size=args.img_size)
376 |
377 | # Define split sizes
378 | train_size = int(0.7 * len(dataset)) # 70% for training
379 | test_size = (len(dataset) - train_size) // 2 # 15% for testing
380 | val_size = len(dataset) - train_size - test_size # 15% for validation
381 |
382 | generator = torch.Generator().manual_seed(42)
383 | train_set, test_set, val_set = torch.utils.data.random_split(dataset, [train_size, test_size, val_size], generator=generator)
384 |
385 | print(f'Training set containing {len(train_set)} images.')
386 | print(f'Validation set containing {len(val_set)} images.')
387 | print(f'Test set containing {len(test_set)} images.')
388 |
389 | args.data_type = 'img3d'
390 | args.in_size, args.out_size = 3, 1
391 | args.data_size = (1, args.img_size, args.img_size, args.img_size)
392 |
393 | elif args.dataset == 'lidc-idri':
394 | dataset = LIDCVolumes('/raid/cian/user/paul.friedrich/datasets/lidc-nifti/', img_size=args.img_size)
395 |
396 | # Define split sizes
397 | train_size = int(0.7 * len(dataset)) # 70% for training
398 | test_size = (len(dataset) - train_size) // 2 # 15% for testing
399 | val_size = len(dataset) - train_size - test_size # 15% for validation
400 |
401 | generator = torch.Generator().manual_seed(42)
402 | train_set, test_set, val_set = torch.utils.data.random_split(dataset, [train_size, test_size, val_size], generator=generator)
403 |
404 | print(f'Training set containing {len(train_set)} images.')
405 | print(f'Validation set containing {len(val_set)} images.')
406 | print(f'Test set containing {len(test_set)} images.')
407 |
408 | args.data_type = 'img3d'
409 | args.in_size, args.out_size = 3, 1
410 | args.data_size = (1, args.img_size, args.img_size, args.img_size)
411 |
412 | #############################################
413 | ############ 2D+t Video Datasets ############
414 | #############################################
415 | elif args.dataset == 'echonet':
416 | train_set = EchoNet('/raid/cian/user/paul.friedrich/datasets/EchoNet/', split='TRAIN')
417 | val_set = EchoNet('/raid/cian/user/paul.friedrich/datasets/EchoNet/', split='VAL')
418 | test_set = EchoNet('/raid/cian/user/paul.friedrich/datasets/EchoNet/', split='TEST')
419 |
420 | print(f'Training set containing {len(train_set)} videos.')
421 | print(f'Validation set containing {len(val_set)} videos.')
422 | print(f'Test set containing {len(test_set)} videos.')
423 |
424 | args.data_type='img3d'
425 | args.in_size, args.out_size = 3, 1
426 | args.data_size = (1, 10, args.img_size, args.img_size)
427 |
428 | #############################################
429 | ########## 1D Timeseries Datasets ###########
430 | #############################################
431 | elif args.dataset == 'ecg':
432 | train_set = ECG1D('/home/paul.friedrich/ecg_classification/', test=False)
433 | valtest_set = ECG1D('/home/paul.friedrich/ecg_classification/', test=True)
434 |
435 | test_size = len(valtest_set) // 2
436 | val_size = len(valtest_set) - test_size
437 |
438 | generator = torch.Generator().manual_seed(42)
439 | test_set, val_set = torch.utils.data.random_split(valtest_set, [test_size, val_size], generator=generator)
440 |
441 | print(f'Training set containing {len(train_set)} ECG signals.')
442 | print(f'Validation set containing {len(val_set)} ECG signals.')
443 | print(f'Test set containing {len(test_set)} ECG signals.')
444 |
445 | args.data_type = 'timeseries'
446 | args.in_size, args.out_size = 1, 1
447 | args.data_size = (1, args.img_size)
448 |
449 | else:
450 | raise NotImplementedError()
451 |
452 | if only_test:
453 | return test_set
454 |
455 | elif all:
456 | return train_set, val_set, test_set
457 |
458 | else:
459 | return train_set, val_set
460 |
--------------------------------------------------------------------------------
/common/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import shutil
4 | import time
5 | import pickle
6 | import numpy as np
7 | import random
8 | import torch
9 | import matplotlib.pyplot as plt
10 | import torch.distributed as dist
11 |
12 | from collections import OrderedDict, defaultdict, deque
13 | from datetime import datetime
14 | from torch.utils.tensorboard import SummaryWriter
15 |
16 |
17 | def set_random_seed(seed):
18 | random.seed(seed)
19 | np.random.seed(seed)
20 | torch.manual_seed(seed)
21 | torch.cuda.manual_seed(seed)
22 | torch.cuda.manual_seed_all(seed)
23 |
24 |
25 | def load_checkpoint(logdir, mode='last'):
26 | model_path = os.path.join(logdir, f'{mode}.model')
27 | optim_path = os.path.join(logdir, f'{mode}.optim')
28 | config_path = os.path.join(logdir, f'{mode}.configs')
29 | lr_path = os.path.join(logdir, f'{mode}.lr')
30 |
31 | print(model_path)
32 | print(optim_path)
33 |
34 | print("=> Loading checkpoint from '{}'".format(logdir))
35 | if os.path.exists(model_path):
36 | model_state = torch.load(model_path)
37 | optim_state = torch.load(optim_path)
38 | with open(config_path, 'rb') as handle:
39 | cfg = pickle.load(handle)
40 | else:
41 | return None, None, None, None
42 |
43 | if os.path.exists(lr_path):
44 | lr_dict = torch.load(lr_path)
45 | else:
46 | lr_dict = None
47 |
48 | return model_state, optim_state, cfg, lr_dict
49 |
50 |
51 | def save_checkpoint(args, step, best_psnr, model, optim_state, logdir, is_best=False, suffix=''):
52 | if is_best:
53 | prefix = 'best'
54 | else:
55 | prefix = 'last'
56 |
57 | model_state = model.state_dict()
58 |
59 | last_model = os.path.join(logdir, f'{prefix}{suffix}.model')
60 | last_optim = os.path.join(logdir, f'{prefix}{suffix}.optim')
61 | last_config = os.path.join(logdir, f'{prefix}{suffix}.configs')
62 |
63 | if isinstance(args.inner_lr, OrderedDict):
64 | last_lr = os.path.join(logdir, f'{prefix}{suffix}.lr')
65 | torch.save(args.inner_lr, last_lr)
66 | if hasattr(args, 'moving_average'):
67 | last_ema = os.path.join(logdir, f'{prefix}{suffix}.ema')
68 | torch.save(args.moving_average, last_ema)
69 | if hasattr(args, 'moving_inner_lr'):
70 | last_lr_ema = os.path.join(logdir, f'{prefix}{suffix}.lr_ema')
71 | torch.save(args.moving_inner_lr, last_lr_ema)
72 |
73 | opt = {
74 | 'step': step,
75 | 'best': best_psnr
76 | }
77 | torch.save(model_state, last_model)
78 | torch.save(optim_state, last_optim)
79 | with open(last_config, 'wb') as handle:
80 | pickle.dump(opt, handle, protocol=pickle.HIGHEST_PROTOCOL)
81 |
82 |
83 | def save_checkpoint_step(args, step, best_psnr, model, optim_state, logdir, suffix=''):
84 | model_state = model.state_dict()
85 |
86 | last_model = os.path.join(logdir, f'step{step}{suffix}.model')
87 | last_optim = os.path.join(logdir, f'step{step}{suffix}.optim')
88 | last_config = os.path.join(logdir, f'step{step}{suffix}.configs')
89 |
90 | if isinstance(args.inner_lr, OrderedDict):
91 | last_lr = os.path.join(logdir, f'step{step}{suffix}.lr')
92 | torch.save(args.inner_lr, last_lr)
93 | if hasattr(args, 'moving_average'):
94 | last_ema = os.path.join(logdir, f'step{step}{suffix}.ema')
95 | torch.save(args.moving_average, last_ema)
96 | if hasattr(args, 'moving_inner_lr'):
97 | last_lr_ema = os.path.join(logdir, f'step{step}{suffix}.lr_ema')
98 | torch.save(args.moving_inner_lr, last_lr_ema)
99 |
100 | opt = {
101 | 'step': step,
102 | 'best': best_psnr
103 | }
104 | torch.save(model_state, last_model)
105 | torch.save(optim_state, last_optim)
106 | with open(last_config, 'wb') as handle:
107 | pickle.dump(opt, handle, protocol=pickle.HIGHEST_PROTOCOL)
108 |
109 |
110 | def resume_training(args, model, optimizer):
111 | if args.resume_path is not None:
112 | model_state, optimizer_state, config, lr_dict = load_checkpoint(args.resume_path, mode='best')
113 | model.load_state_dict(model_state)
114 | optimizer.load_state_dict(optimizer_state)
115 | start_step = config['step']
116 | best_psnr = config['best']
117 | is_best = False
118 | psnr = 0.
119 |
120 | if lr_dict is not None:
121 | args.inner_lr = lr_dict
122 |
123 | else:
124 | is_best = False
125 | start_step = 1
126 | best_psnr = 0.
127 | psnr = 0.
128 | return is_best, start_step, best_psnr, psnr
129 |
130 |
131 | def is_dist_avail_and_initialized():
132 | if not dist.is_available():
133 | return False
134 | if not dist.is_initialized():
135 | return False
136 | return True
137 |
138 |
139 | class Logger(object):
140 | """
141 | Reference: https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
142 | """
143 |
144 | def __init__(self, fn, ask=True, today=True, rank=0, grid_search=False):
145 | self.rank = rank
146 | self.log_path = './logs/'
147 | self.logdir = None
148 |
149 | if grid_search:
150 | self.log_path = './logs_gridsearch/'
151 |
152 | if self.rank == 0:
153 | if not os.path.exists(self.log_path):
154 | os.mkdir(self.log_path)
155 | self.today = today
156 |
157 | logdir = self._make_dir(fn)
158 |
159 | if not os.path.exists(logdir):
160 | os.mkdir(logdir)
161 |
162 | if len(os.listdir(logdir)) != 0 and ask:
163 | ans = input("log_dir is not empty. All data inside log_dir will be deleted. "
164 | "Will you proceed [y/N]? ")
165 | if ans in ['y', 'Y']:
166 | shutil.rmtree(logdir)
167 | else:
168 | exit(1)
169 |
170 | self.set_dir(logdir)
171 |
172 | def _make_dir(self, fn):
173 | if self.today:
174 | today = datetime.today().strftime("%y%m%d")
175 | logdir = self.log_path + today + '_' + fn
176 | else:
177 | logdir = self.log_path + fn
178 | return logdir
179 |
180 | def set_dir(self, logdir, log_fn='log.txt'):
181 | self.logdir = logdir
182 | if not os.path.exists(logdir):
183 | os.mkdir(logdir)
184 | self.writer = SummaryWriter(logdir)
185 | self.log_file = open(os.path.join(logdir, log_fn), 'a')
186 |
187 | def close_writer(self):
188 | if self.rank == 0:
189 | self.writer.close()
190 |
191 | def log(self, string):
192 | if self.rank == 0:
193 | self.log_file.write('[%s] %s' % (datetime.now(), string) + '\n')
194 | self.log_file.flush()
195 |
196 | print('[%s] %s' % (datetime.now(), string))
197 | sys.stdout.flush()
198 |
199 | def log_dirname(self, string):
200 | if self.rank == 0:
201 | self.log_file.write('%s (%s)' % (string, self.logdir) + '\n')
202 | self.log_file.flush()
203 |
204 | print('%s (%s)' % (string, self.logdir))
205 | sys.stdout.flush()
206 |
207 | def save_df(self, dataframe):
208 | if self.rank == 0:
209 | filename = self.logdir + '/ww.csv'
210 | print(filename)
211 | dataframe.to_csv(filename, sep='\t', header=True)
212 |
213 | def scalar_summary(self, tag, value, step):
214 | """Log a scalar variable."""
215 | if self.rank == 0:
216 | self.writer.add_scalar(tag, value, step)
217 |
218 | def log_hist(self, tag, value, step):
219 | if self.rank == 0:
220 | self.writer.add_histogram(tag, value, step)
221 |
222 | def log_hyperparameters(self, args):
223 | if self.rank == 0:
224 | self.writer.add_text(
225 | 'config',
226 | '\n'.join([f'--{k}={repr(v)}
' for k, v in vars(args).items()])
227 | )
228 |
229 | def log_hparams(self, h_dict, m_dict):
230 | if self.rank == 0:
231 | self.writer.add_hparams(h_dict, m_dict, run_name='.')
232 |
233 | def log_image(self, tag, images, step):
234 | """Log an image tensor."""
235 | if self.rank == 0:
236 | if len(images.shape) == 3: # Timeseries
237 | x = torch.arange(1, images.shape[2]+1).numpy()
238 | plt.figure(figsize=(10, 6))
239 | for i in range(6):
240 | y = images[i, 0, :].detach().cpu().numpy()
241 | plt.plot(x, y, label=f"ECG {i+1}")
242 | plt.ylabel("Signal Value")
243 | plt.grid(True)
244 |
245 | plt.tight_layout()
246 | self.writer.add_figure(tag, plt.gcf(), step)
247 |
248 | if len(images.shape) == 4: # 2D Images
249 | self.writer.add_images(tag, images, step)
250 |
251 | if len(images.shape) == 5: # 3D Images
252 | # Log middle slices along all 3 dimensions
253 | batch_size, channels, depth, height, width = images.shape
254 |
255 | # Select the middle slices
256 | middle_depth = depth // 2
257 | middle_height = height // 2
258 | middle_width = width // 2
259 |
260 | # Extract middle slices along each axis
261 | slices_depth = images[:, :, middle_depth, :, :] # Middle slice along depth
262 | slices_height = images[:, :, :, middle_height, :] # Middle slice along height
263 | slices_width = images[:, :, :, :, middle_width] # Middle slice along width
264 |
265 | # Log slices with meaningful tags
266 | self.writer.add_images(f"{tag}_slice_depth", slices_depth, step)
267 | self.writer.add_images(f"{tag}_slice_height", slices_height, step)
268 | self.writer.add_images(f"{tag}_slice_width", slices_width, step)
269 |
270 |
271 | class SmoothedValue(object):
272 | """
273 | Track a series of values and provide access to smoothed values over a
274 | window or the global series average.
275 | """
276 |
277 | def __init__(self, window_size=20, fmt=None):
278 | if fmt is None:
279 | fmt = "{median:.4f} ({global_avg:.4f})"
280 | self.deque = deque(maxlen=window_size)
281 | self.total = 0.0
282 | self.count = 0
283 | self.fmt = fmt
284 |
285 | def update(self, value, n=1):
286 | self.deque.append(value)
287 | self.count += n
288 | self.total += value * n
289 |
290 | def reset(self):
291 | self.deque.clear()
292 | self.total = 0.0
293 | self.count = 0
294 |
295 | def synchronize_between_processes(self):
296 | """
297 | Warning: does not synchronize the deque!
298 | """
299 | if not is_dist_avail_and_initialized():
300 | return
301 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
302 | dist.barrier()
303 | dist.all_reduce(t)
304 | t = t.tolist()
305 | self.count = int(t[0])
306 | self.total = t[1]
307 |
308 | @property
309 | def median(self):
310 | d = torch.tensor(list(self.deque))
311 | return d.median().item()
312 |
313 | @property
314 | def avg(self):
315 | d = torch.tensor(list(self.deque), dtype=torch.float32)
316 | return d.mean().item()
317 |
318 | @property
319 | def global_avg(self):
320 | return self.total / self.count
321 |
322 | @property
323 | def max(self):
324 | return max(self.deque)
325 |
326 | @property
327 | def value(self):
328 | return self.deque[-1]
329 |
330 | def __str__(self):
331 | return self.fmt.format(
332 | median=self.median,
333 | avg=self.avg,
334 | global_avg=self.global_avg,
335 | max=self.max,
336 | value=self.value)
337 |
338 |
339 | class MetricLogger(object):
340 | def __init__(self, delimiter="\t"):
341 | self.meters = defaultdict(SmoothedValue)
342 | self.delimiter = delimiter
343 |
344 | def update(self, **kwargs):
345 | for k, v in kwargs.items():
346 | if v is None:
347 | continue
348 | if isinstance(v, torch.Tensor):
349 | v = v.item()
350 | assert isinstance(v, (float, int))
351 | self.meters[k].update(v)
352 |
353 | def __getattr__(self, attr):
354 | if attr in self.meters:
355 | return self.meters[attr]
356 | if attr in self.__dict__:
357 | return self.__dict__[attr]
358 | raise AttributeError("'{}' object has no attribute '{}'".format(
359 | type(self).__name__, attr))
360 |
361 | def __str__(self):
362 | loss_str = []
363 | for name, meter in self.meters.items():
364 | loss_str.append(
365 | "{}: {}".format(name, str(meter))
366 | )
367 | return self.delimiter.join(loss_str)
368 |
369 | def synchronize_between_processes(self):
370 | for meter in self.meters.values():
371 | meter.synchronize_between_processes()
372 |
373 | def add_meter(self, name, meter):
374 | self.meters[name] = meter
375 |
376 | def reset(self):
377 | for meter in self.meters.values():
378 | meter.reset()
379 |
380 | def log_every(self, iterable, print_freq, header=None):
381 | i = 0
382 | if not header:
383 | header = ''
384 | start_time = time.time()
385 | end = time.time()
386 | iter_time = SmoothedValue(fmt='{avg:.4f}')
387 | data_time = SmoothedValue(fmt='{avg:.4f}')
388 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
389 | log_msg = [
390 | header,
391 | '[{0' + space_fmt + '}/{1}]',
392 | 'eta: {eta}',
393 | '{meters}',
394 | 'time: {time}',
395 | 'data: {data}'
396 | ]
397 | if torch.cuda.is_available():
398 | log_msg.append('max mem: {memory:.0f}')
399 | log_msg = self.delimiter.join(log_msg)
400 | MB = 1024.0 * 1024.0
401 | for obj in iterable:
402 | data_time.update(time.time() - end)
403 | yield obj
404 | iter_time.update(time.time() - end)
405 | if i % print_freq == 0 or i == len(iterable) - 1:
406 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
407 | eta_string = str(datetime.datetime.timedelta(seconds=int(eta_seconds)))
408 | if torch.cuda.is_available():
409 | print(log_msg.format(
410 | i, len(iterable), eta=eta_string,
411 | meters=str(self),
412 | time=str(iter_time), data=str(data_time),
413 | memory=torch.cuda.max_memory_allocated() / MB))
414 | else:
415 | print(log_msg.format(
416 | i, len(iterable), eta=eta_string,
417 | meters=str(self),
418 | time=str(iter_time), data=str(data_time)))
419 | i += 1
420 | end = time.time()
421 | total_time = time.time() - start_time
422 | total_time_str = str(datetime.datetime.timedelta(seconds=int(total_time)))
423 | print('{} Total time: {} ({:.4f} s / it)'.format(
424 | header, total_time_str, total_time / len(iterable)))
425 |
426 |
427 | def psnr(mse):
428 | return -10.0 * torch.log10(mse+1e-24)
429 |
430 | def dice_score(pred, target, epsilon=1e-6):
431 | """
432 | Computes the Dice score between two batched image tensors.
433 | Args:
434 | pred (torch.Tensor): Tensor of shape (B, ...) containing predicted masks.
435 | target (torch.Tensor): Tensor of shape (B, ...) containing ground truth masks.
436 | epsilon (float): Small constant to avoid division by zero.
437 | Returns:
438 | torch.Tensor: Dice score for each item in the batch (shape: [B])
439 | """
440 | # Ensure the input shapes are the same
441 | if pred.shape != target.shape:
442 | raise ValueError(f"Shape mismatch: pred {pred.shape}, target {target.shape}")
443 |
444 | # Flatten each spatial dimension per batch item
445 | B = pred.shape[0]
446 | pred_flat = pred.view(B, -1)
447 | target_flat = target.view(B, -1)
448 |
449 | # Compute Dice coefficient
450 | intersection = (pred_flat * target_flat).sum(dim=1)
451 | union = pred_flat.sum(dim=1) + target_flat.sum(dim=1)
452 |
453 | dice = (2.0 * intersection + epsilon) / (union + epsilon)
454 | return dice
455 |
456 |
457 | def _gaussian_window(window_size: int, sigma: float, device: torch.device, dtype: torch.dtype):
458 | coords = torch.arange(window_size, device=device, dtype=dtype) - window_size // 2
459 | g = torch.exp(-(coords**2) / (2 * sigma**2))
460 | g /= g.sum()
461 | return g.view(1, 1, -1)
462 |
463 |
464 | def ssim_1d(x: torch.Tensor, y: torch.Tensor, window_size: int = 11, sigma: float = 1.5,
465 | data_range: float = None, K1: float = 0.01, K2: float = 0.03) -> torch.Tensor:
466 | """
467 | Compute Structural Similarity Index (SSIM) for 1D time series.
468 | Parameters
469 | ----------
470 | x, y : torch.Tensor
471 | Input 1D tensors (same length) or 2D (batch, length).
472 | window_size : int
473 | Size of sliding window (odd number).
474 | sigma : float
475 | Gaussian kernel standard deviation for local statistics.
476 | data_range : float
477 | Value range of the input (max - min). If None, inferred from data.
478 | K1, K2 : float
479 | Constants for stability in SSIM formula.
480 | Returns
481 | -------
482 | ssim : torch.Tensor
483 | Mean SSIM over the signal (scalar).
484 | """
485 | if x.shape != y.shape:
486 | raise ValueError("Input signals must have the same shape")
487 | if x.dim() == 1:
488 | x = x.unsqueeze(0).unsqueeze(0) # (1,1,L)
489 | y = y.unsqueeze(0).unsqueeze(0)
490 | elif x.dim() == 2:
491 | x = x.unsqueeze(1) # (B,1,L)
492 | y = y.unsqueeze(1)
493 | else:
494 | raise ValueError("Input tensors must be 1D or 2D (batch, length)")
495 | if data_range is None:
496 | data_range = torch.max(torch.cat([x.max().unsqueeze(0) - x.min().unsqueeze(0),
497 | y.max().unsqueeze(0) - y.min().unsqueeze(0)]))
498 | C1 = (K1 * data_range) ** 2
499 | C2 = (K2 * data_range) ** 2
500 | # Gaussian kernel
501 | window = _gaussian_window(window_size, sigma, x.device, x.dtype)
502 | # Local means
503 | mu_x = F.conv1d(x, window, padding=window_size//2)
504 | mu_y = F.conv1d(y, window, padding=window_size//2)
505 | mu_x_sq = mu_x ** 2
506 | mu_y_sq = mu_y ** 2
507 | mu_xy = mu_x * mu_y
508 | # Variances and covariance
509 | sigma_x_sq = F.conv1d(x * x, window, padding=window_size//2) - mu_x_sq
510 | sigma_y_sq = F.conv1d(y * y, window, padding=window_size//2) - mu_y_sq
511 | sigma_xy = F.conv1d(x * y, window, padding=window_size//2) - mu_xy
512 | # SSIM map
513 | numerator = (2 * mu_xy + C1) * (2 * sigma_xy + C2)
514 | denominator = (mu_x_sq + mu_y_sq + C1) * (sigma_x_sq + sigma_y_sq + C2)
515 | ssim_map = numerator / denominator
516 | return ssim_map
517 |
518 |
519 | class InfiniteSampler(torch.utils.data.Sampler):
520 | """
521 | A PyTorch Sampler that provides an infinite stream of indices from the dataset,
522 | optionally shuffling and allowing distributed sampling across replicas.
523 | """
524 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
525 | # Ensure dataset and configuration are valid
526 | assert len(dataset) > 0
527 | assert num_replicas > 0
528 | assert 0 <= rank < num_replicas
529 | assert 0 <= window_size <= 1
530 |
531 | # Initialize base sampler and store parameters
532 | super().__init__(dataset)
533 | self.dataset = dataset
534 | self.rank = rank
535 | self.num_replicas = num_replicas
536 | self.shuffle = shuffle
537 | self.seed = seed
538 | self.window_size = window_size
539 |
540 | def __iter__(self):
541 | # Generate a sequence of indices corresponding to the dataset
542 | order = np.arange(len(self.dataset))
543 |
544 | # Initialize random number generator and window size for shuffling
545 | rnd = None
546 | window = 0
547 | if self.shuffle:
548 | # Shuffle the dataset indices
549 | rnd = np.random.RandomState(self.seed)
550 | rnd.shuffle(order)
551 | window = int(np.rint(order.size * self.window_size))
552 |
553 | # Start iterating over the dataset
554 | idx = 0
555 | while True:
556 | i = idx % order.size
557 | if idx % self.num_replicas == self.rank:
558 | yield order[i]
559 | if window >= 2:
560 | j = (i - rnd.randint(window)) % order.size
561 | order[i], order[j] = order[j], order[i]
562 | idx += 1
563 |
564 |
565 | def load_model(args, model, logger=None):
566 | if logger is None:
567 | log_ = print
568 | else:
569 | log_ = logger.log
570 |
571 | if args.load_path is not None:
572 | log_(f'Load model from {args.load_path}')
573 | checkpoint = torch.load(args.load_path, weights_only=True)
574 |
575 | not_loaded = model.load_state_dict(checkpoint)
576 | print(not_loaded)
577 |
578 | if os.path.exists(args.load_path[:-5] + 'lr'): # Meta-SGD
579 | log_(f'Load lr from {args.load_path[:-5]}lr')
580 | lr = torch.load(args.load_path[:-5] + 'lr')
581 | for (_, param) in lr.items():
582 | param.to(args.device)
583 | args.inner_lr = lr
584 |
--------------------------------------------------------------------------------