├── README.md ├── assets └── unet.png ├── dataset ├── psf_0.fits ├── psf_1.fits ├── psf_2.fits ├── psf_3.fits └── psf_4.fits ├── examples ├── data.ipynb ├── evaluation.ipynb └── training.ipynb └── src ├── NVCC_monitoring.py ├── algorithms ├── Gerchberg–Saxton.py ├── Input-Output.py ├── animation.py └── utils.py ├── generation ├── generator.py ├── plots.py ├── psf.yaml └── radial.py ├── processing ├── plot3D.py └── zoom.py └── pytorch ├── criterion.py ├── dataset.py ├── lr_analyzer.py ├── models ├── Densenet.py ├── InceptionV3.py ├── Resnet.py ├── Unet.py ├── Unet_PP.py ├── VGG.py └── __pycache__ │ └── Unet.cpython-36.pyc ├── train.py ├── utils.py ├── utils_model.py └── utils_visdom.py /README.md: -------------------------------------------------------------------------------- 1 | # Machine learning for image-based wavefront sensing 2 | 3 | Astronomical images are often degraded by the disturbance of the Earth’s atmosphere. This thesis proposes to improve image-based wavefront sensing techniques using machine learning algorithms. Deep convolutional neural networks (CNN) have thus been trained to estimate the wavefront using one or multiple intensity measurements. 4 | 5 | 6 |

7 | 8 |

