├── .DS_Store ├── .gitignore ├── .gitmodules ├── Docs ├── lensless_teaser.jpg └── overview.png ├── NullSpaceDiff ├── README.md ├── app.py ├── basicsr │ ├── __init__.py │ ├── archs │ │ ├── __init__.py │ │ ├── arch_util.py │ │ ├── basicvsr_arch.py │ │ ├── basicvsrpp_arch.py │ │ ├── degradat_arch.py │ │ ├── dfdnet_arch.py │ │ ├── dfdnet_util.py │ │ ├── discriminator_arch.py │ │ ├── duf_arch.py │ │ ├── ecbsr_arch.py │ │ ├── edsr_arch.py │ │ ├── edvr_arch.py │ │ ├── hifacegan_arch.py │ │ ├── hifacegan_util.py │ │ ├── inception.py │ │ ├── rcan_arch.py │ │ ├── ridnet_arch.py │ │ ├── rrdbnet_arch.py │ │ ├── spynet_arch.py │ │ ├── srresnet_arch.py │ │ ├── srvgg_arch.py │ │ ├── stylegan2_arch.py │ │ ├── stylegan2_bilinear_arch.py │ │ ├── swinir_arch.py │ │ ├── tof_arch.py │ │ └── vgg_arch.py │ ├── data │ │ ├── __init__.py │ │ ├── data_sampler.py │ │ ├── data_util.py │ │ ├── degradations.py │ │ ├── ffhq_dataset.py │ │ ├── ffhq_degradation_dataset.py │ │ ├── paired_image_dataset.py │ │ ├── prefetch_dataloader.py │ │ ├── realesrgan_dataset.py │ │ ├── realesrgan_paired_dataset.py │ │ ├── reds_dataset.py │ │ ├── single_image_dataset.py │ │ ├── transforms.py │ │ ├── video_test_dataset.py │ │ └── vimeo90k_dataset.py │ ├── losses │ │ ├── __init__.py │ │ ├── basic_loss.py │ │ ├── gan_loss.py │ │ └── loss_util.py │ ├── metrics │ │ ├── README.md │ │ ├── README_CN.md │ │ ├── __init__.py │ │ ├── fid.py │ │ ├── metric_util.py │ │ ├── niqe.py │ │ ├── niqe_pris_params.npz │ │ ├── psnr_ssim.py │ │ └── test_metrics │ │ │ └── test_psnr_ssim.py │ ├── models │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── edvr_model.py │ │ ├── esrgan_model.py │ │ ├── hifacegan_model.py │ │ ├── lr_scheduler.py │ │ ├── realesrgan_model.py │ │ ├── realesrnet_model.py │ │ ├── sr_model.py │ │ ├── srgan_model.py │ │ ├── stylegan2_model.py │ │ ├── swinir_model.py │ │ ├── video_base_model.py │ │ ├── video_gan_model.py │ │ ├── video_recurrent_gan_model.py │ │ └── video_recurrent_model.py │ ├── ops │ │ ├── __init__.py │ │ ├── dcn │ │ │ ├── __init__.py │ │ │ └── deform_conv.py │ │ ├── fused_act │ │ │ ├── __init__.py │ │ │ └── fused_act.py │ │ └── upfirdn2d │ │ │ ├── __init__.py │ │ │ └── upfirdn2d.py │ ├── test.py │ ├── train.py │ └── utils │ │ ├── __init__.py │ │ ├── color_util.py │ │ ├── diffjpeg.py │ │ ├── dist_util.py │ │ ├── download_util.py │ │ ├── file_client.py │ │ ├── flow_util.py │ │ ├── img_process_util.py │ │ ├── img_util.py │ │ ├── lmdb_util.py │ │ ├── logger.py │ │ ├── matlab_functions.py │ │ ├── misc.py │ │ ├── options.py │ │ ├── plot_util.py │ │ ├── realesrgan_utils.py │ │ └── registry.py ├── cog.yaml ├── configs │ ├── NullSpaceDiff │ │ └── phlatcam_decoded_sim_multi_T_512.yaml │ └── autoencoder │ │ ├── autoencoder_kl_64x64x4_resi.yaml │ │ └── autoencoder_kl_64x64x4_resi_face.yaml ├── environment.yaml ├── ldm │ ├── data │ │ ├── __init__.py │ │ ├── base.py │ │ ├── imagenet.py │ │ └── lsun.py │ ├── lr_scheduler.py │ ├── models │ │ ├── autoencoder.py │ │ ├── diffusion │ │ │ ├── __init__.py │ │ │ ├── classifier.py │ │ │ ├── ddim copy.py │ │ │ ├── ddim.py │ │ │ ├── ddim_with_grad.py │ │ │ ├── ddnm.py │ │ │ ├── ddpm.py │ │ │ ├── ddpm_cond.py │ │ │ ├── ddpm_inv.py │ │ │ └── plms.py │ │ └── respace.py │ ├── modules │ │ ├── attention.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ └── util.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── embedding_manager.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ ├── modules.py │ │ │ └── transformer_utils.py │ │ ├── fftlayer.py │ │ ├── image_degradation │ │ │ ├── __init__.py │ │ │ ├── bsrgan.py │ │ │ ├── bsrgan_light.py │ │ │ ├── utils │ │ │ │ └── test.png │ │ │ └── utils_image.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── contperceptual.py │ │ │ └── vqperceptual.py │ │ ├── spade.py │ │ ├── swinir.py │ │ └── x_transformer.py │ └── util.py ├── main.py ├── musiq │ ├── model │ │ ├── multiscale_transformer.py │ │ ├── multiscale_transformer_utils.py │ │ ├── preprocessing.py │ │ └── resnet.py │ ├── requirements.txt │ ├── run_predict_folder.py │ └── run_predict_image.py ├── predict.py ├── scripts │ ├── helper.py │ ├── infer.sh │ ├── sr_val_ddpm_lensless.py │ ├── train.sh │ ├── util_image.py │ └── wavelet_color_fix.py ├── setup.py └── utils │ └── lensless_utils.py ├── README.md ├── SVDeconv ├── config.py ├── config_diffusercam.py ├── data ├── dataloader.py ├── datasets │ ├── diffusercam.py │ └── phlatcam.py ├── loss.py ├── metrics.py ├── models │ ├── __init__.py │ ├── discriminator.py │ ├── fftlayer.py │ ├── fftlayer_diff.py │ ├── fftlayer_diff_original.py │ ├── get_model.py │ ├── multi_fftlayer.py │ ├── multi_fftlayer_diff.py │ ├── multi_fftlayer_new.py │ ├── unet.py │ └── unet_128.py ├── scripts │ └── train.sh ├── tools │ ├── decode_and_sim_rgb.py │ └── decode_and_sim_rgb_diff_padding.py ├── train.py ├── train_diff.py ├── utils │ ├── __init__.py │ ├── checkpoint_excluder.txt │ ├── contextual_loss.py │ ├── dir_helper.py │ ├── model_serialization.py │ ├── ops.py │ ├── train_helper.py │ ├── tupperware.py │ └── typing_alias.py └── val.py ├── requirements.txt └── tools ├── copy_gt.py └── data_process.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/PhoCoLens/154fe32aea5c2b623f6dd1c07e90c2900d076486/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | StableSR/* 2 | */data/* 3 | *ckpts* 4 | NullSpaceDiff/data_process/* 5 | NullSpaceDiff/logs/* 6 | NullSpaceDiff/scripts_bk/* 7 | __pycache__/* 8 | */__pycache__/* 9 | SVDeconv/output* 10 | SVDeconv/runs* 11 | SVDeconv/data/* 12 | *.pyc 13 | */logs -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/PhoCoLens/154fe32aea5c2b623f6dd1c07e90c2900d076486/.gitmodules -------------------------------------------------------------------------------- /Docs/lensless_teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/PhoCoLens/154fe32aea5c2b623f6dd1c07e90c2900d076486/Docs/lensless_teaser.jpg -------------------------------------------------------------------------------- /Docs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/PhoCoLens/154fe32aea5c2b623f6dd1c07e90c2900d076486/Docs/overview.png -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/xinntao/BasicSR 2 | # flake8: noqa 3 | from .archs import * 4 | from .data import * 5 | from .losses import * 6 | from .metrics import * 7 | from .models import * 8 | from .ops import * 9 | from .test import * 10 | from .train import * 11 | from .utils import * 12 | # from .version import __gitsha__, __version__ 13 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | 8 | __all__ = ['build_network'] 9 | 10 | # automatically scan and import arch modules for registry 11 | # scan all the files under the 'archs' folder and collect files ending with '_arch.py' 12 | arch_folder = osp.dirname(osp.abspath(__file__)) 13 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 14 | # import all the arch modules 15 | _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] 16 | 17 | 18 | def build_network(opt): 19 | opt = deepcopy(opt) 20 | network_type = opt.pop('type') 21 | net = ARCH_REGISTRY.get(network_type)(**opt) 22 | logger = get_root_logger() 23 | logger.info(f'Network [{net.__class__.__name__}] is created.') 24 | return net 25 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/archs/degradat_arch.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | 3 | from basicsr.archs.arch_util import ResidualBlockNoBN, default_init_weights 4 | from basicsr.utils.registry import ARCH_REGISTRY 5 | 6 | @ARCH_REGISTRY.register() 7 | class DEResNet(nn.Module): 8 | """Degradation Estimator with ResNetNoBN arch. v2.1, no vector anymore 9 | As shown in paper 'Towards Flexible Blind JPEG Artifacts Removal', 10 | resnet arch works for image quality estimation. 11 | Args: 12 | num_in_ch (int): channel number of inputs. Default: 3. 13 | num_degradation (int): num of degradation the DE should estimate. Default: 2(blur+noise). 14 | degradation_embed_size (int): embedding size of each degradation vector. 15 | degradation_degree_actv (int): activation function for degradation degree scalar. Default: sigmoid. 16 | num_feats (list): channel number of each stage. 17 | num_blocks (list): residual block of each stage. 18 | downscales (list): downscales of each stage. 19 | """ 20 | 21 | def __init__(self, 22 | num_in_ch=3, 23 | num_degradation=2, 24 | degradation_degree_actv='sigmoid', 25 | num_feats=(64, 128, 256, 512), 26 | num_blocks=(2, 2, 2, 2), 27 | downscales=(2, 2, 2, 1)): 28 | super(DEResNet, self).__init__() 29 | 30 | assert isinstance(num_feats, list) 31 | assert isinstance(num_blocks, list) 32 | assert isinstance(downscales, list) 33 | assert len(num_feats) == len(num_blocks) and len(num_feats) == len(downscales) 34 | 35 | num_stage = len(num_feats) 36 | 37 | self.conv_first = nn.ModuleList() 38 | for _ in range(num_degradation): 39 | self.conv_first.append(nn.Conv2d(num_in_ch, num_feats[0], 3, 1, 1)) 40 | self.body = nn.ModuleList() 41 | for _ in range(num_degradation): 42 | body = list() 43 | for stage in range(num_stage): 44 | for _ in range(num_blocks[stage]): 45 | body.append(ResidualBlockNoBN(num_feats[stage])) 46 | if downscales[stage] == 1: 47 | if stage < num_stage - 1 and num_feats[stage] != num_feats[stage + 1]: 48 | body.append(nn.Conv2d(num_feats[stage], num_feats[stage + 1], 3, 1, 1)) 49 | continue 50 | elif downscales[stage] == 2: 51 | body.append(nn.Conv2d(num_feats[stage], num_feats[min(stage + 1, num_stage - 1)], 3, 2, 1)) 52 | else: 53 | raise NotImplementedError 54 | self.body.append(nn.Sequential(*body)) 55 | 56 | # self.body = nn.Sequential(*body) 57 | 58 | self.num_degradation = num_degradation 59 | self.fc_degree = nn.ModuleList() 60 | if degradation_degree_actv == 'sigmoid': 61 | actv = nn.Sigmoid 62 | elif degradation_degree_actv == 'tanh': 63 | actv = nn.Tanh 64 | else: 65 | raise NotImplementedError(f'only sigmoid and tanh are supported for degradation_degree_actv, ' 66 | f'{degradation_degree_actv} is not supported yet.') 67 | for _ in range(num_degradation): 68 | self.fc_degree.append( 69 | nn.Sequential( 70 | nn.Linear(num_feats[-1], 512), 71 | nn.ReLU(inplace=True), 72 | nn.Linear(512, 1), 73 | actv(), 74 | )) 75 | 76 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 77 | 78 | default_init_weights([self.conv_first, self.body, self.fc_degree], 0.1) 79 | 80 | def forward(self, x): 81 | degrees = [] 82 | for i in range(self.num_degradation): 83 | x_out = self.conv_first[i](x) 84 | feat = self.body[i](x_out) 85 | feat = self.avg_pool(feat) 86 | feat = feat.squeeze(-1).squeeze(-1) 87 | # for i in range(self.num_degradation): 88 | degrees.append(self.fc_degree[i](feat).squeeze(-1)) 89 | 90 | return degrees 91 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/archs/dfdnet_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Function 5 | from torch.nn.utils.spectral_norm import spectral_norm 6 | 7 | 8 | class BlurFunctionBackward(Function): 9 | 10 | @staticmethod 11 | def forward(ctx, grad_output, kernel, kernel_flip): 12 | ctx.save_for_backward(kernel, kernel_flip) 13 | grad_input = F.conv2d(grad_output, kernel_flip, padding=1, groups=grad_output.shape[1]) 14 | return grad_input 15 | 16 | @staticmethod 17 | def backward(ctx, gradgrad_output): 18 | kernel, _ = ctx.saved_tensors 19 | grad_input = F.conv2d(gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1]) 20 | return grad_input, None, None 21 | 22 | 23 | class BlurFunction(Function): 24 | 25 | @staticmethod 26 | def forward(ctx, x, kernel, kernel_flip): 27 | ctx.save_for_backward(kernel, kernel_flip) 28 | output = F.conv2d(x, kernel, padding=1, groups=x.shape[1]) 29 | return output 30 | 31 | @staticmethod 32 | def backward(ctx, grad_output): 33 | kernel, kernel_flip = ctx.saved_tensors 34 | grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip) 35 | return grad_input, None, None 36 | 37 | 38 | blur = BlurFunction.apply 39 | 40 | 41 | class Blur(nn.Module): 42 | 43 | def __init__(self, channel): 44 | super().__init__() 45 | kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32) 46 | kernel = kernel.view(1, 1, 3, 3) 47 | kernel = kernel / kernel.sum() 48 | kernel_flip = torch.flip(kernel, [2, 3]) 49 | 50 | self.kernel = kernel.repeat(channel, 1, 1, 1) 51 | self.kernel_flip = kernel_flip.repeat(channel, 1, 1, 1) 52 | 53 | def forward(self, x): 54 | return blur(x, self.kernel.type_as(x), self.kernel_flip.type_as(x)) 55 | 56 | 57 | def calc_mean_std(feat, eps=1e-5): 58 | """Calculate mean and std for adaptive_instance_normalization. 59 | 60 | Args: 61 | feat (Tensor): 4D tensor. 62 | eps (float): A small value added to the variance to avoid 63 | divide-by-zero. Default: 1e-5. 64 | """ 65 | size = feat.size() 66 | assert len(size) == 4, 'The input feature should be 4D tensor.' 67 | n, c = size[:2] 68 | feat_var = feat.view(n, c, -1).var(dim=2) + eps 69 | feat_std = feat_var.sqrt().view(n, c, 1, 1) 70 | feat_mean = feat.view(n, c, -1).mean(dim=2).view(n, c, 1, 1) 71 | return feat_mean, feat_std 72 | 73 | 74 | def adaptive_instance_normalization(content_feat, style_feat): 75 | """Adaptive instance normalization. 76 | 77 | Adjust the reference features to have the similar color and illuminations 78 | as those in the degradate features. 79 | 80 | Args: 81 | content_feat (Tensor): The reference feature. 82 | style_feat (Tensor): The degradate features. 83 | """ 84 | size = content_feat.size() 85 | style_mean, style_std = calc_mean_std(style_feat) 86 | content_mean, content_std = calc_mean_std(content_feat) 87 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) 88 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 89 | 90 | 91 | def AttentionBlock(in_channel): 92 | return nn.Sequential( 93 | spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True), 94 | spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1))) 95 | 96 | 97 | def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True): 98 | """Conv block used in MSDilationBlock.""" 99 | 100 | return nn.Sequential( 101 | spectral_norm( 102 | nn.Conv2d( 103 | in_channels, 104 | out_channels, 105 | kernel_size=kernel_size, 106 | stride=stride, 107 | dilation=dilation, 108 | padding=((kernel_size - 1) // 2) * dilation, 109 | bias=bias)), 110 | nn.LeakyReLU(0.2), 111 | spectral_norm( 112 | nn.Conv2d( 113 | out_channels, 114 | out_channels, 115 | kernel_size=kernel_size, 116 | stride=stride, 117 | dilation=dilation, 118 | padding=((kernel_size - 1) // 2) * dilation, 119 | bias=bias)), 120 | ) 121 | 122 | 123 | class MSDilationBlock(nn.Module): 124 | """Multi-scale dilation block.""" 125 | 126 | def __init__(self, in_channels, kernel_size=3, dilation=(1, 1, 1, 1), bias=True): 127 | super(MSDilationBlock, self).__init__() 128 | 129 | self.conv_blocks = nn.ModuleList() 130 | for i in range(4): 131 | self.conv_blocks.append(conv_block(in_channels, in_channels, kernel_size, dilation=dilation[i], bias=bias)) 132 | self.conv_fusion = spectral_norm( 133 | nn.Conv2d( 134 | in_channels * 4, 135 | in_channels, 136 | kernel_size=kernel_size, 137 | stride=1, 138 | padding=(kernel_size - 1) // 2, 139 | bias=bias)) 140 | 141 | def forward(self, x): 142 | out = [] 143 | for i in range(4): 144 | out.append(self.conv_blocks[i](x)) 145 | out = torch.cat(out, 1) 146 | out = self.conv_fusion(out) + x 147 | return out 148 | 149 | 150 | class UpResBlock(nn.Module): 151 | 152 | def __init__(self, in_channel): 153 | super(UpResBlock, self).__init__() 154 | self.body = nn.Sequential( 155 | nn.Conv2d(in_channel, in_channel, 3, 1, 1), 156 | nn.LeakyReLU(0.2, True), 157 | nn.Conv2d(in_channel, in_channel, 3, 1, 1), 158 | ) 159 | 160 | def forward(self, x): 161 | out = x + self.body(x) 162 | return out 163 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/archs/edsr_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | from basicsr.archs.arch_util import ResidualBlockNoBN, Upsample, make_layer 5 | from basicsr.utils.registry import ARCH_REGISTRY 6 | 7 | 8 | @ARCH_REGISTRY.register() 9 | class EDSR(nn.Module): 10 | """EDSR network structure. 11 | 12 | Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution. 13 | Ref git repo: https://github.com/thstkdgus35/EDSR-PyTorch 14 | 15 | Args: 16 | num_in_ch (int): Channel number of inputs. 17 | num_out_ch (int): Channel number of outputs. 18 | num_feat (int): Channel number of intermediate features. 19 | Default: 64. 20 | num_block (int): Block number in the trunk network. Default: 16. 21 | upscale (int): Upsampling factor. Support 2^n and 3. 22 | Default: 4. 23 | res_scale (float): Used to scale the residual in residual block. 24 | Default: 1. 25 | img_range (float): Image range. Default: 255. 26 | rgb_mean (tuple[float]): Image mean in RGB orders. 27 | Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. 28 | """ 29 | 30 | def __init__(self, 31 | num_in_ch, 32 | num_out_ch, 33 | num_feat=64, 34 | num_block=16, 35 | upscale=4, 36 | res_scale=1, 37 | img_range=255., 38 | rgb_mean=(0.4488, 0.4371, 0.4040)): 39 | super(EDSR, self).__init__() 40 | 41 | self.img_range = img_range 42 | self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) 43 | 44 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 45 | self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True) 46 | self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 47 | self.upsample = Upsample(upscale, num_feat) 48 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 49 | 50 | def forward(self, x): 51 | self.mean = self.mean.type_as(x) 52 | 53 | x = (x - self.mean) * self.img_range 54 | x = self.conv_first(x) 55 | res = self.conv_after_body(self.body(x)) 56 | res += x 57 | 58 | x = self.conv_last(self.upsample(res)) 59 | x = x / self.img_range + self.mean 60 | 61 | return x 62 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/archs/rcan_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | from basicsr.utils.registry import ARCH_REGISTRY 5 | from .arch_util import Upsample, make_layer 6 | 7 | 8 | class ChannelAttention(nn.Module): 9 | """Channel attention used in RCAN. 10 | 11 | Args: 12 | num_feat (int): Channel number of intermediate features. 13 | squeeze_factor (int): Channel squeeze factor. Default: 16. 14 | """ 15 | 16 | def __init__(self, num_feat, squeeze_factor=16): 17 | super(ChannelAttention, self).__init__() 18 | self.attention = nn.Sequential( 19 | nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0), 20 | nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid()) 21 | 22 | def forward(self, x): 23 | y = self.attention(x) 24 | return x * y 25 | 26 | 27 | class RCAB(nn.Module): 28 | """Residual Channel Attention Block (RCAB) used in RCAN. 29 | 30 | Args: 31 | num_feat (int): Channel number of intermediate features. 32 | squeeze_factor (int): Channel squeeze factor. Default: 16. 33 | res_scale (float): Scale the residual. Default: 1. 34 | """ 35 | 36 | def __init__(self, num_feat, squeeze_factor=16, res_scale=1): 37 | super(RCAB, self).__init__() 38 | self.res_scale = res_scale 39 | 40 | self.rcab = nn.Sequential( 41 | nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1), 42 | ChannelAttention(num_feat, squeeze_factor)) 43 | 44 | def forward(self, x): 45 | res = self.rcab(x) * self.res_scale 46 | return res + x 47 | 48 | 49 | class ResidualGroup(nn.Module): 50 | """Residual Group of RCAB. 51 | 52 | Args: 53 | num_feat (int): Channel number of intermediate features. 54 | num_block (int): Block number in the body network. 55 | squeeze_factor (int): Channel squeeze factor. Default: 16. 56 | res_scale (float): Scale the residual. Default: 1. 57 | """ 58 | 59 | def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1): 60 | super(ResidualGroup, self).__init__() 61 | 62 | self.residual_group = make_layer( 63 | RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale) 64 | self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 65 | 66 | def forward(self, x): 67 | res = self.conv(self.residual_group(x)) 68 | return res + x 69 | 70 | 71 | @ARCH_REGISTRY.register() 72 | class RCAN(nn.Module): 73 | """Residual Channel Attention Networks. 74 | 75 | ``Paper: Image Super-Resolution Using Very Deep Residual Channel Attention Networks`` 76 | 77 | Reference: https://github.com/yulunzhang/RCAN 78 | 79 | Args: 80 | num_in_ch (int): Channel number of inputs. 81 | num_out_ch (int): Channel number of outputs. 82 | num_feat (int): Channel number of intermediate features. 83 | Default: 64. 84 | num_group (int): Number of ResidualGroup. Default: 10. 85 | num_block (int): Number of RCAB in ResidualGroup. Default: 16. 86 | squeeze_factor (int): Channel squeeze factor. Default: 16. 87 | upscale (int): Upsampling factor. Support 2^n and 3. 88 | Default: 4. 89 | res_scale (float): Used to scale the residual in residual block. 90 | Default: 1. 91 | img_range (float): Image range. Default: 255. 92 | rgb_mean (tuple[float]): Image mean in RGB orders. 93 | Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. 94 | """ 95 | 96 | def __init__(self, 97 | num_in_ch, 98 | num_out_ch, 99 | num_feat=64, 100 | num_group=10, 101 | num_block=16, 102 | squeeze_factor=16, 103 | upscale=4, 104 | res_scale=1, 105 | img_range=255., 106 | rgb_mean=(0.4488, 0.4371, 0.4040)): 107 | super(RCAN, self).__init__() 108 | 109 | self.img_range = img_range 110 | self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) 111 | 112 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 113 | self.body = make_layer( 114 | ResidualGroup, 115 | num_group, 116 | num_feat=num_feat, 117 | num_block=num_block, 118 | squeeze_factor=squeeze_factor, 119 | res_scale=res_scale) 120 | self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 121 | self.upsample = Upsample(upscale, num_feat) 122 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 123 | 124 | def forward(self, x): 125 | self.mean = self.mean.type_as(x) 126 | 127 | x = (x - self.mean) * self.img_range 128 | x = self.conv_first(x) 129 | res = self.conv_after_body(self.body(x)) 130 | res += x 131 | 132 | x = self.conv_last(self.upsample(res)) 133 | x = x / self.img_range + self.mean 134 | 135 | return x 136 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/archs/rrdbnet_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | from basicsr.utils.registry import ARCH_REGISTRY 6 | from .arch_util import default_init_weights, make_layer, pixel_unshuffle 7 | 8 | 9 | class ResidualDenseBlock(nn.Module): 10 | """Residual Dense Block. 11 | 12 | Used in RRDB block in ESRGAN. 13 | 14 | Args: 15 | num_feat (int): Channel number of intermediate features. 16 | num_grow_ch (int): Channels for each growth. 17 | """ 18 | 19 | def __init__(self, num_feat=64, num_grow_ch=32): 20 | super(ResidualDenseBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) 22 | self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) 23 | self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) 24 | self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) 25 | self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) 26 | 27 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 28 | 29 | # initialization 30 | default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 31 | 32 | def forward(self, x): 33 | x1 = self.lrelu(self.conv1(x)) 34 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 35 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 36 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 37 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 38 | # Empirically, we use 0.2 to scale the residual for better performance 39 | return x5 * 0.2 + x 40 | 41 | 42 | class RRDB(nn.Module): 43 | """Residual in Residual Dense Block. 44 | 45 | Used in RRDB-Net in ESRGAN. 46 | 47 | Args: 48 | num_feat (int): Channel number of intermediate features. 49 | num_grow_ch (int): Channels for each growth. 50 | """ 51 | 52 | def __init__(self, num_feat, num_grow_ch=32): 53 | super(RRDB, self).__init__() 54 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) 55 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) 56 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) 57 | 58 | def forward(self, x): 59 | out = self.rdb1(x) 60 | out = self.rdb2(out) 61 | out = self.rdb3(out) 62 | # Empirically, we use 0.2 to scale the residual for better performance 63 | return out * 0.2 + x 64 | 65 | 66 | @ARCH_REGISTRY.register() 67 | class RRDBNet(nn.Module): 68 | """Networks consisting of Residual in Residual Dense Block, which is used 69 | in ESRGAN. 70 | 71 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. 72 | 73 | We extend ESRGAN for scale x2 and scale x1. 74 | Note: This is one option for scale 1, scale 2 in RRDBNet. 75 | We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size 76 | and enlarge the channel size before feeding inputs into the main ESRGAN architecture. 77 | 78 | Args: 79 | num_in_ch (int): Channel number of inputs. 80 | num_out_ch (int): Channel number of outputs. 81 | num_feat (int): Channel number of intermediate features. 82 | Default: 64 83 | num_block (int): Block number in the trunk network. Defaults: 23 84 | num_grow_ch (int): Channels for each growth. Default: 32. 85 | """ 86 | 87 | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): 88 | super(RRDBNet, self).__init__() 89 | self.scale = scale 90 | if scale == 2: 91 | num_in_ch = num_in_ch * 4 92 | elif scale == 1: 93 | num_in_ch = num_in_ch * 16 94 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 95 | self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) 96 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 97 | # upsample 98 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 99 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 100 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 101 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 102 | 103 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 104 | 105 | def forward(self, x): 106 | if self.scale == 2: 107 | feat = pixel_unshuffle(x, scale=2) 108 | elif self.scale == 1: 109 | feat = pixel_unshuffle(x, scale=4) 110 | else: 111 | feat = x 112 | feat = self.conv_first(feat) 113 | body_feat = self.conv_body(self.body(feat)) 114 | feat = feat + body_feat 115 | # upsample 116 | feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) 117 | feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) 118 | out = self.conv_last(self.lrelu(self.conv_hr(feat))) 119 | return out 120 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/archs/spynet_arch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn as nn 4 | from torch.nn import functional as F 5 | 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | from .arch_util import flow_warp 8 | 9 | 10 | class BasicModule(nn.Module): 11 | """Basic Module for SpyNet. 12 | """ 13 | 14 | def __init__(self): 15 | super(BasicModule, self).__init__() 16 | 17 | self.basic_module = nn.Sequential( 18 | nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), 19 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), 20 | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), 21 | nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), 22 | nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) 23 | 24 | def forward(self, tensor_input): 25 | return self.basic_module(tensor_input) 26 | 27 | 28 | @ARCH_REGISTRY.register() 29 | class SpyNet(nn.Module): 30 | """SpyNet architecture. 31 | 32 | Args: 33 | load_path (str): path for pretrained SpyNet. Default: None. 34 | """ 35 | 36 | def __init__(self, load_path=None): 37 | super(SpyNet, self).__init__() 38 | self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)]) 39 | if load_path: 40 | self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) 41 | 42 | self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 43 | self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 44 | 45 | def preprocess(self, tensor_input): 46 | tensor_output = (tensor_input - self.mean) / self.std 47 | return tensor_output 48 | 49 | def process(self, ref, supp): 50 | flow = [] 51 | 52 | ref = [self.preprocess(ref)] 53 | supp = [self.preprocess(supp)] 54 | 55 | for level in range(5): 56 | ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) 57 | supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False)) 58 | 59 | flow = ref[0].new_zeros( 60 | [ref[0].size(0), 2, 61 | int(math.floor(ref[0].size(2) / 2.0)), 62 | int(math.floor(ref[0].size(3) / 2.0))]) 63 | 64 | for level in range(len(ref)): 65 | upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 66 | 67 | if upsampled_flow.size(2) != ref[level].size(2): 68 | upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate') 69 | if upsampled_flow.size(3) != ref[level].size(3): 70 | upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate') 71 | 72 | flow = self.basic_module[level](torch.cat([ 73 | ref[level], 74 | flow_warp( 75 | supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'), 76 | upsampled_flow 77 | ], 1)) + upsampled_flow 78 | 79 | return flow 80 | 81 | def forward(self, ref, supp): 82 | assert ref.size() == supp.size() 83 | 84 | h, w = ref.size(2), ref.size(3) 85 | w_floor = math.floor(math.ceil(w / 32.0) * 32.0) 86 | h_floor = math.floor(math.ceil(h / 32.0) * 32.0) 87 | 88 | ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False) 89 | supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False) 90 | 91 | flow = F.interpolate(input=self.process(ref, supp), size=(h, w), mode='bilinear', align_corners=False) 92 | 93 | flow[:, 0, :, :] *= float(w) / float(w_floor) 94 | flow[:, 1, :, :] *= float(h) / float(h_floor) 95 | 96 | return flow 97 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/archs/srresnet_arch.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | 4 | from basicsr.utils.registry import ARCH_REGISTRY 5 | from .arch_util import ResidualBlockNoBN, default_init_weights, make_layer 6 | 7 | 8 | @ARCH_REGISTRY.register() 9 | class MSRResNet(nn.Module): 10 | """Modified SRResNet. 11 | 12 | A compacted version modified from SRResNet in 13 | "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network" 14 | It uses residual blocks without BN, similar to EDSR. 15 | Currently, it supports x2, x3 and x4 upsampling scale factor. 16 | 17 | Args: 18 | num_in_ch (int): Channel number of inputs. Default: 3. 19 | num_out_ch (int): Channel number of outputs. Default: 3. 20 | num_feat (int): Channel number of intermediate features. Default: 64. 21 | num_block (int): Block number in the body network. Default: 16. 22 | upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4. 23 | """ 24 | 25 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, upscale=4): 26 | super(MSRResNet, self).__init__() 27 | self.upscale = upscale 28 | 29 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 30 | self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat) 31 | 32 | # upsampling 33 | if self.upscale in [2, 3]: 34 | self.upconv1 = nn.Conv2d(num_feat, num_feat * self.upscale * self.upscale, 3, 1, 1) 35 | self.pixel_shuffle = nn.PixelShuffle(self.upscale) 36 | elif self.upscale == 4: 37 | self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) 38 | self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) 39 | self.pixel_shuffle = nn.PixelShuffle(2) 40 | 41 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 42 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 43 | 44 | # activation function 45 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 46 | 47 | # initialization 48 | default_init_weights([self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1) 49 | if self.upscale == 4: 50 | default_init_weights(self.upconv2, 0.1) 51 | 52 | def forward(self, x): 53 | feat = self.lrelu(self.conv_first(x)) 54 | out = self.body(feat) 55 | 56 | if self.upscale == 4: 57 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 58 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) 59 | elif self.upscale in [2, 3]: 60 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 61 | 62 | out = self.conv_last(self.lrelu(self.conv_hr(out))) 63 | base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False) 64 | out += base 65 | return out 66 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/archs/srvgg_arch.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | 4 | from basicsr.utils.registry import ARCH_REGISTRY 5 | 6 | 7 | @ARCH_REGISTRY.register(suffix='basicsr') 8 | class SRVGGNetCompact(nn.Module): 9 | """A compact VGG-style network structure for super-resolution. 10 | 11 | It is a compact network structure, which performs upsampling in the last layer and no convolution is 12 | conducted on the HR feature space. 13 | 14 | Args: 15 | num_in_ch (int): Channel number of inputs. Default: 3. 16 | num_out_ch (int): Channel number of outputs. Default: 3. 17 | num_feat (int): Channel number of intermediate features. Default: 64. 18 | num_conv (int): Number of convolution layers in the body network. Default: 16. 19 | upscale (int): Upsampling factor. Default: 4. 20 | act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu. 21 | """ 22 | 23 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): 24 | super(SRVGGNetCompact, self).__init__() 25 | self.num_in_ch = num_in_ch 26 | self.num_out_ch = num_out_ch 27 | self.num_feat = num_feat 28 | self.num_conv = num_conv 29 | self.upscale = upscale 30 | self.act_type = act_type 31 | 32 | self.body = nn.ModuleList() 33 | # the first conv 34 | self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) 35 | # the first activation 36 | if act_type == 'relu': 37 | activation = nn.ReLU(inplace=True) 38 | elif act_type == 'prelu': 39 | activation = nn.PReLU(num_parameters=num_feat) 40 | elif act_type == 'leakyrelu': 41 | activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 42 | self.body.append(activation) 43 | 44 | # the body structure 45 | for _ in range(num_conv): 46 | self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) 47 | # activation 48 | if act_type == 'relu': 49 | activation = nn.ReLU(inplace=True) 50 | elif act_type == 'prelu': 51 | activation = nn.PReLU(num_parameters=num_feat) 52 | elif act_type == 'leakyrelu': 53 | activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 54 | self.body.append(activation) 55 | 56 | # the last conv 57 | self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) 58 | # upsample 59 | self.upsampler = nn.PixelShuffle(upscale) 60 | 61 | def forward(self, x): 62 | out = x 63 | for i in range(0, len(self.body)): 64 | out = self.body[i](out) 65 | 66 | out = self.upsampler(out) 67 | # add the nearest upsampled image, so that the network learns the residual 68 | base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') 69 | out += base 70 | return out 71 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from copy import deepcopy 7 | from functools import partial 8 | from os import path as osp 9 | 10 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 11 | from basicsr.utils import get_root_logger, scandir 12 | from basicsr.utils.dist_util import get_dist_info 13 | from basicsr.utils.registry import DATASET_REGISTRY 14 | 15 | __all__ = ['build_dataset', 'build_dataloader'] 16 | 17 | # automatically scan and import dataset modules for registry 18 | # scan all the files under the data folder with '_dataset' in file names 19 | data_folder = osp.dirname(osp.abspath(__file__)) 20 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 21 | # import all the dataset modules 22 | _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames] 23 | 24 | 25 | def build_dataset(dataset_opt): 26 | """Build dataset from options. 27 | 28 | Args: 29 | dataset_opt (dict): Configuration for dataset. It must contain: 30 | name (str): Dataset name. 31 | type (str): Dataset type. 32 | """ 33 | dataset_opt = deepcopy(dataset_opt) 34 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) 35 | logger = get_root_logger() 36 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.') 37 | return dataset 38 | 39 | 40 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): 41 | """Build dataloader. 42 | 43 | Args: 44 | dataset (torch.utils.data.Dataset): Dataset. 45 | dataset_opt (dict): Dataset options. It contains the following keys: 46 | phase (str): 'train' or 'val'. 47 | num_worker_per_gpu (int): Number of workers for each GPU. 48 | batch_size_per_gpu (int): Training batch size for each GPU. 49 | num_gpu (int): Number of GPUs. Used only in the train phase. 50 | Default: 1. 51 | dist (bool): Whether in distributed training. Used only in the train 52 | phase. Default: False. 53 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 54 | seed (int | None): Seed. Default: None 55 | """ 56 | phase = dataset_opt['phase'] 57 | rank, _ = get_dist_info() 58 | if phase == 'train': 59 | if dist: # distributed training 60 | batch_size = dataset_opt['batch_size_per_gpu'] 61 | num_workers = dataset_opt['num_worker_per_gpu'] 62 | else: # non-distributed training 63 | multiplier = 1 if num_gpu == 0 else num_gpu 64 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 65 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 66 | dataloader_args = dict( 67 | dataset=dataset, 68 | batch_size=batch_size, 69 | shuffle=False, 70 | num_workers=num_workers, 71 | sampler=sampler, 72 | drop_last=True) 73 | if sampler is None: 74 | dataloader_args['shuffle'] = True 75 | dataloader_args['worker_init_fn'] = partial( 76 | worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None 77 | elif phase in ['val', 'test']: # validation 78 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 79 | else: 80 | raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.") 81 | 82 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 83 | dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False) 84 | 85 | prefetch_mode = dataset_opt.get('prefetch_mode') 86 | if prefetch_mode == 'cpu': # CPUPrefetcher 87 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 88 | logger = get_root_logger() 89 | logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}') 90 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) 91 | else: 92 | # prefetch_mode=None: Normal dataloader 93 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 94 | return torch.utils.data.DataLoader(**dataloader_args) 95 | 96 | 97 | def worker_init_fn(worker_id, num_workers, rank, seed): 98 | # Set the worker seed to num_workers * rank + worker_id + seed 99 | worker_seed = num_workers * rank + worker_id + seed 100 | np.random.seed(worker_seed) 101 | random.seed(worker_seed) 102 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) 27 | self.total_size = self.num_samples * self.num_replicas 28 | 29 | def __iter__(self): 30 | # deterministically shuffle based on epoch 31 | g = torch.Generator() 32 | g.manual_seed(self.epoch) 33 | indices = torch.randperm(self.total_size, generator=g).tolist() 34 | 35 | dataset_size = len(self.dataset) 36 | indices = [v % dataset_size for v in indices] 37 | 38 | # subsample 39 | indices = indices[self.rank:self.total_size:self.num_replicas] 40 | assert len(indices) == self.num_samples 41 | 42 | return iter(indices) 43 | 44 | def __len__(self): 45 | return self.num_samples 46 | 47 | def set_epoch(self, epoch): 48 | self.epoch = epoch 49 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/data/ffhq_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | from os import path as osp 4 | from torch.utils import data as data 5 | from torchvision.transforms.functional import normalize 6 | 7 | from basicsr.data.transforms import augment 8 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor 9 | from basicsr.utils.registry import DATASET_REGISTRY 10 | 11 | 12 | @DATASET_REGISTRY.register() 13 | class FFHQDataset(data.Dataset): 14 | """FFHQ dataset for StyleGAN. 15 | 16 | Args: 17 | opt (dict): Config for train datasets. It contains the following keys: 18 | dataroot_gt (str): Data root path for gt. 19 | io_backend (dict): IO backend type and other kwarg. 20 | mean (list | tuple): Image mean. 21 | std (list | tuple): Image std. 22 | use_hflip (bool): Whether to horizontally flip. 23 | 24 | """ 25 | 26 | def __init__(self, opt): 27 | super(FFHQDataset, self).__init__() 28 | self.opt = opt 29 | # file client (io backend) 30 | self.file_client = None 31 | self.io_backend_opt = opt['io_backend'] 32 | 33 | self.gt_folder = opt['dataroot_gt'] 34 | self.mean = opt['mean'] 35 | self.std = opt['std'] 36 | 37 | if self.io_backend_opt['type'] == 'lmdb': 38 | self.io_backend_opt['db_paths'] = self.gt_folder 39 | if not self.gt_folder.endswith('.lmdb'): 40 | raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") 41 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: 42 | self.paths = [line.split('.')[0] for line in fin] 43 | else: 44 | # FFHQ has 70000 images in total 45 | self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)] 46 | 47 | def __getitem__(self, index): 48 | if self.file_client is None: 49 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 50 | 51 | # load gt image 52 | gt_path = self.paths[index] 53 | # avoid errors caused by high latency in reading files 54 | retry = 3 55 | while retry > 0: 56 | try: 57 | img_bytes = self.file_client.get(gt_path) 58 | except Exception as e: 59 | logger = get_root_logger() 60 | logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}') 61 | # change another file to read 62 | index = random.randint(0, self.__len__()) 63 | gt_path = self.paths[index] 64 | time.sleep(1) # sleep 1s for occasional server congestion 65 | else: 66 | break 67 | finally: 68 | retry -= 1 69 | img_gt = imfrombytes(img_bytes, float32=True) 70 | 71 | # random horizontal flip 72 | img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) 73 | # BGR to RGB, HWC to CHW, numpy to tensor 74 | img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) 75 | # normalize 76 | normalize(img_gt, self.mean, self.std, inplace=True) 77 | return {'gt': img_gt, 'gt_path': gt_path} 78 | 79 | def __len__(self): 80 | return len(self.paths) 81 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/data/paired_image_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data as data 2 | from torchvision.transforms.functional import normalize 3 | 4 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file, paired_paths_from_meta_info_file_2 5 | from basicsr.data.transforms import augment, paired_random_crop 6 | from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor 7 | from basicsr.utils.registry import DATASET_REGISTRY 8 | import cv2 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class PairedImageDataset(data.Dataset): 13 | """Paired image dataset for image restoration. 14 | 15 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. 16 | 17 | There are three modes: 18 | 19 | 1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb. 20 | 2. **meta_info_file**: Use meta information file to generate paths. \ 21 | If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. 22 | 3. **folder**: Scan folders to generate paths. The rest. 23 | 24 | Args: 25 | opt (dict): Config for train datasets. It contains the following keys: 26 | dataroot_gt (str): Data root path for gt. 27 | dataroot_lq (str): Data root path for lq. 28 | meta_info_file (str): Path for meta information file. 29 | io_backend (dict): IO backend type and other kwarg. 30 | filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. 31 | Default: '{}'. 32 | gt_size (int): Cropped patched size for gt patches. 33 | use_hflip (bool): Use horizontal flips. 34 | use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). 35 | scale (bool): Scale, which will be added automatically. 36 | phase (str): 'train' or 'val'. 37 | """ 38 | 39 | def __init__(self, opt): 40 | super(PairedImageDataset, self).__init__() 41 | self.opt = opt 42 | # file client (io backend) 43 | self.file_client = None 44 | self.io_backend_opt = opt['io_backend'] 45 | self.mean = opt['mean'] if 'mean' in opt else None 46 | self.std = opt['std'] if 'std' in opt else None 47 | 48 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] 49 | if 'filename_tmpl' in opt: 50 | self.filename_tmpl = opt['filename_tmpl'] 51 | else: 52 | self.filename_tmpl = '{}' 53 | 54 | if self.io_backend_opt['type'] == 'lmdb': 55 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] 56 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 57 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) 58 | elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None: 59 | self.paths = paired_paths_from_meta_info_file_2([self.lq_folder, self.gt_folder], ['lq', 'gt'], 60 | self.opt['meta_info_file'], self.filename_tmpl) 61 | else: 62 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) 63 | 64 | def __getitem__(self, index): 65 | if self.file_client is None: 66 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 67 | 68 | scale = self.opt['scale'] 69 | 70 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 71 | # image range: [0, 1], float32. 72 | gt_path = self.paths[index]['gt_path'] 73 | img_bytes = self.file_client.get(gt_path, 'gt') 74 | img_gt = imfrombytes(img_bytes, float32=True) 75 | lq_path = self.paths[index]['lq_path'] 76 | img_bytes = self.file_client.get(lq_path, 'lq') 77 | # print("gt_path, lq_path", gt_path, lq_path) 78 | img_lq = imfrombytes(img_bytes, float32=True) 79 | 80 | h, w = img_gt.shape[0:2] 81 | # print("---------------img_gt.shape[0:2]", img_gt.shape[0:2]) 82 | # pad 83 | if h < self.opt['gt_size'] or w < self.opt['gt_size']: 84 | pad_h = max(0, self.opt['gt_size'] - h) 85 | pad_w = max(0, self.opt['gt_size'] - w) 86 | img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101) 87 | img_lq = cv2.copyMakeBorder(img_lq, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101) 88 | 89 | # augmentation for training 90 | if self.opt['phase'] == 'train': 91 | gt_size = self.opt['gt_size'] 92 | # random crop 93 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 94 | # flip, rotation 95 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) 96 | 97 | # color space transform 98 | if 'color' in self.opt and self.opt['color'] == 'y': 99 | img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None] 100 | img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None] 101 | 102 | # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets 103 | # TODO: It is better to update the datasets, rather than force to crop 104 | if self.opt['phase'] != 'train': 105 | img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :] 106 | # print("img_gt.shape[0:2]", img_gt.shape[0:2]) 107 | # print("img_lq.shape[0:2]", img_lq.shape[0:2]) 108 | 109 | # BGR to RGB, HWC to CHW, numpy to tensor 110 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 111 | # normalize 112 | if self.mean is not None or self.std is not None: 113 | normalize(img_lq, self.mean, self.std, inplace=True) 114 | normalize(img_gt, self.mean, self.std, inplace=True) 115 | 116 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} 117 | 118 | def __len__(self): 119 | return len(self.paths) 120 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 11 | 12 | Args: 13 | generator: Python generator. 14 | num_prefetch_queue (int): Number of prefetch queue. 15 | """ 16 | 17 | def __init__(self, generator, num_prefetch_queue): 18 | threading.Thread.__init__(self) 19 | self.queue = Queue.Queue(num_prefetch_queue) 20 | self.generator = generator 21 | self.daemon = True 22 | self.start() 23 | 24 | def run(self): 25 | for item in self.generator: 26 | self.queue.put(item) 27 | self.queue.put(None) 28 | 29 | def __next__(self): 30 | next_item = self.queue.get() 31 | if next_item is None: 32 | raise StopIteration 33 | return next_item 34 | 35 | def __iter__(self): 36 | return self 37 | 38 | 39 | class PrefetchDataLoader(DataLoader): 40 | """Prefetch version of dataloader. 41 | 42 | Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 43 | 44 | TODO: 45 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 46 | ddp. 47 | 48 | Args: 49 | num_prefetch_queue (int): Number of prefetch queue. 50 | kwargs (dict): Other arguments for dataloader. 51 | """ 52 | 53 | def __init__(self, num_prefetch_queue, **kwargs): 54 | self.num_prefetch_queue = num_prefetch_queue 55 | super(PrefetchDataLoader, self).__init__(**kwargs) 56 | 57 | def __iter__(self): 58 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 59 | 60 | 61 | class CPUPrefetcher(): 62 | """CPU prefetcher. 63 | 64 | Args: 65 | loader: Dataloader. 66 | """ 67 | 68 | def __init__(self, loader): 69 | self.ori_loader = loader 70 | self.loader = iter(loader) 71 | 72 | def next(self): 73 | try: 74 | return next(self.loader) 75 | except StopIteration: 76 | return None 77 | 78 | def reset(self): 79 | self.loader = iter(self.ori_loader) 80 | 81 | 82 | class CUDAPrefetcher(): 83 | """CUDA prefetcher. 84 | 85 | Reference: https://github.com/NVIDIA/apex/issues/304# 86 | 87 | It may consume more GPU memory. 88 | 89 | Args: 90 | loader: Dataloader. 91 | opt (dict): Options. 92 | """ 93 | 94 | def __init__(self, loader, opt): 95 | self.ori_loader = loader 96 | self.loader = iter(loader) 97 | self.opt = opt 98 | self.stream = torch.cuda.Stream() 99 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 100 | self.preload() 101 | 102 | def preload(self): 103 | try: 104 | self.batch = next(self.loader) # self.batch is a dict 105 | except StopIteration: 106 | self.batch = None 107 | return None 108 | # put tensors to gpu 109 | with torch.cuda.stream(self.stream): 110 | for k, v in self.batch.items(): 111 | if torch.is_tensor(v): 112 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) 113 | 114 | def next(self): 115 | torch.cuda.current_stream().wait_stream(self.stream) 116 | batch = self.batch 117 | self.preload() 118 | return batch 119 | 120 | def reset(self): 121 | self.loader = iter(self.ori_loader) 122 | self.preload() 123 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/data/realesrgan_paired_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb 6 | from basicsr.data.transforms import augment, paired_random_crop 7 | from basicsr.utils import FileClient, imfrombytes, img2tensor 8 | from basicsr.utils.registry import DATASET_REGISTRY 9 | 10 | 11 | @DATASET_REGISTRY.register(suffix='basicsr') 12 | class RealESRGANPairedDataset(data.Dataset): 13 | """Paired image dataset for image restoration. 14 | 15 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. 16 | 17 | There are three modes: 18 | 19 | 1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb. 20 | 2. **meta_info_file**: Use meta information file to generate paths. \ 21 | If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. 22 | 3. **folder**: Scan folders to generate paths. The rest. 23 | 24 | Args: 25 | opt (dict): Config for train datasets. It contains the following keys: 26 | dataroot_gt (str): Data root path for gt. 27 | dataroot_lq (str): Data root path for lq. 28 | meta_info (str): Path for meta information file. 29 | io_backend (dict): IO backend type and other kwarg. 30 | filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. 31 | Default: '{}'. 32 | gt_size (int): Cropped patched size for gt patches. 33 | use_hflip (bool): Use horizontal flips. 34 | use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). 35 | scale (bool): Scale, which will be added automatically. 36 | phase (str): 'train' or 'val'. 37 | """ 38 | 39 | def __init__(self, opt): 40 | super(RealESRGANPairedDataset, self).__init__() 41 | self.opt = opt 42 | self.file_client = None 43 | self.io_backend_opt = opt['io_backend'] 44 | # mean and std for normalizing the input images 45 | self.mean = opt['mean'] if 'mean' in opt else None 46 | self.std = opt['std'] if 'std' in opt else None 47 | 48 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] 49 | self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}' 50 | 51 | # file client (lmdb io backend) 52 | if self.io_backend_opt['type'] == 'lmdb': 53 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] 54 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 55 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) 56 | elif 'meta_info' in self.opt and self.opt['meta_info'] is not None: 57 | # disk backend with meta_info 58 | # Each line in the meta_info describes the relative path to an image 59 | with open(self.opt['meta_info']) as fin: 60 | paths = [line.strip() for line in fin] 61 | self.paths = [] 62 | for path in paths: 63 | gt_path, lq_path = path.split(', ') 64 | gt_path = os.path.join(self.gt_folder, gt_path) 65 | lq_path = os.path.join(self.lq_folder, lq_path) 66 | self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)])) 67 | else: 68 | # disk backend 69 | # it will scan the whole folder to get meta info 70 | # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file 71 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) 72 | 73 | if 'num_pic' in self.opt: 74 | self.paths = self.paths[:self.opt['num_pic']] 75 | if 'phase' not in self.opt: 76 | self.opt['phase'] = 'test' 77 | if 'scale' not in self.opt: 78 | self.opt['scale'] = 1 79 | 80 | 81 | def __getitem__(self, index): 82 | if self.file_client is None: 83 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 84 | 85 | scale = self.opt['scale'] 86 | 87 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 88 | # image range: [0, 1], float32. 89 | gt_path = self.paths[index]['gt_path'] 90 | img_bytes = self.file_client.get(gt_path, 'gt') 91 | img_gt = imfrombytes(img_bytes, float32=True) 92 | lq_path = self.paths[index]['lq_path'] 93 | img_bytes = self.file_client.get(lq_path, 'lq') 94 | img_lq = imfrombytes(img_bytes, float32=True) 95 | 96 | # augmentation for training 97 | if self.opt['phase'] == 'train': 98 | gt_size = self.opt['gt_size'] 99 | # random crop 100 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 101 | # flip, rotation 102 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) 103 | 104 | # BGR to RGB, HWC to CHW, numpy to tensor 105 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 106 | # normalize 107 | if self.mean is not None or self.std is not None: 108 | normalize(img_lq, self.mean, self.std, inplace=True) 109 | normalize(img_gt, self.mean, self.std, inplace=True) 110 | 111 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} 112 | 113 | def __len__(self): 114 | return len(self.paths) 115 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/losses/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import LOSS_REGISTRY 7 | from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty 8 | 9 | __all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize'] 10 | 11 | # automatically scan and import loss modules for registry 12 | # scan all the files under the 'losses' folder and collect files ending with '_loss.py' 13 | loss_folder = osp.dirname(osp.abspath(__file__)) 14 | loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')] 15 | # import all the loss modules 16 | _model_modules = [importlib.import_module(f'basicsr.losses.{file_name}') for file_name in loss_filenames] 17 | 18 | 19 | def build_loss(opt): 20 | """Build loss from options. 21 | 22 | Args: 23 | opt (dict): Configuration. It must contain: 24 | type (str): Model type. 25 | """ 26 | opt = deepcopy(opt) 27 | loss_type = opt.pop('type') 28 | loss = LOSS_REGISTRY.get(loss_type)(**opt) 29 | logger = get_root_logger() 30 | logger.info(f'Loss [{loss.__class__.__name__}] is created.') 31 | return loss 32 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | def reduce_loss(loss, reduction): 7 | """Reduce loss as specified. 8 | 9 | Args: 10 | loss (Tensor): Elementwise loss tensor. 11 | reduction (str): Options are 'none', 'mean' and 'sum'. 12 | 13 | Returns: 14 | Tensor: Reduced loss tensor. 15 | """ 16 | reduction_enum = F._Reduction.get_enum(reduction) 17 | # none: 0, elementwise_mean:1, sum: 2 18 | if reduction_enum == 0: 19 | return loss 20 | elif reduction_enum == 1: 21 | return loss.mean() 22 | else: 23 | return loss.sum() 24 | 25 | 26 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 27 | """Apply element-wise weight and reduce loss. 28 | 29 | Args: 30 | loss (Tensor): Element-wise loss. 31 | weight (Tensor): Element-wise weights. Default: None. 32 | reduction (str): Same as built-in losses of PyTorch. Options are 33 | 'none', 'mean' and 'sum'. Default: 'mean'. 34 | 35 | Returns: 36 | Tensor: Loss values. 37 | """ 38 | # if weight is specified, apply element-wise weight 39 | if weight is not None: 40 | assert weight.dim() == loss.dim() 41 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 42 | loss = loss * weight 43 | 44 | # if weight is not specified or reduction is sum, just reduce the loss 45 | if weight is None or reduction == 'sum': 46 | loss = reduce_loss(loss, reduction) 47 | # if reduction is mean, then compute mean over weight region 48 | elif reduction == 'mean': 49 | if weight.size(1) > 1: 50 | weight = weight.sum() 51 | else: 52 | weight = weight.sum() * loss.size(1) 53 | loss = loss.sum() / weight 54 | 55 | return loss 56 | 57 | 58 | def weighted_loss(loss_func): 59 | """Create a weighted version of a given loss function. 60 | 61 | To use this decorator, the loss function must have the signature like 62 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 63 | element-wise loss without any reduction. This decorator will add weight 64 | and reduction arguments to the function. The decorated function will have 65 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 66 | **kwargs)`. 67 | 68 | :Example: 69 | 70 | >>> import torch 71 | >>> @weighted_loss 72 | >>> def l1_loss(pred, target): 73 | >>> return (pred - target).abs() 74 | 75 | >>> pred = torch.Tensor([0, 2, 3]) 76 | >>> target = torch.Tensor([1, 1, 1]) 77 | >>> weight = torch.Tensor([1, 0, 1]) 78 | 79 | >>> l1_loss(pred, target) 80 | tensor(1.3333) 81 | >>> l1_loss(pred, target, weight) 82 | tensor(1.5000) 83 | >>> l1_loss(pred, target, reduction='none') 84 | tensor([1., 1., 2.]) 85 | >>> l1_loss(pred, target, weight, reduction='sum') 86 | tensor(3.) 87 | """ 88 | 89 | @functools.wraps(loss_func) 90 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 91 | # get element-wise loss 92 | loss = loss_func(pred, target, **kwargs) 93 | loss = weight_reduce_loss(loss, weight, reduction) 94 | return loss 95 | 96 | return wrapper 97 | 98 | 99 | def get_local_weights(residual, ksize): 100 | """Get local weights for generating the artifact map of LDL. 101 | 102 | It is only called by the `get_refined_artifact_map` function. 103 | 104 | Args: 105 | residual (Tensor): Residual between predicted and ground truth images. 106 | ksize (Int): size of the local window. 107 | 108 | Returns: 109 | Tensor: weight for each pixel to be discriminated as an artifact pixel 110 | """ 111 | 112 | pad = (ksize - 1) // 2 113 | residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect') 114 | 115 | unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1) 116 | pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1) 117 | 118 | return pixel_level_weight 119 | 120 | 121 | def get_refined_artifact_map(img_gt, img_output, img_ema, ksize): 122 | """Calculate the artifact map of LDL 123 | (Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022) 124 | 125 | Args: 126 | img_gt (Tensor): ground truth images. 127 | img_output (Tensor): output images given by the optimizing model. 128 | img_ema (Tensor): output images given by the ema model. 129 | ksize (Int): size of the local window. 130 | 131 | Returns: 132 | overall_weight: weight for each pixel to be discriminated as an artifact pixel 133 | (calculated based on both local and global observations). 134 | """ 135 | 136 | residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True) 137 | residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True) 138 | 139 | patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5) 140 | pixel_level_weight = get_local_weights(residual_sr.clone(), ksize) 141 | overall_weight = patch_level_weight * pixel_level_weight 142 | 143 | overall_weight[residual_sr < residual_ema] = 0 144 | 145 | return overall_weight 146 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/metrics/README.md: -------------------------------------------------------------------------------- 1 | # Metrics 2 | 3 | [English](README.md) **|** [简体中文](README_CN.md) 4 | 5 | - [约定](#约定) 6 | - [PSNR 和 SSIM](#psnr-和-ssim) 7 | 8 | ## 约定 9 | 10 | 因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定: 11 | 12 | - Numpy 类型 (一般是 cv2 的结果) 13 | - UINT8: BGR, [0, 255], (h, w, c) 14 | - float: BGR, [0, 1], (h, w, c). 一般作为中间结果 15 | - Tensor 类型 16 | - float: RGB, [0, 1], (n, c, h, w) 17 | 18 | 其他约定: 19 | 20 | - 以 `_pt` 结尾的是 PyTorch 结果 21 | - PyTorch version 支持 batch 计算 22 | - 颜色转换在 float32 上做;metric计算在 float64 上做 23 | 24 | ## PSNR 和 SSIM 25 | 26 | PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。 27 | 在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate) 的 [evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378)) 28 | 29 | 下面列了各个实现的结果比对. 30 | 总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异 31 | 32 | - PSNR 比对 33 | 34 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | 35 | |:---| :---: | :---: | :---: | :---: | :---: | 36 | |baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 | 37 | |baboon| Y | - |22.441898 | 22.441899 | 22.444916| 38 | |comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 | 39 | |comic | Y | - | 21.720398 | 21.720398 | 21.721663| 40 | 41 | - SSIM 比对 42 | 43 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | 44 | |:---| :---: | :---: | :---: | :---: | :---: | 45 | |baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 | 46 | |baboon| Y | - |0.453097| 0.453097 | 0.453171| 47 | |comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738| 48 | |comic | Y | - | 0.585511 | 0.585511 | 0.585522 | 49 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/metrics/README_CN.md: -------------------------------------------------------------------------------- 1 | # Metrics 2 | 3 | [English](README.md) **|** [简体中文](README_CN.md) 4 | 5 | - [约定](#约定) 6 | - [PSNR 和 SSIM](#psnr-和-ssim) 7 | 8 | ## 约定 9 | 10 | 因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定: 11 | 12 | - Numpy 类型 (一般是 cv2 的结果) 13 | - UINT8: BGR, [0, 255], (h, w, c) 14 | - float: BGR, [0, 1], (h, w, c). 一般作为中间结果 15 | - Tensor 类型 16 | - float: RGB, [0, 1], (n, c, h, w) 17 | 18 | 其他约定: 19 | 20 | - 以 `_pt` 结尾的是 PyTorch 结果 21 | - PyTorch version 支持 batch 计算 22 | - 颜色转换在 float32 上做;metric计算在 float64 上做 23 | 24 | ## PSNR 和 SSIM 25 | 26 | PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。 27 | 在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate) 的 [evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378)) 28 | 29 | 下面列了各个实现的结果比对. 30 | 总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异 31 | 32 | - PSNR 比对 33 | 34 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | 35 | |:---| :---: | :---: | :---: | :---: | :---: | 36 | |baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 | 37 | |baboon| Y | - |22.441898 | 22.441899 | 22.444916| 38 | |comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 | 39 | |comic | Y | - | 21.720398 | 21.720398 | 21.721663| 40 | 41 | - SSIM 比对 42 | 43 | |Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU | 44 | |:---| :---: | :---: | :---: | :---: | :---: | 45 | |baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 | 46 | |baboon| Y | - |0.453097| 0.453097 | 0.453171| 47 | |comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738| 48 | |comic | Y | - | 0.585511 | 0.585511 | 0.585522 | 49 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from basicsr.utils.registry import METRIC_REGISTRY 4 | from .niqe import calculate_niqe 5 | from .psnr_ssim import calculate_psnr, calculate_ssim, calculate_ssim_pt, calculate_psnr_pt 6 | 7 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] 8 | 9 | 10 | def calculate_metric(data, opt): 11 | """Calculate metric from data and options. 12 | 13 | Args: 14 | opt (dict): Configuration. It must contain: 15 | type (str): Model type. 16 | """ 17 | opt = deepcopy(opt) 18 | metric_type = opt.pop('type') 19 | metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) 20 | return metric 21 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/metrics/fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from scipy import linalg 5 | from tqdm import tqdm 6 | 7 | from basicsr.archs.inception import InceptionV3 8 | 9 | 10 | def load_patched_inception_v3(device='cuda', resize_input=True, normalize_input=False): 11 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it 12 | # does resize the input. 13 | inception = InceptionV3([3], resize_input=resize_input, normalize_input=normalize_input) 14 | inception = nn.DataParallel(inception).eval().to(device) 15 | return inception 16 | 17 | 18 | @torch.no_grad() 19 | def extract_inception_features(data_generator, inception, len_generator=None, device='cuda'): 20 | """Extract inception features. 21 | 22 | Args: 23 | data_generator (generator): A data generator. 24 | inception (nn.Module): Inception model. 25 | len_generator (int): Length of the data_generator to show the 26 | progressbar. Default: None. 27 | device (str): Device. Default: cuda. 28 | 29 | Returns: 30 | Tensor: Extracted features. 31 | """ 32 | if len_generator is not None: 33 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract') 34 | else: 35 | pbar = None 36 | features = [] 37 | 38 | for data in data_generator: 39 | if pbar: 40 | pbar.update(1) 41 | data = data.to(device) 42 | feature = inception(data)[0].view(data.shape[0], -1) 43 | features.append(feature.to('cpu')) 44 | if pbar: 45 | pbar.close() 46 | features = torch.cat(features, 0) 47 | return features 48 | 49 | 50 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): 51 | """Numpy implementation of the Frechet Distance. 52 | 53 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is: 54 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 55 | Stable version by Dougal J. Sutherland. 56 | 57 | Args: 58 | mu1 (np.array): The sample mean over activations. 59 | sigma1 (np.array): The covariance matrix over activations for generated samples. 60 | mu2 (np.array): The sample mean over activations, precalculated on an representative data set. 61 | sigma2 (np.array): The covariance matrix over activations, precalculated on an representative data set. 62 | 63 | Returns: 64 | float: The Frechet Distance. 65 | """ 66 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' 67 | assert sigma1.shape == sigma2.shape, ('Two covariances have different dimensions') 68 | 69 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) 70 | 71 | # Product might be almost singular 72 | if not np.isfinite(cov_sqrt).all(): 73 | print('Product of cov matrices is singular. Adding {eps} to diagonal of cov estimates') 74 | offset = np.eye(sigma1.shape[0]) * eps 75 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) 76 | 77 | # Numerical error might give slight imaginary component 78 | if np.iscomplexobj(cov_sqrt): 79 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 80 | m = np.max(np.abs(cov_sqrt.imag)) 81 | raise ValueError(f'Imaginary component {m}') 82 | cov_sqrt = cov_sqrt.real 83 | 84 | mean_diff = mu1 - mu2 85 | mean_norm = mean_diff @ mean_diff 86 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) 87 | fid = mean_norm + trace 88 | 89 | return fid 90 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from basicsr.utils import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'") 25 | if len(img.shape) == 2: 26 | img = img[..., None] 27 | if input_order == 'CHW': 28 | img = img.transpose(1, 2, 0) 29 | return img 30 | 31 | 32 | def to_y_channel(img): 33 | """Change to Y channel of YCbCr. 34 | 35 | Args: 36 | img (ndarray): Images with range [0, 255]. 37 | 38 | Returns: 39 | (ndarray): Images with range [0, 255] (float type) without round. 40 | """ 41 | img = img.astype(np.float32) / 255. 42 | if img.ndim == 3 and img.shape[2] == 3: 43 | img = bgr2ycbcr(img, y_only=True) 44 | img = img[..., None] 45 | return img * 255. 46 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/metrics/niqe_pris_params.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/PhoCoLens/154fe32aea5c2b623f6dd1c07e90c2900d076486/NullSpaceDiff/basicsr/metrics/niqe_pris_params.npz -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/metrics/test_metrics/test_psnr_ssim.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | 4 | from basicsr.metrics import calculate_psnr, calculate_ssim 5 | from basicsr.metrics.psnr_ssim import calculate_psnr_pt, calculate_ssim_pt 6 | from basicsr.utils import img2tensor 7 | 8 | 9 | def test(img_path, img_path2, crop_border, test_y_channel=False): 10 | img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) 11 | img2 = cv2.imread(img_path2, cv2.IMREAD_UNCHANGED) 12 | 13 | # --------------------- Numpy --------------------- 14 | psnr = calculate_psnr(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel) 15 | ssim = calculate_ssim(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel) 16 | print(f'\tNumpy\tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}') 17 | 18 | # --------------------- PyTorch (CPU) --------------------- 19 | img = img2tensor(img / 255., bgr2rgb=True, float32=True).unsqueeze_(0) 20 | img2 = img2tensor(img2 / 255., bgr2rgb=True, float32=True).unsqueeze_(0) 21 | 22 | psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 23 | ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 24 | print(f'\tTensor (CPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}') 25 | 26 | # --------------------- PyTorch (GPU) --------------------- 27 | img = img.cuda() 28 | img2 = img2.cuda() 29 | psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 30 | ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 31 | print(f'\tTensor (GPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}') 32 | 33 | psnr_pth = calculate_psnr_pt( 34 | torch.repeat_interleave(img, 2, dim=0), 35 | torch.repeat_interleave(img2, 2, dim=0), 36 | crop_border=crop_border, 37 | test_y_channel=test_y_channel) 38 | ssim_pth = calculate_ssim_pt( 39 | torch.repeat_interleave(img, 2, dim=0), 40 | torch.repeat_interleave(img2, 2, dim=0), 41 | crop_border=crop_border, 42 | test_y_channel=test_y_channel) 43 | print(f'\tTensor (GPU batch) \tPSNR: {psnr_pth[0]:.6f}, {psnr_pth[1]:.6f} dB,' 44 | f'\tSSIM: {ssim_pth[0]:.6f}, {ssim_pth[1]:.6f}') 45 | 46 | 47 | if __name__ == '__main__': 48 | test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=False) 49 | test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=True) 50 | 51 | test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=False) 52 | test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=True) 53 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import MODEL_REGISTRY 7 | 8 | __all__ = ['build_model'] 9 | 10 | # automatically scan and import model modules for registry 11 | # scan all the files under the 'models' folder and collect files ending with '_model.py' 12 | model_folder = osp.dirname(osp.abspath(__file__)) 13 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 14 | # import all the model modules 15 | _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] 16 | 17 | 18 | def build_model(opt): 19 | """Build model from options. 20 | 21 | Args: 22 | opt (dict): Configuration. It must contain: 23 | model_type (str): Model type. 24 | """ 25 | opt = deepcopy(opt) 26 | model = MODEL_REGISTRY.get(opt['model_type'])(opt) 27 | logger = get_root_logger() 28 | logger.info(f'Model [{model.__class__.__name__}] is created.') 29 | return model 30 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/models/edvr_model.py: -------------------------------------------------------------------------------- 1 | from basicsr.utils import get_root_logger 2 | from basicsr.utils.registry import MODEL_REGISTRY 3 | from .video_base_model import VideoBaseModel 4 | 5 | 6 | @MODEL_REGISTRY.register() 7 | class EDVRModel(VideoBaseModel): 8 | """EDVR Model. 9 | 10 | Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. # noqa: E501 11 | """ 12 | 13 | def __init__(self, opt): 14 | super(EDVRModel, self).__init__(opt) 15 | if self.is_train: 16 | self.train_tsa_iter = opt['train'].get('tsa_iter') 17 | 18 | def setup_optimizers(self): 19 | train_opt = self.opt['train'] 20 | dcn_lr_mul = train_opt.get('dcn_lr_mul', 1) 21 | logger = get_root_logger() 22 | logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.') 23 | if dcn_lr_mul == 1: 24 | optim_params = self.net_g.parameters() 25 | else: # separate dcn params and normal params for different lr 26 | normal_params = [] 27 | dcn_params = [] 28 | for name, param in self.net_g.named_parameters(): 29 | if 'dcn' in name: 30 | dcn_params.append(param) 31 | else: 32 | normal_params.append(param) 33 | optim_params = [ 34 | { # add normal params first 35 | 'params': normal_params, 36 | 'lr': train_opt['optim_g']['lr'] 37 | }, 38 | { 39 | 'params': dcn_params, 40 | 'lr': train_opt['optim_g']['lr'] * dcn_lr_mul 41 | }, 42 | ] 43 | 44 | optim_type = train_opt['optim_g'].pop('type') 45 | self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) 46 | self.optimizers.append(self.optimizer_g) 47 | 48 | def optimize_parameters(self, current_iter): 49 | if self.train_tsa_iter: 50 | if current_iter == 1: 51 | logger = get_root_logger() 52 | logger.info(f'Only train TSA module for {self.train_tsa_iter} iters.') 53 | for name, param in self.net_g.named_parameters(): 54 | if 'fusion' not in name: 55 | param.requires_grad = False 56 | elif current_iter == self.train_tsa_iter: 57 | logger = get_root_logger() 58 | logger.warning('Train all the parameters.') 59 | for param in self.net_g.parameters(): 60 | param.requires_grad = True 61 | 62 | super(EDVRModel, self).optimize_parameters(current_iter) 63 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/models/esrgan_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | from basicsr.utils.registry import MODEL_REGISTRY 5 | from .srgan_model import SRGANModel 6 | 7 | 8 | @MODEL_REGISTRY.register() 9 | class ESRGANModel(SRGANModel): 10 | """ESRGAN model for single image super-resolution.""" 11 | 12 | def optimize_parameters(self, current_iter): 13 | # optimize net_g 14 | for p in self.net_d.parameters(): 15 | p.requires_grad = False 16 | 17 | self.optimizer_g.zero_grad() 18 | self.output = self.net_g(self.lq) 19 | 20 | l_g_total = 0 21 | loss_dict = OrderedDict() 22 | if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): 23 | # pixel loss 24 | if self.cri_pix: 25 | l_g_pix = self.cri_pix(self.output, self.gt) 26 | l_g_total += l_g_pix 27 | loss_dict['l_g_pix'] = l_g_pix 28 | # perceptual loss 29 | if self.cri_perceptual: 30 | l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) 31 | if l_g_percep is not None: 32 | l_g_total += l_g_percep 33 | loss_dict['l_g_percep'] = l_g_percep 34 | if l_g_style is not None: 35 | l_g_total += l_g_style 36 | loss_dict['l_g_style'] = l_g_style 37 | # gan loss (relativistic gan) 38 | real_d_pred = self.net_d(self.gt).detach() 39 | fake_g_pred = self.net_d(self.output) 40 | l_g_real = self.cri_gan(real_d_pred - torch.mean(fake_g_pred), False, is_disc=False) 41 | l_g_fake = self.cri_gan(fake_g_pred - torch.mean(real_d_pred), True, is_disc=False) 42 | l_g_gan = (l_g_real + l_g_fake) / 2 43 | 44 | l_g_total += l_g_gan 45 | loss_dict['l_g_gan'] = l_g_gan 46 | 47 | l_g_total.backward() 48 | self.optimizer_g.step() 49 | 50 | # optimize net_d 51 | for p in self.net_d.parameters(): 52 | p.requires_grad = True 53 | 54 | self.optimizer_d.zero_grad() 55 | # gan loss (relativistic gan) 56 | 57 | # In order to avoid the error in distributed training: 58 | # "Error detected in CudnnBatchNormBackward: RuntimeError: one of 59 | # the variables needed for gradient computation has been modified by 60 | # an inplace operation", 61 | # we separate the backwards for real and fake, and also detach the 62 | # tensor for calculating mean. 63 | 64 | # real 65 | fake_d_pred = self.net_d(self.output).detach() 66 | real_d_pred = self.net_d(self.gt) 67 | l_d_real = self.cri_gan(real_d_pred - torch.mean(fake_d_pred), True, is_disc=True) * 0.5 68 | l_d_real.backward() 69 | # fake 70 | fake_d_pred = self.net_d(self.output.detach()) 71 | l_d_fake = self.cri_gan(fake_d_pred - torch.mean(real_d_pred.detach()), False, is_disc=True) * 0.5 72 | l_d_fake.backward() 73 | self.optimizer_d.step() 74 | 75 | loss_dict['l_d_real'] = l_d_real 76 | loss_dict['l_d_fake'] = l_d_fake 77 | loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) 78 | loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) 79 | 80 | self.log_dict = self.reduce_loss_dict(loss_dict) 81 | 82 | if self.ema_decay > 0: 83 | self.model_ema(decay=self.ema_decay) 84 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class MultiStepRestartLR(_LRScheduler): 7 | """ MultiStep with restarts learning rate scheme. 8 | 9 | Args: 10 | optimizer (torch.nn.optimizer): Torch optimizer. 11 | milestones (list): Iterations that will decrease learning rate. 12 | gamma (float): Decrease ratio. Default: 0.1. 13 | restarts (list): Restart iterations. Default: [0]. 14 | restart_weights (list): Restart weights at each restart iteration. 15 | Default: [1]. 16 | last_epoch (int): Used in _LRScheduler. Default: -1. 17 | """ 18 | 19 | def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): 20 | self.milestones = Counter(milestones) 21 | self.gamma = gamma 22 | self.restarts = restarts 23 | self.restart_weights = restart_weights 24 | assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' 25 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 26 | 27 | def get_lr(self): 28 | if self.last_epoch in self.restarts: 29 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 30 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 31 | if self.last_epoch not in self.milestones: 32 | return [group['lr'] for group in self.optimizer.param_groups] 33 | return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] 34 | 35 | 36 | def get_position_from_periods(iteration, cumulative_period): 37 | """Get the position from a period list. 38 | 39 | It will return the index of the right-closest number in the period list. 40 | For example, the cumulative_period = [100, 200, 300, 400], 41 | if iteration == 50, return 0; 42 | if iteration == 210, return 2; 43 | if iteration == 300, return 2. 44 | 45 | Args: 46 | iteration (int): Current iteration. 47 | cumulative_period (list[int]): Cumulative period list. 48 | 49 | Returns: 50 | int: The position of the right-closest number in the period list. 51 | """ 52 | for i, period in enumerate(cumulative_period): 53 | if iteration <= period: 54 | return i 55 | 56 | 57 | class CosineAnnealingRestartLR(_LRScheduler): 58 | """ Cosine annealing with restarts learning rate scheme. 59 | 60 | An example of config: 61 | periods = [10, 10, 10, 10] 62 | restart_weights = [1, 0.5, 0.5, 0.5] 63 | eta_min=1e-7 64 | 65 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 66 | scheduler will restart with the weights in restart_weights. 67 | 68 | Args: 69 | optimizer (torch.nn.optimizer): Torch optimizer. 70 | periods (list): Period for each cosine anneling cycle. 71 | restart_weights (list): Restart weights at each restart iteration. 72 | Default: [1]. 73 | eta_min (float): The minimum lr. Default: 0. 74 | last_epoch (int): Used in _LRScheduler. Default: -1. 75 | """ 76 | 77 | def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): 78 | self.periods = periods 79 | self.restart_weights = restart_weights 80 | self.eta_min = eta_min 81 | assert (len(self.periods) == len( 82 | self.restart_weights)), 'periods and restart_weights should have the same length.' 83 | self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] 84 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 85 | 86 | def get_lr(self): 87 | idx = get_position_from_periods(self.last_epoch, self.cumulative_period) 88 | current_weight = self.restart_weights[idx] 89 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 90 | current_period = self.periods[idx] 91 | 92 | return [ 93 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 94 | (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) 95 | for base_lr in self.base_lrs 96 | ] 97 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/models/swinir_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from basicsr.utils.registry import MODEL_REGISTRY 5 | from .sr_model import SRModel 6 | 7 | 8 | @MODEL_REGISTRY.register() 9 | class SwinIRModel(SRModel): 10 | 11 | def test(self): 12 | # pad to multiplication of window_size 13 | window_size = self.opt['network_g']['window_size'] 14 | scale = self.opt.get('scale', 1) 15 | mod_pad_h, mod_pad_w = 0, 0 16 | _, _, h, w = self.lq.size() 17 | if h % window_size != 0: 18 | mod_pad_h = window_size - h % window_size 19 | if w % window_size != 0: 20 | mod_pad_w = window_size - w % window_size 21 | img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect') 22 | if hasattr(self, 'net_g_ema'): 23 | self.net_g_ema.eval() 24 | with torch.no_grad(): 25 | self.output = self.net_g_ema(img) 26 | else: 27 | self.net_g.eval() 28 | with torch.no_grad(): 29 | self.output = self.net_g(img) 30 | self.net_g.train() 31 | 32 | _, _, h, w = self.output.size() 33 | self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale] 34 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/models/video_gan_model.py: -------------------------------------------------------------------------------- 1 | from basicsr.utils.registry import MODEL_REGISTRY 2 | from .srgan_model import SRGANModel 3 | from .video_base_model import VideoBaseModel 4 | 5 | 6 | @MODEL_REGISTRY.register() 7 | class VideoGANModel(SRGANModel, VideoBaseModel): 8 | """Video GAN model. 9 | 10 | Use multiple inheritance. 11 | It will first use the functions of :class:`SRGANModel`: 12 | 13 | - :func:`init_training_settings` 14 | - :func:`setup_optimizers` 15 | - :func:`optimize_parameters` 16 | - :func:`save` 17 | 18 | Then find functions in :class:`VideoBaseModel`. 19 | """ 20 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/PhoCoLens/154fe32aea5c2b623f6dd1c07e90c2900d076486/NullSpaceDiff/basicsr/ops/__init__.py -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/ops/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, 2 | modulated_deform_conv) 3 | 4 | __all__ = [ 5 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', 6 | 'modulated_deform_conv' 7 | ] 8 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/ops/fused_act/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | 3 | __all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] 4 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/ops/fused_act/fused_act.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 2 | 3 | import os 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Function 7 | 8 | BASICSR_JIT = os.getenv('BASICSR_JIT') 9 | if BASICSR_JIT == 'True': 10 | from torch.utils.cpp_extension import load 11 | module_path = os.path.dirname(__file__) 12 | fused_act_ext = load( 13 | 'fused', 14 | sources=[ 15 | os.path.join(module_path, 'src', 'fused_bias_act.cpp'), 16 | os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), 17 | ], 18 | ) 19 | else: 20 | try: 21 | from . import fused_act_ext 22 | except ImportError: 23 | pass 24 | # avoid annoying print output 25 | # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' 26 | # '1. compile with BASICSR_EXT=True. or\n ' 27 | # '2. set BASICSR_JIT=True during running') 28 | 29 | 30 | class FusedLeakyReLUFunctionBackward(Function): 31 | 32 | @staticmethod 33 | def forward(ctx, grad_output, out, negative_slope, scale): 34 | ctx.save_for_backward(out) 35 | ctx.negative_slope = negative_slope 36 | ctx.scale = scale 37 | 38 | empty = grad_output.new_empty(0) 39 | 40 | grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) 41 | 42 | dim = [0] 43 | 44 | if grad_input.ndim > 2: 45 | dim += list(range(2, grad_input.ndim)) 46 | 47 | grad_bias = grad_input.sum(dim).detach() 48 | 49 | return grad_input, grad_bias 50 | 51 | @staticmethod 52 | def backward(ctx, gradgrad_input, gradgrad_bias): 53 | out, = ctx.saved_tensors 54 | gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, 55 | ctx.scale) 56 | 57 | return gradgrad_out, None, None, None 58 | 59 | 60 | class FusedLeakyReLUFunction(Function): 61 | 62 | @staticmethod 63 | def forward(ctx, input, bias, negative_slope, scale): 64 | empty = input.new_empty(0) 65 | out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 66 | ctx.save_for_backward(out) 67 | ctx.negative_slope = negative_slope 68 | ctx.scale = scale 69 | 70 | return out 71 | 72 | @staticmethod 73 | def backward(ctx, grad_output): 74 | out, = ctx.saved_tensors 75 | 76 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) 77 | 78 | return grad_input, grad_bias, None, None 79 | 80 | 81 | class FusedLeakyReLU(nn.Module): 82 | 83 | def __init__(self, channel, negative_slope=0.2, scale=2**0.5): 84 | super().__init__() 85 | 86 | self.bias = nn.Parameter(torch.zeros(channel)) 87 | self.negative_slope = negative_slope 88 | self.scale = scale 89 | 90 | def forward(self, input): 91 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 92 | 93 | 94 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): 95 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 96 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/ops/upfirdn2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .upfirdn2d import upfirdn2d 2 | 3 | __all__ = ['upfirdn2d'] 4 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from os import path as osp 4 | 5 | from basicsr.data import build_dataloader, build_dataset 6 | from basicsr.models import build_model 7 | from basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs 8 | from basicsr.utils.options import dict2str, parse_options 9 | 10 | 11 | def test_pipeline(root_path): 12 | # parse options, set distributed setting, set ramdom seed 13 | opt, _ = parse_options(root_path, is_train=False) 14 | 15 | torch.backends.cudnn.benchmark = True 16 | # torch.backends.cudnn.deterministic = True 17 | 18 | # mkdir and initialize loggers 19 | make_exp_dirs(opt) 20 | log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") 21 | logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 22 | logger.info(get_env_info()) 23 | logger.info(dict2str(opt)) 24 | 25 | # create test dataset and dataloader 26 | test_loaders = [] 27 | for _, dataset_opt in sorted(opt['datasets'].items()): 28 | test_set = build_dataset(dataset_opt) 29 | test_loader = build_dataloader( 30 | test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) 31 | logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}") 32 | test_loaders.append(test_loader) 33 | 34 | # create model 35 | model = build_model(opt) 36 | 37 | for test_loader in test_loaders: 38 | test_set_name = test_loader.dataset.opt['name'] 39 | logger.info(f'Testing {test_set_name}...') 40 | model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img']) 41 | 42 | 43 | if __name__ == '__main__': 44 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 45 | test_pipeline(root_path) 46 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb 2 | from .diffjpeg import DiffJPEG 3 | from .file_client import FileClient 4 | from .img_process_util import USMSharp, usm_sharp 5 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img 6 | from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger 7 | from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt 8 | from .options import yaml_load 9 | 10 | __all__ = [ 11 | # color_util.py 12 | 'bgr2ycbcr', 13 | 'rgb2ycbcr', 14 | 'rgb2ycbcr_pt', 15 | 'ycbcr2bgr', 16 | 'ycbcr2rgb', 17 | # file_client.py 18 | 'FileClient', 19 | # img_util.py 20 | 'img2tensor', 21 | 'tensor2img', 22 | 'imfrombytes', 23 | 'imwrite', 24 | 'crop_border', 25 | # logger.py 26 | 'MessageLogger', 27 | 'AvgTimer', 28 | 'init_tb_logger', 29 | 'init_wandb_logger', 30 | 'get_root_logger', 31 | 'get_env_info', 32 | # misc.py 33 | 'set_random_seed', 34 | 'get_time_str', 35 | 'mkdir_and_rename', 36 | 'make_exp_dirs', 37 | 'scandir', 38 | 'check_resume', 39 | 'sizeof_fmt', 40 | # diffjpeg 41 | 'DiffJPEG', 42 | # img_process_util 43 | 'USMSharp', 44 | 'usm_sharp', 45 | # options 46 | 'yaml_load' 47 | ] 48 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') 45 | # specify master port 46 | if port is not None: 47 | os.environ['MASTER_PORT'] = str(port) 48 | elif 'MASTER_PORT' in os.environ: 49 | pass # use MASTER_PORT in the environment variable 50 | else: 51 | # 29500 is torch.distributed default port 52 | os.environ['MASTER_PORT'] = '29500' 53 | os.environ['MASTER_ADDR'] = addr 54 | os.environ['WORLD_SIZE'] = str(ntasks) 55 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 56 | os.environ['RANK'] = str(proc_id) 57 | dist.init_process_group(backend=backend) 58 | 59 | 60 | def get_dist_info(): 61 | if dist.is_available(): 62 | initialized = dist.is_initialized() 63 | else: 64 | initialized = False 65 | if initialized: 66 | rank = dist.get_rank() 67 | world_size = dist.get_world_size() 68 | else: 69 | rank = 0 70 | world_size = 1 71 | return rank, world_size 72 | 73 | 74 | def master_only(func): 75 | 76 | @functools.wraps(func) 77 | def wrapper(*args, **kwargs): 78 | rank, _ = get_dist_info() 79 | if rank == 0: 80 | return func(*args, **kwargs) 81 | 82 | return wrapper 83 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import requests 4 | from torch.hub import download_url_to_file, get_dir 5 | from tqdm import tqdm 6 | from urllib.parse import urlparse 7 | 8 | from .misc import sizeof_fmt 9 | 10 | 11 | def download_file_from_google_drive(file_id, save_path): 12 | """Download files from google drive. 13 | 14 | Reference: https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive 15 | 16 | Args: 17 | file_id (str): File id. 18 | save_path (str): Save path. 19 | """ 20 | 21 | session = requests.Session() 22 | URL = 'https://docs.google.com/uc?export=download' 23 | params = {'id': file_id} 24 | 25 | response = session.get(URL, params=params, stream=True) 26 | token = get_confirm_token(response) 27 | if token: 28 | params['confirm'] = token 29 | response = session.get(URL, params=params, stream=True) 30 | 31 | # get file size 32 | response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 33 | if 'Content-Range' in response_file_size.headers: 34 | file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) 35 | else: 36 | file_size = None 37 | 38 | save_response_content(response, save_path, file_size) 39 | 40 | 41 | def get_confirm_token(response): 42 | for key, value in response.cookies.items(): 43 | if key.startswith('download_warning'): 44 | return value 45 | return None 46 | 47 | 48 | def save_response_content(response, destination, file_size=None, chunk_size=32768): 49 | if file_size is not None: 50 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 51 | 52 | readable_file_size = sizeof_fmt(file_size) 53 | else: 54 | pbar = None 55 | 56 | with open(destination, 'wb') as f: 57 | downloaded_size = 0 58 | for chunk in response.iter_content(chunk_size): 59 | downloaded_size += chunk_size 60 | if pbar is not None: 61 | pbar.update(1) 62 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') 63 | if chunk: # filter out keep-alive new chunks 64 | f.write(chunk) 65 | if pbar is not None: 66 | pbar.close() 67 | 68 | 69 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 70 | """Load file form http url, will download models if necessary. 71 | 72 | Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 73 | 74 | Args: 75 | url (str): URL to be downloaded. 76 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 77 | Default: None. 78 | progress (bool): Whether to show the download progress. Default: True. 79 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 80 | 81 | Returns: 82 | str: The path to the downloaded file. 83 | """ 84 | if model_dir is None: # use the pytorch hub_dir 85 | hub_dir = get_dir() 86 | model_dir = os.path.join(hub_dir, 'checkpoints') 87 | 88 | os.makedirs(model_dir, exist_ok=True) 89 | 90 | parts = urlparse(url) 91 | filename = os.path.basename(parts.path) 92 | if file_name is not None: 93 | filename = file_name 94 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 95 | if not os.path.exists(cached_file): 96 | print(f'Downloading: "{url}" to {cached_file}\n') 97 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 98 | return cached_file 99 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/utils/img_process_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def filter2D(img, kernel): 8 | """PyTorch version of cv2.filter2D 9 | 10 | Args: 11 | img (Tensor): (b, c, h, w) 12 | kernel (Tensor): (b, k, k) 13 | """ 14 | k = kernel.size(-1) 15 | b, c, h, w = img.size() 16 | if k % 2 == 1: 17 | img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') 18 | else: 19 | raise ValueError('Wrong kernel size') 20 | 21 | ph, pw = img.size()[-2:] 22 | 23 | if kernel.size(0) == 1: 24 | # apply the same kernel to all batch images 25 | img = img.view(b * c, 1, ph, pw) 26 | kernel = kernel.view(1, 1, k, k) 27 | return F.conv2d(img, kernel, padding=0).view(b, c, h, w) 28 | else: 29 | img = img.view(1, b * c, ph, pw) 30 | kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) 31 | return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) 32 | 33 | 34 | def usm_sharp(img, weight=0.5, radius=50, threshold=10): 35 | """USM sharpening. 36 | 37 | Input image: I; Blurry image: B. 38 | 1. sharp = I + weight * (I - B) 39 | 2. Mask = 1 if abs(I - B) > threshold, else: 0 40 | 3. Blur mask: 41 | 4. Out = Mask * sharp + (1 - Mask) * I 42 | 43 | 44 | Args: 45 | img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. 46 | weight (float): Sharp weight. Default: 1. 47 | radius (float): Kernel size of Gaussian blur. Default: 50. 48 | threshold (int): 49 | """ 50 | if radius % 2 == 0: 51 | radius += 1 52 | blur = cv2.GaussianBlur(img, (radius, radius), 0) 53 | residual = img - blur 54 | mask = np.abs(residual) * 255 > threshold 55 | mask = mask.astype('float32') 56 | soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) 57 | 58 | sharp = img + weight * residual 59 | sharp = np.clip(sharp, 0, 1) 60 | return soft_mask * sharp + (1 - soft_mask) * img 61 | 62 | 63 | class USMSharp(torch.nn.Module): 64 | 65 | def __init__(self, radius=50, sigma=0): 66 | super(USMSharp, self).__init__() 67 | if radius % 2 == 0: 68 | radius += 1 69 | self.radius = radius 70 | kernel = cv2.getGaussianKernel(radius, sigma) 71 | kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) 72 | self.register_buffer('kernel', kernel) 73 | 74 | def forward(self, img, weight=0.5, threshold=10): 75 | blur = filter2D(img, self.kernel) 76 | residual = img - blur 77 | 78 | mask = torch.abs(residual) * 255 > threshold 79 | mask = mask.float() 80 | soft_mask = filter2D(mask, self.kernel) 81 | sharp = img + weight * residual 82 | sharp = torch.clip(sharp, 0, 1) 83 | return soft_mask * sharp + (1 - soft_mask) * img 84 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import time 5 | import torch 6 | from os import path as osp 7 | 8 | from .dist_util import master_only 9 | 10 | 11 | def set_random_seed(seed): 12 | """Set random seeds.""" 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | 20 | def get_time_str(): 21 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 22 | 23 | 24 | def mkdir_and_rename(path): 25 | """mkdirs. If path exists, rename it with timestamp and create a new one. 26 | 27 | Args: 28 | path (str): Folder path. 29 | """ 30 | if osp.exists(path): 31 | new_name = path + '_archived_' + get_time_str() 32 | print(f'Path already exists. Rename it to {new_name}', flush=True) 33 | os.rename(path, new_name) 34 | os.makedirs(path, exist_ok=True) 35 | 36 | 37 | @master_only 38 | def make_exp_dirs(opt): 39 | """Make dirs for experiments.""" 40 | path_opt = opt['path'].copy() 41 | if opt['is_train']: 42 | mkdir_and_rename(path_opt.pop('experiments_root')) 43 | else: 44 | mkdir_and_rename(path_opt.pop('results_root')) 45 | for key, path in path_opt.items(): 46 | if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key): 47 | continue 48 | else: 49 | os.makedirs(path, exist_ok=True) 50 | 51 | 52 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 53 | """Scan a directory to find the interested files. 54 | 55 | Args: 56 | dir_path (str): Path of the directory. 57 | suffix (str | tuple(str), optional): File suffix that we are 58 | interested in. Default: None. 59 | recursive (bool, optional): If set to True, recursively scan the 60 | directory. Default: False. 61 | full_path (bool, optional): If set to True, include the dir_path. 62 | Default: False. 63 | 64 | Returns: 65 | A generator for all the interested files with relative paths. 66 | """ 67 | 68 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 69 | raise TypeError('"suffix" must be a string or tuple of strings') 70 | 71 | root = dir_path 72 | 73 | def _scandir(dir_path, suffix, recursive): 74 | for entry in os.scandir(dir_path): 75 | if not entry.name.startswith('.') and entry.is_file(): 76 | if full_path: 77 | return_path = entry.path 78 | else: 79 | return_path = osp.relpath(entry.path, root) 80 | 81 | if suffix is None: 82 | yield return_path 83 | elif return_path.endswith(suffix): 84 | yield return_path 85 | else: 86 | if recursive: 87 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 88 | else: 89 | continue 90 | 91 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 92 | 93 | 94 | def check_resume(opt, resume_iter): 95 | """Check resume states and pretrain_network paths. 96 | 97 | Args: 98 | opt (dict): Options. 99 | resume_iter (int): Resume iteration. 100 | """ 101 | if opt['path']['resume_state']: 102 | # get all the networks 103 | networks = [key for key in opt.keys() if key.startswith('network_')] 104 | flag_pretrain = False 105 | for network in networks: 106 | if opt['path'].get(f'pretrain_{network}') is not None: 107 | flag_pretrain = True 108 | if flag_pretrain: 109 | print('pretrain_network path will be ignored during resuming.') 110 | # set pretrained model paths 111 | for network in networks: 112 | name = f'pretrain_{network}' 113 | basename = network.replace('network_', '') 114 | if opt['path'].get('ignore_resume_networks') is None or (network 115 | not in opt['path']['ignore_resume_networks']): 116 | opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 117 | print(f"Set {name} to {opt['path'][name]}") 118 | 119 | # change param_key to params in resume 120 | param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')] 121 | for param_key in param_keys: 122 | if opt['path'][param_key] == 'params_ema': 123 | opt['path'][param_key] = 'params' 124 | print(f'Set {param_key} to params') 125 | 126 | 127 | def sizeof_fmt(size, suffix='B'): 128 | """Get human readable file size. 129 | 130 | Args: 131 | size (int): File size. 132 | suffix (str): Suffix. Default: 'B'. 133 | 134 | Return: 135 | str: Formatted file size. 136 | """ 137 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 138 | if abs(size) < 1024.0: 139 | return f'{size:3.1f} {unit}{suffix}' 140 | size /= 1024.0 141 | return f'{size:3.1f} Y{suffix}' 142 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/utils/plot_util.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def read_data_from_tensorboard(log_path, tag): 5 | """Get raw data (steps and values) from tensorboard events. 6 | 7 | Args: 8 | log_path (str): Path to the tensorboard log. 9 | tag (str): tag to be read. 10 | """ 11 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator 12 | 13 | # tensorboard event 14 | event_acc = EventAccumulator(log_path) 15 | event_acc.Reload() 16 | scalar_list = event_acc.Tags()['scalars'] 17 | print('tag list: ', scalar_list) 18 | steps = [int(s.step) for s in event_acc.Scalars(tag)] 19 | values = [s.value for s in event_acc.Scalars(tag)] 20 | return steps, values 21 | 22 | 23 | def read_data_from_txt_2v(path, pattern, step_one=False): 24 | """Read data from txt with 2 returned values (usually [step, value]). 25 | 26 | Args: 27 | path (str): path to the txt file. 28 | pattern (str): re (regular expression) pattern. 29 | step_one (bool): add 1 to steps. Default: False. 30 | """ 31 | with open(path) as f: 32 | lines = f.readlines() 33 | lines = [line.strip() for line in lines] 34 | steps = [] 35 | values = [] 36 | 37 | pattern = re.compile(pattern) 38 | for line in lines: 39 | match = pattern.match(line) 40 | if match: 41 | steps.append(int(match.group(1))) 42 | values.append(float(match.group(2))) 43 | if step_one: 44 | steps = [v + 1 for v in steps] 45 | return steps, values 46 | 47 | 48 | def read_data_from_txt_1v(path, pattern): 49 | """Read data from txt with 1 returned values. 50 | 51 | Args: 52 | path (str): path to the txt file. 53 | pattern (str): re (regular expression) pattern. 54 | """ 55 | with open(path) as f: 56 | lines = f.readlines() 57 | lines = [line.strip() for line in lines] 58 | data = [] 59 | 60 | pattern = re.compile(pattern) 61 | for line in lines: 62 | match = pattern.match(line) 63 | if match: 64 | data.append(float(match.group(1))) 65 | return data 66 | 67 | 68 | def smooth_data(values, smooth_weight): 69 | """ Smooth data using 1st-order IIR low-pass filter (what tensorflow does). 70 | 71 | Reference: https://github.com/tensorflow/tensorboard/blob/f801ebf1f9fbfe2baee1ddd65714d0bccc640fb1/tensorboard/plugins/scalar/vz_line_chart/vz-line-chart.ts#L704 # noqa: E501 72 | 73 | Args: 74 | values (list): A list of values to be smoothed. 75 | smooth_weight (float): Smooth weight. 76 | """ 77 | values_sm = [] 78 | last_sm_value = values[0] 79 | for value in values: 80 | value_sm = last_sm_value * smooth_weight + (1 - smooth_weight) * value 81 | values_sm.append(value_sm) 82 | last_sm_value = value_sm 83 | return values_sm 84 | -------------------------------------------------------------------------------- /NullSpaceDiff/basicsr/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | 3 | 4 | class Registry(): 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | 9 | To create a registry (e.g. a backbone registry): 10 | 11 | .. code-block:: python 12 | 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | 15 | To register an object: 16 | 17 | .. code-block:: python 18 | 19 | @BACKBONE_REGISTRY.register() 20 | class MyBackbone(): 21 | ... 22 | 23 | Or: 24 | 25 | .. code-block:: python 26 | 27 | BACKBONE_REGISTRY.register(MyBackbone) 28 | """ 29 | 30 | def __init__(self, name): 31 | """ 32 | Args: 33 | name (str): the name of this registry 34 | """ 35 | self._name = name 36 | self._obj_map = {} 37 | 38 | def _do_register(self, name, obj, suffix=None): 39 | if isinstance(suffix, str): 40 | name = name + '_' + suffix 41 | 42 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " 43 | f"in '{self._name}' registry!") 44 | self._obj_map[name] = obj 45 | 46 | def register(self, obj=None, suffix=None): 47 | """ 48 | Register the given object under the the name `obj.__name__`. 49 | Can be used as either a decorator or not. 50 | See docstring of this class for usage. 51 | """ 52 | if obj is None: 53 | # used as a decorator 54 | def deco(func_or_class): 55 | name = func_or_class.__name__ 56 | self._do_register(name, func_or_class, suffix) 57 | return func_or_class 58 | 59 | return deco 60 | 61 | # used as a function call 62 | name = obj.__name__ 63 | self._do_register(name, obj, suffix) 64 | 65 | def get(self, name, suffix='basicsr'): 66 | ret = self._obj_map.get(name) 67 | if ret is None: 68 | ret = self._obj_map.get(name + '_' + suffix) 69 | print(f'Name {name} is not found, use name: {name}_{suffix}!') 70 | if ret is None: 71 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") 72 | return ret 73 | 74 | def __contains__(self, name): 75 | return name in self._obj_map 76 | 77 | def __iter__(self): 78 | return iter(self._obj_map.items()) 79 | 80 | def keys(self): 81 | return self._obj_map.keys() 82 | 83 | 84 | DATASET_REGISTRY = Registry('dataset') 85 | ARCH_REGISTRY = Registry('arch') 86 | MODEL_REGISTRY = Registry('model') 87 | LOSS_REGISTRY = Registry('loss') 88 | METRIC_REGISTRY = Registry('metric') 89 | -------------------------------------------------------------------------------- /NullSpaceDiff/cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | system_packages: 7 | - "libgl1-mesa-glx" 8 | - "libglib2.0-0" 9 | python_version: "3.11" 10 | python_packages: 11 | - "torch==2.0.1" 12 | - "torchvision==0.15.2" 13 | - "numpy==1.25.1" 14 | - "opencv-python==4.8.0.74" 15 | - "imageio==2.31.1" 16 | - "omegaconf==2.3.0" 17 | - "transformers==4.31.0" 18 | - "torchmetrics==0.7.0" 19 | - "open_clip_torch==2.0.2" 20 | - "einops==0.6.1" 21 | - "pytorch_lightning==1.7.7" 22 | - "scipy==1.11.1" 23 | - "scikit-image==0.21.0" 24 | - "matplotlib==3.7.2" 25 | - "scikit-learn==1.3.0" 26 | - "kornia==0.6.12" 27 | - "xformers==0.0.20" 28 | - "clip @ git+https://github.com/openai/CLIP.git" 29 | run: 30 | - pip install git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 31 | - mkdir -p /root/.cache/torch/hub/checkpoints && wget --output-document "/root/.cache/torch/hub/checkpoints/vgg16-397923af.pth" "https://download.pytorch.org/models/vgg16-397923af.pth" 32 | predict: "predict.py:Predictor" 33 | -------------------------------------------------------------------------------- /NullSpaceDiff/configs/NullSpaceDiff/phlatcam_decoded_sim_multi_T_512.yaml: -------------------------------------------------------------------------------- 1 | sf: 4 2 | model: 3 | base_learning_rate: 5.0e-05 4 | target: ldm.models.diffusion.ddpm.LatentDiffusionSRTextWTFFHQ 5 | params: 6 | parameterization: "v" 7 | linear_start: 0.00085 8 | linear_end: 0.0120 9 | num_timesteps_cond: 1 10 | log_every_t: 200 11 | timesteps: 1000 12 | first_stage_key: image 13 | cond_stage_key: caption 14 | image_size: 512 15 | channels: 4 16 | cond_stage_trainable: False # Note: different from the one we trained before 17 | conditioning_key: crossattn 18 | monitor: val/loss_simple_ema 19 | scale_factor: 0.18215 20 | use_ema: False 21 | # for training only 22 | ckpt_path: ckpts/stablesr_000117.ckpt 23 | unfrozen_diff: False 24 | random_size: False 25 | time_replace: 1000 26 | use_usm: False 27 | # test_gt: True 28 | #P2 weighting 29 | p2_gamma: ~ 30 | p2_k: ~ 31 | 32 | unet_config: 33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModelDualcondV2 34 | params: 35 | image_size: 32 # unused 36 | in_channels: 4 37 | out_channels: 4 38 | model_channels: 320 39 | attention_resolutions: [ 4, 2, 1 ] 40 | num_res_blocks: 2 41 | channel_mult: [ 1, 2, 4, 4 ] 42 | num_head_channels: 64 43 | use_spatial_transformer: True 44 | use_linear_in_transformer: True 45 | transformer_depth: 1 46 | context_dim: 1024 47 | use_checkpoint: False 48 | legacy: False 49 | semb_channels: 256 50 | 51 | first_stage_config: 52 | target: ldm.models.autoencoder.AutoencoderKL 53 | params: 54 | # for training only 55 | ckpt_path: ckpts/stablesr_000117.ckpt 56 | embed_dim: 4 57 | monitor: val/rec_loss 58 | ddconfig: 59 | double_z: true 60 | z_channels: 4 61 | resolution: 512 62 | in_channels: 3 63 | out_ch: 3 64 | ch: 128 65 | ch_mult: 66 | - 1 67 | - 2 68 | - 4 69 | - 4 70 | num_res_blocks: 2 71 | attn_resolutions: [] 72 | dropout: 0.0 73 | lossconfig: 74 | target: torch.nn.Identity 75 | 76 | cond_stage_config: 77 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 78 | params: 79 | freeze: True 80 | layer: "penultimate" 81 | 82 | structcond_stage_config: 83 | target: ldm.modules.diffusionmodules.openaimodel.EncoderUNetModelWT 84 | params: 85 | image_size: 96 86 | in_channels: 4 87 | model_channels: 256 88 | out_channels: 256 89 | num_res_blocks: 2 90 | attention_resolutions: [ 4, 2, 1 ] 91 | dropout: 0 92 | channel_mult: [ 1, 1, 2, 2 ] 93 | conv_resample: True 94 | dims: 2 95 | use_checkpoint: False 96 | use_fp16: False 97 | num_heads: 4 98 | num_head_channels: -1 99 | num_heads_upsample: -1 100 | use_scale_shift_norm: False 101 | resblock_updown: False 102 | use_new_attention_order: False 103 | 104 | data: 105 | target: main.DataModuleFromConfig 106 | params: 107 | batch_size: 5 108 | num_workers: 5 109 | wrap: false 110 | train: 111 | target: basicsr.data.paired_image_dataset.PairedImageDataset 112 | params: 113 | dataroot_gt: data/fft-svd-1280-1408-meas-decoded_sim_spatial_weight/train/gts_512 114 | dataroot_lq: data/fft-svd-1280-1408-meas-decoded_sim_spatial_weight/train/inputs_512 115 | io_backend: 116 | type: disk 117 | phase: train 118 | gt_size: 512 119 | scale: 1 120 | use_rot: true 121 | use_hflip: true 122 | validation: 123 | target: basicsr.data.paired_image_dataset.PairedImageDataset 124 | params: 125 | dataroot_gt: data/fft-svd-1280-1408-meas-decoded_sim_spatial_weight/val/gts_512 126 | dataroot_lq: data/fft-svd-1280-1408-meas-decoded_sim_spatial_weight/val/inputs_512 127 | io_backend: 128 | type: disk 129 | phase: val 130 | gt_size: 512 131 | scale: 1 132 | use_rot: false 133 | use_hflip: false 134 | 135 | test_data: 136 | target: basicsr.data.paired_image_dataset.PairedImageDataset 137 | params: 138 | dataroot_gt: data/fft-svd-1280-1408-meas-decoded_sim_spatial_weight/val/gts_512 139 | dataroot_lq: data/fft-svd-1280-1408-meas-decoded_sim_spatial_weight/val/inputs_512 140 | io_backend: 141 | type: disk 142 | phase: val 143 | gt_size: 512 144 | scale: 1 145 | use_rot: false 146 | use_hflip: false 147 | 148 | lightning: 149 | modelcheckpoint: 150 | params: 151 | every_n_train_steps: 100 152 | callbacks: 153 | image_logger: 154 | target: main.ImageLogger 155 | params: 156 | batch_frequency: 1000 157 | max_images: 5 158 | log_on_batch_idx: True 159 | increase_log_steps: False 160 | 161 | trainer: 162 | benchmark: True 163 | max_steps: 2000000 164 | accumulate_grad_batches: 2 165 | val_check_interval: 1.0 166 | 167 | 168 | 169 | -------------------------------------------------------------------------------- /NullSpaceDiff/configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-5 3 | target: ldm.models.autoencoder.AutoencoderKLResi 4 | params: 5 | # for training only 6 | # ckpt_path: /mnt/lustre/jywang/code/stable_diffmodels/v2-1_512-ema-pruned.ckpt 7 | monitor: "val/rec_loss" 8 | embed_dim: 4 9 | fusion_w: 1.0 10 | freeze_dec: True 11 | synthesis_data: False 12 | lossconfig: 13 | target: ldm.modules.losses.LPIPSWithDiscriminator 14 | params: 15 | disc_start: 501 16 | kl_weight: 0 17 | disc_weight: 0.025 18 | disc_factor: 1.0 19 | 20 | ddconfig: 21 | double_z: true 22 | z_channels: 4 23 | resolution: 512 24 | in_channels: 3 25 | out_ch: 3 26 | ch: 128 27 | ch_mult: 28 | - 1 29 | - 2 30 | - 4 31 | - 4 32 | num_res_blocks: 2 33 | attn_resolutions: [] 34 | dropout: 0.0 35 | 36 | image_key: 'gt' 37 | 38 | 39 | data: 40 | target: main.DataModuleFromConfig 41 | params: 42 | batch_size: 1 43 | num_workers: 6 44 | wrap: True 45 | train: 46 | target: basicsr.data.single_image_dataset.SingleImageNPDataset 47 | params: 48 | gt_path: ['/mnt/lustre/share/jywang/ddpm_data/CFW_trainingdata/'] 49 | io_backend: 50 | type: disk 51 | validation: 52 | target: basicsr.data.single_image_dataset.SingleImageNPDataset 53 | params: 54 | gt_path: ['/mnt/lustre/share/jywang/ddpm_data/CFW_trainingdata/'] 55 | io_backend: 56 | type: disk 57 | 58 | lightning: 59 | modelcheckpoint: 60 | params: 61 | every_n_train_steps: 1500 62 | callbacks: 63 | image_logger: 64 | target: main.ImageLogger 65 | params: 66 | batch_frequency: 1500 67 | max_images: 4 68 | increase_log_steps: False 69 | 70 | trainer: 71 | benchmark: True 72 | max_steps: 800000 73 | accumulate_grad_batches: 8 74 | -------------------------------------------------------------------------------- /NullSpaceDiff/configs/autoencoder/autoencoder_kl_64x64x4_resi_face.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-5 3 | target: ldm.models.autoencoder.AutoencoderKLResi 4 | params: 5 | # for training only 6 | # ckpt_path: vqgan_finetune_00011.ckpt 7 | monitor: "val/rec_loss" 8 | embed_dim: 4 9 | fusion_w: 1.0 10 | freeze_dec: True 11 | synthesis_data: False 12 | lossconfig: 13 | target: ldm.modules.losses.LPIPSWithDiscriminator 14 | params: 15 | disc_start: 501 16 | kl_weight: 0 17 | disc_weight: 0.025 18 | disc_factor: 1.0 19 | 20 | ddconfig: 21 | double_z: true 22 | z_channels: 4 23 | resolution: 512 24 | in_channels: 3 25 | out_ch: 3 26 | ch: 128 27 | ch_mult: 28 | - 1 29 | - 2 30 | - 4 31 | - 4 32 | num_res_blocks: 2 33 | attn_resolutions: [] 34 | dropout: 0.0 35 | 36 | image_key: 'gt' 37 | 38 | 39 | data: 40 | target: main.DataModuleFromConfig 41 | params: 42 | batch_size: 1 43 | num_workers: 6 44 | wrap: True 45 | train: 46 | target: basicsr.data.single_image_dataset.SingleImageNPDataset 47 | params: 48 | gt_path: ['/mnt/lustre/jywang/code/ddpm_face_data/v2-T200-multistep/'] 49 | io_backend: 50 | type: disk 51 | validation: 52 | target: basicsr.data.single_image_dataset.SingleImageNPDataset 53 | params: 54 | gt_path: ['/mnt/lustre/jywang/code/ddpm_face_data/v2-T200-multistep/'] 55 | io_backend: 56 | type: disk 57 | 58 | lightning: 59 | modelcheckpoint: 60 | params: 61 | every_n_train_steps: 1500 62 | callbacks: 63 | image_logger: 64 | target: main.ImageLogger 65 | params: 66 | batch_frequency: 1500 67 | max_images: 4 68 | increase_log_steps: False 69 | 70 | trainer: 71 | benchmark: True 72 | max_steps: 800000 73 | accumulate_grad_batches: 8 74 | -------------------------------------------------------------------------------- /NullSpaceDiff/environment.yaml: -------------------------------------------------------------------------------- 1 | name: phocolens 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.9 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.12.1 10 | - torchvision=0.13.1 11 | - numpy=1.23.1 12 | - pip: 13 | - albumentations==1.3.0 14 | - opencv-python==4.6.0.66 15 | - imageio==2.9.0 16 | - imageio-ffmpeg==0.4.2 17 | - pytorch-lightning==1.4.2 18 | - omegaconf==2.1.1 19 | - test-tube>=0.7.5 20 | - streamlit==1.12.1 21 | - einops==0.3.0 22 | - transformers==4.19.2 23 | - webdataset==0.2.5 24 | - kornia==0.6 25 | - open_clip_torch==2.0.2 26 | - invisible-watermark>=0.1.5 27 | - streamlit-drawable-canvas==0.8.0 28 | - torchmetrics==0.6.0 29 | - triton 30 | - matplotlib 31 | - wandb 32 | - pillow 33 | - scikit-image 34 | - lpips 35 | - sacred 36 | - waveprop 37 | - recordclass 38 | -------------------------------------------------------------------------------- /NullSpaceDiff/ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/PhoCoLens/154fe32aea5c2b623f6dd1c07e90c2900d076486/NullSpaceDiff/ldm/data/__init__.py -------------------------------------------------------------------------------- /NullSpaceDiff/ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /NullSpaceDiff/ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /NullSpaceDiff/ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /NullSpaceDiff/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/PhoCoLens/154fe32aea5c2b623f6dd1c07e90c2900d076486/NullSpaceDiff/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /NullSpaceDiff/ldm/models/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | # from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim"):]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] #[250,] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | # class SpacedDiffusion(GaussianDiffusion): 63 | # """ 64 | # A diffusion process which can skip steps in a base diffusion process. 65 | # 66 | # :param use_timesteps: a collection (sequence or set) of timesteps from the 67 | # original diffusion process to retain. 68 | # :param kwargs: the kwargs to create the base diffusion process. 69 | # """ 70 | # 71 | # def __init__(self, use_timesteps, **kwargs): 72 | # self.use_timesteps = set(use_timesteps) 73 | # self.timestep_map = [] 74 | # self.original_num_steps = len(kwargs["betas"]) 75 | # 76 | # base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 77 | # last_alpha_cumprod = 1.0 78 | # new_betas = [] 79 | # for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 80 | # if i in self.use_timesteps: 81 | # new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 82 | # last_alpha_cumprod = alpha_cumprod 83 | # self.timestep_map.append(i) 84 | # kwargs["betas"] = np.array(new_betas) 85 | # super().__init__(**kwargs) 86 | # 87 | # def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs 88 | # return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 89 | # 90 | # def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs 91 | # return super().training_losses(self._wrap_model(model), *args, **kwargs) 92 | # 93 | # def _wrap_model(self, model): 94 | # if isinstance(model, _WrappedModel): 95 | # return model 96 | # return _WrappedModel( 97 | # model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 98 | # ) 99 | # 100 | # def _scale_timesteps(self, t): 101 | # # Scaling is done by the wrapped model. 102 | # return t 103 | 104 | class _WrappedModel: 105 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 106 | self.model = model 107 | self.timestep_map = timestep_map 108 | self.rescale_timesteps = rescale_timesteps 109 | self.original_num_steps = original_num_steps 110 | 111 | def __call__(self, x, ts, **kwargs): 112 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 113 | new_ts = map_tensor[ts] 114 | if self.rescale_timesteps: 115 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 116 | return self.model(x, new_ts, **kwargs) 117 | -------------------------------------------------------------------------------- /NullSpaceDiff/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/PhoCoLens/154fe32aea5c2b623f6dd1c07e90c2900d076486/NullSpaceDiff/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /NullSpaceDiff/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/PhoCoLens/154fe32aea5c2b623f6dd1c07e90c2900d076486/NullSpaceDiff/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /NullSpaceDiff/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def sample_deterministic(self): 40 | x = self.mean 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 49 | + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3]) 51 | else: 52 | return 0.5 * torch.sum( 53 | torch.pow(self.mean - other.mean, 2) / other.var 54 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 55 | dim=[1, 2, 3]) 56 | 57 | def nll(self, sample, dims=[1,2,3]): 58 | if self.deterministic: 59 | return torch.Tensor([0.]) 60 | logtwopi = np.log(2.0 * np.pi) 61 | return 0.5 * torch.sum( 62 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 63 | dim=dims) 64 | 65 | def mode(self): 66 | return self.mean 67 | 68 | 69 | def normal_kl(mean1, logvar1, mean2, logvar2): 70 | """ 71 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 72 | Compute the KL divergence between two gaussians. 73 | Shapes are automatically broadcasted, so batches can be compared to 74 | scalars, among other use cases. 75 | """ 76 | tensor = None 77 | for obj in (mean1, logvar1, mean2, logvar2): 78 | if isinstance(obj, torch.Tensor): 79 | tensor = obj 80 | break 81 | assert tensor is not None, "at least one argument must be a Tensor" 82 | 83 | # Force variances to be Tensors. Broadcasting helps convert scalars to 84 | # Tensors, but it does not work for torch.exp(). 85 | logvar1, logvar2 = [ 86 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 87 | for x in (logvar1, logvar2) 88 | ] 89 | 90 | return 0.5 * ( 91 | -1.0 92 | + logvar2 93 | - logvar1 94 | + torch.exp(logvar1 - logvar2) 95 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 96 | ) 97 | -------------------------------------------------------------------------------- /NullSpaceDiff/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | pass 45 | # assert not key in self.m_name2s_name 46 | 47 | def copy_to(self, model): 48 | m_param = dict(model.named_parameters()) 49 | shadow_params = dict(self.named_buffers()) 50 | for key in m_param: 51 | if m_param[key].requires_grad: 52 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 53 | else: 54 | pass 55 | # assert not key in self.m_name2s_name 56 | 57 | def store(self, parameters): 58 | """ 59 | Save the current parameters for restoring later. 60 | Args: 61 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 62 | temporarily stored. 63 | """ 64 | self.collected_params = [param.clone() for param in parameters] 65 | 66 | def restore(self, parameters): 67 | """ 68 | Restore the parameters stored with the `store` method. 69 | Useful to validate the model with EMA parameters without affecting the 70 | original optimization process. Store the parameters before the 71 | `copy_to` method. After validation (or model saving), use this to 72 | restore the former parameters. 73 | Args: 74 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 75 | updated with the stored parameters. 76 | """ 77 | for c_param, param in zip(self.collected_params, parameters): 78 | param.data.copy_(c_param.data) 79 | -------------------------------------------------------------------------------- /NullSpaceDiff/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/PhoCoLens/154fe32aea5c2b623f6dd1c07e90c2900d076486/NullSpaceDiff/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /NullSpaceDiff/ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /NullSpaceDiff/ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/PhoCoLens/154fe32aea5c2b623f6dd1c07e90c2900d076486/NullSpaceDiff/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /NullSpaceDiff/ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /NullSpaceDiff/ldm/modules/spade.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import re 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | # from models.networks.sync_batchnorm import SynchronizedBatchNorm2d 11 | import torch.nn.utils.spectral_norm as spectral_norm 12 | 13 | from ldm.modules.diffusionmodules.util import normalization 14 | 15 | 16 | # Returns a function that creates a normalization function 17 | # that does not condition on semantic map 18 | def get_nonspade_norm_layer(opt, norm_type='instance'): 19 | # helper function to get # output channels of the previous layer 20 | def get_out_channel(layer): 21 | if hasattr(layer, 'out_channels'): 22 | return getattr(layer, 'out_channels') 23 | return layer.weight.size(0) 24 | 25 | # this function will be returned 26 | def add_norm_layer(layer): 27 | nonlocal norm_type 28 | if norm_type.startswith('spectral'): 29 | layer = spectral_norm(layer) 30 | subnorm_type = norm_type[len('spectral'):] 31 | 32 | if subnorm_type == 'none' or len(subnorm_type) == 0: 33 | return layer 34 | 35 | # remove bias in the previous layer, which is meaningless 36 | # since it has no effect after normalization 37 | if getattr(layer, 'bias', None) is not None: 38 | delattr(layer, 'bias') 39 | layer.register_parameter('bias', None) 40 | 41 | if subnorm_type == 'batch': 42 | norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) 43 | elif subnorm_type == 'sync_batch': 44 | norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) 45 | elif subnorm_type == 'instance': 46 | norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) 47 | else: 48 | raise ValueError('normalization layer %s is not recognized' % subnorm_type) 49 | 50 | return nn.Sequential(layer, norm_layer) 51 | 52 | return add_norm_layer 53 | 54 | 55 | # Creates SPADE normalization layer based on the given configuration 56 | # SPADE consists of two steps. First, it normalizes the activations using 57 | # your favorite normalization method, such as Batch Norm or Instance Norm. 58 | # Second, it applies scale and bias to the normalized output, conditioned on 59 | # the segmentation map. 60 | # The format of |config_text| is spade(norm)(ks), where 61 | # (norm) specifies the type of parameter-free normalization. 62 | # (e.g. syncbatch, batch, instance) 63 | # (ks) specifies the size of kernel in the SPADE module (e.g. 3x3) 64 | # Example |config_text| will be spadesyncbatch3x3, or spadeinstance5x5. 65 | # Also, the other arguments are 66 | # |norm_nc|: the #channels of the normalized activations, hence the output dim of SPADE 67 | # |label_nc|: the #channels of the input semantic map, hence the input dim of SPADE 68 | class SPADE(nn.Module): 69 | def __init__(self, norm_nc, label_nc, config_text='spadeinstance3x3'): 70 | super().__init__() 71 | 72 | assert config_text.startswith('spade') 73 | parsed = re.search('spade(\D+)(\d)x\d', config_text) 74 | param_free_norm_type = str(parsed.group(1)) 75 | ks = int(parsed.group(2)) 76 | 77 | self.param_free_norm = normalization(norm_nc) 78 | 79 | # The dimension of the intermediate embedding space. Yes, hardcoded. 80 | nhidden = 128 81 | 82 | pw = ks // 2 83 | self.mlp_shared = nn.Sequential( 84 | nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), 85 | nn.ReLU() 86 | ) 87 | self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 88 | self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 89 | 90 | def forward(self, x_dic, segmap_dic, size=None): 91 | 92 | if size is None: 93 | segmap = segmap_dic[str(x_dic.size(-1))] 94 | x = x_dic 95 | else: 96 | x = x_dic[str(size)] 97 | segmap = segmap_dic[str(size)] 98 | 99 | # Part 1. generate parameter-free normalized activations 100 | normalized = self.param_free_norm(x) 101 | 102 | # Part 2. produce scaling and bias conditioned on semantic map 103 | # segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') 104 | actv = self.mlp_shared(segmap) 105 | gamma = self.mlp_gamma(actv) 106 | beta = self.mlp_beta(actv) 107 | 108 | # apply scale and bias 109 | out = normalized * (1 + gamma) + beta 110 | 111 | return out 112 | -------------------------------------------------------------------------------- /NullSpaceDiff/musiq/model/multiscale_transformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Multiscale image quality transformer. https://arxiv.org/abs/2108.05997.""" 17 | 18 | from flax import nn 19 | import jax.numpy as jnp 20 | import numpy as np 21 | 22 | import musiq.model.multiscale_transformer_utils as utils 23 | import musiq.model.resnet as resnet 24 | 25 | RESNET_TOKEN_DIM = 64 26 | 27 | 28 | class Model(nn.Module): 29 | """Multiscale patch transformer.""" 30 | 31 | def apply(self, 32 | x, 33 | num_classes=1, 34 | train=False, 35 | hidden_size=None, 36 | transformer=None, 37 | resnet_emb=None, 38 | representation_size=None): 39 | """Apply model on inputs. 40 | 41 | Args: 42 | x: the processed input patches and position annotations. 43 | num_classes: the number of output classes. 1 for single model. 44 | train: train or eval. 45 | hidden_size: the hidden dimension for patch embedding tokens. 46 | transformer: the model config for Transformer backbone. 47 | resnet_emb: the config for patch embedding w/ small resnet. 48 | representation_size: size of the last FC before prediction. 49 | 50 | Returns: 51 | Model prediction output. 52 | """ 53 | assert transformer is not None 54 | # Either 3: (batch size, seq len, channel) or 55 | # 4: (batch size, crops, seq len, channel) 56 | assert len(x.shape) in [3, 4] 57 | 58 | multi_crops_input = False 59 | if len(x.shape) == 4: 60 | multi_crops_input = True 61 | batch_size, num_crops, l, channel = x.shape 62 | x = jnp.reshape(x, [batch_size * num_crops, l, channel]) 63 | 64 | # We concat (x, spatial_positions, scale_posiitons, input_masks) 65 | # when preprocessing. 66 | inputs_spatial_positions = x[:, :, -3] 67 | inputs_spatial_positions = inputs_spatial_positions.astype(jnp.int32) 68 | inputs_scale_positions = x[:, :, -2] 69 | inputs_scale_positions = inputs_scale_positions.astype(jnp.int32) 70 | inputs_masks = x[:, :, -1] 71 | inputs_masks = inputs_masks.astype(jnp.bool_) 72 | x = x[:, :, :-3] 73 | n, l, channel = x.shape 74 | if hidden_size: 75 | if resnet_emb: 76 | # channel = patch_size * patch_size * 3 77 | patch_size = int(np.sqrt(channel // 3)) 78 | x = jnp.reshape(x, [-1, patch_size, patch_size, 3]) 79 | x = resnet.StdConv( 80 | x, RESNET_TOKEN_DIM, (7, 7), (2, 2), bias=False, name="conv_root") 81 | x = nn.GroupNorm(x, name="gn_root") 82 | x = nn.relu(x) 83 | x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") 84 | 85 | if resnet_emb.num_layers > 0: 86 | blocks, bottleneck = resnet.get_block_desc(resnet_emb.num_layers) 87 | if blocks: 88 | x = resnet.ResNetStage( 89 | x, 90 | blocks[0], 91 | RESNET_TOKEN_DIM, 92 | first_stride=(1, 1), 93 | bottleneck=bottleneck, 94 | name="block1") 95 | for i, block_size in enumerate(blocks[1:], 1): 96 | x = resnet.ResNetStage( 97 | x, 98 | block_size, 99 | RESNET_TOKEN_DIM * 2**i, 100 | first_stride=(2, 2), 101 | bottleneck=bottleneck, 102 | name=f"block{i + 1}") 103 | x = jnp.reshape(x, [n, l, -1]) 104 | 105 | x = nn.Dense(x, hidden_size, name="embedding") 106 | 107 | # Here, x is a list of embeddings. 108 | x = utils.Encoder( 109 | x, 110 | inputs_spatial_positions, 111 | inputs_scale_positions, 112 | inputs_masks, 113 | train=train, 114 | name="Transformer", 115 | **transformer) 116 | 117 | x = x[:, 0] 118 | 119 | if representation_size: 120 | x = nn.Dense(x, representation_size, name="pre_logits") 121 | x = nn.tanh(x) 122 | else: 123 | x = resnet.IdentityLayer(x, name="pre_logits") 124 | 125 | x = nn.Dense(x, num_classes, name="head", kernel_init=nn.initializers.zeros) 126 | if multi_crops_input: 127 | _, channel = x.shape 128 | x = jnp.reshape(x, [batch_size, num_crops, channel]) 129 | return x -------------------------------------------------------------------------------- /NullSpaceDiff/musiq/model/resnet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ResNet V1 with GroupNorm. https://arxiv.org/abs/1512.03385.""" 17 | 18 | from flax import nn 19 | import jax.numpy as jnp 20 | 21 | 22 | class IdentityLayer(nn.Module): 23 | """Identity layer, convenient for giving a name to an array.""" 24 | 25 | def apply(self, x): 26 | return x 27 | 28 | 29 | def weight_standardize(w, axis, eps): 30 | w = w - jnp.mean(w, axis=axis) 31 | w = w / (jnp.std(w, axis=axis) + eps) 32 | return w 33 | 34 | 35 | class StdConv(nn.Conv): 36 | 37 | def param(self, name, shape, initializer): 38 | param = super().param(name, shape, initializer) 39 | if name == "kernel": 40 | param = weight_standardize(param, axis=[0, 1, 2], eps=1e-5) 41 | return param 42 | 43 | 44 | class ResidualUnit(nn.Module): 45 | """Bottleneck ResNet block.""" 46 | 47 | def apply(self, x, nout, strides=(1, 1), bottleneck=True): 48 | features = nout 49 | nout = nout * 4 if bottleneck else nout 50 | needs_projection = x.shape[-1] != nout or strides != (1, 1) 51 | residual = x 52 | if needs_projection: 53 | residual = StdConv( 54 | residual, nout, (1, 1), strides, bias=False, name="conv_proj") 55 | residual = nn.GroupNorm(residual, epsilon=1e-4, name="gn_proj") 56 | 57 | if bottleneck: 58 | x = StdConv(x, features, (1, 1), bias=False, name="conv1") 59 | x = nn.GroupNorm(x, epsilon=1e-4, name="gn1") 60 | x = nn.relu(x) 61 | 62 | x = StdConv(x, features, (3, 3), strides, bias=False, name="conv2") 63 | x = nn.GroupNorm(x, epsilon=1e-4, name="gn2") 64 | x = nn.relu(x) 65 | 66 | last_kernel = (1, 1) if bottleneck else (3, 3) 67 | x = StdConv(x, nout, last_kernel, bias=False, name="conv3") 68 | x = nn.GroupNorm( 69 | x, epsilon=1e-4, name="gn3", scale_init=nn.initializers.zeros) 70 | x = nn.relu(residual + x) 71 | 72 | return x 73 | 74 | 75 | class ResNetStage(nn.Module): 76 | 77 | def apply(self, x, block_size, nout, first_stride, bottleneck=True): 78 | x = ResidualUnit( 79 | x, nout, strides=first_stride, bottleneck=bottleneck, name="unit1") 80 | for i in range(1, block_size): 81 | x = ResidualUnit( 82 | x, nout, strides=(1, 1), bottleneck=bottleneck, name=f"unit{i + 1}") 83 | return x 84 | 85 | 86 | class Model(nn.Module): 87 | """ResNetV1.""" 88 | 89 | def apply(self, 90 | x, 91 | num_classes=1000, 92 | train=False, 93 | width_factor=1, 94 | num_layers=50): 95 | del train 96 | blocks, bottleneck = get_block_desc(num_layers) 97 | width = int(64 * width_factor) 98 | 99 | # Root block 100 | x = StdConv(x, width, (7, 7), (2, 2), bias=False, name="conv_root") 101 | x = nn.GroupNorm(x, name="gn_root") 102 | x = nn.relu(x) 103 | x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") 104 | 105 | # Stages 106 | x = ResNetStage( 107 | x, 108 | blocks[0], 109 | width, 110 | first_stride=(1, 1), 111 | bottleneck=bottleneck, 112 | name="block1") 113 | for i, block_size in enumerate(blocks[1:], 1): 114 | x = ResNetStage( 115 | x, 116 | block_size, 117 | width * 2**i, 118 | first_stride=(2, 2), 119 | bottleneck=bottleneck, 120 | name=f"block{i + 1}") 121 | 122 | # Head 123 | x = jnp.mean(x, axis=(1, 2)) 124 | x = IdentityLayer(x, name="pre_logits") 125 | x = nn.Dense(x, num_classes, kernel_init=nn.initializers.zeros, name="head") 126 | return x 127 | 128 | 129 | # A dictionary mapping the number of layers in a resnet to the number of 130 | # blocks in each stage of the model. The second argument indicates whether we 131 | # use bottleneck layers or not. 132 | def get_block_desc(num_layers): 133 | if isinstance(num_layers, list): # Be robust to silly mistakes. 134 | num_layers = tuple(num_layers) 135 | return { 136 | 5: ([1], True), # Only strided blocks. Total stride 4. 137 | 8: ([1, 1], True), # Only strided blocks. Total stride 8. 138 | 11: ([1, 1, 1], True), # Only strided blocks. Total stride 16. 139 | 14: ([1, 1, 1, 1], True), # Only strided blocks. Total stride 32. 140 | 9: ([1, 1, 1, 1], False), # Only strided blocks. Total stride 32. 141 | 18: ([2, 2, 2, 2], False), 142 | 26: ([2, 2, 2, 2], True), 143 | 34: ([3, 4, 6, 3], False), 144 | 50: ([3, 4, 6, 3], True), 145 | 101: ([3, 4, 23, 3], True), 146 | 152: ([3, 8, 36, 3], True), 147 | 200: ([3, 24, 36, 3], True) 148 | }.get(num_layers, (num_layers, True)) -------------------------------------------------------------------------------- /NullSpaceDiff/musiq/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py>=0.12.0 2 | chex>=0.0.7 3 | clu>=0.0.3 4 | einops>=0.3.0 5 | flax==0.3.3 6 | ml-collections==0.1.0 7 | numpy>=1.19.5 8 | pandas>=1.1.0 9 | tensorflow>=2.0.0-beta1 10 | jax>=0.1.55 11 | jaxlib>=0.1.37 -------------------------------------------------------------------------------- /NullSpaceDiff/scripts/infer.sh: -------------------------------------------------------------------------------- 1 | # python ./scripts/sr_val_ddpm_lensless.py --init-img data/fft-svd-1280-1408-learn-1280-1408-meas-decoded_sim_spatial_weight/val/inputs_512 --outdir data/test-250317-554 --ckpt logs/2024-04-15T00-59-14_flatnet_decoded_sim_multi/checkpoints/epoch\=000554.ckpt --n_samples 5 --ddpm_steps 200 2 | python ./scripts/sr_val_ddpm_lensless.py --init-img data/fft-svd-1280-1408-meas-decoded_sim_spatial_weight/val/inputs_512 --outdir data/fft-svd-1280-1408-meas-decoded_sim_spatial_weight/val/test-250318-002 --ckpt logs/2025-03-17T15-30-48_svd_nullspace_diff/checkpoints/epoch=000002.ckpt --n_samples 5 --ddpm_steps 200 3 | -------------------------------------------------------------------------------- /NullSpaceDiff/scripts/train.sh: -------------------------------------------------------------------------------- 1 | python main.py --train --base configs/NullSpaceDiff/v2-finetune_phlatcam_decoded_sim_multi_T_512.yaml --gpus 0,1,2,3,4,5,6,7, --scale_lr False --name svd_nullspace_diff -------------------------------------------------------------------------------- /NullSpaceDiff/scripts/wavelet_color_fix.py: -------------------------------------------------------------------------------- 1 | ''' 2 | # -------------------------------------------------------------------------------- 3 | # Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py) 4 | # -------------------------------------------------------------------------------- 5 | ''' 6 | 7 | import torch 8 | from PIL import Image 9 | from torch import Tensor 10 | from torch.nn import functional as F 11 | 12 | from torchvision.transforms import ToTensor, ToPILImage 13 | 14 | def adain_color_fix(target: Image, source: Image): 15 | # Convert images to tensors 16 | to_tensor = ToTensor() 17 | target_tensor = to_tensor(target).unsqueeze(0) 18 | source_tensor = to_tensor(source).unsqueeze(0) 19 | 20 | # Apply adaptive instance normalization 21 | result_tensor = adaptive_instance_normalization(target_tensor, source_tensor) 22 | 23 | # Convert tensor back to image 24 | to_image = ToPILImage() 25 | result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) 26 | 27 | return result_image 28 | 29 | def wavelet_color_fix(target: Image, source: Image): 30 | # Convert images to tensors 31 | to_tensor = ToTensor() 32 | target_tensor = to_tensor(target).unsqueeze(0) 33 | source_tensor = to_tensor(source).unsqueeze(0) 34 | 35 | # Apply wavelet reconstruction 36 | result_tensor = wavelet_reconstruction(target_tensor, source_tensor) 37 | 38 | # Convert tensor back to image 39 | to_image = ToPILImage() 40 | result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) 41 | 42 | return result_image 43 | 44 | def calc_mean_std(feat: Tensor, eps=1e-5): 45 | """Calculate mean and std for adaptive_instance_normalization. 46 | Args: 47 | feat (Tensor): 4D tensor. 48 | eps (float): A small value added to the variance to avoid 49 | divide-by-zero. Default: 1e-5. 50 | """ 51 | size = feat.size() 52 | assert len(size) == 4, 'The input feature should be 4D tensor.' 53 | b, c = size[:2] 54 | feat_var = feat.reshape(b, c, -1).var(dim=2) + eps 55 | feat_std = feat_var.sqrt().reshape(b, c, 1, 1) 56 | feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1) 57 | return feat_mean, feat_std 58 | 59 | def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): 60 | """Adaptive instance normalization. 61 | Adjust the reference features to have the similar color and illuminations 62 | as those in the degradate features. 63 | Args: 64 | content_feat (Tensor): The reference feature. 65 | style_feat (Tensor): The degradate features. 66 | """ 67 | size = content_feat.size() 68 | style_mean, style_std = calc_mean_std(style_feat) 69 | content_mean, content_std = calc_mean_std(content_feat) 70 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) 71 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 72 | 73 | def wavelet_blur(image: Tensor, radius: int): 74 | """ 75 | Apply wavelet blur to the input tensor. 76 | """ 77 | # input shape: (1, 3, H, W) 78 | # convolution kernel 79 | kernel_vals = [ 80 | [0.0625, 0.125, 0.0625], 81 | [0.125, 0.25, 0.125], 82 | [0.0625, 0.125, 0.0625], 83 | ] 84 | kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) 85 | # add channel dimensions to the kernel to make it a 4D tensor 86 | kernel = kernel[None, None] 87 | # repeat the kernel across all input channels 88 | kernel = kernel.repeat(3, 1, 1, 1) 89 | image = F.pad(image, (radius, radius, radius, radius), mode='replicate') 90 | # apply convolution 91 | output = F.conv2d(image, kernel, groups=3, dilation=radius) 92 | return output 93 | 94 | def wavelet_decomposition(image: Tensor, levels=5): 95 | """ 96 | Apply wavelet decomposition to the input tensor. 97 | This function only returns the low frequency & the high frequency. 98 | """ 99 | high_freq = torch.zeros_like(image) 100 | for i in range(levels): 101 | radius = 2 ** i 102 | low_freq = wavelet_blur(image, radius) 103 | high_freq += (image - low_freq) 104 | image = low_freq 105 | 106 | return high_freq, low_freq 107 | 108 | def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): 109 | """ 110 | Apply wavelet decomposition, so that the content will have the same color as the style. 111 | """ 112 | # calculate the wavelet decomposition of the content feature 113 | content_high_freq, content_low_freq = wavelet_decomposition(content_feat) 114 | del content_low_freq 115 | # calculate the wavelet decomposition of the style feature 116 | style_high_freq, style_low_freq = wavelet_decomposition(style_feat) 117 | del style_high_freq 118 | # reconstruct the content feature with the style's high frequency 119 | return content_high_freq + style_low_freq 120 | -------------------------------------------------------------------------------- /NullSpaceDiff/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='NullSpaceDiff', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) 14 | -------------------------------------------------------------------------------- /NullSpaceDiff/utils/lensless_utils.py: -------------------------------------------------------------------------------- 1 | from waveprop.simulation import FarFieldSimulator 2 | import argparse 3 | import os 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import sys 9 | from pathlib import Path 10 | from waveprop.devices import SensorParam 11 | from torch import nn 12 | sensor = dict(size = np.array([4.8e-6 * 1518, 4.8e-6 * 2012])) 13 | 14 | class LenslessSimulator(nn.Module): 15 | def __init__(self, psf_path, object_height = 0.4, scene2mask = 0.434, mask2sensor = 2e-3, sensor = sensor): 16 | super(LenslessSimulator, self).__init__() 17 | psf = np.load(psf_path) 18 | psf = torch.from_numpy(psf).float() 19 | psf = psf.unsqueeze(0).unsqueeze(0) 20 | psf = self.crop_and_padding(psf) 21 | # psf = psf.unsqueeze(-1) 22 | # psf = psf[..., None] 23 | # print(psf.shape) 24 | self.simulator = FarFieldSimulator(object_height, scene2mask, mask2sensor, sensor, psf, is_torch = True, quantize = False) 25 | 26 | def crop_and_padding(self, img, 27 | meas_crop_size_x=1280, meas_crop_size_y=1408, meas_centre_x=808, meas_centre_y=965, psf_height=1518, psf_width=2012, pad_meas_mode="replicate"): 28 | crop_x = meas_centre_x - meas_crop_size_x // 2 29 | crop_y = meas_centre_y - meas_crop_size_y // 2 30 | crop_x_end = crop_x + meas_crop_size_x 31 | crop_y_end = crop_y + meas_crop_size_y 32 | if crop_x < 0: 33 | crop_x = 0 34 | if crop_y < 0: 35 | crop_y = 0 36 | # img shape: (B, C, H, W) 37 | img = img[:, :, crop_x:crop_x_end, crop_y:crop_y_end] 38 | # pad to psf size 39 | pad_x = psf_height - meas_crop_size_x 40 | pad_y = psf_width - meas_crop_size_y 41 | if pad_x < 0 or pad_y < 0: 42 | raise ValueError("psf size should be larger than meas_crop_size") 43 | #resize to half size 44 | # img = F.interpolate(img, scale_factor=0.5, mode="bilinear") 45 | # print("pad_x: {}, pad_y: {}".format(pad_x, pad_y)) 46 | # print("img shape: {}".format(img.shape)) 47 | img = F.pad(img, (pad_y // 2, pad_y // 2, pad_x // 2, pad_x // 2), mode=pad_meas_mode) 48 | return img 49 | 50 | 51 | def forward(self, x): 52 | # print("max: {}, min: {}".format(torch.max(x), torch.min(x))) 53 | x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0) 54 | # x = (x - torch.min(x)) / (torch.max(x) - torch.min(x)) 55 | x = self.simulator.propagate(x) 56 | # x = x / 1000 57 | 58 | # print("after sim conv max: {}, min: {}".format(torch.max(x), torch.min(x))) 59 | # x = (x - torch.min(x)) / (torch.max(x) - torch.min(x)) 60 | # transfer to -1-1 61 | # x = x * 2 - 1 62 | # x = self.crop_and_padding(x) 63 | return x 64 | 65 | fft_args = { 66 | "psf_mat": Path("/root/caixin/flatnet/data/phase_psf/psf.npy"), 67 | "psf_height": 1518, 68 | "psf_width": 2012, 69 | "psf_centre_x": 808, 70 | "psf_centre_y": 965, 71 | "psf_crop_size_x": 1280, 72 | "psf_crop_size_y": 1408, 73 | "meas_height": 1518, 74 | "meas_width": 2012, 75 | "meas_centre_x": 808, 76 | "meas_centre_y": 965, 77 | "meas_crop_size_x": 1280, 78 | "meas_crop_size_y": 1408, 79 | "pad_meas_mode": "replicate", 80 | # Change meas_crop_size_{x,y} to crop sensor meas. This will assume your sensor is smaller than the 81 | # measurement size. True measurement size is 1280x1408x4. Anything smaller than this requires padding of the 82 | # cropped measurement and then multiplying this with gaussian filtered rectangular box. For simplicity use the arguments 83 | # already set. Currently we are using full measurement. 84 | "image_height": 384, 85 | "image_width": 384, 86 | "fft_gamma": 20000, # Gamma for Weiner init 87 | "use_mask": False, # Use mask for cropped meas only 88 | "mask_path": Path("/root/caixin/flatnet/data/phase_psf/box_gaussian_1280_1408.npy"), 89 | # use Path("box_gaussian_1280_1408.npy") for controlled lighting 90 | # use Path("box_gaussian_1280_1408_big_mask.npy") for uncontrolled lighting 91 | "fft_requires_grad": False, 92 | } 93 | 94 | def load_real_capture_as_tensor(data_path, is_cuda = True): 95 | real_capture = cv2.imread(data_path).astype(np.float32) / 255 96 | real_capture = (real_capture - np.min(real_capture)) / (np.max(real_capture) - np.min(real_capture)) 97 | real_capture = torch.tensor(real_capture).permute(2, 0, 1).float().unsqueeze(0) 98 | if is_cuda: 99 | real_capture = real_capture.cuda() -------------------------------------------------------------------------------- /SVDeconv/config_diffusercam.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convention 3 | 4 | ours/naive-fft-(fft_h-fft_w)-learn-(learn_h-learn_w)-meas-(meas_h-meas-w)-kwargs 5 | 6 | * Phlatcam: 1518 x 2012 (post demosiacking) 7 | """ 8 | from pathlib import Path 9 | import torch 10 | from types import SimpleNamespace 11 | 12 | # Define FFT arguments once at the module level 13 | 14 | height = 270 15 | width = 480 16 | fft_args_dict = { 17 | "psf_mat": Path("data/diffusercam/psf.tiff"), 18 | "psf_height": height, 19 | "psf_width": width, 20 | "psf_centre_x": height // 2, 21 | "psf_centre_y": width // 2, 22 | "psf_crop_size_x": height, 23 | "psf_crop_size_y": width, 24 | "meas_height": height, 25 | "meas_width": width, 26 | "meas_centre_x": height // 2, 27 | "meas_centre_y": width // 2, 28 | "meas_crop_size_x": height, 29 | "meas_crop_size_y": width, 30 | "pad_meas_mode": "replicate", 31 | "image_height": 270, 32 | "image_width": 480, 33 | "fft_gamma": 100, # Gamma for Weiner init 34 | "fft_requires_grad": False, 35 | "fft_epochs": 0, 36 | } 37 | 38 | def base_config(): 39 | exp_name = "fft-diffusercam" 40 | is_naive = "naive" in exp_name 41 | multi = 1 42 | use_spatial_weight = False 43 | weight_update = True 44 | dataset = "diffusercam" 45 | # Use FFT arguments from the global definition 46 | locals().update(fft_args_dict) 47 | # ---------------------------------------------------------------------------- # 48 | # Directories 49 | # ---------------------------------------------------------------------------- # 50 | 51 | image_dir = Path("data/diffusercam") 52 | output_dir = Path("output/diffusercam") / exp_name 53 | ckpt_dir = Path("ckpts/diffusercam") / exp_name 54 | run_dir = Path("runs/diffusercam") / exp_name # Tensorboard 55 | 56 | # ---------------------------------------------------------------------------- # 57 | # Data 58 | # ---------------------------------------------------------------------------- # 59 | 60 | 61 | shuffle = True 62 | train_gaussian_noise = 5e-3 63 | 64 | 65 | model = "UNet270480" 66 | batch_size = 18 67 | num_threads = batch_size >> 1 # parallel workers 68 | 69 | # ---------------------------------------------------------------------------- # 70 | # Train Configs 71 | # ---------------------------------------------------------------------------- # 72 | # Schedules 73 | num_epochs = 100 74 | fft_epochs = num_epochs if is_naive else 0 75 | 76 | learning_rate = 1e-4 77 | fft_learning_rate = 3e-5 78 | 79 | # Betas for AdamW. We follow https://arxiv.org/pdf/1704.00028 80 | beta_1 = 0.9 # momentum 81 | beta_2 = 0.999 82 | 83 | lr_scheduler = "cosine" # or step 84 | 85 | # Cosine annealing 86 | T_0 = 1 87 | T_mult = 2 88 | step_size = 2 # For step lr 89 | 90 | # saving models 91 | save_filename_G = "model.pth" 92 | save_filename_FFT = "FFT.pth" 93 | save_filename_D = "D.pth" 94 | 95 | save_filename_latest_G = "model_latest.pth" 96 | save_filename_latest_FFT = "FFT_latest.pth" 97 | save_filename_latest_D = "D_latest.pth" 98 | 99 | log_interval = 100 # the number of iterations (default: 10) to print at 100 | save_ckpt_interval = log_interval * 10 101 | save_copy_every_epochs = 10 102 | # ---------------------------------------------------------------------------- # 103 | # Model 104 | # ---------------------------------------------------------------------------- # 105 | # See models/get_model.py for registry 106 | # model = "unet-128-pixelshuffle-invert" 107 | pixelshuffle_ratio = 2 108 | grad_lambda = 0.0 109 | 110 | G_finetune_layers = [] # None implies all 111 | 112 | num_groups = 8 # Group norm 113 | 114 | # ---------------------------------------------------------------------------- # 115 | # Loss 116 | # ---------------------------------------------------------------------------- # 117 | lambda_adversarial = 0.6 118 | lambda_contextual = 0.0 119 | lambda_perception = 1.2 # 0.006 120 | lambda_image = 1 # mse 121 | lambda_l1 = 0 # l1 122 | 123 | resume = False 124 | finetune = False # Wont load loss or epochs 125 | concat_input = False 126 | zero_conv = False 127 | # ---------------------------------------------------------------------------- # 128 | # Inference Args 129 | # ---------------------------------------------------------------------------- # 130 | inference_mode = "latest" 131 | assert inference_mode in ["latest", "best"] 132 | 133 | # ---------------------------------------------------------------------------- # 134 | # Distribution Args 135 | # ---------------------------------------------------------------------------- # 136 | # choose cpu or cuda:0 device 137 | device = "cuda" if torch.cuda.is_available() else "cpu" 138 | distdataparallel = False 139 | val_train = False 140 | static_val_image = "" 141 | 142 | 143 | 144 | 145 | def ours_diffusercam_mulnew_unet_padding_decode_sim(): 146 | exp_name = "fft-mulnew9-diffusercam_unet_padding_decode_sim" 147 | batch_size = 5 148 | num_threads = 5 149 | lambda_adversarial = 0.0 150 | multi = 9 151 | use_spatial_weight = True 152 | lambda_perception = 0.05 153 | preprocess_with_unet = True 154 | psf_height = 270 * 2 155 | psf_width = 480 * 2 156 | decode_sim = True 157 | 158 | 159 | def infer_train(): 160 | val_train = True 161 | 162 | 163 | named_config_ll = [ 164 | ours_diffusercam_mulnew_unet_padding_decode_sim, 165 | infer_train 166 | ] 167 | 168 | 169 | def initialise(ex): 170 | ex.config(base_config) 171 | for named_config in named_config_ll: 172 | ex.named_config(named_config) 173 | return ex 174 | 175 | fft_args = SimpleNamespace(**fft_args_dict) 176 | 177 | if __name__ == "__main__": 178 | str_named_config_ll = [str(named_config) for named_config in named_config_ll] 179 | print("\n".join(str_named_config_ll)) 180 | -------------------------------------------------------------------------------- /SVDeconv/data: -------------------------------------------------------------------------------- 1 | /root/RawSense/flatnet/data -------------------------------------------------------------------------------- /SVDeconv/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataloaders 3 | """ 4 | 5 | # Libs 6 | from dataclasses import dataclass 7 | import logging 8 | from typing import TYPE_CHECKING 9 | from sacred import Experiment 10 | 11 | # Torch modules 12 | import torch 13 | from torch.utils.data import DataLoader, Dataset 14 | import torch.nn.functional as F 15 | import torch.distributed as dist 16 | import os 17 | import cv2 18 | import numpy as np 19 | from config import initialise 20 | from pathlib import Path 21 | from datasets.phlatcam import PhaseMaskDataset 22 | from datasets.diffusercam import LenslessLearningCollection 23 | if TYPE_CHECKING: 24 | from utils.typing_alias import * 25 | 26 | ex = Experiment("data") 27 | ex = initialise(ex) 28 | @dataclass 29 | class Data: 30 | train_loader: DataLoader 31 | val_loader: DataLoader 32 | test_loader: DataLoader = None 33 | 34 | def get_dataloaders(args, is_local_rank_0: bool = True): 35 | """ 36 | Get dataloaders for train and val 37 | 38 | Returns: 39 | :data 40 | """ 41 | if "phlatcam" in args.dataset_name: 42 | train_dataset = PhaseMaskDataset( 43 | args, mode="train", is_local_rank_0=is_local_rank_0 44 | ) 45 | val_dataset = PhaseMaskDataset(args, mode="val", is_local_rank_0=is_local_rank_0) 46 | test_dataset = PhaseMaskDataset(args, mode="val", is_local_rank_0=is_local_rank_0) 47 | elif "diffusercam" in args.dataset_name: 48 | dataset = LenslessLearningCollection(args) 49 | train_dataset = dataset.train_dataset 50 | val_dataset = dataset.val_dataset 51 | test_dataset = dataset.val_dataset 52 | # print("here") 53 | if is_local_rank_0: 54 | logging.info( 55 | f"Dataset: {args.dataset_name} Len Train: {len(train_dataset)} Val: {len(val_dataset)} Test: {len(test_dataset)}" 56 | ) 57 | 58 | train_loader = None 59 | val_loader = None 60 | test_loader = None 61 | 62 | if len(train_dataset): 63 | if args.distdataparallel: 64 | train_sampler = torch.utils.data.distributed.DistributedSampler( 65 | train_dataset, num_replicas=dist.get_world_size(), shuffle=True 66 | ) 67 | shuffle = False 68 | 69 | else: 70 | train_sampler = None 71 | shuffle = True 72 | 73 | train_loader = DataLoader( 74 | train_dataset, 75 | batch_size=args.batch_size, 76 | shuffle=shuffle, 77 | num_workers=args.num_threads, 78 | pin_memory=True, 79 | # drop_last=True, 80 | sampler=train_sampler, 81 | ) 82 | 83 | if len(val_dataset): 84 | if args.distdataparallel: 85 | val_sampler = torch.utils.data.distributed.DistributedSampler( 86 | val_dataset, num_replicas=dist.get_world_size(), shuffle=True 87 | ) 88 | shuffle = False 89 | 90 | else: 91 | val_sampler = None 92 | shuffle = False 93 | 94 | val_loader = DataLoader( 95 | val_dataset, 96 | batch_size=args.batch_size, 97 | shuffle=shuffle, 98 | num_workers=args.num_threads, 99 | pin_memory=True, 100 | # drop_last=True, 101 | sampler=val_sampler, 102 | ) 103 | 104 | if len(test_dataset): 105 | if args.distdataparallel: 106 | test_sampler = torch.utils.data.distributed.DistributedSampler( 107 | test_dataset, num_replicas=dist.get_world_size(), shuffle=True 108 | ) 109 | shuffle = False 110 | 111 | else: 112 | test_sampler = None 113 | shuffle = True 114 | 115 | test_loader = DataLoader( 116 | test_dataset, 117 | batch_size=args.batch_size, 118 | shuffle=shuffle, 119 | num_workers=args.num_threads, 120 | pin_memory=False, 121 | drop_last=True, 122 | sampler=test_sampler, 123 | ) 124 | 125 | return Data( 126 | train_loader=train_loader, val_loader=val_loader, test_loader=test_loader 127 | ) 128 | 129 | 130 | @ex.automain 131 | def main(_run): 132 | from tqdm import tqdm 133 | from utils.tupperware import tupperware 134 | 135 | args = tupperware(_run.config) 136 | 137 | data = get_dataloaders(args) 138 | 139 | for _ in tqdm(data.train_loader.dataset): 140 | pass 141 | -------------------------------------------------------------------------------- /SVDeconv/datasets/diffusercam.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision.transforms.functional import ( 7 | to_tensor, 8 | resize, 9 | ) 10 | 11 | from dataclasses import dataclass 12 | import logging 13 | from typing import TYPE_CHECKING 14 | from sacred import Experiment 15 | 16 | # Torch modules 17 | import torch 18 | from torch.utils.data import DataLoader, Dataset 19 | import torch.nn.functional as F 20 | import torch.distributed as dist 21 | import os 22 | import cv2 23 | import numpy as np 24 | from config import initialise 25 | from pathlib import Path 26 | 27 | if TYPE_CHECKING: 28 | from utils.typing_alias import * 29 | 30 | 31 | ex = Experiment("data") 32 | ex = initialise(ex) 33 | 34 | 35 | SIZE = 270, 480 36 | 37 | def region_of_interest(x): 38 | return x[..., 60:270, 60:440] 39 | 40 | 41 | def transform(image, gray=False): 42 | # print(image.shape) 43 | image = np.flip(np.flipud(image), axis=2) 44 | image = image.copy() 45 | image = to_tensor(image) 46 | image = resize(image, SIZE) 47 | image = (image - 0.5) * 2 48 | return image 49 | 50 | 51 | def sort_key(x): 52 | return int(x[2:-4]) 53 | 54 | 55 | def load_psf(path): 56 | psf = np.array(Image.open(path)) 57 | return transform(psf) 58 | 59 | 60 | class LenslessLearning(Dataset): 61 | def __init__(self, diffuser_images, ground_truth_images): 62 | """ 63 | Everything is upside-down, and the colors are BGR... 64 | """ 65 | self.xs = diffuser_images 66 | self.ys = ground_truth_images 67 | 68 | def read_image(self, filename): 69 | image = np.load(filename) 70 | 71 | def __len__(self): 72 | return len(self.xs) 73 | 74 | def __getitem__(self, idx): 75 | diffused = self.xs[idx] 76 | ground_truth = self.ys[idx] 77 | # print(diffused, ground_truth) 78 | # print("hello!", np.load(diffused).shape, np.load(ground_truth).shape) 79 | 80 | x = transform(np.load(diffused)) 81 | if ground_truth.name.endswith('.png'): 82 | y = np.array(Image.open(ground_truth)) 83 | y = transform(y) 84 | else: 85 | y = transform(np.load(ground_truth)) 86 | 87 | return x, y, str(diffused.name) 88 | 89 | 90 | class LenslessLearningInTheWild(Dataset): 91 | def __init__(self, path): 92 | xs = [] 93 | manifest = sorted((x.name for x in path.glob('*.npy'))) 94 | for filename in manifest: 95 | xs.append(path / filename) 96 | 97 | self.xs = xs 98 | 99 | def read_image(self, filename): 100 | image = np.load(filename) 101 | 102 | def __len__(self): 103 | return len(self.xs) 104 | 105 | def __getitem__(self, idx): 106 | diffused = self.xs[idx] 107 | x = transform(np.load(diffused)) 108 | return x 109 | 110 | 111 | class LenslessLearningCollection: 112 | def __init__(self, args): 113 | path = Path(args.image_dir) 114 | 115 | self.psf = load_psf(path / 'psf.tiff') 116 | 117 | train_diffused, train_ground_truth = load_manifest(path, 'dataset_train.csv', decode_sim = args.decode_sim) 118 | val_diffused, val_ground_truth = load_manifest(path, 'dataset_test.csv', decode_sim = args.decode_sim) 119 | 120 | self.train_dataset = LenslessLearning(train_diffused, train_ground_truth) 121 | self.val_dataset = LenslessLearning(val_diffused, val_ground_truth) 122 | self.region_of_interest = region_of_interest 123 | 124 | 125 | def load_manifest(path, csv_filename, decode_sim = False): 126 | with open(path / csv_filename) as f: 127 | manifest = f.read().split() 128 | 129 | xs, ys = [], [] 130 | for filename in manifest: 131 | x = path / 'diffuser_images' / filename.replace(".jpg.tiff", ".npy") 132 | if decode_sim: 133 | y = path / 'decode_sim_padding_png' / filename.replace(".jpg.tiff", ".png") 134 | else: 135 | y = path / 'ground_truth_lensed' / filename.replace(".jpg.tiff", ".npy") 136 | # if x.exists() and y.exists(): 137 | # print(f"Found {x} and {y}") 138 | xs.append(x) 139 | ys.append(y) 140 | # else: 141 | # print(f"No file named {x}") 142 | return xs, ys -------------------------------------------------------------------------------- /SVDeconv/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Metrics file 3 | """ 4 | import torch 5 | 6 | from typing import TYPE_CHECKING 7 | from sacred import Experiment 8 | 9 | 10 | if TYPE_CHECKING: 11 | from utils.typing_alias import * 12 | 13 | 14 | ex = Experiment("metrics") 15 | 16 | 17 | def PSNR(source: "Tensor", target: "Tensor"): 18 | """ 19 | Peak Signal to noise ratio 20 | 21 | Ref: https://www.mathworks.com/help/vision/ref/psnr.html 22 | 23 | Images between [-1,1] 24 | """ 25 | source = source.mul(0.5).add(0.5).clamp(0, 1) 26 | target = target.mul(0.5).add(0.5).clamp(0, 1) 27 | noise = ((source - target) ** 2).mean(dim=3).mean(dim=2).mean(dim=1) 28 | signal_max = 1.0 29 | return (10 * torch.log10(signal_max / noise)).mean() 30 | -------------------------------------------------------------------------------- /SVDeconv/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenImagingLab/PhoCoLens/154fe32aea5c2b623f6dd1c07e90c2900d076486/SVDeconv/models/__init__.py -------------------------------------------------------------------------------- /SVDeconv/models/discriminator.py: -------------------------------------------------------------------------------- 1 | from sacred import Experiment 2 | from typing import TYPE_CHECKING 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from config import initialise 8 | 9 | if TYPE_CHECKING: 10 | from utils.typing_alias import * 11 | 12 | ex = Experiment("Disc") 13 | ex = initialise(ex) 14 | 15 | 16 | class Discriminator(nn.Module): 17 | def __init__(self, args): 18 | super(Discriminator, self).__init__() 19 | 20 | self.args = args 21 | 22 | self.disc = nn.Sequential( 23 | nn.Conv2d(3, 64, kernel_size=3, padding=1), 24 | nn.LeakyReLU(0.2), 25 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), 26 | nn.GroupNorm(num_channels=128, num_groups=args.num_groups), 27 | nn.LeakyReLU(0.2), 28 | nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), 29 | nn.GroupNorm(num_channels=128, num_groups=args.num_groups), 30 | nn.LeakyReLU(0.2), 31 | nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), 32 | nn.GroupNorm(num_channels=256, num_groups=args.num_groups), 33 | nn.LeakyReLU(0.2), 34 | nn.AdaptiveAvgPool2d(1), 35 | nn.Conv2d(256, 1, kernel_size=1), 36 | ) 37 | 38 | def forward(self, img): 39 | logit = self.disc(img).squeeze() 40 | # print(logit.shape) 41 | return logit 42 | 43 | 44 | @ex.automain 45 | def main(_run): 46 | from utils.tupperware import tupperware 47 | 48 | args = tupperware(_run.config) 49 | from math import ceil 50 | 51 | batch = 2 52 | 53 | img_ll = [ 54 | torch.randn(batch, 3, ceil(args.image_height / 4), ceil(args.image_width / 4)), 55 | torch.randn(batch, 3, ceil(args.image_height / 2), ceil(args.image_width / 2)), 56 | torch.randn(batch, 3, args.image_height, args.image_width), 57 | ] 58 | D = Discriminator(args) 59 | 60 | D(img_ll) 61 | -------------------------------------------------------------------------------- /SVDeconv/models/get_model.py: -------------------------------------------------------------------------------- 1 | from models.multi_fftlayer_diff import MultiFFTLayer_diff as SVDeconvLayer_diff 2 | # from models.multi_fftlayer_new import MultiFFTLayer_new as SVDeconvLayer 3 | from models.multi_fftlayer import MultiFFTLayer as SVDeconvLayer 4 | 5 | from models.fftlayer import FFTLayer 6 | from models.fftlayer_diff import FFTLayer_diff 7 | from models.unet_128 import Unet as Unet_128 8 | from models.unet import UNet270480 as Unet_diff 9 | 10 | def get_inversion_and_channels(args): 11 | is_svd = "svd" in args.exp_name 12 | is_diff = "diff" in args.exp_name 13 | 14 | if is_svd and not is_diff: 15 | return SVDeconvLayer, 4 if args.load_raw else 3 16 | # return SVDeconvLayer, 8 if args.load_raw else 6 17 | 18 | elif is_svd and is_diff: 19 | return SVDeconvLayer_diff, 6 20 | elif not is_diff: 21 | return FFTLayer, 3 22 | else: 23 | return FFTLayer_diff, 3 24 | 25 | def model(args): 26 | Inversion, in_c = get_inversion_and_channels(args) 27 | 28 | if args.model == "unet-128-pixelshuffle-invert": 29 | return Unet_128(args, in_c=in_c), Inversion(args) 30 | elif args.model == "UNet270480": 31 | return Unet_diff(args, in_c=in_c), Inversion(args) -------------------------------------------------------------------------------- /SVDeconv/models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | BN_EPS = 1e-4 6 | 7 | 8 | class ConvBnRelu2d(nn.Module): 9 | def __init__(self, in_channels, out_channels, kernel_size=(7, 7), padding=1): 10 | super(ConvBnRelu2d, self).__init__() 11 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=False) 12 | self.bn = nn.BatchNorm2d(out_channels, eps=BN_EPS) 13 | self.relu = nn.ReLU(inplace=True) 14 | 15 | def forward(self, x): 16 | x = self.conv(x) 17 | x = self.bn(x) 18 | x = self.relu(x) 19 | return x 20 | 21 | 22 | class StackEncoder(nn.Module): 23 | def __init__(self, x_channels, y_channels, kernel_size=(7, 7)): 24 | super(StackEncoder, self).__init__() 25 | padding = (kernel_size - 1) // 2 26 | self.encode = nn.Sequential( 27 | ConvBnRelu2d(x_channels, y_channels, kernel_size=kernel_size, padding=padding), 28 | ConvBnRelu2d(y_channels, y_channels, kernel_size=kernel_size, padding=padding), 29 | ) 30 | 31 | def forward(self, x): 32 | x = self.encode(x) 33 | x_small = F.max_pool2d(x, kernel_size=2, stride=2) 34 | return x, x_small 35 | 36 | 37 | class StackDecoder(nn.Module): 38 | def __init__(self, x_big_channels, x_channels, y_channels, kernel_size=3): 39 | super(StackDecoder, self).__init__() 40 | padding = (kernel_size - 1) // 2 41 | 42 | self.decode = nn.Sequential( 43 | ConvBnRelu2d(x_big_channels + x_channels, y_channels, kernel_size=kernel_size, padding=padding), 44 | ConvBnRelu2d(y_channels, y_channels, kernel_size=kernel_size, padding=padding), 45 | ConvBnRelu2d(y_channels, y_channels, kernel_size=kernel_size, padding=padding), 46 | ) 47 | 48 | def forward(self, x, down_tensor): 49 | _, channels, height, width = down_tensor.size() 50 | x = F.upsample(x, size=(height, width), mode='bilinear') 51 | x = torch.cat([x, down_tensor], 1) 52 | x = self.decode(x) 53 | return x 54 | 55 | # 32x32 56 | 57 | class UNet270480(nn.Module): 58 | def __init__(self, args, in_c): 59 | super(UNet270480, self).__init__() 60 | # channels, height, width = in_shape 61 | 62 | self.down1 = StackEncoder(in_c * args.pixelshuffle_ratio ** 2, 64, kernel_size=7) ;# 256 63 | self.down2 = StackEncoder(64, 128, kernel_size=7) # 128 64 | self.down3 = StackEncoder(128, 256, kernel_size=7) # 64 65 | self.down4 = StackEncoder(256, 512, kernel_size=7) # 32 66 | self.down5 = StackEncoder(512, 1024, kernel_size=7) # 16 67 | 68 | 69 | self.up5 = StackDecoder(1024, 1024, 512, kernel_size=7) # 32 70 | self.up4 = StackDecoder(512, 512, 256, kernel_size=7) # 64 71 | self.up3 = StackDecoder(256, 256, 128, kernel_size=7) # 128 72 | self.up2 = StackDecoder(128, 128, 64, kernel_size=7) # 256 73 | self.up1 = StackDecoder(64, 64, 64, kernel_size=7) # 512 74 | self.classify = nn.Conv2d(64, 3 * args.pixelshuffle_ratio ** 2, kernel_size=1, bias=True) 75 | 76 | 77 | self.center = nn.Sequential(ConvBnRelu2d(1024, 1024, kernel_size=3, padding=1)) 78 | 79 | 80 | def forward(self, x): 81 | # print(x.shape) 82 | out = x; 83 | down1, out = self.down1(out); 84 | down2, out = self.down2(out); 85 | down3, out = self.down3(out); 86 | down4, out = self.down4(out); 87 | down5, out = self.down5(out); 88 | 89 | out = self.center(out) 90 | out = self.up5(out, down5); 91 | out = self.up4(out, down4); 92 | out = self.up3(out, down3); 93 | out = self.up2(out, down2); 94 | out = self.up1(out, down1); 95 | 96 | out = self.classify(out); 97 | out = torch.squeeze(out, dim=1); 98 | # print(out.shape) 99 | return out 100 | 101 | 102 | class UNet_small(nn.Module): 103 | def __init__(self, in_shape): 104 | super(UNet_small, self).__init__() 105 | channels, height, width = in_shape 106 | 107 | self.down1 = StackEncoder(3, 24, kernel_size=3) # 512 108 | 109 | self.up1 = StackDecoder(24, 24, 24, kernel_size=3) # 512 110 | self.classify = nn.Conv2d(24, 3, kernel_size=1, bias=True) 111 | 112 | self.center = nn.Sequential( 113 | ConvBnRelu2d(24, 24, kernel_size=3, padding=1), 114 | ) 115 | 116 | 117 | def forward(self, x): 118 | out = x 119 | down1, out = self.down1(out) 120 | out = self.center(out) 121 | out = self.up1(out, down1) 122 | out = self.classify(out) 123 | out = torch.squeeze(out, dim=1) 124 | return out 125 | -------------------------------------------------------------------------------- /SVDeconv/scripts/train.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py with $1 distdataparallel=True -p 2 | # python val.py with $1 -p 3 | # python train.py with $1 -p -------------------------------------------------------------------------------- /SVDeconv/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .ops import rggb_2_rgb 2 | 3 | __all__ = ['rggb_2_rgb'] -------------------------------------------------------------------------------- /SVDeconv/utils/checkpoint_excluder.txt: -------------------------------------------------------------------------------- 1 | Epoch_*.pth -------------------------------------------------------------------------------- /SVDeconv/utils/dir_helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper to create directories 3 | 4 | @py37+ 5 | """ 6 | 7 | import logging 8 | 9 | 10 | def dir_init(args, is_local_rank_0: bool = True): 11 | """ 12 | Creates paths for 13 | 14 | : save_filename 15 | : runs/[train,val] 16 | """ 17 | if is_local_rank_0: 18 | logging.info("Initialising folders ...") 19 | ckpt_dir = args.ckpt_dir / args.exp_name 20 | tensorboard_dump = args.run_dir / args.exp_name 21 | 22 | for dir in [ckpt_dir, tensorboard_dump]: 23 | if not dir.is_dir(): 24 | logging.info(f"Creating {dir.resolve()}") 25 | dir.mkdir(parents=True, exist_ok=True) 26 | -------------------------------------------------------------------------------- /SVDeconv/utils/model_serialization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code taken from 3 | 4 | maskrcnn-benchmark 5 | 6 | https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/utils/model_serialization.py 7 | """ 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 9 | from collections import OrderedDict 10 | import logging 11 | 12 | import torch 13 | 14 | 15 | def align_and_update_state_dicts(model_state_dict, loaded_state_dict): 16 | """ 17 | Strategy: suppose that the models that we will create will have prefixes appended 18 | to each of its keys, for example due to an extra level of nesting that the original 19 | pre-trained weights from ImageNet won't contain. For example, model.state_dict() 20 | might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains 21 | res2.conv1.weight. We thus want to match both parameters together. 22 | For that, we look for each model weight, look among all loaded keys if there is one 23 | that is a suffix of the current weight name, and use it if that's the case. 24 | If multiple matches exist, take the one with longest size 25 | of the corresponding name. For example, for the same model as before, the pretrained 26 | weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case, 27 | we want to match backbone[0].body.conv1.weight to conv1.weight, and 28 | backbone[0].body.res2.conv1.weight to res2.conv1.weight. 29 | """ 30 | current_keys = sorted(list(model_state_dict.keys())) 31 | loaded_keys = sorted(list(loaded_state_dict.keys())) 32 | # get a matrix of string matches, where each (i, j) entry correspond to the size of the 33 | # loaded_key string, if it matches 34 | match_matrix = [ 35 | len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys 36 | ] 37 | match_matrix = torch.as_tensor(match_matrix).view( 38 | len(current_keys), len(loaded_keys) 39 | ) 40 | max_match_size, idxs = match_matrix.max(1) 41 | # remove indices that correspond to no-match 42 | idxs[max_match_size == 0] = -1 43 | 44 | # used for logging 45 | max_size = max([len(key) for key in current_keys]) if current_keys else 1 46 | max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1 47 | log_str_template = "{: <{}} loaded from {: <{}} of shape {}" 48 | logger = logging.getLogger(__name__) 49 | for idx_new, idx_old in enumerate(idxs.tolist()): 50 | if idx_old == -1: 51 | continue 52 | key = current_keys[idx_new] 53 | key_old = loaded_keys[idx_old] 54 | model_state_dict[key] = loaded_state_dict[key_old] 55 | logger.debug( 56 | log_str_template.format( 57 | key, 58 | max_size, 59 | key_old, 60 | max_size_loaded, 61 | tuple(loaded_state_dict[key_old].shape), 62 | ) 63 | ) 64 | 65 | 66 | def strip_prefix_if_present(state_dict, prefix): 67 | keys = sorted(state_dict.keys()) 68 | if not all(key.startswith(prefix) for key in keys): 69 | return state_dict 70 | stripped_state_dict = OrderedDict() 71 | for key, value in state_dict.items(): 72 | stripped_state_dict[key.replace(prefix, "")] = value 73 | return stripped_state_dict 74 | 75 | 76 | def load_state_dict(model, loaded_state_dict): 77 | model_state_dict = model.state_dict() 78 | # if the state_dict comes from a model that was wrapped in a 79 | # DataParallel or DistributedDataParallel during serialization, 80 | # remove the "module" prefix before performing the matching 81 | loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.") 82 | align_and_update_state_dicts(model_state_dict, loaded_state_dict) 83 | # use strict loading 84 | model.load_state_dict(model_state_dict, strict=True) -------------------------------------------------------------------------------- /SVDeconv/utils/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def unpixel_shuffle(feature, r: int = 1): 5 | b, c, h, w = feature.shape 6 | out_channel = c * (r ** 2) 7 | out_h = h // r 8 | out_w = w // r 9 | feature_view = feature.contiguous().view(b, c, out_h, r, out_w, r) 10 | feature_prime = ( 11 | feature_view.permute(0, 1, 3, 5, 2, 4) 12 | .contiguous() 13 | .view(b, out_channel, out_h, out_w) 14 | ) 15 | return feature_prime 16 | 17 | 18 | def sample_patches( 19 | inputs: torch.Tensor, patch_size: int = 3, stride: int = 2 20 | ) -> torch.Tensor: 21 | """ 22 | Patch sampler for feature maps. 23 | Parameters 24 | --- 25 | inputs : torch.Tensor 26 | the input feature maps, shape: (c, h, w). 27 | patch_size : int, optional 28 | the spatial size of sampled patches 29 | stride : int, optional 30 | the stride of sampling. 31 | Returns 32 | --- 33 | patches : torch.Tensor 34 | extracted patches, shape: (c, patch_size, patch_size, n_patches). 35 | """ 36 | 37 | n, c, h, w = inputs.shape 38 | patches = ( 39 | inputs.unfold(2, patch_size, stride) 40 | .unfold(3, patch_size, stride) 41 | .reshape(n, c, -1, patch_size, patch_size) 42 | .permute(0, 1, 3, 4, 2) 43 | ) 44 | return patches 45 | 46 | def roll_n(X, axis, n): 47 | f_idx = tuple( 48 | slice(None, None, None) if i != axis else slice(0, n, None) 49 | for i in range(X.dim()) 50 | ) 51 | b_idx = tuple( 52 | slice(None, None, None) if i != axis else slice(n, None, None) 53 | for i in range(X.dim()) 54 | ) 55 | front = X[f_idx] 56 | back = X[b_idx] 57 | return torch.cat([back, front], axis) 58 | 59 | def rggb_2_rgb(img: "Tensor[4,H,W]") -> "Tensor[3,H,W]": 60 | if img.shape[0] == 3: 61 | return img 62 | img_rgb = torch.zeros_like(img)[:3] 63 | img_rgb[0] = img[0] 64 | img_rgb[1] = 0.5 * (img[1] + img[2]) 65 | img_rgb[2] = img[3] 66 | return img_rgb -------------------------------------------------------------------------------- /SVDeconv/utils/tupperware.py: -------------------------------------------------------------------------------- 1 | from collections import UserDict 2 | import collections 3 | from recordclass import recordclass 4 | 5 | __author__ = 'github.com/hangtwenty' 6 | 7 | 8 | def tupperware(mapping): 9 | """ Convert mappings to 'tupperwares' recursively. 10 | Lets you use dicts like they're JavaScript Object Literals (~=JSON)... 11 | It recursively turns mappings (dictionaries) into namedtuples. 12 | Thus, you can cheaply create an object whose attributes are accessible 13 | by dotted notation (all the way down). 14 | Use cases: 15 | * Fake objects (useful for dependency injection when you're making 16 | fakes/stubs that are simpler than proper mocks) 17 | * Storing data (like fixtures) in a structured way, in Python code 18 | (data whose initial definition reads nicely like JSON). You could do 19 | this with dictionaries, but namedtuples are immutable, and their 20 | dotted notation can be clearer in some contexts. 21 | .. doctest:: 22 | >>> t = tupperware({ 23 | ... 'foo': 'bar', 24 | ... 'baz': {'qux': 'quux'}, 25 | ... 'tito': { 26 | ... 'tata': 'tutu', 27 | ... 'totoro': 'tots', 28 | ... 'frobnicator': ['this', 'is', 'not', 'a', 'mapping'] 29 | ... } 30 | ... }) 31 | >>> t # doctest: +ELLIPSIS 32 | Tupperware(tito=Tupperware(...), foo='bar', baz=Tupperware(qux='quux')) 33 | >>> t.tito # doctest: +ELLIPSIS 34 | Tupperware(frobnicator=[...], tata='tutu', totoro='tots') 35 | >>> t.tito.tata 36 | 'tutu' 37 | >>> t.tito.frobnicator 38 | ['this', 'is', 'not', 'a', 'mapping'] 39 | >>> t.foo 40 | 'bar' 41 | >>> t.baz.qux 42 | 'quux' 43 | Args: 44 | mapping: An object that might be a mapping. If it's a mapping, convert 45 | it (and all of its contents that are mappings) to namedtuples 46 | (called 'Tupperwares'). 47 | Returns: 48 | A tupperware (a namedtuple (of namedtuples (of namedtuples (...)))). 49 | If argument is not a mapping, it just returns it (this enables the 50 | recursion). 51 | """ 52 | 53 | if (isinstance(mapping, collections.abc.Mapping) and 54 | not isinstance(mapping, ProtectedDict)): 55 | for key, value in mapping.items(): 56 | mapping[key] = tupperware(value) 57 | return namedtuple_from_mapping(mapping) 58 | return mapping 59 | 60 | 61 | def namedtuple_from_mapping(mapping, name="Tupperware"): 62 | # this_namedtuple_maker = collections.namedtuple(name, mapping.keys()) 63 | this_namedtuple_maker = recordclass(name, mapping.keys()) 64 | return this_namedtuple_maker(**mapping) 65 | 66 | 67 | class ProtectedDict(UserDict): 68 | """ A class that exists just to tell `tupperware` not to eat it. 69 | `tupperware` eats all dicts you give it, recursively; but what if you 70 | actually want a dictionary in there? This will stop it. Just do 71 | ProtectedDict({...}) or ProtectedDict(kwarg=foo). 72 | """ 73 | 74 | 75 | def tupperware_from_kwargs(**kwargs): 76 | return tupperware(kwargs) 77 | -------------------------------------------------------------------------------- /SVDeconv/utils/typing_alias.py: -------------------------------------------------------------------------------- 1 | """ 2 | Meant to be imported as 3 | from utils.typing_helper import * 4 | 5 | To ease # imports for typing. 6 | """ 7 | 8 | __all__ = [ 9 | "TYPE_CHECKING", 10 | "Any", 11 | "Dict", 12 | "DataLoader", 13 | "List", 14 | "lr_scheduler", 15 | "nn.Module", 16 | "optim", 17 | "SummaryWriter", 18 | "tupperware", 19 | "Tensor", 20 | "Tuple", 21 | "Union", 22 | ] 23 | 24 | 25 | from typing import Dict, List, Any, Tuple, Union 26 | from torch.utils.data import DataLoader 27 | from utils.tupperware import tupperware 28 | from torch.utils.tensorboard import SummaryWriter 29 | from torch import Tensor, nn, optim 30 | import torch.optim.lr_scheduler as lr_scheduler 31 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.0 2 | torchvision==0.15.1 3 | albumentations==1.3.0 4 | opencv-python==4.6.0.66 5 | imageio==2.9.0 6 | numpy==1.23.1 7 | imageio-ffmpeg==0.4.2 8 | pytorch-lightning==1.4.2 9 | omegaconf==2.1.1 10 | test-tube>=0.7.5 11 | streamlit==1.12.1 12 | einops==0.3.0 13 | transformers==4.19.2 14 | webdataset==0.2.5 15 | kornia==0.6 16 | open_clip_torch==2.0.2 17 | invisible-watermark>=0.1.5 18 | streamlit-drawable-canvas==0.8.0 19 | torchmetrics==0.6.0 20 | xformers==0.0.16 21 | triton 22 | matplotlib 23 | wandb 24 | pillow 25 | waveprop==0.0.10 26 | sacred 27 | recordclass 28 | lpips 29 | -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 30 | -e git+https://github.com/openai/CLIP.git@main#egg=clip 31 | # -e . 32 | -------------------------------------------------------------------------------- /tools/copy_gt.py: -------------------------------------------------------------------------------- 1 | # copy gt data to data/phlatcam/val/gts and data/phlatcam/train/gts 2 | # the source data is from data/phlatcam/orig, and the train / val split is from data/phlatcam/text_files/train_target.txt and data/phlatcam/text_files/val_target.txt 3 | import os 4 | import random 5 | import cv2 6 | from tqdm import tqdm 7 | 8 | def copy_gt_data(source_dir, target_dir, target_list): 9 | with open(target_list, 'r') as f: 10 | lines = f.readlines() 11 | for line in tqdm(lines): 12 | line = line.strip() 13 | source_path = os.path.join(source_dir, line) 14 | img_name = os.path.basename(line) 15 | target_path = os.path.join(target_dir, img_name.replace('.JPEG', '.png')) 16 | os.makedirs(os.path.dirname(target_path), exist_ok=True) 17 | print(f"Copy {source_path} to {target_path}") 18 | img = cv2.imread(source_path) 19 | cv2.imwrite(target_path, img) 20 | # copy_gt_data('data/phlatcam', 'data/phlatcam/val/gts', 'data/phlatcam/text_files/val_target.txt') 21 | # copy_gt_data('data/phlatcam', 'data/phlatcam/train/gts', 'data/phlatcam/text_files/train_target.txt') 22 | 23 | # generate the train and val split for decoded simulated captures 24 | def generate_train_val_split(target_train_list, target_val_list, train_list, val_list): 25 | with open(train_list, 'r') as f: 26 | train_lines = f.readlines() 27 | with open(val_list, 'r') as f: 28 | val_lines = f.readlines() 29 | # create the target train and val list 30 | with open(target_train_list, 'w') as f: 31 | for line in train_lines: 32 | img_name = line.split(os.sep)[-1].replace('.JPEG', '.png') 33 | target_line = "train/decoded_sim_captures/" + img_name 34 | f.write(target_line) 35 | with open(target_val_list, 'w') as f: 36 | for line in val_lines: 37 | img_name = line.split(os.sep)[-1].replace('.JPEG', '.png') 38 | target_line = "val/decoded_sim_captures/" + img_name 39 | f.write(target_line) 40 | 41 | generate_train_val_split('data/phlatcam/text_files/decoded_sim_captures_train.txt', 'data/phlatcam/text_files/decoded_sim_captures_val.txt', 'data/phlatcam/text_files/train_target.txt', 'data/phlatcam/text_files/val_target.txt') -------------------------------------------------------------------------------- /tools/data_process.py: -------------------------------------------------------------------------------- 1 | # process svd output to dataset that can be used by NullSpaceDiff 2 | # first, move the output to a single folder 3 | # second, move the original images to a single folder 4 | # third, resize the images to 512x512 5 | 6 | #find the directory of the process file 7 | import os 8 | import shutil 9 | import numpy as np 10 | import cv2 11 | from tqdm import tqdm 12 | tools_dir = os.path.dirname(os.path.abspath(__file__)) 13 | svd_dir = os.path.join(os.path.dirname(tools_dir), 'SVDeconv') 14 | nullspace_dir = os.path.join(os.path.dirname(tools_dir), 'NullSpaceDiff') 15 | dataset = "phlatcam" 16 | 17 | exp_name = "fft-svd-1280-1408-meas-decoded_sim_spatial_weight" 18 | output_name = exp_name 19 | output_val_dir = os.path.join(nullspace_dir, "data/%s/%s/val"%(dataset,output_name)) 20 | output_train_dir = os.path.join(nullspace_dir, "data/%s/%s/train"%(dataset,output_name)) 21 | source_svd_train_dir = os.path.join(svd_dir, 'output/%s'%dataset, exp_name, "train") 22 | source_svd_val_dir = os.path.join(svd_dir, 'output/%s'%dataset, exp_name, "val") 23 | source_orig_dir = os.path.join(svd_dir, "data/%s/orig"%dataset) 24 | 25 | index = 0 26 | for cls in tqdm(os.listdir(source_svd_train_dir)): 27 | cls_dir = os.path.join(source_svd_train_dir, cls) 28 | if not os.path.isdir(cls_dir): 29 | continue 30 | for file in os.listdir(cls_dir): 31 | if file.endswith('png') and file.startswith('output_'): 32 | os.makedirs(os.path.join(output_train_dir, "inputs"), exist_ok=True) 33 | shutil.copy(os.path.join(cls_dir, file), os.path.join(os.path.join(output_train_dir, "inputs"), file[7:])) 34 | gt_file = os.path.join(source_orig_dir, cls, file[7:]).replace('png', 'JPEG') 35 | os.makedirs(os.path.join(output_train_dir, "gts"), exist_ok=True) 36 | shutil.copy(gt_file, os.path.join(os.path.join(output_train_dir, "gts"), file[7:])) 37 | 38 | for cls in tqdm(os.listdir(source_svd_val_dir)): 39 | cls_dir = os.path.join(source_svd_val_dir, cls) 40 | if not os.path.isdir(cls_dir): 41 | continue 42 | for file in os.listdir(cls_dir): 43 | if file.endswith('png') and file.startswith('output_'): 44 | os.makedirs(os.path.join(output_val_dir, "inputs"), exist_ok=True) 45 | shutil.copy(os.path.join(cls_dir, file), os.path.join(os.path.join(output_val_dir, "inputs"), file[7:])) 46 | gt_file = os.path.join(source_orig_dir, cls, file[7:]).replace('png', 'JPEG') 47 | os.makedirs(os.path.join(output_val_dir, "gts"), exist_ok=True) 48 | shutil.copy(gt_file, os.path.join(os.path.join(output_val_dir, "gts"), file[7:])) 49 | 50 | 51 | # resize the images to 512x512, and save them to a new folder: inputs_512, gts_512 52 | output_dirs = [output_train_dir, output_val_dir] 53 | for output_dir in output_dirs: 54 | inputs = os.path.join(output_dir, 'inputs') 55 | gts = os.path.join(output_dir, 'gts') 56 | 57 | # create the new folders for resized images 58 | inputs_512 = os.path.join(output_dir, 'inputs_512') 59 | gts_512 = os.path.join(output_dir, 'gts_512') 60 | for dir in [inputs_512, gts_512]: 61 | if not os.path.exists(dir): 62 | os.makedirs(dir) 63 | for file in tqdm(os.listdir(inputs)): 64 | img = cv2.imread(os.path.join(inputs, file)) 65 | img = cv2.resize(img, (512, 512)) 66 | cv2.imwrite(os.path.join(inputs_512, file), img) 67 | for file in tqdm(os.listdir(gts)): 68 | img = cv2.imread(os.path.join(gts, file)) 69 | img = cv2.resize(img, (512, 512)) 70 | cv2.imwrite(os.path.join(gts_512, file), img) 71 | 72 | 73 | --------------------------------------------------------------------------------