├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── ckpts ├── DSLR_2800_phase.pt ├── DSLR_3400_phase.pt ├── MV_1600_phase.pt ├── MV_2400_phase.pt └── MV_recon │ ├── G_040.pt │ ├── args.json │ └── param_MV_2400.py ├── config ├── param_DSLR_2800.py └── param_MV_1600.py ├── env.yml ├── inference.ipynb ├── models ├── Defence.py ├── Dirt.py ├── PerlinBlob.py ├── ROLE.py ├── forwards.py └── recon.py ├── pado-main ├── LICENSE ├── README.md ├── docs │ └── images │ │ ├── logo.pdf │ │ ├── logo.png │ │ └── logo.svg ├── example │ └── tutorial.ipynb └── pado │ ├── .gitignore │ ├── cmap_phase.txt │ ├── complex.py │ ├── conv.py │ ├── fourier.py │ ├── light.py │ ├── material.py │ ├── optical_element.py │ └── propagator.py ├── pado ├── .gitignore ├── cmap_phase.txt ├── complex.py ├── conv.py ├── fourier.py ├── light.py ├── material.py ├── optical_element.py └── propagator.py ├── train.py ├── train.sh └── utils ├── ._save.py ├── save.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/SeeThroughObstructions/f94a297cc87d72dcb17ffdab3bed64cff186ab9f/.gitmodules -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Boost Software License - Version 1.0 - August 17th, 2003 2 | 3 | Permission is hereby granted, free of charge, to any person or organization 4 | obtaining a copy of the software and accompanying documentation covered by 5 | this license (the "Software") to use, reproduce, display, distribute, 6 | execute, and transmit the Software, and to prepare derivative works of the 7 | Software, and to permit third-parties to whom the Software is furnished to 8 | do so, all subject to the following: 9 | 10 | The copyright notices in the Software and this entire statement, including 11 | the above license grant, this restriction and the following disclaimer, 12 | must be included in all copies of the Software, in whole or in part, and 13 | all derivative works of the Software, unless such copies or derivative 14 | works are solely in the form of machine-executable object code generated by 15 | a source language processor. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT 20 | SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE 21 | FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, 22 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 23 | DEALINGS IN THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Seeing Through Obstructions with Diffractive Cloaking 2 | ### [Project Page](https://light.princeton.edu/publication/seeing-through-obstructions/) | [Paper](https://dl.acm.org/doi/abs/10.1145/3528223.3530185) 3 | 4 | [Zheng Shi](https://zheng-shi.github.io/), [Yuval Bahat](https://sites.google.com/view/yuval-bahat/home), [Seung-Hwan Baek](https://www.shbaek.com/), [Qiang Fu](https://cemse.kaust.edu.sa/vcc/people/person/qiang-fu), [Hadi Amata ](https://cemse.kaust.edu.sa/people/person/hadi-amata), [Xiao Li](), [Praneeth Chakravarthula](https://www.cs.unc.edu/~cpk/), [Wolfgang Heidrich](https://vccimaging.org/People/heidriw/), [Felix Heide](https://www.cs.princeton.edu/~fheide/) 5 | 6 | If you find our work useful in your research, please cite: 7 | ``` 8 | @article{Shi2022SeeThroughObstructions, 9 | author = {Shi, Zheng and Bahat, Yuval and Baek, Seung-Hwan and Fu, Qiang and Amata, Hadi and Li, Xiao and Chakravarthula, Praneeth and Heidrich, Wolfgang and Heide, Felix}, 10 | title = {Seeing through Obstructions with Diffractive Cloaking}, 11 | year = {2022}, 12 | issue_date = {July 2022}, 13 | publisher = {Association for Computing Machinery}, 14 | address = {New York, NY, USA}, 15 | volume = {41}, 16 | number = {4}, 17 | issn = {0730-0301}, 18 | url = {https://doi.org/10.1145/3528223.3530185}, 19 | doi = {10.1145/3528223.3530185}} 20 | ``` 21 | 22 | ## Requirements 23 | This code is developed using Pytorch on Linux machine. Full frozen environment can be found in 'env.yml', note some of these libraries are not necessary to run this code. Other than the packages installed in the environment, our image formation model uses package [pado](https://github.com/shwbaek/pado) to simulate wave optics. 24 | 25 | ## Data 26 | In the paper we use [Places365](http://places2.csail.mit.edu/index.html) and [Cityscapes](https://www.cityscapes-dataset.com/) as the obstruction-free background scene. And they can be easily swtich to any other datasets of your choice. See 'train.py' for more details on the data augmentation we applied. For more details on depth-aware obstruction simulation, please refer to 'models/'. 27 | 28 | ## Pre-trained Models and Optimized DOE Designs 29 | Optimzed DOE Designs and pre-trained models are available under 'ckpts/' folder. Please refer to the supplemental documents for fabrication details. 30 | 31 | ## Sensor Capture Simulation and Reconstruction 32 | We include a sample script that demonstrates our entire image formation and reconstruction process. You can run the 'inference.ipynb' notebook in Jupyter Notebook. The notebook will load the checkpoint and run the entire process. The simulated depth-dependent PSFs, simulated sensor capture, as well as reconstructed image will be displayed within the notebook. 33 | 34 | ## Training 35 | We include 'train.sh' for training purpose. Please refer to 'config/' for optics and sensor specs. 36 | 37 | ## License 38 | Our code is licensed under BSL-1. By downloading the software, you agree to the terms of this License. 39 | 40 | ## Questions 41 | If there is anything unclear, please feel free to reach out to me at zhengshi[at]princeton[dot]edu. 42 | -------------------------------------------------------------------------------- /ckpts/DSLR_2800_phase.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/SeeThroughObstructions/f94a297cc87d72dcb17ffdab3bed64cff186ab9f/ckpts/DSLR_2800_phase.pt -------------------------------------------------------------------------------- /ckpts/DSLR_3400_phase.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/SeeThroughObstructions/f94a297cc87d72dcb17ffdab3bed64cff186ab9f/ckpts/DSLR_3400_phase.pt -------------------------------------------------------------------------------- /ckpts/MV_1600_phase.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/SeeThroughObstructions/f94a297cc87d72dcb17ffdab3bed64cff186ab9f/ckpts/MV_1600_phase.pt -------------------------------------------------------------------------------- /ckpts/MV_2400_phase.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/SeeThroughObstructions/f94a297cc87d72dcb17ffdab3bed64cff186ab9f/ckpts/MV_2400_phase.pt -------------------------------------------------------------------------------- /ckpts/MV_recon/G_040.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/SeeThroughObstructions/f94a297cc87d72dcb17ffdab3bed64cff186ab9f/ckpts/MV_recon/G_040.pt -------------------------------------------------------------------------------- /ckpts/MV_recon/args.json: -------------------------------------------------------------------------------- 1 | { 2 | "debug": false, 3 | "train_optics": false, 4 | "pretrained_DOE": "ckpts/MV_2400_phase.pt", 5 | "pretrained_G": null, 6 | "result_path": "/projects/FHEIDE/obstruction_free_doe/ckpt/MV_retrain", 7 | "param_file": "config/param_MV_2400.py", 8 | "obstruction": "dirt_raindrop", 9 | "sensor_noise": 0.008, 10 | "n_epochs": 20, 11 | "optics_lr": 0.1, 12 | "G_lr": 0.0001, 13 | "l1_loss_weight": 1, 14 | "masked_loss_weight": 1, 15 | "perceptual_loss_weight": 1, 16 | "log_freq": 400, 17 | "save_freq": 2000, 18 | "seed": 1234, 19 | "device": "cuda", 20 | "DOE_phase_init_ckpt": "ckpts/MV_2400_phase.pt", 21 | "G_init_ckpt": null 22 | } -------------------------------------------------------------------------------- /ckpts/MV_recon/param_MV_2400.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pado.material import * 3 | from utils.utils import * 4 | 5 | # https://www.edmundoptics.com/p/8mm-UC-Series-Fixed-Focal-Length-Lens/41864?gclid=CjwKCAjwh5qLBhALEiwAioods3Z90aJenLK11rrj2R5E7SpKF0gvF8a9vZrsd0H5aY72nIgcYq42QRoC4hAQAvD_BwE 6 | # https://www.edmundoptics.com/p/bfs-u3-120s4c-cs-usb3-blackflyreg-s-color-camera/40172/ 7 | 8 | camera_resolution = [4000,3000] 9 | camera_pitch = 1.85e-6 10 | background_pitch = camera_pitch * 2 11 | sensor_dist = focal_length = 8e-3 12 | aperture_shape='circle' 13 | wvls = [656e-9, 589e-9, 486e-9] # camera RGB wavelength 14 | DOE_wvl = 550e-9 # wavelength used to set DOE 15 | 16 | # DOE specs 17 | R = C = 2400 # Resolution of the simulated wavefront 18 | DOE_material = 'FUSED_SILICA' 19 | material = Material(DOE_material) 20 | DOE_pitch = camera_pitch 21 | aperture_diamter = DOE_pitch * R 22 | DOE_sample_ratio = 2 23 | image_sample_ratio = 2 24 | equiv_camera_pitch = camera_pitch * image_sample_ratio 25 | img_res = 512 26 | assert DOE_pitch * DOE_sample_ratio == equiv_camera_pitch 27 | DOE_max_height = 1.2e-6 28 | DOE_height_noise_scale = 4*10e-9 29 | DOE_phase_noise_scale = 0.1 30 | 31 | DOE_phase_init = torch.zeros((1,1, R, C)) 32 | 33 | # depth 34 | depth_near_min = 0.05 35 | depth_near_max = 0.12 36 | depth_far_min = 5 37 | depth_far_max = 10 38 | plot_depth = [5,3,1,0.12,0.08,0.05] 39 | 40 | # raindrop 41 | drop_Nmin = 5 42 | drop_Nmax = 8 43 | drop_Rmin = 1e-3 44 | drop_Rmax = 3e-3 45 | 46 | # dirt 47 | perlin_res = 8 48 | perlin_cutoff = 0.55 -------------------------------------------------------------------------------- /config/param_DSLR_2800.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pado.material import * 3 | 4 | # https://www.usa.canon.com/internet/portal/us/home/products/details/cameras/eos-dslr-and-mirrorless-cameras/dslr/eos-rebel-t5-ef-s-18-55-is-ii-kit/eos-rebel-t5-18-55-is-ii-kit 5 | camera_resolution = [5196,3464] # 18e6 6 | camera_pitch = 4.3e-6 7 | background_pitch = camera_pitch * 4 8 | sensor_dist = focal_length = 50e-3 9 | aperture_shape='circle' 10 | wvls = [656e-9, 589e-9, 486e-9] # camera RGB wavelength 11 | DOE_wvl = 550e-9 # wavelength used to set DOE 12 | 13 | # DOE specs 14 | R = C = 2800 # Resolution of the simulated wavefront 15 | DOE_material = 'FUSED_SILICA' 16 | material = Material(DOE_material) 17 | DOE_pitch = camera_pitch * 1.5 18 | aperture_diamter = DOE_pitch * R 19 | DOE_sample_ratio = 2 20 | image_sample_ratio = 3 21 | equiv_camera_pitch = camera_pitch * image_sample_ratio 22 | img_res = 512 23 | assert DOE_pitch * DOE_sample_ratio == equiv_camera_pitch 24 | DOE_max_height = 1.2e-6 25 | DOE_height_noise_scale = 4*10e-9 26 | DOE_phase_noise_scale = 0.05 27 | 28 | DOE_phase_init = torch.zeros((1,1, R, C)) 29 | 30 | # depth 31 | depth_near_min = 0.4 32 | depth_near_max = 0.8 33 | depth_far_min = 5 34 | depth_far_max = 10 35 | plot_depth = [10,8,5,0.8,0.6,0.4] 36 | 37 | # fence 38 | fence_min = 5e-3 39 | fence_max = 15e-3 40 | -------------------------------------------------------------------------------- /config/param_MV_1600.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pado.material import * 3 | from utils.utils import * 4 | 5 | # https://www.edmundoptics.com/p/8mm-UC-Series-Fixed-Focal-Length-Lens/41864?gclid=CjwKCAjwh5qLBhALEiwAioods3Z90aJenLK11rrj2R5E7SpKF0gvF8a9vZrsd0H5aY72nIgcYq42QRoC4hAQAvD_BwE 6 | # https://www.edmundoptics.com/p/bfs-u3-120s4c-cs-usb3-blackflyreg-s-color-camera/40172/ 7 | 8 | camera_resolution = [4000,3000] 9 | camera_pitch = 1.85e-6 10 | background_pitch = camera_pitch * 2 11 | sensor_dist = focal_length = 8e-3 12 | aperture_shape='circle' 13 | wvls = [656e-9, 589e-9, 486e-9] # camera RGB wavelength 14 | DOE_wvl = 550e-9 # wavelength used to set DOE 15 | 16 | # DOE specs 17 | R = C = 1600 # Resolution of the simulated wavefront 18 | DOE_material = 'FUSED_SILICA' 19 | material = Material(DOE_material) 20 | DOE_pitch = camera_pitch * 1.5 21 | aperture_diamter = DOE_pitch * R 22 | DOE_sample_ratio = 2 23 | image_sample_ratio = 3 24 | equiv_camera_pitch = camera_pitch * image_sample_ratio 25 | img_res = 512 26 | assert DOE_pitch * DOE_sample_ratio == equiv_camera_pitch 27 | DOE_max_height = 1.2e-6 28 | DOE_height_noise_scale = 4*10e-9 29 | DOE_phase_noise_scale = 0.1 30 | 31 | DOE_phase_init = torch.zeros((1,1, R, C)) 32 | 33 | # depth 34 | depth_near_min = 0.05 35 | depth_near_max = 0.12 36 | depth_far_min = 5 37 | depth_far_max = 10 38 | plot_depth = [5,3,1,0.12,0.08,0.05] 39 | 40 | # raindrop 41 | drop_Nmin = 5 42 | drop_Nmax = 8 43 | drop_Rmin = 1e-3 44 | drop_Rmax = 3e-3 45 | 46 | # dirt 47 | perlin_res = 8 48 | perlin_cutoff = 0.55 -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: SeeThroughObstruction 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=4.5=1_gnu 8 | - absl-py=0.13.0=py39h06a4308_0 9 | - aiohttp=3.7.4=py39h27cfd23_1 10 | - async-timeout=3.0.1=py39h06a4308_0 11 | - attrs=21.2.0=pyhd3eb1b0_0 12 | - blas=1.0=mkl 13 | - blinker=1.4=py39h06a4308_0 14 | - bottleneck=1.3.2=py39hdd57654_1 15 | - brotlipy=0.7.0=py39h27cfd23_1003 16 | - bzip2=1.0.8=h7b6447c_0 17 | - c-ares=1.17.1=h27cfd23_0 18 | - ca-certificates=2021.7.5=h06a4308_1 19 | - cachetools=4.2.2=pyhd3eb1b0_0 20 | - certifi=2021.5.30=py39h06a4308_0 21 | - cffi=1.14.6=py39h400218f_0 22 | - chardet=3.0.4=py39h06a4308_1003 23 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 24 | - click=8.0.1=pyhd3eb1b0_0 25 | - coverage=5.5=py39h27cfd23_2 26 | - cryptography=3.4.7=py39hd23ed53_0 27 | - cudatoolkit=10.2.89=hfd86e86_1 28 | - cython=0.29.24=py39h295c915_0 29 | - ffmpeg=4.3=hf484d3e_0 30 | - freetype=2.10.4=h5ab3b9f_0 31 | - gmp=6.2.1=h2531618_2 32 | - gnutls=3.6.15=he1e5248_0 33 | - google-auth=1.33.0=pyhd3eb1b0_0 34 | - google-auth-oauthlib=0.4.1=py_2 35 | - grpcio=1.36.1=py39h2157cd5_1 36 | - idna=3.2=pyhd3eb1b0_0 37 | - importlib-metadata=3.10.0=py39h06a4308_0 38 | - intel-openmp=2021.3.0=h06a4308_3350 39 | - jpeg=9b=h024ee3a_2 40 | - lame=3.100=h7b6447c_0 41 | - lcms2=2.12=h3be6417_0 42 | - ld_impl_linux-64=2.35.1=h7274673_9 43 | - libffi=3.3=he6710b0_2 44 | - libgcc-ng=9.3.0=h5101ec6_17 45 | - libgomp=9.3.0=h5101ec6_17 46 | - libiconv=1.15=h63c8f33_5 47 | - libidn2=2.3.2=h7f8727e_0 48 | - libpng=1.6.37=hbc83047_0 49 | - libprotobuf=3.17.2=h4ff587b_1 50 | - libstdcxx-ng=9.3.0=hd4cf53a_17 51 | - libtasn1=4.16.0=h27cfd23_0 52 | - libtiff=4.2.0=h85742a9_0 53 | - libunistring=0.9.10=h27cfd23_0 54 | - libuv=1.40.0=h7b6447c_0 55 | - libwebp-base=1.2.0=h27cfd23_0 56 | - lz4-c=1.9.3=h295c915_1 57 | - markdown=3.3.4=py39h06a4308_0 58 | - mkl=2021.3.0=h06a4308_520 59 | - mkl-service=2.4.0=py39h7f8727e_0 60 | - mkl_fft=1.3.0=py39h42c9631_2 61 | - mkl_random=1.2.2=py39h51133e4_0 62 | - multidict=5.1.0=py39h27cfd23_2 63 | - ncurses=6.2=he6710b0_1 64 | - nettle=3.7.3=hbbd107a_1 65 | - ninja=1.10.2=hff7bd54_1 66 | - numexpr=2.7.3=py39h22e1b3c_1 67 | - numpy=1.20.3=py39hf144106_0 68 | - numpy-base=1.20.3=py39h74d4b33_0 69 | - oauthlib=3.1.1=pyhd3eb1b0_0 70 | - olefile=0.46=py_0 71 | - openh264=2.1.0=hd408876_0 72 | - openjpeg=2.4.0=h3ad879b_0 73 | - openssl=1.1.1k=h27cfd23_0 74 | - pandas=1.3.2=py39h8c16a72_0 75 | - pillow=8.3.1=py39h2c7a002_0 76 | - pip=21.2.4=py37h06a4308_0 77 | - protobuf=3.17.2=py39h295c915_0 78 | - pyasn1=0.4.8=py_0 79 | - pyasn1-modules=0.2.8=py_0 80 | - pycparser=2.20=py_2 81 | - pyjwt=2.1.0=py39h06a4308_0 82 | - pyopenssl=20.0.1=pyhd3eb1b0_1 83 | - pysocks=1.7.1=py39h06a4308_0 84 | - python=3.9.6=h12debd9_1 85 | - python-dateutil=2.8.2=pyhd3eb1b0_0 86 | - pytorch=1.9.0=py3.9_cuda10.2_cudnn7.6.5_0 87 | - pytz=2021.1=pyhd3eb1b0_0 88 | - readline=8.1=h27cfd23_0 89 | - requests=2.26.0=pyhd3eb1b0_0 90 | - requests-oauthlib=1.3.0=py_0 91 | - rsa=4.7.2=pyhd3eb1b0_1 92 | - setuptools=52.0.0=py39h06a4308_0 93 | - six=1.16.0=pyhd3eb1b0_0 94 | - sqlite=3.36.0=hc218d9a_0 95 | - tensorboard=2.5.0=py_0 96 | - tensorboard-plugin-wit=1.6.0=py_0 97 | - tk=8.6.10=hbc83047_0 98 | - torchaudio=0.9.0=py39 99 | - torchvision=0.10.0=py39_cu102 100 | - typing-extensions=3.10.0.0=hd3eb1b0_0 101 | - typing_extensions=3.10.0.0=pyh06a4308_0 102 | - tzdata=2021a=h5d7bf9c_0 103 | - urllib3=1.26.6=pyhd3eb1b0_1 104 | - werkzeug=1.0.1=pyhd3eb1b0_0 105 | - wheel=0.37.0=pyhd3eb1b0_0 106 | - xz=5.2.5=h7b6447c_0 107 | - yarl=1.6.3=py39h27cfd23_0 108 | - zipp=3.5.0=pyhd3eb1b0_0 109 | - zlib=1.2.11=h7b6447c_3 110 | - zstd=1.4.9=haebb681_0 111 | - pip: 112 | - anyio==3.3.0 113 | - argon2-cffi==21.1.0 114 | - astropy==4.3.1 115 | - babel==2.9.1 116 | - backcall==0.2.0 117 | - bleach==4.1.0 118 | - cycler==0.10.0 119 | - debugpy==1.4.1 120 | - decorator==5.0.9 121 | - defusedxml==0.7.1 122 | - entrypoints==0.3 123 | - imageio==2.9.0 124 | - ipykernel==6.3.1 125 | - ipython==7.27.0 126 | - ipython-genutils==0.2.0 127 | - jedi==0.18.0 128 | - jinja2==3.0.1 129 | - json5==0.9.6 130 | - jsonschema==3.2.0 131 | - jupyter-client==7.0.2 132 | - jupyter-core==4.7.1 133 | - jupyter-server==1.10.2 134 | - jupyterlab==3.1.9 135 | - jupyterlab-pygments==0.1.2 136 | - jupyterlab-server==2.7.2 137 | - kiwisolver==1.3.2 138 | - lightpipes==2.1.1 139 | - lpips==0.1.4 140 | - markupsafe==2.0.1 141 | - matplotlib==3.4.3 142 | - matplotlib-inline==0.1.2 143 | - mistune==0.8.4 144 | - nbclassic==0.3.1 145 | - nbclient==0.5.4 146 | - nbconvert==6.1.0 147 | - nbformat==5.1.3 148 | - nest-asyncio==1.5.1 149 | - networkx==2.6.3 150 | - noise==1.2.2 151 | - notebook==6.4.3 152 | - opencv-python==4.5.3.56 153 | - packaging==21.0 154 | - pandocfilters==1.4.3 155 | - parso==0.8.2 156 | - pexpect==4.8.0 157 | - pickleshare==0.7.5 158 | - poppy==0.9.2 159 | - prometheus-client==0.11.0 160 | - prompt-toolkit==3.0.20 161 | - prysm==0.20 162 | - ptyprocess==0.7.0 163 | - pyblur==0.2.3 164 | - pyerfa==2.0.0 165 | - pygments==2.10.0 166 | - pyparsing==2.4.7 167 | - pyrsistent==0.18.0 168 | - pywavelets==1.1.1 169 | - pyzmq==22.2.1 170 | - requests-unixsocket==0.2.0 171 | - scikit-image==0.18.3 172 | - scipy==1.7.1 173 | - send2trash==1.8.0 174 | - sniffio==1.2.0 175 | - terminado==0.11.1 176 | - testpath==0.5.0 177 | - tifffile==2021.8.30 178 | - tornado==6.1 179 | - tqdm==4.62.2 180 | - traitlets==5.1.0 181 | - wand==0.6.7 182 | - wcwidth==0.2.5 183 | - webencodings==0.5.1 184 | - websocket-client==1.2.1 185 | 186 | -------------------------------------------------------------------------------- /models/Defence.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | from PIL import Image 6 | import os 7 | 8 | from utils.utils import * 9 | 10 | class DefencingDataset(Dataset): 11 | 12 | def __init__(self, root_dir, transform=None): 13 | self.root_dir = root_dir 14 | self.transform = transform 15 | self.image_list = [f for f in os.listdir(root_dir) if not f.startswith('.')] 16 | 17 | def __len__(self): 18 | return len(self.image_list) 19 | 20 | def __getitem__(self, idx): 21 | if torch.is_tensor(idx): 22 | idx = idx.tolist() 23 | 24 | img_name = os.path.join(self.root_dir, 25 | self.image_list[idx]) 26 | image = Image.open(img_name).convert('RGB') 27 | pixel_width = int(self.image_list[idx].split('-')[-1].split('.')[0]) 28 | if self.transform: 29 | image = self.transform(image) 30 | sample = {'image': image , 'pixel_width': pixel_width} 31 | return sample 32 | 33 | def _largest_rotated_rect(w, h, angle): 34 | """ 35 | Given a rectangle of size wxh that has been rotated by 'angle' (in 36 | radians), computes the width and height of the largest possible 37 | axis-aligned rectangle within the rotated rectangle. 38 | Original JS code by 'Andri' and Magnus Hoff from Stack Overflow 39 | Converted to Python by Aaron Snoswell 40 | Source: http://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders 41 | """ 42 | 43 | quadrant = int(math.floor(angle / (math.pi / 2))) & 3 44 | sign_alpha = angle if ((quadrant & 1) == 0) else math.pi - angle 45 | alpha = (sign_alpha % math.pi + math.pi) % math.pi 46 | 47 | bb_w = w * math.cos(alpha) + h * math.sin(alpha) 48 | bb_h = w * math.sin(alpha) + h * math.cos(alpha) 49 | 50 | gamma = math.atan2(bb_w, bb_w) if (w < h) else math.atan2(bb_w, bb_w) 51 | 52 | delta = math.pi - alpha - gamma 53 | 54 | length = h if (w < h) else w 55 | 56 | d = length * math.cos(alpha) 57 | a = d * math.sin(alpha) / math.sin(delta) 58 | 59 | y = a * math.cos(gamma) 60 | x = y * math.tan(gamma) 61 | 62 | return ( 63 | bb_w - 2 * x, 64 | bb_h - 2 * y 65 | ) 66 | 67 | transform_fence = transforms.Compose([ 68 | transforms.ColorJitter(contrast=0.2), 69 | transforms.RandomRotation(90), 70 | transforms.CenterCrop(1300), 71 | transforms.RandomHorizontalFlip(), 72 | transforms.ToTensor(), 73 | ]) 74 | 75 | fenceset = DefencingDataset('/projects/FHEIDE/obstruction_free_doe/De-fencing/dataset_parsed/', transform=transform_fence) 76 | 77 | def compute_fence(image_far, depth, args): 78 | '''generate fence obstruction''' 79 | param = args.param 80 | fence_width = metric2pixel(randuni(param.fence_min, param.fence_max, 1)[0] , depth, args) 81 | 82 | fence = fenceset[np.random.randint(len(fenceset))] 83 | 84 | fence_image = fence['image'] 85 | T = transforms.Compose([ 86 | transforms.Resize(size=(int(fence_width/fence['pixel_width']*fence_image.shape[-2]),int(fence_width/fence['pixel_width']*fence_image.shape[-1]))), 87 | transforms.CenterCrop([param.img_res, param.img_res]) 88 | ]) 89 | 90 | image_near = T(fence_image)[None, ...] 91 | mask = (image_near > torch.median(image_near))*1.0 92 | image_near = image_near.to(args.device) * mask.to(args.device) + image_far * (1-mask.to(args.device)) 93 | # if torch.mean(mask) > 0.3: 94 | # print(torch.mean(mask)) 95 | # print(torch.median(image_near)) 96 | return image_near.to(args.device), mask.to(args.device) -------------------------------------------------------------------------------- /models/Dirt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.transforms as transforms 4 | from models.PerlinBlob import * 5 | 6 | def circle_grad(img_res): 7 | center_x, center_y = img_res // 2, img_res // 2 8 | circle_grad = np.zeros([img_res,img_res]) 9 | 10 | for y in range(img_res): 11 | for x in range(img_res): 12 | distx = abs(x - center_x) 13 | disty = abs(y - center_y) 14 | dist = np.sqrt(distx*distx + disty*disty) 15 | circle_grad[y][x] = dist 16 | max_grad = np.max(circle_grad) 17 | circle_grad = circle_grad / max_grad 18 | circle_grad -= 0.5 19 | circle_grad *= 2.0 20 | circle_grad = -circle_grad 21 | 22 | circle_grad -= np.min(circle_grad) 23 | max_grad = np.max(circle_grad) 24 | circle_grad = circle_grad / max_grad 25 | return circle_grad 26 | 27 | def compute_dirt(image_far, depth, args): 28 | param = args.param 29 | brown = (np.array([30, 20, 10]) * np.random.uniform(1,2.5) + np.random.uniform(-5,5,3) )* np.ones([param.img_res, param.img_res, 3]) / 255 30 | perlin_noise = generate_fractal_noise_2d([param.img_res, param.img_res], [param.perlin_res,param.perlin_res], tileable=(True,True), interpolant=interpolant) 31 | depth_adj = depth / param.depth_near_max 32 | T = transforms.Compose([transforms.ToTensor(), 33 | transforms.RandomCrop(int(param.img_res * depth_adj)), 34 | transforms.Resize([param.img_res, param.img_res])]) 35 | perlin_noise = T(perlin_noise).squeeze().numpy() 36 | alpha_map = perlin_noise * (perlin_noise > param.perlin_cutoff) * circle_grad(param.img_res) 37 | alpha_map /= np.max(alpha_map) 38 | image_near = torch.tensor(brown* alpha_map[...,None]).permute(2,0,1)[None,...] 39 | mask = torch.tile(torch.tensor(1.0*(alpha_map[...,None] > 0.3)).permute(2,0,1)[None,...],(1,3,1,1)) 40 | image_near = image_near.to(args.device) * mask.to(args.device) + image_far * (1-mask.to(args.device)) 41 | return image_near.to(args.device), mask.to(args.device) 42 | 43 | -------------------------------------------------------------------------------- /models/PerlinBlob.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def interpolant(t): 4 | return t*t*t*(t*(t*6 - 15) + 10) 5 | 6 | def interpolant1(t): 7 | return t*t*(t*(t*6 - 15) + 10) 8 | 9 | def interpolant2(t): 10 | return t*(t*(t*6 - 15) + 10) 11 | 12 | def generate_perlin_noise_2d( 13 | shape, res, tileable=(False, False), interpolant=interpolant 14 | ): 15 | """Generate a 2D numpy array of perlin noise. 16 | Args: 17 | shape: The shape of the generated array (tuple of two ints). 18 | This must be a multple of res. 19 | res: The number of periods of noise to generate along each 20 | axis (tuple of two ints). Note shape must be a multiple of 21 | res. 22 | tileable: If the noise should be tileable along each axis 23 | (tuple of two bools). Defaults to (False, False). 24 | interpolant: The interpolation function, defaults to 25 | t*t*t*(t*(t*6 - 15) + 10). 26 | Returns: 27 | A numpy array of shape shape with the generated noise. 28 | Raises: 29 | ValueError: If shape is not a multiple of res. 30 | """ 31 | delta = (res[0] / shape[0], res[1] / shape[1]) 32 | d = (shape[0] // res[0], shape[1] // res[1]) 33 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]]\ 34 | .transpose(1, 2, 0) % 1 35 | # Gradients 36 | angles = 2*np.pi*np.random.rand(res[0]+1, res[1]+1) 37 | gradients = np.dstack((np.cos(angles), np.sin(angles))) 38 | if tileable[0]: 39 | gradients[-1,:] = gradients[0,:] 40 | if tileable[1]: 41 | gradients[:,-1] = gradients[:,0] 42 | gradients = gradients.repeat(d[0], 0).repeat(d[1], 1) 43 | g00 = gradients[ :-d[0], :-d[1]] 44 | g10 = gradients[d[0]: , :-d[1]] 45 | g01 = gradients[ :-d[0],d[1]: ] 46 | g11 = gradients[d[0]: ,d[1]: ] 47 | # Ramps 48 | n00 = np.sum(np.dstack((grid[:,:,0] , grid[:,:,1] )) * g00, 2) 49 | n10 = np.sum(np.dstack((grid[:,:,0]-1, grid[:,:,1] )) * g10, 2) 50 | n01 = np.sum(np.dstack((grid[:,:,0] , grid[:,:,1]-1)) * g01, 2) 51 | n11 = np.sum(np.dstack((grid[:,:,0]-1, grid[:,:,1]-1)) * g11, 2) 52 | # Interpolation 53 | t = interpolant(grid) 54 | n0 = n00*(1-t[:,:,0]) + t[:,:,0]*n10 55 | n1 = n01*(1-t[:,:,0]) + t[:,:,0]*n11 56 | 57 | ret = np.sqrt(2)*((1-t[:,:,1])*n0 + t[:,:,1]*n1) 58 | ret = ret - ret.min() 59 | ret = ret/ret.max() 60 | return ret 61 | 62 | 63 | def generate_fractal_noise_2d( 64 | shape, res, octaves=1, persistence=0.5, 65 | lacunarity=2, tileable=(False, False), 66 | interpolant=interpolant 67 | ): 68 | """Generate a 2D numpy array of fractal noise. 69 | Args: 70 | shape: The shape of the generated array (tuple of two ints). 71 | This must be a multiple of lacunarity**(octaves-1)*res. 72 | res: The number of periods of noise to generate along each 73 | axis (tuple of two ints). Note shape must be a multiple of 74 | (lacunarity**(octaves-1)*res). 75 | octaves: The number of octaves in the noise. Defaults to 1. 76 | persistence: The scaling factor between two octaves. 77 | lacunarity: The frequency factor between two octaves. 78 | tileable: If the noise should be tileable along each axis 79 | (tuple of two bools). Defaults to (False, False). 80 | interpolant: The, interpolation function, defaults to 81 | t*t*t*(t*(t*6 - 15) + 10). 82 | Returns: 83 | A numpy array of fractal noise and of shape shape generated by 84 | combining several octaves of perlin noise. 85 | Raises: 86 | ValueError: If shape is not a multiple of 87 | (lacunarity**(octaves-1)*res). 88 | """ 89 | noise = np.zeros(shape) 90 | frequency = 1 91 | amplitude = 1 92 | for _ in range(octaves): 93 | noise += amplitude * generate_perlin_noise_2d( 94 | shape, (frequency*res[0], frequency*res[1]), tileable, interpolant 95 | ) 96 | frequency *= lacunarity 97 | amplitude *= persistence 98 | return noise -------------------------------------------------------------------------------- /models/ROLE.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import torch 5 | 6 | from utils.utils import * 7 | 8 | class raindrop(): 9 | def __init__(self, centerxy, radius): 10 | self.ifcol = False 11 | self.col_with = [] 12 | self.center = centerxy 13 | self.radius = radius 14 | self.alphamap = np.zeros((self.radius * 5, self.radius*4,1)) 15 | self.texture = None 16 | self._createDefaultDrop() 17 | 18 | def updateTexture(self, fg): 19 | 20 | # add fish eye effect to simulate the background 21 | K = np.array([[30*self.radius, 0, 2*self.radius], 22 | [0., 20*self.radius, 3*self.radius], 23 | [0., 0., 1]]) 24 | D = np.array([0.0, 0.0, 0.0, 0.0]) 25 | Knew = K.copy() 26 | Knew[(0,1), (0,1)] = math.pow(self.radius, 1/3) * 1.2 * Knew[(0,1), (0,1)] 27 | fisheye = cv2.fisheye.undistortImage(fg, K, D=D, Knew=Knew) 28 | tmp = np.array(fisheye) 29 | 30 | self.texture = np.clip((0.6 + 1.5*(tmp-0.5)),0,1) 31 | 32 | def _createDefaultDrop(self): 33 | cv2.circle(self.alphamap, (self.radius * 2, self.radius * 3), self.radius, 128, -1) 34 | cv2.ellipse(self.alphamap, (self.radius * 2, self.radius * 3), (self.radius, int(np.random.uniform(0.8,1.5) * self.radius)), 0, 180, 360, 128, -1) 35 | # set alpha map for png 36 | # self.alphamap = cv2.GaussianBlur(self.alphamap,(1,1),0) 37 | self.alphamap = np.asarray(self.alphamap) 38 | self.alphamap = (self.alphamap/np.max(self.alphamap)) 39 | 40 | def compute_raindrop(image_far, depth, args): 41 | '''generate raindrop obstruction''' 42 | param = args.param 43 | image = image_far[0].permute(1,2,0).cpu().numpy() 44 | 45 | drop_num = int(np.random.randint(param.drop_Nmin, param.drop_Nmax) * depth / param.depth_near_max) 46 | 47 | alpha_map = np.zeros_like(image)[:,:,0:1] 48 | imgh, imgw, _ = image.shape 49 | edge_gap = metric2pixel(param.drop_Rmin,depth, args) 50 | ran_pos = [(edge_gap + int(np.random.rand() * (imgw - 2*edge_gap)), 2*edge_gap + int(np.random.rand() * (imgh - 3*edge_gap))) for _ in range(drop_num)] 51 | listRainDrops = [] 52 | ######################### 53 | # Create Raindrop 54 | ######################### 55 | # create raindrop by default 56 | 57 | for pos in ran_pos: 58 | radius = np.minimum(metric2pixel(np.random.uniform(param.drop_Rmin, param.drop_Rmax), depth, args), int(np.floor(param.img_res/20))) 59 | drop = raindrop(pos, radius) 60 | listRainDrops.append(drop) 61 | # add texture 62 | for drop in listRainDrops: 63 | (ix, iy) = drop.center 64 | radius = drop.radius 65 | ROI_WL = 2*radius 66 | ROI_WR = 2*radius 67 | ROI_HU = 3*radius 68 | ROI_HD = 2*radius 69 | if (iy-3*radius) <0 : 70 | ROI_HU = iy 71 | if (iy+2*radius)>imgh: 72 | ROI_HD = imgh - iy 73 | if (ix-2*radius)<0: 74 | ROI_WL = ix 75 | if (ix+2*radius) > imgw: 76 | ROI_WR = imgw - ix 77 | 78 | drop_alpha = drop.alphamap 79 | 80 | alpha_map[iy - ROI_HU:iy + ROI_HD, ix - ROI_WL: ix+ROI_WR,:] += drop_alpha[3*radius - ROI_HU:3*radius + ROI_HD, 2*radius - ROI_WL: 2*radius+ROI_WR,:] 81 | drop.image_coor = np.array([iy - ROI_HU, iy + ROI_HD, ix - ROI_WL, ix+ROI_WR]) 82 | drop.alpha_coor = np.array([3*radius - ROI_HU, 3*radius + ROI_HD, 2*radius - ROI_WL, 2*radius+ROI_WR]) 83 | 84 | upshift = int(0.1 * (iy - ROI_HU)) 85 | drop.updateTexture(image[iy - ROI_HU - upshift: iy + ROI_HD - upshift, ix - ROI_WL: ix+ROI_WR]) 86 | 87 | image_near = np.asarray(cv2.GaussianBlur(image,(5,5),0)) 88 | for drop in listRainDrops: 89 | img_hl,img_hr, img_wl, img_wr = drop.image_coor 90 | drop_hl, drop_hr, drop_wl, drop_wr = drop.alpha_coor 91 | texture_blend = drop.texture*(drop.alphamap[drop_hl:drop_hr, drop_wl:drop_wr]) 92 | update_alpha = drop.alphamap[drop_hl:drop_hr, drop_wl:drop_wr] > 0 93 | image_near[img_hl:img_hr, img_wl: img_wr] = texture_blend * update_alpha + image_near[img_hl:img_hr, img_wl: img_wr] * (1-update_alpha) 94 | 95 | image_near = torch.tensor(image_near * (alpha_map > 0)).permute(2,0,1)[None,...] 96 | # image_near = torch.tensor(image_near).permute(2,0,1)[None,...] 97 | mask = torch.tile(torch.tensor(1.0*(alpha_map > 0)).permute(2,0,1)[None,...],(1,3,1,1)) 98 | return image_near.to(args.device), mask.to(args.device) -------------------------------------------------------------------------------- /models/forwards.py: -------------------------------------------------------------------------------- 1 | from pado.light import * 2 | from pado.optical_element import * 3 | from pado.propagator import * 4 | 5 | from utils.utils import * 6 | 7 | def compute_psf(wvl, depth, doe, args): 8 | '''simulate depth based psf''' 9 | param = args.param 10 | prop = Propagator('Fresnel') 11 | 12 | light = Light(param.R, param.C, param.DOE_pitch, wvl, args.device,B=1) 13 | light.set_spherical_light(depth .numpy()) 14 | 15 | lens = RefractiveLens(param.R, param.C, param.DOE_pitch, param.focal_length, wvl,args.device) 16 | light = lens.forward(light) 17 | 18 | doe.change_wvl(wvl) 19 | light = doe.forward(light) 20 | 21 | aperture = Aperture(param.R, param.C, param.DOE_pitch, param.aperture_diamter, param.aperture_shape, wvl, args.device) 22 | light = aperture.forward(light) 23 | 24 | light_prop = prop.forward(light, param.sensor_dist) 25 | psf = light_prop.get_intensity() 26 | psf = sample_psf(psf, param.DOE_sample_ratio) 27 | psf /= torch.sum(psf) 28 | 29 | psf_size = psf.shape 30 | if psf_size[-2]*psf_size[-1] < param.img_res**2: 31 | wl, wr = compute_pad_size(psf_size[-1], param.img_res) 32 | hl, hr = compute_pad_size(psf_size[-2], param.img_res) 33 | psf = F.pad(psf, (wl, wr, hl, hr), "constant", 0) 34 | elif psf_size[-2]*psf_size[-1] > param.img_res**2: 35 | wl, wr = compute_pad_size(param.img_res, psf_size[-1]) 36 | hl, hr = compute_pad_size(param.img_res, psf_size[-2]) 37 | psf = psf[:,:,hl:-hr, wl:-wr] 38 | 39 | cutoff = np.tan(np.sinh(wvl/(3*param.DOE_pitch)))*param.focal_length / param.equiv_camera_pitch 40 | DOE_mask = edge_mask(int(param.img_res / 2),cutoff, args.device) 41 | psf *= DOE_mask 42 | psf /= torch.sum(psf) 43 | return psf 44 | 45 | def compute_psf_Fraunhofer(wvl, depth, doe, args): 46 | '''simulate depth based psf''' 47 | param = args.param 48 | prop = Propagator('Fraunhofer') 49 | 50 | light = Light(param.R, param.C, param.DOE_pitch, wvl, args.device,B=1) 51 | light.set_spherical_light(depth.numpy()) 52 | 53 | doe.change_wvl(wvl) 54 | light = doe.forward(light) 55 | 56 | aperture = Aperture(param.R, param.C, param.DOE_pitch, param.aperture_diamter, param.aperture_shape, wvl, args.device) 57 | light = aperture.forward(light) 58 | 59 | light_prop = prop.forward(light, param.sensor_dist) 60 | psf = light_prop.get_intensity() 61 | 62 | # resize 63 | psf = F.interpolate(psf, int(param.R * light_prop.pitch / param.DOE_pitch)) 64 | 65 | psf = sample_psf(psf, param.DOE_sample_ratio) 66 | 67 | psf_size = psf.shape 68 | if psf_size[-2]*psf_size[-1] < param.img_res**2: 69 | wl, wr = compute_pad_size(psf_size[-1], param.img_res) 70 | hl, hr = compute_pad_size(psf_size[-2], param.img_res) 71 | psf = F.pad(psf, (wl, wr, hl, hr), "constant", 0) 72 | elif psf_size[-2]*psf_size[-1] > param.img_res**2: 73 | wl, wr = compute_pad_size(param.img_res, psf_size[-1]) 74 | hl, hr = compute_pad_size(param.img_res, psf_size[-2]) 75 | psf = psf[:,:,hl:-hr, wl:-wr] 76 | 77 | cutoff = np.tan(np.sinh(wvl/(3*param.DOE_pitch)))*param.focal_length / param.equiv_camera_pitch 78 | DOE_mask = edge_mask(int(param.img_res / 2),cutoff, args.device) 79 | psf *= DOE_mask 80 | psf /= torch.sum(psf) 81 | return psf 82 | 83 | def image_formation(image_far, DOE_phase, compute_obstruction, args, z_near = None): 84 | param = args.param 85 | doe = DOE(param.R, param.C, param.DOE_pitch, param.material, param.DOE_wvl,args.device, phase = DOE_phase) 86 | height_map = doe.get_height() * edge_mask(int(param.R/ 2),int(param.R/ 2), args.device) 87 | 88 | if z_near is None: 89 | z_near = randuni(param.depth_near_min, param.depth_near_max, 1)[0] # randomly sample the near-point depth from a range 90 | 91 | z_far = randuni(param.depth_far_min, param.depth_far_max, 1)[0] # randomly sample the far-point depth from a range 92 | image_near, mask = compute_obstruction(image_far, z_near, args) 93 | 94 | img_doe = [] 95 | img_near_doe = [] 96 | img_far_doe = [] 97 | psf_far_doe = [] 98 | psf_near_doe = [] 99 | mask_doe = [] 100 | 101 | for i in range(len(param.wvls)): 102 | 103 | wvl = param.wvls[i] 104 | 105 | psf_near = compute_psf_Fraunhofer(wvl, z_near, doe, args) 106 | psf_far = compute_psf_Fraunhofer(wvl, z_far, doe, args) 107 | 108 | img_near_conv = conv_fft(real2complex(image_near[:,i,:,:]), real2complex(psf_near), (int(param.R/2),int(param.R/2),int(param.R/2),int(param.R/2))).get_mag() # 109 | img_far_conv = conv_fft(real2complex(image_far[:,i,:,:]), real2complex(psf_far)).get_mag() 110 | mask_conv = conv_fft(real2complex(mask[:,i,:,:]), real2complex(psf_near), (int(param.R/2),int(param.R/2),int(param.R/2),int(param.R/2))).get_mag() 111 | mask_conv = torch.clamp(1.5*mask_conv,0,1) 112 | img_conv = img_near_conv * mask_conv + img_far_conv * (1 - mask_conv) 113 | 114 | img_doe.append(img_conv) 115 | img_near_doe.append(img_near_conv) 116 | img_far_doe.append(img_far_conv) 117 | psf_near_doe.append(psf_near) 118 | psf_far_doe.append(psf_far) 119 | mask_doe.append(mask_conv) 120 | img_doe = torch.cat(img_doe, dim = 1) 121 | img_near_doe = torch.cat(img_near_doe, dim = 1) 122 | img_far_doe = torch.cat(img_far_doe, dim = 1) 123 | psf_near_doe = torch.cat(psf_near_doe, dim = 1) 124 | psf_far_doe = torch.cat(psf_far_doe, dim = 1) 125 | mask_doe = torch.cat(mask_doe, dim = 1) 126 | 127 | if args.sensor_noise > 0: 128 | noise = torch.rand(img_doe.shape) * 2 * args.sensor_noise - args.sensor_noise 129 | img_doe = torch.clamp(img_doe + noise.type_as(img_doe), 0, 1) 130 | 131 | return image_near.type_as(image_far), mask.type_as(image_far).type_as(image_far), \ 132 | img_doe.type_as(image_far), img_near_doe.type_as(image_far), img_far_doe.type_as(image_far), \ 133 | psf_near_doe.type_as(image_far), psf_far_doe.type_as(image_far), mask_doe.type_as(image_far), height_map.type_as(image_far) -------------------------------------------------------------------------------- /models/recon.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from pado.fourier import fft, ifft 7 | from utils.utils import * 8 | 9 | class Arch(nn.Module): 10 | def __init__(self, args, nf = 12): 11 | super().__init__() 12 | self.nf = nf 13 | 14 | self.down0_0 = ConvBlock(6, nf, 7, 1, 3) 15 | self.down0_1 = ConvBlock(nf,nf, 3,1,1) 16 | 17 | self.down1_0 = ConvBlock(nf, 2*nf, 3,2,1) 18 | self.down1_1 = ConvBlock(2*nf, 2*nf, 3,1,1) 19 | self.down1_2 = ConvBlock(2*nf, 2*nf, 3,1,1) 20 | 21 | self.down2_0 = ConvBlock(2*nf, 4*nf, 3,2,1) 22 | self.down2_1 = ConvBlock(4*nf, 4*nf, 3,1,1) 23 | self.down2_2 = ConvBlock(4*nf, 4*nf, 3,1,1) 24 | 25 | self.down3_0 = ConvBlock(4*nf, 8*nf, 3,2,1) 26 | self.down3_1 = ConvBlock(8*nf, 8*nf, 3,1,1) 27 | self.down3_2 = ConvBlock(8*nf, 8*nf, 3,1,1) 28 | 29 | self.down4_0 = ConvBlock(8*nf, 12*nf, 3,2,1) 30 | self.down4_1 = ConvBlock(12*nf, 12*nf, 3,1,1) 31 | self.down4_2 = ConvBlock(12*nf, 12*nf, 3,1,1) 32 | 33 | self.bottleneck_0 = ConvBlock(24*nf, 24*nf, 3,1,1) 34 | self.bottleneck_1 = ConvBlock(24*nf, 12*nf, 3,1,1) 35 | 36 | self.up4_0 = ConvTraspBlock(12*nf, 8*nf, 2,2,0) 37 | self.up4_1 = ConvBlock(24*nf, 8*nf, 3,1,1) 38 | 39 | self.up3_0 = ConvTraspBlock(8*nf, 4*nf, 2,2,0) 40 | self.up3_1 = ConvBlock(12*nf, 4*nf, 3,1,1) 41 | 42 | self.up2_0 = ConvTraspBlock(4*nf, 2*nf, 2,2,0) 43 | self.up2_1 = ConvBlock(6*nf, 2*nf, 3,1,1) 44 | 45 | self.up1_0 = ConvTraspBlock(2*nf, nf, 2,2,0) 46 | self.up1_1 = ConvBlock(3*nf, nf, 3,1,1) 47 | 48 | self.up0_0 = ConvBlock(nf, 3, 3,1,1) 49 | self.up0_1 = nn.Sequential( 50 | nn.ReflectionPad2d(2), 51 | nn.Conv2d(3, 3, kernel_size=5, stride=1), 52 | nn.Tanh() 53 | ) 54 | 55 | # init to identity mapping 56 | self.res0 = BasicBlock(3, 3, 3, 1, 1).to(args.device) 57 | self.res1 = BasicBlock(3, 3, 3, 1, 1).to(args.device) 58 | self.res2 = BasicBlock(3, 3, 3, 1, 1).to(args.device) 59 | 60 | self.out = nn.Sequential( 61 | nn.ReflectionPad2d(2), 62 | nn.Conv2d(3, 3, kernel_size=5, stride=1), 63 | nn.ReLU() 64 | ) 65 | self.out.apply(init_id_weights) 66 | 67 | 68 | def forward(self, image, psf_near, psf_far): 69 | psf0 = torch.tile(psf_far, (1,int(self.nf/3),1,1)) 70 | psf1 = torch.tile(sample_psf(psf_far, 2), (1,int(2*self.nf/3),1,1)) 71 | psf2 = torch.tile(sample_psf(psf_far, 4), (1,int(4*self.nf/3),1,1)) 72 | psf3 = torch.tile(sample_psf(psf_far, 8), (1,int(8*self.nf/3),1,1)) 73 | psf4 = torch.tile(sample_psf(psf_far, 16), (1,int(12*self.nf/3),1,1)) 74 | 75 | images = [image, Wiener_deconv(image, psf_near)] 76 | images = torch.cat(images, dim = 1) 77 | 78 | down0 = self.down0_0(images) 79 | down0 = self.down0_1(down0) 80 | deconv0 = Wiener_deconv(down0, psf0) 81 | 82 | down1 = self.down1_0(down0) 83 | down1 = self.down1_1(down1) 84 | down1 = self.down1_2(down1) 85 | deconv1 = Wiener_deconv(down1, psf1) 86 | 87 | down2 = self.down2_0(down1) 88 | down2 = self.down2_1(down2) 89 | down2 = self.down2_2(down2) 90 | deconv2 = Wiener_deconv(down2, psf2) 91 | 92 | down3 = self.down3_0(down2) 93 | down3 = self.down3_1(down3) 94 | down3 = self.down3_2(down3) 95 | deconv3 = Wiener_deconv(down3, psf3) 96 | 97 | down4 = self.down4_0(down3) 98 | down4 = self.down4_1(down4) 99 | down4 = self.down4_2(down4) 100 | deconv4 = Wiener_deconv(down4, psf4) 101 | 102 | bottleneck = self.bottleneck_0(torch.cat([deconv4,down4], 1)) 103 | bottleneck = self.bottleneck_1(bottleneck) 104 | 105 | up4 = self.up4_0(bottleneck) 106 | up4 = self.up4_1(torch.cat([up4,down3, deconv3], 1)) 107 | 108 | up3 = self.up3_0(up4) 109 | up3 = self.up3_1(torch.cat([up3,down2, deconv2], 1)) 110 | 111 | up2 = self.up2_0(up3) 112 | up2 = self.up2_1(torch.cat([up2,down1, deconv1], 1)) 113 | 114 | up1 = self.up1_0(up2) 115 | up1 = self.up1_1(torch.cat([up1,down0, deconv0], 1)) 116 | 117 | up0 = self.up0_0(up1) 118 | up0 = self.up0_1(up0) 119 | up0 = up0 + image 120 | 121 | res = self.res0(up0) 122 | res = self.res1(res) 123 | res = self.res2(res) 124 | 125 | out = self.out(res) 126 | 127 | return out 128 | 129 | def ConvBlock(in_channels, out_channels, kernel_size, stride, padding): 130 | block = nn.Sequential( 131 | nn.ReflectionPad2d(padding), 132 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride), #padding=padding), 133 | nn.InstanceNorm2d(out_channels), 134 | nn.LeakyReLU() 135 | ) 136 | return block 137 | 138 | def ConvTraspBlock(in_channels, out_channels, kernel_size, stride, padding): 139 | block = nn.Sequential( 140 | nn.ReflectionPad2d(padding), 141 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride), #padding=padding), 142 | nn.InstanceNorm2d(out_channels), 143 | nn.LeakyReLU() 144 | ) 145 | return block 146 | 147 | def Wiener_deconv(image, psf): 148 | otf = fft(real2complex(psf)) 149 | wiener = otf.conj() / real2complex(otf.get_mag() **2 + 1e-6) 150 | image_deconv = ifft(wiener * fft(real2complex(image))).get_mag() 151 | return torch.clamp(image_deconv.type_as(image), min = 0, max = 1) 152 | 153 | def init_zero_weights(module): 154 | if isinstance(module, torch.nn.Conv2d): 155 | # module.weight.data.zero_() 156 | module.weight.data = torch.nn.init.xavier_uniform_(module.weight.data,gain=1e-3) 157 | if module.bias is not None: 158 | module.bias.data.zero_() 159 | 160 | def init_id_weights(module): 161 | if isinstance(module, nn.Conv2d): 162 | module.weight.data = torch.nn.init.dirac_(module.weight.data) 163 | if module.bias is not None: 164 | module.bias.data.zero_() 165 | 166 | class BasicBlock(nn.Module): 167 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding): 168 | super().__init__() 169 | self.conv1 = ConvBlock(in_channels, out_channels, kernel_size, stride, padding) 170 | self.conv1.apply(init_zero_weights) 171 | self.conv2 = nn.Sequential( 172 | nn.ReflectionPad2d(padding), 173 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride), 174 | nn.InstanceNorm2d(out_channels), 175 | ) 176 | self.conv2.apply(init_zero_weights) 177 | self.leakyrelu = nn.LeakyReLU() 178 | 179 | def forward(self, x): 180 | identity = x 181 | 182 | out = self.conv1(x) 183 | out = self.conv2(out) 184 | out += identity 185 | out = self.leakyrelu(out) 186 | 187 | return out 188 | 189 | 190 | -------------------------------------------------------------------------------- /pado-main/LICENSE: -------------------------------------------------------------------------------- 1 | Pado 2 | Copyright (c) 2022 by Seung-Hwan Baek 3 | 4 | Pado is licensed under a 5 | Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License. 6 | 7 | You should have received a copy of the license along with this 8 | work. If not, see . 9 | 10 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 11 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 12 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 14 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 15 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 16 | THE SOFTWARE. 17 | 18 | 19 | License 20 | 21 | THE WORK (AS DEFINED BELOW) IS PROVIDED UNDER THE TERMS OF THIS CREATIVE COMMONS PUBLIC LICENSE ("CCPL" OR "LICENSE"). THE WORK IS PROTECTED BY COPYRIGHT AND/OR OTHER APPLICABLE LAW. ANY USE OF THE WORK OTHER THAN AS AUTHORIZED UNDER THIS LICENSE OR COPYRIGHT LAW IS PROHIBITED. 22 | 23 | BY EXERCISING ANY RIGHTS TO THE WORK PROVIDED HERE, YOU ACCEPT AND AGREE TO BE BOUND BY THE TERMS OF THIS LICENSE. TO THE EXTENT THIS LICENSE MAY BE CONSIDERED TO BE A CONTRACT, THE LICENSOR GRANTS YOU THE RIGHTS CONTAINED HERE IN CONSIDERATION OF YOUR ACCEPTANCE OF SUCH TERMS AND CONDITIONS. 24 | 25 | 1. Definitions 26 | 27 | "Adaptation" means a work based upon the Work, or upon the Work and other pre-existing works, such as a translation, adaptation, derivative work, arrangement of music or other alterations of a literary or artistic work, or phonogram or performance and includes cinematographic adaptations or any other form in which the Work may be recast, transformed, or adapted including in any form recognizably derived from the original, except that a work that constitutes a Collection will not be considered an Adaptation for the purpose of this License. For the avoidance of doubt, where the Work is a musical work, performance or phonogram, the synchronization of the Work in timed-relation with a moving image ("synching") will be considered an Adaptation for the purpose of this License. 28 | "Collection" means a collection of literary or artistic works, such as encyclopedias and anthologies, or performances, phonograms or broadcasts, or other works or subject matter other than works listed in Section 1(g) below, which, by reason of the selection and arrangement of their contents, constitute intellectual creations, in which the Work is included in its entirety in unmodified form along with one or more other contributions, each constituting separate and independent works in themselves, which together are assembled into a collective whole. A work that constitutes a Collection will not be considered an Adaptation (as defined above) for the purposes of this License. 29 | "Distribute" means to make available to the public the original and copies of the Work or Adaptation, as appropriate, through sale or other transfer of ownership. 30 | "License Elements" means the following high-level license attributes as selected by Licensor and indicated in the title of this License: Attribution, Noncommercial, ShareAlike. 31 | "Licensor" means the individual, individuals, entity or entities that offer(s) the Work under the terms of this License. 32 | "Original Author" means, in the case of a literary or artistic work, the individual, individuals, entity or entities who created the Work or if no individual or entity can be identified, the publisher; and in addition (i) in the case of a performance the actors, singers, musicians, dancers, and other persons who act, sing, deliver, declaim, play in, interpret or otherwise perform literary or artistic works or expressions of folklore; (ii) in the case of a phonogram the producer being the person or legal entity who first fixes the sounds of a performance or other sounds; and, (iii) in the case of broadcasts, the organization that transmits the broadcast. 33 | "Work" means the literary and/or artistic work offered under the terms of this License including without limitation any production in the literary, scientific and artistic domain, whatever may be the mode or form of its expression including digital form, such as a book, pamphlet and other writing; a lecture, address, sermon or other work of the same nature; a dramatic or dramatico-musical work; a choreographic work or entertainment in dumb show; a musical composition with or without words; a cinematographic work to which are assimilated works expressed by a process analogous to cinematography; a work of drawing, painting, architecture, sculpture, engraving or lithography; a photographic work to which are assimilated works expressed by a process analogous to photography; a work of applied art; an illustration, map, plan, sketch or three-dimensional work relative to geography, topography, architecture or science; a performance; a broadcast; a phonogram; a compilation of data to the extent it is protected as a copyrightable work; or a work performed by a variety or circus performer to the extent it is not otherwise considered a literary or artistic work. 34 | "You" means an individual or entity exercising rights under this License who has not previously violated the terms of this License with respect to the Work, or who has received express permission from the Licensor to exercise rights under this License despite a previous violation. 35 | "Publicly Perform" means to perform public recitations of the Work and to communicate to the public those public recitations, by any means or process, including by wire or wireless means or public digital performances; to make available to the public Works in such a way that members of the public may access these Works from a place and at a place individually chosen by them; to perform the Work to the public by any means or process and the communication to the public of the performances of the Work, including by public digital performance; to broadcast and rebroadcast the Work by any means including signs, sounds or images. 36 | "Reproduce" means to make copies of the Work by any means including without limitation by sound or visual recordings and the right of fixation and reproducing fixations of the Work, including storage of a protected performance or phonogram in digital form or other electronic medium. 37 | 2. Fair Dealing Rights. Nothing in this License is intended to reduce, limit, or restrict any uses free from copyright or rights arising from limitations or exceptions that are provided for in connection with the copyright protection under copyright law or other applicable laws. 38 | 39 | 3. License Grant. Subject to the terms and conditions of this License, Licensor hereby grants You a worldwide, royalty-free, non-exclusive, perpetual (for the duration of the applicable copyright) license to exercise the rights in the Work as stated below: 40 | 41 | to Reproduce the Work, to incorporate the Work into one or more Collections, and to Reproduce the Work as incorporated in the Collections; 42 | to create and Reproduce Adaptations provided that any such Adaptation, including any translation in any medium, takes reasonable steps to clearly label, demarcate or otherwise identify that changes were made to the original Work. For example, a translation could be marked "The original work was translated from English to Spanish," or a modification could indicate "The original work has been modified."; 43 | to Distribute and Publicly Perform the Work including as incorporated in Collections; and, 44 | to Distribute and Publicly Perform Adaptations. 45 | The above rights may be exercised in all media and formats whether now known or hereafter devised. The above rights include the right to make such modifications as are technically necessary to exercise the rights in other media and formats. Subject to Section 8(f), all rights not expressly granted by Licensor are hereby reserved, including but not limited to the rights described in Section 4(e). 46 | 47 | 4. Restrictions. The license granted in Section 3 above is expressly made subject to and limited by the following restrictions: 48 | 49 | You may Distribute or Publicly Perform the Work only under the terms of this License. You must include a copy of, or the Uniform Resource Identifier (URI) for, this License with every copy of the Work You Distribute or Publicly Perform. You may not offer or impose any terms on the Work that restrict the terms of this License or the ability of the recipient of the Work to exercise the rights granted to that recipient under the terms of the License. You may not sublicense the Work. You must keep intact all notices that refer to this License and to the disclaimer of warranties with every copy of the Work You Distribute or Publicly Perform. When You Distribute or Publicly Perform the Work, You may not impose any effective technological measures on the Work that restrict the ability of a recipient of the Work from You to exercise the rights granted to that recipient under the terms of the License. This Section 4(a) applies to the Work as incorporated in a Collection, but this does not require the Collection apart from the Work itself to be made subject to the terms of this License. If You create a Collection, upon notice from any Licensor You must, to the extent practicable, remove from the Collection any credit as required by Section 4(d), as requested. If You create an Adaptation, upon notice from any Licensor You must, to the extent practicable, remove from the Adaptation any credit as required by Section 4(d), as requested. 50 | You may Distribute or Publicly Perform an Adaptation only under: (i) the terms of this License; (ii) a later version of this License with the same License Elements as this License; (iii) a Creative Commons jurisdiction license (either this or a later license version) that contains the same License Elements as this License (e.g., Attribution-NonCommercial-ShareAlike 3.0 US) ("Applicable License"). You must include a copy of, or the URI, for Applicable License with every copy of each Adaptation You Distribute or Publicly Perform. You may not offer or impose any terms on the Adaptation that restrict the terms of the Applicable License or the ability of the recipient of the Adaptation to exercise the rights granted to that recipient under the terms of the Applicable License. You must keep intact all notices that refer to the Applicable License and to the disclaimer of warranties with every copy of the Work as included in the Adaptation You Distribute or Publicly Perform. When You Distribute or Publicly Perform the Adaptation, You may not impose any effective technological measures on the Adaptation that restrict the ability of a recipient of the Adaptation from You to exercise the rights granted to that recipient under the terms of the Applicable License. This Section 4(b) applies to the Adaptation as incorporated in a Collection, but this does not require the Collection apart from the Adaptation itself to be made subject to the terms of the Applicable License. 51 | You may not exercise any of the rights granted to You in Section 3 above in any manner that is primarily intended for or directed toward commercial advantage or private monetary compensation. The exchange of the Work for other copyrighted works by means of digital file-sharing or otherwise shall not be considered to be intended for or directed toward commercial advantage or private monetary compensation, provided there is no payment of any monetary compensation in con-nection with the exchange of copyrighted works. 52 | If You Distribute, or Publicly Perform the Work or any Adaptations or Collections, You must, unless a request has been made pursuant to Section 4(a), keep intact all copyright notices for the Work and provide, reasonable to the medium or means You are utilizing: (i) the name of the Original Author (or pseudonym, if applicable) if supplied, and/or if the Original Author and/or Licensor designate another party or parties (e.g., a sponsor institute, publishing entity, journal) for attribution ("Attribution Parties") in Licensor's copyright notice, terms of service or by other reasonable means, the name of such party or parties; (ii) the title of the Work if supplied; (iii) to the extent reasonably practicable, the URI, if any, that Licensor specifies to be associated with the Work, unless such URI does not refer to the copyright notice or licensing information for the Work; and, (iv) consistent with Section 3(b), in the case of an Adaptation, a credit identifying the use of the Work in the Adaptation (e.g., "French translation of the Work by Original Author," or "Screenplay based on original Work by Original Author"). The credit required by this Section 4(d) may be implemented in any reasonable manner; provided, however, that in the case of a Adaptation or Collection, at a minimum such credit will appear, if a credit for all contributing authors of the Adaptation or Collection appears, then as part of these credits and in a manner at least as prominent as the credits for the other contributing authors. For the avoidance of doubt, You may only use the credit required by this Section for the purpose of attribution in the manner set out above and, by exercising Your rights under this License, You may not implicitly or explicitly assert or imply any connection with, sponsorship or endorsement by the Original Author, Licensor and/or Attribution Parties, as appropriate, of You or Your use of the Work, without the separate, express prior written permission of the Original Author, Licensor and/or Attribution Parties. 53 | For the avoidance of doubt: 54 | 55 | Non-waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme cannot be waived, the Licensor reserves the exclusive right to collect such royalties for any exercise by You of the rights granted under this License; 56 | Waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme can be waived, the Licensor reserves the exclusive right to collect such royalties for any exercise by You of the rights granted under this License if Your exercise of such rights is for a purpose or use which is otherwise than noncommercial as permitted under Section 4(c) and otherwise waives the right to collect royalties through any statutory or compulsory licensing scheme; and, 57 | Voluntary License Schemes. The Licensor reserves the right to collect royalties, whether individually or, in the event that the Licensor is a member of a collecting society that administers voluntary licensing schemes, via that society, from any exercise by You of the rights granted under this License that is for a purpose or use which is otherwise than noncommercial as permitted under Section 4(c). 58 | Except as otherwise agreed in writing by the Licensor or as may be otherwise permitted by applicable law, if You Reproduce, Distribute or Publicly Perform the Work either by itself or as part of any Adaptations or Collections, You must not distort, mutilate, modify or take other derogatory action in relation to the Work which would be prejudicial to the Original Author's honor or reputation. Licensor agrees that in those jurisdictions (e.g. Japan), in which any exercise of the right granted in Section 3(b) of this License (the right to make Adaptations) would be deemed to be a distortion, mutilation, modification or other derogatory action prejudicial to the Original Author's honor and reputation, the Licensor will waive or not assert, as appropriate, this Section, to the fullest extent permitted by the applicable national law, to enable You to reasonably exercise Your right under Section 3(b) of this License (right to make Adaptations) but not otherwise. 59 | 5. Representations, Warranties and Disclaimer 60 | 61 | UNLESS OTHERWISE MUTUALLY AGREED TO BY THE PARTIES IN WRITING AND TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, LICENSOR OFFERS THE WORK AS-IS AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE WORK, EXPRESS, IMPLIED, STATUTORY OR OTHERWISE, INCLUDING, WITHOUT LIMITATION, WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NONINFRINGEMENT, OR THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OF ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OF IMPLIED WARRANTIES, SO THIS EXCLUSION MAY NOT APPLY TO YOU. 62 | 63 | 6. Limitation on Liability. EXCEPT TO THE EXTENT REQUIRED BY APPLICABLE LAW, IN NO EVENT WILL LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY FOR ANY SPECIAL, INCIDENTAL, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES ARISING OUT OF THIS LICENSE OR THE USE OF THE WORK, EVEN IF LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 64 | 65 | 7. Termination 66 | 67 | This License and the rights granted hereunder will terminate automatically upon any breach by You of the terms of this License. Individuals or entities who have received Adaptations or Collections from You under this License, however, will not have their licenses terminated provided such individuals or entities remain in full compliance with those licenses. Sections 1, 2, 5, 6, 7, and 8 will survive any termination of this License. 68 | Subject to the above terms and conditions, the license granted here is perpetual (for the duration of the applicable copyright in the Work). Notwithstanding the above, Licensor reserves the right to release the Work under different license terms or to stop distributing the Work at any time; provided, however that any such election will not serve to withdraw this License (or any other license that has been, or is required to be, granted under the terms of this License), and this License will continue in full force and effect unless terminated as stated above. 69 | 8. Miscellaneous 70 | 71 | Each time You Distribute or Publicly Perform the Work or a Collection, the Licensor offers to the recipient a license to the Work on the same terms and conditions as the license granted to You under this License. 72 | Each time You Distribute or Publicly Perform an Adaptation, Licensor offers to the recipient a license to the original Work on the same terms and conditions as the license granted to You under this License. 73 | If any provision of this License is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this License, and without further action by the parties to this agreement, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable. 74 | No term or provision of this License shall be deemed waived and no breach consented to unless such waiver or consent shall be in writing and signed by the party to be charged with such waiver or consent. 75 | This License constitutes the entire agreement between the parties with respect to the Work licensed here. There are no understandings, agreements or representations with respect to the Work not specified here. Licensor shall not be bound by any additional provisions that may appear in any communication from You. This License may not be modified without the mutual written agreement of the Licensor and You. 76 | The rights granted under, and the subject matter referenced, in this License were drafted utilizing the terminology of the Berne Convention for the Protection of Literary and Artistic Works (as amended on September 28, 1979), the Rome Convention of 1961, the WIPO Copyright Treaty of 1996, the WIPO Performances and Phonograms Treaty of 1996 and the Universal Copyright Convention (as revised on July 24, 1971). These rights and subject matter take effect in the relevant jurisdiction in which the License terms are sought to be enforced according to the corresponding provisions of the implementation of those treaty provisions in the applicable national law. If the standard suite of rights granted under applicable copyright law includes additional rights not granted under this License, such additional rights are deemed to be included in the License; this License is not intended to restrict the license of any rights under applicable law. 77 | -------------------------------------------------------------------------------- /pado-main/README.md: -------------------------------------------------------------------------------- 1 |

