├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── ckpts
├── DSLR_2800_phase.pt
├── DSLR_3400_phase.pt
├── MV_1600_phase.pt
├── MV_2400_phase.pt
└── MV_recon
│ ├── G_040.pt
│ ├── args.json
│ └── param_MV_2400.py
├── config
├── param_DSLR_2800.py
└── param_MV_1600.py
├── env.yml
├── inference.ipynb
├── models
├── Defence.py
├── Dirt.py
├── PerlinBlob.py
├── ROLE.py
├── forwards.py
└── recon.py
├── pado-main
├── LICENSE
├── README.md
├── docs
│ └── images
│ │ ├── logo.pdf
│ │ ├── logo.png
│ │ └── logo.svg
├── example
│ └── tutorial.ipynb
└── pado
│ ├── .gitignore
│ ├── cmap_phase.txt
│ ├── complex.py
│ ├── conv.py
│ ├── fourier.py
│ ├── light.py
│ ├── material.py
│ ├── optical_element.py
│ └── propagator.py
├── pado
├── .gitignore
├── cmap_phase.txt
├── complex.py
├── conv.py
├── fourier.py
├── light.py
├── material.py
├── optical_element.py
└── propagator.py
├── train.py
├── train.sh
└── utils
├── ._save.py
├── save.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/SeeThroughObstructions/f94a297cc87d72dcb17ffdab3bed64cff186ab9f/.gitmodules
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Boost Software License - Version 1.0 - August 17th, 2003
2 |
3 | Permission is hereby granted, free of charge, to any person or organization
4 | obtaining a copy of the software and accompanying documentation covered by
5 | this license (the "Software") to use, reproduce, display, distribute,
6 | execute, and transmit the Software, and to prepare derivative works of the
7 | Software, and to permit third-parties to whom the Software is furnished to
8 | do so, all subject to the following:
9 |
10 | The copyright notices in the Software and this entire statement, including
11 | the above license grant, this restriction and the following disclaimer,
12 | must be included in all copies of the Software, in whole or in part, and
13 | all derivative works of the Software, unless such copies or derivative
14 | works are solely in the form of machine-executable object code generated by
15 | a source language processor.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
20 | SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
21 | FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
22 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
23 | DEALINGS IN THE SOFTWARE.
24 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Seeing Through Obstructions with Diffractive Cloaking
2 | ### [Project Page](https://light.princeton.edu/publication/seeing-through-obstructions/) | [Paper](https://dl.acm.org/doi/abs/10.1145/3528223.3530185)
3 |
4 | [Zheng Shi](https://zheng-shi.github.io/), [Yuval Bahat](https://sites.google.com/view/yuval-bahat/home), [Seung-Hwan Baek](https://www.shbaek.com/), [Qiang Fu](https://cemse.kaust.edu.sa/vcc/people/person/qiang-fu), [Hadi Amata ](https://cemse.kaust.edu.sa/people/person/hadi-amata), [Xiao Li](), [Praneeth Chakravarthula](https://www.cs.unc.edu/~cpk/), [Wolfgang Heidrich](https://vccimaging.org/People/heidriw/), [Felix Heide](https://www.cs.princeton.edu/~fheide/)
5 |
6 | If you find our work useful in your research, please cite:
7 | ```
8 | @article{Shi2022SeeThroughObstructions,
9 | author = {Shi, Zheng and Bahat, Yuval and Baek, Seung-Hwan and Fu, Qiang and Amata, Hadi and Li, Xiao and Chakravarthula, Praneeth and Heidrich, Wolfgang and Heide, Felix},
10 | title = {Seeing through Obstructions with Diffractive Cloaking},
11 | year = {2022},
12 | issue_date = {July 2022},
13 | publisher = {Association for Computing Machinery},
14 | address = {New York, NY, USA},
15 | volume = {41},
16 | number = {4},
17 | issn = {0730-0301},
18 | url = {https://doi.org/10.1145/3528223.3530185},
19 | doi = {10.1145/3528223.3530185}}
20 | ```
21 |
22 | ## Requirements
23 | This code is developed using Pytorch on Linux machine. Full frozen environment can be found in 'env.yml', note some of these libraries are not necessary to run this code. Other than the packages installed in the environment, our image formation model uses package [pado](https://github.com/shwbaek/pado) to simulate wave optics.
24 |
25 | ## Data
26 | In the paper we use [Places365](http://places2.csail.mit.edu/index.html) and [Cityscapes](https://www.cityscapes-dataset.com/) as the obstruction-free background scene. And they can be easily swtich to any other datasets of your choice. See 'train.py' for more details on the data augmentation we applied. For more details on depth-aware obstruction simulation, please refer to 'models/'.
27 |
28 | ## Pre-trained Models and Optimized DOE Designs
29 | Optimzed DOE Designs and pre-trained models are available under 'ckpts/' folder. Please refer to the supplemental documents for fabrication details.
30 |
31 | ## Sensor Capture Simulation and Reconstruction
32 | We include a sample script that demonstrates our entire image formation and reconstruction process. You can run the 'inference.ipynb' notebook in Jupyter Notebook. The notebook will load the checkpoint and run the entire process. The simulated depth-dependent PSFs, simulated sensor capture, as well as reconstructed image will be displayed within the notebook.
33 |
34 | ## Training
35 | We include 'train.sh' for training purpose. Please refer to 'config/' for optics and sensor specs.
36 |
37 | ## License
38 | Our code is licensed under BSL-1. By downloading the software, you agree to the terms of this License.
39 |
40 | ## Questions
41 | If there is anything unclear, please feel free to reach out to me at zhengshi[at]princeton[dot]edu.
42 |
--------------------------------------------------------------------------------
/ckpts/DSLR_2800_phase.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/SeeThroughObstructions/f94a297cc87d72dcb17ffdab3bed64cff186ab9f/ckpts/DSLR_2800_phase.pt
--------------------------------------------------------------------------------
/ckpts/DSLR_3400_phase.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/SeeThroughObstructions/f94a297cc87d72dcb17ffdab3bed64cff186ab9f/ckpts/DSLR_3400_phase.pt
--------------------------------------------------------------------------------
/ckpts/MV_1600_phase.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/SeeThroughObstructions/f94a297cc87d72dcb17ffdab3bed64cff186ab9f/ckpts/MV_1600_phase.pt
--------------------------------------------------------------------------------
/ckpts/MV_2400_phase.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/SeeThroughObstructions/f94a297cc87d72dcb17ffdab3bed64cff186ab9f/ckpts/MV_2400_phase.pt
--------------------------------------------------------------------------------
/ckpts/MV_recon/G_040.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/SeeThroughObstructions/f94a297cc87d72dcb17ffdab3bed64cff186ab9f/ckpts/MV_recon/G_040.pt
--------------------------------------------------------------------------------
/ckpts/MV_recon/args.json:
--------------------------------------------------------------------------------
1 | {
2 | "debug": false,
3 | "train_optics": false,
4 | "pretrained_DOE": "ckpts/MV_2400_phase.pt",
5 | "pretrained_G": null,
6 | "result_path": "/projects/FHEIDE/obstruction_free_doe/ckpt/MV_retrain",
7 | "param_file": "config/param_MV_2400.py",
8 | "obstruction": "dirt_raindrop",
9 | "sensor_noise": 0.008,
10 | "n_epochs": 20,
11 | "optics_lr": 0.1,
12 | "G_lr": 0.0001,
13 | "l1_loss_weight": 1,
14 | "masked_loss_weight": 1,
15 | "perceptual_loss_weight": 1,
16 | "log_freq": 400,
17 | "save_freq": 2000,
18 | "seed": 1234,
19 | "device": "cuda",
20 | "DOE_phase_init_ckpt": "ckpts/MV_2400_phase.pt",
21 | "G_init_ckpt": null
22 | }
--------------------------------------------------------------------------------
/ckpts/MV_recon/param_MV_2400.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pado.material import *
3 | from utils.utils import *
4 |
5 | # https://www.edmundoptics.com/p/8mm-UC-Series-Fixed-Focal-Length-Lens/41864?gclid=CjwKCAjwh5qLBhALEiwAioods3Z90aJenLK11rrj2R5E7SpKF0gvF8a9vZrsd0H5aY72nIgcYq42QRoC4hAQAvD_BwE
6 | # https://www.edmundoptics.com/p/bfs-u3-120s4c-cs-usb3-blackflyreg-s-color-camera/40172/
7 |
8 | camera_resolution = [4000,3000]
9 | camera_pitch = 1.85e-6
10 | background_pitch = camera_pitch * 2
11 | sensor_dist = focal_length = 8e-3
12 | aperture_shape='circle'
13 | wvls = [656e-9, 589e-9, 486e-9] # camera RGB wavelength
14 | DOE_wvl = 550e-9 # wavelength used to set DOE
15 |
16 | # DOE specs
17 | R = C = 2400 # Resolution of the simulated wavefront
18 | DOE_material = 'FUSED_SILICA'
19 | material = Material(DOE_material)
20 | DOE_pitch = camera_pitch
21 | aperture_diamter = DOE_pitch * R
22 | DOE_sample_ratio = 2
23 | image_sample_ratio = 2
24 | equiv_camera_pitch = camera_pitch * image_sample_ratio
25 | img_res = 512
26 | assert DOE_pitch * DOE_sample_ratio == equiv_camera_pitch
27 | DOE_max_height = 1.2e-6
28 | DOE_height_noise_scale = 4*10e-9
29 | DOE_phase_noise_scale = 0.1
30 |
31 | DOE_phase_init = torch.zeros((1,1, R, C))
32 |
33 | # depth
34 | depth_near_min = 0.05
35 | depth_near_max = 0.12
36 | depth_far_min = 5
37 | depth_far_max = 10
38 | plot_depth = [5,3,1,0.12,0.08,0.05]
39 |
40 | # raindrop
41 | drop_Nmin = 5
42 | drop_Nmax = 8
43 | drop_Rmin = 1e-3
44 | drop_Rmax = 3e-3
45 |
46 | # dirt
47 | perlin_res = 8
48 | perlin_cutoff = 0.55
--------------------------------------------------------------------------------
/config/param_DSLR_2800.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pado.material import *
3 |
4 | # https://www.usa.canon.com/internet/portal/us/home/products/details/cameras/eos-dslr-and-mirrorless-cameras/dslr/eos-rebel-t5-ef-s-18-55-is-ii-kit/eos-rebel-t5-18-55-is-ii-kit
5 | camera_resolution = [5196,3464] # 18e6
6 | camera_pitch = 4.3e-6
7 | background_pitch = camera_pitch * 4
8 | sensor_dist = focal_length = 50e-3
9 | aperture_shape='circle'
10 | wvls = [656e-9, 589e-9, 486e-9] # camera RGB wavelength
11 | DOE_wvl = 550e-9 # wavelength used to set DOE
12 |
13 | # DOE specs
14 | R = C = 2800 # Resolution of the simulated wavefront
15 | DOE_material = 'FUSED_SILICA'
16 | material = Material(DOE_material)
17 | DOE_pitch = camera_pitch * 1.5
18 | aperture_diamter = DOE_pitch * R
19 | DOE_sample_ratio = 2
20 | image_sample_ratio = 3
21 | equiv_camera_pitch = camera_pitch * image_sample_ratio
22 | img_res = 512
23 | assert DOE_pitch * DOE_sample_ratio == equiv_camera_pitch
24 | DOE_max_height = 1.2e-6
25 | DOE_height_noise_scale = 4*10e-9
26 | DOE_phase_noise_scale = 0.05
27 |
28 | DOE_phase_init = torch.zeros((1,1, R, C))
29 |
30 | # depth
31 | depth_near_min = 0.4
32 | depth_near_max = 0.8
33 | depth_far_min = 5
34 | depth_far_max = 10
35 | plot_depth = [10,8,5,0.8,0.6,0.4]
36 |
37 | # fence
38 | fence_min = 5e-3
39 | fence_max = 15e-3
40 |
--------------------------------------------------------------------------------
/config/param_MV_1600.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pado.material import *
3 | from utils.utils import *
4 |
5 | # https://www.edmundoptics.com/p/8mm-UC-Series-Fixed-Focal-Length-Lens/41864?gclid=CjwKCAjwh5qLBhALEiwAioods3Z90aJenLK11rrj2R5E7SpKF0gvF8a9vZrsd0H5aY72nIgcYq42QRoC4hAQAvD_BwE
6 | # https://www.edmundoptics.com/p/bfs-u3-120s4c-cs-usb3-blackflyreg-s-color-camera/40172/
7 |
8 | camera_resolution = [4000,3000]
9 | camera_pitch = 1.85e-6
10 | background_pitch = camera_pitch * 2
11 | sensor_dist = focal_length = 8e-3
12 | aperture_shape='circle'
13 | wvls = [656e-9, 589e-9, 486e-9] # camera RGB wavelength
14 | DOE_wvl = 550e-9 # wavelength used to set DOE
15 |
16 | # DOE specs
17 | R = C = 1600 # Resolution of the simulated wavefront
18 | DOE_material = 'FUSED_SILICA'
19 | material = Material(DOE_material)
20 | DOE_pitch = camera_pitch * 1.5
21 | aperture_diamter = DOE_pitch * R
22 | DOE_sample_ratio = 2
23 | image_sample_ratio = 3
24 | equiv_camera_pitch = camera_pitch * image_sample_ratio
25 | img_res = 512
26 | assert DOE_pitch * DOE_sample_ratio == equiv_camera_pitch
27 | DOE_max_height = 1.2e-6
28 | DOE_height_noise_scale = 4*10e-9
29 | DOE_phase_noise_scale = 0.1
30 |
31 | DOE_phase_init = torch.zeros((1,1, R, C))
32 |
33 | # depth
34 | depth_near_min = 0.05
35 | depth_near_max = 0.12
36 | depth_far_min = 5
37 | depth_far_max = 10
38 | plot_depth = [5,3,1,0.12,0.08,0.05]
39 |
40 | # raindrop
41 | drop_Nmin = 5
42 | drop_Nmax = 8
43 | drop_Rmin = 1e-3
44 | drop_Rmax = 3e-3
45 |
46 | # dirt
47 | perlin_res = 8
48 | perlin_cutoff = 0.55
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | name: SeeThroughObstruction
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - _openmp_mutex=4.5=1_gnu
8 | - absl-py=0.13.0=py39h06a4308_0
9 | - aiohttp=3.7.4=py39h27cfd23_1
10 | - async-timeout=3.0.1=py39h06a4308_0
11 | - attrs=21.2.0=pyhd3eb1b0_0
12 | - blas=1.0=mkl
13 | - blinker=1.4=py39h06a4308_0
14 | - bottleneck=1.3.2=py39hdd57654_1
15 | - brotlipy=0.7.0=py39h27cfd23_1003
16 | - bzip2=1.0.8=h7b6447c_0
17 | - c-ares=1.17.1=h27cfd23_0
18 | - ca-certificates=2021.7.5=h06a4308_1
19 | - cachetools=4.2.2=pyhd3eb1b0_0
20 | - certifi=2021.5.30=py39h06a4308_0
21 | - cffi=1.14.6=py39h400218f_0
22 | - chardet=3.0.4=py39h06a4308_1003
23 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
24 | - click=8.0.1=pyhd3eb1b0_0
25 | - coverage=5.5=py39h27cfd23_2
26 | - cryptography=3.4.7=py39hd23ed53_0
27 | - cudatoolkit=10.2.89=hfd86e86_1
28 | - cython=0.29.24=py39h295c915_0
29 | - ffmpeg=4.3=hf484d3e_0
30 | - freetype=2.10.4=h5ab3b9f_0
31 | - gmp=6.2.1=h2531618_2
32 | - gnutls=3.6.15=he1e5248_0
33 | - google-auth=1.33.0=pyhd3eb1b0_0
34 | - google-auth-oauthlib=0.4.1=py_2
35 | - grpcio=1.36.1=py39h2157cd5_1
36 | - idna=3.2=pyhd3eb1b0_0
37 | - importlib-metadata=3.10.0=py39h06a4308_0
38 | - intel-openmp=2021.3.0=h06a4308_3350
39 | - jpeg=9b=h024ee3a_2
40 | - lame=3.100=h7b6447c_0
41 | - lcms2=2.12=h3be6417_0
42 | - ld_impl_linux-64=2.35.1=h7274673_9
43 | - libffi=3.3=he6710b0_2
44 | - libgcc-ng=9.3.0=h5101ec6_17
45 | - libgomp=9.3.0=h5101ec6_17
46 | - libiconv=1.15=h63c8f33_5
47 | - libidn2=2.3.2=h7f8727e_0
48 | - libpng=1.6.37=hbc83047_0
49 | - libprotobuf=3.17.2=h4ff587b_1
50 | - libstdcxx-ng=9.3.0=hd4cf53a_17
51 | - libtasn1=4.16.0=h27cfd23_0
52 | - libtiff=4.2.0=h85742a9_0
53 | - libunistring=0.9.10=h27cfd23_0
54 | - libuv=1.40.0=h7b6447c_0
55 | - libwebp-base=1.2.0=h27cfd23_0
56 | - lz4-c=1.9.3=h295c915_1
57 | - markdown=3.3.4=py39h06a4308_0
58 | - mkl=2021.3.0=h06a4308_520
59 | - mkl-service=2.4.0=py39h7f8727e_0
60 | - mkl_fft=1.3.0=py39h42c9631_2
61 | - mkl_random=1.2.2=py39h51133e4_0
62 | - multidict=5.1.0=py39h27cfd23_2
63 | - ncurses=6.2=he6710b0_1
64 | - nettle=3.7.3=hbbd107a_1
65 | - ninja=1.10.2=hff7bd54_1
66 | - numexpr=2.7.3=py39h22e1b3c_1
67 | - numpy=1.20.3=py39hf144106_0
68 | - numpy-base=1.20.3=py39h74d4b33_0
69 | - oauthlib=3.1.1=pyhd3eb1b0_0
70 | - olefile=0.46=py_0
71 | - openh264=2.1.0=hd408876_0
72 | - openjpeg=2.4.0=h3ad879b_0
73 | - openssl=1.1.1k=h27cfd23_0
74 | - pandas=1.3.2=py39h8c16a72_0
75 | - pillow=8.3.1=py39h2c7a002_0
76 | - pip=21.2.4=py37h06a4308_0
77 | - protobuf=3.17.2=py39h295c915_0
78 | - pyasn1=0.4.8=py_0
79 | - pyasn1-modules=0.2.8=py_0
80 | - pycparser=2.20=py_2
81 | - pyjwt=2.1.0=py39h06a4308_0
82 | - pyopenssl=20.0.1=pyhd3eb1b0_1
83 | - pysocks=1.7.1=py39h06a4308_0
84 | - python=3.9.6=h12debd9_1
85 | - python-dateutil=2.8.2=pyhd3eb1b0_0
86 | - pytorch=1.9.0=py3.9_cuda10.2_cudnn7.6.5_0
87 | - pytz=2021.1=pyhd3eb1b0_0
88 | - readline=8.1=h27cfd23_0
89 | - requests=2.26.0=pyhd3eb1b0_0
90 | - requests-oauthlib=1.3.0=py_0
91 | - rsa=4.7.2=pyhd3eb1b0_1
92 | - setuptools=52.0.0=py39h06a4308_0
93 | - six=1.16.0=pyhd3eb1b0_0
94 | - sqlite=3.36.0=hc218d9a_0
95 | - tensorboard=2.5.0=py_0
96 | - tensorboard-plugin-wit=1.6.0=py_0
97 | - tk=8.6.10=hbc83047_0
98 | - torchaudio=0.9.0=py39
99 | - torchvision=0.10.0=py39_cu102
100 | - typing-extensions=3.10.0.0=hd3eb1b0_0
101 | - typing_extensions=3.10.0.0=pyh06a4308_0
102 | - tzdata=2021a=h5d7bf9c_0
103 | - urllib3=1.26.6=pyhd3eb1b0_1
104 | - werkzeug=1.0.1=pyhd3eb1b0_0
105 | - wheel=0.37.0=pyhd3eb1b0_0
106 | - xz=5.2.5=h7b6447c_0
107 | - yarl=1.6.3=py39h27cfd23_0
108 | - zipp=3.5.0=pyhd3eb1b0_0
109 | - zlib=1.2.11=h7b6447c_3
110 | - zstd=1.4.9=haebb681_0
111 | - pip:
112 | - anyio==3.3.0
113 | - argon2-cffi==21.1.0
114 | - astropy==4.3.1
115 | - babel==2.9.1
116 | - backcall==0.2.0
117 | - bleach==4.1.0
118 | - cycler==0.10.0
119 | - debugpy==1.4.1
120 | - decorator==5.0.9
121 | - defusedxml==0.7.1
122 | - entrypoints==0.3
123 | - imageio==2.9.0
124 | - ipykernel==6.3.1
125 | - ipython==7.27.0
126 | - ipython-genutils==0.2.0
127 | - jedi==0.18.0
128 | - jinja2==3.0.1
129 | - json5==0.9.6
130 | - jsonschema==3.2.0
131 | - jupyter-client==7.0.2
132 | - jupyter-core==4.7.1
133 | - jupyter-server==1.10.2
134 | - jupyterlab==3.1.9
135 | - jupyterlab-pygments==0.1.2
136 | - jupyterlab-server==2.7.2
137 | - kiwisolver==1.3.2
138 | - lightpipes==2.1.1
139 | - lpips==0.1.4
140 | - markupsafe==2.0.1
141 | - matplotlib==3.4.3
142 | - matplotlib-inline==0.1.2
143 | - mistune==0.8.4
144 | - nbclassic==0.3.1
145 | - nbclient==0.5.4
146 | - nbconvert==6.1.0
147 | - nbformat==5.1.3
148 | - nest-asyncio==1.5.1
149 | - networkx==2.6.3
150 | - noise==1.2.2
151 | - notebook==6.4.3
152 | - opencv-python==4.5.3.56
153 | - packaging==21.0
154 | - pandocfilters==1.4.3
155 | - parso==0.8.2
156 | - pexpect==4.8.0
157 | - pickleshare==0.7.5
158 | - poppy==0.9.2
159 | - prometheus-client==0.11.0
160 | - prompt-toolkit==3.0.20
161 | - prysm==0.20
162 | - ptyprocess==0.7.0
163 | - pyblur==0.2.3
164 | - pyerfa==2.0.0
165 | - pygments==2.10.0
166 | - pyparsing==2.4.7
167 | - pyrsistent==0.18.0
168 | - pywavelets==1.1.1
169 | - pyzmq==22.2.1
170 | - requests-unixsocket==0.2.0
171 | - scikit-image==0.18.3
172 | - scipy==1.7.1
173 | - send2trash==1.8.0
174 | - sniffio==1.2.0
175 | - terminado==0.11.1
176 | - testpath==0.5.0
177 | - tifffile==2021.8.30
178 | - tornado==6.1
179 | - tqdm==4.62.2
180 | - traitlets==5.1.0
181 | - wand==0.6.7
182 | - wcwidth==0.2.5
183 | - webencodings==0.5.1
184 | - websocket-client==1.2.1
185 |
186 |
--------------------------------------------------------------------------------
/models/Defence.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import torchvision
4 | import torchvision.transforms as transforms
5 | from PIL import Image
6 | import os
7 |
8 | from utils.utils import *
9 |
10 | class DefencingDataset(Dataset):
11 |
12 | def __init__(self, root_dir, transform=None):
13 | self.root_dir = root_dir
14 | self.transform = transform
15 | self.image_list = [f for f in os.listdir(root_dir) if not f.startswith('.')]
16 |
17 | def __len__(self):
18 | return len(self.image_list)
19 |
20 | def __getitem__(self, idx):
21 | if torch.is_tensor(idx):
22 | idx = idx.tolist()
23 |
24 | img_name = os.path.join(self.root_dir,
25 | self.image_list[idx])
26 | image = Image.open(img_name).convert('RGB')
27 | pixel_width = int(self.image_list[idx].split('-')[-1].split('.')[0])
28 | if self.transform:
29 | image = self.transform(image)
30 | sample = {'image': image , 'pixel_width': pixel_width}
31 | return sample
32 |
33 | def _largest_rotated_rect(w, h, angle):
34 | """
35 | Given a rectangle of size wxh that has been rotated by 'angle' (in
36 | radians), computes the width and height of the largest possible
37 | axis-aligned rectangle within the rotated rectangle.
38 | Original JS code by 'Andri' and Magnus Hoff from Stack Overflow
39 | Converted to Python by Aaron Snoswell
40 | Source: http://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders
41 | """
42 |
43 | quadrant = int(math.floor(angle / (math.pi / 2))) & 3
44 | sign_alpha = angle if ((quadrant & 1) == 0) else math.pi - angle
45 | alpha = (sign_alpha % math.pi + math.pi) % math.pi
46 |
47 | bb_w = w * math.cos(alpha) + h * math.sin(alpha)
48 | bb_h = w * math.sin(alpha) + h * math.cos(alpha)
49 |
50 | gamma = math.atan2(bb_w, bb_w) if (w < h) else math.atan2(bb_w, bb_w)
51 |
52 | delta = math.pi - alpha - gamma
53 |
54 | length = h if (w < h) else w
55 |
56 | d = length * math.cos(alpha)
57 | a = d * math.sin(alpha) / math.sin(delta)
58 |
59 | y = a * math.cos(gamma)
60 | x = y * math.tan(gamma)
61 |
62 | return (
63 | bb_w - 2 * x,
64 | bb_h - 2 * y
65 | )
66 |
67 | transform_fence = transforms.Compose([
68 | transforms.ColorJitter(contrast=0.2),
69 | transforms.RandomRotation(90),
70 | transforms.CenterCrop(1300),
71 | transforms.RandomHorizontalFlip(),
72 | transforms.ToTensor(),
73 | ])
74 |
75 | fenceset = DefencingDataset('/projects/FHEIDE/obstruction_free_doe/De-fencing/dataset_parsed/', transform=transform_fence)
76 |
77 | def compute_fence(image_far, depth, args):
78 | '''generate fence obstruction'''
79 | param = args.param
80 | fence_width = metric2pixel(randuni(param.fence_min, param.fence_max, 1)[0] , depth, args)
81 |
82 | fence = fenceset[np.random.randint(len(fenceset))]
83 |
84 | fence_image = fence['image']
85 | T = transforms.Compose([
86 | transforms.Resize(size=(int(fence_width/fence['pixel_width']*fence_image.shape[-2]),int(fence_width/fence['pixel_width']*fence_image.shape[-1]))),
87 | transforms.CenterCrop([param.img_res, param.img_res])
88 | ])
89 |
90 | image_near = T(fence_image)[None, ...]
91 | mask = (image_near > torch.median(image_near))*1.0
92 | image_near = image_near.to(args.device) * mask.to(args.device) + image_far * (1-mask.to(args.device))
93 | # if torch.mean(mask) > 0.3:
94 | # print(torch.mean(mask))
95 | # print(torch.median(image_near))
96 | return image_near.to(args.device), mask.to(args.device)
--------------------------------------------------------------------------------
/models/Dirt.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torchvision.transforms as transforms
4 | from models.PerlinBlob import *
5 |
6 | def circle_grad(img_res):
7 | center_x, center_y = img_res // 2, img_res // 2
8 | circle_grad = np.zeros([img_res,img_res])
9 |
10 | for y in range(img_res):
11 | for x in range(img_res):
12 | distx = abs(x - center_x)
13 | disty = abs(y - center_y)
14 | dist = np.sqrt(distx*distx + disty*disty)
15 | circle_grad[y][x] = dist
16 | max_grad = np.max(circle_grad)
17 | circle_grad = circle_grad / max_grad
18 | circle_grad -= 0.5
19 | circle_grad *= 2.0
20 | circle_grad = -circle_grad
21 |
22 | circle_grad -= np.min(circle_grad)
23 | max_grad = np.max(circle_grad)
24 | circle_grad = circle_grad / max_grad
25 | return circle_grad
26 |
27 | def compute_dirt(image_far, depth, args):
28 | param = args.param
29 | brown = (np.array([30, 20, 10]) * np.random.uniform(1,2.5) + np.random.uniform(-5,5,3) )* np.ones([param.img_res, param.img_res, 3]) / 255
30 | perlin_noise = generate_fractal_noise_2d([param.img_res, param.img_res], [param.perlin_res,param.perlin_res], tileable=(True,True), interpolant=interpolant)
31 | depth_adj = depth / param.depth_near_max
32 | T = transforms.Compose([transforms.ToTensor(),
33 | transforms.RandomCrop(int(param.img_res * depth_adj)),
34 | transforms.Resize([param.img_res, param.img_res])])
35 | perlin_noise = T(perlin_noise).squeeze().numpy()
36 | alpha_map = perlin_noise * (perlin_noise > param.perlin_cutoff) * circle_grad(param.img_res)
37 | alpha_map /= np.max(alpha_map)
38 | image_near = torch.tensor(brown* alpha_map[...,None]).permute(2,0,1)[None,...]
39 | mask = torch.tile(torch.tensor(1.0*(alpha_map[...,None] > 0.3)).permute(2,0,1)[None,...],(1,3,1,1))
40 | image_near = image_near.to(args.device) * mask.to(args.device) + image_far * (1-mask.to(args.device))
41 | return image_near.to(args.device), mask.to(args.device)
42 |
43 |
--------------------------------------------------------------------------------
/models/PerlinBlob.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def interpolant(t):
4 | return t*t*t*(t*(t*6 - 15) + 10)
5 |
6 | def interpolant1(t):
7 | return t*t*(t*(t*6 - 15) + 10)
8 |
9 | def interpolant2(t):
10 | return t*(t*(t*6 - 15) + 10)
11 |
12 | def generate_perlin_noise_2d(
13 | shape, res, tileable=(False, False), interpolant=interpolant
14 | ):
15 | """Generate a 2D numpy array of perlin noise.
16 | Args:
17 | shape: The shape of the generated array (tuple of two ints).
18 | This must be a multple of res.
19 | res: The number of periods of noise to generate along each
20 | axis (tuple of two ints). Note shape must be a multiple of
21 | res.
22 | tileable: If the noise should be tileable along each axis
23 | (tuple of two bools). Defaults to (False, False).
24 | interpolant: The interpolation function, defaults to
25 | t*t*t*(t*(t*6 - 15) + 10).
26 | Returns:
27 | A numpy array of shape shape with the generated noise.
28 | Raises:
29 | ValueError: If shape is not a multiple of res.
30 | """
31 | delta = (res[0] / shape[0], res[1] / shape[1])
32 | d = (shape[0] // res[0], shape[1] // res[1])
33 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]]\
34 | .transpose(1, 2, 0) % 1
35 | # Gradients
36 | angles = 2*np.pi*np.random.rand(res[0]+1, res[1]+1)
37 | gradients = np.dstack((np.cos(angles), np.sin(angles)))
38 | if tileable[0]:
39 | gradients[-1,:] = gradients[0,:]
40 | if tileable[1]:
41 | gradients[:,-1] = gradients[:,0]
42 | gradients = gradients.repeat(d[0], 0).repeat(d[1], 1)
43 | g00 = gradients[ :-d[0], :-d[1]]
44 | g10 = gradients[d[0]: , :-d[1]]
45 | g01 = gradients[ :-d[0],d[1]: ]
46 | g11 = gradients[d[0]: ,d[1]: ]
47 | # Ramps
48 | n00 = np.sum(np.dstack((grid[:,:,0] , grid[:,:,1] )) * g00, 2)
49 | n10 = np.sum(np.dstack((grid[:,:,0]-1, grid[:,:,1] )) * g10, 2)
50 | n01 = np.sum(np.dstack((grid[:,:,0] , grid[:,:,1]-1)) * g01, 2)
51 | n11 = np.sum(np.dstack((grid[:,:,0]-1, grid[:,:,1]-1)) * g11, 2)
52 | # Interpolation
53 | t = interpolant(grid)
54 | n0 = n00*(1-t[:,:,0]) + t[:,:,0]*n10
55 | n1 = n01*(1-t[:,:,0]) + t[:,:,0]*n11
56 |
57 | ret = np.sqrt(2)*((1-t[:,:,1])*n0 + t[:,:,1]*n1)
58 | ret = ret - ret.min()
59 | ret = ret/ret.max()
60 | return ret
61 |
62 |
63 | def generate_fractal_noise_2d(
64 | shape, res, octaves=1, persistence=0.5,
65 | lacunarity=2, tileable=(False, False),
66 | interpolant=interpolant
67 | ):
68 | """Generate a 2D numpy array of fractal noise.
69 | Args:
70 | shape: The shape of the generated array (tuple of two ints).
71 | This must be a multiple of lacunarity**(octaves-1)*res.
72 | res: The number of periods of noise to generate along each
73 | axis (tuple of two ints). Note shape must be a multiple of
74 | (lacunarity**(octaves-1)*res).
75 | octaves: The number of octaves in the noise. Defaults to 1.
76 | persistence: The scaling factor between two octaves.
77 | lacunarity: The frequency factor between two octaves.
78 | tileable: If the noise should be tileable along each axis
79 | (tuple of two bools). Defaults to (False, False).
80 | interpolant: The, interpolation function, defaults to
81 | t*t*t*(t*(t*6 - 15) + 10).
82 | Returns:
83 | A numpy array of fractal noise and of shape shape generated by
84 | combining several octaves of perlin noise.
85 | Raises:
86 | ValueError: If shape is not a multiple of
87 | (lacunarity**(octaves-1)*res).
88 | """
89 | noise = np.zeros(shape)
90 | frequency = 1
91 | amplitude = 1
92 | for _ in range(octaves):
93 | noise += amplitude * generate_perlin_noise_2d(
94 | shape, (frequency*res[0], frequency*res[1]), tileable, interpolant
95 | )
96 | frequency *= lacunarity
97 | amplitude *= persistence
98 | return noise
--------------------------------------------------------------------------------
/models/ROLE.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import numpy as np
4 | import torch
5 |
6 | from utils.utils import *
7 |
8 | class raindrop():
9 | def __init__(self, centerxy, radius):
10 | self.ifcol = False
11 | self.col_with = []
12 | self.center = centerxy
13 | self.radius = radius
14 | self.alphamap = np.zeros((self.radius * 5, self.radius*4,1))
15 | self.texture = None
16 | self._createDefaultDrop()
17 |
18 | def updateTexture(self, fg):
19 |
20 | # add fish eye effect to simulate the background
21 | K = np.array([[30*self.radius, 0, 2*self.radius],
22 | [0., 20*self.radius, 3*self.radius],
23 | [0., 0., 1]])
24 | D = np.array([0.0, 0.0, 0.0, 0.0])
25 | Knew = K.copy()
26 | Knew[(0,1), (0,1)] = math.pow(self.radius, 1/3) * 1.2 * Knew[(0,1), (0,1)]
27 | fisheye = cv2.fisheye.undistortImage(fg, K, D=D, Knew=Knew)
28 | tmp = np.array(fisheye)
29 |
30 | self.texture = np.clip((0.6 + 1.5*(tmp-0.5)),0,1)
31 |
32 | def _createDefaultDrop(self):
33 | cv2.circle(self.alphamap, (self.radius * 2, self.radius * 3), self.radius, 128, -1)
34 | cv2.ellipse(self.alphamap, (self.radius * 2, self.radius * 3), (self.radius, int(np.random.uniform(0.8,1.5) * self.radius)), 0, 180, 360, 128, -1)
35 | # set alpha map for png
36 | # self.alphamap = cv2.GaussianBlur(self.alphamap,(1,1),0)
37 | self.alphamap = np.asarray(self.alphamap)
38 | self.alphamap = (self.alphamap/np.max(self.alphamap))
39 |
40 | def compute_raindrop(image_far, depth, args):
41 | '''generate raindrop obstruction'''
42 | param = args.param
43 | image = image_far[0].permute(1,2,0).cpu().numpy()
44 |
45 | drop_num = int(np.random.randint(param.drop_Nmin, param.drop_Nmax) * depth / param.depth_near_max)
46 |
47 | alpha_map = np.zeros_like(image)[:,:,0:1]
48 | imgh, imgw, _ = image.shape
49 | edge_gap = metric2pixel(param.drop_Rmin,depth, args)
50 | ran_pos = [(edge_gap + int(np.random.rand() * (imgw - 2*edge_gap)), 2*edge_gap + int(np.random.rand() * (imgh - 3*edge_gap))) for _ in range(drop_num)]
51 | listRainDrops = []
52 | #########################
53 | # Create Raindrop
54 | #########################
55 | # create raindrop by default
56 |
57 | for pos in ran_pos:
58 | radius = np.minimum(metric2pixel(np.random.uniform(param.drop_Rmin, param.drop_Rmax), depth, args), int(np.floor(param.img_res/20)))
59 | drop = raindrop(pos, radius)
60 | listRainDrops.append(drop)
61 | # add texture
62 | for drop in listRainDrops:
63 | (ix, iy) = drop.center
64 | radius = drop.radius
65 | ROI_WL = 2*radius
66 | ROI_WR = 2*radius
67 | ROI_HU = 3*radius
68 | ROI_HD = 2*radius
69 | if (iy-3*radius) <0 :
70 | ROI_HU = iy
71 | if (iy+2*radius)>imgh:
72 | ROI_HD = imgh - iy
73 | if (ix-2*radius)<0:
74 | ROI_WL = ix
75 | if (ix+2*radius) > imgw:
76 | ROI_WR = imgw - ix
77 |
78 | drop_alpha = drop.alphamap
79 |
80 | alpha_map[iy - ROI_HU:iy + ROI_HD, ix - ROI_WL: ix+ROI_WR,:] += drop_alpha[3*radius - ROI_HU:3*radius + ROI_HD, 2*radius - ROI_WL: 2*radius+ROI_WR,:]
81 | drop.image_coor = np.array([iy - ROI_HU, iy + ROI_HD, ix - ROI_WL, ix+ROI_WR])
82 | drop.alpha_coor = np.array([3*radius - ROI_HU, 3*radius + ROI_HD, 2*radius - ROI_WL, 2*radius+ROI_WR])
83 |
84 | upshift = int(0.1 * (iy - ROI_HU))
85 | drop.updateTexture(image[iy - ROI_HU - upshift: iy + ROI_HD - upshift, ix - ROI_WL: ix+ROI_WR])
86 |
87 | image_near = np.asarray(cv2.GaussianBlur(image,(5,5),0))
88 | for drop in listRainDrops:
89 | img_hl,img_hr, img_wl, img_wr = drop.image_coor
90 | drop_hl, drop_hr, drop_wl, drop_wr = drop.alpha_coor
91 | texture_blend = drop.texture*(drop.alphamap[drop_hl:drop_hr, drop_wl:drop_wr])
92 | update_alpha = drop.alphamap[drop_hl:drop_hr, drop_wl:drop_wr] > 0
93 | image_near[img_hl:img_hr, img_wl: img_wr] = texture_blend * update_alpha + image_near[img_hl:img_hr, img_wl: img_wr] * (1-update_alpha)
94 |
95 | image_near = torch.tensor(image_near * (alpha_map > 0)).permute(2,0,1)[None,...]
96 | # image_near = torch.tensor(image_near).permute(2,0,1)[None,...]
97 | mask = torch.tile(torch.tensor(1.0*(alpha_map > 0)).permute(2,0,1)[None,...],(1,3,1,1))
98 | return image_near.to(args.device), mask.to(args.device)
--------------------------------------------------------------------------------
/models/forwards.py:
--------------------------------------------------------------------------------
1 | from pado.light import *
2 | from pado.optical_element import *
3 | from pado.propagator import *
4 |
5 | from utils.utils import *
6 |
7 | def compute_psf(wvl, depth, doe, args):
8 | '''simulate depth based psf'''
9 | param = args.param
10 | prop = Propagator('Fresnel')
11 |
12 | light = Light(param.R, param.C, param.DOE_pitch, wvl, args.device,B=1)
13 | light.set_spherical_light(depth .numpy())
14 |
15 | lens = RefractiveLens(param.R, param.C, param.DOE_pitch, param.focal_length, wvl,args.device)
16 | light = lens.forward(light)
17 |
18 | doe.change_wvl(wvl)
19 | light = doe.forward(light)
20 |
21 | aperture = Aperture(param.R, param.C, param.DOE_pitch, param.aperture_diamter, param.aperture_shape, wvl, args.device)
22 | light = aperture.forward(light)
23 |
24 | light_prop = prop.forward(light, param.sensor_dist)
25 | psf = light_prop.get_intensity()
26 | psf = sample_psf(psf, param.DOE_sample_ratio)
27 | psf /= torch.sum(psf)
28 |
29 | psf_size = psf.shape
30 | if psf_size[-2]*psf_size[-1] < param.img_res**2:
31 | wl, wr = compute_pad_size(psf_size[-1], param.img_res)
32 | hl, hr = compute_pad_size(psf_size[-2], param.img_res)
33 | psf = F.pad(psf, (wl, wr, hl, hr), "constant", 0)
34 | elif psf_size[-2]*psf_size[-1] > param.img_res**2:
35 | wl, wr = compute_pad_size(param.img_res, psf_size[-1])
36 | hl, hr = compute_pad_size(param.img_res, psf_size[-2])
37 | psf = psf[:,:,hl:-hr, wl:-wr]
38 |
39 | cutoff = np.tan(np.sinh(wvl/(3*param.DOE_pitch)))*param.focal_length / param.equiv_camera_pitch
40 | DOE_mask = edge_mask(int(param.img_res / 2),cutoff, args.device)
41 | psf *= DOE_mask
42 | psf /= torch.sum(psf)
43 | return psf
44 |
45 | def compute_psf_Fraunhofer(wvl, depth, doe, args):
46 | '''simulate depth based psf'''
47 | param = args.param
48 | prop = Propagator('Fraunhofer')
49 |
50 | light = Light(param.R, param.C, param.DOE_pitch, wvl, args.device,B=1)
51 | light.set_spherical_light(depth.numpy())
52 |
53 | doe.change_wvl(wvl)
54 | light = doe.forward(light)
55 |
56 | aperture = Aperture(param.R, param.C, param.DOE_pitch, param.aperture_diamter, param.aperture_shape, wvl, args.device)
57 | light = aperture.forward(light)
58 |
59 | light_prop = prop.forward(light, param.sensor_dist)
60 | psf = light_prop.get_intensity()
61 |
62 | # resize
63 | psf = F.interpolate(psf, int(param.R * light_prop.pitch / param.DOE_pitch))
64 |
65 | psf = sample_psf(psf, param.DOE_sample_ratio)
66 |
67 | psf_size = psf.shape
68 | if psf_size[-2]*psf_size[-1] < param.img_res**2:
69 | wl, wr = compute_pad_size(psf_size[-1], param.img_res)
70 | hl, hr = compute_pad_size(psf_size[-2], param.img_res)
71 | psf = F.pad(psf, (wl, wr, hl, hr), "constant", 0)
72 | elif psf_size[-2]*psf_size[-1] > param.img_res**2:
73 | wl, wr = compute_pad_size(param.img_res, psf_size[-1])
74 | hl, hr = compute_pad_size(param.img_res, psf_size[-2])
75 | psf = psf[:,:,hl:-hr, wl:-wr]
76 |
77 | cutoff = np.tan(np.sinh(wvl/(3*param.DOE_pitch)))*param.focal_length / param.equiv_camera_pitch
78 | DOE_mask = edge_mask(int(param.img_res / 2),cutoff, args.device)
79 | psf *= DOE_mask
80 | psf /= torch.sum(psf)
81 | return psf
82 |
83 | def image_formation(image_far, DOE_phase, compute_obstruction, args, z_near = None):
84 | param = args.param
85 | doe = DOE(param.R, param.C, param.DOE_pitch, param.material, param.DOE_wvl,args.device, phase = DOE_phase)
86 | height_map = doe.get_height() * edge_mask(int(param.R/ 2),int(param.R/ 2), args.device)
87 |
88 | if z_near is None:
89 | z_near = randuni(param.depth_near_min, param.depth_near_max, 1)[0] # randomly sample the near-point depth from a range
90 |
91 | z_far = randuni(param.depth_far_min, param.depth_far_max, 1)[0] # randomly sample the far-point depth from a range
92 | image_near, mask = compute_obstruction(image_far, z_near, args)
93 |
94 | img_doe = []
95 | img_near_doe = []
96 | img_far_doe = []
97 | psf_far_doe = []
98 | psf_near_doe = []
99 | mask_doe = []
100 |
101 | for i in range(len(param.wvls)):
102 |
103 | wvl = param.wvls[i]
104 |
105 | psf_near = compute_psf_Fraunhofer(wvl, z_near, doe, args)
106 | psf_far = compute_psf_Fraunhofer(wvl, z_far, doe, args)
107 |
108 | img_near_conv = conv_fft(real2complex(image_near[:,i,:,:]), real2complex(psf_near), (int(param.R/2),int(param.R/2),int(param.R/2),int(param.R/2))).get_mag() #
109 | img_far_conv = conv_fft(real2complex(image_far[:,i,:,:]), real2complex(psf_far)).get_mag()
110 | mask_conv = conv_fft(real2complex(mask[:,i,:,:]), real2complex(psf_near), (int(param.R/2),int(param.R/2),int(param.R/2),int(param.R/2))).get_mag()
111 | mask_conv = torch.clamp(1.5*mask_conv,0,1)
112 | img_conv = img_near_conv * mask_conv + img_far_conv * (1 - mask_conv)
113 |
114 | img_doe.append(img_conv)
115 | img_near_doe.append(img_near_conv)
116 | img_far_doe.append(img_far_conv)
117 | psf_near_doe.append(psf_near)
118 | psf_far_doe.append(psf_far)
119 | mask_doe.append(mask_conv)
120 | img_doe = torch.cat(img_doe, dim = 1)
121 | img_near_doe = torch.cat(img_near_doe, dim = 1)
122 | img_far_doe = torch.cat(img_far_doe, dim = 1)
123 | psf_near_doe = torch.cat(psf_near_doe, dim = 1)
124 | psf_far_doe = torch.cat(psf_far_doe, dim = 1)
125 | mask_doe = torch.cat(mask_doe, dim = 1)
126 |
127 | if args.sensor_noise > 0:
128 | noise = torch.rand(img_doe.shape) * 2 * args.sensor_noise - args.sensor_noise
129 | img_doe = torch.clamp(img_doe + noise.type_as(img_doe), 0, 1)
130 |
131 | return image_near.type_as(image_far), mask.type_as(image_far).type_as(image_far), \
132 | img_doe.type_as(image_far), img_near_doe.type_as(image_far), img_far_doe.type_as(image_far), \
133 | psf_near_doe.type_as(image_far), psf_far_doe.type_as(image_far), mask_doe.type_as(image_far), height_map.type_as(image_far)
--------------------------------------------------------------------------------
/models/recon.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from pado.fourier import fft, ifft
7 | from utils.utils import *
8 |
9 | class Arch(nn.Module):
10 | def __init__(self, args, nf = 12):
11 | super().__init__()
12 | self.nf = nf
13 |
14 | self.down0_0 = ConvBlock(6, nf, 7, 1, 3)
15 | self.down0_1 = ConvBlock(nf,nf, 3,1,1)
16 |
17 | self.down1_0 = ConvBlock(nf, 2*nf, 3,2,1)
18 | self.down1_1 = ConvBlock(2*nf, 2*nf, 3,1,1)
19 | self.down1_2 = ConvBlock(2*nf, 2*nf, 3,1,1)
20 |
21 | self.down2_0 = ConvBlock(2*nf, 4*nf, 3,2,1)
22 | self.down2_1 = ConvBlock(4*nf, 4*nf, 3,1,1)
23 | self.down2_2 = ConvBlock(4*nf, 4*nf, 3,1,1)
24 |
25 | self.down3_0 = ConvBlock(4*nf, 8*nf, 3,2,1)
26 | self.down3_1 = ConvBlock(8*nf, 8*nf, 3,1,1)
27 | self.down3_2 = ConvBlock(8*nf, 8*nf, 3,1,1)
28 |
29 | self.down4_0 = ConvBlock(8*nf, 12*nf, 3,2,1)
30 | self.down4_1 = ConvBlock(12*nf, 12*nf, 3,1,1)
31 | self.down4_2 = ConvBlock(12*nf, 12*nf, 3,1,1)
32 |
33 | self.bottleneck_0 = ConvBlock(24*nf, 24*nf, 3,1,1)
34 | self.bottleneck_1 = ConvBlock(24*nf, 12*nf, 3,1,1)
35 |
36 | self.up4_0 = ConvTraspBlock(12*nf, 8*nf, 2,2,0)
37 | self.up4_1 = ConvBlock(24*nf, 8*nf, 3,1,1)
38 |
39 | self.up3_0 = ConvTraspBlock(8*nf, 4*nf, 2,2,0)
40 | self.up3_1 = ConvBlock(12*nf, 4*nf, 3,1,1)
41 |
42 | self.up2_0 = ConvTraspBlock(4*nf, 2*nf, 2,2,0)
43 | self.up2_1 = ConvBlock(6*nf, 2*nf, 3,1,1)
44 |
45 | self.up1_0 = ConvTraspBlock(2*nf, nf, 2,2,0)
46 | self.up1_1 = ConvBlock(3*nf, nf, 3,1,1)
47 |
48 | self.up0_0 = ConvBlock(nf, 3, 3,1,1)
49 | self.up0_1 = nn.Sequential(
50 | nn.ReflectionPad2d(2),
51 | nn.Conv2d(3, 3, kernel_size=5, stride=1),
52 | nn.Tanh()
53 | )
54 |
55 | # init to identity mapping
56 | self.res0 = BasicBlock(3, 3, 3, 1, 1).to(args.device)
57 | self.res1 = BasicBlock(3, 3, 3, 1, 1).to(args.device)
58 | self.res2 = BasicBlock(3, 3, 3, 1, 1).to(args.device)
59 |
60 | self.out = nn.Sequential(
61 | nn.ReflectionPad2d(2),
62 | nn.Conv2d(3, 3, kernel_size=5, stride=1),
63 | nn.ReLU()
64 | )
65 | self.out.apply(init_id_weights)
66 |
67 |
68 | def forward(self, image, psf_near, psf_far):
69 | psf0 = torch.tile(psf_far, (1,int(self.nf/3),1,1))
70 | psf1 = torch.tile(sample_psf(psf_far, 2), (1,int(2*self.nf/3),1,1))
71 | psf2 = torch.tile(sample_psf(psf_far, 4), (1,int(4*self.nf/3),1,1))
72 | psf3 = torch.tile(sample_psf(psf_far, 8), (1,int(8*self.nf/3),1,1))
73 | psf4 = torch.tile(sample_psf(psf_far, 16), (1,int(12*self.nf/3),1,1))
74 |
75 | images = [image, Wiener_deconv(image, psf_near)]
76 | images = torch.cat(images, dim = 1)
77 |
78 | down0 = self.down0_0(images)
79 | down0 = self.down0_1(down0)
80 | deconv0 = Wiener_deconv(down0, psf0)
81 |
82 | down1 = self.down1_0(down0)
83 | down1 = self.down1_1(down1)
84 | down1 = self.down1_2(down1)
85 | deconv1 = Wiener_deconv(down1, psf1)
86 |
87 | down2 = self.down2_0(down1)
88 | down2 = self.down2_1(down2)
89 | down2 = self.down2_2(down2)
90 | deconv2 = Wiener_deconv(down2, psf2)
91 |
92 | down3 = self.down3_0(down2)
93 | down3 = self.down3_1(down3)
94 | down3 = self.down3_2(down3)
95 | deconv3 = Wiener_deconv(down3, psf3)
96 |
97 | down4 = self.down4_0(down3)
98 | down4 = self.down4_1(down4)
99 | down4 = self.down4_2(down4)
100 | deconv4 = Wiener_deconv(down4, psf4)
101 |
102 | bottleneck = self.bottleneck_0(torch.cat([deconv4,down4], 1))
103 | bottleneck = self.bottleneck_1(bottleneck)
104 |
105 | up4 = self.up4_0(bottleneck)
106 | up4 = self.up4_1(torch.cat([up4,down3, deconv3], 1))
107 |
108 | up3 = self.up3_0(up4)
109 | up3 = self.up3_1(torch.cat([up3,down2, deconv2], 1))
110 |
111 | up2 = self.up2_0(up3)
112 | up2 = self.up2_1(torch.cat([up2,down1, deconv1], 1))
113 |
114 | up1 = self.up1_0(up2)
115 | up1 = self.up1_1(torch.cat([up1,down0, deconv0], 1))
116 |
117 | up0 = self.up0_0(up1)
118 | up0 = self.up0_1(up0)
119 | up0 = up0 + image
120 |
121 | res = self.res0(up0)
122 | res = self.res1(res)
123 | res = self.res2(res)
124 |
125 | out = self.out(res)
126 |
127 | return out
128 |
129 | def ConvBlock(in_channels, out_channels, kernel_size, stride, padding):
130 | block = nn.Sequential(
131 | nn.ReflectionPad2d(padding),
132 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride), #padding=padding),
133 | nn.InstanceNorm2d(out_channels),
134 | nn.LeakyReLU()
135 | )
136 | return block
137 |
138 | def ConvTraspBlock(in_channels, out_channels, kernel_size, stride, padding):
139 | block = nn.Sequential(
140 | nn.ReflectionPad2d(padding),
141 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride), #padding=padding),
142 | nn.InstanceNorm2d(out_channels),
143 | nn.LeakyReLU()
144 | )
145 | return block
146 |
147 | def Wiener_deconv(image, psf):
148 | otf = fft(real2complex(psf))
149 | wiener = otf.conj() / real2complex(otf.get_mag() **2 + 1e-6)
150 | image_deconv = ifft(wiener * fft(real2complex(image))).get_mag()
151 | return torch.clamp(image_deconv.type_as(image), min = 0, max = 1)
152 |
153 | def init_zero_weights(module):
154 | if isinstance(module, torch.nn.Conv2d):
155 | # module.weight.data.zero_()
156 | module.weight.data = torch.nn.init.xavier_uniform_(module.weight.data,gain=1e-3)
157 | if module.bias is not None:
158 | module.bias.data.zero_()
159 |
160 | def init_id_weights(module):
161 | if isinstance(module, nn.Conv2d):
162 | module.weight.data = torch.nn.init.dirac_(module.weight.data)
163 | if module.bias is not None:
164 | module.bias.data.zero_()
165 |
166 | class BasicBlock(nn.Module):
167 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
168 | super().__init__()
169 | self.conv1 = ConvBlock(in_channels, out_channels, kernel_size, stride, padding)
170 | self.conv1.apply(init_zero_weights)
171 | self.conv2 = nn.Sequential(
172 | nn.ReflectionPad2d(padding),
173 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride),
174 | nn.InstanceNorm2d(out_channels),
175 | )
176 | self.conv2.apply(init_zero_weights)
177 | self.leakyrelu = nn.LeakyReLU()
178 |
179 | def forward(self, x):
180 | identity = x
181 |
182 | out = self.conv1(x)
183 | out = self.conv2(out)
184 | out += identity
185 | out = self.leakyrelu(out)
186 |
187 | return out
188 |
189 |
190 |
--------------------------------------------------------------------------------
/pado-main/LICENSE:
--------------------------------------------------------------------------------
1 | Pado
2 | Copyright (c) 2022 by Seung-Hwan Baek
3 |
4 | Pado is licensed under a
5 | Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License.
6 |
7 | You should have received a copy of the license along with this
8 | work. If not, see .
9 |
10 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
11 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
12 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
14 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
15 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
16 | THE SOFTWARE.
17 |
18 |
19 | License
20 |
21 | THE WORK (AS DEFINED BELOW) IS PROVIDED UNDER THE TERMS OF THIS CREATIVE COMMONS PUBLIC LICENSE ("CCPL" OR "LICENSE"). THE WORK IS PROTECTED BY COPYRIGHT AND/OR OTHER APPLICABLE LAW. ANY USE OF THE WORK OTHER THAN AS AUTHORIZED UNDER THIS LICENSE OR COPYRIGHT LAW IS PROHIBITED.
22 |
23 | BY EXERCISING ANY RIGHTS TO THE WORK PROVIDED HERE, YOU ACCEPT AND AGREE TO BE BOUND BY THE TERMS OF THIS LICENSE. TO THE EXTENT THIS LICENSE MAY BE CONSIDERED TO BE A CONTRACT, THE LICENSOR GRANTS YOU THE RIGHTS CONTAINED HERE IN CONSIDERATION OF YOUR ACCEPTANCE OF SUCH TERMS AND CONDITIONS.
24 |
25 | 1. Definitions
26 |
27 | "Adaptation" means a work based upon the Work, or upon the Work and other pre-existing works, such as a translation, adaptation, derivative work, arrangement of music or other alterations of a literary or artistic work, or phonogram or performance and includes cinematographic adaptations or any other form in which the Work may be recast, transformed, or adapted including in any form recognizably derived from the original, except that a work that constitutes a Collection will not be considered an Adaptation for the purpose of this License. For the avoidance of doubt, where the Work is a musical work, performance or phonogram, the synchronization of the Work in timed-relation with a moving image ("synching") will be considered an Adaptation for the purpose of this License.
28 | "Collection" means a collection of literary or artistic works, such as encyclopedias and anthologies, or performances, phonograms or broadcasts, or other works or subject matter other than works listed in Section 1(g) below, which, by reason of the selection and arrangement of their contents, constitute intellectual creations, in which the Work is included in its entirety in unmodified form along with one or more other contributions, each constituting separate and independent works in themselves, which together are assembled into a collective whole. A work that constitutes a Collection will not be considered an Adaptation (as defined above) for the purposes of this License.
29 | "Distribute" means to make available to the public the original and copies of the Work or Adaptation, as appropriate, through sale or other transfer of ownership.
30 | "License Elements" means the following high-level license attributes as selected by Licensor and indicated in the title of this License: Attribution, Noncommercial, ShareAlike.
31 | "Licensor" means the individual, individuals, entity or entities that offer(s) the Work under the terms of this License.
32 | "Original Author" means, in the case of a literary or artistic work, the individual, individuals, entity or entities who created the Work or if no individual or entity can be identified, the publisher; and in addition (i) in the case of a performance the actors, singers, musicians, dancers, and other persons who act, sing, deliver, declaim, play in, interpret or otherwise perform literary or artistic works or expressions of folklore; (ii) in the case of a phonogram the producer being the person or legal entity who first fixes the sounds of a performance or other sounds; and, (iii) in the case of broadcasts, the organization that transmits the broadcast.
33 | "Work" means the literary and/or artistic work offered under the terms of this License including without limitation any production in the literary, scientific and artistic domain, whatever may be the mode or form of its expression including digital form, such as a book, pamphlet and other writing; a lecture, address, sermon or other work of the same nature; a dramatic or dramatico-musical work; a choreographic work or entertainment in dumb show; a musical composition with or without words; a cinematographic work to which are assimilated works expressed by a process analogous to cinematography; a work of drawing, painting, architecture, sculpture, engraving or lithography; a photographic work to which are assimilated works expressed by a process analogous to photography; a work of applied art; an illustration, map, plan, sketch or three-dimensional work relative to geography, topography, architecture or science; a performance; a broadcast; a phonogram; a compilation of data to the extent it is protected as a copyrightable work; or a work performed by a variety or circus performer to the extent it is not otherwise considered a literary or artistic work.
34 | "You" means an individual or entity exercising rights under this License who has not previously violated the terms of this License with respect to the Work, or who has received express permission from the Licensor to exercise rights under this License despite a previous violation.
35 | "Publicly Perform" means to perform public recitations of the Work and to communicate to the public those public recitations, by any means or process, including by wire or wireless means or public digital performances; to make available to the public Works in such a way that members of the public may access these Works from a place and at a place individually chosen by them; to perform the Work to the public by any means or process and the communication to the public of the performances of the Work, including by public digital performance; to broadcast and rebroadcast the Work by any means including signs, sounds or images.
36 | "Reproduce" means to make copies of the Work by any means including without limitation by sound or visual recordings and the right of fixation and reproducing fixations of the Work, including storage of a protected performance or phonogram in digital form or other electronic medium.
37 | 2. Fair Dealing Rights. Nothing in this License is intended to reduce, limit, or restrict any uses free from copyright or rights arising from limitations or exceptions that are provided for in connection with the copyright protection under copyright law or other applicable laws.
38 |
39 | 3. License Grant. Subject to the terms and conditions of this License, Licensor hereby grants You a worldwide, royalty-free, non-exclusive, perpetual (for the duration of the applicable copyright) license to exercise the rights in the Work as stated below:
40 |
41 | to Reproduce the Work, to incorporate the Work into one or more Collections, and to Reproduce the Work as incorporated in the Collections;
42 | to create and Reproduce Adaptations provided that any such Adaptation, including any translation in any medium, takes reasonable steps to clearly label, demarcate or otherwise identify that changes were made to the original Work. For example, a translation could be marked "The original work was translated from English to Spanish," or a modification could indicate "The original work has been modified.";
43 | to Distribute and Publicly Perform the Work including as incorporated in Collections; and,
44 | to Distribute and Publicly Perform Adaptations.
45 | The above rights may be exercised in all media and formats whether now known or hereafter devised. The above rights include the right to make such modifications as are technically necessary to exercise the rights in other media and formats. Subject to Section 8(f), all rights not expressly granted by Licensor are hereby reserved, including but not limited to the rights described in Section 4(e).
46 |
47 | 4. Restrictions. The license granted in Section 3 above is expressly made subject to and limited by the following restrictions:
48 |
49 | You may Distribute or Publicly Perform the Work only under the terms of this License. You must include a copy of, or the Uniform Resource Identifier (URI) for, this License with every copy of the Work You Distribute or Publicly Perform. You may not offer or impose any terms on the Work that restrict the terms of this License or the ability of the recipient of the Work to exercise the rights granted to that recipient under the terms of the License. You may not sublicense the Work. You must keep intact all notices that refer to this License and to the disclaimer of warranties with every copy of the Work You Distribute or Publicly Perform. When You Distribute or Publicly Perform the Work, You may not impose any effective technological measures on the Work that restrict the ability of a recipient of the Work from You to exercise the rights granted to that recipient under the terms of the License. This Section 4(a) applies to the Work as incorporated in a Collection, but this does not require the Collection apart from the Work itself to be made subject to the terms of this License. If You create a Collection, upon notice from any Licensor You must, to the extent practicable, remove from the Collection any credit as required by Section 4(d), as requested. If You create an Adaptation, upon notice from any Licensor You must, to the extent practicable, remove from the Adaptation any credit as required by Section 4(d), as requested.
50 | You may Distribute or Publicly Perform an Adaptation only under: (i) the terms of this License; (ii) a later version of this License with the same License Elements as this License; (iii) a Creative Commons jurisdiction license (either this or a later license version) that contains the same License Elements as this License (e.g., Attribution-NonCommercial-ShareAlike 3.0 US) ("Applicable License"). You must include a copy of, or the URI, for Applicable License with every copy of each Adaptation You Distribute or Publicly Perform. You may not offer or impose any terms on the Adaptation that restrict the terms of the Applicable License or the ability of the recipient of the Adaptation to exercise the rights granted to that recipient under the terms of the Applicable License. You must keep intact all notices that refer to the Applicable License and to the disclaimer of warranties with every copy of the Work as included in the Adaptation You Distribute or Publicly Perform. When You Distribute or Publicly Perform the Adaptation, You may not impose any effective technological measures on the Adaptation that restrict the ability of a recipient of the Adaptation from You to exercise the rights granted to that recipient under the terms of the Applicable License. This Section 4(b) applies to the Adaptation as incorporated in a Collection, but this does not require the Collection apart from the Adaptation itself to be made subject to the terms of the Applicable License.
51 | You may not exercise any of the rights granted to You in Section 3 above in any manner that is primarily intended for or directed toward commercial advantage or private monetary compensation. The exchange of the Work for other copyrighted works by means of digital file-sharing or otherwise shall not be considered to be intended for or directed toward commercial advantage or private monetary compensation, provided there is no payment of any monetary compensation in con-nection with the exchange of copyrighted works.
52 | If You Distribute, or Publicly Perform the Work or any Adaptations or Collections, You must, unless a request has been made pursuant to Section 4(a), keep intact all copyright notices for the Work and provide, reasonable to the medium or means You are utilizing: (i) the name of the Original Author (or pseudonym, if applicable) if supplied, and/or if the Original Author and/or Licensor designate another party or parties (e.g., a sponsor institute, publishing entity, journal) for attribution ("Attribution Parties") in Licensor's copyright notice, terms of service or by other reasonable means, the name of such party or parties; (ii) the title of the Work if supplied; (iii) to the extent reasonably practicable, the URI, if any, that Licensor specifies to be associated with the Work, unless such URI does not refer to the copyright notice or licensing information for the Work; and, (iv) consistent with Section 3(b), in the case of an Adaptation, a credit identifying the use of the Work in the Adaptation (e.g., "French translation of the Work by Original Author," or "Screenplay based on original Work by Original Author"). The credit required by this Section 4(d) may be implemented in any reasonable manner; provided, however, that in the case of a Adaptation or Collection, at a minimum such credit will appear, if a credit for all contributing authors of the Adaptation or Collection appears, then as part of these credits and in a manner at least as prominent as the credits for the other contributing authors. For the avoidance of doubt, You may only use the credit required by this Section for the purpose of attribution in the manner set out above and, by exercising Your rights under this License, You may not implicitly or explicitly assert or imply any connection with, sponsorship or endorsement by the Original Author, Licensor and/or Attribution Parties, as appropriate, of You or Your use of the Work, without the separate, express prior written permission of the Original Author, Licensor and/or Attribution Parties.
53 | For the avoidance of doubt:
54 |
55 | Non-waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme cannot be waived, the Licensor reserves the exclusive right to collect such royalties for any exercise by You of the rights granted under this License;
56 | Waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme can be waived, the Licensor reserves the exclusive right to collect such royalties for any exercise by You of the rights granted under this License if Your exercise of such rights is for a purpose or use which is otherwise than noncommercial as permitted under Section 4(c) and otherwise waives the right to collect royalties through any statutory or compulsory licensing scheme; and,
57 | Voluntary License Schemes. The Licensor reserves the right to collect royalties, whether individually or, in the event that the Licensor is a member of a collecting society that administers voluntary licensing schemes, via that society, from any exercise by You of the rights granted under this License that is for a purpose or use which is otherwise than noncommercial as permitted under Section 4(c).
58 | Except as otherwise agreed in writing by the Licensor or as may be otherwise permitted by applicable law, if You Reproduce, Distribute or Publicly Perform the Work either by itself or as part of any Adaptations or Collections, You must not distort, mutilate, modify or take other derogatory action in relation to the Work which would be prejudicial to the Original Author's honor or reputation. Licensor agrees that in those jurisdictions (e.g. Japan), in which any exercise of the right granted in Section 3(b) of this License (the right to make Adaptations) would be deemed to be a distortion, mutilation, modification or other derogatory action prejudicial to the Original Author's honor and reputation, the Licensor will waive or not assert, as appropriate, this Section, to the fullest extent permitted by the applicable national law, to enable You to reasonably exercise Your right under Section 3(b) of this License (right to make Adaptations) but not otherwise.
59 | 5. Representations, Warranties and Disclaimer
60 |
61 | UNLESS OTHERWISE MUTUALLY AGREED TO BY THE PARTIES IN WRITING AND TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, LICENSOR OFFERS THE WORK AS-IS AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE WORK, EXPRESS, IMPLIED, STATUTORY OR OTHERWISE, INCLUDING, WITHOUT LIMITATION, WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NONINFRINGEMENT, OR THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OF ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OF IMPLIED WARRANTIES, SO THIS EXCLUSION MAY NOT APPLY TO YOU.
62 |
63 | 6. Limitation on Liability. EXCEPT TO THE EXTENT REQUIRED BY APPLICABLE LAW, IN NO EVENT WILL LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY FOR ANY SPECIAL, INCIDENTAL, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES ARISING OUT OF THIS LICENSE OR THE USE OF THE WORK, EVEN IF LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
64 |
65 | 7. Termination
66 |
67 | This License and the rights granted hereunder will terminate automatically upon any breach by You of the terms of this License. Individuals or entities who have received Adaptations or Collections from You under this License, however, will not have their licenses terminated provided such individuals or entities remain in full compliance with those licenses. Sections 1, 2, 5, 6, 7, and 8 will survive any termination of this License.
68 | Subject to the above terms and conditions, the license granted here is perpetual (for the duration of the applicable copyright in the Work). Notwithstanding the above, Licensor reserves the right to release the Work under different license terms or to stop distributing the Work at any time; provided, however that any such election will not serve to withdraw this License (or any other license that has been, or is required to be, granted under the terms of this License), and this License will continue in full force and effect unless terminated as stated above.
69 | 8. Miscellaneous
70 |
71 | Each time You Distribute or Publicly Perform the Work or a Collection, the Licensor offers to the recipient a license to the Work on the same terms and conditions as the license granted to You under this License.
72 | Each time You Distribute or Publicly Perform an Adaptation, Licensor offers to the recipient a license to the original Work on the same terms and conditions as the license granted to You under this License.
73 | If any provision of this License is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this License, and without further action by the parties to this agreement, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable.
74 | No term or provision of this License shall be deemed waived and no breach consented to unless such waiver or consent shall be in writing and signed by the party to be charged with such waiver or consent.
75 | This License constitutes the entire agreement between the parties with respect to the Work licensed here. There are no understandings, agreements or representations with respect to the Work not specified here. Licensor shall not be bound by any additional provisions that may appear in any communication from You. This License may not be modified without the mutual written agreement of the Licensor and You.
76 | The rights granted under, and the subject matter referenced, in this License were drafted utilizing the terminology of the Berne Convention for the Protection of Literary and Artistic Works (as amended on September 28, 1979), the Rome Convention of 1961, the WIPO Copyright Treaty of 1996, the WIPO Performances and Phonograms Treaty of 1996 and the Universal Copyright Convention (as revised on July 24, 1971). These rights and subject matter take effect in the relevant jurisdiction in which the License terms are sought to be enforced according to the corresponding provisions of the implementation of those treaty provisions in the applicable national law. If the standard suite of rights granted under applicable copyright law includes additional rights not granted under this License, such additional rights are deemed to be included in the License; this License is not intended to restrict the license of any rights under applicable law.
77 |
--------------------------------------------------------------------------------
/pado-main/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # Pado: Differentiable Light-wave Simulation
6 |
7 | Pado is a differentiable wave-optics library written in PyTorch. Thus, objects and operations defined in Pado are differentiable via PyTorch's automatic differentiation. Pado allows us to differentiate light-wave simulation and it can be integrated into any PyTorch-based network system. This makes Pado particularly useful for research in learning-based computational imaging and display.
8 |
9 | Pado provides high-level abstractions of light-wave simulation, which is useful for users lacking knowledge in wave optics. Pado achieves this with three main objects: light, optical element, and propagator. Constructing your own light-wave simulator with the three objects could improve your working efficiency by focusing on the core problem you want to tackle instead of studying and implementing wave optics from scratch.
10 |
11 | Pado is a Korean word for wave.
12 |
13 | # How to use Pado
14 | We provide a jupyter notebook (`./example/tutorial.ipynb`). More examples will be added later.
15 |
16 | # Prerequisites
17 | - Python
18 | - Pytorch
19 | - Numpy
20 | - Matplotlib
21 | - Scipy
22 |
23 | # About
24 | Pado is maintained and developed by [Seung-Hwan Baek](http://www.shbaek.com) at [POSTECH Computer Graphics Lab](http://cg.postech.ac.kr/).
25 | If you use Pado in your research, please cite Pado using the following BibText template:
26 |
27 | ```bib
28 | @misc{Pado,
29 | Author = {Seung-Hwan baek},
30 | Year = {2022},
31 | Note = {https://github.com/shwbaek/pado},
32 | Title = {Pado: Differentiable Light-wave Simulation}
33 | }
34 | ```
35 |
36 | # License
37 | Seung-Hwan Baek have developed this software and related documentation (the "Software"); confidential use in source form of the Software, without modification, is permitted provided that the following conditions are met:
38 |
39 | Neither the name of the copyright holder nor the names of any contributors may be used to endorse or promote products derived from the Software without specific prior written permission.
40 |
41 | The use of the software is for Non-Commercial Purposes only. As used in this Agreement, “Non-Commercial Purpose” means for the purpose of education or research in a non-commercial organization only. “Non-Commercial Purpose” excludes, without limitation, any use of the Software for, as part of, or in any way in connection with a product (including software) or service which is sold, offered for sale, licensed, leased, published, loaned or rented. If you require a license for a use excluded by this agreement, please email [shwbaek@postech.ac.kr].
42 |
43 | Warranty: POSTECH-CGLAB MAKES NO REPRESENTATIONS OR WARRANTIES ABOUT THE SUITABILITY OF THE SOFTWARE, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, OR NON-INFRINGEMENT. POSTECH-CGLAB SHALL NOT BE LIABLE FOR ANY DAMAGES SUFFERED BY LICENSEE AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THIS SOFTWARE OR ITS DERIVATIVES.
44 |
45 | Please refer to [here](./LICENSE) for more details.
46 |
--------------------------------------------------------------------------------
/pado-main/docs/images/logo.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/SeeThroughObstructions/f94a297cc87d72dcb17ffdab3bed64cff186ab9f/pado-main/docs/images/logo.pdf
--------------------------------------------------------------------------------
/pado-main/docs/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/SeeThroughObstructions/f94a297cc87d72dcb17ffdab3bed64cff186ab9f/pado-main/docs/images/logo.png
--------------------------------------------------------------------------------
/pado-main/docs/images/logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
121 |
--------------------------------------------------------------------------------
/pado-main/pado/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
--------------------------------------------------------------------------------
/pado-main/pado/cmap_phase.txt:
--------------------------------------------------------------------------------
1 | 0.180392 0.129412 0.917647
2 | 0.188235 0.121569 0.921569
3 | 0.200000 0.117647 0.925490
4 | 0.215686 0.109804 0.933333
5 | 0.227451 0.105882 0.937255
6 | 0.243137 0.101961 0.937255
7 | 0.254902 0.101961 0.941176
8 | 0.270588 0.101961 0.945098
9 | 0.286275 0.101961 0.949020
10 | 0.298039 0.105882 0.949020
11 | 0.313725 0.109804 0.952941
12 | 0.325490 0.113725 0.952941
13 | 0.341176 0.117647 0.956863
14 | 0.352941 0.121569 0.956863
15 | 0.364706 0.129412 0.960784
16 | 0.376471 0.133333 0.960784
17 | 0.388235 0.141176 0.960784
18 | 0.400000 0.145098 0.964706
19 | 0.411765 0.152941 0.964706
20 | 0.423529 0.156863 0.964706
21 | 0.435294 0.164706 0.968627
22 | 0.447059 0.168627 0.968627
23 | 0.458824 0.176471 0.972549
24 | 0.470588 0.180392 0.972549
25 | 0.478431 0.184314 0.972549
26 | 0.490196 0.192157 0.972549
27 | 0.501961 0.196078 0.976471
28 | 0.513725 0.200000 0.976471
29 | 0.521569 0.207843 0.976471
30 | 0.533333 0.211765 0.980392
31 | 0.545098 0.215686 0.980392
32 | 0.556863 0.219608 0.980392
33 | 0.568627 0.223529 0.980392
34 | 0.580392 0.227451 0.980392
35 | 0.592157 0.231373 0.980392
36 | 0.603922 0.235294 0.984314
37 | 0.615686 0.239216 0.984314
38 | 0.627451 0.239216 0.984314
39 | 0.643137 0.243137 0.984314
40 | 0.654902 0.247059 0.984314
41 | 0.666667 0.247059 0.984314
42 | 0.682353 0.250980 0.984314
43 | 0.694118 0.250980 0.984314
44 | 0.705882 0.254902 0.984314
45 | 0.717647 0.254902 0.984314
46 | 0.733333 0.258824 0.980392
47 | 0.745098 0.258824 0.980392
48 | 0.756863 0.262745 0.980392
49 | 0.772549 0.262745 0.980392
50 | 0.784314 0.266667 0.980392
51 | 0.796078 0.266667 0.980392
52 | 0.807843 0.270588 0.980392
53 | 0.819608 0.270588 0.976471
54 | 0.831373 0.274510 0.976471
55 | 0.843137 0.274510 0.976471
56 | 0.854902 0.278431 0.976471
57 | 0.866667 0.282353 0.972549
58 | 0.878431 0.286275 0.972549
59 | 0.886275 0.290196 0.968627
60 | 0.898039 0.294118 0.964706
61 | 0.905882 0.301961 0.964706
62 | 0.913725 0.309804 0.960784
63 | 0.921569 0.313725 0.956863
64 | 0.929412 0.321569 0.952941
65 | 0.937255 0.333333 0.945098
66 | 0.941176 0.341176 0.941176
67 | 0.949020 0.349020 0.933333
68 | 0.952941 0.360784 0.925490
69 | 0.956863 0.372549 0.921569
70 | 0.960784 0.384314 0.913725
71 | 0.964706 0.396078 0.901961
72 | 0.968627 0.407843 0.894118
73 | 0.968627 0.419608 0.886275
74 | 0.972549 0.431373 0.878431
75 | 0.976471 0.443137 0.866667
76 | 0.976471 0.454902 0.858824
77 | 0.976471 0.466667 0.847059
78 | 0.980392 0.478431 0.839216
79 | 0.980392 0.494118 0.827451
80 | 0.980392 0.505882 0.815686
81 | 0.984314 0.517647 0.807843
82 | 0.984314 0.529412 0.796078
83 | 0.984314 0.537255 0.784314
84 | 0.984314 0.549020 0.772549
85 | 0.984314 0.560784 0.764706
86 | 0.984314 0.572549 0.752941
87 | 0.984314 0.584314 0.741176
88 | 0.984314 0.596078 0.729412
89 | 0.984314 0.603922 0.721569
90 | 0.984314 0.615686 0.709804
91 | 0.984314 0.627451 0.698039
92 | 0.984314 0.635294 0.686275
93 | 0.984314 0.647059 0.678431
94 | 0.984314 0.658824 0.666667
95 | 0.984314 0.666667 0.654902
96 | 0.984314 0.678431 0.643137
97 | 0.984314 0.686275 0.631373
98 | 0.984314 0.698039 0.619608
99 | 0.984314 0.705882 0.611765
100 | 0.984314 0.717647 0.600000
101 | 0.984314 0.725490 0.588235
102 | 0.984314 0.733333 0.576471
103 | 0.984314 0.745098 0.564706
104 | 0.984314 0.752941 0.552941
105 | 0.984314 0.760784 0.541176
106 | 0.984314 0.768627 0.529412
107 | 0.984314 0.780392 0.517647
108 | 0.984314 0.788235 0.505882
109 | 0.984314 0.796078 0.494118
110 | 0.988235 0.803922 0.482353
111 | 0.988235 0.811765 0.466667
112 | 0.988235 0.823529 0.454902
113 | 0.988235 0.831373 0.443137
114 | 0.988235 0.839216 0.431373
115 | 0.988235 0.847059 0.415686
116 | 0.984314 0.854902 0.403922
117 | 0.984314 0.862745 0.388235
118 | 0.984314 0.870588 0.372549
119 | 0.984314 0.878431 0.360784
120 | 0.980392 0.886275 0.345098
121 | 0.980392 0.894118 0.329412
122 | 0.976471 0.898039 0.313725
123 | 0.972549 0.905882 0.298039
124 | 0.968627 0.909804 0.282353
125 | 0.964706 0.917647 0.266667
126 | 0.960784 0.921569 0.250980
127 | 0.952941 0.925490 0.235294
128 | 0.949020 0.929412 0.223529
129 | 0.941176 0.929412 0.207843
130 | 0.933333 0.933333 0.192157
131 | 0.925490 0.933333 0.180392
132 | 0.917647 0.933333 0.168627
133 | 0.909804 0.933333 0.156863
134 | 0.898039 0.933333 0.145098
135 | 0.890196 0.929412 0.133333
136 | 0.878431 0.929412 0.125490
137 | 0.866667 0.925490 0.117647
138 | 0.858824 0.921569 0.109804
139 | 0.847059 0.921569 0.105882
140 | 0.835294 0.917647 0.101961
141 | 0.823529 0.913725 0.098039
142 | 0.811765 0.909804 0.094118
143 | 0.800000 0.905882 0.090196
144 | 0.788235 0.901961 0.086275
145 | 0.776471 0.898039 0.086275
146 | 0.768627 0.894118 0.082353
147 | 0.756863 0.890196 0.082353
148 | 0.745098 0.886275 0.078431
149 | 0.733333 0.878431 0.078431
150 | 0.721569 0.874510 0.078431
151 | 0.709804 0.870588 0.074510
152 | 0.698039 0.866667 0.074510
153 | 0.686275 0.862745 0.074510
154 | 0.674510 0.858824 0.070588
155 | 0.662745 0.854902 0.070588
156 | 0.650980 0.850980 0.070588
157 | 0.639216 0.847059 0.066667
158 | 0.627451 0.843137 0.066667
159 | 0.611765 0.835294 0.066667
160 | 0.600000 0.831373 0.062745
161 | 0.588235 0.827451 0.062745
162 | 0.576471 0.823529 0.062745
163 | 0.564706 0.819608 0.058824
164 | 0.552941 0.815686 0.058824
165 | 0.541176 0.811765 0.058824
166 | 0.525490 0.807843 0.054902
167 | 0.513725 0.803922 0.054902
168 | 0.501961 0.796078 0.054902
169 | 0.490196 0.792157 0.050980
170 | 0.474510 0.788235 0.050980
171 | 0.462745 0.784314 0.050980
172 | 0.450980 0.780392 0.047059
173 | 0.435294 0.776471 0.047059
174 | 0.423529 0.772549 0.047059
175 | 0.407843 0.764706 0.047059
176 | 0.396078 0.760784 0.043137
177 | 0.380392 0.756863 0.043137
178 | 0.364706 0.752941 0.043137
179 | 0.352941 0.749020 0.047059
180 | 0.337255 0.745098 0.047059
181 | 0.321569 0.737255 0.047059
182 | 0.309804 0.733333 0.050980
183 | 0.294118 0.729412 0.054902
184 | 0.278431 0.725490 0.062745
185 | 0.266667 0.717647 0.066667
186 | 0.250980 0.713725 0.074510
187 | 0.239216 0.709804 0.086275
188 | 0.227451 0.701961 0.094118
189 | 0.219608 0.698039 0.105882
190 | 0.207843 0.694118 0.117647
191 | 0.203922 0.686275 0.129412
192 | 0.196078 0.682353 0.141176
193 | 0.192157 0.674510 0.156863
194 | 0.192157 0.670588 0.168627
195 | 0.192157 0.662745 0.184314
196 | 0.196078 0.654902 0.200000
197 | 0.200000 0.650980 0.211765
198 | 0.203922 0.643137 0.227451
199 | 0.211765 0.635294 0.243137
200 | 0.215686 0.627451 0.258824
201 | 0.223529 0.619608 0.274510
202 | 0.227451 0.615686 0.290196
203 | 0.235294 0.607843 0.301961
204 | 0.239216 0.600000 0.317647
205 | 0.243137 0.592157 0.333333
206 | 0.250980 0.584314 0.349020
207 | 0.254902 0.576471 0.360784
208 | 0.258824 0.568627 0.376471
209 | 0.262745 0.564706 0.392157
210 | 0.262745 0.556863 0.403922
211 | 0.266667 0.549020 0.419608
212 | 0.266667 0.541176 0.431373
213 | 0.270588 0.533333 0.447059
214 | 0.270588 0.525490 0.458824
215 | 0.270588 0.517647 0.474510
216 | 0.270588 0.509804 0.486275
217 | 0.266667 0.501961 0.501961
218 | 0.266667 0.498039 0.513725
219 | 0.262745 0.490196 0.525490
220 | 0.262745 0.482353 0.541176
221 | 0.258824 0.474510 0.552941
222 | 0.254902 0.466667 0.564706
223 | 0.250980 0.458824 0.580392
224 | 0.243137 0.450980 0.592157
225 | 0.239216 0.443137 0.603922
226 | 0.235294 0.435294 0.615686
227 | 0.227451 0.427451 0.627451
228 | 0.223529 0.419608 0.643137
229 | 0.215686 0.411765 0.654902
230 | 0.207843 0.403922 0.666667
231 | 0.203922 0.396078 0.678431
232 | 0.196078 0.388235 0.690196
233 | 0.192157 0.380392 0.701961
234 | 0.184314 0.372549 0.709804
235 | 0.180392 0.364706 0.721569
236 | 0.176471 0.352941 0.733333
237 | 0.172549 0.345098 0.745098
238 | 0.168627 0.333333 0.752941
239 | 0.164706 0.325490 0.764706
240 | 0.160784 0.313725 0.772549
241 | 0.156863 0.305882 0.784314
242 | 0.156863 0.294118 0.792157
243 | 0.152941 0.282353 0.803922
244 | 0.149020 0.274510 0.811765
245 | 0.149020 0.262745 0.823529
246 | 0.145098 0.250980 0.831373
247 | 0.145098 0.239216 0.839216
248 | 0.141176 0.227451 0.850980
249 | 0.141176 0.215686 0.858824
250 | 0.141176 0.203922 0.866667
251 | 0.141176 0.192157 0.874510
252 | 0.145098 0.180392 0.882353
253 | 0.149020 0.168627 0.890196
254 | 0.152941 0.160784 0.898039
255 | 0.160784 0.149020 0.905882
256 | 0.168627 0.141176 0.909804
257 |
--------------------------------------------------------------------------------
/pado-main/pado/conv.py:
--------------------------------------------------------------------------------
1 | from .fourier import fft, ifft
2 |
3 | def conv_fft(img_c, kernel_c, pad_width=None):
4 | """
5 | Compute the convolution of an image with a convolution kernel using FFT
6 | Args:
7 | img_c: [B,Ch,R,C] image as a complex tensor
8 | kernel_c: [B,Ch,R,C] convolution kernel as a complex tensor
9 | pad_width: (tensor) pad width for the last spatial dimensions. should be (0,0,0,0) for circular convolution. for linear convolution, pad zero by the size of the original image
10 | Returns:
11 | im_conv: [B,Ch,R,C] blurred image
12 | """
13 |
14 | img_fft = fft(img_c, pad_width=pad_width)
15 | kernel_fft = fft(kernel_c, pad_width=pad_width)
16 | return ifft( img_fft * kernel_fft, pad_width=pad_width)
17 |
18 |
--------------------------------------------------------------------------------
/pado-main/pado/fourier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .complex import Complex
3 | import matplotlib.pyplot as plt
4 |
5 |
6 | def fft(arr_c, normalized=False, pad_width=None, padval=0, shift=True):
7 | """
8 | Compute the Fast Fourier transform of a complex tensor
9 | Args:
10 | arr_c: [B,Ch,R,C] complex tensor
11 | normalized: Normalize the FFT output. default: False
12 | pad_width: (tensor) pad width for the last spatial dimensions.
13 | shift: flag for shifting the input data to make the zero-frequency located at the center of the arrc
14 | Returns:
15 | arr_c_fft: [B,Ch,R,C] FFT of the input complex tensor
16 | """
17 |
18 | arr_c = Complex(mag=arr_c.get_mag().clone(), ang=arr_c.get_ang().clone())
19 | if pad_width is not None:
20 | if padval == 0:
21 | arr_c.pad_zero(pad_width)
22 | else:
23 | return NotImplementedError('zero padding is only implemented for now')
24 | if shift:
25 | arr_c_shifted = ifftshift(arr_c)
26 | else:
27 | arr_c_shifted = arr_c
28 | arr_c_shifted.to_rect()
29 |
30 | #arr_c_shifted_stack = arr_c_shifted.get_stack()
31 | if normalized is False:
32 | normalized = "backward"
33 | else:
34 | normalized = "forward"
35 | arr_c_shifted_fft = torch.fft.fft2(arr_c_shifted.get_native() , norm=normalized)
36 | arr_c_shifted_fft_c = Complex(real=arr_c_shifted_fft.real, imag=arr_c_shifted_fft.imag)
37 | if shift:
38 | arr_c_fft = fftshift(arr_c_shifted_fft_c)
39 | else:
40 | arr_c_fft = arr_c_shifted_fft_c
41 |
42 | return arr_c_fft
43 |
44 |
45 |
46 | def ifft(arr_c, normalized=False, pad_width=None, shift=True):
47 | """
48 | Compute the inverse Fast Fourier transform of a complex tensor
49 | Args:
50 | arr_c: [B,Ch,R,C] complex tensor
51 | normalized: Normalize the FFT output. default: False
52 | pad_width: (tensor) pad width for the last spatial dimensions.
53 | shift: flag for inversely shifting the input data
54 | Returns:
55 | arr_c_fft: [B,Ch,R,C] inverse FFT of the input complex tensor
56 | """
57 |
58 | arr_c = Complex(mag=arr_c.get_mag().clone(), ang=arr_c.get_ang().clone())
59 | if shift:
60 | arr_c_shifted = ifftshift(arr_c)
61 | else:
62 | arr_c_shifted = arr_c
63 |
64 | arr_c_shifted.to_rect()
65 | if normalized is False:
66 | normalized = "backward"
67 | else:
68 | normalized = "forward"
69 | arr_c_shifted_fft = torch.fft.ifft2(arr_c_shifted.get_native(), norm=normalized)
70 | arr_c_shifted_fft_c = Complex(real=arr_c_shifted_fft.real, imag=arr_c_shifted_fft.imag)
71 | if shift:
72 | arr_c_fft = fftshift(arr_c_shifted_fft_c)
73 | else:
74 | arr_c_fft = arr_c_shifted_fft_c
75 |
76 | if pad_width is not None:
77 | arr_c_fft.crop(pad_width)
78 |
79 | return arr_c_fft
80 |
81 | def fftshift(arr_c, invert=False):
82 | """
83 | Shift the complex tensor so that the zero-frequency signal located at the center of the input
84 | Args:
85 | arr_c: [B,Ch,R,C] complex tensor
86 | invert: flag for inversely shifting the input data
87 | Returns:
88 | arr_c: [B,Ch,R,C] shifted tensor
89 | """
90 |
91 | arr_c.to_rect()
92 | shift_adjust = 0 if invert else 1
93 |
94 | arr_c_shape = arr_c.shape()
95 | C = arr_c_shape[-1]
96 | R = arr_c_shape[-2]
97 |
98 | shift_len = (C + shift_adjust) // 2
99 | arr_c = arr_c[...,shift_len:].cat(arr_c[...,:shift_len], -1)
100 |
101 | shift_len = (R + shift_adjust) // 2
102 | arr_c = arr_c[...,shift_len:,:].cat(arr_c[...,:shift_len,:], -2)
103 |
104 | return arr_c
105 |
106 |
107 | def ifftshift(arr_c):
108 | """
109 | Inversely shift the complex tensor
110 | Args:
111 | arr_c: [B,Ch,R,C] complex tensor
112 | Returns:
113 | arr_c: [B,Ch,R,C] shifted tensor
114 | """
115 |
116 | return fftshift(arr_c, invert=True)
117 |
--------------------------------------------------------------------------------
/pado-main/pado/light.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from .complex import Complex
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | from scipy.io import savemat
7 |
8 | class Light:
9 | def __init__(self, R, C, pitch, wvl, device, amplitude=None, phase=None, real=None, imag=None, B=1):
10 | """
11 | Light wave that has a complex field (B,Ch,R,C) as a wavefront
12 | It takes the input wavefront in one of the following types
13 | 1. amplitude and phase
14 | 2. real and imaginary
15 | 3. None ==> we initialize light with amplitude of one and phase of zero
16 |
17 | Args:
18 | R: row
19 | C: column
20 | pitch: pixel pitch in meter
21 | wvl: wavelength of light in meter
22 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ...
23 | amplitude: [batch_size, # of channels, row, column] tensor of wavefront amplitude, default is None
24 | phase: [batch_size, # of channels, row, column] tensor of wavefront amplitude, default is None
25 | real: [batch_size, # of channels, row, column] tensor of wavefront real part, default is None
26 | imag: [batch_size, # of channels, row, column] tensor of wavefront imaginary part, default is None
27 | B: batch size
28 | """
29 |
30 | self.B = B
31 | self.R = R
32 | self.C = C
33 | self.pitch = pitch
34 | self.device = device
35 | self.wvl = wvl
36 |
37 | if (amplitude==None) and (phase == None) and (real != None) and (imag != None):
38 | self.field = Complex(real=real, imag=imag)
39 |
40 | elif (amplitude!=None) and (phase != None) and (real == None) and (imag == None):
41 | self.field = Complex(mag=amplitude, ang=phase)
42 |
43 | elif (amplitude==None) and (phase == None) and (real == None) and (imag == None):
44 | amplitude = torch.ones((B, 1, self.R, self.C), device=self.device)
45 | phase = torch.zeros((B, 1, self.R, self.C), device=self.device)
46 | self.field = Complex(mag=amplitude, ang=phase)
47 |
48 | else:
49 | NotImplementedError('nope!')
50 |
51 |
52 | def crop(self, crop_width):
53 | """
54 | Crop the light wavefront by crop_width
55 | Args:
56 | crop_width: (tuple) crop width of the tensor following torch functional pad
57 | """
58 |
59 | self.field.crop(crop_width)
60 | self.R = self.field.size(2)
61 | self.C = self.field.size(3)
62 |
63 | def clone(self):
64 | """
65 | Clone the light and return it
66 | """
67 |
68 | return Light(self.R, self.C,
69 | self.pitch, self.wvl, self.device,
70 | amplitude=self.field.get_mag().clone(), phase=self.field.get_ang().clone(),
71 | B=self.B)
72 |
73 |
74 | def pad(self, pad_width, padval=0):
75 | """
76 | Pad the light wavefront with a constant value by pad_width
77 | Args:
78 | pad_width: (tuple) pad width of the tensor following torch functional pad
79 | padval: value to pad. default is zero
80 | """
81 |
82 | if padval == 0:
83 | self.set_amplitude(torch.nn.functional.pad(self.get_amplitude(), pad_width))
84 | self.set_phase(torch.nn.functional.pad(self.get_phase(), pad_width))
85 | else:
86 | return NotImplementedError('only zero padding supported')
87 |
88 | self.R += pad_width[0] + pad_width[1]
89 | self.C += pad_width[2] + pad_width[3]
90 |
91 | def set_real(self, real):
92 | """
93 | Set the real part of the light wavefront
94 | Args:
95 | real: real part in the rect representation of the complex number
96 | """
97 |
98 | self.field.set_real(real)
99 |
100 | def set_imag(self, imag):
101 | """
102 | Set the imaginary part of the light wavefront
103 | Args:
104 | imag: imaginary part in the rect representation of the complex number
105 | """
106 |
107 | self.field.set_imag(imag)
108 |
109 | def set_amplitude(self, amplitude):
110 | """
111 | Set the amplitude of the light wavefront
112 | Args:
113 | amplitude: amplitude in the polar representation of the complex number
114 | """
115 | self.field.set_mag(amplitude)
116 |
117 | def set_phase(self, phase):
118 | """
119 | Set the phase of the complex tensor
120 | Args:
121 | phase: phase in the polar representation of the complex number
122 | """
123 | self.field.set_ang(phase)
124 |
125 |
126 | def set_field(self, field):
127 | """
128 | Set the wavefront modulation of the complex tensor
129 | Args:
130 | field: wavefront as a complex number
131 | """
132 | self.field = field
133 |
134 | def set_pitch(self, pitch):
135 | """
136 | Set the pixel pitch of the complex tensor
137 | Args:
138 | pitch: pixel pitch in meter
139 | """
140 | self.pitch = pitch
141 |
142 |
143 | def get_amplitude(self):
144 | """
145 | Return the amplitude of the wavefront
146 | Returns:
147 | mag: magnitude in the polar representation of the complex number
148 | """
149 |
150 | return self.field.get_mag()
151 |
152 | def get_phase(self):
153 | """
154 | Return the phase of the wavefront
155 | Returns:
156 | ang: angle in the polar representation of the complex number
157 | """
158 |
159 | return self.field.get_ang()
160 |
161 | def get_field(self):
162 | """
163 | Return the complex wavefront
164 | Returns:
165 | field: complex wavefront
166 | """
167 |
168 | return self.field
169 |
170 | def get_intensity(self):
171 | """
172 | Return the intensity of light wavefront
173 | Returns:
174 | intensity: intensity of light
175 | """
176 | return self.field.get_intensity()
177 |
178 | def get_bandwidth(self):
179 | """
180 | Return the bandwidth of light wavefront
181 | Returns:
182 | R_m: spatial height of the wavefront
183 | C_m: spatial width of the wavefront
184 | """
185 |
186 | return self.pitch*self.R, self.pitch*self.C
187 |
188 | def magnify(self, scale_factor, interp_mode='nearest'):
189 | '''
190 | Change the wavefront resolution without changing the pixel pitch
191 | Args:
192 | scale_factor: scale factor for interpolation used in tensor.nn.functional.interpolate
193 | interp_mode: interpolation method used in torch.nn.functional.interpolate 'bilinear', 'nearest'
194 | '''
195 | self.field.resize(scale_factor, interp_mode)
196 | self.R = self.field.mag.shape[-2]
197 | self.C = self.field.mag.shape[-1]
198 |
199 |
200 | def resize(self, target_pitch, interp_mode='nearest'):
201 | '''
202 | Resize the wavefront by changing the pixel pitch.
203 | Args:
204 | target_pitch: new pixel pitch to use
205 | interp_mode: interpolation method used in torch.nn.functional.interpolate 'bilinear', 'nearest'
206 | '''
207 | scale_factor = self.pitch / target_pitch
208 | self.magnify(scale_factor, interp_mode)
209 | self.set_pitch(target_pitch)
210 |
211 | def set_spherical_light(self, z, dx=0, dy=0):
212 | '''
213 | Set the wavefront as spherical one coming from the position of (dx,dy,z).
214 | Args:
215 | z: z distance of the spherical light source from the current light position
216 | dx: x distance of the spherical light source from the current light position
217 | dy: y distance of the spherical light source from the current light position
218 | '''
219 |
220 | [x, y] = np.mgrid[-self.C // 2:self.C // 2, -self.R // 2:self.R // 2].astype(np.float64)
221 | x = x * self.pitch
222 | y = y * self.pitch
223 | r = np.sqrt((x - dx) ** 2 + (y - dy) ** 2 + z ** 2) # this is computed in double precision
224 | theta = 2 * np.pi * r / self.wvl
225 | theta = np.expand_dims(np.expand_dims(theta, axis=0), axis=0)%(2*np.pi)
226 | theta = theta.astype(np.float32)
227 |
228 | theta = torch.tensor(theta, device=self.device)
229 | mag = torch.ones_like(theta)
230 |
231 | self.set_phase(theta)
232 | self.set_amplitude(mag)
233 |
234 | def set_plane_light(self):
235 | '''
236 | Set the wavefront as a plane wave with zero phase and amptliude of one
237 | '''
238 | amplitude = torch.ones((1, 1, self.R, self.C), device=self.device)
239 | phase = torch.zeros((1, 1, self.R, self.C), device=self.device)
240 | self.set_amplitude(amplitude)
241 | self.set_phase(phase)
242 |
243 | def save(self, fn):
244 | '''
245 | Save the amplitude and phase of the light wavefront as a file
246 | Args:
247 | fn: filename to save. the format should be either "npy" or "mat"
248 |
249 | '''
250 |
251 | amp = self.get_amplitude().data.cpu().numpy()
252 | phase = self.get_phase().data.cpu().numpy()
253 |
254 | if fn[-3:] == 'npy':
255 | np.save(fn, amp, phase)
256 | elif fn[-3:] == 'mat':
257 | savemat(fn, {'amplitude':amp, 'phase':phase})
258 | else:
259 | print('extension in %s is unknown'%fn)
260 | print('light saved to %s\n'%fn)
261 |
262 | def visualize(self,b=0,c=0):
263 | """
264 | Visualize the light wave
265 | Args:
266 | b: batch index to visualize default is 0
267 | c: channel index to visualize. default is 0
268 | """
269 |
270 | bw = self.get_bandwidth()
271 |
272 | plt.figure(figsize=(20,5))
273 | plt.subplot(131)
274 | amplitude_b = self.get_amplitude().data.cpu()[b,c,...].squeeze()
275 | plt.imshow(amplitude_b,
276 | extent=[0,bw[0]*1e3, 0, bw[1]*1e3], cmap='inferno')
277 | plt.title('amplitude')
278 | plt.xlabel('mm')
279 | plt.ylabel('mm')
280 | plt.colorbar()
281 |
282 | plt.subplot(132)
283 | phase = self.get_phase().data.cpu()[b,c,...].squeeze()
284 | plt.imshow(self.get_phase().data.cpu()[b,c,...].squeeze(),
285 | extent=[0,bw[0]*1e3, 0, bw[1]*1e3], cmap='hsv', vmin=-np.pi, vmax=np.pi) # cyclic colormap
286 | plt.title('phase')
287 | plt.xlabel('mm')
288 | plt.ylabel('mm')
289 | plt.colorbar()
290 |
291 | plt.subplot(133)
292 | intensity_b = self.get_intensity().data.cpu()[b,c,...].squeeze()
293 | plt.imshow(intensity_b,
294 | extent=[0,bw[0]*1e3, 0, bw[1]*1e3], cmap='inferno')
295 | plt.title('intensity')
296 | plt.xlabel('mm')
297 | plt.ylabel('mm')
298 | plt.colorbar()
299 |
300 | plt.suptitle('(%d,%d), pitch:%.2f[um], wvl:%.2f[nm], device:%s'%(self.R, self.C,
301 | self.pitch/1e-6, self.wvl/1e-9, self.device))
302 | plt.show()
303 |
304 | def shape(self):
305 | """
306 | Returns the shape of light wavefront
307 | Returns:
308 | shape
309 | """
310 | return self.field.shape()
311 |
--------------------------------------------------------------------------------
/pado-main/pado/material.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Material:
5 | def __init__(self, material_name):
6 | """
7 | Material of optical elements and its refractive index.
8 | Args:
9 | material name: name of the material. So far, we provide PDMS, FUSED_SILICA, VACCUM.
10 | """
11 | self.material_name = material_name
12 |
13 | def get_RI(self, wvl):
14 | """
15 | Return the refractive index of the current material for a wavelength
16 | Args:
17 | wvl: wavelength in meter
18 | Returns:
19 | RI: Refractive index at wvl
20 | """
21 |
22 | wvl_nm = wvl / 1e-9
23 | if self.material_name == 'PDMS':
24 | RI = np.sqrt(1 + (1.0057 * (wvl_nm**2))/(wvl_nm**2 - 0.013217))
25 | elif self.material_name == 'FUSED_SILICA':
26 | wvl_um = wvl_nm*1e-3
27 | RI = (1 + 0.6961663 / (1 - (0.0684043 / wvl_um) ** 2) + 0.4079426 / (1 - (0.1162414 / wvl_um) ** 2) + 0.8974794 / (1 - (9.896161 / wvl_um) ** 2)) ** .5
28 | # https://refractiveindex.info/?shelf=glass&book=fused_silica&page=Malitson
29 | elif self.material_name == 'VACUUM':
30 | RI = 1.0
31 | else:
32 | return NotImplementedError('%s is not in the RI list'%self.material_name)
33 |
34 | return RI
35 |
36 |
--------------------------------------------------------------------------------
/pado-main/pado/optical_element.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import matplotlib.pyplot as plt
5 |
6 |
7 | class OpticalElement:
8 | def __init__(self, R, C, pitch, wvl, device, name="not defined",B=1):
9 | """
10 | Base class for optical elements. Any optical element change the wavefront of incident light
11 | The change of the wavefront is stored as amplitude and phase tensors
12 | Note that he number of channels is one for the wavefront modulation.
13 | Args:
14 | R: row
15 | C: column
16 | pitch: pixel pitch in meter
17 | wvl: wavelength of light in meter
18 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ...
19 | name: name of the current optical element
20 | B: batch size
21 | """
22 |
23 | self.name = name
24 | self.B = B
25 | self.R = R
26 | self.C = C
27 | self.pitch = pitch
28 | self.device = device
29 | self.amplitude_change = torch.ones((B, 1, R, C), device=self.device)
30 | self.phase_change = torch.zeros((B, 1, R, C), device=self.device)
31 | self.wvl = wvl
32 |
33 | def shape(self):
34 | """
35 | Returns the shape of light-wavefront modulation. The nunmber of channels is one
36 | Returns:
37 | shape
38 | """
39 | return (self.B,1,self.R,self.C)
40 |
41 | def set_pitch(self, pitch):
42 | """
43 | Set the pixel pitch of the complex tensor
44 | Args:
45 | pitch: pixel pitch in meter
46 | """
47 | self.pitch = pitch
48 |
49 | def resize(self, target_pitch, interp_mode='nearest'):
50 | '''
51 | Resize the wavefront change by changing the pixel pitch.
52 | Args:
53 | target_pitch: new pixel pitch to use
54 | interp_mode: interpolation method used in torch.nn.functional.interpolate 'bilinear', 'nearest'
55 | '''
56 |
57 | scale_factor = self.pitch / target_pitch
58 | self.amplitude_change = F.interpolate(self.amplitude_change, scale_factor=scale_factor,
59 | mode=interp_mode)
60 | self.phase_change = F.interpolate(self.phase_change, scale_factor=scale_factor,
61 | mode=interp_mode)
62 | self.set_pitch(target_pitch)
63 | self.R = self.amplitude_change.shape[-2]
64 | self.C = self.amplitude_change.shape[-1]
65 |
66 | def get_amplitude_change(self):
67 | '''
68 | Return the amplitude change of the wavefront
69 | Returns:
70 | amplitude change: ampiltude change
71 | '''
72 |
73 | return self.amplitude_change
74 |
75 | def get_phase_change(self):
76 | '''
77 | Return the phase change of the wavefront
78 | Returns:
79 | phase change: phase change
80 | '''
81 |
82 | return self.phase_change
83 |
84 | def set_amplitude_change(self, amplitude):
85 | """
86 | Set the amplitude change
87 | Args:
88 | amplitude change: amplitude change in the polar representation of the complex number
89 | """
90 |
91 | assert amplitude.shape[2] == self.R and amplitude.shape[3] == self.C
92 | self.amplitude_change = amplitude
93 |
94 | def set_phase_change(self, phase):
95 | """
96 | Set the phase change
97 | Args:
98 | phase change: phase change in the polar representation of the complex number
99 | """
100 |
101 | assert phase.shape[2] == self.R and phase.shape[3] == self.C
102 | self.phase_change = phase
103 |
104 | def pad(self, pad_width, padval=0):
105 | """
106 | Pad the wavefront change with a constant value by pad_width
107 | Args:
108 | pad_width: (tuple) pad width of the tensor following torch functional pad
109 | padval: value to pad. default is zero
110 | """
111 | if padval == 0:
112 | self.amplitude_change = torch.nn.functional.pad(self.get_amplitude_change(), pad_width)
113 | self.phase_change = torch.nn.functional.pad(self.get_phase_change(), pad_width)
114 | else:
115 | return NotImplementedError('only zero padding supported')
116 |
117 | self.R += pad_width[0] + pad_width[1]
118 | self.C += pad_width[2] + pad_width[3]
119 |
120 | def forward(self, light, interp_mode='nearest'):
121 | """
122 | Forward the incident light with the optical element.
123 | Args:
124 | light: incident light
125 | interp_mode: interpolation method used in torch.nn.functional.interpolate 'bilinear', 'nearest'
126 | Returns:
127 | light after interaction with the optical element
128 | """
129 |
130 | if light.pitch > self.pitch:
131 | light.resize(self.pitch, interp_mode)
132 | light.set_pitch(self.pitch)
133 | elif light.pitch < self.pitch:
134 | self.resize(light.pitch, interp_mode)
135 | self.set_pitch(light.pitch)
136 |
137 | if light.wvl != self.wvl:
138 | return NotImplementedError('wavelength should be same for light and optical elements')
139 |
140 | r1 = np.abs((light.R - self.R)//2)
141 | r2 = np.abs(light.R - self.R) - r1
142 | pad_width = (r1, r2, 0, 0)
143 | if light.R > self.R:
144 | self.pad(pad_width)
145 | elif light.R < self.R:
146 | light.pad(pad_width)
147 |
148 | c1 = np.abs((light.C - self.C)//2)
149 | c2 = np.abs(light.C - self.C) - c1
150 | pad_width = (0, 0, c1, c2)
151 | if light.C > self.C:
152 | self.pad(pad_width)
153 | elif light.C < self.C:
154 | light.pad(pad_width)
155 |
156 | light.set_phase(light.get_phase() + self.get_phase_change())
157 | light.set_amplitude(light.get_amplitude() * self.get_amplitude_change())
158 |
159 | return light
160 |
161 | def visualize(self, b=0):
162 | """
163 | Visualize the wavefront modulation of the optical element
164 | Args:
165 | b: batch index to visualize default is 0
166 | """
167 |
168 | plt.figure(figsize=(13,6))
169 |
170 | plt.subplot(121)
171 | plt.imshow(self.get_amplitude_change().data.cpu()[b,...].squeeze())
172 | plt.title('amplitude')
173 | plt.colorbar()
174 |
175 | plt.subplot(122)
176 | plt.imshow(self.get_phase_change().data.cpu()[b,...].squeeze())
177 | plt.title('phase')
178 | plt.colorbar()
179 |
180 | plt.suptitle('%s, (%d,%d), pitch:%.2f[um], wvl:%.2f[nm], device:%s'
181 | %(self.name, self.R, self.C, self.pitch/1e-6, self.wvl/1e-9, self.device))
182 | plt.show()
183 |
184 |
185 | class RefractiveLens(OpticalElement):
186 | def __init__(self, R, C, pitch, focal_length, wvl, device):
187 | """
188 | Thin refractive lens
189 | Args:
190 | R: row
191 | C: column
192 | pitch: pixel pitch in meter
193 | focal_length: focal length of the lens in meter
194 | wvl: wavelength of light in meter
195 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ...
196 | """
197 |
198 | super().__init__(R, C, pitch, wvl, device, name="refractive_lens")
199 |
200 | self.set_focal_length(focal_length)
201 | self.set_phase_change( self.compute_phase(self.wvl, shift_x=0, shift_y=0) )
202 |
203 | def set_focal_length(self, focal_length):
204 | """
205 | Set the focal length of the lens
206 | Args:
207 | focal_length: focal length in meter
208 | """
209 |
210 | self.focal_length = focal_length
211 |
212 | def compute_phase(self, wvl, shift_x=0, shift_y=0):
213 | """
214 | Set the phase of a thin lens
215 | Args:
216 | wvl: wavelength of light in meter
217 | shift_x: x displacement of the lens w.r.t. incident light
218 | shift_y: y displacement of the lens w.r.t. incident light
219 | """
220 |
221 | bw_R = self.R*self.pitch
222 | bw_C = self.C*self.pitch
223 |
224 | x = np.arange(-bw_C/2, bw_C/2, self.pitch)
225 | x = x[:self.R]
226 | y = np.arange(-bw_R/2, bw_R/2, self.pitch)
227 | y = y[:self.C]
228 | xx,yy = np.meshgrid(x,y)
229 |
230 | theta_change = torch.tensor((-2*np.pi / wvl)*((xx-shift_x)**2 + (yy-shift_y)**2), device=self.device) / (2*self.focal_length)
231 | theta_change = torch.unsqueeze(torch.unsqueeze(theta_change, axis=0), axis=0)
232 | theta_change %= 2*np.pi
233 | theta_change -= np.pi
234 |
235 | return theta_change
236 |
237 | def height2phase(height, wvl, RI, wrap=True):
238 | """
239 | Convert the height of a material to the corresponding phase shift
240 | Args:
241 | height: height of the material in meter
242 | wvl: wavelength of light in meter
243 | RI: refractive index of the material at the wavelength
244 | wrap: return the wrapped phase [0,2pi]
245 | """
246 | dRI = RI - 1
247 | wv_n = 2. * np.pi / wvl
248 | phi = wv_n * dRI * height
249 | if wrap:
250 | phi %= 2 * np.pi
251 | return phi
252 |
253 | def phase2height(phase, wvl, RI):
254 | """
255 | Convert the phase change to the height of a material
256 | Args:
257 | phase: phase change of light
258 | wvl: wavelength of light in meter
259 | RI: refractive index of the material at the wavelength
260 | """
261 | dRI = RI - 1
262 | return wvl * phase / (2 * np.pi) / dRI
263 |
264 | def radius2phase(r, f, wvl):
265 | return (2 * np.pi * (np.sqrt(r * r + f * f) - f) / wvl) % (2 * np.pi)
266 |
267 | class DOE(OpticalElement):
268 | def __init__(self, R, C, pitch, material, wvl, device, height=None, phase=None, amplitude=None):
269 | """
270 | Diffractive optical element (DOE)
271 | Args:
272 | R: row
273 | C: column
274 | pitch: pixel pitch in meter
275 | material: material of the DOE
276 | wvl: wavelength of light in meter
277 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ...
278 | height: height map of the material in meter
279 | phase: phase change of light
280 | amplitude: amplitude change of light
281 | """
282 |
283 | super().__init__(R, C, pitch, wvl, device, name="doe")
284 |
285 | self.material = material
286 | self.height = None
287 |
288 | if amplitude is None:
289 | amplitude = torch.ones((1, 1, self.R, self.C), device=self.device)
290 |
291 | if height is None and phase is not None:
292 | self.mode = 'phase'
293 | self.set_phase_change(phase, wvl)
294 | self.set_amplitude_change(amplitude)
295 | elif height is not None and phase is None:
296 | self.mode = 'height'
297 | self.set_height(height)
298 | self.set_amplitude_change(amplitude)
299 | elif (height is None) and (phase is None) and (amplitude is None):
300 | self.mode = 'phase'
301 | phase = torch.zeros((1, 1, self.R, self.C), device=self.device)
302 | self.set_amplitude_change(amplitude)
303 | self.set_phase_change(phase, wvl)
304 |
305 |
306 | def change_wvl(self, wvl):
307 | """
308 | Change the wavelength of phase change
309 | Args:
310 | wvl: wavelength of phase change
311 | """
312 | height = self.get_height()
313 | self.wvl = wvl
314 | phase = height2phase(height, self.wvl, self.material.get_RI(self.wvl))
315 | self.set_phase_change(phase, self.wvl)
316 |
317 | def set_diffraction_grating_1d(self, slit_width, minh, maxh):
318 | """
319 | Set the wavefront modulation as 1D diffraction grating
320 | Args:
321 | slit_width: width of slit in meter
322 | minh: minimum height in meter
323 | maxh: maximum height in meter
324 | """
325 |
326 | slit_width_px = np.round(slit_width / self.pitch)
327 | slit_space_px = slit_width_px
328 |
329 | dg = np.zeros((self.R, self.C))
330 | slit_num_r = self.R // (2 * slit_width_px)
331 | slit_num_c = self.C // (2 * slit_width_px)
332 |
333 | dg[:] = minh
334 |
335 | for i in range(int(slit_num_c)):
336 | minc = int((slit_width_px + slit_space_px) * i)
337 | maxc = int(minc + slit_width_px)
338 |
339 | dg[:, minc:maxc] = maxh
340 | pc = torch.tensor(dg.astype(np.float32), device=self.device).unsqueeze(0).unsqueeze(0)
341 | self.set_phase_change(pc, self.wvl)
342 |
343 | def set_diffraction_grating_2d(self, slit_width, minh, maxh):
344 | """
345 | Set the wavefront modulation as 2D diffraction grating
346 | Args:
347 | slit_width: width of slit in meter
348 | minh: minimum height in meter
349 | maxh: maximum height in meter
350 | """
351 |
352 | slit_width_px = np.round(slit_width / self.pitch)
353 | slit_space_px = slit_width_px
354 |
355 | dg = np.zeros((self.R, self.C))
356 | slit_num_r = self.R // (2 * slit_width_px)
357 | slit_num_c = self.C // (2 * slit_width_px)
358 |
359 | dg[:] = minh
360 |
361 | for i in range(int(slit_num_r)):
362 | for j in range(int(slit_num_c)):
363 | minc = int((slit_width_px + slit_space_px) * j)
364 | maxc = int(minc + slit_width_px)
365 | minr = int((slit_width_px + slit_space_px) * i)
366 | maxr = int(minr + slit_width_px)
367 |
368 | dg[minr:maxr, minc:maxc] = maxh
369 |
370 | pc = torch.tensor(dg.astype(np.float32), device=self.device).unsqueeze(0).unsqueeze(0)
371 | self.set_phase_change(pc, self.wvl)
372 |
373 | def set_Fresnel_lens(self, focal_length, shift_x=0, shift_y=0):
374 | """
375 | Set the wavefront modulation as a fresnel lens
376 | Args:
377 | focal_length: focal length in meter
378 | shift_x: x displacement of the lens w.r.t. incident light
379 | shift_y: y displacement of the lens w.r.t. incident light
380 | """
381 |
382 | x = np.arange(-self.C*self.pitch/2, self.C*self.pitch/2, self.pitch)
383 | y = np.arange(-self.R*self.pitch/2, self.R*self.pitch/2, self.pitch)
384 | xx,yy = np.meshgrid(x,y)
385 | xx = torch.tensor(xx, device=self.device)
386 | yy = torch.tensor(yy, device=self.device)
387 |
388 | phase = (-2*np.pi / self.wvl) * (torch.sqrt((xx-shift_x)**2 + (yy-shift_y)**2 + focal_length**2) - focal_length)
389 | phase = phase % (2*np.pi)
390 | phase -= np.pi
391 | phase = phase.unsqueeze(0).unsqueeze(0)
392 |
393 | self.set_phase_change(phase, self.wvl)
394 |
395 | def resize(self, target_pitch):
396 | '''
397 | Resize the wavefront by changing the pixel pitch.
398 | Args:
399 | target_pitch: new pixel pitch to use
400 | '''
401 | scale_factor = self.pitch / target_pitch
402 | super().resize(target_pitch)
403 |
404 | if self.mode == 'phase':
405 | super().resize(target_pitch)
406 | elif self.mode == 'height':
407 | self.set_height(F.interpolate(self.height, scale_factor=scale_factor, mode='bilinear', align_corners=False))
408 | else:
409 | NotImplementedError('Mode is not set.')
410 |
411 | def get_height(self):
412 | """
413 | Return the height map of the DOE
414 | Returns:
415 | height map: height map in meter
416 | """
417 |
418 | if self.mode == 'height':
419 | return self.height
420 | elif self.mode == 'phase':
421 | height = phase2height(self.phase_change, self.wvl, self.material.get_RI(self.wvl))
422 | return height
423 | else:
424 | NotImplementedError('Mode is not set.')
425 |
426 | def get_phase_change(self):
427 | """
428 | Return the phase change induced by the DOE
429 | Returns:
430 | phase change: phase change
431 | """
432 | if self.mode == 'height':
433 | self.to_phase_mode()
434 | return self.phase_change
435 |
436 | def set_height(self, height):
437 | """
438 | Set the height map of the DOE
439 | Args:
440 | height map: height map in meter
441 | """
442 |
443 | if self.mode == 'height':
444 | self.height = height
445 | elif self.mode == 'phase':
446 | self.set_phase_change(height2phase(height, self.wvl, self.material.get_RI(self.wvl)), self.wvl)
447 |
448 | def set_phase_change(self, phase_change, wvl):
449 | """
450 | Set the phase change induced by the DOE
451 | Args:
452 | phase change: phase change
453 | """
454 |
455 | if self.mode == 'height':
456 | self.set_height(phase2height(phase_change, wvl, self.material.get_RI(wvl)))
457 | if self.mode == 'phase':
458 | self.wvl = wvl
459 | self.phase_change = phase_change
460 |
461 | def to_phase_mode(self):
462 | """
463 | Change the mode to phase change
464 | """
465 | if self.mode == 'height':
466 | self.phase_change = height2phase(self.height, self.wvl, self.material.get_RI(self.wvl))
467 | self.mode = 'phase'
468 | self.height = None
469 |
470 | def to_height_mode(self):
471 | """
472 | Change the mode to height
473 | """
474 | if self.mode == 'phase':
475 | self.height = phase2height(self.phase_change, self.wvl, self.material.get_RI(self.wvl))
476 | self.mode = 'height'
477 |
478 |
479 | class SLM(OpticalElement):
480 | def __init__(self, R, C, pitch, wvl, device, B=1):
481 | """
482 | Spatial light modulator (SLM)
483 | Args:
484 | R: row
485 | C: column
486 | pitch: pixel pitch in meter
487 | wvl: wavelength of light in meter
488 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ...
489 | B: batch size
490 | """
491 |
492 | super().__init__(R, C, pitch, wvl, device, name="SLM", B=B)
493 |
494 | def set_lens(self, focal_length, shift_x=0, shift_y=0):
495 | """
496 | Set the phase of a thin lens
497 | Args:
498 | wvl: wavelength of light in meter
499 | shift_x: x displacement of the lens w.r.t. incident light
500 | shift_y: y displacement of the lens w.r.t. incident light
501 | """
502 |
503 | x = np.arange(-self.C*self.pitch/2, self.C*self.pitch/2, self.pitch)
504 | y = np.arange(-self.R*self.pitch/2, self.R*self.pitch/2, self.pitch)
505 | xx,yy = np.meshgrid(x,y)
506 |
507 | phase = (2*np.pi / self.wvl)*((xx-shift_x)**2 + (yy-shift_y)**2) / (2*focal_length)
508 | phase = torch.tensor(phase.astype(np.float32), device=self.device).unsqueeze(0).unsqueeze(0)
509 | phase = phase % (2*np.pi)
510 | phase -= np.pi
511 |
512 | self.set_phase_change(phase, self.wvl)
513 |
514 | def set_amplitude_change(self, amplitude, wvl):
515 | """
516 | Set the amplitude change
517 | Args:
518 | amplitude change: amplitude change in the polar representation of the complex number
519 | wvl: wavelength of light in meter
520 |
521 | """
522 | self.wvl = wvl
523 | super().set_amplitude_change(amplitude)
524 |
525 | def set_phase_change(self, phase_change, wvl):
526 | """
527 | Set the phase change
528 | Args:
529 | phase change: phase change in the polar representation of the complex number
530 | wvl: wavelength of light in meter
531 | """
532 | self.wvl = wvl
533 | super().set_phase_change(phase_change)
534 |
535 |
536 | class Aperture(OpticalElement):
537 | def __init__(self, R, C, pitch, aperture_diameter, aperture_shape, wvl, device='cpu'):
538 | """
539 | Aperture
540 | Args:
541 | R: row
542 | C: column
543 | pitch: pixel pitch in meter
544 | aperture_diameter: diamater of the aperture in meter
545 | aperture_shape: shape of the aperture. {'square', 'circle'}
546 | wvl: wavelength of light in meter
547 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ...
548 | """
549 |
550 | super().__init__(R, C, pitch, wvl, device, name="aperture")
551 |
552 | self.aperture_diameter = aperture_diameter
553 | self.aperture_shape = aperture_shape
554 | self.amplitude_change = torch.zeros((self.R, self.C), device=device)
555 | if self.aperture_shape == 'square':
556 | self.set_square()
557 | elif self.aperture_shape == 'circle':
558 | self.set_circle()
559 | else:
560 | return NotImplementedError
561 |
562 | def set_square(self):
563 | """
564 | Set the amplitude modulation of the aperture as square
565 | """
566 |
567 | self.aperture_shape = 'square'
568 |
569 | [x, y] = np.mgrid[-self.R // 2:self.R // 2, -self.C // 2:self.C // 2].astype(np.float32)
570 | r = self.pitch * np.asarray([abs(x), abs(y)]).max(axis=0)
571 | r = np.expand_dims(np.expand_dims(r, axis=0), axis=0)
572 |
573 | max_val = self.aperture_diameter / 2
574 | amp = (r <= max_val).astype(np.float32)
575 | amp[amp == 0] = 1e-20 # to enable stable learning
576 | self.amplitude_change = torch.tensor(amp, device=self.device)
577 |
578 | def set_circle(self, cx=0, cy=0, dia=None):
579 | """
580 | Set the amplitude modulation of the aperture as circle
581 | Args:
582 | cx, cy: relative center position of the circle with respect to the center of the light wavefront
583 | dia: circle diameter
584 | """
585 | [x, y] = np.mgrid[-self.R // 2:self.C // 2, -self.R // 2:self.C // 2].astype(np.float32)
586 | r2 = (x-cx) ** 2 + (y-cy) ** 2
587 | r2[r2 < 0] = 1e-20
588 | r = self.pitch * np.sqrt(r2)
589 | r = np.expand_dims(np.expand_dims(r, axis=0), axis=0)
590 |
591 | if dia is not None:
592 | self.aperture_diameter = dia
593 | self.aperture_shape = 'circle'
594 | max_val = self.aperture_diameter / 2
595 | amp = (r <= max_val).astype(np.float32)
596 | amp[amp == 0] = 1e-20
597 | self.amplitude_change = torch.tensor(amp, device=self.device)
598 |
599 | def quantize(x, levels, vmin=None, vmax=None, include_vmax=True):
600 | """
601 | Quantize the floating array
602 | Args:
603 | levels: number of quantization levels
604 | vmin: minimum value for quantization
605 | vmax: maximum value for quantization
606 | include_vmax: include vmax for the quantized levels
607 | False: quantize x with the space of 1/levels-1.
608 | True: quantize x with the space of 1/levels
609 | """
610 |
611 | if include_vmax is False:
612 | if levels == 0:
613 | return x
614 |
615 | if vmin is None:
616 | vmin = x.min()
617 | if vmax is None:
618 | vmax = x.max()
619 |
620 | #assert(vmin <= vmax)
621 |
622 | normalized = (x - vmin) / (vmax - vmin + 1e-16)
623 | if type(x) is np.ndarray:
624 | levelized = np.floor(normalized * levels) / (levels - 1)
625 | elif type(x) is torch.tensor:
626 | levelized = (normalized * levels).floor() / (levels - 1)
627 | result = levelized * (vmax - vmin) + vmin
628 | result[result < vmin] = vmin
629 | result[result > vmax] = vmax
630 |
631 | elif include_vmax is True:
632 | space = (x.max()-x.min())/levels
633 | vmin = x.min()
634 | vmax = vmin + space*(levels-1)
635 | if type(x) is np.ndarray:
636 | result = (np.floor((x-vmin)/space))*space + vmin
637 | elif type(x) is torch.tensor:
638 | result = (((x-vmin)/space).floor())*space + vmin
639 | result[resultvmax] = vmax
641 |
642 | return result
--------------------------------------------------------------------------------
/pado-main/pado/propagator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from .fourier import fft
4 | from .complex import Complex
5 | from .conv import conv_fft
6 |
7 |
8 | def compute_pad_width(field, linear):
9 | """
10 | Compute the pad width of an array for FFT-based convolution
11 | Args:
12 | field: (B,Ch,R,C) complex tensor
13 | linear: True or False, flag for linear convolution (zero padding) or circular convolution (no padding)
14 | Returns:
15 | pad_width: pad-width tensor
16 | """
17 |
18 | if linear:
19 | R,C = field.shape()[-2:]
20 | pad_width = (C//2, C//2, R//2, R//2)
21 | else:
22 | pad_width = (0,0,0,0)
23 | return pad_width
24 |
25 | def unpad(field_padded, pad_width):
26 | """
27 | Unpad the already-padded complex tensor
28 | Args:
29 | field_padded: (B,Ch,R,C) padded complex tensor
30 | pad_width: pad-width tensor
31 | Returns:
32 | field: unpadded complex tensor
33 | """
34 |
35 | field = field_padded[...,pad_width[2]:-pad_width[3],pad_width[0]:-pad_width[1]]
36 | return field
37 |
38 | class Propagator:
39 | def __init__(self, mode):
40 | """
41 | Free-space propagator of light waves
42 | One can simulate the propagation of light waves on free space (no medium change at all).
43 | Args:
44 | mode: type of propagator. currently, we support "Fraunhofer" propagation or "Fresnel" propagation. Use Fraunhofer for far-field propagation and Fresnel for near-field propagation.
45 | """
46 | self.mode = mode
47 |
48 | def forward(self, light, z, linear=True):
49 | """
50 | Forward the incident light with the propagator.
51 | Args:
52 | light: incident light
53 | z: propagation distance in meter
54 | linear: True or False, flag for linear convolution (zero padding) or circular convolution (no padding)
55 | Returns:
56 | light: light after propagation
57 | """
58 |
59 | if self.mode == 'Fraunhofer':
60 | return self.forward_Fraunhofer(light, z, linear)
61 | if self.mode == 'Fresnel':
62 | return self.forward_Fresnel(light, z, linear)
63 | else:
64 | return NotImplementedError('%s propagator is not implemented'%self.mode)
65 |
66 |
67 | def forward_Fraunhofer(self, light, z, linear=True):
68 | """
69 | Forward the incident light with the Fraunhofer propagator.
70 | Args:
71 | light: incident light
72 | z: propagation distance in meter.
73 | The propagated wavefront is independent w.r.t. the travel distance z.
74 | The distance z only affects the size of the "pixel", effectively adjusting the entire image size.
75 | linear: True or False, flag for linear convolution (zero padding) or circular convolution (no padding)
76 | Returns:
77 | light: light after propagation
78 | """
79 |
80 | pad_width = compute_pad_width(light.field, linear)
81 | field_propagated = fft(light.field, pad_width=pad_width)
82 | field_propagated = unpad(field_propagated, pad_width)
83 |
84 | # based on the Fraunhofer reparametrization (u=x/wvl*z) and the Fourier frequency sampling (1/bandwidth)
85 | bw_r = light.get_bandwidth()[0]
86 | bw_c = light.get_bandwidth()[1]
87 | pitch_r_after_propagation = light.wvl*z/bw_r
88 | pitch_c_after_propagation = light.wvl*z/bw_c
89 |
90 | light_propagated = light.clone()
91 |
92 | # match the x-y pixel pitch using resampling
93 | if pitch_r_after_propagation >= pitch_c_after_propagation:
94 | scale_c = 1
95 | scale_r = pitch_r_after_propagation/pitch_c_after_propagation
96 | pitch_after_propagation = pitch_c_after_propagation
97 | elif pitch_r_after_propagation < pitch_c_after_propagation:
98 | scale_r = 1
99 | scale_c = pitch_c_after_propagation/pitch_r_after_propagation
100 | pitch_after_propagation = pitch_r_after_propagation
101 |
102 | light_propagated.set_field(field_propagated)
103 | light_propagated.magnify((scale_r,scale_c))
104 | light_propagated.set_pitch(pitch_after_propagation)
105 |
106 | return light_propagated
107 |
108 | def forward_Fresnel(self, light, z, linear):
109 | """
110 | Forward the incident light with the Fresnel propagator.
111 | Args:
112 | light: incident light
113 | z: propagation distance in meter.
114 | linear: True or False, flag for linear convolution (zero padding) or circular convolution (no padding)
115 | Returns:
116 | light: light after propagation
117 | """
118 | field_input = light.field
119 |
120 | # compute the convolutional kernel
121 | sx = light.C / 2
122 | sy = light.R / 2
123 | x = np.arange(-sx, sx, 1)
124 | y = np.arange(-sy, sy, 1)
125 | xx, yy = np.meshgrid(x,y)
126 | xx = torch.from_numpy(xx*light.pitch).to(light.device)
127 | yy = torch.from_numpy(yy*light.pitch).to(light.device)
128 | k = 2*np.pi/light.wvl # wavenumber
129 | phase = (k*(xx**2 + yy**2)/(2*z))
130 | amplitude = torch.ones_like(phase) / z / light.wvl
131 | conv_kernel = Complex(mag=amplitude, ang=phase)
132 |
133 | # Propagation with the convolution kernel
134 | pad_width = compute_pad_width(field_input, linear)
135 |
136 | field_propagated = conv_fft(field_input, conv_kernel, pad_width)
137 |
138 | # return the propagated light
139 | light_propagated = light.clone()
140 | light_propagated.set_field(field_propagated)
141 |
142 | return light_propagated
143 |
144 |
--------------------------------------------------------------------------------
/pado/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
--------------------------------------------------------------------------------
/pado/cmap_phase.txt:
--------------------------------------------------------------------------------
1 | 0.180392 0.129412 0.917647
2 | 0.188235 0.121569 0.921569
3 | 0.200000 0.117647 0.925490
4 | 0.215686 0.109804 0.933333
5 | 0.227451 0.105882 0.937255
6 | 0.243137 0.101961 0.937255
7 | 0.254902 0.101961 0.941176
8 | 0.270588 0.101961 0.945098
9 | 0.286275 0.101961 0.949020
10 | 0.298039 0.105882 0.949020
11 | 0.313725 0.109804 0.952941
12 | 0.325490 0.113725 0.952941
13 | 0.341176 0.117647 0.956863
14 | 0.352941 0.121569 0.956863
15 | 0.364706 0.129412 0.960784
16 | 0.376471 0.133333 0.960784
17 | 0.388235 0.141176 0.960784
18 | 0.400000 0.145098 0.964706
19 | 0.411765 0.152941 0.964706
20 | 0.423529 0.156863 0.964706
21 | 0.435294 0.164706 0.968627
22 | 0.447059 0.168627 0.968627
23 | 0.458824 0.176471 0.972549
24 | 0.470588 0.180392 0.972549
25 | 0.478431 0.184314 0.972549
26 | 0.490196 0.192157 0.972549
27 | 0.501961 0.196078 0.976471
28 | 0.513725 0.200000 0.976471
29 | 0.521569 0.207843 0.976471
30 | 0.533333 0.211765 0.980392
31 | 0.545098 0.215686 0.980392
32 | 0.556863 0.219608 0.980392
33 | 0.568627 0.223529 0.980392
34 | 0.580392 0.227451 0.980392
35 | 0.592157 0.231373 0.980392
36 | 0.603922 0.235294 0.984314
37 | 0.615686 0.239216 0.984314
38 | 0.627451 0.239216 0.984314
39 | 0.643137 0.243137 0.984314
40 | 0.654902 0.247059 0.984314
41 | 0.666667 0.247059 0.984314
42 | 0.682353 0.250980 0.984314
43 | 0.694118 0.250980 0.984314
44 | 0.705882 0.254902 0.984314
45 | 0.717647 0.254902 0.984314
46 | 0.733333 0.258824 0.980392
47 | 0.745098 0.258824 0.980392
48 | 0.756863 0.262745 0.980392
49 | 0.772549 0.262745 0.980392
50 | 0.784314 0.266667 0.980392
51 | 0.796078 0.266667 0.980392
52 | 0.807843 0.270588 0.980392
53 | 0.819608 0.270588 0.976471
54 | 0.831373 0.274510 0.976471
55 | 0.843137 0.274510 0.976471
56 | 0.854902 0.278431 0.976471
57 | 0.866667 0.282353 0.972549
58 | 0.878431 0.286275 0.972549
59 | 0.886275 0.290196 0.968627
60 | 0.898039 0.294118 0.964706
61 | 0.905882 0.301961 0.964706
62 | 0.913725 0.309804 0.960784
63 | 0.921569 0.313725 0.956863
64 | 0.929412 0.321569 0.952941
65 | 0.937255 0.333333 0.945098
66 | 0.941176 0.341176 0.941176
67 | 0.949020 0.349020 0.933333
68 | 0.952941 0.360784 0.925490
69 | 0.956863 0.372549 0.921569
70 | 0.960784 0.384314 0.913725
71 | 0.964706 0.396078 0.901961
72 | 0.968627 0.407843 0.894118
73 | 0.968627 0.419608 0.886275
74 | 0.972549 0.431373 0.878431
75 | 0.976471 0.443137 0.866667
76 | 0.976471 0.454902 0.858824
77 | 0.976471 0.466667 0.847059
78 | 0.980392 0.478431 0.839216
79 | 0.980392 0.494118 0.827451
80 | 0.980392 0.505882 0.815686
81 | 0.984314 0.517647 0.807843
82 | 0.984314 0.529412 0.796078
83 | 0.984314 0.537255 0.784314
84 | 0.984314 0.549020 0.772549
85 | 0.984314 0.560784 0.764706
86 | 0.984314 0.572549 0.752941
87 | 0.984314 0.584314 0.741176
88 | 0.984314 0.596078 0.729412
89 | 0.984314 0.603922 0.721569
90 | 0.984314 0.615686 0.709804
91 | 0.984314 0.627451 0.698039
92 | 0.984314 0.635294 0.686275
93 | 0.984314 0.647059 0.678431
94 | 0.984314 0.658824 0.666667
95 | 0.984314 0.666667 0.654902
96 | 0.984314 0.678431 0.643137
97 | 0.984314 0.686275 0.631373
98 | 0.984314 0.698039 0.619608
99 | 0.984314 0.705882 0.611765
100 | 0.984314 0.717647 0.600000
101 | 0.984314 0.725490 0.588235
102 | 0.984314 0.733333 0.576471
103 | 0.984314 0.745098 0.564706
104 | 0.984314 0.752941 0.552941
105 | 0.984314 0.760784 0.541176
106 | 0.984314 0.768627 0.529412
107 | 0.984314 0.780392 0.517647
108 | 0.984314 0.788235 0.505882
109 | 0.984314 0.796078 0.494118
110 | 0.988235 0.803922 0.482353
111 | 0.988235 0.811765 0.466667
112 | 0.988235 0.823529 0.454902
113 | 0.988235 0.831373 0.443137
114 | 0.988235 0.839216 0.431373
115 | 0.988235 0.847059 0.415686
116 | 0.984314 0.854902 0.403922
117 | 0.984314 0.862745 0.388235
118 | 0.984314 0.870588 0.372549
119 | 0.984314 0.878431 0.360784
120 | 0.980392 0.886275 0.345098
121 | 0.980392 0.894118 0.329412
122 | 0.976471 0.898039 0.313725
123 | 0.972549 0.905882 0.298039
124 | 0.968627 0.909804 0.282353
125 | 0.964706 0.917647 0.266667
126 | 0.960784 0.921569 0.250980
127 | 0.952941 0.925490 0.235294
128 | 0.949020 0.929412 0.223529
129 | 0.941176 0.929412 0.207843
130 | 0.933333 0.933333 0.192157
131 | 0.925490 0.933333 0.180392
132 | 0.917647 0.933333 0.168627
133 | 0.909804 0.933333 0.156863
134 | 0.898039 0.933333 0.145098
135 | 0.890196 0.929412 0.133333
136 | 0.878431 0.929412 0.125490
137 | 0.866667 0.925490 0.117647
138 | 0.858824 0.921569 0.109804
139 | 0.847059 0.921569 0.105882
140 | 0.835294 0.917647 0.101961
141 | 0.823529 0.913725 0.098039
142 | 0.811765 0.909804 0.094118
143 | 0.800000 0.905882 0.090196
144 | 0.788235 0.901961 0.086275
145 | 0.776471 0.898039 0.086275
146 | 0.768627 0.894118 0.082353
147 | 0.756863 0.890196 0.082353
148 | 0.745098 0.886275 0.078431
149 | 0.733333 0.878431 0.078431
150 | 0.721569 0.874510 0.078431
151 | 0.709804 0.870588 0.074510
152 | 0.698039 0.866667 0.074510
153 | 0.686275 0.862745 0.074510
154 | 0.674510 0.858824 0.070588
155 | 0.662745 0.854902 0.070588
156 | 0.650980 0.850980 0.070588
157 | 0.639216 0.847059 0.066667
158 | 0.627451 0.843137 0.066667
159 | 0.611765 0.835294 0.066667
160 | 0.600000 0.831373 0.062745
161 | 0.588235 0.827451 0.062745
162 | 0.576471 0.823529 0.062745
163 | 0.564706 0.819608 0.058824
164 | 0.552941 0.815686 0.058824
165 | 0.541176 0.811765 0.058824
166 | 0.525490 0.807843 0.054902
167 | 0.513725 0.803922 0.054902
168 | 0.501961 0.796078 0.054902
169 | 0.490196 0.792157 0.050980
170 | 0.474510 0.788235 0.050980
171 | 0.462745 0.784314 0.050980
172 | 0.450980 0.780392 0.047059
173 | 0.435294 0.776471 0.047059
174 | 0.423529 0.772549 0.047059
175 | 0.407843 0.764706 0.047059
176 | 0.396078 0.760784 0.043137
177 | 0.380392 0.756863 0.043137
178 | 0.364706 0.752941 0.043137
179 | 0.352941 0.749020 0.047059
180 | 0.337255 0.745098 0.047059
181 | 0.321569 0.737255 0.047059
182 | 0.309804 0.733333 0.050980
183 | 0.294118 0.729412 0.054902
184 | 0.278431 0.725490 0.062745
185 | 0.266667 0.717647 0.066667
186 | 0.250980 0.713725 0.074510
187 | 0.239216 0.709804 0.086275
188 | 0.227451 0.701961 0.094118
189 | 0.219608 0.698039 0.105882
190 | 0.207843 0.694118 0.117647
191 | 0.203922 0.686275 0.129412
192 | 0.196078 0.682353 0.141176
193 | 0.192157 0.674510 0.156863
194 | 0.192157 0.670588 0.168627
195 | 0.192157 0.662745 0.184314
196 | 0.196078 0.654902 0.200000
197 | 0.200000 0.650980 0.211765
198 | 0.203922 0.643137 0.227451
199 | 0.211765 0.635294 0.243137
200 | 0.215686 0.627451 0.258824
201 | 0.223529 0.619608 0.274510
202 | 0.227451 0.615686 0.290196
203 | 0.235294 0.607843 0.301961
204 | 0.239216 0.600000 0.317647
205 | 0.243137 0.592157 0.333333
206 | 0.250980 0.584314 0.349020
207 | 0.254902 0.576471 0.360784
208 | 0.258824 0.568627 0.376471
209 | 0.262745 0.564706 0.392157
210 | 0.262745 0.556863 0.403922
211 | 0.266667 0.549020 0.419608
212 | 0.266667 0.541176 0.431373
213 | 0.270588 0.533333 0.447059
214 | 0.270588 0.525490 0.458824
215 | 0.270588 0.517647 0.474510
216 | 0.270588 0.509804 0.486275
217 | 0.266667 0.501961 0.501961
218 | 0.266667 0.498039 0.513725
219 | 0.262745 0.490196 0.525490
220 | 0.262745 0.482353 0.541176
221 | 0.258824 0.474510 0.552941
222 | 0.254902 0.466667 0.564706
223 | 0.250980 0.458824 0.580392
224 | 0.243137 0.450980 0.592157
225 | 0.239216 0.443137 0.603922
226 | 0.235294 0.435294 0.615686
227 | 0.227451 0.427451 0.627451
228 | 0.223529 0.419608 0.643137
229 | 0.215686 0.411765 0.654902
230 | 0.207843 0.403922 0.666667
231 | 0.203922 0.396078 0.678431
232 | 0.196078 0.388235 0.690196
233 | 0.192157 0.380392 0.701961
234 | 0.184314 0.372549 0.709804
235 | 0.180392 0.364706 0.721569
236 | 0.176471 0.352941 0.733333
237 | 0.172549 0.345098 0.745098
238 | 0.168627 0.333333 0.752941
239 | 0.164706 0.325490 0.764706
240 | 0.160784 0.313725 0.772549
241 | 0.156863 0.305882 0.784314
242 | 0.156863 0.294118 0.792157
243 | 0.152941 0.282353 0.803922
244 | 0.149020 0.274510 0.811765
245 | 0.149020 0.262745 0.823529
246 | 0.145098 0.250980 0.831373
247 | 0.145098 0.239216 0.839216
248 | 0.141176 0.227451 0.850980
249 | 0.141176 0.215686 0.858824
250 | 0.141176 0.203922 0.866667
251 | 0.141176 0.192157 0.874510
252 | 0.145098 0.180392 0.882353
253 | 0.149020 0.168627 0.890196
254 | 0.152941 0.160784 0.898039
255 | 0.160784 0.149020 0.905882
256 | 0.168627 0.141176 0.909804
257 |
--------------------------------------------------------------------------------
/pado/conv.py:
--------------------------------------------------------------------------------
1 | from .fourier import fft, ifft
2 |
3 | def conv_fft(img_c, kernel_c, pad_width=None):
4 | """
5 | Compute the convolution of an image with a convolution kernel using FFT
6 | Args:
7 | img_c: [B,Ch,R,C] image as a complex tensor
8 | kernel_c: [B,Ch,R,C] convolution kernel as a complex tensor
9 | pad_width: (tensor) pad width for the last spatial dimensions. should be (0,0,0,0) for circular convolution. for linear convolution, pad zero by the size of the original image
10 | Returns:
11 | im_conv: [B,Ch,R,C] blurred image
12 | """
13 |
14 | img_fft = fft(img_c, pad_width=pad_width)
15 | kernel_fft = fft(kernel_c, pad_width=pad_width)
16 | return ifft( img_fft * kernel_fft, pad_width=pad_width)
17 |
18 |
--------------------------------------------------------------------------------
/pado/fourier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .complex import Complex
3 | import matplotlib.pyplot as plt
4 |
5 |
6 | def fft(arr_c, normalized=False, pad_width=None, padval=0, shift=True):
7 | """
8 | Compute the Fast Fourier transform of a complex tensor
9 | Args:
10 | arr_c: [B,Ch,R,C] complex tensor
11 | normalized: Normalize the FFT output. default: False
12 | pad_width: (tensor) pad width for the last spatial dimensions.
13 | shift: flag for shifting the input data to make the zero-frequency located at the center of the arrc
14 | Returns:
15 | arr_c_fft: [B,Ch,R,C] FFT of the input complex tensor
16 | """
17 |
18 | arr_c = Complex(mag=arr_c.get_mag().clone(), ang=arr_c.get_ang().clone())
19 | if pad_width is not None:
20 | if padval == 0:
21 | arr_c.pad_zero(pad_width)
22 | else:
23 | return NotImplementedError('zero padding is only implemented for now')
24 | if shift:
25 | arr_c_shifted = ifftshift(arr_c)
26 | else:
27 | arr_c_shifted = arr_c
28 | arr_c_shifted.to_rect()
29 |
30 | #arr_c_shifted_stack = arr_c_shifted.get_stack()
31 | if normalized is False:
32 | normalized = "backward"
33 | else:
34 | normalized = "forward"
35 | arr_c_shifted_fft = torch.fft.fft2(arr_c_shifted.get_native() , norm=normalized)
36 | arr_c_shifted_fft_c = Complex(real=arr_c_shifted_fft.real, imag=arr_c_shifted_fft.imag)
37 | if shift:
38 | arr_c_fft = fftshift(arr_c_shifted_fft_c)
39 | else:
40 | arr_c_fft = arr_c_shifted_fft_c
41 |
42 | return arr_c_fft
43 |
44 |
45 |
46 | def ifft(arr_c, normalized=False, pad_width=None, shift=True):
47 | """
48 | Compute the inverse Fast Fourier transform of a complex tensor
49 | Args:
50 | arr_c: [B,Ch,R,C] complex tensor
51 | normalized: Normalize the FFT output. default: False
52 | pad_width: (tensor) pad width for the last spatial dimensions.
53 | shift: flag for inversely shifting the input data
54 | Returns:
55 | arr_c_fft: [B,Ch,R,C] inverse FFT of the input complex tensor
56 | """
57 |
58 | arr_c = Complex(mag=arr_c.get_mag().clone(), ang=arr_c.get_ang().clone())
59 | if shift:
60 | arr_c_shifted = ifftshift(arr_c)
61 | else:
62 | arr_c_shifted = arr_c
63 |
64 | arr_c_shifted.to_rect()
65 | if normalized is False:
66 | normalized = "backward"
67 | else:
68 | normalized = "forward"
69 | arr_c_shifted_fft = torch.fft.ifft2(arr_c_shifted.get_native(), norm=normalized)
70 | arr_c_shifted_fft_c = Complex(real=arr_c_shifted_fft.real, imag=arr_c_shifted_fft.imag)
71 | if shift:
72 | arr_c_fft = fftshift(arr_c_shifted_fft_c)
73 | else:
74 | arr_c_fft = arr_c_shifted_fft_c
75 |
76 | if pad_width is not None:
77 | arr_c_fft.crop(pad_width)
78 |
79 | return arr_c_fft
80 |
81 | def fftshift(arr_c, invert=False):
82 | """
83 | Shift the complex tensor so that the zero-frequency signal located at the center of the input
84 | Args:
85 | arr_c: [B,Ch,R,C] complex tensor
86 | invert: flag for inversely shifting the input data
87 | Returns:
88 | arr_c: [B,Ch,R,C] shifted tensor
89 | """
90 |
91 | arr_c.to_rect()
92 | shift_adjust = 0 if invert else 1
93 |
94 | arr_c_shape = arr_c.shape()
95 | C = arr_c_shape[-1]
96 | R = arr_c_shape[-2]
97 |
98 | shift_len = (C + shift_adjust) // 2
99 | arr_c = arr_c[...,shift_len:].cat(arr_c[...,:shift_len], -1)
100 |
101 | shift_len = (R + shift_adjust) // 2
102 | arr_c = arr_c[...,shift_len:,:].cat(arr_c[...,:shift_len,:], -2)
103 |
104 | return arr_c
105 |
106 |
107 | def ifftshift(arr_c):
108 | """
109 | Inversely shift the complex tensor
110 | Args:
111 | arr_c: [B,Ch,R,C] complex tensor
112 | Returns:
113 | arr_c: [B,Ch,R,C] shifted tensor
114 | """
115 |
116 | return fftshift(arr_c, invert=True)
117 |
--------------------------------------------------------------------------------
/pado/light.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from .complex import Complex
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | from scipy.io import savemat
7 |
8 | class Light:
9 | def __init__(self, R, C, pitch, wvl, device, amplitude=None, phase=None, real=None, imag=None, B=1):
10 | """
11 | Light wave that has a complex field (B,Ch,R,C) as a wavefront
12 | It takes the input wavefront in one of the following types
13 | 1. amplitude and phase
14 | 2. real and imaginary
15 | 3. None ==> we initialize light with amplitude of one and phase of zero
16 |
17 | Args:
18 | R: row
19 | C: column
20 | pitch: pixel pitch in meter
21 | wvl: wavelength of light in meter
22 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ...
23 | amplitude: [batch_size, # of channels, row, column] tensor of wavefront amplitude, default is None
24 | phase: [batch_size, # of channels, row, column] tensor of wavefront amplitude, default is None
25 | real: [batch_size, # of channels, row, column] tensor of wavefront real part, default is None
26 | imag: [batch_size, # of channels, row, column] tensor of wavefront imaginary part, default is None
27 | B: batch size
28 | """
29 |
30 | self.B = B
31 | self.R = R
32 | self.C = C
33 | self.pitch = pitch
34 | self.device = device
35 | self.wvl = wvl
36 |
37 | if (amplitude==None) and (phase == None) and (real != None) and (imag != None):
38 | self.field = Complex(real=real, imag=imag)
39 |
40 | elif (amplitude!=None) and (phase != None) and (real == None) and (imag == None):
41 | self.field = Complex(mag=amplitude, ang=phase)
42 |
43 | elif (amplitude==None) and (phase == None) and (real == None) and (imag == None):
44 | amplitude = torch.ones((B, 1, self.R, self.C), device=self.device)
45 | phase = torch.zeros((B, 1, self.R, self.C), device=self.device)
46 | self.field = Complex(mag=amplitude, ang=phase)
47 |
48 | else:
49 | NotImplementedError('nope!')
50 |
51 |
52 | def crop(self, crop_width):
53 | """
54 | Crop the light wavefront by crop_width
55 | Args:
56 | crop_width: (tuple) crop width of the tensor following torch functional pad
57 | """
58 |
59 | self.field.crop(crop_width)
60 | self.R = self.field.size(2)
61 | self.C = self.field.size(3)
62 |
63 | def clone(self):
64 | """
65 | Clone the light and return it
66 | """
67 |
68 | return Light(self.R, self.C,
69 | self.pitch, self.wvl, self.device,
70 | amplitude=self.field.get_mag().clone(), phase=self.field.get_ang().clone(),
71 | B=self.B)
72 |
73 |
74 | def pad(self, pad_width, padval=0):
75 | """
76 | Pad the light wavefront with a constant value by pad_width
77 | Args:
78 | pad_width: (tuple) pad width of the tensor following torch functional pad
79 | padval: value to pad. default is zero
80 | """
81 |
82 | if padval == 0:
83 | self.set_amplitude(torch.nn.functional.pad(self.get_amplitude(), pad_width))
84 | self.set_phase(torch.nn.functional.pad(self.get_phase(), pad_width))
85 | else:
86 | return NotImplementedError('only zero padding supported')
87 |
88 | self.R += pad_width[0] + pad_width[1]
89 | self.C += pad_width[2] + pad_width[3]
90 |
91 | def set_real(self, real):
92 | """
93 | Set the real part of the light wavefront
94 | Args:
95 | real: real part in the rect representation of the complex number
96 | """
97 |
98 | self.field.set_real(real)
99 |
100 | def set_imag(self, imag):
101 | """
102 | Set the imaginary part of the light wavefront
103 | Args:
104 | imag: imaginary part in the rect representation of the complex number
105 | """
106 |
107 | self.field.set_imag(imag)
108 |
109 | def set_amplitude(self, amplitude):
110 | """
111 | Set the amplitude of the light wavefront
112 | Args:
113 | amplitude: amplitude in the polar representation of the complex number
114 | """
115 | self.field.set_mag(amplitude)
116 |
117 | def set_phase(self, phase):
118 | """
119 | Set the phase of the complex tensor
120 | Args:
121 | phase: phase in the polar representation of the complex number
122 | """
123 | self.field.set_ang(phase)
124 |
125 |
126 | def set_field(self, field):
127 | """
128 | Set the wavefront modulation of the complex tensor
129 | Args:
130 | field: wavefront as a complex number
131 | """
132 | self.field = field
133 |
134 | def set_pitch(self, pitch):
135 | """
136 | Set the pixel pitch of the complex tensor
137 | Args:
138 | pitch: pixel pitch in meter
139 | """
140 | self.pitch = pitch
141 |
142 |
143 | def get_amplitude(self):
144 | """
145 | Return the amplitude of the wavefront
146 | Returns:
147 | mag: magnitude in the polar representation of the complex number
148 | """
149 |
150 | return self.field.get_mag()
151 |
152 | def get_phase(self):
153 | """
154 | Return the phase of the wavefront
155 | Returns:
156 | ang: angle in the polar representation of the complex number
157 | """
158 |
159 | return self.field.get_ang()
160 |
161 | def get_field(self):
162 | """
163 | Return the complex wavefront
164 | Returns:
165 | field: complex wavefront
166 | """
167 |
168 | return self.field
169 |
170 | def get_intensity(self):
171 | """
172 | Return the intensity of light wavefront
173 | Returns:
174 | intensity: intensity of light
175 | """
176 | return self.field.get_intensity()
177 |
178 | def get_bandwidth(self):
179 | """
180 | Return the bandwidth of light wavefront
181 | Returns:
182 | R_m: spatial height of the wavefront
183 | C_m: spatial width of the wavefront
184 | """
185 |
186 | return self.pitch*self.R, self.pitch*self.C
187 |
188 | def magnify(self, scale_factor, interp_mode='nearest'):
189 | '''
190 | Change the wavefront resolution without changing the pixel pitch
191 | Args:
192 | scale_factor: scale factor for interpolation used in tensor.nn.functional.interpolate
193 | interp_mode: interpolation method used in torch.nn.functional.interpolate 'bilinear', 'nearest'
194 | '''
195 | self.field.resize(scale_factor, interp_mode)
196 | if self.field.mode == 'polar':
197 | self.R = self.field.mag.shape[-2]
198 | self.C = self.field.mag.shape[-1]
199 | elif self.field.mode == 'rect':
200 | self.R = self.field.real.shape[-2]
201 | self.C = self.field.real.shape[-1]
202 |
203 |
204 | def resize(self, target_pitch, interp_mode='nearest'):
205 | '''
206 | Resize the wavefront by changing the pixel pitch.
207 | Args:
208 | target_pitch: new pixel pitch to use
209 | interp_mode: interpolation method used in torch.nn.functional.interpolate 'bilinear', 'nearest'
210 | '''
211 | scale_factor = self.pitch / target_pitch
212 | self.magnify(scale_factor, interp_mode)
213 | self.set_pitch(target_pitch)
214 |
215 | def set_spherical_light(self, z, dx=0, dy=0):
216 | '''
217 | Set the wavefront as spherical one coming from the position of (dx,dy,z).
218 | Args:
219 | z: z distance of the spherical light source from the current light position
220 | dx: x distance of the spherical light source from the current light position
221 | dy: y distance of the spherical light source from the current light position
222 | '''
223 |
224 | [x, y] = np.mgrid[-self.C // 2:self.C // 2, -self.R // 2:self.R // 2].astype(np.float64)
225 | x = x * self.pitch
226 | y = y * self.pitch
227 | r = np.sqrt((x - dx) ** 2 + (y - dy) ** 2 + z ** 2) # this is computed in double precision
228 | theta = 2 * np.pi * r / self.wvl
229 | theta = np.expand_dims(np.expand_dims(theta, axis=0), axis=0)%(2*np.pi)
230 | theta = theta.astype(np.float32)
231 |
232 | theta = torch.tensor(theta, device=self.device)
233 | mag = torch.ones_like(theta)
234 |
235 | self.set_phase(theta)
236 | self.set_amplitude(mag)
237 |
238 | def set_plane_light(self):
239 | '''
240 | Set the wavefront as a plane wave with zero phase and amptliude of one
241 | '''
242 | amplitude = torch.ones((1, 1, self.R, self.C), device=self.device)
243 | phase = torch.zeros((1, 1, self.R, self.C), device=self.device)
244 | self.set_amplitude(amplitude)
245 | self.set_phase(phase)
246 |
247 | def save(self, fn):
248 | '''
249 | Save the amplitude and phase of the light wavefront as a file
250 | Args:
251 | fn: filename to save. the format should be either "npy" or "mat"
252 |
253 | '''
254 |
255 | amp = self.get_amplitude().data.cpu().numpy()
256 | phase = self.get_phase().data.cpu().numpy()
257 |
258 | if fn[-3:] == 'npy':
259 | np.save(fn, amp, phase)
260 | elif fn[-3:] == 'mat':
261 | savemat(fn, {'amplitude':amp, 'phase':phase})
262 | else:
263 | print('extension in %s is unknown'%fn)
264 | print('light saved to %s\n'%fn)
265 |
266 | def visualize(self,b=0,c=0):
267 | """
268 | Visualize the light wave
269 | Args:
270 | b: batch index to visualize default is 0
271 | c: channel index to visualize. default is 0
272 | """
273 |
274 | bw = self.get_bandwidth()
275 |
276 | plt.figure(figsize=(20,5))
277 | plt.subplot(131)
278 | amplitude_b = self.get_amplitude().data.cpu()[b,c,...].squeeze()
279 | plt.imshow(amplitude_b,
280 | extent=[0,bw[0]*1e3, 0, bw[1]*1e3], cmap='inferno')
281 | plt.title('amplitude')
282 | plt.xlabel('mm')
283 | plt.ylabel('mm')
284 | plt.colorbar()
285 |
286 | plt.subplot(132)
287 | phase = self.get_phase().data.cpu()[b,c,...].squeeze()
288 | plt.imshow(self.get_phase().data.cpu()[b,c,...].squeeze(),
289 | extent=[0,bw[0]*1e3, 0, bw[1]*1e3], cmap='hsv', vmin=-np.pi, vmax=np.pi) # cyclic colormap
290 | plt.title('phase')
291 | plt.xlabel('mm')
292 | plt.ylabel('mm')
293 | plt.colorbar()
294 |
295 | plt.subplot(133)
296 | intensity_b = self.get_intensity().data.cpu()[b,c,...].squeeze()
297 | plt.imshow(intensity_b,
298 | extent=[0,bw[0]*1e3, 0, bw[1]*1e3], cmap='inferno')
299 | plt.title('intensity')
300 | plt.xlabel('mm')
301 | plt.ylabel('mm')
302 | plt.colorbar()
303 |
304 | plt.suptitle('(%d,%d), pitch:%.2f[um], wvl:%.2f[nm], device:%s'%(self.R, self.C,
305 | self.pitch/1e-6, self.wvl/1e-9, self.device))
306 | plt.show()
307 |
308 | def shape(self):
309 | """
310 | Returns the shape of light wavefront
311 | Returns:
312 | shape
313 | """
314 | return self.field.shape()
315 |
--------------------------------------------------------------------------------
/pado/material.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Material:
5 | def __init__(self, material_name):
6 | """
7 | Material of optical elements and its refractive index.
8 | Args:
9 | material name: name of the material. So far, we provide PDMS, FUSED_SILICA, VACCUM.
10 | """
11 | self.material_name = material_name
12 |
13 | def get_RI(self, wvl):
14 | """
15 | Return the refractive index of the current material for a wavelength
16 | Args:
17 | wvl: wavelength in meter
18 | Returns:
19 | RI: Refractive index at wvl
20 | """
21 |
22 | wvl_nm = wvl / 1e-9
23 | if self.material_name == 'PDMS':
24 | RI = np.sqrt(1 + (1.0057 * (wvl_nm**2))/(wvl_nm**2 - 0.013217))
25 | elif self.material_name == 'FUSED_SILICA':
26 | wvl_um = wvl_nm*1e-3
27 | RI = (1 + 0.6961663 / (1 - (0.0684043 / wvl_um) ** 2) + 0.4079426 / (1 - (0.1162414 / wvl_um) ** 2) + 0.8974794 / (1 - (9.896161 / wvl_um) ** 2)) ** .5
28 | # https://refractiveindex.info/?shelf=glass&book=fused_silica&page=Malitson
29 | elif self.material_name == 'VACUUM':
30 | RI = 1.0
31 | else:
32 | return NotImplementedError('%s is not in the RI list'%self.material_name)
33 |
34 | return RI
35 |
36 |
--------------------------------------------------------------------------------
/pado/optical_element.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import matplotlib.pyplot as plt
5 |
6 |
7 | class OpticalElement:
8 | def __init__(self, R, C, pitch, wvl, device, name="not defined",B=1):
9 | """
10 | Base class for optical elements. Any optical element change the wavefront of incident light
11 | The change of the wavefront is stored as amplitude and phase tensors
12 | Note that he number of channels is one for the wavefront modulation.
13 | Args:
14 | R: row
15 | C: column
16 | pitch: pixel pitch in meter
17 | wvl: wavelength of light in meter
18 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ...
19 | name: name of the current optical element
20 | B: batch size
21 | """
22 |
23 | self.name = name
24 | self.B = B
25 | self.R = R
26 | self.C = C
27 | self.pitch = pitch
28 | self.device = device
29 | self.amplitude_change = torch.ones((B, 1, R, C), device=self.device)
30 | self.phase_change = torch.zeros((B, 1, R, C), device=self.device)
31 | self.wvl = wvl
32 |
33 | def shape(self):
34 | """
35 | Returns the shape of light-wavefront modulation. The nunmber of channels is one
36 | Returns:
37 | shape
38 | """
39 | return (self.B,1,self.R,self.C)
40 |
41 | def set_pitch(self, pitch):
42 | """
43 | Set the pixel pitch of the complex tensor
44 | Args:
45 | pitch: pixel pitch in meter
46 | """
47 | self.pitch = pitch
48 |
49 | def resize(self, target_pitch, interp_mode='nearest'):
50 | '''
51 | Resize the wavefront change by changing the pixel pitch.
52 | Args:
53 | target_pitch: new pixel pitch to use
54 | interp_mode: interpolation method used in torch.nn.functional.interpolate 'bilinear', 'nearest'
55 | '''
56 |
57 | scale_factor = self.pitch / target_pitch
58 | self.amplitude_change = F.interpolate(self.amplitude_change, scale_factor=scale_factor,
59 | mode=interp_mode)
60 | self.phase_change = F.interpolate(self.phase_change, scale_factor=scale_factor,
61 | mode=interp_mode)
62 | self.set_pitch(target_pitch)
63 | self.R = self.amplitude_change.shape[-2]
64 | self.C = self.amplitude_change.shape[-1]
65 |
66 | def get_amplitude_change(self):
67 | '''
68 | Return the amplitude change of the wavefront
69 | Returns:
70 | amplitude change: ampiltude change
71 | '''
72 |
73 | return self.amplitude_change
74 |
75 | def get_phase_change(self):
76 | '''
77 | Return the phase change of the wavefront
78 | Returns:
79 | phase change: phase change
80 | '''
81 |
82 | return self.phase_change
83 |
84 | def set_amplitude_change(self, amplitude):
85 | """
86 | Set the amplitude change
87 | Args:
88 | amplitude change: amplitude change in the polar representation of the complex number
89 | """
90 |
91 | assert amplitude.shape[2] == self.R and amplitude.shape[3] == self.C
92 | self.amplitude_change = amplitude
93 |
94 | def set_phase_change(self, phase):
95 | """
96 | Set the phase change
97 | Args:
98 | phase change: phase change in the polar representation of the complex number
99 | """
100 |
101 | assert phase.shape[2] == self.R and phase.shape[3] == self.C
102 | self.phase_change = phase
103 |
104 | def pad(self, pad_width, padval=0):
105 | """
106 | Pad the wavefront change with a constant value by pad_width
107 | Args:
108 | pad_width: (tuple) pad width of the tensor following torch functional pad
109 | padval: value to pad. default is zero
110 | """
111 | if padval == 0:
112 | self.amplitude_change = torch.nn.functional.pad(self.get_amplitude_change(), pad_width)
113 | self.phase_change = torch.nn.functional.pad(self.get_phase_change(), pad_width)
114 | else:
115 | return NotImplementedError('only zero padding supported')
116 |
117 | self.R += pad_width[0] + pad_width[1]
118 | self.C += pad_width[2] + pad_width[3]
119 |
120 | def forward(self, light, interp_mode='nearest'):
121 | """
122 | Forward the incident light with the optical element.
123 | Args:
124 | light: incident light
125 | interp_mode: interpolation method used in torch.nn.functional.interpolate 'bilinear', 'nearest'
126 | Returns:
127 | light after interaction with the optical element
128 | """
129 |
130 | if light.pitch > self.pitch:
131 | light.resize(self.pitch, interp_mode)
132 | light.set_pitch(self.pitch)
133 | elif light.pitch < self.pitch:
134 | self.resize(light.pitch, interp_mode)
135 | self.set_pitch(light.pitch)
136 |
137 | if light.wvl != self.wvl:
138 | return NotImplementedError('wavelength should be same for light and optical elements')
139 |
140 | r1 = np.abs((light.R - self.R)//2)
141 | r2 = np.abs(light.R - self.R) - r1
142 | pad_width = (r1, r2, 0, 0)
143 | if light.R > self.R:
144 | self.pad(pad_width)
145 | elif light.R < self.R:
146 | light.pad(pad_width)
147 |
148 | c1 = np.abs((light.C - self.C)//2)
149 | c2 = np.abs(light.C - self.C) - c1
150 | pad_width = (0, 0, c1, c2)
151 | if light.C > self.C:
152 | self.pad(pad_width)
153 | elif light.C < self.C:
154 | light.pad(pad_width)
155 |
156 | light.set_phase(light.get_phase() + self.get_phase_change())
157 | light.set_amplitude(light.get_amplitude() * self.get_amplitude_change())
158 |
159 | return light
160 |
161 | def visualize(self, b=0):
162 | """
163 | Visualize the wavefront modulation of the optical element
164 | Args:
165 | b: batch index to visualize default is 0
166 | """
167 |
168 | plt.figure(figsize=(13,6))
169 |
170 | plt.subplot(121)
171 | plt.imshow(self.get_amplitude_change().data.cpu()[b,...].squeeze())
172 | plt.title('amplitude')
173 | plt.colorbar()
174 |
175 | plt.subplot(122)
176 | plt.imshow(self.get_phase_change().data.cpu()[b,...].squeeze())
177 | plt.title('phase')
178 | plt.colorbar()
179 |
180 | plt.suptitle('%s, (%d,%d), pitch:%.2f[um], wvl:%.2f[nm], device:%s'
181 | %(self.name, self.R, self.C, self.pitch/1e-6, self.wvl/1e-9, self.device))
182 | plt.show()
183 |
184 |
185 | class RefractiveLens(OpticalElement):
186 | def __init__(self, R, C, pitch, focal_length, wvl, device):
187 | """
188 | Thin refractive lens
189 | Args:
190 | R: row
191 | C: column
192 | pitch: pixel pitch in meter
193 | focal_length: focal length of the lens in meter
194 | wvl: wavelength of light in meter
195 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ...
196 | """
197 |
198 | super().__init__(R, C, pitch, wvl, device, name="refractive_lens")
199 |
200 | self.set_focal_length(focal_length)
201 | self.set_phase_change( self.compute_phase(self.wvl, shift_x=0, shift_y=0) )
202 |
203 | def set_focal_length(self, focal_length):
204 | """
205 | Set the focal length of the lens
206 | Args:
207 | focal_length: focal length in meter
208 | """
209 |
210 | self.focal_length = focal_length
211 |
212 | def compute_phase(self, wvl, shift_x=0, shift_y=0):
213 | """
214 | Set the phase of a thin lens
215 | Args:
216 | wvl: wavelength of light in meter
217 | shift_x: x displacement of the lens w.r.t. incident light
218 | shift_y: y displacement of the lens w.r.t. incident light
219 | """
220 |
221 | bw_R = self.R*self.pitch
222 | bw_C = self.C*self.pitch
223 |
224 | x = np.arange(-bw_C/2, bw_C/2, self.pitch)
225 | x = x[:self.R]
226 | y = np.arange(-bw_R/2, bw_R/2, self.pitch)
227 | y = y[:self.C]
228 | xx,yy = np.meshgrid(x,y)
229 |
230 | theta_change = torch.tensor((-2*np.pi / wvl)*((xx-shift_x)**2 + (yy-shift_y)**2), device=self.device) / (2*self.focal_length)
231 | theta_change = torch.unsqueeze(torch.unsqueeze(theta_change, axis=0), axis=0)
232 | theta_change %= 2*np.pi
233 | theta_change -= np.pi
234 |
235 | return theta_change
236 |
237 | def height2phase(height, wvl, RI, wrap=True):
238 | """
239 | Convert the height of a material to the corresponding phase shift
240 | Args:
241 | height: height of the material in meter
242 | wvl: wavelength of light in meter
243 | RI: refractive index of the material at the wavelength
244 | wrap: return the wrapped phase [0,2pi]
245 | """
246 | dRI = RI - 1
247 | wv_n = 2. * np.pi / wvl
248 | phi = wv_n * dRI * height
249 | if wrap:
250 | phi %= 2 * np.pi
251 | return phi
252 |
253 | def phase2height(phase, wvl, RI):
254 | """
255 | Convert the phase change to the height of a material
256 | Args:
257 | phase: phase change of light
258 | wvl: wavelength of light in meter
259 | RI: refractive index of the material at the wavelength
260 | """
261 | dRI = RI - 1
262 | return wvl * phase / (2 * np.pi) / dRI
263 |
264 | def radius2phase(r, f, wvl):
265 | return (2 * np.pi * (np.sqrt(r * r + f * f) - f) / wvl) % (2 * np.pi)
266 |
267 | class DOE(OpticalElement):
268 | def __init__(self, R, C, pitch, material, wvl, device, height=None, phase=None, amplitude=None):
269 | """
270 | Diffractive optical element (DOE)
271 | Args:
272 | R: row
273 | C: column
274 | pitch: pixel pitch in meter
275 | material: material of the DOE
276 | wvl: wavelength of light in meter
277 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ...
278 | height: height map of the material in meter
279 | phase: phase change of light
280 | amplitude: amplitude change of light
281 | """
282 |
283 | super().__init__(R, C, pitch, wvl, device, name="doe")
284 |
285 | self.material = material
286 | self.height = None
287 |
288 | if amplitude is None:
289 | amplitude = torch.ones((1, 1, self.R, self.C), device=self.device)
290 |
291 | if height is None and phase is not None:
292 | self.mode = 'phase'
293 | self.set_phase_change(phase, wvl)
294 | self.set_amplitude_change(amplitude)
295 | elif height is not None and phase is None:
296 | self.mode = 'height'
297 | self.set_height(height)
298 | self.set_amplitude_change(amplitude)
299 | elif (height is None) and (phase is None) and (amplitude is None):
300 | self.mode = 'phase'
301 | phase = torch.zeros((1, 1, self.R, self.C), device=self.device)
302 | self.set_amplitude_change(amplitude)
303 | self.set_phase_change(phase, wvl)
304 |
305 |
306 | def change_wvl(self, wvl):
307 | """
308 | Change the wavelength of phase change
309 | Args:
310 | wvl: wavelength of phase change
311 | """
312 | height = self.get_height()
313 | self.wvl = wvl
314 | phase = height2phase(height, self.wvl, self.material.get_RI(self.wvl))
315 | self.set_phase_change(phase, self.wvl)
316 |
317 | def set_diffraction_grating_1d(self, slit_width, minh, maxh):
318 | """
319 | Set the wavefront modulation as 1D diffraction grating
320 | Args:
321 | slit_width: width of slit in meter
322 | minh: minimum height in meter
323 | maxh: maximum height in meter
324 | """
325 |
326 | slit_width_px = np.round(slit_width / self.pitch)
327 | slit_space_px = slit_width_px
328 |
329 | dg = np.zeros((self.R, self.C))
330 | slit_num_r = self.R // (2 * slit_width_px)
331 | slit_num_c = self.C // (2 * slit_width_px)
332 |
333 | dg[:] = minh
334 |
335 | for i in range(int(slit_num_c)):
336 | minc = int((slit_width_px + slit_space_px) * i)
337 | maxc = int(minc + slit_width_px)
338 |
339 | dg[:, minc:maxc] = maxh
340 | pc = torch.tensor(dg.astype(np.float32), device=self.device).unsqueeze(0).unsqueeze(0)
341 | self.set_phase_change(pc, self.wvl)
342 |
343 | def set_diffraction_grating_2d(self, slit_width, minh, maxh):
344 | """
345 | Set the wavefront modulation as 2D diffraction grating
346 | Args:
347 | slit_width: width of slit in meter
348 | minh: minimum height in meter
349 | maxh: maximum height in meter
350 | """
351 |
352 | slit_width_px = np.round(slit_width / self.pitch)
353 | slit_space_px = slit_width_px
354 |
355 | dg = np.zeros((self.R, self.C))
356 | slit_num_r = self.R // (2 * slit_width_px)
357 | slit_num_c = self.C // (2 * slit_width_px)
358 |
359 | dg[:] = minh
360 |
361 | for i in range(int(slit_num_r)):
362 | for j in range(int(slit_num_c)):
363 | minc = int((slit_width_px + slit_space_px) * j)
364 | maxc = int(minc + slit_width_px)
365 | minr = int((slit_width_px + slit_space_px) * i)
366 | maxr = int(minr + slit_width_px)
367 |
368 | dg[minr:maxr, minc:maxc] = maxh
369 |
370 | pc = torch.tensor(dg.astype(np.float32), device=self.device).unsqueeze(0).unsqueeze(0)
371 | self.set_phase_change(pc, self.wvl)
372 |
373 | def set_Fresnel_lens(self, focal_length, shift_x=0, shift_y=0):
374 | """
375 | Set the wavefront modulation as a fresnel lens
376 | Args:
377 | focal_length: focal length in meter
378 | shift_x: x displacement of the lens w.r.t. incident light
379 | shift_y: y displacement of the lens w.r.t. incident light
380 | """
381 |
382 | x = np.arange(-self.C*self.pitch/2, self.C*self.pitch/2, self.pitch)
383 | y = np.arange(-self.R*self.pitch/2, self.R*self.pitch/2, self.pitch)
384 | xx,yy = np.meshgrid(x,y)
385 | xx = torch.tensor(xx, device=self.device)
386 | yy = torch.tensor(yy, device=self.device)
387 |
388 | phase = (-2*np.pi / self.wvl) * (torch.sqrt((xx-shift_x)**2 + (yy-shift_y)**2 + focal_length**2) - focal_length)
389 | phase = phase % (2*np.pi)
390 | phase -= np.pi
391 | phase = phase.unsqueeze(0).unsqueeze(0)
392 |
393 | self.set_phase_change(phase, self.wvl)
394 |
395 | def resize(self, target_pitch):
396 | '''
397 | Resize the wavefront by changing the pixel pitch.
398 | Args:
399 | target_pitch: new pixel pitch to use
400 | '''
401 | scale_factor = self.pitch / target_pitch
402 | super().resize(target_pitch)
403 |
404 | if self.mode == 'phase':
405 | super().resize(target_pitch)
406 | elif self.mode == 'height':
407 | self.set_height(F.interpolate(self.height, scale_factor=scale_factor, mode='bilinear', align_corners=False))
408 | else:
409 | NotImplementedError('Mode is not set.')
410 |
411 | def get_height(self):
412 | """
413 | Return the height map of the DOE
414 | Returns:
415 | height map: height map in meter
416 | """
417 |
418 | if self.mode == 'height':
419 | return self.height
420 | elif self.mode == 'phase':
421 | height = phase2height(self.phase_change, self.wvl, self.material.get_RI(self.wvl))
422 | return height
423 | else:
424 | NotImplementedError('Mode is not set.')
425 |
426 | def get_phase_change(self):
427 | """
428 | Return the phase change induced by the DOE
429 | Returns:
430 | phase change: phase change
431 | """
432 | if self.mode == 'height':
433 | self.to_phase_mode()
434 | return self.phase_change
435 |
436 | def set_height(self, height):
437 | """
438 | Set the height map of the DOE
439 | Args:
440 | height map: height map in meter
441 | """
442 |
443 | if self.mode == 'height':
444 | self.height = height
445 | elif self.mode == 'phase':
446 | self.set_phase_change(height2phase(height, self.wvl, self.material.get_RI(self.wvl)), self.wvl)
447 |
448 | def set_phase_change(self, phase_change, wvl):
449 | """
450 | Set the phase change induced by the DOE
451 | Args:
452 | phase change: phase change
453 | """
454 |
455 | if self.mode == 'height':
456 | self.set_height(phase2height(phase_change, wvl, self.material.get_RI(wvl)))
457 | if self.mode == 'phase':
458 | self.wvl = wvl
459 | self.phase_change = phase_change
460 |
461 | def to_phase_mode(self):
462 | """
463 | Change the mode to phase change
464 | """
465 | if self.mode == 'height':
466 | self.phase_change = height2phase(self.height, self.wvl, self.material.get_RI(self.wvl))
467 | self.mode = 'phase'
468 | self.height = None
469 |
470 | def to_height_mode(self):
471 | """
472 | Change the mode to height
473 | """
474 | if self.mode == 'phase':
475 | self.height = phase2height(self.phase_change, self.wvl, self.material.get_RI(self.wvl))
476 | self.mode = 'height'
477 |
478 |
479 | class SLM(OpticalElement):
480 | def __init__(self, R, C, pitch, wvl, device, B=1):
481 | """
482 | Spatial light modulator (SLM)
483 | Args:
484 | R: row
485 | C: column
486 | pitch: pixel pitch in meter
487 | wvl: wavelength of light in meter
488 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ...
489 | B: batch size
490 | """
491 |
492 | super().__init__(R, C, pitch, wvl, device, name="SLM", B=B)
493 |
494 | def set_lens(self, focal_length, shift_x=0, shift_y=0):
495 | """
496 | Set the phase of a thin lens
497 | Args:
498 | wvl: wavelength of light in meter
499 | shift_x: x displacement of the lens w.r.t. incident light
500 | shift_y: y displacement of the lens w.r.t. incident light
501 | """
502 |
503 | x = np.arange(-self.C*self.pitch/2, self.C*self.pitch/2, self.pitch)
504 | y = np.arange(-self.R*self.pitch/2, self.R*self.pitch/2, self.pitch)
505 | xx,yy = np.meshgrid(x,y)
506 |
507 | phase = (2*np.pi / self.wvl)*((xx-shift_x)**2 + (yy-shift_y)**2) / (2*focal_length)
508 | phase = torch.tensor(phase.astype(np.float32), device=self.device).unsqueeze(0).unsqueeze(0)
509 | phase = phase % (2*np.pi)
510 | phase -= np.pi
511 |
512 | self.set_phase_change(phase, self.wvl)
513 |
514 | def set_amplitude_change(self, amplitude, wvl):
515 | """
516 | Set the amplitude change
517 | Args:
518 | amplitude change: amplitude change in the polar representation of the complex number
519 | wvl: wavelength of light in meter
520 |
521 | """
522 | self.wvl = wvl
523 | super().set_amplitude_change(amplitude)
524 |
525 | def set_phase_change(self, phase_change, wvl):
526 | """
527 | Set the phase change
528 | Args:
529 | phase change: phase change in the polar representation of the complex number
530 | wvl: wavelength of light in meter
531 | """
532 | self.wvl = wvl
533 | super().set_phase_change(phase_change)
534 |
535 |
536 | class Aperture(OpticalElement):
537 | def __init__(self, R, C, pitch, aperture_diameter, aperture_shape, wvl, device='cpu'):
538 | """
539 | Aperture
540 | Args:
541 | R: row
542 | C: column
543 | pitch: pixel pitch in meter
544 | aperture_diameter: diamater of the aperture in meter
545 | aperture_shape: shape of the aperture. {'square', 'circle'}
546 | wvl: wavelength of light in meter
547 | device: device to store the wavefront of light. 'cpu', 'cuda:0', ...
548 | """
549 |
550 | super().__init__(R, C, pitch, wvl, device, name="aperture")
551 |
552 | self.aperture_diameter = aperture_diameter
553 | self.aperture_shape = aperture_shape
554 | self.amplitude_change = torch.zeros((self.R, self.C), device=device)
555 | if self.aperture_shape == 'square':
556 | self.set_square()
557 | elif self.aperture_shape == 'circle':
558 | self.set_circle()
559 | else:
560 | return NotImplementedError
561 |
562 | def set_square(self):
563 | """
564 | Set the amplitude modulation of the aperture as square
565 | """
566 |
567 | self.aperture_shape = 'square'
568 |
569 | [x, y] = np.mgrid[-self.R // 2:self.R // 2, -self.C // 2:self.C // 2].astype(np.float32)
570 | r = self.pitch * np.asarray([abs(x), abs(y)]).max(axis=0)
571 | r = np.expand_dims(np.expand_dims(r, axis=0), axis=0)
572 |
573 | max_val = self.aperture_diameter / 2
574 | amp = (r <= max_val).astype(np.float32)
575 | amp[amp == 0] = 1e-20 # to enable stable learning
576 | self.amplitude_change = torch.tensor(amp, device=self.device)
577 |
578 | def set_circle(self, cx=0, cy=0, dia=None):
579 | """
580 | Set the amplitude modulation of the aperture as circle
581 | Args:
582 | cx, cy: relative center position of the circle with respect to the center of the light wavefront
583 | dia: circle diameter
584 | """
585 | [x, y] = np.mgrid[-self.R // 2:self.C // 2, -self.R // 2:self.C // 2].astype(np.float32)
586 | r2 = (x-cx) ** 2 + (y-cy) ** 2
587 | r2[r2 < 0] = 1e-20
588 | r = self.pitch * np.sqrt(r2)
589 | r = np.expand_dims(np.expand_dims(r, axis=0), axis=0)
590 |
591 | if dia is not None:
592 | self.aperture_diameter = dia
593 | self.aperture_shape = 'circle'
594 | max_val = self.aperture_diameter / 2
595 | amp = (r <= max_val).astype(np.float32)
596 | amp[amp == 0] = 1e-20
597 | self.amplitude_change = torch.tensor(amp, device=self.device)
598 |
599 | def quantize(x, levels, vmin=None, vmax=None, include_vmax=True):
600 | """
601 | Quantize the floating array
602 | Args:
603 | levels: number of quantization levels
604 | vmin: minimum value for quantization
605 | vmax: maximum value for quantization
606 | include_vmax: include vmax for the quantized levels
607 | False: quantize x with the space of 1/levels-1.
608 | True: quantize x with the space of 1/levels
609 | """
610 |
611 | if include_vmax is False:
612 | if levels == 0:
613 | return x
614 |
615 | if vmin is None:
616 | vmin = x.min()
617 | if vmax is None:
618 | vmax = x.max()
619 |
620 | #assert(vmin <= vmax)
621 |
622 | normalized = (x - vmin) / (vmax - vmin + 1e-16)
623 | if type(x) is np.ndarray:
624 | levelized = np.floor(normalized * levels) / (levels - 1)
625 | elif type(x) is torch.tensor:
626 | levelized = (normalized * levels).floor() / (levels - 1)
627 | result = levelized * (vmax - vmin) + vmin
628 | result[result < vmin] = vmin
629 | result[result > vmax] = vmax
630 |
631 | elif include_vmax is True:
632 | space = (x.max()-x.min())/levels
633 | vmin = x.min()
634 | vmax = vmin + space*(levels-1)
635 | if type(x) is np.ndarray:
636 | result = (np.floor((x-vmin)/space))*space + vmin
637 | elif type(x) is torch.tensor:
638 | result = (((x-vmin)/space).floor())*space + vmin
639 | result[resultvmax] = vmax
641 |
642 | return result
--------------------------------------------------------------------------------
/pado/propagator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from .fourier import fft
4 | from .complex import Complex
5 | from .conv import conv_fft
6 |
7 |
8 | def compute_pad_width(field, linear):
9 | """
10 | Compute the pad width of an array for FFT-based convolution
11 | Args:
12 | field: (B,Ch,R,C) complex tensor
13 | linear: True or False, flag for linear convolution (zero padding) or circular convolution (no padding)
14 | Returns:
15 | pad_width: pad-width tensor
16 | """
17 |
18 | if linear:
19 | R,C = field.shape()[-2:]
20 | pad_width = (C//2, C//2, R//2, R//2)
21 | else:
22 | pad_width = (0,0,0,0)
23 | return pad_width
24 |
25 | def unpad(field_padded, pad_width):
26 | """
27 | Unpad the already-padded complex tensor
28 | Args:
29 | field_padded: (B,Ch,R,C) padded complex tensor
30 | pad_width: pad-width tensor
31 | Returns:
32 | field: unpadded complex tensor
33 | """
34 |
35 | field = field_padded[...,pad_width[2]:-pad_width[3],pad_width[0]:-pad_width[1]]
36 | return field
37 |
38 | class Propagator:
39 | def __init__(self, mode):
40 | """
41 | Free-space propagator of light waves
42 | One can simulate the propagation of light waves on free space (no medium change at all).
43 | Args:
44 | mode: type of propagator. currently, we support "Fraunhofer" propagation or "Fresnel" propagation. Use Fraunhofer for far-field propagation and Fresnel for near-field propagation.
45 | """
46 | self.mode = mode
47 |
48 | def forward(self, light, z, linear=True):
49 | """
50 | Forward the incident light with the propagator.
51 | Args:
52 | light: incident light
53 | z: propagation distance in meter
54 | linear: True or False, flag for linear convolution (zero padding) or circular convolution (no padding)
55 | Returns:
56 | light: light after propagation
57 | """
58 |
59 | if self.mode == 'Fraunhofer':
60 | return self.forward_Fraunhofer(light, z, linear)
61 | if self.mode == 'Fresnel':
62 | return self.forward_Fresnel(light, z, linear)
63 | else:
64 | return NotImplementedError('%s propagator is not implemented'%self.mode)
65 |
66 |
67 | def forward_Fraunhofer(self, light, z, linear=True):
68 | """
69 | Forward the incident light with the Fraunhofer propagator.
70 | Args:
71 | light: incident light
72 | z: propagation distance in meter.
73 | The propagated wavefront is independent w.r.t. the travel distance z.
74 | The distance z only affects the size of the "pixel", effectively adjusting the entire image size.
75 | linear: True or False, flag for linear convolution (zero padding) or circular convolution (no padding)
76 | Returns:
77 | light: light after propagation
78 | """
79 |
80 | # pad_width = compute_pad_width(light.field, linear)
81 | field_propagated = fft(light.field)#, pad_width=pad_width)
82 | # field_propagated = unpad(field_propagated, pad_width)
83 |
84 | # based on the Fraunhofer reparametrization (u=x/wvl*z) and the Fourier frequency sampling (1/bandwidth)
85 | bw_r = light.get_bandwidth()[0]
86 | bw_c = light.get_bandwidth()[1]
87 | pitch_r_after_propagation = light.wvl*z/bw_r
88 | pitch_c_after_propagation = light.wvl*z/bw_c
89 |
90 | light_propagated = light.clone()
91 |
92 | # match the x-y pixel pitch using resampling
93 | if pitch_r_after_propagation >= pitch_c_after_propagation:
94 | scale_c = 1
95 | scale_r = pitch_r_after_propagation/pitch_c_after_propagation
96 | pitch_after_propagation = pitch_c_after_propagation
97 | elif pitch_r_after_propagation < pitch_c_after_propagation:
98 | scale_r = 1
99 | scale_c = pitch_c_after_propagation/pitch_r_after_propagation
100 | pitch_after_propagation = pitch_r_after_propagation
101 |
102 | light_propagated.set_field(field_propagated)
103 | light_propagated.magnify((scale_r,scale_c))
104 | light_propagated.set_pitch(pitch_after_propagation)
105 |
106 | return light_propagated
107 |
108 | def forward_Fresnel(self, light, z, linear):
109 | """
110 | Forward the incident light with the Fresnel propagator.
111 | Args:
112 | light: incident light
113 | z: propagation distance in meter.
114 | linear: True or False, flag for linear convolution (zero padding) or circular convolution (no padding)
115 | Returns:
116 | light: light after propagation
117 | """
118 | field_input = light.field
119 |
120 | # compute the convolutional kernel
121 | sx = light.C / 2
122 | sy = light.R / 2
123 | x = np.arange(-sx, sx, 1)
124 | y = np.arange(-sy, sy, 1)
125 | xx, yy = np.meshgrid(x,y)
126 | xx = torch.from_numpy(xx*light.pitch).to(light.device)
127 | yy = torch.from_numpy(yy*light.pitch).to(light.device)
128 | k = 2*np.pi/light.wvl # wavenumber
129 | phase = (k*(xx**2 + yy**2)/(2*z))
130 | amplitude = torch.ones_like(phase) / z / light.wvl
131 | conv_kernel = Complex(mag=amplitude, ang=phase)
132 |
133 | # Propagation with the convolution kernel
134 | pad_width = compute_pad_width(field_input, linear)
135 |
136 | field_propagated = conv_fft(field_input, conv_kernel, pad_width)
137 |
138 | # return the propagated light
139 | light_propagated = light.clone()
140 | light_propagated.set_field(field_propagated)
141 |
142 | return light_propagated
143 |
144 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torchvision.transforms as transforms
4 | from torch.utils.tensorboard import SummaryWriter
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | import torch.nn.functional as F
8 | import torch.backends.cudnn as cudnn
9 | from torch.autograd import Variable
10 |
11 | from importlib.machinery import SourceFileLoader
12 | import numpy as np
13 | import os
14 | import argparse
15 | from tqdm import trange
16 | import lpips
17 |
18 | from utils.save import *
19 | from utils.utils import *
20 |
21 | from models.forwards import *
22 | from models.ROLE import compute_raindrop
23 | from models.Defence import compute_fence
24 | from models.Dirt import compute_dirt
25 |
26 | def compute_dirt_raindrop(image_far, depth, args):
27 | if np.random.uniform() < 0.5:
28 | return compute_dirt(image_far, depth, args)
29 | else:
30 | return compute_raindrop(image_far, depth, args)
31 |
32 | def log(DOE_phase, G, batch_data, step,args):
33 | image_far, _ = batch_data
34 | image_far = image_far.to(args.device)
35 | image_near, mask, image_DOE, image_near_DOE, image_far_DOE, psf_near, psf_far, mask_doe, height_map = image_formation(image_far,DOE_phase, args.compute_obstruction, args)
36 | image = image_near * mask + image_far * (1 - mask)
37 |
38 | image_recon = G(image_DOE, psf_near, psf_far)
39 | image_recon = torch.clamp(image_recon, min=0, max=1)
40 | G_l1_loss = args.l1_loss_weight * args.l1_criterion(image_recon, image_far)
41 | G_perc_loss = torch.mean(args.perceptual_loss_weight * args.perceptual_criterion(2 * image_recon - 1, 2 * image_far - 1))
42 | G_masked_loss = args.masked_loss_weight * args.l1_criterion(image_recon*mask, image_far*mask)
43 | loss = G_l1_loss + G_perc_loss + G_masked_loss
44 |
45 | psfs, log_psfs = plot_depth_based_psf(DOE_phase, args, args.param.plot_depth)
46 | DOE_phase_wrapped = DOE_phase % (2*np.pi)
47 | DOE_phase_wrapped = DOE_phase_wrapped - torch.min(DOE_phase_wrapped)
48 | DOE_phase_wrapped = DOE_phase_wrapped / torch.max(DOE_phase_wrapped)
49 | DOE_phase = DOE_phase - torch.min(DOE_phase)
50 | DOE_phase = DOE_phase / torch.max(DOE_phase)
51 |
52 | args.writer.add_scalar('val_loss/L1_loss',G_l1_loss, step)
53 | args.writer.add_scalar('val_loss/perc_loss',G_perc_loss, step)
54 | args.writer.add_scalar('val_loss/masked_loss',G_masked_loss, step)
55 | args.writer.add_scalar('val_loss/loss',loss, step)
56 | args.writer.add_image('result', torch.cat([torch.cat([image[0],image_far[0]],1), torch.cat([image_DOE[0],image_recon[0]],1)], -1), step)
57 | args.writer.add_image('image_sensor_component', torch.cat([image_far_DOE[0], image_near_DOE[0], mask_doe[0]],-1), step)
58 | args.writer.add_image('RGB_psf', psfs[0], step)
59 | args.writer.add_image('RGB_logpsf', log_psfs[0], step)
60 | args.writer.add_image('height_map',(height_map/torch.max(height_map))[0], step)
61 | args.writer.add_image('phase_map', DOE_phase[0], step)
62 | args.writer.add_image('phase_map_wrapped', DOE_phase_wrapped[0], step)
63 |
64 |
65 | def train_step(batch_data, DOE_phase, optics_optimizer, G, G_optimizer, step, args):
66 | image_far, _ = batch_data
67 | image_far = image_far.to(args.device)
68 | image_near, mask, image_DOE, image_near_DOE, image_far_DOE, psf_near, psf_far, mask_doe, height_map = image_formation(image_far,DOE_phase, args.compute_obstruction, args)
69 | image_recon = G(image_DOE, psf_near, psf_far)
70 | G_l1_loss = args.l1_loss_weight * args.l1_criterion(image_recon, image_far)
71 | G_perc_loss = torch.mean(args.perceptual_loss_weight * args.perceptual_criterion(2 * image_recon - 1, 2 * image_far - 1))
72 | G_masked_loss = args.masked_loss_weight * args.l1_criterion(image_recon*mask, image_far*mask)
73 | loss = G_l1_loss + G_perc_loss + G_masked_loss
74 | loss.backward()
75 |
76 | G_optimizer.step()
77 | G_optimizer.zero_grad()
78 | if args.train_optics:
79 | optics_optimizer.step()
80 | optics_optimizer.zero_grad()
81 | return loss.detach()
82 |
83 | def train(args):
84 |
85 | if args.debug:
86 | np.random.seed(args.seed)
87 | torch.manual_seed(args.seed)
88 | if args.device == 'cuda':
89 | torch.cuda.manual_seed(args.seed)
90 | cudnn.benchmark = True
91 | cudnn.enabled=True
92 | param = args.param
93 |
94 | transform_train = transforms.Compose([
95 | transforms.RandomCrop(param.data_resolution,pad_if_needed=True), # Places365 image size varies
96 | transforms.RandomCrop([param.equiv_crop_size, param.equiv_crop_size],pad_if_needed=True),
97 | transforms.Resize([param.img_res, param.img_res]),
98 | transforms.RandomHorizontalFlip(),
99 | transforms.ToTensor(),
100 | ])
101 |
102 | transform_test = transforms.Compose([
103 | transforms.RandomCrop(param.data_resolution,pad_if_needed=True), # Places365 image size varies
104 | transforms.CenterCrop(param.equiv_crop_size),
105 | transforms.Resize([param.img_res, param.img_res]),
106 | transforms.ToTensor(),
107 | ])
108 | if args.obstruction == 'fence':
109 | trainset = torchvision.datasets.Places365(
110 | root=param.dataset_dir, split="train-standard", transform=transform_train)
111 | testset = torchvision.datasets.Places365(
112 | root=param.dataset_dir, split="val", transform=transform_test)
113 | args.compute_obstruction = compute_fence
114 | elif args.obstruction == 'raindrop':
115 | trainset = torchvision.datasets.ImageFolder(param.training_dir, transform=transform_train)
116 | testset = torchvision.datasets.ImageFolder(param.val_dir, transform=transform_test)
117 | args.compute_obstruction = compute_raindrop
118 | elif args.obstruction == 'dirt':
119 | trainset = torchvision.datasets.ImageFolder(param.training_dir, transform=transform_train)
120 | testset = torchvision.datasets.ImageFolder(param.val_dir, transform=transform_test)
121 | args.compute_obstruction = compute_dirt
122 | elif args.obstruction == 'dirt_raindrop':
123 | trainset = torchvision.datasets.ImageFolder(param.training_dir, transform=transform_train)
124 | testset = torchvision.datasets.ImageFolder(param.val_dir, transform=transform_test)
125 | args.compute_obstruction = compute_dirt_raindrop
126 |
127 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True)
128 | testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False)
129 |
130 | if args.debug:
131 | trainloader = testloader
132 |
133 | # build model and loss
134 | args.perceptual_criterion = lpips.LPIPS(net='vgg').to(args.device)
135 | args.l1_criterion = nn.L1Loss().to(args.device)
136 | args.writer = SummaryWriter(args.result_path)
137 |
138 | param = args.param
139 | if args.train_optics:
140 | DOE_phase = Variable(param.DOE_phase_init.to(args.device), requires_grad=True)
141 | optics_optimizer = optim.Adam([DOE_phase], lr=args.optics_lr)
142 | else:
143 | DOE_phase = Variable(param.DOE_phase_init.to(args.device), requires_grad=False)
144 | optics_optimizer = None
145 |
146 | from models.recon import Arch
147 | G = Arch(args).to(args.device)
148 | if args.pretrained_G is not None:
149 | G.load_state_dict(torch.load(args.G_init_ckpt, map_location='cpu'))
150 | G.to(args.device)
151 |
152 | G_optimizer = optim.Adam(params=G.parameters(), lr=args.G_lr)
153 |
154 | for _, batch_data in enumerate(testloader):
155 | test_data = batch_data
156 | break
157 |
158 | total_step = 0
159 | train_loss = 0
160 | log(DOE_phase, G, test_data, total_step, args)
161 | for epoch_cnt in trange(args.n_epochs, desc="Epoch"):
162 | for _, batch_data in enumerate(trainloader):
163 | step_loss = train_step(batch_data, DOE_phase, optics_optimizer, G, G_optimizer, total_step, args)
164 | total_step += 1
165 | train_loss += step_loss
166 | if total_step % args.log_freq == 0:
167 | log(DOE_phase, G, test_data, total_step, args)
168 | args.writer.add_scalar('train_loss/loss',train_loss/args.log_freq, total_step)
169 | train_loss = 0
170 | if total_step % args.save_freq == 0:
171 | torch.save(G.state_dict(), os.path.join(args.result_path,'G_%03d.pt' % (total_step//args.save_freq)))
172 | if args.train_optics:
173 | torch.save(DOE_phase, os.path.join(args.result_path,'DOE_phase_%03d.pt' % (total_step//args.save_freq)))
174 |
175 |
176 | def main():
177 | parser = argparse.ArgumentParser(
178 | description='Obstruction-Free DOE',
179 | formatter_class=argparse.RawDescriptionHelpFormatter
180 | )
181 | def str2bool(v):
182 | assert(v == 'True' or v == 'False')
183 | return v.lower() in ('true')
184 |
185 | def none_or_str(value):
186 | if value.lower() == 'none':
187 | return None
188 | return value
189 |
190 | parser.add_argument('--debug', action="store_true", help='debug mode, train on validation data to speed up the process')
191 | parser.add_argument('--train_optics', action="store_true", help='optimize optical element design')
192 | parser.add_argument('--pretrained_DOE', default = None, type =none_or_str, help = 'use a pretrained DOE')
193 | parser.add_argument('--pretrained_G', default = None, type =none_or_str, help = 'use a pretrained G')
194 | parser.add_argument('--result_path', default = './ckpt/opt', type=str, help='dir to save models and checkpoints')
195 | parser.add_argument('--param_file', default= 'config/param_MV_1600.py', type=str, help='path to param file')
196 |
197 | parser.add_argument('--obstruction', default = 'dirt_raindrop', type = str, help = 'obsturction type')
198 | parser.add_argument('--sensor_noise', default = 0.008, type=float, help='sensor noise level')
199 | parser.add_argument('--n_epochs', default = 40, type = int, help = 'max num of training epoch')
200 | parser.add_argument('--optics_lr', default=0.1, type=float, help='optical element learning rate')
201 | parser.add_argument('--G_lr', default=1e-4, type=float, help='network learning rate')
202 |
203 | parser.add_argument('--l1_loss_weight', default = 1, type = float, help = 'weight for L1 loss')
204 | parser.add_argument('--masked_loss_weight', default = 1, type = float, help = 'weight for masked loss (focus on obstructed scene)')
205 | parser.add_argument('--perceptual_loss_weight', default = 1, type = float, help = 'weight for perceptual loss')
206 |
207 | parser.add_argument('--log_freq', default=400, type=int, help = 'frequency (num_steps) of logging')
208 | parser.add_argument('--save_freq', default=2000, type=int, help = 'frequency (num_steps) of saving checkpoint and visual performance')
209 | parser.add_argument('--seed', type=int, default=1234, help='random seed')
210 |
211 | args = parser.parse_args()
212 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
213 |
214 | param = SourceFileLoader("param", args.param_file).load_module()
215 | param = convert_resolution(param,args)
216 |
217 | if args.pretrained_DOE is not None:
218 | if args.pretrained_DOE.endswith('.pt'):
219 | args.DOE_phase_init_ckpt = args.pretrained_DOE
220 | else:
221 | args.DOE_phase_init_ckpt = last_save(args.pretrained_DOE, 'DOE_phase_*')
222 | param.DOE_phase_init = torch.load(args.DOE_phase_init_ckpt, map_location='cpu').detach()
223 |
224 | if args.pretrained_G is not None:
225 | args.G_init_ckpt = last_save(args.pretrained_G, 'G_*')
226 |
227 | save_settings(args, param)
228 | train(args)
229 |
230 | if __name__ == '__main__':
231 |
232 | main()
233 |
234 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | PRETRAINED_G=None
2 | PRETRAINED_DOE=None
3 | RESULT_DIR=ckpt/E2E_MV2400
4 | PARAM=config/param_MV_2400.py
5 | OBSTRUCTION=dirt_raindrop
6 |
7 | conda activate SeeThroughObstruction
8 |
9 | python train.py --train_optics --result_path $RESULT_DIR --param_file $PARAM --obstruction $OBSTRUCTION --pretrained_DOE $PRETRAINED_DOE --pretrained_G $PRETRAINED_G
--------------------------------------------------------------------------------
/utils/._save.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/SeeThroughObstructions/f94a297cc87d72dcb17ffdab3bed64cff186ab9f/utils/._save.py
--------------------------------------------------------------------------------
/utils/save.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from matplotlib import pyplot
4 | import cv2
5 | import shutil
6 | import json
7 | from glob import glob
8 |
9 | from models.forwards import *
10 |
11 | def save_settings(args, param):
12 | if not os.path.exists(args.result_path):
13 | os.makedirs(args.result_path)
14 | args_dict = vars(args)
15 | with open(os.path.join(args.result_path,'args.json'), "w") as f:
16 | json.dump(args_dict, f, indent=4, sort_keys=False)
17 | shutil.copy(args.param_file, args.result_path)
18 | if args.pretrained_DOE is not None:
19 | shutil.copy(args.DOE_phase_init_ckpt, os.path.join(args.result_path, 'init'))
20 | if args.pretrained_G is not None:
21 | shutil.copy(args.G_init_ckpt, os.path.join(args.result_path, 'init'))
22 | args.param = param
23 |
24 | def plot_depth_based_psf(DOE_phase, args, depths = [0.05, 0.1, 0.2, 0.4, 0.8, 1, 3, 5], wvls = 'RGB', normalize = True, merge_channel = False):
25 | param = args.param
26 | doe = DOE(param.R, param.C, param.DOE_pitch, param.material, param.DOE_wvl,args.device, phase = DOE_phase)
27 |
28 | psfs = []
29 | if wvls == 'RGB':
30 | for i in range(len(param.wvls)):
31 | wvl = param.wvls[i]
32 | psf_depth = []
33 | for z in depths:
34 | psf = compute_psf_Fraunhofer(wvl, torch.tensor(z), doe, args)
35 | if normalize:
36 | psf_depth.append(psf/torch.max(psf))
37 | else:
38 | psf_depth.append(psf)
39 | psfs.append(torch.cat(psf_depth, -1))
40 | if merge_channel:
41 | psfs = torch.cat(psfs, 1)
42 | else:
43 | psfs = torch.cat(psfs, -2)
44 | elif wvls == 'design':
45 | wvl = param.DOE_wvl
46 | psf_depth = []
47 | for z in depths:
48 | psf = compute_psf_Fraunhofer(wvl, torch.tensor(z), doe, args)
49 | if normalize:
50 | psf_depth.append(psf/torch.max(psf))
51 | else:
52 | psf_depth.append(psf)
53 | psfs.append(torch.cat(psf_depth, -1))
54 | psfs = torch.cat(psfs, -2)
55 | else:
56 | assert False, "%s not defined" %wvls
57 |
58 | log_psfs = torch.log(psfs + 1e-9)
59 | log_psfs -= torch.min(log_psfs)
60 | log_psfs /= torch.max(log_psfs)
61 | return psfs, log_psfs
62 |
63 | def last_save(ckpt_path, file_format):
64 | return sorted(glob(os.path.join(ckpt_path, file_format)))[-1]
65 |
66 | def trapez(y,y0,w):
67 | return np.clip(np.minimum(y+1+w/2-y0, -y+1+w/2+y0),0,1)
68 |
69 | def weighted_line(r0, c0, r1, c1, w, rmin=0, rmax=np.inf):
70 | # The algorithm below works fine if c1 >= c0 and c1-c0 >= abs(r1-r0).
71 | # If either of these cases are violated, do some switches.
72 | if abs(c1-c0) < abs(r1-r0):
73 | # Switch x and y, and switch again when returning.
74 | xx, yy, val = weighted_line(c0, r0, c1, r1, w, rmin=rmin, rmax=rmax)
75 | return (yy, xx, val)
76 |
77 | # At this point we know that the distance in columns (x) is greater
78 | # than that in rows (y). Possibly one more switch if c0 > c1.
79 | if c0 > c1:
80 | return weighted_line(r1, c1, r0, c0, w, rmin=rmin, rmax=rmax)
81 |
82 | # The following is now always < 1 in abs
83 | slope = (r1-r0) / (c1-c0)
84 |
85 | # Adjust weight by the slope
86 | w *= np.sqrt(1+np.abs(slope)) / 2
87 |
88 | # We write y as a function of x, because the slope is always <= 1
89 | # (in absolute value)
90 | x = np.arange(c0, c1+1, dtype=float)
91 | y = x * slope + (c1*r0-c0*r1) / (c1-c0)
92 |
93 | # Now instead of 2 values for y, we have 2*np.ceil(w/2).
94 | # All values are 1 except the upmost and bottommost.
95 | thickness = np.ceil(w/2)
96 | yy = (np.floor(y).reshape(-1,1) + np.arange(-thickness-1,thickness+2).reshape(1,-1))
97 | xx = np.repeat(x, yy.shape[1])
98 | vals = trapez(yy, y.reshape(-1,1), w).flatten()
99 |
100 | yy = yy.flatten()
101 |
102 | # Exclude useless parts and those outside of the interval
103 | # to avoid parts outside of the picture
104 | mask = np.logical_and.reduce((yy >= rmin, yy < rmax, vals > 0))
105 |
106 | return (yy[mask].astype(int), xx[mask].astype(int), vals[mask])
107 |
108 | def plot_line(x1,x2,y1,y2,psf):
109 | rr, cc, val = weighted_line(x1,x2, y1, y2,2)
110 | psf[rr,cc] = val[:,None]
111 | return psf
112 |
113 | def viz_psf(psf, param, center_R = None, g = 2.2, weight = 4):
114 | psf = (psf / np.max(psf))**(1/g)
115 | if center_R is not None:
116 | size = 2*center_R
117 | w = h = int(param.img_res/2-center_R)
118 | psf_center = psf[h:h+size,w:w+size]
119 | psf_center = cv2.resize(psf_center, (center_R * 15,center_R * 15),interpolation = cv2.INTER_NEAREST)
120 | psf[:center_R * 15, :center_R * 15] = psf_center
121 |
122 | if len(psf.shape) == 2:
123 | psf = pyplot.cm.hot(psf)[...,:-1]
124 |
125 | # psf[:int(param.img_res/3),0:weight,:] = 1
126 | psf[:center_R * 15+ weight,center_R * 15: center_R * 15 + weight,:] = 1
127 | # psf[0:weight,:int(param.img_res/3),:] = 1
128 | psf[center_R * 15:center_R * 15 + weight,:center_R * 15+ weight,:] = 1
129 | psf[h:h+size+ weight,w-weight:w,:] = 1
130 | psf[h:h+size+ weight,w + size : w+size+weight,:] = 1
131 | psf[h-weight:h,w-weight:w+size+ weight,:] = 1
132 | psf[h+size:h+size+weight,w:w+size+ weight,:] = 1
133 | rr, cc, val = weighted_line(center_R * 15,0, h+size + weight, w,weight)
134 | psf[rr,cc] = val[:,None]
135 | rr, cc, val = weighted_line(0, center_R * 15,h, w+size+ weight,weight)
136 | psf[rr,cc] = val[:,None]
137 | else:
138 | if len(psf.shape) == 2:
139 | psf = pyplot.cm.hot(psf)[...,:-1]
140 | return psf
141 |
142 | def plot_psf_array(psfs, param, center_R = 10, gap = 10, g = 1.5):
143 | # if psfs.shape[1] == 1:
144 | # psfs = torch.tile(psfs,(1,3,1,1))
145 | psfs_singles = torch.split(psfs, param.img_res, dim=-1)
146 | cnt = len(psfs_singles)
147 | canvas = np.ones([param.img_res, param.img_res * cnt + gap * (cnt-1), 3])
148 | for i in range(cnt):
149 | if psfs.shape[1] == 1:
150 | psf = psfs_singles[i][0,0].cpu().numpy()
151 | else:
152 | psf = psfs_singles[i][0].permute(1,2,0).cpu().numpy()
153 | if i < 3:
154 | canvas[:, i * (param.img_res+gap):i * (param.img_res+gap) + param.img_res, :] = viz_psf(psf, param, center_R = center_R, g = g)
155 |
156 | else:
157 | canvas[:, i * (param.img_res+gap):i * (param.img_res+gap) + param.img_res, :] = viz_psf(psf, param, center_R = None, g = g)
158 | pyplot.figure(figsize = (30,10))
159 | pyplot.imshow(canvas)
160 | return canvas
161 |
162 |
163 |
164 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | from pado.light import *
2 | from pado.optical_element import *
3 | from pado.propagator import *
4 | import torch
5 | import numpy as np
6 |
7 | def convert_resolution(param, args):
8 | # dataset
9 | if args.obstruction == 'fence':
10 | param.dataset_dir = '/projects/FHEIDE/obstruction_free_doe/Places365'
11 | param.data_resolution = [512,768]
12 | elif args.obstruction == 'raindrop' or 'dirt' or 'dirt_raindrop':
13 | param.training_dir = '/projects/FHEIDE/Bad2ClearWeather/Cityscapes/leftImg8bit_trainvaltest/leftImg8bit/train'
14 | param.val_dir = '/projects/FHEIDE/Bad2ClearWeather/Cityscapes/leftImg8bit_trainvaltest/leftImg8bit/val'
15 | param.data_resolution = [1024, 2048]
16 | else:
17 | assert False, "undefined obstruction"
18 |
19 | # convert resolution and pitch size
20 | param.equiv_image_size = param.img_res * param.image_sample_ratio # image resolution before downsampling in camera pixel pitch
21 | param.equiv_crop_size = int(param.equiv_image_size * param.camera_pitch / param.background_pitch) # convert to background pixel pitch
22 | return param
23 |
24 | def randuni(low, high, size):
25 | '''uniformly sample from [low, high)'''
26 | return (torch.rand(size)*(high - low) + low)
27 |
28 | def real2complex(real):
29 | return Complex(mag=real, ang=torch.zeros_like(real))
30 |
31 | def compute_pad_size(current_size, target_size):
32 | assert current_size < target_size
33 | gap = target_size - current_size
34 | left = int(gap/2)
35 | right = gap - left
36 | return int(left), int(right)
37 |
38 | def sample_psf(psf, sample_ratio):
39 | if sample_ratio == 1:
40 | return psf
41 | else:
42 | return torch.nn.AvgPool2d(sample_ratio, stride=sample_ratio)(psf)
43 |
44 | def metric2pixel(metric, depth, args):
45 | return int(metric * args.param.focal_length / (depth * args.param.equiv_camera_pitch))
46 |
47 | def edge_mask(R,cutoff, device):
48 | [x, y] = np.mgrid[-int(R):int(R),-int(R):int(R)]
49 | dist = np.sqrt(x**2 +y**2).astype(np.int32)
50 | mask = torch.tensor(1.0*(dist < cutoff))[None, None, ...]
51 | return mask.to(device)
52 |
53 | class AttributeDict(dict):
54 | def __getattr__(self, attr):
55 | return self[attr]
56 | def __setattr__(self, attr, value):
57 | self[attr] = value
--------------------------------------------------------------------------------