├── .gitignore ├── LICENSE ├── PARAMS.md ├── README.md ├── __pycache__ ├── propagation.cpython-37.pyc └── utils.cpython-37.pyc ├── analysis └── analyze_models.ipynb ├── data ├── lamb.png ├── penguin.png ├── sbd_train_img.jpg └── sbd_val_img.jpg ├── dataio.py ├── denoising_unet.py ├── environment.yml ├── optics.py ├── params.json ├── propagation.py ├── pytorch_prototyping ├── __pycache__ │ └── pytorch_prototyping.cpython-37.pyc └── pytorch_prototyping.py ├── ranges.json ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | val/ 3 | __pycache__/ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Vincent Sitzmann 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PARAMS.md: -------------------------------------------------------------------------------- 1 | # Guide to hyperparameters 2 | 3 | * `data_root`: folder containing image dataset 4 | * `logging_root`: folder to save all experiments in 5 | * `train_test`: specify `train` or `val` to specify which dataset to load 6 | * `exp_name`: name of experiment to specify 7 | * `checkpoint`: load checkpoint file if you want to continue training from a previous training session 8 | * `max_epoch`: maximum number of epochs 9 | * `lr`: learning rate 10 | * `batch_size`: batch size 11 | * `reg_weight`: regularizer factor 12 | * `init_K`: initialization Wiener filtering 13 | * `use_weiener`: `true` or `false` to learn wiener filtering damping factor K 14 | * `resolution`: int for resolution of point spread function 15 | * `pixel_pitch`: float for pixel size 16 | * `focal_length`: float for distance between phase mask and sensor 17 | * `r_cutoff`: `null` or an int for the radius of the aperture 18 | * `refractive_idc`: refractive index of phase mask 19 | * `wavelength`: float for wavelength of interest 20 | * `init_lens`: `random`, `fresnel`, `plano` for initializing the height map of the phase mask 21 | * `single_image`: boolean for testing optimizing phase mask over one image 22 | * `download_data`: boolean for downloading data (should be `false` after first download) 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # End-to-end optimization of a lensless imaging system 2 | 3 | Final project for Winter 2020 iteration of EE 367: Computational Imaging at Stanford. 4 | 5 | Author: Cindy Nguyen, cindyn at stanford.edu 6 | 7 | This repo contains code to perform end-to-end optimization of a plastic phase mask placed close to the sensor (<= 25 mm focal distance). 8 | 9 | ## Pipeline 10 | ![pipeline](https://user-images.githubusercontent.com/21781041/76365430-8a440300-62e4-11ea-8903-5979883f99ee.png) 11 | 12 | We implement an optics module, sensor module, and Wiener deconvolution for image reconstruction. The loss is backpropagated into the heightmap to optimize a coded phase mask. 13 | 14 | ## File guide 15 | * `train.py` is the main training script (implements sequence of experiments as a queue, each experiment means training on one particular set of hyperparameters) 16 | * `params.json` is used to load hyperparameters relevant for training and initialization (see `README.md` for details) 17 | * `ranges.json` is where you can specify the hyperparameters to switch out in between experiments (each experiment will stop depending on early stopping criteria) 18 | * `dataio.py` contains the dataloader that will load the SBDataset for training 19 | * `denoising_unet.py` contains the model and model helper functions 20 | * `notebooks/analyze_models.ipynb` loads models and plots losses 21 | 22 | ## Running code 23 | * `git clone` this repository 24 | * Edit `params.json` and `ranges.json` as needed for experiment (see ```PARAMS.md``` for details) 25 | * If running script to test is things are set up properly, you can run with the settings as is. This will load in a single image and beginning optimizing a height map and damping factor starting from an in-focus Fresnel lens height map. 26 | * If you want to run with the dataset, set the parameter `download_data` to be `true`. After first run, set this to `false`. 27 | * In console, run 28 | ```ssh 29 | $ CUDA_VISIBLE_DEVICES=# python3 train.py 30 | ``` 31 | where `#` specifies GPU device number. If running on CPU, you can simply run `python3 train.py`. 32 | * Data generated from the experiment (saved models and Tensorboard files) will be specified in `runs/exp_name/exp_name_#` where `exp_name` is as specified in hyperaparameters and `#` is automatically determined. 33 | * The training script is set up so that a new experiment is created for each hyperparameter in `ranges.json`, run each sequentially to completion, and save model checkpoints during training in the `runs` folder. 34 | * Data files are not included to save space. 35 | 36 | ## Results 37 | ![doe](https://user-images.githubusercontent.com/21781041/76365740-54534e80-62e5-11ea-81c6-d718e3d0cd54.png) 38 | Results from optimizing the height map only. 39 | 40 | ![wiener](https://user-images.githubusercontent.com/21781041/76365750-5cab8980-62e5-11ea-93b1-b138503c378b.png) 41 | Results from optimizing both the height map and Wiener deconvolution damping factor. 42 | 43 | ## Dependencies 44 | * pytorch 45 | * numpy 46 | * tensorboard 47 | * [U-Net repo](https://github.com/vsitzmann/cifar10_denoising) 48 | * [Propagation and utils repo](https://github.com/computational-imaging/citorch) 49 | -------------------------------------------------------------------------------- /__pycache__/propagation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMHTVM/lensless/0d67a310bab08551d7422fa792f3422a7ee7d9cb/__pycache__/propagation.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMHTVM/lensless/0d67a310bab08551d7422fa792f3422a7ee7d9cb/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /data/lamb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMHTVM/lensless/0d67a310bab08551d7422fa792f3422a7ee7d9cb/data/lamb.png -------------------------------------------------------------------------------- /data/penguin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMHTVM/lensless/0d67a310bab08551d7422fa792f3422a7ee7d9cb/data/penguin.png -------------------------------------------------------------------------------- /data/sbd_train_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMHTVM/lensless/0d67a310bab08551d7422fa792f3422a7ee7d9cb/data/sbd_train_img.jpg -------------------------------------------------------------------------------- /data/sbd_val_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMHTVM/lensless/0d67a310bab08551d7422fa792f3422a7ee7d9cb/data/sbd_val_img.jpg -------------------------------------------------------------------------------- /dataio.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torchvision.transforms import * 3 | import sys 4 | import optics 5 | from utils import * 6 | 7 | 8 | class NoisySBDataset(): 9 | def __init__(self, hyps): 10 | super().__init__() 11 | 12 | self.transforms = Compose([ 13 | CenterCrop(size=(256,256)), 14 | Resize(size=(512,512)), 15 | ToTensor() 16 | ]) 17 | 18 | # if you set download=True AND you've downloaded the files, 19 | # it'll never finish running :-( 20 | self.dataset = torchvision.datasets.SBDataset(root=hyps['data_root'], 21 | image_set=hyps['train_test'], 22 | download=hyps['download_data']) 23 | 24 | def __len__(self): 25 | return len(self.dataset) 26 | 27 | def __getitem__(self, idx): # a[x] for calling a.__getitem__(x) 28 | """Returns tuple of (model_input, ground_truth) 29 | Modifies each item of the dataset upon retrieval 30 | a[x] for calling a.__getitem__(x) 31 | """ 32 | img, _ = self.dataset[idx] 33 | if self.transforms: 34 | img = self.transforms(img) 35 | 36 | img = torch.Tensor(optics.srgb_to_linear(img)) 37 | 38 | return img, img -------------------------------------------------------------------------------- /denoising_unet.py: -------------------------------------------------------------------------------- 1 | import skimage.measure 2 | import torchvision 3 | import utils 4 | from pytorch_prototyping.pytorch_prototyping import * 5 | import matplotlib.pyplot as plt 6 | import torch.nn 7 | import optics 8 | from propagation import Propagation 9 | 10 | if torch.cuda.is_available(): 11 | DEVICE = torch.device("cuda:0") 12 | else: 13 | DEVICE = torch.device("cpu") 14 | 15 | 16 | def num_divisible_by_2(number): 17 | return np.floor(np.log2(number)).astype(int) 18 | 19 | 20 | def get_num_net_params(net): 21 | '''Counts number of trainable parameters in pytorch module''' 22 | model_parameters = filter(lambda p: p.requires_grad, net.parameters()) 23 | params = sum([np.prod(p.size()) for p in model_parameters]) 24 | return params 25 | 26 | 27 | class ConvolveImage(nn.Module): 28 | def __init__(self, hyps, K, heightmap): 29 | super(ConvolveImage, self).__init__() 30 | self.resolution = hyps['resolution'] 31 | self.r_cutoff = hyps['r_cutoff'] 32 | self.wavelength = hyps['wavelength'] 33 | self.focal_length = hyps['focal_length'] 34 | self.pixel_pitch = hyps['pixel_pitch'] 35 | self.refractive_idc = hyps['refractive_idc'] 36 | self.use_wiener = hyps['use_wiener'] 37 | self.heightmap = heightmap 38 | self.K = K 39 | 40 | def forward(self, x): 41 | # model point from infinity 42 | input_field = torch.ones((self.resolution, self.resolution)) 43 | 44 | phase_delay = utils.heightmap_to_phase(self.heightmap, 45 | self.wavelength, 46 | self.refractive_idc) 47 | 48 | field = optics.propagate_through_lens(input_field, phase_delay) 49 | 50 | field = optics.circular_aperture(field, self.r_cutoff) 51 | 52 | # kernel_type = 'fresnel_conv' leads to nans 53 | element = Propagation(kernel_type='fresnel', 54 | propagation_distances=self.focal_length, 55 | slm_resolution=[self.resolution, self.resolution], 56 | slm_pixel_pitch=[self.pixel_pitch, self.pixel_pitch], 57 | wavelength=self.wavelength) 58 | 59 | field = element.forward(field) 60 | psf = utils.field_to_intensity(field) 61 | 62 | psf /= psf.sum() 63 | 64 | final = optics.convolve_img(x, psf) 65 | if not self.use_wiener: 66 | return final.to(DEVICE) 67 | else: 68 | # perform Wiener filtering 69 | final = final.to(DEVICE) 70 | imag = torch.zeros(final.shape).to(DEVICE) 71 | img = utils.stack_complex(final, imag) 72 | img_fft = torch.fft(utils.ifftshift(img), 2) 73 | 74 | otf = optics.psf2otf(psf, output_size=img.shape[2:4]) 75 | 76 | otf = torch.stack((otf, otf, otf), 0) 77 | otf = torch.unsqueeze(otf, 0) 78 | conj_otf = utils.conj(otf) 79 | 80 | otf_img = utils.mul_complex(img_fft, conj_otf) 81 | 82 | denominator = optics.abs_complex(otf) 83 | denominator[:, :, :, :, 0] += self.K 84 | product = utils.div_complex(otf_img, denominator) 85 | 86 | filtered = utils.ifftshift(torch.ifft(product, 2)) 87 | filtered = torch.clamp(filtered, min=1e-5) 88 | 89 | return filtered[:, :, :, :, 0] 90 | 91 | 92 | # class WienerFilter(nn.Module): 93 | # """Perform Wiener Filtering with learnable damping factor 94 | # CUDA backprop issues with module as is 95 | # """ 96 | # 97 | # def __init__(self, hyps, heightmap, K): 98 | # super(WienerFilter, self).__init__() 99 | # self.psf = optics.heightmap_to_psf(hyps, heightmap).to(DEVICE) 100 | # self.K = K 101 | # 102 | # def forward(self, x): 103 | # return optics.wiener_filter(x, self.psf, K=self.K ** 2) 104 | 105 | 106 | class DenoisingUnet(nn.Module): 107 | """U-Net-based deconvolution 108 | Assumes images are scaled from -1 to 1. 109 | """ 110 | 111 | def __init__(self, hyps): 112 | super().__init__() 113 | 114 | self.norm = nn.InstanceNorm2d 115 | self.img_sidelength = hyps['resolution'] 116 | 117 | num_downs_unet = num_divisible_by_2(512) 118 | 119 | self.nf0 = 64 # Number of features to use in the outermost layer of U-Net 120 | 121 | init_heightmap = optics.heightmap_initializer(focal_length=hyps['focal_length'], 122 | resolution=hyps['resolution'], 123 | pixel_pitch=hyps['pixel_pitch'], 124 | refractive_idc=hyps['refractive_idc'], 125 | wavelength=hyps['wavelength'], 126 | init_lens=hyps['init_lens']) 127 | 128 | self.heightmap = nn.Parameter(init_heightmap, requires_grad=True) 129 | self.K = nn.Parameter(torch.ones(1) * hyps['init_K']) 130 | 131 | torch.random.manual_seed(0) 132 | 133 | modules = [] 134 | 135 | modules.append(ConvolveImage(hyps, 136 | K=self.K, 137 | heightmap=self.heightmap)) 138 | 139 | # TODO: implement wiener filtering as a separate module 140 | # if hyps['learn_wiener']: 141 | # modules.append(WienerFilter(psf, K=self.K)) 142 | # else: 143 | # modules.append(WienerFilter(psf, K=hyps['K'])) 144 | 145 | # if hyps["use_wiener"]: 146 | # modules.append(WienerFilter(hyps, heightmap=self.heightmap, K=self.K)) 147 | 148 | # modules.append(Unet(in_channels=3, 149 | # out_channels=3, 150 | # use_dropout=False, 151 | # nf0=self.nf0, 152 | # max_channels=8 * self.nf0, 153 | # norm=self.norm, 154 | # num_down=num_downs_unet, 155 | # outermost_linear=True)) 156 | # modules.append(nn.Tanh()) 157 | 158 | self.denoising_net = nn.Sequential(*modules) 159 | 160 | # Losses 161 | self.loss = nn.MSELoss() 162 | 163 | # List of logs 164 | self.counter = 0 # A counter to enable logging every nth iteration 165 | self.logs = list() 166 | self.learned_gamma = list() 167 | 168 | self.to(DEVICE) 169 | 170 | # print("*" * 100) 171 | # print(self) # Prints the model 172 | # print("*" * 100) 173 | print("Number of parameters: %d" % get_num_net_params(self)) 174 | print("*" * 100) 175 | 176 | def get_distortion_loss(self, prediction, ground_truth): 177 | trgt_imgs = ground_truth.to(DEVICE) 178 | 179 | return self.loss(prediction, trgt_imgs) 180 | 181 | def get_regularization_loss(self, prediction, ground_truth): 182 | return torch.Tensor([0]).to(DEVICE) 183 | 184 | def write_updates(self, writer, predictions, ground_truth, input, iter, hyps): 185 | """Writes out tensorboard scalar and figures.""" 186 | batch_size, _, _, _ = predictions.shape 187 | ground_truth = ground_truth.to(DEVICE) 188 | 189 | output_input_gt = torch.cat((predictions, ground_truth), dim=0) 190 | grid = torchvision.utils.make_grid(output_input_gt, 191 | scale_each=True, 192 | nrow=batch_size, 193 | normalize=True).cpu().detach().numpy() 194 | writer.add_image("Output_vs_gt", grid, iter) 195 | 196 | writer.add_scalar("psnr", self.get_psnr(predictions, ground_truth), iter) 197 | writer.add_scalar("damp", self.get_damp(), iter) 198 | writer.add_figure("heightmap", self.get_heightmap_fig(), iter) 199 | 200 | psf = self.get_psf(hyps) 201 | plt.figure() 202 | plt.imshow(psf) 203 | plt.colorbar() 204 | fig = plt.gcf() 205 | plt.close() 206 | writer.add_figure("psf", fig, iter) 207 | 208 | def get_psnr(self, predictions, ground_truth): 209 | """Calculates the PSNR of the model's prediction.""" 210 | batch_size, _, _, _ = predictions.shape 211 | pred = predictions.detach().cpu().numpy() 212 | gt = ground_truth.detach().cpu().numpy() 213 | 214 | return skimage.measure.compare_psnr(gt, pred, data_range=2) 215 | 216 | def get_damp(self): 217 | """Returns current Wiener filtering damping factor.""" 218 | return self.K.data.cpu() 219 | 220 | def get_psf(self, hyps): 221 | """Returns the PSF of the current heightmap.""" 222 | psf = optics.heightmap_to_psf(hyps, self.get_heightmap()) 223 | return psf.cpu().numpy() 224 | 225 | def get_heightmap_fig(self): 226 | """Wrapper function for getting heightmap and returning 227 | figure handle.""" 228 | x = self.heightmap.data.cpu().numpy() 229 | plt.figure() 230 | plt.imshow(x) 231 | plt.colorbar() 232 | fig = plt.gcf() 233 | return fig 234 | 235 | def get_heightmap(self): 236 | """Returns heightmap.""" 237 | return self.heightmap.data.cpu() 238 | 239 | def forward(self, input): 240 | self.logs = list() # Resets the logs 241 | return self.denoising_net(input) 242 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: lensless 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - imageio=2.6.1=py37_0 7 | - intel-openmp=2019.4=233 8 | - jpeg=9b=he5867d9_2 9 | - matplotlib=3.1.0=py37h54f8f79_0 10 | - numpy=1.16.4=py37hacdab7b_0 11 | - numpy-base=1.16.4=py37h6575580_0 12 | - pandas=0.24.2=py37h0a44026_0 13 | - pillow=6.2.0=py37hb68e598_0 14 | - pip=19.1.1=py37_0 15 | - psutil=5.6.3=py37h1de35cc_0 16 | - pycparser=2.19=py37_0 17 | - pyparsing=2.3.1=py37_0 18 | - python=3.7.3=h359304d_0 19 | - python-dateutil=2.7.5=py37_0 20 | - scikit-image=0.16.2=py37h6c726b0_0 21 | - scipy=1.3.0=py37h1410ff5_0 22 | - setuptools=41.0.1=py37_0 23 | - pytorch=1.4.0=py3.7_0 24 | - torchvision=0.5.0=py37_cpu 25 | - pip: 26 | - torch==1.4.0 27 | - tensorboard==2.0.2 28 | prefix: //anaconda3/envs/lensless 29 | -------------------------------------------------------------------------------- /optics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom optics functions written PyTorch 3 | Author: Cindy Nguyen 4 | """ 5 | 6 | import utils 7 | import torch 8 | import numpy as np 9 | from propagation import Propagation 10 | import torch.nn 11 | 12 | if torch.cuda.is_available(): 13 | DEVICE = torch.device("cuda:0") 14 | else: 15 | DEVICE = torch.device("cpu") 16 | 17 | 18 | def linear_to_srgb(img): 19 | return np.where(img <= 0.0031308, 12.92 * img, 1.055 * img ** (0.41666) - 0.055) 20 | 21 | 22 | def srgb_to_linear(img): 23 | return np.where(img <= 0.04045, img / 12.92, ((img + 0.055) / 1.055) ** 2.4) 24 | 25 | 26 | def wiener_filter(img, psf, K): 27 | """ Performs Wiener filtering on a single channel 28 | :param img: pytorch tensor of image (N,C,H,W) 29 | :param psf: pytorch tensor of psf (H,W) 30 | :param K: damping factor (can be input through hyps or learned) 31 | :return: Wiener filtered image in one channel (N,C,H,W) 32 | """ 33 | img = img.to(DEVICE) 34 | psf = psf.to(DEVICE) 35 | imag = torch.zeros(img.shape).to(DEVICE) 36 | img = utils.stack_complex(img,imag) 37 | img_fft = torch.fft(utils.ifftshift(img),2) 38 | img_fft = img_fft.to(DEVICE) 39 | 40 | otf = psf2otf(psf, output_size=img.shape[2:4]) 41 | otf = torch.stack((otf,otf,otf),0) 42 | otf = torch.unsqueeze(otf, 0) 43 | 44 | conj_otf = utils.conj(otf) 45 | 46 | otf_img = utils.mul_complex(conj_otf,img_fft) 47 | 48 | denominator = abs_complex(otf) 49 | denominator[:, :, :, :, 0] += K 50 | product = utils.div_complex(otf_img, denominator) 51 | filtered = utils.ifftshift(torch.ifft(product,2)) 52 | filtered = torch.clamp(filtered, min=1e-5) 53 | 54 | return filtered[:,:,:,:,0] 55 | 56 | 57 | def convolve_img(image, psf): 58 | """Convolves image with a PSF kernel, convolves on each color channel 59 | :param image: pytorch tensor of image (B,N,H,W) 60 | :param psf: pytorch tensor of psf (H,W) 61 | :return: final convolved image (B,N,H,W) 62 | """ 63 | image = image.cpu() 64 | psf = torch.stack((psf, psf, psf), 0) 65 | psf = torch.unsqueeze(psf, 0) 66 | psf_stack = utils.stack_complex(psf, torch.zeros(psf.shape)) 67 | img_stack = utils.stack_complex(image, torch.zeros(image.shape)) 68 | convolved = utils.conv_fft(img_stack, psf_stack, padval=0) 69 | return convolved[:,:,:,:,0] 70 | 71 | def circular_aperture(input_field, r_cutoff): 72 | """ 73 | :param input_field: (H,W,2) - input field 74 | :param r_cutoff: int or None - radius cutoff for incoming light field 75 | :return: Light field filtered by the aperture 76 | """ 77 | input_shape = input_field.shape 78 | [x, y] = np.mgrid[-(input_shape[0] // 2): (input_shape[0] + 1) // 2, 79 | -(input_shape[1] // 2):(input_shape[1] + 1) // 2].astype(np.float64) 80 | if r_cutoff is None: 81 | r_cutoff = np.amax(x) 82 | r = np.sqrt(x ** 2 + y ** 2) 83 | aperture = (r < r_cutoff) 84 | aperture = torch.Tensor(aperture) 85 | aperture = utils.stack_complex(aperture, aperture) 86 | return aperture * input_field 87 | 88 | 89 | def propagate_through_lens(input_field, phase_delay): 90 | """ 91 | Provides complex valued wave field upon hitting an optical element 92 | :param input_field: (H,W) tensor of phase delay of optical element 93 | :param phase_delay: (H,W) tensor of incoming light field 94 | :return: (H,W,2) complex valued incident light field 95 | """ 96 | real, imag = utils.polar_to_rect(1, phase_delay) 97 | phase_delay = utils.stack_complex(real, imag) 98 | 99 | input_field = utils.stack_complex(input_field, 100 | torch.zeros(input_field.shape)) 101 | return utils.mul_complex(input_field.cpu(), phase_delay.cpu()) 102 | 103 | 104 | def heightmap_to_psf(hyps, height_map): 105 | resolution = hyps['resolution'] 106 | focal_length = hyps['focal_length'] 107 | wavelength = hyps['wavelength'] 108 | pixel_pitch = hyps['pixel_pitch'] 109 | refractive_idc = hyps['refractive_idc'] 110 | r_cutoff = hyps['r_cutoff'] 111 | 112 | input_field = torch.ones((resolution,resolution)) 113 | 114 | phase_delay = utils.heightmap_to_phase(height_map, 115 | wavelength, 116 | refractive_idc) 117 | 118 | field = propagate_through_lens(input_field, phase_delay) 119 | 120 | field = circular_aperture(field, r_cutoff) 121 | 122 | # propagate field from aperture to sensor 123 | element = Propagation(kernel_type='fresnel', 124 | propagation_distances=focal_length, 125 | slm_resolution=[resolution, resolution], 126 | slm_pixel_pitch=[pixel_pitch, pixel_pitch], 127 | wavelength=wavelength) 128 | field = element.forward(field) 129 | psf = utils.field_to_intensity(field) 130 | psf /= psf.sum() 131 | return psf.to(DEVICE) 132 | 133 | 134 | def fspecial_gauss(size, sigma): 135 | """ 136 | Function to mimic the 'fspecial' gaussian MATLAB function 137 | :param size: int - size of blur filter 138 | :param sigma: float - standard deviation of blur 139 | :return: normalized blur filter 140 | """ 141 | x, y = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] 142 | g = np.exp(-((x**2 + y**2)/(2.0*sigma**2))) 143 | return torch.Tensor(g/g.sum()) 144 | 145 | 146 | def heightmap_initializer(focal_length, 147 | resolution=1248, 148 | pixel_pitch=6.4e-6, 149 | refractive_idc=1.43, 150 | wavelength=530e-9, 151 | init_lens='fresnel'): 152 | """ 153 | Initialize heightmap before training 154 | :param focal_length: float - distance between phase mask and sensor 155 | :param resolution: int - size of phase mask 156 | :param pixel_pitch: float - pixel size of phase mask 157 | :param refractive_idc: float - refractive index of phase mask 158 | :param wavelength: float - wavelength of light 159 | :param init_lens: str - type of lens to initialize 160 | :return: height map 161 | """ 162 | if init_lens == 'fresnel' or init_lens == 'plano': 163 | convex_radius = (refractive_idc - 1.) * focal_length # based on lens maker formula 164 | 165 | N = resolution 166 | M = resolution 167 | [x, y] = np.mgrid[-(N // 2): (N + 1) // 2, 168 | -(M // 2):(M + 1) // 2].astype(np.float64) 169 | 170 | x = x * pixel_pitch 171 | y = y * pixel_pitch 172 | 173 | # get lens thickness by paraxial approximations 174 | heightmap = -(x ** 2 + y ** 2) / 2. * (1. / convex_radius) 175 | if init_lens == 'fresnel': 176 | phases = utils.heightmap_to_phase(heightmap, wavelength, refractive_idc) 177 | fresnel = simple_to_fresnel_lens(phases) 178 | heightmap = utils.phase_to_heightmap(fresnel, wavelength, refractive_idc) 179 | 180 | elif init_lens == 'flat': 181 | heightmap = torch.ones((resolution, resolution))*0.0001 182 | else: 183 | heightmap = torch.rand((resolution, resolution)) * pixel_pitch 184 | gauss_filter = fspecial_gauss(10, 5) 185 | 186 | heightmap = utils.stack_complex(torch.real(heightmap), torch.imag(heightmap)) 187 | gauss_filter = utils.stack_complex(torch.real(gauss_filter), torch.imag(gauss_filter)) 188 | heightmap = utils.conv_fft(heightmap, gauss_filter) 189 | heightmap = heightmap[:,:,0] 190 | 191 | return torch.Tensor(heightmap) 192 | 193 | 194 | def psf2otf(input_filter, output_size): 195 | """ 196 | Converts PSF to OTF that is same size as output_size 197 | :param input_filter: (H,W) PSF 198 | :param output_size: [int, int] - size of output filter 199 | :return: OTF (H,W) 200 | """ 201 | fh,fw = input_filter.shape 202 | 203 | padder = torch.nn.ZeroPad2d((0, output_size[1]-fw, 0, output_size[0]-fh)) 204 | padded_filter = padder(input_filter) 205 | 206 | # shift left 207 | left = padded_filter[:,0:(fw-1)//2] 208 | right = padded_filter[:,(fw-1)//2:] 209 | padded = torch.cat([right, left], 1) 210 | 211 | # shift down 212 | up = padded[0:(fh-1)//2,:] 213 | down = padded[(fh-1)//2:,:] 214 | padded = torch.cat([down, up], 0) 215 | 216 | tmp = utils.stack_complex(padded.to(DEVICE), torch.zeros(padded.shape).to(DEVICE)) 217 | tmp = torch.fft(tmp,2) 218 | return tmp.to(DEVICE) 219 | 220 | 221 | def abs_complex(input_field): 222 | """ 223 | Takes absolute value of complex input field 224 | :param input_field: tensor of size (B,C,H,W,2), last dimension is 225 | real and imag 226 | :return: absolute value of complex tensor (B,C,H,W,2) 227 | """ 228 | real, imag = utils.unstack_complex(input_field) 229 | real = real ** 2 + imag ** 2 230 | imag = torch.zeros(real.shape) 231 | return utils.stack_complex(real.to(DEVICE),imag.to(DEVICE)) 232 | 233 | 234 | def simple_to_fresnel_lens(phase_delay): 235 | """ 236 | Converts a plano convex lens phase delay to a Fresnel phase delay 237 | through 2*pi phase wrapping 238 | :param phase_delay: (H,W) phase delay of plano convex lens 239 | :return: phase delay of a Fresnel lens 240 | """ 241 | phase_delay -= phase_delay.min() 242 | return (phase_delay) % (2 * np.pi) - 2 * np.pi -------------------------------------------------------------------------------- /params.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_root": "./data", 3 | "logging_root": "./runs", 4 | "train_test": "train", 5 | "exp_name": "learn-doe", 6 | "checkpoint": null, 7 | "max_epoch": 100, 8 | "lr": 1e-8, 9 | "batch_size": 4, 10 | "reg_weight": 0.0, 11 | "init_K": 0.01, 12 | "use_wiener": false, 13 | "resolution": 512, 14 | "pixel_pitch": 2e-6, 15 | "focal_length": 0.1e-2, 16 | "r_cutoff": null, 17 | "refractive_idc": 1.4349, 18 | "wavelength": 530e-9, 19 | "init_lens": "fresnel", 20 | "single_image": false, 21 | "download_data": false 22 | } -------------------------------------------------------------------------------- /propagation.py: -------------------------------------------------------------------------------- 1 | """Functions for propagation through free space 2 | 3 | Propagation class initialization options: 4 | kernel_type: 'fraunhofer' (alias 'fourier'), 'fresnel', 'fresnel_conv', 5 | 'asm' (alias 'angular_spectrum'), or 'kirchoff'. The transfer 6 | function approaches may be more accurate 7 | fraunhofer: far-field diffraction, purely a Fourier transform 8 | fresnel: near-field diffraction with Fresnel approximation, implemented 9 | as a multiplication with a transfer function in Fourier domain 10 | fresnel_conv: same as fresnel, but implemented as a convolution with a 11 | spatial kernel, via FFT conv for speed 12 | asm: near-field diffraction with the Angular Spectrum Method, 13 | implemented as a transfer function. Note that this may have a 1px 14 | shift relative to the others due to the source paper padding the 15 | input by an extra pixel (for linear convolution) for derivations 16 | kirchoff: near-field diffractoin with the Kirchoff equations, 17 | implemented with a spatial kernel 18 | propagation_distances: distance or distances from SLM to image plane. 19 | Accepts scalars or lists. 20 | slm_resolution: number of pixels on SLM 21 | slm_pixel_pitch: size of pixels on SLM. 22 | image_resolution: number of sampling locations at image plane (optional, 23 | default matches SLM resolution) 24 | wavelength: laser wavelength, (optional, default 532e-9). 25 | propagation_parameters: override parameters for kernel/transfer function 26 | construction. Optional. Possible parameters, with 27 | defaults given: 28 | # for all methods 29 | 'padding_type', 'zero': pad complex field with 'median' or 'zero'. 30 | Using median may have less ringing, but zero 31 | is probably more accurate 32 | # for the spatial kernel convolution methods 33 | 'circular_prop_mask', True: circular mask for propagation kernels, for 34 | bandlimiting the phase function 35 | 'apodize_kernel', True: smooth the circular mask 36 | 'apodization_width', 50: width of cosine dropoff at edge, in pixels 37 | 'prop_mask_fraction', 1: artificially reduces the size of propagation 38 | mask (e.g., 2 will use half the radius) 39 | 'normalize_output', True: forces output field to have the same average 40 | amplitudes as the input when True. Only valid 41 | when using a single propagation distance 42 | # for the transfer function multiplication methods 43 | 'circular_padding', False: doesn't pad the field when True, resulting in 44 | implicit circular padding in the Fourier 45 | domain for the input field. May reduce 46 | ringing at the edges 47 | 'normalize_output', False: same as for the spatial kernel methods, but 48 | defaults to False because the transfer 49 | functions do a better job at energy 50 | preservation by default 51 | # only for the Angular Spectrum Method 52 | 'extra_pixel', True: when not using circular_padding, i.e., for a linear 53 | convolution, pad one extra pixel more than required 54 | (i.e., linear conv to length a + b instead of the 55 | minimum valid a + b - 1). The derivation from 56 | Matsushima and Shimobaba (2009) has an extra pixel, 57 | may not be correct without it, but set if the pixel 58 | shift is important 59 | # only for Fraunhofer 60 | 'fraunhofer_crop_image', True: when resolution changes, crop image 61 | plane instead of SLM plane, details in 62 | __init__ for FraunhoferPropagation 63 | # only for Fraunhofer with multiple distances 64 | 'focal_length', no default: required to determine plane for Fourier 65 | relationship (e.g., lens focal length) 66 | relative to which the other distances are 67 | propagated. 68 | device: torch parameter for the device to place the convolution kernel on. 69 | If not given, will default to the device of the input_field. 70 | 71 | Propagation.forward and Propagation.backward: 72 | input_field: complex field at starting plane (e.g. SLM for foward) 73 | 74 | Returns: output_field at the ending plane matching the specified resolution 75 | (for single distance) or output_fields, a dictionary of fields at 76 | each propagation distance (keys are distances) 77 | 78 | All units are in meters and radians unless explicitly stated as otherwise. 79 | Terms for resolution are in ij (matrix) order, not xy (cartesian) order. 80 | 81 | input_field should be a torch Tensor, everything else can be either numpy or 82 | native python types. input_field is assumed to be a stack of [real, imag] for 83 | input to the fft (see the torch.fft implementation for details). The 84 | output_field follows the same convention. 85 | 86 | Example: Propagate some input_field by 10cm with Fresnel approx, 5um pixel pitch 87 | on the SLM, with a 1080p SLM and image size equal to it 88 | prop = Propagation('fresnel', 10e-2, [1080, 1920], [5e-6, 5e-6]) 89 | output_field = prop.forward(input_field) 90 | output_field = prop.backward(input_field) 91 | 92 | Example: Propagate some input_field by to multiple distances, using Kirchhoff 93 | propagation. 94 | prop = Propagation('kirchhoff', [10e-2, 20e-2, 30e-2], [1080, 1920], 95 | [5e-6, 5e-6]) 96 | 97 | Example: Setting non-default parameters, e.g. wavelength of 632nm, image 98 | resolution of 720p, image sampling of 8um, some of the extra propagation 99 | parameters, or device to gpu 0 100 | propagation_parameters = {'circular_prop_mask': True, 101 | 'apodize_kernel': True} 102 | prop = Propagation('fresnel', 10e-2, [1080, 1920], [5e-6, 5e-6], 103 | [720, 1280], [8e-6, 8e-6], 632e-9, 104 | propagation_parameters, torch.device('cuda:0')) 105 | # or with named parameters 106 | prop = Propagation(kernel_type='fresnel', 107 | propagation_distances=10e-2, 108 | slm_resolution=[1080, 1920], 109 | slm_pixel_pitch=[5e-6, 5e-6], 110 | image_resolution=[720, 1280], 111 | wavelength=632e-9, 112 | propagation_parameters=propagation_parameters, 113 | device=torch.device('cuda:0')) 114 | 115 | Example: Other propagation kernels, alternate ways to define it 116 | prop = Propagation('Fresnel', ...) # not case sensitive 117 | prop = Propagation('fraunhofer', ...) # Fraunhofer 118 | prop = Propagation('asm', ...) # Angular Spectrum Method 119 | 120 | Author: Nitish Padmanaban 121 | """ 122 | 123 | import numpy as np 124 | from scipy.signal import fftconvolve 125 | import torch 126 | import torch.nn as nn 127 | import warnings 128 | import utils 129 | 130 | 131 | class Propagation: 132 | """Convenience class for using different propagation kernels and sets of 133 | propagation distances""" 134 | def __new__(cls, kernel_type, propagation_distances, slm_resolution, 135 | slm_pixel_pitch, image_resolution=None, wavelength=532e-9, 136 | propagation_parameters=None, device=None): 137 | # process input types for propagation distances 138 | if isinstance(propagation_distances, (np.ndarray, torch.Tensor)): 139 | propagation_distances = propagation_distances.flatten().tolist() 140 | # singleton lists should be made into scalars 141 | if (isinstance(propagation_distances, (tuple, list)) 142 | and len(propagation_distances) == 1): 143 | propagation_distances = propagation_distances[0] 144 | 145 | # scalar means this is a single distance propagation 146 | if not isinstance(propagation_distances, (tuple, list)): 147 | cls_out = {'fresnel': FresnelPropagation, 148 | 'fresnel_conv': FresnelConvPropagation, 149 | 'asm': AngularSpectrumPropagation, 150 | 'angular_spectrum': AngularSpectrumPropagation, 151 | 'kirchhoff': KirchhoffPropagation, 152 | 'fraunhofer': FraunhoferPropagation, 153 | 'fourier': FraunhoferPropagation}[kernel_type.lower()] 154 | return cls_out(propagation_distances, slm_resolution, 155 | slm_pixel_pitch, image_resolution, wavelength, 156 | propagation_parameters, device) 157 | else: 158 | return MultiDistancePropagation( 159 | kernel_type, propagation_distances, slm_resolution, 160 | slm_pixel_pitch, image_resolution, wavelength, 161 | propagation_parameters, device) 162 | 163 | 164 | class PropagationBase(nn.Module): 165 | image_native_pitch = None 166 | 167 | """Interface for propagation functions, with some shared functions""" 168 | def __init__(self, propagation_distance, slm_resolution, slm_pixel_pitch, 169 | image_resolution=None, wavelength=532e-9, 170 | propagation_parameters=None, device=None): 171 | super().__init__() 172 | self.slm_resolution = np.array(slm_resolution) 173 | self.slm_pixel_pitch = np.array(slm_pixel_pitch) 174 | self.propagation_distance = propagation_distance 175 | self.wavelength = wavelength 176 | self.dev = device 177 | 178 | # default image dimensions to slm dimensions 179 | if image_resolution is None: 180 | self.image_resolution = self.slm_resolution 181 | else: 182 | self.image_resolution = np.array(image_resolution) 183 | 184 | # native image sampling matches slm pitch, unless overridden by a 185 | # deriving class (e.g. FraunhoferPropagation) 186 | if self.image_native_pitch is None: 187 | self.image_native_pitch = self.slm_pixel_pitch 188 | 189 | # set image pixel pitch to native image sampling 190 | self.image_pixel_pitch = self.image_native_pitch 191 | 192 | # physical size of planes in meters 193 | self.slm_size = self.slm_pixel_pitch * self.slm_resolution 194 | self.image_size = self.image_pixel_pitch * self.image_resolution 195 | 196 | # dictionary for extra parameters particular to base class 197 | self.propagation_parameters = propagation_parameters 198 | if self.propagation_parameters is None: 199 | self.propagation_parameters = {} 200 | 201 | # set default for padding type when convolving 202 | try: 203 | self.padding_type = self.propagation_parameters.pop('padding_type') 204 | except KeyError: 205 | self.padding_type = 'zero' 206 | 207 | def forward(self, input_field): 208 | """Returns output_field, which is input_field propagated by 209 | propagation_distance, from slm_resolution to image_resolution""" 210 | raise NotImplementedError('Must implement in derived class') 211 | 212 | def backward(self, input_field): 213 | """Returns output_field, which is input_field propagated by 214 | -propagation_distance, from image_resolution to slm_resolution""" 215 | raise NotImplementedError('Must implement in derived class') 216 | 217 | def to(self, *args, **kwargs): 218 | """Moves non-parameter tensors needed for propagation to device 219 | 220 | Also updates the internal self.dev added to this class 221 | """ 222 | slf = super().to(*args, **kwargs) 223 | 224 | device_arg = torch._C._nn._parse_to(*args, **kwargs)[0] 225 | if device_arg is not None: 226 | slf.dev = device_arg 227 | 228 | return slf 229 | 230 | def pad_smaller_dims(self, field, target_shape, pytorch=True, padval=None): 231 | if padval is None: 232 | padval = self.get_pad_value(field, pytorch) 233 | return utils.pad_smaller_dims(field, target_shape, pytorch, 234 | padval=padval) 235 | 236 | def crop_larger_dims(self, field, target_shape, pytorch=True): 237 | return utils.crop_larger_dims(field, target_shape, pytorch) 238 | 239 | def get_pad_value(self, field, pytorch=True): 240 | if self.padding_type == 'zero': 241 | return 0 242 | elif self.padding_type == 'median': 243 | if pytorch: 244 | return torch.median(stacked_abs(field)) 245 | else: 246 | return np.median(np.abs(field)) 247 | else: 248 | raise ValueError('Unknown padding type') 249 | 250 | 251 | class NearFieldConvPropagationBase(PropagationBase): 252 | """Defines functions shared across propagation near field approximations 253 | based on convolving a kernel 254 | """ 255 | def __init__(self, propagation_distance, slm_resolution, slm_pixel_pitch, 256 | image_resolution=None, wavelength=532e-9, 257 | propagation_parameters=None, device=None): 258 | super().__init__(propagation_distance, slm_resolution, slm_pixel_pitch, 259 | image_resolution, wavelength, propagation_parameters, 260 | device) 261 | # diffraction pattern calculations 262 | self.max_diffraction_angle = np.arcsin(wavelength 263 | / self.slm_pixel_pitch / 2) 264 | self.prop_mask_radius = (propagation_distance 265 | * np.tan(self.max_diffraction_angle)) 266 | 267 | # limit zone plate to maximum usable size 268 | slm_diagonal = np.sqrt((self.slm_size**2).sum()) 269 | image_diagonal = np.sqrt((self.image_size**2).sum()) 270 | max_usable_distance = slm_diagonal / 2 + image_diagonal / 2 271 | self.prop_mask_radius = np.minimum(self.prop_mask_radius, 272 | max_usable_distance) 273 | 274 | # force input and output of forward/backward 275 | # operations to have the same absolute sum 276 | try: 277 | self.normalize_output = self.propagation_parameters.pop( 278 | 'normalize_output') 279 | except KeyError: 280 | self.normalize_output = True 281 | 282 | # sets self.foward_kernel and self.backward_kernel 283 | self.compute_conv_kernels(**self.propagation_parameters) 284 | 285 | if self.dev is not None: 286 | self.forward_kernel = self.forward_kernel.to(self.dev) 287 | self.backward_kernel = self.backward_kernel.to(self.dev) 288 | 289 | def compute_conv_kernels(self, *, circular_prop_mask=True, 290 | apodize_kernel=True, apodization_width=50, 291 | prop_mask_fraction=1., **kwargs): 292 | # sampling positions along the x and y dims 293 | coords_x = np.arange(self.slm_pixel_pitch[1], 294 | self.prop_mask_radius[1] / prop_mask_fraction, 295 | self.slm_pixel_pitch[1]) 296 | coords_x = np.concatenate((-coords_x[::-1], [0], coords_x)) 297 | coords_y = np.arange(self.slm_pixel_pitch[0], 298 | self.prop_mask_radius[0] / prop_mask_fraction, 299 | self.slm_pixel_pitch[0]) 300 | coords_y = np.concatenate((-coords_y[::-1], [0], coords_y)) 301 | 302 | samples_x, samples_y = np.meshgrid(coords_x, coords_y) 303 | 304 | # compute complex forward propagation at sampled points 305 | forward = self.forward_prop_at_points(samples_x, samples_y) 306 | 307 | if circular_prop_mask: 308 | forward = self.apply_circular_mask(forward, 309 | np.sqrt(samples_x**2 310 | + samples_y**2), 311 | apodize_kernel, 312 | apodization_width) 313 | 314 | # rescale for approx energy preservation even when normalization off 315 | # forward *= self.wavelength / np.sum(self.slm_resolution) 316 | forward /= np.sum(np.abs(forward)) 317 | 318 | # convert to stacked real and imaginary for pytorch fft format 319 | forward_stacked = np.stack((np.real(forward), np.imag(forward)), -1) 320 | self.forward_kernel = torch.from_numpy(forward_stacked).float() 321 | # reverse prop is just the conjugate 322 | backward_stacked = np.stack((np.real(forward), -np.imag(forward)), -1) 323 | self.backward_kernel = torch.from_numpy(backward_stacked).float() 324 | 325 | def forward_prop_at_points(self, samples_x, samples_y): 326 | """computes the convolution kernel for the deriving class's 327 | particular approximation 328 | """ 329 | raise NotImplementedError('Must implement in derived class') 330 | 331 | def apply_circular_mask(self, pattern, distances, apodize=True, 332 | apodization_width=50): 333 | # furthest point along smaller dimension, max usable radius 334 | max_radius = min(distances[0, :].min(), distances[:, 0].min()) 335 | 336 | if apodize: 337 | # set the width of apodization based on the wider pixel pitch 338 | pixel_pitch = max(self.slm_pixel_pitch) 339 | apodization_width *= pixel_pitch 340 | 341 | if apodization_width > max_radius: 342 | apodization_width = max_radius 343 | 344 | # ramp that rises to 1 over a length of apodization_width 345 | normalized_edge_dist = (max_radius - distances) / apodization_width 346 | normalized_edge_dist = normalized_edge_dist.clip(min=0, max=1) 347 | 348 | # convert ramp to smooth cos 349 | mask = 1 / 2 + np.cos(np.pi * normalized_edge_dist - np.pi) / 2 350 | mask /= mask.max() # make sure it's max 1, probably not needed 351 | else: 352 | mask = (distances <= max_radius).astype(np.float64) 353 | 354 | return pattern * mask 355 | 356 | def forward(self, input_field): 357 | # force kernel device to input's device if this module specifies nothing 358 | if (self.dev is None 359 | and self.forward_kernel.device != input_field.device): 360 | self.forward_kernel = self.forward_kernel.to(input_field.device) 361 | 362 | if self.normalize_output: 363 | input_magnitude_sum = magnitude_sum(input_field) 364 | 365 | padval = self.get_pad_value(input_field) 366 | input_padded = self.pad_smaller_dims(input_field, self.image_resolution, 367 | padval=padval) 368 | output_field = utils.conv_fft(input_padded, self.forward_kernel, 369 | padval=padval) 370 | output_cropped = self.crop_larger_dims(output_field, 371 | self.image_resolution) 372 | if self.normalize_output: 373 | output_magnitude_sum = magnitude_sum(output_cropped) 374 | output_cropped = output_cropped * (input_magnitude_sum 375 | / output_magnitude_sum) 376 | 377 | return output_cropped 378 | 379 | def backward(self, input_field): 380 | # force kernel device to input's device if this module specifies nothing 381 | if (self.dev is None 382 | and self.backward_kernel.device != input_field.device): 383 | self.backward_kernel = self.backward_kernel.to(input_field.device) 384 | 385 | if self.normalize_output: 386 | input_magnitude_sum = magnitude_sum(input_field) 387 | 388 | padval = self.get_pad_value(input_field) 389 | input_padded = self.pad_smaller_dims(input_field, self.slm_resolution, 390 | padval=padval) 391 | output_field = utils.conv_fft(input_padded, self.backward_kernel, 392 | padval=padval) 393 | output_cropped = self.crop_larger_dims(output_field, 394 | self.slm_resolution) 395 | if self.normalize_output: 396 | output_magnitude_sum = magnitude_sum(output_cropped) 397 | output_cropped = output_cropped * (input_magnitude_sum 398 | / output_magnitude_sum) 399 | 400 | return output_cropped 401 | 402 | def to(self, *args, **kwargs): 403 | slf = super().to(*args, **kwargs) 404 | 405 | if slf.dev is not None: 406 | slf.forward_kernel = slf.forward_kernel.to(slf.dev) 407 | slf.backward_kernel = slf.backward_kernel.to(slf.dev) 408 | 409 | return slf 410 | 411 | 412 | class NearFieldTransferFnPropagationBase(PropagationBase): 413 | """Defines functions shared across propagation near field approximations 414 | based on applying the transfer function in Fourier domain 415 | """ 416 | def __init__(self, propagation_distance, slm_resolution, slm_pixel_pitch, 417 | image_resolution=None, wavelength=532e-9, 418 | propagation_parameters=None, device=None): 419 | super().__init__(propagation_distance, slm_resolution, slm_pixel_pitch, 420 | image_resolution, wavelength, propagation_parameters, 421 | device) 422 | # force input and output of forward/backward 423 | # operations to have the same absolute sum 424 | try: 425 | self.normalize_output = self.propagation_parameters.pop( 426 | 'normalize_output') 427 | except KeyError: 428 | self.normalize_output = False 429 | 430 | # sets self.foward_kernel and self.backward_kernel 431 | self.compute_transfer_fn(**self.propagation_parameters) 432 | 433 | if self.dev is not None: 434 | self.forward_transfer_fn = self.forward_transfer_fn.to(self.dev) 435 | self.backward_transfer_fn = self.backward_transfer_fn.to(self.dev) 436 | 437 | def compute_transfer_fn(self, *, circular_padding=False, **kwargs): 438 | """computes the Fourier transfer function for the deriving class's 439 | particular approximation 440 | """ 441 | raise NotImplementedError('Must implement in derived class') 442 | 443 | def forward(self, input_field): 444 | # force transfer function device to input's device if this module 445 | # specifies nothing 446 | if (self.dev is None 447 | and self.forward_transfer_fn.device != input_field.device): 448 | self.forward_transfer_fn = self.forward_transfer_fn.to( 449 | input_field.device) 450 | 451 | if self.normalize_output: 452 | input_magnitude_sum = magnitude_sum(input_field) 453 | 454 | # compute Fourier transform of input field 455 | fourier_input = self.padded_fft(input_field) 456 | 457 | # apply transfer function for forward prop 458 | fourier_output = utils.mul_complex(fourier_input, 459 | self.forward_transfer_fn) 460 | 461 | # Fourier transform back to get output 462 | output_cropped = self.cropped_ifft(fourier_output, 463 | self.image_resolution) 464 | 465 | if self.normalize_output: 466 | output_magnitude_sum = magnitude_sum(output_cropped) 467 | output_cropped = output_cropped * (input_magnitude_sum 468 | / output_magnitude_sum) 469 | 470 | return output_cropped 471 | 472 | def backward(self, input_field): 473 | # force transfer function device to input's device if this module 474 | # specifies nothing 475 | if (self.dev is None 476 | and self.backward_transfer_fn.device != input_field.device): 477 | self.backward_transfer_fn = self.backward_transfer_fn.to( 478 | input_field.device) 479 | 480 | if self.normalize_output: 481 | input_magnitude_sum = magnitude_sum(input_field) 482 | 483 | # compute Fourier transform of input field 484 | fourier_input = self.padded_fft(input_field) 485 | 486 | # apply transfer function for backward prop 487 | fourier_output = utils.mul_complex(fourier_input, 488 | self.backward_transfer_fn) 489 | 490 | # Fourier transform back to get output 491 | output_cropped = self.cropped_ifft(fourier_output, self.slm_resolution) 492 | 493 | if self.normalize_output: 494 | output_magnitude_sum = magnitude_sum(output_cropped) 495 | output_cropped = output_cropped * (input_magnitude_sum 496 | / output_magnitude_sum) 497 | 498 | return output_cropped 499 | 500 | def padded_fft(self, input_field): 501 | input_padded = self.pad_smaller_dims(input_field, self.conv_resolution) 502 | return utils.fft(input_padded) 503 | 504 | def cropped_ifft(self, fourier_output, output_res): 505 | output_field = utils.ifft(fourier_output) 506 | return self.crop_larger_dims(output_field, output_res) 507 | 508 | def to(self, *args, **kwargs): 509 | slf = super().to(*args, **kwargs) 510 | 511 | if slf.dev is not None: 512 | slf.forward_transfer_fn = slf.forward_transfer_fn.to(slf.dev) 513 | slf.backward_transfer_fn = slf.backward_transfer_fn.to(slf.dev) 514 | 515 | return slf 516 | 517 | 518 | class FresnelConvPropagation(NearFieldConvPropagationBase): 519 | """Implements the Fresnel approximation for the kernel""" 520 | def forward_prop_at_points(self, samples_x, samples_y): 521 | # prevent 0 522 | if abs(self.propagation_distance) < 1e-10: 523 | prop_dist = -1e-10 if self.propagation_distance < 0 else 1e-10 524 | else: 525 | prop_dist = self.propagation_distance 526 | wave_number = 2 * np.pi / self.wavelength 527 | 528 | # exclude propagation_distance for zero phase at center 529 | phase_term = ((samples_x**2 + samples_y**2) / (2 * prop_dist)) 530 | # ignore 1/j term 531 | amplitude_term = 1 / prop_dist / self.wavelength 532 | return amplitude_term * np.exp(1j * wave_number * phase_term) 533 | 534 | 535 | class KirchhoffPropagation(NearFieldConvPropagationBase): 536 | """Implements the Kirchhoff approximation for the kernel""" 537 | def forward_prop_at_points(self, samples_x, samples_y): 538 | # prevent 0 539 | if abs(self.propagation_distance) < 1e-10: 540 | prop_dist = -1e-10 if self.propagation_distance < 0 else 1e-10 541 | else: 542 | prop_dist = self.propagation_distance 543 | wave_number = 2 * np.pi / self.wavelength 544 | 545 | radius = np.sqrt(prop_dist**2 + samples_x**2 + samples_y**2) 546 | phase_term = radius - prop_dist # zero phase at center 547 | # ignore 1/j term 548 | amplitude_term = prop_dist / self.wavelength / radius**2 549 | return amplitude_term * np.exp(1j * wave_number * phase_term) 550 | 551 | 552 | class FresnelPropagation(NearFieldTransferFnPropagationBase): 553 | """Implements the Fresnel approximation for the transfer function""" 554 | def compute_transfer_fn(self, *, circular_padding=False, **kwargs): 555 | # we always convolve at the size of the larger dimensions 556 | self.conv_resolution = np.maximum(self.slm_resolution, 557 | self.image_resolution) 558 | # for linear convolution, otherwise the input 559 | # field is implicitly circularly padded 560 | if not circular_padding: 561 | self.conv_resolution = self.conv_resolution * 2 - 1 562 | # physical dimensions 563 | self.conv_size = self.slm_pixel_pitch * self.conv_resolution 564 | 565 | # sampling positions along the x and y dims 566 | min_coords = -1 / (2 * self.slm_pixel_pitch) 567 | max_coords = 1 / (2 * self.slm_pixel_pitch) - 1 / self.conv_size 568 | 569 | coords_fx = np.linspace(min_coords[1], 570 | max_coords[1], 571 | self.conv_resolution[1]) 572 | coords_fy = np.linspace(min_coords[0], 573 | max_coords[0], 574 | self.conv_resolution[0]) 575 | 576 | samples_fx, samples_fy = np.meshgrid(coords_fx, coords_fy) 577 | 578 | forward_phases = (np.pi * -self.propagation_distance * self.wavelength 579 | * (samples_fx**2 + samples_fy**2)) 580 | 581 | forward = np.exp(1j * forward_phases) 582 | 583 | # convert to stacked real and imaginary for pytorch fft format 584 | forward_stacked = np.stack((np.real(forward), np.imag(forward)), -1) 585 | self.forward_transfer_fn = torch.from_numpy(forward_stacked).float() 586 | # reverse prop is just the conjugate 587 | backward_stacked = np.stack((np.real(forward), -np.imag(forward)), -1) 588 | self.backward_transfer_fn = torch.from_numpy(backward_stacked).float() 589 | 590 | 591 | class AngularSpectrumPropagation(NearFieldTransferFnPropagationBase): 592 | """Implements the Fresnel approximation for the transfer function""" 593 | def compute_transfer_fn(self, *, circular_padding=False, extra_pixel=True, 594 | **kwargs): 595 | # we always convolve at the size of the larger dimensions 596 | self.conv_resolution = np.maximum(self.slm_resolution, 597 | self.image_resolution) 598 | # for linear convolution, otherwise the input 599 | # field is implicitly circularly padded 600 | if not circular_padding: 601 | self.conv_resolution *= 2 602 | # Note: Matsushima and Shimobaba (2009) only discuss 2x padding, 603 | # unclear if this is correct without the extra pixel 604 | if not extra_pixel: 605 | self.conv_resolution -= 1 606 | # physical dimensions 607 | self.conv_size = self.slm_pixel_pitch * self.conv_resolution 608 | 609 | # sampling positions along the x and y dims 610 | max_coords = 1 / (2 * self.slm_pixel_pitch) - 0.5 / (2 * self.conv_size) 611 | coords_fx = np.linspace(-max_coords[1], 612 | max_coords[1], 613 | self.conv_resolution[1]) 614 | coords_fy = np.linspace(-max_coords[0], 615 | max_coords[0], 616 | self.conv_resolution[0]) 617 | 618 | samples_fx, samples_fy = np.meshgrid(coords_fx, coords_fy) 619 | 620 | forward_phases = (2 * np.pi * self.propagation_distance 621 | * np.sqrt(self.wavelength**-2 - (samples_fx**2 622 | + samples_fy**2))) 623 | 624 | # bandlimit the transfer function, Matsushima and Shimobaba (2009) 625 | f_max = 1 / np.sqrt((2 * self.propagation_distance / self.conv_size)**2 626 | + 1) / self.wavelength 627 | freq_support = ((np.abs(samples_fx) < f_max[1]) 628 | & (np.abs(samples_fy) < f_max[0])) 629 | 630 | forward = freq_support * np.exp(1j * forward_phases) 631 | 632 | # convert to stacked real and imaginary for pytorch fft format 633 | forward_stacked = np.stack((np.real(forward), np.imag(forward)), -1) 634 | self.forward_transfer_fn = torch.from_numpy(forward_stacked).float() 635 | # reverse prop is just the conjugate 636 | backward_stacked = np.stack((np.real(forward), -np.imag(forward)), -1) 637 | self.backward_transfer_fn = torch.from_numpy(backward_stacked).float() 638 | 639 | 640 | class FraunhoferPropagation(PropagationBase): 641 | """Implements Fraunhofer propagation, where lens focal length is given by 642 | propagation_distance""" 643 | def __init__(self, propagation_distance, slm_resolution, slm_pixel_pitch, 644 | image_resolution=None, wavelength=532e-9, 645 | propagation_parameters=None, device=None): 646 | # Fraunhofer propagation has a different native resolution at image 647 | # plane, defined by the transform relating the SLM and image planes. It 648 | # uses frequencies of x/(lambda*f), which changes the sampling density 649 | self.focal_length = propagation_distance 650 | # extent of slm 651 | slm_bandwidth = np.array(slm_pixel_pitch) * np.array(slm_resolution) 652 | slm_fourier_sampling = 1 / slm_bandwidth 653 | self.image_native_pitch = (slm_fourier_sampling * wavelength 654 | * self.focal_length) 655 | 656 | super().__init__(propagation_distance, slm_resolution, slm_pixel_pitch, 657 | image_resolution, wavelength, propagation_parameters, 658 | device) 659 | 660 | # for Fraunhofer propagation, etendue fixes the output physical 661 | # dimensions based on SLM pixel pitch. For a bigger image resolution, we 662 | # just pad the SLM field before propagation. For a smaller image, we can 663 | # either crop first to use part of the SLM to produce a low resolution, 664 | # but full physical size output, or crop after to use less of the 665 | # physical area, but keep the high resolution by using the full SLM. 666 | # Default is to crop the image so that we have more degrees of freedom 667 | # on the SLM 668 | try: 669 | self.fraunhofer_crop_image = self.propagation_parameters.pop( 670 | 'fraunhofer_crop_image') 671 | except KeyError: 672 | self.fraunhofer_crop_image = True 673 | 674 | def forward(self, input_field): 675 | input_padded = self.pad_smaller_dims(input_field, self.image_resolution) 676 | 677 | if self.fraunhofer_crop_image: 678 | output_field = utils.fft(input_padded, normalized=True) 679 | return self.crop_larger_dims(output_field, self.image_resolution) 680 | else: 681 | input_padded_cropped = self.crop_larger_dims(input_padded, 682 | self.image_resolution) 683 | return utils.fft(input_padded_cropped, normalized=True) 684 | 685 | def backward(self, input_field): 686 | # reverse the operations of the forward field 687 | if self.fraunhofer_crop_image: 688 | input_padded = self.pad_smaller_dims(input_field, 689 | self.slm_resolution) 690 | output_field = utils.ifft(input_padded, normalized=True) 691 | else: 692 | output_field_unpadded = utils.ifft(input_field, normalized=True) 693 | output_field = self.pad_smaller_dims(output_field_unpadded, 694 | self.slm_resolution) 695 | 696 | return self.crop_larger_dims(output_field, self.slm_resolution) 697 | 698 | 699 | class MultiDistancePropagation(nn.Module): 700 | """Container class that handles propagating to multiple distances""" 701 | def __init__(self, kernel_type, propagation_distances, slm_resolution, 702 | slm_pixel_pitch, image_resolution=None, wavelength=532e-9, 703 | propagation_parameters=None, device=None): 704 | super().__init__() 705 | self.kernel_type = kernel_type.lower() 706 | self.slm_resolution = slm_resolution 707 | self.slm_pixel_pitch = slm_pixel_pitch 708 | self.image_resolution = image_resolution 709 | self.wavelength = wavelength 710 | self.propagation_parameters = propagation_parameters 711 | self.dev = device 712 | if self.propagation_parameters is None: 713 | self.propagation_parameters = {} 714 | 715 | # for near field distances, turn off internal normalization 716 | # so it can be applied uniformly accross all distances 717 | self.propagation_parameters['normalize_output'] = False 718 | 719 | self.has_fraunhofer = kernel_type in ('fraunhofer', 'fourier') 720 | 721 | # process input types for propagation distances 722 | if isinstance(propagation_distances, (np.ndarray, torch.Tensor)): 723 | propagation_distances = propagation_distances.flatten().tolist() 724 | # unique values only 725 | self.propagation_distances = set(propagation_distances) 726 | # mappings if modified for Fourier plane 727 | self.get_original_dist = {d: d for d in self.propagation_distances} 728 | self.get_internal_dist = {d: d for d in self.propagation_distances} 729 | 730 | # dictionary for the set of propagators 731 | self.propagators = {} 732 | 733 | if self.has_fraunhofer: 734 | self.create_fraunhofer_propagator() 735 | # all other planes will be propagated from the Fourier plane, 736 | # keeping its resolution and pixel pitch 737 | self.kernel_type = 'kirchhoff' 738 | self.create_near_field_propagators(self.fourier_resolution, 739 | self.fourier_pixel_pitch, 740 | None) 741 | else: 742 | self.create_near_field_propagators(self.slm_resolution, 743 | self.slm_pixel_pitch, 744 | self.image_resolution) 745 | 746 | def create_near_field_propagators(self, start_resolution, start_pixel_pitch, 747 | image_resolution): 748 | prop_cls = {'fresnel': FresnelPropagation, 749 | 'fresnel_conv': FresnelConvPropagation, 750 | 'asm': AngularSpectrumPropagation, 751 | 'angular_spectrum': AngularSpectrumPropagation, 752 | 'kirchhoff': KirchhoffPropagation}[self.kernel_type] 753 | for d in self.propagation_distances: 754 | if d == 0: 755 | continue 756 | self.propagators[d] = prop_cls( 757 | d, start_resolution, start_pixel_pitch, image_resolution, 758 | self.wavelength, self.propagation_parameters.copy(), 759 | self.dev) 760 | 761 | def create_fraunhofer_propagator(self): 762 | try: 763 | self.focal_length = self.propagation_parameters.pop('focal_length') 764 | except KeyError: 765 | raise ValueError("Multi-distance Fraunhofer propagation requires " 766 | "'focal_length' in propagation_parameters to " 767 | "specify which propagation_distance has the " 768 | "Fourier relationship.") 769 | 770 | if self.focal_length not in self.propagation_distances: 771 | warnings.warn('focal_length is not in the list of ' 772 | 'propagation_distances. Add it if you also want ' 773 | 'the Fourier plane output field.') 774 | 775 | # set the propagation distances relative to Fourier plane 776 | self.get_original_dist = {d - self.focal_length: d 777 | for d in self.propagation_distances} 778 | self.get_internal_dist = {d: d - self.focal_length 779 | for d in self.propagation_distances} 780 | # make sure 0 doesn't have a rounding error 781 | if 0 not in self.get_original_dist: 782 | zero_value = None 783 | for d in self.propagation_distances: 784 | if abs(d) < 1e-10: 785 | zero_value = d 786 | break 787 | if zero_value is not None: 788 | orig_dist = self.get_original_dist.pop(zero_value) 789 | self.get_original_dist[0] = orig_dist 790 | self.get_internal_dist[orig_dist] = 0 791 | # update the propagation distances for internal use 792 | self.propagation_distances = set(self.get_original_dist.keys()) 793 | 794 | # make Fraunhofer propagator 795 | self.fraunhofer_propagator = FraunhoferPropagation( 796 | self.focal_length, self.slm_resolution, self.slm_pixel_pitch, 797 | self.image_resolution, self.wavelength, 798 | self.propagation_parameters.copy(), self.dev) 799 | self.fourier_resolution = self.fraunhofer_propagator.image_resolution 800 | self.fourier_pixel_pitch = (self.fraunhofer_propagator 801 | .image_pixel_pitch) 802 | 803 | def forward(self, input_field): 804 | # for normalization 805 | input_magnitude_sum = magnitude_sum(input_field) 806 | 807 | # do Fraunhofer propagation first if needed 808 | if self.has_fraunhofer: 809 | input_field = self.fraunhofer_propagator.forward(input_field) 810 | 811 | output_fields = {} 812 | output_sums = {} 813 | for d in self.propagation_distances: 814 | if d == 0: 815 | output_fields[d] = input_field 816 | else: 817 | output_fields[d] = self.propagators[d].forward(input_field) 818 | output_sums[d] = magnitude_sum(output_fields[d]) 819 | 820 | # give the 0 distance layer twice the weight of the highest other layer. 821 | # This is mainly for the Fraunhofer propagation case, since we want the 822 | # layers to have the correct relative radiometric fall-off, but the 823 | # Fourier plane itself would dominate the backprop, so we compensate 824 | if 0 in self.propagation_distances: 825 | sum_max = max(output_sums[d] for d in output_sums if d != 0) 826 | output_fields[0] = output_fields[0] * (2 * sum_max / output_sums[0]) 827 | output_sums[0] = 2 * sum_max 828 | 829 | # normalize output based on input 830 | output_magnitude_sum = sum(output_sums[d] for d in output_sums) 831 | scale_factor = (input_magnitude_sum / output_magnitude_sum 832 | * len(self.propagation_distances)) 833 | for d in output_fields: 834 | output_fields[d].mul_(scale_factor) 835 | 836 | # return using original distances as keys 837 | return {self.get_original_dist[d]: output_fields[d] 838 | for d in output_fields} 839 | 840 | def backward(self, input_fields): 841 | input_magnitude_sum = sum(magnitude_sum(input_fields[d]) 842 | for d in input_fields) 843 | 844 | output_fields = {} 845 | output_sums = {} 846 | for d_orig in input_fields: 847 | d = self.get_internal_dist[d_orig] 848 | if d == 0: 849 | output_fields[d] = input_fields[d_orig] 850 | else: 851 | output_fields[d] = self.propagators[d].backward( 852 | input_fields[d_orig]) 853 | output_sums[d] = magnitude_sum(output_fields[d]) 854 | 855 | # compensate for a 0 distance propagation layer (see self.forward()) 856 | if 0 in self.propagation_distances: 857 | sum_max = max(output_sums[d] for d in output_sums if d != 0) 858 | output_fields[0] = output_fields[0] * (2 * sum_max / output_sums[0]) 859 | output_sums[0] = 2 * sum_max 860 | 861 | # combine the fields 862 | output_field = torch.stack(list(output_fields.values()), -1).sum(-1) 863 | 864 | # reverse Fraunhofer propagation if needed 865 | if self.has_fraunhofer: 866 | output_field = self.fraunhofer_propagator.backward(output_field) 867 | 868 | # normalize output based on input 869 | output_magnitude_sum = magnitude_sum(output_field) 870 | output_field.mul_(input_magnitude_sum / output_magnitude_sum 871 | / len(self.propagation_distances)) 872 | 873 | return output_field 874 | 875 | def to(self, *args, **kwargs): 876 | """Moves non-parameter tensors needed for propagation to device 877 | 878 | Also updates the internal self.dev added to this class 879 | """ 880 | slf = super().to(*args, **kwargs) 881 | 882 | device_arg = torch._C._nn._parse_to(*args, **kwargs)[0] 883 | if device_arg is not None: 884 | slf.dev = device_arg 885 | 886 | if slf.has_fraunhofer: 887 | slf.fraunhofer_propagator.to(slf.dev) 888 | 889 | for d in slf.propagation_distances: 890 | slf.propagators[d].to(slf.dev) 891 | 892 | return slf 893 | 894 | 895 | def stacked_abs(field): 896 | # for a complex field stacked according to pytorch fft format, computes 897 | # the magnitude for each pixel 898 | return torch.pow(utils.field_to_intensity(field), 0.5) 899 | 900 | 901 | def magnitude_sum(field): 902 | # for a complex field stacked according to pytorch fft format, computes 903 | # a normalization factor over the magnitudes 904 | return stacked_abs(field).mean() 905 | -------------------------------------------------------------------------------- /pytorch_prototyping/__pycache__/pytorch_prototyping.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/COMHTVM/lensless/0d67a310bab08551d7422fa792f3422a7ee7d9cb/pytorch_prototyping/__pycache__/pytorch_prototyping.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_prototyping/pytorch_prototyping.py: -------------------------------------------------------------------------------- 1 | '''A number of custom pytorch modules with sane defaults that I find useful for model prototyping.''' 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | import torchvision.utils 6 | 7 | import numpy as np 8 | 9 | import math 10 | import numbers 11 | 12 | class FCLayer(nn.Module): 13 | def __init__(self, in_features, out_features): 14 | super().__init__() 15 | self.net = nn.Sequential( 16 | nn.Linear(in_features, out_features), 17 | nn.LayerNorm([out_features]), 18 | nn.ReLU(inplace=True) 19 | ) 20 | 21 | def forward(self, input): 22 | return self.net(input) 23 | 24 | 25 | # From https://gist.github.com/wassname/ecd2dac6fc8f9918149853d17e3abf02 26 | class LayerNormConv2d(nn.Module): 27 | 28 | def __init__(self, num_features, eps=1e-5, affine=True): 29 | super().__init__() 30 | self.num_features = num_features 31 | self.affine = affine 32 | self.eps = eps 33 | 34 | if self.affine: 35 | self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) 36 | self.beta = nn.Parameter(torch.zeros(num_features)) 37 | 38 | def forward(self, x): 39 | shape = [-1] + [1] * (x.dim() - 1) 40 | mean = x.view(x.size(0), -1).mean(1).view(*shape) 41 | std = x.view(x.size(0), -1).std(1).view(*shape) 42 | 43 | y = (x - mean) / (std + self.eps) 44 | if self.affine: 45 | shape = [1, -1] + [1] * (x.dim() - 2) 46 | y = self.gamma.view(*shape) * y + self.beta.view(*shape) 47 | return y 48 | 49 | 50 | class FCBlock(nn.Module): 51 | def __init__(self, 52 | hidden_ch, 53 | num_hidden_layers, 54 | in_features, 55 | out_features, 56 | outermost_linear=False): 57 | super().__init__() 58 | 59 | self.net = [] 60 | self.net.append(FCLayer(in_features=in_features, out_features=hidden_ch)) 61 | 62 | for i in range(num_hidden_layers): 63 | self.net.append(FCLayer(in_features=hidden_ch, out_features=hidden_ch)) 64 | 65 | if outermost_linear: 66 | self.net.append(nn.Linear(in_features=hidden_ch, out_features=out_features)) 67 | else: 68 | self.net.append(FCLayer(in_features=hidden_ch, out_features=out_features)) 69 | 70 | self.net = nn.Sequential(*self.net) 71 | self.net.apply(self.init_weights) 72 | 73 | def __getitem__(self,item): 74 | return self.net[item] 75 | 76 | def init_weights(self, m): 77 | if type(m) == nn.Linear: 78 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') 79 | 80 | def forward(self, input): 81 | return self.net(input) 82 | 83 | 84 | class DownBlock3D(nn.Module): 85 | '''A 3D convolutional downsampling block. 86 | ''' 87 | 88 | def __init__(self, in_channels, out_channels, norm=nn.BatchNorm3d): 89 | super().__init__() 90 | 91 | self.net = [ 92 | nn.ReplicationPad3d(1), 93 | nn.Conv3d(in_channels, 94 | out_channels, 95 | kernel_size=4, 96 | padding=0, 97 | stride=2, 98 | bias=False if norm is not None else True), 99 | ] 100 | 101 | if norm is not None: 102 | self.net += [norm(out_channels, affine=True)] 103 | 104 | self.net += [nn.LeakyReLU(0.2, True)] 105 | self.net = nn.Sequential(*self.net) 106 | 107 | def forward(self, x): 108 | return self.net(x) 109 | 110 | 111 | class UpBlock3D(nn.Module): 112 | '''A 3D convolutional upsampling block. 113 | ''' 114 | 115 | def __init__(self, in_channels, out_channels, norm=nn.BatchNorm3d): 116 | super().__init__() 117 | 118 | self.net = [ 119 | nn.ConvTranspose3d(in_channels, 120 | out_channels, 121 | kernel_size=4, 122 | stride=2, 123 | padding=1, 124 | bias=False if norm is not None else True), 125 | ] 126 | 127 | if norm is not None: 128 | self.net += [norm(out_channels, affine=True)] 129 | 130 | self.net += [nn.ReLU(True)] 131 | self.net = nn.Sequential(*self.net) 132 | 133 | def forward(self, x, skipped=None): 134 | if skipped is not None: 135 | input = torch.cat([skipped, x], dim=1) 136 | else: 137 | input = x 138 | return self.net(input) 139 | 140 | 141 | class Conv3dSame(torch.nn.Module): 142 | '''3D convolution that pads to keep spatial dimensions equal. 143 | Cannot deal with stride. Only quadratic kernels (=scalar kernel_size). 144 | ''' 145 | 146 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, padding_layer=nn.ReplicationPad3d): 147 | ''' 148 | :param in_channels: Number of input channels 149 | :param out_channels: Number of output channels 150 | :param kernel_size: Scalar. Spatial dimensions of kernel (only quadratic kernels supported). 151 | :param bias: Whether or not to use bias. 152 | :param padding_layer: Which padding to use. Default is reflection padding. 153 | ''' 154 | super().__init__() 155 | ka = kernel_size // 2 156 | kb = ka - 1 if kernel_size % 2 == 0 else ka 157 | self.net = nn.Sequential( 158 | padding_layer((ka, kb, ka, kb, ka, kb)), 159 | nn.Conv3d(in_channels, out_channels, kernel_size, bias=bias, stride=1) 160 | ) 161 | 162 | def forward(self, x): 163 | return self.net(x) 164 | 165 | 166 | class Conv2dSame(torch.nn.Module): 167 | '''2D convolution that pads to keep spatial dimensions equal. 168 | Cannot deal with stride. Only quadratic kernels (=scalar kernel_size). 169 | ''' 170 | 171 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, padding_layer=nn.ReflectionPad2d): 172 | ''' 173 | :param in_channels: Number of input channels 174 | :param out_channels: Number of output channels 175 | :param kernel_size: Scalar. Spatial dimensions of kernel (only quadratic kernels supported). 176 | :param bias: Whether or not to use bias. 177 | :param padding_layer: Which padding to use. Default is reflection padding. 178 | ''' 179 | super().__init__() 180 | ka = kernel_size // 2 181 | kb = ka - 1 if kernel_size % 2 == 0 else ka 182 | self.net = nn.Sequential( 183 | padding_layer((ka, kb, ka, kb)), 184 | nn.Conv2d(in_channels, out_channels, kernel_size, bias=bias, stride=1) 185 | ) 186 | 187 | self.weight = self.net[1].weight 188 | self.bias = self.net[1].bias 189 | 190 | def forward(self, x): 191 | return self.net(x) 192 | 193 | 194 | class UpBlock(nn.Module): 195 | '''A 2d-conv upsampling block with a variety of options for upsampling, and following best practices / with 196 | reasonable defaults. (LeakyReLU, kernel size multiple of stride) 197 | ''' 198 | 199 | def __init__(self, 200 | in_channels, 201 | out_channels, 202 | post_conv=True, 203 | use_dropout=False, 204 | dropout_prob=0.1, 205 | norm=nn.BatchNorm2d, 206 | upsampling_mode='transpose'): 207 | ''' 208 | :param in_channels: Number of input channels 209 | :param out_channels: Number of output channels 210 | :param post_conv: Whether to have another convolutional layer after the upsampling layer. 211 | :param use_dropout: bool. Whether to use dropout or not. 212 | :param dropout_prob: Float. The dropout probability (if use_dropout is True) 213 | :param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity. 214 | :param upsampling_mode: Which upsampling mode: 215 | transpose: Upsampling with stride-2, kernel size 4 transpose convolutions. 216 | bilinear: Feature map is upsampled with bilinear upsampling, then a conv layer. 217 | nearest: Feature map is upsampled with nearest neighbor upsampling, then a conv layer. 218 | shuffle: Feature map is upsampled with pixel shuffling, then a conv layer. 219 | ''' 220 | super().__init__() 221 | 222 | net = list() 223 | 224 | if upsampling_mode == 'transpose': 225 | net += [nn.ConvTranspose2d(in_channels, 226 | out_channels, 227 | kernel_size=4, 228 | stride=2, 229 | padding=1, 230 | bias=True if norm is None else False)] 231 | elif upsampling_mode == 'bilinear': 232 | net += [nn.UpsamplingBilinear2d(scale_factor=2)] 233 | net += [ 234 | Conv2dSame(in_channels, out_channels, kernel_size=3, bias=True if norm is None else False)] 235 | elif upsampling_mode == 'nearest': 236 | net += [nn.UpsamplingNearest2d(scale_factor=2)] 237 | net += [ 238 | Conv2dSame(in_channels, out_channels, kernel_size=3, bias=True if norm is None else False)] 239 | elif upsampling_mode == 'shuffle': 240 | net += [nn.PixelShuffle(upscale_factor=2)] 241 | net += [ 242 | Conv2dSame(in_channels // 4, out_channels, kernel_size=3, 243 | bias=True if norm is None else False)] 244 | else: 245 | raise ValueError("Unknown upsampling mode!") 246 | 247 | if norm is not None: 248 | net += [norm(out_channels, affine=True)] 249 | 250 | net += [nn.ReLU(True)] 251 | 252 | if use_dropout: 253 | net += [nn.Dropout2d(dropout_prob, False)] 254 | 255 | if post_conv: 256 | net += [Conv2dSame(out_channels, 257 | out_channels, 258 | kernel_size=3, 259 | bias=True if norm is None else False)] 260 | 261 | if norm is not None: 262 | net += [norm(out_channels, affine=True)] 263 | 264 | net += [nn.ReLU(True)] 265 | 266 | if use_dropout: 267 | net += [nn.Dropout2d(0.1, False)] 268 | 269 | self.net = nn.Sequential(*net) 270 | 271 | def forward(self, x, skipped=None): 272 | if skipped is not None: 273 | input = torch.cat([skipped, x], dim=1) 274 | else: 275 | input = x 276 | return self.net(input) 277 | 278 | 279 | class DownBlock(nn.Module): 280 | '''A 2D-conv downsampling block following best practices / with reasonable defaults 281 | (LeakyReLU, kernel size multiple of stride) 282 | ''' 283 | 284 | def __init__(self, 285 | in_channels, 286 | out_channels, 287 | prep_conv=True, 288 | middle_channels=None, 289 | use_dropout=False, 290 | dropout_prob=0.1, 291 | norm=nn.BatchNorm2d): 292 | ''' 293 | :param in_channels: Number of input channels 294 | :param out_channels: Number of output channels 295 | :param prep_conv: Whether to have another convolutional layer before the downsampling layer. 296 | :param middle_channels: If prep_conv is true, this sets the number of channels between the prep and downsampling 297 | convs. 298 | :param use_dropout: bool. Whether to use dropout or not. 299 | :param dropout_prob: Float. The dropout probability (if use_dropout is True) 300 | :param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity. 301 | ''' 302 | super().__init__() 303 | 304 | if middle_channels is None: 305 | middle_channels = in_channels 306 | 307 | net = list() 308 | 309 | if prep_conv: 310 | net += [nn.ReflectionPad2d(1), 311 | nn.Conv2d(in_channels, 312 | middle_channels, 313 | kernel_size=3, 314 | padding=0, 315 | stride=1, 316 | bias=True if norm is None else False)] 317 | 318 | if norm is not None: 319 | net += [norm(middle_channels, affine=True)] 320 | 321 | net += [nn.LeakyReLU(0.2, True)] 322 | 323 | if use_dropout: 324 | net += [nn.Dropout2d(dropout_prob, False)] 325 | 326 | net += [nn.ReflectionPad2d(1), 327 | nn.Conv2d(middle_channels, 328 | out_channels, 329 | kernel_size=4, 330 | padding=0, 331 | stride=2, 332 | bias=True if norm is None else False)] 333 | 334 | if norm is not None: 335 | net += [norm(out_channels, affine=True)] 336 | 337 | net += [nn.LeakyReLU(0.2, True)] 338 | 339 | if use_dropout: 340 | net += [nn.Dropout2d(dropout_prob, False)] 341 | 342 | self.net = nn.Sequential(*net) 343 | 344 | def forward(self, x): 345 | return self.net(x) 346 | 347 | 348 | class Unet3d(nn.Module): 349 | '''A 3d-Unet implementation with sane defaults. 350 | ''' 351 | 352 | def __init__(self, 353 | in_channels, 354 | out_channels, 355 | nf0, 356 | num_down, 357 | max_channels, 358 | norm=nn.BatchNorm3d, 359 | outermost_linear=False): 360 | ''' 361 | :param in_channels: Number of input channels 362 | :param out_channels: Number of output channels 363 | :param nf0: Number of features at highest level of U-Net 364 | :param num_down: Number of downsampling stages. 365 | :param max_channels: Maximum number of channels (channels multiply by 2 with every downsampling stage) 366 | :param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity. 367 | :param outermost_linear: Whether the output layer should be a linear layer or a nonlinear one. 368 | ''' 369 | super().__init__() 370 | 371 | assert (num_down > 0), "Need at least one downsampling layer in UNet3d." 372 | 373 | # Define the in block 374 | self.in_layer = [Conv3dSame(in_channels, nf0, kernel_size=3, bias=False)] 375 | 376 | if norm is not None: 377 | self.in_layer += [norm(nf0, affine=True)] 378 | 379 | self.in_layer += [nn.LeakyReLU(0.2, True)] 380 | self.in_layer = nn.Sequential(*self.in_layer) 381 | 382 | # Define the center UNet block. The feature map has height and width 1 --> no batchnorm. 383 | self.unet_block = UnetSkipConnectionBlock3d(int(min(2 ** (num_down - 1) * nf0, max_channels)), 384 | int(min(2 ** (num_down - 1) * nf0, max_channels)), 385 | norm=None) 386 | for i in list(range(0, num_down - 1))[::-1]: 387 | self.unet_block = UnetSkipConnectionBlock3d(int(min(2 ** i * nf0, max_channels)), 388 | int(min(2 ** (i + 1) * nf0, max_channels)), 389 | submodule=self.unet_block, 390 | norm=norm) 391 | 392 | # Define the out layer. Each unet block concatenates its inputs with its outputs - so the output layer 393 | # automatically receives the output of the in_layer and the output of the last unet layer. 394 | self.out_layer = [Conv3dSame(2 * nf0, 395 | out_channels, 396 | kernel_size=3, 397 | bias=outermost_linear)] 398 | 399 | if not outermost_linear: 400 | if norm is not None: 401 | self.out_layer += [norm(out_channels, affine=True)] 402 | self.out_layer += [nn.ReLU(True)] 403 | self.out_layer = nn.Sequential(*self.out_layer) 404 | 405 | def forward(self, x): 406 | in_layer = self.in_layer(x) 407 | unet = self.unet_block(in_layer) 408 | out_layer = self.out_layer(unet) 409 | return out_layer 410 | 411 | 412 | class UnetSkipConnectionBlock3d(nn.Module): 413 | '''Helper class for building a 3D unet. 414 | ''' 415 | 416 | def __init__(self, 417 | outer_nc, 418 | inner_nc, 419 | norm=nn.BatchNorm3d, 420 | submodule=None): 421 | super().__init__() 422 | 423 | if submodule is None: 424 | model = [DownBlock3D(outer_nc, inner_nc, norm=norm), 425 | UpBlock3D(inner_nc, outer_nc, norm=norm)] 426 | else: 427 | model = [DownBlock3D(outer_nc, inner_nc, norm=norm), 428 | submodule, 429 | UpBlock3D(2 * inner_nc, outer_nc, norm=norm)] 430 | 431 | self.model = nn.Sequential(*model) 432 | 433 | def forward(self, x): 434 | forward_passed = self.model(x) 435 | return torch.cat([x, forward_passed], 1) 436 | 437 | 438 | class UnetSkipConnectionBlock(nn.Module): 439 | '''Helper class for building a 2D unet. 440 | ''' 441 | 442 | def __init__(self, 443 | outer_nc, 444 | inner_nc, 445 | upsampling_mode, 446 | norm=nn.BatchNorm2d, 447 | submodule=None, 448 | use_dropout=False, 449 | dropout_prob=0.1): 450 | super().__init__() 451 | 452 | if submodule is None: 453 | model = [DownBlock(outer_nc, inner_nc, use_dropout=use_dropout, dropout_prob=dropout_prob, norm=norm), 454 | UpBlock(inner_nc, outer_nc, use_dropout=use_dropout, dropout_prob=dropout_prob, norm=norm, 455 | upsampling_mode=upsampling_mode)] 456 | else: 457 | model = [DownBlock(outer_nc, inner_nc, use_dropout=use_dropout, dropout_prob=dropout_prob, norm=norm), 458 | submodule, 459 | UpBlock(2 * inner_nc, outer_nc, use_dropout=use_dropout, dropout_prob=dropout_prob, norm=norm, 460 | upsampling_mode=upsampling_mode)] 461 | 462 | self.model = nn.Sequential(*model) 463 | 464 | def forward(self, x): 465 | forward_passed = self.model(x) 466 | return torch.cat([x, forward_passed], 1) 467 | 468 | 469 | class Unet(nn.Module): 470 | '''A 2d-Unet implementation with sane defaults. 471 | ''' 472 | 473 | def __init__(self, 474 | in_channels, 475 | out_channels, 476 | nf0, 477 | num_down, 478 | max_channels, 479 | use_dropout, 480 | upsampling_mode='transpose', 481 | dropout_prob=0.1, 482 | norm=nn.BatchNorm2d, 483 | outermost_linear=False): 484 | ''' 485 | :param in_channels: Number of input channels 486 | :param out_channels: Number of output channels 487 | :param nf0: Number of features at highest level of U-Net 488 | :param num_down: Number of downsampling stages. 489 | :param max_channels: Maximum number of channels (channels multiply by 2 with every downsampling stage) 490 | :param use_dropout: Whether to use dropout or no. 491 | :param dropout_prob: Dropout probability if use_dropout=True. 492 | :param upsampling_mode: Which type of upsampling should be used. See "UpBlock" for documentation. 493 | :param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity. 494 | :param outermost_linear: Whether the output layer should be a linear layer or a nonlinear one. 495 | ''' 496 | super().__init__() 497 | 498 | assert (num_down > 0), "Need at least one downsampling layer in UNet." 499 | 500 | # Define the in block 501 | self.in_layer = [Conv2dSame(in_channels, nf0, kernel_size=3, bias=True if norm is None else False)] 502 | if norm is not None: 503 | self.in_layer += [norm(nf0, affine=True)] 504 | self.in_layer += [nn.LeakyReLU(0.2, True)] 505 | 506 | if use_dropout: 507 | self.in_layer += [nn.Dropout2d(dropout_prob)] 508 | self.in_layer = nn.Sequential(*self.in_layer) 509 | 510 | # Define the center UNet block 511 | self.unet_block = UnetSkipConnectionBlock(min(2 ** (num_down-1) * nf0, max_channels), 512 | min(2 ** (num_down-1) * nf0, max_channels), 513 | use_dropout=use_dropout, 514 | dropout_prob=dropout_prob, 515 | norm=None, # Innermost has no norm (spatial dimension 1) 516 | upsampling_mode=upsampling_mode) 517 | 518 | for i in list(range(0, num_down - 1))[::-1]: 519 | self.unet_block = UnetSkipConnectionBlock(min(2 ** i * nf0, max_channels), 520 | min(2 ** (i + 1) * nf0, max_channels), 521 | use_dropout=use_dropout, 522 | dropout_prob=dropout_prob, 523 | submodule=self.unet_block, 524 | norm=norm, 525 | upsampling_mode=upsampling_mode) 526 | 527 | # Define the out layer. Each unet block concatenates its inputs with its outputs - so the output layer 528 | # automatically receives the output of the in_layer and the output of the last unet layer. 529 | self.out_layer = [Conv2dSame(2 * nf0, 530 | out_channels, 531 | kernel_size=3, 532 | bias=outermost_linear or (norm is None))] 533 | 534 | if not outermost_linear: 535 | if norm is not None: 536 | self.out_layer += [norm(out_channels, affine=True)] 537 | self.out_layer += [nn.ReLU(True)] 538 | 539 | if use_dropout: 540 | self.out_layer += [nn.Dropout2d(dropout_prob)] 541 | self.out_layer = nn.Sequential(*self.out_layer) 542 | 543 | self.out_layer_weight = self.out_layer[0].weight 544 | 545 | def forward(self, x): 546 | in_layer = self.in_layer(x) 547 | unet = self.unet_block(in_layer) 548 | out_layer = self.out_layer(unet) 549 | return out_layer 550 | 551 | 552 | class Identity(nn.Module): 553 | '''Helper module to allow Downsampling and Upsampling nets to default to identity if they receive an empty list.''' 554 | 555 | def __init__(self): 556 | super().__init__() 557 | 558 | def forward(self, input): 559 | return input 560 | 561 | 562 | class DownsamplingNet(nn.Module): 563 | '''A subnetwork that downsamples a 2D feature map with strided convolutions. 564 | ''' 565 | 566 | def __init__(self, 567 | per_layer_out_ch, 568 | in_channels, 569 | use_dropout, 570 | dropout_prob=0.1, 571 | last_layer_one=False, 572 | norm=nn.BatchNorm2d): 573 | ''' 574 | :param per_layer_out_ch: python list of integers. Defines the number of output channels per layer. Length of 575 | list defines number of downsampling steps (each step dowsamples by factor of 2.) 576 | :param in_channels: Number of input channels. 577 | :param use_dropout: Whether or not to use dropout. 578 | :param dropout_prob: Dropout probability. 579 | :param last_layer_one: Whether the output of the last layer will have a spatial size of 1. In that case, 580 | the last layer will not have batchnorm, else, it will. 581 | :param norm: Which norm to use. Defaults to BatchNorm. 582 | ''' 583 | super().__init__() 584 | 585 | if not len(per_layer_out_ch): 586 | self.downs = Identity() 587 | else: 588 | self.downs = list() 589 | self.downs.append(DownBlock(in_channels, per_layer_out_ch[0], use_dropout=use_dropout, 590 | dropout_prob=dropout_prob, middle_channels=per_layer_out_ch[0], norm=norm)) 591 | for i in range(0, len(per_layer_out_ch) - 1): 592 | if last_layer_one and (i == len(per_layer_out_ch) - 2): 593 | norm = None 594 | self.downs.append(DownBlock(per_layer_out_ch[i], 595 | per_layer_out_ch[i + 1], 596 | dropout_prob=dropout_prob, 597 | use_dropout=use_dropout, 598 | norm=norm)) 599 | self.downs = nn.Sequential(*self.downs) 600 | 601 | def forward(self, input): 602 | return self.downs(input) 603 | 604 | 605 | class UpsamplingNet(nn.Module): 606 | '''A subnetwork that upsamples a 2D feature map with a variety of upsampling options. 607 | ''' 608 | 609 | def __init__(self, 610 | per_layer_out_ch, 611 | in_channels, 612 | upsampling_mode, 613 | use_dropout, 614 | dropout_prob=0.1, 615 | first_layer_one=False, 616 | norm=nn.BatchNorm2d): 617 | ''' 618 | :param per_layer_out_ch: python list of integers. Defines the number of output channels per layer. Length of 619 | list defines number of upsampling steps (each step upsamples by factor of 2.) 620 | :param in_channels: Number of input channels. 621 | :param upsampling_mode: Mode of upsampling. For documentation, see class "UpBlock" 622 | :param use_dropout: Whether or not to use dropout. 623 | :param dropout_prob: Dropout probability. 624 | :param first_layer_one: Whether the input to the last layer will have a spatial size of 1. In that case, 625 | the first layer will not have a norm, else, it will. 626 | :param norm: Which norm to use. Defaults to BatchNorm. 627 | ''' 628 | super().__init__() 629 | 630 | if not len(per_layer_out_ch): 631 | self.ups = Identity() 632 | else: 633 | self.ups = list() 634 | self.ups.append(UpBlock(in_channels, 635 | per_layer_out_ch[0], 636 | use_dropout=use_dropout, 637 | dropout_prob=dropout_prob, 638 | norm=None if first_layer_one else norm, 639 | upsampling_mode=upsampling_mode)) 640 | for i in range(0, len(per_layer_out_ch) - 1): 641 | self.ups.append( 642 | UpBlock(per_layer_out_ch[i], 643 | per_layer_out_ch[i + 1], 644 | use_dropout=use_dropout, 645 | dropout_prob=dropout_prob, 646 | norm=norm, 647 | upsampling_mode=upsampling_mode)) 648 | self.ups = nn.Sequential(*self.ups) 649 | 650 | def forward(self, input): 651 | return self.ups(input) -------------------------------------------------------------------------------- /ranges.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr": [1e-8], 3 | "focal_length": [0.5e-2] 4 | } 5 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from dataio import * 4 | from torch.utils.data import DataLoader 5 | from denoising_unet import DenoisingUnet 6 | from torch.utils.tensorboard import SummaryWriter 7 | import torchvision.transforms as transforms 8 | from PIL import Image 9 | import torch.nn 10 | import time 11 | from queue import Queue 12 | 13 | if torch.cuda.is_available(): 14 | DEVICE = torch.device("cuda:0") 15 | else: 16 | DEVICE = torch.device("cpu") 17 | 18 | 19 | def load_json(file_name): 20 | """ Loads a JSON file into a dictionary 21 | :param file_name: str - path to json file 22 | :return: json file loaded into python dict 23 | """ 24 | file_name = os.path.expanduser(file_name) 25 | with open(file_name) as f: 26 | s = f.read() 27 | j = json.loads(s) 28 | return j 29 | 30 | 31 | def image_loader(img_name): 32 | """ For optimizing over one image (testing) 33 | Usage: model_input, ground_truth = image_loader('input.png') 34 | :param img_name: str - path to single image file to load 35 | :return: Variable tensor of image in the format (1,C,H,W) 36 | """ 37 | loader = transforms.Compose([transforms.CenterCrop(size=(512,512)), 38 | transforms.ToTensor()]) 39 | 40 | image = Image.open(img_name) 41 | image = loader(image).float().cpu() 42 | image = torch.Tensor(optics.srgb_to_linear(image)) 43 | blurred_image = image.unsqueeze(0) # specify a batch size of 1 44 | image = image.unsqueeze(0) 45 | return blurred_image.to(DEVICE), image.to(DEVICE) 46 | 47 | 48 | def get_lr(optimizer): 49 | """ 50 | :param optimizer: optimizer object 51 | :return: Current learning rate 52 | """ 53 | for param_group in optimizer.param_groups: 54 | return param_group['lr'] 55 | 56 | 57 | def get_exp_num(file_path, exp_name): 58 | """ 59 | Find the next open experiment ID number. 60 | exp_name: str path to the main experiment folder that contains the model folder 61 | WARNING: don't name experiments with underscores! 62 | 63 | :param file_path: str - path to folder 64 | :param exp_name: str - name of exp 65 | :return: e.g. runs/fresnel50/ 66 | """ 67 | exp_folder = os.path.expanduser(file_path) 68 | _, dirs, _ = next(os.walk(exp_folder)) 69 | exp_nums = set() 70 | for d in dirs: 71 | splt = d.split("_") 72 | if len(splt) >= 2 and splt[0] == exp_name: 73 | try: 74 | exp_nums.add(int(splt[1])) 75 | except: 76 | pass 77 | for i in range(len(exp_nums)): 78 | if i not in exp_nums: 79 | return i 80 | return len(exp_nums) 81 | 82 | 83 | def train(hyps): 84 | torch.cuda.empty_cache() 85 | 86 | # *** load model and data set **** 87 | model = DenoisingUnet(hyps=hyps) 88 | 89 | if not hyps['single_image']: 90 | dataset = NoisySBDataset(hyps=hyps) 91 | dataloader = DataLoader(dataset, batch_size=hyps['batch_size']) 92 | print('Data loader size: ', len(dataloader)) 93 | 94 | if hyps['checkpoint'] is not None: # if trained model is not given, start new checkpoint 95 | model.load_state_dict(torch.load(hyps['checkpoint'])) 96 | 97 | model.to(DEVICE) 98 | 99 | # *** establish folders for saving experiment *** 100 | run_init = os.path.join(hyps['logging_root'], hyps['exp_name']) 101 | os.makedirs(run_init, exist_ok=True) 102 | 103 | file_str = hyps['logging_root'] + '/' + hyps['exp_name'] 104 | hyps['exp_num'] = get_exp_num(file_path=file_str, exp_name=hyps['exp_name']) 105 | dir_name = "{}/{}_{}".format(hyps['exp_name'], hyps['exp_name'], hyps['exp_num']) 106 | dir_name += hyps['search_keys'] 107 | print('Saving information to ', dir_name) 108 | 109 | run_dir = os.path.join(hyps['logging_root'], dir_name) 110 | 111 | os.makedirs(run_dir, exist_ok=True) 112 | 113 | # *** set up optimizer and scheduler *** 114 | optimizer = torch.optim.Adam(model.parameters(), lr=hyps['lr']) 115 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 116 | patience=10, 117 | threshold=1e-4, 118 | factor=0.1) 119 | writer = SummaryWriter(run_dir) # run directory for tensorboard information 120 | iter = 0 121 | 122 | print('Beginning training...') 123 | if hyps['single_image']: 124 | print('Optimizing over a single image...') 125 | 126 | # early stopping criteria 127 | # TODO: move these params to params.json 128 | prev_loss = 1000 129 | stop_count = 0 130 | tolerance = 1e-4 131 | early_stop = 800 132 | epoch_loss = 0 133 | 134 | if hyps['single_image']: # MINI-LOOP for testing 135 | model_input, ground_truth = image_loader('data/lamb.png') 136 | ground_truth = ground_truth.to(DEVICE) 137 | model_input = model_input.to(DEVICE) 138 | 139 | for epoch in range(hyps['max_epoch']): 140 | model_outputs = model(model_input) 141 | optimizer.zero_grad() 142 | 143 | total_loss = model.get_distortion_loss(model_outputs, ground_truth) 144 | 145 | total_loss.backward() 146 | optimizer.step() 147 | scheduler.step(total_loss) 148 | 149 | print("Epoch %03d total_loss %0.4f" % (epoch, total_loss)) 150 | 151 | if not iter: # on the first iteration 152 | # Save parameters used into the log directory. 153 | results_file = run_dir + "/params.txt" 154 | with open(results_file, 'a') as f: 155 | for k in hyps.keys(): 156 | f.write(str(k) + ": " + str(hyps[k]) + '\n') 157 | f.write("\n") 158 | 159 | iter += 1 160 | if iter % 10 == 0: 161 | save_dict = { 162 | "model_state_dict": model.state_dict(), 163 | "heightmap": model.get_heightmap().numpy(), 164 | "psf": model.get_psf(hyps), 165 | "epoch": epoch, 166 | "iter": iter, 167 | "hyps": hyps, 168 | "loss": total_loss, 169 | } 170 | 171 | torch.save(save_dict, os.path.join(run_dir, 'model_epoch_%d_iter_%s.pth' % (epoch, iter))) 172 | results = {"epoch": epoch, 173 | "loss": total_loss} 174 | return results 175 | 176 | for epoch in range(hyps['max_epoch']): 177 | for model_input, ground_truth in dataloader: 178 | 179 | ground_truth = ground_truth.to(DEVICE) 180 | model_input = model_input.to(DEVICE) 181 | 182 | model_outputs = model(model_input) 183 | model.write_updates(writer, model_outputs, ground_truth, model_input, iter, hyps) 184 | 185 | optimizer.zero_grad() 186 | 187 | psnr = model.get_psnr(model_outputs, ground_truth) 188 | dist_loss = model.get_distortion_loss(model_outputs, ground_truth) 189 | reg_loss = model.get_regularization_loss(model_outputs, ground_truth) 190 | total_loss = dist_loss # can include reg_loss in the future 191 | epoch_loss += total_loss 192 | 193 | total_loss.backward() 194 | optimizer.step() 195 | scheduler.step(total_loss) 196 | 197 | print("Iter %07d Epoch %03d dist_loss %0.4f reg_loss %0.4f" % 198 | (iter, epoch, dist_loss, reg_loss * hyps['reg_weight'])) 199 | 200 | writer.add_scalar("scaled_regularization_loss", reg_loss * hyps['reg_weight'], iter) 201 | writer.add_scalar("distortion_loss", dist_loss, iter) 202 | writer.add_scalar("learning_rate", get_lr(optimizer), iter) 203 | 204 | if prev_loss - total_loss <= tolerance: 205 | stop_count += 1 206 | if stop_count >= early_stop: 207 | break 208 | elif stop_count >= 1: 209 | stop_count = 0 210 | prev_loss = total_loss 211 | 212 | if not iter: # on the first iteration 213 | # Save parameters used into the log directory. 214 | results_file = run_dir + "/params.txt" 215 | with open(results_file, 'a') as f: 216 | for k in hyps.keys(): 217 | f.write(str(k) + ": " + str(hyps[k]) + '\n') 218 | f.write("\n") 219 | 220 | iter += 1 221 | if iter % 10 == 0: # used to be 10,000 222 | save_dict = { 223 | "model_state_dict": model.state_dict(), 224 | "optim_state_dict": optimizer.state_dict(), 225 | "heightmap": model.get_heightmap().numpy(), 226 | "psf": model.get_psf(hyps), 227 | "epoch": epoch, 228 | "iter": iter, 229 | "hyps": hyps, 230 | "avg_loss": epoch_loss/iter, 231 | "loss": total_loss, 232 | "psnr": psnr, 233 | "K": model.get_damp() 234 | } 235 | 236 | for k in hyps.keys(): 237 | if k not in save_dict: 238 | save_dict[k] = hyps[k] 239 | 240 | torch.save(save_dict, os.path.join(run_dir, 'model_epoch_%d_iter_%s.pth' % (epoch, iter))) 241 | 242 | if stop_count >= early_stop: 243 | breakp 244 | torch.save(save_dict, os.path.join(run_dir, 'model_epoch_%d_iter_%s.pth' % (epoch, iter))) 245 | 246 | results = {"epoch": epoch, 247 | "iter": iter, 248 | "loss": total_loss} 249 | return results 250 | 251 | 252 | def fill_hyper_q(hyps, ranges, keys, hyper_q, idx=0): 253 | """ 254 | Recursive function to fill queue of specified hyperparameter ranges 255 | :param hyps: dict of hyperparameters 256 | :param ranges: dict of different hyperparameters to test 257 | :param keys: 258 | :param hyper_q: queue of dictionary of hyperparameters 259 | :param idx: current index of hyperparameter being added 260 | :return: queue of dictionary of hyperparameters 261 | """ 262 | if idx >= len(keys): 263 | hyps['search_keys'] = "" 264 | for k in keys: 265 | hyps['search_keys'] += '_' + str(k)+str(hyps[k]) 266 | hyper_q.put({k:v for k,v in hyps.items()}) 267 | else: 268 | key = keys[idx] 269 | for param in ranges[key]: 270 | hyps[key] = param 271 | hyper_q = fill_hyper_q(hyps, ranges, keys, hyper_q, idx+1) 272 | return hyper_q 273 | 274 | 275 | def hyper_search(hyps, ranges): 276 | """ 277 | Creates a queue of experiments to test (experiment is one set of hyperparameters) 278 | Saves results 279 | :param hyps: dictionary of hyperparameters 280 | :param ranges: dictionary of ranges of hyperparameters to test 281 | """ 282 | starttime = time.time() 283 | 284 | # make results file 285 | if not os.path.exists("runs/"+hyps['exp_name']): 286 | os.mkdir("runs/"+hyps['exp_name']) 287 | 288 | results_file = "runs/"+hyps['exp_name']+"/results.txt" 289 | 290 | with open(results_file,'a') as f: 291 | f.write("Hyperparameters:\n") 292 | for k in hyps.keys(): 293 | if k not in ranges: 294 | f.write(str(k) + ": " + str(hyps[k]) + '\n') 295 | f.write("\nHyperranges:\n") 296 | for k in ranges.keys(): 297 | rs = ",".join([str(v) for v in ranges[k]]) 298 | s = str(k) + ": ["+ rs + ']\n' 299 | f.write(s) 300 | f.write('\n') 301 | 302 | hyper_q = Queue() 303 | hyper_q = fill_hyper_q(hyps, ranges, list(ranges.keys()), hyper_q, idx=0) 304 | 305 | print("n_searches:", hyper_q.qsize()) 306 | 307 | while not hyper_q.empty(): 308 | print() 309 | print("Searches left:", hyper_q.qsize(), "-- Running Time:", time.time()-starttime) 310 | hyps = hyper_q.get() 311 | results = train(hyps) 312 | with open(results_file, 'a') as f: 313 | results = " -- ".join([str(k) + ":" + str(results[k]) \ 314 | for k in sorted(results.keys())]) 315 | f.write("\n"+results+"\n") 316 | 317 | if __name__ == '__main__': 318 | # *** load params *** 319 | params_file = "params.json" 320 | ranges_file = "ranges.json" 321 | print() 322 | print("Using params file:", params_file) 323 | print("Using ranges files:", ranges_file) 324 | print() 325 | hyps = load_json(params_file) 326 | ranges = load_json(ranges_file) 327 | hyps_str = "" 328 | for k, v in hyps.items(): 329 | hyps_str += "{}: {}\n".format(k, v) 330 | print("Hyperparameters:") 331 | print(hyps_str) 332 | print("\nSearching over:") 333 | print("\n".join(["{}: {}".format(k, v) for k, v in ranges.items()])) 334 | 335 | os.makedirs(hyps['data_root'], exist_ok=True) 336 | os.makedirs(hyps['logging_root'], exist_ok=True) 337 | 338 | start_time = time.time() 339 | hyper_search(hyps, ranges) 340 | print("Total Execution Time: ", time.time() - start_time) 341 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """A collection of functions for use with complex numbers in pytorch 2 | 3 | Author: Nitish Padmanaban 4 | """ 5 | 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def stack_complex(real_mag, imag_phase): 11 | return torch.stack((real_mag, imag_phase), -1) 12 | 13 | 14 | def unstack_complex(stacked_array): 15 | return stacked_array[..., 0], stacked_array[..., 1] 16 | 17 | 18 | def heightmap_to_phase(height, wavelength, refractive_index): 19 | return height * (2 * np.pi / wavelength) * (refractive_index - 1) 20 | 21 | 22 | def phase_to_heightmap(phase, wavelength, refractive_index): 23 | return phase / (2 * np.pi / wavelength) / (refractive_index - 1) 24 | 25 | 26 | def rect_to_polar(real, imag): 27 | mag = torch.pow(real**2 + imag**2, 0.5) 28 | ang = torch.atan2(imag, real) 29 | return mag, ang 30 | 31 | 32 | def polar_to_rect(mag, ang): 33 | real = mag * torch.cos(ang) 34 | imag = mag * torch.sin(ang) 35 | return real, imag 36 | 37 | 38 | def rect_to_polar_stacked(real_imag): 39 | mag, ang = rect_to_polar(*unstack_complex(real_imag)) 40 | return stack_complex(mag, ang) 41 | 42 | 43 | def polar_to_rect_stacked(mag_ang): 44 | real, imag = polar_to_rect(*unstack_complex(mag_ang)) 45 | return stack_complex(real, imag) 46 | 47 | 48 | def field_to_intensity(real_imag): 49 | return (real_imag ** 2).sum(-1) 50 | 51 | 52 | def field_to_intensity_polar(mag_ang): 53 | return mag_ang[..., 0] ** 2 54 | 55 | 56 | def conj(real_imag): 57 | # also works the same for mag_ang representation 58 | real, imag = unstack_complex(real_imag) 59 | return stack_complex(real, -imag) 60 | 61 | 62 | def mul_complex(field1, field2): 63 | real1, imag1 = unstack_complex(field1) 64 | real2, imag2 = unstack_complex(field2) 65 | 66 | real = real1 * real2 - imag1 * imag2 67 | imag = real1 * imag2 + imag1 * real2 68 | 69 | return stack_complex(real, imag) 70 | 71 | 72 | def mul_complex_polar(field1, field2): 73 | mag1, ang1 = unstack_complex(field1) 74 | mag2, ang2 = unstack_complex(field2) 75 | 76 | mag = mag1 * mag2 77 | ang = ang1 + ang2 78 | 79 | over = ang > np.pi 80 | ang[over].sub_(2 * np.pi) 81 | under = ang <= -np.pi 82 | ang[under].add_(2 * np.pi) 83 | 84 | return stack_complex(mag, ang) 85 | 86 | 87 | def div_complex(field1, field2): 88 | real1, imag1 = unstack_complex(field1) 89 | real2, imag2 = unstack_complex(field2) 90 | 91 | mag_squared = (real2 ** 2) + (imag2 ** 2) 92 | 93 | real = (real1 * real2 + imag1 * imag2) / mag_squared 94 | imag = (-real1 * imag2 + imag1 * real2) / mag_squared 95 | 96 | return stack_complex(real, imag) 97 | 98 | 99 | def div_complex_polar(field1, field2): 100 | mag1, ang1 = unstack_complex(field1) 101 | mag2, ang2 = unstack_complex(field2) 102 | 103 | mag = mag1 / mag2 104 | ang = ang1 - ang2 105 | 106 | over = ang > np.pi 107 | ang[over].sub_(2 * np.pi) 108 | under = ang <= -np.pi 109 | ang[under].add_(2 * np.pi) 110 | 111 | return stack_complex(mag, ang) 112 | 113 | 114 | def recip_complex(field): 115 | real, imag = unstack_complex(field) 116 | 117 | mag_squared = (real ** 2) + (imag ** 2) 118 | 119 | real_inv = real / mag_squared 120 | imag_inv = -imag / mag_squared 121 | 122 | return stack_complex(real_inv, imag_inv) 123 | 124 | 125 | def recip_complex_polar(field): 126 | mag, ang = unstack_complex(field) 127 | return stack_complex(1 / mag, -ang) 128 | 129 | 130 | def conv_fft(img_real_imag, kernel_real_imag, padval=0): 131 | img_pad, kernel_pad, output_pad = conv_pad_sizes(img_real_imag.shape, 132 | kernel_real_imag.shape) 133 | 134 | # fft 135 | img_fft = fft(img_real_imag, pad=img_pad, padval=padval) 136 | kernel_fft = fft(kernel_real_imag, pad=kernel_pad, padval=0) 137 | 138 | # ifft, using img_pad to bring output to img input size 139 | return ifft(mul_complex(img_fft, kernel_fft), pad=output_pad) 140 | 141 | 142 | def conv_fft_polar(img_mag_ang, kernel_mag_ang, padval=0): 143 | img_pad, kernel_pad, output_pad = conv_pad_sizes(img_mag_ang.shape, 144 | kernel_mag_ang.shape) 145 | 146 | # fft 147 | img_fft = fft_polar(img_mag_ang, pad=img_pad, padval=padval) 148 | kernel_fft = fft_polar(kernel_mag_ang, pad=kernel_pad, padval=0) 149 | 150 | # ifft, using img_pad to bring output to img input size 151 | return ifft_polar(mul_complex_polar(img_fft, kernel_fft), pad=output_pad) 152 | 153 | 154 | def fft(real_imag, ndims=2, normalized=False, pad=None, padval=0): 155 | if pad is not None: 156 | real_imag = pad_stacked(real_imag, pad, padval=padval) 157 | return fftshift(torch.fft(ifftshift(real_imag, ndims), ndims, 158 | normalized=normalized), ndims) 159 | 160 | 161 | def fft_polar(mag_ang, ndims=2, normalized=False, pad=None, padval=0): 162 | real_imag = polar_to_rect_stacked(mag_ang) 163 | real_imag_fft = fft(real_imag, ndims, normalized, pad, padval) 164 | return rect_to_polar_stacked(real_imag_fft) 165 | 166 | 167 | def ifft(real_imag, ndims=2, normalized=False, pad=None): 168 | transformed = fftshift(torch.ifft(ifftshift(real_imag, ndims), ndims, 169 | normalized=normalized), ndims) 170 | if pad is not None: 171 | transformed = crop(transformed, pad) 172 | 173 | return transformed 174 | 175 | 176 | def ifft_polar(mag_ang, ndims=2, normalized=False, pad=None): 177 | real_imag = polar_to_rect_stacked(mag_ang) 178 | real_imag_ifft = ifft(real_imag, ndims, normalized, pad) 179 | return rect_to_polar_stacked(real_imag_ifft) 180 | 181 | 182 | def fftshift(array, ndims=2, invert=False): 183 | shift_adjust = 0 if invert else 1 184 | 185 | # skips the last dimension, assuming stacked fft output 186 | if ndims >= 1: 187 | shift_len = (array.shape[-2] + shift_adjust) // 2 188 | array = torch.cat((array[..., shift_len:, :], 189 | array[..., :shift_len, :]), -2) 190 | if ndims >= 2: 191 | shift_len = (array.shape[-3] + shift_adjust) // 2 192 | array = torch.cat((array[..., shift_len:, :, :], 193 | array[..., :shift_len, :, :]), -3) 194 | if ndims == 3: 195 | shift_len = (array.shape[-4] + shift_adjust) // 2 196 | array = torch.cat((array[..., shift_len:, :, :, :], 197 | array[..., :shift_len, :, :, :]), -4) 198 | return array 199 | 200 | 201 | def ifftshift(array, ndims=2): 202 | return fftshift(array, ndims, invert=True) 203 | 204 | 205 | def conv_pad_sizes(image_shape, kernel_shape): 206 | # skips the last dimension, assuming stacked fft output 207 | # minimum required padding is to img.shape + kernel.shape - 1 208 | # padding based on matching fftconvolve output 209 | 210 | # when kernels are even, padding the extra 1 before/after matters 211 | img_pad_end = (1 - ((kernel_shape[-2] % 2) | (image_shape[-2] % 2)), 212 | 1 - ((kernel_shape[-3] % 2) | (image_shape[-3] % 2))) 213 | 214 | image_pad = ((kernel_shape[-2] - img_pad_end[0]) // 2, 215 | (kernel_shape[-2] - 1 + img_pad_end[0]) // 2, 216 | (kernel_shape[-3] - img_pad_end[1]) // 2, 217 | (kernel_shape[-3] - 1 + img_pad_end[1]) // 2) 218 | kernel_pad = (image_shape[-2] // 2, (image_shape[-2] - 1) // 2, 219 | image_shape[-3] // 2, (image_shape[-3] - 1) // 2) 220 | output_pad = ((kernel_shape[-2] - 1) // 2, kernel_shape[-2] // 2, 221 | (kernel_shape[-3] - 1) // 2, kernel_shape[-3] // 2) 222 | return image_pad, kernel_pad, output_pad 223 | 224 | 225 | def pad_stacked(field, pad_width, padval=0): 226 | if padval == 0: 227 | pad_width = (0, 0, *pad_width) # add 0 padding for stacked dimension 228 | return torch.nn.functional.pad(field, pad_width) 229 | else: 230 | if isinstance(padval, torch.Tensor): 231 | padval = padval.item() 232 | 233 | real, imag = unstack_complex(field) 234 | real = torch.nn.functional.pad(real, pad_width, value=padval) 235 | imag = torch.nn.functional.pad(imag, pad_width, value=0) 236 | return stack_complex(real, imag) 237 | 238 | 239 | def crop(array, pad): 240 | # skips the last dimension, assuming stacked fft output 241 | if len(pad) >= 2 and (pad[0] or pad[1]): 242 | if pad[1]: 243 | array = array[..., pad[0]:-pad[1], :] 244 | else: 245 | array = array[..., pad[0]:, :] 246 | 247 | if len(pad) >= 4 and (pad[2] or pad[3]): 248 | if pad[3]: 249 | array = array[..., pad[2]:-pad[3], :, :] 250 | else: 251 | array = array[..., pad[2]:, :, :] 252 | 253 | if len(pad) == 6 and (pad[4] or pad[5]): 254 | if pad[5]: 255 | array = array[..., pad[4]:-pad[5], :, :, :] 256 | else: 257 | array = array[..., pad[4]:, :, :, :] 258 | 259 | return array 260 | 261 | 262 | def pad_smaller_dims(field, target_shape, pytorch=True, stacked=True, padval=0): 263 | if pytorch: 264 | if stacked: 265 | size_diff = np.array(target_shape) - np.array(field.shape[-3:-1]) 266 | odd_dim = np.array(field.shape[-3:-1]) % 2 267 | else: 268 | size_diff = np.array(target_shape) - np.array(field.shape[-2:]) 269 | odd_dim = np.array(field.shape[-2:]) % 2 270 | else: 271 | size_diff = np.array(target_shape) - np.array(field.shape) 272 | odd_dim = np.array(field.shape) % 2 273 | 274 | # pad the dimensions that need to increase in size 275 | if (size_diff > 0).any(): 276 | pad_total = np.maximum(size_diff, 0) 277 | pad_front = (pad_total + odd_dim) // 2 278 | pad_end = (pad_total + 1 - odd_dim) // 2 279 | 280 | if pytorch: 281 | pad_axes = [int(p) # convert from np.int64 282 | for tple in zip(pad_front[::-1], pad_end[::-1]) 283 | for p in tple] 284 | if stacked: 285 | return pad_stacked(field, pad_axes, padval=padval) 286 | else: 287 | return torch.nn.functional.pad(field, pad_axes, value=padval) 288 | else: 289 | return np.pad(field, tuple(zip(pad_front, pad_end)), 'constant', 290 | constant_values=padval) 291 | else: 292 | return field 293 | 294 | 295 | def crop_larger_dims(field, target_shape, pytorch=True, stacked=True): 296 | if pytorch: 297 | if stacked: 298 | size_diff = np.array(field.shape[-3:-1]) - np.array(target_shape) 299 | odd_dim = np.array(field.shape[-3:-1]) % 2 300 | else: 301 | size_diff = np.array(field.shape[-2:]) - np.array(target_shape) 302 | odd_dim = np.array(field.shape[-2:]) % 2 303 | else: 304 | size_diff = np.array(field.shape) - np.array(target_shape) 305 | odd_dim = np.array(field.shape) % 2 306 | 307 | # crop dimensions that need to decrease in size 308 | if (size_diff > 0).any(): 309 | crop_total = np.maximum(size_diff, 0) 310 | crop_front = (crop_total + 1 - odd_dim) // 2 311 | crop_end = (crop_total + odd_dim) // 2 312 | 313 | crop_slices = [slice(int(f), int(-e) if e else None) 314 | for f, e in zip(crop_front, crop_end)] 315 | if pytorch and stacked: 316 | return field[(..., *crop_slices, slice(None))] 317 | else: 318 | return field[(..., *crop_slices)] 319 | else: 320 | return field 321 | 322 | --------------------------------------------------------------------------------