├── .gitignore ├── LICENSE ├── README.md ├── configs ├── bedroom.yml ├── celeba.yml └── tower.yml ├── datasets ├── __init__.py ├── celeba.py ├── utils.py └── vision.py ├── environment.yml ├── filter_builder.py ├── main.py ├── models ├── __init__.py ├── ema.py ├── layers.py ├── ncsnv2.py └── normalization.py └── runners ├── __init__.py └── ncsn_runner.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.gif 2 | *.pdf 3 | exp/ 4 | .idea/ 5 | __pycache__ 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yang Song (yangsong@cs.stanford.edu) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SNIPS: Solving Noisy Inverse Problems Stochastically 2 | 3 | This repo contains the official implementation for the paper [SNIPS: Solving Noisy Inverse Problems Stochastically](http://arxiv.org/abs/2105.14951). 4 | 5 | by Bahjat Kawar, Gregory Vaksman, and Michael Elad, Computer Science Department, Technion. 6 | 7 | ## Running Experiments 8 | 9 | ### Dependencies 10 | 11 | Run the following conda line to install all necessary python packages for our code and set up the snips environment. 12 | 13 | ```bash 14 | conda env create -f environment.yml 15 | ``` 16 | 17 | The environment includes `cudatoolkit=11.0`. You may change that depending on your hardware. 18 | 19 | ### Project structure 20 | 21 | `main.py` is the file that you should run for both training and sampling. Execute ```python main.py --help``` to get its usage description: 22 | 23 | ``` 24 | usage: main.py [-h] --config CONFIG [--seed SEED] [--exp EXP] --doc DOC 25 | [--comment COMMENT] [--verbose VERBOSE] [-i IMAGE_FOLDER] 26 | [-n NUM_VARIATIONS] [-s SIGMA_0] [--degradation DEGRADATION] 27 | 28 | optional arguments: 29 | -h, --help show this help message and exit 30 | --config CONFIG Path to the config file 31 | --seed SEED Random seed 32 | --exp EXP Path for saving running related data. 33 | --doc DOC A string for documentation purpose. Will be the name 34 | of the log folder. 35 | --comment COMMENT A string for experiment comment 36 | --verbose VERBOSE Verbose level: info | debug | warning | critical 37 | -i IMAGE_FOLDER, --image_folder IMAGE_FOLDER 38 | The folder name of samples 39 | -n NUM_VARIATIONS, --num_variations NUM_VARIATIONS 40 | Number of variations to produce 41 | -s SIGMA_0, --sigma_0 SIGMA_0 42 | Noise std to add to observation 43 | --degradation DEGRADATION 44 | Degradation: inp | deblur_uni | deblur_gauss | sr2 | 45 | sr4 | cs4 | cs8 | cs16 46 | 47 | ``` 48 | 49 | Configuration files are in `config/`. You don't need to include the prefix `config/` when specifying `--config` . All files generated when running the code is under the directory specified by `--exp`. They are structured as: 50 | 51 | ```bash 52 | # a folder named by the argument `--exp` given to main.py 53 | ├── datasets # all dataset files 54 | │ ├── celeba # all CelebA files 55 | │ └── lsun # all LSUN files 56 | ├── logs # contains checkpoints and samples produced during training 57 | │ └── # a folder named by the argument `--doc` specified to main.py 58 | │ └── checkpoint_x.pth # the checkpoint file saved at the x-th training iteration 59 | ├── image_samples # contains generated samples 60 | │ └── 61 | │ ├── stochastic_variation.png # samples generated from checkpoint_x.pth, including original, degraded, mean, and std 62 | │ ├── results.pt # the pytorch tensor corresponding to stochastic_variation.png 63 | │ └── y_0.pt # the pytorch tensor containing the input y of SNIPS 64 | ``` 65 | 66 | ### Downloading data 67 | 68 | You can download the aligned and cropped CelebA files from their official source [here](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). The LSUN files can be downloaded using [this script](https://github.com/fyu/lsun). For our purposes, only the validation sets of LSUN bedroom and tower need to be downloaded. 69 | 70 | ### Running SNIPS 71 | 72 | If we want to run SNIPS on CelebA for the problem of super resolution by 2, with added noise of standard deviation 0.1, and obtain 3 variations, we can run the following 73 | 74 | ```bash 75 | python main.py -i celeba --config celeba.yml --doc celeba -n 3 --degradation sr2 --sigma_0 0.1 76 | ``` 77 | 78 | Samples will be saved in `/image_samples/celeba`. 79 | 80 | The available degradations are: Inpainting (`inp`), Uniform deblurring (`deblur_uni`), Gaussian deblurring (`deblur_gauss`), Super resolution by 2 (`sr2`) or by 4 (`sr4`), Compressive sensing by 4 (`cs4`), 8 (`cs8`), or 16 (`cs16`). The sigma_0 can be any value from 0 to 1. 81 | 82 | ## Pretrained Checkpoints 83 | 84 | Link: https://drive.google.com/drive/folders/1217uhIvLg9ZrYNKOR3XTRFSurt4miQrd?usp=sharing 85 | 86 | These checkpoint files are provided as-is from the authors of [NCSNv2](https://github.com/ermongroup/ncsnv2). You can use the CelebA, LSUN-bedroom, and LSUN-tower datasets' pretrained checkpoints. We assume the `--exp` argument is set to `exp`. 87 | 88 | ## Acknowledgement 89 | 90 | This repo is largely based on the [NCSNv2](https://github.com/ermongroup/ncsnv2) repo, and uses modified code from [this repo](https://github.com/alisaaalehi/convolution_as_multiplication) for implementing the blurring matrix. 91 | 92 | ## References 93 | 94 | If you find the code/idea useful for your research, please consider citing 95 | 96 | ```bib 97 | @article{kawar2021snips, 98 | title={{SNIPS}: Solving noisy inverse problems stochastically}, 99 | author={Kawar, Bahjat and Vaksman, Gregory and Elad, Michael}, 100 | journal={Advances in Neural Information Processing Systems}, 101 | volume={34}, 102 | pages={21757--21769}, 103 | year={2021} 104 | } 105 | ``` 106 | 107 | -------------------------------------------------------------------------------- /configs/bedroom.yml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 128 3 | n_epochs: 500000 4 | n_iters: 150001 5 | snapshot_freq: 5000 6 | snapshot_sampling: true 7 | anneal_power: 2 8 | log_all_sigmas: false 9 | 10 | sampling: 11 | batch_size: 6 12 | data_init: false 13 | step_lr: 0.0000018 14 | n_steps_each: 3 15 | ckpt_id: 150000 16 | final_only: true 17 | fid: false 18 | denoise: true 19 | num_samples4fid: 10000 20 | inpainting: false 21 | interpolation: false 22 | n_interpolations: 10 23 | 24 | fast_fid: 25 | batch_size: 1000 26 | num_samples: 1000 27 | step_lr: 0.0000018 28 | n_steps_each: 3 29 | begin_ckpt: 100000 30 | end_ckpt: 150000 31 | verbose: false 32 | ensemble: false 33 | 34 | test: 35 | begin_ckpt: 5000 36 | end_ckpt: 150000 37 | batch_size: 100 38 | 39 | data: 40 | dataset: "LSUN" 41 | category: "bedroom" 42 | image_size: 128 43 | channels: 3 44 | logit_transform: false 45 | uniform_dequantization: false 46 | gaussian_dequantization: false 47 | random_flip: true 48 | rescaled: false 49 | num_workers: 32 50 | 51 | model: 52 | sigma_begin: 190 53 | num_classes: 1086 54 | ema: true 55 | ema_rate: 0.999 56 | spec_norm: false 57 | sigma_dist: geometric 58 | sigma_end: 0.01 59 | normalization: InstanceNorm++ 60 | nonlinearity: elu 61 | ngf: 128 62 | 63 | optim: 64 | weight_decay: 0.000 65 | optimizer: "Adam" 66 | lr: 0.0001 67 | beta1: 0.9 68 | amsgrad: false 69 | eps: 0.00000001 70 | -------------------------------------------------------------------------------- /configs/celeba.yml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 128 3 | n_epochs: 500000 4 | n_iters: 210001 5 | snapshot_freq: 5000 6 | snapshot_sampling: true 7 | anneal_power: 2 8 | log_all_sigmas: false 9 | 10 | sampling: 11 | batch_size: 8 12 | data_init: false 13 | step_lr: 0.0000033 14 | n_steps_each: 5 15 | ckpt_id: 210000 16 | final_only: true 17 | fid: false 18 | denoise: true 19 | num_samples4fid: 10000 20 | inpainting: false 21 | interpolation: false 22 | n_interpolations: 15 23 | 24 | fast_fid: 25 | batch_size: 1000 26 | num_samples: 1000 27 | step_lr: 0.0000033 28 | n_steps_each: 5 29 | begin_ckpt: 5000 30 | end_ckpt: 210000 31 | verbose: false 32 | ensemble: false 33 | 34 | test: 35 | begin_ckpt: 5000 36 | end_ckpt: 210000 37 | batch_size: 100 38 | 39 | data: 40 | dataset: "CELEBA" 41 | image_size: 64 42 | channels: 3 43 | logit_transform: false 44 | uniform_dequantization: false 45 | gaussian_dequantization: false 46 | random_flip: true 47 | rescaled: false 48 | num_workers: 32 49 | 50 | model: 51 | sigma_begin: 90 52 | num_classes: 500 53 | ema: true 54 | ema_rate: 0.999 55 | spec_norm: false 56 | sigma_dist: geometric 57 | sigma_end: 0.01 58 | normalization: InstanceNorm++ 59 | nonlinearity: elu 60 | ngf: 128 61 | 62 | optim: 63 | weight_decay: 0.000 64 | optimizer: "Adam" 65 | lr: 0.0001 66 | beta1: 0.9 67 | amsgrad: false 68 | eps: 0.00000001 69 | -------------------------------------------------------------------------------- /configs/tower.yml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 128 3 | n_epochs: 500000 4 | n_iters: 150001 5 | snapshot_freq: 5000 6 | snapshot_sampling: true 7 | anneal_power: 2 8 | log_all_sigmas: false 9 | 10 | sampling: 11 | batch_size: 6 12 | data_init: false 13 | step_lr: 0.0000018 14 | n_steps_each: 3 15 | ckpt_id: 150000 16 | final_only: false 17 | fid: false 18 | denoise: true 19 | num_samples4fid: 10000 20 | inpainting: false 21 | interpolation: false 22 | n_interpolations: 10 23 | 24 | fast_fid: 25 | batch_size: 1000 26 | num_samples: 1000 27 | step_lr: 0.0000018 28 | n_steps_each: 3 29 | begin_ckpt: 100000 30 | end_ckpt: 150000 31 | verbose: false 32 | ensemble: false 33 | 34 | test: 35 | begin_ckpt: 5000 36 | end_ckpt: 150000 37 | batch_size: 100 38 | 39 | data: 40 | dataset: "LSUN" 41 | category: "tower" 42 | image_size: 128 43 | channels: 3 44 | logit_transform: false 45 | uniform_dequantization: false 46 | gaussian_dequantization: false 47 | random_flip: true 48 | rescaled: false 49 | num_workers: 32 50 | 51 | model: 52 | sigma_begin: 190 53 | num_classes: 1086 54 | ema: true 55 | ema_rate: 0.999 56 | spec_norm: false 57 | sigma_dist: geometric 58 | sigma_end: 0.01 59 | normalization: InstanceNorm++ 60 | nonlinearity: elu 61 | ngf: 128 62 | 63 | optim: 64 | weight_decay: 0.000 65 | optimizer: "Adam" 66 | lr: 0.0001 67 | beta1: 0.9 68 | amsgrad: false 69 | eps: 0.00000001 70 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.transforms as transforms 4 | from torchvision.datasets import LSUN 5 | from datasets.celeba import CelebA 6 | from torch.utils.data import Subset 7 | import numpy as np 8 | 9 | def get_dataset(args, config): 10 | if config.data.random_flip is False: 11 | tran_transform = test_transform = transforms.Compose([ 12 | transforms.Resize(config.data.image_size), 13 | transforms.ToTensor() 14 | ]) 15 | else: 16 | tran_transform = transforms.Compose([ 17 | transforms.Resize(config.data.image_size), 18 | transforms.RandomHorizontalFlip(p=0.5), 19 | transforms.ToTensor() 20 | ]) 21 | test_transform = transforms.Compose([ 22 | transforms.Resize(config.data.image_size), 23 | transforms.ToTensor() 24 | ]) 25 | 26 | if config.data.dataset == 'CELEBA': 27 | if config.data.random_flip: 28 | dataset = CelebA(root=os.path.join(args.exp, 'datasets'), split='test', 29 | transform=transforms.Compose([ 30 | transforms.CenterCrop(140), 31 | transforms.Resize(config.data.image_size), 32 | transforms.RandomHorizontalFlip(), 33 | transforms.ToTensor(), 34 | ]), download=False) 35 | else: 36 | dataset = CelebA(root=os.path.join(args.exp, 'datasets'), split='test', 37 | transform=transforms.Compose([ 38 | transforms.CenterCrop(140), 39 | transforms.Resize(config.data.image_size), 40 | transforms.ToTensor(), 41 | ]), download=False) 42 | 43 | elif config.data.dataset == 'LSUN': 44 | train_folder = '{}_train'.format(config.data.category) 45 | val_folder = '{}_val'.format(config.data.category) 46 | 47 | dataset = LSUN(root=os.path.join(args.exp, 'datasets', 'lsun'), classes=[val_folder], 48 | transform=transforms.Compose([ 49 | transforms.Resize(config.data.image_size), 50 | transforms.CenterCrop(config.data.image_size), 51 | transforms.ToTensor(), 52 | ])) 53 | 54 | return dataset 55 | 56 | def logit_transform(image, lam=1e-6): 57 | image = lam + (1 - 2 * lam) * image 58 | return torch.log(image) - torch.log1p(-image) 59 | 60 | def data_transform(config, X): 61 | if config.data.uniform_dequantization: 62 | X = X / 256. * 255. + torch.rand_like(X) / 256. 63 | if config.data.gaussian_dequantization: 64 | X = X + torch.randn_like(X) * 0.01 65 | 66 | if config.data.rescaled: 67 | X = 2 * X - 1. 68 | elif config.data.logit_transform: 69 | X = logit_transform(X) 70 | 71 | if hasattr(config, 'image_mean'): 72 | return X - config.image_mean.to(X.device)[None, ...] 73 | 74 | return X 75 | 76 | def inverse_data_transform(config, X): 77 | if hasattr(config, 'image_mean'): 78 | X = X + config.image_mean.to(X.device)[None, ...] 79 | 80 | if config.data.logit_transform: 81 | X = torch.sigmoid(X) 82 | elif config.data.rescaled: 83 | X = (X + 1.) / 2. 84 | 85 | return torch.clamp(X, 0.0, 1.0) 86 | -------------------------------------------------------------------------------- /datasets/celeba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import PIL 4 | from .vision import VisionDataset 5 | from .utils import download_file_from_google_drive, check_integrity 6 | 7 | 8 | class CelebA(VisionDataset): 9 | """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset. 10 | 11 | Args: 12 | root (string): Root directory where images are downloaded to. 13 | split (string): One of {'train', 'valid', 'test'}. 14 | Accordingly dataset is selected. 15 | target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``, 16 | or ``landmarks``. Can also be a list to output a tuple with all specified target types. 17 | The targets represent: 18 | ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes 19 | ``identity`` (int): label for each person (data points with the same identity are the same person) 20 | ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height) 21 | ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x, 22 | righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y) 23 | Defaults to ``attr``. 24 | transform (callable, optional): A function/transform that takes in an PIL image 25 | and returns a transformed version. E.g, ``transforms.ToTensor`` 26 | target_transform (callable, optional): A function/transform that takes in the 27 | target and transforms it. 28 | download (bool, optional): If true, downloads the dataset from the internet and 29 | puts it in root directory. If dataset is already downloaded, it is not 30 | downloaded again. 31 | """ 32 | 33 | base_folder = "celeba" 34 | # There currently does not appear to be a easy way to extract 7z in python (without introducing additional 35 | # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available 36 | # right now. 37 | file_list = [ 38 | # File ID MD5 Hash Filename 39 | ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"), 40 | # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"), 41 | # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"), 42 | ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"), 43 | ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"), 44 | ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"), 45 | ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"), 46 | # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"), 47 | ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), 48 | ] 49 | 50 | def __init__(self, root, 51 | split="train", 52 | target_type="attr", 53 | transform=None, target_transform=None, 54 | download=False): 55 | import pandas 56 | super(CelebA, self).__init__(root) 57 | self.split = split 58 | if isinstance(target_type, list): 59 | self.target_type = target_type 60 | else: 61 | self.target_type = [target_type] 62 | self.transform = transform 63 | self.target_transform = target_transform 64 | 65 | '''if download: 66 | self.download() 67 | 68 | if not self._check_integrity(): 69 | raise RuntimeError('Dataset not found or corrupted.' + 70 | ' You can use download=True to download it')''' 71 | 72 | self.transform = transform 73 | self.target_transform = target_transform 74 | 75 | if split.lower() == "train": 76 | split = 0 77 | elif split.lower() == "valid": 78 | split = 1 79 | elif split.lower() == "test": 80 | split = 2 81 | else: 82 | raise ValueError('Wrong split entered! Please use split="train" ' 83 | 'or split="valid" or split="test"') 84 | 85 | with open(os.path.join(self.root, self.base_folder, "list_eval_partition.txt"), "r") as f: 86 | splits = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0) 87 | 88 | with open(os.path.join(self.root, self.base_folder, "identity_CelebA.txt"), "r") as f: 89 | self.identity = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0) 90 | 91 | with open(os.path.join(self.root, self.base_folder, "list_bbox_celeba.txt"), "r") as f: 92 | self.bbox = pandas.read_csv(f, delim_whitespace=True, header=1, index_col=0) 93 | 94 | with open(os.path.join(self.root, self.base_folder, "list_landmarks_align_celeba.txt"), "r") as f: 95 | self.landmarks_align = pandas.read_csv(f, delim_whitespace=True, header=1) 96 | 97 | with open(os.path.join(self.root, self.base_folder, "list_attr_celeba.txt"), "r") as f: 98 | self.attr = pandas.read_csv(f, delim_whitespace=True, header=1) 99 | 100 | mask = (splits[1] == split) 101 | self.filename = splits[mask].index.values 102 | self.identity = torch.as_tensor(self.identity[mask].values) 103 | self.bbox = torch.as_tensor(self.bbox[mask].values) 104 | self.landmarks_align = torch.as_tensor(self.landmarks_align[mask].values) 105 | self.attr = torch.as_tensor(self.attr[mask].values) 106 | self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} 107 | 108 | def _check_integrity(self): 109 | for (_, md5, filename) in self.file_list: 110 | fpath = os.path.join(self.root, self.base_folder, filename) 111 | _, ext = os.path.splitext(filename) 112 | # Allow original archive to be deleted (zip and 7z) 113 | # Only need the extracted images 114 | if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5): 115 | return False 116 | 117 | # Should check a hash of the images 118 | return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) 119 | 120 | def download(self): 121 | import zipfile 122 | 123 | if self._check_integrity(): 124 | print('Files already downloaded and verified') 125 | return 126 | 127 | for (file_id, md5, filename) in self.file_list: 128 | download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) 129 | 130 | with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: 131 | f.extractall(os.path.join(self.root, self.base_folder)) 132 | 133 | def __getitem__(self, index): 134 | X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) 135 | 136 | target = [] 137 | for t in self.target_type: 138 | if t == "attr": 139 | target.append(self.attr[index, :]) 140 | elif t == "identity": 141 | target.append(self.identity[index, 0]) 142 | elif t == "bbox": 143 | target.append(self.bbox[index, :]) 144 | elif t == "landmarks": 145 | target.append(self.landmarks_align[index, :]) 146 | else: 147 | raise ValueError("Target type \"{}\" is not recognized.".format(t)) 148 | target = tuple(target) if len(target) > 1 else target[0] 149 | 150 | if self.transform is not None: 151 | X = self.transform(X) 152 | 153 | if self.target_transform is not None: 154 | target = self.target_transform(target) 155 | 156 | return X, target 157 | 158 | def __len__(self): 159 | return len(self.attr) 160 | 161 | def extra_repr(self): 162 | lines = ["Target type: {target_type}", "Split: {split}"] 163 | return '\n'.join(lines).format(**self.__dict__) 164 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import hashlib 4 | import errno 5 | from torch.utils.model_zoo import tqdm 6 | 7 | 8 | def gen_bar_updater(): 9 | pbar = tqdm(total=None) 10 | 11 | def bar_update(count, block_size, total_size): 12 | if pbar.total is None and total_size: 13 | pbar.total = total_size 14 | progress_bytes = count * block_size 15 | pbar.update(progress_bytes - pbar.n) 16 | 17 | return bar_update 18 | 19 | 20 | def check_integrity(fpath, md5=None): 21 | if md5 is None: 22 | return True 23 | if not os.path.isfile(fpath): 24 | return False 25 | md5o = hashlib.md5() 26 | with open(fpath, 'rb') as f: 27 | # read in 1MB chunks 28 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 29 | md5o.update(chunk) 30 | md5c = md5o.hexdigest() 31 | if md5c != md5: 32 | return False 33 | return True 34 | 35 | 36 | def makedir_exist_ok(dirpath): 37 | """ 38 | Python2 support for os.makedirs(.., exist_ok=True) 39 | """ 40 | try: 41 | os.makedirs(dirpath) 42 | except OSError as e: 43 | if e.errno == errno.EEXIST: 44 | pass 45 | else: 46 | raise 47 | 48 | 49 | def download_url(url, root, filename=None, md5=None): 50 | """Download a file from a url and place it in root. 51 | 52 | Args: 53 | url (str): URL to download file from 54 | root (str): Directory to place downloaded file in 55 | filename (str, optional): Name to save the file under. If None, use the basename of the URL 56 | md5 (str, optional): MD5 checksum of the download. If None, do not check 57 | """ 58 | from six.moves import urllib 59 | 60 | root = os.path.expanduser(root) 61 | if not filename: 62 | filename = os.path.basename(url) 63 | fpath = os.path.join(root, filename) 64 | 65 | makedir_exist_ok(root) 66 | 67 | # downloads file 68 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 69 | print('Using downloaded and verified file: ' + fpath) 70 | else: 71 | try: 72 | print('Downloading ' + url + ' to ' + fpath) 73 | urllib.request.urlretrieve( 74 | url, fpath, 75 | reporthook=gen_bar_updater() 76 | ) 77 | except OSError: 78 | if url[:5] == 'https': 79 | url = url.replace('https:', 'http:') 80 | print('Failed download. Trying https -> http instead.' 81 | ' Downloading ' + url + ' to ' + fpath) 82 | urllib.request.urlretrieve( 83 | url, fpath, 84 | reporthook=gen_bar_updater() 85 | ) 86 | 87 | 88 | def list_dir(root, prefix=False): 89 | """List all directories at a given root 90 | 91 | Args: 92 | root (str): Path to directory whose folders need to be listed 93 | prefix (bool, optional): If true, prepends the path to each result, otherwise 94 | only returns the name of the directories found 95 | """ 96 | root = os.path.expanduser(root) 97 | directories = list( 98 | filter( 99 | lambda p: os.path.isdir(os.path.join(root, p)), 100 | os.listdir(root) 101 | ) 102 | ) 103 | 104 | if prefix is True: 105 | directories = [os.path.join(root, d) for d in directories] 106 | 107 | return directories 108 | 109 | 110 | def list_files(root, suffix, prefix=False): 111 | """List all files ending with a suffix at a given root 112 | 113 | Args: 114 | root (str): Path to directory whose folders need to be listed 115 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 116 | It uses the Python "str.endswith" method and is passed directly 117 | prefix (bool, optional): If true, prepends the path to each result, otherwise 118 | only returns the name of the files found 119 | """ 120 | root = os.path.expanduser(root) 121 | files = list( 122 | filter( 123 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 124 | os.listdir(root) 125 | ) 126 | ) 127 | 128 | if prefix is True: 129 | files = [os.path.join(root, d) for d in files] 130 | 131 | return files 132 | 133 | 134 | def download_file_from_google_drive(file_id, root, filename=None, md5=None): 135 | """Download a Google Drive file from and place it in root. 136 | 137 | Args: 138 | file_id (str): id of file to be downloaded 139 | root (str): Directory to place downloaded file in 140 | filename (str, optional): Name to save the file under. If None, use the id of the file. 141 | md5 (str, optional): MD5 checksum of the download. If None, do not check 142 | """ 143 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url 144 | import requests 145 | url = "https://docs.google.com/uc?export=download" 146 | 147 | root = os.path.expanduser(root) 148 | if not filename: 149 | filename = file_id 150 | fpath = os.path.join(root, filename) 151 | 152 | makedir_exist_ok(root) 153 | 154 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 155 | print('Using downloaded and verified file: ' + fpath) 156 | else: 157 | session = requests.Session() 158 | 159 | response = session.get(url, params={'id': file_id}, stream=True) 160 | token = _get_confirm_token(response) 161 | 162 | if token: 163 | params = {'id': file_id, 'confirm': token} 164 | response = session.get(url, params=params, stream=True) 165 | 166 | _save_response_content(response, fpath) 167 | 168 | 169 | def _get_confirm_token(response): 170 | for key, value in response.cookies.items(): 171 | if key.startswith('download_warning'): 172 | return value 173 | 174 | return None 175 | 176 | 177 | def _save_response_content(response, destination, chunk_size=32768): 178 | with open(destination, "wb") as f: 179 | pbar = tqdm(total=None) 180 | progress = 0 181 | for chunk in response.iter_content(chunk_size): 182 | if chunk: # filter out keep-alive new chunks 183 | f.write(chunk) 184 | progress += len(chunk) 185 | pbar.update(progress - pbar.n) 186 | pbar.close() 187 | -------------------------------------------------------------------------------- /datasets/vision.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | 5 | 6 | class VisionDataset(data.Dataset): 7 | _repr_indent = 4 8 | 9 | def __init__(self, root): 10 | if isinstance(root, torch._six.string_classes): 11 | root = os.path.expanduser(root) 12 | self.root = root 13 | 14 | def __getitem__(self, index): 15 | raise NotImplementedError 16 | 17 | def __len__(self): 18 | raise NotImplementedError 19 | 20 | def __repr__(self): 21 | head = "Dataset " + self.__class__.__name__ 22 | body = ["Number of datapoints: {}".format(self.__len__())] 23 | if self.root is not None: 24 | body.append("Root location: {}".format(self.root)) 25 | body += self.extra_repr().splitlines() 26 | if hasattr(self, 'transform') and self.transform is not None: 27 | body += self._format_transform_repr(self.transform, 28 | "Transforms: ") 29 | if hasattr(self, 'target_transform') and self.target_transform is not None: 30 | body += self._format_transform_repr(self.target_transform, 31 | "Target transforms: ") 32 | lines = [head] + [" " * self._repr_indent + line for line in body] 33 | return '\n'.join(lines) 34 | 35 | def _format_transform_repr(self, transform, head): 36 | lines = transform.__repr__().splitlines() 37 | return (["{}{}".format(head, lines[0])] + 38 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 39 | 40 | def extra_repr(self): 41 | return "" 42 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: snips 2 | channels: 3 | - defaults 4 | - pytorch 5 | - conda-forge 6 | dependencies: 7 | - python=3.7.7 8 | - pytorch=1.7.0 9 | - scipy 10 | - torchvision 11 | - hdf5 12 | - tqdm 13 | - h5py 14 | - cudatoolkit=11.0 15 | - pyyaml 16 | - numpy 17 | - tensorboard 18 | - pandas 19 | - python-lmdb 20 | -------------------------------------------------------------------------------- /filter_builder.py: -------------------------------------------------------------------------------- 1 | #source: https://github.com/alisaaalehi/convolution_as_multiplication/blob/master/Convolution_as_multiplication.ipynb 2 | import numpy as np 3 | from scipy.linalg import toeplitz 4 | 5 | 6 | def matrix_to_vector(input): 7 | """ 8 | Converts the input matrix to a vector by stacking the rows in a specific way explained here 9 | 10 | Arg: 11 | input -- a numpy matrix 12 | 13 | Returns: 14 | ouput_vector -- a column vector with size input.shape[0]*input.shape[1] 15 | """ 16 | input_h, input_w = input.shape 17 | output_vector = np.zeros(input_h*input_w, dtype=input.dtype) 18 | # flip the input matrix up-down because last row should go first 19 | input = np.flipud(input) 20 | for i,row in enumerate(input): 21 | st = i*input_w 22 | nd = st + input_w 23 | output_vector[st:nd] = row 24 | return output_vector 25 | 26 | 27 | def vector_to_matrix(input, output_shape): 28 | """ 29 | Reshapes the output of the maxtrix multiplication to the shape "output_shape" 30 | 31 | Arg: 32 | input -- a numpy vector 33 | 34 | Returns: 35 | output -- numpy matrix with shape "output_shape" 36 | """ 37 | output_h, output_w = output_shape 38 | output = np.zeros(output_shape, dtype=input.dtype) 39 | for i in range(output_h): 40 | st = i*output_w 41 | nd = st + output_w 42 | output[i, :] = input[st:nd] 43 | # flip the output matrix up-down to get correct result 44 | output=np.flipud(output) 45 | return output 46 | 47 | 48 | def convolution_as_maultiplication(I, F, print_ir=False): 49 | """ 50 | Performs 2D convolution between input I and filter F by converting the F to a toeplitz matrix and multiply it 51 | with vectorizes version of I 52 | By : AliSaaalehi@gmail.com 53 | 54 | Arg: 55 | I -- 2D numpy matrix 56 | F -- numpy 2D matrix 57 | print_ir -- if True, all intermediate resutls will be printed after each step of the algorithms 58 | 59 | Returns: 60 | output -- 2D numpy matrix, result of convolving I with F 61 | """ 62 | # number of columns and rows of the input 63 | I_row_num, I_col_num = I.shape 64 | 65 | # number of columns and rows of the filter 66 | F_row_num, F_col_num = F.shape 67 | 68 | # calculate the output dimensions 69 | output_row_num = I_row_num + F_row_num - 1 70 | output_col_num = I_col_num + F_col_num - 1 71 | if print_ir: print('output dimension:', output_row_num, output_col_num) 72 | 73 | # zero pad the filter 74 | F_zero_padded = np.pad(F, ((output_row_num - F_row_num, 0), 75 | (0, output_col_num - F_col_num)), 76 | 'constant', constant_values=0) 77 | if print_ir: print('F_zero_padded: ', F_zero_padded) 78 | 79 | # use each row of the zero-padded F to creat a toeplitz matrix. 80 | # Number of columns in this matrices are same as numbe of columns of input signal 81 | toeplitz_list = [] 82 | for i in range(F_zero_padded.shape[0]-1, -1, -1): # iterate from last row to the first row 83 | c = F_zero_padded[i, :] # i th row of the F 84 | r = np.r_[c[0], np.zeros(I_col_num-1)] # first row for the toeplitz fuction should be defined otherwise 85 | # the result is wrong 86 | toeplitz_m = toeplitz(c,r) # this function is in scipy.linalg library 87 | toeplitz_list.append(toeplitz_m) 88 | if print_ir: print('F '+ str(i)+'\n', toeplitz_m) 89 | 90 | # doubly blocked toeplitz indices: 91 | # this matrix defines which toeplitz matrix from toeplitz_list goes to which part of the doubly blocked 92 | c = range(1, F_zero_padded.shape[0]+1) 93 | r = np.r_[c[0], np.zeros(I_row_num-1, dtype=int)] 94 | doubly_indices = toeplitz(c, r) 95 | if print_ir: print('doubly indices \n', doubly_indices) 96 | 97 | ## creat doubly blocked matrix with zero values 98 | toeplitz_shape = toeplitz_list[0].shape # shape of one toeplitz matrix 99 | h = toeplitz_shape[0]*doubly_indices.shape[0] 100 | w = toeplitz_shape[1]*doubly_indices.shape[1] 101 | doubly_blocked_shape = [h, w] 102 | doubly_blocked = np.zeros(doubly_blocked_shape) 103 | 104 | # tile toeplitz matrices for each row in the doubly blocked matrix 105 | b_h, b_w = toeplitz_shape # hight and withs of each block 106 | for i in range(doubly_indices.shape[0]): 107 | for j in range(doubly_indices.shape[1]): 108 | start_i = i * b_h 109 | start_j = j * b_w 110 | end_i = start_i + b_h 111 | end_j = start_j + b_w 112 | doubly_blocked[start_i: end_i, start_j:end_j] = toeplitz_list[doubly_indices[i,j]-1] 113 | 114 | if print_ir: print('doubly_blocked: ', doubly_blocked) 115 | 116 | # convert I to a vector 117 | vectorized_I = matrix_to_vector(I) 118 | if print_ir: print('vectorized_I: ', vectorized_I) 119 | 120 | # get result of the convolution by matrix mupltiplication 121 | result_vector = np.matmul(doubly_blocked, vectorized_I) 122 | if print_ir: print('result_vector: ', result_vector) 123 | 124 | # reshape the raw rsult to desired matrix form 125 | out_shape = [output_row_num, output_col_num] 126 | output = vector_to_matrix(result_vector, out_shape) 127 | if print_ir: print('Result of implemented method: \n', output) 128 | 129 | return output 130 | 131 | def kernel_to_matrix(F, I_row_num, I_col_num): 132 | """ 133 | Arg: 134 | F -- numpy 2D matrix - kernel of the blur 135 | I_row_num - number of rows in signal 136 | I_col_num - number of cols in signal 137 | 138 | Returns: 139 | output -- 2D numpy matrix, which operates on a signal I of size 140 | """ 141 | 142 | # number of columns and rows of the filter 143 | F_row_num, F_col_num = F.shape 144 | 145 | # calculate the output dimensions 146 | output_row_num = I_row_num + F_row_num - 1 147 | output_col_num = I_col_num + F_col_num - 1 148 | #if print_ir: print('output dimension:', output_row_num, output_col_num) 149 | 150 | # zero pad the filter 151 | F_zero_padded = np.pad(F, ((output_row_num - F_row_num, 0), 152 | (0, output_col_num - F_col_num)), 153 | 'constant', constant_values=0) 154 | #if print_ir: print('F_zero_padded: ', F_zero_padded) 155 | 156 | # use each row of the zero-padded F to creat a toeplitz matrix. 157 | # Number of columns in this matrices are same as numbe of columns of input signal 158 | toeplitz_list = [] 159 | for i in range(F_zero_padded.shape[0]-1, -1, -1): # iterate from last row to the first row 160 | c = F_zero_padded[i, :] # i th row of the F 161 | r = np.r_[c[0], np.zeros(I_col_num-1)] # first row for the toeplitz fuction should be defined otherwise 162 | # the result is wrong 163 | toeplitz_m = toeplitz(c,r) # this function is in scipy.linalg library 164 | toeplitz_list.append(toeplitz_m) 165 | #if print_ir: print('F '+ str(i)+'\n', toeplitz_m) 166 | 167 | # doubly blocked toeplitz indices: 168 | # this matrix defines which toeplitz matrix from toeplitz_list goes to which part of the doubly blocked 169 | c = range(1, F_zero_padded.shape[0]+1) 170 | r = np.r_[c[0], np.zeros(I_row_num-1, dtype=int)] 171 | doubly_indices = toeplitz(c, r) 172 | #if print_ir: print('doubly indices \n', doubly_indices) 173 | 174 | ## creat doubly blocked matrix with zero values 175 | toeplitz_shape = toeplitz_list[0].shape # shape of one toeplitz matrix 176 | h = toeplitz_shape[0]*doubly_indices.shape[0] 177 | w = toeplitz_shape[1]*doubly_indices.shape[1] 178 | doubly_blocked_shape = [h, w] 179 | doubly_blocked = np.zeros(doubly_blocked_shape) 180 | 181 | # tile toeplitz matrices for each row in the doubly blocked matrix 182 | b_h, b_w = toeplitz_shape # hight and withs of each block 183 | for i in range(doubly_indices.shape[0]): 184 | for j in range(doubly_indices.shape[1]): 185 | start_i = i * b_h 186 | start_j = j * b_w 187 | end_i = start_i + b_h 188 | end_j = start_j + b_w 189 | doubly_blocked[start_i: end_i, start_j:end_j] = toeplitz_list[doubly_indices[i,j]-1] 190 | 191 | return doubly_blocked 192 | 193 | def get_custom_kernel(type = "gauss", dim = 64): 194 | kernel = 0 195 | if type == "14641": 196 | kernel = np.array([[1, 4, 6, 4, 1]]) 197 | kernel = np.matmul(kernel.transpose(1, 0), kernel) / 256.0 198 | elif type == "uniform": 199 | kernel = np.array([[1, 1, 1, 1, 1]]) 200 | kernel = np.matmul(kernel.transpose(1, 0), kernel) / 25.0 201 | elif type == "gauss": 202 | #sigma 10 203 | kernel = np.array([[0.03920520445985253,0.03979771524812676,0.0399972021259645,0.03979771524812676,0.03920520445985253], 204 | [0.03979771524812676,0.04039918069022969,0.04060168242614218,0.04039918069022969,0.03979771524812676], 205 | [0.0399972021259645,0.04060168242614218,0.04080519920622999,0.04060168242614218,0.0399972021259645], 206 | [0.03979771524812676,0.04039918069022969,0.04060168242614218,0.04039918069022969,0.03979771524812676], 207 | [0.03920520445985253,0.03979771524812676,0.0399972021259645,0.03979771524812676,0.03920520445985253]]) 208 | H = kernel_to_matrix(kernel, dim, dim) 209 | H = H[[row*(dim+4)+col for row in range(2, dim+2) for col in range(2,dim+2)], :] 210 | return H -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import traceback 3 | import time 4 | import shutil 5 | import logging 6 | import yaml 7 | import sys 8 | import os 9 | import torch 10 | import numpy as np 11 | import torch.utils.tensorboard as tb 12 | import copy 13 | from runners import * 14 | 15 | import os 16 | 17 | def parse_args_and_config(): 18 | parser = argparse.ArgumentParser(description=globals()['__doc__']) 19 | 20 | parser.add_argument('--config', type=str, required=True, help='Path to the config file') 21 | parser.add_argument('--seed', type=int, default=1234, help='Random seed') 22 | parser.add_argument('--exp', type=str, default='exp', help='Path for saving running related data.') 23 | parser.add_argument('--doc', type=str, required=True, help='A string for documentation purpose. ' 24 | 'Will be the name of the log folder.') 25 | parser.add_argument('--comment', type=str, default='', help='A string for experiment comment') 26 | parser.add_argument('--verbose', type=str, default='info', help='Verbose level: info | debug | warning | critical') 27 | parser.add_argument('-i', '--image_folder', type=str, default='images', help="The folder name of samples") 28 | 29 | parser.add_argument('-n', '--num_variations', type=int, default=1, help='Number of variations to produce') 30 | parser.add_argument('-s', '--sigma_0', type=float, default=0.1, help='Noise std to add to observation') 31 | parser.add_argument('--degradation', type=str, default='sr4', help='Degradation: inp | deblur_uni | deblur_gauss | sr2 | sr4 | cs4 | cs8 | cs16') 32 | 33 | args = parser.parse_args() 34 | args.log_path = os.path.join(args.exp, 'logs', args.doc) 35 | 36 | # parse config file 37 | with open(os.path.join('configs', args.config), 'r') as f: 38 | config = yaml.load(f) 39 | new_config = dict2namespace(config) 40 | 41 | tb_path = os.path.join(args.exp, 'tensorboard', args.doc) 42 | 43 | level = getattr(logging, args.verbose.upper(), None) 44 | if not isinstance(level, int): 45 | raise ValueError('level {} not supported'.format(args.verbose)) 46 | 47 | handler1 = logging.StreamHandler() 48 | formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s') 49 | handler1.setFormatter(formatter) 50 | logger = logging.getLogger() 51 | logger.addHandler(handler1) 52 | logger.setLevel(level) 53 | 54 | os.makedirs(os.path.join(args.exp, 'image_samples'), exist_ok=True) 55 | args.image_folder = os.path.join(args.exp, 'image_samples', args.image_folder) 56 | if not os.path.exists(args.image_folder): 57 | os.makedirs(args.image_folder) 58 | else: 59 | response = input("Image folder already exists. Overwrite? (Y/N)") 60 | if response.upper() == 'Y': 61 | overwrite = True 62 | 63 | if overwrite: 64 | shutil.rmtree(args.image_folder) 65 | os.makedirs(args.image_folder) 66 | else: 67 | print("Output image folder exists. Program halted.") 68 | sys.exit(0) 69 | 70 | # add device 71 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 72 | logging.info("Using device: {}".format(device)) 73 | new_config.device = device 74 | 75 | # set random seed 76 | torch.manual_seed(args.seed) 77 | np.random.seed(args.seed) 78 | if torch.cuda.is_available(): 79 | torch.cuda.manual_seed_all(args.seed) 80 | 81 | torch.backends.cudnn.benchmark = True 82 | 83 | return args, new_config 84 | 85 | 86 | def dict2namespace(config): 87 | namespace = argparse.Namespace() 88 | for key, value in config.items(): 89 | if isinstance(value, dict): 90 | new_value = dict2namespace(value) 91 | else: 92 | new_value = value 93 | setattr(namespace, key, new_value) 94 | return namespace 95 | 96 | 97 | def main(): 98 | args, config = parse_args_and_config() 99 | logging.info("Writing log file to {}".format(args.log_path)) 100 | logging.info("Exp instance id = {}".format(os.getpid())) 101 | logging.info("Exp comment = {}".format(args.comment)) 102 | logging.info("Config =") 103 | print(">" * 80) 104 | config_dict = copy.copy(vars(config)) 105 | print(yaml.dump(config_dict, default_flow_style=False)) 106 | print("<" * 80) 107 | 108 | try: 109 | runner = NCSNRunner(args, config) 110 | runner.sample() 111 | except: 112 | logging.error(traceback.format_exc()) 113 | 114 | return 0 115 | 116 | 117 | if __name__ == '__main__': 118 | sys.exit(main()) 119 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import tqdm 4 | 5 | def get_sigmas(config): 6 | if config.model.sigma_dist == 'geometric': 7 | sigmas = torch.tensor( 8 | np.exp(np.linspace(np.log(config.model.sigma_begin), np.log(config.model.sigma_end), 9 | config.model.num_classes))).float().to(config.device) 10 | elif config.model.sigma_dist == 'uniform': 11 | sigmas = torch.tensor( 12 | np.linspace(config.model.sigma_begin, config.model.sigma_end, config.model.num_classes) 13 | ).float().to(config.device) 14 | 15 | else: 16 | raise NotImplementedError('sigma distribution not supported') 17 | 18 | return sigmas 19 | 20 | def mat_by_vec(M, v): 21 | vshape = v.shape[2] 22 | if len(v.shape) > 3: vshape = vshape * v.shape[3] 23 | return torch.matmul(M, v.view(v.shape[0] * v.shape[1], vshape, 24 | 1)).view(v.shape[0], v.shape[1], M.shape[0]) 25 | 26 | def vec_to_image(v, img_dim): 27 | return v.view(v.shape[0], v.shape[1], img_dim, img_dim) 28 | 29 | def invert_diag(M): 30 | M_inv = torch.zeros_like(M) 31 | M_inv[M != 0] = 1 / M[M != 0] 32 | return M_inv 33 | 34 | 35 | @torch.no_grad() 36 | def general_anneal_Langevin_dynamics(H, y_0, x_mod, scorenet, sigmas, n_steps_each=200, step_lr=0.000008, 37 | final_only=False, verbose=False, denoise=True, c_begin = 0, sigma_0 = 1): 38 | U, singulars, V = torch.svd(H, some=False) 39 | V_t = V.transpose(0, 1) 40 | 41 | ZERO = 1e-3 42 | singulars[singulars < ZERO] = 0 43 | 44 | Sigma = torch.zeros_like(H) 45 | for i in range(singulars.shape[0]): Sigma[i, i] = singulars[i] 46 | S_1, S_n = singulars[0], singulars[-1] 47 | 48 | S_S_t = torch.zeros_like(U) 49 | for i in range(singulars.shape[0]): S_S_t[i, i] = singulars[i] ** 2 50 | 51 | num_missing = V.shape[0] - torch.count_nonzero(singulars) 52 | 53 | s0_2_I = ((sigma_0 ** 2) * torch.eye(U.shape[0])).to(x_mod.device) 54 | 55 | V_t_x = mat_by_vec(V_t, x_mod) 56 | U_t_y = mat_by_vec(U.transpose(0,1), y_0) 57 | 58 | img_dim = x_mod.shape[2] 59 | 60 | images = [] 61 | 62 | with torch.no_grad(): 63 | for c, sigma in tqdm.tqdm(enumerate(sigmas), total=len(sigmas), desc='general annealed Langevin sampling'): 64 | 65 | labels = torch.ones(x_mod.shape[0], device=x_mod.device) * (c + c_begin) 66 | labels = labels.long() 67 | step_size = step_lr * ((1 / sigmas[-1]) ** 2) 68 | 69 | falses = torch.zeros(V_t_x.shape[2] - singulars.shape[0], dtype=torch.bool, device=x_mod.device) 70 | cond_before_lite = singulars * sigma > sigma_0 71 | cond_after_lite = singulars * sigma < sigma_0 72 | cond_before = torch.hstack((cond_before_lite, falses)) 73 | cond_after = torch.hstack((cond_after_lite, falses)) 74 | 75 | step_vector = torch.zeros_like(V_t_x) 76 | step_vector[:, :, :] = step_size * (sigma**2) 77 | step_vector[:, :, cond_before] = step_size * ((sigma**2) - (sigma_0 / singulars[cond_before_lite])**2) 78 | step_vector[:, :, cond_after] = step_size * (sigma**2) * (1 - (singulars[cond_after_lite] * sigma / sigma_0)**2) 79 | 80 | for s in range(n_steps_each): 81 | grad = torch.zeros_like(V_t_x) 82 | score = mat_by_vec(V_t, scorenet(x_mod, labels)) 83 | 84 | diag_mat = S_S_t * (sigma ** 2) - s0_2_I 85 | diag_mat[cond_after_lite, cond_after_lite] = diag_mat[cond_after_lite, cond_after_lite] * (-1) 86 | 87 | first_vec = U_t_y - mat_by_vec(Sigma, V_t_x) 88 | cond_grad = mat_by_vec(invert_diag(diag_mat), first_vec) 89 | cond_grad = mat_by_vec(Sigma.transpose(0,1), cond_grad) 90 | grad = torch.zeros_like(cond_grad) 91 | grad[:, :, cond_before] = cond_grad[:, :, cond_before] 92 | grad[:, :, cond_after] = cond_grad[:, :, cond_after] + score[:, :, cond_after] 93 | grad[:, :, -num_missing:] = score[:, :, -num_missing:] 94 | 95 | noise = torch.randn_like(V_t_x) 96 | V_t_x = V_t_x + step_vector * grad + noise * torch.sqrt(step_vector * 2) 97 | x_mod = vec_to_image(mat_by_vec(V, V_t_x), img_dim) 98 | 99 | if not final_only: 100 | images.append(x_mod.to('cpu')) 101 | 102 | if final_only: 103 | return [x_mod.to('cpu')] 104 | else: 105 | return images -------------------------------------------------------------------------------- /models/ema.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch.nn as nn 3 | 4 | class EMAHelper(object): 5 | def __init__(self, mu=0.999): 6 | self.mu = mu 7 | self.shadow = {} 8 | 9 | def register(self, module): 10 | if isinstance(module, nn.DataParallel): 11 | module = module.module 12 | for name, param in module.named_parameters(): 13 | if param.requires_grad: 14 | self.shadow[name] = param.data.clone() 15 | 16 | def update(self, module): 17 | if isinstance(module, nn.DataParallel): 18 | module = module.module 19 | for name, param in module.named_parameters(): 20 | if param.requires_grad: 21 | self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data 22 | 23 | def ema(self, module): 24 | if isinstance(module, nn.DataParallel): 25 | module = module.module 26 | for name, param in module.named_parameters(): 27 | if param.requires_grad: 28 | param.data.copy_(self.shadow[name].data) 29 | 30 | def ema_copy(self, module): 31 | if isinstance(module, nn.DataParallel): 32 | inner_module = module.module 33 | module_copy = type(inner_module)(inner_module.config).to(inner_module.config.device) 34 | module_copy.load_state_dict(inner_module.state_dict()) 35 | module_copy = nn.DataParallel(module_copy) 36 | else: 37 | module_copy = type(module)(module.config).to(module.config.device) 38 | module_copy.load_state_dict(module.state_dict()) 39 | # module_copy = copy.deepcopy(module) 40 | self.ema(module_copy) 41 | return module_copy 42 | 43 | def state_dict(self): 44 | return self.shadow 45 | 46 | def load_state_dict(self, state_dict): 47 | self.shadow = state_dict -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.nn.parameter import Parameter 4 | import torch.nn.functional as F 5 | from .normalization import * 6 | from functools import partial 7 | import math 8 | import torch.nn.init as init 9 | 10 | 11 | def get_act(config): 12 | if config.model.nonlinearity.lower() == 'elu': 13 | return nn.ELU() 14 | elif config.model.nonlinearity.lower() == 'relu': 15 | return nn.ReLU() 16 | elif config.model.nonlinearity.lower() == 'lrelu': 17 | return nn.LeakyReLU(negative_slope=0.2) 18 | elif config.model.nonlinearity.lower() == 'swish': 19 | def swish(x): 20 | return x * torch.sigmoid(x) 21 | return swish 22 | else: 23 | raise NotImplementedError('activation function does not exist!') 24 | 25 | def spectral_norm(layer, n_iters=1): 26 | return torch.nn.utils.spectral_norm(layer, n_power_iterations=n_iters) 27 | 28 | def conv1x1(in_planes, out_planes, stride=1, bias=True, spec_norm=False): 29 | "1x1 convolution" 30 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 31 | padding=0, bias=bias) 32 | if spec_norm: 33 | conv = spectral_norm(conv) 34 | return conv 35 | 36 | 37 | def conv3x3(in_planes, out_planes, stride=1, bias=True, spec_norm=False): 38 | "3x3 convolution with padding" 39 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 40 | padding=1, bias=bias) 41 | if spec_norm: 42 | conv = spectral_norm(conv) 43 | 44 | return conv 45 | 46 | 47 | def stride_conv3x3(in_planes, out_planes, kernel_size, bias=True, spec_norm=False): 48 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, 49 | padding=kernel_size // 2, bias=bias) 50 | if spec_norm: 51 | conv = spectral_norm(conv) 52 | return conv 53 | 54 | 55 | def dilated_conv3x3(in_planes, out_planes, dilation, bias=True, spec_norm=False): 56 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=dilation, dilation=dilation, bias=bias) 57 | if spec_norm: 58 | conv = spectral_norm(conv) 59 | 60 | return conv 61 | 62 | class CRPBlock(nn.Module): 63 | def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True, spec_norm=False): 64 | super().__init__() 65 | self.convs = nn.ModuleList() 66 | for i in range(n_stages): 67 | self.convs.append(conv3x3(features, features, stride=1, bias=False, spec_norm=spec_norm)) 68 | self.n_stages = n_stages 69 | if maxpool: 70 | self.maxpool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2) 71 | else: 72 | self.maxpool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) 73 | 74 | self.act = act 75 | 76 | def forward(self, x): 77 | x = self.act(x) 78 | path = x 79 | for i in range(self.n_stages): 80 | path = self.maxpool(path) 81 | path = self.convs[i](path) 82 | x = path + x 83 | return x 84 | 85 | 86 | class CondCRPBlock(nn.Module): 87 | def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU(), spec_norm=False): 88 | super().__init__() 89 | self.convs = nn.ModuleList() 90 | self.norms = nn.ModuleList() 91 | self.normalizer = normalizer 92 | for i in range(n_stages): 93 | self.norms.append(normalizer(features, num_classes, bias=True)) 94 | self.convs.append(conv3x3(features, features, stride=1, bias=False, spec_norm=spec_norm)) 95 | 96 | self.n_stages = n_stages 97 | self.maxpool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) 98 | self.act = act 99 | 100 | def forward(self, x, y): 101 | x = self.act(x) 102 | path = x 103 | for i in range(self.n_stages): 104 | path = self.norms[i](path, y) 105 | path = self.maxpool(path) 106 | path = self.convs[i](path) 107 | 108 | x = path + x 109 | return x 110 | 111 | 112 | class RCUBlock(nn.Module): 113 | def __init__(self, features, n_blocks, n_stages, act=nn.ReLU(), spec_norm=False): 114 | super().__init__() 115 | 116 | for i in range(n_blocks): 117 | for j in range(n_stages): 118 | setattr(self, '{}_{}_conv'.format(i + 1, j + 1), conv3x3(features, features, stride=1, bias=False, 119 | spec_norm=spec_norm)) 120 | 121 | self.stride = 1 122 | self.n_blocks = n_blocks 123 | self.n_stages = n_stages 124 | self.act = act 125 | 126 | def forward(self, x): 127 | for i in range(self.n_blocks): 128 | residual = x 129 | for j in range(self.n_stages): 130 | x = self.act(x) 131 | x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x) 132 | 133 | x += residual 134 | return x 135 | 136 | 137 | class CondRCUBlock(nn.Module): 138 | def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU(), spec_norm=False): 139 | super().__init__() 140 | 141 | for i in range(n_blocks): 142 | for j in range(n_stages): 143 | setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True)) 144 | setattr(self, '{}_{}_conv'.format(i + 1, j + 1), 145 | conv3x3(features, features, stride=1, bias=False, spec_norm=spec_norm)) 146 | 147 | self.stride = 1 148 | self.n_blocks = n_blocks 149 | self.n_stages = n_stages 150 | self.act = act 151 | self.normalizer = normalizer 152 | 153 | def forward(self, x, y): 154 | for i in range(self.n_blocks): 155 | residual = x 156 | for j in range(self.n_stages): 157 | x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y) 158 | x = self.act(x) 159 | x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x) 160 | 161 | x += residual 162 | return x 163 | 164 | 165 | class MSFBlock(nn.Module): 166 | def __init__(self, in_planes, features, spec_norm=False): 167 | """ 168 | :param in_planes: tuples of input planes 169 | """ 170 | super().__init__() 171 | assert isinstance(in_planes, list) or isinstance(in_planes, tuple) 172 | self.convs = nn.ModuleList() 173 | self.features = features 174 | 175 | for i in range(len(in_planes)): 176 | self.convs.append(conv3x3(in_planes[i], features, stride=1, bias=True, spec_norm=spec_norm)) 177 | 178 | def forward(self, xs, shape): 179 | sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device) 180 | for i in range(len(self.convs)): 181 | h = self.convs[i](xs[i]) 182 | h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True) 183 | sums += h 184 | return sums 185 | 186 | 187 | class CondMSFBlock(nn.Module): 188 | def __init__(self, in_planes, features, num_classes, normalizer, spec_norm=False): 189 | """ 190 | :param in_planes: tuples of input planes 191 | """ 192 | super().__init__() 193 | assert isinstance(in_planes, list) or isinstance(in_planes, tuple) 194 | 195 | self.convs = nn.ModuleList() 196 | self.norms = nn.ModuleList() 197 | self.features = features 198 | self.normalizer = normalizer 199 | 200 | for i in range(len(in_planes)): 201 | self.convs.append(conv3x3(in_planes[i], features, stride=1, bias=True, spec_norm=spec_norm)) 202 | self.norms.append(normalizer(in_planes[i], num_classes, bias=True)) 203 | 204 | def forward(self, xs, y, shape): 205 | sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device) 206 | for i in range(len(self.convs)): 207 | h = self.norms[i](xs[i], y) 208 | h = self.convs[i](h) 209 | h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True) 210 | sums += h 211 | return sums 212 | 213 | 214 | class RefineBlock(nn.Module): 215 | def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True, spec_norm=False): 216 | super().__init__() 217 | 218 | assert isinstance(in_planes, tuple) or isinstance(in_planes, list) 219 | self.n_blocks = n_blocks = len(in_planes) 220 | 221 | self.adapt_convs = nn.ModuleList() 222 | for i in range(n_blocks): 223 | self.adapt_convs.append( 224 | RCUBlock(in_planes[i], 2, 2, act, spec_norm=spec_norm) 225 | ) 226 | 227 | self.output_convs = RCUBlock(features, 3 if end else 1, 2, act, spec_norm=spec_norm) 228 | 229 | if not start: 230 | self.msf = MSFBlock(in_planes, features, spec_norm=spec_norm) 231 | 232 | self.crp = CRPBlock(features, 2, act, maxpool=maxpool, spec_norm=spec_norm) 233 | 234 | def forward(self, xs, output_shape): 235 | assert isinstance(xs, tuple) or isinstance(xs, list) 236 | hs = [] 237 | for i in range(len(xs)): 238 | h = self.adapt_convs[i](xs[i]) 239 | hs.append(h) 240 | 241 | if self.n_blocks > 1: 242 | h = self.msf(hs, output_shape) 243 | else: 244 | h = hs[0] 245 | 246 | h = self.crp(h) 247 | h = self.output_convs(h) 248 | 249 | return h 250 | 251 | 252 | 253 | class CondRefineBlock(nn.Module): 254 | def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False, spec_norm=False): 255 | super().__init__() 256 | 257 | assert isinstance(in_planes, tuple) or isinstance(in_planes, list) 258 | self.n_blocks = n_blocks = len(in_planes) 259 | 260 | self.adapt_convs = nn.ModuleList() 261 | for i in range(n_blocks): 262 | self.adapt_convs.append( 263 | CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act, spec_norm=spec_norm) 264 | ) 265 | 266 | self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act, spec_norm=spec_norm) 267 | 268 | if not start: 269 | self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer, spec_norm=spec_norm) 270 | 271 | self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act, spec_norm=spec_norm) 272 | 273 | def forward(self, xs, y, output_shape): 274 | assert isinstance(xs, tuple) or isinstance(xs, list) 275 | hs = [] 276 | for i in range(len(xs)): 277 | h = self.adapt_convs[i](xs[i], y) 278 | hs.append(h) 279 | 280 | if self.n_blocks > 1: 281 | h = self.msf(hs, y, output_shape) 282 | else: 283 | h = hs[0] 284 | 285 | h = self.crp(h, y) 286 | h = self.output_convs(h, y) 287 | 288 | return h 289 | 290 | 291 | class ConvMeanPool(nn.Module): 292 | def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False, spec_norm=False): 293 | super().__init__() 294 | if not adjust_padding: 295 | conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) 296 | if spec_norm: 297 | conv = spectral_norm(conv) 298 | self.conv = conv 299 | else: 300 | conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) 301 | if spec_norm: 302 | conv = spectral_norm(conv) 303 | 304 | self.conv = nn.Sequential( 305 | nn.ZeroPad2d((1, 0, 1, 0)), 306 | conv 307 | ) 308 | 309 | def forward(self, inputs): 310 | output = self.conv(inputs) 311 | output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2], 312 | output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4. 313 | return output 314 | 315 | class MeanPoolConv(nn.Module): 316 | def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, spec_norm=False): 317 | super().__init__() 318 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) 319 | if spec_norm: 320 | self.conv = spectral_norm(self.conv) 321 | 322 | def forward(self, inputs): 323 | output = inputs 324 | output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2], 325 | output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4. 326 | return self.conv(output) 327 | 328 | 329 | class UpsampleConv(nn.Module): 330 | def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, spec_norm=False): 331 | super().__init__() 332 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) 333 | if spec_norm: 334 | self.conv = spectral_norm(self.conv) 335 | self.pixelshuffle = nn.PixelShuffle(upscale_factor=2) 336 | 337 | def forward(self, inputs): 338 | output = inputs 339 | output = torch.cat([output, output, output, output], dim=1) 340 | output = self.pixelshuffle(output) 341 | return self.conv(output) 342 | 343 | 344 | class ConditionalResidualBlock(nn.Module): 345 | def __init__(self, input_dim, output_dim, num_classes, resample=None, act=nn.ELU(), 346 | normalization=ConditionalBatchNorm2d, adjust_padding=False, dilation=None, spec_norm=False): 347 | super().__init__() 348 | self.non_linearity = act 349 | self.input_dim = input_dim 350 | self.output_dim = output_dim 351 | self.resample = resample 352 | self.normalization = normalization 353 | if resample == 'down': 354 | if dilation is not None: 355 | self.conv1 = dilated_conv3x3(input_dim, input_dim, dilation=dilation, spec_norm=spec_norm) 356 | self.normalize2 = normalization(input_dim, num_classes) 357 | self.conv2 = dilated_conv3x3(input_dim, output_dim, dilation=dilation, spec_norm=spec_norm) 358 | conv_shortcut = partial(dilated_conv3x3, dilation=dilation, spec_norm=spec_norm) 359 | else: 360 | self.conv1 = conv3x3(input_dim, input_dim, spec_norm=spec_norm) 361 | self.normalize2 = normalization(input_dim, num_classes) 362 | self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding, spec_norm=spec_norm) 363 | conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding, spec_norm=spec_norm) 364 | 365 | elif resample is None: 366 | if dilation is not None: 367 | conv_shortcut = partial(dilated_conv3x3, dilation=dilation, spec_norm=spec_norm) 368 | self.conv1 = dilated_conv3x3(input_dim, output_dim, dilation=dilation, spec_norm=spec_norm) 369 | self.normalize2 = normalization(output_dim, num_classes) 370 | self.conv2 = dilated_conv3x3(output_dim, output_dim, dilation=dilation, spec_norm=spec_norm) 371 | else: 372 | conv_shortcut = nn.Conv2d 373 | self.conv1 = conv3x3(input_dim, output_dim, spec_norm=spec_norm) 374 | self.normalize2 = normalization(output_dim, num_classes) 375 | self.conv2 = conv3x3(output_dim, output_dim, spec_norm=spec_norm) 376 | else: 377 | raise Exception('invalid resample value') 378 | 379 | if output_dim != input_dim or resample is not None: 380 | self.shortcut = conv_shortcut(input_dim, output_dim) 381 | 382 | self.normalize1 = normalization(input_dim, num_classes) 383 | 384 | 385 | def forward(self, x, y): 386 | output = self.normalize1(x, y) 387 | output = self.non_linearity(output) 388 | output = self.conv1(output) 389 | output = self.normalize2(output, y) 390 | output = self.non_linearity(output) 391 | output = self.conv2(output) 392 | 393 | if self.output_dim == self.input_dim and self.resample is None: 394 | shortcut = x 395 | else: 396 | shortcut = self.shortcut(x) 397 | 398 | return shortcut + output 399 | 400 | 401 | class ResidualBlock(nn.Module): 402 | def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(), 403 | normalization=nn.BatchNorm2d, adjust_padding=False, dilation=None, spec_norm=False): 404 | super().__init__() 405 | self.non_linearity = act 406 | self.input_dim = input_dim 407 | self.output_dim = output_dim 408 | self.resample = resample 409 | self.normalization = normalization 410 | if resample == 'down': 411 | if dilation is not None: 412 | self.conv1 = dilated_conv3x3(input_dim, input_dim, dilation=dilation, spec_norm=spec_norm) 413 | self.normalize2 = normalization(input_dim) 414 | self.conv2 = dilated_conv3x3(input_dim, output_dim, dilation=dilation, spec_norm=spec_norm) 415 | conv_shortcut = partial(dilated_conv3x3, dilation=dilation, spec_norm=spec_norm) 416 | else: 417 | self.conv1 = conv3x3(input_dim, input_dim, spec_norm=spec_norm) 418 | self.normalize2 = normalization(input_dim) 419 | self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding, spec_norm=spec_norm) 420 | conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding, spec_norm=spec_norm) 421 | 422 | elif resample is None: 423 | if dilation is not None: 424 | conv_shortcut = partial(dilated_conv3x3, dilation=dilation, spec_norm=spec_norm) 425 | self.conv1 = dilated_conv3x3(input_dim, output_dim, dilation=dilation, spec_norm=spec_norm) 426 | self.normalize2 = normalization(output_dim) 427 | self.conv2 = dilated_conv3x3(output_dim, output_dim, dilation=dilation, spec_norm=spec_norm) 428 | else: 429 | # conv_shortcut = nn.Conv2d ### Something wierd here. 430 | conv_shortcut = partial(conv1x1, spec_norm=spec_norm) 431 | self.conv1 = conv3x3(input_dim, output_dim, spec_norm=spec_norm) 432 | self.normalize2 = normalization(output_dim) 433 | self.conv2 = conv3x3(output_dim, output_dim, spec_norm=spec_norm) 434 | else: 435 | raise Exception('invalid resample value') 436 | 437 | if output_dim != input_dim or resample is not None: 438 | self.shortcut = conv_shortcut(input_dim, output_dim) 439 | 440 | self.normalize1 = normalization(input_dim) 441 | 442 | 443 | def forward(self, x): 444 | output = self.normalize1(x) 445 | output = self.non_linearity(output) 446 | output = self.conv1(output) 447 | output = self.normalize2(output) 448 | output = self.non_linearity(output) 449 | output = self.conv2(output) 450 | 451 | if self.output_dim == self.input_dim and self.resample is None: 452 | shortcut = x 453 | else: 454 | shortcut = self.shortcut(x) 455 | 456 | return shortcut + output 457 | -------------------------------------------------------------------------------- /models/ncsnv2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch.nn.functional as F 4 | import torch 5 | from functools import partial 6 | from . import get_sigmas 7 | from .layers import * 8 | from .normalization import get_normalization 9 | 10 | 11 | class NCSNv2(nn.Module): 12 | def __init__(self, config): 13 | super().__init__() 14 | self.logit_transform = config.data.logit_transform 15 | self.rescaled = config.data.rescaled 16 | self.norm = get_normalization(config, conditional=False) 17 | self.ngf = ngf = config.model.ngf 18 | self.num_classes = num_classes = config.model.num_classes 19 | 20 | self.act = act = get_act(config) 21 | self.register_buffer('sigmas', get_sigmas(config)) 22 | self.config = config 23 | 24 | self.begin_conv = nn.Conv2d(config.data.channels, ngf, 3, stride=1, padding=1) 25 | 26 | self.normalizer = self.norm(ngf, self.num_classes) 27 | self.end_conv = nn.Conv2d(ngf, config.data.channels, 3, stride=1, padding=1) 28 | 29 | self.res1 = nn.ModuleList([ 30 | ResidualBlock(self.ngf, self.ngf, resample=None, act=act, 31 | normalization=self.norm), 32 | ResidualBlock(self.ngf, self.ngf, resample=None, act=act, 33 | normalization=self.norm)] 34 | ) 35 | 36 | self.res2 = nn.ModuleList([ 37 | ResidualBlock(self.ngf, 2 * self.ngf, resample='down', act=act, 38 | normalization=self.norm), 39 | ResidualBlock(2 * self.ngf, 2 * self.ngf, resample=None, act=act, 40 | normalization=self.norm)] 41 | ) 42 | 43 | self.res3 = nn.ModuleList([ 44 | ResidualBlock(2 * self.ngf, 2 * self.ngf, resample='down', act=act, 45 | normalization=self.norm, dilation=2), 46 | ResidualBlock(2 * self.ngf, 2 * self.ngf, resample=None, act=act, 47 | normalization=self.norm, dilation=2)] 48 | ) 49 | 50 | if config.data.image_size == 28: 51 | self.res4 = nn.ModuleList([ 52 | ResidualBlock(2 * self.ngf, 2 * self.ngf, resample='down', act=act, 53 | normalization=self.norm, adjust_padding=True, dilation=4), 54 | ResidualBlock(2 * self.ngf, 2 * self.ngf, resample=None, act=act, 55 | normalization=self.norm, dilation=4)] 56 | ) 57 | else: 58 | self.res4 = nn.ModuleList([ 59 | ResidualBlock(2 * self.ngf, 2 * self.ngf, resample='down', act=act, 60 | normalization=self.norm, adjust_padding=False, dilation=4), 61 | ResidualBlock(2 * self.ngf, 2 * self.ngf, resample=None, act=act, 62 | normalization=self.norm, dilation=4)] 63 | ) 64 | 65 | self.refine1 = RefineBlock([2 * self.ngf], 2 * self.ngf, act=act, start=True) 66 | self.refine2 = RefineBlock([2 * self.ngf, 2 * self.ngf], 2 * self.ngf, act=act) 67 | self.refine3 = RefineBlock([2 * self.ngf, 2 * self.ngf], self.ngf, act=act) 68 | self.refine4 = RefineBlock([self.ngf, self.ngf], self.ngf, act=act, end=True) 69 | 70 | def _compute_cond_module(self, module, x): 71 | for m in module: 72 | x = m(x) 73 | return x 74 | 75 | def forward(self, x, y): 76 | if not self.logit_transform and not self.rescaled: 77 | h = 2 * x - 1. 78 | else: 79 | h = x 80 | 81 | output = self.begin_conv(h) 82 | 83 | layer1 = self._compute_cond_module(self.res1, output) 84 | layer2 = self._compute_cond_module(self.res2, layer1) 85 | layer3 = self._compute_cond_module(self.res3, layer2) 86 | layer4 = self._compute_cond_module(self.res4, layer3) 87 | 88 | ref1 = self.refine1([layer4], layer4.shape[2:]) 89 | ref2 = self.refine2([layer3, ref1], layer3.shape[2:]) 90 | ref3 = self.refine3([layer2, ref2], layer2.shape[2:]) 91 | output = self.refine4([layer1, ref3], layer1.shape[2:]) 92 | 93 | output = self.normalizer(output) 94 | output = self.act(output) 95 | output = self.end_conv(output) 96 | 97 | used_sigmas = self.sigmas[y].view(x.shape[0], *([1] * len(x.shape[1:]))) 98 | 99 | output = output / used_sigmas 100 | 101 | return output 102 | 103 | 104 | class NCSNv2Deeper(nn.Module): 105 | def __init__(self, config): 106 | super().__init__() 107 | self.logit_transform = config.data.logit_transform 108 | self.rescaled = config.data.rescaled 109 | self.norm = get_normalization(config, conditional=False) 110 | self.ngf = ngf = config.model.ngf 111 | self.num_classes = config.model.num_classes 112 | self.act = act = get_act(config) 113 | self.register_buffer('sigmas', get_sigmas(config)) 114 | self.config = config 115 | 116 | self.begin_conv = nn.Conv2d(config.data.channels, ngf, 3, stride=1, padding=1) 117 | self.normalizer = self.norm(ngf, self.num_classes) 118 | 119 | self.end_conv = nn.Conv2d(ngf, config.data.channels, 3, stride=1, padding=1) 120 | 121 | self.res1 = nn.ModuleList([ 122 | ResidualBlock(self.ngf, self.ngf, resample=None, act=act, 123 | normalization=self.norm), 124 | ResidualBlock(self.ngf, self.ngf, resample=None, act=act, 125 | normalization=self.norm)] 126 | ) 127 | 128 | self.res2 = nn.ModuleList([ 129 | ResidualBlock(self.ngf, 2 * self.ngf, resample='down', act=act, 130 | normalization=self.norm), 131 | ResidualBlock(2 * self.ngf, 2 * self.ngf, resample=None, act=act, 132 | normalization=self.norm)] 133 | ) 134 | 135 | self.res3 = nn.ModuleList([ 136 | ResidualBlock(2 * self.ngf, 2 * self.ngf, resample='down', act=act, 137 | normalization=self.norm), 138 | ResidualBlock(2 * self.ngf, 2 * self.ngf, resample=None, act=act, 139 | normalization=self.norm)] 140 | ) 141 | 142 | self.res4 = nn.ModuleList([ 143 | ResidualBlock(2 * self.ngf, 4 * self.ngf, resample='down', act=act, 144 | normalization=self.norm, dilation=2), 145 | ResidualBlock(4 * self.ngf, 4 * self.ngf, resample=None, act=act, 146 | normalization=self.norm, dilation=2)] 147 | ) 148 | 149 | self.res5 = nn.ModuleList([ 150 | ResidualBlock(4 * self.ngf, 4 * self.ngf, resample='down', act=act, 151 | normalization=self.norm, dilation=4), 152 | ResidualBlock(4 * self.ngf, 4 * self.ngf, resample=None, act=act, 153 | normalization=self.norm, dilation=4)] 154 | ) 155 | 156 | self.refine1 = RefineBlock([4 * self.ngf], 4 * self.ngf, act=act, start=True) 157 | self.refine2 = RefineBlock([4 * self.ngf, 4 * self.ngf], 2 * self.ngf, act=act) 158 | self.refine3 = RefineBlock([2 * self.ngf, 2 * self.ngf], 2 * self.ngf, act=act) 159 | self.refine4 = RefineBlock([2 * self.ngf, 2 * self.ngf], self.ngf, act=act) 160 | self.refine5 = RefineBlock([self.ngf, self.ngf], self.ngf, act=act, end=True) 161 | 162 | def _compute_cond_module(self, module, x): 163 | for m in module: 164 | x = m(x) 165 | return x 166 | 167 | def forward(self, x, y): 168 | if not self.logit_transform and not self.rescaled: 169 | h = 2 * x - 1. 170 | else: 171 | h = x 172 | 173 | output = self.begin_conv(h) 174 | 175 | layer1 = self._compute_cond_module(self.res1, output) 176 | layer2 = self._compute_cond_module(self.res2, layer1) 177 | layer3 = self._compute_cond_module(self.res3, layer2) 178 | layer4 = self._compute_cond_module(self.res4, layer3) 179 | layer5 = self._compute_cond_module(self.res5, layer4) 180 | 181 | ref1 = self.refine1([layer5], layer5.shape[2:]) 182 | ref2 = self.refine2([layer4, ref1], layer4.shape[2:]) 183 | ref3 = self.refine3([layer3, ref2], layer3.shape[2:]) 184 | ref4 = self.refine4([layer2, ref3], layer2.shape[2:]) 185 | output = self.refine5([layer1, ref4], layer1.shape[2:]) 186 | 187 | output = self.normalizer(output) 188 | output = self.act(output) 189 | output = self.end_conv(output) 190 | 191 | used_sigmas = self.sigmas[y].view(x.shape[0], *([1] * len(x.shape[1:]))) 192 | 193 | output = output / used_sigmas 194 | 195 | return output 196 | 197 | 198 | class NCSNv2Deepest(nn.Module): 199 | def __init__(self, config): 200 | super().__init__() 201 | self.logit_transform = config.data.logit_transform 202 | self.rescaled = config.data.rescaled 203 | self.norm = get_normalization(config, conditional=False) 204 | self.ngf = ngf = config.model.ngf 205 | self.num_classes = config.model.num_classes 206 | self.act = act = get_act(config) 207 | self.register_buffer('sigmas', get_sigmas(config)) 208 | self.config = config 209 | 210 | self.begin_conv = nn.Conv2d(config.data.channels, ngf, 3, stride=1, padding=1) 211 | self.normalizer = self.norm(ngf, self.num_classes) 212 | 213 | self.end_conv = nn.Conv2d(ngf, config.data.channels, 3, stride=1, padding=1) 214 | 215 | self.res1 = nn.ModuleList([ 216 | ResidualBlock(self.ngf, self.ngf, resample=None, act=act, 217 | normalization=self.norm), 218 | ResidualBlock(self.ngf, self.ngf, resample=None, act=act, 219 | normalization=self.norm)] 220 | ) 221 | 222 | self.res2 = nn.ModuleList([ 223 | ResidualBlock(self.ngf, 2 * self.ngf, resample='down', act=act, 224 | normalization=self.norm), 225 | ResidualBlock(2 * self.ngf, 2 * self.ngf, resample=None, act=act, 226 | normalization=self.norm)] 227 | ) 228 | 229 | self.res3 = nn.ModuleList([ 230 | ResidualBlock(2 * self.ngf, 2 * self.ngf, resample='down', act=act, 231 | normalization=self.norm), 232 | ResidualBlock(2 * self.ngf, 2 * self.ngf, resample=None, act=act, 233 | normalization=self.norm)] 234 | ) 235 | 236 | self.res31 = nn.ModuleList([ 237 | ResidualBlock(2 * self.ngf, 2 * self.ngf, resample='down', act=act, 238 | normalization=self.norm), 239 | ResidualBlock(2 * self.ngf, 2 * self.ngf, resample=None, act=act, 240 | normalization=self.norm)] 241 | ) 242 | 243 | self.res4 = nn.ModuleList([ 244 | ResidualBlock(2 * self.ngf, 4 * self.ngf, resample='down', act=act, 245 | normalization=self.norm, dilation=2), 246 | ResidualBlock(4 * self.ngf, 4 * self.ngf, resample=None, act=act, 247 | normalization=self.norm, dilation=2)] 248 | ) 249 | 250 | self.res5 = nn.ModuleList([ 251 | ResidualBlock(4 * self.ngf, 4 * self.ngf, resample='down', act=act, 252 | normalization=self.norm, dilation=4), 253 | ResidualBlock(4 * self.ngf, 4 * self.ngf, resample=None, act=act, 254 | normalization=self.norm, dilation=4)] 255 | ) 256 | 257 | self.refine1 = RefineBlock([4 * self.ngf], 4 * self.ngf, act=act, start=True) 258 | self.refine2 = RefineBlock([4 * self.ngf, 4 * self.ngf], 2 * self.ngf, act=act) 259 | self.refine3 = RefineBlock([2 * self.ngf, 2 * self.ngf], 2 * self.ngf, act=act) 260 | self.refine31 = RefineBlock([2 * self.ngf, 2 * self.ngf], 2 * self.ngf, act=act) 261 | self.refine4 = RefineBlock([2 * self.ngf, 2 * self.ngf], self.ngf, act=act) 262 | self.refine5 = RefineBlock([self.ngf, self.ngf], self.ngf, act=act, end=True) 263 | 264 | def _compute_cond_module(self, module, x): 265 | for m in module: 266 | x = m(x) 267 | return x 268 | 269 | def forward(self, x, y): 270 | if not self.logit_transform and not self.rescaled: 271 | h = 2 * x - 1. 272 | else: 273 | h = x 274 | 275 | output = self.begin_conv(h) 276 | 277 | layer1 = self._compute_cond_module(self.res1, output) 278 | layer2 = self._compute_cond_module(self.res2, layer1) 279 | layer3 = self._compute_cond_module(self.res3, layer2) 280 | layer31 = self._compute_cond_module(self.res31, layer3) 281 | layer4 = self._compute_cond_module(self.res4, layer31) 282 | layer5 = self._compute_cond_module(self.res5, layer4) 283 | 284 | ref1 = self.refine1([layer5], layer5.shape[2:]) 285 | ref2 = self.refine2([layer4, ref1], layer4.shape[2:]) 286 | ref31 = self.refine31([layer31, ref2], layer31.shape[2:]) 287 | ref3 = self.refine3([layer3, ref31], layer3.shape[2:]) 288 | ref4 = self.refine4([layer2, ref3], layer2.shape[2:]) 289 | output = self.refine5([layer1, ref4], layer1.shape[2:]) 290 | 291 | output = self.normalizer(output) 292 | output = self.act(output) 293 | output = self.end_conv(output) 294 | 295 | used_sigmas = self.sigmas[y].view(x.shape[0], *([1] * len(x.shape[1:]))) 296 | 297 | output = output / used_sigmas 298 | 299 | return output 300 | -------------------------------------------------------------------------------- /models/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def get_normalization(config, conditional=True): 6 | norm = config.model.normalization 7 | if conditional: 8 | if norm == 'NoneNorm': 9 | return ConditionalNoneNorm2d 10 | elif norm == 'InstanceNorm++': 11 | return ConditionalInstanceNorm2dPlus 12 | elif norm == 'InstanceNorm': 13 | return ConditionalInstanceNorm2d 14 | elif norm == 'BatchNorm': 15 | return ConditionalBatchNorm2d 16 | elif norm == 'VarianceNorm': 17 | return ConditionalVarianceNorm2d 18 | else: 19 | raise NotImplementedError("{} does not exist!".format(norm)) 20 | else: 21 | if norm == 'BatchNorm': 22 | return nn.BatchNorm2d 23 | elif norm == 'InstanceNorm': 24 | return nn.InstanceNorm2d 25 | elif norm == 'InstanceNorm++': 26 | return InstanceNorm2dPlus 27 | elif norm == 'VarianceNorm': 28 | return VarianceNorm2d 29 | elif norm == 'NoneNorm': 30 | return NoneNorm2d 31 | elif norm is None: 32 | return None 33 | else: 34 | raise NotImplementedError("{} does not exist!".format(norm)) 35 | 36 | class ConditionalBatchNorm2d(nn.Module): 37 | def __init__(self, num_features, num_classes, bias=True): 38 | super().__init__() 39 | self.num_features = num_features 40 | self.bias = bias 41 | self.bn = nn.BatchNorm2d(num_features, affine=False) 42 | if self.bias: 43 | self.embed = nn.Embedding(num_classes, num_features * 2) 44 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 45 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 46 | else: 47 | self.embed = nn.Embedding(num_classes, num_features) 48 | self.embed.weight.data.uniform_() 49 | 50 | def forward(self, x, y): 51 | out = self.bn(x) 52 | if self.bias: 53 | gamma, beta = self.embed(y).chunk(2, dim=1) 54 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) 55 | else: 56 | gamma = self.embed(y) 57 | out = gamma.view(-1, self.num_features, 1, 1) * out 58 | return out 59 | 60 | 61 | class ConditionalInstanceNorm2d(nn.Module): 62 | def __init__(self, num_features, num_classes, bias=True): 63 | super().__init__() 64 | self.num_features = num_features 65 | self.bias = bias 66 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 67 | if bias: 68 | self.embed = nn.Embedding(num_classes, num_features * 2) 69 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 70 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 71 | else: 72 | self.embed = nn.Embedding(num_classes, num_features) 73 | self.embed.weight.data.uniform_() 74 | 75 | def forward(self, x, y): 76 | h = self.instance_norm(x) 77 | if self.bias: 78 | gamma, beta = self.embed(y).chunk(2, dim=-1) 79 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 80 | else: 81 | gamma = self.embed(y) 82 | out = gamma.view(-1, self.num_features, 1, 1) * h 83 | return out 84 | 85 | 86 | class ConditionalVarianceNorm2d(nn.Module): 87 | def __init__(self, num_features, num_classes, bias=False): 88 | super().__init__() 89 | self.num_features = num_features 90 | self.bias = bias 91 | self.embed = nn.Embedding(num_classes, num_features) 92 | self.embed.weight.data.normal_(1, 0.02) 93 | 94 | def forward(self, x, y): 95 | vars = torch.var(x, dim=(2, 3), keepdim=True) 96 | h = x / torch.sqrt(vars + 1e-5) 97 | 98 | gamma = self.embed(y) 99 | out = gamma.view(-1, self.num_features, 1, 1) * h 100 | return out 101 | 102 | 103 | class VarianceNorm2d(nn.Module): 104 | def __init__(self, num_features, bias=False): 105 | super().__init__() 106 | self.num_features = num_features 107 | self.bias = bias 108 | self.alpha = nn.Parameter(torch.zeros(num_features)) 109 | self.alpha.data.normal_(1, 0.02) 110 | 111 | def forward(self, x): 112 | vars = torch.var(x, dim=(2, 3), keepdim=True) 113 | h = x / torch.sqrt(vars + 1e-5) 114 | 115 | out = self.alpha.view(-1, self.num_features, 1, 1) * h 116 | return out 117 | 118 | 119 | class ConditionalNoneNorm2d(nn.Module): 120 | def __init__(self, num_features, num_classes, bias=True): 121 | super().__init__() 122 | self.num_features = num_features 123 | self.bias = bias 124 | if bias: 125 | self.embed = nn.Embedding(num_classes, num_features * 2) 126 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 127 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 128 | else: 129 | self.embed = nn.Embedding(num_classes, num_features) 130 | self.embed.weight.data.uniform_() 131 | 132 | def forward(self, x, y): 133 | if self.bias: 134 | gamma, beta = self.embed(y).chunk(2, dim=-1) 135 | out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1) 136 | else: 137 | gamma = self.embed(y) 138 | out = gamma.view(-1, self.num_features, 1, 1) * x 139 | return out 140 | 141 | 142 | class NoneNorm2d(nn.Module): 143 | def __init__(self, num_features, bias=True): 144 | super().__init__() 145 | 146 | def forward(self, x): 147 | return x 148 | 149 | 150 | class InstanceNorm2dPlus(nn.Module): 151 | def __init__(self, num_features, bias=True): 152 | super().__init__() 153 | self.num_features = num_features 154 | self.bias = bias 155 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 156 | self.alpha = nn.Parameter(torch.zeros(num_features)) 157 | self.gamma = nn.Parameter(torch.zeros(num_features)) 158 | self.alpha.data.normal_(1, 0.02) 159 | self.gamma.data.normal_(1, 0.02) 160 | if bias: 161 | self.beta = nn.Parameter(torch.zeros(num_features)) 162 | 163 | def forward(self, x): 164 | means = torch.mean(x, dim=(2, 3)) 165 | m = torch.mean(means, dim=-1, keepdim=True) 166 | v = torch.var(means, dim=-1, keepdim=True) 167 | means = (means - m) / (torch.sqrt(v + 1e-5)) 168 | h = self.instance_norm(x) 169 | 170 | if self.bias: 171 | h = h + means[..., None, None] * self.alpha[..., None, None] 172 | out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1) 173 | else: 174 | h = h + means[..., None, None] * self.alpha[..., None, None] 175 | out = self.gamma.view(-1, self.num_features, 1, 1) * h 176 | return out 177 | 178 | 179 | class ConditionalInstanceNorm2dPlus(nn.Module): 180 | def __init__(self, num_features, num_classes, bias=True): 181 | super().__init__() 182 | self.num_features = num_features 183 | self.bias = bias 184 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 185 | if bias: 186 | self.embed = nn.Embedding(num_classes, num_features * 3) 187 | self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) 188 | self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0 189 | else: 190 | self.embed = nn.Embedding(num_classes, 2 * num_features) 191 | self.embed.weight.data.normal_(1, 0.02) 192 | 193 | def forward(self, x, y): 194 | means = torch.mean(x, dim=(2, 3)) 195 | m = torch.mean(means, dim=-1, keepdim=True) 196 | v = torch.var(means, dim=-1, keepdim=True) 197 | means = (means - m) / (torch.sqrt(v + 1e-5)) 198 | h = self.instance_norm(x) 199 | 200 | if self.bias: 201 | gamma, alpha, beta = self.embed(y).chunk(3, dim=-1) 202 | h = h + means[..., None, None] * alpha[..., None, None] 203 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 204 | else: 205 | gamma, alpha = self.embed(y).chunk(2, dim=-1) 206 | h = h + means[..., None, None] * alpha[..., None, None] 207 | out = gamma.view(-1, self.num_features, 1, 1) * h 208 | return out 209 | -------------------------------------------------------------------------------- /runners/__init__.py: -------------------------------------------------------------------------------- 1 | from runners.ncsn_runner import * 2 | -------------------------------------------------------------------------------- /runners/ncsn_runner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | import tqdm 4 | 5 | import torch.nn.functional as F 6 | import torch 7 | import os 8 | from torchvision.utils import make_grid, save_image 9 | from torch.utils.data import DataLoader 10 | from models.ncsnv2 import NCSNv2Deeper, NCSNv2, NCSNv2Deepest 11 | from datasets import get_dataset, data_transform, inverse_data_transform 12 | from models import general_anneal_Langevin_dynamics 13 | from models import get_sigmas 14 | from models.ema import EMAHelper 15 | from filter_builder import get_custom_kernel 16 | 17 | __all__ = ['NCSNRunner'] 18 | 19 | 20 | def get_model(config): 21 | if config.data.dataset == 'CELEBA': 22 | return NCSNv2(config).to(config.device) 23 | elif config.data.dataset == 'LSUN': 24 | return NCSNv2Deeper(config).to(config.device) 25 | 26 | class NCSNRunner(): 27 | def __init__(self, args, config): 28 | self.args = args 29 | self.config = config 30 | args.log_sample_path = os.path.join(args.log_path, 'samples') 31 | os.makedirs(args.log_sample_path, exist_ok=True) 32 | 33 | def sample_general(self, score, samples, init_samples, sigma_0, sigmas, num_variations = 8, deg = 'sr4'): 34 | ## show stochastic variation ## 35 | stochastic_variations = torch.zeros((4 + num_variations) * self.config.sampling.batch_size, self.config.data.channels, self.config.data.image_size, 36 | self.config.data.image_size) 37 | 38 | clean = samples.view(samples.shape[0], self.config.data.channels, 39 | self.config.data.image_size, 40 | self.config.data.image_size) 41 | sample = inverse_data_transform(self.config, clean) 42 | stochastic_variations[0 : self.config.sampling.batch_size,:,:,:] = sample 43 | 44 | img_dim = self.config.data.image_size ** 2 45 | 46 | ## get degradation matrix ## 47 | H = 0 48 | if deg[:2] == 'cs': 49 | ## random with set singular values ## 50 | compress_by = int(deg[2:]) 51 | Vt = torch.rand(img_dim, img_dim).to(self.config.device) 52 | Vt, _ = torch.qr(Vt, some=False) 53 | U = torch.rand(img_dim // compress_by, img_dim // compress_by).to(self.config.device) 54 | U, _ = torch.qr(U, some=False) 55 | S = torch.hstack((torch.eye(img_dim // compress_by), torch.zeros(img_dim // compress_by, (compress_by-1) * img_dim // compress_by))).to(self.config.device) 56 | H = torch.matmul(U, torch.matmul(S, Vt)) 57 | elif deg == 'inp': 58 | ## crop ## 59 | H = torch.eye(img_dim).to(self.config.device) 60 | H = H[:-(self.config.data.image_size*20), :] 61 | elif deg == 'deblur_uni': 62 | ## blur ## 63 | H = torch.from_numpy(get_custom_kernel(type="uniform", dim = self.config.data.image_size)).type(torch.FloatTensor).to(self.config.device) 64 | elif deg == 'deblur_gauss': 65 | ## blur ## 66 | H = torch.from_numpy(get_custom_kernel(type="gauss", dim = self.config.data.image_size)).type(torch.FloatTensor).to(self.config.device) 67 | elif deg[:2] == 'sr': 68 | ## downscale - super resolution ## 69 | blur_by = int(deg[2:]) 70 | H = torch.zeros((img_dim // (blur_by**2), img_dim)).to(self.config.device) 71 | for i in range(self.config.data.image_size // blur_by): 72 | for j in range(self.config.data.image_size // blur_by): 73 | for i_inc in range(blur_by): 74 | for j_inc in range(blur_by): 75 | H[i * self.config.data.image_size // blur_by + j, (blur_by*i + i_inc) * self.config.data.image_size + (blur_by*j + j_inc)] = (1/blur_by**2) 76 | else: 77 | print("ERROR: degradation type not supported") 78 | quit() 79 | 80 | ## set up input for the problem ## 81 | y_0 = torch.matmul(H, samples.view(samples.shape[0] * self.config.data.channels, 82 | img_dim, 1)).view(samples.shape[0], self.config.data.channels, H.shape[0]) 83 | y_0 = y_0 + sigma_0 * torch.randn_like(y_0) 84 | torch.save(y_0, os.path.join(self.args.image_folder, "y_0.pt")) 85 | 86 | H_t = H.transpose(0,1) 87 | H_cross = torch.matmul(H_t, torch.inverse(torch.matmul(H, H_t))) 88 | pinv_y_0 = torch.matmul(H_cross, y_0.view(samples.shape[0] * self.config.data.channels, 89 | H.shape[0], 1)) 90 | if deg == 'deblur_uni' or deg == 'deblur_gauss': pinv_y_0 = y_0 91 | sample = inverse_data_transform(self.config, pinv_y_0.view(samples.shape[0], self.config.data.channels, 92 | self.config.data.image_size, 93 | self.config.data.image_size)) 94 | stochastic_variations[1 * self.config.sampling.batch_size : 2 * self.config.sampling.batch_size,:,:,:] = sample 95 | 96 | ## apply SNIPS ## 97 | for i in range(num_variations): 98 | all_samples = general_anneal_Langevin_dynamics(H, y_0, init_samples, score, sigmas, 99 | self.config.sampling.n_steps_each, 100 | self.config.sampling.step_lr, verbose=True, 101 | final_only=self.config.sampling.final_only, 102 | denoise=self.config.sampling.denoise, c_begin=0, sigma_0 = sigma_0) 103 | 104 | sample = all_samples[-1].view(all_samples[-1].shape[0], self.config.data.channels, 105 | self.config.data.image_size, 106 | self.config.data.image_size).to(self.config.device) 107 | stochastic_variations[(self.config.sampling.batch_size) * (i+2) : (self.config.sampling.batch_size) * (i+3),:,:,:] = inverse_data_transform(self.config, sample) 108 | 109 | ## calculate mean and std ## 110 | runs = stochastic_variations[(self.config.sampling.batch_size) * (2) : (self.config.sampling.batch_size) * (2+num_variations),:,:,:] 111 | runs = runs.view(-1, self.config.sampling.batch_size, self.config.data.channels, 112 | self.config.data.image_size, 113 | self.config.data.image_size) 114 | 115 | stochastic_variations[(self.config.sampling.batch_size) * (-2) : (self.config.sampling.batch_size) * (-1),:,:,:] = torch.mean(runs, dim=0) 116 | stochastic_variations[(self.config.sampling.batch_size) * (-1) : ,:,:,:] = torch.std(runs, dim=0) 117 | 118 | torch.save(stochastic_variations, os.path.join(self.args.image_folder, "results.pt")) 119 | 120 | image_grid = make_grid(stochastic_variations, self.config.sampling.batch_size) 121 | save_image(image_grid, os.path.join(self.args.image_folder, 'stochastic_variation.png')) 122 | 123 | ## report PSNRs ## 124 | clean = stochastic_variations[0 * self.config.sampling.batch_size : 1 * self.config.sampling.batch_size,:,:,:] 125 | 126 | for i in range(num_variations): 127 | general = stochastic_variations[(2+i) * self.config.sampling.batch_size : (3+i) * self.config.sampling.batch_size,:,:,:] 128 | mse = torch.mean((general - clean) ** 2) 129 | instance_mse = ((general - clean) ** 2).view(general.shape[0], -1).mean(1) 130 | psnr = torch.mean(10 * torch.log10(1/instance_mse)) 131 | print("MSE/PSNR of the general #%d: %f, %f" % (i, mse, psnr)) 132 | 133 | mean = stochastic_variations[(2+num_variations) * self.config.sampling.batch_size : (3+num_variations) * self.config.sampling.batch_size,:,:,:] 134 | mse = torch.mean((mean - clean) ** 2) 135 | instance_mse = ((mean - clean) ** 2).view(mean.shape[0], -1).mean(1) 136 | psnr = torch.mean(10 * torch.log10(1/instance_mse)) 137 | print("MSE/PSNR of the mean: %f, %f" % (mse, psnr)) 138 | 139 | 140 | def sample(self): 141 | score, states = 0, 0 142 | if self.config.sampling.ckpt_id is None: 143 | states = torch.load(os.path.join(self.args.log_path, 'checkpoint.pth'), map_location=self.config.device) 144 | else: 145 | states = torch.load(os.path.join(self.args.log_path, f'checkpoint_{self.config.sampling.ckpt_id}.pth'), 146 | map_location=self.config.device) 147 | 148 | score = get_model(self.config) 149 | score = torch.nn.DataParallel(score) 150 | 151 | score.load_state_dict(states[0], strict=True) 152 | 153 | if self.config.model.ema: 154 | ema_helper = EMAHelper(mu=self.config.model.ema_rate) 155 | ema_helper.register(score) 156 | ema_helper.load_state_dict(states[-1]) 157 | ema_helper.ema(score) 158 | 159 | sigmas_th = get_sigmas(self.config) 160 | sigmas = sigmas_th.cpu().numpy() 161 | 162 | sigma_0 = self.args.sigma_0 163 | 164 | dataset = get_dataset(self.args, self.config) 165 | dataloader = DataLoader(dataset, batch_size=self.config.sampling.batch_size, shuffle=True, 166 | num_workers=4) 167 | 168 | score.eval() 169 | 170 | data_iter = iter(dataloader) 171 | samples, _ = next(data_iter) 172 | samples = samples.to(self.config.device) 173 | samples = data_transform(self.config, samples) 174 | init_samples = torch.rand_like(samples) 175 | 176 | self.sample_general(score, samples, init_samples, sigma_0, sigmas, num_variations=self.args.num_variations, deg=self.args.degradation) 177 | --------------------------------------------------------------------------------