├── runners ├── __init__.py └── diffusion.py ├── models ├── __init__.py ├── optimizer.py ├── losses.py ├── galerkin.py ├── ema.py ├── denoising.py ├── rdn.py ├── sronet.py ├── edsr.py └── diffusion.py ├── configs ├── test.yml └── train.yml ├── scheduler.py ├── README.md ├── datasets ├── __init__.py └── image_folder.py ├── init_weight.py ├── demo.py ├── utils.py ├── main.py └── test.py /runners/__init__.py: -------------------------------------------------------------------------------- 1 | from . import diffusion -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import edsr, rdn 2 | from . import sronet, galerkin 3 | from . import ema, diffusion 4 | from . import denoising, losses 5 | from . import optimizer -------------------------------------------------------------------------------- /models/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | 4 | def get_optimizer(config, parameters): 5 | if config.optim.optimizer == 'Adam': 6 | return optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay, 7 | betas=(config.optim.beta1, 0.999), amsgrad=config.optim.amsgrad, 8 | eps=config.optim.eps) 9 | elif config.optim.optimizer == 'RMSProp': 10 | return optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay) 11 | elif config.optim.optimizer == 'SGD': 12 | return optim.SGD(parameters, lr=config.optim.lr, momentum=0.9) 13 | else: 14 | raise NotImplementedError( 15 | 'Optimizer {} not understood.'.format(config.optim.optimizer)) 16 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from models.denoising import generalized_steps 4 | from utils import calc_ssim 5 | from datasets import inverse_data_transform 6 | 7 | 8 | def noise_estimation_loss(model, 9 | x_t: torch.Tensor, 10 | inp: torch.Tensor, 11 | hr_coord: torch.Tensor, 12 | cell: torch.Tensor, 13 | e: torch.Tensor, 14 | b: torch.Tensor, 15 | continuous_sqrt_alpha_cumprod, 16 | keepdim=False): 17 | output = model(x_t, continuous_sqrt_alpha_cumprod, inp, hr_coord, cell) 18 | sum_pixel = e.shape[1]*e.shape[2]*e.shape[3] 19 | if keepdim: 20 | return (e - output).abs().sum(dim=(1, 2, 3))/sum_pixel 21 | else: 22 | return (e - output).abs().sum(dim=(1, 2, 3)).mean(dim=0)/sum_pixel 23 | 24 | 25 | loss_registry = { 26 | 'simple': noise_estimation_loss, 27 | } 28 | -------------------------------------------------------------------------------- /configs/test.yml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | root_path: ./data/test 4 | repeat: 1 5 | scale_max: 4 6 | augment: False 7 | batch_size: 1 8 | 9 | data: 10 | logit_transform: False 11 | uniform_dequantization: False 12 | gaussian_dequantization: False 13 | rescaled: True 14 | 15 | model: 16 | type: "simple" 17 | image_size: 256 18 | in_channel: 6 19 | out_channel: 3 20 | inner_channel: 64 21 | norm_groups: 16 #32 22 | channel_mults: [1, 2, 2, 4, 4, 8] 23 | res_blocks: 1 # 2 24 | attn_res: [16, ] 25 | dropout: 0 26 | var_type: fixedsmall 27 | ema_rate: 0.999 28 | ema: True 29 | with_noise_level_emb: True 30 | srno: 31 | encoder: edsr-baseline 32 | no_upsampling: True 33 | width: 256 34 | blocks: 16 35 | 36 | diffusion: 37 | beta_schedule: linear 38 | beta_start: 0.000001 39 | beta_end: 0.01 40 | num_diffusion_timesteps: 2000 41 | 42 | srno_weight: ./srno_weight.pth -------------------------------------------------------------------------------- /scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | 4 | 5 | class CustomIterationScheduler(_LRScheduler): 6 | def __init__(self, optimizer, lr_sequence, step_size, last_iteration=-1): 7 | self.lr_sequence = lr_sequence 8 | self.step_size = step_size 9 | self.current_step = last_iteration 10 | self.lr_index = 0 11 | super(CustomIterationScheduler, self).__init__(optimizer, last_epoch=last_iteration) 12 | 13 | def get_lr(self): 14 | if self.lr_index < len(self.lr_sequence): 15 | return [self.lr_sequence[self.lr_index] for _ in self.optimizer.param_groups] 16 | else: 17 | return [self.lr_sequence[-1] for _ in self.optimizer.param_groups] 18 | 19 | def step(self, iteration=None): 20 | if iteration is None: 21 | self.current_step += 1 22 | else: 23 | self.current_step = iteration 24 | 25 | new_lr_index = self.current_step // self.step_size 26 | 27 | if new_lr_index > self.lr_index and self.lr_index < len(self.lr_sequence) - 1: 28 | self.lr_index = new_lr_index 29 | super(CustomIterationScheduler, self).step() 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Continuous Remote Sensing Image Super-Resolution via Neural Operator Diffusion 2 | 3 | This repository is the official implementation of the paper "NeurOp-Diff: Continuous Remote Sensing Image Super-Resolution via Neural Operator Diffusion". 4 | 5 | # Environment configuration 6 | 7 | The codes are based on python 3.11, pytorch 2.4.0 and CUDA version 12.4. 8 | 9 | # Data preparation 10 | 11 | - [UCMerced](https://drive.google.com/drive/folders/1Mknr0n4VjWIAk3yQwGewplDe4bUa-_D-?usp=drive_link) | [AID_256](https://drive.google.com/drive/folders/1Mknr0n4VjWIAk3yQwGewplDe4bUa-_D-?usp=drive_link) | [RSSCN7_256](https://drive.google.com/drive/folders/1Mknr0n4VjWIAk3yQwGewplDe4bUa-_D-?usp=drive_link) 12 | 13 | # checkpoint 14 | 15 | The pre-trained weights for continuous SR can be found at this [link](https://drive.google.com/file/d/1A06iFZUyu1-CnYtIceBFmThhdhW65oH8/view?usp=sharing) 16 | 17 | # Train 18 | 19 | ```python main.py --config train.yml --exp ./result --doc pth --timesteps [steps] --ni``` 20 | 21 | # Test 22 | 23 | ```python test.py --config ./configs/test.yml --model [checkpoint_path] --timesteps [steps]``` 24 | 25 | # Demo 26 | 27 | ```python demo.py --config ./configs/test.yml --model [checkpoint_path] --path [image_path]``` 28 | 29 | # Acknowledgements 30 | 31 | This code is mainly built based on [DDIM](https://github.com/ermongroup/ddim), [SRNO](https://github.com/2y7c3/Super-Resolution-Neural-Operator) and [LIIF](https://github.com/yinboc/liif) -------------------------------------------------------------------------------- /configs/train.yml: -------------------------------------------------------------------------------- 1 | train_dataset: 2 | dataset: 3 | root_path: ./data/train 4 | repeat: 1 5 | scale_min: 2 6 | scale_max: 8 7 | augment: True 8 | batch_size: 10 9 | 10 | val_dataset: 11 | dataset: 12 | root_path: ./data/val 13 | repeat: 1 14 | scale_min: 4 15 | scale_max: 4 16 | augment: False 17 | batch_size: 1 18 | 19 | data: 20 | channels: 3 21 | logit_transform: False 22 | uniform_dequantization: False 23 | gaussian_dequantization: False 24 | rescaled: True 25 | 26 | model: 27 | type: "simple" 28 | image_size: 256 29 | in_channel: 6 30 | out_channel: 3 31 | inner_channel: 64 32 | norm_groups: 16 #32 33 | channel_mults: [1, 2, 2, 4, 4, 8] #[1, 2, 2, 2, 4, 4, 4] 34 | res_blocks: 1 # 2 35 | attn_res: [16, ] 36 | dropout: 0.2 37 | var_type: fixedsmall 38 | ema_rate: 0.999 39 | ema: True 40 | with_noise_level_emb: True 41 | srno: 42 | encoder: edsr-baseline 43 | no_upsampling: True 44 | width: 256 45 | blocks: 16 46 | 47 | diffusion: 48 | beta_schedule: linear 49 | beta_start: 0.000001 50 | beta_end: 0.01 51 | num_diffusion_timesteps: 2000 52 | 53 | training: 54 | n_epochs: 10000 55 | epoch_save_freq: 10 56 | 57 | optim: 58 | weight_decay: 0.000 59 | optimizer: "Adam" 60 | lr: 0.0001 61 | beta1: 0.9 62 | amsgrad: False 63 | eps: 0.00000001 64 | 65 | scheduler: 66 | lr_sequence: [1e-4, 8e-5, 6e-5, 4e-5, 2e-5, 1e-5, 8e-6, 6e-6, 4e-6, 2e-6] 67 | step_size: 100000 68 | 69 | srno_weight: ./srno_weight.pth 70 | -------------------------------------------------------------------------------- /models/galerkin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LayerNorm(nn.Module): 6 | def __init__(self, d_model, eps=1e-5): 7 | super(LayerNorm, self).__init__() 8 | self.weight = nn.Parameter(torch.ones(d_model)) 9 | self.bias = nn.Parameter(torch.zeros(d_model)) 10 | self.eps = eps 11 | 12 | def forward(self, x): 13 | mean = x.mean(-1, keepdim=True) 14 | std = x.std(-1, keepdim=True) 15 | 16 | out = (x - mean) / (std + self.eps) 17 | out = self.weight * out + self.bias 18 | return out 19 | 20 | class simple_attn(nn.Module): 21 | def __init__(self, midc, heads): 22 | super().__init__() 23 | 24 | self.headc = midc // heads 25 | self.heads = heads 26 | self.midc = midc 27 | 28 | self.qkv_proj = nn.Conv2d(midc, 3*midc, 1) 29 | self.o_proj1 = nn.Conv2d(midc, midc, 1) 30 | self.o_proj2 = nn.Conv2d(midc, midc, 1) 31 | 32 | self.kln = LayerNorm((self.heads, 1, self.headc)) 33 | self.vln = LayerNorm((self.heads, 1, self.headc)) 34 | 35 | self.act = nn.GELU() 36 | 37 | def forward(self, x, name='0'): 38 | B, C, H, W = x.shape 39 | bias = x 40 | 41 | qkv = self.qkv_proj(x).permute(0, 2, 3, 1).reshape(B, H*W, self.heads, 3*self.headc) 42 | qkv = qkv.permute(0, 2, 1, 3) 43 | q, k, v = qkv.chunk(3, dim=-1) 44 | 45 | k = self.kln(k) 46 | v = self.vln(v) 47 | 48 | v = torch.matmul(k.transpose(-2,-1), v) / (H*W) 49 | v = torch.matmul(q, v) 50 | v = v.permute(0, 2, 1, 3).reshape(B, H, W, C) 51 | 52 | ret = v.permute(0, 3, 1, 2) + bias 53 | bias = self.o_proj2(self.act(self.o_proj1(ret))) + bias 54 | 55 | return bias 56 | -------------------------------------------------------------------------------- /models/ema.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class EMAHelper(object): 5 | def __init__(self, mu=0.999): 6 | self.mu = mu 7 | self.shadow = {} 8 | 9 | def register(self, module): 10 | if isinstance(module, nn.DataParallel): 11 | module = module.module 12 | for name, param in module.named_parameters(): 13 | if param.requires_grad: 14 | self.shadow[name] = param.data.clone() 15 | 16 | def update(self, module): 17 | if isinstance(module, nn.DataParallel): 18 | module = module.module 19 | for name, param in module.named_parameters(): 20 | if param.requires_grad: 21 | self.shadow[name].data = ( 22 | 1. - self.mu) * param.data + self.mu * self.shadow[name].data 23 | 24 | def ema(self, module): 25 | if isinstance(module, nn.DataParallel): 26 | module = module.module 27 | for name, param in module.named_parameters(): 28 | if param.requires_grad: 29 | param.data.copy_(self.shadow[name].data) 30 | 31 | def ema_copy(self, module): 32 | if isinstance(module, nn.DataParallel): 33 | inner_module = module.module 34 | module_copy = type(inner_module)( 35 | inner_module.config).to(inner_module.config.device) 36 | module_copy.load_state_dict(inner_module.state_dict()) 37 | module_copy = nn.DataParallel(module_copy) 38 | else: 39 | module_copy = type(module)(module.config).to(module.config.device) 40 | module_copy.load_state_dict(module.state_dict()) 41 | # module_copy = copy.deepcopy(module) 42 | self.ema(module_copy) 43 | return module_copy 44 | 45 | def state_dict(self): 46 | return self.shadow 47 | 48 | def load_state_dict(self, state_dict): 49 | self.shadow = state_dict 50 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | from . import image_folder 6 | 7 | from utils import make_coord 8 | from torchvision import transforms 9 | from torchvision.transforms.functional import InterpolationMode 10 | 11 | 12 | def logit_transform(image, lam=1e-6): 13 | image = lam + (1 - 2 * lam) * image 14 | return torch.log(image) - torch.log1p(-image) 15 | 16 | 17 | def data_transform(config, X): 18 | if config.data.uniform_dequantization: 19 | X = X / 256.0 * 255.0 + torch.rand_like(X) / 256.0 20 | if config.data.gaussian_dequantization: 21 | X = X + torch.randn_like(X) * 0.0005 22 | 23 | if config.data.rescaled: 24 | X = 2 * X - 1.0 25 | elif config.data.logit_transform: 26 | X = logit_transform(X) 27 | 28 | return X 29 | 30 | 31 | def inverse_data_transform(config, X): 32 | if config.data.logit_transform: 33 | X = torch.sigmoid(X) 34 | elif config.data.rescaled: 35 | X = (X + 1.0) / 2.0 36 | 37 | return torch.clamp(X, 0.0, 1.0) 38 | 39 | 40 | def resize_fn(img, size): 41 | return transforms.Resize([size, size], interpolation=InterpolationMode.BICUBIC, antialias=True)(img) 42 | 43 | 44 | def feed_data(config, data): 45 | p = random.uniform(config.wrapper.scale_min, config.wrapper.scale_max) 46 | 47 | img_hr = data['hr'] 48 | # img_lr = data['lr'] 49 | w_lr = round(img_hr.shape[-1] / round(p)) 50 | img_lr = resize_fn(img_hr, w_lr) 51 | # print(p) 52 | # print(img_lr.shape) 53 | 54 | if config.dataset.augment: 55 | hflip = random.random() < 0.5 56 | vflip = random.random() < 0.5 57 | dflip = random.random() < 0.5 58 | 59 | def augment(x): 60 | if hflip: 61 | x = x.flip(-2) 62 | if vflip: 63 | x = x.flip(-1) 64 | if dflip: 65 | x = x.transpose(-2, -1) 66 | return x 67 | 68 | img_lr = augment(img_lr) 69 | 70 | hr_coord = make_coord(img_hr.shape[-2:], flatten=False) 71 | cell = torch.tensor([2 / img_hr.shape[-2], 2 / img_hr.shape[-1]], dtype=torch.float32) 72 | hr_coord = hr_coord.repeat(img_hr.shape[0], 1, 1, 1) 73 | cell = cell.repeat(img_hr.shape[0], 1) 74 | 75 | return { 76 | 'gt': img_hr, 77 | 'inp': img_lr, 78 | 'coord': hr_coord, 79 | 'cell': cell, 80 | } 81 | -------------------------------------------------------------------------------- /init_weight.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | from torch.nn import modules 7 | #################### 8 | # initialize 9 | #################### 10 | 11 | 12 | def weights_init_normal(m, std=0.02): 13 | classname = m.__class__.__name__ 14 | if classname.find('Conv') != -1: 15 | init.normal_(m.weight.data, 0.0, std) 16 | if m.bias is not None: 17 | m.bias.data.zero_() 18 | elif classname.find('Linear') != -1: 19 | init.normal_(m.weight.data, 0.0, std) 20 | if m.bias is not None: 21 | m.bias.data.zero_() 22 | elif classname.find('BatchNorm2d') != -1: 23 | init.normal_(m.weight.data, 1.0, std) # BN also uses norm 24 | init.constant_(m.bias.data, 0.0) 25 | 26 | 27 | def weights_init_kaiming(m, scale=1): 28 | classname = m.__class__.__name__ 29 | if classname.find('Conv2d') != -1: 30 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 31 | m.weight.data *= scale 32 | if m.bias is not None: 33 | m.bias.data.zero_() 34 | elif classname.find('Linear') != -1: 35 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 36 | m.weight.data *= scale 37 | if m.bias is not None: 38 | m.bias.data.zero_() 39 | elif classname.find('BatchNorm2d') != -1: 40 | init.constant_(m.weight.data, 1.0) 41 | init.constant_(m.bias.data, 0.0) 42 | 43 | 44 | def weights_init_orthogonal(m): 45 | classname = m.__class__.__name__ 46 | if classname.find('Conv') != -1: 47 | init.orthogonal_(m.weight.data, gain=1) 48 | if m.bias is not None: 49 | m.bias.data.zero_() 50 | elif classname.find('Linear') != -1: 51 | init.orthogonal_(m.weight.data, gain=1) 52 | if m.bias is not None: 53 | m.bias.data.zero_() 54 | elif classname.find('BatchNorm2d') != -1: 55 | init.constant_(m.weight.data, 1.0) 56 | init.constant_(m.bias.data, 0.0) 57 | 58 | 59 | def init_weights(net, init_type='kaiming', scale=1, std=0.02): 60 | # scale for 'kaiming', std for 'normal'. 61 | logging.info('Initialization method [{:s}]'.format(init_type)) 62 | if init_type == 'normal': 63 | weights_init_normal_ = functools.partial(weights_init_normal, std=std) 64 | net.apply(weights_init_normal_) 65 | elif init_type == 'kaiming': 66 | weights_init_kaiming_ = functools.partial( 67 | weights_init_kaiming, scale=scale) 68 | net.apply(weights_init_kaiming_) 69 | elif init_type == 'orthogonal': 70 | net.apply(weights_init_orthogonal) 71 | else: 72 | raise NotImplementedError( 73 | 'initialization method [{:s}] not implemented'.format(init_type)) 74 | -------------------------------------------------------------------------------- /datasets/image_folder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import torch 5 | from PIL import Image 6 | 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms 9 | from utils import make_coord 10 | from torchvision.transforms import InterpolationMode 11 | 12 | def resize_fn(img, size): 13 | return transforms.ToTensor()( 14 | transforms.Resize(size, InterpolationMode.BICUBIC)( 15 | transforms.ToPILImage()(img))) 16 | 17 | class ImageFolder(Dataset): 18 | 19 | def __init__(self, spec, image_path=None, split_file=None, split_key=None, first_k=None): 20 | self.repeat = spec.repeat 21 | self.augment = spec.augment 22 | self.image_path = image_path 23 | if self.augment: 24 | self.p = random.randint(spec.scale_min, spec.scale_max) 25 | else: 26 | self.p = spec.scale_max 27 | 28 | if image_path == None: 29 | self.root_path = spec.root_path 30 | else: 31 | self.root_path = image_path 32 | 33 | if image_path: 34 | self.files = [image_path] 35 | else: 36 | if split_file is None: 37 | filenames = sorted(os.listdir(self.root_path)) 38 | else: 39 | with open(split_file, 'r') as f: 40 | filenames = json.load(f)[split_key] 41 | if first_k is not None: 42 | filenames = filenames[:first_k] 43 | 44 | self.files = filenames 45 | 46 | def __len__(self): 47 | return len(self.files) * self.repeat 48 | 49 | def __getitem__(self, idx): 50 | filename = self.files[idx % len(self.files)] 51 | 52 | if self.image_path == None: 53 | hr_file = os.path.join(self.root_path, filename) 54 | else: 55 | hr_file = filename 56 | 57 | hr_image = Image.open(hr_file).convert('RGB') 58 | hr_image = transforms.ToTensor()(hr_image) 59 | 60 | w_lr = round(hr_image.shape[-1] / round(self.p)) 61 | lr_image = resize_fn(hr_image, w_lr) 62 | 63 | if self.augment: 64 | hflip = random.random() < 0.5 65 | vflip = random.random() < 0.5 66 | dflip = random.random() < 0.5 67 | 68 | def augment(x): 69 | if hflip: 70 | x = x.flip(-2) 71 | if vflip: 72 | x = x.flip(-1) 73 | if dflip: 74 | x = x.transpose(-2, -1) 75 | return x 76 | 77 | hr_image = augment(hr_image) 78 | lr_image = augment(lr_image) 79 | 80 | hr_coord = make_coord([hr_image.shape[-2], hr_image.shape[-1]], flatten=False) 81 | cell = torch.tensor([2 / hr_image.shape[-2], 2 / hr_image.shape[-1]], dtype=torch.float32) 82 | 83 | return { 84 | 'gt': hr_image, 85 | 'inp': lr_image, 86 | 'coord': hr_coord, 87 | 'cell': cell 88 | } 89 | -------------------------------------------------------------------------------- /models/denoising.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def compute_alpha(beta, t): 6 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) 7 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) 8 | return a 9 | 10 | 11 | def generalized_steps(x, seq, model, b, srno_input, coord, cell, sqrt_alphas_cumprod_prev, **kwargs): 12 | with torch.no_grad(): 13 | n = x.size(0) 14 | seq_next = [-1] + list(seq[:-1]) 15 | x0_preds = [] 16 | xs = [x] 17 | for i, j in zip(reversed(seq), reversed(seq_next)): 18 | t = (torch.ones(n) * i).to(x.device) 19 | next_t = (torch.ones(n) * j).to(x.device) 20 | at = compute_alpha(b, t.long()) 21 | at_next = compute_alpha(b, next_t.long()) 22 | xt = xs[-1].to('cuda') 23 | 24 | noise_level = torch.FloatTensor([sqrt_alphas_cumprod_prev[i+1]]).repeat(n, 1).to(x.device) 25 | et = model(xt, noise_level, srno_input, coord, cell) 26 | 27 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() 28 | x0_preds.append(x0_t.to('cpu')) 29 | c1 = ( 30 | kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() 31 | ) 32 | c2 = ((1 - at_next) - c1 ** 2).sqrt() 33 | xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et 34 | xs.append(xt_next.to('cpu')) 35 | 36 | return xs, x0_preds 37 | 38 | 39 | def ddpm_steps(x, seq, model, b, srno_input, coord, cell, sqrt_alphas_cumprod_prev, **kwargs): 40 | with torch.no_grad(): 41 | n = x.size(0) 42 | seq_next = [-1] + list(seq[:-1]) 43 | xs = [x] 44 | x0_preds = [] 45 | betas = b 46 | for i, j in zip(reversed(seq), reversed(seq_next)): 47 | t = (torch.ones(n) * i).to(x.device) 48 | next_t = (torch.ones(n) * j).to(x.device) 49 | at = compute_alpha(betas, t.long()) 50 | atm1 = compute_alpha(betas, next_t.long()) 51 | beta_t = 1 - at / atm1 52 | x = xs[-1].to('cuda') 53 | 54 | noise_level = torch.FloatTensor([sqrt_alphas_cumprod_prev[i+1]]).repeat(n, 1).to(x.device) 55 | output = model(x, noise_level, srno_input, coord, cell) 56 | e = output 57 | 58 | x0_from_e = (1.0 / at).sqrt() * x - (1.0 / at - 1).sqrt() * e 59 | x0_from_e = torch.clamp(x0_from_e, -1, 1) 60 | x0_preds.append(x0_from_e.to('cpu')) 61 | mean_eps = ( 62 | (atm1.sqrt() * beta_t) * x0_from_e + ((1 - beta_t).sqrt() * (1 - atm1)) * x 63 | ) / (1.0 - at) 64 | 65 | mean = mean_eps 66 | noise = torch.randn_like(x) 67 | mask = 1 - (t == 0).float() 68 | mask = mask.view(-1, 1, 1, 1) 69 | logvar = beta_t.log() 70 | sample = mean + mask * torch.exp(0.5 * logvar) * noise 71 | xs.append(sample.to('cpu')) 72 | 73 | return xs, x0_preds 74 | -------------------------------------------------------------------------------- /models/rdn.py: -------------------------------------------------------------------------------- 1 | # Residual Dense Network for Image Super-Resolution 2 | # https://arxiv.org/abs/1802.08797 3 | # modified from: https://github.com/thstkdgus35/EDSR-PyTorch 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class RDB_Conv(nn.Module): 10 | def __init__(self, inChannels, growRate, kSize=3): 11 | super(RDB_Conv, self).__init__() 12 | Cin = inChannels 13 | G = growRate 14 | self.conv = nn.Sequential(*[ 15 | nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1), 16 | nn.ReLU() 17 | ]) 18 | 19 | def forward(self, x): 20 | out = self.conv(x) 21 | return torch.cat((x, out), 1) 22 | 23 | class RDB(nn.Module): 24 | def __init__(self, growRate0, growRate, nConvLayers, kSize=3): 25 | super(RDB, self).__init__() 26 | G0 = growRate0 27 | G = growRate 28 | C = nConvLayers 29 | 30 | convs = [] 31 | for c in range(C): 32 | convs.append(RDB_Conv(G0 + c*G, G)) 33 | self.convs = nn.Sequential(*convs) 34 | 35 | # Local Feature Fusion 36 | self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1) 37 | 38 | def forward(self, x): 39 | return self.LFF(self.convs(x)) + x 40 | 41 | class RDN(nn.Module): 42 | def __init__(self, scale, G0=64, RDNkSize=3, RDNconfig='B', no_upsampling=False, n_colors=3): 43 | super(RDN, self).__init__() 44 | r = scale 45 | kSize = RDNkSize 46 | 47 | # number of RDB blocks, conv layers, out channels 48 | self.D, C, G = { 49 | 'A': (20, 6, 32), 50 | 'B': (16, 8, 64), 51 | }[RDNconfig] 52 | 53 | # Shallow feature extraction net 54 | self.SFENet1 = nn.Conv2d(n_colors, G0, kSize, padding=(kSize-1)//2, stride=1) 55 | self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 56 | 57 | # Redidual dense blocks and dense feature fusion 58 | self.RDBs = nn.ModuleList() 59 | for i in range(self.D): 60 | self.RDBs.append( 61 | RDB(growRate0 = G0, growRate = G, nConvLayers = C) 62 | ) 63 | 64 | # Global Feature Fusion 65 | self.GFF = nn.Sequential(*[ 66 | nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1), 67 | nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 68 | ]) 69 | 70 | if no_upsampling: 71 | self.out_dim = G0 72 | else: 73 | self.out_dim = n_colors 74 | # Up-sampling net 75 | if r == 2 or r == 3: 76 | self.UPNet = nn.Sequential(*[ 77 | nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1), 78 | nn.PixelShuffle(r), 79 | nn.Conv2d(G, n_colors, kSize, padding=(kSize-1)//2, stride=1) 80 | ]) 81 | elif r == 4: 82 | self.UPNet = nn.Sequential(*[ 83 | nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1), 84 | nn.PixelShuffle(2), 85 | nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1), 86 | nn.PixelShuffle(2), 87 | nn.Conv2d(G, n_colors, kSize, padding=(kSize-1)//2, stride=1) 88 | ]) 89 | else: 90 | raise ValueError("scale must be 2 or 3 or 4.") 91 | 92 | self.no_upsampling = no_upsampling 93 | 94 | def forward(self, x): 95 | f__1 = self.SFENet1(x) 96 | x = self.SFENet2(f__1) 97 | 98 | RDBs_out = [] 99 | for i in range(self.D): 100 | x = self.RDBs[i](x) 101 | RDBs_out.append(x) 102 | 103 | x = self.GFF(torch.cat(RDBs_out,1)) 104 | x += f__1 105 | 106 | if self.no_upsampling: 107 | return x 108 | else: 109 | return self.UPNet(x) 110 | -------------------------------------------------------------------------------- /models/sronet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import math 6 | import numpy as np 7 | 8 | from models.galerkin import simple_attn 9 | from utils import make_coord 10 | from utils import show_feature_map 11 | 12 | from models.edsr import EDSR 13 | from models.rdn import RDN 14 | 15 | 16 | def creat_encoder(type, no_upsampling): 17 | if type == 'edsr': 18 | return EDSR(n_resblocks=32, n_feats=256, res_scale=0.1, scale=2, no_upsampling=no_upsampling) 19 | elif type == 'edsr-baseline': 20 | return EDSR(n_resblocks=16, n_feats=64, res_scale=1, scale=2, no_upsampling=no_upsampling) 21 | elif type == 'rdn': 22 | return RDN(scale=2, G0=64, RDNkSize=3, RDNconfig='B', no_upsampling=no_upsampling) 23 | else: 24 | raise ValueError(f"Unknown model type: {type}") 25 | 26 | 27 | class SRNO(nn.Module): 28 | def __init__(self, config, width=256, blocks=16): 29 | super(SRNO, self).__init__() 30 | self.width = width 31 | self.encoder = creat_encoder(config.model.srno.encoder, config.model.srno.no_upsampling) 32 | 33 | self.conv00 = nn.Conv2d((64 + 2)*4+2, self.width, 1) 34 | 35 | self.conv0 = simple_attn(self.width, blocks) 36 | self.conv1 = simple_attn(self.width, blocks) 37 | 38 | self.fc1 = nn.Conv2d(self.width, 256, 1) 39 | self.fc2 = nn.Conv2d(256, 3, 1) 40 | 41 | def gen_feat(self, inp): 42 | self.inp = inp 43 | self.feat = self.encoder(inp) 44 | return self.feat 45 | 46 | def query_features(self, coord, cell): 47 | feat = (self.feat) 48 | grid = 0 49 | 50 | pos_lr = make_coord(feat.shape[-2:], flatten=False).cuda() \ 51 | .permute(2, 0, 1) \ 52 | .unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:]) 53 | 54 | rx = 2 / feat.shape[-2] / 2 55 | ry = 2 / feat.shape[-1] / 2 56 | vx_lst = [-1, 1] 57 | vy_lst = [-1, 1] 58 | eps_shift = 1e-6 59 | 60 | rel_coords = [] 61 | feat_s = [] 62 | areas = [] 63 | for vx in vx_lst: 64 | for vy in vy_lst: 65 | 66 | coord_ = coord.clone() 67 | coord_[:, :, :, 0] += vx * rx + eps_shift 68 | coord_[:, :, :, 1] += vy * ry + eps_shift 69 | coord_.clamp_(-1 + 1e-6, 1 - 1e-6) 70 | 71 | feat_ = F.grid_sample(feat, coord_.flip(-1), mode='nearest', align_corners=False) 72 | 73 | old_coord = F.grid_sample(pos_lr, coord_.flip(-1), mode='nearest', align_corners=False) 74 | rel_coord = coord.permute(0, 3, 1, 2) - old_coord 75 | rel_coord[:, 0, :, :] *= feat.shape[-2] 76 | rel_coord[:, 1, :, :] *= feat.shape[-1] 77 | 78 | area = torch.abs(rel_coord[:, 0, :, :] * rel_coord[:, 1, :, :]) 79 | areas.append(area + 1e-9) 80 | 81 | rel_coords.append(rel_coord) 82 | feat_s.append(feat_) 83 | 84 | rel_cell = cell.clone() 85 | rel_cell[:,0] *= feat.shape[-2] 86 | rel_cell[:,1] *= feat.shape[-1] 87 | 88 | tot_area = torch.stack(areas).sum(dim=0) 89 | t = areas[0]; areas[0] = areas[3]; areas[3] = t 90 | t = areas[1]; areas[1] = areas[2]; areas[2] = t 91 | 92 | for index, area in enumerate(areas): 93 | feat_s[index] = feat_s[index] * (area / tot_area).unsqueeze(1) 94 | 95 | grid = torch.cat([*rel_coords, *feat_s, \ 96 | rel_cell.unsqueeze(-1).unsqueeze(-1).repeat(1,1,coord.shape[1],coord.shape[2])],dim=1) 97 | 98 | x = self.conv00(grid) 99 | x = self.conv0(x, 0) 100 | x = self.conv1(x, 1) 101 | 102 | feat = x 103 | ret = self.fc2(F.gelu(self.fc1(feat))) 104 | 105 | ret = ret + F.grid_sample(self.inp, coord.flip(-1), mode='bilinear',\ 106 | padding_mode='border', align_corners=False) 107 | 108 | return ret 109 | 110 | def forward(self, inp, coord, cell): 111 | self.gen_feat(inp) 112 | return self.query_features(coord, cell) 113 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import yaml 4 | import sys 5 | import torch 6 | import numpy as np 7 | 8 | from utils import calc_psnr, calc_ssim 9 | from utils import tensor2img, save_img 10 | 11 | from runners.diffusion import get_current_visuals 12 | from runners.diffusion import sample_image 13 | from runners.diffusion import make_data_loader 14 | 15 | from datasets import data_transform, inverse_data_transform 16 | 17 | from models.ema import EMAHelper 18 | from models.diffusion import Model 19 | 20 | 21 | def parse_args_and_config(): 22 | parser = argparse.ArgumentParser(description="Super-Resolution Diffusion Model Evaluation") 23 | parser.add_argument("--config", type=str, required=True, help="Path to the config file") 24 | parser.add_argument("--model", type=str, required=True, help="Path to the model checkpoint") 25 | parser.add_argument("--path", type=str, required=True, help="Path to the image") 26 | parser.add_argument("--seed", type=int, default=60000, help="Random seed") 27 | parser.add_argument( 28 | "--timesteps", type=int, default=50, help="number of steps involved" 29 | ) 30 | parser.add_argument( 31 | "--ni", 32 | action="store_true", 33 | help="No interaction. Suitable for Slurm Job launcher", 34 | ) 35 | parser.add_argument( 36 | "--sample_type", 37 | type=str, 38 | default="generalized", 39 | help="sampling approach (generalized or ddpm_noisy)", 40 | ) 41 | parser.add_argument( 42 | "--skip_type", 43 | type=str, 44 | default="uniform", 45 | help="skip according to (uniform or quadratic)", 46 | ) 47 | parser.add_argument( 48 | "--eta", 49 | type=float, 50 | default=0.0, 51 | help="eta used to control the variances of sigma", 52 | ) 53 | 54 | args = parser.parse_args() 55 | 56 | with open(args.config, "r") as f: 57 | config = yaml.safe_load(f) 58 | config = dict2namespace(config) 59 | 60 | # Set device 61 | config.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 62 | 63 | # set random seed 64 | torch.manual_seed(args.seed) 65 | np.random.seed(args.seed) 66 | if torch.cuda.is_available(): 67 | torch.cuda.manual_seed_all(args.seed) 68 | 69 | torch.backends.cudnn.benchmark = True 70 | 71 | return args, config 72 | 73 | 74 | def dict2namespace(config): 75 | namespace = argparse.Namespace() 76 | for key, value in config.items(): 77 | if isinstance(value, dict): 78 | new_value = dict2namespace(value) 79 | else: 80 | new_value = value 81 | setattr(namespace, key, new_value) 82 | return namespace 83 | 84 | 85 | def eval_psrn(config, args, model, test_loader): 86 | model.eval() 87 | 88 | with torch.no_grad(): 89 | for i, data in enumerate(test_loader): 90 | gt = data['gt'].to(config.device) 91 | lr = data['inp'].to(config.device) 92 | cell = data['cell'].to(config.device) 93 | hr_coord = data['coord'].to(config.device) 94 | 95 | lr = data_transform(config, lr) 96 | gt = data_transform(config, gt) 97 | 98 | x_t = torch.randn_like(gt, device=config.device) 99 | 100 | sr = sample_image(config, args, x_t, model, lr , hr_coord, cell).to(config.device) 101 | visuals = get_current_visuals(sr, data) 102 | sr_img = tensor2img(visuals['SR']) # uint8 103 | hr_img = tensor2img(visuals['GT']) 104 | save_img(sr_img, '{}/sr.png'.format('result')) 105 | 106 | sr = inverse_data_transform(config, sr) 107 | gt = inverse_data_transform(config, gt) 108 | 109 | psnr = calc_psnr(gt, sr) 110 | ssim = calc_ssim(sr_img, hr_img) 111 | 112 | return psnr, ssim 113 | 114 | 115 | def load_model(config, args, model): 116 | checkpoint = torch.load(args.model, map_location=config.device, weights_only=True) 117 | model.load_state_dict(checkpoint[0], strict=True) 118 | 119 | if config.model.ema: 120 | ema_helper = EMAHelper(mu=config.model.ema_rate) 121 | ema_helper.register(model) 122 | ema_helper.load_state_dict(checkpoint[-1]) 123 | ema_helper.ema(model) 124 | else: 125 | ema_helper = None 126 | 127 | 128 | def main(): 129 | args, config = parse_args_and_config() 130 | print(f"Starting evaluation with checkpoint: {args.model}") 131 | print(f"Using device: {config.device}") 132 | 133 | model = Model(config) 134 | model = model.to(config.device) 135 | test_loader = make_data_loader(config.test_dataset, args.path, tag='test') 136 | 137 | try: 138 | load_model(config, args, model) 139 | psnr, ssim = eval_psrn(config, args, model, test_loader) 140 | print(f"PSNR: {psnr:12.6f}, SSIM: {ssim:.4e}") 141 | except Exception as e: 142 | print(f"Error during evaluation: {str(e)}") 143 | raise 144 | 145 | return 0 146 | 147 | if __name__ == "__main__": 148 | 149 | sys.exit(main()) 150 | -------------------------------------------------------------------------------- /models/edsr.py: -------------------------------------------------------------------------------- 1 | # modified from: https://github.com/thstkdgus35/EDSR-PyTorch 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 10 | return nn.Conv2d( 11 | in_channels, out_channels, kernel_size, 12 | padding=(kernel_size//2), bias=bias) 13 | 14 | class MeanShift(nn.Conv2d): 15 | def __init__( 16 | self, rgb_range, 17 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 18 | 19 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 20 | std = torch.Tensor(rgb_std) 21 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 22 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 23 | for p in self.parameters(): 24 | p.requires_grad = False 25 | 26 | class ResBlock(nn.Module): 27 | def __init__( 28 | self, conv, n_feats, kernel_size, 29 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 30 | 31 | super(ResBlock, self).__init__() 32 | m = [] 33 | for i in range(2): 34 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 35 | if bn: 36 | m.append(nn.BatchNorm2d(n_feats)) 37 | if i == 0: 38 | m.append(act) 39 | 40 | self.body = nn.Sequential(*m) 41 | self.res_scale = res_scale 42 | 43 | def forward(self, x): 44 | res = self.body(x).mul(self.res_scale) 45 | res += x 46 | 47 | return res 48 | 49 | class Upsampler(nn.Sequential): 50 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 51 | 52 | m = [] 53 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 54 | for _ in range(int(math.log(scale, 2))): 55 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 56 | m.append(nn.PixelShuffle(2)) 57 | if bn: 58 | m.append(nn.BatchNorm2d(n_feats)) 59 | if act == 'relu': 60 | m.append(nn.ReLU(True)) 61 | elif act == 'prelu': 62 | m.append(nn.PReLU(n_feats)) 63 | 64 | elif scale == 3: 65 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 66 | m.append(nn.PixelShuffle(3)) 67 | if bn: 68 | m.append(nn.BatchNorm2d(n_feats)) 69 | if act == 'relu': 70 | m.append(nn.ReLU(True)) 71 | elif act == 'prelu': 72 | m.append(nn.PReLU(n_feats)) 73 | else: 74 | raise NotImplementedError 75 | 76 | super(Upsampler, self).__init__(*m) 77 | 78 | 79 | url = { 80 | 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt', 81 | 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt', 82 | 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt', 83 | 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt', 84 | 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt', 85 | 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt' 86 | } 87 | 88 | class EDSR(nn.Module): 89 | def __init__(self, n_resblocks, n_feats, res_scale, scale, no_upsampling=False, rgb_range=1): 90 | super(EDSR, self).__init__() 91 | kernel_size = 3 92 | act = nn.ReLU(True) 93 | url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale) 94 | if url_name in url: 95 | self.url = url[url_name] 96 | else: 97 | self.url = None 98 | self.sub_mean = MeanShift(rgb_range) 99 | self.add_mean = MeanShift(rgb_range, sign=1) 100 | 101 | # define head module 102 | m_head = [default_conv(3, n_feats, kernel_size)] 103 | 104 | # define body module 105 | m_body = [ 106 | ResBlock( 107 | default_conv, n_feats, kernel_size, act=act, res_scale=res_scale 108 | ) for _ in range(n_resblocks) 109 | ] 110 | m_body.append(default_conv(n_feats, n_feats, kernel_size)) 111 | 112 | self.head = nn.Sequential(*m_head) 113 | self.body = nn.Sequential(*m_body) 114 | 115 | if no_upsampling: 116 | self.out_dim = n_feats 117 | else: 118 | self.out_dim = 3 119 | # define tail module 120 | m_tail = [ 121 | Upsampler(default_conv, scale, n_feats, act=False), 122 | default_conv(n_feats, 3, kernel_size) 123 | ] 124 | self.tail = nn.Sequential(*m_tail) 125 | 126 | self.no_upsampling = no_upsampling 127 | 128 | def forward(self, x): 129 | #x = self.sub_mean(x) 130 | x = self.head(x) 131 | 132 | res = self.body(x) 133 | res += x 134 | 135 | if self.no_upsampling: 136 | x = res 137 | else: 138 | x = self.tail(res) 139 | #x = self.add_mean(x) 140 | return x 141 | 142 | def load_state_dict(self, state_dict, strict=True): 143 | own_state = self.state_dict() 144 | for name, param in state_dict.items(): 145 | if name in own_state: 146 | if isinstance(param, nn.Parameter): 147 | param = param.data 148 | try: 149 | own_state[name].copy_(param) 150 | except Exception: 151 | if name.find('tail') == -1: 152 | raise RuntimeError('While copying the parameter named {}, ' 153 | 'whose dimensions in the model are {} and ' 154 | 'whose dimensions in the checkpoint are {}.' 155 | .format(name, own_state[name].size(), param.size())) 156 | elif strict: 157 | if name.find('tail') == -1: 158 | raise KeyError('unexpected key "{}" in state_dict' 159 | .format(name)) 160 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # modified from: https://github.com/yinboc/liif 2 | 3 | import os 4 | import time 5 | import shutil 6 | import math 7 | import torch 8 | import numpy as np 9 | from torch.optim import SGD 10 | import cv2 11 | import matplotlib.pyplot as plt 12 | 13 | from torchvision import transforms 14 | from torchvision.transforms import InterpolationMode 15 | from torchvision.utils import make_grid 16 | import random 17 | import math 18 | 19 | def show_feature_map(feature_map,layer,name='rgb',rgb=False): 20 | feature_map = feature_map.squeeze(0) 21 | #if rgb: feature_map = feature_map.permute(1,2,0)*0.5+0.5 22 | feature_map = feature_map.cpu().numpy() 23 | feature_map_num = feature_map.shape[0] 24 | row_num = math.ceil(np.sqrt(feature_map_num)) 25 | if rgb: 26 | #plt.figure() 27 | #plt.imshow(feature_map) 28 | #plt.axis('off') 29 | feature_map = cv2.cvtColor(feature_map,cv2.COLOR_BGR2RGB) 30 | cv2.imwrite('data/'+layer+'/'+name+".png",feature_map*255) 31 | #plt.show() 32 | else: 33 | plt.figure() 34 | for index in range(1, feature_map_num+1): 35 | t = (feature_map[index-1]*255).astype(np.uint8) 36 | t = cv2.applyColorMap(t, cv2.COLORMAP_TWILIGHT) 37 | plt.subplot(row_num, row_num, index) 38 | plt.imshow(t, cmap='gray') 39 | plt.axis('off') 40 | #ensure_path('data/'+layer) 41 | cv2.imwrite('data/'+layer+'/'+str(name)+'_'+str(index)+".png",t) 42 | #plt.show() 43 | plt.savefig('data/'+layer+'/'+str(name)+".png") 44 | 45 | 46 | class Averager(): 47 | 48 | def __init__(self): 49 | self.n = 0.0 50 | self.v = 0.0 51 | 52 | def add(self, v, n=1.0): 53 | self.v = (self.v * self.n + v * n) / (self.n + n) 54 | self.n += n 55 | 56 | def item(self): 57 | return self.v 58 | 59 | 60 | def make_coord(shape, ranges=None, flatten=True): 61 | """ Make coordinates at grid centers. 62 | """ 63 | coord_seqs = [] 64 | for i, n in enumerate(shape): 65 | if ranges is None: 66 | v0, v1 = -1, 1 67 | else: 68 | v0, v1 = ranges[i] 69 | r = (v1 - v0) / (2 * n) 70 | seq = v0 + r + (2 * r) * torch.arange(n).float() 71 | coord_seqs.append(seq) 72 | #ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1) 73 | ret = torch.stack(torch.meshgrid(*coord_seqs,indexing='ij'), dim=-1) 74 | if flatten: 75 | ret = ret.view(-1, ret.shape[-1]) 76 | return ret 77 | 78 | 79 | def to_pixel_samples(img): 80 | """ Convert the image to coord-RGB pairs. 81 | img: Tensor, (3, H, W) 82 | """ 83 | coord = make_coord(img.shape[-2:], flatten=False) 84 | rgb = img.view(3, -1).permute(1, 0) 85 | return coord, rgb 86 | 87 | 88 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)): 89 | ''' 90 | Converts a torch Tensor into an image Numpy array 91 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 92 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 93 | ''' 94 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 95 | tensor = (tensor - min_max[0]) / \ 96 | (min_max[1] - min_max[0]) # to range [0,1] 97 | n_dim = tensor.dim() 98 | if n_dim == 4: 99 | n_img = len(tensor) 100 | img_np = make_grid(tensor, nrow=int( 101 | math.sqrt(n_img)), normalize=False).numpy() 102 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 103 | elif n_dim == 3: 104 | img_np = tensor.numpy() 105 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 106 | elif n_dim == 2: 107 | img_np = tensor.numpy() 108 | else: 109 | raise TypeError( 110 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 111 | if out_type == np.uint8: 112 | img_np = (img_np * 255.0).round() 113 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 114 | return img_np.astype(out_type) 115 | 116 | 117 | def save_img(img, img_path, mode='RGB'): 118 | cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 119 | # cv2.imwrite(img_path, img) 120 | 121 | 122 | def calc_psnr(gt, pred_img): 123 | assert gt.shape == pred_img.shape, "Input images must have the same dimensions" 124 | 125 | mse = torch.mean((gt - pred_img) ** 2, dim=[1, 2, 3]) 126 | if torch.any(mse == 0): 127 | return float('inf') # If MSE is 0, the images are identical 128 | psnr = -10 * torch.log10(mse) 129 | 130 | return psnr.mean().item() 131 | 132 | 133 | def ssim(img1, img2): 134 | C1 = (0.01 * 255)**2 135 | C2 = (0.03 * 255)**2 136 | 137 | img1 = img1.astype(np.float64) 138 | img2 = img2.astype(np.float64) 139 | kernel = cv2.getGaussianKernel(11, 1.5) 140 | window = np.outer(kernel, kernel.transpose()) 141 | 142 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 143 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 144 | mu1_sq = mu1**2 145 | mu2_sq = mu2**2 146 | mu1_mu2 = mu1 * mu2 147 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 148 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 149 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 150 | 151 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 152 | (sigma1_sq + sigma2_sq + C2)) 153 | return ssim_map.mean() 154 | 155 | 156 | def calc_ssim(img1, img2): 157 | '''calculate SSIM 158 | the same outputs as MATLAB's 159 | img1, img2: [0, 255] 160 | ''' 161 | if not img1.shape == img2.shape: 162 | raise ValueError('Input images must have the same dimensions.') 163 | if img1.ndim == 2: 164 | return ssim(img1, img2) 165 | elif img1.ndim == 3: 166 | if img1.shape[2] == 3: 167 | ssims = [] 168 | for i in range(3): 169 | ssims.append(ssim(img1, img2)) 170 | return np.array(ssims).mean() 171 | elif img1.shape[2] == 1: 172 | return ssim(np.squeeze(img1), np.squeeze(img2)) 173 | else: 174 | raise ValueError('Wrong input image dimensions.') 175 | 176 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import traceback 3 | import shutil 4 | import logging 5 | import yaml 6 | import sys 7 | import os 8 | import torch 9 | import numpy as np 10 | import torch.utils.tensorboard as tb 11 | 12 | from runners.diffusion import Diffusion 13 | 14 | torch.set_printoptions(sci_mode=False) 15 | 16 | 17 | def parse_args_and_config(): 18 | parser = argparse.ArgumentParser(description=globals()["__doc__"]) 19 | 20 | parser.add_argument( 21 | "--config", type=str, required=True, help="Path to the config file" 22 | ) 23 | parser.add_argument("--seed", type=int, default=1234, help="Random seed") 24 | parser.add_argument( 25 | "--exp", type=str, default="exp", help="Path for saving running related data." 26 | ) 27 | parser.add_argument( 28 | "--doc", 29 | type=str, 30 | required=True, 31 | help="A string for documentation purpose. " 32 | "Will be the name of the log folder.", 33 | ) 34 | parser.add_argument( 35 | "--comment", type=str, default="", help="A string for experiment comment" 36 | ) 37 | parser.add_argument( 38 | "--verbose", 39 | type=str, 40 | default="info", 41 | help="Verbose level: info | debug | warning | critical", 42 | ) 43 | parser.add_argument("--test", action="store_true", help="Whether to test the model") 44 | parser.add_argument( 45 | "--sample", 46 | action="store_true", 47 | help="Whether to produce samples from the model", 48 | ) 49 | parser.add_argument("--fid", action="store_true") 50 | parser.add_argument("--interpolation", action="store_true") 51 | parser.add_argument( 52 | "--resume_training", action="store_true", help="Whether to resume training" 53 | ) 54 | parser.add_argument( 55 | "-i", 56 | "--image_folder", 57 | type=str, 58 | default="images", 59 | help="The folder name of samples", 60 | ) 61 | parser.add_argument( 62 | "--ni", 63 | action="store_true", 64 | help="No interaction. Suitable for Slurm Job launcher", 65 | ) 66 | parser.add_argument("--use_pretrained", action="store_true") 67 | parser.add_argument( 68 | "--sample_type", 69 | type=str, 70 | default="generalized", 71 | help="sampling approach (generalized or ddpm_noisy)", 72 | ) 73 | parser.add_argument( 74 | "--skip_type", 75 | type=str, 76 | default="uniform", 77 | help="skip according to (uniform or quadratic)", 78 | ) 79 | parser.add_argument( 80 | "--timesteps", type=int, default=1000, help="number of steps involved" 81 | ) 82 | parser.add_argument( 83 | "--eta", 84 | type=float, 85 | default=0.0, 86 | help="eta used to control the variances of sigma", 87 | ) 88 | parser.add_argument("--sequence", action="store_true") 89 | 90 | args = parser.parse_args() 91 | args.log_path = os.path.join(args.exp, "logs", args.doc) 92 | 93 | # parse config file 94 | with open(os.path.join("configs", args.config), "r") as f: 95 | config = yaml.safe_load(f) 96 | new_config = dict2namespace(config) 97 | 98 | tb_path = os.path.join(args.exp, "tensorboard", args.doc) 99 | 100 | if not args.test and not args.sample: 101 | if not args.resume_training: 102 | if os.path.exists(args.log_path): 103 | overwrite = False 104 | if args.ni: 105 | overwrite = True 106 | else: 107 | response = input("Folder already exists. Overwrite? (Y/N)") 108 | if response.upper() == "Y": 109 | overwrite = True 110 | 111 | if overwrite: 112 | shutil.rmtree(args.log_path) 113 | shutil.rmtree(tb_path) 114 | os.makedirs(args.log_path) 115 | if os.path.exists(tb_path): 116 | shutil.rmtree(tb_path) 117 | else: 118 | print("Folder exists. Program halted.") 119 | sys.exit(0) 120 | else: 121 | os.makedirs(args.log_path) 122 | 123 | with open(os.path.join(args.log_path, "config.yml"), "w") as f: 124 | yaml.dump(new_config, f, default_flow_style=False) 125 | 126 | new_config.tb_logger = tb.SummaryWriter(log_dir=tb_path) 127 | # setup logger 128 | level = getattr(logging, args.verbose.upper(), None) 129 | if not isinstance(level, int): 130 | raise ValueError("level {} not supported".format(args.verbose)) 131 | 132 | handler1 = logging.StreamHandler() 133 | handler2 = logging.FileHandler(os.path.join(args.log_path, "stdout.txt")) 134 | formatter = logging.Formatter( 135 | "%(levelname)s - %(filename)s - %(asctime)s - %(message)s" 136 | ) 137 | handler1.setFormatter(formatter) 138 | handler2.setFormatter(formatter) 139 | logger = logging.getLogger() 140 | logger.addHandler(handler1) 141 | logger.addHandler(handler2) 142 | logger.setLevel(level) 143 | 144 | else: 145 | level = getattr(logging, args.verbose.upper(), None) 146 | if not isinstance(level, int): 147 | raise ValueError("level {} not supported".format(args.verbose)) 148 | 149 | handler1 = logging.StreamHandler() 150 | formatter = logging.Formatter( 151 | "%(levelname)s - %(filename)s - %(asctime)s - %(message)s" 152 | ) 153 | handler1.setFormatter(formatter) 154 | logger = logging.getLogger() 155 | logger.addHandler(handler1) 156 | logger.setLevel(level) 157 | 158 | if args.sample: 159 | os.makedirs(os.path.join(args.exp, "image_samples"), exist_ok=True) 160 | args.image_folder = os.path.join( 161 | args.exp, "image_samples", args.image_folder 162 | ) 163 | if not os.path.exists(args.image_folder): 164 | os.makedirs(args.image_folder) 165 | else: 166 | if not (args.fid or args.interpolation): 167 | overwrite = False 168 | if args.ni: 169 | overwrite = True 170 | else: 171 | response = input( 172 | f"Image folder {args.image_folder} already exists. Overwrite? (Y/N)" 173 | ) 174 | if response.upper() == "Y": 175 | overwrite = True 176 | 177 | if overwrite: 178 | shutil.rmtree(args.image_folder) 179 | os.makedirs(args.image_folder) 180 | else: 181 | print("Output image folder exists. Program halted.") 182 | sys.exit(0) 183 | 184 | # add device 185 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 186 | logging.info("Using device: {}".format(device)) 187 | new_config.device = device 188 | 189 | # set random seed 190 | torch.manual_seed(args.seed) 191 | np.random.seed(args.seed) 192 | if torch.cuda.is_available(): 193 | torch.cuda.manual_seed_all(args.seed) 194 | 195 | torch.backends.cudnn.benchmark = True 196 | 197 | return args, new_config 198 | 199 | 200 | def dict2namespace(config): 201 | namespace = argparse.Namespace() 202 | for key, value in config.items(): 203 | if isinstance(value, dict): 204 | new_value = dict2namespace(value) 205 | else: 206 | new_value = value 207 | setattr(namespace, key, new_value) 208 | return namespace 209 | 210 | 211 | def main(): 212 | args, config = parse_args_and_config() 213 | 214 | logging.info("Writing log file to {}".format(args.log_path)) 215 | logging.info("Exp instance id = {}".format(os.getpid())) 216 | logging.info("Exp comment = {}".format(args.comment)) 217 | 218 | try: 219 | runner = Diffusion(args, config) 220 | runner.train() 221 | except Exception: 222 | logging.error(traceback.format_exc()) 223 | 224 | return 0 225 | 226 | 227 | if __name__ == "__main__": 228 | sys.exit(main()) 229 | -------------------------------------------------------------------------------- /models/diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from inspect import isfunction 6 | 7 | from models.sronet import SRNO 8 | 9 | 10 | def exists(x): 11 | return x is not None 12 | 13 | 14 | def default(val, d): 15 | if exists(val): 16 | return val 17 | return d() if isfunction(d) else d 18 | 19 | # PositionalEncoding Source: https://github.com/lmnt-com/wavegrad/blob/master/src/wavegrad/model.py 20 | class PositionalEncoding(nn.Module): 21 | def __init__(self, dim): 22 | super().__init__() 23 | self.dim = dim 24 | 25 | def forward(self, noise_level): 26 | count = self.dim // 2 27 | step = torch.arange(count, dtype=noise_level.dtype, 28 | device=noise_level.device) / count 29 | encoding = noise_level.unsqueeze( 30 | 1) * torch.exp(-math.log(1e4) * step.unsqueeze(0)) 31 | encoding = torch.cat( 32 | [torch.sin(encoding), torch.cos(encoding)], dim=-1) 33 | return encoding 34 | 35 | 36 | class FeatureWiseAffine(nn.Module): 37 | def __init__(self, in_channels, out_channels, use_affine_level=False): 38 | super(FeatureWiseAffine, self).__init__() 39 | self.use_affine_level = use_affine_level 40 | self.noise_func = nn.Sequential( 41 | nn.Linear(in_channels, out_channels*(1+self.use_affine_level)) 42 | ) 43 | 44 | def forward(self, x, noise_embed): 45 | batch = x.shape[0] 46 | if self.use_affine_level: 47 | gamma, beta = self.noise_func(noise_embed).view( 48 | batch, -1, 1, 1).chunk(2, dim=1) 49 | x = (1 + gamma) * x + beta 50 | else: 51 | x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1) 52 | return x 53 | 54 | 55 | class Swish(nn.Module): 56 | def forward(self, x): 57 | return x * torch.sigmoid(x) 58 | 59 | 60 | class Upsample(nn.Module): 61 | def __init__(self, dim): 62 | super().__init__() 63 | self.up = nn.Upsample(scale_factor=2, mode="nearest") 64 | self.conv = nn.Conv2d(dim, dim, 3, padding=1) 65 | 66 | def forward(self, x): 67 | return self.conv(self.up(x)) 68 | 69 | 70 | class Downsample(nn.Module): 71 | def __init__(self, dim): 72 | super().__init__() 73 | self.conv = nn.Conv2d(dim, dim, 3, 2, 1) 74 | 75 | def forward(self, x): 76 | return self.conv(x) 77 | 78 | 79 | class Block(nn.Module): 80 | def __init__(self, dim, dim_out, groups=32, dropout=0): 81 | super().__init__() 82 | self.block = nn.Sequential( 83 | nn.GroupNorm(groups, dim), 84 | Swish(), 85 | nn.Dropout(dropout) if dropout != 0 else nn.Identity(), 86 | nn.Conv2d(dim, dim_out, 3, padding=1) 87 | ) 88 | 89 | def forward(self, x): 90 | return self.block(x) 91 | 92 | 93 | class ResnetBlock(nn.Module): 94 | def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32): 95 | super().__init__() 96 | self.noise_func = FeatureWiseAffine( 97 | noise_level_emb_dim, dim_out, use_affine_level) 98 | 99 | self.block1 = Block(dim, dim_out, groups=norm_groups) 100 | self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout) 101 | self.res_conv = nn.Conv2d( 102 | dim, dim_out, 1) if dim != dim_out else nn.Identity() 103 | 104 | def forward(self, x, time_emb): 105 | b, c, h, w = x.shape 106 | h = self.block1(x) 107 | h = self.noise_func(h, time_emb) 108 | h = self.block2(h) 109 | return h + self.res_conv(x) 110 | 111 | 112 | class SelfAttention(nn.Module): 113 | def __init__(self, in_channel, n_head=1, norm_groups=32): 114 | super().__init__() 115 | 116 | self.n_head = n_head 117 | 118 | self.norm = nn.GroupNorm(norm_groups, in_channel) 119 | self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False) 120 | self.out = nn.Conv2d(in_channel, in_channel, 1) 121 | 122 | def forward(self, input): 123 | batch, channel, height, width = input.shape 124 | n_head = self.n_head 125 | head_dim = channel // n_head 126 | 127 | norm = self.norm(input) 128 | qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width) 129 | query, key, value = qkv.chunk(3, dim=2) # bhdyx 130 | 131 | attn = torch.einsum( 132 | "bnchw, bncyx -> bnhwyx", query, key 133 | ).contiguous() / math.sqrt(channel) 134 | attn = attn.view(batch, n_head, height, width, -1) 135 | attn = torch.softmax(attn, -1) 136 | attn = attn.view(batch, n_head, height, width, height, width) 137 | 138 | out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous() 139 | out = self.out(out.view(batch, channel, height, width)) 140 | 141 | return out + input 142 | 143 | 144 | class ResnetBlocWithAttn(nn.Module): 145 | def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False): 146 | super().__init__() 147 | self.with_attn = with_attn 148 | self.res_block = ResnetBlock( 149 | dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout) 150 | if with_attn: 151 | self.attn = SelfAttention(dim_out, norm_groups=norm_groups) 152 | 153 | def forward(self, x, time_emb): 154 | x = self.res_block(x, time_emb) 155 | if(self.with_attn): 156 | x = self.attn(x) 157 | return x 158 | 159 | 160 | class UNet(nn.Module): 161 | def __init__(self, config): 162 | super().__init__() 163 | 164 | self.config = config 165 | in_channel = config.model.in_channel 166 | out_channel = config.model.out_channel 167 | inner_channel = config.model.inner_channel 168 | norm_groups = config.model.norm_groups 169 | channel_mults = tuple(config.model.channel_mults) 170 | attn_res = config.model.attn_res 171 | res_blocks = config.model.res_blocks 172 | dropout = config.model.dropout 173 | with_noise_level_emb = config.model.with_noise_level_emb 174 | image_size = config.model.image_size 175 | 176 | if with_noise_level_emb: 177 | noise_level_channel = inner_channel 178 | self.noise_level_mlp = nn.Sequential( 179 | PositionalEncoding(inner_channel), 180 | nn.Linear(inner_channel, inner_channel * 4), 181 | Swish(), 182 | nn.Linear(inner_channel * 4, inner_channel) 183 | ) 184 | else: 185 | noise_level_channel = None 186 | self.noise_level_mlp = None 187 | 188 | num_mults = len(channel_mults) 189 | pre_channel = inner_channel 190 | feat_channels = [pre_channel] 191 | now_res = image_size 192 | downs = [nn.Conv2d(in_channel, inner_channel, 193 | kernel_size=3, padding=1)] 194 | for ind in range(num_mults): 195 | is_last = (ind == num_mults - 1) 196 | use_attn = (now_res in attn_res) 197 | channel_mult = inner_channel * channel_mults[ind] 198 | for _ in range(0, res_blocks): 199 | downs.append(ResnetBlocWithAttn( 200 | pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn)) 201 | feat_channels.append(channel_mult) 202 | pre_channel = channel_mult 203 | if not is_last: 204 | downs.append(Downsample(pre_channel)) 205 | feat_channels.append(pre_channel) 206 | now_res = now_res//2 207 | self.downs = nn.ModuleList(downs) 208 | 209 | self.mid = nn.ModuleList([ 210 | ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, 211 | dropout=dropout, with_attn=True), 212 | ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, 213 | dropout=dropout, with_attn=False) 214 | ]) 215 | 216 | ups = [] 217 | for ind in reversed(range(num_mults)): 218 | is_last = (ind < 1) 219 | use_attn = (now_res in attn_res) 220 | channel_mult = inner_channel * channel_mults[ind] 221 | for _ in range(0, res_blocks+1): 222 | ups.append(ResnetBlocWithAttn( 223 | pre_channel+feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, 224 | dropout=dropout, with_attn=use_attn)) 225 | pre_channel = channel_mult 226 | if not is_last: 227 | ups.append(Upsample(pre_channel)) 228 | now_res = now_res*2 229 | 230 | self.ups = nn.ModuleList(ups) 231 | 232 | self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups) 233 | 234 | def forward(self, x, time): 235 | t = self.noise_level_mlp(time) if exists( 236 | self.noise_level_mlp) else None 237 | 238 | feats = [] 239 | for layer in self.downs: 240 | if isinstance(layer, ResnetBlocWithAttn): 241 | x = layer(x, t) 242 | else: 243 | x = layer(x) 244 | feats.append(x) 245 | 246 | for layer in self.mid: 247 | if isinstance(layer, ResnetBlocWithAttn): 248 | x = layer(x, t) 249 | else: 250 | x = layer(x) 251 | 252 | for layer in self.ups: 253 | if isinstance(layer, ResnetBlocWithAttn): 254 | x = layer(torch.cat((x, feats.pop()), dim=1), t) 255 | else: 256 | x = layer(x) 257 | 258 | return self.final_conv(x) 259 | 260 | 261 | class Model(nn.Module): 262 | def __init__(self, config): 263 | super().__init__() 264 | self.Unet = UNet(config) 265 | 266 | self.states = torch.load(config.srno_weight, weights_only=True) 267 | 268 | self.srno = SRNO(config) 269 | self.srno.load_state_dict(self.states['model']) 270 | self.srno.eval() 271 | 272 | def forward(self, x, time, inp, hr_coord, cell): 273 | with torch.no_grad(): 274 | srno_feat = self.srno(inp, hr_coord, cell) 275 | srno_feat = srno_feat.detach() 276 | x = torch.cat([x, srno_feat], dim=1) 277 | 278 | return self.Unet(x, time) 279 | 280 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import yaml 4 | import sys 5 | import torch 6 | import numpy as np 7 | 8 | from tqdm import tqdm 9 | 10 | from utils import calc_psnr, calc_ssim 11 | from utils import Averager 12 | from utils import tensor2img, save_img 13 | 14 | from runners.diffusion import get_current_visuals 15 | from runners.diffusion import sample_image 16 | from runners.diffusion import make_data_loader 17 | 18 | from datasets import data_transform, inverse_data_transform 19 | 20 | from models.ema import EMAHelper 21 | from models.diffusion import Model 22 | 23 | 24 | def parse_args_and_config(): 25 | parser = argparse.ArgumentParser(description="Super-Resolution Diffusion Model Evaluation") 26 | parser.add_argument("--config", type=str, required=True, help="Path to the config file") 27 | parser.add_argument("--model", type=str, required=True, help="Path to the model checkpoint") 28 | parser.add_argument("--seed", type=int, default=60000, help="Random seed") 29 | parser.add_argument( 30 | "--timesteps", type=int, default=50, help="number of steps involved" 31 | ) 32 | parser.add_argument( 33 | "--ni", 34 | action="store_true", 35 | help="No interaction. Suitable for Slurm Job launcher", 36 | ) 37 | parser.add_argument( 38 | "--sample_type", 39 | type=str, 40 | default="generalized", 41 | help="sampling approach (generalized or ddpm_noisy)", 42 | ) 43 | parser.add_argument( 44 | "--skip_type", 45 | type=str, 46 | default="uniform", 47 | help="skip according to (uniform or quadratic)", 48 | ) 49 | parser.add_argument( 50 | "--eta", 51 | type=float, 52 | default=0.0, 53 | help="eta used to control the variances of sigma", 54 | ) 55 | 56 | args = parser.parse_args() 57 | 58 | with open(args.config, "r") as f: 59 | config = yaml.safe_load(f) 60 | config = dict2namespace(config) 61 | 62 | # Set device 63 | config.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 64 | 65 | # set random seed 66 | torch.manual_seed(args.seed) 67 | np.random.seed(args.seed) 68 | if torch.cuda.is_available(): 69 | torch.cuda.manual_seed_all(args.seed) 70 | 71 | torch.backends.cudnn.benchmark = True 72 | 73 | return args, config 74 | 75 | 76 | def dict2namespace(config): 77 | namespace = argparse.Namespace() 78 | for key, value in config.items(): 79 | if isinstance(value, dict): 80 | new_value = dict2namespace(value) 81 | else: 82 | new_value = value 83 | setattr(namespace, key, new_value) 84 | return namespace 85 | 86 | 87 | def eval_psrn(config, args, model, test_loader, result_path, idx): 88 | model.eval() 89 | 90 | with torch.no_grad(): 91 | for data in test_loader: 92 | gt = data['gt'].to(config.device) 93 | lr = data['inp'].to(config.device) 94 | cell = data['cell'].to(config.device) 95 | hr_coord = data['coord'].to(config.device) 96 | 97 | lr = data_transform(config, lr) 98 | gt = data_transform(config, gt) 99 | 100 | x_t = torch.randn_like(gt, device=config.device) 101 | 102 | sr = sample_image(config, args, x_t, model, lr , hr_coord, cell).to(config.device) 103 | visuals = get_current_visuals(sr, data) 104 | sr_img = tensor2img(visuals['SR']) # uint8 105 | hr_img = tensor2img(visuals['GT']) 106 | save_img(sr_img, '{}/{}_sr.png'.format(result_path, idx)) 107 | 108 | sr = inverse_data_transform(config, sr) 109 | gt = inverse_data_transform(config, gt) 110 | 111 | psnr = calc_psnr(gt, sr) 112 | ssim = calc_ssim(sr_img, hr_img) 113 | 114 | return psnr, ssim 115 | 116 | 117 | def load_model(config, args, model): 118 | checkpoint = torch.load(args.model, map_location=config.device, weights_only=True) 119 | model.load_state_dict(checkpoint[0], strict=True) 120 | 121 | if config.model.ema: 122 | ema_helper = EMAHelper(mu=config.model.ema_rate) 123 | ema_helper.register(model) 124 | ema_helper.load_state_dict(checkpoint[-1]) 125 | ema_helper.ema(model) 126 | else: 127 | ema_helper = None 128 | 129 | 130 | def main(image_path, result_path, idx): 131 | args, config = parse_args_and_config() 132 | 133 | model = Model(config) 134 | model = model.to(config.device) 135 | test_loader = make_data_loader(config.test_dataset, image_path, tag='test') 136 | 137 | try: 138 | load_model(config, args, model) 139 | psnr, ssim = eval_psrn(config, args, model, test_loader, result_path, idx) 140 | except Exception as e: 141 | print(f"Error during evaluation: {str(e)}") 142 | raise 143 | 144 | return psnr, ssim 145 | 146 | if __name__ == "__main__": 147 | avg_psnr = Averager() 148 | avg_loss = Averager() 149 | avg_ssim = Averager() 150 | 151 | idx = 0 152 | count = 0 153 | result_path = '{}/{}_{}'.format('result', 'test', 'UCM') 154 | os.makedirs(result_path, exist_ok=True) 155 | 156 | args, config = parse_args_and_config() 157 | root_path = config.test_dataset.dataset.root_path 158 | print(f"Starting evaluation with checkpoint: {args.model}") 159 | print(f"Using device: {config.device}") 160 | 161 | image_files = [f for f in os.listdir(root_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] 162 | for image in tqdm(image_files, total=len(image_files), desc="Generating image samples for Validate"): 163 | image_path = os.path.join(root_path, image) 164 | psnr, ssim = main(image_path, result_path, idx) 165 | 166 | idx += 1 167 | avg_psnr.add(psnr) 168 | avg_ssim.add(ssim) 169 | 170 | print(f"Avg_PSNR: {avg_psnr.item():6.3f}, Avg_SSIM: {avg_ssim.item():.4e}") 171 | 172 | sys.exit() 173 | 174 | 175 | 176 | 177 | 178 | # def parse_args_and_config(): 179 | # parser = argparse.ArgumentParser(description="Super-Resolution Diffusion Model Evaluation") 180 | # parser.add_argument("--config", type=str, required=True, help="Path to the config file") 181 | # parser.add_argument("--model", type=str, required=True, help="Path to the model checkpoint") 182 | # parser.add_argument("--seed", type=int, default=1234, help="Random seed") 183 | # parser.add_argument( 184 | # "--timesteps", type=int, default=5, help="number of steps involved" 185 | # ) 186 | # parser.add_argument( 187 | # "--ni", 188 | # action="store_true", 189 | # help="No interaction. Suitable for Slurm Job launcher", 190 | # ) 191 | # parser.add_argument( 192 | # "--sample_type", 193 | # type=str, 194 | # default="generalized", 195 | # help="sampling approach (generalized or ddpm_noisy)", 196 | # ) 197 | # parser.add_argument( 198 | # "--skip_type", 199 | # type=str, 200 | # default="uniform", 201 | # help="skip according to (uniform or quadratic)", 202 | # ) 203 | # parser.add_argument( 204 | # "--eta", 205 | # type=float, 206 | # default=0.0, 207 | # help="eta used to control the variances of sigma", 208 | # ) 209 | 210 | # args = parser.parse_args() 211 | 212 | # with open(args.config, "r") as f: 213 | # config = yaml.safe_load(f) 214 | # config = dict2namespace(config) 215 | 216 | # # Set device 217 | # config.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 218 | 219 | # # set random seed 220 | # torch.manual_seed(args.seed) 221 | # np.random.seed(args.seed) 222 | # if torch.cuda.is_available(): 223 | # torch.cuda.manual_seed_all(args.seed) 224 | 225 | # torch.backends.cudnn.benchmark = True 226 | 227 | # return args, config 228 | 229 | 230 | # def dict2namespace(config): 231 | # namespace = argparse.Namespace() 232 | # for key, value in config.items(): 233 | # if isinstance(value, dict): 234 | # new_value = dict2namespace(value) 235 | # else: 236 | # new_value = value 237 | # setattr(namespace, key, new_value) 238 | # return namespace 239 | 240 | 241 | # def eval_psrn(config, args, model, test_loader): 242 | # model.eval() 243 | # idx = 0 244 | 245 | # image_path = '{}/{}'.format('result', 'test') 246 | # os.makedirs(image_path, exist_ok=True) 247 | 248 | # with torch.no_grad(): 249 | # for data in tqdm(test_loader, desc="Generating image samples for test"): 250 | # idx += 1 251 | # gt = data['gt'].to(config.device) 252 | # lr = data['inp'].to(config.device) 253 | # cell = data['cell'].to(config.device) 254 | # hr_coord = data['coord'].to(config.device) 255 | 256 | # lr = data_transform(config, lr) 257 | # gt = data_transform(config, gt) 258 | 259 | # x_t = torch.randn_like(gt, device=config.device) 260 | # sr = sample_image(config, args, x_t, model, lr , hr_coord, cell).to(config.device) 261 | # visuals = get_current_visuals(sr, data) 262 | # sr_img = tensor2img(visuals['SR']) # uint8 263 | # hr_img = tensor2img(visuals['GT']) 264 | # save_img(sr_img, '{}/{}_sr.png'.format(image_path, idx)) 265 | 266 | # sr = inverse_data_transform(config, sr) 267 | # gt = inverse_data_transform(config, gt) 268 | 269 | # psnr = calc_psnr(gt, sr) 270 | # ssim = calc_ssim(sr_img, hr_img) 271 | 272 | # return psnr, ssim 273 | 274 | 275 | # def load_model(config, args, model): 276 | # checkpoint = torch.load(args.model, map_location=config.device, weights_only=True) 277 | # model.load_state_dict(checkpoint[0], strict=True) 278 | 279 | # if config.model.ema: 280 | # ema_helper = EMAHelper(mu=config.model.ema_rate) 281 | # ema_helper.register(model) 282 | # ema_helper.load_state_dict(checkpoint[-1]) 283 | # ema_helper.ema(model) 284 | # else: 285 | # ema_helper = None 286 | 287 | 288 | # def main(): 289 | # args, config = parse_args_and_config() 290 | # print(f"Starting evaluation with checkpoint: {args.model}") 291 | # print(f"Using device: {config.device}") 292 | 293 | # model = Model(config) 294 | # model = model.to(config.device) 295 | # test_loader = make_data_loader(config.test_dataset, tag='test') 296 | 297 | # try: 298 | # load_model(config, args, model) 299 | # psnr, ssim = eval_psrn(config, args, model, test_loader) 300 | # print(f"PSNR: {psnr:12.6f}, SSIM: {ssim:.4e}") 301 | # except Exception as e: 302 | # print(f"Error during evaluation: {str(e)}") 303 | # raise 304 | 305 | # return 0 306 | 307 | # if __name__ == "__main__": 308 | # sys.exit(main()) 309 | -------------------------------------------------------------------------------- /runners/diffusion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import time 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.utils.data as data 8 | 9 | from models.diffusion import Model 10 | from models.ema import EMAHelper 11 | from models.optimizer import get_optimizer 12 | from models.losses import loss_registry 13 | from models.denoising import generalized_steps, ddpm_steps 14 | 15 | from tqdm import tqdm 16 | from init_weight import init_weights 17 | from datasets.image_folder import ImageFolder 18 | from datasets import data_transform, inverse_data_transform 19 | 20 | from utils import calc_psnr, calc_ssim 21 | from utils import Averager 22 | from utils import tensor2img, save_img 23 | 24 | from collections import OrderedDict 25 | from torch.utils.data import DataLoader, DistributedSampler 26 | from scheduler import CustomIterationScheduler 27 | 28 | 29 | 30 | def make_data_loader(spec, image_path=None, tag=''): 31 | if spec is None: 32 | return None 33 | 34 | dataset = ImageFolder(spec.dataset, image_path) 35 | loader = DataLoader(dataset, batch_size=spec.batch_size, 36 | shuffle=(tag == 'train'), num_workers=5, pin_memory=True, persistent_workers=True) 37 | 38 | return loader 39 | 40 | 41 | def make_data_loaders(config): 42 | train_loader = make_data_loader(config.train_dataset, tag='train') 43 | val_loader = make_data_loader(config.val_dataset, tag='val') 44 | 45 | return train_loader, val_loader 46 | 47 | 48 | def get_current_visuals(SR, data, need_LR=True, sample=False): 49 | out_dict = OrderedDict() 50 | 51 | out_dict['SR'] = SR.detach().float().cpu() 52 | out_dict['GT'] = data['gt'].detach().float().cpu() 53 | return out_dict 54 | 55 | 56 | def torch2hwcuint8(x, clip=False): 57 | if clip: 58 | x = torch.clamp(x, -1, 1) 59 | x = (x + 1.0) / 2.0 60 | return x 61 | 62 | 63 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 64 | def sigmoid(x): 65 | return 1 / (np.exp(-x) + 1) 66 | 67 | if beta_schedule == "quad": 68 | betas = ( 69 | np.linspace( 70 | beta_start ** 0.5, 71 | beta_end ** 0.5, 72 | num_diffusion_timesteps, 73 | dtype=np.float64, 74 | ) 75 | ** 2 76 | ) 77 | elif beta_schedule == "linear": 78 | betas = np.linspace( 79 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 80 | ) 81 | elif beta_schedule == "const": 82 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 83 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 84 | betas = 1.0 / np.linspace( 85 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 86 | ) 87 | elif beta_schedule == "sigmoid": 88 | betas = np.linspace(-6, 6, num_diffusion_timesteps) 89 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start 90 | else: 91 | raise NotImplementedError(beta_schedule) 92 | assert betas.shape == (num_diffusion_timesteps,) 93 | return betas 94 | 95 | 96 | def sample_image(config, args, x, model, inp, coord, cell, last=True): 97 | 98 | try: 99 | skip = args.skip 100 | except Exception: 101 | skip = 1 102 | 103 | betas = get_beta_schedule( 104 | beta_schedule=config.diffusion.beta_schedule, 105 | beta_start=config.diffusion.beta_start, 106 | beta_end=config.diffusion.beta_end, 107 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps, 108 | ) 109 | betas = torch.from_numpy(betas).float().to(config.device) 110 | num_timesteps = betas.shape[0] 111 | alphas = 1.0 - betas 112 | alphas_cumprod = alphas.cumprod(dim=0) 113 | sqrt_alphas_cumprod_prev = np.sqrt( 114 | np.append(1., alphas_cumprod.cpu().numpy())) 115 | 116 | if args.sample_type == "generalized": 117 | if args.skip_type == "uniform": 118 | skip = num_timesteps // args.timesteps 119 | seq = range(0, num_timesteps, skip) 120 | elif args.skip_type == "quad": 121 | seq = ( 122 | np.linspace( 123 | 0, np.sqrt(num_timesteps * 0.8), args.timesteps 124 | ) 125 | ** 2 126 | ) 127 | seq = [int(s) for s in list(seq)] 128 | else: 129 | raise NotImplementedError 130 | 131 | xs = generalized_steps(x, seq, model, betas, inp, coord, cell, sqrt_alphas_cumprod_prev, eta=args.eta) 132 | x = xs 133 | 134 | elif args.sample_type == "ddpm_noisy": 135 | if args.skip_type == "uniform": 136 | skip = num_timesteps // args.timesteps 137 | seq = range(0, num_timesteps, skip) 138 | elif args.skip_type == "quad": 139 | seq = ( 140 | np.linspace( 141 | 0, np.sqrt(num_timesteps * 0.8), args.timesteps 142 | ) 143 | ** 2 144 | ) 145 | seq = [int(s) for s in list(seq)] 146 | else: 147 | raise NotImplementedError 148 | 149 | x = ddpm_steps(x, seq, model, betas, inp, coord, cell) 150 | 151 | else: 152 | raise NotImplementedError 153 | if last: 154 | x = x[0][-1] 155 | return x 156 | 157 | 158 | class Diffusion(object): 159 | def __init__(self, args, config, device=None): 160 | self.args = args 161 | self.config = config 162 | if device is None: 163 | device = ( 164 | torch.device("cuda") 165 | if torch.cuda.is_available() 166 | else torch.device("cpu") 167 | ) 168 | self.device = device 169 | 170 | self.model_var_type = config.model.var_type 171 | betas = get_beta_schedule( 172 | beta_schedule=config.diffusion.beta_schedule, 173 | beta_start=config.diffusion.beta_start, 174 | beta_end=config.diffusion.beta_end, 175 | num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps, 176 | ) 177 | betas = self.betas = torch.from_numpy(betas).float().to(device) 178 | self.num_timesteps = betas.shape[0] 179 | 180 | alphas = 1.0 - betas 181 | alphas_cumprod = alphas.cumprod(dim=0) 182 | self.sqrt_alphas_cumprod_prev = np.sqrt( 183 | np.append(1., alphas_cumprod.cpu().numpy())) 184 | alphas_cumprod_prev = torch.cat( 185 | [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0 186 | ) 187 | posterior_variance = ( 188 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 189 | ) 190 | if self.model_var_type == "fixedlarge": 191 | self.logvar = betas.log() 192 | elif self.model_var_type == "fixedsmall": 193 | self.logvar = posterior_variance.clamp(min=1e-20).log() 194 | 195 | def train(self): 196 | args, config = self.args, self.config 197 | tb_logger = self.config.tb_logger 198 | 199 | train_loader, val_loader = make_data_loaders(config) 200 | 201 | model = Model(config) 202 | # init_weights(model.Unet, init_type="orthogonal") 203 | model = model.to(self.device) 204 | 205 | lr_sequence = [float(lr) for lr in config.scheduler.lr_sequence] 206 | step_size = config.scheduler.step_size 207 | 208 | optimizer = get_optimizer(self.config, model.parameters()) 209 | scheduler = CustomIterationScheduler(optimizer, lr_sequence, step_size) 210 | 211 | if self.config.model.ema: 212 | ema_helper = EMAHelper(mu=self.config.model.ema_rate) 213 | ema_helper.register(model) 214 | else: 215 | ema_helper = None 216 | 217 | max_psnr = -1e18 218 | start_epoch, step = 1, 1 219 | if self.args.resume_training: 220 | states = torch.load(os.path.join(self.args.log_path, "step_best.pth"), weights_only=True) 221 | model.load_state_dict(states[0]) 222 | 223 | states[1]["param_groups"][0]["eps"] = self.config.optim.eps 224 | optimizer.load_state_dict(states[1]) 225 | start_epoch = states[2]+1 226 | step = states[3]+1 227 | if self.config.model.ema: 228 | ema_helper.load_state_dict(states[4]) 229 | 230 | for epoch in range(start_epoch, self.config.training.n_epochs+1): 231 | data_start = time.time() 232 | data_time = 0 233 | avg_loss = Averager() 234 | 235 | print(f"----------------- Executing training iteration {epoch} epoch. -----------------") 236 | # for i, data in enumerate(train_loader): 237 | for data in tqdm(train_loader, desc="Training"): 238 | x = data['gt'].to(self.device) 239 | inp = data['inp'].to(self.device) 240 | cell = data['cell'].to(self.device) 241 | hr_coord = data['coord'].to(self.device) 242 | 243 | x = data_transform(self.config, x) 244 | inp = data_transform(self.config, inp) 245 | 246 | n = x.size(0) 247 | data_time += time.time() - data_start 248 | model.train() 249 | 250 | e = torch.randn_like(x).to(self.device) 251 | b = self.betas 252 | t_ = np.random.randint(1, self.num_timesteps + 1) 253 | continuous_sqrt_alpha_cumprod = torch.FloatTensor( 254 | np.random.uniform( 255 | self.sqrt_alphas_cumprod_prev[t_-1], 256 | self.sqrt_alphas_cumprod_prev[t_], 257 | size=n 258 | ) 259 | ).to(self.device) 260 | continuous_sqrt_alpha_cumprod = continuous_sqrt_alpha_cumprod.view(n, -1) 261 | x_t = continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1) * x + (1 - continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1)**2).sqrt() * e 262 | 263 | # antithetic sampling 264 | t = torch.randint( 265 | low=0, high=self.num_timesteps, size=(n // 2 + 1,) 266 | ).to(self.device) 267 | t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n] 268 | 269 | loss = loss_registry[config.model.type](model, x_t, inp, hr_coord, cell, e, b, continuous_sqrt_alpha_cumprod) 270 | avg_loss.add(loss) 271 | tb_logger.add_scalar("loss", loss, global_step=step) 272 | # print(f"Training... step: {step:7d}, pix_loss: {loss.item():10.7f}") 273 | 274 | optimizer.zero_grad() 275 | loss.backward() 276 | 277 | try: 278 | torch.nn.utils.clip_grad_norm_( 279 | model.parameters(), config.optim.grad_clip 280 | ) 281 | except Exception: 282 | pass 283 | optimizer.step() 284 | 285 | if self.config.model.ema: 286 | ema_helper.update(model) 287 | 288 | data_start = time.time() 289 | 290 | # save model 291 | if (step % 20000 == 0): 292 | states = [ 293 | model.state_dict(), 294 | optimizer.state_dict(), 295 | epoch, 296 | step, 297 | ] 298 | if self.config.model.ema: 299 | states.append(ema_helper.state_dict()) 300 | 301 | print("-------------------------- Begin validation ----------------------------") 302 | 303 | current_psnr, ssim = self.validate(config, args, model, val_loader, epoch) 304 | logging.info( 305 | f"Validating... epoch: {epoch:2d} PSNR: {current_psnr:12.6f}, SSIM: {ssim:.4e}" 306 | ) 307 | 308 | print("------------------------ Model is being saved --------------------------") 309 | 310 | torch.save(states, os.path.join(self.args.log_path, 'step_{}.pth'.format(step))) 311 | if current_psnr > max_psnr: 312 | max_psnr = current_psnr 313 | torch.save(states, os.path.join(self.args.log_path, 'step_best.pth')) 314 | 315 | print("--------------------------- End of saving ------------------------------") 316 | 317 | logging.info(f"Current learning rate: {optimizer.param_groups[0]['lr']}") 318 | 319 | scheduler.step(step) 320 | step += 1 321 | 322 | logging.info(f"Training... epoch: {epoch:2d} Loss: {avg_loss.item():12.6f}") 323 | 324 | def validate(self, config, args, model, val_loader, epoch): 325 | model.eval() 326 | avg_psnr = Averager() 327 | avg_loss = Averager() 328 | avg_ssim = Averager() 329 | idx = 0 330 | 331 | result_path = '{}/{}'.format('result', epoch) 332 | os.makedirs(result_path, exist_ok=True) 333 | 334 | with torch.no_grad(): 335 | for data in tqdm(val_loader, desc="Generating image samples for PSNR Validate"): 336 | idx += 1 337 | gt = data['gt'].to(self.device) 338 | lr = data['inp'].to(self.device) 339 | cell = data['cell'].to(self.device) 340 | hr_coord = data['coord'].to(self.device) 341 | 342 | gt = data_transform(config, gt) 343 | lr = data_transform(config, lr) 344 | 345 | x_t = torch.randn_like(gt, device=self.device) 346 | sr = sample_image(config, args, x_t, model, lr , hr_coord, cell).to(self.device) 347 | visuals = get_current_visuals(sr, data) 348 | sr_img = tensor2img(visuals['SR']) # uint8 349 | hr_img = tensor2img(visuals['GT']) 350 | save_img(sr_img, '{}/{}_sr.png'.format(result_path, idx)) 351 | 352 | gt = inverse_data_transform(config, gt) 353 | sr = inverse_data_transform(config, sr) 354 | 355 | psnr = calc_psnr(gt, sr) 356 | ssim = calc_ssim(sr_img, hr_img) 357 | avg_psnr.add(psnr) 358 | avg_ssim.add(ssim) 359 | 360 | return avg_psnr.item(), avg_ssim.item() 361 | --------------------------------------------------------------------------------