├── .gitignore ├── LICENSE ├── README.md ├── data ├── README.txt ├── cat_video.mat ├── icvl.mat ├── pet.mat └── weizzman.mat ├── matrix_decomp.py ├── modules ├── cosine_annealing_with_warmup.py ├── deep_decoder.py ├── deep_models │ ├── __init__.py │ ├── common.py │ ├── dcgan.py │ ├── downsampler.py │ ├── resnet.py │ ├── skip.py │ ├── texture_nets.py │ └── unet.py ├── deep_prior.py ├── lin_inverse.py ├── losses.py ├── models.py ├── spectral.py └── utils.py ├── requirements.txt ├── run_figure11.py ├── run_figure11_TV.py ├── run_figure12.py ├── run_figure12_TV.py ├── run_figure2.py ├── run_figure6.py ├── run_figure7.py ├── run_figure9.py ├── run_figure9_bm3d.py └── run_table1.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Vishwanath S R V 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **DeepTensor: Low-Rank Tensor Decomposition with Deep Network Priors** 2 | 3 | **Paper**: https://arxiv.org/abs/2204.03145 4 | 5 | **Contents** 6 | - data/ -- contains mat files required to run some python scripts. The folder contains a README listing the sources 7 | - cat_video.mat -- required for run_figure11.py and run_figure11_TV.py 8 | - icvl.mat -- required for run_figure9.py and run_figure9_bm3d.py 9 | - pet.mat -- required for run_figure12.py and run_figure12_TV.py 10 | - weizzman.mat -- required for run_figure2.py 11 | - modules/ -- contains several python scripts required for all experiments 12 | - deep_models/ -- folder downloaded from https://github.com/DmitryUlyanov/deep-image-prior 13 | - cosine_annealing_with_warmup.py -- code for cosine annealing scheduler downloaded from https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup 14 | - deep_decoder.py -- convenience code downloaded from https://github.com/reinhardh/supplement_deep_decoder 15 | - deep_prior.py -- convenience code downloaded from https://github.com/DmitryUlyanov/deep-image-prior 16 | - lin_inverse.py -- implements functions required for linear inverse problems 17 | - losses.py -- implements losses required for training include L1, L2, and 2D TV 18 | - models.py -- implements multilayer perceptron, and multi-dimensional (1D, 2D, and 3D) U-Net 19 | - spectral.py -- required for run_figure6.py 20 | - utils.py -- miscellaneous utilities 21 | - requirements.txt -- contains a list of requirements to run the code base. Install the required packages using `pip install -r requirements.txt` 22 | - run_figure2.py -- run this to replicate results in figure 2 23 | - run_figure6.py -- run this to replicate results in figure 6 24 | - run_figure7.py -- run this to replicate results in figure 7 25 | - run_figure9.py -- run this to replicate results with DeepTensor and SVD in figure 9 26 | - run_figure9_bm3d.py -- run this to replicate results with BM3D in figure 9 27 | - run_figure11.py -- run this to replicate results with DeepTensor in figure 11 28 | - run_figure11_TV.py -- run this to replicate results with TV in figure 11 29 | - run_figure12.py -- run this to replicate results with DeepTensor in figure 12 30 | - run_figure12_TV.py -- run this to replicate results with TV in figure 12 31 | - run_table1.py -- run this to replicate results in table 1 32 | -------------------------------------------------------------------------------- /data/README.txt: -------------------------------------------------------------------------------- 1 | Source for each of the dataset: 2 | 1. cat_video.mat -- Downloaded from https://drive.google.com/drive/folders/1_iq__37-hw7FJOEUK1tX7mdp8SKB368K. Video downsampled and clipped to fit supplementary requirements 3 | 2. icvl.mat -- Downloaded from http://icvl.cs.bgu.ac.il/hyperspectral/. Data downsampled from original resolution to fit supplementary requirements 4 | 3. pet.mat -- Downloaded from https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=70224216. Data downsampled to fit supplementary requirements 5 | 4. weizzman.mat -- Downloaded from http://www.wisdom.weizmann.ac.il/~vision/FaceBase/. Data downsampled to fit supplementary requirements -------------------------------------------------------------------------------- /data/cat_video.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vishwa91/DeepTensor/32fd35b606a3a0d2fb6a377291143a687aaa2bbd/data/cat_video.mat -------------------------------------------------------------------------------- /data/icvl.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vishwa91/DeepTensor/32fd35b606a3a0d2fb6a377291143a687aaa2bbd/data/icvl.mat -------------------------------------------------------------------------------- /data/pet.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vishwa91/DeepTensor/32fd35b606a3a0d2fb6a377291143a687aaa2bbd/data/pet.mat -------------------------------------------------------------------------------- /data/weizzman.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vishwa91/DeepTensor/32fd35b606a3a0d2fb6a377291143a687aaa2bbd/data/weizzman.mat -------------------------------------------------------------------------------- /matrix_decomp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ''' 4 | This script replicates instances of results in figure 4 in the main paper. 5 | To obtain the plots, please sweep the relevant parameters. 6 | ''' 7 | 8 | import os 9 | import sys 10 | import glob 11 | import tqdm 12 | import importlib 13 | import argparse 14 | 15 | import numpy as np 16 | from scipy import io 17 | from skimage.metrics import structural_similarity as ssim_func 18 | 19 | import torch 20 | 21 | import matplotlib.pyplot as plt 22 | plt.gray() 23 | import cv2 24 | 25 | sys.path.append('modules') 26 | 27 | import models 28 | import utils 29 | import losses 30 | import deep_prior 31 | import deep_decoder 32 | 33 | if __name__ == '__main__': 34 | # Set your simulation constants here 35 | nrows = 64 # Size of the matrix 36 | ncols = 64 37 | rank = 32 # Rank 38 | noise_type = 'gaussian' # Type of noise 39 | signal_type = 'gaussian' # Type of signal 40 | noise_snr = 0.2 # Std. dev for gaussian noise 41 | tau = 1000 # Max. lambda for photon noise 42 | 43 | # Network parameters 44 | nettype = 'dip' 45 | n_inputs = rank 46 | init_nconv = 128 47 | num_channels_up = 5 48 | 49 | sched_args = argparse.Namespace() 50 | # Learning constants 51 | # Important: The number of epochs is decided by noise levels. As a general 52 | # rule of thumb, higher the noise, fewer the epochs. 53 | scheduler_type = 'none' 54 | learning_rate = 1e-4 55 | epochs = 1000 56 | sched_args.step_size = 2000 57 | sched_args.gamma = pow(10, -1/epochs) 58 | sched_args.max_lr = learning_rate 59 | sched_args.min_lr = 1e-6 60 | sched_args.epochs = epochs 61 | 62 | # Generate data 63 | mat, mat_gt = utils.get_matrix(nrows, ncols, rank, noise_type, signal_type, 64 | noise_snr, tau) 65 | 66 | # Move them to device 67 | mat_ten = torch.tensor(mat)[None, ...].cuda() 68 | mat_gt_ten = torch.tensor(mat_gt)[None, ...].cuda() 69 | 70 | u_inp = utils.get_inp([1, n_inputs, nrows]) 71 | v_inp = utils.get_inp([1, n_inputs, ncols]) 72 | 73 | # Create networks 74 | if nettype == 'unet': 75 | u_net = models.UNetND(n_inputs, rank, 1, init_nconv).cuda() 76 | v_net = models.UNetND(n_inputs, rank, 1, init_nconv).cuda() 77 | elif nettype == 'dip': 78 | u_net = deep_prior.get_net(n_inputs, 'skip1d', 'reflection', 79 | upsample_mode='linear', 80 | skip_n33d=init_nconv, 81 | skip_n33u=init_nconv, 82 | num_scales=5, 83 | n_channels=rank).cuda() 84 | v_net = deep_prior.get_net(n_inputs, 'skip1d', 'reflection', 85 | upsample_mode='linear', 86 | skip_n33d=init_nconv, 87 | skip_n33u=init_nconv, 88 | num_scales=5, 89 | n_channels=rank).cuda() 90 | elif nettype == 'dd': 91 | u_net = deep_decoder.decodernw1d(rank, 92 | [init_nconv]*num_channels_up).cuda() 93 | v_net = deep_decoder.decodernw1d(rank, 94 | [init_nconv]*num_channels_up).cuda() 95 | 96 | # Deep decoder requires smaller inputs 97 | u_inp = utils.get_inp([1, init_nconv, nrows // pow(2, num_channels_up)]) 98 | v_inp = utils.get_inp([1, init_nconv, ncols // pow(2, num_channels_up)]) 99 | 100 | # Extract training parameters 101 | net_params = list(u_net.parameters()) + list(v_net.parameters()) 102 | inp_params = [u_inp] + [v_inp] 103 | 104 | # You can either optimize both net and inputs, or just net 105 | params = net_params + inp_params 106 | optimizer = torch.optim.Adam(lr=learning_rate, params=params) 107 | 108 | # Create a learning scheduler 109 | scheduler = utils.get_scheduler(scheduler_type, optimizer, sched_args) 110 | 111 | # Create loss functions -- loses.L1Norm() or losses.L2Norm() 112 | criterion = losses.L2Norm() 113 | 114 | mse_array = np.zeros(epochs) 115 | 116 | # Now start iterations 117 | best_mse = float('inf') 118 | tbar = tqdm.tqdm(range(epochs)) 119 | for idx in tbar: 120 | u_output = u_net(u_inp).permute(0, 2, 1) 121 | v_output = v_net(v_inp) 122 | 123 | mat_estim = torch.bmm(u_output, v_output) 124 | 125 | loss = criterion(mat_estim - mat_ten) 126 | 127 | optimizer.zero_grad() 128 | loss.backward() 129 | optimizer.step() 130 | scheduler.step() 131 | 132 | # Visualize the reconstruction 133 | diff = abs(mat_gt_ten - mat_estim).squeeze().detach().cpu() 134 | mat_cpu = mat_estim.squeeze().detach().cpu().numpy() 135 | 136 | cv2.imshow('Diff x10', diff.numpy().reshape(nrows, ncols)*10) 137 | cv2.imshow('Rec', np.hstack((mat_gt, mat_cpu))) 138 | cv2.waitKey(1) 139 | 140 | mse_array[idx] = ((mat_estim - mat_gt_ten)**2).mean().item() 141 | tbar.set_description('%.4e'%mse_array[idx]) 142 | tbar.refresh() 143 | 144 | if loss.item() < best_mse: 145 | best_epoch = idx 146 | best_mat = mat_cpu 147 | best_mse = loss.item() 148 | 149 | # Now compute accuracy 150 | psnr1 = utils.psnr(mat_gt, best_mat) 151 | psnr2 = utils.psnr(mat_gt, utils.lr_decompose(mat, rank)) 152 | 153 | print('DeepTensor: %.2fdB'%psnr1) 154 | print('SVD: %.2fdB'%psnr2) 155 | 156 | 157 | -------------------------------------------------------------------------------- /modules/cosine_annealing_with_warmup.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | ''' 6 | This script was downloaded from https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup 7 | ''' 8 | 9 | class CosineAnnealingWarmupRestarts(_LRScheduler): 10 | """ 11 | optimizer (Optimizer): Wrapped optimizer. 12 | first_cycle_steps (int): First cycle step size. 13 | cycle_mult(float): Cycle steps magnification. Default: -1. 14 | max_lr(float): First cycle's max learning rate. Default: 0.1. 15 | min_lr(float): Min learning rate. Default: 0.001. 16 | warmup_steps(int): Linear warmup step size. Default: 0. 17 | gamma(float): Decrease rate of max learning rate by cycle. Default: 1. 18 | last_epoch (int): The index of last epoch. Default: -1. 19 | """ 20 | 21 | def __init__(self, 22 | optimizer : torch.optim.Optimizer, 23 | first_cycle_steps : int, 24 | cycle_mult : float = 1., 25 | max_lr : float = 0.1, 26 | min_lr : float = 0.001, 27 | warmup_steps : int = 0, 28 | gamma : float = 1., 29 | last_epoch : int = -1 30 | ): 31 | assert warmup_steps < first_cycle_steps 32 | 33 | self.first_cycle_steps = first_cycle_steps # first cycle step size 34 | self.cycle_mult = cycle_mult # cycle steps magnification 35 | self.base_max_lr = max_lr # first max learning rate 36 | self.max_lr = max_lr # max learning rate in the current cycle 37 | self.min_lr = min_lr # min learning rate 38 | self.warmup_steps = warmup_steps # warmup step size 39 | self.gamma = gamma # decrease rate of max learning rate by cycle 40 | 41 | self.cur_cycle_steps = first_cycle_steps # first cycle step size 42 | self.cycle = 0 # cycle count 43 | self.step_in_cycle = last_epoch # step size of the current cycle 44 | 45 | super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch) 46 | 47 | # set learning rate min_lr 48 | self.init_lr() 49 | 50 | def init_lr(self): 51 | self.base_lrs = [] 52 | for param_group in self.optimizer.param_groups: 53 | param_group['lr'] = self.min_lr 54 | self.base_lrs.append(self.min_lr) 55 | 56 | def get_lr(self): 57 | if self.step_in_cycle == -1: 58 | return self.base_lrs 59 | elif self.step_in_cycle < self.warmup_steps: 60 | return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs] 61 | else: 62 | return [base_lr + (self.max_lr - base_lr) \ 63 | * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \ 64 | / (self.cur_cycle_steps - self.warmup_steps))) / 2 65 | for base_lr in self.base_lrs] 66 | 67 | def step(self, epoch=None): 68 | if epoch is None: 69 | epoch = self.last_epoch + 1 70 | self.step_in_cycle = self.step_in_cycle + 1 71 | if self.step_in_cycle >= self.cur_cycle_steps: 72 | self.cycle += 1 73 | self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps 74 | self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps 75 | else: 76 | if epoch >= self.first_cycle_steps: 77 | if self.cycle_mult == 1.: 78 | self.step_in_cycle = epoch % self.first_cycle_steps 79 | self.cycle = epoch // self.first_cycle_steps 80 | else: 81 | n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult)) 82 | self.cycle = n 83 | self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1)) 84 | self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n) 85 | else: 86 | self.cur_cycle_steps = self.first_cycle_steps 87 | self.step_in_cycle = epoch 88 | 89 | self.max_lr = self.base_max_lr * (self.gamma**self.cycle) 90 | self.last_epoch = math.floor(epoch) 91 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 92 | param_group['lr'] = lr -------------------------------------------------------------------------------- /modules/deep_decoder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This is a modified version of the original code hosted at https://github.com/reinhardh/supplement_deep_decoder 3 | 4 | Specifically: 5 | - 1d and 3d versions of deep decoder are added 6 | - The final sigmoid layer is turned off by default 7 | ''' 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | def add_module(self, module): 13 | self.add_module(str(len(self) + 1), module) 14 | 15 | torch.nn.Module.add = add_module 16 | 17 | 18 | def conv(in_f, out_f, kernel_size, stride=1, pad='zero'): 19 | padder = None 20 | to_pad = int((kernel_size - 1) / 2) 21 | if pad == 'reflection': 22 | padder = nn.ReflectionPad2d(to_pad) 23 | to_pad = 0 24 | 25 | convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=False) 26 | 27 | layers = filter(lambda x: x is not None, [padder, convolver]) 28 | return nn.Sequential(*layers) 29 | 30 | def conv3d(in_f, out_f, kernel_size, stride=1, pad='zero'): 31 | padder = None 32 | to_pad = int((kernel_size - 1) / 2) 33 | if pad == 'reflection': 34 | padder = nn.ReplicationPad3d(to_pad) 35 | to_pad = 0 36 | 37 | convolver = nn.Conv3d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=False) 38 | 39 | layers = filter(lambda x: x is not None, [padder, convolver]) 40 | return nn.Sequential(*layers) 41 | 42 | def conv1d(in_f, out_f, kernel_size, stride=1, pad='zero'): 43 | padder = None 44 | to_pad = int((kernel_size - 1) / 2) 45 | if pad == 'reflection': 46 | padder = nn.ReflectionPad1d(to_pad) 47 | to_pad = 0 48 | 49 | convolver = nn.Conv1d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=False) 50 | 51 | layers = filter(lambda x: x is not None, [padder, convolver]) 52 | return nn.Sequential(*layers) 53 | 54 | def decodernw( 55 | num_output_channels=3, 56 | num_channels_up=[128]*5, 57 | filter_size_up=1, 58 | need_sigmoid=False, 59 | pad ='reflection', 60 | upsample_mode='bilinear', 61 | act_fun=nn.LeakyReLU(0.2, inplace=True), 62 | bn_before_act = False, 63 | bn_affine = True, 64 | upsample_first = True, 65 | ): 66 | 67 | num_channels_up = num_channels_up + [num_channels_up[-1],num_channels_up[-1]] 68 | n_scales = len(num_channels_up) 69 | 70 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 71 | filter_size_up = [filter_size_up]*n_scales 72 | model = nn.Sequential() 73 | 74 | 75 | for i in range(len(num_channels_up)-1): 76 | 77 | if upsample_first: 78 | model.add(conv( num_channels_up[i], num_channels_up[i+1], filter_size_up[i], 1, pad=pad)) 79 | if upsample_mode!='none' and i != len(num_channels_up)-2: 80 | model.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 81 | #model.add(nn.functional.interpolate(size=None,scale_factor=2, mode=upsample_mode)) 82 | else: 83 | if upsample_mode!='none' and i!=0: 84 | model.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 85 | #model.add(nn.functional.interpolate(size=None,scale_factor=2, mode=upsample_mode)) 86 | model.add(conv( num_channels_up[i], num_channels_up[i+1], filter_size_up[i], 1, pad=pad)) 87 | 88 | if i != len(num_channels_up)-1: 89 | if(bn_before_act): 90 | model.add(nn.BatchNorm2d( num_channels_up[i+1] ,affine=bn_affine)) 91 | model.add(act_fun) 92 | if(not bn_before_act): 93 | model.add(nn.BatchNorm2d( num_channels_up[i+1], affine=bn_affine)) 94 | 95 | model.add(conv( num_channels_up[-1], num_output_channels, 1, pad=pad)) 96 | if need_sigmoid: 97 | model.add(nn.Sigmoid()) 98 | 99 | return model 100 | 101 | 102 | def decodernw3d( 103 | num_output_channels=3, 104 | num_channels_up=[128]*5, 105 | filter_size_up=1, 106 | need_sigmoid=False, 107 | pad ='reflection', 108 | upsample_mode='trilinear', 109 | act_fun=nn.LeakyReLU(0.2, inplace=True), 110 | bn_before_act = False, 111 | bn_affine = True, 112 | upsample_first = True, 113 | ): 114 | 115 | num_channels_up = num_channels_up + [num_channels_up[-1],num_channels_up[-1]] 116 | n_scales = len(num_channels_up) 117 | 118 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 119 | filter_size_up = [filter_size_up]*n_scales 120 | model = nn.Sequential() 121 | 122 | 123 | for i in range(len(num_channels_up)-1): 124 | 125 | if upsample_first: 126 | model.add(conv3d( num_channels_up[i], num_channels_up[i+1], filter_size_up[i], 1, pad=pad)) 127 | if upsample_mode!='none' and i != len(num_channels_up)-2: 128 | model.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 129 | #model.add(nn.functional.interpolate(size=None,scale_factor=2, mode=upsample_mode)) 130 | else: 131 | if upsample_mode!='none' and i!=0: 132 | model.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 133 | #model.add(nn.functional.interpolate(size=None,scale_factor=2, mode=upsample_mode)) 134 | model.add(conv3d( num_channels_up[i], num_channels_up[i+1], filter_size_up[i], 1, pad=pad)) 135 | 136 | if i != len(num_channels_up)-1: 137 | if(bn_before_act): 138 | model.add(nn.BatchNorm3d( num_channels_up[i+1] ,affine=bn_affine)) 139 | model.add(act_fun) 140 | if(not bn_before_act): 141 | model.add(nn.BatchNorm3d( num_channels_up[i+1], affine=bn_affine)) 142 | 143 | model.add(conv3d( num_channels_up[-1], num_output_channels, 1, pad=pad)) 144 | if need_sigmoid: 145 | model.add(nn.Sigmoid()) 146 | 147 | return model 148 | 149 | def decodernw1d( 150 | num_output_channels=3, 151 | num_channels_up=[128]*5, 152 | filter_size_up=1, 153 | need_sigmoid=False, 154 | pad ='reflection', 155 | upsample_mode='linear', 156 | act_fun=nn.LeakyReLU(0.2, inplace=True), 157 | bn_before_act=False, 158 | bn_affine=True, 159 | upsample_first=True, 160 | out_size=None, 161 | ): 162 | 163 | num_channels_up = num_channels_up + [num_channels_up[-1],num_channels_up[-1]] 164 | n_scales = len(num_channels_up) 165 | 166 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 167 | filter_size_up = [filter_size_up]*n_scales 168 | model = nn.Sequential() 169 | 170 | 171 | for i in range(len(num_channels_up)-1): 172 | 173 | if upsample_first: 174 | model.add(conv1d( num_channels_up[i], num_channels_up[i+1], filter_size_up[i], 1, pad=pad)) 175 | if upsample_mode!='none' and i != len(num_channels_up)-2: 176 | model.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 177 | #model.add(nn.functional.interpolate(size=None,scale_factor=2, mode=upsample_mode)) 178 | else: 179 | if upsample_mode!='none' and i!=0: 180 | model.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 181 | #model.add(nn.functional.interpolate(size=None,scale_factor=2, mode=upsample_mode)) 182 | model.add(conv1d( num_channels_up[i], num_channels_up[i+1], filter_size_up[i], 1, pad=pad)) 183 | 184 | if i != len(num_channels_up)-1: 185 | if(bn_before_act): 186 | model.add(nn.BatchNorm1d( num_channels_up[i+1] ,affine=bn_affine)) 187 | model.add(act_fun) 188 | if(not bn_before_act): 189 | model.add(nn.BatchNorm1d( num_channels_up[i+1], affine=bn_affine)) 190 | 191 | if out_size is not None: 192 | model.add(nn.Upsample(size=out_size, mode=upsample_mode)) 193 | 194 | model.add(conv1d( num_channels_up[-1], num_output_channels, 1, pad=pad)) 195 | if need_sigmoid: 196 | model.add(nn.Sigmoid()) 197 | 198 | return model 199 | 200 | 201 | # Residual block 202 | class ResidualBlock(nn.Module): 203 | def __init__(self, in_f, out_f): 204 | super(ResidualBlock, self).__init__() 205 | self.conv = nn.Conv2d(in_f, out_f, 1, 1, padding=0, bias=False) 206 | 207 | def forward(self, x): 208 | residual = x 209 | out = self.conv(x) 210 | out += residual 211 | return out 212 | 213 | def resdecoder( 214 | num_output_channels=3, 215 | num_channels_up=[128]*5, 216 | filter_size_up=1, 217 | need_sigmoid=True, 218 | pad='reflection', 219 | upsample_mode='bilinear', 220 | act_fun=nn.ReLU(), # nn.LeakyReLU(0.2, inplace=True) 221 | bn_before_act = False, 222 | bn_affine = True, 223 | ): 224 | 225 | num_channels_up = num_channels_up + [num_channels_up[-1],num_channels_up[-1]] 226 | n_scales = len(num_channels_up) 227 | 228 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 229 | filter_size_up = [filter_size_up]*n_scales 230 | 231 | model = nn.Sequential() 232 | 233 | for i in range(len(num_channels_up)-2): 234 | 235 | model.add( ResidualBlock( num_channels_up[i], num_channels_up[i+1]) ) 236 | 237 | if upsample_mode!='none': 238 | model.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 239 | #model.add(nn.functional.interpolate(size=None,scale_factor=2, mode=upsample_mode)) 240 | 241 | if i != len(num_channels_up)-1: 242 | model.add(act_fun) 243 | #model.add(nn.BatchNorm2d( num_channels_up[i+1], affine=bn_affine)) 244 | 245 | # new 246 | model.add(ResidualBlock( num_channels_up[-1], num_channels_up[-1])) 247 | #model.add(nn.BatchNorm2d( num_channels_up[-1] ,affine=bn_affine)) 248 | model.add(act_fun) 249 | # end new 250 | 251 | model.add(conv( num_channels_up[-1], num_output_channels, 1, pad=pad)) 252 | 253 | if need_sigmoid: 254 | model.add(nn.Sigmoid()) 255 | 256 | return model 257 | 258 | -------------------------------------------------------------------------------- /modules/deep_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .skip import skip 2 | from .texture_nets import get_texture_nets 3 | from .resnet import ResNet 4 | from .unet import UNet 5 | 6 | import torch.nn as nn 7 | 8 | def get_net(input_depth, NET_TYPE, pad, upsample_mode, n_channels=3, act_fun='LeakyReLU', skip_n33d=128, skip_n33u=128, skip_n11=4, num_scales=5, downsample_mode='stride'): 9 | if NET_TYPE == 'ResNet': 10 | # TODO 11 | net = ResNet(input_depth, 3, 10, 16, 1, nn.BatchNorm2d, False) 12 | elif NET_TYPE == 'skip': 13 | net = skip(input_depth, n_channels, num_channels_down = [skip_n33d]*num_scales if isinstance(skip_n33d, int) else skip_n33d, 14 | num_channels_up = [skip_n33u]*num_scales if isinstance(skip_n33u, int) else skip_n33u, 15 | num_channels_skip = [skip_n11]*num_scales if isinstance(skip_n11, int) else skip_n11, 16 | upsample_mode=upsample_mode, downsample_mode=downsample_mode, 17 | need_sigmoid=True, need_bias=True, pad=pad, act_fun=act_fun) 18 | 19 | elif NET_TYPE == 'texture_nets': 20 | net = get_texture_nets(inp=input_depth, ratios = [32, 16, 8, 4, 2, 1], fill_noise=False,pad=pad) 21 | 22 | elif NET_TYPE =='UNet': 23 | net = UNet(num_input_channels=input_depth, num_output_channels=3, 24 | feature_scale=4, more_layers=0, concat_x=False, 25 | upsample_mode=upsample_mode, pad=pad, norm_layer=nn.BatchNorm2d, need_sigmoid=True, need_bias=True) 26 | elif NET_TYPE == 'identity': 27 | assert input_depth == 3 28 | net = nn.Sequential() 29 | else: 30 | assert False 31 | 32 | return net -------------------------------------------------------------------------------- /modules/deep_models/common.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from .downsampler import Downsampler 7 | 8 | def add_module(self, module): 9 | self.add_module(str(len(self) + 1), module) 10 | 11 | torch.nn.Module.add = add_module 12 | 13 | class Concat(nn.Module): 14 | def __init__(self, dim, *args): 15 | super(Concat, self).__init__() 16 | self.dim = dim 17 | 18 | for idx, module in enumerate(args): 19 | self.add_module(str(idx), module) 20 | 21 | def forward(self, input): 22 | inputs = [] 23 | for module in self._modules.values(): 24 | inputs.append(module(input)) 25 | 26 | inputs_shapes2 = [x.shape[2] for x in inputs] 27 | inputs_shapes3 = [x.shape[3] for x in inputs] 28 | 29 | if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)): 30 | inputs_ = inputs 31 | else: 32 | target_shape2 = min(inputs_shapes2) 33 | target_shape3 = min(inputs_shapes3) 34 | 35 | inputs_ = [] 36 | for inp in inputs: 37 | diff2 = (inp.size(2) - target_shape2) // 2 38 | diff3 = (inp.size(3) - target_shape3) // 2 39 | inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3]) 40 | 41 | return torch.cat(inputs_, dim=self.dim) 42 | 43 | def __len__(self): 44 | return len(self._modules) 45 | 46 | class Concat1d(nn.Module): 47 | def __init__(self, dim, *args): 48 | super(Concat1d, self).__init__() 49 | self.dim = dim 50 | 51 | for idx, module in enumerate(args): 52 | self.add_module(str(idx), module) 53 | 54 | def forward(self, input): 55 | inputs = [] 56 | for module in self._modules.values(): 57 | try: 58 | inputs.append(module(input)) 59 | except ValueError: 60 | pdb.set_trace() 61 | 62 | inputs_shapes2 = [x.shape[2] for x in inputs] 63 | 64 | if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)): 65 | inputs_ = inputs 66 | else: 67 | target_shape2 = min(inputs_shapes2) 68 | 69 | inputs_ = [] 70 | for inp in inputs: 71 | diff2 = (inp.size(2) - target_shape2) // 2 72 | inputs_.append(inp[:, :, diff2: diff2 + target_shape2]) 73 | 74 | return torch.cat(inputs_, dim=self.dim) 75 | 76 | def __len__(self): 77 | return len(self._modules) 78 | 79 | class GenNoise(nn.Module): 80 | def __init__(self, dim2): 81 | super(GenNoise, self).__init__() 82 | self.dim2 = dim2 83 | 84 | def forward(self, input): 85 | a = list(input.size()) 86 | a[1] = self.dim2 87 | # print (input.data.type()) 88 | 89 | b = torch.zeros(a).type_as(input.data) 90 | b.normal_() 91 | 92 | x = torch.autograd.Variable(b) 93 | 94 | return x 95 | 96 | 97 | class Swish(nn.Module): 98 | """ 99 | https://arxiv.org/abs/1710.05941 100 | The hype was so huge that I could not help but try it 101 | """ 102 | def __init__(self): 103 | super(Swish, self).__init__() 104 | self.s = nn.Sigmoid() 105 | 106 | def forward(self, x): 107 | return x * self.s(x) 108 | 109 | class Sine(nn.Module): 110 | ''' 111 | Sinusoidal activation layer 112 | ''' 113 | def __init__(self): 114 | super(Sine, self).__init__() 115 | self.omega_0 = 30 116 | 117 | def forward(self, x): 118 | return torch.sin(self.omega_0 * x) 119 | 120 | def act(act_fun = 'LeakyReLU'): 121 | ''' 122 | Either string defining an activation function or module (e.g. nn.ReLU) 123 | ''' 124 | if isinstance(act_fun, str): 125 | if act_fun == 'LeakyReLU': 126 | return nn.LeakyReLU(0.2, inplace=True) 127 | elif act_fun == 'Swish': 128 | return Swish() 129 | elif act_fun == 'ELU': 130 | return nn.ELU() 131 | elif act_fun == 'none': 132 | return nn.Sequential() 133 | elif act_fun == 'sine': 134 | return Sine() 135 | else: 136 | assert False 137 | else: 138 | return act_fun() 139 | 140 | 141 | def bn(num_features): 142 | return nn.BatchNorm2d(num_features) 143 | 144 | def bn1d(num_features): 145 | return nn.BatchNorm1d(num_features) 146 | 147 | 148 | def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero', downsample_mode='stride'): 149 | downsampler = None 150 | if stride != 1 and downsample_mode != 'stride': 151 | 152 | if downsample_mode == 'avg': 153 | downsampler = nn.AvgPool2d(stride, stride) 154 | elif downsample_mode == 'max': 155 | downsampler = nn.MaxPool2d(stride, stride) 156 | elif downsample_mode in ['lanczos2', 'lanczos3']: 157 | downsampler = Downsampler(n_planes=out_f, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True) 158 | else: 159 | assert False 160 | 161 | stride = 1 162 | 163 | padder = None 164 | to_pad = int((kernel_size - 1) / 2) 165 | if pad == 'reflection': 166 | padder = nn.ReflectionPad2d(to_pad) 167 | to_pad = 0 168 | 169 | convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias) 170 | 171 | 172 | layers = filter(lambda x: x is not None, [padder, convolver, downsampler]) 173 | return nn.Sequential(*layers) 174 | 175 | def conv1d(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero', downsample_mode='stride'): 176 | downsampler = None 177 | if stride != 1 and downsample_mode != 'stride': 178 | 179 | if downsample_mode == 'avg': 180 | downsampler = nn.AvgPool1d(stride) 181 | elif downsample_mode == 'max': 182 | downsampler = nn.MaxPool1d(stride) 183 | else: 184 | assert False 185 | 186 | stride = 1 187 | 188 | padder = None 189 | to_pad = int((kernel_size - 1) / 2) 190 | if pad == 'reflection': 191 | padder = nn.ReflectionPad1d(to_pad) 192 | to_pad = 0 193 | 194 | convolver = nn.Conv1d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias) 195 | 196 | layers = filter(lambda x: x is not None, [padder, convolver, downsampler]) 197 | return nn.Sequential(*layers) -------------------------------------------------------------------------------- /modules/deep_models/dcgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def dcgan(inp=2, 5 | ndf=32, 6 | num_ups=4, need_sigmoid=True, need_bias=True, pad='zero', upsample_mode='nearest', need_convT = True): 7 | 8 | layers= [nn.ConvTranspose2d(inp, ndf, kernel_size=3, stride=1, padding=0, bias=False), 9 | nn.BatchNorm2d(ndf), 10 | nn.LeakyReLU(True)] 11 | 12 | for i in range(num_ups-3): 13 | if need_convT: 14 | layers += [ nn.ConvTranspose2d(ndf, ndf, kernel_size=4, stride=2, padding=1, bias=False), 15 | nn.BatchNorm2d(ndf), 16 | nn.LeakyReLU(True)] 17 | else: 18 | layers += [ nn.Upsample(scale_factor=2, mode=upsample_mode), 19 | nn.Conv2d(ndf, ndf, kernel_size=3, stride=1, padding=1, bias=False), 20 | nn.BatchNorm2d(ndf), 21 | nn.LeakyReLU(True)] 22 | 23 | if need_convT: 24 | layers += [nn.ConvTranspose2d(ndf, 3, 4, 2, 1, bias=False),] 25 | else: 26 | layers += [nn.Upsample(scale_factor=2, mode='bilinear'), 27 | nn.Conv2d(ndf, 3, kernel_size=3, stride=1, padding=1, bias=False)] 28 | 29 | 30 | if need_sigmoid: 31 | layers += [nn.Sigmoid()] 32 | 33 | model =nn.Sequential(*layers) 34 | return model -------------------------------------------------------------------------------- /modules/deep_models/downsampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | class Downsampler(nn.Module): 6 | ''' 7 | http://www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf 8 | ''' 9 | def __init__(self, n_planes, factor, kernel_type, phase=0, kernel_width=None, support=None, sigma=None, preserve_size=False): 10 | super(Downsampler, self).__init__() 11 | 12 | assert phase in [0, 0.5], 'phase should be 0 or 0.5' 13 | 14 | if kernel_type == 'lanczos2': 15 | support = 2 16 | kernel_width = 4 * factor + 1 17 | kernel_type_ = 'lanczos' 18 | 19 | elif kernel_type == 'lanczos3': 20 | support = 3 21 | kernel_width = 6 * factor + 1 22 | kernel_type_ = 'lanczos' 23 | 24 | elif kernel_type == 'gauss12': 25 | kernel_width = 7 26 | sigma = 1/2 27 | kernel_type_ = 'gauss' 28 | 29 | elif kernel_type == 'gauss1sq2': 30 | kernel_width = 9 31 | sigma = 1./np.sqrt(2) 32 | kernel_type_ = 'gauss' 33 | 34 | elif kernel_type in ['lanczos', 'gauss', 'box']: 35 | kernel_type_ = kernel_type 36 | 37 | else: 38 | assert False, 'wrong name kernel' 39 | 40 | 41 | # note that `kernel width` will be different to actual size for phase = 1/2 42 | self.kernel = get_kernel(factor, kernel_type_, phase, kernel_width, support=support, sigma=sigma) 43 | 44 | downsampler = nn.Conv2d(n_planes, n_planes, kernel_size=self.kernel.shape, stride=factor, padding=0) 45 | downsampler.weight.data[:] = 0 46 | downsampler.bias.data[:] = 0 47 | 48 | kernel_torch = torch.from_numpy(self.kernel) 49 | for i in range(n_planes): 50 | downsampler.weight.data[i, i] = kernel_torch 51 | 52 | self.downsampler_ = downsampler 53 | 54 | if preserve_size: 55 | 56 | if self.kernel.shape[0] % 2 == 1: 57 | pad = int((self.kernel.shape[0] - 1) / 2.) 58 | else: 59 | pad = int((self.kernel.shape[0] - factor) / 2.) 60 | 61 | self.padding = nn.ReplicationPad2d(pad) 62 | 63 | self.preserve_size = preserve_size 64 | 65 | def forward(self, input): 66 | if self.preserve_size: 67 | x = self.padding(input) 68 | else: 69 | x= input 70 | self.x = x 71 | return self.downsampler_(x) 72 | 73 | def get_kernel(factor, kernel_type, phase, kernel_width, support=None, sigma=None): 74 | assert kernel_type in ['lanczos', 'gauss', 'box'] 75 | 76 | # factor = float(factor) 77 | if phase == 0.5 and kernel_type != 'box': 78 | kernel = np.zeros([kernel_width - 1, kernel_width - 1]) 79 | else: 80 | kernel = np.zeros([kernel_width, kernel_width]) 81 | 82 | 83 | if kernel_type == 'box': 84 | assert phase == 0.5, 'Box filter is always half-phased' 85 | kernel[:] = 1./(kernel_width * kernel_width) 86 | 87 | elif kernel_type == 'gauss': 88 | assert sigma, 'sigma is not specified' 89 | assert phase != 0.5, 'phase 1/2 for gauss not implemented' 90 | 91 | center = (kernel_width + 1.)/2. 92 | print(center, kernel_width) 93 | sigma_sq = sigma * sigma 94 | 95 | for i in range(1, kernel.shape[0] + 1): 96 | for j in range(1, kernel.shape[1] + 1): 97 | di = (i - center)/2. 98 | dj = (j - center)/2. 99 | kernel[i - 1][j - 1] = np.exp(-(di * di + dj * dj)/(2 * sigma_sq)) 100 | kernel[i - 1][j - 1] = kernel[i - 1][j - 1]/(2. * np.pi * sigma_sq) 101 | elif kernel_type == 'lanczos': 102 | assert support, 'support is not specified' 103 | center = (kernel_width + 1) / 2. 104 | 105 | for i in range(1, kernel.shape[0] + 1): 106 | for j in range(1, kernel.shape[1] + 1): 107 | 108 | if phase == 0.5: 109 | di = abs(i + 0.5 - center) / factor 110 | dj = abs(j + 0.5 - center) / factor 111 | else: 112 | di = abs(i - center) / factor 113 | dj = abs(j - center) / factor 114 | 115 | 116 | pi_sq = np.pi * np.pi 117 | 118 | val = 1 119 | if di != 0: 120 | val = val * support * np.sin(np.pi * di) * np.sin(np.pi * di / support) 121 | val = val / (np.pi * np.pi * di * di) 122 | 123 | if dj != 0: 124 | val = val * support * np.sin(np.pi * dj) * np.sin(np.pi * dj / support) 125 | val = val / (np.pi * np.pi * dj * dj) 126 | 127 | kernel[i - 1][j - 1] = val 128 | 129 | 130 | else: 131 | assert False, 'wrong method name' 132 | 133 | kernel /= kernel.sum() 134 | 135 | return kernel 136 | 137 | #a = Downsampler(n_planes=3, factor=2, kernel_type='lanczos2', phase='1', preserve_size=True) 138 | 139 | 140 | 141 | 142 | 143 | 144 | ################# 145 | # Learnable downsampler 146 | 147 | # KS = 32 148 | # dow = nn.Sequential(nn.ReplicationPad2d(int((KS - factor) / 2.)), nn.Conv2d(1,1,KS,factor)) 149 | 150 | # class Apply(nn.Module): 151 | # def __init__(self, what, dim, *args): 152 | # super(Apply, self).__init__() 153 | # self.dim = dim 154 | 155 | # self.what = what 156 | 157 | # def forward(self, input): 158 | # inputs = [] 159 | # for i in range(input.size(self.dim)): 160 | # inputs.append(self.what(input.narrow(self.dim, i, 1))) 161 | 162 | # return torch.cat(inputs, dim=self.dim) 163 | 164 | # def __len__(self): 165 | # return len(self._modules) 166 | 167 | # downs = Apply(dow, 1) 168 | # downs.type(dtype)(net_input.type(dtype)).size() 169 | -------------------------------------------------------------------------------- /modules/deep_models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from numpy.random import normal 4 | from numpy.linalg import svd 5 | from math import sqrt 6 | import torch.nn.init 7 | from .common import * 8 | 9 | class ResidualSequential(nn.Sequential): 10 | def __init__(self, *args): 11 | super(ResidualSequential, self).__init__(*args) 12 | 13 | def forward(self, x): 14 | out = super(ResidualSequential, self).forward(x) 15 | # print(x.size(), out.size()) 16 | x_ = None 17 | if out.size(2) != x.size(2) or out.size(3) != x.size(3): 18 | diff2 = x.size(2) - out.size(2) 19 | diff3 = x.size(3) - out.size(3) 20 | # print(1) 21 | x_ = x[:, :, diff2 /2:out.size(2) + diff2 / 2, diff3 / 2:out.size(3) + diff3 / 2] 22 | else: 23 | x_ = x 24 | return out + x_ 25 | 26 | def eval(self): 27 | print(2) 28 | for m in self.modules(): 29 | m.eval() 30 | exit() 31 | 32 | 33 | def get_block(num_channels, norm_layer, act_fun): 34 | layers = [ 35 | nn.Conv2d(num_channels, num_channels, 3, 1, 1, bias=False), 36 | norm_layer(num_channels, affine=True), 37 | act(act_fun), 38 | nn.Conv2d(num_channels, num_channels, 3, 1, 1, bias=False), 39 | norm_layer(num_channels, affine=True), 40 | ] 41 | return layers 42 | 43 | 44 | class ResNet(nn.Module): 45 | def __init__(self, num_input_channels, num_output_channels, num_blocks, num_channels, need_residual=True, act_fun='LeakyReLU', need_sigmoid=True, norm_layer=nn.BatchNorm2d, pad='reflection'): 46 | ''' 47 | pad = 'start|zero|replication' 48 | ''' 49 | super(ResNet, self).__init__() 50 | 51 | if need_residual: 52 | s = ResidualSequential 53 | else: 54 | s = nn.Sequential 55 | 56 | stride = 1 57 | # First layers 58 | layers = [ 59 | # nn.ReplicationPad2d(num_blocks * 2 * stride + 3), 60 | conv(num_input_channels, num_channels, 3, stride=1, bias=True, pad=pad), 61 | act(act_fun) 62 | ] 63 | # Residual blocks 64 | # layers_residual = [] 65 | for i in range(num_blocks): 66 | layers += [s(*get_block(num_channels, norm_layer, act_fun))] 67 | 68 | layers += [ 69 | nn.Conv2d(num_channels, num_channels, 3, 1, 1), 70 | norm_layer(num_channels, affine=True) 71 | ] 72 | 73 | # if need_residual: 74 | # layers += [ResidualSequential(*layers_residual)] 75 | # else: 76 | # layers += [Sequential(*layers_residual)] 77 | 78 | # if factor >= 2: 79 | # # Do upsampling if needed 80 | # layers += [ 81 | # nn.Conv2d(num_channels, num_channels * 82 | # factor ** 2, 3, 1), 83 | # nn.PixelShuffle(factor), 84 | # act(act_fun) 85 | # ] 86 | layers += [ 87 | conv(num_channels, num_output_channels, 3, 1, bias=True, pad=pad), 88 | nn.Sigmoid() 89 | ] 90 | self.model = nn.Sequential(*layers) 91 | 92 | def forward(self, input): 93 | return self.model(input) 94 | 95 | def eval(self): 96 | self.model.eval() 97 | -------------------------------------------------------------------------------- /modules/deep_models/skip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .common import * 4 | 5 | def skip( 6 | num_input_channels=2, num_output_channels=3, 7 | num_channels_down=[16, 32, 64, 128, 128], num_channels_up=[16, 32, 64, 128, 128], num_channels_skip=[4, 4, 4, 4, 4], 8 | filter_size_down=3, filter_size_up=3, filter_skip_size=1, 9 | need_sigmoid=True, need_bias=True, 10 | pad='zero', upsample_mode='nearest', downsample_mode='stride', act_fun='LeakyReLU', 11 | need1x1_up=True): 12 | """Assembles encoder-decoder with skip connections. 13 | 14 | Arguments: 15 | act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU) 16 | pad (string): zero|reflection (default: 'zero') 17 | upsample_mode (string): 'nearest|bilinear' (default: 'nearest') 18 | downsample_mode (string): 'stride|avg|max|lanczos2' (default: 'stride') 19 | 20 | """ 21 | assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip) 22 | 23 | n_scales = len(num_channels_down) 24 | 25 | if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)) : 26 | upsample_mode = [upsample_mode]*n_scales 27 | 28 | if not (isinstance(downsample_mode, list)or isinstance(downsample_mode, tuple)): 29 | downsample_mode = [downsample_mode]*n_scales 30 | 31 | if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)) : 32 | filter_size_down = [filter_size_down]*n_scales 33 | 34 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 35 | filter_size_up = [filter_size_up]*n_scales 36 | 37 | last_scale = n_scales - 1 38 | 39 | cur_depth = None 40 | 41 | model = nn.Sequential() 42 | model_tmp = model 43 | 44 | input_depth = num_input_channels 45 | for i in range(len(num_channels_down)): 46 | 47 | deeper = nn.Sequential() 48 | skip = nn.Sequential() 49 | 50 | if num_channels_skip[i] != 0: 51 | model_tmp.add(Concat(1, skip, deeper)) 52 | else: 53 | model_tmp.add(deeper) 54 | 55 | model_tmp.add(bn(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i]))) 56 | 57 | if num_channels_skip[i] != 0: 58 | skip.add(conv(input_depth, num_channels_skip[i], filter_skip_size, bias=need_bias, pad=pad)) 59 | skip.add(bn(num_channels_skip[i])) 60 | skip.add(act(act_fun)) 61 | 62 | # skip.add(Concat(2, GenNoise(nums_noise[i]), skip_part)) 63 | 64 | deeper.add(conv(input_depth, num_channels_down[i], filter_size_down[i], 2, bias=need_bias, pad=pad, downsample_mode=downsample_mode[i])) 65 | deeper.add(bn(num_channels_down[i])) 66 | deeper.add(act(act_fun)) 67 | 68 | deeper.add(conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad)) 69 | deeper.add(bn(num_channels_down[i])) 70 | deeper.add(act(act_fun)) 71 | 72 | deeper_main = nn.Sequential() 73 | 74 | if i == len(num_channels_down) - 1: 75 | # The deepest 76 | k = num_channels_down[i] 77 | else: 78 | deeper.add(deeper_main) 79 | k = num_channels_up[i + 1] 80 | 81 | deeper.add(nn.Upsample(scale_factor=2, mode=upsample_mode[i])) 82 | 83 | model_tmp.add(conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad)) 84 | model_tmp.add(bn(num_channels_up[i])) 85 | model_tmp.add(act(act_fun)) 86 | 87 | 88 | if need1x1_up: 89 | model_tmp.add(conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad)) 90 | model_tmp.add(bn(num_channels_up[i])) 91 | model_tmp.add(act(act_fun)) 92 | 93 | input_depth = num_channels_down[i] 94 | model_tmp = deeper_main 95 | 96 | model.add(conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad)) 97 | if need_sigmoid: 98 | model.add(nn.Sigmoid()) 99 | 100 | return model 101 | 102 | def skip1d( 103 | num_input_channels=2, num_output_channels=3, 104 | num_channels_down=[16, 32, 64, 128, 128], num_channels_up=[16, 32, 64, 128, 128], num_channels_skip=[4, 4, 4, 4, 4], 105 | filter_size_down=3, filter_size_up=3, filter_skip_size=1, 106 | need_sigmoid=True, need_bias=True, 107 | pad='zero', upsample_mode='nearest', downsample_mode='stride', act_fun='LeakyReLU', 108 | need1x1_up=True): 109 | """Assembles encoder-decoder with skip connections. 110 | 111 | Arguments: 112 | act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU) 113 | pad (string): zero|reflection (default: 'zero') 114 | upsample_mode (string): 'nearest|bilinear' (default: 'nearest') 115 | downsample_mode (string): 'stride|avg|max|lanczos2' (default: 'stride') 116 | 117 | """ 118 | assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip) 119 | 120 | n_scales = len(num_channels_down) 121 | 122 | if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)) : 123 | upsample_mode = [upsample_mode]*n_scales 124 | 125 | if not (isinstance(downsample_mode, list)or isinstance(downsample_mode, tuple)): 126 | downsample_mode = [downsample_mode]*n_scales 127 | 128 | if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)) : 129 | filter_size_down = [filter_size_down]*n_scales 130 | 131 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 132 | filter_size_up = [filter_size_up]*n_scales 133 | 134 | last_scale = n_scales - 1 135 | 136 | cur_depth = None 137 | 138 | model = nn.Sequential() 139 | model_tmp = model 140 | 141 | input_depth = num_input_channels 142 | for i in range(len(num_channels_down)): 143 | 144 | deeper = nn.Sequential() 145 | skip = nn.Sequential() 146 | 147 | if num_channels_skip[i] != 0: 148 | model_tmp.add(Concat1d(1, skip, deeper)) 149 | else: 150 | model_tmp.add(deeper) 151 | 152 | model_tmp.add(bn1d(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i]))) 153 | 154 | if num_channels_skip[i] != 0: 155 | skip.add(conv1d(input_depth, num_channels_skip[i], filter_skip_size, bias=need_bias, pad=pad)) 156 | skip.add(bn1d(num_channels_skip[i])) 157 | skip.add(act(act_fun)) 158 | 159 | # skip.add(Concat(2, GenNoise(nums_noise[i]), skip_part)) 160 | 161 | deeper.add(conv1d(input_depth, num_channels_down[i], filter_size_down[i], 2, bias=need_bias, pad=pad, downsample_mode=downsample_mode[i])) 162 | deeper.add(bn1d(num_channels_down[i])) 163 | deeper.add(act(act_fun)) 164 | 165 | deeper.add(conv1d(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad)) 166 | deeper.add(bn1d(num_channels_down[i])) 167 | deeper.add(act(act_fun)) 168 | 169 | deeper_main = nn.Sequential() 170 | 171 | if i == len(num_channels_down) - 1: 172 | # The deepest 173 | k = num_channels_down[i] 174 | else: 175 | deeper.add(deeper_main) 176 | k = num_channels_up[i + 1] 177 | 178 | deeper.add(nn.Upsample(scale_factor=2, mode=upsample_mode[i])) 179 | 180 | model_tmp.add(conv1d(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad)) 181 | model_tmp.add(bn1d(num_channels_up[i])) 182 | model_tmp.add(act(act_fun)) 183 | 184 | 185 | if need1x1_up: 186 | model_tmp.add(conv1d(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad)) 187 | model_tmp.add(bn1d(num_channels_up[i])) 188 | model_tmp.add(act(act_fun)) 189 | 190 | input_depth = num_channels_down[i] 191 | model_tmp = deeper_main 192 | 193 | model.add(conv1d(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad)) 194 | if need_sigmoid: 195 | model.add(nn.Sigmoid()) 196 | 197 | return model -------------------------------------------------------------------------------- /modules/deep_models/texture_nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .common import * 4 | 5 | 6 | normalization = nn.BatchNorm2d 7 | 8 | 9 | def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero'): 10 | if pad == 'zero': 11 | return nn.Conv2d(in_f, out_f, kernel_size, stride, padding=(kernel_size - 1) / 2, bias=bias) 12 | elif pad == 'reflection': 13 | layers = [nn.ReflectionPad2d((kernel_size - 1) / 2), 14 | nn.Conv2d(in_f, out_f, kernel_size, stride, padding=0, bias=bias)] 15 | return nn.Sequential(*layers) 16 | 17 | def get_texture_nets(inp=3, ratios = [32, 16, 8, 4, 2, 1], fill_noise=False, pad='zero', need_sigmoid=False, conv_num=8, upsample_mode='nearest'): 18 | 19 | 20 | for i in range(len(ratios)): 21 | j = i + 1 22 | 23 | seq = nn.Sequential() 24 | 25 | tmp = nn.AvgPool2d(ratios[i], ratios[i]) 26 | 27 | seq.add(tmp) 28 | if fill_noise: 29 | seq.add(GenNoise(inp)) 30 | 31 | seq.add(conv(inp, conv_num, 3, pad=pad)) 32 | seq.add(normalization(conv_num)) 33 | seq.add(act()) 34 | 35 | seq.add(conv(conv_num, conv_num, 3, pad=pad)) 36 | seq.add(normalization(conv_num)) 37 | seq.add(act()) 38 | 39 | seq.add(conv(conv_num, conv_num, 1, pad=pad)) 40 | seq.add(normalization(conv_num)) 41 | seq.add(act()) 42 | 43 | if i == 0: 44 | seq.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 45 | cur = seq 46 | else: 47 | 48 | cur_temp = cur 49 | 50 | cur = nn.Sequential() 51 | 52 | # Batch norm before merging 53 | seq.add(normalization(conv_num)) 54 | cur_temp.add(normalization(conv_num * (j - 1))) 55 | 56 | cur.add(Concat(1, cur_temp, seq)) 57 | 58 | cur.add(conv(conv_num * j, conv_num * j, 3, pad=pad)) 59 | cur.add(normalization(conv_num * j)) 60 | cur.add(act()) 61 | 62 | cur.add(conv(conv_num * j, conv_num * j, 3, pad=pad)) 63 | cur.add(normalization(conv_num * j)) 64 | cur.add(act()) 65 | 66 | cur.add(conv(conv_num * j, conv_num * j, 1, pad=pad)) 67 | cur.add(normalization(conv_num * j)) 68 | cur.add(act()) 69 | 70 | if i == len(ratios) - 1: 71 | cur.add(conv(conv_num * j, 3, 1, pad=pad)) 72 | else: 73 | cur.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 74 | 75 | model = cur 76 | if need_sigmoid: 77 | model.add(nn.Sigmoid()) 78 | 79 | return model 80 | -------------------------------------------------------------------------------- /modules/deep_models/unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .common import * 6 | 7 | class ListModule(nn.Module): 8 | def __init__(self, *args): 9 | super(ListModule, self).__init__() 10 | idx = 0 11 | for module in args: 12 | self.add_module(str(idx), module) 13 | idx += 1 14 | 15 | def __getitem__(self, idx): 16 | if idx >= len(self._modules): 17 | raise IndexError('index {} is out of range'.format(idx)) 18 | if idx < 0: 19 | idx = len(self) + idx 20 | 21 | it = iter(self._modules.values()) 22 | for i in range(idx): 23 | next(it) 24 | return next(it) 25 | 26 | def __iter__(self): 27 | return iter(self._modules.values()) 28 | 29 | def __len__(self): 30 | return len(self._modules) 31 | 32 | class UNet(nn.Module): 33 | ''' 34 | upsample_mode in ['deconv', 'nearest', 'bilinear'] 35 | pad in ['zero', 'replication', 'none'] 36 | ''' 37 | def __init__(self, num_input_channels=3, num_output_channels=3, 38 | feature_scale=4, more_layers=0, concat_x=False, 39 | upsample_mode='deconv', pad='zero', norm_layer=nn.InstanceNorm2d, need_sigmoid=True, need_bias=True): 40 | super(UNet, self).__init__() 41 | 42 | self.feature_scale = feature_scale 43 | self.more_layers = more_layers 44 | self.concat_x = concat_x 45 | 46 | 47 | filters = [64, 128, 256, 512, 1024] 48 | filters = [x // self.feature_scale for x in filters] 49 | 50 | self.start = unetConv2(num_input_channels, filters[0] if not concat_x else filters[0] - num_input_channels, norm_layer, need_bias, pad) 51 | 52 | self.down1 = unetDown(filters[0], filters[1] if not concat_x else filters[1] - num_input_channels, norm_layer, need_bias, pad) 53 | self.down2 = unetDown(filters[1], filters[2] if not concat_x else filters[2] - num_input_channels, norm_layer, need_bias, pad) 54 | self.down3 = unetDown(filters[2], filters[3] if not concat_x else filters[3] - num_input_channels, norm_layer, need_bias, pad) 55 | self.down4 = unetDown(filters[3], filters[4] if not concat_x else filters[4] - num_input_channels, norm_layer, need_bias, pad) 56 | 57 | # more downsampling layers 58 | if self.more_layers > 0: 59 | self.more_downs = [ 60 | unetDown(filters[4], filters[4] if not concat_x else filters[4] - num_input_channels , norm_layer, need_bias, pad) for i in range(self.more_layers)] 61 | self.more_ups = [unetUp(filters[4], upsample_mode, need_bias, pad, same_num_filt =True) for i in range(self.more_layers)] 62 | 63 | self.more_downs = ListModule(*self.more_downs) 64 | self.more_ups = ListModule(*self.more_ups) 65 | 66 | self.up4 = unetUp(filters[3], upsample_mode, need_bias, pad) 67 | self.up3 = unetUp(filters[2], upsample_mode, need_bias, pad) 68 | self.up2 = unetUp(filters[1], upsample_mode, need_bias, pad) 69 | self.up1 = unetUp(filters[0], upsample_mode, need_bias, pad) 70 | 71 | self.final = conv(filters[0], num_output_channels, 1, bias=need_bias, pad=pad) 72 | 73 | if need_sigmoid: 74 | self.final = nn.Sequential(self.final, nn.Sigmoid()) 75 | 76 | def forward(self, inputs): 77 | 78 | # Downsample 79 | downs = [inputs] 80 | down = nn.AvgPool2d(2, 2) 81 | for i in range(4 + self.more_layers): 82 | downs.append(down(downs[-1])) 83 | 84 | in64 = self.start(inputs) 85 | if self.concat_x: 86 | in64 = torch.cat([in64, downs[0]], 1) 87 | 88 | down1 = self.down1(in64) 89 | if self.concat_x: 90 | down1 = torch.cat([down1, downs[1]], 1) 91 | 92 | down2 = self.down2(down1) 93 | if self.concat_x: 94 | down2 = torch.cat([down2, downs[2]], 1) 95 | 96 | down3 = self.down3(down2) 97 | if self.concat_x: 98 | down3 = torch.cat([down3, downs[3]], 1) 99 | 100 | down4 = self.down4(down3) 101 | if self.concat_x: 102 | down4 = torch.cat([down4, downs[4]], 1) 103 | 104 | if self.more_layers > 0: 105 | prevs = [down4] 106 | for kk, d in enumerate(self.more_downs): 107 | # print(prevs[-1].size()) 108 | out = d(prevs[-1]) 109 | if self.concat_x: 110 | out = torch.cat([out, downs[kk + 5]], 1) 111 | 112 | prevs.append(out) 113 | 114 | up_ = self.more_ups[-1](prevs[-1], prevs[-2]) 115 | for idx in range(self.more_layers - 1): 116 | l = self.more_ups[self.more - idx - 2] 117 | up_= l(up_, prevs[self.more - idx - 2]) 118 | else: 119 | up_= down4 120 | 121 | up4= self.up4(up_, down3) 122 | up3= self.up3(up4, down2) 123 | up2= self.up2(up3, down1) 124 | up1= self.up1(up2, in64) 125 | 126 | return self.final(up1) 127 | 128 | 129 | 130 | class unetConv2(nn.Module): 131 | def __init__(self, in_size, out_size, norm_layer, need_bias, pad): 132 | super(unetConv2, self).__init__() 133 | 134 | print(pad) 135 | if norm_layer is not None: 136 | self.conv1= nn.Sequential(conv(in_size, out_size, 3, bias=need_bias, pad=pad), 137 | norm_layer(out_size), 138 | nn.ReLU(),) 139 | self.conv2= nn.Sequential(conv(out_size, out_size, 3, bias=need_bias, pad=pad), 140 | norm_layer(out_size), 141 | nn.ReLU(),) 142 | else: 143 | self.conv1= nn.Sequential(conv(in_size, out_size, 3, bias=need_bias, pad=pad), 144 | nn.ReLU(),) 145 | self.conv2= nn.Sequential(conv(out_size, out_size, 3, bias=need_bias, pad=pad), 146 | nn.ReLU(),) 147 | def forward(self, inputs): 148 | outputs= self.conv1(inputs) 149 | outputs= self.conv2(outputs) 150 | return outputs 151 | 152 | 153 | class unetDown(nn.Module): 154 | def __init__(self, in_size, out_size, norm_layer, need_bias, pad): 155 | super(unetDown, self).__init__() 156 | self.conv= unetConv2(in_size, out_size, norm_layer, need_bias, pad) 157 | self.down= nn.MaxPool2d(2, 2) 158 | 159 | def forward(self, inputs): 160 | outputs= self.down(inputs) 161 | outputs= self.conv(outputs) 162 | return outputs 163 | 164 | 165 | class unetUp(nn.Module): 166 | def __init__(self, out_size, upsample_mode, need_bias, pad, same_num_filt=False): 167 | super(unetUp, self).__init__() 168 | 169 | num_filt = out_size if same_num_filt else out_size * 2 170 | if upsample_mode == 'deconv': 171 | self.up= nn.ConvTranspose2d(num_filt, out_size, 4, stride=2, padding=1) 172 | self.conv= unetConv2(out_size * 2, out_size, None, need_bias, pad) 173 | elif upsample_mode=='bilinear' or upsample_mode=='nearest': 174 | self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode=upsample_mode), 175 | conv(num_filt, out_size, 3, bias=need_bias, pad=pad)) 176 | self.conv= unetConv2(out_size * 2, out_size, None, need_bias, pad) 177 | else: 178 | assert False 179 | 180 | def forward(self, inputs1, inputs2): 181 | in1_up= self.up(inputs1) 182 | 183 | if (inputs2.size(2) != in1_up.size(2)) or (inputs2.size(3) != in1_up.size(3)): 184 | diff2 = (inputs2.size(2) - in1_up.size(2)) // 2 185 | diff3 = (inputs2.size(3) - in1_up.size(3)) // 2 186 | inputs2_ = inputs2[:, :, diff2 : diff2 + in1_up.size(2), diff3 : diff3 + in1_up.size(3)] 187 | else: 188 | inputs2_ = inputs2 189 | 190 | output= self.conv(torch.cat([in1_up, inputs2_], 1)) 191 | 192 | return output 193 | -------------------------------------------------------------------------------- /modules/deep_prior.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env 2 | 3 | ''' 4 | This is a modified version of the original code hosted at https://github.com/DmitryUlyanov/deep-image-prior 5 | 6 | Specifically 7 | - 1D version is added 8 | - The final sigmoid layer is turned off by default 9 | ''' 10 | 11 | import os 12 | import sys 13 | import tqdm 14 | import pdb 15 | 16 | import numpy as np 17 | import torch 18 | from torch import nn 19 | 20 | import torchvision 21 | import cv2 22 | 23 | from deep_models.skip import skip, skip1d 24 | from deep_models.texture_nets import get_texture_nets 25 | from deep_models.resnet import ResNet 26 | from deep_models.unet import UNet 27 | 28 | class Downsampler(nn.Module): 29 | ''' 30 | http://www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf 31 | ''' 32 | def __init__(self, n_planes, factor, kernel_type, phase=0, 33 | kernel_width=None, support=None, sigma=None, 34 | preserve_size=False): 35 | super(Downsampler, self).__init__() 36 | 37 | assert phase in [0, 0.5], 'phase should be 0 or 0.5' 38 | 39 | if kernel_type == 'lanczos2': 40 | support = 2 41 | kernel_width = 4 * factor + 1 42 | kernel_type_ = 'lanczos' 43 | 44 | elif kernel_type == 'lanczos3': 45 | support = 3 46 | kernel_width = 6 * factor + 1 47 | kernel_type_ = 'lanczos' 48 | 49 | elif kernel_type == 'gauss12': 50 | kernel_width = 7 51 | sigma = 1/2 52 | kernel_type_ = 'gauss' 53 | 54 | elif kernel_type == 'gauss1sq2': 55 | kernel_width = 9 56 | sigma = 1./np.sqrt(2) 57 | kernel_type_ = 'gauss' 58 | 59 | elif kernel_type in ['lanczos', 'gauss', 'box']: 60 | kernel_type_ = kernel_type 61 | 62 | else: 63 | assert False, 'wrong name kernel' 64 | 65 | 66 | # note that `kernel width` will be different to actual size for phase = 1/2 67 | self.kernel = get_kernel(factor, kernel_type_, phase, kernel_width, 68 | support=support, sigma=sigma) 69 | 70 | downsampler = nn.Conv2d(n_planes, n_planes, 71 | kernel_size=self.kernel.shape, 72 | stride=factor, padding=0) 73 | downsampler.weight.data[:] = 0 74 | downsampler.bias.data[:] = 0 75 | 76 | kernel_torch = torch.from_numpy(self.kernel) 77 | for i in range(n_planes): 78 | downsampler.weight.data[i, i] = kernel_torch 79 | 80 | self.downsampler_ = downsampler 81 | 82 | if preserve_size: 83 | 84 | if self.kernel.shape[0] % 2 == 1: 85 | pad = int((self.kernel.shape[0] - 1) / 2.) 86 | else: 87 | pad = int((self.kernel.shape[0] - factor) / 2.) 88 | 89 | self.padding = nn.ReplicationPad2d(pad) 90 | 91 | self.preserve_size = preserve_size 92 | 93 | def forward(self, input): 94 | if self.preserve_size: 95 | x = self.padding(input) 96 | else: 97 | x= input 98 | self.x = x 99 | return self.downsampler_(x) 100 | 101 | def get_kernel(factor, kernel_type, phase, kernel_width, 102 | support=None, sigma=None): 103 | assert kernel_type in ['lanczos', 'gauss', 'box'] 104 | 105 | # factor = float(factor) 106 | if phase == 0.5 and kernel_type != 'box': 107 | kernel = np.zeros([kernel_width - 1, kernel_width - 1]) 108 | else: 109 | kernel = np.zeros([kernel_width, kernel_width]) 110 | 111 | 112 | if kernel_type == 'box': 113 | assert phase == 0.5, 'Box filter is always half-phased' 114 | kernel[:] = 1./(kernel_width * kernel_width) 115 | 116 | elif kernel_type == 'gauss': 117 | assert sigma, 'sigma is not specified' 118 | assert phase != 0.5, 'phase 1/2 for gauss not implemented' 119 | 120 | center = (kernel_width + 1.)/2. 121 | print(center, kernel_width) 122 | sigma_sq = sigma * sigma 123 | 124 | for i in range(1, kernel.shape[0] + 1): 125 | for j in range(1, kernel.shape[1] + 1): 126 | di = (i - center)/2. 127 | dj = (j - center)/2. 128 | kernel[i - 1][j - 1] = np.exp(-(di * di + dj * dj)/(2 * sigma_sq)) 129 | kernel[i - 1][j - 1] = kernel[i - 1][j - 1]/(2. * np.pi * sigma_sq) 130 | elif kernel_type == 'lanczos': 131 | assert support, 'support is not specified' 132 | center = (kernel_width + 1) / 2. 133 | 134 | for i in range(1, kernel.shape[0] + 1): 135 | for j in range(1, kernel.shape[1] + 1): 136 | 137 | if phase == 0.5: 138 | di = abs(i + 0.5 - center) / factor 139 | dj = abs(j + 0.5 - center) / factor 140 | else: 141 | di = abs(i - center) / factor 142 | dj = abs(j - center) / factor 143 | 144 | 145 | pi_sq = np.pi * np.pi 146 | 147 | val = 1 148 | if di != 0: 149 | val = val * support * np.sin(np.pi * di) * np.sin(np.pi * di / support) 150 | val = val / (np.pi * np.pi * di * di) 151 | 152 | if dj != 0: 153 | val = val * support * np.sin(np.pi * dj) * np.sin(np.pi * dj / support) 154 | val = val / (np.pi * np.pi * dj * dj) 155 | 156 | kernel[i - 1][j - 1] = val 157 | 158 | 159 | else: 160 | assert False, 'wrong method name' 161 | 162 | kernel /= kernel.sum() 163 | 164 | return kernel 165 | 166 | def get_noise(input_depth, method, spatial_size, noise_type='u', var=1./10): 167 | """Returns a pytorch.Tensor of size 168 | (1 x `input_depth` x `spatial_size[0]` x `spatial_size[1]`) 169 | initialized in a specific way. 170 | Args: 171 | input_depth: number of channels in the tensor 172 | method: `noise` for fillting tensor with noise; `meshgrid` 173 | for np.meshgrid 174 | spatial_size: spatial size of the tensor to initialize 175 | noise_type: 'u' for uniform; 'n' for normal 176 | var: a factor, a noise will be multiplicated by. Basically it is 177 | standard deviation scaler. 178 | """ 179 | if isinstance(spatial_size, int): 180 | spatial_size = (spatial_size, spatial_size) 181 | if method == 'noise': 182 | shape = [1, input_depth, spatial_size[0], spatial_size[1]] 183 | net_input = torch.zeros(shape) 184 | 185 | fill_noise(net_input, noise_type) 186 | net_input *= var 187 | elif method == 'meshgrid': 188 | assert input_depth == 2 189 | X, Y = np.meshgrid(np.arange(0, spatial_size[1])/float(spatial_size[1]-1), np.arange(0, spatial_size[0])/float(spatial_size[0]-1)) 190 | meshgrid = np.concatenate([X[None,:], Y[None,:]]) 191 | net_input= np_to_torch(meshgrid) 192 | else: 193 | assert False 194 | 195 | return net_input 196 | 197 | def np_to_torch(img_np): 198 | '''Converts image in numpy.array to torch.Tensor. 199 | 200 | From C x W x H [0..1] to C x W x H [0..1] 201 | ''' 202 | return torch.from_numpy(img_np)[None, :] 203 | 204 | def torch_to_np(img_var): 205 | '''Converts an image in torch.Tensor format to np.array. 206 | 207 | From 1 x C x W x H [0..1] to C x W x H [0..1] 208 | ''' 209 | return img_var.detach().cpu().numpy()[0] 210 | 211 | def fill_noise(x, noise_type): 212 | """Fills tensor `x` with noise of type `noise_type`.""" 213 | if noise_type == 'u': 214 | x.uniform_() 215 | elif noise_type == 'n': 216 | x.normal_() 217 | else: 218 | assert False 219 | 220 | def get_image_grid(images_np, nrow=8): 221 | '''Creates a grid from a list of images by concatenating them.''' 222 | images_torch = [torch.from_numpy(x) for x in images_np] 223 | torch_grid = torchvision.utils.make_grid(images_torch, nrow) 224 | 225 | return torch_grid.numpy() 226 | 227 | def get_net(input_depth, NET_TYPE, pad, upsample_mode, n_channels=3, 228 | act_fun='LeakyReLU', skip_n33d=128, skip_n33u=128, skip_n11=4, 229 | num_scales=5, downsample_mode='stride', filter_size_up=3, 230 | filter_size_down=3, filter_size_skip=1): 231 | if NET_TYPE == 'ResNet': 232 | # TODO 233 | net = ResNet(input_depth, 3, 10, 16, 1, nn.BatchNorm2d, False) 234 | elif NET_TYPE == 'skip': 235 | net = skip(input_depth, n_channels, num_channels_down = [skip_n33d]*num_scales if isinstance(skip_n33d, int) else skip_n33d, 236 | num_channels_up = [skip_n33u]*num_scales if isinstance(skip_n33u, int) else skip_n33u, 237 | num_channels_skip = [skip_n11]*num_scales if isinstance(skip_n11, int) else skip_n11, 238 | upsample_mode=upsample_mode, downsample_mode=downsample_mode, 239 | need_sigmoid=False, need_bias=True, pad=pad, act_fun=act_fun, 240 | filter_size_up=filter_size_up, 241 | filter_size_down=filter_size_down, 242 | filter_skip_size=filter_size_skip) 243 | elif NET_TYPE == 'skip1d': 244 | net = skip1d(input_depth, n_channels, num_channels_down = [skip_n33d]*num_scales if isinstance(skip_n33d, int) else skip_n33d, 245 | num_channels_up = [skip_n33u]*num_scales if isinstance(skip_n33u, int) else skip_n33u, 246 | num_channels_skip = [skip_n11]*num_scales if isinstance(skip_n11, int) else skip_n11, 247 | upsample_mode=upsample_mode, downsample_mode=downsample_mode, 248 | need_sigmoid=False, need_bias=True, pad=pad, act_fun=act_fun, 249 | filter_size_up=filter_size_up, 250 | filter_size_down=filter_size_down, 251 | filter_skip_size=filter_size_skip) 252 | 253 | elif NET_TYPE == 'texture_nets': 254 | net = get_texture_nets(inp=input_depth, ratios = [32, 16, 8, 4, 2, 1], fill_noise=False,pad=pad) 255 | 256 | elif NET_TYPE =='UNet': 257 | net = UNet(num_input_channels=input_depth, num_output_channels=3, 258 | feature_scale=4, more_layers=0, concat_x=False, 259 | upsample_mode=upsample_mode, pad=pad, norm_layer=nn.BatchNorm2d, need_sigmoid=True, need_bias=True) 260 | elif NET_TYPE == 'identity': 261 | assert input_depth == 3 262 | net = nn.Sequential() 263 | else: 264 | assert False 265 | 266 | return net 267 | -------------------------------------------------------------------------------- /modules/lin_inverse.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ''' 4 | This script contains functions pertaining to linear inverse problems 5 | in section 4. 6 | ''' 7 | 8 | import os 9 | import sys 10 | import glob 11 | import tqdm 12 | import pdb 13 | 14 | import numpy as np 15 | from scipy import signal 16 | 17 | import torch 18 | from torch import nn 19 | import kornia 20 | 21 | import matplotlib.pyplot as plt 22 | import cv2 23 | 24 | def radon(imten, angles): 25 | ''' 26 | Compute forward radon operation 27 | 28 | Inputs: 29 | imten: (1, nimg, H, W) image tensor 30 | angles: (nangles) angles tensor -- should be on same device as 31 | imten 32 | 33 | Outputs: 34 | sinogram: (nimg, nangles, W) sinogram 35 | ''' 36 | nangles = len(angles) 37 | imten_rep = torch.repeat_interleave(imten, nangles, 0) 38 | 39 | imten_rot = kornia.rotate(imten_rep, angles) 40 | 41 | sinogram = imten_rot.sum(2).squeeze() 42 | 43 | return sinogram 44 | 45 | def get_video_coding_frames(video_size, nframes): 46 | ''' 47 | Get masks for video CS 48 | 49 | Inputs: 50 | video size: Size of the video cube 51 | nframes: Number of frames to combine into a single frame 52 | 53 | Outputs: 54 | masks: Binary masks of the same size as video_size 55 | ''' 56 | H, W, totalframes = video_size 57 | 58 | X, Y = np.mgrid[:H, :W] 59 | 60 | indices = np.random.randint(0, nframes, (H, W)) 61 | masks_sub = np.zeros((H, W, nframes)) 62 | masks_sub[X, Y, indices] = 1 63 | 64 | masks = np.tile(masks_sub, [1, 1, totalframes//nframes + 1]) 65 | 66 | return masks[..., :totalframes] 67 | 68 | def video2codedvideo(video_ten, masks_ten, nframes): 69 | ''' 70 | Convert video to coded video, similar to Hitomi et al. 71 | 72 | Inputs: 73 | video_ten: (1, totalframes, H, W) video tensor 74 | masks_ten: (1, totalframes, H, W) mask tensor 75 | nframes: Number of frames to combine together 76 | 77 | Outputs: 78 | codedvideo_ten: (1, totalframems//nframes + 1, H, W) coded video 79 | ''' 80 | codedvideo_list = [] 81 | 82 | for idx in range(0, video_ten.shape[1], nframes): 83 | video_chunk = video_ten[:, idx:idx+nframes, :, :] 84 | masks_chunk = masks_ten[:, idx:idx+nframes, :, :] 85 | 86 | codedvideo = (video_chunk*masks_chunk).sum(1, keepdim=True) 87 | codedvideo_list.append(codedvideo) 88 | 89 | if idx < video_ten.shape[1]: 90 | video_chunk = video_ten[:, idx:, :, :] 91 | masks_chunk = masks_ten[:, idx:, :, :] 92 | 93 | codedvideo = (video_chunk*masks_chunk).sum(1, keepdim=True) 94 | codedvideo_list.append(codedvideo) 95 | 96 | codedvideo_ten = torch.cat(codedvideo_list, dim=1) 97 | 98 | return codedvideo_ten -------------------------------------------------------------------------------- /modules/losses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ''' 4 | This script implements all loss functions utilized in our experiments 5 | including L1 loss, L2 loss, and TV norm. 6 | ''' 7 | 8 | import torch 9 | 10 | 11 | class TVNorm(): 12 | def __init__(self, mode='l1'): 13 | self.mode = mode 14 | def __call__(self, img): 15 | grad_x = img[..., 1:, 1:] - img[..., 1:, :-1] 16 | grad_y = img[..., 1:, 1:] - img[..., :-1, 1:] 17 | 18 | if self.mode == 'isotropic': 19 | return torch.sqrt(grad_x**2 + grad_y**2).mean() 20 | elif self.mode == 'l1': 21 | return abs(grad_x).mean() + abs(grad_y).mean() 22 | else: 23 | return (grad_x.pow(2) + grad_y.pow(2)).mean() 24 | 25 | class L1Norm(): 26 | def __init__(self): 27 | pass 28 | def __call__(self, x): 29 | return abs(x).mean() 30 | 31 | class L2Norm(): 32 | def __init__(self): 33 | pass 34 | def __call__(self, x): 35 | return (x.pow(2)).mean() 36 | -------------------------------------------------------------------------------- /modules/models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ''' 4 | This script implements models other than the one proposed by [Ulyanov et al. 2018], 5 | and [Heckel et al. 2018]. 6 | ''' 7 | 8 | import torch 9 | 10 | 11 | class SimpleForwardMLP(torch.nn.Module): 12 | """ 13 | Simple feed forward MLP 14 | 15 | Inputs: 16 | n_inputs: Number of input dim 17 | ksizes: (nlayers,) Kernel sizes for each layer 18 | 19 | """ 20 | 21 | def __init__(self, n_inputs, ksizes): 22 | super().__init__() 23 | 24 | self.relu = torch.nn.ReLU(inplace=True) 25 | 26 | modules = [torch.nn.Linear(n_inputs, ksizes[0]), self.relu] 27 | 28 | for idx in range(1, len(ksizes)): 29 | modules.append(torch.nn.Linear(ksizes[idx - 1], ksizes[idx])) 30 | if idx < len(ksizes) - 1: 31 | modules.append(self.relu) 32 | self.module_list = torch.nn.Sequential(*modules) 33 | 34 | def forward(self, X, **kwargs): 35 | return self.module_list(X) 36 | 37 | 38 | class SimpleForward2D(torch.nn.Module): 39 | """ 40 | Simple feed forward 2D convolutional neural network 41 | 42 | Inputs: 43 | n_inputs: Number of input channels 44 | n_outputs: Number of output channels 45 | nconvs: (nlayers,) filters for each layer 46 | 47 | """ 48 | 49 | def __init__(self, n_inputs, n_outputs, nconvs): 50 | super().__init__() 51 | 52 | self.relu = torch.nn.ReLU(inplace=True) 53 | 54 | modules = [ 55 | torch.nn.Conv2d(n_inputs, nconvs[0], (15, 15), padding=(7, 7)), 56 | self.relu, 57 | torch.nn.BatchNorm2d(nconvs[0]), 58 | ] 59 | 60 | for idx in range(len(nconvs) - 1): 61 | modules.append( 62 | torch.nn.Conv2d(nconvs[idx], nconvs[idx + 1], (3, 3), padding=(1, 1)) 63 | ) 64 | modules.append(self.relu) 65 | modules.append(torch.nn.BatchNorm2d(nconvs[idx + 1])) 66 | 67 | modules.append(torch.nn.Conv2d(nconvs[-1], n_outputs, (3, 3), padding=(1, 1))) 68 | modules.append(torch.nn.BatchNorm2d(n_outputs)) 69 | 70 | self.module_list = torch.nn.Sequential(*modules) 71 | 72 | def forward(self, X, **kwargs): 73 | return self.module_list(X) 74 | 75 | class SimpleForward1D(torch.nn.Module): 76 | """ 77 | Simple feed forward 2D convolutional neural network 78 | 79 | Inputs: 80 | n_inputs: Number of input channels 81 | n_outputs: Number of output channels 82 | nlayers: Number of layers 83 | ksizes: (nlayers,) Kernel sizes for each layer 84 | nconvs: (nlayers,) filters for each layer 85 | 86 | """ 87 | 88 | def __init__(self, n_inputs, n_outputs, nlayers, nconvs): 89 | super().__init__() 90 | 91 | self.relu = torch.nn.LeakyReLU(0.2, inplace=True) 92 | 93 | modules = [ 94 | torch.nn.Conv1d(n_inputs, nconvs[0], 15, padding=7), 95 | self.relu, 96 | torch.nn.BatchNorm1d(nconvs[0]), 97 | ] 98 | 99 | for idx in range(nlayers - 1): 100 | modules.append(torch.nn.Conv1d(nconvs[idx], nconvs[idx + 1], 3, padding=1)) 101 | modules.append(self.relu) 102 | modules.append(torch.nn.BatchNorm1d(nconvs[idx + 1])) 103 | 104 | modules.append(torch.nn.Conv1d(nconvs[-1], n_outputs, 3, padding=1)) 105 | modules.append(torch.nn.BatchNorm1d(n_outputs)) 106 | 107 | self.module_list = torch.nn.Sequential(*modules) 108 | 109 | def forward(self, X, **kwargs): 110 | return self.module_list(X) 111 | 112 | class UNetND(torch.nn.Module): 113 | """ 114 | A generalrized ND Unet that can be instantiated to 1D, 2D or 3D UNet 115 | 116 | Inputs: 117 | n_inputs: Number of input channels 118 | n_outputs: Number of output channels 119 | ndim: 1, 2, or 3 for 1d, 2d, 3d 120 | ndown: Number of downsampling layers 121 | init_nconv: Number of convolutions for first layer. Subsequent 122 | downsampled layers have twice the number of layers as the first 123 | one. 124 | 125 | """ 126 | 127 | def __init__(self, n_inputs, n_outputs, ndim=2, init_nconv=32): 128 | super().__init__() 129 | 130 | self.relu = torch.nn.LeakyReLU(0.2, inplace=True) 131 | self.sigmoid = torch.nn.Sigmoid() 132 | 133 | # Create operatorsr 134 | if ndim == 1: 135 | self.conv = torch.nn.Conv1d 136 | self.bn = torch.nn.BatchNorm1d 137 | self.maxpool = torch.nn.MaxPool1d(2) 138 | self.upmode = "linear" 139 | ksize = 5 140 | stride = 2 141 | skip_ksize = 1 142 | pad = 2 143 | elif ndim == 2: 144 | self.conv = torch.nn.Conv2d 145 | self.bn = torch.nn.BatchNorm2d 146 | self.maxpool = torch.nn.MaxPool2d(2) 147 | self.upmode = "bilinear" 148 | ksize = (5, 5) 149 | stride = (2, 2) 150 | skip_ksize = (1, 1) 151 | pad = (2, 2) 152 | else: 153 | self.conv = torch.nn.Conv3d 154 | self.bn = torch.nn.BatchNorm3d 155 | self.maxpool = torch.nn.MaxPool3d(2) 156 | self.upmode = "trilinear" 157 | ksize = (5, 5, 5) 158 | stride = (2, 2, 2) 159 | skip_ksize = (1, 1, 1) 160 | pad = (2, 2, 2) 161 | 162 | # Create a static list of convolutions -- simpler 163 | nconvs = [init_nconv] * 5 164 | # nconvs = [16, 32, 64, 128, 128] 165 | 166 | # Downsampling 167 | self.do_conv1 = self.conv(n_inputs, nconvs[0], ksize, padding=pad) 168 | self.do_bn1 = self.bn(nconvs[0]) 169 | self.do_conv2 = self.conv( 170 | nconvs[0], nconvs[1], ksize, stride=stride, padding=pad 171 | ) 172 | self.do_bn2 = self.bn(nconvs[1]) 173 | self.do_conv3 = self.conv( 174 | nconvs[1], nconvs[2], ksize, stride=stride, padding=pad 175 | ) 176 | self.do_bn3 = self.bn(nconvs[2]) 177 | self.do_conv4 = self.conv( 178 | nconvs[2], nconvs[3], ksize, stride=stride, padding=pad 179 | ) 180 | self.do_bn4 = self.bn(nconvs[3]) 181 | self.do_conv5 = self.conv( 182 | nconvs[3], nconvs[4], ksize, stride=stride, padding=pad 183 | ) 184 | self.do_bn5 = self.bn(nconvs[4]) 185 | 186 | # Skip connection convolutions 187 | self.skip1 = self.conv(nconvs[0], nconvs[1] // 2, skip_ksize) 188 | self.skip2 = self.conv(nconvs[1], nconvs[2] // 2, skip_ksize) 189 | self.skip3 = self.conv(nconvs[2], nconvs[3] // 2, skip_ksize) 190 | self.skip4 = self.conv(nconvs[3], nconvs[4] // 2, skip_ksize) 191 | 192 | # Upsampling 193 | self.up_conv1 = self.conv(nconvs[4], nconvs[4] // 2, ksize, padding=pad) 194 | self.up_bn1 = self.bn(nconvs[4] // 2) 195 | self.up_conv2 = self.conv(nconvs[4], nconvs[3] // 2, ksize, padding=pad) 196 | self.up_bn2 = self.bn(nconvs[3] // 2) 197 | self.up_conv3 = self.conv(nconvs[3], nconvs[2] // 2, ksize, padding=pad) 198 | self.up_bn3 = self.bn(nconvs[2] // 2) 199 | self.up_conv4 = self.conv(nconvs[2], nconvs[1] // 2, ksize, padding=pad) 200 | self.up_bn4 = self.bn(nconvs[1] // 2) 201 | self.up_conv5 = self.conv(nconvs[1], n_outputs, ksize, padding=pad) 202 | 203 | def forward(self, X): 204 | # Downsample 205 | X1 = self.relu(self.do_bn1(self.do_conv1(X))) 206 | X2 = self.relu(self.do_bn2(self.do_conv2(X1))) 207 | X3 = self.relu(self.do_bn3(self.do_conv3(X2))) 208 | X4 = self.relu(self.do_bn4(self.do_conv4(X3))) 209 | X5 = self.relu(self.do_bn5(self.do_conv5(X4))) 210 | 211 | # Upsample 212 | X6 = self.relu(self.up_bn1(self.up_conv1(X5))) 213 | X6 = torch.nn.Upsample(size=X4.shape[2:], mode=self.upmode)(X6) 214 | 215 | X7 = torch.cat((X6, self.relu(self.skip4(X4))), dim=1) 216 | X7 = self.relu(self.up_bn2(self.up_conv2(X7))) 217 | X7 = torch.nn.Upsample(size=X3.shape[2:], mode=self.upmode)(X7) 218 | 219 | X8 = torch.cat((X7, self.relu(self.skip3(X3))), dim=1) 220 | X8 = self.relu(self.up_bn3(self.up_conv3(X8))) 221 | X8 = torch.nn.Upsample(size=X2.shape[2:], mode=self.upmode)(X8) 222 | 223 | X9 = torch.cat((X8, self.relu(self.skip2(X2))), dim=1) 224 | X9 = self.relu(self.up_bn4(self.up_conv4(X9))) 225 | X9 = torch.nn.Upsample(size=X1.shape[2:], mode=self.upmode)(X9) 226 | 227 | X10 = torch.cat((X9, self.relu(self.skip1(X1))), dim=1) 228 | X10 = self.up_conv5(X10) 229 | 230 | return X10 231 | -------------------------------------------------------------------------------- /modules/spectral.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ''' 4 | This module implements functions required for hyperspectral image experiments 5 | in section 4. 6 | ''' 7 | 8 | import numpy as np 9 | import torch 10 | 11 | # We will use tensorly to compute outer product 12 | import tensorly 13 | tensorly.set_backend('pytorch') 14 | from tensorly import decomposition 15 | 16 | from scipy.sparse import linalg 17 | 18 | def lr_decompose(cube, rank): 19 | ''' 20 | Perform a truncated SVD 21 | 22 | Inputs: 23 | cube: (H, W, nwvl) hyperspectral cube 24 | rank: Rank to decompose the cube 25 | 26 | Outputs: 27 | cube_lr: Low rank decomposition 28 | ''' 29 | H, W, nwvl = cube.shape 30 | hsmat = cube.reshape(H*W, nwvl) 31 | 32 | u, s, vt = linalg.svds(hsmat, k=rank) 33 | 34 | hsmat_lr = u.dot(np.diag(s)).dot(vt) 35 | cube_lr = hsmat_lr.reshape(H, W, nwvl) 36 | 37 | return cube_lr 38 | 39 | def tucker_decompose(cube, rank, max_iters=1000): 40 | ''' 41 | Perform a truncated tucker decomposition 42 | 43 | Inputs: 44 | cube: (H, W, T) numpy array 45 | rank: Rank for tucker decomposition 46 | max_iters: Maximum iterations for tucker decomposition 47 | 48 | Outputs: 49 | cube_lr: Low tucker rank decompositoin 50 | ''' 51 | cube_ten = torch.tensor(cube).cuda() 52 | tucker_core, tucker_factors = decomposition.tucker(cube_ten, 53 | (rank, rank, rank), 54 | n_iter_max=max_iters) 55 | 56 | cube_approx_tucker = tensorly.tucker_to_tensor((tucker_core, 57 | tucker_factors)) 58 | cube_approx_tucker = cube_approx_tucker.cpu().numpy() 59 | 60 | return cube_approx_tucker 61 | -------------------------------------------------------------------------------- /modules/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ''' 4 | Miscellaneous utilities that are required in all experiments. 5 | ''' 6 | 7 | # System imports 8 | import pickle 9 | 10 | import torch 11 | 12 | # Scientific computing 13 | import numpy as np 14 | import scipy.linalg as lin 15 | from scipy import io 16 | from scipy import signal 17 | from scipy.sparse import linalg 18 | 19 | # Plotting 20 | import cv2 21 | import matplotlib.pyplot as plt 22 | from matplotlib.lines import Line2D 23 | 24 | try: 25 | from cosine_annealing_with_warmup import CosineAnnealingWarmupRestarts 26 | except ModuleNotFoundError: 27 | from .cosine_annealing_with_warmup import CosineAnnealingWarmupRestarts 28 | 29 | def nextpow2(x): 30 | ''' 31 | Return smallest number larger than x and a power of 2. 32 | ''' 33 | logx = np.ceil(np.log2(x)) 34 | return pow(2, logx) 35 | 36 | def normalize(x, fullnormalize=False): 37 | ''' 38 | Normalize input to lie between 0, 1. 39 | 40 | Inputs: 41 | x: Input signal 42 | fullnormalize: If True, normalize such that minimum is 0 and 43 | maximum is 1. Else, normalize such that maximum is 1 alone. 44 | 45 | Outputs: 46 | xnormalized: Normalized x. 47 | ''' 48 | 49 | if x.sum() == 0: 50 | return x 51 | 52 | xmax = x.max() 53 | 54 | if fullnormalize: 55 | xmin = x.min() 56 | else: 57 | xmin = 0 58 | 59 | xnormalized = (x - xmin)/(xmax - xmin) 60 | 61 | return xnormalized 62 | 63 | def rsnr(x, xhat): 64 | ''' 65 | Compute reconstruction SNR for a given signal and its reconstruction. 66 | 67 | Inputs: 68 | x: Ground truth signal (ndarray) 69 | xhat: Approximation of x 70 | 71 | Outputs: 72 | rsnr_val: RSNR = 20log10(||x||/||x-xhat||) 73 | ''' 74 | xn = lin.norm(x.reshape(-1)) 75 | en = lin.norm((x-xhat).reshape(-1)) + 1e-12 76 | rsnr_val = 20*np.log10(xn/en) 77 | 78 | return rsnr_val 79 | 80 | def psnr(x, xhat): 81 | ''' Compute Peak Signal to Noise Ratio in dB 82 | 83 | Inputs: 84 | x: Ground truth signal 85 | xhat: Reconstructed signal 86 | 87 | Outputs: 88 | snrval: PSNR in dB 89 | ''' 90 | err = x - xhat 91 | denom = np.mean(pow(err, 2)) + 1e-12 92 | 93 | snrval = 10*np.log10(np.max(x)/denom) 94 | 95 | return snrval 96 | 97 | def measure(x, noise_snr=40, tau=100): 98 | ''' Realistic sensor measurement with readout and photon noise 99 | 100 | Inputs: 101 | noise_snr: Readout noise in electron count 102 | tau: Integration time. Poisson noise is created for x*tau. 103 | (Default is 100) 104 | 105 | Outputs: 106 | x_meas: x with added noise 107 | ''' 108 | x_meas = np.copy(x) 109 | 110 | noise = np.random.randn(x_meas.size).reshape(x_meas.shape)*noise_snr 111 | 112 | # First add photon noise, provided it is not infinity 113 | if tau != float('Inf'): 114 | x_meas = x_meas*tau 115 | 116 | x_meas[x > 0] = np.random.poisson(x_meas[x > 0]) 117 | x_meas[x <= 0] = -np.random.poisson(-x_meas[x <= 0]) 118 | 119 | x_meas = (x_meas + noise)/tau 120 | 121 | else: 122 | x_meas = x_meas + noise 123 | 124 | return x_meas 125 | 126 | def rician(sig, noise_snr): 127 | ''' 128 | Add Rician noise 129 | 130 | Inputs: 131 | sig: N dimensional signal 132 | noise_snr: std. dev of input noise 133 | 134 | Outputs: 135 | sig_rician: Rician corrupted signal 136 | ''' 137 | n1 = np.random.randn(*sig.shape)*noise_snr 138 | n2 = np.random.randn(*sig.shape)*noise_snr 139 | 140 | return np.sqrt((sig + n1)**2 + n2**2) 141 | 142 | def resize(cube, scale): 143 | ''' 144 | Resize a multi-channel image 145 | 146 | Inputs: 147 | cube: (H, W, nchan) image stack 148 | scale: Scaling 149 | ''' 150 | H, W, nchan = cube.shape 151 | 152 | im0_lr = cv2.resize(cube[..., 0], None, fx=scale, fy=scale) 153 | Hl, Wl = im0_lr.shape 154 | 155 | cube_lr = np.zeros((Hl, Wl, nchan), dtype=cube.dtype) 156 | 157 | for idx in range(nchan): 158 | cube_lr[..., idx] = cv2.resize(cube[..., idx], None, 159 | fx=scale, fy=scale, 160 | interpolation=cv2.INTER_AREA) 161 | return cube_lr 162 | 163 | def moduloclip(cube, mulsize): 164 | ''' 165 | Clip a cube to have multiples of mulsize 166 | 167 | Inputs: 168 | cube: (H, W, T) sized cube 169 | mulsize: (h, w) tuple having smallest stride size 170 | 171 | Outputs: 172 | cube_clipped: Clipped cube with size equal to multiples of h, w 173 | ''' 174 | H, W, T = cube.shape 175 | 176 | H1 = mulsize[0]*(H // mulsize[0]) 177 | W1 = mulsize[1]*(W // mulsize[1]) 178 | 179 | cube_clipped = cube[:H1, :W1, :] 180 | 181 | return cube_clipped 182 | 183 | def implay(cube, delay=20): 184 | ''' 185 | Play hyperspectral image as a video 186 | ''' 187 | if cube.dtype != np.uint8: 188 | cube = (255*cube/cube.max()).astype(np.uint8) 189 | 190 | T = cube.shape[-1] 191 | 192 | for idx in range(T): 193 | cv2.imshow('Video', cube[..., idx]) 194 | cv2.waitKey(delay) 195 | 196 | def build_montage(images): 197 | ''' 198 | Build a montage out of images 199 | ''' 200 | nimg, H, W = images.shape 201 | 202 | nrows = int(np.ceil(np.sqrt(nimg))) 203 | ncols = int(np.ceil(nimg/nrows)) 204 | 205 | montage_im = np.zeros((H*nrows, W*ncols), dtype=np.float32) 206 | 207 | cnt = 0 208 | for r in range(nrows): 209 | for c in range(ncols): 210 | h1 = r*H 211 | h2 = (r+1)*H 212 | w1 = c*W 213 | w2 = (c+1)*W 214 | 215 | if cnt == nimg: 216 | break 217 | 218 | montage_im[h1:h2, w1:w2] = images[cnt, ...] 219 | cnt += 1 220 | 221 | return montage_im 222 | 223 | def get_matrix(nrows, ncols, rank, noise_type, signal_type, 224 | noise_snr=5, tau=1000): 225 | ''' 226 | Get a matrix for simulations 227 | 228 | Inputs: 229 | nrows, ncols: Size of the matrix 230 | rank: Rank of the matrix 231 | noise_type: Type of the noise to add. Currently None, gaussian, 232 | and poisson 233 | signal_type: Type of the signal itself. Currently gaussian and 234 | piecewise constant 235 | noise_snr: Amount of noise to add in terms of 236 | ''' 237 | if signal_type == 'gaussian': 238 | U = np.random.randn(nrows, rank) 239 | V = np.random.randn(rank, ncols) 240 | elif signal_type == 'piecewise': 241 | nlocs = 10 242 | 243 | U = np.zeros((nrows, rank)) 244 | V = np.zeros((rank, ncols)) 245 | 246 | for idx in range(rank): 247 | u_locs = np.random.randint(0, nrows, nlocs) 248 | v_locs = np.random.randint(0, ncols, nlocs) 249 | 250 | U[u_locs, idx] = np.random.randn(nlocs) 251 | V[idx, v_locs] = np.random.randn(nlocs) 252 | 253 | U = np.cumsum(U, 0) 254 | V = np.cumsum(V, 1) 255 | else: 256 | raise AttributeError('Signal type not implemented') 257 | 258 | mat = normalize(U.dot(V), True) 259 | 260 | if noise_type == 'gaussian': 261 | mat_noisy = measure(mat, noise_snr, float('inf')) 262 | elif noise_type == 'poisson': 263 | mat_noisy = measure(mat, noise_snr, tau) 264 | elif noise_type == 'rician': 265 | noise1 = np.random.randn(nrows, ncols)*noise_snr 266 | noise2 = np.random.randn(nrows, ncols)*noise_snr 267 | 268 | mat_noisy = np.sqrt((mat + noise1)**2 + noise2**2) 269 | else: 270 | raise AttributeError('Noise type not implemented') 271 | 272 | return mat_noisy, mat 273 | 274 | def get_pca(nrows, ndata, rank, noise_type, signal_type, 275 | noise_snr=5, tau=1000): 276 | ''' 277 | Get PCA data 278 | 279 | Inputs: 280 | nrows: Number of rows in data 281 | ndata: Number of data points 282 | rank: Intrinsic dimension 283 | noise_type: Type of the noise to add. Currently None, gaussian, 284 | and poisson 285 | signal_type: Type of the signal itself. Currently gaussian and 286 | piecewise constant 287 | noise_snr: Amount of noise to add in terms of 288 | ''' 289 | # Generate normalized coefficients 290 | coefs = np.random.randn(rank, ndata) 291 | coefs_norm = np.sqrt((coefs*coefs).sum(0)).reshape(1, ndata) 292 | coefs = coefs/coefs_norm 293 | 294 | if signal_type == 'gaussian': 295 | basis = np.random.randn(nrows, rank) 296 | elif signal_type == 'piecewise': 297 | nlocs = 10 298 | 299 | basis = np.zeros((nrows, rank)) 300 | 301 | for idx in range(rank): 302 | u_locs = np.random.randint(0, nrows, nlocs) 303 | basis[u_locs, idx] = np.random.randn(nlocs) 304 | 305 | basis = np.cumsum(basis, 0) 306 | else: 307 | raise AttributeError('Signal type not implemented') 308 | 309 | # Compute orthogonal basis with QR decomposition 310 | basis, _ = np.linalg.qr(basis) 311 | mat = basis.dot(coefs) 312 | 313 | if noise_type == 'gaussian': 314 | mat_noisy = measure(mat, noise_snr, float('inf')) 315 | elif noise_type == 'poisson': 316 | mat_noisy = measure(mat, noise_snr, tau) 317 | elif noise_type == 'rician': 318 | noise1 = np.random.randn(nrows, ndata)*noise_snr 319 | noise2 = np.random.randn(nrows, ndata)*noise_snr 320 | 321 | mat_noisy = np.sqrt((mat + noise1)**2 + noise2**2) 322 | else: 323 | raise AttributeError('Noise type not implemented') 324 | 325 | return mat_noisy, mat, basis 326 | 327 | def get_inp(tensize, const=10.0): 328 | ''' 329 | Wrapper to get a variable on graph 330 | ''' 331 | inp = torch.rand(tensize).cuda()/const 332 | inp = torch.autograd.Variable(inp, requires_grad=True).cuda() 333 | inp = torch.nn.Parameter(inp) 334 | 335 | return inp 336 | 337 | def get_2d_posencode_inp(H, W, n_inputs): 338 | ''' 339 | Get positionally encoded inputs for inpainting tasks 340 | 341 | https://bmild.github.io/fourfeat/ 342 | ''' 343 | X, Y = np.mgrid[:H, :W] 344 | coords = np.hstack(((10*X/H).reshape(-1, 1), (10*Y/W).reshape(-1, 1))) 345 | 346 | freqs = np.random.randn(2, n_inputs) 347 | 348 | angles = coords.dot(freqs) 349 | 350 | sin_vals = np.sin(2*np.pi*angles) 351 | cos_vals = np.cos(2*np.pi*angles) 352 | 353 | posencode_vals = np.hstack((sin_vals, cos_vals)).astype(np.float32) 354 | 355 | inp = posencode_vals.reshape(H, W, 2*n_inputs) 356 | inp = torch.autograd.Variable(torch.tensor(inp), requires_grad=True).cuda() 357 | inp = torch.nn.Parameter(inp.permute(2, 0, 1)[None, ...]) 358 | 359 | return inp 360 | 361 | def get_1d_posencode_inp(H, n_inputs): 362 | ''' 363 | Get positionally encoded inputs for inpainting tasks 364 | 365 | https://bmild.github.io/fourfeat/ 366 | ''' 367 | X = 10*np.arange(H).reshape(-1, 1)/H 368 | 369 | freqs = np.random.randn(1, n_inputs) 370 | 371 | angles = X.dot(freqs) 372 | 373 | sin_vals = np.sin(2*np.pi*angles) 374 | cos_vals = np.cos(2*np.pi*angles) 375 | 376 | posencode_vals = np.hstack((sin_vals, cos_vals)).astype(np.float32) 377 | 378 | inp = posencode_vals.reshape(H, 2*n_inputs) 379 | inp = torch.autograd.Variable(torch.tensor(inp), requires_grad=True).cuda() 380 | inp = torch.nn.Parameter(inp.permute(1, 0)[None, ...]) 381 | 382 | return inp 383 | 384 | def lr_decompose(mat, rank=6): 385 | ''' 386 | Low rank decomposition 387 | ''' 388 | u, s, vt = linalg.svds(mat, k=rank) 389 | mat_lr = u.dot(np.diag(s)).dot(vt) 390 | 391 | return mat_lr 392 | 393 | def get_scheduler(scheduler_type, optimizer, args): 394 | ''' 395 | Get a scheduler 396 | 397 | Inputs: 398 | scheduler_type: 'none', 'step', 'exponential', 'cosine' 399 | optimizer: One of torch.optim optimizers 400 | args: Namspace containing arguments relevant to each optimizer 401 | 402 | Outputs: 403 | scheduler: A torch learning rate scheduler 404 | ''' 405 | if scheduler_type == 'none': 406 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 407 | step_size=args.epochs) 408 | elif scheduler_type == 'step': 409 | # Compute gamma 410 | gamma = pow(10, -1/(args.epochs/args.step_size)) 411 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 412 | step_size=args.step_size, 413 | gamma=gamma) 414 | elif scheduler_type == 'exponential': 415 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 416 | gamma=args.gamma) 417 | elif scheduler_type == 'cosine': 418 | scheduler = CosineAnnealingWarmupRestarts(optimizer, 419 | first_cycle_steps=200, 420 | cycle_mult=1.0, 421 | max_lr=args.max_lr, 422 | min_lr=args.min_lr, 423 | warmup_steps=50, 424 | gamma=args.gamma) 425 | 426 | return scheduler -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tqdm 3 | opencv_python 4 | scikit_image 5 | scipy 6 | torchvision==0.9.1+cu102 7 | torchaudio==0.8.1 8 | tensorly==0.6.0 9 | torch==1.8.1+cu102 10 | matplotlib 11 | bm3d==3.0.7 12 | kornia==0.4.1 13 | scikit_learn==0.24.2 14 | -------------------------------------------------------------------------------- /run_figure11.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ''' 4 | This script replicates the DeepTensor results in Figure 7. The accuracy is 5 | different from the paper as the video has been downsampled and clipped to 6 | fit supplementary material requirements. 7 | ''' 8 | 9 | import os 10 | import sys 11 | import tqdm 12 | import copy 13 | import time 14 | 15 | import numpy as np 16 | from scipy import io 17 | from skimage.metrics import structural_similarity as ssim_func 18 | 19 | import torch 20 | 21 | import matplotlib.pyplot as plt 22 | plt.gray() 23 | import cv2 24 | 25 | sys.path.append('../modules') 26 | 27 | import models 28 | import utils 29 | import losses 30 | import deep_prior 31 | import deep_decoder 32 | import lin_inverse 33 | 34 | if __name__ == '__main__': 35 | expname = 'cat_video' 36 | nframes = 8 37 | rank = 32 38 | 39 | # Network parameters 40 | nettype = 'dip' 41 | n_inputs = rank 42 | init_nconv = 64 43 | num_channels_up = 3 44 | 45 | # Noise parameters 46 | scaling = 1 47 | tau = 100 48 | noise_snr = 2 49 | 50 | # Learning constants 51 | learning_rate = 1e-3 52 | epochs = 3000 53 | reg_noise_std = 1.0/30.0 54 | exp_weight = 0.99 55 | 56 | # Load data 57 | data = io.loadmat('data/%s.mat'%expname) 58 | cube = data['hypercube'].astype(np.float32) 59 | cube = utils.resize(cube/cube.max(), scaling) 60 | 61 | H, W, totalframes = cube.shape 62 | 63 | if rank == 'max': 64 | rank = min(H, W, totalframes) 65 | n_inputs = rank 66 | 67 | cube_noisy = utils.measure(cube, noise_snr, tau) 68 | 69 | # Get masks 70 | masks = lin_inverse.get_video_coding_frames(cube.shape, nframes) 71 | 72 | # Get video measurements 73 | measurements = np.zeros((H, W, totalframes//nframes+1), dtype=np.float32) 74 | 75 | for idx in range(0, totalframes, nframes): 76 | cube_chunk = cube_noisy[..., idx:idx+nframes] 77 | masks_chunk = masks[..., idx:idx+nframes] 78 | measurements[..., idx//nframes] = (cube_chunk*masks_chunk).sum(2) 79 | 80 | if idx < totalframes: 81 | cube_chunk = cube_noisy[..., idx:] 82 | masks_chunk = masks[..., idx:] 83 | measurements[..., idx//nframes + 1] = (cube_chunk*masks_chunk).sum(2) 84 | 85 | # Send data to device 86 | measurements_ten = torch.tensor(measurements)[None, ...] 87 | measurements_ten = measurements_ten.permute(0, 3, 1, 2).cuda() 88 | cube_ten = torch.tensor(cube)[None, ...].permute(0, 3, 1, 2).cuda() 89 | masks_ten = torch.tensor(masks)[None, ...].permute(0, 3, 1, 2).cuda() 90 | 91 | if nettype == 'unet': 92 | im_net = models.UNetND(n_inputs, rank, 2, init_nconv).cuda() 93 | spec_net = models.UNetND(n_inputs, rank, 1, init_nconv).cuda() 94 | 95 | im_inp = utils.get_inp([1, n_inputs, H, W]) 96 | spec_inp = utils.get_inp([1, n_inputs, totalframes]) 97 | elif nettype == 'dip': 98 | im_net = deep_prior.get_net(n_inputs, 'skip', 'reflection', 99 | upsample_mode='bilinear', 100 | skip_n33d=init_nconv, 101 | skip_n33u=init_nconv, 102 | num_scales=5, 103 | n_channels=rank).cuda() 104 | spec_net = deep_prior.get_net(n_inputs, 'skip1d', 'reflection', 105 | upsample_mode='linear', 106 | skip_n33d=init_nconv, 107 | skip_n33u=init_nconv, 108 | num_scales=5, 109 | n_channels=rank).cuda() 110 | 111 | im_inp = utils.get_inp([1, n_inputs, H, W]) 112 | spec_inp = utils.get_inp([1, n_inputs, totalframes]) 113 | elif nettype == 'dd': 114 | nchans = [init_nconv]*num_channels_up 115 | im_net = deep_decoder.decodernw(rank, nchans).cuda() 116 | spec_net = deep_decoder.decodernw1d(rank, nchans).cuda() 117 | 118 | H1 = H // pow(2, num_channels_up) 119 | W1 = W // pow(2, num_channels_up) 120 | nwvl1 = totalframes // pow(2, num_channels_up) 121 | im_inp = utils.get_inp([1, init_nconv, H1, W1]) 122 | spec_inp = utils.get_inp([1, init_nconv, nwvl1]) 123 | 124 | # Switch to trairning mode 125 | im_net.train() 126 | spec_net.train() 127 | 128 | net_params = list(im_net.parameters()) + list(spec_net.parameters()) 129 | inp_params = [im_inp] + [spec_inp] 130 | 131 | im_inp_per = im_inp.detach().clone() 132 | spec_inp_per = spec_inp.detach().clone() 133 | cube_estim_avg = None 134 | 135 | params = net_params + inp_params 136 | 137 | criterion_l1 = losses.L2Norm() 138 | 139 | loss_array = np.zeros(epochs) 140 | mse_array = np.zeros(epochs) 141 | time_array = np.zeros(epochs) 142 | 143 | optimizer = torch.optim.Adam(lr=learning_rate, params=params) 144 | 145 | best_loss = float('inf') 146 | best_epoch = 0 147 | tic = time.time() 148 | tbar = tqdm.tqdm(range(epochs)) 149 | for idx in tbar: 150 | # Perturb inputs 151 | im_inp_perturbed = im_inp + im_inp_per.normal_()*reg_noise_std 152 | spec_inp_perturbed = spec_inp + spec_inp_per.normal_()*reg_noise_std 153 | 154 | U_img = im_net(im_inp_perturbed) 155 | V = spec_net(spec_inp_perturbed) 156 | 157 | U = U_img.reshape(-1, rank, H*W).permute(0, 2, 1) 158 | mat_estim = torch.bmm(U, V) 159 | 160 | cube_estim = mat_estim.reshape(1, H, W, totalframes).permute(0, 3, 1, 2) 161 | 162 | meas_estim = lin_inverse.video2codedvideo(cube_estim, masks_ten, nframes) 163 | 164 | loss = criterion_l1(measurements_ten - meas_estim) 165 | 166 | optimizer.zero_grad() 167 | loss.backward() 168 | optimizer.step() 169 | 170 | loss_item = loss.item() 171 | mse_array[idx] = ((cube_ten - cube_estim)**2).mean().item() 172 | time_array[idx] = time.time() - tic 173 | 174 | tbar.set_description('%.4e'%mse_array[idx]) 175 | tbar.refresh() 176 | 177 | if loss_item < best_loss: 178 | best_loss = loss_item 179 | best_epoch = idx 180 | best_cube_estim = copy.deepcopy(cube_estim.detach()) 181 | 182 | # Averaging as per original code 183 | if cube_estim_avg is None: 184 | cube_estim_avg = cube_estim.detach() 185 | else: 186 | cube_estim_avg = exp_weight*cube_estim_avg +\ 187 | (1 - exp_weight)*cube_estim.detach() 188 | 189 | with torch.no_grad(): 190 | img_idx = idx%totalframes 191 | ref = cube[..., img_idx] 192 | diff = abs(cube_ten - cube_estim).mean(2).squeeze().detach().cpu() 193 | img = cube_estim[0, img_idx, ...].detach().cpu().numpy() 194 | 195 | if sys.platform == 'win32': 196 | cv2.imshow('Diff x10', diff.numpy()*10) 197 | cv2.imshow('Avg', np.hstack((ref, img))) 198 | cv2.waitKey(1) 199 | 200 | cube_estim = best_cube_estim.cpu().squeeze().permute(1, 2, 0).numpy() 201 | psnrval = utils.psnr(cube, cube_estim) 202 | ssimval = ssim_func(cube, cube_estim, multichannel=True) 203 | 204 | print('PSNR: %.2f'%psnrval) 205 | print('SSIM: %.2f'%ssimval) 206 | -------------------------------------------------------------------------------- /run_figure11_TV.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | 4 | 5 | ''' 6 | This script replicates the TV results in Figure 7. The accuracy is 7 | different from the paper as the video has been downsampled and clipped to 8 | fit supplementary material requirements. 9 | ''' 10 | 11 | 12 | import sys 13 | import tqdm 14 | 15 | import numpy as np 16 | from scipy import io 17 | from skimage.metrics import structural_similarity as ssim_func 18 | 19 | import torch 20 | 21 | import matplotlib.pyplot as plt 22 | plt.gray() 23 | import cv2 24 | 25 | sys.path.append('modules') 26 | 27 | import utils 28 | import losses 29 | import lin_inverse 30 | 31 | if __name__ == '__main__': 32 | expname = 'cat_video' 33 | nframes = 8 34 | rank = 200 35 | 36 | # Network parameters 37 | uv_decompose = False # Set this to true to optimize U, V instead 38 | 39 | # Noise parameters 40 | scaling = 1 41 | tau = 100 42 | noise_snr = 2 43 | 44 | # Learning constants 45 | learning_rate = 1e-1 46 | epochs = 500 47 | lambda_tv = 1e-2 48 | 49 | # Load data 50 | data = io.loadmat('data/%s.mat'%expname) 51 | cube = data['hypercube'].astype(np.float32) 52 | cube = utils.resize(cube/cube.max(), scaling) 53 | 54 | H, W, totalframes = cube.shape 55 | 56 | if rank == 'max': 57 | rank = min(H, W, totalframes) 58 | n_inputs = rank 59 | 60 | cube_noisy = utils.measure(cube, noise_snr, tau) 61 | 62 | # Get masks 63 | masks = lin_inverse.get_video_coding_frames(cube.shape, nframes) 64 | 65 | # Get video measurements 66 | measurements = np.zeros((H, W, totalframes//nframes+1), dtype=np.float32) 67 | 68 | for idx in range(0, totalframes, nframes): 69 | cube_chunk = cube_noisy[..., idx:idx+nframes] 70 | masks_chunk = masks[..., idx:idx+nframes] 71 | measurements[..., idx//nframes] = (cube_chunk*masks_chunk).sum(2) 72 | 73 | if idx < totalframes: 74 | cube_chunk = cube_noisy[..., idx:] 75 | masks_chunk = masks[..., idx:] 76 | measurements[..., idx//nframes + 1] = (cube_chunk*masks_chunk).sum(2) 77 | 78 | # Send data to device 79 | measurements_ten = torch.tensor(measurements)[None, ...] 80 | measurements_ten = measurements_ten.permute(0, 3, 1, 2).cuda() 81 | cube_ten = torch.tensor(cube)[None, ...].permute(0, 3, 1, 2).cuda() 82 | masks_ten = torch.tensor(masks)[None, ...].permute(0, 3, 1, 2).cuda() 83 | 84 | # Generate variables 85 | if uv_decompose: 86 | U = utils.get_inp([1, H*W, rank]) 87 | V = utils.get_inp([1, rank, totalframes]) 88 | params = [U] + [V] 89 | else: 90 | mat_estim = utils.get_inp([1, H*W, totalframes]) 91 | params = [mat_estim] 92 | 93 | criterion_l1 = losses.L2Norm() 94 | criterion_tv = losses.TVNorm() 95 | 96 | loss_array = np.zeros(epochs) 97 | mse_array = np.zeros(epochs) 98 | 99 | optimizer = torch.optim.Adam(lr=learning_rate, params=params) 100 | 101 | for idx in tqdm.tqdm(range(epochs)): 102 | if uv_decompose: 103 | U_img = U.reshape(1, H, W, rank).permute(3, 0, 1, 2) 104 | 105 | mat_estim = torch.bmm(U, V) 106 | else: 107 | U_img = mat_estim 108 | cube_estim = mat_estim.reshape(1, H, W, totalframes).permute(0, 3, 1, 2) 109 | 110 | meas_estim = lin_inverse.video2codedvideo(cube_estim, masks_ten, nframes) 111 | 112 | loss1 = criterion_l1(measurements_ten - meas_estim) 113 | loss2 = criterion_tv(U_img) 114 | 115 | loss = loss1 + lambda_tv*loss2 116 | 117 | optimizer.zero_grad() 118 | loss.backward() 119 | optimizer.step() 120 | 121 | with torch.no_grad(): 122 | img_idx = idx%totalframes 123 | ref = cube[..., img_idx] 124 | diff = abs(cube_ten - cube_estim).mean(2).squeeze().detach().cpu() 125 | img = cube_estim[0, img_idx, ...].detach().cpu().numpy() 126 | cv2.imshow('Diff x10', diff.numpy()*10) 127 | cv2.imshow('Avg', np.hstack((ref, img))) 128 | cv2.waitKey(1) 129 | 130 | cube_estim = cube_estim.detach().squeeze().cpu().permute(1, 2, 0).numpy() 131 | 132 | psnrval = utils.psnr(cube, cube_estim) 133 | ssimval = ssim_func(cube, cube_estim, multichannel=True) 134 | 135 | print('PSNR: %.2f'%psnrval) 136 | print('SSIM: %.2f'%ssimval) -------------------------------------------------------------------------------- /run_figure12.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ''' 4 | This script replicates the DeepTensor results in Figure 8. The accuracy is 5 | different from the paper as PET scan was downsampled. 6 | ''' 7 | 8 | import sys 9 | import tqdm 10 | import copy 11 | import time 12 | import argparse 13 | 14 | import numpy as np 15 | from scipy import io 16 | from skimage.metrics import structural_similarity as ssim_func 17 | 18 | import torch 19 | import tensorly 20 | tensorly.set_backend('pytorch') 21 | 22 | import matplotlib.pyplot as plt 23 | plt.gray() 24 | import cv2 25 | 26 | sys.path.append('../modules') 27 | 28 | import models 29 | import utils 30 | import losses 31 | import deep_prior 32 | import deep_decoder 33 | import lin_inverse 34 | 35 | if __name__ == '__main__': 36 | expname = 'pet' 37 | nangles = 40 38 | rank = 1000 39 | nettype = 'dip' 40 | 41 | # Network parameters 42 | n_inputs = rank 43 | init_nconv = 128 44 | num_channels_up = 5 45 | 46 | # Noise parameters 47 | scaling = 1 48 | tau = 1000 49 | noise_snr = 2 50 | 51 | # Learning constants 52 | sched_args = argparse.Namespace() 53 | scheduler_type = 'none' 54 | learning_rate = 1e-3 55 | lambda_tv = 1e1 56 | epochs = 5000 57 | reg_noise_std = 1.0/30.0 58 | exp_weight = 0.99 59 | sched_args.step_size = 2000 60 | sched_args.gamma = pow(10, -1/epochs) 61 | sched_args.max_lr = learning_rate 62 | sched_args.min_lr = 1e-4 63 | sched_args.epochs = epochs 64 | 65 | # Load data 66 | data = io.loadmat('data/%s.mat'%expname) 67 | cube = data['hypercube'].astype(np.float32) 68 | cube = utils.resize(cube/cube.max(), scaling) 69 | 70 | H, W, T = cube.shape 71 | 72 | cube_noisy = utils.measure(cube, noise_snr, tau).astype(np.float32) 73 | 74 | # Send data to device 75 | cube_gt_ten = torch.tensor(cube).cuda() 76 | cube_ten = torch.tensor(cube).cuda().permute(2, 0, 1)[None, ...] 77 | angles = torch.tensor(np.linspace(0, 180, nangles).astype(np.float32)).cuda() 78 | 79 | # Generate sinogram 80 | measurements = lin_inverse.radon(cube_ten, angles).detach().cpu().numpy() 81 | maxval = measurements.max() 82 | measurements = utils.measure(measurements, noise_snr, tau) 83 | measurements = torch.tensor(measurements).cuda() 84 | 85 | # Generate inputs -- positionally encoded inputs results in smoother reconstruction 86 | #u_inp = utils.get_inp([1, rank, H]) 87 | #v_inp = utils.get_inp([1, rank, W]) 88 | #w_inp = utils.get_inp([1, rank, T]) 89 | 90 | u_inp = utils.get_1d_posencode_inp(H, rank//2) 91 | v_inp = utils.get_1d_posencode_inp(W, rank//2) 92 | w_inp = utils.get_1d_posencode_inp(T, rank//2) 93 | 94 | H1 = rank 95 | W1 = rank 96 | T1 = rank 97 | factor = rank 98 | 99 | # Generate networks 100 | if nettype == 'unet': 101 | u_net = models.UNetND(H1, H1, 1, 16).cuda() 102 | v_net = models.UNetND(W1, W1, 1, 16).cuda() 103 | w_net = models.UNetND(T1, T1, 1, 16).cuda() 104 | elif nettype == 'dip': 105 | u_net = deep_prior.get_net(H1, 'skip1d', 'reflection', 106 | upsample_mode='linear', 107 | skip_n33d=init_nconv, 108 | skip_n33u=init_nconv, 109 | num_scales=num_channels_up, 110 | n_channels=H1).cuda() 111 | v_net = deep_prior.get_net(W1, 'skip1d', 'reflection', 112 | upsample_mode='linear', 113 | skip_n33d=init_nconv, 114 | skip_n33u=init_nconv, 115 | num_scales=num_channels_up, 116 | n_channels=W1).cuda() 117 | w_net = deep_prior.get_net(T1, 'skip1d', 'reflection', 118 | upsample_mode='linear', 119 | skip_n33d=init_nconv, 120 | skip_n33u=init_nconv, 121 | num_scales=num_channels_up, 122 | n_channels=T1).cuda() 123 | else: 124 | u_net = deep_decoder.decodernw1d(H1, [H//8, 64, 64], 125 | filter_size_up=3).cuda() 126 | v_net = deep_decoder.decodernw1d(W1, [W//8, 64, 64], 127 | filter_size_up=3).cuda() 128 | w_net = deep_decoder.decodernw1d(T1, [T//8, 64, 64], 129 | filter_size_up=3).cuda() 130 | # Deep decoder requires smaller inputs 131 | u_inp = utils.get_inp([1, H//8, H//8]) 132 | v_inp = utils.get_inp([1, W//8, W//8]) 133 | w_inp = utils.get_inp([1, T//8, T//8]) 134 | 135 | # Switch to trairning mode 136 | u_net.train() 137 | v_net.train() 138 | w_net.train() 139 | 140 | net_params = list(u_net.parameters()) + list(v_net.parameters()) +\ 141 | list(w_net.parameters()) 142 | inp_params = [u_inp] + [v_inp] + [w_inp] 143 | 144 | params = net_params + inp_params 145 | 146 | core = utils.get_inp(rank) 147 | with torch.no_grad(): 148 | core[...] = 1/rank 149 | params += [core] 150 | 151 | criterion_l1 = losses.L2Norm() 152 | 153 | loss_array = np.zeros(epochs) 154 | mse_array = np.zeros(epochs) 155 | time_array = np.zeros(epochs) 156 | 157 | X, Y = np.mgrid[:H, :W] 158 | mask = (np.hypot((X-H/2), (Y - W/2)) < min(H, W)/2).astype(np.float32) 159 | maskten = torch.tensor(mask)[..., None].cuda() 160 | 161 | optimizer = torch.optim.Adam(lr=learning_rate, params=params) 162 | 163 | # Create a learning scheduler 164 | scheduler = utils.get_scheduler(scheduler_type, optimizer, sched_args) 165 | 166 | u_inp_per = u_inp.detach().clone() 167 | v_inp_per = v_inp.detach().clone() 168 | w_inp_per = w_inp.detach().clone() 169 | 170 | best_loss = float('inf') 171 | best_epoch = 0 172 | tic = time.time() 173 | 174 | tbar = tqdm.tqdm(range(epochs)) 175 | for idx in tbar: 176 | # Perturbe inputs 177 | u_inp_perturbed = u_inp + u_inp_per.normal_()*reg_noise_std 178 | v_inp_perturbed = v_inp + v_inp_per.normal_()*reg_noise_std 179 | w_inp_perturbed = w_inp + w_inp_per.normal_()*reg_noise_std 180 | 181 | Uk = u_net(u_inp_perturbed) 182 | Vk = v_net(v_inp_perturbed) 183 | Wk = w_net(w_inp_perturbed) 184 | 185 | factors = [Uk.T[..., 0], Vk.T[..., 0], Wk.T[..., 0]] 186 | 187 | tensor_estim = tensorly.cp_to_tensor((core, factors)) 188 | 189 | measurements_estim = lin_inverse.radon( 190 | tensor_estim[None, ...].permute(0, 3, 1, 2), 191 | angles) 192 | 193 | loss = criterion_l1(measurements - measurements_estim) 194 | 195 | mse_array[idx] = ((tensor_estim - cube_gt_ten)**2).mean().item() 196 | loss_array[idx] = loss.item() 197 | time_array[idx] = time.time() - tic 198 | 199 | tbar.set_description('%.4e'%mse_array[idx]) 200 | tbar.refresh() 201 | if loss_array[idx] < best_loss: 202 | best_loss = loss_array[idx] 203 | best_epoch = idx 204 | best_cube_estim = copy.deepcopy(tensor_estim.detach().squeeze()) 205 | 206 | optimizer.zero_grad() 207 | loss.backward() 208 | optimizer.step() 209 | scheduler.step() 210 | 211 | with torch.no_grad(): 212 | img_idx = idx%T 213 | ref = cube[..., img_idx] 214 | diff = abs(cube_gt_ten - tensor_estim).mean(2) 215 | meas_diff = abs(measurements - measurements_estim).mean(0) 216 | diff = diff.squeeze().detach().cpu() 217 | meas_diff = meas_diff.detach().cpu() 218 | img = tensor_estim[..., img_idx].detach().cpu().numpy() 219 | 220 | if sys.platform == 'win32': 221 | cv2.imshow('Rec Diff x10', diff.numpy()*10) 222 | cv2.imshow('Meas Diff x10', 10*meas_diff.numpy()/maxval) 223 | cv2.imshow('Rec', np.hstack((ref, img))) 224 | cv2.waitKey(1) 225 | 226 | cube_estim = best_cube_estim.detach().squeeze().cpu().numpy() 227 | 228 | psnrval = utils.psnr(cube, cube_estim) 229 | ssimval = ssim_func(cube, cube_estim, multichannel=True) 230 | 231 | print('PSNR: %.2f'%psnrval) 232 | print('SSIM: %.2f'%ssimval) 233 | 234 | -------------------------------------------------------------------------------- /run_figure12_TV.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ''' 4 | This script replicates the TV results in Figure 8. The accuracy is 5 | different from the paper as PET scan was downsampled. 6 | ''' 7 | 8 | import sys 9 | import tqdm 10 | 11 | import numpy as np 12 | from scipy import io 13 | from skimage.metrics import structural_similarity as ssim_func 14 | 15 | import torch 16 | import tensorly 17 | tensorly.set_backend('pytorch') 18 | 19 | import matplotlib.pyplot as plt 20 | plt.gray() 21 | import cv2 22 | 23 | sys.path.append('modules') 24 | 25 | import utils 26 | import losses 27 | import lin_inverse 28 | 29 | if __name__ == '__main__': 30 | expname = 'pet' 31 | nangles = 40 # Number of measurementes per z-slice 32 | rank = 1000 33 | 34 | # Network parameters 35 | uv_decompose = False # Set True to decompose the 3D volume into PARAFAC tensor 36 | 37 | # Noise parameters 38 | scaling = 1 39 | tau = 1000 40 | noise_snr = 2 41 | 42 | # Learning constants 43 | learning_rate = 1e-1 44 | epochs = 100 45 | lambda_tv = 1e0 46 | 47 | # Load data 48 | data = io.loadmat('data/%s.mat'%expname) 49 | cube = data['hypercube'].astype(np.float32)[:, :, :288] 50 | cube = utils.resize(cube/cube.max(), scaling) 51 | 52 | H, W, T = cube.shape 53 | 54 | if rank == 'max': 55 | rank = min(H, W, T) 56 | n_inputs = rank 57 | 58 | cube_noisy = utils.measure(cube, noise_snr, tau).astype(np.float32) 59 | 60 | # Send data to device 61 | cube_gt_ten = torch.tensor(cube).cuda() 62 | cube_ten = torch.tensor(cube).cuda().permute(2, 0, 1)[None, ...] 63 | angles = torch.tensor(np.linspace(0, 180, nangles).astype(np.float32)).cuda() 64 | 65 | # Generate sinogram 66 | measurements = lin_inverse.radon(cube_ten, angles).detach().cpu().numpy() 67 | measurements = utils.measure(measurements, noise_snr, tau) 68 | measurements = torch.tensor(measurements).cuda() 69 | 70 | if uv_decompose: 71 | U = utils.get_inp([H, rank]) 72 | V = utils.get_inp([W, rank]) 73 | W = utils.get_inp([T, rank]) 74 | core = utils.get_inp(rank) 75 | 76 | params = [U] + [V] + [W] + [core] 77 | else: 78 | tensor_estim = utils.get_inp([H, W, T]) 79 | params = [tensor_estim] 80 | 81 | criterion_l1 = losses.L2Norm() 82 | criterion_tv = losses.TVNorm() 83 | 84 | loss_array = np.zeros(epochs) 85 | mse_array = np.zeros(epochs) 86 | 87 | optimizer = torch.optim.Adam(lr=learning_rate, params=params) 88 | 89 | tbar = tqdm.tqdm(range(epochs)) 90 | for idx in tbar: 91 | if uv_decompose: 92 | factors = [U, V, W] 93 | tensor_estim = tensorly.cp_to_tensor((core, factors)) 94 | tensor_estim = tensor_estim 95 | 96 | measurements_estim = lin_inverse.radon( 97 | tensor_estim[None, ...].permute(0, 3, 1, 2), 98 | angles) 99 | 100 | loss1 = criterion_l1(measurements - measurements_estim) 101 | loss2 = criterion_tv(tensor_estim) 102 | 103 | loss = loss1 + lambda_tv*loss2 104 | 105 | mse_array[idx] = ((tensor_estim - cube_gt_ten)**2).mean().item() 106 | 107 | tbar.set_description('%.4e'%mse_array[idx]) 108 | 109 | optimizer.zero_grad() 110 | loss.backward() 111 | optimizer.step() 112 | 113 | with torch.no_grad(): 114 | img_idx = idx%T 115 | ref = cube[..., img_idx] 116 | diff = abs(cube_gt_ten - tensor_estim).mean(2).squeeze().detach().cpu() 117 | img = tensor_estim[..., img_idx].detach().cpu().numpy() 118 | 119 | if sys.platform == 'win32': 120 | cv2.imshow('Diff x10', diff.numpy()*10) 121 | cv2.imshow('Avg', np.hstack((ref, img))) 122 | cv2.waitKey(1) 123 | 124 | cube_estim = tensor_estim.detach().squeeze().cpu().numpy() 125 | 126 | psnrval = utils.psnr(cube, cube_estim) 127 | ssimval = ssim_func(cube, cube_estim, multichannel=True) 128 | 129 | print('PSNR: %.2f'%psnrval) 130 | print('SSIM: %.2f'%ssimval) 131 | 132 | -------------------------------------------------------------------------------- /run_figure2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ''' 4 | This script replicates results in Figure 2 of our main paper. 5 | ''' 6 | 7 | import os 8 | import sys 9 | import tqdm 10 | import importlib 11 | import copy 12 | import argparse 13 | 14 | import numpy as np 15 | from scipy import io 16 | from scipy.sparse import linalg 17 | from skimage.metrics import structural_similarity as ssim_func 18 | 19 | from sklearn.model_selection import train_test_split 20 | from sklearn.model_selection import GridSearchCV 21 | from sklearn.datasets import fetch_lfw_people, fetch_olivetti_faces 22 | from sklearn.metrics import classification_report 23 | from sklearn.metrics import confusion_matrix 24 | from sklearn.metrics import average_precision_score as apfunc 25 | from sklearn.decomposition import PCA, FastICA 26 | from sklearn.svm import SVC 27 | 28 | import torch 29 | 30 | import matplotlib.pyplot as plt 31 | plt.gray() 32 | import cv2 33 | 34 | sys.path.append('modules') 35 | 36 | import models 37 | import utils 38 | import losses 39 | import deep_prior 40 | import deep_decoder 41 | 42 | def do_pca_and_svm(X_train, y_train, X_test, n_components, param_grid): 43 | ''' 44 | Just a combined wrapper for PCA + SVM 45 | ''' 46 | 47 | X_train_centered = X_train - X_train.mean(1).reshape(-1, 1) 48 | X_train_centered /= X_train_centered.std(1).reshape(-1, 1) 49 | 50 | covmat = X_train_centered.T.dot(X_train_centered) 51 | _, eigenfaces = linalg.eigsh(covmat, k=n_components) 52 | eigenfaces = eigenfaces.T 53 | 54 | y_pred = do_svm(eigenfaces, X_train, y_train, X_test, param_grid) 55 | 56 | return eigenfaces, y_pred 57 | 58 | def do_svm(eigenfaces, X_train, y_train, X_test, param_grid): 59 | ''' 60 | Just do SVM 61 | ''' 62 | 63 | X_train_centered = X_train - X_train.mean(1).reshape(-1, 1) 64 | 65 | X_test_centered = X_test - X_test.mean(1).reshape(-1, 1) 66 | 67 | X_train_proj = X_train_centered.dot(eigenfaces.T) 68 | X_test_proj = X_test_centered.dot(eigenfaces.T) 69 | 70 | # Train a SVM classification model for pca 71 | clf = GridSearchCV( 72 | SVC(kernel='linear', class_weight='balanced'), param_grid 73 | ) 74 | clf = clf.fit(X_train_proj, y_train) 75 | 76 | # Quantitative evaluation of the model quality on the test set 77 | y_pred = clf.predict(X_test_proj) 78 | 79 | return y_pred 80 | 81 | def do_ica_and_svm(X_train, y_train, X_test, n_components, param_grid): 82 | ''' 83 | Independent component analysis 84 | ''' 85 | ica = FastICA(n_components=n_components, random_state=0).fit(X_train) 86 | eigenfaces = ica.components_ 87 | 88 | y_pred = do_svm(eigenfaces, X_train, y_train, X_test, param_grid) 89 | 90 | return eigenfaces, y_pred 91 | 92 | def average_precision(y_test, y_pred, n_classes): 93 | return apfunc(np.eye(n_classes)[y_test], np.eye(n_classes)[y_pred]) 94 | 95 | if __name__ == '__main__': 96 | expname = 'weizzman' 97 | n_components = 84 98 | train_size = 0.25 99 | scale = 1 100 | 101 | # Noise constants 102 | tau = 20 103 | noise_snr = 2 104 | 105 | # Network constants 106 | n_inputs = n_components 107 | init_nconv = 128 108 | num_channels_up = 5 109 | nettype = 'dip' 110 | 111 | # SVM constants 112 | param_grid = {'C': [1e1, 1e2, 1e3, 5e3, 1e4, 5e4, 1e5], 113 | 'gamma': [0.0001, 0.0005, 0.001, 0.005, 0.01, 0.1, 1.0, 10.0], } 114 | 115 | # Learning constants 116 | sched_args = argparse.Namespace() 117 | scheduler_type = 'none' 118 | learning_rate = 1e-3 119 | epochs = 100 120 | reg_noise_std = 1.0/30.0 121 | exp_weight = 0.9 122 | sched_args.step_size = 2000 123 | sched_args.gamma = pow(10, -2/epochs) 124 | sched_args.max_lr = learning_rate 125 | sched_args.min_lr = 1e-6 126 | sched_args.epochs = epochs 127 | 128 | # Load data 129 | if expname == 'lfw': 130 | lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=0.5) 131 | n_samples, h, w = lfw_people.images.shape 132 | 133 | X = lfw_people.data/255.0 134 | y = lfw_people.target 135 | target_names = lfw_people.target_names 136 | n_classes = target_names.shape[0] 137 | elif expname == 'olivetti': 138 | X, y = fetch_olivetti_faces(return_X_y=True) 139 | h = 64 140 | w = 64 141 | n_classes = np.unique(y).size 142 | n_samples = X.shape[0] 143 | 144 | else: 145 | data = io.loadmat('data/weizzman.mat') 146 | faces = data['faces'].astype(np.float32)/255 147 | y = data['labels'].ravel() 148 | n_classes = np.unique(y).size 149 | 150 | names = data['names'].ravel() 151 | target_names = [name[0] for name in names] 152 | 153 | # Resize to avoid flooding GPU 154 | data = np.transpose(utils.resize(np.transpose(faces, [1, 2, 0]), scale), 155 | [2, 0, 1]) 156 | n_samples, h, w = data.shape 157 | X = data.reshape(n_samples, h*w) 158 | 159 | # Split into a training and testing set 160 | X_train, X_test, y_train, y_test = train_test_split( 161 | X, y, train_size=train_size, random_state=42) 162 | 163 | # Corrupt data after splitting so that we can keep track of clean PCA 164 | X_train_noisy = utils.measure(X_train, noise_snr, tau) 165 | X_test_noisy = utils.measure(X_test, noise_snr, tau) 166 | 167 | # Generate networks 168 | if nettype == 'unet': 169 | eigen_net = models.UNetND(n_inputs, n_components, 2, 170 | init_nconv).cuda() 171 | net_inp = utils.get_inp([1, n_inputs, h, w]) 172 | elif nettype == 'dip': 173 | eigen_net = deep_prior.get_net(n_inputs, 'skip', 'reflection', 174 | upsample_mode='bilinear', 175 | skip_n33d=128, 176 | skip_n33u=128, 177 | num_scales=5, 178 | n_channels=n_components).cuda() 179 | net_inp = utils.get_inp([1, n_inputs, h, w]) 180 | elif nettype == 'dd': 181 | nchans = [init_nconv]*num_channels_up 182 | eigen_net = deep_decoder.decodernw(n_components).cuda() 183 | 184 | H1 = h // pow(2, num_channels_up) 185 | W1 = w // pow(2, num_channels_up) 186 | net_inp = utils.get_inp([1, init_nconv, H1, W1]) 187 | 188 | # Clean data 189 | print('Generating eigenfaces on clean data') 190 | eigenfaces_gt, y_gt = do_ica_and_svm(X_train, y_train, X_test, n_components, 191 | param_grid) 192 | 193 | print('Generating eigenfaces on noisy data') 194 | eigenfaces_ica, y_ica = do_ica_and_svm(X_train_noisy, y_train, X_test_noisy, 195 | n_components, param_grid) 196 | 197 | print('Computing ICA components') 198 | eigenfaces_pca, y_pca = do_pca_and_svm(X_train_noisy, y_train, X_test_noisy, 199 | n_components, param_grid) 200 | 201 | ## Part two DeepTensor 202 | # Compute covariance matrix 203 | print('Now starting DeepTensor') 204 | X_train_noisy_centered = X_train_noisy - X_train_noisy.mean(1).reshape(-1, 1) 205 | 206 | covmat = X_train_noisy_centered.T.dot(X_train_noisy_centered) 207 | covmat_ten = torch.tensor(covmat)[None, ...].cuda() 208 | 209 | net_params = list(eigen_net.parameters()) 210 | inp_params = [net_inp] 211 | 212 | params = net_params + inp_params 213 | optimizer = torch.optim.Adam(lr=learning_rate, params=params) 214 | 215 | # Create a learning scheduler 216 | scheduler = utils.get_scheduler(scheduler_type, optimizer, sched_args) 217 | 218 | net_inp_delta = net_inp.detach().clone() 219 | 220 | criterion_l1 = losses.L2Norm() 221 | loss_array = np.zeros(epochs) 222 | best_loss = float('inf') 223 | best_epoch = 0 224 | 225 | eigenimg_array = np.zeros((h, w, n_components, epochs), dtype=np.float32) 226 | 227 | for idx in tqdm.tqdm(range(epochs)): 228 | # Perturb inputs 229 | net_inp_perturbed = net_inp + net_inp_delta.normal_()*reg_noise_std 230 | 231 | eigenimg = eigen_net(net_inp_perturbed) 232 | eigenvec = eigenimg.reshape(1, n_components, h*w) 233 | 234 | covmat_estim = torch.bmm(eigenvec.permute(0, 2, 1), eigenvec) 235 | 236 | loss = criterion_l1(covmat_ten - covmat_estim) 237 | 238 | optimizer.zero_grad() 239 | loss.backward() 240 | optimizer.step() 241 | scheduler.step() 242 | 243 | loss_array[idx] = loss.item() 244 | 245 | tmp = eigenimg[0, ...].permute(1, 2, 0).detach().cpu().numpy() 246 | eigenimg_array[:, :, :, idx] = tmp 247 | 248 | if loss.item() < best_loss: 249 | best_loss = loss.item() 250 | best_epoch = idx 251 | 252 | best_eigenvec = copy.deepcopy(eigenvec.detach()) 253 | best_covmat = copy.deepcopy(covmat_estim.detach()) 254 | 255 | if sys.platform == 'win32': 256 | img = utils.build_montage(eigenimg.squeeze().detach().cpu().numpy()) 257 | 258 | cv2.imshow('Eigenvector', utils.normalize(img, True)) 259 | cv2.waitKey(1) 260 | 261 | covmat_dlrf = best_covmat.squeeze().cpu().numpy() 262 | _, eigenface_dlrf = linalg.eigsh(covmat_dlrf, k=n_components) 263 | eigenface_dlrf = eigenface_dlrf.T 264 | 265 | # Now predict 266 | y_dlrf = do_svm(eigenface_dlrf, X_train_noisy, y_train, X_test_noisy, 267 | param_grid) 268 | 269 | accuracy_gt = (y_test == y_gt).sum()/y_test.size 270 | accuracy_pca = (y_test == y_pca).sum()/y_test.size 271 | accuracy_ica = (y_test == y_ica).sum()/y_test.size 272 | accuracy_dlrf = (y_test == y_dlrf).sum()/y_test.size 273 | 274 | print('Noiseless PCA accuracy: %.2f'%((y_test == y_gt).sum()/y_test.size)) 275 | print('Noiseless PCA mAP: %.2f'%average_precision(y_test, y_gt, n_classes)) 276 | print('') 277 | 278 | print('PCA accuracy: %.2f'%((y_test == y_pca).sum()/y_test.size)) 279 | print('PCA mAP: %.2f'%average_precision(y_test, y_pca, n_classes)) 280 | print('') 281 | 282 | print('ICA accuracy: %.2f'%((y_test == y_ica).sum()/y_test.size)) 283 | print('ICA mAP: %.2f'%average_precision(y_test, y_ica, n_classes)) 284 | print('') 285 | 286 | print('DeepTensor accuracy: %.2f'%((y_test == y_dlrf).sum()/y_test.size)) 287 | print('DeepTensor mAP: %.2f'%average_precision(y_test, y_dlrf, n_classes)) -------------------------------------------------------------------------------- /run_figure6.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ''' 4 | This script replicates instances of results in figure 4 in the main paper. 5 | To obtain the plots, please sweep the relevant parameters. 6 | ''' 7 | 8 | import os 9 | import sys 10 | import glob 11 | import tqdm 12 | import importlib 13 | import argparse 14 | 15 | import numpy as np 16 | from scipy import io 17 | from skimage.metrics import structural_similarity as ssim_func 18 | 19 | import torch 20 | 21 | import matplotlib.pyplot as plt 22 | plt.gray() 23 | import cv2 24 | 25 | sys.path.append('modules') 26 | 27 | import models 28 | import utils 29 | import losses 30 | import deep_prior 31 | import deep_decoder 32 | 33 | if __name__ == '__main__': 34 | # Set your simulation constants here 35 | nrows = 64 # Size of the matrix 36 | ncols = 64 37 | rank = 32 # Rank 38 | noise_type = 'gaussian' # Type of noise 39 | signal_type = 'gaussian' # Type of signal 40 | noise_snr = 0.2 # Std. dev for gaussian noise 41 | tau = 1000 # Max. lambda for photon noise 42 | 43 | # Network parameters 44 | nettype = 'dip' 45 | n_inputs = rank 46 | init_nconv = 128 47 | num_channels_up = 5 48 | 49 | sched_args = argparse.Namespace() 50 | # Learning constants 51 | # Important: The number of epochs is decided by noise levels. As a general 52 | # rule of thumb, higher the noise, fewer the epochs. 53 | scheduler_type = 'none' 54 | learning_rate = 1e-4 55 | epochs = 1000 56 | sched_args.step_size = 2000 57 | sched_args.gamma = pow(10, -1/epochs) 58 | sched_args.max_lr = learning_rate 59 | sched_args.min_lr = 1e-6 60 | sched_args.epochs = epochs 61 | 62 | # Generate data 63 | mat, mat_gt = utils.get_matrix(nrows, ncols, rank, noise_type, signal_type, 64 | noise_snr, tau) 65 | 66 | # Move them to device 67 | mat_ten = torch.tensor(mat)[None, ...].cuda() 68 | mat_gt_ten = torch.tensor(mat_gt)[None, ...].cuda() 69 | 70 | u_inp = utils.get_inp([1, n_inputs, nrows]) 71 | v_inp = utils.get_inp([1, n_inputs, ncols]) 72 | 73 | # Create networks 74 | if nettype == 'unet': 75 | u_net = models.UNetND(n_inputs, rank, 1, init_nconv).cuda() 76 | v_net = models.UNetND(n_inputs, rank, 1, init_nconv).cuda() 77 | elif nettype == 'dip': 78 | u_net = deep_prior.get_net(n_inputs, 'skip1d', 'reflection', 79 | upsample_mode='linear', 80 | skip_n33d=init_nconv, 81 | skip_n33u=init_nconv, 82 | num_scales=5, 83 | n_channels=rank).cuda() 84 | v_net = deep_prior.get_net(n_inputs, 'skip1d', 'reflection', 85 | upsample_mode='linear', 86 | skip_n33d=init_nconv, 87 | skip_n33u=init_nconv, 88 | num_scales=5, 89 | n_channels=rank).cuda() 90 | elif nettype == 'dd': 91 | u_net = deep_decoder.decodernw1d(rank, 92 | [init_nconv]*num_channels_up).cuda() 93 | v_net = deep_decoder.decodernw1d(rank, 94 | [init_nconv]*num_channels_up).cuda() 95 | 96 | # Deep decoder requires smaller inputs 97 | u_inp = utils.get_inp([1, init_nconv, nrows // pow(2, num_channels_up)]) 98 | v_inp = utils.get_inp([1, init_nconv, ncols // pow(2, num_channels_up)]) 99 | 100 | # Extract training parameters 101 | net_params = list(u_net.parameters()) + list(v_net.parameters()) 102 | inp_params = [u_inp] + [v_inp] 103 | 104 | # You can either optimize both net and inputs, or just net 105 | params = net_params + inp_params 106 | optimizer = torch.optim.Adam(lr=learning_rate, params=params) 107 | 108 | # Create a learning scheduler 109 | scheduler = utils.get_scheduler(scheduler_type, optimizer, sched_args) 110 | 111 | # Create loss functions -- loses.L1Norm() or losses.L2Norm() 112 | criterion = losses.L2Norm() 113 | 114 | mse_array = np.zeros(epochs) 115 | 116 | # Now start iterations 117 | best_mse = float('inf') 118 | tbar = tqdm.tqdm(range(epochs)) 119 | for idx in tbar: 120 | u_output = u_net(u_inp).permute(0, 2, 1) 121 | v_output = v_net(v_inp) 122 | 123 | mat_estim = torch.bmm(u_output, v_output) 124 | 125 | loss = criterion(mat_estim - mat_ten) 126 | 127 | optimizer.zero_grad() 128 | loss.backward() 129 | optimizer.step() 130 | scheduler.step() 131 | 132 | # Visualize the reconstruction 133 | diff = abs(mat_gt_ten - mat_estim).squeeze().detach().cpu() 134 | mat_cpu = mat_estim.squeeze().detach().cpu().numpy() 135 | 136 | cv2.imshow('Diff x10', diff.numpy().reshape(nrows, ncols)*10) 137 | cv2.imshow('Rec', np.hstack((mat_gt, mat_cpu))) 138 | cv2.waitKey(1) 139 | 140 | mse_array[idx] = ((mat_estim - mat_gt_ten)**2).mean().item() 141 | tbar.set_description('%.4e'%mse_array[idx]) 142 | tbar.refresh() 143 | 144 | if loss.item() < best_mse: 145 | best_epoch = idx 146 | best_mat = mat_cpu 147 | best_mse = loss.item() 148 | 149 | # Now compute accuracy 150 | psnr1 = utils.psnr(mat_gt, best_mat) 151 | psnr2 = utils.psnr(mat_gt, utils.lr_decompose(mat, rank)) 152 | 153 | print('DeepTensor: %.2fdB'%psnr1) 154 | print('SVD: %.2fdB'%psnr2) 155 | 156 | 157 | -------------------------------------------------------------------------------- /run_figure7.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import sys 5 | import glob 6 | import tqdm 7 | import importlib 8 | 9 | import numpy as np 10 | from skimage.metrics import structural_similarity as ssim_func 11 | 12 | import torch 13 | 14 | import matplotlib.pyplot as plt 15 | plt.gray() 16 | import cv2 17 | 18 | sys.path.append('modules') 19 | 20 | import models 21 | import utils 22 | import losses 23 | import deep_prior 24 | import deep_decoder 25 | 26 | if __name__ == '__main__': 27 | # Set your simulation constants here 28 | nrows = 64 # Size of the matrix 29 | ncols = 64 30 | rank = 10 # Rank 31 | noise_type = 'gaussian' # Type of noise 32 | signal_type = 'gaussian' # Type of signal 33 | noise_snr = 0.2 # Std. dev for gaussian noise 34 | tau = 1000 # Max. lambda for photon noise 35 | 36 | # Network parameters 37 | nettype = 'dip' 38 | n_inputs = 10 39 | init_nconv = 128 40 | num_channels_up = 5 41 | 42 | # Learning constants 43 | learning_rate = 1e-4 44 | 45 | # The number of epochs should be set according to noise level, and number 46 | # of data points. As a general rule of thumb, higher noise or fewer samples 47 | # requires fewer epochs 48 | epochs = 500 49 | 50 | # Generate data 51 | mat, mat_gt, basis = utils.get_pca(nrows, ncols, rank, 52 | noise_type, signal_type, 53 | noise_snr, tau) 54 | 55 | # Compute covariance matrix 56 | mat_centered = mat - mat.mean(0).reshape(1, ncols) 57 | mat_gt_centered = mat_gt - mat_gt.mean(0).reshape(1, ncols) 58 | covmat = mat_centered.dot(mat_centered.T) 59 | covmat_gt = mat_gt_centered.dot(mat_gt_centered.T) 60 | 61 | minval = covmat_gt.min() 62 | maxval = covmat_gt.max() 63 | 64 | covmat = (covmat - minval)/(maxval - minval) 65 | covmat_gt = (covmat_gt - minval)/(maxval - minval) 66 | 67 | # Move them to device 68 | covmat_ten = torch.tensor(covmat)[None, ...].cuda() 69 | covmat_gt_ten = torch.tensor(covmat_gt)[None, ...].cuda() 70 | 71 | # Since we are doing PCA, we need only one network 72 | u_inp = utils.get_inp([1, n_inputs, nrows]) 73 | 74 | # Create networks 75 | if nettype == 'unet': 76 | u_net = models.UNetND(n_inputs, rank, 1, init_nconv).cuda() 77 | elif nettype == 'dip': 78 | u_net = deep_prior.get_net(n_inputs, 'skip1d', 'reflection', 'linear', 79 | rank).cuda() 80 | elif nettype == 'dd': 81 | u_net = deep_decoder.decodernw1d(rank, 82 | [init_nconv]*num_channels_up).cuda() 83 | 84 | # Deep decoder requires smaller inputs 85 | u_inp = utils.get_inp([1, init_nconv, nrows // pow(2, num_channels_up)]) 86 | 87 | # Create optimizer 88 | net_params = list(u_net.parameters()) 89 | inp_params = [u_inp] 90 | 91 | # You can either optimize both net and inputs, or just net 92 | params = net_params + inp_params 93 | optimizer = torch.optim.Adam(lr=learning_rate, params=params) 94 | 95 | # Create loss functions -- loses.L1Norm() or losses.L2Norm() 96 | criterion = losses.L2Norm() 97 | 98 | mse_array = np.zeros(epochs) 99 | 100 | # Now start iterations 101 | best_mse = float('inf') 102 | tbar = tqdm.tqdm(range(epochs)) 103 | for idx in tbar: 104 | u_output = u_net(u_inp) 105 | 106 | covmat_estim = torch.bmm(u_output.permute(0, 2, 1), u_output) 107 | 108 | loss = criterion(covmat_estim - covmat_ten) 109 | 110 | optimizer.zero_grad() 111 | loss.backward() 112 | optimizer.step() 113 | 114 | # Visualize the reconstruction 115 | diff = abs(covmat_gt_ten - covmat_estim).squeeze().detach().cpu() 116 | mat_cpu = covmat_estim.squeeze().detach().cpu().numpy() 117 | 118 | cv2.imshow('Diff x10', diff.numpy().reshape(nrows, nrows)*10) 119 | cv2.imshow('Rec', np.hstack((covmat_gt, mat_cpu))) 120 | cv2.waitKey(1) 121 | 122 | mse_array[idx] = ((covmat_estim - covmat_gt_ten)**2).mean().item() 123 | tbar.set_description('%.4e'%mse_array[idx]) 124 | tbar.refresh() 125 | 126 | if loss.item() < best_mse: 127 | best_epoch = idx 128 | best_mat = mat_cpu 129 | best_mse = loss.item() 130 | 131 | # Now compute accuracy 132 | psnr1 = utils.psnr(covmat_gt, best_mat) 133 | psnr2 = utils.psnr(covmat_gt, utils.lr_decompose(covmat, rank)) 134 | 135 | print('PCA with DeepTensor: %.2fdB'%psnr1) 136 | print('PCA with SVD: %.2fdB'%psnr2) 137 | 138 | -------------------------------------------------------------------------------- /run_figure9.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | 4 | 5 | ''' 6 | This script replicates the DeepTensor results in Figure 6. The accuracy is 7 | different from the paper as hyperspectral image was downsampled. 8 | ''' 9 | 10 | import sys 11 | import tqdm 12 | import copy 13 | import time 14 | 15 | import numpy as np 16 | from scipy import io 17 | from skimage.metrics import structural_similarity as ssim_func 18 | 19 | import torch 20 | 21 | import matplotlib.pyplot as plt 22 | plt.gray() 23 | import cv2 24 | 25 | sys.path.append('modules') 26 | 27 | import models 28 | import utils 29 | import losses 30 | import spectral 31 | import deep_prior 32 | import deep_decoder 33 | 34 | if __name__ == '__main__': 35 | expname = 'icvl' 36 | rank = 20 37 | n_inputs = rank 38 | init_nconv = 128 39 | num_channels_up = 5 40 | nettype = 'dip' 41 | 42 | scaling = 1 43 | tau = 100 44 | noise_snr = 2 45 | 46 | # Learning constants 47 | learning_rate = 1e-3 48 | epochs = 5000 49 | reg_noise_std = 1.0/30.0 50 | exp_weight = 0.99 51 | 52 | # Load data 53 | data = io.loadmat('data/%s.mat'%expname) 54 | cube = data['hypercube'].astype(np.float32) 55 | wavelengths = data['wavelengths'].astype(np.float32).ravel() 56 | cube = utils.resize(cube/cube.max(), scaling) 57 | 58 | H, W, nwvl = cube.shape 59 | 60 | if rank == 'max': 61 | rank = nwvl - 1 62 | n_inputs = rank 63 | 64 | hsmat = cube.reshape(H*W, nwvl) 65 | 66 | cube_noisy = utils.measure(cube, noise_snr, tau) 67 | 68 | hsten = torch.tensor(hsmat)[None, ...].cuda() 69 | hsten_noisy = torch.tensor(cube_noisy.reshape(H*W, nwvl))[None, ...].cuda() 70 | 71 | criterion_l1 = losses.L1Norm() 72 | 73 | loss_array = np.zeros(epochs) 74 | mse_array = np.zeros(epochs) 75 | 76 | # Generate networks 77 | if nettype == 'unet': 78 | im_net = models.UNetND(n_inputs, rank, 2, init_nconv).cuda() 79 | spec_net = models.UNetND(n_inputs, rank, 1, init_nconv).cuda() 80 | 81 | im_inp = utils.get_inp([1, n_inputs, H, W]) 82 | spec_inp = utils.get_inp([1, n_inputs, nwvl]) 83 | elif nettype == 'dip': 84 | im_net = deep_prior.get_net(n_inputs, 'skip', 'reflection', 85 | upsample_mode='bilinear', 86 | skip_n33d=128, 87 | skip_n33u=128, 88 | num_scales=5, 89 | n_channels=rank).cuda() 90 | spec_net = deep_prior.get_net(n_inputs, 'skip1d', 'reflection', 91 | upsample_mode='linear', 92 | skip_n33d=128, 93 | skip_n33u=128, 94 | num_scales=5, 95 | n_channels=rank).cuda() 96 | im_inp = utils.get_inp([1, n_inputs, H, W]) 97 | spec_inp = utils.get_inp([1, n_inputs, nwvl]) 98 | elif nettype == 'dd': 99 | nchans = [init_nconv]*num_channels_up 100 | im_net = deep_decoder.decodernw(rank).cuda() 101 | spec_net = deep_decoder.decodernw1d(rank).cuda() 102 | 103 | H1 = H // pow(2, num_channels_up) 104 | W1 = W // pow(2, num_channels_up) 105 | nwvl1 = nwvl // pow(2, num_channels_up) 106 | im_inp = utils.get_inp([1, init_nconv, H1, W1]) 107 | spec_inp = utils.get_inp([1, init_nconv, nwvl1]) 108 | 109 | # Switch to training mode 110 | im_net.train() 111 | spec_net.train() 112 | 113 | net_params = list(im_net.parameters()) + list(spec_net.parameters()) 114 | inp_params = [im_inp] + [spec_inp] 115 | 116 | im_inp_per = im_inp.detach().clone() 117 | spec_inp_per = spec_inp.detach().clone() 118 | 119 | hs_estim_avg = None 120 | best_loss = float('inf') 121 | best_epoch = 0 122 | 123 | params = net_params + inp_params 124 | optimizer = torch.optim.Adam(lr=learning_rate, params=params) 125 | 126 | tic = time.time() 127 | tbar = tqdm.tqdm(range(epochs)) 128 | for idx in tbar: 129 | # Perturb inputs 130 | im_inp_perturbed = im_inp + im_inp_per.normal_()*reg_noise_std 131 | spec_inp_perturbed = spec_inp + spec_inp_per.normal_()*reg_noise_std 132 | 133 | U_img = im_net(im_inp_perturbed) 134 | V = spec_net(spec_inp_perturbed) 135 | 136 | U = U_img.reshape(-1, rank, H*W).permute(0, 2, 1) 137 | 138 | hs_estim = torch.bmm(U, V) 139 | 140 | loss = criterion_l1(hsten_noisy - hs_estim) 141 | loss_array[idx] = loss.item() 142 | mse_array[idx] = ((hs_estim - hsten)**2).mean().item() 143 | 144 | tbar.set_description('%.4e'%mse_array[idx]) 145 | tbar.refresh() 146 | 147 | optimizer.zero_grad() 148 | loss.backward() 149 | optimizer.step() 150 | 151 | if loss.item() < best_loss: 152 | best_loss = loss.item() 153 | best_epoch = idx 154 | best_hs_estim = copy.deepcopy(hs_estim.detach().squeeze()) 155 | 156 | # Averaging as per original code 157 | if hs_estim_avg is None: 158 | hs_estim_avg = hs_estim.detach() 159 | else: 160 | hs_estim_avg = exp_weight*hs_estim_avg +\ 161 | (1 - exp_weight)*hs_estim.detach() 162 | 163 | if sys.platform == 'win32': 164 | im_idx = idx%nwvl 165 | with torch.no_grad(): 166 | diff = abs(hs_estim - hsten).mean(2).squeeze().detach().cpu() 167 | avg = hs_estim.mean(2).squeeze().detach().cpu().numpy() 168 | img = hs_estim[0, :, im_idx].detach().cpu().numpy().reshape(H, W) 169 | cv2.imshow('Diff x10', diff.numpy().reshape(H, W)*10) 170 | cv2.imshow('Band', 171 | np.sqrt(abs(np.hstack((cube[..., im_idx], img))))) 172 | cv2.waitKey(1) 173 | 174 | toc = time.time() 175 | 176 | cube_estim = best_hs_estim.cpu().numpy().reshape(H, W, nwvl) 177 | cube_lr = spectral.lr_decompose(cube_noisy, rank) 178 | diff = abs(hs_estim - hsten).mean(2).squeeze().detach().cpu() 179 | 180 | wvl_idx = nwvl // 2 181 | 182 | plt.subplot(2, 2, 1) 183 | plt.imshow(cube[..., wvl_idx]); plt.title('Ground truth') 184 | 185 | plt.subplot(2, 2, 2) 186 | snrval = utils.psnr(cube, cube_lr) 187 | ssimval = ssim_func(cube, cube_lr, multichannel=True) 188 | plt.imshow(cube_lr[..., wvl_idx]) 189 | plt.title('SVD | %.1f dB | %.2f'%(snrval, ssimval)) 190 | 191 | plt.subplot(2, 2, 3) 192 | snrval = utils.psnr(cube, cube_estim) 193 | ssimval = ssim_func(cube, cube_estim, multichannel=True) 194 | plt.imshow(cube_estim[..., wvl_idx]) 195 | plt.title('DeepTensor | %.1f dB | %.2f'%(snrval, ssimval)) 196 | 197 | plt.subplot(2, 2, 4) 198 | plt.imshow(diff.numpy().reshape(H, W)) 199 | plt.title('Diff image') 200 | plt.colorbar() 201 | 202 | plt.show() 203 | -------------------------------------------------------------------------------- /run_figure9_bm3d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | 4 | ''' 5 | This script replicates the BM3D results in Figure 6. The accuracy is 6 | different from the paper as the hyperspectral image has been downsampled. 7 | ''' 8 | 9 | 10 | import sys 11 | import tqdm 12 | 13 | import numpy as np 14 | from scipy import io 15 | from skimage.metrics import structural_similarity as ssim_func 16 | import bm3d 17 | 18 | import matplotlib.pyplot as plt 19 | plt.gray() 20 | 21 | sys.path.append('modules') 22 | import utils 23 | 24 | if __name__ == '__main__': 25 | expname = 'icvl' 26 | 27 | scaling = 1 28 | tau = 100 29 | noise_snr = 2 30 | 31 | sigma_psd = int(np.sqrt(tau))/255 32 | 33 | # Load data 34 | data = io.loadmat('data/%s.mat'%expname) 35 | cube = data['hypercube'].astype(np.float32) 36 | cube = utils.resize(cube/cube.max(), scaling) 37 | 38 | H, W, nwvl = cube.shape 39 | cube_noisy = utils.measure(cube, noise_snr, tau) 40 | 41 | cube_estim = np.zeros_like(cube_noisy) 42 | 43 | for idx in tqdm.tqdm(range(nwvl)): 44 | denoised_img = bm3d.bm3d(cube_noisy[..., idx], 45 | sigma_psd=sigma_psd, 46 | stage_arg=bm3d.BM3DStages.ALL_STAGES) 47 | cube_estim[..., idx] = denoised_img 48 | 49 | snrval = utils.psnr(cube, cube_estim) 50 | ssimval = ssim_func(cube, cube_estim, multichannel=True) 51 | 52 | plt.subplot(2, 2, 1) 53 | plt.imshow(cube[..., nwvl//2]) 54 | plt.title('Ground truth') 55 | 56 | plt.subplot(2, 2, 2) 57 | plt.imshow(cube_noisy[..., nwvl//2]) 58 | plt.title('Noisy') 59 | 60 | plt.subplot(2, 2, 3) 61 | plt.imshow(cube_estim[..., nwvl//2]) 62 | plt.title('BM3D denoised output') 63 | 64 | plt.subplot(2, 2, 4) 65 | plt.imshow(abs(cube - cube_estim).mean(2), cmap='jet') 66 | plt.colorbar() 67 | plt.title('Absolute error') 68 | 69 | print(snrval, ssimval) 70 | plt.show() -------------------------------------------------------------------------------- /run_table1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ''' 4 | This script replicates results in table 1 in the main paper. 5 | 6 | Note: Running this script downloads datasets that can be up to 2.5GB in size. 7 | Make sure you have enough disk space before running this script. 8 | ''' 9 | 10 | import sys 11 | import tqdm 12 | import argparse 13 | import torchaudio 14 | import numpy as np 15 | 16 | 17 | import torch 18 | 19 | import matplotlib.pyplot as plt 20 | 21 | sys.path.append("modules") 22 | from torchvision import datasets, transforms 23 | from torch.utils.data import DataLoader 24 | from sklearn.decomposition import NMF 25 | import models 26 | import utils 27 | import losses 28 | import deep_prior 29 | import deep_decoder 30 | 31 | import numpy as np 32 | import logging 33 | import logging.config 34 | import scipy.sparse 35 | from numpy.linalg import eigh 36 | 37 | 38 | def eighk(M, k=0): 39 | """Returns ordered eigenvectors of a squared matrix. Too low eigenvectors 40 | are ignored. Optionally only the first k vectors/values are returned. 41 | Arguments 42 | --------- 43 | M - squared matrix 44 | k - (default 0): number of eigenvectors/values to return 45 | Returns 46 | ------- 47 | w : [:k] eigenvalues 48 | v : [:k] eigenvectors 49 | """ 50 | values, vectors = eigh(M) 51 | 52 | # get rid of too low eigenvalues 53 | s = np.where(values > _EPS)[0] 54 | vectors = vectors[:, s] 55 | values = values[s] 56 | 57 | # sort eigenvectors according to largest value 58 | idx = np.argsort(values)[::-1] 59 | values = values[idx] 60 | vectors = vectors[:, idx] 61 | 62 | # select only the top k eigenvectors 63 | if k > 0: 64 | values = values[:k] 65 | vectors = vectors[:, :k] 66 | 67 | return values, vectors 68 | 69 | 70 | def cmdet(d): 71 | """Returns the Volume of a simplex computed via the Cayley-Menger 72 | determinant. 73 | Arguments 74 | --------- 75 | d - euclidean distance matrix (shouldn't be squared) 76 | Returns 77 | ------- 78 | V - volume of the simplex given by d 79 | """ 80 | D = np.ones((d.shape[0] + 1, d.shape[0] + 1)) 81 | D[0, 0] = 0.0 82 | D[1:, 1:] = d ** 2 83 | j = np.float32(D.shape[0] - 2) 84 | f1 = (-1.0) ** (j + 1) / ((2 ** j) * ((factorial(j)) ** 2)) 85 | cmd = f1 * np.linalg.det(D) 86 | 87 | # sometimes, for very small values, "cmd" might be negative, thus we take 88 | # the absolute value 89 | return np.sqrt(np.abs(cmd)) 90 | 91 | 92 | def simplex(d): 93 | """Computed the volume of a simplex S given by a coordinate matrix D. 94 | Arguments 95 | --------- 96 | d - coordinate matrix (k x n, n samples in k dimensions) 97 | Returns 98 | ------- 99 | V - volume of the Simplex spanned by d 100 | """ 101 | # compute the simplex volume using coordinates 102 | D = np.ones((d.shape[0] + 1, d.shape[1])) 103 | D[1:, :] = d 104 | V = np.abs(np.linalg.det(D)) / factorial(d.shape[1] - 1) 105 | return V 106 | 107 | 108 | class PyMFBase: 109 | """ 110 | PyMF Base Class. Does nothing useful apart from providing some basic methods. 111 | """ 112 | 113 | # some small value 114 | 115 | _EPS = 1e-10 116 | 117 | def __init__(self, data, num_bases=4, **kwargs): 118 | """ """ 119 | 120 | def setup_logging(): 121 | # create logger 122 | self._logger = logging.getLogger("pymf") 123 | 124 | # add ch to logger 125 | if len(self._logger.handlers) < 1: 126 | # create console handler and set level to debug 127 | ch = logging.StreamHandler() 128 | ch.setLevel(logging.DEBUG) 129 | # create formatter 130 | formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s") 131 | 132 | # add formatter to ch 133 | ch.setFormatter(formatter) 134 | 135 | self._logger.addHandler(ch) 136 | 137 | setup_logging() 138 | 139 | # set variables 140 | self.data = data 141 | self._num_bases = num_bases 142 | 143 | # initialize H and W to random values 144 | self._data_dimension, self._num_samples = self.data.shape 145 | 146 | def residual(self): 147 | """Returns the residual in % of the total amount of data 148 | Returns 149 | ------- 150 | residual : float 151 | """ 152 | res = np.sum(np.abs(self.data - np.dot(self.W, self.H))) 153 | total = 100.0 * res / np.sum(np.abs(self.data)) 154 | return total 155 | 156 | def frobenius_norm(self): 157 | """Frobenius norm (||data - WH||) of a data matrix and a low rank 158 | approximation given by WH. Minimizing the Fnorm ist the most common 159 | optimization criterion for matrix factorization methods. 160 | Returns: 161 | ------- 162 | frobenius norm: F = ||data - WH|| 163 | """ 164 | # check if W and H exist 165 | if hasattr(self, "H") and hasattr(self, "W"): 166 | if scipy.sparse.issparse(self.data): 167 | tmp = self.data[:, :] - (self.W * self.H) 168 | tmp = tmp.multiply(tmp).sum() 169 | err = np.sqrt(tmp) 170 | else: 171 | err = np.sqrt(np.sum((self.data[:, :] - np.dot(self.W, self.H)) ** 2)) 172 | else: 173 | err = None 174 | 175 | return err 176 | 177 | def _init_w(self): 178 | """Initalize W to random values [0,1].""" 179 | # add a small value, otherwise nmf and related methods get into trouble as 180 | # they have difficulties recovering from zero. 181 | self.W = np.random.random((self._data_dimension, self._num_bases)) + 10 ** -4 182 | 183 | def _init_h(self): 184 | """Initalize H to random values [0,1].""" 185 | self.H = np.random.random((self._num_bases, self._num_samples)) + 10 ** -4 186 | 187 | def _update_h(self): 188 | """Overwrite for updating H.""" 189 | pass 190 | 191 | def _update_w(self): 192 | """Overwrite for updating W.""" 193 | pass 194 | 195 | def _converged(self, i): 196 | """ 197 | If the optimization of the approximation is below the machine precision, 198 | return True. 199 | Parameters 200 | ---------- 201 | i : index of the update step 202 | Returns 203 | ------- 204 | converged : boolean 205 | """ 206 | derr = np.abs(self.ferr[i] - self.ferr[i - 1]) / self._num_samples 207 | if derr < self._EPS: 208 | return True 209 | else: 210 | return False 211 | 212 | def factorize( 213 | self, 214 | niter=100, 215 | show_progress=False, 216 | compute_w=True, 217 | compute_h=True, 218 | compute_err=True, 219 | epoch_hook=None, 220 | ): 221 | """Factorize s.t. WH = data 222 | Parameters 223 | ---------- 224 | niter : int 225 | number of iterations. 226 | show_progress : bool 227 | print some extra information to stdout. 228 | compute_h : bool 229 | iteratively update values for H. 230 | compute_w : bool 231 | iteratively update values for W. 232 | compute_err : bool 233 | compute Frobenius norm |data-WH| after each update and store 234 | it to .ferr[k]. 235 | epoch_hook : function 236 | If this exists, evaluate it every iteration 237 | Updated Values 238 | -------------- 239 | .W : updated values for W. 240 | .H : updated values for H. 241 | .ferr : Frobenius norm |data-WH| for each iteration. 242 | """ 243 | 244 | if show_progress: 245 | self._logger.setLevel(logging.INFO) 246 | else: 247 | self._logger.setLevel(logging.ERROR) 248 | 249 | # create W and H if they don't already exist 250 | # -> any custom initialization to W,H should be done before 251 | if not hasattr(self, "W") and compute_w: 252 | self._init_w() 253 | 254 | if not hasattr(self, "H") and compute_h: 255 | self._init_h() 256 | 257 | # Computation of the error can take quite long for large matrices, 258 | # thus we make it optional. 259 | if compute_err: 260 | self.ferr = np.zeros(niter) 261 | 262 | for i in range(niter): 263 | if compute_w: 264 | self._update_w() 265 | 266 | if compute_h: 267 | self._update_h() 268 | 269 | if compute_err: 270 | self.ferr[i] = self.frobenius_norm() 271 | self._logger.info("FN: %s (%s/%s)" % (self.ferr[i], i + 1, niter)) 272 | else: 273 | self._logger.info("Iteration: (%s/%s)" % (i + 1, niter)) 274 | 275 | if epoch_hook is not None: 276 | epoch_hook(self) 277 | 278 | # check if the err is not changing anymore 279 | if i > 1 and compute_err: 280 | if self._converged(i): 281 | # adjust the error measure 282 | self.ferr = self.ferr[:i] 283 | break 284 | 285 | 286 | class SNMF(PyMFBase): 287 | """ 288 | SNMF(data, num_bases=4) 289 | 290 | Semi Non-negative Matrix Factorization. Factorize a data matrix into two 291 | matrices s.t. F = | data - W*H | is minimal. For Semi-NMF only H is 292 | constrained to non-negativity. 293 | 294 | Parameters 295 | ---------- 296 | data : array_like, shape (_data_dimension, _num_samples) 297 | the input data 298 | num_bases: int, optional 299 | Number of bases to compute (column rank of W and row rank of H). 300 | 4 (default) 301 | 302 | Attributes 303 | ---------- 304 | W : "data_dimension x num_bases" matrix of basis vectors 305 | H : "num bases x num_samples" matrix of coefficients 306 | ferr : frobenius norm (after calling .factorize()) 307 | 308 | Example 309 | ------- 310 | Applying Semi-NMF to some rather stupid data set: 311 | 312 | >>> import numpy as np 313 | >>> data = np.array([[1.0, 0.0, 2.0], [0.0, 1.0, 1.0]]) 314 | >>> snmf_mdl = SNMF(data, num_bases=2) 315 | >>> snmf_mdl.factorize(niter=10) 316 | 317 | The basis vectors are now stored in snmf_mdl.W, the coefficients in snmf_mdl.H. 318 | To compute coefficients for an existing set of basis vectors simply copy W 319 | to snmf_mdl.W, and set compute_w to False: 320 | 321 | >>> data = np.array([[1.5], [1.2]]) 322 | >>> W = np.array([[1.0, 0.0], [0.0, 1.0]]) 323 | >>> snmf_mdl = SNMF(data, num_bases=2) 324 | >>> snmf_mdl.W = W 325 | >>> snmf_mdl.factorize(niter=1, compute_w=False) 326 | 327 | The result is a set of coefficients snmf_mdl.H, s.t. data = W * snmf_mdl.H. 328 | """ 329 | 330 | def _update_w(self): 331 | W1 = np.dot(self.data[:, :], self.H.T) 332 | W2 = np.dot(self.H, self.H.T) 333 | self.W = np.dot(W1, np.linalg.inv(W2)) 334 | 335 | def _update_h(self): 336 | def separate_positive(m): 337 | return (np.abs(m) + m) / 2.0 338 | 339 | def separate_negative(m): 340 | return (np.abs(m) - m) / 2.0 341 | 342 | XW = np.dot(self.data[:, :].T, self.W) 343 | 344 | WW = np.dot(self.W.T, self.W) 345 | WW_pos = separate_positive(WW) 346 | WW_neg = separate_negative(WW) 347 | 348 | XW_pos = separate_positive(XW) 349 | H1 = (XW_pos + np.dot(self.H.T, WW_neg)).T 350 | 351 | XW_neg = separate_negative(XW) 352 | H2 = (XW_neg + np.dot(self.H.T, WW_pos)).T + 10 ** -9 353 | 354 | self.H *= np.sqrt(H1 / H2) 355 | 356 | 357 | DATASET = "TFR" 358 | 359 | 360 | def get_nmf(data, noisy_data, n_clusters): 361 | flat_data = data.reshape((data.shape[0], -1)) 362 | flat_noisy_data = noisy_data.reshape((noisy_data.shape[0], -1)) 363 | scores = [] 364 | reconstructions = [] 365 | criterion = losses.L2Norm() 366 | 367 | print('Running NMF') 368 | nmf = NMF(n_components=n_clusters, alpha=0.0, l1_ratio=0.0) 369 | nmf.fit(flat_noisy_data) 370 | reconstructions = [nmf.inverse_transform(nmf.transform(flat_noisy_data))] 371 | scores = [criterion(torch.from_numpy(flat_data - reconstructions[-1]))] 372 | 373 | print('Running semi-NMF') 374 | snmf_mdl = SNMF(flat_noisy_data, num_bases=n_clusters) 375 | snmf_mdl.factorize(niter=500) 376 | reconstructions.append(np.dot(snmf_mdl.W, snmf_mdl.H)) 377 | scores.append(criterion(torch.from_numpy(flat_data - reconstructions[-1]))) 378 | 379 | return scores, reconstructions 380 | 381 | 382 | if __name__ == "__main__": 383 | 384 | # Generate data 385 | if DATASET == "MNIST": 386 | transform = transforms.Compose( 387 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 388 | ) 389 | data = datasets.MNIST( 390 | root="./", train=True, download=True, transform=transform 391 | ) 392 | train_loader = DataLoader(data, batch_size=len(data)) 393 | data = next(iter(train_loader))[0].numpy()[:2048] 394 | # Generate data 395 | elif DATASET == "CIFAR": 396 | transform = transforms.Compose( 397 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 398 | ) 399 | data = datasets.CIFAR10( 400 | root="./", train=True, download=True, transform=transform 401 | ) 402 | train_loader = DataLoader(data, batch_size=len(data)) 403 | data = next(iter(train_loader))[0].numpy()[:2048] 404 | elif DATASET == "TFR": 405 | data = torchaudio.datasets.SPEECHCOMMANDS("./", download=True) 406 | spectro = torchaudio.transforms.Spectrogram( 407 | n_fft=1024, 408 | win_length=512, 409 | hop_length=32, 410 | center=True, 411 | pad_mode="reflect", 412 | power=2.0, 413 | ) 414 | data = [data[i][0] for i in range(40)] 415 | data = torch.stack(data) 416 | data = spectro(data)[:, :, :500, :500] 417 | 418 | data = data.numpy() 419 | 420 | data = data.mean(1, keepdims=True) 421 | data -= data.min((1, 2, 3), keepdims=True) 422 | data /= data.max((1, 2, 3), keepdims=True) 423 | noisy_data = data + ( 424 | np.random.randn(*data.shape) * 0.3 + (np.random.randn(*data.shape) ** 2) * 0.3 425 | ) 426 | noisy_data = np.clip(noisy_data, 0, 10) 427 | 428 | print("Input PSNR: %.2f dB"%utils.psnr(data, noisy_data)) 429 | 430 | # Set your simulation constants here 431 | n_samples = data.shape[0] # Size of the matrix 432 | n_channels = data.shape[1] 433 | image_size = data.shape[2:] 434 | h, w = image_size 435 | n_clusters = 128 436 | 437 | # Network parameters 438 | nettype = "DD" 439 | reg_noise_std = 1 / 30.0 440 | 441 | def runner( 442 | activation_u, 443 | activation_v, 444 | nettype=nettype, 445 | data=data, 446 | noisy_data=noisy_data, 447 | n_clusters=n_clusters, 448 | image_size=image_size, 449 | n_channels=n_channels, 450 | n_samples=n_samples, 451 | ): 452 | sched_args = argparse.Namespace() 453 | # Learning constants 454 | scheduler_type = "none" 455 | learning_rate = 1e-2 456 | epochs = 1000 457 | sched_args.step_size = 2000 458 | sched_args.gamma = 0.9999 459 | sched_args.max_lr = learning_rate 460 | sched_args.min_lr = 1e-6 461 | sched_args.epochs = epochs 462 | 463 | u_inp = utils.get_inp([n_samples, n_clusters]) 464 | v_inp = utils.get_inp([n_clusters, n_channels, h, w]) 465 | 466 | # Create networks 467 | if nettype == "MLP": 468 | u_net = models.SimpleForwardMLP(n_samples, [32, 32, n_samples]).cuda() 469 | v_net = models.SimpleForwardMLP(n_clusters, [32, 32, n_clusters]).cuda() 470 | elif nettype == "CNN": 471 | v_net = models.SimpleForward2D( 472 | n_channels, n_channels, [512, 512, 512] 473 | ).cuda() 474 | elif nettype == "DIP": 475 | v_net = deep_prior.get_net( 476 | n_clusters, 477 | "skip", 478 | "reflection", 479 | upsample_mode="bilinear", 480 | skip_n33d=128, 481 | skip_n33u=128, 482 | num_scales=4, 483 | n_channels=n_clusters, 484 | ).cuda() 485 | u_net = deep_prior.get_net( 486 | n_clusters, 487 | "skip1d", 488 | "reflection", 489 | upsample_mode="linear", 490 | skip_n33d=128, 491 | skip_n33u=128, 492 | num_scales=5, 493 | n_channels=n_clusters, 494 | ).cuda() 495 | v_inp = utils.get_inp([1, n_clusters, h, w]) 496 | u_inp = utils.get_inp([1, n_clusters, n_samples]) 497 | elif nettype == "DD": 498 | # This is multichannel version 499 | v_inp = utils.get_inp([n_channels, n_clusters, h // 4, w // 4]) 500 | u_inp = utils.get_inp([n_channels, n_clusters, n_samples // 4]) 501 | v_net = deep_decoder.decodernw( 502 | num_output_channels=n_clusters, 503 | num_channels_up=[n_clusters] 504 | + [128] * int(np.log2(image_size[0] // v_inp.shape[2]) - 1), 505 | ).cuda() 506 | 507 | # probability membership network 508 | u_net = deep_decoder.decodernw1d( 509 | num_output_channels=n_clusters, num_channels_up=[n_clusters, 128] 510 | ) 511 | u_net = u_net.cuda() 512 | 513 | # Extract training parameters 514 | net_params = list(v_net.parameters()) + list(u_net.parameters()) 515 | inp_params = [u_inp] + [v_inp] 516 | 517 | # You can either optimize both net and inputs, or just net 518 | params = net_params + inp_params 519 | optimizer = torch.optim.Adam(lr=learning_rate, params=params) 520 | 521 | # Create a learning scheduler 522 | scheduler = utils.get_scheduler(scheduler_type, optimizer, sched_args) 523 | 524 | # Create loss functions -- loses.L1Norm() or losses.L2Norm() 525 | criterion = losses.L2Norm() 526 | 527 | mse_array = np.zeros(epochs) 528 | tmse_array = np.zeros(epochs) 529 | 530 | # Now start iterations 531 | best_mse = float("inf") 532 | best_epoch = 0 533 | 534 | # Move them to device 535 | scores, reconstructions = get_nmf(data, noisy_data, n_clusters) 536 | print(scores) 537 | psnr2 = utils.psnr(data, reconstructions[0].reshape(data.shape)) 538 | print("NMF: %.2f dB"%psnr2) 539 | psnr2 = utils.psnr(data, reconstructions[1].reshape(data.shape)) 540 | print("sNMF: %.2f dB"%psnr2) 541 | 542 | data = torch.tensor(data).cuda() 543 | noisy_data = torch.tensor(noisy_data).cuda() 544 | 545 | tbar = tqdm.tqdm(range(epochs)) 546 | for idx in tbar: 547 | u_inp_per = u_inp 548 | v_inp_per = v_inp 549 | 550 | u_output = activation_u(u_net(u_inp_per)) 551 | centroids = activation_v(v_net(v_inp_per)) 552 | 553 | centroids_mat = centroids.reshape(1, n_clusters, h * w) 554 | reconstruction = torch.bmm(u_output.permute(0, 2, 1), centroids_mat) 555 | reconstruction = reconstruction.reshape(1, n_samples, h, w).permute( 556 | 1, 0, 2, 3 557 | ) 558 | 559 | loss = criterion(reconstruction - noisy_data) 560 | 561 | loss_l1 = centroids.abs().mean() + u_output.abs().mean() 562 | loss_l2 = criterion(centroids) + criterion(u_output) 563 | 564 | loss = loss + 0.1 * 0.5 ** loss_l1 + 0.5 * 0.1 * 0.5 * loss_l2 565 | 566 | optimizer.zero_grad() 567 | loss.backward() 568 | optimizer.step() 569 | scheduler.step() 570 | 571 | # # Visualize the reconstruction 572 | mat_cpu = reconstruction.detach().cpu().numpy() 573 | centroids_cpu = centroids.detach().cpu().numpy() 574 | 575 | mseval = criterion(reconstruction - data).item() 576 | 577 | mse_array[idx] = mseval 578 | tmse_array[idx] = loss.item() 579 | 580 | tbar.set_description('%.4e'%mseval) 581 | tbar.refresh() 582 | 583 | if tmse_array[idx] < best_mse: 584 | best_mse = tmse_array[idx] 585 | best_epoch = idx 586 | best_mat = reconstruction.detach().cpu().numpy() 587 | 588 | # Now compute accuracy 589 | data = data.cpu().numpy() 590 | psnr1 = utils.psnr(data, best_mat) 591 | 592 | print("DeepTensor NMF: %.2fdB" % psnr1) 593 | 594 | print('NMF, softplus') 595 | runner(torch.nn.Softplus(), torch.nn.Softplus()) 596 | 597 | print('NMF, abs') 598 | runner(torch.abs, torch.abs) 599 | 600 | print('NMF, relu') 601 | runner(torch.relu, torch.relu) 602 | 603 | print('semi-NMF, softplus') 604 | runner(torch.nn.Softplus(), lambda x: x) 605 | 606 | print('semi-NMF, abs') 607 | runner(torch.abs, lambda x: x) 608 | 609 | print('semi-NMF, relu') 610 | runner(torch.relu, lambda x: x) 611 | --------------------------------------------------------------------------------