├── DPIdiagram.jpg ├── planet_astrometry_diagrm.png ├── dataset ├── interferometry1 │ ├── gt.fits │ └── obs.uvfits ├── fastmri_sample │ ├── mask │ │ ├── mask1.npy │ │ ├── mask16.npy │ │ ├── mask32.npy │ │ ├── mask4.npy │ │ └── mask8.npy │ └── mri │ │ ├── brain │ │ └── scan_0.pkl │ │ └── knee │ │ └── scan_0.pkl ├── interferometry_m87 │ └── synthetic_crescentfloorgaussian2 │ │ ├── groundtruth.fits │ │ ├── model_params.npy │ │ ├── obs_mring_synthdata_allnoise.uvfits │ │ ├── obs_mring_synthdata_thermal_only.uvfits │ │ ├── obs_mring_synthdata_thermal_phase_only.uvfits │ │ ├── obs_mring_synthdata_allnoise_scanavg_sysnoise2.uvfits │ │ ├── obs_mring_synthdata_thermal_only_scanavg_sysnoise2.uvfits │ │ └── obs_mring_synthdata_thermal_phase_only_scanavg_sysnoise2.uvfits └── orbital_fit │ └── betapic_astrometry.csv ├── DPItorch ├── MRI_helpers.py ├── DPI_2Dtoydist.py ├── DPI_MRI.py ├── DPI_interferometry.py ├── interferometry_helpers.py ├── generative_model │ ├── glow_model.py │ ├── cond_glow_model.py │ ├── realnvpfc_model.py │ └── cond_realnvpfc_model.py ├── orbit_helpers.py └── DPIx_orbit.py ├── README.md └── DPI.yml /DPIdiagram.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeSunPU/DPI/HEAD/DPIdiagram.jpg -------------------------------------------------------------------------------- /planet_astrometry_diagrm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeSunPU/DPI/HEAD/planet_astrometry_diagrm.png -------------------------------------------------------------------------------- /dataset/interferometry1/gt.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeSunPU/DPI/HEAD/dataset/interferometry1/gt.fits -------------------------------------------------------------------------------- /dataset/interferometry1/obs.uvfits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeSunPU/DPI/HEAD/dataset/interferometry1/obs.uvfits -------------------------------------------------------------------------------- /dataset/fastmri_sample/mask/mask1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeSunPU/DPI/HEAD/dataset/fastmri_sample/mask/mask1.npy -------------------------------------------------------------------------------- /dataset/fastmri_sample/mask/mask16.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeSunPU/DPI/HEAD/dataset/fastmri_sample/mask/mask16.npy -------------------------------------------------------------------------------- /dataset/fastmri_sample/mask/mask32.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeSunPU/DPI/HEAD/dataset/fastmri_sample/mask/mask32.npy -------------------------------------------------------------------------------- /dataset/fastmri_sample/mask/mask4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeSunPU/DPI/HEAD/dataset/fastmri_sample/mask/mask4.npy -------------------------------------------------------------------------------- /dataset/fastmri_sample/mask/mask8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeSunPU/DPI/HEAD/dataset/fastmri_sample/mask/mask8.npy -------------------------------------------------------------------------------- /dataset/fastmri_sample/mri/brain/scan_0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeSunPU/DPI/HEAD/dataset/fastmri_sample/mri/brain/scan_0.pkl -------------------------------------------------------------------------------- /dataset/fastmri_sample/mri/knee/scan_0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeSunPU/DPI/HEAD/dataset/fastmri_sample/mri/knee/scan_0.pkl -------------------------------------------------------------------------------- /dataset/interferometry_m87/synthetic_crescentfloorgaussian2/groundtruth.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeSunPU/DPI/HEAD/dataset/interferometry_m87/synthetic_crescentfloorgaussian2/groundtruth.fits -------------------------------------------------------------------------------- /dataset/interferometry_m87/synthetic_crescentfloorgaussian2/model_params.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeSunPU/DPI/HEAD/dataset/interferometry_m87/synthetic_crescentfloorgaussian2/model_params.npy -------------------------------------------------------------------------------- /dataset/interferometry_m87/synthetic_crescentfloorgaussian2/obs_mring_synthdata_allnoise.uvfits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeSunPU/DPI/HEAD/dataset/interferometry_m87/synthetic_crescentfloorgaussian2/obs_mring_synthdata_allnoise.uvfits -------------------------------------------------------------------------------- /dataset/interferometry_m87/synthetic_crescentfloorgaussian2/obs_mring_synthdata_thermal_only.uvfits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeSunPU/DPI/HEAD/dataset/interferometry_m87/synthetic_crescentfloorgaussian2/obs_mring_synthdata_thermal_only.uvfits -------------------------------------------------------------------------------- /dataset/interferometry_m87/synthetic_crescentfloorgaussian2/obs_mring_synthdata_thermal_phase_only.uvfits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeSunPU/DPI/HEAD/dataset/interferometry_m87/synthetic_crescentfloorgaussian2/obs_mring_synthdata_thermal_phase_only.uvfits -------------------------------------------------------------------------------- /dataset/interferometry_m87/synthetic_crescentfloorgaussian2/obs_mring_synthdata_allnoise_scanavg_sysnoise2.uvfits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeSunPU/DPI/HEAD/dataset/interferometry_m87/synthetic_crescentfloorgaussian2/obs_mring_synthdata_allnoise_scanavg_sysnoise2.uvfits -------------------------------------------------------------------------------- /dataset/interferometry_m87/synthetic_crescentfloorgaussian2/obs_mring_synthdata_thermal_only_scanavg_sysnoise2.uvfits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeSunPU/DPI/HEAD/dataset/interferometry_m87/synthetic_crescentfloorgaussian2/obs_mring_synthdata_thermal_only_scanavg_sysnoise2.uvfits -------------------------------------------------------------------------------- /dataset/interferometry_m87/synthetic_crescentfloorgaussian2/obs_mring_synthdata_thermal_phase_only_scanavg_sysnoise2.uvfits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeSunPU/DPI/HEAD/dataset/interferometry_m87/synthetic_crescentfloorgaussian2/obs_mring_synthdata_thermal_phase_only_scanavg_sysnoise2.uvfits -------------------------------------------------------------------------------- /DPItorch/MRI_helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | import torch.nn.functional as functional 9 | 10 | torch.set_default_dtype(torch.float32) 11 | import torch.optim as optim 12 | import pickle 13 | import math 14 | 15 | ############################################################################### 16 | # Define the loss functions for MRI imaging 17 | ############################################################################### 18 | def Loss_kspace_diff(sigma): 19 | def func(y_true, y_pred): 20 | return torch.mean(torch.abs(y_pred - y_true), (1, 2, 3)) / sigma 21 | return func 22 | 23 | def Loss_kspace_diff2(sigma): 24 | def func(y_true, y_pred): 25 | return torch.mean((y_pred - y_true)**2, (1, 2, 3)) / (sigma)**2 26 | return func 27 | 28 | def Loss_l1(y_pred): 29 | # image prior - sparsity loss 30 | return torch.mean(torch.abs(y_pred), (-1, -2)) 31 | 32 | def Loss_TSV(y_pred): 33 | # image prior - total squared variation loss 34 | return torch.mean((y_pred[:, 1::, :] - y_pred[:, 0:-1, :])**2, (-1, -2)) + torch.mean((y_pred[:, :, 1::] - y_pred[:, :, 0:-1])**2, (-1, -2)) 35 | 36 | def Loss_TV(y_pred): 37 | # image prior - total variation loss 38 | return torch.mean(torch.abs(y_pred[:, 1::, :] - y_pred[:, 0:-1, :]), (-1, -2)) + torch.mean(torch.abs(y_pred[:, :, 1::] - y_pred[:, :, 0:-1]), (-1, -2)) 39 | 40 | # def Loss_TV(y_pred): 41 | # # image prior - total variation loss 42 | # eps = 1e-24 43 | # return torch.mean(torch.sqrt((y_pred[:, 1::, :]-y_pred[:, 0:-1, :])**2+eps), (-1, -2)) + torch.mean(torch.sqrt((y_pred[:, :, 1::]-y_pred[:, :, 0:-1])**2+eps), (-1, -2)) 44 | -------------------------------------------------------------------------------- /dataset/orbital_fit/betapic_astrometry.csv: -------------------------------------------------------------------------------- 1 | epoch,object,raoff,raoff_err,decoff,decoff_err,sep,sep_err,pa,pa_err,rv,rv_err 2 | 52953,1,233.5,22,340.7,22,413,22,34.42,3.52,, 3 | 55129,1,-152.8,14,-257,14,299,14,210.74,2.89,, 4 | 55194,1,-162.5,9,-259.3,8,306,9,212.07,1.71,, 5 | 55296,1,-172.6,7,-299.9,7,346,7,209.93,1.15,, 6 | 55467,1,-193.1,11,-330.7,11,383,11,210.28,1.73,, 7 | 55516,1,-207.4,8,-326.7,10,387,8,212.41,1.35,, 8 | 55517,1,-208.6,13,-329.5,14,390,13,212.34,2.13,, 9 | 55593,1,-210.9,9,-349.2,10,408,9,211.13,1.48,, 10 | 55646,1,-213.8,12,-368.4,14,426,13,210.13,1.81,, 11 | 55168,1,-158.1,10,-281.7,10,323,10,209.3,1.8,, 12 | 55168,1,-165.4,10,-295.9,10,339,10,209.2,1.7,, 13 | 55555,1,-221.1,6.8,-341.7,8.8,407,5,212.9,1.4,, 14 | 55854,1,-240.4,3.1,-386.3,3.1,455,3,211.9,0.4,, 15 | 55854,1,-236.8,4.9,-385,4.8,452,5,211.6,0.6,, 16 | 56015,1,-228.9,3,-384,3.1,447,3,210.8,0.4,, 17 | 56015,1,-236.1,4.9,-380.8,4.8,448,5,211.8,0.6,, 18 | 56263,1,-243.6,12.9,-391.4,11,461,14,211.9,1.2,, 19 | 56265,1,-249.1,10,-398.6,9.9,470,10,212,1.2,, 20 | 56612.01000000001,1,,,,,0430.757,1.535,212.43,0.1681,, 21 | 56612.01000000001,1,,,,,0429.093,1.03,212.577,0.1509,, 22 | 56614.0175,1,,,,,0430.193,0.9790000000000001,212.463,0.1476,, 23 | 56635.99049999996,1,,,,,0425.502,1.003,212.508,0.1547,, 24 | 56635.99049999996,1,,,,,0424.389,0.987,212.845,0.151,, 25 | 56637.01250000004,1,,,,,0425.29,1.03,212.466,0.156,, 26 | 56969.01650000003,1,,,,,0356.24699999999996,0.9990000000000001,213.025,0.1947,, 27 | 57113.9945028854,1,,,,,0317.28299999999996,0.9319999999999999,213.13,0.1977,, 28 | 57332.00899822457,1,,,,,0250.491,1.453,214.145,0.3426,, 29 | 57360.989999143494,1,,,,,0240.194,1.056,213.575,0.3418,, 30 | 57377.99899968289,1,,,,,0234.514,0.9600000000000001,213.807,0.3006,, 31 | 57407.983600631924,1,,,,,0222.575,2.1310000000000002,214.838,0.4433,, 32 | 58381.98249999996,1,,,,,0141.9,5.3,28.16,1.82,, 33 | 56998.98299999997,1,,,,,0350.51,3.2,212.6,0.66,, 34 | 57045.99500072921,1,,,,,0335.488,0.9148000000000001,212.882,0.2014,, 35 | 57058.00350110996,1,,,,,0332.42,1.7,212.58,0.35,, 36 | 57295.983497082205,1,,,,,0262.02,1.7799999999999998,213.02,0.48,, 37 | 57355.98949898494,1,,,,,0242.04999999999998,2.5100000000000002,213.3,0.74,, 38 | 57382.013999810195,1,,,,,0234.84,1.8,213.79,0.51,, 39 | 57406.99540060067,1,,,,,0227.23,1.55,213.15,0.46,, 40 | 57472.98520268747,1,,,,,0203.66,1.42,213.9,0.46,, 41 | 57493.993603351875,1,,,,,0197.49,2.3600000000000003,213.88,0.83,, 42 | 57647.018208190944,1,,,,,0142.35999999999999,2.34,214.62,1.1,, 43 | 57675.01720907641,1,,,,,0134.5,2.46,215.5,1.22,, 44 | 57710.006810182844,1,,,,,0127.12,6.44,215.8,3.37,, 45 | 58378.003999999964,1,,,,,0140.46,3.12,29.71,1.67,, 46 | 58441.003000000004,1,,,,,0164.5,1.8,28.64,0.7,, 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Probabilistic Imaging (DPI) 2 | ![overview image](https://github.com/HeSunPU/DPI/blob/main/DPIdiagram.jpg) 3 | Deep Probabilistic Imaging: Uncertainty Quantification and Multi-modal Solution Characterization for Computational Imaging, [AAAI 2021](https://arxiv.org/abs/2010.14462) 4 | 5 | ## Run Examples 6 | 1. The simple 2D example can be run using the ipython notebook ```DPItorch/notebook/DPI toy 2D results.ipynb``` 7 | 8 | 2. The DPI radio interferometric example can be trained using ```DPItorch/DPI_interferometry.py```, and analyzed using ```DPItorch/notebook/DPI interferometry results.ipynb``` 9 | 10 | ```python DPI_interferometry.py --lr 1e-4 --clip 1e-3 --n_epoch 30000 --npix 32 --n_flow 16 --logdet 1.0 --save_path ./checkpoint/interferometry --obspath ../dataset/interferometry1/obs.uvfits``` 11 | 12 | 3. The DPI MRI example can be trained using ```DPItorch/DPI_interferometry.py```, and analyzed using ```DPItorch/notebook/DPI MRI results.ipynb``` 13 | 14 | ```python DPI_MRI.py --lr 1e-5 --clip 1e-3 --n_epoch 100000 --npix 64 --n_flow 16 --ratio 4 --logdet 1.0 --tv 1e3 --save_path ./checkpoint/mri --impath ../dataset/fastmri_sample/mri/knee/scan_0.pkl --maskpath ../dataset/fastmri_sample/mask/mask4.npy --sigma 5e-7``` 15 | 16 | **Arguments:** 17 | 18 | General: 19 | * lr (float) - learning rate 20 | * clip (float) - threshold for gradient clip 21 | * n_epoch (int) - number of epochs 22 | * npix (int) - size of reconstruction images (npix * npix) 23 | * n_flow (int) - number of affine coupling blocks 24 | * logdet (float) - weight of the entropy loss (larger means more diverse samples) 25 | * save_path (str) - folder that saves the learned DPI normalizing flow model 26 | 27 | For radio interferometric imaging: 28 | * obspath (str) - observation data file 29 | 30 | For compressed sensing MRI: 31 | * impath (str) - fast MRI image for generating MRI measurements 32 | * maskpath (str) - compressed sensing sampling mask 33 | * sigma (float) - additive measurement noise 34 | 35 | ## Requirements 36 | General requirements for PyTorch release: 37 | * [pytorch](https://pytorch.org/) 38 | * [torchkbnufft](https://pypi.org/project/torchkbnufft/) 39 | 40 | For radio interferometric imaging: 41 | * [eht-imaging](https://pypi.org/project/ehtim/) 42 | * [astropy](https://pypi.org/project/astropy/) 43 | * [pyNFFT](https://pypi.org/project/pyNFFT/) 44 | 45 | Please check ```DPI.yml``` for the detailed Anaconda environment information. TensorFlow release is coming soon! 46 | 47 | ## Citation 48 | ``` 49 | @inproceedings{sun2021deep, 50 | author = {He Sun and Katherine L. Bouman}, 51 | title = {Deep Probabilistic Imaging: Uncertainty Quantification and Multi-modal Solution Characterization for Computational Imaging}, 52 | booktitle = {AAAI Conference on Artificial Intelligence (AAAI)}, 53 | year = {2021}, 54 | } 55 | ``` 56 | 57 | # alpha-Deep Probabilistic Imaging (alpha-DPI) 58 | ![overview image](https://github.com/HeSunPU/DPI/blob/main/planet_astrometry_diagrm.png) 59 | alpha-Deep Probabilistic Inference (alpha-DPI): efficient uncertainty quantification from exoplanet astrometry to black hole feature extraction, [arXiv](https://arxiv.org/abs/2201.08506) 60 | 61 | ## Run Examples 62 | 1. The alpha-DPI radio interferometric example can be trained using ```DPItorch/DPIx_interferometry.py``` 63 | 64 | ```python DPIx_interferometry.py --n_gaussian 2 --divergence_type alpha --alpha_divergence 0.95 --n_epoch 20000 --lr 1e-4 --fov 160 --save_path ./checkpoint/interferometry_m87_mcfe/synthetic/crescentfloornuissance2/alpha095closure --obspath ../dataset/interferometry_m87/synthetic_crescentfloorgaussian2/obs_mring_synthdata_allnoise_scanavg_sysnoise2.uvfits``` 65 | 66 | 2. The alpha-DPI planet direct imaging orbit fitting example can be trained using ```DPItorch/DPIx_orbit.py``` 67 | 68 | ```python DPIx_orbit.py --divergence_type alpha --alpha_divergence 0.6 --coordinate_type cartesian --save_path ./checkpoint/orbit_beta_pic_b/cartesian/alpha06``` 69 | 70 | ## Citation 71 | ``` 72 | @article{sun2022alpha, 73 | title={alpha-Deep Probabilistic Inference (alpha-DPI): efficient uncertainty quantification from exoplanet astrometry to black hole feature extraction}, 74 | author={Sun, He and Bouman, Katherine L and Tiede, Paul and Wang, Jason J and Blunt, Sarah and Mawet, Dimitri}, 75 | journal={arXiv preprint arXiv:2201.08506}, 76 | year={2022} 77 | } 78 | 79 | ``` 80 | -------------------------------------------------------------------------------- /DPItorch/DPI_2Dtoydist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as functional 6 | 7 | torch.set_default_dtype(torch.float32) 8 | import torch.optim as optim 9 | import pickle 10 | import math 11 | # from sys import exit 12 | import matplotlib.pyplot as plt 13 | from sklearn import datasets 14 | 15 | plt.ion() 16 | 17 | from generative_model import realnvpfc_model 18 | 19 | ############################################################################### 20 | # Three examples of 2D toy loglikelihood 21 | ############################################################################### 22 | # Example 1 - Gaussian mixture 23 | x1, y1, a1, sigma1 = -0.5, -0.5, 1, 0.4 24 | x2, y2, a2, sigma2 = -0.5, 0.5, 1, 0.4 25 | x3, y3, a3, sigma3 = 0.5, -0.5, 1, 0.4 26 | x4, y4, a4, sigma4 = 0.5, 0.5, 1, 0.4 27 | log_prob1 = lambda x, y: torch.log(a1 * torch.exp(-1/sigma1**2*((x-x1)**2+(y-y1)**2)) + \ 28 | a2 * torch.exp(-1/sigma2**2*((x-x2)**2+(y-y2)**2)) + \ 29 | a3 * torch.exp(-1/sigma3**2*((x-x3)**2+(y-y3)**2)) + \ 30 | a4 * torch.exp(-1/sigma4**2*((x-x4)**2+(y-y4)**2))) 31 | 32 | # Example 2 - double crescent 33 | log_prob2 = lambda x, y: -0.5 * ((torch.sqrt((4*x)**2 + (4*y)**2)-2)/0.6)**2 + \ 34 | torch.log(torch.exp(-0.5 * (4*x-2)**2/0.6**2)+torch.exp(-0.5 * (4*x+2)**2/0.6**2)) 35 | 36 | # Example 3 - sinusoidal 37 | log_prob3 = lambda x, y: -0.5 * ((4*y - torch.sin(2*np.pi*x))/0.4)**2 38 | 39 | 40 | 41 | ############################################################################### 42 | # training a normalizing flow to approximate probability function 43 | ############################################################################### 44 | # loss_func = -log_prob1 # select the likelihood function to estimate 45 | loss_func = lambda x, y: -log_prob1(x, y) # select the likelihood function to estimate 46 | 47 | # Define the architecture of normalizing flow 48 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 49 | 50 | affine = True 51 | n_flow = 16#8#32 52 | generator = realnvpfc_model.RealNVP(2, n_flow, affine=affine, seqfrac=1/128).to(device) 53 | 54 | # define the optimizer 55 | optimizer = optim.Adam(generator.parameters(), lr = 1e-5) 56 | 57 | n_epoch = 3000 # number of optimization steps 58 | diversity = 1 # weight of the diversity loss 59 | n_samples = 512 # number of samples in each optimization step 60 | 61 | # start optimization 62 | loss_list = [] 63 | sample_list = [] 64 | for k in range(n_epoch): 65 | x_samples_transformed, logdet = generator.reverse(torch.randn((n_samples, 2)).to(device)) 66 | x_samples = 2 * torch.sigmoid(x_samples_transformed) - 1 67 | det_sigmoid = torch.sum(-x_samples_transformed - 2 * torch.nn.Softplus()(-x_samples_transformed), -1) 68 | logdet = logdet + det_sigmoid 69 | 70 | loss = torch.mean(loss_func(x_samples[:, 0], x_samples[:, 1]) - diversity * logdet) 71 | loss_list.append(loss.detach().cpu().numpy()) 72 | 73 | optimizer.zero_grad() 74 | loss.backward() 75 | nn.utils.clip_grad_norm_(generator.parameters(), 1e-4) 76 | optimizer.step() 77 | 78 | if (k + 1) % 100 == 0: 79 | print(f"epoch: {k:}, loss: {loss.item():.5f}") 80 | sample_list.append(x_samples.detach().cpu().numpy()) 81 | 82 | 83 | generator.eval() 84 | 85 | x_samples_transformed, logdet = generator.reverse(torch.randn((n_samples, 2)).to(device)) 86 | x_samples = 2 * torch.sigmoid(x_samples_transformed) - 1 87 | x_generated = x_samples.detach().cpu().numpy() 88 | plt.figure(), plt.plot(x_generated[:, 0], x_generated[:, 1], 'rx') 89 | 90 | 91 | x_range = np.arange(-0.95, 0.95, 0.05) 92 | y_range = np.arange(-0.95, 0.95, 0.05) 93 | 94 | X, Y = np.meshgrid(x_range, y_range) 95 | value = loss_func(torch.tensor(X), torch.tensor(Y)).numpy() 96 | Z = np.array(value) 97 | 98 | fig = plt.figure(figsize=(6,5)) 99 | left, bottom, width, height = 0.1, 0.1, 0.8, 0.8 100 | ax = fig.add_axes([left, bottom, width, height]) 101 | prob_true = np.exp(-Z) / np.sum(np.exp(-Z)) 102 | cp = plt.contourf(X, Y, prob_true, 40) 103 | cp = plt.contour 104 | cp = plt.colorbar() 105 | 106 | 107 | X_reshape = X.reshape((-1, 1)) 108 | Y_reshape = Y.reshape((-1, 1)) 109 | XY_reshape = np.concatenate([X_reshape, Y_reshape], -1).astype(np.float32) 110 | XY_reshape_torch = torch.tensor(XY_reshape).to(device) 111 | XY_reshape_transoformed = torch.log(1+XY_reshape_torch+1e-8) - torch.log(1-XY_reshape_torch+1e-8) 112 | random_samples, logdet = generator.forward(XY_reshape_transoformed) 113 | det_sigmoid = torch.sum(XY_reshape_transoformed + 2 * torch.nn.Softplus()(-XY_reshape_transoformed), -1) 114 | logdet = logdet + det_sigmoid 115 | 116 | 117 | logprob = logdet.cpu().detach().numpy() - 0.5 * np.sum((random_samples.cpu().detach().numpy())**2+np.log(2*np.pi), -1) 118 | logprob = logprob.reshape(X.shape) 119 | 120 | 121 | fig = plt.figure(figsize=(6,5)) 122 | left, bottom, width, height = 0.1, 0.1, 0.8, 0.8 123 | ax = fig.add_axes([left, bottom, width, height]) 124 | 125 | prob = np.exp(logprob) / np.sum(np.exp(logprob)) 126 | # prob = logprob 127 | cp = plt.contourf(X_reshape.reshape((len(x_range), len(y_range))), Y_reshape.reshape((len(x_range), len(y_range))), prob, 40) 128 | cp = plt.contour 129 | cp = plt.colorbar() 130 | 131 | 132 | -------------------------------------------------------------------------------- /DPI.yml: -------------------------------------------------------------------------------- 1 | name: torch_proj 2 | channels: 3 | - anaconda 4 | - astropy 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _pytorch_select=0.2=gpu_0 10 | - _tflow_select=2.3.0=mkl 11 | - absl-py=0.12.0=py37h06a4308_0 12 | - anyio=2.2.0=py37h06a4308_1 13 | - argon2-cffi=20.1.0=py37h27cfd23_1 14 | - arviz=0.11.2=pyhd3eb1b0_0 15 | - astor=0.8.1=py37h06a4308_0 16 | - astropy=4.2.1=py37h27cfd23_1 17 | - async_generator=1.10=py37h28b3542_0 18 | - attrs=20.3.0=pyhd3eb1b0_0 19 | - babel=2.9.0=pyhd3eb1b0_0 20 | - backcall=0.2.0=pyhd3eb1b0_0 21 | - blas=1.0=mkl 22 | - bleach=3.3.0=pyhd3eb1b0_0 23 | - brotlipy=0.7.0=py37h27cfd23_1003 24 | - bzip2=1.0.8=h7b6447c_0 25 | - c-ares=1.17.1=h27cfd23_0 26 | - ca-certificates=2021.4.13=h06a4308_1 27 | - cairo=1.16.0=hf32fb01_1 28 | - certifi=2020.12.5=py37h06a4308_0 29 | - cffi=1.14.5=py37h261ae71_0 30 | - cftime=1.4.1=py37h6323ea4_0 31 | - chardet=3.0.4=py37h06a4308_1003 32 | - cloudpickle=1.6.0=py_0 33 | - colorama=0.4.4=pyhd3eb1b0_0 34 | - colorlog=5.0.1=py37h06a4308_1 35 | - colormap=1.0.1=py_2 36 | - corner=2.2.1=pyhd8ed1ab_0 37 | - coverage=5.5=py37h27cfd23_2 38 | - cryptography=3.4.7=py37hd23ed53_0 39 | - cudatoolkit=10.0.130=0 40 | - cudnn=7.6.5=cuda10.0_0 41 | - curl=7.71.1=hbc83047_1 42 | - cycler=0.10.0=py37_0 43 | - cython=0.29.23=py37h2531618_0 44 | - cytoolz=0.11.0=py37h7b6447c_0 45 | - dask-core=2021.4.0=pyhd3eb1b0_0 46 | - dbus=1.13.18=hb2f20db_0 47 | - decorator=5.0.6=pyhd3eb1b0_0 48 | - defusedxml=0.7.1=pyhd3eb1b0_0 49 | - easydev=0.10.1=pyh9f0ad1d_0 50 | - entrypoints=0.3=py37_0 51 | - expat=2.3.0=h2531618_2 52 | - ffmpeg=4.0=hcdf2ecd_0 53 | - fftw=3.3.9=h27cfd23_1 54 | - fontconfig=2.13.1=h6c09931_0 55 | - freeglut=3.0.0=hf484d3e_5 56 | - freetype=2.10.4=h5ab3b9f_0 57 | - fsspec=0.9.0=pyhd3eb1b0_0 58 | - gast=0.4.0=py_0 59 | - glib=2.68.1=h36276a3_0 60 | - google-pasta=0.2.0=py_0 61 | - graphite2=1.3.14=h23475e2_0 62 | - grpcio=1.36.1=py37h2157cd5_1 63 | - gst-plugins-base=1.14.0=h8213a91_2 64 | - gstreamer=1.14.0=h28cd5cc_2 65 | - harfbuzz=1.8.8=hffaf4a1_0 66 | - hdf4=4.2.13=h3ca952b_2 67 | - hdf5=1.10.2=hba1933b_1 68 | - icu=58.2=he6710b0_3 69 | - idna=2.10=pyhd3eb1b0_0 70 | - imageio=2.9.0=pyhd3eb1b0_0 71 | - importlib-metadata=3.10.0=py37h06a4308_0 72 | - importlib_metadata=3.10.0=hd3eb1b0_0 73 | - intel-openmp=2021.2.0=h06a4308_610 74 | - ipykernel=5.3.4=py37h5ca1d4c_0 75 | - ipython=7.21.0=py37hb070fc8_0 76 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 77 | - jasper=2.0.14=h07fcdf6_1 78 | - jedi=0.17.0=py37_0 79 | - jinja2=2.11.3=pyhd3eb1b0_0 80 | - joblib=1.0.1=pyhd3eb1b0_0 81 | - jpeg=9b=h024ee3a_2 82 | - json5=0.9.5=py_0 83 | - jsonschema=3.2.0=py_2 84 | - jupyter-packaging=0.7.12=pyhd3eb1b0_0 85 | - jupyter_client=6.1.12=pyhd3eb1b0_0 86 | - jupyter_core=4.7.1=py37h06a4308_0 87 | - jupyter_server=1.4.1=py37h06a4308_0 88 | - jupyterlab=3.0.14=pyhd8ed1ab_0 89 | - jupyterlab_pygments=0.1.2=py_0 90 | - jupyterlab_server=2.4.0=pyhd3eb1b0_0 91 | - keras-applications=1.0.8=py_1 92 | - keras-preprocessing=1.1.2=pyhd3eb1b0_0 93 | - kiwisolver=1.3.1=py37h2531618_0 94 | - krb5=1.18.2=h173b8e3_0 95 | - lcms2=2.12=h3be6417_0 96 | - ld_impl_linux-64=2.33.1=h53a641e_7 97 | - libcurl=7.71.1=h20c2e04_1 98 | - libedit=3.1.20210216=h27cfd23_1 99 | - libffi=3.3=he6710b0_2 100 | - libgcc-ng=9.1.0=hdf63c60_0 101 | - libgfortran-ng=7.3.0=hdf63c60_0 102 | - libglu=9.0.0=hf484d3e_1 103 | - libnetcdf=4.6.1=h10edf3e_2 104 | - libopencv=3.4.2=hb342d67_1 105 | - libopus=1.3.1=h7b6447c_0 106 | - libpng=1.6.37=hbc83047_0 107 | - libprotobuf=3.14.0=h8c45485_0 108 | - libsodium=1.0.18=h7b6447c_0 109 | - libssh2=1.9.0=h1ba5d50_1 110 | - libstdcxx-ng=9.1.0=hdf63c60_0 111 | - libtiff=4.1.0=h2733197_1 112 | - libuuid=1.0.3=h1bed415_2 113 | - libvpx=1.7.0=h439df22_0 114 | - libxcb=1.14=h7b6447c_0 115 | - libxml2=2.9.10=hb55368b_3 116 | - locket=0.2.1=py37h06a4308_1 117 | - lz4-c=1.9.3=h2531618_0 118 | - markdown=3.3.4=py37h06a4308_0 119 | - markupsafe=1.1.1=py37h14c3975_1 120 | - matplotlib=3.3.3=py37h89c1867_0 121 | - matplotlib-base=3.3.3=py37h4f6019d_0 122 | - mistune=0.8.4=py37h14c3975_1001 123 | - mkl=2020.2=256 124 | - mkl-service=2.3.0=py37he8ac12f_0 125 | - mkl_fft=1.3.0=py37h54f3939_0 126 | - mkl_random=1.1.1=py37h0573a6f_0 127 | - nbclassic=0.2.6=pyhd3eb1b0_0 128 | - nbclient=0.5.3=pyhd3eb1b0_0 129 | - nbconvert=6.0.7=py37_0 130 | - nbformat=5.1.3=pyhd3eb1b0_0 131 | - ncurses=6.2=he6710b0_1 132 | - nest-asyncio=1.5.1=pyhd3eb1b0_0 133 | - netcdf4=1.4.2=py37h4b4f87f_0 134 | - networkx=2.5=py_0 135 | - nfft=3.2.4=hf8c457e_1000 136 | - ninja=1.10.2=hff7bd54_1 137 | - notebook=6.3.0=py37h06a4308_0 138 | - olefile=0.46=py37_0 139 | - opencv=3.4.2=py37h6fd60c2_1 140 | - openssl=1.1.1k=h27cfd23_0 141 | - packaging=20.9=pyhd3eb1b0_0 142 | - pandoc=2.12=h06a4308_0 143 | - pandocfilters=1.4.3=py37h06a4308_1 144 | - parse=1.14.0=py_0 145 | - parso=0.8.2=pyhd3eb1b0_0 146 | - partd=1.2.0=pyhd3eb1b0_0 147 | - pcre=8.44=he6710b0_0 148 | - pexpect=4.8.0=pyhd3eb1b0_3 149 | - pickleshare=0.7.5=pyhd3eb1b0_1003 150 | - pillow=8.2.0=py37he98fc37_0 151 | - pip=21.0.1=py37h06a4308_0 152 | - pixman=0.40.0=h7b6447c_0 153 | - prometheus_client=0.10.0=pyhd8ed1ab_0 154 | - prompt-toolkit=3.0.17=pyh06a4308_0 155 | - protobuf=3.14.0=py37h2531618_1 156 | - ptyprocess=0.7.0=pyhd3eb1b0_2 157 | - py-opencv=3.4.2=py37hb342d67_1 158 | - pycparser=2.20=py_2 159 | - pyerfa=1.7.2=py37h27cfd23_0 160 | - pygments=2.8.1=pyhd3eb1b0_0 161 | - pynfft=1.3.2=py37h03ebfcd_1002 162 | - pyopenssl=20.0.1=pyhd3eb1b0_1 163 | - pyparsing=2.4.7=pyhd3eb1b0_0 164 | - pyqt=5.9.2=py37h05f1152_2 165 | - pyrsistent=0.17.3=py37h7b6447c_0 166 | - pysocks=1.7.1=py37_1 167 | - python=3.7.10=hdb3f193_0 168 | - python-dateutil=2.8.1=pyhd3eb1b0_0 169 | - python_abi=3.7=1_cp37m 170 | - pytorch=1.3.1=cuda100py37h53c1284_0 171 | - pytorch-gpu=1.3.1=0 172 | - pytz=2021.1=pyhd3eb1b0_0 173 | - pywavelets=1.1.1=py37h7b6447c_2 174 | - pyyaml=5.4.1=py37h27cfd23_1 175 | - qt=5.9.7=h5867ecd_1 176 | - readline=8.1=h27cfd23_0 177 | - requests=2.24.0=py_0 178 | - scikit-image=0.16.2=py37h0573a6f_0 179 | - scikit-learn=0.23.2=py37h0573a6f_0 180 | - scipy=1.5.2=py37h0b6359f_0 181 | - seaborn=0.11.0=py_0 182 | - send2trash=1.5.0=pyhd3eb1b0_1 183 | - setuptools=47.3.0=py37_0 184 | - sip=4.19.8=py37hf484d3e_0 185 | - six=1.15.0=py37h06a4308_0 186 | - sniffio=1.2.0=py37h06a4308_1 187 | - sqlite=3.35.4=hdfb4753_0 188 | - tensorboard=1.14.0=py37hf484d3e_0 189 | - tensorflow=1.14.0=mkl_py37h45c423b_0 190 | - tensorflow-base=1.14.0=mkl_py37h7ce6ba3_0 191 | - tensorflow-estimator=1.14.0=py_0 192 | - termcolor=1.1.0=py37h06a4308_1 193 | - terminado=0.9.4=py37h06a4308_0 194 | - testpath=0.4.4=pyhd3eb1b0_0 195 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 196 | - tk=8.6.10=hbc83047_0 197 | - toolz=0.11.1=pyhd3eb1b0_0 198 | - tornado=6.1=py37h27cfd23_0 199 | - tqdm=4.51.0=pyh9f0ad1d_0 200 | - traitlets=5.0.5=pyhd3eb1b0_0 201 | - typing_extensions=3.7.4.3=pyha847dfd_0 202 | - urllib3=1.25.11=py_0 203 | - wcwidth=0.2.5=py_0 204 | - werkzeug=1.0.1=pyhd3eb1b0_0 205 | - wheel=0.36.2=pyhd3eb1b0_0 206 | - wrapt=1.12.1=py37h7b6447c_1 207 | - xarray=0.17.0=pyhd3eb1b0_0 208 | - xz=5.2.5=h7b6447c_0 209 | - yaml=0.2.5=h7b6447c_0 210 | - zeromq=4.3.4=h2531618_0 211 | - zipp=3.4.1=pyhd3eb1b0_0 212 | - zlib=1.2.11=h7b6447c_3 213 | - zstd=1.4.9=haebb681_0 214 | - pip: 215 | - aiohttp==3.7.3 216 | - aiohttp-cors==0.7.0 217 | - aioredis==1.3.1 218 | - alabaster==0.7.12 219 | - async-timeout==3.0.1 220 | - blessings==1.7 221 | - cached-property==1.5.2 222 | - cachetools==4.2.0 223 | - click==7.1.2 224 | - colorful==0.5.4 225 | - dill==0.3.3 226 | - docutils==0.17 227 | - ehtim==1.2.2 228 | - emcee==3.0.2 229 | - ephem==3.7.7.1 230 | - filelock==3.0.12 231 | - future==0.18.2 232 | - google-api-core==1.24.1 233 | - google-auth==1.24.0 234 | - googleapis-common-protos==1.52.0 235 | - gpustat==0.6.0 236 | - h5py==3.2.1 237 | - hdrhistogram==0.8.0 238 | - hiredis==1.1.0 239 | - imagesize==1.2.0 240 | - msgpack==1.0.2 241 | - multidict==5.1.0 242 | - nbsphinx==0.8.2 243 | - numpy==1.20.2 244 | - nvidia-ml-py3==7.352.0 245 | - opencensus==0.7.11 246 | - opencensus-context==0.1.2 247 | - orbitize==1.15.1 248 | - pandas==1.1.4 249 | - pandas-appender==0.9.4 250 | - paramsurvey==0.4.9 251 | - pbr==5.5.1 252 | - ptemcee==1.0.0 253 | - py-spy==0.3.3 254 | - pyarrow==2.0.0 255 | - pyasn1==0.4.8 256 | - pyasn1-modules==0.2.8 257 | - pyzmq==22.0.3 258 | - radvel==1.4.4 259 | - ray==1.1.0 260 | - redis==3.5.3 261 | - rsa==4.7 262 | - snowballstemmer==2.1.0 263 | - sphinx==3.5.3 264 | - sphinxcontrib-applehelp==1.0.2 265 | - sphinxcontrib-devhelp==1.0.2 266 | - sphinxcontrib-htmlhelp==1.0.3 267 | - sphinxcontrib-jsmath==1.0.1 268 | - sphinxcontrib-qthelp==1.0.3 269 | - sphinxcontrib-serializinghtml==1.1.4 270 | - torchcubicspline==0.0.3 271 | - torchkbnufft==0.3.2.post1 272 | - webencodings==0.5.1 273 | - yarl==1.6.3 274 | prefix: /home/groot/anaconda3/envs/torch_proj 275 | 276 | -------------------------------------------------------------------------------- /DPItorch/DPI_MRI.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | # import matplotlib.pyplot as plt 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | import torch.nn.functional as functional 9 | 10 | torch.set_default_dtype(torch.float32) 11 | import torch.optim as optim 12 | import pickle 13 | import math 14 | import cv2 15 | 16 | 17 | from generative_model import glow_model 18 | from generative_model import realnvpfc_model 19 | 20 | from MRI_helpers import * 21 | 22 | 23 | import argparse 24 | 25 | # plt.ion() 26 | 27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | parser = argparse.ArgumentParser(description="Deep Probabilistic Imaging Trainer for MRI") 29 | parser.add_argument("--cuda", default=0, type=int, help="cuda index in use") 30 | parser.add_argument("--impath", default='../dataset/fastmri_sample/mri/knee/scan_0.pkl', type=str, help="MRI image scan") 31 | parser.add_argument("--maskpath", default='../dataset/fastmri_sample/mask/mask8.npy', type=str, help="MRI image scan mask") 32 | parser.add_argument("--save_path", default='./save_path_mri', type=str, help="file save path") 33 | parser.add_argument("--npix", default=64, type=int, help="image shape (pixels)") 34 | parser.add_argument("--ratio", default=1/8, type=float, help="MRI compression ratio") 35 | parser.add_argument("--sigma", default=5e-7, type=float, help="std of MRI additive measurement noise") 36 | parser.add_argument("--model_form", default='realnvp', type=str, help="form of the deep generative model") 37 | parser.add_argument("--n_flow", default=16, type=int, help="number of flows in RealNVP or Glow") 38 | parser.add_argument("--n_block", default=4, type=int, help="number of blocks in Glow") 39 | parser.add_argument("--lr", default=1e-5, type=float, help="learning rate") 40 | parser.add_argument("--n_epoch", default=3000, type=int, help="number of epochs for training RealNVP") 41 | parser.add_argument("--logdet", default=1.0, type=float, help="logdet weight") 42 | # parser.add_argument("--l1", default=1e4, type=float, help="l1 prior weight") 43 | parser.add_argument("--l1", default=0.0, type=float, help="l1 prior weight") 44 | parser.add_argument("--tv", default=1e3, type=float, help="tv prior weight") 45 | # parser.add_argument("--flux", default=0.1, type=float, help="flux prior weight") 46 | parser.add_argument("--clip", default=1e-2, type=float, help="gradient clip for neural network training") 47 | 48 | 49 | # def readMRIdata(filepath): 50 | # with open(filepath, 'rb') as f: 51 | # obj = pickle.load(f) 52 | # kspace = obj['kspace'] 53 | # target_image = obj['target'] 54 | # attributes = obj['attributes'] 55 | # return kspace, target_image, attributes 56 | 57 | def readMRIdata(filepath): 58 | with open(filepath, 'rb') as f: 59 | obj = pickle.load(f) 60 | target_image = obj['target'] 61 | return target_image 62 | 63 | def fft2c(data): 64 | # data = np.fft.ifftshift(data) 65 | data = np.fft.fft2(data, norm="ortho") 66 | # data = np.fft.fftshift(data) 67 | return np.stack((data.real, data.imag), axis=-1) 68 | 69 | 70 | def fft2c_torch(img): 71 | x = img.unsqueeze(-1) 72 | x = torch.cat([x, torch.zeros_like(x)], -1) 73 | kspace_pred = torch.fft(x, 2, normalized=True) 74 | return kspace_pred 75 | 76 | 77 | class Img_logscale(nn.Module): 78 | """ Custom Linear layer but mimics a standard linear layer """ 79 | def __init__(self, scale=1): 80 | super().__init__() 81 | log_scale = torch.Tensor(np.log(scale)*np.ones(1)) 82 | self.log_scale = nn.Parameter(log_scale) 83 | 84 | def forward(self): 85 | return self.log_scale 86 | 87 | 88 | if __name__ == "__main__": 89 | args = parser.parse_args() 90 | impath = args.impath 91 | npix = args.npix 92 | ratio = args.ratio 93 | save_path = args.save_path 94 | sigma = args.sigma 95 | 96 | if device == 'cuda': 97 | device = device + ':{}'.format(args.cuda) 98 | 99 | save_path = args.save_path 100 | if not os.path.exists(save_path): 101 | os.makedirs(save_path) 102 | 103 | # _, img_true, _ = readMRIdata(impath) 104 | img_true = readMRIdata(impath) 105 | img_true = cv2.resize(img_true, (npix, npix), interpolation=cv2.INTER_AREA) 106 | kspace = fft2c(img_true) 107 | kspace = kspace + np.random.normal(size=kspace.shape) * sigma 108 | mask = np.load(args.maskpath) 109 | mask[24:40, 24:40] = 1 110 | mask = np.fft.fftshift(mask) 111 | mask = np.stack((mask, mask), axis=-1) 112 | 113 | args.flux = np.sum(img_true) 114 | 115 | if args.model_form == 'realnvp': 116 | n_flow = args.n_flow 117 | affine = True 118 | img_generator = realnvpfc_model.RealNVP(npix*npix, n_flow, affine=affine).to(device) 119 | # img_generator.load_state_dict(torch.load(save_path+'/generativemodel_'+args.model_form+'_ratio{}_res{}flow{}logdet{}_tv'.format(args.ratio, npix, n_flow, args.logdet))) 120 | elif args.model_form == 'glow': 121 | n_channel = 1 122 | n_flow = args.n_flow 123 | n_block = args.n_block 124 | affine = True 125 | no_lu = False#True 126 | z_shapes = glow_model.calc_z_shapes(n_channel, npix, n_flow, n_block) 127 | img_generator = glow_model.Glow(n_channel, n_flow, n_block, affine=affine, conv_lu=not no_lu).to(device) 128 | 129 | 130 | logscale_factor = Img_logscale(scale=args.flux/(0.8*npix*npix)).to(device) 131 | # logscale_factor.load_state_dict(torch.load(save_path+'/generativescale_'+args.model_form+'_ratio{}_res{}flow{}logdet{}_tv'.format(args.ratio, npix, n_flow, args.logdet))) 132 | 133 | # define the losses and weights for MRI 134 | # Loss_kspace_img = Loss_kspace_diff(sigma) 135 | Loss_kspace_img = Loss_kspace_diff2(sigma) 136 | 137 | kspace_weight = 1.0 138 | imgl1_weight = args.l1 / args.flux 139 | imgtv_weight = args.tv * npix / args.flux 140 | logdet_weight = args.logdet / (0.5 * np.sum(mask)) 141 | 142 | kspace_true = torch.Tensor(mask * kspace).to(device=device) 143 | 144 | # optimize both scale and image generator 145 | lr = args.lr 146 | optimizer = optim.Adam(list(img_generator.parameters())+list(logscale_factor.parameters()), lr = lr) 147 | # optimizer = optim.Adam(img_generator.parameters(), lr = lr) 148 | 149 | n_epoch = args.n_epoch#30000#10000#100000#50000#100# 150 | loss_list = [] 151 | loss_prior_list = [] 152 | loss_kspace_list = [] 153 | logdet_list = [] 154 | loss_tv_list = [] 155 | loss_l1_list = [] 156 | 157 | 158 | 159 | n_batch = 64#32#8 160 | for k in range(n_epoch): 161 | if args.model_form == 'realnvp': 162 | z_sample = torch.randn(n_batch, npix*npix).to(device=device) 163 | elif args.model_form == 'glow': 164 | z_sample = [] 165 | for z in z_shapes: 166 | z_new = torch.randn(n_batch, *z) 167 | z_sample.append(z_new.to(device)) 168 | 169 | # generate image samples 170 | img_samp, logdet = img_generator.reverse(z_sample) 171 | img_samp = img_samp.reshape((-1, npix, npix)) 172 | 173 | # apply scale factor and sigmoid/softplus layer for positivity constraint 174 | logscale_factor_value = logscale_factor.forward() 175 | scale_factor = torch.exp(logscale_factor_value) 176 | img = torch.nn.Softplus()(img_samp) * scale_factor 177 | det_softplus = torch.sum(img_samp - torch.nn.Softplus()(img_samp), (1, 2)) 178 | det_scale = logscale_factor_value * npix * npix 179 | logdet = logdet + det_softplus + det_scale 180 | 181 | kspace_pred = fft2c_torch(img) 182 | loss_data = Loss_kspace_img(kspace_true, kspace_pred * torch.Tensor(mask).to(device)) / np.mean(mask) 183 | 184 | loss_l1 = Loss_l1(img) if imgl1_weight>0 else 0 185 | loss_tv = Loss_TV(img) if imgtv_weight>0 else 0 186 | 187 | loss_prior = imgtv_weight * loss_tv + imgl1_weight * loss_l1 188 | 189 | # if k < 0.0 * n_epoch: 190 | # loss = torch.mean(loss_data) + torch.mean(loss_prior) - 10*logdet_weight*torch.mean(logdet) 191 | # else: 192 | # loss = torch.mean(loss_data) + torch.mean(loss_prior) - logdet_weight*torch.mean(logdet) 193 | loss = torch.mean(loss_data) + torch.mean(loss_prior) - logdet_weight*torch.mean(logdet) 194 | 195 | 196 | loss_list.append(loss.detach().cpu().numpy()) 197 | loss_kspace_list.append(torch.mean(loss_data).detach().cpu().numpy()) 198 | loss_prior_list.append(torch.mean(loss_prior).detach().cpu().numpy()) 199 | logdet_list.append(-torch.mean(logdet).detach().cpu().numpy() / (npix*npix)) 200 | loss_tv_list.append(imgtv_weight * torch.mean(loss_tv).detach().cpu().numpy() if imgtv_weight>0 else 0) 201 | loss_l1_list.append(imgl1_weight * torch.mean(loss_l1).detach().cpu().numpy() if imgl1_weight>0 else 0) 202 | 203 | optimizer.zero_grad() 204 | loss.backward() 205 | nn.utils.clip_grad_norm_(list(img_generator.parameters())+ list(logscale_factor.parameters()), args.clip) 206 | # nn.utils.clip_grad_norm_(img_generator.parameters(), args.clip) 207 | optimizer.step() 208 | 209 | 210 | print(f"epoch: {k:}, loss: {loss_list[-1]:.5f}, loss kspace: {loss_kspace_list[-1]:.5f}, logdet: {logdet_list[-1]:.5f}") 211 | print(f"loss tv: {loss_tv_list[-1]:.5f}, loss l1: {loss_l1_list[-1]:.5f}") 212 | 213 | 214 | torch.save(img_generator.state_dict(), save_path+'/generativemodel_'+args.model_form+'_ratio{}_res{}flow{}logdet{}_tv2'.format(args.ratio, npix, n_flow, args.logdet)) 215 | torch.save(logscale_factor.state_dict(), save_path+'/generativescale_'+args.model_form+'_ratio{}_res{}flow{}logdet{}_tv2'.format(args.ratio, npix, n_flow, args.logdet)) 216 | np.save(save_path+'/generativeimage_'+args.model_form+'_ratio{}_res{}flow{}logdet{}_tv2.npy'.format(args.ratio, npix, n_flow, args.logdet), img.cpu().detach().numpy().squeeze()) 217 | 218 | loss_all = {} 219 | loss_all['total'] = np.array(loss_list) 220 | loss_all['kspace'] = np.array(loss_kspace_list) 221 | loss_all['logdet'] = np.array(logdet_list) 222 | loss_all['tv'] = np.array(loss_tv_list) 223 | 224 | loss_all['l1'] = np.array(loss_l1_list) 225 | np.save(save_path+'/loss_'+args.model_form+'_ratio{}_res{}flow{}logdet{}_tv2.npy'.format(args.ratio, npix, n_flow, args.logdet), loss_all) 226 | 227 | -------------------------------------------------------------------------------- /DPItorch/DPI_interferometry.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | import torch.nn.functional as functional 9 | 10 | torch.set_default_dtype(torch.float32) 11 | import torch.optim as optim 12 | import pickle 13 | import math 14 | 15 | from torchkbnufft import KbNufft, AdjKbNufft 16 | from torchkbnufft.mri.dcomp_calc import calculate_radial_dcomp_pytorch 17 | from torchkbnufft.math import absolute 18 | 19 | # from loupe import models_vlbi, layers_vlbi # loupe package 20 | import ehtim as eh # eht imaging package 21 | 22 | from ehtim.observing.obs_helpers import * 23 | import ehtim.const_def as ehc 24 | from scipy.ndimage import gaussian_filter 25 | import skimage.transform 26 | # import helpers as hp 27 | import csv 28 | import sys 29 | import datetime 30 | import warnings 31 | import copy 32 | 33 | import gc 34 | import cv2 35 | 36 | from astropy.io import fits 37 | from pynfft.nfft import NFFT 38 | 39 | from generative_model import glow_model 40 | from generative_model import realnvpfc_model 41 | # from interferometry_loss import * 42 | # from interferometry_obs import * 43 | from interferometry_helpers import * 44 | 45 | import argparse 46 | 47 | plt.ion() 48 | 49 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 50 | parser = argparse.ArgumentParser(description="Deep Probabilistic Imaging Trainer for Interferometry") 51 | parser.add_argument("--cuda", default=0, type=int, help="cuda index in use") 52 | parser.add_argument("--obspath", default='../dataset/interferometry1/obs.uvfits', type=str, help="EHT observation file path") 53 | parser.add_argument("--impath", default='../dataset/interferometry1/gt.fits', type=str, help="groud-truth EHT image file path") 54 | parser.add_argument("--save_path", default='./save_path', type=str, help="file save path") 55 | parser.add_argument("--npix", default=32, type=int, help="image shape (pixels)") 56 | parser.add_argument("--fov", default=160, type=float, help="field of view of the image in micro-arcsecond") 57 | parser.add_argument("--prior_fwhm", default=50, type=float, help="fwhm of image prior in micro-arcsecond") 58 | parser.add_argument("--model_form", default='realnvp', type=str, help="form of the deep generative model") 59 | parser.add_argument("--n_flow", default=16, type=int, help="number of flows in RealNVP or Glow") 60 | parser.add_argument("--n_block", default=4, type=int, help="number of blocks in Glow") 61 | parser.add_argument("--lr", default=1e-4, type=float, help="learning rate") 62 | parser.add_argument("--n_epoch", default=30000, type=int, help="number of epochs for training RealNVP") 63 | parser.add_argument("--logdet", default=1.0, type=float, help="logdet weight") 64 | parser.add_argument("--l1", default=1.0, type=float, help="l1 prior weight") 65 | parser.add_argument("--tsv", default=100.0, type=float, help="tsv prior weight") 66 | parser.add_argument("--flux", default=1000.0, type=float, help="flux prior weight") 67 | parser.add_argument("--center", default=1.0, type=float, help="centering prior weight") 68 | parser.add_argument("--mem", default=1024.0, type=float, help="mem prior weight") 69 | parser.add_argument("--clip", default=0.1, type=float, help="gradient clip for neural network training") 70 | 71 | parser.add_argument("--ttype", default='nfft', type=str, help="fourier transform computation method") 72 | 73 | 74 | class Img_logscale(nn.Module): 75 | """ Custom Linear layer but mimics a standard linear layer """ 76 | def __init__(self, scale=1): 77 | super().__init__() 78 | log_scale = torch.Tensor(np.log(scale)*np.ones(1)) 79 | self.log_scale = nn.Parameter(log_scale) 80 | 81 | def forward(self): 82 | return self.log_scale 83 | 84 | 85 | if __name__ == "__main__": 86 | args = parser.parse_args() 87 | obs_path = args.obspath 88 | gt_path = args.impath 89 | npix = args.npix 90 | 91 | if torch.cuda.is_available(): 92 | device = torch.device('cuda:{}'.format(args.cuda)) 93 | 94 | obs = eh.obsdata.load_uvfits(obs_path) 95 | 96 | # define the prior image for MEM regularizer 97 | flux_const = np.median(obs.unpack_bl('APEX', 'ALMA', 'amp')['amp']) 98 | prior_fwhm = args.prior_fwhm*eh.RADPERUAS#60*eh.RADPERUAS# 99 | fov = args.fov*eh.RADPERUAS 100 | zbl = flux_const#2.0#0.8# 101 | prior = eh.image.make_square(obs, npix, fov).add_gauss(zbl, (prior_fwhm, prior_fwhm, 0, 0, 0)) 102 | prior = prior.add_gauss(zbl*1e-6, (prior_fwhm, prior_fwhm, 0, prior_fwhm, prior_fwhm)) 103 | 104 | simim = prior.copy() 105 | 106 | # simim = eh.image.load_fits(gt_path) 107 | # simim = simim.regrid_image(fov, npix) 108 | simim.ra = obs.ra 109 | simim.dec = obs.dec 110 | simim.rf = obs.rf 111 | 112 | save_path = args.save_path 113 | if not os.path.exists(save_path): 114 | os.makedirs(save_path) 115 | 116 | 117 | # define the eht observation function 118 | ttype = args.ttype 119 | nufft_ob = KbNufft(im_size=(npix, npix), numpoints=3) 120 | dft_mat, ktraj_vis, pulsefac_vis_torch, cphase_ind_list, cphase_sign_list, camp_ind_list = Obs_params_torch(obs, simim, snrcut=0.0, ttype=ttype) 121 | eht_obs_torch = eht_observation_pytorch(npix, nufft_ob, dft_mat, ktraj_vis, pulsefac_vis_torch, cphase_ind_list, cphase_sign_list, camp_ind_list, device, ttype=ttype) 122 | 123 | 124 | if args.model_form == 'realnvp': 125 | n_flow = args.n_flow 126 | affine = True 127 | img_generator = realnvpfc_model.RealNVP(npix*npix, n_flow, affine=affine).to(device) 128 | # img_generator.load_state_dict(torch.load(save_path+'/generativemodel_'+args.model_form+'_res{}flow{}logdet{}_closure_fluxcentermemtsv'.format(npix, n_flow, args.logdet))) 129 | elif args.model_form == 'glow': 130 | n_channel = 1 131 | n_flow = args.n_flow 132 | n_block = args.n_block 133 | affine = True 134 | no_lu = False#True 135 | z_shapes = glow_model.calc_z_shapes(n_channel, npix, n_flow, n_block) 136 | img_generator = glow_model.Glow(n_channel, n_flow, n_block, affine=affine, conv_lu=not no_lu).to(device) 137 | 138 | logscale_factor = Img_logscale(scale=flux_const/(0.8*npix*npix)).to(device) 139 | # logscale_factor.load_state_dict(torch.load(save_path+'/generativescale_'+args.model_form+'_res{}flow{}logdet{}_closure_fluxcentermemtsv'.format(npix, n_flow, args.logdet))) 140 | 141 | # define the losses and weights for very long baseline interferometric imaging 142 | Loss_center_img = Loss_center(device, center=npix/2-0.5, dim=npix) 143 | Loss_flux_img = Loss_flux(flux_const) 144 | Loss_vis_img = Loss_vis_diff(obs.data['sigma'], device) 145 | Loss_cphase_img = Loss_angle_diff(obs.cphase['sigmacp'], device) 146 | Loss_logamp_img = Loss_logamp_diff(obs.data['sigma'], device) 147 | # Loss_logca_img = Loss_logca_diff(obs.camp['sigmaca'], device) 148 | Loss_logca_img2 = Loss_logca_diff2(obs.logcamp['sigmaca'], device) 149 | 150 | 151 | camp_weight = 1.0 152 | cphase_weight = len(obs.cphase['cphase'])/len(obs.camp['camp'])#1.0#10.0# 153 | visamp_weight = 0.0#1e-3#1.0#1e-5# 154 | imgl1_weight = args.l1 * npix*npix/flux_const#npix*npix/flux_const#1.0 155 | imgtsv_weight = args.tsv * npix*npix#100*npix*npix 156 | imgflux_weight = args.flux#1024#0.0#npix*npix#1.0#10 157 | imgcenter_weight = args.center*1e5/(npix*npix)#1e5/(npix*npix)#100#0.0#1.0#npix*npix#10# 158 | imgcrossentropy_weight = args.mem#1024#10*npix*npix 159 | logdet_weight = 2.0 * args.logdet / len(obs.camp['camp'])#args.logdet / (npix*npix)#1.0 / (npix*npix) # 160 | 161 | 162 | vis_true = torch.Tensor(np.concatenate([np.expand_dims(obs.data['vis'].real, 0), 163 | np.expand_dims(obs.data['vis'].imag, 0)], 0)).to(device=device) 164 | visamp_true = torch.Tensor(np.array(np.abs(obs.data['vis']))).to(device=device) 165 | cphase_true = torch.Tensor(np.array(obs.cphase['cphase'])).to(device=device) 166 | camp_true = torch.Tensor(np.array(obs.camp['camp'])).to(device=device) 167 | logcamp_true = torch.Tensor(np.array(obs.logcamp['camp'])).to(device=device) 168 | prior_im = torch.Tensor(np.array(prior.imvec.reshape((npix, npix)))).to(device=device) 169 | 170 | # optimize both scale and image generator 171 | lr = args.lr 172 | optimizer = optim.Adam(list(img_generator.parameters())+list(logscale_factor.parameters()), lr = lr) 173 | # optimizer = optim.Adam(img_generator.parameters(), lr = lr) 174 | 175 | 176 | n_epoch = args.n_epoch#30000#10000#100000#50000#100# 177 | loss_list = [] 178 | loss_prior_list = [] 179 | loss_cphase_list = [] 180 | loss_logca_list = [] 181 | loss_visamp_list = [] 182 | # loss_vis_list = [] 183 | logdet_list = [] 184 | loss_center_list = [] 185 | loss_tsv_list = [] 186 | loss_flux_list = [] 187 | loss_cross_entropy_list = [] 188 | loss_l1_list = [] 189 | 190 | n_batch = 32#8 191 | for k in range(n_epoch): 192 | if args.model_form == 'realnvp': 193 | z_sample = torch.randn(n_batch, npix*npix).to(device=device) 194 | elif args.model_form == 'glow': 195 | z_sample = [] 196 | for z in z_shapes: 197 | z_new = torch.randn(n_batch, *z) 198 | z_sample.append(z_new.to(device)) 199 | 200 | # generate image samples 201 | img_samp, logdet = img_generator.reverse(z_sample) 202 | img_samp = img_samp.reshape((-1, npix, npix)) 203 | 204 | # apply scale factor and sigmoid/softplus layer for positivity constraint 205 | logscale_factor_value = logscale_factor.forward() 206 | scale_factor = torch.exp(logscale_factor_value) 207 | img = torch.nn.Softplus()(img_samp) * scale_factor 208 | det_softplus = torch.sum(img_samp - torch.nn.Softplus()(img_samp), (1, 2)) 209 | # logdet = logdet + det_softplus 210 | det_scale = logscale_factor_value * npix * npix 211 | logdet = logdet + det_softplus + det_scale 212 | 213 | vis, visamp, cphase, logcamp = eht_obs_torch(img) 214 | loss_center = Loss_center_img(img) if imgcenter_weight>0 else 0 215 | loss_l1 = Loss_l1(img) if imgl1_weight>0 else 0 216 | loss_tsv = Loss_TSV(img) if imgtsv_weight>0 else 0 217 | loss_cross_entropy = Loss_cross_entropy(prior_im, img) if imgcrossentropy_weight>0 else 0 218 | loss_flux = Loss_flux_img(img) if imgflux_weight>0 else 0 219 | # loss_vis = Loss_vis_img(vis_true, vis) 220 | loss_visamp = Loss_logamp_img(visamp_true, visamp) if visamp_weight>0 else 0 221 | loss_cphase = Loss_cphase_img(cphase_true, cphase) if cphase_weight>0 else 0 222 | loss_camp = Loss_logca_img2(logcamp_true, logcamp) if camp_weight>0 else 0 223 | 224 | loss_data = camp_weight * loss_camp + cphase_weight * loss_cphase + visamp_weight * loss_visamp 225 | # loss_prior = imgflux_weight * loss_flux + imgl1_weight * loss_l1 + imgcenter_weight * loss_center 226 | loss_prior = imgcrossentropy_weight*loss_cross_entropy + imgflux_weight * loss_flux + \ 227 | imgtsv_weight * loss_tsv + imgcenter_weight * loss_center + imgl1_weight * loss_l1 228 | 229 | loss = torch.mean(loss_data) + torch.mean(loss_prior) - logdet_weight*torch.mean(logdet) 230 | 231 | optimizer.zero_grad() 232 | loss.backward() 233 | nn.utils.clip_grad_norm_(list(img_generator.parameters())+ list(logscale_factor.parameters()), args.clip) 234 | # nn.utils.clip_grad_norm_(img_generator.parameters(), args.clip) 235 | optimizer.step() 236 | 237 | loss_list.append(loss.detach().cpu().numpy()) 238 | loss_cphase_list.append(torch.mean(loss_cphase).detach().cpu().numpy() if cphase_weight>0 else 0) 239 | loss_logca_list.append(torch.mean(loss_camp).detach().cpu().numpy() if camp_weight>0 else 0) 240 | loss_visamp_list.append(torch.mean(loss_visamp).detach().cpu().numpy() if visamp_weight>0 else 0) 241 | loss_prior_list.append(torch.mean(loss_prior).detach().cpu().numpy()) 242 | logdet_list.append(-torch.mean(logdet).detach().cpu().numpy() / (npix*npix)) 243 | loss_flux_list.append(torch.mean(loss_flux).detach().cpu().numpy() if imgflux_weight>0 else 0) 244 | loss_tsv_list.append(torch.mean(loss_tsv).detach().cpu().numpy() if imgtsv_weight>0 else 0) 245 | loss_center_list.append(torch.mean(loss_center).detach().cpu().numpy() if imgcenter_weight>0 else 0) 246 | loss_cross_entropy_list.append(torch.mean(loss_cross_entropy).detach().cpu().numpy() if imgcrossentropy_weight>0 else 0) 247 | loss_l1_list.append(torch.mean(loss_l1).detach().cpu().numpy() if imgl1_weight>0 else 0) 248 | 249 | 250 | print(f"epoch: {k:}, loss: {loss_list[-1]:.5f}, loss cphase: {loss_cphase_list[-1]:.5f}, loss camp: {loss_logca_list[-1]:.5f}, loss visamp: {loss_visamp_list[-1]:.5f}, logdet: {logdet_list[-1]:.5f}") 251 | print(f"loss cross entropy: {loss_cross_entropy_list[-1]:.5f}, loss tsv: {loss_tsv_list[-1]:.5f}, loss l1: {loss_l1_list[-1]:.5f}, loss center: {loss_center_list[-1]:.5f}, loss flux: {loss_flux_list[-1]:.5f}") 252 | 253 | 254 | # print(f"epoch: {((n_epoch//n_blur)*k_blur+k):}, loss: {loss_list[-1]:.5f}, loss cphase: {loss_cphase_list[-1]:.5f}, loss camp: {loss_logca_list[-1]:.5f}, loss visamp: {loss_visamp_list[-1]:.5f}, loss prior: {loss_prior_list[-1]:.5f}") 255 | 256 | torch.save(img_generator.state_dict(), save_path+'/generativemodel_'+args.model_form+'_res{}flow{}logdet{}_closure_fluxcentermemtsv'.format(npix, n_flow, args.logdet)) 257 | torch.save(logscale_factor.state_dict(), save_path+'/generativescale_'+args.model_form+'_res{}flow{}logdet{}_closure_fluxcentermemtsv'.format(npix, n_flow, args.logdet)) 258 | np.save(save_path+'/generativeimage_'+args.model_form+'_res{}flow{}logdet{}_closure_fluxcentermemtsv.npy'.format(npix, n_flow, args.logdet), img.cpu().detach().numpy().squeeze()) 259 | 260 | loss_all = {} 261 | loss_all['total'] = np.array(loss_list) 262 | loss_all['cphase'] = np.array(loss_cphase_list) 263 | loss_all['logca'] = np.array(loss_logca_list) 264 | loss_all['visamp'] = np.array(loss_visamp_list) 265 | loss_all['logdet'] = np.array(logdet_list) 266 | loss_all['flux'] = np.array(loss_flux_list) 267 | loss_all['tsv'] = np.array(loss_tsv_list) 268 | loss_all['center'] = np.array(loss_center_list) 269 | loss_all['mem'] = np.array(loss_cross_entropy_list) 270 | loss_all['l1'] = np.array(loss_l1_list) 271 | np.save(save_path+'/loss_'+args.model_form+'_res{}flow{}logdet{}_closure_fluxcentermemtsv.npy'.format(npix, n_flow, args.logdet), loss_all) 272 | 273 | -------------------------------------------------------------------------------- /DPItorch/interferometry_helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | import torch.nn.functional as functional 9 | 10 | torch.set_default_dtype(torch.float32) 11 | import torch.optim as optim 12 | import pickle 13 | import math 14 | 15 | from torchkbnufft import KbNufft, AdjKbNufft 16 | from torchkbnufft.mri.dcomp_calc import calculate_radial_dcomp_pytorch 17 | from torchkbnufft.math import absolute 18 | 19 | # from loupe import models_vlbi, layers_vlbi # loupe package 20 | import ehtim as eh # eht imaging package 21 | 22 | from ehtim.observing.obs_helpers import * 23 | import ehtim.const_def as ehc 24 | 25 | from astropy.io import fits 26 | from pynfft.nfft import NFFT 27 | 28 | 29 | 30 | def torch_complex_mul(x, y): 31 | # complex multiplication in pytorch 32 | xy_real = x[:, :, 0:1] * y[0:1] - x[:, :, 1::] * y[1::] 33 | xy_imag = x[:, :, 0:1] * y[1::] + x[:, :, 1::] * y[0:1] 34 | return torch.cat([xy_real, xy_imag], -2) 35 | 36 | def torch_complex_matmul(x, F): 37 | Fx_real = torch.matmul(x, F[:, :, 0]) 38 | Fx_imag = torch.matmul(x, F[:, :, 1]) 39 | return torch.cat([Fx_real.unsqueeze(1), Fx_imag.unsqueeze(1)], -2) 40 | 41 | ############################################################################### 42 | # extracte interferometry observation parameters from obs and simim files 43 | ############################################################################### 44 | def Obs_params_torch(obs, simim, snrcut=0.0, ttype='nfft'): 45 | ############################################################################### 46 | # generate the discrete Fourier transform matrices or nfft variables for complex visibilities 47 | ############################################################################### 48 | obs_data = obs.unpack(['u', 'v', 'vis', 'sigma']) 49 | uv = np.hstack((obs_data['u'].reshape(-1,1), obs_data['v'].reshape(-1,1))) 50 | vu = np.hstack((obs_data['v'].reshape(-1,1), obs_data['u'].reshape(-1,1))) 51 | 52 | fft_pad_factor = ehc.FFT_PAD_DEFAULT 53 | p_rad = ehc.GRIDDER_P_RAD_DEFAULT 54 | npad = int(fft_pad_factor * np.max((simim.xdim, simim.ydim))) 55 | nfft_info_vis = NFFTInfo(simim.xdim, simim.ydim, simim.psize, simim.pulse, npad, p_rad, uv) 56 | pulsefac_vis = nfft_info_vis.pulsefac 57 | 58 | vu_scaled = np.array(vu * simim.psize * 2 * np.pi) 59 | ktraj_vis = torch.tensor(vu_scaled.T).unsqueeze(0) 60 | pulsefac_vis_torch = torch.tensor(np.concatenate([np.expand_dims(pulsefac_vis.real, 0), 61 | np.expand_dims(pulsefac_vis.imag, 0)], 0)) 62 | if ttype == 'direct': 63 | dft_mat = ftmatrix(simim.psize, simim.xdim, simim.ydim, uv, pulse=simim.pulse) 64 | dft_mat = np.expand_dims(dft_mat.T, -1) 65 | dft_mat = np.concatenate([dft_mat.real, dft_mat.imag], -1) 66 | dft_mat = torch.tensor(dft_mat, dtype=torch.float32) 67 | else: 68 | dft_mat = None 69 | 70 | ############################################################################### 71 | # generate the discrete Fourier transform matrices for closure phases 72 | ############################################################################### 73 | # if snrcut > 0: 74 | # obs.add_cphase(count='min', snrcut=snrcut) 75 | # else: 76 | # obs.add_cphase(count='min') 77 | 78 | # if snrcut > 0: 79 | # obs.add_cphase(count='max', snrcut=snrcut) 80 | # else: 81 | # obs.add_cphase(count='max') 82 | 83 | if snrcut > 0: 84 | obs.add_cphase(count='min-cut0bl', uv_min=.1e9, snrcut=snrcut) 85 | else: 86 | obs.add_cphase(count='min-cut0bl', uv_min=.1e9) 87 | 88 | 89 | 90 | tc1 = obs.cphase['t1'] 91 | tc2 = obs.cphase['t2'] 92 | tc3 = obs.cphase['t3'] 93 | 94 | cphase_map = np.zeros((len(obs.cphase['time']), 3)) 95 | 96 | zero_symbol = 100000 97 | for k1 in range(cphase_map.shape[0]): 98 | for k2 in list(np.where(obs.data['time']==obs.cphase['time'][k1])[0]): 99 | if obs.data['t1'][k2] == obs.cphase['t1'][k1] and obs.data['t2'][k2] == obs.cphase['t2'][k1]: 100 | cphase_map[k1, 0] = k2 101 | if k2 == 0: 102 | cphase_map[k1, 0] = zero_symbol 103 | elif obs.data['t2'][k2] == obs.cphase['t1'][k1] and obs.data['t1'][k2] == obs.cphase['t2'][k1]: 104 | cphase_map[k1, 0] = -k2 105 | if k2 == 0: 106 | cphase_map[k1, 0] = -zero_symbol 107 | elif obs.data['t1'][k2] == obs.cphase['t2'][k1] and obs.data['t2'][k2] == obs.cphase['t3'][k1]: 108 | cphase_map[k1, 1] = k2 109 | if k2 == 0: 110 | cphase_map[k1, 1] = zero_symbol 111 | elif obs.data['t2'][k2] == obs.cphase['t2'][k1] and obs.data['t1'][k2] == obs.cphase['t3'][k1]: 112 | cphase_map[k1, 1] = -k2 113 | if k2 == 0: 114 | cphase_map[k1, 1] = -zero_symbol 115 | elif obs.data['t1'][k2] == obs.cphase['t3'][k1] and obs.data['t2'][k2] == obs.cphase['t1'][k1]: 116 | cphase_map[k1, 2] = k2 117 | if k2 == 0: 118 | cphase_map[k1, 2] = zero_symbol 119 | elif obs.data['t2'][k2] == obs.cphase['t3'][k1] and obs.data['t1'][k2] == obs.cphase['t1'][k1]: 120 | cphase_map[k1, 2] = -k2 121 | if k2 == 0: 122 | cphase_map[k1, 2] = -zero_symbol 123 | 124 | cphase_ind1 = np.abs(cphase_map[:, 0]).astype(np.int) 125 | cphase_ind1[cphase_ind1==zero_symbol] = 0 126 | cphase_ind2 = np.abs(cphase_map[:, 1]).astype(np.int) 127 | cphase_ind2[cphase_ind2==zero_symbol] = 0 128 | cphase_ind3 = np.abs(cphase_map[:, 2]).astype(np.int) 129 | cphase_ind3[cphase_ind3==zero_symbol] = 0 130 | cphase_sign1 = np.sign(cphase_map[:, 0]) 131 | cphase_sign2 = np.sign(cphase_map[:, 1]) 132 | cphase_sign3 = np.sign(cphase_map[:, 2]) 133 | 134 | 135 | cphase_ind_list = [torch.tensor(cphase_ind1), torch.tensor(cphase_ind2), torch.tensor(cphase_ind3)] 136 | cphase_sign_list = [torch.tensor(cphase_sign1), torch.tensor(cphase_sign2), torch.tensor(cphase_sign3)] 137 | 138 | ############################################################################### 139 | # generate the discrete Fourier transform matrices for closure amp 140 | ############################################################################### 141 | if snrcut > 0: 142 | obs.add_camp(debias=True, count='min', snrcut=snrcut) 143 | obs.add_logcamp(debias=True, count='min', snrcut=snrcut) 144 | else: 145 | obs.add_camp(debias=True, count='min') 146 | obs.add_logcamp(debias=True, count='min') 147 | 148 | # if snrcut > 0: 149 | # obs.add_camp(debias=True, count='max', snrcut=snrcut) 150 | # obs.add_logcamp(debias=True, count='max', snrcut=snrcut) 151 | # else: 152 | # obs.add_camp(debias=True, count='max') 153 | # obs.add_logcamp(debias=True, count='max') 154 | 155 | # obs.add_camp(count='max') 156 | tca1 = obs.camp['t1'] 157 | tca2 = obs.camp['t2'] 158 | tca3 = obs.camp['t3'] 159 | tca4 = obs.camp['t4'] 160 | 161 | camp_map = np.zeros((len(obs.camp['time']), 6)) 162 | 163 | zero_symbol = 10000 164 | for k1 in range(camp_map.shape[0]): 165 | for k2 in list(np.where(obs.data['time']==obs.camp['time'][k1])[0]): 166 | if obs.data['t1'][k2] == obs.camp['t1'][k1] and obs.data['t2'][k2] == obs.camp['t2'][k1]: 167 | camp_map[k1, 0] = k2 168 | if k2 == 0: 169 | camp_map[k1, 0] = zero_symbol 170 | elif obs.data['t2'][k2] == obs.camp['t1'][k1] and obs.data['t1'][k2] == obs.camp['t2'][k1]: 171 | camp_map[k1, 0] = -k2 172 | if k2 == 0: 173 | camp_map[k1, 0] = -zero_symbol 174 | elif obs.data['t1'][k2] == obs.camp['t1'][k1] and obs.data['t2'][k2] == obs.camp['t3'][k1]: 175 | camp_map[k1, 1] = k2 176 | if k2 == 0: 177 | camp_map[k1, 1] = zero_symbol 178 | elif obs.data['t2'][k2] == obs.camp['t1'][k1] and obs.data['t1'][k2] == obs.camp['t3'][k1]: 179 | camp_map[k1, 1] = -k2 180 | if k2 == 0: 181 | camp_map[k1, 1] = -zero_symbol 182 | elif obs.data['t1'][k2] == obs.camp['t1'][k1] and obs.data['t2'][k2] == obs.camp['t4'][k1]: 183 | camp_map[k1, 2] = k2 184 | if k2 == 0: 185 | camp_map[k1, 2] = zero_symbol 186 | elif obs.data['t2'][k2] == obs.camp['t1'][k1] and obs.data['t1'][k2] == obs.camp['t4'][k1]: 187 | camp_map[k1, 2] = -k2 188 | if k2 == 0: 189 | camp_map[k1, 2] = -zero_symbol 190 | elif obs.data['t1'][k2] == obs.camp['t2'][k1] and obs.data['t2'][k2] == obs.camp['t3'][k1]: 191 | camp_map[k1, 3] = k2 192 | if k2 == 0: 193 | camp_map[k1, 3] = zero_symbol 194 | elif obs.data['t2'][k2] == obs.camp['t2'][k1] and obs.data['t1'][k2] == obs.camp['t3'][k1]: 195 | camp_map[k1, 3] = -k2 196 | if k2 == 0: 197 | camp_map[k1, 3] = -zero_symbol 198 | elif obs.data['t1'][k2] == obs.camp['t2'][k1] and obs.data['t2'][k2] == obs.camp['t4'][k1]: 199 | camp_map[k1, 4] = k2 200 | if k2 == 0: 201 | camp_map[k1, 4] = zero_symbol 202 | elif obs.data['t2'][k2] == obs.camp['t2'][k1] and obs.data['t1'][k2] == obs.camp['t4'][k1]: 203 | camp_map[k1, 4] = -k2 204 | if k2 == 0: 205 | camp_map[k1, 4] = -zero_symbol 206 | elif obs.data['t1'][k2] == obs.camp['t3'][k1] and obs.data['t2'][k2] == obs.camp['t4'][k1]: 207 | camp_map[k1, 5] = k2 208 | if k2 == 0: 209 | camp_map[k1, 5] = zero_symbol 210 | elif obs.data['t2'][k2] == obs.camp['t3'][k1] and obs.data['t1'][k2] == obs.camp['t4'][k1]: 211 | camp_map[k1, 5] = -k2 212 | if k2 == 0: 213 | camp_map[k1, 5] = -zero_symbol 214 | 215 | camp_ind1 = np.abs(camp_map[:, 0]).astype(np.int) 216 | camp_ind1[camp_ind1==zero_symbol] = 0 217 | camp_ind2 = np.abs(camp_map[:, 5]).astype(np.int) 218 | camp_ind2[camp_ind2==zero_symbol] = 0 219 | camp_ind3 = np.abs(camp_map[:, 2]).astype(np.int) 220 | camp_ind3[camp_ind3==zero_symbol] = 0 221 | camp_ind4 = np.abs(camp_map[:, 3]).astype(np.int) 222 | camp_ind4[camp_ind4==zero_symbol] = 0 223 | # camp_sign1 = np.sign(camp_map[:, 0]) 224 | # camp_sign2 = np.sign(camp_map[:, 5]) 225 | # camp_sign3 = np.sign(camp_map[:, 2]) 226 | # camp_sign4 = np.sign(camp_map[:, 3]) 227 | 228 | camp_ind_list = [torch.tensor(camp_ind1), torch.tensor(camp_ind2), torch.tensor(camp_ind3), torch.tensor(camp_ind4)] 229 | # camp_sign_list = [torch.tensor(camp_sign1), torch.tensor(camp_sign2), torch.tensor(camp_sign3), torch.tensor(camp_sign4)] 230 | return dft_mat, ktraj_vis, pulsefac_vis_torch, cphase_ind_list, cphase_sign_list, camp_ind_list 231 | 232 | ############################################################################### 233 | # Define the interferometry observation function 234 | ############################################################################### 235 | def eht_observation_pytorch(npix, nufft_ob, dft_mat, ktraj_vis, pulsefac_vis_torch, cphase_ind_list, cphase_sign_list, camp_ind_list, device, ttype='nfft'): 236 | eps = 1e-16 237 | nufft_ob = nufft_ob.to(device=device) 238 | ktraj_vis = ktraj_vis.to(device=device) 239 | pulsefac_vis_torch = pulsefac_vis_torch.to(device=device) 240 | 241 | cphase_ind1 = cphase_ind_list[0].to(device=device) 242 | cphase_ind2 = cphase_ind_list[1].to(device=device) 243 | cphase_ind3 = cphase_ind_list[2].to(device=device) 244 | 245 | cphase_sign1 = cphase_sign_list[0].to(device=device) 246 | cphase_sign2 = cphase_sign_list[1].to(device=device) 247 | cphase_sign3 = cphase_sign_list[2].to(device=device) 248 | 249 | camp_ind1 = camp_ind_list[0].to(device=device) 250 | camp_ind2 = camp_ind_list[1].to(device=device) 251 | camp_ind3 = camp_ind_list[2].to(device=device) 252 | camp_ind4 = camp_ind_list[3].to(device=device) 253 | 254 | if ttype == 'direct': 255 | F = dft_mat.to(device=device) 256 | 257 | def func(x): 258 | if ttype == 'direct': 259 | x = torch.reshape(x, (-1, npix*npix)).type(torch.float32).to(device=device) 260 | vis_torch = torch_complex_matmul(x, F) 261 | elif ttype == 'nfft': 262 | x = torch.reshape(x, (-1, npix, npix)).type(torch.float32).to(device=device).unsqueeze(1) 263 | x = torch.cat([x, torch.zeros_like(x)], 1) 264 | x = x.unsqueeze(0) 265 | 266 | kdata = nufft_ob(x, ktraj_vis) 267 | vis_torch = torch_complex_mul(kdata, pulsefac_vis_torch).squeeze(0) 268 | vis_amp = torch.sqrt((vis_torch[:, 0, :])**2 + (vis_torch[:, 1, :])**2 + eps) 269 | 270 | vis1_torch = torch.index_select(vis_torch, -1, cphase_ind1) 271 | vis2_torch = torch.index_select(vis_torch, -1, cphase_ind2) 272 | vis3_torch = torch.index_select(vis_torch, -1, cphase_ind3) 273 | 274 | ang1 = torch.atan2(vis1_torch[:, 1, :], vis1_torch[:, 0, :]) 275 | ang2 = torch.atan2(vis2_torch[:, 1, :], vis2_torch[:, 0, :]) 276 | ang3 = torch.atan2(vis3_torch[:, 1, :], vis3_torch[:, 0, :]) 277 | cphase = (cphase_sign1*ang1 + cphase_sign2*ang2 + cphase_sign3*ang3) * 180 / np.pi 278 | 279 | 280 | vis12_torch = torch.index_select(vis_torch, -1, camp_ind1) 281 | vis12_amp = torch.sqrt((vis12_torch[:, 0, :])**2 + (vis12_torch[:, 1, :])**2 + eps) 282 | vis34_torch = torch.index_select(vis_torch, -1, camp_ind2) 283 | vis34_amp = torch.sqrt((vis34_torch[:, 0, :])**2 + (vis34_torch[:, 1, :])**2 + eps) 284 | vis14_torch = torch.index_select(vis_torch, -1, camp_ind3) 285 | vis14_amp = torch.sqrt((vis14_torch[:, 0, :])**2 + (vis14_torch[:, 1, :])**2 + eps) 286 | vis23_torch = torch.index_select(vis_torch, -1, camp_ind4) 287 | vis23_amp = torch.sqrt((vis23_torch[:, 0, :])**2 + (vis23_torch[:, 1, :])**2 + eps) 288 | 289 | logcamp = torch.log(vis12_amp) + torch.log(vis34_amp) - torch.log(vis14_amp) - torch.log(vis23_amp) 290 | 291 | return vis_torch, vis_amp, cphase, logcamp 292 | return func 293 | 294 | 295 | ############################################################################### 296 | # Define the loss functions for interferometry imaging 297 | ############################################################################### 298 | 299 | def Loss_angle_diff(sigma, device): 300 | # closure phase difference loss 301 | sigma = torch.Tensor(sigma).type(torch.float32).to(device=device) 302 | def func(y_true, y_pred): 303 | angle_true = y_true * np.pi / 180 304 | angle_pred = y_pred * np.pi / 180 305 | # return K.mean(1 - K.cos(angle_true - angle_pred)) 306 | return 2.0*torch.mean((1 - torch.cos(angle_true - angle_pred))/(sigma*np.pi/180)**2, 1) 307 | return func 308 | 309 | 310 | def Loss_logca_diff(sigma, device): 311 | # log closure amp difference loss (computed based on closure amp and its std) 312 | sigma = torch.Tensor(sigma).type(torch.float32).to(device=device) 313 | def func(y_true, y_pred): 314 | return torch.mean((y_true/sigma)**2 * (torch.log(y_true)-torch.log(y_pred))**2, 1) 315 | return func 316 | 317 | 318 | def Loss_logca_diff2(sigma, device): 319 | # log closure amp difference loss (computed based on log closure amp and its std) 320 | sigma = torch.Tensor(sigma).type(torch.float32).to(device=device) 321 | def func(y_true, y_pred): 322 | return torch.mean((y_true - y_pred)**2/(sigma)**2, 1) 323 | return func 324 | 325 | 326 | def Loss_vis_diff(sigma, device): 327 | # visibility difference loss 328 | sigma = torch.Tensor(sigma).type(torch.float32).to(device=device) 329 | def func(y_true, y_pred): 330 | return torch.mean(((y_true[0]-y_pred[:, 0])**2+(y_true[1]-y_pred[:, 1])**2)/(sigma)**2, 1) 331 | return func 332 | 333 | 334 | def Loss_logamp_diff(sigma, device): 335 | # log amp difference loss 336 | sigma = torch.Tensor(sigma).type(torch.float32).to(device=device) 337 | def func(y_true, y_pred): 338 | return torch.mean((y_true)**2/(sigma)**2 * (torch.log(y_true)-torch.log(y_pred))**2, 1) 339 | return func 340 | 341 | 342 | def Loss_visamp_diff(sigma, device): 343 | # log amp difference loss 344 | sigma = torch.Tensor(sigma).type(torch.float32).to(device=device) 345 | def func(y_true, y_pred): 346 | return torch.mean((y_true-y_pred)**2/(sigma)**2, 1) 347 | return func 348 | 349 | def Loss_l1(y_pred): 350 | # image prior - sparsity loss 351 | return torch.mean(torch.abs(y_pred), (-1, -2)) 352 | 353 | 354 | def Loss_TSV(y_pred): 355 | # image prior - total squared variation loss 356 | return torch.mean((y_pred[:, 1::, :] - y_pred[:, 0:-1, :])**2, (-1, -2)) + torch.mean((y_pred[:, :, 1::] - y_pred[:, :, 0:-1])**2, (-1, -2)) 357 | 358 | def Loss_TV(y_pred): 359 | # image prior - total variation loss 360 | return torch.mean(torch.abs(y_pred[:, 1::, :] - y_pred[:, 0:-1, :]), (-1, -2)) + torch.mean(torch.abs(y_pred[:, :, 1::] - y_pred[:, :, 0:-1]), (-1, -2)) 361 | 362 | 363 | def Loss_flux(flux): 364 | # image prior - flux loss 365 | def func(y_pred): 366 | return (torch.sum(y_pred, (-1, -2)) - flux)**2 367 | return func 368 | 369 | def Loss_center(device, center=15.5, dim=32): 370 | # image prior - centering loss 371 | X = np.concatenate([np.arange(dim).reshape((1, dim))] * dim, 0) 372 | Y = np.concatenate([np.arange(dim).reshape((dim, 1))] * dim, 1) 373 | X = torch.Tensor(X).type(torch.float32).to(device = device) 374 | Y = torch.Tensor(Y).type(torch.float32).to(device = device) 375 | def func(y_pred): 376 | y_pred_flux = torch.mean(y_pred, (-1, -2)) 377 | xc_pred_norm = torch.mean(y_pred * X, (-1, -2)) / y_pred_flux 378 | yc_pred_norm = torch.mean(y_pred * Y, (-1, -2)) / y_pred_flux 379 | 380 | loss = 0.5 * ((xc_pred_norm-center)**2 + (yc_pred_norm-center)**2) 381 | return loss 382 | return func 383 | 384 | def Loss_cross_entropy(y_true, y_pred): 385 | # image prior - cross entropy loss (measuring difference between recon and reference images) 386 | loss = torch.mean(y_pred * (torch.log(y_pred + 1e-12) - torch.log(y_true + 1e-12)), (-1, -2)) 387 | return loss -------------------------------------------------------------------------------- /DPItorch/generative_model/glow_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from math import log, pi, exp 5 | import numpy as np 6 | from scipy import linalg as la 7 | 8 | logabs = lambda x: torch.log(torch.abs(x)) 9 | 10 | 11 | # class ActNorm(nn.Module): 12 | # def __init__(self, in_channel, logdet=True): 13 | # super().__init__() 14 | 15 | # self.loc = nn.Parameter(torch.zeros(1, in_channel, 1, 1)) 16 | # self.scale = nn.Parameter(torch.ones(1, in_channel, 1, 1)) 17 | 18 | # self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 19 | # self.logdet = logdet 20 | 21 | # def initialize(self, input, inv_init=False): 22 | # with torch.no_grad(): 23 | # flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 24 | # mean = ( 25 | # flatten.mean(1) 26 | # .unsqueeze(1) 27 | # .unsqueeze(2) 28 | # .unsqueeze(3) 29 | # .permute(1, 0, 2, 3) 30 | # ) 31 | # std = ( 32 | # flatten.std(1) 33 | # .unsqueeze(1) 34 | # .unsqueeze(2) 35 | # .unsqueeze(3) 36 | # .permute(1, 0, 2, 3) 37 | # ) 38 | # if inv_init: 39 | # self.loc.data.copy_(torch.zeros_like(mean)) 40 | # self.scale.data.copy_(torch.ones_like(std)) 41 | # else: 42 | # self.loc.data.copy_(-mean) 43 | # self.scale.data.copy_(1 / (std + 1e-6)) 44 | 45 | # def forward(self, input): 46 | # _, _, height, width = input.shape 47 | 48 | # if self.initialized.item() == 0: 49 | # self.initialize(input) 50 | # self.initialized.fill_(1) 51 | 52 | # log_abs = logabs(self.scale) 53 | 54 | # logdet = height * width * torch.sum(log_abs) 55 | 56 | # if self.logdet: 57 | # return self.scale * (input + self.loc), logdet 58 | 59 | # else: 60 | # return self.scale * (input + self.loc) 61 | 62 | # def reverse(self, output): 63 | # _, _, height, width = output.shape 64 | 65 | # if self.initialized.item() == 0: 66 | # self.initialize(output, inv_init=True) 67 | # self.initialized.fill_(1) 68 | 69 | # log_abs = logabs(self.scale) 70 | 71 | # logdet = -height * width * torch.sum(log_abs) 72 | 73 | # if self.logdet: 74 | # return output / self.scale - self.loc, logdet 75 | 76 | # else: 77 | # return output / self.scale - self.loc 78 | 79 | 80 | class ActNorm(nn.Module): 81 | def __init__(self, in_channel, logdet=True): 82 | super().__init__() 83 | 84 | self.loc = nn.Parameter(torch.zeros(1, in_channel, 1, 1)) 85 | self.scale_inv = nn.Parameter(torch.ones(1, in_channel, 1, 1)) 86 | 87 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 88 | self.logdet = logdet 89 | 90 | def initialize(self, input, inv_init=False): 91 | with torch.no_grad(): 92 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 93 | mean = ( 94 | flatten.mean(1) 95 | .unsqueeze(1) 96 | .unsqueeze(2) 97 | .unsqueeze(3) 98 | .permute(1, 0, 2, 3) 99 | ) 100 | std = ( 101 | flatten.std(1) 102 | .unsqueeze(1) 103 | .unsqueeze(2) 104 | .unsqueeze(3) 105 | .permute(1, 0, 2, 3) 106 | ) 107 | if inv_init: 108 | self.loc.data.copy_(torch.zeros_like(mean)) 109 | self.scale_inv.data.copy_(torch.ones_like(std)) 110 | else: 111 | self.loc.data.copy_(-mean) 112 | self.scale_inv.data.copy_((std + 1e-6)) 113 | 114 | def forward(self, input): 115 | _, _, height, width = input.shape 116 | 117 | if self.initialized.item() == 0: 118 | self.initialize(input) 119 | self.initialized.fill_(1) 120 | 121 | log_abs = -logabs(self.scale_inv) 122 | 123 | logdet = height * width * torch.sum(log_abs) 124 | 125 | if self.logdet: 126 | return (1.0 / self.scale_inv) * (input + self.loc), logdet 127 | 128 | else: 129 | return (1.0 / self.scale_inv) * (input + self.loc) 130 | 131 | def reverse(self, output): 132 | _, _, height, width = output.shape 133 | 134 | if self.initialized.item() == 0: 135 | self.initialize(output, inv_init=True) 136 | self.initialized.fill_(1) 137 | 138 | log_abs = -logabs(self.scale_inv) 139 | 140 | logdet = -height * width * torch.sum(log_abs) 141 | 142 | if self.logdet: 143 | return output * self.scale_inv - self.loc, logdet 144 | 145 | else: 146 | return output * self.scale_inv - self.loc 147 | 148 | class InvConv2d(nn.Module): 149 | def __init__(self, in_channel): 150 | super().__init__() 151 | 152 | weight = torch.randn(in_channel, in_channel) 153 | q, _ = torch.qr(weight) 154 | weight = q.unsqueeze(2).unsqueeze(3) 155 | self.weight = nn.Parameter(weight) 156 | 157 | def forward(self, input): 158 | _, _, height, width = input.shape 159 | 160 | out = F.conv2d(input, self.weight) 161 | logdet = ( 162 | height * width * torch.slogdet(self.weight.squeeze().double())[1].float() 163 | ) 164 | 165 | return out, logdet 166 | 167 | def reverse(self, output): 168 | _, _, height, width = output.shape 169 | 170 | in_recover = F.conv2d(output, self.weight.squeeze().inverse().unsqueeze(2).unsqueeze(3)) 171 | logdet = ( 172 | -height * width * torch.slogdet(self.weight.squeeze().double())[1].float() 173 | ) 174 | 175 | return in_recover, logdet 176 | 177 | 178 | class InvConv2dLU(nn.Module): 179 | def __init__(self, in_channel): 180 | super().__init__() 181 | 182 | weight = np.random.randn(in_channel, in_channel) 183 | q, _ = la.qr(weight) 184 | w_p, w_l, w_u = la.lu(q.astype(np.float32)) 185 | w_s = np.diag(w_u) 186 | w_u = np.triu(w_u, 1) 187 | u_mask = np.triu(np.ones_like(w_u), 1) 188 | l_mask = u_mask.T 189 | 190 | w_p = torch.from_numpy(w_p) 191 | w_l = torch.from_numpy(w_l) 192 | w_s = torch.from_numpy(w_s) 193 | w_u = torch.from_numpy(w_u) 194 | 195 | self.register_buffer("w_p", w_p) 196 | self.register_buffer("u_mask", torch.from_numpy(u_mask)) 197 | self.register_buffer("l_mask", torch.from_numpy(l_mask)) 198 | self.register_buffer("s_sign", torch.sign(w_s)) 199 | self.register_buffer("l_eye", torch.eye(l_mask.shape[0])) 200 | self.w_l = nn.Parameter(w_l) 201 | self.w_s = nn.Parameter(logabs(w_s)) 202 | self.w_u = nn.Parameter(w_u) 203 | 204 | def forward(self, input): 205 | _, _, height, width = input.shape 206 | 207 | weight = self.calc_weight() 208 | 209 | out = F.conv2d(input, weight) 210 | logdet = height * width * torch.sum(self.w_s) 211 | 212 | return out, logdet 213 | 214 | def calc_weight(self): 215 | weight = ( 216 | self.w_p 217 | @ (self.w_l * self.l_mask + self.l_eye) 218 | @ ((self.w_u * self.u_mask) + torch.diag(self.s_sign * torch.exp(self.w_s))) 219 | ) 220 | 221 | return weight.unsqueeze(2).unsqueeze(3) 222 | 223 | def reverse(self, output): 224 | _, _, height, width = output.shape 225 | 226 | weight = self.calc_weight() 227 | 228 | in_recover = F.conv2d(output, weight.squeeze().inverse().unsqueeze(2).unsqueeze(3)) 229 | logdet = -height * width * torch.sum(self.w_s) 230 | 231 | return in_recover, logdet 232 | 233 | 234 | class ZeroConv2d(nn.Module): 235 | def __init__(self, in_channel, out_channel, padding=1): 236 | super().__init__() 237 | 238 | self.conv = nn.Conv2d(in_channel, out_channel, 3, padding=0) 239 | self.conv.weight.data.zero_() 240 | self.conv.bias.data.zero_() 241 | self.scale = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 242 | 243 | def forward(self, input): 244 | out = F.pad(input, [1, 1, 1, 1], value=1) 245 | out = self.conv(out) 246 | out = out * torch.exp(self.scale * 3) 247 | 248 | return out 249 | 250 | 251 | class AffineCoupling(nn.Module): 252 | def __init__(self, in_channel, filter_size=512, affine=True): 253 | super().__init__() 254 | 255 | self.affine = affine 256 | 257 | # self.net = nn.Sequential( 258 | # nn.Conv2d(in_channel // 2, filter_size, 3, padding=1), 259 | # nn.ReLU(inplace=True), 260 | # nn.Conv2d(filter_size, filter_size, 1), 261 | # nn.ReLU(inplace=True), 262 | # ZeroConv2d(filter_size, in_channel if self.affine else in_channel // 2), 263 | # ) 264 | 265 | # self.net[0].weight.data.normal_(0, 0.05) 266 | # self.net[0].bias.data.zero_() 267 | 268 | # self.net[2].weight.data.normal_(0, 0.05) 269 | # self.net[2].bias.data.zero_() 270 | 271 | self.net = nn.Sequential( 272 | nn.Conv2d(in_channel // 2, filter_size, 3, padding=1), 273 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 274 | nn.BatchNorm2d(filter_size), 275 | nn.Conv2d(filter_size, filter_size, 1), 276 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 277 | nn.BatchNorm2d(filter_size), 278 | ZeroConv2d(filter_size, in_channel if self.affine else in_channel // 2), 279 | ) 280 | 281 | self.net[0].weight.data.normal_(0, 0.05) 282 | self.net[0].bias.data.zero_() 283 | 284 | self.net[3].weight.data.normal_(0, 0.05) 285 | self.net[3].bias.data.zero_() 286 | 287 | def forward(self, input): 288 | in_a, in_b = input.chunk(2, 1) 289 | 290 | if self.affine: 291 | # Case 1 292 | log_s0, t = self.net(in_a).chunk(2, 1) 293 | log_s = torch.tanh(log_s0) 294 | s = torch.exp(log_s) 295 | 296 | # # Case 2 297 | # log_s0, t0 = self.net(in_a).chunk(2, 1) 298 | # log_s = torch.tanh(log_s0) 299 | # t = torch.tanh(t0) 300 | # s = torch.exp(log_s) 301 | 302 | # # Case 3 303 | # log_s, t = self.net(in_a).chunk(2, 1) 304 | # s = torch.sigmoid(log_s + 2) 305 | # out_a = s * in_a + t 306 | out_b = (in_b + t) * s 307 | 308 | # logdet = torch.sum(torch.log(s).view(input.shape[0], -1), 1) 309 | logdet = torch.sum(log_s.view(input.shape[0], -1), 1) 310 | 311 | else: 312 | net_out = self.net(in_a) 313 | out_b = in_b + net_out 314 | logdet = None 315 | 316 | return torch.cat([in_a, out_b], 1), logdet 317 | 318 | def reverse(self, output): 319 | out_a, out_b = output.chunk(2, 1) 320 | 321 | if self.affine: 322 | # Case 1 323 | log_s0, t = self.net(out_a).chunk(2, 1) 324 | log_s = torch.tanh(log_s0) 325 | s = torch.exp(log_s) 326 | 327 | # # Case 2 328 | # log_s0, t0 = self.net(out_a).chunk(2, 1) 329 | # log_s = torch.tanh(log_s0) 330 | # t = torch.tanh(t0) 331 | # s = torch.exp(log_s) 332 | 333 | # # Case 3 334 | # log_s, t = self.net(out_a).chunk(2, 1) 335 | # s = torch.sigmoid(log_s + 2) 336 | # in_a = (out_a - t) / s 337 | in_b = out_b / s - t 338 | 339 | # logdet = -torch.sum(torch.log(s).view(output.shape[0], -1), 1) 340 | logdet = -torch.sum(log_s.view(output.shape[0], -1), 1) 341 | 342 | else: 343 | net_out = self.net(out_a) 344 | in_b = out_b - net_out 345 | 346 | logdet = None 347 | 348 | return torch.cat([out_a, in_b], 1), logdet 349 | 350 | 351 | class Flow(nn.Module): 352 | def __init__(self, in_channel, affine=True, conv_lu=True): 353 | super().__init__() 354 | 355 | self.actnorm = ActNorm(in_channel) 356 | 357 | if conv_lu: 358 | self.invconv = InvConv2dLU(in_channel) 359 | 360 | else: 361 | self.invconv = InvConv2d(in_channel) 362 | 363 | self.coupling = AffineCoupling(in_channel, affine=affine) 364 | 365 | def forward(self, input): 366 | out, logdet = self.actnorm(input) 367 | out, det1 = self.invconv(out) 368 | out, det2 = self.coupling(out) 369 | 370 | logdet = logdet + det1 371 | if det2 is not None: 372 | logdet = logdet + det2 373 | 374 | return out, logdet 375 | 376 | def reverse(self, output): 377 | input, logdet = self.coupling.reverse(output) 378 | input, det1 = self.invconv.reverse(input) 379 | input, det2 = self.actnorm.reverse(input) 380 | 381 | logdet = logdet + det1 382 | if det2 is not None: 383 | logdet = logdet + det2 384 | 385 | return input, logdet 386 | 387 | 388 | 389 | def gaussian_log_p(x, mean, log_sd): 390 | return -0.5 * log(2 * pi) - log_sd - 0.5 * (x - mean) ** 2 / torch.exp(2 * log_sd) 391 | 392 | def gaussian_log_p2(eps, mean, log_sd): 393 | return -0.5 * log(2 * pi) - log_sd - 0.5 * eps ** 2 394 | 395 | def gaussian_sample(eps, mean, log_sd): 396 | return mean + torch.exp(log_sd) * eps 397 | 398 | def gaussian_sample2(x, mean, log_sd): 399 | return (x - mean) * torch.exp(-log_sd) 400 | 401 | 402 | class Block(nn.Module): 403 | def __init__(self, in_channel, n_flow, split=True, affine=True, conv_lu=True): 404 | super().__init__() 405 | 406 | squeeze_dim = in_channel * 4 407 | 408 | self.flows = nn.ModuleList() 409 | for i in range(n_flow): 410 | self.flows.append(Flow(squeeze_dim, affine=affine, conv_lu=conv_lu)) 411 | 412 | self.split = split 413 | 414 | if split: 415 | self.prior = ZeroConv2d(in_channel * 2, in_channel * 4) 416 | 417 | else: 418 | self.prior = ZeroConv2d(in_channel * 4, in_channel * 8) 419 | 420 | def forward(self, input): 421 | b_size, n_channel, height, width = input.shape 422 | squeezed = input.view(b_size, n_channel, height // 2, 2, width // 2, 2) 423 | squeezed = squeezed.permute(0, 1, 3, 5, 2, 4) 424 | out = squeezed.contiguous().view(b_size, n_channel * 4, height // 2, width // 2) 425 | 426 | logdet = 0 427 | 428 | for flow in self.flows: 429 | out, det = flow(out) 430 | logdet = logdet + det 431 | 432 | if self.split: 433 | out, z_new = out.chunk(2, 1) 434 | mean, log_sd = self.prior(out).chunk(2, 1) 435 | # log_p = gaussian_log_p(z_new, mean, log_sd) 436 | # log_p = log_p.view(b_size, -1).sum(1) 437 | z_eps = gaussian_sample2(z_new, mean, log_sd) 438 | logdet = logdet - log_sd.view(b_size, -1).sum(1) 439 | log_p = -0.5 * log(2 * pi) - 0.5 * z_eps ** 2 440 | log_p = log_p.view(b_size, -1).sum(1) 441 | 442 | else: 443 | zero = torch.zeros_like(out) 444 | mean, log_sd = self.prior(zero).chunk(2, 1) 445 | # log_p = gaussian_log_p(out, mean, log_sd) 446 | # log_p = log_p.view(b_size, -1).sum(1) 447 | z_new = out 448 | z_eps = gaussian_sample2(z_new, mean, log_sd) 449 | logdet = logdet - log_sd.view(b_size, -1).sum(1) 450 | log_p = -0.5 * log(2 * pi) - 0.5 * z_eps ** 2 451 | log_p = log_p.view(b_size, -1).sum(1) 452 | 453 | 454 | return out, logdet, log_p, z_eps 455 | 456 | def reverse(self, output, eps=None): 457 | input = output 458 | 459 | logdet = 0 460 | 461 | if self.split: 462 | mean, log_sd = self.prior(input).chunk(2, 1) 463 | z = gaussian_sample(eps, mean, log_sd) 464 | # log_p = gaussian_log_p2(eps, mean, log_sd) 465 | logdet = logdet + log_sd.view(input.shape[0], -1).sum(1) 466 | input = torch.cat([output, z], 1) 467 | 468 | else: 469 | zero = torch.zeros_like(input) 470 | # zero = F.pad(zero, [1, 1, 1, 1], value=1) 471 | mean, log_sd = self.prior(zero).chunk(2, 1) 472 | z = gaussian_sample(eps, mean, log_sd) 473 | # log_p = gaussian_log_p2(eps, mean, log_sd) 474 | logdet = logdet + log_sd.view(input.shape[0], -1).sum(1) 475 | input = z 476 | 477 | 478 | 479 | for flow in self.flows[::-1]: 480 | input, det = flow.reverse(input) 481 | logdet = logdet + det 482 | 483 | b_size, n_channel, height, width = input.shape 484 | 485 | 486 | unsqueezed = input.view(b_size, n_channel // 4, 2, 2, height, width) 487 | unsqueezed = unsqueezed.permute(0, 1, 4, 2, 5, 3) 488 | unsqueezed = unsqueezed.contiguous().view( 489 | b_size, n_channel // 4, height * 2, width * 2 490 | ) 491 | 492 | return unsqueezed, logdet 493 | 494 | 495 | class Glow(nn.Module): 496 | def __init__( 497 | self, in_channel, n_flow, n_block, affine=True, conv_lu=True 498 | ): 499 | super().__init__() 500 | 501 | self.blocks = nn.ModuleList() 502 | n_channel = in_channel 503 | for i in range(n_block - 1): 504 | self.blocks.append(Block(n_channel, n_flow, affine=affine, conv_lu=conv_lu)) 505 | n_channel *= 2 506 | self.blocks.append(Block(n_channel, n_flow, split=False, affine=affine)) 507 | 508 | def forward(self, input): 509 | log_p_sum = 0 510 | logdet = 0 511 | out = input 512 | z_outs = [] 513 | 514 | for block in self.blocks: 515 | out, det, log_p, z_new = block(out) 516 | z_outs.append(z_new) 517 | # z_outs.append(z_new.view(input.shape[0], -1)) 518 | logdet = logdet + det 519 | 520 | if log_p is not None: 521 | log_p_sum = log_p_sum + log_p 522 | # z_eps = torch.cat(z_outs, -1) 523 | 524 | return log_p_sum, logdet, z_outs 525 | 526 | def reverse(self, z_list): 527 | log_p_sum = 0 528 | logdet = 0 529 | 530 | for i, block in enumerate(self.blocks[::-1]): 531 | if i == 0: 532 | input, det = block.reverse(z_list[-1], z_list[-1]) 533 | else: 534 | input, det = block.reverse(input, z_list[-(i + 1)]) 535 | 536 | logdet = logdet + det 537 | 538 | 539 | return input, logdet 540 | 541 | 542 | def calc_z_shapes(n_channel, input_size, n_flow, n_block): 543 | z_shapes = [] 544 | 545 | for i in range(n_block - 1): 546 | input_size //= 2 547 | n_channel *= 2 548 | 549 | z_shapes.append((n_channel, input_size, input_size)) 550 | 551 | input_size //= 2 552 | z_shapes.append((n_channel * 4, input_size, input_size)) 553 | 554 | return z_shapes -------------------------------------------------------------------------------- /DPItorch/generative_model/cond_glow_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from math import log, pi, exp 5 | import numpy as np 6 | from scipy import linalg as la 7 | 8 | logabs = lambda x: torch.log(torch.abs(x)) 9 | 10 | 11 | # class ActNorm(nn.Module): 12 | # def __init__(self, in_channel, logdet=True): 13 | # super().__init__() 14 | 15 | # self.loc = nn.Parameter(torch.zeros(1, in_channel, 1, 1)) 16 | # self.scale = nn.Parameter(torch.ones(1, in_channel, 1, 1)) 17 | 18 | # self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 19 | # self.logdet = logdet 20 | 21 | # def initialize(self, input, inv_init=False): 22 | # with torch.no_grad(): 23 | # flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 24 | # mean = ( 25 | # flatten.mean(1) 26 | # .unsqueeze(1) 27 | # .unsqueeze(2) 28 | # .unsqueeze(3) 29 | # .permute(1, 0, 2, 3) 30 | # ) 31 | # std = ( 32 | # flatten.std(1) 33 | # .unsqueeze(1) 34 | # .unsqueeze(2) 35 | # .unsqueeze(3) 36 | # .permute(1, 0, 2, 3) 37 | # ) 38 | # if inv_init: 39 | # self.loc.data.copy_(torch.zeros_like(mean)) 40 | # self.scale.data.copy_(torch.ones_like(std)) 41 | # else: 42 | # self.loc.data.copy_(-mean) 43 | # self.scale.data.copy_(1 / (std + 1e-6)) 44 | 45 | # def forward(self, input): 46 | # _, _, height, width = input.shape 47 | 48 | # if self.initialized.item() == 0: 49 | # self.initialize(input) 50 | # self.initialized.fill_(1) 51 | 52 | # log_abs = logabs(self.scale) 53 | 54 | # logdet = height * width * torch.sum(log_abs) 55 | 56 | # if self.logdet: 57 | # return self.scale * (input + self.loc), logdet 58 | 59 | # else: 60 | # return self.scale * (input + self.loc) 61 | 62 | # def reverse(self, output): 63 | # _, _, height, width = output.shape 64 | 65 | # if self.initialized.item() == 0: 66 | # self.initialize(output, inv_init=True) 67 | # self.initialized.fill_(1) 68 | 69 | # log_abs = logabs(self.scale) 70 | 71 | # logdet = -height * width * torch.sum(log_abs) 72 | 73 | # if self.logdet: 74 | # return output / self.scale - self.loc, logdet 75 | 76 | # else: 77 | # return output / self.scale - self.loc 78 | 79 | 80 | class ActNorm(nn.Module): 81 | def __init__(self, in_channel, logdet=True): 82 | super().__init__() 83 | 84 | self.loc = nn.Parameter(torch.zeros(1, in_channel, 1, 1)) 85 | self.scale_inv = nn.Parameter(torch.ones(1, in_channel, 1, 1)) 86 | 87 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 88 | self.logdet = logdet 89 | 90 | def initialize(self, input, inv_init=False): 91 | with torch.no_grad(): 92 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 93 | mean = ( 94 | flatten.mean(1) 95 | .unsqueeze(1) 96 | .unsqueeze(2) 97 | .unsqueeze(3) 98 | .permute(1, 0, 2, 3) 99 | ) 100 | std = ( 101 | flatten.std(1) 102 | .unsqueeze(1) 103 | .unsqueeze(2) 104 | .unsqueeze(3) 105 | .permute(1, 0, 2, 3) 106 | ) 107 | if inv_init: 108 | self.loc.data.copy_(torch.zeros_like(mean)) 109 | self.scale_inv.data.copy_(torch.ones_like(std)) 110 | else: 111 | self.loc.data.copy_(-mean) 112 | self.scale_inv.data.copy_((std + 1e-6)) 113 | 114 | def forward(self, input): 115 | _, _, height, width = input.shape 116 | 117 | if self.initialized.item() == 0: 118 | self.initialize(input) 119 | self.initialized.fill_(1) 120 | 121 | log_abs = -logabs(self.scale_inv) 122 | 123 | logdet = height * width * torch.sum(log_abs) 124 | 125 | if self.logdet: 126 | return (1.0 / self.scale_inv) * (input + self.loc), logdet 127 | 128 | else: 129 | return (1.0 / self.scale_inv) * (input + self.loc) 130 | 131 | def reverse(self, output): 132 | _, _, height, width = output.shape 133 | 134 | if self.initialized.item() == 0: 135 | self.initialize(output, inv_init=True) 136 | self.initialized.fill_(1) 137 | 138 | log_abs = -logabs(self.scale_inv) 139 | 140 | logdet = -height * width * torch.sum(log_abs) 141 | 142 | if self.logdet: 143 | return output * self.scale_inv - self.loc, logdet 144 | 145 | else: 146 | return output * self.scale_inv - self.loc 147 | 148 | class InvConv2d(nn.Module): 149 | def __init__(self, in_channel): 150 | super().__init__() 151 | 152 | weight = torch.randn(in_channel, in_channel) 153 | q, _ = torch.qr(weight) 154 | weight = q.unsqueeze(2).unsqueeze(3) 155 | self.weight = nn.Parameter(weight) 156 | 157 | def forward(self, input): 158 | _, _, height, width = input.shape 159 | 160 | out = F.conv2d(input, self.weight) 161 | logdet = ( 162 | height * width * torch.slogdet(self.weight.squeeze().double())[1].float() 163 | ) 164 | 165 | return out, logdet 166 | 167 | def reverse(self, output): 168 | _, _, height, width = output.shape 169 | 170 | in_recover = F.conv2d(output, self.weight.squeeze().inverse().unsqueeze(2).unsqueeze(3)) 171 | logdet = ( 172 | -height * width * torch.slogdet(self.weight.squeeze().double())[1].float() 173 | ) 174 | 175 | return in_recover, logdet 176 | 177 | 178 | class InvConv2dLU(nn.Module): 179 | def __init__(self, in_channel): 180 | super().__init__() 181 | 182 | weight = np.random.randn(in_channel, in_channel) 183 | q, _ = la.qr(weight) 184 | w_p, w_l, w_u = la.lu(q.astype(np.float32)) 185 | w_s = np.diag(w_u) 186 | w_u = np.triu(w_u, 1) 187 | u_mask = np.triu(np.ones_like(w_u), 1) 188 | l_mask = u_mask.T 189 | 190 | w_p = torch.from_numpy(w_p) 191 | w_l = torch.from_numpy(w_l) 192 | w_s = torch.from_numpy(w_s) 193 | w_u = torch.from_numpy(w_u) 194 | 195 | self.register_buffer("w_p", w_p) 196 | self.register_buffer("u_mask", torch.from_numpy(u_mask)) 197 | self.register_buffer("l_mask", torch.from_numpy(l_mask)) 198 | self.register_buffer("s_sign", torch.sign(w_s)) 199 | self.register_buffer("l_eye", torch.eye(l_mask.shape[0])) 200 | self.w_l = nn.Parameter(w_l) 201 | self.w_s = nn.Parameter(logabs(w_s)) 202 | self.w_u = nn.Parameter(w_u) 203 | 204 | def forward(self, input): 205 | _, _, height, width = input.shape 206 | 207 | weight = self.calc_weight() 208 | 209 | out = F.conv2d(input, weight) 210 | logdet = height * width * torch.sum(self.w_s) 211 | 212 | return out, logdet 213 | 214 | def calc_weight(self): 215 | weight = ( 216 | self.w_p 217 | @ (self.w_l * self.l_mask + self.l_eye) 218 | @ ((self.w_u * self.u_mask) + torch.diag(self.s_sign * torch.exp(self.w_s))) 219 | ) 220 | 221 | return weight.unsqueeze(2).unsqueeze(3) 222 | 223 | def reverse(self, output): 224 | _, _, height, width = output.shape 225 | 226 | weight = self.calc_weight() 227 | 228 | in_recover = F.conv2d(output, weight.squeeze().inverse().unsqueeze(2).unsqueeze(3)) 229 | logdet = -height * width * torch.sum(self.w_s) 230 | 231 | return in_recover, logdet 232 | 233 | 234 | class ZeroConv2d(nn.Module): 235 | def __init__(self, in_channel, out_channel, padding=1): 236 | super().__init__() 237 | 238 | self.conv = nn.Conv2d(in_channel, out_channel, 3, padding=0) 239 | self.conv.weight.data.zero_() 240 | self.conv.bias.data.zero_() 241 | self.scale = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 242 | 243 | def forward(self, input): 244 | out = F.pad(input, [1, 1, 1, 1], value=1) 245 | out = self.conv(out) 246 | out = out * torch.exp(self.scale * 3) 247 | 248 | return out 249 | 250 | # class View(nn.Module): 251 | # def __init__(self, shape): 252 | # self.shape = shape 253 | 254 | # def forward(self, x): 255 | # return x.view(*self.shape) 256 | 257 | class AffineCoupling(nn.Module): 258 | def __init__(self, in_channel, filter_size=512, cond_filter_size=1, affine=True): 259 | super().__init__() 260 | 261 | self.affine = affine 262 | self.cond_filter_size = cond_filter_size 263 | 264 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 265 | 266 | # self.net = nn.Sequential( 267 | # nn.Conv2d(in_channel // 2, filter_size, 3, padding=1), 268 | # nn.ReLU(inplace=True), 269 | # nn.Conv2d(filter_size, filter_size, 1), 270 | # nn.ReLU(inplace=True), 271 | # ZeroConv2d(filter_size, in_channel if self.affine else in_channel // 2), 272 | # ) 273 | 274 | # self.net[0].weight.data.normal_(0, 0.05) 275 | # self.net[0].bias.data.zero_() 276 | 277 | # self.net[2].weight.data.normal_(0, 0.05) 278 | # self.net[2].bias.data.zero_() 279 | 280 | 281 | 282 | self.net = nn.Sequential( 283 | nn.Conv2d(in_channel // 2 + cond_filter_size, filter_size, 3, padding=1), 284 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 285 | nn.BatchNorm2d(filter_size), 286 | nn.Conv2d(filter_size, filter_size, 1), 287 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 288 | nn.BatchNorm2d(filter_size), 289 | ZeroConv2d(filter_size, in_channel if self.affine else in_channel // 2), 290 | ) 291 | 292 | 293 | self.net[0].weight.data.normal_(0, 0.05) 294 | self.net[0].bias.data.zero_() 295 | 296 | self.net[3].weight.data.normal_(0, 0.05) 297 | self.net[3].bias.data.zero_() 298 | 299 | 300 | def initialize(self, input, cond_input, device='cpu', inv_init=False): 301 | with torch.no_grad(): 302 | _, in_channel, height, width = input.shape 303 | _, ndim_cond = cond_input.shape 304 | 305 | 306 | self.net2 = nn.Sequential( 307 | nn.Linear(ndim_cond, height*width*self.cond_filter_size)) 308 | self.net2.to(device) 309 | 310 | 311 | def forward(self, input, cond_input, device='cpu'): 312 | if self.initialized.item() == 0: 313 | self.initialize(input, cond_input, device=device) 314 | self.initialized.fill_(1) 315 | 316 | in_a, in_b = input.chunk(2, 1) 317 | cond_transformed = self.net2(cond_input) 318 | 319 | _, in_channel, height, width = input.shape 320 | cond_transformed = cond_transformed.reshape((-1, self.cond_filter_size, height, width)) 321 | 322 | if self.affine: 323 | # Case 1 324 | log_s0, t = self.net(torch.cat([in_a, cond_transformed], 1)).chunk(2, 1) 325 | log_s = torch.tanh(log_s0) 326 | s = torch.exp(log_s) 327 | 328 | # # Case 2 329 | # log_s0, t0 = self.net(in_a).chunk(2, 1) 330 | # log_s = torch.tanh(log_s0) 331 | # t = torch.tanh(t0) 332 | # s = torch.exp(log_s) 333 | 334 | # # Case 3 335 | # log_s, t = self.net(in_a).chunk(2, 1) 336 | # s = torch.sigmoid(log_s + 2) 337 | # out_a = s * in_a + t 338 | out_b = (in_b + t) * s 339 | 340 | # logdet = torch.sum(torch.log(s).view(input.shape[0], -1), 1) 341 | logdet = torch.sum(log_s.view(input.shape[0], -1), 1) 342 | 343 | else: 344 | net_out = self.net(torch.cat([in_a, cond_transformed], 1)) 345 | out_b = in_b + net_out 346 | logdet = None 347 | 348 | return torch.cat([in_a, out_b], 1), logdet 349 | 350 | def reverse(self, output, cond_input, device='cpu'): 351 | if self.initialized.item() == 0: 352 | self.initialize(output, cond_input, device=device) 353 | self.initialized.fill_(1) 354 | 355 | out_a, out_b = output.chunk(2, 1) 356 | cond_transformed = self.net2(cond_input) 357 | 358 | _, in_channel, height, width = output.shape 359 | cond_transformed = cond_transformed.reshape((-1, self.cond_filter_size, height, width)) 360 | 361 | if self.affine: 362 | # Case 1 363 | log_s0, t = self.net(torch.cat([out_a, cond_transformed], 1)).chunk(2, 1) 364 | log_s = torch.tanh(log_s0) 365 | s = torch.exp(log_s) 366 | 367 | # # Case 2 368 | # log_s0, t0 = self.net(out_a).chunk(2, 1) 369 | # log_s = torch.tanh(log_s0) 370 | # t = torch.tanh(t0) 371 | # s = torch.exp(log_s) 372 | 373 | # # Case 3 374 | # log_s, t = self.net(out_a).chunk(2, 1) 375 | # s = torch.sigmoid(log_s + 2) 376 | # in_a = (out_a - t) / s 377 | in_b = out_b / s - t 378 | 379 | # logdet = -torch.sum(torch.log(s).view(output.shape[0], -1), 1) 380 | logdet = -torch.sum(log_s.view(output.shape[0], -1), 1) 381 | 382 | else: 383 | net_out = self.net(torch.cat([out_a, cond_transformed], 1)) 384 | in_b = out_b - net_out 385 | 386 | logdet = None 387 | 388 | return torch.cat([out_a, in_b], 1), logdet 389 | 390 | 391 | class Flow(nn.Module): 392 | def __init__(self, in_channel, affine=True, conv_lu=True): 393 | super().__init__() 394 | 395 | self.actnorm = ActNorm(in_channel) 396 | 397 | if conv_lu: 398 | self.invconv = InvConv2dLU(in_channel) 399 | 400 | else: 401 | self.invconv = InvConv2d(in_channel) 402 | 403 | self.coupling = AffineCoupling(in_channel, cond_filter_size=in_channel//4, affine=affine) 404 | 405 | def forward(self, input, cond_input, device='cpu'): 406 | out, logdet = self.actnorm(input) 407 | out, det1 = self.invconv(out) 408 | out, det2 = self.coupling(out, cond_input, device=device) 409 | 410 | logdet = logdet + det1 411 | if det2 is not None: 412 | logdet = logdet + det2 413 | 414 | return out, logdet 415 | 416 | def reverse(self, output, cond_input, device='cpu'): 417 | input, logdet = self.coupling.reverse(output, cond_input, device=device) 418 | input, det1 = self.invconv.reverse(input) 419 | input, det2 = self.actnorm.reverse(input) 420 | 421 | logdet = logdet + det1 422 | if det2 is not None: 423 | logdet = logdet + det2 424 | 425 | return input, logdet 426 | 427 | 428 | 429 | def gaussian_log_p(x, mean, log_sd): 430 | return -0.5 * log(2 * pi) - log_sd - 0.5 * (x - mean) ** 2 / torch.exp(2 * log_sd) 431 | 432 | def gaussian_log_p2(eps, mean, log_sd): 433 | return -0.5 * log(2 * pi) - log_sd - 0.5 * eps ** 2 434 | 435 | def gaussian_sample(eps, mean, log_sd): 436 | return mean + torch.exp(log_sd) * eps 437 | 438 | def gaussian_sample2(x, mean, log_sd): 439 | return (x - mean) * torch.exp(-log_sd) 440 | 441 | 442 | class Block(nn.Module): 443 | def __init__(self, in_channel, n_flow, split=True, affine=True, conv_lu=True): 444 | super().__init__() 445 | 446 | squeeze_dim = in_channel * 4 447 | 448 | self.flows = nn.ModuleList() 449 | for i in range(n_flow): 450 | self.flows.append(Flow(squeeze_dim, affine=affine, conv_lu=conv_lu)) 451 | 452 | self.split = split 453 | 454 | if split: 455 | self.prior = ZeroConv2d(in_channel * 2, in_channel * 4) 456 | 457 | else: 458 | self.prior = ZeroConv2d(in_channel * 4, in_channel * 8) 459 | 460 | def forward(self, input, cond_input, device='cpu'): 461 | b_size, n_channel, height, width = input.shape 462 | squeezed = input.view(b_size, n_channel, height // 2, 2, width // 2, 2) 463 | squeezed = squeezed.permute(0, 1, 3, 5, 2, 4) 464 | out = squeezed.contiguous().view(b_size, n_channel * 4, height // 2, width // 2) 465 | 466 | logdet = 0 467 | 468 | for flow in self.flows: 469 | out, det = flow(out, cond_input, device=device) 470 | logdet = logdet + det 471 | 472 | if self.split: 473 | out, z_new = out.chunk(2, 1) 474 | mean, log_sd = self.prior(out).chunk(2, 1) 475 | # log_p = gaussian_log_p(z_new, mean, log_sd) 476 | # log_p = log_p.view(b_size, -1).sum(1) 477 | z_eps = gaussian_sample2(z_new, mean, log_sd) 478 | logdet = logdet - log_sd.view(b_size, -1).sum(1) 479 | log_p = -0.5 * log(2 * pi) - 0.5 * z_eps ** 2 480 | log_p = log_p.view(b_size, -1).sum(1) 481 | 482 | else: 483 | zero = torch.zeros_like(out) 484 | mean, log_sd = self.prior(zero).chunk(2, 1) 485 | # log_p = gaussian_log_p(out, mean, log_sd) 486 | # log_p = log_p.view(b_size, -1).sum(1) 487 | z_new = out 488 | z_eps = gaussian_sample2(z_new, mean, log_sd) 489 | logdet = logdet - log_sd.view(b_size, -1).sum(1) 490 | log_p = -0.5 * log(2 * pi) - 0.5 * z_eps ** 2 491 | log_p = log_p.view(b_size, -1).sum(1) 492 | 493 | 494 | return out, logdet, log_p, z_eps 495 | 496 | def reverse(self, output, cond_input, eps=None, device='cpu'): 497 | input = output 498 | 499 | logdet = 0 500 | 501 | if self.split: 502 | mean, log_sd = self.prior(input).chunk(2, 1) 503 | z = gaussian_sample(eps, mean, log_sd) 504 | # log_p = gaussian_log_p2(eps, mean, log_sd) 505 | logdet = logdet + log_sd.view(input.shape[0], -1).sum(1) 506 | input = torch.cat([output, z], 1) 507 | 508 | else: 509 | zero = torch.zeros_like(input) 510 | # zero = F.pad(zero, [1, 1, 1, 1], value=1) 511 | mean, log_sd = self.prior(zero).chunk(2, 1) 512 | z = gaussian_sample(eps, mean, log_sd) 513 | # log_p = gaussian_log_p2(eps, mean, log_sd) 514 | logdet = logdet + log_sd.view(input.shape[0], -1).sum(1) 515 | input = z 516 | 517 | 518 | 519 | for flow in self.flows[::-1]: 520 | input, det = flow.reverse(input, cond_input, device=device) 521 | logdet = logdet + det 522 | 523 | b_size, n_channel, height, width = input.shape 524 | 525 | 526 | unsqueezed = input.view(b_size, n_channel // 4, 2, 2, height, width) 527 | unsqueezed = unsqueezed.permute(0, 1, 4, 2, 5, 3) 528 | unsqueezed = unsqueezed.contiguous().view( 529 | b_size, n_channel // 4, height * 2, width * 2 530 | ) 531 | 532 | return unsqueezed, logdet 533 | 534 | 535 | class Glow(nn.Module): 536 | def __init__( 537 | self, in_channel, n_flow, n_block, affine=True, conv_lu=True 538 | ): 539 | super().__init__() 540 | 541 | self.blocks = nn.ModuleList() 542 | n_channel = in_channel 543 | for i in range(n_block - 1): 544 | self.blocks.append(Block(n_channel, n_flow, affine=affine, conv_lu=conv_lu)) 545 | n_channel *= 2 546 | self.blocks.append(Block(n_channel, n_flow, split=False, affine=affine)) 547 | 548 | def forward(self, input, cond_input, device='cpu'): 549 | log_p_sum = 0 550 | logdet = 0 551 | out = input 552 | z_outs = [] 553 | 554 | for block in self.blocks: 555 | out, det, log_p, z_new = block(out, cond_input, device=device) 556 | z_outs.append(z_new) 557 | # z_outs.append(z_new.view(input.shape[0], -1)) 558 | logdet = logdet + det 559 | 560 | if log_p is not None: 561 | log_p_sum = log_p_sum + log_p 562 | # z_eps = torch.cat(z_outs, -1) 563 | 564 | return log_p_sum, logdet, z_outs 565 | 566 | def reverse(self, z_list, cond_input, device='cpu'): 567 | log_p_sum = 0 568 | logdet = 0 569 | 570 | for i, block in enumerate(self.blocks[::-1]): 571 | if i == 0: 572 | input, det = block.reverse(z_list[-1], cond_input, z_list[-1], device=device) 573 | else: 574 | input, det = block.reverse(input, cond_input, z_list[-(i + 1)], device=device) 575 | 576 | logdet = logdet + det 577 | 578 | 579 | return input, logdet 580 | 581 | 582 | def calc_z_shapes(n_channel, input_size, n_flow, n_block): 583 | z_shapes = [] 584 | 585 | for i in range(n_block - 1): 586 | input_size //= 2 587 | n_channel *= 2 588 | 589 | z_shapes.append((n_channel, input_size, input_size)) 590 | 591 | input_size //= 2 592 | z_shapes.append((n_channel * 4, input_size, input_size)) 593 | 594 | return z_shapes -------------------------------------------------------------------------------- /DPItorch/generative_model/realnvpfc_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from math import log, pi, exp 5 | import numpy as np 6 | from scipy import linalg as la 7 | 8 | class ActNorm(nn.Module): 9 | def __init__(self, logdet=True): 10 | super().__init__() 11 | 12 | self.loc = nn.Parameter(torch.zeros(1, )) 13 | self.log_scale_inv = nn.Parameter(torch.zeros(1, )) 14 | 15 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 16 | self.logdet = logdet 17 | 18 | def initialize(self, input, inv_init=False): 19 | with torch.no_grad(): 20 | mean = input.mean().reshape((1, )) 21 | std = input.std().reshape((1, )) 22 | 23 | if inv_init: 24 | self.loc.data.copy_(torch.zeros_like(mean)) 25 | self.log_scale_inv.data.copy_(torch.zeros_like(std)) 26 | else: 27 | self.loc.data.copy_(-mean) 28 | self.log_scale_inv.data.copy_(torch.log(std + 1e-6)) 29 | 30 | def forward(self, input): 31 | _, in_dim = input.shape 32 | 33 | if self.initialized.item() == 0: 34 | self.initialize(input) 35 | self.initialized.fill_(1) 36 | 37 | scale_inv = torch.exp(self.log_scale_inv) 38 | 39 | log_abs = -self.log_scale_inv 40 | 41 | logdet = in_dim * torch.sum(log_abs) 42 | 43 | if self.logdet: 44 | return (1.0 / scale_inv) * (input + self.loc), logdet 45 | 46 | else: 47 | return (1.0 / scale_inv) * (input + self.loc) 48 | 49 | def reverse(self, output): 50 | _, in_dim = output.shape 51 | 52 | if self.initialized.item() == 0: 53 | self.initialize(output, inv_init=True) 54 | self.initialized.fill_(1) 55 | 56 | scale_inv = torch.exp(self.log_scale_inv) 57 | 58 | log_abs = -self.log_scale_inv 59 | 60 | logdet = -in_dim * torch.sum(log_abs) 61 | 62 | if self.logdet: 63 | return output * scale_inv - self.loc, logdet 64 | 65 | else: 66 | return output * scale_inv - self.loc 67 | 68 | # logabs = lambda x: torch.log(torch.abs(x)) 69 | 70 | # class ActNorm(nn.Module): 71 | # def __init__(self, logdet=True): 72 | # super().__init__() 73 | 74 | # self.loc = nn.Parameter(torch.zeros(1, )) 75 | # self.scale_inv = nn.Parameter(torch.ones(1, )) 76 | 77 | # self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 78 | # self.logdet = logdet 79 | 80 | # def initialize(self, input, inv_init=False): 81 | # with torch.no_grad(): 82 | # mean = input.mean().reshape((1, )) 83 | # std = input.std().reshape((1, )) 84 | 85 | # if inv_init: 86 | # self.loc.data.copy_(torch.zeros_like(mean)) 87 | # self.scale_inv.data.copy_(torch.ones_like(std)) 88 | # else: 89 | # self.loc.data.copy_(-mean) 90 | # self.scale_inv.data.copy_((std + 1e-6)) 91 | 92 | # def forward(self, input): 93 | # _, in_dim = input.shape 94 | 95 | # if self.initialized.item() == 0: 96 | # self.initialize(input) 97 | # self.initialized.fill_(1) 98 | 99 | # log_abs = -logabs(self.scale_inv) 100 | 101 | # logdet = in_dim * torch.sum(log_abs) 102 | 103 | # if self.logdet: 104 | # return (1.0 / self.scale_inv) * (input + self.loc), logdet 105 | 106 | # else: 107 | # return (1.0 / self.scale_inv) * (input + self.loc) 108 | 109 | # def reverse(self, output): 110 | # _, in_dim = output.shape 111 | 112 | # if self.initialized.item() == 0: 113 | # self.initialize(output, inv_init=True) 114 | # self.initialized.fill_(1) 115 | 116 | # log_abs = -logabs(self.scale_inv) 117 | 118 | # logdet = -in_dim * torch.sum(log_abs) 119 | 120 | # if self.logdet: 121 | # return output * self.scale_inv - self.loc, logdet 122 | 123 | # else: 124 | # return output * self.scale_inv - self.loc 125 | 126 | 127 | 128 | class ZeroFC(nn.Module): 129 | def __init__(self, in_dim, out_dim): 130 | super().__init__() 131 | 132 | self.fc = nn.Linear(in_dim, out_dim) 133 | self.fc.weight.data.zero_() 134 | self.fc.bias.data.zero_() 135 | self.scale = nn.Parameter(torch.zeros(out_dim, )) 136 | 137 | def forward(self, input): 138 | out = self.fc(input) 139 | out = out * torch.exp(self.scale * 3) 140 | 141 | return out 142 | 143 | class AffineCoupling(nn.Module): 144 | def __init__(self, ndim, seqfrac=4, affine=True, batch_norm=True): 145 | super().__init__() 146 | 147 | self.affine = affine 148 | self.batch_norm = batch_norm 149 | 150 | # self.net = nn.Sequential( 151 | # nn.Linear(ndim-ndim//2, ndim // (2*seqfrac)), 152 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 153 | # nn.BatchNorm1d(ndim // (2*seqfrac)), 154 | # nn.Linear(ndim // (2*seqfrac), ndim // (2*seqfrac)), 155 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 156 | # nn.BatchNorm1d(ndim // (2*seqfrac)), 157 | # ZeroFC(ndim // (2*seqfrac), 2*(ndim // 2) if self.affine else ndim // 2), 158 | # ) 159 | 160 | # self.net = nn.Sequential( 161 | # nn.Linear(ndim-ndim//2, int(ndim / (2*seqfrac))), 162 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 163 | # nn.LayerNorm(int(ndim / (2*seqfrac))), 164 | # nn.Linear(int(ndim / (2*seqfrac)), int(ndim / (2*seqfrac))), 165 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 166 | # nn.LayerNorm(int(ndim / (2*seqfrac))), 167 | # ZeroFC(int(ndim / (2*seqfrac)), 2*(ndim // 2) if self.affine else ndim // 2), 168 | # ) 169 | if batch_norm: 170 | # older version has skip connection, but we find that not necessary 171 | self.net = nn.Sequential( 172 | nn.Linear(ndim-ndim//2, int(ndim / (2*seqfrac))), 173 | nn.LeakyReLU(negative_slope=0.01, inplace=True), 174 | # nn.Softplus(beta=1, threshold=20), 175 | # nn.Tanh(), 176 | # nn.ReLU(), 177 | # nn.GELU(), 178 | nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 179 | nn.Linear(int(ndim / (2*seqfrac)), int(ndim / (2*seqfrac))), 180 | nn.LeakyReLU(negative_slope=0.01, inplace=True), 181 | # nn.Softplus(beta=1, threshold=20), 182 | # nn.Tanh(), 183 | # nn.ReLU(), 184 | # nn.GELU(), 185 | nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 186 | ZeroFC(int(ndim / (2*seqfrac)), 2*(ndim // 2) if self.affine else ndim // 2), 187 | ) 188 | 189 | self.net[0].weight.data.normal_(0, 0.05) 190 | self.net[0].bias.data.zero_() 191 | 192 | self.net[3].weight.data.normal_(0, 0.05) 193 | self.net[3].bias.data.zero_() 194 | # self.net[2].weight.data.normal_(0, 0.05) 195 | # self.net[2].bias.data.zero_() 196 | else: 197 | # older version has skip connection, but we find that not necessary 198 | self.net = nn.Sequential( 199 | nn.Linear(ndim-ndim//2, int(ndim / (2*seqfrac))), 200 | nn.LeakyReLU(negative_slope=0.01, inplace=True), 201 | # nn.Softplus(beta=1, threshold=20), 202 | # nn.Tanh(), 203 | # nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 204 | nn.Linear(int(ndim / (2*seqfrac)), int(ndim / (2*seqfrac))), 205 | nn.LeakyReLU(negative_slope=0.01, inplace=True), 206 | # nn.Softplus(beta=1, threshold=20), 207 | # nn.Tanh(), 208 | # nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 209 | ZeroFC(int(ndim / (2*seqfrac)), 2*(ndim // 2) if self.affine else ndim // 2), 210 | ) 211 | 212 | self.net[0].weight.data.normal_(0, 0.05) 213 | self.net[0].bias.data.zero_() 214 | 215 | # self.net[3].weight.data.normal_(0, 0.05) 216 | # self.net[3].bias.data.zero_() 217 | self.net[2].weight.data.normal_(0, 0.05) 218 | self.net[2].bias.data.zero_() 219 | 220 | 221 | 222 | def forward(self, input): 223 | in_a, in_b = input.chunk(2, 1) 224 | 225 | if self.affine: 226 | log_s0, t = self.net(in_a).chunk(2, 1) 227 | log_s = torch.tanh(log_s0) 228 | s = torch.exp(log_s) 229 | out_b = (in_b + t) * s 230 | 231 | logdet = torch.sum(log_s.view(input.shape[0], -1), 1) 232 | 233 | else: 234 | net_out = self.net(in_a) 235 | out_b = in_b + net_out 236 | logdet = None 237 | 238 | return torch.cat([in_a, out_b], 1), logdet 239 | 240 | def reverse(self, output): 241 | out_a, out_b = output.chunk(2, 1) 242 | 243 | if self.affine: 244 | log_s0, t = self.net(out_a).chunk(2, 1) 245 | log_s = torch.tanh(log_s0) 246 | s = torch.exp(log_s) 247 | in_b = out_b / s - t 248 | 249 | logdet = -torch.sum(log_s.view(output.shape[0], -1), 1) 250 | 251 | else: 252 | net_out = self.net(out_a) 253 | in_b = out_b - net_out 254 | 255 | logdet = None 256 | 257 | return torch.cat([out_a, in_b], 1), logdet 258 | 259 | 260 | # class AffineCoupling(nn.Module): 261 | # def __init__(self, ndim, seqfrac=4, affine=True, batch_norm=True): 262 | # super().__init__() 263 | 264 | # self.affine = affine 265 | # self.batch_norm = batch_norm 266 | 267 | # # self.net = nn.Sequential( 268 | # # nn.Linear(ndim-ndim//2, ndim // (2*seqfrac)), 269 | # # nn.LeakyReLU(negative_slope=0.01, inplace=True), 270 | # # nn.BatchNorm1d(ndim // (2*seqfrac)), 271 | # # nn.Linear(ndim // (2*seqfrac), ndim // (2*seqfrac)), 272 | # # nn.LeakyReLU(negative_slope=0.01, inplace=True), 273 | # # nn.BatchNorm1d(ndim // (2*seqfrac)), 274 | # # ZeroFC(ndim // (2*seqfrac), 2*(ndim // 2) if self.affine else ndim // 2), 275 | # # ) 276 | 277 | # # self.net = nn.Sequential( 278 | # # nn.Linear(ndim-ndim//2, int(ndim / (2*seqfrac))), 279 | # # nn.LeakyReLU(negative_slope=0.01, inplace=True), 280 | # # nn.LayerNorm(int(ndim / (2*seqfrac))), 281 | # # nn.Linear(int(ndim / (2*seqfrac)), int(ndim / (2*seqfrac))), 282 | # # nn.LeakyReLU(negative_slope=0.01, inplace=True), 283 | # # nn.LayerNorm(int(ndim / (2*seqfrac))), 284 | # # ZeroFC(int(ndim / (2*seqfrac)), 2*(ndim // 2) if self.affine else ndim // 2), 285 | # # ) 286 | # if batch_norm: 287 | # # older version has skip connection, but we find that not necessary 288 | # self.net = nn.Sequential( 289 | # nn.Linear(ndim-ndim//2, int(ndim / (2*seqfrac))), 290 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 291 | # # nn.Softplus(beta=1, threshold=20), 292 | # # nn.Tanh(), 293 | # nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 294 | # nn.Linear(int(ndim / (2*seqfrac)), int(ndim / (2*seqfrac))), 295 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 296 | # # nn.Softplus(beta=1, threshold=20), 297 | # # nn.Tanh(), 298 | # nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 299 | # nn.Linear(int(ndim / (2*seqfrac)), int(ndim / (2*seqfrac))), 300 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 301 | # # nn.Softplus(beta=1, threshold=20), 302 | # # nn.Tanh(), 303 | # nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 304 | # ZeroFC(int(ndim / (2*seqfrac)), 2*(ndim // 2) if self.affine else ndim // 2), 305 | # ) 306 | 307 | # self.net[0].weight.data.normal_(0, 0.05) 308 | # self.net[0].bias.data.zero_() 309 | 310 | # self.net[3].weight.data.normal_(0, 0.05) 311 | # self.net[3].bias.data.zero_() 312 | # # self.net[2].weight.data.normal_(0, 0.05) 313 | # # self.net[2].bias.data.zero_() 314 | 315 | # self.net[6].weight.data.normal_(0, 0.05) 316 | # self.net[6].bias.data.zero_() 317 | # else: 318 | # # older version has skip connection, but we find that not necessary 319 | # self.net = nn.Sequential( 320 | # nn.Linear(ndim-ndim//2, int(ndim / (2*seqfrac))), 321 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 322 | # # nn.Softplus(beta=1, threshold=20), 323 | # # nn.Tanh(), 324 | # # nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 325 | # nn.Linear(int(ndim / (2*seqfrac)), int(ndim / (2*seqfrac))), 326 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 327 | # # nn.Softplus(beta=1, threshold=20), 328 | # # nn.Tanh(), 329 | # # nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 330 | # nn.Linear(int(ndim / (2*seqfrac)), int(ndim / (2*seqfrac))), 331 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 332 | # # nn.Softplus(beta=1, threshold=20), 333 | # # nn.Tanh(), 334 | # # nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 335 | # ZeroFC(int(ndim / (2*seqfrac)), 2*(ndim // 2) if self.affine else ndim // 2), 336 | # ) 337 | 338 | # self.net[0].weight.data.normal_(0, 0.05) 339 | # self.net[0].bias.data.zero_() 340 | 341 | # # self.net[3].weight.data.normal_(0, 0.05) 342 | # # self.net[3].bias.data.zero_() 343 | # self.net[2].weight.data.normal_(0, 0.05) 344 | # self.net[2].bias.data.zero_() 345 | 346 | # self.net[4].weight.data.normal_(0, 0.05) 347 | # self.net[4].bias.data.zero_() 348 | 349 | # def forward(self, input): 350 | # in_a, in_b = input.chunk(2, 1) 351 | 352 | # if self.affine: 353 | # log_s0, t = self.net(in_a).chunk(2, 1) 354 | # log_s = torch.tanh(log_s0) 355 | # s = torch.exp(log_s) 356 | # out_b = (in_b + t) * s 357 | 358 | # logdet = torch.sum(log_s.view(input.shape[0], -1), 1) 359 | 360 | # else: 361 | # net_out = self.net(in_a) 362 | # out_b = in_b + net_out 363 | # logdet = None 364 | 365 | # return torch.cat([in_a, out_b], 1), logdet 366 | 367 | # def reverse(self, output): 368 | # out_a, out_b = output.chunk(2, 1) 369 | 370 | # if self.affine: 371 | # log_s0, t = self.net(out_a).chunk(2, 1) 372 | # log_s = torch.tanh(log_s0) 373 | # s = torch.exp(log_s) 374 | # in_b = out_b / s - t 375 | 376 | # logdet = -torch.sum(log_s.view(output.shape[0], -1), 1) 377 | 378 | # else: 379 | # net_out = self.net(out_a) 380 | # in_b = out_b - net_out 381 | 382 | # logdet = None 383 | 384 | # return torch.cat([out_a, in_b], 1), logdet 385 | 386 | 387 | # class Flow(nn.Module): 388 | # def __init__(self, ndim, affine=True, seqfrac=4, batch_norm=True): 389 | # super().__init__() 390 | 391 | # self.coupling = AffineCoupling(ndim, seqfrac=seqfrac, affine=affine, batch_norm=batch_norm) 392 | # self.coupling2 = AffineCoupling(ndim, seqfrac=seqfrac, affine=affine, batch_norm=batch_norm) 393 | 394 | # self.ndim = ndim 395 | 396 | # def forward(self, input): 397 | # logdet = 0 398 | # out, det2 = self.coupling(input) 399 | # out = out[:, np.arange(self.ndim-1, -1, -1)] 400 | # out, det4 = self.coupling2(out) 401 | # out = out[:, np.arange(self.ndim-1, -1, -1)] 402 | 403 | # if det2 is not None: 404 | # logdet = logdet + det2 405 | # if det4 is not None: 406 | # logdet = logdet + det4 407 | 408 | # return out, logdet 409 | 410 | # def reverse(self, output): 411 | # logdet = 0 412 | # input = output[:, np.arange(self.ndim-1, -1, -1)] 413 | # input, det1 = self.coupling2.reverse(input) 414 | # input = input[:, np.arange(self.ndim-1, -1, -1)] 415 | # input, det3 = self.coupling.reverse(input) 416 | 417 | 418 | # if det1 is not None: 419 | # logdet = logdet + det1 420 | # if det3 is not None: 421 | # logdet = logdet + det3 422 | 423 | # return input, logdet 424 | 425 | 426 | # alternating affine coupling 427 | class Flow(nn.Module): 428 | def __init__(self, ndim, affine=True, seqfrac=4, batch_norm=True): 429 | super().__init__() 430 | 431 | self.actnorm = ActNorm() 432 | self.actnorm2 = ActNorm() 433 | 434 | self.coupling = AffineCoupling(ndim, seqfrac=seqfrac, affine=affine, batch_norm=batch_norm) 435 | self.coupling2 = AffineCoupling(ndim, seqfrac=seqfrac, affine=affine, batch_norm=batch_norm) 436 | 437 | self.ndim = ndim 438 | 439 | def forward(self, input): 440 | logdet = 0 441 | out, det1 = self.actnorm(input) 442 | out, det2 = self.coupling(out) 443 | out = out[:, np.arange(self.ndim-1, -1, -1)] 444 | out, det3 = self.actnorm2(out) 445 | out, det4 = self.coupling2(out) 446 | out = out[:, np.arange(self.ndim-1, -1, -1)] 447 | 448 | logdet = logdet + det1 449 | if det2 is not None: 450 | logdet = logdet + det2 451 | logdet = logdet + det3 452 | if det4 is not None: 453 | logdet = logdet + det4 454 | 455 | return out, logdet 456 | 457 | def reverse(self, output): 458 | logdet = 0 459 | input = output[:, np.arange(self.ndim-1, -1, -1)] 460 | input, det1 = self.coupling2.reverse(input) 461 | input, det2 = self.actnorm2.reverse(input) 462 | input = input[:, np.arange(self.ndim-1, -1, -1)] 463 | input, det3 = self.coupling.reverse(input) 464 | input, det4 = self.actnorm.reverse(input) 465 | 466 | 467 | if det1 is not None: 468 | logdet = logdet + det1 469 | logdet = logdet + det2 470 | if det3 is not None: 471 | logdet = logdet + det3 472 | logdet = logdet + det4 473 | 474 | return input, logdet 475 | 476 | # # single affine coupling 477 | # class Flow(nn.Module): 478 | # def __init__(self, ndim, affine=True, seqfrac=4, batch_norm=True): 479 | # super().__init__() 480 | 481 | # self.actnorm = ActNorm() 482 | 483 | # self.coupling = AffineCoupling(ndim, seqfrac=seqfrac, affine=affine, batch_norm=batch_norm) 484 | 485 | # self.ndim = ndim 486 | 487 | # def forward(self, input): 488 | # logdet = 0 489 | # out, det1 = self.actnorm(input) 490 | # out, det2 = self.coupling(out) 491 | 492 | # logdet = logdet + det1 493 | # if det2 is not None: 494 | # logdet = logdet + det2 495 | 496 | # return out, logdet 497 | 498 | # def reverse(self, output): 499 | # logdet = 0 500 | # input = output 501 | # input, det1 = self.coupling.reverse(input) 502 | # input, det2 = self.actnorm.reverse(input) 503 | 504 | # if det1 is not None: 505 | # logdet = logdet + det1 506 | # logdet = logdet + det2 507 | 508 | # return input, logdet 509 | 510 | 511 | # class Flow(nn.Module): 512 | # def __init__(self, ndim, affine=True, seqfrac=4): 513 | # super().__init__() 514 | 515 | # self.actnorm = ActNorm() 516 | # self.actnorm2 = ActNorm() 517 | 518 | # self.coupling = AffineCoupling(ndim, seqfrac=seqfrac, affine=affine) 519 | # self.coupling2 = AffineCoupling(ndim, seqfrac=seqfrac, affine=affine) 520 | 521 | # self.ndim = ndim 522 | 523 | # def forward(self, input): 524 | # logdet = 0 525 | # out, det2 = self.coupling(input) 526 | # out = out[:, np.arange(self.ndim-1, -1, -1)] 527 | # out, det4 = self.coupling2(out) 528 | # out = out[:, np.arange(self.ndim-1, -1, -1)] 529 | 530 | # if det2 is not None: 531 | # logdet = logdet + det2 532 | # if det4 is not None: 533 | # logdet = logdet + det4 534 | 535 | # return out, logdet 536 | 537 | # def reverse(self, output): 538 | # logdet = 0 539 | # input = output[:, np.arange(self.ndim-1, -1, -1)] 540 | # input, det1 = self.coupling2.reverse(input) 541 | # input = input[:, np.arange(self.ndim-1, -1, -1)] 542 | # input, det3 = self.coupling.reverse(input) 543 | 544 | 545 | # if det1 is not None: 546 | # logdet = logdet + det1 547 | # if det3 is not None: 548 | # logdet = logdet + det3 549 | 550 | # return input, logdet 551 | 552 | 553 | def Order_inverse(order): 554 | order_inv = [] 555 | for k in range(len(order)): 556 | for i in range(len(order)): 557 | if order[i] == k: 558 | order_inv.append(i) 559 | return np.array(order_inv) 560 | 561 | 562 | class RealNVP(nn.Module): 563 | def __init__( 564 | self, ndim, n_flow, affine=True, seqfrac=4, permute='random', batch_norm=True 565 | ): 566 | super().__init__() 567 | self.blocks = nn.ModuleList() 568 | self.orders = [] 569 | for i in range(n_flow): 570 | self.blocks.append(Flow(ndim, affine=affine, seqfrac=seqfrac, batch_norm=batch_norm)) 571 | if permute == 'random': 572 | self.orders.append(np.random.RandomState(seed=i).permutation(ndim)) 573 | elif permute == 'reverse': 574 | self.orders.append(np.arange(ndim-1, -1, -1)) 575 | else: 576 | print('We can only do no permutation, random permutation or reverse permutation in affine coupling layer. Using no permutation by default!') 577 | self.orders.append(np.arange(ndim)) 578 | 579 | self.inverse_orders = [] 580 | for i in range(n_flow): 581 | self.inverse_orders.append(Order_inverse(self.orders[i])) 582 | 583 | def forward(self, input): 584 | logdet = 0 585 | out = input 586 | 587 | for i in range(len(self.blocks)): 588 | out, det = self.blocks[i](out) 589 | logdet = logdet + det 590 | out = out[:, self.orders[i]] 591 | 592 | return out, logdet 593 | 594 | def reverse(self, out): 595 | logdet = 0 596 | input = out 597 | 598 | for i in range(len(self.blocks)-1, -1, -1): 599 | input = input[:, self.inverse_orders[i]] 600 | input, det = self.blocks[i].reverse(input) 601 | logdet = logdet + det 602 | 603 | return input, logdet -------------------------------------------------------------------------------- /DPItorch/generative_model/cond_realnvpfc_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from math import log, pi, exp 5 | import numpy as np 6 | from scipy import linalg as la 7 | 8 | 9 | class ActNorm(nn.Module): 10 | def __init__(self, logdet=True): 11 | super().__init__() 12 | 13 | self.loc = nn.Parameter(torch.zeros(1, )) 14 | self.log_scale_inv = nn.Parameter(torch.zeros(1, )) 15 | 16 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 17 | self.logdet = logdet 18 | 19 | def initialize(self, input, inv_init=False): 20 | with torch.no_grad(): 21 | mean = input.mean().reshape((1, )) 22 | std = input.std().reshape((1, )) 23 | 24 | if inv_init: 25 | self.loc.data.copy_(torch.zeros_like(mean)) 26 | self.log_scale_inv.data.copy_(torch.zeros_like(std)) 27 | else: 28 | self.loc.data.copy_(-mean) 29 | self.log_scale_inv.data.copy_(torch.log(std + 1e-6)) 30 | 31 | def forward(self, input): 32 | _, in_dim = input.shape 33 | 34 | if self.initialized.item() == 0: 35 | self.initialize(input) 36 | self.initialized.fill_(1) 37 | 38 | scale_inv = torch.exp(self.log_scale_inv) 39 | 40 | log_abs = -self.log_scale_inv 41 | 42 | logdet = in_dim * torch.sum(log_abs) 43 | 44 | if self.logdet: 45 | return (1.0 / scale_inv) * (input + self.loc), logdet 46 | 47 | else: 48 | return (1.0 / scale_inv) * (input + self.loc) 49 | 50 | def reverse(self, output): 51 | _, in_dim = output.shape 52 | 53 | if self.initialized.item() == 0: 54 | self.initialize(output, inv_init=True) 55 | self.initialized.fill_(1) 56 | 57 | scale_inv = torch.exp(self.log_scale_inv) 58 | 59 | log_abs = -self.log_scale_inv 60 | 61 | logdet = -in_dim * torch.sum(log_abs) 62 | 63 | if self.logdet: 64 | return output * scale_inv - self.loc, logdet 65 | 66 | else: 67 | return output * scale_inv - self.loc 68 | 69 | 70 | # logabs = lambda x: torch.log(torch.abs(x)) 71 | 72 | 73 | # class ActNorm(nn.Module): 74 | # def __init__(self, logdet=True): 75 | # super().__init__() 76 | 77 | # self.loc = nn.Parameter(torch.zeros(1, )) 78 | # self.scale_inv = nn.Parameter(torch.ones(1, )) 79 | 80 | # self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 81 | # self.logdet = logdet 82 | 83 | # def initialize(self, input, inv_init=False): 84 | # with torch.no_grad(): 85 | # mean = input.mean().reshape((1, )) 86 | # std = input.std().reshape((1, )) 87 | 88 | # if inv_init: 89 | # self.loc.data.copy_(torch.zeros_like(mean)) 90 | # self.scale_inv.data.copy_(torch.ones_like(std)) 91 | # else: 92 | # self.loc.data.copy_(-mean) 93 | # self.scale_inv.data.copy_((std + 1e-6)) 94 | 95 | # def forward(self, input): 96 | # _, in_dim = input.shape 97 | 98 | # if self.initialized.item() == 0: 99 | # self.initialize(input) 100 | # self.initialized.fill_(1) 101 | 102 | # log_abs = -logabs(self.scale_inv) 103 | 104 | # logdet = in_dim * torch.sum(log_abs) 105 | 106 | # if self.logdet: 107 | # return (1.0 / self.scale_inv) * (input + self.loc), logdet 108 | 109 | # else: 110 | # return (1.0 / self.scale_inv) * (input + self.loc) 111 | 112 | # def reverse(self, output): 113 | # _, in_dim = output.shape 114 | 115 | # if self.initialized.item() == 0: 116 | # self.initialize(output, inv_init=True) 117 | # self.initialized.fill_(1) 118 | 119 | # log_abs = -logabs(self.scale_inv) 120 | 121 | # logdet = -in_dim * torch.sum(log_abs) 122 | 123 | # if self.logdet: 124 | # return output * self.scale_inv - self.loc, logdet 125 | 126 | # else: 127 | # return output * self.scale_inv - self.loc 128 | 129 | 130 | 131 | class ZeroFC(nn.Module): 132 | def __init__(self, in_dim, out_dim): 133 | super().__init__() 134 | 135 | self.fc = nn.Linear(in_dim, out_dim) 136 | self.fc.weight.data.zero_() 137 | self.fc.bias.data.zero_() 138 | self.scale = nn.Parameter(torch.zeros(out_dim, )) 139 | 140 | def forward(self, input): 141 | out = self.fc(input) 142 | out = out * torch.exp(self.scale * 3) 143 | 144 | return out 145 | 146 | 147 | class AffineCoupling(nn.Module): 148 | def __init__(self, ndim, ndim_cond, seqfrac=4, affine=True, batch_norm=True): 149 | super().__init__() 150 | 151 | self.affine = affine 152 | self.batch_norm = batch_norm 153 | 154 | # self.net = nn.Sequential( 155 | # nn.Linear(ndim-ndim//2, ndim // (2*seqfrac)), 156 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 157 | # nn.BatchNorm1d(ndim // (2*seqfrac)), 158 | # nn.Linear(ndim // (2*seqfrac), ndim // (2*seqfrac)), 159 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 160 | # nn.BatchNorm1d(ndim // (2*seqfrac)), 161 | # ZeroFC(ndim // (2*seqfrac), 2*(ndim // 2) if self.affine else ndim // 2), 162 | # ) 163 | 164 | # self.net = nn.Sequential( 165 | # nn.Linear(ndim-ndim//2, int(ndim / (2*seqfrac))), 166 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 167 | # nn.LayerNorm(int(ndim / (2*seqfrac))), 168 | # nn.Linear(int(ndim / (2*seqfrac)), int(ndim / (2*seqfrac))), 169 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 170 | # nn.LayerNorm(int(ndim / (2*seqfrac))), 171 | # ZeroFC(int(ndim / (2*seqfrac)), 2*(ndim // 2) if self.affine else ndim // 2), 172 | # ) 173 | 174 | # older version has skip connection, but we find that not necessary 175 | # self.net = nn.Sequential( 176 | # nn.Linear(ndim_cond+ndim-ndim//2, int(ndim / (2*seqfrac))), 177 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 178 | # nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 179 | # nn.Linear(int(ndim / (2*seqfrac)), int(ndim / (2*seqfrac))), 180 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 181 | # nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 182 | # ZeroFC(int(ndim / (2*seqfrac)), 2*(ndim // 2) if self.affine else ndim // 2), 183 | # ) 184 | 185 | if batch_norm: 186 | self.net = nn.Sequential( 187 | nn.Linear(ndim-ndim//2+ndim_cond, int(ndim / (2*seqfrac))), 188 | nn.LeakyReLU(negative_slope=0.01, inplace=True), 189 | nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 190 | nn.Linear(int(ndim / (2*seqfrac)), int(ndim / (2*seqfrac))), 191 | nn.LeakyReLU(negative_slope=0.01, inplace=True), 192 | nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 193 | ZeroFC(int(ndim / (2*seqfrac)), 2*(ndim // 2) if self.affine else ndim // 2), 194 | ) 195 | 196 | self.net[0].weight.data.normal_(0, 0.05) 197 | self.net[0].bias.data.zero_() 198 | 199 | self.net[3].weight.data.normal_(0, 0.05) 200 | self.net[3].bias.data.zero_() 201 | # self.net[2].weight.data.normal_(0, 0.05) 202 | # self.net[2].bias.data.zero_() 203 | 204 | else: 205 | self.net = nn.Sequential( 206 | nn.Linear(ndim-ndim//2+ndim_cond, int(ndim / (2*seqfrac))), 207 | nn.LeakyReLU(negative_slope=0.01, inplace=True), 208 | nn.Linear(int(ndim / (2*seqfrac)), int(ndim / (2*seqfrac))), 209 | nn.LeakyReLU(negative_slope=0.01, inplace=True), 210 | ZeroFC(int(ndim / (2*seqfrac)), 2*(ndim // 2) if self.affine else ndim // 2), 211 | ) 212 | 213 | 214 | self.net[0].weight.data.normal_(0, 0.05) 215 | self.net[0].bias.data.zero_() 216 | 217 | self.net[2].weight.data.normal_(0, 0.05) 218 | self.net[2].bias.data.zero_() 219 | 220 | 221 | 222 | def forward(self, input, cond_input): 223 | in_a, in_b = input.chunk(2, 1) 224 | 225 | in_a_cond = torch.cat([in_a, cond_input], -1) 226 | 227 | if self.affine: 228 | log_s0, t = self.net(in_a_cond).chunk(2, 1) 229 | log_s = torch.tanh(log_s0) 230 | s = torch.exp(log_s) 231 | out_b = (in_b + t) * s 232 | 233 | logdet = torch.sum(log_s.view(input.shape[0], -1), 1) 234 | 235 | else: 236 | net_out = self.net(in_a_cond) 237 | out_b = in_b + net_out 238 | logdet = None 239 | 240 | return torch.cat([in_a, out_b], 1), logdet 241 | 242 | def reverse(self, output, cond_input): 243 | out_a, out_b = output.chunk(2, 1) 244 | 245 | out_a_cond = torch.cat([out_a, cond_input], -1) 246 | 247 | if self.affine: 248 | log_s0, t = self.net(out_a_cond).chunk(2, 1) 249 | log_s = torch.tanh(log_s0) 250 | s = torch.exp(log_s) 251 | in_b = out_b / s - t 252 | 253 | logdet = -torch.sum(log_s.view(output.shape[0], -1), 1) 254 | 255 | else: 256 | net_out = self.net(out_a_cond) 257 | in_b = out_b - net_out 258 | 259 | logdet = None 260 | 261 | return torch.cat([out_a, in_b], 1), logdet 262 | 263 | 264 | # class AffineCoupling(nn.Module): 265 | # def __init__(self, ndim, ndim_cond, seqfrac=4, affine=True, batch_norm=True): 266 | # super().__init__() 267 | 268 | # self.affine = affine 269 | # self.batch_norm = batch_norm 270 | 271 | # # self.net = nn.Sequential( 272 | # # nn.Linear(ndim-ndim//2, ndim // (2*seqfrac)), 273 | # # nn.LeakyReLU(negative_slope=0.01, inplace=True), 274 | # # nn.BatchNorm1d(ndim // (2*seqfrac)), 275 | # # nn.Linear(ndim // (2*seqfrac), ndim // (2*seqfrac)), 276 | # # nn.LeakyReLU(negative_slope=0.01, inplace=True), 277 | # # nn.BatchNorm1d(ndim // (2*seqfrac)), 278 | # # ZeroFC(ndim // (2*seqfrac), 2*(ndim // 2) if self.affine else ndim // 2), 279 | # # ) 280 | 281 | # # self.net = nn.Sequential( 282 | # # nn.Linear(ndim-ndim//2, int(ndim / (2*seqfrac))), 283 | # # nn.LeakyReLU(negative_slope=0.01, inplace=True), 284 | # # nn.LayerNorm(int(ndim / (2*seqfrac))), 285 | # # nn.Linear(int(ndim / (2*seqfrac)), int(ndim / (2*seqfrac))), 286 | # # nn.LeakyReLU(negative_slope=0.01, inplace=True), 287 | # # nn.LayerNorm(int(ndim / (2*seqfrac))), 288 | # # ZeroFC(int(ndim / (2*seqfrac)), 2*(ndim // 2) if self.affine else ndim // 2), 289 | # # ) 290 | 291 | # # older version has skip connection, but we find that not necessary 292 | # # self.net = nn.Sequential( 293 | # # nn.Linear(ndim_cond+ndim-ndim//2, int(ndim / (2*seqfrac))), 294 | # # nn.LeakyReLU(negative_slope=0.01, inplace=True), 295 | # # nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 296 | # # nn.Linear(int(ndim / (2*seqfrac)), int(ndim / (2*seqfrac))), 297 | # # nn.LeakyReLU(negative_slope=0.01, inplace=True), 298 | # # nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 299 | # # ZeroFC(int(ndim / (2*seqfrac)), 2*(ndim // 2) if self.affine else ndim // 2), 300 | # # ) 301 | 302 | # if batch_norm: 303 | # self.net = nn.Sequential( 304 | # nn.Linear(ndim, int(ndim / (2*seqfrac))), 305 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 306 | # nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 307 | # nn.Linear(int(ndim / (2*seqfrac)), int(ndim / (2*seqfrac))), 308 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 309 | # nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 310 | # ZeroFC(int(ndim / (2*seqfrac)), 2*(ndim // 2) if self.affine else ndim // 2), 311 | # ) 312 | 313 | # self.net2 = nn.Sequential(nn.Linear(ndim_cond, ndim//2)) 314 | 315 | # self.net[0].weight.data.normal_(0, 0.05) 316 | # self.net[0].bias.data.zero_() 317 | 318 | # self.net[3].weight.data.normal_(0, 0.05) 319 | # self.net[3].bias.data.zero_() 320 | # # self.net[2].weight.data.normal_(0, 0.05) 321 | # # self.net[2].bias.data.zero_() 322 | 323 | # self.net2[0].weight.data.normal_(0, 0.05) 324 | # self.net2[0].bias.data.zero_() 325 | # else: 326 | # self.net = nn.Sequential( 327 | # nn.Linear(ndim, int(ndim / (2*seqfrac))), 328 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 329 | # nn.Linear(int(ndim / (2*seqfrac)), int(ndim / (2*seqfrac))), 330 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 331 | # ZeroFC(int(ndim / (2*seqfrac)), 2*(ndim // 2) if self.affine else ndim // 2), 332 | # ) 333 | 334 | # self.net2 = nn.Sequential(nn.Linear(ndim_cond, ndim//2)) 335 | 336 | # self.net[0].weight.data.normal_(0, 0.05) 337 | # self.net[0].bias.data.zero_() 338 | 339 | # self.net[2].weight.data.normal_(0, 0.05) 340 | # self.net[2].bias.data.zero_() 341 | 342 | # self.net2[0].weight.data.normal_(0, 0.05) 343 | # self.net2[0].bias.data.zero_() 344 | 345 | # def forward(self, input, cond_input): 346 | # in_a, in_b = input.chunk(2, 1) 347 | 348 | # cond_input = self.net2(cond_input) 349 | # in_a_cond = torch.cat([in_a, cond_input], -1) 350 | 351 | # if self.affine: 352 | # log_s0, t = self.net(in_a_cond).chunk(2, 1) 353 | # log_s = torch.tanh(log_s0) 354 | # s = torch.exp(log_s) 355 | # out_b = (in_b + t) * s 356 | 357 | # logdet = torch.sum(log_s.view(input.shape[0], -1), 1) 358 | 359 | # else: 360 | # net_out = self.net(in_a_cond) 361 | # out_b = in_b + net_out 362 | # logdet = None 363 | 364 | # return torch.cat([in_a, out_b], 1), logdet 365 | 366 | # def reverse(self, output, cond_input): 367 | # out_a, out_b = output.chunk(2, 1) 368 | 369 | # cond_input = self.net2(cond_input) 370 | # out_a_cond = torch.cat([out_a, cond_input], -1) 371 | 372 | # if self.affine: 373 | # log_s0, t = self.net(out_a_cond).chunk(2, 1) 374 | # log_s = torch.tanh(log_s0) 375 | # s = torch.exp(log_s) 376 | # in_b = out_b / s - t 377 | 378 | # logdet = -torch.sum(log_s.view(output.shape[0], -1), 1) 379 | 380 | # else: 381 | # net_out = self.net(out_a_cond) 382 | # in_b = out_b - net_out 383 | 384 | # logdet = None 385 | 386 | # return torch.cat([out_a, in_b], 1), logdet 387 | 388 | 389 | 390 | # class AffineCoupling(nn.Module): 391 | # def __init__(self, ndim, ndim_cond, seqfrac=4, affine=True): 392 | # super().__init__() 393 | 394 | # self.affine = affine 395 | 396 | # # self.net = nn.Sequential( 397 | # # nn.Linear(ndim-ndim//2, ndim // (2*seqfrac)), 398 | # # nn.LeakyReLU(negative_slope=0.01, inplace=True), 399 | # # nn.BatchNorm1d(ndim // (2*seqfrac)), 400 | # # nn.Linear(ndim // (2*seqfrac), ndim // (2*seqfrac)), 401 | # # nn.LeakyReLU(negative_slope=0.01, inplace=True), 402 | # # nn.BatchNorm1d(ndim // (2*seqfrac)), 403 | # # ZeroFC(ndim // (2*seqfrac), 2*(ndim // 2) if self.affine else ndim // 2), 404 | # # ) 405 | 406 | # # self.net = nn.Sequential( 407 | # # nn.Linear(ndim-ndim//2, int(ndim / (2*seqfrac))), 408 | # # nn.LeakyReLU(negative_slope=0.01, inplace=True), 409 | # # nn.LayerNorm(int(ndim / (2*seqfrac))), 410 | # # nn.Linear(int(ndim / (2*seqfrac)), int(ndim / (2*seqfrac))), 411 | # # nn.LeakyReLU(negative_slope=0.01, inplace=True), 412 | # # nn.LayerNorm(int(ndim / (2*seqfrac))), 413 | # # ZeroFC(int(ndim / (2*seqfrac)), 2*(ndim // 2) if self.affine else ndim // 2), 414 | # # ) 415 | 416 | # # older version has skip connection, but we find that not necessary 417 | # self.net = nn.Sequential( 418 | # nn.Linear(ndim-ndim//2, int(ndim / (2*seqfrac))), 419 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 420 | # nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 421 | # nn.Linear(int(ndim / (2*seqfrac)), int(ndim / (2*seqfrac))), 422 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 423 | # nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 424 | # ZeroFC(int(ndim / (2*seqfrac)), 2*(ndim // 2) if self.affine else ndim // 2), 425 | # ) 426 | 427 | # self.net2 = nn.Sequential( 428 | # nn.Linear(ndim_cond, int(ndim / (2*seqfrac))), 429 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 430 | # # nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 431 | # nn.Linear(int(ndim / (2*seqfrac)), int(ndim / (2*seqfrac))), 432 | # nn.LeakyReLU(negative_slope=0.01, inplace=True), 433 | # # nn.BatchNorm1d(int(ndim / (2*seqfrac)), eps=1e-2, affine=True), 434 | # ZeroFC(int(ndim / (2*seqfrac)), 2*(ndim // 2) if self.affine else ndim // 2), 435 | # ) 436 | 437 | 438 | # self.net[0].weight.data.normal_(0, 0.05) 439 | # self.net[0].bias.data.zero_() 440 | 441 | # self.net[3].weight.data.normal_(0, 0.05) 442 | # self.net[3].bias.data.zero_() 443 | 444 | # self.net2[0].weight.data.normal_(0, 0.05) 445 | # self.net2[0].bias.data.zero_() 446 | 447 | # self.net2[2].weight.data.normal_(0, 0.05) 448 | # self.net2[2].bias.data.zero_() 449 | # # self.net[2].weight.data.normal_(0, 0.05) 450 | # # self.net[2].bias.data.zero_() 451 | 452 | # def forward(self, input, cond_input): 453 | # in_a, in_b = input.chunk(2, 1) 454 | 455 | # if self.affine: 456 | # log_s0, t0 = self.net(in_a).chunk(2, 1) 457 | # log_s1, t1 = self.net2(cond_input).chunk(2, 1) 458 | # log_s = torch.tanh(log_s0+log_s1) 459 | # s = torch.exp(log_s) 460 | # t = t0 + t1 461 | # out_b = (in_b + t) * s 462 | 463 | # logdet = torch.sum(log_s.view(input.shape[0], -1), 1) 464 | 465 | # else: 466 | # net_out0 = self.net(in_a) 467 | # net_out1 = self.net2(cond_input) 468 | # net_out = net_out0 + net_out1 469 | # out_b = in_b + net_out 470 | # logdet = None 471 | 472 | # return torch.cat([in_a, out_b], 1), logdet 473 | 474 | # def reverse(self, output, cond_input): 475 | # out_a, out_b = output.chunk(2, 1) 476 | 477 | # if self.affine: 478 | # log_s0, t0 = self.net(out_a).chunk(2, 1) 479 | # log_s1, t1 = self.net2(cond_input).chunk(2, 1) 480 | # log_s = torch.tanh(log_s0+log_s1) 481 | # s = torch.exp(log_s) 482 | # t = t0 + t1 483 | # in_b = out_b / s - t 484 | 485 | # logdet = -torch.sum(log_s.view(output.shape[0], -1), 1) 486 | 487 | # else: 488 | # net_out0 = self.net(out_a) 489 | # net_out1 = self.net2(cond_input) 490 | # net_out = net_out0 + net_out1 491 | # in_b = out_b - net_out 492 | 493 | # logdet = None 494 | 495 | # return torch.cat([out_a, in_b], 1), logdet 496 | 497 | 498 | 499 | # class Flow(nn.Module): 500 | # def __init__(self, ndim, ndim_cond, affine=True, seqfrac=4, batch_norm=True): 501 | # super().__init__() 502 | 503 | 504 | # self.coupling = AffineCoupling(ndim, ndim_cond, seqfrac=seqfrac, affine=affine, batch_norm=batch_norm) 505 | # self.coupling2 = AffineCoupling(ndim, ndim_cond, seqfrac=seqfrac, affine=affine, batch_norm=batch_norm) 506 | 507 | # self.ndim = ndim 508 | 509 | # def forward(self, input, cond_input): 510 | # logdet = 0 511 | # out, det2 = self.coupling(input, cond_input) 512 | # out = out[:, np.arange(self.ndim-1, -1, -1)] 513 | # out, det4 = self.coupling2(out, cond_input) 514 | # out = out[:, np.arange(self.ndim-1, -1, -1)] 515 | 516 | # if det2 is not None: 517 | # logdet = logdet + det2 518 | # if det4 is not None: 519 | # logdet = logdet + det4 520 | 521 | # return out, logdet 522 | 523 | # def reverse(self, output, cond_input): 524 | # logdet = 0 525 | # input = output[:, np.arange(self.ndim-1, -1, -1)] 526 | # input, det1 = self.coupling2.reverse(input, cond_input) 527 | # input = input[:, np.arange(self.ndim-1, -1, -1)] 528 | # input, det3 = self.coupling.reverse(input, cond_input) 529 | 530 | 531 | # if det1 is not None: 532 | # logdet = logdet + det1 533 | # if det3 is not None: 534 | # logdet = logdet + det3 535 | 536 | # return input, logdet 537 | 538 | 539 | 540 | class Flow(nn.Module): 541 | def __init__(self, ndim, ndim_cond, affine=True, seqfrac=4, batch_norm=True): 542 | super().__init__() 543 | 544 | self.actnorm = ActNorm() 545 | self.actnorm2 = ActNorm() 546 | 547 | self.coupling = AffineCoupling(ndim, ndim_cond, seqfrac=seqfrac, affine=affine, batch_norm=batch_norm) 548 | self.coupling2 = AffineCoupling(ndim, ndim_cond, seqfrac=seqfrac, affine=affine, batch_norm=batch_norm) 549 | 550 | self.ndim = ndim 551 | 552 | def forward(self, input, cond_input): 553 | logdet = 0 554 | out, det1 = self.actnorm(input) 555 | out, det2 = self.coupling(out, cond_input) 556 | out = out[:, np.arange(self.ndim-1, -1, -1)] 557 | out, det3 = self.actnorm2(out) 558 | out, det4 = self.coupling2(out, cond_input) 559 | out = out[:, np.arange(self.ndim-1, -1, -1)] 560 | 561 | logdet = logdet + det1 562 | if det2 is not None: 563 | logdet = logdet + det2 564 | logdet = logdet + det3 565 | if det4 is not None: 566 | logdet = logdet + det4 567 | 568 | return out, logdet 569 | 570 | def reverse(self, output, cond_input): 571 | logdet = 0 572 | input = output[:, np.arange(self.ndim-1, -1, -1)] 573 | input, det1 = self.coupling2.reverse(input, cond_input) 574 | input, det2 = self.actnorm2.reverse(input) 575 | input = input[:, np.arange(self.ndim-1, -1, -1)] 576 | input, det3 = self.coupling.reverse(input, cond_input) 577 | input, det4 = self.actnorm.reverse(input) 578 | 579 | 580 | if det1 is not None: 581 | logdet = logdet + det1 582 | logdet = logdet + det2 583 | if det3 is not None: 584 | logdet = logdet + det3 585 | logdet = logdet + det4 586 | 587 | return input, logdet 588 | 589 | 590 | # class Flow(nn.Module): 591 | # def __init__(self, ndim, affine=True, seqfrac=4): 592 | # super().__init__() 593 | 594 | # self.actnorm = ActNorm() 595 | # self.actnorm2 = ActNorm() 596 | 597 | # self.coupling = AffineCoupling(ndim, seqfrac=seqfrac, affine=affine) 598 | # self.coupling2 = AffineCoupling(ndim, seqfrac=seqfrac, affine=affine) 599 | 600 | # self.ndim = ndim 601 | 602 | # def forward(self, input): 603 | # logdet = 0 604 | # out, det2 = self.coupling(input) 605 | # out = out[:, np.arange(self.ndim-1, -1, -1)] 606 | # out, det4 = self.coupling2(out) 607 | # out = out[:, np.arange(self.ndim-1, -1, -1)] 608 | 609 | # if det2 is not None: 610 | # logdet = logdet + det2 611 | # if det4 is not None: 612 | # logdet = logdet + det4 613 | 614 | # return out, logdet 615 | 616 | # def reverse(self, output): 617 | # logdet = 0 618 | # input = output[:, np.arange(self.ndim-1, -1, -1)] 619 | # input, det1 = self.coupling2.reverse(input) 620 | # input = input[:, np.arange(self.ndim-1, -1, -1)] 621 | # input, det3 = self.coupling.reverse(input) 622 | 623 | 624 | # if det1 is not None: 625 | # logdet = logdet + det1 626 | # if det3 is not None: 627 | # logdet = logdet + det3 628 | 629 | # return input, logdet 630 | 631 | 632 | def Order_inverse(order): 633 | order_inv = [] 634 | for k in range(len(order)): 635 | for i in range(len(order)): 636 | if order[i] == k: 637 | order_inv.append(i) 638 | return np.array(order_inv) 639 | 640 | 641 | class RealNVP(nn.Module): 642 | def __init__( 643 | self, ndim, ndim_cond, n_flow, affine=True, seqfrac=4, permute='random', batch_norm=True 644 | ): 645 | super().__init__() 646 | self.blocks = nn.ModuleList() 647 | self.orders = [] 648 | for i in range(n_flow): 649 | self.blocks.append(Flow(ndim, ndim_cond, affine=affine, seqfrac=seqfrac, batch_norm=batch_norm)) 650 | if permute == 'random': 651 | self.orders.append(np.random.RandomState(seed=i).permutation(ndim)) 652 | elif permute == 'reverse': 653 | self.orders.append(np.arange(ndim-1, -1, -1)) 654 | else: 655 | print('We can only do no permutation, random permutation or reverse permutation in affine coupling layer. Using no permutation by default!') 656 | self.orders.append(np.arange(ndim)) 657 | 658 | self.inverse_orders = [] 659 | for i in range(n_flow): 660 | self.inverse_orders.append(Order_inverse(self.orders[i])) 661 | 662 | def forward(self, input, cond_input): 663 | logdet = 0 664 | out = input 665 | 666 | for i in range(len(self.blocks)): 667 | out, det = self.blocks[i](out, cond_input) 668 | logdet = logdet + det 669 | out = out[:, self.orders[i]] 670 | 671 | return out, logdet 672 | 673 | def reverse(self, out, cond_input): 674 | logdet = 0 675 | input = out 676 | 677 | for i in range(len(self.blocks)-1, -1, -1): 678 | input = input[:, self.inverse_orders[i]] 679 | input, det = self.blocks[i].reverse(input, cond_input) 680 | logdet = logdet + det 681 | 682 | return input, logdet -------------------------------------------------------------------------------- /DPItorch/orbit_helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | import torch.nn.functional as functional 9 | 10 | torch.set_default_dtype(torch.float32) 11 | import torch.optim as optim 12 | import pickle 13 | import math 14 | 15 | from torchkbnufft import KbNufft, AdjKbNufft 16 | from torchkbnufft.mri.dcomp_calc import calculate_radial_dcomp_pytorch 17 | from torchkbnufft.math import absolute 18 | 19 | from generative_model import realnvpfc_model 20 | 21 | import astropy.units as u 22 | import astropy.constants as consts 23 | import warnings 24 | 25 | import corner 26 | plt.ion() 27 | 28 | def calc_orbit_torch(epochs, sma, ecc, inc, aop, pan, tau, plx, mtot, mass_for_Kamp=None, tau_ref_epoch=58849, tolerance=1e-9, max_iter=10): 29 | 30 | n_orbs = sma.shape[0] # num sets of input orbital parameters 31 | n_dates = epochs.shape[0] # number of dates to compute offsets and vz 32 | 33 | # return planetary RV if `mass_for_Kamp` is not defined 34 | if mass_for_Kamp is None: 35 | mass_for_Kamp = 1.0 * mtot 36 | 37 | # Necessary for _calc_ecc_anom, for now 38 | ecc_arr = torch.matmul(torch.ones_like(epochs).unsqueeze(-1), ecc.unsqueeze(0)) 39 | 40 | # Compute period (from Kepler's third law) and mean motion 41 | period_const = np.sqrt(4*np.pi**2.0*(u.AU)**3/(consts.G*(u.Msun))) 42 | period_const = period_const.to(u.day).value 43 | period = torch.sqrt(sma**3/mtot) * period_const 44 | mean_motion = 2*np.pi/(period) 45 | 46 | # # compute mean anomaly (size: n_orbs x n_dates) 47 | manom = (mean_motion*(epochs.unsqueeze(-1) - tau_ref_epoch) - 2*np.pi*tau) % (2.0*np.pi) 48 | 49 | 50 | # compute eccentric anomalies (size: n_orbs x n_dates) 51 | # eanom = _calc_ecc_anom_torch(manom.numpy(), ecc_arr.numpy(), tolerance=tolerance, max_iter=max_iter) 52 | eanom = _calc_ecc_anom_torch(manom, ecc_arr, tolerance=tolerance, max_iter=max_iter) 53 | # compute the true anomalies (size: n_orbs x n_dates) 54 | # Note: matrix multiplication makes the shapes work out here and below 55 | tanom = 2.*torch.atan(torch.sqrt((1.0 + ecc)/(1.0 - ecc))*torch.tan(0.5*eanom)) 56 | # compute 3-D orbital radius of second body (size: n_orbs x n_dates) 57 | 58 | radius = sma * (1.0 - ecc * torch.cos(eanom)) 59 | 60 | # compute ra/dec offsets (size: n_orbs x n_dates) 61 | # math from James Graham. Lots of trig 62 | c2i2 = torch.cos(0.5*inc)**2 63 | s2i2 = torch.sin(0.5*inc)**2 64 | arg1 = tanom + aop + pan 65 | arg2 = tanom + aop - pan 66 | c1 = torch.cos(arg1) 67 | c2 = torch.cos(arg2) 68 | s1 = torch.sin(arg1) 69 | s2 = torch.sin(arg2) 70 | 71 | # updated sign convention for Green Eq. 19.4-19.7 72 | raoff = radius * (c2i2*s1 - s2i2*s2) * plx 73 | deoff = radius * (c2i2*c1 + s2i2*c2) * plx 74 | 75 | # compute the radial velocity (vz) of the body (size: n_orbs x n_dates) 76 | # first comptue the RV semi-amplitude (size: n_orbs x n_dates) 77 | Kv_const = np.sqrt(consts.G) * (1.0 * u.Msun) / np.sqrt(1.0 * u.Msun * u.au) 78 | Kv_const = Kv_const.to(u.km/u.s).value 79 | Kv = mass_for_Kamp * torch.sqrt(1.0 / ((1.0 - ecc**2) * mtot * sma)) * torch.sin(inc) * Kv_const 80 | 81 | # compute the vz 82 | vz = Kv * (ecc*torch.cos(aop) + torch.cos(aop + tanom)) 83 | return raoff, deoff, vz 84 | 85 | def _calc_ecc_anom_torch(manom, ecc, tolerance=1e-9, max_iter=100): 86 | # Initialize eanom array 87 | eanom = np.full(np.shape(manom), np.nan) 88 | 89 | # Save some boolean arrays 90 | ecc_zero = ecc == 0.0 91 | ecc_high = ecc >= 0.95 92 | # ecc_high = ecc >= 0.9 93 | 94 | # solve kepler equation using newton solver 95 | eanom, abs_diff = _newton_solver_torch( 96 | manom, ecc, tolerance=tolerance, max_iter=max_iter) 97 | 98 | # solve kepler equation using analytical method 99 | eanom_mikkola = _mikkola_solver_wrapper_torch(manom, ecc) 100 | 101 | # # solve kepler equation using newton solver 102 | # eanom, abs_diff = _newton_solver_torch( 103 | # manom, ecc, tolerance=tolerance, max_iter=max_iter, eanom0=eanom_mikkola) 104 | 105 | # use analytical solution when the newton solver is not accurate enough 106 | eanom = torch.where(abs_diff > tolerance, eanom_mikkola, eanom) 107 | # use analytical solver when the ecc is high 108 | eanom = torch.where(ecc_high, eanom_mikkola, eanom) 109 | # use manom when ecc is zero 110 | eanom = torch.where(ecc_zero, manom, eanom) 111 | 112 | # eanom = torch.where(ecc_zero, manom, eanom_mikkola) 113 | 114 | return eanom 115 | 116 | def _newton_solver_torch(manom, ecc, tolerance=1e-9, max_iter=10, eanom0=None): 117 | """ 118 | Newton-Raphson solver for eccentric anomaly. 119 | Args: 120 | manom (np.array): array of mean anomalies 121 | ecc (np.array): array of eccentricities 122 | eanom0 (np.array): array of first guess for eccentric anomaly, same shape as manom (optional) 123 | Return: 124 | eanom (np.array): array of eccentric anomalies 125 | Written: Rob De Rosa, 2018 126 | """ 127 | 128 | # Initialize at E=M, E=pi is better at very high eccentricities 129 | if eanom0 is None: 130 | eanom = 1.0 * manom#torch.clone(manom) 131 | else: 132 | eanom = 1.0 * eanom0#torch.clone(eanom0) 133 | 134 | # Let's do one iteration to start with 135 | eanom = eanom - (eanom - (ecc * torch.sin(eanom)) - manom) / (1.0 - (ecc * torch.cos(eanom))) 136 | 137 | diff = (eanom - (ecc * torch.sin(eanom)) - manom) / (1.0 - (ecc * torch.cos(eanom))) 138 | abs_diff = torch.abs(diff) 139 | # ind = torch.where(abs_diff > tolerance) 140 | niter = 0 141 | while niter <= max_iter: 142 | diff = (eanom - (ecc * torch.sin(eanom)) - manom) / (1.0 - (ecc * torch.cos(eanom))) 143 | eanom = eanom - diff 144 | niter += 1 145 | 146 | diff = (eanom - (ecc * torch.sin(eanom)) - manom) / (1.0 - (ecc * torch.cos(eanom))) 147 | abs_diff = torch.abs(diff) 148 | 149 | # eanom_mikkola = _mikkola_solver_wrapper_torch(manom, ecc) 150 | # eanom = torch.where(abs_diff > tolerance, eanom_mikkola, eanom) 151 | 152 | return eanom, abs_diff 153 | 154 | 155 | 156 | 157 | def _mikkola_solver_wrapper_torch(manom, ecc): 158 | """ 159 | Analtyical Mikkola solver (S. Mikkola. 1987. Celestial Mechanics, 40, 329-334.) for the eccentric anomaly. 160 | Wrapper for the python implemenation of the IDL version. From Rob De Rosa. 161 | Args: 162 | manom (np.array): array of mean anomalies between 0 and 2pi 163 | ecc (np.array): eccentricity 164 | Return: 165 | eanom (np.array): array of eccentric anomalies 166 | Written: Jason Wang, 2018 167 | """ 168 | 169 | eanom1 = _mikkola_solver_torch(manom, ecc) 170 | manom2 = (2.0 * np.pi) - manom 171 | eanom2 = _mikkola_solver_torch(manom2, ecc) 172 | eanom2 = (2.0 * np.pi) - eanom2 173 | 174 | eanom = torch.where(manom > np.pi, eanom2, eanom1) 175 | 176 | return eanom 177 | 178 | 179 | def _mikkola_solver_torch(manom, ecc): 180 | """ 181 | Analtyical Mikkola solver for the eccentric anomaly. 182 | Adapted from IDL routine keplereq.pro by Rob De Rosa http://www.lpl.arizona.edu/~bjackson/idl_code/keplereq.pro 183 | Args: 184 | manom (float or np.array): mean anomaly, must be between 0 and pi. 185 | ecc (float or np.array): eccentricity 186 | Return: 187 | eanom (np.array): array of eccentric anomalies 188 | Written: Jason Wang, 2018 189 | """ 190 | 191 | alpha = (1.0 - ecc) / ((4.0 * ecc) + 0.5) 192 | beta = (0.5 * manom) / ((4.0 * ecc) + 0.5) 193 | 194 | aux = torch.sqrt(beta**2.0 + alpha**3.0) 195 | z = torch.abs(beta + aux)**(1.0/3.0) 196 | 197 | s0 = z - (alpha/z) 198 | s1 = s0 - (0.078*(s0**5.0)) / (1.0 + ecc) 199 | e0 = manom + (ecc * (3.0*s1 - 4.0*(s1**3.0))) 200 | 201 | se0 = torch.sin(e0) 202 | ce0 = torch.cos(e0) 203 | 204 | f = e0-ecc*se0-manom 205 | f1 = 1.0-ecc*ce0 206 | f2 = ecc*se0 207 | f3 = ecc*ce0 208 | f4 = -f2 209 | u1 = -f/f1 210 | u2 = -f/(f1+0.5*f2*u1) 211 | u3 = -f/(f1+0.5*f2*u2+(1.0/6.0)*f3*u2*u2) 212 | u4 = -f/(f1+0.5*f2*u3+(1.0/6.0)*f3*u3*u3+(1.0/24.0)*f4*(u3**3.0)) 213 | 214 | return (e0 + u4) 215 | 216 | 217 | 218 | 219 | # ############################################################################ 220 | # ## original orbitize 221 | # ############################################################################ 222 | # """ 223 | # This module solves for the orbit of the planet given Keplerian parameters. 224 | # """ 225 | # import numpy as np 226 | # import astropy.units as u 227 | # import astropy.constants as consts 228 | # import warnings # to be removed after tau_ref_epoch warning is removed. 229 | 230 | # try: 231 | # from . import _kepler 232 | # cext = True 233 | # except ImportError: 234 | # print("WARNING: KEPLER: Unable to import C-based Kepler's \ 235 | # equation solver. Falling back to the slower NumPy implementation.") 236 | # cext = False 237 | 238 | 239 | # def calc_orbit(epochs, sma, ecc, inc, aop, pan, tau, plx, mtot, mass_for_Kamp=None, tau_ref_epoch=58849, tolerance=1e-9, max_iter=100, tau_warning=True): 240 | # """ 241 | # Returns the separation and radial velocity of the body given array of 242 | # orbital parameters (size n_orbs) at given epochs (array of size n_dates) 243 | # Based on orbit solvers from James Graham and Rob De Rosa. Adapted by Jason Wang and Henry Ngo. 244 | # Args: 245 | # epochs (np.array): MJD times for which we want the positions of the planet 246 | # sma (np.array): semi-major axis of orbit [au] 247 | # ecc (np.array): eccentricity of the orbit [0,1] 248 | # inc (np.array): inclination [radians] 249 | # aop (np.array): argument of periastron [radians] 250 | # pan (np.array): longitude of the ascending node [radians] 251 | # tau (np.array): epoch of periastron passage in fraction of orbital period past MJD=0 [0,1] 252 | # plx (np.array): parallax [mas] 253 | # mtot (np.array): total mass of the two-body orbit (M_* + M_planet) [Solar masses] 254 | # mass_for_Kamp (np.array, optional): mass of the body that causes the RV signal. 255 | # For example, if you want to return the stellar RV, this is the planet mass. 256 | # If you want to return the planetary RV, this is the stellar mass. [Solar masses]. 257 | # For planet mass ~ 0, mass_for_Kamp ~ M_tot, and function returns planetary RV (default). 258 | # tau_ref_epoch (float, optional): reference date that tau is defined with respect to (i.e., tau=0) 259 | # tolerance (float, optional): absolute tolerance of iterative computation. Defaults to 1e-9. 260 | # max_iter (int, optional): maximum number of iterations before switching. Defaults to 100. 261 | # tau_warning (bool, optional, depricating): temporary argument to warn users about tau_ref_epoch default value change. 262 | # Users that are calling this function themsleves should receive a warning since default is True. 263 | # To be removed when tau_ref_epoch change is fully propogated to users. Users can turn it off to stop getting the warning. 264 | # Return: 265 | # 3-tuple: 266 | # raoff (np.array): array-like (n_dates x n_orbs) of RA offsets between the bodies 267 | # (origin is at the other body) [mas] 268 | # deoff (np.array): array-like (n_dates x n_orbs) of Dec offsets between the bodies [mas] 269 | # vz (np.array): array-like (n_dates x n_orbs) of radial velocity of one of the bodies 270 | # (see `mass_for_Kamp` description) [km/s] 271 | # Written: Jason Wang, Henry Ngo, 2018 272 | # """ 273 | # if tau_warning: 274 | # warnings.warn("tau_ref_epoch default for kepler.calc_orbit is 58849 now instead of 0 MJD. " 275 | # "Please check that this does not break your code. You can turn off this warning by setting " 276 | # "tau_warning=False when you call kepler.calc_orbit.") 277 | 278 | # n_orbs = np.size(sma) # num sets of input orbital parameters 279 | # n_dates = np.size(epochs) # number of dates to compute offsets and vz 280 | 281 | # # return planetary RV if `mass_for_Kamp` is not defined 282 | # if mass_for_Kamp is None: 283 | # mass_for_Kamp = mtot 284 | 285 | # # Necessary for _calc_ecc_anom, for now 286 | # if np.isscalar(epochs): # just in case epochs is given as a scalar 287 | # epochs = np.array([epochs]) 288 | # ecc_arr = np.tile(ecc, (n_dates, 1)) 289 | 290 | # # Compute period (from Kepler's third law) and mean motion 291 | # period = np.sqrt(4*np.pi**2.0*(sma*u.AU)**3/(consts.G*(mtot*u.Msun))) 292 | # period = period.to(u.day).value 293 | # mean_motion = 2*np.pi/(period) # in rad/day 294 | 295 | # # # compute mean anomaly (size: n_orbs x n_dates) 296 | # manom = (mean_motion*(epochs[:, None] - tau_ref_epoch) - 2*np.pi*tau) % (2.0*np.pi) 297 | # # compute eccentric anomalies (size: n_orbs x n_dates) 298 | # eanom = _calc_ecc_anom(manom, ecc_arr, tolerance=tolerance, max_iter=max_iter) 299 | # # compute the true anomalies (size: n_orbs x n_dates) 300 | # # Note: matrix multiplication makes the shapes work out here and below 301 | # tanom = 2.*np.arctan(np.sqrt((1.0 + ecc)/(1.0 - ecc))*np.tan(0.5*eanom)) 302 | # # compute 3-D orbital radius of second body (size: n_orbs x n_dates) 303 | # radius = sma * (1.0 - ecc * np.cos(eanom)) 304 | 305 | # # compute ra/dec offsets (size: n_orbs x n_dates) 306 | # # math from James Graham. Lots of trig 307 | # c2i2 = np.cos(0.5*inc)**2 308 | # s2i2 = np.sin(0.5*inc)**2 309 | # arg1 = tanom + aop + pan 310 | # arg2 = tanom + aop - pan 311 | # c1 = np.cos(arg1) 312 | # c2 = np.cos(arg2) 313 | # s1 = np.sin(arg1) 314 | # s2 = np.sin(arg2) 315 | 316 | # # updated sign convention for Green Eq. 19.4-19.7 317 | # raoff = radius * (c2i2*s1 - s2i2*s2) * plx 318 | # deoff = radius * (c2i2*c1 + s2i2*c2) * plx 319 | 320 | # # compute the radial velocity (vz) of the body (size: n_orbs x n_dates) 321 | # # first comptue the RV semi-amplitude (size: n_orbs x n_dates) 322 | # Kv = np.sqrt(consts.G / (1.0 - ecc**2)) * (mass_for_Kamp * u.Msun * 323 | # np.sin(inc)) / np.sqrt(mtot * u.Msun) / np.sqrt(sma * u.au) 324 | # # Convert to km/s 325 | # Kv = Kv.to(u.km/u.s) 326 | 327 | # # compute the vz 328 | # vz = Kv.value * (ecc*np.cos(aop) + np.cos(aop + tanom)) 329 | # # Squeeze out extra dimension (useful if n_orbs = 1, does nothing if n_orbs > 1) 330 | # vz = np.squeeze(vz)[()] 331 | # return raoff, deoff, vz 332 | 333 | # def _calc_ecc_anom(manom, ecc, tolerance=1e-9, max_iter=100, use_c=False): 334 | # """ 335 | # Computes the eccentric anomaly from the mean anomlay. 336 | # Code from Rob De Rosa's orbit solver (e < 0.95 use Newton, e >= 0.95 use Mikkola) 337 | # Args: 338 | # manom (float/np.array): mean anomaly, either a scalar or np.array of any shape 339 | # ecc (float/np.array): eccentricity, either a scalar or np.array of the same shape as manom 340 | # tolerance (float, optional): absolute tolerance of iterative computation. Defaults to 1e-9. 341 | # max_iter (int, optional): maximum number of iterations before switching. Defaults to 100. 342 | # Return: 343 | # eanom (float/np.array): eccentric anomalies, same shape as manom 344 | # Written: Jason Wang, 2018 345 | # """ 346 | 347 | # if np.isscalar(ecc) or (np.shape(manom) == np.shape(ecc)): 348 | # pass 349 | # else: 350 | # raise ValueError("ecc must be a scalar, or ecc.shape == manom.shape") 351 | 352 | # # If manom is a scalar, make it into a one-element array 353 | # if np.isscalar(manom): 354 | # manom = np.array((manom, )) 355 | 356 | # # If ecc is a scalar, make it the same shape as manom 357 | # if np.isscalar(ecc): 358 | # ecc = np.full(np.shape(manom), ecc) 359 | 360 | # # Initialize eanom array 361 | # eanom = np.full(np.shape(manom), np.nan) 362 | 363 | # # Save some boolean arrays 364 | # ecc_zero = ecc == 0.0 365 | # ecc_low = ecc < 0.95 366 | 367 | # # First deal with e == 0 elements 368 | # ind_zero = np.where(ecc_zero) 369 | # if len(ind_zero[0]) > 0: 370 | # eanom[ind_zero] = manom[ind_zero] 371 | 372 | # # Now low eccentricities 373 | # ind_low = np.where(~ecc_zero & ecc_low) 374 | # if cext and use_c: 375 | # if len(ind_low[0]) > 0: eanom[ind_low] = _kepler._c_newton_solver(manom[ind_low], ecc[ind_low], tolerance=tolerance, max_iter=max_iter) 376 | 377 | # # the C solver returns eanom = -1 if it doesnt converge after max_iter iterations 378 | # m_one = eanom == -1 379 | # ind_high = np.where(~ecc_zero & ~ecc_low | m_one) 380 | # else: 381 | # if len(ind_low[0]) > 0: 382 | # eanom[ind_low] = _newton_solver( 383 | # manom[ind_low], ecc[ind_low], tolerance=tolerance, max_iter=max_iter) 384 | # ind_high = np.where(~ecc_zero & ~ecc_low) 385 | 386 | # # Now high eccentricities 387 | # if len(ind_high[0]) > 0: 388 | # eanom[ind_high] = _mikkola_solver_wrapper(manom[ind_high], ecc[ind_high], use_c) 389 | 390 | # return np.squeeze(eanom)[()] 391 | 392 | 393 | # def _newton_solver(manom, ecc, tolerance=1e-9, max_iter=100, eanom0=None, use_c=False): 394 | # """ 395 | # Newton-Raphson solver for eccentric anomaly. 396 | # Args: 397 | # manom (np.array): array of mean anomalies 398 | # ecc (np.array): array of eccentricities 399 | # eanom0 (np.array): array of first guess for eccentric anomaly, same shape as manom (optional) 400 | # Return: 401 | # eanom (np.array): array of eccentric anomalies 402 | # Written: Rob De Rosa, 2018 403 | # """ 404 | # # Ensure manom and ecc are np.array (might get passed as astropy.Table Columns instead) 405 | # manom = np.array(manom) 406 | # ecc = np.array(ecc) 407 | 408 | # # Initialize at E=M, E=pi is better at very high eccentricities 409 | # if eanom0 is None: 410 | # eanom = np.copy(manom) 411 | # else: 412 | # eanom = np.copy(eanom0) 413 | 414 | # # Let's do one iteration to start with 415 | # eanom -= (eanom - (ecc * np.sin(eanom)) - manom) / (1.0 - (ecc * np.cos(eanom))) 416 | 417 | # diff = (eanom - (ecc * np.sin(eanom)) - manom) / (1.0 - (ecc * np.cos(eanom))) 418 | # abs_diff = np.abs(diff) 419 | # ind = np.where(abs_diff > tolerance) 420 | # niter = 0 421 | # while ((ind[0].size > 0) and (niter <= max_iter)): 422 | # eanom[ind] -= diff[ind] 423 | # # If it hasn't converged after half the iterations are done, try starting from pi 424 | # if niter == (max_iter//2): 425 | # eanom[ind] = np.pi 426 | # diff[ind] = (eanom[ind] - (ecc[ind] * np.sin(eanom[ind])) - manom[ind]) / \ 427 | # (1.0 - (ecc[ind] * np.cos(eanom[ind]))) 428 | # abs_diff[ind] = np.abs(diff[ind]) 429 | # ind = np.where(abs_diff > tolerance) 430 | # niter += 1 431 | 432 | # if niter >= max_iter: 433 | # print(manom[ind], eanom[ind], diff[ind], ecc[ind], '> {} iter.'.format(max_iter)) 434 | # eanom[ind] = _mikkola_solver_wrapper(manom[ind], ecc[ind], use_c) # Send remaining orbits to the analytical version, this has not happened yet... 435 | 436 | # return eanom 437 | 438 | # def _mikkola_solver_wrapper(manom, ecc, use_c): 439 | # """ 440 | # Analtyical Mikkola solver (S. Mikkola. 1987. Celestial Mechanics, 40, 329-334.) for the eccentric anomaly. 441 | # Wrapper for the python implemenation of the IDL version. From Rob De Rosa. 442 | # Args: 443 | # manom (np.array): array of mean anomalies between 0 and 2pi 444 | # ecc (np.array): eccentricity 445 | # Return: 446 | # eanom (np.array): array of eccentric anomalies 447 | # Written: Jason Wang, 2018 448 | # """ 449 | 450 | # ind_change = np.where(manom > np.pi) 451 | # manom[ind_change] = (2.0 * np.pi) - manom[ind_change] 452 | # if cext and use_c: 453 | # eanom = _kepler._c_mikkola_solver(manom, ecc) 454 | # else: 455 | # eanom = _mikkola_solver(manom, ecc) 456 | # eanom[ind_change] = (2.0 * np.pi) - eanom[ind_change] 457 | 458 | # return eanom 459 | 460 | 461 | # def _mikkola_solver(manom, ecc): 462 | # """ 463 | # Analtyical Mikkola solver for the eccentric anomaly. 464 | # Adapted from IDL routine keplereq.pro by Rob De Rosa http://www.lpl.arizona.edu/~bjackson/idl_code/keplereq.pro 465 | # Args: 466 | # manom (float or np.array): mean anomaly, must be between 0 and pi. 467 | # ecc (float or np.array): eccentricity 468 | # Return: 469 | # eanom (np.array): array of eccentric anomalies 470 | # Written: Jason Wang, 2018 471 | # """ 472 | 473 | # alpha = (1.0 - ecc) / ((4.0 * ecc) + 0.5) 474 | # beta = (0.5 * manom) / ((4.0 * ecc) + 0.5) 475 | 476 | # aux = np.sqrt(beta**2.0 + alpha**3.0) 477 | # z = np.abs(beta + aux)**(1.0/3.0) 478 | 479 | # s0 = z - (alpha/z) 480 | # s1 = s0 - (0.078*(s0**5.0)) / (1.0 + ecc) 481 | # e0 = manom + (ecc * (3.0*s1 - 4.0*(s1**3.0))) 482 | 483 | # se0 = np.sin(e0) 484 | # ce0 = np.cos(e0) 485 | 486 | # f = e0-ecc*se0-manom 487 | # f1 = 1.0-ecc*ce0 488 | # f2 = ecc*se0 489 | # f3 = ecc*ce0 490 | # f4 = -f2 491 | # u1 = -f/f1 492 | # u2 = -f/(f1+0.5*f2*u1) 493 | # u3 = -f/(f1+0.5*f2*u2+(1.0/6.0)*f3*u2*u2) 494 | # u4 = -f/(f1+0.5*f2*u3+(1.0/6.0)*f3*u3*u3+(1.0/24.0)*f4*(u3**3.0)) 495 | 496 | # return (e0 + u4) 497 | 498 | 499 | # ################################################################################ 500 | # ## test the forward model implementation 501 | # ################################################################################ 502 | 503 | # import pandas as pd 504 | 505 | # astrometry_data = pd.read_csv('../dataset/orbital_fit/betapic_astrometry.csv') 506 | # # astrometry_data = astrometry_data[0:18] 507 | 508 | # raoff_true = np.array(astrometry_data['raoff'][0:18]) 509 | # raoff_err = np.array(astrometry_data['raoff_err'][0:18]) 510 | # decoff_true = np.array(astrometry_data['decoff'][0:18]) 511 | # decoff_err = np.array(astrometry_data['decoff_err'][0:18]) 512 | 513 | # sep_true = np.array(astrometry_data['sep']) 514 | # sep_err = np.array(astrometry_data['sep_err']) 515 | # pa_values = np.array(astrometry_data['pa']) 516 | # pa_values[pa_values>180] = pa_values[pa_values>180] - 360 517 | # pa_true = np.pi / 180 * pa_values 518 | # pa_err = np.pi / 180 * np.array(astrometry_data['pa_err']) 519 | 520 | 521 | 522 | # if torch.cuda.is_available(): 523 | # device = torch.device('cuda:{}'.format(0)) 524 | # epochs = torch.tensor(np.array(astrometry_data['epoch']), dtype=torch.float32).to(device) 525 | # epochs_np = epochs.cpu().numpy() 526 | 527 | 528 | # sma_np = np.array(9.2, dtype=np.float32).reshape((1, )) 529 | # ecc_np = np.array(0.05, dtype=np.float32).reshape((1, )) 530 | # inc_np = np.array(np.radians(88.9), dtype=np.float32).reshape((1, )) 531 | # aop_np = np.array(np.radians(220), dtype=np.float32).reshape((1, )) 532 | # pan_np = np.array(np.radians(31.85), dtype=np.float32).reshape((1, )) 533 | # tau_np = np.array(0.2, dtype=np.float32).reshape((1, )) 534 | # plx_np = np.array(51.5, dtype=np.float32).reshape((1, )) 535 | # mtot_np = np.array(1.8, dtype=np.float32).reshape((1, )) 536 | 537 | 538 | # sma = torch.tensor(sma_np).to(device) 539 | # ecc = torch.tensor(ecc_np).to(device) 540 | # inc = torch.tensor(inc_np).to(device) 541 | # aop = torch.tensor(aop_np).to(device) 542 | # pan = torch.tensor(pan_np).to(device) 543 | # tau = torch.tensor(tau_np).to(device) 544 | # plx = torch.tensor(plx_np).to(device) 545 | # mtot = torch.tensor(mtot_np).to(device) 546 | 547 | 548 | 549 | # raoff, deoff, vz = calc_orbit(epochs_np, sma_np, ecc_np, inc_np, aop_np, pan_np, tau_np, plx_np, mtot_np, max_iter=100, tau_ref_epoch=50000) 550 | # sep = np.sqrt(raoff**2 + deoff**2) 551 | # pa = np.arctan2(raoff, deoff) 552 | 553 | # raoff1 = np.array(raoff[0:18]) 554 | # deoff1 = np.array(deoff[0:18]) 555 | # pa1 = np.array(pa) 556 | # sep1 = np.array(sep) 557 | 558 | # # logl1 = (raoff1 - raoff_true)**2 / raoff_err**2 + (deoff1 - decoff_true)**2 / decoff_err**2 559 | # logl1 = - 0.5 * np.sum((raoff1 - raoff_true)**2 / raoff_err**2 + (deoff1 - decoff_true)**2 / decoff_err**2) - \ 560 | # np.sum(np.log(np.sqrt(2*np.pi)*raoff_err)) - np.sum(np.log(np.sqrt(2*np.pi)*decoff_err)) 561 | 562 | # # logl_all1 = - 0.5 * np.sum((raoff1 - raoff_true)**2 / raoff_err**2 + (deoff1 - decoff_true)**2 / decoff_err**2) - \ 563 | # # 0.5 * np.sum((np.arctan2(np.sin(pa1-pa_true), np.cos(pa1-pa_true)))**2 / pa_err**2 + (sep1 - sep_true)**2 / sep_err**2) - \ 564 | # # np.sum(np.log(np.sqrt(2*np.pi)*raoff_err)) - np.sum(np.log(np.sqrt(2*np.pi)*decoff_err)) - \ 565 | # # np.sum(np.log(np.sqrt(2*np.pi)*pa_err)) - np.sum(np.log(np.sqrt(2*np.pi)*sep_err)) 566 | 567 | 568 | # logl_all1 = - 0.5 * np.sum((raoff1 - raoff_true)**2 / raoff_err**2 + (deoff1 - decoff_true)**2 / decoff_err**2) - \ 569 | # 0.5 * np.sum((np.arctan2(np.sin(pa1[18::]-pa_true[18::]), np.cos(pa1[18::]-pa_true[18::])))**2 / pa_err[18::]**2 + (sep1[18::] - sep_true[18::])**2 / sep_err[18::]**2) - \ 570 | # np.sum(np.log(np.sqrt(2*np.pi)*raoff_err)) - np.sum(np.log(np.sqrt(2*np.pi)*decoff_err)) - \ 571 | # np.sum(np.log(np.sqrt(2*np.pi)*pa_err[18::]*180/np.pi)) - np.sum(np.log(np.sqrt(2*np.pi)*sep_err[18::])) 572 | 573 | # # logl_all1 = - 0.5 * np.sum((np.arctan2(np.sin(pa1-pa_true), np.cos(pa1-pa_true)))**2 / pa_err**2 + (sep1 - sep_true)**2 / sep_err**2) - \ 574 | # # np.sum(np.log(np.sqrt(2*np.pi)*pa_err*180/np.pi)) - np.sum(np.log(np.sqrt(2*np.pi)*sep_err)) 575 | 576 | # # logl_all1 = - 0.5 * np.sum((raoff1 - raoff_true)**2 / raoff_err**2 + (deoff1 - decoff_true)**2 / decoff_err**2) - \ 577 | # # 0.5 * np.sum((np.arctan(np.sin(pa1-pa_true)/np.cos(pa1-pa_true)))**2 / pa_err**2 + (sep1 - sep_true)**2 / sep_err**2) - \ 578 | # # np.sum(np.log(np.sqrt(2*np.pi)*raoff_err)) - np.sum(np.log(np.sqrt(2*np.pi)*decoff_err)) - \ 579 | # # np.sum(np.log(np.sqrt(2*np.pi)*pa_err)) - np.sum(np.log(np.sqrt(2*np.pi)*sep_err)) 580 | 581 | 582 | # raoff_torch, deoff_torch, vz_torch = calc_orbit_torch(epochs, sma, ecc, inc, aop, pan, tau, plx, mtot, max_iter=10, tau_ref_epoch=50000) 583 | 584 | # raoff1_torch = np.array(raoff_torch.cpu().numpy()).squeeze()[0:18] 585 | # deoff1_torch = np.array(deoff_torch.cpu().numpy()).squeeze()[0:18] 586 | # sep1_torch = np.array(torch.sqrt(raoff_torch**2 + deoff_torch**2).cpu().numpy()).squeeze() 587 | # pa1_torch = np.array(torch.atan2(raoff_torch, deoff_torch).cpu().numpy()).squeeze() 588 | 589 | 590 | # # logl1_torch = (raoff1_torch - raoff_true)**2 / raoff_err**2 + (deoff1_torch - decoff_true)**2 / decoff_err**2 591 | # logl1_torch = - 0.5 * np.sum((raoff1_torch - raoff_true)**2 / raoff_err**2 + (deoff1_torch - decoff_true)**2 / decoff_err**2) - \ 592 | # np.sum(np.log(np.sqrt(2*np.pi)*raoff_err)) - np.sum(np.log(np.sqrt(2*np.pi)*decoff_err)) 593 | 594 | # logl_all1_torch = - 0.5 * np.sum((raoff1_torch - raoff_true)**2 / raoff_err**2 + (deoff1_torch - decoff_true)**2 / decoff_err**2) - \ 595 | # 0.5 * np.sum((np.arctan2(np.sin(pa1_torch[18::]-pa_true[18::]), np.cos(pa1_torch[18::]-pa_true[18::])))**2 / pa_err[18::]**2 + (sep1_torch[18::] - sep_true[18::])**2 / sep_err[18::]**2) - \ 596 | # np.sum(np.log(np.sqrt(2*np.pi)*raoff_err)) - np.sum(np.log(np.sqrt(2*np.pi)*decoff_err)) - \ 597 | # np.sum(np.log(np.sqrt(2*np.pi)*pa_err[18::]*180/np.pi)) - np.sum(np.log(np.sqrt(2*np.pi)*sep_err[18::])) 598 | 599 | 600 | # sma_np = np.array(11., dtype=np.float32).reshape((1, )) 601 | # ecc_np = np.array(0.16, dtype=np.float32).reshape((1, )) 602 | # inc_np = np.array(np.radians(88.9), dtype=np.float32).reshape((1, )) 603 | # aop_np = np.array(np.radians(190), dtype=np.float32).reshape((1, )) 604 | # pan_np = np.array(np.radians(31.85), dtype=np.float32).reshape((1, )) 605 | # tau_np = np.array(0.6, dtype=np.float32).reshape((1, )) 606 | # plx_np = np.array(51.5, dtype=np.float32).reshape((1, )) 607 | # mtot_np = np.array(1.8, dtype=np.float32).reshape((1, )) 608 | 609 | 610 | # sma = torch.tensor(sma_np).to(device) 611 | # ecc = torch.tensor(ecc_np).to(device) 612 | # inc = torch.tensor(inc_np).to(device) 613 | # aop = torch.tensor(aop_np).to(device) 614 | # pan = torch.tensor(pan_np).to(device) 615 | # tau = torch.tensor(tau_np).to(device) 616 | # plx = torch.tensor(plx_np).to(device) 617 | # mtot = torch.tensor(mtot_np).to(device) 618 | 619 | 620 | 621 | # raoff, deoff, vz = calc_orbit(epochs_np, sma_np, ecc_np, inc_np, aop_np, pan_np, tau_np, plx_np, mtot_np, max_iter=100, tau_ref_epoch=50000) 622 | # sep = np.sqrt(raoff**2 + deoff**2) 623 | # pa = np.arctan2(raoff, deoff) 624 | 625 | # raoff2 = np.array(raoff[0:18]) 626 | # deoff2 = np.array(deoff[0:18]) 627 | # pa2 = np.array(pa) 628 | # sep2 = np.array(sep) 629 | 630 | # # logl2 = (raoff2 - raoff_true)**2 / raoff_err**2 + (deoff2 - decoff_true)**2 / decoff_err**2 631 | # logl2 = - 0.5 * np.sum((raoff2 - raoff_true)**2 / raoff_err**2 + (deoff2 - decoff_true)**2 / decoff_err**2) - \ 632 | # np.sum(np.log(np.sqrt(2*np.pi)*raoff_err)) - np.sum(np.log(np.sqrt(2*np.pi)*decoff_err)) 633 | 634 | 635 | # # logl_all2 = - 0.5 * np.sum((raoff2 - raoff_true)**2 / raoff_err**2 + (deoff2 - decoff_true)**2 / decoff_err**2) - \ 636 | # # 0.5 * np.sum((np.arctan2(np.sin(pa2-pa_true), np.cos(pa2-pa_true)))**2 / pa_err**2 + (sep2 - sep_true)**2 / sep_err**2) - \ 637 | # # np.sum(np.log(np.sqrt(2*np.pi)*raoff_err)) - np.sum(np.log(np.sqrt(2*np.pi)*decoff_err)) - \ 638 | # # np.sum(np.log(np.sqrt(2*np.pi)*pa_err)) - np.sum(np.log(np.sqrt(2*np.pi)*sep_err)) 639 | 640 | # logl_all2 = - 0.5 * np.sum((raoff2 - raoff_true)**2 / raoff_err**2 + (deoff2 - decoff_true)**2 / decoff_err**2) - \ 641 | # 0.5 * np.sum((np.arctan2(np.sin(pa2[18::]-pa_true[18::]), np.cos(pa2[18::]-pa_true[18::])))**2 / pa_err[18::]**2 + (sep2[18::] - sep_true[18::])**2 / sep_err[18::]**2) - \ 642 | # np.sum(np.log(np.sqrt(2*np.pi)*raoff_err)) - np.sum(np.log(np.sqrt(2*np.pi)*decoff_err)) - \ 643 | # np.sum(np.log(np.sqrt(2*np.pi)*pa_err[18::]*180/np.pi)) - np.sum(np.log(np.sqrt(2*np.pi)*sep_err[18::])) 644 | 645 | 646 | # # logl_all2 = - 0.5 * np.sum((raoff2 - raoff_true)**2 / raoff_err**2 + (deoff2 - decoff_true)**2 / decoff_err**2) - \ 647 | # # 0.5 * np.sum((np.arctan(np.sin(pa2-pa_true)/np.cos(pa2-pa_true)))**2 / pa_err**2 + (sep2 - sep_true)**2 / sep_err**2) - \ 648 | # # np.sum(np.log(np.sqrt(2*np.pi)*raoff_err)) - np.sum(np.log(np.sqrt(2*np.pi)*decoff_err)) - \ 649 | # # np.sum(np.log(np.sqrt(2*np.pi)*pa_err)) - np.sum(np.log(np.sqrt(2*np.pi)*sep_err)) 650 | 651 | # raoff_torch, deoff_torch, vz_torch = calc_orbit_torch(epochs, sma, ecc, inc, aop, pan, tau, plx, mtot, max_iter=10, tau_ref_epoch=50000) 652 | 653 | # raoff2_torch = np.array(raoff_torch.cpu().numpy()).squeeze()[0:18] 654 | # deoff2_torch = np.array(deoff_torch.cpu().numpy()).squeeze()[0:18] 655 | # sep2_torch = np.array(torch.sqrt(raoff_torch**2 + deoff_torch**2).cpu().numpy()).squeeze() 656 | # pa2_torch = np.array(torch.atan2(raoff_torch, deoff_torch).cpu().numpy()).squeeze() 657 | 658 | 659 | # # logl2_torch = (raoff2_torch - raoff_true)**2 / raoff_err**2 + (deoff2_torch - decoff_true)**2 / decoff_err**2 660 | # logl2_torch = - 0.5 * np.sum((raoff2_torch - raoff_true)**2 / raoff_err**2 + (deoff2_torch - decoff_true)**2 / decoff_err**2) - \ 661 | # np.sum(np.log(np.sqrt(2*np.pi)*raoff_err)) - np.sum(np.log(np.sqrt(2*np.pi)*decoff_err)) 662 | 663 | 664 | # logl_all2_torch = - 0.5 * np.sum((raoff2_torch - raoff_true)**2 / raoff_err**2 + (deoff2_torch - decoff_true)**2 / decoff_err**2) - \ 665 | # 0.5 * np.sum((np.arctan2(np.sin(pa2_torch[18::]-pa_true[18::]), np.cos(pa2_torch[18::]-pa_true[18::])))**2 / pa_err[18::]**2 + (sep2_torch[18::] - sep_true[18::])**2 / sep_err[18::]**2) - \ 666 | # np.sum(np.log(np.sqrt(2*np.pi)*raoff_err)) - np.sum(np.log(np.sqrt(2*np.pi)*decoff_err)) - \ 667 | # np.sum(np.log(np.sqrt(2*np.pi)*pa_err[18::]*180/np.pi)) - np.sum(np.log(np.sqrt(2*np.pi)*sep_err[18::])) 668 | 669 | -------------------------------------------------------------------------------- /DPItorch/DPIx_orbit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | import torch.nn.functional as functional 9 | 10 | torch.set_default_dtype(torch.float32) 11 | import torch.optim as optim 12 | import pickle 13 | import math 14 | 15 | from torchkbnufft import KbNufft, AdjKbNufft 16 | from torchkbnufft.mri.dcomp_calc import calculate_radial_dcomp_pytorch 17 | from torchkbnufft.math import absolute 18 | 19 | from generative_model import realnvpfc_model 20 | from orbit_helpers import * 21 | 22 | import astropy.units as u 23 | import astropy.constants as consts 24 | import warnings 25 | 26 | import corner 27 | import argparse 28 | 29 | plt.ion() 30 | 31 | 32 | import time 33 | import pandas as pd 34 | 35 | 36 | class Params2orbits(nn.Module): 37 | def __init__(self, sma_range=[10.0, 1000.0], ecc_range=[0.0, 1.0], 38 | inc_range=[0.0, np.pi], aop_range=[0.0, 2*np.pi], 39 | pan_range=[0.0, 2*np.pi], tau_range=[0.0, 1.0], 40 | plx_range=[56.95-3*0.26, 56.95+3*0.26], mtot_range=[1.22-3*0.08, 1.22+3*0.08]): 41 | super().__init__() 42 | self.sma_range = sma_range 43 | self.ecc_range = ecc_range 44 | self.inc_range = inc_range 45 | self.aop_range = aop_range 46 | self.pan_range = pan_range 47 | self.tau_range = tau_range 48 | self.plx_range = plx_range 49 | self.mtot_range = mtot_range 50 | 51 | 52 | 53 | def forward(self, params): 54 | log_sma = np.log(self.sma_range[0]) + params[:, 0] * (np.log(self.sma_range[1])-np.log(self.sma_range[0])) 55 | sma = torch.exp(log_sma) 56 | # sma = self.sma_range[0] + params[:, 0] * (self.sma_range[1]-self.sma_range[0]) 57 | # log_ecc = np.log(np.max([self.ecc_range[0], 1e-8])) + params[:, 1] * (np.log(self.ecc_range[1])-np.log(np.max([self.ecc_range[0], 1e-8]))) 58 | # ecc = torch.exp(log_ecc) 59 | ecc = self.ecc_range[0] + params[:, 1] * (self.ecc_range[1]-self.ecc_range[0]) 60 | inc = torch.acos(np.cos(self.inc_range[1]) + params[:, 2] * (np.cos(self.inc_range[0])-np.cos(self.inc_range[1]))) 61 | aop = self.aop_range[0] + params[:, 3] * (self.aop_range[1]-self.aop_range[0]) 62 | pan = self.pan_range[0] + params[:, 4] * (self.pan_range[1]-self.pan_range[0]) 63 | tau = self.tau_range[0] + params[:, 5] * (self.tau_range[1]-self.tau_range[0]) 64 | plx = self.plx_range[0] + params[:, 6] * (self.plx_range[1]-self.plx_range[0]) 65 | mtot = self.mtot_range[0] + params[:, 7] * (self.mtot_range[1]-self.mtot_range[0]) 66 | 67 | return sma, ecc, inc, aop%(2*np.pi), pan%(2*np.pi), tau%1, plx, mtot 68 | # return sma, ecc, inc, aop, pan, tau, plx, mtot 69 | def reverse(self, sma, ecc, inc, aop, pan, tau, plx, mtot): 70 | log_sma = torch.log(sma) 71 | params0 = (log_sma - np.log(self.sma_range[0])) / (np.log(self.sma_range[1])-np.log(self.sma_range[0])) 72 | # params0 = (sma - self.sma_range[0]) / (self.sma_range[1]-self.sma_range[0]) 73 | # log_ecc = torch.log(ecc) 74 | # params1 = (log_ecc - np.log(np.max([self.ecc_range[0], 1e-8]))) / (np.log(self.ecc_range[1])-np.log(np.max([self.ecc_range[0], 1e-8]))) 75 | params1 = (ecc - self.ecc_range[0]) / (self.ecc_range[1]-self.ecc_range[0]) 76 | params2 = (torch.cos(inc) - np.cos(self.inc_range[1]))/(np.cos(self.inc_range[0])-np.cos(self.inc_range[1])) 77 | params3 = (aop - self.aop_range[0]) / (self.aop_range[1]-self.aop_range[0]) 78 | params4 = (pan - self.pan_range[0]) / (self.pan_range[1]-self.pan_range[0]) 79 | params5 = (tau - self.tau_range[0]) / (self.tau_range[1]-self.tau_range[0]) 80 | params6 = (plx - self.plx_range[0]) / (self.plx_range[1]-self.plx_range[0]) 81 | params7 = (mtot - self.mtot_range[0]) / (self.mtot_range[1]-self.mtot_range[0]) 82 | 83 | return torch.cat([params0.unsqueeze(-1), params1.unsqueeze(-1), params2.unsqueeze(-1), params3.unsqueeze(-1), 84 | params4.unsqueeze(-1), params5.unsqueeze(-1), params6.unsqueeze(-1), params7.unsqueeze(-1)], -1) 85 | 86 | 87 | parser = argparse.ArgumentParser(description="Deep Probabilistic Imaging Trainer for orbit fitting") 88 | 89 | parser.add_argument("--divergence_type", default='alpha', type=str, help="KL or alpha, type of objective divergence used for variational inference") 90 | parser.add_argument("--alpha_divergence", default=1.0, type=float, help="hyperparameters for alpha divergence") 91 | parser.add_argument("--save_path", default='./checkpoint/orbit_beta_pic_b/cartesian/alpha1', type=str, help="path to save normalizing flow models") 92 | # parser.add_argument("--alpha_divergence", default=0.5, type=float, help="hyperparameters for alpha divergence") 93 | # parser.add_argument("--save_path", default='./checkpoint/orbit_beta_pic_b/all/randomtest', type=str, help="path to save normalizing flow models") 94 | 95 | parser.add_argument("--coordinate_type", default='cartesian', type=str, help="coordinate type") 96 | parser.add_argument("--target", default='betapic', type=str, help="target exoplanet") 97 | 98 | parser.add_argument("--data_weight", default=1.0, type=float, help="final data weight for training, between 0-1") 99 | parser.add_argument("--start_order", default=4, type=float, help="start order") 100 | parser.add_argument("--decay_rate", default=3000, type=float, help="decay rate") 101 | parser.add_argument("--n_epoch", default=24000, type=int, help="number of epochs for training RealNVP") 102 | 103 | parser.add_argument("--n_flow", default=16, type=int, help="number of affine coupling layers in RealNVP") 104 | 105 | 106 | if torch.cuda.is_available(): 107 | device = torch.device('cuda:{}'.format(0)) 108 | 109 | if __name__ == "__main__": 110 | args = parser.parse_args() 111 | 112 | save_path = args.save_path#'./checkpoint/GJ504' 113 | # save_path = './checkpoint/orbit_beta_pic_b/all_simulated_annealing_alphadiv'#'./checkpoint/orbit_beta_pic_b/all_simulated_annealing'#'./checkpoint/orbit_beta_pic_b/all_simulated_annealing'#'./checkpoint/orbit_beta_pic_b_cartesian'#'./checkpoint/orbit_GJ504'# 114 | if not os.path.exists(save_path): 115 | os.makedirs(save_path) 116 | 117 | 118 | 119 | n_flow = args.n_flow#16#4#8#16#32#16#4#32# 120 | affine = True 121 | nparams = 8 122 | 123 | base_distribution = 'gaussian'#'gmm'#'gmm_only'# 124 | 125 | 126 | params_generator = realnvpfc_model.RealNVP(nparams, n_flow, affine=affine, seqfrac=1/16, batch_norm=True).to(device) 127 | # params_generator = realnvpfc_model.RealNVP(nparams, n_flow, affine=affine, seqfrac=1/64, batch_norm=True).to(device) 128 | # params_generator = realnvpfc_model.RealNVP(nparams, n_flow, affine=affine, seqfrac=1/128, batch_norm=False).to(device) 129 | # params_generator = realnvpfc_model.RealNVP(nparams, n_flow, affine=affine, seqfrac=1/64, batch_norm=False).to(device) 130 | # params_generator = realnvpfc_model.RealNVP(nparams, n_flow, affine=affine, seqfrac=1/32, batch_norm=False).to(device) 131 | 132 | 133 | target = args.target#'GJ504'#'betapic'# 134 | if target == 'betapic': 135 | astrometry_data = pd.read_csv('../dataset/orbital_fit/betapic_astrometry.csv') 136 | # astrometry_data['raoff'] = astrometry_data['sep'] * np.sin(astrometry_data['pa'] * np.pi / 180) 137 | # astrometry_data['decoff'] = astrometry_data['sep'] * np.cos(astrometry_data['pa'] * np.pi / 180) 138 | # astrometry_data['raoff_err'] = astrometry_data['sep_err'] 139 | # astrometry_data['decoff_err'] = astrometry_data['sep_err'] 140 | # astrometry_data['sep'] = astrometry_data['sep'] - 0.5*astrometry_data['sep_err']**2/astrometry_data['sep'] 141 | # astrometry_data = astrometry_data[0:17] 142 | 143 | cartesian_indices = np.where(np.logical_not(np.isnan(np.array(astrometry_data['raoff']))))[0] 144 | polar_indices = np.where(np.logical_not(np.isnan(np.array(astrometry_data['pa']))))[0] 145 | # polar_indices = np.arange(18) 146 | 147 | 148 | 149 | polar_exclude_cartesian_indices = np.where(np.logical_and(np.isnan(np.array(astrometry_data['raoff'])), np.logical_not(np.isnan(np.array(astrometry_data['pa'])))))[0] 150 | all_indices = np.concatenate([cartesian_indices, polar_exclude_cartesian_indices]) 151 | 152 | 153 | 154 | epochs = torch.tensor(np.array(astrometry_data['epoch']), dtype=torch.float32).to(device) 155 | sep = torch.tensor(np.array(astrometry_data['sep'][polar_indices]), dtype=torch.float32).to(device) 156 | sep_err = torch.tensor(np.array(astrometry_data['sep_err'][polar_indices]), dtype=torch.float32).to(device) 157 | pa_values = np.array(astrometry_data['pa'][polar_indices]) 158 | pa_values[pa_values>180] = pa_values[pa_values>180] - 360 159 | pa = np.pi / 180 * torch.tensor(pa_values, dtype=torch.float32).to(device) 160 | pa_err = np.pi / 180 * torch.tensor(np.array(astrometry_data['pa_err'][polar_indices]), dtype=torch.float32).to(device) 161 | 162 | sep_err = sep_err * 3 163 | pa_err = pa_err * 3 164 | 165 | raoff = torch.tensor(np.array(astrometry_data['raoff'][cartesian_indices]), dtype=torch.float32).to(device) 166 | raoff_err = torch.tensor(np.array(astrometry_data['raoff_err'][cartesian_indices]), dtype=torch.float32).to(device) 167 | decoff = torch.tensor(np.array(astrometry_data['decoff'][cartesian_indices]), dtype=torch.float32).to(device) 168 | decoff_err = torch.tensor(np.array(astrometry_data['decoff_err'][cartesian_indices]), dtype=torch.float32).to(device) 169 | 170 | raoff_convert = sep * torch.sin(pa) 171 | decoff_convert = sep * torch.cos(pa) 172 | 173 | eps = 1e-3 174 | orbit_converter = Params2orbits(sma_range=[4, 40], ecc_range=[1e-5, 0.99], 175 | inc_range=[81/180*np.pi, 99/180*np.pi], aop_range=[0.0-eps, 2.0*np.pi+eps], 176 | pan_range=[25/180*np.pi, 85/180*np.pi], tau_range=[0.0-eps, 1.0+eps], 177 | plx_range=[51.44-5*0.12, 51.44+5*0.12], mtot_range=[1.75-5*0.05, 1.75+5*0.05]).to(device) 178 | 179 | 180 | # orbit_converter = Params2orbits(sma_range=[4, 40], ecc_range=[1e-8, 0.99], 181 | # inc_range=[81/180*np.pi, 99/180*np.pi], aop_range=[-2.0*np.pi, 2.0*np.pi], 182 | # pan_range=[25/180*np.pi, 85/180*np.pi], tau_range=[-1.0, 1.0], 183 | # plx_range=[51.44-5*0.12, 51.44+5*0.12], mtot_range=[1.75-5*0.05, 1.75+5*0.05]) 184 | 185 | 186 | 187 | 188 | coordinate_type = args.coordinate_type#'all'#'cartesian'#'polar'# 189 | if coordinate_type == 'cartesian': 190 | epochs = epochs[cartesian_indices] 191 | elif coordinate_type == 'polar': 192 | epochs = epochs[polar_indices] 193 | elif coordinate_type == 'all' or coordinate_type == 'all_cartesian': 194 | epochs = epochs[all_indices] 195 | 196 | 197 | if coordinate_type == 'cartesian': 198 | scale_factor = 1.0 / len(cartesian_indices) 199 | elif coordinate_type == 'polar': 200 | scale_factor = 1.0 / len(polar_indices) 201 | elif coordinate_type == 'all' or coordinate_type == 'all_cartesian': 202 | scale_factor = 1.0 / len(all_indices) 203 | 204 | 205 | 206 | 207 | elif target == 'GJ504': 208 | 209 | epochs = torch.tensor([55645.95, 55702.89, 55785.015, 55787.935, 55985.19400184, 56029.11400323, 56072.30200459], dtype=torch.float32).to(device) 210 | sep = torch.tensor([2479, 2483, 2481, 2448, 2483, 2487, 2499], dtype=torch.float32).to(device) 211 | sep_err = torch.tensor([16, 8, 33, 24, 15, 8, 26], dtype=torch.float32).to(device) 212 | pa_values = np.array([327.94, 327.45, 326.84, 325.82, 326.46, 326.54, 326.14]) 213 | pa_values[pa_values>180] = pa_values[pa_values>180] - 360 214 | pa = np.pi / 180 * torch.tensor(pa_values, dtype=torch.float32).to(device) 215 | pa_err = np.pi / 180 * torch.tensor([0.39, 0.19, 0.94, 0.66, 0.36, 0.18, 0.61], dtype=torch.float32).to(device) 216 | 217 | 218 | coordinate_type = 'polar' 219 | 220 | sep_weight = 1.0 221 | pa_weight = 1.0 222 | logdet_weight = 2.0 / len(epochs) 223 | prior_weight = 1.0 / len(epochs) 224 | 225 | eps = 1e-3 226 | orbit_converter = Params2orbits(sma_range=[1e1, 1e4], ecc_range=[1e-8, 0.99], 227 | inc_range=[0.0+eps, np.pi-eps], aop_range=[0.0-eps, 2.0*np.pi+eps], 228 | pan_range=[0.0-eps, 2.0*np.pi+eps], tau_range=[0.0-eps, 1.0+eps], 229 | plx_range=[56.95-5*0.26, 56.95+5*0.26], mtot_range=[1.22-5*0.08, 1.22+5*0.08]).to(device) 230 | 231 | # orbit_converter = Params2orbits(sma_range=[1e1, 1e4], ecc_range=[1e-8, 0.99], 232 | # inc_range=[0.0+3e-4, np.pi-3e-4], aop_range=[-2.0*np.pi, 2.0*np.pi], 233 | # pan_range=[-2.0*np.pi, 2.0*np.pi], tau_range=[-1.0, 1.0], 234 | # plx_range=[56.95-5*0.26, 56.95+5*0.26], mtot_range=[1.22-5*0.08, 1.22+5*0.08]) 235 | 236 | scale_factor = 1.0 / len(epochs) 237 | 238 | 239 | n_batch = 8192#256#2048#512#4096#256#128#64#32#8 240 | n_smooth = 10 241 | loss_best = 1e8 242 | 243 | 244 | loss_list = [] 245 | loss_prior_list = [] 246 | loss_sep_list = [] 247 | loss_pa_list = [] 248 | loss_raoff_list = [] 249 | loss_decoff_list = [] 250 | loss_raoff_convert_list = [] 251 | loss_decoff_convert_list = [] 252 | logdet_list = [] 253 | 254 | 255 | 256 | # optimize both scale and image generator 257 | lr = 2e-4#3e-4#1e-4#1e-4#3e-4#1e-4#3e-4#1e-3#3e-4#1e-3#3e-3#1e-2#1e-4#1e-5#2e-4#1e-4#1e-6#1e-5#1e-3#args.lr#1e-5# 258 | clip = 1e-4#1e-4#1e-5#1e-4#1e-4#1e-3#1e-4#1e-3#1e-5#1e-4#3e-4#1e-3#1e-4#1e-5#2e-4#1e-4#1e-6#1e-5#3e-5#1e-3#1#1e2#1e-1# 259 | 260 | 261 | optimizer = optim.Adam(params_generator.parameters(), lr = lr, amsgrad=True) 262 | # optimizer = optim.Adam(params_generator.parameters(), lr = lr) 263 | 264 | start_order = args.start_order#5#6#5#4# 265 | n_epoch = args.n_epoch#30000#30000#21000#100000#30000#3000# 266 | decay_rate = args.decay_rate#3000#10000#5000#3000 267 | start_time = time.time() 268 | alpha_divergence = args.alpha_divergence#1.0#0.99#0.6#0.8#0.95#0.999#0.9# 269 | 270 | divergence_type = args.divergence_type#'alpha'#'KL'# 271 | 272 | final_data_weight = args.data_weight 273 | 274 | for k in range(n_epoch): 275 | data_weight = min(10**(-start_order+k/decay_rate), final_data_weight) 276 | 277 | z_sample = torch.randn((n_batch, nparams)).to(device=device) 278 | 279 | 280 | # generate image samples 281 | params_samp, logdet = params_generator.reverse(z_sample) 282 | params = torch.sigmoid(params_samp) 283 | # compute log determinant 284 | det_sigmoid = torch.sum(-params_samp-2*torch.nn.Softplus()(-params_samp), -1) 285 | logdet = logdet + det_sigmoid 286 | 287 | 288 | # params_samp = params_generator.forward() 289 | # params = torch.sigmoid(params_samp) 290 | 291 | sma, ecc, inc, aop, pan, tau, plx, mtot = orbit_converter.forward(params) 292 | 293 | if target == 'betapic': 294 | raoff_torch, deoff_torch, vz_torch = calc_orbit_torch(epochs, sma, ecc, inc, aop, pan, tau, plx, mtot, max_iter=10, tolerance=1e-8, tau_ref_epoch=50000) 295 | else: 296 | raoff_torch, deoff_torch, vz_torch = calc_orbit_torch(epochs, sma, ecc, inc, aop, pan, tau, plx, mtot, max_iter=10, tolerance=1e-8) 297 | 298 | sep_torch = torch.transpose(torch.sqrt(raoff_torch**2 + deoff_torch**2), 0, 1) 299 | pa_torch = torch.transpose(torch.atan2(raoff_torch, deoff_torch), 0, 1) 300 | 301 | raoff_torch = torch.transpose(raoff_torch, 0, 1) 302 | deoff_torch = torch.transpose(deoff_torch, 0, 1) 303 | 304 | if coordinate_type == 'polar': 305 | # loss_sep = (sep_torch - sep)**2 / sep_err**2 306 | # loss_pa = (torch.atan2(torch.sin(pa-pa_torch), torch.cos(pa-pa_torch)))**2 / pa_err**2 307 | # loss_sep = (torch.log(sep_torch) - torch.log(sep))**2 / (sep_err/sep)**2 308 | loss_sep = (torch.log(sep_torch) - torch.log(sep) + 0.5*sep_err**2/sep**2)**2 / (sep_err/sep)**2 309 | loss_pa = 2.0 * (1 - torch.cos(pa-pa_torch)) / pa_err**2 310 | elif coordinate_type == 'cartesian': 311 | loss_raoff = (raoff_torch - raoff)**2 / raoff_err**2 312 | loss_decoff = (deoff_torch - decoff)**2 / decoff_err**2 313 | elif coordinate_type == 'all': 314 | # loss_sep = (sep_torch[:, polar_exclude_cartesian_indices] - sep[polar_exclude_cartesian_indices])**2 / sep_err[polar_exclude_cartesian_indices]**2 315 | loss_sep = (torch.log(sep_torch[:, polar_exclude_cartesian_indices]) - torch.log(sep[polar_exclude_cartesian_indices]))**2 / (sep_err[polar_exclude_cartesian_indices]/sep[polar_exclude_cartesian_indices])**2 316 | # loss_pa = (torch.atan2(torch.sin(pa_torch[:, polar_exclude_cartesian_indices]-pa[polar_exclude_cartesian_indices]), torch.cos(pa_torch[:, polar_exclude_cartesian_indices]-pa[polar_exclude_cartesian_indices])))**2 / pa_err[polar_exclude_cartesian_indices]**2 317 | loss_pa = 2.0 * (1 - torch.cos(pa_torch[:, polar_exclude_cartesian_indices] - pa[polar_exclude_cartesian_indices])) / pa_err[polar_exclude_cartesian_indices]**2 318 | 319 | loss_raoff = (raoff_torch[:, cartesian_indices] - raoff[cartesian_indices])**2 / raoff_err[cartesian_indices]**2 320 | loss_decoff = (deoff_torch[:, cartesian_indices] - decoff[cartesian_indices])**2 / decoff_err[cartesian_indices]**2 321 | elif coordinate_type == 'all_cartesian': 322 | loss_raoff_convert = (raoff_torch[:, polar_exclude_cartesian_indices] - raoff_convert[polar_exclude_cartesian_indices])**2 / sep_err[polar_exclude_cartesian_indices]**2 323 | loss_decoff_convert = (deoff_torch[:, polar_exclude_cartesian_indices] - decoff_convert[polar_exclude_cartesian_indices])**2 / sep_err[polar_exclude_cartesian_indices]**2 324 | 325 | loss_raoff = (raoff_torch[:, cartesian_indices] - raoff[cartesian_indices])**2 / raoff_err[cartesian_indices]**2 326 | loss_decoff = (deoff_torch[:, cartesian_indices] - decoff[cartesian_indices])**2 / decoff_err[cartesian_indices]**2 327 | 328 | if target == 'betapic': 329 | loss_prior = (plx - 51.44)**2 / 0.12**2 + (mtot - 1.75)**2 / 0.05**2 330 | elif target == 'GJ504': 331 | loss_prior = (plx - 56.95)**2 / 0.26**2 + (mtot - 1.22)**2 / 0.08**2 332 | 333 | 334 | logprob = -logdet - 0.5*torch.sum(z_sample**2, 1) 335 | 336 | # Define the divergence loss - annealed loss 337 | if coordinate_type == 'polar': 338 | loss = data_weight * (0.5* torch.sum(loss_sep, -1) + 0.5 * torch.sum(loss_pa, -1) + 0.5 * loss_prior) + logprob 339 | elif coordinate_type == 'cartesian': 340 | loss = data_weight * (0.5* torch.sum(loss_raoff, -1) + 0.5 * torch.sum(loss_decoff, -1) + 0.5 * loss_prior) + logprob 341 | elif coordinate_type == 'all': 342 | loss = data_weight * (0.5* torch.sum(loss_sep, -1) + 0.5 * torch.sum(loss_pa, -1) + \ 343 | 0.5* torch.sum(loss_raoff, -1) + 0.5 * torch.sum(loss_decoff, -1) + 0.5 * loss_prior) + logprob 344 | elif coordinate_type == 'all_cartesian': 345 | loss = data_weight * (0.5* torch.sum(loss_raoff_convert, -1) + 0.5 * torch.sum(loss_decoff_convert, -1) + \ 346 | 0.5* torch.sum(loss_raoff, -1) + 0.5 * torch.sum(loss_decoff, -1) + 0.5 * loss_prior) + logprob 347 | 348 | 349 | if divergence_type == 'KL' or alpha_divergence == 1: 350 | loss = torch.mean(scale_factor * loss) 351 | elif divergence_type == 'alpha': 352 | rej_weights = nn.Softmax(dim=0)(-(1-alpha_divergence)*loss).detach() 353 | loss = torch.sum(rej_weights * scale_factor * loss) 354 | 355 | 356 | # Define the divergence loss - original loss 357 | if coordinate_type == 'polar': 358 | loss_orig = (0.5* torch.sum(loss_sep, -1) + 0.5 * torch.sum(loss_pa, -1) + 0.5 * loss_prior) + logprob 359 | elif coordinate_type == 'cartesian': 360 | loss_orig = (0.5* torch.sum(loss_raoff, -1) + 0.5 * torch.sum(loss_decoff, -1) + 0.5 * loss_prior) + logprob 361 | elif coordinate_type == 'all': 362 | loss_orig = (0.5* torch.sum(loss_sep, -1) + 0.5 * torch.sum(loss_pa, -1) + \ 363 | 0.5* torch.sum(loss_raoff, -1) + 0.5 * torch.sum(loss_decoff, -1) + 0.5 * loss_prior) + logprob 364 | elif coordinate_type == 'all_cartesian': 365 | loss_orig = (0.5* torch.sum(loss_raoff_convert, -1) + 0.5 * torch.sum(loss_decoff_convert, -1) + \ 366 | 0.5* torch.sum(loss_raoff, -1) + 0.5 * torch.sum(loss_decoff, -1) + 0.5 * loss_prior) + logprob 367 | 368 | # Define the divergence loss - original loss 369 | if divergence_type == 'KL' or alpha_divergence == 1: 370 | loss_orig = torch.mean(scale_factor * loss_orig) 371 | elif divergence_type == 'alpha': 372 | loss_orig = scale_factor * torch.log(torch.mean(torch.exp(-(1-alpha_divergence)*loss_orig)))/(alpha_divergence-1) 373 | 374 | 375 | optimizer.zero_grad() 376 | loss.backward() 377 | nn.utils.clip_grad_norm_(params_generator.parameters(), clip) 378 | optimizer.step() 379 | 380 | loss_list.append(loss_orig.detach().cpu().numpy()) 381 | loss_prior_list.append(torch.mean(loss_prior).detach().cpu().numpy()) 382 | if coordinate_type == 'polar': 383 | loss_sep_list.append(torch.mean(loss_sep).detach().cpu().numpy()) 384 | loss_pa_list.append(torch.mean(loss_pa).detach().cpu().numpy()) 385 | elif coordinate_type == 'cartesian': 386 | loss_raoff_list.append(torch.mean(loss_raoff).detach().cpu().numpy()) 387 | loss_decoff_list.append(torch.mean(loss_decoff).detach().cpu().numpy()) 388 | elif coordinate_type == 'all': 389 | loss_sep_list.append(torch.mean(loss_sep).detach().cpu().numpy()) 390 | loss_pa_list.append(torch.mean(loss_pa).detach().cpu().numpy()) 391 | loss_raoff_list.append(torch.mean(loss_raoff).detach().cpu().numpy()) 392 | loss_decoff_list.append(torch.mean(loss_decoff).detach().cpu().numpy()) 393 | elif coordinate_type == 'all_cartesian': 394 | loss_raoff_convert_list.append(torch.mean(loss_raoff_convert).detach().cpu().numpy()) 395 | loss_decoff_convert_list.append(torch.mean(loss_decoff_convert).detach().cpu().numpy()) 396 | loss_raoff_list.append(torch.mean(loss_raoff).detach().cpu().numpy()) 397 | loss_decoff_list.append(torch.mean(loss_decoff).detach().cpu().numpy()) 398 | logdet_list.append(-torch.mean(logdet).detach().cpu().numpy()/nparams) 399 | 400 | 401 | if coordinate_type == 'polar': 402 | print(f"epoch: {(k):}, loss: {loss_list[-1]:.5f}, loss sep: {loss_sep_list[-1]:.5f}, loss pa: {loss_pa_list[-1]:.5f}, loss prior: {loss_prior_list[-1]:.5f}, logdet: {logdet_list[-1]:.5f}") 403 | elif coordinate_type == 'cartesian': 404 | print(f"epoch: {(k):}, loss: {loss_list[-1]:.5f}, loss raoff: {loss_raoff_list[-1]:.5f}, loss decoff: {loss_decoff_list[-1]:.5f}, loss prior: {loss_prior_list[-1]:.5f}, logdet: {logdet_list[-1]:.5f}") 405 | elif coordinate_type == 'all': 406 | print(f"epoch: {(k):}, loss: {loss_list[-1]:.5f}, loss sep: {loss_sep_list[-1]:.5f}, loss pa: {loss_pa_list[-1]:.5f}, loss raoff: {loss_raoff_list[-1]:.5f}, loss decoff: {loss_decoff_list[-1]:.5f}, loss prior: {loss_prior_list[-1]:.5f}, logdet: {logdet_list[-1]:.5f}") 407 | elif coordinate_type == 'all_cartesian': 408 | print(f"epoch: {(k):}, loss: {loss_list[-1]:.5f}, loss raoff convert: {loss_raoff_convert_list[-1]:.5f}, loss decoff convert: {loss_decoff_convert_list[-1]:.5f}, loss raoff: {loss_raoff_list[-1]:.5f}, loss decoff: {loss_decoff_list[-1]:.5f}, loss prior: {loss_prior_list[-1]:.5f}, logdet: {logdet_list[-1]:.5f}") 409 | # if k > n_smooth and data_weight==1: 410 | if k > n_smooth + 1: 411 | loss_now = np.mean(loss_list[-n_smooth::]) 412 | if loss_now <= loss_best: 413 | loss_best = loss_now 414 | print('################{}###############'.format(loss_best)) 415 | 416 | torch.save(params_generator.state_dict(), save_path+'/generativemodelbest_'+coordinate_type+'_'+'RealNVP'+'_flow{}'.format(n_flow)) 417 | 418 | if k == 0 or (k+1)%decay_rate == 0: 419 | torch.save(params_generator.state_dict(), save_path+'/generativemodel_loop{}_'.format((k+1)//decay_rate)+coordinate_type+'_'+'RealNVP'+'_flow{}'.format(n_flow)) 420 | end_time = time.time() 421 | 422 | 423 | torch.save(params_generator.state_dict(), save_path+'/generativemodel_'+coordinate_type+'_'+'RealNVP'+'_flow{}'.format(n_flow)) 424 | 425 | loss_all = {} 426 | loss_all['total'] = np.array(loss_list) 427 | loss_all['prior'] = np.array(loss_prior_list) 428 | if coordinate_type == 'polar': 429 | loss_all['sep'] = np.array(loss_sep_list) 430 | loss_all['pa'] = np.array(loss_pa_list) 431 | elif coordinate_type == 'cartesian': 432 | loss_all['raoff'] = np.array(loss_raoff_list) 433 | loss_all['decoff'] = np.array(loss_decoff_list) 434 | elif coordinate_type == 'all': 435 | loss_all['raoff'] = np.array(loss_raoff_list) 436 | loss_all['decoff'] = np.array(loss_decoff_list) 437 | loss_all['sep'] = np.array(loss_sep_list) 438 | loss_all['pa'] = np.array(loss_pa_list) 439 | elif coordinate_type == 'all_cartesian': 440 | loss_all['raoff'] = np.array(loss_raoff_list) 441 | loss_all['decoff'] = np.array(loss_decoff_list) 442 | oss_all['raoff_convert'] = np.array(loss_raoff_convert_list) 443 | loss_all['decoff_convert'] = np.array(loss_decoff_convert_list) 444 | 445 | loss_all['logdet'] = np.array(logdet_list) 446 | loss_all['time'] = end_time - start_time 447 | np.save(save_path+'/loss_'+coordinate_type+'_'+'RealNVP'+'_flow{}'.format(n_flow), loss_all) 448 | 449 | 450 | #####################################Visualization################################################## 451 | # params_generator.load_state_dict(torch.load(save_path+'/generativemodel_'+coordinate_type+'_'+'RealNVP'+'_flow{}'.format(n_flow))) 452 | 453 | # # save_path = './checkpoint/orbit_beta_pic_b/cartesian2/alpha09' 454 | # save_path = './checkpoint/orbit_beta_pic_b/all2/KL' 455 | 456 | # alpha_divergence = 0.99#0.8 457 | # coordinate_type = 'all'#'cartesian' 458 | # 459 | 460 | # def Gen_samples(params_generator, rejsamp_flag=True, n_concat=10, alpha_divergence=1.0, coordinate_type='cartesian'): 461 | 462 | # for k in range(n_concat): 463 | 464 | # z_sample = torch.randn((n_batch, nparams)).to(device=device) 465 | # # generate image samples 466 | # params_samp, logdet = params_generator.reverse(z_sample) 467 | # params = torch.sigmoid(params_samp) 468 | 469 | # sma, ecc, inc, aop, pan, tau, plx, mtot = orbit_converter.forward(params) 470 | 471 | # if rejsamp_flag: 472 | 473 | # det_sigmoid = torch.sum(-params_samp-2*torch.nn.Softplus()(-params_samp), -1) 474 | # logdet = logdet + det_sigmoid 475 | # logprob = -logdet - 0.5*torch.sum(z_sample**2, 1) 476 | 477 | # if target == 'betapic': 478 | # raoff_torch, deoff_torch, vz_torch = calc_orbit_torch(epochs, sma, ecc, inc, aop, pan, tau, plx, mtot, max_iter=10, tolerance=1e-8, tau_ref_epoch=50000) 479 | # else: 480 | # raoff_torch, deoff_torch, vz_torch = calc_orbit_torch(epochs, sma, ecc, inc, aop, pan, tau, plx, mtot, max_iter=10, tolerance=1e-8) 481 | 482 | # sep_torch = torch.transpose(torch.sqrt(raoff_torch**2 + deoff_torch**2), 0, 1) 483 | # pa_torch = torch.transpose(torch.atan2(raoff_torch, deoff_torch), 0, 1) 484 | 485 | # raoff_torch = torch.transpose(raoff_torch, 0, 1) 486 | # deoff_torch = torch.transpose(deoff_torch, 0, 1) 487 | 488 | # if coordinate_type == 'polar': 489 | # # loss_sep = (sep_torch - sep)**2 / sep_err**2 490 | # # loss_sep = (torch.log(sep_torch) - torch.log(sep))**2 / (sep_err/sep)**2 491 | # loss_sep = (torch.log(sep_torch) - torch.log(sep) + 0.5*sep_err**2/sep**2)**2 / (sep_err/sep)**2 492 | # loss_pa = (torch.atan2(torch.sin(pa-pa_torch), torch.cos(pa-pa_torch)))**2 / pa_err**2 493 | # elif coordinate_type == 'cartesian': 494 | # loss_raoff = (raoff_torch - raoff)**2 / raoff_err**2 495 | # loss_decoff = (deoff_torch - decoff)**2 / decoff_err**2 496 | # elif coordinate_type == 'all': 497 | # # loss_sep = (sep_torch[:, polar_exclude_cartesian_indices] - sep[polar_exclude_cartesian_indices])**2 / sep_err[polar_exclude_cartesian_indices]**2 498 | # loss_sep = (torch.log(sep_torch[:, polar_exclude_cartesian_indices]) - torch.log(sep[polar_exclude_cartesian_indices]))**2 / (sep_err[polar_exclude_cartesian_indices]/sep[polar_exclude_cartesian_indices])**2 499 | # loss_pa = (torch.atan2(torch.sin(pa_torch[:, polar_exclude_cartesian_indices]-pa[polar_exclude_cartesian_indices]), torch.cos(pa_torch[:, polar_exclude_cartesian_indices]-pa[polar_exclude_cartesian_indices])))**2 / pa_err[polar_exclude_cartesian_indices]**2 500 | 501 | # loss_raoff = (raoff_torch[:, cartesian_indices] - raoff[cartesian_indices])**2 / raoff_err[cartesian_indices]**2 502 | # loss_decoff = (deoff_torch[:, cartesian_indices] - decoff[cartesian_indices])**2 / decoff_err[cartesian_indices]**2 503 | 504 | # if target == 'betapic': 505 | # loss_prior = (plx - 51.44)**2 / 0.12**2 + (mtot - 1.75)**2 / 0.05**2 506 | # elif target == 'GJ504': 507 | # loss_prior = (plx - 56.95)**2 / 0.26**2 + (mtot - 1.22)**2 / 0.08**2 508 | 509 | # # # Loss function for generative model 510 | # # if coordinate_type == 'polar': 511 | # # loss_data = sep_weight * torch.mean(loss_sep, 1) + pa_weight * torch.mean(loss_pa, 1) + prior_weight * torch.mean(loss_prior) 512 | # # elif coordinate_type == 'cartesian': 513 | # # loss_data = raoff_weight * torch.mean(loss_raoff, 1) + decoff_weight * torch.mean(loss_decoff, 1) + prior_weight * torch.mean(loss_prior) 514 | # # elif coordinate_type == 'all': 515 | # # loss_data = sep_weight * torch.mean(loss_sep, 1) + pa_weight * torch.mean(loss_pa, 1) + raoff_weight * torch.mean(loss_raoff, 1) + decoff_weight * torch.mean(loss_decoff, 1) + \ 516 | # # prior_weight * torch.mean(loss_prior) 517 | 518 | 519 | # # Loss function for generative model 520 | # if coordinate_type == 'polar': 521 | # loss_orig = (0.5* torch.sum(loss_sep, -1) + 0.5 * torch.sum(loss_pa, -1) + 0.5 * loss_prior) + logprob 522 | # elif coordinate_type == 'cartesian': 523 | # loss_orig = (0.5* torch.sum(loss_raoff, -1) + 0.5 * torch.sum(loss_decoff, -1) + 0.5 * loss_prior) + logprob 524 | # elif coordinate_type == 'all': 525 | # loss_orig = (0.5* torch.sum(loss_sep, -1) + 0.5 * torch.sum(loss_pa, -1) + \ 526 | # 0.5* torch.sum(loss_raoff, -1) + 0.5 * torch.sum(loss_decoff, -1) + 0.5 * loss_prior) + logprob 527 | # # rej_prob = n_batch * nn.Softmax(dim=0)(-(1-alpha_divergence)*loss_orig) 528 | # rej_prob = n_batch * nn.Softmax(dim=0)(-loss_orig) 529 | 530 | # # rej_prob = rej_prob / torch.max(rej_prob) 531 | 532 | # rej_M = torch.sort(rej_prob)[0][int(0.99*n_batch)] 533 | # rej_prob = rej_prob / rej_M 534 | 535 | # U = torch.rand((n_batch, )).to(device=device) 536 | 537 | # # logU = torch.log(U) 538 | # # ind = torch.where((-loss_data - logprob - logM) > logU)[0] 539 | 540 | # ind = torch.where(rej_prob > U)[0] 541 | 542 | # else: 543 | # ind = np.arange(n_batch) 544 | # # ind = torch.sort(sep_weight * torch.mean(loss_pa, -1) + pa_weight * torch.mean(loss_sep, -1) + prior_weight * loss_prior)[1][0:int(0.99*n_batch)].detach().cpu().numpy() 545 | 546 | # # ind = torch.sort(raoff_weight * torch.mean(loss_raoff, -1) + decoff_weight * torch.mean(loss_decoff, -1) + prior_weight * loss_prior)[1][0:int(0.99*n_batch)].detach().cpu().numpy() 547 | 548 | 549 | 550 | # # aop[aop>np.pi] = aop[aop>np.pi] - 2 * np.pi 551 | # # tau[tau>0.5] = tau[tau>0.5] - 1.0 552 | # orbit_params1 = np.concatenate([sma[ind].unsqueeze(-1).detach().cpu().numpy(), (tau[ind].unsqueeze(-1).detach().cpu().numpy()), 553 | # (180/np.pi*aop[ind].unsqueeze(-1).detach().cpu().numpy()), 554 | # 180/np.pi*pan[ind].unsqueeze(-1).detach().cpu().numpy(), 180/np.pi*inc[ind].unsqueeze(-1).detach().cpu().numpy(), 555 | # ecc[ind].unsqueeze(-1).detach().cpu().numpy(), mtot[ind].unsqueeze(-1).detach().cpu().numpy()], -1) 556 | 557 | 558 | 559 | # orbit_params2 = np.concatenate([sma[ind].unsqueeze(-1).detach().cpu().numpy(), ecc[ind].unsqueeze(-1).detach().cpu().numpy(), 560 | # 180/np.pi*inc[ind].unsqueeze(-1).detach().cpu().numpy(), 561 | # (180/np.pi*aop[ind].unsqueeze(-1).detach().cpu().numpy()), 562 | # 180/np.pi*pan[ind].unsqueeze(-1).detach().cpu().numpy(), (tau[ind].unsqueeze(-1).detach().cpu().numpy()), 563 | # plx[ind].unsqueeze(-1).detach().cpu().numpy(), mtot[ind].unsqueeze(-1).detach().cpu().numpy()], -1) 564 | 565 | # if k == 0: 566 | # orbit_params_all1 = np.array(orbit_params1) 567 | # orbit_params_all2 = np.array(orbit_params2) 568 | 569 | # else: 570 | # orbit_params_all1 = np.concatenate([orbit_params_all1, orbit_params1], 0) 571 | # orbit_params_all2 = np.concatenate([orbit_params_all2, orbit_params2], 0) 572 | 573 | # return orbit_params_all2 574 | 575 | 576 | def Gen_samples2(params_generator, rejsamp_flag=True, n_concat=10, alpha_divergence=1.0, coordinate_type='cartesian'): 577 | 578 | for k in range(n_concat): 579 | 580 | z_sample = torch.randn((n_batch, nparams)).to(device=device) 581 | # generate image samples 582 | params_samp, logdet = params_generator.reverse(z_sample) 583 | params = torch.sigmoid(params_samp) 584 | 585 | sma, ecc, inc, aop, pan, tau, plx, mtot = orbit_converter.forward(params) 586 | 587 | if rejsamp_flag: 588 | 589 | det_sigmoid = torch.sum(-params_samp-2*torch.nn.Softplus()(-params_samp), -1) 590 | logdet = logdet + det_sigmoid 591 | logprob = -logdet - 0.5*torch.sum(z_sample**2, 1) 592 | 593 | if target == 'betapic': 594 | raoff_torch, deoff_torch, vz_torch = calc_orbit_torch(epochs, sma, ecc, inc, aop, pan, tau, plx, mtot, max_iter=10, tolerance=1e-8, tau_ref_epoch=50000) 595 | else: 596 | raoff_torch, deoff_torch, vz_torch = calc_orbit_torch(epochs, sma, ecc, inc, aop, pan, tau, plx, mtot, max_iter=10, tolerance=1e-8) 597 | 598 | sep_torch = torch.transpose(torch.sqrt(raoff_torch**2 + deoff_torch**2), 0, 1) 599 | pa_torch = torch.transpose(torch.atan2(raoff_torch, deoff_torch), 0, 1) 600 | 601 | raoff_torch = torch.transpose(raoff_torch, 0, 1) 602 | deoff_torch = torch.transpose(deoff_torch, 0, 1) 603 | 604 | if coordinate_type == 'polar': 605 | loss_sep = (torch.log(sep_torch) - torch.log(sep) + 0.5*sep_err**2/sep**2)**2 / (sep_err/sep)**2 606 | loss_pa = (torch.atan2(torch.sin(pa-pa_torch), torch.cos(pa-pa_torch)))**2 / pa_err**2 607 | elif coordinate_type == 'cartesian': 608 | loss_raoff = (raoff_torch - raoff)**2 / raoff_err**2 609 | loss_decoff = (deoff_torch - decoff)**2 / decoff_err**2 610 | elif coordinate_type == 'all': 611 | loss_sep = (torch.log(sep_torch[:, polar_exclude_cartesian_indices]) - torch.log(sep[polar_exclude_cartesian_indices]))**2 / (sep_err[polar_exclude_cartesian_indices]/sep[polar_exclude_cartesian_indices])**2 612 | loss_pa = (torch.atan2(torch.sin(pa_torch[:, polar_exclude_cartesian_indices]-pa[polar_exclude_cartesian_indices]), torch.cos(pa_torch[:, polar_exclude_cartesian_indices]-pa[polar_exclude_cartesian_indices])))**2 / pa_err[polar_exclude_cartesian_indices]**2 613 | 614 | loss_raoff = (raoff_torch[:, cartesian_indices] - raoff[cartesian_indices])**2 / raoff_err[cartesian_indices]**2 615 | loss_decoff = (deoff_torch[:, cartesian_indices] - decoff[cartesian_indices])**2 / decoff_err[cartesian_indices]**2 616 | 617 | if target == 'betapic': 618 | loss_prior = (plx - 51.44)**2 / 0.12**2 + (mtot - 1.75)**2 / 0.05**2 619 | elif target == 'GJ504': 620 | loss_prior = (plx - 56.95)**2 / 0.26**2 + (mtot - 1.22)**2 / 0.08**2 621 | 622 | 623 | # Loss function for generative model 624 | if coordinate_type == 'polar': 625 | loss_orig = (0.5* torch.sum(loss_sep, -1) + 0.5 * torch.sum(loss_pa, -1) + 0.5 * loss_prior) + logprob 626 | elif coordinate_type == 'cartesian': 627 | loss_orig = (0.5* torch.sum(loss_raoff, -1) + 0.5 * torch.sum(loss_decoff, -1) + 0.5 * loss_prior) + logprob 628 | elif coordinate_type == 'all': 629 | loss_orig = (0.5* torch.sum(loss_sep, -1) + 0.5 * torch.sum(loss_pa, -1) + \ 630 | 0.5* torch.sum(loss_raoff, -1) + 0.5 * torch.sum(loss_decoff, -1) + 0.5 * loss_prior) + logprob 631 | # rej_prob = n_batch * nn.Softmax(dim=0)(-(1-alpha_divergence)*loss_orig) 632 | 633 | 634 | # importance_weights = nn.Softmax(dim=0)(-loss_orig) 635 | # importance_weights = torch.exp(-loss_orig) 636 | importance_weights = -loss_orig 637 | 638 | ind = np.arange(n_batch) 639 | 640 | else: 641 | ind = np.arange(n_batch) 642 | 643 | 644 | 645 | orbit_params2 = np.concatenate([sma[ind].unsqueeze(-1).detach().cpu().numpy(), ecc[ind].unsqueeze(-1).detach().cpu().numpy(), 646 | 180/np.pi*inc[ind].unsqueeze(-1).detach().cpu().numpy(), 647 | (180/np.pi*aop[ind].unsqueeze(-1).detach().cpu().numpy()), 648 | 180/np.pi*pan[ind].unsqueeze(-1).detach().cpu().numpy(), (tau[ind].unsqueeze(-1).detach().cpu().numpy()), 649 | plx[ind].unsqueeze(-1).detach().cpu().numpy(), mtot[ind].unsqueeze(-1).detach().cpu().numpy()], -1) 650 | 651 | if rejsamp_flag: 652 | orbit_params2 = np.concatenate([orbit_params2, importance_weights[ind].detach().cpu().numpy().reshape((-1, 1))], 1) 653 | if k == 0: 654 | orbit_params_all2 = np.array(orbit_params2) 655 | 656 | else: 657 | orbit_params_all2 = np.concatenate([orbit_params_all2, orbit_params2], 0) 658 | 659 | return orbit_params_all2 660 | # corner.corner(orbit_params_all1, labels=['sma', 'tau', 'aop', 'pan', 'inc', 'ecc', 'mtot'], bins=50, quantiles=[0.16, 0.84]) 661 | 662 | 663 | 664 | params_generator.load_state_dict(torch.load(save_path+'/generativemodelbest_'+coordinate_type+'_'+'RealNVP'+'_flow{}'.format(n_flow))) 665 | params_generator.eval() 666 | 667 | orbit_params_all2 = Gen_samples2(params_generator, rejsamp_flag=False, n_concat=100, alpha_divergence=alpha_divergence, coordinate_type=coordinate_type) 668 | # # corner.corner(orbit_params_all2, labels=['sma', 'ecc', 'inc', 'aop', 'pan', 'tau', 'plx', 'mtot'], bins=50, quantiles=[0.16, 0.84]) 669 | 670 | # corner.corner(orbit_params_all2, labels=['sma', 'ecc', 'inc', 'aop', 'pan', 'tau', 'plx', 'mtot'], bins=50, quantiles=[0.16, 0.84], 671 | # range=[[8, 40], [0.0, 0.9], [85.5, 93.0], [0, 360], [29.6, 33.0], [0.0, 1.0], [50.8, 52.1], [1.5, 2.0]]) 672 | 673 | np.save(save_path+'/'+target+'_postsamples_norej_alpha{}.npy'.format(alpha_divergence), orbit_params_all2) 674 | 675 | 676 | 677 | orbit_params_all2 = Gen_samples2(params_generator, rejsamp_flag=True, n_concat=100, alpha_divergence=alpha_divergence, coordinate_type=coordinate_type) 678 | # corner.corner(orbit_params_all2, labels=['sma', 'ecc', 'inc', 'aop', 'pan', 'tau', 'plx', 'mtot'], bins=50, quantiles=[0.16, 0.84], 679 | # range=[[8, 40], [0.0, 0.9], [85.5, 93.0], [0, 360], [29.6, 33.0], [0.0, 1.0], [50.8, 52.1], [1.5, 2.0]]) 680 | 681 | # np.save(save_path+'/'+target+'_postsamples_rej_alpha{}.npy'.format(alpha_divergence), orbit_params_all2) 682 | # np.save(save_path+'/'+target+'_postsamples_rej2_alpha{}.npy'.format(alpha_divergence), orbit_params_all2) 683 | np.save(save_path+'/'+target+'_postsamples_importance_alpha{}.npy'.format(alpha_divergence), orbit_params_all2) 684 | 685 | 686 | # corner.corner(samples, labels=['sma', 'ecc', 'inc', 'aop', 'pan', 'tau', 'plx', 'mtot'], bins=50, quantiles=[0.16, 0.84], 687 | # range=[[8, 40], [0.0, 0.9], [85.5, 93.0], [0, 360], [29.6, 33.0], [0.0, 1.0], [50.8, 52.1], [1.5, 2.0]]) 688 | 689 | # corner.corner(orbit_params_all2, labels=['sma', 'ecc', 'inc', 'aop', 'pan', 'tau', 'plx', 'mtot'], bins=50, quantiles=[0.16, 0.84]) 690 | 691 | # corner.corner(gmm_sample.detach().cpu().numpy(), quantiles=[0.16, 0.84]) 692 | # corner.corner(z_sample.detach().cpu().numpy(), quantiles=[0.16, 0.84]) 693 | 694 | 695 | for i in range(1, (n_epoch//decay_rate)+1): 696 | 697 | params_generator.load_state_dict(torch.load(save_path+'/generativemodel_loop{}_'.format(i)+coordinate_type+'_'+'RealNVP'+'_flow{}'.format(n_flow))) 698 | params_generator.eval() 699 | 700 | orbit_params_all2 = Gen_samples2(params_generator, rejsamp_flag=False, n_concat=100, alpha_divergence=alpha_divergence, coordinate_type=coordinate_type) 701 | np.save(save_path+'/'+target+'_postsamples_norej_alpha{}_loop{}.npy'.format(alpha_divergence, i), orbit_params_all2) 702 | 703 | orbit_params_all2 = Gen_samples2(params_generator, rejsamp_flag=True, n_concat=100, alpha_divergence=alpha_divergence, coordinate_type=coordinate_type) 704 | # np.save(save_path+'/'+target+'_postsamples_rej_alpha{}_loop{}.npy'.format(alpha_divergence, i), orbit_params_all2) 705 | # np.save(save_path+'/'+target+'_postsamples_rej2_alpha{}_loop{}.npy'.format(alpha_divergence, i), orbit_params_all2) 706 | np.save(save_path+'/'+target+'_postsamples_importance_alpha{}_loop{}.npy'.format(alpha_divergence, i), orbit_params_all2) 707 | --------------------------------------------------------------------------------