├── tests └── __init__.py ├── src └── spdc_inv │ ├── data │ ├── __init__.py │ ├── utils.py │ └── interaction.py │ ├── loss │ ├── __init__.py │ └── loss.py │ ├── models │ ├── __init__.py │ ├── utils.py │ └── spdc_model.py │ ├── optim │ ├── __init__.py │ └── optimizer.py │ ├── training │ ├── __init__.py │ ├── utils.py │ └── trainer.py │ ├── utils │ ├── __init__.py │ ├── defaults.py │ └── utils.py │ ├── experiments │ ├── __init__.py │ ├── utils.py │ ├── results_and_stats_utils.py │ └── experiment.py │ └── __init__.py ├── Optica.jpg ├── illustration.png ├── Data availability ├── 3a.npy ├── 3b.npy ├── 5a.npy ├── 5b.npy ├── 3c_imag.npy ├── 3c_real.npy ├── 7a_imag.npy ├── 7a_real.npy ├── 7b_imag.npy └── 7b_real.npy ├── .gitignore ├── setup.py ├── matlab ├── readNPY.m ├── LGfigures │ ├── readNPY.m │ ├── plotcube.m │ ├── readNPYheader.m │ └── LGfigures.m ├── readNPYheader.m └── LGfigures.m ├── environment.yml ├── README.md └── LICENSE /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/spdc_inv/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/spdc_inv/loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/spdc_inv/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/spdc_inv/optim/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/spdc_inv/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/spdc_inv/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/spdc_inv/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Optica.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EyalRozenberg1/SPDCinv/HEAD/Optica.jpg -------------------------------------------------------------------------------- /illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EyalRozenberg1/SPDCinv/HEAD/illustration.png -------------------------------------------------------------------------------- /Data availability/3a.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EyalRozenberg1/SPDCinv/HEAD/Data availability/3a.npy -------------------------------------------------------------------------------- /Data availability/3b.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EyalRozenberg1/SPDCinv/HEAD/Data availability/3b.npy -------------------------------------------------------------------------------- /Data availability/5a.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EyalRozenberg1/SPDCinv/HEAD/Data availability/5a.npy -------------------------------------------------------------------------------- /Data availability/5b.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EyalRozenberg1/SPDCinv/HEAD/Data availability/5b.npy -------------------------------------------------------------------------------- /Data availability/3c_imag.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EyalRozenberg1/SPDCinv/HEAD/Data availability/3c_imag.npy -------------------------------------------------------------------------------- /Data availability/3c_real.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EyalRozenberg1/SPDCinv/HEAD/Data availability/3c_real.npy -------------------------------------------------------------------------------- /Data availability/7a_imag.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EyalRozenberg1/SPDCinv/HEAD/Data availability/7a_imag.npy -------------------------------------------------------------------------------- /Data availability/7a_real.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EyalRozenberg1/SPDCinv/HEAD/Data availability/7a_real.npy -------------------------------------------------------------------------------- /Data availability/7b_imag.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EyalRozenberg1/SPDCinv/HEAD/Data availability/7b_imag.npy -------------------------------------------------------------------------------- /Data availability/7b_real.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EyalRozenberg1/SPDCinv/HEAD/Data availability/7b_real.npy -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | helpers.py 3 | 4 | SPDC.py 5 | 6 | .idea/ 7 | 8 | __pycache__/ 9 | 10 | *.npy 11 | 12 | *.asv 13 | 14 | *.zip 15 | 16 | results/ 17 | 18 | logs/ 19 | 20 | .DS_Store 21 | 22 | matlab/HGfigures/1.fig 23 | 24 | *.png 25 | 26 | *.fig 27 | -------------------------------------------------------------------------------- /src/spdc_inv/utils/defaults.py: -------------------------------------------------------------------------------- 1 | COINCIDENCE_RATE = 'coincidence_rate' 2 | DENSITY_MATRIX = 'density_matrix' 3 | TOMOGRAPHY_MATRIX = 'tomography_matrix' 4 | REAL = 'real' 5 | IMAG = 'imag' 6 | QUBIT = 'qubit' 7 | QUTRIT = 'qutrit' 8 | qubit_projection_n_state2 = 6 9 | qubit_tomography_dimensions = 2 10 | qutrit_projection_n_state2 = 15 11 | qutrit_tomography_dimensions = 3 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from setuptools import find_packages 3 | 4 | import os 5 | 6 | 7 | setup(name='spdc_inv', 8 | version='1.0', 9 | author='Eyal Rozenberg', 10 | author_email='eyalr@campus.technion.ac.il', 11 | url='https://github.com/EyalRozenberg1/spdc_inv', 12 | license='', 13 | classifiers=[ 14 | 'Programming Language :: Python :: 3.7', 15 | ], 16 | package_dir={'spdc_inv': os.path.join(os.getcwd(), 'src')}, 17 | packages=find_packages(exclude=[]), 18 | install_requires=[ 19 | 'numpy', 20 | 'scipy', 21 | 'matplotlib', 22 | ], 23 | ) 24 | -------------------------------------------------------------------------------- /src/spdc_inv/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from pathlib import Path 3 | from pkg_resources import get_distribution, DistributionNotFound 4 | 5 | import os 6 | 7 | try: 8 | dist_name = "spdc_inv" 9 | __version__ = get_distribution(dist_name).version 10 | 11 | except DistributionNotFound: 12 | __version__ = 'unknown' 13 | 14 | finally: 15 | del get_distribution, DistributionNotFound 16 | 17 | PROJECT_ROOT = Path(Path(__file__).resolve().parents[2]) 18 | if str(PROJECT_ROOT).startswith(os.getcwd()): 19 | PROJECT_ROOT = PROJECT_ROOT.relative_to(os.getcwd()) 20 | 21 | SRC_ROOT = PROJECT_ROOT.joinpath('src') 22 | PACKAGE_ROOT = SRC_ROOT.joinpath('spdc_inv') 23 | DATA_DIR = PROJECT_ROOT.joinpath('data') 24 | LOGS_DIR = PROJECT_ROOT.joinpath('logs') 25 | RES_DIR = PROJECT_ROOT.joinpath('results') 26 | -------------------------------------------------------------------------------- /matlab/readNPY.m: -------------------------------------------------------------------------------- 1 | function data = readNPY(filename) 2 | % Function to read NPY files into matlab. 3 | % *** Only reads a subset of all possible NPY files, specifically N-D arrays of certain data types. 4 | % See https://github.com/kwikteam/npy-matlab/blob/master/tests/npy.ipynb for 5 | % more. 6 | % 7 | 8 | [shape, dataType, fortranOrder, littleEndian, totalHeaderLength, ~] = readNPYheader(filename); 9 | 10 | if littleEndian 11 | fid = fopen(filename, 'r', 'l'); 12 | else 13 | fid = fopen(filename, 'r', 'b'); 14 | end 15 | 16 | try 17 | 18 | [~] = fread(fid, totalHeaderLength, 'uint8'); 19 | 20 | % read the data 21 | data = fread(fid, prod(shape), [dataType '=>' dataType]); 22 | 23 | if length(shape)>1 && ~fortranOrder 24 | data = reshape(data, shape(end:-1:1)); 25 | data = permute(data, [length(shape):-1:1]); 26 | elseif length(shape)>1 27 | data = reshape(data, shape); 28 | end 29 | 30 | fclose(fid); 31 | 32 | catch me 33 | fclose(fid); 34 | rethrow(me); 35 | end -------------------------------------------------------------------------------- /matlab/LGfigures/readNPY.m: -------------------------------------------------------------------------------- 1 | function data = readNPY(filename) 2 | % Function to read NPY files into matlab. 3 | % *** Only reads a subset of all possible NPY files, specifically N-D arrays of certain data types. 4 | % See https://github.com/kwikteam/npy-matlab/blob/master/tests/npy.ipynb for 5 | % more. 6 | % 7 | 8 | [shape, dataType, fortranOrder, littleEndian, totalHeaderLength, ~] = readNPYheader(filename); 9 | 10 | if littleEndian 11 | fid = fopen(filename, 'r', 'l'); 12 | else 13 | fid = fopen(filename, 'r', 'b'); 14 | end 15 | 16 | try 17 | 18 | [~] = fread(fid, totalHeaderLength, 'uint8'); 19 | 20 | % read the data 21 | data = fread(fid, prod(shape), [dataType '=>' dataType]); 22 | 23 | if length(shape)>1 && ~fortranOrder 24 | data = reshape(data, shape(end:-1:1)); 25 | data = permute(data, [length(shape):-1:1]); 26 | elseif length(shape)>1 27 | data = reshape(data, shape); 28 | end 29 | 30 | fclose(fid); 31 | 32 | catch me 33 | fclose(fid); 34 | rethrow(me); 35 | end -------------------------------------------------------------------------------- /src/spdc_inv/data/utils.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | 3 | 4 | def nz_MgCLN_Gayer( 5 | lam: float, 6 | T: float, 7 | ax: str=None, 8 | ): 9 | """ 10 | Refractive index for MgCLN, based on Gayer et al, APB 2008 11 | 12 | Parameters 13 | ---------- 14 | lam: wavelength (lambda) [um] 15 | T: Temperature [Celsius Degrees] 16 | ax: polarization 17 | 18 | Returns 19 | ------- 20 | nz: Refractive index on z polarization 21 | 22 | """ 23 | a = np.array([5.756, 0.0983, 0.2020, 189.32, 12.52, 1.32 * 10 ** (-2)]) 24 | b = np.array([2.860 * 10 ** (-6), 4.700 * 10 ** (-8), 6.113 * 10 ** (-8), 1.516 * 10 ** (-4)]) 25 | f = (T - 24.5) * (T + 570.82) 26 | 27 | n1 = a[0] 28 | n2 = b[0] * f 29 | n3 = (a[1] + b[1] * f) / (lam ** 2 - (a[2] + b[2] * f) ** 2) 30 | n4 = (a[3] + b[3] * f) / (lam ** 2 - (a[4]) ** 2) 31 | n5 = -a[5] * lam ** 2 32 | 33 | nz = np.sqrt(n1 + n2 + n3 + n4 + n5) 34 | return nz 35 | 36 | 37 | def n_KTP_Kato( 38 | lam: float, 39 | T: float, 40 | ax: str, 41 | ): 42 | """ 43 | Refractive index for KTP, based on K. Kato 44 | 45 | Parameters 46 | ---------- 47 | lam: wavelength (lambda) [um] 48 | T: Temperature [Celsius Degrees] 49 | ax: polarization 50 | 51 | Returns 52 | ------- 53 | n: Refractive index 54 | 55 | """ 56 | assert ax in ['z', 'y'], 'polarization must be either z or y' 57 | dT = (T - 20) 58 | if ax == "z": 59 | n_no_T_dep = np.sqrt(4.59423 + 0.06206 / (lam ** 2 - 0.04763) + 110.80672 / (lam ** 2 - 86.12171)) 60 | dn = (0.9221 / lam ** 3 - 2.9220 / lam ** 2 + 3.6677 / lam - 0.1897) * 1e-5 * dT 61 | if ax == "y": 62 | n_no_T_dep = np.sqrt(3.45018 + 0.04341 / (lam ** 2 - 0.04597) + 16.98825 / (lam ** 2 - 39.43799)) 63 | dn = (0.1997 / lam ** 3 - 0.4063 / lam ** 2 + 0.5154 / lam + 0.5425) * 1e-5 * dT 64 | n = n_no_T_dep + dn 65 | return n -------------------------------------------------------------------------------- /src/spdc_inv/optim/optimizer.py: -------------------------------------------------------------------------------- 1 | from jax.experimental import optimizers 2 | 3 | 4 | def Optimizer( 5 | optimizer: str = 'adam', 6 | exp_decay_lr: bool = True, 7 | step_size: float = 0.05, 8 | decay_steps: int = 50, 9 | decay_rate: float = 0.5, 10 | ): 11 | 12 | assert optimizer in ['adam', 13 | 'sgd', 14 | 'adagrad', 15 | 'adamax', 16 | 'momentum', 17 | 'nesterov', 18 | 'rmsprop', 19 | 'rmsprop_momentum'], 'non-standard optimizer choice' 20 | 21 | if exp_decay_lr: 22 | step_schedule = optimizers.exponential_decay(step_size=step_size, 23 | decay_steps=decay_steps, 24 | decay_rate=decay_rate) 25 | else: 26 | step_schedule = step_size 27 | 28 | # Use optimizers to set optimizer initialization and update functions 29 | if optimizer == 'adam': 30 | return optimizers.adam(step_schedule, b1=0.9, b2=0.999, eps=1e-08) 31 | 32 | elif optimizer == 'sgd': 33 | return optimizers.sgd(step_schedule) 34 | 35 | elif optimizer == 'adagrad': 36 | return optimizers.adagrad(step_schedule) 37 | 38 | elif optimizer == 'adamax': 39 | return optimizers.adamax(step_schedule, b1=0.9, b2=0.999, eps=1e-08) 40 | 41 | elif optimizer == 'momentum': 42 | return optimizers.momentum(step_schedule, mass=1e-02) 43 | 44 | elif optimizer == 'nesterov': 45 | return optimizers.nesterov(step_schedule, mass=1e-02) 46 | 47 | elif optimizer == 'rmsprop': 48 | return optimizers.rmsprop(step_schedule, gamma=0.9, eps=1e-08) 49 | 50 | elif optimizer == 'rmsprop_momentum': 51 | return optimizers.rmsprop_momentum(step_schedule, gamma=0.9, eps=1e-08, momentum=0.9) 52 | -------------------------------------------------------------------------------- /matlab/LGfigures/plotcube.m: -------------------------------------------------------------------------------- 1 | function plotcube(varargin) 2 | % PLOTCUBE - Display a 3D-cube in the current axes 3 | % 4 | % PLOTCUBE(EDGES,ORIGIN,ALPHA,COLOR) displays a 3D-cube in the current axes 5 | % with the following properties: 6 | % * EDGES : 3-elements vector that defines the length of cube edges 7 | % * ORIGIN: 3-elements vector that defines the start point of the cube 8 | % * ALPHA : scalar that defines the transparency of the cube faces (from 0 9 | % to 1) 10 | % * COLOR : 3-elements vector that defines the faces color of the cube 11 | % 12 | % Example: 13 | % >> plotcube([5 5 5],[ 2 2 2],.8,[1 0 0]); 14 | % >> plotcube([5 5 5],[10 10 10],.8,[0 1 0]); 15 | % >> plotcube([5 5 5],[20 20 20],.8,[0 0 1]); 16 | % Default input arguments 17 | inArgs = { ... 18 | [10 56 100] , ... % Default edge sizes (x,y and z) 19 | [10 10 10] , ... % Default coordinates of the origin point of the cube 20 | .7 , ... % Default alpha value for the cube's faces 21 | [1 0 0] ... % Default Color for the cube 22 | }; 23 | % Replace default input arguments by input values 24 | inArgs(1:nargin) = varargin; 25 | % Create all variables 26 | [edges,origin,alpha,clr] = deal(inArgs{:}); 27 | XYZ = { ... 28 | [0 0 0 0] [0 0 1 1] [0 1 1 0] ; ... 29 | [1 1 1 1] [0 0 1 1] [0 1 1 0] ; ... 30 | [0 1 1 0] [0 0 0 0] [0 0 1 1] ; ... 31 | [0 1 1 0] [1 1 1 1] [0 0 1 1] ; ... 32 | [0 1 1 0] [0 0 1 1] [0 0 0 0] ; ... 33 | [0 1 1 0] [0 0 1 1] [1 1 1 1] ... 34 | }; 35 | XYZ = mat2cell(... 36 | cellfun( @(x,y,z) x*y+z , ... 37 | XYZ , ... 38 | repmat(mat2cell(edges,1,[1 1 1]),6,1) , ... 39 | repmat(mat2cell(origin,1,[1 1 1]),6,1) , ... 40 | 'UniformOutput',false), ... 41 | 6,[1 1 1]); 42 | cellfun(@patch,XYZ{1},XYZ{2},XYZ{3},... 43 | repmat({clr},6,1),... 44 | repmat({'FaceAlpha'},6,1),... 45 | repmat({alpha},6,1), ... 46 | repmat({'EdgeColor'},6,1), ... 47 | repmat({'None'},6,1) ... 48 | ); 49 | view(3); 50 | -------------------------------------------------------------------------------- /matlab/readNPYheader.m: -------------------------------------------------------------------------------- 1 | function [arrayShape, dataType, fortranOrder, littleEndian, totalHeaderLength, npyVersion] = readNPYheader(filename) 2 | % function [arrayShape, dataType, fortranOrder, littleEndian, ... 3 | % totalHeaderLength, npyVersion] = readNPYheader(filename) 4 | % 5 | % parse the header of a .npy file and return all the info contained 6 | % therein. 7 | % 8 | % Based on spec at http://docs.scipy.org/doc/numpy-dev/neps/npy-format.html 9 | 10 | fid = fopen(filename); 11 | 12 | % verify that the file exists 13 | if (fid == -1) 14 | if ~isempty(dir(filename)) 15 | error('Permission denied: %s', filename); 16 | else 17 | error('File not found: %s', filename); 18 | end 19 | end 20 | 21 | try 22 | 23 | dtypesMatlab = {'uint8','uint16','uint32','uint64','int8','int16','int32','int64','single','double', 'logical'}; 24 | dtypesNPY = {'u1', 'u2', 'u4', 'u8', 'i1', 'i2', 'i4', 'i8', 'f4', 'f8', 'b1'}; 25 | 26 | 27 | magicString = fread(fid, [1 6], 'uint8=>uint8'); 28 | 29 | if ~all(magicString == [147,78,85,77,80,89]) 30 | error('readNPY:NotNUMPYFile', 'Error: This file does not appear to be NUMPY format based on the header.'); 31 | end 32 | 33 | majorVersion = fread(fid, [1 1], 'uint8=>uint8'); 34 | minorVersion = fread(fid, [1 1], 'uint8=>uint8'); 35 | 36 | npyVersion = [majorVersion minorVersion]; 37 | 38 | headerLength = fread(fid, [1 1], 'uint16=>uint16'); 39 | 40 | totalHeaderLength = 10+headerLength; 41 | 42 | arrayFormat = fread(fid, [1 headerLength], 'char=>char'); 43 | 44 | % to interpret the array format info, we make some fairly strict 45 | % assumptions about its format... 46 | 47 | r = regexp(arrayFormat, '''descr''\s*:\s*''(.*?)''', 'tokens'); 48 | dtNPY = r{1}{1}; 49 | 50 | littleEndian = ~strcmp(dtNPY(1), '>'); 51 | 52 | dataType = dtypesMatlab{strcmp(dtNPY(2:3), dtypesNPY)}; 53 | 54 | r = regexp(arrayFormat, '''fortran_order''\s*:\s*(\w+)', 'tokens'); 55 | fortranOrder = strcmp(r{1}{1}, 'True'); 56 | 57 | r = regexp(arrayFormat, '''shape''\s*:\s*\((.*?)\)', 'tokens'); 58 | shapeStr = r{1}{1}; 59 | arrayShape = str2num(shapeStr(shapeStr~='L')); 60 | 61 | 62 | fclose(fid); 63 | 64 | catch me 65 | fclose(fid); 66 | rethrow(me); 67 | end -------------------------------------------------------------------------------- /matlab/LGfigures/readNPYheader.m: -------------------------------------------------------------------------------- 1 | function [arrayShape, dataType, fortranOrder, littleEndian, totalHeaderLength, npyVersion] = readNPYheader(filename) 2 | % function [arrayShape, dataType, fortranOrder, littleEndian, ... 3 | % totalHeaderLength, npyVersion] = readNPYheader(filename) 4 | % 5 | % parse the header of a .npy file and return all the info contained 6 | % therein. 7 | % 8 | % Based on spec at http://docs.scipy.org/doc/numpy-dev/neps/npy-format.html 9 | 10 | fid = fopen(filename); 11 | 12 | % verify that the file exists 13 | if (fid == -1) 14 | if ~isempty(dir(filename)) 15 | error('Permission denied: %s', filename); 16 | else 17 | error('File not found: %s', filename); 18 | end 19 | end 20 | 21 | try 22 | 23 | dtypesMatlab = {'uint8','uint16','uint32','uint64','int8','int16','int32','int64','single','double', 'logical'}; 24 | dtypesNPY = {'u1', 'u2', 'u4', 'u8', 'i1', 'i2', 'i4', 'i8', 'f4', 'f8', 'b1'}; 25 | 26 | 27 | magicString = fread(fid, [1 6], 'uint8=>uint8'); 28 | 29 | if ~all(magicString == [147,78,85,77,80,89]) 30 | error('readNPY:NotNUMPYFile', 'Error: This file does not appear to be NUMPY format based on the header.'); 31 | end 32 | 33 | majorVersion = fread(fid, [1 1], 'uint8=>uint8'); 34 | minorVersion = fread(fid, [1 1], 'uint8=>uint8'); 35 | 36 | npyVersion = [majorVersion minorVersion]; 37 | 38 | headerLength = fread(fid, [1 1], 'uint16=>uint16'); 39 | 40 | totalHeaderLength = 10+headerLength; 41 | 42 | arrayFormat = fread(fid, [1 headerLength], 'char=>char'); 43 | 44 | % to interpret the array format info, we make some fairly strict 45 | % assumptions about its format... 46 | 47 | r = regexp(arrayFormat, '''descr''\s*:\s*''(.*?)''', 'tokens'); 48 | dtNPY = r{1}{1}; 49 | 50 | littleEndian = ~strcmp(dtNPY(1), '>'); 51 | 52 | dataType = dtypesMatlab{strcmp(dtNPY(2:3), dtypesNPY)}; 53 | 54 | r = regexp(arrayFormat, '''fortran_order''\s*:\s*(\w+)', 'tokens'); 55 | fortranOrder = strcmp(r{1}{1}, 'True'); 56 | 57 | r = regexp(arrayFormat, '''shape''\s*:\s*\((.*?)\)', 'tokens'); 58 | shapeStr = r{1}{1}; 59 | arrayShape = str2num(shapeStr(shapeStr~='L')); 60 | 61 | 62 | fclose(fid); 63 | 64 | catch me 65 | fclose(fid); 66 | rethrow(me); 67 | end -------------------------------------------------------------------------------- /matlab/LGfigures.m: -------------------------------------------------------------------------------- 1 | l = 0:1:9; 2 | G2_sim = double(readNPY('G2 (29).npy')); 3 | G2_sim = reshape(G2_sim,[9,9]); 4 | G2_sim = abs(G2_sim)/sum(sum(abs(G2_sim))); 5 | figure; imagesc(l,l,G2_sim); axis square; colorbar; title('simulation') 6 | 7 | h=figure; b=bar3(G2_sim); 8 | zlim([0 0.22]) 9 | xlabel('j (idler)'); ylabel('u (signal)'); zlabel('probability'); 10 | title('Sim 30mm, MaxX=300, dx=8, dz=10, N=4000') 11 | set(gca, 'XTickLabel', {'-4', '-3', '-2', '-1', '0', '1', '2', '3', '4','5'}, 'FontSize', 20, 'FontName', 'Calibri') 12 | set(gca, 'YTickLabel',{'-4', '-3', '-2', '-1', '0', '1', '2', '3', '4','5'}, 'FontSize', 20, 'FontName', 'Calibri') 13 | % colormap pink; 14 | for k = 1:length(b) 15 | zdata = get(b(k),'ZData'); 16 | b(k).CData = zdata; 17 | b(k).FaceColor = 'interp'; 18 | end 19 | 20 | 21 | %% 22 | 23 | 24 | 25 | % PolingCoeffs = readNPY('PlusMinus2Pair_L1mm_pump70um\PolingCoeffs.npy'); 26 | % %PolingCoeffs = reshape(PolingCoeffs, [5, 9]); 27 | % 28 | % dx =1; %um 29 | % MaxX = 120; %um 30 | % pump_waist = 70; %um 31 | % x = -MaxX:dx:MaxX; 32 | % y = x; 33 | % r0 = sqrt(2)*pump_waist; 34 | % 35 | % 36 | % [X,Y] = meshgrid(x,y); 37 | % Phi = atan2(Y,X); 38 | % Rad = sqrt(X.^2+Y.^2)/r0; 39 | % Profile = 0; 40 | % alphas = [2.4048, 5.5201, 8.6537, 11.7915, 14.9309]; 41 | % 42 | % for p = 0:4 43 | % for ll = -4:4 44 | % Profile = Profile + PolingCoeffs(ll+5+9*p).*besselj(0, alphas(p+1)*Rad).*exp(-1j*ll*Phi); 45 | % end 46 | % end 47 | % Magnitude = abs(Profile)/max(max(abs(Profile))); 48 | % phase = angle(Profile); 49 | % dutycycle = asin(Magnitude)/pi; 50 | % figure; imagesc(real(Profile)); colorbar 51 | % 52 | % Z = -50:0.1:50; %um 53 | % Lambda = 6.9328; %um 54 | % DeltaK = 2*pi/Lambda; 55 | % Poling = zeros(length(x),length(y), length(Z)); 56 | % for i = 1:length(Z) 57 | % i 58 | % z = Z(i); 59 | % for m = 0:100 60 | % if m == 0 61 | % Poling(:,:,i) = Poling(:,:,i) + 2*dutycycle - 1; 62 | % else 63 | % Poling(:,:,i) = Poling(:,:,i) + (2/(m*pi)).*sin(pi*m*dutycycle).*2.*cos(m*DeltaK*z + m * phase); 64 | % end 65 | % end 66 | % end 67 | % 68 | % %% 69 | % figure; imagesc(sign(squeeze(Poling(:,:,502)))); colorbar 70 | % figure; 71 | % subplot(3,1,1); 72 | % imagesc(x,Z,sign(squeeze(Poling(:,round(end/3),:)))); ylabel('x[\mum]'); xlabel('z[\mum]'); text(-100,40,strcat('y=', num2str(y(round(end/3))),'um'),'Color','white','FontSize',16); set(gca, 'FontSize', 18); axis image 73 | % subplot(3,1,2); 74 | % imagesc(x,Z,sign(squeeze(Poling(:,round(end/2),:)))); ylabel('x[\mum]'); xlabel('z[\mum]'); text(-100,40,strcat('y=', num2str(y(round(end/2))),'um'),'Color','white','FontSize',16); set(gca, 'FontSize', 18); axis image 75 | % subplot(3,1,3); 76 | % imagesc(x,Z,sign(squeeze(Poling(:,round(2*end/3),:)))); ylabel('x[\mum]'); xlabel('z[\mum]'); text(-100,40,strcat('y=', num2str(y(round(2*end/3))),'um'),'Color','white','FontSize',16); set(gca, 'FontSize', 18); axis image 77 | 78 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: spdc 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _pytorch_select=0.2=gpu_0 7 | - _tflow_select=2.3.0=mkl 8 | - absl-py=0.9.0=py37_0 9 | - asn1crypto=1.3.0=py37_0 10 | - astor=0.8.0=py37_0 11 | - blas=1.0=mkl 12 | - blinker=1.4=py37_0 13 | - c-ares=1.15.0=h7b6447c_1001 14 | - ca-certificates=2020.1.1=0 15 | - cachetools=3.1.1=py_0 16 | - certifi=2019.11.28=py37_1 17 | - cffi=1.13.2=py37h2e261b9_0 18 | - chardet=3.0.4=py37_1003 19 | - click=7.0=py37_0 20 | - cryptography=2.8=py37h1ba5d50_0 21 | - cudatoolkit=10.0.130=0 22 | - cudnn=7.6.5=cuda10.0_0 23 | - cycler=0.10.0=py37_0 24 | - dbus=1.13.12=h746ee38_0 25 | - expat=2.2.6=he6710b0_0 26 | - fontconfig=2.13.0=h9420a91_0 27 | - freetype=2.9.1=h8a8886c_1 28 | - glib=2.63.1=h5a9c865_0 29 | - google-auth=1.11.2=py_0 30 | - google-auth-oauthlib=0.4.1=py_2 31 | - google-pasta=0.1.8=py_0 32 | - grpcio=1.27.2=py37hf8bcb03_0 33 | - gst-plugins-base=1.14.0=hbbd80ab_1 34 | - gstreamer=1.14.0=hb453b48_1 35 | - h5py=2.10.0=py37h7918eee_0 36 | - hdf5=1.10.4=hb1b8bf9_0 37 | - icu=58.2=h9c2bf20_1 38 | - idna=2.8=py37_0 39 | - intel-openmp=2019.4=243 40 | - jpeg=9b=h024ee3a_2 41 | - keras-applications=1.0.8=py_0 42 | - keras-preprocessing=1.1.0=py_1 43 | - kiwisolver=1.1.0=py37he6710b0_0 44 | - ld_impl_linux-64=2.33.1=h53a641e_7 45 | - libedit=3.1.20181209=hc058e9b_0 46 | - libffi=3.2.1=hd88cf55_4 47 | - libgcc-ng=9.1.0=hdf63c60_0 48 | - libgfortran-ng=7.3.0=hdf63c60_0 49 | - libpng=1.6.37=hbc83047_0 50 | - libprotobuf=3.11.4=hd408876_0 51 | - libstdcxx-ng=9.1.0=hdf63c60_0 52 | - libuuid=1.0.3=h1bed415_2 53 | - libxcb=1.13=h1bed415_1 54 | - libxml2=2.9.9=hea5a465_1 55 | - markdown=3.1.1=py37_0 56 | - matplotlib=3.1.1=py37h5429711_0 57 | - mkl=2019.4=243 58 | - mkl-service=2.3.0=py37he904b0f_0 59 | - mkl_fft=1.0.15=py37ha843d7b_0 60 | - mkl_random=1.1.0=py37hd6b4f25_0 61 | - ncurses=6.1=he6710b0_1 62 | - ninja=1.9.0=py37hfd86e86_0 63 | - numpy=1.18.1=py37h4f9e942_0 64 | - numpy-base=1.18.1=py37hde5b4d6_1 65 | - oauthlib=3.1.0=py_0 66 | - openssl=1.1.1e=h7b6447c_0 67 | - opt_einsum=3.1.0=py_0 68 | - pcre=8.43=he6710b0_0 69 | - pip=20.0.2=py37_1 70 | - protobuf=3.11.4=py37he6710b0_0 71 | - pyasn1=0.4.8=py_0 72 | - pyasn1-modules=0.2.7=py_0 73 | - pycparser=2.19=py37_0 74 | - pyjwt=1.7.1=py37_0 75 | - pyopenssl=19.1.0=py37_0 76 | - pyparsing=2.4.6=py_0 77 | - pyqt=5.9.2=py37h05f1152_2 78 | - pysocks=1.7.1=py37_0 79 | - python=3.7.6=h0371630_2 80 | - python-dateutil=2.8.1=py_0 81 | - pytorch=1.3.1=cuda100py37h53c1284_0 82 | - pytz=2019.3=py_0 83 | - qt=5.9.7=h5867ecd_1 84 | - readline=7.0=h7b6447c_5 85 | - requests=2.22.0=py37_1 86 | - requests-oauthlib=1.3.0=py_0 87 | - rsa=4.0=py_0 88 | - scipy=1.4.1=py37h0b6359f_0 89 | - setuptools=45.1.0=py37_0 90 | - sip=4.19.8=py37hf484d3e_0 91 | - six=1.14.0=py37_0 92 | - sqlite=3.30.1=h7b6447c_0 93 | - tensorboard=2.1.0=py3_0 94 | - tensorflow=2.1.0=mkl_py37h80a91df_0 95 | - tensorflow-base=2.1.0=mkl_py37h6d63fb7_0 96 | - tensorflow-estimator=2.1.0=pyhd54b08b_0 97 | - termcolor=1.1.0=py37_1 98 | - tk=8.6.8=hbc83047_0 99 | - tornado=6.0.3=py37h7b6447c_0 100 | - urllib3=1.25.8=py37_0 101 | - werkzeug=1.0.0=py_0 102 | - wheel=0.33.6=py37_0 103 | - wrapt=1.11.2=py37h7b6447c_0 104 | - xz=5.2.4=h14c3975_4 105 | - zlib=1.2.11=h7b6447c_3 106 | - pip: 107 | - astunparse==1.6.3 108 | - attrs==19.3.0 109 | - blessings==1.7 110 | - dill==0.3.1.1 111 | - future==0.18.2 112 | - gast==0.3.3 113 | - googleapis-common-protos==1.51.0 114 | - gpustat==0.6.0 115 | - jax==0.1.66 116 | - jax-loss==0.0.1 117 | - jaxlib==0.1.46 118 | - nvidia-ml-py3==7.352.0 119 | - promise==2.3 120 | - psutil==5.7.0 121 | - tb-nightly==2.2.0a20200324 122 | - tensorboard-plugin-wit==1.6.0.post2 123 | - tensorflow-metadata==0.21.2 124 | - tf-estimator-nightly==2.3.0.dev2020032401 125 | - tf-nightly==2.2.0.dev20200324 126 | - tfds-nightly==2.1.0.dev202003240105 127 | - tqdm==4.43.0 128 | prefix: /home/eyal/miniconda3/envs/spdc 129 | 130 | -------------------------------------------------------------------------------- /matlab/LGfigures/LGfigures.m: -------------------------------------------------------------------------------- 1 | l = -4:1:4; 2 | 3 | folder = 'C:\Users\avivk\OneDrive\Desktop\Aviv\SPDC\Machine Learning\Figures for the paper\ququart2_only_crystal'; 4 | 5 | G2_sim = double(readNPY(strcat(folder,'\coincidence_rate_PumpLG0m2.npy'))); 6 | G2_sim = reshape(G2_sim,[9,9]); 7 | G2_sim = abs(G2_sim)/sum(sum(abs(G2_sim))); 8 | % figure; imagesc(l,l,G2_sim); axis square; colorbar; title('simulation'); set(gca, 'FontSize', 20, 'FontName', 'Calibri'); xlabel('l_i (idler)'); ylabel('l_s (signal)'); 9 | 10 | G2_target = double(readNPY(strcat(folder,'\target.npy'))); 11 | G2_target = reshape(G2_target,[9,9]); 12 | G2_target = abs(G2_target)/sum(sum(abs(G2_target))); 13 | 14 | 15 | h=figure; b=bar3(G2_sim); 16 | zlabel('probability'); 17 | set(gca, 'XTickLabel', {'-4', '-3', '-2', '-1', '0', '1', '2', '3', '4'}, 'FontSize', 28, 'FontName', 'Calibri') 18 | set(gca, 'YTickLabel',{'-4', '-3', '-2', '-1', '0', '1', '2', '3', '4'}, 'FontSize', 28, 'FontName', 'Calibri') 19 | colormap summer; 20 | for k = 1:length(b) 21 | zdata = get(b(k),'ZData'); 22 | b(k).CData = zdata; 23 | b(k).FaceColor = 'interp'; 24 | end 25 | zlim([0, 0.5]); 26 | 27 | 28 | h=figure; b=bar3(G2_target); 29 | zlabel('probability'); 30 | set(gca, 'XTickLabel', {'-4', '-3', '-2', '-1', '0', '1', '2', '3', '4'}, 'FontSize', 28, 'FontName', 'Calibri') 31 | set(gca, 'YTickLabel',{'-4', '-3', '-2', '-1', '0', '1', '2', '3', '4'}, 'FontSize', 28, 'FontName', 'Calibri') 32 | colormap summer; 33 | for k = 1:length(b) 34 | zdata = get(b(k),'ZData'); 35 | b(k).CData = zdata; 36 | b(k).FaceColor = 'interp'; 37 | end 38 | zlim([0, 0.5]); 39 | 40 | 41 | 42 | 43 | 44 | %% 45 | 46 | dx = 1e-6; %m 47 | MaxX = 120e-6; %m 48 | %pump_waist = 40e-6; %m 49 | x = -MaxX:dx:MaxX; 50 | y = x; 51 | 52 | Pump = 1; 53 | 54 | if Pump 55 | pump_coeffs_real = [0 0 1 0 0 0 0 0 0]; %readNPY(strcat(folder,'\parameters_pump_real.npy')); 56 | pump_coeffs_imag = [0 0 0 0 0 0 0 0 0];%readNPY(strcat(folder,'\parameters_pump_imag.npy')); 57 | pump_coeffs = pump_coeffs_real + 1i*pump_coeffs_imag; 58 | pump_waists_vec = 1e-6*[40 40 40 40 40 40 40 40 40]; %readNPY(strcat(folder,'\parameters_pump_waists.npy')); 59 | %*34/40 60 | 61 | MaxP = 0; 62 | MaxL = 4;%2; 63 | 64 | PumpCoeffs = pump_coeffs; 65 | 66 | [X,Y] = meshgrid(x,y); 67 | PumpProfile = 0; 68 | 69 | for p = 0:MaxP 70 | for ll = -MaxL:MaxL 71 | pump_waist = pump_waists_vec(ll+MaxL+1+(2*MaxL+1)*p); 72 | PumpProfile = PumpProfile + PumpCoeffs(ll+MaxL+1+(2*MaxL+1)*p).*LaguerreGauss(404e-9, 2, pump_waist,ll,p,X,Y,0); 73 | end 74 | end 75 | 76 | figure; imagesc(x*1e6, y*1e6, abs(PumpProfile)); axis square; colorbar; set(gca, 'FontSize', 28, 'FontName', 'Calibri'); xlabel('x[um]'); ylabel('y[um]'); 77 | axes('Position',[.53 .7 .2 .2]) 78 | box on 79 | imagesc(x*1e6, y*1e6, angle(PumpProfile)); axis square; colorbar; set(gca, 'FontSize', 16, 'FontName', 'Calibri', 'XColor', 'w', 'YColor', 'w'); xlabel('x[um]', 'Color', 'w'); ylabel('y[um]', 'Color', 'w'); 80 | c = colorbar; c.Color = 'w'; 81 | end 82 | 83 | %% 84 | 85 | 86 | Poling = 1; 87 | 88 | 89 | if Poling 90 | poling_coeffs_real = readNPY(strcat(folder,'\parameters_crystal_real.npy')); 91 | poling_coeffs_imag = readNPY(strcat(folder,'\parameters_crystal_imag.npy')); 92 | poling_coeffs = poling_coeffs_real + 1i*poling_coeffs_imag; 93 | poling_waist = 1e-5*readNPY(strcat(folder,'\parameters_crystal_effective_waists.npy')); 94 | 95 | % poling_coeffs_real = zeros(size(poling_coeffs_real)); 96 | % poling_coeffs_imag = poling_coeffs_real; 97 | % 98 | % poling_coeffs_real(1+2+1+(2*2+1)*0) = 1; 99 | % 100 | % poling_coeffs = poling_coeffs_real + 1i*poling_coeffs_imag; 101 | % poling_waist = 40e-6*ones(size(poling_waist)); 102 | 103 | 104 | MaxP = 2;%9; 105 | MaxL = 2;%4; 106 | 107 | PolingCoeffs = poling_coeffs; 108 | 109 | [X,Y] = meshgrid(x,y); 110 | Profile = 0; 111 | 112 | for p = 0:MaxP 113 | for ll = -MaxL:MaxL 114 | waist_curr = poling_waist(ll+MaxL+1+(2*MaxL+1)*p); 115 | Profile = Profile + PolingCoeffs(ll+MaxL+1+(2*MaxL+1)*p).*LaguerreGauss(808e-9, 2, waist_curr,ll,p,X,Y,0); 116 | end 117 | end 118 | 119 | Profile = Profile / max(max(abs(Profile))); 120 | 121 | Magnitude = abs(Profile); 122 | phase = angle(Profile); 123 | dutycycle = asin(Magnitude)/pi; 124 | % figure; imagesc(real(Profile)); title('Real(poling)'); colorbar 125 | % figure; imagesc(imag(Profile)); title('Imag(poling)'); colorbar 126 | % figure; imagesc(Magnitude); title('abs(poling)'); colorbar 127 | % figure; imagesc(phase); title('phase(poling)'); colorbar 128 | 129 | %% 130 | Lambda = 9.87; %um 131 | 132 | Z = -3*Lambda/2:1:1.1*3*Lambda/2; %um 133 | 134 | 135 | DeltaK = 2*pi/Lambda; 136 | Poling = zeros(length(x),length(y), length(Z)); 137 | for i = 1:length(Z) 138 | 100*i/length(Z) 139 | z = Z(i); 140 | for m = 0:100 141 | if m == 0 142 | Poling(:,:,i) = Poling(:,:,i) + 2*dutycycle - 1; 143 | else 144 | Poling(:,:,i) = Poling(:,:,i) + (2/(m*pi)).*sin(pi*m*dutycycle).*2.*cos(m*DeltaK*z + m * phase); 145 | end 146 | end 147 | end 148 | end 149 | 150 | 151 | %% 152 | 153 | Draw3DPoling = 0; 154 | 155 | if Draw3DPoling 156 | s_pol = size(Poling); 157 | z_factor = 20; 158 | figure; axis image; axis off; 159 | for ii = 1:s_pol(1) 160 | 100*ii/s_pol(1) 161 | for jj = 1:s_pol(2) 162 | for kk = 1:s_pol(3) 163 | if sign(Poling(ii,jj,kk)) == 1 164 | plotcube([1 z_factor 1],[ii z_factor*(kk-1) + 1 jj],1,[102, 163, 255]/256); 165 | hold on; 166 | end 167 | end 168 | end 169 | end 170 | 171 | plotcube([241 length(Z)*z_factor 241],[1 1 1],.1,[17, 103, 177]/256); 172 | 173 | 174 | view(130,20); 175 | camlight('right') 176 | end 177 | 178 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repo contains the official implementation of the paper: 2 | 3 | # Inverse Design of Spontaneous Parametric Down Conversion for Generation of High-Dimensional Qudits 4 | ![illustration](Optica.jpg) 5 | 6 | ## about 7 | (**If you use the code, please _cite our papers_.**) 8 | We have introduced an algorithm for solving the inverse design problem of generating structured and entangled photon pairs in quantum optics, using tailored nonlinear interactions in the SPDC process. The *SPDCinv* algorithm extracts the optimal physical parameters which yield a desired quantum state or correlations between structured photon-pairs, that can then be used in future experiments. To ensure convergence to realizable results and to improve the predictive accuracy, our algorithm obeyed physical constraints through the integration of the time-unfolded propagation dynamics governing the interaction of the SPDC Hamiltonian. 9 | 10 | ### in this version 11 | We have shown how we can apply our algorithm to obtain the optimal nonlinear $\chi^{(2)}$ volume holograms (2D/3D) as well as different pump structures for generating the desired maximally-entangled states. Using this version, one can further obtain all-optical coherent control over the generated quantum states by actively changing the profile of the pump beam. 12 | 13 | ## extensions 14 | This work can readily be extended to the spectral-temporal domain, by allowing non-periodic volume holograms along the propagation axis -- making it possible to shape the joint spectral amplitude of the photon pairs. Furthermore, one can easily adopt our approach for other optical systems, such as: nonlinear waveguides and resonators, $\chi^{(3)}$ effects (e.g. spontaneous four wave mixing), spatial solitons, fiber optics communication systems, and even higher-order coincidence probabilities. Moreover, the algorithm can be upgraded to include passive optical elements such as beam-splitters, holograms, and mode sorters, thereby providing greater flexibility for generating and manipulating quantum optical states. The *SPDCinv* algorithm can incorporate decoherence mechanisms arising from non-perturbative high-order photon pair generation in the high gain regime. Other decoherence effects due to losses such as absorption and scattering can be incorporated into the model in the future. Finally, the current scheme can be adapted to other quantum systems sharing a similar Hamiltonian structure, such as superfluids and superconductors. 15 | 16 | ## Data Availability 17 | Data underlying the results presented in this paper are available at `SPDCinv/Data availability/`. 18 | 19 | ## running the code 20 | To understand and determine the variables of interaction and learning hyperparameters, see `src/spdc_inv/experiments/experiment.py` and read the documentation there. 21 | 22 | Before you run the code, please run the following line from the bash: 23 | `export PYTHONPATH="${PYTHONPATH}:/home/jupyter/src"` 24 | Later, run experiment.py by: 25 | `python src/spdc_inv/experiments/experiment.py` 26 | 27 | ## Giving Credit 28 | If you use this code in your work, please cite the associated papers. 29 | 30 | ``` 31 | @article{Rozenberg:22, 32 | author = {Eyal Rozenberg and Aviv Karnieli and Ofir Yesharim and Joshua Foley-Comer and Sivan Trajtenberg-Mills and Daniel Freedman and Alex M. Bronstein and Ady Arie}, 33 | journal = {Optica}, 34 | keywords = {Computation methods; Four wave mixing; Nonlinear photonic crystals; Quantum information processing; Quantum key distribution; Quantum optics}, 35 | number = {6}, 36 | pages = {602--615}, 37 | publisher = {Optica Publishing Group}, 38 | title = {Inverse design of spontaneous parametric downconversion for generation of high-dimensional qudits}, 39 | volume = {9}, 40 | month = {Jun}, 41 | year = {2022}, 42 | url = {http://opg.optica.org/optica/abstract.cfm?URI=optica-9-6-602}, 43 | doi = {10.1364/OPTICA.451115}, 44 | abstract = {Spontaneous parametric downconversion (SPDC) in quantum optics is an invaluable resource for the realization of high-dimensional qudits with spatial modes of light. One of the main open challenges is how to directly generate a desirable qudit state in the SPDC process. This problem can be addressed through advanced computational learning methods; however, due to difficulties in modeling the SPDC process by a fully differentiable algorithm, progress has been limited. Here, we overcome these limitations and introduce a physically constrained and differentiable model, validated against experimental results for shaped pump beams and structured crystals, capable of learning the relevant interaction parameters in the process. We avoid any restrictions induced by the stochastic nature of our physical model and integrate the dynamic equations governing the evolution under the SPDC Hamiltonian. We solve the inverse problem of designing a nonlinear quantum optical system that achieves the desired quantum state of downconverted photon pairs. The desired states are defined using either the second-order correlations between different spatial modes or by specifying the required density matrix. By learning nonlinear photonic crystal structures as well as different pump shapes, we successfully show how to generate maximally entangled states. Furthermore, we simulate all-optical coherent control over the generated quantum state by actively changing the profile of the pump beam. Our work can be useful for applications such as novel designs of high-dimensional quantum key distribution and quantum information processing protocols. In addition, our method can be readily applied for controlling other degrees of freedom of light in the SPDC process, such as spectral and temporal properties, and may even be used in condensed-matter systems having a similar interaction Hamiltonian.}, 45 | } 46 | ``` 47 | 48 | ``` 49 | @inproceedings{Rozenberg:21, 50 | author = {Eyal Rozenberg and Aviv Karnieli and Ofir Yesharim and Sivan Trajtenberg-Mills and Daniel Freedman and Alex M. Bronstein and Ady Arie}, 51 | booktitle = {Conference on Lasers and Electro-Optics}, 52 | journal = {Conference on Lasers and Electro-Optics}, 53 | keywords = {Light matter interactions; Nonlinear optical crystals; Nonlinear photonic crystals; Photonic crystals; Quantum communications; Quantum optics}, 54 | pages = {FM1N.7}, 55 | publisher = {Optica Publishing Group}, 56 | title = {Inverse Design of Quantum Holograms in Three-Dimensional Nonlinear Photonic Crystals}, 57 | year = {2021}, 58 | url = {http://opg.optica.org/abstract.cfm?URI=CLEO_QELS-2021-FM1N.7}, 59 | doi = {10.1364/CLEO_QELS.2021.FM1N.7}, 60 | abstract = {We introduce a systematic approach for designing 3D nonlinear photonic crystals and pump beams for generating desired quantum correlations between structured photon-pairs. Our model is fully differentiable, allowing accurate and efficient learning and discovery of novel designs.}, 61 | } 62 | ``` 63 | -------------------------------------------------------------------------------- /src/spdc_inv/models/utils.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from jax import jit 3 | from spdc_inv.utils.utils import h_bar, eps0, c 4 | from spdc_inv.utils.utils import PP_crystal_slab 5 | 6 | import jax.numpy as np 7 | 8 | 9 | class Field(ABC): 10 | """ 11 | A class that holds everything to do with the interaction values of a given beam 12 | vac - corresponding vacuum state coefficient 13 | kappa - coupling constant 14 | k - wave vector 15 | """ 16 | def __init__( 17 | self, 18 | beam, 19 | dx, 20 | dy, 21 | maxZ 22 | ): 23 | """ 24 | 25 | Parameters 26 | ---------- 27 | beam: A class that holds everything to do with a beam 28 | dx: transverse resolution in x [m] 29 | dy: transverse resolution in y [m] 30 | maxZ: Crystal's length in z [m] 31 | """ 32 | 33 | self.vac = np.sqrt(h_bar * beam.w / (2 * eps0 * beam.n ** 2 * dx * dy * maxZ)) 34 | self.kappa = 2 * 1j * beam.w ** 2 / (beam.k * c ** 2) 35 | self.k = beam.k 36 | 37 | 38 | def crystal_prop( 39 | pump_profile, 40 | pump, 41 | signal_field, 42 | idler_field, 43 | vacuum_states, 44 | interaction, 45 | poling_period, 46 | N, 47 | crystal_hologram, 48 | infer=None, 49 | signal_init=None, 50 | idler_init=None 51 | ): 52 | """ 53 | Crystal propagation 54 | propagate through crystal using split step Fourier for 4 fields: signal, idler and two vacuum states 55 | 56 | Parameters 57 | ---------- 58 | pump_profile: electromagnetic pump beam profile 59 | pump: A class that holds everything to do with the pump beam 60 | signal_field: A class that holds everything to do with the interaction values of the signal beam 61 | idler_field: A class that holds everything to do with the interaction values of the idler beam 62 | vacuum_states: The vacuum and interaction fields 63 | interaction: A class that represents the SPDC interaction process, on all of its physical parameters. 64 | poling_period: Poling period (dk_offset * delta_k) 65 | N: number of vacuum_state elements 66 | crystal_hologram: 3D crystal hologram 67 | infer: (True/False) if in inference mode, we include more coefficients in the poling 68 | description for better validation 69 | signal_init: initial signal profile. If None, initiate to zero 70 | idler_init: initial idler profile. If None, initiate to zero 71 | 72 | Returns: the interacting fields at the end of interaction medium 73 | ------- 74 | 75 | """ 76 | 77 | x = interaction.x 78 | y = interaction.y 79 | Nx = interaction.Nx 80 | Ny = interaction.Ny 81 | dz = interaction.dz 82 | 83 | 84 | if signal_init is None: 85 | signal_out = np.zeros([N, Nx, Ny]) 86 | else: 87 | assert len(signal_init.shape) == 3 88 | assert signal_init.shape[0] == N 89 | assert signal_init.shape[1] == Nx 90 | assert signal_init.shape[2] == Ny 91 | signal_out = signal_init 92 | 93 | if idler_init is None: 94 | idler_out = np.zeros([N, Nx, Ny]) 95 | else: 96 | assert len(idler_init.shape) == 3 97 | assert idler_init.shape[0] == N 98 | assert idler_init.shape[1] == Nx 99 | assert idler_init.shape[2] == Ny 100 | idler_out = idler_init 101 | 102 | 103 | signal_vac = signal_field.vac * (vacuum_states[:, 0, 0] + 1j * vacuum_states[:, 0, 1]) / np.sqrt(2) 104 | idler_vac = idler_field.vac * (vacuum_states[:, 1, 0] + 1j * vacuum_states[:, 1, 1]) / np.sqrt(2) 105 | 106 | for z in interaction.z: 107 | signal_out, signal_vac, idler_out, idler_vac = propagate_dz( 108 | pump_profile, 109 | x, 110 | y, 111 | z, 112 | dz, 113 | pump.k, 114 | signal_field.k, 115 | idler_field.k, 116 | signal_field.kappa, 117 | idler_field.kappa, 118 | poling_period, 119 | crystal_hologram, 120 | interaction.d33, 121 | signal_out, 122 | signal_vac, 123 | idler_out, 124 | idler_vac, 125 | infer 126 | ) 127 | 128 | return signal_out, idler_out, idler_vac 129 | 130 | 131 | @jit 132 | def propagate_dz( 133 | pump_profile, 134 | x, 135 | y, 136 | z, 137 | dz, 138 | pump_k, 139 | signal_field_k, 140 | idler_field_k, 141 | signal_field_kappa, 142 | idler_field_kappa, 143 | poling_period, 144 | crystal_hologram, 145 | interaction_d33, 146 | signal_out, 147 | signal_vac, 148 | idler_out, 149 | idler_vac, 150 | infer=None, 151 | ): 152 | """ 153 | Single step of crystal propagation 154 | single split step Fourier for 4 fields: signal, idler and two vacuum states 155 | 156 | Parameters 157 | ---------- 158 | pump_profile 159 | x: x axis, length 2*MaxX (transverse) 160 | y: y axis, length 2*MaxY (transverse) 161 | z: z axis, length MaxZ (propagation) 162 | dz: longitudinal resolution in z [m] 163 | pump_k: pump k vector 164 | signal_field_k: signal k vector 165 | idler_field_k: field k vector 166 | signal_field_kappa: signal kappa 167 | idler_field_kappa: idler kappa 168 | poling_period: poling period 169 | crystal_hologram: Crystal 3D hologram (if None, ignore) 170 | interaction_d33: nonlinear coefficient [meter/Volt] 171 | signal_out: current signal profile 172 | signal_vac: current signal vacuum state profile 173 | idler_out: current idler profile 174 | idler_vac: current idler vacuum state profile 175 | infer: (True/False) if in inference mode, we include more coefficients in the poling 176 | description for better validation 177 | 178 | Returns 179 | ------- 180 | 181 | """ 182 | 183 | # pump beam: 184 | E_pump = propagate(pump_profile, x, y, pump_k, z) * np.exp(-1j * pump_k * z) 185 | 186 | # crystal slab: 187 | PP = PP_crystal_slab(poling_period, z, crystal_hologram, inference=None) 188 | 189 | # coupled wave equations - split step 190 | # signal: 191 | dEs_out_dz = signal_field_kappa * interaction_d33 * PP * E_pump * np.conj(idler_vac) 192 | dEs_vac_dz = signal_field_kappa * interaction_d33 * PP * E_pump * np.conj(idler_out) 193 | signal_out = signal_out + dEs_out_dz * dz 194 | signal_vac = signal_vac + dEs_vac_dz * dz 195 | 196 | # idler: 197 | dEi_out_dz = idler_field_kappa * interaction_d33 * PP * E_pump * np.conj(signal_vac) 198 | dEi_vac_dz = idler_field_kappa * interaction_d33 * PP * E_pump * np.conj(signal_out) 199 | idler_out = idler_out + dEi_out_dz * dz 200 | idler_vac = idler_vac + dEi_vac_dz * dz 201 | 202 | # propagate 203 | signal_out = propagate(signal_out, x, y, signal_field_k, dz) * np.exp(-1j * signal_field_k * dz) 204 | signal_vac = propagate(signal_vac, x, y, signal_field_k, dz) * np.exp(-1j * signal_field_k * dz) 205 | idler_out = propagate(idler_out, x, y, idler_field_k, dz) * np.exp(-1j * idler_field_k * dz) 206 | idler_vac = propagate(idler_vac, x, y, idler_field_k, dz) * np.exp(-1j * idler_field_k * dz) 207 | 208 | return signal_out, signal_vac, idler_out, idler_vac 209 | 210 | 211 | @jit 212 | def propagate(A, x, y, k, dz): 213 | """ 214 | Free Space propagation using the free space transfer function, 215 | (two dimensional), according to Saleh 216 | Using CGS, or MKS, Boyd 2nd eddition 217 | 218 | Parameters 219 | ---------- 220 | A: electromagnetic beam profile 221 | x,y: spatial vectors 222 | k: wave vector 223 | dz: The distance to propagate 224 | 225 | Returns the propagated field 226 | ------- 227 | 228 | """ 229 | dx = np.abs(x[1] - x[0]) 230 | dy = np.abs(y[1] - y[0]) 231 | 232 | # define the fourier vectors 233 | X, Y = np.meshgrid(x, y, indexing='ij') 234 | KX = 2 * np.pi * (X / dx) / (np.size(X, 1) * dx) 235 | KY = 2 * np.pi * (Y / dy) / (np.size(Y, 1) * dy) 236 | 237 | # The Free space transfer function of propagation, using the Fresnel approximation 238 | # (from "Engineering optics with matlab"/ing-ChungPoon&TaegeunKim): 239 | H_w = np.exp(-1j * dz * (np.square(KX) + np.square(KY)) / (2 * k)) 240 | H_w = np.fft.ifftshift(H_w) 241 | 242 | # Fourier Transform: move to k-space 243 | G = np.fft.fft2(A) # The two-dimensional discrete Fourier transform (DFT) of A. 244 | 245 | # propoagte in the fourier space 246 | F = np.multiply(G, H_w) 247 | 248 | # inverse Fourier Transform: go back to real space 249 | Eout = np.fft.ifft2(F) # [in real space]. E1 is the two-dimensional INVERSE discrete Fourier transform (DFT) of F1 250 | 251 | return Eout 252 | -------------------------------------------------------------------------------- /src/spdc_inv/training/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import jax.numpy as np 3 | from jax import lax 4 | from jax import jit 5 | from spdc_inv.utils.defaults import qubit_projection_n_state2, \ 6 | qubit_tomography_dimensions, qutrit_projection_n_state2, qutrit_tomography_dimensions 7 | import math 8 | 9 | 10 | @jit 11 | def project(projection_basis, beam_profile): 12 | """ 13 | The function projects some state beam_profile onto given projection_basis 14 | Parameters 15 | ---------- 16 | projection_basis: array of basis function 17 | beam_profile: beam profile (2d) 18 | 19 | Returns 20 | ------- 21 | 22 | """ 23 | Nxx2 = beam_profile.shape[1] ** 2 24 | N = beam_profile.shape[0] 25 | Nh = projection_basis.shape[0] 26 | projection = (np.conj(projection_basis) * beam_profile).reshape(Nh, N, Nxx2).sum(2) 27 | normalization1 = np.abs(beam_profile ** 2).reshape(N, Nxx2).sum(1) 28 | normalization2 = np.abs(projection_basis ** 2).reshape(Nh, Nxx2).sum(1) 29 | projection = projection / np.sqrt(normalization1[None, :] * normalization2[:, None]) 30 | return projection 31 | 32 | 33 | @jit 34 | def decompose(beam_profile, projection_basis_arr): 35 | """ 36 | Decompose a given beam profile into modes defined in the dictionary 37 | Parameters 38 | ---------- 39 | beam_profile: beam profile (2d) 40 | projection_basis_arr: array of basis function 41 | 42 | Returns: beam profile as a decomposition of basis functions 43 | ------- 44 | 45 | """ 46 | projection = project(projection_basis_arr[:, None], beam_profile) 47 | return np.transpose(projection) 48 | 49 | 50 | @jit 51 | def fix_power(decomposed_profile, beam_profile): 52 | """ 53 | Normalize power and ignore higher modes 54 | Parameters 55 | ---------- 56 | decomposed_profile: the decomposed beam profile 57 | beam_profile: the original beam profile 58 | 59 | Returns a normalized decomposed profile 60 | ------- 61 | 62 | """ 63 | scale = np.sqrt( 64 | np.sum(beam_profile * np.conj(beam_profile), (1, 2))) / np.sqrt( 65 | np.sum(decomposed_profile * np.conj(decomposed_profile), (1, 2))) 66 | 67 | return decomposed_profile * scale[:, None, None] 68 | 69 | 70 | @jit 71 | def kron(a, b, multiple_devices: bool = False): 72 | """ 73 | Calculates the kronecker product between two 2d tensors 74 | Parameters 75 | ---------- 76 | a, b: 2d tensors 77 | multiple_devices: (True/False) whether multiple devices are used 78 | 79 | Returns the kronecker product 80 | ------- 81 | 82 | """ 83 | if multiple_devices: 84 | return lax.psum((a[:, :, None, :, None] * b[:, None, :, None, :]).sum(0), 'device') 85 | 86 | else: 87 | return (a[:, :, None, :, None] * b[:, None, :, None, :]).sum(0) 88 | 89 | 90 | @jit 91 | def projection_matrices_calc(a, b, c, N): 92 | """ 93 | 94 | Parameters 95 | ---------- 96 | a, b, c: the interacting fields 97 | N: Total number of interacting vacuum state elements 98 | 99 | Returns the projective matrices 100 | ------- 101 | 102 | """ 103 | G1_ss = kron(np.conj(a), a) / N 104 | G1_ii = kron(np.conj(b), b) / N 105 | G1_si = kron(np.conj(b), a) / N 106 | G1_si_dagger = kron(np.conj(a), b) / N 107 | Q_si = kron(c, a) / N 108 | Q_si_dagger = kron(np.conj(a), np.conj(c)) / N 109 | 110 | return G1_ss, G1_ii, G1_si, G1_si_dagger, Q_si, Q_si_dagger 111 | 112 | 113 | @jit 114 | def projection_matrix_calc(G1_ss, G1_ii, G1_si, G1_si_dagger, Q_si, Q_si_dagger): 115 | """ 116 | 117 | Parameters 118 | ---------- 119 | G1_ss, G1_ii, G1_si, G1_si_dagger, Q_si, Q_si_dagger: the projective matrices 120 | Returns the 2nd order projective matrix 121 | ------- 122 | 123 | """ 124 | return (lax.psum(G1_ii, 'device') * 125 | lax.psum(G1_ss, 'device') + 126 | lax.psum(Q_si_dagger, 'device') * 127 | lax.psum(Q_si, 'device') + 128 | lax.psum(G1_si_dagger, 'device') * 129 | lax.psum(G1_si, 'device') 130 | ).real 131 | 132 | 133 | # for coupling inefficiencies 134 | @jit 135 | def coupling_inefficiency_calc_G2( 136 | lam, 137 | SMF_waist, 138 | max_mode_l: int = 4, 139 | focal_length: float = 4.6e-3, 140 | SMF_mode_diam: float = 2.5e-6, 141 | ): 142 | waist = 46.07 * SMF_waist 143 | a_0 = np.sqrt(2) * lam * focal_length / (np.pi * waist) 144 | A = 2 / (1 + (SMF_mode_diam ** 2 / a_0 ** 2)) 145 | B = 2 / (1 + (a_0 ** 2 / SMF_mode_diam ** 2)) 146 | inef_coeff = np.zeros([2 * max_mode_l + 1, 2 * max_mode_l + 1]) 147 | 148 | for l_i in range(-max_mode_l, max_mode_l + 1): 149 | inef_coeff_i = (math.factorial(abs(l_i)) ** 2) * (A ** (2 * abs(l_i) + 1) * B) / (math.factorial(2 * abs(l_i))) 150 | for l_s in range(-max_mode_l, max_mode_l + 1): 151 | inef_coeff_s = (math.factorial(abs(l_s)) ** 2) * (A ** (2 * abs(l_s) + 1) * B) / ( 152 | math.factorial(2 * abs(l_s))) 153 | inef_coeff = inef_coeff.at[l_i + max_mode_l, l_s + max_mode_l].set((inef_coeff_i + inef_coeff_s)) 154 | 155 | return inef_coeff.reshape(1, (2 * max_mode_l + 1) ** 2) 156 | 157 | 158 | @jit 159 | def coupling_inefficiency_calc_tomo( 160 | lam, 161 | SMF_waist, 162 | focal_length: float = 4.6e-3, 163 | SMF_mode_diam: float = 2.5e-6, 164 | ): 165 | waist = 46.07 * SMF_waist 166 | a_0 = np.sqrt(2) * lam * focal_length / (np.pi * waist) 167 | A = 2 / (1 + (SMF_mode_diam ** 2 / a_0 ** 2)) 168 | B = 2 / (1 + (a_0 ** 2 / SMF_mode_diam ** 2)) 169 | inef_coeff = np.zeros([qutrit_projection_n_state2, qutrit_projection_n_state2]) 170 | 171 | for base_1 in range(qutrit_projection_n_state2): 172 | # azimuthal modes l = {-1, 0, 1} defined according to order of MUBs 173 | if base_1 == 0 or base_1 == 2: 174 | l_1 = 1 175 | inef_coeff_i = (math.factorial(abs(l_1)) ** 2) * (A ** (2 * abs(l_1) + 1) * B) / ( 176 | math.factorial(2 * abs(l_1))) 177 | elif base_1 == 1: 178 | l_1 = 0 179 | inef_coeff_i = (math.factorial(abs(l_1)) ** 2) * (A ** (2 * abs(l_1) + 1) * B) / ( 180 | math.factorial(2 * abs(l_1))) 181 | 182 | else: 183 | if base_1 < 7 or base_1 > 10: 184 | l_1, l_2 = 1, 0 185 | else: 186 | l_1, l_2 = 1, 1 187 | inef_coeff_1 = (math.factorial(abs(l_1)) ** 2) * (A ** (2 * abs(l_1) + 1) * B) / ( 188 | math.factorial(2 * abs(l_1))) 189 | inef_coeff_2 = (math.factorial(abs(l_2)) ** 2) * (A ** (2 * abs(l_2) + 1) * B) / ( 190 | math.factorial(2 * abs(l_2))) 191 | inef_coeff_i = 0.5 * (inef_coeff_1 + inef_coeff_2) 192 | 193 | for base_2 in range(qutrit_projection_n_state2): 194 | if base_2 == 0 or base_2 == 2: 195 | l_1 = 1 196 | inef_coeff_s = (math.factorial(abs(l_1)) ** 2) * (A ** (2 * abs(l_1) + 1) * B) / ( 197 | math.factorial(2 * abs(l_1))) 198 | elif base_2 == 1: 199 | l_1 = 0 200 | inef_coeff_s = (math.factorial(abs(l_1)) ** 2) * (A ** (2 * abs(l_1) + 1) * B) / ( 201 | math.factorial(2 * abs(l_1))) 202 | 203 | else: 204 | if base_2 < 7 or base_2 > 10: 205 | l_1, l_2 = 1, 0 206 | else: 207 | l_1, l_2 = 1, 1 208 | inef_coeff_1 = (math.factorial(abs(l_1)) ** 2) * (A ** (2 * abs(l_1) + 1) * B) / ( 209 | math.factorial(2 * abs(l_1))) 210 | inef_coeff_2 = (math.factorial(abs(l_2)) ** 2) * (A ** (2 * abs(l_2) + 1) * B) / ( 211 | math.factorial(2 * abs(l_2))) 212 | inef_coeff_s = 0.5 * (inef_coeff_1 + inef_coeff_2) 213 | 214 | inef_coeff = inef_coeff.at[base_1, base_2].set((inef_coeff_i + inef_coeff_s)) 215 | 216 | return inef_coeff.reshape(1, qutrit_projection_n_state2 ** 2) 217 | 218 | 219 | @jit 220 | def get_qubit_density_matrix( 221 | tomography_matrix, 222 | masks, 223 | rotation_mats 224 | ): 225 | 226 | tomography_matrix = tomography_matrix.reshape(qubit_projection_n_state2, qubit_projection_n_state2) 227 | 228 | dens_mat = (1 / (qubit_tomography_dimensions ** 2)) * (tomography_matrix * masks).sum(1).sum(1).reshape( 229 | qubit_tomography_dimensions ** 4, 1, 1) 230 | dens_mat = (dens_mat * rotation_mats) 231 | dens_mat = dens_mat.sum(0) 232 | 233 | return dens_mat 234 | 235 | 236 | @jit 237 | def get_qutrit_density_matrix( 238 | tomography_matrix, 239 | masks, 240 | rotation_mats 241 | ): 242 | 243 | tomography_matrix = tomography_matrix.reshape(qutrit_projection_n_state2, qutrit_projection_n_state2) 244 | 245 | dens_mat = (1 / (qutrit_tomography_dimensions ** 2)) * (tomography_matrix * masks).sum(1).sum(1).reshape( 246 | qutrit_tomography_dimensions ** 4, 1, 1) 247 | dens_mat = (dens_mat * rotation_mats) 248 | dens_mat = dens_mat.sum(0) 249 | 250 | return dens_mat 251 | -------------------------------------------------------------------------------- /src/spdc_inv/experiments/utils.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from jax import numpy as np 3 | from typing import Tuple, Dict, Any, List, Union 4 | from spdc_inv.utils.utils import ( 5 | HermiteBank, LaguerreBank, TomographyBankLG, TomographyBankHG 6 | ) 7 | from spdc_inv.utils.defaults import QUBIT, QUTRIT 8 | from spdc_inv.utils.defaults import qubit_projection_n_state2, \ 9 | qubit_tomography_dimensions, qutrit_projection_n_state2, qutrit_tomography_dimensions 10 | 11 | 12 | class Projection_coincidence_rate(ABC): 13 | """ 14 | A class that represents the projective basis for 15 | calculating the coincidence rate observable of the interaction. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | calculate_observable: Tuple[Dict[Any, bool], ...], 21 | waist_pump0: float, 22 | signal_wavelength: float, 23 | crystal_x: np.array, 24 | crystal_y: np.array, 25 | temperature: float, 26 | ctype, 27 | polarization: str, 28 | z: float = 0., 29 | projection_basis: str = 'LG', 30 | max_mode1: int = 1, 31 | max_mode2: int = 4, 32 | waist: float = None, 33 | wavelength: float = None, 34 | tau: float = 1e-9, 35 | SMF_waist: float = None, 36 | 37 | ): 38 | """ 39 | 40 | Parameters 41 | ---------- 42 | calculate_observable: True/False, will the observable be calculated in simulation 43 | waist_pump0: pump waists at the center of the crystal (initial-before training) 44 | signal_wavelength: signal wavelength at spdc interaction 45 | crystal_x: x axis linspace array (transverse) 46 | crystal_y: y axis linspace array (transverse) 47 | temperature: interaction temperature 48 | ctype: refractive index function 49 | polarization: polarization for calculating effective refractive index 50 | z: projection longitudinal position 51 | projection_basis: type of projection basis 52 | Can be: LG (Laguerre-Gauss) / HG (Hermite-Gauss) 53 | max_mode1: Maximum value of first mode of the 2D projection basis 54 | max_mode2: Maximum value of second mode of the 2D projection basis 55 | waist: waists of the projection basis functions 56 | wavelength: wavelength for generating projection basis 57 | tau: coincidence window [nano sec] 58 | SMF_waist: signal/idler beam radius at single mode fibre 59 | """ 60 | 61 | self.tau = tau 62 | 63 | if waist is None: 64 | self.waist = np.sqrt(2) * waist_pump0 65 | else: 66 | self.waist = waist 67 | 68 | if wavelength is None: 69 | wavelength = signal_wavelength 70 | 71 | assert projection_basis.lower() in ['lg', 'hg'], 'The projection basis is LG or HG ' \ 72 | 'basis functions only' 73 | 74 | self.projection_basis = projection_basis 75 | self.max_mode1 = max_mode1 76 | self.max_mode2 = max_mode2 77 | 78 | # number of modes for projection basis 79 | if projection_basis.lower() == 'lg': 80 | self.projection_n_modes1 = max_mode1 81 | self.projection_n_modes2 = 2 * max_mode2 + 1 82 | else: 83 | self.projection_n_modes1 = max_mode1 84 | self.projection_n_modes2 = max_mode2 85 | 86 | # Total number of projection modes 87 | self.projection_n_modes = self.projection_n_modes1 * self.projection_n_modes2 88 | 89 | refractive_index = ctype(wavelength * 1e6, temperature, polarization) 90 | [x, y] = np.meshgrid(crystal_x, crystal_y) 91 | 92 | self.SMF_waist = SMF_waist 93 | 94 | if calculate_observable: 95 | if projection_basis.lower() == 'lg': 96 | self.basis_arr, self.basis_str = \ 97 | LaguerreBank( 98 | wavelength, 99 | refractive_index, 100 | self.waist, 101 | self.max_mode1, 102 | self.max_mode2, 103 | x, y, z) 104 | else: 105 | self.basis_arr, self.basis_str = \ 106 | HermiteBank( 107 | wavelength, 108 | refractive_index, 109 | self.waist, 110 | self.max_mode1, 111 | self.max_mode2, 112 | x, y, z) 113 | 114 | 115 | class Projection_tomography_matrix(ABC): 116 | """ 117 | A class that represents the projective basis for 118 | calculating the tomography matrix & density matrix observable of the interaction. 119 | """ 120 | 121 | def __init__( 122 | self, 123 | calculate_observable: Tuple[Dict[Any, bool], ...], 124 | waist_pump0: float, 125 | signal_wavelength: float, 126 | crystal_x: np.array, 127 | crystal_y: np.array, 128 | temperature: float, 129 | ctype, 130 | polarization: str, 131 | z: float = 0., 132 | projection_basis: str = 'LG', 133 | max_mode1: int = 1, 134 | max_mode2: int = 1, 135 | waist: float = None, 136 | wavelength: float = None, 137 | tau: float = 1e-9, 138 | relative_phase: List[Union[Union[int, float], Any]] = None, 139 | tomography_quantum_state: str = None, 140 | 141 | ): 142 | """ 143 | 144 | Parameters 145 | ---------- 146 | calculate_observable: True/False, will the observable be calculated in simulation 147 | waist_pump0: pump waists at the center of the crystal (initial-before training) 148 | signal_wavelength: signal wavelength at spdc interaction 149 | crystal_x: x axis linspace array (transverse) 150 | crystal_y: y axis linspace array (transverse) 151 | temperature: interaction temperature 152 | ctype: refractive index function 153 | polarization: polarization for calculating effective refractive index 154 | z: projection longitudinal position 155 | projection_basis: type of projection basis 156 | Can be: LG (Laguerre-Gauss) 157 | max_mode1: Maximum value of first mode of the 2D projection basis 158 | max_mode2: Maximum value of second mode of the 2D projection basis 159 | waist: waists of the projection basis functions 160 | wavelength: wavelength for generating projection basis 161 | tau: coincidence window [nano sec] 162 | relative_phase: The relative phase between the mutually unbiased bases (MUBs) states 163 | tomography_quantum_state: the current quantum state we calculate it tomography matrix. 164 | currently we support: qubit/qutrit 165 | """ 166 | 167 | self.tau = tau 168 | 169 | if waist is None: 170 | self.waist = np.sqrt(2) * waist_pump0 171 | else: 172 | self.waist = waist 173 | 174 | if wavelength is None: 175 | wavelength = signal_wavelength 176 | 177 | assert projection_basis.lower() in ['lg', 'hg'], 'The projection basis is LG or HG' \ 178 | 'basis functions only' 179 | 180 | assert max_mode1 == 1, 'for Tomography projections, max_mode1 must be 1' 181 | assert max_mode2 == 1, 'for Tomography projections, max_mode2 must be 1' 182 | 183 | self.projection_basis = projection_basis 184 | self.max_mode1 = max_mode1 185 | self.max_mode2 = max_mode2 186 | 187 | assert tomography_quantum_state in [QUBIT, QUTRIT], f'quantum state must be {QUBIT} or {QUTRIT}, ' \ 188 | 'but received {tomography_quantum_state}' 189 | self.tomography_quantum_state = tomography_quantum_state 190 | self.relative_phase = relative_phase 191 | 192 | self.projection_n_state1 = 1 193 | if self.tomography_quantum_state is QUBIT: 194 | self.projection_n_state2 = qubit_projection_n_state2 195 | self.tomography_dimensions = qubit_tomography_dimensions 196 | else: 197 | self.projection_n_state2 = qutrit_projection_n_state2 198 | self.tomography_dimensions = qutrit_tomography_dimensions 199 | 200 | refractive_index = ctype(wavelength * 1e6, temperature, polarization) 201 | [x, y] = np.meshgrid(crystal_x, crystal_y) 202 | if calculate_observable: 203 | if self.projection_basis == 'lg': 204 | self.basis_arr, self.basis_str = \ 205 | TomographyBankLG( 206 | wavelength, 207 | refractive_index, 208 | self.waist, 209 | self.max_mode1, 210 | self.max_mode2, 211 | x, y, z, 212 | self.relative_phase, 213 | self.tomography_quantum_state 214 | ) 215 | else: 216 | self.basis_arr, self.basis_str = \ 217 | TomographyBankHG( 218 | wavelength, 219 | refractive_index, 220 | self.waist, 221 | self.max_mode1, 222 | self.max_mode2, 223 | x, y, z, 224 | self.relative_phase, 225 | self.tomography_quantum_state 226 | ) 227 | -------------------------------------------------------------------------------- /src/spdc_inv/experiments/results_and_stats_utils.py: -------------------------------------------------------------------------------- 1 | from spdc_inv.utils.defaults import COINCIDENCE_RATE, DENSITY_MATRIX, TOMOGRAPHY_MATRIX 2 | from spdc_inv.utils.utils import G1_Normalization 3 | from spdc_inv import RES_DIR 4 | from jax import numpy as np 5 | 6 | import os 7 | import shutil 8 | import numpy as onp 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | def save_training_statistics( 13 | logs_dir, 14 | fit_results, 15 | interaction, 16 | model_parameters, 17 | ): 18 | if fit_results is not None: 19 | loss_trn, best_loss = fit_results 20 | 21 | pump_coeffs_real, \ 22 | pump_coeffs_imag, \ 23 | waist_pump, \ 24 | crystal_coeffs_real, \ 25 | crystal_coeffs_imag, \ 26 | r_scale = model_parameters 27 | 28 | pump = open(os.path.join(logs_dir, 'pump.txt'), 'w') 29 | pump.write( 30 | type_coeffs_to_txt( 31 | interaction.pump_basis, 32 | interaction.pump_max_mode1, 33 | interaction.pump_max_mode2, 34 | pump_coeffs_real[0] if pump_coeffs_real is not None 35 | else interaction.initial_pump_coefficients()[0], 36 | pump_coeffs_imag[0] if pump_coeffs_imag is not None 37 | else interaction.initial_pump_coefficients()[1], 38 | waist_pump[0] if waist_pump is not None 39 | else interaction.initial_pump_waists(), 40 | ) 41 | ) 42 | 43 | if interaction.crystal_basis: 44 | 45 | crystal = open(os.path.join(logs_dir, 'crystal.txt'), 'w') 46 | crystal.write( 47 | type_coeffs_to_txt( 48 | interaction.crystal_basis, 49 | interaction.crystal_max_mode1, 50 | interaction.crystal_max_mode2, 51 | crystal_coeffs_real[0] if crystal_coeffs_real is not None 52 | else interaction.initial_crystal_coefficients()[0], 53 | crystal_coeffs_imag[0] if crystal_coeffs_imag is not None 54 | else interaction.initial_crystal_coefficients()[1], 55 | r_scale[0] if r_scale is not None 56 | else interaction.initial_crystal_waists(), 57 | ) 58 | ) 59 | 60 | if fit_results is not None: 61 | # print loss 62 | plt.plot(loss_trn, 'r', label='training') 63 | plt.ylabel('objective loss') 64 | plt.xlabel('#epoch') 65 | # plt.ylim(0.2, 1) 66 | plt.axhline(y=best_loss, color='gray', linestyle='--') 67 | plt.text(2, best_loss, f'best = {best_loss}', rotation=0, horizontalalignment='left', 68 | verticalalignment='top', multialignment='center') 69 | plt.legend() 70 | plt.savefig(os.path.join(logs_dir, 'loss')) 71 | plt.close() 72 | 73 | np.save(os.path.join(logs_dir, 'parameters_pump_real.npy'), 74 | pump_coeffs_real[0] if pump_coeffs_real is not None 75 | else interaction.initial_pump_coefficients()[0]) 76 | np.save(os.path.join(logs_dir, 'parameters_pump_imag.npy'), 77 | pump_coeffs_imag[0] if pump_coeffs_imag is not None 78 | else interaction.initial_pump_coefficients()[1]) 79 | np.save(os.path.join(logs_dir, 'parameters_pump_waists.npy'), 80 | waist_pump[0] if waist_pump is not None 81 | else interaction.initial_pump_waists()) 82 | if interaction.crystal_basis is not None: 83 | np.save(os.path.join(logs_dir, 'parameters_crystal_real.npy'), 84 | crystal_coeffs_real[0] if crystal_coeffs_real is not None 85 | else interaction.initial_crystal_coefficients()[0]) 86 | np.save(os.path.join(logs_dir, 'parameters_crystal_imag.npy'), 87 | crystal_coeffs_imag[0] if crystal_coeffs_imag is not None 88 | else interaction.initial_crystal_coefficients()[1]) 89 | np.save(os.path.join(logs_dir, 'parameters_crystal_effective_waists.npy'), 90 | r_scale[0] if r_scale is not None 91 | else interaction.initial_crystal_waists() 92 | ) 93 | 94 | return 95 | 96 | 97 | def save_results( 98 | run_name, 99 | observable_vec, 100 | observables, 101 | projection_coincidence_rate, 102 | projection_tomography_matrix, 103 | Signal, 104 | Idler, 105 | ): 106 | results_dir = os.path.join(RES_DIR, run_name) 107 | if os.path.exists(results_dir): 108 | shutil.rmtree(results_dir) 109 | os.makedirs(results_dir, exist_ok=True) 110 | 111 | (coincidence_rate, density_matrix, tomography_matrix) = observables 112 | 113 | if observable_vec[COINCIDENCE_RATE]: 114 | coincidence_rate = coincidence_rate[0] 115 | coincidence_rate = coincidence_rate / np.sum(np.abs(coincidence_rate)) 116 | np.save(os.path.join(results_dir, 'coincidence_rate.npy'), coincidence_rate) 117 | coincidence_rate_plots( 118 | results_dir, 119 | coincidence_rate, 120 | projection_coincidence_rate, 121 | Signal, 122 | Idler, 123 | ) 124 | 125 | if observable_vec[DENSITY_MATRIX]: 126 | density_matrix = density_matrix[0] 127 | density_matrix = density_matrix / np.trace(np.real(density_matrix)) 128 | np.save(os.path.join(results_dir, 'density_matrix_real.npy'), onp.real(density_matrix)) 129 | np.save(os.path.join(results_dir, 'density_matrix_imag.npy'), onp.imag(density_matrix)) 130 | density_matrix_plots( 131 | results_dir, 132 | density_matrix, 133 | ) 134 | 135 | if observable_vec[TOMOGRAPHY_MATRIX]: 136 | tomography_matrix = tomography_matrix[0] 137 | tomography_matrix = tomography_matrix / np.sum(np.abs(tomography_matrix)) 138 | np.save(os.path.join(results_dir, 'tomography_matrix.npy'), tomography_matrix) 139 | tomography_matrix_plots( 140 | results_dir, 141 | tomography_matrix, 142 | projection_tomography_matrix, 143 | Signal, 144 | Idler, 145 | ) 146 | 147 | 148 | def coincidence_rate_plots( 149 | results_dir, 150 | coincidence_rate, 151 | projection_coincidence_rate, 152 | Signal, 153 | Idler, 154 | ): 155 | # coincidence_rate = unwrap_kron(coincidence_rate, 156 | # projection_coincidence_rate.projection_n_modes1, 157 | # projection_coincidence_rate.projection_n_modes2) 158 | coincidence_rate = coincidence_rate[0, :].\ 159 | reshape(projection_coincidence_rate.projection_n_modes2, projection_coincidence_rate.projection_n_modes2) 160 | 161 | # Compute and plot reduced coincidence_rate 162 | g1_ss_normalization = G1_Normalization(Signal.w) 163 | g1_ii_normalization = G1_Normalization(Idler.w) 164 | coincidence_rate_reduced = coincidence_rate * \ 165 | projection_coincidence_rate.tau / (g1_ii_normalization * g1_ss_normalization) 166 | 167 | # plot coincidence_rate 2d 168 | plt.imshow(coincidence_rate_reduced) 169 | plt.xlabel(r'signal mode i') 170 | plt.ylabel(r'idle mode j') 171 | plt.colorbar() 172 | 173 | plt.savefig(os.path.join(results_dir, 'coincidence_rate')) 174 | plt.close() 175 | 176 | 177 | def tomography_matrix_plots( 178 | results_dir, 179 | tomography_matrix, 180 | projection_tomography_matrix, 181 | Signal, 182 | Idler, 183 | ): 184 | 185 | # tomography_matrix = unwrap_kron(tomography_matrix, 186 | # projection_tomography_matrix.projection_n_state1, 187 | # projection_tomography_matrix.projection_n_state2) 188 | 189 | tomography_matrix = tomography_matrix[0, :].\ 190 | reshape(projection_tomography_matrix.projection_n_state2, projection_tomography_matrix.projection_n_state2) 191 | 192 | # Compute and plot reduced tomography_matrix 193 | g1_ss_normalization = G1_Normalization(Signal.w) 194 | g1_ii_normalization = G1_Normalization(Idler.w) 195 | 196 | tomography_matrix_reduced = tomography_matrix * \ 197 | projection_tomography_matrix.tau / (g1_ii_normalization * g1_ss_normalization) 198 | 199 | # plot tomography_matrix 2d 200 | plt.imshow(tomography_matrix_reduced) 201 | plt.xlabel(r'signal mode i') 202 | plt.ylabel(r'idle mode j') 203 | plt.colorbar() 204 | 205 | plt.savefig(os.path.join(results_dir, 'tomography_matrix')) 206 | plt.close() 207 | 208 | 209 | def density_matrix_plots( 210 | results_dir, 211 | density_matrix, 212 | ): 213 | 214 | density_matrix_real = onp.real(density_matrix) 215 | density_matrix_imag = onp.imag(density_matrix) 216 | 217 | plt.imshow(density_matrix_real) 218 | plt.xlabel(r'signal mode i') 219 | plt.ylabel(r'idle mode j') 220 | plt.colorbar() 221 | plt.savefig(os.path.join(results_dir, 'density_matrix_real')) 222 | plt.close() 223 | 224 | plt.imshow(density_matrix_imag) 225 | plt.xlabel(r'signal mode i') 226 | plt.ylabel(r'idle mode j') 227 | plt.colorbar() 228 | plt.savefig(os.path.join(results_dir, 'density_matrix_imag')) 229 | plt.close() 230 | 231 | 232 | def type_coeffs_to_txt( 233 | basis, 234 | max_mode1, 235 | max_mode2, 236 | coeffs_real, 237 | coeffs_imag, 238 | waists): 239 | sign = {'1.0': '+', '-1.0': '-', '0.0': '+'} 240 | print_str = f'basis: {basis}({max_mode1},{max_mode2}):\n' 241 | for _real, _imag, _waist in zip(coeffs_real, coeffs_imag, waists): 242 | sign_imag = sign[str(onp.sign(_imag).item())] 243 | print_str += '{:.4} {} j{:.4} (waist: {:.4}[um])\n'.format(_real, sign_imag, onp.abs(_imag), _waist * 10) 244 | return print_str 245 | 246 | 247 | def unwrap_kron(G, M1, M2): 248 | ''' 249 | the function takes a Kronicker product of size M1^2 x M2^2 and turns is into an 250 | M1 x M2 x M1 x M2 tensor. It is used only for illustration and not during the learning 251 | Parameters 252 | ---------- 253 | G: the tensor we wish to reshape 254 | M1: first dimension 255 | M2: second dimension 256 | 257 | Returns a reshaped tensor with shape (M1, M2, M1, M2) 258 | ------- 259 | 260 | ''' 261 | 262 | C = onp.zeros((M1, M2, M1, M2), dtype=onp.float32) 263 | 264 | for i in range(M1): 265 | for j in range(M2): 266 | for k in range(M1): 267 | for l in range(M2): 268 | C[i, j, k, l] = G[k + M1 * i, l + M2 * j] 269 | return C 270 | -------------------------------------------------------------------------------- /src/spdc_inv/loss/loss.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from jax import jit 3 | from typing import Tuple, Sequence, Dict, Any, Union, Optional, List 4 | 5 | import jax.numpy as np 6 | import numpy as onp 7 | import itertools 8 | import os 9 | from spdc_inv import DATA_DIR 10 | 11 | 12 | class Loss(ABC): 13 | def __init__( 14 | self, 15 | observable_as_target: Tuple[Dict[Any, bool], ...], 16 | target: str = None, 17 | loss_arr: Union[Dict[Any, Optional[Tuple[str, str]]], 18 | Tuple[Dict[Any, Optional[Tuple[str, str]]], ...]] = None, 19 | loss_weights: Union[Dict[Any, Optional[Tuple[float, float]]], 20 | Tuple[Dict[Any, Optional[Tuple[float, float]]], ...]] = None, 21 | reg_observable: Union[Dict[Any, Optional[Tuple[str, str]]], 22 | Tuple[Dict[Any, Optional[Tuple[str, str]]], ...]] = None, 23 | reg_observable_w: Union[Dict[Any, Optional[Tuple[float, float]]], 24 | Tuple[Dict[Any, Optional[Tuple[float, float]]], ...]] = None, 25 | reg_observable_elements: Union[Dict[Any, Optional[Tuple[List[int], List[int]]]], 26 | Tuple[Dict[Any, Optional[Tuple[List[int], List[int]]]], ...]] = None, 27 | l2_reg: float = 0., 28 | 29 | ): 30 | self.LOSS = dict(l1=self.l1, 31 | l2=self.l2, 32 | kl=self.kl, 33 | bhattacharyya=self.bhattacharyya, 34 | trace_distance=self.trace_distance) 35 | 36 | self.REG_OBS = dict(sparsify=self.sparsify, 37 | equalize=self.equalize) 38 | 39 | if observable_as_target: 40 | assert (loss_arr is not None) or \ 41 | (reg_observable is not None), 'While observable_as_target is True, ' \ 42 | 'a loss must be selected.\n' \ 43 | 'Got loss_arr and reg_observable as None.' 44 | 45 | if loss_arr is not None: 46 | assert len(loss_arr) == len(loss_weights), 'loss_arr and loss_weights must have equal number of elements' 47 | 48 | assert observable_as_target and os.path.exists(os.path.join(DATA_DIR, 'targets', target)), \ 49 | f' target file, {target}, is missing' 50 | 51 | for loss_arr_ in loss_arr: 52 | assert loss_arr_.lower() in self.LOSS, f'Loss must be defined as on of the following' \ 53 | f'options only: {list(self.LOSS.keys())}' 54 | 55 | if reg_observable is not None: 56 | assert len(reg_observable) == len(reg_observable_w) and \ 57 | len(reg_observable) == len(reg_observable_elements), 'reg_observable, reg_observable_w and' \ 58 | 'reg_observable_elements must have equal ' \ 59 | 'number of elements' 60 | for reg_obs in reg_observable: 61 | assert reg_obs.lower() in self.REG_OBS, f'Loss must be defined as on of the following' \ 62 | f'options only: {list(self.REG_OBS.keys())}' 63 | 64 | self.observable_as_target = observable_as_target 65 | self.loss_arr = loss_arr 66 | self.loss_weights = loss_weights 67 | self.reg_observable = reg_observable 68 | self.reg_observable_w = reg_observable_w 69 | self.reg_observable_elements = reg_observable_elements 70 | self.l2_reg = l2_reg 71 | 72 | self.target_str = None 73 | if observable_as_target and loss_arr is not None: 74 | self.target_str = target 75 | 76 | self.loss_stack = self.LOSS_stack() 77 | self.reg_obs_stack = self.REG_obs_stack() 78 | 79 | 80 | def apply( 81 | self, 82 | observable, 83 | model_parameters, 84 | target, 85 | ): 86 | 87 | loss = 0. 88 | if self.observable_as_target: 89 | 90 | for loss_func, loss_weight in zip(self.loss_stack, self.loss_weights): 91 | loss = loss + loss_weight * loss_func(observable, target) 92 | 93 | for obs_func, weight, elements in zip(self.reg_obs_stack, 94 | self.reg_observable_w, self.reg_observable_elements): 95 | loss = loss + weight * obs_func(observable, elements) 96 | 97 | if self.l2_reg > 0.: 98 | loss = loss + self.l2_reg * self.l2_regularization(model_parameters) 99 | 100 | return loss 101 | 102 | def LOSS_stack(self): 103 | 104 | loss_stack = [] 105 | if self.loss_arr is None: 106 | self.loss_weights = [] 107 | return loss_stack 108 | 109 | for loss in self.loss_arr: 110 | loss_stack.append(self.LOSS[loss]) 111 | 112 | return loss_stack 113 | 114 | def REG_obs_stack(self): 115 | 116 | reg_obs_stack = [] 117 | if self.reg_observable is None: 118 | self.reg_observable_w = [] 119 | self.reg_observable_elements = [] 120 | return reg_obs_stack 121 | 122 | for reg_obs in self.reg_observable: 123 | reg_obs_stack.append(self.REG_OBS[reg_obs]) 124 | 125 | return reg_obs_stack 126 | 127 | @staticmethod 128 | @jit 129 | def l1(observable: np.array, 130 | target: np.array, 131 | ): 132 | """ 133 | L1 loss 134 | Parameters 135 | ---------- 136 | observable: tensor 137 | 138 | Returns the l1 distance between observable and the target 139 | ------- 140 | 141 | """ 142 | return np.sum(np.abs(observable - target)) 143 | 144 | @staticmethod 145 | @jit 146 | def l2(observable: np.array, 147 | target: np.array, 148 | ): 149 | """ 150 | L2 loss 151 | Parameters 152 | ---------- 153 | observable: tensor 154 | 155 | Returns the l2 distance between observable and the target 156 | ------- 157 | 158 | """ 159 | return np.sum((observable - target)**2) 160 | 161 | @staticmethod 162 | @jit 163 | def kl(observable: np.array, 164 | target: np.array, 165 | eps: float = 1e-2, 166 | ): 167 | """ 168 | kullback leibler divergence 169 | Parameters 170 | ---------- 171 | observable: tensor 172 | eps: Epsilon is used here to avoid conditional code for 173 | checking that neither observable nor the target is equal to 0 (or smaller) 174 | 175 | Returns the kullback leibler divergence between observable and the target 176 | ------- 177 | 178 | """ 179 | 180 | A = observable + eps 181 | B = target + eps 182 | return np.sum(B * np.log(B / A)) 183 | 184 | @staticmethod 185 | @jit 186 | def bhattacharyya(observable: np.array, 187 | target: np.array, 188 | eps: float = 1e-10, 189 | ): 190 | """ 191 | Bhattacharyya distance 192 | 193 | Parameters 194 | ---------- 195 | observable: tensor 196 | eps: Epsilon is used here to avoid conditional code for 197 | checking that neither observable nor the target is equal to 0 (or smaller) 198 | 199 | Returns the Bhattacharyya distance between observable and the target 200 | ------- 201 | 202 | """ 203 | 204 | return np.sqrt(1. - np.sum(np.sqrt(observable * target + eps))) 205 | 206 | @staticmethod 207 | @jit 208 | def trace_distance(rho: np.array, 209 | target: np.array, 210 | ): 211 | """ 212 | Trace Distance 213 | 214 | Calculate the Trace Distance between rho and the target density matrix 215 | as depict in: https://en.wikipedia.org/wiki/Trace_distance#Definition 216 | 217 | Parameters 218 | ---------- 219 | rho: density matrix rho 220 | Returns: Trace distance between rho and the target 221 | ------- 222 | """ 223 | 224 | td = 0.5 * np.linalg.norm(rho - target, ord='nuc') 225 | 226 | return td 227 | 228 | @staticmethod 229 | def sparsify(observable, 230 | elements, 231 | ): 232 | """ 233 | the method will penalize all other elements in tensor observable 234 | Parameters 235 | ---------- 236 | observable: tensor of size (1,projection_n_modes) 237 | elements: elements we don't want to penalize 238 | 239 | Returns l1 amplitude of all other elements 240 | ------- 241 | 242 | """ 243 | projection_n_modes = observable.shape[-1] 244 | sparsify_elements = onp.delete(onp.arange(projection_n_modes), elements) 245 | return np.sum(np.abs(observable[..., sparsify_elements])) 246 | 247 | @staticmethod 248 | def equalize( 249 | observable, 250 | elements 251 | ): 252 | """ 253 | the method will penalize if elements in observable doesn't have equal amplitude 254 | Parameters 255 | ---------- 256 | observable: tensor of size (1,projection_n_modes) 257 | elements: elements we wish to have equal energy in observable 258 | 259 | Returns the sum over all l1 distances fro all elements in observable 260 | ------- 261 | 262 | """ 263 | equalize_elements_combinations = list(itertools.combinations(elements, 2)) 264 | reg = 0. 265 | for el_comb in equalize_elements_combinations: 266 | reg = reg + np.sum(np.abs(observable[..., el_comb[0]] - observable[..., el_comb[1]])) 267 | 268 | return reg 269 | 270 | @staticmethod 271 | @jit 272 | def l2_regularization(model_parameters): 273 | """ 274 | l2 regularization 275 | Parameters 276 | ---------- 277 | model_parameters: model's learned parameters 278 | 279 | Returns l2 regularization 280 | ------- 281 | 282 | """ 283 | l2_reg = 0. 284 | for params in model_parameters: 285 | if params is not None: 286 | l2_reg = l2_reg + np.sum(np.abs(params)**2) 287 | return l2_reg 288 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/spdc_inv/models/spdc_model.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | import jax.random as random 3 | import jax.numpy as np 4 | 5 | from spdc_inv.models.utils import Field 6 | from spdc_inv.models.utils import crystal_prop, propagate 7 | from spdc_inv.utils.defaults import QUBIT 8 | from spdc_inv.utils.utils import DensMat 9 | from spdc_inv.training.utils import ( 10 | projection_matrix_calc, projection_matrices_calc, 11 | decompose, fix_power, get_qubit_density_matrix, get_qutrit_density_matrix, 12 | coupling_inefficiency_calc_G2, coupling_inefficiency_calc_tomo, 13 | ) 14 | 15 | 16 | class SPDCmodel(ABC): 17 | """ 18 | A differentiable SPDC forward model 19 | """ 20 | 21 | def __init__( 22 | self, 23 | pump, 24 | signal, 25 | idler, 26 | projection_coincidence_rate, 27 | projection_tomography_matrix, 28 | interaction, 29 | pump_structure, 30 | crystal_hologram, 31 | poling_period, 32 | DeltaZ, 33 | coincidence_rate_observable, 34 | density_matrix_observable, 35 | tomography_matrix_observable, 36 | coupling_inefficiencies: bool = False, 37 | ): 38 | 39 | self.pump = pump 40 | self.signal = signal 41 | self.idler = idler 42 | self.projection_coincidence_rate = projection_coincidence_rate 43 | self.projection_tomography_matrix = projection_tomography_matrix 44 | self.interaction = interaction 45 | self.pump_structure = pump_structure 46 | self.crystal_hologram = crystal_hologram 47 | self.poling_period = poling_period 48 | self.DeltaZ = DeltaZ 49 | self.coincidence_rate_observable = coincidence_rate_observable 50 | self.density_matrix_observable = density_matrix_observable 51 | self.tomography_matrix_observable = tomography_matrix_observable 52 | 53 | self.coupling_inefficiencies = coupling_inefficiencies 54 | 55 | self.N = None 56 | self.N_device = None 57 | self.learn_mode = None 58 | 59 | self.signal_f = Field(signal, interaction.dx, interaction.dy, interaction.maxZ) 60 | self.idler_f = Field(idler, interaction.dx, interaction.dy, interaction.maxZ) 61 | 62 | 63 | def forward( 64 | self, 65 | model_parameters, 66 | rand_key 67 | ): 68 | pump_coeffs_real, \ 69 | pump_coeffs_imag, \ 70 | waist_pump, \ 71 | crystal_coeffs_real, \ 72 | crystal_coeffs_imag, \ 73 | r_scale = model_parameters 74 | 75 | self.pump_structure.create_profile(pump_coeffs_real, pump_coeffs_imag, waist_pump) 76 | if self.crystal_hologram is not None: 77 | self.crystal_hologram.create_profile(crystal_coeffs_real, crystal_coeffs_imag, r_scale) 78 | 79 | rand_key, subkey = random.split(rand_key) 80 | # initialize the vacuum and interaction fields 81 | vacuum_states = random.normal( 82 | subkey, 83 | (self.N_device, 2, 2, self.interaction.Nx, self.interaction.Ny) 84 | ) 85 | signal_out,\ 86 | idler_out,\ 87 | idler_vac \ 88 | = crystal_prop(self.pump_structure.E, 89 | self.pump, 90 | self.signal_f, 91 | self.idler_f, 92 | vacuum_states, 93 | self.interaction, 94 | self.poling_period, 95 | self.N_device, 96 | None if self.crystal_hologram is None else self.crystal_hologram.crystal_profile, 97 | True if not self.learn_mode else None, 98 | signal_init=None, 99 | idler_init=None 100 | ) 101 | 102 | # Propagate generated fields back to the middle of the crystal 103 | signal_out_back_prop = propagate(signal_out, 104 | self.interaction.x, 105 | self.interaction.y, 106 | self.signal_f.k, 107 | self.DeltaZ 108 | ) * np.exp(-1j * self.signal_f.k * self.DeltaZ) 109 | 110 | idler_out_back_prop = propagate(idler_out, 111 | self.interaction.x, 112 | self.interaction.y, 113 | self.idler_f.k, 114 | self.DeltaZ 115 | ) * np.exp(-1j * self.idler_f.k * self.DeltaZ) 116 | 117 | idler_vac_back_prop = propagate(idler_vac, 118 | self.interaction.x, 119 | self.interaction.y, 120 | self.idler_f.k, 121 | self.DeltaZ 122 | ) * np.exp(-1j * self.idler_f.k * self.DeltaZ) 123 | 124 | coincidence_rate_projections, tomography_matrix_projections = \ 125 | self.get_1st_order_projections( 126 | signal_out, 127 | idler_out, 128 | idler_vac, 129 | signal_out_back_prop, 130 | idler_out_back_prop, 131 | idler_vac_back_prop, 132 | 133 | 134 | ) 135 | 136 | observables = self.get_observables(coincidence_rate_projections, tomography_matrix_projections) 137 | 138 | return observables 139 | 140 | def get_1st_order_projections( 141 | self, 142 | signal_out, 143 | idler_out, 144 | idler_vac, 145 | signal_out_back_prop, 146 | idler_out_back_prop, 147 | idler_vac_back_prop, 148 | ): 149 | """ 150 | the function calculates first order correlation functions. 151 | According to https://doi.org/10.1002/lpor.201900321 152 | 153 | Parameters 154 | ---------- 155 | signal_out: the signal at the end of interaction 156 | idler_out: the idler at the end of interaction 157 | idler_vac: the idler vacuum state at the end of interaction 158 | 159 | Returns: first order correlation functions according to https://doi.org/10.1002/lpor.201900321 160 | ------- 161 | 162 | """ 163 | 164 | coincidence_rate_projections, tomography_matrix_projections = None, None 165 | if self.coincidence_rate_observable: 166 | coincidence_rate_projections = self.decompose_and_get_projections( 167 | signal_out, 168 | idler_out, 169 | idler_vac, 170 | signal_out_back_prop, 171 | idler_out_back_prop, 172 | idler_vac_back_prop, 173 | self.projection_coincidence_rate.basis_arr, 174 | self.projection_coincidence_rate.projection_n_modes1, 175 | self.projection_coincidence_rate.projection_n_modes2 176 | ) 177 | 178 | if self.tomography_matrix_observable or self.density_matrix_observable: 179 | tomography_matrix_projections = self.decompose_and_get_projections( 180 | signal_out, 181 | idler_out, 182 | idler_vac, 183 | signal_out_back_prop, 184 | idler_out_back_prop, 185 | idler_vac_back_prop, 186 | self.projection_tomography_matrix.basis_arr, 187 | self.projection_tomography_matrix.projection_n_state1, 188 | self.projection_tomography_matrix.projection_n_state2 189 | ) 190 | 191 | return coincidence_rate_projections, tomography_matrix_projections 192 | 193 | def decompose_and_get_projections( 194 | self, 195 | signal_out, 196 | idler_out, 197 | idler_vac, 198 | signal_out_back_prop, 199 | idler_out_back_prop, 200 | idler_vac_back_prop, 201 | basis_arr, 202 | projection_n_1, 203 | projection_n_2 204 | ): 205 | """ 206 | The function decompose the interacting fields onto selected basis array, and calculates first order 207 | correlation functions according to https://doi.org/10.1002/lpor.201900321 208 | 209 | Parameters 210 | ---------- 211 | signal_out 212 | idler_out 213 | idler_vac 214 | signal_out_back_prop 215 | idler_out_back_prop 216 | idler_vac_back_prop 217 | basis_arr 218 | projection_n_1 219 | projection_n_2 220 | 221 | Returns 222 | ------- 223 | 224 | """ 225 | 226 | signal_beam_decompose, idler_beam_decompose, idler_vac_decompose = \ 227 | self.decompose( 228 | signal_out, 229 | idler_out, 230 | idler_vac, 231 | signal_out_back_prop, 232 | idler_out_back_prop, 233 | idler_vac_back_prop, 234 | basis_arr, 235 | projection_n_1, 236 | projection_n_2 237 | ) 238 | 239 | G1_ss, G1_ii, G1_si, G1_si_dagger, Q_si, Q_si_dagger = projection_matrices_calc( 240 | signal_beam_decompose, 241 | idler_beam_decompose, 242 | idler_vac_decompose, 243 | self.N 244 | ) 245 | 246 | return G1_ss, G1_ii, G1_si, G1_si_dagger, Q_si, Q_si_dagger 247 | 248 | def decompose( 249 | self, 250 | signal_out, 251 | idler_out, 252 | idler_vac, 253 | signal_out_back_prop, 254 | idler_out_back_prop, 255 | idler_vac_back_prop, 256 | basis_arr, 257 | projection_n_1, 258 | projection_n_2 259 | ): 260 | 261 | signal_beam_decompose = decompose( 262 | signal_out_back_prop, 263 | basis_arr 264 | ).reshape( 265 | self.N_device, 266 | projection_n_1, 267 | projection_n_2) 268 | 269 | idler_beam_decompose = decompose( 270 | idler_out_back_prop, 271 | basis_arr 272 | ).reshape( 273 | self.N_device, 274 | projection_n_1, 275 | projection_n_2) 276 | 277 | idler_vac_decompose = decompose( 278 | idler_vac_back_prop, 279 | basis_arr 280 | ).reshape( 281 | self.N_device, 282 | projection_n_1, 283 | projection_n_2) 284 | 285 | # say there are no higher modes by normalizing the power 286 | signal_beam_decompose = fix_power(signal_beam_decompose, signal_out) 287 | idler_beam_decompose = fix_power(idler_beam_decompose, idler_out) 288 | idler_vac_decompose = fix_power(idler_vac_decompose, idler_vac) 289 | 290 | return signal_beam_decompose, idler_beam_decompose, idler_vac_decompose 291 | 292 | def get_observables( 293 | self, 294 | coincidence_rate_projections, 295 | tomography_matrix_projections, 296 | 297 | ): 298 | coincidence_rate, density_matrix, tomography_matrix = None, None, None 299 | 300 | if self.coincidence_rate_observable: 301 | coincidence_rate = projection_matrix_calc( 302 | *coincidence_rate_projections 303 | ).reshape( 304 | self.projection_coincidence_rate.projection_n_modes1 ** 2, 305 | self.projection_coincidence_rate.projection_n_modes2 ** 2 306 | ) 307 | ## coupling inefficiences 308 | if self.coupling_inefficiencies: 309 | assert self.projection_coincidence_rate.projection_basis.lower() == 'lg', \ 310 | f'Only implemented for Laguerre-Gauss bases. ' \ 311 | f'We received {self.projection_coincidence_rate.projection_basis}' 312 | coincidence_rate = np.multiply(coupling_inefficiency_calc_G2( 313 | self.signal.lam, 314 | self.projection_coincidence_rate.SMF_waist, 315 | ), coincidence_rate 316 | ) 317 | 318 | if self.tomography_matrix_observable or self.density_matrix_observable: 319 | tomography_matrix = projection_matrix_calc( 320 | *tomography_matrix_projections 321 | ).reshape( 322 | self.projection_tomography_matrix.projection_n_state1 ** 2, 323 | self.projection_tomography_matrix.projection_n_state2 ** 2) 324 | if self.coupling_inefficiencies: 325 | if self.projection_tomography_matrix.tomography_quantum_state is not QUBIT: 326 | # in the case of qubit tomography, inefficiency factor is same for all modes 327 | # in the tomography matrix 328 | tomography_matrix = np.multiply(coupling_inefficiency_calc_tomo( 329 | self.signal.lam, 330 | self.projection_tomography_matrix.SMF_waist, 331 | ), tomography_matrix 332 | ) 333 | 334 | if self.density_matrix_observable: 335 | densmat = DensMat( 336 | self.projection_tomography_matrix.projection_n_state2, 337 | self.projection_tomography_matrix.tomography_dimensions 338 | ) 339 | 340 | if self.projection_tomography_matrix.tomography_quantum_state is QUBIT: 341 | density_matrix = get_qubit_density_matrix(tomography_matrix, 342 | densmat.masks, 343 | densmat.rotation_mats, 344 | ).reshape( 345 | self.projection_tomography_matrix.tomography_dimensions ** 2, 346 | self.projection_tomography_matrix.tomography_dimensions ** 2) 347 | else: 348 | density_matrix = get_qutrit_density_matrix(tomography_matrix, 349 | densmat.masks, 350 | densmat.rotation_mats, 351 | ).reshape( 352 | self.projection_tomography_matrix.tomography_dimensions ** 2, 353 | self.projection_tomography_matrix.tomography_dimensions ** 2) 354 | 355 | return coincidence_rate, density_matrix, tomography_matrix 356 | 357 | -------------------------------------------------------------------------------- /src/spdc_inv/training/trainer.py: -------------------------------------------------------------------------------- 1 | import jax.random as random 2 | import time 3 | import numpy as onp 4 | 5 | from abc import ABC 6 | from typing import Dict, Optional, Tuple, Any 7 | from jax import pmap 8 | from jax import value_and_grad 9 | from jax import lax 10 | from functools import partial 11 | from jax import numpy as np 12 | from jax.lax import stop_gradient 13 | from spdc_inv.utils.utils import Crystal_hologram, Beam_profile 14 | from spdc_inv.models.spdc_model import SPDCmodel 15 | from spdc_inv.utils.defaults import COINCIDENCE_RATE, DENSITY_MATRIX, TOMOGRAPHY_MATRIX 16 | 17 | 18 | class BaseTrainer(ABC): 19 | """ 20 | A class abstracting the various tasks of training models. 21 | Provides methods at multiple levels of granularity 22 | """ 23 | def __init__( 24 | self, 25 | key: np.array, 26 | n_epochs: int, 27 | N_train: int, 28 | N_inference: int, 29 | N_train_device: int, 30 | N_inference_device: int, 31 | learn_pump_coeffs: bool, 32 | learn_pump_waists: bool, 33 | learn_crystal_coeffs: bool, 34 | learn_crystal_waists: bool, 35 | keep_best: bool, 36 | n_devices: int, 37 | projection_coincidence_rate, 38 | projection_tomography_matrix, 39 | interaction, 40 | pump, 41 | signal, 42 | idler, 43 | observable_vec: Optional[Tuple[Dict[Any, bool]]], 44 | coupling_inefficiencies: bool, 45 | ): 46 | 47 | self.key = key 48 | self.n_devices = n_devices 49 | self.n_epochs = n_epochs 50 | self.N_train = N_train 51 | self.N_inference = N_inference 52 | self.N_train_device = N_train_device 53 | self.N_inference_device = N_inference_device 54 | self.keep_best = keep_best 55 | self.delta_k = pump.k - signal.k - idler.k # phase mismatch 56 | self.poling_period = interaction.dk_offset * self.delta_k 57 | self.Nx = interaction.Nx 58 | self.Ny = interaction.Ny 59 | self.DeltaZ = - interaction.maxZ / 2 # DeltaZ: longitudinal middle of the crystal (with negative sign). 60 | # To propagate generated fields back to the middle of the crystal 61 | 62 | self.projection_coincidence_rate = projection_coincidence_rate 63 | self.projection_tomography_matrix = projection_tomography_matrix 64 | 65 | assert list(observable_vec.keys()) == [COINCIDENCE_RATE, 66 | DENSITY_MATRIX, 67 | TOMOGRAPHY_MATRIX], 'observable_vec must only contain ' \ 68 | 'the keys [coincidence_rate,' \ 69 | 'density_matrix, tomography_matrix]' 70 | 71 | self.coincidence_rate_observable = observable_vec[COINCIDENCE_RATE] 72 | self.density_matrix_observable = observable_vec[DENSITY_MATRIX] 73 | self.tomography_matrix_observable = observable_vec[TOMOGRAPHY_MATRIX] 74 | self.coincidence_rate_loss = None 75 | self.density_matrix_loss = None 76 | self.tomography_matrix_loss = None 77 | self.opt_init, self.opt_update, self.get_params = None, None, None 78 | self.target_coincidence_rate = None 79 | self.target_density_matrix = None 80 | self.target_tomography_matrix = None 81 | 82 | self.coupling_inefficiencies = coupling_inefficiencies 83 | 84 | # Initialize pump and crystal coefficients 85 | pump_coeffs_real, \ 86 | pump_coeffs_imag = interaction.initial_pump_coefficients() 87 | waist_pump = interaction.initial_pump_waists() 88 | 89 | crystal_coeffs_real,\ 90 | crystal_coeffs_imag = interaction.initial_crystal_coefficients() 91 | r_scale = interaction.initial_crystal_waists() 92 | 93 | self.pump_coeffs_real, \ 94 | self.pump_coeffs_imag, \ 95 | self.waist_pump = None, None, None 96 | 97 | self.crystal_coeffs_real, \ 98 | self.crystal_coeffs_imag, \ 99 | self.r_scale = None, None, None 100 | 101 | self.learn_pump_coeffs = learn_pump_coeffs 102 | self.learn_pump_waists = learn_pump_waists 103 | self.learn_crystal_coeffs = learn_crystal_coeffs 104 | self.learn_crystal_waists = learn_crystal_waists 105 | 106 | if self.learn_pump_coeffs: 107 | self.pump_coeffs_real, \ 108 | self.pump_coeffs_imag = pump_coeffs_real, pump_coeffs_imag 109 | if self.learn_pump_waists: 110 | self.waist_pump = waist_pump 111 | if self.learn_crystal_coeffs: 112 | self.crystal_coeffs_real, \ 113 | self.crystal_coeffs_imag = crystal_coeffs_real, crystal_coeffs_imag 114 | if self.learn_crystal_waists: 115 | self.r_scale = r_scale 116 | 117 | 118 | self.model_parameters = pmap(lambda x: ( 119 | self.pump_coeffs_real, 120 | self.pump_coeffs_imag, 121 | self.waist_pump, 122 | self.crystal_coeffs_real, 123 | self.crystal_coeffs_imag, 124 | self.r_scale 125 | ))(np.arange(self.n_devices)) 126 | 127 | print(f"Interaction length [m]: {interaction.maxZ} \n") 128 | print(f"Pump beam basis coefficients: \n {pump_coeffs_real + 1j * pump_coeffs_imag}\n") 129 | print(f"Pump basis functions waists [um]: \n {waist_pump * 10}\n") 130 | 131 | if interaction.crystal_basis: 132 | print(f"3D hologram basis coefficients: \n {crystal_coeffs_real + 1j * crystal_coeffs_imag}\n") 133 | print("3D hologram basis functions-" 134 | f"effective waists (r_scale) [um]: \n {r_scale * 10}\n") 135 | self.crystal_hologram = Crystal_hologram(crystal_coeffs_real, 136 | crystal_coeffs_imag, 137 | r_scale, 138 | interaction.x, 139 | interaction.y, 140 | interaction.crystal_max_mode1, 141 | interaction.crystal_max_mode2, 142 | interaction.crystal_basis, 143 | signal.lam, 144 | signal.n, 145 | learn_crystal_coeffs, 146 | learn_crystal_waists) 147 | else: 148 | self.crystal_hologram = None 149 | 150 | self.pump_structure = Beam_profile(pump_coeffs_real, 151 | pump_coeffs_imag, 152 | waist_pump, 153 | interaction.power_pump, 154 | interaction.x, 155 | interaction.y, 156 | interaction.dx, 157 | interaction.dy, 158 | interaction.pump_max_mode1, 159 | interaction.pump_max_mode2, 160 | interaction.pump_basis, 161 | interaction.lam_pump, 162 | pump.n, 163 | learn_pump_coeffs, 164 | learn_pump_waists) 165 | 166 | self.model = SPDCmodel(pump, 167 | signal=signal, 168 | idler=idler, 169 | projection_coincidence_rate=projection_coincidence_rate, 170 | projection_tomography_matrix=projection_tomography_matrix, 171 | interaction=interaction, 172 | pump_structure=self.pump_structure, 173 | crystal_hologram=self.crystal_hologram, 174 | poling_period=self.poling_period, 175 | DeltaZ=self.DeltaZ, 176 | coincidence_rate_observable=self.coincidence_rate_observable, 177 | density_matrix_observable=self.density_matrix_observable, 178 | tomography_matrix_observable=self.tomography_matrix_observable, 179 | coupling_inefficiencies=self.coupling_inefficiencies 180 | ) 181 | 182 | def inference(self): 183 | self.model.learn_mode = False 184 | self.model.N = self.N_inference 185 | self.model.N_device = self.N_inference_device 186 | 187 | # seed vacuum samples for each gpu 188 | self.key, subkey = random.split(self.key) 189 | keys = random.split(subkey, self.n_devices) 190 | # observables = pmap(self.model.forward, axis_name='device')(self.model_parameters, keys) 191 | observables = pmap(self.model.forward, axis_name='device')(stop_gradient(self.model_parameters), keys) 192 | 193 | return observables 194 | 195 | def fit(self): 196 | self.model.learn_mode = True 197 | self.model.N = self.N_train 198 | self.model.N_device = self.N_train_device 199 | 200 | opt_state = self.opt_init(self.model_parameters) 201 | 202 | loss_trn, best_loss = [], None 203 | epochs_without_improvement = 0 204 | 205 | for epoch in range(self.n_epochs): 206 | start_time_epoch = time.time() 207 | print(f'running epoch {epoch}/{self.n_epochs}') 208 | 209 | idx = np.array([epoch]).repeat(self.n_devices) 210 | self.key, subkey = random.split(self.key) 211 | training_subkeys = random.split(subkey, self.n_devices) 212 | 213 | training_loss, opt_state = self.update(opt_state, 214 | idx, 215 | training_subkeys,) 216 | 217 | loss_trn.append(training_loss[0].item()) 218 | 219 | print("in {:0.2f} sec".format(time.time() - start_time_epoch)) 220 | print("training objective loss:{:0.6f}".format(loss_trn[epoch])) 221 | 222 | if best_loss is None or loss_trn[epoch] < best_loss and not onp.isnan(loss_trn[epoch]): 223 | 224 | print(f'best objective loss is reached\n') 225 | 226 | model_parameters = self.get_params(opt_state) 227 | pump_coeffs_real, \ 228 | pump_coeffs_imag, \ 229 | waist_pump, \ 230 | crystal_coeffs_real, \ 231 | crystal_coeffs_imag, \ 232 | r_scale = model_parameters 233 | 234 | if self.learn_pump_coeffs: 235 | normalization = np.sqrt(np.sum(np.abs(pump_coeffs_real) ** 2 + 236 | np.abs(pump_coeffs_imag) ** 2, 1, keepdims=True)) 237 | pump_coeffs_real = pump_coeffs_real / normalization 238 | pump_coeffs_imag = pump_coeffs_imag / normalization 239 | 240 | print(f"Pump beam basis coefficients: \n " 241 | f"{pump_coeffs_real[0] + 1j * pump_coeffs_imag[0]}\n") 242 | 243 | if self.learn_pump_waists: 244 | print(f"Pump basis functions " 245 | f"waists [um]: \n {waist_pump[0] * 10}\n") 246 | 247 | if self.crystal_hologram: 248 | if self.learn_crystal_coeffs: 249 | normalization = np.sqrt(np.sum(np.abs(crystal_coeffs_real) ** 2 + 250 | np.abs(crystal_coeffs_imag) ** 2, 1, keepdims=True)) 251 | crystal_coeffs_real = crystal_coeffs_real / normalization 252 | crystal_coeffs_imag = crystal_coeffs_imag / normalization 253 | 254 | print(f"3D hologram basis coefficients: \n " 255 | f"{crystal_coeffs_real[0] + 1j * crystal_coeffs_imag[0]}\n") 256 | 257 | if self.learn_crystal_waists: 258 | print("3D hologram basis functions-" 259 | f"effective waists (r_scale) [um]: \n {r_scale[0] * 10}\n") 260 | 261 | best_loss = loss_trn[epoch] 262 | epochs_without_improvement = 0 263 | if self.keep_best: 264 | self.model_parameters = model_parameters 265 | print(f'parameters are updated\n') 266 | else: 267 | epochs_without_improvement += 1 268 | print(f'number of epochs without improvement {epochs_without_improvement}, ' 269 | f'best objective loss {best_loss}') 270 | 271 | if not self.keep_best: 272 | model_parameters = self.get_params(opt_state) 273 | self.model_parameters = model_parameters 274 | 275 | return loss_trn, best_loss 276 | 277 | @partial(pmap, axis_name='device', static_broadcasted_argnums=(0,)) 278 | def update( 279 | self, 280 | opt_state, 281 | i, 282 | training_subkeys, 283 | ): 284 | 285 | model_parameters = self.get_params(opt_state) 286 | loss, grads = value_and_grad(self.loss)(model_parameters, training_subkeys) 287 | 288 | grads_ = [] 289 | for g, grads_param in enumerate(grads): 290 | if grads_param is not None: 291 | grads_.append(np.array([lax.psum(dw, 'device') for dw in grads_param])) 292 | else: 293 | grads_.append(None) 294 | grads = tuple(grads_) 295 | 296 | opt_state = self.opt_update(i, grads, opt_state) 297 | training_loss = lax.pmean(loss, 'device') 298 | 299 | return training_loss, opt_state 300 | 301 | 302 | def loss( 303 | self, 304 | model_parameters, 305 | keys, 306 | ): 307 | pump_coeffs_real, \ 308 | pump_coeffs_imag, \ 309 | waist_pump, \ 310 | crystal_coeffs_real, \ 311 | crystal_coeffs_imag, \ 312 | r_scale = model_parameters 313 | 314 | if self.learn_pump_coeffs: 315 | normalization = np.sqrt(np.sum(np.abs(pump_coeffs_real) ** 2 + np.abs(pump_coeffs_imag) ** 2)) 316 | pump_coeffs_real = pump_coeffs_real / normalization 317 | pump_coeffs_imag = pump_coeffs_imag / normalization 318 | 319 | if self.crystal_hologram: 320 | if self.learn_crystal_coeffs: 321 | normalization = np.sqrt(np.sum(np.abs(crystal_coeffs_real) ** 2 + np.abs(crystal_coeffs_imag) ** 2)) 322 | crystal_coeffs_real = crystal_coeffs_real / normalization 323 | crystal_coeffs_imag = crystal_coeffs_imag / normalization 324 | 325 | model_parameters = (pump_coeffs_real, 326 | pump_coeffs_imag, 327 | waist_pump, 328 | crystal_coeffs_real, 329 | crystal_coeffs_imag, 330 | r_scale) 331 | 332 | observables = self.model.forward(model_parameters, keys) 333 | 334 | (coincidence_rate, density_matrix, tomography_matrix) = observables 335 | if self.coincidence_rate_loss.observable_as_target: 336 | coincidence_rate = coincidence_rate / np.sum(np.abs(coincidence_rate)) 337 | coincidence_rate_loss = self.coincidence_rate_loss.apply( 338 | coincidence_rate, model_parameters, self.target_coincidence_rate 339 | ) 340 | 341 | if self.density_matrix_loss.observable_as_target: 342 | density_matrix = density_matrix / np.trace(np.real(density_matrix)) 343 | density_matrix_loss = self.density_matrix_loss.apply( 344 | density_matrix, model_parameters, self.target_density_matrix 345 | ) 346 | 347 | if self.tomography_matrix_loss.observable_as_target: 348 | tomography_matrix = tomography_matrix / np.sum(np.abs(tomography_matrix)) 349 | tomography_matrix_loss = self.tomography_matrix_loss.apply( 350 | tomography_matrix, model_parameters, self.target_tomography_matrix 351 | ) 352 | 353 | return coincidence_rate_loss + density_matrix_loss + tomography_matrix_loss 354 | 355 | -------------------------------------------------------------------------------- /src/spdc_inv/data/interaction.py: -------------------------------------------------------------------------------- 1 | import jax.random as random 2 | import os 3 | 4 | from abc import ABC 5 | from jax import numpy as np 6 | from jax.ops import index_update 7 | from spdc_inv.data.utils import n_KTP_Kato, nz_MgCLN_Gayer 8 | from spdc_inv.utils.utils import PP_crystal_slab 9 | from spdc_inv.utils.utils import SFG_idler_wavelength 10 | from spdc_inv.utils.defaults import REAL, IMAG 11 | from typing import Dict 12 | 13 | 14 | class Interaction(ABC): 15 | """ 16 | A class that represents the SPDC interaction process, 17 | on all of its physical parameters. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | pump_basis: str = 'LG', 23 | pump_max_mode1: int = 5, 24 | pump_max_mode2: int = 2, 25 | initial_pump_coefficient: str = 'uniform', 26 | custom_pump_coefficient: Dict[str, Dict[int, int]] = None, 27 | pump_coefficient_path: str = None, 28 | initial_pump_waist: str = 'waist_pump0', 29 | pump_waists_path: str = None, 30 | crystal_basis: str = 'LG', 31 | crystal_max_mode1: int = 5, 32 | crystal_max_mode2: int = 2, 33 | initial_crystal_coefficient: str = 'uniform', 34 | custom_crystal_coefficient: Dict[str, Dict[int, int]] = None, 35 | crystal_coefficient_path: str = None, 36 | initial_crystal_waist: str = 'r_scale0', 37 | crystal_waists_path: str = None, 38 | lam_pump: float = 405e-9, 39 | crystal_str: str = 'ktp', 40 | power_pump: float = 1e-3, 41 | waist_pump0: float = None, 42 | r_scale0: float = 40e-6, 43 | dx: float = 4e-6, 44 | dy: float = 4e-6, 45 | dz: float = 10e-6, 46 | maxX: float = 120e-6, 47 | maxY: float = 120e-6, 48 | maxZ: float = 1e-4, 49 | R: float = 0.1, 50 | Temperature: float = 50, 51 | pump_polarization: str = 'y', 52 | signal_polarization: str = 'y', 53 | idler_polarization: str = 'z', 54 | dk_offset: float = 1, 55 | power_signal: float = 1, 56 | power_idler: float = 1, 57 | key: np.array = None, 58 | 59 | ): 60 | """ 61 | 62 | Parameters 63 | ---------- 64 | pump_basis: Pump's construction basis method 65 | Can be: LG (Laguerre-Gauss) / HG (Hermite-Gauss) 66 | pump_max_mode1: Maximum value of first mode of the 2D pump basis 67 | pump_max_mode2: Maximum value of second mode of the 2D pump basis 68 | initial_pump_coefficient: defines the initial distribution of coefficient-amplitudes for pump basis function 69 | can be: uniform- uniform distribution 70 | random- uniform distribution 71 | custom- as defined at custom_pump_coefficient 72 | load- will be loaded from np.arrays defined under path: pump_coefficient_path 73 | with names: parameters_pump_real.npy, parameters_pump_imag.npy 74 | pump_coefficient_path: path for loading waists for pump basis function 75 | custom_pump_coefficient: (dictionary) used only if initial_pump_coefficient=='custom' 76 | {'real': {indexes:coeffs}, 'imag': {indexes:coeffs}}. 77 | initial_pump_waist: defines the initial values of waists for pump basis function 78 | can be: waist_pump0- will be set according to waist_pump0 79 | load- will be loaded from np.arrays defined under path: pump_waists_path 80 | with name: parameters_pump_waists.npy 81 | pump_waists_path: path for loading coefficient-amplitudes for pump basis function 82 | crystal_basis: Crystal's construction basis method 83 | Can be: 84 | None / FT (Fourier-Taylor) / FB (Fourier-Bessel) / LG (Laguerre-Gauss) / HG (Hermite-Gauss) 85 | - if None, the crystal will contain NO hologram 86 | crystal_max_mode1: Maximum value of first mode of the 2D crystal basis 87 | crystal_max_mode2: Maximum value of second mode of the 2D crystal basis 88 | initial_crystal_coefficient: defines the initial distribution of coefficient-amplitudes for crystal basis function 89 | can be: uniform- uniform distribution 90 | random- uniform distribution 91 | custom- as defined at custom_crystal_coefficient 92 | load- will be loaded from np.arrays defined under path: crystal_coefficient_path 93 | with names: parameters_crystal_real.npy, parameters_crystal_imag.npy 94 | crystal_coefficient_path: path for loading coefficient-amplitudes for crystal basis function 95 | custom_crystal_coefficient: (dictionary) used only if initial_crystal_coefficient=='custom' 96 | {'real': {indexes:coeffs}, 'imag': {indexes:coeffs}}. 97 | initial_crystal_waist: defines the initial values of waists for crystal basis function 98 | can be: r_scale0- will be set according to r_scale0 99 | load- will be loaded from np.arrays defined under path: crystal_waists_path 100 | with name: parameters_crystal_effective_waists.npy 101 | crystal_waists_path: path for loading waists for crystal basis function 102 | lam_pump: Pump wavelength 103 | crystal_str: Crystal type. Can be: KTP or MgCLN 104 | power_pump: Pump power [watt] 105 | waist_pump0: waists of the pump basis functions. 106 | -- If None, waist_pump0 = sqrt(maxZ / self.pump_k) 107 | r_scale0: effective waists of the crystal basis functions. 108 | -- If None, r_scale0 = waist_pump0 109 | dx: transverse resolution in x [m] 110 | dy: transverse resolution in y [m] 111 | dz: longitudinal resolution in z [m] 112 | maxX: Transverse cross-sectional size from the center of the crystal in x [m] 113 | maxY: Transverse cross-sectional size from the center of the crystal in y [m] 114 | maxZ: Crystal's length in z [m] 115 | R: distance to far-field screen [m] 116 | Temperature: crystal's temperature [Celsius Degrees] 117 | pump_polarization: Polarization of the pump beam 118 | signal_polarization: Polarization of the signal beam 119 | idler_polarization: Polarization of the idler beam 120 | dk_offset: delta_k offset 121 | power_signal: Signal power [watt] 122 | power_idler: Idler power [watt] 123 | key: Random key 124 | """ 125 | 126 | self.lam_pump = lam_pump 127 | self.dx = dx 128 | self.dy = dy 129 | self.dz = dz 130 | self.x = np.arange(-maxX, maxX, dx) # x axis, length 2*MaxX (transverse) 131 | self.y = np.arange(-maxY, maxY, dy) # y axis, length 2*MaxY (transverse) 132 | self.Nx = len(self.x) 133 | self.Ny = len(self.y) 134 | self.z = np.arange(-maxZ / 2, maxZ / 2, dz) # z axis, length MaxZ (propagation) 135 | self.maxX = maxX 136 | self.maxY = maxY 137 | self.maxZ = maxZ 138 | self.R = R 139 | self.Temperature = Temperature 140 | self.dk_offset = dk_offset 141 | self.power_pump = power_pump 142 | self.power_signal = power_signal 143 | self.power_idler = power_idler 144 | self.lam_signal = 2 * lam_pump 145 | self.lam_idler = SFG_idler_wavelength(self.lam_pump, self.lam_signal) 146 | self.pump_polarization = pump_polarization 147 | self.signal_polarization = signal_polarization 148 | self.idler_polarization = idler_polarization 149 | self.key = key 150 | 151 | assert crystal_str.lower() in ['ktp', 'mgcln'], 'crystal must be either KTP or MgCLN' 152 | if crystal_str.lower() == 'ktp': 153 | self.ctype = n_KTP_Kato # refractive index function 154 | self.pump_k = 2 * np.pi * n_KTP_Kato(lam_pump * 1e6, Temperature, pump_polarization) / lam_pump 155 | self.d33 = 16.9e-12 # nonlinear coefficient [meter/Volt] 156 | else: 157 | self.ctype = nz_MgCLN_Gayer # refractive index function 158 | self.pump_k = 2 * np.pi * nz_MgCLN_Gayer(lam_pump * 1e6, Temperature) / lam_pump 159 | self.d33 = 23.4e-12 # [meter/Volt] 160 | # self.slab = PP_crystal_slab 161 | 162 | if waist_pump0 is None: 163 | self.waist_pump0 = np.sqrt(maxZ / self.pump_k) 164 | else: 165 | self.waist_pump0 = waist_pump0 166 | 167 | if r_scale0 is None: 168 | self.r_scale0 = self.waist_pump0 169 | else: 170 | self.r_scale0 = r_scale0 171 | 172 | assert pump_basis.lower() in ['lg', 'hg'], 'The beam structure is constructed as a combination ' \ 173 | 'of LG or HG basis functions only' 174 | self.pump_basis = pump_basis 175 | self.pump_max_mode1 = pump_max_mode1 176 | self.pump_max_mode2 = pump_max_mode2 177 | self.initial_pump_coefficient = initial_pump_coefficient 178 | self.custom_pump_coefficient = custom_pump_coefficient 179 | self.pump_coefficient_path = pump_coefficient_path 180 | 181 | # number of modes for pump basis 182 | if pump_basis.lower() == 'lg': 183 | self.pump_n_modes1 = pump_max_mode1 184 | self.pump_n_modes2 = 2 * pump_max_mode2 + 1 185 | else: 186 | self.pump_n_modes1 = pump_max_mode1 187 | self.pump_n_modes2 = pump_max_mode2 188 | 189 | # Total number of pump modes 190 | self.pump_n_modes = self.pump_n_modes1 * self.pump_n_modes2 191 | 192 | self.initial_pump_waist = initial_pump_waist 193 | self.pump_waists_path = pump_waists_path 194 | 195 | self.crystal_basis = crystal_basis 196 | if crystal_basis: 197 | assert crystal_basis.lower() in ['ft', 'fb', 'lg', 'hg'], 'The crystal structure was constructed ' \ 198 | 'as a combination of FT, FB, LG or HG ' \ 199 | 'basis functions only' 200 | 201 | self.crystal_max_mode1 = crystal_max_mode1 202 | self.crystal_max_mode2 = crystal_max_mode2 203 | self.initial_crystal_coefficient = initial_crystal_coefficient 204 | self.custom_crystal_coefficient = custom_crystal_coefficient 205 | self.crystal_coefficient_path = crystal_coefficient_path 206 | 207 | # number of modes for crystal basis 208 | if crystal_basis.lower() in ['ft', 'fb', 'lg']: 209 | self.crystal_n_modes1 = crystal_max_mode1 210 | self.crystal_n_modes2 = 2 * crystal_max_mode2 + 1 211 | else: 212 | self.crystal_n_modes1 = crystal_max_mode1 213 | self.crystal_n_modes2 = crystal_max_mode2 214 | 215 | # Total number of crystal modes 216 | self.crystal_n_modes = self.crystal_n_modes1 * self.crystal_n_modes2 217 | 218 | self.initial_crystal_waist = initial_crystal_waist 219 | self.crystal_waists_path = crystal_waists_path 220 | 221 | def initial_pump_coefficients( 222 | self, 223 | ): 224 | 225 | if self.initial_pump_coefficient == "uniform": 226 | coeffs_real = np.ones(self.pump_n_modes, dtype=np.float32) 227 | coeffs_imag = np.ones(self.pump_n_modes, dtype=np.float32) 228 | 229 | elif self.initial_pump_coefficient == "random": 230 | 231 | self.key, pump_coeff_key = random.split(self.key) 232 | rand_real, rand_imag = random.split(pump_coeff_key) 233 | coeffs_real = random.normal(rand_real, (self.pump_n_modes,)) 234 | coeffs_imag = random.normal(rand_imag, (self.pump_n_modes,)) 235 | 236 | elif self.initial_pump_coefficient == "custom": 237 | assert self.custom_pump_coefficient, 'for custom method, pump basis coefficients and ' \ 238 | 'indexes must be selected' 239 | coeffs_real = np.zeros(self.pump_n_modes, dtype=np.float32) 240 | coeffs_imag = np.zeros(self.pump_n_modes, dtype=np.float32) 241 | for index, coeff in self.custom_pump_coefficient[REAL].items(): 242 | assert type(index) is int, f'index {index} must be int type' 243 | assert type(coeff) is float, f'coeff {coeff} must be float type' 244 | assert index < self.pump_n_modes, 'index for custom pump (real) initialization must be smaller ' \ 245 | 'than total number of modes.' \ 246 | f'Got index {index} for total number of modes {self.pump_n_modes}' 247 | coeffs_real = index_update(coeffs_real, index, coeff) 248 | 249 | for index, coeff in self.custom_pump_coefficient[IMAG].items(): 250 | assert type(index) is int, f'index {index} must be int type' 251 | assert type(coeff) is float, f'coeff {coeff} must be float type' 252 | assert index < self.pump_n_modes, 'index for custom pump (imag) initialization must be smaller ' \ 253 | 'than total number of modes.' \ 254 | f'Got index {index} for total number of modes {self.pump_n_modes}' 255 | coeffs_imag = index_update(coeffs_imag, index, coeff) 256 | 257 | 258 | elif self.initial_pump_coefficient == "load": 259 | assert self.pump_coefficient_path, 'Path to pump coefficients must be defined' 260 | 261 | coeffs_real = np.load(os.path.join(self.pump_coefficient_path, 'parameters_pump_real.npy')) 262 | coeffs_imag = np.load(os.path.join(self.pump_coefficient_path, 'parameters_pump_imag.npy')) 263 | 264 | else: 265 | coeffs_real, coeffs_imag = None, None 266 | assert "ERROR: incompatible pump basis coefficients" 267 | 268 | normalization = np.sqrt(np.sum(np.abs(coeffs_real) ** 2 + np.abs(coeffs_imag) ** 2)) 269 | coeffs_real = coeffs_real / normalization 270 | coeffs_imag = coeffs_imag / normalization 271 | 272 | return coeffs_real, coeffs_imag 273 | 274 | 275 | def initial_pump_waists( 276 | self, 277 | ): 278 | if self.initial_pump_waist == "waist_pump0": 279 | waist_pump = np.ones(self.pump_n_modes, dtype=np.float32) * self.waist_pump0 * 1e5 280 | 281 | elif self.initial_pump_waist == "load": 282 | assert self.pump_waists_path, 'Path to pump waists must be defined' 283 | 284 | waist_pump = np.load(os.path.join(self.pump_coefficient_path, "parameters_pump_waists.npy")) * 1e-1 285 | 286 | else: 287 | waist_pump = None 288 | assert "ERROR: incompatible pump basis waists" 289 | 290 | return waist_pump 291 | 292 | 293 | def initial_crystal_coefficients( 294 | self, 295 | ): 296 | if not self.crystal_basis: 297 | return None, None 298 | 299 | elif self.initial_crystal_coefficient == "uniform": 300 | coeffs_real = np.ones(self.crystal_n_modes, dtype=np.float32) 301 | coeffs_imag = np.ones(self.crystal_n_modes, dtype=np.float32) 302 | 303 | elif self.initial_crystal_coefficient == "random": 304 | 305 | self.key, crystal_coeff_key = random.split(self.key) 306 | rand_real, rand_imag = random.split(crystal_coeff_key) 307 | coeffs_real = random.normal(rand_real, (self.crystal_n_modes,)) 308 | coeffs_imag = random.normal(rand_imag, (self.crystal_n_modes,)) 309 | 310 | elif self.initial_crystal_coefficient == "custom": 311 | assert self.custom_crystal_coefficient, 'for custom method, crystal basis coefficients and ' \ 312 | 'indexes must be selected' 313 | coeffs_real = np.zeros(self.crystal_n_modes, dtype=np.float32) 314 | coeffs_imag = np.zeros(self.crystal_n_modes, dtype=np.float32) 315 | for index, coeff in self.custom_crystal_coefficient[REAL].items(): 316 | assert type(index) is int, f'index {index} must be int type' 317 | assert type(coeff) is float, f'coeff {coeff} must be float type' 318 | assert index < self.crystal_n_modes, 'index for custom crystal (real) initialization must be smaller ' \ 319 | 'than total number of modes.' \ 320 | f'Got index {index} for total number of modes {self.crystal_n_modes}' 321 | coeffs_real = index_update(coeffs_real, index, coeff) 322 | 323 | for index, coeff in self.custom_crystal_coefficient[IMAG].items(): 324 | assert type(index) is int, f'index {index} must be int type' 325 | assert type(coeff) is float, f'coeff {coeff} must be float type' 326 | assert index < self.crystal_n_modes, 'index for custom crystal (imag) initialization must be smaller ' \ 327 | 'than total number of modes.' \ 328 | f'Got index {index} for total number of modes {self.crystal_n_modes}' 329 | coeffs_imag = index_update(coeffs_imag, index, coeff) 330 | 331 | elif self.initial_crystal_coefficient == "load": 332 | assert self.crystal_coefficient_path, 'Path to crystal coefficients must be defined' 333 | 334 | coeffs_real = np.load(os.path.join(self.crystal_coefficient_path, 'parameters_crystal_real.npy')) 335 | coeffs_imag = np.load(os.path.join(self.crystal_coefficient_path, 'parameters_crystal_imag.npy')) 336 | 337 | else: 338 | coeffs_real, coeffs_imag = None, None 339 | assert "ERROR: incompatible crystal basis coefficients" 340 | 341 | normalization = np.sqrt(np.sum(np.abs(coeffs_real) ** 2 + np.abs(coeffs_imag) ** 2)) 342 | coeffs_real = coeffs_real / normalization 343 | coeffs_imag = coeffs_imag / normalization 344 | 345 | return coeffs_real, coeffs_imag 346 | 347 | 348 | def initial_crystal_waists( 349 | self, 350 | ): 351 | 352 | if not self.crystal_basis: 353 | return None 354 | 355 | if self.initial_crystal_waist == "r_scale0": 356 | r_scale = np.ones(self.crystal_n_modes, dtype=np.float32) * self.r_scale0 * 1e5 357 | 358 | elif self.initial_crystal_waist == "load": 359 | assert self.crystal_waists_path, 'Path to crystal waists must be defined' 360 | 361 | r_scale = np.load(os.path.join(self.crystal_waists_path, "parameters_crystal_effective_waists.npy")) * 1e-1 362 | 363 | else: 364 | r_scale = None 365 | assert "ERROR: incompatible crystal basis waists" 366 | 367 | return r_scale 368 | -------------------------------------------------------------------------------- /src/spdc_inv/experiments/experiment.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import shutil 4 | import time 5 | import numpy as onp 6 | from datetime import datetime 7 | from typing import Tuple, Dict, Any, Optional, List, Union 8 | from spdc_inv import DATA_DIR 9 | from spdc_inv import LOGS_DIR 10 | from spdc_inv.utils.defaults import REAL, IMAG 11 | from spdc_inv.utils.defaults import COINCIDENCE_RATE, DENSITY_MATRIX, TOMOGRAPHY_MATRIX 12 | 13 | def run_experiment( 14 | run_name: str, 15 | seed: int = 42, 16 | CUDA_VISIBLE_DEVICES: str = None, 17 | JAX_ENABLE_X64: str = 'True', 18 | minimal_GPU_memory: bool = False, 19 | learn_mode: bool = False, 20 | learn_pump_coeffs: bool = True, 21 | learn_pump_waists: bool = True, 22 | learn_crystal_coeffs: bool = True, 23 | learn_crystal_waists: bool = True, 24 | keep_best: bool = True, 25 | n_epochs: int = 500, 26 | N_train: int = 1000, 27 | N_inference: int = 1000, 28 | target: str = 'qutrit', 29 | observable_vec: Tuple[Dict[Any, bool]] = None, 30 | coupling_inefficiencies: bool = False, 31 | loss_arr: Tuple[Dict[Any, Optional[Tuple[str, str]]]] = None, 32 | loss_weights: Tuple[Dict[Any, Optional[Tuple[float, float]]]] = None, 33 | reg_observable: Tuple[Dict[Any, Optional[Tuple[str, str]]]] = None, 34 | reg_observable_w: Tuple[Dict[Any, Optional[Tuple[float, float]]]] = None, 35 | reg_observable_elements: Tuple[Dict[Any, Optional[Tuple[List[int], List[int]]]]] = None, 36 | tau: float = 1e-9, 37 | SMF_waist: float = 2.18e-6, 38 | l2_reg: float = 1e-5, 39 | optimizer: str = 'adam', 40 | exp_decay_lr: bool = True, 41 | step_size: float = 0.05, 42 | decay_steps: int = 50, 43 | decay_rate: float = 0.5, 44 | pump_basis: str = 'LG', 45 | pump_max_mode1: int = 5, 46 | pump_max_mode2: int = 1, 47 | initial_pump_coefficient: str = 'random', 48 | custom_pump_coefficient: Dict[str, Dict[int, int]] = None, 49 | pump_coefficient_path: str = None, 50 | initial_pump_waist: str = 'waist_pump0', 51 | pump_waists_path: str = None, 52 | crystal_basis: str = 'LG', 53 | crystal_max_mode1: int = 5, 54 | crystal_max_mode2: int = 2, 55 | initial_crystal_coefficient: str = 'random', 56 | custom_crystal_coefficient: Dict[str, Dict[int, int]] = None, 57 | crystal_coefficient_path: str = None, 58 | initial_crystal_waist: str = 'r_scale0', 59 | crystal_waists_path: str = None, 60 | lam_pump: float = 405e-9, 61 | crystal_str: str = 'ktp', 62 | power_pump: float = 1e-3, 63 | waist_pump0: float = 40e-6, 64 | r_scale0: float = 40e-6, 65 | dx: float = 4e-6, 66 | dy: float = 4e-6, 67 | dz: float = 10e-6, 68 | maxX: float = 120e-6, 69 | maxY: float = 120e-6, 70 | maxZ: float = 1e-4, 71 | R: float = 0.1, 72 | Temperature: float = 50, 73 | pump_polarization: str = 'y', 74 | signal_polarization: str = 'y', 75 | idler_polarization: str = 'z', 76 | dk_offset: float = 1., 77 | power_signal: float = 1., 78 | power_idler: float = 1., 79 | coincidence_projection_basis: str = 'LG', 80 | coincidence_projection_max_mode1: int = 1, 81 | coincidence_projection_max_mode2: int = 4, 82 | coincidence_projection_waist: float = None, 83 | coincidence_projection_wavelength: float = None, 84 | coincidence_projection_polarization: str = 'y', 85 | coincidence_projection_z: float = 0., 86 | tomography_projection_basis: str = 'LG', 87 | tomography_projection_max_mode1: int = 1, 88 | tomography_projection_max_mode2: int = 1, 89 | tomography_projection_waist: float = None, 90 | tomography_projection_wavelength: float = None, 91 | tomography_projection_polarization: str = 'y', 92 | tomography_projection_z: float = 0., 93 | tomography_relative_phase: List[Union[Union[int, float], Any]] = None, 94 | tomography_quantum_state: str = 'qubit' 95 | 96 | ): 97 | """ 98 | This function is the main function for running SPDC project 99 | 100 | Parameters 101 | ---------- 102 | run_name: selected name (will be used for naming the folder) 103 | seed: initial seed for random functions 104 | CUDA_VISIBLE_DEVICES: visible gpu devices to be used 105 | JAX_ENABLE_X64: if True, use double-precision numbers (enabling 64bit mode) 106 | minimal_GPU_memory: This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no 107 | longer needed (note that this is the only configuration that will deallocate GPU memory, 108 | instead of reusing it). This is very slow, so is not recommended for general use, 109 | but may be useful for running with the minimal possible GPU memory footprint 110 | or debugging OOM failures. 111 | 112 | learn_mode: if True run learning method. False, inference. 113 | learn_pump_coeffs: if True, pump coefficients are learned in learning mode 114 | learn_pump_waists: if True, pump coefficients waists are learned in learning mode 115 | learn_crystal_coeffs: if True, crystal coefficients are learned in learning mode 116 | learn_crystal_waists: if True, crystal coefficients waists are learned in learning mode 117 | keep_best: if True, best learned result are kept 118 | n_epochs: number of epochs for learning 119 | N_train: size of vacuum states in training method 120 | N_inference: size of vacuum states in inference method 121 | target: name of target folder with observables, for training (should be placed under: SPDCinv/data/targets/) 122 | For any observable, the files in the folder must contain one of the corresponding names: 123 | 'coincidence_rate.npy', 'density_matrix.npy' 'tomography_matrix.npy' 124 | observable_vec: if an observable in the dictionary is True, 125 | the method will learn/infer the observable along the process 126 | loss_arr: if an observable in observable_vec is True, the following sequence of loss functions are used. 127 | The target observables in the 'target' folder are used. 128 | optional: l1, l2, kl, bhattacharyya, trace_distance, None 129 | if None (for any observable), the loss using target observable will be ignored, while the rest of the 130 | loss options, i.e reg_observable, can be still applied. 131 | loss_weights: the sequence loss_arr is weighted by loss_weights (for each observable) 132 | reg_observable: if an observable in observable_vec is True, the following sequence 133 | of regularization functions are used; applying regularization directly on observable elements. 134 | optional: sparsify: he method will penalize all other elements in tensor observable 135 | equalize: the method will penalize if elements in observable doesn't have equal amplitude 136 | * elements are defined under: reg_observable_elements 137 | reg_observable_w: the sequence reg_observable is weighted by reg_observable_w (for each observable) 138 | reg_observable_elements: the elements for the sequence reg_observable are defined in reg_observable_elements 139 | (for each observable) 140 | 141 | l2_reg: l2 regularization coefficient for model leaned parameters (to reduce over-fitting). 142 | if 0,it will be ignored 143 | optimizer: optimized method. can be: adam 144 | sgd 145 | adagrad 146 | adamax 147 | momentum 148 | nesterov 149 | rmsprop 150 | rmsprop_momentum 151 | exp_decay_lr: the exponential decay rate for step size. calculated for each step i as: 152 | step_size * decay_rate ** (i / decay_steps) 153 | if False, this will be ignored 154 | step_size: learning step size 155 | decay_steps: decay steps for exp_decay_lr 156 | decay_rate: decay rate for exp_decay_lr 157 | 158 | pump_basis: Pump's construction basis method 159 | Can be: LG (Laguerre-Gauss) / HG (Hermite-Gauss) 160 | pump_max_mode1: Maximum value of first mode of the 2D pump basis 161 | pump_max_mode2: Maximum value of second mode of the 2D pump basis 162 | initial_pump_coefficient: defines the initial distribution of coefficient-amplitudes for pump basis function 163 | can be: uniform- uniform distribution 164 | random- uniform distribution 165 | custom- as defined at custom_pump_coefficient 166 | load- will be loaded from np.arrays defined under path: pump_coefficient_path 167 | with names: PumpCoeffs_real.npy, PumpCoeffs_imag.npy 168 | pump_coefficient_path: path for loading waists for pump basis function 169 | custom_pump_coefficient: (dictionary) used only if initial_pump_coefficient=='custom' 170 | {'real': {indexes:coeffs}, 'imag': {indexes:coeffs}}. 171 | initial_pump_waist: defines the initial values of waists for pump basis function 172 | can be: waist_pump0- will be set according to waist_pump0 173 | load- will be loaded from np.arrays defined under path: pump_waists_path 174 | with name: PumpWaistCoeffs.npy 175 | pump_waists_path: path for loading coefficient-amplitudes for pump basis function 176 | crystal_basis: Crystal's construction basis method 177 | Can be: 178 | None / FT (Fourier-Taylor) / FB (Fourier-Bessel) / LG (Laguerre-Gauss) / HG (Hermite-Gauss) 179 | - if None, the crystal will contain NO hologram 180 | crystal_max_mode1: Maximum value of first mode of the 2D crystal basis 181 | crystal_max_mode2: Maximum value of second mode of the 2D crystal basis 182 | initial_crystal_coefficient: defines the initial distribution of coefficient-amplitudes for crystal basis function 183 | can be: uniform- uniform distribution 184 | random- uniform distribution 185 | custom- as defined at custom_crystal_coefficient 186 | load- will be loaded from np.arrays defined under path: crystal_coefficient_path 187 | with names: CrystalCoeffs_real.npy, CrystalCoeffs_imag.npy 188 | crystal_coefficient_path: path for loading coefficient-amplitudes for crystal basis function 189 | custom_crystal_coefficient: (dictionary) used only if initial_crystal_coefficient=='custom' 190 | {'real': {indexes:coeffs}, 'imag': {indexes:coeffs}}. 191 | initial_crystal_waist: defines the initial values of waists for crystal basis function 192 | can be: r_scale0- will be set according to r_scale0 193 | load- will be loaded from np.arrays defined under path: crystal_waists_path 194 | with name: CrystalWaistCoeffs.npy 195 | crystal_waists_path: path for loading waists for crystal basis function 196 | lam_pump: Pump wavelength 197 | crystal_str: Crystal type. Can be: KTP or MgCLN 198 | power_pump: Pump power [watt] 199 | waist_pump0: waists of the pump basis functions. 200 | -- If None, waist_pump0 = sqrt(maxZ / self.pump_k) 201 | r_scale0: effective waists of the crystal basis functions. 202 | -- If None, r_scale0 = waist_pump0 203 | dx: transverse resolution in x [m] 204 | dy: transverse resolution in y [m] 205 | dz: longitudinal resolution in z [m] 206 | maxX: Transverse cross-sectional size from the center of the crystal in x [m] 207 | maxY: Transverse cross-sectional size from the center of the crystal in y [m] 208 | maxZ: Crystal's length in z [m] 209 | R: distance to far-field screen [m] 210 | Temperature: crystal's temperature [Celsius Degrees] 211 | pump_polarization: Polarization of the pump beam 212 | signal_polarization: Polarization of the signal beam 213 | idler_polarization: Polarization of the idler beam 214 | dk_offset: delta_k offset 215 | power_signal: Signal power [watt] 216 | power_idler: Idler power [watt] 217 | 218 | coincidence_projection_basis: represents the projective basis for calculating the coincidence rate observable 219 | of the interaction. Can be: LG (Laguerre-Gauss) / HG (Hermite-Gauss) 220 | coincidence_projection_max_mode1: Maximum value of first mode of the 2D projection basis for coincidence rate 221 | coincidence_projection_max_mode2: Maximum value of second mode of the 2D projection basis for coincidence rate 222 | coincidence_projection_waist: waists of the projection basis functions of coincidence rate. 223 | if None, np.sqrt(2) * waist_pump0 is used 224 | coincidence_projection_wavelength: wavelength for generating projection basis of coincidence rate. 225 | if None, the signal wavelength is used 226 | coincidence_projection_polarization: polarization for calculating effective refractive index 227 | coincidence_projection_z: projection longitudinal position 228 | tomography_projection_basis: represents the projective basis for calculating the tomography matrix & density matrix 229 | observables of the interaction. Can be: LG (Laguerre-Gauss) / HG (Hermite-Gauss) 230 | tomography_projection_max_mode1: Maximum value of first mode of the 2D projection basis for tomography matrix & 231 | density matrix 232 | tomography_projection_max_mode2: Maximum value of second mode of the 2D projection basis for tomography matrix & 233 | density matrix 234 | tomography_projection_waist: waists of the projection basis functions of tomography matrix & density matrix 235 | if None, np.sqrt(2) * waist_pump0 is used 236 | tomography_projection_wavelength: wavelength for generating projection basis of tomography matrix & density matrix. 237 | if None, the signal wavelength is used 238 | tomography_projection_polarization: polarization for calculating effective refractive index 239 | tomography_projection_z: projection longitudinal position 240 | tomography_relative_phase: The relative phase between the mutually unbiased bases (MUBs) states 241 | tomography_quantum_state: the current quantum state we calculate it tomography matrix. 242 | currently we support: qubit/qutrit 243 | tau: coincidence window [nano sec] 244 | SMF_waist: signal/idler beam radius at single mode fibre 245 | ------- 246 | 247 | """ 248 | run_name = f'l_{run_name}_{str(datetime.today()).split()[0]}' if \ 249 | learn_mode else f'i_{run_name}_{str(datetime.today()).split()[0]}' 250 | 251 | now = datetime.now() 252 | date_and_time = now.strftime("%d/%m/%Y %H:%M:%S") 253 | print("date and time =", date_and_time) 254 | 255 | if CUDA_VISIBLE_DEVICES: 256 | os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES 257 | if JAX_ENABLE_X64: 258 | os.environ["JAX_ENABLE_X64"] = JAX_ENABLE_X64 259 | 260 | if minimal_GPU_memory: 261 | os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = 'platform' 262 | 263 | import jax 264 | from jax.lib import xla_bridge 265 | from jax import numpy as np 266 | from spdc_inv.utils.utils import Beam 267 | from spdc_inv.loss.loss import Loss 268 | from spdc_inv.data.interaction import Interaction 269 | from spdc_inv.experiments.utils import Projection_coincidence_rate, Projection_tomography_matrix 270 | from spdc_inv.experiments.results_and_stats_utils import save_results, save_training_statistics 271 | from spdc_inv.training.trainer import BaseTrainer 272 | from spdc_inv.optim.optimizer import Optimizer 273 | 274 | if not seed: 275 | seed = random.randint(0, 2 ** 31) 276 | key = jax.random.PRNGKey(seed) 277 | 278 | n_devices = xla_bridge.device_count() 279 | print(f'Number of GPU devices: {n_devices} \n') 280 | 281 | if learn_mode: 282 | assert N_train % n_devices == 0, "The number of training examples should be " \ 283 | "divisible by the number of devices" 284 | 285 | assert N_inference % n_devices == 0, "The number of inference examples should be " \ 286 | "divisible by the number of devices" 287 | 288 | N_train_device = int(N_train / n_devices) 289 | N_inference_device = int(N_inference / n_devices) 290 | 291 | specs = { 292 | 'experiment name': run_name, 293 | 'seed': seed, 294 | 'date and time': date_and_time, 295 | 'number of gpu devices': n_devices, 296 | 'JAX_ENABLE_X64': JAX_ENABLE_X64, 297 | } 298 | specs.update({'----- Learning Parameters': '----- '}) 299 | specs.update(learning_params) 300 | specs.update({'----- Loss Parameters': '----- '}) 301 | specs.update(loss_params) 302 | specs.update({'----- Optimizer Parameters': '----- '}) 303 | specs.update(optimizer_params) 304 | specs.update({'----- Interaction Parameters': '----- '}) 305 | specs.update(interaction_params) 306 | specs.update({'----- Projection Parameters': '----- '}) 307 | specs.update(projection_params) 308 | 309 | logs_dir = os.path.join(LOGS_DIR, run_name) 310 | if os.path.exists(logs_dir): 311 | shutil.rmtree(logs_dir) 312 | os.makedirs(logs_dir, exist_ok=True) 313 | 314 | key, interaction_key = jax.random.split(key) 315 | interaction = Interaction( 316 | pump_basis=pump_basis, 317 | pump_max_mode1=pump_max_mode1, 318 | pump_max_mode2=pump_max_mode2, 319 | initial_pump_coefficient=initial_pump_coefficient, 320 | custom_pump_coefficient=custom_pump_coefficient, 321 | pump_coefficient_path=pump_coefficient_path, 322 | initial_pump_waist=initial_pump_waist, 323 | pump_waists_path=pump_waists_path, 324 | crystal_basis=crystal_basis, 325 | crystal_max_mode1=crystal_max_mode1, 326 | crystal_max_mode2=crystal_max_mode2, 327 | initial_crystal_coefficient=initial_crystal_coefficient, 328 | custom_crystal_coefficient=custom_crystal_coefficient, 329 | crystal_coefficient_path=crystal_coefficient_path, 330 | initial_crystal_waist=initial_crystal_waist, 331 | crystal_waists_path=crystal_waists_path, 332 | lam_pump=lam_pump, 333 | crystal_str=crystal_str, 334 | power_pump=power_pump, 335 | waist_pump0=waist_pump0, 336 | r_scale0=r_scale0, 337 | dx=dx, 338 | dy=dy, 339 | dz=dz, 340 | maxX=maxX, 341 | maxY=maxY, 342 | maxZ=maxZ, 343 | R=R, 344 | Temperature=Temperature, 345 | pump_polarization=pump_polarization, 346 | signal_polarization=signal_polarization, 347 | idler_polarization=idler_polarization, 348 | dk_offset=dk_offset, 349 | power_signal=power_signal, 350 | power_idler=power_idler, 351 | key=interaction_key, 352 | ) 353 | 354 | projection_coincidence_rate = Projection_coincidence_rate( 355 | calculate_observable=observable_vec[COINCIDENCE_RATE], 356 | waist_pump0=interaction.waist_pump0, 357 | signal_wavelength=interaction.lam_signal, 358 | crystal_x=interaction.x, 359 | crystal_y=interaction.y, 360 | temperature=interaction.Temperature, 361 | ctype=interaction.ctype, 362 | polarization=coincidence_projection_polarization, 363 | z=coincidence_projection_z, 364 | projection_basis=coincidence_projection_basis, 365 | max_mode1=coincidence_projection_max_mode1, 366 | max_mode2=coincidence_projection_max_mode2, 367 | waist=coincidence_projection_waist, 368 | wavelength=coincidence_projection_wavelength, 369 | tau=tau, 370 | SMF_waist=SMF_waist, 371 | ) 372 | 373 | projection_tomography_matrix = Projection_tomography_matrix( 374 | calculate_observable=observable_vec[DENSITY_MATRIX] or observable_vec[TOMOGRAPHY_MATRIX], 375 | waist_pump0=interaction.waist_pump0, 376 | signal_wavelength=interaction.lam_signal, 377 | crystal_x=interaction.x, 378 | crystal_y=interaction.y, 379 | temperature=interaction.Temperature, 380 | ctype=interaction.ctype, 381 | polarization=tomography_projection_polarization, 382 | z=tomography_projection_z, 383 | relative_phase=tomography_relative_phase, 384 | tomography_quantum_state=tomography_quantum_state, 385 | projection_basis=tomography_projection_basis, 386 | max_mode1=tomography_projection_max_mode1, 387 | max_mode2=tomography_projection_max_mode2, 388 | waist=tomography_projection_waist, 389 | wavelength=tomography_projection_wavelength, 390 | tau=tau, 391 | ) 392 | 393 | Pump = Beam(lam=interaction.lam_pump, 394 | ctype=interaction.ctype, 395 | polarization=interaction.pump_polarization, 396 | T=interaction.Temperature, 397 | power=interaction.power_pump) 398 | 399 | Signal = Beam(lam=interaction.lam_signal, 400 | ctype=interaction.ctype, 401 | polarization=interaction.signal_polarization, 402 | T=interaction.Temperature, 403 | power=interaction.power_signal) 404 | 405 | Idler = Beam(lam=interaction.lam_idler, 406 | ctype=interaction.ctype, 407 | polarization=interaction.idler_polarization, 408 | T=interaction.Temperature, 409 | power=interaction.power_idler) 410 | 411 | trainer = BaseTrainer( 412 | key=key, 413 | n_epochs=n_epochs, 414 | N_train=N_train, 415 | N_inference=N_inference, 416 | N_train_device=N_train_device, 417 | N_inference_device=N_inference_device, 418 | learn_pump_coeffs=learn_pump_coeffs, 419 | learn_pump_waists=learn_pump_waists, 420 | learn_crystal_coeffs=learn_crystal_coeffs, 421 | learn_crystal_waists=learn_crystal_waists, 422 | keep_best=keep_best, 423 | n_devices=n_devices, 424 | projection_coincidence_rate=projection_coincidence_rate, 425 | projection_tomography_matrix=projection_tomography_matrix, 426 | interaction=interaction, 427 | pump=Pump, 428 | signal=Signal, 429 | idler=Idler, 430 | observable_vec=observable_vec, 431 | coupling_inefficiencies=coupling_inefficiencies 432 | ) 433 | 434 | training_total_time = None 435 | if learn_mode: 436 | trainer.coincidence_rate_loss = Loss(observable_as_target=observable_vec[COINCIDENCE_RATE], 437 | target=os.path.join(target, f'{COINCIDENCE_RATE}.npy'), 438 | loss_arr=loss_arr[COINCIDENCE_RATE], 439 | loss_weights=loss_weights[COINCIDENCE_RATE], 440 | reg_observable=reg_observable[COINCIDENCE_RATE], 441 | reg_observable_w=reg_observable_w[COINCIDENCE_RATE], 442 | reg_observable_elements=reg_observable_elements[COINCIDENCE_RATE], 443 | l2_reg=l2_reg) 444 | 445 | trainer.density_matrix_loss = Loss(observable_as_target=observable_vec[DENSITY_MATRIX], 446 | target=os.path.join(target, f'{DENSITY_MATRIX}.npy'), 447 | loss_arr=loss_arr[DENSITY_MATRIX], 448 | loss_weights=loss_weights[DENSITY_MATRIX], 449 | reg_observable=reg_observable[DENSITY_MATRIX], 450 | reg_observable_w=reg_observable_w[DENSITY_MATRIX], 451 | reg_observable_elements=reg_observable_elements[DENSITY_MATRIX],) 452 | 453 | trainer.tomography_matrix_loss = Loss(observable_as_target=observable_vec[TOMOGRAPHY_MATRIX], 454 | target=os.path.join(target, f'{TOMOGRAPHY_MATRIX}.npy'), 455 | loss_arr=loss_arr[TOMOGRAPHY_MATRIX], 456 | loss_weights=loss_weights[TOMOGRAPHY_MATRIX], 457 | reg_observable=reg_observable[TOMOGRAPHY_MATRIX], 458 | reg_observable_w=reg_observable_w[TOMOGRAPHY_MATRIX], 459 | reg_observable_elements=reg_observable_elements[TOMOGRAPHY_MATRIX],) 460 | 461 | trainer.opt_init, trainer.opt_update, trainer.get_params = Optimizer(optimizer=optimizer, 462 | exp_decay_lr=exp_decay_lr, 463 | step_size=step_size, 464 | decay_steps=decay_steps, 465 | decay_rate=decay_rate) 466 | 467 | if trainer.coincidence_rate_loss.target_str and observable_vec[COINCIDENCE_RATE]: 468 | trainer.target_coincidence_rate = np.load(os.path.join( 469 | DATA_DIR, 'targets', trainer.coincidence_rate_loss.target_str)) 470 | 471 | if trainer.density_matrix_loss.target_str and observable_vec[DENSITY_MATRIX]: 472 | trainer.target_density_matrix = np.load(os.path.join( 473 | DATA_DIR, 'targets', trainer.density_matrix_loss.target_str)) 474 | 475 | if trainer.tomography_matrix_loss.target_str and observable_vec[TOMOGRAPHY_MATRIX]: 476 | trainer.target_tomography_matrix = np.load(os.path.join( 477 | DATA_DIR, 'targets', trainer.tomography_matrix_loss.target_str)) 478 | 479 | start_time = time.time() 480 | fit_results = trainer.fit() 481 | training_total_time = (time.time() - start_time) 482 | print("training is done after: %s seconds" % training_total_time) 483 | 484 | save_training_statistics( 485 | logs_dir, 486 | fit_results, 487 | interaction, 488 | trainer.model_parameters, 489 | ) 490 | 491 | start_time = time.time() 492 | observables = trainer.inference() 493 | inference_total_time = (time.time() - start_time) 494 | print("inference is done after: %s seconds" % inference_total_time) 495 | 496 | save_results( 497 | run_name, 498 | observable_vec, 499 | observables, 500 | projection_coincidence_rate, 501 | projection_tomography_matrix, 502 | Signal, 503 | Idler, 504 | ) 505 | 506 | else: 507 | start_time = time.time() 508 | observables = trainer.inference() 509 | inference_total_time = (time.time() - start_time) 510 | print("inference is done after: %s seconds" % inference_total_time) 511 | 512 | save_training_statistics( 513 | logs_dir, 514 | None, 515 | interaction, 516 | trainer.model_parameters, 517 | ) 518 | 519 | save_results( 520 | run_name, 521 | observable_vec, 522 | observables, 523 | projection_coincidence_rate, 524 | projection_tomography_matrix, 525 | Signal, 526 | Idler, 527 | ) 528 | 529 | specs_file = os.path.join(logs_dir, 'data_specs.txt') 530 | with open(specs_file, 'w') as f: 531 | if learn_mode: 532 | f.write(f"training running time: {training_total_time} sec\n") 533 | f.write(f"inference running time: {inference_total_time} sec\n") 534 | for k, v in specs.items(): 535 | f.write(f"{k}: {str(v)}\n") 536 | 537 | 538 | if __name__ == "__main__": 539 | 540 | learning_params = { 541 | 'learn_mode': True, 542 | 'learn_pump_coeffs': True, 543 | 'learn_pump_waists': True, 544 | 'learn_crystal_coeffs': True, 545 | 'learn_crystal_waists': True, 546 | 'keep_best': True, 547 | 'n_epochs': 500, 548 | 'N_train': 1000, 549 | 'N_inference': 4000, 550 | 'target': 'qutrit', 551 | 'coupling_inefficiencies': True, 552 | 'observable_vec': { 553 | COINCIDENCE_RATE: True, 554 | DENSITY_MATRIX: False, 555 | TOMOGRAPHY_MATRIX: False 556 | } 557 | } 558 | 559 | loss_params = { 560 | 'loss_arr': { 561 | COINCIDENCE_RATE: None, 562 | DENSITY_MATRIX: None, 563 | TOMOGRAPHY_MATRIX: None 564 | }, 565 | 'loss_weights': { 566 | COINCIDENCE_RATE: (1.,), 567 | DENSITY_MATRIX: None, 568 | TOMOGRAPHY_MATRIX: None 569 | }, 570 | 'reg_observable': { 571 | COINCIDENCE_RATE: ('sparsify', 'equalize'), 572 | DENSITY_MATRIX: None, 573 | TOMOGRAPHY_MATRIX: None 574 | }, 575 | 'reg_observable_w': { 576 | COINCIDENCE_RATE: (.5, .5), 577 | DENSITY_MATRIX: None, 578 | TOMOGRAPHY_MATRIX: None 579 | }, 580 | 'reg_observable_elements': { 581 | COINCIDENCE_RATE: ([30, 40, 50], [30, 40, 50]), 582 | DENSITY_MATRIX: None, 583 | TOMOGRAPHY_MATRIX: None 584 | }, 585 | } 586 | 587 | optimizer_params = { 588 | 'l2_reg': 0., 589 | 'optimizer': 'adam', 590 | 'exp_decay_lr': False, 591 | 'step_size': 0.05, 592 | 'decay_steps': 50, 593 | 'decay_rate': 0.5, 594 | } 595 | 596 | interaction_params = { 597 | 'pump_basis': 'LG', 598 | 'pump_max_mode1': 1, 599 | 'pump_max_mode2': 4, 600 | 'initial_pump_coefficient': 'custom', 601 | 'custom_pump_coefficient': {REAL: {0: 0., 1: 0., 2: 0., 3: 0., 4: 1., 5: 0., 6: 0., 7: 0., 8: 0.}, 602 | IMAG: {0: 0., 1: 0., 2: 0.}}, 603 | 'pump_coefficient_path': None, 604 | 'initial_pump_waist': 'waist_pump0', 605 | 'pump_waists_path': None, 606 | 'crystal_basis': 'LG', 607 | 'crystal_max_mode1': 10, 608 | 'crystal_max_mode2': 4, 609 | 'initial_crystal_coefficient': 'custom', 610 | 'custom_crystal_coefficient': {REAL: {4: 1.}, IMAG: {0: 0., 1: 0., 2: 0.}}, 611 | 'crystal_coefficient_path': None, 612 | 'initial_crystal_waist': 'r_scale0', 613 | 'crystal_waists_path': None, 614 | 'lam_pump': 405e-9, 615 | 'crystal_str': 'ktp', 616 | 'power_pump': 1e-3, 617 | 'waist_pump0': 40e-6, 618 | 'r_scale0': 40e-6, 619 | 'dx': 4e-6, 620 | 'dy': 4e-6, 621 | 'dz': 10e-6, 622 | 'maxX': 180e-6, 623 | 'maxY': 180e-6, 624 | 'maxZ': 1e-3, 625 | 'R': 0.1, 626 | 'Temperature': 50, 627 | 'pump_polarization': 'y', 628 | 'signal_polarization': 'y', 629 | 'idler_polarization': 'z', 630 | 'dk_offset': 1., 631 | 'power_signal': 1., 632 | 'power_idler': 1., 633 | } 634 | 635 | projection_params = { 636 | 'coincidence_projection_basis': 'LG', 637 | 'coincidence_projection_max_mode1': 1, 638 | 'coincidence_projection_max_mode2': 4, 639 | 'coincidence_projection_waist': None, 640 | 'coincidence_projection_wavelength': None, 641 | 'coincidence_projection_polarization': 'y', 642 | 'coincidence_projection_z': 0., 643 | 'tomography_projection_basis': 'LG', 644 | 'tomography_projection_max_mode1': 1, 645 | 'tomography_projection_max_mode2': 1, 646 | 'tomography_projection_waist': None, 647 | 'tomography_projection_wavelength': None, 648 | 'tomography_projection_polarization': 'y', 649 | 'tomography_projection_z': 0., 650 | 'tomography_relative_phase': [0, onp.pi, 3 * (onp.pi / 2), onp.pi / 2], 651 | 'tomography_quantum_state': 'qutrit', 652 | 'tau': 1e-9, 653 | 'SMF_waist': 2.18e-6, 654 | } 655 | 656 | run_experiment( 657 | run_name='test', 658 | seed=42, 659 | JAX_ENABLE_X64='True', 660 | minimal_GPU_memory=False, 661 | CUDA_VISIBLE_DEVICES='0, 1', 662 | **learning_params, 663 | **loss_params, 664 | **optimizer_params, 665 | **interaction_params, 666 | **projection_params, 667 | ) 668 | -------------------------------------------------------------------------------- /src/spdc_inv/utils/utils.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from jax.ops import index_update, index_add, index 3 | from typing import List, Union, Any 4 | from spdc_inv.utils.defaults import QUBIT 5 | 6 | import scipy.special as sp 7 | import jax.numpy as np 8 | import math 9 | 10 | 11 | # Constants: 12 | pi = np.pi 13 | c = 2.99792458e8 # speed of light [meter/sec] 14 | eps0 = 8.854187817e-12 # vacuum permittivity [Farad/meter] 15 | h_bar = 1.054571800e-34 # [m^2 kg / s], taken from http://physics.nist.gov/cgi-bin/cuu/Value?hbar|search_for=planck 16 | 17 | # lambda functions: 18 | G1_Normalization = lambda w: h_bar * w / (2 * eps0 * c) 19 | I = lambda A, n: 2 * n * eps0 * c * np.abs(A) ** 2 # Intensity 20 | Power2D = lambda A, n, dx, dy: np.sum(I(A, n)) * dx * dy 21 | 22 | # Compute the idler wavelength given pump and signal 23 | SFG_idler_wavelength = lambda lambda_p, lambda_s: lambda_p * lambda_s / (lambda_s - lambda_p) 24 | 25 | 26 | def PP_crystal_slab( 27 | delta_k, 28 | z, 29 | crystal_profile, 30 | inference=None 31 | ): 32 | """ 33 | Periodically poled crystal slab. 34 | create the crystal slab at point z in the crystal, for poling period 2pi/delta_k 35 | 36 | Parameters 37 | ---------- 38 | delta_k: k mismatch 39 | z: longitudinal point for generating poling pattern 40 | crystal_profile: Crystal 3D hologram (if None, ignore) 41 | inference: (True/False) if in inference mode, we include more coefficients in the poling 42 | description for better validation 43 | 44 | Returns Periodically poled crystal slab at point z 45 | ------- 46 | 47 | """ 48 | if crystal_profile is None: 49 | return np.sign(np.cos(np.abs(delta_k) * z)) 50 | else: 51 | magnitude = np.abs(crystal_profile) 52 | phase = np.angle(crystal_profile) 53 | if inference is not None: 54 | max_order_fourier = 20 55 | poling = 0 56 | magnitude = magnitude / magnitude.max() 57 | DutyCycle = np.arcsin(magnitude) / np.pi 58 | for m in range(max_order_fourier): 59 | if m == 0: 60 | poling = poling + 2 * DutyCycle - 1 61 | else: 62 | poling = poling + (2 / (m * np.pi)) * \ 63 | np.sin(m * pi * DutyCycle) * 2 * np.cos(m * phase + m * np.abs(delta_k) * z) 64 | return poling 65 | else: 66 | return (2 / np.pi) * np.exp(1j * (np.abs(delta_k) * z)) * magnitude * np.exp(1j * phase) 67 | 68 | 69 | def HermiteBank( 70 | lam, 71 | refractive_index, 72 | W0, 73 | max_mode_x, 74 | max_mode_y, 75 | x, 76 | y, 77 | z=0 78 | ): 79 | """ 80 | generates a dictionary of Hermite Gauss basis functions 81 | 82 | Parameters 83 | ---------- 84 | lam; wavelength 85 | refractive_index: refractive index 86 | W0: beam waist 87 | max_mode_x: maximum projection mode 1st axis 88 | max_mode_y: maximum projection mode 2nd axis 89 | x: transverse points, x axis 90 | y: transverse points, y axis 91 | z: projection longitudinal position 92 | 93 | Returns 94 | ------- 95 | dictionary of Hermite Gauss basis functions 96 | """ 97 | Hermite_dict = {} 98 | for nx in range(max_mode_x): 99 | for ny in range(max_mode_y): 100 | Hermite_dict[f'|HG{nx}{ny}>'] = Hermite_gauss(lam, refractive_index, W0, nx, ny, z, x, y) 101 | return np.array(list(Hermite_dict.values())), [*Hermite_dict] 102 | 103 | 104 | def LaguerreBank( 105 | lam, 106 | refractive_index, 107 | W0, 108 | max_mode_p, 109 | max_mode_l, 110 | x, 111 | y, 112 | z=0, 113 | get_dict: bool = False, 114 | ): 115 | """ 116 | generates a dictionary of Laguerre Gauss basis functions 117 | 118 | Parameters 119 | ---------- 120 | lam; wavelength 121 | refractive_index: refractive index 122 | W0: beam waist 123 | max_mode_p: maximum projection mode 1st axis 124 | max_mode_l: maximum projection mode 2nd axis 125 | x: transverse points, x axis 126 | y: transverse points, y axis 127 | z: projection longitudinal position 128 | get_dict: (True/False) if True, the function will return a dictionary, 129 | else the dictionary is splitted to basis functions np.array and list of dictionary keys. 130 | 131 | Returns 132 | ------- 133 | dictionary of Laguerre Gauss basis functions 134 | """ 135 | Laguerre_dict = {} 136 | for p in range(max_mode_p): 137 | for l in range(-max_mode_l, max_mode_l + 1): 138 | Laguerre_dict[f'|LG{p}{l}>'] = Laguerre_gauss(lam, refractive_index, W0, l, p, z, x, y) 139 | if get_dict: 140 | return Laguerre_dict 141 | 142 | return np.array(list(Laguerre_dict.values())), [*Laguerre_dict] 143 | 144 | 145 | def TomographyBankLG( 146 | lam, 147 | refractive_index, 148 | W0, 149 | max_mode_p, 150 | max_mode_l, 151 | x, 152 | y, 153 | z=0, 154 | relative_phase: List[Union[Union[int, float], Any]] = None, 155 | tomography_quantum_state: str = None, 156 | ): 157 | """ 158 | generates a dictionary of basis function with projections into two orthogonal LG bases and mutually unbiased 159 | bases (MUBs). The MUBs are constructed from superpositions of the two orthogonal LG bases. 160 | according to: https://doi.org/10.1364/AOP.11.000067 161 | 162 | Parameters 163 | ---------- 164 | lam; wavelength 165 | refractive_index: refractive index 166 | W0: beam waist 167 | max_mode_p: maximum projection mode 1st axis 168 | max_mode_l: maximum projection mode 2nd axis 169 | x: transverse points, x axis 170 | y: transverse points, y axis 171 | z: projection longitudinal position 172 | relative_phase: The relative phase between the mutually unbiased bases (MUBs) states 173 | tomography_quantum_state: the current quantum state we calculate it tomography matrix. 174 | currently we support: qubit/qutrit 175 | 176 | Returns 177 | ------- 178 | dictionary of bases functions used for constructing the tomography matrix 179 | """ 180 | 181 | TOMO_dict = \ 182 | LaguerreBank( 183 | lam, 184 | refractive_index, 185 | W0, 186 | max_mode_p, 187 | max_mode_l, 188 | x, y, z, 189 | get_dict=True) 190 | 191 | if tomography_quantum_state is QUBIT: 192 | del TOMO_dict['|LG00>'] 193 | 194 | LG_modes, LG_string = np.array(list(TOMO_dict.values())), [*TOMO_dict] 195 | 196 | for m in range(len(TOMO_dict) - 1, -1, -1): 197 | for n in range(m - 1, -1, -1): 198 | for k in range(len(relative_phase)): 199 | TOMO_dict[f'{LG_string[m]}+e^j{str(relative_phase[k]/np.pi)}π{LG_string[n]}'] = \ 200 | (1 / np.sqrt(2)) * (LG_modes[m] + np.exp(1j * relative_phase[k]) * LG_modes[n]) 201 | 202 | return np.array(list(TOMO_dict.values())), [*TOMO_dict] 203 | 204 | def TomographyBankHG( 205 | lam, 206 | refractive_index, 207 | W0, 208 | max_mode_x, 209 | max_mode_y, 210 | x, 211 | y, 212 | z=0, 213 | relative_phase: List[Union[Union[int, float], Any]] = None, 214 | tomography_quantum_state: str = None, 215 | ): 216 | """ 217 | generates a dictionary of basis function with projections into two orthogonal HG bases and mutually unbiased 218 | bases (MUBs). The MUBs are constructed from superpositions of the two orthogonal HG bases. 219 | according to: https://doi.org/10.1364/AOP.11.000067 220 | 221 | Parameters 222 | ---------- 223 | lam; wavelength 224 | refractive_index: refractive index 225 | W0: beam waist 226 | max_mode_x: maximum projection mode 1st axis 227 | max_mode_y: maximum projection mode 2nd axis 228 | x: transverse points, x axis 229 | y: transverse points, y axis 230 | z: projection longitudinal position 231 | relative_phase: The relative phase between the mutually unbiased bases (MUBs) states 232 | tomography_quantum_state: the current quantum state we calculate it tomography matrix. 233 | currently we support: qubit 234 | 235 | Returns 236 | ------- 237 | dictionary of bases functions used for constructing the tomography matrix 238 | """ 239 | 240 | TOMO_dict = \ 241 | HermiteBank( 242 | lam, 243 | refractive_index, 244 | W0, 245 | max_mode_x, 246 | max_mode_y, 247 | x, y, z, 248 | get_dict=True) 249 | 250 | if tomography_quantum_state is QUBIT: 251 | del TOMO_dict['|HG00>'] 252 | del TOMO_dict['|HG11>'] 253 | 254 | HG_modes, HG_string = np.array(list(TOMO_dict.values())), [*TOMO_dict] 255 | 256 | for m in range(len(TOMO_dict) - 1, -1, -1): 257 | for n in range(m - 1, -1, -1): 258 | for k in range(len(relative_phase)): 259 | TOMO_dict[f'{HG_string[m]}+e^j{str(relative_phase[k]/np.pi)}π{HG_string[n]}'] = \ 260 | (1 / np.sqrt(2)) * (HG_modes[m] + np.exp(1j * relative_phase[k]) * HG_modes[n]) 261 | 262 | return np.array(list(TOMO_dict.values())), [*TOMO_dict] 263 | 264 | 265 | def Hermite_gauss(lam, refractive_index, W0, nx, ny, z, X, Y, coef=None): 266 | """ 267 | Hermite Gauss in 2D 268 | 269 | Parameters 270 | ---------- 271 | lam: wavelength 272 | refractive_index: refractive index 273 | W0: beam waists 274 | n, m: order of the HG beam 275 | z: the place in z to calculate for 276 | x,y: matrices of x and y 277 | coef 278 | 279 | Returns 280 | ------- 281 | Hermite-Gaussian beam of order n,m in 2D 282 | """ 283 | k = 2 * np.pi * refractive_index / lam 284 | z0 = np.pi * W0 ** 2 * refractive_index / lam # Rayleigh range 285 | Wz = W0 * np.sqrt(1 + (z / z0) ** 2) # w(z), the variation of the spot size 286 | 287 | invR = z / ((z ** 2) + (z0 ** 2)) # radius of curvature 288 | gouy = (nx + ny + 1)*np.arctan(z/z0) 289 | if coef is None: 290 | coefx = np.sqrt(np.sqrt(2/pi) / (2**nx * math.factorial(nx))) 291 | coefy = np.sqrt(np.sqrt(2/pi) / (2**ny * math.factorial(ny))) 292 | coef = coefx * coefy 293 | U = coef * \ 294 | (W0/Wz) * np.exp(-(X**2 + Y**2) / Wz**2) * \ 295 | HermiteP(nx, np.sqrt(2) * X / Wz) * \ 296 | HermiteP(ny, np.sqrt(2) * Y / Wz) * \ 297 | np.exp(-1j * (k * (X**2 + Y**2) / 2) * invR) * \ 298 | np.exp(1j * gouy) 299 | 300 | return U 301 | 302 | 303 | def Laguerre_gauss(lam, refractive_index, W0, l, p, z, x, y, coef=None): 304 | """ 305 | Laguerre Gauss in 2D 306 | 307 | Parameters 308 | ---------- 309 | lam: wavelength 310 | refractive_index: refractive index 311 | W0: beam waists 312 | l, p: order of the LG beam 313 | z: the place in z to calculate for 314 | x,y: matrices of x and y 315 | coef 316 | 317 | Returns 318 | ------- 319 | Laguerre-Gaussian beam of order l,p in 2D 320 | """ 321 | k = 2 * np.pi * refractive_index / lam 322 | z0 = np.pi * W0 ** 2 * refractive_index / lam # Rayleigh range 323 | Wz = W0 * np.sqrt(1 + (z / z0) ** 2) # w(z), the variation of the spot size 324 | r = np.sqrt(x**2 + y**2) 325 | phi = np.arctan2(y, x) 326 | 327 | invR = z / ((z ** 2) + (z0 ** 2)) # radius of curvature 328 | gouy = (np.abs(l)+2*p+1)*np.arctan(z/z0) 329 | if coef is None: 330 | coef = np.sqrt(2*math.factorial(p)/(np.pi * math.factorial(p + np.abs(l)))) 331 | 332 | U = coef * \ 333 | (W0/Wz)*(r*np.sqrt(2)/Wz)**(np.abs(l)) * \ 334 | np.exp(-r**2 / Wz**2) * \ 335 | LaguerreP(p, l, 2 * r**2 / Wz**2) * \ 336 | np.exp(-1j * (k * r**2 / 2) * invR) * \ 337 | np.exp(-1j * l * phi) * \ 338 | np.exp(1j * gouy) 339 | return U 340 | 341 | 342 | def HermiteP(n, x): 343 | """ 344 | Hermite polynomial of rank n Hn(x) 345 | 346 | Parameters 347 | ---------- 348 | n: order of the LG beam 349 | x: matrix of x 350 | 351 | Returns 352 | ------- 353 | Hermite polynomial 354 | """ 355 | if n == 0: 356 | return 1 357 | elif n == 1: 358 | return 2 * x 359 | else: 360 | return 2 * x * HermiteP(n - 1, x) - 2 * (n - 1) * HermiteP(n - 2, x) 361 | 362 | 363 | def LaguerreP(p, l, x): 364 | """ 365 | Generalized Laguerre polynomial of rank p,l L_p^|l|(x) 366 | 367 | Parameters 368 | ---------- 369 | l, p: order of the LG beam 370 | x: matrix of x 371 | 372 | Returns 373 | ------- 374 | Generalized Laguerre polynomial 375 | """ 376 | if p == 0: 377 | return 1 378 | elif p == 1: 379 | return 1 + np.abs(l)-x 380 | else: 381 | return ((2*p-1+np.abs(l)-x)*LaguerreP(p-1, l, x) - (p-1+np.abs(l))*LaguerreP(p-2, l, x))/p 382 | 383 | 384 | class Beam(ABC): 385 | """ 386 | A class that holds everything to do with a beam 387 | """ 388 | def __init__(self, 389 | lam: float, 390 | ctype, 391 | polarization: str, 392 | T: float, 393 | power: float = 0): 394 | 395 | """ 396 | 397 | Parameters 398 | ---------- 399 | lam: beam's wavelength 400 | ctype: function that holds crystal type fo calculating refractive index 401 | polarization: Polarization of the beam 402 | T: crystal's temperature [Celsius Degrees] 403 | power: beam power [watt] 404 | """ 405 | 406 | self.lam = lam 407 | self.n = ctype(lam * 1e6, T, polarization) # refractive index 408 | self.w = 2 * np.pi * c / lam # frequency 409 | self.k = 2 * np.pi * ctype(lam * 1e6, T, polarization) / lam # wave vector 410 | self.power = power # beam power 411 | 412 | 413 | class Beam_profile(ABC): 414 | def __init__( 415 | self, 416 | pump_coeffs_real, 417 | pump_coeffs_imag, 418 | waist_pump, 419 | power_pump, 420 | x, 421 | y, 422 | dx, 423 | dy, 424 | max_mode1, 425 | max_mode2, 426 | pump_basis: str, 427 | lam_pump, 428 | refractive_index, 429 | learn_pump_coeffs: bool = False, 430 | learn_pump_waists: bool = False, 431 | z: float = 0., 432 | ): 433 | 434 | 435 | self.x = x 436 | self.y = y 437 | self.z = z 438 | self.learn_pump_coeffs = learn_pump_coeffs 439 | self.learn_pump_waists = learn_pump_waists 440 | self.learn_pump = learn_pump_coeffs or learn_pump_waists 441 | self.lam_pump = lam_pump 442 | self.pump_basis = pump_basis 443 | self.max_mode1 = max_mode1 444 | self.max_mode2 = max_mode2 445 | self.power = power_pump 446 | self.crystal_dx = dx 447 | self.crystal_dy = dy 448 | self.refractive_index = refractive_index 449 | 450 | if not self.learn_pump_coeffs: 451 | self.pump_coeffs_real, \ 452 | self.pump_coeffs_imag = pump_coeffs_real, pump_coeffs_imag 453 | if not self.learn_pump_waists: 454 | self.waist_pump = waist_pump 455 | 456 | if self.pump_basis.lower() == 'lg': # Laguerre-Gauss 457 | self.coef = np.zeros(len(waist_pump), dtype=np.float32) 458 | idx = 0 459 | for p in range(self.max_mode1): 460 | for l in range(-self.max_mode2, self.max_mode2 + 1): 461 | 462 | self.coef = index_update( 463 | self.coef, idx, 464 | np.sqrt(2 * math.factorial(p) / (np.pi * math.factorial(p + np.abs(l)))) 465 | ) 466 | 467 | idx += 1 468 | 469 | if not self.learn_pump: 470 | self.E = self._profile_laguerre_gauss(pump_coeffs_real, pump_coeffs_imag, waist_pump) 471 | 472 | elif self.pump_basis.lower() == "hg": # Hermite-Gauss 473 | self.coef = np.zeros(len(waist_pump), dtype=np.float32) 474 | idx = 0 475 | for nx in range(self.max_mode1): 476 | for ny in range(self.max_mode2): 477 | self.coef = index_update( 478 | self.coef, idx, 479 | np.sqrt(np.sqrt(2 / pi) / (2 ** nx * math.factorial(nx))) * 480 | np.sqrt(np.sqrt(2 / pi) / (2 ** ny * math.factorial(ny)))) 481 | 482 | idx += 1 483 | 484 | if not self.learn_pump: 485 | self.E = self._profile_hermite_gauss(pump_coeffs_real, pump_coeffs_imag, waist_pump) 486 | 487 | 488 | def create_profile(self, pump_coeffs_real, pump_coeffs_imag, waist_pump): 489 | if self.learn_pump: 490 | if self.pump_basis.lower() == 'lg': # Laguerre-Gauss 491 | if self.learn_pump_coeffs and self.learn_pump_waists: 492 | self.E = self._profile_laguerre_gauss( 493 | pump_coeffs_real, pump_coeffs_imag, waist_pump 494 | ) 495 | elif self.learn_pump_coeffs: 496 | self.E = self._profile_laguerre_gauss( 497 | pump_coeffs_real, pump_coeffs_imag, self.waist_pump 498 | ) 499 | else: 500 | self.E = self._profile_laguerre_gauss( 501 | self.pump_coeffs_real, self.pump_coeffs_imag, waist_pump 502 | ) 503 | 504 | elif self.pump_basis.lower() == 'hg': # Hermite-Gauss 505 | if self.learn_pump_coeffs and self.learn_pump_waists: 506 | self.E = self._profile_hermite_gauss( 507 | pump_coeffs_real, pump_coeffs_imag, waist_pump 508 | ) 509 | elif self.learn_pump_coeffs: 510 | self.E = self._profile_hermite_gauss( 511 | pump_coeffs_real, pump_coeffs_imag, self.waist_pump 512 | ) 513 | else: 514 | self.E = self._profile_hermite_gauss( 515 | self.pump_coeffs_real, self.pump_coeffs_imag, waist_pump 516 | ) 517 | 518 | def _profile_laguerre_gauss( 519 | self, 520 | pump_coeffs_real, 521 | pump_coeffs_imag, 522 | waist_pump 523 | ): 524 | coeffs = pump_coeffs_real + 1j * pump_coeffs_imag 525 | [X, Y] = np.meshgrid(self.x, self.y) 526 | pump_profile = 0. 527 | idx = 0 528 | for p in range(self.max_mode1): 529 | for l in range(-self.max_mode2, self.max_mode2 + 1): 530 | pump_profile += coeffs[idx] * \ 531 | Laguerre_gauss(self.lam_pump, self.refractive_index, 532 | waist_pump[idx] * 1e-5, l, p, self.z, X, Y, self.coef[idx]) 533 | idx += 1 534 | 535 | pump_profile = fix_power(pump_profile, self.power, self.refractive_index, 536 | self.crystal_dx, self.crystal_dy)[np.newaxis, :, :] 537 | return pump_profile 538 | 539 | def _profile_hermite_gauss( 540 | self, 541 | pump_coeffs_real, 542 | pump_coeffs_imag, 543 | waist_pump 544 | ): 545 | 546 | coeffs = pump_coeffs_real + 1j * pump_coeffs_imag 547 | [X, Y] = np.meshgrid(self.x, self.y) 548 | pump_profile = 0. 549 | idx = 0 550 | for nx in range(self.max_mode1): 551 | for ny in range(self.max_mode2): 552 | pump_profile += coeffs[idx] * \ 553 | Hermite_gauss(self.lam_pump, self.refractive_index, 554 | waist_pump[idx] * 1e-5, nx, ny, self.z, X, Y, self.coef[idx]) 555 | idx += 1 556 | 557 | pump_profile = fix_power(pump_profile, self.power, self.refractive_index, 558 | self.crystal_dx, self.crystal_dy)[np.newaxis, :, :] 559 | return pump_profile 560 | 561 | 562 | class Crystal_hologram(ABC): 563 | def __init__( 564 | self, 565 | crystal_coeffs_real, 566 | crystal_coeffs_imag, 567 | r_scale, 568 | x, 569 | y, 570 | max_mode1, 571 | max_mode2, 572 | crystal_basis, 573 | lam_signal, 574 | refractive_index, 575 | learn_crystal_coeffs: bool = False, 576 | learn_crystal_waists: bool = False, 577 | z: float = 0., 578 | ): 579 | 580 | self.x = x 581 | self.y = y 582 | self.z = z 583 | self.learn_crystal_coeffs = learn_crystal_coeffs 584 | self.learn_crystal_waists = learn_crystal_waists 585 | self.learn_crystal = learn_crystal_coeffs or learn_crystal_waists 586 | self.refractive_index = refractive_index 587 | self.lam_signal = lam_signal 588 | self.crystal_basis = crystal_basis 589 | self.max_mode1 = max_mode1 590 | self.max_mode2 = max_mode2 591 | 592 | if not self.learn_crystal_coeffs: 593 | self.crystal_coeffs_real, \ 594 | self.crystal_coeffs_imag = crystal_coeffs_real, crystal_coeffs_imag 595 | if not self.learn_crystal_waists: 596 | self.r_scale = r_scale 597 | 598 | 599 | 600 | if crystal_basis.lower() == 'ft': # Fourier-Taylor 601 | if not self.learn_crystal: 602 | self.crystal_profile = self._profile_fourier_taylor(crystal_coeffs_real, crystal_coeffs_imag, r_scale) 603 | 604 | elif crystal_basis.lower() == 'fb': # Fourier-Bessel 605 | 606 | [X, Y] = np.meshgrid(self.x, self.y) 607 | self.coef = np.zeros(len(r_scale), dtype=np.float32) 608 | idx = 0 609 | for p in range(self.max_mode1): 610 | for l in range(-self.max_mode2, self.max_mode2 + 1): 611 | rad = np.sqrt(X ** 2 + Y ** 2) / (r_scale[idx] * 1e-5) 612 | self.coef = index_update( 613 | self.coef, idx, 614 | sp.jv(0, sp.jn_zeros(0, p + 1)[-1] * rad) 615 | ) 616 | idx += 1 617 | 618 | if not self.learn_crystal: 619 | self.crystal_profile = self._profile_fourier_bessel(crystal_coeffs_real, crystal_coeffs_imag) 620 | 621 | elif crystal_basis.lower() == 'lg': # Laguerre-Gauss 622 | 623 | self.coef = np.zeros(len(r_scale), dtype=np.float32) 624 | idx = 0 625 | for p in range(self.max_mode1): 626 | for l in range(-self.max_mode2, self.max_mode2 + 1): 627 | self.coef = index_update( 628 | self.coef, idx, 629 | np.sqrt(2 * math.factorial(p) / (np.pi * math.factorial(p + np.abs(l)))) 630 | ) 631 | idx += 1 632 | 633 | if not self.learn_crystal: 634 | self.crystal_profile = self._profile_laguerre_gauss(crystal_coeffs_real, crystal_coeffs_imag, r_scale) 635 | 636 | elif crystal_basis.lower() == 'hg': # Hermite-Gauss 637 | 638 | self.coef = np.zeros(len(r_scale), dtype=np.float32) 639 | idx = 0 640 | for m in range(self.max_mode1): 641 | for n in range(self.max_mode2): 642 | self.coef = index_update( 643 | self.coef, idx, 644 | np.sqrt(np.sqrt(2 / pi) / (2 ** m * math.factorial(m))) * 645 | np.sqrt(np.sqrt(2 / pi) / (2 ** n * math.factorial(n))) 646 | ) 647 | 648 | idx += 1 649 | 650 | if not self.learn_crystal: 651 | self.crystal_profile = self._profile_hermite_gauss(crystal_coeffs_real, crystal_coeffs_imag, r_scale) 652 | 653 | def create_profile( 654 | self, 655 | crystal_coeffs_real, 656 | crystal_coeffs_imag, 657 | r_scale, 658 | ): 659 | if self.learn_crystal: 660 | if self.crystal_basis.lower() == 'ft': # Fourier-Taylor 661 | if self.learn_crystal_coeffs and self.learn_crystal_waists: 662 | self.crystal_profile = self._profile_fourier_taylor( 663 | crystal_coeffs_real, crystal_coeffs_imag, r_scale 664 | ) 665 | elif self.learn_crystal_coeffs: 666 | self.crystal_profile = self._profile_fourier_taylor( 667 | crystal_coeffs_real, crystal_coeffs_imag, self.r_scale 668 | ) 669 | else: 670 | self.crystal_profile = self._profile_fourier_taylor( 671 | self.crystal_coeffs_real, self.crystal_coeffs_imag, r_scale 672 | ) 673 | 674 | elif self.crystal_basis.lower() == 'fb': # Fourier-Bessel 675 | if self.learn_crystal_coeffs: 676 | self.crystal_profile = self._profile_fourier_bessel( 677 | crystal_coeffs_real, crystal_coeffs_imag 678 | ) 679 | else: 680 | self.crystal_profile = self._profile_fourier_bessel( 681 | self.crystal_coeffs_real, self.crystal_coeffs_imag 682 | ) 683 | 684 | elif self.crystal_basis.lower() == 'lg': # Laguerre-Gauss 685 | if self.learn_crystal_coeffs and self.learn_crystal_waists: 686 | self.crystal_profile = self._profile_laguerre_gauss( 687 | crystal_coeffs_real, crystal_coeffs_imag, r_scale 688 | ) 689 | elif self.learn_crystal_coeffs: 690 | self.crystal_profile = self._profile_laguerre_gauss( 691 | crystal_coeffs_real, crystal_coeffs_imag, self.r_scale 692 | ) 693 | else: 694 | self.crystal_profile = self._profile_laguerre_gauss( 695 | self.crystal_coeffs_real, self.crystal_coeffs_imag, r_scale 696 | ) 697 | 698 | elif self.crystal_basis.lower() == 'hg': # Hermite-Gauss 699 | if self.learn_crystal_coeffs and self.learn_crystal_waists: 700 | self.crystal_profile = self._profile_hermite_gauss( 701 | crystal_coeffs_real, crystal_coeffs_imag, r_scale 702 | ) 703 | elif self.learn_crystal_coeffs: 704 | self.crystal_profile = self._profile_hermite_gauss( 705 | crystal_coeffs_real, crystal_coeffs_imag, self.r_scale 706 | ) 707 | else: 708 | self.crystal_profile = self._profile_hermite_gauss( 709 | self.crystal_coeffs_real, self.crystal_coeffs_imag, r_scale 710 | ) 711 | 712 | def _profile_fourier_taylor( 713 | self, 714 | crystal_coeffs_real, 715 | crystal_coeffs_imag, 716 | r_scale, 717 | ): 718 | coeffs = crystal_coeffs_real + 1j * crystal_coeffs_imag 719 | [X, Y] = np.meshgrid(self.x, self.y) 720 | phi_angle = np.arctan2(Y, X) 721 | crystal_profile = 0. 722 | idx = 0 723 | for p in range(self.max_mode1): 724 | for l in range(-self.max_mode2, self.max_mode2 + 1): 725 | rad = np.sqrt(X**2 + Y**2) / (r_scale[idx] * 1e-5) 726 | crystal_profile += coeffs[idx] * rad**p * np.exp(-rad**2) * np.exp(-1j * l * phi_angle) 727 | idx += 1 728 | 729 | return crystal_profile 730 | 731 | def _profile_fourier_bessel( 732 | self, 733 | crystal_coeffs_real, 734 | crystal_coeffs_imag, 735 | ): 736 | coeffs = crystal_coeffs_real + 1j * crystal_coeffs_imag 737 | [X, Y] = np.meshgrid(self.x, self.y) 738 | phi_angle = np.arctan2(Y, X) 739 | crystal_profile = 0. 740 | idx = 0 741 | for p in range(self.max_mode1): 742 | for l in range(-self.max_mode2, self.max_mode2 + 1): 743 | crystal_profile += coeffs[idx] * self.coef[idx] * np.exp(-1j * l * phi_angle) 744 | idx += 1 745 | 746 | return crystal_profile 747 | 748 | def _profile_laguerre_gauss( 749 | self, 750 | crystal_coeffs_real, 751 | crystal_coeffs_imag, 752 | r_scale, 753 | ): 754 | coeffs = crystal_coeffs_real + 1j * crystal_coeffs_imag 755 | [X, Y] = np.meshgrid(self.x, self.y) 756 | idx = 0 757 | crystal_profile = 0. 758 | for p in range(self.max_mode1): 759 | for l in range(-self.max_mode2, self.max_mode2 + 1): 760 | crystal_profile += coeffs[idx] * \ 761 | Laguerre_gauss(self.lam_signal, self.refractive_index, 762 | r_scale[idx] * 1e-5, l, p, self.z, X, Y, self.coef[idx]) 763 | idx += 1 764 | 765 | return crystal_profile 766 | 767 | def _profile_hermite_gauss( 768 | self, 769 | crystal_coeffs_real, 770 | crystal_coeffs_imag, 771 | r_scale, 772 | ): 773 | coeffs = crystal_coeffs_real + 1j * crystal_coeffs_imag 774 | [X, Y] = np.meshgrid(self.x, self.y) 775 | idx = 0 776 | crystal_profile = 0. 777 | for m in range(self.max_mode1): 778 | for n in range(self.max_mode2): 779 | crystal_profile += coeffs[idx] * \ 780 | Hermite_gauss(self.lam_signal, self.refractive_index, 781 | r_scale[idx] * 1e-5, m, n, self.z, X, Y, self.coef[idx]) 782 | 783 | idx += 1 784 | 785 | return crystal_profile 786 | 787 | 788 | def fix_power( 789 | A, 790 | power, 791 | n, 792 | dx, 793 | dy 794 | ): 795 | """ 796 | The function takes a field A and normalizes in to have the power indicated 797 | 798 | Parameters 799 | ---------- 800 | A 801 | power 802 | n 803 | dx 804 | dy 805 | 806 | Returns 807 | ------- 808 | 809 | """ 810 | output = A * np.sqrt(power) / np.sqrt(Power2D(A, n, dx, dy)) 811 | return output 812 | 813 | 814 | class DensMat(ABC): 815 | """ 816 | A class that holds tomography dimensions and 817 | tensors used for calculating qubit and qutrit tomography 818 | """ 819 | 820 | def __init__( 821 | self, 822 | projection_n_state2, 823 | tomography_dimension 824 | ): 825 | assert tomography_dimension in [2, 3], "tomography_dimension must be 2 or 3, " \ 826 | f"got {tomography_dimension}" 827 | 828 | self.projection_n_state2 = projection_n_state2 829 | self.tomography_dimension = tomography_dimension 830 | self.rotation_mats, self.masks = self.dens_mat_tensors() 831 | 832 | def dens_mat_tensors( 833 | self 834 | ): 835 | rot_mats_tensor = np.zeros([self.tomography_dimension ** 4, 836 | self.tomography_dimension ** 2, 837 | self.tomography_dimension ** 2], 838 | dtype='complex64') 839 | 840 | masks_tensor = np.zeros([self.tomography_dimension ** 4, 841 | self.projection_n_state2, 842 | self.projection_n_state2], 843 | dtype='complex64') 844 | 845 | if self.tomography_dimension == 2: 846 | mats = ( 847 | np.eye(2, dtype='complex64'), 848 | np.array([[0, 1], [1, 0]]), 849 | np.array([[0, -1j], [1j, 0]]), 850 | np.array([[1, 0], [0, -1]]) 851 | ) 852 | 853 | vecs = ( 854 | np.array([1, 1, 0, 0, 0, 0]), 855 | np.array([0, 0, 1, -1, 0, 0]), 856 | np.array([0, 0, 0, 0, 1, -1]), 857 | np.array([1, -1, 0, 0, 0, 0]) 858 | ) 859 | 860 | else: # tomography_dimension == 3 861 | mats = ( 862 | np.eye(3, dtype='complex64'), 863 | np.array([[1, 0, 0], [0, -1, 0], [0, 0, 0]]), 864 | np.array([[0, 1, 0], [1, 0, 0], [0, 0, 0]]), 865 | np.array([[0, -1j, 0], [1j, 0, 0], [0, 0, 0]]), 866 | np.array([[0, 0, 1], [0, 0, 0], [1, 0, 0]]), 867 | np.array([[0, 0, -1j], [0, 0, 0], [1j, 0, 0]]), 868 | np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0]]), 869 | np.array([[0, 0, 0], [0, 0, -1j], [0, 1j, 0]]), 870 | (1 / np.sqrt(3)) * np.array([[1, 0, 0], [0, 1, 0], [0, 0, -2]]) 871 | ) 872 | 873 | vecs = ( 874 | np.array([1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 875 | np.array([1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 876 | np.array([0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 877 | np.array([0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0]), 878 | np.array([0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0]), 879 | np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0]), 880 | np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0]), 881 | np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1]), 882 | (np.sqrt(3) / 3) * np.array([1, 1, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) 883 | ) 884 | 885 | counter = 0 886 | 887 | for m in range(self.tomography_dimension ** 2): 888 | for n in range(self.tomography_dimension ** 2): 889 | norm1 = np.trace(mats[m] @ mats[m]) 890 | norm2 = np.trace(mats[n] @ mats[n]) 891 | mat1 = mats[m] / norm1 892 | mat2 = mats[n] / norm2 893 | rot_mats_tensor = index_add(rot_mats_tensor, index[counter, :, :], np.kron(mat1, mat2)) 894 | mask = np.dot(vecs[m].reshape(self.projection_n_state2, 1), 895 | np.transpose((vecs[n]).reshape(self.projection_n_state2, 1))) 896 | masks_tensor = index_add(masks_tensor, index[counter, :, :], mask) 897 | counter = counter + 1 898 | 899 | return rot_mats_tensor, masks_tensor 900 | --------------------------------------------------------------------------------