├── common ├── __init__.py ├── w0_utils.py ├── args.py └── utils.py ├── data ├── __init__.py ├── preproc_lidc-idri.py └── dataset.py ├── eval ├── __init__.py ├── maml_full_fit.py ├── maml_full_eval.py └── maml_scale.py ├── models ├── __init__.py ├── layers.py ├── inrs.py └── model_wrapper.py ├── train ├── __init__.py ├── trainer.py └── maml_boot.py ├── assets ├── MedNF.png └── overview.png ├── environment.yaml ├── configs ├── experiments │ ├── 1d_timeseries │ │ └── default.yaml │ ├── 3d_imgs │ │ └── default.yaml │ └── 2d_imgs │ │ ├── default_224.yaml │ │ ├── default_128.yaml │ │ └── default_64.yaml ├── eval │ ├── 1d_timeseries │ │ └── default.yaml │ ├── 3d_imgs │ │ └── default.yaml │ └── 2d_imgs │ │ ├── default_128.yaml │ │ ├── default_224.yaml │ │ └── default_64.yaml └── fit │ └── default_64.yaml ├── LICENSE ├── eval.py ├── .gitignore ├── fit_NFset.py ├── train.py ├── README.md └── downstream_tasks └── classification.py /common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /eval/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/MedNF.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfriedri/medfuncta/HEAD/assets/MedNF.png -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfriedri/medfuncta/HEAD/assets/overview.png -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: medfuncta 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | dependencies: 6 | - python=3.10 7 | - pytorch=2.5.1 8 | - torchvision=0.20.1 9 | - numpy=2.1.3 10 | - matplotlib=3.10.1 11 | - pip 12 | - pip: 13 | - tensorboard 14 | - pandas 15 | - nibabel 16 | - einops 17 | - medmnist 18 | - dicom2nifti 19 | - pytorch_msssim 20 | - lpips 21 | - torchmetrics 22 | - opencv-python -------------------------------------------------------------------------------- /configs/experiments/1d_timeseries/default.yaml: -------------------------------------------------------------------------------- 1 | { 2 | # Training configuration 3 | 'dataset': ecg, 4 | 'img_size': 187, 5 | 'batch_size': 64, 6 | 'inner_steps': 10, 7 | 'max_iter': 250000, 8 | 'lr_scheduler': True, 9 | 10 | # Context reduction 11 | 'data_ratio': 1.0, 12 | 'sample_type': random, 13 | 14 | # Omega schedule 15 | 'w0_sched_type': linear, 16 | 'w0': 20, 17 | 'wK': 200, 18 | 19 | # Test configuration 20 | 'test_batch_size': 64, 21 | 'num_test_signals': 256, 22 | 'inner_steps_test': 20, 23 | 24 | # Model config 25 | 'min_hidden_dim': 64, 26 | 'max_hidden_dim': 64, 27 | 'progression_type': linear, 28 | 'num_layers': 8, 29 | 'latent_modulation_dim': 64, 30 | } -------------------------------------------------------------------------------- /configs/experiments/3d_imgs/default.yaml: -------------------------------------------------------------------------------- 1 | { 2 | # Training cofiguration 3 | 'task': rec, 4 | 'dataset': brats, 5 | 'img_size': 32, 6 | 'batch_size': 4, 7 | 'inner_steps': 10, 8 | 'max_iter': 250000, 9 | 'lr_scheduler': True, 10 | 11 | # Context reduction 12 | 'data_ratio': 0.25, 13 | 'sample_type': random, 14 | 15 | # Omega schedule 16 | 'w0_sched_type': linear, 17 | 'w0': 20, 18 | 'wK': 300, 19 | 20 | # Test configuration 21 | 'test_batch_size': 4, 22 | 'num_test_signals': 64, 23 | 'inner_steps_test': 20, 24 | 25 | # Model config 26 | 'min_hidden_dim': 256, 27 | 'max_hidden_dim': 256, 28 | 'progression_type': linear, 29 | 'num_layers': 15, 30 | 'latent_modulation_dim': 8192, 31 | } -------------------------------------------------------------------------------- /configs/experiments/2d_imgs/default_224.yaml: -------------------------------------------------------------------------------- 1 | { 2 | # Training cofiguration 3 | 'task': rec, 4 | 'dataset': chestmnist, 5 | 'img_size': 224, 6 | 'batch_size': 4, 7 | 'inner_steps': 10, 8 | 'max_iter': 250000, 9 | 'lr_scheduler': True, 10 | 11 | # Context reduction 12 | 'data_ratio': 0.1, 13 | 'sample_type': random, 14 | 15 | # Omega schedule 16 | 'w0_sched_type': linear, 17 | 'w0': 30, 18 | 'wK': 300, 19 | 20 | # Test configuration 21 | 'test_batch_size': 4, 22 | 'num_test_signals': 256, 23 | 'inner_steps_test': 20, 24 | 25 | # Model config 26 | 'min_hidden_dim': 256, 27 | 'max_hidden_dim': 256, 28 | 'progression_type': linear, 29 | 'num_layers': 15, 30 | 'latent_modulation_dim': 16384, 31 | } -------------------------------------------------------------------------------- /configs/experiments/2d_imgs/default_128.yaml: -------------------------------------------------------------------------------- 1 | { 2 | # Training configuration 3 | 'task': rec, 4 | 'dataset': chestmnist, 5 | 'img_size': 128, 6 | 'batch_size': 8, 7 | 'inner_steps': 10, 8 | 'max_iter': 250000, 9 | 'lr_scheduler': True, 10 | 11 | # Context reduction 12 | 'data_ratio': 0.25, 13 | 'sample_type': random, 14 | 15 | # Omega schedule 16 | 'w0_sched_type': linear, 17 | 'w0': 30, 18 | 'wK': 300, 19 | 20 | # Test configuration 21 | 'test_batch_size': 8, 22 | 'num_test_signals': 256, 23 | 'inner_steps_test': 20, 24 | 25 | # Model config 26 | 'min_hidden_dim': 256, 27 | 'max_hidden_dim': 256, 28 | 'progression_type': linear, 29 | 'num_layers': 15, 30 | 'latent_modulation_dim': 8192, 31 | } -------------------------------------------------------------------------------- /configs/experiments/2d_imgs/default_64.yaml: -------------------------------------------------------------------------------- 1 | { 2 | # Training configuration 3 | 'task': rec, 4 | 'dataset': chestmnist, 5 | 'img_size': 64, 6 | 'batch_size': 24, 7 | 'inner_steps': 10, 8 | 'max_iter': 250000, 9 | 'lr_scheduler': True, 10 | 11 | # Context reduction 12 | 'data_ratio': 0.25, 13 | 'sample_type': random, 14 | 15 | # Omega schedule 16 | 'w0_sched_type': linear, 17 | 'w0': 20, 18 | 'wK': 400, 19 | 20 | # Test configuration 21 | 'test_batch_size': 24, 22 | 'num_test_signals': 256, 23 | 'inner_steps_test': 20, 24 | 25 | # Model config 26 | 'min_hidden_dim': 256, 27 | 'max_hidden_dim': 256, 28 | 'progression_type': linear, 29 | 'num_layers': 15, 30 | 'latent_modulation_dim': 2048, 31 | } -------------------------------------------------------------------------------- /configs/eval/1d_timeseries/default.yaml: -------------------------------------------------------------------------------- 1 | { 2 | # Training configuration 3 | 'dataset': ecg, 4 | 'img_size': 187, 5 | 'batch_size': 64, 6 | 'inner_steps': 10, 7 | 'max_iter': 250000, 8 | 'lr_scheduler': True, 9 | 10 | # Context reduction 11 | 'data_ratio': 1.0, 12 | 'sample_type': random, 13 | 14 | # Omega schedule 15 | 'w0_sched_type': linear, 16 | 'w0': 20, 17 | 'wK': 200, 18 | 19 | # Test configuration 20 | 'test_batch_size': 64, 21 | 'num_test_signals': 256, 22 | 'inner_steps_test': 20, 23 | 24 | # Model config 25 | 'min_hidden_dim': 64, 26 | 'max_hidden_dim': 64, 27 | 'progression_type': linear, 28 | 'num_layers': 8, 29 | 'latent_modulation_dim': 64, 30 | 31 | # Load model from checkpoint 32 | 'load_path': /path/to/best.model or /path/to/stepXXXX.model, 33 | } -------------------------------------------------------------------------------- /configs/eval/3d_imgs/default.yaml: -------------------------------------------------------------------------------- 1 | { 2 | # Training cofiguration 3 | 'dataset': brats, 4 | 'img_size': 32, 5 | 'batch_size': 4, 6 | 'inner_steps': 10, 7 | 'max_iter': 250000, 8 | 'lr_scheduler': True, 9 | 10 | # Context reduction 11 | 'data_ratio': 0.25, 12 | 'sample_type': random, 13 | 14 | # Omega schedule 15 | 'w0_sched_type': linear, 16 | 'w0': 20, 17 | 'wK': 300, 18 | 19 | # Test configuration 20 | 'test_batch_size': 4, 21 | 'num_test_signals': 64, 22 | 'inner_steps_test': 20, 23 | 24 | # Model config 25 | 'min_hidden_dim': 256, 26 | 'max_hidden_dim': 256, 27 | 'progression_type': linear, 28 | 'num_layers': 15, 29 | 'latent_modulation_dim': 8192, 30 | 31 | # Load model from checkpoint 32 | 'load_path': /path/to/best.model or /path/to/stepXXXX.model, 33 | } -------------------------------------------------------------------------------- /configs/eval/2d_imgs/default_128.yaml: -------------------------------------------------------------------------------- 1 | { 2 | # Training cofiguration 3 | 'dataset': chestmnist, 4 | 'img_size': 128, 5 | 'batch_size': 8, 6 | 'inner_steps': 10, 7 | 'max_iter': 250000, 8 | 'lr_scheduler': True, 9 | 10 | # Context reduction 11 | 'data_ratio': 0.25, 12 | 'sample_type': random, 13 | 14 | # Omega schedule 15 | 'w0_sched_type': linear, 16 | 'w0': 30, 17 | 'wK': 300, 18 | 19 | # Test configuration 20 | 'test_batch_size': 8, 21 | 'num_test_signals': 256, 22 | 'inner_steps_test': 20, 23 | 24 | # Model config 25 | 'min_hidden_dim': 256, 26 | 'max_hidden_dim': 256, 27 | 'progression_type': linear, 28 | 'num_layers': 15, 29 | 'latent_modulation_dim': 8192, 30 | 31 | # Load model from checkpoint 32 | 'load_path': /path/to/best.model or /path/to/stepXXXX.model, 33 | } -------------------------------------------------------------------------------- /configs/eval/2d_imgs/default_224.yaml: -------------------------------------------------------------------------------- 1 | { 2 | # Training configuration 3 | 'dataset': chestmnist, 4 | 'img_size': 224, 5 | 'batch_size': 4, 6 | 'inner_steps': 10, 7 | 'max_iter': 250000, 8 | 'lr_scheduler': True, 9 | 10 | # Context reduction 11 | 'data_ratio': 0.1, 12 | 'sample_type': random, 13 | 14 | # Omega schedule 15 | 'w0_sched_type': linear, 16 | 'w0': 30, 17 | 'wK': 300, 18 | 19 | # Test configuration 20 | 'test_batch_size': 4, 21 | 'num_test_signals': 256, 22 | 'inner_steps_test': 20, 23 | 24 | # Model config 25 | 'min_hidden_dim': 256, 26 | 'max_hidden_dim': 256, 27 | 'progression_type': linear, 28 | 'num_layers': 15, 29 | 'latent_modulation_dim': 16384, 30 | 31 | # Load model from checkpoint 32 | 'load_path': /path/to/best.model or /path/to/stepXXXX.model, 33 | } -------------------------------------------------------------------------------- /configs/eval/2d_imgs/default_64.yaml: -------------------------------------------------------------------------------- 1 | { 2 | # Training configuration 3 | 'dataset': chestmnist, 4 | 'img_size': 64, 5 | 'batch_size': 24, 6 | 'inner_steps': 10, 7 | 'max_iter': 250000, 8 | 'lr_scheduler': True, 9 | 10 | # Context reduction 11 | 'data_ratio': 0.25, 12 | 'sample_type': random, 13 | 14 | # Omega schedule 15 | 'w0_sched_type': linear, 16 | 'w0': 20, 17 | 'wK': 400, 18 | 19 | # Test configuration 20 | 'test_batch_size': 24, 21 | 'num_test_signals': 256, 22 | 'inner_steps_test': 20, 23 | 24 | # Model config 25 | 'min_hidden_dim': 256, 26 | 'max_hidden_dim': 256, 27 | 'progression_type': linear, 28 | 'num_layers': 15, 29 | 'latent_modulation_dim': 2048, 30 | 31 | # Load model from checkpoint 32 | 'load_path': /path/to/best.model or /path/to/stepXXXX.model, 33 | } -------------------------------------------------------------------------------- /configs/fit/default_64.yaml: -------------------------------------------------------------------------------- 1 | { 2 | # Training configuration 3 | 'task': rec, 4 | 'dataset': chestmnist, 5 | 'img_size': 64, 6 | 'batch_size': 1, 7 | 'inner_steps': 10, 8 | 'max_iter': 250000, 9 | 'lr_scheduler': True, 10 | 11 | # Context reduction 12 | 'data_ratio': 0.25, 13 | 'sample_type': random, 14 | 15 | # Omega schedule 16 | 'w0_sched_type': linear, 17 | 'w0': 20, 18 | 'wK': 400, 19 | 20 | # Test configuration 21 | 'test_batch_size': 1, 22 | 'num_test_signals': 256, 23 | 'inner_steps_test': 20, 24 | 25 | # Model config 26 | 'min_hidden_dim': 256, 27 | 'max_hidden_dim': 256, 28 | 'progression_type': linear, 29 | 'num_layers': 15, 30 | 'latent_modulation_dim': 2048, 31 | 32 | # Load model from checkpoint 33 | 'load_path': /path/to/best.model or /path/to/stepXXXX.model, 34 | 35 | #Save NFs to 36 | 'save_dir': /path/to/save/medfuncta/set 37 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Paul Friedrich 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /common/w0_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import math 4 | from typing import Callable 5 | 6 | def get_w0s(args): 7 | sched_type = args.w0_sched_type 8 | num_layers = args.num_layers - 1 9 | device = args.device 10 | 11 | def linear_schedule(i): 12 | w0_0 = args.w0 13 | w0_K = args.wK 14 | a = (w0_K - w0_0) / (num_layers - 1) 15 | b = w0_0 16 | return a * i + b 17 | 18 | def exponential_schedule(i): 19 | w0_0 = args.w0 20 | w0_K = args.wK 21 | a = w0_0 22 | b = math.log(w0_K / w0_0) / (num_layers - 1) 23 | return a * torch.exp(b * i) 24 | 25 | def const_manual_schedule(i): 26 | return torch.tensor(args.w0, device=device) 27 | 28 | sched_map: dict[str, Callable[[torch.Tensor], torch.Tensor]] = { 29 | 'linear': linear_schedule, 30 | 'exponential': exponential_schedule, 31 | } 32 | 33 | w0_fn = sched_map.get(sched_type, const_manual_schedule) 34 | w0s = [w0_fn(torch.tensor(i, device=device)) for i in range(num_layers)] 35 | 36 | return w0s 37 | 38 | def save_w0s(w0s, logger): 39 | logdir = logger.logdir 40 | w0_path = os.path.join(logdir, 'w0s.sched') 41 | w0s = [t.cpu() for t in w0s] 42 | torch.save(w0s, w0_path) 43 | 44 | def load_w0s(args): 45 | load_dir = os.path.join(os.path.dirname(args.load_path), 'w0s.sched') 46 | w0s = torch.load(load_dir, weights_only=True) 47 | return w0s 48 | -------------------------------------------------------------------------------- /eval/maml_full_fit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | 4 | from tqdm import tqdm 5 | from train.maml_boot import inner_adapt_test_scale 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | 10 | def fit_nfs(args, model_wrapper, dataloader, set=None): 11 | 12 | model_wrapper.model.eval() 13 | model_wrapper.coord_init() 14 | 15 | for n, data in enumerate(tqdm(dataloader)): 16 | data, label = data 17 | data = data.to(device) 18 | batch_size = data.size(0) 19 | model_wrapper.model.reset_modulations() 20 | 21 | _ = inner_adapt_test_scale(model_wrapper=model_wrapper, data=data, step_size=args.inner_lr, 22 | num_steps=args.inner_steps_test, first_order=True, sample_type=args.sample_type, 23 | scale_type='grad') 24 | 25 | if set == 'test': 26 | with torch.no_grad(): 27 | pred = model_wrapper().clamp(0, 1) 28 | if n < 100: 29 | # Convert to PIL image 30 | to_pil = transforms.ToPILImage() 31 | image = to_pil(data.squeeze()) 32 | input_path = args.save_dir + f'test/imgs/{n}_input.png' 33 | image.save(input_path) 34 | image = to_pil(pred.squeeze()) 35 | recon_path = args.save_dir + f'test/imgs/{n}_recon.png' 36 | image.save(recon_path) 37 | 38 | else: 39 | input('done') 40 | 41 | for i in range(batch_size): 42 | datapoint = { 43 | 'modulations': model_wrapper.model.modulations[i].detach().cpu(), 44 | 'label': label[i].detach().cpu() 45 | } 46 | sdir = args.save_dir + f'/{set}/' + f'datapoint_{(n * batch_size) + i}.pt' 47 | torch.save(datapoint, sdir) 48 | return 49 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class LatentModulatedSIRENLayer(nn.Module): 7 | def __init__(self, in_size, out_size, latent_modulation_dim: 512, w0=30., 8 | modulate_shift=True, modulate_scale=False, is_first=False, is_last=False): 9 | super().__init__() 10 | self.in_size = in_size 11 | self.out_size = out_size 12 | self.latent_modulation_dim = latent_modulation_dim 13 | self.w0 = w0 14 | self.modulate_shift = modulate_shift 15 | self.modulate_scale = modulate_scale 16 | self.is_first = is_first 17 | self.is_last = is_last 18 | 19 | self.linear = nn.Linear(in_size, out_size, bias=True) 20 | 21 | if modulate_shift and not is_first and not is_last: 22 | self.modulate_shift_layer = nn.Linear(latent_modulation_dim, out_size) 23 | if modulate_scale and not is_first and not is_last: 24 | self.modulate_scale_layer = nn.Linear(latent_modulation_dim, out_size) 25 | 26 | self._init(w0, is_first) 27 | 28 | def _init(self, w0, is_first): 29 | dim_in = self.in_size 30 | w_std = 1 / dim_in if is_first else math.sqrt(6.0 / dim_in) / w0.item() 31 | nn.init.uniform_(self.linear.weight, -w_std, w_std) 32 | nn.init.uniform_(self.linear.bias, -w_std, w_std) 33 | 34 | def forward(self, x, latent): 35 | x = self.linear(x) 36 | 37 | if not self.is_first and not self.is_last: 38 | shift = 0.0 if not self.modulate_shift else self.modulate_shift_layer(latent) 39 | scale = 1.0 if not self.modulate_scale else self.modulate_scale_layer(latent) 40 | 41 | if self.modulate_shift: 42 | if len(shift.shape) == 2: 43 | shift = shift.unsqueeze(dim=1) 44 | if self.modulate_scale: 45 | if len(scale.shape) == 2: 46 | scale = scale.unsqueeze(dim=1) 47 | 48 | x = scale * x + shift 49 | 50 | if not self.is_last: 51 | x = torch.sin(self.w0 * x) 52 | return x 53 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | 4 | from common.args import parse_args 5 | from common.utils import set_random_seed, load_model 6 | from common.w0_utils import load_w0s 7 | from data.dataset import get_dataset 8 | from eval.maml_full_eval import test_model 9 | from models.inrs import LatentModulatedSIREN 10 | from models.model_wrapper import ModelWrapper 11 | 12 | 13 | def main(args): 14 | """ 15 | Main function to call for running an evaluation procedure (evaluate performance on test set). 16 | :param args: parameters parsed from the command line/ a config.yaml.. 17 | :return: Nothing. 18 | """ 19 | 20 | """ Set a device to use """ 21 | if torch.cuda.is_available(): 22 | torch.cuda.set_device(args.gpu_id) 23 | device = torch.device(f'cuda' if torch.cuda.is_available() else 'cpu') 24 | args.device = device 25 | 26 | """ Enable determinism """ 27 | set_random_seed(args.seed) 28 | torch.backends.cudnn.deterministic = True 29 | torch.backends.cudnn.benchmark = False 30 | 31 | """ Define test dataset """ 32 | test_set = get_dataset(args, only_test=True) 33 | test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, num_workers=4, pin_memory=True, 34 | drop_last=True) 35 | 36 | """ Get w0s to initialize the model """ 37 | w0s = load_w0s(args) 38 | args.w0s = w0s 39 | 40 | """ Initialize model and optimizer """ 41 | model = LatentModulatedSIREN( 42 | in_size=args.in_size, 43 | out_size=args.out_size, 44 | min_hidden_size=args.min_hidden_dim, 45 | max_hidden_size=args.max_hidden_dim, 46 | progression_type=args.progression_type, 47 | num_layers=args.num_layers, 48 | latent_modulation_dim=args.latent_modulation_dim, 49 | w0s=args.w0s, 50 | modulate_shift=args.modulate_shift, 51 | modulate_scale=args.modulate_scale, 52 | enable_skip_connections=args.enable_skip_connections, 53 | ).to(device) 54 | 55 | """ Initialize modulation vectors (signal-specific parameter vector) """ 56 | model.modulations = torch.zeros(size=[args.test_batch_size, args.latent_modulation_dim], requires_grad=True).to(device) 57 | 58 | """ Wrap the model """ 59 | model = ModelWrapper(args, model) 60 | load_model(args, model) 61 | 62 | """ Define test function """ 63 | test_model(args, model, test_loader, logger=None) 64 | 65 | 66 | if __name__ == "__main__": 67 | args = parse_args() 68 | main(args) 69 | -------------------------------------------------------------------------------- /train/trainer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import torch 4 | from common.utils import resume_training, MetricLogger, save_checkpoint, save_checkpoint_step 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | 9 | def trainer(args, train_function, test_function, model_wrapper, meta_optimizer, train_loader, test_loader, logger, scheduler=None): 10 | """ 11 | The main function that performs the training. Iteratively calls training steps (train_function) and evaluations (test_function). 12 | :param args: parameters parsed from the command line. 13 | :param train_function: function that performs a single meta-update step. 14 | :param test_function: function that performs the evaluation. 15 | :param model_wrapper: the wrapped model. 16 | :param meta_optimizer: optimizer used for the meta-learning (global optimizer). 17 | :param train_loader: data loader for training data. 18 | :param test_loader: data loader for testing data (usually this will be your validation set). 19 | :param logger: a logger. 20 | :param scheduler: boolean - True: lr-schedule is used. 21 | :return: Nothing. 22 | """ 23 | 24 | metric_logger = MetricLogger(delimiter=" ") 25 | 26 | """ Resume training (optional with '--resume_path' flag) """ 27 | is_best, start_step, best_psnr, psnr = resume_training(args, model_wrapper, meta_optimizer) 28 | 29 | """ Start Training """ 30 | logger.log_dirname(f"Start training") 31 | 32 | """ Load training data """ 33 | for it, train_batch in enumerate(train_loader): 34 | step = start_step + it + 1 35 | if step > args.outer_steps: 36 | break 37 | 38 | train_batch, _ = train_batch 39 | train_batch = train_batch.float().to(device, non_blocking=True) 40 | 41 | """ Perform a single meta-update training step """ 42 | train_function(args, step, model_wrapper, meta_optimizer, train_batch, metric_logger, logger, scheduler) 43 | 44 | """ Evaluate and save model every eval_step steps """ 45 | if step == 1 or step % args.eval_step == 0: 46 | psnr, lpips, ssim = test_function(args, step, model_wrapper, test_loader, logger) 47 | 48 | if best_psnr < psnr: 49 | best_psnr = psnr 50 | save_checkpoint(args, step, best_psnr, model_wrapper, meta_optimizer.state_dict(), logger.logdir, 51 | is_best=True) 52 | 53 | logger.scalar_summary('eval/best_psnr', best_psnr, step) 54 | logger.log('[EVAL] [Step %3d] [PSNR %5.2f] [BestPSNR %5.2f]' % (step, psnr, best_psnr)) 55 | 56 | """ Save model every save_step steps""" 57 | if step == 1 or step % args.save_step == 0: 58 | save_checkpoint_step(args, step, best_psnr, model_wrapper,meta_optimizer.state_dict(), logger.logdir) 59 | 60 | """ Finish training after max_iter steps """ 61 | if step >= args.max_iter: 62 | break 63 | 64 | """ Save the last model""" 65 | save_checkpoint(args, args.outer_steps, best_psnr, model_wrapper, meta_optimizer.state_dict(), logger.logdir) 66 | -------------------------------------------------------------------------------- /eval/maml_full_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import lpips 3 | import torch.nn.functional as F 4 | from tqdm import tqdm 5 | from pytorch_msssim import ssim 6 | 7 | from common.utils import MetricLogger, psnr, ssim_1d 8 | from train.maml_boot import inner_adapt_test_scale 9 | 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | 12 | 13 | def test_model(args, model_wrapper, test_loader, logger=None): 14 | metric_logger = MetricLogger(delimiter=" ") 15 | 16 | if logger is None: 17 | log_ = print 18 | else: 19 | log_ = logger.log 20 | 21 | model_wrapper.model.eval() 22 | model_wrapper.coord_init() 23 | 24 | lpips_score = lpips.LPIPS(net='alex').to(device) 25 | 26 | for n, data in enumerate(tqdm(test_loader)): 27 | data, _ = data 28 | data = data.to(device) 29 | batch_size = data.size(0) 30 | model_wrapper.model.reset_modulations() 31 | 32 | _ = inner_adapt_test_scale(model_wrapper=model_wrapper, data=data, step_size=args.inner_lr, 33 | num_steps=args.inner_steps_test, first_order=True, sample_type=args.sample_type, 34 | scale_type='grad') 35 | 36 | with torch.no_grad(): 37 | pred = model_wrapper().clamp(0, 1) 38 | 39 | if args.data_type == 'img': 40 | lpips_results = lpips_score((pred * 2 - 1), (data * 2 - 1)).mean() 41 | mse_results = F.mse_loss(data.view(batch_size, -1), pred.reshape(batch_size, -1), reduce=False).mean() 42 | psnr_results = psnr( 43 | F.mse_loss(data.view(batch_size, -1), pred.reshape(batch_size, -1), reduce=False).mean(dim=1) 44 | ).mean() 45 | ssim_results = ssim(pred, data, data_range=1.).mean() 46 | 47 | elif args.data_type == 'img3d': 48 | mse_results = F.mse_loss(data.view(batch_size, -1), pred.reshape(batch_size, -1), reduce=False).mean() 49 | psnr_results = psnr( 50 | F.mse_loss(data.view(batch_size, -1), pred.reshape(batch_size, -1), reduce=False).mean(dim=1) 51 | ).mean() 52 | ssim_results = ssim(pred, data, data_range=1.).mean() 53 | lpips_results = torch.zeros_like(psnr_results) 54 | 55 | elif args.data_type == 'timeseries': 56 | mse_results = F.mse_loss(data.view(batch_size, -1), pred.reshape(batch_size, -1), reduce=False).mean() 57 | psnr_results = psnr( 58 | F.mse_loss(data.view(batch_size, -1), pred.reshape(batch_size, -1), reduce=False).mean(dim=1) 59 | ).mean() 60 | ssim_results = ssim_1d(pred.squeeze(), data.squeeze(), data_range=1.).mean() 61 | lpips_results = torch.zeros_like(psnr_results) 62 | 63 | else: 64 | raise NotImplementedError() 65 | 66 | metric_logger.meters['lpips'].update(lpips_results.item(), n=batch_size) 67 | metric_logger.meters['psnr'].update(psnr_results.item(), n=batch_size) 68 | metric_logger.meters['mse'].update(mse_results.item(), n=batch_size) 69 | metric_logger.meters['ssim'].update(ssim_results.item(), n=batch_size) 70 | 71 | if n % 10 == 0: 72 | # gather the stats from all processes 73 | metric_logger.synchronize_between_processes() 74 | 75 | log_(f'*[EVAL {n}][PSNR %.6f][LPIPS %.6f][SSIM %.6f][MSE %.6f]' % 76 | (metric_logger.psnr.global_avg, metric_logger.lpips.global_avg, 77 | metric_logger.ssim.global_avg, metric_logger.mse.global_avg)) 78 | 79 | # gather the stats from all processes 80 | metric_logger.synchronize_between_processes() 81 | log_(f'*[EVAL Final][PSNR %.8f][LPIPS %.8f][SSIM %.8f][MSE %.8f]' % 82 | (metric_logger.psnr.global_avg, metric_logger.lpips.global_avg, 83 | metric_logger.ssim.global_avg, metric_logger.mse.global_avg)) 84 | 85 | return 86 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Log files 10 | logs/ 11 | nfsets/ 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 114 | .pdm.toml 115 | .pdm-python 116 | .pdm-build/ 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | .idea/ -------------------------------------------------------------------------------- /data/preproc_lidc-idri.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for preprocessing the LIDC-IDRI dataset. 3 | """ 4 | import argparse 5 | import os 6 | import shutil 7 | import dicom2nifti 8 | import nibabel as nib 9 | import numpy as np 10 | from scipy.ndimage import zoom 11 | 12 | 13 | def preprocess_nifti(input_path, output_path): 14 | # Load the Nifti image 15 | print('Process image: {}'.format(input_path)) 16 | img = nib.load(input_path) 17 | 18 | # Get the current voxel sizes 19 | voxel_sizes = img.header.get_zooms() 20 | 21 | # Calculate the target voxel size (1mm x 1mm x 1mm) 22 | target_voxel_size = (1.0, 1.0, 1.0) 23 | 24 | # Calculate the resampling factor 25 | zoom_factors = [current / target for target, current in zip(target_voxel_size, voxel_sizes)] 26 | 27 | # Resample the image 28 | print("[1] Resample the image ...") 29 | resampled_data = zoom(img.get_fdata(), zoom_factors, order=3, mode='nearest') 30 | 31 | print("[2] Center crop the image ...") 32 | crop_size = (256, 256, 256) 33 | depth, height, width = resampled_data.shape 34 | 35 | d_start = (depth - crop_size[0]) // 2 36 | h_start = (height - crop_size[1]) // 2 37 | w_start = (width - crop_size[2]) // 2 38 | cropped_arr = resampled_data[d_start:d_start + crop_size[0], h_start:h_start + crop_size[1], w_start:w_start + crop_size[2]] 39 | 40 | print("[3] Clip all values below -1000 ...") 41 | cropped_arr[cropped_arr < -1000] = -1000 42 | 43 | print("[4] Clip the upper quantile (0.999) to remove outliers ...") 44 | out_clipped = np.clip(cropped_arr, -1000, np.quantile(cropped_arr, 0.999)) 45 | 46 | print("[5] Bring image to positive and cast to ...") 47 | out_pos = out_clipped + 1000 48 | out_pos = np.int16(out_pos) 49 | 50 | assert out_pos.shape == (256, 256, 256), "The output shape should be (256,256,256)" 51 | 52 | print("[6] FINAL REPORT: Min value: {}, Max value: {}, Shape: {}".format(out_pos.min(), 53 | out_pos.max(), 54 | out_pos.shape)) 55 | print("-------------------------------------------------------------------------------") 56 | 57 | # Save the resampled image 58 | resampled_img = nib.Nifti1Image(out_pos, np.eye(4)) 59 | nib.save(resampled_img, output_path) 60 | 61 | 62 | if __name__ == "__main__": 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument('--dicom_dir', type=str, required=True, 65 | help='Directory containing the original dicom data') 66 | parser.add_argument('--nifti_dir', type=str, required=True, 67 | help='Directory to store the processed nifti files') 68 | parser.add_argument('--delete_unprocessed', type=eval, default=False, 69 | help='Set true to delete the unprocessed nifti files') 70 | args = parser.parse_args() 71 | 72 | # Convert DICOM to nifti 73 | for patient in os.listdir(args.dicom_dir): 74 | print('Convert {} to nifti'.format(patient)) 75 | if not os.path.exists(os.path.join(args.nifti_dir, patient)): 76 | os.makedirs(os.path.join(args.nifti_dir, patient)) 77 | dicom2nifti.convert_directory(os.path.join(args.dicom_dir, patient), 78 | os.path.join(args.nifti_dir, patient)) 79 | shutil.rmtree(os.path.join(args.dicom_dir, patient)) 80 | 81 | # Preprocess nifti files 82 | for root, dirs, files in os.walk(args.nifti_dir): 83 | for file in files: 84 | try: 85 | preprocess_nifti(os.path.join(root, file), os.path.join(root, 'processed.nii.gz')) 86 | except: 87 | print("Error occurred for file: {}".format(file)) 88 | 89 | # Delete unprocessed nifti files 90 | if args.delete_unprocessed: 91 | for root, dirs, files in os.walk(args.nifti_dir): 92 | for file in files: 93 | if not file == 'processed.nii.gz': 94 | os.remove(os.path.join(root, file)) -------------------------------------------------------------------------------- /fit_NFset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | from common.args import parse_args 7 | from common.utils import set_random_seed, load_model 8 | from common.w0_utils import get_w0s, save_w0s 9 | from data.dataset import get_dataset 10 | from eval.maml_full_fit import fit_nfs 11 | from models.inrs import LatentModulatedSIREN 12 | from models.model_wrapper import ModelWrapper 13 | 14 | 15 | def main(args): 16 | """ 17 | Main function to call for fitting neural fields to a whole dataset (having pretrained shared weights). 18 | :param args: parameters parsed from the command line/a config.yaml. 19 | :return: Nothing. 20 | """ 21 | 22 | """ Set a device to use """ 23 | if torch.cuda.is_available(): 24 | torch.cuda.set_device(args.gpu_id) 25 | device = torch.device(f'cuda' if torch.cuda.is_available() else 'cpu') 26 | args.device = device 27 | 28 | """ Enable determinism """ 29 | set_random_seed(args.seed) 30 | torch.backends.cudnn.deterministic = True 31 | torch.backends.cudnn.benchmark = False 32 | 33 | """ Define dataset that you want to convert to NFs """ 34 | train, val, test = get_dataset(args, all=True) 35 | train_loader = DataLoader(train, batch_size=args.test_batch_size, shuffle=False, num_workers=4, pin_memory=True, 36 | drop_last=True) 37 | val_loader = DataLoader(val, batch_size=args.test_batch_size, shuffle=False, num_workers=4, pin_memory=True, 38 | drop_last=True) 39 | test_loader = DataLoader(test, batch_size=args.test_batch_size, shuffle=False, num_workers=4, pin_memory=True, 40 | drop_last=True) 41 | 42 | """ Get w0s to initialize the model """ 43 | w0s = get_w0s(args) 44 | args.w0s = w0s 45 | 46 | """ Initialize model and optimizer """ 47 | model = LatentModulatedSIREN( 48 | in_size=args.in_size, 49 | out_size=args.out_size, 50 | min_hidden_size=args.min_hidden_dim, 51 | max_hidden_size=args.max_hidden_dim, 52 | progression_type=args.progression_type, 53 | num_layers=args.num_layers, 54 | latent_modulation_dim=args.latent_modulation_dim, 55 | w0s=args.w0s, 56 | modulate_shift=args.modulate_shift, 57 | modulate_scale=args.modulate_scale, 58 | enable_skip_connections=args.enable_skip_connections, 59 | ).to(device) 60 | 61 | """ Initialize modulation vectors """ 62 | model.modulations = torch.zeros(size=[args.test_batch_size, args.latent_modulation_dim], requires_grad=True).to(device) 63 | model = ModelWrapper(args, model) 64 | load_model(args, model) 65 | 66 | if not os.path.exists(args.save_dir): 67 | print(f'Create: {args.save_dir}') 68 | os.mkdir(args.save_dir) 69 | 70 | """ Create training set """ 71 | if not os.path.exists(args.save_dir + 'train'): 72 | print(f'Create: {args.save_dir}'+'train/') 73 | os.mkdir(args.save_dir + 'train/') 74 | fit_nfs(args, model, train_loader, set='train') 75 | print("Created MedFuncta Set: Training") 76 | 77 | """ Create validation set """ 78 | if not os.path.exists(args.save_dir + 'val'): 79 | print(f'Create: {args.save_dir}' + 'val/') 80 | os.mkdir(args.save_dir + 'val/') 81 | fit_nfs(args, model, val_loader, set='val') 82 | print("Created MedFuncta Set: Validation") 83 | 84 | """ Create test set """ 85 | if not os.path.exists(args.save_dir + 'test'): 86 | print(f'Create: {args.save_dir}' + 'test/') 87 | os.mkdir(args.save_dir + 'test') 88 | os.mkdir(args.save_dir + 'test/imgs') 89 | fit_nfs(args, model, test_loader, set='test') 90 | print("Created MedFuncta Set: Test") 91 | 92 | """ Save the model to save_dir folder """ 93 | model_path = args.save_dir + 'model.pt' 94 | torch.save(model.model, model_path) 95 | print("DONE") 96 | 97 | 98 | if __name__ == "__main__": 99 | args = parse_args() 100 | main(args) 101 | -------------------------------------------------------------------------------- /models/inrs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | from models.layers import LatentModulatedSIRENLayer 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | 9 | class LatentModulatedSIREN(nn.Module): 10 | def __init__(self, in_size, out_size, w0s, min_hidden_size=256, max_hidden_size=256, num_layers=15, 11 | latent_modulation_dim=2048, modulate_shift=True, modulate_scale=False, enable_skip_connections=False, 12 | progression_type='linear'): 13 | super().__init__() 14 | self.num_layers = num_layers 15 | self.hidden_sizes = self._calculate_progressive_sizes( 16 | min_hidden_size, max_hidden_size, num_layers, progression_type 17 | ) 18 | print(f"Progressive layer widths: {self.hidden_sizes}") 19 | layers = [] 20 | for i in range(num_layers - 1): 21 | is_first = i == 0 22 | layer_in_size = in_size if is_first else self.hidden_sizes[i - 1] 23 | layer_out_size = self.hidden_sizes[i] 24 | layers.append(LatentModulatedSIRENLayer(in_size=layer_in_size, out_size=layer_out_size, 25 | latent_modulation_dim=latent_modulation_dim, w0=w0s[i], 26 | modulate_shift=modulate_shift, modulate_scale=modulate_scale, 27 | is_first=is_first)) 28 | self.layers = nn.ModuleList(layers) 29 | self.last_layer = LatentModulatedSIRENLayer(in_size=self.hidden_sizes[-1], out_size=out_size, 30 | latent_modulation_dim=latent_modulation_dim, w0=w0s[-1], 31 | modulate_shift=modulate_shift, modulate_scale=modulate_scale, 32 | is_last=True) 33 | self.enable_skip_connections = enable_skip_connections 34 | self.modulations = torch.zeros(size=[latent_modulation_dim], requires_grad=True).to(device) 35 | 36 | def reset_modulations(self): 37 | self.modulations = self.modulations.detach() * 0 38 | self.modulations.requires_grad = True 39 | 40 | def forward(self, x, get_features=False): 41 | x = self.layers[0](x, self.modulations) 42 | for layer in self.layers[1:]: 43 | y = layer(x, self.modulations) 44 | if self.enable_skip_connections: 45 | x = x + y 46 | else: 47 | x = y 48 | features = x 49 | out = self.last_layer(features, self.modulations) + 0.5 50 | 51 | if get_features: 52 | return out, features 53 | else: 54 | return out 55 | 56 | 57 | def _calculate_progressive_sizes(self, min_size, max_size, num_layers, progression_type='linear'): 58 | """Calculate progressive hidden layer sizes.""" 59 | if num_layers <= 1: 60 | return [min_size] 61 | 62 | # We have num_layers-1 hidden layers (excluding output layer) 63 | n_hidden = num_layers - 1 64 | 65 | if progression_type == 'linear': 66 | # Linear interpolation 67 | sizes = np.linspace(min_size, max_size, n_hidden) 68 | elif progression_type == 'exponential': 69 | # Exponential growth 70 | log_min = np.log(min_size) 71 | log_max = np.log(max_size) 72 | log_sizes = np.linspace(log_min, log_max, n_hidden) 73 | sizes = np.exp(log_sizes) 74 | elif progression_type == 'cosine': 75 | # Cosine schedule (slower at beginning and end) 76 | t = np.linspace(0, 1, n_hidden) 77 | cosine_factor = (1 - np.cos(t * np.pi)) / 2 78 | sizes = min_size + (max_size - min_size) * cosine_factor 79 | else: 80 | raise ValueError(f"Unknown progression_type: {progression_type}") 81 | 82 | # Round to nearest multiple of 8 for efficiency (optional) 83 | sizes = [int(8 * round(size / 8)) for size in sizes] 84 | 85 | return sizes 86 | -------------------------------------------------------------------------------- /eval/maml_scale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import lpips 3 | from common.utils import MetricLogger, psnr 4 | from train.maml_boot import inner_adapt_test_scale 5 | from pytorch_msssim import ssim 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | 10 | def test_model(args, step, model_wrapper, test_loader, logger=None): 11 | """ 12 | Function that performs the model evaluation 13 | """ 14 | metric_logger = MetricLogger(delimiter=" ") 15 | lpips_score = lpips.LPIPS(net='alex').to(device) 16 | 17 | if logger is None: 18 | log_ = print 19 | else: 20 | log_ = logger.log 21 | 22 | model_wrapper.model.eval() # Enter evaluation mode 23 | model_wrapper.coord_init() # Reset coordinates 24 | 25 | """ Iterate over test loader """ 26 | for n, data in enumerate(test_loader): 27 | if n * args.test_batch_size > args.num_test_signals: 28 | break 29 | 30 | data, _ = data # Discard label 31 | data = data.float().to(device) 32 | batch_size = data.size(0) 33 | model_wrapper.model.reset_modulations() 34 | input = data 35 | 36 | _ = inner_adapt_test_scale(model_wrapper=model_wrapper, data=data, step_size=args.inner_lr, 37 | num_steps=args.inner_steps_test, first_order=True, sample_type=args.sample_type, 38 | scale_type='grad') 39 | 40 | """ Outer loss aggregation """ 41 | with torch.no_grad(): 42 | loss_out_tt_gradscale = model_wrapper(data) 43 | loss_out = loss_out_tt_gradscale 44 | psnr_out = psnr(loss_out).mean() 45 | 46 | if args.data_type == 'img': 47 | out = model_wrapper().clamp(0,1) 48 | lpips_result = lpips_score((out * 2 - 1), (input * 2 - 1)).mean() 49 | ssim_result = ssim(out, input, data_range=1.).mean() 50 | metric_logger.meters['lpips'].update(lpips_result.item(), n=batch_size) 51 | metric_logger.meters['ssim'].update(ssim_result.item(), n=batch_size) 52 | metric_logger.meters['psnr'].update(psnr_out.item(), n=batch_size) 53 | metric_logger.meters['loss'].update(loss_out.mean().item(), n=batch_size) 54 | 55 | if args.data_type == 'img3d': 56 | out = model_wrapper().clamp(0, 1) 57 | lpips_result = torch.zeros_like(loss_out_tt_gradscale).mean() 58 | ssim_result = ssim(out, input, data_range=1.).mean() 59 | metric_logger.meters['lpips'].update(lpips_result.item(), n=batch_size) 60 | metric_logger.meters['ssim'].update(ssim_result.item(), n=batch_size) 61 | metric_logger.meters['psnr'].update(psnr_out.item(), n=batch_size) 62 | metric_logger.meters['loss'].update(loss_out.mean().item(), n=batch_size) 63 | 64 | if args.data_type == 'timeseries': 65 | lpips_result = torch.zeros_like(loss_out_tt_gradscale).mean() 66 | ssim_result = torch.zeros_like(loss_out_tt_gradscale).mean() 67 | metric_logger.meters['lpips'].update(lpips_result.item(), n=batch_size) 68 | metric_logger.meters['ssim'].update(ssim_result.item(), n=batch_size) 69 | metric_logger.meters['psnr'].update(psnr_out.item(), n=batch_size) 70 | metric_logger.meters['loss'].update(loss_out.mean().item(), n=batch_size) 71 | 72 | metric_logger.synchronize_between_processes() 73 | 74 | """ Log to tensorboard & console """ 75 | if args.data_type == 'img' or args.data_type == 'img3d': 76 | log_('*[EVAL-GTT-REC] [Loss %f] [PSNR %.3f] [LPIPS %.3f] [SSIM %.3f]' % 77 | (metric_logger.loss.global_avg, metric_logger.psnr.global_avg, metric_logger.lpips.global_avg, 78 | metric_logger.ssim.global_avg)) 79 | if logger is not None: 80 | logger.scalar_summary('eval/loss', metric_logger.loss.global_avg, step) 81 | logger.scalar_summary('eval/psnr', metric_logger.psnr.global_avg, step) 82 | logger.scalar_summary('eval/ssim', metric_logger.ssim.global_avg, step) 83 | logger.scalar_summary('eval/lpips', metric_logger.lpips.global_avg, step) 84 | logger.log_image('eval/img_in', input, step) 85 | logger.log_image('eval/img_out', out, step) 86 | 87 | 88 | return metric_logger.psnr.global_avg, metric_logger.lpips.global_avg, metric_logger.ssim.global_avg 89 | -------------------------------------------------------------------------------- /common/args.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os.path 3 | from argparse import ArgumentParser 4 | 5 | 6 | def load_cfg(args): 7 | with open(args.config, "rb") as f: 8 | cfg = yaml.safe_load(f) 9 | 10 | for key, value in cfg.items(): 11 | args.__dict__[key] = value 12 | 13 | return args 14 | 15 | 16 | def parse_args(): 17 | parser = ArgumentParser() 18 | 19 | """ Config """ 20 | parser.add_argument('--config', help='Path to config .yaml file', type=str, default=None) 21 | 22 | """ System configuration """ 23 | parser.add_argument('--gpu_id', help='GPU ID', type=int, default=0) 24 | parser.add_argument('--seed', help='Random seed', type=int, default=42) 25 | 26 | """ Resume training """ 27 | parser.add_argument('--resume_path', help='Path to the logdir of training to resume', type=str) 28 | 29 | """ Training configuration """ 30 | parser.add_argument('--dataset', help='Dataset', type=str) 31 | parser.add_argument('--img_size', help='Image size', type=int, default=64) 32 | parser.add_argument('--batch_size', help='Batch size (number of images per batch) used for training', type=int, default=24) 33 | parser.add_argument('--outer_steps', help='Numer of meta-learning steps to perform', type=int, default=250000) 34 | parser.add_argument('--inner_steps', help='Number of inner loop optimization steps (G)', type=int, default=10) 35 | parser.add_argument('--meta_lr', help='Learning rate for meta-learning updates (beta)', type=float, default=3e-6) 36 | parser.add_argument('--inner_lr', help='Learning rate for inner loop (alpha)', type=float, default=1e-2) 37 | parser.add_argument('--lr_scheduler', help='If True, a global lr-schedule is applied', type=eval, default=True) 38 | parser.add_argument('--data_ratio', help='Ratio of data used for training (gamma)', type=float, default=0.25) 39 | 40 | """ Testing configuration """ 41 | parser.add_argument('--test_batch_size', help='Batch size used for testing', type=int, default=24) 42 | parser.add_argument('--num_test_signals', help='Number of signals used for testing', default=256, type=int) 43 | parser.add_argument('--inner_steps_test', help='Number of inner loop update steps at test-time (H)', type=int, default=20) 44 | 45 | """ Model configuration """ 46 | parser.add_argument('--min_hidden_dim', help='MLP hidden size start', type=int, default=256) 47 | parser.add_argument('--max_hidden_dim', help='MLP hidden size start', type=int, default=256) 48 | parser.add_argument('--progression_type', help='Progression type hidden_dim [linear, exponential, cosine]', type=str, default='linear') 49 | parser.add_argument('--num_layers', help='Number of MLP layers (K)', type=int, default=15) 50 | parser.add_argument('--latent_modulation_dim', help='Representation size (P)', type=int, default=2048) 51 | parser.add_argument('--w0', help='SIREN parameter w0', type=float, default=30.) 52 | parser.add_argument('--wK', help='SIREN parameter wK', type=float, default=300.) 53 | parser.add_argument('--w0_sched_type', help='Type of w0 schedule', type=str, default='linear') 54 | parser.add_argument('--modulate_shift', help='Set True to use shift modulations', type=eval, default=True) 55 | parser.add_argument('--modulate_scale', help='Set True to use scale modulations (not recommended)', type=eval, default=False) 56 | parser.add_argument('--enable_skip_connections', help='Set True to enable skip-connections', type=eval, default=False) 57 | 58 | """ Logging configuration """ 59 | parser.add_argument('--print_step', help='Print every x steps', type=int, default=100) 60 | parser.add_argument('--print_img_step', help='Print images every x steps', type=int, default=100) 61 | parser.add_argument('--eval_step', help='Evaluate every x steps', type=int, default=1000) 62 | parser.add_argument('--save_step', help='Save model every x steps', type=int, default=50000) 63 | parser.add_argument('--advanced_step', type=int, default=1000) 64 | parser.add_argument('--log_advanced', help='Activate to log advanced statistics', type=eval, default=False) 65 | 66 | """ Eval configuration """ 67 | parser.add_argument('--load_path', help='Load model from this path', type=str, default=None) 68 | 69 | """ Fitting configuration """ 70 | parser.add_argument('--save_dir', help='Directory to store shared model, modulations and labels', type=str, default=None) 71 | 72 | """ Parse Arguments """ 73 | args = parser.parse_args() 74 | 75 | """ Load config files """ 76 | if args.config is not None and os.path.exists(args.config): 77 | load_cfg(args) 78 | 79 | return args 80 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torch.utils.data import DataLoader 4 | 5 | from common.args import parse_args 6 | from common.utils import set_random_seed, Logger, InfiniteSampler 7 | from common.w0_utils import get_w0s, save_w0s 8 | from data.dataset import get_dataset 9 | from models.inrs import LatentModulatedSIREN 10 | from models.model_wrapper import ModelWrapper 11 | from train.trainer import trainer 12 | from train.maml_boot import train_step 13 | from eval.maml_scale import test_model 14 | 15 | 16 | def main(args): 17 | """ 18 | Main function to call for running a training procedure (meta-learning a shred network). 19 | :param args: parameters parsed from the command line/a config.yaml. 20 | :return: Nothing. 21 | """ 22 | 23 | """ Set a device to use """ 24 | if torch.cuda.is_available(): 25 | torch.cuda.set_device(args.gpu_id) 26 | device = torch.device(f'cuda' if torch.cuda.is_available() else 'cpu') 27 | args.device = device 28 | 29 | """ Enable determinism """ 30 | set_random_seed(args.seed) 31 | torch.backends.cudnn.deterministic = True 32 | torch.backends.cudnn.benchmark = True 33 | 34 | """ Define dataset """ 35 | train_set, val_set = get_dataset(args) 36 | 37 | """ Define dataloader """ 38 | infinite_sampler = InfiniteSampler(train_set, rank=0, num_replicas=1, shuffle=True, seed=args.seed) 39 | train_loader = DataLoader(train_set, sampler=infinite_sampler, batch_size=args.batch_size, num_workers=4, 40 | prefetch_factor=2) 41 | val_loader = DataLoader(val_set, batch_size=args.test_batch_size, shuffle=False, num_workers=4) 42 | 43 | """ Get w0s to initialize the model """ 44 | w0s = get_w0s(args) 45 | args.w0s = w0s 46 | 47 | """ Initialize model """ 48 | model = LatentModulatedSIREN( 49 | in_size=args.in_size, # Input dimension (coordinate dim) C 50 | out_size=args.out_size, # Output dimension (signal dim) D 51 | min_hidden_size=args.min_hidden_dim, # First layers hidden dimension 52 | max_hidden_size=args.max_hidden_dim, # Last layers hidden dimension (usually min_hidden_dim) 53 | progression_type=args.progression_type, # Defines how hidden dimension progresses in model 54 | num_layers=args.num_layers, # Number of layers K 55 | latent_modulation_dim=args.latent_modulation_dim, # Representation size P 56 | w0s=args.w0s, # Per-layer omega parameters 57 | modulate_shift=args.modulate_shift, # If shift modulation is used (default: True) 58 | modulate_scale=args.modulate_scale, # If scale modulation is used (default: False) 59 | enable_skip_connections=args.enable_skip_connections, # Set True to enable skip-connections (default: False) 60 | ).to(device) 61 | 62 | """ Initialize modulation vectors (signal-specific parameter vector) """ 63 | model.modulations = torch.zeros(size=[args.batch_size, args.latent_modulation_dim], requires_grad=True).to(device) 64 | model.modulation_init = model.modulations.clone().detach() 65 | 66 | """ Wrap the model """ 67 | model = ModelWrapper(args, model) 68 | 69 | """ Define training and test functions """ 70 | train_function = train_step 71 | test_function = test_model 72 | 73 | """ Define logger """ 74 | fname = (f'{args.dataset}_size{args.img_size}_bs{args.batch_size}_inner{args.inner_steps}_gamma{args.data_ratio}_' 75 | f'{args.config.split("/")[-1].split(".yaml")[0]}') 76 | logger = Logger(fname, ask=args.resume_path is None, rank=args.gpu_id) 77 | logger.log(args) 78 | logger.log(w0s) 79 | logger.log(model) 80 | logger.log_hyperparameters(args) 81 | 82 | """ Save w0s """ 83 | save_w0s(w0s, logger) 84 | 85 | """ Initialize meta-optimizer """ 86 | meta_optimizer = optim.AdamW(params=model.model.parameters(), lr=args.meta_lr) 87 | 88 | """ Initialize a global lr-scheduler (recommended) """ 89 | scheduler = None 90 | if args.lr_scheduler: 91 | scheduler = optim.lr_scheduler.CosineAnnealingLR(meta_optimizer, eta_min=1e-7, T_max=args.max_iter) 92 | 93 | """ Start training """ 94 | trainer(args, train_function, test_function, model, meta_optimizer, train_loader, val_loader, logger, scheduler) 95 | 96 | """ Close logger """ 97 | logger.close_writer() 98 | 99 | 100 | if __name__ == "__main__": 101 | args = parse_args() 102 | main(args) 103 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MedFuncta: A Unified Framework for Learning Efficient Medical Neural Fields 2 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 3 | [![Static Badge](https://img.shields.io/badge/Project-page-blue)](https://pfriedri.github.io/medfuncta-io/) 4 | [![arXiv](https://img.shields.io/badge/arXiv-2502.14401-b31b1b.svg)](https://arxiv.org/abs/2502.14401) 5 | 6 | This is the official PyTorch implementation of the paper **MedFuncta: A Unified Framework for Learning Efficient Medical Neural Fields** by [Paul Friedrich](https://pfriedri.github.io/), [Florentin Bieder](https://dbe.unibas.ch/en/persons/florentin-bieder/), [Julian McGinnis](https://www.kiinformatik.mri.tum.de/de/team/julian_mcginnis), [Julia Wolleb](https://medicine.yale.edu/profile/julia-wolleb/), [Daniel Rueckert](https://www.professoren.tum.de/rueckert-daniel) and [Philippe C. Cattin](https://dbe.unibas.ch/en/persons/philippe-claude-cattin/). 7 | 8 | If you find our work useful, please consider to :star: **star this repository** and :memo: **cite our paper**: 9 | ```bibtex 10 | @article{friedrich2025medfuncta, 11 | title={MedFuncta: A Unified Framework for Learning Efficient Medical Neural Fields}, 12 | author={Friedrich, Paul and Bieder, Florentin and McGinnis, Julian and Wolleb, Julia and Rueckert, Daniel and Cattin, Philippe C}, 13 | journal={arXiv preprint arXiv:2502.14401}, 14 | year={2025} 15 | } 16 | ``` 17 | ## Paper Abstract 18 | Research in medical imaging primarily focuses on discrete data representations that poorly scale with grid resolution and fail to capture the often continuous nature of the underlying signal. 19 | Neural Fields (NFs) offer a powerful alternative by modeling data as continuous functions. 20 | While single-instance NFs have successfully been applied in medical contexts, extending them to large-scale medical datasets remains an open challenge. 21 | We therefore introduce [**MedFuncta**](https://arxiv.org/abs/2502.14401), a unified framework for large-scale NF training on diverse medical signals. 22 | Building on Functa, our approach encodes data into a unified representation, namely a 1D latent vector, that modulates a shared, meta-learned NF, enabling generalization across a dataset. 23 | We revisit common design choices, introducing a non-constant frequency parameter $\omega$ in widely used SIREN activations, and establish a connection between this $\omega$-schedule and layer-wise learning rates, relating our findings to recent work in theoretical learning dynamics. 24 | We additionally introduce a scalable meta-learning strategy for shared network learning that employs sparse supervision during training, thereby reducing memory consumption and computational overhead while maintaining competitive performance. 25 | Finally, we evaluate MedFuncta across a diverse range of medical datasets and show how to solve relevant downstream tasks on our neural data representation. 26 | To promote further research in this direction, we release our code, model weights and the first large-scale dataset - [**MedNF**](https://doi.org/10.5281/zenodo.14898708) - containing > 500 k latent vectors for multi-instance medical NFs. 27 | 28 |

29 | 30 |

31 | 32 | ## Dependencies 33 | We recommend using a [conda](https://github.com/conda-forge/miniforge#mambaforge) environment to install the required dependencies. 34 | You can create and activate such an environment called `medfuncta` by running the following commands: 35 | ```sh 36 | mamba env create -f environment.yaml 37 | mamba activate medfuncta 38 | ``` 39 | 40 | ## Training (Meta-Learning) 41 | To obtain meta-learned shared model parameters, simply run the following command with the correct `config.yaml`: 42 | ```sh 43 | python train.py --config ./configs/experiments/DATASET_RESOLUTION.yaml 44 | ``` 45 | 46 | ## Evaluation (Reconstruction Experiments) 47 | To perform reconstruction experiments (evaluate the reconstruction quality), simply run the following command with the correct `config.yaml`: 48 | ```sh 49 | python eval.py --config ./configs/eval/experiments/DATASET_RESOLUTION.yaml 50 | ``` 51 | ## Create a MedNF Dataset 52 | To convert a dataset into our MedFuncta representation, simply run the following command with the correct `config.yaml`: 53 | ```sh 54 | python fit_NFset.py --config ./configs/fit/DATASET_RESOLUTION.yaml 55 | ``` 56 | 57 | ## Classification Experiments 58 | The source code for reproducing our classification experiments can be found in `/downstream_tasks/classification.py`. 59 | All arguments can be set in the `Args` class in this script. 60 | 61 | 62 | ## MedNF Dataset 63 | We release **MedNF** a large-scale dataset containing more than 500 k medical NFs. 64 | More information on the dataset can be found in our paper (Appendix D). 65 | The dataset can be accessed here: [https://doi.org/10.5281/zenodo.14898708](https://doi.org/10.5281/zenodo.14898708). 66 | 67 | The dataset consists of the following 7 sub-datasets: 68 |

69 | 70 |

71 | 72 | ## Data 73 | To ensure good reproducibility, we trained and evaluated our network on publicly available datasets: 74 | * **MedMNIST**, a large-scale MNIST-like collection of standardized biomedical images. More information is avilable [here](https://medmnist.com/). 75 | 76 | * **MIT-BIH Arryhythmia**, a heartbeat classification dataset. We use a preprocessed version that is available [here](https://www.kaggle.com/datasets/shayanfazeli/heartbeat). 77 | 78 | * **BRATS 2023: Adult Glioma**, a dataset containing routine clinically-acquired, multi-site multiparametric magnetic resonance imaging (MRI) scans of brain tumor patients. We just used the T1-weighted images for training. The data is available [here](https://www.synapse.org/#!Synapse:syn51514105). 79 | 80 | * **LIDC-IDRI**, a dataset containing multi-site, thoracic computed tomography (CT) scans of lung cancer patients. The data is available [here](https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=1966254). 81 | 82 | The provided code works for the following data structure (you might need to adapt the directories in `data/dataset.py`): 83 | ``` 84 | data 85 | └───BRATS 86 | └───BraTS-GLI-00000-000 87 | └───BraTS-GLI-00000-000-seg.nii.gz 88 | └───BraTS-GLI-00000-000-t1c.nii.gz 89 | └───BraTS-GLI-00000-000-t1n.nii.gz 90 | └───BraTS-GLI-00000-000-t2f.nii.gz 91 | └───BraTS-GLI-00000-000-t2w.nii.gz 92 | └───BraTS-GLI-00001-000 93 | └───BraTS-GLI-00002-000 94 | ... 95 | 96 | └───LIDC-IDRI 97 | └───LIDC-IDRI-0001 98 | └───preprocessed.nii.gz 99 | └───LIDC-IDRI-0002 100 | └───LIDC-IDRI-0003 101 | ... 102 | 103 | └───MIT-BIH 104 | └───mitbih_test.csv 105 | └───mitbih_train.csv 106 | 107 | ... 108 | ``` 109 | We provide a script for preprocessing LIDC-IDRI. Simply run the following command with the correct path to the downloaded DICOM files `DICOM_PATH` and the directory you want to store the processed nifti files `NIFTI_PATH`: 110 | ```sh 111 | python data/preproc_lidc-idri.py --dicom_dir DICOM_PATH --nifti_dir NIFTI_PATH 112 | ``` 113 | 114 | ## Acknowledgements 115 | Our code is based on / inspired by the following repositories: 116 | * https://github.com/jihoontack/GradNCP 117 | * https://github.com/google-deepmind/functa 118 | * https://github.com/pfriedri/wdm-3d 119 | -------------------------------------------------------------------------------- /train/maml_boot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from common.utils import psnr 3 | 4 | 5 | def get_grad_norm(grads, detach=True): 6 | grad_norm_list = [] 7 | for grad in grads: 8 | if grad is None: 9 | grad_norm = 0 10 | else: 11 | if detach: 12 | grad_norm = torch.norm(grad.data, p=2, keepdim=True).unsqueeze(dim=0) 13 | else: 14 | grad_norm = torch.norm(grad, p=2, keepdim=True).unsqueeze(dim=0) 15 | 16 | grad_norm_list.append(grad_norm) 17 | return torch.norm(torch.cat(grad_norm_list, dim=0), p=2, dim=1) 18 | 19 | 20 | def train_step(args, step, model_wrapper, optimizer, data, metric_logger, logger, scheduler=None): 21 | """ 22 | Function that performs a single meta update step. 23 | """ 24 | model_wrapper.model.train() # Enable model training 25 | model_wrapper.coord_init() # Reset coordinates 26 | model_wrapper.model.reset_modulations() # Reset modulations (zero-initialization) 27 | 28 | batch_size = data.size(0) 29 | 30 | if step % args.print_img_step == 0: 31 | input = data # Save input data for logging 32 | 33 | """ Inner-loop optimization for G steps """ 34 | _ = inner_adapt(model_wrapper=model_wrapper, data=data, step_size=args.inner_lr, 35 | num_steps=args.inner_steps, first_order=False, sample_type=args.sample_type) 36 | 37 | """ Compute loss using full context set""" 38 | model_wrapper.coord_init() # Reset coordinates 39 | loss_out = model_wrapper(data) # Compute reconstruction loss 40 | psnr_out = psnr(loss_out) 41 | 42 | if step % args.print_img_step == 0: 43 | images = model_wrapper() # Sample images 44 | loss = loss_out.mean() 45 | 46 | """ Meta update (optimize shared weights) """ 47 | optimizer.zero_grad() 48 | loss.backward() 49 | grad_norm = torch.nn.utils.clip_grad_norm_(model_wrapper.model.parameters(), 0.5) # Clip gradient 50 | optimizer.step() 51 | torch.cuda.synchronize() 52 | 53 | """ Update scheduler """ 54 | if scheduler: 55 | scheduler.step() 56 | 57 | """ Track stats """ 58 | metric_logger.meters['loss'].update(loss_out.mean().item(), n=batch_size) 59 | metric_logger.meters['psnr'].update(psnr_out.mean().item(), n=batch_size) 60 | metric_logger.meters['grad_norm'].update(grad_norm.item(), n=batch_size) 61 | metric_logger.synchronize_between_processes() 62 | 63 | """ Log scalars to tensorboard & console """ 64 | if step % args.print_step == 0: 65 | logger.scalar_summary('train/loss', metric_logger.loss.global_avg, step) 66 | logger.scalar_summary('train/psnr', metric_logger.psnr.global_avg, step) 67 | 68 | logger.log('[TRAIN-REC] [Step %3d] [Loss %f] [PSNR %.3f]' % 69 | (step, metric_logger.loss.global_avg, metric_logger.psnr.global_avg)) 70 | 71 | logger.scalar_summary('supp/grad_norm', metric_logger.grad_norm.global_avg, step) 72 | if scheduler: 73 | logger.scalar_summary('supp/lr', scheduler.get_last_lr()[0], step) 74 | 75 | """ Log images to tensorboard""" 76 | if step % args.print_img_step == 0: 77 | logger.log_image('train/img_in', input, step) 78 | logger.log_image('train/img_pred', images, step) 79 | 80 | """ Log activation distributions and weight dynamics """ 81 | if step % args.advanced_step == 0: 82 | if args.log_advanced: 83 | # Weight dynamics 84 | for name, param in model_wrapper.model.named_parameters(): 85 | if 'weight' in name: 86 | logger.log_hist(f'weights/{name}', param.data, step) 87 | if param.grad is not None: 88 | logger.log_hist(f'grads/{name}', param.grad, step) 89 | 90 | metric_logger.reset() 91 | 92 | 93 | def inner_adapt(model_wrapper, data, step_size=1e-2, num_steps=3, first_order=False, sample_type='none'): 94 | """ 95 | Performs the inner loop optimization. 96 | :param model_wrapper: the wrapped model. 97 | :param data: the data used for training. 98 | :param step_size: the inner_loop learning rate (alpha). 99 | :param num_steps: numer of inner-loop update steps G. 100 | :param first_order: if True, first order MAML is used (not recommended). 101 | :param sample_type: coordinate sample type. 102 | :return: loss 103 | """ 104 | loss = 0. # Initialize outer_loop loss 105 | 106 | """ Perform num_step (G) inner-loop updates """ 107 | for _ in range(num_steps): 108 | if sample_type != 'none': 109 | model_wrapper.sample_coordinates(sample_type='random', data=data) # Sample coordinates for the training step 110 | loss = inner_loop_step(model_wrapper, data, step_size, first_order) 111 | return loss 112 | 113 | 114 | def inner_loop_step(model_wrapper, data, inner_lr=1e-2, first_order=False): 115 | """ Performs a single inner-loop update. """ 116 | batch_size = data.size(0) 117 | 118 | with torch.enable_grad(): 119 | loss = model_wrapper(data) 120 | grads = torch.autograd.grad( 121 | loss.mean() * batch_size, 122 | model_wrapper.model.modulations, 123 | create_graph=not first_order, 124 | )[0] 125 | model_wrapper.model.modulations = model_wrapper.model.modulations - inner_lr * grads 126 | return loss 127 | 128 | 129 | def inner_adapt_test_scale(model_wrapper, data, step_size=1e-2, num_steps=3, first_order=False, 130 | sample_type='none', scale_type='grad'): 131 | """ Similar to inner_adapt, but with rescaled gradients at test-time """ 132 | loss = 0. # Initialize outer_loop loss 133 | 134 | """ Perform num_step (H) inner-loop updates """ 135 | for _ in range(num_steps): 136 | if sample_type != 'none': 137 | model_wrapper.sample_coordinates(sample_type='random', data=data) 138 | loss = inner_loop_step_tt_gradscale(model_wrapper, data, step_size, first_order, scale_type) 139 | return loss 140 | 141 | 142 | def inner_loop_step_tt_gradscale(model_wrapper, data, inner_lr=1e-2, first_order=False, scale_type='grad'): 143 | """ Similar to inner_loop_step, but with rescaled gradients at test-time. """ 144 | batch_size = data.size(0) 145 | model_wrapper.model.zero_grad() 146 | 147 | """ Get gradients with sparse supervision (as in training) """ 148 | with torch.enable_grad(): 149 | subsample_loss = model_wrapper(data) 150 | subsample_grad = torch.autograd.grad( 151 | subsample_loss.mean() * batch_size, 152 | model_wrapper.model.modulations, 153 | create_graph=False, 154 | allow_unused=True 155 | )[0] 156 | 157 | model_wrapper.model.zero_grad() 158 | model_wrapper.coord_init() # Reset coordinates 159 | 160 | """ Get gradients wit full supervision (during inference)""" 161 | with torch.enable_grad(): 162 | loss = model_wrapper(data) 163 | grads = torch.autograd.grad( 164 | loss.mean() * batch_size, 165 | model_wrapper.model.modulations, 166 | create_graph=not first_order, 167 | allow_unused=True 168 | )[0] 169 | 170 | """ Rescale the gradient """ 171 | if scale_type == 'grad': 172 | subsample_grad_norm = get_grad_norm(subsample_grad, detach=True) 173 | grad_norm = get_grad_norm(grads, detach=True) 174 | grad_scale = subsample_grad_norm / (grad_norm + 1e-16) 175 | grad_scale_ = grad_scale.view((batch_size,) + (1,) * (len(grads.shape) - 1)).detach() 176 | else: 177 | raise NotImplementedError() 178 | 179 | model_wrapper.model.modulations = model_wrapper.model.modulations - inner_lr * grad_scale_ * grads 180 | 181 | return loss 182 | -------------------------------------------------------------------------------- /models/model_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import random 5 | from einops import rearrange 6 | 7 | 8 | def exists(val): 9 | return val is not None 10 | 11 | 12 | class ModelWrapper(nn.Module): 13 | def __init__(self, args, model): 14 | super().__init__() 15 | self.args = args 16 | self.model = model 17 | self.data_type = args.data_type 18 | 19 | self.sampled_coord = None 20 | self.sampled_index = None 21 | self.gradncp_coord = None 22 | self.gradncp_index = None 23 | 24 | if self.data_type == 'img': 25 | self.width = args.data_size[1] 26 | self.height = args.data_size[2] 27 | 28 | mgrid = self.shape_to_coords((self.width, self.height)) 29 | mgrid = rearrange(mgrid, 'h w c -> (h w) c') 30 | 31 | elif self.data_type == 'img3d': 32 | self.width = args.data_size[1] 33 | self.height = args.data_size[2] 34 | self.depth = args.data_size[3] 35 | 36 | mgrid = self.shape_to_coords((self.width, self.height, self.depth)) 37 | mgrid = rearrange(mgrid, 'h w d c -> (h w d) c') 38 | 39 | elif self.data_type == 'timeseries': 40 | self.length = args.data_size[-1] 41 | mgrid = self.shape_to_coords([self.length]) 42 | 43 | else: 44 | raise NotImplementedError() 45 | 46 | self.register_buffer('grid', mgrid) 47 | 48 | def coord_init(self): 49 | self.sampled_coord = None 50 | self.sampled_index = None 51 | self.gradncp_coord = None 52 | self.gradncp_index = None 53 | 54 | def get_batch_coords(self, x=None): 55 | if x is None: 56 | meta_batch_size = 1 57 | else: 58 | meta_batch_size = x.size(0) 59 | 60 | # batch of coordinates 61 | if self.sampled_coord is None and self.gradncp_coord is None: 62 | coords = self.grid 63 | elif self.gradncp_coord is not None: 64 | return self.gradncp_coord, meta_batch_size 65 | else: 66 | coords = self.sampled_coord 67 | coords = coords.clone().detach()[None, ...].repeat((meta_batch_size,) + (1,) * len(coords.shape)) 68 | return coords, meta_batch_size 69 | 70 | def shape_to_coords(self, spatial_dims): 71 | coords = [] 72 | for i in range(len(spatial_dims)): 73 | coords.append(torch.linspace(-1.0, 1.0, spatial_dims[i])) 74 | return torch.stack(torch.meshgrid(*coords, indexing='ij'), dim=-1) 75 | 76 | def sample_coordinates(self, sample_type, data): 77 | if sample_type == 'random': 78 | self.random_sample() 79 | elif sample_type == 'gradncp': 80 | self.gradncp(data) 81 | else: 82 | raise NotImplementedError() 83 | 84 | def gradncp(self, x): 85 | ratio = self.args.data_ratio 86 | meta_batch_size = x.size(0) 87 | coords = self.grid 88 | coords = coords.clone().detach()[None, ...].repeat((meta_batch_size,) + (1,) * len(coords.shape)) 89 | coords = coords.to(self.args.device) 90 | with torch.no_grad(): 91 | out, feature = self.model(coords, get_features=True) 92 | 93 | if self.data_type == 'img': 94 | out = rearrange(out, 'b hw c -> b c hw') 95 | feature = rearrange(feature, 'b hw f -> b f hw') 96 | x = rearrange(x, 'b c h w -> b c (h w)') 97 | elif self.data_type == 'img3d': 98 | out = rearrange(out, 'b hwd c -> b c hwd') 99 | feature = rearrange(feature, 'b hwd f -> b f hwd') 100 | x = rearrange(x, 'b c h w d -> b c (h w d)') 101 | elif self.data_type == 'timeseries': 102 | out = rearrange(out, 'b l c -> b c l') 103 | feature = rearrange(feature, 'b l f -> b f l') 104 | else: 105 | raise NotImplementedError() 106 | 107 | error = x - out 108 | 109 | gradient = -1 * feature.unsqueeze(dim=1) * error.unsqueeze(dim=2) 110 | gradient_bias = -1 * error.unsqueeze(dim=2) 111 | gradient = torch.cat([gradient, gradient_bias], dim=2) 112 | gradient = rearrange(gradient, 'b c f hw -> b (c f) hw') 113 | gradient_norm = torch.norm(gradient, dim=1) 114 | 115 | coords_len = gradient_norm.size(1) 116 | 117 | self.gradncp_index = torch.sort(gradient_norm, dim=1, descending=True)[1][:, :int(coords_len * ratio)] 118 | self.gradncp_coord = torch.gather(coords, 1, self.gradncp_index.unsqueeze(dim=2).repeat(1, 1, self.args.in_size)) 119 | self.gradncp_index = self.gradncp_index.unsqueeze(dim=1).repeat(1, self.args.out_size, 1) 120 | 121 | def random_sample(self): 122 | coord_size = self.grid.size(0) 123 | perm = torch.randperm(coord_size) 124 | self.sampled_index = perm[:int(self.args.data_ratio * coord_size)] 125 | self.sampled_coord = self.grid[self.sampled_index] 126 | return self.sampled_coord 127 | 128 | def forward(self, x=None): 129 | if self.data_type == 'img': 130 | return self.forward_img(x) 131 | if self.data_type == 'img3d': 132 | return self.forward_img3d(x) 133 | if self.data_type == 'timeseries': 134 | return self.forward_timeseries(x) 135 | else: 136 | raise NotImplementedError() 137 | 138 | def forward_img(self, x): 139 | coords, meta_batch_size = self.get_batch_coords(x) 140 | coords = coords.to(self.args.device) 141 | 142 | out = self.model(coords) 143 | out = rearrange(out, 'b hw c -> b c hw') 144 | 145 | if exists(x): 146 | if self.sampled_coord is None and self.gradncp_coord is None: 147 | return F.mse_loss(x.view(meta_batch_size, -1), out.reshape(meta_batch_size, -1), reduce=False).mean(dim=1) 148 | elif self.gradncp_coord is not None: 149 | x = rearrange(x, 'b c h w -> b c (h w)') 150 | x = torch.gather(x, 2, self.gradncp_index) 151 | return F.mse_loss(x.view(meta_batch_size, -1), out.reshape(meta_batch_size, -1), reduce=False).mean(dim=1) 152 | else: 153 | x = rearrange(x, 'b c h w -> b c (h w)')[:, :, self.sampled_index] 154 | return F.mse_loss(x.view(meta_batch_size, -1), out.reshape(meta_batch_size, -1), reduce=False).mean(dim=1) 155 | 156 | out = rearrange(out, 'b c (h w) -> b c h w', h=self.height, w=self.width) 157 | return out 158 | 159 | def forward_img3d(self, x): 160 | coords, meta_batch_size = self.get_batch_coords(x) 161 | coords = coords.to(self.args.device) 162 | 163 | out = self.model(coords) 164 | out = rearrange(out, 'b hwd c -> b c hwd') 165 | 166 | if exists(x): 167 | if self.sampled_coord is None and self.gradncp_coord is None: 168 | return F.mse_loss(x.view(meta_batch_size, -1), out.reshape(meta_batch_size, -1), reduce=False).mean(dim=1) 169 | elif self.gradncp_coord is not None: 170 | x = rearrange(x, 'b c h w d -> b c (h w d)') 171 | x = torch.gather(x, 2, self.gradncp_index) 172 | return F.mse_loss(x.view(meta_batch_size, -1), out.reshape(meta_batch_size, -1), reduce=False).mean(dim=1) 173 | else: 174 | x = rearrange(x, 'b c h w d -> b c (h w d)')[:, :, self.sampled_index] 175 | return F.mse_loss(x.view(meta_batch_size, -1), out.reshape(meta_batch_size, -1), reduce=False).mean(dim=1) 176 | 177 | out = rearrange(out, 'b c (h w d) -> b c h w d', h=self.height, w=self.width, d=self.depth) 178 | return out 179 | 180 | def forward_timeseries(self, x): 181 | coords, meta_batch_size = self.get_batch_coords(x) 182 | coords = coords.to(self.args.device) 183 | 184 | out = self.model(coords) 185 | out = rearrange(out, 'b l c -> b c l') 186 | 187 | if exists(x): 188 | if self.sampled_coord is None and self.gradncp_coord is None: 189 | return F.mse_loss(x.view(meta_batch_size, -1), out.reshape(meta_batch_size, -1), reduce=False).mean(dim=1) 190 | elif self.gradncp_coord is not None: 191 | x = torch.gather(x, 2, self.gradncp_index) 192 | return F.mse_loss(x.view(meta_batch_size, -1), out.reshape(meta_batch_size, -1), reduce=False).mean(dim=1) 193 | else: 194 | x = x[:, :, self.sampled_index] 195 | return F.mse_loss(x.view(meta_batch_size, -1), out.reshape(meta_batch_size, -1), reduce=False).mean(dim=1) 196 | return out 197 | -------------------------------------------------------------------------------- /downstream_tasks/classification.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import sys 4 | 5 | sys.path.append(".") 6 | 7 | import time 8 | import random 9 | import numpy as np 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torchvision.transforms as T 13 | import torchvision.models as models 14 | from sklearn.neighbors import KNeighborsClassifier 15 | from sklearn.metrics import accuracy_score, f1_score 16 | from torch.utils.data import DataLoader, Dataset 17 | from torchmetrics.classification import MulticlassF1Score 18 | 19 | 20 | def set_random_seed(seed): 21 | random.seed(seed) 22 | np.random.seed(seed) 23 | torch.manual_seed(seed) 24 | 25 | 26 | class NFDataset(Dataset): 27 | def __init__(self, root_dir): 28 | self.root_dir = root_dir 29 | self.files = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.pt')] 30 | 31 | def __len__(self): 32 | return len(self.files) 33 | 34 | def __getitem__(self, idx): 35 | data = torch.load(self.files[idx], weights_only=False) 36 | return data['modulations'].float(), data['label'] 37 | 38 | 39 | class SimpleClassifier(nn.Module): 40 | def __init__(self, input_dim, num_classes): 41 | super(SimpleClassifier, self).__init__() 42 | self.network = nn.Sequential( 43 | nn.Linear(input_dim, 512), 44 | nn.ReLU(), 45 | nn.Dropout(0.3), 46 | nn.Linear(512, 256), 47 | nn.ReLU(), 48 | nn.Dropout(0.3), 49 | nn.Linear(256, num_classes), 50 | ) 51 | 52 | def forward(self, x): 53 | return self.network(x) 54 | 55 | 56 | class ResNet50Classifier(nn.Module): 57 | def __init__(self, num_classes, mode='rgb'): 58 | super(ResNet50Classifier, self).__init__() 59 | self.resnet50 = models.resnet50() 60 | self.resnet50.fc = nn.Linear(self.resnet50.fc.in_features, num_classes) 61 | if mode == 'grayscale': 62 | self.resnet50.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=7, 63 | stride=2, padding=3, bias=False) 64 | 65 | def forward(self, x): 66 | return self.resnet50(x) 67 | 68 | 69 | class EfficientNetB0Classifier(nn.Module): 70 | def __init__(self, num_classes, mode='rgb'): 71 | super(EfficientNetB0Classifier, self).__init__() 72 | self.efficientnet_b0 = models.efficientnet_b0() 73 | self.efficientnet_b0.classifier[1] = nn.Linear(self.efficientnet_b0.classifier[1].in_features, num_classes) 74 | if mode == 'grayscale': 75 | self.efficientnet_b0.features[0][0] = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, 76 | stride=2, padding=1, bias=False) 77 | 78 | def forward(self, x): 79 | return self.efficientnet_b0(x) 80 | 81 | 82 | def load_nf_data_for_knn(dataset): 83 | """Load all data from NFDataset into numpy arrays for KNN""" 84 | features = [] 85 | labels = [] 86 | 87 | for i in range(len(dataset)): 88 | data, label = dataset[i] 89 | features.append(data.numpy()) 90 | labels.append(label.item() if torch.is_tensor(label) else label) 91 | 92 | return np.array(features), np.array(labels) 93 | 94 | 95 | def train_knn(train_dataset, args): 96 | """Train KNN classifier""" 97 | print("Loading training data for KNN...") 98 | X_train, y_train = load_nf_data_for_knn(train_dataset) 99 | 100 | print(f"Training KNN with k={args.k_neighbors}...") 101 | knn = KNeighborsClassifier(n_neighbors=args.k_neighbors, n_jobs=-1) 102 | knn.fit(X_train, y_train) 103 | 104 | return knn 105 | 106 | 107 | def evaluate_knn(knn_model, dataset, num_classes): 108 | """Evaluate KNN classifier""" 109 | print("Loading evaluation data for KNN...") 110 | X, y = load_nf_data_for_knn(dataset) 111 | 112 | predictions = knn_model.predict(X) 113 | accuracy = accuracy_score(y, predictions) * 100 114 | f1 = f1_score(y, predictions, average='macro' if num_classes > 2 else 'binary') 115 | 116 | return accuracy, f1 117 | 118 | 119 | def train(model, train_loader, criterion, optimizer, device, num_classes): 120 | model.train() 121 | total_loss = 0.0 122 | 123 | for data, labels in train_loader: 124 | data, labels = data.to(device), labels.to(device) 125 | labels = labels.squeeze() 126 | optimizer.zero_grad() 127 | out = model(data) 128 | loss = criterion(out, labels) 129 | loss.backward() 130 | optimizer.step() 131 | 132 | total_loss += loss.item() 133 | 134 | return total_loss / len(train_loader) 135 | 136 | 137 | def evaluate(model, val_loader, criterion, device, num_classes): 138 | model.eval() 139 | total_loss = 0.0 140 | correct = 0 141 | total = 0 142 | f1_metric = MulticlassF1Score(num_classes=num_classes).to(device) 143 | 144 | with torch.no_grad(): 145 | for data, labels in val_loader: 146 | data, labels = data.to(device), labels.to(device) 147 | labels = labels.squeeze() 148 | out = model(data) 149 | loss = criterion(out, labels) 150 | 151 | total_loss += loss.item() 152 | _, predicted = torch.max(out, 1) 153 | correct += (predicted == labels).sum().item() 154 | total += labels.size(0) 155 | 156 | f1_metric.update(predicted, labels) 157 | 158 | accuracy = 100 * correct / total 159 | f1_score = f1_metric.compute().item() 160 | 161 | return total_loss / len(val_loader), accuracy, f1_score 162 | 163 | 164 | class Args: 165 | def __init__(self): 166 | self.data_dir = "your/data/dir" # Data directory 167 | self.classifier = 'simple' # knn, simple, resnet, efficientnet 168 | self.mode = 'grayscale' # grayscale or rgb 169 | self.batch_size = 32 170 | self.input_dim = 2048 171 | self.num_classes = 7 # Number of classes 172 | self.learning_rate = 1e-3 173 | self.num_epochs = 50 174 | self.seed = 42 175 | self.k_neighbors = 3 # Number of neighbors for KNN 176 | 177 | 178 | def main(args): 179 | device = torch.device(f'cuda' if torch.cuda.is_available() else 'cpu') 180 | 181 | """ Enable determinism """ 182 | set_random_seed(args.seed) 183 | torch.backends.cudnn.deterministic = True 184 | torch.backends.cudnn.benchmark = False 185 | 186 | """ Define Dataset and Dataloader """ 187 | if args.classifier in ['simple', 'knn']: 188 | train_set = NFDataset(os.path.join(args.data_dir, "train")) 189 | val_set = NFDataset(os.path.join(args.data_dir, "val")) 190 | test_set = NFDataset(os.path.join(args.data_dir, "test")) 191 | 192 | elif args.classifier == 'resnet' or args.classifier == 'efficientnet': 193 | from medmnist import PneumoniaMNIST 194 | # from medmnist import DermaMNIST 195 | transforms = T.Compose([ 196 | T.ToTensor(), 197 | ]) 198 | train_set = PneumoniaMNIST(split='train', transform=transforms, download='True', size=64) 199 | val_set = PneumoniaMNIST(split='val', transform=transforms, download='True', size=64) 200 | test_set = PneumoniaMNIST(split='test', transform=transforms, download='True', size=64) 201 | # train_set = DermaMNIST(split='train', transform=transforms, download='True', size=64) 202 | # val_set = DermaMNIST(split='val', transform=transforms, download='True', size=64) 203 | # test_set = DermaMNIST(split='test', transform=transforms, download='True', size=64) 204 | else: 205 | raise NotImplementedError() 206 | 207 | # Handle KNN separately since it doesn't use PyTorch training loop 208 | if args.classifier == 'knn': 209 | print("Training KNN Classifier...") 210 | start_time = time.time() 211 | 212 | # Train KNN 213 | knn_model = train_knn(train_set, args) 214 | 215 | # Evaluate on validation set 216 | val_acc, val_f1 = evaluate_knn(knn_model, val_set, args.num_classes) 217 | print(f"Validation Accuracy: {val_acc:.2f}%, Validation F1: {val_f1:.4f}") 218 | 219 | # Evaluate on test set 220 | test_acc, test_f1 = evaluate_knn(knn_model, test_set, args.num_classes) 221 | 222 | end_time = time.time() 223 | elapsed_time = end_time - start_time 224 | print(f"Elapsed time: {elapsed_time:.2f} seconds") 225 | print(f"Test Accuracy: {test_acc:.2f}%, Test F1 Score: {test_f1:.4f}") 226 | 227 | return 228 | 229 | # For neural network classifiers 230 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True) 231 | val_loader = DataLoader(val_set, batch_size=args.batch_size) 232 | test_loader = DataLoader(test_set, batch_size=args.batch_size) 233 | 234 | """ Select classification model """ 235 | if args.classifier == 'simple': 236 | model = SimpleClassifier(args.input_dim, args.num_classes).to(device) 237 | elif args.classifier == 'resnet': 238 | model = ResNet50Classifier(args.num_classes, mode=args.mode).to(device) 239 | elif args.classifier == 'efficientnet': 240 | model = EfficientNetB0Classifier(args.num_classes, mode=args.mode).to(device) 241 | 242 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 243 | print(f"Parameters: {pytorch_total_params}") 244 | 245 | """ Define optimization criterion and optimizer """ 246 | criterion = nn.CrossEntropyLoss() 247 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) 248 | 249 | """ Run training and validation loop """ 250 | best_val_acc = 0.0 251 | start_time = time.time() 252 | for epoch in range(args.num_epochs): 253 | train_loss = train(model, train_loader, criterion, optimizer, device, args.num_classes) 254 | val_loss, val_acc, val_f1 = evaluate(model, val_loader, criterion, device, args.num_classes) 255 | if val_acc > best_val_acc: 256 | best_val_acc = val_acc 257 | best_model = model.state_dict() 258 | 259 | print(f"Epoch {epoch + 1}/{args.num_epochs}: ", 260 | f"Train Loss: {train_loss:.4f} ", 261 | f"| Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%, Val F1: {val_f1:.4f}") 262 | 263 | end_time = time.time() 264 | elapsed_time = end_time - start_time 265 | print(f"Elapsed time: {elapsed_time} seconds") 266 | 267 | """ Final evaluation on test set """ 268 | model.load_state_dict(best_model) 269 | test_loss, test_acc, test_f1 = evaluate(model, test_loader, criterion, device, args.num_classes) 270 | print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%, Test F1 Score: {test_f1:.4f}") 271 | 272 | 273 | if __name__ == "__main__": 274 | args = Args() 275 | main(args) 276 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torchvision.transforms as T 5 | import pandas as pd 6 | import os 7 | import nibabel 8 | import cv2 as cv 9 | 10 | 11 | class ECG1D(torch.utils.data.Dataset): 12 | def __init__(self, directory, test=False): 13 | super().__init__() 14 | 15 | if not test: 16 | self.df = pd.read_csv(directory + '/mitbih_train.csv') 17 | else: 18 | self.df = pd.read_csv(directory + '/mitbih_test.csv') 19 | 20 | def __len__(self): 21 | return len(self.df) 22 | 23 | def __getitem__(self, idx): 24 | sample = self.df.iloc[idx, :-1].values.astype(float) 25 | label = self.df.iloc[idx, -1] 26 | sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(dim=0) 27 | label = torch.tensor(label, dtype=torch.long) 28 | 29 | return sample, label 30 | 31 | 32 | class BRATSVolumes(torch.utils.data.Dataset): 33 | def __init__(self, directory, normalize=None, img_size=32): 34 | super().__init__() 35 | self.directory = os.path.expanduser(directory) 36 | self.normalize = normalize or (lambda x: x) 37 | self.img_size = img_size 38 | self.seqtypes = ['t1n', 't1c', 't2w', 't2f', 'seg'] 39 | self.seqtypes_set = set(self.seqtypes) 40 | self.database = [] 41 | 42 | for root, dirs, files in os.walk(self.directory): 43 | # Ensure determinism 44 | dirs.sort() 45 | files.sort() 46 | # if there are no subdirs, we have a datadir 47 | if not dirs: 48 | datapoint = dict() 49 | # extract all files as channels 50 | for f in files: 51 | seqtype = f.split('-')[4].split('.')[0] 52 | datapoint[seqtype] = os.path.join(root, f) 53 | self.database.append(datapoint) 54 | 55 | def __getitem__(self, x): 56 | filedict = self.database[x] 57 | name = filedict['t1n'] 58 | nib_img = nibabel.load(name) # We only use t1 weighted images 59 | out = nib_img.get_fdata() 60 | 61 | # Clip and normalize the images 62 | out_clipped = np.clip(out, np.quantile(out, 0.001), np.quantile(out, 0.999)) 63 | out_normalized = (out_clipped - np.min(out_clipped)) / (np.max(out_clipped) - np.min(out_clipped)) 64 | out = torch.tensor(out_normalized) 65 | 66 | # Zero pad images 67 | image = torch.zeros(1, 256, 256, 256) 68 | image[:, 8:-8, 8:-8, 50:-51] = out 69 | 70 | # Downsampling 71 | if self.img_size == 32: 72 | downsample = nn.AvgPool3d(kernel_size=8, stride=8) 73 | image = downsample(image) 74 | 75 | if self.img_size == 64: 76 | downsample = nn.AvgPool3d(kernel_size=4, stride=4) 77 | image = downsample(image) 78 | 79 | # Normalization 80 | image = self.normalize(image) 81 | 82 | # Insert dummy label 83 | label = 1 84 | 85 | return image, label 86 | 87 | def __len__(self): 88 | return len(self.database) 89 | 90 | class LIDCVolumes(torch.utils.data.Dataset): 91 | def __init__(self, directory, normalize=None, img_size=32): 92 | super().__init__() 93 | self.directory = os.path.expanduser(directory) 94 | self.normalize = normalize or (lambda x: x) 95 | self.img_size = img_size 96 | self.database = [] 97 | 98 | for root, dirs, files in os.walk(self.directory): 99 | # Ensure determinism 100 | dirs.sort() 101 | files.sort() 102 | # if there are no subdirs, we have a datadir 103 | if not dirs: 104 | datapoint = dict() 105 | for f in files: 106 | datapoint['image'] = os.path.join(root, f) 107 | if len(datapoint) != 0: 108 | self.database.append(datapoint) 109 | 110 | def __getitem__(self, x): 111 | filedict = self.database[x] 112 | name = filedict['image'] 113 | nib_img = nibabel.load(name) 114 | out = nib_img.get_fdata() 115 | 116 | # Clip and normalize the images 117 | out_clipped = np.clip(out, np.quantile(out, 0.001), np.quantile(out, 0.999)) 118 | out_normalized = (out_clipped - np.min(out_clipped)) / (np.max(out_clipped) - np.min(out_clipped)) 119 | out = torch.tensor(out_normalized) 120 | 121 | image = torch.zeros(1, 256, 256, 256) 122 | image[:, :, :, :] = out 123 | 124 | if self.img_size == 32: 125 | downsample = nn.AvgPool3d(kernel_size=8, stride=8) 126 | image = downsample(image) 127 | 128 | if self.img_size == 64: 129 | downsample = nn.AvgPool3d(kernel_size=4, stride=4) 130 | image = downsample(image) 131 | 132 | # normalization 133 | image = self.normalize(image) 134 | 135 | # Insert dummy label 136 | label = 1 137 | 138 | return image, label 139 | 140 | def __len__(self): 141 | return len(self.database) 142 | 143 | 144 | class EchoNet(torch.utils.data.Dataset): 145 | def __init__(self, directory, split='TRAIN'): 146 | self.data = pd.read_csv(directory + '/FileList.csv') 147 | self.data = self.data[self.data['Split'] == split] 148 | self.video_dir = directory + '/Videos/' 149 | self.max_frames = 10 150 | 151 | def __len__(self): 152 | return len(self.data) 153 | 154 | def __getitem__(self, idx): 155 | row = self.data.iloc[idx] 156 | filename = row['FileName'] 157 | ef = torch.tensor(row['EF'], dtype=torch.float32) 158 | 159 | video_path = os.path.join(self.video_dir, f"{filename}.avi") 160 | video = self.load_video(video_path) 161 | 162 | if self.transform: 163 | video = self.transform(video) 164 | 165 | return video, ef 166 | 167 | def load_video(self, path): 168 | cap = cv.VideoCapture(path) 169 | frames = [] 170 | 171 | while True: 172 | ret, frame = cap.read() 173 | if not ret: 174 | break 175 | frame = cv.cvtColor(frame, cv.COLOR_BGR2GRAY) # Convert to grayscale 176 | frame = cv.resize(frame, (112, 112)) # Resize if needed 177 | frames.append(frame) 178 | 179 | cap.release() 180 | frames = np.stack(frames, axis=0) 181 | frames = torch.tensor(frames, dtype=torch.float32) / 255.0 182 | frames = frames.unsqueeze(1) 183 | 184 | # Crop/pad to max_frames 185 | if self.max_frames: 186 | T = frames.shape[0] 187 | if T > self.max_frames: 188 | frames = frames[:self.max_frames] 189 | elif T < self.max_frames: 190 | pad = torch.zeros((self.max_frames - T, 1, 112, 112)) 191 | frames = torch.cat([frames, pad], dim=0) 192 | 193 | return frames 194 | 195 | 196 | def get_dataset(args, only_test=False, all=False): 197 | train_set = None 198 | val_set = None 199 | test_set = None 200 | 201 | ############################################# 202 | ############# 2D Image Datasets ############# 203 | ############################################# 204 | if args.dataset == 'chestmnist': 205 | from medmnist import ChestMNIST 206 | transforms = T.Compose([ 207 | T.ToTensor(), 208 | T.Grayscale() 209 | ]) 210 | 211 | train_set = ChestMNIST(split='train', transform=transforms, download='True', size=args.img_size) 212 | val_set = ChestMNIST(split='val', transform=transforms, download='True', size=args.img_size) 213 | test_set = ChestMNIST(split='test', transform=transforms, download='True', size=args.img_size) 214 | 215 | print(f'Training set containing {len(train_set)} images.') 216 | print(f'Validation set containing {len(val_set)} images.') 217 | print(f'Test set containing {len(test_set)} images.') 218 | 219 | args.data_type = 'img' 220 | args.in_size, args.out_size = 2, 1 221 | args.data_size = (1, args.img_size, args.img_size) 222 | 223 | elif args.dataset == 'pneumoniamnist': 224 | from medmnist import PneumoniaMNIST 225 | transforms = T.Compose([ 226 | T.ToTensor(), 227 | T.Grayscale() 228 | ]) 229 | 230 | train_set = PneumoniaMNIST(split='train', transform=transforms, download='True', size=args.img_size) 231 | val_set = PneumoniaMNIST(split='val', transform=transforms, download='True', size=args.img_size) 232 | test_set = PneumoniaMNIST(split='test', transform=transforms, download='True', size=args.img_size) 233 | 234 | print(f'Training set containing {len(train_set)} images.') 235 | print(f'Validation set containing {len(val_set)} images.') 236 | print(f'Test set containing {len(test_set)} images.') 237 | 238 | args.data_type = 'img' 239 | args.in_size, args.out_size = 2, 1 240 | args.data_size = (1, args.img_size, args.img_size) 241 | 242 | elif args.dataset == 'retinamnist': 243 | from medmnist import RetinaMNIST 244 | transforms = T.Compose([ 245 | T.ToTensor(), 246 | ]) 247 | train_set = RetinaMNIST(split='train', transform=transforms, download='True', size=args.img_size) 248 | val_set = RetinaMNIST(split='val', transform=transforms, download='True', size=args.img_size) 249 | test_set = RetinaMNIST(split='test', transform=transforms, download='True', size=args.img_size) 250 | 251 | print(f'Training set containing {len(train_set)} images.') 252 | print(f'Validation set containing {len(val_set)} images.') 253 | print(f'Test set containing {len(test_set)} images.') 254 | 255 | args.data_type = 'img' 256 | args.in_size, args.out_size = 2, 3 257 | args.data_size = (3, args.img_size, args.img_size) 258 | 259 | elif args.dataset == 'dermamnist': 260 | from medmnist import DermaMNIST 261 | transforms = T.Compose([ 262 | T.ToTensor(), 263 | ]) 264 | train_set = DermaMNIST(split='train', transform=transforms, download='True', size=args.img_size) 265 | val_set = DermaMNIST(split='val', transform=transforms, download='True', size=args.img_size) 266 | test_set = DermaMNIST(split='test', transform=transforms, download='True', size=args.img_size) 267 | 268 | print(f'Training set containing {len(train_set)} images.') 269 | print(f'Validation set containing {len(val_set)} images.') 270 | print(f'Test set containing {len(test_set)} images.') 271 | 272 | args.data_type = 'img' 273 | args.in_size, args.out_size = 2, 3 274 | args.data_size = (3, args.img_size, args.img_size) 275 | 276 | elif args.dataset == 'octmnist': 277 | from medmnist import OCTMNIST 278 | transforms = T.Compose([ 279 | T.ToTensor(), 280 | T.Grayscale() 281 | ]) 282 | train_set = OCTMNIST(split='train', transform=transforms, download='True', size=args.img_size) 283 | val_set = OCTMNIST(split='val', transform=transforms, download='True', size=args.img_size) 284 | test_set = OCTMNIST(split='test', transform=transforms, download='True', size=args.img_size) 285 | 286 | print(f'Training set containing {len(train_set)} images.') 287 | print(f'Validation set containing {len(val_set)} images.') 288 | print(f'Test set containing {len(test_set)} images.') 289 | 290 | args.data_type = 'img' 291 | args.in_size, args.out_size = 2, 1 292 | args.data_size = (1, args.img_size, args.img_size) 293 | 294 | elif args.dataset == 'pathmnist': 295 | from medmnist import PathMNIST 296 | transforms = T.Compose([ 297 | T.ToTensor(), 298 | ]) 299 | train_set = PathMNIST(split='train', transform=transforms, download='True', size=args.img_size) 300 | val_set = PathMNIST(split='val', transform=transforms, download='True', size=args.img_size) 301 | test_set = PathMNIST(split='test', transform=transforms, download='True', size=args.img_size) 302 | 303 | print(f'Training set containing {len(train_set)} images.') 304 | print(f'Validation set containing {len(val_set)} images.') 305 | print(f'Test set containing {len(test_set)} images.') 306 | 307 | args.data_type = 'img' 308 | args.in_size, args.out_size = 2, 3 309 | args.data_size = (3, args.img_size, args.img_size) 310 | 311 | elif args.dataset == 'tissuemnist': 312 | from medmnist import TissueMNIST 313 | transforms = T.Compose([ 314 | T.ToTensor(), 315 | T.Grayscale() 316 | ]) 317 | train_set = TissueMNIST(split='train', transform=transforms, download='True', size=args.img_size) 318 | val_set = TissueMNIST(split='val', transform=transforms, download='True', size=args.img_size) 319 | test_set = TissueMNIST(split='test', transform=transforms, download='True', size=args.img_size) 320 | 321 | print(f'Training set containing {len(train_set)} images.') 322 | print(f'Validation set containing {len(val_set)} images.') 323 | print(f'Test set containing {len(test_set)} images.') 324 | 325 | args.data_type = 'img' 326 | args.in_size, args.out_size = 2, 1 327 | args.data_size = (1, args.img_size, args.img_size) 328 | 329 | ############################################# 330 | ############# 3D Image Datasets ############# 331 | ############################################# 332 | elif args.dataset == 'nodulemnist': 333 | from medmnist import NoduleMNIST3D 334 | train_set = NoduleMNIST3D(split='train', download='True', size=args.img_size) 335 | val_set = NoduleMNIST3D(split='val', download='True', size=args.img_size) 336 | test_set = NoduleMNIST3D(split='test', download='True', size=args.img_size) 337 | 338 | print(f'Training set containing {len(train_set)} images.') 339 | print(f'Validation set containing {len(val_set)} images.') 340 | print(f'Test set containing {len(test_set)} images.') 341 | 342 | args.data_type = 'img3d' 343 | args.in_size, args.out_size = 3, 1 344 | args.data_size = (1, args.img_size, args.img_size, args.img_size) 345 | 346 | elif args.dataset == 'organmnist': 347 | from medmnist import OrganMNIST3D 348 | train_set = OrganMNIST3D(split='train', download='True', size=args.img_size) 349 | val_set = OrganMNIST3D(split='val', download='True', size=args.img_size) 350 | test_set = OrganMNIST3D(split='test', download='True', size=args.img_size) 351 | 352 | print(f'Training set containing {len(train_set)} images.') 353 | print(f'Validation set containing {len(val_set)} images.') 354 | print(f'Test set containing {len(test_set)} images.') 355 | 356 | args.data_type = 'img3d' 357 | args.in_size, args.out_size = 3, 1 358 | args.data_size = (1, args.img_size, args.img_size, args.img_size) 359 | 360 | elif args.dataset == 'vesselmnist': 361 | from medmnist import VesselMNIST3D 362 | train_set = VesselMNIST3D(split='train', download='True', size=args.img_size) 363 | val_set = VesselMNIST3D(split='val', download='True', size=args.img_size) 364 | test_set = VesselMNIST3D(split='test', download='True', size=args.img_size) 365 | 366 | print(f'Training set containing {len(train_set)} images.') 367 | print(f'Validation set containing {len(val_set)} images.') 368 | print(f'Test set containing {len(test_set)} images.') 369 | 370 | args.data_type = 'shape3d' 371 | args.in_size, args.out_size = 3, 1 372 | args.data_size = (1, args.img_size, args.img_size, args.img_size) 373 | 374 | elif args.dataset == 'brats': 375 | dataset = BRATSVolumes('/raid/cian/user/paul.friedrich/datasets/BRATS2023-GLI/', img_size=args.img_size) 376 | 377 | # Define split sizes 378 | train_size = int(0.7 * len(dataset)) # 70% for training 379 | test_size = (len(dataset) - train_size) // 2 # 15% for testing 380 | val_size = len(dataset) - train_size - test_size # 15% for validation 381 | 382 | generator = torch.Generator().manual_seed(42) 383 | train_set, test_set, val_set = torch.utils.data.random_split(dataset, [train_size, test_size, val_size], generator=generator) 384 | 385 | print(f'Training set containing {len(train_set)} images.') 386 | print(f'Validation set containing {len(val_set)} images.') 387 | print(f'Test set containing {len(test_set)} images.') 388 | 389 | args.data_type = 'img3d' 390 | args.in_size, args.out_size = 3, 1 391 | args.data_size = (1, args.img_size, args.img_size, args.img_size) 392 | 393 | elif args.dataset == 'lidc-idri': 394 | dataset = LIDCVolumes('/raid/cian/user/paul.friedrich/datasets/lidc-nifti/', img_size=args.img_size) 395 | 396 | # Define split sizes 397 | train_size = int(0.7 * len(dataset)) # 70% for training 398 | test_size = (len(dataset) - train_size) // 2 # 15% for testing 399 | val_size = len(dataset) - train_size - test_size # 15% for validation 400 | 401 | generator = torch.Generator().manual_seed(42) 402 | train_set, test_set, val_set = torch.utils.data.random_split(dataset, [train_size, test_size, val_size], generator=generator) 403 | 404 | print(f'Training set containing {len(train_set)} images.') 405 | print(f'Validation set containing {len(val_set)} images.') 406 | print(f'Test set containing {len(test_set)} images.') 407 | 408 | args.data_type = 'img3d' 409 | args.in_size, args.out_size = 3, 1 410 | args.data_size = (1, args.img_size, args.img_size, args.img_size) 411 | 412 | ############################################# 413 | ############ 2D+t Video Datasets ############ 414 | ############################################# 415 | elif args.dataset == 'echonet': 416 | train_set = EchoNet('/raid/cian/user/paul.friedrich/datasets/EchoNet/', split='TRAIN') 417 | val_set = EchoNet('/raid/cian/user/paul.friedrich/datasets/EchoNet/', split='VAL') 418 | test_set = EchoNet('/raid/cian/user/paul.friedrich/datasets/EchoNet/', split='TEST') 419 | 420 | print(f'Training set containing {len(train_set)} videos.') 421 | print(f'Validation set containing {len(val_set)} videos.') 422 | print(f'Test set containing {len(test_set)} videos.') 423 | 424 | args.data_type='img3d' 425 | args.in_size, args.out_size = 3, 1 426 | args.data_size = (1, 10, args.img_size, args.img_size) 427 | 428 | ############################################# 429 | ########## 1D Timeseries Datasets ########### 430 | ############################################# 431 | elif args.dataset == 'ecg': 432 | train_set = ECG1D('/home/paul.friedrich/ecg_classification/', test=False) 433 | valtest_set = ECG1D('/home/paul.friedrich/ecg_classification/', test=True) 434 | 435 | test_size = len(valtest_set) // 2 436 | val_size = len(valtest_set) - test_size 437 | 438 | generator = torch.Generator().manual_seed(42) 439 | test_set, val_set = torch.utils.data.random_split(valtest_set, [test_size, val_size], generator=generator) 440 | 441 | print(f'Training set containing {len(train_set)} ECG signals.') 442 | print(f'Validation set containing {len(val_set)} ECG signals.') 443 | print(f'Test set containing {len(test_set)} ECG signals.') 444 | 445 | args.data_type = 'timeseries' 446 | args.in_size, args.out_size = 1, 1 447 | args.data_size = (1, args.img_size) 448 | 449 | else: 450 | raise NotImplementedError() 451 | 452 | if only_test: 453 | return test_set 454 | 455 | elif all: 456 | return train_set, val_set, test_set 457 | 458 | else: 459 | return train_set, val_set 460 | -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import time 5 | import pickle 6 | import numpy as np 7 | import random 8 | import torch 9 | import matplotlib.pyplot as plt 10 | import torch.distributed as dist 11 | 12 | from collections import OrderedDict, defaultdict, deque 13 | from datetime import datetime 14 | from torch.utils.tensorboard import SummaryWriter 15 | 16 | 17 | def set_random_seed(seed): 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | 24 | 25 | def load_checkpoint(logdir, mode='last'): 26 | model_path = os.path.join(logdir, f'{mode}.model') 27 | optim_path = os.path.join(logdir, f'{mode}.optim') 28 | config_path = os.path.join(logdir, f'{mode}.configs') 29 | lr_path = os.path.join(logdir, f'{mode}.lr') 30 | 31 | print(model_path) 32 | print(optim_path) 33 | 34 | print("=> Loading checkpoint from '{}'".format(logdir)) 35 | if os.path.exists(model_path): 36 | model_state = torch.load(model_path) 37 | optim_state = torch.load(optim_path) 38 | with open(config_path, 'rb') as handle: 39 | cfg = pickle.load(handle) 40 | else: 41 | return None, None, None, None 42 | 43 | if os.path.exists(lr_path): 44 | lr_dict = torch.load(lr_path) 45 | else: 46 | lr_dict = None 47 | 48 | return model_state, optim_state, cfg, lr_dict 49 | 50 | 51 | def save_checkpoint(args, step, best_psnr, model, optim_state, logdir, is_best=False, suffix=''): 52 | if is_best: 53 | prefix = 'best' 54 | else: 55 | prefix = 'last' 56 | 57 | model_state = model.state_dict() 58 | 59 | last_model = os.path.join(logdir, f'{prefix}{suffix}.model') 60 | last_optim = os.path.join(logdir, f'{prefix}{suffix}.optim') 61 | last_config = os.path.join(logdir, f'{prefix}{suffix}.configs') 62 | 63 | if isinstance(args.inner_lr, OrderedDict): 64 | last_lr = os.path.join(logdir, f'{prefix}{suffix}.lr') 65 | torch.save(args.inner_lr, last_lr) 66 | if hasattr(args, 'moving_average'): 67 | last_ema = os.path.join(logdir, f'{prefix}{suffix}.ema') 68 | torch.save(args.moving_average, last_ema) 69 | if hasattr(args, 'moving_inner_lr'): 70 | last_lr_ema = os.path.join(logdir, f'{prefix}{suffix}.lr_ema') 71 | torch.save(args.moving_inner_lr, last_lr_ema) 72 | 73 | opt = { 74 | 'step': step, 75 | 'best': best_psnr 76 | } 77 | torch.save(model_state, last_model) 78 | torch.save(optim_state, last_optim) 79 | with open(last_config, 'wb') as handle: 80 | pickle.dump(opt, handle, protocol=pickle.HIGHEST_PROTOCOL) 81 | 82 | 83 | def save_checkpoint_step(args, step, best_psnr, model, optim_state, logdir, suffix=''): 84 | model_state = model.state_dict() 85 | 86 | last_model = os.path.join(logdir, f'step{step}{suffix}.model') 87 | last_optim = os.path.join(logdir, f'step{step}{suffix}.optim') 88 | last_config = os.path.join(logdir, f'step{step}{suffix}.configs') 89 | 90 | if isinstance(args.inner_lr, OrderedDict): 91 | last_lr = os.path.join(logdir, f'step{step}{suffix}.lr') 92 | torch.save(args.inner_lr, last_lr) 93 | if hasattr(args, 'moving_average'): 94 | last_ema = os.path.join(logdir, f'step{step}{suffix}.ema') 95 | torch.save(args.moving_average, last_ema) 96 | if hasattr(args, 'moving_inner_lr'): 97 | last_lr_ema = os.path.join(logdir, f'step{step}{suffix}.lr_ema') 98 | torch.save(args.moving_inner_lr, last_lr_ema) 99 | 100 | opt = { 101 | 'step': step, 102 | 'best': best_psnr 103 | } 104 | torch.save(model_state, last_model) 105 | torch.save(optim_state, last_optim) 106 | with open(last_config, 'wb') as handle: 107 | pickle.dump(opt, handle, protocol=pickle.HIGHEST_PROTOCOL) 108 | 109 | 110 | def resume_training(args, model, optimizer): 111 | if args.resume_path is not None: 112 | model_state, optimizer_state, config, lr_dict = load_checkpoint(args.resume_path, mode='best') 113 | model.load_state_dict(model_state) 114 | optimizer.load_state_dict(optimizer_state) 115 | start_step = config['step'] 116 | best_psnr = config['best'] 117 | is_best = False 118 | psnr = 0. 119 | 120 | if lr_dict is not None: 121 | args.inner_lr = lr_dict 122 | 123 | else: 124 | is_best = False 125 | start_step = 1 126 | best_psnr = 0. 127 | psnr = 0. 128 | return is_best, start_step, best_psnr, psnr 129 | 130 | 131 | def is_dist_avail_and_initialized(): 132 | if not dist.is_available(): 133 | return False 134 | if not dist.is_initialized(): 135 | return False 136 | return True 137 | 138 | 139 | class Logger(object): 140 | """ 141 | Reference: https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 142 | """ 143 | 144 | def __init__(self, fn, ask=True, today=True, rank=0, grid_search=False): 145 | self.rank = rank 146 | self.log_path = './logs/' 147 | self.logdir = None 148 | 149 | if grid_search: 150 | self.log_path = './logs_gridsearch/' 151 | 152 | if self.rank == 0: 153 | if not os.path.exists(self.log_path): 154 | os.mkdir(self.log_path) 155 | self.today = today 156 | 157 | logdir = self._make_dir(fn) 158 | 159 | if not os.path.exists(logdir): 160 | os.mkdir(logdir) 161 | 162 | if len(os.listdir(logdir)) != 0 and ask: 163 | ans = input("log_dir is not empty. All data inside log_dir will be deleted. " 164 | "Will you proceed [y/N]? ") 165 | if ans in ['y', 'Y']: 166 | shutil.rmtree(logdir) 167 | else: 168 | exit(1) 169 | 170 | self.set_dir(logdir) 171 | 172 | def _make_dir(self, fn): 173 | if self.today: 174 | today = datetime.today().strftime("%y%m%d") 175 | logdir = self.log_path + today + '_' + fn 176 | else: 177 | logdir = self.log_path + fn 178 | return logdir 179 | 180 | def set_dir(self, logdir, log_fn='log.txt'): 181 | self.logdir = logdir 182 | if not os.path.exists(logdir): 183 | os.mkdir(logdir) 184 | self.writer = SummaryWriter(logdir) 185 | self.log_file = open(os.path.join(logdir, log_fn), 'a') 186 | 187 | def close_writer(self): 188 | if self.rank == 0: 189 | self.writer.close() 190 | 191 | def log(self, string): 192 | if self.rank == 0: 193 | self.log_file.write('[%s] %s' % (datetime.now(), string) + '\n') 194 | self.log_file.flush() 195 | 196 | print('[%s] %s' % (datetime.now(), string)) 197 | sys.stdout.flush() 198 | 199 | def log_dirname(self, string): 200 | if self.rank == 0: 201 | self.log_file.write('%s (%s)' % (string, self.logdir) + '\n') 202 | self.log_file.flush() 203 | 204 | print('%s (%s)' % (string, self.logdir)) 205 | sys.stdout.flush() 206 | 207 | def save_df(self, dataframe): 208 | if self.rank == 0: 209 | filename = self.logdir + '/ww.csv' 210 | print(filename) 211 | dataframe.to_csv(filename, sep='\t', header=True) 212 | 213 | def scalar_summary(self, tag, value, step): 214 | """Log a scalar variable.""" 215 | if self.rank == 0: 216 | self.writer.add_scalar(tag, value, step) 217 | 218 | def log_hist(self, tag, value, step): 219 | if self.rank == 0: 220 | self.writer.add_histogram(tag, value, step) 221 | 222 | def log_hyperparameters(self, args): 223 | if self.rank == 0: 224 | self.writer.add_text( 225 | 'config', 226 | '\n'.join([f'--{k}={repr(v)}
' for k, v in vars(args).items()]) 227 | ) 228 | 229 | def log_hparams(self, h_dict, m_dict): 230 | if self.rank == 0: 231 | self.writer.add_hparams(h_dict, m_dict, run_name='.') 232 | 233 | def log_image(self, tag, images, step): 234 | """Log an image tensor.""" 235 | if self.rank == 0: 236 | if len(images.shape) == 3: # Timeseries 237 | x = torch.arange(1, images.shape[2]+1).numpy() 238 | plt.figure(figsize=(10, 6)) 239 | for i in range(6): 240 | y = images[i, 0, :].detach().cpu().numpy() 241 | plt.plot(x, y, label=f"ECG {i+1}") 242 | plt.ylabel("Signal Value") 243 | plt.grid(True) 244 | 245 | plt.tight_layout() 246 | self.writer.add_figure(tag, plt.gcf(), step) 247 | 248 | if len(images.shape) == 4: # 2D Images 249 | self.writer.add_images(tag, images, step) 250 | 251 | if len(images.shape) == 5: # 3D Images 252 | # Log middle slices along all 3 dimensions 253 | batch_size, channels, depth, height, width = images.shape 254 | 255 | # Select the middle slices 256 | middle_depth = depth // 2 257 | middle_height = height // 2 258 | middle_width = width // 2 259 | 260 | # Extract middle slices along each axis 261 | slices_depth = images[:, :, middle_depth, :, :] # Middle slice along depth 262 | slices_height = images[:, :, :, middle_height, :] # Middle slice along height 263 | slices_width = images[:, :, :, :, middle_width] # Middle slice along width 264 | 265 | # Log slices with meaningful tags 266 | self.writer.add_images(f"{tag}_slice_depth", slices_depth, step) 267 | self.writer.add_images(f"{tag}_slice_height", slices_height, step) 268 | self.writer.add_images(f"{tag}_slice_width", slices_width, step) 269 | 270 | 271 | class SmoothedValue(object): 272 | """ 273 | Track a series of values and provide access to smoothed values over a 274 | window or the global series average. 275 | """ 276 | 277 | def __init__(self, window_size=20, fmt=None): 278 | if fmt is None: 279 | fmt = "{median:.4f} ({global_avg:.4f})" 280 | self.deque = deque(maxlen=window_size) 281 | self.total = 0.0 282 | self.count = 0 283 | self.fmt = fmt 284 | 285 | def update(self, value, n=1): 286 | self.deque.append(value) 287 | self.count += n 288 | self.total += value * n 289 | 290 | def reset(self): 291 | self.deque.clear() 292 | self.total = 0.0 293 | self.count = 0 294 | 295 | def synchronize_between_processes(self): 296 | """ 297 | Warning: does not synchronize the deque! 298 | """ 299 | if not is_dist_avail_and_initialized(): 300 | return 301 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 302 | dist.barrier() 303 | dist.all_reduce(t) 304 | t = t.tolist() 305 | self.count = int(t[0]) 306 | self.total = t[1] 307 | 308 | @property 309 | def median(self): 310 | d = torch.tensor(list(self.deque)) 311 | return d.median().item() 312 | 313 | @property 314 | def avg(self): 315 | d = torch.tensor(list(self.deque), dtype=torch.float32) 316 | return d.mean().item() 317 | 318 | @property 319 | def global_avg(self): 320 | return self.total / self.count 321 | 322 | @property 323 | def max(self): 324 | return max(self.deque) 325 | 326 | @property 327 | def value(self): 328 | return self.deque[-1] 329 | 330 | def __str__(self): 331 | return self.fmt.format( 332 | median=self.median, 333 | avg=self.avg, 334 | global_avg=self.global_avg, 335 | max=self.max, 336 | value=self.value) 337 | 338 | 339 | class MetricLogger(object): 340 | def __init__(self, delimiter="\t"): 341 | self.meters = defaultdict(SmoothedValue) 342 | self.delimiter = delimiter 343 | 344 | def update(self, **kwargs): 345 | for k, v in kwargs.items(): 346 | if v is None: 347 | continue 348 | if isinstance(v, torch.Tensor): 349 | v = v.item() 350 | assert isinstance(v, (float, int)) 351 | self.meters[k].update(v) 352 | 353 | def __getattr__(self, attr): 354 | if attr in self.meters: 355 | return self.meters[attr] 356 | if attr in self.__dict__: 357 | return self.__dict__[attr] 358 | raise AttributeError("'{}' object has no attribute '{}'".format( 359 | type(self).__name__, attr)) 360 | 361 | def __str__(self): 362 | loss_str = [] 363 | for name, meter in self.meters.items(): 364 | loss_str.append( 365 | "{}: {}".format(name, str(meter)) 366 | ) 367 | return self.delimiter.join(loss_str) 368 | 369 | def synchronize_between_processes(self): 370 | for meter in self.meters.values(): 371 | meter.synchronize_between_processes() 372 | 373 | def add_meter(self, name, meter): 374 | self.meters[name] = meter 375 | 376 | def reset(self): 377 | for meter in self.meters.values(): 378 | meter.reset() 379 | 380 | def log_every(self, iterable, print_freq, header=None): 381 | i = 0 382 | if not header: 383 | header = '' 384 | start_time = time.time() 385 | end = time.time() 386 | iter_time = SmoothedValue(fmt='{avg:.4f}') 387 | data_time = SmoothedValue(fmt='{avg:.4f}') 388 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 389 | log_msg = [ 390 | header, 391 | '[{0' + space_fmt + '}/{1}]', 392 | 'eta: {eta}', 393 | '{meters}', 394 | 'time: {time}', 395 | 'data: {data}' 396 | ] 397 | if torch.cuda.is_available(): 398 | log_msg.append('max mem: {memory:.0f}') 399 | log_msg = self.delimiter.join(log_msg) 400 | MB = 1024.0 * 1024.0 401 | for obj in iterable: 402 | data_time.update(time.time() - end) 403 | yield obj 404 | iter_time.update(time.time() - end) 405 | if i % print_freq == 0 or i == len(iterable) - 1: 406 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 407 | eta_string = str(datetime.datetime.timedelta(seconds=int(eta_seconds))) 408 | if torch.cuda.is_available(): 409 | print(log_msg.format( 410 | i, len(iterable), eta=eta_string, 411 | meters=str(self), 412 | time=str(iter_time), data=str(data_time), 413 | memory=torch.cuda.max_memory_allocated() / MB)) 414 | else: 415 | print(log_msg.format( 416 | i, len(iterable), eta=eta_string, 417 | meters=str(self), 418 | time=str(iter_time), data=str(data_time))) 419 | i += 1 420 | end = time.time() 421 | total_time = time.time() - start_time 422 | total_time_str = str(datetime.datetime.timedelta(seconds=int(total_time))) 423 | print('{} Total time: {} ({:.4f} s / it)'.format( 424 | header, total_time_str, total_time / len(iterable))) 425 | 426 | 427 | def psnr(mse): 428 | return -10.0 * torch.log10(mse+1e-24) 429 | 430 | def dice_score(pred, target, epsilon=1e-6): 431 | """ 432 | Computes the Dice score between two batched image tensors. 433 | Args: 434 | pred (torch.Tensor): Tensor of shape (B, ...) containing predicted masks. 435 | target (torch.Tensor): Tensor of shape (B, ...) containing ground truth masks. 436 | epsilon (float): Small constant to avoid division by zero. 437 | Returns: 438 | torch.Tensor: Dice score for each item in the batch (shape: [B]) 439 | """ 440 | # Ensure the input shapes are the same 441 | if pred.shape != target.shape: 442 | raise ValueError(f"Shape mismatch: pred {pred.shape}, target {target.shape}") 443 | 444 | # Flatten each spatial dimension per batch item 445 | B = pred.shape[0] 446 | pred_flat = pred.view(B, -1) 447 | target_flat = target.view(B, -1) 448 | 449 | # Compute Dice coefficient 450 | intersection = (pred_flat * target_flat).sum(dim=1) 451 | union = pred_flat.sum(dim=1) + target_flat.sum(dim=1) 452 | 453 | dice = (2.0 * intersection + epsilon) / (union + epsilon) 454 | return dice 455 | 456 | 457 | def _gaussian_window(window_size: int, sigma: float, device: torch.device, dtype: torch.dtype): 458 | coords = torch.arange(window_size, device=device, dtype=dtype) - window_size // 2 459 | g = torch.exp(-(coords**2) / (2 * sigma**2)) 460 | g /= g.sum() 461 | return g.view(1, 1, -1) 462 | 463 | 464 | def ssim_1d(x: torch.Tensor, y: torch.Tensor, window_size: int = 11, sigma: float = 1.5, 465 | data_range: float = None, K1: float = 0.01, K2: float = 0.03) -> torch.Tensor: 466 | """ 467 | Compute Structural Similarity Index (SSIM) for 1D time series. 468 | Parameters 469 | ---------- 470 | x, y : torch.Tensor 471 | Input 1D tensors (same length) or 2D (batch, length). 472 | window_size : int 473 | Size of sliding window (odd number). 474 | sigma : float 475 | Gaussian kernel standard deviation for local statistics. 476 | data_range : float 477 | Value range of the input (max - min). If None, inferred from data. 478 | K1, K2 : float 479 | Constants for stability in SSIM formula. 480 | Returns 481 | ------- 482 | ssim : torch.Tensor 483 | Mean SSIM over the signal (scalar). 484 | """ 485 | if x.shape != y.shape: 486 | raise ValueError("Input signals must have the same shape") 487 | if x.dim() == 1: 488 | x = x.unsqueeze(0).unsqueeze(0) # (1,1,L) 489 | y = y.unsqueeze(0).unsqueeze(0) 490 | elif x.dim() == 2: 491 | x = x.unsqueeze(1) # (B,1,L) 492 | y = y.unsqueeze(1) 493 | else: 494 | raise ValueError("Input tensors must be 1D or 2D (batch, length)") 495 | if data_range is None: 496 | data_range = torch.max(torch.cat([x.max().unsqueeze(0) - x.min().unsqueeze(0), 497 | y.max().unsqueeze(0) - y.min().unsqueeze(0)])) 498 | C1 = (K1 * data_range) ** 2 499 | C2 = (K2 * data_range) ** 2 500 | # Gaussian kernel 501 | window = _gaussian_window(window_size, sigma, x.device, x.dtype) 502 | # Local means 503 | mu_x = F.conv1d(x, window, padding=window_size//2) 504 | mu_y = F.conv1d(y, window, padding=window_size//2) 505 | mu_x_sq = mu_x ** 2 506 | mu_y_sq = mu_y ** 2 507 | mu_xy = mu_x * mu_y 508 | # Variances and covariance 509 | sigma_x_sq = F.conv1d(x * x, window, padding=window_size//2) - mu_x_sq 510 | sigma_y_sq = F.conv1d(y * y, window, padding=window_size//2) - mu_y_sq 511 | sigma_xy = F.conv1d(x * y, window, padding=window_size//2) - mu_xy 512 | # SSIM map 513 | numerator = (2 * mu_xy + C1) * (2 * sigma_xy + C2) 514 | denominator = (mu_x_sq + mu_y_sq + C1) * (sigma_x_sq + sigma_y_sq + C2) 515 | ssim_map = numerator / denominator 516 | return ssim_map 517 | 518 | 519 | class InfiniteSampler(torch.utils.data.Sampler): 520 | """ 521 | A PyTorch Sampler that provides an infinite stream of indices from the dataset, 522 | optionally shuffling and allowing distributed sampling across replicas. 523 | """ 524 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 525 | # Ensure dataset and configuration are valid 526 | assert len(dataset) > 0 527 | assert num_replicas > 0 528 | assert 0 <= rank < num_replicas 529 | assert 0 <= window_size <= 1 530 | 531 | # Initialize base sampler and store parameters 532 | super().__init__(dataset) 533 | self.dataset = dataset 534 | self.rank = rank 535 | self.num_replicas = num_replicas 536 | self.shuffle = shuffle 537 | self.seed = seed 538 | self.window_size = window_size 539 | 540 | def __iter__(self): 541 | # Generate a sequence of indices corresponding to the dataset 542 | order = np.arange(len(self.dataset)) 543 | 544 | # Initialize random number generator and window size for shuffling 545 | rnd = None 546 | window = 0 547 | if self.shuffle: 548 | # Shuffle the dataset indices 549 | rnd = np.random.RandomState(self.seed) 550 | rnd.shuffle(order) 551 | window = int(np.rint(order.size * self.window_size)) 552 | 553 | # Start iterating over the dataset 554 | idx = 0 555 | while True: 556 | i = idx % order.size 557 | if idx % self.num_replicas == self.rank: 558 | yield order[i] 559 | if window >= 2: 560 | j = (i - rnd.randint(window)) % order.size 561 | order[i], order[j] = order[j], order[i] 562 | idx += 1 563 | 564 | 565 | def load_model(args, model, logger=None): 566 | if logger is None: 567 | log_ = print 568 | else: 569 | log_ = logger.log 570 | 571 | if args.load_path is not None: 572 | log_(f'Load model from {args.load_path}') 573 | checkpoint = torch.load(args.load_path, weights_only=True) 574 | 575 | not_loaded = model.load_state_dict(checkpoint) 576 | print(not_loaded) 577 | 578 | if os.path.exists(args.load_path[:-5] + 'lr'): # Meta-SGD 579 | log_(f'Load lr from {args.load_path[:-5]}lr') 580 | lr = torch.load(args.load_path[:-5] + 'lr') 581 | for (_, param) in lr.items(): 582 | param.to(args.device) 583 | args.inner_lr = lr 584 | --------------------------------------------------------------------------------