9 | 10 | ## Getting Started 11 | 12 | ### Prerequisites 13 | 14 | First, make sure the following python libraries are installed. 15 | 16 | ``` 17 | Aotools 18 | Astropy 19 | Soapy 20 | Scipy 21 | Pytorch 22 | Visdom 23 | ``` 24 | ### Examples 25 | 26 | The dataset generation can be run using. The dataset size and other parameters can be set in the same file. 27 | 28 | ``` 29 | python src/generation/generator.py 30 | ``` 31 | 32 | Some notebooks to highlights the networks and the dataset. 33 | 34 | - [Overview of the dataset](examples/data.ipynb) 35 | - [Network Training](examples/training.ipynb) 36 | - [Network evaluation](examples/evaluation.ipynb) 37 | 38 | Finally some classical algorithms (Gerchberg–Saxton) can be directly tested on the dataset. 39 | 40 | ``` 41 | python src/algorithms/Gerchberg–Saxton.py 42 | ``` 43 | -------------------------------------------------------------------------------- /assets/unet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/povanberg/Machine-learning-for-image-based-wavefront-sensing/a687f422d822a7c1db76375a0a4a67a60b03721d/assets/unet.png -------------------------------------------------------------------------------- /dataset/psf_0.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/povanberg/Machine-learning-for-image-based-wavefront-sensing/a687f422d822a7c1db76375a0a4a67a60b03721d/dataset/psf_0.fits -------------------------------------------------------------------------------- /dataset/psf_1.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/povanberg/Machine-learning-for-image-based-wavefront-sensing/a687f422d822a7c1db76375a0a4a67a60b03721d/dataset/psf_1.fits -------------------------------------------------------------------------------- /dataset/psf_2.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/povanberg/Machine-learning-for-image-based-wavefront-sensing/a687f422d822a7c1db76375a0a4a67a60b03721d/dataset/psf_2.fits -------------------------------------------------------------------------------- /dataset/psf_3.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/povanberg/Machine-learning-for-image-based-wavefront-sensing/a687f422d822a7c1db76375a0a4a67a60b03721d/dataset/psf_3.fits -------------------------------------------------------------------------------- /dataset/psf_4.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/povanberg/Machine-learning-for-image-based-wavefront-sensing/a687f422d822a7c1db76375a0a4a67a60b03721d/dataset/psf_4.fits -------------------------------------------------------------------------------- /examples/training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Example of network training using Pytorch (1.1) and Cuda (9).\n", 10 | "\n", 11 | "# NB: This code use real time monitoring based on Visdom\n", 12 | "# an open source webserver allowing real time monitoring\n", 13 | "#\n", 14 | "# https://github.com/facebookresearch/visdom\n", 15 | "#\n", 16 | "# Start the webserver using:\n", 17 | "# python -m visdom.server\n", 18 | "#\n", 19 | "# Access it on: (by default)\n", 20 | "# http://localhost:8097\n", 21 | "\n", 22 | "# Global import\n", 23 | "import sys\n", 24 | "import torch\n", 25 | "import torch.nn as nn\n", 26 | "import torch.nn.functional as F\n", 27 | "import torch.optim as optim\n", 28 | "from torchvision import transforms\n", 29 | "from collections import OrderedDict\n", 30 | "\n", 31 | "# Local import\n", 32 | "sys.path.insert(0, '../src/pytorch/models/')\n", 33 | "from Unet import UNet\n", 34 | "\n", 35 | "sys.path.insert(0, '../src/pytorch/')\n", 36 | "from dataset import *\n", 37 | "from train import *\n", 38 | "from lr_analyzer import *\n", 39 | "from criterion import *" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 2, 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "data": { 49 | "image/png": "\n", 50 | "text/plain": [ 51 | "
" 52 | ] 53 | }, 54 | "metadata": { 55 | "needs_background": "light" 56 | }, 57 | "output_type": "display_data" 58 | } 59 | ], 60 | "source": [ 61 | "# Load dataset\n", 62 | "\n", 63 | "data_dir = '../dataset/'\n", 64 | "dataset_size = 100000\n", 65 | "dataset = psf_dataset(root_dir = data_dir, \n", 66 | " size = dataset_size,\n", 67 | " transform = transforms.Compose([Noise(), Normalize(), ToTensor()]))\n", 68 | "\n", 69 | "# Check everything works as expected\n", 70 | "import matplotlib.pyplot as plt\n", 71 | "\n", 72 | "id = 0\n", 73 | "sample = dataset[id]\n", 74 | "phase = sample['phase']\n", 75 | "image_in = sample['image'][0]\n", 76 | "image_out = sample['image'][1]\n", 77 | "\n", 78 | "f, axarr = plt.subplots(1, 3, figsize=(15, 10))\n", 79 | "im1 = axarr[0].imshow(phase, cmap=plt.cm.jet)\n", 80 | "im1.set_clim(-np.pi, np.pi)\n", 81 | "axarr[0].set_title(\"Phase\")\n", 82 | "plt.colorbar(im1, ax = axarr[0], fraction=0.046)\n", 83 | "im2 = axarr[1].imshow(image_in, cmap=plt.cm.jet)\n", 84 | "axarr[1].set_title(\"In\")\n", 85 | "plt.colorbar(im2, ax = axarr[1], fraction=0.046)\n", 86 | "im3 = axarr[2].imshow(image_out, cmap=plt.cm.jet)\n", 87 | "axarr[2].set_title(\"Out\")\n", 88 | "plt.colorbar(im3, ax = axarr[2], fraction=0.046)\n", 89 | "plt.show()" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 3, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "# Load model architecture, in this example: Unet \n", 99 | "\n", 100 | "model = UNet(2, 1)\n", 101 | "criterion = RMSELoss()\n", 102 | "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n", 103 | "\n", 104 | "# Move Network to GPU\n", 105 | "\n", 106 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 107 | "if torch.cuda.device_count() > 1:\n", 108 | " model = nn.DataParallel(model)\n", 109 | " model.cuda()\n", 110 | "\n", 111 | "# Eventually load existing weights\n", 112 | "\n", 113 | "#model_dir = 'ADAM_it2/model.pth'\n", 114 | "#state_dict = torch.load(model_dir)\n", 115 | "#new_state_dict = OrderedDict()\n", 116 | "#for k, v in state_dict.items():\n", 117 | "# name = k[7:] # remove module.\n", 118 | "# new_state_dict[name] = v\n", 119 | "#model.load_state_dict(state_dict)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": { 126 | "scrolled": true 127 | }, 128 | "outputs": [], 129 | "source": [ 130 | "# Launch training script. The network weights are automatically saved \n", 131 | "# at the end of an epoch (if the test error is reduced). The metrics are also\n", 132 | "# saved at the end of each epoch in JSON format. All outputs are also stored in a \n", 133 | "# log file.\n", 134 | "#\n", 135 | "# - model = network to train\n", 136 | "# - dataset = dataset object\n", 137 | "# - optimizer = gradient descent optimizer (Adam, SGD, RMSProp)\n", 138 | "# - criterion = loss function\n", 139 | "# - split[x, 1-x] = Division train/test. 'x' is the proportion of the test set.\n", 140 | "# - batch_size = batch size\n", 141 | "# - n_epochs = number of epochs\n", 142 | "# - model_dir = where to save the results\n", 143 | "# - visdom = enable real time monitoring\n", 144 | "\n", 145 | "train(model, \n", 146 | " dataset, \n", 147 | " optimizer, \n", 148 | " criterion,\n", 149 | " split = [0.50, 0.50],\n", 150 | " batch_size = 64,\n", 151 | " n_epochs = 500,\n", 152 | " model_dir = './',\n", 153 | " visdom = True)" 154 | ] 155 | } 156 | ], 157 | "metadata": { 158 | "kernelspec": { 159 | "display_name": "Python (myenv)", 160 | "language": "python", 161 | "name": "myenv" 162 | }, 163 | "language_info": { 164 | "codemirror_mode": { 165 | "name": "ipython", 166 | "version": 3 167 | }, 168 | "file_extension": ".py", 169 | "mimetype": "text/x-python", 170 | "name": "python", 171 | "nbconvert_exporter": "python", 172 | "pygments_lexer": "ipython3", 173 | "version": "3.6.7" 174 | } 175 | }, 176 | "nbformat": 4, 177 | "nbformat_minor": 2 178 | } 179 | -------------------------------------------------------------------------------- /src/NVCC_monitoring.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | from threading import Timer 5 | from pynvml import * 6 | import matplotlib.pyplot as plt 7 | 8 | # Small codes dedicated to the monitoring of nvidia GPUs. 9 | 10 | def getGPUMetrics(): 11 | # Metrics accuracy within 1%, see docs: 12 | # docs.nvidia.com/deploy/nvml-api 13 | try: 14 | nvmlInit() 15 | except err: 16 | print("Failed to initialize NVML: ", err) 17 | os.exit(1) 18 | 19 | deviceCount = nvmlDeviceGetCount() 20 | GPUs = [nvmlDeviceGetHandleByIndex(i) for i in range(deviceCount)] 21 | 22 | temperatures = [] 23 | fanSpeed = [] 24 | power = [] 25 | memory = [] 26 | 27 | for i in range(deviceCount): 28 | temperatures.append(nvmlDeviceGetTemperature(GPUs[i], NVML_TEMPERATURE_GPU)) 29 | memory.append(nvmlDeviceGetMemoryInfo(GPUs[i]).used) 30 | fanSpeed.append(nvmlDeviceGetFanSpeed(GPUs[i])) 31 | power.append(nvmlDeviceGetPowerUsage(GPUs[i]) / 1000) # Miliwatt to watt 32 | 33 | try: 34 | nvmlShutdown() 35 | except err: 36 | print("Error shutting down NVML:", err) 37 | os.exit(1) 38 | 39 | metrics = { 'gpu%i'%i: { 40 | 'temperatures': temperatures[i], 41 | 'fanSpeed': fanSpeed[i], 42 | 'memory': memory[i], 43 | 'power': power[i], 44 | 'time': time.strftime("%H:%M:%S") 45 | } for i in range(deviceCount) 46 | } 47 | 48 | return metrics 49 | 50 | 51 | def saveGPUMetrics(metrics, saving_dir='', name='monitoring_metrics.json', deviceCount=2): 52 | # Save to json GPU metrics 53 | 54 | json_path = os.path.join(saving_dir, name) 55 | if os.path.exists(json_path): 56 | # Load existing metrics and add news 57 | with open(json_path, 'r') as f: 58 | old_metrics = json.load(f) 59 | for i in range(deviceCount): 60 | for key in old_metrics['gpu%i'%i]: 61 | old_metrics['gpu%i'%i][key].append(metrics['gpu%i'%i][key]) 62 | with open(json_path, 'w') as f: 63 | json.dump(old_metrics, f, indent=4) 64 | else: 65 | # If does not exist 66 | with open(json_path, 'w') as f: 67 | for i in range(deviceCount): 68 | for key in metrics['gpu%i'%i]: 69 | metrics['gpu%i'%i][key] = [metrics['gpu%i'%i][key]] 70 | json.dump(metrics, f, indent=4) 71 | 72 | def plotMetrics(json_path, key_name, limit=93): 73 | 74 | with open(json_path, 'r') as f: 75 | metrics = json.load(f) 76 | 77 | plt.plot(metrics['gpu0'][key_name], label='GPU 0') 78 | plt.plot(metrics['gpu1'][key_name], label='GPU 1') 79 | plt.hlines(limit,0, len(metrics['gpu0'][key_name]), color='red', linestyle='--') 80 | plt.grid() 81 | plt.legend() 82 | plt.title(key_name) 83 | plt.show() 84 | 85 | # How to use: 86 | # monitor = monitoring.monitoringGPU(30) # autostart, time in seconds 87 | # - Do fancy stuffs 88 | # monitor.stop() 89 | class monitoringGPU(object): 90 | def __init__(self, interval, *args, **kwargs): 91 | self._timer = None 92 | self.interval = interval 93 | self.args = args 94 | self.kwargs = kwargs 95 | self.is_running = False 96 | self.start() 97 | 98 | def _run(self): 99 | self.is_running = False 100 | self.start() 101 | metrics = getGPUMetrics() 102 | saveGPUMetrics(metrics, *self.args, **self.kwargs) 103 | 104 | def start(self): 105 | if not self.is_running: 106 | self._timer = Timer(self.interval, self._run) 107 | self._timer.start() 108 | self.is_running = True 109 | 110 | def stop(self): 111 | self._timer.cancel() 112 | self.is_running = False 113 | 114 | -------------------------------------------------------------------------------- /src/algorithms/Gerchberg–Saxton.py: -------------------------------------------------------------------------------- 1 | import aotools 2 | from astropy.io import fits 3 | import numpy as np 4 | import utils 5 | from animation import * 6 | from time import time 7 | 8 | def GerchbergSaxton(target, source, phase, n_max=200, animation=True): 9 | ''' 10 | Phase retrieval, Gerchberg-Saxton algorithm. 11 | 12 | [1] R. W. Gerchberg and W. O. Saxton, “A practical algorithm 13 | for the determination of the phase from image and diffraction 14 | plane pictures,” Optik 35, 237 (1972) 15 | 16 | [2] J. R. Fienup, "Phase retrieval algorithms: a comparison," 17 | Appl. Opt. 21, 2758-2769 (1982) 18 | 19 | :param target: 20 | :param source: 21 | :param phase: Algorithm goal, provided for visualization and metrics 22 | :param n_max: Maximum number of iteration 23 | :param animation: 24 | :return: 25 | ''' 26 | 27 | # Add padding 28 | target = utils.addPadding(np.sqrt(target)) 29 | source = utils.addPadding(source) 30 | 31 | # Metrics: tuple -> (time, error) 32 | metrics = [] 33 | 34 | # Initialize animation 35 | if animation: 36 | f, axarr = initAnimation() 37 | 38 | # Timer 39 | timer = 0.0 40 | 41 | # Random initializer 42 | A = source * np.exp(1j * 0.0 * np.pi * (np.random.rand(source.shape[0], source.shape[1])*2-1)) 43 | 44 | for n in range(n_max): 45 | t0 = time() 46 | B = np.absolute(source) * np.exp(1j * np.angle(A)) 47 | C = utils.fft(B) 48 | D = np.absolute(target) * np.exp(1j * np.angle(C)) 49 | A = utils.ifft(D) 50 | 51 | t1 = time() 52 | timer += t1-t0 53 | 54 | phaseEst = source * np.angle(A) 55 | #phaseEst = np.rot90(np.rot90(-phaseEst)) 56 | error = utils.rootMeanSquaredError(phase, utils.removePadding(phaseEst), mask=True) 57 | #error = utils.rootMeanSquaredError(C, D, mask=True) 58 | 59 | metrics.append((timer, error)) 60 | 61 | if animation: 62 | H = utils.addPadding(mask) * np.exp(1j * (phaseEst-utils.addPadding(phase))) 63 | h = utils.fft(H) 64 | psf = utils.removePadding(np.abs(h) ** 2) 65 | updateAnimation(f, axarr, metrics, phase, utils.removePadding(phaseEst), psf, timer) 66 | 67 | return metrics 68 | 69 | if __name__ == '__main__': 70 | 71 | # Files 72 | #reference_file = 'references.fits' 73 | psf_file = '../../dataset/psf_1.fits' 74 | 75 | # Data 76 | wavelength = 2200 * (10**(-9)) #[m] 77 | n=20 78 | z_basis = aotools.zernikeArray(n+1, 128, norm='rms') #[rad] 79 | mask = aotools.circle(64, 128) 80 | 81 | #rv_HDU = fits.open(reference_file) 82 | #mask = rv_HDU[0].data # [0-1] function defining entrance pupil 83 | #psf_reference = rv_HDU[1].data # diffraction limited point spread function 84 | 85 | HDU = fits.open(psf_file) 86 | phase = utils.meterToRadian(HDU[1].data, wavelength* (10**(9))) 87 | 88 | H = utils.addPadding(mask) * np.exp(1j * utils.addPadding(phase)) 89 | h = utils.fft(H) 90 | psf_test = utils.removePadding(np.abs(h)**2) 91 | 92 | metrics = GerchbergSaxton(psf_test, mask, phase, n_max=200, animation=True) 93 | 94 | -------------------------------------------------------------------------------- /src/algorithms/Input-Output.py: -------------------------------------------------------------------------------- 1 | import aotools 2 | from astropy.io import fits 3 | import numpy as np 4 | import utils 5 | from animation import * 6 | from time import time 7 | 8 | def HybridInputOutput(target, source, phase, n_max=200, animation=True): 9 | ''' 10 | 11 | [1] E. Osherovich, Numerical methods for phase retrieval, 2012, 12 | https://arxiv.org/abs/1203.4756 13 | [2] J. R. Fienup, Phase retrieval algorithms: a comparison, 1982, 14 | https://www.osapublishing.org/ao/abstract.cfm?uri=ao-21-15-2758 15 | 16 | :param target: 17 | :param source: 18 | :param phase: Algorithm goal, provided for visualization and metrics 19 | :param n_max: Maximum number of iteration 20 | :param animation: 21 | :return: 22 | ''' 23 | 24 | # Add padding 25 | target = utils.addPadding(np.sqrt(target)) 26 | source = utils.addPadding(source) 27 | 28 | # Metrics: tuple -> (time, rmse) 29 | metrics = [] 30 | 31 | # Initialize animation 32 | if animation: 33 | f, axarr = initAnimation() 34 | 35 | # Timer 36 | timer = 0.0 37 | 38 | # Random initializer 39 | g_k_prime = np.exp(1j * 0.0 * np.pi * (np.random.rand(source.shape[0], source.shape[1])*2-1)) 40 | 41 | 42 | # Previous iteration 43 | g_k_previous = None 44 | 45 | for n in range(n_max): 46 | t0 = time() 47 | 48 | g_k = source * np.exp(1j * np.angle(g_k_prime)) 49 | G_k= utils.fft(g_k) 50 | G_k_prime = np.absolute(target) * np.exp(1j * np.angle(G_k)) 51 | g_k_prime = utils.ifft(G_k_prime) 52 | 53 | 54 | if g_k_previous is None: 55 | g_k_previous = g_k_prime 56 | else: 57 | g_k_previous = g_k 58 | 59 | indices = np.logical_or(np.logical_and(g_k < 0, source), np.logical_not(source)) 60 | 61 | g_k[indices] = g_k_previous[indices] - 0.9 * np.real(g_k_prime[indices]) 62 | 63 | t1 = time() 64 | timer += t1-t0 65 | 66 | phaseEst = source * np.angle(g_k) 67 | #phaseEst = np.rot90(np.rot90(-phaseEst)) 68 | error = utils.rootMeanSquaredError(phase, utils.removePadding(phaseEst), mask=True) 69 | #error = utils.rootMeanSquaredError(G_k, G_k_prime, mask=True) 70 | 71 | metrics.append((timer, error)) 72 | 73 | if animation: 74 | H = utils.addPadding(mask) * np.exp(1j * (phaseEst-utils.addPadding(phase))) 75 | h = utils.fft(H) 76 | psf = utils.removePadding(np.abs(h) ** 2) 77 | updateAnimation(f, axarr, metrics, phase, utils.removePadding(phaseEst), psf, timer) 78 | 79 | return metrics 80 | 81 | if __name__ == '__main__': 82 | 83 | # Files 84 | reference_file = 'references.fits' 85 | psf_file = 'psf_1.fits' 86 | 87 | # Data 88 | wavelength = 2200 * (10**(-9)) #[m] 89 | n=20 90 | z_basis = aotools.zernikeArray(n+1, 128, norm='rms') #[rad] 91 | 92 | rv_HDU = fits.open(reference_file) 93 | mask = rv_HDU[0].data # [0-1] function defining entrance pupil 94 | psf_reference = rv_HDU[1].data # diffraction limited point spread function 95 | 96 | HDU = fits.open(psf_file) 97 | phase = utils.meterToRadian(HDU[1].data, wavelength* (10**(9))) 98 | 99 | H = utils.addPadding(mask) * np.exp(1j * utils.addPadding(phase)) 100 | h = utils.fft(H) 101 | psf_test = utils.removePadding(np.abs(h)**2) 102 | 103 | metrics = HybridInputOutput(psf_test, mask, phase, n_max=300, animation=True) 104 | metrics = np.array(metrics) 105 | -------------------------------------------------------------------------------- /src/algorithms/animation.py: -------------------------------------------------------------------------------- 1 | import aotools 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import matplotlib as mpl 5 | import utils 6 | 7 | 8 | def initAnimation(): 9 | mpl.style.use('default') 10 | f, axarr = plt.subplots(2, 2, figsize=(10, 10)) 11 | return f, axarr 12 | 13 | def updateAnimation(f, axarr, error, phase, phaseEst, psf, timer): 14 | 15 | f.suptitle('Algorithm time: {0:.5}s'.format(timer)) 16 | cmap = plt.cm.jet 17 | error = np.array(error) 18 | im1 = axarr[0, 0].plot(error[:, 1], linewidth=2.5) 19 | axarr[0, 0].grid(color='lightgrey', linestyle='--') 20 | axarr[0, 0].set_title("Wavefront error") 21 | axarr[0, 0].set_xlabel('iterations') 22 | axarr[0, 0].set_ylabel('RMSE') 23 | im2 = axarr[0, 1].imshow(psf**(1/3), cmap=cmap) 24 | cb2 = plt.colorbar(im2, ax=axarr[0, 1], fraction=0.046) 25 | axarr[0, 1].set_title("Point Spread function (strehl={0:.5f})".format(utils.strehl(phase-phaseEst))) 26 | axarr[0, 1].set_axis_off() 27 | mask=aotools.circle(64, 128).astype(np.float64) 28 | phase[mask<0.1]=None 29 | phaseEst[mask<0.1]=None 30 | im3 = axarr[1, 0].imshow(phase, cmap=cmap) 31 | im3.set_clim(-np.pi,np.pi) 32 | cb3 = plt.colorbar(im3, ax=axarr[1, 0], fraction=0.046) 33 | axarr[1, 0].set_title("Exact Phase") 34 | axarr[1, 0].set_axis_off() 35 | im4 = axarr[1, 1].imshow(phaseEst, cmap=cmap) 36 | im4.set_clim(-np.pi, np.pi) 37 | axarr[1, 1].set_title("Recovered phase") 38 | axarr[1, 1].set_axis_off() 39 | cb4 = plt.colorbar(im4, ax=axarr[1, 1], fraction=0.046) 40 | plt.pause(1e-5) 41 | axarr[0, 0].cla() 42 | cb2.remove() 43 | cb3.remove() 44 | cb4.remove() 45 | phase[mask<0.1]=0 46 | phaseEst[mask<0.1]=0 47 | -------------------------------------------------------------------------------- /src/algorithms/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.fft as FFT 3 | import aotools 4 | 5 | def meterToRadian(array, wavelength): 6 | ''' 7 | Convert array from meter to radian 8 | 9 | :param array: [nm] 10 | :param wavelength: [nm] 11 | :return: [rad] 12 | ''' 13 | array_rad = (array / wavelength) * (2*np.pi) 14 | return array_rad 15 | 16 | def getPhase(z_coeffs, z_basis): 17 | ''' 18 | Compute phase from Zernike basis and Zernike coeffs 19 | 20 | :param z_coeffs: [rad] 21 | :param z_basis: [rad] 22 | :return: [rad] 23 | ''' 24 | phase = z_coeffs[:, None, None] * z_basis[:, :, :] 25 | phase = np.sum(phase, axis=0) 26 | phase = np.squeeze(phase) 27 | return phase 28 | 29 | def fft(array): 30 | ''' 31 | Compute discrete fast fourier transform 32 | 33 | :param array: 34 | :return: 35 | ''' 36 | fft_array = FFT.fftshift(FFT.fft2(FFT.fftshift(array))) 37 | return fft_array 38 | 39 | def ifft(array): 40 | ''' 41 | Compute discrete inverse fast fourier transform 42 | 43 | :param array: 44 | :return: 45 | ''' 46 | fft_array = FFT.ifftshift(FFT.ifft2(FFT.ifftshift(array))) 47 | return fft_array 48 | 49 | def pad_with(vector, pad_width, iaxis, kwargs): 50 | ''' 51 | Padding utils 52 | 53 | :param vector: 54 | :param pad_width: 55 | :param iaxis: 56 | :param kwargs: 57 | :return: 58 | ''' 59 | pad_value = kwargs.get('padder', 10) 60 | vector[:pad_width[0]] = pad_value 61 | vector[-pad_width[1]:] = pad_value 62 | return vector 63 | 64 | def addPadding(array, padding=2): 65 | ''' 66 | Add padding to array 67 | 68 | :param array: 69 | :param padding: 70 | :return: 71 | ''' 72 | size = array.shape[1] 73 | padded_array = np.pad(array, padding*size, pad_with, padder = 0) 74 | return padded_array 75 | 76 | def removePadding(array, padding=2): 77 | ''' 78 | Remove padding from array 79 | 80 | :param array: 81 | :param padding: 82 | :return: 83 | ''' 84 | size = array.shape[1] // (2*padding + 1) 85 | rmPixel = padding*size 86 | return array[rmPixel:size+rmPixel,rmPixel:size+rmPixel] 87 | 88 | 89 | def rootMeanSquaredError(array1, array2, mask=True): 90 | ''' 91 | RMSE error between array1 and array2 92 | if mask=True computed over circle 93 | 94 | :param array: 95 | :return: 96 | ''' 97 | if mask is True: 98 | size = array1.shape[1] 99 | center = size//2 100 | radius = size//2 101 | 102 | n = 0 103 | error = 0.0 104 | for x in range(size): 105 | for y in range(size): 106 | if (x-center)**2 + (y-center)**2 <= radius: 107 | n += 1 108 | error += (array1[x, y]-array2[x, y])**2 109 | rms_error = np.sqrt((1/n)*(error)) 110 | else: 111 | rms_error = np.sqrt(((array1 - array2) ** 2).mean()) 112 | return rms_error 113 | 114 | def strehl(phase): 115 | mask = aotools.circle(64, 128) 116 | N= 0.0 117 | phase_mean = 0.0 118 | for i in range(128): 119 | for j in range(128): 120 | if(mask[i,j]>= 0.001): 121 | N += 1 122 | phase_mean += phase[i, j] 123 | phase_mean = phase_mean / N 124 | strehl = np.abs(np.mean(np.exp(1j*(phase-phase_mean))))**2 125 | return strehl -------------------------------------------------------------------------------- /src/generation/generator.py: -------------------------------------------------------------------------------- 1 | import time 2 | import aotools 3 | from radial import radial_data 4 | import numpy as np 5 | from scipy import fftpack 6 | from astropy.io import fits 7 | from soapy import SCI, confParse 8 | from matplotlib import pyplot as plt 9 | 10 | # ------------------------------------------------------------------------ 11 | # Generate Point Spread functions from randomly drawn non-common path 12 | # aberrations. The aberration follows a 1/f^2 law. 13 | # One PSF in focus as well as a PSF out of focus are saved in FITS format 14 | # (see astropy docs). The corresponding phase and Zernike Coefficient 15 | # are also saved. 16 | # ------------------------------------------------------------------------ 17 | 18 | np.random.seed(seed=0) 19 | 20 | SOAPY_CONF = "psf.yaml" # Soapy config 21 | gridsize = 128 # Pixel size of science camera 22 | wavelength = 2.2e-6 # Observational wavelength 23 | diameter = 10 # Telescope diameter 24 | pixelScale = 0.01 # [''/px]s 25 | 26 | n_psfs = 5 # Number of PSFs 27 | n_zernike = 100 # Number of Zernike polynomials 28 | i_zernike = np.arange(2, n_zernike + 2) # Zernike polynomial indices (piston excluded) 29 | o_zernike= [] # Zernike polynomial radial Order, see J. Noll paper : 30 | for i in range(1,n_zernike): # "Zernike polynomials and atmospheric turbulence", 1975 31 | for j in range(i+1): 32 | if len(o_zernike) < n_zernike: 33 | o_zernike.append(i) 34 | 35 | # Generate randomly Zernike coefficient. By dividing the value 36 | # by its radial order we produce a distribution following 37 | # the expected 1/f^-2 law. 38 | c_zernike = 2 * np.random.random((n_psfs, n_zernike)) - 1 39 | for j in range(n_psfs): 40 | for i in range(n_zernike): 41 | c_zernike[j, i] = c_zernike[j, i] / o_zernike[i] 42 | c_zernike = np.array([c_zernike[k, :] / np.abs(c_zernike[k, :]).sum() 43 | * wavelength*(10**9) for k in range(n_psfs)]) 44 | 45 | # Update scientific camera parameters 46 | config = confParse.loadSoapyConfig(SOAPY_CONF) 47 | config.scis[0].pxlScale = pixelScale 48 | config.tel.telDiam = diameter 49 | config.calcParams() 50 | 51 | mask = aotools.circle(config.sim.pupilSize / 2., config.sim.simSize).astype(np.float64) 52 | zernike_basis = aotools.zernikeArray(n_zernike + 1, config.sim.pupilSize, norm='rms') 53 | 54 | psfObj = SCI.PSF(config, nSci=0, mask=mask) 55 | 56 | psfs_in = np.zeros((n_psfs, psfObj.detector.shape[0], psfObj.detector.shape[1])) 57 | psfs_out = np.zeros((n_psfs, psfObj.detector.shape[0], psfObj.detector.shape[1])) 58 | 59 | defocus = (wavelength / 4) * (10 ** 9) * zernike_basis[3, :, :] 60 | 61 | t0 = time.time() 62 | n_fail = 0 63 | 64 | for i in range(n_psfs): 65 | 66 | aberrations_in = np.squeeze(np.sum(c_zernike[i, :, None, None] * zernike_basis[1:, :, :], axis=0)) 67 | psfs_in[i, :, :] = np.copy(psfObj.frame(aberrations_in.astype(np.float64))) 68 | 69 | aberations_out = np.squeeze(aberrations_in) + defocus 70 | psfs_out[i, :, :] = np.copy(psfObj.frame(aberations_out.astype(np.float64))) 71 | 72 | # psfs_in[i, :, :] = np.random.poisson(lam=100000*psfs_in[i, :, :], size=None) 73 | # psfs_out[i, :, :] = np.random.poisson(lam=100000*psfs_out[i, :, :], size=None) 74 | 75 | # Save 76 | outfile = "psf_" + str(i) + ".fits" 77 | hdu_primary = fits.PrimaryHDU(c_zernike[i, :].astype(np.float32)) 78 | hdu_phase = fits.ImageHDU(aberrations_in.astype(np.float32), name='PHASE') 79 | hdu_In = fits.ImageHDU(psfs_in[i, :, :].astype(np.float32), name='INFOCUS') 80 | hdu_Out = fits.ImageHDU(psfs_out[i, :, :].astype(np.float32), name='OUTFOCUS') 81 | hdu = fits.HDUList([hdu_primary, hdu_phase, hdu_In, hdu_Out]) 82 | hdu.writeto(outfile, overwrite=True) 83 | 84 | t_soapy = time.time() - t0 85 | print('Propagation and saving finished in {0:2f}s'.format(t_soapy)) 86 | print('Failed: {0:2f}'.format(n_fail)) 87 | -------------------------------------------------------------------------------- /src/generation/plots.py: -------------------------------------------------------------------------------- 1 | import time 2 | import aotools 3 | from radial import radial_data 4 | import numpy as np 5 | from scipy import fftpack 6 | from astropy.io import fits 7 | from soapy import SCI, confParse 8 | from matplotlib import pyplot as plt 9 | 10 | id = 0 11 | phase = np.squeeze(np.sum(c_zernike[id, :, None, None] * zernike_basis[1:, :, :], axis=0)) 12 | F1 = fftpack.fft2(phase) 13 | F2 = fftpack.fftshift( F1 ) 14 | psd2D = np.abs( F2 )**2 15 | 16 | plt.imshow(np.sqrt(psfs_in[id,:,:]), cmap=plt.cm.jet) 17 | plt.axis('off') 18 | plt.savefig('psf_in.pdf') 19 | plt.imshow(np.sqrt(psfs_out[id,:,:]), cmap=plt.cm.jet) 20 | plt.axis('off') 21 | plt.savefig('psf_out.pdf') 22 | plt.imshow(phase, cmap=plt.cm.jet) 23 | plt.axis('off') 24 | plt.savefig('phase_in.pdf') 25 | 26 | fig, ax = plt.subplots(figsize=(15, 5)) 27 | width = 0.4 28 | plt.bar(i_zernike[:100], np.abs(c_zernike[id]/2200*2*np.pi)[:100], color='#32526e', width=width, zorder=3) 29 | #plt.title('Zernike coefficient distribution', fontsize=19) 30 | plt.xlabel('zernike coefficients', fontsize=16) 31 | plt.ylabel('magnitude [rad]', fontsize=16) 32 | plt.xticks(fontsize=16) 33 | plt.yticks(fontsize=16) 34 | ax.spines['right'].set_visible(False) 35 | ax.spines['top'].set_visible(False) 36 | plt.grid(zorder=0, color='lightgray', linestyle='--') 37 | plt.ylim(0,0.4) 38 | plt.savefig('z_distrib.pdf') 39 | plt.show() 40 | 41 | rad_obj = radial_data(psd2D, rmax=64) 42 | fig, ax = plt.subplots() 43 | plt.xlabel('Spatial frequency (cycles/pupil)', fontsize=13) 44 | plt.ylabel('PSF (nm²nm²)', fontsize=13) 45 | plt.loglog(rad_obj.r[1:], psd2D[65:128, 64]) 46 | ax.spines['right'].set_visible(False) 47 | ax.spines['top'].set_visible(False) 48 | plt.grid(zorder=0, color='lightgray', linestyle='--') 49 | start, end = ax.get_xlim() 50 | plt.xticks(np.logspace(np.log10(start), np.log10(end), num=9, base=10),('10⁰','','','','10¹','','','','','')) 51 | plt.savefig('PSD_rad.pdf') 52 | plt.show() 53 | 54 | fig, ax = plt.subplots() 55 | #plt.title('1D PSD avg', fontsize=15) 56 | plt.xlabel('Spatial frequency (cycles/pupil)', fontsize=13) 57 | plt.ylabel('PSF (nm²nm²)', fontsize=13) 58 | plt.loglog(rad_obj.r[1:],rad_obj.mean[1:]) 59 | ax.spines['right'].set_visible(False) 60 | ax.spines['top'].set_visible(False) 61 | plt.grid(zorder=0, color='lightgray', linestyle='--') 62 | start, end = ax.get_xlim() 63 | plt.xticks(np.logspace(np.log10(start), np.log10(end), num=9, base=10),('10⁰','','','','10¹','','','','','')) 64 | plt.savefig('PSD_rad_avg.pdf') 65 | plt.show() 66 | -------------------------------------------------------------------------------- /src/generation/psf.yaml: -------------------------------------------------------------------------------- 1 | simName: 2 | pupilSize: 128 3 | 4 | nSci: 1 5 | nIters: 5000 6 | loopTime: 0.0025 7 | threads: 4 8 | 9 | verbosity: 2 10 | 11 | 12 | Atmosphere: 13 | scrnNo: 1 14 | scrnHeights: [0] 15 | scrnStrengths: [1] 16 | windDirs: [0] 17 | windSpeeds: [5] 18 | wholeScrnSize: 2048 19 | r0: 0.1 20 | L0: [100] 21 | infinite: True 22 | 23 | Telescope: 24 | telDiam: 10 25 | obsDiam: 0 26 | mask: circle 27 | 28 | Reconstructor: 29 | type: MVM 30 | 31 | 32 | Science: 33 | 0: 34 | position: [0, 0] 35 | FOV: 10.0 36 | #pxlScale: 0.2 37 | wavelength: 2.2e-6 38 | pxls: 128 39 | fftOversamp: 2 40 | fftwThreads: 0 41 | 42 | fftwFlag: "FFTW_MEASURE" 43 | 44 | -------------------------------------------------------------------------------- /src/generation/radial.py: -------------------------------------------------------------------------------- 1 | def radial_data(data,annulus_width=1,working_mask=None,x=None,y=None,rmax=None): 2 | """ 3 | r = radial_data(data,annulus_width,working_mask,x,y) 4 | 5 | A function to reduce an image to a radial cross-section. 6 | 7 | INPUT: 8 | ------ 9 | data - whatever data you are radially averaging. Data is 10 | binned into a series of annuli of width 'annulus_width' 11 | pixels. 12 | annulus_width - width of each annulus. Default is 1. 13 | working_mask - array of same size as 'data', with zeros at 14 | whichever 'data' points you don't want included 15 | in the radial data computations. 16 | x,y - coordinate system in which the data exists (used to set 17 | the center of the data). By default, these are set to 18 | integer meshgrids 19 | rmax -- maximum radial value over which to compute statistics 20 | 21 | OUTPUT: 22 | ------- 23 | r - a data structure containing the following 24 | statistics, computed across each annulus: 25 | .r - the radial coordinate used (outer edge of annulus) 26 | .mean - mean of the data in the annulus 27 | .std - standard deviation of the data in the annulus 28 | .median - median value in the annulus 29 | .max - maximum value in the annulus 30 | .min - minimum value in the annulus 31 | .numel - number of elements in the annulus 32 | """ 33 | 34 | # 2010-03-10 19:22 IJC: Ported to python from Matlab 35 | # 2005/12/19 Added 'working_region' option (IJC) 36 | # 2005/12/15 Switched order of outputs (IJC) 37 | # 2005/12/12 IJC: Removed decifact, changed name, wrote comments. 38 | # 2005/11/04 by Ian Crossfield at the Jet Propulsion Laboratory 39 | 40 | import numpy as ny 41 | 42 | class radialDat: 43 | """Empty object container. 44 | """ 45 | def __init__(self): 46 | self.mean = None 47 | self.std = None 48 | self.median = None 49 | self.numel = None 50 | self.max = None 51 | self.min = None 52 | self.r = None 53 | 54 | #--------------------- 55 | # Set up input parameters 56 | #--------------------- 57 | data = ny.array(data) 58 | 59 | if working_mask==None: 60 | working_mask = ny.ones(data.shape,bool) 61 | 62 | npix, npiy = data.shape 63 | if x==None or y==None: 64 | x1 = ny.arange(-npix/2.,npix/2.) 65 | y1 = ny.arange(-npiy/2.,npiy/2.) 66 | x,y = ny.meshgrid(y1,x1) 67 | 68 | r = abs(x+1j*y) 69 | 70 | if rmax==None: 71 | rmax = r[working_mask].max() 72 | 73 | #--------------------- 74 | # Prepare the data container 75 | #--------------------- 76 | dr = ny.abs([x[0,0] - x[0,1]]) * annulus_width 77 | radial = ny.arange(rmax/dr)*dr + dr/2. 78 | nrad = len(radial) 79 | radialdata = radialDat() 80 | radialdata.mean = ny.zeros(nrad) 81 | radialdata.std = ny.zeros(nrad) 82 | radialdata.median = ny.zeros(nrad) 83 | radialdata.numel = ny.zeros(nrad) 84 | radialdata.max = ny.zeros(nrad) 85 | radialdata.min = ny.zeros(nrad) 86 | radialdata.r = radial 87 | 88 | #--------------------- 89 | # Loop through the bins 90 | #--------------------- 91 | for irad in range(nrad): #= 1:numel(radial) 92 | minrad = irad*dr 93 | maxrad = minrad + dr 94 | thisindex = (r>=minrad) * (r None: 39 | # pylint: disable=invalid-name 40 | self.T_max = T_max 41 | self.eta_min = eta_min 42 | self.factor = factor 43 | self._last_restart: int = 0 44 | self._cycle_counter: int = 0 45 | self._cycle_factor: float = 1. 46 | self._updated_cycle_len: int = T_max 47 | self._initialized: bool = False 48 | super(CosineWithRestarts, self).__init__(optimizer, last_epoch) 49 | 50 | def get_lr(self): 51 | """Get updated learning rate.""" 52 | # HACK: We need to check if this is the first time get_lr() was called, since 53 | # we want to start with step = 0, but _LRScheduler calls get_lr with 54 | # last_epoch + 1 when initialized. 55 | if not self._initialized: 56 | self._initialized = True 57 | return self.base_lrs 58 | 59 | step = self.last_epoch + 1 60 | self._cycle_counter = step - self._last_restart 61 | 62 | lrs = [ 63 | ( 64 | self.eta_min + ((lr - self.eta_min) / 2) * 65 | ( 66 | np.cos( 67 | np.pi * 68 | ((self._cycle_counter) % self._updated_cycle_len) / 69 | self._updated_cycle_len 70 | ) + 1 71 | ) 72 | ) for lr in self.base_lrs 73 | ] 74 | 75 | if self._cycle_counter % self._updated_cycle_len == 0: 76 | # Adjust the cycle length. 77 | self._cycle_factor *= self.factor 78 | self._cycle_counter = 0 79 | self._updated_cycle_len = int(self._cycle_factor * self.T_max) 80 | self._last_restart = step 81 | 82 | return lrs 83 | -------------------------------------------------------------------------------- /src/pytorch/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from astropy.io import fits 3 | from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler 4 | from astropy.visualization import SqrtStretch, MinMaxInterval 5 | import numpy as np 6 | 7 | class psf_dataset(Dataset): 8 | 9 | def __init__(self, root_dir, size, transform=None): 10 | self.size = size 11 | self.root_dir = root_dir 12 | self.transform = transform 13 | 14 | def __len__(self): 15 | return self.size 16 | 17 | def __getitem__(self, id): 18 | 19 | if id >= self.size: 20 | raise ValueError('[Dataset] Index out of bounds') 21 | return None 22 | 23 | sample_name = self.root_dir + 'psf_' + str(int(id)) + '.fits' 24 | sample_hdu = fits.open(sample_name) 25 | 26 | image = np.stack((sample_hdu[2].data, sample_hdu[3].data)).astype(np.float32) 27 | 28 | phase = sample_hdu[1].data.astype(np.float32) 29 | 30 | sample = {'phase': phase, 'image': image} 31 | 32 | if self.transform: 33 | sample = self.transform(sample) 34 | 35 | return sample 36 | 37 | 38 | class Normalize(object): 39 | def __call__(self, sample): 40 | phase, image = sample['phase'], sample['image'] 41 | 42 | image[0] = minmax(np.sqrt(image[0])) 43 | image[1] = minmax(np.sqrt(image[1])) 44 | 45 | phase = (phase/2200.)*2*np.pi 46 | 47 | return {'phase': phase, 'image': image} 48 | 49 | 50 | def minmax(array): 51 | a_min = np.min(array) 52 | a_max = np.max(array) 53 | return (array-a_min)/(a_max-a_min) 54 | 55 | class ToTensor(object): 56 | def __call__(self, sample): 57 | phase, image = sample['phase'], sample['image'] 58 | 59 | return {'phase': torch.from_numpy(phase), 'image': torch.from_numpy(image)} 60 | 61 | class Noise(object): 62 | def __call__(self, sample): 63 | phase, image = sample['phase'], sample['image'] 64 | 65 | noise_intensity = 1000 66 | image[0] = minmax(image[0]) 67 | image[1] = minmax(image[1]) 68 | image[0] = np.random.poisson(lam=noise_intensity*image[0], size=None) 69 | image[1] = np.random.poisson(lam=noise_intensity*image[1], size=None) 70 | 71 | return {'phase': phase, 'image': image} 72 | 73 | 74 | def splitDataLoader(dataset, split=[0.9, 0.1], batch_size=32, random_seed=None, shuffle=True): 75 | indices = list(range(len(dataset))) 76 | s = int(np.floor(split[1] * len(dataset))) 77 | if shuffle: 78 | np.random.seed(random_seed) 79 | np.random.shuffle(indices) 80 | train_indices, val_indices = indices[s:], indices[:s] 81 | 82 | train_sampler, val_sampler = SubsetRandomSampler(train_indices), SubsetRandomSampler(val_indices) 83 | 84 | train_dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, sampler=train_sampler) 85 | val_dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, sampler=val_sampler) 86 | 87 | return train_dataloader, val_dataloader 88 | -------------------------------------------------------------------------------- /src/pytorch/lr_analyzer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import utils 4 | import json 5 | import logging 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | import numpy as np 11 | from torchvision import transforms 12 | from dataset import psf_dataset, splitDataLoader, ToTensor, Normalize 13 | from utils_visdom import VisdomWebServer 14 | import aotools 15 | 16 | def lr_analyzer(model, dataset, optimizer, criterion, split=[0.9, 0.1], batch_size=64, lr=[1e-5, 1e-1]): 17 | 18 | for p in optimizer.param_groups: 19 | p['lr'] = lr[0] 20 | 21 | lr_log = np.geomspace(lr[0], lr[1], 100) 22 | 23 | # Dataset 24 | dataloaders, _ = splitDataLoader(dataset, split=split, batch_size=batch_size) 25 | 26 | # Device 27 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 28 | 29 | losses = [] 30 | lrs = [] 31 | 32 | running_loss = 0.0 33 | it = 0 34 | 35 | for _, sample in enumerate(dataloaders): 36 | # GPU support 37 | inputs = sample['image'].to(device) 38 | phase_0 = sample['phase'].to(device) 39 | 40 | # zero the parameter gradients 41 | optimizer.zero_grad() 42 | 43 | # forward: track history if only in train 44 | with torch.set_grad_enabled(True): 45 | 46 | # Network return phase and zernike coeffs 47 | phase_estimation = model(inputs) 48 | loss = criterion(torch.squeeze(phase_estimation), phase_0) 49 | loss.backward() 50 | optimizer.step() 51 | 52 | losses.append(loss.item()) 53 | lrs.append(get_lr(optimizer)) 54 | 55 | if it == 100: 56 | break 57 | 58 | #update lr 59 | for p in optimizer.param_groups: 60 | p['lr'] = lr_log[it] 61 | 62 | it +=1 63 | 64 | return losses, lrs 65 | 66 | def get_lr(optimizer): 67 | for p in optimizer.param_groups: 68 | lr = p['lr'] 69 | return lr 70 | -------------------------------------------------------------------------------- /src/pytorch/models/Densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import aotools 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | 8 | class Net(nn.Module): 9 | 10 | def __init__(self): 11 | super(Net, self).__init__() 12 | 13 | self.densenet = models.densenet161(pretrained=True) 14 | 15 | for param in self.densenet.parameters(): 16 | param.requires_grad = True 17 | 18 | # Input reshape 19 | first_conv_layer = [nn.Conv2d(2, 3, kernel_size=1, stride=1, bias=True), 20 | nn.AdaptiveMaxPool2d(224), 21 | self.densenet.features.conv0] 22 | self.densenet.features.conv0 = nn.Sequential(*first_conv_layer) 23 | 24 | # Classifier 25 | self.densenet.classifier = nn.Sequential( 26 | nn.Linear(2208, 20, bias=True), 27 | #nn.ReLU(inplace=True), 28 | #nn.BatchNorm1d(1024), 29 | #nn.Linear(1024, 1024, bias=True), 30 | #nn.ReLU(inplace=True), 31 | #nn.BatchNorm1d(1024), 32 | #nn.Linear(1024, 20, bias=True) 33 | ) 34 | 35 | self.phase2dlayer = Phase2DLayer(20,128) 36 | 37 | def forward(self, x): 38 | # 128x128x2 39 | z = self.densenet(x) 40 | phase = self.phase2dlayer(z) 41 | return phase, z 42 | 43 | class Phase2D(torch.autograd.Function): 44 | 45 | @staticmethod 46 | def forward(ctx, input, z_basis): 47 | ctx.z_basis = z_basis.cpu()#.cuda() 48 | output = input[:,:, None, None] * ctx.z_basis[None, 1:,:,:] 49 | return torch.sum(output, dim=1) 50 | 51 | @staticmethod 52 | def backward(ctx, grad_output): 53 | dL_dy = grad_output.unsqueeze(1) 54 | dy_dz = ctx.z_basis[1:,:,:].unsqueeze(0) 55 | grad_input = torch.sum(dL_dy * dy_dz, dim=(2,3)) 56 | return grad_input, None 57 | 58 | class Phase2DLayer(nn.Module): 59 | def __init__(self, input_features, output_features): 60 | super(Phase2DLayer, self).__init__() 61 | self.input_features = input_features 62 | self.output_features = output_features 63 | self.z_basis = aotools.zernikeArray(input_features+1, output_features, norm='rms') 64 | self.z_basis = torch.as_tensor(self.z_basis, dtype=torch.float32) 65 | 66 | def forward(self, input): 67 | return Phase2D.apply(input, self.z_basis) 68 | 69 | 70 | class BasicConv2d(nn.Module): 71 | 72 | def __init__(self, in_channels, out_channels, **kwargs): 73 | super(BasicConv2d, self).__init__() 74 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 75 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 76 | 77 | def forward(self, x): 78 | x = self.conv(x) 79 | x = self.bn(x) 80 | return F.relu(x, inplace=True) -------------------------------------------------------------------------------- /src/pytorch/models/InceptionV3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import aotools 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | 8 | class Net(nn.Module): 9 | 10 | def __init__(self): 11 | super(Net, self).__init__() 12 | 13 | self.inception = models.inception_v3(pretrained=True, transform_input=False) 14 | 15 | for param in self.inception.parameters(): 16 | param.requires_grad = True 17 | 18 | # Input size 19 | first_conv_layer = [nn.Conv2d(2, 3, kernel_size=1, stride=1, bias=True), 20 | nn.AdaptiveMaxPool2d(299), 21 | self.inception.Conv2d_1a_3x3] 22 | self.inception.Conv2d_1a_3x3= nn.Sequential(*first_conv_layer) 23 | 24 | # Fit classifier 25 | self.inception.fc = nn.Sequential( 26 | nn.Linear(2048, 20), 27 | #nn.ReLU(inplace=True), 28 | #nn.BatchNorm1d(2048), 29 | #nn.Linear(2048, 1024), 30 | #nn.ReLU(inplace=True), 31 | #nn.BatchNorm1d(2048), 32 | #nn.Linear(1024, 20) 33 | ) 34 | 35 | self.phase2dlayer = Phase2DLayer(20,128) 36 | 37 | def forward(self, x): 38 | if self.inception.training: 39 | z, _ = self.inception(x) 40 | else: 41 | z = self.inception(x) 42 | phase = self.phase2dlayer(z) 43 | return phase, z 44 | 45 | class Phase2D(torch.autograd.Function): 46 | 47 | @staticmethod 48 | def forward(ctx, input, z_basis): 49 | ctx.z_basis = z_basis.cpu()#.cuda() 50 | output = input[:,:, None, None] * ctx.z_basis[None, 1:,:,:] 51 | return torch.sum(output, dim=1) 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | dL_dy = grad_output.unsqueeze(1) 56 | dy_dz = ctx.z_basis[1:,:,:].unsqueeze(0) 57 | grad_input = torch.sum(dL_dy * dy_dz, dim=(2,3)) 58 | return grad_input, None 59 | 60 | class Phase2DLayer(nn.Module): 61 | def __init__(self, input_features, output_features): 62 | super(Phase2DLayer, self).__init__() 63 | self.input_features = input_features 64 | self.output_features = output_features 65 | self.z_basis = aotools.zernikeArray(input_features+1, output_features, norm='rms') 66 | self.z_basis = torch.as_tensor(self.z_basis, dtype=torch.float32) 67 | 68 | def forward(self, input): 69 | return Phase2D.apply(input, self.z_basis) 70 | 71 | 72 | class BasicConv2d(nn.Module): 73 | 74 | def __init__(self, in_channels, out_channels, **kwargs): 75 | super(BasicConv2d, self).__init__() 76 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 77 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 78 | 79 | def forward(self, x): 80 | x = self.conv(x) 81 | x = self.bn(x) 82 | return F.relu(x, inplace=True) -------------------------------------------------------------------------------- /src/pytorch/models/Resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import aotools 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | 8 | class Net(nn.Module): 9 | 10 | def __init__(self): 11 | super(Net, self).__init__() 12 | 13 | self.resnet = models.resnet50(pretrained=True) 14 | 15 | for param in self.resnet.parameters(): 16 | param.requires_grad = True 17 | 18 | # Input size 2x128x128 -> 2x224x224 19 | first_conv_layer = [nn.Conv2d(2, 3, kernel_size=1, stride=1, bias=True), 20 | nn.AdaptiveMaxPool2d(224), 21 | self.resnet.conv1] 22 | self.resnet.conv1= nn.Sequential(*first_conv_layer) 23 | 24 | # Fit classifier 25 | self.resnet.fc = nn.Sequential( 26 | nn.Linear(2048, 100), 27 | #nn.ReLU(inplace=True), 28 | #nn.BatchNorm1d(1024), 29 | #nn.Linear(1024, 1024), 30 | #nn.ReLU(inplace=True), 31 | #nn.BatchNorm1d(1024), 32 | #nn.Linear(1024, 20) 33 | ) 34 | 35 | self.phase2dlayer = Phase2DLayer(100,128) 36 | 37 | def forward(self, x): 38 | # 128x128x2 39 | z = self.resnet(x) 40 | phase = self.phase2dlayer(z) 41 | return phase 42 | 43 | class Phase2D(torch.autograd.Function): 44 | 45 | @staticmethod 46 | def forward(ctx, input, z_basis): 47 | ctx.z_basis = z_basis.cuda() 48 | output = input[:,:, None, None] * ctx.z_basis[None, 1:,:,:] 49 | return torch.sum(output, dim=1) 50 | 51 | @staticmethod 52 | def backward(ctx, grad_output): 53 | dL_dy = grad_output.unsqueeze(1) 54 | dy_dz = ctx.z_basis[1:,:,:].unsqueeze(0) 55 | grad_input = torch.sum(dL_dy * dy_dz, dim=(2,3)) 56 | return grad_input, None 57 | 58 | class Phase2DLayer(nn.Module): 59 | def __init__(self, input_features, output_features): 60 | super(Phase2DLayer, self).__init__() 61 | self.input_features = input_features 62 | self.output_features = output_features 63 | self.z_basis = aotools.zernikeArray(input_features+1, output_features, norm='rms') 64 | self.z_basis = torch.as_tensor(self.z_basis, dtype=torch.float32) 65 | 66 | def forward(self, input): 67 | return Phase2D.apply(input, self.z_basis) 68 | -------------------------------------------------------------------------------- /src/pytorch/models/Unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class UNet(nn.Module): 6 | def __init__(self, n_channels_in, n_channels_out): 7 | super(UNet, self).__init__() 8 | self.inc = inconv(n_channels_in, 64) 9 | self.down1 = down(64, 128) 10 | self.down2 = down(128, 256) 11 | self.down3 = down(256, 512) 12 | self.down4 = down(512, 512) 13 | self.up1 = up(1024, 256, bilinear=True) 14 | self.up2 = up(512, 128, bilinear=True) 15 | self.up3 = up(256, 64, bilinear=True) 16 | self.up4 = up(128, 64, bilinear=True) 17 | self.outc = outconv(64, n_channels_out) 18 | 19 | def forward(self, x): 20 | x1 = self.inc(x) 21 | x2 = self.down1(x1) 22 | x3 = self.down2(x2) 23 | x4 = self.down3(x3) 24 | x5 = self.down4(x4) 25 | x = self.up1(x5, x4) 26 | x = self.up2(x, x3) 27 | x = self.up3(x, x2) 28 | x = self.up4(x, x1) 29 | x = self.outc(x) 30 | return x 31 | 32 | class double_conv(nn.Module): 33 | '''(conv => BN => ReLU) * 2''' 34 | def __init__(self, in_ch, out_ch): 35 | super(double_conv, self).__init__() 36 | self.conv = nn.Sequential( 37 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 38 | nn.BatchNorm2d(out_ch), 39 | nn.ReLU(inplace=True), 40 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 41 | nn.BatchNorm2d(out_ch), 42 | nn.ReLU(inplace=True) 43 | ) 44 | 45 | def forward(self, x): 46 | x = self.conv(x) 47 | return x 48 | 49 | class inconv(nn.Module): 50 | def __init__(self, in_ch, out_ch): 51 | super(inconv, self).__init__() 52 | self.conv = double_conv(in_ch, out_ch) 53 | 54 | def forward(self, x): 55 | x = self.conv(x) 56 | return x 57 | 58 | class down(nn.Module): 59 | def __init__(self, in_ch, out_ch): 60 | super(down, self).__init__() 61 | self.mpconv = nn.Sequential( 62 | nn.MaxPool2d(2), 63 | double_conv(in_ch, out_ch) 64 | ) 65 | 66 | def forward(self, x): 67 | x = self.mpconv(x) 68 | return x 69 | 70 | class up(nn.Module): 71 | def __init__(self, in_ch, out_ch, bilinear=True): 72 | super(up, self).__init__() 73 | 74 | # would be a nice idea if the upsampling could be learned too, 75 | # but my machine do not have enough memory to handle all those weights 76 | if bilinear: 77 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 78 | else: 79 | self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) 80 | 81 | self.conv = double_conv(in_ch, out_ch) 82 | 83 | def forward(self, x1, x2): 84 | x1 = self.up(x1) 85 | 86 | # input is CHW 87 | diffY = x2.size()[2] - x1.size()[2] 88 | diffX = x2.size()[3] - x1.size()[3] 89 | 90 | x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, 91 | diffY // 2, diffY - diffY//2)) 92 | 93 | x = torch.cat([x2, x1], dim=1) 94 | x = self.conv(x) 95 | return x 96 | 97 | class outconv(nn.Module): 98 | def __init__(self, in_ch, out_ch): 99 | super(outconv, self).__init__() 100 | self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1) 101 | 102 | def forward(self, x): 103 | x = self.conv(x) 104 | return x 105 | 106 | -------------------------------------------------------------------------------- /src/pytorch/models/Unet_PP.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | from torch import nn 5 | from torch.nn import functional as F 6 | import torch 7 | from torchvision import models 8 | import torchvision 9 | 10 | 11 | class VGGBlock(nn.Module): 12 | def __init__(self, in_channels, middle_channels, out_channels, act_func=nn.ReLU(inplace=True)): 13 | super(VGGBlock, self).__init__() 14 | self.act_func = act_func 15 | self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1) 16 | self.bn1 = nn.BatchNorm2d(middle_channels) 17 | self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1) 18 | self.bn2 = nn.BatchNorm2d(out_channels) 19 | 20 | def forward(self, x): 21 | out = self.conv1(x) 22 | out = self.bn1(out) 23 | out = self.act_func(out) 24 | 25 | out = self.conv2(out) 26 | out = self.bn2(out) 27 | out = self.act_func(out) 28 | 29 | return out 30 | 31 | 32 | class UNet(nn.Module): 33 | def __init__(self, args): 34 | super().__init__() 35 | 36 | self.args = args 37 | 38 | nb_filter = [32, 64, 128, 256, 512] 39 | 40 | self.pool = nn.MaxPool2d(2, 2) 41 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 42 | 43 | self.conv0_0 = VGGBlock(args.input_channels, nb_filter[0], nb_filter[0]) 44 | self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1]) 45 | self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2]) 46 | self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3]) 47 | self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4]) 48 | 49 | self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3]) 50 | self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2]) 51 | self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1]) 52 | self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0]) 53 | 54 | self.final = nn.Conv2d(nb_filter[0], 1, kernel_size=1) 55 | 56 | 57 | def forward(self, input): 58 | x0_0 = self.conv0_0(input) 59 | x1_0 = self.conv1_0(self.pool(x0_0)) 60 | x2_0 = self.conv2_0(self.pool(x1_0)) 61 | x3_0 = self.conv3_0(self.pool(x2_0)) 62 | x4_0 = self.conv4_0(self.pool(x3_0)) 63 | 64 | x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1)) 65 | x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1)) 66 | x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1)) 67 | x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1)) 68 | 69 | output = self.final(x0_4) 70 | return output 71 | 72 | 73 | class NestedUNet(nn.Module): 74 | def __init__(self): 75 | super().__init__() 76 | 77 | #self.args = args 78 | 79 | nb_filter = [32, 64, 96, 128, 256] 80 | 81 | self.pool = nn.MaxPool2d(2, 2) 82 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 83 | 84 | self.conv0_0 = VGGBlock(2, nb_filter[0], nb_filter[0]) 85 | self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1]) 86 | self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2]) 87 | self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3]) 88 | self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4]) 89 | 90 | self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0]) 91 | self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1]) 92 | self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2]) 93 | self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3]) 94 | 95 | self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0]) 96 | self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1]) 97 | self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2]) 98 | 99 | self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0]) 100 | self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1]) 101 | 102 | self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0]) 103 | 104 | #if self.args.deepsupervision: 105 | # self.final1 = nn.Conv2d(nb_filter[0], 1, kernel_size=1) 106 | # self.final2 = nn.Conv2d(nb_filter[0], 1, kernel_size=1) 107 | # self.final3 = nn.Conv2d(nb_filter[0], 1, kernel_size=1) 108 | # self.final4 = nn.Conv2d(nb_filter[0], 1, kernel_size=1) 109 | #else: 110 | self.final = nn.Conv2d(nb_filter[0], 1, kernel_size=1) 111 | 112 | 113 | def forward(self, input): 114 | x0_0 = self.conv0_0(input) 115 | x1_0 = self.conv1_0(self.pool(x0_0)) 116 | x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1)) 117 | 118 | x2_0 = self.conv2_0(self.pool(x1_0)) 119 | x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1)) 120 | x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1)) 121 | 122 | x3_0 = self.conv3_0(self.pool(x2_0)) 123 | x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1)) 124 | x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1)) 125 | x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1)) 126 | 127 | x4_0 = self.conv4_0(self.pool(x3_0)) 128 | x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1)) 129 | x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1)) 130 | x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1)) 131 | x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1)) 132 | ''' 133 | if self.args.deepsupervision: 134 | output1 = self.final1(x0_1) 135 | output2 = self.final2(x0_2) 136 | output3 = self.final3(x0_3) 137 | output4 = self.final4(x0_4) 138 | return [output1, output2, output3, output4] 139 | 140 | else: 141 | ''' 142 | output = self.final(x0_4) 143 | return output 144 | -------------------------------------------------------------------------------- /src/pytorch/models/VGG.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import aotools 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class Net(nn.Module): 8 | 9 | def __init__(self): 10 | super(Net, self).__init__() 11 | 12 | self.conv_a1 = BasicConv2d(2, 32, kernel_size=3, stride=1, padding=1) 13 | self.conv_a2 = BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) 14 | 15 | self.conv_b1 = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) 16 | self.conv_b2 = BasicConv2d(64, 64, kernel_size=3, stride=1, padding=1) 17 | 18 | self.conv_c1 = BasicConv2d(64, 128, kernel_size=3, stride=1, padding=1) 19 | self.conv_c2 = BasicConv2d(128, 128, kernel_size=3, stride=1, padding=1) 20 | 21 | self.conv_d1 = BasicConv2d(128, 256, kernel_size=3, stride=1, padding=1) 22 | self.conv_d2 = BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1) 23 | 24 | self.conv_e1 = BasicConv2d(256, 512, kernel_size=3, stride=1, padding=1) 25 | self.conv_e2 = BasicConv2d(512, 512, kernel_size=3, stride=1, padding=1) 26 | 27 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 28 | 29 | self.fc1 = torch.nn.Linear(4*4*512, 1024) 30 | #self.fc1_bn = nn.BatchNorm1d(1024) 31 | self.fc2 = torch.nn.Linear(1024, 1024) 32 | #self.fc2_bn = nn.BatchNorm1d(1024) 33 | self.fc3 = torch.nn.Linear(1024, 20) 34 | 35 | self.phase2dlayer = Phase2DLayer(20,128) 36 | 37 | def forward(self, x): 38 | # 128x128x2 39 | x = self.conv_a1(x) # 128x128x16 40 | x = self.conv_a2(x) 41 | x = self.pool(x) 42 | x = self.conv_b1(x) # 64x64x32 43 | x = self.conv_b2(x) 44 | x = self.pool(x) 45 | x = self.conv_c1(x) # 32x32x64 46 | x = self.conv_c2(x) 47 | x = self.pool(x) 48 | x = self.conv_d1(x) # 16x16x128 49 | x = self.conv_d2(x) 50 | x = self.pool(x) 51 | x = self.conv_e1(x) # 8x8x512 52 | x = self.conv_e2(x) # 8x8x512 53 | x = self.pool(x) 54 | x = x.view(-1, 4*4*512) 55 | x = F.relu(self.fc1(x)) 56 | x = F.relu(self.fc2(x)) 57 | z_coeffs = self.fc3(x) 58 | phase = self.phase2dlayer(z_coeffs) 59 | return phase, z_coeffs 60 | 61 | 62 | class Phase2D(torch.autograd.Function): 63 | 64 | @staticmethod 65 | def forward(ctx, input, z_basis): 66 | ctx.z_basis = z_basis.cpu() #.cuda() 67 | output = input[:,:, None, None] * ctx.z_basis[None, 1:,:,:] 68 | return torch.sum(output, dim=1) 69 | 70 | @staticmethod 71 | def backward(ctx, grad_output): 72 | dL_dy = grad_output.unsqueeze(1) 73 | dy_dz = ctx.z_basis[1:,:,:].unsqueeze(0) 74 | grad_input = torch.sum(dL_dy * dy_dz, dim=(2,3)) 75 | return grad_input, None 76 | 77 | class Phase2DLayer(nn.Module): 78 | def __init__(self, input_features, output_features): 79 | super(Phase2DLayer, self).__init__() 80 | self.input_features = input_features 81 | self.output_features = output_features 82 | self.z_basis = aotools.zernikeArray(input_features+1, output_features, norm='rms') 83 | self.z_basis = torch.as_tensor(self.z_basis, dtype=torch.float32) 84 | 85 | def forward(self, input): 86 | return Phase2D.apply(input, self.z_basis) 87 | 88 | 89 | class BasicConv2d(nn.Module): 90 | 91 | def __init__(self, in_channels, out_channels, **kwargs): 92 | super(BasicConv2d, self).__init__() 93 | self.conv = nn.Conv2d(in_channels, out_channels, bias=True, **kwargs) 94 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 95 | 96 | def forward(self, x): 97 | x = self.conv(x) 98 | x = self.bn(x) 99 | return F.relu(x, inplace=True) -------------------------------------------------------------------------------- /src/pytorch/models/__pycache__/Unet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/povanberg/Machine-learning-for-image-based-wavefront-sensing/a687f422d822a7c1db76375a0a4a67a60b03721d/src/pytorch/models/__pycache__/Unet.cpython-36.pyc -------------------------------------------------------------------------------- /src/pytorch/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import utils 4 | import json 5 | import logging 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torchvision import transforms 11 | from dataset import psf_dataset, splitDataLoader, ToTensor, Normalize 12 | from utils_visdom import VisdomWebServer 13 | import aotools 14 | from criterion import * 15 | 16 | def train(model, dataset, optimizer, criterion, split=[0.9, 0.1], batch_size=32, 17 | n_epochs=1, model_dir='./', random_seed=None, visdom=False): 18 | 19 | # Create directory if doesn't exist 20 | if not os.path.exists(model_dir): 21 | os.makedirs(model_dir) 22 | 23 | # Logging 24 | log_path = os.path.join(model_dir, 'logs.log') 25 | utils.set_logger(log_path) 26 | 27 | # Visdom support 28 | if visdom: 29 | vis = VisdomWebServer() 30 | 31 | # Dataset 32 | dataloaders = {} 33 | dataloaders['train'], dataloaders['val'] = splitDataLoader(dataset, split=split, 34 | batch_size=batch_size, random_seed=random_seed) 35 | 36 | # --- 37 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.1) 38 | #scheduler = CosineWithRestarts(optimizer, T_max=40, eta_min=1e-7, last_epoch=-1) 39 | #scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30, eta_min=1e-7, last_epoch=-1) 40 | 41 | # Metrics 42 | metrics_path = os.path.join(model_dir, 'metrics.json') 43 | 44 | metrics = { 45 | 'model': model_dir, 46 | 'optimizer': optimizer.__class__.__name__, 47 | 'criterion': criterion.__class__.__name__, 48 | 'scheduler': scheduler.__class__.__name__, 49 | 'dataset_size': int(len(dataset)), 50 | 'train_size': int(split[0]*len(dataset)), 51 | 'test_size': int(split[1]*len(dataset)), 52 | 'n_epoch': n_epochs, 53 | 'batch_size': batch_size, 54 | 'learning_rate': [], 55 | 'train_loss': [], 56 | 'val_loss': [], 57 | 'zernike_train_loss': [], 58 | 'zernike_val_loss': [] 59 | } 60 | 61 | # Zernike basis 62 | z_basis = torch.as_tensor(aotools.zernikeArray(100+1, 128, norm='rms'), dtype=torch.float32) 63 | 64 | # Device 65 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 66 | 67 | # Training 68 | since = time.time() 69 | dataset_size = { 70 | 'train':int(split[0]*len(dataset)), 71 | 'val':int(split[1]*len(dataset)) 72 | } 73 | 74 | 75 | best_loss = 0.0 76 | 77 | for epoch in range(n_epochs): 78 | 79 | logging.info('-'*30) 80 | epoch_time = time.time() 81 | 82 | # Each epoch has a training and validation phase 83 | for phase in ['train', 'val']: 84 | if phase == 'train': 85 | model.train() # Set model to training mode 86 | else: 87 | model.eval() # Set model to evaluate mode 88 | 89 | running_loss = 0.0 90 | zernike_loss =0.0 91 | 92 | for _, sample in enumerate(dataloaders[phase]): 93 | # GPU support 94 | inputs = sample['image'].to(device) 95 | phase_0 = sample['phase'].to(device) 96 | 97 | # zero the parameter gradients 98 | optimizer.zero_grad() 99 | 100 | # forward: track history if only in train 101 | with torch.set_grad_enabled(phase == 'train'): 102 | 103 | # Network return phase and zernike coeffs 104 | phase_estimation = model(inputs) 105 | loss = criterion(torch.squeeze(phase_estimation), phase_0) 106 | 107 | # backward 108 | if phase == 'train': 109 | loss.backward() 110 | optimizer.step() 111 | 112 | running_loss += 1 * loss.item() * inputs.size(0) 113 | 114 | logging.info('[%i/%i] %s loss: %f' % (epoch+1, n_epochs, phase, running_loss / dataset_size[phase])) 115 | 116 | # Update metrics 117 | metrics[phase+'_loss'].append(running_loss / dataset_size[phase]) 118 | #metrics['zernike_'+phase+'_loss'].append(zernike_loss / dataset_size[phase]) 119 | if phase=='train': 120 | metrics['learning_rate'].append(get_lr(optimizer)) 121 | 122 | # Adaptive learning rate 123 | if phase == 'val': 124 | scheduler.step() 125 | # Save weigths 126 | if epoch == 0 or running_loss < best_loss: 127 | best_loss = running_loss 128 | model_path = os.path.join(model_dir, 'model.pth') 129 | torch.save(model.state_dict(), model_path) 130 | # Save metrics 131 | with open(metrics_path, 'w') as f: 132 | json.dump(metrics, f, indent=4) 133 | # Visdom update 134 | if visdom: 135 | vis.update(metrics) 136 | 137 | logging.info('[%i/%i] Time: %f s' % (epoch + 1, n_epochs, time.time()-epoch_time)) 138 | 139 | time_elapsed = time.time() - since 140 | logging.info('[-----] All epochs completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 141 | 142 | 143 | 144 | def get_lr(optimizer): 145 | for p in optimizer.param_groups: 146 | lr = p['lr'] 147 | return lr 148 | -------------------------------------------------------------------------------- /src/pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import json 3 | import os 4 | import matplotlib.pyplot as plt 5 | 6 | def set_logger(log_path): 7 | """ 8 | Set the logger to log info in terminal and file `log_path`. 9 | 10 | Args: 11 | log_path: (string) where to log 12 | """ 13 | logger = logging.getLogger() 14 | logger.setLevel(logging.INFO) 15 | 16 | if not logger.handlers: 17 | # Logging to a file 18 | file_handler = logging.FileHandler(log_path) 19 | file_handler.setFormatter(logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s')) 20 | logger.addHandler(file_handler) 21 | 22 | # Logging to console 23 | stream_handler = logging.StreamHandler() 24 | stream_handler.setFormatter(logging.Formatter('%(message)s')) 25 | logger.addHandler(stream_handler) 26 | 27 | 28 | 29 | 30 | class Params(): 31 | """ 32 | Class that loads hyperparameters from a json file. 33 | 34 | params = Params(json_path) 35 | print(params.learning_rate) 36 | params.learning_rate = 0.5 # change the value of learning_rate in params 37 | """ 38 | 39 | def __init__(self, json_path): 40 | 41 | if not os.path.exists(json_path): 42 | with open(json_path, 'w') as f: 43 | data = {} 44 | json.dump(data, f, indent=4) 45 | 46 | with open(json_path) as f: 47 | params = json.load(f) 48 | self.__dict__.update(params) 49 | 50 | def save(self, json_path): 51 | with open(json_path, 'w') as f: 52 | json.dump(self.__dict__, f, indent=4) 53 | 54 | def update(self, json_path): 55 | """Loads parameters from json file""" 56 | with open(json_path) as f: 57 | params = json.load(f) 58 | self.__dict__.update(params) 59 | 60 | def hasKey(self, json_path, key_name): 61 | bool_key = False 62 | with open(json_path) as f: 63 | params = json.load(f) 64 | if key_name in params: 65 | bool_key = True 66 | 67 | return bool_key 68 | 69 | @property 70 | def dict(self): 71 | """Gives dict-like access to Params instance by `params.dict['learning_rate']""" 72 | return self.__dict__ 73 | 74 | 75 | def plot_learningcurve(metrics, save=True, show=True, name='lrcurve.pdf', 76 | xlim=[None,None], ylim=[None,None], zernike=False): 77 | import numpy as np 78 | plt.figure() 79 | #x = np.arange(200) 80 | #plt.plot(x, np.array(metrics['train_loss' if not zernike else 'zernike_train_loss'])[x]/(0.8*np.log(x)), label='Training loss', color='blue') 81 | plt.plot(metrics['train_loss' if not zernike else 'zernike_train_loss'][:], label='Training loss', color='blue') 82 | plt.plot(metrics['val_loss' if not zernike else 'zernike_val_loss'][:], label='Validation loss', color='red') 83 | plt.legend() 84 | plt.grid() 85 | plt.xlim(xlim[0], xlim[1]) 86 | plt.ylim(ylim[0], ylim[1]) 87 | plt.xlabel('epochs') 88 | plt.ylabel('loss') 89 | if save: plt.savefig(name) 90 | if show: plt.show() 91 | 92 | 93 | def get_metrics(model_dir=''): 94 | 95 | metrics_path = os.path.join(model_dir, 'metrics.json') 96 | 97 | with open(metrics_path) as f: 98 | metrics = json.load(f) 99 | return metrics 100 | 101 | return None 102 | -------------------------------------------------------------------------------- /src/pytorch/utils_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class BasicConv2d(nn.Module): 6 | def __init__(self, in_channels, out_channels, **kwargs): 7 | super(BasicConv2d, self).__init__() 8 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 9 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 10 | 11 | def forward(self, x): 12 | x = self.conv(x) 13 | x = self.bn(x) 14 | return F.relu(x, inplace=True) 15 | 16 | 17 | class Phase2D(torch.autograd.Function): 18 | 19 | @staticmethod 20 | def forward(ctx, input, z_basis): 21 | ctx.z_basis = z_basis.cuda() 22 | output = input[:,:, None, None] * ctx.z_basis[None, 1:,:,:] 23 | return torch.sum(output, dim=1) 24 | 25 | @staticmethod 26 | def backward(ctx, grad_output): 27 | dL_dy = grad_output.unsqueeze(1) 28 | dy_dz = ctx.z_basis[1:,:,:].unsqueeze(0) 29 | grad_input = torch.sum(dL_dy * dy_dz, dim=(2,3)) 30 | return grad_input, None 31 | 32 | class Phase2DLayer(nn.Module): 33 | def __init__(self, input_features, output_features): 34 | super(Phase2DLayer, self).__init__() 35 | self.input_features = input_features 36 | self.output_features = output_features 37 | self.z_basis = aotools.zernikeArray(input_features+1, output_features, norm='rms') 38 | self.z_basis = torch.as_tensor(self.z_basis, dtype=torch.float32) 39 | 40 | def forward(self, input): 41 | return Phase2D.apply(input, self.z_basis) 42 | -------------------------------------------------------------------------------- /src/pytorch/utils_visdom.py: -------------------------------------------------------------------------------- 1 | from visdom import Visdom 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import json 5 | from utils import plot_learningcurve 6 | 7 | # Start web server with: python -m visdom.server 8 | 9 | class VisdomWebServer(object): 10 | 11 | def __init__(self): 12 | 13 | DEFAULT_PORT = 8097 14 | DEFAULT_HOSTNAME = "http://localhost" 15 | 16 | self.vis = Visdom(port=DEFAULT_PORT, server=DEFAULT_HOSTNAME) 17 | 18 | def update(self, metrics): 19 | 20 | if not self.vis.check_connection(): 21 | 'No connection could be formed quickly' 22 | return 23 | 24 | # Learning curve 25 | try: 26 | fig, ax = plt.subplots() 27 | plt.plot(metrics['train_loss'], label='Training loss', color='#32526e') 28 | plt.plot(metrics['val_loss'], label='Validation loss', color='#ff6b57') 29 | plt.legend() 30 | ax.spines['right'].set_visible(False) 31 | ax.spines['top'].set_visible(False) 32 | plt.grid(zorder=0, color='lightgray', linestyle='--') 33 | self.vis.matplot(plt, win='lrcurve') 34 | plt.close() 35 | plt.clf() 36 | 37 | fig, ax = plt.subplots() 38 | plt.plot(metrics['learning_rate'], color='#32526e') 39 | ax.spines['right'].set_visible(False) 40 | ax.spines['top'].set_visible(False) 41 | plt.grid(zorder=0, color='lightgray', linestyle='--') 42 | self.vis.matplot(plt, win='lr_rate') 43 | plt.close() 44 | plt.clf() 45 | 46 | #plt.figure() 47 | #plt.plot(metrics['zernike_train_loss'], label='Zernike train loss', color='blue') 48 | #plt.plot(metrics['zernike_val_loss'], label='Zernike val loss', color='red') 49 | #plt.legend() 50 | #plt.grid() 51 | #self.vis.matplot(plt, win='lrcurve_z') 52 | #plt.close() 53 | #plt.clf() 54 | except BaseException as err: 55 | print('Skipped matplotlib example') 56 | print('Error message: ', err) 57 | 58 | 59 | 60 | if __name__ == "__main__": 61 | 62 | from utils import get_metrics 63 | 64 | metrics = get_metrics('experiments/example') 65 | 66 | visdom = VisdomWebServer() 67 | visdom.update(metrics) 68 | 69 | --------------------------------------------------------------------------------