├── .gitignore
├── .gitmodules
├── README.md
├── __init__.py
├── assets
├── process_sde.png
└── sampling_sdf.gif
├── cfg
├── __init__.py
├── glas.yaml
└── monuseg.yaml
├── datasets
├── __init__.py
├── config_dl.py
├── glas_dataset.py
├── monuseg_dataset.py
└── transform_factory.py
├── env.yml
├── main.py
├── models
├── __init__.py
└── ddpm.py
├── preprocess_data
└── precompute_sdf.ipynb
├── sampler.py
└── trainer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.directory
2 | *pycache*
3 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "SimulationHelper"]
2 | path = SimulationHelper
3 | url = https://github.com/f-ilic/SimulationHelper.git
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | **Score-Based Generative Models for Medical Image Segmentation using Signed Distance Functions**
3 | GCPR 2023
4 | Lea Bogensperger, Dominik Narnhofer, Filip Ilic, Thomas Pock
5 |
6 | ---
7 |
8 | [[Project Page]](https://github.com/leabogensperger/generative-segmentation-sdf)
9 | [[Paper]](https://arxiv.org/abs/2303.05966)
10 |
11 |
12 | Environment Setup:
13 | ```bash
14 | git clone --recurse-submodules git@github.com:leabogensperger/generative-segmentation-sdf.git
15 | conda env create -f env.yaml
16 | conda activate generative_segmentation_sdf
17 | ```
18 |
19 | # Score-Based Generative Models for Medical Image Segmentation using Signed Distance Functions
20 |
21 | This repository contains the code to train a generative model that learns the conditional distribution of implicit segmentation masks in the form of signed distance function conditioned on a specific input image. The generative model is set up as a score-based diffusion model with a variance-exploding scheme -- however, later experiments have shown that the variance-preserving scheme seems numerically a bit more stable for this case, therefore this option is now also included (set the param *sde* in *SMLD* of the config file to either *ve*/*vp*).
22 |
23 |
24 |
25 | # Instructions
26 |
27 | 1) Run by specifying a config file:
28 | ```python
29 | python main.py --config "cfg/monuseg.yaml"
30 | ```
31 |
32 | 2) Sample (set experiment folder in config file):
33 | ```python
34 | python sample.py --config "cfg/monuseg.yaml"
35 | ```
36 |
37 | Note: the pre-processed data sets will be uploaded later. The data set is specified by the config file. The root directory is set with in the config file, which must contain csv files for train and test mode with columns *filename* and *maskname* of all pre-processed patches. Moreover, it must contain the folders *Trainig_patches* and *Test_patches*, which include for each patch a .png file of the input image and a .npy file of the sdf transformed segmentation mask.
38 |
39 | # Sampling
40 |
41 | The sampling process of the proposed approach is shown using the predictor-corrector sampling algorithm (see Algorithm 1 in the paper).
42 | In the top row there are four different condition images and the center row contains the generated/predicted SDF masks.
43 | Further, the bottom row displays the corresponding binary masks, which are obtained only indirectly from thresholding the predicted SDF masks.
44 |
45 |
46 |
47 | # Cite
48 |
49 | ```bibtex
50 | @misc{
51 | bogensperger2023scorebased,
52 | title={Score-Based Generative Models for Medical Image Segmentation using Signed Distance Functions},
53 | author={Lea Bogensperger and Dominik Narnhofer and Filip Ilic and Thomas Pock},
54 | year={2023},
55 | eprint={2303.05966},
56 | archivePrefix={arXiv},
57 | primaryClass={cs.CV}
58 | }
59 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leabogensperger/generative-segmentation-sdf/3b8037e3e03d347a41f361f6ba844e6928240200/__init__.py
--------------------------------------------------------------------------------
/assets/process_sde.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leabogensperger/generative-segmentation-sdf/3b8037e3e03d347a41f361f6ba844e6928240200/assets/process_sde.png
--------------------------------------------------------------------------------
/assets/sampling_sdf.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leabogensperger/generative-segmentation-sdf/3b8037e3e03d347a41f361f6ba844e6928240200/assets/sampling_sdf.gif
--------------------------------------------------------------------------------
/cfg/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leabogensperger/generative-segmentation-sdf/3b8037e3e03d347a41f361f6ba844e6928240200/cfg/__init__.py
--------------------------------------------------------------------------------
/cfg/glas.yaml:
--------------------------------------------------------------------------------
1 | general:
2 | modality: glas
3 | corr_mode: diffusion_ls # diffusion on level sets
4 | img_cond: 1 # condition on image to obtain segmentation mask
5 | data_path: "/home/lea/Data/GlaS_trunc"
6 | csv_train: "train.csv"
7 | csv_test: "test.csv"
8 | batch_size: 32
9 | sz: 128 #
10 | resume_training: False
11 | load_path: ''
12 | class_label_cond: False
13 | num_classes: 0
14 | with_class_label_emb: False
15 |
16 | inference:
17 | latest: False
18 | load_exp: ''
19 | n_samples: 4
20 |
21 | model:
22 | type: 'unet'
23 | n_cin: 1 # 1, n_classes
24 | n_cin_cond: 1 # 1, 3 for gray/color valued input
25 | n_fm: 10
26 | dim: 128
27 | embedding: 'sinusoidal'
28 | mults:
29 | - 1
30 | - 2
31 | - 4
32 | - 4
33 |
34 | learning:
35 | epochs: 500000
36 | lr: 1.0E-4
37 | loss: 2
38 | n_val: 8
39 | clip: 100000.
40 | gpus:
41 | - 1
42 |
43 | SMLD:
44 | sde: 'vp' # VE used in GCPR paper, VP scheme is like classic DDPM
45 | beta_1: 1.E-4 # default from DDPM
46 | beta_T: 0.02 # default from DDPM
47 | T: 1000 # default from DDPM
48 | n_steps: 100 # VE params
49 | sigma_1_m: 5. # VE params: heuristic
50 | sigma_L_m: 0.001 # VE params: heuristic
51 | objective: 'cont'
52 | sampler: 'pc' # for VE scheme with reverse SDE
53 | eps: 2.0E-5 # annealed Langevin
54 | N: 200 # Predictor steps
55 | M: 1 # Corrector steps
56 | r: 0.15 # "snr" for PC sampling
--------------------------------------------------------------------------------
/cfg/monuseg.yaml:
--------------------------------------------------------------------------------
1 | general:
2 | modality: monuseg
3 | corr_mode: diffusion_ls # diffusion on level sets
4 | img_cond: 1 # condition on image to obtain segmentation mask
5 | data_path: "/home/lea/Data/MonuSeg_spcn_trunc"
6 | csv_train: "train.csv"
7 | csv_test: "test.csv"
8 | batch_size: 32
9 | sz: 128 # 128, 256
10 | resume_training: False
11 | load_path: ''
12 | class_label_cond: False
13 | num_classes: 0
14 | with_class_label_emb: False
15 |
16 | inference:
17 | latest: False
18 | load_exp: ''
19 | n_samples: 4
20 |
21 | model:
22 | type: 'unet'
23 | n_cin: 1 # 1, n_classes
24 | n_cin_cond: 1 # 1, 3 for gray/color valued input
25 | n_fm: 10
26 | dim: 128
27 | embedding: 'sinusoidal'
28 | mults:
29 | - 1
30 | - 2
31 | - 4
32 | - 4
33 |
34 | learning:
35 | epochs: 500000
36 | lr: 1.0E-4
37 | loss: 2
38 | n_val: 8
39 | clip: 40000.
40 | gpus:
41 | - 1
42 |
43 | SMLD:
44 | sde: 'vp' # VE used in GCPR paper, VP scheme is like classic DDPM
45 | beta_1: 1.E-4 # default from DDPM
46 | beta_T: 0.02 # default from DDPM
47 | T: 1000 # default from DDPM
48 | n_steps: 100 # VE params
49 | sigma_1_m: 5. # VE params: heuristic
50 | sigma_L_m: 0.001 # VE params: heuristic
51 | objective: 'cont'
52 | sampler: 'pc' # for VE scheme with reverse SDE
53 | eps: 2.0E-5 # annealed Langevin
54 | N: 200 # Predictor steps
55 | M: 1 # Corrector steps
56 | r: 0.15 # "snr" for PC sampling
57 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leabogensperger/generative-segmentation-sdf/3b8037e3e03d347a41f361f6ba844e6928240200/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/config_dl.py:
--------------------------------------------------------------------------------
1 | from types import SimpleNamespace
2 | from torch.utils.data.dataloader import DataLoader
3 | import yaml
4 | import json
5 | from os.path import join, isfile
6 |
7 | from datasets.transform_factory import inv_normalize, transform_factory
8 | from datasets.monuseg_dataset import MoNuSegDataset
9 | from datasets.glas_dataset import GlaSDataset
10 |
11 |
12 | def config_dl(cfg):
13 | if cfg.general.modality == 'monuseg':
14 | DatasetType = MoNuSegDataset
15 | stats = {'mean': 0., 'std': 1.}
16 |
17 | elif cfg.general.modality == 'glas':
18 | DatasetType = GlaSDataset
19 | stats = {'mean': 0., 'std': 1.}
20 |
21 | else:
22 | raise ValueError('Unknown modality %s specified!' %cfg.modality)
23 |
24 | train_dataset = DatasetType(cfg.general.data_path, f'{cfg.general.data_path}/{cfg.general.csv_train}', cfg=cfg.general)
25 | test_dataset = DatasetType(cfg.general.data_path, f'{cfg.general.data_path}/{cfg.general.csv_test}', cfg=cfg.general)
26 |
27 | train_dataloader = DataLoader(train_dataset, batch_size=cfg.general.batch_size, shuffle=True, drop_last=True)
28 | test_dataloader = DataLoader(test_dataset, batch_size=cfg.inference.n_samples, shuffle=False, drop_last=False)
29 |
30 | dbdict = {"train_dl": train_dataloader, "test_dl": test_dataloader}
31 |
32 | train_dl, test_dl = dbdict["train_dl"], dbdict["test_dl"]
33 | tfdict = transform_factory(cfg.general)
34 | T_train, T_test = tfdict["train"](stats["mean"], stats["std"]), tfdict["test"](
35 | stats["mean"], stats["std"]
36 | )
37 | train_dl.dataset.transform, test_dl.dataset.transform = T_train, T_test
38 | train_dl.inv_normalize, test_dl.inv_normalize = inv_normalize(
39 | stats["mean"], stats["std"]
40 | ), inv_normalize(stats["mean"], stats["std"])
41 |
42 | return train_dl, test_dl
43 |
--------------------------------------------------------------------------------
/datasets/glas_dataset.py:
--------------------------------------------------------------------------------
1 | from PIL import Image, ImageOps
2 | import numpy as np
3 | import cv2
4 | import pandas as pd
5 | import matplotlib.pyplot as plt
6 | from types import SimpleNamespace
7 |
8 | import torch
9 | from torch.utils.data import Dataset, DataLoader
10 | from datasets.transform_factory import transform_factory
11 |
12 | class GlaSDataset(Dataset):
13 | def __init__(self, data_path, csv_file, cfg):
14 | self.data_path = data_path
15 | self.csv_file = csv_file
16 | self.data = pd.read_csv(self.csv_file)
17 |
18 | self.transform = None
19 | self.inv_normalize = None
20 |
21 | self.corr_mode = cfg.corr_mode
22 | self.img_cond = cfg.img_cond
23 | self.sz = cfg.sz
24 |
25 | def __len__(self):
26 | return len(self.data)
27 |
28 | def __getitem__(self, idx):
29 | img_path = self.data_path + self.data.loc[idx]['filename']
30 | mask_path = self.data_path + self.data.loc[idx]['maskname']
31 |
32 | # load image and mask
33 | img = cv2.imread(img_path,0).astype(np.float32)/255.
34 |
35 | # load level sets mask
36 | if self.corr_mode == 'diffusion_ls':
37 | mask_ls_path = self.data_path + self.data.loc[idx]['maskdtname']
38 | mask = np.load(mask_ls_path).astype(np.float32)
39 | else:
40 | mask = cv2.imread(mask_path,0).astype(np.float32)
41 | mask = mask/255.
42 |
43 | if self.corr_mode == 'diffusion':
44 | corr_type = 0
45 | else:
46 | corr_type = 1
47 |
48 | transform_cfg = {
49 | 'hflip': np.random.rand(),
50 | 'vflip': np.random.rand(),
51 | 'corr_type': corr_type, # 0 is diffusion
52 | 'img_cond': self.img_cond,
53 | }
54 |
55 | if self.img_cond: # condition on image
56 | ret = {'image': mask, 'mask': img, 'name': str(img_path.split('/')[-1][:-4])}
57 | else:
58 | ret = {'image': img, 'mask': mask, 'name': str(img_path.split('/')[-1][:-4])}
59 |
60 | self.transform(transform_cfg)(ret)
61 |
62 | if self.img_cond and self.corr_mode == 'diffusion_ls':
63 | if 'trunc' in self.data_path:
64 | ret['image'] /= 5.
65 | else:
66 | ret['image'] *= 1.6
67 | return ret
--------------------------------------------------------------------------------
/datasets/monuseg_dataset.py:
--------------------------------------------------------------------------------
1 | from PIL import Image, ImageOps
2 | import numpy as np
3 | import cv2
4 | import pandas as pd
5 | import matplotlib.pyplot as plt
6 | from types import SimpleNamespace
7 |
8 | import torch
9 | from torch.utils.data import Dataset, DataLoader
10 | from datasets.transform_factory import transform_factory
11 |
12 | class MoNuSegDataset(Dataset):
13 | def __init__(self, data_path, csv_file, cfg):
14 | self.data_path = data_path
15 | self.csv_file = csv_file
16 | self.data = pd.read_csv(self.csv_file)
17 |
18 | self.transform = None
19 | self.inv_normalize = None
20 |
21 | self.corr_mode = cfg.corr_mode
22 | self.img_cond = cfg.img_cond
23 | self.sz = cfg.sz
24 |
25 | def __len__(self):
26 | return len(self.data)
27 |
28 | def __getitem__(self, idx):
29 | img_path = self.data_path + self.data.loc[idx]['filename']
30 | mask_path = self.data_path + self.data.loc[idx]['maskname']
31 |
32 | # load image and mask
33 | if 'rgb' in self.data_path:
34 | img = cv2.imread(img_path).astype(np.float32)/255.
35 | else:
36 | img = cv2.imread(img_path,0).astype(np.float32)/255.
37 |
38 | # load level sets mask
39 | if self.corr_mode == 'diffusion_ls':
40 | mask_ls_path = self.data_path + self.data.loc[idx]['maskdtname']
41 | mask = np.load(mask_ls_path)
42 | else:
43 | mask = cv2.imread(mask_path,0).astype(np.float32)
44 | mask[mask > 200] = 255.
45 | mask[mask <= 200] = 0.
46 | # mask to [0,1]
47 | mask = mask/255.
48 |
49 | if self.corr_mode == 'diffusion':
50 | corr_type = 0
51 | else:
52 | corr_type = 1
53 |
54 | transform_cfg = {
55 | 'hflip': np.random.rand(),
56 | 'vflip': np.random.rand(),
57 | 'corr_type': corr_type, # 0 is diffusion
58 | 'img_cond': self.img_cond,
59 | }
60 |
61 | if self.img_cond: # condition on image
62 | ret = {'image': mask, 'mask': img, 'name': str(img_path.split('/')[-1][:-4])}
63 | else:
64 | ret = {'image': img, 'mask': mask, 'name': str(img_path.split('/')[-1][:-4])}
65 |
66 | self.transform(transform_cfg)(ret)
67 |
68 | if self.img_cond and self.corr_mode == 'diffusion_ls':
69 | if 'trunc' in self.data_path:
70 | ret['image'] /= 5.
71 | else:
72 | ret['image'] *= 10./3. # heuristic from intensity distribution of histogram
73 |
74 | return ret
75 |
--------------------------------------------------------------------------------
/datasets/transform_factory.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Dict, List, Optional, Tuple
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from torchvision.transforms import Compose, Lambda, transforms, InterpolationMode, CenterCrop
7 | from torchvision.transforms.functional import crop, hflip, vflip, equalize
8 |
9 | class ApplyTransformToKey:
10 | def __init__(self, key: str, transform: Callable):
11 | self._key = key
12 | self._transform = transform
13 |
14 | def __call__(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
15 | x[self._key] = self._transform(x[self._key])
16 | return x
17 |
18 | # ----------------------------------------------------------------
19 |
20 |
21 | def train_glas_transform(mean, std):
22 | def configured_transform(transform_config):
23 | p_hflip = transform_config['hflip']
24 | p_vflip = transform_config['vflip']
25 | corr_type = transform_config['corr_type']
26 | img_cond = transform_config['img_cond']
27 |
28 | def hflip_closure(img):
29 | return hflip(img) if p_hflip > 0.5 else img
30 |
31 | def vflip_closure(img):
32 | return vflip(img) if p_vflip > 0.5 else img
33 |
34 | def normalization_mask(mask):
35 | return (mask-0.5)*2 # normalize all to be in [-1,1] for guidance image
36 |
37 | return Compose([
38 | ApplyTransformToKey(
39 | key="image",
40 | transform=Compose([
41 | transforms.ToTensor(),
42 | hflip_closure,
43 | vflip_closure,
44 | ]),
45 | ),
46 |
47 | ApplyTransformToKey(
48 | key="mask",
49 | transform=Compose([
50 | transforms.ToTensor(),
51 | normalization_mask,
52 | hflip_closure,
53 | vflip_closure,
54 | ]),
55 | ),
56 | ])
57 | return configured_transform
58 |
59 | def test_glas_transform(mean, std):
60 | def configured_transform(transform_config):
61 | def normalization_mask(mask):
62 | return (mask-0.5)*2 # normalize all to be in [-1,1] for guidance image
63 |
64 | return Compose([
65 | ApplyTransformToKey(
66 | key="image",
67 | transform=Compose([
68 | transforms.ToTensor(),
69 | ]),
70 | ),
71 |
72 | ApplyTransformToKey(
73 | key="mask",
74 | transform=Compose([
75 | transforms.ToTensor(),
76 | normalization_mask,
77 | ]),
78 | ),
79 | ])
80 | return configured_transform
81 |
82 | def train_monuseg_transform(mean, std):
83 | def configured_transform(transform_config):
84 | p_hflip = transform_config['hflip']
85 | p_vflip = transform_config['vflip']
86 | corr_type = transform_config['corr_type']
87 | img_cond = transform_config['img_cond']
88 |
89 | def normalization(img):
90 | return (img-0.5)*2 if corr_type == 0 else img
91 |
92 | def hflip_closure(img):
93 | return hflip(img) if p_hflip > 0.5 else img
94 |
95 | def vflip_closure(img):
96 | return vflip(img) if p_vflip > 0.5 else img
97 |
98 | def normalization(img):
99 | return img # placeholder for different normalization procedure
100 |
101 | def normalization_mask(mask):
102 | return (mask-0.5)*2 # normalize all to be in [-1,1] for guidance image
103 |
104 | interp_mode_img = InterpolationMode.NEAREST if img_cond == 1 else InterpolationMode.BILINEAR
105 | interp_mode_mask = InterpolationMode.BILINEAR if img_cond == 1 else InterpolationMode.NEAREST
106 |
107 | return Compose([
108 | ApplyTransformToKey(
109 | key="image",
110 | transform=Compose([
111 | transforms.ToTensor(),
112 | hflip_closure,
113 | vflip_closure,
114 | ]),
115 | ),
116 |
117 | ApplyTransformToKey(
118 | key="mask",
119 | transform=Compose([
120 | transforms.ToTensor(),
121 | normalization_mask,
122 | hflip_closure,
123 | vflip_closure,
124 | ]),
125 | ),
126 | ])
127 | return configured_transform
128 |
129 | def test_monuseg_transform(mean, std):
130 | def configured_transform(transform_config):
131 | p_hflip = transform_config['hflip']
132 | p_vflip = transform_config['vflip']
133 | corr_type = transform_config['corr_type']
134 | img_cond = transform_config['img_cond']
135 |
136 | def normalization(img):
137 | return (img-0.5)*2 if corr_type == 0 else img
138 |
139 | def normalization(img):
140 | return img # placeholder for different normalization procedure
141 |
142 | def normalization_mask(mask):
143 | return (mask-0.5)*2 # normalize all to be in [-1,1] for guidance image
144 |
145 | interp_mode_img = InterpolationMode.NEAREST if img_cond == 1 else InterpolationMode.BILINEAR
146 | interp_mode_mask = InterpolationMode.BILINEAR if img_cond == 1 else InterpolationMode.NEAREST
147 |
148 | return Compose([
149 | ApplyTransformToKey(
150 | key="image",
151 | transform=Compose([
152 | transforms.ToTensor(),
153 | ]),
154 | ),
155 |
156 | ApplyTransformToKey(
157 | key="mask",
158 | transform=Compose([
159 | transforms.ToTensor(),
160 | normalization_mask,
161 | ]),
162 | ),
163 | ])
164 | return configured_transform
165 |
166 | # ----------------------------------------------------------------
167 | def inv_normalize(mean, std):
168 | return transforms.Normalize(mean=-mean/std, std=1/std)
169 |
170 | def transform_factory(cfg):
171 | if cfg.modality == 'monuseg':
172 | ret = {
173 | 'train': train_monuseg_transform,
174 | 'test' : test_monuseg_transform
175 | }
176 |
177 | elif cfg.modality == 'glas':
178 | ret = {
179 | 'train': train_glas_transform,
180 | 'test': test_glas_transform
181 | }
182 |
183 | else:
184 | raise ValueError('Unknown modality %s specified!' %cfg.modality)
185 |
186 | return ret
187 |
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | name: generative_segmentation_sdf
2 | channels:
3 | - anaconda
4 | - pytorch
5 | - nvidia
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - _libgcc_mutex=0.1=conda_forge
10 | - _openmp_mutex=4.5=1_gnu
11 | - _tflow_select=2.3.0=mkl
12 | - absl-py=0.13.0=pyhd8ed1ab_0
13 | - aiohttp=3.7.4.post0=py38h497a2fe_0
14 | - appdirs=1.4.4=pyh9f0ad1d_0
15 | - astor=0.8.1=pyh9f0ad1d_0
16 | - astunparse=1.6.3=pyhd8ed1ab_0
17 | - async-timeout=3.0.1=py_1000
18 | - attrs=21.2.0=pyhd8ed1ab_0
19 | - autopep8=1.5.7=pyhd8ed1ab_0
20 | - blas=1.0=mkl
21 | - blinker=1.4=py_1
22 | - brotlipy=0.7.0=py38h497a2fe_1001
23 | - bzip2=1.0.8=h7f98852_4
24 | - c-ares=1.17.1=h7f98852_1
25 | - ca-certificates=2022.9.24=ha878542_0
26 | - cachetools=4.2.2=pyhd8ed1ab_0
27 | - certifi=2022.9.24=pyhd8ed1ab_0
28 | - cffi=1.14.6=py38ha65f79e_0
29 | - chardet=3.0.4=py38h924ce5b_1008
30 | - click=8.0.1=py38h578d9bd_0
31 | - cloudpickle=2.0.0=pyhd8ed1ab_0
32 | - coverage=5.5=py38h497a2fe_0
33 | - cryptography=3.4.7=py38ha5dfef3_0
34 | - cudatoolkit=11.1.74=h6bb024c_0
35 | - cycler=0.10.0=py_2
36 | - cython=0.29.24=py38h709712a_0
37 | - cytoolz=0.11.0=py38h497a2fe_3
38 | - dask-core=2021.8.1=pyhd8ed1ab_0
39 | - dbus=1.13.18=hb2f20db_0
40 | - decorator=4.4.2=py_0
41 | - dominate=2.6.0=pyhd8ed1ab_0
42 | - einops=0.4.1=pyhd8ed1ab_0
43 | - expat=2.4.1=h9c3ff4c_0
44 | - ffmpeg=4.3=hf484d3e_0
45 | - fontconfig=2.13.1=h6c09931_0
46 | - freetype=2.10.4=h0708190_1
47 | - fsspec=2021.8.1=pyhd8ed1ab_0
48 | - gast=0.4.0=pyh9f0ad1d_0
49 | - gettext=0.19.8.1=h0b5b191_1005
50 | - glib=2.69.0=h5202010_0
51 | - gmp=6.2.1=h58526e2_0
52 | - gnutls=3.6.15=he1e5248_0
53 | - google-auth=1.33.0=pyh6c4a22f_0
54 | - google-auth-oauthlib=0.4.4=pyhd8ed1ab_0
55 | - google-pasta=0.2.0=pyh8c360ce_0
56 | - grpcio=1.36.1=py38hdd6454d_0
57 | - gst-plugins-base=1.14.0=h8213a91_2
58 | - gstreamer=1.14.0=h28cd5cc_2
59 | - h5py=2.10.0=nompi_py38h9915d05_106
60 | - hdf5=1.10.6=nompi_h7c3c948_1111
61 | - icu=58.2=hf484d3e_1000
62 | - idna=2.10=pyh9f0ad1d_0
63 | - imageio=2.9.0=py_0
64 | - importlib-metadata=3.10.0=py38h578d9bd_0
65 | - intel-openmp=2021.3.0=h06a4308_3350
66 | - joblib=1.0.1=pyhd8ed1ab_0
67 | - jpeg=9b=h024ee3a_2
68 | - keras-preprocessing=1.1.2=pyhd8ed1ab_0
69 | - kiwisolver=1.3.1=py38h1fd1430_1
70 | - krb5=1.19.2=hcc1bbae_0
71 | - lame=3.100=h7f98852_1001
72 | - lcms2=2.12=h3be6417_0
73 | - ld_impl_linux-64=2.35.1=hea4e1c9_2
74 | - libblas=3.9.0=11_linux64_mkl
75 | - libcblas=3.9.0=11_linux64_mkl
76 | - libcurl=7.78.0=h2574ce0_0
77 | - libedit=3.1.20191231=he28a2e2_2
78 | - libev=4.33=h516909a_1
79 | - libffi=3.3=h58526e2_2
80 | - libgcc-ng=9.3.0=h2828fa1_19
81 | - libgfortran-ng=7.5.0=h14aa051_19
82 | - libgfortran4=7.5.0=h14aa051_19
83 | - libgomp=9.3.0=h2828fa1_19
84 | - libiconv=1.15=h516909a_1006
85 | - libidn2=2.3.2=h7f98852_0
86 | - libllvm10=10.0.1=he513fc3_3
87 | - libnghttp2=1.43.0=h812cca2_0
88 | - libpng=1.6.37=h21135ba_2
89 | - libprotobuf=3.17.2=h780b84a_1
90 | - libssh2=1.9.0=ha56f1ee_6
91 | - libstdcxx-ng=9.3.0=h6de172a_19
92 | - libtasn1=4.16.0=h27cfd23_0
93 | - libtiff=4.2.0=h85742a9_0
94 | - libunistring=0.9.10=h7f98852_0
95 | - libuuid=1.0.3=h7f8727e_2
96 | - libuv=1.40.0=h7f98852_0
97 | - libwebp-base=1.2.0=h7f98852_2
98 | - libxcb=1.14=h7b6447c_0
99 | - libxml2=2.9.12=h03d6c58_0
100 | - libxslt=1.1.34=hc22bd24_0
101 | - libzlib=1.2.11=h36c2ea0_1013
102 | - llvmlite=0.36.0=py38h4630a5e_0
103 | - locket=0.2.1=py38h06a4308_1
104 | - lxml=4.8.0=py38h1f438cf_0
105 | - lz4-c=1.9.3=h9c3ff4c_1
106 | - markdown=3.3.4=pyhd8ed1ab_0
107 | - matplotlib=3.3.4=py38h578d9bd_0
108 | - matplotlib-base=3.3.4=py38h0efea84_0
109 | - mkl=2021.3.0=h06a4308_520
110 | - mkl-service=2.4.0=py38h497a2fe_0
111 | - mkl_fft=1.3.0=py38h42c9631_2
112 | - mkl_random=1.2.2=py38h1abd341_0
113 | - multidict=5.1.0=py38h497a2fe_1
114 | - mypy=0.910=py38h497a2fe_0
115 | - mypy_extensions=0.4.3=py38h578d9bd_4
116 | - ncurses=6.2=h58526e2_4
117 | - nettle=3.7.3=hbbd107a_1
118 | - networkx=2.6.3=pyhd8ed1ab_1
119 | - ninja=1.10.2=h4bd325d_0
120 | - numba=0.53.1=py38h8b71fd7_1
121 | - numpy=1.20.3=py38hf144106_0
122 | - numpy-base=1.20.3=py38h74d4b33_0
123 | - oauthlib=3.1.1=pyhd8ed1ab_0
124 | - olefile=0.46=pyh9f0ad1d_1
125 | - openh264=2.1.0=hd408876_0
126 | - openjpeg=2.3.0=hf38bd82_1003
127 | - openssl=1.1.1q=h7f8727e_0
128 | - opt_einsum=3.3.0=pyhd8ed1ab_1
129 | - packaging=21.0=pyhd8ed1ab_0
130 | - pandas=1.3.1=py38h1abd341_0
131 | - partd=1.2.0=pyhd8ed1ab_0
132 | - pcre=8.45=h9c3ff4c_0
133 | - pillow=8.3.1=py38h2c7a002_0
134 | - pip=21.1.3=pyhd8ed1ab_0
135 | - pooch=1.5.2=pyhd8ed1ab_0
136 | - protobuf=3.17.2=py38h709712a_0
137 | - psutil=5.8.0=py38h497a2fe_1
138 | - pyasn1=0.4.8=py_0
139 | - pyasn1-modules=0.2.8=py_0
140 | - pycodestyle=2.7.0=pyhd8ed1ab_0
141 | - pycparser=2.20=pyh9f0ad1d_2
142 | - pyjwt=2.1.0=pyhd8ed1ab_0
143 | - pyopenssl=20.0.1=pyhd8ed1ab_0
144 | - pyparsing=2.4.7=pyhd8ed1ab_1
145 | - pyqt=5.9.2=py38h05f1152_4
146 | - pysocks=1.7.1=py38h578d9bd_4
147 | - python=3.8.10=h49503c6_1_cpython
148 | - python-dateutil=2.8.2=pyhd8ed1ab_0
149 | - python-flatbuffers=1.12=pyhd8ed1ab_1
150 | - python_abi=3.8=2_cp38
151 | - pytorch=1.9.0=py3.8_cuda11.1_cudnn8.0.5_0
152 | - pytz=2021.3=pyhd8ed1ab_0
153 | - pyu2f=0.1.5=pyhd8ed1ab_0
154 | - pywavelets=1.1.1=py38h5c078b8_3
155 | - pyyaml=5.4.1=py38h497a2fe_0
156 | - qt=5.9.7=h5867ecd_1
157 | - readline=8.1=h46c0cb4_0
158 | - requests=2.25.1=pyhd3deb0d_0
159 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0
160 | - rsa=4.7.2=pyh44b312d_0
161 | - scikit-image=0.18.1=py38h51da96c_0
162 | - scikit-learn=0.24.2=py38hdc147b9_0
163 | - scipy=1.6.2=py38had2a1c9_1
164 | - seaborn=0.11.2=pyhd3eb1b0_0
165 | - setuptools=52.0.0=py38h06a4308_1
166 | - sip=4.19.13=py38he6710b0_0
167 | - six=1.16.0=pyh6c4a22f_0
168 | - sqlite=3.36.0=h9cd32fc_0
169 | - tbb=2020.3=hfd86e86_0
170 | - tensorboard=2.6.0=pyhd8ed1ab_1
171 | - tensorboard-data-server=0.6.0=py38h2b97feb_0
172 | - tensorboard-plugin-wit=1.6.0=pyh9f0ad1d_0
173 | - tensorboardx=2.5.1=pyhd8ed1ab_0
174 | - tensorflow=2.4.1=mkl_py38hb2083e0_0
175 | - tensorflow-base=2.4.1=mkl_py38h43e0292_0
176 | - tensorflow-estimator=2.5.0=pyh81a9013_1
177 | - termcolor=1.1.0=py_2
178 | - threadpoolctl=2.2.0=pyh8a188c0_0
179 | - tifffile=2020.10.1=py38hdd07704_2
180 | - tk=8.6.10=h21135ba_1
181 | - toml=0.10.2=pyhd8ed1ab_0
182 | - toolz=0.11.1=py_0
183 | - torchaudio=0.9.0=py38
184 | - torchvision=0.10.0=py38_cu111
185 | - tornado=6.1=py38h497a2fe_1
186 | - typing-extensions=3.10.0.0=hd8ed1ab_0
187 | - typing_extensions=3.10.0.0=pyha770c72_0
188 | - urllib3=1.26.6=pyhd8ed1ab_0
189 | - werkzeug=1.0.1=pyh9f0ad1d_0
190 | - wheel=0.36.2=pyhd3deb0d_0
191 | - wrapt=1.12.1=py38h497a2fe_3
192 | - xz=5.2.5=h516909a_1
193 | - yaml=0.2.5=h516909a_0
194 | - yarl=1.6.3=py38h497a2fe_2
195 | - zipp=3.5.0=pyhd8ed1ab_0
196 | - zlib=1.2.11=h36c2ea0_1013
197 | - zstd=1.4.9=ha95c52a_0
198 | - pip:
199 | - anyio==3.6.1
200 | - argon2-cffi==21.3.0
201 | - argon2-cffi-bindings==21.2.0
202 | - asttokens==2.0.5
203 | - av==8.0.3
204 | - babel==2.10.1
205 | - backcall==0.2.0
206 | - beautifulsoup4==4.11.1
207 | - bleach==5.0.0
208 | - bottle==0.12.19
209 | - bottle-websocket==0.2.9
210 | - debugpy==1.6.0
211 | - defusedxml==0.7.1
212 | - eel==0.14.0
213 | - entrypoints==0.4
214 | - executing==0.8.3
215 | - fastjsonschema==2.15.3
216 | - future==0.18.2
217 | - fvcore==0.1.5.post20211019
218 | - gevent==21.1.2
219 | - gevent-websocket==0.10.1
220 | - greenlet==1.1.0
221 | - higher==0.2.1
222 | - imageio-ffmpeg==0.4.5
223 | - importlib-resources==5.7.1
224 | - iopath==0.1.9
225 | - ipykernel==6.13.0
226 | - ipython==8.3.0
227 | - ipython-genutils==0.2.0
228 | - jedi==0.18.1
229 | - jinja2==3.1.2
230 | - json5==0.9.8
231 | - jsonschema==4.5.1
232 | - jupyter-client==7.3.1
233 | - jupyter-core==4.10.0
234 | - jupyter-server==1.17.0
235 | - jupyterlab==3.4.2
236 | - jupyterlab-pygments==0.2.2
237 | - jupyterlab-server==2.13.0
238 | - markupsafe==2.1.1
239 | - matplotlib-inline==0.1.3
240 | - mistune==0.8.4
241 | - moviepy==1.0.3
242 | - nbclassic==0.3.7
243 | - nbclient==0.6.3
244 | - nbconvert==6.5.0
245 | - nbformat==5.4.0
246 | - nest-asyncio==1.5.5
247 | - notebook==6.4.11
248 | - notebook-shim==0.1.0
249 | - opencv-python==4.5.3.56
250 | - optoth==0.2.0
251 | - pad2d-op-v1==1.0
252 | - pandocfilters==1.5.0
253 | - parameterized==0.8.1
254 | - parso==0.8.3
255 | - perlin-noise==1.12
256 | - pexpect==4.8.0
257 | - pickleshare==0.7.5
258 | - portalocker==2.3.2
259 | - proglog==0.1.9
260 | - prometheus-client==0.14.1
261 | - prompt-toolkit==3.0.29
262 | - ptyprocess==0.7.0
263 | - pure-eval==0.2.2
264 | - pygments==2.12.0
265 | - pyrsistent==0.18.1
266 | - pytorchvideo==0.1.3
267 | - pyzmq==22.3.0
268 | - send2trash==1.8.0
269 | - sniffio==1.2.0
270 | - soupsieve==2.3.2.post1
271 | - stack-data==0.2.0
272 | - tabulate==0.8.9
273 | - terminado==0.15.0
274 | - tinycss2==1.1.1
275 | - torch-dct==0.1.5
276 | - torch-tb-profiler==0.3.1
277 | - torchmetrics==0.11.1
278 | - tqdm==4.62.0
279 | - traitlets==5.2.1.post0
280 | - ttach==0.0.3
281 | - wcwidth==0.2.5
282 | - webencodings==0.5.1
283 | - websocket-client==1.3.2
284 | - whichcraft==0.6.1
285 | - yacs==0.1.8
286 | - zope-event==4.5.0
287 | - zope-interface==5.4.0
288 | prefix: /opt/python_envs/anaconda3/envs/granules
289 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import random
2 | from types import SimpleNamespace
3 | import imageio
4 | import numpy as np
5 | import argparse
6 | import sys
7 | import os
8 | import json
9 | from tqdm.auto import tqdm
10 | import matplotlib.pyplot as plt
11 | import yaml
12 |
13 | import einops
14 | import torch
15 | import torch.nn as nn
16 | from torch.optim import Adam
17 | from torch.utils.data import DataLoader
18 | from torch.utils.tensorboard import SummaryWriter
19 |
20 | from SimulationHelper.simulation import Simulation
21 | from datasets.config_dl import config_dl
22 | from models import ddpm
23 | import trainer
24 |
25 | # Setting reproducibility
26 | SEED = 0
27 | random.seed(SEED)
28 | np.random.seed(SEED)
29 | torch.manual_seed(SEED)
30 |
31 | parser = argparse.ArgumentParser("")
32 | parser.add_argument(
33 | "--config", default="cfg/monuseg.yaml", type=str, help="path to .yaml config" # glas, monuseg
34 | )
35 | args = parser.parse_args()
36 |
37 | def count_parameters(net):
38 | return sum(p.numel() for p in net.parameters() if p.requires_grad)
39 |
40 | if __name__ == "__main__":
41 | # program arguments
42 | with open(args.config) as file:
43 | yaml_cfg = yaml.safe_load(file)
44 | cfg = json.loads(
45 | json.dumps(yaml_cfg), object_hook=lambda d: SimpleNamespace(**d)
46 | )
47 |
48 | device = torch.device("cuda")
49 | print(f"Using device: {device}\t" + (f"{torch.cuda.get_device_name(0)}"))
50 |
51 | # set up dataloader, model
52 | train_dl, test_dl = config_dl(cfg)
53 | if cfg.model.type == 'unet':
54 | model = ddpm.Network(
55 | dim=cfg.model.dim,
56 | channels=cfg.model.n_cin,
57 | cond_channels=cfg.model.n_cin_cond,
58 | init_dim=cfg.model.n_fm,
59 | dim_mults=tuple(cfg.model.mults),
60 | embedding=cfg.model.embedding,
61 | img_cond=cfg.general.img_cond,
62 | with_class_label_emb=cfg.general.with_class_label_emb,
63 | class_label_cond=cfg.general.class_label_cond,
64 | num_classes=cfg.general.num_classes,
65 | ).to(device)
66 |
67 | else:
68 | raise ValueError('Unknown model type!')
69 |
70 | # optimizer
71 | optim = Adam(model.parameters(), cfg.learning.lr)
72 |
73 | # Optionally, load a pre-trained model that will be further trained
74 | if cfg.general.resume_training:
75 | load_path = os.getcwd() + "/runs/" + cfg.general.modality
76 | load_path += "/" + cfg.general.load_path + "/models/"
77 | fnames = sorted([fname for fname in os.listdir(load_path) if fname.endswith(".pt")])
78 |
79 | model.load_state_dict(
80 | torch.load(load_path + fnames[-1], map_location=device)["state_dict"],
81 | strict=False,
82 | )
83 | print("\nINFO: succesfully retrieved learned model params from specified cfg dir/epoch!")
84 |
85 | # load optimizer state dict
86 | optim.load_state_dict(torch.load(load_path + fnames[-1], map_location=device)["optimizer"])
87 | print("\nINFO: succesfully retrieved optim state dict specified cfg dir/epoch!")
88 |
89 | # network params
90 | print("\nNetwork has %i params" % count_parameters(model))
91 |
92 | # simulation
93 | sim_name = str(cfg.general.modality)
94 | with Simulation(
95 | sim_name=sim_name, output_root=f'{os.path.join(os.getcwd(), "runs/")}'
96 | ) as simulation:
97 | writer = SummaryWriter(os.path.join(simulation.outdir, "tensorboard"))
98 | cfg.inference.load_exp = simulation.outdir.split("/")[-1]
99 | with open(os.path.join(simulation.outdir, "cfg.yaml"), "w") as f:
100 | yaml.dump({k: v.__dict__ for k, v in cfg.__dict__.items()}, f)
101 |
102 | # training
103 | if (cfg.general.corr_mode == "diffusion" or cfg.general.corr_mode == "diffusion_ls"):
104 | noise_level_dict={'s1': cfg.SMLD.sigma_1_m, 'sL': cfg.SMLD.sigma_L_m, 'L': cfg.SMLD.n_steps}
105 | beta_dict = {'beta1': cfg.SMLD.beta_1, 'betaT': cfg.SMLD.beta_T, 'T': cfg.SMLD.T}
106 |
107 | trainer.TrainScoreNetwork(noise_level_dict,beta_dict,sde=cfg.SMLD.sde,model_type=cfg.model.type,train_objective=cfg.SMLD.objective,loss_power=cfg.learning.loss,n_val=cfg.learning.n_val,val_dl=train_dl).do(
108 | model,
109 | train_dl,
110 | cfg.learning.epochs,
111 | cfg.learning.clip,
112 | optim=optim,
113 | device=device,
114 | simulation=simulation,
115 | writer=writer,
116 | img_cond=cfg.general.img_cond,
117 | class_label_cond=cfg.general.class_label_cond,
118 | )
119 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leabogensperger/generative-segmentation-sdf/3b8037e3e03d347a41f361f6ba844e6928240200/models/__init__.py
--------------------------------------------------------------------------------
/models/ddpm.py:
--------------------------------------------------------------------------------
1 | import math
2 | from inspect import isfunction
3 | from functools import partial
4 | import matplotlib.pyplot as plt
5 | from tqdm.auto import tqdm
6 | from einops import rearrange
7 | import numpy as np
8 |
9 | import torch
10 | from torch import nn, einsum
11 | import torch.nn.functional as F
12 |
13 | # taken and adapted from https://huggingface.co/blog/annotated-diffusion
14 |
15 | def exists(x):
16 | return x is not None
17 |
18 |
19 | def default(val, d):
20 | if exists(val):
21 | return val
22 | return d() if isfunction(d) else d
23 |
24 |
25 | class Residual(nn.Module):
26 | def __init__(self, fn):
27 | super().__init__()
28 | self.fn = fn
29 |
30 | def forward(self, x, *args, **kwargs):
31 | return self.fn(x, *args, **kwargs) + x
32 |
33 |
34 | def Upsample(dim):
35 | return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
36 |
37 |
38 | def Downsample(dim):
39 | return nn.Conv2d(dim, dim, 4, 2, 1)
40 |
41 |
42 | class SinusoidalPositionEmbeddings(nn.Module):
43 | def __init__(self, dim):
44 | super().__init__()
45 | self.dim = dim
46 |
47 | def forward(self, time):
48 | device = time.device
49 | half_dim = self.dim // 2
50 | embeddings = math.log(10000) / (half_dim - 1)
51 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
52 | embeddings = (
53 | time * embeddings[None, :]
54 | ) # time is already in the shape [batchsize,1]
55 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
56 | return embeddings
57 |
58 | class GaussianFourierProjection(nn.Module): # for continuous training
59 | """Gaussian Fourier embeddings for noise levels.
60 | taken from https://github.com/yang-song/score_sde_pytorch/blob/cb1f359f4aadf0ff9a5e122fe8fffc9451fd6e44/models/layerspp.py#L32
61 | """
62 |
63 | def __init__(self, dim=256, scale=16.0):
64 | super().__init__()
65 | dim = dim//2
66 | self.W = nn.Parameter(torch.randn(dim) * scale, requires_grad=False)
67 |
68 | def forward(self, x):
69 | x_proj = x* self.W[None, :] * 2 * np.pi
70 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
71 |
72 |
73 | class Block(nn.Module):
74 | def __init__(self, dim, dim_out, groups=8):
75 | super().__init__()
76 | self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
77 | self.norm = nn.GroupNorm(groups, dim_out)
78 | self.act = nn.SiLU()
79 |
80 | def forward(self, x, scale_shift=None):
81 | x = self.proj(x)
82 | x = self.norm(x)
83 |
84 | if exists(scale_shift):
85 | scale, shift = scale_shift
86 | x = x * (scale + 1) + shift
87 |
88 | x = self.act(x)
89 | return x
90 |
91 |
92 | class ResnetBlock(nn.Module):
93 | """https://arxiv.org/abs/1512.03385"""
94 |
95 | def __init__(
96 | self,
97 | dim,
98 | dim_out,
99 | *,
100 | time_emb_dim=None,
101 | groups=8,
102 | class_label_cond=False,
103 | num_classes=None
104 | ):
105 | super().__init__()
106 | self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out)) if exists(time_emb_dim) else None
107 |
108 | if class_label_cond:
109 | # TODO time_emb_dim, class_emd_dim should have their own params for dimentionality.
110 | self.class_label_mlp = nn.Sequential(
111 | nn.SiLU(), nn.Linear(time_emb_dim, dim_out)
112 | )
113 |
114 | self.block1 = Block(dim, dim_out, groups=groups)
115 | self.block2 = Block(dim_out, dim_out, groups=groups)
116 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
117 |
118 | self.num_classees = num_classes
119 | self.class_label_cond = class_label_cond
120 |
121 | def forward(self, x, time_emb=None, class_lbl=None):
122 | h = self.block1(x)
123 |
124 | if exists(self.mlp) and exists(time_emb):
125 | time_emb = self.mlp(time_emb)
126 | h = rearrange(time_emb, "b c -> b c 1 1") + h
127 |
128 | if self.class_label_cond is True:
129 | class_lbl_emb = self.class_label_mlp(
130 | class_lbl
131 | ) # Bring the lbl_emb to correct shape.
132 | h = rearrange(class_lbl_emb, "b c -> b c 1 1") + h
133 |
134 | h = self.block2(h)
135 | return h + self.res_conv(x)
136 |
137 |
138 | class Attention(nn.Module):
139 | def __init__(self, dim, heads=4, dim_head=32):
140 | super().__init__()
141 | self.scale = dim_head**-0.5
142 | self.heads = heads
143 | hidden_dim = dim_head * heads
144 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
145 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
146 |
147 | def forward(self, x):
148 | b, c, h, w = x.shape
149 | qkv = self.to_qkv(x).chunk(3, dim=1)
150 | q, k, v = map(
151 | lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
152 | )
153 | q = q * self.scale
154 |
155 | sim = einsum("b h d i, b h d j -> b h i j", q, k)
156 | sim = sim - sim.amax(dim=-1, keepdim=True).detach()
157 | attn = sim.softmax(dim=-1)
158 |
159 | out = einsum("b h i j, b h d j -> b h i d", attn, v)
160 | out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
161 | return self.to_out(out)
162 |
163 |
164 | class LinearAttention(nn.Module):
165 | def __init__(self, dim, heads=4, dim_head=32):
166 | super().__init__()
167 | self.scale = dim_head**-0.5
168 | self.heads = heads
169 | hidden_dim = dim_head * heads
170 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
171 |
172 | self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
173 |
174 | def forward(self, x):
175 | b, c, h, w = x.shape
176 | qkv = self.to_qkv(x).chunk(3, dim=1)
177 | q, k, v = map(
178 | lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
179 | )
180 |
181 | q = q.softmax(dim=-2)
182 | k = k.softmax(dim=-1)
183 |
184 | q = q * self.scale
185 | context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
186 |
187 | out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
188 | out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
189 | return self.to_out(out)
190 |
191 |
192 | class PreNorm(nn.Module):
193 | def __init__(self, dim, fn):
194 | super().__init__()
195 | self.fn = fn
196 | self.norm = nn.GroupNorm(1, dim)
197 |
198 | def forward(self, x):
199 | x = self.norm(x)
200 | return self.fn(x)
201 |
202 |
203 | class Network(nn.Module):
204 | def __init__(
205 | self,
206 | dim,
207 | init_dim=None,
208 | out_dim=None,
209 | dim_mults=(1, 2, 4, 8),
210 | channels=1, # channels of sdf input
211 | cond_channels=1, # rgb vs gray input for conditioning
212 | embedding='sinusoidal',
213 | with_time_emb=True,
214 | resnet_block_groups=2,
215 | img_cond=None,
216 | with_class_label_emb=False,
217 | class_label_cond=False,
218 | num_classes=None,
219 | ):
220 | super().__init__()
221 |
222 | self.embedding = embedding
223 | assert self.embedding == 'fourier' or self.embedding == 'sinusoidal'
224 |
225 | self.class_label_cond = class_label_cond
226 | self.num_classes = num_classes
227 | self.with_class_label_emb = with_class_label_emb
228 |
229 | block_klass = partial(ResnetBlock, groups=resnet_block_groups)
230 |
231 | # time embeddings
232 | if with_time_emb:
233 | time_dim = dim * 4
234 | self.time_mlp = nn.Sequential(
235 | GaussianFourierProjection(dim) if self.embedding == 'fourier' else SinusoidalPositionEmbeddings(dim), # TODO: include option which embedding type to use
236 | # SinusoidalPositionEmbeddings(dim),
237 | nn.Linear(dim, time_dim),
238 | nn.GELU(),
239 | nn.Linear(time_dim, time_dim),
240 | )
241 | else:
242 | raise Exception("Time embedding is set to False, None of the other code can deal with it. Think. Idiot.")
243 |
244 | # class_label embeddings
245 | if class_label_cond == True:
246 | if with_class_label_emb:
247 | class_label_dim = dim * 4
248 | self.class_label_embedding_mlp = nn.Sequential(
249 | SinusoidalPositionEmbeddings(dim),
250 | nn.Linear(dim, class_label_dim),
251 | nn.GELU(),
252 | nn.Linear(class_label_dim, class_label_dim),
253 | )
254 | else:
255 | class_label_dim = dim * 4
256 | self.class_label_embedding_mlp = nn.Sequential(
257 | nn.Linear(dim, class_label_dim),
258 | nn.GELU(),
259 | nn.Linear(class_label_dim, class_label_dim),
260 | )
261 | else:
262 | self.class_label_embedding_mlp = None
263 |
264 | # determine dimensions
265 | self.channels = channels
266 | self.cond_channels = cond_channels
267 |
268 | # conditioning branch at beginning (SegDiff style)
269 | if img_cond == 1:
270 | self.encoder_img = nn.Sequential(
271 | block_klass(channels, init_dim, time_emb_dim=time_dim)
272 | )
273 | self.encoder_mask = nn.Sequential(
274 | block_klass(cond_channels, init_dim, time_emb_dim=time_dim)
275 | )
276 | init_dim *= 2
277 | channels = init_dim
278 |
279 | init_dim = default(init_dim, dim // 3 * 2)
280 | self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)
281 |
282 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
283 | in_out = list(zip(dims[:-1], dims[1:]))
284 |
285 | # layers
286 | self.downs = nn.ModuleList([])
287 | self.ups = nn.ModuleList([])
288 | num_resolutions = len(in_out)
289 |
290 | for ind, (dim_in, dim_out) in enumerate(in_out):
291 | is_last = ind >= (num_resolutions - 1)
292 |
293 | self.downs.append(
294 | nn.ModuleList(
295 | [
296 | block_klass(
297 | dim_in,
298 | dim_out,
299 | time_emb_dim=time_dim,
300 | class_label_cond=class_label_cond,
301 | num_classes=num_classes,
302 | ),
303 | block_klass(
304 | dim_out,
305 | dim_out,
306 | time_emb_dim=time_dim,
307 | class_label_cond=class_label_cond,
308 | num_classes=num_classes,
309 | ),
310 | Residual(PreNorm(dim_out, LinearAttention(dim_out))),
311 | Downsample(dim_out) if not is_last else nn.Identity(),
312 | ]
313 | )
314 | )
315 |
316 | mid_dim = dims[-1]
317 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
318 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
319 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
320 |
321 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
322 | is_last = ind >= (num_resolutions - 1)
323 |
324 | self.ups.append(
325 | nn.ModuleList(
326 | [
327 | block_klass(
328 | dim_out * 2,
329 | dim_in,
330 | time_emb_dim=time_dim,
331 | class_label_cond=class_label_cond,
332 | num_classes=num_classes,
333 | ),
334 | block_klass(
335 | dim_in,
336 | dim_in,
337 | time_emb_dim=time_dim,
338 | class_label_cond=class_label_cond,
339 | num_classes=num_classes,
340 | ),
341 | Residual(PreNorm(dim_in, LinearAttention(dim_in))),
342 | Upsample(dim_in) if not is_last else nn.Identity(),
343 | ]
344 | )
345 | )
346 |
347 | out_dim = default(out_dim, self.channels)
348 | self.final_conv = nn.Sequential(
349 | block_klass(dim, dim, time_emb_dim=None), nn.Conv2d(dim, out_dim, 1)
350 | )
351 |
352 | def forward(self, x, time, img_cond=None, class_lbl=None):
353 | if img_cond is not None:
354 | # x = self.encoder_img(x) + self.encoder_mask(cond)
355 | x = torch.cat((self.encoder_img(x), self.encoder_mask(img_cond)), 1)
356 |
357 | x = self.init_conv(x)
358 |
359 | t = self.time_mlp(time) if exists(self.time_mlp) else None
360 | class_lbl = (
361 | self.class_label_embedding_mlp(class_lbl)
362 | if exists(self.class_label_embedding_mlp)
363 | else None
364 | )
365 |
366 | h = []
367 |
368 | # downsample
369 | for block1, block2, attn, downsample in self.downs:
370 | x = block1(x, t, class_lbl)
371 | x = block2(x, t, class_lbl)
372 | x = attn(x)
373 | h.append(x)
374 | x = downsample(x)
375 |
376 | # bottleneck
377 | x = self.mid_block1(x, t, class_lbl)
378 | x = self.mid_attn(x)
379 | x = self.mid_block2(x, t, class_lbl)
380 |
381 | # upsample
382 | for block1, block2, attn, upsample in self.ups:
383 | x = torch.cat((x, h.pop()), dim=1)
384 | x = block1(x, t, class_lbl)
385 | x = block2(x, t, class_lbl)
386 | x = attn(x)
387 | x = upsample(x)
388 |
389 | return self.final_conv(x)
--------------------------------------------------------------------------------
/preprocess_data/precompute_sdf.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 4,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import matplotlib.pyplot as plt\n",
11 | "from skimage.draw import ellipse\n",
12 | "from skimage.morphology import binary_erosion\n",
13 | "from scipy.ndimage import distance_transform_edt"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 37,
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "# generate toy mask\n",
23 | "N = 128\n",
24 | "m = np.zeros((N, N), dtype=int)\n",
25 | "\n",
26 | "# Define parameters for the ellipses (center, semi-axes, and orientation)\n",
27 | "ellipses = [\n",
28 | " (30, 40, 20, 20, np.deg2rad(30)), # (row, col, semi-major, semi-minor, rotation)\n",
29 | " (20, 80, 10, 20, np.deg2rad(75)),\n",
30 | " (90, 60, 30, 40, np.deg2rad(15)),\n",
31 | "]\n",
32 | "\n",
33 | "# Draw each ellipse on the mask\n",
34 | "for (r, c, r_radius, c_radius, angle) in ellipses:\n",
35 | " rr, cc = ellipse(r, c, r_radius, c_radius, rotation=angle, shape=m.shape)\n",
36 | " m[rr, cc] = 1"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 38,
42 | "metadata": {},
43 | "outputs": [
44 | {
45 | "data": {
46 | "image/png": "",
47 | "text/plain": [
48 | ""
49 | ]
50 | },
51 | "metadata": {},
52 | "output_type": "display_data"
53 | }
54 | ],
55 | "source": [
56 | "# extract boundaries\n",
57 | "m_bd = np.abs(binary_erosion(m) - m)\n",
58 | "\n",
59 | "fig, ax = plt.subplots(1,2)\n",
60 | "ax[0].imshow(m), ax[0].set_title('binary mask')\n",
61 | "ax[1].imshow(m_bd), ax[1].set_title('extracted boundaries')\n",
62 | "plt.show()"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 55,
68 | "metadata": {},
69 | "outputs": [
70 | {
71 | "data": {
72 | "image/png": "",
73 | "text/plain": [
74 | ""
75 | ]
76 | },
77 | "metadata": {},
78 | "output_type": "display_data"
79 | }
80 | ],
81 | "source": [
82 | "# extract signed distance function (SDF)\n",
83 | "distance = distance_transform_edt(np.where(m_bd==0., np.ones_like(m_bd), np.zeros_like(m_bd)))\n",
84 | "m_sdf = np.where(m == 1, distance*-1, distance) # ensure signed DT\n",
85 | "\n",
86 | "# truncate at threshold and normalize between [-1,1]\n",
87 | "thresh = 15\n",
88 | "m_sdf[m_sdf >= thresh] = thresh\n",
89 | "m_sdf[m_sdf <= -thresh] = -thresh\n",
90 | "m_sdf /= thresh\n",
91 | "\n",
92 | "fig, ax = plt.subplots(1,2)\n",
93 | "ax[0].imshow(m), ax[0].set_title('binary mask')\n",
94 | "ax[1].imshow(m_sdf), ax[1].set_title('SDF transformed mask')\n",
95 | "plt.show()"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": 56,
101 | "metadata": {},
102 | "outputs": [
103 | {
104 | "name": "stdout",
105 | "output_type": "stream",
106 | "text": [
107 | "SDF mask values living in [-1,1]\n"
108 | ]
109 | }
110 | ],
111 | "source": [
112 | "# check SDF normalization\n",
113 | "print('SDF mask values living in [%i,%i]' %(m_sdf.min(),m_sdf.max()))"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": 57,
119 | "metadata": {},
120 | "outputs": [
121 | {
122 | "name": "stdout",
123 | "output_type": "stream",
124 | "text": [
125 | "Retrieved mask from thresholding SDF agrees with original binary mask: True\n"
126 | ]
127 | }
128 | ],
129 | "source": [
130 | "# retrieve binary mask from SDF mask by thresholding \n",
131 | "m_retrieved = np.where(m_sdf <= 0, np.ones_like(m_sdf), np.zeros_like(m_sdf))\n",
132 | "print('Retrieved mask from thresholding SDF agrees with original binary mask: %s' %np.allclose(m_retrieved,m)) "
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": null,
138 | "metadata": {},
139 | "outputs": [],
140 | "source": []
141 | }
142 | ],
143 | "metadata": {
144 | "kernelspec": {
145 | "display_name": "misc",
146 | "language": "python",
147 | "name": "python3"
148 | },
149 | "language_info": {
150 | "codemirror_mode": {
151 | "name": "ipython",
152 | "version": 3
153 | },
154 | "file_extension": ".py",
155 | "mimetype": "text/x-python",
156 | "name": "python",
157 | "nbconvert_exporter": "python",
158 | "pygments_lexer": "ipython3",
159 | "version": "3.12.2"
160 | }
161 | },
162 | "nbformat": 4,
163 | "nbformat_minor": 2
164 | }
165 |
--------------------------------------------------------------------------------
/sampler.py:
--------------------------------------------------------------------------------
1 | import random
2 | from types import SimpleNamespace
3 | import imageio
4 | import numpy as np
5 | import argparse
6 | import sys
7 | import os
8 | import json
9 | from tqdm.auto import tqdm
10 | import matplotlib.pyplot as plt
11 | from PIL import ImageDraw
12 | import numpy.ma as ma
13 |
14 | import einops
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 | from torch.optim import Adam
19 | from torch.utils.data import DataLoader
20 | from torch.utils.tensorboard import SummaryWriter
21 | import yaml
22 | from torchvision.transforms import PILToTensor, ToPILImage
23 | from PIL import ImageFont, ImageDraw, Image
24 | from torchvision.utils import make_grid
25 | from torchmetrics import JaccardIndex, Dice, F1Score
26 | import torchmetrics
27 | from torchvision.utils import save_image
28 | from torch.nn.functional import one_hot
29 |
30 | from SimulationHelper.simulation import Simulation
31 | from datasets.config_dl import config_dl
32 | from models import ddpm
33 |
34 | parser = argparse.ArgumentParser("")
35 | parser.add_argument(
36 | "--config", default="cfg/monuseg.yaml", type=str, help="path to .yaml config"
37 | )
38 | parser.add_argument("--seed", default=0, type=int, help="seed for reproducibility") # 1
39 | args = parser.parse_args()
40 |
41 | # Setting reproducibility
42 | def set_seed(SEED=0):
43 | random.seed(SEED)
44 | np.random.seed(SEED)
45 | torch.manual_seed(SEED)
46 |
47 | def store_gif(frames, frames_per_gif, load_path, sample_str=''):
48 | gif_name = load_path + "/samples/samples" + sample_str + ".gif"
49 |
50 | with imageio.get_writer(gif_name, mode="I") as writer:
51 | for idx, frame in enumerate(frames):
52 | writer.append_data(frame)
53 | if idx == len(frames) - 1:
54 | for _ in range(frames_per_gif // 3):
55 | writer.append_data(frames[-1])
56 |
57 | def show_images(images, vmin=None, vmax=None, save_name="", overlay=None):
58 | """Shows the provided images as sub-pictures in a square"""
59 | alpha=0.6 if overlay is not None else 1. # alpha channel if additional overlay image is given
60 | if vmin is None:
61 | vmin = images.min().item()
62 | if vmax is None:
63 | vmax = images.max().item()
64 |
65 | if overlay is not None:
66 | overlay = overlay.detach().cpu().numpy()
67 |
68 | # Converting images to CPU numpy arrays
69 | if type(images) is torch.Tensor:
70 | images = images.detach().cpu().numpy()
71 |
72 | # Defining number of rows and columns
73 | fig = plt.figure(figsize=(8, 8))
74 | rows = int(len(images) ** (1 / 2))
75 | cols = round(len(images) / rows)
76 |
77 | # Populating figure with sub-plots
78 | idx = 0
79 | for r in range(rows):
80 | for c in range(cols):
81 | fig.add_subplot(rows, cols, idx + 1)
82 |
83 | if idx < len(images):
84 | if overlay is not None:
85 | plt.imshow(overlay[idx][0], cmap="gray")
86 | images[:,:,0,0] = vmax # this is just for plotting!
87 | images[:,:,0,1] = 1
88 | mask = np.ma.masked_where(images[idx][0] == 0, images[idx][0])
89 | plt.imshow(mask, alpha=alpha), plt.axis("off")
90 | else:
91 | plt.imshow(images[idx][0], alpha=alpha, cmap="gray", vmin=vmin, vmax=vmax), plt.axis("off")
92 | idx += 1
93 |
94 | # Showing the figure
95 | plt.savefig(save_name, bbox_inches="tight", dpi=250)
96 | plt.close()
97 |
98 |
99 | def compute_metrics(x,x_gt,thresh,corr_mode,num_classes):
100 | # compute IoU between thresholded x and x_gt
101 | x_thresh = torch.where(x > thresh, torch.zeros_like(x), torch.ones_like(x)).type(torch.int8).squeeze().cpu()
102 | x_gt_thresh = torch.where(x_gt > 0., torch.zeros_like(x_gt), torch.ones_like(x_gt)).type(torch.int8).squeeze().cpu()
103 |
104 | jaccard = JaccardIndex(task="binary")
105 | dice = Dice(task='binary',average='macro',num_classes=2,ignore_index=0) # dice=f1 in binary segmentation
106 | iou, dice = jaccard(x_thresh, x_gt_thresh), dice(x_thresh, x_gt_thresh)
107 | return iou, dice
108 |
109 |
110 | def plot_all(x,cond,x_gt,img_cond,load_path,std_min,corr_mode,sample_str=''):
111 | sdf_min = x_gt.min().item()
112 | sdf_max = x_gt.max().item()
113 |
114 | show_images(x, vmin=sdf_min, vmax=sdf_max, save_name=load_path + "/samples/samples_" + str(sample_str) + ".png")
115 | if img_cond == 1:
116 | show_images(cond, save_name=load_path + "/samples/condition.png")
117 | show_images(x_gt, vmin=sdf_min, vmax=sdf_max, save_name=load_path + "/samples/groundtruth.png")
118 |
119 | x_thresh = torch.where(x > 3.*std_min, torch.zeros_like(x), torch.ones_like(x))
120 | show_images(x_thresh, save_name=load_path + "/samples/samples_thresholded_.png")
121 |
122 | x_gt_thresh = torch.where(x_gt > 0., torch.zeros_like(x_gt), torch.ones_like(x_gt))
123 | show_images(x_gt_thresh, save_name=load_path + "/samples/groundtruth_thresholded.png")
124 |
125 | # show thresholded maps on top of conditioning image
126 | vmax = x_gt.shape[1]
127 | show_images(x_thresh, save_name=load_path + "/samples/samples_thresholded_overlay.png", vmax=vmax, overlay=cond)
128 | show_images(x_gt_thresh, save_name=load_path + "/samples/groundtruth_thresholded_overlay.png", vmax=vmax, overlay=cond)
129 |
130 | class Sampling:
131 | def __init__(self, scorenet, model_type, device, load_path, sz, noise_level_dict, beta_dict, sde, img_cond, corr_mode, save_images=True):
132 | # general params
133 | self.scorenet = scorenet
134 | self.device = device
135 | self.load_path = load_path
136 | self.sz = sz
137 | self.sde = sde
138 | self.img_cond = img_cond
139 | self.model_type = model_type
140 | self.corr_mode = corr_mode
141 |
142 | # if set to False, no images are saved
143 | self.save_images = save_images
144 |
145 | if self.sde == 've':
146 | self.s1, self.sL, self.L = noise_level_dict['s1'], noise_level_dict['sL'], noise_level_dict['L']
147 | self.sigmas = torch.tensor(np.exp(np.linspace(np.log(self.s1),np.log(self.sL), self.L))).type(torch.float32)
148 | elif self.sde == 'vp':
149 | self.beta1, self.betaT, self.T = beta_dict['beta1'], beta_dict['betaT'], beta_dict['T']
150 | self.betas = np.linspace(1.E-4, 0.02, 1000, dtype=np.float32)
151 | self.alphas = 1 - self.betas
152 | self.alpha_bars = torch.from_numpy(np.asarray([np.prod(self.alphas[:i + 1]) for i in range(len(self.alphas))]))
153 |
154 | def get_sigma(self,t):
155 | return self.sigmas[-1]*(self.sigmas[0]/self.sigmas[-1])**t
156 |
157 | def sample(self, x, m_gt, n_samples, N, M, r, num_classes=2):
158 | if self.sde == 've':
159 | return self._sample_ve(x, m_gt, n_samples=n_samples, N=N, M=M, r=r,num_classes=num_classes)
160 | elif self.sde == 'vp':
161 | return self._sample_vp(x, m_gt, n_samples=n_samples,num_classes=num_classes)
162 |
163 | def _sample_vp(self,x, m_gt, n_samples, num_classes):
164 | # TODO: sample according to DDPM paper, note up to date this is fixed to 1000 time steps but could be adapted with a continuous loss function
165 | """
166 | Sample according to DDPM paper
167 | """
168 | m = torch.randn(n_samples,1,self.sz,self.sz).float().to(self.device)
169 | device = x.device
170 | m_list = []
171 | with torch.no_grad():
172 | for i, t in tqdm(enumerate(list(range(self.T))[::-1])):
173 | # Estimating noise to be removed
174 | time_tensor = (torch.ones(n_samples, 1) * t).to(self.device).long()
175 | eta_theta = self.scorenet(m,t*torch.ones((n_samples,1)).to(self.device),img_cond=x)
176 | alpha_t = self.alphas[t]
177 | alpha_t_bar = self.alpha_bars[t]
178 |
179 | # Partially denoising the image
180 | m = (1 / np.sqrt(alpha_t)) * (
181 | m - (1 - alpha_t) / np.sqrt(1 - alpha_t_bar) * eta_theta
182 | )
183 |
184 | m_list.append(m.detach().cpu())
185 | if t > 0: # no noise added in last sampling step
186 | z = torch.randn(n_samples, 1, self.sz, self.sz).to(device)
187 | beta_t = self.betas[t]
188 | # # Option 1: sigma_t squared = beta_t
189 | # sigma_t = np.sqrt(beta_t)
190 |
191 | # Option 2: sigma_t squared = beta_tilda_t
192 | prev_alpha_t_bar = self.alpha_bars[t-1] if t > 0 else self.alphas[0]
193 | beta_tilde_t = ((1 - prev_alpha_t_bar)/(1 - alpha_t_bar)) * beta_t
194 | sigma_t = np.sqrt(beta_tilde_t)
195 |
196 | # Adding some more noise like in Langevin Dynamics fashion
197 | m = m + sigma_t * z
198 |
199 | if self.save_images:
200 | plot_all(m,x,m_gt,self.img_cond,load_path,std_min=1e-3,corr_mode=self.corr_mode,sample_str='vp')
201 |
202 | # x_list.append(x.detach().cpu())
203 | return m, m_list
204 |
205 | def _sample_ve(self, cond, x_gt, n_samples, N, M, r, num_classes):
206 | """
207 | Sample using reverse-time SDE (VE-SDE)
208 | N number of predictor steps
209 | M number of corrector steps
210 | r "signal-to-noise" ratio
211 | """
212 |
213 | frames = []
214 | frames_thresh = []
215 | frames_all = []
216 | frames_per_gif = 100
217 | frame_idxs = np.linspace(0, N, frames_per_gif).astype(np.uint)
218 |
219 | t = torch.linspace(1-(1./N),0,N) # TODO: fix start at 1 or 1-dt
220 | sigma_t = self.get_sigma(t[0])
221 | x_list = []
222 |
223 | # initialize x and sample
224 | n_samples = x_gt.shape[0]
225 | x = self.sigmas[0]*torch.clip(torch.randn(n_samples,num_classes,self.sz,self.sz),-2.,2.).to(self.device)
226 |
227 | with torch.no_grad():
228 | for i, t_curr in enumerate(t):
229 | if i % 20 == 0:
230 | iou, dice = compute_metrics(x,x_gt,thresh=3*self.sigmas[-1].item(),corr_mode=self.corr_mode,num_classes=num_classes)
231 | print('PC sampling it [%i]:\t IoU [%.6f], Dice [%.6f]' %(i,iou,dice))
232 |
233 | # set sigma(t)
234 | sigma_t_prev = sigma_t.clone()
235 | sigma_t = self.get_sigma(t_curr)
236 |
237 | # get scores, sample noise
238 | if self.model_type == 'unet':
239 | scores = self.scorenet(x,sigma_t_prev*torch.ones((n_samples,1)).to(self.device),img_cond=cond)
240 | elif self.model_type == 'tdv':
241 | scores = self.scorenet.grad(torch.cat([x,cond],1),sigma_t_prev*torch.ones((n_samples,1,1,1)).to(self.device))[:,0:1]
242 |
243 | z = torch.clip(torch.randn_like(x),-2.,2.)
244 | tau = (sigma_t_prev**2 - sigma_t**2)
245 |
246 | # predictor step
247 | x = x + tau*scores
248 | x_list.append(x.detach().cpu())
249 | x += np.sqrt(tau)*z
250 |
251 | # corrector steps
252 | for j in range(M):
253 | # z = torch.randn_like(x)
254 | z = torch.clip(torch.randn_like(x),-2.,2.)
255 |
256 | # compute eps
257 | if self.model_type == 'unet':
258 | scores_corr = self.scorenet(x,sigma_t*torch.ones((n_samples,1)).to(self.device),img_cond=cond)
259 | elif self.model_type == 'tdv':
260 | scores_corr = self.scorenet.grad(torch.cat([x,cond],1),sigma_t*torch.ones((n_samples,1,1,1)).to(self.device))[:,0:1]
261 |
262 | eps = 2*(r*torch.norm(z).item()/torch.norm(scores_corr).item())**2
263 | x = x + eps*scores_corr
264 |
265 | x_list.append(x.detach().cpu())
266 | x += np.sqrt(2*eps)*z
267 |
268 | if self.save_images and (i in frame_idxs or t_curr == 0): # TODO: if other samplers than PC are used, make sure that the gif is also generated for them
269 | # Putting digits in range [0, 255]
270 | normalized = x.clone()
271 | if self.corr_mode == 'diffusion_ls':
272 | normalized_thresh = torch.where(x > 3*self.sigmas[-1], torch.zeros_like(x), torch.ones_like(x))
273 | elif self.corr_mode == 'diffusion':
274 | normalized_thresh = torch.where(x < 0.5, torch.zeros_like(x), torch.ones_like(x))
275 |
276 | for i in range(len(normalized)):
277 | normalized[i] -= torch.min(normalized[i])
278 | normalized[i] *= 255 / torch.max(normalized[i])
279 |
280 | normalized_thresh[i] -= torch.min(normalized_thresh[i])
281 | normalized_thresh[i] *= 255 / torch.max(normalized_thresh[i])
282 |
283 | # Reshaping batch (n, c, h, w) to be a (as much as it gets) square frame
284 | frame = einops.rearrange(
285 | normalized,
286 | "(b1 b2) c h w -> (b1 h) (b2 w) c",
287 | b1=int(n_samples**0.5),
288 | )
289 | frame = frame.cpu().numpy().astype(np.uint8)
290 | frames.append(frame)
291 |
292 | # append thresholded
293 | frame_thresh = einops.rearrange(
294 | normalized_thresh,
295 | "(b1 b2) c h w -> (b1 h) (b2 w) c",
296 | b1=int(n_samples**0.5),
297 | )
298 | frame_thresh = frame_thresh.cpu().numpy().astype(np.uint8)
299 | frames_thresh.append(frame_thresh)
300 |
301 | # plotting
302 | if self.save_images:
303 | plot_all(x_list[-1],cond,x_gt,self.img_cond,load_path,std_min=cfg.SMLD.sigma_L,corr_mode=self.corr_mode,sample_str='ve')
304 |
305 | return x_list[-1], x_list
306 |
307 | if __name__ == "__main__":
308 | with open(args.config) as file:
309 | yaml_cfg = yaml.safe_load(file)
310 | cfg = json.loads(json.dumps(yaml_cfg), object_hook=lambda d: SimpleNamespace(**d))
311 |
312 | device = torch.device("cuda")
313 | print(f"Using device: {device}\t" + (f"{torch.cuda.get_device_name(0)}"))
314 |
315 | set_seed(SEED=0)
316 |
317 | # set up dataloader, model
318 | train_dl, test_dl = config_dl(cfg)
319 | if cfg.model.type == 'unet':
320 | model = ddpm.Network(
321 | dim=cfg.model.dim,
322 | channels=cfg.model.n_cin,
323 | cond_channels=cfg.model.n_cin_cond,
324 | init_dim=cfg.model.n_fm,
325 | dim_mults=tuple(cfg.model.mults),
326 | embedding=cfg.model.embedding,
327 | img_cond=cfg.general.img_cond,
328 | with_class_label_emb=cfg.general.with_class_label_emb,
329 | class_label_cond=cfg.general.class_label_cond,
330 | num_classes=cfg.general.num_classes,
331 | ).to(device)
332 |
333 | else:
334 | raise ValueError('Unknown model type!')
335 |
336 | load_path = os.getcwd() + "/runs/" + cfg.general.modality
337 |
338 | if cfg.inference.latest:
339 | print("\nINFO: inference the lastest experiment!")
340 | load_path += "/" + sorted(os.listdir(load_path))[-1]
341 | else:
342 | print("\nINFO: inference from selected experiment, *not* the latest!")
343 | load_path += "/" + cfg.inference.load_exp
344 |
345 | print(f"Loading *latest* checkpoint from {load_path + '/models/'}")
346 | if not os.path.exists(load_path + "/samples"): # makedir for samples
347 | os.mkdir(load_path + "/samples")
348 |
349 | # load the model weights
350 | fnames = sorted(
351 | [fname for fname in os.listdir(load_path + "/models/") if fname.endswith(".pt")]
352 | )
353 | model.load_state_dict(torch.load(load_path + "/models/" + fnames[-1], map_location=device)["state_dict"],strict=True) # strict=False
354 |
355 | model.eval()
356 | print("\nModel loaded from %s" % (load_path + "/models/" + fnames[-1]))
357 |
358 | # sample and save generated images
359 | if cfg.general.corr_mode == "diffusion" or cfg.general.corr_mode == "diffusion_ls":
360 | noise_level_dict={'s1': cfg.SMLD.sigma_1_m, 'sL': cfg.SMLD.sigma_L_m, 'L': cfg.SMLD.n_steps}
361 | beta_dict = {'beta1': cfg.SMLD.beta_1, 'betaT': cfg.SMLD.beta_T, 'T': cfg.SMLD.T}
362 |
363 | Sampler = Sampling(
364 | scorenet=model,
365 | model_type=cfg.model.type,
366 | device=device,
367 | load_path=load_path,
368 | sz=cfg.general.sz,
369 | noise_level_dict=noise_level_dict,
370 | beta_dict=beta_dict,
371 | sde=cfg.SMLD.sde,
372 | img_cond=cfg.general.img_cond,
373 | corr_mode=cfg.general.corr_mode,
374 | save_images=True)
375 |
376 | # load conditioning image and ground truth
377 | it_test_dl = iter(test_dl)
378 | batch = next(it_test_dl)
379 | x = None if cfg.general.img_cond==0 else batch['mask'].to(device)
380 | m_gt = None if cfg.general.img_cond==0 else batch['image'].to(device)
381 |
382 | # generate samples
383 | samples, samples_list = Sampler.sample(x, m_gt, n_samples=cfg.inference.n_samples, N=cfg.SMLD.N,M=cfg.SMLD.M,r=cfg.SMLD.r, num_classes=cfg.model.n_cin)
384 |
385 | # eval metrics
386 | iou, dice = compute_metrics(samples, m_gt.cpu(), corr_mode=cfg.general.corr_mode,thresh=3.*cfg.SMLD.sigma_L_m,num_classes=cfg.model.n_cin)
387 | print('\nFinal metrics: IoU [%f], Dice [%f]' %(iou,dice))
388 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import random
2 | import imageio
3 | import numpy as np
4 | import argparse
5 | import sys
6 | import os
7 | import json
8 | from tqdm.auto import tqdm
9 | import matplotlib.pyplot as plt
10 |
11 | import einops
12 | import torch
13 | import torch.nn as nn
14 | from torch.optim import Adam
15 | from torch.utils.data import DataLoader
16 | from torch.utils.tensorboard import SummaryWriter
17 | from torchvision.utils import make_grid
18 |
19 | from SimulationHelper.simulation import Simulation
20 | from datasets.config_dl import config_dl
21 | from datasets.transform_factory import inv_normalize, transform_factory
22 | from models import ddpm
23 |
24 | class TrainScoreNetwork:
25 | def __init__(self, noise_level_dict, beta_dict, sde, model_type, train_objective, anneal_power=2, loss_power=2, n_val=8, val_dl=None):
26 | self.sde = sde
27 | self.model_type = model_type
28 |
29 | if self.sde == 've':
30 | self.s1, self.sL, self.L = noise_level_dict['s1'], noise_level_dict['sL'], noise_level_dict['L']
31 | self.sigmas = torch.tensor(np.exp(np.linspace(np.log(self.s1),np.log(self.sL), self.L))).type(torch.float32)
32 |
33 | self.model_type = model_type
34 | self.anneal_power = anneal_power
35 | self.loss_power = loss_power
36 | self.train_objective = train_objective
37 | assert train_objective == 'disc' or train_objective == 'cont'
38 |
39 | if val_dl: # then use test dataloader
40 | val_batch = next(iter(val_dl))
41 | self.x_val = val_batch['image'][:n_val]
42 | self.cond_val = val_batch['mask'][:n_val]
43 |
44 | eta_val = torch.randn_like(self.x_val)
45 | self.used_sigmas_val = torch.linspace(self.sigmas[0],self.sigmas[-1], self.x_val.shape[0])[:,None,None,None]
46 | self.z_val = self.x_val + eta_val*self.used_sigmas_val
47 |
48 | elif self.sde == 'vp':
49 | self.beta1, self.betaT, self.T = beta_dict['beta1'], beta_dict['betaT'], beta_dict['T']
50 | self.betas = np.linspace(1.E-4, 0.02, 1000, dtype=np.float32)
51 | self.alphas = 1 - self.betas
52 | self.alpha_bars = torch.from_numpy(np.asarray([np.prod(self.alphas[:i + 1]) for i in range(len(self.alphas))]))
53 |
54 | if val_dl: # TODO: implement validation
55 | val_batch = next(iter(val_dl))
56 | self.x_val = val_batch['image'][:n_val]
57 | pass
58 |
59 | else:
60 | raise ValueError('Unknown SDE type!')
61 |
62 | def get_grad_norm(self, model):
63 | parameters = [p for p in model.parameters() if p.grad is not None and p.requires_grad]
64 | norms = [p.grad.detach().abs().max().item() for p in parameters]
65 | return np.asarray(norms).max()
66 |
67 | def do(self, scorenet, dl, n_epochs, clip, optim, device, simulation, writer, img_cond, class_label_cond=False):
68 | if self.sde == 've':
69 | self._do_ve(scorenet, dl, n_epochs, clip, optim, device, simulation, writer, img_cond, class_label_cond)
70 |
71 | elif self.sde == 'vp':
72 | self._do_vp(scorenet, dl, n_epochs, clip, optim, device, simulation, writer, img_cond, class_label_cond)
73 |
74 | def _do_ve(self, scorenet, dl, n_epochs, clip, optim, device, simulation, writer, img_cond=0, class_label_cond=False):
75 | if img_cond == 0:
76 | self.cond_val = None
77 | else:
78 | self.cond_val = self.cond_val.to(device)
79 | best_loss = float("inf")
80 |
81 | for epoch in tqdm(range(n_epochs), desc=f"Training progress", colour="#00ff00"):
82 | epoch_loss = 0.0
83 | grad_norms_epoch = []
84 |
85 | for step, batch in enumerate(tqdm(dl, leave=False, desc=f"Epoch {epoch + 1}/{n_epochs}", colour="#005500")):
86 | # Loading data
87 | x = batch['image'].to(device)
88 | cond = None if img_cond==0 else batch['mask'].to(device)
89 | lbl = None if class_label_cond is False else batch['label'].to(device).unsqueeze(1)
90 | n = len(x)
91 |
92 | # noise-conditional score network corruption
93 | if self.train_objective == 'disc':
94 | sigmas_idx = torch.randint(0, self.L, (n,))#.to(device)
95 | used_sigmas = (self.sigmas[sigmas_idx][:,None,None,None]).to(device)
96 | elif self.train_objective == 'cont': # continuous training objective (SDE style)
97 | t = torch.from_numpy(np.random.uniform(1e-5,1,(n,))).float()
98 | used_sigmas = (self.sigmas[-1]*(self.sigmas[0]/self.sigmas[-1])**t)[:,None,None,None].to(device)
99 |
100 | # noise corruption
101 | eta = torch.randn_like(x).to(device)
102 | z = x + eta*used_sigmas.to(device)
103 |
104 | # compute score matching loss
105 | target = 1/(used_sigmas**2) * (x-z)
106 | if self.model_type == 'unet':
107 | scores = scorenet(z, used_sigmas.reshape(n,-1), img_cond=cond, class_lbl=lbl)
108 | elif self.model_type == 'tdv':
109 | scores = scorenet.grad(torch.cat([z,cond],1), used_sigmas.reshape(n,1,1,1))[:,0:1]
110 |
111 | if step % 100 == 0: # Sanity Check. Whats going into the network?
112 | with torch.no_grad():
113 | scorenet.eval()
114 | if self.x_val is not None: # always take same val/test batch
115 | if self.model_type == 'unet':
116 | scores_val = scorenet(self.z_val.to(device), self.used_sigmas_val.to(device).reshape(self.z_val.shape[0],-1), img_cond=self.cond_val, class_lbl=lbl)
117 | elif self.model_type == 'tdv':
118 | scores_val = scorenet.grad(torch.cat([self.z_val.to(device),self.cond_val],1), self.used_sigmas_val.to(device).reshape(self.z_val.shape[0],1,1,1))[:,0:1]
119 |
120 | x_mmse_val = self.z_val.to(device) + self.used_sigmas_val.to(device)**2 * scores_val
121 |
122 | # for multi-class plotting just take a random class
123 | if self.x_val.shape[1] > 1:
124 | class_idx = 4
125 | x_val, z_val, x_mmse_val = self.x_val[:,class_idx][:,None], self.z_val[:,class_idx][:,None], x_mmse_val[:,class_idx][:,None]
126 | else:
127 | x_val, z_val = self.x_val, self.z_val
128 |
129 | all_stacked = torch.cat([
130 | make_grid(x_val, nrow=x_val.shape[0], normalize=True, scale_each=True).cpu(),
131 | make_grid(z_val, nrow=x_val.shape[0], normalize=True,scale_each=True).cpu(),
132 | make_grid(x_mmse_val, nrow=self.x_val.shape[0], normalize=True, scale_each=True).cpu()], dim=1)
133 |
134 | else: # check on random input data
135 | x_mmse = z + used_sigmas**2*scores
136 | all_stacked = torch.cat([
137 | make_grid(x, nrow=x.shape[0], normalize=True, scale_each=True).cpu(),
138 | make_grid(z, nrow=x.shape[0], normalize=True,scale_each=True).cpu(),
139 | make_grid(x_mmse, nrow=x.shape[0], normalize=True, scale_each=True).cpu()], dim=1)
140 |
141 | # plot clean, noisy, and denoised (using Tweedie's formula)
142 | if step % 100 == 0:
143 | writer.add_image(f'training', all_stacked, global_step=epoch)
144 | writer.flush()
145 | scorenet.train()
146 |
147 | # Optimizing the MSE between the noise plugged and the predicted noise #
148 | loss_batches = ((torch.abs(target - scores))**self.loss_power).sum((-3,-2,-1))*used_sigmas.squeeze()**self.anneal_power # NOTE: L1 loss and anneal_power should match
149 | loss = loss_batches.mean()
150 |
151 | optim.zero_grad()
152 | loss.backward()
153 |
154 | if isinstance(clip,float):
155 | torch.nn.utils.clip_grad_norm_(scorenet.parameters(), max_norm=clip, norm_type='inf')
156 | grad_norms_epoch.append(self.get_grad_norm(scorenet))
157 |
158 | optim.step()
159 | epoch_loss += loss.item() * len(x) / len(dl.dataset)
160 |
161 |
162 | log_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.8f}"
163 | if epoch % 50 == 0:
164 | writer.add_scalar(f'train/epoch_loss', epoch_loss, epoch)
165 |
166 | writer.add_scalar(f'train/epoch_max_grad', np.asarray(grad_norms_epoch).max(), epoch)
167 | writer.add_scalar(f'train/epoch_mean_grad', np.asarray(grad_norms_epoch).mean(), epoch)
168 |
169 | # Storing the model
170 | if epoch % 5000 == 0: # save every 5000th epochs model, no matter what?
171 | checkpoint = {'state_dict': scorenet.state_dict()}
172 | simulation.save_pytorch(checkpoint, overwrite=False, subdir='models_sanity', epoch='_'+'{0:07}'.format(epoch))
173 |
174 | if best_loss > epoch_loss:
175 | best_loss = epoch_loss
176 |
177 | # save last 3 checkpoints
178 | if epoch > 0:
179 | cp_dir = simulation._outdir + '/models'
180 | if len([name for name in os.listdir(cp_dir) if os.path.isfile(os.path.join(cp_dir,name))]) == 3:
181 | fnames = sorted([fname for fname in os.listdir(cp_dir) if fname.endswith('.pt')])
182 | os.remove(os.path.join(cp_dir,fnames[0]))
183 | checkpoint = {'epoch': epoch,
184 | 'state_dict': scorenet.state_dict(),
185 | 'optimizer': optim.state_dict()}
186 | simulation.save_pytorch(checkpoint, overwrite=False, epoch='_'+'{0:07}'.format(epoch))
187 | log_string += " --> Best model ever (stored)"
188 |
189 | print(log_string)
190 |
191 | def _do_vp(self, scorenet, dl, n_epochs, clip, optim, device, simulation, writer, img_cond, class_label_cond=False):
192 | best_loss = float("inf")
193 | mse = nn.MSELoss()
194 |
195 | for epoch in tqdm(range(n_epochs), desc=f"Training progress", colour="#00ff00"):
196 | epoch_loss = 0.0
197 | grad_norms_epoch = []
198 |
199 | for step, batch in enumerate(tqdm(dl, leave=False, desc=f"Epoch {epoch + 1}/{n_epochs}", colour="#005500")):
200 | # Loading data
201 | m = batch['image'].to(device)
202 | m *= 5. # TODO: comment if a clean data loader *without* scaling to [-0.2,0.2] is used - this is to get the mask to [-1,1] like the image
203 | x = None if img_cond==0 else batch['mask'].to(device)
204 | lbl = None if class_label_cond is False else batch['label'].to(device).unsqueeze(1)
205 | n = len(m)
206 |
207 | # noise corruption
208 | t = torch.randint(0, self.T, (n,)).to(device)
209 | a_bar = self.alpha_bars.to(device)[t]#.to(x.device)
210 | eta = torch.randn_like(x).to(device)
211 | m_noisy = a_bar.sqrt().reshape(n, 1, 1, 1) * m + (1 - a_bar).sqrt().reshape(n, 1, 1, 1) * eta
212 |
213 | # compute score matching loss
214 | if self.model_type == 'unet':
215 | eta_estimated = scorenet(m_noisy, t.reshape(n,-1), img_cond=x, class_lbl=lbl)
216 |
217 | elif self.model_type == 'uvit':
218 | eta_estimated = scorenet(m_noisy, t.reshape(n,-1), img_cond=x)
219 |
220 | elif self.model_type == 'tdv':
221 | raise NotImplementedError
222 |
223 | # Optimizing the MSE between the noise plugged and the predicted noise #
224 | loss = mse(eta_estimated, eta)
225 | optim.zero_grad()
226 | loss.backward()
227 |
228 | if isinstance(clip,float):
229 | torch.nn.utils.clip_grad_norm_(scorenet.parameters(), max_norm=clip, norm_type='inf')
230 | grad_norms_epoch.append(self.get_grad_norm(scorenet))
231 |
232 | optim.step()
233 | epoch_loss += loss.item() * len(x) / len(dl.dataset)
234 |
235 | log_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.8f}"
236 | if epoch % 10 == 0:
237 | writer.add_scalar(f'train/epoch_loss', epoch_loss, epoch)
238 | writer.add_scalar(f'train/epoch_max_grad', np.asarray(grad_norms_epoch).max(), epoch)
239 | writer.add_scalar(f'train/epoch_mean_grad', np.asarray(grad_norms_epoch).mean(), epoch)
240 |
241 | if best_loss > epoch_loss:
242 | best_loss = epoch_loss
243 | # save last 3 checkpoints
244 | if epoch > 0:
245 | cp_dir = simulation._outdir + '/models'
246 | if len([name for name in os.listdir(cp_dir) if os.path.isfile(os.path.join(cp_dir,name))]) == 3:
247 | fnames = sorted([fname for fname in os.listdir(cp_dir) if fname.endswith('.pt')])
248 | os.remove(os.path.join(cp_dir,fnames[0]))
249 | checkpoint = {'epoch': epoch,
250 | 'state_dict': scorenet.state_dict(),
251 | 'optimizer': optim.state_dict()}
252 | simulation.save_pytorch(checkpoint, overwrite=False, epoch='_'+'{0:07}'.format(epoch))
253 | log_string += " --> Best model ever (stored)"
254 | print(log_string)
--------------------------------------------------------------------------------