├── .gitignore ├── LICENSE ├── README.md ├── args.py ├── dataset.py ├── experimental ├── ckpt │ ├── checkpoint │ ├── ckpt-98.data-00000-of-00002 │ ├── ckpt-98.data-00001-of-00002 │ └── ckpt-98.index └── data │ ├── captures │ ├── 102302.npy │ ├── 110802.npy │ └── 138301.npy │ ├── psf │ └── psf.npy │ └── vignette_factor.npy ├── loss.py ├── metasurface ├── conv.py └── solver.py ├── networks ├── G │ ├── FP.py │ └── Wiener.py └── select.py ├── run_train.sh ├── test.ipynb ├── train.py └── training ├── ckpt └── ckpt.txt └── data ├── test ├── 100600.jpg ├── 116601.jpg └── 137901.jpg └── train ├── 106000.jpg ├── 106001.jpg ├── 106002.jpg ├── 106100.jpg ├── 106101.jpg ├── 106200.jpg ├── 106201.jpg ├── 106202.jpg ├── 106300.jpg ├── 106301.jpg ├── 106700.jpg ├── 106701.jpg ├── 106702.jpg ├── 106703.jpg ├── 106704.jpg ├── 106800.jpg ├── 106801.jpg ├── 106900.jpg ├── 106901.jpg ├── 106902.jpg └── 106903.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | **/*.out 3 | **/__pycache__ 4 | **/.ipynb_checkpoints 5 | **/*.pyc 6 | **/.DS_Store 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any person or organization 2 | obtaining a copy of the software and accompanying documentation covered by 3 | this license (the "Software") to use, reproduce, display, distribute, 4 | execute, and transmit the Software, and to prepare derivative works of the 5 | Software, and to permit third-parties to whom the Software is furnished to 6 | do so, all subject to the following: 7 | 8 | The copyright notices in the Software and this entire statement, including 9 | the above license grant, this restriction and the following disclaimer, 10 | must be included in all copies of the Software, in whole or in part, and 11 | all derivative works of the Software, unless such copies or derivative 12 | works are solely in the form of machine-executable object code generated by 13 | a source language processor. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT 18 | SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE 19 | FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, 20 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 21 | DEALINGS IN THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Nano-Optics for High-quality Thin Lens Imaging 2 | ### [Project Page](https://light.princeton.edu/neural-nano-optics/) | [Paper](https://www.nature.com/articles/s41467-021-26443-0) | [Data](https://drive.google.com/drive/folders/1fsAvN9MPtN5jJPeIFjWuLUY9Hp8NNkar?usp=sharing) 3 | 4 | [![DOI: 10.5281/zenodo.47223](https://zenodo.org/badge/doi/10.5281/zenodo.5637678.svg)](https://doi.org/10.5281/zenodo.5637678) 5 | 6 | [Ethan Tseng](https://ethan-tseng.github.io), [Shane Colburn](https://scholar.google.com/citations?user=WLnx6NkAAAAJ&hl=en), [James Whitehead](https://scholar.google.com/citations?user=Hpcg0h4AAAAJ&hl=en), [Luocheng Huang](https://scholar.google.com/citations?user=x9UDJHgAAAAJ&hl=en), [Seung-Hwan Baek](https://sites.google.com/view/shbaek/), [Arka Majumdar](https://scholar.google.com/citations?user=DpIGlW4AAAAJ&hl=en), [Felix Heide](https://www.cs.princeton.edu/~fheide/) 7 | 8 | This code implements a differentiable proxy model for simulating meta-optics and a neural feature propagation deconvolution method. These components are optimized end-to-end using machine learning optimizers. 9 | 10 | The experimental results from the manuscript and the supplemental information are reproducible with this implementation. The proposed differentiable proxy model, neural feature propagation, and end-to-end optimization framework are implemented completely in TensorFlow, without dependency on third-party libraries. 11 | 12 | ## Training 13 | To perform end-to-end training (of meta-optic and deconvolution) execute the 'run_train.sh' script. The model checkpoint which includes saved parameters for both the meta-optic and deconvolution will be saved to 'training/ckpt'. The folder 'training/data' contains a subset of the training and test data that we used for optimizing our end-to-end imaging pipeline. 14 | 15 | ## Testing 16 | To perform inference on real-world captures launch the "test.ipynb" notebook in Jupyter Notebook and step through the cells. The notebook will load in a finetuned checkpoint of our neural feature propagation network from 'experimental/ckpt' which will process captured sensor measurements located in 'experimental/data'. The reconstructed images will be displayed within the notebook. 17 | 18 | Additional captured sensor measurements can be found in the [data repository](https://drive.google.com/drive/folders/1fsAvN9MPtN5jJPeIFjWuLUY9Hp8NNkar?usp=sharing). 19 | 20 | ## Requirements 21 | This code has been tested with Python 3.6.10 using TensorFlow 2.2.0 running on Linux with an Nvidia P100 GPU with 16GB RAM. 22 | 23 | We installed the following library packages to run this code: 24 | ``` 25 | TensorFlow >= 2.2 26 | TensorFlow Probability 27 | TensorFlow Addons 28 | Numpy 29 | Scipy 30 | matplotlib 31 | jupyter-notebook 32 | ``` 33 | 34 | ## Citation 35 | If you find our work useful in your research, please cite: 36 | ``` 37 | @article{Tseng2021NeuralNanoOptics, 38 | author={Tseng, Ethan and Colburn, Shane and Whitehead, James and Huang, Luocheng 39 | and Baek, Seung-Hwan and Majumdar, Arka and Heide, Felix}, 40 | title={Neural nano-optics for high-quality thin lens imaging}, 41 | journal={Nature Communications}, 42 | year={2021}, 43 | month={Nov}, 44 | day={29}, 45 | volume={12}, 46 | number={1}, 47 | pages={6493} 48 | } 49 | ``` 50 | 51 | ## License 52 | Our code is licensed under BSL-1. By downloading the software, you agree to the terms of this License. The training data in the folder 'training/data' comes from the [INRIA Holidays Dataset](https://lear.inrialpes.fr/~jegou/data.php). 53 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | # Argument parameters file 2 | 3 | import argparse 4 | 5 | def parse_args(): 6 | def str2bool(v): 7 | assert(v == 'True' or v == 'False') 8 | return v.lower() in ('true') 9 | 10 | def none_or_str(value): 11 | if value.lower() == 'none': 12 | return None 13 | return value 14 | 15 | parser = argparse.ArgumentParser(description='Parameter settings for end-to-end optimization of neural nano-optics') 16 | 17 | # Data loading arguments 18 | parser.add_argument('--train_dir' , type=str, required=True, help='Directory of training input images') 19 | parser.add_argument('--test_dir' , type=str, required=True, help='Directory of testing input images') 20 | 21 | # Saving and logging arguments 22 | parser.add_argument('--save_dir' , type=str, required=True, help='Directory for saving ckpts and TensorBoard file') 23 | parser.add_argument('--save_freq' , type=int, default=1000, help='Interval to save model') 24 | parser.add_argument('--log_freq' , type=int, default=500, help='Interval to write to TensorBoard') 25 | parser.add_argument('--ckpt_dir' , type=none_or_str, default='None', help='Restoring from a checkpoint') 26 | parser.add_argument('--max_to_keep', type=int, default=2, help='Number of checkpoints to save') 27 | 28 | # Loss arguments 29 | parser.add_argument('--loss_mode' , type=str, default='L1') 30 | parser.add_argument('--batch_weights' , type=str, default='1.0') 31 | parser.add_argument('--Norm_loss_weight' , type=float, default=1.0) 32 | parser.add_argument('--P_loss_weight' , type=float, default=0.0) 33 | parser.add_argument('--Spatial_loss_weight', type=float, default=0.0) 34 | parser.add_argument('--vgg_layers' , type=str, default='block2_conv2,block3_conv2') 35 | 36 | # Training arguments 37 | parser.add_argument('--steps' , type=int, default=1000000000, help='Total number of optimization cycles') 38 | parser.add_argument('--aug_rotate', type=str2bool, default=False, help='True to rotate PSF during training') 39 | 40 | # Convolution arguments 41 | parser.add_argument('--real_psf' , type=str, help='Npy of experimentally measured PSF') 42 | parser.add_argument('--psf_mode' , type=str, default=True, help='Use simulated PSF or captured PSF') 43 | parser.add_argument('--conv_mode' , type=str, default=True, help='True to apply convolution for forward model') 44 | parser.add_argument('--conv' , type=str, default='patch_size', help='patch_size for memory efficiency, full_size for full image') 45 | parser.add_argument('--do_taper' , type=str2bool, default=True, help='Activate edge tapering') 46 | parser.add_argument('--offset' , type=str2bool, default=True, help='True to use offset convolution mode') 47 | parser.add_argument('--normalize_psf', type=str2bool, default=False, help='True to normalize PSF') 48 | parser.add_argument('--theta_base' , type=str, default = '0.0,5.0,10.0,15.0', help='Field angles') 49 | 50 | # Metasurface arguments 51 | parser.add_argument('--num_coeffs' , type=int, default=8, help='Number of optimizable phase coefficients') 52 | parser.add_argument('--use_general_phase', type=str2bool, default=False, help='Set to true to use a pre-determined phase pattern') 53 | parser.add_argument('--metasurface' , type=str , default='zeros', help='Metasurface initialization') 54 | parser.add_argument('--s1' , type=float, default=0.9e-3, help='s1 parameter for log-asphere/saxicon') 55 | parser.add_argument('--s2' , type=float, default=1.4e-3, help='s2 parameter for log-asphere/saxicon') 56 | parser.add_argument('--alpha' , type=float, default=270.176968209, help='Alpha value for cubic (set to 86*pi)') 57 | parser.add_argument('--target_wavelength', type=float, default=511.0e-9, help='Target wavelength for hyperboidal and squbic') 58 | parser.add_argument('--bound_val' , type=float, default=1000.0, help='Absolute value of range for phase coeff') 59 | 60 | # Sensor arguments 61 | parser.add_argument('--a_poisson', type=float, default=0.00004, help='Poisson noise component') 62 | parser.add_argument('--b_sqrt' , type=float, default=0.00001, help='Gaussian noise standard deviation') 63 | parser.add_argument('--mag' , type=float, default=8.1, help='Relay system magnification factor (slightly less than 10x)') 64 | 65 | # Optimization arguments 66 | parser.add_argument('--Phase_iters', type=int, default=1, help='Number of meta-optic optimization iterations per cycle') 67 | parser.add_argument('--Phase_lr' , type=float, default=5e-3, help='Meta-optic learning rate') 68 | parser.add_argument('--Phase_beta1', type=float, default=0.9, help='Meta-optic beta1 term for Adam optimizer') 69 | parser.add_argument('--G_iters' , type=int, default=1, help='Number of deconvolution optimization iterations per cycle') 70 | parser.add_argument('--G_lr' , type=float, default=1e-4, help='Deconvolution learning rate') 71 | parser.add_argument('--G_beta1' , type=float, default=0.9, help='Deconvolution beta1 term for Adam optimizer') 72 | parser.add_argument('--G_network' , type=str, default='FP', help='Select deconvolution method') 73 | parser.add_argument('--snr_opt' , type=str2bool, default=False, help='True to optimize SNR parameter') 74 | parser.add_argument('--snr_init' , type=float, default=4.0, help='Initial value of SNR parameter') 75 | 76 | args = parser.parse_args() 77 | args.theta_base = [float(w) for w in args.theta_base.split(',')] 78 | args.batch_weights = [float(w) for w in args.batch_weights.split(',')] 79 | print(args) 80 | return args 81 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def load(image_width, image_width_padded, augment): 4 | # image_width = Width for image content 5 | # image_width_padded = Width including padding to accomodate PSF 6 | def load_fn(image_file): 7 | image = tf.io.read_file(image_file) 8 | image = tf.image.decode_jpeg(image) 9 | image = tf.cast(image, tf.float32) 10 | image = image / 255. 11 | 12 | if augment: 13 | image = tf.image.random_flip_left_right(image) 14 | image = tf.image.random_flip_up_down(image) 15 | image = tf.image.resize_with_crop_or_pad(image, image_width_padded, image_width_padded) 16 | return (image, image) # Input and GT 17 | return load_fn 18 | 19 | def train_dataset_sim(image_width, image_width_padded, args): 20 | load_fn = load(image_width, image_width_padded, augment=True) 21 | ds = tf.data.Dataset.list_files(args.train_dir+'*.jpg') 22 | ds = ds.map(load_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) 23 | ds = ds.shuffle(20) 24 | ds = ds.repeat() # Repeat forever 25 | ds = ds.batch(1) # Batch size = 1 26 | ds = ds.prefetch(tf.data.experimental.AUTOTUNE) 27 | return ds 28 | 29 | def test_dataset_sim(image_width, image_width_padded, args): 30 | load_fn = load(image_width, image_width_padded, augment=False) 31 | ds = tf.data.Dataset.list_files(args.test_dir+'*.jpg', shuffle=False) 32 | ds = ds.map(load_fn, num_parallel_calls=None) 33 | ds = ds.batch(1) # Batch size = 1 34 | return ds 35 | -------------------------------------------------------------------------------- /experimental/ckpt/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "ckpt-98" 2 | all_model_checkpoint_paths: "ckpt-98" 3 | all_model_checkpoint_timestamps: 1608870002.7160997 4 | last_preserved_timestamp: 1608870001.6286473 5 | -------------------------------------------------------------------------------- /experimental/ckpt/ckpt-98.data-00000-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/experimental/ckpt/ckpt-98.data-00000-of-00002 -------------------------------------------------------------------------------- /experimental/ckpt/ckpt-98.data-00001-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/experimental/ckpt/ckpt-98.data-00001-of-00002 -------------------------------------------------------------------------------- /experimental/ckpt/ckpt-98.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/experimental/ckpt/ckpt-98.index -------------------------------------------------------------------------------- /experimental/data/captures/102302.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/experimental/data/captures/102302.npy -------------------------------------------------------------------------------- /experimental/data/captures/110802.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/experimental/data/captures/110802.npy -------------------------------------------------------------------------------- /experimental/data/captures/138301.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/experimental/data/captures/138301.npy -------------------------------------------------------------------------------- /experimental/data/psf/psf.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/experimental/data/psf/psf.npy -------------------------------------------------------------------------------- /experimental/data/vignette_factor.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/experimental/data/vignette_factor.npy -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | # Loss functions 2 | 3 | import tensorflow as tf 4 | 5 | # Per-Pixel loss 6 | def Norm_loss(G_img, gt_img, args): 7 | if args.loss_mode == 'L1': metric = tf.abs 8 | elif args.loss_mode == 'L2': metric = tf.square 9 | else: assert False, ("Mode needs to be L1 or L2") 10 | 11 | loss = 0.0 12 | for i, weight in enumerate(args.batch_weights): 13 | loss = loss + weight * tf.reduce_mean(metric(G_img[i,:,:,:] - gt_img[0,:,:,:])) 14 | return loss 15 | 16 | # Perceptual loss (VGG19) 17 | def P_loss(G_img, gt_img, vgg_model, args): 18 | if args.loss_mode == 'L1': metric = tf.abs 19 | elif args.loss_mode == 'L2': metric = tf.square 20 | else: assert False, ("Mode needs to be L1 or L2") 21 | 22 | preprocessed_G_img = tf.keras.applications.vgg19.preprocess_input(G_img*255.0) 23 | preprocessed_gt_img = tf.keras.applications.vgg19.preprocess_input(gt_img*255.0) 24 | 25 | G_layer_outs = vgg_model(preprocessed_G_img) 26 | gt_layer_outs = vgg_model(preprocessed_gt_img) 27 | 28 | loss = 0.0 29 | for i, weight in enumerate(args.batch_weights): 30 | loss = loss + weight * tf.add_n([tf.reduce_mean(metric( (G_layer_out[i,:,:,:] - gt_layer_out[0,:,:,:]) / 255. )) 31 | for G_layer_out, gt_layer_out in zip(G_layer_outs, gt_layer_outs)]) 32 | return loss 33 | 34 | # Spatial gradient loss 35 | def Spatial_loss(output_img, GT_img, args): 36 | if args.loss_mode == 'L1': metric = tf.abs 37 | elif args.loss_mode == 'L2': metric = tf.square 38 | else: assert False, ("Mode needs to be L1 or L2") 39 | 40 | def spatial_gradient(x): 41 | diag_down = x[:, 1:, 1:, :] - x[:, :-1, :-1, :] 42 | dv = x[:, 1:, :, :] - x[:, :-1, :, :] 43 | dh = x[:, :, 1:, :] - x[:, :, :-1, :] 44 | diag_up = x[:, :-1, 1:, :] - x[:, 1:, :-1, :] 45 | 46 | return [dh, dv, diag_down, diag_up] 47 | 48 | total_loss = 0.0 49 | for i, weight in enumerate(args.batch_weights): 50 | gx = spatial_gradient(output_img[i:i+1,:,:,:]) 51 | gy = spatial_gradient(GT_img) 52 | loss = 0 53 | for xx, yy in zip(gx, gy): 54 | loss = loss + tf.reduce_mean(metric(xx - yy)) 55 | total_loss = total_loss + weight * loss 56 | return total_loss 57 | 58 | # Loss for the entire end-to-end imaging pipeline 59 | def G_loss(G_img, gt_img, vgg_model, args): 60 | # Compute metrics 61 | PSNR = tf.reduce_mean(tf.image.psnr(G_img, gt_img, max_val=1.0)) 62 | SSIM = tf.reduce_mean(tf.image.ssim(G_img, gt_img, max_val=1.0)) 63 | metrics = {'PSNR':PSNR, 'SSIM':SSIM} 64 | 65 | # Compute losses 66 | Norm_loss_val = 0.0 67 | P_loss_val = 0.0 68 | Spatial_loss_val = 0.0 69 | if not args.Norm_loss_weight == 0.0: 70 | Norm_loss_val = args.Norm_loss_weight * Norm_loss(G_img, gt_img, args) 71 | if not args.P_loss_weight == 0.0: 72 | P_loss_val = args.P_loss_weight * P_loss(G_img, gt_img, vgg_model, args) 73 | if not args.Spatial_loss_weight == 0.0: 74 | Spatial_loss_val = args.Spatial_loss_weight * Spatial_loss(G_img, gt_img, args) 75 | Content_loss_val = Norm_loss_val + P_loss_val + Spatial_loss_val 76 | loss_components = {'Norm':Norm_loss_val, 'P':P_loss_val, 'Spatial':Spatial_loss_val} 77 | 78 | return Content_loss_val, loss_components, metrics 79 | -------------------------------------------------------------------------------- /metasurface/conv.py: -------------------------------------------------------------------------------- 1 | # Convolution and Fourier operations 2 | 3 | import tensorflow as tf 4 | import time 5 | 6 | def fft(img): 7 | img = tf.transpose(tf.cast(img, dtype = tf.complex64), perm = [0, 3, 1, 2]) 8 | Fimg = tf.signal.fft2d(img) 9 | return Fimg 10 | 11 | def ifft(Fimg): 12 | img = tf.cast(tf.abs(tf.signal.ifft2d(Fimg)), dtype=tf.float32) 13 | img = tf.transpose(img, perm = [0, 2, 3, 1]) 14 | return img 15 | 16 | def psf2otf(psf, h, w): 17 | psf = tf.image.resize_with_crop_or_pad(psf, h, w) 18 | psf = tf.transpose(tf.cast(psf, dtype = tf.complex64), perm = [0, 3, 1, 2]) 19 | psf = tf.signal.fftshift(psf, axes=(2,3)) 20 | otf = tf.signal.fft2d(psf) 21 | return otf 22 | 23 | # Assume non-padded PSF as input 24 | def get_edgetaper_weight(psf, im_h, im_w, mode='autocorrelation'): 25 | assert(im_h // 2 >= psf.shape[1]) 26 | assert(im_w // 2 >= psf.shape[2]) 27 | 28 | if mode == 'autocorrelation': 29 | padding_h = im_h - 1 - psf.shape[1] 30 | padding_w = im_w - 1 - psf.shape[2] 31 | 32 | psf_prj_h = tf.reduce_sum(psf, axis=1, keepdims=True) 33 | psf_prj_h = tf.transpose(psf_prj_h, perm=[0,3,1,2]) # Move dimension to inner-most location 34 | psf_prj_h = tf.pad(psf_prj_h, paddings=((0,0),(0,0),(0,0),(padding_h,0)),mode='constant') 35 | psf_prj_h = tf.square(tf.abs(tf.signal.fft(tf.cast(psf_prj_h[:,:,:,:], dtype=tf.complex64)))) 36 | psf_prj_h = tf.math.real(tf.signal.ifft(tf.cast(psf_prj_h, dtype=tf.complex64))) 37 | psf_prj_h = tf.concat([psf_prj_h, psf_prj_h[:,:,:,0:1]],axis=-1) 38 | psf_prj_h = psf_prj_h / tf.reduce_max(psf_prj_h, axis=-1, keepdims=True) 39 | 40 | psf_prj_w = tf.reduce_sum(psf, axis=2, keepdims=True) 41 | psf_prj_w = tf.transpose(psf_prj_w, perm=[0,3,2,1]) # Move dimension to inner-most location 42 | psf_prj_w = tf.pad(psf_prj_w, paddings=((0,0),(0,0),(0,0),(padding_w,0)),mode='constant') 43 | psf_prj_w = tf.square(tf.abs(tf.signal.fft(tf.cast(psf_prj_w[:,:,:,:], dtype=tf.complex64)))) 44 | psf_prj_w = tf.math.real(tf.signal.ifft(tf.cast(psf_prj_w, dtype=tf.complex64))) 45 | psf_prj_w = tf.concat([psf_prj_w, psf_prj_w[:,:,:,0:1]],axis=-1) 46 | psf_prj_w = psf_prj_w / tf.reduce_max(psf_prj_w, axis=-1, keepdims=True) 47 | 48 | psf_prj_h = tf.transpose(psf_prj_h, perm=[0,3,2,1]) 49 | psf_prj_w = tf.transpose(psf_prj_w, perm=[0,2,3,1]) 50 | 51 | weight = (1 - psf_prj_h) * (1 - psf_prj_w) 52 | return weight 53 | elif mode == 'bilinear': 54 | im_h = int(im.shape[1]) 55 | im_w = int(im.shape[2]) 56 | channels = int(im.shape[3]) 57 | psf_h = int(psf.shape[1]) 58 | psf_w = int(psf.shape[2]) 59 | 60 | window_vec_h = tf.ones_like(im[0,0:im_h - psf_h//2,0,0]) 61 | window_vec_w = tf.ones_like(im[0,0,0:im_w - psf_w//2,0]) 62 | 63 | window_vec_h = tf.concat([tf.zeros(psf_h//4, dtype=tf.float32), 64 | tf.linspace(0.0,1.0, num=psf_h//4), 65 | window_vec_h, 66 | tf.linspace(1.0,0.0, num=psf_h//4), 67 | tf.zeros(psf_h//4, dtype=tf.float32)], axis=0) 68 | window_vec_w = tf.concat([tf.zeros(psf_w//4, dtype=tf.float32), 69 | tf.linspace(0.0,1.0, num=psf_w//4), 70 | window_vec_w, 71 | tf.linspace(1.0,0.0, num=psf_w//4), 72 | tf.zeros(psf_w//4, dtype=tf.float32)], axis=0) 73 | window_vec_h = tf.cast(window_vec_h, dtype=tf.float32) 74 | window_vec_w = tf.cast(window_vec_w, dtype=tf.float32) 75 | window_vec_h = window_vec_h[tf.newaxis,:,tf.newaxis,tf.newaxis] 76 | window_vec_w = window_vec_w[tf.newaxis,tf.newaxis,:,tf.newaxis] 77 | edgetaper_weight = window_vec_h * window_vec_w 78 | edgetaper_weight = tf.concat(channels * [edgetaper_weight], axis=3) 79 | edgetaper_weight = tf.image.resize_with_crop_or_pad(edgetaper_weight, im_h, im_w) 80 | return edgetaper_weight 81 | else: 82 | assert 0 83 | 84 | 85 | # Wiener filter deconvolution with optional edgetapering 86 | # otf - Precomputed Optical Transfer Function 87 | # ew - Precomputed edgetaper weight 88 | def deconvolve_wnr(blur, snr, otf, ew, do_taper=False): 89 | if do_taper: 90 | blur_tapered = ifft(fft(blur) * otf) 91 | blur = ew * blur + (1 - ew) * blur_tapered 92 | blur_debug = blur 93 | 94 | wiener_filter = tf.math.conj(otf) / (tf.cast(tf.abs(otf) ** 2, tf.complex64) + tf.cast(1 / tf.abs(snr), tf.complex64)) 95 | output = tf.cast(tf.abs(ifft(wiener_filter * fft(blur))), tf.float32) 96 | return output, blur_debug 97 | 98 | 99 | # Forward pass 100 | def convolution_tf(params, args): 101 | def conv_fn(image, psf): 102 | if args.conv_mode == 'REAL': 103 | return image 104 | assert((image.shape[1]) == params['load_width']) 105 | assert((image.shape[2]) == params['load_width']) 106 | otf = psf2otf(psf, params['load_width'], params['load_width']) 107 | blur = ifft(fft(image) * otf) 108 | blur = tf.image.resize_with_crop_or_pad(blur, params['network_width'], params['network_width']) 109 | return blur 110 | return conv_fn 111 | 112 | 113 | # Backwards pass 114 | def deconvolution_tf(params, args): 115 | def deconv_fn(blur, psf, snr, G, training): 116 | h = blur.shape[1] 117 | w = blur.shape[2] 118 | 119 | # Pre-compute optical transfer function and edgetaper weights at different scales 120 | psf_1x = psf 121 | otf_1x = psf2otf(psf_1x, h , w ) 122 | ew_1x = get_edgetaper_weight(psf_1x, h , w ) 123 | 124 | psf_2x = tf.image.resize(psf, [tf.constant(psf.shape[1]//2, dtype=tf.int32), 125 | tf.constant(psf.shape[2]//2, dtype=tf.int32)], 126 | method='bilinear', preserve_aspect_ratio=True) 127 | psf_2x = psf_2x / tf.reduce_sum(psf_2x, axis=[1,2], keepdims=True) 128 | otf_2x = psf2otf(psf_2x, h//2, w//2) 129 | ew_2x = get_edgetaper_weight(psf_2x, h//2, w//2) 130 | 131 | psf_4x = tf.image.resize(psf, [tf.constant(psf.shape[1]//4, dtype=tf.int32), 132 | tf.constant(psf.shape[2]//4, dtype=tf.int32)], 133 | method='bilinear', preserve_aspect_ratio=True) 134 | psf_4x = psf_4x / tf.reduce_sum(psf_4x, axis=[1,2], keepdims=True) 135 | otf_4x = psf2otf(psf_4x, h//4, w//4) 136 | ew_4x = get_edgetaper_weight(psf_4x, h//4, w//4) 137 | 138 | # Apply deconvolution algorithm and return time spent 139 | start = time.time() 140 | G_img, *G_debug = G([blur, tf.expand_dims(snr, 0), otf_1x, ew_1x, otf_2x, ew_2x, otf_4x, ew_4x], training=training) 141 | end = time.time() 142 | t = end - start 143 | return t, G_img, G_debug 144 | return deconv_fn 145 | -------------------------------------------------------------------------------- /metasurface/solver.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_probability as tfp 3 | import tensorflow_addons as tfa 4 | import numpy as np 5 | 6 | # Make the phase function for even asphere polynomials. 7 | #num_coeffs = 8 8 | #def phase_func(x, a2, a4, a6, a8, a10, a12, a14, a16): 9 | # return a2 * x ** 2 + a4 * x ** 4 + a6 * x ** 6 + a8 * x ** 8 + a10 * x ** 10 + a12 * x ** 12 + a14 * x ** 14 + a16 * x ** 16 10 | def make_phase_func(num_coeffs): 11 | func_str = 'def phase_func(x,' 12 | for i in range(num_coeffs): 13 | func_str = func_str + 'a' + str(2*(i+1)) 14 | if i < num_coeffs - 1: 15 | func_str = func_str + ',' 16 | func_str = func_str + '): return ' 17 | for i in range(num_coeffs): 18 | func_str = func_str + 'a' + str(2*(i+1)) + '*x**' + str(2*(i+1)) 19 | if i < num_coeffs - 1: 20 | func_str = func_str + ' + ' 21 | ldict = {} 22 | print(func_str) 23 | exec(func_str, globals(), ldict) 24 | return ldict['phase_func'] 25 | 26 | # Initializes parameters used in the simulation and optimization. 27 | def initialize_params(args): 28 | 29 | theta_base = args.theta_base 30 | phi_base = 0.0 # Phi angle for full field simulation. Currently unused. 31 | 32 | # Define the `params` dictionary. 33 | params = dict({}) 34 | 35 | # Number of optimizable phase coefficients 36 | params['num_coeffs'] = args.num_coeffs 37 | params['phase_func'] = make_phase_func(params['num_coeffs']) 38 | 39 | # Units and tensor dimensions. 40 | params['nanometers'] = 1E-9 41 | params['degrees'] = np.pi / 180 42 | 43 | # Upsampling for Fourier optics propagation 44 | params['upsample'] = 1 45 | params['normalize_psf'] = args.normalize_psf 46 | 47 | # Sensor parameters 48 | params['magnification'] = args.mag # Image magnification 49 | params['sensor_pixel'] = 5.86E-6 # Meters 50 | params['sensor_height'] = 1216 # Sensor pixels 51 | params['sensor_width'] = 1936 # Sensor pixels 52 | params['a_poisson'] = args.a_poisson # Poisson noise component 53 | params['b_sqrt'] = args.b_sqrt # Gaussian noise standard deviation 54 | 55 | # Focal length 56 | params['f'] = 1E-3 57 | 58 | # Tensor shape parameters and upsampling. 59 | lambda_base = [606.0, 511.0, 462.0] 60 | params['lambda_base'] = lambda_base # Screen wavelength 61 | params['lambda_base_weights'] = np.array([1.0, 1.0, 1.0]) 62 | params['theta_base'] = theta_base 63 | params['phi_base'] = [0.0] 64 | 65 | # PSF grid shape. 66 | # dim is set to work with the offset PSF training scheme 67 | if args.offset: 68 | dim = np.int(2 * (np.size(params['theta_base']) - 1) - 1) 69 | else: 70 | dim = 5 # <-- TODO: Hack to get image size to be 720 x 720 71 | psfs_grid_shape = [dim, dim] 72 | params['psfs_grid_shape'] = psfs_grid_shape 73 | 74 | # Square input image width based on max field angle (20 degrees) 75 | image_width = params['f'] * np.tan(20.0 * np.pi / 180.0) * np.sqrt(2) 76 | image_width = image_width * params['magnification'] / params['sensor_pixel'] 77 | params['image_width'] = np.int(2*dim * np.ceil(image_width / (2*dim) )) 78 | 79 | if args.conv == 'patch_size': 80 | # Patch sized image for training efficiency 81 | params['psf_width'] = (params['image_width'] // dim) 82 | assert(params['psf_width'] % 2 == 0) 83 | params['hw'] = (params['psf_width']) // 2 84 | params['load_width'] = (params['image_width'] // params['psfs_grid_shape'][0]) + 2*params['psf_width'] 85 | params['network_width'] = (params['image_width'] // params['psfs_grid_shape'][0]) + params['psf_width'] 86 | params['out_width'] = (params['image_width'] // params['psfs_grid_shape'][0]) 87 | elif args.conv == 'full_size': 88 | # Full size image for inference 89 | params['psf_width'] = (params['image_width'] // 2) 90 | print(params['psf_width']) 91 | assert(params['psf_width'] % 2 == 0) 92 | params['hw'] = (params['psf_width']) // 2 93 | params['load_width'] = params['image_width'] + 2*params['psf_width'] 94 | params['network_width'] = params['image_width'] + params['psf_width'] 95 | params['out_width'] = params['image_width'] 96 | else: 97 | assert 0 98 | 99 | print('Image width: {}'.format(params['image_width'])) 100 | print('PSF width: {}'.format(params['psf_width'])) 101 | print('Load width: {}'.format(params['load_width'])) 102 | print('Network width: {}'.format(params['network_width'])) 103 | print('Out width: {}'.format(params['out_width'])) 104 | 105 | params['batchSize'] = np.size(lambda_base) * np.size(theta_base) * np.size(phi_base) 106 | batchSize = params['batchSize'] 107 | fwhm = np.array([35.0, 34.0, 21.0]) # Screen fwhm 108 | params['sigma'] = fwhm / 2.355 109 | calib_fwhm = np.array([15.0, 30.0, 14.0]) # Calibration fwhm 110 | params['calib_sigma'] = calib_fwhm / 2.355 111 | num_pixels = 1429 # Needed for 0.5 mm diameter aperture 112 | params['pixels_aperture'] = num_pixels 113 | pixelsX = num_pixels 114 | pixelsY = num_pixels 115 | params['pixelsX'] = pixelsX 116 | params['pixelsY'] = pixelsY 117 | params['upsample'] = 1 118 | 119 | # Simulation grid. 120 | params['wavelength_nominal'] = 452E-9 121 | params['pitch'] = 350E-9 122 | params['Lx'] = 1 * params['pitch'] 123 | params['Ly'] = params['Lx'] 124 | dx = params['Lx'] # grid resolution along x 125 | dy = params['Ly'] # grid resolution along x 126 | xa = np.linspace(0, pixelsX - 1, pixelsX) * dx # x axis array 127 | xa = xa - np.mean(xa) # center x axis at zero 128 | ya = np.linspace(0, pixelsY - 1, pixelsY) * dy # y axis vector 129 | ya = ya - np.mean(ya) # center y axis at zero 130 | [y_mesh, x_mesh] = np.meshgrid(ya, xa) 131 | params['x_mesh'] = x_mesh 132 | params['y_mesh'] = y_mesh 133 | 134 | # Wavelengths and field angles. 135 | lam0 = params['nanometers'] * tf.convert_to_tensor(np.repeat(lambda_base, np.size(theta_base) * np.size(phi_base)), dtype = tf.float32) 136 | lam0 = lam0[:, tf.newaxis, tf.newaxis] 137 | lam0 = tf.tile(lam0, multiples = (1, pixelsX, pixelsY)) 138 | params['lam0'] = lam0 139 | 140 | theta = params['degrees'] * tf.convert_to_tensor(np.tile(theta_base, np.size(lambda_base) * np.size(phi_base)), dtype = tf.float32) 141 | theta = theta[:, tf.newaxis, tf.newaxis] 142 | theta = tf.tile(theta, multiples = (1, pixelsX, pixelsY)) 143 | params['theta'] = theta 144 | 145 | phi = np.repeat(phi_base, np.size(theta_base)) 146 | phi = np.tile(phi, np.size(lambda_base)) 147 | phi = params['degrees'] * tf.convert_to_tensor(phi, dtype = tf.float32) 148 | phi = phi[:, tf.newaxis, tf.newaxis] 149 | phi = tf.tile(phi, multiples = (1, pixelsX, pixelsY)) 150 | params['phi'] = phi 151 | 152 | # Propagation parameters. 153 | params['propagator'] = make_propagator(params) 154 | params['input'] = define_input_fields(params) 155 | 156 | # Metasurface proxy phase model. 157 | params['phase_to_structure_coeffs'] = [-0.1484, 0.6809, 0.2923] 158 | params['structure_to_phase_coeffs'] = [6.051, -0.02033, 2.26, 1.371E-5, -0.002947, 0.797] 159 | params['use_proxy_phase'] = True 160 | 161 | # Compute the PSFs on the full field grid without exploiting azimuthal symmetry. 162 | params['full_field'] = False # Not currently used 163 | 164 | # Use a predefined phase pattern (cubic, log-asphere, shifted axicon) 165 | params['use_general_phase'] = args.use_general_phase 166 | 167 | # Manufacturing considerations. 168 | params['fab_tolerancing'] = False #True 169 | params['fab_error_global'] = 0.03 # +/- 6% duty cycle variation globally (2*sigma) 170 | params['fab_error_local'] = 0.015 # +/- 3% duty cycle variation locally (2*sigma) 171 | 172 | return params 173 | 174 | # Makes the recicrpocal space propagator to use for the specified input conditions. 175 | def make_propagator(params): 176 | 177 | batchSize = params['batchSize'] 178 | pixelsX = params['pixelsX'] 179 | pixelsY = params['pixelsY'] 180 | upsample = params['upsample'] 181 | 182 | # Propagator definition. 183 | k = 2 * np.pi / params['lam0'][:, 0, 0] 184 | k = k[:, np.newaxis, np.newaxis] 185 | samp = params['upsample'] * pixelsX 186 | k = tf.tile(k, multiples = (1, 2 * samp - 1, 2 * samp - 1)) 187 | k = tf.cast(k, dtype = tf.complex64) 188 | k_xlist_pos = 2 * np.pi * np.linspace(0, 1 / (2 * params['Lx'] / params['upsample']), samp) 189 | front = k_xlist_pos[-(samp - 1):] 190 | front = -front[::-1] 191 | k_xlist = np.hstack((front, k_xlist_pos)) 192 | k_x = np.kron(k_xlist, np.ones((2 * samp - 1, 1))) 193 | k_x = k_x[np.newaxis, :, :] 194 | k_y = np.transpose(k_x, axes = [0, 2, 1]) 195 | k_x = tf.convert_to_tensor(k_x, dtype = tf.complex64) 196 | k_x = tf.tile(k_x, multiples = (batchSize, 1, 1)) 197 | k_y = tf.convert_to_tensor(k_y, dtype = tf.complex64) 198 | k_y = tf.tile(k_y, multiples = (batchSize, 1, 1)) 199 | k_z_arg = tf.square(k) - (tf.square(k_x) + tf.square(k_y)) 200 | k_z = tf.sqrt(k_z_arg) 201 | 202 | # Find shift amount 203 | theta = params['theta'][:, 0, 0] 204 | theta = theta[:, np.newaxis, np.newaxis] 205 | y0 = np.tan(theta) * params['f'] 206 | y0 = tf.tile(y0, multiples = (1, 2 * samp - 1, 2 * samp - 1)) 207 | y0 = tf.cast(y0, dtype = tf.complex64) 208 | 209 | phi = params['phi'][:, 0, 0] 210 | phi = phi[:, np.newaxis, np.newaxis] 211 | x0 = np.tan(phi) * params['f'] 212 | x0 = tf.tile(x0, multiples = (1, 2 * samp - 1, 2 * samp - 1)) 213 | x0 = tf.cast(x0, dtype = tf.complex64) 214 | 215 | propagator_arg = 1j * (k_z * params['f'] + k_x * x0 + k_y * y0) 216 | propagator = tf.exp(propagator_arg) 217 | 218 | return propagator 219 | 220 | # Propagate the specified fields to the sensor plane. 221 | def propagate(field, params): 222 | # Field has dimensions of (batchSize, pixelsX, pixelsY) 223 | # Each element corresponds to the zero order planewave component on the output 224 | propagator = params['propagator'] 225 | 226 | # Zero pad `field` to be a stack of 2n-1 x 2n-1 matrices 227 | # Put batch parameter last for padding then transpose back 228 | _, _, m = field.shape 229 | n = params['upsample'] * m 230 | field = tf.transpose(field, perm = [1, 2, 0]) 231 | field_real = tf.math.real(field) 232 | field_imag = tf.math.imag(field) 233 | field_real = tf.image.resize(field_real, [n, n], method = 'nearest') 234 | field_imag = tf.image.resize(field_imag, [n, n], method = 'nearest') 235 | field = tf.cast(field_real, dtype = tf.complex64) + 1j * tf.cast(field_imag, dtype = tf.complex64) 236 | field = tf.image.resize_with_crop_or_pad(field, 2 * n - 1, 2 * n - 1) 237 | field = tf.transpose(field, perm = [2, 0, 1]) 238 | 239 | field_freq = tf.signal.fftshift(tf.signal.fft2d(field), axes = (1, 2)) 240 | field_filtered = tf.signal.ifftshift(field_freq * propagator, axes = (1, 2)) 241 | out = tf.signal.ifft2d(field_filtered) 242 | 243 | # Crop back down to n x n matrices 244 | out = tf.transpose(out, perm = [1, 2, 0]) 245 | out = tf.image.resize_with_crop_or_pad(out, n, n) 246 | out = tf.transpose(out, perm = [2, 0, 1]) 247 | return out 248 | 249 | # Defines the input electric fields for the given wavelengths and field angles. 250 | def define_input_fields(params): 251 | 252 | # Define the cartesian cross section 253 | pixelsX = params['pixelsX'] 254 | pixelsY = params['pixelsY'] 255 | dx = params['Lx'] # grid resolution along x 256 | dy = params['Ly'] # grid resolution along y 257 | xa = np.linspace(0, pixelsX - 1, pixelsX) * dx # x axis array 258 | xa = xa - np.mean(xa) # center x axis at zero 259 | ya = np.linspace(0, pixelsY - 1, pixelsY) * dy # y axis vector 260 | ya = ya - np.mean(ya) # center y axis at zero 261 | [y_mesh, x_mesh] = np.meshgrid(ya, xa) 262 | x_mesh = x_mesh[np.newaxis, :, :] 263 | y_mesh = y_mesh[np.newaxis, :, :] 264 | lam_phase_test = params['lam0'][:, 0, 0] 265 | lam_phase_test = lam_phase_test[:, tf.newaxis, tf.newaxis] 266 | theta_phase_test = params['theta'][:, 0, 0] 267 | theta_phase_test = theta_phase_test[:, tf.newaxis, tf.newaxis] 268 | phi_phase_test = params['phi'][:, 0, 0] 269 | phi_phase_test = phi_phase_test[:, tf.newaxis, tf.newaxis] 270 | phase_def = 2 * np.pi / lam_phase_test * (np.sin(theta_phase_test) * x_mesh + np.sin(phi_phase_test) * y_mesh) 271 | phase_def = tf.cast(phase_def, dtype = tf.complex64) 272 | 273 | return tf.exp(1j * phase_def) 274 | 275 | # Generates a phase distribution modelling a metasurface given some phase coefficients. 276 | def metasurface_phase_generator(phase_coeffs, params): 277 | x_mesh = params['x_mesh'] 278 | y_mesh = params['y_mesh'] 279 | x_mesh = x_mesh[np.newaxis, :, :] 280 | y_mesh = y_mesh[np.newaxis, :, :] 281 | phase_def = tf.zeros(shape = np.shape(x_mesh), dtype = tf.float32) 282 | r_phase = np.sqrt(x_mesh ** 2 + y_mesh ** 2) / (params['pixels_aperture'] * params['Lx'] / 2.0) 283 | if params['use_general_phase'] == True: 284 | phase_def = params['general_phase'] 285 | else: 286 | for j in range(np.size(phase_coeffs.numpy())): 287 | power = tf.constant(2 * (j + 1), dtype = tf.float32) 288 | r_power = tf.math.pow(r_phase, power) 289 | phase_def = phase_def + phase_coeffs[j] * r_power 290 | 291 | phase_def = tf.math.floormod(phase_def, 2 * np.pi) 292 | if params['use_proxy_phase'] == True: 293 | # Determine the duty cycle distribution first. 294 | duty = duty_cycle_from_phase(phase_def, params) 295 | 296 | # Accounts for global and local process variations in grating duty cycle. 297 | if params['fab_tolerancing'] == True: 298 | global_error = tf.random.normal(shape = [1], mean = 0.0, stddev = params['fab_error_global'], dtype = tf.float32) 299 | local_error = tf.random.normal(shape = tf.shape(duty), mean = 0.0, stddev = params['fab_error_local'], dtype = tf.float32) 300 | duty = duty + global_error + local_error 301 | 302 | # Duty cycle is fit to this range and querying outside is not physically meaningful so we need to clip it. 303 | duty = tf.clip_by_value(duty, clip_value_min = 0.3, clip_value_max = 0.82) 304 | 305 | phase_def = phase_from_duty_and_lambda(duty, params) 306 | else: 307 | phase_def = phase_def * params['wavelength_nominal'] / params['lam0'] 308 | 309 | mask = ((x_mesh ** 2 + y_mesh ** 2) < (params['pixels_aperture'] * params['Lx'] / 2.0) ** 2) 310 | phase_def = phase_def * mask 311 | return phase_def 312 | 313 | # Calculates the required duty cycle distribution at the nominal wavelength given 314 | # a specified phase function using a pre-fit polynomial proxy for the mapping. 315 | def duty_cycle_from_phase(phase, params): 316 | phase = phase / (2 * np.pi) 317 | p = params['phase_to_structure_coeffs'] 318 | return p[0] * phase ** 2 + p[1] * phase + p[2] 319 | 320 | # Calculates the phase shift for a distribution of diameters at all the desired 321 | # simulation wavelengths using a pre-fit polynomial proxy for the mapping. 322 | def phase_from_duty_and_lambda(duty, params): 323 | p = params['structure_to_phase_coeffs'] 324 | lam = params['lam0'] / params['nanometers'] 325 | phase = p[0] + p[1]*lam + p[2]*duty + p[3]*lam**2 + p[4]*lam*duty + p[5]*duty**2 326 | return phase * 2 * np.pi 327 | 328 | # Finds the intensity at the sensor given the input fields. 329 | def compute_intensity_at_sensor(field, params): 330 | coherent_psf = propagate(params['input'] * field, params) 331 | return tf.math.abs(coherent_psf) ** 2 332 | 333 | # Determines the PSF from the intensity at the sensor, accounting for image magnification. 334 | def calculate_psf(intensity, params): 335 | # Transpose for subsequent reshaping 336 | intensity = tf.transpose(intensity, perm = [1, 2, 0]) 337 | aperture = params['pixels_aperture'] 338 | sensor_pixel = params['sensor_pixel'] 339 | magnification = params['magnification'] 340 | period = params['Lx'] 341 | 342 | # Determine PSF shape after optical magnification 343 | mag_width = int(np.round(aperture * period * magnification / sensor_pixel)) 344 | mag_intensity = tf.image.resize(intensity, [mag_width, mag_width], method='bilinear') # Sample onto sensor pixels 345 | 346 | # Maintain same energy as before optical magnification 347 | denom = tf.math.reduce_sum(mag_intensity, axis = [0, 1], keepdims = False) 348 | denom = denom[tf.newaxis, tf.newaxis, :] 349 | mag_intensity = mag_intensity * tf.math.reduce_sum(intensity, axis = [0, 1], keepdims = True) / denom 350 | 351 | # Crop to sensor dimensions 352 | sensor_psf = mag_intensity 353 | #sensor_psf = tf.image.resize_with_crop_or_pad(sensor_psf, params['sensor_height'], params['sensor_width']) 354 | sensor_psf = tf.transpose(sensor_psf, perm = [2, 0, 1]) 355 | sensor_psf = tf.clip_by_value(sensor_psf, 0.0, 1.0) 356 | return sensor_psf 357 | 358 | # Defines a metasurface, including phase and amplitude variation. 359 | def define_metasurface(phase_var, params): 360 | phase_def = metasurface_phase_generator(phase_var, params) 361 | phase_def = tf.cast(phase_def, dtype = tf.complex64) 362 | amp = ((params['x_mesh'] ** 2 + params['y_mesh'] ** 2) < (params['pixels_aperture'] * params['Lx'] / 2.0) ** 2) 363 | I = 1.0 / np.sum(amp) 364 | E_amp = np.sqrt(I) 365 | return amp * E_amp * tf.exp(1j * phase_def) 366 | 367 | # Shifts the raw PSF to be centered, cropped to the patch size, and stacked 368 | # along the channels dimension 369 | def shift_and_segment_psf(psf, params): 370 | # Calculate the shift amounts for each PSF. 371 | b, h, w = psf.shape 372 | shifted_psf = psf 373 | 374 | # Reshape the PSFs based on the color channel. 375 | psf_channels_shape = (params['batchSize'] // (np.size(params['theta_base']) * np.size(params['phi_base'])), 376 | np.size(params['theta_base']) * np.size(params['phi_base']) , 377 | h, w) 378 | shifted_psf_c_channels = tf.reshape(shifted_psf, shape = psf_channels_shape) 379 | shifted_psf_c_channels = tf.transpose(shifted_psf_c_channels, perm = (1, 2, 3, 0)) 380 | 381 | samples = np.size(params['lambda_base']) // 3 382 | for j in range(np.size(params['theta_base']) * np.size(params['phi_base'])): 383 | psfs_j = shifted_psf_c_channels[j, :, :, :] 384 | for k in range(3): 385 | psfs_jk = psfs_j[:, :, k * samples : (k + 1) * samples] 386 | psfs_jk_avg = tf.math.reduce_sum(psfs_jk, axis = 2, keepdims = False) 387 | psfs_jk_avg = psfs_jk_avg[:, :, tf.newaxis] 388 | if k == 0: 389 | psfs_channels = psfs_jk_avg 390 | else: 391 | psfs_channels = tf.concat([psfs_channels, psfs_jk_avg], axis = 2) 392 | 393 | psfs_channels_expanded = psfs_channels[tf.newaxis, :, :, :] 394 | if j == 0: 395 | psfs_thetas_channels = psfs_channels_expanded 396 | else: 397 | psfs_thetas_channels = tf.concat([psfs_thetas_channels, psfs_channels_expanded], axis = 0) 398 | 399 | psfs_thetas_channels = psfs_thetas_channels[:, h // 2 - params['hw'] : h // 2 + params['hw'], 400 | w // 2 - params['hw'] : w // 2 + params['hw'], :] 401 | 402 | # Normalize to unit power per channel since multiple wavelengths are now combined into each channel 403 | if params['normalize_psf']: 404 | psfs_thetas_channels_sum = tf.math.reduce_sum(psfs_thetas_channels, axis = (1, 2), keepdims = True) 405 | psfs_thetas_channels = psfs_thetas_channels / psfs_thetas_channels_sum 406 | return psfs_thetas_channels 407 | 408 | 409 | # Rotate PSF (non-SVOLA) 410 | def rotate_psfs(psf, params, rotate=True): 411 | #psfs_grid_shape = params['psfs_grid_shape'] 412 | #rotations = np.zeros(np.prod(psfs_grid_shape)) 413 | psfs = shift_and_segment_psf(psf, params) 414 | rot_angle = 0.0 415 | if rotate: 416 | angles = np.array([0.0, 45.0, 90.0, 135.0, 180.0, 225.0, 270.0, 315.0], dtype=np.float32) 417 | rot_angle = (np.random.choice(angles) * np.pi / 180.0).astype(np.float32) 418 | rot_angles = tf.fill([np.size(params['theta_base']) * np.size(params['phi_base'])], rot_angle) 419 | psfs_rot = tfa.image.rotate(psfs, angles = rot_angles, interpolation = 'NEAREST') 420 | return psfs_rot 421 | 422 | # PSF patches are determined by rotating them into the different patch regions 423 | # for subsequent SVOLA convolution. 424 | def rotate_psf_patches(psf, params): 425 | psfs_grid_shape = params['psfs_grid_shape'] 426 | rotations = np.zeros(np.prod(psfs_grid_shape)) 427 | psfs = shift_and_segment_psf(psf, params) 428 | 429 | # Iterate through all positions in the PSF grid. 430 | mid_y = (psfs_grid_shape[0] - 1) // 2 431 | mid_x = (psfs_grid_shape[1] - 1) // 2 432 | for i in range(psfs_grid_shape[0]): 433 | for j in range(psfs_grid_shape[1]): 434 | r_idx = i - mid_y 435 | c_idx = j - mid_x 436 | 437 | if params['full_field'] == True: 438 | index = psfs_grid_shape[0] * j + i 439 | psf_ij = psfs[index, :, :, :] 440 | else: 441 | # Calculate the required rotation angle. 442 | rotations[i * psfs_grid_shape[0] + j] = np.arctan2(-r_idx, c_idx) + np.pi / 2 443 | 444 | # Set the PSF based on the normalized radial distance. 445 | psf_ij = psfs[max(abs(r_idx), abs(c_idx)),:,:,:] 446 | 447 | psf_ij = psf_ij[tf.newaxis, :, :, :] 448 | 449 | if (i == 0 and j == 0): 450 | psf_patches = psf_ij 451 | else: 452 | psf_patches = tf.concat([psf_patches, psf_ij], axis = 0) 453 | 454 | # Apply the rotations as a batch operation. 455 | psf_patches = tfa.image.rotate(psf_patches, angles = rotations, interpolation = 'NEAREST') 456 | return psf_patches 457 | 458 | 459 | def get_psfs(phase_var, params, conv_mode, aug_rotate): 460 | metasurface_mask = define_metasurface(phase_var, params) 461 | intensity = compute_intensity_at_sensor(metasurface_mask, params) 462 | psf = calculate_psf(intensity, params) 463 | psfs_single = rotate_psfs(psf, params, rotate=False) 464 | psfs_conv = rotate_psfs(psf, params, rotate=aug_rotate) 465 | return psfs_single, psfs_conv 466 | 467 | 468 | # Applies Poisson noise and adds Gaussian noise. 469 | def sensor_noise(input_layer, params, clip = (1E-20,1.)): 470 | 471 | # Apply Poisson noise. 472 | if (params['a_poisson'] > 0): 473 | a_poisson_tf = tf.constant(params['a_poisson'], dtype = tf.float32) 474 | 475 | input_layer = tf.clip_by_value(input_layer, clip[0], 100.0) 476 | p = tfp.distributions.Poisson(rate = input_layer / a_poisson_tf, validate_args = True) 477 | sampled = tfp.monte_carlo.expectation(f = lambda x: x, samples = p.sample(1), log_prob = p.log_prob, use_reparameterization = False) 478 | output = sampled * a_poisson_tf 479 | else: 480 | output = input_layer 481 | 482 | # Add Gaussian readout noise. 483 | gauss_noise = tf.random.normal(shape=tf.shape(output), mean = 0.0, stddev = params['b_sqrt'], dtype = tf.float32) 484 | output = output + gauss_noise 485 | 486 | # Clipping. 487 | output = tf.clip_by_value(output, clip[0], clip[1]) 488 | return output 489 | 490 | 491 | # Samples wavelengths from a random normal distribution centered about the peak 492 | # wavelengths in the spectra based on the FWHM of each peak. 493 | def randomize_wavelengths(params, lambda_base, sigma): 494 | pixelsX = params['pixelsX'] 495 | pixelsY = params['pixelsY'] 496 | thetas = params['theta_base'] 497 | phis = params['phi_base'] 498 | lambdas = np.random.normal(lambda_base, sigma) 499 | lam0 = params['nanometers'] * tf.convert_to_tensor(np.repeat(lambdas, np.size(thetas) * np.size(phis)), dtype = tf.float32) 500 | lam0 = lam0[:, tf.newaxis, tf.newaxis] 501 | lam0 = tf.tile(lam0, multiples = (1, pixelsX, pixelsY)) 502 | params['lam0'] = lam0 503 | 504 | # Reset wavelengths back to nominal wavelength 505 | def set_wavelengths(params, lambda_base): 506 | lam0 = params['nanometers'] * tf.convert_to_tensor(np.repeat(lambda_base, np.size(params['theta_base']) * np.size(params['phi_base'])), dtype = tf.float32) 507 | lam0 = lam0[:, tf.newaxis, tf.newaxis] 508 | lam0 = tf.tile(lam0, multiples = (1, params['pixelsX'], params['pixelsY'])) 509 | params['lam0'] = lam0 510 | 511 | 512 | ## General Phase Functions ## 513 | 514 | # Calculates the phase for a log-asphere based on the s1 and s2 parameters. 515 | def log_asphere_phase(s1, s2, params): 516 | x_mesh = params['x_mesh'] 517 | y_mesh = params['y_mesh'] 518 | x_mesh = x_mesh[np.newaxis, :, :] 519 | y_mesh = y_mesh[np.newaxis, :, :] 520 | r_phase = np.sqrt(x_mesh ** 2 + y_mesh ** 2) 521 | R = params['pixels_aperture'] * params['Lx'] / 2.0 # Aperture radius 522 | quo = (s2 - s1) / R ** 2 523 | quo_large = s1 + quo * r_phase**2 524 | term1 = np.pi / params['wavelength_nominal'] / quo 525 | term2 = np.log(2 * quo * (np.sqrt(r_phase**2 + quo_large**2) + quo_large) + 1) - np.log(4*quo*s1 + 1) 526 | phase_def = -term1 * term2 527 | phase_def = tf.convert_to_tensor(phase_def, dtype = tf.float32) 528 | mask = ((x_mesh ** 2 + y_mesh ** 2) < R ** 2) 529 | phase_def = phase_def * mask 530 | return phase_def 531 | 532 | 533 | # Calculates the phase for a shifted axicon based on the s1 and s2 parameters. 534 | def shifted_axicon_phase(s1, s2, params): 535 | x_mesh = params['x_mesh'] 536 | y_mesh = params['y_mesh'] 537 | x_mesh = x_mesh[np.newaxis, :, :] 538 | y_mesh = y_mesh[np.newaxis, :, :] 539 | r_phase = np.sqrt(x_mesh ** 2 + y_mesh ** 2) 540 | R = params['pixels_aperture'] * params['Lx'] / 2.0 # Aperture radius 541 | samples = 1 * params['pixels_aperture'] 542 | dr = R / samples 543 | phase_def = np.zeros((1, params['pixels_aperture'], params['pixels_aperture'])) 544 | for j in range(params['pixels_aperture']): 545 | for k in range(params['pixels_aperture']): 546 | r_max = r_phase[0, j, k] 547 | if r_max < R: 548 | if j <= params['pixels_aperture'] // 2 and k <= params['pixels_aperture'] // 2: 549 | r_vector = np.linspace(0, r_max, np.int(samples * r_max / R)) 550 | numerator = r_vector * dr 551 | denominator = np.sqrt(r_vector ** 2 + (s1 + (s2 - s1) * r_vector / R) ** 2) 552 | integrand = numerator / denominator 553 | phase_def[0, j, k] = np.sum(integrand) 554 | else: # Copy the previously computed result 555 | phase_def[0, j, k] = phase_def[0, min(j, params['pixels_aperture'] - j - 1), \ 556 | min(k, params['pixels_aperture'] - k - 1)] 557 | phase_def = -2 * np.pi / params['wavelength_nominal'] * phase_def 558 | phase_def = tf.convert_to_tensor(phase_def, dtype = tf.float32) 559 | mask = ((x_mesh ** 2 + y_mesh ** 2) < R ** 2) 560 | phase_def = phase_def * mask 561 | return phase_def 562 | 563 | 564 | # Calculates the phase for a cubic phase mask with a hyperboloidal lens term assuming focusing 565 | # for the specified wavelength. 566 | def cubic_phase(alpha, wavelength, params): 567 | x_mesh = params['x_mesh'] 568 | y_mesh = params['y_mesh'] 569 | x_mesh = x_mesh[np.newaxis, :, :] 570 | y_mesh = y_mesh[np.newaxis, :, :] 571 | 572 | # As we intend for the focusing term to be for the provided 'wavelength' parameter, we need to scale 573 | # the focal length because we are effectively designing an intentionally defocused lens at the 574 | # nominal wavelength. Output phase from this function needs to be the phase at the nominal wavelength. 575 | f = params['f'] * wavelength / params['wavelength_nominal'] 576 | 577 | R = params['pixels_aperture'] * params['Lx'] / 2.0 # Aperture radius 578 | focusing_term = 2 * np.pi / params['wavelength_nominal'] * (f - np.sqrt(x_mesh ** 2 + y_mesh ** 2 + f ** 2)) 579 | edof_term = alpha / R ** 3 * (x_mesh ** 3 + y_mesh ** 3) 580 | phase_def = focusing_term + edof_term 581 | phase_def = tf.convert_to_tensor(phase_def, dtype = tf.float32) 582 | mask = ((x_mesh ** 2 + y_mesh ** 2) < R ** 2) 583 | phase_def = phase_def * mask 584 | return phase_def 585 | 586 | 587 | # Calculates the phase for a hyperboloidal lens term assuming focusing for the specified wavelength. 588 | def hyperboidal_phase(wavelength, params): 589 | x_mesh = params['x_mesh'] 590 | y_mesh = params['y_mesh'] 591 | x_mesh = x_mesh[np.newaxis, :, :] 592 | y_mesh = y_mesh[np.newaxis, :, :] 593 | 594 | # As we intend for the focusing term to be for the provided 'wavelength' parameter, we need to scale 595 | # the focal length because we are effectively designing an intentionally defocused lens at the 596 | # nominal wavelength. Output phase from this function needs to be the phase at the nominal wavelength. 597 | f = params['f'] * wavelength / params['wavelength_nominal'] 598 | 599 | R = params['pixels_aperture'] * params['Lx'] / 2.0 # Aperture radius 600 | phase_def = 2 * np.pi / params['wavelength_nominal'] * (f - np.sqrt(x_mesh ** 2 + y_mesh ** 2 + f ** 2)) 601 | phase_def = tf.convert_to_tensor(phase_def, dtype = tf.float32) 602 | mask = ((x_mesh ** 2 + y_mesh ** 2) < R ** 2) 603 | phase_def = phase_def * mask 604 | return phase_def 605 | -------------------------------------------------------------------------------- /networks/G/FP.py: -------------------------------------------------------------------------------- 1 | # Neural feature propagator network 2 | 3 | import tensorflow as tf 4 | import tensorflow_addons as tfa 5 | from metasurface.conv import deconvolve_wnr 6 | 7 | def conv(filters, size, stride, activation, apply_instnorm=True): 8 | result = tf.keras.Sequential() 9 | result.add(tf.keras.layers.Conv2D(filters, size, stride, padding='same', use_bias=True)) 10 | if apply_instnorm: 11 | result.add(tfa.layers.InstanceNormalization()) 12 | if not activation == None: 13 | result.add(activation()) 14 | return result 15 | 16 | def conv_transp(filters, size, stride, activation, apply_instnorm=True): 17 | result = tf.keras.Sequential() 18 | result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding='same', use_bias=True)) 19 | if not activation == None: 20 | result.add(activation()) 21 | return result 22 | 23 | def feat_extract(img, snr, otf_1x, ew_1x, otf_2x, ew_2x, otf_4x, ew_4x, params, args): 24 | LReLU = tf.keras.layers.LeakyReLU 25 | ReLU = tf.keras.layers.ReLU 26 | 27 | down_l0 = conv(15, 7, 1, LReLU, apply_instnorm=False)(img) 28 | down_l0 = conv(15, 7, 1, LReLU, apply_instnorm=False)(down_l0) 29 | 30 | down_l1 = conv(30, 5, 2, LReLU, apply_instnorm=False)(down_l0) 31 | down_l1 = conv(30, 3, 1, LReLU, apply_instnorm=False)(down_l1) 32 | down_l1 = conv(30, 3, 1, LReLU, apply_instnorm=False)(down_l1) 33 | 34 | down_l2 = conv(60, 5, 2, LReLU, apply_instnorm=False)(down_l1) 35 | down_l2 = conv(60, 3, 1, LReLU, apply_instnorm=False)(down_l2) 36 | down_l2 = conv(60, 3, 1, LReLU, apply_instnorm=False)(down_l2) 37 | 38 | # 4x 39 | conv_l2_k0 = conv(60, 3, 1, LReLU, apply_instnorm=False)(down_l2) 40 | conv_l2_k1 = conv(60, 3, 1, LReLU, apply_instnorm=False)(conv_l2_k0) 41 | 42 | conv_l2_k2 = conv(60, 3, 1, LReLU, apply_instnorm=False)(tf.concat([down_l2, conv_l2_k1], axis=3)) 43 | conv_l2_k3 = conv(60, 3, 1, LReLU, apply_instnorm=False)(conv_l2_k2) 44 | 45 | conv_l2_k4 = conv(60, 3, 1, LReLU, apply_instnorm=False)(conv_l2_k3) 46 | conv_l2_k5 = conv(60, 3, 1, LReLU, apply_instnorm=False)(conv_l2_k4) 47 | 48 | wien_l2_b, _, = deconvolve_wnr(conv_l2_k5, snr, tf.tile(otf_4x, [1, 20, 1, 1]), tf.tile(ew_4x, [1, 1, 1, 20]), do_taper=(args.do_taper)) 49 | 50 | # 2x 51 | conv_l1_k0 = conv(30, 3, 1, LReLU, apply_instnorm=False)(down_l1) 52 | conv_l1_k1 = conv(30, 3, 1, LReLU, apply_instnorm=False)(conv_l1_k0) 53 | 54 | conv_l1_k2 = conv(30, 3, 1, LReLU, apply_instnorm=False)(tf.concat([down_l1, conv_l1_k1], axis=3)) 55 | conv_l1_k3 = conv(30, 3, 1, LReLU, apply_instnorm=False)(conv_l1_k2) 56 | 57 | conv_l1_k4 = conv(30, 3, 1, LReLU, apply_instnorm=False)(conv_l1_k3) 58 | conv_l1_k5 = conv(30, 3, 1, LReLU, apply_instnorm=False)(conv_l1_k4) 59 | 60 | up_l2 = conv_transp(30, 2, 2, LReLU, apply_instnorm=False)(conv_l2_k5) 61 | conv_l1_k6 = conv(30, 3, 1, LReLU, apply_instnorm=False)(tf.concat([up_l2, conv_l1_k5], axis=3)) 62 | conv_l1_k7 = conv(30, 3, 1, LReLU, apply_instnorm=False)(conv_l1_k6) 63 | 64 | wien_l1_b, _ = deconvolve_wnr(conv_l1_k7, snr, tf.tile(otf_2x, [1, 10, 1, 1]), tf.tile(ew_2x, [1, 1, 1, 10]), do_taper=(args.do_taper)) 65 | 66 | # 1x 67 | conv_l0_k0 = conv(15, 5, 1, LReLU, apply_instnorm=False)(down_l0) 68 | conv_l0_k1 = conv(15, 5, 1, LReLU, apply_instnorm=False)(conv_l0_k0) 69 | 70 | conv_l0_k2 = conv(15, 5, 1, LReLU, apply_instnorm=False)(tf.concat([down_l0, conv_l0_k1], axis=3)) 71 | conv_l0_k3 = conv(15, 5, 1, LReLU, apply_instnorm=False)(conv_l0_k2) 72 | 73 | conv_l0_k4 = conv(15, 5, 1, LReLU, apply_instnorm=False)(conv_l0_k3) 74 | conv_l0_k5 = conv(15, 5, 1, LReLU, apply_instnorm=False)(conv_l0_k4) 75 | 76 | up_l1 = conv_transp(15, 2, 2, LReLU, apply_instnorm=False)(conv_l1_k5) 77 | conv_l0_k6 = conv(15, 5, 1, LReLU, apply_instnorm=False)(tf.concat([up_l1, conv_l0_k5], axis=3)) 78 | conv_l0_k7 = conv(15, 5, 1, LReLU, apply_instnorm=False)(conv_l0_k6) 79 | 80 | wiener_1x = tf.math.conj(otf_1x) / (tf.cast(tf.abs(otf_1x) ** 2, tf.complex64) + tf.cast(1 / tf.abs(snr), tf.complex64)) 81 | wiener_1x = tf.tile(wiener_1x, [1, 5, 1, 1]) 82 | wien_l0_b, _ = deconvolve_wnr(conv_l0_k7, snr, tf.tile(otf_1x, [1, 5, 1, 1]), tf.tile(ew_1x, [1, 1, 1, 5]), do_taper=(args.do_taper)) 83 | 84 | return wien_l0_b, wien_l1_b, wien_l2_b, ew_1x, ew_2x, ew_4x 85 | 86 | def FP(params, args): 87 | LReLU = tf.keras.layers.LeakyReLU 88 | ReLU = tf.keras.layers.ReLU 89 | 90 | h = params['network_width'] 91 | w = params['network_width'] 92 | inputs = tf.keras.layers.Input(shape=[h ,w ,3]) 93 | snr = tf.keras.layers.Input(shape=[]) 94 | otf_1x = tf.keras.layers.Input(shape=[3, h , w ], dtype=tf.complex64) 95 | ew_1x = tf.keras.layers.Input(shape=[h ,w ,3]) 96 | otf_2x = tf.keras.layers.Input(shape=[3,h//2,w//2], dtype=tf.complex64) 97 | ew_2x = tf.keras.layers.Input(shape=[h//2,w//2,3]) 98 | otf_4x = tf.keras.layers.Input(shape=[3,h//4,w//4], dtype=tf.complex64) 99 | ew_4x = tf.keras.layers.Input(shape=[h//4,w//4,3]) 100 | 101 | ## Feature Extractor 102 | deconv0, deconv1, deconv2, edge0, edge1, edge2 = \ 103 | feat_extract(inputs, tf.math.pow(10.0, snr), otf_1x, ew_1x, otf_2x, ew_2x, otf_4x, ew_4x, params, args) 104 | side = (h - params['out_width']) // 2 105 | deconv0 = deconv0[:,side:-side,side:-side,:] 106 | deconv1 = deconv1[:,side//2:-side//2,side//2:-side//2,:] 107 | deconv2 = deconv2[:,side//4:-side//4,side//4:-side//4,:] 108 | 109 | ## Decoder 110 | conv_l0_k0 = conv(30, 5, 1, LReLU, apply_instnorm=False)(deconv0) 111 | conv_l0_k1 = conv(30, 5, 1, LReLU, apply_instnorm=False)(conv_l0_k0) 112 | down_l0 = conv(30, 5, 2, LReLU, apply_instnorm=False)(conv_l0_k1) 113 | 114 | conv_l1_k0 = tf.concat([deconv1, down_l0], axis=3) 115 | conv_l1_k1 = conv(60, 3, 1, LReLU, apply_instnorm=False)(conv_l1_k0) 116 | conv_l1_k2 = conv(60, 3, 1, LReLU, apply_instnorm=False)(conv_l1_k1) 117 | down_l1 = conv(60, 3, 2, LReLU, apply_instnorm=False)(conv_l1_k2) 118 | 119 | conv_l2_k0 = tf.concat([deconv2, down_l1], axis=3) 120 | conv_l2_k1 = conv(120, 3, 1, LReLU, apply_instnorm=False)(conv_l2_k0) 121 | conv_l2_k2 = conv(120, 3, 1, LReLU, apply_instnorm=False)(conv_l2_k1) 122 | 123 | conv_l2_k3 = conv(120, 3, 1, LReLU, apply_instnorm=False)(tf.concat([conv_l2_k0, conv_l2_k2], axis=3)) 124 | conv_l2_k4 = conv(120, 3, 1, LReLU, apply_instnorm=False)(conv_l2_k3) 125 | 126 | up_l2 = conv_transp(60, 2, 2, LReLU, apply_instnorm=False)(conv_l2_k4) 127 | conv_l1_k3 = conv(60, 3, 1, LReLU, apply_instnorm=False)(tf.concat([conv_l1_k2, up_l2], axis=3)) 128 | conv_l1_k4 = conv(60, 3, 1, LReLU, apply_instnorm=False)(conv_l1_k3) 129 | 130 | up_l1 = conv_transp(30, 2, 2, LReLU, apply_instnorm=False)(conv_l1_k4) 131 | conv_l0_k2 = conv(30, 5, 1, LReLU, apply_instnorm=False)(tf.concat([conv_l0_k1, up_l1], axis=3)) 132 | conv_l0_k3 = conv(30, 5, 1, LReLU, apply_instnorm=False)(conv_l0_k2) 133 | 134 | out = conv(3, 1, 1, None, apply_instnorm=False)(conv_l0_k3) 135 | out = tf.clip_by_value(out, 0.0, 1.0) 136 | 137 | return tf.keras.Model(inputs=[inputs,snr,otf_1x,ew_1x,otf_2x,ew_2x,otf_4x,ew_4x], 138 | outputs=[out, out]) 139 | -------------------------------------------------------------------------------- /networks/G/Wiener.py: -------------------------------------------------------------------------------- 1 | # Wiener deconvolution 2 | 3 | import tensorflow as tf 4 | from metasurface.conv import deconvolve_wnr 5 | 6 | def Wiener(params, args): 7 | print('Loading Wiener model') 8 | h = params['network_width'] 9 | w = params['network_width'] 10 | inputs = tf.keras.layers.Input(shape=[h ,w ,3]) 11 | snr = tf.keras.layers.Input(shape=[]) 12 | otf_1x = tf.keras.layers.Input(shape=[3, h , w ], dtype=tf.complex64) 13 | ew_1x = tf.keras.layers.Input(shape=[h ,w ,3]) 14 | otf_2x = tf.keras.layers.Input(shape=[3,h//2,w//2], dtype=tf.complex64) 15 | ew_2x = tf.keras.layers.Input(shape=[h//2,w//2,3]) 16 | otf_4x = tf.keras.layers.Input(shape=[3,h//4,w//4], dtype=tf.complex64) 17 | ew_4x = tf.keras.layers.Input(shape=[h//4,w//4,3]) 18 | 19 | snr10 = tf.math.pow(10.0, snr) 20 | outputs, blur = deconvolve_wnr(inputs, snr10, otf_1x, ew_1x, do_taper=args.do_taper) 21 | 22 | outputs = tf.clip_by_value(outputs, 0.0, 1.0) 23 | outputs = tf.image.resize_with_crop_or_pad(outputs, params['out_width'], params['out_width']) 24 | return tf.keras.Model(inputs=[inputs,snr,otf_1x,ew_1x,otf_2x,ew_2x,otf_4x,ew_4x], outputs=[outputs, blur, ew_1x]) 25 | -------------------------------------------------------------------------------- /networks/select.py: -------------------------------------------------------------------------------- 1 | # Select deconvolution method 2 | 3 | from networks.G.FP import FP 4 | from networks.G.Wiener import Wiener 5 | 6 | def select_G(params, args): 7 | if args.G_network == 'FP': 8 | return FP(params, args) 9 | elif args.G_network == 'Wiener': 10 | return Wiener(params, args) 11 | else: 12 | assert False, ("Unsupported generator network") 13 | -------------------------------------------------------------------------------- /run_train.sh: -------------------------------------------------------------------------------- 1 | # Data loading 2 | TRAIN_DIR=./training/data/train/ 3 | TEST_DIR=./training/data/test/ 4 | 5 | # Saving and logging 6 | SAVE_DIR=./training/ckpt/ 7 | LOG_FREQ=20 8 | SAVE_FREQ=50 9 | CKPT_DIR=None 10 | MAX_TO_KEEP=2 11 | 12 | # Loss 13 | LOSS_MODE=L1 14 | BATCH_WEIGHTS=1.0,1.0,1.0 15 | NORM_LOSS_WEIGHT=1.0 16 | P_LOSS_WEIGHT=0.1 17 | VGG_LAYERS=block2_conv2,block3_conv2 18 | SPATIAL_LOSS_WEIGHT=0.0 19 | 20 | # Training 21 | AUG_ROTATE=True 22 | 23 | # Convolution 24 | REAL_PSF=./experimental/data/psf/psf.npy 25 | PSF_MODE=SIM_PSF 26 | CONV_MODE=SIM 27 | CONV=patch_size 28 | DO_TAPER=True 29 | OFFSET=True 30 | NORMALIZE_PSF=True 31 | THETA_BASE=0.0,5.0,10.0,15.0 32 | 33 | # Metasurface 34 | NUM_COEFFS=8 35 | USE_GENERAL_PHASE=False 36 | METASURFACE=log_asphere #zeros 37 | S1=0.9e-3 38 | S2=1.4e-3 39 | ALPHA=270.176968209 40 | TARGET_WAVELENGTH=511.0e-9 41 | BOUND_VAL=1000.0 42 | 43 | # Sensor 44 | A_POISSON=0.00004 45 | B_SQRT=0.00001 46 | MAG=8.1 # Set so that image size is 720 x 720 47 | 48 | # Optimization 49 | PHASE_LR=0.005 50 | PHASE_ITERS=0 51 | G_LR=0.0001 52 | G_ITERS=10 53 | G_NETWORK=FP 54 | SNR_OPT=False #True 55 | SNR_INIT=3.0 56 | 57 | python train.py --train_dir $TRAIN_DIR --test_dir $TEST_DIR --save_dir $SAVE_DIR --log_freq $LOG_FREQ --save_freq $SAVE_FREQ --ckpt_dir $CKPT_DIR --max_to_keep $MAX_TO_KEEP --loss_mode $LOSS_MODE --batch_weights $BATCH_WEIGHTS --Norm_loss_weight $NORM_LOSS_WEIGHT --P_loss_weight $P_LOSS_WEIGHT --vgg_layers $VGG_LAYERS --Spatial_loss_weight $SPATIAL_LOSS_WEIGHT --aug_rotate $AUG_ROTATE --real_psf $REAL_PSF --psf_mode $PSF_MODE --conv_mode $CONV_MODE --conv $CONV --do_taper $DO_TAPER --offset $OFFSET --normalize_psf $NORMALIZE_PSF --theta_base $THETA_BASE --num_coeffs $NUM_COEFFS --use_general_phase $USE_GENERAL_PHASE --metasurface $METASURFACE --s1 $S1 --s2 $S2 --alpha $ALPHA --target_wavelength $TARGET_WAVELENGTH --bound_val $BOUND_VAL --a_poisson $A_POISSON --b_sqrt $B_SQRT --mag $MAG --Phase_lr $PHASE_LR --Phase_iters $PHASE_ITERS --G_lr $G_LR --G_iters $G_ITERS --G_network $G_NETWORK --snr_opt $SNR_OPT --snr_init $SNR_INIT 58 | -------------------------------------------------------------------------------- /test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Inference Code for 'Neural Nano-Optics for High-quality Thin Lens Imaging'\n", 8 | "\n", 9 | "#### This notebook can be used to produce the experimental reconstructions shown in the manuscript and in the supplemental information." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import tensorflow as tf\n", 19 | "import tensorflow_addons as tfa\n", 20 | "import numpy as np\n", 21 | "from networks.select import select_G\n", 22 | "from args import parse_args\n", 23 | "import metasurface.solver as solver\n", 24 | "import metasurface.conv as conv\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "import sys" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "# Set up the arguments for real inference\n", 36 | "sys.argv=['','--train_dir','.',\\\n", 37 | " '--test_dir' ,'.',\\\n", 38 | " '--save_dir' ,'.',\\\n", 39 | " '--ckpt_dir' ,'experimental/ckpt/',\\\n", 40 | " '--real_psf' ,'./experimental/data/psf/psf.npy',\\\n", 41 | " '--psf_mode' ,'REAL_PSF',\\\n", 42 | " '--conv_mode','REAL',\\\n", 43 | " '--conv' ,'full_size']\n", 44 | "args = parse_args()" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "# Initialize and restore deconvolution method\n", 54 | "params = solver.initialize_params(args)\n", 55 | "params['conv_fn'] = conv.convolution_tf(params, args)\n", 56 | "params['deconv_fn'] = conv.deconvolution_tf(params, args)\n", 57 | "\n", 58 | "snr = tf.Variable(args.snr_init, dtype=tf.float32)\n", 59 | "G = select_G(params, args)\n", 60 | "checkpoint = tf.train.Checkpoint(G=G, snr=snr)\n", 61 | "\n", 62 | "status = checkpoint.restore(tf.train.latest_checkpoint(args.ckpt_dir, latest_filename=None))\n", 63 | "status.expect_partial()" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "# Perform deconvolution" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "# Check that the dimensions agree with experimental captures\n", 80 | "assert(params['image_width'] == 720)\n", 81 | "assert(params['psf_width'] == 360)\n", 82 | "assert(params['network_width'] == 1080)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "# Load in experimentally measured PSFs\n", 92 | "psf = (np.load('./experimental/data/psf/psf.npy'))\n", 93 | "psf = tf.constant(psf)\n", 94 | "psf = tf.image.resize_with_crop_or_pad(psf, params['psf_width'], params['psf_width'])\n", 95 | "psf = psf / tf.reduce_sum(psf, axis=(1,2), keepdims=True)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "def reconstruct(img_name, psf, snr, G):\n", 105 | " img = np.load(img_name)\n", 106 | " _, G_img, _ = params['deconv_fn'](img, psf, snr, G, training=False)\n", 107 | " G_img_ = G_img.numpy()[0,:,:,:]\n", 108 | "\n", 109 | " # Vignette Correct\n", 110 | " vig_factor = np.load('experimental/data/vignette_factor.npy')[0,:,:,:]\n", 111 | " G_img_ = G_img_ * vig_factor\n", 112 | " \n", 113 | " # Gain\n", 114 | " G_img_ = G_img_ * 1.2\n", 115 | " G_img_[G_img_ > 1.0] = 1.0\n", 116 | "\n", 117 | " # Contrast Normalization\n", 118 | " minval = np.percentile(G_img_, 5)\n", 119 | " maxval = np.percentile(G_img_, 95)\n", 120 | " G_img_ = np.clip(G_img_, minval, maxval)\n", 121 | " G_img_ = (G_img_ - minval) / (maxval - minval)\n", 122 | " G_img_[G_img_ > 1.0] = 1.0\n", 123 | "\n", 124 | " plt.figure(figsize=(6,6))\n", 125 | " plt.imshow(G_img_)" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": {}, 131 | "source": [ 132 | "### Reconstruct Images" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "# Figure 3\n", 142 | "reconstruct('./experimental/data/captures/138301.npy', psf, snr, G)\n", 143 | "reconstruct('./experimental/data/captures/102302.npy', psf, snr, G)\n", 144 | "reconstruct('./experimental/data/captures/110802.npy', psf, snr, G)" 145 | ] 146 | } 147 | ], 148 | "metadata": { 149 | "kernelspec": { 150 | "display_name": "Python 3", 151 | "language": "python", 152 | "name": "python3" 153 | }, 154 | "language_info": { 155 | "codemirror_mode": { 156 | "name": "ipython", 157 | "version": 3 158 | }, 159 | "file_extension": ".py", 160 | "mimetype": "text/x-python", 161 | "name": "python", 162 | "nbconvert_exporter": "python", 163 | "pygments_lexer": "ipython3", 164 | "version": "3.6.10" 165 | } 166 | }, 167 | "nbformat": 4, 168 | "nbformat_minor": 4 169 | } 170 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from networks.select import select_G 4 | from dataset import train_dataset_sim, test_dataset_sim 5 | from loss import G_loss 6 | from args import parse_args 7 | 8 | import metasurface.solver as solver 9 | import metasurface.conv as conv 10 | import scipy.optimize as scp_opt 11 | 12 | import os 13 | import time 14 | 15 | ## Logging for TensorBoard 16 | def log(img, gt_img, Phase_var, G, snr, vgg_model, summary_writer, step, params, args): 17 | # Metasurface simulation 18 | 19 | if args.psf_mode == 'SIM_PSF': 20 | solver.set_wavelengths(params, params['lambda_base']) 21 | psfs_debug, psfs_conv_forward = solver.get_psfs(Phase_var * args.bound_val, params, conv_mode=args.conv, aug_rotate=args.aug_rotate) 22 | psfs_conv_deconv = psfs_conv_forward 23 | if args.offset: 24 | # This allow for spatial sensitivity training 25 | psfs_conv_forward = psfs_conv_forward[1:,:,:,:] 26 | psfs_conv_deconv = psfs_conv_deconv[:-1,:,:,:] 27 | assert(psfs_conv_forward.shape[0] == psfs_conv_deconv.shape[0]) 28 | elif args.psf_mode == 'REAL_PSF': 29 | real_psf = np.load(args.real_psf) 30 | real_psf = tf.constant(real_psf, dtype=tf.float32) 31 | real_psf = tf.image.resize_with_crop_or_pad(real_psf, params['psf_width'], params['psf_width']) 32 | real_psf = real_psf / tf.reduce_sum(real_psf, axis=(1,2), keepdims=True) 33 | psfs_debug = real_psf 34 | psfs_conv_forward = real_psf 35 | psfs_conv_deconv = real_psf 36 | else: 37 | assert False, ("Unsupported PSF mode") 38 | 39 | conv_image = params['conv_fn'](img, psfs_conv_forward) 40 | sensor_img = solver.sensor_noise(conv_image, params) 41 | _, G_img, G_debug = params['deconv_fn'](sensor_img, psfs_conv_deconv, snr, G, training=False) 42 | 43 | # Losses 44 | gt_img = tf.image.resize_with_crop_or_pad(gt_img, params['out_width'], params['out_width']) 45 | G_Content_loss_val, G_loss_components, G_metrics = G_loss(G_img, gt_img, vgg_model, args) 46 | 47 | # Save records to TensorBoard 48 | with summary_writer.as_default(): 49 | # Images 50 | tf.summary.image(name = 'Input/Input' , data=img, step=step) 51 | tf.summary.image(name = 'Input/GT' , data=gt_img, step=step) 52 | 53 | if args.offset: 54 | num_patches = np.size(params['theta_base']) - 1 55 | else: 56 | num_patches = np.size(params['theta_base']) 57 | for i in range(num_patches): 58 | tf.summary.image(name = 'Output/Output_'+str(i), data=G_img[i:i+1,:,:,:], step=step) 59 | tf.summary.image(name = 'Blur/Blur_'+str(i), data=conv_image[i:i+1,:,:,:], step=step) 60 | tf.summary.image(name = 'Sensor/Sensor_'+str(i), data=sensor_img[i:i+1,:,:,:], step=step) 61 | for j, debug in enumerate(G_debug): 62 | tf.summary.image(name = 'Debug/Debug_'+str(j)+'_'+str(i), data=debug[i:i+1,:,:,:] , step=step) 63 | 64 | # PSF 65 | for i in range(np.size(params['theta_base'])): 66 | psf_patch = psfs_debug[i:i+1,:,:,:] 67 | tf.summary.image(name='PSF/PSF_'+str(i), 68 | data=psf_patch / tf.reduce_max(psf_patch), step=step) 69 | for l in range(np.size(params['lambda_base'])): 70 | psf_patch = psfs_debug[i:i+1,:,:,l:l+1] 71 | tf.summary.image(name='PSF_'+str(params['lambda_base'][l])+'/PSF_'+str(i), 72 | data=psf_patch / tf.reduce_max(psf_patch), step=step) 73 | for i in range(Phase_var.shape[0]): 74 | tf.summary.scalar(name = 'Phase/Phase_'+str(i), data=Phase_var[i], step=step) 75 | 76 | # Metrics 77 | tf.summary.scalar(name = 'metrics/G_PSNR', data = G_metrics['PSNR'], step=step) 78 | tf.summary.scalar(name = 'metrics/G_SSIM', data = G_metrics['SSIM'], step=step) 79 | tf.summary.scalar(name = 'snr', data = snr, step=step) 80 | 81 | # Content losses 82 | tf.summary.scalar(name = 'loss/G_Content_loss', data = G_Content_loss_val, step=step) 83 | tf.summary.scalar(name = 'loss/G_Norm_loss' , data = G_loss_components['Norm'], step=step) 84 | tf.summary.scalar(name = 'loss/G_P_loss' , data = G_loss_components['P'], step=step) 85 | tf.summary.scalar(name = 'loss/G_Spatial_loss', data = G_loss_components['Spatial'], step=step) 86 | 87 | 88 | ## Optimization Step 89 | def train_step(mode, img, gt_img, Phase_var, Phase_optimizer, G, G_optimizer, snr, vgg_model, params, args): 90 | with tf.GradientTape() as G_tape: 91 | # Metasurface simulation 92 | 93 | if args.psf_mode == 'SIM_PSF': 94 | solver.set_wavelengths(params, params['lambda_base']) 95 | psfs_debug, psfs_conv_forward = solver.get_psfs(Phase_var * args.bound_val, params, conv_mode=args.conv, aug_rotate=args.aug_rotate) 96 | psfs_conv_deconv = psfs_conv_forward 97 | if args.offset: 98 | # This allow for spatial sensitivity training 99 | psfs_conv_forward = psfs_conv_forward[1:,:,:,:] 100 | psfs_conv_deconv = psfs_conv_deconv[:-1,:,:,:] 101 | assert(psfs_conv_forward.shape[0] == psfs_conv_deconv.shape[0]) 102 | 103 | elif args.psf_mode == 'REAL_PSF': 104 | real_psf = np.load(args.real_psf) 105 | real_psf = tf.constant(real_psf, dtype=tf.float32) 106 | real_psf = tf.image.resize_with_crop_or_pad(real_psf, params['psf_width'], params['psf_width']) 107 | real_psf = real_psf / tf.reduce_sum(real_psf, axis=(1,2), keepdims=True) 108 | psfs_debug = real_psf 109 | psfs_conv_forward = real_psf 110 | psfs_conv_deconv = real_psf 111 | else: 112 | assert False, ("Unsupported PSF mode") 113 | 114 | conv_image = params['conv_fn'](img, psfs_conv_forward) 115 | sensor_img = solver.sensor_noise(conv_image, params) 116 | _, G_img, _ = params['deconv_fn'](sensor_img, psfs_conv_deconv, snr, G, training=True) 117 | 118 | # Losses 119 | gt_img = tf.image.resize_with_crop_or_pad(gt_img, params['out_width'], params['out_width']) 120 | G_loss_val, G_loss_components, G_metrics = G_loss(G_img, gt_img, vgg_model, args) 121 | 122 | # Apply gradients 123 | if mode == 'Phase': 124 | Phase_gradients = G_tape.gradient(G_loss_val, Phase_var) 125 | Phase_optimizer.apply_gradients([(Phase_gradients, Phase_var)]) 126 | Phase_var.assign(tf.clip_by_value(Phase_var, -1.0, 1.0)) # Clipped to normalized phase range 127 | elif mode == 'G': 128 | G_vars = G.trainable_variables 129 | if args.snr_opt: 130 | G_vars.append(snr) 131 | G_gradients = G_tape.gradient(G_loss_val, G_vars) 132 | G_optimizer.apply_gradients(zip(G_gradients, G_vars)) 133 | if args.snr_opt: 134 | snr.assign(tf.clip_by_value(snr, 3.0, 4.0)) 135 | else: 136 | assert False, "Non-existant training mode" 137 | 138 | ## Training loop 139 | def train(args): 140 | ## Metasurface 141 | params = solver.initialize_params(args) 142 | 143 | if args.metasurface == 'random': 144 | phase_initial = np.random.uniform(low = -args.bound_val, high = args.bound_val, size = params['num_coeffs']) 145 | elif args.metasurface == 'zeros': 146 | phase_initial = np.zeros(params['num_coeffs'], dtype=np.float32) 147 | elif args.metasurface == 'single': 148 | phase_initial = np.array([-np.pi * (params['Lx'] * params['pixelsX'] / 2) ** 2 / params['wavelength_nominal'] / params['f'], 0.0, 0.0, 0.0, 0.0], dtype=np.float32) 149 | elif args.metasurface == 'neural': 150 | # Best parameters with neural optimization 151 | phase_initial = np.array([-0.3494864 , -0.00324192, -1. , -1. , 152 | -1. , -1. , -1. , -1. ], dtype=np.float32) 153 | phase_initial = phase_initial * args.bound_val # <-- should be 1000 154 | assert(args.bound_val == 1000) 155 | else: 156 | if args.metasurface == 'log_asphere': 157 | phase_log = solver.log_asphere_phase(args.s1, args.s2, params) 158 | elif args.metasurface == 'shifted_axicon': 159 | phase_log = solver.shifted_axicon_phase(args.s1, args.s2, params) 160 | elif args.metasurface == 'squbic': 161 | phase_log = solver.squbic_phase(args.A, params) 162 | elif args.metasurface == 'hyperboidal': 163 | phase_log = solver.hyperboidal_phase(args.target_wavelength, params) 164 | elif args.metasurface == 'cubic': 165 | phase_log = solver.cubic_phase(args.alpha, args.target_wavelength, params) # Only for direct inference 166 | else: 167 | assert False, ("Unsupported metasurface mode") 168 | params['general_phase'] = phase_log # For direct phase inference 169 | if args.use_general_phase: 170 | assert(args.Phase_iters == 0) 171 | 172 | # For optimization 173 | lb = (params['pixelsX'] - params['pixels_aperture']) // 2 174 | ub = (params['pixelsX'] + params['pixels_aperture']) // 2 175 | x = params['x_mesh'][lb : ub, 0] / (0.5 * params['pixels_aperture'] * params['Lx']) 176 | phase_slice = phase_log[0, lb : ub, params['pixelsX'] // 2] 177 | p_fit, _ = scp_opt.curve_fit(params['phase_func'], x, phase_slice, bounds=(-args.bound_val, args.bound_val)) 178 | phase_initial = p_fit 179 | print('Initial Phase: {}'.format(phase_initial), flush=True) 180 | print('Image width: {}'.format(params['image_width']), flush=True) 181 | 182 | # Normalize the phases within the bounds 183 | phase_initial = phase_initial / args.bound_val 184 | Phase_var = tf.Variable(phase_initial, dtype = tf.float32) 185 | Phase_optimizer = tf.keras.optimizers.Adam(args.Phase_lr, beta_1=args.Phase_beta1) 186 | 187 | # SNR term for deconvolution algorithm 188 | snr = tf.Variable(args.snr_init, dtype=tf.float32) 189 | 190 | # Do not optimize phase during finetuning 191 | if args.psf_mode == 'REAL_PSF': 192 | assert(args.Phase_iters == 0) 193 | 194 | # Convolution mode 195 | if args.offset: 196 | assert(len(args.batch_weights) == len(args.theta_base) - 1) 197 | else: 198 | assert(len(args.batch_weights) == len(args.theta_base)) 199 | params['conv_fn'] = conv.convolution_tf(params, args) 200 | params['deconv_fn'] = conv.deconvolution_tf(params, args) 201 | 202 | ## Network architectures 203 | G = select_G(params, args) 204 | G_optimizer = tf.keras.optimizers.Adam(args.G_lr, beta_1=args.G_beta1) 205 | 206 | ## Construct vgg for perceptual loss 207 | if not args.P_loss_weight == 0: 208 | vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet') 209 | vgg_layers = [vgg.get_layer(name).output for name in args.vgg_layers.split(',')] 210 | vgg_model = tf.keras.Model(inputs=vgg.input, outputs=vgg_layers) 211 | vgg_model.trainable = False 212 | else: 213 | vgg_model = None 214 | 215 | ## Saving the model 216 | checkpoint = tf.train.Checkpoint(Phase_optimizer=Phase_optimizer, Phase_var=Phase_var, G_optimizer=G_optimizer, G=G, snr=snr) 217 | 218 | max_to_keep = args.max_to_keep 219 | if args.max_to_keep == 0: 220 | max_to_keep = None 221 | manager = tf.train.CheckpointManager(checkpoint, directory=args.save_dir, max_to_keep=max_to_keep) 222 | 223 | ## Loading pre-trained model if exists 224 | if not args.ckpt_dir == None: 225 | status = checkpoint.restore(tf.train.latest_checkpoint(args.ckpt_dir, latest_filename=None)) 226 | status.expect_partial() # Silence warnings 227 | #status.assert_existing_objects_matched() # Only partial load for networks (we don't load the optimizers) 228 | #status.assert_consumed() 229 | 230 | ## Create summary writer for TensorBoard 231 | summary_writer = tf.summary.create_file_writer(args.save_dir) 232 | 233 | ## Dataset 234 | train_ds = iter(train_dataset_sim(params['out_width'], params['load_width'], args)) 235 | test_ds = list(test_dataset_sim(params['out_width'], params['load_width'], args).take(1)) 236 | 237 | ## Do training 238 | for step in range(args.steps): 239 | start = time.time() 240 | if step % args.save_freq == 0: 241 | print('Saving', flush=True) 242 | manager.save() 243 | if step % args.log_freq == 0: 244 | print('Logging', flush=True) 245 | test_batch = test_ds[0] 246 | img = test_batch[0] 247 | gt_img = test_batch[1] 248 | log(img, gt_img, Phase_var, G, snr, vgg_model, summary_writer, step, params, args) 249 | for _ in range(args.Phase_iters): 250 | img_batch = next(train_ds) 251 | img = img_batch[0] 252 | gt_img = img_batch[1] 253 | train_step('Phase', img, gt_img, Phase_var, Phase_optimizer, G, G_optimizer, snr, vgg_model, params, args) 254 | for _ in range(args.G_iters): 255 | img_batch = next(train_ds) 256 | img = img_batch[0] 257 | gt_img = img_batch[1] 258 | train_step('G', img, gt_img, Phase_var, Phase_optimizer, G, G_optimizer, snr, vgg_model, params, args) 259 | print("Step time: {}\n".format(time.time() - start), flush=True) 260 | 261 | 262 | ## Entry point 263 | def main(): 264 | args = parse_args() 265 | train(args) 266 | 267 | if __name__ == '__main__': 268 | main() 269 | -------------------------------------------------------------------------------- /training/ckpt/ckpt.txt: -------------------------------------------------------------------------------- 1 | Models are saved here during training. 2 | -------------------------------------------------------------------------------- /training/data/test/100600.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/test/100600.jpg -------------------------------------------------------------------------------- /training/data/test/116601.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/test/116601.jpg -------------------------------------------------------------------------------- /training/data/test/137901.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/test/137901.jpg -------------------------------------------------------------------------------- /training/data/train/106000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106000.jpg -------------------------------------------------------------------------------- /training/data/train/106001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106001.jpg -------------------------------------------------------------------------------- /training/data/train/106002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106002.jpg -------------------------------------------------------------------------------- /training/data/train/106100.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106100.jpg -------------------------------------------------------------------------------- /training/data/train/106101.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106101.jpg -------------------------------------------------------------------------------- /training/data/train/106200.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106200.jpg -------------------------------------------------------------------------------- /training/data/train/106201.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106201.jpg -------------------------------------------------------------------------------- /training/data/train/106202.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106202.jpg -------------------------------------------------------------------------------- /training/data/train/106300.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106300.jpg -------------------------------------------------------------------------------- /training/data/train/106301.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106301.jpg -------------------------------------------------------------------------------- /training/data/train/106700.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106700.jpg -------------------------------------------------------------------------------- /training/data/train/106701.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106701.jpg -------------------------------------------------------------------------------- /training/data/train/106702.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106702.jpg -------------------------------------------------------------------------------- /training/data/train/106703.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106703.jpg -------------------------------------------------------------------------------- /training/data/train/106704.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106704.jpg -------------------------------------------------------------------------------- /training/data/train/106800.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106800.jpg -------------------------------------------------------------------------------- /training/data/train/106801.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106801.jpg -------------------------------------------------------------------------------- /training/data/train/106900.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106900.jpg -------------------------------------------------------------------------------- /training/data/train/106901.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106901.jpg -------------------------------------------------------------------------------- /training/data/train/106902.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106902.jpg -------------------------------------------------------------------------------- /training/data/train/106903.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/Neural_Nano-Optics/417a8a30c63d6e64331f8c188d39b927c69690c7/training/data/train/106903.jpg --------------------------------------------------------------------------------