2 | pado logo 3 |

4 | 5 | # Pado: Differentiable Light-wave Simulation 6 | 7 | Pado is a differentiable wave-optics library written in PyTorch. Thus, objects and operations defined in Pado are differentiable via PyTorch's automatic differentiation. Pado allows us to differentiate light-wave simulation and it can be integrated into any PyTorch-based network system. This makes Pado particularly useful for research in learning-based computational imaging and display. 8 | 9 | Pado provides high-level abstractions of light-wave simulation, which is useful for users lacking knowledge in wave optics. Pado achieves this with three main objects: light, optical element, and propagator. Constructing your own light-wave simulator with the three objects could improve your working efficiency by focusing on the core problem you want to tackle instead of studying and implementing wave optics from scratch. 10 | 11 | Pado is a Korean word for wave. 12 | 13 | # How to use Pado 14 | We provide a jupyter notebook (`./example/tutorial.ipynb`). More examples will be added later. 15 | 16 | # Prerequisites 17 | - Python 18 | - Pytorch 19 | - Numpy 20 | - Matplotlib 21 | - Scipy 22 | 23 | # About 24 | Pado is maintained and developed by [Seung-Hwan Baek](http://www.shbaek.com) at [POSTECH Computer Graphics Lab](http://cg.postech.ac.kr/). 25 | If you use Pado in your research, please cite Pado using the following BibText template: 26 | 27 | ```bib 28 | @misc{Pado, 29 | Author = {Seung-Hwan baek}, 30 | Year = {2022}, 31 | Note = {https://github.com/shwbaek/pado}, 32 | Title = {Pado: Differentiable Light-wave Simulation} 33 | } 34 | ``` 35 | 36 | # License 37 | Seung-Hwan Baek have developed this software and related documentation (the "Software"); confidential use in source form of the Software, without modification, is permitted provided that the following conditions are met: 38 | 39 | Neither the name of the copyright holder nor the names of any contributors may be used to endorse or promote products derived from the Software without specific prior written permission. 40 | 41 | The use of the software is for Non-Commercial Purposes only. As used in this Agreement, “Non-Commercial Purpose” means for the purpose of education or research in a non-commercial organization only. “Non-Commercial Purpose” excludes, without limitation, any use of the Software for, as part of, or in any way in connection with a product (including software) or service which is sold, offered for sale, licensed, leased, published, loaned or rented. If you require a license for a use excluded by this agreement, please email [shwbaek@postech.ac.kr]. 42 | 43 | Warranty: POSTECH-CGLAB MAKES NO REPRESENTATIONS OR WARRANTIES ABOUT THE SUITABILITY OF THE SOFTWARE, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, OR NON-INFRINGEMENT. POSTECH-CGLAB SHALL NOT BE LIABLE FOR ANY DAMAGES SUFFERED BY LICENSEE AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THIS SOFTWARE OR ITS DERIVATIVES. 44 | 45 | Please refer to [here](./LICENSE) for more details. 46 | -------------------------------------------------------------------------------- /pado-main/docs/images/logo.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/SeeThroughObstructions/f94a297cc87d72dcb17ffdab3bed64cff186ab9f/pado-main/docs/images/logo.pdf -------------------------------------------------------------------------------- /pado-main/docs/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/SeeThroughObstructions/f94a297cc87d72dcb17ffdab3bed64cff186ab9f/pado-main/docs/images/logo.png -------------------------------------------------------------------------------- /pado-main/docs/images/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 20 | 22 | 24 | 27 | 31 | 32 | 35 | 39 | 40 | 41 | 42 | 63 | 65 | 66 | 68 | image/svg+xml 69 | 71 | 72 | 73 | 74 | 80 | 85 | 89 | 93 | 97 | 102 | 107 | 112 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /pado-main/pado/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc -------------------------------------------------------------------------------- /pado-main/pado/cmap_phase.txt: -------------------------------------------------------------------------------- 1 | 0.180392 0.129412 0.917647 2 | 0.188235 0.121569 0.921569 3 | 0.200000 0.117647 0.925490 4 | 0.215686 0.109804 0.933333 5 | 0.227451 0.105882 0.937255 6 | 0.243137 0.101961 0.937255 7 | 0.254902 0.101961 0.941176 8 | 0.270588 0.101961 0.945098 9 | 0.286275 0.101961 0.949020 10 | 0.298039 0.105882 0.949020 11 | 0.313725 0.109804 0.952941 12 | 0.325490 0.113725 0.952941 13 | 0.341176 0.117647 0.956863 14 | 0.352941 0.121569 0.956863 15 | 0.364706 0.129412 0.960784 16 | 0.376471 0.133333 0.960784 17 | 0.388235 0.141176 0.960784 18 | 0.400000 0.145098 0.964706 19 | 0.411765 0.152941 0.964706 20 | 0.423529 0.156863 0.964706 21 | 0.435294 0.164706 0.968627 22 | 0.447059 0.168627 0.968627 23 | 0.458824 0.176471 0.972549 24 | 0.470588 0.180392 0.972549 25 | 0.478431 0.184314 0.972549 26 | 0.490196 0.192157 0.972549 27 | 0.501961 0.196078 0.976471 28 | 0.513725 0.200000 0.976471 29 | 0.521569 0.207843 0.976471 30 | 0.533333 0.211765 0.980392 31 | 0.545098 0.215686 0.980392 32 | 0.556863 0.219608 0.980392 33 | 0.568627 0.223529 0.980392 34 | 0.580392 0.227451 0.980392 35 | 0.592157 0.231373 0.980392 36 | 0.603922 0.235294 0.984314 37 | 0.615686 0.239216 0.984314 38 | 0.627451 0.239216 0.984314 39 | 0.643137 0.243137 0.984314 40 | 0.654902 0.247059 0.984314 41 | 0.666667 0.247059 0.984314 42 | 0.682353 0.250980 0.984314 43 | 0.694118 0.250980 0.984314 44 | 0.705882 0.254902 0.984314 45 | 0.717647 0.254902 0.984314 46 | 0.733333 0.258824 0.980392 47 | 0.745098 0.258824 0.980392 48 | 0.756863 0.262745 0.980392 49 | 0.772549 0.262745 0.980392 50 | 0.784314 0.266667 0.980392 51 | 0.796078 0.266667 0.980392 52 | 0.807843 0.270588 0.980392 53 | 0.819608 0.270588 0.976471 54 | 0.831373 0.274510 0.976471 55 | 0.843137 0.274510 0.976471 56 | 0.854902 0.278431 0.976471 57 | 0.866667 0.282353 0.972549 58 | 0.878431 0.286275 0.972549 59 | 0.886275 0.290196 0.968627 60 | 0.898039 0.294118 0.964706 61 | 0.905882 0.301961 0.964706 62 | 0.913725 0.309804 0.960784 63 | 0.921569 0.313725 0.956863 64 | 0.929412 0.321569 0.952941 65 | 0.937255 0.333333 0.945098 66 | 0.941176 0.341176 0.941176 67 | 0.949020 0.349020 0.933333 68 | 0.952941 0.360784 0.925490 69 | 0.956863 0.372549 0.921569 70 | 0.960784 0.384314 0.913725 71 | 0.964706 0.396078 0.901961 72 | 0.968627 0.407843 0.894118 73 | 0.968627 0.419608 0.886275 74 | 0.972549 0.431373 0.878431 75 | 0.976471 0.443137 0.866667 76 | 0.976471 0.454902 0.858824 77 | 0.976471 0.466667 0.847059 78 | 0.980392 0.478431 0.839216 79 | 0.980392 0.494118 0.827451 80 | 0.980392 0.505882 0.815686 81 | 0.984314 0.517647 0.807843 82 | 0.984314 0.529412 0.796078 83 | 0.984314 0.537255 0.784314 84 | 0.984314 0.549020 0.772549 85 | 0.984314 0.560784 0.764706 86 | 0.984314 0.572549 0.752941 87 | 0.984314 0.584314 0.741176 88 | 0.984314 0.596078 0.729412 89 | 0.984314 0.603922 0.721569 90 | 0.984314 0.615686 0.709804 91 | 0.984314 0.627451 0.698039 92 | 0.984314 0.635294 0.686275 93 | 0.984314 0.647059 0.678431 94 | 0.984314 0.658824 0.666667 95 | 0.984314 0.666667 0.654902 96 | 0.984314 0.678431 0.643137 97 | 0.984314 0.686275 0.631373 98 | 0.984314 0.698039 0.619608 99 | 0.984314 0.705882 0.611765 100 | 0.984314 0.717647 0.600000 101 | 0.984314 0.725490 0.588235 102 | 0.984314 0.733333 0.576471 103 | 0.984314 0.745098 0.564706 104 | 0.984314 0.752941 0.552941 105 | 0.984314 0.760784 0.541176 106 | 0.984314 0.768627 0.529412 107 | 0.984314 0.780392 0.517647 108 | 0.984314 0.788235 0.505882 109 | 0.984314 0.796078 0.494118 110 | 0.988235 0.803922 0.482353 111 | 0.988235 0.811765 0.466667 112 | 0.988235 0.823529 0.454902 113 | 0.988235 0.831373 0.443137 114 | 0.988235 0.839216 0.431373 115 | 0.988235 0.847059 0.415686 116 | 0.984314 0.854902 0.403922 117 | 0.984314 0.862745 0.388235 118 | 0.984314 0.870588 0.372549 119 | 0.984314 0.878431 0.360784 120 | 0.980392 0.886275 0.345098 121 | 0.980392 0.894118 0.329412 122 | 0.976471 0.898039 0.313725 123 | 0.972549 0.905882 0.298039 124 | 0.968627 0.909804 0.282353 125 | 0.964706 0.917647 0.266667 126 | 0.960784 0.921569 0.250980 127 | 0.952941 0.925490 0.235294 128 | 0.949020 0.929412 0.223529 129 | 0.941176 0.929412 0.207843 130 | 0.933333 0.933333 0.192157 131 | 0.925490 0.933333 0.180392 132 | 0.917647 0.933333 0.168627 133 | 0.909804 0.933333 0.156863 134 | 0.898039 0.933333 0.145098 135 | 0.890196 0.929412 0.133333 136 | 0.878431 0.929412 0.125490 137 | 0.866667 0.925490 0.117647 138 | 0.858824 0.921569 0.109804 139 | 0.847059 0.921569 0.105882 140 | 0.835294 0.917647 0.101961 141 | 0.823529 0.913725 0.098039 142 | 0.811765 0.909804 0.094118 143 | 0.800000 0.905882 0.090196 144 | 0.788235 0.901961 0.086275 145 | 0.776471 0.898039 0.086275 146 | 0.768627 0.894118 0.082353 147 | 0.756863 0.890196 0.082353 148 | 0.745098 0.886275 0.078431 149 | 0.733333 0.878431 0.078431 150 | 0.721569 0.874510 0.078431 151 | 0.709804 0.870588 0.074510 152 | 0.698039 0.866667 0.074510 153 | 0.686275 0.862745 0.074510 154 | 0.674510 0.858824 0.070588 155 | 0.662745 0.854902 0.070588 156 | 0.650980 0.850980 0.070588 157 | 0.639216 0.847059 0.066667 158 | 0.627451 0.843137 0.066667 159 | 0.611765 0.835294 0.066667 160 | 0.600000 0.831373 0.062745 161 | 0.588235 0.827451 0.062745 162 | 0.576471 0.823529 0.062745 163 | 0.564706 0.819608 0.058824 164 | 0.552941 0.815686 0.058824 165 | 0.541176 0.811765 0.058824 166 | 0.525490 0.807843 0.054902 167 | 0.513725 0.803922 0.054902 168 | 0.501961 0.796078 0.054902 169 | 0.490196 0.792157 0.050980 170 | 0.474510 0.788235 0.050980 171 | 0.462745 0.784314 0.050980 172 | 0.450980 0.780392 0.047059 173 | 0.435294 0.776471 0.047059 174 | 0.423529 0.772549 0.047059 175 | 0.407843 0.764706 0.047059 176 | 0.396078 0.760784 0.043137 177 | 0.380392 0.756863 0.043137 178 | 0.364706 0.752941 0.043137 179 | 0.352941 0.749020 0.047059 180 | 0.337255 0.745098 0.047059 181 | 0.321569 0.737255 0.047059 182 | 0.309804 0.733333 0.050980 183 | 0.294118 0.729412 0.054902 184 | 0.278431 0.725490 0.062745 185 | 0.266667 0.717647 0.066667 186 | 0.250980 0.713725 0.074510 187 | 0.239216 0.709804 0.086275 188 | 0.227451 0.701961 0.094118 189 | 0.219608 0.698039 0.105882 190 | 0.207843 0.694118 0.117647 191 | 0.203922 0.686275 0.129412 192 | 0.196078 0.682353 0.141176 193 | 0.192157 0.674510 0.156863 194 | 0.192157 0.670588 0.168627 195 | 0.192157 0.662745 0.184314 196 | 0.196078 0.654902 0.200000 197 | 0.200000 0.650980 0.211765 198 | 0.203922 0.643137 0.227451 199 | 0.211765 0.635294 0.243137 200 | 0.215686 0.627451 0.258824 201 | 0.223529 0.619608 0.274510 202 | 0.227451 0.615686 0.290196 203 | 0.235294 0.607843 0.301961 204 | 0.239216 0.600000 0.317647 205 | 0.243137 0.592157 0.333333 206 | 0.250980 0.584314 0.349020 207 | 0.254902 0.576471 0.360784 208 | 0.258824 0.568627 0.376471 209 | 0.262745 0.564706 0.392157 210 | 0.262745 0.556863 0.403922 211 | 0.266667 0.549020 0.419608 212 | 0.266667 0.541176 0.431373 213 | 0.270588 0.533333 0.447059 214 | 0.270588 0.525490 0.458824 215 | 0.270588 0.517647 0.474510 216 | 0.270588 0.509804 0.486275 217 | 0.266667 0.501961 0.501961 218 | 0.266667 0.498039 0.513725 219 | 0.262745 0.490196 0.525490 220 | 0.262745 0.482353 0.541176 221 | 0.258824 0.474510 0.552941 222 | 0.254902 0.466667 0.564706 223 | 0.250980 0.458824 0.580392 224 | 0.243137 0.450980 0.592157 225 | 0.239216 0.443137 0.603922 226 | 0.235294 0.435294 0.615686 227 | 0.227451 0.427451 0.627451 228 | 0.223529 0.419608 0.643137 229 | 0.215686 0.411765 0.654902 230 | 0.207843 0.403922 0.666667 231 | 0.203922 0.396078 0.678431 232 | 0.196078 0.388235 0.690196 233 | 0.192157 0.380392 0.701961 234 | 0.184314 0.372549 0.709804 235 | 0.180392 0.364706 0.721569 236 | 0.176471 0.352941 0.733333 237 | 0.172549 0.345098 0.745098 238 | 0.168627 0.333333 0.752941 239 | 0.164706 0.325490 0.764706 240 | 0.160784 0.313725 0.772549 241 | 0.156863 0.305882 0.784314 242 | 0.156863 0.294118 0.792157 243 | 0.152941 0.282353 0.803922 244 | 0.149020 0.274510 0.811765 245 | 0.149020 0.262745 0.823529 246 | 0.145098 0.250980 0.831373 247 | 0.145098 0.239216 0.839216 248 | 0.141176 0.227451 0.850980 249 | 0.141176 0.215686 0.858824 250 | 0.141176 0.203922 0.866667 251 | 0.141176 0.192157 0.874510 252 | 0.145098 0.180392 0.882353 253 | 0.149020 0.168627 0.890196 254 | 0.152941 0.160784 0.898039 255 | 0.160784 0.149020 0.905882 256 | 0.168627 0.141176 0.909804 257 | -------------------------------------------------------------------------------- /pado-main/pado/conv.py: -------------------------------------------------------------------------------- 1 | from .fourier import fft, ifft 2 | 3 | def conv_fft(img_c, kernel_c, pad_width=None): 4 | """ 5 | Compute the convolution of an image with a convolution kernel using FFT 6 | Args: 7 | img_c: [B,Ch,R,C] image as a complex tensor 8 | kernel_c: [B,Ch,R,C] convolution kernel as a complex tensor 9 | pad_width: (tensor) pad width for the last spatial dimensions. should be (0,0,0,0) for circular convolution. for linear convolution, pad zero by the size of the original image 10 | Returns: 11 | im_conv: [B,Ch,R,C] blurred image 12 | """ 13 | 14 | img_fft = fft(img_c, pad_width=pad_width) 15 | kernel_fft = fft(kernel_c, pad_width=pad_width) 16 | return ifft( img_fft * kernel_fft, pad_width=pad_width) 17 | 18 | -------------------------------------------------------------------------------- /pado-main/pado/fourier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .complex import Complex 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def fft(arr_c, normalized=False, pad_width=None, padval=0, shift=True): 7 | """ 8 | Compute the Fast Fourier transform of a complex tensor 9 | Args: 10 | arr_c: [B,Ch,R,C] complex tensor 11 | normalized: Normalize the FFT output. default: False 12 | pad_width: (tensor) pad width for the last spatial dimensions. 13 | shift: flag for shifting the input data to make the zero-frequency located at the center of the arrc 14 | Returns: 15 | arr_c_fft: [B,Ch,R,C] FFT of the input complex tensor 16 | """ 17 | 18 | arr_c = Complex(mag=arr_c.get_mag().clone(), ang=arr_c.get_ang().clone()) 19 | if pad_width is not None: 20 | if padval == 0: 21 | arr_c.pad_zero(pad_width) 22 | else: 23 | return NotImplementedError('zero padding is only implemented for now') 24 | if shift: 25 | arr_c_shifted = ifftshift(arr_c) 26 | else: 27 | arr_c_shifted = arr_c 28 | arr_c_shifted.to_rect() 29 | 30 | #arr_c_shifted_stack = arr_c_shifted.get_stack() 31 | if normalized is False: 32 | normalized = "backward" 33 | else: 34 | normalized = "forward" 35 | arr_c_shifted_fft = torch.fft.fft2(arr_c_shifted.get_native() , norm=normalized) 36 | arr_c_shifted_fft_c = Complex(real=arr_c_shifted_fft.real, imag=arr_c_shifted_fft.imag) 37 | if shift: 38 | arr_c_fft = fftshift(arr_c_shifted_fft_c) 39 | else: 40 | arr_c_fft = arr_c_shifted_fft_c 41 | 42 | return arr_c_fft 43 | 44 | 45 | 46 | def ifft(arr_c, normalized=False, pad_width=None, shift=True): 47 | """ 48 | Compute the inverse Fast Fourier transform of a complex tensor 49 | Args: 50 | arr_c: [B,Ch,R,C] complex tensor 51 | normalized: Normalize the FFT output. default: False 52 | pad_width: (tensor) pad width for the last spatial dimensions. 53 | shift: flag for inversely shifting the input data 54 | Returns: 55 | arr_c_fft: [B,Ch,R,C] inverse FFT of the input complex tensor 56 | """ 57 | 58 | arr_c = Complex(mag=arr_c.get_mag().clone(), ang=arr_c.get_ang().clone()) 59 | if shift: 60 | arr_c_shifted = ifftshift(arr_c) 61 | else: 62 | arr_c_shifted = arr_c 63 | 64 | arr_c_shifted.to_rect() 65 | if normalized is False: 66 | normalized = "backward" 67 | else: 68 | normalized = "forward" 69 | arr_c_shifted_fft = torch.fft.ifft2(arr_c_shifted.get_native(), norm=normalized) 70 | arr_c_shifted_fft_c = Complex(real=arr_c_shifted_fft.real, imag=arr_c_shifted_fft.imag) 71 | if shift: 72 | arr_c_fft = fftshift(arr_c_shifted_fft_c) 73 | else: 74 | arr_c_fft = arr_c_shifted_fft_c 75 | 76 | if pad_width is not None: 77 | arr_c_fft.crop(pad_width) 78 | 79 | return arr_c_fft 80 | 81 | def fftshift(arr_c, invert=False): 82 | """ 83 | Shift the complex tensor so that the zero-frequency signal located at the center of the input 84 | Args: 85 | arr_c: [B,Ch,R,C] complex tensor 86 | invert: flag for inversely shifting the input data 87 | Returns: 88 | arr_c: [B,Ch,R,C] shifted tensor 89 | """ 90 | 91 | arr_c.to_rect() 92 | shift_adjust = 0 if invert else 1 93 | 94 | arr_c_shape = arr_c.shape() 95 | C = arr_c_shape[-1] 96 | R = arr_c_shape[-2] 97 | 98 | shift_len = (C + shift_adjust) // 2 99 | arr_c = arr_c[...,shift_len:].cat(arr_c[...,:shift_len], -1) 100 | 101 | shift_len = (R + shift_adjust) // 2 102 | arr_c = arr_c[...,shift_len:,:].cat(arr_c[...,:shift_len,:], -2) 103 | 104 | return arr_c 105 | 106 | 107 | def ifftshift(arr_c): 108 | """ 109 | Inversely shift the complex tensor 110 | Args: 111 | arr_c: [B,Ch,R,C] complex tensor 112 | Returns: 113 | arr_c: [B,Ch,R,C] shifted tensor 114 | """ 115 | 116 | return fftshift(arr_c, invert=True) 117 | -------------------------------------------------------------------------------- /pado-main/pado/light.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .complex import Complex 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from scipy.io import savemat 7 | 8 | class Light: 9 | def __init__(self, R, C, pitch, wvl, device, amplitude=None, phase=None, real=None, imag=None, B=1): 10 | """ 11 | Light wave that has a complex field (B,Ch,R,C) as a wavefront 12 | It takes the input wavefront in one of the following types 13 | 1. amplitude and phase 14 | 2. real and imaginary 15 | 3. None ==> we initialize light with amplitude of one and phase of zero 16 | 17 | Args: 18 | R: row 19 | C: column 20 | pitch: pixel pitch in meter 21 | wvl: wavelength of light in meter 22 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ... 23 | amplitude: [batch_size, # of channels, row, column] tensor of wavefront amplitude, default is None 24 | phase: [batch_size, # of channels, row, column] tensor of wavefront amplitude, default is None 25 | real: [batch_size, # of channels, row, column] tensor of wavefront real part, default is None 26 | imag: [batch_size, # of channels, row, column] tensor of wavefront imaginary part, default is None 27 | B: batch size 28 | """ 29 | 30 | self.B = B 31 | self.R = R 32 | self.C = C 33 | self.pitch = pitch 34 | self.device = device 35 | self.wvl = wvl 36 | 37 | if (amplitude==None) and (phase == None) and (real != None) and (imag != None): 38 | self.field = Complex(real=real, imag=imag) 39 | 40 | elif (amplitude!=None) and (phase != None) and (real == None) and (imag == None): 41 | self.field = Complex(mag=amplitude, ang=phase) 42 | 43 | elif (amplitude==None) and (phase == None) and (real == None) and (imag == None): 44 | amplitude = torch.ones((B, 1, self.R, self.C), device=self.device) 45 | phase = torch.zeros((B, 1, self.R, self.C), device=self.device) 46 | self.field = Complex(mag=amplitude, ang=phase) 47 | 48 | else: 49 | NotImplementedError('nope!') 50 | 51 | 52 | def crop(self, crop_width): 53 | """ 54 | Crop the light wavefront by crop_width 55 | Args: 56 | crop_width: (tuple) crop width of the tensor following torch functional pad 57 | """ 58 | 59 | self.field.crop(crop_width) 60 | self.R = self.field.size(2) 61 | self.C = self.field.size(3) 62 | 63 | def clone(self): 64 | """ 65 | Clone the light and return it 66 | """ 67 | 68 | return Light(self.R, self.C, 69 | self.pitch, self.wvl, self.device, 70 | amplitude=self.field.get_mag().clone(), phase=self.field.get_ang().clone(), 71 | B=self.B) 72 | 73 | 74 | def pad(self, pad_width, padval=0): 75 | """ 76 | Pad the light wavefront with a constant value by pad_width 77 | Args: 78 | pad_width: (tuple) pad width of the tensor following torch functional pad 79 | padval: value to pad. default is zero 80 | """ 81 | 82 | if padval == 0: 83 | self.set_amplitude(torch.nn.functional.pad(self.get_amplitude(), pad_width)) 84 | self.set_phase(torch.nn.functional.pad(self.get_phase(), pad_width)) 85 | else: 86 | return NotImplementedError('only zero padding supported') 87 | 88 | self.R += pad_width[0] + pad_width[1] 89 | self.C += pad_width[2] + pad_width[3] 90 | 91 | def set_real(self, real): 92 | """ 93 | Set the real part of the light wavefront 94 | Args: 95 | real: real part in the rect representation of the complex number 96 | """ 97 | 98 | self.field.set_real(real) 99 | 100 | def set_imag(self, imag): 101 | """ 102 | Set the imaginary part of the light wavefront 103 | Args: 104 | imag: imaginary part in the rect representation of the complex number 105 | """ 106 | 107 | self.field.set_imag(imag) 108 | 109 | def set_amplitude(self, amplitude): 110 | """ 111 | Set the amplitude of the light wavefront 112 | Args: 113 | amplitude: amplitude in the polar representation of the complex number 114 | """ 115 | self.field.set_mag(amplitude) 116 | 117 | def set_phase(self, phase): 118 | """ 119 | Set the phase of the complex tensor 120 | Args: 121 | phase: phase in the polar representation of the complex number 122 | """ 123 | self.field.set_ang(phase) 124 | 125 | 126 | def set_field(self, field): 127 | """ 128 | Set the wavefront modulation of the complex tensor 129 | Args: 130 | field: wavefront as a complex number 131 | """ 132 | self.field = field 133 | 134 | def set_pitch(self, pitch): 135 | """ 136 | Set the pixel pitch of the complex tensor 137 | Args: 138 | pitch: pixel pitch in meter 139 | """ 140 | self.pitch = pitch 141 | 142 | 143 | def get_amplitude(self): 144 | """ 145 | Return the amplitude of the wavefront 146 | Returns: 147 | mag: magnitude in the polar representation of the complex number 148 | """ 149 | 150 | return self.field.get_mag() 151 | 152 | def get_phase(self): 153 | """ 154 | Return the phase of the wavefront 155 | Returns: 156 | ang: angle in the polar representation of the complex number 157 | """ 158 | 159 | return self.field.get_ang() 160 | 161 | def get_field(self): 162 | """ 163 | Return the complex wavefront 164 | Returns: 165 | field: complex wavefront 166 | """ 167 | 168 | return self.field 169 | 170 | def get_intensity(self): 171 | """ 172 | Return the intensity of light wavefront 173 | Returns: 174 | intensity: intensity of light 175 | """ 176 | return self.field.get_intensity() 177 | 178 | def get_bandwidth(self): 179 | """ 180 | Return the bandwidth of light wavefront 181 | Returns: 182 | R_m: spatial height of the wavefront 183 | C_m: spatial width of the wavefront 184 | """ 185 | 186 | return self.pitch*self.R, self.pitch*self.C 187 | 188 | def magnify(self, scale_factor, interp_mode='nearest'): 189 | ''' 190 | Change the wavefront resolution without changing the pixel pitch 191 | Args: 192 | scale_factor: scale factor for interpolation used in tensor.nn.functional.interpolate 193 | interp_mode: interpolation method used in torch.nn.functional.interpolate 'bilinear', 'nearest' 194 | ''' 195 | self.field.resize(scale_factor, interp_mode) 196 | self.R = self.field.mag.shape[-2] 197 | self.C = self.field.mag.shape[-1] 198 | 199 | 200 | def resize(self, target_pitch, interp_mode='nearest'): 201 | ''' 202 | Resize the wavefront by changing the pixel pitch. 203 | Args: 204 | target_pitch: new pixel pitch to use 205 | interp_mode: interpolation method used in torch.nn.functional.interpolate 'bilinear', 'nearest' 206 | ''' 207 | scale_factor = self.pitch / target_pitch 208 | self.magnify(scale_factor, interp_mode) 209 | self.set_pitch(target_pitch) 210 | 211 | def set_spherical_light(self, z, dx=0, dy=0): 212 | ''' 213 | Set the wavefront as spherical one coming from the position of (dx,dy,z). 214 | Args: 215 | z: z distance of the spherical light source from the current light position 216 | dx: x distance of the spherical light source from the current light position 217 | dy: y distance of the spherical light source from the current light position 218 | ''' 219 | 220 | [x, y] = np.mgrid[-self.C // 2:self.C // 2, -self.R // 2:self.R // 2].astype(np.float64) 221 | x = x * self.pitch 222 | y = y * self.pitch 223 | r = np.sqrt((x - dx) ** 2 + (y - dy) ** 2 + z ** 2) # this is computed in double precision 224 | theta = 2 * np.pi * r / self.wvl 225 | theta = np.expand_dims(np.expand_dims(theta, axis=0), axis=0)%(2*np.pi) 226 | theta = theta.astype(np.float32) 227 | 228 | theta = torch.tensor(theta, device=self.device) 229 | mag = torch.ones_like(theta) 230 | 231 | self.set_phase(theta) 232 | self.set_amplitude(mag) 233 | 234 | def set_plane_light(self): 235 | ''' 236 | Set the wavefront as a plane wave with zero phase and amptliude of one 237 | ''' 238 | amplitude = torch.ones((1, 1, self.R, self.C), device=self.device) 239 | phase = torch.zeros((1, 1, self.R, self.C), device=self.device) 240 | self.set_amplitude(amplitude) 241 | self.set_phase(phase) 242 | 243 | def save(self, fn): 244 | ''' 245 | Save the amplitude and phase of the light wavefront as a file 246 | Args: 247 | fn: filename to save. the format should be either "npy" or "mat" 248 | 249 | ''' 250 | 251 | amp = self.get_amplitude().data.cpu().numpy() 252 | phase = self.get_phase().data.cpu().numpy() 253 | 254 | if fn[-3:] == 'npy': 255 | np.save(fn, amp, phase) 256 | elif fn[-3:] == 'mat': 257 | savemat(fn, {'amplitude':amp, 'phase':phase}) 258 | else: 259 | print('extension in %s is unknown'%fn) 260 | print('light saved to %s\n'%fn) 261 | 262 | def visualize(self,b=0,c=0): 263 | """ 264 | Visualize the light wave 265 | Args: 266 | b: batch index to visualize default is 0 267 | c: channel index to visualize. default is 0 268 | """ 269 | 270 | bw = self.get_bandwidth() 271 | 272 | plt.figure(figsize=(20,5)) 273 | plt.subplot(131) 274 | amplitude_b = self.get_amplitude().data.cpu()[b,c,...].squeeze() 275 | plt.imshow(amplitude_b, 276 | extent=[0,bw[0]*1e3, 0, bw[1]*1e3], cmap='inferno') 277 | plt.title('amplitude') 278 | plt.xlabel('mm') 279 | plt.ylabel('mm') 280 | plt.colorbar() 281 | 282 | plt.subplot(132) 283 | phase = self.get_phase().data.cpu()[b,c,...].squeeze() 284 | plt.imshow(self.get_phase().data.cpu()[b,c,...].squeeze(), 285 | extent=[0,bw[0]*1e3, 0, bw[1]*1e3], cmap='hsv', vmin=-np.pi, vmax=np.pi) # cyclic colormap 286 | plt.title('phase') 287 | plt.xlabel('mm') 288 | plt.ylabel('mm') 289 | plt.colorbar() 290 | 291 | plt.subplot(133) 292 | intensity_b = self.get_intensity().data.cpu()[b,c,...].squeeze() 293 | plt.imshow(intensity_b, 294 | extent=[0,bw[0]*1e3, 0, bw[1]*1e3], cmap='inferno') 295 | plt.title('intensity') 296 | plt.xlabel('mm') 297 | plt.ylabel('mm') 298 | plt.colorbar() 299 | 300 | plt.suptitle('(%d,%d), pitch:%.2f[um], wvl:%.2f[nm], device:%s'%(self.R, self.C, 301 | self.pitch/1e-6, self.wvl/1e-9, self.device)) 302 | plt.show() 303 | 304 | def shape(self): 305 | """ 306 | Returns the shape of light wavefront 307 | Returns: 308 | shape 309 | """ 310 | return self.field.shape() 311 | -------------------------------------------------------------------------------- /pado-main/pado/material.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Material: 5 | def __init__(self, material_name): 6 | """ 7 | Material of optical elements and its refractive index. 8 | Args: 9 | material name: name of the material. So far, we provide PDMS, FUSED_SILICA, VACCUM. 10 | """ 11 | self.material_name = material_name 12 | 13 | def get_RI(self, wvl): 14 | """ 15 | Return the refractive index of the current material for a wavelength 16 | Args: 17 | wvl: wavelength in meter 18 | Returns: 19 | RI: Refractive index at wvl 20 | """ 21 | 22 | wvl_nm = wvl / 1e-9 23 | if self.material_name == 'PDMS': 24 | RI = np.sqrt(1 + (1.0057 * (wvl_nm**2))/(wvl_nm**2 - 0.013217)) 25 | elif self.material_name == 'FUSED_SILICA': 26 | wvl_um = wvl_nm*1e-3 27 | RI = (1 + 0.6961663 / (1 - (0.0684043 / wvl_um) ** 2) + 0.4079426 / (1 - (0.1162414 / wvl_um) ** 2) + 0.8974794 / (1 - (9.896161 / wvl_um) ** 2)) ** .5 28 | # https://refractiveindex.info/?shelf=glass&book=fused_silica&page=Malitson 29 | elif self.material_name == 'VACUUM': 30 | RI = 1.0 31 | else: 32 | return NotImplementedError('%s is not in the RI list'%self.material_name) 33 | 34 | return RI 35 | 36 | -------------------------------------------------------------------------------- /pado-main/pado/optical_element.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | class OpticalElement: 8 | def __init__(self, R, C, pitch, wvl, device, name="not defined",B=1): 9 | """ 10 | Base class for optical elements. Any optical element change the wavefront of incident light 11 | The change of the wavefront is stored as amplitude and phase tensors 12 | Note that he number of channels is one for the wavefront modulation. 13 | Args: 14 | R: row 15 | C: column 16 | pitch: pixel pitch in meter 17 | wvl: wavelength of light in meter 18 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ... 19 | name: name of the current optical element 20 | B: batch size 21 | """ 22 | 23 | self.name = name 24 | self.B = B 25 | self.R = R 26 | self.C = C 27 | self.pitch = pitch 28 | self.device = device 29 | self.amplitude_change = torch.ones((B, 1, R, C), device=self.device) 30 | self.phase_change = torch.zeros((B, 1, R, C), device=self.device) 31 | self.wvl = wvl 32 | 33 | def shape(self): 34 | """ 35 | Returns the shape of light-wavefront modulation. The nunmber of channels is one 36 | Returns: 37 | shape 38 | """ 39 | return (self.B,1,self.R,self.C) 40 | 41 | def set_pitch(self, pitch): 42 | """ 43 | Set the pixel pitch of the complex tensor 44 | Args: 45 | pitch: pixel pitch in meter 46 | """ 47 | self.pitch = pitch 48 | 49 | def resize(self, target_pitch, interp_mode='nearest'): 50 | ''' 51 | Resize the wavefront change by changing the pixel pitch. 52 | Args: 53 | target_pitch: new pixel pitch to use 54 | interp_mode: interpolation method used in torch.nn.functional.interpolate 'bilinear', 'nearest' 55 | ''' 56 | 57 | scale_factor = self.pitch / target_pitch 58 | self.amplitude_change = F.interpolate(self.amplitude_change, scale_factor=scale_factor, 59 | mode=interp_mode) 60 | self.phase_change = F.interpolate(self.phase_change, scale_factor=scale_factor, 61 | mode=interp_mode) 62 | self.set_pitch(target_pitch) 63 | self.R = self.amplitude_change.shape[-2] 64 | self.C = self.amplitude_change.shape[-1] 65 | 66 | def get_amplitude_change(self): 67 | ''' 68 | Return the amplitude change of the wavefront 69 | Returns: 70 | amplitude change: ampiltude change 71 | ''' 72 | 73 | return self.amplitude_change 74 | 75 | def get_phase_change(self): 76 | ''' 77 | Return the phase change of the wavefront 78 | Returns: 79 | phase change: phase change 80 | ''' 81 | 82 | return self.phase_change 83 | 84 | def set_amplitude_change(self, amplitude): 85 | """ 86 | Set the amplitude change 87 | Args: 88 | amplitude change: amplitude change in the polar representation of the complex number 89 | """ 90 | 91 | assert amplitude.shape[2] == self.R and amplitude.shape[3] == self.C 92 | self.amplitude_change = amplitude 93 | 94 | def set_phase_change(self, phase): 95 | """ 96 | Set the phase change 97 | Args: 98 | phase change: phase change in the polar representation of the complex number 99 | """ 100 | 101 | assert phase.shape[2] == self.R and phase.shape[3] == self.C 102 | self.phase_change = phase 103 | 104 | def pad(self, pad_width, padval=0): 105 | """ 106 | Pad the wavefront change with a constant value by pad_width 107 | Args: 108 | pad_width: (tuple) pad width of the tensor following torch functional pad 109 | padval: value to pad. default is zero 110 | """ 111 | if padval == 0: 112 | self.amplitude_change = torch.nn.functional.pad(self.get_amplitude_change(), pad_width) 113 | self.phase_change = torch.nn.functional.pad(self.get_phase_change(), pad_width) 114 | else: 115 | return NotImplementedError('only zero padding supported') 116 | 117 | self.R += pad_width[0] + pad_width[1] 118 | self.C += pad_width[2] + pad_width[3] 119 | 120 | def forward(self, light, interp_mode='nearest'): 121 | """ 122 | Forward the incident light with the optical element. 123 | Args: 124 | light: incident light 125 | interp_mode: interpolation method used in torch.nn.functional.interpolate 'bilinear', 'nearest' 126 | Returns: 127 | light after interaction with the optical element 128 | """ 129 | 130 | if light.pitch > self.pitch: 131 | light.resize(self.pitch, interp_mode) 132 | light.set_pitch(self.pitch) 133 | elif light.pitch < self.pitch: 134 | self.resize(light.pitch, interp_mode) 135 | self.set_pitch(light.pitch) 136 | 137 | if light.wvl != self.wvl: 138 | return NotImplementedError('wavelength should be same for light and optical elements') 139 | 140 | r1 = np.abs((light.R - self.R)//2) 141 | r2 = np.abs(light.R - self.R) - r1 142 | pad_width = (r1, r2, 0, 0) 143 | if light.R > self.R: 144 | self.pad(pad_width) 145 | elif light.R < self.R: 146 | light.pad(pad_width) 147 | 148 | c1 = np.abs((light.C - self.C)//2) 149 | c2 = np.abs(light.C - self.C) - c1 150 | pad_width = (0, 0, c1, c2) 151 | if light.C > self.C: 152 | self.pad(pad_width) 153 | elif light.C < self.C: 154 | light.pad(pad_width) 155 | 156 | light.set_phase(light.get_phase() + self.get_phase_change()) 157 | light.set_amplitude(light.get_amplitude() * self.get_amplitude_change()) 158 | 159 | return light 160 | 161 | def visualize(self, b=0): 162 | """ 163 | Visualize the wavefront modulation of the optical element 164 | Args: 165 | b: batch index to visualize default is 0 166 | """ 167 | 168 | plt.figure(figsize=(13,6)) 169 | 170 | plt.subplot(121) 171 | plt.imshow(self.get_amplitude_change().data.cpu()[b,...].squeeze()) 172 | plt.title('amplitude') 173 | plt.colorbar() 174 | 175 | plt.subplot(122) 176 | plt.imshow(self.get_phase_change().data.cpu()[b,...].squeeze()) 177 | plt.title('phase') 178 | plt.colorbar() 179 | 180 | plt.suptitle('%s, (%d,%d), pitch:%.2f[um], wvl:%.2f[nm], device:%s' 181 | %(self.name, self.R, self.C, self.pitch/1e-6, self.wvl/1e-9, self.device)) 182 | plt.show() 183 | 184 | 185 | class RefractiveLens(OpticalElement): 186 | def __init__(self, R, C, pitch, focal_length, wvl, device): 187 | """ 188 | Thin refractive lens 189 | Args: 190 | R: row 191 | C: column 192 | pitch: pixel pitch in meter 193 | focal_length: focal length of the lens in meter 194 | wvl: wavelength of light in meter 195 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ... 196 | """ 197 | 198 | super().__init__(R, C, pitch, wvl, device, name="refractive_lens") 199 | 200 | self.set_focal_length(focal_length) 201 | self.set_phase_change( self.compute_phase(self.wvl, shift_x=0, shift_y=0) ) 202 | 203 | def set_focal_length(self, focal_length): 204 | """ 205 | Set the focal length of the lens 206 | Args: 207 | focal_length: focal length in meter 208 | """ 209 | 210 | self.focal_length = focal_length 211 | 212 | def compute_phase(self, wvl, shift_x=0, shift_y=0): 213 | """ 214 | Set the phase of a thin lens 215 | Args: 216 | wvl: wavelength of light in meter 217 | shift_x: x displacement of the lens w.r.t. incident light 218 | shift_y: y displacement of the lens w.r.t. incident light 219 | """ 220 | 221 | bw_R = self.R*self.pitch 222 | bw_C = self.C*self.pitch 223 | 224 | x = np.arange(-bw_C/2, bw_C/2, self.pitch) 225 | x = x[:self.R] 226 | y = np.arange(-bw_R/2, bw_R/2, self.pitch) 227 | y = y[:self.C] 228 | xx,yy = np.meshgrid(x,y) 229 | 230 | theta_change = torch.tensor((-2*np.pi / wvl)*((xx-shift_x)**2 + (yy-shift_y)**2), device=self.device) / (2*self.focal_length) 231 | theta_change = torch.unsqueeze(torch.unsqueeze(theta_change, axis=0), axis=0) 232 | theta_change %= 2*np.pi 233 | theta_change -= np.pi 234 | 235 | return theta_change 236 | 237 | def height2phase(height, wvl, RI, wrap=True): 238 | """ 239 | Convert the height of a material to the corresponding phase shift 240 | Args: 241 | height: height of the material in meter 242 | wvl: wavelength of light in meter 243 | RI: refractive index of the material at the wavelength 244 | wrap: return the wrapped phase [0,2pi] 245 | """ 246 | dRI = RI - 1 247 | wv_n = 2. * np.pi / wvl 248 | phi = wv_n * dRI * height 249 | if wrap: 250 | phi %= 2 * np.pi 251 | return phi 252 | 253 | def phase2height(phase, wvl, RI): 254 | """ 255 | Convert the phase change to the height of a material 256 | Args: 257 | phase: phase change of light 258 | wvl: wavelength of light in meter 259 | RI: refractive index of the material at the wavelength 260 | """ 261 | dRI = RI - 1 262 | return wvl * phase / (2 * np.pi) / dRI 263 | 264 | def radius2phase(r, f, wvl): 265 | return (2 * np.pi * (np.sqrt(r * r + f * f) - f) / wvl) % (2 * np.pi) 266 | 267 | class DOE(OpticalElement): 268 | def __init__(self, R, C, pitch, material, wvl, device, height=None, phase=None, amplitude=None): 269 | """ 270 | Diffractive optical element (DOE) 271 | Args: 272 | R: row 273 | C: column 274 | pitch: pixel pitch in meter 275 | material: material of the DOE 276 | wvl: wavelength of light in meter 277 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ... 278 | height: height map of the material in meter 279 | phase: phase change of light 280 | amplitude: amplitude change of light 281 | """ 282 | 283 | super().__init__(R, C, pitch, wvl, device, name="doe") 284 | 285 | self.material = material 286 | self.height = None 287 | 288 | if amplitude is None: 289 | amplitude = torch.ones((1, 1, self.R, self.C), device=self.device) 290 | 291 | if height is None and phase is not None: 292 | self.mode = 'phase' 293 | self.set_phase_change(phase, wvl) 294 | self.set_amplitude_change(amplitude) 295 | elif height is not None and phase is None: 296 | self.mode = 'height' 297 | self.set_height(height) 298 | self.set_amplitude_change(amplitude) 299 | elif (height is None) and (phase is None) and (amplitude is None): 300 | self.mode = 'phase' 301 | phase = torch.zeros((1, 1, self.R, self.C), device=self.device) 302 | self.set_amplitude_change(amplitude) 303 | self.set_phase_change(phase, wvl) 304 | 305 | 306 | def change_wvl(self, wvl): 307 | """ 308 | Change the wavelength of phase change 309 | Args: 310 | wvl: wavelength of phase change 311 | """ 312 | height = self.get_height() 313 | self.wvl = wvl 314 | phase = height2phase(height, self.wvl, self.material.get_RI(self.wvl)) 315 | self.set_phase_change(phase, self.wvl) 316 | 317 | def set_diffraction_grating_1d(self, slit_width, minh, maxh): 318 | """ 319 | Set the wavefront modulation as 1D diffraction grating 320 | Args: 321 | slit_width: width of slit in meter 322 | minh: minimum height in meter 323 | maxh: maximum height in meter 324 | """ 325 | 326 | slit_width_px = np.round(slit_width / self.pitch) 327 | slit_space_px = slit_width_px 328 | 329 | dg = np.zeros((self.R, self.C)) 330 | slit_num_r = self.R // (2 * slit_width_px) 331 | slit_num_c = self.C // (2 * slit_width_px) 332 | 333 | dg[:] = minh 334 | 335 | for i in range(int(slit_num_c)): 336 | minc = int((slit_width_px + slit_space_px) * i) 337 | maxc = int(minc + slit_width_px) 338 | 339 | dg[:, minc:maxc] = maxh 340 | pc = torch.tensor(dg.astype(np.float32), device=self.device).unsqueeze(0).unsqueeze(0) 341 | self.set_phase_change(pc, self.wvl) 342 | 343 | def set_diffraction_grating_2d(self, slit_width, minh, maxh): 344 | """ 345 | Set the wavefront modulation as 2D diffraction grating 346 | Args: 347 | slit_width: width of slit in meter 348 | minh: minimum height in meter 349 | maxh: maximum height in meter 350 | """ 351 | 352 | slit_width_px = np.round(slit_width / self.pitch) 353 | slit_space_px = slit_width_px 354 | 355 | dg = np.zeros((self.R, self.C)) 356 | slit_num_r = self.R // (2 * slit_width_px) 357 | slit_num_c = self.C // (2 * slit_width_px) 358 | 359 | dg[:] = minh 360 | 361 | for i in range(int(slit_num_r)): 362 | for j in range(int(slit_num_c)): 363 | minc = int((slit_width_px + slit_space_px) * j) 364 | maxc = int(minc + slit_width_px) 365 | minr = int((slit_width_px + slit_space_px) * i) 366 | maxr = int(minr + slit_width_px) 367 | 368 | dg[minr:maxr, minc:maxc] = maxh 369 | 370 | pc = torch.tensor(dg.astype(np.float32), device=self.device).unsqueeze(0).unsqueeze(0) 371 | self.set_phase_change(pc, self.wvl) 372 | 373 | def set_Fresnel_lens(self, focal_length, shift_x=0, shift_y=0): 374 | """ 375 | Set the wavefront modulation as a fresnel lens 376 | Args: 377 | focal_length: focal length in meter 378 | shift_x: x displacement of the lens w.r.t. incident light 379 | shift_y: y displacement of the lens w.r.t. incident light 380 | """ 381 | 382 | x = np.arange(-self.C*self.pitch/2, self.C*self.pitch/2, self.pitch) 383 | y = np.arange(-self.R*self.pitch/2, self.R*self.pitch/2, self.pitch) 384 | xx,yy = np.meshgrid(x,y) 385 | xx = torch.tensor(xx, device=self.device) 386 | yy = torch.tensor(yy, device=self.device) 387 | 388 | phase = (-2*np.pi / self.wvl) * (torch.sqrt((xx-shift_x)**2 + (yy-shift_y)**2 + focal_length**2) - focal_length) 389 | phase = phase % (2*np.pi) 390 | phase -= np.pi 391 | phase = phase.unsqueeze(0).unsqueeze(0) 392 | 393 | self.set_phase_change(phase, self.wvl) 394 | 395 | def resize(self, target_pitch): 396 | ''' 397 | Resize the wavefront by changing the pixel pitch. 398 | Args: 399 | target_pitch: new pixel pitch to use 400 | ''' 401 | scale_factor = self.pitch / target_pitch 402 | super().resize(target_pitch) 403 | 404 | if self.mode == 'phase': 405 | super().resize(target_pitch) 406 | elif self.mode == 'height': 407 | self.set_height(F.interpolate(self.height, scale_factor=scale_factor, mode='bilinear', align_corners=False)) 408 | else: 409 | NotImplementedError('Mode is not set.') 410 | 411 | def get_height(self): 412 | """ 413 | Return the height map of the DOE 414 | Returns: 415 | height map: height map in meter 416 | """ 417 | 418 | if self.mode == 'height': 419 | return self.height 420 | elif self.mode == 'phase': 421 | height = phase2height(self.phase_change, self.wvl, self.material.get_RI(self.wvl)) 422 | return height 423 | else: 424 | NotImplementedError('Mode is not set.') 425 | 426 | def get_phase_change(self): 427 | """ 428 | Return the phase change induced by the DOE 429 | Returns: 430 | phase change: phase change 431 | """ 432 | if self.mode == 'height': 433 | self.to_phase_mode() 434 | return self.phase_change 435 | 436 | def set_height(self, height): 437 | """ 438 | Set the height map of the DOE 439 | Args: 440 | height map: height map in meter 441 | """ 442 | 443 | if self.mode == 'height': 444 | self.height = height 445 | elif self.mode == 'phase': 446 | self.set_phase_change(height2phase(height, self.wvl, self.material.get_RI(self.wvl)), self.wvl) 447 | 448 | def set_phase_change(self, phase_change, wvl): 449 | """ 450 | Set the phase change induced by the DOE 451 | Args: 452 | phase change: phase change 453 | """ 454 | 455 | if self.mode == 'height': 456 | self.set_height(phase2height(phase_change, wvl, self.material.get_RI(wvl))) 457 | if self.mode == 'phase': 458 | self.wvl = wvl 459 | self.phase_change = phase_change 460 | 461 | def to_phase_mode(self): 462 | """ 463 | Change the mode to phase change 464 | """ 465 | if self.mode == 'height': 466 | self.phase_change = height2phase(self.height, self.wvl, self.material.get_RI(self.wvl)) 467 | self.mode = 'phase' 468 | self.height = None 469 | 470 | def to_height_mode(self): 471 | """ 472 | Change the mode to height 473 | """ 474 | if self.mode == 'phase': 475 | self.height = phase2height(self.phase_change, self.wvl, self.material.get_RI(self.wvl)) 476 | self.mode = 'height' 477 | 478 | 479 | class SLM(OpticalElement): 480 | def __init__(self, R, C, pitch, wvl, device, B=1): 481 | """ 482 | Spatial light modulator (SLM) 483 | Args: 484 | R: row 485 | C: column 486 | pitch: pixel pitch in meter 487 | wvl: wavelength of light in meter 488 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ... 489 | B: batch size 490 | """ 491 | 492 | super().__init__(R, C, pitch, wvl, device, name="SLM", B=B) 493 | 494 | def set_lens(self, focal_length, shift_x=0, shift_y=0): 495 | """ 496 | Set the phase of a thin lens 497 | Args: 498 | wvl: wavelength of light in meter 499 | shift_x: x displacement of the lens w.r.t. incident light 500 | shift_y: y displacement of the lens w.r.t. incident light 501 | """ 502 | 503 | x = np.arange(-self.C*self.pitch/2, self.C*self.pitch/2, self.pitch) 504 | y = np.arange(-self.R*self.pitch/2, self.R*self.pitch/2, self.pitch) 505 | xx,yy = np.meshgrid(x,y) 506 | 507 | phase = (2*np.pi / self.wvl)*((xx-shift_x)**2 + (yy-shift_y)**2) / (2*focal_length) 508 | phase = torch.tensor(phase.astype(np.float32), device=self.device).unsqueeze(0).unsqueeze(0) 509 | phase = phase % (2*np.pi) 510 | phase -= np.pi 511 | 512 | self.set_phase_change(phase, self.wvl) 513 | 514 | def set_amplitude_change(self, amplitude, wvl): 515 | """ 516 | Set the amplitude change 517 | Args: 518 | amplitude change: amplitude change in the polar representation of the complex number 519 | wvl: wavelength of light in meter 520 | 521 | """ 522 | self.wvl = wvl 523 | super().set_amplitude_change(amplitude) 524 | 525 | def set_phase_change(self, phase_change, wvl): 526 | """ 527 | Set the phase change 528 | Args: 529 | phase change: phase change in the polar representation of the complex number 530 | wvl: wavelength of light in meter 531 | """ 532 | self.wvl = wvl 533 | super().set_phase_change(phase_change) 534 | 535 | 536 | class Aperture(OpticalElement): 537 | def __init__(self, R, C, pitch, aperture_diameter, aperture_shape, wvl, device='cpu'): 538 | """ 539 | Aperture 540 | Args: 541 | R: row 542 | C: column 543 | pitch: pixel pitch in meter 544 | aperture_diameter: diamater of the aperture in meter 545 | aperture_shape: shape of the aperture. {'square', 'circle'} 546 | wvl: wavelength of light in meter 547 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ... 548 | """ 549 | 550 | super().__init__(R, C, pitch, wvl, device, name="aperture") 551 | 552 | self.aperture_diameter = aperture_diameter 553 | self.aperture_shape = aperture_shape 554 | self.amplitude_change = torch.zeros((self.R, self.C), device=device) 555 | if self.aperture_shape == 'square': 556 | self.set_square() 557 | elif self.aperture_shape == 'circle': 558 | self.set_circle() 559 | else: 560 | return NotImplementedError 561 | 562 | def set_square(self): 563 | """ 564 | Set the amplitude modulation of the aperture as square 565 | """ 566 | 567 | self.aperture_shape = 'square' 568 | 569 | [x, y] = np.mgrid[-self.R // 2:self.R // 2, -self.C // 2:self.C // 2].astype(np.float32) 570 | r = self.pitch * np.asarray([abs(x), abs(y)]).max(axis=0) 571 | r = np.expand_dims(np.expand_dims(r, axis=0), axis=0) 572 | 573 | max_val = self.aperture_diameter / 2 574 | amp = (r <= max_val).astype(np.float32) 575 | amp[amp == 0] = 1e-20 # to enable stable learning 576 | self.amplitude_change = torch.tensor(amp, device=self.device) 577 | 578 | def set_circle(self, cx=0, cy=0, dia=None): 579 | """ 580 | Set the amplitude modulation of the aperture as circle 581 | Args: 582 | cx, cy: relative center position of the circle with respect to the center of the light wavefront 583 | dia: circle diameter 584 | """ 585 | [x, y] = np.mgrid[-self.R // 2:self.C // 2, -self.R // 2:self.C // 2].astype(np.float32) 586 | r2 = (x-cx) ** 2 + (y-cy) ** 2 587 | r2[r2 < 0] = 1e-20 588 | r = self.pitch * np.sqrt(r2) 589 | r = np.expand_dims(np.expand_dims(r, axis=0), axis=0) 590 | 591 | if dia is not None: 592 | self.aperture_diameter = dia 593 | self.aperture_shape = 'circle' 594 | max_val = self.aperture_diameter / 2 595 | amp = (r <= max_val).astype(np.float32) 596 | amp[amp == 0] = 1e-20 597 | self.amplitude_change = torch.tensor(amp, device=self.device) 598 | 599 | def quantize(x, levels, vmin=None, vmax=None, include_vmax=True): 600 | """ 601 | Quantize the floating array 602 | Args: 603 | levels: number of quantization levels 604 | vmin: minimum value for quantization 605 | vmax: maximum value for quantization 606 | include_vmax: include vmax for the quantized levels 607 | False: quantize x with the space of 1/levels-1. 608 | True: quantize x with the space of 1/levels 609 | """ 610 | 611 | if include_vmax is False: 612 | if levels == 0: 613 | return x 614 | 615 | if vmin is None: 616 | vmin = x.min() 617 | if vmax is None: 618 | vmax = x.max() 619 | 620 | #assert(vmin <= vmax) 621 | 622 | normalized = (x - vmin) / (vmax - vmin + 1e-16) 623 | if type(x) is np.ndarray: 624 | levelized = np.floor(normalized * levels) / (levels - 1) 625 | elif type(x) is torch.tensor: 626 | levelized = (normalized * levels).floor() / (levels - 1) 627 | result = levelized * (vmax - vmin) + vmin 628 | result[result < vmin] = vmin 629 | result[result > vmax] = vmax 630 | 631 | elif include_vmax is True: 632 | space = (x.max()-x.min())/levels 633 | vmin = x.min() 634 | vmax = vmin + space*(levels-1) 635 | if type(x) is np.ndarray: 636 | result = (np.floor((x-vmin)/space))*space + vmin 637 | elif type(x) is torch.tensor: 638 | result = (((x-vmin)/space).floor())*space + vmin 639 | result[resultvmax] = vmax 641 | 642 | return result -------------------------------------------------------------------------------- /pado-main/pado/propagator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .fourier import fft 4 | from .complex import Complex 5 | from .conv import conv_fft 6 | 7 | 8 | def compute_pad_width(field, linear): 9 | """ 10 | Compute the pad width of an array for FFT-based convolution 11 | Args: 12 | field: (B,Ch,R,C) complex tensor 13 | linear: True or False, flag for linear convolution (zero padding) or circular convolution (no padding) 14 | Returns: 15 | pad_width: pad-width tensor 16 | """ 17 | 18 | if linear: 19 | R,C = field.shape()[-2:] 20 | pad_width = (C//2, C//2, R//2, R//2) 21 | else: 22 | pad_width = (0,0,0,0) 23 | return pad_width 24 | 25 | def unpad(field_padded, pad_width): 26 | """ 27 | Unpad the already-padded complex tensor 28 | Args: 29 | field_padded: (B,Ch,R,C) padded complex tensor 30 | pad_width: pad-width tensor 31 | Returns: 32 | field: unpadded complex tensor 33 | """ 34 | 35 | field = field_padded[...,pad_width[2]:-pad_width[3],pad_width[0]:-pad_width[1]] 36 | return field 37 | 38 | class Propagator: 39 | def __init__(self, mode): 40 | """ 41 | Free-space propagator of light waves 42 | One can simulate the propagation of light waves on free space (no medium change at all). 43 | Args: 44 | mode: type of propagator. currently, we support "Fraunhofer" propagation or "Fresnel" propagation. Use Fraunhofer for far-field propagation and Fresnel for near-field propagation. 45 | """ 46 | self.mode = mode 47 | 48 | def forward(self, light, z, linear=True): 49 | """ 50 | Forward the incident light with the propagator. 51 | Args: 52 | light: incident light 53 | z: propagation distance in meter 54 | linear: True or False, flag for linear convolution (zero padding) or circular convolution (no padding) 55 | Returns: 56 | light: light after propagation 57 | """ 58 | 59 | if self.mode == 'Fraunhofer': 60 | return self.forward_Fraunhofer(light, z, linear) 61 | if self.mode == 'Fresnel': 62 | return self.forward_Fresnel(light, z, linear) 63 | else: 64 | return NotImplementedError('%s propagator is not implemented'%self.mode) 65 | 66 | 67 | def forward_Fraunhofer(self, light, z, linear=True): 68 | """ 69 | Forward the incident light with the Fraunhofer propagator. 70 | Args: 71 | light: incident light 72 | z: propagation distance in meter. 73 | The propagated wavefront is independent w.r.t. the travel distance z. 74 | The distance z only affects the size of the "pixel", effectively adjusting the entire image size. 75 | linear: True or False, flag for linear convolution (zero padding) or circular convolution (no padding) 76 | Returns: 77 | light: light after propagation 78 | """ 79 | 80 | pad_width = compute_pad_width(light.field, linear) 81 | field_propagated = fft(light.field, pad_width=pad_width) 82 | field_propagated = unpad(field_propagated, pad_width) 83 | 84 | # based on the Fraunhofer reparametrization (u=x/wvl*z) and the Fourier frequency sampling (1/bandwidth) 85 | bw_r = light.get_bandwidth()[0] 86 | bw_c = light.get_bandwidth()[1] 87 | pitch_r_after_propagation = light.wvl*z/bw_r 88 | pitch_c_after_propagation = light.wvl*z/bw_c 89 | 90 | light_propagated = light.clone() 91 | 92 | # match the x-y pixel pitch using resampling 93 | if pitch_r_after_propagation >= pitch_c_after_propagation: 94 | scale_c = 1 95 | scale_r = pitch_r_after_propagation/pitch_c_after_propagation 96 | pitch_after_propagation = pitch_c_after_propagation 97 | elif pitch_r_after_propagation < pitch_c_after_propagation: 98 | scale_r = 1 99 | scale_c = pitch_c_after_propagation/pitch_r_after_propagation 100 | pitch_after_propagation = pitch_r_after_propagation 101 | 102 | light_propagated.set_field(field_propagated) 103 | light_propagated.magnify((scale_r,scale_c)) 104 | light_propagated.set_pitch(pitch_after_propagation) 105 | 106 | return light_propagated 107 | 108 | def forward_Fresnel(self, light, z, linear): 109 | """ 110 | Forward the incident light with the Fresnel propagator. 111 | Args: 112 | light: incident light 113 | z: propagation distance in meter. 114 | linear: True or False, flag for linear convolution (zero padding) or circular convolution (no padding) 115 | Returns: 116 | light: light after propagation 117 | """ 118 | field_input = light.field 119 | 120 | # compute the convolutional kernel 121 | sx = light.C / 2 122 | sy = light.R / 2 123 | x = np.arange(-sx, sx, 1) 124 | y = np.arange(-sy, sy, 1) 125 | xx, yy = np.meshgrid(x,y) 126 | xx = torch.from_numpy(xx*light.pitch).to(light.device) 127 | yy = torch.from_numpy(yy*light.pitch).to(light.device) 128 | k = 2*np.pi/light.wvl # wavenumber 129 | phase = (k*(xx**2 + yy**2)/(2*z)) 130 | amplitude = torch.ones_like(phase) / z / light.wvl 131 | conv_kernel = Complex(mag=amplitude, ang=phase) 132 | 133 | # Propagation with the convolution kernel 134 | pad_width = compute_pad_width(field_input, linear) 135 | 136 | field_propagated = conv_fft(field_input, conv_kernel, pad_width) 137 | 138 | # return the propagated light 139 | light_propagated = light.clone() 140 | light_propagated.set_field(field_propagated) 141 | 142 | return light_propagated 143 | 144 | -------------------------------------------------------------------------------- /pado/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc -------------------------------------------------------------------------------- /pado/cmap_phase.txt: -------------------------------------------------------------------------------- 1 | 0.180392 0.129412 0.917647 2 | 0.188235 0.121569 0.921569 3 | 0.200000 0.117647 0.925490 4 | 0.215686 0.109804 0.933333 5 | 0.227451 0.105882 0.937255 6 | 0.243137 0.101961 0.937255 7 | 0.254902 0.101961 0.941176 8 | 0.270588 0.101961 0.945098 9 | 0.286275 0.101961 0.949020 10 | 0.298039 0.105882 0.949020 11 | 0.313725 0.109804 0.952941 12 | 0.325490 0.113725 0.952941 13 | 0.341176 0.117647 0.956863 14 | 0.352941 0.121569 0.956863 15 | 0.364706 0.129412 0.960784 16 | 0.376471 0.133333 0.960784 17 | 0.388235 0.141176 0.960784 18 | 0.400000 0.145098 0.964706 19 | 0.411765 0.152941 0.964706 20 | 0.423529 0.156863 0.964706 21 | 0.435294 0.164706 0.968627 22 | 0.447059 0.168627 0.968627 23 | 0.458824 0.176471 0.972549 24 | 0.470588 0.180392 0.972549 25 | 0.478431 0.184314 0.972549 26 | 0.490196 0.192157 0.972549 27 | 0.501961 0.196078 0.976471 28 | 0.513725 0.200000 0.976471 29 | 0.521569 0.207843 0.976471 30 | 0.533333 0.211765 0.980392 31 | 0.545098 0.215686 0.980392 32 | 0.556863 0.219608 0.980392 33 | 0.568627 0.223529 0.980392 34 | 0.580392 0.227451 0.980392 35 | 0.592157 0.231373 0.980392 36 | 0.603922 0.235294 0.984314 37 | 0.615686 0.239216 0.984314 38 | 0.627451 0.239216 0.984314 39 | 0.643137 0.243137 0.984314 40 | 0.654902 0.247059 0.984314 41 | 0.666667 0.247059 0.984314 42 | 0.682353 0.250980 0.984314 43 | 0.694118 0.250980 0.984314 44 | 0.705882 0.254902 0.984314 45 | 0.717647 0.254902 0.984314 46 | 0.733333 0.258824 0.980392 47 | 0.745098 0.258824 0.980392 48 | 0.756863 0.262745 0.980392 49 | 0.772549 0.262745 0.980392 50 | 0.784314 0.266667 0.980392 51 | 0.796078 0.266667 0.980392 52 | 0.807843 0.270588 0.980392 53 | 0.819608 0.270588 0.976471 54 | 0.831373 0.274510 0.976471 55 | 0.843137 0.274510 0.976471 56 | 0.854902 0.278431 0.976471 57 | 0.866667 0.282353 0.972549 58 | 0.878431 0.286275 0.972549 59 | 0.886275 0.290196 0.968627 60 | 0.898039 0.294118 0.964706 61 | 0.905882 0.301961 0.964706 62 | 0.913725 0.309804 0.960784 63 | 0.921569 0.313725 0.956863 64 | 0.929412 0.321569 0.952941 65 | 0.937255 0.333333 0.945098 66 | 0.941176 0.341176 0.941176 67 | 0.949020 0.349020 0.933333 68 | 0.952941 0.360784 0.925490 69 | 0.956863 0.372549 0.921569 70 | 0.960784 0.384314 0.913725 71 | 0.964706 0.396078 0.901961 72 | 0.968627 0.407843 0.894118 73 | 0.968627 0.419608 0.886275 74 | 0.972549 0.431373 0.878431 75 | 0.976471 0.443137 0.866667 76 | 0.976471 0.454902 0.858824 77 | 0.976471 0.466667 0.847059 78 | 0.980392 0.478431 0.839216 79 | 0.980392 0.494118 0.827451 80 | 0.980392 0.505882 0.815686 81 | 0.984314 0.517647 0.807843 82 | 0.984314 0.529412 0.796078 83 | 0.984314 0.537255 0.784314 84 | 0.984314 0.549020 0.772549 85 | 0.984314 0.560784 0.764706 86 | 0.984314 0.572549 0.752941 87 | 0.984314 0.584314 0.741176 88 | 0.984314 0.596078 0.729412 89 | 0.984314 0.603922 0.721569 90 | 0.984314 0.615686 0.709804 91 | 0.984314 0.627451 0.698039 92 | 0.984314 0.635294 0.686275 93 | 0.984314 0.647059 0.678431 94 | 0.984314 0.658824 0.666667 95 | 0.984314 0.666667 0.654902 96 | 0.984314 0.678431 0.643137 97 | 0.984314 0.686275 0.631373 98 | 0.984314 0.698039 0.619608 99 | 0.984314 0.705882 0.611765 100 | 0.984314 0.717647 0.600000 101 | 0.984314 0.725490 0.588235 102 | 0.984314 0.733333 0.576471 103 | 0.984314 0.745098 0.564706 104 | 0.984314 0.752941 0.552941 105 | 0.984314 0.760784 0.541176 106 | 0.984314 0.768627 0.529412 107 | 0.984314 0.780392 0.517647 108 | 0.984314 0.788235 0.505882 109 | 0.984314 0.796078 0.494118 110 | 0.988235 0.803922 0.482353 111 | 0.988235 0.811765 0.466667 112 | 0.988235 0.823529 0.454902 113 | 0.988235 0.831373 0.443137 114 | 0.988235 0.839216 0.431373 115 | 0.988235 0.847059 0.415686 116 | 0.984314 0.854902 0.403922 117 | 0.984314 0.862745 0.388235 118 | 0.984314 0.870588 0.372549 119 | 0.984314 0.878431 0.360784 120 | 0.980392 0.886275 0.345098 121 | 0.980392 0.894118 0.329412 122 | 0.976471 0.898039 0.313725 123 | 0.972549 0.905882 0.298039 124 | 0.968627 0.909804 0.282353 125 | 0.964706 0.917647 0.266667 126 | 0.960784 0.921569 0.250980 127 | 0.952941 0.925490 0.235294 128 | 0.949020 0.929412 0.223529 129 | 0.941176 0.929412 0.207843 130 | 0.933333 0.933333 0.192157 131 | 0.925490 0.933333 0.180392 132 | 0.917647 0.933333 0.168627 133 | 0.909804 0.933333 0.156863 134 | 0.898039 0.933333 0.145098 135 | 0.890196 0.929412 0.133333 136 | 0.878431 0.929412 0.125490 137 | 0.866667 0.925490 0.117647 138 | 0.858824 0.921569 0.109804 139 | 0.847059 0.921569 0.105882 140 | 0.835294 0.917647 0.101961 141 | 0.823529 0.913725 0.098039 142 | 0.811765 0.909804 0.094118 143 | 0.800000 0.905882 0.090196 144 | 0.788235 0.901961 0.086275 145 | 0.776471 0.898039 0.086275 146 | 0.768627 0.894118 0.082353 147 | 0.756863 0.890196 0.082353 148 | 0.745098 0.886275 0.078431 149 | 0.733333 0.878431 0.078431 150 | 0.721569 0.874510 0.078431 151 | 0.709804 0.870588 0.074510 152 | 0.698039 0.866667 0.074510 153 | 0.686275 0.862745 0.074510 154 | 0.674510 0.858824 0.070588 155 | 0.662745 0.854902 0.070588 156 | 0.650980 0.850980 0.070588 157 | 0.639216 0.847059 0.066667 158 | 0.627451 0.843137 0.066667 159 | 0.611765 0.835294 0.066667 160 | 0.600000 0.831373 0.062745 161 | 0.588235 0.827451 0.062745 162 | 0.576471 0.823529 0.062745 163 | 0.564706 0.819608 0.058824 164 | 0.552941 0.815686 0.058824 165 | 0.541176 0.811765 0.058824 166 | 0.525490 0.807843 0.054902 167 | 0.513725 0.803922 0.054902 168 | 0.501961 0.796078 0.054902 169 | 0.490196 0.792157 0.050980 170 | 0.474510 0.788235 0.050980 171 | 0.462745 0.784314 0.050980 172 | 0.450980 0.780392 0.047059 173 | 0.435294 0.776471 0.047059 174 | 0.423529 0.772549 0.047059 175 | 0.407843 0.764706 0.047059 176 | 0.396078 0.760784 0.043137 177 | 0.380392 0.756863 0.043137 178 | 0.364706 0.752941 0.043137 179 | 0.352941 0.749020 0.047059 180 | 0.337255 0.745098 0.047059 181 | 0.321569 0.737255 0.047059 182 | 0.309804 0.733333 0.050980 183 | 0.294118 0.729412 0.054902 184 | 0.278431 0.725490 0.062745 185 | 0.266667 0.717647 0.066667 186 | 0.250980 0.713725 0.074510 187 | 0.239216 0.709804 0.086275 188 | 0.227451 0.701961 0.094118 189 | 0.219608 0.698039 0.105882 190 | 0.207843 0.694118 0.117647 191 | 0.203922 0.686275 0.129412 192 | 0.196078 0.682353 0.141176 193 | 0.192157 0.674510 0.156863 194 | 0.192157 0.670588 0.168627 195 | 0.192157 0.662745 0.184314 196 | 0.196078 0.654902 0.200000 197 | 0.200000 0.650980 0.211765 198 | 0.203922 0.643137 0.227451 199 | 0.211765 0.635294 0.243137 200 | 0.215686 0.627451 0.258824 201 | 0.223529 0.619608 0.274510 202 | 0.227451 0.615686 0.290196 203 | 0.235294 0.607843 0.301961 204 | 0.239216 0.600000 0.317647 205 | 0.243137 0.592157 0.333333 206 | 0.250980 0.584314 0.349020 207 | 0.254902 0.576471 0.360784 208 | 0.258824 0.568627 0.376471 209 | 0.262745 0.564706 0.392157 210 | 0.262745 0.556863 0.403922 211 | 0.266667 0.549020 0.419608 212 | 0.266667 0.541176 0.431373 213 | 0.270588 0.533333 0.447059 214 | 0.270588 0.525490 0.458824 215 | 0.270588 0.517647 0.474510 216 | 0.270588 0.509804 0.486275 217 | 0.266667 0.501961 0.501961 218 | 0.266667 0.498039 0.513725 219 | 0.262745 0.490196 0.525490 220 | 0.262745 0.482353 0.541176 221 | 0.258824 0.474510 0.552941 222 | 0.254902 0.466667 0.564706 223 | 0.250980 0.458824 0.580392 224 | 0.243137 0.450980 0.592157 225 | 0.239216 0.443137 0.603922 226 | 0.235294 0.435294 0.615686 227 | 0.227451 0.427451 0.627451 228 | 0.223529 0.419608 0.643137 229 | 0.215686 0.411765 0.654902 230 | 0.207843 0.403922 0.666667 231 | 0.203922 0.396078 0.678431 232 | 0.196078 0.388235 0.690196 233 | 0.192157 0.380392 0.701961 234 | 0.184314 0.372549 0.709804 235 | 0.180392 0.364706 0.721569 236 | 0.176471 0.352941 0.733333 237 | 0.172549 0.345098 0.745098 238 | 0.168627 0.333333 0.752941 239 | 0.164706 0.325490 0.764706 240 | 0.160784 0.313725 0.772549 241 | 0.156863 0.305882 0.784314 242 | 0.156863 0.294118 0.792157 243 | 0.152941 0.282353 0.803922 244 | 0.149020 0.274510 0.811765 245 | 0.149020 0.262745 0.823529 246 | 0.145098 0.250980 0.831373 247 | 0.145098 0.239216 0.839216 248 | 0.141176 0.227451 0.850980 249 | 0.141176 0.215686 0.858824 250 | 0.141176 0.203922 0.866667 251 | 0.141176 0.192157 0.874510 252 | 0.145098 0.180392 0.882353 253 | 0.149020 0.168627 0.890196 254 | 0.152941 0.160784 0.898039 255 | 0.160784 0.149020 0.905882 256 | 0.168627 0.141176 0.909804 257 | -------------------------------------------------------------------------------- /pado/conv.py: -------------------------------------------------------------------------------- 1 | from .fourier import fft, ifft 2 | 3 | def conv_fft(img_c, kernel_c, pad_width=None): 4 | """ 5 | Compute the convolution of an image with a convolution kernel using FFT 6 | Args: 7 | img_c: [B,Ch,R,C] image as a complex tensor 8 | kernel_c: [B,Ch,R,C] convolution kernel as a complex tensor 9 | pad_width: (tensor) pad width for the last spatial dimensions. should be (0,0,0,0) for circular convolution. for linear convolution, pad zero by the size of the original image 10 | Returns: 11 | im_conv: [B,Ch,R,C] blurred image 12 | """ 13 | 14 | img_fft = fft(img_c, pad_width=pad_width) 15 | kernel_fft = fft(kernel_c, pad_width=pad_width) 16 | return ifft( img_fft * kernel_fft, pad_width=pad_width) 17 | 18 | -------------------------------------------------------------------------------- /pado/fourier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .complex import Complex 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def fft(arr_c, normalized=False, pad_width=None, padval=0, shift=True): 7 | """ 8 | Compute the Fast Fourier transform of a complex tensor 9 | Args: 10 | arr_c: [B,Ch,R,C] complex tensor 11 | normalized: Normalize the FFT output. default: False 12 | pad_width: (tensor) pad width for the last spatial dimensions. 13 | shift: flag for shifting the input data to make the zero-frequency located at the center of the arrc 14 | Returns: 15 | arr_c_fft: [B,Ch,R,C] FFT of the input complex tensor 16 | """ 17 | 18 | arr_c = Complex(mag=arr_c.get_mag().clone(), ang=arr_c.get_ang().clone()) 19 | if pad_width is not None: 20 | if padval == 0: 21 | arr_c.pad_zero(pad_width) 22 | else: 23 | return NotImplementedError('zero padding is only implemented for now') 24 | if shift: 25 | arr_c_shifted = ifftshift(arr_c) 26 | else: 27 | arr_c_shifted = arr_c 28 | arr_c_shifted.to_rect() 29 | 30 | #arr_c_shifted_stack = arr_c_shifted.get_stack() 31 | if normalized is False: 32 | normalized = "backward" 33 | else: 34 | normalized = "forward" 35 | arr_c_shifted_fft = torch.fft.fft2(arr_c_shifted.get_native() , norm=normalized) 36 | arr_c_shifted_fft_c = Complex(real=arr_c_shifted_fft.real, imag=arr_c_shifted_fft.imag) 37 | if shift: 38 | arr_c_fft = fftshift(arr_c_shifted_fft_c) 39 | else: 40 | arr_c_fft = arr_c_shifted_fft_c 41 | 42 | return arr_c_fft 43 | 44 | 45 | 46 | def ifft(arr_c, normalized=False, pad_width=None, shift=True): 47 | """ 48 | Compute the inverse Fast Fourier transform of a complex tensor 49 | Args: 50 | arr_c: [B,Ch,R,C] complex tensor 51 | normalized: Normalize the FFT output. default: False 52 | pad_width: (tensor) pad width for the last spatial dimensions. 53 | shift: flag for inversely shifting the input data 54 | Returns: 55 | arr_c_fft: [B,Ch,R,C] inverse FFT of the input complex tensor 56 | """ 57 | 58 | arr_c = Complex(mag=arr_c.get_mag().clone(), ang=arr_c.get_ang().clone()) 59 | if shift: 60 | arr_c_shifted = ifftshift(arr_c) 61 | else: 62 | arr_c_shifted = arr_c 63 | 64 | arr_c_shifted.to_rect() 65 | if normalized is False: 66 | normalized = "backward" 67 | else: 68 | normalized = "forward" 69 | arr_c_shifted_fft = torch.fft.ifft2(arr_c_shifted.get_native(), norm=normalized) 70 | arr_c_shifted_fft_c = Complex(real=arr_c_shifted_fft.real, imag=arr_c_shifted_fft.imag) 71 | if shift: 72 | arr_c_fft = fftshift(arr_c_shifted_fft_c) 73 | else: 74 | arr_c_fft = arr_c_shifted_fft_c 75 | 76 | if pad_width is not None: 77 | arr_c_fft.crop(pad_width) 78 | 79 | return arr_c_fft 80 | 81 | def fftshift(arr_c, invert=False): 82 | """ 83 | Shift the complex tensor so that the zero-frequency signal located at the center of the input 84 | Args: 85 | arr_c: [B,Ch,R,C] complex tensor 86 | invert: flag for inversely shifting the input data 87 | Returns: 88 | arr_c: [B,Ch,R,C] shifted tensor 89 | """ 90 | 91 | arr_c.to_rect() 92 | shift_adjust = 0 if invert else 1 93 | 94 | arr_c_shape = arr_c.shape() 95 | C = arr_c_shape[-1] 96 | R = arr_c_shape[-2] 97 | 98 | shift_len = (C + shift_adjust) // 2 99 | arr_c = arr_c[...,shift_len:].cat(arr_c[...,:shift_len], -1) 100 | 101 | shift_len = (R + shift_adjust) // 2 102 | arr_c = arr_c[...,shift_len:,:].cat(arr_c[...,:shift_len,:], -2) 103 | 104 | return arr_c 105 | 106 | 107 | def ifftshift(arr_c): 108 | """ 109 | Inversely shift the complex tensor 110 | Args: 111 | arr_c: [B,Ch,R,C] complex tensor 112 | Returns: 113 | arr_c: [B,Ch,R,C] shifted tensor 114 | """ 115 | 116 | return fftshift(arr_c, invert=True) 117 | -------------------------------------------------------------------------------- /pado/light.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .complex import Complex 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from scipy.io import savemat 7 | 8 | class Light: 9 | def __init__(self, R, C, pitch, wvl, device, amplitude=None, phase=None, real=None, imag=None, B=1): 10 | """ 11 | Light wave that has a complex field (B,Ch,R,C) as a wavefront 12 | It takes the input wavefront in one of the following types 13 | 1. amplitude and phase 14 | 2. real and imaginary 15 | 3. None ==> we initialize light with amplitude of one and phase of zero 16 | 17 | Args: 18 | R: row 19 | C: column 20 | pitch: pixel pitch in meter 21 | wvl: wavelength of light in meter 22 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ... 23 | amplitude: [batch_size, # of channels, row, column] tensor of wavefront amplitude, default is None 24 | phase: [batch_size, # of channels, row, column] tensor of wavefront amplitude, default is None 25 | real: [batch_size, # of channels, row, column] tensor of wavefront real part, default is None 26 | imag: [batch_size, # of channels, row, column] tensor of wavefront imaginary part, default is None 27 | B: batch size 28 | """ 29 | 30 | self.B = B 31 | self.R = R 32 | self.C = C 33 | self.pitch = pitch 34 | self.device = device 35 | self.wvl = wvl 36 | 37 | if (amplitude==None) and (phase == None) and (real != None) and (imag != None): 38 | self.field = Complex(real=real, imag=imag) 39 | 40 | elif (amplitude!=None) and (phase != None) and (real == None) and (imag == None): 41 | self.field = Complex(mag=amplitude, ang=phase) 42 | 43 | elif (amplitude==None) and (phase == None) and (real == None) and (imag == None): 44 | amplitude = torch.ones((B, 1, self.R, self.C), device=self.device) 45 | phase = torch.zeros((B, 1, self.R, self.C), device=self.device) 46 | self.field = Complex(mag=amplitude, ang=phase) 47 | 48 | else: 49 | NotImplementedError('nope!') 50 | 51 | 52 | def crop(self, crop_width): 53 | """ 54 | Crop the light wavefront by crop_width 55 | Args: 56 | crop_width: (tuple) crop width of the tensor following torch functional pad 57 | """ 58 | 59 | self.field.crop(crop_width) 60 | self.R = self.field.size(2) 61 | self.C = self.field.size(3) 62 | 63 | def clone(self): 64 | """ 65 | Clone the light and return it 66 | """ 67 | 68 | return Light(self.R, self.C, 69 | self.pitch, self.wvl, self.device, 70 | amplitude=self.field.get_mag().clone(), phase=self.field.get_ang().clone(), 71 | B=self.B) 72 | 73 | 74 | def pad(self, pad_width, padval=0): 75 | """ 76 | Pad the light wavefront with a constant value by pad_width 77 | Args: 78 | pad_width: (tuple) pad width of the tensor following torch functional pad 79 | padval: value to pad. default is zero 80 | """ 81 | 82 | if padval == 0: 83 | self.set_amplitude(torch.nn.functional.pad(self.get_amplitude(), pad_width)) 84 | self.set_phase(torch.nn.functional.pad(self.get_phase(), pad_width)) 85 | else: 86 | return NotImplementedError('only zero padding supported') 87 | 88 | self.R += pad_width[0] + pad_width[1] 89 | self.C += pad_width[2] + pad_width[3] 90 | 91 | def set_real(self, real): 92 | """ 93 | Set the real part of the light wavefront 94 | Args: 95 | real: real part in the rect representation of the complex number 96 | """ 97 | 98 | self.field.set_real(real) 99 | 100 | def set_imag(self, imag): 101 | """ 102 | Set the imaginary part of the light wavefront 103 | Args: 104 | imag: imaginary part in the rect representation of the complex number 105 | """ 106 | 107 | self.field.set_imag(imag) 108 | 109 | def set_amplitude(self, amplitude): 110 | """ 111 | Set the amplitude of the light wavefront 112 | Args: 113 | amplitude: amplitude in the polar representation of the complex number 114 | """ 115 | self.field.set_mag(amplitude) 116 | 117 | def set_phase(self, phase): 118 | """ 119 | Set the phase of the complex tensor 120 | Args: 121 | phase: phase in the polar representation of the complex number 122 | """ 123 | self.field.set_ang(phase) 124 | 125 | 126 | def set_field(self, field): 127 | """ 128 | Set the wavefront modulation of the complex tensor 129 | Args: 130 | field: wavefront as a complex number 131 | """ 132 | self.field = field 133 | 134 | def set_pitch(self, pitch): 135 | """ 136 | Set the pixel pitch of the complex tensor 137 | Args: 138 | pitch: pixel pitch in meter 139 | """ 140 | self.pitch = pitch 141 | 142 | 143 | def get_amplitude(self): 144 | """ 145 | Return the amplitude of the wavefront 146 | Returns: 147 | mag: magnitude in the polar representation of the complex number 148 | """ 149 | 150 | return self.field.get_mag() 151 | 152 | def get_phase(self): 153 | """ 154 | Return the phase of the wavefront 155 | Returns: 156 | ang: angle in the polar representation of the complex number 157 | """ 158 | 159 | return self.field.get_ang() 160 | 161 | def get_field(self): 162 | """ 163 | Return the complex wavefront 164 | Returns: 165 | field: complex wavefront 166 | """ 167 | 168 | return self.field 169 | 170 | def get_intensity(self): 171 | """ 172 | Return the intensity of light wavefront 173 | Returns: 174 | intensity: intensity of light 175 | """ 176 | return self.field.get_intensity() 177 | 178 | def get_bandwidth(self): 179 | """ 180 | Return the bandwidth of light wavefront 181 | Returns: 182 | R_m: spatial height of the wavefront 183 | C_m: spatial width of the wavefront 184 | """ 185 | 186 | return self.pitch*self.R, self.pitch*self.C 187 | 188 | def magnify(self, scale_factor, interp_mode='nearest'): 189 | ''' 190 | Change the wavefront resolution without changing the pixel pitch 191 | Args: 192 | scale_factor: scale factor for interpolation used in tensor.nn.functional.interpolate 193 | interp_mode: interpolation method used in torch.nn.functional.interpolate 'bilinear', 'nearest' 194 | ''' 195 | self.field.resize(scale_factor, interp_mode) 196 | if self.field.mode == 'polar': 197 | self.R = self.field.mag.shape[-2] 198 | self.C = self.field.mag.shape[-1] 199 | elif self.field.mode == 'rect': 200 | self.R = self.field.real.shape[-2] 201 | self.C = self.field.real.shape[-1] 202 | 203 | 204 | def resize(self, target_pitch, interp_mode='nearest'): 205 | ''' 206 | Resize the wavefront by changing the pixel pitch. 207 | Args: 208 | target_pitch: new pixel pitch to use 209 | interp_mode: interpolation method used in torch.nn.functional.interpolate 'bilinear', 'nearest' 210 | ''' 211 | scale_factor = self.pitch / target_pitch 212 | self.magnify(scale_factor, interp_mode) 213 | self.set_pitch(target_pitch) 214 | 215 | def set_spherical_light(self, z, dx=0, dy=0): 216 | ''' 217 | Set the wavefront as spherical one coming from the position of (dx,dy,z). 218 | Args: 219 | z: z distance of the spherical light source from the current light position 220 | dx: x distance of the spherical light source from the current light position 221 | dy: y distance of the spherical light source from the current light position 222 | ''' 223 | 224 | [x, y] = np.mgrid[-self.C // 2:self.C // 2, -self.R // 2:self.R // 2].astype(np.float64) 225 | x = x * self.pitch 226 | y = y * self.pitch 227 | r = np.sqrt((x - dx) ** 2 + (y - dy) ** 2 + z ** 2) # this is computed in double precision 228 | theta = 2 * np.pi * r / self.wvl 229 | theta = np.expand_dims(np.expand_dims(theta, axis=0), axis=0)%(2*np.pi) 230 | theta = theta.astype(np.float32) 231 | 232 | theta = torch.tensor(theta, device=self.device) 233 | mag = torch.ones_like(theta) 234 | 235 | self.set_phase(theta) 236 | self.set_amplitude(mag) 237 | 238 | def set_plane_light(self): 239 | ''' 240 | Set the wavefront as a plane wave with zero phase and amptliude of one 241 | ''' 242 | amplitude = torch.ones((1, 1, self.R, self.C), device=self.device) 243 | phase = torch.zeros((1, 1, self.R, self.C), device=self.device) 244 | self.set_amplitude(amplitude) 245 | self.set_phase(phase) 246 | 247 | def save(self, fn): 248 | ''' 249 | Save the amplitude and phase of the light wavefront as a file 250 | Args: 251 | fn: filename to save. the format should be either "npy" or "mat" 252 | 253 | ''' 254 | 255 | amp = self.get_amplitude().data.cpu().numpy() 256 | phase = self.get_phase().data.cpu().numpy() 257 | 258 | if fn[-3:] == 'npy': 259 | np.save(fn, amp, phase) 260 | elif fn[-3:] == 'mat': 261 | savemat(fn, {'amplitude':amp, 'phase':phase}) 262 | else: 263 | print('extension in %s is unknown'%fn) 264 | print('light saved to %s\n'%fn) 265 | 266 | def visualize(self,b=0,c=0): 267 | """ 268 | Visualize the light wave 269 | Args: 270 | b: batch index to visualize default is 0 271 | c: channel index to visualize. default is 0 272 | """ 273 | 274 | bw = self.get_bandwidth() 275 | 276 | plt.figure(figsize=(20,5)) 277 | plt.subplot(131) 278 | amplitude_b = self.get_amplitude().data.cpu()[b,c,...].squeeze() 279 | plt.imshow(amplitude_b, 280 | extent=[0,bw[0]*1e3, 0, bw[1]*1e3], cmap='inferno') 281 | plt.title('amplitude') 282 | plt.xlabel('mm') 283 | plt.ylabel('mm') 284 | plt.colorbar() 285 | 286 | plt.subplot(132) 287 | phase = self.get_phase().data.cpu()[b,c,...].squeeze() 288 | plt.imshow(self.get_phase().data.cpu()[b,c,...].squeeze(), 289 | extent=[0,bw[0]*1e3, 0, bw[1]*1e3], cmap='hsv', vmin=-np.pi, vmax=np.pi) # cyclic colormap 290 | plt.title('phase') 291 | plt.xlabel('mm') 292 | plt.ylabel('mm') 293 | plt.colorbar() 294 | 295 | plt.subplot(133) 296 | intensity_b = self.get_intensity().data.cpu()[b,c,...].squeeze() 297 | plt.imshow(intensity_b, 298 | extent=[0,bw[0]*1e3, 0, bw[1]*1e3], cmap='inferno') 299 | plt.title('intensity') 300 | plt.xlabel('mm') 301 | plt.ylabel('mm') 302 | plt.colorbar() 303 | 304 | plt.suptitle('(%d,%d), pitch:%.2f[um], wvl:%.2f[nm], device:%s'%(self.R, self.C, 305 | self.pitch/1e-6, self.wvl/1e-9, self.device)) 306 | plt.show() 307 | 308 | def shape(self): 309 | """ 310 | Returns the shape of light wavefront 311 | Returns: 312 | shape 313 | """ 314 | return self.field.shape() 315 | -------------------------------------------------------------------------------- /pado/material.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Material: 5 | def __init__(self, material_name): 6 | """ 7 | Material of optical elements and its refractive index. 8 | Args: 9 | material name: name of the material. So far, we provide PDMS, FUSED_SILICA, VACCUM. 10 | """ 11 | self.material_name = material_name 12 | 13 | def get_RI(self, wvl): 14 | """ 15 | Return the refractive index of the current material for a wavelength 16 | Args: 17 | wvl: wavelength in meter 18 | Returns: 19 | RI: Refractive index at wvl 20 | """ 21 | 22 | wvl_nm = wvl / 1e-9 23 | if self.material_name == 'PDMS': 24 | RI = np.sqrt(1 + (1.0057 * (wvl_nm**2))/(wvl_nm**2 - 0.013217)) 25 | elif self.material_name == 'FUSED_SILICA': 26 | wvl_um = wvl_nm*1e-3 27 | RI = (1 + 0.6961663 / (1 - (0.0684043 / wvl_um) ** 2) + 0.4079426 / (1 - (0.1162414 / wvl_um) ** 2) + 0.8974794 / (1 - (9.896161 / wvl_um) ** 2)) ** .5 28 | # https://refractiveindex.info/?shelf=glass&book=fused_silica&page=Malitson 29 | elif self.material_name == 'VACUUM': 30 | RI = 1.0 31 | else: 32 | return NotImplementedError('%s is not in the RI list'%self.material_name) 33 | 34 | return RI 35 | 36 | -------------------------------------------------------------------------------- /pado/optical_element.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | class OpticalElement: 8 | def __init__(self, R, C, pitch, wvl, device, name="not defined",B=1): 9 | """ 10 | Base class for optical elements. Any optical element change the wavefront of incident light 11 | The change of the wavefront is stored as amplitude and phase tensors 12 | Note that he number of channels is one for the wavefront modulation. 13 | Args: 14 | R: row 15 | C: column 16 | pitch: pixel pitch in meter 17 | wvl: wavelength of light in meter 18 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ... 19 | name: name of the current optical element 20 | B: batch size 21 | """ 22 | 23 | self.name = name 24 | self.B = B 25 | self.R = R 26 | self.C = C 27 | self.pitch = pitch 28 | self.device = device 29 | self.amplitude_change = torch.ones((B, 1, R, C), device=self.device) 30 | self.phase_change = torch.zeros((B, 1, R, C), device=self.device) 31 | self.wvl = wvl 32 | 33 | def shape(self): 34 | """ 35 | Returns the shape of light-wavefront modulation. The nunmber of channels is one 36 | Returns: 37 | shape 38 | """ 39 | return (self.B,1,self.R,self.C) 40 | 41 | def set_pitch(self, pitch): 42 | """ 43 | Set the pixel pitch of the complex tensor 44 | Args: 45 | pitch: pixel pitch in meter 46 | """ 47 | self.pitch = pitch 48 | 49 | def resize(self, target_pitch, interp_mode='nearest'): 50 | ''' 51 | Resize the wavefront change by changing the pixel pitch. 52 | Args: 53 | target_pitch: new pixel pitch to use 54 | interp_mode: interpolation method used in torch.nn.functional.interpolate 'bilinear', 'nearest' 55 | ''' 56 | 57 | scale_factor = self.pitch / target_pitch 58 | self.amplitude_change = F.interpolate(self.amplitude_change, scale_factor=scale_factor, 59 | mode=interp_mode) 60 | self.phase_change = F.interpolate(self.phase_change, scale_factor=scale_factor, 61 | mode=interp_mode) 62 | self.set_pitch(target_pitch) 63 | self.R = self.amplitude_change.shape[-2] 64 | self.C = self.amplitude_change.shape[-1] 65 | 66 | def get_amplitude_change(self): 67 | ''' 68 | Return the amplitude change of the wavefront 69 | Returns: 70 | amplitude change: ampiltude change 71 | ''' 72 | 73 | return self.amplitude_change 74 | 75 | def get_phase_change(self): 76 | ''' 77 | Return the phase change of the wavefront 78 | Returns: 79 | phase change: phase change 80 | ''' 81 | 82 | return self.phase_change 83 | 84 | def set_amplitude_change(self, amplitude): 85 | """ 86 | Set the amplitude change 87 | Args: 88 | amplitude change: amplitude change in the polar representation of the complex number 89 | """ 90 | 91 | assert amplitude.shape[2] == self.R and amplitude.shape[3] == self.C 92 | self.amplitude_change = amplitude 93 | 94 | def set_phase_change(self, phase): 95 | """ 96 | Set the phase change 97 | Args: 98 | phase change: phase change in the polar representation of the complex number 99 | """ 100 | 101 | assert phase.shape[2] == self.R and phase.shape[3] == self.C 102 | self.phase_change = phase 103 | 104 | def pad(self, pad_width, padval=0): 105 | """ 106 | Pad the wavefront change with a constant value by pad_width 107 | Args: 108 | pad_width: (tuple) pad width of the tensor following torch functional pad 109 | padval: value to pad. default is zero 110 | """ 111 | if padval == 0: 112 | self.amplitude_change = torch.nn.functional.pad(self.get_amplitude_change(), pad_width) 113 | self.phase_change = torch.nn.functional.pad(self.get_phase_change(), pad_width) 114 | else: 115 | return NotImplementedError('only zero padding supported') 116 | 117 | self.R += pad_width[0] + pad_width[1] 118 | self.C += pad_width[2] + pad_width[3] 119 | 120 | def forward(self, light, interp_mode='nearest'): 121 | """ 122 | Forward the incident light with the optical element. 123 | Args: 124 | light: incident light 125 | interp_mode: interpolation method used in torch.nn.functional.interpolate 'bilinear', 'nearest' 126 | Returns: 127 | light after interaction with the optical element 128 | """ 129 | 130 | if light.pitch > self.pitch: 131 | light.resize(self.pitch, interp_mode) 132 | light.set_pitch(self.pitch) 133 | elif light.pitch < self.pitch: 134 | self.resize(light.pitch, interp_mode) 135 | self.set_pitch(light.pitch) 136 | 137 | if light.wvl != self.wvl: 138 | return NotImplementedError('wavelength should be same for light and optical elements') 139 | 140 | r1 = np.abs((light.R - self.R)//2) 141 | r2 = np.abs(light.R - self.R) - r1 142 | pad_width = (r1, r2, 0, 0) 143 | if light.R > self.R: 144 | self.pad(pad_width) 145 | elif light.R < self.R: 146 | light.pad(pad_width) 147 | 148 | c1 = np.abs((light.C - self.C)//2) 149 | c2 = np.abs(light.C - self.C) - c1 150 | pad_width = (0, 0, c1, c2) 151 | if light.C > self.C: 152 | self.pad(pad_width) 153 | elif light.C < self.C: 154 | light.pad(pad_width) 155 | 156 | light.set_phase(light.get_phase() + self.get_phase_change()) 157 | light.set_amplitude(light.get_amplitude() * self.get_amplitude_change()) 158 | 159 | return light 160 | 161 | def visualize(self, b=0): 162 | """ 163 | Visualize the wavefront modulation of the optical element 164 | Args: 165 | b: batch index to visualize default is 0 166 | """ 167 | 168 | plt.figure(figsize=(13,6)) 169 | 170 | plt.subplot(121) 171 | plt.imshow(self.get_amplitude_change().data.cpu()[b,...].squeeze()) 172 | plt.title('amplitude') 173 | plt.colorbar() 174 | 175 | plt.subplot(122) 176 | plt.imshow(self.get_phase_change().data.cpu()[b,...].squeeze()) 177 | plt.title('phase') 178 | plt.colorbar() 179 | 180 | plt.suptitle('%s, (%d,%d), pitch:%.2f[um], wvl:%.2f[nm], device:%s' 181 | %(self.name, self.R, self.C, self.pitch/1e-6, self.wvl/1e-9, self.device)) 182 | plt.show() 183 | 184 | 185 | class RefractiveLens(OpticalElement): 186 | def __init__(self, R, C, pitch, focal_length, wvl, device): 187 | """ 188 | Thin refractive lens 189 | Args: 190 | R: row 191 | C: column 192 | pitch: pixel pitch in meter 193 | focal_length: focal length of the lens in meter 194 | wvl: wavelength of light in meter 195 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ... 196 | """ 197 | 198 | super().__init__(R, C, pitch, wvl, device, name="refractive_lens") 199 | 200 | self.set_focal_length(focal_length) 201 | self.set_phase_change( self.compute_phase(self.wvl, shift_x=0, shift_y=0) ) 202 | 203 | def set_focal_length(self, focal_length): 204 | """ 205 | Set the focal length of the lens 206 | Args: 207 | focal_length: focal length in meter 208 | """ 209 | 210 | self.focal_length = focal_length 211 | 212 | def compute_phase(self, wvl, shift_x=0, shift_y=0): 213 | """ 214 | Set the phase of a thin lens 215 | Args: 216 | wvl: wavelength of light in meter 217 | shift_x: x displacement of the lens w.r.t. incident light 218 | shift_y: y displacement of the lens w.r.t. incident light 219 | """ 220 | 221 | bw_R = self.R*self.pitch 222 | bw_C = self.C*self.pitch 223 | 224 | x = np.arange(-bw_C/2, bw_C/2, self.pitch) 225 | x = x[:self.R] 226 | y = np.arange(-bw_R/2, bw_R/2, self.pitch) 227 | y = y[:self.C] 228 | xx,yy = np.meshgrid(x,y) 229 | 230 | theta_change = torch.tensor((-2*np.pi / wvl)*((xx-shift_x)**2 + (yy-shift_y)**2), device=self.device) / (2*self.focal_length) 231 | theta_change = torch.unsqueeze(torch.unsqueeze(theta_change, axis=0), axis=0) 232 | theta_change %= 2*np.pi 233 | theta_change -= np.pi 234 | 235 | return theta_change 236 | 237 | def height2phase(height, wvl, RI, wrap=True): 238 | """ 239 | Convert the height of a material to the corresponding phase shift 240 | Args: 241 | height: height of the material in meter 242 | wvl: wavelength of light in meter 243 | RI: refractive index of the material at the wavelength 244 | wrap: return the wrapped phase [0,2pi] 245 | """ 246 | dRI = RI - 1 247 | wv_n = 2. * np.pi / wvl 248 | phi = wv_n * dRI * height 249 | if wrap: 250 | phi %= 2 * np.pi 251 | return phi 252 | 253 | def phase2height(phase, wvl, RI): 254 | """ 255 | Convert the phase change to the height of a material 256 | Args: 257 | phase: phase change of light 258 | wvl: wavelength of light in meter 259 | RI: refractive index of the material at the wavelength 260 | """ 261 | dRI = RI - 1 262 | return wvl * phase / (2 * np.pi) / dRI 263 | 264 | def radius2phase(r, f, wvl): 265 | return (2 * np.pi * (np.sqrt(r * r + f * f) - f) / wvl) % (2 * np.pi) 266 | 267 | class DOE(OpticalElement): 268 | def __init__(self, R, C, pitch, material, wvl, device, height=None, phase=None, amplitude=None): 269 | """ 270 | Diffractive optical element (DOE) 271 | Args: 272 | R: row 273 | C: column 274 | pitch: pixel pitch in meter 275 | material: material of the DOE 276 | wvl: wavelength of light in meter 277 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ... 278 | height: height map of the material in meter 279 | phase: phase change of light 280 | amplitude: amplitude change of light 281 | """ 282 | 283 | super().__init__(R, C, pitch, wvl, device, name="doe") 284 | 285 | self.material = material 286 | self.height = None 287 | 288 | if amplitude is None: 289 | amplitude = torch.ones((1, 1, self.R, self.C), device=self.device) 290 | 291 | if height is None and phase is not None: 292 | self.mode = 'phase' 293 | self.set_phase_change(phase, wvl) 294 | self.set_amplitude_change(amplitude) 295 | elif height is not None and phase is None: 296 | self.mode = 'height' 297 | self.set_height(height) 298 | self.set_amplitude_change(amplitude) 299 | elif (height is None) and (phase is None) and (amplitude is None): 300 | self.mode = 'phase' 301 | phase = torch.zeros((1, 1, self.R, self.C), device=self.device) 302 | self.set_amplitude_change(amplitude) 303 | self.set_phase_change(phase, wvl) 304 | 305 | 306 | def change_wvl(self, wvl): 307 | """ 308 | Change the wavelength of phase change 309 | Args: 310 | wvl: wavelength of phase change 311 | """ 312 | height = self.get_height() 313 | self.wvl = wvl 314 | phase = height2phase(height, self.wvl, self.material.get_RI(self.wvl)) 315 | self.set_phase_change(phase, self.wvl) 316 | 317 | def set_diffraction_grating_1d(self, slit_width, minh, maxh): 318 | """ 319 | Set the wavefront modulation as 1D diffraction grating 320 | Args: 321 | slit_width: width of slit in meter 322 | minh: minimum height in meter 323 | maxh: maximum height in meter 324 | """ 325 | 326 | slit_width_px = np.round(slit_width / self.pitch) 327 | slit_space_px = slit_width_px 328 | 329 | dg = np.zeros((self.R, self.C)) 330 | slit_num_r = self.R // (2 * slit_width_px) 331 | slit_num_c = self.C // (2 * slit_width_px) 332 | 333 | dg[:] = minh 334 | 335 | for i in range(int(slit_num_c)): 336 | minc = int((slit_width_px + slit_space_px) * i) 337 | maxc = int(minc + slit_width_px) 338 | 339 | dg[:, minc:maxc] = maxh 340 | pc = torch.tensor(dg.astype(np.float32), device=self.device).unsqueeze(0).unsqueeze(0) 341 | self.set_phase_change(pc, self.wvl) 342 | 343 | def set_diffraction_grating_2d(self, slit_width, minh, maxh): 344 | """ 345 | Set the wavefront modulation as 2D diffraction grating 346 | Args: 347 | slit_width: width of slit in meter 348 | minh: minimum height in meter 349 | maxh: maximum height in meter 350 | """ 351 | 352 | slit_width_px = np.round(slit_width / self.pitch) 353 | slit_space_px = slit_width_px 354 | 355 | dg = np.zeros((self.R, self.C)) 356 | slit_num_r = self.R // (2 * slit_width_px) 357 | slit_num_c = self.C // (2 * slit_width_px) 358 | 359 | dg[:] = minh 360 | 361 | for i in range(int(slit_num_r)): 362 | for j in range(int(slit_num_c)): 363 | minc = int((slit_width_px + slit_space_px) * j) 364 | maxc = int(minc + slit_width_px) 365 | minr = int((slit_width_px + slit_space_px) * i) 366 | maxr = int(minr + slit_width_px) 367 | 368 | dg[minr:maxr, minc:maxc] = maxh 369 | 370 | pc = torch.tensor(dg.astype(np.float32), device=self.device).unsqueeze(0).unsqueeze(0) 371 | self.set_phase_change(pc, self.wvl) 372 | 373 | def set_Fresnel_lens(self, focal_length, shift_x=0, shift_y=0): 374 | """ 375 | Set the wavefront modulation as a fresnel lens 376 | Args: 377 | focal_length: focal length in meter 378 | shift_x: x displacement of the lens w.r.t. incident light 379 | shift_y: y displacement of the lens w.r.t. incident light 380 | """ 381 | 382 | x = np.arange(-self.C*self.pitch/2, self.C*self.pitch/2, self.pitch) 383 | y = np.arange(-self.R*self.pitch/2, self.R*self.pitch/2, self.pitch) 384 | xx,yy = np.meshgrid(x,y) 385 | xx = torch.tensor(xx, device=self.device) 386 | yy = torch.tensor(yy, device=self.device) 387 | 388 | phase = (-2*np.pi / self.wvl) * (torch.sqrt((xx-shift_x)**2 + (yy-shift_y)**2 + focal_length**2) - focal_length) 389 | phase = phase % (2*np.pi) 390 | phase -= np.pi 391 | phase = phase.unsqueeze(0).unsqueeze(0) 392 | 393 | self.set_phase_change(phase, self.wvl) 394 | 395 | def resize(self, target_pitch): 396 | ''' 397 | Resize the wavefront by changing the pixel pitch. 398 | Args: 399 | target_pitch: new pixel pitch to use 400 | ''' 401 | scale_factor = self.pitch / target_pitch 402 | super().resize(target_pitch) 403 | 404 | if self.mode == 'phase': 405 | super().resize(target_pitch) 406 | elif self.mode == 'height': 407 | self.set_height(F.interpolate(self.height, scale_factor=scale_factor, mode='bilinear', align_corners=False)) 408 | else: 409 | NotImplementedError('Mode is not set.') 410 | 411 | def get_height(self): 412 | """ 413 | Return the height map of the DOE 414 | Returns: 415 | height map: height map in meter 416 | """ 417 | 418 | if self.mode == 'height': 419 | return self.height 420 | elif self.mode == 'phase': 421 | height = phase2height(self.phase_change, self.wvl, self.material.get_RI(self.wvl)) 422 | return height 423 | else: 424 | NotImplementedError('Mode is not set.') 425 | 426 | def get_phase_change(self): 427 | """ 428 | Return the phase change induced by the DOE 429 | Returns: 430 | phase change: phase change 431 | """ 432 | if self.mode == 'height': 433 | self.to_phase_mode() 434 | return self.phase_change 435 | 436 | def set_height(self, height): 437 | """ 438 | Set the height map of the DOE 439 | Args: 440 | height map: height map in meter 441 | """ 442 | 443 | if self.mode == 'height': 444 | self.height = height 445 | elif self.mode == 'phase': 446 | self.set_phase_change(height2phase(height, self.wvl, self.material.get_RI(self.wvl)), self.wvl) 447 | 448 | def set_phase_change(self, phase_change, wvl): 449 | """ 450 | Set the phase change induced by the DOE 451 | Args: 452 | phase change: phase change 453 | """ 454 | 455 | if self.mode == 'height': 456 | self.set_height(phase2height(phase_change, wvl, self.material.get_RI(wvl))) 457 | if self.mode == 'phase': 458 | self.wvl = wvl 459 | self.phase_change = phase_change 460 | 461 | def to_phase_mode(self): 462 | """ 463 | Change the mode to phase change 464 | """ 465 | if self.mode == 'height': 466 | self.phase_change = height2phase(self.height, self.wvl, self.material.get_RI(self.wvl)) 467 | self.mode = 'phase' 468 | self.height = None 469 | 470 | def to_height_mode(self): 471 | """ 472 | Change the mode to height 473 | """ 474 | if self.mode == 'phase': 475 | self.height = phase2height(self.phase_change, self.wvl, self.material.get_RI(self.wvl)) 476 | self.mode = 'height' 477 | 478 | 479 | class SLM(OpticalElement): 480 | def __init__(self, R, C, pitch, wvl, device, B=1): 481 | """ 482 | Spatial light modulator (SLM) 483 | Args: 484 | R: row 485 | C: column 486 | pitch: pixel pitch in meter 487 | wvl: wavelength of light in meter 488 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ... 489 | B: batch size 490 | """ 491 | 492 | super().__init__(R, C, pitch, wvl, device, name="SLM", B=B) 493 | 494 | def set_lens(self, focal_length, shift_x=0, shift_y=0): 495 | """ 496 | Set the phase of a thin lens 497 | Args: 498 | wvl: wavelength of light in meter 499 | shift_x: x displacement of the lens w.r.t. incident light 500 | shift_y: y displacement of the lens w.r.t. incident light 501 | """ 502 | 503 | x = np.arange(-self.C*self.pitch/2, self.C*self.pitch/2, self.pitch) 504 | y = np.arange(-self.R*self.pitch/2, self.R*self.pitch/2, self.pitch) 505 | xx,yy = np.meshgrid(x,y) 506 | 507 | phase = (2*np.pi / self.wvl)*((xx-shift_x)**2 + (yy-shift_y)**2) / (2*focal_length) 508 | phase = torch.tensor(phase.astype(np.float32), device=self.device).unsqueeze(0).unsqueeze(0) 509 | phase = phase % (2*np.pi) 510 | phase -= np.pi 511 | 512 | self.set_phase_change(phase, self.wvl) 513 | 514 | def set_amplitude_change(self, amplitude, wvl): 515 | """ 516 | Set the amplitude change 517 | Args: 518 | amplitude change: amplitude change in the polar representation of the complex number 519 | wvl: wavelength of light in meter 520 | 521 | """ 522 | self.wvl = wvl 523 | super().set_amplitude_change(amplitude) 524 | 525 | def set_phase_change(self, phase_change, wvl): 526 | """ 527 | Set the phase change 528 | Args: 529 | phase change: phase change in the polar representation of the complex number 530 | wvl: wavelength of light in meter 531 | """ 532 | self.wvl = wvl 533 | super().set_phase_change(phase_change) 534 | 535 | 536 | class Aperture(OpticalElement): 537 | def __init__(self, R, C, pitch, aperture_diameter, aperture_shape, wvl, device='cpu'): 538 | """ 539 | Aperture 540 | Args: 541 | R: row 542 | C: column 543 | pitch: pixel pitch in meter 544 | aperture_diameter: diamater of the aperture in meter 545 | aperture_shape: shape of the aperture. {'square', 'circle'} 546 | wvl: wavelength of light in meter 547 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ... 548 | """ 549 | 550 | super().__init__(R, C, pitch, wvl, device, name="aperture") 551 | 552 | self.aperture_diameter = aperture_diameter 553 | self.aperture_shape = aperture_shape 554 | self.amplitude_change = torch.zeros((self.R, self.C), device=device) 555 | if self.aperture_shape == 'square': 556 | self.set_square() 557 | elif self.aperture_shape == 'circle': 558 | self.set_circle() 559 | else: 560 | return NotImplementedError 561 | 562 | def set_square(self): 563 | """ 564 | Set the amplitude modulation of the aperture as square 565 | """ 566 | 567 | self.aperture_shape = 'square' 568 | 569 | [x, y] = np.mgrid[-self.R // 2:self.R // 2, -self.C // 2:self.C // 2].astype(np.float32) 570 | r = self.pitch * np.asarray([abs(x), abs(y)]).max(axis=0) 571 | r = np.expand_dims(np.expand_dims(r, axis=0), axis=0) 572 | 573 | max_val = self.aperture_diameter / 2 574 | amp = (r <= max_val).astype(np.float32) 575 | amp[amp == 0] = 1e-20 # to enable stable learning 576 | self.amplitude_change = torch.tensor(amp, device=self.device) 577 | 578 | def set_circle(self, cx=0, cy=0, dia=None): 579 | """ 580 | Set the amplitude modulation of the aperture as circle 581 | Args: 582 | cx, cy: relative center position of the circle with respect to the center of the light wavefront 583 | dia: circle diameter 584 | """ 585 | [x, y] = np.mgrid[-self.R // 2:self.C // 2, -self.R // 2:self.C // 2].astype(np.float32) 586 | r2 = (x-cx) ** 2 + (y-cy) ** 2 587 | r2[r2 < 0] = 1e-20 588 | r = self.pitch * np.sqrt(r2) 589 | r = np.expand_dims(np.expand_dims(r, axis=0), axis=0) 590 | 591 | if dia is not None: 592 | self.aperture_diameter = dia 593 | self.aperture_shape = 'circle' 594 | max_val = self.aperture_diameter / 2 595 | amp = (r <= max_val).astype(np.float32) 596 | amp[amp == 0] = 1e-20 597 | self.amplitude_change = torch.tensor(amp, device=self.device) 598 | 599 | def quantize(x, levels, vmin=None, vmax=None, include_vmax=True): 600 | """ 601 | Quantize the floating array 602 | Args: 603 | levels: number of quantization levels 604 | vmin: minimum value for quantization 605 | vmax: maximum value for quantization 606 | include_vmax: include vmax for the quantized levels 607 | False: quantize x with the space of 1/levels-1. 608 | True: quantize x with the space of 1/levels 609 | """ 610 | 611 | if include_vmax is False: 612 | if levels == 0: 613 | return x 614 | 615 | if vmin is None: 616 | vmin = x.min() 617 | if vmax is None: 618 | vmax = x.max() 619 | 620 | #assert(vmin <= vmax) 621 | 622 | normalized = (x - vmin) / (vmax - vmin + 1e-16) 623 | if type(x) is np.ndarray: 624 | levelized = np.floor(normalized * levels) / (levels - 1) 625 | elif type(x) is torch.tensor: 626 | levelized = (normalized * levels).floor() / (levels - 1) 627 | result = levelized * (vmax - vmin) + vmin 628 | result[result < vmin] = vmin 629 | result[result > vmax] = vmax 630 | 631 | elif include_vmax is True: 632 | space = (x.max()-x.min())/levels 633 | vmin = x.min() 634 | vmax = vmin + space*(levels-1) 635 | if type(x) is np.ndarray: 636 | result = (np.floor((x-vmin)/space))*space + vmin 637 | elif type(x) is torch.tensor: 638 | result = (((x-vmin)/space).floor())*space + vmin 639 | result[resultvmax] = vmax 641 | 642 | return result -------------------------------------------------------------------------------- /pado/propagator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .fourier import fft 4 | from .complex import Complex 5 | from .conv import conv_fft 6 | 7 | 8 | def compute_pad_width(field, linear): 9 | """ 10 | Compute the pad width of an array for FFT-based convolution 11 | Args: 12 | field: (B,Ch,R,C) complex tensor 13 | linear: True or False, flag for linear convolution (zero padding) or circular convolution (no padding) 14 | Returns: 15 | pad_width: pad-width tensor 16 | """ 17 | 18 | if linear: 19 | R,C = field.shape()[-2:] 20 | pad_width = (C//2, C//2, R//2, R//2) 21 | else: 22 | pad_width = (0,0,0,0) 23 | return pad_width 24 | 25 | def unpad(field_padded, pad_width): 26 | """ 27 | Unpad the already-padded complex tensor 28 | Args: 29 | field_padded: (B,Ch,R,C) padded complex tensor 30 | pad_width: pad-width tensor 31 | Returns: 32 | field: unpadded complex tensor 33 | """ 34 | 35 | field = field_padded[...,pad_width[2]:-pad_width[3],pad_width[0]:-pad_width[1]] 36 | return field 37 | 38 | class Propagator: 39 | def __init__(self, mode): 40 | """ 41 | Free-space propagator of light waves 42 | One can simulate the propagation of light waves on free space (no medium change at all). 43 | Args: 44 | mode: type of propagator. currently, we support "Fraunhofer" propagation or "Fresnel" propagation. Use Fraunhofer for far-field propagation and Fresnel for near-field propagation. 45 | """ 46 | self.mode = mode 47 | 48 | def forward(self, light, z, linear=True): 49 | """ 50 | Forward the incident light with the propagator. 51 | Args: 52 | light: incident light 53 | z: propagation distance in meter 54 | linear: True or False, flag for linear convolution (zero padding) or circular convolution (no padding) 55 | Returns: 56 | light: light after propagation 57 | """ 58 | 59 | if self.mode == 'Fraunhofer': 60 | return self.forward_Fraunhofer(light, z, linear) 61 | if self.mode == 'Fresnel': 62 | return self.forward_Fresnel(light, z, linear) 63 | else: 64 | return NotImplementedError('%s propagator is not implemented'%self.mode) 65 | 66 | 67 | def forward_Fraunhofer(self, light, z, linear=True): 68 | """ 69 | Forward the incident light with the Fraunhofer propagator. 70 | Args: 71 | light: incident light 72 | z: propagation distance in meter. 73 | The propagated wavefront is independent w.r.t. the travel distance z. 74 | The distance z only affects the size of the "pixel", effectively adjusting the entire image size. 75 | linear: True or False, flag for linear convolution (zero padding) or circular convolution (no padding) 76 | Returns: 77 | light: light after propagation 78 | """ 79 | 80 | # pad_width = compute_pad_width(light.field, linear) 81 | field_propagated = fft(light.field)#, pad_width=pad_width) 82 | # field_propagated = unpad(field_propagated, pad_width) 83 | 84 | # based on the Fraunhofer reparametrization (u=x/wvl*z) and the Fourier frequency sampling (1/bandwidth) 85 | bw_r = light.get_bandwidth()[0] 86 | bw_c = light.get_bandwidth()[1] 87 | pitch_r_after_propagation = light.wvl*z/bw_r 88 | pitch_c_after_propagation = light.wvl*z/bw_c 89 | 90 | light_propagated = light.clone() 91 | 92 | # match the x-y pixel pitch using resampling 93 | if pitch_r_after_propagation >= pitch_c_after_propagation: 94 | scale_c = 1 95 | scale_r = pitch_r_after_propagation/pitch_c_after_propagation 96 | pitch_after_propagation = pitch_c_after_propagation 97 | elif pitch_r_after_propagation < pitch_c_after_propagation: 98 | scale_r = 1 99 | scale_c = pitch_c_after_propagation/pitch_r_after_propagation 100 | pitch_after_propagation = pitch_r_after_propagation 101 | 102 | light_propagated.set_field(field_propagated) 103 | light_propagated.magnify((scale_r,scale_c)) 104 | light_propagated.set_pitch(pitch_after_propagation) 105 | 106 | return light_propagated 107 | 108 | def forward_Fresnel(self, light, z, linear): 109 | """ 110 | Forward the incident light with the Fresnel propagator. 111 | Args: 112 | light: incident light 113 | z: propagation distance in meter. 114 | linear: True or False, flag for linear convolution (zero padding) or circular convolution (no padding) 115 | Returns: 116 | light: light after propagation 117 | """ 118 | field_input = light.field 119 | 120 | # compute the convolutional kernel 121 | sx = light.C / 2 122 | sy = light.R / 2 123 | x = np.arange(-sx, sx, 1) 124 | y = np.arange(-sy, sy, 1) 125 | xx, yy = np.meshgrid(x,y) 126 | xx = torch.from_numpy(xx*light.pitch).to(light.device) 127 | yy = torch.from_numpy(yy*light.pitch).to(light.device) 128 | k = 2*np.pi/light.wvl # wavenumber 129 | phase = (k*(xx**2 + yy**2)/(2*z)) 130 | amplitude = torch.ones_like(phase) / z / light.wvl 131 | conv_kernel = Complex(mag=amplitude, ang=phase) 132 | 133 | # Propagation with the convolution kernel 134 | pad_width = compute_pad_width(field_input, linear) 135 | 136 | field_propagated = conv_fft(field_input, conv_kernel, pad_width) 137 | 138 | # return the propagated light 139 | light_propagated = light.clone() 140 | light_propagated.set_field(field_propagated) 141 | 142 | return light_propagated 143 | 144 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | from torch.utils.tensorboard import SummaryWriter 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | import torch.backends.cudnn as cudnn 9 | from torch.autograd import Variable 10 | 11 | from importlib.machinery import SourceFileLoader 12 | import numpy as np 13 | import os 14 | import argparse 15 | from tqdm import trange 16 | import lpips 17 | 18 | from utils.save import * 19 | from utils.utils import * 20 | 21 | from models.forwards import * 22 | from models.ROLE import compute_raindrop 23 | from models.Defence import compute_fence 24 | from models.Dirt import compute_dirt 25 | 26 | def compute_dirt_raindrop(image_far, depth, args): 27 | if np.random.uniform() < 0.5: 28 | return compute_dirt(image_far, depth, args) 29 | else: 30 | return compute_raindrop(image_far, depth, args) 31 | 32 | def log(DOE_phase, G, batch_data, step,args): 33 | image_far, _ = batch_data 34 | image_far = image_far.to(args.device) 35 | image_near, mask, image_DOE, image_near_DOE, image_far_DOE, psf_near, psf_far, mask_doe, height_map = image_formation(image_far,DOE_phase, args.compute_obstruction, args) 36 | image = image_near * mask + image_far * (1 - mask) 37 | 38 | image_recon = G(image_DOE, psf_near, psf_far) 39 | image_recon = torch.clamp(image_recon, min=0, max=1) 40 | G_l1_loss = args.l1_loss_weight * args.l1_criterion(image_recon, image_far) 41 | G_perc_loss = torch.mean(args.perceptual_loss_weight * args.perceptual_criterion(2 * image_recon - 1, 2 * image_far - 1)) 42 | G_masked_loss = args.masked_loss_weight * args.l1_criterion(image_recon*mask, image_far*mask) 43 | loss = G_l1_loss + G_perc_loss + G_masked_loss 44 | 45 | psfs, log_psfs = plot_depth_based_psf(DOE_phase, args, args.param.plot_depth) 46 | DOE_phase_wrapped = DOE_phase % (2*np.pi) 47 | DOE_phase_wrapped = DOE_phase_wrapped - torch.min(DOE_phase_wrapped) 48 | DOE_phase_wrapped = DOE_phase_wrapped / torch.max(DOE_phase_wrapped) 49 | DOE_phase = DOE_phase - torch.min(DOE_phase) 50 | DOE_phase = DOE_phase / torch.max(DOE_phase) 51 | 52 | args.writer.add_scalar('val_loss/L1_loss',G_l1_loss, step) 53 | args.writer.add_scalar('val_loss/perc_loss',G_perc_loss, step) 54 | args.writer.add_scalar('val_loss/masked_loss',G_masked_loss, step) 55 | args.writer.add_scalar('val_loss/loss',loss, step) 56 | args.writer.add_image('result', torch.cat([torch.cat([image[0],image_far[0]],1), torch.cat([image_DOE[0],image_recon[0]],1)], -1), step) 57 | args.writer.add_image('image_sensor_component', torch.cat([image_far_DOE[0], image_near_DOE[0], mask_doe[0]],-1), step) 58 | args.writer.add_image('RGB_psf', psfs[0], step) 59 | args.writer.add_image('RGB_logpsf', log_psfs[0], step) 60 | args.writer.add_image('height_map',(height_map/torch.max(height_map))[0], step) 61 | args.writer.add_image('phase_map', DOE_phase[0], step) 62 | args.writer.add_image('phase_map_wrapped', DOE_phase_wrapped[0], step) 63 | 64 | 65 | def train_step(batch_data, DOE_phase, optics_optimizer, G, G_optimizer, step, args): 66 | image_far, _ = batch_data 67 | image_far = image_far.to(args.device) 68 | image_near, mask, image_DOE, image_near_DOE, image_far_DOE, psf_near, psf_far, mask_doe, height_map = image_formation(image_far,DOE_phase, args.compute_obstruction, args) 69 | image_recon = G(image_DOE, psf_near, psf_far) 70 | G_l1_loss = args.l1_loss_weight * args.l1_criterion(image_recon, image_far) 71 | G_perc_loss = torch.mean(args.perceptual_loss_weight * args.perceptual_criterion(2 * image_recon - 1, 2 * image_far - 1)) 72 | G_masked_loss = args.masked_loss_weight * args.l1_criterion(image_recon*mask, image_far*mask) 73 | loss = G_l1_loss + G_perc_loss + G_masked_loss 74 | loss.backward() 75 | 76 | G_optimizer.step() 77 | G_optimizer.zero_grad() 78 | if args.train_optics: 79 | optics_optimizer.step() 80 | optics_optimizer.zero_grad() 81 | return loss.detach() 82 | 83 | def train(args): 84 | 85 | if args.debug: 86 | np.random.seed(args.seed) 87 | torch.manual_seed(args.seed) 88 | if args.device == 'cuda': 89 | torch.cuda.manual_seed(args.seed) 90 | cudnn.benchmark = True 91 | cudnn.enabled=True 92 | param = args.param 93 | 94 | transform_train = transforms.Compose([ 95 | transforms.RandomCrop(param.data_resolution,pad_if_needed=True), # Places365 image size varies 96 | transforms.RandomCrop([param.equiv_crop_size, param.equiv_crop_size],pad_if_needed=True), 97 | transforms.Resize([param.img_res, param.img_res]), 98 | transforms.RandomHorizontalFlip(), 99 | transforms.ToTensor(), 100 | ]) 101 | 102 | transform_test = transforms.Compose([ 103 | transforms.RandomCrop(param.data_resolution,pad_if_needed=True), # Places365 image size varies 104 | transforms.CenterCrop(param.equiv_crop_size), 105 | transforms.Resize([param.img_res, param.img_res]), 106 | transforms.ToTensor(), 107 | ]) 108 | if args.obstruction == 'fence': 109 | trainset = torchvision.datasets.Places365( 110 | root=param.dataset_dir, split="train-standard", transform=transform_train) 111 | testset = torchvision.datasets.Places365( 112 | root=param.dataset_dir, split="val", transform=transform_test) 113 | args.compute_obstruction = compute_fence 114 | elif args.obstruction == 'raindrop': 115 | trainset = torchvision.datasets.ImageFolder(param.training_dir, transform=transform_train) 116 | testset = torchvision.datasets.ImageFolder(param.val_dir, transform=transform_test) 117 | args.compute_obstruction = compute_raindrop 118 | elif args.obstruction == 'dirt': 119 | trainset = torchvision.datasets.ImageFolder(param.training_dir, transform=transform_train) 120 | testset = torchvision.datasets.ImageFolder(param.val_dir, transform=transform_test) 121 | args.compute_obstruction = compute_dirt 122 | elif args.obstruction == 'dirt_raindrop': 123 | trainset = torchvision.datasets.ImageFolder(param.training_dir, transform=transform_train) 124 | testset = torchvision.datasets.ImageFolder(param.val_dir, transform=transform_test) 125 | args.compute_obstruction = compute_dirt_raindrop 126 | 127 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True) 128 | testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False) 129 | 130 | if args.debug: 131 | trainloader = testloader 132 | 133 | # build model and loss 134 | args.perceptual_criterion = lpips.LPIPS(net='vgg').to(args.device) 135 | args.l1_criterion = nn.L1Loss().to(args.device) 136 | args.writer = SummaryWriter(args.result_path) 137 | 138 | param = args.param 139 | if args.train_optics: 140 | DOE_phase = Variable(param.DOE_phase_init.to(args.device), requires_grad=True) 141 | optics_optimizer = optim.Adam([DOE_phase], lr=args.optics_lr) 142 | else: 143 | DOE_phase = Variable(param.DOE_phase_init.to(args.device), requires_grad=False) 144 | optics_optimizer = None 145 | 146 | from models.recon import Arch 147 | G = Arch(args).to(args.device) 148 | if args.pretrained_G is not None: 149 | G.load_state_dict(torch.load(args.G_init_ckpt, map_location='cpu')) 150 | G.to(args.device) 151 | 152 | G_optimizer = optim.Adam(params=G.parameters(), lr=args.G_lr) 153 | 154 | for _, batch_data in enumerate(testloader): 155 | test_data = batch_data 156 | break 157 | 158 | total_step = 0 159 | train_loss = 0 160 | log(DOE_phase, G, test_data, total_step, args) 161 | for epoch_cnt in trange(args.n_epochs, desc="Epoch"): 162 | for _, batch_data in enumerate(trainloader): 163 | step_loss = train_step(batch_data, DOE_phase, optics_optimizer, G, G_optimizer, total_step, args) 164 | total_step += 1 165 | train_loss += step_loss 166 | if total_step % args.log_freq == 0: 167 | log(DOE_phase, G, test_data, total_step, args) 168 | args.writer.add_scalar('train_loss/loss',train_loss/args.log_freq, total_step) 169 | train_loss = 0 170 | if total_step % args.save_freq == 0: 171 | torch.save(G.state_dict(), os.path.join(args.result_path,'G_%03d.pt' % (total_step//args.save_freq))) 172 | if args.train_optics: 173 | torch.save(DOE_phase, os.path.join(args.result_path,'DOE_phase_%03d.pt' % (total_step//args.save_freq))) 174 | 175 | 176 | def main(): 177 | parser = argparse.ArgumentParser( 178 | description='Obstruction-Free DOE', 179 | formatter_class=argparse.RawDescriptionHelpFormatter 180 | ) 181 | def str2bool(v): 182 | assert(v == 'True' or v == 'False') 183 | return v.lower() in ('true') 184 | 185 | def none_or_str(value): 186 | if value.lower() == 'none': 187 | return None 188 | return value 189 | 190 | parser.add_argument('--debug', action="store_true", help='debug mode, train on validation data to speed up the process') 191 | parser.add_argument('--train_optics', action="store_true", help='optimize optical element design') 192 | parser.add_argument('--pretrained_DOE', default = None, type =none_or_str, help = 'use a pretrained DOE') 193 | parser.add_argument('--pretrained_G', default = None, type =none_or_str, help = 'use a pretrained G') 194 | parser.add_argument('--result_path', default = './ckpt/opt', type=str, help='dir to save models and checkpoints') 195 | parser.add_argument('--param_file', default= 'config/param_MV_1600.py', type=str, help='path to param file') 196 | 197 | parser.add_argument('--obstruction', default = 'dirt_raindrop', type = str, help = 'obsturction type') 198 | parser.add_argument('--sensor_noise', default = 0.008, type=float, help='sensor noise level') 199 | parser.add_argument('--n_epochs', default = 40, type = int, help = 'max num of training epoch') 200 | parser.add_argument('--optics_lr', default=0.1, type=float, help='optical element learning rate') 201 | parser.add_argument('--G_lr', default=1e-4, type=float, help='network learning rate') 202 | 203 | parser.add_argument('--l1_loss_weight', default = 1, type = float, help = 'weight for L1 loss') 204 | parser.add_argument('--masked_loss_weight', default = 1, type = float, help = 'weight for masked loss (focus on obstructed scene)') 205 | parser.add_argument('--perceptual_loss_weight', default = 1, type = float, help = 'weight for perceptual loss') 206 | 207 | parser.add_argument('--log_freq', default=400, type=int, help = 'frequency (num_steps) of logging') 208 | parser.add_argument('--save_freq', default=2000, type=int, help = 'frequency (num_steps) of saving checkpoint and visual performance') 209 | parser.add_argument('--seed', type=int, default=1234, help='random seed') 210 | 211 | args = parser.parse_args() 212 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 213 | 214 | param = SourceFileLoader("param", args.param_file).load_module() 215 | param = convert_resolution(param,args) 216 | 217 | if args.pretrained_DOE is not None: 218 | if args.pretrained_DOE.endswith('.pt'): 219 | args.DOE_phase_init_ckpt = args.pretrained_DOE 220 | else: 221 | args.DOE_phase_init_ckpt = last_save(args.pretrained_DOE, 'DOE_phase_*') 222 | param.DOE_phase_init = torch.load(args.DOE_phase_init_ckpt, map_location='cpu').detach() 223 | 224 | if args.pretrained_G is not None: 225 | args.G_init_ckpt = last_save(args.pretrained_G, 'G_*') 226 | 227 | save_settings(args, param) 228 | train(args) 229 | 230 | if __name__ == '__main__': 231 | 232 | main() 233 | 234 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | PRETRAINED_G=None 2 | PRETRAINED_DOE=None 3 | RESULT_DIR=ckpt/E2E_MV2400 4 | PARAM=config/param_MV_2400.py 5 | OBSTRUCTION=dirt_raindrop 6 | 7 | conda activate SeeThroughObstruction 8 | 9 | python train.py --train_optics --result_path $RESULT_DIR --param_file $PARAM --obstruction $OBSTRUCTION --pretrained_DOE $PRETRAINED_DOE --pretrained_G $PRETRAINED_G -------------------------------------------------------------------------------- /utils/._save.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-computational-imaging/SeeThroughObstructions/f94a297cc87d72dcb17ffdab3bed64cff186ab9f/utils/._save.py -------------------------------------------------------------------------------- /utils/save.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from matplotlib import pyplot 4 | import cv2 5 | import shutil 6 | import json 7 | from glob import glob 8 | 9 | from models.forwards import * 10 | 11 | def save_settings(args, param): 12 | if not os.path.exists(args.result_path): 13 | os.makedirs(args.result_path) 14 | args_dict = vars(args) 15 | with open(os.path.join(args.result_path,'args.json'), "w") as f: 16 | json.dump(args_dict, f, indent=4, sort_keys=False) 17 | shutil.copy(args.param_file, args.result_path) 18 | if args.pretrained_DOE is not None: 19 | shutil.copy(args.DOE_phase_init_ckpt, os.path.join(args.result_path, 'init')) 20 | if args.pretrained_G is not None: 21 | shutil.copy(args.G_init_ckpt, os.path.join(args.result_path, 'init')) 22 | args.param = param 23 | 24 | def plot_depth_based_psf(DOE_phase, args, depths = [0.05, 0.1, 0.2, 0.4, 0.8, 1, 3, 5], wvls = 'RGB', normalize = True, merge_channel = False): 25 | param = args.param 26 | doe = DOE(param.R, param.C, param.DOE_pitch, param.material, param.DOE_wvl,args.device, phase = DOE_phase) 27 | 28 | psfs = [] 29 | if wvls == 'RGB': 30 | for i in range(len(param.wvls)): 31 | wvl = param.wvls[i] 32 | psf_depth = [] 33 | for z in depths: 34 | psf = compute_psf_Fraunhofer(wvl, torch.tensor(z), doe, args) 35 | if normalize: 36 | psf_depth.append(psf/torch.max(psf)) 37 | else: 38 | psf_depth.append(psf) 39 | psfs.append(torch.cat(psf_depth, -1)) 40 | if merge_channel: 41 | psfs = torch.cat(psfs, 1) 42 | else: 43 | psfs = torch.cat(psfs, -2) 44 | elif wvls == 'design': 45 | wvl = param.DOE_wvl 46 | psf_depth = [] 47 | for z in depths: 48 | psf = compute_psf_Fraunhofer(wvl, torch.tensor(z), doe, args) 49 | if normalize: 50 | psf_depth.append(psf/torch.max(psf)) 51 | else: 52 | psf_depth.append(psf) 53 | psfs.append(torch.cat(psf_depth, -1)) 54 | psfs = torch.cat(psfs, -2) 55 | else: 56 | assert False, "%s not defined" %wvls 57 | 58 | log_psfs = torch.log(psfs + 1e-9) 59 | log_psfs -= torch.min(log_psfs) 60 | log_psfs /= torch.max(log_psfs) 61 | return psfs, log_psfs 62 | 63 | def last_save(ckpt_path, file_format): 64 | return sorted(glob(os.path.join(ckpt_path, file_format)))[-1] 65 | 66 | def trapez(y,y0,w): 67 | return np.clip(np.minimum(y+1+w/2-y0, -y+1+w/2+y0),0,1) 68 | 69 | def weighted_line(r0, c0, r1, c1, w, rmin=0, rmax=np.inf): 70 | # The algorithm below works fine if c1 >= c0 and c1-c0 >= abs(r1-r0). 71 | # If either of these cases are violated, do some switches. 72 | if abs(c1-c0) < abs(r1-r0): 73 | # Switch x and y, and switch again when returning. 74 | xx, yy, val = weighted_line(c0, r0, c1, r1, w, rmin=rmin, rmax=rmax) 75 | return (yy, xx, val) 76 | 77 | # At this point we know that the distance in columns (x) is greater 78 | # than that in rows (y). Possibly one more switch if c0 > c1. 79 | if c0 > c1: 80 | return weighted_line(r1, c1, r0, c0, w, rmin=rmin, rmax=rmax) 81 | 82 | # The following is now always < 1 in abs 83 | slope = (r1-r0) / (c1-c0) 84 | 85 | # Adjust weight by the slope 86 | w *= np.sqrt(1+np.abs(slope)) / 2 87 | 88 | # We write y as a function of x, because the slope is always <= 1 89 | # (in absolute value) 90 | x = np.arange(c0, c1+1, dtype=float) 91 | y = x * slope + (c1*r0-c0*r1) / (c1-c0) 92 | 93 | # Now instead of 2 values for y, we have 2*np.ceil(w/2). 94 | # All values are 1 except the upmost and bottommost. 95 | thickness = np.ceil(w/2) 96 | yy = (np.floor(y).reshape(-1,1) + np.arange(-thickness-1,thickness+2).reshape(1,-1)) 97 | xx = np.repeat(x, yy.shape[1]) 98 | vals = trapez(yy, y.reshape(-1,1), w).flatten() 99 | 100 | yy = yy.flatten() 101 | 102 | # Exclude useless parts and those outside of the interval 103 | # to avoid parts outside of the picture 104 | mask = np.logical_and.reduce((yy >= rmin, yy < rmax, vals > 0)) 105 | 106 | return (yy[mask].astype(int), xx[mask].astype(int), vals[mask]) 107 | 108 | def plot_line(x1,x2,y1,y2,psf): 109 | rr, cc, val = weighted_line(x1,x2, y1, y2,2) 110 | psf[rr,cc] = val[:,None] 111 | return psf 112 | 113 | def viz_psf(psf, param, center_R = None, g = 2.2, weight = 4): 114 | psf = (psf / np.max(psf))**(1/g) 115 | if center_R is not None: 116 | size = 2*center_R 117 | w = h = int(param.img_res/2-center_R) 118 | psf_center = psf[h:h+size,w:w+size] 119 | psf_center = cv2.resize(psf_center, (center_R * 15,center_R * 15),interpolation = cv2.INTER_NEAREST) 120 | psf[:center_R * 15, :center_R * 15] = psf_center 121 | 122 | if len(psf.shape) == 2: 123 | psf = pyplot.cm.hot(psf)[...,:-1] 124 | 125 | # psf[:int(param.img_res/3),0:weight,:] = 1 126 | psf[:center_R * 15+ weight,center_R * 15: center_R * 15 + weight,:] = 1 127 | # psf[0:weight,:int(param.img_res/3),:] = 1 128 | psf[center_R * 15:center_R * 15 + weight,:center_R * 15+ weight,:] = 1 129 | psf[h:h+size+ weight,w-weight:w,:] = 1 130 | psf[h:h+size+ weight,w + size : w+size+weight,:] = 1 131 | psf[h-weight:h,w-weight:w+size+ weight,:] = 1 132 | psf[h+size:h+size+weight,w:w+size+ weight,:] = 1 133 | rr, cc, val = weighted_line(center_R * 15,0, h+size + weight, w,weight) 134 | psf[rr,cc] = val[:,None] 135 | rr, cc, val = weighted_line(0, center_R * 15,h, w+size+ weight,weight) 136 | psf[rr,cc] = val[:,None] 137 | else: 138 | if len(psf.shape) == 2: 139 | psf = pyplot.cm.hot(psf)[...,:-1] 140 | return psf 141 | 142 | def plot_psf_array(psfs, param, center_R = 10, gap = 10, g = 1.5): 143 | # if psfs.shape[1] == 1: 144 | # psfs = torch.tile(psfs,(1,3,1,1)) 145 | psfs_singles = torch.split(psfs, param.img_res, dim=-1) 146 | cnt = len(psfs_singles) 147 | canvas = np.ones([param.img_res, param.img_res * cnt + gap * (cnt-1), 3]) 148 | for i in range(cnt): 149 | if psfs.shape[1] == 1: 150 | psf = psfs_singles[i][0,0].cpu().numpy() 151 | else: 152 | psf = psfs_singles[i][0].permute(1,2,0).cpu().numpy() 153 | if i < 3: 154 | canvas[:, i * (param.img_res+gap):i * (param.img_res+gap) + param.img_res, :] = viz_psf(psf, param, center_R = center_R, g = g) 155 | 156 | else: 157 | canvas[:, i * (param.img_res+gap):i * (param.img_res+gap) + param.img_res, :] = viz_psf(psf, param, center_R = None, g = g) 158 | pyplot.figure(figsize = (30,10)) 159 | pyplot.imshow(canvas) 160 | return canvas 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from pado.light import * 2 | from pado.optical_element import * 3 | from pado.propagator import * 4 | import torch 5 | import numpy as np 6 | 7 | def convert_resolution(param, args): 8 | # dataset 9 | if args.obstruction == 'fence': 10 | param.dataset_dir = '/projects/FHEIDE/obstruction_free_doe/Places365' 11 | param.data_resolution = [512,768] 12 | elif args.obstruction == 'raindrop' or 'dirt' or 'dirt_raindrop': 13 | param.training_dir = '/projects/FHEIDE/Bad2ClearWeather/Cityscapes/leftImg8bit_trainvaltest/leftImg8bit/train' 14 | param.val_dir = '/projects/FHEIDE/Bad2ClearWeather/Cityscapes/leftImg8bit_trainvaltest/leftImg8bit/val' 15 | param.data_resolution = [1024, 2048] 16 | else: 17 | assert False, "undefined obstruction" 18 | 19 | # convert resolution and pitch size 20 | param.equiv_image_size = param.img_res * param.image_sample_ratio # image resolution before downsampling in camera pixel pitch 21 | param.equiv_crop_size = int(param.equiv_image_size * param.camera_pitch / param.background_pitch) # convert to background pixel pitch 22 | return param 23 | 24 | def randuni(low, high, size): 25 | '''uniformly sample from [low, high)''' 26 | return (torch.rand(size)*(high - low) + low) 27 | 28 | def real2complex(real): 29 | return Complex(mag=real, ang=torch.zeros_like(real)) 30 | 31 | def compute_pad_size(current_size, target_size): 32 | assert current_size < target_size 33 | gap = target_size - current_size 34 | left = int(gap/2) 35 | right = gap - left 36 | return int(left), int(right) 37 | 38 | def sample_psf(psf, sample_ratio): 39 | if sample_ratio == 1: 40 | return psf 41 | else: 42 | return torch.nn.AvgPool2d(sample_ratio, stride=sample_ratio)(psf) 43 | 44 | def metric2pixel(metric, depth, args): 45 | return int(metric * args.param.focal_length / (depth * args.param.equiv_camera_pitch)) 46 | 47 | def edge_mask(R,cutoff, device): 48 | [x, y] = np.mgrid[-int(R):int(R),-int(R):int(R)] 49 | dist = np.sqrt(x**2 +y**2).astype(np.int32) 50 | mask = torch.tensor(1.0*(dist < cutoff))[None, None, ...] 51 | return mask.to(device) 52 | 53 | class AttributeDict(dict): 54 | def __getattr__(self, attr): 55 | return self[attr] 56 | def __setattr__(self, attr, value): 57 | self[attr] = value --------------------------------------------------------------------------------