├── .gitignore ├── .gitmodules ├── README.md ├── __init__.py ├── assets ├── process_sde.png └── sampling_sdf.gif ├── cfg ├── __init__.py ├── glas.yaml └── monuseg.yaml ├── datasets ├── __init__.py ├── config_dl.py ├── glas_dataset.py ├── monuseg_dataset.py └── transform_factory.py ├── env.yml ├── main.py ├── models ├── __init__.py └── ddpm.py ├── preprocess_data └── precompute_sdf.ipynb ├── sampler.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.directory 2 | *pycache* 3 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "SimulationHelper"] 2 | path = SimulationHelper 3 | url = https://github.com/f-ilic/SimulationHelper.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | **Score-Based Generative Models for Medical Image Segmentation using Signed Distance Functions**
3 | GCPR 2023
4 | Lea Bogensperger, Dominik Narnhofer, Filip Ilic, Thomas Pock
5 | 6 | --- 7 | 8 | [[Project Page]](https://github.com/leabogensperger/generative-segmentation-sdf) 9 | [[Paper]](https://arxiv.org/abs/2303.05966) 10 | 11 | 12 | Environment Setup: 13 | ```bash 14 | git clone --recurse-submodules git@github.com:leabogensperger/generative-segmentation-sdf.git 15 | conda env create -f env.yaml 16 | conda activate generative_segmentation_sdf 17 | ``` 18 | 19 | # Score-Based Generative Models for Medical Image Segmentation using Signed Distance Functions 20 | 21 | This repository contains the code to train a generative model that learns the conditional distribution of implicit segmentation masks in the form of signed distance function conditioned on a specific input image. The generative model is set up as a score-based diffusion model with a variance-exploding scheme -- however, later experiments have shown that the variance-preserving scheme seems numerically a bit more stable for this case, therefore this option is now also included (set the param *sde* in *SMLD* of the config file to either *ve*/*vp*). 22 | 23 | drawing 24 | 25 | # Instructions 26 | 27 | 1) Run by specifying a config file: 28 | ```python 29 | python main.py --config "cfg/monuseg.yaml" 30 | ``` 31 | 32 | 2) Sample (set experiment folder in config file): 33 | ```python 34 | python sample.py --config "cfg/monuseg.yaml" 35 | ``` 36 | 37 | Note: the pre-processed data sets will be uploaded later. The data set is specified by the config file. The root directory is set with in the config file, which must contain csv files for train and test mode with columns *filename* and *maskname* of all pre-processed patches. Moreover, it must contain the folders *Trainig_patches* and *Test_patches*, which include for each patch a .png file of the input image and a .npy file of the sdf transformed segmentation mask. 38 | 39 | # Sampling 40 | 41 | The sampling process of the proposed approach is shown using the predictor-corrector sampling algorithm (see Algorithm 1 in the paper). 42 | In the top row there are four different condition images and the center row contains the generated/predicted SDF masks. 43 | Further, the bottom row displays the corresponding binary masks, which are obtained only indirectly from thresholding the predicted SDF masks. 44 | 45 | 46 | 47 | # Cite 48 | 49 | ```bibtex 50 | @misc{ 51 | bogensperger2023scorebased, 52 | title={Score-Based Generative Models for Medical Image Segmentation using Signed Distance Functions}, 53 | author={Lea Bogensperger and Dominik Narnhofer and Filip Ilic and Thomas Pock}, 54 | year={2023}, 55 | eprint={2303.05966}, 56 | archivePrefix={arXiv}, 57 | primaryClass={cs.CV} 58 | } 59 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leabogensperger/generative-segmentation-sdf/3b8037e3e03d347a41f361f6ba844e6928240200/__init__.py -------------------------------------------------------------------------------- /assets/process_sde.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leabogensperger/generative-segmentation-sdf/3b8037e3e03d347a41f361f6ba844e6928240200/assets/process_sde.png -------------------------------------------------------------------------------- /assets/sampling_sdf.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leabogensperger/generative-segmentation-sdf/3b8037e3e03d347a41f361f6ba844e6928240200/assets/sampling_sdf.gif -------------------------------------------------------------------------------- /cfg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leabogensperger/generative-segmentation-sdf/3b8037e3e03d347a41f361f6ba844e6928240200/cfg/__init__.py -------------------------------------------------------------------------------- /cfg/glas.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | modality: glas 3 | corr_mode: diffusion_ls # diffusion on level sets 4 | img_cond: 1 # condition on image to obtain segmentation mask 5 | data_path: "/home/lea/Data/GlaS_trunc" 6 | csv_train: "train.csv" 7 | csv_test: "test.csv" 8 | batch_size: 32 9 | sz: 128 # 10 | resume_training: False 11 | load_path: '' 12 | class_label_cond: False 13 | num_classes: 0 14 | with_class_label_emb: False 15 | 16 | inference: 17 | latest: False 18 | load_exp: '' 19 | n_samples: 4 20 | 21 | model: 22 | type: 'unet' 23 | n_cin: 1 # 1, n_classes 24 | n_cin_cond: 1 # 1, 3 for gray/color valued input 25 | n_fm: 10 26 | dim: 128 27 | embedding: 'sinusoidal' 28 | mults: 29 | - 1 30 | - 2 31 | - 4 32 | - 4 33 | 34 | learning: 35 | epochs: 500000 36 | lr: 1.0E-4 37 | loss: 2 38 | n_val: 8 39 | clip: 100000. 40 | gpus: 41 | - 1 42 | 43 | SMLD: 44 | sde: 'vp' # VE used in GCPR paper, VP scheme is like classic DDPM 45 | beta_1: 1.E-4 # default from DDPM 46 | beta_T: 0.02 # default from DDPM 47 | T: 1000 # default from DDPM 48 | n_steps: 100 # VE params 49 | sigma_1_m: 5. # VE params: heuristic 50 | sigma_L_m: 0.001 # VE params: heuristic 51 | objective: 'cont' 52 | sampler: 'pc' # for VE scheme with reverse SDE 53 | eps: 2.0E-5 # annealed Langevin 54 | N: 200 # Predictor steps 55 | M: 1 # Corrector steps 56 | r: 0.15 # "snr" for PC sampling -------------------------------------------------------------------------------- /cfg/monuseg.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | modality: monuseg 3 | corr_mode: diffusion_ls # diffusion on level sets 4 | img_cond: 1 # condition on image to obtain segmentation mask 5 | data_path: "/home/lea/Data/MonuSeg_spcn_trunc" 6 | csv_train: "train.csv" 7 | csv_test: "test.csv" 8 | batch_size: 32 9 | sz: 128 # 128, 256 10 | resume_training: False 11 | load_path: '' 12 | class_label_cond: False 13 | num_classes: 0 14 | with_class_label_emb: False 15 | 16 | inference: 17 | latest: False 18 | load_exp: '' 19 | n_samples: 4 20 | 21 | model: 22 | type: 'unet' 23 | n_cin: 1 # 1, n_classes 24 | n_cin_cond: 1 # 1, 3 for gray/color valued input 25 | n_fm: 10 26 | dim: 128 27 | embedding: 'sinusoidal' 28 | mults: 29 | - 1 30 | - 2 31 | - 4 32 | - 4 33 | 34 | learning: 35 | epochs: 500000 36 | lr: 1.0E-4 37 | loss: 2 38 | n_val: 8 39 | clip: 40000. 40 | gpus: 41 | - 1 42 | 43 | SMLD: 44 | sde: 'vp' # VE used in GCPR paper, VP scheme is like classic DDPM 45 | beta_1: 1.E-4 # default from DDPM 46 | beta_T: 0.02 # default from DDPM 47 | T: 1000 # default from DDPM 48 | n_steps: 100 # VE params 49 | sigma_1_m: 5. # VE params: heuristic 50 | sigma_L_m: 0.001 # VE params: heuristic 51 | objective: 'cont' 52 | sampler: 'pc' # for VE scheme with reverse SDE 53 | eps: 2.0E-5 # annealed Langevin 54 | N: 200 # Predictor steps 55 | M: 1 # Corrector steps 56 | r: 0.15 # "snr" for PC sampling 57 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leabogensperger/generative-segmentation-sdf/3b8037e3e03d347a41f361f6ba844e6928240200/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/config_dl.py: -------------------------------------------------------------------------------- 1 | from types import SimpleNamespace 2 | from torch.utils.data.dataloader import DataLoader 3 | import yaml 4 | import json 5 | from os.path import join, isfile 6 | 7 | from datasets.transform_factory import inv_normalize, transform_factory 8 | from datasets.monuseg_dataset import MoNuSegDataset 9 | from datasets.glas_dataset import GlaSDataset 10 | 11 | 12 | def config_dl(cfg): 13 | if cfg.general.modality == 'monuseg': 14 | DatasetType = MoNuSegDataset 15 | stats = {'mean': 0., 'std': 1.} 16 | 17 | elif cfg.general.modality == 'glas': 18 | DatasetType = GlaSDataset 19 | stats = {'mean': 0., 'std': 1.} 20 | 21 | else: 22 | raise ValueError('Unknown modality %s specified!' %cfg.modality) 23 | 24 | train_dataset = DatasetType(cfg.general.data_path, f'{cfg.general.data_path}/{cfg.general.csv_train}', cfg=cfg.general) 25 | test_dataset = DatasetType(cfg.general.data_path, f'{cfg.general.data_path}/{cfg.general.csv_test}', cfg=cfg.general) 26 | 27 | train_dataloader = DataLoader(train_dataset, batch_size=cfg.general.batch_size, shuffle=True, drop_last=True) 28 | test_dataloader = DataLoader(test_dataset, batch_size=cfg.inference.n_samples, shuffle=False, drop_last=False) 29 | 30 | dbdict = {"train_dl": train_dataloader, "test_dl": test_dataloader} 31 | 32 | train_dl, test_dl = dbdict["train_dl"], dbdict["test_dl"] 33 | tfdict = transform_factory(cfg.general) 34 | T_train, T_test = tfdict["train"](stats["mean"], stats["std"]), tfdict["test"]( 35 | stats["mean"], stats["std"] 36 | ) 37 | train_dl.dataset.transform, test_dl.dataset.transform = T_train, T_test 38 | train_dl.inv_normalize, test_dl.inv_normalize = inv_normalize( 39 | stats["mean"], stats["std"] 40 | ), inv_normalize(stats["mean"], stats["std"]) 41 | 42 | return train_dl, test_dl 43 | -------------------------------------------------------------------------------- /datasets/glas_dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageOps 2 | import numpy as np 3 | import cv2 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | from types import SimpleNamespace 7 | 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | from datasets.transform_factory import transform_factory 11 | 12 | class GlaSDataset(Dataset): 13 | def __init__(self, data_path, csv_file, cfg): 14 | self.data_path = data_path 15 | self.csv_file = csv_file 16 | self.data = pd.read_csv(self.csv_file) 17 | 18 | self.transform = None 19 | self.inv_normalize = None 20 | 21 | self.corr_mode = cfg.corr_mode 22 | self.img_cond = cfg.img_cond 23 | self.sz = cfg.sz 24 | 25 | def __len__(self): 26 | return len(self.data) 27 | 28 | def __getitem__(self, idx): 29 | img_path = self.data_path + self.data.loc[idx]['filename'] 30 | mask_path = self.data_path + self.data.loc[idx]['maskname'] 31 | 32 | # load image and mask 33 | img = cv2.imread(img_path,0).astype(np.float32)/255. 34 | 35 | # load level sets mask 36 | if self.corr_mode == 'diffusion_ls': 37 | mask_ls_path = self.data_path + self.data.loc[idx]['maskdtname'] 38 | mask = np.load(mask_ls_path).astype(np.float32) 39 | else: 40 | mask = cv2.imread(mask_path,0).astype(np.float32) 41 | mask = mask/255. 42 | 43 | if self.corr_mode == 'diffusion': 44 | corr_type = 0 45 | else: 46 | corr_type = 1 47 | 48 | transform_cfg = { 49 | 'hflip': np.random.rand(), 50 | 'vflip': np.random.rand(), 51 | 'corr_type': corr_type, # 0 is diffusion 52 | 'img_cond': self.img_cond, 53 | } 54 | 55 | if self.img_cond: # condition on image 56 | ret = {'image': mask, 'mask': img, 'name': str(img_path.split('/')[-1][:-4])} 57 | else: 58 | ret = {'image': img, 'mask': mask, 'name': str(img_path.split('/')[-1][:-4])} 59 | 60 | self.transform(transform_cfg)(ret) 61 | 62 | if self.img_cond and self.corr_mode == 'diffusion_ls': 63 | if 'trunc' in self.data_path: 64 | ret['image'] /= 5. 65 | else: 66 | ret['image'] *= 1.6 67 | return ret -------------------------------------------------------------------------------- /datasets/monuseg_dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageOps 2 | import numpy as np 3 | import cv2 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | from types import SimpleNamespace 7 | 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | from datasets.transform_factory import transform_factory 11 | 12 | class MoNuSegDataset(Dataset): 13 | def __init__(self, data_path, csv_file, cfg): 14 | self.data_path = data_path 15 | self.csv_file = csv_file 16 | self.data = pd.read_csv(self.csv_file) 17 | 18 | self.transform = None 19 | self.inv_normalize = None 20 | 21 | self.corr_mode = cfg.corr_mode 22 | self.img_cond = cfg.img_cond 23 | self.sz = cfg.sz 24 | 25 | def __len__(self): 26 | return len(self.data) 27 | 28 | def __getitem__(self, idx): 29 | img_path = self.data_path + self.data.loc[idx]['filename'] 30 | mask_path = self.data_path + self.data.loc[idx]['maskname'] 31 | 32 | # load image and mask 33 | if 'rgb' in self.data_path: 34 | img = cv2.imread(img_path).astype(np.float32)/255. 35 | else: 36 | img = cv2.imread(img_path,0).astype(np.float32)/255. 37 | 38 | # load level sets mask 39 | if self.corr_mode == 'diffusion_ls': 40 | mask_ls_path = self.data_path + self.data.loc[idx]['maskdtname'] 41 | mask = np.load(mask_ls_path) 42 | else: 43 | mask = cv2.imread(mask_path,0).astype(np.float32) 44 | mask[mask > 200] = 255. 45 | mask[mask <= 200] = 0. 46 | # mask to [0,1] 47 | mask = mask/255. 48 | 49 | if self.corr_mode == 'diffusion': 50 | corr_type = 0 51 | else: 52 | corr_type = 1 53 | 54 | transform_cfg = { 55 | 'hflip': np.random.rand(), 56 | 'vflip': np.random.rand(), 57 | 'corr_type': corr_type, # 0 is diffusion 58 | 'img_cond': self.img_cond, 59 | } 60 | 61 | if self.img_cond: # condition on image 62 | ret = {'image': mask, 'mask': img, 'name': str(img_path.split('/')[-1][:-4])} 63 | else: 64 | ret = {'image': img, 'mask': mask, 'name': str(img_path.split('/')[-1][:-4])} 65 | 66 | self.transform(transform_cfg)(ret) 67 | 68 | if self.img_cond and self.corr_mode == 'diffusion_ls': 69 | if 'trunc' in self.data_path: 70 | ret['image'] /= 5. 71 | else: 72 | ret['image'] *= 10./3. # heuristic from intensity distribution of histogram 73 | 74 | return ret 75 | -------------------------------------------------------------------------------- /datasets/transform_factory.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, List, Optional, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.transforms import Compose, Lambda, transforms, InterpolationMode, CenterCrop 7 | from torchvision.transforms.functional import crop, hflip, vflip, equalize 8 | 9 | class ApplyTransformToKey: 10 | def __init__(self, key: str, transform: Callable): 11 | self._key = key 12 | self._transform = transform 13 | 14 | def __call__(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 15 | x[self._key] = self._transform(x[self._key]) 16 | return x 17 | 18 | # ---------------------------------------------------------------- 19 | 20 | 21 | def train_glas_transform(mean, std): 22 | def configured_transform(transform_config): 23 | p_hflip = transform_config['hflip'] 24 | p_vflip = transform_config['vflip'] 25 | corr_type = transform_config['corr_type'] 26 | img_cond = transform_config['img_cond'] 27 | 28 | def hflip_closure(img): 29 | return hflip(img) if p_hflip > 0.5 else img 30 | 31 | def vflip_closure(img): 32 | return vflip(img) if p_vflip > 0.5 else img 33 | 34 | def normalization_mask(mask): 35 | return (mask-0.5)*2 # normalize all to be in [-1,1] for guidance image 36 | 37 | return Compose([ 38 | ApplyTransformToKey( 39 | key="image", 40 | transform=Compose([ 41 | transforms.ToTensor(), 42 | hflip_closure, 43 | vflip_closure, 44 | ]), 45 | ), 46 | 47 | ApplyTransformToKey( 48 | key="mask", 49 | transform=Compose([ 50 | transforms.ToTensor(), 51 | normalization_mask, 52 | hflip_closure, 53 | vflip_closure, 54 | ]), 55 | ), 56 | ]) 57 | return configured_transform 58 | 59 | def test_glas_transform(mean, std): 60 | def configured_transform(transform_config): 61 | def normalization_mask(mask): 62 | return (mask-0.5)*2 # normalize all to be in [-1,1] for guidance image 63 | 64 | return Compose([ 65 | ApplyTransformToKey( 66 | key="image", 67 | transform=Compose([ 68 | transforms.ToTensor(), 69 | ]), 70 | ), 71 | 72 | ApplyTransformToKey( 73 | key="mask", 74 | transform=Compose([ 75 | transforms.ToTensor(), 76 | normalization_mask, 77 | ]), 78 | ), 79 | ]) 80 | return configured_transform 81 | 82 | def train_monuseg_transform(mean, std): 83 | def configured_transform(transform_config): 84 | p_hflip = transform_config['hflip'] 85 | p_vflip = transform_config['vflip'] 86 | corr_type = transform_config['corr_type'] 87 | img_cond = transform_config['img_cond'] 88 | 89 | def normalization(img): 90 | return (img-0.5)*2 if corr_type == 0 else img 91 | 92 | def hflip_closure(img): 93 | return hflip(img) if p_hflip > 0.5 else img 94 | 95 | def vflip_closure(img): 96 | return vflip(img) if p_vflip > 0.5 else img 97 | 98 | def normalization(img): 99 | return img # placeholder for different normalization procedure 100 | 101 | def normalization_mask(mask): 102 | return (mask-0.5)*2 # normalize all to be in [-1,1] for guidance image 103 | 104 | interp_mode_img = InterpolationMode.NEAREST if img_cond == 1 else InterpolationMode.BILINEAR 105 | interp_mode_mask = InterpolationMode.BILINEAR if img_cond == 1 else InterpolationMode.NEAREST 106 | 107 | return Compose([ 108 | ApplyTransformToKey( 109 | key="image", 110 | transform=Compose([ 111 | transforms.ToTensor(), 112 | hflip_closure, 113 | vflip_closure, 114 | ]), 115 | ), 116 | 117 | ApplyTransformToKey( 118 | key="mask", 119 | transform=Compose([ 120 | transforms.ToTensor(), 121 | normalization_mask, 122 | hflip_closure, 123 | vflip_closure, 124 | ]), 125 | ), 126 | ]) 127 | return configured_transform 128 | 129 | def test_monuseg_transform(mean, std): 130 | def configured_transform(transform_config): 131 | p_hflip = transform_config['hflip'] 132 | p_vflip = transform_config['vflip'] 133 | corr_type = transform_config['corr_type'] 134 | img_cond = transform_config['img_cond'] 135 | 136 | def normalization(img): 137 | return (img-0.5)*2 if corr_type == 0 else img 138 | 139 | def normalization(img): 140 | return img # placeholder for different normalization procedure 141 | 142 | def normalization_mask(mask): 143 | return (mask-0.5)*2 # normalize all to be in [-1,1] for guidance image 144 | 145 | interp_mode_img = InterpolationMode.NEAREST if img_cond == 1 else InterpolationMode.BILINEAR 146 | interp_mode_mask = InterpolationMode.BILINEAR if img_cond == 1 else InterpolationMode.NEAREST 147 | 148 | return Compose([ 149 | ApplyTransformToKey( 150 | key="image", 151 | transform=Compose([ 152 | transforms.ToTensor(), 153 | ]), 154 | ), 155 | 156 | ApplyTransformToKey( 157 | key="mask", 158 | transform=Compose([ 159 | transforms.ToTensor(), 160 | normalization_mask, 161 | ]), 162 | ), 163 | ]) 164 | return configured_transform 165 | 166 | # ---------------------------------------------------------------- 167 | def inv_normalize(mean, std): 168 | return transforms.Normalize(mean=-mean/std, std=1/std) 169 | 170 | def transform_factory(cfg): 171 | if cfg.modality == 'monuseg': 172 | ret = { 173 | 'train': train_monuseg_transform, 174 | 'test' : test_monuseg_transform 175 | } 176 | 177 | elif cfg.modality == 'glas': 178 | ret = { 179 | 'train': train_glas_transform, 180 | 'test': test_glas_transform 181 | } 182 | 183 | else: 184 | raise ValueError('Unknown modality %s specified!' %cfg.modality) 185 | 186 | return ret 187 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: generative_segmentation_sdf 2 | channels: 3 | - anaconda 4 | - pytorch 5 | - nvidia 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=conda_forge 10 | - _openmp_mutex=4.5=1_gnu 11 | - _tflow_select=2.3.0=mkl 12 | - absl-py=0.13.0=pyhd8ed1ab_0 13 | - aiohttp=3.7.4.post0=py38h497a2fe_0 14 | - appdirs=1.4.4=pyh9f0ad1d_0 15 | - astor=0.8.1=pyh9f0ad1d_0 16 | - astunparse=1.6.3=pyhd8ed1ab_0 17 | - async-timeout=3.0.1=py_1000 18 | - attrs=21.2.0=pyhd8ed1ab_0 19 | - autopep8=1.5.7=pyhd8ed1ab_0 20 | - blas=1.0=mkl 21 | - blinker=1.4=py_1 22 | - brotlipy=0.7.0=py38h497a2fe_1001 23 | - bzip2=1.0.8=h7f98852_4 24 | - c-ares=1.17.1=h7f98852_1 25 | - ca-certificates=2022.9.24=ha878542_0 26 | - cachetools=4.2.2=pyhd8ed1ab_0 27 | - certifi=2022.9.24=pyhd8ed1ab_0 28 | - cffi=1.14.6=py38ha65f79e_0 29 | - chardet=3.0.4=py38h924ce5b_1008 30 | - click=8.0.1=py38h578d9bd_0 31 | - cloudpickle=2.0.0=pyhd8ed1ab_0 32 | - coverage=5.5=py38h497a2fe_0 33 | - cryptography=3.4.7=py38ha5dfef3_0 34 | - cudatoolkit=11.1.74=h6bb024c_0 35 | - cycler=0.10.0=py_2 36 | - cython=0.29.24=py38h709712a_0 37 | - cytoolz=0.11.0=py38h497a2fe_3 38 | - dask-core=2021.8.1=pyhd8ed1ab_0 39 | - dbus=1.13.18=hb2f20db_0 40 | - decorator=4.4.2=py_0 41 | - dominate=2.6.0=pyhd8ed1ab_0 42 | - einops=0.4.1=pyhd8ed1ab_0 43 | - expat=2.4.1=h9c3ff4c_0 44 | - ffmpeg=4.3=hf484d3e_0 45 | - fontconfig=2.13.1=h6c09931_0 46 | - freetype=2.10.4=h0708190_1 47 | - fsspec=2021.8.1=pyhd8ed1ab_0 48 | - gast=0.4.0=pyh9f0ad1d_0 49 | - gettext=0.19.8.1=h0b5b191_1005 50 | - glib=2.69.0=h5202010_0 51 | - gmp=6.2.1=h58526e2_0 52 | - gnutls=3.6.15=he1e5248_0 53 | - google-auth=1.33.0=pyh6c4a22f_0 54 | - google-auth-oauthlib=0.4.4=pyhd8ed1ab_0 55 | - google-pasta=0.2.0=pyh8c360ce_0 56 | - grpcio=1.36.1=py38hdd6454d_0 57 | - gst-plugins-base=1.14.0=h8213a91_2 58 | - gstreamer=1.14.0=h28cd5cc_2 59 | - h5py=2.10.0=nompi_py38h9915d05_106 60 | - hdf5=1.10.6=nompi_h7c3c948_1111 61 | - icu=58.2=hf484d3e_1000 62 | - idna=2.10=pyh9f0ad1d_0 63 | - imageio=2.9.0=py_0 64 | - importlib-metadata=3.10.0=py38h578d9bd_0 65 | - intel-openmp=2021.3.0=h06a4308_3350 66 | - joblib=1.0.1=pyhd8ed1ab_0 67 | - jpeg=9b=h024ee3a_2 68 | - keras-preprocessing=1.1.2=pyhd8ed1ab_0 69 | - kiwisolver=1.3.1=py38h1fd1430_1 70 | - krb5=1.19.2=hcc1bbae_0 71 | - lame=3.100=h7f98852_1001 72 | - lcms2=2.12=h3be6417_0 73 | - ld_impl_linux-64=2.35.1=hea4e1c9_2 74 | - libblas=3.9.0=11_linux64_mkl 75 | - libcblas=3.9.0=11_linux64_mkl 76 | - libcurl=7.78.0=h2574ce0_0 77 | - libedit=3.1.20191231=he28a2e2_2 78 | - libev=4.33=h516909a_1 79 | - libffi=3.3=h58526e2_2 80 | - libgcc-ng=9.3.0=h2828fa1_19 81 | - libgfortran-ng=7.5.0=h14aa051_19 82 | - libgfortran4=7.5.0=h14aa051_19 83 | - libgomp=9.3.0=h2828fa1_19 84 | - libiconv=1.15=h516909a_1006 85 | - libidn2=2.3.2=h7f98852_0 86 | - libllvm10=10.0.1=he513fc3_3 87 | - libnghttp2=1.43.0=h812cca2_0 88 | - libpng=1.6.37=h21135ba_2 89 | - libprotobuf=3.17.2=h780b84a_1 90 | - libssh2=1.9.0=ha56f1ee_6 91 | - libstdcxx-ng=9.3.0=h6de172a_19 92 | - libtasn1=4.16.0=h27cfd23_0 93 | - libtiff=4.2.0=h85742a9_0 94 | - libunistring=0.9.10=h7f98852_0 95 | - libuuid=1.0.3=h7f8727e_2 96 | - libuv=1.40.0=h7f98852_0 97 | - libwebp-base=1.2.0=h7f98852_2 98 | - libxcb=1.14=h7b6447c_0 99 | - libxml2=2.9.12=h03d6c58_0 100 | - libxslt=1.1.34=hc22bd24_0 101 | - libzlib=1.2.11=h36c2ea0_1013 102 | - llvmlite=0.36.0=py38h4630a5e_0 103 | - locket=0.2.1=py38h06a4308_1 104 | - lxml=4.8.0=py38h1f438cf_0 105 | - lz4-c=1.9.3=h9c3ff4c_1 106 | - markdown=3.3.4=pyhd8ed1ab_0 107 | - matplotlib=3.3.4=py38h578d9bd_0 108 | - matplotlib-base=3.3.4=py38h0efea84_0 109 | - mkl=2021.3.0=h06a4308_520 110 | - mkl-service=2.4.0=py38h497a2fe_0 111 | - mkl_fft=1.3.0=py38h42c9631_2 112 | - mkl_random=1.2.2=py38h1abd341_0 113 | - multidict=5.1.0=py38h497a2fe_1 114 | - mypy=0.910=py38h497a2fe_0 115 | - mypy_extensions=0.4.3=py38h578d9bd_4 116 | - ncurses=6.2=h58526e2_4 117 | - nettle=3.7.3=hbbd107a_1 118 | - networkx=2.6.3=pyhd8ed1ab_1 119 | - ninja=1.10.2=h4bd325d_0 120 | - numba=0.53.1=py38h8b71fd7_1 121 | - numpy=1.20.3=py38hf144106_0 122 | - numpy-base=1.20.3=py38h74d4b33_0 123 | - oauthlib=3.1.1=pyhd8ed1ab_0 124 | - olefile=0.46=pyh9f0ad1d_1 125 | - openh264=2.1.0=hd408876_0 126 | - openjpeg=2.3.0=hf38bd82_1003 127 | - openssl=1.1.1q=h7f8727e_0 128 | - opt_einsum=3.3.0=pyhd8ed1ab_1 129 | - packaging=21.0=pyhd8ed1ab_0 130 | - pandas=1.3.1=py38h1abd341_0 131 | - partd=1.2.0=pyhd8ed1ab_0 132 | - pcre=8.45=h9c3ff4c_0 133 | - pillow=8.3.1=py38h2c7a002_0 134 | - pip=21.1.3=pyhd8ed1ab_0 135 | - pooch=1.5.2=pyhd8ed1ab_0 136 | - protobuf=3.17.2=py38h709712a_0 137 | - psutil=5.8.0=py38h497a2fe_1 138 | - pyasn1=0.4.8=py_0 139 | - pyasn1-modules=0.2.8=py_0 140 | - pycodestyle=2.7.0=pyhd8ed1ab_0 141 | - pycparser=2.20=pyh9f0ad1d_2 142 | - pyjwt=2.1.0=pyhd8ed1ab_0 143 | - pyopenssl=20.0.1=pyhd8ed1ab_0 144 | - pyparsing=2.4.7=pyhd8ed1ab_1 145 | - pyqt=5.9.2=py38h05f1152_4 146 | - pysocks=1.7.1=py38h578d9bd_4 147 | - python=3.8.10=h49503c6_1_cpython 148 | - python-dateutil=2.8.2=pyhd8ed1ab_0 149 | - python-flatbuffers=1.12=pyhd8ed1ab_1 150 | - python_abi=3.8=2_cp38 151 | - pytorch=1.9.0=py3.8_cuda11.1_cudnn8.0.5_0 152 | - pytz=2021.3=pyhd8ed1ab_0 153 | - pyu2f=0.1.5=pyhd8ed1ab_0 154 | - pywavelets=1.1.1=py38h5c078b8_3 155 | - pyyaml=5.4.1=py38h497a2fe_0 156 | - qt=5.9.7=h5867ecd_1 157 | - readline=8.1=h46c0cb4_0 158 | - requests=2.25.1=pyhd3deb0d_0 159 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0 160 | - rsa=4.7.2=pyh44b312d_0 161 | - scikit-image=0.18.1=py38h51da96c_0 162 | - scikit-learn=0.24.2=py38hdc147b9_0 163 | - scipy=1.6.2=py38had2a1c9_1 164 | - seaborn=0.11.2=pyhd3eb1b0_0 165 | - setuptools=52.0.0=py38h06a4308_1 166 | - sip=4.19.13=py38he6710b0_0 167 | - six=1.16.0=pyh6c4a22f_0 168 | - sqlite=3.36.0=h9cd32fc_0 169 | - tbb=2020.3=hfd86e86_0 170 | - tensorboard=2.6.0=pyhd8ed1ab_1 171 | - tensorboard-data-server=0.6.0=py38h2b97feb_0 172 | - tensorboard-plugin-wit=1.6.0=pyh9f0ad1d_0 173 | - tensorboardx=2.5.1=pyhd8ed1ab_0 174 | - tensorflow=2.4.1=mkl_py38hb2083e0_0 175 | - tensorflow-base=2.4.1=mkl_py38h43e0292_0 176 | - tensorflow-estimator=2.5.0=pyh81a9013_1 177 | - termcolor=1.1.0=py_2 178 | - threadpoolctl=2.2.0=pyh8a188c0_0 179 | - tifffile=2020.10.1=py38hdd07704_2 180 | - tk=8.6.10=h21135ba_1 181 | - toml=0.10.2=pyhd8ed1ab_0 182 | - toolz=0.11.1=py_0 183 | - torchaudio=0.9.0=py38 184 | - torchvision=0.10.0=py38_cu111 185 | - tornado=6.1=py38h497a2fe_1 186 | - typing-extensions=3.10.0.0=hd8ed1ab_0 187 | - typing_extensions=3.10.0.0=pyha770c72_0 188 | - urllib3=1.26.6=pyhd8ed1ab_0 189 | - werkzeug=1.0.1=pyh9f0ad1d_0 190 | - wheel=0.36.2=pyhd3deb0d_0 191 | - wrapt=1.12.1=py38h497a2fe_3 192 | - xz=5.2.5=h516909a_1 193 | - yaml=0.2.5=h516909a_0 194 | - yarl=1.6.3=py38h497a2fe_2 195 | - zipp=3.5.0=pyhd8ed1ab_0 196 | - zlib=1.2.11=h36c2ea0_1013 197 | - zstd=1.4.9=ha95c52a_0 198 | - pip: 199 | - anyio==3.6.1 200 | - argon2-cffi==21.3.0 201 | - argon2-cffi-bindings==21.2.0 202 | - asttokens==2.0.5 203 | - av==8.0.3 204 | - babel==2.10.1 205 | - backcall==0.2.0 206 | - beautifulsoup4==4.11.1 207 | - bleach==5.0.0 208 | - bottle==0.12.19 209 | - bottle-websocket==0.2.9 210 | - debugpy==1.6.0 211 | - defusedxml==0.7.1 212 | - eel==0.14.0 213 | - entrypoints==0.4 214 | - executing==0.8.3 215 | - fastjsonschema==2.15.3 216 | - future==0.18.2 217 | - fvcore==0.1.5.post20211019 218 | - gevent==21.1.2 219 | - gevent-websocket==0.10.1 220 | - greenlet==1.1.0 221 | - higher==0.2.1 222 | - imageio-ffmpeg==0.4.5 223 | - importlib-resources==5.7.1 224 | - iopath==0.1.9 225 | - ipykernel==6.13.0 226 | - ipython==8.3.0 227 | - ipython-genutils==0.2.0 228 | - jedi==0.18.1 229 | - jinja2==3.1.2 230 | - json5==0.9.8 231 | - jsonschema==4.5.1 232 | - jupyter-client==7.3.1 233 | - jupyter-core==4.10.0 234 | - jupyter-server==1.17.0 235 | - jupyterlab==3.4.2 236 | - jupyterlab-pygments==0.2.2 237 | - jupyterlab-server==2.13.0 238 | - markupsafe==2.1.1 239 | - matplotlib-inline==0.1.3 240 | - mistune==0.8.4 241 | - moviepy==1.0.3 242 | - nbclassic==0.3.7 243 | - nbclient==0.6.3 244 | - nbconvert==6.5.0 245 | - nbformat==5.4.0 246 | - nest-asyncio==1.5.5 247 | - notebook==6.4.11 248 | - notebook-shim==0.1.0 249 | - opencv-python==4.5.3.56 250 | - optoth==0.2.0 251 | - pad2d-op-v1==1.0 252 | - pandocfilters==1.5.0 253 | - parameterized==0.8.1 254 | - parso==0.8.3 255 | - perlin-noise==1.12 256 | - pexpect==4.8.0 257 | - pickleshare==0.7.5 258 | - portalocker==2.3.2 259 | - proglog==0.1.9 260 | - prometheus-client==0.14.1 261 | - prompt-toolkit==3.0.29 262 | - ptyprocess==0.7.0 263 | - pure-eval==0.2.2 264 | - pygments==2.12.0 265 | - pyrsistent==0.18.1 266 | - pytorchvideo==0.1.3 267 | - pyzmq==22.3.0 268 | - send2trash==1.8.0 269 | - sniffio==1.2.0 270 | - soupsieve==2.3.2.post1 271 | - stack-data==0.2.0 272 | - tabulate==0.8.9 273 | - terminado==0.15.0 274 | - tinycss2==1.1.1 275 | - torch-dct==0.1.5 276 | - torch-tb-profiler==0.3.1 277 | - torchmetrics==0.11.1 278 | - tqdm==4.62.0 279 | - traitlets==5.2.1.post0 280 | - ttach==0.0.3 281 | - wcwidth==0.2.5 282 | - webencodings==0.5.1 283 | - websocket-client==1.3.2 284 | - whichcraft==0.6.1 285 | - yacs==0.1.8 286 | - zope-event==4.5.0 287 | - zope-interface==5.4.0 288 | prefix: /opt/python_envs/anaconda3/envs/granules 289 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import random 2 | from types import SimpleNamespace 3 | import imageio 4 | import numpy as np 5 | import argparse 6 | import sys 7 | import os 8 | import json 9 | from tqdm.auto import tqdm 10 | import matplotlib.pyplot as plt 11 | import yaml 12 | 13 | import einops 14 | import torch 15 | import torch.nn as nn 16 | from torch.optim import Adam 17 | from torch.utils.data import DataLoader 18 | from torch.utils.tensorboard import SummaryWriter 19 | 20 | from SimulationHelper.simulation import Simulation 21 | from datasets.config_dl import config_dl 22 | from models import ddpm 23 | import trainer 24 | 25 | # Setting reproducibility 26 | SEED = 0 27 | random.seed(SEED) 28 | np.random.seed(SEED) 29 | torch.manual_seed(SEED) 30 | 31 | parser = argparse.ArgumentParser("") 32 | parser.add_argument( 33 | "--config", default="cfg/monuseg.yaml", type=str, help="path to .yaml config" # glas, monuseg 34 | ) 35 | args = parser.parse_args() 36 | 37 | def count_parameters(net): 38 | return sum(p.numel() for p in net.parameters() if p.requires_grad) 39 | 40 | if __name__ == "__main__": 41 | # program arguments 42 | with open(args.config) as file: 43 | yaml_cfg = yaml.safe_load(file) 44 | cfg = json.loads( 45 | json.dumps(yaml_cfg), object_hook=lambda d: SimpleNamespace(**d) 46 | ) 47 | 48 | device = torch.device("cuda") 49 | print(f"Using device: {device}\t" + (f"{torch.cuda.get_device_name(0)}")) 50 | 51 | # set up dataloader, model 52 | train_dl, test_dl = config_dl(cfg) 53 | if cfg.model.type == 'unet': 54 | model = ddpm.Network( 55 | dim=cfg.model.dim, 56 | channels=cfg.model.n_cin, 57 | cond_channels=cfg.model.n_cin_cond, 58 | init_dim=cfg.model.n_fm, 59 | dim_mults=tuple(cfg.model.mults), 60 | embedding=cfg.model.embedding, 61 | img_cond=cfg.general.img_cond, 62 | with_class_label_emb=cfg.general.with_class_label_emb, 63 | class_label_cond=cfg.general.class_label_cond, 64 | num_classes=cfg.general.num_classes, 65 | ).to(device) 66 | 67 | else: 68 | raise ValueError('Unknown model type!') 69 | 70 | # optimizer 71 | optim = Adam(model.parameters(), cfg.learning.lr) 72 | 73 | # Optionally, load a pre-trained model that will be further trained 74 | if cfg.general.resume_training: 75 | load_path = os.getcwd() + "/runs/" + cfg.general.modality 76 | load_path += "/" + cfg.general.load_path + "/models/" 77 | fnames = sorted([fname for fname in os.listdir(load_path) if fname.endswith(".pt")]) 78 | 79 | model.load_state_dict( 80 | torch.load(load_path + fnames[-1], map_location=device)["state_dict"], 81 | strict=False, 82 | ) 83 | print("\nINFO: succesfully retrieved learned model params from specified cfg dir/epoch!") 84 | 85 | # load optimizer state dict 86 | optim.load_state_dict(torch.load(load_path + fnames[-1], map_location=device)["optimizer"]) 87 | print("\nINFO: succesfully retrieved optim state dict specified cfg dir/epoch!") 88 | 89 | # network params 90 | print("\nNetwork has %i params" % count_parameters(model)) 91 | 92 | # simulation 93 | sim_name = str(cfg.general.modality) 94 | with Simulation( 95 | sim_name=sim_name, output_root=f'{os.path.join(os.getcwd(), "runs/")}' 96 | ) as simulation: 97 | writer = SummaryWriter(os.path.join(simulation.outdir, "tensorboard")) 98 | cfg.inference.load_exp = simulation.outdir.split("/")[-1] 99 | with open(os.path.join(simulation.outdir, "cfg.yaml"), "w") as f: 100 | yaml.dump({k: v.__dict__ for k, v in cfg.__dict__.items()}, f) 101 | 102 | # training 103 | if (cfg.general.corr_mode == "diffusion" or cfg.general.corr_mode == "diffusion_ls"): 104 | noise_level_dict={'s1': cfg.SMLD.sigma_1_m, 'sL': cfg.SMLD.sigma_L_m, 'L': cfg.SMLD.n_steps} 105 | beta_dict = {'beta1': cfg.SMLD.beta_1, 'betaT': cfg.SMLD.beta_T, 'T': cfg.SMLD.T} 106 | 107 | trainer.TrainScoreNetwork(noise_level_dict,beta_dict,sde=cfg.SMLD.sde,model_type=cfg.model.type,train_objective=cfg.SMLD.objective,loss_power=cfg.learning.loss,n_val=cfg.learning.n_val,val_dl=train_dl).do( 108 | model, 109 | train_dl, 110 | cfg.learning.epochs, 111 | cfg.learning.clip, 112 | optim=optim, 113 | device=device, 114 | simulation=simulation, 115 | writer=writer, 116 | img_cond=cfg.general.img_cond, 117 | class_label_cond=cfg.general.class_label_cond, 118 | ) 119 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leabogensperger/generative-segmentation-sdf/3b8037e3e03d347a41f361f6ba844e6928240200/models/__init__.py -------------------------------------------------------------------------------- /models/ddpm.py: -------------------------------------------------------------------------------- 1 | import math 2 | from inspect import isfunction 3 | from functools import partial 4 | import matplotlib.pyplot as plt 5 | from tqdm.auto import tqdm 6 | from einops import rearrange 7 | import numpy as np 8 | 9 | import torch 10 | from torch import nn, einsum 11 | import torch.nn.functional as F 12 | 13 | # taken and adapted from https://huggingface.co/blog/annotated-diffusion 14 | 15 | def exists(x): 16 | return x is not None 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | class Residual(nn.Module): 26 | def __init__(self, fn): 27 | super().__init__() 28 | self.fn = fn 29 | 30 | def forward(self, x, *args, **kwargs): 31 | return self.fn(x, *args, **kwargs) + x 32 | 33 | 34 | def Upsample(dim): 35 | return nn.ConvTranspose2d(dim, dim, 4, 2, 1) 36 | 37 | 38 | def Downsample(dim): 39 | return nn.Conv2d(dim, dim, 4, 2, 1) 40 | 41 | 42 | class SinusoidalPositionEmbeddings(nn.Module): 43 | def __init__(self, dim): 44 | super().__init__() 45 | self.dim = dim 46 | 47 | def forward(self, time): 48 | device = time.device 49 | half_dim = self.dim // 2 50 | embeddings = math.log(10000) / (half_dim - 1) 51 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 52 | embeddings = ( 53 | time * embeddings[None, :] 54 | ) # time is already in the shape [batchsize,1] 55 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 56 | return embeddings 57 | 58 | class GaussianFourierProjection(nn.Module): # for continuous training 59 | """Gaussian Fourier embeddings for noise levels. 60 | taken from https://github.com/yang-song/score_sde_pytorch/blob/cb1f359f4aadf0ff9a5e122fe8fffc9451fd6e44/models/layerspp.py#L32 61 | """ 62 | 63 | def __init__(self, dim=256, scale=16.0): 64 | super().__init__() 65 | dim = dim//2 66 | self.W = nn.Parameter(torch.randn(dim) * scale, requires_grad=False) 67 | 68 | def forward(self, x): 69 | x_proj = x* self.W[None, :] * 2 * np.pi 70 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 71 | 72 | 73 | class Block(nn.Module): 74 | def __init__(self, dim, dim_out, groups=8): 75 | super().__init__() 76 | self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) 77 | self.norm = nn.GroupNorm(groups, dim_out) 78 | self.act = nn.SiLU() 79 | 80 | def forward(self, x, scale_shift=None): 81 | x = self.proj(x) 82 | x = self.norm(x) 83 | 84 | if exists(scale_shift): 85 | scale, shift = scale_shift 86 | x = x * (scale + 1) + shift 87 | 88 | x = self.act(x) 89 | return x 90 | 91 | 92 | class ResnetBlock(nn.Module): 93 | """https://arxiv.org/abs/1512.03385""" 94 | 95 | def __init__( 96 | self, 97 | dim, 98 | dim_out, 99 | *, 100 | time_emb_dim=None, 101 | groups=8, 102 | class_label_cond=False, 103 | num_classes=None 104 | ): 105 | super().__init__() 106 | self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out)) if exists(time_emb_dim) else None 107 | 108 | if class_label_cond: 109 | # TODO time_emb_dim, class_emd_dim should have their own params for dimentionality. 110 | self.class_label_mlp = nn.Sequential( 111 | nn.SiLU(), nn.Linear(time_emb_dim, dim_out) 112 | ) 113 | 114 | self.block1 = Block(dim, dim_out, groups=groups) 115 | self.block2 = Block(dim_out, dim_out, groups=groups) 116 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 117 | 118 | self.num_classees = num_classes 119 | self.class_label_cond = class_label_cond 120 | 121 | def forward(self, x, time_emb=None, class_lbl=None): 122 | h = self.block1(x) 123 | 124 | if exists(self.mlp) and exists(time_emb): 125 | time_emb = self.mlp(time_emb) 126 | h = rearrange(time_emb, "b c -> b c 1 1") + h 127 | 128 | if self.class_label_cond is True: 129 | class_lbl_emb = self.class_label_mlp( 130 | class_lbl 131 | ) # Bring the lbl_emb to correct shape. 132 | h = rearrange(class_lbl_emb, "b c -> b c 1 1") + h 133 | 134 | h = self.block2(h) 135 | return h + self.res_conv(x) 136 | 137 | 138 | class Attention(nn.Module): 139 | def __init__(self, dim, heads=4, dim_head=32): 140 | super().__init__() 141 | self.scale = dim_head**-0.5 142 | self.heads = heads 143 | hidden_dim = dim_head * heads 144 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 145 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 146 | 147 | def forward(self, x): 148 | b, c, h, w = x.shape 149 | qkv = self.to_qkv(x).chunk(3, dim=1) 150 | q, k, v = map( 151 | lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 152 | ) 153 | q = q * self.scale 154 | 155 | sim = einsum("b h d i, b h d j -> b h i j", q, k) 156 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 157 | attn = sim.softmax(dim=-1) 158 | 159 | out = einsum("b h i j, b h d j -> b h i d", attn, v) 160 | out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) 161 | return self.to_out(out) 162 | 163 | 164 | class LinearAttention(nn.Module): 165 | def __init__(self, dim, heads=4, dim_head=32): 166 | super().__init__() 167 | self.scale = dim_head**-0.5 168 | self.heads = heads 169 | hidden_dim = dim_head * heads 170 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 171 | 172 | self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim)) 173 | 174 | def forward(self, x): 175 | b, c, h, w = x.shape 176 | qkv = self.to_qkv(x).chunk(3, dim=1) 177 | q, k, v = map( 178 | lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 179 | ) 180 | 181 | q = q.softmax(dim=-2) 182 | k = k.softmax(dim=-1) 183 | 184 | q = q * self.scale 185 | context = torch.einsum("b h d n, b h e n -> b h d e", k, v) 186 | 187 | out = torch.einsum("b h d e, b h d n -> b h e n", context, q) 188 | out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) 189 | return self.to_out(out) 190 | 191 | 192 | class PreNorm(nn.Module): 193 | def __init__(self, dim, fn): 194 | super().__init__() 195 | self.fn = fn 196 | self.norm = nn.GroupNorm(1, dim) 197 | 198 | def forward(self, x): 199 | x = self.norm(x) 200 | return self.fn(x) 201 | 202 | 203 | class Network(nn.Module): 204 | def __init__( 205 | self, 206 | dim, 207 | init_dim=None, 208 | out_dim=None, 209 | dim_mults=(1, 2, 4, 8), 210 | channels=1, # channels of sdf input 211 | cond_channels=1, # rgb vs gray input for conditioning 212 | embedding='sinusoidal', 213 | with_time_emb=True, 214 | resnet_block_groups=2, 215 | img_cond=None, 216 | with_class_label_emb=False, 217 | class_label_cond=False, 218 | num_classes=None, 219 | ): 220 | super().__init__() 221 | 222 | self.embedding = embedding 223 | assert self.embedding == 'fourier' or self.embedding == 'sinusoidal' 224 | 225 | self.class_label_cond = class_label_cond 226 | self.num_classes = num_classes 227 | self.with_class_label_emb = with_class_label_emb 228 | 229 | block_klass = partial(ResnetBlock, groups=resnet_block_groups) 230 | 231 | # time embeddings 232 | if with_time_emb: 233 | time_dim = dim * 4 234 | self.time_mlp = nn.Sequential( 235 | GaussianFourierProjection(dim) if self.embedding == 'fourier' else SinusoidalPositionEmbeddings(dim), # TODO: include option which embedding type to use 236 | # SinusoidalPositionEmbeddings(dim), 237 | nn.Linear(dim, time_dim), 238 | nn.GELU(), 239 | nn.Linear(time_dim, time_dim), 240 | ) 241 | else: 242 | raise Exception("Time embedding is set to False, None of the other code can deal with it. Think. Idiot.") 243 | 244 | # class_label embeddings 245 | if class_label_cond == True: 246 | if with_class_label_emb: 247 | class_label_dim = dim * 4 248 | self.class_label_embedding_mlp = nn.Sequential( 249 | SinusoidalPositionEmbeddings(dim), 250 | nn.Linear(dim, class_label_dim), 251 | nn.GELU(), 252 | nn.Linear(class_label_dim, class_label_dim), 253 | ) 254 | else: 255 | class_label_dim = dim * 4 256 | self.class_label_embedding_mlp = nn.Sequential( 257 | nn.Linear(dim, class_label_dim), 258 | nn.GELU(), 259 | nn.Linear(class_label_dim, class_label_dim), 260 | ) 261 | else: 262 | self.class_label_embedding_mlp = None 263 | 264 | # determine dimensions 265 | self.channels = channels 266 | self.cond_channels = cond_channels 267 | 268 | # conditioning branch at beginning (SegDiff style) 269 | if img_cond == 1: 270 | self.encoder_img = nn.Sequential( 271 | block_klass(channels, init_dim, time_emb_dim=time_dim) 272 | ) 273 | self.encoder_mask = nn.Sequential( 274 | block_klass(cond_channels, init_dim, time_emb_dim=time_dim) 275 | ) 276 | init_dim *= 2 277 | channels = init_dim 278 | 279 | init_dim = default(init_dim, dim // 3 * 2) 280 | self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3) 281 | 282 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 283 | in_out = list(zip(dims[:-1], dims[1:])) 284 | 285 | # layers 286 | self.downs = nn.ModuleList([]) 287 | self.ups = nn.ModuleList([]) 288 | num_resolutions = len(in_out) 289 | 290 | for ind, (dim_in, dim_out) in enumerate(in_out): 291 | is_last = ind >= (num_resolutions - 1) 292 | 293 | self.downs.append( 294 | nn.ModuleList( 295 | [ 296 | block_klass( 297 | dim_in, 298 | dim_out, 299 | time_emb_dim=time_dim, 300 | class_label_cond=class_label_cond, 301 | num_classes=num_classes, 302 | ), 303 | block_klass( 304 | dim_out, 305 | dim_out, 306 | time_emb_dim=time_dim, 307 | class_label_cond=class_label_cond, 308 | num_classes=num_classes, 309 | ), 310 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 311 | Downsample(dim_out) if not is_last else nn.Identity(), 312 | ] 313 | ) 314 | ) 315 | 316 | mid_dim = dims[-1] 317 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 318 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) 319 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 320 | 321 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 322 | is_last = ind >= (num_resolutions - 1) 323 | 324 | self.ups.append( 325 | nn.ModuleList( 326 | [ 327 | block_klass( 328 | dim_out * 2, 329 | dim_in, 330 | time_emb_dim=time_dim, 331 | class_label_cond=class_label_cond, 332 | num_classes=num_classes, 333 | ), 334 | block_klass( 335 | dim_in, 336 | dim_in, 337 | time_emb_dim=time_dim, 338 | class_label_cond=class_label_cond, 339 | num_classes=num_classes, 340 | ), 341 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 342 | Upsample(dim_in) if not is_last else nn.Identity(), 343 | ] 344 | ) 345 | ) 346 | 347 | out_dim = default(out_dim, self.channels) 348 | self.final_conv = nn.Sequential( 349 | block_klass(dim, dim, time_emb_dim=None), nn.Conv2d(dim, out_dim, 1) 350 | ) 351 | 352 | def forward(self, x, time, img_cond=None, class_lbl=None): 353 | if img_cond is not None: 354 | # x = self.encoder_img(x) + self.encoder_mask(cond) 355 | x = torch.cat((self.encoder_img(x), self.encoder_mask(img_cond)), 1) 356 | 357 | x = self.init_conv(x) 358 | 359 | t = self.time_mlp(time) if exists(self.time_mlp) else None 360 | class_lbl = ( 361 | self.class_label_embedding_mlp(class_lbl) 362 | if exists(self.class_label_embedding_mlp) 363 | else None 364 | ) 365 | 366 | h = [] 367 | 368 | # downsample 369 | for block1, block2, attn, downsample in self.downs: 370 | x = block1(x, t, class_lbl) 371 | x = block2(x, t, class_lbl) 372 | x = attn(x) 373 | h.append(x) 374 | x = downsample(x) 375 | 376 | # bottleneck 377 | x = self.mid_block1(x, t, class_lbl) 378 | x = self.mid_attn(x) 379 | x = self.mid_block2(x, t, class_lbl) 380 | 381 | # upsample 382 | for block1, block2, attn, upsample in self.ups: 383 | x = torch.cat((x, h.pop()), dim=1) 384 | x = block1(x, t, class_lbl) 385 | x = block2(x, t, class_lbl) 386 | x = attn(x) 387 | x = upsample(x) 388 | 389 | return self.final_conv(x) -------------------------------------------------------------------------------- /preprocess_data/precompute_sdf.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "from skimage.draw import ellipse\n", 12 | "from skimage.morphology import binary_erosion\n", 13 | "from scipy.ndimage import distance_transform_edt" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 37, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "# generate toy mask\n", 23 | "N = 128\n", 24 | "m = np.zeros((N, N), dtype=int)\n", 25 | "\n", 26 | "# Define parameters for the ellipses (center, semi-axes, and orientation)\n", 27 | "ellipses = [\n", 28 | " (30, 40, 20, 20, np.deg2rad(30)), # (row, col, semi-major, semi-minor, rotation)\n", 29 | " (20, 80, 10, 20, np.deg2rad(75)),\n", 30 | " (90, 60, 30, 40, np.deg2rad(15)),\n", 31 | "]\n", 32 | "\n", 33 | "# Draw each ellipse on the mask\n", 34 | "for (r, c, r_radius, c_radius, angle) in ellipses:\n", 35 | " rr, cc = ellipse(r, c, r_radius, c_radius, rotation=angle, shape=m.shape)\n", 36 | " m[rr, cc] = 1" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 38, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "data": { 46 | "image/png": "", 47 | "text/plain": [ 48 | "
" 49 | ] 50 | }, 51 | "metadata": {}, 52 | "output_type": "display_data" 53 | } 54 | ], 55 | "source": [ 56 | "# extract boundaries\n", 57 | "m_bd = np.abs(binary_erosion(m) - m)\n", 58 | "\n", 59 | "fig, ax = plt.subplots(1,2)\n", 60 | "ax[0].imshow(m), ax[0].set_title('binary mask')\n", 61 | "ax[1].imshow(m_bd), ax[1].set_title('extracted boundaries')\n", 62 | "plt.show()" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 55, 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "data": { 72 | "image/png": "", 73 | "text/plain": [ 74 | "
" 75 | ] 76 | }, 77 | "metadata": {}, 78 | "output_type": "display_data" 79 | } 80 | ], 81 | "source": [ 82 | "# extract signed distance function (SDF)\n", 83 | "distance = distance_transform_edt(np.where(m_bd==0., np.ones_like(m_bd), np.zeros_like(m_bd)))\n", 84 | "m_sdf = np.where(m == 1, distance*-1, distance) # ensure signed DT\n", 85 | "\n", 86 | "# truncate at threshold and normalize between [-1,1]\n", 87 | "thresh = 15\n", 88 | "m_sdf[m_sdf >= thresh] = thresh\n", 89 | "m_sdf[m_sdf <= -thresh] = -thresh\n", 90 | "m_sdf /= thresh\n", 91 | "\n", 92 | "fig, ax = plt.subplots(1,2)\n", 93 | "ax[0].imshow(m), ax[0].set_title('binary mask')\n", 94 | "ax[1].imshow(m_sdf), ax[1].set_title('SDF transformed mask')\n", 95 | "plt.show()" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 56, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "SDF mask values living in [-1,1]\n" 108 | ] 109 | } 110 | ], 111 | "source": [ 112 | "# check SDF normalization\n", 113 | "print('SDF mask values living in [%i,%i]' %(m_sdf.min(),m_sdf.max()))" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 57, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "Retrieved mask from thresholding SDF agrees with original binary mask: True\n" 126 | ] 127 | } 128 | ], 129 | "source": [ 130 | "# retrieve binary mask from SDF mask by thresholding \n", 131 | "m_retrieved = np.where(m_sdf <= 0, np.ones_like(m_sdf), np.zeros_like(m_sdf))\n", 132 | "print('Retrieved mask from thresholding SDF agrees with original binary mask: %s' %np.allclose(m_retrieved,m)) " 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [] 141 | } 142 | ], 143 | "metadata": { 144 | "kernelspec": { 145 | "display_name": "misc", 146 | "language": "python", 147 | "name": "python3" 148 | }, 149 | "language_info": { 150 | "codemirror_mode": { 151 | "name": "ipython", 152 | "version": 3 153 | }, 154 | "file_extension": ".py", 155 | "mimetype": "text/x-python", 156 | "name": "python", 157 | "nbconvert_exporter": "python", 158 | "pygments_lexer": "ipython3", 159 | "version": "3.12.2" 160 | } 161 | }, 162 | "nbformat": 4, 163 | "nbformat_minor": 2 164 | } 165 | -------------------------------------------------------------------------------- /sampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | from types import SimpleNamespace 3 | import imageio 4 | import numpy as np 5 | import argparse 6 | import sys 7 | import os 8 | import json 9 | from tqdm.auto import tqdm 10 | import matplotlib.pyplot as plt 11 | from PIL import ImageDraw 12 | import numpy.ma as ma 13 | 14 | import einops 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | from torch.optim import Adam 19 | from torch.utils.data import DataLoader 20 | from torch.utils.tensorboard import SummaryWriter 21 | import yaml 22 | from torchvision.transforms import PILToTensor, ToPILImage 23 | from PIL import ImageFont, ImageDraw, Image 24 | from torchvision.utils import make_grid 25 | from torchmetrics import JaccardIndex, Dice, F1Score 26 | import torchmetrics 27 | from torchvision.utils import save_image 28 | from torch.nn.functional import one_hot 29 | 30 | from SimulationHelper.simulation import Simulation 31 | from datasets.config_dl import config_dl 32 | from models import ddpm 33 | 34 | parser = argparse.ArgumentParser("") 35 | parser.add_argument( 36 | "--config", default="cfg/monuseg.yaml", type=str, help="path to .yaml config" 37 | ) 38 | parser.add_argument("--seed", default=0, type=int, help="seed for reproducibility") # 1 39 | args = parser.parse_args() 40 | 41 | # Setting reproducibility 42 | def set_seed(SEED=0): 43 | random.seed(SEED) 44 | np.random.seed(SEED) 45 | torch.manual_seed(SEED) 46 | 47 | def store_gif(frames, frames_per_gif, load_path, sample_str=''): 48 | gif_name = load_path + "/samples/samples" + sample_str + ".gif" 49 | 50 | with imageio.get_writer(gif_name, mode="I") as writer: 51 | for idx, frame in enumerate(frames): 52 | writer.append_data(frame) 53 | if idx == len(frames) - 1: 54 | for _ in range(frames_per_gif // 3): 55 | writer.append_data(frames[-1]) 56 | 57 | def show_images(images, vmin=None, vmax=None, save_name="", overlay=None): 58 | """Shows the provided images as sub-pictures in a square""" 59 | alpha=0.6 if overlay is not None else 1. # alpha channel if additional overlay image is given 60 | if vmin is None: 61 | vmin = images.min().item() 62 | if vmax is None: 63 | vmax = images.max().item() 64 | 65 | if overlay is not None: 66 | overlay = overlay.detach().cpu().numpy() 67 | 68 | # Converting images to CPU numpy arrays 69 | if type(images) is torch.Tensor: 70 | images = images.detach().cpu().numpy() 71 | 72 | # Defining number of rows and columns 73 | fig = plt.figure(figsize=(8, 8)) 74 | rows = int(len(images) ** (1 / 2)) 75 | cols = round(len(images) / rows) 76 | 77 | # Populating figure with sub-plots 78 | idx = 0 79 | for r in range(rows): 80 | for c in range(cols): 81 | fig.add_subplot(rows, cols, idx + 1) 82 | 83 | if idx < len(images): 84 | if overlay is not None: 85 | plt.imshow(overlay[idx][0], cmap="gray") 86 | images[:,:,0,0] = vmax # this is just for plotting! 87 | images[:,:,0,1] = 1 88 | mask = np.ma.masked_where(images[idx][0] == 0, images[idx][0]) 89 | plt.imshow(mask, alpha=alpha), plt.axis("off") 90 | else: 91 | plt.imshow(images[idx][0], alpha=alpha, cmap="gray", vmin=vmin, vmax=vmax), plt.axis("off") 92 | idx += 1 93 | 94 | # Showing the figure 95 | plt.savefig(save_name, bbox_inches="tight", dpi=250) 96 | plt.close() 97 | 98 | 99 | def compute_metrics(x,x_gt,thresh,corr_mode,num_classes): 100 | # compute IoU between thresholded x and x_gt 101 | x_thresh = torch.where(x > thresh, torch.zeros_like(x), torch.ones_like(x)).type(torch.int8).squeeze().cpu() 102 | x_gt_thresh = torch.where(x_gt > 0., torch.zeros_like(x_gt), torch.ones_like(x_gt)).type(torch.int8).squeeze().cpu() 103 | 104 | jaccard = JaccardIndex(task="binary") 105 | dice = Dice(task='binary',average='macro',num_classes=2,ignore_index=0) # dice=f1 in binary segmentation 106 | iou, dice = jaccard(x_thresh, x_gt_thresh), dice(x_thresh, x_gt_thresh) 107 | return iou, dice 108 | 109 | 110 | def plot_all(x,cond,x_gt,img_cond,load_path,std_min,corr_mode,sample_str=''): 111 | sdf_min = x_gt.min().item() 112 | sdf_max = x_gt.max().item() 113 | 114 | show_images(x, vmin=sdf_min, vmax=sdf_max, save_name=load_path + "/samples/samples_" + str(sample_str) + ".png") 115 | if img_cond == 1: 116 | show_images(cond, save_name=load_path + "/samples/condition.png") 117 | show_images(x_gt, vmin=sdf_min, vmax=sdf_max, save_name=load_path + "/samples/groundtruth.png") 118 | 119 | x_thresh = torch.where(x > 3.*std_min, torch.zeros_like(x), torch.ones_like(x)) 120 | show_images(x_thresh, save_name=load_path + "/samples/samples_thresholded_.png") 121 | 122 | x_gt_thresh = torch.where(x_gt > 0., torch.zeros_like(x_gt), torch.ones_like(x_gt)) 123 | show_images(x_gt_thresh, save_name=load_path + "/samples/groundtruth_thresholded.png") 124 | 125 | # show thresholded maps on top of conditioning image 126 | vmax = x_gt.shape[1] 127 | show_images(x_thresh, save_name=load_path + "/samples/samples_thresholded_overlay.png", vmax=vmax, overlay=cond) 128 | show_images(x_gt_thresh, save_name=load_path + "/samples/groundtruth_thresholded_overlay.png", vmax=vmax, overlay=cond) 129 | 130 | class Sampling: 131 | def __init__(self, scorenet, model_type, device, load_path, sz, noise_level_dict, beta_dict, sde, img_cond, corr_mode, save_images=True): 132 | # general params 133 | self.scorenet = scorenet 134 | self.device = device 135 | self.load_path = load_path 136 | self.sz = sz 137 | self.sde = sde 138 | self.img_cond = img_cond 139 | self.model_type = model_type 140 | self.corr_mode = corr_mode 141 | 142 | # if set to False, no images are saved 143 | self.save_images = save_images 144 | 145 | if self.sde == 've': 146 | self.s1, self.sL, self.L = noise_level_dict['s1'], noise_level_dict['sL'], noise_level_dict['L'] 147 | self.sigmas = torch.tensor(np.exp(np.linspace(np.log(self.s1),np.log(self.sL), self.L))).type(torch.float32) 148 | elif self.sde == 'vp': 149 | self.beta1, self.betaT, self.T = beta_dict['beta1'], beta_dict['betaT'], beta_dict['T'] 150 | self.betas = np.linspace(1.E-4, 0.02, 1000, dtype=np.float32) 151 | self.alphas = 1 - self.betas 152 | self.alpha_bars = torch.from_numpy(np.asarray([np.prod(self.alphas[:i + 1]) for i in range(len(self.alphas))])) 153 | 154 | def get_sigma(self,t): 155 | return self.sigmas[-1]*(self.sigmas[0]/self.sigmas[-1])**t 156 | 157 | def sample(self, x, m_gt, n_samples, N, M, r, num_classes=2): 158 | if self.sde == 've': 159 | return self._sample_ve(x, m_gt, n_samples=n_samples, N=N, M=M, r=r,num_classes=num_classes) 160 | elif self.sde == 'vp': 161 | return self._sample_vp(x, m_gt, n_samples=n_samples,num_classes=num_classes) 162 | 163 | def _sample_vp(self,x, m_gt, n_samples, num_classes): 164 | # TODO: sample according to DDPM paper, note up to date this is fixed to 1000 time steps but could be adapted with a continuous loss function 165 | """ 166 | Sample according to DDPM paper 167 | """ 168 | m = torch.randn(n_samples,1,self.sz,self.sz).float().to(self.device) 169 | device = x.device 170 | m_list = [] 171 | with torch.no_grad(): 172 | for i, t in tqdm(enumerate(list(range(self.T))[::-1])): 173 | # Estimating noise to be removed 174 | time_tensor = (torch.ones(n_samples, 1) * t).to(self.device).long() 175 | eta_theta = self.scorenet(m,t*torch.ones((n_samples,1)).to(self.device),img_cond=x) 176 | alpha_t = self.alphas[t] 177 | alpha_t_bar = self.alpha_bars[t] 178 | 179 | # Partially denoising the image 180 | m = (1 / np.sqrt(alpha_t)) * ( 181 | m - (1 - alpha_t) / np.sqrt(1 - alpha_t_bar) * eta_theta 182 | ) 183 | 184 | m_list.append(m.detach().cpu()) 185 | if t > 0: # no noise added in last sampling step 186 | z = torch.randn(n_samples, 1, self.sz, self.sz).to(device) 187 | beta_t = self.betas[t] 188 | # # Option 1: sigma_t squared = beta_t 189 | # sigma_t = np.sqrt(beta_t) 190 | 191 | # Option 2: sigma_t squared = beta_tilda_t 192 | prev_alpha_t_bar = self.alpha_bars[t-1] if t > 0 else self.alphas[0] 193 | beta_tilde_t = ((1 - prev_alpha_t_bar)/(1 - alpha_t_bar)) * beta_t 194 | sigma_t = np.sqrt(beta_tilde_t) 195 | 196 | # Adding some more noise like in Langevin Dynamics fashion 197 | m = m + sigma_t * z 198 | 199 | if self.save_images: 200 | plot_all(m,x,m_gt,self.img_cond,load_path,std_min=1e-3,corr_mode=self.corr_mode,sample_str='vp') 201 | 202 | # x_list.append(x.detach().cpu()) 203 | return m, m_list 204 | 205 | def _sample_ve(self, cond, x_gt, n_samples, N, M, r, num_classes): 206 | """ 207 | Sample using reverse-time SDE (VE-SDE) 208 | N number of predictor steps 209 | M number of corrector steps 210 | r "signal-to-noise" ratio 211 | """ 212 | 213 | frames = [] 214 | frames_thresh = [] 215 | frames_all = [] 216 | frames_per_gif = 100 217 | frame_idxs = np.linspace(0, N, frames_per_gif).astype(np.uint) 218 | 219 | t = torch.linspace(1-(1./N),0,N) # TODO: fix start at 1 or 1-dt 220 | sigma_t = self.get_sigma(t[0]) 221 | x_list = [] 222 | 223 | # initialize x and sample 224 | n_samples = x_gt.shape[0] 225 | x = self.sigmas[0]*torch.clip(torch.randn(n_samples,num_classes,self.sz,self.sz),-2.,2.).to(self.device) 226 | 227 | with torch.no_grad(): 228 | for i, t_curr in enumerate(t): 229 | if i % 20 == 0: 230 | iou, dice = compute_metrics(x,x_gt,thresh=3*self.sigmas[-1].item(),corr_mode=self.corr_mode,num_classes=num_classes) 231 | print('PC sampling it [%i]:\t IoU [%.6f], Dice [%.6f]' %(i,iou,dice)) 232 | 233 | # set sigma(t) 234 | sigma_t_prev = sigma_t.clone() 235 | sigma_t = self.get_sigma(t_curr) 236 | 237 | # get scores, sample noise 238 | if self.model_type == 'unet': 239 | scores = self.scorenet(x,sigma_t_prev*torch.ones((n_samples,1)).to(self.device),img_cond=cond) 240 | elif self.model_type == 'tdv': 241 | scores = self.scorenet.grad(torch.cat([x,cond],1),sigma_t_prev*torch.ones((n_samples,1,1,1)).to(self.device))[:,0:1] 242 | 243 | z = torch.clip(torch.randn_like(x),-2.,2.) 244 | tau = (sigma_t_prev**2 - sigma_t**2) 245 | 246 | # predictor step 247 | x = x + tau*scores 248 | x_list.append(x.detach().cpu()) 249 | x += np.sqrt(tau)*z 250 | 251 | # corrector steps 252 | for j in range(M): 253 | # z = torch.randn_like(x) 254 | z = torch.clip(torch.randn_like(x),-2.,2.) 255 | 256 | # compute eps 257 | if self.model_type == 'unet': 258 | scores_corr = self.scorenet(x,sigma_t*torch.ones((n_samples,1)).to(self.device),img_cond=cond) 259 | elif self.model_type == 'tdv': 260 | scores_corr = self.scorenet.grad(torch.cat([x,cond],1),sigma_t*torch.ones((n_samples,1,1,1)).to(self.device))[:,0:1] 261 | 262 | eps = 2*(r*torch.norm(z).item()/torch.norm(scores_corr).item())**2 263 | x = x + eps*scores_corr 264 | 265 | x_list.append(x.detach().cpu()) 266 | x += np.sqrt(2*eps)*z 267 | 268 | if self.save_images and (i in frame_idxs or t_curr == 0): # TODO: if other samplers than PC are used, make sure that the gif is also generated for them 269 | # Putting digits in range [0, 255] 270 | normalized = x.clone() 271 | if self.corr_mode == 'diffusion_ls': 272 | normalized_thresh = torch.where(x > 3*self.sigmas[-1], torch.zeros_like(x), torch.ones_like(x)) 273 | elif self.corr_mode == 'diffusion': 274 | normalized_thresh = torch.where(x < 0.5, torch.zeros_like(x), torch.ones_like(x)) 275 | 276 | for i in range(len(normalized)): 277 | normalized[i] -= torch.min(normalized[i]) 278 | normalized[i] *= 255 / torch.max(normalized[i]) 279 | 280 | normalized_thresh[i] -= torch.min(normalized_thresh[i]) 281 | normalized_thresh[i] *= 255 / torch.max(normalized_thresh[i]) 282 | 283 | # Reshaping batch (n, c, h, w) to be a (as much as it gets) square frame 284 | frame = einops.rearrange( 285 | normalized, 286 | "(b1 b2) c h w -> (b1 h) (b2 w) c", 287 | b1=int(n_samples**0.5), 288 | ) 289 | frame = frame.cpu().numpy().astype(np.uint8) 290 | frames.append(frame) 291 | 292 | # append thresholded 293 | frame_thresh = einops.rearrange( 294 | normalized_thresh, 295 | "(b1 b2) c h w -> (b1 h) (b2 w) c", 296 | b1=int(n_samples**0.5), 297 | ) 298 | frame_thresh = frame_thresh.cpu().numpy().astype(np.uint8) 299 | frames_thresh.append(frame_thresh) 300 | 301 | # plotting 302 | if self.save_images: 303 | plot_all(x_list[-1],cond,x_gt,self.img_cond,load_path,std_min=cfg.SMLD.sigma_L,corr_mode=self.corr_mode,sample_str='ve') 304 | 305 | return x_list[-1], x_list 306 | 307 | if __name__ == "__main__": 308 | with open(args.config) as file: 309 | yaml_cfg = yaml.safe_load(file) 310 | cfg = json.loads(json.dumps(yaml_cfg), object_hook=lambda d: SimpleNamespace(**d)) 311 | 312 | device = torch.device("cuda") 313 | print(f"Using device: {device}\t" + (f"{torch.cuda.get_device_name(0)}")) 314 | 315 | set_seed(SEED=0) 316 | 317 | # set up dataloader, model 318 | train_dl, test_dl = config_dl(cfg) 319 | if cfg.model.type == 'unet': 320 | model = ddpm.Network( 321 | dim=cfg.model.dim, 322 | channels=cfg.model.n_cin, 323 | cond_channels=cfg.model.n_cin_cond, 324 | init_dim=cfg.model.n_fm, 325 | dim_mults=tuple(cfg.model.mults), 326 | embedding=cfg.model.embedding, 327 | img_cond=cfg.general.img_cond, 328 | with_class_label_emb=cfg.general.with_class_label_emb, 329 | class_label_cond=cfg.general.class_label_cond, 330 | num_classes=cfg.general.num_classes, 331 | ).to(device) 332 | 333 | else: 334 | raise ValueError('Unknown model type!') 335 | 336 | load_path = os.getcwd() + "/runs/" + cfg.general.modality 337 | 338 | if cfg.inference.latest: 339 | print("\nINFO: inference the lastest experiment!") 340 | load_path += "/" + sorted(os.listdir(load_path))[-1] 341 | else: 342 | print("\nINFO: inference from selected experiment, *not* the latest!") 343 | load_path += "/" + cfg.inference.load_exp 344 | 345 | print(f"Loading *latest* checkpoint from {load_path + '/models/'}") 346 | if not os.path.exists(load_path + "/samples"): # makedir for samples 347 | os.mkdir(load_path + "/samples") 348 | 349 | # load the model weights 350 | fnames = sorted( 351 | [fname for fname in os.listdir(load_path + "/models/") if fname.endswith(".pt")] 352 | ) 353 | model.load_state_dict(torch.load(load_path + "/models/" + fnames[-1], map_location=device)["state_dict"],strict=True) # strict=False 354 | 355 | model.eval() 356 | print("\nModel loaded from %s" % (load_path + "/models/" + fnames[-1])) 357 | 358 | # sample and save generated images 359 | if cfg.general.corr_mode == "diffusion" or cfg.general.corr_mode == "diffusion_ls": 360 | noise_level_dict={'s1': cfg.SMLD.sigma_1_m, 'sL': cfg.SMLD.sigma_L_m, 'L': cfg.SMLD.n_steps} 361 | beta_dict = {'beta1': cfg.SMLD.beta_1, 'betaT': cfg.SMLD.beta_T, 'T': cfg.SMLD.T} 362 | 363 | Sampler = Sampling( 364 | scorenet=model, 365 | model_type=cfg.model.type, 366 | device=device, 367 | load_path=load_path, 368 | sz=cfg.general.sz, 369 | noise_level_dict=noise_level_dict, 370 | beta_dict=beta_dict, 371 | sde=cfg.SMLD.sde, 372 | img_cond=cfg.general.img_cond, 373 | corr_mode=cfg.general.corr_mode, 374 | save_images=True) 375 | 376 | # load conditioning image and ground truth 377 | it_test_dl = iter(test_dl) 378 | batch = next(it_test_dl) 379 | x = None if cfg.general.img_cond==0 else batch['mask'].to(device) 380 | m_gt = None if cfg.general.img_cond==0 else batch['image'].to(device) 381 | 382 | # generate samples 383 | samples, samples_list = Sampler.sample(x, m_gt, n_samples=cfg.inference.n_samples, N=cfg.SMLD.N,M=cfg.SMLD.M,r=cfg.SMLD.r, num_classes=cfg.model.n_cin) 384 | 385 | # eval metrics 386 | iou, dice = compute_metrics(samples, m_gt.cpu(), corr_mode=cfg.general.corr_mode,thresh=3.*cfg.SMLD.sigma_L_m,num_classes=cfg.model.n_cin) 387 | print('\nFinal metrics: IoU [%f], Dice [%f]' %(iou,dice)) 388 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import random 2 | import imageio 3 | import numpy as np 4 | import argparse 5 | import sys 6 | import os 7 | import json 8 | from tqdm.auto import tqdm 9 | import matplotlib.pyplot as plt 10 | 11 | import einops 12 | import torch 13 | import torch.nn as nn 14 | from torch.optim import Adam 15 | from torch.utils.data import DataLoader 16 | from torch.utils.tensorboard import SummaryWriter 17 | from torchvision.utils import make_grid 18 | 19 | from SimulationHelper.simulation import Simulation 20 | from datasets.config_dl import config_dl 21 | from datasets.transform_factory import inv_normalize, transform_factory 22 | from models import ddpm 23 | 24 | class TrainScoreNetwork: 25 | def __init__(self, noise_level_dict, beta_dict, sde, model_type, train_objective, anneal_power=2, loss_power=2, n_val=8, val_dl=None): 26 | self.sde = sde 27 | self.model_type = model_type 28 | 29 | if self.sde == 've': 30 | self.s1, self.sL, self.L = noise_level_dict['s1'], noise_level_dict['sL'], noise_level_dict['L'] 31 | self.sigmas = torch.tensor(np.exp(np.linspace(np.log(self.s1),np.log(self.sL), self.L))).type(torch.float32) 32 | 33 | self.model_type = model_type 34 | self.anneal_power = anneal_power 35 | self.loss_power = loss_power 36 | self.train_objective = train_objective 37 | assert train_objective == 'disc' or train_objective == 'cont' 38 | 39 | if val_dl: # then use test dataloader 40 | val_batch = next(iter(val_dl)) 41 | self.x_val = val_batch['image'][:n_val] 42 | self.cond_val = val_batch['mask'][:n_val] 43 | 44 | eta_val = torch.randn_like(self.x_val) 45 | self.used_sigmas_val = torch.linspace(self.sigmas[0],self.sigmas[-1], self.x_val.shape[0])[:,None,None,None] 46 | self.z_val = self.x_val + eta_val*self.used_sigmas_val 47 | 48 | elif self.sde == 'vp': 49 | self.beta1, self.betaT, self.T = beta_dict['beta1'], beta_dict['betaT'], beta_dict['T'] 50 | self.betas = np.linspace(1.E-4, 0.02, 1000, dtype=np.float32) 51 | self.alphas = 1 - self.betas 52 | self.alpha_bars = torch.from_numpy(np.asarray([np.prod(self.alphas[:i + 1]) for i in range(len(self.alphas))])) 53 | 54 | if val_dl: # TODO: implement validation 55 | val_batch = next(iter(val_dl)) 56 | self.x_val = val_batch['image'][:n_val] 57 | pass 58 | 59 | else: 60 | raise ValueError('Unknown SDE type!') 61 | 62 | def get_grad_norm(self, model): 63 | parameters = [p for p in model.parameters() if p.grad is not None and p.requires_grad] 64 | norms = [p.grad.detach().abs().max().item() for p in parameters] 65 | return np.asarray(norms).max() 66 | 67 | def do(self, scorenet, dl, n_epochs, clip, optim, device, simulation, writer, img_cond, class_label_cond=False): 68 | if self.sde == 've': 69 | self._do_ve(scorenet, dl, n_epochs, clip, optim, device, simulation, writer, img_cond, class_label_cond) 70 | 71 | elif self.sde == 'vp': 72 | self._do_vp(scorenet, dl, n_epochs, clip, optim, device, simulation, writer, img_cond, class_label_cond) 73 | 74 | def _do_ve(self, scorenet, dl, n_epochs, clip, optim, device, simulation, writer, img_cond=0, class_label_cond=False): 75 | if img_cond == 0: 76 | self.cond_val = None 77 | else: 78 | self.cond_val = self.cond_val.to(device) 79 | best_loss = float("inf") 80 | 81 | for epoch in tqdm(range(n_epochs), desc=f"Training progress", colour="#00ff00"): 82 | epoch_loss = 0.0 83 | grad_norms_epoch = [] 84 | 85 | for step, batch in enumerate(tqdm(dl, leave=False, desc=f"Epoch {epoch + 1}/{n_epochs}", colour="#005500")): 86 | # Loading data 87 | x = batch['image'].to(device) 88 | cond = None if img_cond==0 else batch['mask'].to(device) 89 | lbl = None if class_label_cond is False else batch['label'].to(device).unsqueeze(1) 90 | n = len(x) 91 | 92 | # noise-conditional score network corruption 93 | if self.train_objective == 'disc': 94 | sigmas_idx = torch.randint(0, self.L, (n,))#.to(device) 95 | used_sigmas = (self.sigmas[sigmas_idx][:,None,None,None]).to(device) 96 | elif self.train_objective == 'cont': # continuous training objective (SDE style) 97 | t = torch.from_numpy(np.random.uniform(1e-5,1,(n,))).float() 98 | used_sigmas = (self.sigmas[-1]*(self.sigmas[0]/self.sigmas[-1])**t)[:,None,None,None].to(device) 99 | 100 | # noise corruption 101 | eta = torch.randn_like(x).to(device) 102 | z = x + eta*used_sigmas.to(device) 103 | 104 | # compute score matching loss 105 | target = 1/(used_sigmas**2) * (x-z) 106 | if self.model_type == 'unet': 107 | scores = scorenet(z, used_sigmas.reshape(n,-1), img_cond=cond, class_lbl=lbl) 108 | elif self.model_type == 'tdv': 109 | scores = scorenet.grad(torch.cat([z,cond],1), used_sigmas.reshape(n,1,1,1))[:,0:1] 110 | 111 | if step % 100 == 0: # Sanity Check. Whats going into the network? 112 | with torch.no_grad(): 113 | scorenet.eval() 114 | if self.x_val is not None: # always take same val/test batch 115 | if self.model_type == 'unet': 116 | scores_val = scorenet(self.z_val.to(device), self.used_sigmas_val.to(device).reshape(self.z_val.shape[0],-1), img_cond=self.cond_val, class_lbl=lbl) 117 | elif self.model_type == 'tdv': 118 | scores_val = scorenet.grad(torch.cat([self.z_val.to(device),self.cond_val],1), self.used_sigmas_val.to(device).reshape(self.z_val.shape[0],1,1,1))[:,0:1] 119 | 120 | x_mmse_val = self.z_val.to(device) + self.used_sigmas_val.to(device)**2 * scores_val 121 | 122 | # for multi-class plotting just take a random class 123 | if self.x_val.shape[1] > 1: 124 | class_idx = 4 125 | x_val, z_val, x_mmse_val = self.x_val[:,class_idx][:,None], self.z_val[:,class_idx][:,None], x_mmse_val[:,class_idx][:,None] 126 | else: 127 | x_val, z_val = self.x_val, self.z_val 128 | 129 | all_stacked = torch.cat([ 130 | make_grid(x_val, nrow=x_val.shape[0], normalize=True, scale_each=True).cpu(), 131 | make_grid(z_val, nrow=x_val.shape[0], normalize=True,scale_each=True).cpu(), 132 | make_grid(x_mmse_val, nrow=self.x_val.shape[0], normalize=True, scale_each=True).cpu()], dim=1) 133 | 134 | else: # check on random input data 135 | x_mmse = z + used_sigmas**2*scores 136 | all_stacked = torch.cat([ 137 | make_grid(x, nrow=x.shape[0], normalize=True, scale_each=True).cpu(), 138 | make_grid(z, nrow=x.shape[0], normalize=True,scale_each=True).cpu(), 139 | make_grid(x_mmse, nrow=x.shape[0], normalize=True, scale_each=True).cpu()], dim=1) 140 | 141 | # plot clean, noisy, and denoised (using Tweedie's formula) 142 | if step % 100 == 0: 143 | writer.add_image(f'training', all_stacked, global_step=epoch) 144 | writer.flush() 145 | scorenet.train() 146 | 147 | # Optimizing the MSE between the noise plugged and the predicted noise # 148 | loss_batches = ((torch.abs(target - scores))**self.loss_power).sum((-3,-2,-1))*used_sigmas.squeeze()**self.anneal_power # NOTE: L1 loss and anneal_power should match 149 | loss = loss_batches.mean() 150 | 151 | optim.zero_grad() 152 | loss.backward() 153 | 154 | if isinstance(clip,float): 155 | torch.nn.utils.clip_grad_norm_(scorenet.parameters(), max_norm=clip, norm_type='inf') 156 | grad_norms_epoch.append(self.get_grad_norm(scorenet)) 157 | 158 | optim.step() 159 | epoch_loss += loss.item() * len(x) / len(dl.dataset) 160 | 161 | 162 | log_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.8f}" 163 | if epoch % 50 == 0: 164 | writer.add_scalar(f'train/epoch_loss', epoch_loss, epoch) 165 | 166 | writer.add_scalar(f'train/epoch_max_grad', np.asarray(grad_norms_epoch).max(), epoch) 167 | writer.add_scalar(f'train/epoch_mean_grad', np.asarray(grad_norms_epoch).mean(), epoch) 168 | 169 | # Storing the model 170 | if epoch % 5000 == 0: # save every 5000th epochs model, no matter what? 171 | checkpoint = {'state_dict': scorenet.state_dict()} 172 | simulation.save_pytorch(checkpoint, overwrite=False, subdir='models_sanity', epoch='_'+'{0:07}'.format(epoch)) 173 | 174 | if best_loss > epoch_loss: 175 | best_loss = epoch_loss 176 | 177 | # save last 3 checkpoints 178 | if epoch > 0: 179 | cp_dir = simulation._outdir + '/models' 180 | if len([name for name in os.listdir(cp_dir) if os.path.isfile(os.path.join(cp_dir,name))]) == 3: 181 | fnames = sorted([fname for fname in os.listdir(cp_dir) if fname.endswith('.pt')]) 182 | os.remove(os.path.join(cp_dir,fnames[0])) 183 | checkpoint = {'epoch': epoch, 184 | 'state_dict': scorenet.state_dict(), 185 | 'optimizer': optim.state_dict()} 186 | simulation.save_pytorch(checkpoint, overwrite=False, epoch='_'+'{0:07}'.format(epoch)) 187 | log_string += " --> Best model ever (stored)" 188 | 189 | print(log_string) 190 | 191 | def _do_vp(self, scorenet, dl, n_epochs, clip, optim, device, simulation, writer, img_cond, class_label_cond=False): 192 | best_loss = float("inf") 193 | mse = nn.MSELoss() 194 | 195 | for epoch in tqdm(range(n_epochs), desc=f"Training progress", colour="#00ff00"): 196 | epoch_loss = 0.0 197 | grad_norms_epoch = [] 198 | 199 | for step, batch in enumerate(tqdm(dl, leave=False, desc=f"Epoch {epoch + 1}/{n_epochs}", colour="#005500")): 200 | # Loading data 201 | m = batch['image'].to(device) 202 | m *= 5. # TODO: comment if a clean data loader *without* scaling to [-0.2,0.2] is used - this is to get the mask to [-1,1] like the image 203 | x = None if img_cond==0 else batch['mask'].to(device) 204 | lbl = None if class_label_cond is False else batch['label'].to(device).unsqueeze(1) 205 | n = len(m) 206 | 207 | # noise corruption 208 | t = torch.randint(0, self.T, (n,)).to(device) 209 | a_bar = self.alpha_bars.to(device)[t]#.to(x.device) 210 | eta = torch.randn_like(x).to(device) 211 | m_noisy = a_bar.sqrt().reshape(n, 1, 1, 1) * m + (1 - a_bar).sqrt().reshape(n, 1, 1, 1) * eta 212 | 213 | # compute score matching loss 214 | if self.model_type == 'unet': 215 | eta_estimated = scorenet(m_noisy, t.reshape(n,-1), img_cond=x, class_lbl=lbl) 216 | 217 | elif self.model_type == 'uvit': 218 | eta_estimated = scorenet(m_noisy, t.reshape(n,-1), img_cond=x) 219 | 220 | elif self.model_type == 'tdv': 221 | raise NotImplementedError 222 | 223 | # Optimizing the MSE between the noise plugged and the predicted noise # 224 | loss = mse(eta_estimated, eta) 225 | optim.zero_grad() 226 | loss.backward() 227 | 228 | if isinstance(clip,float): 229 | torch.nn.utils.clip_grad_norm_(scorenet.parameters(), max_norm=clip, norm_type='inf') 230 | grad_norms_epoch.append(self.get_grad_norm(scorenet)) 231 | 232 | optim.step() 233 | epoch_loss += loss.item() * len(x) / len(dl.dataset) 234 | 235 | log_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.8f}" 236 | if epoch % 10 == 0: 237 | writer.add_scalar(f'train/epoch_loss', epoch_loss, epoch) 238 | writer.add_scalar(f'train/epoch_max_grad', np.asarray(grad_norms_epoch).max(), epoch) 239 | writer.add_scalar(f'train/epoch_mean_grad', np.asarray(grad_norms_epoch).mean(), epoch) 240 | 241 | if best_loss > epoch_loss: 242 | best_loss = epoch_loss 243 | # save last 3 checkpoints 244 | if epoch > 0: 245 | cp_dir = simulation._outdir + '/models' 246 | if len([name for name in os.listdir(cp_dir) if os.path.isfile(os.path.join(cp_dir,name))]) == 3: 247 | fnames = sorted([fname for fname in os.listdir(cp_dir) if fname.endswith('.pt')]) 248 | os.remove(os.path.join(cp_dir,fnames[0])) 249 | checkpoint = {'epoch': epoch, 250 | 'state_dict': scorenet.state_dict(), 251 | 'optimizer': optim.state_dict()} 252 | simulation.save_pytorch(checkpoint, overwrite=False, epoch='_'+'{0:07}'.format(epoch)) 253 | log_string += " --> Best model ever (stored)" 254 | print(log_string) --------------------------------------------------------------------------------