├── stylegan2 ├── __init__.py ├── op │ ├── __init__.py │ ├── fused_bias_act.cpp │ ├── upfirdn2d.cpp │ ├── fused_act.py │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu └── model.py ├── .gitmodules ├── include ├── __init__.py ├── visualize.py ├── compression.py ├── wavelet.py ├── helpers.py ├── decoder.py └── fit.py ├── .gitignore ├── requirements.txt ├── scripts ├── fairness │ ├── ffhq.sh │ ├── cat10.sh │ ├── cat30.sh │ ├── cat40.sh │ ├── cat60.sh │ ├── cat70.sh │ ├── cat90.sh │ ├── cat80.sh │ ├── cat50.sh │ └── cat20.sh └── compressed-sensing │ ├── celeba_map.sh │ ├── celeba_langevin.sh │ ├── ffhq_map.sh │ └── ffhq_langevin.sh ├── src ├── SphericalOptimizer.py ├── loss.py ├── view_estimated_ffhq_cs.py ├── view_estimated_celebA_cs.py ├── metrics_utils.py ├── cs_metrics.ipynb ├── PULSE.py ├── compressed_sensing.py ├── utils.py └── estimators.py ├── download.sh ├── LICENSE ├── shuffle_catdog.sh ├── README.md └── glow └── model.py /stylegan2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "ncsnv2"] 2 | path = ncsnv2 3 | url = git@github.com:ermongroup/ncsnv2.git 4 | -------------------------------------------------------------------------------- /stylegan2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /include/__init__.py: -------------------------------------------------------------------------------- 1 | from .wavelet import * 2 | from .decoder import * 3 | from .visualize import * 4 | from .fit import * 5 | from .helpers import * 6 | from .compression import * -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb 2 | *.pyc 3 | env/ 4 | estimated/ 5 | test_images/ 6 | checkpoints/ 7 | datasets/ 8 | *tar* 9 | *.png 10 | lpips/weights 11 | *backup* 12 | *.pdf 13 | *.npy 14 | *.txt 15 | *.pt 16 | *.sw* 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.0 2 | torchvision==0.8.0 3 | numpy 4 | scipy 5 | matplotlib 6 | ipdb 7 | gdown 8 | ipython 9 | tensorflow-gpu==1.8.0 10 | keras==2.2.0 11 | pillow==5.2.0 12 | scikit-image 13 | ninja 14 | -------------------------------------------------------------------------------- /scripts/fairness/ffhq.sh: -------------------------------------------------------------------------------- 1 | python src/compressed_sensing.py --checkpoint-path checkpoints/ncsnv2_ffhq/checkpoint_80000.pth --net ncsnv2 --dataset ffhq-69000 --num-input-images 1 --batch-size 1 --ncsnv2-configs-file ./ncsnv2/configs/ffhq.yml --measurement-type superres --noise-std 0 --downsample 8 --model-type langevin --print-stats --checkpoint-iter 1 --cuda --mloss-weight 1.0 --learning-rate 9e-7 --sigma-init 348 --sigma-final 0.01 --L 2311 --T 3 2 | -------------------------------------------------------------------------------- /scripts/fairness/cat10.sh: -------------------------------------------------------------------------------- 1 | python src/compressed_sensing.py --checkpoint-path checkpoints/stylegan2_catdog/network-snapshot-cat10dog90-001843.pt --net stylegan2 --dataset cat10dog90 --input-type full-input --num-input-images 555 --batch-size 1 --image-size 512 --measurement-type superres --noise-std 0.1 --downsample 128 --model-type langevin --annealed --mloss-weight -1 --zprior-weight -1 --zprior-sdev 1 --zprior-init-sdev 1.0 --T 3 --L 500 --sigma-init 1.0 --sigma-final 0.1 --num-noise-variables 18 --optimizer-type sgd --learning-rate 5e-6 --momentum 0. --num-random-restarts 4 --checkpoint-iter 1 --cuda --max-update-iter -1 --error-threshold 1.0 2 | 3 | -------------------------------------------------------------------------------- /scripts/fairness/cat30.sh: -------------------------------------------------------------------------------- 1 | python src/compressed_sensing.py --checkpoint-path checkpoints/stylegan2_catdog/network-snapshot-cat30dog70-001843.pt --net stylegan2 --dataset cat30dog70 --input-type full-input --num-input-images 715 --batch-size 1 --image-size 512 --measurement-type superres --noise-std 0.1 --downsample 128 --model-type langevin --annealed --mloss-weight -1 --zprior-weight -1 --zprior-sdev 1 --zprior-init-sdev 1.0 --T 3 --L 500 --sigma-init 1.0 --sigma-final 0.1 --num-noise-variables 18 --optimizer-type sgd --learning-rate 5e-6 --momentum 0. --num-random-restarts 4 --checkpoint-iter 1 --cuda --max-update-iter -1 --error-threshold 1.0 2 | 3 | -------------------------------------------------------------------------------- /scripts/fairness/cat40.sh: -------------------------------------------------------------------------------- 1 | python src/compressed_sensing.py --checkpoint-path checkpoints/stylegan2_catdog/network-snapshot-cat40dog60-001843.pt --net stylegan2 --dataset cat40dog60 --input-type full-input --num-input-images 833 --batch-size 1 --image-size 512 --measurement-type superres --noise-std 0.1 --downsample 128 --model-type langevin --annealed --mloss-weight -1 --zprior-weight -1 --zprior-sdev 1 --zprior-init-sdev 1.0 --T 3 --L 500 --sigma-init 1.0 --sigma-final 0.1 --num-noise-variables 18 --optimizer-type sgd --learning-rate 5e-6 --momentum 0. --num-random-restarts 4 --checkpoint-iter 1 --cuda --max-update-iter -1 --error-threshold 1.0 2 | 3 | -------------------------------------------------------------------------------- /scripts/fairness/cat60.sh: -------------------------------------------------------------------------------- 1 | python src/compressed_sensing.py --checkpoint-path checkpoints/stylegan2_catdog/network-snapshot-cat60dog40-001843.pt --net stylegan2 --dataset cat60dog40 --input-type full-input --num-input-images 833 --batch-size 1 --image-size 512 --measurement-type superres -- noise-std 0.1 --downsample 128 --model-types langevin --annealed --mloss-weight -1 --zprior-weight -1 --zprior-sdev 1 --zprior-init-sdev 1.0 --T 3 --L 500 --sigma-init 1.0 --sigma-final 0.1 --num-noise-variables 18 --optimizer-type sgd --learning-rate 5e-6 --momentum 0. --num-random-restarts 4 --checkpoint-iter 1 --cuda --max-update-iter -1 --error-threshold 1.0 2 | -------------------------------------------------------------------------------- /scripts/fairness/cat70.sh: -------------------------------------------------------------------------------- 1 | python src/compressed_sensing.py --checkpoint-path checkpoints/stylegan2_catdog/network-snapshot-cat70dog30-001843.pt --net stylegan2 --dataset cat70dog30 --input-type full-input --num-input-images 715 --batch-size 1 --image-size 512 --measurement-type superres --noise-std 0.1 --downsample 128 --model-type langevin --annealed --mloss-weight -1 --zprior-weight -1 --zprior-sdev 1 --zprior-init-sdev 1.0 --T 3 --L 500 --sigma-init 1.0 --sigma-final 0.1 --num-noise-variables 18 --optimizer-type sgd --learning-rate 5e-6 --momentum 0. --num-random-restarts 4 --checkpoint-iter 1 --cuda --max-update-iter -1 --error-threshold 1.0 2 | 3 | -------------------------------------------------------------------------------- /scripts/fairness/cat90.sh: -------------------------------------------------------------------------------- 1 | python src/compressed_sensing.py --checkpoint-path checkpoints/stylegan2_catdog/network-snapshot-cat90dog10-001843.pt --net stylegan2 --dataset cat90dog10 --input-type full-input --num-input-images 555 --batch-size 1 --image-size 512 --measurement-type superres --noise-std 0.1 --downsample 128 --model-type langevin --annealed --mloss-weight -1 --zprior-weight -1 --zprior-sdev 1 --zprior-init-sdev 1.0 --T 3 --L 500 --sigma-init 1.0 --sigma-final 0.1 --num-noise-variables 18 --optimizer-type sgd --learning-rate 5e-6 --momentum 0. --num-random-restarts 4 --checkpoint-iter 1 --cuda --max-update-iter -1 --error-threshold 1.0 2 | 3 | -------------------------------------------------------------------------------- /stylegan2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /src/SphericalOptimizer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim import Optimizer 4 | 5 | # Spherical Optimizer Class 6 | # Uses the first two dimensions as batch information 7 | # Optimizes over the surface of a sphere using the initial radius throughout 8 | # 9 | # Example Usage: 10 | # opt = SphericalOptimizer(torch.optim.SGD, [x], lr=0.01) 11 | 12 | class SphericalOptimizer(Optimizer): 13 | def __init__(self, optimizer, params, **kwargs): 14 | self.opt = optimizer(params, **kwargs) 15 | self.params = params 16 | with torch.no_grad(): 17 | self.radii = {param: (param.pow(2).sum(tuple(range(2,param.ndim)),keepdim=True)+1e-9).sqrt() for param in params} 18 | 19 | @torch.no_grad() 20 | def step(self, closure=None): 21 | loss = self.opt.step(closure) 22 | for param in self.params: 23 | param.data.div_((param.pow(2).sum(tuple(range(2,param.ndim)),keepdim=True)+1e-9).sqrt()) 24 | param.mul_(self.radii[param]) 25 | 26 | return loss -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | mkdir -p checkpoints/glow 2 | mkdir datasets 3 | 4 | # afhq dataset 5 | wget -N https://www.dropbox.com/s/t9l9o3vsx2jai3z/afhq.zip?dl=0 -O datasets/afhq.zip 6 | # glow checkpoint 7 | curl https://openaipublic.azureedge.net/glow-demo/large3/graph_unoptimized.pb > checkpoints/glow/graph_unoptimized.pb 8 | # test images 9 | gdown https://drive.google.com/uc?id=1FrijKOZ0Fu3V_SpI2GnILqOY2DfNbRPX 10 | # ncsnv2 checkpoint 11 | gdown https://drive.google.com/uc?id=151V3yt-JYDd298rZ2i8ORSSVHY1HRcUT 12 | # stylegan checkpointss 13 | gdown https://drive.google.com/uc?id=14urG8mZN9ap8ZyHTA-DBJ9NNJvRRfGE1 14 | 15 | # extract stuff 16 | unzip datasets/afhq.zip -d ./datasets 17 | tar -zxvf test_images.tar.gz 18 | 19 | # accidentally included broken symoblic links in the tar archive, so 20 | # delete folders before shuffling and creating cat/dog validation data 21 | # over different biases 22 | rm -r test_images/cat* 23 | bash shuffle_catdog.sh 24 | 25 | # extract ncsnv2 and stylegan2 checkpoints 26 | tar -zxvf ncsnv2_checkpoint.tar.gz 27 | tar -zxvf stylegan2_checkpoints.tar.gz 28 | 29 | -------------------------------------------------------------------------------- /stylegan2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 ajiljalal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /include/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from torch.autograd import Variable 3 | import torch 4 | import torch.optim 5 | import numpy as np 6 | from collections import Iterable 7 | 8 | 9 | dtype = torch.cuda.FloatTensor 10 | #dtype = torch.FloatTensor 11 | 12 | def save_np_img(img,filename): 13 | if(img.shape[0] == 1): 14 | plt.imshow(np.clip(img[0],0,1)) 15 | else: 16 | plt.imshow(np.clip(img.transpose(1, 2, 0),0,1)) 17 | plt.axis('off') 18 | plt.savefig(filename, bbox_inches='tight') 19 | plt.close() 20 | 21 | def apply_until(net_input,net,n = 100): 22 | # applies function by funtion of a network 23 | for i,fun in enumerate(net): 24 | if i>=n: 25 | break 26 | if i==0: 27 | out = fun(net_input.type(dtype)) 28 | else: 29 | out = fun(out) 30 | print(i, "last func. applied:", net[i-1]) 31 | if n == 0: 32 | return net_input 33 | else: 34 | return out 35 | 36 | 37 | from math import ceil 38 | 39 | 40 | # given a lists of images as np-arrays, plot them as a row# given 41 | def plot_image_grid(imgs,nrows=10): 42 | ncols = ceil( len(imgs)/nrows ) 43 | nrows = min(nrows,len(imgs)) 44 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True,figsize=(ncols,nrows),squeeze=False) 45 | for i, row in enumerate(axes): 46 | for j, ax in enumerate(row): 47 | ax.imshow(imgs[j*nrows+i], cmap='Greys_r', interpolation='none') 48 | ax.get_xaxis().set_visible(False) 49 | ax.get_yaxis().set_visible(False) 50 | fig.tight_layout(pad=0.1) 51 | return fig 52 | 53 | def save_tensor(out,filename,nrows=8): 54 | imgs = [img for img in out.data.cpu().numpy()[0]] 55 | fig = plot_image_grid(imgs,nrows=nrows) 56 | plt.savefig(filename) 57 | plt.close() 58 | 59 | -------------------------------------------------------------------------------- /include/compression.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import torch 3 | import torch.optim 4 | import copy 5 | import numpy as np 6 | 7 | from .helpers import * 8 | from .decoder import * 9 | from .fit import * 10 | from .wavelet import * 11 | 12 | def rep_error_deep_decoder(img_np,k=128,convert2ycbcr=False): 13 | ''' 14 | mse obtained by representing img_np with the deep decoder 15 | ''' 16 | output_depth = img_np.shape[0] 17 | if output_depth == 3 and convert2ycbcr: 18 | img = rgb2ycbcr(img_np) 19 | else: 20 | img = img_np 21 | img_var = np_to_var(img).type(dtype) 22 | 23 | num_channels = [k]*5 24 | net = decodernwv2(output_depth,num_channels_up=num_channels,bn_before_act=True).type(dtype) 25 | rnd = 500 26 | numit = 15000 27 | rn = 0.005 28 | mse_n, mse_t, ni, net = fit( num_channels=num_channels, 29 | reg_noise_std=rn, 30 | reg_noise_decayevery = rnd, 31 | num_iter=numit, 32 | LR=0.004, 33 | img_noisy_var=img_var, 34 | net=net, 35 | img_clean_var=img_var, 36 | find_best=True, 37 | ) 38 | out_img = net(ni.type(dtype)).data.cpu().numpy()[0] 39 | if output_depth == 3 and convert2ycbcr: 40 | out_img = ycbcr2rgb(out_img) 41 | return psnr(out_img,img_np), out_img, num_param(net) 42 | 43 | def rep_error_wavelet(img_np,ncoeff=300): 44 | ''' 45 | mse obtained by representing img_np with wavelet thresholding 46 | ncoff coefficients are retained per color channel 47 | ''' 48 | if img_np.shape[0] == 1: 49 | img_np = img_np[0,:,:] 50 | out_img_np = denoise_wavelet(img_np, ncoeff=ncoeff, multichannel=False, convert2ycbcr=True, mode='hard') 51 | else: 52 | img_np = np.transpose(img_np) 53 | out_img_np = denoise_wavelet(img_np, ncoeff=ncoeff, multichannel=True, convert2ycbcr=True, mode='hard') 54 | # img_np = np.array([img_np[:,:,0],img_np[:,:,1],img_np[:,:,2]]) 55 | return psnr(out_img_np,img_np), out_img_np 56 | 57 | def myimgshow(plt,img): 58 | if(img.shape[0] == 1): 59 | plt.imshow(np.clip(img[0],0,1),cmap='Greys',interpolation='none') 60 | else: 61 | plt.imshow(np.clip(img.transpose(1, 2, 0),0,1),interpolation='none') 62 | 63 | -------------------------------------------------------------------------------- /scripts/compressed-sensing/celeba_map.sh: -------------------------------------------------------------------------------- 1 | python src/compressed_sensing.py --checkpoint-path ./checkpoints/glow/graph_unoptimized.pb --net glow --dataset celebA --num-input-images 1 --batch-size 1 --measurement-type circulant --noise-std 16.0 --num-measurements 2500 --model-type map --print-stats --checkpoint-iter 1 --cuda --mloss-weight -1 --learning-rate 0.001 --zprior-weight -1 --no-annealed --optimizer-type adam --max-update-iter 3000 2 | 3 | python src/compressed_sensing.py --checkpoint-path ./checkpoints/glow/graph_unoptimized.pb --net glow --dataset celebA --num-input-images 1 --batch-size 1 --measurement-type circulant --noise-std 16.0 --num-measurements 5000 --model-type map --print-stats --checkpoint-iter 1 --cuda --mloss-weight -1 --learning-rate 0.001 --zprior-weight -1 --no-annealed --optimizer-type adam --max-update-iter 3000 4 | 5 | python src/compressed_sensing.py --checkpoint-path ./checkpoints/glow/graph_unoptimized.pb --net glow --dataset celebA --num-input-images 1 --batch-size 1 --measurement-type circulant --noise-std 16.0 --num-measurements 10000 --model-type map --print-stats --checkpoint-iter 1 --cuda --mloss-weight -1 --learning-rate 0.001 --zprior-weight -1 --no-annealed --optimizer-type adam --max-update-iter 3000 6 | 7 | python src/compressed_sensing.py --checkpoint-path ./checkpoints/glow/graph_unoptimized.pb --net glow --dataset celebA --num-input-images 1 --batch-size 1 --measurement-type circulant --noise-std 16.0 --num-measurements 20000 --model-type map --print-stats --checkpoint-iter 1 --cuda --mloss-weight -1 --learning-rate 0.001 --zprior-weight -1 --no-annealed --optimizer-type adam --max-update-iter 3000 8 | 9 | python src/compressed_sensing.py --checkpoint-path ./checkpoints/glow/graph_unoptimized.pb --net glow --dataset celebA --num-input-images 1 --batch-size 1 --measurement-type circulant --noise-std 16.0 --num-measurements 30000 --model-type map --print-stats --checkpoint-iter 1 --cuda --mloss-weight -1 --learning-rate 0.001 --zprior-weight -1 --no-annealed --optimizer-type adam --max-update-iter 3000 10 | 11 | python src/compressed_sensing.py --checkpoint-path ./checkpoints/glow/graph_unoptimized.pb --net glow --dataset celebA --num-input-images 1 --batch-size 1 --measurement-type circulant --noise-std 16.0 --num-measurements 35000 --model-type map --print-stats --checkpoint-iter 1 --cuda --mloss-weight -1 --learning-rate 0.001 --zprior-weight -1 --no-annealed --optimizer-type adam --max-update-iter 3000 12 | 13 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | #from bicubic import BicubicDownSample 3 | 4 | class LossBuilder(torch.nn.Module): 5 | def __init__(self, ref_im, loss_str, eps, hr_size): 6 | super(LossBuilder, self).__init__() 7 | assert ref_im.shape[2]==ref_im.shape[3] 8 | im_size = ref_im.shape[2] 9 | factor=hr_size//im_size 10 | assert im_size*factor==hr_size 11 | self.D = torch.nn.AvgPool2d(factor) #BicubicDownSample(factor=factor) 12 | self.ref_im = ref_im 13 | self.parsed_loss = [loss_term.split('*') for loss_term in loss_str.split('+')] 14 | self.eps = eps 15 | 16 | # Takes a list of tensors, flattens them, and concatenates them into a vector 17 | # Used to calculate euclidian distance between lists of tensors 18 | def flatcat(self, l): 19 | l = l if(isinstance(l, list)) else [l] 20 | return torch.cat([x.flatten() for x in l], dim=0) 21 | 22 | def _loss_l2(self, gen_im_lr, ref_im, **kwargs): 23 | return ((gen_im_lr - ref_im).pow(2).mean((1, 2, 3)).clamp(min=self.eps).sum()) 24 | 25 | def _loss_l1(self, gen_im_lr, ref_im, **kwargs): 26 | return 10*((gen_im_lr - ref_im).abs().mean((1, 2, 3)).clamp(min=self.eps).sum()) 27 | 28 | # Uses geodesic distance on sphere to sum pairwise distances of the 18 vectors 29 | def _loss_geocross(self, latent, **kwargs): 30 | if(latent.shape[1] == 1): 31 | return 0 32 | else: 33 | X = latent.view(-1, 1, 18, 512) 34 | Y = latent.view(-1, 18, 1, 512) 35 | A = ((X-Y).pow(2).sum(-1)+1e-9).sqrt() 36 | B = ((X+Y).pow(2).sum(-1)+1e-9).sqrt() 37 | D = 2*torch.atan2(A, B) 38 | D = ((D.pow(2)*512).mean((1, 2))/8.).sum() 39 | return D 40 | 41 | def forward(self, latent, gen_im): 42 | var_dict = {'latent': latent, 43 | 'gen_im_lr': self.D(gen_im), 44 | 'ref_im': self.ref_im, 45 | } 46 | loss = 0 47 | loss_fun_dict = { 48 | 'L2': self._loss_l2, 49 | 'L1': self._loss_l1, 50 | 'GEOCROSS': self._loss_geocross, 51 | } 52 | losses = {} 53 | for weight, loss_type in self.parsed_loss: 54 | tmp_loss = loss_fun_dict[loss_type](**var_dict) 55 | losses[loss_type] = tmp_loss 56 | loss += float(weight)*tmp_loss 57 | return loss, losses 58 | -------------------------------------------------------------------------------- /scripts/compressed-sensing/celeba_langevin.sh: -------------------------------------------------------------------------------- 1 | python src/compressed_sensing.py --checkpoint-path ./checkpoints/glow/graph_unoptimized.pb --net glow --dataset celebA --num-input-images 1 --batch-size 1 --measurement-type circulant --noise-std 16.0 --num-measurements 2500 --model-type langevin --print-stats --checkpoint-iter 1 --cuda --mloss-weight -1 --learning-rate 5e-5 --sigma-init 64 --sigma-final 16 --L 10 --T 200 --zprior-weight -1 --annealed 2 | 3 | python src/compressed_sensing.py --checkpoint-path ./checkpoints/glow/graph_unoptimized.pb --net glow --dataset celebA --num-input-images 1 --batch-size 1 --measurement-type circulant --noise-std 16.0 --num-measurements 5000 --model-type langevin --print-stats --checkpoint-iter 1 --cuda --mloss-weight -1 --learning-rate 5e-5 --sigma-init 64 --sigma-final 16 --L 10 --T 200 --zprior-weight -1 --annealed 4 | 5 | python src/compressed_sensing.py --checkpoint-path ./checkpoints/glow/graph_unoptimized.pb --net glow --dataset celebA --num-input-images 1 --batch-size 1 --measurement-type circulant --noise-std 16.0 --num-measurements 10000 --model-type langevin --print-stats --checkpoint-iter 1 --cuda --mloss-weight -1 --learning-rate 5e-5 --sigma-init 64 --sigma-final 16 --L 10 --T 200 --zprior-weight -1 --annealed 6 | 7 | python src/compressed_sensing.py --checkpoint-path ./checkpoints/glow/graph_unoptimized.pb --net glow --dataset celebA --num-input-images 1 --batch-size 1 --measurement-type circulant --noise-std 16.0 --num-measurements 20000 --model-type langevin --print-stats --checkpoint-iter 1 --cuda --mloss-weight -1 --learning-rate 1e-5 --sigma-init 64 --sigma-final 16 --L 10 --T 200 --zprior-weight -1 --annealed 8 | 9 | python src/compressed_sensing.py --checkpoint-path ./checkpoints/glow/graph_unoptimized.pb --net glow --dataset celebA --num-input-images 1 --batch-size 1 --measurement-type circulant --noise-std 16.0 --num-measurements 30000 --model-type langevin --print-stats --checkpoint-iter 1 --cuda --mloss-weight -1 --learning-rate 1e-5 --sigma-init 64 --sigma-final 16 --L 10 --T 200 --zprior-weight -1 --annealed 10 | 11 | python src/compressed_sensing.py --checkpoint-path ./checkpoints/glow/graph_unoptimized.pb --net glow --dataset celebA --num-input-images 1 --batch-size 1 --measurement-type circulant --noise-std 16.0 --num-measurements 35000 --model-type langevin --print-stats --checkpoint-iter 1 --cuda --mloss-weight -1 --learning-rate 1e-5 --sigma-init 64 --sigma-final 16 --L 10 --T 200 --zprior-weight -1 --annealed 12 | 13 | 14 | -------------------------------------------------------------------------------- /scripts/fairness/cat80.sh: -------------------------------------------------------------------------------- 1 | python src/compressed_sensing.py --checkpoint-path checkpoints/stylegan2_catdog/network-snapshot-cat80dog20-002867.pt --net stylegan2 --dataset cat80dog20 --input-type full-input --num-input-images 625 --batch-size 1 --image-size 512 --measurement-type superres -- noise-std 0.1 --downsample 64 --model-types langevin --annealed --mloss-weight -1 --zprior-weight -1 --zprior-sdev 1 --zprior-init-sdev 1.0 --T 3 --L 500 --sigma-init 1.0 --sigma-final 0.1 --num-noise-variables 18 --optimizer-type sgd --learning-rate 5e-6 --momentum 0. --num-random-restarts 1 --checkpoint-iter 1 --cuda --max-update-iter -1 --error-threshold 1.0 2 | 3 | python src/compressed_sensing.py --checkpoint-path checkpoints/stylegan2_catdog/network-snapshot-cat80dog20-002867.pt --net stylegan2 --dataset cat80dog20 --input-type full-input --num-input-images 625 --batch-size 1 --image-size 512 --measurement-type superres -- noise-std 0.1 --downsample 128 --model-types langevin --annealed --mloss-weight -1 --zprior-weight -1 --zprior-sdev 1 --zprior-init-sdev 1.0 --T 3 --L 500 --sigma-init 1.0 --sigma-final 0.1 --num-noise-variables 18 --optimizer-type sgd --learning-rate 5e-6 --momentum 0. --num-random-restarts 4 --checkpoint-iter 1 --cuda --max-update-iter -1 --error-threshold 1.0 4 | 5 | python src/compressed_sensing.py --checkpoint-path checkpoints/stylegan2_catdog/network-snapshot-cat80dog20-002867.pt --net stylegan2 --dataset cat80dog20 --input-type full-input --num-input-images 625 --batch-size 1 --image-size 512 --measurement-type superres -- noise-std 0.01 --downsample 256 --model-types langevin --annealed --mloss-weight -1 --zprior-weight -1 --zprior-sdev 1 --zprior-init-sdev 1.0 --T 3 --L 500 --sigma-init 1.0 --sigma-final 0.01 --num-noise-variables 18 --optimizer-type sgd --learning-rate 5e-6 --momentum 0. --num-random-restarts 4 --checkpoint-iter 1 --cuda --max-update-iter -1 --error-threshold 1.0 6 | 7 | python src/compressed_sensing.py --checkpoint-path checkpoints/stylegan2_catdog/network-snapshot-cat80dog20-002867.pt --net stylegan2 --dataset cat80dog20 --input-type full-input --num-input-images 625 --batch-size 1 --image-size 512 --measurement-type superres -- noise-std 0.01 --downsample 512 --model-types langevin --annealed --mloss-weight -1 --zprior-weight -1 --zprior-sdev 1 --zprior-init-sdev 1.0 --T 3 --L 500 --sigma-init 1.0 --sigma-final 0.01 --num-noise-variables 18 --optimizer-type sgd --learning-rate 5e-6 --momentum 0. --num-random-restarts 4 --checkpoint-iter 1 --cuda --max-update-iter -1 --error-threshold 1.0 8 | -------------------------------------------------------------------------------- /scripts/fairness/cat50.sh: -------------------------------------------------------------------------------- 1 | python src/compressed_sensing.py --checkpoint-path checkpoints/stylegan2_catdog/network-snapshot-cat50dog50-003481.pt --net stylegan2 --dataset cat50dog50 --input-type full-input --num-input-images 1000 --batch-size 1 --image-size 512 --measurement-type superres -- noise-std 0.1 --downsample 64 --model-types langevin --annealed --mloss-weight -1 --zprior-weight -1 --zprior-sdev 1 --zprior-init-sdev 1.0 --T 3 --L 500 --sigma-init 1.0 --sigma-final 0.1 --num-noise-variables 18 --optimizer-type sgd --learning-rate 5e-6 --momentum 0. --num-random-restarts 4 --checkpoint-iter 1 --cuda --max-update-iter -1 --error-threshold 1.0 2 | 3 | python src/compressed_sensing.py --checkpoint-path checkpoints/stylegan2_catdog/network-snapshot-cat50dog50-003481.pt --net stylegan2 --dataset cat50dog50 --input-type full-input --num-input-images 1000 --batch-size 1 --image-size 512 --measurement-type superres -- noise-std 0.1 --downsample 128 --model-types langevin --annealed --mloss-weight -1 --zprior-weight -1 --zprior-sdev 1 --zprior-init-sdev 1.0 --T 3 --L 500 --sigma-init 1.0 --sigma-final 0.1 --num-noise-variables 18 --optimizer-type sgd --learning-rate 5e-6 --momentum 0. --num-random-restarts 4 --checkpoint-iter 1 --cuda --max-update-iter -1 --error-threshold 1.0 4 | 5 | python src/compressed_sensing.py --checkpoint-path checkpoints/stylegan2_catdog/network-snapshot-cat50dog50-003481.pt --net stylegan2 --dataset cat50dog50 --input-type full-input --num-input-images 1000 --batch-size 1 --image-size 512 --measurement-type superres -- noise-std 0.01 --downsample 256 --model-types langevin --annealed --mloss-weight -1 --zprior-weight -1 --zprior-sdev 1 --zprior-init-sdev 1.0 --T 3 --L 500 --sigma-init 1.0 --sigma-final 0.01 --num-noise-variables 18 --optimizer-type sgd --learning-rate 5e-6 --momentum 0. --num-random-restarts 4 --checkpoint-iter 1 --cuda --max-update-iter -1 --error-threshold 1.0 6 | 7 | python src/compressed_sensing.py --checkpoint-path checkpoints/stylegan2_catdog/network-snapshot-cat50dog50-003481.pt --net stylegan2 --dataset cat50dog50 --input-type full-input --num-input-images 1000 --batch-size 1 --image-size 512 --measurement-type superres -- noise-std 0.01 --downsample 512 --model-types langevin --annealed --mloss-weight -1 --zprior-weight -1 --zprior-sdev 1 --zprior-init-sdev 1.0 --T 3 --L 500 --sigma-init 1.0 --sigma-final 0.01 --num-noise-variables 18 --optimizer-type sgd --learning-rate 5e-6 --momentum 0. --num-random-restarts 4 --checkpoint-iter 1 --cuda --max-update-iter -1 --error-threshold 1.0 8 | 9 | -------------------------------------------------------------------------------- /scripts/fairness/cat20.sh: -------------------------------------------------------------------------------- 1 | # 64x undersampling 2 | python src/compressed_sensing.py --checkpoint-path checkpoints/stylegan2_catdog/network-snapshot-cat20dog80-003481.pt --net stylegan2 --dataset cat20dog80 --input-type full-input --num-input-images 625 --batch-size 1 --image-size 512 --measurement-type superres -- noise-std 0.1 --downsample 64 --model-types langevin --annealed --mloss-weight -1 --zprior-weight -1 --zprior-sdev 1 --zprior-init-sdev 1.0 --T 3 --L 500 --sigma-init 1.0 --sigma-final 0.1 --num-noise-variables 18 --optimizer-type sgd --learning-rate 5e-6 --momentum 0. --num-random-restarts 4 --checkpoint-iter 1 --cuda --max-update-iter -1 --error-threshold 1.0 3 | 4 | # 128x undersampling 5 | python src/compressed_sensing.py --checkpoint-path checkpoints/stylegan2_catdog/network-snapshot-cat20dog80-003481.pt --net stylegan2 --dataset cat20dog80 --input-type full-input --num-input-images 625 --batch-size 1 --image-size 512 --measurement-type superres -- noise-std 0.1 --downsample 128 --model-types langevin --annealed --mloss-weight -1 --zprior-weight -1 --zprior-sdev 1 --zprior-init-sdev 1.0 --T 3 --L 500 --sigma-init 1.0 --sigma-final 0.1 --num-noise-variables 18 --optimizer-type sgd --learning-rate 5e-6 --momentum 0. --num-random-restarts 4 --checkpoint-iter 1 --cuda --max-update-iter -1 --error-threshold 1.0 6 | 7 | # 256x undersampling 8 | python src/compressed_sensing.py --checkpoint-path checkpoints/stylegan2_catdog/network-snapshot-cat20dog80-003481.pt --net stylegan2 --dataset cat20dog80 --input-type full-input --num-input-images 625 --batch-size 1 --image-size 512 --measurement-type superres -- noise-std 0.01 --downsample 256 --model-types langevin --annealed --mloss-weight -1 --zprior-weight -1 --zprior-sdev 1 --zprior-init-sdev 1.0 --T 3 --L 500 --sigma-init 1.0 --sigma-final 0.01 --num-noise-variables 18 --optimizer-type sgd --learning-rate 5e-6 --momentum 0. --num-random-restarts 4 --checkpoint-iter 1 --cuda --max-update-iter -1 --error-threshold 1.0 9 | 10 | # 512x undersampling 11 | python src/compressed_sensing.py --checkpoint-path checkpoints/stylegan2_catdog/network-snapshot-cat20dog80-003481.pt --net stylegan2 --dataset cat20dog80 --input-type full-input --num-input-images 625 --batch-size 1 --image-size 512 --measurement-type superres -- noise-std 0.01 --downsample 512 --model-types langevin --annealed --mloss-weight -1 --zprior-weight -1 --zprior-sdev 1 --zprior-init-sdev 1.0 --T 3 --L 500 --sigma-init 1.0 --sigma-final 0.01 --num-noise-variables 18 --optimizer-type sgd --learning-rate 5e-6 --momentum 0. --num-random-restarts 4 --checkpoint-iter 1 --cuda --max-update-iter -1 --error-threshold 1.0 12 | -------------------------------------------------------------------------------- /stylegan2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output, empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | grad_bias = grad_input.sum(dim).detach() 39 | 40 | return grad_input, grad_bias 41 | 42 | @staticmethod 43 | def backward(ctx, gradgrad_input, gradgrad_bias): 44 | out, = ctx.saved_tensors 45 | gradgrad_out = fused.fused_bias_act( 46 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 47 | ) 48 | 49 | return gradgrad_out, None, None, None 50 | 51 | 52 | class FusedLeakyReLUFunction(Function): 53 | @staticmethod 54 | def forward(ctx, input, bias, negative_slope, scale): 55 | empty = input.new_empty(0) 56 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 57 | ctx.save_for_backward(out) 58 | ctx.negative_slope = negative_slope 59 | ctx.scale = scale 60 | 61 | return out 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | out, = ctx.saved_tensors 66 | 67 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 68 | grad_output, out, ctx.negative_slope, ctx.scale 69 | ) 70 | 71 | return grad_input, grad_bias, None, None 72 | 73 | 74 | class FusedLeakyReLU(nn.Module): 75 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 76 | super().__init__() 77 | 78 | self.bias = nn.Parameter(torch.zeros(channel)) 79 | self.negative_slope = negative_slope 80 | self.scale = scale 81 | 82 | def forward(self, input): 83 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 84 | 85 | 86 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 87 | if input.device.type == "cpu": 88 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 89 | return ( 90 | F.leaky_relu( 91 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 92 | ) 93 | * scale 94 | ) 95 | 96 | else: 97 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 98 | -------------------------------------------------------------------------------- /shuffle_catdog.sh: -------------------------------------------------------------------------------- 1 | mkdir -p test_images/cat10dog90/imgs/ 2 | cd test_images/cat10dog90/imgs/ 3 | shuf -ezn 500 ../../../datasets/afhq/val/dog/* | xargs -0 ln -st . 4 | shuf -ezn 55 ../../../datasets/afhq/val/cat/* | xargs -0 ln -st . 5 | list=$(ls * | shuf); set i=0; for name in $list; do new=$(printf "%06d.jpg" "$i"); mv $name $new; let i=i+1; done 6 | 7 | cd ../../../ 8 | mkdir -p test_images/cat20dog80/imgs/ 9 | cd test_images/cat20dog80/imgs/ 10 | shuf -ezn 500 ../../../datasets/afhq/val/dog/* | xargs -0 ln -st . 11 | shuf -ezn 125 ../../../datasets/afhq/val/cat/* | xargs -0 ln -st . 12 | list=$(ls * | shuf); set i=0; for name in $list; do new=$(printf "%06d.jpg" "$i"); mv $name $new; let i=i+1; done 13 | 14 | cd ../../../ 15 | mkdir -p test_images/cat30dog70/imgs/ 16 | cd test_images/cat30dog70/imgs/ 17 | shuf -ezn 500 ../../../datasets/afhq/val/dog/* | xargs -0 ln -st . 18 | shuf -ezn 215 ../../../datasets/afhq/val/cat/* | xargs -0 ln -st . 19 | list=$(ls * | shuf); set i=0; for name in $list; do new=$(printf "%06d.jpg" "$i"); mv $name $new; let i=i+1; done 20 | 21 | cd ../../../ 22 | mkdir -p test_images/cat40dog60/imgs/ 23 | cd test_images/cat40dog60/imgs/ 24 | shuf -ezn 500 ../../../datasets/afhq/val/dog/* | xargs -0 ln -st . 25 | shuf -ezn 333 ../../../datasets/afhq/val/cat/* | xargs -0 ln -st . 26 | list=$(ls * | shuf); set i=0; for name in $list; do new=$(printf "%06d.jpg" "$i"); mv $name $new; let i=i+1; done 27 | 28 | cd ../../../ 29 | mkdir -p test_images/cat50dog50/imgs/ 30 | cd test_images/cat50dog50/imgs/ 31 | shuf -ezn 500 ../../../datasets/afhq/val/dog/* | xargs -0 ln -st . 32 | shuf -ezn 500 ../../../datasets/afhq/val/cat/* | xargs -0 ln -st . 33 | list=$(ls * | shuf); set i=0; for name in $list; do new=$(printf "%06d.jpg" "$i"); mv $name $new; let i=i+1; done 34 | 35 | cd ../../../ 36 | mkdir -p test_images/cat60dog40/imgs/ 37 | cd test_images/cat60dog40/imgs/ 38 | shuf -ezn 333 ../../../datasets/afhq/val/dog/* | xargs -0 ln -st . 39 | shuf -ezn 500 ../../../datasets/afhq/val/cat/* | xargs -0 ln -st . 40 | list=$(ls * | shuf); set i=0; for name in $list; do new=$(printf "%06d.jpg" "$i"); mv $name $new; let i=i+1; done 41 | 42 | cd ../../../ 43 | mkdir -p test_images/cat70dog30/imgs/ 44 | cd test_images/cat70dog30/imgs/ 45 | shuf -ezn 215 ../../../datasets/afhq/val/dog/* | xargs -0 ln -st . 46 | shuf -ezn 500 ../../../datasets/afhq/val/cat/* | xargs -0 ln -st . 47 | list=$(ls * | shuf); set i=0; for name in $list; do new=$(printf "%06d.jpg" "$i"); mv $name $new; let i=i+1; done 48 | 49 | cd ../../../ 50 | mkdir -p test_images/cat80dog20/imgs/ 51 | cd test_images/cat80dog20/imgs/ 52 | shuf -ezn 125 ../../../datasets/afhq/val/dog/* | xargs -0 ln -st . 53 | shuf -ezn 500 ../../../datasets/afhq/val/cat/* | xargs -0 ln -st . 54 | list=$(ls * | shuf); set i=0; for name in $list; do new=$(printf "%06d.jpg" "$i"); mv $name $new; let i=i+1; done 55 | 56 | cd ../../../ 57 | mkdir -p test_images/cat90dog10/imgs/ 58 | cd test_images/cat90dog10/imgs/ 59 | shuf -ezn 55 ../../../datasets/afhq/val/dog/* | xargs -0 ln -st . 60 | shuf -ezn 500 ../../../datasets/afhq/val/cat/* | xargs -0 ln -st . 61 | list=$(ls * | shuf); set i=0; for name in $list; do new=$(printf "%06d.jpg" "$i"); mv $name $new; let i=i+1; done 62 | 63 | -------------------------------------------------------------------------------- /stylegan2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /scripts/compressed-sensing/ffhq_map.sh: -------------------------------------------------------------------------------- 1 | python src/compressed_sensing.py --checkpoint-path checkpoints/ncsnv2_ffhq/checkpoint_80000.pth --net ncsnv2 --dataset ffhq-69000 --num-input-images 1 --batch-size 1 --ncsnv2-configs-file ./ncsnv2/configs/ffhq.yml --measurement-type circulant --noise-std 4 --num-measurements 5000 --model-type map --print-stats --checkpoint-iter 1 --cuda --mloss-weight 1.0 --learning-rate 9e-7 --sigma-init 348 --sigma-final 0.01 --L 1150 --T 3 2 | 3 | python src/compressed_sensing.py --checkpoint-path checkpoints/ncsnv2_ffhq/checkpoint_80000.pth --net ncsnv2 --dataset ffhq-69000 --num-input-images 1 --batch-size 1 --ncsnv2-configs-file ./ncsnv2/configs/ffhq.yml --measurement-type circulant --noise-std 4 --num-measurements 10000 --model-type map --print-stats --checkpoint-iter 1 --cuda --mloss-weight 1.0 --learning-rate 9e-7 --sigma-init 348 --sigma-final 0.01 --L 1150 --T 3 4 | 5 | python src/compressed_sensing.py --checkpoint-path checkpoints/ncsnv2_ffhq/checkpoint_80000.pth --net ncsnv2 --dataset ffhq-69000 --num-input-images 1 --batch-size 1 --ncsnv2-configs-file ./ncsnv2/configs/ffhq.yml --measurement-type circulant --noise-std 4 --num-measurements 15000 --model-type map --print-stats --checkpoint-iter 1 --cuda --mloss-weight 1.0 --learning-rate 9e-7 --sigma-init 348 --sigma-final 0.01 --L 1150 --T 3 6 | 7 | python src/compressed_sensing.py --checkpoint-path checkpoints/ncsnv2_ffhq/checkpoint_80000.pth --net ncsnv2 --dataset ffhq-69000 --num-input-images 1 --batch-size 1 --ncsnv2-configs-file ./ncsnv2/configs/ffhq.yml --measurement-type circulant --noise-std 4 --num-measurements 20000 --model-type map --print-stats --checkpoint-iter 1 --cuda --mloss-weight 1.0 --learning-rate 9e-7 --sigma-init 348 --sigma-final 0.01 --L 2311 --T 3 8 | 9 | python src/compressed_sensing.py --checkpoint-path checkpoints/ncsnv2_ffhq/checkpoint_80000.pth --net ncsnv2 --dataset ffhq-69000 --num-input-images 1 --batch-size 1 --ncsnv2-configs-file ./ncsnv2/configs/ffhq.yml --measurement-type circulant --noise-std 4 --num-measurements 30000 --model-type map --print-stats --checkpoint-iter 1 --cuda --mloss-weight 1.0 --learning-rate 9e-7 --sigma-init 348 --sigma-final 0.001 --L 2311 --T 3 10 | 11 | python src/compressed_sensing.py --checkpoint-path checkpoints/ncsnv2_ffhq/checkpoint_80000.pth --net ncsnv2 --dataset ffhq-69000 --num-input-images 1 --batch-size 1 --ncsnv2-configs-file ./ncsnv2/configs/ffhq.yml --measurement-type circulant --noise-std 4 --num-measurements 40000 --model-type map --print-stats --checkpoint-iter 1 --cuda --mloss-weight 1.0 --learning-rate 9e-7 --sigma-init 348 --sigma-final 0.01 --L 2311 --T 3 12 | 13 | python src/compressed_sensing.py --checkpoint-path checkpoints/ncsnv2_ffhq/checkpoint_80000.pth --net ncsnv2 --dataset ffhq-69000 --num-input-images 1 --batch-size 1 --ncsnv2-configs-file ./ncsnv2/configs/ffhq.yml --measurement-type circulant --noise-std 4 --num-measurements 50000 --model-type map --print-stats --checkpoint-iter 1 --cuda --mloss-weight 1.0 --learning-rate 9e-7 --sigma-init 348 --sigma-final 0.01 --L 2311 --T 3 14 | 15 | python src/compressed_sensing.py --checkpoint-path checkpoints/ncsnv2_ffhq/checkpoint_80000.pth --net ncsnv2 --dataset ffhq-69000 --num-input-images 1 --batch-size 1 --ncsnv2-configs-file ./ncsnv2/configs/ffhq.yml --measurement-type circulant --noise-std 4 --num-measurements 75000 --model-type map --print-stats --checkpoint-iter 1 --cuda --mloss-weight 1.0 --learning-rate 9e-6 --sigma-init 348 --sigma-final 0.01 --L 2311 --T 3 16 | 17 | -------------------------------------------------------------------------------- /scripts/compressed-sensing/ffhq_langevin.sh: -------------------------------------------------------------------------------- 1 | python src/compressed_sensing.py --checkpoint-path checkpoints/ncsnv2_ffhq/checkpoint_80000.pth --net ncsnv2 --dataset ffhq-69000 --num-input-images 1 --batch-size 1 --ncsnv2-configs-file ./ncsnv2/configs/ffhq.yml --measurement-type circulant --noise-std 4.0 --num-measurements 5000 --model-type langevin --print-stats --checkpoint-iter 1 --cuda --mloss-weight 0.1 --learning-rate 9e-6 --sigma-init 348 --sigma-final 0.01 --L 2311 --T 3 2 | 3 | python src/compressed_sensing.py --checkpoint-path checkpoints/ncsnv2_ffhq/checkpoint_80000.pth --net ncsnv2 --dataset ffhq-69000 --num-input-images 1 --batch-size 1 --ncsnv2-configs-file ./ncsnv2/configs/ffhq.yml --measurement-type circulant --noise-std 4.0 --num-measurements 10000 --model-type langevin --print-stats --checkpoint-iter 1 --cuda --mloss-weight 0.1 --learning-rate 9e-6 --sigma-init 348 --sigma-final 0.01 --L 2311 --T 3 4 | 5 | python src/compressed_sensing.py --checkpoint-path checkpoints/ncsnv2_ffhq/checkpoint_80000.pth --net ncsnv2 --dataset ffhq-69000 --num-input-images 1 --batch-size 1 --ncsnv2-configs-file ./ncsnv2/configs/ffhq.yml --measurement-type circulant --noise-std 4.0 --num-measurements 15000 --model-type langevin --print-stats --checkpoint-iter 1 --cuda --mloss-weight 0.1 --learning-rate 9e-6 --sigma-init 348 --sigma-final 0.01 --L 2311 --T 3 6 | 7 | python src/compressed_sensing.py --checkpoint-path checkpoints/ncsnv2_ffhq/checkpoint_80000.pth --net ncsnv2 --dataset ffhq-69000 --num-input-images 1 --batch-size 1 --ncsnv2-configs-file ./ncsnv2/configs/ffhq.yml --measurement-type circulant --noise-std 4.0 --num-measurements 20000 --model-type langevin --print-stats --checkpoint-iter 1 --cuda --mloss-weight 0.1 --learning-rate 9e-6 --sigma-init 348 --sigma-final 0.01 --L 2311 --T 3 8 | 9 | python src/compressed_sensing.py --checkpoint-path checkpoints/ncsnv2_ffhq/checkpoint_80000.pth --net ncsnv2 --dataset ffhq-69000 --num-input-images 1 --batch-size 1 --ncsnv2-configs-file ./ncsnv2/configs/ffhq.yml --measurement-type circulant --noise-std 4.0 --num-measurements 30000 --model-type langevin --print-stats --checkpoint-iter 1 --cuda --mloss-weight 0.09 --learning-rate 9e-7 --sigma-init 348 --sigma-final 0.001 --L 2311 --T 3 10 | 11 | python src/compressed_sensing.py --checkpoint-path checkpoints/ncsnv2_ffhq/checkpoint_80000.pth --net ncsnv2 --dataset ffhq-69000 --num-input-images 1 --batch-size 1 --ncsnv2-configs-file ./ncsnv2/configs/ffhq.yml --measurement-type circulant --noise-std 4.0 --num-measurements 40000 --model-type langevin --print-stats --checkpoint-iter 1 --cuda --mloss-weight 0.1 --learning-rate 9e-7 --sigma-init 348 --sigma-final 0.001 --L 2311 --T 3 12 | 13 | python src/compressed_sensing.py --checkpoint-path checkpoints/ncsnv2_ffhq/checkpoint_80000.pth --net ncsnv2 --dataset ffhq-69000 --num-input-images 1 --batch-size 1 --ncsnv2-configs-file ./ncsnv2/configs/ffhq.yml --measurement-type circulant --noise-std 4.0 --num-measurements 50000 --model-type langevin --print-stats --checkpoint-iter 1 --cuda --mloss-weight 0.1 --learning-rate 9e-7 --sigma-init 348 --sigma-final 0.001 --L 2311 --T 3 14 | 15 | python src/compressed_sensing.py --checkpoint-path checkpoints/ncsnv2_ffhq/checkpoint_80000.pth --net ncsnv2 --dataset ffhq-69000 --num-input-images 1 --batch-size 1 --ncsnv2-configs-file ./ncsnv2/configs/ffhq.yml --measurement-type circulant --noise-std 4.0 --num-measurements 75000 --model-type langevin --print-stats --checkpoint-iter 1 --cuda --mloss-weight 0.1 --learning-rate 9e-7 --sigma-init 348 --sigma-final 0.001 --L 2311 --T 3 16 | 17 | -------------------------------------------------------------------------------- /include/wavelet.py: -------------------------------------------------------------------------------- 1 | #import matplotlib.pyplot as plt 2 | import numpy as np 3 | import numbers 4 | import pywt 5 | import scipy 6 | import skimage.color as color 7 | from skimage.restoration import (denoise_wavelet, estimate_sigma) 8 | from skimage import data, img_as_float 9 | from skimage.util import random_noise 10 | from skimage.measure import compare_psnr 11 | from include import * 12 | 13 | def _wavelet_threshold(image, wavelet, ncoeff = None, threshold=None, mode='soft', wavelet_levels=None): 14 | 15 | wavelet = pywt.Wavelet(wavelet) 16 | 17 | # original_extent is used to workaround PyWavelets issue #80 18 | # odd-sized input results in an image with 1 extra sample after waverecn 19 | original_extent = [slice(s) for s in image.shape] 20 | 21 | # Determine the number of wavelet decomposition levels 22 | if wavelet_levels is None: 23 | # Determine the maximum number of possible levels for image 24 | dlen = wavelet.dec_len 25 | wavelet_levels = np.min( 26 | [pywt.dwt_max_level(s, dlen) for s in image.shape]) 27 | 28 | # Skip coarsest wavelet scales (see Notes in docstring). 29 | wavelet_levels = max(wavelet_levels - 3, 1) 30 | 31 | coeffs = pywt.wavedecn(image, wavelet=wavelet, level=wavelet_levels) 32 | # Detail coefficients at each decomposition level 33 | dcoeffs = coeffs[1:] 34 | 35 | a = [] 36 | for level in dcoeffs: 37 | for key in level: 38 | a += [np.ndarray.flatten(level[key])] 39 | a = np.concatenate(a) 40 | a = np.sort( np.abs(a) ) 41 | 42 | sh = coeffs[0].shape 43 | basecoeffs = sh[0]*sh[1] 44 | threshold = a[- (ncoeff - basecoeffs)] 45 | 46 | # A single threshold for all coefficient arrays 47 | denoised_detail = [{key: pywt.threshold(level[key],value=threshold, 48 | mode=mode) for key in level} for level in dcoeffs] 49 | 50 | denoised_coeffs = [coeffs[0]] + denoised_detail 51 | return pywt.waverecn(denoised_coeffs, wavelet)[original_extent] 52 | 53 | 54 | def denoise_wavelet(image, ncoeff=None, wavelet='db1', mode='hard', 55 | wavelet_levels=None, multichannel=False, 56 | convert2ycbcr=False): 57 | 58 | image = img_as_float(image) 59 | 60 | 61 | if multichannel: 62 | if convert2ycbcr: 63 | out = color.rgb2ycbcr(image) 64 | for i in range(3): 65 | # renormalizing this color channel to live in [0, 1] 66 | min, max = out[..., i].min(), out[..., i].max() 67 | channel = out[..., i] - min 68 | channel /= max - min 69 | out[..., i] = denoise_wavelet(channel, wavelet=wavelet,ncoeff=ncoeff, 70 | mode=mode, 71 | wavelet_levels=wavelet_levels) 72 | 73 | out[..., i] = out[..., i] * (max - min) 74 | out[..., i] += min 75 | out = color.ycbcr2rgb(out) 76 | else: 77 | out = np.empty_like(image) 78 | for c in range(image.shape[-1]): 79 | out[..., c] = _wavelet_threshold(image[..., c],ncoeff=ncoeff, 80 | wavelet=wavelet, mode=mode, 81 | wavelet_levels=wavelet_levels) 82 | else: 83 | out = _wavelet_threshold(image, wavelet=wavelet, mode=mode,ncoeff=ncoeff, 84 | wavelet_levels=wavelet_levels) 85 | 86 | clip_range = (-1, 1) if image.min() < 0 else (0, 1) 87 | return np.clip(out, *clip_range) 88 | 89 | 90 | -------------------------------------------------------------------------------- /src/view_estimated_ffhq_cs.py: -------------------------------------------------------------------------------- 1 | """View estimated images for ffhq-69000""" 2 | # pylint: disable = C0301, R0903, R0902 3 | 4 | import numpy as np 5 | import utils 6 | import matplotlib.pyplot as plt 7 | import pickle as pkl 8 | from metrics_utils import int_or_float, find_best 9 | import glob 10 | 11 | class Hparams(object): 12 | """Hyperparameters""" 13 | def __init__(self): 14 | self.input_type = 'full-input' 15 | self.input_path_pattern = './test_images/ffhq-69000' 16 | self.input_path = './test_images/ffhq-69000' 17 | self.num_input_images = 30 18 | self.image_matrix = 0 19 | self.image_shape = (3,256,256) 20 | self.image_size = 256 21 | self.noise_std = 4.0 22 | self.n_input = np.prod(self.image_shape) 23 | self.measurement_type = 'circulant' 24 | self.model_types = ['MAP', 'Langevin'] 25 | 26 | 27 | def view(xs_dict, patterns_images, patterns_l2, images_nums, hparams, **kws): 28 | """View the images""" 29 | x_hats_dict = {} 30 | l2_dict = {} 31 | for model_type, pattern_image, pattern_l2 in zip(hparams.model_types, patterns_images, patterns_l2): 32 | outfiles = [pattern_image.format(i) for i in images_nums] 33 | x_hats_dict[model_type] = {i: plt.imread(outfile) for i, outfile in enumerate(outfiles)} 34 | with open(pattern_l2, 'rb') as f: 35 | l2_dict[model_type] = pkl.load(f) 36 | xs_dict_temp = {i : xs_dict[i] for i in images_nums} 37 | utils.image_matrix(xs_dict_temp, x_hats_dict, l2_dict, hparams, **kws) 38 | 39 | 40 | def get_image_nums(start, stop, hparams): 41 | """Get range of images""" 42 | assert start >= 0 43 | assert stop <= hparams.num_input_images 44 | images_nums = list(range(start, stop)) 45 | return images_nums 46 | 47 | 48 | def main(): 49 | """Make and save image matrices""" 50 | hparams = Hparams() 51 | xs_dict = utils.model_input(hparams) 52 | start, stop = 0, 1 53 | images_nums = get_image_nums(start, stop, hparams) 54 | is_save = True 55 | 56 | def formatted(f): 57 | return format(f, '.4f').rstrip('0').rstrip('.') 58 | legend_base_regexs = [ 59 | ('MAP', 60 | f'./estimated/ffhq-69000/full-input/circulant/{hparams.noise_std}/', 61 | '/ncsnv2/map/*'), 62 | ('Langevin', 63 | f'./estimated/ffhq-69000/full-input/circulant/{hparams.noise_std}/', 64 | '/ncsnv2/langevin/*') 65 | 66 | ] 67 | retrieve_list = [['l2', 'mean'], ['l2', 'std']] 68 | 69 | for num_measurements in [5000,10000,15000,20000,30000,40000,50000,75000]: 70 | patterns_images, patterns_l2 = [], [] 71 | exists = True 72 | for legend, base, regex in legend_base_regexs: 73 | keys = map(int_or_float, [a.split('/')[-1] for a in glob.glob(base + '*')]) 74 | list_keys = [key for key in keys] 75 | if num_measurements not in list_keys: 76 | exists = False 77 | break 78 | pattern = base + str(num_measurements) + regex 79 | criterion = ['l2', 'mean'] 80 | 81 | _, best_dir = find_best(pattern, criterion, retrieve_list) 82 | pattern_images = best_dir + '/images/{:06d}.png' 83 | pattern_l2 = best_dir + '/l2_losses.pkl' 84 | patterns_images.append(pattern_images) 85 | patterns_l2.append(pattern_l2) 86 | if exists: 87 | view(xs_dict, patterns_images, patterns_l2, images_nums, hparams) 88 | save_path = f'./results/ffhq-69000-reconstr-{num_measurements}-{criterion[0]}.pdf' 89 | utils.save_plot(is_save, save_path) 90 | else: 91 | print(f'Could not find reconstructions for {num_measurements}') 92 | 93 | if __name__ == '__main__': 94 | main() 95 | -------------------------------------------------------------------------------- /src/view_estimated_celebA_cs.py: -------------------------------------------------------------------------------- 1 | """View estimated images for ffhq-69000""" 2 | # pylint: disable = C0301, R0903, R0902 3 | 4 | import numpy as np 5 | import utils 6 | import matplotlib.pyplot as plt 7 | import pickle as pkl 8 | from metrics_utils import int_or_float, find_best 9 | import glob 10 | 11 | class Hparams(object): 12 | """Hyperparameters""" 13 | def __init__(self): 14 | self.input_type = 'full-input' 15 | self.input_path_pattern = './test_images/celebA' 16 | self.input_path = './test_images/celebA' 17 | self.num_input_images = 30 18 | self.image_matrix = 0 19 | self.image_shape = (3,256,256) 20 | self.image_size = 256 21 | self.noise_std = 16.0 22 | self.n_input = np.prod(self.image_shape) 23 | self.measurement_type = 'circulant' 24 | self.model_types = ['MAP', 'Langevin'] 25 | 26 | 27 | def view(xs_dict, patterns_images, patterns_l2, images_nums, hparams, **kws): 28 | """View the images""" 29 | x_hats_dict = {} 30 | l2_dict = {} 31 | for model_type, pattern_image, pattern_l2 in zip(hparams.model_types, patterns_images, patterns_l2): 32 | outfiles = [pattern_image.format(i) for i in images_nums] 33 | x_hats_dict[model_type] = {i: plt.imread(outfile) for i, outfile in enumerate(outfiles)} 34 | with open(pattern_l2, 'rb') as f: 35 | l2_dict[model_type] = pkl.load(f) 36 | xs_dict_temp = {i : xs_dict[i] for i in images_nums} 37 | utils.image_matrix(xs_dict_temp, x_hats_dict, l2_dict, hparams, **kws) 38 | 39 | 40 | def get_image_nums(start, stop, hparams): 41 | """Get range of images""" 42 | assert start >= 0 43 | assert stop <= hparams.num_input_images 44 | images_nums = list(range(start, stop)) 45 | return images_nums 46 | 47 | 48 | def main(): 49 | """Make and save image matrices""" 50 | hparams = Hparams() 51 | xs_dict = utils.model_input(hparams) 52 | start, stop = 0, 1 53 | images_nums = get_image_nums(start, stop, hparams) 54 | is_save = True 55 | 56 | def formatted(f): 57 | return format(f, '.4f').rstrip('0').rstrip('.') 58 | legend_base_regexs = [ 59 | ('MAP', 60 | f'./estimated/celebA/full-input/circulant/{hparams.noise_std}/', 61 | '/glow/map/None_None*'), 62 | ('Langevin', 63 | f'./estimated*/celebA/full-input/circulant/{hparams.noise_std}/', 64 | '/glow/*langevin/None*'), 65 | ] 66 | retrieve_list = [['l2', 'mean'], ['l2', 'std']] 67 | 68 | for num_measurements in [2500,5000,10000,20000,30000,35000] : 69 | patterns_images, patterns_l2 = [], [] 70 | exists = True 71 | for legend, base, regex in legend_base_regexs: 72 | keys = map(int_or_float, [a.split('/')[-1] for a in glob.glob(base + '*')]) 73 | list_keys = [key for key in keys] 74 | if num_measurements not in list_keys: 75 | exists = False 76 | break 77 | pattern = base + str(num_measurements) + regex 78 | if legend in ['Modified-MAP']: 79 | criterion = ['l2', 'mean'] 80 | else: 81 | criterion = ['likelihood', 'mean'] 82 | 83 | _, best_dir = find_best(pattern, criterion, retrieve_list) 84 | pattern_images = best_dir + '/images/{:06d}.png' 85 | pattern_l2 = best_dir + '/l2_losses.pkl' 86 | patterns_images.append(pattern_images) 87 | patterns_l2.append(pattern_l2) 88 | if exists: 89 | view(xs_dict, patterns_images, patterns_l2, images_nums, hparams) 90 | save_path = f'./results/celebA-reconstr-{num_measurements}-{criterion[0]}.pdf' 91 | utils.save_plot(is_save, save_path) 92 | else: 93 | print(f'Could not find reconstructions for {num_measurements}') 94 | 95 | if __name__ == '__main__': 96 | main() 97 | -------------------------------------------------------------------------------- /include/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import sys 5 | 6 | import numpy as np 7 | from PIL import Image 8 | import PIL 9 | import numpy as np 10 | 11 | from torch.autograd import Variable 12 | 13 | 14 | 15 | 16 | 17 | import numpy as np 18 | import torch 19 | import matplotlib.pyplot as plt 20 | 21 | from PIL import Image 22 | import PIL 23 | 24 | from torch.autograd import Variable 25 | 26 | def load_and_crop(imgname,target_width=512,target_height=512): 27 | ''' 28 | imgname: string of image location 29 | load an image, and center-crop if the image is large enough, else return none 30 | ''' 31 | img = Image.open(imgname) 32 | width, height = img.size 33 | if width <= target_width or height <= target_height: 34 | return None 35 | 36 | left = (width - target_width)/2 37 | top = (height - target_height)/2 38 | right = (width + target_width)/2 39 | bottom = (height + target_height)/2 40 | 41 | return img.crop((left, top, right, bottom)) 42 | 43 | def save_np_img(img,filename): 44 | if(img.shape[0] == 1): 45 | plt.imshow(np.clip(img[0],0,1),cmap='Greys',interpolation='nearest') 46 | else: 47 | plt.imshow(np.clip(img.transpose(1, 2, 0),0,1)) 48 | plt.axis('off') 49 | plt.savefig(filename, bbox_inches='tight') 50 | plt.close() 51 | 52 | def np_to_tensor(img_np): 53 | '''Converts image in numpy.array to torch.Tensor. 54 | 55 | From C x W x H [0..1] to C x W x H [0..1] 56 | ''' 57 | return torch.from_numpy(img_np) 58 | 59 | def np_to_var(img_np, dtype = torch.cuda.FloatTensor): 60 | '''Converts image in numpy.array to torch.Variable. 61 | 62 | From C x W x H [0..1] to 1 x C x W x H [0..1] 63 | ''' 64 | return Variable(np_to_tensor(img_np)[None, :]) 65 | 66 | def var_to_np(img_var): 67 | '''Converts an image in torch.Variable format to np.array. 68 | 69 | From 1 x C x W x H [0..1] to C x W x H [0..1] 70 | ''' 71 | return img_var.data.cpu().numpy()[0] 72 | 73 | 74 | def pil_to_np(img_PIL): 75 | '''Converts image in PIL format to np.array. 76 | 77 | From W x H x C [0...255] to C x W x H [0..1] 78 | ''' 79 | ar = np.array(img_PIL) 80 | 81 | if len(ar.shape) == 3: 82 | ar = ar.transpose(2,0,1) 83 | else: 84 | ar = ar[None, ...] 85 | 86 | return ar.astype(np.float32) / 255. 87 | 88 | 89 | def rgb2ycbcr(img): 90 | #out = color.rgb2ycbcr( img.transpose(1, 2, 0) ) 91 | #return out.transpose(2,0,1)/256. 92 | r,g,b = img[0],img[1],img[2] 93 | y = 0.299*r+0.587*g+0.114*b 94 | cb = 0.5 - 0.168736*r - 0.331264*g + 0.5*b 95 | cr = 0.5 + 0.5*r - 0.418588*g - 0.081312*b 96 | return np.array([y,cb,cr]) 97 | 98 | def ycbcr2rgb(img): 99 | #out = color.ycbcr2rgb( 256.*img.transpose(1, 2, 0) ) 100 | #return (out.transpose(2,0,1) - np.min(out))/(np.max(out)-np.min(out)) 101 | y,cb,cr = img[0],img[1],img[2] 102 | r = y + 1.402*(cr-0.5) 103 | g = y - 0.344136*(cb-0.5) - 0.714136*(cr-0.5) 104 | b = y + 1.772*(cb - 0.5) 105 | return np.array([r,g,b]) 106 | 107 | 108 | 109 | def mse(x_hat,x_true,maxv=1.): 110 | x_hat = x_hat.flatten() 111 | x_true = x_true.flatten() 112 | mse = np.mean(np.square(x_hat-x_true)) 113 | energy = np.mean(np.square(x_true)) 114 | return mse/energy 115 | 116 | def psnr(x_hat,x_true,maxv=1.): 117 | x_hat = x_hat.flatten() 118 | x_true = x_true.flatten() 119 | mse=np.mean(np.square(x_hat-x_true)) 120 | psnr_ = 10.*np.log(maxv**2/mse)/np.log(10.) 121 | return psnr_ 122 | 123 | def num_param(net): 124 | s = sum([np.prod(list(p.size())) for p in net.parameters()]); 125 | return s 126 | #print('Number of params: %d' % s) 127 | 128 | def rgb2gray(rgb): 129 | r, g, b = rgb[0,:,:], rgb[1,:,:], rgb[2,:,:] 130 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b 131 | return np.array([gray]) 132 | 133 | def savemtx_for_logplot(A,filename = "exp.dat"): 134 | ind = sorted(list(set([int(i) for i in np.geomspace(1, len(A[0])-1 ,num=700)]))) 135 | A = [ [a[i] for i in ind] for a in A] 136 | X = np.array([ind] + A) 137 | np.savetxt(filename, X.T, delimiter=' ') 138 | -------------------------------------------------------------------------------- /src/metrics_utils.py: -------------------------------------------------------------------------------- 1 | """Some utils for plotting metrics""" 2 | # pylint: disable = C0111 3 | 4 | 5 | import glob 6 | import numpy as np 7 | import utils 8 | import matplotlib.pyplot as plt 9 | import matplotlib 10 | matplotlib.rcParams['pdf.fonttype'] = 42 11 | matplotlib.rcParams['ps.fonttype'] = 42 12 | 13 | 14 | 15 | def int_or_float(val): 16 | try: 17 | return int(val) 18 | except ValueError: 19 | return float(val) 20 | 21 | 22 | def get_figsize(is_save): 23 | if is_save: 24 | figsize = [6, 4] 25 | else: 26 | figsize = None 27 | return figsize 28 | 29 | def get_data(expt_dir): 30 | data = {} 31 | measurement_losses = utils.load_if_pickled(expt_dir + '/measurement_losses.pkl') 32 | l2_losses = utils.load_if_pickled(expt_dir + '/l2_losses.pkl') 33 | likelihoods = utils.load_if_pickled(expt_dir + '/likelihoods.pkl') 34 | z_norms = utils.load_if_pickled(expt_dir + '/z_norms.pkl') 35 | data = {'measurement': measurement_losses.values(), 36 | 'l2': l2_losses.values(), 37 | 'likelihood': likelihoods.values(), 38 | 'norm': z_norms.values()} 39 | return data 40 | 41 | 42 | def get_metrics(expt_dir): 43 | data = get_data(expt_dir) 44 | 45 | metrics = {} 46 | 47 | measurement_list = np.array(list(data['measurement'])) 48 | measurement_list = measurement_list/(3*256*256) 49 | m_loss_mean = np.mean(measurement_list) 50 | m_loss_std = np.std(measurement_list) / np.sqrt(len(data['measurement'])) 51 | metrics['measurement'] = {'mean': m_loss_mean, 'std': m_loss_std} 52 | 53 | l2_list = list(data['l2']) 54 | l2_loss_mean = np.mean(l2_list) 55 | l2_loss_std = np.std(np.array(l2_list)) / np.sqrt(len(data['l2'])) 56 | metrics['l2'] = {'mean':l2_loss_mean, 'std':l2_loss_std} 57 | 58 | 59 | likelihood_list = list(data['likelihood']) 60 | likelihood_loss_mean = np.mean(likelihood_list) 61 | likelihood_loss_std = np.std(np.array(likelihood_list)) / np.sqrt(len(data['likelihood'])) 62 | metrics['likelihood'] = {'mean':likelihood_loss_mean, 'std':likelihood_loss_std} 63 | 64 | norm_list = list(data['norm']) 65 | norm_loss_mean = np.mean(norm_list) 66 | norm_loss_std = np.std(np.array(norm_list)) / np.sqrt(len(data['norm'])) 67 | metrics['norm'] = {'mean':norm_loss_mean, 'std':norm_loss_std} 68 | return metrics 69 | 70 | 71 | def get_expt_metrics(expt_dirs): 72 | expt_metrics = {} 73 | for expt_dir in expt_dirs: 74 | metrics = get_metrics(expt_dir) 75 | expt_metrics[expt_dir] = metrics 76 | return expt_metrics 77 | 78 | 79 | def get_nested_value(dic, field): 80 | answer = dic 81 | for key in field: 82 | answer = answer[key] 83 | return answer 84 | 85 | 86 | def find_best(pattern, criterion, retrieve_list): 87 | dirs = glob.glob(pattern) 88 | metrics = get_expt_metrics(dirs) 89 | best_merit = 1e10 90 | answer = [None]*len(retrieve_list) 91 | for dir, val in metrics.items(): 92 | merit = get_nested_value(val, criterion) 93 | if merit < best_merit: 94 | best_merit = merit 95 | best_dir = dir 96 | for i, field in enumerate(retrieve_list): 97 | answer[i] = get_nested_value(val, field) 98 | 99 | try: 100 | print(best_dir) 101 | except: 102 | best_dir = None 103 | pass 104 | return answer, best_dir 105 | 106 | 107 | def plot(base, regex, criterion, retrieve_list, label): 108 | keys = map(int_or_float, [a.split('/')[-1] for a in glob.glob(base + '*')]) 109 | means, std_devs = {}, {} 110 | for i, key in enumerate(keys): 111 | pattern = base + str(key) + regex 112 | answer, _ = find_best(pattern, criterion, retrieve_list) 113 | if answer[0] is not None: 114 | means[key], std_devs[key] = answer 115 | plot_keys = sorted(means.keys()) 116 | if retrieve_list[0][0] != 'measurement': 117 | means = np.asarray([ means[key] for key in plot_keys]) 118 | std_devs = np.asarray([ std_devs[key] for key in plot_keys]) 119 | (lines, caps, _) = plt.errorbar(plot_keys, means, yerr=1.96*std_devs, 120 | marker='o', markersize=5, capsize=5, label=label) 121 | 122 | elif retrieve_list[0][0] == 'measurement': 123 | means = np.asarray([means[key] for key in plot_keys]) 124 | std_devs = np.asarray([std_devs[key] for key in plot_keys]) 125 | (lines, caps, _) = plt.errorbar(plot_keys, means, yerr=1.96*std_devs, 126 | fmt=':^', markersize=5, capsize=5, label=label) 127 | 128 | for cap in caps: 129 | cap.set_markeredgewidth(1) 130 | return lines.get_color() 131 | -------------------------------------------------------------------------------- /include/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def add_module(self, module): 5 | self.add_module(str(len(self) + 1), module) 6 | 7 | torch.nn.Module.add = add_module 8 | 9 | 10 | def conv(in_f, out_f, kernel_size, stride=1, pad='zero'): 11 | padder = None 12 | to_pad = int((kernel_size - 1) / 2) 13 | if pad == 'reflection': 14 | padder = nn.ReflectionPad2d(to_pad) 15 | to_pad = 0 16 | 17 | convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=False) 18 | 19 | layers = filter(lambda x: x is not None, [padder, convolver]) 20 | return nn.Sequential(*layers) 21 | 22 | def decodernw( 23 | num_output_channels=3, 24 | num_channels_up=[128]*5, 25 | filter_size_up=1, 26 | need_sigmoid=True, 27 | pad ='reflection', 28 | upsample_mode='bilinear', 29 | act_fun=nn.ReLU(), # nn.LeakyReLU(0.2, inplace=True) 30 | bn_before_act = False, 31 | bn_affine = True, 32 | upsample_first = True, 33 | ): 34 | 35 | num_channels_up = num_channels_up + [num_channels_up[-1],num_channels_up[-1]] 36 | n_scales = len(num_channels_up) 37 | 38 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 39 | filter_size_up = [filter_size_up]*n_scales 40 | model = nn.Sequential() 41 | 42 | 43 | for i in range(len(num_channels_up)-1): 44 | 45 | if upsample_first: 46 | model.add(conv( num_channels_up[i], num_channels_up[i+1], filter_size_up[i], 1, pad=pad)) 47 | if upsample_mode!='none' and i != len(num_channels_up)-2: 48 | model.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 49 | #model.add(nn.functional.interpolate(size=None,scale_factor=2, mode=upsample_mode)) 50 | else: 51 | if upsample_mode!='none' and i!=0: 52 | model.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 53 | #model.add(nn.functional.interpolate(size=None,scale_factor=2, mode=upsample_mode)) 54 | model.add(conv( num_channels_up[i], num_channels_up[i+1], filter_size_up[i], 1, pad=pad)) 55 | 56 | if i != len(num_channels_up)-1: 57 | if(bn_before_act): 58 | model.add(nn.BatchNorm2d( num_channels_up[i+1] ,affine=bn_affine)) 59 | model.add(act_fun) 60 | if(not bn_before_act): 61 | model.add(nn.BatchNorm2d( num_channels_up[i+1], affine=bn_affine)) 62 | 63 | model.add(conv( num_channels_up[-1], num_output_channels, 1, pad=pad)) 64 | if need_sigmoid: 65 | model.add(nn.Sigmoid()) 66 | 67 | return model 68 | 69 | 70 | 71 | # Residual block 72 | class ResidualBlock(nn.Module): 73 | def __init__(self, in_f, out_f): 74 | super(ResidualBlock, self).__init__() 75 | self.conv = nn.Conv2d(in_f, out_f, 1, 1, padding=0, bias=False) 76 | 77 | def forward(self, x): 78 | residual = x 79 | out = self.conv(x) 80 | out += residual 81 | return out 82 | 83 | def resdecoder( 84 | num_output_channels=3, 85 | num_channels_up=[128]*5, 86 | filter_size_up=1, 87 | need_sigmoid=True, 88 | pad='reflection', 89 | upsample_mode='bilinear', 90 | act_fun=nn.ReLU(), # nn.LeakyReLU(0.2, inplace=True) 91 | bn_before_act = False, 92 | bn_affine = True, 93 | ): 94 | 95 | num_channels_up = num_channels_up + [num_channels_up[-1],num_channels_up[-1]] 96 | n_scales = len(num_channels_up) 97 | 98 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 99 | filter_size_up = [filter_size_up]*n_scales 100 | 101 | model = nn.Sequential() 102 | 103 | for i in range(len(num_channels_up)-2): 104 | 105 | model.add( ResidualBlock( num_channels_up[i], num_channels_up[i+1]) ) 106 | 107 | if upsample_mode!='none': 108 | model.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 109 | #model.add(nn.functional.interpolate(size=None,scale_factor=2, mode=upsample_mode)) 110 | 111 | if i != len(num_channels_up)-1: 112 | model.add(act_fun) 113 | #model.add(nn.BatchNorm2d( num_channels_up[i+1], affine=bn_affine)) 114 | 115 | # new 116 | model.add(ResidualBlock( num_channels_up[-1], num_channels_up[-1])) 117 | #model.add(nn.BatchNorm2d( num_channels_up[-1] ,affine=bn_affine)) 118 | model.add(act_fun) 119 | # end new 120 | 121 | model.add(conv( num_channels_up[-1], num_output_channels, 1, pad=pad)) 122 | 123 | if need_sigmoid: 124 | model.add(nn.Sigmoid()) 125 | 126 | return model 127 | 128 | 129 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Instance-Optimal Compressed Sensing via Posterior Sampling & Fairness for Image Generation with Uncertain Sensitive Attributes 2 | 3 | This repo contains code for our papers [Instance-Optimal Compressed Sensing via Posterior Sampling](https://arxiv.org/abs/2106.11438) & [Fairness for Image Generation with Uncertain Sensitive Attributes]() 4 | 5 | NOTE: Please run **all** commands from the root directory of the repository, i.e from ```code-cs-fairness/``` 6 | 7 | ## Preliminaries 8 | 9 | 1. Clone repo and install dependencies 10 | 11 | ```shell 12 | $ git clone git@github.com:ajiljalal/code-cs-fairness.git 13 | $ cd code-cs-fairness 14 | $ python3.6 -m venv env 15 | $ source env/bin/activate 16 | $ pip install -U pip 17 | $ pip install -r requirements.txt 18 | $ git submodule update --init --recursive 19 | ``` 20 | 21 | 2. Download data, checkpoints, and setup validation images 22 | ```shell 23 | $ bash download.sh 24 | $ bash shuffle_catdog.sh 25 | ``` 26 | 27 | ## Reproducing quantitative results 28 | The scripts for compressed sensing results are in ```scripts/compressed-sensing```, and the scripts for fairness are in ```scripts/fairness```. 29 | Please adjust the command line arguments according to your requirements. ```--num-input-images``` and ```--batch-size``` need to be adjusted according to your compute capabilities and requirements. 30 | 31 | ## Visualizing results 32 | The files ```src/view_estimated_celebA_cs.py```, ```src/view_estimated_ffhq_cs.py``` will plot qualitative reconstructions for compressed sensing. The Jupyter notebook ```src/cs_metrics.ipynb``` will plot quantitative metrics. 33 | 34 | A similar notebook for fairness will be added shortly. 35 | 36 | You can manually access the results under appropriately named folders in ```estimated/```. 37 | 38 | ## Citations 39 | 40 | If you find this repo helpful, please cite the following papers: 41 | ``` 42 | @article{jalal2021instance, 43 | title={Instance-Optimal Compressed Sensing via Posterior Sampling}, 44 | author={Jalal, Ajil and Karmalkar, Sushrut and Dimakis, Alexandros G and Price, Eric}, 45 | journal={arXiv preprint arXiv:2106.11438}, 46 | year={2021} 47 | } 48 | 49 | @inproceedings{jalal2021fairness, 50 | title={Fairness for Image Generation with Uncertain Sensitive Attributes}, 51 | author={Jalal, Ajil and Karmalkar, Sushrut and Hoffmann, Jessica and Dimakis, Alex and Price, Eric}, 52 | booktitle={International Conference on Machine Learning}, 53 | pages={4721--4732}, 54 | year={2021}, 55 | organization={PMLR} 56 | } 57 | ``` 58 | 59 | Our work uses data, code, and models from the following prior work, which must be cited according to what you use: 60 | ``` 61 | @inproceedings{song2020improved, 62 | author = {Yang Song and Stefano Ermon}, 63 | editor = {Hugo Larochelle and 64 | Marc'Aurelio Ranzato and 65 | Raia Hadsell and 66 | Maria{-}Florina Balcan and 67 | Hsuan{-}Tien Lin}, 68 | title = {Improved Techniques for Training Score-Based Generative Models}, 69 | booktitle = {Advances in Neural Information Processing Systems 33: Annual Conference 70 | on Neural Information Processing Systems 2020, NeurIPS 2020, December 71 | 6-12, 2020, virtual}, 72 | year = {2020} 73 | } 74 | 75 | @article{kingma2018glow, 76 | title={Glow: Generative flow with invertible 1x1 convolutions}, 77 | author={Kingma, Diederik P and Dhariwal, Prafulla}, 78 | journal={arXiv preprint arXiv:1807.03039}, 79 | year={2018} 80 | } 81 | 82 | @inproceedings{Karras2020ada, 83 | title = {Training Generative Adversarial Networks with Limited Data}, 84 | author = {Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila}, 85 | booktitle = {Proc. NeurIPS}, 86 | year = {2020} 87 | } 88 | 89 | @inproceedings{choi2020starganv2, 90 | title={StarGAN v2: Diverse Image Synthesis for Multiple Domains}, 91 | author={Yunjey Choi and Youngjung Uh and Jaejun Yoo and Jung-Woo Ha}, 92 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 93 | year={2020} 94 | } 95 | 96 | @article{heckel_deep_2018, 97 | author = {Reinhard Heckel and Paul Hand}, 98 | title = {Deep Decoder: Concise Image Representations from Untrained Non-convolutional Networks}, 99 | journal = {International Conference on Learning Representations}, 100 | year = {2019} 101 | } 102 | 103 | ``` 104 | 105 | 106 | ## Acknowledgments 107 | The FFHQ NCSNv2 model was obtained from the official repo: [https://github.com/ermongroup/ncsnv2](https://github.com/ermongroup/ncsnv2) 108 | 109 | We trained StyleGAN2 models on cats and dogs via the official repo: [https://github.com/NVlabs/stylegan2-ada](https://github.com/NVlabs/stylegan2-ada) using the AFHQ dataset (Choi et al, 2020). The FFHQ StyleGAN2 model was obtained from the official repo. 110 | 111 | We used code from ```https://github.com/rosinality/stylegan2-pytorch``` to convert Stylegan2-ADA models from Tensorflow checkpoints to PyTorch checkpoints. 112 | 113 | The GLOW model and code is from the official repo: [https://github.com/openai/glow](https://github.com/openai/glow). Unfortunately the provided .pb file uses placeholders for the _z_ tensors, which makes them non-differentiable and the model cannot be directly used in our experiments. In order to address this issue, we used the solution found here: [https://stackoverflow.com/a/57528005](https://stackoverflow.com/a/57528005). 114 | 115 | -------------------------------------------------------------------------------- /include/fit.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import torch 3 | import torch.optim 4 | import copy 5 | import numpy as np 6 | from scipy.linalg import hadamard 7 | 8 | from .helpers import * 9 | 10 | dtype = torch.cuda.FloatTensor 11 | #dtype = torch.FloatTensor 12 | 13 | 14 | def exp_lr_scheduler(optimizer, epoch, init_lr=0.001, lr_decay_epoch=500): 15 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" 16 | lr = init_lr * (0.65**(epoch // lr_decay_epoch)) 17 | 18 | if epoch % lr_decay_epoch == 0: 19 | print('LR is set to {}'.format(lr)) 20 | 21 | for param_group in optimizer.param_groups: 22 | param_group['lr'] = lr 23 | 24 | return optimizer 25 | 26 | 27 | 28 | def fit(net, 29 | img_noisy_var, 30 | num_channels, 31 | img_clean_var, 32 | num_iter = 5000, 33 | LR = 0.01, 34 | OPTIMIZER='adam', 35 | opt_input = False, 36 | reg_noise_std = 0, 37 | reg_noise_decayevery = 100000, 38 | mask_var = None, 39 | apply_f = None, 40 | lr_decay_epoch = 0, 41 | net_input = None, 42 | net_input_gen = "random", 43 | find_best=False, 44 | weight_decay=0, 45 | ): 46 | 47 | if net_input is not None: 48 | print("input provided") 49 | else: 50 | # feed uniform noise into the network 51 | totalupsample = 2**len(num_channels) 52 | width = 2 # int(img_clean_var.data.shape[2]/totalupsample) 53 | height = 2 # int(img_clean_var.data.shape[3]/totalupsample) 54 | shape = [1,num_channels[0], width, height] 55 | print("shape: ", shape) 56 | net_input = Variable(torch.zeros(shape)) 57 | net_input.data.uniform_() 58 | net_input.data *= 1./10 59 | 60 | net_input_saved = net_input.data.clone() 61 | noise = net_input.data.clone() 62 | p = [x for x in net.parameters() ] 63 | 64 | if(opt_input == True): # optimizer over the input as well 65 | net_input.requires_grad = True 66 | p += [net_input] 67 | 68 | mse_wrt_noisy = np.zeros(num_iter) 69 | mse_wrt_truth = np.zeros(num_iter) 70 | 71 | if OPTIMIZER == 'SGD': 72 | print("optimize with SGD", LR) 73 | optimizer = torch.optim.SGD(p, lr=LR,momentum=0.9,weight_decay=weight_decay) 74 | elif OPTIMIZER == 'adam': 75 | print("optimize with adam", LR) 76 | optimizer = torch.optim.Adam(p, lr=LR,weight_decay=weight_decay) 77 | elif OPTIMIZER == 'LBFGS': 78 | print("optimize with LBFGS", LR) 79 | optimizer = torch.optim.LBFGS(p, lr=LR) 80 | 81 | mse = torch.nn.MSELoss() #.type(dtype) 82 | noise_energy = mse(img_noisy_var, img_clean_var) 83 | 84 | if find_best: 85 | best_net = copy.deepcopy(net) 86 | best_mse = 1000000.0 87 | 88 | for i in range(num_iter): 89 | 90 | if lr_decay_epoch is not 0: 91 | optimizer = exp_lr_scheduler(optimizer, i, init_lr=LR, lr_decay_epoch=lr_decay_epoch) 92 | if reg_noise_std > 0: 93 | if i % reg_noise_decayevery == 0: 94 | reg_noise_std *= 0.7 95 | net_input = Variable(net_input_saved + (noise.normal_() * reg_noise_std)) 96 | 97 | def closure(): 98 | optimizer.zero_grad() 99 | out = net(net_input.type(dtype)) 100 | 101 | # training loss 102 | if mask_var is not None: 103 | loss = mse( out * mask_var , img_noisy_var * mask_var ) 104 | elif apply_f: 105 | loss = mse( apply_f(out) , img_noisy_var ) 106 | else: 107 | loss = mse(out, img_noisy_var) 108 | 109 | loss.backward() 110 | mse_wrt_noisy[i] = loss.data.cpu().numpy() 111 | 112 | 113 | # the actual loss 114 | true_loss = torch.zeros_like(loss) # mse(Variable(out.data, requires_grad=False), img_clean_var) 115 | mse_wrt_truth[i] = true_loss.data.cpu().numpy() 116 | if i % 10 == 0: 117 | out2 = net(Variable(net_input_saved).type(dtype)) 118 | loss2 = torch.zeros_like(loss) # mse(out2, img_clean_var) 119 | print ('Iteration %05d Train loss %f Actual loss %f Actual loss orig %f Noise Energy %f' % (i, loss.data,true_loss.data,loss2.data,noise_energy.data), '\r', end='') 120 | return loss 121 | 122 | 123 | #if OPTIMIZER == 'LBFGS': 124 | # if i < 100: 125 | # optimizer = torch.optim.Adam(p, lr=LR) 126 | # else: 127 | # optimizer = torch.optim.LBFGS(p, lr=LR) 128 | 129 | 130 | loss = optimizer.step(closure) 131 | 132 | if find_best: 133 | # if training loss improves by at least one percent, we found a new best net 134 | if best_mse > 1.005*loss.data: 135 | best_mse = loss.data 136 | best_net = copy.deepcopy(net) 137 | 138 | 139 | if find_best: 140 | net = best_net 141 | return mse_wrt_noisy, mse_wrt_truth,net_input_saved, net 142 | 143 | 144 | 145 | 146 | 147 | ### weight regularization 148 | #if orth_reg > 0: 149 | # for name, param in net.named_parameters(): 150 | # consider all the conv weights, but the last one which only combines colors 151 | # if '.1.weight' in name and str( len(net)-1 ) not in name: 152 | # param_flat = param.view(param.shape[0], -1) 153 | # sym = torch.mm(param_flat, torch.t(param_flat)) 154 | # sym -= Variable(torch.eye(param_flat.shape[0])).type(dtype) 155 | # loss = loss + (orth_reg * sym.sum().type(dtype) ) 156 | ### 157 | -------------------------------------------------------------------------------- /src/cs_metrics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import glob\n", 10 | "import matplotlib\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "# import seaborn as sns\n", 13 | "%matplotlib inline\n", 14 | "from matplotlib.backends.backend_pdf import PdfPages\n", 15 | "from PIL import Image\n", 16 | "import numpy as np\n", 17 | "matplotlib.rcParams['pdf.fonttype'] = 42\n", 18 | "matplotlib.rcParams['ps.fonttype'] = 42\n", 19 | "\n", 20 | "import utils\n", 21 | "import metrics_utils\n", 22 | "\n", 23 | "matplotlib.rcParams.update({'font.size': 15, 'font.weight':'normal'})\n", 24 | "\n", 25 | "is_save = True\n", 26 | "figsize = metrics_utils.get_figsize(is_save)" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "# CelebA results" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# hyperparameters that gave best validation error\n", 43 | "# MAP and Ours\n", 44 | "\n", 45 | "dset = 'celebA'\n", 46 | "noise_std = 16.0\n", 47 | "legend_base_regexs = [\n", 48 | " ('Langevin(Ours)',\n", 49 | " f'../estimated/{dset}/full-input/circulant/{noise_std}/',\n", 50 | " '/glow/*langevin/None*'),\n", 51 | " ('MAP',\n", 52 | " f'../estimated*/celebA/full-input/circulant/{noise_std}/',\n", 53 | " '/glow/map/None_None*'),\n", 54 | "]\n", 55 | "\n" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "## Plot\n", 65 | "retrieve = 'l2'\n", 66 | "crit= 'likelihood'\n", 67 | "pdf = PdfPages(f'../results/dset={dset}_noise={noise_std}_retrieve={retrieve}_criterion={crit}.pdf')\n", 68 | "ax = plt.axes()\n", 69 | "criterion = [crit, 'mean']\n", 70 | "retrieve_list = [[retrieve, 'mean'], [retrieve, 'std']]\n", 71 | "legends = []\n", 72 | "colors = dict()\n", 73 | "for legend, base, regex in legend_base_regexs:\n", 74 | " if legend == 'Modified-MAP':\n", 75 | " criterion = ['l2', 'mean']\n", 76 | " colors[legend] = metrics_utils.plot(base, regex, criterion, retrieve_list, legend)\n", 77 | " legends.append(legend)\n", 78 | " \n", 79 | "\n", 80 | "\n", 81 | "\n", 82 | "plt.gca().set_xscale(\"log\", nonposx='clip')\n", 83 | "plt.gca().set_ylim([0,0.008])\n", 84 | "plt.gca().xaxis.set_major_formatter(plt.NullFormatter())\n", 85 | "plt.gca().xaxis.set_minor_formatter(plt.NullFormatter())\n", 86 | "\n", 87 | "# labels, ticks, titles\n", 88 | "ticks = [2500,5000,10000,20000,30000,35000]\n", 89 | "labels = ticks\n", 90 | "\n", 91 | "\n", 92 | "# plt.gca().set_xticks(ticks, labels)#, rotation='vertical')\n", 93 | "plt.xticks(ticks, labels, rotation='vertical')\n", 94 | "plt.ylabel('Reconstruction error (per pixel)')\n", 95 | "plt.xlabel('Number of measurements')\n", 96 | "\n", 97 | "\n", 98 | "# Legends\n", 99 | "plt.legend(fontsize=8)\n", 100 | "\n", 101 | "pdf.savefig(bbox_inches='tight')\n", 102 | "pdf.close()\n", 103 | "\n", 104 | "\n" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "# FFHQ results" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "# hyperparameters that gave best validation error\n", 121 | "# MAP and Ours\n", 122 | "\n", 123 | "\n", 124 | "\n", 125 | "dset = 'ffhq-69000'\n", 126 | "noise_std = 4.0\n", 127 | "legend_base_regexs = [\n", 128 | " ('MAP',\n", 129 | " f'../estimated/{dset}/full-input/circulant/{noise_std}/',\n", 130 | " '/ncsnv2/map/*'),\n", 131 | " ('Langevin(Ours)',\n", 132 | " f'../estimated/{dset}/full-input/circulant/{noise_std}/',\n", 133 | " '/ncsnv2/langevin/*')\n", 134 | "]" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "## Plot\n", 144 | "retrieve = 'l2'\n", 145 | "crit = 'l2'\n", 146 | "pdf = PdfPages(f'../results/dset={dset}_noise={noise_std}_retrieve={retrieve}_criterion={crit}.pdf')\n", 147 | "ax = plt.axes()\n", 148 | "criterion = [crit, 'mean']\n", 149 | "retrieve_list = [[retrieve, 'mean'], [retrieve, 'std']]\n", 150 | "legends = []\n", 151 | "colors = dict()\n", 152 | "for legend, base, regex in legend_base_regexs:\n", 153 | " colors[legend] = metrics_utils.plot(base, regex, criterion, retrieve_list, legend)\n", 154 | " legends.append(legend)\n", 155 | " \n", 156 | "\n", 157 | "\n", 158 | "\n", 159 | "plt.gca().set_xscale(\"log\", nonposx='clip')\n", 160 | "plt.gca().set_ylim([0,0.005])\n", 161 | "plt.gca().xaxis.set_major_formatter(plt.NullFormatter())\n", 162 | "plt.gca().xaxis.set_minor_formatter(plt.NullFormatter())\n", 163 | "\n", 164 | "# # labels, ticks, titles\n", 165 | "ticks = [5000, 10000, 15000, 20000, 30000, 40000, 50000, 75000]\n", 166 | "labels = ticks\n", 167 | "\n", 168 | "plt.xticks(ticks, labels, rotation='vertical')\n", 169 | "plt.ylabel('Reconstruction error (per pixel)')\n", 170 | "plt.xlabel('Number of measurements')\n", 171 | "\n", 172 | "\n", 173 | "# Legends\n", 174 | "plt.legend(fontsize=8)\n", 175 | "\n", 176 | "pdf.savefig(bbox_inches='tight')\n", 177 | "pdf.close()\n", 178 | "\n", 179 | "\n" 180 | ] 181 | } 182 | ], 183 | "metadata": { 184 | "kernelspec": { 185 | "display_name": "cs-full", 186 | "language": "python", 187 | "name": "cs-full" 188 | }, 189 | "language_info": { 190 | "codemirror_mode": { 191 | "name": "ipython", 192 | "version": 3 193 | }, 194 | "file_extension": ".py", 195 | "mimetype": "text/x-python", 196 | "name": "python", 197 | "nbconvert_exporter": "python", 198 | "pygments_lexer": "ipython3", 199 | "version": "3.6.12" 200 | } 201 | }, 202 | "nbformat": 4, 203 | "nbformat_minor": 1 204 | } 205 | -------------------------------------------------------------------------------- /stylegan2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | upfirdn2d_op = load( 11 | "upfirdn2d", 12 | sources=[ 13 | os.path.join(module_path, "upfirdn2d.cpp"), 14 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 15 | ], 16 | ) 17 | 18 | 19 | class UpFirDn2dBackward(Function): 20 | @staticmethod 21 | def forward( 22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 23 | ): 24 | 25 | up_x, up_y = up 26 | down_x, down_y = down 27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 28 | 29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 30 | 31 | grad_input = upfirdn2d_op.upfirdn2d( 32 | grad_output, 33 | grad_kernel, 34 | down_x, 35 | down_y, 36 | up_x, 37 | up_y, 38 | g_pad_x0, 39 | g_pad_x1, 40 | g_pad_y0, 41 | g_pad_y1, 42 | ) 43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 44 | 45 | ctx.save_for_backward(kernel) 46 | 47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 48 | 49 | ctx.up_x = up_x 50 | ctx.up_y = up_y 51 | ctx.down_x = down_x 52 | ctx.down_y = down_y 53 | ctx.pad_x0 = pad_x0 54 | ctx.pad_x1 = pad_x1 55 | ctx.pad_y0 = pad_y0 56 | ctx.pad_y1 = pad_y1 57 | ctx.in_size = in_size 58 | ctx.out_size = out_size 59 | 60 | return grad_input 61 | 62 | @staticmethod 63 | def backward(ctx, gradgrad_input): 64 | kernel, = ctx.saved_tensors 65 | 66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 67 | 68 | gradgrad_out = upfirdn2d_op.upfirdn2d( 69 | gradgrad_input, 70 | kernel, 71 | ctx.up_x, 72 | ctx.up_y, 73 | ctx.down_x, 74 | ctx.down_y, 75 | ctx.pad_x0, 76 | ctx.pad_x1, 77 | ctx.pad_y0, 78 | ctx.pad_y1, 79 | ) 80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 81 | gradgrad_out = gradgrad_out.view( 82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 83 | ) 84 | 85 | return gradgrad_out, None, None, None, None, None, None, None, None 86 | 87 | 88 | class UpFirDn2d(Function): 89 | @staticmethod 90 | def forward(ctx, input, kernel, up, down, pad): 91 | up_x, up_y = up 92 | down_x, down_y = down 93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 94 | 95 | kernel_h, kernel_w = kernel.shape 96 | batch, channel, in_h, in_w = input.shape 97 | ctx.in_size = input.shape 98 | 99 | input = input.reshape(-1, in_h, in_w, 1) 100 | 101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 102 | 103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 105 | ctx.out_size = (out_h, out_w) 106 | 107 | ctx.up = (up_x, up_y) 108 | ctx.down = (down_x, down_y) 109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 110 | 111 | g_pad_x0 = kernel_w - pad_x0 - 1 112 | g_pad_y0 = kernel_h - pad_y0 - 1 113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 115 | 116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 117 | 118 | out = upfirdn2d_op.upfirdn2d( 119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 120 | ) 121 | # out = out.view(major, out_h, out_w, minor) 122 | out = out.view(-1, channel, out_h, out_w) 123 | 124 | return out 125 | 126 | @staticmethod 127 | def backward(ctx, grad_output): 128 | kernel, grad_kernel = ctx.saved_tensors 129 | 130 | grad_input = UpFirDn2dBackward.apply( 131 | grad_output, 132 | kernel, 133 | grad_kernel, 134 | ctx.up, 135 | ctx.down, 136 | ctx.pad, 137 | ctx.g_pad, 138 | ctx.in_size, 139 | ctx.out_size, 140 | ) 141 | 142 | return grad_input, None, None, None, None 143 | 144 | 145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 146 | if input.device.type == "cpu": 147 | out = upfirdn2d_native( 148 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 149 | ) 150 | 151 | else: 152 | out = UpFirDn2d.apply( 153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 154 | ) 155 | 156 | return out 157 | 158 | 159 | def upfirdn2d_native( 160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 161 | ): 162 | _, channel, in_h, in_w = input.shape 163 | input = input.reshape(-1, in_h, in_w, 1) 164 | 165 | _, in_h, in_w, minor = input.shape 166 | kernel_h, kernel_w = kernel.shape 167 | 168 | out = input.view(-1, in_h, 1, in_w, 1, minor) 169 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 170 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 171 | 172 | out = F.pad( 173 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 174 | ) 175 | out = out[ 176 | :, 177 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 178 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 179 | :, 180 | ] 181 | 182 | out = out.permute(0, 3, 1, 2) 183 | out = out.reshape( 184 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 185 | ) 186 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 187 | out = F.conv2d(out, w) 188 | out = out.reshape( 189 | -1, 190 | minor, 191 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 192 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 193 | ) 194 | out = out.permute(0, 2, 3, 1) 195 | out = out[:, ::down_y, ::down_x, :] 196 | 197 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 198 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 199 | 200 | return out.view(-1, channel, out_h, out_w) 201 | -------------------------------------------------------------------------------- /glow/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from PIL import Image 4 | from tensorflow.core.framework import graph_pb2 5 | import copy 6 | import tensorflow.contrib.graph_editor as ge 7 | import os 8 | from threading import Lock 9 | 10 | 11 | def flatten_eps(eps): 12 | # [BS, eps_size] 13 | return np.concatenate([np.reshape(e, (e.shape[0], -1)) for e in eps], axis=-1) 14 | 15 | 16 | def unflatten_eps(feps): 17 | index = 0 18 | eps = [] 19 | bs = feps.shape[0] # feps.size // eps_size 20 | for shape in eps_shapes: 21 | eps.append(np.reshape( 22 | feps[:, index: index+np.prod(shape)], (bs, *shape))) 23 | index += np.prod(shape) 24 | return eps 25 | 26 | def load_pb(path_to_pb): 27 | with tf.gfile.GFile(path_to_pb, 'rb') as f: 28 | graph_def_optimized = tf.GraphDef() 29 | graph_def_optimized.ParseFromString(f.read()) 30 | 31 | with tf.Graph().as_default() as graph: 32 | tf.import_graph_def(graph_def_optimized, name='') 33 | return graph 34 | 35 | 36 | def run(sess, fetches, feed_dict): 37 | lock = Lock() 38 | with lock: 39 | # Locked tensorflow so average server response time to user is lower 40 | result = sess.run(fetches, feed_dict) 41 | return result 42 | 43 | def get_model(model_path, batch_size, z_sdev): 44 | assert os.path.exists(model_path), f'model_path does not exist: {model_path}' 45 | 46 | with tf.gfile.GFile(model_path, 'rb') as f: 47 | graph_def = tf.GraphDef() 48 | graph_def.ParseFromString(f.read()) 49 | 50 | tf.import_graph_def(graph_def) 51 | 52 | 53 | inputs = { 54 | 'dec_eps_0': 'Placeholder', 55 | 'dec_eps_1': 'Placeholder_1', 56 | 'dec_eps_2': 'Placeholder_2', 57 | 'dec_eps_3': 'Placeholder_3', 58 | 'dec_eps_4': 'Placeholder_4', 59 | 'dec_eps_5': 'Placeholder_5', 60 | 'enc_x': 'input/image', 61 | 'enc_x_d': 'input/downsampled_image', 62 | 'enc_y': 'input/label' 63 | } 64 | outputs = { 65 | 'dec_x': 'model_1/Reshape_4', 66 | 'enc_eps_0': 'model/pool0/truediv_1', 67 | 'enc_eps_1': 'model/pool1/truediv_1', 68 | 'enc_eps_2': 'model/pool2/truediv_1', 69 | 'enc_eps_3': 'model/pool3/truediv_1', 70 | 'enc_eps_4': 'model/pool4/truediv_1', 71 | 'enc_eps_5': 'model/truediv_4' 72 | } 73 | 74 | eps_shapes = [(128, 128, 6), (64, 64, 12), (32, 32, 24), 75 | (16, 16, 48), (8, 8, 96), (4, 4, 384)] 76 | eps_sizes = [np.prod(e) for e in eps_shapes] 77 | eps_size = 256 * 256 * 3 78 | 79 | 80 | dec_eps = [] 81 | dec_eps_shapes = [(batch_size,128, 128, 6), (batch_size,64, 64, 12), (batch_size,32, 32, 24), 82 | (batch_size,16, 16, 48), (batch_size,8, 8, 96), (batch_size,4, 4, 384)] 83 | 84 | # replace the decoder placeholders with differentiable variables 85 | target_var_name_pairs = [] 86 | for i in range(6): 87 | # name of i-th decoder placeholder 88 | name = 'import/' + inputs[f'dec_eps_{i}'] 89 | var_shape = dec_eps_shapes[i] 90 | 91 | # Give each variable a name that doesn't already exist in the graph 92 | var_name = f'dec_eps_{i}_turned_var' 93 | # Create TensorFlow variable initialized by values of original const. 94 | # var = tf.get_variable(name=var_name, dtype='float32', shape=var_shape,initializer=tf.constant_initializer(tensor_as_numpy_array)) 95 | var = tf.get_variable(name=var_name, dtype='float32', shape=var_shape,initializer=tf.random_normal_initializer(stddev=z_sdev)) 96 | 97 | # We want to keep track of our variables names for later. 98 | target_var_name_pairs.append((name, var_name)) 99 | 100 | # add new variable to list 101 | dec_eps.append(var) 102 | 103 | # At this point, we added a bunch of tf.Variables to the graph, but they're 104 | # not connected to anything. 105 | 106 | # The magic: we use TF Graph Editor to swap the Constant nodes' outputs with 107 | # the outputs of our newly created Variables. 108 | 109 | for const_name, var_name in target_var_name_pairs: 110 | const_op = tf.get_default_graph().get_operation_by_name(const_name) 111 | var_reader_op = tf.get_default_graph().get_operation_by_name(var_name + '/read') 112 | ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op)) 113 | 114 | 115 | # remove floor operations from the graph 116 | floor_op = tf.get_default_graph().get_operation_by_name('import/model/Floor_1') 117 | div_op = tf.get_default_graph().get_operation_by_name('import/model/truediv_2') 118 | 119 | ge.swap_outputs(ge.sgv(floor_op), ge.sgv(div_op)) 120 | 121 | # remove random noise from encoder 122 | c_op = tf.get_default_graph().get_operation_by_name('import/model/random_uniform/sub/_4108__cf__4108') 123 | c_ = tf.constant(0, dtype=tf.float32) 124 | c_zero_op = tf.get_default_graph().get_operation_by_name('Const') 125 | ge.swap_outputs(ge.sgv(c_op), ge.sgv(c_zero_op)) 126 | 127 | n_eps = 6 128 | 129 | 130 | def get(name): 131 | return tf.get_default_graph().get_tensor_by_name('import/' + name + ':0') 132 | 133 | # Encoder 134 | enc_x = get(inputs['enc_x']) 135 | enc_eps = [get(outputs['enc_eps_' + str(i)]) for i in range(n_eps)] 136 | enc_x_d = get(inputs['enc_x_d']) 137 | enc_y = get(inputs['enc_y']) 138 | 139 | # Decoder 140 | dec_x = get(outputs['dec_x']) 141 | 142 | 143 | 144 | def encode(img): 145 | if len(img.shape) == 3: 146 | img = np.expand_dims(img, 0) 147 | bs = img.shape[0] 148 | assert img.shape[1:] == (256, 256, 3) 149 | feed_dict = {enc_x: img} 150 | update_feed(feed_dict, bs) # For unoptimized model 151 | 152 | return flatten_eps(run(sess, enc_eps, feed_dict)) 153 | 154 | 155 | def decode(feps): 156 | if len(feps.shape) == 1: 157 | feps = np.expand_dims(feps, 0) 158 | bs = feps.shape[0] 159 | # assert len(eps) == n_eps 160 | # for i in range(n_eps): 161 | # shape = (BATCH_SIZE, 128 // (2 ** i), 128 // (2 ** i), 6 * (2 ** i) * (2 ** (i == (n_eps - 1)))) 162 | # assert eps[i].shape == shape 163 | eps = unflatten_eps(feps) 164 | 165 | feed_dict = {} 166 | for i in range(n_eps): 167 | feed_dict[dec_eps[i]] = eps[i] 168 | update_feed(feed_dict, bs) # For unoptimized model 169 | 170 | return run(sess, dec_x, feed_dict) 171 | 172 | def random(bs=1, eps_std=0.7): 173 | feps = np.random.normal(scale=eps_std, size=[bs, eps_size]) 174 | return decode(feps), feps 175 | # function that updates the feed_dict to include a downsampled image 176 | # and a conditional label set to all zeros. 177 | def update_feed(feed_dict, bs): 178 | x_d = 128 * np.ones([bs, 128, 128, 3], dtype=np.uint8) 179 | y = np.zeros([bs], dtype=np.int32) 180 | feed_dict[enc_x_d] = x_d 181 | feed_dict[enc_y] = y 182 | return feed_dict 183 | 184 | feed_dict = {} 185 | update_feed(feed_dict, batch_size) 186 | return dec_x, dec_eps, feed_dict, run 187 | 188 | 189 | -------------------------------------------------------------------------------- /src/PULSE.py: -------------------------------------------------------------------------------- 1 | #from stylegan import G_synthesis,G_mapping 2 | import sys 3 | import os 4 | from dataclasses import dataclass 5 | from SphericalOptimizer import SphericalOptimizer 6 | from pathlib import Path 7 | import numpy as np 8 | import time 9 | import torch 10 | from loss import LossBuilder 11 | from functools import partial 12 | 13 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 14 | 15 | from stylegan2.model import Generator 16 | 17 | class PULSE(torch.nn.Module): 18 | def __init__(self, image_size, checkpoint_path, dataset, verbose=True): 19 | super(PULSE, self).__init__() 20 | 21 | self.image_size = image_size 22 | self.model = Generator(image_size, 512,8) 23 | self.model.load_state_dict(torch.load(checkpoint_path)["g_ema"], strict=False) 24 | self.model.eval() 25 | self.model.cuda() 26 | for p in self.model.parameters(): 27 | p.requires_grad = False 28 | 29 | #self.synthesis = G_synthesis().cuda() 30 | self.verbose = verbose 31 | self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2) 32 | if Path(f"{dataset}_gaussian_fit.pt").exists(): 33 | self.gaussian_fit = torch.load(f"{dataset}_gaussian_fit.pt") 34 | else: 35 | with torch.no_grad(): 36 | latent = torch.randn((1000000,512),dtype=torch.float32, device="cuda") 37 | latent_out = torch.nn.LeakyReLU(5)(self.model.style(latent)) 38 | self.gaussian_fit = {"mean": latent_out.mean(0), "std": latent_out.std(0)} 39 | torch.save(self.gaussian_fit,f"{dataset}_gaussian_fit.pt") 40 | if self.verbose: print(f"\tSaved {dataset}_gaussian_fit.pt") 41 | 42 | #cache_dir = Path(cache_dir) 43 | #cache_dir.mkdir(parents=True, exist_ok = True) 44 | #if self.verbose: print("Loading Synthesis Network") 45 | #with open_url("https://drive.google.com/uc?id=1TCViX1YpQyRsklTVYEJwdbmK91vklCo8", cache_dir=cache_dir, verbose=verbose) as f: 46 | # self.synthesis.load_state_dict(torch.load(f)) 47 | 48 | #for param in self.synthesis.parameters(): 49 | # param.requires_grad = False 50 | 51 | #self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2) 52 | 53 | #if Path("gaussian_fit.pt").exists(): 54 | # self.gaussian_fit = torch.load("gaussian_fit.pt") 55 | #else: 56 | # if self.verbose: print("\tLoading Mapping Network") 57 | # mapping = G_mapping().cuda() 58 | 59 | # with open_url("https://drive.google.com/uc?id=14R6iHGf5iuVx3DMNsACAl7eBr7Vdpd0k", cache_dir=cache_dir, verbose=verbose) as f: 60 | # mapping.load_state_dict(torch.load(f)) 61 | 62 | # if self.verbose: print("\tRunning Mapping Network") 63 | # with torch.no_grad(): 64 | # torch.manual_seed(0) 65 | # latent = torch.randn((1000000,512),dtype=torch.float32, device="cuda") 66 | # latent_out = torch.nn.LeakyReLU(5)(mapping(latent)) 67 | # self.gaussian_fit = {"mean": latent_out.mean(0), "std": latent_out.std(0)} 68 | # torch.save(self.gaussian_fit,"gaussian_fit.pt") 69 | # if self.verbose: print("\tSaved \"gaussian_fit.pt\"") 70 | 71 | def forward(self, ref_im, 72 | seed, 73 | loss_str, 74 | pulse_eps, 75 | noise_type, 76 | num_noise_variables, 77 | tile_latent, 78 | optimizer_type, 79 | learning_rate, 80 | max_update_iter, 81 | lr_schedule, 82 | **kwargs): 83 | 84 | if seed: 85 | torch.manual_seed(seed) 86 | torch.cuda.manual_seed(seed) 87 | torch.backends.cudnn.deterministic = True 88 | 89 | batch_size = ref_im.shape[0] 90 | 91 | # Generate latent tensor 92 | if(tile_latent): 93 | latent = torch.randn( 94 | (batch_size, 1, 512), dtype=torch.float, requires_grad=True, device='cuda') 95 | else: 96 | latent = torch.randn( 97 | (batch_size, 18, 512), dtype=torch.float, requires_grad=True, device='cuda') 98 | 99 | # Generate list of noise tensors 100 | #noise = [] # stores all of the noise tensors 101 | #noise_vars = [] # stores the noise tensors that we want to optimize on 102 | 103 | #for i in range(18): 104 | # # dimension of the ith noise tensor 105 | # res = (batch_size, 1, 2**(i//2+2), 2**(i//2+2)) 106 | 107 | # if(noise_type == 'zero' or i in [int(layer) for layer in bad_noise_layers.split('.')]): 108 | # new_noise = torch.zeros(res, dtype=torch.float, device='cuda') 109 | # new_noise.requires_grad = False 110 | # elif(noise_type == 'fixed'): 111 | # new_noise = torch.randn(res, dtype=torch.float, device='cuda') 112 | # new_noise.requires_grad = False 113 | # elif (noise_type == 'trainable'): 114 | # new_noise = torch.randn(res, dtype=torch.float, device='cuda') 115 | # if (i < num_noise_variables): 116 | # new_noise.requires_grad = True 117 | # noise_vars.append(new_noise) 118 | # else: 119 | # new_noise.requires_grad = False 120 | # else: 121 | # raise Exception("unknown noise type") 122 | 123 | # noise.append(new_noise) 124 | 125 | noises_single = self.model.make_noise() 126 | noises = [] 127 | noise_vars = [] 128 | for idx, noise in enumerate(noises_single): 129 | res = noise.shape 130 | 131 | if(noise_type == 'zero'): 132 | new_noise = torch.zeros(res, dtype=torch.float, device='cuda') 133 | new_noise.requires_grad = False 134 | elif(noise_type == 'fixed'): 135 | new_noise = torch.randn(res, dtype=torch.float, device='cuda') 136 | new_noise.requires_grad = False 137 | elif (noise_type == 'trainable'): 138 | new_noise = torch.randn(res, dtype=torch.float, device='cuda') 139 | if (idx < num_noise_variables): 140 | new_noise.requires_grad = True 141 | noise_vars.append(new_noise) 142 | else: 143 | new_noise.requires_grad = False 144 | else: 145 | raise Exception("unknown noise type") 146 | noises.append(new_noise) 147 | 148 | print(len(noise_vars)) 149 | var_list = [latent]+noise_vars 150 | 151 | opt_dict = { 152 | 'sgd': torch.optim.SGD, 153 | 'adam': torch.optim.Adam, 154 | 'sgdm': partial(torch.optim.SGD, momentum=0.9), 155 | 'adamax': torch.optim.Adamax 156 | } 157 | opt_func = opt_dict[optimizer_type] 158 | opt = SphericalOptimizer(opt_func, var_list, lr=learning_rate) 159 | 160 | schedule_dict = { 161 | 'fixed': lambda x: 1, 162 | 'linear1cycle': lambda x: (9*(1-np.abs(x/max_update_iter-1/2)*2)+1)/10, 163 | 'linear1cycledrop': lambda x: (9*(1-np.abs(x/(0.9*max_update_iter)-1/2)*2)+1)/10 if x < 0.9*max_update_iter else 1/10 + (x-0.9*max_update_iter)/(0.1*max_update_iter)*(1/1000-1/10), 164 | } 165 | schedule_func = schedule_dict[lr_schedule] 166 | scheduler = torch.optim.lr_scheduler.LambdaLR(opt.opt, schedule_func) 167 | 168 | loss_builder = LossBuilder(ref_im, loss_str, pulse_eps, self.image_size).cuda() 169 | 170 | min_loss = np.inf 171 | min_l2 = np.inf 172 | best_summary = "" 173 | start_t = time.time() 174 | gen_im = None 175 | 176 | 177 | if self.verbose: print("Optimizing") 178 | for j in range(max_update_iter): 179 | opt.opt.zero_grad() 180 | 181 | # Duplicate latent in case tile_latent = True 182 | if (tile_latent): 183 | latent_in = latent.expand(-1, 18, -1) 184 | else: 185 | latent_in = latent 186 | 187 | # Apply learned linear mapping to match latent distribution to that of the mapping network 188 | latent_in = self.lrelu(latent_in*self.gaussian_fit["std"] + self.gaussian_fit["mean"]) 189 | 190 | # Normalize image to [0,1] instead of [-1,1] 191 | #gen_im = (self.synthesis(latent_in, noise)+1)/2 192 | gen_im = self.model([latent_in],input_is_latent=True, noise=noises)[0] * 0.5 + 0.5 193 | 194 | # Calculate Losses 195 | loss, loss_dict = loss_builder(latent_in, gen_im) 196 | loss_dict['TOTAL'] = loss 197 | 198 | # Save best summary for log 199 | if(loss < min_loss): 200 | min_loss = loss 201 | best_summary = f'BEST ({j+1}) | '+' | '.join( 202 | [f'{x}: {y:.4f}' for x, y in loss_dict.items()]) 203 | best_im = gen_im.clone() 204 | 205 | print(best_summary) 206 | loss_l2 = loss_dict['L2'] 207 | 208 | if(loss_l2 < min_l2): 209 | min_l2 = loss_l2 210 | 211 | 212 | loss.backward() 213 | opt.step() 214 | scheduler.step() 215 | 216 | total_t = time.time()-start_t 217 | current_info = f' | time: {total_t:.1f} | it/s: {(j+1)/total_t:.2f} | batchsize: {batch_size}' 218 | if self.verbose: print(best_summary+current_info) 219 | if(min_l2 <= pulse_eps): 220 | pass 221 | else: 222 | print("Could not find a face that downscales correctly within epsilon") 223 | yield (gen_im.clone().cpu().detach().clamp(0, 1),loss_builder.D(best_im).cpu().detach().clamp(0, 1)) 224 | -------------------------------------------------------------------------------- /src/compressed_sensing.py: -------------------------------------------------------------------------------- 1 | """Compressed sensing main script""" 2 | # pylint: disable=C0301,C0103,C0111 3 | 4 | from __future__ import division 5 | import os 6 | from argparse import ArgumentParser 7 | import numpy as np 8 | import utils 9 | import yaml 10 | 11 | 12 | def main(hparams): 13 | 14 | # Set up some stuff accoring to hparams 15 | hparams.n_input = np.prod(hparams.image_shape) 16 | utils.print_hparams(hparams) 17 | 18 | # get inputs 19 | xs_dict = utils.model_input(hparams) 20 | 21 | # get estimator 22 | estimator = utils.get_estimator(hparams, hparams.model_type) 23 | 24 | # set up folders, etc for checkpointing 25 | utils.setup_checkpointing(hparams) 26 | 27 | # get saved results 28 | measurement_losses, l2_losses, z_hats, likelihoods = utils.load_checkpoints(hparams) 29 | 30 | x_batch_dict = {} 31 | x_hats_dict = {} 32 | 33 | A = utils.get_A(hparams) 34 | 35 | 36 | 37 | for key, x in xs_dict.items(): 38 | if not hparams.not_lazy: 39 | # If lazy, first check if the image has already been 40 | # saved before . If yes, then skip this image. 41 | save_path = utils.get_save_path(hparams, key) 42 | is_saved = os.path.isfile(save_path) 43 | if is_saved: 44 | continue 45 | 46 | x_batch_dict[key] = x 47 | if len(x_batch_dict) < hparams.batch_size: 48 | continue 49 | 50 | # Reshape input 51 | x_batch_list = [x.reshape(1, hparams.n_input) for _, x in x_batch_dict.items()] 52 | x_batch = np.concatenate(x_batch_list) 53 | 54 | # Construct noise and measurements 55 | noise_batch = utils.get_noise(hparams) 56 | y_batch = utils.get_measurements(x_batch, A, noise_batch, hparams) 57 | 58 | # Construct estimates 59 | x_hat_batch, z_hat_batch, likelihood_batch = estimator(A, y_batch, hparams) 60 | 61 | for i, key in enumerate(x_batch_dict.keys()): 62 | x = xs_dict[key] 63 | y_train = y_batch[i] 64 | x_hat = x_hat_batch[i] 65 | 66 | # Save the estimate 67 | x_hats_dict[key] = x_hat 68 | 69 | # Compute and store measurement and l2 loss 70 | measurement_losses[key] = utils.get_measurement_loss(x_hat, A, y_train, hparams) 71 | l2_losses[key] = utils.get_l2_loss(x_hat, x) 72 | z_hats[key] = z_hat_batch[i] 73 | likelihoods[key] = likelihood_batch[i] 74 | 75 | print('Processed upto image {0} / {1}'.format(key+1, len(xs_dict))) 76 | 77 | # Checkpointing 78 | if (not hparams.debug) and ((key+1) % hparams.checkpoint_iter == 0): 79 | utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, z_hats, likelihoods, hparams) 80 | x_hats_dict = {} 81 | print('\nProcessed and saved first ', key+1, 'images\n') 82 | 83 | x_batch_dict = {} 84 | 85 | # Final checkpoint 86 | if not hparams.debug: 87 | utils.checkpoint(x_hats_dict, measurement_losses, l2_losses, z_hats, likelihoods, hparams) 88 | print('\nProcessed and saved all {0} image(s)\n'.format(len(xs_dict))) 89 | 90 | if hparams.print_stats: 91 | print(hparams.model_type) 92 | measurement_loss_list = list(measurement_losses.values()) 93 | l2_loss_list = list(l2_losses.values()) 94 | mean_m_loss = np.mean(measurement_loss_list) 95 | mean_l2_loss = np.mean(l2_loss_list) 96 | print('mean measurement loss = {0}'.format(mean_m_loss)) 97 | print('mean l2 loss = {0}'.format(mean_l2_loss)) 98 | 99 | if hparams.image_matrix > 0: 100 | utils.image_matrix(xs_dict, x_hats_dict, hparams) 101 | 102 | # Warn the user that some things were not processsed 103 | if len(x_batch_dict) > 0: 104 | print('\nDid NOT process last {} images because they did not fill up the last batch.'.format(len(x_batch_dict))) 105 | print('Consider rerunning lazily with a smaller batch size.') 106 | 107 | if __name__ == '__main__': 108 | PARSER = ArgumentParser() 109 | 110 | # Pretrained model 111 | PARSER.add_argument('--checkpoint-path', type=str, default='models/ncsnv2_ffhq/checkpoint_80000.pth', help='Path to pretrained model') 112 | PARSER.add_argument('--net', type=str, default='ncsnv2', help='Name of model. options = [glow, stylegan2, ncsnv2, dd]') 113 | 114 | # Input 115 | PARSER.add_argument('--dataset', type=str, default='celebA', help='Dataset to use') 116 | PARSER.add_argument('--image-size', type=int, default=256, help='size of image') 117 | PARSER.add_argument('--input-type', type=str, default='full-input', help='Where to take input from') 118 | PARSER.add_argument('--num-input-images', type=int, default=2, help='number of input images') 119 | PARSER.add_argument('--batch-size', type=int, default=2, help='How many examples are processed together') 120 | PARSER.add_argument('--cache-dir', type=str, default='cache', help='cache directory for model weights') 121 | PARSER.add_argument('--ncsnv2-configs-file', type=str, default='./ncsnv2/configs/ffhq.yml', help='location of ncsnv2 config file') 122 | 123 | 124 | # Problem definition 125 | PARSER.add_argument('--measurement-type', type=str, default='gaussian', help='measurement type') 126 | PARSER.add_argument('--noise-std', type=float, default=1, help='expected norm of noise') 127 | PARSER.add_argument('--measurement-noise-type', type=str, default='gaussian', help='type of noise') 128 | 129 | # Measurement type specific hparams 130 | PARSER.add_argument('--num-measurements', type=int, default=200, help='number of gaussian measurements') 131 | PARSER.add_argument('--downsample', type=int, default=None, help='downsampling factor') 132 | 133 | # Model 134 | PARSER.add_argument('--model-type', type=str, default=None, required=True, help='model used for estimation. options=[map, langevin, pulse, dd]') 135 | PARSER.add_argument('--mloss-weight', type=float, default=-1, help='L2 measurement loss weight') 136 | PARSER.add_argument('--zprior-weight', type=float, default=-1, help='weight on z prior') 137 | PARSER.add_argument('--zprior-sdev', type=float, default=1.0, help='standard deviation for target distributon of z') 138 | PARSER.add_argument('--zprior-init-sdev', type=float, default=1.0, help='standard deviation to initialize z') 139 | PARSER.add_argument('--T', type=float, default=-1, help='number of iterations for each level of noise in Langevin annealing') 140 | PARSER.add_argument('--L', type=float, default=-1, help='number of noise levels for annealing Langevin') 141 | PARSER.add_argument('--sigma-init', type=float, default=-1, help='initial noise level for annealing langevin') 142 | PARSER.add_argument('--sigma-final', type=float, default=-1, help='final noise level for annealing Langevin') 143 | PARSER.add_argument('--error-threshold', type=float, default=0., help='threshold for measurement error before restart') 144 | PARSER.add_argument('--num-noise-variables', type=int, default=5, help='STYLEGAN2 : number of noise variables in to optimize') 145 | 146 | # NN specfic hparams 147 | PARSER.add_argument('--optimizer-type', type=str, default='sgd', help='Optimizer type') 148 | PARSER.add_argument('--learning-rate', type=float, default=0.4, help='learning rate') 149 | PARSER.add_argument('--momentum', type=float, default=0., help='momentum value') 150 | PARSER.add_argument('--max-update-iter', type=int, default=1000, help='maximum updates to z') 151 | PARSER.add_argument('--num-random-restarts', type=int, default=1, help='number of random restarts') 152 | PARSER.add_argument('--decay-lr', action='store_true', help='whether to decay learning rate') 153 | 154 | #PULSE arguments 155 | PARSER.add_argument('--seed', type=int, help='manual seed to use') 156 | PARSER.add_argument('--loss-str', type=str, default="100*L2+0.05*GEOCROSS", help='Loss function to use') 157 | PARSER.add_argument('--pulse-eps', type=float, default=2e-3, help='Target for downscaling loss (L2)') 158 | PARSER.add_argument('--noise-type', type=str, default='trainable', help='zero, fixed, or trainable') 159 | PARSER.add_argument('--tile-latent', action='store_true', help='Whether to forcibly tile the same latent 18 times') 160 | PARSER.add_argument('--lr-schedule', type=str, default='linear1cycledrop', help='fixed, linear1cycledrop, linear1cycle') 161 | 162 | # Output 163 | PARSER.add_argument('--not-lazy', action='store_true', help='whether the evaluation is lazy') 164 | PARSER.add_argument('--debug', action='store_true', help='debug mode does not save images or stats') 165 | PARSER.add_argument('--print-stats', action='store_true', help='whether to print statistics') 166 | PARSER.add_argument('--checkpoint-iter', type=int, default=50, help='checkpoint every x batches') 167 | PARSER.add_argument('--image-matrix', type=int, default=0, 168 | help=''' 169 | 0 = 00 = no image matrix, 170 | 1 = 01 = show image matrix 171 | 2 = 10 = save image matrix 172 | 3 = 11 = save and show image matrix 173 | ''' 174 | ) 175 | PARSER.add_argument('--gif', action='store_true', help='whether to create a gif') 176 | PARSER.add_argument('--gif-iter', type=int, default=1, help='save gif frame every x iter') 177 | PARSER.add_argument('--gif-dir', type=str, default='', help='where to store gif frames') 178 | 179 | PARSER.add_argument('--cuda', dest='cuda', action='store_true') 180 | PARSER.add_argument('--no-cuda', dest='cuda', action='store_false') 181 | PARSER.set_defaults(cuda=True) 182 | 183 | PARSER.add_argument('--project', dest='project', action='store_true') 184 | PARSER.add_argument('--no-project', dest='project', action='store_false') 185 | PARSER.set_defaults(project=True) 186 | 187 | PARSER.add_argument('--annealed', dest='annealed', action='store_true') 188 | PARSER.add_argument('--no-annealed', dest='annealed', action='store_false') 189 | PARSER.set_defaults(annealed=False) 190 | 191 | HPARAMS = PARSER.parse_args() 192 | HPARAMS.input_path = f'./test_images/{HPARAMS.dataset}' 193 | if HPARAMS.cuda: 194 | HPARAMS.device='cuda:0' 195 | else: 196 | HPARAMS.device = 'cpu:0' 197 | 198 | 199 | if HPARAMS.net == 'ncsnv2': 200 | with open(HPARAMS.ncsnv2_configs_file, 'r') as f: 201 | HPARAMS.ncsnv2_configs = yaml.load(f) 202 | HPARAMS.ncsnv2_configs['sampling']['step_lr'] = HPARAMS.learning_rate 203 | HPARAMS.ncsnv2_configs['sampling']['n_steps_each'] = int(HPARAMS.T) 204 | HPARAMS.ncsnv2_configs['model']['sigma_begin'] = int(HPARAMS.sigma_init) 205 | HPARAMS.ncsnv2_configs['model']['sigma_end'] = HPARAMS.sigma_final 206 | 207 | HPARAMS.image_shape = (3, HPARAMS.image_size, HPARAMS.image_size) 208 | HPARAMS.n_input = np.prod(HPARAMS.image_shape) 209 | 210 | if HPARAMS.measurement_type == 'circulant': 211 | HPARAMS.train_indices = np.random.randint(0, HPARAMS.n_input, HPARAMS.num_measurements ) 212 | HPARAMS.sign_pattern = np.float32((np.random.rand(1,HPARAMS.n_input) < 0.5)*2 - 1.) 213 | elif HPARAMS.measurement_type == 'superres': 214 | HPARAMS.y_shape = (HPARAMS.batch_size, HPARAMS.image_shape[0], HPARAMS.image_size//HPARAMS.downsample,HPARAMS.image_size//HPARAMS.downsample) 215 | HPARAMS.num_measurements = np.prod(HPARAMS.y_shape[1:]) 216 | elif HPARAMS.measurement_type == 'project': 217 | HPARAMS.y_shape = (HPARAMS.batch_size, HPARAMS.image_shape[0], HPARAMS.image_size, HPARAMS.image_size) 218 | HPARAMS.num_measurements = np.prod(HPARAMS.y_shape[1:]) 219 | 220 | 221 | from utils import view_image 222 | 223 | if HPARAMS.mloss_weight < 0: 224 | HPARAMS.mloss_weight = None 225 | if HPARAMS.zprior_weight < 0: 226 | HPARAMS.zprior_weight = None 227 | if HPARAMS.annealed: 228 | if HPARAMS.T < 0: 229 | HPARAMS.T = 200 230 | if HPARAMS.L < 0: 231 | HPARAMS.L = 10 232 | if HPARAMS.sigma_final < 0: 233 | HPARAMS.sigma_final = HPARAMS.noise_std 234 | if HPARAMS.sigma_init < 0: 235 | HPARAMS.sigma_init = 100 * HPARAMS.sigma_final 236 | HPARAMS.max_update_iter = int(HPARAMS.T * HPARAMS.L) 237 | 238 | main(HPARAMS) 239 | 240 | 241 | -------------------------------------------------------------------------------- /stylegan2/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /stylegan2/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import functools 4 | import operator 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from torch.autograd import Function 10 | 11 | from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d 12 | 13 | 14 | class PixelNorm(nn.Module): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def forward(self, input): 19 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 20 | 21 | 22 | def make_kernel(k): 23 | k = torch.tensor(k, dtype=torch.float32) 24 | 25 | if k.ndim == 1: 26 | k = k[None, :] * k[:, None] 27 | 28 | k /= k.sum() 29 | 30 | return k 31 | 32 | 33 | class Upsample(nn.Module): 34 | def __init__(self, kernel, factor=2): 35 | super().__init__() 36 | 37 | self.factor = factor 38 | kernel = make_kernel(kernel) * (factor ** 2) 39 | self.register_buffer('kernel', kernel) 40 | 41 | p = kernel.shape[0] - factor 42 | 43 | pad0 = (p + 1) // 2 + factor - 1 44 | pad1 = p // 2 45 | 46 | self.pad = (pad0, pad1) 47 | 48 | def forward(self, input): 49 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) 50 | 51 | return out 52 | 53 | 54 | class Downsample(nn.Module): 55 | def __init__(self, kernel, factor=2): 56 | super().__init__() 57 | 58 | self.factor = factor 59 | kernel = make_kernel(kernel) 60 | self.register_buffer('kernel', kernel) 61 | 62 | p = kernel.shape[0] - factor 63 | 64 | pad0 = (p + 1) // 2 65 | pad1 = p // 2 66 | 67 | self.pad = (pad0, pad1) 68 | 69 | def forward(self, input): 70 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) 71 | 72 | return out 73 | 74 | 75 | class Blur(nn.Module): 76 | def __init__(self, kernel, pad, upsample_factor=1): 77 | super().__init__() 78 | 79 | kernel = make_kernel(kernel) 80 | 81 | if upsample_factor > 1: 82 | kernel = kernel * (upsample_factor ** 2) 83 | 84 | self.register_buffer('kernel', kernel) 85 | 86 | self.pad = pad 87 | 88 | def forward(self, input): 89 | out = upfirdn2d(input, self.kernel, pad=self.pad) 90 | 91 | return out 92 | 93 | 94 | class EqualConv2d(nn.Module): 95 | def __init__( 96 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 97 | ): 98 | super().__init__() 99 | 100 | self.weight = nn.Parameter( 101 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 102 | ) 103 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 104 | 105 | self.stride = stride 106 | self.padding = padding 107 | 108 | if bias: 109 | self.bias = nn.Parameter(torch.zeros(out_channel)) 110 | 111 | else: 112 | self.bias = None 113 | 114 | def forward(self, input): 115 | out = F.conv2d( 116 | input, 117 | self.weight * self.scale, 118 | bias=self.bias, 119 | stride=self.stride, 120 | padding=self.padding, 121 | ) 122 | 123 | return out 124 | 125 | def __repr__(self): 126 | return ( 127 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' 128 | f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' 129 | ) 130 | 131 | 132 | class EqualLinear(nn.Module): 133 | def __init__( 134 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 135 | ): 136 | super().__init__() 137 | 138 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 139 | 140 | if bias: 141 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 142 | 143 | else: 144 | self.bias = None 145 | 146 | self.activation = activation 147 | 148 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 149 | self.lr_mul = lr_mul 150 | 151 | def forward(self, input): 152 | if self.activation: 153 | out = F.linear(input, self.weight * self.scale) 154 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 155 | 156 | else: 157 | out = F.linear( 158 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 159 | ) 160 | 161 | return out 162 | 163 | def __repr__(self): 164 | return ( 165 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' 166 | ) 167 | 168 | 169 | class ScaledLeakyReLU(nn.Module): 170 | def __init__(self, negative_slope=0.2): 171 | super().__init__() 172 | 173 | self.negative_slope = negative_slope 174 | 175 | def forward(self, input): 176 | out = F.leaky_relu(input, negative_slope=self.negative_slope) 177 | 178 | return out * math.sqrt(2) 179 | 180 | 181 | class ModulatedConv2d(nn.Module): 182 | def __init__( 183 | self, 184 | in_channel, 185 | out_channel, 186 | kernel_size, 187 | style_dim, 188 | demodulate=True, 189 | upsample=False, 190 | downsample=False, 191 | blur_kernel=[1, 3, 3, 1], 192 | ): 193 | super().__init__() 194 | 195 | self.eps = 1e-8 196 | self.kernel_size = kernel_size 197 | self.in_channel = in_channel 198 | self.out_channel = out_channel 199 | self.upsample = upsample 200 | self.downsample = downsample 201 | 202 | if upsample: 203 | factor = 2 204 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 205 | pad0 = (p + 1) // 2 + factor - 1 206 | pad1 = p // 2 + 1 207 | 208 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) 209 | 210 | if downsample: 211 | factor = 2 212 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 213 | pad0 = (p + 1) // 2 214 | pad1 = p // 2 215 | 216 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 217 | 218 | fan_in = in_channel * kernel_size ** 2 219 | self.scale = 1 / math.sqrt(fan_in) 220 | self.padding = kernel_size // 2 221 | 222 | self.weight = nn.Parameter( 223 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 224 | ) 225 | 226 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 227 | 228 | self.demodulate = demodulate 229 | 230 | def __repr__(self): 231 | return ( 232 | f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' 233 | f'upsample={self.upsample}, downsample={self.downsample})' 234 | ) 235 | 236 | def forward(self, input, style): 237 | batch, in_channel, height, width = input.shape 238 | 239 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1) 240 | weight = self.scale * self.weight * style 241 | 242 | if self.demodulate: 243 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 244 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 245 | 246 | weight = weight.view( 247 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 248 | ) 249 | 250 | if self.upsample: 251 | input = input.view(1, batch * in_channel, height, width) 252 | weight = weight.view( 253 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 254 | ) 255 | weight = weight.transpose(1, 2).reshape( 256 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 257 | ) 258 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) 259 | _, _, height, width = out.shape 260 | out = out.view(batch, self.out_channel, height, width) 261 | out = self.blur(out) 262 | 263 | elif self.downsample: 264 | input = self.blur(input) 265 | _, _, height, width = input.shape 266 | input = input.view(1, batch * in_channel, height, width) 267 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) 268 | _, _, height, width = out.shape 269 | out = out.view(batch, self.out_channel, height, width) 270 | 271 | else: 272 | input = input.view(1, batch * in_channel, height, width) 273 | out = F.conv2d(input, weight, padding=self.padding, groups=batch) 274 | _, _, height, width = out.shape 275 | out = out.view(batch, self.out_channel, height, width) 276 | 277 | return out 278 | 279 | 280 | class NoiseInjection(nn.Module): 281 | def __init__(self): 282 | super().__init__() 283 | 284 | self.weight = nn.Parameter(torch.zeros(1)) 285 | 286 | def forward(self, image, noise=None): 287 | if noise is None: 288 | batch, _, height, width = image.shape 289 | noise = image.new_empty(batch, 1, height, width).normal_() 290 | 291 | return image + self.weight * noise 292 | 293 | 294 | class ConstantInput(nn.Module): 295 | def __init__(self, channel, size=4): 296 | super().__init__() 297 | 298 | self.input = nn.Parameter(torch.randn(1, channel, size, size)) 299 | 300 | def forward(self, input): 301 | batch = input.shape[0] 302 | out = self.input.repeat(batch, 1, 1, 1) 303 | 304 | return out 305 | 306 | 307 | class StyledConv(nn.Module): 308 | def __init__( 309 | self, 310 | in_channel, 311 | out_channel, 312 | kernel_size, 313 | style_dim, 314 | upsample=False, 315 | blur_kernel=[1, 3, 3, 1], 316 | demodulate=True, 317 | ): 318 | super().__init__() 319 | 320 | self.conv = ModulatedConv2d( 321 | in_channel, 322 | out_channel, 323 | kernel_size, 324 | style_dim, 325 | upsample=upsample, 326 | blur_kernel=blur_kernel, 327 | demodulate=demodulate, 328 | ) 329 | 330 | self.noise = NoiseInjection() 331 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 332 | # self.activate = ScaledLeakyReLU(0.2) 333 | self.activate = FusedLeakyReLU(out_channel) 334 | 335 | def forward(self, input, style, noise=None): 336 | out = self.conv(input, style) 337 | out = self.noise(out, noise=noise) 338 | # out = out + self.bias 339 | out = self.activate(out) 340 | 341 | return out 342 | 343 | 344 | class ToRGB(nn.Module): 345 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): 346 | super().__init__() 347 | 348 | if upsample: 349 | self.upsample = Upsample(blur_kernel) 350 | 351 | self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) 352 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) 353 | 354 | def forward(self, input, style, skip=None): 355 | out = self.conv(input, style) 356 | out = out + self.bias 357 | 358 | if skip is not None: 359 | skip = self.upsample(skip) 360 | 361 | out = out + skip 362 | 363 | return out 364 | 365 | 366 | class Generator(nn.Module): 367 | def __init__( 368 | self, 369 | size, 370 | style_dim, 371 | n_mlp, 372 | channel_multiplier=2, 373 | blur_kernel=[1, 3, 3, 1], 374 | lr_mlp=0.01, 375 | ): 376 | super().__init__() 377 | 378 | self.size = size 379 | 380 | self.style_dim = style_dim 381 | 382 | layers = [PixelNorm()] 383 | 384 | for i in range(n_mlp): 385 | layers.append( 386 | EqualLinear( 387 | style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' 388 | ) 389 | ) 390 | 391 | self.style = nn.Sequential(*layers) 392 | 393 | self.channels = { 394 | 4: 512, 395 | 8: 512, 396 | 16: 512, 397 | 32: 512, 398 | 64: 256 * channel_multiplier, 399 | 128: 128 * channel_multiplier, 400 | 256: 64 * channel_multiplier, 401 | 512: 32 * channel_multiplier, 402 | 1024: 16 * channel_multiplier, 403 | } 404 | 405 | self.input = ConstantInput(self.channels[4]) 406 | self.conv1 = StyledConv( 407 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel 408 | ) 409 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) 410 | 411 | self.log_size = int(math.log(size, 2)) 412 | self.num_layers = (self.log_size - 2) * 2 + 1 413 | 414 | self.convs = nn.ModuleList() 415 | self.upsamples = nn.ModuleList() 416 | self.to_rgbs = nn.ModuleList() 417 | self.noises = nn.Module() 418 | 419 | in_channel = self.channels[4] 420 | 421 | for layer_idx in range(self.num_layers): 422 | res = (layer_idx + 5) // 2 423 | shape = [1, 1, 2 ** res, 2 ** res] 424 | self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) 425 | 426 | for i in range(3, self.log_size + 1): 427 | out_channel = self.channels[2 ** i] 428 | 429 | self.convs.append( 430 | StyledConv( 431 | in_channel, 432 | out_channel, 433 | 3, 434 | style_dim, 435 | upsample=True, 436 | blur_kernel=blur_kernel, 437 | ) 438 | ) 439 | 440 | self.convs.append( 441 | StyledConv( 442 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel 443 | ) 444 | ) 445 | 446 | self.to_rgbs.append(ToRGB(out_channel, style_dim)) 447 | 448 | in_channel = out_channel 449 | 450 | self.n_latent = self.log_size * 2 - 2 451 | 452 | def make_noise(self): 453 | device = self.input.input.device 454 | 455 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] 456 | 457 | for i in range(3, self.log_size + 1): 458 | for _ in range(2): 459 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) 460 | 461 | return noises 462 | 463 | def mean_latent(self, n_latent): 464 | latent_in = torch.randn( 465 | n_latent, self.style_dim, device=self.input.input.device 466 | ) 467 | latent = self.style(latent_in).mean(0, keepdim=True) 468 | 469 | return latent 470 | 471 | def get_latent(self, input): 472 | return self.style(input) 473 | 474 | def forward( 475 | self, 476 | styles, 477 | return_latents=False, 478 | inject_index=None, 479 | truncation=1, 480 | truncation_latent=None, 481 | input_is_latent=False, 482 | noise=None, 483 | randomize_noise=True, 484 | ): 485 | if not input_is_latent: 486 | styles = [self.style(s) for s in styles] 487 | 488 | if noise is None: 489 | if randomize_noise: 490 | noise = [None] * self.num_layers 491 | else: 492 | noise = [ 493 | getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) 494 | ] 495 | 496 | if truncation < 1: 497 | style_t = [] 498 | 499 | for style in styles: 500 | style_t.append( 501 | truncation_latent + truncation * (style - truncation_latent) 502 | ) 503 | 504 | styles = style_t 505 | 506 | if len(styles) < 2: 507 | inject_index = self.n_latent 508 | 509 | if styles[0].ndim < 3: 510 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 511 | 512 | else: 513 | latent = styles[0] 514 | 515 | else: 516 | if inject_index is None: 517 | inject_index = random.randint(1, self.n_latent - 1) 518 | 519 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 520 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) 521 | 522 | latent = torch.cat([latent, latent2], 1) 523 | 524 | out = self.input(latent) 525 | out = self.conv1(out, latent[:, 0], noise=noise[0]) 526 | 527 | skip = self.to_rgb1(out, latent[:, 1]) 528 | 529 | i = 1 530 | for conv1, conv2, noise1, noise2, to_rgb in zip( 531 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs 532 | ): 533 | out = conv1(out, latent[:, i], noise=noise1) 534 | out = conv2(out, latent[:, i + 1], noise=noise2) 535 | skip = to_rgb(out, latent[:, i + 2], skip) 536 | 537 | i += 2 538 | 539 | image = skip 540 | 541 | if return_latents: 542 | return image, latent 543 | 544 | else: 545 | return image, None 546 | 547 | 548 | class ConvLayer(nn.Sequential): 549 | def __init__( 550 | self, 551 | in_channel, 552 | out_channel, 553 | kernel_size, 554 | downsample=False, 555 | blur_kernel=[1, 3, 3, 1], 556 | bias=True, 557 | activate=True, 558 | ): 559 | layers = [] 560 | 561 | if downsample: 562 | factor = 2 563 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 564 | pad0 = (p + 1) // 2 565 | pad1 = p // 2 566 | 567 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 568 | 569 | stride = 2 570 | self.padding = 0 571 | 572 | else: 573 | stride = 1 574 | self.padding = kernel_size // 2 575 | 576 | layers.append( 577 | EqualConv2d( 578 | in_channel, 579 | out_channel, 580 | kernel_size, 581 | padding=self.padding, 582 | stride=stride, 583 | bias=bias and not activate, 584 | ) 585 | ) 586 | 587 | if activate: 588 | if bias: 589 | layers.append(FusedLeakyReLU(out_channel)) 590 | 591 | else: 592 | layers.append(ScaledLeakyReLU(0.2)) 593 | 594 | super().__init__(*layers) 595 | 596 | 597 | class ResBlock(nn.Module): 598 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): 599 | super().__init__() 600 | 601 | self.conv1 = ConvLayer(in_channel, in_channel, 3) 602 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) 603 | 604 | self.skip = ConvLayer( 605 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False 606 | ) 607 | 608 | def forward(self, input): 609 | out = self.conv1(input) 610 | out = self.conv2(out) 611 | 612 | skip = self.skip(input) 613 | out = (out + skip) / math.sqrt(2) 614 | 615 | return out 616 | 617 | 618 | class Discriminator(nn.Module): 619 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): 620 | super().__init__() 621 | 622 | channels = { 623 | 4: 512, 624 | 8: 512, 625 | 16: 512, 626 | 32: 512, 627 | 64: 256 * channel_multiplier, 628 | 128: 128 * channel_multiplier, 629 | 256: 64 * channel_multiplier, 630 | 512: 32 * channel_multiplier, 631 | 1024: 16 * channel_multiplier, 632 | } 633 | 634 | convs = [ConvLayer(3, channels[size], 1)] 635 | 636 | log_size = int(math.log(size, 2)) 637 | 638 | in_channel = channels[size] 639 | 640 | for i in range(log_size, 2, -1): 641 | out_channel = channels[2 ** (i - 1)] 642 | 643 | convs.append(ResBlock(in_channel, out_channel, blur_kernel)) 644 | 645 | in_channel = out_channel 646 | 647 | self.convs = nn.Sequential(*convs) 648 | 649 | self.stddev_group = 4 650 | self.stddev_feat = 1 651 | 652 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) 653 | self.final_linear = nn.Sequential( 654 | EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), 655 | EqualLinear(channels[4], 1), 656 | ) 657 | 658 | def forward(self, input): 659 | out = self.convs(input) 660 | 661 | batch, channel, height, width = out.shape 662 | group = min(batch, self.stddev_group) 663 | stddev = out.view( 664 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width 665 | ) 666 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 667 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 668 | stddev = stddev.repeat(group, 1, height, width) 669 | out = torch.cat([out, stddev], 1) 670 | 671 | out = self.final_conv(out) 672 | 673 | out = out.view(batch, -1) 674 | out = self.final_linear(out) 675 | 676 | return out 677 | 678 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | """Some common utils""" 2 | # pylint: disable = C0301, C0103, C0111 3 | 4 | import os 5 | import pickle 6 | import shutil 7 | import torch 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | from matplotlib.backends.backend_pdf import PdfPages 11 | from skimage.transform import downscale_local_mean 12 | import sys 13 | 14 | import glob 15 | from torchvision import transforms, datasets 16 | import torch 17 | 18 | from PIL import Image 19 | 20 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 21 | 22 | try: 23 | import tensorflow as tf 24 | except ImportError: 25 | print('did not import tensorflow') 26 | 27 | import estimators 28 | font = { 'weight' : 'bold', 29 | 'size' : 30} 30 | import matplotlib 31 | matplotlib.rc('font', **font) 32 | matplotlib.rcParams['pdf.fonttype'] = 42 33 | matplotlib.rcParams['ps.fonttype'] = 42 34 | 35 | 36 | class BestKeeper(object): 37 | """Class to keep the best stuff""" 38 | def __init__(self, batch_size, n_input): 39 | self.batch_size = batch_size 40 | self.losses_val_best = [1e100 for _ in range(batch_size)] 41 | self.x_hat_batch_val_best = np.zeros((batch_size, n_input)) 42 | 43 | def report(self, x_hat_batch_val, losses_val): 44 | for i in range(self.batch_size): 45 | if losses_val[i] < self.losses_val_best[i]: 46 | self.x_hat_batch_val_best[i, :] = x_hat_batch_val[i, :] 47 | self.losses_val_best[i] = losses_val[i] 48 | 49 | def get_best(self): 50 | return self.x_hat_batch_val_best 51 | 52 | 53 | def get_l2_loss(image1, image2): 54 | """Get L2 loss between the two images""" 55 | assert image1.shape == image2.shape 56 | return np.mean((image1 - image2)**2) 57 | 58 | 59 | def get_measurement_loss(x_hat, A, y, hparams): 60 | """Get measurement loss of the estimated image""" 61 | y_hat = get_measurements(x_hat, A, 0 , hparams) 62 | # measurements are in a batch of size 1. 63 | y_hat = y_hat.reshape(y.shape) 64 | # if A is None: 65 | # y_hat = x_hat 66 | # else: 67 | # y_hat = np.matmul(x_hat, A) 68 | assert y_hat.shape == y.shape 69 | return np.sum((y - y_hat) ** 2) 70 | 71 | 72 | def save_to_pickle(data, pkl_filepath): 73 | """Save the data to a pickle file""" 74 | with open(pkl_filepath, 'wb') as pkl_file: 75 | pickle.dump(data, pkl_file) 76 | 77 | 78 | def load_if_pickled(pkl_filepath): 79 | """Load if the pickle file exists. Else return empty dict""" 80 | if os.path.isfile(pkl_filepath): 81 | with open(pkl_filepath, 'rb') as pkl_file: 82 | data = pickle.load(pkl_file) 83 | else: 84 | data = {} 85 | return data 86 | 87 | 88 | def get_estimator(hparams, model_type): 89 | if model_type == 'map' and hparams.net == 'glow': 90 | estimator = estimators.glow_annealed_map_estimator(hparams) 91 | elif model_type == 'langevin' and hparams.net == 'glow': 92 | estimator = estimators.glow_annealed_langevin_estimator(hparams) 93 | elif model_type == 'langevin' and hparams.net == 'stylegan2': 94 | estimator = estimators.stylegan_langevin_estimator(hparams) 95 | elif model_type == 'pulse': 96 | assert hparams.net.lower() == 'stylegan2' 97 | estimator = estimators.stylegan_pulse_estimator(hparams) 98 | elif model_type == 'map' and hparams.net == 'ncsnv2': 99 | estimator = estimators.ncsnv2_langevin_estimator(hparams, MAP=True) 100 | elif model_type == 'langevin' and hparams.net == 'ncsnv2': 101 | estimator = estimators.ncsnv2_langevin_estimator(hparams, MAP=False) 102 | elif hparams.net == 'dd' or hparams.model_type == 'dd': 103 | estimator = estimators.deep_decoder_estimator(hparams) 104 | else: 105 | raise NotImplementedError 106 | return estimator 107 | 108 | def setup_checkpointing(hparams): 109 | # Set up checkpoint directories 110 | checkpoint_dir = get_checkpoint_dir(hparams, hparams.model_type) 111 | set_up_dir(checkpoint_dir) 112 | 113 | 114 | def save_images(est_images, save_image, hparams): 115 | """Save a batch of images to png files""" 116 | for image_num, image in est_images.items(): 117 | save_path = get_save_path(hparams, image_num) 118 | image = image.reshape(hparams.image_shape) 119 | save_image(image, save_path) 120 | 121 | 122 | def checkpoint(est_images, measurement_losses, l2_losses, z_hats, likelihoods, hparams): 123 | """Save images, measurement losses and L2 losses for a batch""" 124 | if not hparams.debug: 125 | save_images(est_images, save_image, hparams) 126 | 127 | m_losses_filepath, l2_losses_filepath, z_hats_filepath, likelihoods_filepath = get_pkl_filepaths(hparams, hparams.model_type) 128 | save_to_pickle(measurement_losses, m_losses_filepath) 129 | save_to_pickle(l2_losses, l2_losses_filepath) 130 | save_to_pickle(z_hats, z_hats_filepath) 131 | save_to_pickle(likelihoods, likelihoods_filepath) 132 | 133 | 134 | def load_checkpoints(hparams): 135 | measurement_losses, l2_losses, z_hats, likelihoods = {}, {}, {}, {} 136 | if not hparams.debug: 137 | # Load pickled loss dictionaries 138 | m_losses_filepath, l2_losses_filepath, z_hats_filepath, likelihoods_filepath = get_pkl_filepaths(hparams, hparams.model_type) 139 | measurement_losses = load_if_pickled(m_losses_filepath) 140 | l2_losses = load_if_pickled(l2_losses_filepath) 141 | z_hats = load_if_pickled(z_hats_filepath) 142 | likelihoods = load_if_pickled(likelihoods_filepath) 143 | else: 144 | measurement_losses = {} 145 | l2_losses = {} 146 | z_hats = {} 147 | likelihoods = {} 148 | return measurement_losses, l2_losses, z_hats, likelihoods 149 | 150 | 151 | def image_matrix(images, est_images, l2_losses, hparams, alg_labels=True): 152 | """Display images""" 153 | 154 | if hparams.measurement_type in ['inpaint', 'superres']: 155 | figure_height = 2 + len(hparams.model_types) 156 | else: 157 | figure_height = 1 + len(hparams.model_types) 158 | 159 | fig = plt.figure(figsize=[4*len(images), 4.3*figure_height]) 160 | 161 | outer_counter = 0 162 | inner_counter = 0 163 | 164 | # Show original images 165 | outer_counter += 1 166 | for image in images.values(): 167 | inner_counter += 1 168 | ax = fig.add_subplot(figure_height, 1, outer_counter, frameon=False) 169 | ax.get_xaxis().set_visible(False) 170 | ax.get_yaxis().set_ticks([]) 171 | if alg_labels: 172 | ax.set_ylabel('Original')#, fontsize=14) 173 | _ = fig.add_subplot(figure_height, len(images), inner_counter) 174 | view_image(image, hparams) 175 | 176 | # Show original images with inpainting mask 177 | if hparams.measurement_type == 'inpaint': 178 | mask = get_inpaint_mask(hparams) 179 | outer_counter += 1 180 | for image in images.values(): 181 | inner_counter += 1 182 | ax = fig.add_subplot(figure_height, 1, outer_counter, frameon=False) 183 | ax.get_xaxis().set_visible(False) 184 | ax.get_yaxis().set_ticks([]) 185 | if alg_labels: 186 | ax.set_ylabel('Masked') #, fontsize=14) 187 | _ = fig.add_subplot(figure_height, len(images), inner_counter) 188 | view_image(image, hparams, mask) 189 | 190 | # Show original images with blurring 191 | if hparams.measurement_type == 'superres': 192 | factor = hparams.superres_factor 193 | A = get_A_superres(hparams) 194 | outer_counter += 1 195 | for image in images.values(): 196 | image_low_res = np.matmul(image, A) / np.sqrt(hparams.n_input/(factor**2)) / (factor**2) 197 | low_res_shape = (int(hparams.image_shape[0]/factor), int(hparams.image_shape[1]/factor), hparams.image_shape[2]) 198 | image_low_res = np.reshape(image_low_res, low_res_shape) 199 | inner_counter += 1 200 | ax = fig.add_subplot(figure_height, 1, outer_counter, frameon=False) 201 | ax.get_xaxis().set_visible(False) 202 | ax.get_yaxis().set_ticks([]) 203 | if alg_labels: 204 | ax.set_ylabel('Blurred') #, fontsize=14) 205 | _ = fig.add_subplot(figure_height, len(images), inner_counter) 206 | view_image(image_low_res, hparams) 207 | 208 | for model_type in hparams.model_types: 209 | outer_counter += 1 210 | for image, l2 in zip(est_images[model_type].values(), l2_losses[model_type].values()): 211 | inner_counter += 1 212 | ax = fig.add_subplot(figure_height, 1, outer_counter, frameon=False) 213 | # ax.get_xaxis().set_visible(False) 214 | ax.get_yaxis().set_ticks([]) 215 | ax.get_xaxis().set_ticks([]) 216 | if alg_labels: 217 | ax.set_ylabel(model_type) #, fontsize=14) 218 | _ = fig.add_subplot(figure_height, len(images), inner_counter) 219 | #_.set_title(f'PSNR={-10*np.log10(l2):.3f}', fontsize=8) 220 | view_image(image, hparams) 221 | 222 | if hparams.image_matrix >= 2: 223 | save_path = get_matrix_save_path(hparams) 224 | plt.savefig(save_path) 225 | 226 | if hparams.image_matrix in [1, 3]: 227 | plt.show() 228 | 229 | 230 | def plot_image(image, cmap=None): 231 | """Show the image""" 232 | frame = plt.gca() 233 | frame.axes.get_xaxis().set_visible(False) 234 | frame.axes.get_yaxis().set_visible(False) 235 | frame = frame.imshow(image, cmap=cmap) 236 | 237 | 238 | def get_checkpoint_dir(hparams, model_type): 239 | if hparams.annealed : 240 | base_dir = './estimated/{0}/{1}/{2}/{3}/{4}/{5}/annealed_{6}/'.format( 241 | hparams.dataset, 242 | hparams.input_type, 243 | hparams.measurement_type, 244 | hparams.noise_std, 245 | hparams.num_measurements, 246 | hparams.net, 247 | model_type 248 | ) 249 | else: 250 | base_dir = './estimated/{0}/{1}/{2}/{3}/{4}/{5}/{6}/'.format( 251 | hparams.dataset, 252 | hparams.input_type, 253 | hparams.measurement_type, 254 | hparams.noise_std, 255 | hparams.num_measurements, 256 | hparams.net, 257 | model_type 258 | ) 259 | 260 | 261 | if hparams.net == 'ncsnv2': 262 | dir_name = '{}_{}_{}_{}_{}_{}_{}_{}_{}'.format( 263 | hparams.mloss_weight, 264 | hparams.ncsnv2_configs['sampling']['step_lr'], 265 | hparams.ncsnv2_configs['sampling']['n_steps_each'], 266 | hparams.ncsnv2_configs['model']['sigma_begin'], 267 | int(hparams.L), 268 | hparams.ncsnv2_configs['model']['ema'], 269 | hparams.ncsnv2_configs['model']['ema_rate'], 270 | hparams.ncsnv2_configs['model']['sigma_dist'], 271 | hparams.ncsnv2_configs['model']['sigma_end'] 272 | ) 273 | elif hparams.net == 'dd' or hparams.model_type == 'dd': 274 | dir_name = '{}_{}_{}'.format( 275 | hparams.optimizer_type, 276 | hparams.learning_rate, 277 | hparams.max_update_iter 278 | ) 279 | 280 | else: 281 | 282 | if model_type in ['map'] and (not hparams.annealed): 283 | dir_name = '{}_{}_{}_{}_{}_{}_{}_{}'.format( 284 | hparams.mloss_weight, 285 | hparams.zprior_weight, 286 | hparams.optimizer_type, 287 | hparams.project, 288 | hparams.learning_rate, 289 | hparams.momentum, 290 | hparams.max_update_iter, 291 | hparams.num_random_restarts, 292 | ) 293 | elif model_type in ['map'] and hparams.annealed: 294 | if hparams.net == 'stylegan2': 295 | dir_name = '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format( 296 | hparams.zprior_weight, 297 | hparams.T, 298 | hparams.L, 299 | hparams.sigma_init, 300 | hparams.sigma_final, 301 | hparams.optimizer_type, 302 | hparams.project, 303 | hparams.learning_rate, 304 | hparams.momentum, 305 | hparams.max_update_iter, 306 | hparams.num_random_restarts, 307 | hparams.num_noise_variables 308 | ) 309 | else: 310 | dir_name = '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format( 311 | hparams.zprior_weight, 312 | hparams.T, 313 | hparams.L, 314 | hparams.sigma_init, 315 | hparams.sigma_final, 316 | hparams.optimizer_type, 317 | hparams.learning_rate, 318 | hparams.momentum, 319 | hparams.max_update_iter, 320 | hparams.num_random_restarts, 321 | ) 322 | elif model_type in ['langevin'] and not hparams.annealed: 323 | dir_name = '{}_{}_{}_{}_{}_{}_{}'.format( 324 | hparams.mloss_weight, 325 | hparams.zprior_weight, 326 | hparams.optimizer_type, 327 | hparams.learning_rate, 328 | hparams.momentum, 329 | hparams.max_update_iter, 330 | hparams.num_random_restarts, 331 | ) 332 | elif model_type in ['langevin'] and hparams.annealed: 333 | if hparams.net == 'stylegan2': 334 | dir_name = '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format( 335 | hparams.zprior_weight, 336 | hparams.T, 337 | hparams.L, 338 | hparams.sigma_init, 339 | hparams.sigma_final, 340 | hparams.zprior_init_sdev, 341 | hparams.zprior_sdev, 342 | hparams.optimizer_type, 343 | hparams.learning_rate, 344 | hparams.momentum, 345 | hparams.max_update_iter, 346 | hparams.num_random_restarts, 347 | hparams.num_noise_variables 348 | ) 349 | else: 350 | dir_name = '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format( 351 | hparams.zprior_weight, 352 | hparams.T, 353 | hparams.L, 354 | hparams.sigma_init, 355 | hparams.sigma_final, 356 | hparams.zprior_init_sdev, 357 | hparams.zprior_sdev, 358 | hparams.optimizer_type, 359 | hparams.learning_rate, 360 | hparams.momentum, 361 | hparams.max_update_iter, 362 | hparams.num_random_restarts 363 | ) 364 | elif model_type in ['pulse']: 365 | dir_name = '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format( 366 | hparams.seed, 367 | hparams.loss_str, 368 | hparams.pulse_eps, 369 | hparams.noise_type, 370 | hparams.tile_latent, 371 | hparams.num_noise_variables, 372 | hparams.optimizer_type, 373 | hparams.learning_rate, 374 | hparams.max_update_iter, 375 | hparams.lr_schedule, 376 | ) 377 | else: 378 | raise NotImplementedError 379 | 380 | ckpt_dir = base_dir + dir_name + '/' 381 | 382 | return ckpt_dir 383 | 384 | 385 | def get_pkl_filepaths(hparams, model_type): 386 | """Return paths for the pickle files""" 387 | checkpoint_dir = get_checkpoint_dir(hparams, model_type) 388 | m_losses_filepath = checkpoint_dir + 'measurement_losses.pkl' 389 | l2_losses_filepath = checkpoint_dir + 'l2_losses.pkl' 390 | z_hats_filepath = checkpoint_dir + 'z.pkl' 391 | likelihoods_filepath = checkpoint_dir + 'likelihoods.pkl' 392 | return m_losses_filepath, l2_losses_filepath, z_hats_filepath, likelihoods_filepath 393 | 394 | 395 | def get_save_path(hparams, image_num): 396 | save_paths = {} 397 | checkpoint_dir = get_checkpoint_dir(hparams, hparams.model_type) 398 | image_dir = os.path.join(checkpoint_dir, 'images') 399 | set_up_dir(image_dir) 400 | save_path = os.path.join(image_dir , '{0:06d}.png'.format(image_num)) 401 | return save_path 402 | 403 | 404 | def get_matrix_save_path(hparams): 405 | save_path = './estimated/{0}/{1}/{2}/{3}/{4}/matrix_{5}.png'.format( 406 | hparams.dataset, 407 | hparams.input_type, 408 | hparams.measurement_type, 409 | hparams.noise_std, 410 | hparams.num_measurements, 411 | '_'.join(hparams.model_types) 412 | ) 413 | return save_path 414 | 415 | 416 | def set_up_dir(directory, clean=False): 417 | if os.path.exists(directory): 418 | if clean: 419 | shutil.rmtree(directory) 420 | else: 421 | os.makedirs(directory) 422 | 423 | 424 | def print_hparams(hparams): 425 | print('') 426 | for temp in dir(hparams): 427 | if temp[:1] != '_': 428 | print('{0} = {1}'.format(temp, getattr(hparams, temp))) 429 | print('') 430 | 431 | 432 | def get_learning_rate(global_step, hparams): 433 | if hparams.decay_lr: 434 | return tf.train.exponential_decay(hparams.learning_rate, 435 | global_step, 436 | 50, 437 | 0.7, 438 | staircase=True) 439 | else: 440 | return tf.constant(hparams.learning_rate) 441 | 442 | 443 | def get_optimizer(z, learning_rate, hparams): 444 | if hparams.optimizer_type == 'sgd': 445 | if hparams.net == 'realnvp': 446 | return torch.optim.SGD([z], learning_rate, momentum=hparams.momentum) 447 | elif hparams.net == 'stylegan2': 448 | return torch.optim.SGD(z, learning_rate, momentum=hparams.momentum) 449 | 450 | elif hparams.net == 'glow': 451 | if hparams.momentum == 0.: 452 | return tf.train.GradientDescentOptimizer(learning_rate) 453 | else: 454 | return tf.train.MomentumOptimizer(learning_rate, hparams.momentum) 455 | elif hparams.optimizer_type == 'rmsprop': 456 | return torch.optim.RMSprop([z], lr=learning_rate, momentum=hparams.momentum) 457 | elif hparams.optimizer_type == 'adam': 458 | if hparams.net == 'realnvp': 459 | return torch.optim.Adam([z], lr=learning_rate) 460 | elif hparams.net == 'stylegan2': 461 | return torch.optim.Adam(z, lr=learning_rate) 462 | elif hparams.net == 'glow': 463 | return tf.train.AdamOptimizer(learning_rate) 464 | elif hparams.optimizer_type == 'adagrad': 465 | return torch.optim.Adagrad([z], lr=learning_rate) 466 | elif hparams.optimizer_type == 'lbfgs': 467 | if hparams.net == 'realnvp': 468 | return torch.optim.LBFGS([z], lr=learning_rate) 469 | elif hparams.net == 'glow': 470 | return Exception('Tensorflow does not support ' + hparams.optimizer_type) 471 | else: 472 | raise Exception('Optimizer ' + hparams.optimizer_type + ' not supported') 473 | 474 | 475 | def get_inpaint_mask(hparams): 476 | image_size = hparams.image_shape[0] 477 | margin = (image_size - hparams.inpaint_size) / 2 478 | mask = np.ones(hparams.image_shape) 479 | mask[margin:margin+hparams.inpaint_size, margin:margin+hparams.inpaint_size] = 0 480 | return mask 481 | 482 | 483 | def get_A_inpaint(hparams): 484 | mask = get_inpaint_mask(hparams) 485 | mask = mask.reshape(1, -1) 486 | A = np.eye(np.prod(mask.shape)) * np.tile(mask, [np.prod(mask.shape), 1]) 487 | A = np.asarray([a for a in A if np.sum(a) != 0]) 488 | 489 | # Make sure that the norm of each row of A is hparams.n_input 490 | A = np.sqrt(hparams.n_input) * A 491 | assert all(np.abs(np.sum(A**2, 1) - hparams.n_input) < 1e-6) 492 | 493 | return A.T 494 | 495 | 496 | def get_A_superres(hparams): 497 | factor = hparams.superres_factor 498 | A = np.zeros((int(hparams.n_input/(factor**2)), hparams.n_input)) 499 | l = 0 500 | for i in range(hparams.image_shape[0]/factor): 501 | for j in range(hparams.image_shape[1]/factor): 502 | for k in range(hparams.image_shape[2]): 503 | a = np.zeros(hparams.image_shape) 504 | a[factor*i:factor*(i+1), factor*j:factor*(j+1), k] = 1 505 | A[l, :] = np.reshape(a, [1, -1]) 506 | l += 1 507 | 508 | # Make sure that the norm of each row of A is hparams.n_input 509 | A = np.sqrt(hparams.n_input/(factor**2)) * A 510 | assert all(np.abs(np.sum(A**2, 1) - hparams.n_input) < 1e-6) 511 | 512 | return A.T 513 | 514 | def get_A(hparams): 515 | if hparams.measurement_type == 'gaussian': 516 | A = np.random.randn(hparams.n_input, hparams.num_measurements)/np.sqrt(hparams.num_measurements) 517 | elif hparams.measurement_type == 'superres': 518 | A = None 519 | # A = get_A_superres(hparams) 520 | elif hparams.measurement_type == 'inpaint': 521 | A = get_A_inpaint(hparams) 522 | elif hparams.measurement_type == 'project': 523 | A = None 524 | elif hparams.measurement_type == 'circulant': 525 | temp = np.random.randn(1, hparams.n_input) 526 | A = temp/ np.sqrt(hparams.num_measurements) 527 | else: 528 | raise NotImplementedError 529 | return A 530 | 531 | 532 | def set_num_measurements(hparams): 533 | if hparams.measurement_type == 'project': 534 | hparams.num_measurements = hparams.n_input 535 | else: 536 | hparams.num_measurements = get_A(hparams).shape[1] 537 | 538 | 539 | def get_checkpoint_path(ckpt_dir): 540 | ckpt_dir = os.path.abspath(ckpt_dir) 541 | ckpt = tf.train.get_checkpoint_state(ckpt_dir) 542 | if ckpt and ckpt.model_checkpoint_path: 543 | ckpt_path = os.path.join(ckpt_dir, 544 | ckpt.model_checkpoint_path) 545 | else: 546 | print('No checkpoint file found') 547 | ckpt_path = '' 548 | return ckpt_path 549 | 550 | 551 | 552 | 553 | def save_plot(is_save, save_path): 554 | if is_save: 555 | pdf = PdfPages(save_path) 556 | pdf.savefig(bbox_inches='tight') 557 | pdf.close() 558 | 559 | def get_opt_reinit_op(opt, var_list, global_step): 560 | opt_slots = [opt.get_slot(var, name) for name in opt.get_slot_names() for var in var_list] 561 | if isinstance(opt, tf.train.AdamOptimizer): 562 | opt_slots.extend([opt._beta1_power, opt._beta2_power]) #pylint: disable = W0212 563 | all_opt_variables = opt_slots + var_list + [global_step] 564 | opt_reinit_op = tf.variables_initializer(all_opt_variables) 565 | return opt_reinit_op 566 | 567 | 568 | def partial_circulant_tf(inputs, filters, indices, sign_pattern): 569 | n = np.prod(inputs.get_shape().as_list()[1:]) 570 | bs = inputs.get_shape().as_list()[0] 571 | input_reshape = tf.reshape(inputs, (-1,n)) 572 | input_sign = tf.multiply(input_reshape , sign_pattern) 573 | 574 | zeros_input = tf.zeros_like(input_sign) 575 | zeros_filter = tf.zeros_like(filters) 576 | complex_input = tf.complex(input_sign, zeros_input ) 577 | complex_filter = tf.complex(filters, zeros_filter) 578 | output_fft = tf.multiply(tf.fft(complex_input), tf.fft(complex_filter)) 579 | output_ifft = tf.ifft(output_fft) 580 | output = tf.real(output_ifft) 581 | return tf.gather(output, indices, axis=1) 582 | 583 | # later versions of pytorch make FFT a module rather 584 | # than a function. uncomment the following for different versions 585 | 586 | #def partial_circulant_torch(inputs, filters, indices, sign_pattern): 587 | # n = np.prod(inputs.shape[1:]) 588 | # bs = inputs.shape[0] 589 | # input_reshape = inputs.view(bs,n) 590 | # input_sign = input_reshape * sign_pattern 591 | # 592 | # input_fft, filter_fft = torch.fft.fft(input_sign, dim=1), torch.fft.fft(filters, dim=1) 593 | # 594 | # output_fft = input_fft * filter_fft 595 | # output_ifft = torch.fft.ifft(output_fft, dim=1) 596 | # output_real = torch.real(output_ifft) 597 | # return output_real[:, indices] 598 | 599 | def partial_circulant_torch(inputs, filters, indices, sign_pattern): 600 | n = np.prod(inputs.shape[1:]) 601 | # vectorize stuff 602 | bs = inputs.shape[0] 603 | input_reshape = inputs.view(bs,n) 604 | 605 | # multiply input with random bernoulli 606 | input_sign = input_reshape * sign_pattern 607 | 608 | def to_complex(tensor): 609 | zeros = torch.zeros_like(tensor) 610 | concat = torch.cat((tensor, zeros), axis=0) 611 | reshape = concat.view(2,-1,n) 612 | return reshape.permute(1,2,0) 613 | 614 | # convert to two-dimensional complex value 615 | complex_input = to_complex(input_sign) 616 | complex_filter = to_complex(filters) 617 | 618 | # do fft of measurement row and input 619 | input_fft = torch.fft(complex_input, 1) 620 | filter_fft = torch.fft(complex_filter, 1) 621 | output_fft = torch.zeros_like(input_fft) 622 | 623 | # perform complex multiplication between FFTs of input and 624 | # row of measurement matrix 625 | output_fft[:,:,0] = input_fft[:,:,0]*filter_fft[:,:,0] - input_fft[:,:,1] * filter_fft[:,:,1] 626 | output_fft[:,:,1] = input_fft[:,:,1] * filter_fft[:,:,0] + input_fft[:,:,0] * filter_fft[:,:,1] 627 | 628 | # take IFFT to get the convolved output 629 | output_ifft = torch.ifft(output_fft, 1) 630 | # get real value 631 | output_real = output_ifft[:,:,0] 632 | return output_real[:, indices] 633 | 634 | def blur(image, factor): 635 | meas = tf.nn.avg_pool(image,[1,1,factor,factor],strides=[1,1,factor,factor],padding='VALID', data_format='NCHW') 636 | return meas 637 | 638 | def get_measurements(x_batch, A, noise_batch, hparams): 639 | if hparams.measurement_type == 'project': 640 | y_batch = x_batch + noise_batch 641 | elif hparams.measurement_type == 'circulant': 642 | full_measurements = np.real(np.fft.ifft(np.fft.fft(x_batch*hparams.sign_pattern) *np.fft.fft(A))) 643 | indices = hparams.train_indices 644 | y_batch = full_measurements[:,indices] + noise_batch 645 | 646 | elif hparams.measurement_type == 'superres': 647 | x_reshape = x_batch.reshape((-1,) + hparams.image_shape) 648 | x_downsample = downscale_local_mean(x_reshape, (1,1, hparams.downsample,hparams.downsample)) 649 | y_batch = x_downsample.reshape(-1, hparams.num_measurements) + noise_batch 650 | else: 651 | y_batch = np.matmul(x_batch, A) + noise_batch 652 | return y_batch 653 | 654 | def tensorflow_session(): 655 | # Init session and params 656 | config = tf.ConfigProto() 657 | config.gpu_options.allow_growth = True 658 | # Pin GPU to local rank (one GPU per process) 659 | config.gpu_options.visible_device_list = str(0) 660 | sess = tf.Session(config=config) 661 | return sess 662 | 663 | def get_noise(hparams): 664 | if hparams.measurement_noise_type == 'gaussian': 665 | noise_batch = (hparams.noise_std/np.sqrt(hparams.num_measurements)) * np.random.randn(hparams.batch_size, hparams.num_measurements) 666 | else: 667 | raise NotImplementedError 668 | return noise_batch 669 | 670 | 671 | 672 | def get_full_input(hparams): 673 | """Create input tensors""" 674 | trans = transforms.Compose([transforms.Resize((hparams.image_size,hparams.image_size)),transforms.ToTensor()]) 675 | dataset = datasets.ImageFolder(hparams.input_path, transform=trans) 676 | if hparams.input_type == 'full-input': 677 | dataloader = torch.utils.data.DataLoader(dataset,batch_size=1,drop_last=False,shuffle=False) 678 | elif hparams.input_type == 'random-test': 679 | dataloader = torch.utils.data.DataLoader(dataset,batch_size=1,drop_last=False,shuffle=True) 680 | else: 681 | raise NotImplementedError 682 | 683 | dataiter = iter(dataloader) 684 | images = {i: next(dataiter)[0].view(-1).numpy() for i in range(hparams.num_input_images)} 685 | 686 | return images 687 | 688 | 689 | def model_input(hparams): 690 | """Create input tensors""" 691 | 692 | if hparams.input_type in ['full-input', 'random-test']: 693 | images = get_full_input(hparams) 694 | else: 695 | raise NotImplementedError 696 | return images 697 | 698 | 699 | def view_image(image, hparams): 700 | """Process and show the image""" 701 | if len(image) == hparams.n_input: 702 | image = image.reshape(hparams.image_shape) 703 | if image.shape == hparams.image_shape: 704 | image = np.array(image).transpose(1,2,0) 705 | min_image = image.min() 706 | max_image = image.max() 707 | plot_image((image - min_image)/(max_image - min_image)) 708 | 709 | 710 | def save_image(image, path): 711 | """Save an image as a png file""" 712 | x_png = np.uint8(np.clip(image*256,0,255)) 713 | x_png = x_png.transpose(1,2,0) 714 | if x_png.shape[-1] == 1: 715 | x_png = x_png[:,:,0] 716 | x_png = Image.fromarray(x_png).save(path) 717 | -------------------------------------------------------------------------------- /src/estimators.py: -------------------------------------------------------------------------------- 1 | """Estimators for compressed sensing""" 2 | # pylint: disable = C0301, C0103, C0111, R0914 3 | 4 | import copy 5 | import heapq 6 | import torch 7 | import numpy as np 8 | import utils 9 | import scipy.fftpack as fftpack 10 | import sys 11 | import os 12 | import utils 13 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 14 | 15 | from stylegan2.model import Generator 16 | from ncsnv2.models import get_sigmas, ema 17 | from ncsnv2.models.ncsnv2 import NCSNv2, NCSNv2Deepest 18 | from glow import model as glow_model 19 | 20 | try: 21 | import tensorflow as tf 22 | except: 23 | print('did not import tensorflow') 24 | 25 | import torch.nn.functional as F 26 | 27 | from PULSE import PULSE 28 | import yaml 29 | import argparse 30 | import time 31 | from include import fit, decoder 32 | 33 | def get_measurements_torch(x_hat_batch, A, measurement_type, hparams): 34 | batch_size = hparams.batch_size 35 | if measurement_type == 'project': 36 | y_hat_batch = x_hat_batch 37 | elif measurement_type == 'gaussian': 38 | y_hat_batch = torch.mm(xhat_batch.view(batch_size,-1), A) 39 | elif measurement_type == 'circulant': 40 | sign_pattern = torch.Tensor(hparams.sign_pattern).to(hparams.device) 41 | y_hat_batch = utils.partial_circulant_torch(x_hat_batch, A, hparams.train_indices,sign_pattern) 42 | elif measurement_type == 'superres': 43 | x_hat_reshape_batch = x_hat_batch.view((batch_size,) + hparams.image_shape) 44 | y_hat_batch = F.avg_pool2d(x_hat_reshape_batch, hparams.downsample) 45 | return y_hat_batch.view(batch_size, -1) 46 | 47 | def stylegan_langevin_estimator(hparams): 48 | 49 | model = Generator(hparams.image_size, 512, 8) 50 | model.load_state_dict(torch.load(hparams.checkpoint_path)["g_ema"], strict=False) 51 | model.eval() 52 | model = model.to(hparams.device) 53 | 54 | for p in model.parameters(): 55 | p.requires_grad = False 56 | 57 | #mse = torch.nn.SmoothL1Loss(reduction='none') 58 | mse = torch.nn.MSELoss(reduction='none') 59 | l1 = torch.nn.L1Loss(reduction='none') 60 | annealed = hparams.annealed 61 | 62 | 63 | model = torch.nn.DataParallel(model) 64 | 65 | def estimator(A_val, y_val, hparams): 66 | """Function that returns the estimated image""" 67 | 68 | if A_val is not None: 69 | A = torch.Tensor(A_val).cuda() 70 | else: 71 | A = None 72 | y = torch.Tensor(y_val).cuda() 73 | 74 | best_keeper = utils.BestKeeper(hparams.batch_size, hparams.n_input) 75 | noises_single = model.module.make_noise() 76 | count = 512 77 | for noise in noises_single: 78 | count += np.prod(noise.shape) 79 | best_keeper_z = utils.BestKeeper(hparams.batch_size, count) 80 | 81 | # run T steps of langevin for L different noise levels 82 | T = hparams.T 83 | L = hparams.L 84 | sigma1 = hparams.sigma_init 85 | sigmaT = hparams.sigma_final 86 | # geometric factor for tuning sigma and learning rate 87 | factor = np.power(sigmaT / sigma1, 1/(L-1)) 88 | 89 | # if you're running regular langevin, step size is fixed 90 | # and noise std doesn't change 91 | if annealed: 92 | lr_lambda = lambda i: (sigma1 * np.power(factor, (i-1)//T))**2 / (sigmaT **2) 93 | sigma_lambda = lambda i: sigma1 * np.power(factor, i//T) 94 | else: 95 | lr_lambda = lambda i: 1 96 | sigma_lambda = lambda i: hparams.noise_std 97 | 98 | for i in range(hparams.num_random_restarts): 99 | 100 | z = hparams.zprior_init_sdev * torch.randn(hparams.batch_size, 512, device=hparams.device) 101 | z.requires_grad_() 102 | noises_single = model.module.make_noise() 103 | noises = [] 104 | noise_vars = [] 105 | 106 | count = 512 107 | # optimize over a certain number of noise variables 108 | for idx, noise in enumerate(noises_single): 109 | noises.append(hparams.zprior_init_sdev * noise.repeat(hparams.batch_size, 1, 1, 1).normal_()) 110 | if idx < hparams.num_noise_variables: 111 | noise_vars.append(noises[-1]) 112 | noises[-1].requires_grad = True 113 | count += np.prod(noises[-1].shape) 114 | print(count) 115 | 116 | opt = utils.get_optimizer([z] + noise_vars, hparams.learning_rate, hparams) 117 | # check whether cosine annealing of lr helps 118 | scheduler = torch.optim.lr_scheduler.LambdaLR(opt,lr_lambda) 119 | 120 | for j in range(hparams.max_update_iter): 121 | opt.zero_grad() 122 | # stylegan2 adds some noise to the latent. dunno if this 123 | # is important 124 | 125 | # noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2 126 | # latent_n = latent_noise(latent_in, noise_strength.item()) 127 | 128 | # the flag input_is_latent determines whether z is passed through the 129 | # styling network 130 | # can be done explicitly as well 131 | x_hat_batch = 0.5 * model([z], input_is_latent=False, noise=noises)[0] + 0.5 132 | if hparams.gif and (( j % hparams.gif_iter) == 0): 133 | images = x_hat_batch.detach().cpu().numpy() 134 | for im_num, image in enumerate(images): 135 | save_dir = '{0}/{1}/'.format(hparams.gif_dir, im_num) 136 | utils.set_up_dir(save_dir) 137 | save_path = save_dir + '{0}.png'.format(j) 138 | image = image.reshape(hparams.image_shape) 139 | save_image(image, save_path) 140 | y_hat_batch = get_measurements_torch(x_hat_batch, A, hparams.measurement_type, hparams) 141 | y_hat_batch_nchw = y_hat_batch.view( hparams.y_shape) 142 | y_batch_nchw = y.view( hparams.y_shape) 143 | m_loss_batch = mse(y_hat_batch, y).sum(dim=1) 144 | p_loss_batch = torch.norm(z,dim=-1).pow(2) 145 | for noise in noise_vars: 146 | p_loss_batch += torch.norm(noise.view(hparams.batch_size, -1), dim=-1).pow(2) 147 | p_loss_batch = p_loss_batch.view(-1) 148 | if (hparams.mloss_weight is not None) and (not annealed): 149 | mloss_weight = hparams.mloss_weight 150 | else: 151 | sigma = sigma_lambda(j) 152 | mloss_weight = hparams.num_measurements / (2 * sigma**2) 153 | 154 | if (hparams.zprior_weight is not None) : 155 | zprior_weight = hparams.zprior_weight 156 | else: 157 | zprior_weight = 0.5 / (hparams.zprior_sdev **2) 158 | total_loss_batch = mloss_weight * m_loss_batch + zprior_weight * p_loss_batch 159 | 160 | m_loss = m_loss_batch.sum() 161 | p_loss = p_loss_batch.sum() 162 | total_loss = total_loss_batch.sum() 163 | 164 | total_loss.backward() 165 | opt.step() 166 | 167 | gradient_noise_weight = np.sqrt(2*opt.param_groups[0]['lr'] /(1- hparams.momentum)) 168 | z.data += gradient_noise_weight*torch.randn_like(z).to(hparams.device) 169 | for noise in noise_vars: 170 | noise.data += gradient_noise_weight * torch.randn_like(noise).to(hparams.device) 171 | 172 | scheduler.step() 173 | 174 | logging_format = 'rr {} iter {} lr {} total_loss {} l_loss {} m_loss {} p_loss {}' 175 | print(logging_format.format(i, j, opt.param_groups[0]['lr'], total_loss.item(), None, m_loss.item(), p_loss.item())) 176 | 177 | x_hat_batch = 0.5 * model([z], input_is_latent=False, noise=noises)[0] + 0.5 178 | y_hat_batch = get_measurements_torch(x_hat_batch, A, hparams.measurement_type, hparams) 179 | y_hat_batch_nchw = y_hat_batch.view( hparams.y_shape) 180 | y_batch_nchw = y.view( hparams.y_shape) 181 | m_loss_batch = mse(y_hat_batch, y).sum(dim=1) 182 | p_loss_batch = torch.norm(z,dim=-1).pow(2) 183 | for noise in noise_vars: 184 | p_loss_batch += torch.norm(noise.view(hparams.batch_size, -1), dim=-1).pow(2) 185 | p_loss_batch = p_loss_batch.view(-1) 186 | total_loss_batch = mloss_weight * m_loss_batch \ 187 | + zprior_weight * p_loss_batch 188 | best_keeper.report(x_hat_batch.view(hparams.batch_size,-1).detach().cpu().numpy(), total_loss_batch.detach().cpu().numpy()) 189 | z_hat_batch = z.view(hparams.batch_size,-1).detach().cpu().numpy() 190 | for noise in noises: 191 | z_hat_batch = np.c_[z_hat_batch, noise.view(hparams.batch_size,-1).detach().cpu().numpy()] 192 | 193 | best_keeper_z.report(z_hat_batch, total_loss_batch.detach().cpu().numpy()) 194 | if m_loss_batch.mean()<= hparams.error_threshold: 195 | break 196 | return best_keeper.get_best(), best_keeper_z.get_best(), best_keeper.losses_val_best 197 | 198 | return estimator 199 | 200 | def stylegan_pulse_estimator(hparams): 201 | 202 | model = PULSE(image_size=hparams.image_size, checkpoint_path=hparams.checkpoint_path, dataset=hparams.dataset) 203 | def estimator(A_val, y_val, hparams): 204 | """Function that returns the estimated image""" 205 | kwargs = vars(hparams) 206 | 207 | y_val_tensor = torch.Tensor(y_val.reshape(hparams.y_shape)).cuda() 208 | 209 | for (HR, LR) in model(y_val_tensor, **kwargs): 210 | 211 | return HR.detach().cpu().numpy().reshape(hparams.batch_size,-1), np.zeros(hparams.batch_size), np.zeros(hparams.batch_size) 212 | 213 | return estimator 214 | 215 | 216 | def glow_annealed_map_estimator(hparams): 217 | 218 | annealed = hparams.annealed 219 | # set up model and session 220 | dec_x, dec_eps, hparams.feed_dict, run = glow_model.get_model(hparams.checkpoint_path, hparams.batch_size, hparams.zprior_sdev) 221 | 222 | x_hat_batch_nhwc = dec_x + 0.5 223 | 224 | 225 | # Set up palceholders 226 | A = tf.placeholder(tf.float32, shape=(1,hparams.n_input), name='A') 227 | y_batch = tf.placeholder(tf.float32, shape=(hparams.batch_size, hparams.num_measurements), name='y_batch') 228 | 229 | # convert from NHWC to NCHW 230 | # since I used pytorch for reading data, the measurements 231 | # are from a ground truth of data format NCHW 232 | # Meanwhile GLOW's output has format NHWC 233 | x_hat_batch_nchw = tf.transpose(x_hat_batch_nhwc, perm = [0,3,1,2]) 234 | 235 | # measure the estimate 236 | if hparams.measurement_type == 'project': 237 | y_hat_batch = tf.identity(x_hat_batch_nchw, name='y2_batch') 238 | elif hparams.measurement_type == 'circulant': 239 | sign_pattern_tf = tf.constant(hparams.sign_pattern, name='sign_pattern') 240 | y_hat_batch = utils.partial_circulant_tf(x_hat_batch_nchw, A, hparams.train_indices, sign_pattern_tf) 241 | elif hparams.measurement_type == 'superres': 242 | y_hat_batch = tf.reshape(utils.blur(x_hat_batch_nchw, hparams.downsample),(hparams.batch_size, -1)) 243 | 244 | # define all losses 245 | z_list = [tf.reshape(dec_eps[i],(hparams.batch_size,-1)) for i in range(6)] 246 | z_stack = tf.concat(z_list, axis=1) 247 | z_loss_batch = tf.reduce_sum(z_stack**2, 1) 248 | z_loss = tf.reduce_sum(z_loss_batch) 249 | y_loss_batch = tf.reduce_sum((y_batch - y_hat_batch)**2, 1) 250 | y_loss = tf.reduce_sum(y_loss_batch) 251 | 252 | # mloss_weight should be m/2sigma^2 for proper langevin 253 | # zprior_weight should be 1/(2*0.49) for proper langevin 254 | sigma = tf.placeholder(tf.float32, shape=[]) 255 | if (hparams.mloss_weight is not None) and (not annealed): 256 | mloss_weight = hparams.mloss_weight 257 | else: 258 | mloss_weight = 0.5 * hparams.num_measurements / sigma**2 259 | if hparams.zprior_weight is None: 260 | if hparams.zprior_sdev != 0: 261 | zprior_weight = 1/(2 * hparams.zprior_sdev**2) 262 | else: 263 | zprior_weight = 0 264 | else: 265 | zprior_weight = hparams.zprior_weight 266 | total_loss_batch = mloss_weight * y_loss_batch + zprior_weight * z_loss_batch 267 | total_loss = tf.reduce_sum(total_loss_batch) 268 | 269 | # Set up gradient descent 270 | global_step = tf.Variable(0, trainable=False, name='global_step') 271 | learning_rate = tf.placeholder(tf.float32, shape=[]) 272 | with tf.variable_scope(tf.get_variable_scope(), reuse=False): 273 | opt = utils.get_optimizer(dec_eps,learning_rate, hparams) 274 | update_op = opt.minimize(total_loss, var_list=dec_eps, global_step=global_step, name='update_op') 275 | 276 | sess = utils.tensorflow_session() 277 | 278 | # initialize variables 279 | uninitialized_vars = set(sess.run(tf.report_uninitialized_variables())) 280 | init_op = tf.variables_initializer( 281 | [v for v in tf.global_variables() if v.op.name.encode('UTF-8') in uninitialized_vars]) 282 | # sess.run(init_op) 283 | 284 | def estimator(A_val, y_val, hparams): 285 | """Function that returns the estimated image""" 286 | best_keeper = utils.BestKeeper(hparams.batch_size, hparams.n_input) 287 | best_keeper_z = utils.BestKeeper(hparams.batch_size, hparams.n_input) 288 | 289 | feed_dict = hparams.feed_dict.copy() 290 | 291 | if hparams.measurement_type == 'circulant': 292 | feed_dict.update({A: A_val, y_batch: y_val}) 293 | else: 294 | feed_dict.update({y_batch: y_val}) 295 | 296 | for i in range(hparams.num_random_restarts): 297 | sess.run(init_op) 298 | # sess.run(opt_reinit_op) 299 | for j in range(hparams.max_update_iter): 300 | if hparams.gif and (( j % hparams.gif_iter) == 0): 301 | images = sess.run(x_hat_batch_nchw, feed_dict=feed_dict) 302 | for im_num, image in enumerate(images): 303 | save_dir = '{0}/{1}/'.format(hparams.gif_dir, im_num) 304 | utils.set_up_dir(save_dir) 305 | save_path = save_dir + '{0}.png'.format(j) 306 | image = image.reshape(hparams.image_shape) 307 | save_image(image, save_path) 308 | 309 | factor = np.power(hparams.sigma_final / hparams.sigma_init, 1/(hparams.L-1)) 310 | 311 | if annealed: 312 | lr_lambda = lambda t: (hparams.sigma_init * np.power(factor, t//hparams.T))**2 / (hparams.sigma_final **2) 313 | sigma_value = hparams.sigma_init * np.power(factor, j//hparams.T) 314 | lr_value = hparams.learning_rate * lr_lambda(j) 315 | else: 316 | sigma_value = hparams.sigma_final 317 | lr_value = hparams.learning_rate 318 | 319 | feed_dict.update({learning_rate: lr_value, sigma: sigma_value}) 320 | _, total_loss_value, z_loss_value, y_loss_value = run(sess, [update_op, total_loss, z_loss, y_loss], feed_dict) 321 | logging_format = 'rr {} iter {} lr {} total_loss {} y_loss {} z_loss {}' 322 | print(logging_format.format(i, j, lr_value, total_loss_value, y_loss_value, z_loss_value)) 323 | 324 | 325 | x_hat_batch_value, z_hat_batch_value, total_loss_batch_value = run(sess, [x_hat_batch_nchw, z_stack, total_loss_batch], feed_dict=feed_dict) 326 | 327 | x_hat_batch_value = x_hat_batch_value.reshape(hparams.batch_size, -1) 328 | best_keeper.report(x_hat_batch_value, total_loss_batch_value) 329 | best_keeper_z.report(z_hat_batch_value, total_loss_batch_value) 330 | return best_keeper.get_best(), best_keeper_z.get_best(), best_keeper.losses_val_best 331 | 332 | 333 | return estimator 334 | 335 | 336 | def glow_annealed_langevin_estimator(hparams): 337 | 338 | annealed = hparams.annealed 339 | # set up model and session 340 | dec_x, dec_eps, hparams.feed_dict, run = glow_model.get_model(hparams.checkpoint_path, hparams.batch_size, hparams.zprior_sdev) 341 | 342 | x_hat_batch_nhwc = dec_x + 0.5 343 | 344 | 345 | # Set up palceholders 346 | A = tf.placeholder(tf.float32, shape=(1,hparams.n_input), name='A') 347 | y_batch = tf.placeholder(tf.float32, shape=(hparams.batch_size, hparams.num_measurements), name='y_batch') 348 | 349 | 350 | # convert from NHWC to NCHW 351 | # since I used pytorch for reading data, the measurements 352 | # are from a ground truth of data format NCHW 353 | # Meanwhile GLOW's output has format NHWC 354 | x_hat_batch_nchw = tf.transpose(x_hat_batch_nhwc, perm = [0,3,1,2]) 355 | 356 | # measure the estimate 357 | if hparams.measurement_type == 'project': 358 | y_hat_batch = tf.identity(x_hat_batch_nchw, name='y2_batch') 359 | elif hparams.measurement_type == 'circulant': 360 | sign_pattern_tf = tf.constant(hparams.sign_pattern, name='sign_pattern') 361 | y_hat_batch = utils.partial_circulant_tf(x_hat_batch_nchw, A, hparams.train_indices, sign_pattern_tf) 362 | elif hparams.measurement_type == 'superres': 363 | y_hat_batch = tf.reshape(utils.blur(x_hat_batch_nchw, hparams.downsample),(hparams.batch_size, -1)) 364 | 365 | # create noise placeholders for langevin 366 | noise_vars = [tf.placeholder(tf.float32, shape=dec_eps[i].get_shape()) for i in range(len(dec_eps))] 367 | 368 | # define all losses 369 | z_list = [tf.reshape(dec_eps[i],(hparams.batch_size,-1)) for i in range(6)] 370 | z_stack = tf.concat(z_list, axis=1) 371 | z_loss_batch = tf.reduce_sum(z_stack**2, 1) 372 | z_loss = tf.reduce_sum(z_loss_batch) 373 | y_loss_batch = tf.reduce_sum((y_batch - y_hat_batch)**2, 1) 374 | y_loss = tf.reduce_sum(y_loss_batch) 375 | 376 | # mloss_weight should be m/2sigma^2 for proper langevin 377 | # zprior_weight should be 0.5 for proper langevin 378 | sigma = tf.placeholder(tf.float32, shape=[]) 379 | mloss_weight = 0.5 * hparams.num_measurements / (sigma ** 2) 380 | if hparams.zprior_weight is None: 381 | zprior_weight = 0.5 382 | else: 383 | zprior_weight = hparams.zprior_weight 384 | total_loss_batch = mloss_weight * y_loss_batch + zprior_weight * z_loss_batch 385 | total_loss = tf.reduce_sum(total_loss_batch) 386 | 387 | # Set up gradient descent 388 | global_step = tf.Variable(0, trainable=False, name='global_step') 389 | learning_rate = tf.placeholder(tf.float32, shape=[]) 390 | with tf.variable_scope(tf.get_variable_scope(), reuse=False): 391 | opt = utils.get_optimizer(dec_eps,learning_rate, hparams) 392 | update_op = opt.minimize(total_loss, var_list=dec_eps, global_step=global_step, name='update_op') 393 | noise_ops = [dec_eps[i].assign_add(noise_vars[i]) for i in range(len(dec_eps))] 394 | 395 | sess = utils.tensorflow_session() 396 | 397 | # initialize variables 398 | uninitialized_vars = set(sess.run(tf.report_uninitialized_variables())) 399 | init_op = tf.variables_initializer( 400 | [v for v in tf.global_variables() if v.op.name.encode('UTF-8') in uninitialized_vars]) 401 | 402 | def estimator(A_val, y_val, hparams): 403 | """Function that returns the estimated image""" 404 | best_keeper = utils.BestKeeper(hparams.batch_size, hparams.n_input) 405 | best_keeper_z = utils.BestKeeper(hparams.batch_size, hparams.n_input) 406 | feed_dict = hparams.feed_dict.copy() 407 | 408 | if hparams.measurement_type == 'circulant': 409 | feed_dict.update({A: A_val, y_batch: y_val}) 410 | else: 411 | feed_dict.update({y_batch: y_val}) 412 | 413 | for i in range(hparams.num_random_restarts): 414 | sess.run(init_op) 415 | for j in range(hparams.max_update_iter): 416 | if hparams.gif and (( j % hparams.gif_iter) == 0): 417 | images = sess.run(x_hat_batch_nchw, feed_dict=feed_dict) 418 | for im_num, image in enumerate(images): 419 | save_dir = '{0}/{1}/'.format(hparams.gif_dir, im_num) 420 | utils.set_up_dir(save_dir) 421 | save_path = save_dir + '{0}.png'.format(j) 422 | image = image.reshape(hparams.image_shape) 423 | save_image(image, save_path) 424 | 425 | 426 | factor = np.power(hparams.sigma_final / hparams.sigma_init, 1/(hparams.L-1)) 427 | 428 | if annealed: 429 | lr_lambda = lambda t: (hparams.sigma_init * np.power(factor, t//hparams.T))**2 / (hparams.sigma_final **2) 430 | sigma_value = hparams.sigma_init * np.power(factor, j//hparams.T) 431 | else: 432 | lr_lambda = lambda t: 1 433 | sigma_value = hparams.sigma_final 434 | lr_value = hparams.learning_rate * lr_lambda(j) 435 | 436 | feed_dict.update({learning_rate: lr_value, sigma: sigma_value}) 437 | _, total_loss_value, z_loss_value, y_loss_value = run(sess, [update_op, total_loss, z_loss, y_loss], feed_dict) 438 | logging_format = 'rr {} iter {} lr {} total_loss {} y_loss {} z_loss {}' 439 | print(logging_format.format(i, j, lr_value, total_loss_value, y_loss_value, z_loss_value)) 440 | 441 | # gradient_noise_weight should be sqrt(2*lr) for proper langevin 442 | gradient_noise_weight = np.sqrt(2*lr_value/(1-hparams.momentum)) 443 | for noise_var in noise_vars: 444 | noise_shape = noise_var.get_shape().as_list() 445 | feed_dict.update({noise_var: gradient_noise_weight *np.random.randn(hparams.batch_size, noise_shape[1], noise_shape[2], noise_shape[3])}) 446 | results = run(sess,noise_ops,feed_dict) 447 | 448 | 449 | x_hat_batch_value, z_hat_batch_value, total_loss_batch_value = run(sess, [x_hat_batch_nchw, z_stack, total_loss_batch], feed_dict=feed_dict) 450 | 451 | x_hat_batch_value = x_hat_batch_value.reshape(hparams.batch_size, -1) 452 | best_keeper.report(x_hat_batch_value, total_loss_batch_value) 453 | best_keeper_z.report(z_hat_batch_value, total_loss_batch_value) 454 | return best_keeper.get_best(), best_keeper_z.get_best(), best_keeper.losses_val_best 455 | 456 | return estimator 457 | 458 | 459 | def ncsnv2_langevin_estimator(hparams, MAP=False): 460 | def dict2namespace(config): 461 | namespace = argparse.Namespace() 462 | for key, value in config.items(): 463 | if isinstance(value, dict): 464 | new_value = dict2namespace(value) 465 | else: 466 | new_value = value 467 | setattr(namespace, key, new_value) 468 | return namespace 469 | 470 | def gradient_log_conditional_likelihood(A, y_batch, y_hat_batch): 471 | err = y_batch - y_hat_batch 472 | if hparams.measurement_type == 'superres': 473 | err = err.view(hparams.y_shape) 474 | ans = F.interpolate(err, scale_factor = hparams.downsample) 475 | elif hparams.measurement_type == 'circulant': 476 | err_padded = torch.zeros(hparams.batch_size, hparams.n_input).to(hparams.device) 477 | err_padded[:,hparams.train_indices] = err 478 | A_shift = torch.zeros_like(A) 479 | A_shift[0,0] = A[0,0] 480 | A_shift[0,1:] = A[0,1:].flip(dims=[0]) 481 | 482 | err_A = utils.partial_circulant_torch(err_padded, A_shift, range(hparams.n_input), sign_pattern=ones_torch) 483 | 484 | ans = err_A * sign_pattern_torch 485 | ans = ans.view((-1,) + hparams.image_shape) 486 | else: 487 | return NotImplementedError 488 | 489 | return ans 490 | 491 | 492 | batch_size = hparams.batch_size 493 | if hparams.measurement_type == 'circulant': 494 | if hparams.sign_pattern is not None: 495 | sign_pattern_torch = torch.Tensor(hparams.sign_pattern).to(hparams.device) 496 | ones_torch = torch.ones(1,hparams.n_input).to(hparams.device) 497 | else: 498 | pass 499 | new_config = dict2namespace(hparams.ncsnv2_configs) 500 | new_config.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 501 | new_config.sampling.batch_size = batch_size 502 | 503 | if 'ffhq' in hparams.dataset : 504 | score = NCSNv2Deepest(new_config).to(new_config.device) 505 | elif hparams.dataset == 'celebA': 506 | score = NCSNv2(new_config).to(new_config.device) 507 | 508 | 509 | sigmas_torch = get_sigmas(new_config) 510 | sigmas = sigmas_torch.cpu().numpy() 511 | 512 | states = torch.load(hparams.checkpoint_path, 513 | map_location=new_config.device) 514 | 515 | score = torch.nn.DataParallel(score) 516 | 517 | score.load_state_dict(states[0], strict=True) 518 | 519 | for p in score.parameters(): 520 | p.requires_grad = False 521 | 522 | if new_config.model.ema: 523 | ema_helper = ema.EMAHelper(mu=new_config.model.ema_rate) 524 | ema_helper.register(score) 525 | ema_helper.load_state_dict(states[-1]) 526 | ema_helper.ema(score) 527 | 528 | 529 | score.eval() 530 | 531 | mse = torch.nn.MSELoss(reduction='none') 532 | 533 | def estimator(A_val, y_val, hparams): 534 | x_hat_nchw_batch = torch.rand((hparams.batch_size,) + hparams.image_shape, 535 | device=new_config.device) 536 | zeros = torch.zeros_like(x_hat_nchw_batch) 537 | n_steps_each = new_config.sampling.n_steps_each 538 | step_lr = new_config.sampling.step_lr 539 | 540 | y_batch = torch.Tensor(y_val).to(new_config.device) 541 | if A_val is not None: 542 | A = torch.Tensor(A_val).to(new_config.device) 543 | else: 544 | A = None 545 | with torch.no_grad(): 546 | start = time.time() 547 | for c, sigma in enumerate(sigmas[:int(hparams.L)]): 548 | labels = torch.ones(x_hat_nchw_batch.shape[0], device=x_hat_nchw_batch.device) * c 549 | labels = labels.long() 550 | step_size = step_lr * (sigma / sigmas[-1]) ** 2 551 | 552 | for s in range(n_steps_each): 553 | j = c*n_steps_each + s 554 | if hparams.gif and (( j % hparams.gif_iter) == 0): 555 | images = x_hat_nchw_batch.detach().cpu().numpy() 556 | for im_num, image in enumerate(images): 557 | save_dir = '{0}/{1}/'.format(hparams.gif_dir, im_num) 558 | utils.set_up_dir(save_dir) 559 | save_path = save_dir + '{0}.png'.format(j) 560 | image = image.reshape(hparams.image_shape) 561 | save_image(image, save_path) 562 | noise = torch.randn_like(x_hat_nchw_batch) * np.sqrt(step_size * 2) 563 | grad = score(x_hat_nchw_batch, labels) 564 | #error = (up(torch_downsample(x_hat_nchw_batch,factor) - y) / sigma**2) 565 | 566 | y_hat_batch = get_measurements_torch(x_hat_nchw_batch.view(hparams.batch_size, -1), A, hparams.measurement_type, hparams) 567 | m_loss_grad_nchw = gradient_log_conditional_likelihood(A, y_batch, y_hat_batch)/(sigma**2 + hparams.noise_std**2/hparams.num_measurements) 568 | if hparams.mloss_weight is None: 569 | mloss_weight = 1.0 570 | else: 571 | mloss_weight = hparams.mloss_weight 572 | if MAP: 573 | x_hat_nchw_batch = x_hat_nchw_batch + step_size * (grad + mloss_weight * m_loss_grad_nchw) 574 | else: 575 | x_hat_nchw_batch = x_hat_nchw_batch + step_size * (grad + mloss_weight * m_loss_grad_nchw) + noise 576 | 577 | m_loss_batch = mse(y_hat_batch, y_batch).sum(dim=1) 578 | 579 | print("class: {}, step_size: {}, mean {}, max {}, y_mse {}".format(c, step_size, grad.abs().mean(), 580 | grad.abs().max(), m_loss_batch.mean())) 581 | end = time.time() 582 | print(f'Time on batch:{(end - start)/60:.3f} minutes') 583 | return x_hat_nchw_batch.view(hparams.batch_size,-1).cpu().numpy(), np.zeros(hparams.batch_size), m_loss_batch.cpu().numpy() 584 | 585 | return estimator 586 | 587 | def deep_decoder_estimator(hparams): 588 | num_channels = [700]*7 589 | output_depth = 3 # hparams.image_size 590 | 591 | def estimator(A_val, y_val, hparams): 592 | 593 | y_batch = torch.Tensor(y_val).to(hparams.device) 594 | if A_val is not None: 595 | A = torch.Tensor(A_val).to(hparams.device) 596 | else: 597 | A = None 598 | 599 | def apply_f(x): 600 | return get_measurements_torch(x.view(hparams.batch_size,-1),A,hparams.measurement_type,hparams) 601 | 602 | net = decoder.decodernw(output_depth, num_channels_up=num_channels, upsample_first=True).cuda() 603 | 604 | rn = 0.005 605 | rnd = 500 606 | numit = 4000 607 | 608 | print(hparams.max_update_iter) 609 | mse_n, mse_t, ni, net = fit( 610 | num_channels=num_channels, 611 | reg_noise_std=rn, 612 | reg_noise_decayevery = rnd, 613 | num_iter=hparams.max_update_iter, 614 | LR=hparams.learning_rate, 615 | OPTIMIZER=hparams.optimizer_type, 616 | img_noisy_var=y_batch, 617 | net=net, 618 | img_clean_var=torch.zeros_like(y_batch), 619 | find_best=True, 620 | apply_f=apply_f, 621 | ) 622 | return net(ni.cuda()).view(hparams.batch_size,-1).detach().cpu().numpy(), np.zeros(hparams.batch_size), np.zeros(hparams.batch_size) 623 | 624 | return estimator 625 | 626 | --------------------------------------------------------------------------------