├── README.md ├── assets ├── expected.jpg ├── main.jpg ├── rect_input.png └── rect_output.png ├── eval.py ├── functions ├── __init__.py ├── ckpt_util.py ├── conjugate_gradient.py ├── degradation.py ├── jpeg.py ├── measurements.py ├── nonuniform │ └── kernels │ │ └── 000001.npy ├── svd_ddnm.py └── svd_operators.py ├── requirements.txt ├── samples ├── afhq_example.jpg ├── div2k_example.png └── ffhq_example.png ├── sd3_sampler.py ├── solve.py ├── solve_arbitrary.py ├── util.py └── utils ├── admm.py ├── blur_util.py ├── diffpir_util.py ├── img_util.py ├── inpaint_util.py ├── log_util.py ├── motionblur.py ├── resizer.py └── utils_sisr.py /README.md: -------------------------------------------------------------------------------- 1 | # FlowDPS: Flow-Driven Posterior Sampling for Inverse Problems 2 | 3 | ![img](assets/main.jpg) 4 | 5 | ## Abstract 6 | 7 | 8 | ❗️Flow matching is a recent state-of-the-art framework for generative modeling based on ordinary differential equations (ODEs). While closely related to diffusion models, __it provides a more general perspective__ on generative modeling. 9 | 10 | ❓ Although inverse problem solving has been extensively explored using diffusion models, it has not been rigorously examined within the broader context of flow models. Therefore, __we extend the diffusion inverse solvers (DIS)— which perform posterior sampling by combining a denoising diffusion prior with an likelihood gradient—into the flow framework.__ 11 | 12 | 👍 Our proposed solver, Flow-Driven Posterior Sampling (FlowDPS), can also be seamlessly integrated into a latent flow model with a transformer architecture. Across four linear inverse problems, we confirm that FlowDPS outperforms state-of-the-art alternatives, all without requiring additional training. 13 | 14 | 15 | ## Quick Start 16 | 17 | ### Environment Setup 18 | 19 | First, clone this repository and install requirements. 20 | 21 | ``` 22 | git clone https://github.com/FlowDPS-Inverse/FlowDPS.git 23 | cd FlowDPS 24 | conda create -n flowdps python==3.10 25 | conda activate flowdps 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | > Provided requirements.txt installs torch with CUDA 11.8. If you are using other versions, please change it. 30 | 31 | For the motion blur, we need to clone below repository. 32 | ``` 33 | git clone https://github.com/LeviBorodenko/motionblur.git 34 | ``` 35 | 36 | ### Examples 37 | 38 | You can quickly check the results using following examples. 39 | 40 | **Example 1. Super-resolution x 12 (avg-pool) / Dog** 41 | ``` 42 | python solve.py \ 43 | --img_size 768 \ 44 | --img_path samples/afhq_example.jpg \ 45 | --prompt "a photo of a closed face of a dog" \ 46 | --task sr_avgpool \ 47 | --deg_scale 12 \ 48 | --efficient_memory; 49 | ``` 50 | 51 | **Example 2. Super-resolution x 12 (bicubic) / Animal** 52 | ``` 53 | python solve.py \ 54 | --img_size 768 \ 55 | --img_path samples/div2k_example.png \ 56 | --prompt "a high quality photo of animal, bush, close-up, fox, grass, green, greenery, hide, panda, red, red panda, stare" \ 57 | --task sr_bicubic \ 58 | --deg_scale 12 \ 59 | --efficient_memory; 60 | ``` 61 | > The prompt (after "a high quality photo of") is extracted by DAPE from measurement. 62 | 63 | **Example 3. Motion Deblur / Human** 64 | ``` 65 | python solve.py \ 66 | --img_size 768 \ 67 | --img_path samples/ffhq_example.png \ 68 | --prompt "a photo of a closed face" \ 69 | --task deblur_motion \ 70 | --deg_scale 61 \ 71 | --efficient_memory; 72 | ``` 73 | 74 | 75 | For each task, expected results are 76 | ![expect](assets/expected.jpg) 77 | 78 | 79 | ### Arbitraty size problem 80 | You can solve the problem for rectangle images. 81 | 82 | ```bash 83 | python solve_arbitrary.py \ 84 | --imgH 768 \ 85 | --imgW 1152 \ 86 | --img_path samples/div2k_example.png \ 87 | --prompt "a high quality photo of animal, bush, close-up, fox, grass, green, greenery, hide, panda, red, red panda, stare" \ 88 | --task deblur_motion \ 89 | --deg_scale 61 \ 90 | --efficient_memory; 91 | ``` 92 | 93 | Measurement | Reconstruction 94 | :-------------------------:|:-------------------------: 95 | ![](assets/rect_input.png) | ![](assets/rect_output.png) 96 | 97 | ## How to choose task and solver 98 | 99 | You can freely change the task and solver using arguments: 100 | - `task` : sr_avgpool / sr_bicubic / deblur_gauss / deblur_motion 101 | - `method` : psld / flowchef / flowdps 102 | 103 | If you want to change the amount of degradation, change `deg_scale`. For SR tasks, it refers to the downscale factor, and for deblurring tasks, it refers to the kernel size. 104 | 105 | ## Efficient inference 106 | 107 | If you use `--efficient_memory`, text encoder will pre-compute text embeddings and is removed from the GPU. 108 | 109 | This allows us to solve inverse problem with a single GPU with VRAM 24GB. 110 | -------------------------------------------------------------------------------- /assets/expected.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlowDPS-Inverse/FlowDPS/3e73e443c987ba378b091515ba1e9b0963860e81/assets/expected.jpg -------------------------------------------------------------------------------- /assets/main.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlowDPS-Inverse/FlowDPS/3e73e443c987ba378b091515ba1e9b0963860e81/assets/main.jpg -------------------------------------------------------------------------------- /assets/rect_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlowDPS-Inverse/FlowDPS/3e73e443c987ba378b091515ba1e9b0963860e81/assets/rect_input.png -------------------------------------------------------------------------------- /assets/rect_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlowDPS-Inverse/FlowDPS/3e73e443c987ba378b091515ba1e9b0963860e81/assets/rect_output.png -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import List 3 | from pathlib import Path 4 | 5 | import torch 6 | from torchvision import transforms 7 | import numpy as np 8 | from PIL import Image 9 | from skimage.metrics import peak_signal_noise_ratio as psnr 10 | 11 | import lpips 12 | from pytorch_fid import fid_score 13 | from pytorch_msssim import ssim 14 | 15 | 16 | def tag(name:str): 17 | def wrapper(func): 18 | func.tag = name 19 | return func 20 | return wrapper 21 | 22 | class Factory(object): 23 | def __init__(self, name: List[str]): 24 | self.name = name 25 | methods = {func for func in dir(self) if callable(getattr(self, func)) and hasattr(getattr(self, func), 'tag')} 26 | self.tagged_method = {getattr(self, func).tag : getattr(self, func) for func in methods} 27 | self._call_func = self.get_method(name) 28 | 29 | def retrieve(self, input_dir, pred_dir): 30 | input_path = sorted(list(Path(input_dir).glob('*.png'))) + sorted(list(Path(input_dir).glob('*.jpg'))) 31 | pred_path = sorted(list(Path(pred_dir).glob('*.png'))) + sorted(list(Path(pred_dir).glob('*.jpg'))) 32 | return input_path, pred_path 33 | 34 | def __call__(self, *args, **kwargs): 35 | output = [] 36 | for _func in self._call_func: 37 | output.append(_func(*args, **kwargs)) 38 | return output 39 | 40 | def get_method(self, name: list[str]): 41 | methods = [] 42 | for n in name: 43 | if n not in self.tagged_method: 44 | raise ValueError(f'Cannot find {self.__class__.__name__} ({n})') 45 | else: 46 | methods.append(self.tagged_method[n]) 47 | return methods 48 | 49 | class Metric(Factory): 50 | @tag('psnr') 51 | def _psnr(self, input_path, pred_path, transform=None, data_range:int=255, **kwargs): 52 | if transform is None: 53 | transform = transforms.Compose([ 54 | transforms.ToTensor() 55 | ]) 56 | 57 | values = [] 58 | in_fs, pred_fs = self.retrieve(input_path, pred_path) 59 | for in_f, pred_f in zip(in_fs, pred_fs): 60 | try: 61 | img1 = np.array(transform(Image.open(in_f).convert('RGB'))) * data_range 62 | img2 = np.array(transform(Image.open(pred_f).convert('RGB'))) * data_range 63 | values.append(psnr(img1, img2, data_range=data_range)) 64 | except: 65 | continue 66 | 67 | return np.mean(values) 68 | 69 | @tag('ssim') 70 | def _ssim(self, input_path, pred_path, transform=None, data_range:int=255, **kwargs): 71 | if transform is None: 72 | transform = transforms.Compose([ 73 | transforms.ToTensor() 74 | ]) 75 | 76 | values = [] 77 | in_fs, pred_fs = self.retrieve(input_path, pred_path) 78 | for in_f, pred_f in zip(in_fs, pred_fs): 79 | try: 80 | img1 = transform(Image.open(in_f).convert('RGB')).unsqueeze(0) * data_range 81 | img2 = transform(Image.open(pred_f).convert('RGB')).unsqueeze(0) * data_range 82 | values.append(ssim(img1, img2).item()) 83 | except: 84 | continue 85 | 86 | return np.mean(values) 87 | 88 | @tag('fid') 89 | def _fid(self, pred_path, label_path, **kwargs): 90 | return fid_score.calculate_fid_given_paths([str(pred_path), str(label_path)], 50, 'cuda', 2048).item() 91 | 92 | @tag('lpips') 93 | def _lpips(self, input_path, pred_path, transform=None, **kwargs): 94 | lpips_fn = lpips.LPIPS(net='vgg').to('cuda').eval() 95 | if transform is None: 96 | transform = transforms.Compose([ 97 | transforms.Resize((224, 224)), 98 | transforms.ToTensor() 99 | ]) 100 | 101 | values = [] 102 | in_fs, pred_fs = self.retrieve(input_path, pred_path) 103 | for in_f, pred_f in zip(in_fs, pred_fs): 104 | try: 105 | img1 = transform(Image.open(in_f).convert('RGB')).to('cuda') 106 | img2 = transform(Image.open(pred_f).convert('RGB')).to('cuda') 107 | values.append(lpips_fn(img1, img2).item()) 108 | except: 109 | continue 110 | 111 | return np.mean(values) 112 | 113 | 114 | if __name__ == '__main__': 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument('--path1', type=Path) 117 | parser.add_argument('--path2', type=Path) 118 | parser.add_argument('--metric', type=str, nargs='+') 119 | parser.add_argument('--prompt', type=str) 120 | args = parser.parse_args() 121 | 122 | metric = Metric(args.metric) 123 | output = metric(args.path1, args.path2, prompt=args.prompt) 124 | 125 | for m, o in zip(args.metric, output): 126 | print(f'{m}: {o}') 127 | -------------------------------------------------------------------------------- /functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlowDPS-Inverse/FlowDPS/3e73e443c987ba378b091515ba1e9b0963860e81/functions/__init__.py -------------------------------------------------------------------------------- /functions/ckpt_util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "cifar10": "https://heibox.uni-heidelberg.de/f/869980b53bf5416c8a28/?dl=1", 7 | "ema_cifar10": "https://heibox.uni-heidelberg.de/f/2e4f01e2d9ee49bab1d5/?dl=1", 8 | "lsun_bedroom": "https://heibox.uni-heidelberg.de/f/f179d4f21ebc4d43bbfe/?dl=1", 9 | "ema_lsun_bedroom": "https://heibox.uni-heidelberg.de/f/b95206528f384185889b/?dl=1", 10 | "lsun_cat": "https://heibox.uni-heidelberg.de/f/fac870bd988348eab88e/?dl=1", 11 | "ema_lsun_cat": "https://heibox.uni-heidelberg.de/f/0701aac3aa69457bbe34/?dl=1", 12 | "lsun_church": "https://heibox.uni-heidelberg.de/f/2711a6f712e34b06b9d8/?dl=1", 13 | "ema_lsun_church": "https://heibox.uni-heidelberg.de/f/44ccb50ef3c6436db52e/?dl=1", 14 | } 15 | CKPT_MAP = { 16 | "cifar10": "diffusion_cifar10_model/model-790000.ckpt", 17 | "ema_cifar10": "ema_diffusion_cifar10_model/model-790000.ckpt", 18 | "lsun_bedroom": "diffusion_lsun_bedroom_model/model-2388000.ckpt", 19 | "ema_lsun_bedroom": "ema_diffusion_lsun_bedroom_model/model-2388000.ckpt", 20 | "lsun_cat": "diffusion_lsun_cat_model/model-1761000.ckpt", 21 | "ema_lsun_cat": "ema_diffusion_lsun_cat_model/model-1761000.ckpt", 22 | "lsun_church": "diffusion_lsun_church_model/model-4432000.ckpt", 23 | "ema_lsun_church": "ema_diffusion_lsun_church_model/model-4432000.ckpt", 24 | } 25 | MD5_MAP = { 26 | "cifar10": "82ed3067fd1002f5cf4c339fb80c4669", 27 | "ema_cifar10": "1fa350b952534ae442b1d5235cce5cd3", 28 | "lsun_bedroom": "f70280ac0e08b8e696f42cb8e948ff1c", 29 | "ema_lsun_bedroom": "1921fa46b66a3665e450e42f36c2720f", 30 | "lsun_cat": "bbee0e7c3d7abfb6e2539eaf2fb9987b", 31 | "ema_lsun_cat": "646f23f4821f2459b8bafc57fd824558", 32 | "lsun_church": "eb619b8a5ab95ef80f94ce8a5488dae3", 33 | "ema_lsun_church": "fdc68a23938c2397caba4a260bc2445f", 34 | } 35 | 36 | 37 | def download(url, local_path, chunk_size=1024): 38 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 39 | with requests.get(url, stream=True) as r: 40 | total_size = int(r.headers.get("content-length", 0)) 41 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 42 | with open(local_path, "wb") as f: 43 | for data in r.iter_content(chunk_size=chunk_size): 44 | if data: 45 | f.write(data) 46 | pbar.update(chunk_size) 47 | 48 | 49 | def md5_hash(path): 50 | with open(path, "rb") as f: 51 | content = f.read() 52 | return hashlib.md5(content).hexdigest() 53 | 54 | 55 | def get_ckpt_path(name, root=None, check=False, prefix='exp'): 56 | if 'church_outdoor' in name: 57 | name = name.replace('church_outdoor', 'church') 58 | assert name in URL_MAP 59 | # Modify the path when necessary 60 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.join(prefix, "logs/")) 61 | root = ( 62 | root 63 | if root is not None 64 | else os.path.join(cachedir, "diffusion_models_converted") 65 | ) 66 | path = os.path.join(root, CKPT_MAP[name]) 67 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 68 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 69 | download(URL_MAP[name], path) 70 | md5 = md5_hash(path) 71 | assert md5 == MD5_MAP[name], md5 72 | return path 73 | -------------------------------------------------------------------------------- /functions/conjugate_gradient.py: -------------------------------------------------------------------------------- 1 | # From wikipedia. MATLAB code, 2 | # function x = conjgrad(A, b, x) 3 | # r = b - A * x; 4 | # p = r; 5 | # rsold = r' * r; 6 | # 7 | # for i = 1:length(b) 8 | # Ap = A * p; 9 | # alpha = rsold / (p' * Ap); 10 | # x = x + alpha * p; 11 | # r = r - alpha * Ap; 12 | # rsnew = r' * r; 13 | # if sqrt(rsnew) < 1e-10 14 | # break 15 | # end 16 | # p = r + (rsnew / rsold) * p; 17 | # rsold = rsnew; 18 | # end 19 | # end 20 | 21 | from typing import Callable, Optional 22 | 23 | import torch 24 | 25 | 26 | def CG(A: Callable, 27 | b: torch.Tensor, 28 | x: torch.Tensor, 29 | m: Optional[int]=5, 30 | eps: Optional[float]=1e-4, 31 | damping: float=0.0, 32 | use_mm: bool=False) -> torch.Tensor: 33 | 34 | if use_mm: 35 | mm_fn = lambda x, y: torch.mm(x.view(1, -1), y.view(1, -1).T) 36 | else: 37 | mm_fn = lambda x, y: (x * y).flatten().sum() 38 | 39 | orig_shape = x.shape 40 | x = x.view(x.shape[0], -1) 41 | 42 | r = b - A(x) 43 | p = r.clone() 44 | 45 | rsold = mm_fn(r, r) 46 | assert not (rsold != rsold).any(), print(f'NaN detected 1') 47 | 48 | for i in range(m): 49 | Ap = A(p) + damping * p 50 | alpha = rsold / mm_fn(p, Ap) 51 | assert not (alpha != alpha).any(), print(f'NaN detected 2') 52 | 53 | x = x + alpha * p 54 | r = r - alpha * Ap 55 | 56 | rsnew = mm_fn(r, r) 57 | assert not (rsnew != rsnew).any(), print('NaN detected 3') 58 | 59 | if rsnew.sqrt().abs() < eps: 60 | break 61 | 62 | p = r + (rsnew / rsold) * p 63 | rsold = rsnew.clone() 64 | 65 | return x.reshape(orig_shape) 66 | 67 | -------------------------------------------------------------------------------- /functions/degradation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from munch import Munch 4 | 5 | import functions.svd_operators as svd_op 6 | from functions import measurements 7 | from utils.inpaint_util import MaskGenerator 8 | 9 | __DEGRADATION__ = {} 10 | 11 | def register_degradation(name: str): 12 | def wrapper(fn): 13 | if __DEGRADATION__.get(name) is not None: 14 | raise NameError(f'DEGRADATION {name} is already registered') 15 | __DEGRADATION__[name]=fn 16 | return fn 17 | return wrapper 18 | 19 | def get_degradation(name: str, 20 | deg_config: Munch, 21 | device:torch.device): 22 | if __DEGRADATION__.get(name) is None: 23 | raise NameError(f'DEGRADATION {name} does not exist.') 24 | return __DEGRADATION__[name](deg_config, device) 25 | 26 | @register_degradation(name='cs_walshhadamard') 27 | def deg_cs_walshhadamard(deg_config, device): 28 | compressed_size = round(1/deg_config.deg_scale) 29 | A_funcs = svd_op.WalshHadamardCS(deg_config.channels, 30 | deg_config.image_size, 31 | compressed_size, 32 | torch.randperm(deg_config.image_size**2), 33 | device) 34 | return A_funcs 35 | 36 | @register_degradation(name='cs_blockbased') 37 | def deg_cs_blockbased(deg_config, device): 38 | cs_ratio = deg_config.deg_scale 39 | A_funcs = svd_op.CS(deg_config.channels, 40 | deg_config.image_size, 41 | cs_ratio, 42 | device) 43 | return A_funcs 44 | 45 | @register_degradation(name='inpainting') 46 | def deg_inpainting(deg_config, device): 47 | # TODO: generate mask rather than load 48 | loaded = np.load("exp/inp_masks/mask_768_half.npy") # block 49 | # loaded = np.load("lip_mask_4.npy") 50 | mask = torch.from_numpy(loaded).to(device).reshape(-1) 51 | missing_r = torch.nonzero(mask == 0).long().reshape(-1) * 3 52 | missing_g = missing_r + 1 53 | missing_b = missing_g + 1 54 | missing = torch.cat([missing_r, missing_g, missing_b], dim=0) 55 | A_funcs = svd_op.Inpainting(deg_config.channels, 56 | deg_config.image_size, 57 | missing, 58 | device) 59 | return A_funcs 60 | 61 | @register_degradation(name='denoising') 62 | def deg_denoise(deg_config, device): 63 | A_funcs = svd_op.Denoising(deg_config.channels, 64 | deg_config.image_size, 65 | device) 66 | return A_funcs 67 | 68 | @register_degradation(name='colorization') 69 | def deg_colorization(deg_config, device): 70 | A_funcs = svd_op.Colorization(deg_config.image_size, 71 | device) 72 | return A_funcs 73 | 74 | 75 | @register_degradation(name='sr_avgpool') 76 | def deg_sr_avgpool(deg_config, device): 77 | blur_by = int(deg_config.deg_scale) 78 | A_funcs = svd_op.SuperResolution(deg_config.channels, 79 | deg_config.image_size, 80 | blur_by, 81 | device) 82 | return A_funcs 83 | 84 | @register_degradation(name='sr_bicubic') 85 | def deg_sr_bicubic(deg_config, device): 86 | def bicubic_kernel(x, a=-0.5): 87 | if abs(x) <= 1: 88 | return (a + 2) * abs(x) ** 3 - (a + 3) * abs(x) ** 2 + 1 89 | elif 1 < abs(x) and abs(x) < 2: 90 | return a * abs(x) ** 3 - 5 * a * abs(x) ** 2 + 8 * a * abs(x) - 4 * a 91 | else: 92 | return 0 93 | 94 | factor = int(deg_config.deg_scale) 95 | k = np.zeros((factor * 4)) 96 | for i in range(factor * 4): 97 | x = (1 / factor) * (i - np.floor(factor * 4 / 2) + 0.5) 98 | k[i] = bicubic_kernel(x) 99 | k = k / np.sum(k) 100 | kernel = torch.from_numpy(k).float().to(device) 101 | A_funcs = svd_op.SRConv(kernel / kernel.sum(), 102 | deg_config.channels, 103 | deg_config.image_size, 104 | device, 105 | stride=factor) 106 | return A_funcs 107 | 108 | @register_degradation(name='deblur_uni') 109 | def deg_deblur_uni(deg_config, device): 110 | A_funcs = svd_op.Deblurring(torch.tensor([1/deg_config.deg_scale]*deg_config.deg_scale).to(device), 111 | deg_config.channels, 112 | deg_config.image_size, 113 | device) 114 | return A_funcs 115 | 116 | @register_degradation(name='deblur_gauss') 117 | def deg_deblur_gauss(deg_config, device): 118 | sigma = 3.0 119 | pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2])) 120 | size = deg_config.deg_scale 121 | ker = [] 122 | for k in range(-size//2, size//2): 123 | ker.append(pdf(k)) 124 | kernel = torch.Tensor(ker).to(device) 125 | A_funcs = svd_op.Deblurring(kernel / kernel.sum(), 126 | deg_config.channels, 127 | deg_config.image_size, 128 | device) 129 | return A_funcs 130 | 131 | @register_degradation(name='deblur_aniso') 132 | def deg_deblur_aniso(deg_config, device): 133 | sigma = 20 134 | pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2])) 135 | kernel2 = torch.Tensor([pdf(-4), pdf(-3), pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2), pdf(3), pdf(4)]).to(device) 136 | 137 | sigma = 1 138 | pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2])) 139 | kernel1 = torch.Tensor([pdf(-4), pdf(-3), pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2), pdf(3), pdf(4)]).to(device) 140 | 141 | A_funcs = svd_op.Deblurring2D(kernel1 / kernel1.sum(), 142 | kernel2 / kernel2.sum(), 143 | deg_config.channels, 144 | deg_config.image_size, 145 | device) 146 | return A_funcs 147 | 148 | @register_degradation(name='deblur_motion') 149 | def deg_deblur_motion(deg_config, device): 150 | A_funcs = measurements.MotionBlurOperator( 151 | kernel_size=deg_config.deg_scale, 152 | intensity=0.5, 153 | device=device 154 | ) 155 | return A_funcs 156 | 157 | @register_degradation(name='deblur_nonuniform') 158 | def deg_deblur_motion(deg_config, device, kernels=None, masks=None): 159 | A_funcs = measurements.NonuniformBlurOperator( 160 | deg_config.image_size, 161 | deg_config.deg_scale, 162 | device, 163 | kernels=kernels, 164 | masks=masks, 165 | ) 166 | return A_funcs 167 | 168 | 169 | # ======= FOR arbitraty image size ======= 170 | @register_degradation(name='sr_avgpool_gen') 171 | def deg_sr_avgpool_general(deg_config, device): 172 | blur_by = int(deg_config.deg_scale) 173 | A_funcs = svd_op.SuperResolutionGeneral(deg_config.channels, 174 | deg_config.imgH, 175 | deg_config.imgW, 176 | blur_by, 177 | device) 178 | return A_funcs 179 | 180 | @register_degradation(name='deblur_gauss_gen') 181 | def deg_deblur_guass_general(deg_config, device): 182 | A_funcs = measurements.GaussialBlurOperator( 183 | kernel_size=deg_config.deg_scale, 184 | intensity=3.0, 185 | device=device 186 | ) 187 | return A_funcs 188 | 189 | 190 | from functions.jpeg import jpeg_encode, jpeg_decode 191 | 192 | class JPEGOperator(): 193 | def __init__(self, qf: int, device): 194 | self.qf = qf 195 | self.device = device 196 | 197 | def A(self, img): 198 | x_luma, x_chroma = jpeg_encode(img, self.qf) 199 | return x_luma, x_chroma 200 | 201 | def At(self, encoded): 202 | return jpeg_decode(encoded, self.qf) 203 | 204 | 205 | @register_degradation(name='jpeg') 206 | def deg_jpeg(deg_config, device): 207 | A_funcs = JPEGOperator( 208 | qf = deg_config.deg_scale, 209 | device=device 210 | ) 211 | return A_funcs 212 | -------------------------------------------------------------------------------- /functions/jpeg.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from ddrm-jpeg. 5 | # 6 | # Source: 7 | # https://github.com/bahjat-kawar/ddrm-jpeg/blob/master/functions/jpeg_torch.py 8 | # 9 | # The license for the original version of this file can be 10 | # found in this directory (LICENSE_DDRM_JPEG). 11 | # The modifications to this file are subject to the same license. 12 | # --------------------------------------------------------------- 13 | 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | 18 | def dct1(x): 19 | """ 20 | Discrete Cosine Transform, Type I 21 | :param x: the input signal 22 | :return: the DCT-I of the signal over the last dimension 23 | """ 24 | x_shape = x.shape 25 | x = x.view(-1, x_shape[-1]) 26 | 27 | return torch.fft.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1))[:, :, 0].view(*x_shape) 28 | 29 | 30 | def idct1(X): 31 | """ 32 | The inverse of DCT-I, which is just a scaled DCT-I 33 | Our definition if idct1 is such that idct1(dct1(x)) == x 34 | :param X: the input signal 35 | :return: the inverse DCT-I of the signal over the last dimension 36 | """ 37 | n = X.shape[-1] 38 | return dct1(X) / (2 * (n - 1)) 39 | 40 | 41 | def dct(x, norm=None): 42 | """ 43 | Discrete Cosine Transform, Type II (a.k.a. the DCT) 44 | For the meaning of the parameter `norm`, see: 45 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 46 | :param x: the input signal 47 | :param norm: the normalization, None or 'ortho' 48 | :return: the DCT-II of the signal over the last dimension 49 | """ 50 | x_shape = x.shape 51 | N = x_shape[-1] 52 | x = x.contiguous().view(-1, N) 53 | 54 | v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) 55 | 56 | Vc = torch.view_as_real(torch.fft.fft(v, dim=1)) 57 | 58 | k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) 59 | W_r = torch.cos(k) 60 | W_i = torch.sin(k) 61 | 62 | V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i 63 | 64 | if norm == 'ortho': 65 | V[:, 0] /= np.sqrt(N) * 2 66 | V[:, 1:] /= np.sqrt(N / 2) * 2 67 | 68 | V = 2 * V.view(*x_shape) 69 | 70 | return V 71 | 72 | 73 | def idct(X, norm=None): 74 | """ 75 | The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III 76 | Our definition of idct is that idct(dct(x)) == x 77 | For the meaning of the parameter `norm`, see: 78 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 79 | :param X: the input signal 80 | :param norm: the normalization, None or 'ortho' 81 | :return: the inverse DCT-II of the signal over the last dimension 82 | """ 83 | 84 | x_shape = X.shape 85 | N = x_shape[-1] 86 | 87 | X_v = X.contiguous().view(-1, x_shape[-1]) / 2 88 | 89 | if norm == 'ortho': 90 | X_v[:, 0] *= np.sqrt(N) * 2 91 | X_v[:, 1:] *= np.sqrt(N / 2) * 2 92 | 93 | k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N) 94 | W_r = torch.cos(k) 95 | W_i = torch.sin(k) 96 | 97 | V_t_r = X_v 98 | V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1) 99 | 100 | V_r = V_t_r * W_r - V_t_i * W_i 101 | V_i = V_t_r * W_i + V_t_i * W_r 102 | 103 | V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) 104 | 105 | v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1) 106 | x = v.new_zeros(v.shape) 107 | x[:, ::2] += v[:, :N - (N // 2)] 108 | x[:, 1::2] += v.flip([1])[:, :N // 2] 109 | 110 | return x.view(*x_shape) 111 | 112 | 113 | def dct_2d(x, norm=None): 114 | """ 115 | 2-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT) 116 | For the meaning of the parameter `norm`, see: 117 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 118 | :param x: the input signal 119 | :param norm: the normalization, None or 'ortho' 120 | :return: the DCT-II of the signal over the last 2 dimensions 121 | """ 122 | X1 = dct(x, norm=norm) 123 | X2 = dct(X1.transpose(-1, -2), norm=norm) 124 | return X2.transpose(-1, -2) 125 | 126 | 127 | def idct_2d(X, norm=None): 128 | """ 129 | The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III 130 | Our definition of idct is that idct_2d(dct_2d(x)) == x 131 | For the meaning of the parameter `norm`, see: 132 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 133 | :param X: the input signal 134 | :param norm: the normalization, None or 'ortho' 135 | :return: the DCT-II of the signal over the last 2 dimensions 136 | """ 137 | x1 = idct(X, norm=norm) 138 | x2 = idct(x1.transpose(-1, -2), norm=norm) 139 | return x2.transpose(-1, -2) 140 | 141 | 142 | def dct_3d(x, norm=None): 143 | """ 144 | 3-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT) 145 | For the meaning of the parameter `norm`, see: 146 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 147 | :param x: the input signal 148 | :param norm: the normalization, None or 'ortho' 149 | :return: the DCT-II of the signal over the last 3 dimensions 150 | """ 151 | X1 = dct(x, norm=norm) 152 | X2 = dct(X1.transpose(-1, -2), norm=norm) 153 | X3 = dct(X2.transpose(-1, -3), norm=norm) 154 | return X3.transpose(-1, -3).transpose(-1, -2) 155 | 156 | 157 | def idct_3d(X, norm=None): 158 | """ 159 | The inverse to 3D DCT-II, which is a scaled Discrete Cosine Transform, Type III 160 | Our definition of idct is that idct_3d(dct_3d(x)) == x 161 | For the meaning of the parameter `norm`, see: 162 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 163 | :param X: the input signal 164 | :param norm: the normalization, None or 'ortho' 165 | :return: the DCT-II of the signal over the last 3 dimensions 166 | """ 167 | x1 = idct(X, norm=norm) 168 | x2 = idct(x1.transpose(-1, -2), norm=norm) 169 | x3 = idct(x2.transpose(-1, -3), norm=norm) 170 | return x3.transpose(-1, -3).transpose(-1, -2) 171 | 172 | 173 | class LinearDCT(nn.Linear): 174 | """Implement any DCT as a linear layer; in practice this executes around 175 | 50x faster on GPU. Unfortunately, the DCT matrix is stored, which will 176 | increase memory usage. 177 | :param in_features: size of expected input 178 | :param type: which dct function in this file to use""" 179 | def __init__(self, in_features, type, norm=None, bias=False): 180 | self.type = type 181 | self.N = in_features 182 | self.norm = norm 183 | super(LinearDCT, self).__init__(in_features, in_features, bias=bias) 184 | 185 | def reset_parameters(self): 186 | # initialise using dct function 187 | I = torch.eye(self.N) 188 | if self.type == 'dct1': 189 | self.weight.data = dct1(I).data.t() 190 | elif self.type == 'idct1': 191 | self.weight.data = idct1(I).data.t() 192 | elif self.type == 'dct': 193 | self.weight.data = dct(I, norm=self.norm).data.t() 194 | elif self.type == 'idct': 195 | self.weight.data = idct(I, norm=self.norm).data.t() 196 | self.weight.requires_grad = False # don't learn this! 197 | 198 | 199 | def apply_linear_2d(x, linear_layer): 200 | """Can be used with a LinearDCT layer to do a 2D DCT. 201 | :param x: the input signal 202 | :param linear_layer: any PyTorch Linear layer 203 | :return: result of linear layer applied to last 2 dimensions 204 | """ 205 | X1 = linear_layer(x) 206 | X2 = linear_layer(X1.transpose(-1, -2)) 207 | return X2.transpose(-1, -2) 208 | 209 | 210 | def apply_linear_3d(x, linear_layer): 211 | """Can be used with a LinearDCT layer to do a 3D DCT. 212 | :param x: the input signal 213 | :param linear_layer: any PyTorch Linear layer 214 | :return: result of linear layer applied to last 3 dimensions 215 | """ 216 | X1 = linear_layer(x) 217 | X2 = linear_layer(X1.transpose(-1, -2)) 218 | X3 = linear_layer(X2.transpose(-1, -3)) 219 | return X3.transpose(-1, -3).transpose(-1, -2) 220 | 221 | 222 | def torch_rgb2ycbcr(x): 223 | # Assume x is a batch of size (N x C x H x W) 224 | v = torch.tensor([[.299, .587, .114], [-.1687, -.3313, .5], [.5, -.4187, -.0813]]).to(x.device) 225 | ycbcr = torch.tensordot(x, v, dims=([1], [1])).transpose(3, 2).transpose(2, 1) 226 | ycbcr[:,1:] += 128 227 | return ycbcr 228 | 229 | 230 | def torch_ycbcr2rgb(x): 231 | # Assume x is a batch of size (N x C x H x W) 232 | v = torch.tensor([[ 1.00000000e+00, -3.68199903e-05, 1.40198758e+00], 233 | [ 1.00000000e+00, -3.44113281e-01, -7.14103821e-01], 234 | [ 1.00000000e+00, 1.77197812e+00, -1.34583413e-04]]).to(x.device) 235 | x[:, 1:] -= 128 236 | rgb = torch.tensordot(x, v, dims=([1], [1])).transpose(3, 2).transpose(2, 1) 237 | return rgb 238 | 239 | def chroma_subsample(x): 240 | return x[:, 0:1, :, :], x[:, 1:, ::2, ::2] 241 | 242 | 243 | def general_quant_matrix(qf = 10): 244 | q1 = torch.tensor([ 245 | 16, 11, 10, 16, 24, 40, 51, 61, 246 | 12, 12, 14, 19, 26, 58, 60, 55, 247 | 14, 13, 16, 24, 40, 57, 69, 56, 248 | 14, 17, 22, 29, 51, 87, 80, 62, 249 | 18, 22, 37, 56, 68, 109, 103, 77, 250 | 24, 35, 55, 64, 81, 104, 113, 92, 251 | 49, 64, 78, 87, 103, 121, 120, 101, 252 | 72, 92, 95, 98, 112, 100, 103, 99 253 | ]) 254 | q2 = torch.tensor([ 255 | 17, 18, 24, 47, 99, 99, 99, 99, 256 | 18, 21, 26, 66, 99, 99, 99, 99, 257 | 24, 26, 56, 99, 99, 99, 99, 99, 258 | 47, 66, 99, 99, 99, 99, 99, 99, 259 | 99, 99, 99, 99, 99, 99, 99, 99, 260 | 99, 99, 99, 99, 99, 99, 99, 99, 261 | 99, 99, 99, 99, 99, 99, 99, 99, 262 | 99, 99, 99, 99, 99, 99, 99, 99 263 | ]) 264 | s = (5000 / qf) if qf < 50 else (200 - 2 * qf) 265 | q1 = torch.floor((s * q1 + 50) / 100) 266 | q1[q1 <= 0] = 1 267 | q1[q1 > 255] = 255 268 | q2 = torch.floor((s * q2 + 50) / 100) 269 | q2[q2 <= 0] = 1 270 | q2[q2 > 255] = 255 271 | return q1, q2 272 | 273 | 274 | def quantization_matrix(qf): 275 | return general_quant_matrix(qf) 276 | # q1 = torch.tensor([[ 80, 55, 50, 80, 120, 200, 255, 255], 277 | # [ 60, 60, 70, 95, 130, 255, 255, 255], 278 | # [ 70, 65, 80, 120, 200, 255, 255, 255], 279 | # [ 70, 85, 110, 145, 255, 255, 255, 255], 280 | # [ 90, 110, 185, 255, 255, 255, 255, 255], 281 | # [120, 175, 255, 255, 255, 255, 255, 255], 282 | # [245, 255, 255, 255, 255, 255, 255, 255], 283 | # [255, 255, 255, 255, 255, 255, 255, 255]]) 284 | # q2 = torch.tensor([[ 85, 90, 120, 235, 255, 255, 255, 255], 285 | # [ 90, 105, 130, 255, 255, 255, 255, 255], 286 | # [120, 130, 255, 255, 255, 255, 255, 255], 287 | # [235, 255, 255, 255, 255, 255, 255, 255], 288 | # [255, 255, 255, 255, 255, 255, 255, 255], 289 | # [255, 255, 255, 255, 255, 255, 255, 255], 290 | # [255, 255, 255, 255, 255, 255, 255, 255], 291 | # [255, 255, 255, 255, 255, 255, 255, 255]]) 292 | # return q1, q2 293 | 294 | def jpeg_encode(x, qf): 295 | # Assume x is a batch of size (N x C x H x W) 296 | # [-1, 1] to [0, 255] 297 | x = (x + 1) / 2 * 255 298 | n_batch, _, n_size, _ = x.shape 299 | 300 | x = torch_rgb2ycbcr(x) 301 | x_luma, x_chroma = chroma_subsample(x) 302 | unfold = nn.Unfold(kernel_size=(8, 8), stride=(8, 8)) 303 | x_luma = unfold(x_luma).transpose(2, 1) 304 | x_chroma = unfold(x_chroma).transpose(2, 1) 305 | 306 | x_luma = x_luma.reshape(-1, 8, 8) - 128 307 | x_chroma = x_chroma.reshape(-1, 8, 8) - 128 308 | 309 | dct_layer = LinearDCT(8, 'dct', norm='ortho') 310 | dct_layer.to(x_luma.device) 311 | x_luma = apply_linear_2d(x_luma, dct_layer) 312 | x_chroma = apply_linear_2d(x_chroma, dct_layer) 313 | 314 | x_luma = x_luma.view(-1, 1, 8, 8) 315 | x_chroma = x_chroma.view(-1, 2, 8, 8) 316 | 317 | q1, q2 = quantization_matrix(qf) 318 | q1 = q1.to(x_luma.device) 319 | q2 = q2.to(x_luma.device) 320 | x_luma /= q1.view(1, 8, 8) 321 | x_chroma /= q2.view(1, 8, 8) 322 | 323 | x_luma = x_luma.round() 324 | x_chroma = x_chroma.round() 325 | 326 | x_luma = x_luma.reshape(n_batch, (n_size // 8) ** 2, 64).transpose(2, 1) 327 | x_chroma = x_chroma.reshape(n_batch, (n_size // 16) ** 2, 64 * 2).transpose(2, 1) 328 | 329 | fold = nn.Fold(output_size=(n_size, n_size), kernel_size=(8, 8), stride=(8, 8)) 330 | x_luma = fold(x_luma) 331 | fold = nn.Fold(output_size=(n_size // 2, n_size // 2), kernel_size=(8, 8), stride=(8, 8)) 332 | x_chroma = fold(x_chroma) 333 | 334 | return [x_luma, x_chroma] 335 | 336 | 337 | 338 | def jpeg_decode(x, qf): 339 | # Assume x[0] is a batch of size (N x 1 x H x W) (luma) 340 | # Assume x[1:] is a batch of size (N x 2 x H/2 x W/2) (chroma) 341 | x_luma, x_chroma = x 342 | n_batch, _, n_size, _ = x_luma.shape 343 | unfold = nn.Unfold(kernel_size=(8, 8), stride=(8, 8)) 344 | x_luma = unfold(x_luma).transpose(2, 1) 345 | x_luma = x_luma.reshape(-1, 1, 8, 8) 346 | x_chroma = unfold(x_chroma).transpose(2, 1) 347 | x_chroma = x_chroma.reshape(-1, 2, 8, 8) 348 | 349 | q1, q2 = quantization_matrix(qf) 350 | q1 = q1.to(x_luma.device) 351 | q2 = q2.to(x_luma.device) 352 | x_luma *= q1.view(1, 8, 8) 353 | x_chroma *= q2.view(1, 8, 8) 354 | 355 | x_luma = x_luma.reshape(-1, 8, 8) 356 | x_chroma = x_chroma.reshape(-1, 8, 8) 357 | 358 | dct_layer = LinearDCT(8, 'idct', norm='ortho') 359 | dct_layer.to(x_luma.device) 360 | x_luma = apply_linear_2d(x_luma, dct_layer) 361 | x_chroma = apply_linear_2d(x_chroma, dct_layer) 362 | 363 | x_luma = (x_luma + 128).reshape(n_batch, (n_size // 8) ** 2, 64).transpose(2, 1) 364 | x_chroma = (x_chroma + 128).reshape(n_batch, (n_size // 16) ** 2, 64 * 2).transpose(2, 1) 365 | 366 | fold = nn.Fold(output_size=(n_size, n_size), kernel_size=(8, 8), stride=(8, 8)) 367 | x_luma = fold(x_luma) 368 | fold = nn.Fold(output_size=(n_size // 2, n_size // 2), kernel_size=(8, 8), stride=(8, 8)) 369 | x_chroma = fold(x_chroma) 370 | 371 | x_chroma_repeated = torch.zeros(n_batch, 2, n_size, n_size, device = x_luma.device) 372 | x_chroma_repeated[:, :, 0::2, 0::2] = x_chroma 373 | x_chroma_repeated[:, :, 0::2, 1::2] = x_chroma 374 | x_chroma_repeated[:, :, 1::2, 0::2] = x_chroma 375 | x_chroma_repeated[:, :, 1::2, 1::2] = x_chroma 376 | 377 | x = torch.cat([x_luma, x_chroma_repeated], dim=1) 378 | 379 | x = torch_ycbcr2rgb(x) 380 | 381 | # [0, 255] to [-1, 1] 382 | x = x / 255 * 2 - 1 383 | 384 | return x 385 | 386 | 387 | def build_jpeg(qf): 388 | # log.info(f"[Corrupt] JPEG restoration: {qf=} ...") 389 | def jpeg(img): 390 | encoded = jpeg_encode(img, qf) 391 | return jpeg_decode(encoded, qf), encoded 392 | return jpeg -------------------------------------------------------------------------------- /functions/measurements.py: -------------------------------------------------------------------------------- 1 | '''This module handles task-dependent operations (A) and noises (n) to simulate a measurement y=Ax+n.''' 2 | 3 | from abc import ABC, abstractmethod 4 | from functools import partial 5 | 6 | from torch.nn import functional as F 7 | from torchvision import torch 8 | 9 | from utils.blur_util import Blurkernel 10 | from utils.img_util import fft2d 11 | import numpy as np 12 | from utils.resizer import Resizer 13 | from utils.utils_sisr import pre_calculate_FK, pre_calculate_nonuniform 14 | from torch.fft import fft2, ifft2 15 | 16 | 17 | from motionblur.motionblur import Kernel 18 | 19 | # ================= 20 | # Operation classes 21 | # ================= 22 | 23 | __OPERATOR__ = {} 24 | _GAMMA_FACTOR = 2.2 25 | 26 | def register_operator(name: str): 27 | def wrapper(cls): 28 | if __OPERATOR__.get(name, None): 29 | raise NameError(f"Name {name} is already registered!") 30 | __OPERATOR__[name] = cls 31 | return cls 32 | return wrapper 33 | 34 | 35 | def get_operator(name: str, **kwargs): 36 | if __OPERATOR__.get(name, None) is None: 37 | raise NameError(f"Name {name} is not defined.") 38 | return __OPERATOR__[name](**kwargs) 39 | 40 | 41 | class LinearOperator(ABC): 42 | @abstractmethod 43 | def forward(self, data, **kwargs): 44 | # calculate A * X 45 | pass 46 | 47 | @abstractmethod 48 | def noisy_forward(self, data, **kwargs): 49 | # calculate A * X + n 50 | pass 51 | 52 | @abstractmethod 53 | def transpose(self, data, **kwargs): 54 | # calculate A^T * X 55 | pass 56 | 57 | def ortho_project(self, data, **kwargs): 58 | # calculate (I - A^T * A)X 59 | return data - self.transpose(self.forward(data, **kwargs), **kwargs) 60 | 61 | def project(self, data, measurement, **kwargs): 62 | # calculate (I - A^T * A)Y - AX 63 | return self.ortho_project(measurement, **kwargs) - self.forward(data, **kwargs) 64 | 65 | 66 | @register_operator(name='noise') 67 | class DenoiseOperator(LinearOperator): 68 | def __init__(self, device): 69 | self.device = device 70 | 71 | def forward(self, data): 72 | return data 73 | 74 | def noisy_forward(self, data): 75 | return data 76 | 77 | def transpose(self, data): 78 | return data 79 | 80 | def ortho_project(self, data): 81 | return data 82 | 83 | def project(self, data): 84 | return data 85 | 86 | 87 | @register_operator(name='sr_bicubic') 88 | class SuperResolutionOperator(LinearOperator): 89 | def __init__(self, 90 | in_shape, 91 | scale_factor, 92 | noise, 93 | noise_scale, 94 | device): 95 | self.device = device 96 | self.up_sample = partial(F.interpolate, scale_factor=scale_factor) 97 | self.down_sample = Resizer(in_shape, 1/scale_factor).to(device) 98 | self.noise = get_noise(name=noise, scale=noise_scale) 99 | 100 | def A(self, data, **kwargs): 101 | return self.forward(data, **kwargs) 102 | 103 | def forward(self, data, **kwargs): 104 | return self.down_sample(data) 105 | 106 | def noisy_forward(self, data, **kwargs): 107 | return self.noise.forward(self.down_sample(data)) 108 | 109 | def transpose(self, data, **kwargs): 110 | return self.up_sample(data) 111 | 112 | def project(self, data, measurement, **kwargs): 113 | return data - self.transpose(self.forward(data)) + self.transpose(measurement) 114 | 115 | @register_operator(name='deblur_motion') 116 | class MotionBlurOperator(LinearOperator): 117 | def __init__(self, 118 | kernel_size, 119 | intensity, 120 | device): 121 | self.device = device 122 | self.kernel_size = kernel_size 123 | self.conv = Blurkernel(blur_type='motion', 124 | kernel_size=kernel_size, 125 | std=intensity, 126 | device=device).to(device) # should we keep this device term? 127 | 128 | self.kernel_size =kernel_size 129 | self.intensity = intensity 130 | self.kernel = Kernel(size=(kernel_size, kernel_size), intensity=intensity) 131 | kernel = torch.tensor(self.kernel.kernelMatrix, dtype=torch.float32) 132 | self.conv.update_weights(kernel) 133 | 134 | def forward(self, data, **kwargs): 135 | # A^T * A 136 | return self.conv(data) 137 | 138 | def noisy_forward(self, data, **kwargs): 139 | pass 140 | 141 | def transpose(self, data, **kwargs): 142 | return data 143 | 144 | def change_kernel(self): 145 | self.kernel = Kernel(size=(self.kernel_size, self.kernel_size), intensity=self.intensity) 146 | kernel = torch.tensor(self.kernel.kernelMatrix, dtype=torch.float32) 147 | self.conv.update_weights(kernel) 148 | 149 | def get_kernel(self): 150 | kernel = self.kernel.kernelMatrix.type(torch.float32).to(self.device) 151 | return kernel.view(1, 1, self.kernel_size, self.kernel_size) 152 | 153 | def A(self, data): 154 | return self.forward(data) 155 | 156 | def At(self, data): 157 | return self.transpose(data) 158 | 159 | @register_operator(name='deblur_gauss') 160 | class GaussialBlurOperator(LinearOperator): 161 | def __init__(self, 162 | kernel_size, 163 | intensity, 164 | device): 165 | self.device = device 166 | self.kernel_size = kernel_size 167 | self.conv = Blurkernel(blur_type='gaussian', 168 | kernel_size=kernel_size, 169 | std=intensity, 170 | device=device).to(device) 171 | self.kernel = self.conv.get_kernel() 172 | self.conv.update_weights(self.kernel.type(torch.float32)) 173 | 174 | def forward(self, data, **kwargs): 175 | return self.conv(data) 176 | 177 | def noisy_forward(self, data, **kwargs): 178 | pass 179 | 180 | def transpose(self, data, **kwargs): 181 | return data 182 | 183 | def get_kernel(self): 184 | return self.kernel.view(1, 1, self.kernel_size, self.kernel_size) 185 | 186 | def apply_kernel(self, data, kernel): 187 | self.conv.update_weights(kernel.type(torch.float32)) 188 | return self.conv(data) 189 | 190 | def A(self, data): 191 | return self.forward(data) 192 | 193 | def At(self, data): 194 | return self.transpose(data) 195 | 196 | @register_operator(name='inpainting') 197 | class InpaintingOperator(LinearOperator): 198 | '''This operator get pre-defined mask and return masked image.''' 199 | def __init__(self, 200 | noise, 201 | noise_scale, 202 | device): 203 | self.device = device 204 | self.noise = get_noise(name=noise, scale=noise_scale) 205 | 206 | def forward(self, data, **kwargs): 207 | try: 208 | return data * kwargs.get('mask', None).to(self.device) 209 | except: 210 | raise ValueError("Require mask") 211 | 212 | def noisy_forward(self, data, **kwargs): 213 | return self.noise.forward(self.forward(data, **kwargs)) 214 | 215 | def transpose(self, data, **kwargs): 216 | return data 217 | 218 | def ortho_project(self, data, **kwargs): 219 | return data - self.forward(data, **kwargs) 220 | 221 | # Operator for BlindDPS. 222 | @register_operator(name='blind_blur') 223 | class BlindBlurOperator(LinearOperator): 224 | def __init__(self, device, **kwargs) -> None: 225 | self.device = device 226 | 227 | def forward(self, data, kernel, **kwargs): 228 | return self.apply_kernel(data, kernel) 229 | 230 | def transpose(self, data, **kwargs): 231 | return data 232 | 233 | def apply_kernel(self, data, kernel): 234 | #TODO: faster way to apply conv?:W 235 | 236 | b_img = torch.zeros_like(data).to(self.device) 237 | for i in range(3): 238 | b_img[:, i, :, :] = F.conv2d(data[:, i:i+1, :, :], kernel, padding='same') 239 | return b_img 240 | 241 | 242 | class NonLinearOperator(ABC): 243 | @abstractmethod 244 | def forward(self, data, **kwargs): 245 | pass 246 | 247 | @abstractmethod 248 | def noisy_forward(self, data, **kwargs): 249 | pass 250 | 251 | def project(self, data, measurement, **kwargs): 252 | return data + measurement - self.forward(data) 253 | 254 | @register_operator(name='phase_retrieval') 255 | class PhaseRetrievalOperator(NonLinearOperator): 256 | def __init__(self, 257 | oversample, 258 | noise, 259 | noise_scale, 260 | device): 261 | self.pad = int((oversample / 8.0) * 256) 262 | self.device = device 263 | self.noise = get_noise(name=noise, scale=noise_scale) 264 | 265 | def forward(self, data, **kwargs): 266 | padded = F.pad(data, (self.pad, self.pad, self.pad, self.pad)) 267 | amplitude = fft2d(padded).abs() 268 | return amplitude 269 | 270 | def noisy_forard(self, data, **kwargs): 271 | return self.noise.forward(self.forward(data, **kwargs)) 272 | 273 | @register_operator(name='nonuniform_blur') 274 | class NonuniformBlurOperator(LinearOperator): 275 | def __init__(self, in_shape, kernel_size, device, 276 | kernels=None, masks=None): 277 | self.device = device 278 | self.kernel_size = kernel_size 279 | self.in_shape = in_shape 280 | 281 | # TODO: generalize 282 | if kernels is None and masks is None: 283 | self.kernels = np.load('./functions/nonuniform/kernels/000001.npy') 284 | self.masks = np.load('./functions/nonuniform/masks/000001.npy') 285 | self.kernels = torch.tensor(self.kernels).to(device) 286 | self.masks = torch.tensor(self.masks).to(device) 287 | 288 | # approximate in image space 289 | def forward_img(self, data): 290 | K = self.kernel_size 291 | data = F.pad(data, (K//2, K//2, K//2, K//2), mode="reflect") 292 | kernels = self.kernels.transpose(0, 1) 293 | data_rgb_batch = data.transpose(0, 1) 294 | conv_rgb_batch = F.conv2d(data_rgb_batch, kernels) 295 | y_rgb_batch = conv_rgb_batch * self.masks 296 | y_rgb_batch = torch.sum(y_rgb_batch, dim=1, keepdim=True) 297 | y = y_rgb_batch.transpose(0, 1) 298 | return y 299 | 300 | # NOTE: Only using this operator will make the problem nonlinear (gamma-correction) 301 | def forward_nonlinear(self, data, flatten=False, noiseless=False): 302 | # 1. Usual nonuniform blurring degradataion pipeline 303 | kernels = self.kernels.transpose(0, 1) 304 | FK, FKC = pre_calculate_FK(kernels) 305 | y = ifft2(FK * fft2(data)).real 306 | y = y.transpose(0, 1) 307 | y_rgb_batch = self.masks * y 308 | y_rgb_batch = torch.sum(y_rgb_batch, dim=1, keepdim=True) 309 | y = y_rgb_batch.transpose(0, 1) 310 | F2KM, FKFMy = pre_calculate_nonuniform(data, y, FK, FKC, self.masks) 311 | self.pre_calculated = (FK, FKC, F2KM, FKFMy) 312 | # 2. Gamma-correction 313 | y = (y + 1) / 2 314 | y = y ** (1 / _GAMMA_FACTOR) 315 | y = (y - 0.5) / 0.5 316 | return y 317 | 318 | def noisy_forward(self, data, **kwargs): 319 | return self.noise.forward(self.forward(data)) 320 | 321 | # exact in Fourier 322 | def forward(self, data, flatten=False, noiseless=False): 323 | # [1, 25, 33, 33] -> [25, 1, 33, 33] 324 | kernels = self.kernels.transpose(0, 1) 325 | # [25, 1, 512, 512] 326 | FK, FKC = pre_calculate_FK(kernels) 327 | # [25, 3, 512, 512] 328 | y = ifft2(FK * fft2(data)).real 329 | # [3, 25, 512, 512] 330 | y = y.transpose(0, 1) 331 | y_rgb_batch = self.masks * y 332 | # [3, 1, 512, 512] 333 | y_rgb_batch = torch.sum(y_rgb_batch, dim=1, keepdim=True) 334 | # [1, 3, 512, 512] 335 | y = y_rgb_batch.transpose(0, 1) 336 | F2KM, FKFMy = pre_calculate_nonuniform(data, y, FK, FKC, self.masks) 337 | self.pre_calculated = (FK, FKC, F2KM, FKFMy) 338 | return y 339 | 340 | def transpose(self, y, flatten=False): 341 | kernels = self.kernels.transpose(0, 1) 342 | FK, FKC = pre_calculate_FK(kernels) 343 | # 1. braodcast and multiply 344 | # [3, 1, 512, 512] 345 | y_rgb_batch = y.transpose(0, 1) 346 | # [3, 25, 512, 512] 347 | y_rgb_batch = y_rgb_batch.repeat(1, 25, 1, 1) 348 | y = self.masks * y_rgb_batch 349 | # 2. transpose of convolution in Fourier 350 | # [25, 3, 512, 512] 351 | y = y.transpose(0, 1) 352 | ATy_broadcast = ifft2(FKC * fft2(y)).real 353 | # [1, 3, 512, 512] 354 | ATy = torch.sum(ATy_broadcast, dim=0, keepdim=True) 355 | return ATy 356 | 357 | def A(self, data): 358 | return self.forward(data) 359 | 360 | def At(self, data): 361 | return self.transpose(data) 362 | 363 | # ============= 364 | # Noise classes 365 | # ============= 366 | 367 | 368 | __NOISE__ = {} 369 | 370 | def register_noise(name: str): 371 | def wrapper(cls): 372 | if __NOISE__.get(name, None): 373 | raise NameError(f"Name {name} is already defined!") 374 | __NOISE__[name] = cls 375 | return cls 376 | return wrapper 377 | 378 | def get_noise(name: str, **kwargs): 379 | if __NOISE__.get(name, None) is None: 380 | raise NameError(f"Name {name} is not defined.") 381 | noiser = __NOISE__[name](**kwargs) 382 | noiser.__name__ = name 383 | return noiser 384 | 385 | class Noise(ABC): 386 | def __call__(self, data): 387 | return self.forward(data) 388 | 389 | @abstractmethod 390 | def forward(self, data): 391 | pass 392 | 393 | @register_noise(name='clean') 394 | class Clean(Noise): 395 | def __init__(self, **kwargs): 396 | pass 397 | 398 | def forward(self, data): 399 | return data 400 | 401 | @register_noise(name='gaussian') 402 | class GaussianNoise(Noise): 403 | def __init__(self, scale): 404 | self.scale = scale 405 | 406 | def forward(self, data): 407 | return data + torch.randn_like(data, device=data.device) * self.scale 408 | 409 | 410 | @register_noise(name='poisson') 411 | class PoissonNoise(Noise): 412 | def __init__(self, scale): 413 | self.scale = scale 414 | 415 | def forward(self, data): 416 | ''' 417 | Follow skimage.util.random_noise. 418 | ''' 419 | 420 | # version 3 (stack-overflow) 421 | import numpy as np 422 | data = (data + 1.0) / 2.0 423 | data = data.clamp(0, 1) 424 | device = data.device 425 | data = data.detach().cpu() 426 | data = torch.from_numpy(np.random.poisson(data * 255.0 * self.scale) / 255.0 / self.scale) 427 | data = data * 2.0 - 1.0 428 | data = data.clamp(-1, 1) 429 | return data.to(device) 430 | -------------------------------------------------------------------------------- /functions/nonuniform/kernels/000001.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlowDPS-Inverse/FlowDPS/3e73e443c987ba378b091515ba1e9b0963860e81/functions/nonuniform/kernels/000001.npy -------------------------------------------------------------------------------- /functions/svd_ddnm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import torchvision.utils as tvu 4 | import torchvision 5 | import os 6 | 7 | class_num = 951 8 | 9 | 10 | def compute_alpha(beta, t): 11 | beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) 12 | a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1) 13 | return a 14 | 15 | def inverse_data_transform(x): 16 | x = (x + 1.0) / 2.0 17 | return torch.clamp(x, 0.0, 1.0) 18 | 19 | def ddnm_diffusion(x, model, b, eta, A_funcs, y, cls_fn=None, classes=None, config=None): 20 | with torch.no_grad(): 21 | 22 | # setup iteration variables 23 | skip = config.diffusion.num_diffusion_timesteps//config.time_travel.T_sampling 24 | n = x.size(0) 25 | x0_preds = [] 26 | xs = [x] 27 | 28 | # generate time schedule 29 | times = get_schedule_jump(config.time_travel.T_sampling, 30 | config.time_travel.travel_length, 31 | config.time_travel.travel_repeat, 32 | ) 33 | time_pairs = list(zip(times[:-1], times[1:])) 34 | 35 | # reverse diffusion sampling 36 | for i, j in tqdm(time_pairs): 37 | i, j = i*skip, j*skip 38 | if j<0: j=-1 39 | 40 | if j < i: # normal sampling 41 | t = (torch.ones(n) * i).to(x.device) 42 | next_t = (torch.ones(n) * j).to(x.device) 43 | at = compute_alpha(b, t.long()) 44 | at_next = compute_alpha(b, next_t.long()) 45 | xt = xs[-1].to('cuda') 46 | if cls_fn == None: 47 | et = model(xt, t) 48 | else: 49 | classes = torch.ones(xt.size(0), dtype=torch.long, device=torch.device("cuda"))*class_num 50 | et = model(xt, t, classes) 51 | et = et[:, :3] 52 | et = et - (1 - at).sqrt()[0, 0, 0, 0] * cls_fn(x, t, classes) 53 | 54 | if et.size(1) == 6: 55 | et = et[:, :3] 56 | 57 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() 58 | 59 | x0_t_hat = x0_t - A_funcs.A_pinv( 60 | A_funcs.A(x0_t.reshape(x0_t.size(0), -1)) - y.reshape(y.size(0), -1) 61 | ).reshape(*x0_t.size()) 62 | 63 | c1 = (1 - at_next).sqrt() * eta 64 | c2 = (1 - at_next).sqrt() * ((1 - eta ** 2) ** 0.5) 65 | xt_next = at_next.sqrt() * x0_t_hat + c1 * torch.randn_like(x0_t) + c2 * et 66 | 67 | x0_preds.append(x0_t.to('cpu')) 68 | xs.append(xt_next.to('cpu')) 69 | else: # time-travel back 70 | next_t = (torch.ones(n) * j).to(x.device) 71 | at_next = compute_alpha(b, next_t.long()) 72 | x0_t = x0_preds[-1].to('cuda') 73 | 74 | xt_next = at_next.sqrt() * x0_t + torch.randn_like(x0_t) * (1 - at_next).sqrt() 75 | 76 | xs.append(xt_next.to('cpu')) 77 | 78 | return [xs[-1]], [x0_preds[-1]] 79 | 80 | def ddnm_plus_diffusion(x, model, b, eta, A_funcs, y, sigma_y, cls_fn=None, classes=None, config=None): 81 | with torch.no_grad(): 82 | 83 | # setup iteration variables 84 | skip = config.diffusion.num_diffusion_timesteps//config.time_travel.T_sampling 85 | n = x.size(0) 86 | x0_preds = [] 87 | xs = [x] 88 | 89 | # generate time schedule 90 | times = get_schedule_jump(config.time_travel.T_sampling, 91 | config.time_travel.travel_length, 92 | config.time_travel.travel_repeat, 93 | ) 94 | time_pairs = list(zip(times[:-1], times[1:])) 95 | 96 | # reverse diffusion sampling 97 | for i, j in tqdm(time_pairs): 98 | i, j = i*skip, j*skip 99 | if j<0: j=-1 100 | 101 | if j < i: # normal sampling 102 | t = (torch.ones(n) * i).to(x.device) 103 | next_t = (torch.ones(n) * j).to(x.device) 104 | at = compute_alpha(b, t.long()) 105 | at_next = compute_alpha(b, next_t.long()) 106 | xt = xs[-1].to('cuda') 107 | if cls_fn == None: 108 | et = model(xt, t) 109 | else: 110 | classes = torch.ones(xt.size(0), dtype=torch.long, device=torch.device("cuda"))*class_num 111 | et = model(xt, t, classes) 112 | et = et[:, :3] 113 | et = et - (1 - at).sqrt()[0, 0, 0, 0] * cls_fn(x, t, classes) 114 | 115 | if et.size(1) == 6: 116 | et = et[:, :3] 117 | 118 | # Eq. 12 119 | x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() 120 | 121 | sigma_t = (1 - at_next).sqrt()[0, 0, 0, 0] 122 | 123 | # Eq. 17 124 | x0_t_hat = x0_t - A_funcs.Lambda(A_funcs.A_pinv( 125 | A_funcs.A(x0_t.reshape(x0_t.size(0), -1)) - y.reshape(y.size(0), -1) 126 | ).reshape(x0_t.size(0), -1), at_next.sqrt()[0, 0, 0, 0], sigma_y, sigma_t, eta).reshape(*x0_t.size()) 127 | 128 | # Eq. 51 129 | xt_next = at_next.sqrt() * x0_t_hat + A_funcs.Lambda_noise( 130 | torch.randn_like(x0_t).reshape(x0_t.size(0), -1), 131 | at_next.sqrt()[0, 0, 0, 0], sigma_y, sigma_t, eta, et.reshape(et.size(0), -1)).reshape(*x0_t.size()) 132 | 133 | x0_preds.append(x0_t.to('cpu')) 134 | xs.append(xt_next.to('cpu')) 135 | else: # time-travel back 136 | next_t = (torch.ones(n) * j).to(x.device) 137 | at_next = compute_alpha(b, next_t.long()) 138 | x0_t = x0_preds[-1].to('cuda') 139 | 140 | xt_next = at_next.sqrt() * x0_t + torch.randn_like(x0_t) * (1 - at_next).sqrt() 141 | 142 | xs.append(xt_next.to('cpu')) 143 | 144 | # #ablation 145 | # if i%50==0: 146 | # os.makedirs('/userhome/wyh/ddnm/debug/x0t', exist_ok=True) 147 | # tvu.save_image( 148 | # inverse_data_transform(x0_t[0]), 149 | # os.path.join('/userhome/wyh/ddnm/debug/x0t', f"x0_t_{i}.png") 150 | # ) 151 | 152 | # os.makedirs('/userhome/wyh/ddnm/debug/x0_t_hat', exist_ok=True) 153 | # tvu.save_image( 154 | # inverse_data_transform(x0_t_hat[0]), 155 | # os.path.join('/userhome/wyh/ddnm/debug/x0_t_hat', f"x0_t_hat_{i}.png") 156 | # ) 157 | 158 | # os.makedirs('/userhome/wyh/ddnm/debug/xt_next', exist_ok=True) 159 | # tvu.save_image( 160 | # inverse_data_transform(xt_next[0]), 161 | # os.path.join('/userhome/wyh/ddnm/debug/xt_next', f"xt_next_{i}.png") 162 | # ) 163 | 164 | return [xs[-1]], [x0_preds[-1]] 165 | 166 | # form RePaint 167 | def get_schedule_jump(T_sampling, travel_length, travel_repeat): 168 | 169 | jumps = {} 170 | for j in range(0, T_sampling - travel_length, travel_length): 171 | jumps[j] = travel_repeat - 1 172 | 173 | t = T_sampling 174 | ts = [] 175 | 176 | while t >= 1: 177 | t = t-1 178 | ts.append(t) 179 | 180 | if jumps.get(t, 0) > 0: 181 | jumps[t] = jumps[t] - 1 182 | for _ in range(travel_length): 183 | t = t + 1 184 | ts.append(t) 185 | 186 | ts.append(-1) 187 | 188 | _check_times(ts, -1, T_sampling) 189 | 190 | return ts 191 | 192 | def _check_times(times, t_0, T_sampling): 193 | # Check end 194 | assert times[0] > times[1], (times[0], times[1]) 195 | 196 | # Check beginning 197 | assert times[-1] == -1, times[-1] 198 | 199 | # Steplength = 1 200 | for t_last, t_cur in zip(times[:-1], times[1:]): 201 | assert abs(t_last - t_cur) == 1, (t_last, t_cur) 202 | 203 | # Value range 204 | for t in times: 205 | assert t >= t_0, (t, t_0) 206 | assert t <= T_sampling, (t, T_sampling) 207 | -------------------------------------------------------------------------------- /functions/svd_operators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class A_functions: 5 | """ 6 | A class replacing the SVD of a matrix A, perhaps efficiently. 7 | All input vectors are of shape (Batch, ...). 8 | All output vectors are of shape (Batch, DataDimension). 9 | """ 10 | 11 | def V(self, vec): 12 | """ 13 | Multiplies the input vector by V 14 | """ 15 | raise NotImplementedError() 16 | 17 | def Vt(self, vec): 18 | """ 19 | Multiplies the input vector by V transposed 20 | """ 21 | raise NotImplementedError() 22 | 23 | def U(self, vec): 24 | """ 25 | Multiplies the input vector by U 26 | """ 27 | raise NotImplementedError() 28 | 29 | def Ut(self, vec): 30 | """ 31 | Multiplies the input vector by U transposed 32 | """ 33 | raise NotImplementedError() 34 | 35 | def singulars(self): 36 | """ 37 | Returns a vector containing the singular values. The shape of the vector should be the same as the smaller dimension (like U) 38 | """ 39 | raise NotImplementedError() 40 | 41 | def add_zeros(self, vec): 42 | """ 43 | Adds trailing zeros to turn a vector from the small dimension (U) to the big dimension (V) 44 | """ 45 | raise NotImplementedError() 46 | 47 | def A(self, vec): 48 | """ 49 | Multiplies the input vector by A 50 | """ 51 | temp = self.Vt(vec) 52 | singulars = self.singulars() 53 | return self.U(singulars * temp[:, :singulars.shape[0]]) 54 | 55 | def At(self, vec): 56 | """ 57 | Multiplies the input vector by A transposed 58 | """ 59 | temp = self.Ut(vec) 60 | singulars = self.singulars() 61 | return self.V(self.add_zeros(singulars * temp[:, :singulars.shape[0]])) 62 | 63 | def A_pinv(self, vec): 64 | """ 65 | Multiplies the input vector by the pseudo inverse of A 66 | """ 67 | temp = self.Ut(vec) 68 | singulars = self.singulars() 69 | 70 | factors = 1. / singulars 71 | factors[singulars == 0] = 0. 72 | 73 | # temp[:, :singulars.shape[0]] = temp[:, :singulars.shape[0]] / singulars 74 | temp[:, :singulars.shape[0]] = temp[:, :singulars.shape[0]] * factors 75 | return self.V(self.add_zeros(temp)) 76 | 77 | def A_pinv_eta(self, vec, eta): 78 | """ 79 | Multiplies the input vector by the pseudo inverse of A with factor eta 80 | """ 81 | temp = self.Ut(vec) 82 | singulars = self.singulars() 83 | factors = singulars / (singulars*singulars+eta) 84 | # print(temp.size(), factors.size(), singulars.size()) 85 | temp[:, :singulars.shape[0]] = temp[:, :singulars.shape[0]] * factors 86 | return self.V(self.add_zeros(temp)) 87 | 88 | def Lambda(self, vec, a, sigma_y, sigma_t, eta): 89 | raise NotImplementedError() 90 | 91 | def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon): 92 | raise NotImplementedError() 93 | 94 | 95 | # block-wise CS 96 | class CS(A_functions): 97 | def __init__(self, channels, img_dim, ratio, device): #ratio = 2 or 4 98 | self.img_dim = img_dim 99 | self.channels = channels 100 | self.y_dim = img_dim // 32 101 | self.ratio = 32 102 | A = torch.randn(32**2, 32**2).to(device) 103 | _, _, self.V_small = torch.svd(A, some=False) 104 | self.Vt_small = self.V_small.transpose(0, 1) 105 | self.singulars_small = torch.ones(int(32 * 32 * ratio), device=device) 106 | self.cs_size = self.singulars_small.size(0) 107 | 108 | def V(self, vec): 109 | #reorder the vector back into patches (because singulars are ordered descendingly) 110 | 111 | temp = vec.clone().reshape(vec.shape[0], -1) 112 | patches = torch.zeros(vec.size(0), self.channels * self.y_dim ** 2, self.ratio ** 2, device=vec.device) 113 | patches[:, :, :self.cs_size] = temp[:, :self.channels * self.y_dim ** 2 * self.cs_size].contiguous().reshape( 114 | vec.size(0), -1, self.cs_size) 115 | patches[:, :, self.cs_size:] = temp[:, self.channels * self.y_dim ** 2 * self.cs_size:].contiguous().reshape( 116 | vec.size(0), self.channels * self.y_dim ** 2, -1) 117 | 118 | #multiply each patch by the small V 119 | patches = torch.matmul(self.V_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2) 120 | #repatch the patches into an image 121 | patches_orig = patches.reshape(vec.shape[0], self.channels, self.y_dim, self.y_dim, self.ratio, self.ratio) 122 | recon = patches_orig.permute(0, 1, 2, 4, 3, 5).contiguous() 123 | recon = recon.reshape(vec.shape[0], self.channels * self.img_dim ** 2) 124 | return recon 125 | 126 | def Vt(self, vec): 127 | #extract flattened patches 128 | patches = vec.clone().reshape(vec.shape[0], self.channels, self.img_dim, self.img_dim) 129 | patches = patches.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio) 130 | patches = patches.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio**2) 131 | #multiply each by the small V transposed 132 | patches = torch.matmul(self.Vt_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2) 133 | #reorder the vector to have the first entry first (because singulars are ordered descendingly) 134 | recon = torch.zeros(vec.shape[0], self.channels * self.img_dim**2, device=vec.device) 135 | recon[:, :self.channels * self.y_dim ** 2 * self.cs_size] = patches[:, :, :, :self.cs_size].contiguous().reshape( 136 | vec.shape[0], -1) 137 | recon[:, self.channels * self.y_dim ** 2 * self.cs_size:] = patches[:, :, :, self.cs_size:].contiguous().reshape( 138 | vec.shape[0], -1) 139 | return recon 140 | 141 | def U(self, vec): 142 | return vec.clone().reshape(vec.shape[0], -1) 143 | 144 | def Ut(self, vec): #U is 1x1, so U^T = U 145 | return vec.clone().reshape(vec.shape[0], -1) 146 | 147 | def singulars(self): 148 | return self.singulars_small.repeat(self.channels * self.y_dim**2) 149 | 150 | def add_zeros(self, vec): 151 | reshaped = vec.clone().reshape(vec.shape[0], -1) 152 | temp = torch.zeros((vec.shape[0], self.channels * self.img_dim**2), device=vec.device) 153 | temp[:, :reshaped.shape[1]] = reshaped 154 | return temp 155 | 156 | 157 | def color2gray(x): 158 | x = x[:, 0:1, :, :] * 0.3333 + x[:, 1:2, :, :] * 0.3334 + x[:, 2:, :, :] * 0.3333 159 | return x 160 | 161 | 162 | def gray2color(x): 163 | base = 0.3333 ** 2 + 0.3334 ** 2 + 0.3333 ** 2 164 | return torch.stack((x * 0.3333 / base, x * 0.3334 / base, x * 0.3333 / base), 1) 165 | 166 | 167 | #a memory inefficient implementation for any general degradation A 168 | class GeneralA(A_functions): 169 | def mat_by_vec(self, M, v): 170 | vshape = v.shape[1] 171 | if len(v.shape) > 2: vshape = vshape * v.shape[2] 172 | if len(v.shape) > 3: vshape = vshape * v.shape[3] 173 | return torch.matmul(M, v.view(v.shape[0], vshape, 174 | 1)).view(v.shape[0], M.shape[0]) 175 | 176 | def __init__(self, A): 177 | self._U, self._singulars, self._V = torch.svd(A, some=False) 178 | self._Vt = self._V.transpose(0, 1) 179 | self._Ut = self._U.transpose(0, 1) 180 | 181 | ZERO = 1e-3 182 | self._singulars[self._singulars < ZERO] = 0 183 | print(len([x.item() for x in self._singulars if x == 0])) 184 | 185 | def V(self, vec): 186 | return self.mat_by_vec(self._V, vec.clone()) 187 | 188 | def Vt(self, vec): 189 | return self.mat_by_vec(self._Vt, vec.clone()) 190 | 191 | def U(self, vec): 192 | return self.mat_by_vec(self._U, vec.clone()) 193 | 194 | def Ut(self, vec): 195 | return self.mat_by_vec(self._Ut, vec.clone()) 196 | 197 | def singulars(self): 198 | return self._singulars 199 | 200 | def add_zeros(self, vec): 201 | out = torch.zeros(vec.shape[0], self._V.shape[0], device=vec.device) 202 | out[:, :self._U.shape[0]] = vec.clone().reshape(vec.shape[0], -1) 203 | return out 204 | 205 | #Walsh-Hadamard Compressive Sensing 206 | class WalshHadamardCS(A_functions): 207 | def fwht(self, vec): #the Fast Walsh Hadamard Transform is the same as its inverse 208 | a = vec.reshape(vec.shape[0], self.channels, self.img_dim**2) 209 | h = 1 210 | while h < self.img_dim**2: 211 | a = a.reshape(vec.shape[0], self.channels, -1, h * 2) 212 | b = a.clone() 213 | a[:, :, :, :h] = b[:, :, :, :h] + b[:, :, :, h:2*h] 214 | a[:, :, :, h:2*h] = b[:, :, :, :h] - b[:, :, :, h:2*h] 215 | h *= 2 216 | a = a.reshape(vec.shape[0], self.channels, self.img_dim**2) / self.img_dim 217 | return a 218 | 219 | def __init__(self, channels, img_dim, ratio, perm, device): 220 | self.channels = channels 221 | self.img_dim = img_dim 222 | self.ratio = ratio 223 | self.perm = perm 224 | self._singulars = torch.ones(channels * img_dim**2 // ratio, device=device) 225 | 226 | def V(self, vec): 227 | temp = torch.zeros(vec.shape[0], self.channels, self.img_dim**2, device=vec.device) 228 | temp[:, :, self.perm] = vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1) 229 | return self.fwht(temp).reshape(vec.shape[0], -1) 230 | 231 | def Vt(self, vec): 232 | return self.fwht(vec.clone())[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1) 233 | 234 | def U(self, vec): 235 | return vec.clone().reshape(vec.shape[0], -1) 236 | 237 | def Ut(self, vec): 238 | return vec.clone().reshape(vec.shape[0], -1) 239 | 240 | def singulars(self): 241 | return self._singulars 242 | 243 | def add_zeros(self, vec): 244 | out = torch.zeros(vec.shape[0], self.channels * self.img_dim**2, device=vec.device) 245 | out[:, :self.channels * self.img_dim**2 // self.ratio] = vec.clone().reshape(vec.shape[0], -1) 246 | return out 247 | 248 | def Lambda(self, vec, a, sigma_y, sigma_t, eta): 249 | temp_vec = self.fwht(vec.clone())[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1) 250 | 251 | singulars = self._singulars 252 | lambda_t = torch.ones(self.channels * self.img_dim ** 2, device=vec.device) 253 | temp = torch.zeros(self.channels * self.img_dim ** 2, device=vec.device) 254 | temp[:singulars.size(0)] = singulars 255 | singulars = temp 256 | inverse_singulars = 1. / singulars 257 | inverse_singulars[singulars == 0] = 0. 258 | 259 | if a != 0 and sigma_y != 0: 260 | change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0 261 | lambda_t = lambda_t * (-change_index + 1.0) + change_index * ( 262 | singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y) 263 | 264 | lambda_t = lambda_t.reshape(1, -1) 265 | temp_vec = temp_vec * lambda_t 266 | 267 | temp_out = torch.zeros(vec.shape[0], self.channels, self.img_dim ** 2, device=vec.device) 268 | temp_out[:, :, self.perm] = temp_vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1) 269 | return self.fwht(temp_out).reshape(vec.shape[0], -1) 270 | 271 | def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon): 272 | temp_vec = vec.clone().reshape( 273 | vec.shape[0], self.channels, self.img_dim ** 2)[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1) 274 | temp_eps = epsilon.clone().reshape( 275 | vec.shape[0], self.channels, self.img_dim ** 2)[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1) 276 | 277 | d1_t = torch.ones(self.channels * self.img_dim ** 2, device=vec.device) * sigma_t * eta 278 | d2_t = torch.ones(self.channels * self.img_dim ** 2, device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5 279 | 280 | singulars = self._singulars 281 | temp = torch.zeros(self.channels * self.img_dim ** 2, device=vec.device) 282 | temp[:singulars.size(0)] = singulars 283 | singulars = temp 284 | inverse_singulars = 1. / singulars 285 | inverse_singulars[singulars == 0] = 0. 286 | 287 | if a != 0 and sigma_y != 0: 288 | change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0 289 | d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta 290 | d2_t = d2_t * (-change_index + 1.0) 291 | 292 | change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0 293 | d1_t = d1_t * (-change_index + 1.0) + torch.sqrt( 294 | change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2)) 295 | d2_t = d2_t * (-change_index + 1.0) 296 | 297 | change_index = (singulars == 0) * 1.0 298 | d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta 299 | d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5 300 | 301 | d1_t = d1_t.reshape(1, -1) 302 | d2_t = d2_t.reshape(1, -1) 303 | 304 | temp_vec = temp_vec * d1_t 305 | temp_eps = temp_eps * d2_t 306 | 307 | temp_out_vec = torch.zeros(vec.shape[0], self.channels, self.img_dim ** 2, device=vec.device) 308 | temp_out_vec[:, :, self.perm] = temp_vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1) 309 | temp_out_vec = self.fwht(temp_out_vec).reshape(vec.shape[0], -1) 310 | 311 | temp_out_eps = torch.zeros(vec.shape[0], self.channels, self.img_dim ** 2, device=vec.device) 312 | temp_out_eps[:, :, self.perm] = temp_eps.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1) 313 | temp_out_eps = self.fwht(temp_out_eps).reshape(vec.shape[0], -1) 314 | 315 | return temp_out_vec + temp_out_eps 316 | 317 | 318 | #Inpainting 319 | class Inpainting(A_functions): 320 | def __init__(self, channels, img_dim, missing_indices, device): 321 | self.channels = channels 322 | self.img_dim = img_dim 323 | self._singulars = torch.ones(channels * img_dim**2 - missing_indices.shape[0]).to(device) 324 | self.missing_indices = missing_indices 325 | self.kept_indices = torch.Tensor([i for i in range(channels * img_dim**2) if i not in missing_indices]).to(device).long() 326 | 327 | def V(self, vec): 328 | temp = vec.clone().reshape(vec.shape[0], -1) 329 | out = torch.zeros_like(temp) 330 | out[:, self.kept_indices] = temp[:, :self.kept_indices.shape[0]] 331 | out[:, self.missing_indices] = temp[:, self.kept_indices.shape[0]:] 332 | return out.reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1).reshape(vec.shape[0], -1) 333 | 334 | def Vt(self, vec): 335 | temp = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1).reshape(vec.shape[0], -1) 336 | out = torch.zeros_like(temp) 337 | out[:, :self.kept_indices.shape[0]] = temp[:, self.kept_indices] 338 | out[:, self.kept_indices.shape[0]:] = temp[:, self.missing_indices] 339 | return out 340 | 341 | def U(self, vec): 342 | return vec.clone().reshape(vec.shape[0], -1) 343 | 344 | def Ut(self, vec): 345 | return vec.clone().reshape(vec.shape[0], -1) 346 | 347 | def singulars(self): 348 | return self._singulars 349 | 350 | def add_zeros(self, vec): 351 | temp = torch.zeros((vec.shape[0], self.channels * self.img_dim**2), device=vec.device) 352 | reshaped = vec.clone().reshape(vec.shape[0], -1) 353 | temp[:, :reshaped.shape[1]] = reshaped 354 | return temp 355 | 356 | def Lambda(self, vec, a, sigma_y, sigma_t, eta): 357 | 358 | temp = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1).reshape(vec.shape[0], -1) 359 | out = torch.zeros_like(temp) 360 | out[:, :self.kept_indices.shape[0]] = temp[:, self.kept_indices] 361 | out[:, self.kept_indices.shape[0]:] = temp[:, self.missing_indices] 362 | 363 | singulars = self._singulars 364 | lambda_t = torch.ones(temp.size(1), device=vec.device) 365 | temp_singulars = torch.zeros(temp.size(1), device=vec.device) 366 | temp_singulars[:singulars.size(0)] = singulars 367 | singulars = temp_singulars 368 | inverse_singulars = 1. / singulars 369 | inverse_singulars[singulars == 0] = 0. 370 | 371 | if a != 0 and sigma_y != 0: 372 | change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0 373 | lambda_t = lambda_t * (-change_index + 1.0) + change_index * ( 374 | singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y) 375 | 376 | lambda_t = lambda_t.reshape(1, -1) 377 | out = out * lambda_t 378 | 379 | result = torch.zeros_like(temp) 380 | result[:, self.kept_indices] = out[:, :self.kept_indices.shape[0]] 381 | result[:, self.missing_indices] = out[:, self.kept_indices.shape[0]:] 382 | return result.reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1).reshape(vec.shape[0], -1) 383 | 384 | def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon): 385 | temp_vec = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1).reshape(vec.shape[0], -1) 386 | out_vec = torch.zeros_like(temp_vec) 387 | out_vec[:, :self.kept_indices.shape[0]] = temp_vec[:, self.kept_indices] 388 | out_vec[:, self.kept_indices.shape[0]:] = temp_vec[:, self.missing_indices] 389 | 390 | temp_eps = epsilon.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1).reshape(vec.shape[0], -1) 391 | out_eps = torch.zeros_like(temp_eps) 392 | out_eps[:, :self.kept_indices.shape[0]] = temp_eps[:, self.kept_indices] 393 | out_eps[:, self.kept_indices.shape[0]:] = temp_eps[:, self.missing_indices] 394 | 395 | singulars = self._singulars 396 | d1_t = torch.ones(temp_vec.size(1), device=vec.device) * sigma_t * eta 397 | d2_t = torch.ones(temp_vec.size(1), device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5 398 | 399 | temp_singulars = torch.zeros(temp_vec.size(1), device=vec.device) 400 | temp_singulars[:singulars.size(0)] = singulars 401 | singulars = temp_singulars 402 | inverse_singulars = 1. / singulars 403 | inverse_singulars[singulars == 0] = 0. 404 | 405 | if a != 0 and sigma_y != 0: 406 | change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0 407 | d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta 408 | d2_t = d2_t * (-change_index + 1.0) 409 | 410 | change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0 411 | d1_t = d1_t * (-change_index + 1.0) + torch.sqrt( 412 | change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2)) 413 | d2_t = d2_t * (-change_index + 1.0) 414 | 415 | change_index = (singulars == 0) * 1.0 416 | d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta 417 | d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5 418 | 419 | d1_t = d1_t.reshape(1, -1) 420 | d2_t = d2_t.reshape(1, -1) 421 | out_vec = out_vec * d1_t 422 | out_eps = out_eps * d2_t 423 | 424 | result_vec = torch.zeros_like(temp_vec) 425 | result_vec[:, self.kept_indices] = out_vec[:, :self.kept_indices.shape[0]] 426 | result_vec[:, self.missing_indices] = out_vec[:, self.kept_indices.shape[0]:] 427 | result_vec = result_vec.reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1).reshape(vec.shape[0], -1) 428 | 429 | result_eps = torch.zeros_like(temp_eps) 430 | result_eps[:, self.kept_indices] = out_eps[:, :self.kept_indices.shape[0]] 431 | result_eps[:, self.missing_indices] = out_eps[:, self.kept_indices.shape[0]:] 432 | result_eps = result_eps.reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1).reshape(vec.shape[0], -1) 433 | 434 | return result_vec + result_eps 435 | 436 | #Denoising 437 | class Denoising(A_functions): 438 | def __init__(self, channels, img_dim, device): 439 | self._singulars = torch.ones(channels * img_dim**2, device=device) 440 | 441 | def V(self, vec): 442 | return vec.clone().reshape(vec.shape[0], -1) 443 | 444 | def Vt(self, vec): 445 | return vec.clone().reshape(vec.shape[0], -1) 446 | 447 | def U(self, vec): 448 | return vec.clone().reshape(vec.shape[0], -1) 449 | 450 | def Ut(self, vec): 451 | return vec.clone().reshape(vec.shape[0], -1) 452 | 453 | def singulars(self): 454 | return self._singulars 455 | 456 | def add_zeros(self, vec): 457 | return vec.clone().reshape(vec.shape[0], -1) 458 | 459 | def Lambda(self, vec, a, sigma_y, sigma_t, eta): 460 | if sigma_t < a * sigma_y: 461 | factor = (sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y).item() 462 | return vec * factor 463 | else: 464 | return vec 465 | 466 | def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon): 467 | if sigma_t >= a * sigma_y: 468 | factor = torch.sqrt(sigma_t ** 2 - a ** 2 * sigma_y ** 2).item() 469 | return vec * factor 470 | else: 471 | return vec * sigma_t * eta 472 | 473 | #Super Resolution 474 | class SuperResolution(A_functions): 475 | def __init__(self, channels, img_dim, ratio, device): #ratio = 2 or 4 476 | assert img_dim % ratio == 0 477 | self.img_dim = img_dim 478 | self.channels = channels 479 | self.y_dim = img_dim // ratio 480 | self.ratio = ratio 481 | A = torch.Tensor([[1 / ratio**2] * ratio**2]).to(device) 482 | self.U_small, self.singulars_small, self.V_small = torch.svd(A, some=False) 483 | self.Vt_small = self.V_small.transpose(0, 1) 484 | 485 | def V(self, vec): 486 | #reorder the vector back into patches (because singulars are ordered descendingly) 487 | temp = vec.clone().reshape(vec.shape[0], -1) 488 | patches = torch.zeros(vec.shape[0], self.channels, self.y_dim**2, self.ratio**2, device=vec.device) 489 | patches[:, :, :, 0] = temp[:, :self.channels * self.y_dim**2].view(vec.shape[0], self.channels, -1) 490 | for idx in range(self.ratio**2-1): 491 | patches[:, :, :, idx+1] = temp[:, (self.channels*self.y_dim**2+idx)::self.ratio**2-1].view(vec.shape[0], self.channels, -1) 492 | #multiply each patch by the small V 493 | patches = torch.matmul(self.V_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2) 494 | #repatch the patches into an image 495 | patches_orig = patches.reshape(vec.shape[0], self.channels, self.y_dim, self.y_dim, self.ratio, self.ratio) 496 | recon = patches_orig.permute(0, 1, 2, 4, 3, 5).contiguous() 497 | recon = recon.reshape(vec.shape[0], self.channels * self.img_dim ** 2) 498 | return recon 499 | 500 | def Vt(self, vec): 501 | #extract flattened patches 502 | patches = vec.clone().reshape(vec.shape[0], self.channels, self.img_dim, self.img_dim) 503 | patches = patches.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio) 504 | unfold_shape = patches.shape 505 | patches = patches.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio**2) 506 | #multiply each by the small V transposed 507 | patches = torch.matmul(self.Vt_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2) 508 | #reorder the vector to have the first entry first (because singulars are ordered descendingly) 509 | recon = torch.zeros(vec.shape[0], self.channels * self.img_dim**2, device=vec.device) 510 | recon[:, :self.channels * self.y_dim**2] = patches[:, :, :, 0].view(vec.shape[0], self.channels * self.y_dim**2) 511 | for idx in range(self.ratio**2-1): 512 | recon[:, (self.channels*self.y_dim**2+idx)::self.ratio**2-1] = patches[:, :, :, idx+1].view(vec.shape[0], self.channels * self.y_dim**2) 513 | return recon 514 | 515 | def U(self, vec): 516 | return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1) 517 | 518 | def Ut(self, vec): #U is 1x1, so U^T = U 519 | return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1) 520 | 521 | def singulars(self): 522 | return self.singulars_small.repeat(self.channels * self.y_dim**2) 523 | 524 | def add_zeros(self, vec): 525 | reshaped = vec.clone().reshape(vec.shape[0], -1) 526 | temp = torch.zeros((vec.shape[0], reshaped.shape[1] * self.ratio**2), device=vec.device) 527 | temp[:, :reshaped.shape[1]] = reshaped 528 | return temp 529 | 530 | def Lambda(self, vec, a, sigma_y, sigma_t, eta): 531 | singulars = self.singulars_small 532 | 533 | patches = vec.clone().reshape(vec.shape[0], self.channels, self.img_dim, self.img_dim) 534 | patches = patches.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio) 535 | patches = patches.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio ** 2) 536 | 537 | patches = torch.matmul(self.Vt_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2) 538 | 539 | lambda_t = torch.ones(self.ratio ** 2, device=vec.device) 540 | 541 | temp = torch.zeros(self.ratio ** 2, device=vec.device) 542 | temp[:singulars.size(0)] = singulars 543 | singulars = temp 544 | inverse_singulars = 1. / singulars 545 | inverse_singulars[singulars == 0] = 0. 546 | 547 | if a != 0 and sigma_y != 0: 548 | change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0 549 | lambda_t = lambda_t * (-change_index + 1.0) + change_index * (singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y) 550 | 551 | lambda_t = lambda_t.reshape(1, 1, 1, -1) 552 | # print("lambda_t:", lambda_t) 553 | # print("V:", self.V_small) 554 | # print(lambda_t.size(), self.V_small.size()) 555 | # print("Sigma_t:", torch.matmul(torch.matmul(self.V_small, torch.diag(lambda_t.reshape(-1))), self.Vt_small)) 556 | patches = patches * lambda_t 557 | 558 | 559 | patches = torch.matmul(self.V_small, patches.reshape(-1, self.ratio**2, 1)) 560 | 561 | patches = patches.reshape(vec.shape[0], self.channels, self.y_dim, self.y_dim, self.ratio, self.ratio) 562 | patches = patches.permute(0, 1, 2, 4, 3, 5).contiguous() 563 | patches = patches.reshape(vec.shape[0], self.channels * self.img_dim ** 2) 564 | 565 | return patches 566 | 567 | def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon): 568 | singulars = self.singulars_small 569 | 570 | patches_vec = vec.clone().reshape(vec.shape[0], self.channels, self.img_dim, self.img_dim) 571 | patches_vec = patches_vec.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio) 572 | patches_vec = patches_vec.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio ** 2) 573 | 574 | patches_eps = epsilon.clone().reshape(vec.shape[0], self.channels, self.img_dim, self.img_dim) 575 | patches_eps = patches_eps.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio) 576 | patches_eps = patches_eps.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio ** 2) 577 | 578 | d1_t = torch.ones(self.ratio ** 2, device=vec.device) * sigma_t * eta 579 | d2_t = torch.ones(self.ratio ** 2, device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5 580 | 581 | temp = torch.zeros(self.ratio ** 2, device=vec.device) 582 | temp[:singulars.size(0)] = singulars 583 | singulars = temp 584 | inverse_singulars = 1. / singulars 585 | inverse_singulars[singulars == 0] = 0. 586 | 587 | if a != 0 and sigma_y != 0: 588 | 589 | change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0 590 | d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta 591 | d2_t = d2_t * (-change_index + 1.0) 592 | 593 | change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0 594 | d1_t = d1_t * (-change_index + 1.0) + torch.sqrt(change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2)) 595 | d2_t = d2_t * (-change_index + 1.0) 596 | 597 | change_index = (singulars == 0) * 1.0 598 | d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta 599 | d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5 600 | 601 | d1_t = d1_t.reshape(1, 1, 1, -1) 602 | d2_t = d2_t.reshape(1, 1, 1, -1) 603 | patches_vec = patches_vec * d1_t 604 | patches_eps = patches_eps * d2_t 605 | 606 | patches_vec = torch.matmul(self.V_small, patches_vec.reshape(-1, self.ratio**2, 1)) 607 | 608 | patches_vec = patches_vec.reshape(vec.shape[0], self.channels, self.y_dim, self.y_dim, self.ratio, self.ratio) 609 | patches_vec = patches_vec.permute(0, 1, 2, 4, 3, 5).contiguous() 610 | patches_vec = patches_vec.reshape(vec.shape[0], self.channels * self.img_dim ** 2) 611 | 612 | patches_eps = torch.matmul(self.V_small, patches_eps.reshape(-1, self.ratio**2, 1)) 613 | 614 | patches_eps = patches_eps.reshape(vec.shape[0], self.channels, self.y_dim, self.y_dim, self.ratio, self.ratio) 615 | patches_eps = patches_eps.permute(0, 1, 2, 4, 3, 5).contiguous() 616 | patches_eps = patches_eps.reshape(vec.shape[0], self.channels * self.img_dim ** 2) 617 | 618 | return patches_vec + patches_eps 619 | 620 | class SuperResolutionGeneral(SuperResolution): 621 | def __init__(self, channels, imgH, imgW, ratio, device): #ratio = 2 or 4 622 | assert imgH % ratio == 0 and imgW % ratio == 0 623 | self.imgH = imgH 624 | self.imgW = imgW 625 | self.channels = channels 626 | self.yH = imgH // ratio 627 | self.yW = imgW // ratio 628 | self.ratio = ratio 629 | A = torch.Tensor([[1 / ratio**2] * ratio**2]).to(device) 630 | self.U_small, self.singulars_small, self.V_small = torch.svd(A, some=False) 631 | self.Vt_small = self.V_small.transpose(0, 1) 632 | 633 | def V(self, vec): 634 | #reorder the vector back into patches (because singulars are ordered descendingly) 635 | temp = vec.clone().reshape(vec.shape[0], -1) 636 | patches = torch.zeros(vec.shape[0], self.channels, self.yH*self.yW, self.ratio**2, device=vec.device) 637 | patches[:, :, :, 0] = temp[:, :self.channels * self.yH*self.yW].view(vec.shape[0], self.channels, -1) 638 | for idx in range(self.ratio**2-1): 639 | patches[:, :, :, idx+1] = temp[:, (self.channels*self.yH*self.yW+idx)::self.ratio**2-1].view(vec.shape[0], self.channels, -1) 640 | #multiply each patch by the small V 641 | patches = torch.matmul(self.V_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2) 642 | #repatch the patches into an image 643 | patches_orig = patches.reshape(vec.shape[0], self.channels, self.yH, self.yW, self.ratio, self.ratio) 644 | recon = patches_orig.permute(0, 1, 2, 4, 3, 5).contiguous() 645 | recon = recon.reshape(vec.shape[0], self.channels * self.imgH * self.imgW) 646 | return recon 647 | 648 | def Vt(self, vec): 649 | #extract flattened patches 650 | patches = vec.clone().reshape(vec.shape[0], self.channels, self.imgH, self.imgW) 651 | patches = patches.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio) 652 | unfold_shape = patches.shape 653 | patches = patches.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio**2) 654 | #multiply each by the small V transposed 655 | patches = torch.matmul(self.Vt_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2) 656 | #reorder the vector to have the first entry first (because singulars are ordered descendingly) 657 | recon = torch.zeros(vec.shape[0], self.channels * self.imgH*self.imgW, device=vec.device) 658 | recon[:, :self.channels * self.yH*self.yW] = patches[:, :, :, 0].view(vec.shape[0], self.channels * self.yH*self.yW) 659 | for idx in range(self.ratio**2-1): 660 | recon[:, (self.channels*self.yH*self.yW+idx)::self.ratio**2-1] = patches[:, :, :, idx+1].view(vec.shape[0], self.channels * self.yH*self.yW) 661 | return recon 662 | 663 | def U(self, vec): 664 | return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1) 665 | 666 | def Ut(self, vec): #U is 1x1, so U^T = U 667 | return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1) 668 | 669 | def singulars(self): 670 | return self.singulars_small.repeat(self.channels * self.yH*self.yW) 671 | 672 | def add_zeros(self, vec): 673 | reshaped = vec.clone().reshape(vec.shape[0], -1) 674 | temp = torch.zeros((vec.shape[0], reshaped.shape[1] * self.ratio**2), device=vec.device) 675 | temp[:, :reshaped.shape[1]] = reshaped 676 | return temp 677 | 678 | def Lambda(self, vec, a, sigma_y, sigma_t, eta): 679 | singulars = self.singulars_small 680 | 681 | patches = vec.clone().reshape(vec.shape[0], self.channels, self.imgH, self.imgW) 682 | patches = patches.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio) 683 | patches = patches.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio ** 2) 684 | 685 | patches = torch.matmul(self.Vt_small, patches.reshape(-1, self.ratio**2, 1)).reshape(vec.shape[0], self.channels, -1, self.ratio**2) 686 | 687 | lambda_t = torch.ones(self.ratio ** 2, device=vec.device) 688 | 689 | temp = torch.zeros(self.ratio ** 2, device=vec.device) 690 | temp[:singulars.size(0)] = singulars 691 | singulars = temp 692 | inverse_singulars = 1. / singulars 693 | inverse_singulars[singulars == 0] = 0. 694 | 695 | if a != 0 and sigma_y != 0: 696 | change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0 697 | lambda_t = lambda_t * (-change_index + 1.0) + change_index * (singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y) 698 | 699 | lambda_t = lambda_t.reshape(1, 1, 1, -1) 700 | # print("lambda_t:", lambda_t) 701 | # print("V:", self.V_small) 702 | # print(lambda_t.size(), self.V_small.size()) 703 | # print("Sigma_t:", torch.matmul(torch.matmul(self.V_small, torch.diag(lambda_t.reshape(-1))), self.Vt_small)) 704 | patches = patches * lambda_t 705 | 706 | 707 | patches = torch.matmul(self.V_small, patches.reshape(-1, self.ratio**2, 1)) 708 | 709 | patches = patches.reshape(vec.shape[0], self.channels, self.yH, self.yW, self.ratio, self.ratio) 710 | patches = patches.permute(0, 1, 2, 4, 3, 5).contiguous() 711 | patches = patches.reshape(vec.shape[0], self.channels * self.imgH * self.imgW) 712 | 713 | return patches 714 | 715 | def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon): 716 | singulars = self.singulars_small 717 | 718 | patches_vec = vec.clone().reshape(vec.shape[0], self.channels, self.imgH, self.imgW) 719 | patches_vec = patches_vec.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio) 720 | patches_vec = patches_vec.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio ** 2) 721 | 722 | patches_eps = epsilon.clone().reshape(vec.shape[0], self.channels, self.imgH, self.imgW) 723 | patches_eps = patches_eps.unfold(2, self.ratio, self.ratio).unfold(3, self.ratio, self.ratio) 724 | patches_eps = patches_eps.contiguous().reshape(vec.shape[0], self.channels, -1, self.ratio ** 2) 725 | 726 | d1_t = torch.ones(self.ratio ** 2, device=vec.device) * sigma_t * eta 727 | d2_t = torch.ones(self.ratio ** 2, device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5 728 | 729 | temp = torch.zeros(self.ratio ** 2, device=vec.device) 730 | temp[:singulars.size(0)] = singulars 731 | singulars = temp 732 | inverse_singulars = 1. / singulars 733 | inverse_singulars[singulars == 0] = 0. 734 | 735 | if a != 0 and sigma_y != 0: 736 | 737 | change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0 738 | d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta 739 | d2_t = d2_t * (-change_index + 1.0) 740 | 741 | change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0 742 | d1_t = d1_t * (-change_index + 1.0) + torch.sqrt(change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2)) 743 | d2_t = d2_t * (-change_index + 1.0) 744 | 745 | change_index = (singulars == 0) * 1.0 746 | d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta 747 | d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5 748 | 749 | d1_t = d1_t.reshape(1, 1, 1, -1) 750 | d2_t = d2_t.reshape(1, 1, 1, -1) 751 | patches_vec = patches_vec * d1_t 752 | patches_eps = patches_eps * d2_t 753 | 754 | patches_vec = torch.matmul(self.V_small, patches_vec.reshape(-1, self.ratio**2, 1)) 755 | 756 | patches_vec = patches_vec.reshape(vec.shape[0], self.channels, self.yH, self.yW, self.ratio, self.ratio) 757 | patches_vec = patches_vec.permute(0, 1, 2, 4, 3, 5).contiguous() 758 | patches_vec = patches_vec.reshape(vec.shape[0], self.channels * self.imgH * self.imgW) 759 | 760 | patches_eps = torch.matmul(self.V_small, patches_eps.reshape(-1, self.ratio**2, 1)) 761 | 762 | patches_eps = patches_eps.reshape(vec.shape[0], self.channels, self.yH, self.yW, self.ratio, self.ratio) 763 | patches_eps = patches_eps.permute(0, 1, 2, 4, 3, 5).contiguous() 764 | patches_eps = patches_eps.reshape(vec.shape[0], self.channels * self.imgH * self.imgW) 765 | 766 | return patches_vec + patches_eps 767 | 768 | #Colorization 769 | class Colorization(A_functions): 770 | def __init__(self, img_dim, device): 771 | self.channels = 3 772 | self.img_dim = img_dim 773 | #Do the SVD for the per-pixel matrix 774 | A = torch.Tensor([[0.3333, 0.3334, 0.3333]]).to(device) 775 | self.U_small, self.singulars_small, self.V_small = torch.svd(A, some=False) 776 | self.Vt_small = self.V_small.transpose(0, 1) 777 | 778 | def V(self, vec): 779 | #get the needles 780 | needles = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1) #shape: B, WA, C' 781 | #multiply each needle by the small V 782 | needles = torch.matmul(self.V_small, needles.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1, self.channels) #shape: B, WA, C 783 | #permute back to vector representation 784 | recon = needles.permute(0, 2, 1) #shape: B, C, WA 785 | return recon.reshape(vec.shape[0], -1) 786 | 787 | def Vt(self, vec): 788 | #get the needles 789 | needles = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1) #shape: B, WA, C 790 | #multiply each needle by the small V transposed 791 | needles = torch.matmul(self.Vt_small, needles.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1, self.channels) #shape: B, WA, C' 792 | #reorder the vector so that the first entry of each needle is at the top 793 | recon = needles.permute(0, 2, 1).reshape(vec.shape[0], -1) 794 | return recon 795 | 796 | def U(self, vec): 797 | return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1) 798 | 799 | def Ut(self, vec): #U is 1x1, so U^T = U 800 | return self.U_small[0, 0] * vec.clone().reshape(vec.shape[0], -1) 801 | 802 | def singulars(self): 803 | return self.singulars_small.repeat(self.img_dim**2) 804 | 805 | def add_zeros(self, vec): 806 | reshaped = vec.clone().reshape(vec.shape[0], -1) 807 | temp = torch.zeros((vec.shape[0], self.channels * self.img_dim**2), device=vec.device) 808 | temp[:, :self.img_dim**2] = reshaped 809 | return temp 810 | 811 | def Lambda(self, vec, a, sigma_y, sigma_t, eta): 812 | needles = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1) 813 | 814 | needles = torch.matmul(self.Vt_small, needles.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1, self.channels) 815 | 816 | singulars = self.singulars_small 817 | lambda_t = torch.ones(self.channels, device=vec.device) 818 | temp = torch.zeros(self.channels, device=vec.device) 819 | temp[:singulars.size(0)] = singulars 820 | singulars = temp 821 | inverse_singulars = 1. / singulars 822 | inverse_singulars[singulars == 0] = 0. 823 | 824 | if a != 0 and sigma_y != 0: 825 | change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0 826 | lambda_t = lambda_t * (-change_index + 1.0) + change_index * ( 827 | singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y) 828 | 829 | lambda_t = lambda_t.reshape(1, 1, self.channels) 830 | needles = needles * lambda_t 831 | 832 | needles = torch.matmul(self.V_small, needles.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1, self.channels) 833 | 834 | recon = needles.permute(0, 2, 1).reshape(vec.shape[0], -1) 835 | return recon 836 | 837 | def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon): 838 | singulars = self.singulars_small 839 | 840 | needles_vec = vec.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1) 841 | needles_epsilon = epsilon.clone().reshape(vec.shape[0], self.channels, -1).permute(0, 2, 1) 842 | 843 | d1_t = torch.ones(self.channels, device=vec.device) * sigma_t * eta 844 | d2_t = torch.ones(self.channels, device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5 845 | 846 | temp = torch.zeros(self.channels, device=vec.device) 847 | temp[:singulars.size(0)] = singulars 848 | singulars = temp 849 | inverse_singulars = 1. / singulars 850 | inverse_singulars[singulars == 0] = 0. 851 | 852 | if a != 0 and sigma_y != 0: 853 | change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0 854 | d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta 855 | d2_t = d2_t * (-change_index + 1.0) 856 | 857 | change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0 858 | d1_t = d1_t * (-change_index + 1.0) + torch.sqrt( 859 | change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2)) 860 | d2_t = d2_t * (-change_index + 1.0) 861 | 862 | change_index = (singulars == 0) * 1.0 863 | d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta 864 | d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5 865 | 866 | d1_t = d1_t.reshape(1, 1, self.channels) 867 | d2_t = d2_t.reshape(1, 1, self.channels) 868 | 869 | needles_vec = needles_vec * d1_t 870 | needles_epsilon = needles_epsilon * d2_t 871 | 872 | needles_vec = torch.matmul(self.V_small, needles_vec.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1, self.channels) 873 | recon_vec = needles_vec.permute(0, 2, 1).reshape(vec.shape[0], -1) 874 | 875 | needles_epsilon = torch.matmul(self.V_small, needles_epsilon.reshape(-1, self.channels, 1)).reshape(vec.shape[0], -1,self.channels) 876 | recon_epsilon = needles_epsilon.permute(0, 2, 1).reshape(vec.shape[0], -1) 877 | 878 | return recon_vec + recon_epsilon 879 | 880 | #Walsh-Aadamard Compressive Sensing 881 | class WalshAadamardCS(A_functions): 882 | def fwht(self, vec): #the Fast Walsh Aadamard Transform is the same as its inverse 883 | a = vec.reshape(vec.shape[0], self.channels, self.img_dim**2) 884 | h = 1 885 | while h < self.img_dim**2: 886 | a = a.reshape(vec.shape[0], self.channels, -1, h * 2) 887 | b = a.clone() 888 | a[:, :, :, :h] = b[:, :, :, :h] + b[:, :, :, h:2*h] 889 | a[:, :, :, h:2*h] = b[:, :, :, :h] - b[:, :, :, h:2*h] 890 | h *= 2 891 | a = a.reshape(vec.shape[0], self.channels, self.img_dim**2) / self.img_dim 892 | return a 893 | 894 | def __init__(self, channels, img_dim, ratio, perm, device): 895 | self.channels = channels 896 | self.img_dim = img_dim 897 | self.ratio = ratio 898 | self.perm = perm 899 | self._singulars = torch.ones(channels * img_dim**2 // ratio, device=device) 900 | 901 | def V(self, vec): 902 | temp = torch.zeros(vec.shape[0], self.channels, self.img_dim**2, device=vec.device) 903 | temp[:, :, self.perm] = vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1) 904 | return self.fwht(temp).reshape(vec.shape[0], -1) 905 | 906 | def Vt(self, vec): 907 | return self.fwht(vec.clone())[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1) 908 | 909 | def U(self, vec): 910 | return vec.clone().reshape(vec.shape[0], -1) 911 | 912 | def Ut(self, vec): 913 | return vec.clone().reshape(vec.shape[0], -1) 914 | 915 | def singulars(self): 916 | return self._singulars 917 | 918 | def add_zeros(self, vec): 919 | out = torch.zeros(vec.shape[0], self.channels * self.img_dim**2, device=vec.device) 920 | out[:, :self.channels * self.img_dim**2 // self.ratio] = vec.clone().reshape(vec.shape[0], -1) 921 | return out 922 | 923 | def Lambda(self, vec, a, sigma_y, sigma_t, eta): 924 | temp_vec = self.fwht(vec.clone())[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1) 925 | 926 | singulars = self._singulars 927 | lambda_t = torch.ones(self.channels * self.img_dim ** 2, device=vec.device) 928 | temp = torch.zeros(self.channels * self.img_dim ** 2, device=vec.device) 929 | temp[:singulars.size(0)] = singulars 930 | singulars = temp 931 | inverse_singulars = 1. / singulars 932 | inverse_singulars[singulars == 0] = 0. 933 | 934 | if a != 0 and sigma_y != 0: 935 | change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0 936 | lambda_t = lambda_t * (-change_index + 1.0) + change_index * ( 937 | singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y) 938 | 939 | lambda_t = lambda_t.reshape(1, -1) 940 | temp_vec = temp_vec * lambda_t 941 | 942 | temp_out = torch.zeros(vec.shape[0], self.channels, self.img_dim ** 2, device=vec.device) 943 | temp_out[:, :, self.perm] = temp_vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1) 944 | return self.fwht(temp_out).reshape(vec.shape[0], -1) 945 | 946 | def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon): 947 | temp_vec = vec.clone().reshape( 948 | vec.shape[0], self.channels, self.img_dim ** 2)[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1) 949 | temp_eps = epsilon.clone().reshape( 950 | vec.shape[0], self.channels, self.img_dim ** 2)[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1) 951 | 952 | d1_t = torch.ones(self.channels * self.img_dim ** 2, device=vec.device) * sigma_t * eta 953 | d2_t = torch.ones(self.channels * self.img_dim ** 2, device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5 954 | 955 | singulars = self._singulars 956 | temp = torch.zeros(self.channels * self.img_dim ** 2, device=vec.device) 957 | temp[:singulars.size(0)] = singulars 958 | singulars = temp 959 | inverse_singulars = 1. / singulars 960 | inverse_singulars[singulars == 0] = 0. 961 | 962 | if a != 0 and sigma_y != 0: 963 | change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0 964 | d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta 965 | d2_t = d2_t * (-change_index + 1.0) 966 | 967 | change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0 968 | d1_t = d1_t * (-change_index + 1.0) + torch.sqrt( 969 | change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2)) 970 | d2_t = d2_t * (-change_index + 1.0) 971 | 972 | change_index = (singulars == 0) * 1.0 973 | d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta 974 | d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5 975 | 976 | d1_t = d1_t.reshape(1, -1) 977 | d2_t = d2_t.reshape(1, -1) 978 | 979 | temp_vec = temp_vec * d1_t 980 | temp_eps = temp_eps * d2_t 981 | 982 | temp_out_vec = torch.zeros(vec.shape[0], self.channels, self.img_dim ** 2, device=vec.device) 983 | temp_out_vec[:, :, self.perm] = temp_vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1) 984 | temp_out_vec = self.fwht(temp_out_vec).reshape(vec.shape[0], -1) 985 | 986 | temp_out_eps = torch.zeros(vec.shape[0], self.channels, self.img_dim ** 2, device=vec.device) 987 | temp_out_eps[:, :, self.perm] = temp_eps.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1) 988 | temp_out_eps = self.fwht(temp_out_eps).reshape(vec.shape[0], -1) 989 | 990 | return temp_out_vec + temp_out_eps 991 | 992 | #Convolution-based super-resolution 993 | class SRConv(A_functions): 994 | def mat_by_img(self, M, v, dim): 995 | return torch.matmul(M, v.reshape(v.shape[0] * self.channels, dim, 996 | dim)).reshape(v.shape[0], self.channels, M.shape[0], dim) 997 | 998 | def img_by_mat(self, v, M, dim): 999 | return torch.matmul(v.reshape(v.shape[0] * self.channels, dim, 1000 | dim), M).reshape(v.shape[0], self.channels, dim, M.shape[1]) 1001 | 1002 | def __init__(self, kernel, channels, img_dim, device, stride = 1): 1003 | self.img_dim = img_dim 1004 | self.channels = channels 1005 | self.ratio = stride 1006 | small_dim = img_dim // stride 1007 | self.small_dim = small_dim 1008 | #build 1D conv matrix 1009 | A_small = torch.zeros(small_dim, img_dim, device=device) 1010 | for i in range(stride//2, img_dim + stride//2, stride): 1011 | for j in range(i - kernel.shape[0]//2, i + kernel.shape[0]//2): 1012 | j_effective = j 1013 | #reflective padding 1014 | if j_effective < 0: j_effective = -j_effective-1 1015 | if j_effective >= img_dim: j_effective = (img_dim - 1) - (j_effective - img_dim) 1016 | #matrix building 1017 | A_small[i // stride, j_effective] += kernel[j - i + kernel.shape[0]//2] 1018 | #get the svd of the 1D conv 1019 | self.U_small, self.singulars_small, self.V_small = torch.svd(A_small, some=False) 1020 | ZERO = 3e-2 1021 | self.singulars_small[self.singulars_small < ZERO] = 0 1022 | #calculate the singular values of the big matrix 1023 | self._singulars = torch.matmul(self.singulars_small.reshape(small_dim, 1), self.singulars_small.reshape(1, small_dim)).reshape(small_dim**2) 1024 | #permutation for matching the singular values. See P_1 in Appendix D.5. 1025 | self._perm = torch.Tensor([self.img_dim * i + j for i in range(self.small_dim) for j in range(self.small_dim)] + \ 1026 | [self.img_dim * i + j for i in range(self.small_dim) for j in range(self.small_dim, self.img_dim)]).to(device).long() 1027 | 1028 | def V(self, vec): 1029 | #invert the permutation 1030 | temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device) 1031 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)[:, :self._perm.shape[0], :] 1032 | temp[:, self._perm.shape[0]:, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)[:, self._perm.shape[0]:, :] 1033 | temp = temp.permute(0, 2, 1) 1034 | #multiply the image by V from the left and by V^T from the right 1035 | out = self.mat_by_img(self.V_small, temp, self.img_dim) 1036 | out = self.img_by_mat(out, self.V_small.transpose(0, 1), self.img_dim).reshape(vec.shape[0], -1) 1037 | return out 1038 | 1039 | def Vt(self, vec): 1040 | #multiply the image by V^T from the left and by V from the right 1041 | temp = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone(), self.img_dim) 1042 | temp = self.img_by_mat(temp, self.V_small, self.img_dim).reshape(vec.shape[0], self.channels, -1) 1043 | #permute the entries 1044 | temp[:, :, :self._perm.shape[0]] = temp[:, :, self._perm] 1045 | temp = temp.permute(0, 2, 1) 1046 | return temp.reshape(vec.shape[0], -1) 1047 | 1048 | def U(self, vec): 1049 | #invert the permutation 1050 | temp = torch.zeros(vec.shape[0], self.small_dim**2, self.channels, device=vec.device) 1051 | temp[:, :self.small_dim**2, :] = vec.clone().reshape(vec.shape[0], self.small_dim**2, self.channels) 1052 | temp = temp.permute(0, 2, 1) 1053 | #multiply the image by U from the left and by U^T from the right 1054 | out = self.mat_by_img(self.U_small, temp, self.small_dim) 1055 | out = self.img_by_mat(out, self.U_small.transpose(0, 1), self.small_dim).reshape(vec.shape[0], -1) 1056 | return out 1057 | 1058 | def Ut(self, vec): 1059 | #multiply the image by U^T from the left and by U from the right 1060 | temp = self.mat_by_img(self.U_small.transpose(0, 1), vec.clone(), self.small_dim) 1061 | temp = self.img_by_mat(temp, self.U_small, self.small_dim).reshape(vec.shape[0], self.channels, -1) 1062 | #permute the entries 1063 | temp = temp.permute(0, 2, 1) 1064 | return temp.reshape(vec.shape[0], -1) 1065 | 1066 | def singulars(self): 1067 | return self._singulars.repeat_interleave(3).reshape(-1) 1068 | 1069 | def add_zeros(self, vec): 1070 | reshaped = vec.clone().reshape(vec.shape[0], -1) 1071 | temp = torch.zeros((vec.shape[0], reshaped.shape[1] * self.ratio**2), device=vec.device) 1072 | temp[:, :reshaped.shape[1]] = reshaped 1073 | return temp 1074 | 1075 | #Deblurring 1076 | class Deblurring(A_functions): 1077 | def mat_by_img(self, M, v): 1078 | return torch.matmul(M, v.reshape(v.shape[0] * self.channels, self.img_dim, 1079 | self.img_dim)).reshape(v.shape[0], self.channels, M.shape[0], self.img_dim) 1080 | 1081 | def img_by_mat(self, v, M): 1082 | return torch.matmul(v.reshape(v.shape[0] * self.channels, self.img_dim, 1083 | self.img_dim), M).reshape(v.shape[0], self.channels, self.img_dim, M.shape[1]) 1084 | 1085 | def __init__(self, kernel, channels, img_dim, device, ZERO = 3e-2): 1086 | self.img_dim = img_dim 1087 | self.channels = channels 1088 | #build 1D conv matrix 1089 | A_small = torch.zeros(img_dim, img_dim, device=device) 1090 | for i in range(img_dim): 1091 | for j in range(i - kernel.shape[0]//2, i + kernel.shape[0]//2): 1092 | if j < 0 or j >= img_dim: continue 1093 | A_small[i, j] = kernel[j - i + kernel.shape[0]//2] 1094 | #get the svd of the 1D conv 1095 | self.U_small, self.singulars_small, self.V_small = torch.svd(A_small, some=False) 1096 | #ZERO = 3e-2 1097 | self.singulars_small_orig = self.singulars_small.clone() 1098 | self.singulars_small[self.singulars_small < ZERO] = 0 1099 | #calculate the singular values of the big matrix 1100 | self._singulars_orig = torch.matmul(self.singulars_small_orig.reshape(img_dim, 1), self.singulars_small_orig.reshape(1, img_dim)).reshape(img_dim**2) 1101 | self._singulars = torch.matmul(self.singulars_small.reshape(img_dim, 1), self.singulars_small.reshape(1, img_dim)).reshape(img_dim**2) 1102 | #sort the big matrix singulars and save the permutation 1103 | self._singulars, self._perm = self._singulars.sort(descending=True) #, stable=True) 1104 | self._singulars_orig = self._singulars_orig[self._perm] 1105 | 1106 | def V(self, vec): 1107 | #invert the permutation 1108 | temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device) 1109 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels) 1110 | temp = temp.permute(0, 2, 1) 1111 | #multiply the image by V from the left and by V^T from the right 1112 | out = self.mat_by_img(self.V_small, temp) 1113 | out = self.img_by_mat(out, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1) 1114 | return out 1115 | 1116 | def Vt(self, vec): 1117 | #multiply the image by V^T from the left and by V from the right 1118 | temp = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone()) 1119 | temp = self.img_by_mat(temp, self.V_small).reshape(vec.shape[0], self.channels, -1) 1120 | #permute the entries according to the singular values 1121 | temp = temp[:, :, self._perm].permute(0, 2, 1) 1122 | return temp.reshape(vec.shape[0], -1) 1123 | 1124 | def U(self, vec): 1125 | #invert the permutation 1126 | temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device) 1127 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels) 1128 | temp = temp.permute(0, 2, 1) 1129 | #multiply the image by U from the left and by U^T from the right 1130 | out = self.mat_by_img(self.U_small, temp) 1131 | out = self.img_by_mat(out, self.U_small.transpose(0, 1)).reshape(vec.shape[0], -1) 1132 | return out 1133 | 1134 | def Ut(self, vec): 1135 | #multiply the image by U^T from the left and by U from the right 1136 | temp = self.mat_by_img(self.U_small.transpose(0, 1), vec.clone()) 1137 | temp = self.img_by_mat(temp, self.U_small).reshape(vec.shape[0], self.channels, -1) 1138 | #permute the entries according to the singular values 1139 | temp = temp[:, :, self._perm].permute(0, 2, 1) 1140 | return temp.reshape(vec.shape[0], -1) 1141 | 1142 | def singulars(self): 1143 | return self._singulars.repeat(1, 3).reshape(-1) 1144 | 1145 | def add_zeros(self, vec): 1146 | return vec.clone().reshape(vec.shape[0], -1) 1147 | 1148 | def A_pinv(self, vec): 1149 | temp = self.Ut(vec) 1150 | singulars = self._singulars.repeat(1, 3).reshape(-1) 1151 | 1152 | factors = 1. / singulars 1153 | factors[singulars == 0] = 0. 1154 | 1155 | temp[:, :singulars.shape[0]] = temp[:, :singulars.shape[0]] * factors 1156 | return self.V(self.add_zeros(temp)) 1157 | 1158 | def Lambda(self, vec, a, sigma_y, sigma_t, eta): 1159 | temp_vec = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone()) 1160 | temp_vec = self.img_by_mat(temp_vec, self.V_small).reshape(vec.shape[0], self.channels, -1) 1161 | temp_vec = temp_vec[:, :, self._perm].permute(0, 2, 1) 1162 | 1163 | singulars = self._singulars_orig 1164 | lambda_t = torch.ones(self.img_dim ** 2, device=vec.device) 1165 | temp_singulars = torch.zeros(self.img_dim ** 2, device=vec.device) 1166 | temp_singulars[:singulars.size(0)] = singulars 1167 | singulars = temp_singulars 1168 | inverse_singulars = 1. / singulars 1169 | inverse_singulars[singulars == 0] = 0. 1170 | 1171 | if a != 0 and sigma_y != 0: 1172 | change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0 1173 | lambda_t = lambda_t * (-change_index + 1.0) + change_index * ( 1174 | singulars * sigma_t * (1 - eta ** 2) ** 0.5 / a / sigma_y) 1175 | 1176 | lambda_t = lambda_t.reshape(1, -1, 1) 1177 | temp_vec = temp_vec * lambda_t 1178 | 1179 | temp = torch.zeros(vec.shape[0], self.img_dim ** 2, self.channels, device=vec.device) 1180 | temp[:, self._perm, :] = temp_vec.clone().reshape(vec.shape[0], self.img_dim ** 2, self.channels) 1181 | temp = temp.permute(0, 2, 1) 1182 | out = self.mat_by_img(self.V_small, temp) 1183 | out = self.img_by_mat(out, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1) 1184 | return out 1185 | 1186 | def Lambda_noise(self, vec, a, sigma_y, sigma_t, eta, epsilon): 1187 | temp_vec = vec.clone().reshape(vec.shape[0], self.channels, -1) 1188 | temp_vec = temp_vec[:, :, self._perm].permute(0, 2, 1) 1189 | 1190 | temp_eps = epsilon.clone().reshape(vec.shape[0], self.channels, -1) 1191 | temp_eps = temp_eps[:, :, self._perm].permute(0, 2, 1) 1192 | 1193 | singulars = self._singulars_orig 1194 | d1_t = torch.ones(self.img_dim ** 2, device=vec.device) * sigma_t * eta 1195 | d2_t = torch.ones(self.img_dim ** 2, device=vec.device) * sigma_t * (1 - eta ** 2) ** 0.5 1196 | 1197 | temp_singulars = torch.zeros(self.img_dim ** 2, device=vec.device) 1198 | temp_singulars[:singulars.size(0)] = singulars 1199 | singulars = temp_singulars 1200 | inverse_singulars = 1. / singulars 1201 | inverse_singulars[singulars == 0] = 0. 1202 | 1203 | if a != 0 and sigma_y != 0: 1204 | change_index = (sigma_t < a * sigma_y * inverse_singulars) * 1.0 1205 | d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta 1206 | d2_t = d2_t * (-change_index + 1.0) 1207 | 1208 | change_index = (sigma_t > a * sigma_y * inverse_singulars) * 1.0 1209 | d1_t = d1_t * (-change_index + 1.0) + torch.sqrt( 1210 | change_index * (sigma_t ** 2 - a ** 2 * sigma_y ** 2 * inverse_singulars ** 2)) 1211 | d2_t = d2_t * (-change_index + 1.0) 1212 | 1213 | change_index = (singulars == 0) * 1.0 1214 | d1_t = d1_t * (-change_index + 1.0) + change_index * sigma_t * eta 1215 | d2_t = d2_t * (-change_index + 1.0) + change_index * sigma_t * (1 - eta ** 2) ** 0.5 1216 | 1217 | d1_t = d1_t.reshape(1, -1, 1) 1218 | d2_t = d2_t.reshape(1, -1, 1) 1219 | 1220 | temp_vec = temp_vec * d1_t 1221 | temp_eps = temp_eps * d2_t 1222 | 1223 | temp_vec_new = torch.zeros(vec.shape[0], self.img_dim ** 2, self.channels, device=vec.device) 1224 | temp_vec_new[:, self._perm, :] = temp_vec 1225 | out_vec = self.mat_by_img(self.V_small, temp_vec_new.permute(0, 2, 1)) 1226 | out_vec = self.img_by_mat(out_vec, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1) 1227 | 1228 | temp_eps_new = torch.zeros(vec.shape[0], self.img_dim ** 2, self.channels, device=vec.device) 1229 | temp_eps_new[:, self._perm, :] = temp_eps 1230 | out_eps = self.mat_by_img(self.V_small, temp_eps_new.permute(0, 2, 1)) 1231 | out_eps = self.img_by_mat(out_eps, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1) 1232 | 1233 | return out_vec + out_eps 1234 | 1235 | #Anisotropic Deblurring 1236 | class Deblurring2D(A_functions): 1237 | def mat_by_img(self, M, v): 1238 | return torch.matmul(M, v.reshape(v.shape[0] * self.channels, self.img_dim, 1239 | self.img_dim)).reshape(v.shape[0], self.channels, M.shape[0], self.img_dim) 1240 | 1241 | def img_by_mat(self, v, M): 1242 | return torch.matmul(v.reshape(v.shape[0] * self.channels, self.img_dim, 1243 | self.img_dim), M).reshape(v.shape[0], self.channels, self.img_dim, M.shape[1]) 1244 | 1245 | def __init__(self, kernel1, kernel2, channels, img_dim, device): 1246 | self.img_dim = img_dim 1247 | self.channels = channels 1248 | A_small1 = torch.zeros(img_dim, img_dim, device=device) 1249 | for i in range(img_dim): 1250 | for j in range(i - kernel1.shape[0]//2, i + kernel1.shape[0]//2): 1251 | if j < 0 or j >= img_dim: continue 1252 | A_small1[i, j] = kernel1[j - i + kernel1.shape[0]//2] 1253 | A_small2 = torch.zeros(img_dim, img_dim, device=device) 1254 | for i in range(img_dim): 1255 | for j in range(i - kernel2.shape[0]//2, i + kernel2.shape[0]//2): 1256 | if j < 0 or j >= img_dim: continue 1257 | A_small2[i, j] = kernel2[j - i + kernel2.shape[0]//2] 1258 | self.U_small1, self.singulars_small1, self.V_small1 = torch.svd(A_small1, some=False) 1259 | self.U_small2, self.singulars_small2, self.V_small2 = torch.svd(A_small2, some=False) 1260 | ZERO = 3e-2 1261 | self.singulars_small1[self.singulars_small1 < ZERO] = 0 1262 | self.singulars_small2[self.singulars_small2 < ZERO] = 0 1263 | #calculate the singular values of the big matrix 1264 | self._singulars = torch.matmul(self.singulars_small1.reshape(img_dim, 1), self.singulars_small2.reshape(1, img_dim)).reshape(img_dim**2) 1265 | #sort the big matrix singulars and save the permutation 1266 | self._singulars, self._perm = self._singulars.sort(descending=True) #, stable=True) 1267 | 1268 | def V(self, vec): 1269 | #invert the permutation 1270 | temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device) 1271 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels) 1272 | temp = temp.permute(0, 2, 1) 1273 | #multiply the image by V from the left and by V^T from the right 1274 | out = self.mat_by_img(self.V_small1, temp) 1275 | out = self.img_by_mat(out, self.V_small2.transpose(0, 1)).reshape(vec.shape[0], -1) 1276 | return out 1277 | 1278 | def Vt(self, vec): 1279 | #multiply the image by V^T from the left and by V from the right 1280 | temp = self.mat_by_img(self.V_small1.transpose(0, 1), vec.clone()) 1281 | temp = self.img_by_mat(temp, self.V_small2).reshape(vec.shape[0], self.channels, -1) 1282 | #permute the entries according to the singular values 1283 | temp = temp[:, :, self._perm].permute(0, 2, 1) 1284 | return temp.reshape(vec.shape[0], -1) 1285 | 1286 | def U(self, vec): 1287 | #invert the permutation 1288 | temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device) 1289 | temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels) 1290 | temp = temp.permute(0, 2, 1) 1291 | #multiply the image by U from the left and by U^T from the right 1292 | out = self.mat_by_img(self.U_small1, temp) 1293 | out = self.img_by_mat(out, self.U_small2.transpose(0, 1)).reshape(vec.shape[0], -1) 1294 | return out 1295 | 1296 | def Ut(self, vec): 1297 | #multiply the image by U^T from the left and by U from the right 1298 | temp = self.mat_by_img(self.U_small1.transpose(0, 1), vec.clone()) 1299 | temp = self.img_by_mat(temp, self.U_small2).reshape(vec.shape[0], self.channels, -1) 1300 | #permute the entries according to the singular values 1301 | temp = temp[:, :, self._perm].permute(0, 2, 1) 1302 | return temp.reshape(vec.shape[0], -1) 1303 | 1304 | def singulars(self): 1305 | return self._singulars.repeat(1, 3).reshape(-1) 1306 | 1307 | def add_zeros(self, vec): 1308 | return vec.clone().reshape(vec.shape[0], -1) 1309 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.4.0 2 | certifi==2025.1.31 3 | charset-normalizer==3.4.1 4 | diffusers==0.30.1 5 | filelock==3.13.1 6 | fsspec==2024.6.1 7 | huggingface-hub==0.29.3 8 | idna==3.10 9 | importlib_metadata==8.6.1 10 | Jinja2==3.1.4 11 | MarkupSafe==2.1.5 12 | mpmath==1.3.0 13 | munch==4.0.0 14 | networkx==3.3 15 | numpy==2.1.2 16 | nvidia-cublas-cu11==11.11.3.6 17 | nvidia-cuda-cupti-cu11==11.8.87 18 | nvidia-cuda-nvrtc-cu11==11.8.89 19 | nvidia-cuda-runtime-cu11==11.8.89 20 | nvidia-cudnn-cu11==9.1.0.70 21 | nvidia-cufft-cu11==10.9.0.58 22 | nvidia-curand-cu11==10.3.0.86 23 | nvidia-cusolver-cu11==11.4.1.48 24 | nvidia-cusparse-cu11==11.7.5.86 25 | nvidia-nccl-cu11==2.20.5 26 | nvidia-nvtx-cu11==11.8.86 27 | packaging==24.2 28 | pillow==11.1.0 29 | protobuf==6.30.0 30 | psutil==7.0.0 31 | PyYAML==6.0.2 32 | regex==2024.11.6 33 | requests==2.32.3 34 | safetensors==0.5.3 35 | scipy==1.15.2 36 | sentencepiece==0.2.0 37 | sympy==1.13.1 38 | tokenizers==0.21.0 39 | torch==2.4.1 40 | torchvision==0.19.1 41 | tqdm==4.67.1 42 | transformers==4.49.0 43 | triton==3.0.0 44 | typing_extensions==4.12.2 45 | urllib3==2.3.0 46 | zipp==3.21.0 47 | -------------------------------------------------------------------------------- /samples/afhq_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlowDPS-Inverse/FlowDPS/3e73e443c987ba378b091515ba1e9b0963860e81/samples/afhq_example.jpg -------------------------------------------------------------------------------- /samples/div2k_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlowDPS-Inverse/FlowDPS/3e73e443c987ba378b091515ba1e9b0963860e81/samples/div2k_example.png -------------------------------------------------------------------------------- /samples/ffhq_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlowDPS-Inverse/FlowDPS/3e73e443c987ba378b091515ba1e9b0963860e81/samples/ffhq_example.png -------------------------------------------------------------------------------- /sd3_sampler.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Optional 2 | import math 3 | import torch 4 | 5 | from tqdm import tqdm 6 | from diffusers import StableDiffusion3Pipeline 7 | 8 | 9 | # ======================================================================= 10 | # Factory 11 | # ======================================================================= 12 | 13 | __SOLVER__ = {} 14 | 15 | def register_solver(name:str): 16 | def wrapper(cls): 17 | if __SOLVER__.get(name, None) is not None: 18 | raise ValueError(f"Solver {name} already registered.") 19 | __SOLVER__[name] = cls 20 | return cls 21 | return wrapper 22 | 23 | def get_solver(name:str, **kwargs): 24 | if name not in __SOLVER__: 25 | raise ValueError(f"Solver {name} does not exist.") 26 | return __SOLVER__[name](**kwargs) 27 | 28 | # ======================================================================= 29 | 30 | 31 | class StableDiffusion3Base(): 32 | def __init__(self, model_key:str='stabilityai/stable-diffusion-3-medium-diffusers', device='cuda', dtype=torch.float16): 33 | self.device = device 34 | self.dtype = dtype 35 | 36 | pipe = StableDiffusion3Pipeline.from_pretrained(model_key, torch_dtype=self.dtype) 37 | 38 | self.scheduler = pipe.scheduler 39 | 40 | self.tokenizer_1 = pipe.tokenizer 41 | self.tokenizer_2 = pipe.tokenizer_2 42 | self.tokenizer_3 = pipe.tokenizer_3 43 | self.text_enc_1 = pipe.text_encoder 44 | self.text_enc_2 = pipe.text_encoder_2 45 | self.text_enc_3 = pipe.text_encoder_3 46 | 47 | self.vae=pipe.vae 48 | self.transformer = pipe.transformer 49 | self.transformer.eval() 50 | self.transformer.requires_grad_(False) 51 | 52 | self.vae_scale_factor = ( 53 | 2 ** (len(self.vae.config.block_out_channels)-1) if hasattr(self, "vae") and self.vae is not None else 8 54 | ) 55 | 56 | del pipe 57 | 58 | def encode_prompt(self, prompt: List[str], batch_size:int=1) -> List[torch.Tensor]: 59 | ''' 60 | We assume that 61 | 1. number of tokens < max_length 62 | 2. one prompt for one image 63 | ''' 64 | # CLIP encode (used for modulation of adaLN-zero) 65 | # now, we have two CLIPs 66 | text_clip1_ids = self.tokenizer_1(prompt, 67 | padding="max_length", 68 | max_length=77, 69 | truncation=True, 70 | return_tensors='pt').input_ids 71 | text_clip1_emb = self.text_enc_1(text_clip1_ids.to(self.text_enc_1.device), output_hidden_states=True) 72 | pool_clip1_emb = text_clip1_emb[0].to(dtype=self.dtype, device=self.text_enc_1.device) 73 | text_clip1_emb = text_clip1_emb.hidden_states[-2].to(dtype=self.dtype, device=self.text_enc_1.device) 74 | 75 | text_clip2_ids = self.tokenizer_2(prompt, 76 | padding="max_length", 77 | max_length=77, 78 | truncation=True, 79 | return_tensors='pt').input_ids 80 | text_clip2_emb = self.text_enc_2(text_clip2_ids.to(self.text_enc_2.device), output_hidden_states=True) 81 | pool_clip2_emb = text_clip2_emb[0].to(dtype=self.dtype, device=self.text_enc_2.device) 82 | text_clip2_emb = text_clip2_emb.hidden_states[-2].to(dtype=self.dtype, device=self.text_enc_2.device) 83 | 84 | # T5 encode (used for text condition) 85 | text_t5_ids = self.tokenizer_3(prompt, 86 | padding="max_length", 87 | max_length=77, 88 | truncation=True, 89 | add_special_tokens=True, 90 | return_tensors='pt').input_ids 91 | text_t5_emb = self.text_enc_3(text_t5_ids.to(self.text_enc_3.device))[0] 92 | text_t5_emb = text_t5_emb.to(dtype=self.dtype, device=self.text_enc_3.device) 93 | 94 | 95 | # Merge 96 | clip_prompt_emb = torch.cat([text_clip1_emb, text_clip2_emb], dim=-1) 97 | clip_prompt_emb = torch.nn.functional.pad( 98 | clip_prompt_emb, (0, text_t5_emb.shape[-1] - clip_prompt_emb.shape[-1]) 99 | ) 100 | prompt_emb = torch.cat([clip_prompt_emb, text_t5_emb], dim=-2) 101 | pooled_prompt_emb = torch.cat([pool_clip1_emb, pool_clip2_emb], dim=-1) 102 | 103 | return prompt_emb, pooled_prompt_emb 104 | 105 | 106 | def initialize_latent(self, img_size:Tuple[int], batch_size:int=1, **kwargs): 107 | H, W = img_size 108 | lH, lW = H//self.vae_scale_factor, W//self.vae_scale_factor 109 | lC = self.transformer.config.in_channels 110 | latent_shape = (batch_size, lC, lH, lW) 111 | 112 | z = torch.randn(latent_shape, device=self.device, dtype=self.dtype) 113 | 114 | return z 115 | 116 | def encode(self, image: torch.Tensor) -> torch.Tensor: 117 | z = self.vae.encode(image).latent_dist.sample() 118 | z = (z-self.vae.config.shift_factor) * self.vae.config.scaling_factor 119 | return z 120 | 121 | def decode(self, z: torch.Tensor) -> torch.Tensor: 122 | z = (z/self.vae.config.scaling_factor) + self.vae.config.shift_factor 123 | return self.vae.decode(z, return_dict=False)[0] 124 | 125 | def predict_vector(self, z, t, prompt_emb, pooled_emb): 126 | v = self.transformer(hidden_states=z, 127 | timestep=t, 128 | pooled_projections=pooled_emb, 129 | encoder_hidden_states=prompt_emb, 130 | return_dict=False)[0] 131 | return v 132 | 133 | class SD3Euler(StableDiffusion3Base): 134 | def __init__(self, model_key:str='stabilityai/stable-diffusion-3-medium-diffusers', device='cuda'): 135 | super().__init__(model_key=model_key, device=device) 136 | 137 | def inversion(self, src_img, prompts: List[str], NFE:int, cfg_scale: float=1.0, batch_size: int=1, 138 | prompt_emb:Optional[List[torch.Tensor]]=None, 139 | null_emb:Optional[List[torch.Tensor]]=None): 140 | 141 | # encode text prompts 142 | with torch.no_grad(): 143 | if prompt_emb is None: 144 | prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size) 145 | else: 146 | prompt_emb, pooled_emb = prompt_emb[0], prompt_emb[1] 147 | 148 | prompt_emb = prompt_emb.to(self.transformer.device) 149 | pooled_emb = pooled_emb.to(self.transformer.device) 150 | 151 | if null_emb is None: 152 | null_prompt_emb, null_pooled_emb = self.encode_prompt([""]) 153 | else: 154 | null_prompt_emb, null_pooled_emb = null_emb[0], null_emb[1] 155 | 156 | null_prompt_emb = null_prompt_emb.to(self.transformer.device) 157 | null_pooled_emb = null_pooled_emb.to(self.transformer.device) 158 | 159 | # initialize latent 160 | src_img = src_img.to(device=self.vae.device, dtype=self.dtype) 161 | with torch.no_grad(): 162 | z = self.encode(src_img).to(self.transformer.device) 163 | 164 | # timesteps (default option. You can make your custom here.) 165 | self.scheduler.set_timesteps(NFE, device=self.transformer.device) 166 | timesteps = self.scheduler.timesteps 167 | timesteps = torch.cat([timesteps, torch.zeros(1, device=self.transformer.device)]) 168 | timesteps = reversed(timesteps) 169 | sigmas = timesteps / self.scheduler.config.num_train_timesteps 170 | 171 | # Solve ODE 172 | pbar = tqdm(timesteps[:-1], total=NFE, desc='SD3 Euler Inversion') 173 | for i, t in enumerate(pbar): 174 | timestep = t.expand(z.shape[0]).to(self.transformer.device) 175 | with torch.no_grad(): 176 | pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb) 177 | if cfg_scale != 1.0: 178 | pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb) 179 | else: 180 | pred_null_v = 0.0 181 | 182 | sigma = sigmas[i] 183 | sigma_next = sigmas[i+1] 184 | 185 | z = z + (sigma_next - sigma) * (pred_null_v + cfg_scale * (pred_v - pred_null_v)) 186 | 187 | return z 188 | 189 | def sample(self, prompts: List[str], NFE:int, img_shape: Optional[Tuple[int]]=None, 190 | cfg_scale: float=1.0, batch_size: int = 1, 191 | latent:Optional[List[torch.Tensor]]=None, 192 | prompt_emb:Optional[List[torch.Tensor]]=None, 193 | null_emb:Optional[List[torch.Tensor]]=None): 194 | 195 | imgH, imgW = img_shape if img_shape is not None else (1024, 1024) 196 | 197 | # encode text prompts 198 | with torch.no_grad(): 199 | if prompt_emb is None: 200 | prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size) 201 | else: 202 | prompt_emb, pooled_emb = prompt_emb[0], prompt_emb[1] 203 | 204 | prompt_emb.to(self.transformer.device) 205 | pooled_emb.to(self.transformer.device) 206 | 207 | if null_emb is None: 208 | null_prompt_emb, null_pooled_emb = self.encode_prompt([""], batch_size) 209 | else: 210 | null_prompt_emb, null_pooled_emb = null_emb[0], null_emb[1] 211 | 212 | null_prompt_emb.to(self.transformer.device) 213 | null_pooled_emb.to(self.transformer.device) 214 | 215 | # initialize latent 216 | if latent is None: 217 | z = self.initialize_latent((imgH, imgW), batch_size) 218 | else: 219 | z = latent 220 | 221 | # timesteps (default option. You can make your custom here.) 222 | self.scheduler.set_timesteps(NFE, device=self.device) 223 | timesteps = self.scheduler.timesteps 224 | sigmas = timesteps / self.scheduler.config.num_train_timesteps 225 | 226 | # Solve ODE 227 | pbar = tqdm(timesteps, total=NFE, desc='SD3 Euler') 228 | for i, t in enumerate(pbar): 229 | timestep = t.expand(z.shape[0]).to(self.device) 230 | pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb) 231 | if cfg_scale != 1.0: 232 | pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb) 233 | else: 234 | pred_null_v = 0.0 235 | 236 | sigma = sigmas[i] 237 | sigma_next = sigmas[i+1] if i+1 < NFE else 0.0 238 | 239 | z = z + (sigma_next - sigma) * (pred_null_v + cfg_scale * (pred_v - pred_null_v)) 240 | 241 | # decode 242 | with torch.no_grad(): 243 | img = self.decode(z) 244 | return img 245 | 246 | @register_solver("flowdps") 247 | class SD3FlowDPS(SD3Euler): 248 | def data_consistency(self, z0t, operator, measurement, task, stepsize:int=30.0): 249 | z0t = z0t.requires_grad_(True) 250 | num_iters = 3 251 | for _ in range(num_iters): 252 | x0t = self.decode(z0t).float() 253 | if "sr" in task: 254 | loss = torch.linalg.norm((operator.A_pinv(measurement) - operator.A_pinv(operator.A(x0t))).view(1, -1)) 255 | else: 256 | loss = torch.linalg.norm((operator.At(measurement) - operator.At(operator.A(x0t))).view(1, -1)) 257 | grad = torch.autograd.grad(loss, z0t)[0].half() 258 | z0t = z0t - stepsize*grad 259 | 260 | return z0t.detach() 261 | 262 | def sample(self, measurement, operator, task, 263 | prompts: List[str], NFE:int, 264 | img_shape: Optional[Tuple[int]]=None, 265 | cfg_scale: float=1.0, batch_size: int = 1, 266 | step_size: float=30.0, 267 | latent:Optional[List[torch.Tensor]]=None, 268 | prompt_emb:Optional[List[torch.Tensor]]=None, 269 | null_emb:Optional[List[torch.Tensor]]=None): 270 | 271 | imgH, imgW = img_shape if img_shape is not None else (1024, 1024) 272 | 273 | # encode text prompts 274 | with torch.no_grad(): 275 | if prompt_emb is None: 276 | prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size) 277 | else: 278 | prompt_emb, pooled_emb = prompt_emb[0], prompt_emb[1] 279 | 280 | prompt_emb.to(self.transformer.device) 281 | pooled_emb.to(self.transformer.device) 282 | 283 | if null_emb is None: 284 | null_prompt_emb, null_pooled_emb = self.encode_prompt([""], batch_size) 285 | else: 286 | null_prompt_emb, null_pooled_emb = null_emb[0], null_emb[1] 287 | 288 | null_prompt_emb.to(self.transformer.device) 289 | null_pooled_emb.to(self.transformer.device) 290 | 291 | # initialize latent 292 | if latent is None: 293 | z = self.initialize_latent((imgH, imgW), batch_size) 294 | else: 295 | z = latent 296 | 297 | # timesteps (default option. You can make your custom here.) 298 | self.scheduler.config.shift = 4.0 299 | self.scheduler.set_timesteps(NFE, device=self.device) 300 | timesteps = self.scheduler.timesteps 301 | sigmas = timesteps / self.scheduler.config.num_train_timesteps 302 | 303 | # Solve ODE 304 | pbar = tqdm(timesteps, total=NFE, desc='SD3-FlowDPS') 305 | for i, t in enumerate(pbar): 306 | timestep = t.expand(z.shape[0]).to(self.device) 307 | 308 | with torch.no_grad(): 309 | pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb) 310 | if cfg_scale != 1.0: 311 | pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb) 312 | else: 313 | pred_null_v = 0.0 314 | 315 | sigma = sigmas[i] 316 | sigma_next = sigmas[i+1] if i+1 < NFE else 0.0 317 | 318 | # denoising 319 | z0t = z - sigma * (pred_null_v + cfg_scale * (pred_v-pred_null_v)) 320 | z1t = z + (1-sigma) * (pred_null_v + cfg_scale * (pred_v-pred_null_v)) 321 | delta = sigma - sigma_next 322 | 323 | if i < NFE: 324 | z0y = self.data_consistency(z0t, operator, measurement, task=task, stepsize=step_size) 325 | z0y = (1-sigma) * z0t + sigma * z0y 326 | 327 | # renoising 328 | noise = math.sqrt(sigma_next) * z1t + math.sqrt(1-sigma_next) * torch.randn_like(z1t) 329 | z = z0y + (sigma-delta) * (noise - z0y) 330 | 331 | # decode 332 | with torch.no_grad(): 333 | img = self.decode(z) 334 | return img 335 | 336 | @register_solver("flowchef") 337 | class SD3FlowChef(SD3Euler): 338 | def data_consistency(self, z0t, operator, measurement, task): 339 | z0t = z0t.requires_grad_(True) 340 | x0t = self.decode(z0t).float() 341 | if "sr" in task: 342 | loss = torch.linalg.norm((operator.A_pinv(measurement) - operator.A_pinv(operator.A(x0t))).view(1, -1)) 343 | else: 344 | loss = torch.linalg.norm((operator.At(measurement) - operator.At(operator.A(x0t))).view(1, -1)) 345 | grad = torch.autograd.grad(loss, z0t)[0].half() 346 | return grad.detach() 347 | 348 | 349 | def sample(self, measurement, operator, task, 350 | prompts: List[str], NFE:int, 351 | img_shape: Optional[Tuple[int]]=None, 352 | cfg_scale: float=1.0, batch_size: int = 1, 353 | step_size: float=50.0, 354 | latent:Optional[List[torch.Tensor]]=None, 355 | prompt_emb:Optional[List[torch.Tensor]]=None, 356 | null_emb:Optional[List[torch.Tensor]]=None): 357 | 358 | imgH, imgW = img_shape if img_shape is not None else (1024, 1024) 359 | 360 | # encode text prompts 361 | with torch.no_grad(): 362 | if prompt_emb is None: 363 | prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size) 364 | else: 365 | prompt_emb, pooled_emb = prompt_emb[0], prompt_emb[1] 366 | 367 | prompt_emb.to(self.transformer.device) 368 | pooled_emb.to(self.transformer.device) 369 | 370 | if null_emb is None: 371 | null_prompt_emb, null_pooled_emb = self.encode_prompt([""], batch_size) 372 | else: 373 | null_prompt_emb, null_pooled_emb = null_emb[0], null_emb[1] 374 | 375 | null_prompt_emb.to(self.transformer.device) 376 | null_pooled_emb.to(self.transformer.device) 377 | 378 | # initialize latent 379 | if latent is None: 380 | z = self.initialize_latent((imgH, imgW), batch_size) 381 | else: 382 | z = latent 383 | 384 | # timesteps (default option. You can make your custom here.) 385 | self.scheduler.config.shift = 4.0 386 | self.scheduler.set_timesteps(NFE, device=self.device) 387 | timesteps = self.scheduler.timesteps 388 | sigmas = timesteps / self.scheduler.config.num_train_timesteps 389 | 390 | # Solve ODE 391 | pbar = tqdm(timesteps, total=NFE, desc='SD3-FlowChef') 392 | for i, t in enumerate(pbar): 393 | timestep = t.expand(z.shape[0]).to(self.device) 394 | 395 | with torch.no_grad(): 396 | pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb) 397 | if cfg_scale != 1.0: 398 | pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb) 399 | else: 400 | pred_null_v = 0.0 401 | 402 | sigma = sigmas[i] 403 | sigma_next = sigmas[i+1] if i+1 < NFE else 0.0 404 | 405 | # denoising 406 | z0t = z - sigma * (pred_null_v + cfg_scale * (pred_v-pred_null_v)) 407 | z1t = z + (1-sigma) * (pred_null_v + cfg_scale * (pred_v-pred_null_v)) 408 | delta = sigma - sigma_next 409 | 410 | if i < NFE: 411 | grad = self.data_consistency(z0t, operator, measurement, task=task) 412 | 413 | # renoising 414 | z = z0t + (sigma-delta) * (z1t - z0t) - step_size*grad 415 | 416 | # decode 417 | with torch.no_grad(): 418 | img = self.decode(z) 419 | return img 420 | 421 | @register_solver('psld') 422 | class SD3PSLD(SD3Euler): 423 | def sample(self, measurement, operator, task, 424 | prompts: List[str], NFE:int, 425 | img_shape: Optional[Tuple[int]]=None, 426 | cfg_scale: float=1.0, batch_size: int = 1, 427 | step_size: float=50.0, 428 | latent:Optional[List[torch.Tensor]]=None, 429 | prompt_emb:Optional[List[torch.Tensor]]=None, 430 | null_emb:Optional[List[torch.Tensor]]=None): 431 | 432 | imgH, imgW = img_shape if img_shape is not None else (1024, 1024) 433 | 434 | # encode text prompts 435 | with torch.no_grad(): 436 | if prompt_emb is None: 437 | prompt_emb, pooled_emb = self.encode_prompt(prompts, batch_size) 438 | else: 439 | prompt_emb, pooled_emb = prompt_emb[0], prompt_emb[1] 440 | 441 | prompt_emb.to(self.transformer.device) 442 | pooled_emb.to(self.transformer.device) 443 | 444 | if null_emb is None: 445 | null_prompt_emb, null_pooled_emb = self.encode_prompt([""], batch_size) 446 | else: 447 | null_prompt_emb, null_pooled_emb = null_emb[0], null_emb[1] 448 | 449 | null_prompt_emb.to(self.transformer.device) 450 | null_pooled_emb.to(self.transformer.device) 451 | 452 | # initialize latent 453 | if latent is None: 454 | z = self.initialize_latent((imgH, imgW), batch_size) 455 | else: 456 | z = latent 457 | 458 | # timesteps (default option. You can make your custom here.) 459 | self.scheduler.config.shift = 4.0 460 | self.scheduler.set_timesteps(NFE, device=self.device) 461 | timesteps = self.scheduler.timesteps 462 | sigmas = timesteps / self.scheduler.config.num_train_timesteps 463 | 464 | # Solve ODE 465 | pbar = tqdm(timesteps, total=NFE, desc='SD3-PSLD') 466 | for i, t in enumerate(pbar): 467 | timestep = t.expand(z.shape[0]).to(self.device) 468 | 469 | z = z.requires_grad_(True) 470 | pred_v = self.predict_vector(z, timestep, prompt_emb, pooled_emb) 471 | if cfg_scale != 1.0: 472 | pred_null_v = self.predict_vector(z, timestep, null_prompt_emb, null_pooled_emb) 473 | else: 474 | pred_null_v = 0.0 475 | pred_v = pred_null_v + cfg_scale * (pred_v - pred_null_v) 476 | 477 | sigma = sigmas[i] 478 | sigma_next = sigmas[i+1] if i+1 < NFE else 0.0 479 | 480 | # denoising 481 | z0t = z - sigma * pred_v 482 | z1t = z + (1-sigma) * pred_v 483 | delta = sigma - sigma_next 484 | 485 | # DC & goodness of z0t 486 | x_pred = self.decode(z0t).float() 487 | y_pred = operator.A(x_pred) 488 | y_residue = torch.linalg.norm((y_pred-measurement).view(1, -1)) 489 | 490 | if "sr" in task: 491 | ortho_proj = x_pred.reshape(1, -1) - operator.A_pinv(y_pred).reshape(1, -1) 492 | parallel_proj = operator.A_pinv(measurement).reshape(1, -1) 493 | else: 494 | ortho_proj = x_pred.reshape(1, -1) - operator.At(y_pred).reshape(1, -1) 495 | parallel_proj = operator.At(measurement).reshape(1, -1) 496 | proj = parallel_proj + ortho_proj 497 | 498 | recon_z = self.encode(proj.reshape(1, 3, imgH, imgW).half()) 499 | z0_residue = torch.linalg.norm((z0t - recon_z).view(1, -1)) 500 | 501 | residue = 1.0 * y_residue + 0.1 * z0_residue 502 | grad = torch.autograd.grad(residue, z)[0] 503 | 504 | # renoising 505 | z = z0t + (sigma-delta) * (z1t - z0t) - step_size*grad 506 | z.detach() 507 | 508 | # decode 509 | with torch.no_grad(): 510 | img = self.decode(z) 511 | return img 512 | 513 | -------------------------------------------------------------------------------- /solve.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from typing import List 4 | 5 | from munch import munchify 6 | from PIL import Image 7 | from tqdm import tqdm 8 | import torch 9 | from torchvision.utils import save_image 10 | from torchvision import transforms 11 | 12 | from util import set_seed, get_img_list, process_text 13 | from sd3_sampler import get_solver 14 | from functions.degradation import get_degradation 15 | 16 | @torch.no_grad 17 | def precompute(args, prompts:List[str], solver) -> List[torch.Tensor]: 18 | prompt_emb_set = [] 19 | pooled_emb_set = [] 20 | 21 | num_samples = args.num_samples if args.num_samples > 0 else len(prompts) 22 | for prompt in prompts[:num_samples]: 23 | prompt_emb, pooled_emb = solver.encode_prompt(prompt, batch_size=1) 24 | prompt_emb_set.append(prompt_emb) 25 | pooled_emb_set.append(pooled_emb) 26 | 27 | return prompt_emb_set, pooled_emb_set 28 | 29 | def run(args): 30 | # load solver 31 | solver = get_solver(args.method) 32 | 33 | # load text prompts 34 | prompts = process_text(prompt=args.prompt, prompt_file=args.prompt_file) 35 | solver.text_enc_1.to('cuda') 36 | solver.text_enc_2.to('cuda') 37 | solver.text_enc_3.to('cuda') 38 | 39 | if args.efficient_memory: 40 | # precompute text embedding and remove encoders from GPU 41 | # This will allow us 1) fast inference 2) with lower memory requirement (<24GB) 42 | with torch.no_grad(): 43 | prompt_emb_set, pooled_emb_set = precompute(args, prompts, solver) 44 | null_emb, null_pooled_emb = solver.encode_prompt([''], batch_size=1) 45 | 46 | del solver.text_enc_1 47 | del solver.text_enc_2 48 | del solver.text_enc_3 49 | torch.cuda.empty_cache() 50 | 51 | prompt_embs = [[x, y] for x, y in zip(prompt_emb_set, pooled_emb_set)] 52 | null_embs = [null_emb, null_pooled_emb] 53 | else: 54 | prompt_embs = [[None, None]] * len(prompts) 55 | null_embs = [None, None] 56 | 57 | print("Prompts are processed.") 58 | 59 | solver.vae.to('cuda') 60 | solver.transformer.to('cuda') 61 | 62 | # problem setup 63 | deg_config = munchify({ 64 | 'channels': 3, 65 | 'image_size': args.img_size, 66 | 'deg_scale': args.deg_scale 67 | }) 68 | operator = get_degradation(args.task, deg_config, solver.transformer.device) 69 | 70 | # solve problem 71 | tf = transforms.Compose([ 72 | transforms.Resize(args.img_size), 73 | transforms.CenterCrop(args.img_size), 74 | transforms.ToTensor() 75 | ]) 76 | 77 | pbar = tqdm(get_img_list(args.img_path), desc="Solving") 78 | for i, path in enumerate(pbar): 79 | img = tf(Image.open(path).convert('RGB')) 80 | img = img.unsqueeze(0).to(solver.vae.device) 81 | img = img * 2 - 1 82 | 83 | y = operator.A(img) 84 | y = y + 0.03 * torch.randn_like(y) 85 | 86 | out = solver.sample(measurement=y, 87 | operator=operator, 88 | prompts=prompts[i] if len(prompts)>1 else prompts[0], 89 | NFE=args.NFE, 90 | img_shape=(args.img_size, args.img_size), 91 | cfg_scale=args.cfg_scale, 92 | step_size=args.step_size, 93 | task=args.task, 94 | prompt_emb=prompt_embs[i] if len(prompt_embs)>1 else prompt_embs[0], 95 | null_emb=null_embs 96 | ) 97 | # save results 98 | save_image(operator.At(y).reshape(img.shape), 99 | args.workdir.joinpath(f'input/{str(i).zfill(4)}.png'), 100 | normalize=True) 101 | save_image(out, 102 | args.workdir.joinpath(f'recon/{str(i).zfill(4)}.png'), 103 | normalize=True) 104 | save_image(img, 105 | args.workdir.joinpath(f'label/{str(i).zfill(4)}.png'), 106 | normalize=True) 107 | 108 | if (i+1) == args.num_samples: 109 | break 110 | 111 | 112 | if __name__ == "__main__": 113 | parser = argparse.ArgumentParser() 114 | # sampling params 115 | parser.add_argument('--seed', type=int, default=0) 116 | parser.add_argument('--NFE', type=int, default=28) 117 | parser.add_argument('--cfg_scale', type=float, default=2.0) 118 | parser.add_argument('--img_size', type=int, default=768) 119 | 120 | # workdir params 121 | parser.add_argument('--workdir', type=Path, default='workdir') 122 | 123 | # data params 124 | parser.add_argument('--img_path', type=Path) 125 | parser.add_argument('--prompt', type=str, default=None) 126 | parser.add_argument('--prompt_file', type=str, default=None) 127 | parser.add_argument('--num_samples', type=int, default=-1) 128 | 129 | # problem params 130 | parser.add_argument('--task', type=str, default='sr_avgpool') 131 | parser.add_argument('--method', type=str, default='flowdps') 132 | parser.add_argument('--deg_scale', type=int, default=12) 133 | 134 | # solver params 135 | parser.add_argument('--step_size', type=float, default=15.0) 136 | parser.add_argument('--efficient_memory',default=False, action='store_true') 137 | args = parser.parse_args() 138 | 139 | 140 | # workdir creation and seed setup 141 | set_seed(args.seed) 142 | args.workdir.joinpath('input').mkdir(parents=True, exist_ok=True) 143 | args.workdir.joinpath('recon').mkdir(parents=True, exist_ok=True) 144 | args.workdir.joinpath('label').mkdir(parents=True, exist_ok=True) 145 | 146 | # run main script 147 | run(args) 148 | 149 | -------------------------------------------------------------------------------- /solve_arbitrary.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from typing import List 4 | 5 | from munch import munchify 6 | from PIL import Image 7 | from tqdm import tqdm 8 | import torch 9 | from torchvision.utils import save_image 10 | from torchvision import transforms 11 | 12 | from util import set_seed, get_img_list, process_text 13 | from sd3_sampler import get_solver 14 | from functions.degradation import get_degradation 15 | 16 | @torch.no_grad 17 | def precompute(args, prompts:List[str], solver) -> List[torch.Tensor]: 18 | prompt_emb_set = [] 19 | pooled_emb_set = [] 20 | 21 | num_samples = args.num_samples if args.num_samples > 0 else len(prompts) 22 | for prompt in prompts[:num_samples]: 23 | prompt_emb, pooled_emb = solver.encode_prompt(prompt, batch_size=1) 24 | prompt_emb_set.append(prompt_emb) 25 | pooled_emb_set.append(pooled_emb) 26 | 27 | return prompt_emb_set, pooled_emb_set 28 | 29 | def run(args): 30 | # load solver 31 | solver = get_solver(args.method) 32 | 33 | # load text prompts 34 | prompts = process_text(prompt=args.prompt, prompt_file=args.prompt_file) 35 | solver.text_enc_1.to('cuda') 36 | solver.text_enc_2.to('cuda') 37 | solver.text_enc_3.to('cuda') 38 | 39 | if args.efficient_memory: 40 | # precompute text embedding and remove encoders from GPU 41 | # This will allow us 1) fast inference 2) with lower memory requirement (<24GB) 42 | with torch.no_grad(): 43 | prompt_emb_set, pooled_emb_set = precompute(args, prompts, solver) 44 | null_emb, null_pooled_emb = solver.encode_prompt([''], batch_size=1) 45 | 46 | del solver.text_enc_1 47 | del solver.text_enc_2 48 | del solver.text_enc_3 49 | torch.cuda.empty_cache() 50 | 51 | prompt_embs = [[x, y] for x, y in zip(prompt_emb_set, pooled_emb_set)] 52 | null_embs = [null_emb, null_pooled_emb] 53 | else: 54 | prompt_embs = [[None, None]] * len(prompts) 55 | null_embs = [None, None] 56 | 57 | print("Prompts are processed.") 58 | 59 | solver.vae.to('cuda') 60 | solver.transformer.to('cuda') 61 | 62 | # problem setup 63 | deg_config = munchify({ 64 | 'channels': 3, 65 | 'imgH': args.imgH, 66 | 'imgW': args.imgW, 67 | 'deg_scale': args.deg_scale 68 | }) 69 | operator = get_degradation(args.task, deg_config, solver.transformer.device) 70 | 71 | # solve problem 72 | tf = transforms.Compose([ 73 | transforms.Resize(min(args.imgH, args.imgW)), 74 | transforms.CenterCrop((args.imgH, args.imgW)), 75 | transforms.ToTensor() 76 | ]) 77 | 78 | pbar = tqdm(get_img_list(args.img_path), desc="Solving") 79 | for i, path in enumerate(pbar): 80 | img = tf(Image.open(path).convert('RGB')) 81 | img = img.unsqueeze(0).to(solver.vae.device) 82 | img = img * 2 - 1 83 | 84 | y = operator.A(img) 85 | y = y + 0.03 * torch.randn_like(y) 86 | 87 | out = solver.sample(measurement=y, 88 | operator=operator, 89 | prompts=prompts[i] if len(prompts)>1 else prompts[0], 90 | NFE=args.NFE, 91 | img_shape=(args.imgH, args.imgW), 92 | cfg_scale=args.cfg_scale, 93 | step_size=args.step_size, 94 | task=args.task, 95 | prompt_emb=prompt_embs[i] if len(prompt_embs)>1 else prompt_embs[0], 96 | null_emb=null_embs 97 | ) 98 | # save results 99 | save_image(operator.At(y).reshape(img.shape), 100 | args.workdir.joinpath(f'input/{str(i).zfill(4)}.png'), 101 | normalize=True) 102 | save_image(out, 103 | args.workdir.joinpath(f'recon/{str(i).zfill(4)}.png'), 104 | normalize=True) 105 | save_image(img, 106 | args.workdir.joinpath(f'label/{str(i).zfill(4)}.png'), 107 | normalize=True) 108 | 109 | if (i+1) == args.num_samples: 110 | break 111 | 112 | 113 | if __name__ == "__main__": 114 | parser = argparse.ArgumentParser() 115 | # sampling params 116 | parser.add_argument('--seed', type=int, default=0) 117 | parser.add_argument('--NFE', type=int, default=28) 118 | parser.add_argument('--cfg_scale', type=float, default=2.0) 119 | parser.add_argument('--imgH', type=int, default=768) 120 | parser.add_argument('--imgW', type=int, default=1152) 121 | 122 | # workdir params 123 | parser.add_argument('--workdir', type=Path, default='workdir') 124 | 125 | # data params 126 | parser.add_argument('--img_path', type=Path) 127 | parser.add_argument('--prompt', type=str, default=None) 128 | parser.add_argument('--prompt_file', type=str, default=None) 129 | parser.add_argument('--num_samples', type=int, default=-1) 130 | 131 | # problem params 132 | parser.add_argument('--task', type=str, default='sr_avgpool') 133 | parser.add_argument('--method', type=str, default='flowdps') 134 | parser.add_argument('--deg_scale', type=int, default=12) 135 | 136 | # solver params 137 | parser.add_argument('--step_size', type=float, default=15.0) 138 | parser.add_argument('--efficient_memory',default=False, action='store_true') 139 | args = parser.parse_args() 140 | 141 | 142 | # workdir creation and seed setup 143 | set_seed(args.seed) 144 | args.workdir.joinpath('input').mkdir(parents=True, exist_ok=True) 145 | args.workdir.joinpath('recon').mkdir(parents=True, exist_ok=True) 146 | args.workdir.joinpath('label').mkdir(parents=True, exist_ok=True) 147 | 148 | # run main script 149 | run(args) 150 | 151 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | 8 | def set_seed(seed): 9 | torch.manual_seed(seed) 10 | torch.cuda.manual_seed(seed) 11 | torch.cuda.manual_seed_all(seed) # if use multi-GPU 12 | torch.backends.cudnn.deterministic = True 13 | torch.backends.cudnn.benchmark = False 14 | np.random.seed(seed) 15 | random.seed(seed) 16 | 17 | 18 | def process_prompt_file(prompt_file: str, **kwargs): 19 | parse_fn = kwargs.get('parse_fn', None) 20 | if parse_fn is None: 21 | # default setup for DIV2K prompt from DAPE 22 | def _parse_fn(text): 23 | tmp = text.split(": ")[-1].strip() 24 | if tmp == '': 25 | return "a high quality photo" 26 | else: 27 | return "a high quality photo of " + tmp 28 | 29 | parse_fn = lambda x: _parse_fn(x) 30 | 31 | with open(prompt_file, 'r') as f: 32 | prompts = f.readlines() 33 | prompts = [parse_fn(x) for x in prompts] 34 | return prompts 35 | 36 | 37 | def process_text(prompt: str=None, prompt_file: str=None, **kwargs) -> List[str]: 38 | assert prompt is not None or prompt_file is not None, \ 39 | print("Either prompt of prompt_file must be given.") 40 | 41 | if prompt is not None: 42 | if prompt_file is not None: 43 | print("Both prompt and prompt_file are given. We will use prompt.") 44 | prompts = [prompt] 45 | else: 46 | prompts = process_prompt_file(prompt_file, **kwargs) 47 | 48 | return prompts 49 | 50 | 51 | def get_img_list(root: Path): 52 | if root.is_dir(): 53 | files = list(sorted(root.glob('*.png'))) \ 54 | + list(sorted(root.glob('*.jpg'))) \ 55 | + list(sorted(root.glob('*.jpeg'))) 56 | else: 57 | files = [root] 58 | 59 | for f in files: 60 | yield f 61 | -------------------------------------------------------------------------------- /utils/admm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def solve_for_k(x, y, v, u, rho): 5 | Fx = torch.fft.fftn(x, dim=(-2, -1)) 6 | Fxc = torch.conj(Fx) 7 | FxcFx = torch.pow(torch.abs(Fx), 2) 8 | Fy = torch.fft.fftn(y, dim=(-2, -1)) 9 | 10 | Sec = torch.fft.fftn((v - u/rho), dim=(-2, -1)) 11 | Fk_new = (Fxc * Fy + rho * Sec) / (FxcFx + rho) 12 | 13 | return torch.real(torch.fft.ifftn(Fk_new, dim=(-2, -1))) 14 | 15 | def solve_for_k_v2(x, y, v, u, rho): 16 | Fx = torch.fft.fftn(x, dim=(-2, -1)) 17 | Fxc = torch.conj(Fx) 18 | FxcFx = torch.pow(torch.abs(Fx), 2) 19 | Fy = torch.fft.fftn(y, dim=(-2, -1)) 20 | Sec = torch.fft.fftn((v - u/rho), dim=(-2, -1)) 21 | 22 | FR = Fxc * Fy + rho * Sec 23 | FBR = Fx.mul(FR) 24 | invWBR = FBR.div(FxcFx + rho) 25 | FCBInvWBR = Fxc * invWBR 26 | Fk = (FR - FCBInvWBR) / rho 27 | return torch.real(torch.fft.ifftn(Fk, dim=(-2, -1))) 28 | 29 | def estimate_ker(x, y, k, m:int=15, rho=5e-2, beta=1e-3): 30 | # initialize 31 | v = torch.zeros_like(x).to('cuda') 32 | u = torch.zeros_like(x).to('cuda') 33 | 34 | for _ in range(m): 35 | v = torch.clamp((k + u/rho).abs() - beta/rho, 0) * torch.sign(k + u/rho) 36 | k = solve_for_k_v2(x, y, v, u, rho) 37 | u = u + 1.5*rho*(k - v) 38 | 39 | return torch.fft.ifftshift(k).flip(dims=(-2, -1)) -------------------------------------------------------------------------------- /utils/blur_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import scipy 5 | 6 | from utils.motionblur import Kernel as MotionKernel 7 | 8 | class Blurkernel(nn.Module): 9 | def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0, device=None): 10 | super().__init__() 11 | self.blur_type = blur_type 12 | self.kernel_size = kernel_size 13 | self.std = std 14 | self.device = device 15 | self.seq = nn.Sequential( 16 | nn.ReflectionPad2d(self.kernel_size//2), 17 | nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3) 18 | ) 19 | 20 | self.weights_init() 21 | 22 | def forward(self, x): 23 | return self.seq(x) 24 | 25 | def weights_init(self): 26 | if self.blur_type == "gaussian": 27 | n = np.zeros((self.kernel_size, self.kernel_size)) 28 | n[self.kernel_size // 2,self.kernel_size // 2] = 1 29 | k = scipy.ndimage.gaussian_filter(n, sigma=self.std) 30 | k = torch.from_numpy(k) 31 | self.k = k 32 | for name, f in self.named_parameters(): 33 | f.data.copy_(k) 34 | elif self.blur_type == "motion": 35 | k = MotionKernel(size=(self.kernel_size, self.kernel_size), intensity=self.std).kernelMatrix 36 | k = torch.from_numpy(k) 37 | self.k = k 38 | for name, f in self.named_parameters(): 39 | f.data.copy_(k) 40 | 41 | def update_weights(self, k): 42 | if not torch.is_tensor(k): 43 | k = torch.from_numpy(k).to(self.device) 44 | for name, f in self.named_parameters(): 45 | f.data.copy_(k) 46 | 47 | def get_kernel(self): 48 | return self.k 49 | -------------------------------------------------------------------------------- /utils/diffpir_util.py: -------------------------------------------------------------------------------- 1 | '''Taken and modified from https://github.com/yuanzhi-zhu/DiffPIR''' 2 | 3 | # -*- coding: utf-8 -*- 4 | import numpy as np 5 | import torch 6 | import torch.fft 7 | from scipy import ndimage 8 | from scipy.interpolate import interp2d 9 | 10 | 11 | def splits(a, sf): 12 | '''split a into sfxsf distinct blocks 13 | Args: 14 | a: NxCxWxH 15 | sf: split factor 16 | Returns: 17 | b: NxCx(W/sf)x(H/sf)x(sf^2) 18 | ''' 19 | b = torch.stack(torch.chunk(a, sf, dim=2), dim=4) 20 | b = torch.cat(torch.chunk(b, sf, dim=3), dim=4) 21 | return b 22 | 23 | 24 | def p2o(psf, shape): 25 | ''' 26 | Convert point-spread function to optical transfer function. 27 | otf = p2o(psf) computes the Fast Fourier Transform (FFT) of the 28 | point-spread function (PSF) array and creates the optical transfer 29 | function (OTF) array that is not influenced by the PSF off-centering. 30 | Args: 31 | psf: NxCxhxw 32 | shape: [H, W] 33 | Returns: 34 | otf: NxCxHxWx2 35 | ''' 36 | otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf) 37 | otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf) 38 | for axis, axis_size in enumerate(psf.shape[2:]): 39 | otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2) 40 | otf = torch.fft.fftn(otf, dim=(-2,-1)) 41 | #n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf))) 42 | #otf[..., 1][torch.abs(otf[..., 1]) < n_ops*2.22e-16] = torch.tensor(0).type_as(psf) 43 | return otf 44 | 45 | 46 | def upsample(x, sf=3): 47 | '''s-fold upsampler 48 | Upsampling the spatial size by filling the new entries with zeros 49 | x: tensor image, NxCxWxH 50 | ''' 51 | st = 0 52 | z = torch.zeros((x.shape[0], x.shape[1], x.shape[2]*sf, x.shape[3]*sf)).type_as(x) 53 | z[..., st::sf, st::sf].copy_(x) 54 | return z 55 | 56 | 57 | def downsample(x, sf=3): 58 | '''s-fold downsampler 59 | Keeping the upper-left pixel for each distinct sfxsf patch and discarding the others 60 | x: tensor image, NxCxWxH 61 | ''' 62 | st = 0 63 | return x[..., st::sf, st::sf] 64 | 65 | 66 | 67 | def data_solution(x, FB, FBC, F2B, FBFy, alpha, sf): 68 | FR = FBFy + torch.fft.fftn(alpha*x, dim=(-2,-1)) 69 | x1 = FB.mul(FR) 70 | FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False) 71 | invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False) 72 | invWBR = FBR.div(invW + alpha) 73 | FCBinvWBR = FBC*invWBR.repeat(1, 1, sf, sf) 74 | FX = (FR-FCBinvWBR)/alpha 75 | Xest = torch.real(torch.fft.ifftn(FX, dim=(-2, -1))) 76 | 77 | return Xest 78 | 79 | 80 | def pre_calculate(x, k, sf): 81 | ''' 82 | Args: 83 | x: NxCxHxW, LR input 84 | k: NxCxhxw 85 | sf: integer 86 | 87 | Returns: 88 | FB, FBC, F2B, FBFy 89 | will be reused during iterations 90 | ''' 91 | w, h = x.shape[-2:] 92 | FB = p2o(k, (w*sf, h*sf)) 93 | FBC = torch.conj(FB) 94 | F2B = torch.pow(torch.abs(FB), 2) 95 | STy = upsample(x, sf=sf) 96 | FBFy = FBC*torch.fft.fftn(STy, dim=(-2, -1)) 97 | return FB, FBC, F2B, FBFy 98 | 99 | 100 | def classical_degradation(x, k, sf=3): 101 | ''' blur + downsampling 102 | 103 | Args: 104 | x: HxWxC image, [0, 1]/[0, 255] 105 | k: hxw, double 106 | sf: down-scale factor 107 | 108 | Return: 109 | downsampled LR image 110 | ''' 111 | x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') 112 | #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) 113 | st = 0 114 | return x[st::sf, st::sf, ...] 115 | 116 | 117 | 118 | def shift_pixel(x, sf, upper_left=True): 119 | """shift pixel for super-resolution with different scale factors 120 | Args: 121 | x: WxHxC or WxH, image or kernel 122 | sf: scale factor 123 | upper_left: shift direction 124 | """ 125 | h, w = x.shape[:2] 126 | shift = (sf-1)*0.5 127 | xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) 128 | if upper_left: 129 | x1 = xv + shift 130 | y1 = yv + shift 131 | else: 132 | x1 = xv - shift 133 | y1 = yv - shift 134 | 135 | x1 = np.clip(x1, 0, w-1) 136 | y1 = np.clip(y1, 0, h-1) 137 | 138 | if x.ndim == 2: 139 | x = interp2d(xv, yv, x)(x1, y1) 140 | if x.ndim == 3: 141 | for i in range(x.shape[-1]): 142 | x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) 143 | 144 | return x -------------------------------------------------------------------------------- /utils/img_util.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import numpy as np 4 | import torch 5 | from torch.fft import fft2, fftshift, ifft2, ifftshift 6 | from torchvision.utils import save_image 7 | 8 | 9 | def draw_img(img: Union[torch.Tensor, np.ndarray], 10 | save_path:Optional[str]='test.png', 11 | nrow:Optional[int]=8, 12 | normalize:Optional[bool]=True): 13 | if isinstance(img, np.ndarray): 14 | img = torch.Tensor(img) 15 | 16 | save_image(img, fp=save_path, nrow=nrow, normalize=normalize) 17 | 18 | def normalize(img: Union[torch.Tensor, np.ndarray]) \ 19 | -> Union[torch.Tensor, np.ndarray]: 20 | 21 | return (img - img.min())/(img.max()-img.min()) 22 | 23 | def to_np(img: torch.Tensor, 24 | mode: Optional[str]='NCHW') -> np.ndarray: 25 | 26 | assert mode in ['NCHW', 'NHWC'] 27 | 28 | if mode == 'NCHW': 29 | img = img.permute(0,2,3,1) 30 | 31 | return img.detach().cpu().numpy() 32 | 33 | def fft2d(img: torch.Tensor, 34 | mode: Optional[str]='NCHW') -> torch.Tensor: 35 | 36 | assert mode in ['NCHW', 'NHWC'] 37 | 38 | if mode == 'NCHW': 39 | return fftshift(fft2(img)) 40 | elif mode == 'NHWC': 41 | img = img.permute(0,3,1,2) 42 | return fftshift(fft2(img)) 43 | else: 44 | raise NameError 45 | 46 | 47 | def ifft2d(img: torch.Tensor, 48 | mode: Optional[str]='NCHW') -> torch.Tensor: 49 | 50 | assert mode in ['NCHW', 'NHWC'] 51 | 52 | if mode == 'NCHW': 53 | return ifft2(ifftshift(img)) 54 | elif mode == 'NHWC': 55 | img = ifft2(ifftshift(img)) 56 | return img.permute(0,2,3,1) 57 | else: 58 | raise NameError 59 | 60 | 61 | """ 62 | Helper functions for new types of inverse problems 63 | """ 64 | 65 | def fft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: 66 | """ 67 | Apply centered 2 dimensional Fast Fourier Transform. 68 | Args: 69 | data: Complex valued input data containing at least 3 dimensions: 70 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size 71 | 2. All other dimensions are assumed to be batch dimensions. 72 | norm: Normalization mode. See ``torch.fft.fft``. 73 | Returns: 74 | The FFT of the input. 75 | """ 76 | if not data.shape[-1] == 2: 77 | raise ValueError("Tensor does not have separate complex dim.") 78 | 79 | data = ifftshift(data, dim=[-3, -2]) 80 | data = torch.view_as_real( 81 | torch.fft.fftn( # type: ignore 82 | torch.view_as_complex(data), dim=(-2, -1), norm=norm 83 | ) 84 | ) 85 | data = fftshift(data, dim=[-3, -2]) 86 | 87 | return data 88 | 89 | 90 | def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: 91 | """ 92 | Apply centered 2-dimensional Inverse Fast Fourier Transform. 93 | Args: 94 | data: Complex valued input data containing at least 3 dimensions: 95 | dimensions -3 & -2 are spatial dimensions and dimension -1 has size 96 | 2. All other dimensions are assumed to be batch dimensions. 97 | norm: Normalization mode. See ``torch.fft.ifft``. 98 | Returns: 99 | The IFFT of the input. 100 | """ 101 | if not data.shape[-1] == 2: 102 | raise ValueError("Tensor does not have separate complex dim.") 103 | 104 | data = ifftshift(data, dim=[-3, -2]) 105 | data = torch.view_as_real( 106 | torch.fft.ifftn( # type: ignore 107 | torch.view_as_complex(data), dim=(-2, -1), norm=norm 108 | ) 109 | ) 110 | data = fftshift(data, dim=[-3, -2]) 111 | 112 | return data 113 | 114 | def fft2(x): 115 | """ FFT with shifting DC to the center of the image""" 116 | return torch.fft.fftshift(torch.fft.fft2(x), dim=[-1, -2]) 117 | 118 | 119 | def ifft2(x): 120 | """ IFFT with shifting DC to the corner of the image prior to transform""" 121 | return torch.fft.ifft2(torch.fft.ifftshift(x, dim=[-1, -2])) 122 | 123 | 124 | def fft2_m(x): 125 | """ FFT for multi-coil """ 126 | if not torch.is_complex(x): 127 | x = x.type(torch.complex64) 128 | return torch.view_as_complex(fft2c_new(torch.view_as_real(x))) 129 | 130 | 131 | def ifft2_m(x): 132 | """ IFFT for multi-coil """ 133 | if not torch.is_complex(x): 134 | x = x.type(torch.complex64) 135 | return torch.view_as_complex(ifft2c_new(torch.view_as_real(x))) -------------------------------------------------------------------------------- /utils/inpaint_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def random_sq_bbox(img, mask_shape, image_size=256, margin=(16, 16)): 6 | """Generate a random sqaure mask for inpainting 7 | """ 8 | B, C, H, W = img.shape 9 | h, w = mask_shape 10 | margin_height, margin_width = margin 11 | maxt = image_size - margin_height - h 12 | maxl = image_size - margin_width - w 13 | 14 | # bb 15 | t = np.random.randint(margin_height, maxt) 16 | l = np.random.randint(margin_width, maxl) 17 | 18 | # make mask 19 | mask = torch.ones([B, C, H, W], device=img.device) 20 | mask[..., t:t+h, l:l+w] = 0 21 | 22 | return mask, t, t+h, l, l+w 23 | 24 | 25 | class MaskGenerator: 26 | def __init__(self, mask_type, mask_len_range=None, mask_prob_range=None, 27 | image_size=256, margin=(16, 16)): 28 | """ 29 | (mask_len_range): given in (min, max) tuple. 30 | Specifies the range of box size in each dimension 31 | (mask_prob_range): for the case of random masking, 32 | specify the probability of individual pixels being masked 33 | """ 34 | assert mask_type in ['box', 'random', 'both', 'extreme'] 35 | self.mask_type = mask_type 36 | self.mask_len_range = mask_len_range 37 | self.mask_prob_range = mask_prob_range 38 | self.image_size = image_size 39 | self.margin = margin 40 | 41 | def _retrieve_box(self, img): 42 | l, h = self.mask_len_range 43 | l, h = int(l), int(h) 44 | mask_h = np.random.randint(l, h) 45 | mask_w = np.random.randint(l, h) 46 | mask, t, tl, w, wh = random_sq_bbox(img, 47 | mask_shape=(mask_h, mask_w), 48 | image_size=self.image_size, 49 | margin=self.margin) 50 | return mask, t, tl, w, wh 51 | 52 | def _retrieve_random(self, img): 53 | total = self.image_size ** 2 54 | # random pixel sampling 55 | l, h = self.mask_prob_range 56 | prob = np.random.uniform(l, h) 57 | mask_vec = torch.ones([1, self.image_size * self.image_size]) 58 | samples = np.random.choice(self.image_size * self.image_size, int(total * prob), replace=False) 59 | mask_vec[:, samples] = 0 60 | mask_b = mask_vec.view(1, self.image_size, self.image_size) 61 | mask_b = mask_b.repeat(3, 1, 1) 62 | mask = torch.ones_like(img, device=img.device) 63 | mask[:, ...] = mask_b 64 | return mask 65 | 66 | def __call__(self, img): 67 | if self.mask_type == 'random': 68 | mask = self._retrieve_random(img) 69 | return mask 70 | elif self.mask_type == 'box': 71 | mask, t, th, w, wl = self._retrieve_box(img) 72 | return mask 73 | elif self.mask_type == 'extreme': 74 | mask, t, th, w, wl = self._retrieve_box(img) 75 | mask = 1. - mask 76 | return mask 77 | 78 | -------------------------------------------------------------------------------- /utils/log_util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from PIL import Image 3 | from pathlib import Path 4 | from rich.logging import RichHandler 5 | 6 | class Logger: 7 | def __init__(self): 8 | self.name = "FlowDPS" 9 | 10 | def initLogger(self): 11 | __logger = logging.getLogger(self.name) 12 | 13 | FORMAT = f"[{self.name}] >> %(message)s" 14 | handler = RichHandler() 15 | handler.setFormatter(logging.Formatter(FORMAT)) 16 | 17 | __logger.addHandler(handler) 18 | 19 | __logger.setLevel(logging.INFO) 20 | 21 | return __logger 22 | 23 | def make_gif(input_path: Path, save_path: Path) -> None: 24 | files = sorted(input_path.glob('*.png')) 25 | frames = [] 26 | 27 | for f in files: 28 | frames.append(Image.open(f).convert('RGB')) 29 | 30 | frame_one = frames[0] 31 | frame_one.save(save_path, format="GIF", append_images=frames, 32 | save_all=True, duration=100, loop=0) 33 | -------------------------------------------------------------------------------- /utils/motionblur.py: -------------------------------------------------------------------------------- 1 | """ From https://github.com/LeviBorodenko/motionblur """ 2 | import numpy as np 3 | from PIL import Image, ImageDraw, ImageFilter 4 | from numpy.random import uniform, triangular, beta 5 | from math import pi 6 | from pathlib import Path 7 | from scipy.signal import convolve 8 | 9 | # tiny error used for nummerical stability 10 | eps = 0.1 11 | 12 | 13 | def softmax(x): 14 | """Compute softmax values for each sets of scores in x.""" 15 | e_x = np.exp(x - np.max(x)) 16 | return e_x / e_x.sum() 17 | 18 | 19 | def norm(lst: list) -> float: 20 | """[summary] 21 | L^2 norm of a list 22 | [description] 23 | Used for internals 24 | Arguments: 25 | lst {list} -- vector 26 | """ 27 | if not isinstance(lst, list): 28 | raise ValueError("Norm takes a list as its argument") 29 | 30 | if lst == []: 31 | return 0 32 | 33 | return (sum((i**2 for i in lst)))**0.5 34 | 35 | 36 | def polar2z(r: np.ndarray, θ: np.ndarray) -> np.ndarray: 37 | """[summary] 38 | Takes a list of radii and angles (radians) and 39 | converts them into a corresponding list of complex 40 | numbers x + yi. 41 | [description] 42 | 43 | Arguments: 44 | r {np.ndarray} -- radius 45 | θ {np.ndarray} -- angle 46 | 47 | Returns: 48 | [np.ndarray] -- list of complex numbers r e^(i theta) as x + iy 49 | """ 50 | return r * np.exp(1j * θ) 51 | 52 | 53 | class Kernel(object): 54 | """[summary] 55 | Class representing a motion blur kernel of a given intensity. 56 | 57 | [description] 58 | Keyword Arguments: 59 | size {tuple} -- Size of the kernel in px times px 60 | (default: {(100, 100)}) 61 | 62 | intensity {float} -- Float between 0 and 1. 63 | Intensity of the motion blur. 64 | 65 | : 0 means linear motion blur and 1 is a highly non linear 66 | and often convex motion blur path. (default: {0}) 67 | 68 | Attribute: 69 | kernelMatrix -- Numpy matrix of the kernel of given intensity 70 | 71 | Properties: 72 | applyTo -- Applies kernel to image 73 | (pass as path, pillow image or np array) 74 | 75 | Raises: 76 | ValueError 77 | """ 78 | 79 | def __init__(self, size: tuple = (100, 100), intensity: float=0): 80 | 81 | # checking if size is correctly given 82 | if not isinstance(size, tuple): 83 | raise ValueError("Size must be TUPLE of 2 positive integers") 84 | elif len(size) != 2 or type(size[0]) != type(size[1]) != int: 85 | raise ValueError("Size must be tuple of 2 positive INTEGERS") 86 | elif size[0] < 0 or size[1] < 0: 87 | raise ValueError("Size must be tuple of 2 POSITIVE integers") 88 | 89 | # check if intensity is float (int) between 0 and 1 90 | if type(intensity) not in [int, float, np.float32, np.float64]: 91 | raise ValueError("Intensity must be a number between 0 and 1") 92 | elif intensity < 0 or intensity > 1: 93 | raise ValueError("Intensity must be a number between 0 and 1") 94 | 95 | # saving args 96 | self.SIZE = size 97 | self.INTENSITY = intensity 98 | 99 | # deriving quantities 100 | 101 | # we super size first and then downscale at the end for better 102 | # anti-aliasing 103 | self.SIZEx2 = tuple([2 * i for i in size]) 104 | self.x, self.y = self.SIZEx2 105 | 106 | # getting length of kernel diagonal 107 | self.DIAGONAL = (self.x**2 + self.y**2)**0.5 108 | 109 | # flag to see if kernel has been calculated already 110 | self.kernel_is_generated = False 111 | 112 | def _createPath(self): 113 | """[summary] 114 | creates a motion blur path with the given intensity. 115 | [description] 116 | Proceede in 5 steps 117 | 1. Get a random number of random step sizes 118 | 2. For each step get a random angle 119 | 3. combine steps and angles into a sequence of increments 120 | 4. create path out of increments 121 | 5. translate path to fit the kernel dimensions 122 | 123 | NOTE: "random" means random but might depend on the given intensity 124 | """ 125 | 126 | # first we find the lengths of the motion blur steps 127 | def getSteps(): 128 | """[summary] 129 | Here we calculate the length of the steps taken by 130 | the motion blur 131 | [description] 132 | We want a higher intensity lead to a longer total motion 133 | blur path and more different steps along the way. 134 | 135 | Hence we sample 136 | 137 | MAX_PATH_LEN =[U(0,1) + U(0, intensity^2)] * diagonal * 0.75 138 | 139 | and each step: beta(1, 30) * (1 - self.INTENSITY + eps) * diagonal) 140 | """ 141 | 142 | # getting max length of blur motion 143 | self.MAX_PATH_LEN = 0.75 * self.DIAGONAL * \ 144 | (uniform() + uniform(0, self.INTENSITY**2)) 145 | 146 | # getting step 147 | steps = [] 148 | 149 | while sum(steps) < self.MAX_PATH_LEN: 150 | 151 | # sample next step 152 | step = beta(1, 30) * (1 - self.INTENSITY + eps) * self.DIAGONAL 153 | if step < self.MAX_PATH_LEN: 154 | steps.append(step) 155 | 156 | # note the steps and the total number of steps 157 | self.NUM_STEPS = len(steps) 158 | self.STEPS = np.asarray(steps) 159 | 160 | def getAngles(): 161 | """[summary] 162 | Gets an angle for each step 163 | [description] 164 | The maximal angle should be larger the more 165 | intense the motion is. So we sample it from a 166 | U(0, intensity * pi) 167 | 168 | We sample "jitter" from a beta(2,20) which is the probability 169 | that the next angle has a different sign than the previous one. 170 | """ 171 | 172 | # same as with the steps 173 | 174 | # first we get the max angle in radians 175 | self.MAX_ANGLE = uniform(0, self.INTENSITY * pi) 176 | 177 | # now we sample "jitter" which is the probability that the 178 | # next angle has a different sign than the previous one 179 | self.JITTER = beta(2, 20) 180 | 181 | # initialising angles (and sign of angle) 182 | angles = [uniform(low=-self.MAX_ANGLE, high=self.MAX_ANGLE)] 183 | 184 | while len(angles) < self.NUM_STEPS: 185 | 186 | # sample next angle (absolute value) 187 | angle = triangular(0, self.INTENSITY * 188 | self.MAX_ANGLE, self.MAX_ANGLE + eps) 189 | 190 | # with jitter probability change sign wrt previous angle 191 | if uniform() < self.JITTER: 192 | angle *= - np.sign(angles[-1]) 193 | else: 194 | angle *= np.sign(angles[-1]) 195 | 196 | angles.append(angle) 197 | 198 | # save angles 199 | self.ANGLES = np.asarray(angles) 200 | 201 | # Get steps and angles 202 | getSteps() 203 | getAngles() 204 | 205 | # Turn them into a path 206 | #### 207 | 208 | # we turn angles and steps into complex numbers 209 | complex_increments = polar2z(self.STEPS, self.ANGLES) 210 | 211 | # generate path as the cumsum of these increments 212 | self.path_complex = np.cumsum(complex_increments) 213 | 214 | # find center of mass of path 215 | self.com_complex = sum(self.path_complex) / self.NUM_STEPS 216 | 217 | # Shift path s.t. center of mass lies in the middle of 218 | # the kernel and a apply a random rotation 219 | ### 220 | 221 | # center it on COM 222 | center_of_kernel = (self.x + 1j * self.y) / 2 223 | self.path_complex -= self.com_complex 224 | 225 | # randomly rotate path by an angle a in (0, pi) 226 | self.path_complex *= np.exp(1j * uniform(0, pi)) 227 | 228 | # center COM on center of kernel 229 | self.path_complex += center_of_kernel 230 | 231 | # convert complex path to final list of coordinate tuples 232 | self.path = [(i.real, i.imag) for i in self.path_complex] 233 | 234 | def _createKernel(self, save_to: Path=None, show: bool=False): 235 | """[summary] 236 | Finds a kernel (psf) of given intensity. 237 | [description] 238 | use displayKernel to actually see the kernel. 239 | 240 | Keyword Arguments: 241 | save_to {Path} -- Image file to save the kernel to. {None} 242 | show {bool} -- shows kernel if true 243 | """ 244 | 245 | # check if we haven't already generated a kernel 246 | if self.kernel_is_generated: 247 | return None 248 | 249 | # get the path 250 | self._createPath() 251 | 252 | # Initialise an image with super-sized dimensions 253 | # (pillow Image object) 254 | self.kernel_image = Image.new("RGB", self.SIZEx2) 255 | 256 | # ImageDraw instance that is linked to the kernel image that 257 | # we can use to draw on our kernel_image 258 | self.painter = ImageDraw.Draw(self.kernel_image) 259 | 260 | # draw the path 261 | self.painter.line(xy=self.path, width=int(self.DIAGONAL / 150)) 262 | 263 | # applying gaussian blur for realism 264 | self.kernel_image = self.kernel_image.filter( 265 | ImageFilter.GaussianBlur(radius=int(self.DIAGONAL * 0.01))) 266 | 267 | # Resize to actual size 268 | self.kernel_image = self.kernel_image.resize( 269 | self.SIZE, resample=Image.LANCZOS) 270 | 271 | # convert to gray scale 272 | self.kernel_image = self.kernel_image.convert("L") 273 | 274 | # flag that we have generated a kernel 275 | self.kernel_is_generated = True 276 | 277 | def displayKernel(self, save_to: Path=None, show: bool=True): 278 | """[summary] 279 | Finds a kernel (psf) of given intensity. 280 | [description] 281 | Saves the kernel to save_to if needed or shows it 282 | is show true 283 | 284 | Keyword Arguments: 285 | save_to {Path} -- Image file to save the kernel to. {None} 286 | show {bool} -- shows kernel if true 287 | """ 288 | 289 | # generate kernel if needed 290 | self._createKernel() 291 | 292 | # save if needed 293 | if save_to is not None: 294 | 295 | save_to_file = Path(save_to) 296 | 297 | # save Kernel image 298 | self.kernel_image.save(save_to_file) 299 | else: 300 | # Show kernel 301 | self.kernel_image.show() 302 | 303 | @property 304 | def kernelMatrix(self) -> np.ndarray: 305 | """[summary] 306 | Kernel matrix of motion blur of given intensity. 307 | [description] 308 | Once generated, it stays the same. 309 | Returns: 310 | numpy ndarray 311 | """ 312 | 313 | # generate kernel if needed 314 | self._createKernel() 315 | kernel = np.asarray(self.kernel_image, dtype=np.float32) 316 | kernel /= np.sum(kernel) 317 | 318 | return kernel 319 | 320 | @kernelMatrix.setter 321 | def kernelMatrix(self, *kargs): 322 | raise NotImplementedError("Can't manually set kernel matrix yet") 323 | 324 | def applyTo(self, image, keep_image_dim: bool = False) -> Image: 325 | """[summary] 326 | Applies kernel to one of the following: 327 | 328 | 1. Path to image file 329 | 2. Pillow image object 330 | 3. (H,W,3)-shaped numpy array 331 | [description] 332 | 333 | Arguments: 334 | image {[str, Path, Image, np.ndarray]} 335 | keep_image_dim {bool} -- If true, then we will 336 | conserve the image dimension after blurring 337 | by using "same" convolution instead of "valid" 338 | convolution inside the scipy convolve function. 339 | 340 | Returns: 341 | Image -- [description] 342 | """ 343 | # calculate kernel if haven't already 344 | self._createKernel() 345 | 346 | def applyToPIL(image: Image, keep_image_dim: bool = False) -> Image: 347 | """[summary] 348 | Applies the kernel to an PIL.Image instance 349 | [description] 350 | converts to RGB and applies the kernel to each 351 | band before recombining them. 352 | Arguments: 353 | image {Image} -- Image to convolve 354 | keep_image_dim {bool} -- If true, then we will 355 | conserve the image dimension after blurring 356 | by using "same" convolution instead of "valid" 357 | convolution inside the scipy convolve function. 358 | 359 | Returns: 360 | Image -- blurred image 361 | """ 362 | # convert to RGB 363 | image = image.convert(mode="RGB") 364 | 365 | conv_mode = "valid" 366 | if keep_image_dim: 367 | conv_mode = "same" 368 | 369 | result_bands = () 370 | 371 | for band in image.split(): 372 | 373 | # convolve each band individually with kernel 374 | result_band = convolve( 375 | band, self.kernelMatrix, mode=conv_mode).astype("uint8") 376 | 377 | # collect bands 378 | result_bands += result_band, 379 | 380 | # stack bands back together 381 | result = np.dstack(result_bands) 382 | 383 | # Get image 384 | return Image.fromarray(result) 385 | 386 | # If image is Path 387 | if isinstance(image, str) or isinstance(image, Path): 388 | 389 | # open image as Image class 390 | image_path = Path(image) 391 | image = Image.open(image_path) 392 | 393 | return applyToPIL(image, keep_image_dim) 394 | 395 | elif isinstance(image, Image.Image): 396 | 397 | # apply kernel 398 | return applyToPIL(image, keep_image_dim) 399 | 400 | elif isinstance(image, np.ndarray): 401 | 402 | # ASSUMES we have an array of the form (H, W, 3) 403 | ### 404 | 405 | # initiate Image object from array 406 | image = Image.fromarray(image) 407 | 408 | return applyToPIL(image, keep_image_dim) 409 | 410 | else: 411 | 412 | raise ValueError("Cannot apply kernel to this type.") 413 | 414 | 415 | if __name__ == '__main__': 416 | image = Image.open("./images/moon.png") 417 | image.show() 418 | k = Kernel() 419 | 420 | k.applyTo(image, keep_image_dim=True).show() 421 | -------------------------------------------------------------------------------- /utils/resizer.py: -------------------------------------------------------------------------------- 1 | # This code was taken from: https://github.com/assafshocher/resizer by Assaf Shocher 2 | import numpy as np 3 | import torch 4 | from math import pi 5 | from torch import nn 6 | 7 | 8 | class Resizer(nn.Module): 9 | def __init__(self, in_shape, scale_factor=None, output_shape=None, kernel=None, antialiasing=True): 10 | super(Resizer, self).__init__() 11 | 12 | # First standardize values and fill missing arguments (if needed) by deriving scale from output shape or vice versa 13 | scale_factor, output_shape = self.fix_scale_and_size(in_shape, output_shape, scale_factor) 14 | 15 | # Choose interpolation method, each method has the matching kernel size 16 | method, kernel_width = { 17 | "cubic": (cubic, 4.0), 18 | "lanczos2": (lanczos2, 4.0), 19 | "lanczos3": (lanczos3, 6.0), 20 | "box": (box, 1.0), 21 | "linear": (linear, 2.0), 22 | None: (cubic, 4.0) # set default interpolation method as cubic 23 | }.get(kernel) 24 | 25 | # Antialiasing is only used when downscaling 26 | antialiasing *= (np.any(np.array(scale_factor) < 1)) 27 | 28 | # Sort indices of dimensions according to scale of each dimension. since we are going dim by dim this is efficient 29 | sorted_dims = np.argsort(np.array(scale_factor)) 30 | self.sorted_dims = [int(dim) for dim in sorted_dims if scale_factor[dim] != 1] 31 | 32 | # Iterate over dimensions to calculate local weights for resizing and resize each time in one direction 33 | field_of_view_list = [] 34 | weights_list = [] 35 | for dim in self.sorted_dims: 36 | # for each coordinate (along 1 dim), calculate which coordinates in the input image affect its result and the 37 | # weights that multiply the values there to get its result. 38 | weights, field_of_view = self.contributions(in_shape[dim], output_shape[dim], scale_factor[dim], method, 39 | kernel_width, antialiasing) 40 | 41 | # convert to torch tensor 42 | weights = torch.tensor(weights.T, dtype=torch.float32) 43 | 44 | # We add singleton dimensions to the weight matrix so we can multiply it with the big tensor we get for 45 | # tmp_im[field_of_view.T], (bsxfun style) 46 | weights_list.append( 47 | nn.Parameter(torch.reshape(weights, list(weights.shape) + (len(scale_factor) - 1) * [1]), 48 | requires_grad=False)) 49 | field_of_view_list.append( 50 | nn.Parameter(torch.tensor(field_of_view.T.astype(np.int32), dtype=torch.long), requires_grad=False)) 51 | 52 | self.field_of_view = nn.ParameterList(field_of_view_list) 53 | self.weights = nn.ParameterList(weights_list) 54 | 55 | def forward(self, in_tensor): 56 | x = in_tensor 57 | 58 | # Use the affecting position values and the set of weights to calculate the result of resizing along this 1 dim 59 | for dim, fov, w in zip(self.sorted_dims, self.field_of_view, self.weights): 60 | # To be able to act on each dim, we swap so that dim 0 is the wanted dim to resize 61 | x = torch.transpose(x, dim, 0) 62 | 63 | # This is a bit of a complicated multiplication: x[field_of_view.T] is a tensor of order image_dims+1. 64 | # for each pixel in the output-image it matches the positions the influence it from the input image (along 1 dim 65 | # only, this is why it only adds 1 dim to 5the shape). We then multiply, for each pixel, its set of positions with 66 | # the matching set of weights. we do this by this big tensor element-wise multiplication (MATLAB bsxfun style: 67 | # matching dims are multiplied element-wise while singletons mean that the matching dim is all multiplied by the 68 | # same number 69 | x = torch.sum(x[fov] * w, dim=0) 70 | 71 | # Finally we swap back the axes to the original order 72 | x = torch.transpose(x, dim, 0) 73 | 74 | return x 75 | 76 | def fix_scale_and_size(self, input_shape, output_shape, scale_factor): 77 | # First fixing the scale-factor (if given) to be standardized the function expects (a list of scale factors in the 78 | # same size as the number of input dimensions) 79 | if scale_factor is not None: 80 | # By default, if scale-factor is a scalar we assume 2d resizing and duplicate it. 81 | if np.isscalar(scale_factor) and len(input_shape) > 1: 82 | scale_factor = [scale_factor, scale_factor] 83 | 84 | # We extend the size of scale-factor list to the size of the input by assigning 1 to all the unspecified scales 85 | scale_factor = list(scale_factor) 86 | scale_factor = [1] * (len(input_shape) - len(scale_factor)) + scale_factor 87 | 88 | # Fixing output-shape (if given): extending it to the size of the input-shape, by assigning the original input-size 89 | # to all the unspecified dimensions 90 | if output_shape is not None: 91 | output_shape = list(input_shape[len(output_shape):]) + list(np.uint(np.array(output_shape))) 92 | 93 | # Dealing with the case of non-give scale-factor, calculating according to output-shape. note that this is 94 | # sub-optimal, because there can be different scales to the same output-shape. 95 | if scale_factor is None: 96 | scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape) 97 | 98 | # Dealing with missing output-shape. calculating according to scale-factor 99 | if output_shape is None: 100 | output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor))) 101 | 102 | return scale_factor, output_shape 103 | 104 | def contributions(self, in_length, out_length, scale, kernel, kernel_width, antialiasing): 105 | # This function calculates a set of 'filters' and a set of field_of_view that will later on be applied 106 | # such that each position from the field_of_view will be multiplied with a matching filter from the 107 | # 'weights' based on the interpolation method and the distance of the sub-pixel location from the pixel centers 108 | # around it. This is only done for one dimension of the image. 109 | 110 | # When anti-aliasing is activated (default and only for downscaling) the receptive field is stretched to size of 111 | # 1/sf. this means filtering is more 'low-pass filter'. 112 | fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel 113 | kernel_width *= 1.0 / scale if antialiasing else 1.0 114 | 115 | # These are the coordinates of the output image 116 | out_coordinates = np.arange(1, out_length + 1) 117 | 118 | # since both scale-factor and output size can be provided simulatneously, perserving the center of the image requires shifting 119 | # the output coordinates. the deviation is because out_length doesn't necesary equal in_length*scale. 120 | # to keep the center we need to subtract half of this deivation so that we get equal margins for boths sides and center is preserved. 121 | shifted_out_coordinates = out_coordinates - (out_length - in_length * scale) / 2 122 | 123 | # These are the matching positions of the output-coordinates on the input image coordinates. 124 | # Best explained by example: say we have 4 horizontal pixels for HR and we downscale by SF=2 and get 2 pixels: 125 | # [1,2,3,4] -> [1,2]. Remember each pixel number is the middle of the pixel. 126 | # The scaling is done between the distances and not pixel numbers (the right boundary of pixel 4 is transformed to 127 | # the right boundary of pixel 2. pixel 1 in the small image matches the boundary between pixels 1 and 2 in the big 128 | # one and not to pixel 2. This means the position is not just multiplication of the old pos by scale-factor). 129 | # So if we measure distance from the left border, middle of pixel 1 is at distance d=0.5, border between 1 and 2 is 130 | # at d=1, and so on (d = p - 0.5). we calculate (d_new = d_old / sf) which means: 131 | # (p_new-0.5 = (p_old-0.5) / sf) -> p_new = p_old/sf + 0.5 * (1-1/sf) 132 | match_coordinates = shifted_out_coordinates / scale + 0.5 * (1 - 1 / scale) 133 | 134 | # This is the left boundary to start multiplying the filter from, it depends on the size of the filter 135 | left_boundary = np.floor(match_coordinates - kernel_width / 2) 136 | 137 | # Kernel width needs to be enlarged because when covering has sub-pixel borders, it must 'see' the pixel centers 138 | # of the pixels it only covered a part from. So we add one pixel at each side to consider (weights can zeroize them) 139 | expanded_kernel_width = np.ceil(kernel_width) + 2 140 | 141 | # Determine a set of field_of_view for each each output position, these are the pixels in the input image 142 | # that the pixel in the output image 'sees'. We get a matrix whos horizontal dim is the output pixels (big) and the 143 | # vertical dim is the pixels it 'sees' (kernel_size + 2) 144 | field_of_view = np.squeeze( 145 | np.int16(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1)) 146 | 147 | # Assign weight to each pixel in the field of view. A matrix whos horizontal dim is the output pixels and the 148 | # vertical dim is a list of weights matching to the pixel in the field of view (that are specified in 149 | # 'field_of_view') 150 | weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1) 151 | 152 | # Normalize weights to sum up to 1. be careful from dividing by 0 153 | sum_weights = np.sum(weights, axis=1) 154 | sum_weights[sum_weights == 0] = 1.0 155 | weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1) 156 | 157 | # We use this mirror structure as a trick for reflection padding at the boundaries 158 | mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1)))) 159 | field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])] 160 | 161 | # Get rid of weights and pixel positions that are of zero weight 162 | non_zero_out_pixels = np.nonzero(np.any(weights, axis=0)) 163 | weights = np.squeeze(weights[:, non_zero_out_pixels]) 164 | field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels]) 165 | 166 | # Final products are the relative positions and the matching weights, both are output_size X fixed_kernel_size 167 | return weights, field_of_view 168 | 169 | 170 | # These next functions are all interpolation methods. x is the distance from the left pixel center 171 | 172 | 173 | def cubic(x): 174 | absx = np.abs(x) 175 | absx2 = absx ** 2 176 | absx3 = absx ** 3 177 | return ((1.5 * absx3 - 2.5 * absx2 + 1) * (absx <= 1) + 178 | (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * ((1 < absx) & (absx <= 2))) 179 | 180 | 181 | def lanczos2(x): 182 | return (((np.sin(pi * x) * np.sin(pi * x / 2) + np.finfo(np.float32).eps) / 183 | ((pi ** 2 * x ** 2 / 2) + np.finfo(np.float32).eps)) 184 | * (abs(x) < 2)) 185 | 186 | 187 | def box(x): 188 | return ((-0.5 <= x) & (x < 0.5)) * 1.0 189 | 190 | 191 | def lanczos3(x): 192 | return (((np.sin(pi * x) * np.sin(pi * x / 3) + np.finfo(np.float32).eps) / 193 | ((pi ** 2 * x ** 2 / 3) + np.finfo(np.float32).eps)) 194 | * (abs(x) < 3)) 195 | 196 | 197 | def linear(x): 198 | return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1)) -------------------------------------------------------------------------------- /utils/utils_sisr.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch.fft 3 | import torch 4 | 5 | import numpy as np 6 | from scipy import ndimage 7 | from scipy.interpolate import interp2d 8 | 9 | def splits(a, sf): 10 | '''split a into sfxsf distinct blocks 11 | Args: 12 | a: NxCxWxH 13 | sf: split factor 14 | Returns: 15 | b: NxCx(W/sf)x(H/sf)x(sf^2) 16 | ''' 17 | b = torch.stack(torch.chunk(a, sf, dim=2), dim=4) 18 | b = torch.cat(torch.chunk(b, sf, dim=3), dim=4) 19 | return b 20 | 21 | 22 | def p2o(psf, shape): 23 | ''' 24 | Convert point-spread function to optical transfer function. 25 | otf = p2o(psf) computes the Fast Fourier Transform (FFT) of the 26 | point-spread function (PSF) array and creates the optical transfer 27 | function (OTF) array that is not influenced by the PSF off-centering. 28 | Args: 29 | psf: NxCxhxw 30 | shape: [H, W] 31 | Returns: 32 | otf: NxCxHxWx2 33 | ''' 34 | otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf) 35 | otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf) 36 | for axis, axis_size in enumerate(psf.shape[2:]): 37 | otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2) 38 | otf = torch.fft.fftn(otf, dim=(-2,-1)) 39 | #n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf))) 40 | #otf[..., 1][torch.abs(otf[..., 1]) < n_ops*2.22e-16] = torch.tensor(0).type_as(psf) 41 | return otf 42 | 43 | 44 | def upsample(x, sf=3): 45 | '''s-fold upsampler 46 | Upsampling the spatial size by filling the new entries with zeros 47 | x: tensor image, NxCxWxH 48 | ''' 49 | st = 0 50 | z = torch.zeros((x.shape[0], x.shape[1], x.shape[2]*sf, x.shape[3]*sf)).type_as(x) 51 | z[..., st::sf, st::sf].copy_(x) 52 | return z 53 | 54 | 55 | def downsample(x, sf=3): 56 | '''s-fold downsampler 57 | Keeping the upper-left pixel for each distinct sfxsf patch and discarding the others 58 | x: tensor image, NxCxWxH 59 | ''' 60 | st = 0 61 | return x[..., st::sf, st::sf] 62 | 63 | 64 | def data_solution_simple(x, F2K, FKFy, rho): 65 | rho = rho.clip(min=1e-2) 66 | numerator = FKFy + torch.fft.fftn(rho*x, dim=(-2,-1)) 67 | denominator = F2K + rho 68 | FX = numerator / denominator 69 | Xest = torch.real(torch.fft.ifftn(FX, dim=(-2, -1))) 70 | return Xest 71 | 72 | 73 | def data_solution_nonuniform(x, FK, FKC, F2KM, FKFMy, rho): 74 | rho = rho.clip(min=1e-2) 75 | numerator = FKFMy + torch.fft.fftn(rho*x, dim=(-2,-1)) 76 | denominator = F2KM + rho 77 | FX = numerator / denominator 78 | Xest = torch.real(torch.fft.ifftn(FX, dim=(-2, -1))) 79 | return Xest 80 | 81 | 82 | def pre_calculate(x, k): 83 | ''' 84 | Args: 85 | x: NxCxHxW, LR input 86 | k: NxCxhxw 87 | 88 | Returns: 89 | FK, FKC, F2K 90 | will be reused during iterations 91 | ''' 92 | w, h = x.shape[-2:] 93 | FK = p2o(k, (w, h)) 94 | FKC = torch.conj(FK) 95 | F2K = torch.pow(torch.abs(FK), 2) 96 | return FK, FKC, F2K 97 | 98 | 99 | def pre_calculate_FK(k): 100 | ''' 101 | Args: 102 | k: [25, 1, 33, 33] 25 is the number of filters 103 | 104 | Returns: 105 | FK: 106 | FKC: 107 | ''' 108 | # [25, 1, 512, 512] (expanded from) [25, 1, 33, 33] 109 | FK = p2o(k, (512, 512)) 110 | FKC = torch.conj(FK) 111 | return FK, FKC 112 | 113 | 114 | def pre_calculate_nonuniform(x, y, FK, FKC, mask): 115 | ''' 116 | Args: 117 | x: [1, 3, 512, 512] 118 | y: [1, 3, 512, 512] 119 | FK: [25, 1, 512, 512] 25 is the number of filters 120 | FKC: [25, 1, 512, 512] 121 | m: [1, 25, 512, 512] 122 | 123 | Returns: 124 | ''' 125 | mask = mask.transpose(0, 1) 126 | w, h = x.shape[-2:] 127 | # [1, 3, 512, 512] -> [25, 3, 512, 512] 128 | By = y.repeat(mask.shape[0], 1, 1, 1) 129 | # [25, 3, 512, 512] 130 | My = mask * By 131 | # or use just fft..? 132 | FMy = torch.fft.fft2(My) 133 | 134 | # [25, 3, 512, 512] 135 | FKFMy = FK * FMy 136 | # [1, 3, 512, 512] 137 | FKFMy = torch.sum(FKFMy, dim=0, keepdim=True) 138 | 139 | # [25, 1, 512, 512] 140 | F2KM = torch.abs(FKC * (mask ** 2) * FK) 141 | # [1, 1, 512, 512] 142 | F2KM = torch.sum(F2KM, dim=0, keepdim=True) 143 | return F2KM, FKFMy 144 | 145 | 146 | def classical_degradation(x, k, sf=3): 147 | ''' blur + downsampling 148 | 149 | Args: 150 | x: HxWxC image, [0, 1]/[0, 255] 151 | k: hxw, double 152 | sf: down-scale factor 153 | 154 | Return: 155 | downsampled LR image 156 | ''' 157 | x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') 158 | #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) 159 | st = 0 160 | return x[st::sf, st::sf, ...] 161 | 162 | 163 | 164 | def shift_pixel(x, sf, upper_left=True): 165 | """shift pixel for super-resolution with different scale factors 166 | Args: 167 | x: WxHxC or WxH, image or kernel 168 | sf: scale factor 169 | upper_left: shift direction 170 | """ 171 | h, w = x.shape[:2] 172 | shift = (sf-1)*0.5 173 | xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) 174 | if upper_left: 175 | x1 = xv + shift 176 | y1 = yv + shift 177 | else: 178 | x1 = xv - shift 179 | y1 = yv - shift 180 | 181 | x1 = np.clip(x1, 0, w-1) 182 | y1 = np.clip(y1, 0, h-1) 183 | 184 | if x.ndim == 2: 185 | x = interp2d(xv, yv, x)(x1, y1) 186 | if x.ndim == 3: 187 | for i in range(x.shape[-1]): 188 | x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) 189 | 190 | return x 191 | --------------------------------------------------------------------------------