├── README.md
├── assets
└── unet.png
├── dataset
├── psf_0.fits
├── psf_1.fits
├── psf_2.fits
├── psf_3.fits
└── psf_4.fits
├── examples
├── data.ipynb
├── evaluation.ipynb
└── training.ipynb
└── src
├── NVCC_monitoring.py
├── algorithms
├── Gerchberg–Saxton.py
├── Input-Output.py
├── animation.py
└── utils.py
├── generation
├── generator.py
├── plots.py
├── psf.yaml
└── radial.py
├── processing
├── plot3D.py
└── zoom.py
└── pytorch
├── criterion.py
├── dataset.py
├── lr_analyzer.py
├── models
├── Densenet.py
├── InceptionV3.py
├── Resnet.py
├── Unet.py
├── Unet_PP.py
├── VGG.py
└── __pycache__
│ └── Unet.cpython-36.pyc
├── train.py
├── utils.py
├── utils_model.py
└── utils_visdom.py
/README.md:
--------------------------------------------------------------------------------
1 | # Machine learning for image-based wavefront sensing
2 |
3 | Astronomical images are often degraded by the disturbance of the Earth’s atmosphere. This thesis proposes to improve image-based wavefront sensing techniques using machine learning algorithms. Deep convolutional neural networks (CNN) have thus been trained to estimate the wavefront using one or multiple intensity measurements.
4 |
5 |
6 |
7 |
8 |
9 |
10 | ## Getting Started
11 |
12 | ### Prerequisites
13 |
14 | First, make sure the following python libraries are installed.
15 |
16 | ```
17 | Aotools
18 | Astropy
19 | Soapy
20 | Scipy
21 | Pytorch
22 | Visdom
23 | ```
24 | ### Examples
25 |
26 | The dataset generation can be run using. The dataset size and other parameters can be set in the same file.
27 |
28 | ```
29 | python src/generation/generator.py
30 | ```
31 |
32 | Some notebooks to highlights the networks and the dataset.
33 |
34 | - [Overview of the dataset](examples/data.ipynb)
35 | - [Network Training](examples/training.ipynb)
36 | - [Network evaluation](examples/evaluation.ipynb)
37 |
38 | Finally some classical algorithms (Gerchberg–Saxton) can be directly tested on the dataset.
39 |
40 | ```
41 | python src/algorithms/Gerchberg–Saxton.py
42 | ```
43 |
--------------------------------------------------------------------------------
/assets/unet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/povanberg/Machine-learning-for-image-based-wavefront-sensing/a687f422d822a7c1db76375a0a4a67a60b03721d/assets/unet.png
--------------------------------------------------------------------------------
/dataset/psf_0.fits:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/povanberg/Machine-learning-for-image-based-wavefront-sensing/a687f422d822a7c1db76375a0a4a67a60b03721d/dataset/psf_0.fits
--------------------------------------------------------------------------------
/dataset/psf_1.fits:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/povanberg/Machine-learning-for-image-based-wavefront-sensing/a687f422d822a7c1db76375a0a4a67a60b03721d/dataset/psf_1.fits
--------------------------------------------------------------------------------
/dataset/psf_2.fits:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/povanberg/Machine-learning-for-image-based-wavefront-sensing/a687f422d822a7c1db76375a0a4a67a60b03721d/dataset/psf_2.fits
--------------------------------------------------------------------------------
/dataset/psf_3.fits:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/povanberg/Machine-learning-for-image-based-wavefront-sensing/a687f422d822a7c1db76375a0a4a67a60b03721d/dataset/psf_3.fits
--------------------------------------------------------------------------------
/dataset/psf_4.fits:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/povanberg/Machine-learning-for-image-based-wavefront-sensing/a687f422d822a7c1db76375a0a4a67a60b03721d/dataset/psf_4.fits
--------------------------------------------------------------------------------
/examples/training.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# Example of network training using Pytorch (1.1) and Cuda (9).\n",
10 | "\n",
11 | "# NB: This code use real time monitoring based on Visdom\n",
12 | "# an open source webserver allowing real time monitoring\n",
13 | "#\n",
14 | "# https://github.com/facebookresearch/visdom\n",
15 | "#\n",
16 | "# Start the webserver using:\n",
17 | "# python -m visdom.server\n",
18 | "#\n",
19 | "# Access it on: (by default)\n",
20 | "# http://localhost:8097\n",
21 | "\n",
22 | "# Global import\n",
23 | "import sys\n",
24 | "import torch\n",
25 | "import torch.nn as nn\n",
26 | "import torch.nn.functional as F\n",
27 | "import torch.optim as optim\n",
28 | "from torchvision import transforms\n",
29 | "from collections import OrderedDict\n",
30 | "\n",
31 | "# Local import\n",
32 | "sys.path.insert(0, '../src/pytorch/models/')\n",
33 | "from Unet import UNet\n",
34 | "\n",
35 | "sys.path.insert(0, '../src/pytorch/')\n",
36 | "from dataset import *\n",
37 | "from train import *\n",
38 | "from lr_analyzer import *\n",
39 | "from criterion import *"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 2,
45 | "metadata": {},
46 | "outputs": [
47 | {
48 | "data": {
49 | "image/png": "\n",
50 | "text/plain": [
51 | ""
52 | ]
53 | },
54 | "metadata": {
55 | "needs_background": "light"
56 | },
57 | "output_type": "display_data"
58 | }
59 | ],
60 | "source": [
61 | "# Load dataset\n",
62 | "\n",
63 | "data_dir = '../dataset/'\n",
64 | "dataset_size = 100000\n",
65 | "dataset = psf_dataset(root_dir = data_dir, \n",
66 | " size = dataset_size,\n",
67 | " transform = transforms.Compose([Noise(), Normalize(), ToTensor()]))\n",
68 | "\n",
69 | "# Check everything works as expected\n",
70 | "import matplotlib.pyplot as plt\n",
71 | "\n",
72 | "id = 0\n",
73 | "sample = dataset[id]\n",
74 | "phase = sample['phase']\n",
75 | "image_in = sample['image'][0]\n",
76 | "image_out = sample['image'][1]\n",
77 | "\n",
78 | "f, axarr = plt.subplots(1, 3, figsize=(15, 10))\n",
79 | "im1 = axarr[0].imshow(phase, cmap=plt.cm.jet)\n",
80 | "im1.set_clim(-np.pi, np.pi)\n",
81 | "axarr[0].set_title(\"Phase\")\n",
82 | "plt.colorbar(im1, ax = axarr[0], fraction=0.046)\n",
83 | "im2 = axarr[1].imshow(image_in, cmap=plt.cm.jet)\n",
84 | "axarr[1].set_title(\"In\")\n",
85 | "plt.colorbar(im2, ax = axarr[1], fraction=0.046)\n",
86 | "im3 = axarr[2].imshow(image_out, cmap=plt.cm.jet)\n",
87 | "axarr[2].set_title(\"Out\")\n",
88 | "plt.colorbar(im3, ax = axarr[2], fraction=0.046)\n",
89 | "plt.show()"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 3,
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "# Load model architecture, in this example: Unet \n",
99 | "\n",
100 | "model = UNet(2, 1)\n",
101 | "criterion = RMSELoss()\n",
102 | "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
103 | "\n",
104 | "# Move Network to GPU\n",
105 | "\n",
106 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
107 | "if torch.cuda.device_count() > 1:\n",
108 | " model = nn.DataParallel(model)\n",
109 | " model.cuda()\n",
110 | "\n",
111 | "# Eventually load existing weights\n",
112 | "\n",
113 | "#model_dir = 'ADAM_it2/model.pth'\n",
114 | "#state_dict = torch.load(model_dir)\n",
115 | "#new_state_dict = OrderedDict()\n",
116 | "#for k, v in state_dict.items():\n",
117 | "# name = k[7:] # remove module.\n",
118 | "# new_state_dict[name] = v\n",
119 | "#model.load_state_dict(state_dict)"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": null,
125 | "metadata": {
126 | "scrolled": true
127 | },
128 | "outputs": [],
129 | "source": [
130 | "# Launch training script. The network weights are automatically saved \n",
131 | "# at the end of an epoch (if the test error is reduced). The metrics are also\n",
132 | "# saved at the end of each epoch in JSON format. All outputs are also stored in a \n",
133 | "# log file.\n",
134 | "#\n",
135 | "# - model = network to train\n",
136 | "# - dataset = dataset object\n",
137 | "# - optimizer = gradient descent optimizer (Adam, SGD, RMSProp)\n",
138 | "# - criterion = loss function\n",
139 | "# - split[x, 1-x] = Division train/test. 'x' is the proportion of the test set.\n",
140 | "# - batch_size = batch size\n",
141 | "# - n_epochs = number of epochs\n",
142 | "# - model_dir = where to save the results\n",
143 | "# - visdom = enable real time monitoring\n",
144 | "\n",
145 | "train(model, \n",
146 | " dataset, \n",
147 | " optimizer, \n",
148 | " criterion,\n",
149 | " split = [0.50, 0.50],\n",
150 | " batch_size = 64,\n",
151 | " n_epochs = 500,\n",
152 | " model_dir = './',\n",
153 | " visdom = True)"
154 | ]
155 | }
156 | ],
157 | "metadata": {
158 | "kernelspec": {
159 | "display_name": "Python (myenv)",
160 | "language": "python",
161 | "name": "myenv"
162 | },
163 | "language_info": {
164 | "codemirror_mode": {
165 | "name": "ipython",
166 | "version": 3
167 | },
168 | "file_extension": ".py",
169 | "mimetype": "text/x-python",
170 | "name": "python",
171 | "nbconvert_exporter": "python",
172 | "pygments_lexer": "ipython3",
173 | "version": "3.6.7"
174 | }
175 | },
176 | "nbformat": 4,
177 | "nbformat_minor": 2
178 | }
179 |
--------------------------------------------------------------------------------
/src/NVCC_monitoring.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import json
4 | from threading import Timer
5 | from pynvml import *
6 | import matplotlib.pyplot as plt
7 |
8 | # Small codes dedicated to the monitoring of nvidia GPUs.
9 |
10 | def getGPUMetrics():
11 | # Metrics accuracy within 1%, see docs:
12 | # docs.nvidia.com/deploy/nvml-api
13 | try:
14 | nvmlInit()
15 | except err:
16 | print("Failed to initialize NVML: ", err)
17 | os.exit(1)
18 |
19 | deviceCount = nvmlDeviceGetCount()
20 | GPUs = [nvmlDeviceGetHandleByIndex(i) for i in range(deviceCount)]
21 |
22 | temperatures = []
23 | fanSpeed = []
24 | power = []
25 | memory = []
26 |
27 | for i in range(deviceCount):
28 | temperatures.append(nvmlDeviceGetTemperature(GPUs[i], NVML_TEMPERATURE_GPU))
29 | memory.append(nvmlDeviceGetMemoryInfo(GPUs[i]).used)
30 | fanSpeed.append(nvmlDeviceGetFanSpeed(GPUs[i]))
31 | power.append(nvmlDeviceGetPowerUsage(GPUs[i]) / 1000) # Miliwatt to watt
32 |
33 | try:
34 | nvmlShutdown()
35 | except err:
36 | print("Error shutting down NVML:", err)
37 | os.exit(1)
38 |
39 | metrics = { 'gpu%i'%i: {
40 | 'temperatures': temperatures[i],
41 | 'fanSpeed': fanSpeed[i],
42 | 'memory': memory[i],
43 | 'power': power[i],
44 | 'time': time.strftime("%H:%M:%S")
45 | } for i in range(deviceCount)
46 | }
47 |
48 | return metrics
49 |
50 |
51 | def saveGPUMetrics(metrics, saving_dir='', name='monitoring_metrics.json', deviceCount=2):
52 | # Save to json GPU metrics
53 |
54 | json_path = os.path.join(saving_dir, name)
55 | if os.path.exists(json_path):
56 | # Load existing metrics and add news
57 | with open(json_path, 'r') as f:
58 | old_metrics = json.load(f)
59 | for i in range(deviceCount):
60 | for key in old_metrics['gpu%i'%i]:
61 | old_metrics['gpu%i'%i][key].append(metrics['gpu%i'%i][key])
62 | with open(json_path, 'w') as f:
63 | json.dump(old_metrics, f, indent=4)
64 | else:
65 | # If does not exist
66 | with open(json_path, 'w') as f:
67 | for i in range(deviceCount):
68 | for key in metrics['gpu%i'%i]:
69 | metrics['gpu%i'%i][key] = [metrics['gpu%i'%i][key]]
70 | json.dump(metrics, f, indent=4)
71 |
72 | def plotMetrics(json_path, key_name, limit=93):
73 |
74 | with open(json_path, 'r') as f:
75 | metrics = json.load(f)
76 |
77 | plt.plot(metrics['gpu0'][key_name], label='GPU 0')
78 | plt.plot(metrics['gpu1'][key_name], label='GPU 1')
79 | plt.hlines(limit,0, len(metrics['gpu0'][key_name]), color='red', linestyle='--')
80 | plt.grid()
81 | plt.legend()
82 | plt.title(key_name)
83 | plt.show()
84 |
85 | # How to use:
86 | # monitor = monitoring.monitoringGPU(30) # autostart, time in seconds
87 | # - Do fancy stuffs
88 | # monitor.stop()
89 | class monitoringGPU(object):
90 | def __init__(self, interval, *args, **kwargs):
91 | self._timer = None
92 | self.interval = interval
93 | self.args = args
94 | self.kwargs = kwargs
95 | self.is_running = False
96 | self.start()
97 |
98 | def _run(self):
99 | self.is_running = False
100 | self.start()
101 | metrics = getGPUMetrics()
102 | saveGPUMetrics(metrics, *self.args, **self.kwargs)
103 |
104 | def start(self):
105 | if not self.is_running:
106 | self._timer = Timer(self.interval, self._run)
107 | self._timer.start()
108 | self.is_running = True
109 |
110 | def stop(self):
111 | self._timer.cancel()
112 | self.is_running = False
113 |
114 |
--------------------------------------------------------------------------------
/src/algorithms/Gerchberg–Saxton.py:
--------------------------------------------------------------------------------
1 | import aotools
2 | from astropy.io import fits
3 | import numpy as np
4 | import utils
5 | from animation import *
6 | from time import time
7 |
8 | def GerchbergSaxton(target, source, phase, n_max=200, animation=True):
9 | '''
10 | Phase retrieval, Gerchberg-Saxton algorithm.
11 |
12 | [1] R. W. Gerchberg and W. O. Saxton, “A practical algorithm
13 | for the determination of the phase from image and diffraction
14 | plane pictures,” Optik 35, 237 (1972)
15 |
16 | [2] J. R. Fienup, "Phase retrieval algorithms: a comparison,"
17 | Appl. Opt. 21, 2758-2769 (1982)
18 |
19 | :param target:
20 | :param source:
21 | :param phase: Algorithm goal, provided for visualization and metrics
22 | :param n_max: Maximum number of iteration
23 | :param animation:
24 | :return:
25 | '''
26 |
27 | # Add padding
28 | target = utils.addPadding(np.sqrt(target))
29 | source = utils.addPadding(source)
30 |
31 | # Metrics: tuple -> (time, error)
32 | metrics = []
33 |
34 | # Initialize animation
35 | if animation:
36 | f, axarr = initAnimation()
37 |
38 | # Timer
39 | timer = 0.0
40 |
41 | # Random initializer
42 | A = source * np.exp(1j * 0.0 * np.pi * (np.random.rand(source.shape[0], source.shape[1])*2-1))
43 |
44 | for n in range(n_max):
45 | t0 = time()
46 | B = np.absolute(source) * np.exp(1j * np.angle(A))
47 | C = utils.fft(B)
48 | D = np.absolute(target) * np.exp(1j * np.angle(C))
49 | A = utils.ifft(D)
50 |
51 | t1 = time()
52 | timer += t1-t0
53 |
54 | phaseEst = source * np.angle(A)
55 | #phaseEst = np.rot90(np.rot90(-phaseEst))
56 | error = utils.rootMeanSquaredError(phase, utils.removePadding(phaseEst), mask=True)
57 | #error = utils.rootMeanSquaredError(C, D, mask=True)
58 |
59 | metrics.append((timer, error))
60 |
61 | if animation:
62 | H = utils.addPadding(mask) * np.exp(1j * (phaseEst-utils.addPadding(phase)))
63 | h = utils.fft(H)
64 | psf = utils.removePadding(np.abs(h) ** 2)
65 | updateAnimation(f, axarr, metrics, phase, utils.removePadding(phaseEst), psf, timer)
66 |
67 | return metrics
68 |
69 | if __name__ == '__main__':
70 |
71 | # Files
72 | #reference_file = 'references.fits'
73 | psf_file = '../../dataset/psf_1.fits'
74 |
75 | # Data
76 | wavelength = 2200 * (10**(-9)) #[m]
77 | n=20
78 | z_basis = aotools.zernikeArray(n+1, 128, norm='rms') #[rad]
79 | mask = aotools.circle(64, 128)
80 |
81 | #rv_HDU = fits.open(reference_file)
82 | #mask = rv_HDU[0].data # [0-1] function defining entrance pupil
83 | #psf_reference = rv_HDU[1].data # diffraction limited point spread function
84 |
85 | HDU = fits.open(psf_file)
86 | phase = utils.meterToRadian(HDU[1].data, wavelength* (10**(9)))
87 |
88 | H = utils.addPadding(mask) * np.exp(1j * utils.addPadding(phase))
89 | h = utils.fft(H)
90 | psf_test = utils.removePadding(np.abs(h)**2)
91 |
92 | metrics = GerchbergSaxton(psf_test, mask, phase, n_max=200, animation=True)
93 |
94 |
--------------------------------------------------------------------------------
/src/algorithms/Input-Output.py:
--------------------------------------------------------------------------------
1 | import aotools
2 | from astropy.io import fits
3 | import numpy as np
4 | import utils
5 | from animation import *
6 | from time import time
7 |
8 | def HybridInputOutput(target, source, phase, n_max=200, animation=True):
9 | '''
10 |
11 | [1] E. Osherovich, Numerical methods for phase retrieval, 2012,
12 | https://arxiv.org/abs/1203.4756
13 | [2] J. R. Fienup, Phase retrieval algorithms: a comparison, 1982,
14 | https://www.osapublishing.org/ao/abstract.cfm?uri=ao-21-15-2758
15 |
16 | :param target:
17 | :param source:
18 | :param phase: Algorithm goal, provided for visualization and metrics
19 | :param n_max: Maximum number of iteration
20 | :param animation:
21 | :return:
22 | '''
23 |
24 | # Add padding
25 | target = utils.addPadding(np.sqrt(target))
26 | source = utils.addPadding(source)
27 |
28 | # Metrics: tuple -> (time, rmse)
29 | metrics = []
30 |
31 | # Initialize animation
32 | if animation:
33 | f, axarr = initAnimation()
34 |
35 | # Timer
36 | timer = 0.0
37 |
38 | # Random initializer
39 | g_k_prime = np.exp(1j * 0.0 * np.pi * (np.random.rand(source.shape[0], source.shape[1])*2-1))
40 |
41 |
42 | # Previous iteration
43 | g_k_previous = None
44 |
45 | for n in range(n_max):
46 | t0 = time()
47 |
48 | g_k = source * np.exp(1j * np.angle(g_k_prime))
49 | G_k= utils.fft(g_k)
50 | G_k_prime = np.absolute(target) * np.exp(1j * np.angle(G_k))
51 | g_k_prime = utils.ifft(G_k_prime)
52 |
53 |
54 | if g_k_previous is None:
55 | g_k_previous = g_k_prime
56 | else:
57 | g_k_previous = g_k
58 |
59 | indices = np.logical_or(np.logical_and(g_k < 0, source), np.logical_not(source))
60 |
61 | g_k[indices] = g_k_previous[indices] - 0.9 * np.real(g_k_prime[indices])
62 |
63 | t1 = time()
64 | timer += t1-t0
65 |
66 | phaseEst = source * np.angle(g_k)
67 | #phaseEst = np.rot90(np.rot90(-phaseEst))
68 | error = utils.rootMeanSquaredError(phase, utils.removePadding(phaseEst), mask=True)
69 | #error = utils.rootMeanSquaredError(G_k, G_k_prime, mask=True)
70 |
71 | metrics.append((timer, error))
72 |
73 | if animation:
74 | H = utils.addPadding(mask) * np.exp(1j * (phaseEst-utils.addPadding(phase)))
75 | h = utils.fft(H)
76 | psf = utils.removePadding(np.abs(h) ** 2)
77 | updateAnimation(f, axarr, metrics, phase, utils.removePadding(phaseEst), psf, timer)
78 |
79 | return metrics
80 |
81 | if __name__ == '__main__':
82 |
83 | # Files
84 | reference_file = 'references.fits'
85 | psf_file = 'psf_1.fits'
86 |
87 | # Data
88 | wavelength = 2200 * (10**(-9)) #[m]
89 | n=20
90 | z_basis = aotools.zernikeArray(n+1, 128, norm='rms') #[rad]
91 |
92 | rv_HDU = fits.open(reference_file)
93 | mask = rv_HDU[0].data # [0-1] function defining entrance pupil
94 | psf_reference = rv_HDU[1].data # diffraction limited point spread function
95 |
96 | HDU = fits.open(psf_file)
97 | phase = utils.meterToRadian(HDU[1].data, wavelength* (10**(9)))
98 |
99 | H = utils.addPadding(mask) * np.exp(1j * utils.addPadding(phase))
100 | h = utils.fft(H)
101 | psf_test = utils.removePadding(np.abs(h)**2)
102 |
103 | metrics = HybridInputOutput(psf_test, mask, phase, n_max=300, animation=True)
104 | metrics = np.array(metrics)
105 |
--------------------------------------------------------------------------------
/src/algorithms/animation.py:
--------------------------------------------------------------------------------
1 | import aotools
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import matplotlib as mpl
5 | import utils
6 |
7 |
8 | def initAnimation():
9 | mpl.style.use('default')
10 | f, axarr = plt.subplots(2, 2, figsize=(10, 10))
11 | return f, axarr
12 |
13 | def updateAnimation(f, axarr, error, phase, phaseEst, psf, timer):
14 |
15 | f.suptitle('Algorithm time: {0:.5}s'.format(timer))
16 | cmap = plt.cm.jet
17 | error = np.array(error)
18 | im1 = axarr[0, 0].plot(error[:, 1], linewidth=2.5)
19 | axarr[0, 0].grid(color='lightgrey', linestyle='--')
20 | axarr[0, 0].set_title("Wavefront error")
21 | axarr[0, 0].set_xlabel('iterations')
22 | axarr[0, 0].set_ylabel('RMSE')
23 | im2 = axarr[0, 1].imshow(psf**(1/3), cmap=cmap)
24 | cb2 = plt.colorbar(im2, ax=axarr[0, 1], fraction=0.046)
25 | axarr[0, 1].set_title("Point Spread function (strehl={0:.5f})".format(utils.strehl(phase-phaseEst)))
26 | axarr[0, 1].set_axis_off()
27 | mask=aotools.circle(64, 128).astype(np.float64)
28 | phase[mask<0.1]=None
29 | phaseEst[mask<0.1]=None
30 | im3 = axarr[1, 0].imshow(phase, cmap=cmap)
31 | im3.set_clim(-np.pi,np.pi)
32 | cb3 = plt.colorbar(im3, ax=axarr[1, 0], fraction=0.046)
33 | axarr[1, 0].set_title("Exact Phase")
34 | axarr[1, 0].set_axis_off()
35 | im4 = axarr[1, 1].imshow(phaseEst, cmap=cmap)
36 | im4.set_clim(-np.pi, np.pi)
37 | axarr[1, 1].set_title("Recovered phase")
38 | axarr[1, 1].set_axis_off()
39 | cb4 = plt.colorbar(im4, ax=axarr[1, 1], fraction=0.046)
40 | plt.pause(1e-5)
41 | axarr[0, 0].cla()
42 | cb2.remove()
43 | cb3.remove()
44 | cb4.remove()
45 | phase[mask<0.1]=0
46 | phaseEst[mask<0.1]=0
47 |
--------------------------------------------------------------------------------
/src/algorithms/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import numpy.fft as FFT
3 | import aotools
4 |
5 | def meterToRadian(array, wavelength):
6 | '''
7 | Convert array from meter to radian
8 |
9 | :param array: [nm]
10 | :param wavelength: [nm]
11 | :return: [rad]
12 | '''
13 | array_rad = (array / wavelength) * (2*np.pi)
14 | return array_rad
15 |
16 | def getPhase(z_coeffs, z_basis):
17 | '''
18 | Compute phase from Zernike basis and Zernike coeffs
19 |
20 | :param z_coeffs: [rad]
21 | :param z_basis: [rad]
22 | :return: [rad]
23 | '''
24 | phase = z_coeffs[:, None, None] * z_basis[:, :, :]
25 | phase = np.sum(phase, axis=0)
26 | phase = np.squeeze(phase)
27 | return phase
28 |
29 | def fft(array):
30 | '''
31 | Compute discrete fast fourier transform
32 |
33 | :param array:
34 | :return:
35 | '''
36 | fft_array = FFT.fftshift(FFT.fft2(FFT.fftshift(array)))
37 | return fft_array
38 |
39 | def ifft(array):
40 | '''
41 | Compute discrete inverse fast fourier transform
42 |
43 | :param array:
44 | :return:
45 | '''
46 | fft_array = FFT.ifftshift(FFT.ifft2(FFT.ifftshift(array)))
47 | return fft_array
48 |
49 | def pad_with(vector, pad_width, iaxis, kwargs):
50 | '''
51 | Padding utils
52 |
53 | :param vector:
54 | :param pad_width:
55 | :param iaxis:
56 | :param kwargs:
57 | :return:
58 | '''
59 | pad_value = kwargs.get('padder', 10)
60 | vector[:pad_width[0]] = pad_value
61 | vector[-pad_width[1]:] = pad_value
62 | return vector
63 |
64 | def addPadding(array, padding=2):
65 | '''
66 | Add padding to array
67 |
68 | :param array:
69 | :param padding:
70 | :return:
71 | '''
72 | size = array.shape[1]
73 | padded_array = np.pad(array, padding*size, pad_with, padder = 0)
74 | return padded_array
75 |
76 | def removePadding(array, padding=2):
77 | '''
78 | Remove padding from array
79 |
80 | :param array:
81 | :param padding:
82 | :return:
83 | '''
84 | size = array.shape[1] // (2*padding + 1)
85 | rmPixel = padding*size
86 | return array[rmPixel:size+rmPixel,rmPixel:size+rmPixel]
87 |
88 |
89 | def rootMeanSquaredError(array1, array2, mask=True):
90 | '''
91 | RMSE error between array1 and array2
92 | if mask=True computed over circle
93 |
94 | :param array:
95 | :return:
96 | '''
97 | if mask is True:
98 | size = array1.shape[1]
99 | center = size//2
100 | radius = size//2
101 |
102 | n = 0
103 | error = 0.0
104 | for x in range(size):
105 | for y in range(size):
106 | if (x-center)**2 + (y-center)**2 <= radius:
107 | n += 1
108 | error += (array1[x, y]-array2[x, y])**2
109 | rms_error = np.sqrt((1/n)*(error))
110 | else:
111 | rms_error = np.sqrt(((array1 - array2) ** 2).mean())
112 | return rms_error
113 |
114 | def strehl(phase):
115 | mask = aotools.circle(64, 128)
116 | N= 0.0
117 | phase_mean = 0.0
118 | for i in range(128):
119 | for j in range(128):
120 | if(mask[i,j]>= 0.001):
121 | N += 1
122 | phase_mean += phase[i, j]
123 | phase_mean = phase_mean / N
124 | strehl = np.abs(np.mean(np.exp(1j*(phase-phase_mean))))**2
125 | return strehl
--------------------------------------------------------------------------------
/src/generation/generator.py:
--------------------------------------------------------------------------------
1 | import time
2 | import aotools
3 | from radial import radial_data
4 | import numpy as np
5 | from scipy import fftpack
6 | from astropy.io import fits
7 | from soapy import SCI, confParse
8 | from matplotlib import pyplot as plt
9 |
10 | # ------------------------------------------------------------------------
11 | # Generate Point Spread functions from randomly drawn non-common path
12 | # aberrations. The aberration follows a 1/f^2 law.
13 | # One PSF in focus as well as a PSF out of focus are saved in FITS format
14 | # (see astropy docs). The corresponding phase and Zernike Coefficient
15 | # are also saved.
16 | # ------------------------------------------------------------------------
17 |
18 | np.random.seed(seed=0)
19 |
20 | SOAPY_CONF = "psf.yaml" # Soapy config
21 | gridsize = 128 # Pixel size of science camera
22 | wavelength = 2.2e-6 # Observational wavelength
23 | diameter = 10 # Telescope diameter
24 | pixelScale = 0.01 # [''/px]s
25 |
26 | n_psfs = 5 # Number of PSFs
27 | n_zernike = 100 # Number of Zernike polynomials
28 | i_zernike = np.arange(2, n_zernike + 2) # Zernike polynomial indices (piston excluded)
29 | o_zernike= [] # Zernike polynomial radial Order, see J. Noll paper :
30 | for i in range(1,n_zernike): # "Zernike polynomials and atmospheric turbulence", 1975
31 | for j in range(i+1):
32 | if len(o_zernike) < n_zernike:
33 | o_zernike.append(i)
34 |
35 | # Generate randomly Zernike coefficient. By dividing the value
36 | # by its radial order we produce a distribution following
37 | # the expected 1/f^-2 law.
38 | c_zernike = 2 * np.random.random((n_psfs, n_zernike)) - 1
39 | for j in range(n_psfs):
40 | for i in range(n_zernike):
41 | c_zernike[j, i] = c_zernike[j, i] / o_zernike[i]
42 | c_zernike = np.array([c_zernike[k, :] / np.abs(c_zernike[k, :]).sum()
43 | * wavelength*(10**9) for k in range(n_psfs)])
44 |
45 | # Update scientific camera parameters
46 | config = confParse.loadSoapyConfig(SOAPY_CONF)
47 | config.scis[0].pxlScale = pixelScale
48 | config.tel.telDiam = diameter
49 | config.calcParams()
50 |
51 | mask = aotools.circle(config.sim.pupilSize / 2., config.sim.simSize).astype(np.float64)
52 | zernike_basis = aotools.zernikeArray(n_zernike + 1, config.sim.pupilSize, norm='rms')
53 |
54 | psfObj = SCI.PSF(config, nSci=0, mask=mask)
55 |
56 | psfs_in = np.zeros((n_psfs, psfObj.detector.shape[0], psfObj.detector.shape[1]))
57 | psfs_out = np.zeros((n_psfs, psfObj.detector.shape[0], psfObj.detector.shape[1]))
58 |
59 | defocus = (wavelength / 4) * (10 ** 9) * zernike_basis[3, :, :]
60 |
61 | t0 = time.time()
62 | n_fail = 0
63 |
64 | for i in range(n_psfs):
65 |
66 | aberrations_in = np.squeeze(np.sum(c_zernike[i, :, None, None] * zernike_basis[1:, :, :], axis=0))
67 | psfs_in[i, :, :] = np.copy(psfObj.frame(aberrations_in.astype(np.float64)))
68 |
69 | aberations_out = np.squeeze(aberrations_in) + defocus
70 | psfs_out[i, :, :] = np.copy(psfObj.frame(aberations_out.astype(np.float64)))
71 |
72 | # psfs_in[i, :, :] = np.random.poisson(lam=100000*psfs_in[i, :, :], size=None)
73 | # psfs_out[i, :, :] = np.random.poisson(lam=100000*psfs_out[i, :, :], size=None)
74 |
75 | # Save
76 | outfile = "psf_" + str(i) + ".fits"
77 | hdu_primary = fits.PrimaryHDU(c_zernike[i, :].astype(np.float32))
78 | hdu_phase = fits.ImageHDU(aberrations_in.astype(np.float32), name='PHASE')
79 | hdu_In = fits.ImageHDU(psfs_in[i, :, :].astype(np.float32), name='INFOCUS')
80 | hdu_Out = fits.ImageHDU(psfs_out[i, :, :].astype(np.float32), name='OUTFOCUS')
81 | hdu = fits.HDUList([hdu_primary, hdu_phase, hdu_In, hdu_Out])
82 | hdu.writeto(outfile, overwrite=True)
83 |
84 | t_soapy = time.time() - t0
85 | print('Propagation and saving finished in {0:2f}s'.format(t_soapy))
86 | print('Failed: {0:2f}'.format(n_fail))
87 |
--------------------------------------------------------------------------------
/src/generation/plots.py:
--------------------------------------------------------------------------------
1 | import time
2 | import aotools
3 | from radial import radial_data
4 | import numpy as np
5 | from scipy import fftpack
6 | from astropy.io import fits
7 | from soapy import SCI, confParse
8 | from matplotlib import pyplot as plt
9 |
10 | id = 0
11 | phase = np.squeeze(np.sum(c_zernike[id, :, None, None] * zernike_basis[1:, :, :], axis=0))
12 | F1 = fftpack.fft2(phase)
13 | F2 = fftpack.fftshift( F1 )
14 | psd2D = np.abs( F2 )**2
15 |
16 | plt.imshow(np.sqrt(psfs_in[id,:,:]), cmap=plt.cm.jet)
17 | plt.axis('off')
18 | plt.savefig('psf_in.pdf')
19 | plt.imshow(np.sqrt(psfs_out[id,:,:]), cmap=plt.cm.jet)
20 | plt.axis('off')
21 | plt.savefig('psf_out.pdf')
22 | plt.imshow(phase, cmap=plt.cm.jet)
23 | plt.axis('off')
24 | plt.savefig('phase_in.pdf')
25 |
26 | fig, ax = plt.subplots(figsize=(15, 5))
27 | width = 0.4
28 | plt.bar(i_zernike[:100], np.abs(c_zernike[id]/2200*2*np.pi)[:100], color='#32526e', width=width, zorder=3)
29 | #plt.title('Zernike coefficient distribution', fontsize=19)
30 | plt.xlabel('zernike coefficients', fontsize=16)
31 | plt.ylabel('magnitude [rad]', fontsize=16)
32 | plt.xticks(fontsize=16)
33 | plt.yticks(fontsize=16)
34 | ax.spines['right'].set_visible(False)
35 | ax.spines['top'].set_visible(False)
36 | plt.grid(zorder=0, color='lightgray', linestyle='--')
37 | plt.ylim(0,0.4)
38 | plt.savefig('z_distrib.pdf')
39 | plt.show()
40 |
41 | rad_obj = radial_data(psd2D, rmax=64)
42 | fig, ax = plt.subplots()
43 | plt.xlabel('Spatial frequency (cycles/pupil)', fontsize=13)
44 | plt.ylabel('PSF (nm²nm²)', fontsize=13)
45 | plt.loglog(rad_obj.r[1:], psd2D[65:128, 64])
46 | ax.spines['right'].set_visible(False)
47 | ax.spines['top'].set_visible(False)
48 | plt.grid(zorder=0, color='lightgray', linestyle='--')
49 | start, end = ax.get_xlim()
50 | plt.xticks(np.logspace(np.log10(start), np.log10(end), num=9, base=10),('10⁰','','','','10¹','','','','',''))
51 | plt.savefig('PSD_rad.pdf')
52 | plt.show()
53 |
54 | fig, ax = plt.subplots()
55 | #plt.title('1D PSD avg', fontsize=15)
56 | plt.xlabel('Spatial frequency (cycles/pupil)', fontsize=13)
57 | plt.ylabel('PSF (nm²nm²)', fontsize=13)
58 | plt.loglog(rad_obj.r[1:],rad_obj.mean[1:])
59 | ax.spines['right'].set_visible(False)
60 | ax.spines['top'].set_visible(False)
61 | plt.grid(zorder=0, color='lightgray', linestyle='--')
62 | start, end = ax.get_xlim()
63 | plt.xticks(np.logspace(np.log10(start), np.log10(end), num=9, base=10),('10⁰','','','','10¹','','','','',''))
64 | plt.savefig('PSD_rad_avg.pdf')
65 | plt.show()
66 |
--------------------------------------------------------------------------------
/src/generation/psf.yaml:
--------------------------------------------------------------------------------
1 | simName:
2 | pupilSize: 128
3 |
4 | nSci: 1
5 | nIters: 5000
6 | loopTime: 0.0025
7 | threads: 4
8 |
9 | verbosity: 2
10 |
11 |
12 | Atmosphere:
13 | scrnNo: 1
14 | scrnHeights: [0]
15 | scrnStrengths: [1]
16 | windDirs: [0]
17 | windSpeeds: [5]
18 | wholeScrnSize: 2048
19 | r0: 0.1
20 | L0: [100]
21 | infinite: True
22 |
23 | Telescope:
24 | telDiam: 10
25 | obsDiam: 0
26 | mask: circle
27 |
28 | Reconstructor:
29 | type: MVM
30 |
31 |
32 | Science:
33 | 0:
34 | position: [0, 0]
35 | FOV: 10.0
36 | #pxlScale: 0.2
37 | wavelength: 2.2e-6
38 | pxls: 128
39 | fftOversamp: 2
40 | fftwThreads: 0
41 |
42 | fftwFlag: "FFTW_MEASURE"
43 |
44 |
--------------------------------------------------------------------------------
/src/generation/radial.py:
--------------------------------------------------------------------------------
1 | def radial_data(data,annulus_width=1,working_mask=None,x=None,y=None,rmax=None):
2 | """
3 | r = radial_data(data,annulus_width,working_mask,x,y)
4 |
5 | A function to reduce an image to a radial cross-section.
6 |
7 | INPUT:
8 | ------
9 | data - whatever data you are radially averaging. Data is
10 | binned into a series of annuli of width 'annulus_width'
11 | pixels.
12 | annulus_width - width of each annulus. Default is 1.
13 | working_mask - array of same size as 'data', with zeros at
14 | whichever 'data' points you don't want included
15 | in the radial data computations.
16 | x,y - coordinate system in which the data exists (used to set
17 | the center of the data). By default, these are set to
18 | integer meshgrids
19 | rmax -- maximum radial value over which to compute statistics
20 |
21 | OUTPUT:
22 | -------
23 | r - a data structure containing the following
24 | statistics, computed across each annulus:
25 | .r - the radial coordinate used (outer edge of annulus)
26 | .mean - mean of the data in the annulus
27 | .std - standard deviation of the data in the annulus
28 | .median - median value in the annulus
29 | .max - maximum value in the annulus
30 | .min - minimum value in the annulus
31 | .numel - number of elements in the annulus
32 | """
33 |
34 | # 2010-03-10 19:22 IJC: Ported to python from Matlab
35 | # 2005/12/19 Added 'working_region' option (IJC)
36 | # 2005/12/15 Switched order of outputs (IJC)
37 | # 2005/12/12 IJC: Removed decifact, changed name, wrote comments.
38 | # 2005/11/04 by Ian Crossfield at the Jet Propulsion Laboratory
39 |
40 | import numpy as ny
41 |
42 | class radialDat:
43 | """Empty object container.
44 | """
45 | def __init__(self):
46 | self.mean = None
47 | self.std = None
48 | self.median = None
49 | self.numel = None
50 | self.max = None
51 | self.min = None
52 | self.r = None
53 |
54 | #---------------------
55 | # Set up input parameters
56 | #---------------------
57 | data = ny.array(data)
58 |
59 | if working_mask==None:
60 | working_mask = ny.ones(data.shape,bool)
61 |
62 | npix, npiy = data.shape
63 | if x==None or y==None:
64 | x1 = ny.arange(-npix/2.,npix/2.)
65 | y1 = ny.arange(-npiy/2.,npiy/2.)
66 | x,y = ny.meshgrid(y1,x1)
67 |
68 | r = abs(x+1j*y)
69 |
70 | if rmax==None:
71 | rmax = r[working_mask].max()
72 |
73 | #---------------------
74 | # Prepare the data container
75 | #---------------------
76 | dr = ny.abs([x[0,0] - x[0,1]]) * annulus_width
77 | radial = ny.arange(rmax/dr)*dr + dr/2.
78 | nrad = len(radial)
79 | radialdata = radialDat()
80 | radialdata.mean = ny.zeros(nrad)
81 | radialdata.std = ny.zeros(nrad)
82 | radialdata.median = ny.zeros(nrad)
83 | radialdata.numel = ny.zeros(nrad)
84 | radialdata.max = ny.zeros(nrad)
85 | radialdata.min = ny.zeros(nrad)
86 | radialdata.r = radial
87 |
88 | #---------------------
89 | # Loop through the bins
90 | #---------------------
91 | for irad in range(nrad): #= 1:numel(radial)
92 | minrad = irad*dr
93 | maxrad = minrad + dr
94 | thisindex = (r>=minrad) * (r None:
39 | # pylint: disable=invalid-name
40 | self.T_max = T_max
41 | self.eta_min = eta_min
42 | self.factor = factor
43 | self._last_restart: int = 0
44 | self._cycle_counter: int = 0
45 | self._cycle_factor: float = 1.
46 | self._updated_cycle_len: int = T_max
47 | self._initialized: bool = False
48 | super(CosineWithRestarts, self).__init__(optimizer, last_epoch)
49 |
50 | def get_lr(self):
51 | """Get updated learning rate."""
52 | # HACK: We need to check if this is the first time get_lr() was called, since
53 | # we want to start with step = 0, but _LRScheduler calls get_lr with
54 | # last_epoch + 1 when initialized.
55 | if not self._initialized:
56 | self._initialized = True
57 | return self.base_lrs
58 |
59 | step = self.last_epoch + 1
60 | self._cycle_counter = step - self._last_restart
61 |
62 | lrs = [
63 | (
64 | self.eta_min + ((lr - self.eta_min) / 2) *
65 | (
66 | np.cos(
67 | np.pi *
68 | ((self._cycle_counter) % self._updated_cycle_len) /
69 | self._updated_cycle_len
70 | ) + 1
71 | )
72 | ) for lr in self.base_lrs
73 | ]
74 |
75 | if self._cycle_counter % self._updated_cycle_len == 0:
76 | # Adjust the cycle length.
77 | self._cycle_factor *= self.factor
78 | self._cycle_counter = 0
79 | self._updated_cycle_len = int(self._cycle_factor * self.T_max)
80 | self._last_restart = step
81 |
82 | return lrs
83 |
--------------------------------------------------------------------------------
/src/pytorch/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from astropy.io import fits
3 | from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
4 | from astropy.visualization import SqrtStretch, MinMaxInterval
5 | import numpy as np
6 |
7 | class psf_dataset(Dataset):
8 |
9 | def __init__(self, root_dir, size, transform=None):
10 | self.size = size
11 | self.root_dir = root_dir
12 | self.transform = transform
13 |
14 | def __len__(self):
15 | return self.size
16 |
17 | def __getitem__(self, id):
18 |
19 | if id >= self.size:
20 | raise ValueError('[Dataset] Index out of bounds')
21 | return None
22 |
23 | sample_name = self.root_dir + 'psf_' + str(int(id)) + '.fits'
24 | sample_hdu = fits.open(sample_name)
25 |
26 | image = np.stack((sample_hdu[2].data, sample_hdu[3].data)).astype(np.float32)
27 |
28 | phase = sample_hdu[1].data.astype(np.float32)
29 |
30 | sample = {'phase': phase, 'image': image}
31 |
32 | if self.transform:
33 | sample = self.transform(sample)
34 |
35 | return sample
36 |
37 |
38 | class Normalize(object):
39 | def __call__(self, sample):
40 | phase, image = sample['phase'], sample['image']
41 |
42 | image[0] = minmax(np.sqrt(image[0]))
43 | image[1] = minmax(np.sqrt(image[1]))
44 |
45 | phase = (phase/2200.)*2*np.pi
46 |
47 | return {'phase': phase, 'image': image}
48 |
49 |
50 | def minmax(array):
51 | a_min = np.min(array)
52 | a_max = np.max(array)
53 | return (array-a_min)/(a_max-a_min)
54 |
55 | class ToTensor(object):
56 | def __call__(self, sample):
57 | phase, image = sample['phase'], sample['image']
58 |
59 | return {'phase': torch.from_numpy(phase), 'image': torch.from_numpy(image)}
60 |
61 | class Noise(object):
62 | def __call__(self, sample):
63 | phase, image = sample['phase'], sample['image']
64 |
65 | noise_intensity = 1000
66 | image[0] = minmax(image[0])
67 | image[1] = minmax(image[1])
68 | image[0] = np.random.poisson(lam=noise_intensity*image[0], size=None)
69 | image[1] = np.random.poisson(lam=noise_intensity*image[1], size=None)
70 |
71 | return {'phase': phase, 'image': image}
72 |
73 |
74 | def splitDataLoader(dataset, split=[0.9, 0.1], batch_size=32, random_seed=None, shuffle=True):
75 | indices = list(range(len(dataset)))
76 | s = int(np.floor(split[1] * len(dataset)))
77 | if shuffle:
78 | np.random.seed(random_seed)
79 | np.random.shuffle(indices)
80 | train_indices, val_indices = indices[s:], indices[:s]
81 |
82 | train_sampler, val_sampler = SubsetRandomSampler(train_indices), SubsetRandomSampler(val_indices)
83 |
84 | train_dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, sampler=train_sampler)
85 | val_dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, sampler=val_sampler)
86 |
87 | return train_dataloader, val_dataloader
88 |
--------------------------------------------------------------------------------
/src/pytorch/lr_analyzer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import utils
4 | import json
5 | import logging
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.optim as optim
10 | import numpy as np
11 | from torchvision import transforms
12 | from dataset import psf_dataset, splitDataLoader, ToTensor, Normalize
13 | from utils_visdom import VisdomWebServer
14 | import aotools
15 |
16 | def lr_analyzer(model, dataset, optimizer, criterion, split=[0.9, 0.1], batch_size=64, lr=[1e-5, 1e-1]):
17 |
18 | for p in optimizer.param_groups:
19 | p['lr'] = lr[0]
20 |
21 | lr_log = np.geomspace(lr[0], lr[1], 100)
22 |
23 | # Dataset
24 | dataloaders, _ = splitDataLoader(dataset, split=split, batch_size=batch_size)
25 |
26 | # Device
27 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
28 |
29 | losses = []
30 | lrs = []
31 |
32 | running_loss = 0.0
33 | it = 0
34 |
35 | for _, sample in enumerate(dataloaders):
36 | # GPU support
37 | inputs = sample['image'].to(device)
38 | phase_0 = sample['phase'].to(device)
39 |
40 | # zero the parameter gradients
41 | optimizer.zero_grad()
42 |
43 | # forward: track history if only in train
44 | with torch.set_grad_enabled(True):
45 |
46 | # Network return phase and zernike coeffs
47 | phase_estimation = model(inputs)
48 | loss = criterion(torch.squeeze(phase_estimation), phase_0)
49 | loss.backward()
50 | optimizer.step()
51 |
52 | losses.append(loss.item())
53 | lrs.append(get_lr(optimizer))
54 |
55 | if it == 100:
56 | break
57 |
58 | #update lr
59 | for p in optimizer.param_groups:
60 | p['lr'] = lr_log[it]
61 |
62 | it +=1
63 |
64 | return losses, lrs
65 |
66 | def get_lr(optimizer):
67 | for p in optimizer.param_groups:
68 | lr = p['lr']
69 | return lr
70 |
--------------------------------------------------------------------------------
/src/pytorch/models/Densenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import aotools
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torchvision.models as models
7 |
8 | class Net(nn.Module):
9 |
10 | def __init__(self):
11 | super(Net, self).__init__()
12 |
13 | self.densenet = models.densenet161(pretrained=True)
14 |
15 | for param in self.densenet.parameters():
16 | param.requires_grad = True
17 |
18 | # Input reshape
19 | first_conv_layer = [nn.Conv2d(2, 3, kernel_size=1, stride=1, bias=True),
20 | nn.AdaptiveMaxPool2d(224),
21 | self.densenet.features.conv0]
22 | self.densenet.features.conv0 = nn.Sequential(*first_conv_layer)
23 |
24 | # Classifier
25 | self.densenet.classifier = nn.Sequential(
26 | nn.Linear(2208, 20, bias=True),
27 | #nn.ReLU(inplace=True),
28 | #nn.BatchNorm1d(1024),
29 | #nn.Linear(1024, 1024, bias=True),
30 | #nn.ReLU(inplace=True),
31 | #nn.BatchNorm1d(1024),
32 | #nn.Linear(1024, 20, bias=True)
33 | )
34 |
35 | self.phase2dlayer = Phase2DLayer(20,128)
36 |
37 | def forward(self, x):
38 | # 128x128x2
39 | z = self.densenet(x)
40 | phase = self.phase2dlayer(z)
41 | return phase, z
42 |
43 | class Phase2D(torch.autograd.Function):
44 |
45 | @staticmethod
46 | def forward(ctx, input, z_basis):
47 | ctx.z_basis = z_basis.cpu()#.cuda()
48 | output = input[:,:, None, None] * ctx.z_basis[None, 1:,:,:]
49 | return torch.sum(output, dim=1)
50 |
51 | @staticmethod
52 | def backward(ctx, grad_output):
53 | dL_dy = grad_output.unsqueeze(1)
54 | dy_dz = ctx.z_basis[1:,:,:].unsqueeze(0)
55 | grad_input = torch.sum(dL_dy * dy_dz, dim=(2,3))
56 | return grad_input, None
57 |
58 | class Phase2DLayer(nn.Module):
59 | def __init__(self, input_features, output_features):
60 | super(Phase2DLayer, self).__init__()
61 | self.input_features = input_features
62 | self.output_features = output_features
63 | self.z_basis = aotools.zernikeArray(input_features+1, output_features, norm='rms')
64 | self.z_basis = torch.as_tensor(self.z_basis, dtype=torch.float32)
65 |
66 | def forward(self, input):
67 | return Phase2D.apply(input, self.z_basis)
68 |
69 |
70 | class BasicConv2d(nn.Module):
71 |
72 | def __init__(self, in_channels, out_channels, **kwargs):
73 | super(BasicConv2d, self).__init__()
74 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
75 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
76 |
77 | def forward(self, x):
78 | x = self.conv(x)
79 | x = self.bn(x)
80 | return F.relu(x, inplace=True)
--------------------------------------------------------------------------------
/src/pytorch/models/InceptionV3.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import aotools
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torchvision.models as models
7 |
8 | class Net(nn.Module):
9 |
10 | def __init__(self):
11 | super(Net, self).__init__()
12 |
13 | self.inception = models.inception_v3(pretrained=True, transform_input=False)
14 |
15 | for param in self.inception.parameters():
16 | param.requires_grad = True
17 |
18 | # Input size
19 | first_conv_layer = [nn.Conv2d(2, 3, kernel_size=1, stride=1, bias=True),
20 | nn.AdaptiveMaxPool2d(299),
21 | self.inception.Conv2d_1a_3x3]
22 | self.inception.Conv2d_1a_3x3= nn.Sequential(*first_conv_layer)
23 |
24 | # Fit classifier
25 | self.inception.fc = nn.Sequential(
26 | nn.Linear(2048, 20),
27 | #nn.ReLU(inplace=True),
28 | #nn.BatchNorm1d(2048),
29 | #nn.Linear(2048, 1024),
30 | #nn.ReLU(inplace=True),
31 | #nn.BatchNorm1d(2048),
32 | #nn.Linear(1024, 20)
33 | )
34 |
35 | self.phase2dlayer = Phase2DLayer(20,128)
36 |
37 | def forward(self, x):
38 | if self.inception.training:
39 | z, _ = self.inception(x)
40 | else:
41 | z = self.inception(x)
42 | phase = self.phase2dlayer(z)
43 | return phase, z
44 |
45 | class Phase2D(torch.autograd.Function):
46 |
47 | @staticmethod
48 | def forward(ctx, input, z_basis):
49 | ctx.z_basis = z_basis.cpu()#.cuda()
50 | output = input[:,:, None, None] * ctx.z_basis[None, 1:,:,:]
51 | return torch.sum(output, dim=1)
52 |
53 | @staticmethod
54 | def backward(ctx, grad_output):
55 | dL_dy = grad_output.unsqueeze(1)
56 | dy_dz = ctx.z_basis[1:,:,:].unsqueeze(0)
57 | grad_input = torch.sum(dL_dy * dy_dz, dim=(2,3))
58 | return grad_input, None
59 |
60 | class Phase2DLayer(nn.Module):
61 | def __init__(self, input_features, output_features):
62 | super(Phase2DLayer, self).__init__()
63 | self.input_features = input_features
64 | self.output_features = output_features
65 | self.z_basis = aotools.zernikeArray(input_features+1, output_features, norm='rms')
66 | self.z_basis = torch.as_tensor(self.z_basis, dtype=torch.float32)
67 |
68 | def forward(self, input):
69 | return Phase2D.apply(input, self.z_basis)
70 |
71 |
72 | class BasicConv2d(nn.Module):
73 |
74 | def __init__(self, in_channels, out_channels, **kwargs):
75 | super(BasicConv2d, self).__init__()
76 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
77 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
78 |
79 | def forward(self, x):
80 | x = self.conv(x)
81 | x = self.bn(x)
82 | return F.relu(x, inplace=True)
--------------------------------------------------------------------------------
/src/pytorch/models/Resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import aotools
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torchvision.models as models
7 |
8 | class Net(nn.Module):
9 |
10 | def __init__(self):
11 | super(Net, self).__init__()
12 |
13 | self.resnet = models.resnet50(pretrained=True)
14 |
15 | for param in self.resnet.parameters():
16 | param.requires_grad = True
17 |
18 | # Input size 2x128x128 -> 2x224x224
19 | first_conv_layer = [nn.Conv2d(2, 3, kernel_size=1, stride=1, bias=True),
20 | nn.AdaptiveMaxPool2d(224),
21 | self.resnet.conv1]
22 | self.resnet.conv1= nn.Sequential(*first_conv_layer)
23 |
24 | # Fit classifier
25 | self.resnet.fc = nn.Sequential(
26 | nn.Linear(2048, 100),
27 | #nn.ReLU(inplace=True),
28 | #nn.BatchNorm1d(1024),
29 | #nn.Linear(1024, 1024),
30 | #nn.ReLU(inplace=True),
31 | #nn.BatchNorm1d(1024),
32 | #nn.Linear(1024, 20)
33 | )
34 |
35 | self.phase2dlayer = Phase2DLayer(100,128)
36 |
37 | def forward(self, x):
38 | # 128x128x2
39 | z = self.resnet(x)
40 | phase = self.phase2dlayer(z)
41 | return phase
42 |
43 | class Phase2D(torch.autograd.Function):
44 |
45 | @staticmethod
46 | def forward(ctx, input, z_basis):
47 | ctx.z_basis = z_basis.cuda()
48 | output = input[:,:, None, None] * ctx.z_basis[None, 1:,:,:]
49 | return torch.sum(output, dim=1)
50 |
51 | @staticmethod
52 | def backward(ctx, grad_output):
53 | dL_dy = grad_output.unsqueeze(1)
54 | dy_dz = ctx.z_basis[1:,:,:].unsqueeze(0)
55 | grad_input = torch.sum(dL_dy * dy_dz, dim=(2,3))
56 | return grad_input, None
57 |
58 | class Phase2DLayer(nn.Module):
59 | def __init__(self, input_features, output_features):
60 | super(Phase2DLayer, self).__init__()
61 | self.input_features = input_features
62 | self.output_features = output_features
63 | self.z_basis = aotools.zernikeArray(input_features+1, output_features, norm='rms')
64 | self.z_basis = torch.as_tensor(self.z_basis, dtype=torch.float32)
65 |
66 | def forward(self, input):
67 | return Phase2D.apply(input, self.z_basis)
68 |
--------------------------------------------------------------------------------
/src/pytorch/models/Unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class UNet(nn.Module):
6 | def __init__(self, n_channels_in, n_channels_out):
7 | super(UNet, self).__init__()
8 | self.inc = inconv(n_channels_in, 64)
9 | self.down1 = down(64, 128)
10 | self.down2 = down(128, 256)
11 | self.down3 = down(256, 512)
12 | self.down4 = down(512, 512)
13 | self.up1 = up(1024, 256, bilinear=True)
14 | self.up2 = up(512, 128, bilinear=True)
15 | self.up3 = up(256, 64, bilinear=True)
16 | self.up4 = up(128, 64, bilinear=True)
17 | self.outc = outconv(64, n_channels_out)
18 |
19 | def forward(self, x):
20 | x1 = self.inc(x)
21 | x2 = self.down1(x1)
22 | x3 = self.down2(x2)
23 | x4 = self.down3(x3)
24 | x5 = self.down4(x4)
25 | x = self.up1(x5, x4)
26 | x = self.up2(x, x3)
27 | x = self.up3(x, x2)
28 | x = self.up4(x, x1)
29 | x = self.outc(x)
30 | return x
31 |
32 | class double_conv(nn.Module):
33 | '''(conv => BN => ReLU) * 2'''
34 | def __init__(self, in_ch, out_ch):
35 | super(double_conv, self).__init__()
36 | self.conv = nn.Sequential(
37 | nn.Conv2d(in_ch, out_ch, 3, padding=1),
38 | nn.BatchNorm2d(out_ch),
39 | nn.ReLU(inplace=True),
40 | nn.Conv2d(out_ch, out_ch, 3, padding=1),
41 | nn.BatchNorm2d(out_ch),
42 | nn.ReLU(inplace=True)
43 | )
44 |
45 | def forward(self, x):
46 | x = self.conv(x)
47 | return x
48 |
49 | class inconv(nn.Module):
50 | def __init__(self, in_ch, out_ch):
51 | super(inconv, self).__init__()
52 | self.conv = double_conv(in_ch, out_ch)
53 |
54 | def forward(self, x):
55 | x = self.conv(x)
56 | return x
57 |
58 | class down(nn.Module):
59 | def __init__(self, in_ch, out_ch):
60 | super(down, self).__init__()
61 | self.mpconv = nn.Sequential(
62 | nn.MaxPool2d(2),
63 | double_conv(in_ch, out_ch)
64 | )
65 |
66 | def forward(self, x):
67 | x = self.mpconv(x)
68 | return x
69 |
70 | class up(nn.Module):
71 | def __init__(self, in_ch, out_ch, bilinear=True):
72 | super(up, self).__init__()
73 |
74 | # would be a nice idea if the upsampling could be learned too,
75 | # but my machine do not have enough memory to handle all those weights
76 | if bilinear:
77 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
78 | else:
79 | self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
80 |
81 | self.conv = double_conv(in_ch, out_ch)
82 |
83 | def forward(self, x1, x2):
84 | x1 = self.up(x1)
85 |
86 | # input is CHW
87 | diffY = x2.size()[2] - x1.size()[2]
88 | diffX = x2.size()[3] - x1.size()[3]
89 |
90 | x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
91 | diffY // 2, diffY - diffY//2))
92 |
93 | x = torch.cat([x2, x1], dim=1)
94 | x = self.conv(x)
95 | return x
96 |
97 | class outconv(nn.Module):
98 | def __init__(self, in_ch, out_ch):
99 | super(outconv, self).__init__()
100 | self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1)
101 |
102 | def forward(self, x):
103 | x = self.conv(x)
104 | return x
105 |
106 |
--------------------------------------------------------------------------------
/src/pytorch/models/Unet_PP.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 |
4 | from torch import nn
5 | from torch.nn import functional as F
6 | import torch
7 | from torchvision import models
8 | import torchvision
9 |
10 |
11 | class VGGBlock(nn.Module):
12 | def __init__(self, in_channels, middle_channels, out_channels, act_func=nn.ReLU(inplace=True)):
13 | super(VGGBlock, self).__init__()
14 | self.act_func = act_func
15 | self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
16 | self.bn1 = nn.BatchNorm2d(middle_channels)
17 | self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
18 | self.bn2 = nn.BatchNorm2d(out_channels)
19 |
20 | def forward(self, x):
21 | out = self.conv1(x)
22 | out = self.bn1(out)
23 | out = self.act_func(out)
24 |
25 | out = self.conv2(out)
26 | out = self.bn2(out)
27 | out = self.act_func(out)
28 |
29 | return out
30 |
31 |
32 | class UNet(nn.Module):
33 | def __init__(self, args):
34 | super().__init__()
35 |
36 | self.args = args
37 |
38 | nb_filter = [32, 64, 128, 256, 512]
39 |
40 | self.pool = nn.MaxPool2d(2, 2)
41 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
42 |
43 | self.conv0_0 = VGGBlock(args.input_channels, nb_filter[0], nb_filter[0])
44 | self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
45 | self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
46 | self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
47 | self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
48 |
49 | self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
50 | self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
51 | self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
52 | self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
53 |
54 | self.final = nn.Conv2d(nb_filter[0], 1, kernel_size=1)
55 |
56 |
57 | def forward(self, input):
58 | x0_0 = self.conv0_0(input)
59 | x1_0 = self.conv1_0(self.pool(x0_0))
60 | x2_0 = self.conv2_0(self.pool(x1_0))
61 | x3_0 = self.conv3_0(self.pool(x2_0))
62 | x4_0 = self.conv4_0(self.pool(x3_0))
63 |
64 | x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
65 | x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
66 | x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
67 | x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))
68 |
69 | output = self.final(x0_4)
70 | return output
71 |
72 |
73 | class NestedUNet(nn.Module):
74 | def __init__(self):
75 | super().__init__()
76 |
77 | #self.args = args
78 |
79 | nb_filter = [32, 64, 96, 128, 256]
80 |
81 | self.pool = nn.MaxPool2d(2, 2)
82 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
83 |
84 | self.conv0_0 = VGGBlock(2, nb_filter[0], nb_filter[0])
85 | self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
86 | self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
87 | self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
88 | self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
89 |
90 | self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
91 | self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
92 | self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
93 | self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
94 |
95 | self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
96 | self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
97 | self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])
98 |
99 | self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
100 | self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])
101 |
102 | self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])
103 |
104 | #if self.args.deepsupervision:
105 | # self.final1 = nn.Conv2d(nb_filter[0], 1, kernel_size=1)
106 | # self.final2 = nn.Conv2d(nb_filter[0], 1, kernel_size=1)
107 | # self.final3 = nn.Conv2d(nb_filter[0], 1, kernel_size=1)
108 | # self.final4 = nn.Conv2d(nb_filter[0], 1, kernel_size=1)
109 | #else:
110 | self.final = nn.Conv2d(nb_filter[0], 1, kernel_size=1)
111 |
112 |
113 | def forward(self, input):
114 | x0_0 = self.conv0_0(input)
115 | x1_0 = self.conv1_0(self.pool(x0_0))
116 | x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
117 |
118 | x2_0 = self.conv2_0(self.pool(x1_0))
119 | x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
120 | x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
121 |
122 | x3_0 = self.conv3_0(self.pool(x2_0))
123 | x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
124 | x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
125 | x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
126 |
127 | x4_0 = self.conv4_0(self.pool(x3_0))
128 | x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
129 | x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
130 | x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
131 | x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
132 | '''
133 | if self.args.deepsupervision:
134 | output1 = self.final1(x0_1)
135 | output2 = self.final2(x0_2)
136 | output3 = self.final3(x0_3)
137 | output4 = self.final4(x0_4)
138 | return [output1, output2, output3, output4]
139 |
140 | else:
141 | '''
142 | output = self.final(x0_4)
143 | return output
144 |
--------------------------------------------------------------------------------
/src/pytorch/models/VGG.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import aotools
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | class Net(nn.Module):
8 |
9 | def __init__(self):
10 | super(Net, self).__init__()
11 |
12 | self.conv_a1 = BasicConv2d(2, 32, kernel_size=3, stride=1, padding=1)
13 | self.conv_a2 = BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
14 |
15 | self.conv_b1 = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
16 | self.conv_b2 = BasicConv2d(64, 64, kernel_size=3, stride=1, padding=1)
17 |
18 | self.conv_c1 = BasicConv2d(64, 128, kernel_size=3, stride=1, padding=1)
19 | self.conv_c2 = BasicConv2d(128, 128, kernel_size=3, stride=1, padding=1)
20 |
21 | self.conv_d1 = BasicConv2d(128, 256, kernel_size=3, stride=1, padding=1)
22 | self.conv_d2 = BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1)
23 |
24 | self.conv_e1 = BasicConv2d(256, 512, kernel_size=3, stride=1, padding=1)
25 | self.conv_e2 = BasicConv2d(512, 512, kernel_size=3, stride=1, padding=1)
26 |
27 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
28 |
29 | self.fc1 = torch.nn.Linear(4*4*512, 1024)
30 | #self.fc1_bn = nn.BatchNorm1d(1024)
31 | self.fc2 = torch.nn.Linear(1024, 1024)
32 | #self.fc2_bn = nn.BatchNorm1d(1024)
33 | self.fc3 = torch.nn.Linear(1024, 20)
34 |
35 | self.phase2dlayer = Phase2DLayer(20,128)
36 |
37 | def forward(self, x):
38 | # 128x128x2
39 | x = self.conv_a1(x) # 128x128x16
40 | x = self.conv_a2(x)
41 | x = self.pool(x)
42 | x = self.conv_b1(x) # 64x64x32
43 | x = self.conv_b2(x)
44 | x = self.pool(x)
45 | x = self.conv_c1(x) # 32x32x64
46 | x = self.conv_c2(x)
47 | x = self.pool(x)
48 | x = self.conv_d1(x) # 16x16x128
49 | x = self.conv_d2(x)
50 | x = self.pool(x)
51 | x = self.conv_e1(x) # 8x8x512
52 | x = self.conv_e2(x) # 8x8x512
53 | x = self.pool(x)
54 | x = x.view(-1, 4*4*512)
55 | x = F.relu(self.fc1(x))
56 | x = F.relu(self.fc2(x))
57 | z_coeffs = self.fc3(x)
58 | phase = self.phase2dlayer(z_coeffs)
59 | return phase, z_coeffs
60 |
61 |
62 | class Phase2D(torch.autograd.Function):
63 |
64 | @staticmethod
65 | def forward(ctx, input, z_basis):
66 | ctx.z_basis = z_basis.cpu() #.cuda()
67 | output = input[:,:, None, None] * ctx.z_basis[None, 1:,:,:]
68 | return torch.sum(output, dim=1)
69 |
70 | @staticmethod
71 | def backward(ctx, grad_output):
72 | dL_dy = grad_output.unsqueeze(1)
73 | dy_dz = ctx.z_basis[1:,:,:].unsqueeze(0)
74 | grad_input = torch.sum(dL_dy * dy_dz, dim=(2,3))
75 | return grad_input, None
76 |
77 | class Phase2DLayer(nn.Module):
78 | def __init__(self, input_features, output_features):
79 | super(Phase2DLayer, self).__init__()
80 | self.input_features = input_features
81 | self.output_features = output_features
82 | self.z_basis = aotools.zernikeArray(input_features+1, output_features, norm='rms')
83 | self.z_basis = torch.as_tensor(self.z_basis, dtype=torch.float32)
84 |
85 | def forward(self, input):
86 | return Phase2D.apply(input, self.z_basis)
87 |
88 |
89 | class BasicConv2d(nn.Module):
90 |
91 | def __init__(self, in_channels, out_channels, **kwargs):
92 | super(BasicConv2d, self).__init__()
93 | self.conv = nn.Conv2d(in_channels, out_channels, bias=True, **kwargs)
94 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
95 |
96 | def forward(self, x):
97 | x = self.conv(x)
98 | x = self.bn(x)
99 | return F.relu(x, inplace=True)
--------------------------------------------------------------------------------
/src/pytorch/models/__pycache__/Unet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/povanberg/Machine-learning-for-image-based-wavefront-sensing/a687f422d822a7c1db76375a0a4a67a60b03721d/src/pytorch/models/__pycache__/Unet.cpython-36.pyc
--------------------------------------------------------------------------------
/src/pytorch/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import utils
4 | import json
5 | import logging
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.optim as optim
10 | from torchvision import transforms
11 | from dataset import psf_dataset, splitDataLoader, ToTensor, Normalize
12 | from utils_visdom import VisdomWebServer
13 | import aotools
14 | from criterion import *
15 |
16 | def train(model, dataset, optimizer, criterion, split=[0.9, 0.1], batch_size=32,
17 | n_epochs=1, model_dir='./', random_seed=None, visdom=False):
18 |
19 | # Create directory if doesn't exist
20 | if not os.path.exists(model_dir):
21 | os.makedirs(model_dir)
22 |
23 | # Logging
24 | log_path = os.path.join(model_dir, 'logs.log')
25 | utils.set_logger(log_path)
26 |
27 | # Visdom support
28 | if visdom:
29 | vis = VisdomWebServer()
30 |
31 | # Dataset
32 | dataloaders = {}
33 | dataloaders['train'], dataloaders['val'] = splitDataLoader(dataset, split=split,
34 | batch_size=batch_size, random_seed=random_seed)
35 |
36 | # ---
37 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.1)
38 | #scheduler = CosineWithRestarts(optimizer, T_max=40, eta_min=1e-7, last_epoch=-1)
39 | #scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30, eta_min=1e-7, last_epoch=-1)
40 |
41 | # Metrics
42 | metrics_path = os.path.join(model_dir, 'metrics.json')
43 |
44 | metrics = {
45 | 'model': model_dir,
46 | 'optimizer': optimizer.__class__.__name__,
47 | 'criterion': criterion.__class__.__name__,
48 | 'scheduler': scheduler.__class__.__name__,
49 | 'dataset_size': int(len(dataset)),
50 | 'train_size': int(split[0]*len(dataset)),
51 | 'test_size': int(split[1]*len(dataset)),
52 | 'n_epoch': n_epochs,
53 | 'batch_size': batch_size,
54 | 'learning_rate': [],
55 | 'train_loss': [],
56 | 'val_loss': [],
57 | 'zernike_train_loss': [],
58 | 'zernike_val_loss': []
59 | }
60 |
61 | # Zernike basis
62 | z_basis = torch.as_tensor(aotools.zernikeArray(100+1, 128, norm='rms'), dtype=torch.float32)
63 |
64 | # Device
65 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
66 |
67 | # Training
68 | since = time.time()
69 | dataset_size = {
70 | 'train':int(split[0]*len(dataset)),
71 | 'val':int(split[1]*len(dataset))
72 | }
73 |
74 |
75 | best_loss = 0.0
76 |
77 | for epoch in range(n_epochs):
78 |
79 | logging.info('-'*30)
80 | epoch_time = time.time()
81 |
82 | # Each epoch has a training and validation phase
83 | for phase in ['train', 'val']:
84 | if phase == 'train':
85 | model.train() # Set model to training mode
86 | else:
87 | model.eval() # Set model to evaluate mode
88 |
89 | running_loss = 0.0
90 | zernike_loss =0.0
91 |
92 | for _, sample in enumerate(dataloaders[phase]):
93 | # GPU support
94 | inputs = sample['image'].to(device)
95 | phase_0 = sample['phase'].to(device)
96 |
97 | # zero the parameter gradients
98 | optimizer.zero_grad()
99 |
100 | # forward: track history if only in train
101 | with torch.set_grad_enabled(phase == 'train'):
102 |
103 | # Network return phase and zernike coeffs
104 | phase_estimation = model(inputs)
105 | loss = criterion(torch.squeeze(phase_estimation), phase_0)
106 |
107 | # backward
108 | if phase == 'train':
109 | loss.backward()
110 | optimizer.step()
111 |
112 | running_loss += 1 * loss.item() * inputs.size(0)
113 |
114 | logging.info('[%i/%i] %s loss: %f' % (epoch+1, n_epochs, phase, running_loss / dataset_size[phase]))
115 |
116 | # Update metrics
117 | metrics[phase+'_loss'].append(running_loss / dataset_size[phase])
118 | #metrics['zernike_'+phase+'_loss'].append(zernike_loss / dataset_size[phase])
119 | if phase=='train':
120 | metrics['learning_rate'].append(get_lr(optimizer))
121 |
122 | # Adaptive learning rate
123 | if phase == 'val':
124 | scheduler.step()
125 | # Save weigths
126 | if epoch == 0 or running_loss < best_loss:
127 | best_loss = running_loss
128 | model_path = os.path.join(model_dir, 'model.pth')
129 | torch.save(model.state_dict(), model_path)
130 | # Save metrics
131 | with open(metrics_path, 'w') as f:
132 | json.dump(metrics, f, indent=4)
133 | # Visdom update
134 | if visdom:
135 | vis.update(metrics)
136 |
137 | logging.info('[%i/%i] Time: %f s' % (epoch + 1, n_epochs, time.time()-epoch_time))
138 |
139 | time_elapsed = time.time() - since
140 | logging.info('[-----] All epochs completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
141 |
142 |
143 |
144 | def get_lr(optimizer):
145 | for p in optimizer.param_groups:
146 | lr = p['lr']
147 | return lr
148 |
--------------------------------------------------------------------------------
/src/pytorch/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import json
3 | import os
4 | import matplotlib.pyplot as plt
5 |
6 | def set_logger(log_path):
7 | """
8 | Set the logger to log info in terminal and file `log_path`.
9 |
10 | Args:
11 | log_path: (string) where to log
12 | """
13 | logger = logging.getLogger()
14 | logger.setLevel(logging.INFO)
15 |
16 | if not logger.handlers:
17 | # Logging to a file
18 | file_handler = logging.FileHandler(log_path)
19 | file_handler.setFormatter(logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s'))
20 | logger.addHandler(file_handler)
21 |
22 | # Logging to console
23 | stream_handler = logging.StreamHandler()
24 | stream_handler.setFormatter(logging.Formatter('%(message)s'))
25 | logger.addHandler(stream_handler)
26 |
27 |
28 |
29 |
30 | class Params():
31 | """
32 | Class that loads hyperparameters from a json file.
33 |
34 | params = Params(json_path)
35 | print(params.learning_rate)
36 | params.learning_rate = 0.5 # change the value of learning_rate in params
37 | """
38 |
39 | def __init__(self, json_path):
40 |
41 | if not os.path.exists(json_path):
42 | with open(json_path, 'w') as f:
43 | data = {}
44 | json.dump(data, f, indent=4)
45 |
46 | with open(json_path) as f:
47 | params = json.load(f)
48 | self.__dict__.update(params)
49 |
50 | def save(self, json_path):
51 | with open(json_path, 'w') as f:
52 | json.dump(self.__dict__, f, indent=4)
53 |
54 | def update(self, json_path):
55 | """Loads parameters from json file"""
56 | with open(json_path) as f:
57 | params = json.load(f)
58 | self.__dict__.update(params)
59 |
60 | def hasKey(self, json_path, key_name):
61 | bool_key = False
62 | with open(json_path) as f:
63 | params = json.load(f)
64 | if key_name in params:
65 | bool_key = True
66 |
67 | return bool_key
68 |
69 | @property
70 | def dict(self):
71 | """Gives dict-like access to Params instance by `params.dict['learning_rate']"""
72 | return self.__dict__
73 |
74 |
75 | def plot_learningcurve(metrics, save=True, show=True, name='lrcurve.pdf',
76 | xlim=[None,None], ylim=[None,None], zernike=False):
77 | import numpy as np
78 | plt.figure()
79 | #x = np.arange(200)
80 | #plt.plot(x, np.array(metrics['train_loss' if not zernike else 'zernike_train_loss'])[x]/(0.8*np.log(x)), label='Training loss', color='blue')
81 | plt.plot(metrics['train_loss' if not zernike else 'zernike_train_loss'][:], label='Training loss', color='blue')
82 | plt.plot(metrics['val_loss' if not zernike else 'zernike_val_loss'][:], label='Validation loss', color='red')
83 | plt.legend()
84 | plt.grid()
85 | plt.xlim(xlim[0], xlim[1])
86 | plt.ylim(ylim[0], ylim[1])
87 | plt.xlabel('epochs')
88 | plt.ylabel('loss')
89 | if save: plt.savefig(name)
90 | if show: plt.show()
91 |
92 |
93 | def get_metrics(model_dir=''):
94 |
95 | metrics_path = os.path.join(model_dir, 'metrics.json')
96 |
97 | with open(metrics_path) as f:
98 | metrics = json.load(f)
99 | return metrics
100 |
101 | return None
102 |
--------------------------------------------------------------------------------
/src/pytorch/utils_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class BasicConv2d(nn.Module):
6 | def __init__(self, in_channels, out_channels, **kwargs):
7 | super(BasicConv2d, self).__init__()
8 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
9 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
10 |
11 | def forward(self, x):
12 | x = self.conv(x)
13 | x = self.bn(x)
14 | return F.relu(x, inplace=True)
15 |
16 |
17 | class Phase2D(torch.autograd.Function):
18 |
19 | @staticmethod
20 | def forward(ctx, input, z_basis):
21 | ctx.z_basis = z_basis.cuda()
22 | output = input[:,:, None, None] * ctx.z_basis[None, 1:,:,:]
23 | return torch.sum(output, dim=1)
24 |
25 | @staticmethod
26 | def backward(ctx, grad_output):
27 | dL_dy = grad_output.unsqueeze(1)
28 | dy_dz = ctx.z_basis[1:,:,:].unsqueeze(0)
29 | grad_input = torch.sum(dL_dy * dy_dz, dim=(2,3))
30 | return grad_input, None
31 |
32 | class Phase2DLayer(nn.Module):
33 | def __init__(self, input_features, output_features):
34 | super(Phase2DLayer, self).__init__()
35 | self.input_features = input_features
36 | self.output_features = output_features
37 | self.z_basis = aotools.zernikeArray(input_features+1, output_features, norm='rms')
38 | self.z_basis = torch.as_tensor(self.z_basis, dtype=torch.float32)
39 |
40 | def forward(self, input):
41 | return Phase2D.apply(input, self.z_basis)
42 |
--------------------------------------------------------------------------------
/src/pytorch/utils_visdom.py:
--------------------------------------------------------------------------------
1 | from visdom import Visdom
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import json
5 | from utils import plot_learningcurve
6 |
7 | # Start web server with: python -m visdom.server
8 |
9 | class VisdomWebServer(object):
10 |
11 | def __init__(self):
12 |
13 | DEFAULT_PORT = 8097
14 | DEFAULT_HOSTNAME = "http://localhost"
15 |
16 | self.vis = Visdom(port=DEFAULT_PORT, server=DEFAULT_HOSTNAME)
17 |
18 | def update(self, metrics):
19 |
20 | if not self.vis.check_connection():
21 | 'No connection could be formed quickly'
22 | return
23 |
24 | # Learning curve
25 | try:
26 | fig, ax = plt.subplots()
27 | plt.plot(metrics['train_loss'], label='Training loss', color='#32526e')
28 | plt.plot(metrics['val_loss'], label='Validation loss', color='#ff6b57')
29 | plt.legend()
30 | ax.spines['right'].set_visible(False)
31 | ax.spines['top'].set_visible(False)
32 | plt.grid(zorder=0, color='lightgray', linestyle='--')
33 | self.vis.matplot(plt, win='lrcurve')
34 | plt.close()
35 | plt.clf()
36 |
37 | fig, ax = plt.subplots()
38 | plt.plot(metrics['learning_rate'], color='#32526e')
39 | ax.spines['right'].set_visible(False)
40 | ax.spines['top'].set_visible(False)
41 | plt.grid(zorder=0, color='lightgray', linestyle='--')
42 | self.vis.matplot(plt, win='lr_rate')
43 | plt.close()
44 | plt.clf()
45 |
46 | #plt.figure()
47 | #plt.plot(metrics['zernike_train_loss'], label='Zernike train loss', color='blue')
48 | #plt.plot(metrics['zernike_val_loss'], label='Zernike val loss', color='red')
49 | #plt.legend()
50 | #plt.grid()
51 | #self.vis.matplot(plt, win='lrcurve_z')
52 | #plt.close()
53 | #plt.clf()
54 | except BaseException as err:
55 | print('Skipped matplotlib example')
56 | print('Error message: ', err)
57 |
58 |
59 |
60 | if __name__ == "__main__":
61 |
62 | from utils import get_metrics
63 |
64 | metrics = get_metrics('experiments/example')
65 |
66 | visdom = VisdomWebServer()
67 | visdom.update(metrics)
68 |
69 |
--------------------------------------------------------------------------------