├── .gitignore ├── LICENSE.md ├── README.md ├── assets └── architecture.png ├── checkpoints └── segment_retinaUCSF_seed1.pty ├── comparison └── interpolation │ └── run_baseline_interp.py ├── data └── placeholder.md ├── external_src ├── I2SB │ ├── __init__.py │ ├── guided_diffusion │ │ ├── LICENSE_GUIDED_DIFFUSION │ │ ├── __init__.py │ │ ├── fp16_util.py │ │ ├── gaussian_diffusion.py │ │ ├── logger.py │ │ ├── losses.py │ │ ├── nn.py │ │ ├── respace.py │ │ ├── script_util.py │ │ └── unet.py │ └── i2sb │ │ ├── __init__.py │ │ ├── ckpt_util.py │ │ ├── diffusion.py │ │ ├── network.py │ │ ├── runner.py │ │ └── util.py └── SuperRetina │ ├── README.md │ ├── common │ ├── common_util.py │ ├── eval_util.py │ └── train_util.py │ ├── config │ ├── test.yaml │ ├── test_VARIA.yaml │ └── train.yaml │ ├── dataset │ └── retina_dataset.py │ ├── loss │ ├── dice_loss.py │ └── triplet_loss.py │ ├── model │ ├── pke_module.py │ ├── record_module.py │ └── super_retina.py │ ├── predictor.py │ ├── requirements.txt │ ├── save │ └── .placehold │ ├── test_on_FIRE.py │ ├── test_on_VARIA.py │ └── train.py ├── results └── placeholder.md └── src ├── data_utils ├── __init__.py ├── extend.py ├── prepare_dataset.py └── split.py ├── datasets ├── __init__.py ├── brain_gbm.py ├── brain_ms.py ├── retina_areds.py ├── retina_ucsf.py └── synthetic.py ├── nn ├── __init__.py ├── autoencoder.py ├── autoencoder_ode.py ├── autoencoder_t_emb.py ├── aux_net.py ├── base.py ├── common_encoder.py ├── imageflownet_ode.py ├── imageflownet_sde.py ├── nn_utils.py ├── off_the_shelf_encoder.py ├── scheduler.py ├── unet_i2sb.py ├── unet_ode.py ├── unet_ode_simple.py ├── unet_ode_simple_position_parametrized.py ├── unet_sode.py └── unet_t_emb.py ├── plotting └── demo_gradient_field.py ├── preprocessing ├── 01_preprocess_brain_GBM.py ├── 01_preprocess_brain_MS.py ├── 01_preprocess_retina_AREDS.py ├── 01_preprocess_retina_UCSF.py ├── 02_register_retina_AREDS.py ├── 02_register_retina_UCSF.py ├── 03_crop_retina_UCSF.py ├── 03_generate_eye_mask_retina_AREDS.py ├── archive_register_retina_AREDS.py ├── deprecated_03_crop_retina_AREDS.py ├── synthesize_dataset.py └── test_registration.py ├── scripts └── test_time_optimization.py ├── train_2pt_all.py ├── train_npt_cde.py ├── train_npt_sode.py ├── train_predictor.py ├── train_segmentor.py └── utils ├── __init__.py ├── attribute_hashmap.py ├── early_stop.py ├── log_util.py ├── metrics.py ├── parse.py └── seed.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | data 3 | **/__pycache__ 4 | 5 | # Figures 6 | **/*.png 7 | 8 | # Models 9 | **/*.pt 10 | **/*.pkl 11 | **/*.pth 12 | **/*.pth.tar 13 | 14 | # Data 15 | **/*.npy 16 | **/*.npz 17 | 18 | # Slurm jobs 19 | **/*.out 20 | **/bash_* -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Non-Commercial License Yale Copyright © 2024 Yale University. 2 | 3 | Permission is hereby granted to use, copy, modify, and distribute this Software for any non-commercial purpose. Any distribution or modification or derivations of the Software (together “Derivative Works”) must be made available on GitHub and shall include this copyright notice and this permission notice in all copies or substantial portions of the Software. For the purposes of this license, "non-commercial" means not intended for or directed towards commercial advantage or monetary compensation either via the Software itself or Derivative Works or uses of either which lead to or generate any commercial products. In any event, the use and modification of the Software or Derivative Works shall remain governed by the terms and conditions of this Agreement; Any commercial use of the Software requires a separate commercial license from the copyright holder at Yale University. Direct any requests for commercial licenses to Yale Ventures at yaleventures@yale.edu. 4 | 5 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 6 | -------------------------------------------------------------------------------- /assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/ImageFlowNet/61412a9a014776e653de5bd713b6b817852be3d8/assets/architecture.png -------------------------------------------------------------------------------- /checkpoints/segment_retinaUCSF_seed1.pty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/ImageFlowNet/61412a9a014776e653de5bd713b6b817852be3d8/checkpoints/segment_retinaUCSF_seed1.pty -------------------------------------------------------------------------------- /data/placeholder.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/ImageFlowNet/61412a9a014776e653de5bd713b6b817852be3d8/data/placeholder.md -------------------------------------------------------------------------------- /external_src/I2SB/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/ImageFlowNet/61412a9a014776e653de5bd713b6b817852be3d8/external_src/I2SB/__init__.py -------------------------------------------------------------------------------- /external_src/I2SB/guided_diffusion/LICENSE_GUIDED_DIFFUSION: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /external_src/I2SB/guided_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Taken from the following link as is from: 3 | # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/__init__.py 4 | # 5 | # The license for the original version of this file can be 6 | # found in this directory (LICENSE_GUIDED_DIFFUSION). 7 | # --------------------------------------------------------------- 8 | 9 | """ 10 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 11 | """ 12 | -------------------------------------------------------------------------------- /external_src/I2SB/guided_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Taken from the following link as is from: 3 | # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/fp16_util.py 4 | # 5 | # The license for the original version of this file can be 6 | # found in this directory (LICENSE_GUIDED_DIFFUSION). 7 | # --------------------------------------------------------------- 8 | 9 | """ 10 | Helpers to train with 16-bit precision. 11 | """ 12 | 13 | import numpy as np 14 | import torch as th 15 | import torch.nn as nn 16 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 17 | 18 | from . import logger 19 | 20 | INITIAL_LOG_LOSS_SCALE = 20.0 21 | 22 | 23 | def convert_module_to_f16(l): 24 | """ 25 | Convert primitive modules to float16. 26 | """ 27 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 28 | l.weight.data = l.weight.data.half() 29 | if l.bias is not None: 30 | l.bias.data = l.bias.data.half() 31 | 32 | 33 | def convert_module_to_f32(l): 34 | """ 35 | Convert primitive modules to float32, undoing convert_module_to_f16(). 36 | """ 37 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 38 | l.weight.data = l.weight.data.float() 39 | if l.bias is not None: 40 | l.bias.data = l.bias.data.float() 41 | 42 | 43 | def make_master_params(param_groups_and_shapes): 44 | """ 45 | Copy model parameters into a (differently-shaped) list of full-precision 46 | parameters. 47 | """ 48 | master_params = [] 49 | for param_group, shape in param_groups_and_shapes: 50 | master_param = nn.Parameter( 51 | _flatten_dense_tensors( 52 | [param.detach().float() for (_, param) in param_group] 53 | ).view(shape) 54 | ) 55 | master_param.requires_grad = True 56 | master_params.append(master_param) 57 | return master_params 58 | 59 | 60 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 61 | """ 62 | Copy the gradients from the model parameters into the master parameters 63 | from make_master_params(). 64 | """ 65 | for master_param, (param_group, shape) in zip( 66 | master_params, param_groups_and_shapes 67 | ): 68 | master_param.grad = _flatten_dense_tensors( 69 | [param_grad_or_zeros(param) for (_, param) in param_group] 70 | ).view(shape) 71 | 72 | 73 | def master_params_to_model_params(param_groups_and_shapes, master_params): 74 | """ 75 | Copy the master parameter data back into the model parameters. 76 | """ 77 | # Without copying to a list, if a generator is passed, this will 78 | # silently not copy any parameters. 79 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 80 | for (_, param), unflat_master_param in zip( 81 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 82 | ): 83 | param.detach().copy_(unflat_master_param) 84 | 85 | 86 | def unflatten_master_params(param_group, master_param): 87 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 88 | 89 | 90 | def get_param_groups_and_shapes(named_model_params): 91 | named_model_params = list(named_model_params) 92 | scalar_vector_named_params = ( 93 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 94 | (-1), 95 | ) 96 | matrix_named_params = ( 97 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 98 | (1, -1), 99 | ) 100 | return [scalar_vector_named_params, matrix_named_params] 101 | 102 | 103 | def master_params_to_state_dict( 104 | model, param_groups_and_shapes, master_params, use_fp16 105 | ): 106 | if use_fp16: 107 | state_dict = model.state_dict() 108 | for master_param, (param_group, _) in zip( 109 | master_params, param_groups_and_shapes 110 | ): 111 | for (name, _), unflat_master_param in zip( 112 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 113 | ): 114 | assert name in state_dict 115 | state_dict[name] = unflat_master_param 116 | else: 117 | state_dict = model.state_dict() 118 | for i, (name, _value) in enumerate(model.named_parameters()): 119 | assert name in state_dict 120 | state_dict[name] = master_params[i] 121 | return state_dict 122 | 123 | 124 | def state_dict_to_master_params(model, state_dict, use_fp16): 125 | if use_fp16: 126 | named_model_params = [ 127 | (name, state_dict[name]) for name, _ in model.named_parameters() 128 | ] 129 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 130 | master_params = make_master_params(param_groups_and_shapes) 131 | else: 132 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 133 | return master_params 134 | 135 | 136 | def zero_master_grads(master_params): 137 | for param in master_params: 138 | param.grad = None 139 | 140 | 141 | def zero_grad(model_params): 142 | for param in model_params: 143 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 144 | if param.grad is not None: 145 | param.grad.detach_() 146 | param.grad.zero_() 147 | 148 | 149 | def param_grad_or_zeros(param): 150 | if param.grad is not None: 151 | return param.grad.data.detach() 152 | else: 153 | return th.zeros_like(param) 154 | 155 | 156 | class MixedPrecisionTrainer: 157 | def __init__( 158 | self, 159 | *, 160 | model, 161 | use_fp16=False, 162 | fp16_scale_growth=1e-3, 163 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 164 | ): 165 | self.model = model 166 | self.use_fp16 = use_fp16 167 | self.fp16_scale_growth = fp16_scale_growth 168 | 169 | self.model_params = list(self.model.parameters()) 170 | self.master_params = self.model_params 171 | self.param_groups_and_shapes = None 172 | self.lg_loss_scale = initial_lg_loss_scale 173 | 174 | if self.use_fp16: 175 | self.param_groups_and_shapes = get_param_groups_and_shapes( 176 | self.model.named_parameters() 177 | ) 178 | self.master_params = make_master_params(self.param_groups_and_shapes) 179 | self.model.convert_to_fp16() 180 | 181 | def zero_grad(self): 182 | zero_grad(self.model_params) 183 | 184 | def backward(self, loss: th.Tensor): 185 | if self.use_fp16: 186 | loss_scale = 2 ** self.lg_loss_scale 187 | (loss * loss_scale).backward() 188 | else: 189 | loss.backward() 190 | 191 | def optimize(self, opt: th.optim.Optimizer): 192 | if self.use_fp16: 193 | return self._optimize_fp16(opt) 194 | else: 195 | return self._optimize_normal(opt) 196 | 197 | def _optimize_fp16(self, opt: th.optim.Optimizer): 198 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 199 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 200 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 201 | if check_overflow(grad_norm): 202 | self.lg_loss_scale -= 1 203 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 204 | zero_master_grads(self.master_params) 205 | return False 206 | 207 | logger.logkv_mean("grad_norm", grad_norm) 208 | logger.logkv_mean("param_norm", param_norm) 209 | 210 | for p in self.master_params: 211 | p.grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 212 | opt.step() 213 | zero_master_grads(self.master_params) 214 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 215 | self.lg_loss_scale += self.fp16_scale_growth 216 | return True 217 | 218 | def _optimize_normal(self, opt: th.optim.Optimizer): 219 | grad_norm, param_norm = self._compute_norms() 220 | logger.logkv_mean("grad_norm", grad_norm) 221 | logger.logkv_mean("param_norm", param_norm) 222 | opt.step() 223 | return True 224 | 225 | def _compute_norms(self, grad_scale=1.0): 226 | grad_norm = 0.0 227 | param_norm = 0.0 228 | for p in self.master_params: 229 | with th.no_grad(): 230 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 231 | if p.grad is not None: 232 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 233 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 234 | 235 | def master_params_to_state_dict(self, master_params): 236 | return master_params_to_state_dict( 237 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 238 | ) 239 | 240 | def state_dict_to_master_params(self, state_dict): 241 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 242 | 243 | 244 | def check_overflow(value): 245 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 246 | -------------------------------------------------------------------------------- /external_src/I2SB/guided_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Taken from the following link as is from: 3 | # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/losses.py 4 | # 5 | # The license for the original version of this file can be 6 | # found in this directory (LICENSE_GUIDED_DIFFUSION). 7 | # --------------------------------------------------------------- 8 | 9 | """ 10 | Helpers for various likelihood-based losses. These are ported from the original 11 | Ho et al. diffusion models codebase: 12 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 13 | """ 14 | 15 | import numpy as np 16 | 17 | import torch as th 18 | 19 | 20 | def normal_kl(mean1, logvar1, mean2, logvar2): 21 | """ 22 | Compute the KL divergence between two gaussians. 23 | 24 | Shapes are automatically broadcasted, so batches can be compared to 25 | scalars, among other use cases. 26 | """ 27 | tensor = None 28 | for obj in (mean1, logvar1, mean2, logvar2): 29 | if isinstance(obj, th.Tensor): 30 | tensor = obj 31 | break 32 | assert tensor is not None, "at least one argument must be a Tensor" 33 | 34 | # Force variances to be Tensors. Broadcasting helps convert scalars to 35 | # Tensors, but it does not work for th.exp(). 36 | logvar1, logvar2 = [ 37 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 38 | for x in (logvar1, logvar2) 39 | ] 40 | 41 | return 0.5 * ( 42 | -1.0 43 | + logvar2 44 | - logvar1 45 | + th.exp(logvar1 - logvar2) 46 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 47 | ) 48 | 49 | 50 | def approx_standard_normal_cdf(x): 51 | """ 52 | A fast approximation of the cumulative distribution function of the 53 | standard normal. 54 | """ 55 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 56 | 57 | 58 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 59 | """ 60 | Compute the log-likelihood of a Gaussian distribution discretizing to a 61 | given image. 62 | 63 | :param x: the target images. It is assumed that this was uint8 values, 64 | rescaled to the range [-1, 1]. 65 | :param means: the Gaussian mean Tensor. 66 | :param log_scales: the Gaussian log stddev Tensor. 67 | :return: a tensor like x of log probabilities (in nats). 68 | """ 69 | assert x.shape == means.shape == log_scales.shape 70 | centered_x = x - means 71 | inv_stdv = th.exp(-log_scales) 72 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 73 | cdf_plus = approx_standard_normal_cdf(plus_in) 74 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 75 | cdf_min = approx_standard_normal_cdf(min_in) 76 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 77 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 78 | cdf_delta = cdf_plus - cdf_min 79 | log_probs = th.where( 80 | x < -0.999, 81 | log_cdf_plus, 82 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 83 | ) 84 | assert log_probs.shape == x.shape 85 | return log_probs 86 | -------------------------------------------------------------------------------- /external_src/I2SB/guided_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Taken from the following link as is from: 3 | # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/nn.py 4 | # 5 | # The license for the original version of this file can be 6 | # found in this directory (LICENSE_GUIDED_DIFFUSION). 7 | # --------------------------------------------------------------- 8 | 9 | """ 10 | Various utilities for neural networks. 11 | """ 12 | 13 | import math 14 | 15 | import torch as th 16 | import torch.nn as nn 17 | 18 | 19 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 20 | class SiLU(nn.Module): 21 | def forward(self, x): 22 | return x * th.sigmoid(x) 23 | 24 | 25 | class GroupNorm32(nn.GroupNorm): 26 | def forward(self, x): 27 | return super().forward(x.float()).type(x.dtype) 28 | 29 | 30 | def conv_nd(dims, *args, **kwargs): 31 | """ 32 | Create a 1D, 2D, or 3D convolution module. 33 | """ 34 | if dims == 1: 35 | return nn.Conv1d(*args, **kwargs) 36 | elif dims == 2: 37 | return nn.Conv2d(*args, **kwargs) 38 | elif dims == 3: 39 | return nn.Conv3d(*args, **kwargs) 40 | raise ValueError(f"unsupported dimensions: {dims}") 41 | 42 | 43 | def linear(*args, **kwargs): 44 | """ 45 | Create a linear module. 46 | """ 47 | return nn.Linear(*args, **kwargs) 48 | 49 | 50 | def avg_pool_nd(dims, *args, **kwargs): 51 | """ 52 | Create a 1D, 2D, or 3D average pooling module. 53 | """ 54 | if dims == 1: 55 | return nn.AvgPool1d(*args, **kwargs) 56 | elif dims == 2: 57 | return nn.AvgPool2d(*args, **kwargs) 58 | elif dims == 3: 59 | return nn.AvgPool3d(*args, **kwargs) 60 | raise ValueError(f"unsupported dimensions: {dims}") 61 | 62 | 63 | def update_ema(target_params, source_params, rate=0.99): 64 | """ 65 | Update target parameters to be closer to those of source parameters using 66 | an exponential moving average. 67 | 68 | :param target_params: the target parameter sequence. 69 | :param source_params: the source parameter sequence. 70 | :param rate: the EMA rate (closer to 1 means slower). 71 | """ 72 | for targ, src in zip(target_params, source_params): 73 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 74 | 75 | 76 | def zero_module(module): 77 | """ 78 | Zero out the parameters of a module and return it. 79 | """ 80 | for p in module.parameters(): 81 | p.detach().zero_() 82 | return module 83 | 84 | 85 | def scale_module(module, scale): 86 | """ 87 | Scale the parameters of a module and return it. 88 | """ 89 | for p in module.parameters(): 90 | p.detach().mul_(scale) 91 | return module 92 | 93 | 94 | def mean_flat(tensor): 95 | """ 96 | Take the mean over all non-batch dimensions. 97 | """ 98 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 99 | 100 | 101 | def normalization(channels): 102 | """ 103 | Make a standard normalization layer. 104 | 105 | :param channels: number of input channels. 106 | :return: an nn.Module for normalization. 107 | """ 108 | return GroupNorm32(32, channels) 109 | 110 | 111 | def timestep_embedding(timesteps, dim, max_period=10000): 112 | """ 113 | Create sinusoidal timestep embeddings. 114 | 115 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 116 | These may be fractional. 117 | :param dim: the dimension of the output. 118 | :param max_period: controls the minimum frequency of the embeddings. 119 | :return: an [N x dim] Tensor of positional embeddings. 120 | """ 121 | half = dim // 2 122 | freqs = th.exp( 123 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 124 | ).to(device=timesteps.device) 125 | args = timesteps[:, None].float() * freqs[None] 126 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 127 | if dim % 2: 128 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 129 | return embedding 130 | 131 | 132 | def checkpoint(func, inputs, params, flag): 133 | """ 134 | Evaluate a function without caching intermediate activations, allowing for 135 | reduced memory at the expense of extra compute in the backward pass. 136 | 137 | :param func: the function to evaluate. 138 | :param inputs: the argument sequence to pass to `func`. 139 | :param params: a sequence of parameters `func` depends on but does not 140 | explicitly take as arguments. 141 | :param flag: if False, disable gradient checkpointing. 142 | """ 143 | if flag and any(param.requires_grad for param in params): 144 | args = tuple(inputs) + tuple(params) 145 | return CheckpointFunction.apply(func, len(inputs), *args) 146 | else: 147 | return func(*inputs) 148 | 149 | 150 | class CheckpointFunction(th.autograd.Function): 151 | @staticmethod 152 | def forward(ctx, run_function, length, *args): 153 | ctx.run_function = run_function 154 | ctx.input_tensors = list(args[:length]) 155 | ctx.input_params = list(args[length:]) 156 | with th.no_grad(): 157 | output_tensors = ctx.run_function(*ctx.input_tensors) 158 | return output_tensors 159 | 160 | @staticmethod 161 | def backward(ctx, *output_grads): 162 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 163 | with th.enable_grad(): 164 | # Fixes a bug where the first op in run_function modifies the 165 | # Tensor storage in place, which is not allowed for detach()'d 166 | # Tensors. 167 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 168 | output_tensors = ctx.run_function(*shallow_copies) 169 | input_grads = th.autograd.grad( 170 | output_tensors, 171 | ctx.input_tensors + ctx.input_params, 172 | output_grads, 173 | allow_unused=True, 174 | ) 175 | del ctx.input_tensors 176 | del ctx.input_params 177 | del output_tensors 178 | return (None, None) + input_grads 179 | -------------------------------------------------------------------------------- /external_src/I2SB/guided_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Taken from the following link as is from: 3 | # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py 4 | # 5 | # The license for the original version of this file can be 6 | # found in this directory (LICENSE_GUIDED_DIFFUSION). 7 | # --------------------------------------------------------------- 8 | 9 | import numpy as np 10 | import torch as th 11 | 12 | from .gaussian_diffusion import GaussianDiffusion 13 | 14 | 15 | def space_timesteps(num_timesteps, section_counts): 16 | """ 17 | Create a list of timesteps to use from an original diffusion process, 18 | given the number of timesteps we want to take from equally-sized portions 19 | of the original process. 20 | 21 | For example, if there's 300 timesteps and the section counts are [10,15,20] 22 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 23 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 24 | 25 | If the stride is a string starting with "ddim", then the fixed striding 26 | from the DDIM paper is used, and only one section is allowed. 27 | 28 | :param num_timesteps: the number of diffusion steps in the original 29 | process to divide up. 30 | :param section_counts: either a list of numbers, or a string containing 31 | comma-separated numbers, indicating the step count 32 | per section. As a special case, use "ddimN" where N 33 | is a number of steps to use the striding from the 34 | DDIM paper. 35 | :return: a set of diffusion steps from the original process to use. 36 | """ 37 | if isinstance(section_counts, str): 38 | if section_counts.startswith("ddim"): 39 | desired_count = int(section_counts[len("ddim") :]) 40 | for i in range(1, num_timesteps): 41 | if len(range(0, num_timesteps, i)) == desired_count: 42 | return set(range(0, num_timesteps, i)) 43 | raise ValueError( 44 | f"cannot create exactly {num_timesteps} steps with an integer stride" 45 | ) 46 | section_counts = [int(x) for x in section_counts.split(",")] 47 | size_per = num_timesteps // len(section_counts) 48 | extra = num_timesteps % len(section_counts) 49 | start_idx = 0 50 | all_steps = [] 51 | for i, section_count in enumerate(section_counts): 52 | size = size_per + (1 if i < extra else 0) 53 | if size < section_count: 54 | raise ValueError( 55 | f"cannot divide section of {size} steps into {section_count}" 56 | ) 57 | if section_count <= 1: 58 | frac_stride = 1 59 | else: 60 | frac_stride = (size - 1) / (section_count - 1) 61 | cur_idx = 0.0 62 | taken_steps = [] 63 | for _ in range(section_count): 64 | taken_steps.append(start_idx + round(cur_idx)) 65 | cur_idx += frac_stride 66 | all_steps += taken_steps 67 | start_idx += size 68 | return set(all_steps) 69 | 70 | 71 | class SpacedDiffusion(GaussianDiffusion): 72 | """ 73 | A diffusion process which can skip steps in a base diffusion process. 74 | 75 | :param use_timesteps: a collection (sequence or set) of timesteps from the 76 | original diffusion process to retain. 77 | :param kwargs: the kwargs to create the base diffusion process. 78 | """ 79 | 80 | def __init__(self, use_timesteps, **kwargs): 81 | self.use_timesteps = set(use_timesteps) 82 | self.timestep_map = [] 83 | self.original_num_steps = len(kwargs["betas"]) 84 | 85 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 86 | last_alpha_cumprod = 1.0 87 | new_betas = [] 88 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 89 | if i in self.use_timesteps: 90 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 91 | last_alpha_cumprod = alpha_cumprod 92 | self.timestep_map.append(i) 93 | kwargs["betas"] = np.array(new_betas) 94 | super().__init__(**kwargs) 95 | 96 | def p_mean_variance( 97 | self, model, *args, **kwargs 98 | ): # pylint: disable=signature-differs 99 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 100 | 101 | def training_losses( 102 | self, model, *args, **kwargs 103 | ): # pylint: disable=signature-differs 104 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 105 | 106 | def condition_mean(self, cond_fn, *args, **kwargs): 107 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 108 | 109 | def condition_score(self, cond_fn, *args, **kwargs): 110 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 111 | 112 | def _wrap_model(self, model): 113 | if isinstance(model, _WrappedModel): 114 | return model 115 | return _WrappedModel( 116 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 117 | ) 118 | 119 | def _scale_timesteps(self, t): 120 | # Scaling is done by the wrapped model. 121 | return t 122 | 123 | 124 | class _WrappedModel: 125 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 126 | self.model = model 127 | self.timestep_map = timestep_map 128 | self.rescale_timesteps = rescale_timesteps 129 | self.original_num_steps = original_num_steps 130 | 131 | def __call__(self, x, ts, **kwargs): 132 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 133 | new_ts = map_tensor[ts] 134 | if self.rescale_timesteps: 135 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 136 | return self.model(x, new_ts, **kwargs) 137 | -------------------------------------------------------------------------------- /external_src/I2SB/i2sb/__init__.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for I2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | from .runner import Runner 9 | from .ckpt_util import download_ckpt, download -------------------------------------------------------------------------------- /external_src/I2SB/i2sb/ckpt_util.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for I2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import os 9 | import requests 10 | from tqdm import tqdm 11 | 12 | import pickle 13 | 14 | import torch 15 | 16 | from guided_diffusion.script_util import ( 17 | model_and_diffusion_defaults, 18 | create_model, 19 | args_to_dict, 20 | ) 21 | 22 | from argparse import Namespace 23 | 24 | from pathlib import Path 25 | 26 | 27 | ADM_IMG256_UNCOND_CKPT = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt" 28 | I2SB_IMG256_UNCOND_PKL = "256x256_diffusion_uncond_fixedsigma.pkl" 29 | I2SB_IMG256_UNCOND_CKPT = "256x256_diffusion_uncond_fixedsigma.pt" 30 | I2SB_IMG256_COND_PKL = "256x256_diffusion_cond_fixedsigma.pkl" 31 | I2SB_IMG256_COND_CKPT = "256x256_diffusion_cond_fixedsigma.pt" 32 | 33 | def download(url, local_path, chunk_size=1024): 34 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 35 | with requests.get(url, stream=True) as r: 36 | total_size = int(r.headers.get("content-length", 0)) 37 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 38 | with open(local_path, "wb") as f: 39 | for data in r.iter_content(chunk_size=chunk_size): 40 | if data: 41 | f.write(data) 42 | pbar.update(chunk_size) 43 | 44 | def create_argparser(): 45 | return Namespace( 46 | attention_resolutions='32,16,8', 47 | batch_size=4, 48 | channel_mult='', 49 | class_cond=False, 50 | clip_denoised=True, 51 | diffusion_steps=1000, 52 | dropout=0.0, 53 | image_size=256, 54 | learn_sigma=True, 55 | adm_ckpt='256x256_diffusion_uncond.pt', 56 | noise_schedule='linear', 57 | num_channels=256, 58 | num_head_channels=64, 59 | num_heads=4, 60 | num_heads_upsample=-1, 61 | num_res_blocks=2, 62 | num_samples=4, 63 | predict_xstart=False, 64 | resblock_updown=True, 65 | rescale_learned_sigmas=False, 66 | rescale_timesteps=False, 67 | timestep_respacing='250', 68 | use_checkpoint=False, 69 | use_ddim=False, 70 | use_fp16=True, 71 | use_kl=False, 72 | use_new_attention_order=False, 73 | use_scale_shift_norm=True 74 | ) 75 | 76 | def extract_model_kwargs(kwargs): 77 | return { 78 | "image_size": kwargs["image_size"], 79 | "num_channels": kwargs["num_channels"], 80 | "num_res_blocks": kwargs["num_res_blocks"], 81 | "channel_mult": kwargs["channel_mult"], 82 | "learn_sigma": kwargs["learn_sigma"], 83 | "class_cond": kwargs["class_cond"], 84 | "use_checkpoint": kwargs["use_checkpoint"], 85 | "attention_resolutions": kwargs["attention_resolutions"], 86 | "num_heads": kwargs["num_heads"], 87 | "num_head_channels": kwargs["num_head_channels"], 88 | "num_heads_upsample": kwargs["num_heads_upsample"], 89 | "use_scale_shift_norm": kwargs["use_scale_shift_norm"], 90 | "dropout": kwargs["dropout"], 91 | "resblock_updown": kwargs["resblock_updown"], 92 | "use_fp16": kwargs["use_fp16"], 93 | "use_new_attention_order": kwargs["use_new_attention_order"], 94 | } 95 | 96 | def extract_diffusion_kwargs(kwargs): 97 | return { 98 | "diffusion_steps": kwargs["diffusion_steps"], 99 | "learn_sigma": False, 100 | "noise_schedule": kwargs["noise_schedule"], 101 | "use_kl": kwargs["use_kl"], 102 | "predict_xstart": kwargs["predict_xstart"], 103 | "rescale_timesteps": kwargs["rescale_timesteps"], 104 | "rescale_learned_sigmas": kwargs["rescale_learned_sigmas"], 105 | "timestep_respacing": kwargs["timestep_respacing"], 106 | } 107 | 108 | def download_adm_image256_uncond_ckpt(ckpt_dir="data/"): 109 | ckpt_pkl = os.path.join(ckpt_dir, I2SB_IMG256_UNCOND_PKL) 110 | ckpt_pt = os.path.join(ckpt_dir, I2SB_IMG256_UNCOND_CKPT) 111 | if os.path.exists(ckpt_pkl) and os.path.exists(ckpt_pt): 112 | return 113 | 114 | opt = create_argparser() 115 | 116 | adm_ckpt = os.path.join(ckpt_dir, opt.adm_ckpt) 117 | if not os.path.exists(adm_ckpt): 118 | print("Downloading ADM checkpoint to {} ...".format(adm_ckpt)) 119 | download(ADM_IMG256_UNCOND_CKPT, adm_ckpt) 120 | ckpt_state_dict = torch.load(adm_ckpt, map_location="cpu") 121 | 122 | # pt: remove the sigma prediction 123 | ckpt_state_dict["out.2.weight"] = ckpt_state_dict["out.2.weight"][:3] 124 | ckpt_state_dict["out.2.bias"] = ckpt_state_dict["out.2.bias"][:3] 125 | torch.save(ckpt_state_dict, ckpt_pt) 126 | 127 | # pkl 128 | kwargs = args_to_dict(opt, model_and_diffusion_defaults().keys()) 129 | kwargs['learn_sigma'] = False 130 | model_kwargs = extract_model_kwargs(kwargs) 131 | with open(ckpt_pkl, "wb") as f: 132 | pickle.dump(model_kwargs, f) 133 | 134 | print(f"Saved adm uncond pretrain models at {ckpt_pkl=} and {ckpt_pt}!") 135 | 136 | def download_adm_image256_cond_ckpt(ckpt_dir="data/"): 137 | ckpt_pkl = os.path.join(ckpt_dir, I2SB_IMG256_COND_PKL) 138 | ckpt_pt = os.path.join(ckpt_dir, I2SB_IMG256_COND_CKPT) 139 | if os.path.exists(ckpt_pkl) and os.path.exists(ckpt_pt): 140 | return 141 | 142 | opt = create_argparser() 143 | 144 | adm_ckpt = os.path.join(ckpt_dir, opt.adm_ckpt) 145 | if not os.path.exists(adm_ckpt): 146 | print("Downloading ADM checkpoint to {} ...".format(adm_ckpt)) 147 | download(ADM_IMG256_UNCOND_CKPT, adm_ckpt) 148 | ckpt_state_dict = torch.load(adm_ckpt, map_location="cpu") 149 | 150 | # pkl 151 | kwargs = args_to_dict(opt, model_and_diffusion_defaults().keys()) 152 | kwargs['learn_sigma'] = False 153 | model_kwargs = extract_model_kwargs(kwargs) 154 | model_kwargs.update(extract_diffusion_kwargs(kwargs)) 155 | model_kwargs["use_fp16"] = False 156 | model_kwargs["in_channels"] = 6 157 | with open(ckpt_pkl, "wb") as f: 158 | pickle.dump(model_kwargs, f) 159 | 160 | # pt: remove the sigma prediction and add concat module 161 | ckpt_state_dict["out.2.weight"] = ckpt_state_dict["out.2.weight"][:3] 162 | ckpt_state_dict["out.2.bias"] = ckpt_state_dict["out.2.bias"][:3] 163 | model = create_model(**model_kwargs) 164 | ckpt_state_dict['input_blocks.0.0.weight'] = torch.cat([ 165 | ckpt_state_dict['input_blocks.0.0.weight'], 166 | model.input_blocks[0][0].weight.data[:, 3:] 167 | ], dim=1) 168 | model.load_state_dict(ckpt_state_dict) 169 | torch.save(ckpt_state_dict, ckpt_pt) 170 | 171 | print(f"Saved adm cond pretrain models at {ckpt_pkl=} and {ckpt_pt}!") 172 | 173 | def download_ckpt(ckpt_dir="data/"): 174 | os.makedirs(ckpt_dir, exist_ok=True) 175 | download_adm_image256_uncond_ckpt(ckpt_dir=ckpt_dir) 176 | download_adm_image256_cond_ckpt(ckpt_dir=ckpt_dir) 177 | 178 | def build_ckpt_option(opt, log, ckpt_path): 179 | ckpt_path = Path(ckpt_path) 180 | opt_pkl_path = ckpt_path / "options.pkl" 181 | assert opt_pkl_path.exists() 182 | with open(opt_pkl_path, "rb") as f: 183 | ckpt_opt = pickle.load(f) 184 | log.info(f"Loaded options from {opt_pkl_path=}!") 185 | 186 | overwrite_keys = ["use_fp16", "device"] 187 | for k in overwrite_keys: 188 | assert hasattr(opt, k) 189 | setattr(ckpt_opt, k, getattr(opt, k)) 190 | 191 | ckpt_opt.load = ckpt_path / "latest.pt" 192 | return ckpt_opt 193 | -------------------------------------------------------------------------------- /external_src/I2SB/i2sb/diffusion.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for I2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import numpy as np 9 | from tqdm import tqdm 10 | from functools import partial 11 | import torch 12 | 13 | from .util import unsqueeze_xdim 14 | 15 | 16 | def compute_gaussian_product_coef(sigma1, sigma2): 17 | """ Given p1 = N(x_t|x_0, sigma_1**2) and p2 = N(x_t|x_1, sigma_2**2) 18 | return p1 * p2 = N(x_t| coef1 * x0 + coef2 * x1, var) """ 19 | 20 | denom = sigma1**2 + sigma2**2 21 | coef1 = sigma2**2 / denom 22 | coef2 = sigma1**2 / denom 23 | var = (sigma1**2 * sigma2**2) / denom 24 | return coef1, coef2, var 25 | 26 | class Diffusion(): 27 | def __init__(self, betas, device): 28 | 29 | self.device = device 30 | 31 | # compute analytic std: eq 11 32 | std_fwd = np.sqrt(np.cumsum(betas)) 33 | std_bwd = np.sqrt(np.flip(np.cumsum(np.flip(betas)))) 34 | mu_x0, mu_x1, var = compute_gaussian_product_coef(std_fwd, std_bwd) 35 | std_sb = np.sqrt(var) 36 | 37 | # tensorize everything 38 | to_torch = partial(torch.tensor, dtype=torch.float32) 39 | self.betas = to_torch(betas).to(device) 40 | self.std_fwd = to_torch(std_fwd).to(device) 41 | self.std_bwd = to_torch(std_bwd).to(device) 42 | self.std_sb = to_torch(std_sb).to(device) 43 | self.mu_x0 = to_torch(mu_x0).to(device) 44 | self.mu_x1 = to_torch(mu_x1).to(device) 45 | 46 | def get_std_fwd(self, step, xdim=None): 47 | std_fwd = self.std_fwd[step] 48 | return std_fwd if xdim is None else unsqueeze_xdim(std_fwd, xdim) 49 | 50 | def q_sample(self, step, x0, x1, ot_ode=False): 51 | """ Sample q(x_t | x_0, x_1), i.e. eq 11 """ 52 | 53 | assert x0.shape == x1.shape 54 | batch, *xdim = x0.shape 55 | 56 | mu_x0 = unsqueeze_xdim(self.mu_x0[step], xdim) 57 | mu_x1 = unsqueeze_xdim(self.mu_x1[step], xdim) 58 | std_sb = unsqueeze_xdim(self.std_sb[step], xdim) 59 | 60 | xt = mu_x0 * x0 + mu_x1 * x1 61 | if not ot_ode: 62 | xt = xt + std_sb * torch.randn_like(xt) 63 | return xt.detach() 64 | 65 | def p_posterior(self, nprev, n, x_n, x0, ot_ode=False): 66 | """ Sample p(x_{nprev} | x_n, x_0), i.e. eq 4""" 67 | 68 | assert nprev < n 69 | std_n = self.std_fwd[n] 70 | std_nprev = self.std_fwd[nprev] 71 | std_delta = (std_n**2 - std_nprev**2).sqrt() 72 | 73 | mu_x0, mu_xn, var = compute_gaussian_product_coef(std_nprev, std_delta) 74 | 75 | xt_prev = mu_x0 * x0 + mu_xn * x_n 76 | if not ot_ode and nprev > 0: 77 | xt_prev = xt_prev + var.sqrt() * torch.randn_like(xt_prev) 78 | 79 | return xt_prev 80 | 81 | def ddpm_sampling(self, steps, pred_x0_fn, x1, mask=None, ot_ode=False, log_steps=None, verbose=True): 82 | xt = x1.detach().to(self.device) 83 | 84 | xs = [] 85 | pred_x0s = [] 86 | 87 | log_steps = log_steps or steps 88 | assert steps[0] == log_steps[0] == 0 89 | 90 | steps = steps[::-1] 91 | 92 | pair_steps = zip(steps[1:], steps[:-1]) 93 | pair_steps = tqdm(pair_steps, desc='DDPM sampling', total=len(steps)-1) if verbose else pair_steps 94 | for prev_step, step in pair_steps: 95 | assert prev_step < step, f"{prev_step=}, {step=}" 96 | 97 | pred_x0 = pred_x0_fn(xt, step) 98 | xt = self.p_posterior(prev_step, step, xt, pred_x0, ot_ode=ot_ode) 99 | 100 | if mask is not None: 101 | xt_true = x1 102 | if not ot_ode: 103 | _prev_step = torch.full((xt.shape[0],), prev_step, device=self.device, dtype=torch.long) 104 | std_sb = unsqueeze_xdim(self.std_sb[_prev_step], xdim=x1.shape[1:]) 105 | xt_true = xt_true + std_sb * torch.randn_like(xt_true) 106 | xt = (1. - mask) * xt_true + mask * xt 107 | 108 | if prev_step in log_steps: 109 | pred_x0s.append(pred_x0.detach().cpu()) 110 | xs.append(xt.detach().cpu()) 111 | 112 | stack_bwd_traj = lambda z: torch.flip(torch.stack(z, dim=1), dims=(1,)) 113 | return stack_bwd_traj(xs), stack_bwd_traj(pred_x0s) 114 | -------------------------------------------------------------------------------- /external_src/I2SB/i2sb/network.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for I2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import os 9 | import pickle 10 | import torch 11 | 12 | from guided_diffusion.script_util import create_model 13 | 14 | from . import util 15 | from .ckpt_util import ( 16 | I2SB_IMG256_UNCOND_PKL, 17 | I2SB_IMG256_UNCOND_CKPT, 18 | I2SB_IMG256_COND_PKL, 19 | I2SB_IMG256_COND_CKPT, 20 | ) 21 | 22 | from ipdb import set_trace as debug 23 | 24 | class Image256Net(torch.nn.Module): 25 | def __init__(self, log, noise_levels, use_fp16=False, cond=False, pretrained_adm=True, ckpt_dir="data/"): 26 | super(Image256Net, self).__init__() 27 | 28 | # initialize model 29 | ckpt_pkl = os.path.join(ckpt_dir, I2SB_IMG256_COND_PKL if cond else I2SB_IMG256_UNCOND_PKL) 30 | with open(ckpt_pkl, "rb") as f: 31 | kwargs = pickle.load(f) 32 | kwargs["use_fp16"] = use_fp16 33 | self.diffusion_model = create_model(**kwargs) 34 | log.info(f"[Net] Initialized network from {ckpt_pkl=}! Size={util.count_parameters(self.diffusion_model)}!") 35 | 36 | # load (modified) adm ckpt 37 | if pretrained_adm: 38 | ckpt_pt = os.path.join(ckpt_dir, I2SB_IMG256_COND_CKPT if cond else I2SB_IMG256_UNCOND_CKPT) 39 | out = torch.load(ckpt_pt, map_location="cpu") 40 | self.diffusion_model.load_state_dict(out) 41 | log.info(f"[Net] Loaded pretrained adm {ckpt_pt=}!") 42 | 43 | self.diffusion_model.eval() 44 | self.cond = cond 45 | self.noise_levels = noise_levels 46 | 47 | def forward(self, x, steps, cond=None): 48 | 49 | t = self.noise_levels[steps].detach() 50 | assert t.dim()==1 and t.shape[0] == x.shape[0] 51 | 52 | x = torch.cat([x, cond], dim=1) if self.cond else x 53 | return self.diffusion_model(x, t) 54 | -------------------------------------------------------------------------------- /external_src/I2SB/i2sb/util.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for I2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import os 9 | 10 | import torch 11 | from torch.utils.data import DataLoader 12 | 13 | 14 | class DataLoaderX(DataLoader): 15 | def __iter__(self): 16 | from prefetch_generator import BackgroundGenerator 17 | return BackgroundGenerator(super().__iter__()) 18 | 19 | def setup_loader(dataset, batch_size, num_workers=4): 20 | loader = DataLoaderX( 21 | dataset, 22 | batch_size=batch_size, 23 | shuffle=True, 24 | pin_memory=True, 25 | num_workers=num_workers, 26 | drop_last=True, 27 | ) 28 | 29 | while True: 30 | yield from loader 31 | 32 | class BaseWriter(object): 33 | def __init__(self, opt): 34 | self.rank = opt.global_rank 35 | def add_scalar(self, step, key, val): 36 | pass # do nothing 37 | def add_image(self, step, key, image): 38 | pass # do nothing 39 | def close(self): pass 40 | 41 | class WandBWriter(BaseWriter): 42 | def __init__(self, opt): 43 | import wandb 44 | 45 | super(WandBWriter,self).__init__(opt) 46 | if self.rank == 0: 47 | assert wandb.login(key=opt.wandb_api_key) 48 | wandb.init(dir=str(opt.log_dir), project="i2sb", entity=opt.wandb_user, name=opt.name, config=vars(opt)) 49 | 50 | def add_scalar(self, step, key, val): 51 | import wandb 52 | 53 | if self.rank == 0: wandb.log({key: val}, step=step) 54 | 55 | def add_image(self, step, key, image): 56 | import wandb 57 | 58 | if self.rank == 0: 59 | # adopt from torchvision.utils.save_image 60 | image = image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() 61 | wandb.log({key: wandb.Image(image)}, step=step) 62 | 63 | 64 | class TensorBoardWriter(BaseWriter): 65 | def __init__(self, opt): 66 | from torch.utils.tensorboard import SummaryWriter 67 | 68 | super(TensorBoardWriter,self).__init__(opt) 69 | if self.rank == 0: 70 | run_dir = str(opt.log_dir / opt.name) 71 | os.makedirs(run_dir, exist_ok=True) 72 | self.writer=SummaryWriter(log_dir=run_dir, flush_secs=20) 73 | 74 | def add_scalar(self, global_step, key, val): 75 | if self.rank == 0: self.writer.add_scalar(key, val, global_step=global_step) 76 | 77 | def add_image(self, global_step, key, image): 78 | if self.rank == 0: 79 | image = image.mul(255).add_(0.5).clamp_(0, 255).to("cpu", torch.uint8) 80 | self.writer.add_image(key, image, global_step=global_step) 81 | 82 | def close(self): 83 | if self.rank == 0: self.writer.close() 84 | 85 | def build_log_writer(opt): 86 | if opt.log_writer == 'wandb': return WandBWriter(opt) 87 | elif opt.log_writer == 'tensorboard': return TensorBoardWriter(opt) 88 | else: return BaseWriter(opt) # do nothing 89 | 90 | def count_parameters(model): 91 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 92 | 93 | def space_indices(num_steps, count): 94 | assert count <= num_steps 95 | 96 | if count <= 1: 97 | frac_stride = 1 98 | else: 99 | frac_stride = (num_steps - 1) / (count - 1) 100 | 101 | cur_idx = 0.0 102 | taken_steps = [] 103 | for _ in range(count): 104 | taken_steps.append(round(cur_idx)) 105 | cur_idx += frac_stride 106 | 107 | return taken_steps 108 | 109 | def unsqueeze_xdim(z, xdim): 110 | bc_dim = (...,) + (None,) * len(xdim) 111 | return z[bc_dim] 112 | -------------------------------------------------------------------------------- /external_src/SuperRetina/README.md: -------------------------------------------------------------------------------- 1 | # SuperRetina for Retinal Image Matching 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/semi-supervised-keypoint-detector-and/image-registration-on-fire)](https://paperswithcode.com/sota/image-registration-on-fire?p=semi-supervised-keypoint-detector-and) 4 | 5 | This is the official source code of our ECCV2022 paper: [Semi-Supervised Keypoint Detector and Descriptor for Retinal Image Matching](https://arxiv.org/abs/2207.07932). 6 | 7 | ![illustration](./image/illustration.png) 8 | 9 | ## Environment 10 | We used Anaconda to setup a deep learning workspace that supports PyTorch. Run the following script to install all the required packages. 11 | 12 | ``` conda 13 | conda create -n SuperRetina python==3.8 -y 14 | conda activate SuperRetina 15 | git clone https://github.com/ruc-aimc-lab/SuperRetina.git 16 | cd SuperRetina 17 | pip install -r requirements.txt 18 | ``` 19 | 20 | ## Downloads 21 | 22 | ### Data 23 | See the [data](data) pape. SuperRetina is trained on a small amount of keypoint annotations, which can be either manually labeled or auto-labeled by a specific keypoint detection algorithm. Check [notebooks/read_keypoint_labels.ipynb](notebooks/read_keypoint_labels.ipynb) to see our data format of keypoint annotations. 24 | 25 | ### Models 26 | 27 | You may skip the training stage and use our provided models for keypoint detection and description on retinal images. 28 | + [Google drive](https://drive.google.com/drive/folders/1h-MH3wEiN7BoLyMRjF1OAwABKqq6gVFL?usp=sharing) 29 | 30 | Put the trained model into `save/` folder. 31 | 32 | 33 | ## Code 34 | 35 | ### Training 36 | 37 | Write the [config/train.yaml](config/train.yaml) file before training SuperRetina. Here we provide a demo training config file. Then you can train SuperRetina on your own data by using the following command. 38 | 39 | ``` 40 | python train.py 41 | ``` 42 | 43 | ### Inference 44 | 45 | #### Registration Performance 46 | The [test_on_FIRE.py](test_on_FIRE.py) code shows how image registration is performed on the FIRE dataset. 47 | ``` 48 | python test_on_FIRE.py 49 | ``` 50 | If everything goes well, you shall see the following message on your screen: 51 | ``` 52 | ---------------------------------------- 53 | Failed:0.00%, Inaccurate:1.50%, Acceptable:98.50% 54 | ---------------------------------------- 55 | S: 0.950, P: 0.554, A: 0.783, mAUC: 0.762 56 | ``` 57 | 58 | #### Identity Verification Performance 59 | 60 | The [test_on_VARIA.py](./test_on_VARIA.py) code shows how identity verification is performed on the VARIA dataset. 61 | ``` 62 | python test_on_VARIA.py 63 | ``` 64 | If everything goes well, you shall see the following message on your screen: 65 | ``` 66 | VARIA DATASET 67 | EER: 0.00%, threshold: 40 68 | ``` 69 | 70 | --- 71 | 72 | We have also provided some tutorial codes showing step-by-step usage of SuperRetina: 73 | + [notebooks/tutorial-inference.ipynb](notebooks/tutorial-inference.ipynb): Perform registration for a given pair of images. 74 | + [notebooks/eval-registration-on-FIRE.ipynb](notebooks/eval-registration-on-FIRE.ipynb): Evaluation on the [FIRE](https://projects.ics.forth.gr/cvrl/fire/) dataset. 75 | 76 | ## Citations 77 | If you find this repository useful, please consider citing: 78 | ``` 79 | @inproceedings{liu2022SuperRetina, 80 | title={Semi-Supervised Keypoint Detector and Descriptor for Retinal Image Matching}, 81 | author={Jiazhen Liu and Xirong Li and Qijie Wei and Jie Xu and Dayong Ding}, 82 | booktitle={Proceedings of the 17th European Conference on Computer Vision (ECCV)}, 83 | year={2022} 84 | } 85 | ``` 86 | 87 | ## Contact 88 | If you encounter any issue when running the code, please feel free to reach us either by creating a new issue in the GitHub or by emailing 89 | 90 | + Jiazhen Liu (liujiazhen@ruc.edu.cn) 91 | 92 | -------------------------------------------------------------------------------- /external_src/SuperRetina/common/common_util.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | import torch 3 | import matplotlib.pyplot as plt 4 | 5 | from torchvision import transforms 6 | import numpy as np 7 | from PIL import Image 8 | import cv2 9 | import torch.nn as nn 10 | import scipy.stats as st 11 | 12 | 13 | def remove_borders(keypoints, scores, border: int, height: int, width: int): 14 | """ Removes keypoints too close to the border """ 15 | mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border)) 16 | mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border)) 17 | mask = mask_h & mask_w 18 | return keypoints[mask], scores[mask] 19 | 20 | 21 | def simple_nms(scores, nms_radius: int): 22 | """ Fast Non-maximum suppression to remove nearby points """ 23 | assert (nms_radius >= 0) 24 | 25 | size = nms_radius * 2 + 1 26 | avg_size = 2 27 | def max_pool(x): 28 | return torch.nn.functional.max_pool2d( 29 | x, kernel_size=size, stride=1, padding=nms_radius) 30 | 31 | zeros = torch.zeros_like(scores) 32 | # max_map = max_pool(scores) 33 | 34 | max_mask = scores == max_pool(scores) 35 | max_mask_ = torch.rand(max_mask.shape).to(max_mask.device) / 10 36 | max_mask_[~max_mask] = 0 37 | mask = ((max_mask_ == max_pool(max_mask_)) & (max_mask_ > 0)) 38 | 39 | return torch.where(mask, scores, zeros) 40 | 41 | 42 | def pre_processing(data): 43 | """ Enhance retinal images """ 44 | train_imgs = datasets_normalized(data) 45 | train_imgs = clahe_equalized(train_imgs) 46 | train_imgs = adjust_gamma(train_imgs, 1.2) 47 | 48 | train_imgs = train_imgs / 255. 49 | 50 | return train_imgs.astype(np.float32) 51 | 52 | 53 | def rgb2gray(rgb): 54 | """ Convert RGB image to gray image """ 55 | r, g, b = rgb.split() 56 | return g 57 | 58 | 59 | def clahe_equalized(images): 60 | clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) 61 | images_equalized = np.empty(images.shape) 62 | images_equalized[:, :] = clahe.apply(np.array(images[:, :], 63 | dtype=np.uint8)) 64 | 65 | return images_equalized 66 | 67 | 68 | def datasets_normalized(images): 69 | # images_normalized = np.empty(images.shape) 70 | images_std = np.std(images) 71 | images_mean = np.mean(images) 72 | images_normalized = (images - images_mean) / (images_std + 1e-6) 73 | minv = np.min(images_normalized) 74 | images_normalized = ((images_normalized - minv) / 75 | (np.max(images_normalized) - minv)) * 255 76 | 77 | return images_normalized 78 | 79 | 80 | def adjust_gamma(images, gamma=1.0): 81 | invGamma = 1.0 / gamma 82 | table = np.array([((i / 255.0) ** invGamma) * 255 83 | for i in np.arange(0, 256)]).astype("uint8") 84 | new_images = np.empty(images.shape) 85 | new_images[:, :] = cv2.LUT(np.array(images[:, :], 86 | dtype=np.uint8), table) 87 | 88 | return new_images 89 | 90 | 91 | def nms(detector_pred, nms_thresh=0.1, nms_size=10, detector_label=None, mask=False): 92 | """ Apply NMS on predictions, if mask, then remove geo_points that appearing in labels """ 93 | detector_pred = detector_pred.clone().detach() 94 | 95 | B, _, h, w = detector_pred.shape 96 | 97 | # if mask: 98 | # assert detector_label is not None 99 | # detector_pred[detector_pred < nms_thresh] = 0 100 | # label_mask = detector_label 101 | # 102 | # # more area 103 | # 104 | # detector_label = detector_label.long().cpu().numpy() 105 | # detector_label = detector_label.astype(np.uint8) 106 | # kernel = np.ones((3, 3), np.uint8) 107 | # label_mask = np.array([cv2.dilate(detector_label[s, 0], kernel, iterations=1) 108 | # for s in range(len(detector_label))]) 109 | # label_mask = torch.from_numpy(label_mask).unsqueeze(1) 110 | # detector_pred[label_mask > 1e-6] = 0 111 | 112 | scores = simple_nms(detector_pred, nms_size) 113 | 114 | scores = scores.reshape(B, h, w) 115 | 116 | points = [ 117 | torch.nonzero(s > nms_thresh) 118 | for s in scores] 119 | 120 | scores = [s[tuple(k.t())] for s, k in zip(scores, points)] 121 | 122 | points, scores = list(zip(*[ 123 | remove_borders(k, s, 8, h, w) 124 | for k, s in zip(points, scores)])) 125 | points = [torch.flip(k, [1]).long() for k in points] 126 | 127 | return points 128 | 129 | 130 | def sample_keypoint_desc(keypoints, descriptors, s: int = 8): 131 | """ Interpolate descriptors at keypoint locations """ 132 | b, c, h, w = descriptors.shape 133 | keypoints = keypoints.clone().float() 134 | 135 | keypoints /= torch.tensor([(w * s - 1), (h * s - 1)]).to(keypoints)[None] 136 | keypoints = keypoints * 2 - 1 # normalize to (-1, 1) 137 | 138 | args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {} 139 | descriptors = torch.nn.functional.grid_sample( 140 | descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) 141 | 142 | descriptors = torch.nn.functional.normalize( 143 | descriptors.reshape(b, c, -1), p=2, dim=1) 144 | return descriptors 145 | 146 | 147 | def sample_descriptors(detector_pred, descriptor_pred, affine_descriptor_pred, grid_inverse, 148 | nms_size=10, nms_thresh=0.1, scale=8, affine_detector_pred=None): 149 | """ 150 | sample descriptors based on keypoints 151 | :param affine_descriptor_pred: 152 | :param descriptor_pred: 153 | :param detector_pred: 154 | :param grid_inverse: used for inverse transformation of affine 155 | :param nms_size 156 | :param nms_thresh 157 | :param scale: down sampling size of detector 158 | :return: sampled descriptors 159 | """ 160 | B, _, h, w = detector_pred.shape 161 | keypoints = nms(detector_pred, nms_size=nms_size, nms_thresh=nms_thresh) 162 | 163 | affine_keypoints = [(grid_inverse[s, k[:, 1].long(), k[:, 0].long()]) for s, k in 164 | enumerate(keypoints)] 165 | 166 | kp = [] 167 | affine_kp = [] 168 | for s, k in enumerate(affine_keypoints): 169 | idx = (k[:, 0] < 1) & (k[:, 0] > -1) & (k[:, 1] < 1) & ( 170 | k[:, 1] > -1) 171 | kp.append(keypoints[s][idx]) 172 | ak = k[idx] 173 | ak[:, 0] = (ak[:, 0] + 1) / 2 * (w - 1) 174 | ak[:, 1] = (ak[:, 1] + 1) / 2 * (h - 1) 175 | affine_kp.append(ak) 176 | 177 | descriptors = [sample_keypoint_desc(k[None], d[None], s=scale)[0] 178 | for k, d in zip(kp, descriptor_pred)] 179 | affine_descriptors = [sample_keypoint_desc(k[None], d[None], s=scale)[0] 180 | for k, d in zip(affine_kp, affine_descriptor_pred)] 181 | return descriptors, affine_descriptors, keypoints 182 | -------------------------------------------------------------------------------- /external_src/SuperRetina/common/eval_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # Compute AUC scores for image registration on the FIRE dataset 5 | def compute_auc(s_error, p_error, a_error): 6 | assert (len(s_error) == 71) # Easy pairs 7 | assert (len(p_error) == 48) # Hard pairs. Note file control_points_P37_1_2.txt is ignored 8 | assert (len(a_error) == 14) # Moderate pairs 9 | 10 | s_error = np.array(s_error) 11 | p_error = np.array(p_error) 12 | a_error = np.array(a_error) 13 | 14 | limit = 25 15 | gs_error = np.zeros(limit + 1) 16 | gp_error = np.zeros(limit + 1) 17 | ga_error = np.zeros(limit + 1) 18 | 19 | accum_s = 0 20 | accum_p = 0 21 | accum_a = 0 22 | 23 | for i in range(1, limit + 1): 24 | gs_error[i] = np.sum(s_error < i) * 100 / len(s_error) 25 | gp_error[i] = np.sum(p_error < i) * 100 / len(p_error) 26 | ga_error[i] = np.sum(a_error < i) * 100 / len(a_error) 27 | 28 | accum_s = accum_s + gs_error[i] 29 | accum_p = accum_p + gp_error[i] 30 | accum_a = accum_a + ga_error[i] 31 | 32 | auc_s = accum_s / (limit * 100) 33 | auc_p = accum_p / (limit * 100) 34 | auc_a = accum_a / (limit * 100) 35 | mAUC = (auc_s + auc_p + auc_a) / 3.0 36 | return {'s': auc_s, 'p': auc_p, 'a': auc_a, 'mAUC': mAUC} 37 | -------------------------------------------------------------------------------- /external_src/SuperRetina/config/test.yaml: -------------------------------------------------------------------------------- 1 | PREDICT: 2 | device: cuda:0 3 | model_save_path: ./save/SuperRetina.pth 4 | model_image_width: 768 5 | model_image_height: 768 6 | use_matching_trick: True 7 | 8 | nms_size: 10 9 | nms_thresh: 0.01 10 | 11 | knn_thresh: 0.9 12 | 13 | -------------------------------------------------------------------------------- /external_src/SuperRetina/config/test_VARIA.yaml: -------------------------------------------------------------------------------- 1 | PREDICT: 2 | device: cuda:0 3 | model_save_path: ./save/SuperRetina.pth 4 | model_image_width: 512 5 | model_image_height: 512 6 | use_matching_trick: false 7 | 8 | nms_size: 10 9 | nms_thresh: 0.1 10 | 11 | knn_thresh: 0.8 12 | 13 | -------------------------------------------------------------------------------- /external_src/SuperRetina/config/train.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | nms_size: 10 3 | nms_thresh: 0.1 4 | 5 | batch_size: 4 6 | num_epoch: 150 7 | 8 | device: cuda:0 9 | load_pre_trained_model: false 10 | pretrained_path: ./save/SuperRetina.pth 11 | 12 | model_save_path: ./save/new_model.pth 13 | model_save_epoch: 1 14 | 15 | DATASET: 16 | model_image_width: 768 17 | model_image_height: 768 18 | dataset_path: ./data/Lab 19 | train_split_file: eccv22_train.txt 20 | val_split_file: eccv22_val.txt 21 | auxiliary: ./data/Auxiliary 22 | 23 | PKE: 24 | pke_start_epoch: 0 # the epoch to start PKE learn 25 | geometric_thresh: 0.5 26 | content_thresh: 0.7 27 | gaussian_kernel_size: 13 # used to generate heatmap 28 | gaussian_sigma: 2 29 | 30 | pke_show: true 31 | pke_show_epoch: 5 32 | pke_show_list: [] # if pke_show_list == [], then randomly select one sample to show 33 | 34 | 35 | # value_map is used to record history learned points. The value of the position of the newly added keypoint will increase. 36 | VALUE_MAP: 37 | area: 8 # if the neighbor of the keypoint have non-zero values, the values will increase by `value_increase_area` 38 | value_increase_point: 5 # if the neighbor of the keypoint is all zero, the keypoint position's value will increase by `value_increase_point` 39 | value_increase_area: 1 40 | value_decay: 1 # if the history keypoints don't appear in this epoch, the correspoding value will decay 41 | is_value_map_save: false # if false, running in RAM, otherwise, the value map will store as a temp file in `value_map_save_dir` 42 | value_map_save_dir: ./data/lab_values 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /external_src/SuperRetina/dataset/retina_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import scipy.stats as st 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | from PIL import Image 6 | import numpy as np 7 | import imgaug.augmenters as iaa 8 | import os 9 | from torch.utils.data import Dataset, DataLoader 10 | from torchvision import transforms 11 | 12 | from common.common_util import pre_processing 13 | from common.train_util import get_gaussian_kernel 14 | import time 15 | 16 | class RetinaDataset(Dataset): 17 | def __init__(self, data_path, split_file='eccv22_train.txt', 18 | is_train=True, data_shape=(768, 768), auxiliary=None): 19 | self.data = [] 20 | 21 | self.image_path = os.path.join(data_path, 'ImageData') 22 | self.label_path = os.path.join(data_path, 'Annotations') 23 | self.split_file = os.path.join(data_path, 'ImageSets', split_file) 24 | 25 | self.enhancement_sequential = iaa.Sequential([ 26 | iaa.Multiply((1.0, 1.2)), # change brightness, doesn't affect keypoints 27 | iaa.Sometimes( 28 | 0.2, 29 | iaa.GaussianBlur(sigma=(0, 6)) 30 | ), 31 | iaa.Sometimes( 32 | 0.2, 33 | iaa.LinearContrast((0.75, 1.2)) 34 | ), 35 | ], random_order=True) 36 | 37 | self.is_train = is_train 38 | self.model_image_height, self.model_image_width = data_shape[0], data_shape[1] 39 | 40 | with open(self.split_file) as f: 41 | files = f.readlines() 42 | 43 | for file in files: 44 | file = file.strip() 45 | try: 46 | image, label = file.split(', ') 47 | image = os.path.join(self.image_path, image) 48 | label = os.path.join(self.label_path, label) 49 | assert os.path.exists(image) and os.path.exists(label) 50 | self.data.append((image, label)) 51 | except Exception: 52 | pass 53 | if auxiliary is not None and os.path.exists(auxiliary): 54 | print('-'*10+f"Load Lab {'train' if is_train else 'eval'} data and auxiliary data without labels"+'-'*10) 55 | auxiliaries_tmp = os.listdir(auxiliary) 56 | auxiliaries = [] 57 | auxiliaries_tmp = [(os.path.join(auxiliary, a), None) for a in auxiliaries_tmp] 58 | for a, b in auxiliaries_tmp: 59 | if a.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')): 60 | auxiliaries.append((a, b)) 61 | self.data = self.data + auxiliaries 62 | else: 63 | print('-' * 10 + f"Load Lab {'train' if is_train else 'eval'} data, and there is no auxiliary data" + '-' * 10) 64 | self.transforms = transforms.Compose([ 65 | transforms.Resize((self.model_image_width, self.model_image_height)), 66 | transforms.ToTensor(), 67 | ]) 68 | 69 | def __len__(self): 70 | return len(self.data) 71 | 72 | def __getitem__(self, index): 73 | input_with_label = False 74 | 75 | image_path, label_path = self.data[index] 76 | label_name = '[None]' 77 | 78 | image = Image.open(image_path).convert('RGB') 79 | image = np.asarray(image) 80 | image = image[:, :, 1] 81 | 82 | image = pre_processing(image) 83 | # if self.is_train: 84 | # image = self.enhancement_sequential(image=image) 85 | image = Image.fromarray(image) 86 | image_tensor = self.transforms(image) 87 | 88 | if label_path is not None: 89 | label_name = os.path.split(label_path)[-1] 90 | keypoint_position = np.loadtxt(label_path) # (2, n): (x, y).T 91 | keypoint_position[:, 0] *= image_tensor.shape[-1] 92 | keypoint_position[:, 1] *= image_tensor.shape[-2] 93 | 94 | tensor_position = torch.zeros([self.model_image_height, self.model_image_width]) 95 | 96 | tensor_position[keypoint_position[:, 1], keypoint_position[:, 0]] = 1 97 | tensor_position = tensor_position.unsqueeze(0) 98 | input_with_label = True 99 | 100 | return image_tensor, input_with_label, tensor_position, label_name 101 | 102 | # without labels, only train descriptor 103 | tensor_position = torch.empty(image_tensor.shape) 104 | 105 | return image_tensor, input_with_label, tensor_position, label_name 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /external_src/SuperRetina/loss/dice_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class DiceBCELoss(nn.Module): 7 | def __init__(self, weight=None, size_average=True): 8 | super(DiceBCELoss, self).__init__() 9 | 10 | def forward(self, inputs, targets, smooth=1): 11 | # comment out if your model contains a sigmoid or equivalent activation layer 12 | # inputs = F.sigmoid(inputs) 13 | 14 | # flatten label and prediction tensors 15 | inputs = inputs.view(-1) 16 | targets = targets.view(-1) 17 | 18 | intersection = (inputs * targets).sum() 19 | dice_loss = 1 - (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth) 20 | BCE = F.binary_cross_entropy(inputs, targets, reduction='mean') 21 | Dice_BCE = 0.01*BCE + dice_loss 22 | 23 | return Dice_BCE 24 | 25 | class DiceLoss(nn.Module): 26 | def __init__(self, smooth=1e-3, p=2, reduction='mean'): 27 | super(DiceLoss, self).__init__() 28 | self.smooth = smooth 29 | self.p = p 30 | self.reduction = reduction 31 | 32 | def forward(self, predict, target): 33 | assert predict.shape[0] == target.shape[0], "predict & target batch size don't match" 34 | predict = predict.contiguous().view(-1) 35 | target = target.contiguous().view(-1) 36 | a = torch.mul(predict, target) 37 | b = predict.pow(self.p) + target.pow(self.p) 38 | 39 | num = 2 * torch.sum(a, dim=0) + self.smooth 40 | 41 | den = torch.sum(b, dim=0) + self.smooth 42 | 43 | loss = 1 - num / den 44 | 45 | if self.reduction == 'mean': 46 | return loss.mean() 47 | elif self.reduction == 'sum': 48 | return loss.sum() 49 | elif self.reduction == 'none': 50 | return loss 51 | else: 52 | raise Exception('Unexpected reduction {}'.format(self.reduction)) -------------------------------------------------------------------------------- /external_src/SuperRetina/loss/triplet_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pairwise_distance(x1, x2, p=2, eps=1e-6): 5 | r""" 6 | Computes the batchwise pairwise distance between vectors v1,v2: 7 | .. math :: 8 | \Vert x \Vert _p := \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p} 9 | Args: 10 | x1: first input tensor 11 | x2: second input tensor 12 | p: the norm degree. Default: 2 13 | Shape: 14 | - Input: :math:`(N, D)` where `D = vector dimension` 15 | - Output: :math:`(N, 1)` 16 | >>> input1 = autograd.Variable(torch.randn(100, 128)) 17 | >>> input2 = autograd.Variable(torch.randn(100, 128)) 18 | >>> output = F.pairwise_distance(input1, input2, p=2) 19 | >>> output.backward() 20 | """ 21 | assert x1.size() == x2.size(), "Input sizes must be equal." 22 | assert x1.dim() == 2, "Input must be a 2D matrix." 23 | 24 | return 1 - torch.cosine_similarity(x1, x2, dim=1) 25 | # diff = torch.abs(x1 - x2) 26 | # out = torch.sum(torch.pow(diff + eps, p), dim=1) 27 | # 28 | # return torch.pow(out, 1. / p) 29 | 30 | 31 | def triplet_margin_loss_gor_one(anchor, positive, negative, beta=1.0, margin=1.0, p=2, eps=1e-6, swap=False): 32 | assert anchor.size() == positive.size(), "Input sizes between positive and negative must be equal." 33 | assert anchor.size() == negative.size(), "Input sizes between anchor and negative must be equal." 34 | assert positive.size() == negative.size(), "Input sizes between positive and negative must be equal." 35 | assert anchor.dim() == 2, "Inputd must be a 2D matrix." 36 | assert margin > 0.0, 'Margin should be positive value.' 37 | d_p = pairwise_distance(anchor, positive, p, eps) 38 | d_n = pairwise_distance(anchor, negative, p, eps) 39 | 40 | dist_hinge = torch.clamp(margin + d_p - d_n, min=0.0) 41 | 42 | neg_dis = torch.pow(torch.sum(torch.mul(anchor, negative), 1), 2) 43 | gor = torch.mean(neg_dis) 44 | 45 | loss = torch.mean(dist_hinge) + beta * (gor) 46 | 47 | return loss 48 | 49 | 50 | def triplet_margin_loss_gor(anchor, positive, negative1, negative2, beta=1.0, margin=1.0, p=2, eps=1e-6, swap=False): 51 | assert anchor.size() == positive.size(), "Input sizes between positive and negative must be equal." 52 | assert anchor.size() == negative1.size(), "Input sizes between anchor and negative must be equal." 53 | assert positive.size() == negative2.size(), "Input sizes between positive and negative must be equal." 54 | assert anchor.dim() == 2, "Inputd must be a 2D matrix." 55 | assert margin > 0.0, 'Margin should be positive value.' 56 | 57 | # loss1 = triplet_margin_loss_gor_one(anchor, positive, negative1) 58 | # loss2 = triplet_margin_loss_gor_one(anchor, positive, negative2) 59 | # 60 | # return 0.5*(loss1+loss2) 61 | 62 | d_p = pairwise_distance(anchor, positive, p, eps) 63 | d_n1 = pairwise_distance(anchor, negative1, p, eps) 64 | d_n2 = pairwise_distance(anchor, negative2, p, eps) 65 | 66 | dist_hinge = torch.clamp(margin + d_p - 0.5 * (d_n1 + d_n2), min=0.0) 67 | 68 | neg_dis1 = torch.pow(torch.sum(torch.mul(anchor, negative1), 1), 2) 69 | gor1 = torch.mean(neg_dis1) 70 | neg_dis2 = torch.pow(torch.sum(torch.mul(anchor, negative2), 1), 2) 71 | gor2 = torch.mean(neg_dis2) 72 | 73 | loss = torch.mean(dist_hinge) + beta * (gor1 + gor2) 74 | 75 | return loss 76 | 77 | 78 | def distance_matrix_vector(anchor, positive): 79 | """Given batch of anchor descriptors and positive descriptors calculate distance matrix""" 80 | D = anchor.shape[-1] 81 | d1_sq = torch.sum(anchor * anchor, dim=1).unsqueeze(-1) 82 | d2_sq = torch.sum(positive * positive, dim=1).unsqueeze(-1) 83 | 84 | eps = 1e-3 85 | return torch.sqrt((d1_sq.repeat(1, positive.size(0)) + torch.t(d2_sq.repeat(1, anchor.size(0))) 86 | - 2.0 * torch.bmm(anchor.unsqueeze(0), torch.t(positive).unsqueeze(0)).squeeze(0))+eps) 87 | 88 | # anchor = anchor.permute(1, 0).view(D, -1, 1) 89 | # positive = positive.permute(1, 0).view(D, 1, -1) 90 | # return torch.norm(anchor - positive, dim=0) 91 | 92 | 93 | def percentile(t, q): 94 | """ 95 | Return the ``q``-th percentile of the flattened input tensor's data. 96 | 97 | CAUTION: 98 | * Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used. 99 | * Values are not interpolated, which corresponds to 100 | ``numpy.percentile(..., interpolation="nearest")``. 101 | 102 | :param t: Input tensor. 103 | :param q: Percentile to compute, which must be between 0 and 100 inclusive. 104 | :return: Resulting value (scalar). 105 | """ 106 | # Note that ``kthvalue()`` works one-based, i.e. the first sorted value 107 | # indeed corresponds to k=1, not k=0! Use float(q) instead of q directly, 108 | # so that ``round()`` returns an integer, even if q is a np.float32. 109 | k = 1 + round(.01 * float(q) * (t.numel() - 1)) 110 | result = t.view(-1).kthvalue(int(k)).values.item() 111 | return result 112 | 113 | 114 | """ Triplet loss usd in SOSNet """ 115 | def sos_reg(anchor, positive, KNN=True, k=1, eps=1e-8): 116 | dist_matrix_a = distance_matrix_vector(anchor, anchor) + eps 117 | dist_matrix_b = distance_matrix_vector(positive, positive) + eps 118 | if KNN: 119 | k_max = percentile(dist_matrix_b, k) 120 | #print("k_max:", k_max) 121 | mask = dist_matrix_b.lt(k_max) 122 | dist_matrix_a = dist_matrix_a*mask.int().float() 123 | dist_matrix_b = dist_matrix_b*mask.int().float() 124 | SOS_temp = torch.sqrt(torch.sum(torch.pow(dist_matrix_a-dist_matrix_b, 2))) 125 | return torch.mean(SOS_temp) 126 | -------------------------------------------------------------------------------- /external_src/SuperRetina/model/pke_module.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | from common.common_util import sample_keypoint_desc, nms 6 | from model.record_module import update_value_map 7 | 8 | 9 | def mapping_points(grid, points, h, w): 10 | """ Using grid_inverse to apply affine transform on geo_points 11 | :return point set and its corresponding affine point set 12 | """ 13 | 14 | grid_points = [(grid[s, k[:, 1].long(), k[:, 0].long()]) for s, k in 15 | enumerate(points)] 16 | filter_points = [] 17 | affine_points = [] 18 | for s, k in enumerate(grid_points): # filter bad geo_points 19 | idx = (k[:, 0] < 1) & (k[:, 0] > -1) & (k[:, 1] < 1) & ( 20 | k[:, 1] > -1) 21 | gp = grid_points[s][idx] 22 | gp[:, 0] = (gp[:, 0] + 1) / 2 * (w - 1) 23 | gp[:, 1] = (gp[:, 1] + 1) / 2 * (h - 1) 24 | affine_points.append(gp) 25 | filter_points.append(points[s][idx]) 26 | 27 | return filter_points, affine_points 28 | 29 | 30 | def content_filter(descriptor_pred, affine_descriptor_pred, geo_points, 31 | affine_geo_points, content_thresh=0.7, scale=8): 32 | """ 33 | content-based matching in paper 34 | :param descriptor_pred: descriptors of input_image images 35 | :param affine_descriptor_pred: descriptors of affine images 36 | :param geo_points: 37 | :param affine_geo_points: 38 | :param content_thresh: 39 | :param scale: down sampling size of descriptor_pred 40 | :return: content-filtered keypoints 41 | """ 42 | 43 | descriptors = [sample_keypoint_desc(k[None], d[None], scale)[0].permute(1, 0) 44 | for k, d in zip(geo_points, descriptor_pred)] 45 | aff_descriptors = [sample_keypoint_desc(k[None], d[None], scale)[0].permute(1, 0) 46 | for k, d in zip(affine_geo_points, affine_descriptor_pred)] 47 | content_points = [] 48 | affine_content_points = [] 49 | dist = [torch.norm(descriptors[d][:, None] - aff_descriptors[d], dim=2, p=2) 50 | for d in range(len(descriptors))] 51 | for i in range(len(dist)): 52 | D = dist[i] 53 | if len(D) <= 1: 54 | content_points.append([]) 55 | affine_content_points.append([]) 56 | continue 57 | val, ind = torch.topk(D, 2, dim=1, largest=False) 58 | 59 | arange = torch.arange(len(D)) 60 | # rule1 spatial correspondence 61 | c1 = ind[:, 0] == arange.to(ind.device) 62 | # rule2 pass the ratio test 63 | c2 = val[:, 0] < val[:, 1] * content_thresh 64 | 65 | check = c2 * c1 66 | content_points.append(geo_points[i][check]) 67 | affine_content_points.append(affine_geo_points[i][check]) 68 | return content_points, affine_content_points 69 | 70 | 71 | def geometric_filter(affine_detector_pred, points, affine_points, max_num=1024, geometric_thresh=0.5): 72 | """ 73 | geometric matching in paper 74 | :param affine_detector_pred: geo_points probability of affine image 75 | :param points: nms results of input_image image 76 | :param affine_points: nms results of affine image 77 | :param max_num: maximum number of learned keypoints 78 | :param geometric_thresh: 79 | :return: geometric-filtered keypoints 80 | """ 81 | geo_points = [] 82 | affine_geo_points = [] 83 | for s, k in enumerate(affine_points): 84 | sample_aff_values = affine_detector_pred[s, 0, k[:, 1].long(), k[:, 0].long()] 85 | check = sample_aff_values.squeeze() >= geometric_thresh 86 | geo_points.append(points[s][check][:max_num]) 87 | affine_geo_points.append(k[check][:max_num]) 88 | 89 | return geo_points, affine_geo_points 90 | 91 | 92 | def pke_learn(detector_pred, descriptor_pred, grid_inverse, affine_detector_pred, 93 | affine_descriptor_pred, kernel, loss_cal, label_point_positions, 94 | value_map, config, PKE_learn=True): 95 | """ 96 | pke process used for detector 97 | :param detector_pred: probability map from raw image 98 | :param descriptor_pred: prediction of descriptor_pred network 99 | :param kernel: used for gaussian heatmaps 100 | :param mask_kernel: used for masking initial keypoints 101 | :param grid_inverse: used for inverse 102 | :param loss_cal: loss (default is dice) 103 | :param label_point_positions: positions of keypoints on labels 104 | :param value_map: value map for recoding and selecting learned geo_points 105 | :param pke_learn: whether to use PKE 106 | :return: loss of detector, num of additional geo_points, updated value maps and enhanced labels 107 | """ 108 | # used for masking initial keypoints on enhanced labels 109 | initial_label = F.conv2d(label_point_positions, kernel, 110 | stride=1, padding=(kernel.shape[-1] - 1) // 2) 111 | initial_label[initial_label > 1] = 1 112 | 113 | if not PKE_learn: 114 | return loss_cal(detector_pred, initial_label.to(detector_pred)), 0, None, None, initial_label 115 | 116 | nms_size = config['nms_size'] 117 | nms_thresh = config['nms_thresh'] 118 | scale = 8 119 | 120 | enhanced_label = None 121 | geometric_thresh = config['geometric_thresh'] 122 | content_thresh = config['content_thresh'] 123 | with torch.no_grad(): 124 | h, w = detector_pred.shape[2:] 125 | 126 | # number of learned points 127 | number_pts = 0 128 | points = nms(detector_pred, nms_thresh=nms_thresh, nms_size=nms_size, 129 | detector_label=initial_label, mask=True) 130 | 131 | # geometric matching 132 | points, affine_points = mapping_points(grid_inverse, points, h, w) 133 | geo_points, affine_geo_points = geometric_filter(affine_detector_pred, points, affine_points, 134 | geometric_thresh=geometric_thresh) 135 | 136 | 137 | # content matching 138 | content_points, affine_contend_points = content_filter(descriptor_pred, affine_descriptor_pred, geo_points, 139 | affine_geo_points, content_thresh=content_thresh, 140 | scale=scale) 141 | enhanced_label_pts = [] 142 | for step in range(len(content_points)): 143 | # used to combine initial points and learned points 144 | positions = torch.where(label_point_positions[step, 0] == 1) 145 | if len(positions) == 2: 146 | positions = torch.cat((positions[1].unsqueeze(-1), positions[0].unsqueeze(-1)), -1) 147 | else: 148 | positions = positions[0] 149 | 150 | final_points = update_value_map(value_map[step], content_points[step], config) 151 | 152 | # final_points = torch.cat((final_points, positions)) 153 | 154 | temp_label = torch.zeros([h, w]).to(detector_pred.device) 155 | 156 | temp_label[final_points[:, 1], final_points[:, 0]] = 0.5 157 | temp_label[positions[:, 1], positions[:, 0]] = 1 158 | 159 | enhanced_kps = nms(temp_label.unsqueeze(0).unsqueeze(0), 0.1, 10)[0] 160 | if len(enhanced_kps) < len(positions): 161 | enhanced_kps = positions 162 | # print(len(final_points), len(positions), len(enhanced_kps)) 163 | number_pts += (len(enhanced_kps) - len(positions)) 164 | # number_pts += (len(enhanced_kps) - len(positions)) if (len(enhanced_kps) - len(positions)) > 0 else 0 165 | 166 | temp_label[:] = 0 167 | temp_label[enhanced_kps[:, 1], enhanced_kps[:, 0]] = 1 168 | 169 | enhanced_label_pts.append(temp_label.unsqueeze(0).unsqueeze(0)) 170 | 171 | temp_label = F.conv2d(temp_label.unsqueeze(0).unsqueeze(0), kernel, stride=1, 172 | padding=(kernel.shape[-1] - 1) // 2) # generating gaussian heatmaps 173 | temp_label[temp_label > 1] = 1 174 | 175 | if enhanced_label is None: 176 | enhanced_label = temp_label 177 | else: 178 | enhanced_label = torch.cat((enhanced_label, temp_label)) 179 | 180 | enhanced_label_pts = torch.cat(enhanced_label_pts) 181 | affine_pred_inverse = F.grid_sample(affine_detector_pred, grid_inverse, align_corners=True) 182 | 183 | loss1 = loss_cal(detector_pred, enhanced_label) # L_geo 184 | loss2 = loss_cal(detector_pred, affine_pred_inverse) # L_clf 185 | # pred_mask = (enhanced_label > 0) & (affine_pred_inverse != 0) 186 | # loss2 = loss_cal(detector_pred[pred_mask], affine_pred_inverse[pred_mask]) # L_clf 187 | 188 | # mask_pred = grid_inverse 189 | # loss2 = loss_cal(detector_pred[mask_pred], affine_pred_inverse[mask_pred]) # L_clf 190 | 191 | loss = loss1+loss2 192 | 193 | return loss, number_pts, value_map, enhanced_label_pts, enhanced_label 194 | -------------------------------------------------------------------------------- /external_src/SuperRetina/model/record_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from common.common_util import simple_nms 4 | 5 | 6 | def update_value_map(value_map, points, value_map_config): 7 | """ 8 | Update value maps used for recording learned keypoints from PKE, 9 | and getting the final learned keypoints which are combined of previous learned keypoints. 10 | :param value_map: previous value maps 11 | :param points: the learned keypoints in this epoch 12 | :param value_map_config: 13 | :return: the final learned keypoints combined of previous learning points 14 | """ 15 | 16 | raw_value_map = value_map.clone() 17 | # used for record areas of value=0 18 | raw_value_map[value_map == 0] = -1 19 | area_set = value_map_config['area'] 20 | area = area_set // 2 21 | 22 | value_increase_point = value_map_config['value_increase_point'] 23 | value_increase_area = value_map_config['value_increase_area'] 24 | value_decay = value_map_config['value_decay'] 25 | 26 | h, w = value_map[0].shape 27 | for (x, y) in points: 28 | y_d = y - area // 2 if y - area // 2 > 0 else 0 29 | y_u = y + area // 2 if y + area // 2 < h else h 30 | x_l = x - area // 2 if x - area // 2 > 0 else 0 31 | x_r = x + area // 2 if x + area // 2 < w else w 32 | tmp = value_map[0, y_d:y_u, x_l:x_r] 33 | if value_map[0, y, x] != 0 or tmp.sum() == 0: 34 | value_map[0, y, x] += value_increase_point # if there is no learned point before, then add a high value 35 | else: 36 | tmp[tmp > 0] += value_increase_area 37 | value_map[0, y_d:y_u, x_l:x_r] = tmp 38 | 39 | value_map[torch.where( 40 | value_map == raw_value_map)] -= value_decay # value decay of positions that don't appear this time 41 | 42 | tmp = value_map.detach().clone() 43 | 44 | tmp = simple_nms(tmp.unsqueeze(0).float(), area_set*2) 45 | tmp = tmp.squeeze() 46 | 47 | final_points = torch.nonzero(tmp >= value_increase_point) 48 | final_points = torch.flip(final_points, [1]).long() # to x, y 49 | return final_points 50 | -------------------------------------------------------------------------------- /external_src/SuperRetina/requirements.txt: -------------------------------------------------------------------------------- 1 | imgaug==0.4.0 2 | matplotlib==3.5.1 3 | numpy==1.22.3 4 | opencv_python==4.6.0.66 5 | Pillow==9.2.0 6 | PyYAML==6.0 7 | scikit_learn==1.1.1 8 | scipy==1.8.0 9 | torch==1.8.1 10 | torchvision==0.9.1 11 | tqdm==4.64.0 12 | -------------------------------------------------------------------------------- /external_src/SuperRetina/save/.placehold: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/ImageFlowNet/61412a9a014776e653de5bd713b6b817852be3d8/external_src/SuperRetina/save/.placehold -------------------------------------------------------------------------------- /external_src/SuperRetina/test_on_FIRE.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | from common.eval_util import compute_auc 7 | from predictor import Predictor 8 | import os 9 | import cv2 10 | import yaml 11 | 12 | config_path = './config/test.yaml' 13 | if os.path.exists(config_path): 14 | with open(config_path) as f: 15 | config = yaml.safe_load(f) 16 | else: 17 | raise FileNotFoundError("Config File doesn't Exist") 18 | 19 | Pred = Predictor(config) 20 | 21 | data_path = './data/' # Change the data_path according to your own setup 22 | testset = 'FIRE' 23 | use_matching_trick = config['PREDICT']['use_matching_trick'] 24 | gt_dir = os.path.join(data_path, testset, 'Ground Truth') 25 | im_dir = os.path.join(data_path, testset, 'Images') 26 | 27 | match_pairs = [x for x in os.listdir(gt_dir) if x.endswith('.txt') 28 | and not x.endswith('P37_1_2.txt')] 29 | 30 | match_pairs.sort() 31 | big_num = 1e6 32 | good_nums_rate = [] 33 | image_num = 0 34 | 35 | failed = 0 36 | inaccurate = 0 37 | mae = 0 38 | mee = 0 39 | 40 | # category: S, P, A, corresponding to Easy, Hard, Mod in paper 41 | auc_record = dict([(category, []) for category in ['S', 'P', 'A']]) 42 | 43 | for pair_file in tqdm(match_pairs): 44 | gt_file = os.path.join(gt_dir, pair_file) 45 | file_name = pair_file.replace('.txt', '') 46 | 47 | category = file_name.split('_')[2][0] 48 | 49 | refer = file_name.split('_')[2] + '_' + file_name.split('_')[3] 50 | query = file_name.split('_')[2] + '_' + file_name.split('_')[4] 51 | 52 | query_im_path = os.path.join(im_dir, query + '.jpg') 53 | refer_im_path = os.path.join(im_dir, refer + '.jpg') 54 | H_m1, inliers_num_rate, query_image, _ = Pred.compute_homography(query_im_path, refer_im_path) 55 | H_m2 = None 56 | if use_matching_trick: 57 | if H_m1 is not None: 58 | h, w = Pred.image_height, Pred.image_width 59 | query_align_first = cv2.warpPerspective(query_image, H_m1, (w, h), borderMode=cv2.BORDER_CONSTANT, 60 | borderValue=(0)) 61 | query_align_first = query_align_first.astype(float) 62 | query_align_first /= 255. 63 | H_m2, inliers_num_rate, _, _ = Pred.compute_homography(query_align_first, refer_im_path, query_is_image=True) 64 | 65 | good_nums_rate.append(inliers_num_rate) 66 | image_num += 1 67 | 68 | if inliers_num_rate < 1e-6: 69 | failed += 1 70 | avg_dist = big_num 71 | else: 72 | points_gd = np.loadtxt(gt_file) 73 | raw = np.zeros([len(points_gd), 2]) 74 | dst = np.zeros([len(points_gd), 2]) 75 | raw[:, 0] = points_gd[:, 2] 76 | raw[:, 1] = points_gd[:, 3] 77 | dst[:, 0] = points_gd[:, 0] 78 | dst[:, 1] = points_gd[:, 1] 79 | dst_pred = cv2.perspectiveTransform(raw.reshape(-1, 1, 2), H_m1) 80 | if H_m2 is not None: 81 | dst_pred = cv2.perspectiveTransform(dst_pred.reshape(-1, 1, 2), H_m2) 82 | 83 | dst_pred = dst_pred.squeeze() 84 | 85 | dis = (dst - dst_pred) ** 2 86 | dis = np.sqrt(dis[:, 0] + dis[:, 1]) 87 | avg_dist = dis.mean() 88 | 89 | mae = dis.max() 90 | mee = np.median(dis) 91 | if mae > 50 or mee > 20: 92 | inaccurate += 1 93 | 94 | auc_record[category].append(avg_dist) 95 | 96 | 97 | print('-'*40) 98 | print(f"Failed:{'%.2f' % (100*failed/image_num)}%, Inaccurate:{'%.2f' % (100*inaccurate/image_num)}%, " 99 | f"Acceptable:{'%.2f' % (100*(image_num-inaccurate-failed)/image_num)}%") 100 | 101 | print('-'*40) 102 | 103 | auc = compute_auc(auc_record['S'], auc_record['P'], auc_record['A']) 104 | print('S: %.3f, P: %.3f, A: %.3f, mAUC: %.3f' % (auc['s'], auc['p'], auc['a'], auc['mAUC'])) 105 | -------------------------------------------------------------------------------- /external_src/SuperRetina/test_on_VARIA.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | import torch 4 | from scipy.optimize import brentq 5 | from scipy.interpolate import interp1d 6 | from sklearn.metrics import roc_curve 7 | 8 | import numpy as np 9 | from predictor import Predictor 10 | import os 11 | import yaml 12 | from tqdm import tqdm 13 | 14 | config_path = './config/test_VARIA.yaml' 15 | if os.path.exists(config_path): 16 | with open(config_path) as f: 17 | config = yaml.safe_load(f) 18 | else: 19 | raise FileNotFoundError("Config File doesn't Exist") 20 | 21 | Pred = Predictor(config) 22 | 23 | data_path = './data/' # Change the data_path according to your own setup 24 | testset = 'VARIA' 25 | save_tmp_info = 'VARIA_info' # Save keypoints and descriptors 26 | id_path = os.path.join(data_path, testset, 'pair_index.txt') 27 | im_dir = os.path.join(data_path, testset, 'Images') 28 | 29 | save_tmp_info = os.path.join(data_path, testset, save_tmp_info) 30 | 31 | if os.path.exists(save_tmp_info): 32 | shutil.rmtree(save_tmp_info) 33 | os.makedirs(save_tmp_info) 34 | 35 | with open(id_path, 'r') as f: 36 | pairs = f.readlines() 37 | 38 | inliers = [] 39 | classes = [] 40 | 41 | print('Getting Predictions of All Images') 42 | for img in tqdm(os.listdir(im_dir)): 43 | if img.endswith('.pgm'): 44 | image_path = os.path.join(im_dir, img) 45 | save_path = os.path.join(save_tmp_info, img.replace('.pgm', '.pt')) 46 | 47 | Pred.model_run_one_image(image_path, save_path) 48 | 49 | print('Matching Pairs') 50 | for pair in tqdm(pairs): 51 | pair = pair.strip() 52 | query, refer, is_accepted = pair.split(', ') 53 | 54 | query_im_path = os.path.join(im_dir, query) 55 | refer_im_path = os.path.join(im_dir, refer) 56 | 57 | query_info_path = os.path.join(save_tmp_info, query.replace('.pgm', '.pt')) 58 | refer_info_path = os.path.join(save_tmp_info, refer.replace('.pgm', '.pt')) 59 | 60 | query_info = torch.load(query_info_path) 61 | refer_info = torch.load(refer_info_path) 62 | 63 | _, inliers_num = Pred.homography_from_tensor(query_info, refer_info) 64 | 65 | inliers.append(inliers_num) 66 | classes.append(int(is_accepted)) 67 | 68 | inliers = np.array(inliers) 69 | classes = np.array(classes) 70 | 71 | fpr, tpr, threshold = roc_curve(classes, inliers) 72 | eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.) 73 | thresh = interp1d(fpr, threshold)(eer) 74 | 75 | print('VARIA DATASET') 76 | print('EER: %.2f%%, threshold: %d' % (eer*100, thresh)) 77 | 78 | -------------------------------------------------------------------------------- /external_src/SuperRetina/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | from torch.utils.data import DataLoader 5 | 6 | from common.train_util import train_model 7 | from dataset.retina_dataset import RetinaDataset 8 | from model.super_retina import SuperRetina 9 | import torch.optim as optim 10 | import yaml 11 | from torch.optim import lr_scheduler 12 | import warnings 13 | 14 | warnings.filterwarnings('ignore') 15 | 16 | config_path = './config/train.yaml' 17 | 18 | if os.path.exists(config_path): 19 | with open(config_path) as f: 20 | config = yaml.safe_load(f) 21 | else: 22 | raise FileNotFoundError("Config File doesn't Exist") 23 | 24 | assert 'MODEL' in config 25 | assert 'PKE' in config 26 | assert 'DATASET' in config 27 | assert 'VALUE_MAP' in config 28 | train_config = {**config['MODEL'], **config['PKE'], **config['DATASET'], **config['VALUE_MAP']} 29 | 30 | batch_size = train_config['batch_size'] 31 | num_epoch = train_config['num_epoch'] 32 | device = train_config['device'] 33 | device = torch.device(device if torch.cuda.is_available() else "cpu") 34 | 35 | dataset_path = train_config['dataset_path'] 36 | data_shape = (train_config['model_image_height'], train_config['model_image_width']) 37 | 38 | train_split_file = train_config['train_split_file'] 39 | val_split_file = train_config['val_split_file'] 40 | auxiliary = train_config['auxiliary'] 41 | train_set = RetinaDataset(dataset_path, split_file=train_split_file, 42 | is_train=True, data_shape=data_shape, auxiliary=auxiliary) 43 | val_set = RetinaDataset(dataset_path, split_file=val_split_file, is_train=False, data_shape=data_shape) 44 | 45 | load_pre_trained_model = train_config['load_pre_trained_model'] 46 | pretrained_path = train_config['pretrained_path'] 47 | 48 | model = SuperRetina(train_config, device=device) 49 | if load_pre_trained_model: 50 | if not os.path.exists(pretrained_path): 51 | raise Exception('Pretrained model doesn\'t exist') 52 | checkpoint = torch.load(pretrained_path, map_location=device) 53 | model.load_state_dict(checkpoint['net']) 54 | 55 | optimizer = optim.Adam(model.parameters(), lr=1e-4) 56 | 57 | dataloaders = { 58 | 'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8), 59 | 'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=8) 60 | } 61 | 62 | model = train_model(model, optimizer, dataloaders, device, num_epochs=num_epoch, train_config=train_config) 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /results/placeholder.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/ImageFlowNet/61412a9a014776e653de5bd713b6b817852be3d8/results/placeholder.md -------------------------------------------------------------------------------- /src/data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/ImageFlowNet/61412a9a014776e653de5bd713b6b817852be3d8/src/data_utils/__init__.py -------------------------------------------------------------------------------- /src/data_utils/extend.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class ExtendedDataset(Dataset): 8 | def __init__(self, 9 | dataset: Dataset, 10 | desired_len: int): 11 | self.dataset = dataset 12 | self.desired_len = desired_len 13 | 14 | def __len__(self) -> int: 15 | return self.desired_len 16 | 17 | def __getitem__(self, idx) -> Tuple[np.array, np.array]: 18 | return self.dataset.__getitem__(idx % len(self.dataset)) 19 | -------------------------------------------------------------------------------- /src/data_utils/split.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset, random_split 6 | 7 | 8 | def split_dataset(dataset: Dataset, 9 | splits: Tuple[float, ] = (0.8, 0.1, 0.1), 10 | random_seed: int = 0) -> Tuple[Dataset, ]: 11 | """ 12 | Splits data into non-overlapping datasets of given proportions. 13 | 14 | Either a "train/validation/test" split 15 | Or a "train/validation" split is supported. 16 | """ 17 | assert len(splits) in [2, 3] 18 | 19 | splits = np.array(splits) 20 | splits = splits / np.sum(splits) 21 | 22 | n = len(dataset) 23 | if len(splits) == 2: 24 | val_size = int(splits[1] * n) 25 | train_size = n - val_size 26 | train_set, val_set = random_split( 27 | dataset, [train_size, val_size], 28 | generator=torch.Generator().manual_seed(random_seed)) 29 | return train_set, val_set 30 | else: 31 | val_size = int(splits[1] * n) 32 | test_size = int(splits[2] * n) 33 | train_size = n - val_size - test_size 34 | train_set, val_set, test_set = random_split( 35 | dataset, [train_size, val_size, test_size], 36 | generator=torch.Generator().manual_seed(random_seed)) 37 | return train_set, val_set, test_set 38 | 39 | 40 | def split_indices(indices: List[int] = None, 41 | splits: Tuple[float, ] = (0.8, 0.1, 0.1), 42 | random_seed: int = 0): 43 | """ 44 | Splits indices into non-overlapping subsets of given proportions. 45 | 46 | Either a "train/validation/test" split 47 | Or a "train/validation" split is supported. 48 | """ 49 | assert len(splits) in [2, 3] 50 | 51 | splits = np.array(splits) 52 | splits = splits / np.sum(splits) 53 | 54 | rng = np.random.default_rng(seed=random_seed) 55 | indices = rng.permutation(indices) 56 | 57 | n = len(indices) 58 | if len(splits) == 2: 59 | val_size = int(splits[1] * n) 60 | train_size = n - val_size 61 | train_indices = sorted(indices[:train_size]) 62 | val_indices = sorted(indices[train_size:]) 63 | return train_indices, val_indices 64 | else: 65 | val_size = int(splits[1] * n) 66 | test_size = int(splits[2] * n) 67 | train_size = n - val_size - test_size 68 | train_indices = sorted(indices[:train_size]) 69 | val_indices = sorted(indices[train_size:train_size + val_size]) 70 | test_indices = sorted(indices[train_size + val_size:]) 71 | return train_indices, val_indices, test_indices 72 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/ImageFlowNet/61412a9a014776e653de5bd713b6b817852be3d8/src/datasets/__init__.py -------------------------------------------------------------------------------- /src/datasets/synthetic.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | from typing import Literal 4 | from glob import glob 5 | from typing import List, Tuple 6 | 7 | import cv2 8 | import numpy as np 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class SyntheticDataset(Dataset): 13 | 14 | def __init__(self, 15 | base_path: str = '../../data/synthesized/', 16 | image_folder: str = 'base/', 17 | target_dim: Tuple[int] = (256, 256)): 18 | ''' 19 | NOTE: since different patients may have different number of visits, the returned array will 20 | not necessarily be of the same shape. Due to the concatenation requirements, we can only 21 | set batch size to 1 in the downstream Dataloader. 22 | ''' 23 | super().__init__() 24 | 25 | self.target_dim = target_dim 26 | all_image_folders = sorted(glob('%s/%s/*/' % (base_path, image_folder))) 27 | 28 | self.image_by_patient = [] 29 | 30 | for folder in all_image_folders: 31 | paths = sorted(glob('%s/*.png' % (folder))) 32 | if len(paths) >= 2: 33 | self.image_by_patient.append(paths) 34 | 35 | def __len__(self) -> int: 36 | return len(self.image_by_patient) 37 | 38 | def num_image_channel(self) -> int: 39 | ''' Number of image channels. ''' 40 | return 3 41 | 42 | 43 | class SyntheticSubset(SyntheticDataset): 44 | 45 | def __init__(self, 46 | main_dataset: SyntheticDataset = None, 47 | subset_indices: List[int] = None, 48 | return_format: str = Literal['one_pair', 'all_pairs', 49 | 'array']): 50 | ''' 51 | A subset of SyntheticDataset. 52 | 53 | In SyntheticDataset, we carefully isolated the (variable number of) images from 54 | different patients, and in train/val/test split we split the data by 55 | patient rather than by image. 56 | 57 | Now we have 3 instances of SyntheticSubset, one for each train/val/test set. 58 | In each set, we can safely unpack the images out. 59 | We want to organize the images such that each time `__getitem__` is called, 60 | it gets a pair of [x_start, x_end] and [t_start, t_end]. 61 | ''' 62 | super().__init__() 63 | 64 | self.target_dim = main_dataset.target_dim 65 | self.return_format = return_format 66 | 67 | self.image_by_patient = [ 68 | main_dataset.image_by_patient[i] for i in subset_indices 69 | ] 70 | 71 | self.all_image_pairs = [] 72 | for image_list in self.image_by_patient: 73 | pair_indices = list( 74 | itertools.combinations(np.arange(len(image_list)), r=2)) 75 | for (idx1, idx2) in pair_indices: 76 | self.all_image_pairs.append( 77 | [image_list[idx1], image_list[idx2]]) 78 | 79 | def __len__(self) -> int: 80 | if self.return_format == 'one_pair': 81 | # If we only return 1 pair of images per patient... 82 | return len(self.image_by_patient) 83 | elif self.return_format == 'all_pairs': 84 | # If we return all pairs of images per patient... 85 | return len(self.all_image_pairs) 86 | elif self.return_format == 'array': 87 | # If we return all images as an array per patient... 88 | return len(self.image_by_patient) 89 | 90 | def __getitem__(self, idx) -> Tuple[np.array, np.array]: 91 | if self.return_format == 'one_pair': 92 | image_list = self.image_by_patient[idx] 93 | pair_indices = list( 94 | itertools.combinations(np.arange(len(image_list)), r=2)) 95 | sampled_pair = [ 96 | image_list[i] 97 | for i in pair_indices[np.random.choice(len(pair_indices))] 98 | ] 99 | images = np.array([ 100 | load_image(p, target_dim=self.target_dim) for p in sampled_pair 101 | ]) 102 | timestamps = np.array([get_time(p) for p in sampled_pair]) 103 | 104 | elif self.return_format == 'all_pairs': 105 | queried_pair = self.all_image_pairs[idx] 106 | images = np.array([ 107 | load_image(p, target_dim=self.target_dim) for p in queried_pair 108 | ]) 109 | timestamps = np.array([get_time(p) for p in queried_pair]) 110 | 111 | elif self.return_format == 'array': 112 | queried_patient = self.image_by_patient[idx] 113 | images = np.array([ 114 | load_image(p, target_dim=self.target_dim) 115 | for p in queried_patient 116 | ]) 117 | timestamps = np.array([get_time(p) for p in queried_patient]) 118 | 119 | return images, timestamps 120 | 121 | 122 | def load_image(path: str, target_dim: Tuple[int] = None) -> np.array: 123 | ''' Load image as numpy array from a path string.''' 124 | if target_dim is not None: 125 | image = np.array( 126 | cv2.resize( 127 | cv2.cvtColor(cv2.imread(path, cv2.IMREAD_COLOR), 128 | code=cv2.COLOR_BGR2RGB), target_dim)) 129 | else: 130 | image = np.array( 131 | cv2.cvtColor(cv2.imread(path, cv2.IMREAD_COLOR), 132 | code=cv2.COLOR_BGR2RGB)) 133 | 134 | # Normalize image. 135 | image = (image / 255 * 2) - 1 136 | 137 | # Channel last to channel first to comply with Torch. 138 | image = np.moveaxis(image, -1, 0) 139 | 140 | return image 141 | 142 | 143 | def get_time(path: str) -> float: 144 | ''' Get the timestamp information from a path string. ''' 145 | time = path.split('time_')[1].replace('.png', '') 146 | # Shall be 3 digits 147 | assert len(time) == 3 148 | time = float(time) 149 | return time 150 | -------------------------------------------------------------------------------- /src/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/ImageFlowNet/61412a9a014776e653de5bd713b6b817852be3d8/src/nn/__init__.py -------------------------------------------------------------------------------- /src/nn/autoencoder.py: -------------------------------------------------------------------------------- 1 | from .base import BaseNetwork 2 | from .common_encoder import Encoder 3 | from .nn_utils import ConvBlock, ResConvBlock 4 | import torch 5 | 6 | 7 | class AutoEncoder(BaseNetwork): 8 | 9 | def __init__(self, 10 | device: torch.device = torch.device('cpu'), 11 | num_filters: int = 16, 12 | depth: int = 5, 13 | use_residual: bool = False, 14 | in_channels: int = 3, 15 | out_channels: int = 3, 16 | non_linearity: str = 'relu'): 17 | ''' 18 | A vanilla AutoEncoder mode. 19 | 20 | Parameters 21 | ---------- 22 | device: torch.device 23 | num_filters : int 24 | Number of convolutional filters. 25 | depth: int 26 | Depth of the model (encoding or decoding) 27 | use_residual: bool 28 | Whether to use residual connection within the same conv block 29 | in_channels: int 30 | Number of input image channels. 31 | out_channels: int 32 | Number of output image channels. 33 | non_linearity : string 34 | One of 'relu' and 'softplus' 35 | ''' 36 | super().__init__() 37 | 38 | self.device = device 39 | self.depth = depth 40 | self.use_residual = use_residual 41 | self.in_channels = in_channels 42 | self.non_linearity_str = non_linearity 43 | if self.non_linearity_str == 'relu': 44 | self.non_linearity = torch.nn.ReLU(inplace=True) 45 | elif self.non_linearity_str == 'softplus': 46 | self.non_linearity = torch.nn.Softplus() 47 | 48 | n_f = num_filters # shorthand 49 | 50 | if self.use_residual: 51 | conv_block = ResConvBlock 52 | upconv_block = ResConvBlock 53 | else: 54 | conv_block = ConvBlock 55 | upconv_block = ConvBlock 56 | 57 | # This is for the encoder. 58 | self.encoder = Encoder(in_channels=in_channels, 59 | n_f=n_f, 60 | depth=self.depth, 61 | conv_block=conv_block, 62 | non_linearity=self.non_linearity) 63 | 64 | # This is for the decoder. 65 | self.up_list = torch.nn.ModuleList([]) 66 | self.up_conn_list = torch.nn.ModuleList([]) 67 | for d in range(self.depth): 68 | self.up_conn_list.append(torch.nn.Conv2d(n_f * 2 ** (d + 1), n_f * 2 ** d, 1, 1)) 69 | self.up_list.append(upconv_block(n_f * 2 ** d)) 70 | self.up_list = self.up_list[::-1] 71 | self.up_conn_list = self.up_conn_list[::-1] 72 | 73 | self.out_layer = torch.nn.Conv2d(n_f, out_channels, 1) 74 | 75 | 76 | def forward(self, x: torch.Tensor, t: torch.Tensor): 77 | ''' 78 | No time embedding. `t` is treated as a dummy variable. 79 | ''' 80 | 81 | assert x.shape[0] == 1 82 | 83 | x, _ = self.encoder(x) 84 | 85 | for d in range(self.depth): 86 | x = torch.nn.functional.interpolate(x, 87 | scale_factor=2, 88 | mode='bilinear', 89 | align_corners=True) 90 | x = self.non_linearity(self.up_conn_list[d](x)) 91 | x = self.up_list[d](x) 92 | 93 | output = self.out_layer(x) 94 | 95 | return output 96 | -------------------------------------------------------------------------------- /src/nn/autoencoder_ode.py: -------------------------------------------------------------------------------- 1 | from .base import BaseNetwork 2 | from .nn_utils import ConvBlock, ResConvBlock, ODEfunc, ODEBlock 3 | from .common_encoder import Encoder 4 | import torch 5 | 6 | 7 | class ODEAutoEncoder(BaseNetwork): 8 | 9 | def __init__(self, 10 | device: torch.device = torch.device('cpu'), 11 | num_filters: int = 16, 12 | depth: int = 5, 13 | use_residual: bool = False, 14 | in_channels: int = 3, 15 | out_channels: int = 3, 16 | non_linearity: str = 'relu'): 17 | ''' 18 | An AutoEncoder model with ODE. 19 | 20 | Parameters 21 | ---------- 22 | device: torch.device 23 | num_filters : int 24 | Number of convolutional filters. 25 | depth: int 26 | Depth of the model (encoding or decoding) 27 | use_residual: bool 28 | Whether to use residual connection within the same conv block 29 | in_channels: int 30 | Number of input image channels. 31 | out_channels: int 32 | Number of output image channels. 33 | non_linearity : string 34 | One of 'relu' and 'softplus' 35 | ''' 36 | super().__init__() 37 | 38 | self.device = device 39 | self.depth = depth 40 | self.use_residual = use_residual 41 | self.in_channels = in_channels 42 | self.non_linearity_str = non_linearity 43 | if self.non_linearity_str == 'relu': 44 | self.non_linearity = torch.nn.ReLU(inplace=True) 45 | elif self.non_linearity_str == 'softplus': 46 | self.non_linearity = torch.nn.Softplus() 47 | 48 | n_f = num_filters # shorthand 49 | 50 | if self.use_residual: 51 | conv_block = ResConvBlock 52 | upconv_block = ResConvBlock 53 | else: 54 | conv_block = ConvBlock 55 | upconv_block = ConvBlock 56 | 57 | # This is for the encoder. 58 | self.encoder = Encoder(in_channels=in_channels, 59 | n_f=n_f, 60 | depth=self.depth, 61 | conv_block=conv_block, 62 | non_linearity=self.non_linearity) 63 | 64 | # This is for the decoder. 65 | self.up_list = torch.nn.ModuleList([]) 66 | self.up_conn_list = torch.nn.ModuleList([]) 67 | for d in range(self.depth): 68 | self.up_conn_list.append(torch.nn.Conv2d(n_f * 2 ** (d + 1), n_f * 2 ** d, 1, 1)) 69 | self.up_list.append(upconv_block(n_f * 2 ** d)) 70 | self.up_list = self.up_list[::-1] 71 | self.up_conn_list = self.up_conn_list[::-1] 72 | 73 | self.ode_bottleneck = ODEBlock(ODEfunc(dim=n_f * 2 ** self.depth)) 74 | self.out_layer = torch.nn.Conv2d(n_f, out_channels, 1) 75 | 76 | 77 | def forward(self, x: torch.Tensor, t: torch.Tensor): 78 | ''' 79 | Time embedding through ODE. 80 | ''' 81 | 82 | assert x.shape[0] == 1 83 | 84 | # Skip ODE if no time difference. 85 | use_ode = t.item() != 0 86 | if use_ode: 87 | integration_time = torch.tensor([0, t.item()]).float().to(t.device) 88 | 89 | x, _ = self.encoder(x) 90 | 91 | if use_ode: 92 | x = self.ode_bottleneck(x, integration_time) 93 | 94 | for d in range(self.depth): 95 | x = torch.nn.functional.interpolate(x, 96 | scale_factor=2, 97 | mode='bilinear', 98 | align_corners=True) 99 | x = self.non_linearity(self.up_conn_list[d](x)) 100 | x = self.up_list[d](x) 101 | 102 | output = self.out_layer(x) 103 | 104 | return output 105 | -------------------------------------------------------------------------------- /src/nn/autoencoder_t_emb.py: -------------------------------------------------------------------------------- 1 | from .base import BaseNetwork 2 | from .nn_utils import ConvBlock, ResConvBlock, timestep_embedding 3 | from .common_encoder import Encoder 4 | import torch 5 | 6 | 7 | class T_AutoEncoder(BaseNetwork): 8 | 9 | def __init__(self, 10 | device: torch.device = torch.device('cpu'), 11 | num_filters: int = 16, 12 | depth: int = 5, 13 | use_residual: bool = False, 14 | in_channels: int = 3, 15 | out_channels: int = 3, 16 | non_linearity: str = 'relu'): 17 | ''' 18 | An AutoEncoder model with time embedding. 19 | 20 | Parameters 21 | ---------- 22 | device: torch.device 23 | num_filters : int 24 | Number of convolutional filters. 25 | depth: int 26 | Depth of the model (encoding or decoding) 27 | use_residual: bool 28 | Whether to use residual connection within the same conv block 29 | in_channels: int 30 | Number of input image channels. 31 | out_channels: int 32 | Number of output image channels. 33 | non_linearity : string 34 | One of 'relu' and 'softplus' 35 | ''' 36 | super().__init__() 37 | 38 | self.device = device 39 | self.depth = depth 40 | self.use_residual = use_residual 41 | self.in_channels = in_channels 42 | self.non_linearity_str = non_linearity 43 | if self.non_linearity_str == 'relu': 44 | self.non_linearity = torch.nn.ReLU(inplace=True) 45 | elif self.non_linearity_str == 'softplus': 46 | self.non_linearity = torch.nn.Softplus() 47 | 48 | n_f = num_filters # shorthand 49 | 50 | if self.use_residual: 51 | conv_block = ResConvBlock 52 | upconv_block = ResConvBlock 53 | else: 54 | conv_block = ConvBlock 55 | upconv_block = ConvBlock 56 | 57 | # This is for the encoder. 58 | self.encoder = Encoder(in_channels=in_channels, 59 | n_f=n_f, 60 | depth=self.depth, 61 | conv_block=conv_block, 62 | non_linearity=self.non_linearity) 63 | 64 | # This is for the decoder. 65 | self.up_list = torch.nn.ModuleList([]) 66 | self.up_conn_list = torch.nn.ModuleList([]) 67 | for d in range(self.depth): 68 | self.up_conn_list.append(torch.nn.Conv2d(n_f * 2 ** (d + 1), n_f * 2 ** d, 1, 1)) 69 | self.up_list.append(upconv_block(n_f * 2 ** d)) 70 | self.up_list = self.up_list[::-1] 71 | self.up_conn_list = self.up_conn_list[::-1] 72 | 73 | self.time_embed_dim = n_f * 2 ** self.depth 74 | self.time_embed = torch.nn.Sequential( 75 | torch.torch.nn.Linear(self.time_embed_dim, self.time_embed_dim), 76 | torch.nn.SiLU(), 77 | torch.torch.nn.Linear(self.time_embed_dim, self.time_embed_dim), 78 | ) 79 | self.out_layer = torch.nn.Conv2d(n_f, out_channels, 1) 80 | 81 | 82 | def forward(self, x: torch.Tensor, t: torch.Tensor): 83 | ''' 84 | Time embedding through sinusoidal embedding. 85 | ''' 86 | 87 | assert x.shape[0] == 1 88 | 89 | x, _ = self.encoder(x) 90 | 91 | # Time embedding through feature space addition. 92 | assert x.shape[0] == 1 and x.shape[1] == self.time_embed_dim 93 | t_emb = self.time_embed(timestep_embedding(t, dim=self.time_embed_dim)) 94 | t_emb = t_emb[:, :, None, None].repeat((1, 1, x.shape[2], x.shape[3])) 95 | x = x + t_emb 96 | 97 | for d in range(self.depth): 98 | x = torch.nn.functional.interpolate(x, 99 | scale_factor=2, 100 | mode='bilinear', 101 | align_corners=True) 102 | x = self.non_linearity(self.up_conn_list[d](x)) 103 | x = self.up_list[d](x) 104 | 105 | output = self.out_layer(x) 106 | 107 | return output 108 | 109 | 110 | -------------------------------------------------------------------------------- /src/nn/aux_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base import BaseNetwork 4 | from .nn_utils import ConvBlock, ResConvBlock 5 | from .common_encoder import Encoder 6 | 7 | 8 | class AuxNet(BaseNetwork): 9 | 10 | def __init__(self, 11 | device: torch.device = torch.device('cpu'), 12 | num_filters: int = 16, 13 | depth: int = 5, 14 | use_residual: bool = False, 15 | in_channels: int = 3, 16 | out_channels: int = 1, 17 | dim_proj: int = 256, 18 | non_linearity: str = 'relu'): 19 | ''' 20 | Auxiliary Network that performs projection and segmentation. 21 | The projection head brings the model to a vector space of `dim_proj` dimensions, 22 | and performs contrastive learning. It then serves as a "discriminator" to train 23 | the main network through loss backpropagation. 24 | 25 | Parameters 26 | ---------- 27 | device: torch.device 28 | num_filters : int 29 | Number of convolutional filters. 30 | depth: int 31 | Depth of the model (encoding or decoding) 32 | use_residual: bool 33 | Whether to use residual connection within the same conv block 34 | in_channels: int 35 | Number of input image channels. 36 | out_channels: int 37 | Number of output image channels. 38 | non_linearity : string 39 | One of 'relu' and 'softplus' 40 | ''' 41 | super().__init__() 42 | 43 | self.device = device 44 | self.depth = depth 45 | self.use_residual = use_residual 46 | self.in_channels = in_channels 47 | self.dim_proj = dim_proj 48 | self.non_linearity_str = non_linearity 49 | if self.non_linearity_str == 'relu': 50 | self.non_linearity = torch.nn.ReLU(inplace=True) 51 | elif self.non_linearity_str == 'softplus': 52 | self.non_linearity = torch.nn.Softplus() 53 | 54 | n_f = num_filters # shorthand 55 | 56 | if self.use_residual: 57 | conv_block = ResConvBlock 58 | upconv_block = ResConvBlock 59 | else: 60 | conv_block = ConvBlock 61 | upconv_block = ConvBlock 62 | 63 | # This is for the encoder. 64 | self.encoder = Encoder(in_channels=in_channels, 65 | n_f=n_f, 66 | depth=self.depth, 67 | conv_block=conv_block, 68 | non_linearity=self.non_linearity) 69 | 70 | # This is for the segmentation head. 71 | self.up_list = torch.nn.ModuleList([]) 72 | self.up_conn_list = torch.nn.ModuleList([]) 73 | for d in range(self.depth): 74 | self.up_conn_list.append(torch.nn.Conv2d(n_f * 3 * 2 ** d, n_f * 2 ** d, 1, 1)) 75 | self.up_list.append(upconv_block(n_f * 2 ** d)) 76 | self.up_list = self.up_list[::-1] 77 | self.up_conn_list = self.up_conn_list[::-1] 78 | 79 | self.seg_head = torch.nn.ModuleList([ 80 | conv_block(n_f), 81 | torch.nn.Conv2d(n_f, out_channels, 1), 82 | torch.nn.Sigmoid(), 83 | ]) 84 | 85 | # This is for the projection head 86 | self.proj_head = torch.nn.ModuleList([ 87 | conv_block(n_f * 2 ** self.depth), 88 | torch.nn.AdaptiveAvgPool2d((1, 1)), 89 | torch.nn.Flatten(), 90 | torch.nn.Linear(n_f * 2 ** self.depth, n_f * 2 ** self.depth, bias=False), 91 | torch.nn.ReLU(), 92 | torch.nn.Linear(n_f * 2 ** self.depth, self.dim_proj, bias=False), 93 | ]) 94 | 95 | def forward_seg(self, x: torch.Tensor): 96 | ''' 97 | Forward through the segmentation path. 98 | ''' 99 | 100 | x, residual_list = self.encoder(x) 101 | 102 | for d in range(self.depth): 103 | x = torch.nn.functional.interpolate(x, 104 | scale_factor=2, 105 | mode='bilinear', 106 | align_corners=False) 107 | x = torch.cat([x, residual_list.pop(-1)], dim=1) 108 | x = self.non_linearity(self.up_conn_list[d](x)) 109 | x = self.up_list[d](x) 110 | 111 | for module in self.seg_head: 112 | x = module(x) 113 | 114 | return x 115 | 116 | def forward_proj(self, x: torch.Tensor): 117 | ''' 118 | Forward through the projection path. 119 | ''' 120 | 121 | x, _ = self.encoder(x) 122 | 123 | for module in self.proj_head: 124 | x = module(x) 125 | 126 | # L2-normalize to put it onto the unit hypersphere. 127 | x = torch.nn.functional.normalize(x, dim=-1) 128 | return x 129 | 130 | def forward(self, *args, **kwargs): 131 | raise NotImplementedError('Please use `forward_seg` or `forward_proj` instead.') 132 | -------------------------------------------------------------------------------- /src/nn/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | class BaseNetwork(torch.nn.Module): 6 | ''' 7 | An base network class. For defining common utilities such as loading/saving. 8 | ''' 9 | 10 | def __init__(self, **kwargs): 11 | super(BaseNetwork, self).__init__() 12 | pass 13 | 14 | def forward(self, *args, **kwargs): 15 | pass 16 | 17 | def save_weights(self, model_save_path: str) -> None: 18 | os.makedirs(os.path.dirname(model_save_path), exist_ok=True) 19 | torch.save(self.state_dict(), model_save_path) 20 | return 21 | 22 | def load_weights(self, model_save_path: str, device: torch.device) -> None: 23 | self.load_state_dict(torch.load(model_save_path, map_location=device)) 24 | return 25 | 26 | def init_params(self): 27 | ''' 28 | Parameter initialization. 29 | ''' 30 | for m in self.modules(): 31 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d): 32 | torch.nn.init.kaiming_normal_(m.weight, mode='fan_in') 33 | if m.bias is not None: 34 | torch.nn.init.constant_(m.bias, 0) 35 | elif isinstance(m, torch.nn.BatchNorm2d): 36 | torch.nn.init.constant_(m.weight, 1) 37 | torch.nn.init.constant_(m.bias, 0) 38 | elif isinstance(m, torch.nn.Linear): 39 | torch.nn.init.normal_(m.weight, std=1e-3) 40 | if m.bias is not None: 41 | torch.nn.init.constant_(m.bias, 0) 42 | 43 | def freeze(self): 44 | ''' 45 | Freeze parameters. 46 | ''' 47 | for p in self.parameters(): 48 | p.requires_grad = False 49 | 50 | def unfreeze(self): 51 | ''' 52 | Freeze parameters. 53 | ''' 54 | for p in self.parameters(): 55 | p.requires_grad = True 56 | -------------------------------------------------------------------------------- /src/nn/common_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base import BaseNetwork 3 | 4 | 5 | class Encoder(BaseNetwork): 6 | 7 | def __init__(self, 8 | in_channels: int, 9 | n_f: int, 10 | depth: int, 11 | conv_block: torch.nn.Module, 12 | non_linearity: torch.nn.Module, 13 | bilinear: bool = True, 14 | use_bn: bool = True): 15 | super().__init__() 16 | 17 | self.depth = depth 18 | self.non_linearity = non_linearity 19 | self.bilinear = bilinear 20 | self.use_bn = use_bn 21 | 22 | self.conv1x1 = torch.nn.Conv2d(in_channels, n_f, 1, 1) 23 | self.down_list = torch.nn.ModuleList([]) 24 | self.down_conn_list = torch.nn.ModuleList([]) 25 | for d in range(self.depth): 26 | self.down_list.append(conv_block(n_f * 2 ** d)) 27 | if self.use_bn: 28 | self.down_conn_list.append(torch.nn.Sequential( 29 | torch.nn.Conv2d(n_f * 2 ** d, n_f * 2 ** (d + 1), 1, 1), 30 | torch.nn.BatchNorm2d(n_f * 2 ** (d + 1)), 31 | )) 32 | else: 33 | self.down_conn_list.append(torch.nn.Conv2d(n_f * 2 ** d, n_f * 2 ** (d + 1), 1, 1)) 34 | 35 | self.bottleneck = conv_block(n_f * 2 ** self.depth) 36 | 37 | if not self.bilinear: 38 | self.pooling = torch.nn.MaxPool2d(2) 39 | 40 | def forward(self, x: torch.Tensor): 41 | x = self.non_linearity(self.conv1x1(x)) 42 | 43 | residual_list = [] 44 | for d in range(self.depth): 45 | x = self.down_list[d](x) 46 | residual_list.append(x.clone()) 47 | x = self.non_linearity(self.down_conn_list[d](x)) 48 | if self.bilinear: 49 | x = torch.nn.functional.interpolate(x, 50 | scale_factor=0.5, 51 | mode='bilinear', 52 | align_corners=False) 53 | else: 54 | x = self.pooling(x) 55 | 56 | x = self.bottleneck(x) 57 | 58 | return x, residual_list 59 | 60 | def freeze_weights(self) -> None: 61 | ''' 62 | Freeze the weights and make them unchangable during training. 63 | ''' 64 | for param in self.parameters(): 65 | param.requires_grad = False 66 | 67 | def copy_weights(self, other_instance: torch.nn.Module) -> None: 68 | ''' 69 | Copy the weights from a given instance 70 | ''' 71 | with torch.no_grad(): 72 | for this_param, other_param in zip(self.parameters(), other_instance.parameters()): 73 | this_param.data.copy_(other_param.data) 74 | -------------------------------------------------------------------------------- /src/nn/imageflownet_ode.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | from .base import BaseNetwork 6 | from .nn_utils import PPODEfunc, ODEBlock 7 | 8 | import_dir = '/'.join(os.path.realpath(__file__).split('/')[:-3]) 9 | sys.path.insert(0, import_dir + '/external_src/I2SB/') 10 | from guided_diffusion.script_util import create_model 11 | from guided_diffusion.unet import timestep_embedding 12 | 13 | 14 | class ImageFlowNetODE(BaseNetwork): 15 | 16 | def __init__(self, 17 | device: torch.device, 18 | in_channels: int, 19 | ode_location: str = 'all_connections', 20 | contrastive: bool = False, 21 | **kwargs): 22 | ''' 23 | ImageFlowNet model with ODE. 24 | NOTE: This is a UNet with a position-paramterized ODE vector field. 25 | 26 | Parameters 27 | ---------- 28 | device: torch.device 29 | in_channels: int 30 | Number of input image channels. 31 | 32 | ode_location: str 33 | If 'bottleneck', only perform ODE on the bottleneck layer. 34 | If 'all_resolutions', skip connections with the same resolution share the same ODE. 35 | If 'all_connections', perform ODE separately in all skip connections. 36 | 37 | contrastive: bool 38 | Whether or not to perform contrastive learning (SimSiam) on bottleneck layer. 39 | 40 | All other kwargs will be ignored. 41 | ''' 42 | super().__init__() 43 | 44 | self.device = device 45 | self.ode_location = ode_location 46 | assert self.ode_location in ['bottleneck', 'all_resolutions', 'all_connections'] 47 | self.contrastive = contrastive 48 | 49 | image_size = 256 # TODO: currently hard coded 50 | 51 | # NOTE: This model is smaller than the other counterparts, 52 | # because running NeuralODE require some significant GPU space. 53 | # initialize model 54 | self.unet = create_model( 55 | image_size=image_size, 56 | in_channels=in_channels, 57 | num_channels=64, 58 | num_res_blocks=1, 59 | channel_mult='', 60 | learn_sigma=False, 61 | class_cond=False, 62 | use_checkpoint=False, 63 | attention_resolutions='32,16,8', 64 | num_heads=4, 65 | num_head_channels=16, 66 | num_heads_upsample=-1, 67 | use_scale_shift_norm=True, 68 | dropout=0.0, 69 | resblock_updown=True, 70 | use_fp16=False, 71 | use_new_attention_order=False) 72 | 73 | # Record the channel dimensions by passing in a dummy tensor. 74 | self.dim_list = [] 75 | h_dummy = torch.zeros((1, 1, image_size, image_size)).type(self.unet.dtype) 76 | t_dummy = torch.zeros((1)).type(self.unet.dtype) 77 | emb = self.unet.time_embed(timestep_embedding(t_dummy, self.unet.model_channels)) 78 | for module in self.unet.input_blocks: 79 | h_dummy = module(h_dummy, emb) 80 | self.dim_list.append(h_dummy.shape[1]) 81 | h_dummy_bottleneck = self.unet.middle_block(h_dummy, emb) 82 | self.dim_list.append(h_dummy_bottleneck.shape[1]) 83 | 84 | # Construct the ODE modules. 85 | self.ode_list = torch.nn.ModuleList([]) 86 | if self.ode_location == 'bottleneck': 87 | self.ode_list.append(ODEBlock(PPODEfunc(dim=h_dummy_bottleneck.shape[1]))) 88 | elif self.ode_location == 'all_resolutions': 89 | for dim in np.unique(self.dim_list): 90 | self.ode_list.append(ODEBlock(PPODEfunc(dim=dim))) 91 | elif self.ode_location == 'all_connections': 92 | for dim in self.dim_list: 93 | self.ode_list.append(ODEBlock(PPODEfunc(dim=dim))) 94 | 95 | self.unet.to(self.device) 96 | self.ode_list.to(self.device) 97 | 98 | if self.contrastive: 99 | pred_dim = 256 100 | self.projector = torch.nn.Sequential( 101 | torch.nn.Linear(h_dummy_bottleneck.shape[1] * 102 | h_dummy_bottleneck.shape[2] * 103 | h_dummy_bottleneck.shape[3], pred_dim) 104 | ) 105 | self.predictor = torch.nn.Sequential( 106 | torch.nn.Linear(pred_dim, pred_dim, bias=False), 107 | torch.nn.ReLU(inplace=True), 108 | torch.nn.Linear(pred_dim, pred_dim), 109 | ) 110 | self.projector.to(self.device) 111 | self.predictor.to(self.device) 112 | 113 | def time_independent_parameters(self): 114 | ''' 115 | Parameters related to ODE. 116 | ''' 117 | return set(self.parameters()) - set(self.ode_list.parameters()) 118 | 119 | def freeze_time_independent(self): 120 | ''' 121 | Freeze paramters that are time-independent. 122 | ''' 123 | for p in self.time_independent_parameters(): 124 | p.requires_grad = False 125 | 126 | def forward(self, x: torch.Tensor, t: torch.Tensor, return_grad: bool = False): 127 | """ 128 | Apply the model to an input batch. 129 | 130 | :param x: an [N x C x ...] Tensor of inputs. 131 | :param t: a 1-D batch of timesteps. 132 | :return: an [N x C x ...] Tensor of outputs. 133 | """ 134 | 135 | # Skip ODE if no time difference. 136 | use_ode = t.item() != 0 137 | if use_ode: 138 | integration_time = torch.tensor([0, t.item()]).float().to(t.device) 139 | 140 | # Provide a dummy time embedding, since we are learning a position-paramterized ODE vector field. 141 | dummy_t = torch.zeros_like(t).to(t.device) 142 | emb = self.unet.time_embed(timestep_embedding(dummy_t, self.unet.model_channels)) 143 | 144 | h = x.type(self.unet.dtype) 145 | 146 | # Contraction path. 147 | h_skip_connection = [] 148 | for module in self.unet.input_blocks: 149 | h = module(h, emb) 150 | h_skip_connection.append(h) 151 | 152 | # Bottleneck 153 | h = self.unet.middle_block(h, emb) 154 | 155 | # ODE on bottleneck 156 | if use_ode: 157 | h = self.ode_list[-1](h, integration_time) 158 | 159 | # Expansion path. 160 | for module_idx, module in enumerate(self.unet.output_blocks): 161 | h_skip = h_skip_connection.pop(-1) 162 | 163 | # ODE over skip connections. 164 | if use_ode and self.ode_location in ['all_resolutions', 'all_connections']: 165 | if self.ode_location == 'all_connections': 166 | curr_ode_block = self.ode_list[::-1][module_idx + 1] 167 | else: 168 | resolution_idx = np.argwhere(np.unique(self.dim_list) == h_skip.shape[1]).item() 169 | curr_ode_block = self.ode_list[resolution_idx] 170 | h_skip = curr_ode_block(h_skip, integration_time) 171 | 172 | h = torch.cat([h, h_skip], dim=1) 173 | h = module(h, emb) 174 | 175 | # Output. 176 | h = h.type(x.dtype) 177 | output = self.unet.out(h) 178 | 179 | if return_grad: 180 | vec_field_gradients = 0 181 | for i in range(len(self.ode_list)): 182 | vec_field_gradients += self.ode_list[i].vec_grad() 183 | return output, vec_field_gradients.mean() / len(self.ode_list) 184 | else: 185 | return output 186 | 187 | def simsiam_project(self, x: torch.Tensor): 188 | # Provide a dummy time embedding, since we are learning a position-paramterized ODE vector field. 189 | dummy_t = torch.zeros(1).to(x.device) 190 | emb = self.unet.time_embed(timestep_embedding(dummy_t, self.unet.model_channels)) 191 | 192 | h = x.type(self.unet.dtype) 193 | # Contraction path. 194 | for module in self.unet.input_blocks: 195 | h = module(h, emb) 196 | # Bottleneck 197 | h = self.unet.middle_block(h, emb) 198 | 199 | h = h.reshape(h.shape[0], -1) 200 | 201 | z = self.projector(h) 202 | return z 203 | 204 | def simsiam_predict(self, z: torch.Tensor): 205 | p = self.predictor(z) 206 | return p 207 | 208 | @torch.no_grad() 209 | def return_embeddings(self, x: torch.Tensor, t: torch.Tensor): 210 | """ 211 | Store and return the embedding vectors. 212 | 213 | :param x: an [N x C x ...] Tensor of inputs. 214 | :param t: a 1-D batch of timesteps. 215 | """ 216 | embeddings_before = [] 217 | embeddings_after = [] 218 | 219 | # Skip ODE if no time difference. 220 | use_ode = t.item() != 0 221 | if use_ode: 222 | integration_time = torch.tensor([0, t.item()]).float().to(t.device) 223 | 224 | # Provide a dummy time embedding, since we are learning a position-paramterized ODE vector field. 225 | dummy_t = torch.zeros_like(t).to(t.device) 226 | emb = self.unet.time_embed(timestep_embedding(dummy_t, self.unet.model_channels)) 227 | 228 | h = x.type(self.unet.dtype) 229 | 230 | # Contraction path. 231 | h_skip_connection = [] 232 | for module in self.unet.input_blocks: 233 | h = module(h, emb) 234 | h_skip_connection.append(h) 235 | 236 | # Bottleneck 237 | h = self.unet.middle_block(h, emb) 238 | 239 | # ODE on bottleneck 240 | embeddings_before.append(h) 241 | if use_ode: 242 | h = self.ode_list[-1](h, integration_time) 243 | embeddings_after.append(h) 244 | 245 | # Expansion path. 246 | for module_idx, module in enumerate(self.unet.output_blocks): 247 | h_skip = h_skip_connection.pop(-1) 248 | 249 | # ODE over skip connections. 250 | embeddings_before.append(h_skip) 251 | if use_ode and self.ode_location in ['all_resolutions', 'all_connections']: 252 | if self.ode_location == 'all_connections': 253 | curr_ode_block = self.ode_list[::-1][module_idx + 1] 254 | else: 255 | resolution_idx = np.argwhere(np.unique(self.dim_list) == h_skip.shape[1]).item() 256 | curr_ode_block = self.ode_list[resolution_idx] 257 | 258 | h_skip = curr_ode_block(h_skip, integration_time) 259 | embeddings_after.append(h_skip) 260 | 261 | h = torch.cat([h, h_skip], dim=1) 262 | h = module(h, emb) 263 | 264 | return embeddings_before, embeddings_after 265 | -------------------------------------------------------------------------------- /src/nn/off_the_shelf_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import os 4 | FILE_DIR = '/'.join(os.path.realpath(__file__).split('/')[:-3]) 5 | 6 | class VisionEncoder(): 7 | def __init__(self, 8 | pretrained_model: str, 9 | device: torch.device): 10 | if pretrained_model == 'resnet18': 11 | self.backbone = torchvision.models.resnet18(weights='DEFAULT') 12 | self.backbone.fc = torch.nn.Identity() 13 | elif pretrained_model == 'convnext_tiny': 14 | self.backbone = torchvision.models.convnext_tiny(weights='DEFAULT') 15 | self.backbone.classifier[-1] = torch.nn.Identity() 16 | elif pretrained_model == 'mobilenetv3_small': 17 | self.backbone = torchvision.models.mobilenet_v3_small(weights='DEFAULT') 18 | self.backbone.classifier[-1] = torch.nn.Identity() 19 | elif pretrained_model == 'retinal': 20 | self.backbone = torchvision.models.resnet50(weights=None) 21 | flair_model_weights = torch.load(FILE_DIR + '/external_src/FLAIR_retina/flair_resnet.pth', map_location=device) 22 | vision_model_weights = {} 23 | for key in flair_model_weights.keys(): 24 | if 'vision_model' in key: 25 | vision_model_weights[key.replace('vision_model.model.', '')] = flair_model_weights[key] 26 | self.backbone.load_state_dict(vision_model_weights, strict=False) 27 | self.backbone.fc = torch.nn.Identity() 28 | self.backbone.eval() 29 | self.backbone.to(device) 30 | self.device = device 31 | 32 | def embed(self, image: torch.Tensor) -> torch.Tensor: 33 | assert len(image.shape) == 4 34 | assert image.shape[1] in [1, 3] 35 | if image.shape[1] == 1: 36 | image = image.repeat(1, 3, 1, 1) 37 | 38 | latent_embedding = self.backbone(image.float().to(self.device)) 39 | return latent_embedding 40 | -------------------------------------------------------------------------------- /src/nn/scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright The PyTorch Lightning team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # 16 | # Implemented by @ananyahjha93 17 | # also found at: https://github.com/Lightning-AI/lightning-bolts/blob/master/pl_bolts/optimizers/lr_scheduler.py 18 | import math 19 | import warnings 20 | from typing import List 21 | 22 | from torch.optim import Optimizer 23 | from torch.optim.lr_scheduler import _LRScheduler 24 | 25 | 26 | class LinearWarmupCosineAnnealingLR(_LRScheduler): 27 | """Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr 28 | and base_lr followed by a cosine annealing schedule between base_lr and eta_min. 29 | 30 | .. warning:: 31 | It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR` 32 | after each iteration as calling it after each epoch will keep the starting lr at 33 | warmup_start_lr for the first epoch which is 0 in most cases. 34 | 35 | .. warning:: 36 | passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING. 37 | It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of 38 | :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing 39 | epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling 40 | train and validation methods. 41 | 42 | Example: 43 | >>> layer = nn.Linear(10, 1) 44 | >>> optimizer = Adam(layer.parameters(), lr=0.02) 45 | >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40) 46 | >>> # 47 | >>> # the default case 48 | >>> for epoch in range(40): 49 | ... # train(...) 50 | ... # validate(...) 51 | ... scheduler.step() 52 | >>> # 53 | >>> # passing epoch param case 54 | >>> for epoch in range(40): 55 | ... scheduler.step(epoch) 56 | ... # train(...) 57 | ... # validate(...) 58 | """ 59 | 60 | def __init__( 61 | self, 62 | optimizer: Optimizer, 63 | warmup_epochs: int, 64 | max_epochs: int, 65 | warmup_start_lr: float = 0.0, 66 | eta_min: float = 0.0, 67 | last_epoch: int = -1, 68 | ) -> None: 69 | """ 70 | Args: 71 | optimizer (Optimizer): Wrapped optimizer. 72 | warmup_epochs (int): Maximum number of iterations for linear warmup 73 | max_epochs (int): Maximum number of iterations 74 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. 75 | eta_min (float): Minimum learning rate. Default: 0. 76 | last_epoch (int): The index of last epoch. Default: -1. 77 | """ 78 | self.warmup_epochs = warmup_epochs 79 | self.max_epochs = max_epochs 80 | self.warmup_start_lr = warmup_start_lr 81 | self.eta_min = eta_min 82 | 83 | super().__init__(optimizer, last_epoch) 84 | 85 | def get_lr(self) -> List[float]: 86 | """Compute learning rate using chainable form of the scheduler.""" 87 | if not self._get_lr_called_within_step: 88 | warnings.warn( 89 | "To get the last learning rate computed by the scheduler, " 90 | "please use `get_last_lr()`.", 91 | UserWarning, 92 | ) 93 | 94 | if self.last_epoch == self.warmup_epochs: 95 | return self.base_lrs 96 | if self.last_epoch == 0: 97 | return [self.warmup_start_lr] * len(self.base_lrs) 98 | if self.last_epoch < self.warmup_epochs: 99 | return [ 100 | group["lr"] + (base_lr - self.warmup_start_lr) / 101 | (self.warmup_epochs - 1) for base_lr, group in zip( 102 | self.base_lrs, self.optimizer.param_groups) 103 | ] 104 | if (self.last_epoch - 1 - self.max_epochs) % ( 105 | 2 * (self.max_epochs - self.warmup_epochs)) == 0: 106 | return [ 107 | group["lr"] + (base_lr - self.eta_min) * 108 | (1 - math.cos(math.pi / 109 | (self.max_epochs - self.warmup_epochs))) / 2 110 | for base_lr, group in zip(self.base_lrs, 111 | self.optimizer.param_groups) 112 | ] 113 | 114 | return [ 115 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / 116 | (self.max_epochs - self.warmup_epochs))) / 117 | (1 + math.cos(math.pi * 118 | (self.last_epoch - self.warmup_epochs - 1) / 119 | (self.max_epochs - self.warmup_epochs))) * 120 | (group["lr"] - self.eta_min) + self.eta_min 121 | for group in self.optimizer.param_groups 122 | ] 123 | 124 | def _get_closed_form_lr(self) -> List[float]: 125 | """Called when epoch is passed as a param to the `step` function of the scheduler.""" 126 | if self.last_epoch < self.warmup_epochs: 127 | return [ 128 | self.warmup_start_lr + self.last_epoch * 129 | (base_lr - self.warmup_start_lr) / 130 | max(1, self.warmup_epochs - 1) for base_lr in self.base_lrs 131 | ] 132 | 133 | return [ 134 | self.eta_min + 0.5 * (base_lr - self.eta_min) * 135 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / 136 | (self.max_epochs - self.warmup_epochs))) 137 | for base_lr in self.base_lrs 138 | ] -------------------------------------------------------------------------------- /src/nn/unet_i2sb.py: -------------------------------------------------------------------------------- 1 | from .base import BaseNetwork 2 | from .nn_utils import ConvBlock, ResConvBlock, timestep_embedding 3 | from .common_encoder import Encoder 4 | import os 5 | import sys 6 | import torch 7 | 8 | import_dir = '/'.join(os.path.realpath(__file__).split('/')[:-3]) 9 | sys.path.insert(0, import_dir + '/external_src/I2SB/') 10 | from guided_diffusion.script_util import create_model 11 | 12 | 13 | class I2SBUNet(BaseNetwork): 14 | 15 | def __init__(self, 16 | device: torch.device, 17 | in_channels: int, 18 | step_to_t, 19 | diffusion, 20 | **kwargs): 21 | ''' 22 | An UNet model for I2SB: Image-to-Image Schrodinger Bridge. 23 | 24 | Parameters 25 | ---------- 26 | device: torch.device 27 | in_channels: int 28 | Number of input image channels. 29 | step_to_t: List 30 | A mapping from step index to time t. 31 | diffusion: 32 | A Diffusion object. 33 | All other kwargs will be ignored. 34 | ''' 35 | super().__init__() 36 | 37 | self.device = device 38 | self.step_to_t = step_to_t 39 | self.diffusion = diffusion 40 | 41 | # initialize model 42 | self.model = create_model( 43 | image_size=256, # TODO: currently hard coded 44 | in_channels=in_channels, 45 | num_channels=256, 46 | num_res_blocks=2, 47 | channel_mult='', 48 | learn_sigma=False, 49 | class_cond=False, 50 | use_checkpoint=False, 51 | attention_resolutions='32,16,8', 52 | num_heads=4, 53 | num_head_channels=64, 54 | num_heads_upsample=-1, 55 | use_scale_shift_norm=True, 56 | dropout=0.0, 57 | resblock_updown=True, 58 | use_fp16=False, 59 | use_new_attention_order=False) 60 | 61 | self.model.eval() 62 | self.model.to(self.device) 63 | 64 | def forward(self, x: torch.Tensor, t: torch.Tensor): 65 | 66 | assert t.dim()==1 and t.shape[0] == x.shape[0] 67 | return self.model(x, t) 68 | 69 | @torch.no_grad() 70 | def ddpm_sampling(self, x_start, steps): 71 | ''' 72 | Inference. 73 | ''' 74 | 75 | x_start = x_start.to(self.device) 76 | 77 | def pred_x_end_fn(x, step): 78 | step = torch.full((x.shape[0],), step, device=self.device, dtype=torch.long) 79 | t = self.step_to_t[step] 80 | out = self.model(x, t) 81 | return self.compute_pred_x0(step, x, out, clip_denoise=False) 82 | 83 | xs, x_end_pred = self.diffusion.ddpm_sampling( 84 | steps, pred_x_end_fn, x_start, mask=None, ot_ode=False, log_steps=None, verbose=False, 85 | ) 86 | 87 | return xs, x_end_pred 88 | 89 | def compute_pred_x0(self, step, xt, net_out, clip_denoise=False): 90 | """ Given network output, recover x0. This should be the inverse of I2SB Eq 12 """ 91 | std_fwd = self.diffusion.get_std_fwd(step, xdim=xt.shape[1:]) 92 | pred_x0 = xt - std_fwd * net_out 93 | if clip_denoise: pred_x0.clamp_(-1., 1.) 94 | return pred_x0 95 | -------------------------------------------------------------------------------- /src/nn/unet_ode.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | from .base import BaseNetwork 6 | from .nn_utils import ODEfunc, ODEBlock 7 | 8 | import_dir = '/'.join(os.path.realpath(__file__).split('/')[:-3]) 9 | sys.path.insert(0, import_dir + '/external_src/I2SB/') 10 | from guided_diffusion.script_util import create_model 11 | from guided_diffusion.unet import timestep_embedding 12 | 13 | 14 | class ODEUNet(BaseNetwork): 15 | 16 | def __init__(self, 17 | device: torch.device, 18 | in_channels: int, 19 | ode_location: str = 'all_connections', 20 | **kwargs): 21 | ''' 22 | A UNet model with ODE. 23 | 24 | ode_location: str 25 | If 'bottleneck', only perform ODE on the bottleneck layer. 26 | If 'all_resolutions', skip connections with the same resolution share the same ODE. 27 | If 'all_connections', perform ODE separately in all skip connections. 28 | 29 | Parameters 30 | ---------- 31 | device: torch.device 32 | in_channels: int 33 | Number of input image channels. 34 | All other kwargs will be ignored. 35 | ''' 36 | super().__init__() 37 | 38 | self.device = device 39 | self.ode_location = ode_location 40 | image_size = 256 # TODO: currently hard coded 41 | 42 | # NOTE: This model is smaller than the other counterparts, 43 | # because running NeuralODE require some significant GPU space. 44 | # initialize model 45 | self.unet = create_model( 46 | image_size=image_size, 47 | in_channels=in_channels, 48 | num_channels=64, 49 | num_res_blocks=1, 50 | channel_mult='', 51 | learn_sigma=False, 52 | class_cond=False, 53 | use_checkpoint=False, 54 | attention_resolutions='32,16,8', 55 | num_heads=4, 56 | num_head_channels=16, 57 | num_heads_upsample=-1, 58 | use_scale_shift_norm=True, 59 | dropout=0.0, 60 | resblock_updown=True, 61 | use_fp16=False, 62 | use_new_attention_order=False) 63 | 64 | # Record the channel dimensions by passing in a dummy tensor. 65 | self.dim_list = [] 66 | h_dummy = torch.zeros((1, 1, image_size, image_size)).type(self.unet.dtype) 67 | t_dummy = torch.zeros((1)).type(self.unet.dtype) 68 | emb = self.unet.time_embed(timestep_embedding(t_dummy, self.unet.model_channels)) 69 | for module in self.unet.input_blocks: 70 | h_dummy = module(h_dummy, emb) 71 | self.dim_list.append(h_dummy.shape[1]) 72 | h_dummy_bottleneck = self.unet.middle_block(h_dummy, emb) 73 | self.dim_list.append(h_dummy_bottleneck.shape[1]) 74 | 75 | # Construct the ODE modules. 76 | self.ode_list = torch.nn.ModuleList([]) 77 | if self.ode_location == 'bottleneck': 78 | self.ode_list.append(ODEBlock(ODEfunc(dim=h_dummy_bottleneck.shape[1]))) 79 | elif self.ode_location == 'all_resolutions': 80 | for dim in np.unique(self.dim_list): 81 | self.ode_list.append(ODEBlock(ODEfunc(dim=dim))) 82 | elif self.ode_location == 'all_connections': 83 | for dim in self.dim_list: 84 | self.ode_list.append(ODEBlock(ODEfunc(dim=dim))) 85 | 86 | self.unet.to(self.device) 87 | self.ode_list.to(self.device) 88 | 89 | def time_independent_parameters(self): 90 | ''' 91 | Parameters related to ODE. 92 | ''' 93 | return set(self.parameters()) - set(self.ode_list.parameters()) 94 | 95 | def freeze_time_independent(self): 96 | ''' 97 | Freeze paramters that are time-independent. 98 | ''' 99 | for p in self.time_independent_parameters(): 100 | p.requires_grad = False 101 | 102 | def forward(self, x: torch.Tensor, t: torch.Tensor, return_grad: bool = False): 103 | """ 104 | Apply the model to an input batch. 105 | 106 | :param x: an [N x C x ...] Tensor of inputs. 107 | :param t: a 1-D batch of timesteps. 108 | :return: an [N x C x ...] Tensor of outputs. 109 | """ 110 | 111 | # Skip ODE if no time difference. 112 | use_ode = t.item() != 0 113 | if use_ode: 114 | integration_time = torch.tensor([0, t.item()]).float().to(t.device) 115 | 116 | emb = self.unet.time_embed(timestep_embedding(t, self.unet.model_channels)) 117 | 118 | h = x.type(self.unet.dtype) 119 | 120 | # Contraction path. 121 | h_skip_connection = [] 122 | for module in self.unet.input_blocks: 123 | h = module(h, emb) 124 | h_skip_connection.append(h) 125 | 126 | # Bottleneck 127 | h = self.unet.middle_block(h, emb) 128 | 129 | # ODE on bottleneck 130 | if use_ode: 131 | h = self.ode_list[-1](h, integration_time) 132 | 133 | # Expansion path. 134 | for module_idx, module in enumerate(self.unet.output_blocks): 135 | h_skip = h_skip_connection.pop(-1) 136 | 137 | # ODE over skip connections. 138 | if use_ode and self.ode_location in ['all_resolutions', 'all_connections']: 139 | if self.ode_location == 'all_connections': 140 | curr_ode_block = self.ode_list[::-1][module_idx + 1] 141 | else: 142 | resolution_idx = np.argwhere(np.unique(self.dim_list) == h_skip.shape[1]).item() 143 | curr_ode_block = self.ode_list[resolution_idx] 144 | h_skip = curr_ode_block(h_skip, integration_time) 145 | 146 | h = torch.cat([h, h_skip], dim=1) 147 | h = module(h, emb) 148 | 149 | # Output. 150 | h = h.type(x.dtype) 151 | output = self.unet.out(h) 152 | 153 | if return_grad: 154 | vec_field_gradients = 0 155 | for i in range(len(self.ode_list)): 156 | vec_field_gradients += self.ode_list[i].vec_grad() 157 | return output, vec_field_gradients.mean() / len(self.ode_list) 158 | else: 159 | return output 160 | -------------------------------------------------------------------------------- /src/nn/unet_ode_simple.py: -------------------------------------------------------------------------------- 1 | from .base import BaseNetwork 2 | from .nn_utils import ConvBlock, ResConvBlock, ODEfunc, ODEBlock 3 | import torch 4 | 5 | 6 | class ODEUNetSimple(BaseNetwork): 7 | 8 | def __init__(self, 9 | device: torch.device = torch.device('cpu'), 10 | num_filters: int = 64, 11 | depth: int = 6, 12 | use_residual: bool = False, 13 | in_channels: int = 3, 14 | out_channels: int = 3, 15 | non_linearity: str = 'relu', 16 | use_bn: bool = True): 17 | ''' 18 | A UNet model with ODE. 19 | 20 | Parameters 21 | ---------- 22 | device: torch.device 23 | num_filters : int 24 | Number of convolutional filters. 25 | depth: int 26 | Depth of the model (encoding or decoding) 27 | use_residual: bool 28 | Whether to use residual connection within the same conv block 29 | in_channels: int 30 | Number of input image channels. 31 | out_channels: int 32 | Number of output image channels. 33 | non_linearity : string 34 | One of 'relu' and 'softplus' 35 | ''' 36 | super().__init__() 37 | 38 | self.device = device 39 | self.depth = depth 40 | self.use_residual = use_residual 41 | self.in_channels = in_channels 42 | self.non_linearity_str = non_linearity 43 | if self.non_linearity_str == 'relu': 44 | self.non_linearity = torch.nn.ReLU(inplace=True) 45 | elif self.non_linearity_str == 'softplus': 46 | self.non_linearity = torch.nn.Softplus() 47 | self.use_bn = use_bn 48 | 49 | n_f = num_filters # shorthand 50 | 51 | if self.use_residual: 52 | conv_block = ResConvBlock 53 | upconv_block = ResConvBlock 54 | else: 55 | conv_block = ConvBlock 56 | upconv_block = ConvBlock 57 | 58 | # This is for the contraction path. 59 | self.conv1x1 = torch.nn.Conv2d(in_channels, n_f, 1, 1) 60 | self.down_list = torch.nn.ModuleList([]) 61 | self.down_conn_list = torch.nn.ModuleList([]) 62 | for d in range(self.depth): 63 | self.down_list.append(conv_block(n_f * 2 ** d)) 64 | if self.use_bn: 65 | self.down_conn_list.append(torch.nn.Sequential( 66 | torch.nn.Conv2d(n_f * 2 ** d, n_f * 2 ** (d + 1), 1, 1), 67 | torch.nn.BatchNorm2d(n_f * 2 ** (d + 1)), 68 | )) 69 | else: 70 | self.down_conn_list.append(torch.nn.Conv2d(n_f * 2 ** d, n_f * 2 ** (d + 1), 1, 1)) 71 | self.bottleneck = conv_block(n_f * 2 ** self.depth) 72 | 73 | # This is for the expansion path. 74 | self.ode_list = torch.nn.ModuleList([]) 75 | self.up_list = torch.nn.ModuleList([]) 76 | self.up_conn_list = torch.nn.ModuleList([]) 77 | for d in range(self.depth): 78 | self.ode_list.append(ODEBlock(ODEfunc(dim=n_f * 2 ** d))) 79 | self.up_list.append(upconv_block(n_f * 2 ** d)) 80 | if self.use_bn: 81 | self.up_conn_list.append(torch.nn.Sequential( 82 | torch.nn.Conv2d(n_f * 3 * 2 ** d, n_f * 2 ** d, 1, 1), 83 | torch.nn.BatchNorm2d(n_f * 2 ** d), 84 | )) 85 | else: 86 | self.up_conn_list.append(torch.nn.Conv2d(n_f * 3 * 2 ** d, n_f * 2 ** d, 1, 1)) 87 | self.ode_list = self.ode_list[::-1] 88 | self.up_list = self.up_list[::-1] 89 | self.up_conn_list = self.up_conn_list[::-1] 90 | 91 | self.ode_bottleneck = ODEBlock(ODEfunc(dim=n_f * 2 ** self.depth)) 92 | self.out_layer = torch.nn.Conv2d(n_f, out_channels, 1) 93 | 94 | def time_independent_parameters(self): 95 | ''' 96 | Parameters related to ODE. 97 | ''' 98 | return set(self.parameters()) - set(self.ode_list.parameters()) - set(self.ode_bottleneck.parameters()) 99 | 100 | def freeze_time_independent(self): 101 | ''' 102 | Freeze paramters that are time-independent. 103 | ''' 104 | for p in self.time_independent_parameters(): 105 | p.requires_grad = False 106 | 107 | def forward(self, x: torch.Tensor, t: torch.Tensor, return_grad: bool = False): 108 | ''' 109 | Time embedding through ODE. 110 | ''' 111 | 112 | assert x.shape[0] == 1 113 | 114 | # Skip ODE if no time difference. 115 | use_ode = t.item() != 0 116 | if use_ode: 117 | integration_time = torch.tensor([0, t.item()]).float().to(t.device) 118 | 119 | ###################### 120 | # Contraction path. 121 | ###################### 122 | x = self.non_linearity(self.conv1x1(x)) 123 | residual_list = [] 124 | for d in range(self.depth): 125 | x = self.down_list[d](x) 126 | residual_list.append(x.clone()) 127 | x = self.non_linearity(self.down_conn_list[d](x)) 128 | x = torch.nn.functional.interpolate(x, 129 | scale_factor=0.5, 130 | mode='bilinear', 131 | align_corners=False) 132 | x = self.bottleneck(x) 133 | 134 | ###################### 135 | # Expansion path. 136 | ###################### 137 | if use_ode: 138 | x = self.ode_bottleneck(x, integration_time) 139 | 140 | for d in range(self.depth): 141 | x = torch.nn.functional.interpolate(x, 142 | scale_factor=2, 143 | mode='bilinear', 144 | align_corners=False) 145 | if use_ode: 146 | res = self.ode_list[d](residual_list.pop(-1), integration_time) 147 | else: 148 | res = residual_list.pop(-1) 149 | x = torch.cat([x, res], dim=1) 150 | x = self.non_linearity(self.up_conn_list[d](x)) 151 | x = self.up_list[d](x) 152 | 153 | output = self.out_layer(x) 154 | 155 | if return_grad: 156 | vec_field_gradients = 0 157 | for i in range(len(self.ode_list)): 158 | vec_field_gradients += self.ode_list[i].vec_grad() 159 | return output, vec_field_gradients.mean() / len(self.ode_list) 160 | else: 161 | return output 162 | -------------------------------------------------------------------------------- /src/nn/unet_ode_simple_position_parametrized.py: -------------------------------------------------------------------------------- 1 | from .base import BaseNetwork 2 | from .nn_utils import ConvBlock, ResConvBlock, PPODEfunc, ODEBlock 3 | import torch 4 | 5 | 6 | class PPODEUNetSimple(BaseNetwork): 7 | 8 | def __init__(self, 9 | device: torch.device = torch.device('cpu'), 10 | num_filters: int = 64, 11 | depth: int = 6, 12 | use_residual: bool = False, 13 | in_channels: int = 3, 14 | out_channels: int = 3, 15 | non_linearity: str = 'relu', 16 | use_bn: bool = True): 17 | ''' 18 | A UNet model with ODE. 19 | 20 | Parameters 21 | ---------- 22 | device: torch.device 23 | num_filters : int 24 | Number of convolutional filters. 25 | depth: int 26 | Depth of the model (encoding or decoding) 27 | use_residual: bool 28 | Whether to use residual connection within the same conv block 29 | in_channels: int 30 | Number of input image channels. 31 | out_channels: int 32 | Number of output image channels. 33 | non_linearity : string 34 | One of 'relu' and 'softplus' 35 | ''' 36 | super().__init__() 37 | 38 | self.device = device 39 | self.depth = depth 40 | self.use_residual = use_residual 41 | self.in_channels = in_channels 42 | self.non_linearity_str = non_linearity 43 | if self.non_linearity_str == 'relu': 44 | self.non_linearity = torch.nn.ReLU(inplace=True) 45 | elif self.non_linearity_str == 'softplus': 46 | self.non_linearity = torch.nn.Softplus() 47 | self.use_bn = use_bn 48 | 49 | n_f = num_filters # shorthand 50 | 51 | if self.use_residual: 52 | conv_block = ResConvBlock 53 | upconv_block = ResConvBlock 54 | else: 55 | conv_block = ConvBlock 56 | upconv_block = ConvBlock 57 | 58 | # This is for the contraction path. 59 | self.conv1x1 = torch.nn.Conv2d(in_channels, n_f, 1, 1) 60 | self.down_list = torch.nn.ModuleList([]) 61 | self.down_conn_list = torch.nn.ModuleList([]) 62 | for d in range(self.depth): 63 | self.down_list.append(conv_block(n_f * 2 ** d)) 64 | if self.use_bn: 65 | self.down_conn_list.append(torch.nn.Sequential( 66 | torch.nn.Conv2d(n_f * 2 ** d, n_f * 2 ** (d + 1), 1, 1), 67 | torch.nn.BatchNorm2d(n_f * 2 ** (d + 1)), 68 | )) 69 | else: 70 | self.down_conn_list.append(torch.nn.Conv2d(n_f * 2 ** d, n_f * 2 ** (d + 1), 1, 1)) 71 | self.bottleneck = conv_block(n_f * 2 ** self.depth) 72 | 73 | # This is for the expansion path. 74 | self.ode_list = torch.nn.ModuleList([]) 75 | self.up_list = torch.nn.ModuleList([]) 76 | self.up_conn_list = torch.nn.ModuleList([]) 77 | for d in range(self.depth): 78 | self.ode_list.append(ODEBlock(PPODEfunc(dim=n_f * 2 ** d))) 79 | self.up_list.append(upconv_block(n_f * 2 ** d)) 80 | if self.use_bn: 81 | self.up_conn_list.append(torch.nn.Sequential( 82 | torch.nn.Conv2d(n_f * 3 * 2 ** d, n_f * 2 ** d, 1, 1), 83 | torch.nn.BatchNorm2d(n_f * 2 ** d), 84 | )) 85 | else: 86 | self.up_conn_list.append(torch.nn.Conv2d(n_f * 3 * 2 ** d, n_f * 2 ** d, 1, 1)) 87 | self.ode_list = self.ode_list[::-1] 88 | self.up_list = self.up_list[::-1] 89 | self.up_conn_list = self.up_conn_list[::-1] 90 | 91 | self.ode_bottleneck = ODEBlock(PPODEfunc(dim=n_f * 2 ** self.depth)) 92 | self.out_layer = torch.nn.Conv2d(n_f, out_channels, 1) 93 | 94 | def time_independent_parameters(self): 95 | ''' 96 | Parameters related to ODE. 97 | ''' 98 | return set(self.parameters()) - set(self.ode_list.parameters()) - set(self.ode_bottleneck.parameters()) 99 | 100 | def freeze_time_independent(self): 101 | ''' 102 | Freeze paramters that are time-independent. 103 | ''' 104 | for p in self.time_independent_parameters(): 105 | p.requires_grad = False 106 | 107 | def forward(self, x: torch.Tensor, t: torch.Tensor, return_grad: bool = False): 108 | ''' 109 | Time embedding through ODE. 110 | ''' 111 | 112 | assert x.shape[0] == 1 113 | 114 | # Skip ODE if no time difference. 115 | use_ode = t.item() != 0 116 | if use_ode: 117 | integration_time = torch.tensor([0, t.item()]).float().to(t.device) 118 | 119 | ###################### 120 | # Contraction path. 121 | ###################### 122 | x = self.non_linearity(self.conv1x1(x)) 123 | residual_list = [] 124 | for d in range(self.depth): 125 | x = self.down_list[d](x) 126 | residual_list.append(x.clone()) 127 | x = self.non_linearity(self.down_conn_list[d](x)) 128 | x = torch.nn.functional.interpolate(x, 129 | scale_factor=0.5, 130 | mode='bilinear', 131 | align_corners=False) 132 | x = self.bottleneck(x) 133 | 134 | ###################### 135 | # Expansion path. 136 | ###################### 137 | if use_ode: 138 | x = self.ode_bottleneck(x, integration_time) 139 | 140 | for d in range(self.depth): 141 | x = torch.nn.functional.interpolate(x, 142 | scale_factor=2, 143 | mode='bilinear', 144 | align_corners=False) 145 | if use_ode: 146 | res = self.ode_list[d](residual_list.pop(-1), integration_time) 147 | else: 148 | res = residual_list.pop(-1) 149 | x = torch.cat([x, res], dim=1) 150 | x = self.non_linearity(self.up_conn_list[d](x)) 151 | x = self.up_list[d](x) 152 | 153 | output = self.out_layer(x) 154 | 155 | if return_grad: 156 | vec_field_gradients = 0 157 | for i in range(len(self.ode_list)): 158 | vec_field_gradients += self.ode_list[i].vec_grad() 159 | return output, vec_field_gradients.mean() / len(self.ode_list) 160 | else: 161 | return output 162 | -------------------------------------------------------------------------------- /src/nn/unet_sode.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | from .base import BaseNetwork 6 | from .nn_utils import StaticODEfunc, Combine2Channels, LatentClassifier, SODEBlock 7 | 8 | import_dir = '/'.join(os.path.realpath(__file__).split('/')[:-3]) 9 | sys.path.insert(0, import_dir + '/external_src/I2SB/') 10 | from guided_diffusion.script_util import create_model 11 | from guided_diffusion.unet import timestep_embedding 12 | 13 | 14 | class SODEUNet(BaseNetwork): 15 | 16 | def __init__(self, 17 | device: torch.device, 18 | in_channels: int, 19 | **kwargs): 20 | ''' 21 | A UNet model with State-augmented Ordinary Differential Equation (SODE). 22 | 23 | Parameters 24 | ---------- 25 | device: torch.device 26 | in_channels: int 27 | Number of input image channels. 28 | All other kwargs will be ignored. 29 | ''' 30 | super().__init__() 31 | 32 | self.device = device 33 | image_size = 256 # TODO: currently hard coded 34 | 35 | # NOTE: This model is smaller than the other counterparts, 36 | # because running NeuralODE require some significant GPU space. 37 | # initialize model 38 | self.unet = create_model( 39 | image_size=image_size, 40 | in_channels=in_channels, 41 | num_channels=128, 42 | num_res_blocks=1, 43 | channel_mult='', 44 | learn_sigma=False, 45 | class_cond=False, 46 | use_checkpoint=False, 47 | attention_resolutions='32,16,8', 48 | num_heads=4, 49 | num_head_channels=64, 50 | num_heads_upsample=-1, 51 | use_scale_shift_norm=True, 52 | dropout=0.0, 53 | resblock_updown=True, 54 | use_fp16=False, 55 | use_new_attention_order=False) 56 | 57 | # Record the channel dimensions by passing in a dummy tensor. 58 | self.dim_list = [] 59 | h_dummy = torch.zeros((1, 1, image_size, image_size)).type(self.unet.dtype) 60 | t_dummy = torch.zeros((1)).type(self.unet.dtype) 61 | emb = self.unet.time_embed(timestep_embedding(t_dummy, self.unet.model_channels)) 62 | for module in self.unet.input_blocks: 63 | h_dummy = module(h_dummy, emb) 64 | if h_dummy.shape[1] not in self.dim_list: 65 | self.dim_list.append(h_dummy.shape[1]) 66 | h_dummy = self.unet.middle_block(h_dummy, emb) 67 | if h_dummy.shape[1] not in self.dim_list: 68 | self.dim_list.append(h_dummy.shape[1]) 69 | 70 | # Construct the SODE modules. 71 | self.sode_list = torch.nn.ModuleList([]) 72 | for dim in self.dim_list: 73 | # NOTE: not a typo. ODEfunc inside CDEBlock. 74 | self.sode_list.append(SODEBlock(StaticODEfunc(dim=dim), 75 | Combine2Channels(dim=dim), 76 | LatentClassifier(dim=dim, emb_channels=emb.shape[1]))) 77 | 78 | self.unet.to(self.device) 79 | self.sode_list.to(self.device) 80 | 81 | def time_independent_parameters(self): 82 | ''' 83 | Parameters related to CDE. 84 | ''' 85 | return set(self.parameters()) - set(self.sode_list.parameters()) 86 | 87 | def freeze_time_independent(self): 88 | ''' 89 | Freeze paramters that are time-independent. 90 | ''' 91 | for p in self.time_independent_parameters(): 92 | p.requires_grad = False 93 | 94 | def forward(self, x: torch.Tensor, t: torch.Tensor, return_grad: bool = False): 95 | """ 96 | Apply the model to an input batch. 97 | 98 | :param x: an [N x C x ...] Tensor of inputs. 99 | :param t: a 1-D batch of timesteps. 100 | :return: an [N x C x ...] Tensor of outputs. 101 | """ 102 | 103 | # Skip ODE if no time difference. 104 | use_ode = not (len(t) == 1 and t.item() == 0) 105 | if use_ode: 106 | integration_time = t.float() 107 | 108 | h_skip_connection = [] 109 | 110 | # Provide a dummy time embedding, since we are learning a static ODE vector field. 111 | dummy_t = torch.zeros_like(t).to(t.device) 112 | emb = self.unet.time_embed(timestep_embedding(dummy_t, self.unet.model_channels)) 113 | 114 | # State-augmented ODE actually needs the proper time embedding. 115 | emb_sode = self.unet.time_embed(timestep_embedding(t, self.unet.model_channels)) 116 | 117 | h = x.type(self.unet.dtype) 118 | for module in self.unet.input_blocks: 119 | h = module(h, emb) 120 | if use_ode: 121 | ode_idx = np.argwhere(np.array(self.dim_list) == h.shape[1]).item() 122 | h_skip = self.sode_list[ode_idx](h, emb_sode, integration_time) 123 | h_skip_connection.append(h_skip) 124 | else: 125 | h_skip_connection.append(h) 126 | 127 | h = self.unet.middle_block(h, emb) 128 | if use_ode: 129 | ode_idx = np.argwhere(np.array(self.dim_list) == h.shape[1]).item() 130 | h = self.sode_list[ode_idx](h, emb_sode, integration_time) 131 | 132 | for module in self.unet.output_blocks: 133 | # When modeling multiple observations, we only keep the first observation. 134 | if h.shape[0] > 1: 135 | h = h[0].unsqueeze(0) 136 | h = torch.cat([h, h_skip_connection.pop(-1)], dim=1) 137 | h = module(h, emb) 138 | 139 | h = h.type(x.dtype) 140 | 141 | if return_grad: 142 | vec_field_gradients = 0 143 | for i in range(len(self.sode_list)): 144 | vec_field_gradients += self.sode_list[i].vec_grad() 145 | return self.unet.out(h), vec_field_gradients.mean() / len(self.sode_list) 146 | else: 147 | return self.unet.out(h) 148 | 149 | -------------------------------------------------------------------------------- /src/nn/unet_t_emb.py: -------------------------------------------------------------------------------- 1 | from .base import BaseNetwork 2 | # from .nn_utils import ConvBlock, ResConvBlock, timestep_embedding 3 | # from .common_encoder import Encoder 4 | import os 5 | import sys 6 | import torch 7 | 8 | import_dir = '/'.join(os.path.realpath(__file__).split('/')[:-3]) 9 | sys.path.insert(0, import_dir + '/external_src/I2SB/') 10 | from guided_diffusion.script_util import create_model 11 | 12 | 13 | class T_UNet(BaseNetwork): 14 | 15 | def __init__(self, 16 | device: torch.device, 17 | in_channels: int, 18 | **kwargs): 19 | ''' 20 | An UNet model with time embedding. 21 | This is equivalent to Image-to-Image Schrodinger Bridge without distribution estimation and with only 1 step. 22 | 23 | Parameters 24 | ---------- 25 | device: torch.device 26 | in_channels: int 27 | Number of input image channels. 28 | All other kwargs will be ignored. 29 | ''' 30 | super().__init__() 31 | 32 | self.device = device 33 | 34 | # initialize model 35 | self.model = create_model( 36 | image_size=256, # TODO: currently hard coded 37 | in_channels=in_channels, 38 | num_channels=128, 39 | num_res_blocks=1, 40 | channel_mult='', 41 | learn_sigma=False, 42 | class_cond=False, 43 | use_checkpoint=False, 44 | attention_resolutions='32,16,8', 45 | num_heads=4, 46 | num_head_channels=16, 47 | num_heads_upsample=-1, 48 | use_scale_shift_norm=True, 49 | dropout=0.0, 50 | resblock_updown=True, 51 | use_fp16=False, 52 | use_new_attention_order=False) 53 | 54 | self.model.eval() 55 | self.model.to(self.device) 56 | 57 | def forward(self, x: torch.Tensor, t: torch.Tensor): 58 | 59 | assert t.dim()==1 and t.shape[0] == x.shape[0] 60 | return self.model(x, t) 61 | 62 | def freeze_time_independent(self): 63 | ''' 64 | Freeze paramters that are time-independent. 65 | ''' 66 | pass 67 | 68 | 69 | # class T_UNet(BaseNetwork): 70 | 71 | # def __init__(self, 72 | # device: torch.device = torch.device('cpu'), 73 | # num_filters: int = 16, 74 | # depth: int = 5, 75 | # use_residual: bool = False, 76 | # in_channels: int = 3, 77 | # out_channels: int = 3, 78 | # non_linearity: str = 'relu'): 79 | # ''' 80 | # An UNet model with time embedding. 81 | 82 | # Parameters 83 | # ---------- 84 | # device: torch.device 85 | # num_filters : int 86 | # Number of convolutional filters. 87 | # depth: int 88 | # Depth of the model (encoding or decoding) 89 | # use_residual: bool 90 | # Whether to use residual connection within the same conv block 91 | # in_channels: int 92 | # Number of input image channels. 93 | # out_channels: int 94 | # Number of output image channels. 95 | # non_linearity : string 96 | # One of 'relu' and 'softplus' 97 | # ''' 98 | # super().__init__() 99 | 100 | # self.device = device 101 | # self.depth = depth 102 | # self.use_residual = use_residual 103 | # self.in_channels = in_channels 104 | # self.non_linearity_str = non_linearity 105 | # if self.non_linearity_str == 'relu': 106 | # self.non_linearity = torch.nn.ReLU(inplace=True) 107 | # elif self.non_linearity_str == 'softplus': 108 | # self.non_linearity = torch.nn.Softplus() 109 | 110 | # n_f = num_filters # shorthand 111 | 112 | # if self.use_residual: 113 | # conv_block = ResConvBlock 114 | # upconv_block = ResConvBlock 115 | # else: 116 | # conv_block = ConvBlock 117 | # upconv_block = ConvBlock 118 | 119 | # # This is for the encoder. 120 | # self.encoder = Encoder(in_channels=in_channels, 121 | # n_f=n_f, 122 | # depth=self.depth, 123 | # conv_block=conv_block, 124 | # non_linearity=self.non_linearity) 125 | 126 | # # This is for the decoder. 127 | # bottleneck_channel = n_f * 2 ** self.depth 128 | # self.t_emb_list = torch.nn.ModuleList([]) 129 | # self.up_list = torch.nn.ModuleList([]) 130 | # self.up_conn_list = torch.nn.ModuleList([]) 131 | # for d in range(self.depth): 132 | # self.t_emb_list.append(self._t_mlp_layer(bottleneck_channel, n_f * 2 ** d)) 133 | # self.up_conn_list.append(torch.nn.Conv2d(n_f * 3 * 2 ** d, n_f * 2 ** d, 1, 1)) 134 | # self.up_list.append(upconv_block(n_f * 2 ** d)) 135 | # self.t_emb_list = self.t_emb_list[::-1] 136 | # self.up_list = self.up_list[::-1] 137 | # self.up_conn_list = self.up_conn_list[::-1] 138 | 139 | # self.t_emb_bottleneck = self._t_mlp_layer(bottleneck_channel, bottleneck_channel) 140 | # self.t_emb_common = self._t_mlp_common(bottleneck_channel) 141 | # self.out_layer = torch.nn.Conv2d(n_f, out_channels, 1) 142 | 143 | # def _t_mlp_common(self, time_embed_dim: int): 144 | # ''' 145 | # Construct a block for time embedding. 146 | # ''' 147 | # return torch.nn.Sequential( 148 | # torch.torch.nn.Linear(time_embed_dim, time_embed_dim), 149 | # torch.nn.SiLU(), 150 | # torch.torch.nn.Linear(time_embed_dim, time_embed_dim), 151 | # ) 152 | 153 | # def _t_mlp_layer(self, time_embed_dim_common: int, time_embed_dim_layer: int): 154 | # ''' 155 | # Construct a block for time embedding. 156 | # ''' 157 | # return torch.nn.Sequential( 158 | # torch.nn.SiLU(), 159 | # torch.torch.nn.Linear(time_embed_dim_common, time_embed_dim_layer), 160 | # ) 161 | 162 | # def time_independent_parameters(self): 163 | # ''' 164 | # Parameters related to time embedding. 165 | # ''' 166 | # return set(self.parameters()) - set(self.t_emb_list.parameters()) - set(self.t_emb_bottleneck.parameters()) - set(self.t_emb_common.parameters()) 167 | 168 | # def freeze_time_independent(self): 169 | # ''' 170 | # Freeze paramters that are time-independent. 171 | # ''' 172 | # for p in self.time_independent_parameters(): 173 | # p.requires_grad = False 174 | 175 | # def forward(self, x: torch.Tensor, t: torch.Tensor): 176 | # ''' 177 | # Time embedding through sinusoidal embedding. 178 | # ''' 179 | 180 | # assert x.shape[0] == 1 181 | 182 | # x, residual_list = self.encoder(x) 183 | 184 | # # Time embedding through feature space addition. 185 | # assert x.shape[0] == 1 186 | # t_emb_common = self.t_emb_common(timestep_embedding(t, dim=x.shape[1])) 187 | 188 | # t_emb = self.t_emb_bottleneck(t_emb_common) 189 | # t_emb = t_emb[:, :, None, None].repeat((1, 1, x.shape[2], x.shape[3])) 190 | # x = x + t_emb 191 | 192 | # for d in range(self.depth): 193 | # x = torch.nn.functional.interpolate(x, 194 | # scale_factor=2, 195 | # mode='bilinear', 196 | # align_corners=True) 197 | # res = residual_list.pop(-1) 198 | # t_emb = self.t_emb_list[d](t_emb_common) 199 | # t_emb = t_emb[:, :, None, None].repeat((1, 1, res.shape[2], res.shape[3])) 200 | # res = res + t_emb 201 | # x = torch.cat([x, res], dim=1) 202 | # x = self.non_linearity(self.up_conn_list[d](x)) 203 | # x = self.up_list[d](x) 204 | 205 | # output = self.out_layer(x) 206 | 207 | # return output 208 | 209 | 210 | -------------------------------------------------------------------------------- /src/plotting/demo_gradient_field.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | 5 | # Define the function whose gradient we want to plot 6 | def func(x, y, z): 7 | # return np.cos(x) * np.sin(y) * np.cos(z) 8 | return np.sin(x**2) + np.cos(z) * np.sin(y) 9 | # Define the range and step size for each axis 10 | x_range = np.linspace(-1, 1, 6) 11 | y_range = np.linspace(-1, 1, 6) 12 | z_range = np.linspace(-1, 1, 6) 13 | 14 | # Create a meshgrid from the ranges 15 | X, Y, Z = np.meshgrid(x_range, y_range, z_range) 16 | 17 | # Calculate the gradient of the function at each point 18 | grad_x, grad_y, grad_z = np.gradient(func(X, Y, Z), x_range, y_range, z_range) 19 | 20 | # Plotting 21 | fig = plt.figure() 22 | ax = fig.add_subplot(111, projection='3d') 23 | 24 | # Plot the vector field 25 | ax.quiver(X, Y, Z, grad_x, grad_y, grad_z, 26 | length=0.12, 27 | color='k', 28 | linewidths=1.0, 29 | arrow_length_ratio=0.5, 30 | normalize=False) 31 | 32 | # Remove grid. 33 | ax.grid(False) 34 | # Remove labels. 35 | ax.set_xticks([]) 36 | ax.set_yticks([]) 37 | ax.set_zticks([]) 38 | # Remove axes lines. 39 | ax.xaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 40 | ax.yaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 41 | ax.zaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 42 | # Change box color. 43 | ax.xaxis.set_pane_color((251/255, 229/255, 214/255, 0.4)) 44 | ax.yaxis.set_pane_color((251/255, 229/255, 214/255, 0.4)) 45 | ax.zaxis.set_pane_color((251/255, 229/255, 214/255, 0.4)) 46 | 47 | ax.set_title('Gradient Field') 48 | 49 | plt.savefig('demo_gradient_field.png') 50 | -------------------------------------------------------------------------------- /src/preprocessing/01_preprocess_brain_MS.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import numpy as np 4 | import cv2 5 | from tqdm import tqdm 6 | import nibabel as nib 7 | 8 | 9 | def normalize(mri_scan): 10 | assert np.min(mri_scan) == 0 11 | lower_bound = 0 12 | upper_bound = np.percentile(mri_scan, 99.90) 13 | mri_scan = np.clip(mri_scan, lower_bound, upper_bound) 14 | mri_scan = mri_scan / upper_bound 15 | return np.uint8(mri_scan * 255) 16 | 17 | 18 | if __name__ == '__main__': 19 | 20 | image_folder = '../../data/brain_MS/brain_MS_images/' 21 | out_shape = np.array((256, 256)) 22 | file_extension = '_flair_pp.nii' 23 | 24 | min_num_pixel_for_lesion = 250 25 | 26 | subject_dirs = sorted(glob(image_folder + '*')) 27 | 28 | for folder in tqdm(subject_dirs): 29 | subject_name = folder.split('/')[-1] 30 | 31 | scan_paths = sorted(glob(image_folder + subject_name + '/*%s' % file_extension)) 32 | mask_paths = sorted(glob(image_folder + subject_name + '/*mask1.nii')) 33 | 34 | assert len(scan_paths) == len(mask_paths) 35 | 36 | slices_with_ms = [] 37 | for m_pth in (mask_paths): 38 | # Only use the mask to find which slices have MS. 39 | mask_nii = nib.load(m_pth) 40 | mask = mask_nii.get_fdata() 41 | assert mask.shape == (181, 217, 181) 42 | slices_with_ms.extend(np.argwhere(mask.sum(axis=(0, 1)) > min_num_pixel_for_lesion).flatten()) 43 | 44 | slices_with_ms = np.unique(slices_with_ms) 45 | 46 | for s_pth, m_pth in zip(scan_paths, mask_paths): 47 | out_path_image = s_pth.replace('brain_MS_images', 'brain_MS_images_256x256') 48 | out_path_mask = s_pth.replace('brain_MS_images', 'brain_MS_masks_256x256') 49 | 50 | scan_nii = nib.load(s_pth) 51 | scan = scan_nii.get_fdata() 52 | scan = normalize(scan) 53 | assert scan.shape == (181, 217, 181) 54 | 55 | mask_nii = nib.load(m_pth) 56 | mask = mask_nii.get_fdata() 57 | mask = mask * 255 58 | assert mask.shape == (181, 217, 181) 59 | 60 | # Get the slices with MS. 61 | for i in slices_with_ms: 62 | out_fname_image = out_path_image.replace(subject_name, '%s_slice%s' % (subject_name, str(i).zfill(3))).replace(file_extension, '.png') 63 | out_fname_mask = out_path_mask.replace(subject_name, '%s_slice%s' % (subject_name, str(i).zfill(3))).replace(file_extension, '_MS_mask.png') 64 | os.makedirs(os.path.dirname(out_fname_image), exist_ok=True) 65 | os.makedirs(os.path.dirname(out_fname_mask), exist_ok=True) 66 | 67 | img = scan[:, :, i] 68 | msk = mask[:, :, i] 69 | reshape_ratio = img.shape[:2] / out_shape 70 | tmp_out_shape = np.int16(img.shape[:2] / reshape_ratio.max()) 71 | 72 | img = cv2.resize(img, 73 | dsize=tmp_out_shape[::-1], 74 | interpolation=cv2.INTER_CUBIC) 75 | msk = cv2.resize(msk, 76 | dsize=tmp_out_shape[::-1], 77 | interpolation=cv2.INTER_NEAREST) 78 | 79 | if img.shape[0] == img.shape[1]: 80 | final_img = img 81 | final_mask = msk 82 | 83 | elif img.shape[0] > img.shape[1]: 84 | final_img = np.zeros(out_shape, dtype=np.uint8) 85 | final_mask = np.zeros(out_shape, dtype=np.uint8) 86 | delta_size = final_img.shape[1] - img.shape[1] 87 | final_img[:, delta_size // 2 + 1 : final_img.shape[1] - delta_size // 2] = img 88 | final_mask[:, delta_size // 2 + 1 : final_img.shape[1] - delta_size // 2] = msk 89 | else: 90 | final_img = np.zeros(out_shape, dtype=np.uint8) 91 | final_mask = np.zeros(out_shape, dtype=np.uint8) 92 | delta_size = final_img.shape[0] - img.shape[0] 93 | final_img[delta_size // 2 + 1 : final_img.shape[0] - delta_size // 2, :] = img 94 | final_mask[delta_size // 2 + 1 : final_img.shape[0] - delta_size // 2, :] = msk 95 | 96 | cv2.imwrite(out_fname_image, final_img) 97 | cv2.imwrite(out_fname_mask, final_mask) 98 | -------------------------------------------------------------------------------- /src/preprocessing/01_preprocess_retina_AREDS.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import numpy as np 4 | import cv2 5 | from tqdm import tqdm 6 | 7 | if __name__ == '__main__': 8 | image_folder = '../../data/retina_areds/AREDS_2014_images/' 9 | 10 | image_paths = sorted(glob(image_folder + '*/*.jpg')) 11 | out_shape = np.array((512, 512)) 12 | 13 | for pth in tqdm(image_paths): 14 | out_path = pth.replace('AREDS_2014_images', 15 | 'AREDS_2014_images_512x512') 16 | os.makedirs(os.path.dirname(out_path), exist_ok=True) 17 | 18 | img = cv2.imread(pth) 19 | reshape_ratio = img.shape[:2] / out_shape 20 | tmp_out_shape = np.int16(img.shape[:2] / reshape_ratio.max()) 21 | 22 | img = cv2.resize(img, 23 | dsize=tmp_out_shape[::-1], 24 | interpolation=cv2.INTER_CUBIC) 25 | 26 | if img.shape[0] == img.shape[1]: 27 | final_img = img 28 | elif img.shape[0] > img.shape[1]: 29 | final_img = np.zeros((*out_shape, 3), dtype=np.uint8) 30 | delta_size = final_img.shape[1] - img.shape[1] 31 | final_img[:, delta_size // 2:final_img.shape[1] - 32 | delta_size // 2, :] = img 33 | else: 34 | final_img = np.zeros((*out_shape, 3), dtype=np.uint8) 35 | delta_size = final_img.shape[0] - img.shape[0] 36 | final_img[delta_size // 2:final_img.shape[0] - 37 | delta_size // 2, :, :] = img 38 | 39 | cv2.imwrite(out_path, final_img) 40 | -------------------------------------------------------------------------------- /src/preprocessing/01_preprocess_retina_UCSF.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import numpy as np 4 | import cv2 5 | from tqdm import tqdm 6 | from read_roi import read_roi_file 7 | from read_roi import read_roi_zip 8 | from PIL import Image, ImageDraw 9 | 10 | 11 | def polygon_to_mask(x_list, y_list, mask_shape): 12 | polygon_coords = [(x, y) for (x, y) in zip(x_list, y_list)] 13 | mask = Image.new('L', mask_shape[:2], 0) 14 | ImageDraw.Draw(mask).polygon(polygon_coords, outline=1, fill=1) 15 | mask = np.array(mask) 16 | return mask 17 | 18 | 19 | def resize_and_pad(img, out_shape): 20 | reshape_ratio = img.shape[:2] / out_shape[:2] 21 | tmp_out_shape = np.int16(img.shape[:2] / reshape_ratio.max()) 22 | 23 | img = cv2.resize(img, 24 | dsize=tmp_out_shape[::-1], 25 | interpolation=cv2.INTER_CUBIC) 26 | 27 | if img.shape[0] == img.shape[1]: 28 | final_img = img 29 | elif img.shape[0] > img.shape[1]: 30 | final_img = np.zeros((out_shape), dtype=np.uint8) 31 | delta_size = final_img.shape[1] - img.shape[1] 32 | final_img[:, delta_size // 2:final_img.shape[1] - 33 | delta_size // 2, ...] = img 34 | else: 35 | final_img = np.zeros((out_shape), dtype=np.uint8) 36 | delta_size = final_img.shape[0] - img.shape[0] 37 | final_img[delta_size // 2:final_img.shape[0] - 38 | delta_size // 2, :, ...] = img 39 | 40 | return final_img 41 | 42 | 43 | if __name__ == '__main__': 44 | image_folder = '../../data/retina_ucsf/Images/Raw Images/' 45 | roi_folder = '../../data/retina_ucsf/Images/Graded Images and ROI Files/LS_review/ImageJROI/All/' 46 | 47 | image_paths = sorted(glob(image_folder + '*/*.tif')) 48 | # Currently, shape has to be a square. 49 | out_shape_image = np.array((512, 512, 3)) 50 | out_shape_mask = np.array((512, 512)) 51 | 52 | for pth in tqdm(image_paths): 53 | 54 | folder_name = pth.split('/')[-2] 55 | grader = os.path.basename(pth).split(folder_name)[0] 56 | unique_identifier = '_'.join(os.path.basename(pth).split('_')[:3]) 57 | 58 | # Ignore the "graded" images. 59 | if len(grader) > 0: 60 | continue 61 | 62 | # Save the image 63 | out_path_image = pth.replace('Images/Raw Images/', 64 | 'UCSF_images_512x512/').replace('.tif', '.png') 65 | os.makedirs(os.path.dirname(out_path_image), exist_ok=True) 66 | 67 | img = cv2.imread(pth) 68 | raw_image_shape = img.shape 69 | 70 | img = resize_and_pad(img, out_shape_image) 71 | cv2.imwrite(out_path_image, img) 72 | 73 | # Find the corresponding ROI files. 74 | roi_files = sorted(glob(roi_folder + folder_name + '/' + unique_identifier + '*.roi')) 75 | roi_zip_files = sorted(glob(roi_folder + folder_name + '/' + unique_identifier + '*.zip')) 76 | 77 | # Convert the ROI files to masks. 78 | for roi_file in roi_files: 79 | _roi = read_roi_file(roi_file) 80 | 81 | assert len(_roi.keys()) == 1 82 | for k in _roi.keys(): 83 | _roi_item = _roi[k] 84 | 85 | roi_save_name = _roi_item['name'] 86 | 87 | assert unique_identifier in _roi_item['name'] 88 | 89 | mask = np.zeros(raw_image_shape[:2]) 90 | 91 | assert _roi_item['type'] in ['point', 'polygon'] 92 | if _roi_item['type'] == 'point': 93 | #TODO: Use a small square to represent the point. 94 | continue 95 | 96 | elif _roi_item['type'] == 'polygon': 97 | # Convert the polygon to a mask. 98 | mask = polygon_to_mask(_roi_item['x'], _roi_item['y'], raw_image_shape) 99 | mask = resize_and_pad(mask, out_shape_mask) 100 | 101 | mask = np.uint8(mask * 255) 102 | 103 | out_path_mask = pth.replace('Images/Raw Images/', 104 | 'UCSF_masks_512x512/').replace( 105 | os.path.basename(pth), roi_save_name + '_mask.png') 106 | os.makedirs(os.path.dirname(out_path_mask), exist_ok=True) 107 | cv2.imwrite(out_path_mask, mask) 108 | 109 | # Convert the ROI zip files to masks. 110 | for roi_zip_file in roi_zip_files: 111 | _roi = read_roi_zip(roi_zip_file) 112 | roi_save_name = os.path.basename(roi_zip_file).replace('.roi', '').replace('.zip', '') 113 | assert unique_identifier in roi_save_name 114 | 115 | assert len(_roi.keys()) > 1 116 | 117 | mask = None 118 | for k in _roi.keys(): 119 | _roi_item = _roi[k] 120 | 121 | assert _roi_item['type'] == 'polygon' 122 | 123 | curr_mask = polygon_to_mask(_roi_item['x'], _roi_item['y'], raw_image_shape) 124 | curr_mask = resize_and_pad(curr_mask, out_shape_mask) 125 | 126 | if mask is None: 127 | mask = curr_mask 128 | else: 129 | mask = np.logical_or(mask, curr_mask) 130 | 131 | mask = np.uint8(mask * 255) 132 | 133 | out_path_mask = pth.replace('Images/Raw Images/', 134 | 'UCSF_masks_512x512/').replace( 135 | os.path.basename(pth), roi_save_name + '_mask.png') 136 | os.makedirs(os.path.dirname(out_path_mask), exist_ok=True) 137 | cv2.imwrite(out_path_mask, mask) 138 | 139 | -------------------------------------------------------------------------------- /src/preprocessing/03_crop_retina_UCSF.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import cv2 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | 8 | def find_biggest_intersecting_square(object_mask: np.array) -> np.array: 9 | assert len(object_mask.shape) == 2 10 | assert object_mask.min() == 0 and object_mask.max() == 1 11 | 12 | # `square_size[i, j]` is the side length of the biggest square 13 | # whose bottom right corner is at [i, j]. 14 | square_size = np.int16(np.zeros_like(object_mask)) 15 | for i in range(square_size.shape[0]): 16 | for j in range(square_size.shape[1]): 17 | if i == 0 or j == 0: 18 | square_size[i, j] = np.int16(object_mask[i, j]) 19 | elif object_mask[i, j] == 0: 20 | square_size[i, j] = 0 21 | else: 22 | square_size[i, j] = 1 + np.min([ 23 | square_size[i-1, j], square_size[i, j-1], square_size[i-1, j-1] 24 | ]) 25 | 26 | # Return the top-left and the bottom-right squares, in case one is better than the other 27 | max_size = square_size.max() 28 | bottomright_arr = np.where(square_size == max_size) 29 | # top left square. 30 | h_br, w_br = bottomright_arr[0][0], bottomright_arr[1][0] 31 | h_tl, w_tl = h_br - max_size, w_br - max_size 32 | square_mask_tl = np.zeros_like(object_mask) 33 | square_mask_tl[h_tl + 1: h_br, w_tl + 1 : w_br] = 1 34 | 35 | # bottom right square. 36 | h_br, w_br = bottomright_arr[0][-1], bottomright_arr[1][-1] 37 | h_tl, w_tl = h_br - max_size, w_br - max_size 38 | square_mask_br = np.zeros_like(object_mask) 39 | square_mask_br[h_tl + 1: h_br, w_tl + 1 : w_br] = 1 40 | 41 | return square_mask_tl, square_mask_br 42 | 43 | def crop_longitudinal(output_shape, 44 | base_folder_source: str, 45 | base_mask_folder_source: str, 46 | base_fg_mask_folder_source: str, 47 | base_folder_target: str, 48 | base_mask_folder_target: str): 49 | ''' 50 | Crop the longitudinal images. 51 | 52 | These longitudinal images are already registered. 53 | 54 | We need to crop them such that the final images only have foreground regions (`fg_mask`), 55 | and they contain (hopefully) the entire geographic atropy mask (`mask`). 56 | ''' 57 | 58 | source_image_folders = sorted(glob(base_folder_source + '/*')) 59 | 60 | for folder in tqdm(source_image_folders): 61 | image_list = sorted(glob(folder + '/*.png')) 62 | if len(image_list) <= 2: 63 | # Can ignore this folder if there is fewer than 2 images. 64 | pass 65 | 66 | # Find the corresponding geographic atrophy masks and foreground masks. 67 | ga_mask_list = ['_'.join(img_path.split('_')[:-1]).replace(base_folder_source, base_mask_folder_source) + '_GA_mask.png' for img_path in image_list] 68 | fg_mask_list = [img_path.replace(base_folder_source, base_fg_mask_folder_source).replace('.png', '_foreground_mask.png') for img_path in image_list] 69 | 70 | all_ga_intersection = None 71 | all_fg_intersection = None 72 | 73 | for i in range(len(image_list)): 74 | ga_mask_path = ga_mask_list[i] 75 | fg_mask_path = fg_mask_list[i] 76 | 77 | ga_mask = cv2.imread(ga_mask_path, cv2.IMREAD_GRAYSCALE) 78 | fg_mask = cv2.imread(fg_mask_path, cv2.IMREAD_GRAYSCALE) 79 | 80 | assert ga_mask.max() == 255 81 | assert fg_mask.max() == 255 82 | 83 | if all_ga_intersection is None: 84 | all_ga_intersection = ga_mask > 128 85 | all_fg_intersection = fg_mask > 128 86 | else: 87 | all_ga_intersection = np.logical_and(all_ga_intersection, ga_mask > 128) 88 | all_fg_intersection = np.logical_and(all_fg_intersection, fg_mask > 128) 89 | 90 | # Check that `all_ga_intersection` is fully inside `all_fg_intersection`. 91 | assert (np.logical_or(all_ga_intersection, all_fg_intersection) == all_fg_intersection).all() 92 | 93 | # Find a square that is fully inside `all_fg_intersection`. 94 | assert all_fg_intersection.min() in [0, 1] 95 | 96 | if all_fg_intersection.min() == 1: 97 | common_fg_mask = all_fg_intersection 98 | else: 99 | common_fg_mask_tl, common_fg_mask_br = find_biggest_intersecting_square(all_fg_intersection) 100 | if np.logical_and(all_ga_intersection, common_fg_mask_tl).sum() > np.logical_and(all_ga_intersection, common_fg_mask_br).sum(): 101 | common_fg_mask = common_fg_mask_tl 102 | else: 103 | common_fg_mask = common_fg_mask_br 104 | 105 | mask_arr = np.where(common_fg_mask) 106 | h_tl, w_tl = mask_arr[0][0], mask_arr[1][0] 107 | h_br, w_br = mask_arr[0][-1], mask_arr[1][-1] 108 | 109 | for i, image_path in enumerate(image_list): 110 | image_path = image_list[i] 111 | ga_mask_path = ga_mask_list[i] 112 | 113 | image = cv2.imread(image_path, cv2.IMREAD_COLOR) 114 | ga_mask = cv2.imread(ga_mask_path, cv2.IMREAD_GRAYSCALE) 115 | 116 | image = image[h_tl : h_br, w_tl : w_br, :] 117 | ga_mask = ga_mask[h_tl : h_br, w_tl : w_br] 118 | 119 | image = cv2.resize(image, dsize=output_shape[::-1], interpolation=cv2.INTER_CUBIC) 120 | ga_mask = cv2.resize(ga_mask, dsize=output_shape[::-1], interpolation=cv2.INTER_CUBIC) 121 | ga_mask = np.uint8((ga_mask > 128) * 255) 122 | 123 | final_image_path = image_path.replace(base_folder_source, base_folder_target) 124 | final_ga_mask_path = ga_mask_path.replace(base_mask_folder_source, base_mask_folder_target) 125 | 126 | os.makedirs(os.path.dirname(final_image_path), exist_ok=True) 127 | cv2.imwrite(final_image_path, image) 128 | os.makedirs(os.path.dirname(final_ga_mask_path), exist_ok=True) 129 | cv2.imwrite(final_ga_mask_path, ga_mask) 130 | 131 | 132 | if __name__ == '__main__': 133 | base_folder_source = '../../data/retina_ucsf/UCSF_images_aligned_512x512/' 134 | base_mask_folder_source = '../../data/retina_ucsf/UCSF_masks_aligned_512x512/' 135 | base_fg_mask_folder_source = '../../data/retina_ucsf/UCSF_FG_masks_aligned_512x512/' 136 | base_folder_target = '../../data/retina_ucsf/UCSF_images_final_512x512/' 137 | base_mask_folder_target = '../../data/retina_ucsf/UCSF_masks_final_512x512/' 138 | 139 | output_shape = (512, 512) 140 | 141 | crop_longitudinal(output_shape=output_shape, 142 | base_folder_source=base_folder_source, 143 | base_mask_folder_source=base_mask_folder_source, 144 | base_fg_mask_folder_source=base_fg_mask_folder_source, 145 | base_folder_target=base_folder_target, 146 | base_mask_folder_target=base_mask_folder_target) 147 | -------------------------------------------------------------------------------- /src/preprocessing/03_generate_eye_mask_retina_AREDS.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from glob import glob 4 | import torch 5 | import cv2 6 | import numpy as np 7 | from torchvision import transforms 8 | from PIL import Image 9 | from tqdm import tqdm 10 | from typing import Tuple, List, Iterable 11 | from segment_anything import SamPredictor, sam_model_registry 12 | 13 | 14 | class SAM_Segmenter(object): 15 | def __init__(self, device: torch.device, checkpoint: str): 16 | ''' 17 | Initialize a Segment Anything Model (SAM) model. 18 | ''' 19 | sam_model = sam_model_registry["default"](checkpoint=checkpoint).to(device) 20 | self.predictor = SamPredictor(sam_model) 21 | 22 | def segment(self, image: np.array): 23 | ''' 24 | Run Segment Anything Model (SAM) using a box prompt. 25 | ''' 26 | # Estimate the prompt box. 27 | image_green = image[:, :, 1] 28 | x_array, y_array = np.where(image_green > np.percentile(image_green, 50)) 29 | prompt_box = np.array([x_array.min(), y_array.min(), x_array.max(), y_array.max()]) 30 | 31 | self.predictor.set_image(image) 32 | segments, _, _ = self.predictor.predict(box=prompt_box) 33 | segments = segments.transpose(1, 2, 0) 34 | 35 | mask_idx = segments.sum(axis=(0, 1)).argmax() 36 | mask = segments[..., mask_idx] 37 | return mask 38 | 39 | 40 | def crop_longitudinal(base_folder_source: str, base_folder_target: str): 41 | ''' 42 | Crop the longitudinal images. 43 | These images are already spatially registered. 44 | We only need to crop them such that only the overlapping region of the images remain. 45 | 46 | For the case in `data/retina_areds/AREDS_2014_images_aligned_512x512/`, 47 | each folder represents a series of longitudinal images to be cropped. 48 | ''' 49 | 50 | # SAM config 51 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 52 | sam_segmenter = SAM_Segmenter(device=device, checkpoint='../../external_src/SAM/sam_vit_h_4b8939.pth') 53 | 54 | source_image_folders = sorted(glob(base_folder_source + '/*')) 55 | 56 | for folder in tqdm(source_image_folders): 57 | subject_folder_name = os.path.basename(folder) 58 | os.makedirs(base_folder_target + '/' + subject_folder_name + '/', exist_ok=True) 59 | 60 | image_list = sorted(glob(folder + '/*.jpg')) 61 | assert len(image_list) > 0 62 | 63 | # Generate the eye mask and save it to the target folder. 64 | for image_path in image_list: 65 | image_name = os.path.basename(image_path) 66 | image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) 67 | mask = sam_segmenter.segment(image) 68 | mask = np.uint8(mask) * 255 69 | mask_name = image_name.replace('.jpg', '_eye_mask.jpg') 70 | cv2.imwrite(base_folder_target + '/' + subject_folder_name + '/' + mask_name, mask) 71 | 72 | if __name__ == '__main__': 73 | base_folder_source = '../../data/retina_areds/AREDS_2014_images_aligned_512x512/' 74 | base_folder_target = '../../data/retina_areds/AREDS_2014_eye_masks_aligned_512x512/' 75 | 76 | crop_longitudinal(base_folder_source=base_folder_source, 77 | base_folder_target=base_folder_target) 78 | -------------------------------------------------------------------------------- /src/preprocessing/deprecated_03_crop_retina_AREDS.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from glob import glob 4 | import torch 5 | import cv2 6 | import numpy as np 7 | from torchvision import transforms 8 | from PIL import Image 9 | from tqdm import tqdm 10 | from typing import Tuple, List, Iterable 11 | from segment_anything import SamPredictor, sam_model_registry 12 | 13 | 14 | class SAM_Segmenter(object): 15 | def __init__(self, device: torch.device, checkpoint: str): 16 | ''' 17 | Initialize a Segment Anything Model (SAM) model. 18 | ''' 19 | sam_model = sam_model_registry["default"](checkpoint=checkpoint).to(device) 20 | self.predictor = SamPredictor(sam_model) 21 | 22 | def segment(self, image: np.array): 23 | ''' 24 | Run Segment Anything Model (SAM) using a box prompt. 25 | ''' 26 | # Estimate the prompt box. 27 | image_gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 28 | x_array, y_array = np.where(image_gray > np.percentile(image_gray, 75)) 29 | prompt_box = np.array([x_array.min(), y_array.min(), x_array.max(), y_array.max()]) 30 | 31 | self.predictor.set_image(image) 32 | segments, _, _ = self.predictor.predict(box=prompt_box) 33 | segments = segments.transpose(1, 2, 0) 34 | 35 | mask_idx = segments.sum(axis=(0, 1)).argmax() 36 | mask = segments[..., mask_idx] 37 | return mask 38 | 39 | 40 | def crop_longitudinal(base_folder_source: str, base_folder_target: str): 41 | ''' 42 | Crop the longitudinal images. 43 | These images are already spatially registered. 44 | We only need to crop them such that only the overlapping region of the images remain. 45 | 46 | For the case in `data/retina_areds/AREDS_2014_images_aligned_512x512/`, 47 | each folder represents a series of longitudinal images to be cropped. 48 | ''' 49 | 50 | # SuperRetina config 51 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 52 | sam_segmenter = SAM_Segmenter(device=device, checkpoint='../../external_src/SAM/sam_vit_h_4b8939.pth') 53 | 54 | source_image_folders = sorted(glob(base_folder_source + '/*')) 55 | 56 | for folder in tqdm(source_image_folders): 57 | subject_folder_name = os.path.basename(folder) 58 | os.makedirs(base_folder_target + '/' + subject_folder_name + '/', exist_ok=True) 59 | 60 | image_list = sorted(glob(folder + '/*.jpg')) 61 | assert len(image_list) > 0 62 | 63 | if len(image_list) == 1: 64 | # Directly copy the image over if there is only 1 image. 65 | image_path = image_list[0] 66 | image_name = os.path.basename(image_path) 67 | image = cv2.imread(image_path, cv2.IMREAD_COLOR) 68 | cv2.imwrite(base_folder_target + '/' + subject_folder_name + '/' + image_name, image) 69 | 70 | # Build a common mask for the images. 71 | common_mask = None 72 | for image_path in image_list: 73 | image_name = os.path.basename(image_path) 74 | image = cv2.imread(image_path, cv2.IMREAD_COLOR) 75 | mask = sam_segmenter.segment(image) 76 | if common_mask is None: 77 | common_mask = mask 78 | else: 79 | common_mask = np.logical_and(common_mask, mask) 80 | 81 | # Apply the common mask on all images. 82 | assert common_mask is not None 83 | for image_path in image_list: 84 | image_name = os.path.basename(image_path) 85 | image = cv2.imread(image_path, cv2.IMREAD_COLOR) 86 | image[~common_mask] = 0 87 | cv2.imwrite(base_folder_target + '/' + subject_folder_name + '/' + image_name, image) 88 | 89 | 90 | if __name__ == '__main__': 91 | base_folder_source = '../../data/retina_areds/AREDS_2014_images_aligned_512x512/' 92 | base_folder_target = '../../data/retina_areds/AREDS_2014_images_aligned_cropped_512x512/' 93 | 94 | crop_longitudinal(base_folder_source=base_folder_source, 95 | base_folder_target=base_folder_target) 96 | -------------------------------------------------------------------------------- /src/preprocessing/synthesize_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm 5 | from typing import Tuple 6 | from matplotlib import colormaps 7 | 8 | 9 | def _generate_longitudinal(image_shape: Tuple[int] = (256, 256), 10 | num_images: int = 10, 11 | initial_radius: Tuple[int] = (18, 16), 12 | final_radius: Tuple[int] = (36, 48), 13 | random_seed: int = None): 14 | ''' 15 | Generate longitudinal images of an big rectangle containing a small ellipse. 16 | The big square (eye) remains unchanged, while the small ellipse (geographic atrophy) grows. 17 | ''' 18 | 19 | images = [np.zeros((*image_shape, 3), dtype=np.uint8) for _ in range(num_images)] 20 | 21 | if random_seed is not None: 22 | np.random.seed(random_seed) 23 | 24 | color_rectangle = np.uint8(np.array(colormaps['copper'](np.random.choice(range(colormaps['copper'].N)))[:3]) * 255) 25 | color_ellipse = np.uint8(np.array(colormaps['Wistia'](np.random.choice(range(colormaps['Wistia'].N)))[:3]) * 255) 26 | 27 | # First generate the big rectangle. 28 | square_tl = [int(np.random.uniform(1/8*image_shape[i], 1/4*image_shape[i]/4)) for i in range(2)] 29 | square_br = [int(np.random.uniform(3/4*image_shape[i], 7/8*image_shape[i])) for i in range(2)] 30 | square_centroid = np.mean([square_tl, square_br], axis=0) 31 | for image in images: 32 | image[square_tl[0]:square_br[0], 33 | square_tl[1]:square_br[1], :] = color_rectangle 34 | 35 | # Then generate the increasingly bigger ellipses. 36 | ellipse_centroid = [int(np.random.uniform(square_tl[i]+final_radius[i], 37 | square_br[i]-final_radius[i])) for i in range(2)] 38 | radius_x_list = np.linspace(initial_radius[0], final_radius[0], num_images) 39 | radius_y_list = np.linspace(initial_radius[1], final_radius[1], num_images) 40 | for i, image in enumerate(images): 41 | x_arr = np.linspace(0, image_shape[0]-1, image_shape[0])[:, None] 42 | y_arr = np.linspace(0, image_shape[1]-1, image_shape[1]) 43 | ellipse_mask = ((x_arr-ellipse_centroid[0])/radius_x_list[i])**2 + \ 44 | ((y_arr-ellipse_centroid[1])/radius_y_list[i])**2 <= 1 45 | image[ellipse_mask, :] = color_ellipse 46 | 47 | # OpenCV color channel convertion for saving. 48 | images = [cv2.cvtColor(image, cv2.COLOR_RGB2BGR) for image in images] 49 | return images, square_centroid 50 | 51 | 52 | def synthesize_dataset(save_folder: str = '../../data/synthesized/', num_subjects: int = 200): 53 | ''' 54 | Synthesize 4 datasets. 55 | 1. The first dataset has no spatial variation. It has pixel-level alignment temporally. 56 | 2. The second dataset has a predictable translation factor. 57 | 3. The third dataset has a predictable rotation factor. 58 | 4. The fourth dataset is irregular. At each time point, we randomly pick an image from 1/2/3 at that time point. 59 | ''' 60 | 61 | for subject_idx in tqdm(range(num_subjects)): 62 | images, square_centroid = _generate_longitudinal(random_seed=subject_idx) 63 | images_trans, images_rot = [], [] 64 | 65 | # Do nothing. 66 | dataset = 'base' 67 | os.makedirs(save_folder + dataset + '/subject_%s' % str(subject_idx).zfill(5), exist_ok=True) 68 | for time_idx, img in enumerate(images): 69 | cv2.imwrite(save_folder + dataset + '/subject_%s' % str(subject_idx).zfill(5) + '/subject_%s_time_%s.png' % ( 70 | str(subject_idx).zfill(5), str(time_idx).zfill(3)), img) 71 | 72 | # Add translation. 73 | dataset = 'translation' 74 | os.makedirs(save_folder + dataset + '/subject_%s' % str(subject_idx).zfill(5), exist_ok=True) 75 | max_trans_x, max_trans_y = 32, 32 76 | for time_idx, img in enumerate(images): 77 | translation_x = int(2 * max_trans_x / (len(images) - 1) * time_idx - max_trans_x) 78 | translation_y = int(max_trans_y * np.cos(time_idx / len(images) * 2*np.pi)) 79 | translation_matrix = np.float32([[1, 0, translation_x], [0, 1, translation_y]]) 80 | img_trans = cv2.warpAffine(img, translation_matrix, (img.shape[0], img.shape[1])) 81 | cv2.imwrite(save_folder + dataset + '/subject_%s' % str(subject_idx).zfill(5) + '/subject_%s_time_%s.png' % ( 82 | str(subject_idx).zfill(5), str(time_idx).zfill(3)), img_trans) 83 | images_trans.append(img_trans) 84 | 85 | # Add rotation. 86 | dataset = 'rotation' 87 | os.makedirs(save_folder + dataset + '/subject_%s' % str(subject_idx).zfill(5), exist_ok=True) 88 | for time_idx, img in enumerate(images): 89 | angle = np.linspace(0, 180, len(images))[time_idx] 90 | rotation_matrix = cv2.getRotationMatrix2D((square_centroid[1], square_centroid[0]), angle, 1) 91 | img_rot = cv2.warpAffine(img, rotation_matrix, (img.shape[0], img.shape[1])) 92 | cv2.imwrite(save_folder + dataset + '/subject_%s' % str(subject_idx).zfill(5) + '/subject_%s_time_%s.png' % ( 93 | str(subject_idx).zfill(5), str(time_idx).zfill(3)), img_rot) 94 | images_rot.append(img_rot) 95 | 96 | # Randomly pick from previous lists. 97 | dataset = 'mixing' 98 | os.makedirs(save_folder + dataset + '/subject_%s' % str(subject_idx).zfill(5), exist_ok=True) 99 | for time_idx in range(len(images)): 100 | chosen_list = np.random.choice(['images', 'images_trans', 'images_rot']) 101 | cv2.imwrite(save_folder + dataset + '/subject_%s' % str(subject_idx).zfill(5) + '/subject_%s_time_%s.png' % ( 102 | str(subject_idx).zfill(5), str(time_idx).zfill(3)), eval(chosen_list)[time_idx]) 103 | return 104 | 105 | 106 | if __name__ == '__main__': 107 | synthesize_dataset() -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrishnaswamyLab/ImageFlowNet/61412a9a014776e653de5bd713b6b817852be3d8/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/attribute_hashmap.py: -------------------------------------------------------------------------------- 1 | class AttributeHashmap(dict): 2 | """ 3 | A Specialized hashmap such that: 4 | hash_map = AttributeHashmap(dictionary) 5 | `hash_map.key` is equivalent to `dictionary[key]` 6 | """ 7 | def __init__(self, *args, **kwargs): 8 | super(AttributeHashmap, self).__init__(*args, **kwargs) 9 | self.__dict__ = self 10 | -------------------------------------------------------------------------------- /src/utils/early_stop.py: -------------------------------------------------------------------------------- 1 | class EarlyStopping(object): 2 | """ 3 | Early Stopping pytorch implementation from Stefano Nardo 4 | https://gist.github.com/stefanonardo/693d96ceb2f531fa05db530f3e21517d 5 | """ 6 | 7 | def __init__(self, mode='min', min_delta=0, patience=8, percentage=False): 8 | 9 | self.mode = mode 10 | self.min_delta = min_delta 11 | self.patience = patience 12 | self.best = None 13 | self.num_bad_epochs = 0 14 | self.is_better = None 15 | self._init_is_better(mode, min_delta, percentage) 16 | 17 | if patience == 0: 18 | self.is_better = lambda a, b: True 19 | self.step = lambda a: False 20 | 21 | def step(self, metrics): 22 | 23 | if self.best is None: 24 | self.best = metrics 25 | return False 26 | 27 | # slight modification from source, to handle non-tensor metrics. If NAN, return True. 28 | if metrics != metrics: 29 | return True 30 | 31 | if self.is_better(metrics, self.best): 32 | self.num_bad_epochs = 0 33 | self.best = metrics 34 | else: 35 | self.num_bad_epochs += 1 36 | 37 | if self.num_bad_epochs >= self.patience: 38 | return True 39 | 40 | return False 41 | 42 | def _init_is_better(self, mode, min_delta, percentage): 43 | 44 | if mode not in {'min', 'max'}: 45 | raise ValueError('mode ' + mode + ' is unknown!') 46 | if not percentage: 47 | if mode == 'min': 48 | self.is_better = lambda a, best: a < best - min_delta 49 | if mode == 'max': 50 | self.is_better = lambda a, best: a > best + min_delta 51 | else: 52 | if mode == 'min': 53 | self.is_better = lambda a, best: a < best - ( 54 | best * min_delta / 100) 55 | if mode == 'max': 56 | self.is_better = lambda a, best: a > best + ( 57 | best * min_delta / 100) 58 | -------------------------------------------------------------------------------- /src/utils/log_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | 6 | 7 | def log(s, filepath=None, to_console=True): 8 | ''' 9 | Logs a string to either file or console 10 | 11 | Arg(s): 12 | s : str 13 | string to log 14 | filepath 15 | output filepath for logging 16 | to_console : bool 17 | log to console 18 | ''' 19 | 20 | if to_console: 21 | print(s) 22 | 23 | if filepath is not None: 24 | if not os.path.isdir(os.path.dirname(filepath)): 25 | os.makedirs(os.path.dirname(filepath)) 26 | with open(filepath, 'w+') as o: 27 | o.write(s + '\n') 28 | else: 29 | with open(filepath, 'a+') as o: 30 | o.write(s + '\n') 31 | 32 | def colorize(T, colormap='magma'): 33 | ''' 34 | Colorizes a 1-channel tensor with matplotlib colormaps 35 | 36 | Arg(s): 37 | T : torch.Tensor[float32] 38 | 1-channel tensor 39 | colormap : str 40 | matplotlib colormap 41 | ''' 42 | 43 | cm = plt.cm.get_cmap(colormap) 44 | shape = T.shape 45 | 46 | # Convert to numpy array and transpose 47 | if shape[0] > 1: 48 | T = np.squeeze(np.transpose(T.cpu().numpy(), (0, 2, 3, 1))) 49 | else: 50 | T = np.squeeze(np.transpose(T.cpu().numpy(), (0, 2, 3, 1)), axis=-1) 51 | 52 | # Colorize using colormap and transpose back 53 | color = np.concatenate([ 54 | np.expand_dims(cm(T[n, ...])[..., 0:3], 0) for n in range(T.shape[0])], 55 | axis=0) 56 | color = np.transpose(color, (0, 3, 1, 2)) 57 | 58 | # Convert back to tensor 59 | return torch.from_numpy(color.astype(np.float32)) 60 | -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.metrics import hausdorff_distance, structural_similarity 3 | 4 | 5 | def psnr(image1, image2, max_value=2): 6 | ''' 7 | Assuming data range is [-1, 1]. 8 | ''' 9 | assert image1.shape == image2.shape 10 | 11 | eps = 1e-12 12 | 13 | mse = np.mean((image1 - image2)**2) 14 | return 20 * np.log10(max_value / np.sqrt(mse + eps)) 15 | 16 | 17 | def ssim(image1: np.array, image2: np.array, data_range=2, **kwargs) -> float: 18 | ''' 19 | Please make sure the data are provided in [H, W, C] shape. 20 | 21 | Assuming data range is [-1, 1] --> `data_range` = 1. 22 | ''' 23 | assert image1.shape == image2.shape 24 | 25 | H, W = image1.shape[:2] 26 | 27 | if min(H, W) < 7: 28 | win_size = min(H, W) 29 | if win_size % 2 == 0: 30 | win_size -= 1 31 | else: 32 | win_size = None 33 | 34 | if len(image1.shape) == 3: 35 | channel_axis = -1 36 | else: 37 | channel_axis = None 38 | 39 | return structural_similarity(image1, 40 | image2, 41 | data_range=data_range, 42 | channel_axis=channel_axis, 43 | win_size=win_size, 44 | **kwargs) 45 | 46 | 47 | def dice_coeff(label_pred: np.array, label_true: np.array) -> float: 48 | epsilon = 1e-12 49 | intersection = np.logical_and(label_pred, label_true).sum() 50 | dice = (2 * intersection + epsilon) / (label_pred.sum() + 51 | label_true.sum() + epsilon) 52 | return dice 53 | 54 | 55 | def hausdorff(label_pred: np.array, label_true: np.array) -> float: 56 | if np.sum(label_pred) == 0 and np.sum(label_true) == 0: 57 | # If both of `label_pred` or `label_true` are all zeros, 58 | # return 0. 59 | return 0 60 | 61 | elif np.sum(label_pred) == 0 or np.sum(label_true) == 0: 62 | # If one of `label_pred` or `label_true` is all zeros, 63 | # but not both, 64 | # return the max Euclidean distance. 65 | H, W = label_true.shape[:2] 66 | return np.sqrt((H**2 + W**2)) 67 | 68 | else: 69 | return hausdorff_distance(label_pred, label_true) 70 | -------------------------------------------------------------------------------- /src/utils/parse.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import os 3 | from glob import glob 4 | 5 | from utils.attribute_hashmap import AttributeHashmap 6 | from utils.log_util import log 7 | 8 | 9 | def parse_settings(config: AttributeHashmap, segmentor: bool = False, 10 | log_settings: bool = True, run_count: int = None): 11 | # fix typing issues 12 | for key in ['learning_rate', 'ode_tol']: 13 | if key in config.keys(): 14 | config[key] = float(config[key]) 15 | 16 | # fix path issues 17 | ROOT = '/'.join( 18 | os.path.dirname(os.path.abspath(__file__)).split('/')[:-2]) 19 | for key in config.keys(): 20 | if type(config[key]) == str and '$ROOT' in config[key]: 21 | config[key] = config[key].replace('$ROOT', ROOT) 22 | 23 | if segmentor: 24 | config.save_folder = os.path.dirname(config.segmentor_ckpt) + '/' 25 | os.makedirs(config.save_folder, exist_ok=True) 26 | config.model_save_path = config.segmentor_ckpt 27 | 28 | else: 29 | setting_str = '%s_%s_%ssmoothness-%.3f_latent-%.3f_contrastive-%.3f_invariance-%.3f_seed_%s' % ( 30 | config.dataset_name, 31 | config.model, 32 | 'NoL2_' if config.no_l2 else '', 33 | config.coeff_smoothness, 34 | config.coeff_latent, 35 | config.coeff_contrastive, 36 | config.coeff_invariance, 37 | config.random_seed, 38 | ) 39 | 40 | output_save_path = '%s/%s' % (config.output_save_folder, setting_str) 41 | 42 | # Initialize save folder. 43 | if run_count is None: 44 | existing_runs = glob(output_save_path + '/run_*/') 45 | if len(existing_runs) > 0: 46 | run_counts = [int(item.split('/')[-2].split('run_')[1]) for item in existing_runs] 47 | run_count = max(run_counts) + 1 48 | else: 49 | run_count = 1 50 | 51 | config.save_folder = '%s/run_%d/' % (output_save_path, run_count) 52 | config.model_save_path = config.save_folder + setting_str + '.pty' 53 | 54 | # Initialize log file. 55 | config.log_dir = config.save_folder + 'log.txt' 56 | if log_settings: 57 | log_str = 'Config: \n' 58 | for key in config.keys(): 59 | log_str += '%s: %s\n' % (key, config[key]) 60 | log_str += '\nTraining History:' 61 | log(log_str, filepath=config.log_dir, to_console=True) 62 | 63 | return config 64 | -------------------------------------------------------------------------------- /src/utils/seed.py: -------------------------------------------------------------------------------- 1 | def seed_everything(seed: int) -> None: 2 | """ 3 | https://gist.github.com/ihoromi4/b681a9088f348942b01711f251e5f964 4 | """ 5 | 6 | import os 7 | import random 8 | 9 | import numpy as np 10 | import torch 11 | 12 | random.seed(seed) 13 | os.environ['PYTHONHASHSEED'] = str(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.backends.cudnn.deterministic = True 18 | torch.backends.cudnn.benchmark = True --------------------------------------------------------------------------------