├── context_encoder ├── __init__.py ├── .DS_Store └── encoders.py ├── .DS_Store ├── train_model.sh ├── eval_model.sh ├── metrics.py ├── README.md ├── requirements.txt ├── hparams.yaml ├── loss.py ├── data_continuous_EHT.py ├── mlp.py ├── gICLEAN.py ├── main.py └── data_ehtim_cont.py /context_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidsAtHKUST/PolarRec/HEAD/.DS_Store -------------------------------------------------------------------------------- /context_encoder/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidsAtHKUST/PolarRec/HEAD/context_encoder/.DS_Store -------------------------------------------------------------------------------- /train_model.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --exp_name Galaxy10-DEC-cont/transformer/mlp_8_layer/image_loss-NF_128 \ 3 | --ngpus 2 \ 4 | --yaml_file \ 5 | --loss_type spectral \ 6 | --num_fourier 128 \ 7 | --input_size 256 \ 8 | --dataset Galaxy10_DECals \ 9 | --data_path_cont ../data/eht_cont_200im_Galaxy10_DECals_full.h5 \ 10 | --data_path_imgs ../data/Galaxy10_DECals.h5 \ 11 | --dataset_path ../data/eht_grid_128FC_200im_Galaxy10_DECals_full.h5 \ 12 | -------------------------------------------------------------------------------- /eval_model.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --eval \ 3 | --exp_name Galaxy10-DEC-cont/transformer/mlp_8_layer/image_loss-NF_128 \ 4 | --yaml_file "Your path" \ 5 | --model_checkpoint "Your path" \ 6 | --loss_type spectral \ 7 | --num_fourier 128 \ 8 | --input_size 256 \ 9 | --dataset Galaxy10_DECals \ 10 | --data_path_cont ../data/eht_cont_200im_Galaxy10_DECals_full.h5 \ 11 | --data_path_imgs ../data/Galaxy10_DECals.h5 \ 12 | --dataset_path ../data/eht_grid_128FC_200im_Galaxy10_DECals_full.h5 \ 13 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import csv 4 | import cv2 5 | import numpy as np 6 | from skimage.metrics import mean_squared_error as mse, peak_signal_noise_ratio as psnr, structural_similarity as ssim 7 | 8 | 9 | def compute_metrics(img1, img2): 10 | mse_val = mse(img1, img2) 11 | psnr_val = psnr(img1, img2) 12 | ssim_val = ssim(img1, img2, multichannel=True) 13 | 14 | return mse_val, psnr_val, ssim_val 15 | 16 | 17 | 18 | mse_list = [] 19 | psnr_list = [] 20 | ssim_list = [] 21 | 22 | for i in range(0, 5000): 23 | 24 | rec_img_path = '../test_res1/Galaxy10-DEC-cont/transformer/mlp_8_layer/image_loss-NF_128/'+str(i)+'/recon_image.png' 25 | image_img_path = '../test_res1/Galaxy10-DEC-cont/transformer/mlp_8_layer/image_loss-NF_128/'+str(i)+'/image.png' 26 | 27 | asamples_mean_img = cv2.imread(rec_img_path, cv2.COLOR_RGB2GRAY) 28 | 29 | image_img = cv2.imread(image_img_path, cv2.COLOR_RGB2GRAY) 30 | 31 | mse_val, psnr_val, ssim_val = compute_metrics(asamples_mean_img, image_img) 32 | 33 | mse_list.append(mse_val) 34 | psnr_list.append(psnr_val) 35 | ssim_list.append(ssim_val) 36 | 37 | mse_mean = np.mean(mse_list) 38 | mse_std = np.std(mse_list) 39 | psnr_mean = np.mean(psnr_list) 40 | psnr_std = np.std(psnr_list) 41 | ssim_mean = np.mean(ssim_list) 42 | ssim_std = np.std(ssim_list) 43 | 44 | print("PSNR", psnr_mean, psnr_std) 45 | print("SSIM", ssim_mean, ssim_std) 46 | 47 | 48 | def LFD(recon_freq, real_freq): 49 | 50 | tmp = (recon_freq - real_freq) ** 2 51 | 52 | freq_distance = tmp[:,0,:,:] + tmp[:,1,:,:] 53 | 54 | LFD = np.log(freq_distance + 1) 55 | return LFD 56 | 57 | 58 | 59 | data_list_1 = [] 60 | data_list_2 = [] 61 | 62 | for i in range(0, 5000): 63 | folder_name = "../test_res1/Galaxy10-DEC-cont/transformer/mlp_8_layer/image_loss-NF_128/"+str(i) 64 | file_path = os.path.join(folder_name, "GT_vis.npy") 65 | 66 | data = np.load(file_path) 67 | data_list_1.append(data) 68 | 69 | folder_name = "../test_res1/Galaxy10-DEC-cont/transformer/mlp_8_layer/image_loss-NF_128/"+str(i) 70 | file_path = os.path.join(folder_name, "recon_vis.npy") 71 | data = np.load(file_path) 72 | data_list_2.append(data) 73 | 74 | result_1 = np.stack(data_list_1, axis=0) 75 | result_2 = np.stack(data_list_2, axis=0) 76 | 77 | 78 | res = LFD(result_1, result_2) 79 | res_vector = np.mean(res, axis=(1, 2)) 80 | 81 | mean = np.mean(res_vector) 82 | std_dev = np.std(res_vector) 83 | 84 | print("LFD_mean:", mean) 85 | print("LFD_std:", std_dev) 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## PolarRec: Improving Radio Interferometric Data Reconstruction with Polar Coordinates 2 | 3 | ### Abstract 4 | In radio astronomy, visibility data, which are measurements of wave signals from radio telescopes, are transformed into images for observation of distant celestial objects. However, these resultant images usually contain both real sources and artifacts, due to signal sparsity and other factors. One way to obtain cleaner images is to reconstruct samples into dense forms before imaging. Unfortunately, existing reconstruction methods often miss some components of visibility in frequency domain, so blurred object edges and persistent artifacts remain in the images. Furthermore, the computation overhead is high on irregular visibility samples due to the data skew. To address these problems, we propose PolarRec, a transformer-encoder-conditioned reconstruction pipeline with visibility samples converted into the polar coordinate representation. This representation matches the way in which radio telescopes observe a celestial area as the Earth rotates. As a result, visibility samples distribute in the polar system more uniformly than in the Cartesian space. Therefore, we propose to use radial distance in the loss function, to help reconstruct complete visibility effectively. Also, we group visibility samples by their polar angles and propose a group-based encoding scheme to improve the efficiency. Our experiments demonstrate that PolarRec markedly improves imaging results by faithfully reconstructing all frequency components in the visibility domain while significantly reducing the computation cost in visibility data encoding. 5 | 6 | 7 | ### Run the demo 8 | 9 | #### Setup the conda environment 10 | Set up the conda environment using the `requirements.txt` file. 11 | 12 | 13 | #### Datasets 14 | Please find the datasets at https://astronn.readthedocs.io and project of Wu. et al [1]. 15 | 16 | 17 | #### Modify the model, datapath path parameter within the bash script. 18 | 19 | #### Train Model 20 | Run the `train_model.sh` script from command line: 21 | ``` 22 | sh ./train_model.sh 23 | ``` 24 | 25 | #### Inference using the trained model 26 | Modify the paths in `eval_model.sh` script. 27 | 28 | Run the `eval_model.sh` script from command line: 29 | 30 | ``` 31 | sh ./eval_model.sh 32 | ``` 33 | The results will be saved in the `'../test_res1'` folder, including visibility reconstruction and resultant image. 34 | 35 | Evaluate the results with SSIM, PSNR, LFD: 36 | 37 | ``` 38 | python metrics.py 39 | ``` 40 | 41 | 42 | 43 | **Reference** 44 | 45 | [1] Wu B, Liu C, Eckart B, et al. Neural interferometry: Image reconstruction from astronomical interferometers using transformer-conditioned neural fields[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2022, 36(3): 2685-2693. 46 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | absl-py=0.12.0=pypi_0 6 | aiohttp=3.7.4.post0=pypi_0 7 | appdirs=1.4.4=pypi_0 8 | astroid=2.6.6=py37h06a4308_0 9 | astropy=4.2=pypi_0 10 | async-timeout=3.0.1=pypi_0 11 | attrs=20.3.0=pypi_0 12 | autopep8=1.5.7=pyhd3eb1b0_0 13 | backcall=0.2.0=pypi_0 14 | brotlipy=0.7.0=py37h7b6447c_1000 15 | ca-certificates=2021.7.5=h06a4308_1 16 | cached-property=1.5.2=pypi_0 17 | cachetools=4.2.2=pypi_0 18 | certifi=2021.5.30=py37h06a4308_0 19 | cffi=1.14.3=py37he30daa8_0 20 | chardet=3.0.4=py37_1003 21 | click=8.0.1=pypi_0 22 | configparser=5.0.2=pypi_0 23 | cryptography=3.1.1=py37h1ba5d50_0 24 | cycler=0.10.0=pypi_0 25 | decorator=4.4.2=pypi_0 26 | docker-pycreds=0.4.0=pypi_0 27 | ehtim=1.2.2=pypi_0 28 | einops=0.3.0=pypi_0 29 | ephem=3.7.7.1=pypi_0 30 | fftw=3.3.8=nompi_hfc0cae8_1114 31 | freetype=2.10.4=h5ab3b9f_0 32 | fsspec=2021.7.0=pypi_0 33 | future=0.18.2=pypi_0 34 | gitdb=4.0.7=pypi_0 35 | gitpython=3.1.18=pypi_0 36 | google-auth=1.30.0=pypi_0 37 | google-auth-oauthlib=0.4.4=pypi_0 38 | grpcio=1.37.1=pypi_0 39 | h5py=3.1.0=pypi_0 40 | idna=2.10=py_0 41 | imageio=2.9.0=pypi_0 42 | importlib-metadata=4.0.1=pypi_0 43 | ipdb=0.13.4=pypi_0 44 | ipython=7.19.0=pypi_0 45 | ipython-genutils=0.2.0=pypi_0 46 | isort=5.9.3=pyhd3eb1b0_0 47 | jedi=0.18.0=pypi_0 48 | joblib=1.0.1=pypi_0 49 | jpeg=9b=habf39ab_1 50 | jsonargparse=3.19.4=pypi_0 51 | kiwisolver=1.3.1=pypi_0 52 | lazy-object-proxy=1.6.0=py37h27cfd23_0 53 | lcms2=2.11=h396b838_0 54 | ld_impl_linux-64=2.33.1=h53a641e_7 55 | libblas=3.9.0=7_openblas 56 | libcblas=3.9.0=7_openblas 57 | libedit=3.1.20191231=h14c3975_1 58 | libffi=3.3=he6710b0_2 59 | libgcc-ng=9.1.0=hdf63c60_0 60 | libgfortran-ng=7.5.0=hae1eefd_17 61 | libgfortran4=7.5.0=hae1eefd_17 62 | liblapack=3.9.0=7_openblas 63 | libopenblas=0.3.12=pthreads_hb3c22a3_1 64 | libpng=1.6.37=hbc83047_0 65 | libstdcxx-ng=9.1.0=hdf63c60_0 66 | libtiff=4.1.0=h2733197_1 67 | lz4-c=1.9.2=heb0550a_3 68 | mako=1.1.6=pypi_0 69 | markdown=3.3.4=pypi_0 70 | markupsafe=2.1.0=pypi_0 71 | matplotlib=3.3.3=pypi_0 72 | mccabe=0.6.1=py37_1 73 | multidict=5.1.0=pypi_0 74 | ncurses=6.2=he6710b0_1 75 | networkx=2.5=pypi_0 76 | nfft=3.2.4=hf8c457e_1000 77 | numpy=1.19.5=pypi_0 78 | oauthlib=3.1.0=pypi_0 79 | olefile=0.46=py37_0 80 | openssl=1.1.1l=h7f8727e_0 81 | packaging=20.9=pypi_0 82 | pandas=1.2.0=pypi_0 83 | parso=0.8.1=pypi_0 84 | pathtools=0.1.2=pypi_0 85 | pexpect=4.8.0=pypi_0 86 | pickleshare=0.7.5=pypi_0 87 | pillow=8.1.0=pypi_0 88 | pip=20.3.3=py37h06a4308_0 89 | platformdirs=2.5.1=pypi_0 90 | promise=2.3=pypi_0 91 | prompt-toolkit=3.0.10=pypi_0 92 | protobuf=3.15.8=pypi_0 93 | psutil=5.8.0=pypi_0 94 | ptyprocess=0.7.0=pypi_0 95 | pyasn1=0.4.8=pypi_0 96 | pyasn1-modules=0.2.8=pypi_0 97 | pycodestyle=2.7.0=pyhd3eb1b0_0 98 | pycparser=2.20=py_2 99 | pycuda=2021.1=pypi_0 100 | pydeprecate=0.3.1=pypi_0 101 | pyerfa=1.7.1.1=pypi_0 102 | pyfits=3.5=pypi_0 103 | pygments=2.7.3=pypi_0 104 | pylint=2.9.6=py37h06a4308_1 105 | pynfft=1.3.2=py37h161383b_1003 106 | pyopenssl=19.1.0=py_1 107 | pyparsing=2.4.7=pypi_0 108 | pysocks=1.7.1=py37_1 109 | python=3.7.9=h7579374_0 110 | python-dateutil=2.8.1=pypi_0 111 | python_abi=3.7=1_cp37m 112 | pytools=2022.1=pypi_0 113 | pytorch-lightning=1.4.4=pypi_0 114 | pytz=2020.5=pypi_0 115 | pywavelets=1.1.1=pypi_0 116 | pyyaml=5.3.1=pypi_0 117 | readline=8.0=h7b6447c_0 118 | requests=2.24.0=py_0 119 | requests-oauthlib=1.3.0=pypi_0 120 | rsa=4.7.2=pypi_0 121 | scikit-image=0.18.1=pypi_0 122 | scikit-learn=0.24.2=pypi_0 123 | scipy=1.6.0=pypi_0 124 | sentry-sdk=1.3.1=pypi_0 125 | setuptools=51.0.0=py37h06a4308_2 126 | shortuuid=1.0.1=pypi_0 127 | six=1.15.0=py_0 128 | sklearn=0.0=pypi_0 129 | smmap=4.0.0=pypi_0 130 | sqlite=3.33.0=h62c20be_0 131 | subprocess32=3.5.4=pypi_0 132 | tensorboard=2.4.1=pypi_0 133 | tensorboard-plugin-wit=1.8.0=pypi_0 134 | threadpoolctl=2.1.0=pypi_0 135 | tifffile=2021.3.17=pypi_0 136 | tk=8.6.10=hbc83047_0 137 | toml=0.10.2=pyhd3eb1b0_0 138 | torch=1.8.1=pypi_0 139 | torchaudio=0.8.1=pypi_0 140 | torchkbnufft=1.2.0.post3=pypi_0 141 | torchmetrics=0.5.0=pypi_0 142 | torchvision=0.10.1=pypi_0 143 | tqdm=4.60.0=pypi_0 144 | traitlets=5.0.5=pypi_0 145 | typed-ast=1.4.3=py37h7f8727e_1 146 | typing-extensions=3.7.4.3=pypi_0 147 | typing_extensions=3.10.0.2=pyh06a4308_0 148 | urllib3=1.25.11=py_0 149 | wandb=0.12.0=pypi_0 150 | wcwidth=0.2.5=pypi_0 151 | werkzeug=1.0.1=pypi_0 152 | wheel=0.36.2=pyhd3eb1b0_0 153 | wrapt=1.12.1=py37h7b6447c_1 154 | xz=5.2.5=h7b6447c_0 155 | yarl=1.6.3=pypi_0 156 | zipp=3.4.1=pypi_0 157 | zlib=1.2.11=h7b6447c_3 158 | zstd=1.4.4=h0b5b093_3 159 | -------------------------------------------------------------------------------- /hparams.yaml: -------------------------------------------------------------------------------- 1 | L_embed: 128 2 | args: !!python/object:argparse.Namespace 3 | L_embed: 128 4 | accelerator: null 5 | accumulate_grad_batches: 1 6 | amp_backend: native 7 | amp_level: O2 8 | args: !!python/object:argparse.Namespace 9 | L_embed: 128 10 | accelerator: null 11 | accumulate_grad_batches: 1 12 | amp_backend: native 13 | amp_level: O2 14 | auto_lr_find: false 15 | auto_scale_batch_size: false 16 | auto_select_gpus: false 17 | batch_size: 4 18 | benchmark: false 19 | check_val_every_n_epoch: 1 20 | checkpoint_callback: true 21 | data_path_cont: /dataset/galaxy10/eht_cont_200im_Galaxy10_DECals_full.h5 22 | data_path_imgs: /dataset/galaxy10/Galaxy10_DECals.h5 23 | dataset: Galaxy10_DECals 24 | dataset_path: /dataset/galaxy10/eht_grid_128FC_200im_Galaxy10_DECals_full.h5 25 | default_root_dir: ./logs/Galaxy10-DEC-cont/transformer/mlp_8_layer/image_loss-NF_128 26 | deterministic: false 27 | devices: null 28 | distributed_backend: null 29 | exp_name: Galaxy10-DEC-cont/transformer/mlp_8_layer/image_loss-NF_128 30 | fast_dev_run: false 31 | flush_logs_every_n_steps: 100 32 | gpus: null 33 | gradient_clip_algorithm: norm 34 | gradient_clip_val: 0.0 35 | input_encoding: fourier 36 | input_size: 256 37 | ipus: null 38 | learning_rate: 0.0001 39 | limit_predict_batches: 1.0 40 | limit_test_batches: 1.0 41 | limit_train_batches: 1.0 42 | limit_val_batches: 1.0 43 | log_every_n_steps: 50 44 | log_gpu_memory: null 45 | logger: true 46 | loss_type: image 47 | m_epochs: 100 48 | max_epochs: null 49 | max_steps: null 50 | max_time: null 51 | min_epochs: null 52 | min_steps: null 53 | mlp_hidden_dim: 256 54 | mlp_layers: 8 55 | model_checkpoint: '' 56 | move_metrics_to_cpu: false 57 | multiple_trainloader_mode: max_size_cycle 58 | num_fourier: 128 59 | num_nodes: 1 60 | num_processes: 1 61 | num_sanity_val_steps: 2 62 | num_workers: 8 63 | overfit_batches: 0.0 64 | plugins: null 65 | precision: 32 66 | prepare_data_per_node: true 67 | process_position: 0 68 | profiler: null 69 | progress_bar_refresh_rate: null 70 | reload_dataloaders_every_epoch: false 71 | reload_dataloaders_every_n_epochs: 0 72 | replace_sampler_ddp: true 73 | resume_from_checkpoint: null 74 | scale_loss_image: 1.0 75 | sigma: 5.0 76 | stochastic_weight_avg: false 77 | sync_batchnorm: false 78 | terminate_on_nan: false 79 | tpu_cores: null 80 | track_grad_norm: -1 81 | truncated_bptt_steps: null 82 | val_check_interval: 1.0 83 | val_fldr: ./logs/Galaxy10-DEC-cont/transformer/mlp_8_layer/image_loss-NF_128 84 | weights_save_path: null 85 | weights_summary: top 86 | auto_lr_find: false 87 | auto_scale_batch_size: false 88 | auto_select_gpus: false 89 | batch_size: 4 90 | benchmark: false 91 | check_val_every_n_epoch: 1 92 | checkpoint_callback: true 93 | data_path_cont: ../data/eht_cont_200im_Galaxy10_DECals_full.h5 94 | data_path_imgs: ../data/Galaxy10_DECals.h5 95 | dataset: Galaxy10_DECals 96 | dataset_path: ../data/eht_grid_128FC_200im_Galaxy10_DECals_full.h5 97 | default_root_dir: null 98 | deterministic: false 99 | devices: null 100 | distributed_backend: null 101 | eval: false 102 | exp_name: Galaxy10-DEC-cont/transformer/mlp_8_layer/image_loss-NF_128 103 | fast_dev_run: false 104 | flush_logs_every_n_steps: 100 105 | gpus: null 106 | gradient_clip_algorithm: norm 107 | gradient_clip_val: 0.0 108 | hidden_dims: 109 | - 256 110 | - 256 111 | - 256 112 | - 256 113 | - 256 114 | - 256 115 | - 256 116 | input_encoding: fourier 117 | input_size: 256 118 | ipus: null 119 | kl_coeff: 100.0 120 | latent_dim: 1024 121 | learning_rate: 0.0001 122 | limit_predict_batches: 1.0 123 | limit_test_batches: 1.0 124 | limit_train_batches: 1.0 125 | limit_val_batches: 1.0 126 | log_every_n_steps: 50 127 | log_gpu_memory: null 128 | logger: true 129 | loss_type: spectral 130 | m_epochs: 200 131 | max_epochs: null 132 | max_steps: null 133 | max_time: null 134 | min_epochs: null 135 | min_steps: null 136 | mlp_hidden_dim: 256 137 | mlp_layers: 8 138 | model_checkpoint: '' 139 | move_metrics_to_cpu: false 140 | multiple_trainloader_mode: max_size_cycle 141 | ngpu: 8 142 | ngpus: 143 | - 0 144 | num_fourier: 128 145 | num_fourier_coeff: 128 146 | num_nodes: 1 147 | num_processes: 1 148 | num_sanity_val_steps: 2 149 | num_workers: 16 150 | overfit_batches: 0.0 151 | plugins: null 152 | precision: 32 153 | prepare_data_per_node: true 154 | process_position: 0 155 | profiler: null 156 | progress_bar_refresh_rate: null 157 | reload_dataloaders_every_epoch: false 158 | reload_dataloaders_every_n_epochs: 0 159 | replace_sampler_ddp: true 160 | resume_from_checkpoint: null 161 | scale_loss_image: 1.0 162 | sigma: 5.0 163 | stochastic_weight_avg: false 164 | sync_batchnorm: false 165 | terminate_on_nan: false 166 | tpu_cores: null 167 | track_grad_norm: -1 168 | truncated_bptt_steps: null 169 | val_check_interval: 1.0 170 | val_fldr: ./val_fldr-test 171 | weights_save_path: null 172 | weights_summary: top 173 | yaml_file: ../data/Galaxy10-DEC-cont/transformer/mlp_8_layer/image_loss-NF_128/hparams.yaml 174 | batch_size: 32 175 | hidden_dims: 176 | - 256 177 | - 256 178 | - 256 179 | - 256 180 | - 256 181 | - 256 182 | - 256 183 | input_encoding: fourier 184 | input_size: 256 185 | kl_coeff: 100.0 186 | latent_dim: 1024 187 | learning_rate: 0.0001 188 | model_checkpoint: '' 189 | ngpu: 2 190 | num_fourier_coeff: 128 191 | sigma: 5.0 192 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # version adaptation for PyTorch > 1.7.1 5 | IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.'))) > (1, 7, 1) 6 | if IS_HIGH_VERSION: 7 | import torch.fft 8 | 9 | 10 | class FocalFrequencyLoss(nn.Module): 11 | """ 12 | 13 | Ref: 14 | Focal Frequency Loss for Image Reconstruction and Synthesis. In ICCV 2021. 15 | 16 | 17 | Args: 18 | loss_weight (float): weight for focal frequency loss. Default: 1.0 19 | alpha (float): the scaling factor alpha of the spectrum weight matrix for flexibility. Default: 1.0 20 | beta (float): the scaling factor alpha of the spectrum weight matrix for flexibility. Default: 1.0 21 | patch_factor (int): the factor to crop image patches for patch-based focal frequency loss. Default: 1 22 | ave_spectrum (bool): whether to use minibatch average spectrum. Default: False 23 | log_matrix (bool): whether to adjust the spectrum weight matrix by logarithm. Default: False 24 | batch_matrix (bool): whether to calculate the spectrum weight matrix using batch-based statistics. Default: False 25 | """ 26 | 27 | def __init__(self, loss_weight=1.0, alpha=1.0, beta = 1.0, patch_factor=1, ave_spectrum=False, log_matrix=False, batch_matrix=False): 28 | super(FocalFrequencyLoss, self).__init__() 29 | self.loss_weight = loss_weight 30 | self.alpha = alpha 31 | self.beta = beta 32 | self.patch_factor = patch_factor 33 | self.ave_spectrum = ave_spectrum 34 | self.log_matrix = log_matrix 35 | self.batch_matrix = batch_matrix 36 | 37 | def tensor2freq(self, x): 38 | # crop image patches 39 | patch_factor = self.patch_factor 40 | _, _, h, w = x.shape 41 | assert h % patch_factor == 0 and w % patch_factor == 0, ( 42 | 'Patch factor should be divisible by image height and width') 43 | patch_list = [] 44 | patch_h = h // patch_factor 45 | patch_w = w // patch_factor 46 | for i in range(patch_factor): 47 | for j in range(patch_factor): 48 | patch_list.append(x[:, :, i * patch_h:(i + 1) * patch_h, j * patch_w:(j + 1) * patch_w]) 49 | 50 | # stack to patch tensor 51 | y = torch.stack(patch_list, 1) 52 | 53 | # perform 2D DFT (real-to-complex, orthonormalization) 54 | if IS_HIGH_VERSION: 55 | freq = torch.fft.fft2(y, norm='ortho') 56 | freq = torch.stack([freq.real, freq.imag], -1) 57 | else: 58 | freq = torch.rfft(y, 2, onesided=False, normalized=True) 59 | return freq 60 | 61 | def loss_formulation(self, recon_freq, real_freq, matrix=None): 62 | # spectrum weight matrix 63 | if matrix is not None: 64 | # if the matrix is predefined 65 | weight_matrix = matrix.detach() 66 | else: 67 | # if the matrix is calculated online: continuous, dynamic, based on current Euclidean distance 68 | matrix_tmp = (recon_freq - real_freq) ** 2 69 | matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha 70 | 71 | # whether to adjust the spectrum weight matrix by logarithm 72 | if self.log_matrix: 73 | matrix_tmp = torch.log(matrix_tmp + 1.0) 74 | 75 | # whether to calculate the spectrum weight matrix using batch-based statistics 76 | if self.batch_matrix: 77 | matrix_tmp = matrix_tmp / matrix_tmp.max() 78 | else: 79 | matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None] 80 | 81 | matrix_tmp[torch.isnan(matrix_tmp)] = 0.0 82 | matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0) 83 | weight_matrix = matrix_tmp.clone().detach() 84 | 85 | assert weight_matrix.min().item() >= 0 and weight_matrix.max().item() <= 1, ( 86 | 'The values of spectrum weight matrix should be in the range [0, 1], ' 87 | 'but got Min: %.10f Max: %.10f' % (weight_matrix.min().item(), weight_matrix.max().item())) 88 | 89 | height, width = recon_freq.shape[3:5] 90 | center_y, center_x = height / 2, width / 2 91 | # Create 2D arrays for the x and y coordinates 92 | y = torch.arange(height) 93 | x = torch.arange(width) 94 | # Create a grid of (x,y) coordinates 95 | x_grid, y_grid = torch.meshgrid(x - center_x, y - center_y) 96 | # Calculate the distance of each point from the center 97 | distance_from_center = torch.sqrt(x_grid**2 + y_grid**2) 98 | # Normalize distances to range from 0 to 1 99 | normalized_distance = distance_from_center / torch.max(distance_from_center) 100 | 101 | weights = (normalized_distance + 1) ** self.beta 102 | 103 | batch_size = recon_freq.shape[0] 104 | # now reshape your tensor 105 | weights = weights.unsqueeze(0).unsqueeze(0) # creates a tensor of size (1, 1, height, width) 106 | 107 | # now repeat your tensor along the desired dimensions 108 | weights = weights.repeat(batch_size, 1, 1, 1, 1) # results in a tensor of size (batch, 1, 1, height, width, 2)# results in a tensor of size (batch, 1, 1, height, width, 2) 109 | 110 | # frequency distance using (squared) Euclidean distance 111 | tmp = (recon_freq - real_freq) ** 2 112 | # tmp = abs(recon_freq - real_freq) 113 | freq_distance = tmp[..., 0] + tmp[..., 1] 114 | 115 | # dynamic spectrum weighting (Hadamard product) 116 | loss = weight_matrix * freq_distance * weights.to(freq_distance.device) 117 | return torch.mean(loss) 118 | 119 | def forward(self, pred, target, matrix=None, **kwargs): 120 | """Forward function to calculate focal frequency loss. 121 | 122 | Args: 123 | pred (torch.Tensor): of shape (N, C, H, W). Predicted tensor. 124 | target (torch.Tensor): of shape (N, C, H, W). Target tensor. 125 | matrix (torch.Tensor, optional): Element-wise spectrum weight matrix. 126 | Default: None (If set to None: calculated online, dynamic). 127 | """ 128 | # pred_freq = self.tensor2freq(pred) 129 | # target_freq = self.tensor2freq(target) 130 | pred_freq = pred.unsqueeze(2) 131 | target_freq = target.unsqueeze(2) 132 | 133 | 134 | # whether to use minibatch average spectrum 135 | if self.ave_spectrum: 136 | pred_freq = torch.mean(pred_freq, 0, keepdim=True) 137 | target_freq = torch.mean(target_freq, 0, keepdim=True) 138 | 139 | # calculate focal frequency loss 140 | return self.loss_formulation(pred_freq, target_freq, matrix) * self.loss_weight -------------------------------------------------------------------------------- /data_continuous_EHT.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import h5py 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from data_ehtim_cont import * 6 | import torch 7 | 8 | 9 | def load_h5_uvvis(fpath): 10 | print('--loading h5 file for eht sparse and dense {u,v,vis_re,vis_im} dataset...') 11 | with h5py.File(fpath, 'r') as F: 12 | u_sparse = np.array(F['u_sparse']) 13 | v_sparse = np.array(F['v_sparse']) 14 | vis_re_sparse = np.array(F['vis_re_sparse']) 15 | vis_im_sparse = np.array(F['vis_im_sparse']) 16 | u_dense = np.array(F['u_dense']) 17 | v_dense = np.array(F['v_dense']) 18 | vis_re_dense = np.array(F['vis_re_dense']) 19 | vis_im_dense = np.array(F['vis_im_dense']) 20 | print('Done--') 21 | return u_sparse, v_sparse, vis_re_sparse, vis_im_sparse, u_dense, v_dense, vis_re_dense, vis_im_dense 22 | 23 | 24 | def load_h5_uvvis_cont(fpath): 25 | print('--loading h5 file for eht continuous {u,v,vis_re,vis_im} dataset...') 26 | with h5py.File(fpath, 'r') as F: 27 | u_cont = np.array(F['u_cont']) 28 | v_cont = np.array(F['v_cont']) 29 | vis_re_cont = np.array(F['vis_re_cont']) 30 | vis_im_cont = np.array(F['vis_im_cont']) 31 | print('Done--') 32 | return u_cont, v_cont, vis_re_cont, vis_im_cont 33 | 34 | 35 | class EHTIM_Dataset(Dataset): 36 | ''' 37 | EHT-imaged dataset (load precomputed) 38 | ''' 39 | def __init__(self, 40 | dset_name = 'Galaxy10', # 'MNIST' 41 | data_path = '../data/eht_grid_128FC_200im_Galaxy10_DECals_full.h5', 42 | data_path_imgs = '../data/Galaxy10_DECals.h5', 43 | data_path_cont = '../data/eht_cont_200im_Galaxy10_DECals_full.h5', 44 | img_res = 200, 45 | pre_normalize = False, 46 | ): 47 | 48 | # get spectral data 49 | u_sparse, v_sparse, vis_re_sparse, vis_im_sparse, u_dense, v_dense, vis_re_dense, vis_im_dense = load_h5_uvvis(data_path) 50 | print(u_sparse.shape, v_sparse.shape, vis_re_sparse.shape, vis_im_sparse.shape, u_dense.shape, v_dense.shape, vis_re_dense.shape, vis_im_dense.shape) 51 | 52 | uv_sparse = np.stack((u_sparse.flatten(), v_sparse.flatten()), axis=1) 53 | uv_dense = np.stack((u_dense.flatten(), v_dense.flatten()), axis=1) 54 | fourier_resolution = int(len(uv_dense)**(0.5)) 55 | self.fourier_res = fourier_resolution 56 | 57 | # rescale uv to (-0.5, 0.5) 58 | max_base = np.max(uv_sparse) 59 | uv_dense_scaled = np.rint((uv_dense+max_base) / max_base * (fourier_resolution-1)/2) / (fourier_resolution-1) - 0.5 60 | self.uv_dense = uv_dense_scaled 61 | self.vis_re_dense = vis_re_dense 62 | self.vis_im_dense = vis_im_dense 63 | # TODO: double check un-scaling if continuous (originally scaled with sparse) 64 | # should be ok bc dataset generation was scaled to max baseline, so np.max(uv_sparse)=np.max(uv_cont) 65 | 66 | # use sparse continuous data 67 | if data_path_cont: 68 | print('using sparse continuous visibility data..') 69 | u_cont, v_cont, vis_re_cont, vis_im_cont = load_h5_uvvis_cont(data_path_cont) 70 | uv_cont = np.stack((u_cont.flatten(), v_cont.flatten()), axis=1) 71 | uv_cont_scaled = np.rint((uv_cont+max_base) / max_base * (fourier_resolution-1)/2) / (fourier_resolution-1) - 0.5 72 | self.uv_sparse = uv_cont_scaled 73 | self.vis_re_sparse = vis_re_cont 74 | self.vis_im_sparse = vis_im_cont 75 | 76 | # use sparse grid data 77 | else: 78 | print('using sparse grid visibility data..') 79 | uv_sparse_scaled = np.rint((uv_sparse+max_base) / max_base * (fourier_resolution-1)/2) / (fourier_resolution-1) - 0.5 80 | self.uv_sparse = uv_sparse_scaled 81 | self.vis_re_sparse = vis_re_sparse 82 | self.vis_im_sparse = vis_im_sparse 83 | 84 | # load GT images 85 | self.img_res = img_res 86 | 87 | if dset_name == 'MNIST': 88 | if data_path_imgs: 89 | from torchvision.datasets import MNIST 90 | from torchvision import transforms 91 | 92 | transform = transforms.Compose([transforms.Resize((img_res, img_res)), 93 | transforms.ToTensor(), 94 | transforms.Normalize((0.1307,), (0.3081,)), 95 | ]) 96 | self.img_dataset = MNIST('', train=True, download=True, transform=transform) 97 | else: # if loading img data is not necessary 98 | self.img_dataset = None 99 | 100 | elif dset_name == 'Galaxy10' or 'Galaxy10_DECals': 101 | if data_path_imgs: 102 | self.img_dataset = Galaxy10_Dataset(data_path_imgs, None) 103 | else: # if loading img data is not necessary 104 | self.img_dataset = None 105 | 106 | else: 107 | print('[ MNIST | Galaxy10 | Galaxy10_DECals ]') 108 | raise NotImplementedError 109 | 110 | # pre-normalize data? (disable for phase loss) 111 | self.pre_normalize = pre_normalize 112 | 113 | 114 | def __getitem__(self, idx): 115 | vis_dense = self.vis_re_dense[:,idx] + 1j*self.vis_im_dense[:,idx] 116 | vis_real = self.vis_re_sparse[:,idx].astype(np.float32) 117 | vis_imag = self.vis_im_sparse[:,idx].astype(np.float32) 118 | if self.pre_normalize == True: 119 | padding = 50 ## TODO make this actual hyperparam 120 | real_min, real_max= np.amin(vis_real)-padding, np.amax(vis_real)+padding 121 | imag_min, imag_max= np.amin(vis_imag)-padding, np.amax(vis_imag)+padding 122 | vis_real_normed = (vis_real - real_min) / (real_max - real_min) 123 | vis_imag_normed = (vis_imag - imag_min) / (imag_max - imag_min) 124 | vis_sparse = np.stack([vis_real_normed, vis_imag_normed], axis=1) 125 | else: 126 | vis_sparse = np.stack([vis_real, vis_imag], axis=1) 127 | 128 | if self.img_dataset: 129 | img, label = self.img_dataset[idx] 130 | img_res_initial = int(torch.numel(img)**(0.5)) 131 | img = img.reshape((img_res_initial,img_res_initial)) 132 | if img_res_initial != self.img_res: 133 | img = upscale_tensor(img, final_res=self.img_res, method='cubic') 134 | img = torch.from_numpy(img) 135 | else: 136 | img = torch.from_numpy(np.zeros((self.img_res,self.img_res))) 137 | label = None 138 | 139 | return self.uv_sparse.astype(np.float32), self.uv_dense.astype(np.float32), vis_sparse.astype(np.float32), vis_dense, img, label 140 | 141 | def __len__(self): 142 | return len(self.vis_re_sparse[0,:]) 143 | 144 | 145 | if __name__ == "__main__": 146 | 147 | fourier_resolution = 64 148 | dset_name = 'Galaxy10' #'MNIST' 149 | idx = 123 150 | 151 | data_path =f'../data/eht_grid_128FC_200im_Galaxy10_DECals_full.h5' 152 | 153 | #spectral_dataset = EHTIM_Dataset(data_path) 154 | #uv_sparse, uv_dense, vis_sparse, vis_dense = spectral_dataset[idx] 155 | 156 | im_data_path = '../data/Galaxy10_DECals.h5' 157 | spectral_dataset = EHTIM_Dataset(dset_name = dset_name, 158 | data_path = data_path, 159 | data_path_imgs = im_data_path, 160 | img_res = 200 161 | ) 162 | uv_sparse, uv_dense, vis_sparse, vis_dense, img = spectral_dataset[idx] 163 | print(uv_sparse.shape, uv_dense.shape, vis_sparse.shape, vis_dense.shape, img.shape) 164 | 165 | # plot data 166 | vis_amp_sparse = np.linalg.norm(vis_sparse, axis=1) 167 | vis_amp_dense = np.abs(vis_dense) 168 | 169 | print(uv_sparse.shape) 170 | plt.scatter(uv_sparse[:,0], uv_sparse[:,1], c=vis_amp_sparse) 171 | plt.savefig('ehtim_sparse.png') 172 | print(uv_dense.shape) 173 | print(uv_dense) 174 | print(vis_amp_dense.shape) 175 | print(vis_amp_dense) 176 | plt.scatter(uv_dense[:,0], uv_dense[:,1], c=vis_amp_dense) 177 | plt.savefig('ehtim_dense.png') 178 | 179 | plt.imshow(img) 180 | plt.savefig('ehtim_gt_img.png') 181 | 182 | # obs_meta = spectral_dataset.get_metadata(idx, dset_name) 183 | # plt.imshow(obs_meta['gt_img']) 184 | # plt.savefig('ehtim_gt_img.png') 185 | -------------------------------------------------------------------------------- /context_encoder/encoders.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import copy 5 | import math 6 | import numpy as np 7 | from sklearn.cluster import AgglomerativeClustering 8 | from sklearn.metrics import pairwise_distances 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | # from torch_geometric.nn import GATConv, GATv2Conv 14 | from torch import nn, einsum 15 | from einops import rearrange, repeat 16 | # from torch_geometric.utils import dense_to_sparse, to_dense_adj, add_remaining_self_loops 17 | import pandas as pd 18 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 19 | 20 | class PreNorm(nn.Module): 21 | def __init__(self, dim, fn): 22 | super().__init__() 23 | self.norm = nn.LayerNorm(dim) 24 | self.fn = fn 25 | def forward(self, x, **kwargs): 26 | return self.fn(self.norm(x), **kwargs) 27 | 28 | class FeedForward(nn.Module): 29 | def __init__(self, dim, hidden_dim, dropout = 0.): 30 | super().__init__() 31 | self.net = nn.Sequential( 32 | nn.Linear(dim, hidden_dim), 33 | nn.GELU(), 34 | nn.Dropout(dropout), 35 | nn.Linear(hidden_dim, dim), 36 | nn.Dropout(dropout) 37 | ) 38 | def forward(self, x): 39 | return self.net(x) 40 | 41 | class Attention(nn.Module): 42 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., adj_dim = 83): 43 | super().__init__() 44 | inner_dim = dim_head * heads 45 | project_out = not (heads == 1 and dim_head == dim) 46 | 47 | # self.adj_dim = adj_dim 48 | 49 | self.heads = heads 50 | self.scale = dim_head ** -0.5 51 | 52 | self.attend = nn.Softmax(dim = -1) 53 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 54 | # self.adj_head = nn.Linear(adj_dim, adj_dim * heads) 55 | 56 | self.to_out = nn.Sequential( 57 | nn.Linear(inner_dim, dim), 58 | nn.Dropout(dropout) 59 | ) if project_out else nn.Identity() 60 | 61 | def forward(self, x): 62 | b, n, _, h = *x.shape, self.heads # n- # of tokens, h- # of heads, n- # of dimensions for each head 63 | qkv = self.to_qkv(x).chunk(3, dim = -1) 64 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 65 | 66 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 67 | # adj = self.adj_head(adj) 68 | # adj = rearrange(adj, 'b n (h d) -> b h n d', h = h) 69 | # dots = dots + adj 70 | 71 | attn = self.attend(dots) 72 | 73 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 74 | out = rearrange(out, 'b h n d -> b n (h d)') 75 | return self.to_out(out) 76 | 77 | class Transformer(nn.Module): 78 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., has_global_token=False): 79 | super().__init__() 80 | self.layers = nn.ModuleList() 81 | for _ in range(depth): 82 | self.layers.append(nn.ModuleList([ 83 | Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, adj_dim=83))), 84 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) 85 | ])) 86 | def forward(self, x): 87 | # x, adj = data 88 | for attn, ff in self.layers: 89 | x = attn(x) 90 | x = ff(x) 91 | return x 92 | 93 | class PolarRec_Encoder(nn.Module): 94 | def __init__(self, *, 95 | input_dim, pe_dim, dim, depth, heads, mlp_dim, 96 | dim_head = 16, dropout = 0., emb_dropout = 0., 97 | output_dim = 1024, output_tokens = 3, 98 | has_global_token=False, group_size = 16, 99 | ): 100 | 101 | ''' 102 | dim_value_embedding = dim - pe_dim 103 | ''' 104 | 105 | super().__init__() 106 | 107 | assert output_tokens>0 108 | 109 | assert dim>pe_dim 110 | 111 | self.input_dim=input_dim #input value dimension (e.g. =2 visibility map - real and imag. ) 112 | self.pe_dim = pe_dim # positional-encoding dim. 113 | self.dim =dim #feature embedding dimension 114 | 115 | self.depth =depth 116 | self.heads =heads # number of multi-heads 117 | self.mlp_dim =mlp_dim 118 | self.output_dim = output_dim 119 | self.output_tokens = output_tokens # number of output tokens 120 | 121 | self.global_token=None 122 | self.has_global_token= has_global_token # if use global token 123 | self.cen_dim = 0 124 | self.total_num = 1660 125 | self.mean_pool_size = 1660 // group_size 126 | 127 | if has_global_token: 128 | self.global_token = nn.Parameter(torch.randn(1, 1, dim)) 129 | 130 | self.feat_embedding = nn.Sequential( 131 | nn.Linear(self.input_dim, self.dim - self.pe_dim - self.cen_dim), 132 | ) 133 | 134 | self.before_mean = nn.Sequential( 135 | nn.Linear(self.dim, self.dim // 2), 136 | nn.LeakyReLU(), 137 | nn.Linear(self.dim // 2, self.dim) 138 | ) 139 | 140 | 141 | self.mean_pool = nn.AdaptiveAvgPool2d((self.mean_pool_size, self.dim)) 142 | 143 | self.emb_dropout = nn.Dropout(emb_dropout) 144 | 145 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 146 | 147 | self.output_token_heads = [ nn.Sequential( 148 | nn.LayerNorm(dim), 149 | nn.Linear(dim, self.output_dim) 150 | ) for _ in range(int(self.output_tokens)) ] 151 | self.output_token_heads = nn.ModuleList(self.output_token_heads) 152 | self.sorted_indices = None 153 | 154 | 155 | 156 | 157 | def adj_mat(self, pos): 158 | 159 | x = pos[:, :, 0] 160 | y = pos[:, :, 1] 161 | x = x.unsqueeze(2) 162 | y = y.unsqueeze(2) 163 | x_i_2 = torch.einsum('ijk,ijk->ijk',[x,x]) 164 | res = x_i_2.expand(pos.shape[0], pos.shape[1], pos.shape[1]) 165 | res = res + res.transpose(1,2) 166 | res = res - 2* torch.bmm(x, x.transpose(1,2)) 167 | y_i_2 = torch.einsum('ijk,ijk->ijk',[y,y]) 168 | y_2 = y_i_2.expand(pos.shape[0], pos.shape[1], pos.shape[1]) 169 | res = res + y_2 170 | res = res + y_2.transpose(1,2) 171 | res = res - 2* torch.bmm(y, y.transpose(1,2)) 172 | return F.normalize(res.sqrt(), p=2, dim=2) 173 | 174 | 175 | 176 | def forward(self, tokens, pos): 177 | ''' 178 | INPUTS 179 | tokens - B x N_token x Dim_token_feature(input_dim), input : [pose_embed, values] 180 | 181 | OUTPUTS: 182 | output_tokens - output tokens (B x N_out_tokens x dim_out_tokens 183 | ''' 184 | # adj = self.adj_mat(pos) 185 | # cen_emb = self.central_embedding(adj) 186 | 187 | emb_val = self.feat_embedding(tokens[..., self.pe_dim :]) # B xN_token x self.dim- self.pe_dim 188 | 189 | 190 | # emb_token = torch.cat([tokens[..., :self.pe_dim], emb_val, cen_emb], dim=-1) # B xN_token x self.dim 191 | emb_token = torch.cat([tokens[..., :self.pe_dim], emb_val], dim=-1) # B xN_token x self.dim 192 | 193 | # if self.sorted_indices == None: 194 | # batch_size = emb_token.shape[0] 195 | # num_tokens = emb_token.shape[1] 196 | 197 | # # We use only the first item's position in the batch for calculation 198 | # pos_np = pos[0].cpu().detach().numpy() 199 | 200 | # # Calculate pairwise distances 201 | # distances = pairwise_distances(pos_np, pos_np) 202 | # n = 20 203 | # # Apply clustering 204 | # clustering = AgglomerativeClustering(n_clusters=num_tokens//n, affinity='precomputed', linkage='average') 205 | # labels = clustering.fit_predict(distances) 206 | 207 | # # Now we sort our tokens and pos based on labels 208 | # self.sorted_indices = torch.argsort(torch.tensor(labels)) 209 | 210 | if self.sorted_indices == None: 211 | batch_size = emb_token.shape[0] 212 | num_tokens = emb_token.shape[1] 213 | 214 | # We use only the first item's position in the batch for calculation 215 | pos_np = pos[0].cpu().detach().numpy() 216 | 217 | x_coords = pos_np[:, 0] 218 | y_coords = pos_np[:, 1] 219 | 220 | angles = np.arctan2(y_coords, x_coords) 221 | 222 | self.sorted_indices = torch.argsort(torch.tensor(angles)) 223 | 224 | np.save("vispoints.npy", pos_np) 225 | 226 | 227 | emb_token = emb_token[:, self.sorted_indices] 228 | 229 | 230 | emb_token = self.before_mean(emb_token) 231 | emb_token = self.mean_pool(emb_token) 232 | 233 | 234 | 235 | B, N_token, _ = emb_token.shape 236 | 237 | if self.has_global_token: 238 | emb_token = torch.cat([self.global_token.repeat(B, 1,1), emb_token], dim=1) 239 | 240 | emb_token = self.emb_dropout(emb_token) 241 | transformed_token = self.transformer(emb_token) 242 | 243 | #currently use the index reduction but there are other reduction 244 | #TODO: use matrix multiplication as the reduction method 245 | transformed_token_reduced = transformed_token[:, :self.output_tokens, ...] 246 | 247 | 248 | out_tokens=[] 249 | for idx_token in range(self.output_tokens): 250 | out_tokens.append(self.output_token_heads[idx_token]( 251 | transformed_token_reduced[:, idx_token,...].unsqueeze(1))) 252 | output_tokens = torch.cat(out_tokens, dim=1) 253 | 254 | 255 | return output_tokens 256 | 257 | 258 | class Residual(nn.Module): 259 | def __init__(self, fn): 260 | super().__init__() 261 | self.fn = fn 262 | 263 | def forward(self, x, **kwargs): 264 | return self.fn(x, **kwargs) + x 265 | 266 | 267 | class QuickFix(nn.Module): 268 | def __init__(self, dim, heads, fn): 269 | super().__init__() 270 | self.dim = dim 271 | self.heads = heads 272 | self.linear = nn.Linear(dim * heads, dim) 273 | self.fn = fn 274 | 275 | def forward(self, x, **kwargs): 276 | return self.linear(self.fn(x, **kwargs)) 277 | 278 | 279 | 280 | #ConvEncoder# 281 | class LinearEncoder(nn.Module): 282 | def __init__(self, x_dim=28*28, hidden_dims=[512, 256], latent_dim=2, in_channels=1): 283 | super().__init__() 284 | self.fc1 = nn.Linear(x_dim*in_channels, hidden_dims[0]) 285 | self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1]) 286 | self.fc31 = nn.Linear(hidden_dims[1], latent_dim) # mu 287 | self.fc32 = nn.Linear(hidden_dims[1], latent_dim) # log_var 288 | 289 | def forward(self, x): 290 | x = x.view(x.shape[0], -1) 291 | h = F.relu(self.fc1(x)) 292 | h = F.relu(self.fc2(h)) 293 | return self.fc31(h), self.fc32(h) # mu, log_var 294 | 295 | class ConvEncoder(nn.Module): 296 | def __init__(self, x_dim=28*28, hidden_dims=[32, 64], latent_dim=2, in_channels=1, activation=nn.ReLU): 297 | super().__init__() 298 | self.activation = activation 299 | modules = [] 300 | '''modules.append(nn.Sequential( 301 | nn.Conv2d(in_channels, out_channels=hidden_dims[0], 302 | kernel_size=5, stride=1, padding=2), 303 | nn.BatchNorm2d(hidden_dims[0]), 304 | activation())) 305 | modules.append(nn.Sequential( 306 | nn.Conv2d(hidden_dims[0], out_channels=2*hidden_dims[0], 307 | kernel_size=5, stride=1, padding=2), 308 | nn.BatchNorm2d(2*hidden_dims[0]), 309 | activation())) 310 | in_channels = 2*hidden_dims[0]''' 311 | for h_dim in hidden_dims: 312 | modules.append( 313 | nn.Sequential( 314 | nn.Conv2d(in_channels, out_channels=h_dim, 315 | kernel_size=5, stride=2, padding=2), 316 | nn.BatchNorm2d(h_dim), 317 | activation()) 318 | ) 319 | in_channels = h_dim 320 | #bottleneck_res = [28, 14, 7, 4, 2] + [1]*30 ## TODO only valid for 28^2 mnist digits 321 | bottleneck_res = [int(np.ceil(np.sqrt(x_dim) * 0.5**i)) for i in range(35)] # set res to decrease geometrically 322 | self.res_flattened = bottleneck_res[len(hidden_dims)] 323 | self.encoder = nn.Sequential(*modules) 324 | self.fc_mu = nn.Linear(hidden_dims[-1]*(self.res_flattened**2), latent_dim) 325 | self.fc_var = nn.Linear(hidden_dims[-1]*(self.res_flattened**2), latent_dim) 326 | 327 | def forward(self, x): 328 | x = self.encoder(x) 329 | x = torch.flatten(x, start_dim=1) 330 | try: 331 | mu = self.fc_mu(x) 332 | except: 333 | import ipdb; ipdb.set_trace() 334 | log_var = self.fc_var(x) 335 | 336 | return mu, log_var 337 | 338 | 339 | 340 | 341 | 342 | 343 | -------------------------------------------------------------------------------- /mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import pytorch_lightning as pl 4 | import numpy as np 5 | 6 | def posenc(x, L_embed=4): 7 | 8 | rets = [x] 9 | for i in range(0, L_embed): 10 | for fn in [torch.sin, torch.cos]: 11 | rets.append(fn(2.*3.14159265*(i+1) * x)) 12 | return torch.cat(rets, dim=-1) 13 | 14 | def calcB(m=1024, d=2, sigma=1.0): 15 | B = torch.randn(m, d)*sigma 16 | return B.cuda() 17 | 18 | def fourierfeat_enc(x, B): 19 | 20 | feat = torch.cat([#torch.sum(x**2, -1, keepdims=True), ## new 21 | x, ## new 22 | torch.cos(2*3.14159265*(x @ B.T)), 23 | torch.sin(2*3.14159265*(x @ B.T))], -1) 24 | return feat 25 | 26 | class PE_Module(torch.nn.Module): 27 | def __init__(self, type, embed_L): 28 | super(PE_Module, self).__init__() 29 | 30 | self.embed_L= embed_L 31 | self.type=type 32 | 33 | def forward(self, x): 34 | if self.type == 'posenc': 35 | return posenc(x, L_embed=self.embed_L) 36 | 37 | elif self.type== 'fourier': 38 | return fourierfeat_enc(x, B=self.embed_L) 39 | 40 | class PosEncodedMLP(torch.nn.Module): 41 | def __init__(self, 42 | input_size=2, output_size=2, 43 | hidden_dims=[256, 256], L_embed=5, 44 | embed_type='nerf', activation=nn.ReLU, sigma=0.1, 45 | ): 46 | 47 | super(PosEncodedMLP, self).__init__() 48 | self.embed_type = embed_type 49 | self.L_embed = L_embed 50 | if self.L_embed > 0 and self.embed_type == 'nerf': 51 | self.input_size = L_embed*2*input_size+input_size 52 | elif self.L_embed > 0 and self.embed_type == 'fourier': 53 | self.B = calcB(m=L_embed, d=2, sigma=sigma) 54 | self.input_size = L_embed*2+3 55 | else: 56 | self.input_size = input_size 57 | 58 | #import ipdb; ipdb.set_trace() 59 | 60 | modules = [] 61 | dim_prev = self.input_size 62 | for h_dim in hidden_dims: 63 | modules.append( 64 | nn.Sequential( 65 | nn.Linear(dim_prev, h_dim), 66 | activation())) 67 | dim_prev = h_dim 68 | modules.append(nn.Sequential(nn.Linear(hidden_dims[-1], output_size), 69 | ))#nn.Sigmoid())) 70 | self.mlp = nn.Sequential(*modules) 71 | 72 | def _step(self, x): 73 | 74 | if self.L_embed > 0 and self.embed_type == 'nerf': 75 | x = posenc(x, self.L_embed) 76 | elif self.L_embed > 0 and self.embed_type == 'fourier': 77 | x = fourierfeat_enc(x, self.B) 78 | 79 | x = self.mlp(x) 80 | 81 | return x 82 | 83 | def forward(self, x): 84 | x = self._step(x) 85 | return x 86 | 87 | 88 | 89 | class PosEncodedMLP_FiLM(pl.LightningModule): 90 | 91 | def __init__(self, context_dim=64, input_size=2, output_size=2, 92 | hidden_dims=[256, 256], L_embed=10, embed_type='nerf', 93 | activation=nn.ReLU, sigma=5.0, 94 | context_type='VAE'): 95 | ''' 96 | context_type = 'VAE'(default) | 'Transformer' 97 | ''' 98 | super().__init__() 99 | 100 | self.context_type = context_type 101 | if context_dim > 0: 102 | layer = FiLMLinear 103 | else: 104 | layer = nn.Linear # will break if context_dim is an input 105 | 106 | self.context_dim = context_dim 107 | 108 | self.embed_type = embed_type 109 | self.L_embed = L_embed 110 | 111 | 112 | if self.L_embed > 0 and self.embed_type == 'nerf': 113 | self.input_size = L_embed*2*input_size+input_size 114 | elif self.L_embed > 0 and self.embed_type == 'fourier': 115 | self.B = nn.Parameter(calcB(m=L_embed, d=2, sigma=sigma), requires_grad=False) 116 | # self.input_size = L_embed*2+3 117 | self.input_size = L_embed*2+2 # change from +3 to +2 due to change in the fourierfeat_enc() function 118 | else: 119 | self.input_size = input_size 120 | 121 | #positional embedding function# 122 | if self.L_embed > 0 and self.embed_type == 'nerf': 123 | # self.embed_fun = lambda x_in: posenc(x_in, self.L_embed) 124 | self.embed_func= PE_Module(type='posenc', embed_L=self.L_embed) 125 | 126 | elif self.L_embed > 0 and self.embed_type == 'fourier': 127 | # self.embed_fun = lambda x_in: fourierfeat_enc(x_in, self.B) 128 | self.embed_fun = PE_Module(type='fourier', embed_L= self.B) 129 | 130 | self.layers = [] 131 | self.activations = [] 132 | dim_prev = self.input_size 133 | for h_dim in hidden_dims: 134 | self.layers.append(layer(dim_prev, h_dim, context_dim=self.context_dim)) 135 | self.activations.append(activation()) 136 | dim_prev = h_dim 137 | 138 | # self.layer1 = layer(self.input_size, hidden_dims[0], context_dim=self.context_dim) 139 | # self.act1 = activation() 140 | # self.layer2 = layer(hidden_dims[0], hidden_dims[1], context_dim=self.context_dim) 141 | # self.act2 = activation() 142 | 143 | self.layers= nn.ModuleList(self.layers) 144 | self.activations= nn.ModuleList(self.activations) 145 | self.final_layer = layer(hidden_dims[-1], output_size, context_dim=self.context_dim) 146 | ##self.final_activation = nn.Sigmoid() ## TODO removed this for unconstrained output 147 | 148 | def set_B(self, B): 149 | self.B = B 150 | 151 | def forward(self, x_in, context): 152 | ''' 153 | context - 154 | B x 1 x ndim for VAE, 155 | B x L x ndim for Transfomer (assuming L layers in MLP) 156 | ''' 157 | 158 | # if self.L_embed > 0 and self.embed_type == 'nerf': 159 | # x_embed = posenc(x_in, self.L_embed) 160 | # elif self.L_embed > 0 and self.embed_type == 'fourier': 161 | # x_embed = fourierfeat_enc(x_in, self.B) 162 | 163 | x_embed = self.embed_fun(x_in) # B x N x 2 -> B x N x dim_PE_dim 164 | 165 | 166 | #for l, a in zip(self.layers, self.activations): 167 | # print(x.shape, x.device, context.shape, context.device); input() 168 | # x = l(x, context) 169 | # x = a(x) 170 | 171 | # if self.context_type=='VAE': 172 | # con1 = context 173 | # con2 = context 174 | # con3 = context 175 | # elif self.context_type=='Transformer': 176 | # con1 = context[:, 0, :].unsqueeze(1) 177 | # con2 = context[:, 1, :].unsqueeze(1) 178 | # con3 = context[:, 2, :].unsqueeze(1) 179 | # x = self.layer1(x_embed, con1) 180 | # x = self.act1(x) 181 | # x = self.layer2(x, con2) 182 | # x = self.act2(x) 183 | # x = self.final_layer(x, con3) 184 | #x = self.final_activation(x) 185 | 186 | x_tmp = x_embed 187 | for ilayer, layer in enumerate(self.layers): 188 | x = layer( x_tmp, context if self.context_type=='VAE' else context[:,ilayer,:].unsqueeze(1) ) 189 | x = self.activations[ilayer](x) 190 | x_tmp = x 191 | 192 | x= self.final_layer(x_tmp, context if self.context_type=='VAE' else context[:,-1,:].unsqueeze(1) ) 193 | 194 | return x 195 | 196 | class FiLMLinear(pl.LightningModule): 197 | def __init__(self, in_dim, out_dim, context_dim=64, residual=False): 198 | super().__init__() 199 | 200 | self.linear = nn.Linear(in_dim, out_dim) 201 | self.activation1 = nn.LeakyReLU() 202 | self.activation2 = nn.LeakyReLU() 203 | self.film1 = nn.Linear(context_dim, out_dim) 204 | self.film2 = nn.Linear(context_dim, out_dim) 205 | self.residual = residual 206 | 207 | def forward(self, x, shape_context): 208 | if self.residual: 209 | out = self.linear(x) 210 | resid = self.activation1(out) 211 | 212 | gamma = self.film1(shape_context) 213 | beta = self.film2(shape_context) 214 | 215 | out = gamma * out + beta 216 | 217 | out = self.activation2(out) 218 | out = out + resid 219 | else: 220 | out = self.linear(x) 221 | gamma = self.film1(shape_context) 222 | beta = self.film2(shape_context) 223 | out = gamma * out + beta 224 | out = self.activation1(out) 225 | return out 226 | 227 | 228 | class Linear(pl.LightningModule): 229 | ''' dummy wrapper around linear to support (ignoring) shape context param''' 230 | def __init__(self, in_dim, out_dim, context_dim=64, residual=False): 231 | super().__init__() 232 | self.linear = nn.Linear(in_dim, out_dim) 233 | 234 | def forward(self, x, shape_context=None): #ignore shape context 235 | out = self.linear(x) 236 | return out 237 | 238 | import torch 239 | from torch import nn 240 | 241 | class NeRF_Embedding(nn.Module): 242 | def __init__(self, in_channels, N_freqs, logscale=True): 243 | """ 244 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) 245 | in_channels: number of input channels (3 for both xyz and direction) 246 | """ 247 | super(NeRF_Embedding, self).__init__() 248 | self.N_freqs = N_freqs 249 | self.in_channels = in_channels 250 | self.funcs = [torch.sin, torch.cos] 251 | self.out_channels = in_channels*(len(self.funcs)*N_freqs+1) 252 | 253 | if logscale: 254 | self.freq_bands = 2**torch.linspace(0, N_freqs-1, N_freqs) 255 | else: 256 | self.freq_bands = torch.linspace(1, 2**(N_freqs-1), N_freqs) 257 | 258 | def forward(self, x): 259 | """ 260 | Embeds x to (x, sin(2^k x), cos(2^k x), ...) 261 | Different from the paper, "x" is also in the output 262 | See https://github.com/bmild/nerf/issues/12 263 | 264 | Inputs: 265 | x: (B, self.in_channels) 266 | 267 | Outputs: 268 | out: (B, self.out_channels) 269 | """ 270 | out = [x] 271 | for freq in self.freq_bands: 272 | for func in self.funcs: 273 | out += [func(freq*x)] 274 | 275 | return torch.cat(out, -1) 276 | 277 | class NeRF_Fourier(pl.LightningModule): 278 | def __init__(self, 279 | context_dim=64, 280 | input_size=2, 281 | output_size=5, 282 | D=8, W=256, 283 | L_embed=10, 284 | skips=[4], 285 | hidden_dims = None, # dummy 286 | embed_type = 'nerf', 287 | activation = nn.ReLU, 288 | sigma = 2.5, 289 | context_type='VAE'): 290 | """ 291 | D: number of layers for density (sigma) encoder 292 | W: number of hidden units in each layer 293 | skips: add skip connection in the Dth layer 294 | """ 295 | super(NeRF_Fourier, self).__init__() 296 | 297 | self.context_type = context_type 298 | if context_dim > 0: 299 | Layer = FiLMLinear 300 | else: 301 | Layer = Linear 302 | 303 | self.context_dim = context_dim 304 | 305 | self.embed_type = embed_type 306 | self.L_embed = L_embed 307 | 308 | self.D = D 309 | self.W = W 310 | 311 | self.skips = skips 312 | 313 | if embed_type == 'nerf': 314 | self.embedding_xyz = NeRF_Embedding(input_size, L_embed, logscale=True) # 10 is the default number 315 | self.in_channels_xyz = input_size * ( 316 | len(self.embedding_xyz.funcs) * self.embedding_xyz.N_freqs + 1) # in_channels_xyz 317 | else: 318 | self.B = calcB(m=L_embed, d=input_size, sigma=sigma) 319 | self.in_channels_xyz = L_embed*2 + input_size #+ 1 # 320 | 321 | # xyz encoding layers 322 | for i in range(D): 323 | if i == 0: 324 | layer = Layer(self.in_channels_xyz, W, context_dim=self.context_dim) 325 | elif i in skips: 326 | layer = Layer(W+self.in_channels_xyz, W, context_dim=self.context_dim) 327 | else: 328 | layer = Layer(W, W, context_dim=self.context_dim) 329 | layer = _Sequential(layer, activation(True)) 330 | setattr(self, f"xyz_encoding_{i+1}", layer) 331 | self.xyz_encoding_final = Layer(W, W, context_dim=self.context_dim) 332 | 333 | # output layers (real and imag) 334 | # or if using phase loss, out_dim may be 5 335 | self.fourier = Layer(W, output_size, context_dim=self.context_dim) 336 | 337 | def set_B(self, B): 338 | self.B = B 339 | 340 | def forward(self, x, context=None): 341 | """ 342 | Encodes input (xyz+dir) to rgb+sigma (not ready to render yet). 343 | For rendering this ray, please see rendering.py 344 | 345 | Inputs: 346 | x: (B, self.in_channels_xyz(+self.in_channels_dir)) 347 | the embedded vector of position and direction 348 | sigma_only: whether to infer sigma only. If True, 349 | x is of shape (B, self.in_channels_xyz) 350 | 351 | Outputs: 352 | if sigma_ony: 353 | sigma: (B, 1) sigma 354 | else: 355 | out: (B, 4), rgb and sigma 356 | """ 357 | if self.embed_type == 'nerf': 358 | embedded_x = self.embedding_xyz(x) 359 | else: 360 | embedded_x = fourierfeat_enc(x, self.B) 361 | input_xyz = embedded_x 362 | 363 | xyz_ = input_xyz 364 | for i in range(self.D): 365 | if i in self.skips: 366 | xyz_ = torch.cat([input_xyz, xyz_], -1) 367 | xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_, context) 368 | 369 | fourier = self.fourier(xyz_, context) 370 | 371 | return fourier 372 | 373 | class _Sequential(nn.Sequential): 374 | def forward(self, input, shape_context=None): 375 | for module in self._modules.values(): 376 | if type(module) == FiLMLinear or type(module) == _Sequential: 377 | input = module(input, shape_context=shape_context) 378 | else: 379 | input = module(input) 380 | return input 381 | 382 | class NeRF_Fourier_Two_Heads(nn.Module): 383 | def __init__(self, 384 | input_size=2, 385 | output_size=5, 386 | D=8, W=256, 387 | L_embed=10, 388 | skips=[4], 389 | embed_type = 'nerf', 390 | activation = nn.ReLU, 391 | sigma = 2.5, 392 | ): 393 | """ 394 | D: number of layers for density (sigma) encoder 395 | W: number of hidden units in each layer 396 | skips: add skip connection in the Dth layer 397 | """ 398 | super(NeRF_Fourier_Two_Heads, self).__init__() 399 | self.D = D 400 | self.W = W 401 | 402 | self.skips = skips 403 | 404 | self.embed_type = embed_type 405 | 406 | if embed_type == 'nerf': 407 | self.embedding_xyz = NeRF_Embedding(input_size, L_embed, logscale=True) # 10 is the default number 408 | self.in_channels_xyz = input_size * ( 409 | len(self.embedding_xyz.funcs) * self.embedding_xyz.N_freqs + 1) # in_channels_xyz 410 | else: 411 | self.B = calcB(m=L_embed, d=input_size, sigma=sigma) 412 | self.in_channels_xyz = L_embed*2 + input_size + 1 413 | 414 | # xyz encoding layers 415 | for i in range(D): 416 | if i == 0: 417 | layer = nn.Linear(self.in_channels_xyz, W) 418 | elif i in skips: 419 | layer = nn.Linear(W+self.in_channels_xyz, W) 420 | else: 421 | layer = nn.Linear(W, W) 422 | layer = nn.Sequential(layer, activation(True)) 423 | setattr(self, f"xyz_encoding_{i+1}", layer) 424 | #self.xyz_encoding_final = nn.Linear(W, W) 425 | 426 | # output layers (real and imag) 427 | # or if using phase loss, out_dim may be 5 428 | 429 | self.ampl = nn.Sequential( 430 | nn.Linear(W, W), 431 | activation(True), 432 | nn.Linear(W, 1)) 433 | self.phase = nn.Sequential( 434 | nn.Linear(W, W), 435 | activation(True), 436 | nn.Linear(W, output_size-1)) 437 | 438 | def forward(self, x): 439 | """ 440 | Encodes input (xyz+dir) to rgb+sigma (not ready to render yet). 441 | For rendering this ray, please see rendering.py 442 | 443 | Inputs: 444 | x: (B, self.in_channels_xyz(+self.in_channels_dir)) 445 | the embedded vector of position and direction 446 | sigma_only: whether to infer sigma only. If True, 447 | x is of shape (B, self.in_channels_xyz) 448 | 449 | Outputs: 450 | if sigma_ony: 451 | sigma: (B, 1) sigma 452 | else: 453 | out: (B, 4), rgb and sigma 454 | """ 455 | if self.embed_type == 'nerf': 456 | embedded_x = self.embedding_xyz(x) 457 | else: 458 | embedded_x = fourierfeat_enc(x, self.B) 459 | input_xyz = embedded_x 460 | 461 | xyz_ = input_xyz 462 | for i in range(self.D): 463 | if i in self.skips: 464 | xyz_ = torch.cat([input_xyz, xyz_], -1) 465 | xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_) 466 | 467 | amp = self.ampl(xyz_) 468 | phase = self.phase(xyz_) 469 | 470 | return torch.cat([amp, phase], -1) 471 | -------------------------------------------------------------------------------- /gICLEAN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time,pdb,sys 3 | import astropy.io.fits as pyfits 4 | import matplotlib 5 | #matplotlib.use('Agg') 6 | import matplotlib.image as img 7 | import matplotlib.pyplot as plt 8 | import matplotlib.cm as cm 9 | from scipy import ndimage 10 | from scipy.stats import multivariate_normal 11 | from scipy.signal import convolve2d 12 | 13 | # Import the PyCUDA modules 14 | import pycuda.compiler as nvcc 15 | import pycuda.gpuarray as gpu 16 | import pycuda.driver as cu 17 | import scikits.cuda.fft as fft 18 | # Initialize the CUDA device 19 | import pycuda.autoinit 20 | # Elementwise stuff 21 | from pycuda.elementwise import ElementwiseKernel 22 | from pycuda import cumath 23 | 24 | 25 | 26 | ###################### 27 | # CUDA kernels 28 | ###################### 29 | 30 | def cuda_compile(source_string, function_name): 31 | print("Compiling a CUDA kernel...") 32 | # Compile the CUDA Kernel at runtime 33 | source_module = nvcc.SourceModule(source_string) 34 | # Return a handle to the compiled CUDA kernel 35 | return source_module.get_function(function_name) 36 | 37 | GRID=lambda x,y,W: ((x)+((y)*W)) 38 | 39 | IGRIDX=lambda tid,W: tid%W 40 | IGRIDY=lambda tid,W: int(tid)/int(W) 41 | 42 | # ------------------- 43 | # Gridding kernels 44 | # ------------------- 45 | 46 | code = \ 47 | """ 48 | #define WIDTH 6 49 | #define NCGF 12 50 | #define HWIDTH 3 51 | #define STEP 4 52 | 53 | __device__ __constant__ float cgf[32]; 54 | 55 | // ********************* 56 | // MAP KERNELS 57 | // ********************* 58 | 59 | __global__ void gridVis_wBM_kernel(float2 *Grd, float2 *bm, int *cnt, float *d_u, float *d_v, float *d_re, 60 | float *d_im, int nu, float du, int gcount, int umax, int vmax){ 61 | int iu = blockDim.x*blockIdx.x + threadIdx.x; 62 | int iv = blockDim.y*blockIdx.y + threadIdx.y; 63 | int u0 = 0.5*nu; 64 | if(iu >= u0 && iu <= u0+umax && iv <= u0+vmax){ 65 | for (int ivis = 0; ivis < gcount; ivis++){ 66 | float mu = d_u[ivis]; 67 | float mv = d_v[ivis]; 68 | int hflag = 1; 69 | if (mu < 0){ 70 | hflag = -1; 71 | mu = -1*mu; 72 | mv = -1*mv; 73 | } 74 | float uu = mu/du+u0; 75 | float vv = mv/du+u0; 76 | int cnu=abs(iu-uu),cnv=abs(iv-vv); 77 | int ind = iv*nu+iu; 78 | if (cnu < HWIDTH && cnv < HWIDTH){ 79 | float wgt = cgf[int(round(4.6*cnu+NCGF-0.5))]*cgf[int(round(4.6*cnv+NCGF-0.5))]; 80 | Grd[ind].x += wgt*d_re[ivis]; 81 | Grd[ind].y += hflag*wgt*d_im[ivis]; 82 | cnt[ind] += 1; 83 | bm [ind].x += wgt; 84 | } 85 | // deal with points&pixels close to u=0 boundary 86 | if (iu-u0 < HWIDTH && mu/du < HWIDTH) { 87 | mu = -1*mu; 88 | mv = -1*mv; 89 | uu = mu/du+u0; 90 | vv = mv/du+u0; 91 | cnu=abs(iu-uu),cnv=abs(iv-vv); 92 | if (cnu < HWIDTH && cnv < HWIDTH){ 93 | float wgt = cgf[int(round(4.6*cnu+NCGF-0.5))]*cgf[int(round(4.6*cnv+NCGF-0.5))]; 94 | Grd[ind].x += wgt*d_re[ivis]; 95 | Grd[ind].y += -1*hflag*wgt*d_im[ivis]; 96 | cnt[ind] += 1; 97 | bm [ind].x += wgt; 98 | } 99 | } 100 | } 101 | } 102 | } 103 | 104 | __global__ void dblGrid_kernel(float2 *Grd, int nu, int hfac){ 105 | int iu = blockDim.x*blockIdx.x + threadIdx.x; 106 | int iv = blockDim.y*blockIdx.y + threadIdx.y; 107 | int u0 = 0.5*nu; 108 | if (iu > 0 && iu < u0 && iv < nu){ 109 | int niu = nu-iu; 110 | int niv = nu-iv; 111 | Grd[iv*nu+iu].x = Grd[niv*nu+niu].x; 112 | Grd[iv*nu+iu].y = hfac*Grd[niv*nu+niu].y; 113 | } 114 | } 115 | 116 | __global__ void wgtGrid_kernel(float2 *Grd, int *cnt, float briggs, int nu){ 117 | int iu = blockDim.x*blockIdx.x + threadIdx.x; 118 | int iv = blockDim.y*blockIdx.y + threadIdx.y; 119 | int u0 = 0.5*nu; 120 | if (iu >= u0 && iu < nu && iv < nu){ 121 | if (cnt[iv*nu+iu]!= 0){ 122 | int ind = iv*nu+iu; 123 | float foo = cnt[ind]; 124 | float wgt = 1./sqrt(1 + foo*foo/(briggs*briggs)); 125 | Grd[ind].x = Grd[ind].x*wgt; 126 | Grd[ind].y = Grd[ind].y*wgt; 127 | } 128 | } 129 | } 130 | 131 | __global__ void nrmGrid_kernel(float *Grd, float nrm, int nu){ 132 | int iu = blockDim.x*blockIdx.x + threadIdx.x; 133 | int iv = blockDim.y*blockIdx.y + threadIdx.y; 134 | if (iu < nu && iv < nu){ 135 | Grd[iv*nu + iu] = Grd[iv*nu+iu]*nrm; 136 | } 137 | } 138 | 139 | __global__ void corrGrid_kernel(float2 *Grd, float *corr, int nu){ 140 | int iu = blockDim.x*blockIdx.x + threadIdx.x; 141 | int iv = blockDim.y*blockIdx.y + threadIdx.y; 142 | if (iu < nu && iv < nu){ 143 | Grd[iv*nu + iu].x = Grd[iv*nu+iu].x*corr[nu/2]*corr[nu/2]/(corr[iu]*corr[iv]); 144 | Grd[iv*nu + iu].y = Grd[iv*nu+iu].y*corr[nu/2]*corr[nu/2]/(corr[iu]*corr[iv]); 145 | } 146 | } 147 | 148 | // ********************* 149 | // BEAM KERNELS 150 | // ********************* 151 | __global__ void nrmBeam_kernel(float *bmR, float nrm, int nu){ 152 | int iu = blockDim.x*blockIdx.x + threadIdx.x; 153 | int iv = blockDim.y*blockIdx.y + threadIdx.y; 154 | if(iu < nu && iv < nu){ 155 | bmR[iv*nu+iu] = nrm*bmR[iv*nu+iu]; 156 | } 157 | } 158 | 159 | // ********************* 160 | // MORE semi-USEFUL KERNELS 161 | // ********************* 162 | 163 | __global__ void shiftGrid_kernel(float2 *Grd, float2 *nGrd, int nu){ 164 | int iu = blockDim.x*blockIdx.x + threadIdx.x; 165 | int iv = blockDim.y*blockIdx.y + threadIdx.y; 166 | if(iu < nu && iv < nu){ 167 | int niu,niv,nud2 = 0.5*nu; 168 | if(iu < nud2) niu = nud2+iu; 169 | else niu = iu-nud2; 170 | if(iv < nud2) niv = nud2+iv; 171 | else niv = iv-nud2; 172 | nGrd[niv*nu + niu].x = Grd[iv*nu+iu].x; 173 | nGrd[niv*nu + niu].y = Grd[iv*nu+iu].y; 174 | } 175 | } 176 | 177 | __global__ void trimIm_kernel(float2 *im, float *nim, int noff, int nx, int nnx){ 178 | int ix = blockDim.x*blockIdx.x + threadIdx.x; 179 | int iy = blockDim.y*blockIdx.y + threadIdx.y; 180 | if(iy < nnx && ix < nnx){ 181 | nim[iy*nnx + ix] = im[(iy+noff)*nx+ix+noff].x; 182 | } 183 | } 184 | """ 185 | module = nvcc.SourceModule(code) 186 | gridVis_wBM_kernel = module.get_function("gridVis_wBM_kernel") 187 | shiftGrid_kernel = module.get_function("shiftGrid_kernel") 188 | nrmGrid_kernel = module.get_function("nrmGrid_kernel") 189 | wgtGrid_kernel = module.get_function("wgtGrid_kernel") 190 | dblGrid_kernel = module.get_function("dblGrid_kernel") 191 | corrGrid_kernel = module.get_function("corrGrid_kernel") 192 | nrmBeam_kernel = module.get_function("nrmBeam_kernel") 193 | trimIm_kernel = module.get_function("trimIm_kernel") 194 | 195 | # ------------------- 196 | # CLEAN kernels 197 | # ------------------- 198 | 199 | find_max_kernel_source = \ 200 | """ 201 | // Function to compute 1D array position 202 | #define GRID(x,y,W) ((x)+((y)*W)) 203 | 204 | __global__ void find_max_kernel(float* dimg, int* maxid, float maxval, int W, int H, float* model) 205 | { 206 | // Identify place on grid 207 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 208 | int idy = blockIdx.y * blockDim.y + threadIdx.y; 209 | int id = GRID(idy,idx,H); 210 | 211 | // Ignore boundary pixels 212 | if (idx>-1 && idx-1 && idy-1 && idx-1 && idy-1 && idx-1 && idy-1 && bidx-1 && bidy 1: 306 | print('bad eta value!') 307 | if (twoalp < 1 or twoalp > 4): 308 | print('bad alpha value!') 309 | if (m < 4 or m > 8): 310 | print('bad width value!') 311 | 312 | etalim = np.float32([1., 1., 0.75, 0.775, 0.775]) 313 | nnum = np.int8([5, 7, 5, 5, 6]) 314 | ndenom = np.int8([3, 2, 3, 3, 3]) 315 | p = np.float32( 316 | [ 317 | [ [5.613913E-2,-3.019847E-1, 6.256387E-1, 318 | -6.324887E-1, 3.303194E-1, 0.0, 0.0], 319 | [6.843713E-2,-3.342119E-1, 6.302307E-1, 320 | -5.829747E-1, 2.765700E-1, 0.0, 0.0], 321 | [8.203343E-2,-3.644705E-1, 6.278660E-1, 322 | -5.335581E-1, 2.312756E-1, 0.0, 0.0], 323 | [9.675562E-2,-3.922489E-1, 6.197133E-1, 324 | -4.857470E-1, 1.934013E-1, 0.0, 0.0], 325 | [1.124069E-1,-4.172349E-1, 6.069622E-1, 326 | -4.405326E-1, 1.618978E-1, 0.0, 0.0] 327 | ], 328 | [ [8.531865E-4,-1.616105E-2, 6.888533E-2, 329 | -1.109391E-1, 7.747182E-2, 0.0, 0.0], 330 | [2.060760E-3,-2.558954E-2, 8.595213E-2, 331 | -1.170228E-1, 7.094106E-2, 0.0, 0.0], 332 | [4.028559E-3,-3.697768E-2, 1.021332E-1, 333 | -1.201436E-1, 6.412774E-2, 0.0, 0.0], 334 | [6.887946E-3,-4.994202E-2, 1.168451E-1, 335 | -1.207733E-1, 5.744210E-2, 0.0, 0.0], 336 | [1.071895E-2,-6.404749E-2, 1.297386E-1, 337 | -1.194208E-1, 5.112822E-2, 0.0, 0.0] 338 | ] 339 | ]) 340 | q = np.float32( 341 | [ 342 | [ [1., 9.077644E-1, 2.535284E-1], 343 | [1., 8.626056E-1, 2.291400E-1], 344 | [1., 8.212018E-1, 2.078043E-1], 345 | [1., 7.831755E-1, 1.890848E-1], 346 | [1., 7.481828E-1, 1.726085E-1] 347 | ], 348 | [ [1., 1.101270 , 3.858544E-1], 349 | [1., 1.025431 , 3.337648E-1], 350 | [1., 9.599102E-1, 2.918724E-1], 351 | [1., 9.025276E-1, 2.575337E-1], 352 | [1., 8.517470E-1, 2.289667E-1] 353 | ] 354 | ]) 355 | 356 | i = int(m - 4) 357 | if(np.abs(eta) > etalim[i]): 358 | ip = 1 359 | x = eta*eta - 1 360 | else: 361 | ip = 0 362 | x = eta*eta - etalim[i]*etalim[i] 363 | # numerator via Horner's rule 364 | mnp = nnum[i]-1 365 | num = p[int(ip),int(twoalp),int(mnp)] 366 | for j in np.arange(mnp): 367 | num = num*x + p[int(ip),int(twoalp),int(mnp-1-j)] 368 | # denominator via Horner's rule 369 | nq = ndenom[i]-1 370 | denom = q[int(ip),int(twoalp),int(nq)] 371 | for j in np.arange(nq): 372 | denom = denom*x + q[int(ip),int(twoalp),int(nq-1-j)] 373 | 374 | return np.float32(num/denom) 375 | 376 | def gcf(n,width): 377 | """ 378 | Create table with spheroidal gridding function, C 379 | This implementation follows MIRIAD's grid.for subroutine. 380 | """ 381 | alpha = 1 382 | j = 2*alpha 383 | p = 0.5*j 384 | phi = np.zeros(n,dtype=np.float32) 385 | for i in np.arange(n): 386 | x = np.float32(2*i-(n-1))/(n-1) 387 | phi[i] = (np.sqrt(1-x*x)**j)*spheroid(x,width,p) 388 | return phi 389 | 390 | def corrfun(n,width): 391 | """ 392 | Create gridding correction function, c 393 | This implementation follows MIRIAD's grid.for subroutine. 394 | """ 395 | alpha = 1 396 | dx = 2./n 397 | i0 = n/2+1 398 | phi = np.zeros(n,dtype=np.float32) 399 | for i in np.arange(n): 400 | x = (i-i0+1)*dx 401 | phi[i] = spheroid(x,width,alpha) 402 | return phi 403 | 404 | def cuda_gridvis(settings,plan): 405 | """ 406 | Grid the visibilities parallelized by pixel. 407 | References: 408 | - Chapter 10 in "Interferometry and Synthesis in Radio Astronomy" 409 | by Thompson, Moran, & Swenson 410 | - Daniel Brigg's PhD Thesis: http://www.aoc.nrao.edu/dissertations/dbriggs/ 411 | """ 412 | print("Gridding the visibilities") 413 | t_start=time.time() 414 | 415 | # unpack parameters 416 | vfile = settings['vfile'] 417 | briggs = settings['briggs'] 418 | imsize = settings['imsize'] 419 | cell = settings['cell'] 420 | nx = np.int32(2*imsize) 421 | noff = np.int32((nx-imsize)/2) 422 | 423 | ## constants 424 | arc2rad = np.float32(np.pi/180/3600.) 425 | du = np.float32(1./(arc2rad*cell*nx)) 426 | ## grab data 427 | f = pyfits.open(settings['vfile']) 428 | ## quickly figure out what data is not flagged 429 | freq = np.float32(f[0].header['CRVAL4']) 430 | good = np.where(f[0].data.data[:,0,0,0,0,0,0] != 0) 431 | h_u = np.float32(freq*f[0].data.par('uu')[good]) 432 | h_v = np.float32(freq*f[0].data.par('vv')[good]) 433 | gcount = np.int32(np.size(h_u)) 434 | ## assume data is unpolarized 435 | h_re = np.float32(0.5*(f[0].data.data[good,0,0,0,0,0,0]+f[0].data.data[good,0,0,0,0,1,0])) 436 | h_im = np.float32(0.5*(f[0].data.data[good,0,0,0,0,0,1]+f[0].data.data[good,0,0,0,0,1,1])) 437 | ## make GPU arrays 438 | h_grd = np.zeros((nx,nx),dtype=np.complex64) 439 | h_cnt = np.zeros((nx,nx),dtype=np.int32) 440 | d_u = gpu.to_gpu(h_u) 441 | d_v = gpu.to_gpu(h_v) 442 | d_re = gpu.to_gpu(h_re) 443 | d_im = gpu.to_gpu(h_im) 444 | d_cnt = gpu.zeros((np.int(nx),np.int(nx)),np.int32) 445 | d_grd = gpu.zeros((np.int(nx),np.int(nx)),np.complex64) 446 | d_ngrd = gpu.zeros_like(d_grd) 447 | d_bm = gpu.zeros_like(d_grd) 448 | d_nbm = gpu.zeros_like(d_grd) 449 | d_fim = gpu.zeros((np.int(imsize),np.int(imsize)),np.float32) 450 | ## define kernel parameters 451 | blocksize2D = (8,16,1) 452 | gridsize2D = (np.int(np.ceil(1.*nx/blocksize2D[0])),np.int(np.ceil(1.*nx/blocksize2D[1]))) 453 | blocksizeF2D = (16,16,1) 454 | gridsizeF2D = (np.int(np.ceil(1.*imsize/blocksizeF2D[0])),np.int(np.ceil(1.*imsize/blocksizeF2D[1]))) 455 | blocksize1D = (256,1,1) 456 | gridsize1D = (np.int(np.ceil(1.*gcount/blocksize1D[0])),1) 457 | 458 | # ------------------------ 459 | # make gridding kernels 460 | # ------------------------ 461 | ## make spheroidal convolution kernel (don't mess with these!) 462 | width = 6. 463 | ngcf = 24 464 | h_cgf = gcf(ngcf,width) 465 | ## make grid correction 466 | h_corr = corrfun(nx,width) 467 | d_cgf = module.get_global('cgf')[0] 468 | d_corr = gpu.to_gpu(h_corr) 469 | cu.memcpy_htod(d_cgf,h_cgf) 470 | 471 | # ------------------------ 472 | # grid it up 473 | # ------------------------ 474 | d_umax = gpu.max(cumath.fabs(d_u)) 475 | d_vmax = gpu.max(cumath.fabs(d_v)) 476 | umax = np.int32(np.ceil(d_umax.get()/du)) 477 | vmax = np.int32(np.ceil(d_vmax.get()/du)) 478 | 479 | ## grid ($$) 480 | # This should be improvable via: 481 | # - shared memory solution? I tried... 482 | # - better coalesced memory access? I tried... 483 | # - reorganzing and indexing UV data beforehand? 484 | # (i.e. http://www.nvidia.com/docs/IO/47905/ECE757_Project_Report_Gregerson.pdf) 485 | # - storing V(u,v) in texture memory? 486 | gridVis_wBM_kernel(d_grd,d_bm,d_cnt,d_u,d_v,d_re,d_im,nx,du,gcount,umax,vmax,\ 487 | block=blocksize2D,grid=gridsize2D) 488 | ## apply weights 489 | wgtGrid_kernel(d_bm,d_cnt,briggs,nx,block=blocksize2D,grid=gridsize2D) 490 | hfac = np.int32(1) 491 | dblGrid_kernel(d_bm,nx,hfac,block=blocksize2D,grid=gridsize2D) 492 | shiftGrid_kernel(d_bm,d_nbm,nx,block=blocksize2D,grid=gridsize2D) 493 | ## normalize 494 | wgtGrid_kernel(d_grd,d_cnt,briggs,nx,block=blocksize2D,grid=gridsize2D) 495 | ## Reflect grid about v axis 496 | hfac = np.int32(-1) 497 | dblGrid_kernel(d_grd,nx,hfac,block=blocksize2D,grid=gridsize2D) 498 | ## Shift both 499 | shiftGrid_kernel(d_grd,d_ngrd,nx,block=blocksize2D,grid=gridsize2D) 500 | 501 | # ------------------------ 502 | # Make the beam 503 | # ------------------------ 504 | ## Transform to image plane 505 | fft.fft(d_nbm,d_bm,plan) 506 | ## Shift 507 | shiftGrid_kernel(d_bm,d_nbm,nx,block=blocksize2D,grid=gridsize2D) 508 | ## Correct for C 509 | corrGrid_kernel(d_nbm,d_corr,nx,block=blocksize2D,grid=gridsize2D) 510 | # Trim 511 | trimIm_kernel(d_nbm,d_fim,noff,nx,imsize,block=blocksizeF2D,grid=gridsizeF2D) 512 | ## Normalize 513 | d_bmax = gpu.max(d_fim) 514 | bmax = d_bmax.get() 515 | bmax = np.float32(1./bmax) 516 | nrmBeam_kernel(d_fim,bmax,imsize,block=blocksizeF2D,grid=gridsizeF2D) 517 | ## Pull onto CPU 518 | dpsf = d_fim.get() 519 | 520 | # ------------------------ 521 | # Make the map 522 | # ------------------------ 523 | ## Transform to image plane 524 | fft.fft(d_ngrd,d_grd,plan) 525 | ## Shift 526 | shiftGrid_kernel(d_grd,d_ngrd,nx,block=blocksize2D,grid=gridsize2D) 527 | ## Correct for C 528 | corrGrid_kernel(d_ngrd,d_corr,nx,block=blocksize2D,grid=gridsize2D) 529 | ## Trim 530 | trimIm_kernel(d_ngrd,d_fim,noff,nx,imsize,block=blocksizeF2D,grid=gridsizeF2D) 531 | ## Normalize (Jy/beam) 532 | nrmGrid_kernel(d_fim,bmax,imsize,block=blocksizeF2D,grid=gridsizeF2D) 533 | 534 | ## Finish timers 535 | t_end=time.time() 536 | t_full=t_end-t_start 537 | print("Gridding execution time %0.5f"%t_full+' s') 538 | print("\t%0.5f"%(t_full/gcount)+' s per visibility') 539 | 540 | ## Return dirty psf (CPU) and dirty image (GPU) 541 | return dpsf,d_fim 542 | 543 | ###################### 544 | # CLEAN functions 545 | ###################### 546 | 547 | def serial_clean_beam(dpsf,window=20, sigma=1.0): 548 | """ 549 | Clean a dirty beam on the CPU 550 | A very simple approach - just extract the central beam #improvable# 551 | Another solution would be fitting a 2D Gaussian, 552 | e.g. http://code.google.com/p/agpy/source/browse/trunk/agpy/gaussfitter.py 553 | """ 554 | print("Cleaning the dirty beam") 555 | h,w=np.shape(dpsf) 556 | h = int(h) 557 | w = int(w) 558 | window = int(window) 559 | 560 | gaussian_window = multivariate_normal([0,0],[[1.0,0],[0.0,1.0]]) 561 | x1 = np.linspace(-8 * sigma, 8 * sigma, window*2) # x-values for the normal-dstr 562 | X, Y = np.meshgrid(x1, x1) 563 | pos = np.dstack((X, Y)) 564 | gaussian_filter = gaussian_window.pdf(pos) 565 | 566 | cpsf=np.zeros([h,w]) 567 | cpsf[w//2-window:w//2+window,h//2-window:h//2+window]=gaussian_filter*dpsf[w//2-window:w//2+window,h//2-window:h//2+window] 568 | 569 | #import pylab as plt 570 | #plt.imshow(cpsf); plt.show() 571 | 572 | ##Normalize 573 | cpsf=cpsf/np.max(cpsf) 574 | return np.float32(cpsf) 575 | 576 | def gpu_getmax(map, polarity=False): 577 | """ 578 | Use pycuda to get the maximum absolute deviation of the residual map, 579 | with the correct sign 580 | """ 581 | if polarity: 582 | imax=gpu.max(cumath.fabs(map)).get() 583 | if gpu.max(map).get()!=imax: imax*=-1 584 | return np.float32(imax) 585 | else: 586 | imax = gpu.max(map).get() 587 | return np.float32(imax) 588 | 589 | def cuda_hogbom(gpu_dirty,gpu_dpsf,gpu_cpsf,thresh=0.2,damp=1,gain=0.1,prefix='test', maxIter=1e5, polarity=True, 590 | verbose=True, im_gt=None, dpsf_unnormed=None, plot_intermediate=False): 591 | """ 592 | Use CUDA to implement the Hogbom CLEAN algorithm 593 | 594 | A nice description of the algorithm is given by the NRAO, here: 595 | http://www.cv.nrao.edu/~abridle/deconvol/node8.html 596 | 597 | Parameters: 598 | * dirty: The dirty image (2D numpy array) 599 | * dpsf: The dirty beam psf (2D numpy array) 600 | * thresh: User-defined threshold to stop iteration, as a fraction of the max pixel intensity (float) 601 | * damp: The damping factor to scale the dirty beam by 602 | * prefix: prefix for output image file names 603 | """ 604 | height,width=np.shape(gpu_dirty) 605 | dirty_im = np.float32(gpu_dirty.get().copy()) 606 | height, width = np.int32(height), np.int32(width) 607 | print('Height=', height, 'Width=', width) 608 | ## Grid parameters - #improvable# 609 | tsize=1 610 | blocksize = (int(tsize),int(tsize),1) # The number of threads per block (x,y,z) 611 | gridsize = (int(width/tsize),int(height/tsize)) # The number of thread blocks (x,y) 612 | ## Setup cleam image and point source model 613 | gpu_pmodel = gpu.zeros([height,width],dtype=np.float32) 614 | gpu_clean = gpu.zeros([height,width],dtype=np.float32) 615 | ## Setup GPU constants 616 | gpu_max_id = gpu.to_gpu(np.array([width*height/2], dtype=np.int32)) 617 | imax=np.float32(gpu_getmax(gpu_dirty, polarity=polarity)) #gpu_dirty.get().max()# 618 | thresh_val=np.float32(thresh*imax) 619 | ## Steps 1-3 - Iterate until threshold has been reached 620 | t_start=time.time() 621 | i=0 622 | num_plot = 0 623 | while abs(imax)>(thresh_val) and i < maxIter: 624 | 625 | if (np.mod(i,1000)==0): 626 | print("Hogbom iteration",i) 627 | if plot_intermediate: 628 | gpu_dirty_final = gpu_dirty.copy() 629 | gpu_clean_final = gpu_clean.copy() 630 | #add_noise_kernel(gpu_dirty_final, gpu_clean_final, np.int32(width + height)) 631 | plot(dirty_im, gpu_dirty_final, gpu_clean_final, gpu_cpsf.get(), gpu_dpsf.get(), im_gt=im_gt, 632 | dpsf_unnormed=dpsf_unnormed, prefix='iter%06d_' % num_plot) 633 | num_plot += 1 634 | ## Step 1 - Find max 635 | find_max_kernel(gpu_dirty,gpu_max_id,imax,np.int32(width),np.int32(height),gpu_pmodel,\ 636 | block=blocksize, grid=gridsize) 637 | 638 | ## Step 2 - Subtract the beam (assume that it is normalized to have max 1) 639 | ## This kernel simultaneously reconstructs the CLEANed image. 640 | if verbose: print("Subtracting dirty beam "+str(i)+", maxval=%0.8f"%imax+' at x='+str(gpu_max_id.get()%width)+\ 641 | ', y='+str(gpu_max_id.get()//width), 'thresh=', thresh_val) 642 | sub_beam_kernel(gpu_dirty,gpu_dpsf,gpu_max_id,gpu_clean,gpu_cpsf,np.float32(gain*imax),np.int32(width),\ 643 | np.int32(height), block=blocksize, grid=gridsize) 644 | i+=1 645 | ## Step 3 - Find maximum value using gpuarray 646 | imax=gpu_getmax(gpu_dirty, polarity=polarity) #gpu_dirty.get().max() # 647 | t_end=time.time() 648 | t_full=t_end-t_start 649 | print("Hogbom execution time %0.5f"%t_full+' s') 650 | print("\t%0.5f"%(t_full/i)+' s per iteration') 651 | ## Step 4 - Add the residuals back in 652 | #relu_kernel(gpu_dirty, np.int32(width), np.int32(height), block=blocksize, grid=gridsize) 653 | add_noise_kernel(gpu_dirty,gpu_clean,np.int32(width*height),np.float32(0.1)) 654 | 655 | return gpu_dirty,gpu_pmodel,gpu_clean 656 | 657 | def clean_cuda(dirty_im, dirty_psf, thresh=0.2, gain=0.1, clean_beam_size=50.0, maxIter=1e5, prefix='test', 658 | im_gt=None, polarity=True, clean_psf=None, plot_intermediate=False): 659 | imsize = np.int32(dirty_im.shape[0]) 660 | 661 | dirty_psf_unnormed = np.float32(dirty_psf) 662 | 663 | gaussian_window = multivariate_normal([0, 0], [[1.0, 0], [0.0, 1.0]]) 664 | x1 = np.linspace(-2, 2, imsize) # x-values for the normal-dstr 665 | X, Y = np.meshgrid(x1, x1) 666 | pos = np.dstack((X, Y)) 667 | gaussian_filter = np.float32(gaussian_window.pdf(pos)) 668 | dirty_psf_unnormed = np.float32(dirty_psf_unnormed*gaussian_filter) 669 | #import pylab as plt 670 | #plt.imshow(dirty_psf); plt.show() 671 | 672 | dirty_psf_max = np.float32(dirty_psf_unnormed.max()) 673 | dirty_psf = dirty_psf_unnormed / dirty_psf_max 674 | 675 | gpu_dpsf = gpu.to_gpu(np.float32(dirty_psf)) 676 | gpu_im = gpu.to_gpu(np.float32(dirty_im)) 677 | 678 | ## Clean the PSF 679 | if clean_psf is not None: 680 | cpsf = clean_psf/clean_psf.max() 681 | else: 682 | cpsf = serial_clean_beam(dirty_psf, imsize/clean_beam_size) 683 | gpu_cpsf = gpu.to_gpu(np.float32(cpsf)) 684 | 685 | ## Run CLEAN 686 | gpu_dirty, gpu_pmodel, gpu_clean = cuda_hogbom(gpu_im, 687 | gpu_dpsf, 688 | gpu_cpsf, 689 | thresh=thresh, 690 | gain=gain, 691 | maxIter=maxIter, 692 | im_gt=im_gt, 693 | polarity=polarity, 694 | dpsf_unnormed=dirty_psf_unnormed, 695 | plot_intermediate=plot_intermediate) 696 | 697 | '''plot(dirty_im, gpu_dirty, gpu_clean, cpsf, dirty_psf, im_gt=im_gt, 698 | dpsf_unnormed=dirty_psf_unnormed, prefix=prefix)''' 699 | 700 | return gpu_clean.get() 701 | 702 | def plot(dirty_im, gpu_dirty, gpu_clean, clean_psf, dirty_psf, im_gt=None, 703 | dpsf_unnormed=None, prefix='test'): 704 | imsize = np.int32(dirty_im.shape[0]) 705 | 706 | if im_gt is not None: 707 | vra = [np.percentile(im_gt, 1), np.percentile(im_gt, 99)] 708 | else: 709 | vra = [np.percentile(dirty_im,1),np.percentile(dirty_im,99)] 710 | 711 | cmap = cm.hot 712 | 713 | print("Plotting dirty and cleaned beam") 714 | fig,axs=plt.subplots(3,2,sharex='all',sharey='all', figsize=(8,12));plt.subplots_adjust(wspace=0) 715 | axs[0,0].imshow(dirty_psf,vmin=np.percentile(dirty_psf,1),vmax=np.percentile(dirty_psf,99),cmap=cmap, origin='upper') 716 | axs[0,0].set_title('Dirty beam') 717 | axs[0,1].imshow(clean_psf,vmin=np.percentile(dirty_psf,1),vmax=np.percentile(dirty_psf,99),cmap=cmap, origin='upper') 718 | axs[0,1].set_title('Estimated clean beam') 719 | print("Plotting dirty image and dirty image after iterative source removal") 720 | axs[1,0].imshow(dirty_im,vmin=vra[0],vmax=vra[1],cmap=cmap,origin='upper') 721 | axs[1,0].set_title('Original dirty image') 722 | axs[1,1].imshow(gpu_dirty.get(),vmin=vra[0],vmax=vra[1],cmap=cmap,origin='upper') 723 | axs[1,1].set_title('Dirty image cleaned of sources') 724 | print("Plotting dirty image and final clean image") 725 | if im_gt is not None: 726 | axs[2,0].imshow(im_gt,vmin=vra[0],vmax=vra[1],cmap=cmap,origin='upper') 727 | axs[2,0].set_title('Original ground truth image') 728 | 729 | '''if dpsf_unnormed is not None: 730 | dirty_convolve = convolve2d(dpsf_unnormed, im_gt, mode='same') # [:img.shape[0]*2,:img.shape[1]*2] 731 | 732 | #print(dpsf_unnormed.min(), dpsf_unnormed.max(), im_gt.min(), im_gt.max()) 733 | #print(dirty_convolve.min(), dirty_convolve.max(), vra[0], vra[1]); input() 734 | 735 | axs[3,0].imshow(dirty_convolve,cmap=cmap,origin='upper') #,vmin=vra[0],vmax=vra[1],''' 736 | 737 | axs[2,1].imshow(gpu_clean.get(),vmin=vra[0],vmax=vra[1],cmap=cmap,origin='upper') 738 | axs[2,1].set_title('Final cleaned image') 739 | plt.savefig(prefix+'_clean_final.png') 740 | plt.close() 741 | 742 | 743 | 744 | if __name__ == '__main__': 745 | 746 | ## Load command line options 747 | 748 | # Which example? 749 | if len(sys.argv)>1: 750 | example=sys.argv[1] 751 | else: example = 'gaussian' 752 | if len(sys.argv)>2: 753 | ISIZE=float(sys.argv[2]) 754 | else: 755 | ISIZE=1024 756 | # Make plots? 757 | if len(sys.argv)>3: 758 | PLOTME=float(sys.argv[3]) 759 | else: 760 | PLOTME=1 761 | 762 | folder = './code/examples' 763 | 764 | # Load settings for each example 765 | settings = dict([]) 766 | if (example == 'gaussian'): 767 | # image a gaussian 768 | settings['vfile'] = f'{folder}/sim1.gauss.alma.out20.ms.fits' 769 | settings['imsize'] = np.int32(ISIZE) # number of image pixels 770 | settings['cell'] = np.float32(5.12/ISIZE) # pixel size in arcseconds 771 | settings['briggs'] = np.float32(1e7) # weight parameter 772 | elif (example == 'ring'): 773 | # image an inclined ring 774 | settings['vfile'] = f'{folder}/sim1.ring.alma.out20.ms.fits' 775 | settings['imsize']= np.int32(ISIZE) # number of image pixels 776 | settings['cell'] = np.float32(5.12/ISIZE) # pixel size in arcseconds 777 | settings['briggs']= np.float32(1e7) # weight parameter 778 | elif (example == 'mouse'): 779 | # image a non-astronomical source 780 | settings['vfile'] = f'{folder}/sim1.mickey.alma.out20.ms.fits' 781 | settings['imsize']= np.int32(ISIZE) # number of image pixels 782 | settings['cell'] = np.float32(5.12/ISIZE) # pixel size in arcseconds 783 | settings['briggs']= np.float32(1e3) # weight parameter 784 | elif (example == 'hd163296'): 785 | # image a single channel of the CO J=3-2 line from a protoplanetary disk 786 | # data from: https://almascience.nrao.edu/almadata/sciver/HD163296Band7/ 787 | settings['vfile'] = f'{folder}/HD163296.CO32.regridded.ms.constub.c21.fits' 788 | settings['imsize']= np.int32(ISIZE) # number of image pixels 789 | settings['cell'] = np.float32(25./ISIZE) # pixel size in arcseconds 790 | settings['briggs']= np.float32(1e7) # weight parameter 791 | vra = [-0.15,1.2] # intensity range for figure 792 | else: 793 | print('QUITTING: NO SUCH EXAMPLE.') 794 | sys.exit() 795 | 796 | ## make cuFFT plan #improvable# 797 | imsize = settings['imsize'] 798 | nx = np.int32(2*imsize) 799 | plan = fft.Plan((np.int(nx),np.int(nx)),np.complex64,np.complex64) 800 | 801 | ## Create the PSF & dirty image 802 | dpsf,gpu_im = cuda_gridvis(settings,plan) 803 | 804 | # # import pylab as plt 805 | 806 | # fig, ax = plt.subplots(nrows=1,ncols=2) 807 | # ax[0].imshow(gpu_im.get()) 808 | # ax[1].imshow(dpsf) 809 | # plt.show() 810 | 811 | clean_cuda(dirty_im=gpu_im.get(), dirty_psf=dpsf, thresh=0.2, gain=0.1, clean_beam_size=50.0, prefix='clean_cuda') 812 | 813 | input("done!") 814 | 815 | '''print(dpsf.min(), dpsf.max()); input() 816 | print(gpu_im.get().min(), gpu_im.get().max()); 817 | input()''' 818 | 819 | gpu_dpsf = gpu.to_gpu(dpsf) 820 | if PLOTME: 821 | dirty = np.roll(np.fliplr(gpu_im.get()),1,axis=1) 822 | 823 | ## Clean the PSF 824 | cpsf=serial_clean_beam(dpsf,imsize/50.) 825 | gpu_cpsf = gpu.to_gpu(cpsf) 826 | 827 | # if PLOTME: 828 | # print("Plotting dirty and cleaned beam") 829 | # fig,axs=plt.subplots(1,2,sharex='all',sharey='all');plt.subplots_adjust(wspace=0) 830 | # axs[0].imshow(dpsf,vmin=np.percentile(dpsf,1),vmax=np.percentile(dpsf,99),cmap=cm.gray) 831 | # axs[1].imshow(cpsf,vmin=np.percentile(dpsf,1),vmax=np.percentile(dpsf,99),cmap=cm.gray) 832 | # plt.savefig('test_cleanbeam.png') 833 | # plt.close() 834 | 835 | ## Run CLEAN 836 | gpu_dirty,gpu_pmodel,gpu_clean = cuda_hogbom(gpu_im,gpu_dpsf,gpu_cpsf,thresh=0.2,gain=0.1) 837 | 838 | if PLOTME: 839 | prefix=example 840 | try: 841 | vra 842 | except NameError: 843 | vra = [np.percentile(dirty,1),np.percentile(dirty,99)] 844 | 845 | print("Plotting dirty image and dirty image after iterative source removal") 846 | # fig,axs=plt.subplots(1,2,sharex='all',sharey='all',figsize=(12.2,6));plt.subplots_adjust(wspace=0) 847 | # axs[0].imshow(dirty,vmin=vra[0],vmax=vra[1],cmap=cm.gray,origin='lower') 848 | # axs[0].set_title('Original dirty image') 849 | # axs[1].imshow(np.roll(np.fliplr(gpu_dirty.get()),1,axis=1),vmin=vra[0],vmax=vra[1],cmap=cm.gray,origin='lower') 850 | # axs[1].set_title('Dirty image cleaned of sources') 851 | # plt.savefig(prefix+'_dirty_final.png') 852 | # plt.close() 853 | 854 | print("Plotting dirty image and final clean image") 855 | vra = [np.percentile(dirty,1),np.percentile(dirty,99)] 856 | # fig,axs=plt.subplots(1,2,sharex='all',sharey='all',figsize=(12.2,6));plt.subplots_adjust(wspace=0) 857 | # clean = np.roll(np.fliplr(gpu_clean.get()),1,axis=1) 858 | # axs[0].imshow(dirty,vmin=vra[0],vmax=vra[1],cmap=cm.gray,origin='lower') 859 | # axs[0].set_title('Original dirty image') 860 | # axs[1].imshow(clean,vmin=vra[0],vmax=vra[1],cmap=cm.gray,origin='lower') 861 | # axs[1].set_title('Final cleaned image') 862 | # plt.savefig(prefix+'_clean_final.png') 863 | # plt.close() 864 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on code from Wu, B.; Liu, C.; Eckart, B.; and Kautz, J. 2022. Neural interferometry: Image reconstruction from astronomical interferometers using transformer-conditioned neural fields. In Proceedings of the AAAI Conference on Artificial Intelligence. 3 | """ 4 | 5 | from copy import Error 6 | import os 7 | import argparse 8 | from argparse import ArgumentParser 9 | import pytorch_lightning as pl 10 | import torch 11 | from torch import nn 12 | from torch.nn import functional as F 13 | import numpy as np 14 | from pytorch_lightning.callbacks import ModelCheckpoint 15 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 16 | 17 | 18 | from torch.utils.data import DataLoader, random_split 19 | from torch.utils.data import Dataset, Subset 20 | 21 | from loss import FocalFrequencyLoss as FFL 22 | ffl = FFL(loss_weight=1, alpha=1, beta=1) 23 | 24 | import matplotlib 25 | import matplotlib.pyplot as plt 26 | 27 | from pytorch_lightning.plugins import DDPPlugin 28 | import context_encoder.encoders as m_encoder 29 | from mlp import PosEncodedMLP_FiLM 30 | from data_continuous_EHT import EHTIM_Dataset 31 | from data_ehtim_cont import make_dirtyim, make_im_torch 32 | 33 | 34 | import logging 35 | import sys 36 | logging.basicConfig(stream=sys.stderr, level=logging.DEBUG) 37 | 38 | from scipy import interpolate 39 | from numpy.fft import fft2, ifft2, fftshift, ifftshift 40 | 41 | import socket 42 | hostname= socket.gethostname() 43 | ti = 0 44 | if hostname!= 'NV': 45 | matplotlib.use('Agg') 46 | 47 | 48 | 49 | 50 | class PolarRec(pl.LightningModule): 51 | 52 | def __init__( 53 | self, args, 54 | learning_rate=1e-4, L_embed=5, 55 | input_encoding='nerf', sigma=2.5, 56 | hidden_dims=[256,256], 57 | latent_dim=64, kl_coeff=100.0, 58 | num_fourier_coeff=32, batch_size=32, 59 | input_size=28, model_checkpoint='', ngpu=None): 60 | 61 | super().__init__() 62 | 63 | self.save_hyperparameters() 64 | 65 | self.loss_func = nn.MSELoss(reduction='mean') 66 | self.loss_type = args.loss_type 67 | self.ngpu = ngpu 68 | self.use_unet = False 69 | self.use_GAT = False 70 | 71 | self.uv_dense_sparse_index=None 72 | self.num_fourier_coeff = num_fourier_coeff 73 | self.scale_loss_image= False #args.scale_loss_image 74 | 75 | 76 | 77 | if self.use_unet: 78 | 79 | if self.loss_type=='unet_direct': 80 | self.UNET=unet.UNet(2, 1) #input: sparse visibility map; output: image 81 | else: 82 | self.UNET=unet.UNet(2, 2) #input: sparse visibility map; output: dense visibility map 83 | else: 84 | if self.use_GAT: 85 | self.cond_mlp = PosEncodedMLP_FiLM( 86 | context_dim=latent_dim, 87 | input_size=2, output_size=2, 88 | hidden_dims=hidden_dims, 89 | L_embed=L_embed, embed_type=input_encoding, 90 | activation=nn.ReLU, 91 | sigma=sigma, 92 | context_type='Transformer') 93 | 94 | encoder = m_encoder.ViGAT( 95 | input_dim=2, #value dim 96 | # PE dim for MLP, we are going to use the same PE as the MLP 97 | pe_dim=self.cond_mlp.input_size, 98 | dim=512, depth=4, heads=16, 99 | output_dim=latent_dim, 100 | dropout=.1, emb_dropout=0., 101 | mlp_dim=512, 102 | output_tokens=args.mlp_layers, 103 | has_global_token=False) 104 | self.pe_encoder = self.cond_mlp.embed_fun 105 | self.context_encoder = encoder 106 | else: 107 | self.cond_mlp = PosEncodedMLP_FiLM( 108 | context_dim=latent_dim, 109 | input_size=2, output_size=2, 110 | hidden_dims=hidden_dims, 111 | L_embed=L_embed, embed_type=input_encoding, 112 | activation=nn.ReLU, 113 | sigma=sigma, 114 | context_type='Transformer') 115 | 116 | encoder = m_encoder.PolarRec_Encoder( 117 | input_dim=2, #value dim 118 | # PE dim for MLP, we are going to use the same PE as the MLP 119 | pe_dim=self.cond_mlp.input_size, 120 | dim=512, depth=4, heads=16, 121 | dim_head=512//16, 122 | output_dim=latent_dim, 123 | dropout=.1, emb_dropout=0., 124 | mlp_dim=512, 125 | output_tokens=args.mlp_layers, 126 | has_global_token=False) 127 | self.pe_encoder = self.cond_mlp.embed_fun 128 | self.context_encoder = encoder 129 | 130 | 131 | self.norm_fact=None 132 | 133 | self.numEpoch = 0 134 | 135 | self.uv_arr= None 136 | self.U, self.V= None, None 137 | self.uv_coords_grid_query= None 138 | 139 | #validation plots 140 | self.folder_val = f'{args.val_fldr}/imgs/' 141 | self.folder_anim = f'{args.val_fldr}/anims/' 142 | os.makedirs(self.folder_val, exist_ok=True) 143 | os.makedirs(self.folder_anim, exist_ok=True) 144 | self.numPlot = 0 145 | self.plotFreq = 10 146 | 147 | # testing 148 | self.test_iter = 0 149 | self.test_log_step = 50 150 | self.test_zs = [] 151 | self.test_imgs = [] 152 | self.test_fldr= f'../test_res1/{args.exp_name}' 153 | 154 | 155 | 156 | def load_pe_encoder(self, file_path): 157 | print("loading checkpoint...") 158 | self.pe_encoder.load_state_dict(torch.load(file_path)) 159 | print("finish loading!") 160 | 161 | 162 | def forward(self, x, z): 163 | pred_visibilities = self.cond_mlp(x, context=z) 164 | return pred_visibilities 165 | 166 | 167 | def _f(self, x): 168 | return ((x+0.5)%1)-0.5 169 | 170 | def inference_w_conjugate(self, uv_coords, z, nF=0, return_np=True): 171 | halfspace = self._get_halfspace(uv_coords) 172 | 173 | # does this modify visibilities in place? 174 | uv_coords_flipped = self._flip_uv(uv_coords, halfspace) 175 | 176 | pred_visibilities = self(uv_coords_flipped, z) 177 | # print("pred_vis", pred_visibilities.shape) 178 | pred_vis_real = pred_visibilities[:,:,0] 179 | pred_vis_imag = pred_visibilities[:,:,1] 180 | pred_vis_imag[halfspace] = -pred_vis_imag[halfspace] 181 | 182 | 183 | if nF == 0: nF = self.hparams.num_fourier_coeff 184 | 185 | pred_vis_imag = pred_vis_imag.reshape((-1, nF, nF)) 186 | pred_vis_real = pred_vis_real.reshape((-1, nF, nF)) 187 | 188 | global ti 189 | 190 | plt.imsave(fname="../test_res1/Galaxy10-DEC-cont/transformer/mlp_8_layer/image_loss-NF_128/"+str(ti)+"/recon_imag.png", arr=pred_vis_imag[0].cpu().detach().numpy(), cmap="viridis") 191 | plt.imsave(fname="../test_res1/Galaxy10-DEC-cont/transformer/mlp_8_layer/image_loss-NF_128/"+str(ti)+"/recon_real.png", arr=pred_vis_real[0].cpu().detach().numpy(), cmap="viridis") 192 | np.save("../test_res1/Galaxy10-DEC-cont/transformer/mlp_8_layer/image_loss-NF_128/"+str(ti)+"/recon_vis.npy", np.stack((pred_vis_real[0].cpu().detach().numpy(), pred_vis_imag[0].cpu().detach().numpy()), axis=0)) 193 | 194 | ti = ti + 1 195 | 196 | pred_vis_imag[:,0,0] = 0 197 | pred_vis_imag[:,0,nF//2] = 0 198 | pred_vis_imag[:,nF//2,0] = 0 199 | pred_vis_imag[:,nF//2,nF//2] = 0 200 | 201 | 202 | if return_np: 203 | pred_fft = pred_vis_real.detach().cpu().numpy() + 1j*pred_vis_imag.detach().cpu().numpy() 204 | else: 205 | pred_fft = pred_vis_real + 1j*pred_vis_imag 206 | 207 | # NEW: set border to zero to counteract weird border issues 208 | pred_fft[:,0,:] = 0.0 209 | pred_fft[:,:,0] = 0.0 210 | pred_fft[:,:,-1] = 0.0 211 | pred_fft[:,-1,:] = 0.0 212 | # print("pred_fft", pred_fft.shape) 213 | return pred_fft 214 | 215 | def _get_halfspace(self, uv_coords): 216 | #left_halfspace = torch.logical_and(uv_coords[:,0] > 0, uv_coords[:,1] > 0) 217 | left_halfspace = torch.logical_and(torch.logical_or( 218 | uv_coords[:,:,0] < 0, 219 | torch.logical_and(uv_coords[:,:,0] == 0, uv_coords[:,:,1] > 0)), 220 | ~torch.logical_and(uv_coords[:,:,0] == -.5, uv_coords[:,:,1] > 0)) 221 | 222 | return left_halfspace 223 | 224 | def _conjugate_vis(self, vis, halfspace): 225 | # take complex conjugate if flipped uv coords 226 | # so network doesn't receive confusing gradient information 227 | vis[halfspace] = torch.conj(vis[halfspace]) 228 | return vis 229 | 230 | def _flip_uv(self, uv_coords, halfspace): 231 | 232 | halfspace_2d = torch.stack((halfspace, halfspace), axis=-1) 233 | uv_coords_flipped = torch.where(halfspace_2d, self._f(-uv_coords), uv_coords) 234 | 235 | return uv_coords_flipped 236 | 237 | def _recon_image_rfft(self, uv_dense, z, imsize, max_base, eht_fov, ): 238 | #get the query uv's 239 | B= uv_dense.shape[0] 240 | img_res=imsize[0] 241 | uv_dense_per=uv_dense[0] 242 | u_dense, v_dense= uv_dense_per[:,0].unique(), uv_dense_per[:,1].unique() 243 | u_dense= torch.linspace( u_dense.min(), u_dense.max(), len(u_dense)//2 * 2 + 1).to(u_dense) 244 | v_dense= torch.linspace( v_dense.min(), v_dense.max(), len(v_dense)//2 * 2 + 1).to(u_dense) 245 | uv_arr= torch.cat([u_dense.unsqueeze(-1), v_dense.unsqueeze(-1)], dim=-1) 246 | scale_ux= max_base * eht_fov/ img_res 247 | uv_arr= ((uv_arr+.5) * 2 -1.) * scale_ux # scaled input 248 | U,V= torch.meshgrid(uv_arr[:,0], uv_arr[:,1]) 249 | uv_coords_grid_query= torch.cat((U.reshape(-1,1), V.reshape(-1,1)), dim=-1).unsqueeze(0).repeat(B,1,1) 250 | #get predicted visibilities 251 | pred_visibilities = self(uv_coords_grid_query, z) #Bx (HW) x 2 252 | pred_visibilities_map= torch.view_as_complex(pred_visibilities).reshape(B, U.shape[0], U.shape[1]) 253 | img_recon = make_im_torch(uv_arr, pred_visibilities_map, img_res, eht_fov, 254 | norm_fact=self.norm_fact if self.norm_fact is not None else 1., 255 | return_im=True) 256 | 257 | return img_recon 258 | 259 | 260 | def _step_image_loss(self, batch, batch_idx, num_zero_samples=0, loss_type='image',): 261 | ''' 262 | forward pass then calculate the loss in the image domain 263 | we will use rfft to ensure that the values in the image domain are real 264 | ''' 265 | 266 | uv_coords, uv_dense, vis_sparse, visibilities, img_0s, label = batch 267 | img_res= img_0s.shape[-1] 268 | 269 | eht_fov = 1.4108078120287498e-09 270 | max_base = 8368481300.0 271 | scale_ux= max_base * eht_fov/ img_res 272 | 273 | pos = uv_coords* scale_ux 274 | pe_uv = self.pe_encoder(pos) 275 | inputs_encoder = torch.cat([pe_uv, vis_sparse], dim=-1) 276 | z = self.context_encoder(inputs_encoder, pos) 277 | 278 | B= uv_dense.shape[0] 279 | nF= int( uv_dense.shape[1]**.5 ) 280 | 281 | 282 | #get the query uv's 283 | if self.uv_coords_grid_query is None: 284 | uv_dense_per=uv_dense[0] 285 | u_dense, v_dense= uv_dense_per[:,0].unique(), uv_dense_per[:,1].unique() 286 | u_dense= torch.linspace( u_dense.min(), u_dense.max(), len(u_dense)//2 * 2 ).to(u_dense) 287 | v_dense= torch.linspace( v_dense.min(), v_dense.max(), len(v_dense)//2 * 2 ).to(u_dense) 288 | uv_arr= torch.cat([u_dense.unsqueeze(-1), v_dense.unsqueeze(-1)], dim=-1) 289 | # print("uv_arr", uv_arr.shape) 290 | uv_arr= ((uv_arr+.5) * 2 -1.) * scale_ux # scaled input 291 | # print("uv_arr", uv_arr.shape) 292 | U,V= torch.meshgrid(uv_arr[:,0], uv_arr[:,1]) 293 | uv_coords_grid_query= torch.cat((U.reshape(-1,1), V.reshape(-1,1)), dim=-1).unsqueeze(0).repeat(B,1,1) 294 | self.uv_arr= uv_arr 295 | self.U, self.V= U,V 296 | self.uv_coords_grid_query= uv_coords_grid_query 297 | print('initilized self.uv_coords_grid_query') 298 | 299 | 300 | #get predicted visibilities 301 | pred_visibilities = self(self.uv_coords_grid_query, z) #Bx (HW) x 2 302 | 303 | #image recon 304 | if self.norm_fact is None: # get the normalization factor, which is fixed given image/spectral domain dimensions 305 | visibilities_map= visibilities.reshape(-1, self.num_fourier_coeff, self.num_fourier_coeff) 306 | uv_dense_physical = (uv_dense.detach().cpu().numpy()[0,:,:] +0.5)*(2*max_base) - (max_base) 307 | _, _, norm_fact = make_dirtyim(uv_dense_physical, 308 | visibilities_map.detach().cpu().numpy()[ 0, :, :].reshape(-1), 309 | img_res, eht_fov, return_im=True) 310 | self.norm_fact= norm_fact 311 | print('initiliazed the norm fact') 312 | 313 | #visibilities_map: B x len(u_dense) x len(v_dense) 314 | 315 | pred_visibilities_map= torch.view_as_complex(pred_visibilities).reshape(B, self.U.shape[0], self.U.shape[1]) 316 | img_recon = make_im_torch(self.uv_arr, pred_visibilities_map, img_res, eht_fov, norm_fact=self.norm_fact, return_im=True) 317 | 318 | vis_maps = visibilities.reshape(-1, self.num_fourier_coeff, self.num_fourier_coeff) 319 | img_recon_gt= make_im_torch(self.uv_arr, vis_maps, img_res, eht_fov, norm_fact=self.norm_fact, return_im=True) 320 | 321 | #energy in the frequency space 322 | freq_norms = torch.sqrt(torch.sum(self.uv_coords_grid_query**2, -1)) 323 | abs_pred = torch.sqrt(pred_visibilities[:,:,0]**2 + pred_visibilities[:,:,1]**2) 324 | energy = torch.mean(freq_norms*abs_pred) 325 | 326 | ######################## 327 | halfspace = self._get_halfspace(uv_dense) 328 | uv_coords_flipped = self._flip_uv(uv_dense, halfspace) 329 | vis_conj = self._conjugate_vis(visibilities, halfspace) 330 | 331 | vis_real = vis_conj.real.float() 332 | vis_imag = vis_conj.imag.float() 333 | 334 | freq_norms = torch.sqrt(torch.sum(uv_dense**2, -1)) 335 | abs_pred = torch.sqrt(pred_visibilities[:,:,0]**2 + pred_visibilities[:,:,1]**2) 336 | energy = torch.mean(freq_norms*abs_pred) 337 | 338 | pred_vis_real = pred_visibilities[:,:,0] 339 | pred_vis_imag = pred_visibilities[:,:,1] 340 | pred_vis_imag[halfspace] = -pred_vis_imag[halfspace] 341 | vis_imag[halfspace] = -vis_imag[halfspace] 342 | nF = 0 343 | if nF == 0: nF = self.hparams.num_fourier_coeff 344 | pred_vis_imag = pred_vis_imag.reshape((-1, nF, nF)) 345 | pred_vis_real = pred_vis_real.reshape((-1, nF, nF)) 346 | vis_imag = vis_imag.reshape((-1, nF, nF)) 347 | vis_real = vis_real.reshape((-1, nF, nF)) 348 | imaginary_loss = ffl(pred_vis_imag.unsqueeze(1), vis_imag.unsqueeze(1)) 349 | real_loss = ffl(pred_vis_real.unsqueeze(1), vis_real.unsqueeze(1)) 350 | 351 | ########################## 352 | 353 | if loss_type=='image': 354 | # loss= (img_0s - img_recon.real ).abs().mean() 355 | loss= (img_recon_gt.real - img_recon.real ).abs().mean() + real_loss + imaginary_loss 356 | return 0., 0., loss, loss, energy 357 | # loss= pred_visibilities.abs().mean() 358 | # return 0., 0., loss, loss,0. 359 | else: 360 | raise Error('undefined loss_type') 361 | 362 | def _step_unet(self, batch, batch_idx, num_zero_samples=0): 363 | # batch is a set of uv coords and complex visibilities 364 | uv_coords, uv_dense, vis_sparse, visibilities, img, label = batch 365 | B,img_res= img.shape[0], img.shape[-1] 366 | 367 | 368 | ### 369 | #UNET 370 | if self.uv_dense_sparse_index is None: 371 | print('getting uv_dense_sparse_index...') 372 | uv_coords_per= uv_coords[0] #S,2 373 | uv_dense_per= uv_dense[0]#N,2 374 | uv_dense_sparse_index= [] 375 | for i_sparse in range(uv_coords_per.shape[0]): 376 | uv_coord= uv_coords_per[i_sparse] 377 | uv_dense_equal= torch.logical_and(uv_dense_per[:,0]==uv_coord[0], uv_dense_per[:,1]==uv_coord[1]) 378 | uv_dense_sparse_index.append( uv_dense_equal.nonzero() ) 379 | # uv_dense_sparse_index= torch.LongTensor(uv_dense_sparse_index,).to(uv_coords.device) 380 | uv_dense_sparse_index= torch.cat(uv_dense_sparse_index).long().to(uv_coords.device) 381 | print('done') 382 | self.uv_dense_sparse_index= uv_dense_sparse_index 383 | 384 | #get the sparse visibility image (input to the UNet) 385 | uv_dense_sparse_map= torch.zeros((uv_coords.shape[0], self.num_fourier_coeff**2, 2), ).to(uv_coords.device) 386 | uv_dense_sparse_map[:,self.uv_dense_sparse_index,: ]=vis_sparse 387 | uv_dense_sparse_map= uv_dense_sparse_map.permute(0, 2, 1).contiguous().reshape(-1, 2, self.num_fourier_coeff, self.num_fourier_coeff) 388 | uv_dense_unet_output= self.UNET(uv_dense_sparse_map)# B,2,H,W or B,1,H,W 389 | ### 390 | 391 | if self.loss_type in ('image', 'image_spectral'): 392 | eht_fov = 1.4108078120287498e-09 393 | max_base = 8368481300.0 394 | scale_ux= max_base * eht_fov/ img_res 395 | #get the query uv's 396 | if self.uv_coords_grid_query is None: 397 | uv_dense_per=uv_dense[0] 398 | u_dense, v_dense= uv_dense_per[:,0].unique(), uv_dense_per[:,1].unique() 399 | u_dense= torch.linspace( u_dense.min(), u_dense.max(), len(u_dense)//2 * 2 ).to(u_dense) 400 | v_dense= torch.linspace( v_dense.min(), v_dense.max(), len(v_dense)//2 * 2 ).to(u_dense) 401 | uv_arr= torch.cat([u_dense.unsqueeze(-1), v_dense.unsqueeze(-1)], dim=-1) 402 | uv_arr= ((uv_arr+.5) * 2 -1.) * scale_ux # scaled input 403 | U,V= torch.meshgrid(uv_arr[:,0], uv_arr[:,1]) 404 | uv_coords_grid_query= torch.cat((U.reshape(-1,1), V.reshape(-1,1)), dim=-1).unsqueeze(0).repeat(uv_coords.shape[0],1,1) 405 | self.uv_arr= uv_arr 406 | self.U, self.V= U,V 407 | self.uv_coords_grid_query= uv_coords_grid_query 408 | print('initilized self.uv_coords_grid_query') 409 | #image recon 410 | if self.norm_fact is None: # get the normalization factor, which is fixed given image/spectral domain dimensions 411 | visibilities_map= visibilities.reshape(-1, self.num_fourier_coeff, self.num_fourier_coeff) 412 | uv_dense_physical = (uv_dense.detach().cpu().numpy()[0,:,:] +0.5)*(2*max_base) - (max_base) 413 | _, _, norm_fact = make_dirtyim(uv_dense_physical, 414 | visibilities_map.detach().cpu().numpy()[ 0, :, :].reshape(-1), 415 | img_res, eht_fov, return_im=True) 416 | self.norm_fact= norm_fact 417 | print('initiliazed the norm fact') 418 | uv_dense_sparse_recon = torch.view_as_complex(uv_dense_unet_output.permute(0,2,3,1).contiguous()) # B,H,W 419 | img_recon = make_im_torch(self.uv_arr, uv_dense_sparse_recon, img_res, eht_fov, norm_fact=self.norm_fact, return_im=True) 420 | #image recon loss 421 | loss= (img - img_recon.real ).abs().mean() 422 | return 0., 0., loss, loss, 0. 423 | 424 | elif self.loss_type in ('spectral'): 425 | #spectral loss 426 | vis_mat= torch.view_as_real(visibilities) 427 | real_loss = self.loss_func(vis_mat[...,0], uv_dense_unet_output[:,0,...].reshape(B,-1) ) 428 | imaginary_loss = self.loss_func(vis_mat[...,1], uv_dense_unet_output[:,1,...].reshape(B,-1)) 429 | loss = real_loss + imaginary_loss 430 | return real_loss, imaginary_loss, 0, loss, 0. 431 | 432 | elif self.loss_type in ('unet_direct'): 433 | loss= (img- uv_dense_unet_output.squeeze(1)).abs().mean() 434 | return 0., 0., loss, loss, 0. 435 | 436 | else: 437 | raise Error(f'undefined loss_type {self.loss_type}') 438 | 439 | 440 | 441 | 442 | def _step(self, batch, batch_idx, num_zero_samples=0): 443 | # batch is a set of uv coords and complex visibilities 444 | uv_coords, uv_dense, vis_sparse, visibilities, img, label = batch 445 | 446 | pos = uv_coords 447 | pe_uv = self.pe_encoder(uv_coords) 448 | vis_sparse_cpu = vis_sparse.cpu() 449 | vis_sparse_np = vis_sparse_cpu.detach().numpy() 450 | 451 | np.save('vis_sparse.npy', vis_sparse_np) 452 | 453 | 454 | 455 | inputs_encoder = torch.cat([pe_uv, vis_sparse], dim=-1) 456 | z = self.context_encoder(inputs_encoder, pos) 457 | 458 | halfspace = self._get_halfspace(uv_dense) 459 | uv_coords_flipped = self._flip_uv(uv_dense, halfspace) 460 | vis_conj = self._conjugate_vis(visibilities, halfspace) 461 | 462 | # now condition MLP on z # 463 | pred_visibilities = self(uv_coords_flipped, z) #Bx HW x2 464 | vis_real = vis_conj.real.float() 465 | vis_imag = vis_conj.imag.float() 466 | 467 | freq_norms = torch.sqrt(torch.sum(uv_dense**2, -1)) 468 | abs_pred = torch.sqrt(pred_visibilities[:,:,0]**2 + pred_visibilities[:,:,1]**2) 469 | energy = torch.mean(freq_norms*abs_pred) 470 | 471 | real_loss = self.loss_func(vis_real, pred_visibilities[:,:,0]) 472 | imaginary_loss = self.loss_func(vis_imag, pred_visibilities[:,:,1]) 473 | 474 | pred_vis_real = pred_visibilities[:,:,0] 475 | pred_vis_imag = pred_visibilities[:,:,1] 476 | pred_vis_imag[halfspace] = -pred_vis_imag[halfspace] 477 | vis_imag[halfspace] = -vis_imag[halfspace] 478 | nF = 0 479 | if nF == 0: nF = self.hparams.num_fourier_coeff 480 | pred_vis_imag = pred_vis_imag.reshape((-1, nF, nF)) 481 | pred_vis_real = pred_vis_real.reshape((-1, nF, nF)) 482 | vis_imag = vis_imag.reshape((-1, nF, nF)) 483 | vis_real = vis_real.reshape((-1, nF, nF)) 484 | 485 | pred_vis_imag = pred_vis_imag.unsqueeze(1).unsqueeze(-1) 486 | pred_vis_real = pred_vis_real.unsqueeze(1).unsqueeze(-1) 487 | pred_vis = torch.cat((pred_vis_imag, pred_vis_real), dim=-1) 488 | vis_imag = vis_imag.unsqueeze(1).unsqueeze(-1) 489 | vis_real = vis_real.unsqueeze(1).unsqueeze(-1) 490 | vis = torch.cat((vis_imag, vis_real), dim=-1) 491 | 492 | loss = ffl(pred_vis, vis) 493 | 494 | 495 | return real_loss, imaginary_loss, loss, energy 496 | 497 | def training_step(self, batch, batch_idx, if_profile=False): 498 | 499 | if if_profile: 500 | print('start: training step') 501 | start = torch.cuda.Event(enable_timing=True) 502 | end = torch.cuda.Event(enable_timing=True) 503 | start.record() 504 | 505 | if self.use_unet: 506 | real_loss, imaginary_loss, image_loss, loss, energy= self._step_unet(batch, batch_idx) 507 | elif self.loss_type=='spectral': 508 | real_loss, imaginary_loss, loss, energy= self._step(batch, batch_idx) 509 | elif self.loss_type=='image' or self.loss_type=='image_spectral': 510 | real_loss, imaginary_loss, image_loss, loss, energy= self._step_image_loss(batch, batch_idx, loss_type=self.loss_type) 511 | self.log('train/image_loss', image_loss, 512 | sync_dist=True if self.ngpu > 1 else False, 513 | rank_zero_only=True if self.ngpu > 1 else False,) 514 | 515 | log_vars = [real_loss, 516 | loss, 517 | imaginary_loss, 518 | energy] 519 | log_names = ['train/real_loss', 520 | 'train/total_loss', 521 | 'train/imaginary_loss', 522 | 'train_metadata/energy'] 523 | 524 | for name, var in zip(log_names, log_vars): 525 | self.log(name, var, 526 | sync_dist=True if self.ngpu > 1 else False, 527 | rank_zero_only=True if self.ngpu > 1 else False) 528 | 529 | return loss 530 | 531 | def test_step(self, batch, batch_idx): 532 | os.makedirs(self.test_fldr, exist_ok=True) 533 | os.makedirs(f'{self.test_fldr}/{batch_idx}', exist_ok=True) 534 | 535 | uv_coords, uv_dense, vis_sparse, visibilities, img, label = batch 536 | B= uv_dense.shape[0] 537 | vis_maps = visibilities.reshape(-1, self.num_fourier_coeff, self.num_fourier_coeff) 538 | eht_fov = 1.4108078120287498e-09 539 | max_base = 8368481300.0 540 | # img_res = self.hparams.input_size 541 | img_res = img.shape[-1] 542 | nF = self.hparams.num_fourier_coeff 543 | scale_ux= max_base * eht_fov/ img_res 544 | 545 | if not self.use_unet: 546 | if self.loss_type=='spectral': 547 | pos = uv_coords 548 | pe_uv = self.pe_encoder(uv_coords) 549 | else: 550 | pos = uv_coords*scale_ux 551 | pe_uv = self.pe_encoder(uv_coords*scale_ux) 552 | inputs_encoder = torch.cat([pe_uv, vis_sparse], dim=-1) 553 | z = self.context_encoder(inputs_encoder, pos) 554 | 555 | halfspace = self._get_halfspace(uv_dense) 556 | uv_coords_flipped = self._flip_uv(uv_dense, halfspace) 557 | vis_conj = self._conjugate_vis(visibilities, halfspace) 558 | vis_real = vis_conj.real.float() 559 | vis_imag = vis_conj.imag.float() 560 | freq_norms = torch.sqrt(torch.sum(uv_dense**2, -1)) 561 | vis_imag[halfspace] = -vis_imag[halfspace] 562 | nF = 0 563 | if nF == 0: nF = self.hparams.num_fourier_coeff 564 | vis_imag = vis_imag.reshape((-1, nF, nF)) 565 | vis_real = vis_real.reshape((-1, nF, nF)) 566 | plt.imsave(fname="../test_res1/Galaxy10-DEC-cont/transformer/mlp_8_layer/image_loss-NF_128/"+str(ti)+"/GT_imag.png", arr=vis_imag[0].cpu().detach().numpy(), cmap="viridis") 567 | plt.imsave(fname="../test_res1/Galaxy10-DEC-cont/transformer/mlp_8_layer/image_loss-NF_128/"+str(ti)+"/GT_real.png", arr=vis_real[0].cpu().detach().numpy(), cmap="viridis") 568 | np.save("../test_res1/Galaxy10-DEC-cont/transformer/mlp_8_layer/image_loss-NF_128/"+str(ti)+"/GT_vis.npy", np.stack((vis_real[0].cpu().detach().numpy(), vis_imag[0].cpu().detach().numpy()), axis=0)) 569 | 570 | 571 | #get the query uv's 572 | if self.uv_coords_grid_query is None: 573 | uv_dense_per=uv_dense[0] 574 | u_dense, v_dense= uv_dense_per[:,0].unique(), uv_dense_per[:,1].unique() 575 | u_dense= torch.linspace( u_dense.min(), u_dense.max(), len(u_dense)//2 * 2 ).to(u_dense) 576 | v_dense= torch.linspace( v_dense.min(), v_dense.max(), len(v_dense)//2 * 2 ).to(u_dense) 577 | uv_arr= torch.cat([u_dense.unsqueeze(-1), v_dense.unsqueeze(-1)], dim=-1) 578 | uv_arr= ((uv_arr+.5) * 2 -1.) * scale_ux # scaled input 579 | U,V= torch.meshgrid(uv_arr[:,0], uv_arr[:,1]) 580 | uv_coords_grid_query= torch.cat((U.reshape(-1,1), V.reshape(-1,1)), dim=-1).unsqueeze(0).repeat(B,1,1) 581 | self.uv_arr= uv_arr 582 | self.U, self.V= U,V 583 | self.uv_coords_grid_query= uv_coords_grid_query 584 | print('initilized self.uv_coords_grid_query') 585 | if self.norm_fact is None: # get the normalization factor, which is fixed given image/spectral domain dimensions 586 | visibilities_map= visibilities.reshape(-1, self.num_fourier_coeff, self.num_fourier_coeff) 587 | uv_dense_physical = (uv_dense.detach().cpu().numpy()[0,:,:] +0.5)*(2*max_base) - (max_base) 588 | _, _, norm_fact = make_dirtyim(uv_dense_physical, 589 | visibilities_map.detach().cpu().numpy()[ 0, :, :].reshape(-1), 590 | img_res, eht_fov, return_im=True) 591 | self.norm_fact= norm_fact 592 | print('initiliazed the norm fact') 593 | 594 | # reconstruct dirty image via eht-im 595 | # constants for our current datasets; TODO: get from metadata 596 | 597 | if self.use_unet: 598 | if self.uv_dense_sparse_index is None: 599 | print('getting uv_dense_sparse_index...') 600 | uv_coords_per= uv_coords[0] #S,2 601 | uv_dense_per= uv_dense[0]#N,2 602 | uv_dense_sparse_index= [] 603 | for i_sparse in range(uv_coords_per.shape[0]): 604 | uv_coord= uv_coords_per[i_sparse] 605 | uv_dense_equal= torch.logical_and(uv_dense_per[:,0]==uv_coord[0], uv_dense_per[:,1]==uv_coord[1]) 606 | uv_dense_sparse_index.append( uv_dense_equal.nonzero() ) 607 | uv_dense_sparse_index= torch.LongTensor(uv_dense_sparse_index,).to(uv_coords.device) 608 | print('done') 609 | self.uv_dense_sparse_index= uv_dense_sparse_index 610 | #get the sparse visibility image (input to the UNet) 611 | uv_dense_sparse_map= torch.zeros((uv_coords.shape[0], self.num_fourier_coeff**2, 2), ).to(uv_coords.device) 612 | uv_dense_sparse_map[:,self.uv_dense_sparse_index,: ]=vis_sparse 613 | uv_dense_sparse_map= uv_dense_sparse_map.permute(0, 2, 1).contiguous().reshape(-1, 2, self.num_fourier_coeff, self.num_fourier_coeff) 614 | uv_dense_unet_output= self.UNET(uv_dense_sparse_map)# B,2,H,W 615 | uv_dense_unet_output= torch.view_as_complex(uv_dense_unet_output.permute(0,2,3,1).contiguous()) # B,H,W 616 | img_recon = make_im_torch(self.uv_arr, uv_dense_unet_output, img_res, eht_fov, norm_fact=self.norm_fact, return_im=True) 617 | img_recon_gt = img 618 | img_recon = (img_recon.real).float() / img_recon.abs().max() 619 | img_recon_gt = (img_recon_gt).float() / img_recon_gt.abs().max() 620 | 621 | elif self.loss_type == 'spectral': 622 | pred_fft = self.inference_w_conjugate(uv_dense, z, return_np=False) 623 | uv_dense_physical = (uv_dense.detach().cpu().numpy()[0,:,:] +0.5)*(2*max_base) - (max_base) 624 | img_recon = make_im_torch(self.uv_arr, pred_fft, img_res, eht_fov, norm_fact=self.norm_fact, return_im=True) 625 | img_recon_gt= make_im_torch(self.uv_arr, vis_maps, img_res, eht_fov, norm_fact=self.norm_fact, return_im=True) 626 | img_recon = (img_recon.real).float() 627 | img_recon_gt = (img_recon_gt.real).float() 628 | 629 | elif self.loss_type in ('image', 'image_spectral'): 630 | pred_visibilities = self(self.uv_coords_grid_query, z) #Bx (HW) x 2 631 | pred_visibilities_map= torch.view_as_complex(pred_visibilities).reshape(B, self.U.shape[0], self.U.shape[1]) 632 | img_recon = make_im_torch(self.uv_arr, pred_visibilities_map, img_res, eht_fov, norm_fact=self.norm_fact, return_im=True) 633 | img_recon_gt= make_im_torch(self.uv_arr, vis_maps, img_res, eht_fov, norm_fact=self.norm_fact, return_im=True) 634 | img_recon = (img_recon.real).float() / img_recon.abs().max() 635 | img_recon_gt = (img_recon_gt.real).float() / img_recon_gt.abs().max() 636 | 637 | plt.imsave(f'{self.test_fldr}/{batch_idx}/image.png', img_recon_gt.reshape(-1, img.shape[-1]).cpu(), cmap='hot') 638 | plt.imsave(f'{self.test_fldr}/{batch_idx}/recon_image.png', img_recon.reshape(-1, img.shape[-1]).cpu(), cmap='hot') 639 | 640 | def validation_step(self, batch, batch_idx): 641 | pass 642 | 643 | 644 | def validation_epoch_end(self, outputs): 645 | pass 646 | 647 | 648 | def from_pretrained(self, checkpoint_name): 649 | return self.load_from_checkpoint(checkpoint_name, strict=False) 650 | 651 | def configure_optimizers(self): 652 | return torch.optim.Adam(self.parameters(), lr=1e-4) 653 | 654 | @staticmethod 655 | def add_model_specific_args(parent_parser): 656 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 657 | parser.add_argument('--L_embed', type=int, default=128) 658 | parser.add_argument('--input_encoding', type=str, choices=['fourier','nerf','none'], default='nerf') 659 | parser.add_argument('--learning_rate', type=float, default=1e-4) 660 | parser.add_argument('--sigma', type=float, default=5.0) #sigma=1 seems to underfit and 4 overfits/memorizes 661 | parser.add_argument('--model_checkpoint', type=str, default='') #, default='./vae_flow_e2e_kl0.1_epoch139.ckpt') 662 | parser.add_argument('--val_fldr', type=str, default=f'./val_fldr-test') 663 | 664 | return parser 665 | 666 | 667 | def parse_yaml(args,): 668 | ''' 669 | Parse the yaml file, the settings in the yaml file are given higher priority 670 | args: 671 | argparse.Namespace 672 | ''' 673 | import yaml 674 | 675 | opt=vars(args) 676 | opt_raw= vars(args).copy() 677 | args_yaml= yaml.unsafe_load(open(args.yaml_file)) 678 | opt.update(args_yaml,) 679 | 680 | opt['eval'] =opt_raw['eval'] 681 | opt['exp_name'] =opt_raw['exp_name'] 682 | opt['ngpus'] =opt_raw['ngpus'] 683 | opt['dataset']= opt_raw['dataset'] 684 | opt['model_checkpoint'] =opt_raw['model_checkpoint'] 685 | opt['dataset_path']= opt_raw['dataset_path'] 686 | opt['data_path_imgs']= opt_raw['data_path_imgs'] 687 | opt['data_path_cont']= opt_raw['data_path_cont'] 688 | opt['loss_type']= opt_raw['loss_type'] 689 | opt['num_fourier']= opt_raw['num_fourier'] 690 | opt['input_size']= opt_raw['input_size'] 691 | 692 | args= argparse.Namespace(**opt) 693 | return args 694 | 695 | 696 | def cli_main(): 697 | pl.seed_everything(42) 698 | 699 | # ------------ 700 | # args 701 | # ------------ 702 | parser = ArgumentParser() 703 | parser.add_argument('--exp_name', type=str, default='test') #default='Galaxy10_DECals_cont_mlp8') 704 | parser.add_argument('--ngpus', nargs='+', type=int, default=[0]) 705 | parser.add_argument('--eval', action='store_true', 706 | default=False, help='if evaluation mode [False]') 707 | 708 | parser.add_argument('--batch_size', default=32, type=int) 709 | parser.add_argument('--num_workers', default=16, type=int) 710 | parser.add_argument('--dataset', type=str, 711 | # default='Galaxy10', 712 | default='Galaxy10_DECals', 713 | help='MNIST | Galaxy10 | Galaxy10_DECals') 714 | 715 | parser.add_argument('--dataset_path', type=str, 716 | #default=f'/astroim//data/eht_grid_256FC_200im_MNIST_full.h5', 717 | #default=f'/astroim//data/eht_grid_256FC_200im_Galaxy10_full.h5', 718 | #default=f'/astroim/data/eht_grid_256FC_200im_Galaxy10_DECals_full.h5', 719 | # default=f'/astroim/data/eht_grid_256FC_200im_Galaxy10_DECals_test100.h5', 720 | default=f'../data/eht_grid_256FC_200im_Galaxy10_DECals_full.h5', 721 | # default=f'../data/eht_grid_128FC_200im_Galaxy10_full.h5', 722 | help='dataset path to precomputed spectral data (dense grid and sparse grid)') 723 | 724 | parser.add_argument('--data_path_cont', type=str, 725 | #default=f'/astroim/data/eht_cont_200im_MNIST_full.h5', 726 | # default=f'../data/eht_cont_200im_Galaxy10_full.h5', 727 | # default=f'../data/eht_cont_200im_Galaxy10_DECals_full.h5', 728 | # default=f'/astroim/data/eht_cont_200im_Galaxy10_DECals_full.h5', 729 | default=None, 730 | help='dataset path to precomputed spectral data (continuous)') 731 | 732 | parser.add_argument('--data_path_imgs', type=str, 733 | # default=None, 734 | # default='../data/Galaxy10.h5', 735 | default='../data/Galaxy10_DECals.h5', 736 | help='dataset path to Galaxy10 images; for MNIST, it is by default at ./MNIST; if None, sets to 0s (faster, imgs usually not needed)') 737 | 738 | parser.add_argument('--input_size', default=64, type=int) 739 | parser.add_argument('--num_fourier', default=256, type=int) 740 | parser.add_argument('--loss_type', type=str, default='spectral', help='spectral | image | spectral_image [spectral]') 741 | parser.add_argument('--scale_loss_image', type=float, default=1., help='only valid if use spectral_image as the loss_type' ) 742 | parser.add_argument('--mlp_layers', default=8, type=int, help=' # of layers in mlp, this will also decide the # of tokens [8]') 743 | parser.add_argument('--mlp_hidden_dim', default=256, type=int, help=' hidden dims in mlp [256]') 744 | parser.add_argument('--m_epochs', default=400, type=int, help= '# of max training epochs [1000]') 745 | 746 | parser.add_argument('--yaml_file', default='', type=str, help ='path to yaml file') 747 | 748 | parser = pl.Trainer.add_argparse_args(parser) # get lightning-specific commandline options 749 | parser = PolarRec.add_model_specific_args(parser) # get model-defined commandline options 750 | args = parser.parse_args() 751 | 752 | 753 | yaml_file= args.yaml_file 754 | if len(yaml_file)>0: 755 | parse_yaml(args) 756 | 757 | 758 | latent_dim = 1024 759 | 760 | # # ------------ 761 | # # data 762 | # # ------------ 763 | # # load up dataset of u, v vis and images 764 | 765 | dataset = EHTIM_Dataset(dset_name = args.dataset, 766 | data_path = args.dataset_path, 767 | data_path_cont = args.data_path_cont, 768 | data_path_imgs = args.data_path_imgs, 769 | img_res = args.input_size, 770 | pre_normalize = False, 771 | ) 772 | 773 | split_train, split_test = random_split(dataset, [len(dataset)-len(dataset)//5, len(dataset)//5]) 774 | split_train, split_val = random_split(split_train, [len(split_train)-len(dataset)//10, len(dataset)//10]) 775 | 776 | 777 | 778 | 779 | ngpu = torch.cuda.device_count() 780 | 781 | 782 | train_loader = DataLoader( 783 | split_train, 784 | batch_size=32, 785 | num_workers=args.num_workers, 786 | shuffle=True, drop_last=True) 787 | 788 | val_loader = DataLoader( 789 | split_val, 790 | batch_size=32, 791 | num_workers=args.num_workers, 792 | drop_last=True) 793 | 794 | test_loader = DataLoader( 795 | split_test, 796 | batch_size=1, 797 | num_workers=args.num_workers, 798 | drop_last=True) 799 | 800 | 801 | # ------------ 802 | # model 803 | # ------------ 804 | mlp_hiddens = [args.mlp_hidden_dim for i in range(args.mlp_layers-1)] 805 | implicitModel = PolarRec(args, 806 | learning_rate=1e-4, 807 | L_embed=args.L_embed, 808 | input_encoding=args.input_encoding, 809 | sigma=args.sigma, 810 | num_fourier_coeff=args.num_fourier, 811 | batch_size=32, 812 | input_size=args.input_size, 813 | latent_dim=latent_dim, 814 | hidden_dims=mlp_hiddens, 815 | model_checkpoint=args.model_checkpoint, 816 | ngpu=ngpu) 817 | 818 | if len(args.model_checkpoint)>0: 819 | print(f'--- loading from {args.model_checkpoint}...') 820 | implicitModel = implicitModel.load_from_checkpoint(args.model_checkpoint) 821 | implicitModel.ngpu= ngpu 822 | 823 | checkpoint_callback = ModelCheckpoint(monitor='train/total_loss',dirpath='') 824 | # trainer = pl.Trainer(callbacks=[checkpoint_callback]) 825 | trainer = pl.Trainer(callbacks=[EarlyStopping(monitor='val_loss')]) 826 | trainer = pl.Trainer.from_argparse_args(args, 827 | gpus=args.ngpus, 828 | plugins=DDPPlugin(find_unused_parameters=False), 829 | replace_sampler_ddp=True, 830 | accelerator='ddp', 831 | progress_bar_refresh_rate=20, 832 | max_epochs=args.m_epochs, 833 | val_check_interval=0.25, 834 | ) 835 | 836 | # ------------ 837 | # training 838 | # ------------ 839 | if not args.eval: 840 | print('==Training==') 841 | print(f'--- loading from {args.model_checkpoint}...') 842 | 843 | trainer.fit(implicitModel, train_loader, val_loader) 844 | print(implicitModel) 845 | else: 846 | print('==Testing==') 847 | trainer.test(implicitModel, test_loader, ) 848 | 849 | 850 | 851 | if __name__ == '__main__': 852 | cli_main() 853 | 854 | -------------------------------------------------------------------------------- /data_ehtim_cont.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | import matplotlib.gridspec as gridspec 4 | import matplotlib.colors as colors 5 | 6 | import numpy as np 7 | from numpy.fft import fft2, fftshift 8 | from scipy import interpolate 9 | 10 | from torchvision import transforms 11 | from PIL import Image 12 | 13 | import ehtim as eh 14 | import ehtim.const_def as ehc 15 | 16 | import h5py 17 | import torch 18 | from torch.utils.data import Dataset 19 | from skimage import color 20 | 21 | # CLEAN tests 22 | from ehtim.imaging.clean import * 23 | # import gICLEAN 24 | 25 | import socket 26 | hostname= socket.gethostname() 27 | 28 | from scipy.signal import convolve2d 29 | 30 | torch.manual_seed(0) 31 | np.random.seed(0) 32 | 33 | def rgb2gray(rgb): 34 | r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2] 35 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b 36 | return gray 37 | 38 | 39 | def preprocess_ehtim(img): 40 | # load the image (path or data) into eht obs format 41 | 42 | if (torch.is_tensor(img) or isinstance(img, np.ndarray)): # if img arr or tensor 43 | image = img.numpy() 44 | else: 45 | return 'img type not recognized' 46 | 47 | if image.ndim==3 and image.shape[-1]==3: 48 | image = rgb2gray( image ) 49 | 50 | return image 51 | 52 | 53 | def eht_createImg(image, normalize=False, pulse=ehc.PULSE_DEFAULT, obs_type='eht'): 54 | ''' 55 | image - np array 56 | ''' 57 | 58 | filename = './code/avery_m87_2_eofn.txt' # 200x200 59 | if obs_type == 'dense': # synthetic dense array 60 | meta_file =f'data/eht-imaging/array_test_dense.txt' 61 | elif obs_type == 'sparse': # synthetic sparse array 62 | meta_file =f'data/eht-imaging/array_test_sparse.txt' 63 | else: # EHT array from CHIRP Supplement 64 | meta_file ='./code/EHT2017.txt' 65 | 66 | assert image.shape[0]==200 67 | 68 | # Read the header 69 | file = open(filename) 70 | src = ' '.join(file.readline().split()[2:]) 71 | ra = file.readline().split() 72 | ra = float(ra[2]) + float(ra[4]) / 60.0 + float(ra[6]) / 3600.0 73 | dec = file.readline().split() 74 | dec = np.sign(float(dec[2])) * (abs(float(dec[2])) + 75 | float(dec[4]) / 60.0 + float(dec[6]) / 3600.0) 76 | mjd_float = float(file.readline().split()[2]) 77 | mjd = int(mjd_float) 78 | time = (mjd_float - mjd) * 24 79 | rf = float(file.readline().split()[2]) * 1e9 80 | xdim = file.readline().split() 81 | xdim_p = int(xdim[2]) 82 | psize_x = float(xdim[4]) * ehc.RADPERAS / xdim_p 83 | ydim = file.readline().split() 84 | ydim_p = int(ydim[2]) 85 | psize_y = float(ydim[4]) * ehc.RADPERAS / ydim_p 86 | file.close() 87 | 88 | if normalize: 89 | img = image / np.sqrt((image**2).sum()) 90 | else: 91 | img = image 92 | 93 | # load the image 94 | eht_image= eh.image.Image( 95 | img, 96 | psize_x, ra, dec, 97 | rf=rf, source=src, mjd=mjd, time=time, pulse=pulse, 98 | polrep='stokes', pol_prim='I') 99 | 100 | # load meta 101 | eht_meta = eh.array.load_txt(meta_file) 102 | return eht_image, eht_meta 103 | 104 | 105 | def upscale_tensor(x, final_res=256, method='nearest'): 106 | init_res = x.shape[0] 107 | xy_idx_dense = np.mgrid[:init_res,:init_res] 108 | x_idx_dense = xy_idx_dense[0].flatten() 109 | y_idx_dense = xy_idx_dense[1].flatten() 110 | 111 | # meshgrid from 0..(final_res-1)/final_res with final_res number of entries 112 | U, V = torch.meshgrid(torch.arange(final_res), torch.arange(final_res)) 113 | U, V = U/float(final_res), V/float(final_res) 114 | 115 | # now it's a meshgrid from 0..(1-1/final_res)*init_res = init_res - init_res/final_res 116 | U, V = init_res*U, init_res*V 117 | 118 | upscaled = interpolate.griddata((x_idx_dense, y_idx_dense), x.flatten(), (U, V), method=method, fill_value=-0.5) 119 | 120 | return upscaled 121 | 122 | 123 | def obs_with_eht(img_path, obs_type='eht', eht_npix=200): 124 | 125 | image = preprocess_ehtim(img_path) 126 | eht_im, eht_meta = eht_createImg(image, normalize=True, obs_type=obs_type) 127 | 128 | # Observe the image 129 | # tint_sec is the integration time in seconds, and tadv_sec is the advance time between scans 130 | # tstart_hr is the GMST time of the start of the observation and tstop_hr is the GMST time of the end 131 | # bw_hz is the bandwidth in Hz 132 | # sgrscat=True blurs the visibilities with the Sgr A* scattering kernel for the appropriate image frequency 133 | # ampcal and phasecal determine if gain variations and phase errors are included 134 | if obs_type=='dense': 135 | tadv_sec = 600 136 | elif obs_type=='sparse': 137 | tadv_sec = 6000 138 | else: # default use EHT 139 | tadv_sec = 600 140 | tstart_hr = 0 141 | tstop_hr = 24 142 | tint_sec = 12 143 | bw_hz = 4.096e9 144 | eht_obs = eht_im.observe(eht_meta, tint_sec, tadv_sec, tstart_hr, tstop_hr, bw_hz, 145 | sgrscat=False, add_th_noise=False, ampcal=True, phasecal=True, ttype='direct') 146 | 147 | # FOV used in CHIRP (approx angular size of M87 SMBH) [200x200] 148 | eht_fov = np.radians(.000291/3600) 149 | 150 | # Resolution 151 | eht_res = eht_obs.res() # nominal array resolution, 1/longest baseline 152 | print("Nominal Resolution: " , eht_res) 153 | print("FoV: " , eht_fov) 154 | 155 | return eht_obs, eht_im, eht_res, eht_fov, eht_npix 156 | 157 | def make_im_torch( 158 | uv_arr, vis_arr, npix, fov, pulse=ehc.PULSE_DEFAULT, weighting='uniform', norm_fact=None, return_im=False, seperable_FFT=True, rescaled_pix=True): 159 | """Make the observation image using direct Fourier transform. 160 | Assume the visibilities are on regulars grid in the continuous domain 161 | 162 | Args: 163 | uv_arr- U x 2 (U==V) 164 | vis_arr- B x U x V 165 | npix (int): The pixel size of the square output image. 166 | fov (float): The field of view of the square output image in radians. 167 | pulse (function): The function convolved with the pixel values for continuous image. 168 | weighting (str): 'uniform' or 'natural' 169 | Returns: 170 | (Image): an Image object with dirty image. 171 | """ 172 | import math 173 | 174 | if rescaled_pix: 175 | pdim = 1. #scaled input 176 | else: 177 | pdim = fov / npix 178 | 179 | u = uv_arr[:,0] 180 | v = uv_arr[:,1] 181 | 182 | B, U, V= vis_arr.shape[0], vis_arr.shape[1], vis_arr.shape[2] 183 | assert U==V 184 | 185 | #TODO: xlist as input to speed up 186 | #DONE: calculate the scale of u*x and v*x directly 187 | #DONE: scaled by normfac 188 | xlist = torch.arange(0, -npix, -1, device=uv_arr.device) * pdim + (pdim * npix) / 2.0 - pdim / 2.0 189 | 190 | 191 | # #--Sequence 1D Inverse DFT--# 192 | if seperable_FFT: 193 | X_coord= xlist.reshape(1, npix, 1, 1, 1) 194 | Y_coord= xlist.reshape(1, 1, npix, 1, 1) 195 | U_coord= u.reshape(1,1,1, U,1) 196 | V_coord= v.reshape(1,1,1, 1,V) 197 | Vis= vis_arr.reshape(B, 1, 1, U, V) 198 | #the inner integration (over u) 199 | U_X= U_coord*X_coord 200 | 201 | # temp_a = Vis * torch.exp(-2.j* math.pi* U_X) 202 | # inner_integral= torch.sum(temp_a , dim=-2,keepdim=True)/temp_a.size(-2) #B X 1 1 V 203 | 204 | 205 | inner_integral= torch.mean(Vis * torch.exp(-2.j* math.pi* U_X) , dim=-2,keepdim=True) #B X 1 1 V 206 | #the outer integration (over v) 207 | V_Y= V_coord*Y_coord 208 | 209 | # temp_b=inner_integral * torch.exp(-2.j*math.pi* V_Y) 210 | # outer_integral= torch.sum(temp_b, dim=-1, keepdim=True )/temp_b.size(-1) # B X Y 1 1 211 | outer_integral= torch.mean(inner_integral * torch.exp(-2.j*math.pi* V_Y), dim=-1, keepdim=True ) # B X Y 1 1 212 | image_complex= outer_integral.squeeze(-1).squeeze(-1) # B X Y 213 | else: 214 | #--2D raw version IDFT--# 215 | X_coord= xlist.reshape(1, npix, 1, 1, 1).expand(B,npix,npix, U,V) 216 | Y_coord= xlist.reshape(1, 1, npix, 1, 1).expand_as(X_coord) 217 | U_coord= u.reshape(1,1,1, U,1).expand_as(X_coord) 218 | V_coord= v.reshape(1,1,1, 1,V).expand_as(X_coord) 219 | U_X= U_coord*X_coord 220 | V_Y= V_coord*Y_coord 221 | Vis= vis_arr.reshape(B, 1, 1, U, V).expand_as(X_coord) 222 | temp_c = Vis * torch.exp(-2.j*math.pi*(U_X + V_Y)) 223 | image_complex= torch.mean(temp_c, dim=-1).mean(dim=-1) 224 | # temp_d = torch.sum(temp_c, dim=-1)/temp_c.size(-1) 225 | # image_complex = torch.sum(temp_d)/temp_d.size(-1) 226 | 227 | 228 | if norm_fact is not None: 229 | image_complex= image_complex* norm_fact 230 | 231 | 232 | # import pdb; pdb.set_trace() 233 | return image_complex 234 | 235 | 236 | def make_dirtyim(uv_arr, vis_arr, npix, fov, pulse=ehc.PULSE_DEFAULT, weighting='uniform', return_im=False, cutoff_freq=0.03, sigma=1.0): 237 | """Make the observation dirty image (direct Fourier transform). 238 | 239 | Args: 240 | 241 | npix (int): The pixel size of the square output image. 242 | fov (float): The field of view of the square output image in radians. 243 | pulse (function): The function convolved with the pixel values for continuous image. 244 | weighting (str): 'uniform' or 'natural' 245 | Returns: 246 | (Image): an Image object with dirty image. 247 | """ 248 | 249 | pdim = fov / npix 250 | u = uv_arr[:,0] 251 | v = uv_arr[:,1] 252 | 253 | xlist = np.arange(0, -npix, -1) * pdim + (pdim * npix) / 2.0 - pdim / 2.0 254 | if weighting == 'natural': 255 | sigma = np.atleast_2d(sigma) 256 | print(u.shape, sigma.shape); input() 257 | weights = 1. / (sigma*sigma) 258 | else: 259 | weights = np.ones(u.shape) 260 | 261 | dim= np.array([[np.mean(weights * np.cos(-2 * np.pi * (i * u + j * v))) 262 | for i in xlist] 263 | for j in xlist]) 264 | normfac= 1. / np.sum(dim) 265 | 266 | vis = vis_arr 267 | 268 | # TODO -- use NFFT 269 | # TODO -- different beam weightings 270 | im = np.array([[np.mean(weights * (np.real(vis) * np.cos(-2 * np.pi * (i * u + j * v)) - 271 | np.imag(vis) * np.sin(-2 * np.pi * (i * u + j * v)))) 272 | for i in xlist] 273 | for j in xlist]) 274 | 275 | # Final normalization 276 | im = im * normfac 277 | im = im[0:npix, 0:npix] 278 | 279 | do_sinc = False 280 | if do_sinc: 281 | 282 | fc = cutoff_freq # Cutoff frequency as a fraction of the sampling rate (in (0, 0.5)). 283 | b = 2.0*fc/3.0 # Transition band, as a fraction of the sampling rate (in (0, 0.5)). 284 | N = int(np.ceil((4 / b))) 285 | if not N % 2: N += 1 # Make sure that N is odd. 286 | crop = int(N / 2) 287 | n = np.arange(N) 288 | 289 | # Compute sinc filter. 290 | h = np.sinc(2 * fc * (n - (N - 1) / 2)) 291 | 292 | # Compute Blackman window. 293 | #w = 0.42 - 0.5 * np.cos(2 * np.pi * n / (N - 1)) + \ 294 | # 0.08 * np.cos(4 * np.pi * n / (N - 1)) 295 | w = np.blackman(N) 296 | 297 | # Multiply sinc filter by window. 298 | h_windowed = h * w 299 | 300 | # Normalize to get unity gain. 301 | h_windowed = h_windowed / np.sum(h_windowed) 302 | 303 | do_plot = False 304 | if do_plot: 305 | import pylab as plt 306 | fig, axs = plt.subplots(nrows=2, ncols=3, constrained_layout=True) 307 | 308 | axs[0, 0].plot(h) 309 | axs[0, 0].set_title("sinc filter") 310 | axs[0, 1].plot(w) 311 | axs[0, 1].set_title("blackman window") 312 | axs[0, 2].plot(h_windowed) 313 | axs[0, 2].set_title("windowed sinc") 314 | axs[1, 0].plot(np.fft.fftshift(np.fft.fft(h)), 'o') 315 | axs[1, 0].set_title("sinc filter") 316 | axs[1, 1].plot(np.fft.fftshift(np.fft.fft(w)), 'o') 317 | axs[1, 1].set_title("blackman window") 318 | axs[1, 2].plot(np.fft.fftshift(np.fft.fft(h_windowed)), 'o') 319 | axs[1, 2].set_title("windowed sinc") 320 | plt.show() 321 | 322 | im_shape = im.shape 323 | im_x = np.stack([np.convolve(im[i,:], h) for i in range(im.shape[0])]) 324 | 325 | im_xy = np.stack([np.convolve(im_x[:,i], h) for i in range(im_x.shape[1])]) 326 | 327 | im = im_xy[crop:im_shape[0]+crop, crop:im_shape[1]+crop].T 328 | 329 | print(im.shape, N, crop, im_xy.shape, im_x.shape, w.shape, h.shape); 330 | 331 | out = eh.image.Image(im, pdim, 10, 20, pulse=pulse) # filler RA/Dec values 332 | #out = ehtim.image.Image(im, pdim, self.ra, self.dec, polrep=self.polrep, 333 | # rf=self.rf, source=self.source, mjd=self.mjd, pulse=pulse) 334 | if not return_im: 335 | return out 336 | 337 | else: 338 | return out, im, normfac 339 | 340 | 341 | def get_uvvis_data(img_path, obs_type='eht', eht_npix=200, num_fourier_coeff=64): 342 | """ obs an image with ehtim, return {u,v,vis} for grid dense, continuous sparse, grid sparse data 343 | """ 344 | # data dicts 345 | grid_dense = {} 346 | cont_sparse = {} 347 | grid_sparse = {} 348 | obs_meta = {} 349 | 350 | # eht-im observation (continuous sparse) 351 | eht_obs, eht_im, eht_res, eht_fov, eht_npix = obs_with_eht(img_path, obs_type=obs_type, eht_npix=eht_npix) 352 | u_eht = np.array(eht_obs.unpack(['u'], conj=True)).astype(np.float) 353 | v_eht = np.array(eht_obs.unpack(['v'], conj=True)).astype(np.float) 354 | vis_eht = np.array(eht_obs.unpack(['vis'], conj=True)).astype(np.complex) 355 | uv_dist_eht = np.array(eht_obs.unpack(['uvdist'], conj=True)).astype(np.float) 356 | 357 | # dataset: ground truth (scaled to eht_npix) 358 | obs_meta['gt_img'] = eht_im.imarr() 359 | obs_meta['res'] = eht_res 360 | obs_meta['fov'] = eht_fov 361 | obs_meta['npix'] = eht_npix 362 | obs_meta['n_FC'] = num_fourier_coeff 363 | obs_meta['sigma'] = eht_obs.unpack(['sigma']) 364 | 365 | # dataset: continuous sparse 366 | cont_sparse['uv'] = np.stack((u_eht, v_eht), axis=1) 367 | cont_sparse['vis'] = vis_eht 368 | #cont_sparse['dim'] = make_dirtyim(cont_sparse['uv'], cont_sparse['vis'], eht_npix, eht_fov) 369 | 370 | # dataset: grid dense 371 | '''max_base = np.max(uv_dist_eht) 372 | x = np.linspace(-max_base, max_base, num_fourier_coeff) 373 | y = np.linspace(-max_base, max_base, num_fourier_coeff) 374 | xv, yv = np.meshgrid(x, y) 375 | grid_dense['uv'] = np.stack((xv.ravel(), yv.ravel()), axis=1) 376 | grid_dense['vis'] = eht_im.sample_uv(grid_dense['uv'])[0] # ignore polarizations 377 | #grid_dense['dim'] = make_dirtyim(grid_dense['uv'], grid_dense['vis'], eht_npix, eht_fov) 378 | 379 | # dataset: grid sparse 380 | x_centers = (x[1:]+x[:-1])/2 381 | y_centers = (y[1:]+y[:-1])/2 382 | u_dig = np.digitize(u_eht, x_centers) 383 | v_dig = np.digitize(v_eht, y_centers) 384 | uv_dig = np.stack((x[u_dig], y[v_dig]), axis=1) 385 | grid_sparse['uv'] = np.unique(uv_dig , axis=0) # remove duplicates 386 | grid_sparse['vis'] = eht_im.sample_uv(grid_sparse['uv'])[0]''' 387 | #grid_sparse['dim'] = make_dirtyim(grid_sparse['uv'], grid_sparse['vis'], eht_npix, eht_fov) 388 | 389 | #return grid_dense, cont_sparse, grid_sparse, obs_meta 390 | return None, cont_sparse, None, obs_meta 391 | 392 | 393 | def plot_eht_compare(grid_dense, cont_sparse, grid_sparse, obs_meta, savefig=False, cutoff_freq=0.03): 394 | """ 3-row plot for eht_npix resolution obs """ 395 | 396 | # make dirty images: 397 | eht_npix, eht_fov, num_fourier_coeff, sigma = obs_meta['npix'], obs_meta['fov'], obs_meta['n_FC'], obs_meta['sigma'] 398 | dim_grid_dense, im1, norm_grid_dense = make_dirtyim(grid_dense['uv'], grid_dense['vis'], eht_npix, eht_fov, sigma=sigma, cutoff_freq=cutoff_freq, return_im=True) 399 | dim_cont_sparse, im2, norm_cont_sparse = make_dirtyim(cont_sparse['uv'], cont_sparse['vis'], eht_npix, eht_fov, sigma=sigma, cutoff_freq=cutoff_freq, return_im=True) 400 | dim_grid_sparse, im3, norm_grid_sparse = make_dirtyim(grid_sparse['uv'], grid_sparse['vis'], eht_npix, eht_fov, sigma=sigma, cutoff_freq=cutoff_freq, return_im=True) 401 | 402 | dirty_beam, im4, norm_dirty_beam = make_dirtyim(cont_sparse['uv'], np.ones_like(cont_sparse['vis']), eht_npix, eht_fov, sigma=sigma, cutoff_freq=cutoff_freq, return_im=True) 403 | 404 | dim_grid_dense = dim_grid_dense.imarr() 405 | dim_cont_sparse = dim_cont_sparse.imarr() 406 | dim_grid_sparse = dim_grid_sparse.imarr() 407 | dirty_beam = dirty_beam.imarr() 408 | 409 | '''import pylab as plt 410 | fig, ax = plt.subplots(nrows=4, ncols=2) 411 | ax[0,0].imshow(dim_grid_dense) 412 | ax[0,1].imshow(im1) 413 | ax[1,0].imshow(dim_cont_sparse) 414 | ax[1,1].imshow(im2) 415 | ax[2,0].imshow(dim_grid_sparse) 416 | ax[2,1].imshow(im3) 417 | ax[3,0].imshow(dirty_beam) 418 | ax[3,1].imshow(im4) 419 | print(norm1, norm2, norm3, norm4) 420 | print(dim_grid_dense.max(), im1.max()) 421 | print(dim_cont_sparse.max(), im2.max()) 422 | print(dim_grid_sparse.max(), im3.max()) 423 | print(dirty_beam.max(), im4.max()) 424 | print('---') 425 | print(dim_grid_dense.min(), im1.min()) 426 | print(dim_cont_sparse.min(), im2.min()) 427 | print(dim_grid_sparse.min(), im3.min()) 428 | print(dirty_beam.min(), im4.min()) 429 | plt.show()''' 430 | 431 | '''import pylab as plt 432 | fig, ax = plt.subplots(nrows=1, ncols=2) 433 | ax[0].imshow(dim_cont_sparse) 434 | ax[1].imshow(dirty_beam) 435 | plt.show()''' 436 | 437 | # import gICLEAN 438 | # gICLEAN.clean_cuda(dirty_im=dim_cont_sparse/norm_cont_sparse, dirty_psf=dirty_beam/norm_dirty_beam, thresh=0.001, gain=1.0, clean_beam_size=4.0, 439 | # maxIter=1e6, 440 | # prefix='test4', 441 | # im_gt=dim_grid_dense/norm_grid_dense, 442 | # polarity=False) 443 | # input("done!") 444 | 445 | # plot properties 446 | vmin, vmax = 1e-4, 1e2 # 1e-2, 1e3 # fft color range 447 | #vmin_img, vmax_img = 1.5*np.min(img), 1.5*np.max(img) 448 | uv_dist_eht = np.linalg.norm(cont_sparse['uv'], axis=1) 449 | max_base = np.max(uv_dist_eht) 450 | 451 | # make figure 452 | fig = plt.figure(figsize=(16, 12), dpi=300) 453 | gs = gridspec.GridSpec(3,4, hspace=0.3, wspace=0.25) 454 | 455 | # grid_dense 456 | ax = plt.subplot(gs[0,0]) 457 | ax.set_title("(%s x %s Dense grid) Dirty Image\n$I^{D}_{grid}(l,m) \equiv \mathscr{F}^{-1}_{NU}[\hat{\mathcal{V}}_{EHT}(u,v)]$" % (num_fourier_coeff, num_fourier_coeff), fontsize=10) 458 | ax.imshow(dim_grid_dense) #, vmin=vmin_img, vmax=vmax_img) 459 | 460 | ax = plt.subplot(gs[0,1]) 461 | ax.set_title("Visibity Phase\n$ \\angle{\mathcal{V}(u,v)}$", fontsize=10) 462 | ax.scatter(grid_dense['uv'][:,0], grid_dense['uv'][:,1], c=np.angle(grid_dense['vis']), 463 | s=1, cmap='twilight', vmin=-np.pi, vmax=np.pi, rasterized=True) 464 | ax.set_xlim([-1.1*max_base, 1.1*max_base]) 465 | ax.set_ylim([-1.1*max_base, 1.1*max_base]) 466 | 467 | ax = plt.subplot(gs[0,2]) 468 | ax.set_title("Visibility Amplitude\n$|\mathcal{V}(u,v)|$", fontsize=10) 469 | ax.scatter(grid_dense['uv'][:,0], grid_dense['uv'][:,1], c=np.abs(grid_dense['vis']), 470 | s=1, cmap='viridis', vmin=vmin, vmax=vmax, rasterized=True) 471 | ax.set_xlim([-1.1*max_base, 1.1*max_base]) 472 | ax.set_ylim([-1.1*max_base, 1.1*max_base]) 473 | 474 | ax = plt.subplot(gs[0,3]) 475 | ax.set_title("Visibility Amplitude \n vs. UV distance", fontsize=10) 476 | ax.scatter(np.linalg.norm(grid_dense['uv'], axis=1), np.abs(grid_dense['vis']), c=np.abs(grid_dense['vis']), 477 | s=1, marker='.', vmin=vmin, vmax=vmax, rasterized=True) 478 | ax.text(0.03, 0.97, f"n={len(grid_dense['uv'])}", fontsize=8, ha='left', va='top', transform=ax.transAxes) 479 | ax.set_yscale('log') 480 | ax.set_xlim([0,1.25e10]) 481 | ax.set_ylim([1e-1,3000]) 482 | 483 | # cont_sparse 484 | ax = plt.subplot(gs[1,0]) 485 | ax.set_title("(EHT) Dirty Image\n$I^{D}_{EHT}(l,m) \equiv \mathscr{F}^{-1}_{NU}[\hat{\mathcal{V}}_{EHT}(u,v)]$", fontsize=10) 486 | ax.imshow(dim_cont_sparse) #, vmin=vmin_img, vmax=vmax_img) 487 | 488 | ax = plt.subplot(gs[1,1]) 489 | ax.set_title("(EHT) Observed Visib. Phase\n$ \\angle{\hat{\mathcal{V}}_{EHT}(u,v)}$", fontsize=10) 490 | ax.scatter(grid_dense['uv'][:,0], grid_dense['uv'][:,1], c='0.5', alpha=0.7, s=0.1, marker='.', rasterized=True) 491 | ax.scatter(cont_sparse['uv'][:,0], cont_sparse['uv'][:,1], c=np.angle(cont_sparse['vis']), 492 | s=1, marker='.', cmap='twilight', vmin=-np.pi, vmax=np.pi, rasterized=True) 493 | ax.set_xlim([-1.1*max_base, 1.1*max_base]) 494 | ax.set_ylim([-1.1*max_base, 1.1*max_base]) 495 | 496 | ax = plt.subplot(gs[1,2]) 497 | ax.set_title("(EHT) Observed Visib. Amp\n$|\hat{\mathcal{V}}_{EHT}(u,v)|$", fontsize=10) 498 | ax.scatter(grid_dense['uv'][:,0], grid_dense['uv'][:,1], c='0.5', alpha=0.7, s=0.1, marker='.', rasterized=True) 499 | ax.scatter(cont_sparse['uv'][:,0], cont_sparse['uv'][:,1], c=np.abs(cont_sparse['vis']), 500 | s=1, marker='.', cmap='viridis', vmin=vmin, vmax=vmax, rasterized=True) 501 | ax.set_xlim([-1.1*max_base, 1.1*max_base]) 502 | ax.set_ylim([-1.1*max_base, 1.1*max_base]) 503 | 504 | ax = plt.subplot(gs[1,3]) 505 | ax.set_title("(EHT) Visib. Amp. \n vs. UV distance", fontsize=10) 506 | ax.scatter(np.linalg.norm(grid_dense['uv'], axis=1), np.abs(grid_dense['vis']), c='0.5', alpha=0.7, s=0.1, marker='.', rasterized=True) 507 | ax.scatter(np.linalg.norm(cont_sparse['uv'], axis=1), np.abs(cont_sparse['vis']), c=np.abs(cont_sparse['vis']), 508 | s=1, vmin=vmin, vmax=vmax, rasterized=True) 509 | ax.text(0.03, 0.97, f"n={len(cont_sparse['uv'])}", fontsize=8, ha='left', va='top', transform=ax.transAxes) 510 | ax.set_yscale('log') 511 | ax.set_xlim([0,1.25e10]) 512 | ax.set_ylim([1e-1,3000]) 513 | 514 | # grid_sparse 515 | ax = plt.subplot(gs[2,0]) 516 | ax.set_title("(EHT grid) Dirty Image\n$I^{D}_{EHT,grid}(l,m) \equiv \mathscr{F}^{-1}_{NU}[\hat{\mathcal{V}}_{EHT,grid}(u,v)]$", fontsize=10) 517 | ax.imshow(dim_grid_sparse) #, vmin=vmin_img, vmax=vmax_img) 518 | 519 | ax = plt.subplot(gs[2,1]) 520 | ax.set_title("(EHT,grid) Visib. Phase\n$ \\angle{\hat{\mathcal{V}}_{EHT,grid}(u,v)}$", fontsize=10) 521 | ax.scatter(grid_dense['uv'][:,0], grid_dense['uv'][:,1], alpha=0.7, s=0.1, c='0.5', marker='.', rasterized=True) 522 | ax.scatter(grid_sparse['uv'][:,0], grid_sparse['uv'][:,1], c=np.angle(grid_sparse['vis']), 523 | s=1, marker='.', cmap='twilight', vmin=-np.pi, vmax=np.pi, rasterized=True) 524 | ax.set_xlim([-1.1*max_base, 1.1*max_base]) 525 | ax.set_ylim([-1.1*max_base, 1.1*max_base]) 526 | 527 | ax = plt.subplot(gs[2,2]) 528 | ax.set_title("(EHT,grid) Visib. Amp\n$|\hat{\mathcal{V}}_{EHT,grid}(u,v)|$", fontsize=10) 529 | ax.scatter(grid_dense['uv'][:,0], grid_dense['uv'][:,1], alpha=0.7, s=0.1, c='0.5', marker='.', rasterized=True) 530 | ax.scatter(grid_sparse['uv'][:,0], grid_sparse['uv'][:,1], c=np.abs(grid_sparse['vis']), 531 | s=1, marker='.', cmap='viridis', vmin=vmin, vmax=vmax, rasterized=True) 532 | ax.set_xlim([-1.1*max_base, 1.1*max_base]) 533 | ax.set_ylim([-1.1*max_base, 1.1*max_base]) 534 | 535 | ax = plt.subplot(gs[2,3]) 536 | ax.set_title("(EHT grid) Visib. Amp. \n vs. UV distance", fontsize=10) 537 | ax.scatter(np.linalg.norm(grid_dense['uv'], axis=1), np.abs(grid_dense['vis']), c='0.5', alpha=0.7, s=0.1, marker='.', rasterized=True) 538 | ax.scatter(np.linalg.norm(grid_sparse['uv'], axis=1), np.abs(grid_sparse['vis']), c=np.abs(grid_sparse['vis']), 539 | s=1, vmin=vmin, vmax=vmax, rasterized=True) 540 | ax.text(0.03, 0.97, f"n={len(grid_sparse['uv'])}", fontsize=8, ha='left', va='top', transform=ax.transAxes) 541 | ax.set_yscale('log') 542 | ax.set_xlim([0,1.25e10]) 543 | ax.set_ylim([1e-1,3000]) 544 | 545 | if savefig: 546 | plt.savefig(savefig, bbox_inches='tight') 547 | plt.close() 548 | 549 | 550 | def load_h5(fpath): 551 | print('--loading h5 file for Galaxy10 dataset...') 552 | with h5py.File(fpath, 'r') as F: 553 | x = np.array(F['images']) 554 | y = np.array(F['ans']) 555 | print('Done--') 556 | 557 | return x, y 558 | 559 | 560 | class Galaxy10_Dataset(Dataset): 561 | ''' 562 | loader for Galaxy10 version_1, lower resolution 563 | ''' 564 | def __init__(self, h5_path ='./dataset_ssd/astroImg/Galaxy10.h5', transform_in = None): 565 | if transform_in is None: 566 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 567 | else: 568 | transform = transform_in 569 | 570 | imgs, labels= load_h5(h5_path) 571 | self.imgs = imgs 572 | self.labels = labels 573 | self.transform = transform 574 | 575 | def __getitem__(self, idx): 576 | scale = 1/255. 577 | img_Lab = color.rgb2lab(self.imgs[idx]) 578 | img = self.transform(img_Lab[...,0] * scale) 579 | #tf2 = transforms.Compose([transforms.ToPILImage()]) 580 | #img_Lab = tf2(color.rgb2lab(self.imgs[idx])) 581 | #img = self.transform(img_Lab[...,0]) 582 | #img *= scale 583 | label = self.labels[idx] 584 | return img, label 585 | 586 | def __len__(self): 587 | #return len(img) 588 | return len(self.imgs) 589 | 590 | class Galaxy10_DECals_Dataset(Dataset): 591 | ''' 592 | loader for Galaxy10 DECals (version 2), 256x256 resolution 593 | ''' 594 | def __init__(self, h5_path ='/astroim/data/Galaxy10_DECals.h5', transform_in = None): 595 | if transform_in is None: 596 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 597 | else: 598 | transform = transform_in 599 | 600 | imgs, labels= load_h5(h5_path) 601 | self.imgs = imgs 602 | self.labels = labels 603 | self.transform = transform 604 | 605 | def __getitem__(self, idx): 606 | scale = 1/255. 607 | img_Lab = color.rgb2lab(self.imgs[idx]) 608 | img = self.transform(img_Lab[...,0] * scale) 609 | #tf2 = transforms.Compose([transforms.ToPILImage()]) 610 | #img_Lab = tf2(color.rgb2lab(self.imgs[idx])) 611 | #img = self.transform(img_Lab[...,0]) 612 | #img *= scale 613 | label = self.labels[idx] 614 | return img, label 615 | 616 | def __len__(self): 617 | #return len(img) 618 | return len(self.imgs) 619 | 620 | class EHT_Continuous_Dataset(Dataset): 621 | ''' 622 | dataset for EHT imaging of MNIST or Galaxy10: 623 | returns {u,v,vis} for dense grid, sparse continuous, sparse grid 624 | 625 | dset_name = ['MNIST', 'Galaxy10'] 626 | obs_type = ['eht', 'sparse', 'dense'] # note: sparse/dense replace EHT array with artificial telescope array 627 | ''' 628 | 629 | def __init__(self, 630 | eht_npix = 200, 631 | num_FC = 64, 632 | dset_name = 'Galaxy10', 633 | h5_path_img = '../data/Galaxy10.h5', 634 | transform_in = None, 635 | obs_type='eht'): 636 | 637 | if dset_name == 'MNIST': 638 | from torchvision.datasets import MNIST 639 | from torchvision import transforms 640 | 641 | transform = transforms.Compose([transforms.Resize((200, 200)), 642 | transforms.ToTensor(), 643 | transforms.Normalize((0.1307,), (0.3081,)), 644 | ]) 645 | self.dataset = MNIST('', train=True, download=True, transform=transform) 646 | 647 | elif dset_name == 'Galaxy10': 648 | h5_path_img = '../data/Galaxy10.h5' 649 | self.dataset = Galaxy10_Dataset(h5_path_img, transform_in) 650 | 651 | elif dset_name == 'Galaxy10_DECals': 652 | h5_path_img = '../data/Galaxy10_DECals.h5' 653 | self.dataset = Galaxy10_DECals_Dataset(h5_path_img, transform_in) 654 | 655 | else: 656 | print("choose dset_name from ['MNIST', 'Galaxy10', 'Galaxy10_DECals']") 657 | raise NotImplementedError 658 | 659 | 660 | self.eht_npix = eht_npix 661 | self.num_FC = num_FC 662 | self.obs_type = obs_type 663 | 664 | def __getitem__(self, idx): 665 | 666 | # rescale to 200x200 for eht-im setup 667 | img_res_initial = int(torch.numel(self.dataset[idx][0])**(0.5)) 668 | img = self.dataset[idx][0].reshape((img_res_initial,img_res_initial)) 669 | 670 | if img_res_initial != 200: 671 | #print('scaling input to match requested size:', img_res_initial, 200) 672 | img = upscale_tensor(img, final_res=200, method='cubic') 673 | img = torch.from_numpy(img) 674 | 675 | grid_dense, cont_sparse, grid_sparse, obs_meta = get_uvvis_data(img, obs_type=self.obs_type, eht_npix=self.eht_npix, num_fourier_coeff=self.num_FC) 676 | 677 | #--DEBUG replace the vis member of grid_dense with the one genrated by DPI helper 678 | # if hostname=='NV': 679 | # import dpi_helper 680 | # vis_grid_dense = dpi_helper.get_uvvis_data_dpi( 681 | # img.reshape(1, img.shape[-2], img.shape[-1]).repeat(2, 1, 1), 682 | # uvfit_filepath='../data/gt.fits', 683 | # obs_path='../data/obs.uvfits', 684 | # fov= obs_meta['fov'], 685 | # pdim= obs_meta['fov']/ img.shape[-1], 686 | # npix= img.shape[-1], 687 | # num_fourier_coeff=self.num_FC, 688 | # uv_input= grid_dense['uv']) 689 | # vis_grid_dense= torch.view_as_complex(vis_grid_dense[0].T.contiguous()).cpu().numpy() 690 | # grid_dense['vis']= vis_grid_dense 691 | #---- END OF DEBUG ---# 692 | 693 | 694 | return grid_dense, cont_sparse, grid_sparse, obs_meta 695 | 696 | def __len__(self): 697 | return len(self.dataset) 698 | 699 | 700 | def plot_compare_dirtyim_ehtobs(grid_dense, cont_sparse, grid_sparse, obs_meta, gt_image): 701 | # make dirty images: 702 | cutoff_freq = 0.0 703 | weighting = 'uniform' 704 | eht_npix, eht_fov, num_fourier_coeff, sigma = obs_meta['npix'], obs_meta['fov'], obs_meta['n_FC'], obs_meta['sigma'] 705 | dim_grid_dense, im1, norm_grid_dense = make_dirtyim(grid_dense['uv'], grid_dense['vis'], eht_npix, eht_fov, sigma=sigma, 706 | cutoff_freq=cutoff_freq, return_im=True, 707 | weighting=weighting) 708 | dim_cont_sparse, im2, norm_cont_sparse = make_dirtyim(cont_sparse['uv'], cont_sparse['vis'], eht_npix, eht_fov, sigma=sigma, 709 | cutoff_freq=cutoff_freq, return_im=True, 710 | weighting=weighting) 711 | dim_grid_sparse, im3, norm_grid_sparse = make_dirtyim(grid_sparse['uv'], grid_sparse['vis'], eht_npix, eht_fov, sigma=sigma, 712 | cutoff_freq=cutoff_freq, return_im=True, 713 | weighting=weighting) 714 | 715 | dirty_beam, im4, norm_dirty_beam = make_dirtyim(cont_sparse['uv'], 10.0*np.ones_like(cont_sparse['vis']), eht_npix, eht_fov, sigma=sigma, 716 | cutoff_freq=cutoff_freq, return_im=True, 717 | weighting=weighting) 718 | 719 | dim_grid_dense = dim_grid_dense.imarr() 720 | dim_cont_sparse = dim_cont_sparse.imarr() 721 | dim_grid_sparse = dim_grid_sparse.imarr() 722 | dirty_beam = dirty_beam.imarr() 723 | 724 | 725 | fov=1.4108078120287498e-09 726 | npix=len(gt_image) 727 | pdim = fov/npix 728 | im = eh.image.Image(gt_image, pdim, 0, 0,) 729 | # fov2 = im.xdim * im.psize # same as fov 730 | #im.display() 731 | 732 | # observe the image the same way as data generator 733 | meta_file ='./code/EHT2017.txt' 734 | eht_meta = eh.array.load_txt(meta_file) 735 | 736 | tadv_sec = 600 737 | tstart_hr = 0 738 | tstop_hr = 24 739 | tint_sec = 12 740 | bw_hz = 4.096e9 741 | obs = im.observe(eht_meta, tint_sec, tadv_sec, tstart_hr, tstop_hr, bw_hz, 742 | sgrscat=False, add_th_noise=False, ampcal=True, phasecal=True, ttype='direct') 743 | 744 | #fov_expanded = fov * 1.1 745 | 746 | # Resolution 747 | beamparams = obs.fit_beam() # fitted beam parameters (fwhm_maj, fwhm_min, theta) in radians 748 | res = obs.res() # nominal array resolution, 1/longest baseline 749 | print("Clean beam parameters: ", beamparams) 750 | print("Nominal Resolution: ", res) 751 | 752 | #obs.save_uvfits('galaxy10_decals_obs.fits') # exports a UVFITS file modeled on template.UVP 753 | #obs.save_fits('galaxy10_decals_obs.fits') 754 | #print('saved file!') 755 | 756 | dim = obs.dirtyimage(npix, fov).imarr() 757 | dbeam = obs.dirtybeam(npix, fov).imarr() 758 | cbeam = obs.cleanbeam(npix, fov).imarr() 759 | 760 | clean_beam_size = 4.0 761 | imsize = np.int32(dirty_beam.shape[0]) 762 | dirty_psf_max = np.float32(dirty_beam.max()) 763 | dirty_psf = dirty_beam / dirty_psf_max 764 | # clean_psf = gICLEAN.serial_clean_beam(dirty_beam, imsize / clean_beam_size)*dirty_psf_max 765 | 766 | cmap = 'afmhot' 767 | prefix = 'compare_beams' 768 | 769 | fig, axs = plt.subplots(5, 2, sharex='all', sharey='all', figsize=(7, 15)) 770 | plt.subplots_adjust(wspace=0) 771 | 772 | vra = [np.percentile(dirty_beam, 1), np.percentile(dirty_beam, 99)] 773 | axs[0,0].imshow(dirty_beam,vmin=vra[0],vmax=vra[1],cmap=cmap, origin='upper') 774 | axs[0,0].set_title('Dirty beam (dirtyim)') 775 | axs[0,1].imshow(dbeam,vmin=vra[0],vmax=vra[1],cmap=cmap, origin='upper') 776 | axs[0,1].set_title('Dirty beam (EHT)') 777 | 778 | vra = [np.percentile(dim_cont_sparse, 1), np.percentile(dim_cont_sparse, 99)] 779 | axs[1,0].imshow(dim_cont_sparse,vmin=vra[0],vmax=vra[1],cmap=cmap,origin='upper') 780 | axs[1,0].set_title('Dirty image (dirtyim)') 781 | axs[1,1].imshow(dim,vmin=vra[0],vmax=vra[1],cmap=cmap,origin='upper') 782 | axs[1,1].set_title('Dirty image (EHT)') 783 | 784 | vra = [np.percentile(clean_psf, 1), np.percentile(clean_psf, 99)] 785 | axs[2,0].imshow(clean_psf,vmin=vra[0],vmax=vra[1],cmap=cmap,origin='upper') 786 | axs[2,0].set_title('Clean beam (clean-cuda)') 787 | axs[2,1].imshow(cbeam,vmin=vra[0],vmax=vra[1],cmap=cmap,origin='upper') 788 | axs[2,1].set_title('Clean beam (EHT)') 789 | 790 | vra = [np.percentile(gt_image, 1), np.percentile(gt_image, 99)] 791 | axs[3,0].imshow(dim_grid_dense, vmin=vra[0], vmax=vra[1], cmap=cmap, origin='upper') 792 | axs[3,0].set_title('Dense IFFT') 793 | axs[3,1].imshow(gt_image,vmin=vra[0],vmax=vra[1],cmap=cmap,origin='upper') 794 | axs[3,1].set_title('GT Image (original)') 795 | 796 | 797 | vra = [np.percentile(gt_image, 1), np.percentile(gt_image, 99)] 798 | dirty_convolve_mine = convolve2d(gt_image, dirty_beam, mode='same') 799 | dirty_convolve_eht = convolve2d(gt_image, dbeam, mode='same') 800 | axs[4,0].imshow(dirty_convolve_mine,vmin=vra[0],vmax=vra[1],cmap=cmap,origin='upper') 801 | axs[4, 0].set_title('Convolved (dirtyim)') 802 | axs[4,1].imshow(dirty_convolve_eht, vmin=vra[0], vmax=vra[1], cmap=cmap, origin='upper') 803 | axs[4, 1].set_title('Convolved (EHT)') 804 | 805 | plt.savefig(prefix+'_clean_final.png') 806 | plt.close() 807 | #dim.display() 808 | #dbeam.display() 809 | #cbeam.display() 810 | 811 | # gICLEAN.clean_cuda(dirty_im=dim, dirty_psf=dbeam, thresh=1e-10, gain=1e-1, clean_beam_size=4.0, 812 | # maxIter=1e6, 813 | # prefix='Galaxy10_decals_EHT_lessnoise_dirty10.0', 814 | # im_gt=gt_image, 815 | # clean_psf=cbeam, 816 | # polarity=False) 817 | 818 | # gICLEAN.clean_cuda(dirty_im=dim_cont_sparse, dirty_psf=dirty_beam, thresh=1e-10, gain=1e-1, clean_beam_size=4.0, 819 | # maxIter=1e6, 820 | # prefix='Galaxy10_decals_Mine_lessnoise_dirty10.0', 821 | # im_gt=gt_image, 822 | # clean_psf=None, 823 | # polarity=False) 824 | 825 | prior = eh.image.make_square(obs, npix, im.fovx()) 826 | outvis = dd_clean_vis(obs, prior, niter=500, loop_gain=0.1, 827 | method='max_delta', weighting='natural', 828 | show_updates=True) 829 | 830 | beamparams = obs.fit_beam() 831 | dirty_im_pred_CLEAN = outvis.blur_gauss(beamparams, 0.5).imarr() 832 | 833 | 834 | 835 | 836 | def dd_CLEAN(gt_image, niter, loop_gain): 837 | 838 | fov=1.4108078120287498e-09 839 | npix=len(gt_image) 840 | pdim = fov/npix 841 | im = eh.image.Image(gt_image, pdim, 0, 0,) 842 | #im.display() 843 | 844 | # observe the image the same way as data generator 845 | meta_file ='./code/EHT2017.txt' 846 | eht_meta = eh.array.load_txt(meta_file) 847 | 848 | tadv_sec = 600 849 | tstart_hr = 0 850 | tstop_hr = 24 851 | tint_sec = 12 852 | bw_hz = 4.096e9 853 | obs = im.observe(eht_meta, tint_sec, tadv_sec, tstart_hr, tstop_hr, bw_hz, 854 | sgrscat=False, add_th_noise=False, ampcal=True, phasecal=True, ttype='direct') 855 | 856 | #npix = 32 857 | fov2 = im.xdim * im.psize # same as fov 858 | 859 | # Resolution 860 | beamparams = obs.fit_beam() # fitted beam parameters (fwhm_maj, fwhm_min, theta) in radians 861 | res = obs.res() # nominal array resolution, 1/longest baseline 862 | print("Clean beam parameters: ", beamparams) 863 | print("Nominal Resolution: ", res) 864 | 865 | #prior = eh.image.make_square(obs, 128, 1.5*im.fovx()) 866 | #prior = eh.image.make_square(obs, 64, im.fovx()) 867 | prior = eh.image.make_square(obs, npix, im.fovx()) 868 | 869 | # data domain clean with visibilities 870 | #outvis = dd_clean_vis(obs, prior, niter=100, loop_gain=0.1, method='min_chisq', weighting='uniform', show_updates=True) # to see iterations 871 | #outvis = dd_clean_vis(obs, prior, niter=niter, loop_gain=loop_gain, method='min_chisq', weighting='uniform') 872 | #outvis = dd_clean_vis(obs, prior, niter=niter, loop_gain=loop_gain, method='min_chisq', weighting='natural') 873 | #outvis = dd_clean_vis(obs, prior, niter=niter, loop_gain=loop_gain, method='max_delta', weighting='uniform') 874 | outvis = dd_clean_vis(obs, prior, niter=niter, loop_gain=loop_gain, 875 | method='max_delta', weighting='natural', 876 | show_updates=False) 877 | 878 | beamparams = obs.fit_beam() 879 | dirty_im_pred_CLEAN = outvis.blur_gauss(beamparams, 0.5).imarr() 880 | 881 | return dirty_im_pred_CLEAN 882 | 883 | 884 | def do_test(compare_sparse_dense=False, do_clean=False, compare_dirty=False, do_clean_cuda=True): 885 | import pytorch_lightning as pl 886 | from torch.utils.data import DataLoader, random_split 887 | 888 | pl.seed_everything(42) 889 | numVal = 32 * 16 890 | 891 | num_fourier_coeff = 200 892 | eht_npix = 200 893 | dset_name = 'Galaxy10_DECals' #'Galaxy10' # 894 | 895 | eht_cont_dset = EHT_Continuous_Dataset(eht_npix=eht_npix, 896 | num_FC=num_fourier_coeff, 897 | dset_name=dset_name, 898 | obs_type='eht') 899 | 900 | split_train, split_val = random_split(eht_cont_dset, [len(eht_cont_dset) - numVal, numVal]) 901 | split_val, _ = random_split(split_val, [1, len(split_val)-1]) 902 | 903 | # CLEAN figs 904 | cleaned_lst = [] 905 | for idx in range(len(split_val)): 906 | print(idx) 907 | print('-------') 908 | grid_dense, cont_sparse, grid_sparse, obs_meta = split_val[idx] 909 | # dim_grid_dense = make_dirtyim(grid_dense['uv'], grid_dense['vis'], eht_npix, fov).imarr() 910 | 911 | if compare_dirty: 912 | plot_compare_dirtyim_ehtobs(grid_dense, cont_sparse, grid_sparse, obs_meta, obs_meta['gt_img']) 913 | 914 | if do_clean: 915 | dirty_im_pred_CLEAN = dd_CLEAN(obs_meta['gt_img'], niter=500, loop_gain=0.05) 916 | plt.imshow(dirty_im_pred_CLEAN, cmap='afmhot') 917 | 918 | if compare_sparse_dense: 919 | savefig = f'ehtim_grid_{num_fourier_coeff}FC_{eht_npix}im_{dset_name}_{idx:05d}_{cutoff_freq}.png' 920 | plot_eht_compare(grid_dense, cont_sparse, grid_sparse, obs_meta, savefig=savefig, cutoff_freq=cutoff_freq) 921 | 922 | if do_clean_cuda: 923 | cutoff_freq = 0.0 924 | weighting = 'uniform' 925 | eht_npix, eht_fov, num_fourier_coeff, sigma = obs_meta['npix'], obs_meta['fov'], \ 926 | obs_meta['n_FC'], obs_meta['sigma'] 927 | 928 | dim_cont_sparse = make_dirtyim(cont_sparse['uv'], cont_sparse['vis'], eht_npix, 929 | eht_fov, sigma=sigma, 930 | cutoff_freq=cutoff_freq, return_im=False, 931 | weighting=weighting).imarr() 932 | 933 | plt.imsave('reconstructed_dirty.png', arr = abs(dim_cont_sparse), cmap='hot') 934 | dirty_beam = make_dirtyim(cont_sparse['uv'], np.ones_like(cont_sparse['vis']), 935 | eht_npix, eht_fov, sigma=sigma, 936 | cutoff_freq=cutoff_freq, return_im=False, 937 | weighting=weighting).imarr() 938 | np.save("beam.npy", dirty_beam) 939 | plt.imsave('reconstructed_beam.png', arr = dirty_beam, cmap='gray') 940 | # dirty_im = np.load('dirty.npy') 941 | print(dim_cont_sparse.shape) 942 | # cleaned = gICLEAN.clean_cuda(dirty_im=abs(dim_cont_sparse), dirty_psf=dirty_beam, thresh=1e-10, gain=0.1, 943 | # clean_beam_size=4, 944 | # maxIter=1e5, 945 | # prefix='../clean-cuda_val/Galaxy10_decals_clean-cuda_idx%05d' % idx, 946 | # im_gt=obs_meta['gt_img'], 947 | # clean_psf=None, 948 | # polarity=False) 949 | print(cleaned.shape) 950 | plt.imsave('reconstructed_image.png', arr = cleaned, cmap='hot') 951 | cleaned_lst.append(cleaned) 952 | cleaned_npy = np.stack(cleaned_lst) 953 | np.save('val_cleaned_idx%05d.npy' % idx, cleaned_npy) 954 | 955 | if __name__ == "__main__": 956 | do_test(compare_sparse_dense=False, do_clean=False, compare_dirty=False, do_clean_cuda=True) 957 | --------------------------------------------------------------------------------