├── Demo_correction_aberration.ipynb ├── PyTorchAberrations ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── aberration_functions.cpython-36.pyc │ ├── aberration_functions.cpython-38.pyc │ ├── aberration_layers.cpython-36.pyc │ ├── aberration_layers.cpython-38.pyc │ ├── aberration_models.cpython-36.pyc │ └── aberration_models.cpython-38.pyc ├── aberration_functions.py ├── aberration_layers.py └── aberration_models.py └── README.md /PyTorchAberrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wavefrontshaping/PyTorchAberrations/8a5dfe677081971e0d26e4868119fecf29588f40/PyTorchAberrations/__init__.py -------------------------------------------------------------------------------- /PyTorchAberrations/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wavefrontshaping/PyTorchAberrations/8a5dfe677081971e0d26e4868119fecf29588f40/PyTorchAberrations/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /PyTorchAberrations/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wavefrontshaping/PyTorchAberrations/8a5dfe677081971e0d26e4868119fecf29588f40/PyTorchAberrations/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /PyTorchAberrations/__pycache__/aberration_functions.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wavefrontshaping/PyTorchAberrations/8a5dfe677081971e0d26e4868119fecf29588f40/PyTorchAberrations/__pycache__/aberration_functions.cpython-36.pyc -------------------------------------------------------------------------------- /PyTorchAberrations/__pycache__/aberration_functions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wavefrontshaping/PyTorchAberrations/8a5dfe677081971e0d26e4868119fecf29588f40/PyTorchAberrations/__pycache__/aberration_functions.cpython-38.pyc -------------------------------------------------------------------------------- /PyTorchAberrations/__pycache__/aberration_layers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wavefrontshaping/PyTorchAberrations/8a5dfe677081971e0d26e4868119fecf29588f40/PyTorchAberrations/__pycache__/aberration_layers.cpython-36.pyc -------------------------------------------------------------------------------- /PyTorchAberrations/__pycache__/aberration_layers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wavefrontshaping/PyTorchAberrations/8a5dfe677081971e0d26e4868119fecf29588f40/PyTorchAberrations/__pycache__/aberration_layers.cpython-38.pyc -------------------------------------------------------------------------------- /PyTorchAberrations/__pycache__/aberration_models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wavefrontshaping/PyTorchAberrations/8a5dfe677081971e0d26e4868119fecf29588f40/PyTorchAberrations/__pycache__/aberration_models.cpython-36.pyc -------------------------------------------------------------------------------- /PyTorchAberrations/__pycache__/aberration_models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wavefrontshaping/PyTorchAberrations/8a5dfe677081971e0d26e4868119fecf29588f40/PyTorchAberrations/__pycache__/aberration_models.cpython-38.pyc -------------------------------------------------------------------------------- /PyTorchAberrations/aberration_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def complex_matmul(A,B): 5 | ''' 6 | Matrix multiplication for complex tensors. 7 | Tensors have to have a last dimension of size 2 for real and imaginary parts. 8 | The -2 and -3 dimensions are the 2 dimensions to multiply. 9 | Other previous dimensions are considered as batch dimensions (cf PyTorch matmul() function). 10 | ''' 11 | return torch.stack((A[...,0].matmul(B[...,0])-A[...,1].matmul(B[...,1]), 12 | A[...,0].matmul(B[...,1])+A[...,1].matmul(B[...,0])),dim=-1) 13 | 14 | def complex_mul(A,B): 15 | ''' 16 | Element-wise multiplication for complex tensors. 17 | Tensors have to have a last dimension of size 2 for real and imaginary parts. 18 | The -2 and -3 dimensions are the 2 dimensions to multiply. 19 | Other previous dimensions are considered as batch dimensions (cf PyTorch mul() function). 20 | ''' 21 | return torch.stack((A[...,0].mul(B[...,0])-A[...,1].mul(B[...,1]), 22 | A[...,0].mul(B[...,1])+A[...,1].mul(B[...,0])),dim=-1) 23 | 24 | 25 | def pi2_shift(A): 26 | return torch.stack((-A[...,1],A[...,0]),dim=-1) 27 | 28 | 29 | def conjugate(A): 30 | return torch.stack((A[...,0],-A[...,1]), dim=-1) 31 | 32 | def complex_fft(A,*args, **kwargs): 33 | return torch.view_as_complex(torch.fft(torch.view_as_real(A),*args, **kwargs)) 34 | 35 | def complex_ifft(A,*args, **kwargs): 36 | return torch.view_as_complex(torch.ifft(torch.view_as_real(A),*args, **kwargs)) 37 | 38 | def complex_fftshift(A): 39 | A = torch.view_as_real(A) 40 | 41 | n_x = A.shape[-3] 42 | n_y = A.shape[-2] 43 | 44 | return torch.view_as_complex(torch.cat( \ 45 | (torch.cat((A[...,n_x//2:,n_y//2:,:],A[...,:n_x//2,n_y//2:,:]), dim = -3), 46 | torch.cat((A[...,n_x//2:,:n_y//2,:],A[...,:n_x//2,:n_y//2,:]), dim = -3)), dim = -2)) 47 | 48 | def complex_ifftshift(A): 49 | A = torch.view_as_real(A) 50 | n_x = A.shape[-3] 51 | n_y = A.shape[-2] 52 | offset_x = n_x%2 53 | offset_y = n_y%2 54 | return torch.view_as_complex(torch.cat( \ 55 | (torch.cat((A[...,n_x//2+offset_x:,n_y//2+offset_y:,:],A[...,:n_x//2+offset_x,n_y//2+offset_y:,:]), dim = -3), 56 | torch.cat((A[...,n_x//2+offset_x:,:n_y//2+offset_y,:],A[...,:n_x//2+offset_x,:n_y//2+offset_y,:]), dim = -3)), 57 | dim = -2)) 58 | 59 | def crop_center(input, size): 60 | x = input.shape[1] 61 | y = input.shape[2] 62 | start_x = x//2-(size//2) 63 | start_y = y//2-(size//2) 64 | return input[:,start_x:start_x+size,start_y:start_y+size,...] 65 | 66 | def pt_to_cpx(A): 67 | return np.array(A[...,0])+1j*np.array(A[...,1]) 68 | 69 | 70 | def cpx_to_pt(A, device, dtype = torch.float32): 71 | return torch.stack((torch.from_numpy(A.real), 72 | torch.from_numpy(A.imag)), dim = -1).type(dtype).to(device) 73 | 74 | def norm2(A,device): 75 | return torch.sqrt(torch.sum(complex_mul(A,conjugate(A))[...,0],dim = 1)) 76 | 77 | def normalize(A,device): 78 | b = norm2(A, device = device) 79 | 80 | mid_dim = A.shape[1] 81 | zeros = torch.zeros(mid_dim,device = device) 82 | divider = torch.meshgrid(b,zeros)[0] 83 | normalized = torch.stack((A[:,:,0] / divider, 84 | A[:,:,1] / divider), dim = -1) 85 | return normalized 86 | 87 | 88 | def tm_to_pt(A, device, dtype = torch.float32): 89 | if len(A.shape) == 2: 90 | return torch.stack((torch.from_numpy(A.real), 91 | torch.from_numpy(A.imag))).permute((1,2,0)).type(dtype).to(device) 92 | else: 93 | return torch.stack((torch.from_numpy(A.real), 94 | torch.from_numpy(A.imag))).permute((1,2,3,0)).type(dtype).to(device) 95 | -------------------------------------------------------------------------------- /PyTorchAberrations/aberration_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module, ZeroPad2d 3 | from PyTorchAberrations.aberration_functions import complex_mul, conjugate, pi2_shift 4 | 5 | 6 | 7 | 8 | 9 | ################################################################ 10 | ################### AUTOGRAD FUNCTIONS ######################### 11 | ################################################################ 12 | 13 | class ComplexZernikeFunction(torch.autograd.Function): 14 | ''' 15 | Function that apply a complex Zernike polynomial to the phase of a batch 16 | of compleximages (or a matrix). 17 | ''' 18 | @staticmethod 19 | def forward(ctx, input, alpha, j): 20 | 21 | 22 | nx = torch.arange(0,1,1./input.shape[1], dtype = torch.float32) 23 | ny = torch.arange(0,1,1./input.shape[2], dtype = torch.float32) 24 | 25 | X0, Y0 = 0.5+0.5/input.shape[1], 0.5+0.5/input.shape[2] 26 | X,Y = torch.meshgrid(nx,ny) 27 | X = X.to(input.device)-X0 28 | Y = Y.to(input.device)-Y0 29 | 30 | # see https://en.wikipedia.org/wiki/Zernike_polynomials 31 | if j == 0: 32 | F = torch.ones_like(X) 33 | elif j == 1: 34 | F = X 35 | elif j == 2: 36 | F = Y 37 | elif j == 3: 38 | # Oblique astigmatism 39 | F = 2.*X.mul(Y) 40 | elif j == 4: 41 | # Defocus 42 | F = X**2+Y**2 43 | elif j == 5: 44 | # Vertical astigmatism 45 | F = X**2-Y**2 46 | else: 47 | R = torch.sqrt(X**2+Y**2) 48 | THETA = torch.atan2(Y, X) 49 | if j == 6: 50 | # Vertical trefoil 51 | F = torch.mul(R**3, torch.sin(3.*THETA)) 52 | elif j == 7: 53 | # Vertical coma 54 | F = torch.mul(3.*R**3,torch.sin(3.*THETA)) 55 | elif j == 8: 56 | # Horizontal coma 57 | F = torch.mul(3.*R**3,torch.cos(3.*THETA)) 58 | elif j == 9: 59 | # Oblique trefoil 60 | F = torch.mul(R**3, torch.cos(3.*THETA)) 61 | elif j == 10: 62 | # Oblique quadrafoil 63 | F = 2.*torch.mul(R**4, torch.sin(4.*THETA)) 64 | elif j == 11: 65 | # Oblique secondary astigmatism 66 | F = 2.*torch.mul(4.*R**4-3.*R**2, torch.sin(2.*THETA)) 67 | elif j == 12: 68 | # Primary spherical 69 | F = 6.*R**4-6.*R**2 + torch.ones_like(R) 70 | elif j == 13: 71 | # Vertical secondary astigmatism 72 | F = 2.*torch.mul(4.*R**4-3.*R**2, torch.cos(2.*THETA)) 73 | elif j == 14: 74 | # Vertical quadrafoil 75 | F = 2.*torch.mul(R**4, torch.cos(4.*THETA)) 76 | else: 77 | raise 78 | 79 | weight = torch.exp(1j*alpha*F) 80 | 81 | ctx.save_for_backward(input, alpha, F) 82 | output = input*weight 83 | # output = torch.view_as_real(output) 84 | 85 | 86 | # weight = torch.stack((torch.cos(alpha*F), 87 | # torch.sin(alpha*F)), dim = -1) 88 | # ctx.save_for_backward(input, weight, alpha, F) 89 | 90 | # output = complex_mul(input,weight) 91 | 92 | 93 | 94 | 95 | # print(input.dtype) 96 | # print(alpha.dtype) 97 | # print('++'*100) 98 | 99 | return output 100 | 101 | @staticmethod 102 | def backward(ctx, grad_output): 103 | input, alpha, F = ctx.saved_tensors 104 | 105 | weight = torch.exp(1j*alpha*F) 106 | 107 | grad_input = grad_alpha = None 108 | if ctx.needs_input_grad[0]: 109 | # grad_input = torch.view_as_complex(grad_output)*weight.conj() 110 | grad_input = grad_output*weight.conj() 111 | # pass 112 | 113 | 114 | if ctx.needs_input_grad[1]: 115 | pass 116 | grad_alpha = torch.sum(grad_output*(1j*F*weight*input).conj()).real 117 | # print(grad_alpha.dtype) 118 | # print('*'*410) 119 | # grad_alpha.imag = 0. 120 | grad_alpha.unsqueeze_(0) 121 | 122 | return grad_input, grad_alpha, None 123 | # return torch.view_as_real(grad_input), grad_alpha, None 124 | 125 | 126 | 127 | ####################################################### 128 | #################### MODULES ########################## 129 | ####################################################### 130 | 131 | # class ComplexZeroPad2d(Module): 132 | # ''' 133 | # Apply zero padding to a batch of 2D complex images (or matrix) 134 | # ''' 135 | # def __init__(self, padding): 136 | # super(ComplexZeroPad2d, self).__init__() 137 | # self.pad_r = ZeroPad2d(padding) 138 | # self.pad_i = ZeroPad2d(padding) 139 | 140 | # def forward(self,input): 141 | # return torch.stack((self.pad_r(input[...,0]), 142 | # self.pad_i(input[...,1])), dim = -1) 143 | 144 | class ComplexZernike(Module): 145 | ''' 146 | Layer that apply a complex Zernike polynomial to the phase of a batch 147 | of compleximages (or a matrix). 148 | Only one parameter, the strenght of the polynomial, is learned. 149 | Initial value is 0. 150 | ''' 151 | def __init__(self, j): 152 | super(ComplexZernike, self).__init__() 153 | assert j in range(15) 154 | self.j = j 155 | self.alpha = torch.nn.Parameter(torch.zeros(1), requires_grad=True) 156 | 157 | 158 | def forward(self, input): 159 | return ComplexZernikeFunction.apply(input, self.alpha, self.j) 160 | 161 | class ComplexScaling(Module): 162 | ''' 163 | Layer that apply a global scaling to a stack of 2D complex images (or matrix). 164 | Only one parameter, the scaling factor, is learned. 165 | Initial value is 1. 166 | ''' 167 | def __init__(self): 168 | super(ComplexScaling, self).__init__() 169 | 170 | self.theta = torch.nn.Parameter(torch.zeros(1), requires_grad=True) 171 | # parameters 0 and 4 are the ones corresponding to x and y scaling 172 | # parameters 1 and 3 are the ones corresponding to shearing 173 | # parameters 2 and 6 are shifts 174 | 175 | def forward(self, input): 176 | input = torch.view_as_real(input).permute((0,3,1,2)) 177 | 178 | grid = torch.nn.functional.affine_grid( 179 | ((1.+self.theta)*(torch.tensor([1, 0., 0., 0., 1, 0.], 180 | dtype=input.dtype).to(input.device)) 181 | ).reshape((2,3)).expand((input.shape[0],2,3)), 182 | input.size()) 183 | 184 | return torch.view_as_complex(torch.nn.functional.grid_sample(input, grid, align_corners=True).permute((0,2,3,1)).contiguous()) 185 | 186 | class ComplexDeformation(Module): 187 | ''' 188 | Layer that apply a global affine transformation to a stack of 2D complex images (or matrix). 189 | 6 parameters are learned. 190 | ''' 191 | def __init__(self): 192 | super(ComplexDeformation, self).__init__() 193 | 194 | self.theta = torch.nn.Parameter(torch.tensor([0., 0, 0, 0, 0., 0])) 195 | # parameters 0 and 4 are the ones corresponding to x and y scaling 196 | # parameters 1 and 3 are the ones corresponding to shearing 197 | # parameters 2 and 6 are shifts 198 | 199 | def forward(self, input): 200 | input = torch.view_as_real(input).permute((0,3,1,2)) 201 | grid = torch.nn.functional.affine_grid( 202 | ((1.+self.theta).mul(torch.tensor([1, 0., 0., 0., 1, 0.], 203 | dtype=input.dtype).to(input.device)) 204 | ).reshape((2,3)).expand((input.shape[0],2,3)), 205 | input.size()) 206 | 207 | return torch.view_as_complex(torch.nn.functional.grid_sample(input, grid, align_corners=True).permute((0,2,3,1))) 208 | -------------------------------------------------------------------------------- /PyTorchAberrations/aberration_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.nn import Module, Sequential, Identity 4 | from torch.nn import ZeroPad2d 5 | from PyTorchAberrations.aberration_layers import ComplexDeformation 6 | from PyTorchAberrations.aberration_layers import ComplexZernike, ComplexScaling 7 | from PyTorchAberrations.aberration_functions import crop_center, complex_fftshift 8 | from PyTorchAberrations.aberration_functions import complex_ifftshift, conjugate, normalize 9 | from PyTorchAberrations.aberration_functions import complex_fft, complex_ifft 10 | 11 | class AberrationModes(torch.nn.Module): 12 | ''' 13 | Model for input and output aberrations. 14 | Apply an `Aberration` model to the input and output mode basis. 15 | ''' 16 | def __init__(self, 17 | inpoints, 18 | onpoints, 19 | padding_coeff = 0., 20 | list_zernike_ft = list(range(3)), 21 | list_zernike_direct = list(range(3)), 22 | deformation = 'single'): 23 | super(AberrationModes, self).__init__() 24 | self.abberation_output = Aberration(onpoints, 25 | list_zernike_ft = list_zernike_ft, 26 | list_zernike_direct = list_zernike_direct, 27 | padding_coeff = padding_coeff, 28 | deformation = deformation) 29 | self.abberation_input = Aberration(inpoints, 30 | list_zernike_ft = list_zernike_ft, 31 | list_zernike_direct = list_zernike_direct, 32 | padding_coeff = padding_coeff, 33 | deformation = deformation) 34 | self.inpoints = inpoints 35 | self.onpoints = onpoints 36 | 37 | def forward(self,input, output): 38 | 39 | output_modes = output 40 | output_modes = self.abberation_output(output_modes) 41 | # output_modes = normalize(output_modes.reshape((-1,self.onpoints**2,2)),device = self.device).reshape((-1,self.onpoints,self.onpoints,2)) 42 | 43 | 44 | input_modes = input 45 | input_modes = self.abberation_input(input_modes) 46 | # input_modes = normalize(input_modes.reshape((-1,self.inpoints**2,2)),device = self.device).reshape((-1,self.inpoints,self.inpoints,2)) 47 | 48 | return output_modes, input_modes 49 | 50 | 51 | class Aberration(torch.nn.Module): 52 | ''' 53 | Model that apply aberrations (direct and Fourier plane) and a global scaling 54 | at the input dimension of a matrix. 55 | ''' 56 | def __init__(self, 57 | shape, 58 | list_zernike_ft, 59 | list_zernike_direct, 60 | padding_coeff = 0., 61 | deformation = 'single', 62 | features = None): 63 | # Here we define the type of Model we want to be using, the number of polynoms and if we want to implement a deformation. 64 | super(Aberration, self).__init__() 65 | 66 | #Check whether the model is given the lists of zernike polynoms to use or simply the total number to use 67 | if type(list_zernike_direct) not in [list, np.ndarray]: 68 | list_zernike_direct = range(0,list_zernike_direct) 69 | if type(list_zernike_ft) not in [list, np.ndarray]: 70 | list_zernike_ft = range(0,list_zernike_ft) 71 | 72 | self.nxy = shape 73 | 74 | # padding layer, to have a good FFT resolution 75 | # (requires to crop after IFFT) 76 | padding = int(padding_coeff*self.nxy) 77 | self.pad = ZeroPad2d(padding) 78 | 79 | # scaling x, y 80 | if deformation == 'single': 81 | self.deformation = ComplexDeformation() 82 | elif deformation == 'scaling': 83 | self.deformation = ComplexScaling() 84 | else: 85 | self.deformation = Identity() 86 | 87 | self.zernike_ft = Sequential(*(ComplexZernike(j=j + 1) for j in list_zernike_ft)) 88 | self.zernike_direct = Sequential(*(ComplexZernike(j=j + 1) for j in list_zernike_direct)) 89 | 90 | 91 | def forward(self,input): 92 | assert(input.shape[1] == input.shape[2]) 93 | 94 | # padding 95 | input = self.pad(torch.view_as_complex(input)) 96 | 97 | # scaling 98 | input = self.deformation(input) 99 | #self.deformation(input) 100 | 101 | # to Fourier domain 102 | input = complex_ifftshift(input) 103 | input = complex_fft(input, 2) 104 | input = complex_fftshift(input) 105 | # input = torch.view_as_real(complex_fftshift(input)) 106 | 107 | # Zernike layers in the Fourier plane 108 | input = self.zernike_ft(input) 109 | 110 | # to direct domain 111 | # input = torch.view_as_complex(input) 112 | input = complex_ifftshift(input) 113 | input = complex_ifft(input, 2) 114 | input = complex_fftshift(input) 115 | # input = torch.view_as_real(input) 116 | 117 | # Zernike layers in the direct plane 118 | input = self.zernike_direct(input) 119 | 120 | # Crop at the center (because of coeff) 121 | input = crop_center(input,self.nxy) 122 | 123 | return torch.view_as_real(input) 124 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorchAberrations 2 | Differentiable aberration layers for PyTorch using Zernike polynomials. 3 | --------------------------------------------------------------------------------