├── .gitignore ├── LICENSE ├── README.md ├── diffoptics ├── __init__.py ├── basics.py ├── optics.py ├── shapes.py ├── solvers.py └── version.py ├── examples ├── autodiff.py ├── backprop_compare.py ├── caustic_pyramid.py ├── data │ └── 20210304 │ │ └── ref2.tif ├── end2end_edof_backward_tracing.py ├── images │ ├── einstein.jpg │ ├── point.tif │ └── squirrel.jpg ├── lenses │ ├── DoubleGauss │ │ ├── US02532751-1.txt │ │ └── US02532751-1.zmx │ ├── README.md │ ├── ThorLabs │ │ ├── ACL5040U.txt │ │ └── LA1131.txt │ ├── Zemax_samples │ │ └── Nikon-z35-f1.8-JPA2019090949-example2.txt │ └── end2end │ │ └── end2end_edof.txt ├── misalignment_point.py ├── neural_networks │ └── DeblurGANv2 │ │ ├── .gitignore │ │ ├── LICENSE │ │ ├── README.md │ │ ├── adversarial_trainer.py │ │ ├── aug.py │ │ ├── config │ │ └── config.yaml │ │ ├── dataset.py │ │ ├── metric_counter.py │ │ ├── models │ │ ├── __init__.py │ │ ├── fpn_densenet.py │ │ ├── fpn_inception.py │ │ ├── fpn_inception_simple.py │ │ ├── fpn_mobilenet.py │ │ ├── losses.py │ │ ├── mobilenet_v2.py │ │ ├── models.py │ │ ├── networks.py │ │ ├── senet.py │ │ └── unet_seresnext.py │ │ ├── picture_to_video.py │ │ ├── predict.py │ │ ├── requirements.txt │ │ ├── schedulers.py │ │ ├── test.py │ │ ├── test.sh │ │ ├── test_aug.py │ │ ├── test_batchsize.py │ │ ├── test_dataset.py │ │ ├── test_metrics.py │ │ ├── train.py │ │ ├── train_end2end.py │ │ └── util │ │ ├── __init__.py │ │ ├── image_pool.py │ │ └── metrics.py ├── nikon.py ├── render_image.py ├── render_psf.py ├── sanity_check.py ├── spherical_aberration.py ├── training_dataset │ ├── 0008.png │ ├── 0010.png │ ├── 0023.png │ ├── 0030.png │ ├── 0031.png │ ├── 0032.png │ ├── 0115.png │ └── 0267.png └── utils_end2end.py ├── imgs ├── abp.jpg ├── applications.jpg ├── bp_abp_comp.jpg ├── examples │ ├── I.jpg │ ├── I0.jpg │ ├── I_final.png │ ├── I_psf_z=-1000.0.png │ ├── I_psf_z=-10000.0.png │ ├── I_psf_z=-1500.0.png │ ├── I_psf_z=-2000.0.png │ ├── I_psf_z=-3000.0.png │ ├── I_rendered.jpg │ ├── I_target.png │ ├── iter_1_z=6000.0mm_images.png │ ├── optimized.gif │ ├── optimized.mp4 │ ├── phase.png │ ├── sanity_check_dO.jpg │ └── sanity_check_zemax.jpg ├── memory_comp.jpg └── overview.jpg └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/.gitignore -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 vccimaging 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dO: A differentiable engine for Deep Lens design of computational imaging systems 2 | This is the PyTorch implementation for our paper "dO: A differentiable engine for Deep Lens design of computational imaging systems". 3 | ### [Project Page](https://vccimaging.org/Publications/Wang2022DiffOptics/) | [Paper](https://vccimaging.org/Publications/Wang2022DiffOptics/Wang2022DiffOptics.pdf) | [Supplementary Material](https://vccimaging.org/Publications/Wang2022DiffOptics/Wang2022DiffOptics_supp.pdf) 4 | 5 | dO: A differentiable engine for Deep Lens design of computational imaging systems 6 | [Congli Wang](https://congliwang.github.io), 7 | [Ni Chen](https://ni-chen.github.io), and 8 | [Wolfgang Heidrich](https://vccimaging.org/People/heidriw)
9 | King Abdullah University of Science and Technology (KAUST)
10 | IEEE Transactions on Computational Imaging 2022 11 | 12 | 13 | Figure: Our engine dO models ray tracing in a lens system in a derivative-aware way, this enables ray tracing with back-propagation. To be derivative-aware, all modules must be differentiable so that gradients can be back-propagated from the error metric ϵ(p(θ)) to variable parameters θ. This is achieved by two stages of the reverse-mode AD: the forward and the backward passes. To ensure differentiability and efficiency, a custom ray-surface intersection solver is introduced. Instead of unrolling iterations for forward/backward, only the forward (no AD) is computed to obtain solutions at surfaces fi = 0, and gradients are amended afterwards. 14 | 15 | ## TL; DR 16 | 17 | We implemented in PyTorch a memory- and computation-efficient differentiable ray tracing system for optical designs, for design applications in freeform, Deep Lens, metrology, and more. 18 | 19 | ## Update list 20 | 21 | - [x] Initial code release. 22 | - [x] [`autodiff.py`](./examples/autodiff.py): Demo of the `dO` engine. 23 | - [x] [`backprop_compare.py`](./examples/backprop_compare.py): Example on comparison between back-propagation and adjoint back-propagation. 24 | - [x] [`caustic_pyramid.py`](./examples/caustic_pyramid.py): Example on freeform caustic design. 25 | 26 | | Target irradiance | Optimized irradiance | Optimized phase map | 27 | | :-------------------------------------: | :-----------------------------------: | :-------------------------------: | 28 | | ![I_target](imgs/examples/I_target.png) | ![I_final](imgs/examples/I_final.png) | ![phase](imgs/examples/phase.png) | 29 | 30 | - [x] [`misalignment_point.py`](./examples/misalignment_point.py): Example on misalignment back-engineering, using real measurements. 31 | 32 | | Model (initial) | Measurement | Model (optimized) | 33 | | :-------------------------: | :-----------------------: | :---------------------------------------: | 34 | | ![I0](imgs/examples/I0.jpg) | ![I](imgs/examples/I.jpg) | ![optimized](imgs/examples/optimized.gif) | 35 | 36 | - [x] [`nikon.py`](./examples/nikon.py): Example on optimizing a Nikon design. 37 | - [x] [`render_image.py`](./examples/render_image.py): Example on rendering a single image from a design. 38 | 39 | | ![I_rendered](imgs/examples/I_rendered.jpg) | 40 | | ------------------------------------------- | 41 | 42 | - [x] [`render_psf.py`](./examples/render_psf.py): Example on rendering PSFs of varying fields and depths for a design. 43 | 44 | | ![I_psf_z=-3000.0](imgs/examples/I_psf_z=-3000.0.png) | ![I_psf_z=-2000.0](imgs/examples/I_psf_z=-2000.0.png) | 45 | | ------------------------------------------------------------ | ------------------------------------------------------------ | 46 | | ![I_psf_z=-1500.0](imgs/examples/I_psf_z=-1500.0.png) | ![I_psf_z=-1000.0](imgs/examples/I_psf_z=-1000.0.png) | 47 | 48 | - [x] [`sanity_check.py`](./examples/sanity_check.py): Example on Zemax versus `dO` for sanity check. 49 | 50 | | `dO` | Zemax | 51 | | :---------------------------------------------------: | :---------------------------------------------------------: | 52 | | ![sanity_check_dO](imgs/examples/sanity_check_dO.jpg) | ![sanity_check_zemax](imgs/examples/sanity_check_zemax.jpg) | 53 | 54 | 55 | - [x] [`spherical_aberration.py`](./examples/spherical_aberration.py): Example on optimizing spherical aberration. 56 | 57 | - [x] [`end2end_edof_backward_tracing.py`](./examples/end2end_edof_backward_tracing.py): Example on end-to-end learning of wavefront coding, for extended depth of field applications, using backward ray tracing. 58 | 59 | | ![iter_1_z=6000.0mm_images](imgs/examples/iter_1_z=6000.0mm_images.png) | 60 | | ------------------------------------------------------------ | 61 | 62 | - [ ] Code cleanups and add comments. 63 | 64 | - [ ] File I/O with Zemax. 65 | 66 | - [ ] Mini GUI for easy operations. 67 | 68 | ## Installation 69 | 70 | ### Prerequisite 71 | Though no GPUs are required, for speed's sake it is better to run the engine on a GPU. 72 | 73 | Install the required Python packages: 74 | ```python 75 | pip install -r requirements.txt 76 | ``` 77 | 78 | ### Running examples 79 | Examples are in the [`./examples`](./examples) folder, and running some of the examples may require installing additional Python packages. Just follow the terminal hints, for example install the following: 80 | 81 | ```shell 82 | pip install imageio opencv-python scikit-image 83 | ``` 84 | 85 | In case Python cannot find the path to `dO`, run the example scripts in the [`./examples`](./examples) directory, for example: 86 | 87 | ```shell 88 | cd /examples 89 | python3 misalignment_point.py 90 | ``` 91 | 92 | 93 | ## Summary 94 | 95 | ### Target problem 96 | 97 | - General optical design/metrology or Deep Lens designs are parameter-optimization problems, and learning-based methods (e.g. with back-propagation) can be employed as solvers. This requires the optical modeling to be numerically derivative-aware (i.e. differentiable). 98 | - However, straightforward differentiable ray tracing with auto-diff (AD) is not memory/computation-efficient. 99 | 100 | ### Our solutions 101 | 102 | - Differentiable ray-surface intersections requires a differentiable root-finding solver, which is typically iterative, like Newton's solver. Straightforward implementation is inefficient in both memory and computation. However, our paper makes an observation that, the status of the solver's iterations is *irrelevant* to the final solution -- That means, a differentiable root-finding solver can be smartly implemented as: (1) Find the optimal solution without AD (e.g. in block `with torch.no_grad()` in PyTorch), and (2) Re-engage AD to the solution found. This leads to great reduce in memory consumption, scaling up the system differentiability to large number of parameters or rays. 103 | 104 | | ![](./imgs/memory_comp.jpg) | 105 | | ------------------------------------------------------------ | 106 | | Figure: Comparison between the straightforward and our proposed differentiable ray-surface intersection methods for freeform surface optimization. Our method reduces the required memory by about 6 times. | 107 | 108 | - When optimizing a custom merit function for image-based applications appended with a neural network, e.g. in Deep Lens designs, the training (or, back-propagation) can be split into two parts: 109 | - (Front-end) Optical design parameter optimization (training). 110 | - (Back-end) Neural network post-processing training. 111 | 112 | This de-coupling resembles the checkpointing technology in deep learning, and hence reducing the memory-hunger issue when tracing many number of rays. 113 | 114 | | ![](./imgs/abp.jpg) | 115 | | ------------------------------------------------------------ | 116 | | ![](./imgs/bp_abp_comp.jpg) | 117 | | Figure: Adjoint back-propagation (Adjoint BP) and the corresponding comparison against back-propagation (BP). Our implementation enables the scale up to many millions of rays while the conventional cannot. | 118 | 119 | ### Applications 120 | 121 | | ![](./imgs/applications.jpg) | 122 | | :----------------------------------------------------------: | 123 | | Figure: Using dO the differentiable ray tracing system, we show the feasibility of advanced optical designs. | 124 | 125 | ## Relevant Project 126 | 127 | [Towards self-calibrated lens metrology by differentiable refractive deflectometry](https://vccimaging.org/Publications/Wang2021DiffDeflectometry/Wang2021DiffDeflectometry.pdf) 128 | [Congli Wang](https://congliwang.github.io), 129 | [Ni Chen](https://ni-chen.github.io), and 130 | [Wolfgang Heidrich](https://vccimaging.org/People/heidriw)
131 | King Abdullah University of Science and Technology (KAUST)
132 | OSA Optics Express 2021 133 | 134 | GitHub: https://github.com/vccimaging/DiffDeflectometry. 135 | 136 | ## Citation 137 | 138 | ```bibtex 139 | @article{wang2022dO, 140 | title={{dO: A differentiable engine for Deep Lens design of computational imaging systems}}, 141 | author={Wang, Congli and Chen, Ni and Heidrich, Wolfgang}, 142 | journal={IEEE Transactions on Computational Imaging}, 143 | year={2022}, 144 | volume={8}, 145 | number={}, 146 | pages={905-916}, 147 | doi={10.1109/TCI.2022.3212837}, 148 | publisher={IEEE} 149 | } 150 | ``` 151 | 152 | ## Contact 153 | Please either open an issue, or contact Congli Wang for questions. 154 | 155 | -------------------------------------------------------------------------------- /diffoptics/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | 3 | from .basics import * 4 | from .shapes import * 5 | from .optics import * 6 | from .solvers import * 7 | -------------------------------------------------------------------------------- /diffoptics/basics.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class PrettyPrinter(): 7 | def __str__(self): 8 | lines = [self.__class__.__name__ + ':'] 9 | for key, val in vars(self).items(): 10 | if val.__class__.__name__ in ('list', 'tuple'): 11 | for i, v in enumerate(val): 12 | lines += '{}[{}]: {}'.format(key, i, v).split('\n') 13 | 14 | elif val.__class__.__name__ in 'dict': 15 | pass 16 | elif key == key.upper() and len(key) > 5: 17 | pass 18 | else: 19 | lines += '{}: {}'.format(key, val).split('\n') 20 | return '\n '.join(lines) 21 | 22 | def to(self, device=torch.device('cpu')): 23 | for key, val in vars(self).items(): 24 | if torch.is_tensor(val): 25 | exec('self.{x} = self.{x}.to(device)'.format(x=key)) 26 | elif issubclass(type(val), PrettyPrinter): 27 | exec(f'self.{key}.to(device)') 28 | elif val.__class__.__name__ in ('list', 'tuple'): 29 | for i, v in enumerate(val): 30 | if torch.is_tensor(v): 31 | exec('self.{x}[{i}] = self.{x}[{i}].to(device)'.format(x=key, i=i)) 32 | elif issubclass(type(v), PrettyPrinter): 33 | exec('self.{}[{}].to(device)'.format(key, i)) 34 | 35 | 36 | class Ray(PrettyPrinter): 37 | """ 38 | Definition of a geometric ray. 39 | 40 | - o is the ray origin 41 | - d is the ray direction (normalized) 42 | """ 43 | def __init__(self, o, d, wavelength, device=torch.device('cpu')): 44 | self.o = o 45 | self.d = d 46 | 47 | # scalar-version 48 | self.wavelength = wavelength # [nm] 49 | self.mint = 1e-5 # [mm] 50 | self.maxt = 1e5 # [mm] 51 | self.to(device) 52 | 53 | def __call__(self, t): 54 | return self.o + t[..., None] * self.d 55 | 56 | 57 | class Transformation(PrettyPrinter): 58 | """ 59 | Rigid Transformation. 60 | 61 | - R is the rotation matrix. 62 | - t is the translational vector. 63 | """ 64 | def __init__(self, R, t): 65 | if torch.is_tensor(R): 66 | self.R = R 67 | else: 68 | self.R = torch.Tensor(R) 69 | if torch.is_tensor(t): 70 | self.t = t 71 | else: 72 | self.t = torch.Tensor(t) 73 | 74 | def transform_point(self, o): 75 | return torch.squeeze(self.R @ o[..., None]) + self.t 76 | 77 | def transform_vector(self, d): 78 | return torch.squeeze(self.R @ d[..., None]) 79 | 80 | def transform_ray(self, ray): 81 | o = self.transform_point(ray.o) 82 | d = self.transform_vector(ray.d) 83 | if o.is_cuda: 84 | return Ray(o, d, ray.wavelength, device=torch.device('cuda')) 85 | else: 86 | return Ray(o, d, ray.wavelength) 87 | 88 | def inverse(self): 89 | RT = self.R.T 90 | t = self.t 91 | return Transformation(RT, -RT @ t) 92 | 93 | 94 | class Spectrum(PrettyPrinter): 95 | """ 96 | Spectrum distribution of the rays. 97 | """ 98 | def __init__(self): 99 | self.WAVELENGTH_MIN = 400 # [nm] 100 | self.WAVELENGTH_MAX = 760 # [nm] 101 | self.to() 102 | 103 | def sample_wavelength(self, sample): 104 | return self.WAVELENGTH_MIN + (self.WAVELENGTH_MAX - self.WAVELENGTH_MIN) * sample 105 | 106 | 107 | class Sampler(PrettyPrinter): 108 | """ 109 | Sampler for generating random sample points. 110 | """ 111 | def __init__(self): 112 | self.to() 113 | self.pi_over_2 = np.pi / 2 114 | self.pi_over_4 = np.pi / 4 115 | 116 | def concentric_sample_disk(self, x, y): 117 | # https://pbr-book.org/3ed-2018/Monte_Carlo_Integration/2D_Sampling_with_Multidimensional_Transformations 118 | 119 | # map uniform random numbers to [-1,1]^2 120 | x = 2 * x - 1 121 | y = 2 * y - 1 122 | 123 | # handle degeneracy at the origin when xy == [0,0] 124 | 125 | # apply concentric mapping to point 126 | eps = np.finfo(float).eps 127 | 128 | if type(x) is torch.Tensor and type(y) is torch.Tensor: 129 | cond = torch.abs(x) > torch.abs(y) 130 | r = torch.where(cond, x, y) 131 | theta = torch.where(cond, 132 | self.pi_over_4 * (y / (x + eps)), 133 | self.pi_over_2 - self.pi_over_4 * (x / (y + eps)) 134 | ) 135 | return r * torch.cos(theta), r * torch.sin(theta) 136 | 137 | if type(x) is np.ndarray and type(y) is np.ndarray: 138 | cond = np.abs(x) > np.abs(y) 139 | r = np.where(cond, x, y) 140 | theta = np.where(cond, 141 | self.pi_over_4 * (y / (x + eps)), 142 | self.pi_over_2 - self.pi_over_4 * (x / (y + eps)) 143 | ) 144 | return r * np.cos(theta), r * np.sin(theta) 145 | 146 | 147 | class Filter(PrettyPrinter): 148 | def __init__(self, radius): 149 | self.radius = radius 150 | def eval(self, p): 151 | raise NotImplementedError() 152 | 153 | class Box(Filter): 154 | def __init__(self, radius=None): 155 | if radius is None: 156 | radius = [0.5, 0.5] 157 | Filter.__init__(self, radius) 158 | def eval(self, x): 159 | return torch.ones_like(x) 160 | 161 | class Triangle(Filter): 162 | def __init__(self, radius): 163 | if radius is None: 164 | radius = [2.0, 2.0] 165 | Filter.__init__(self, radius) 166 | def eval(self, p): 167 | x, y = p[...,0], p[...,1] 168 | return (torch.maximum(torch.zeros_like(x), self.radius[0] - x) * 169 | torch.maximum(torch.zeros_like(y), self.radius[1] - y)) 170 | 171 | # ---------------------------------------------------------------------------------------- 172 | 173 | class Material(PrettyPrinter): 174 | """ 175 | Optical materials for computing the refractive indices. 176 | 177 | The following follows the simple formula that 178 | 179 | n(\lambda) = A + B / \lambda^2 180 | 181 | where the two constants A and B can be computed from nD (index at 589.3 nm) and V (abbe number). 182 | """ 183 | def __init__(self, name=None): 184 | self.name = 'vacuum' if name is None else name.lower() 185 | 186 | # This table is hard-coded. TODO: Import glass libraries from Zemax. 187 | self.MATERIAL_TABLE = { # [nD, Abbe number] 188 | "vacuum": [1., np.inf], 189 | "air": [1.000293, np.inf], 190 | "occluder": [1., np.inf], 191 | "f2": [1.620, 36.37], 192 | "f15": [1.60570, 37.831], 193 | "uvfs": [1.458, 67.82], 194 | 195 | # https://shop.schott.com/advanced_optics/ 196 | "bk10": [1.49780, 66.954], 197 | "n-baf10": [1.67003, 47.11], 198 | "n-bk7": [1.51680, 64.17], 199 | "n-sf1": [1.71736, 29.62], 200 | "n-sf2": [1.64769, 33.82], 201 | "n-sf4": [1.75513, 27.38], 202 | "n-sf5": [1.67271, 32.25], 203 | "n-sf6": [1.80518, 25.36], 204 | "n-sf6ht": [1.80518, 25.36], 205 | "n-sf8": [1.68894, 31.31], 206 | "n-sf10": [1.72828, 28.53], 207 | "n-sf11": [1.78472, 25.68], 208 | "sf1": [1.71736, 29.51], 209 | "sf2": [1.64769, 33.85], 210 | "sf4": [1.75520, 27.58], 211 | "sf5": [1.67270, 32.21], 212 | "sf6": [1.80518, 25.43], 213 | "sf18": [1.72150, 29.245], 214 | 215 | # HIKARI.AGF 216 | "baf10": [1.67, 47.05], 217 | 218 | # SUMITA.AGF 219 | "sk1": [1.61030, 56.712], 220 | "sk16": [1.62040, 60.306], 221 | "ssk4": [1.61770, 55.116], 222 | 223 | # https://www.pgo-online.com/intl/B270.html 224 | "b270": [1.52290, 58.50], 225 | 226 | # https://refractiveindex.info, nD at 589.3 nm 227 | "s-nph1": [1.8078, 22.76], 228 | "d-k59": [1.5175, 63.50], 229 | 230 | "flint": [1.6200, 36.37], 231 | "pmma": [1.491756, 58.00], 232 | "polycarb": [1.585470, 30.00] 233 | } 234 | self.A, self.B = self._lookup_material() 235 | 236 | def ior(self, wavelength): 237 | """Computes index of refraction of a given wavelength (in [nm])""" 238 | return self.A + self.B / wavelength**2 239 | 240 | @staticmethod 241 | def nV_to_AB(n, V): 242 | def ivs(a): return 1./a**2 243 | lambdas = [656.3, 589.3, 486.1] 244 | B = 0.0 if V == 0 else (n - 1) / V / ( ivs(lambdas[2]) - ivs(lambdas[0]) ) 245 | A = n - B * ivs(lambdas[1]) 246 | return A, B 247 | 248 | def _lookup_material(self): 249 | out = self.MATERIAL_TABLE.get(self.name) 250 | if isinstance(out, list): 251 | n, V = out 252 | elif out is None: 253 | # try parsing input as a n/V pair 254 | tmp = self.name.split('/') 255 | n, V = float(tmp[0]), float(tmp[1]) 256 | return self.nV_to_AB(n, V) 257 | 258 | def to_string(self): 259 | return f'{self.A} + {self.B}/lambda^2' 260 | 261 | 262 | class InterpolationMode(Enum): 263 | nearest = 1 264 | linear = 2 265 | 266 | class BoundaryMode(Enum): 267 | zero = 1 268 | replicate = 2 269 | symmetric = 3 270 | periodic = 4 271 | 272 | class SimulationMode(Enum): 273 | render = 1 274 | trace = 2 275 | 276 | 277 | """ 278 | Utility functions. 279 | """ 280 | 281 | def init(): 282 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 283 | print("DiffOptics is using: {}".format(device)) 284 | torch.set_default_tensor_type('torch.FloatTensor') 285 | return device 286 | 287 | def length2(d): 288 | return torch.sum(d**2, axis=-1) 289 | 290 | def length(d): 291 | return torch.sqrt(length2(d)) 292 | 293 | def normalize(d): 294 | return d / length(d)[..., None] 295 | 296 | def set_zeros(x, valid=None): 297 | if valid == None: 298 | return torch.where(torch.isnan(x), torch.zeros_like(x), x) 299 | else: 300 | mask = valid[...,None] if len(x.shape) > len(valid.shape) else valid 301 | return torch.where(~mask, torch.zeros_like(x), x) 302 | 303 | def rodrigues_rotation_matrix(k, theta): # theta: [rad] 304 | """ 305 | This function implements the Rodrigues rotation matrix. 306 | """ 307 | # cross-product matrix 308 | kx, ky, kz = k[0], k[1], k[2] 309 | K = torch.Tensor([ 310 | [ 0, -kz, ky], 311 | [ kz, 0, -kx], 312 | [-ky, kx, 0] 313 | ]).to(k.device) 314 | if not torch.is_tensor(theta): 315 | theta = torch.Tensor(np.asarray(theta)).to(k.device) 316 | return torch.eye(3, device=k.device) + torch.sin(theta) * K + (1 - torch.cos(theta)) * K @ K 317 | 318 | def set_axes_equal(ax, scale=np.ones(3)): 319 | """ 320 | Make axes of 3D plot have equal scale (or scaled by `scale`). 321 | """ 322 | limits = np.array([ 323 | ax.get_xlim3d(), 324 | ax.get_ylim3d(), 325 | ax.get_zlim3d() 326 | ]) 327 | tmp = np.abs(limits[:,1]-limits[:,0]) 328 | ax.set_box_aspect(scale * tmp/np.min(tmp)) 329 | 330 | 331 | """ 332 | Test functions 333 | """ 334 | 335 | def generate_test_rays(): 336 | filmsize = np.array([4, 2]) 337 | 338 | o = np.array([3,4,-200]) 339 | o = np.tile(o[None, None, ...], [*filmsize, 1]) 340 | o = torch.Tensor(o) 341 | 342 | dx = 0.1 * torch.rand(*filmsize) 343 | dy = 0.1 * torch.rand(*filmsize) 344 | d = normalize(torch.stack((dx, dy, torch.ones_like(dx)), axis=-1)) 345 | 346 | wavelength = 500 # [nm] 347 | return Ray(o, d, wavelength) 348 | 349 | def generate_test_transformation(): 350 | k = np.random.rand(3) 351 | k = k / np.sqrt(np.sum(k**2)) 352 | k = torch.Tensor(k) 353 | theta = 1 # [rad] 354 | R = rodrigues_rotation_matrix(k, theta) 355 | t = np.random.rand(3) 356 | return Transformation(R, t) 357 | 358 | def generate_test_material(): 359 | return Material('N-BK7') 360 | 361 | 362 | if __name__ == "__main__": 363 | init() 364 | 365 | rays = generate_test_rays() 366 | to_world = generate_test_transformation() 367 | print(to_world) 368 | 369 | rays_new = to_world.transform_ray(rays) 370 | o_old = rays.o[2,1,...].numpy() 371 | o_new = rays_new.o[2,1,...].numpy() 372 | assert (to_world.transform_point(o_old) - o_new).abs().sum() < 1e-15 373 | 374 | material = generate_test_material() 375 | print(material) 376 | -------------------------------------------------------------------------------- /diffoptics/shapes.py: -------------------------------------------------------------------------------- 1 | from .basics import * 2 | 3 | 4 | class Endpoint(PrettyPrinter): 5 | """ 6 | Abstract class for objects. 7 | """ 8 | def __init__(self, transformation, device=torch.device('cpu')): 9 | self.to_world = transformation 10 | self.to_object = transformation.inverse() 11 | self.to(device) 12 | self.device = device 13 | 14 | def intersect(self, ray): 15 | raise NotImplementedError() 16 | 17 | def sample_ray(self, position_sample=None): 18 | raise NotImplementedError() 19 | 20 | def draw_points(self, ax, options, seq=range(3)): 21 | raise NotImplementedError() 22 | 23 | def update_Rt(self, R, t): 24 | self.to_world = Transformation(R, t) 25 | self.to_object = self.to_world.inverse() 26 | self.to(self.device) 27 | 28 | 29 | class Screen(Endpoint): 30 | """ 31 | A screen obejct, useful for image rendering. 32 | 33 | Local frame centers at [-w, w]/2 x [-h, h]/2. 34 | """ 35 | def __init__(self, transformation, size, texture, device=torch.device('cpu')): 36 | self.size = torch.Tensor(np.float32(size)) # screen dimension [mm] 37 | self.halfsize = self.size/2 # screen half-dimension [mm] 38 | self.texture_shift = torch.zeros(2) # screen image shift [mm] 39 | self.update_texture(texture) 40 | Endpoint.__init__(self, transformation, device) 41 | self.to(device) 42 | 43 | def update_texture(self, texture: torch.Tensor): 44 | self.texture = texture # screen image 45 | self.texturesize = torch.Tensor(np.array(texture.shape[0:2])).long().to(texture.device) # screen image dimension [pixel] 46 | 47 | def intersect(self, ray): 48 | ray_in = self.to_object.transform_ray(ray) 49 | t = - ray_in.o[..., 2] / (1e-10 + ray_in.d[..., 2]) # (TODO: potential NaN grad) 50 | local = ray_in(t) 51 | 52 | # Is intersection within ray segment and rectangle? 53 | valid = ( 54 | (t >= ray_in.mint) & 55 | (t <= ray_in.maxt) & 56 | (torch.abs(local[..., 0] - self.texture_shift[0]) <= self.halfsize[0]) & 57 | (torch.abs(local[..., 1] - self.texture_shift[1]) <= self.halfsize[1]) 58 | ) 59 | 60 | # UV coordinate 61 | uv = (local[..., 0:2] + self.halfsize - self.texture_shift) / self.size 62 | 63 | # Force uv to be valid in [0,1]^2 (just a sanity check: uv should be in [0,1]^2) 64 | uv = torch.clamp(uv, min=0.0, max=1.0) 65 | 66 | return local, uv, valid 67 | 68 | def shading(self, uv, valid, bmode=BoundaryMode.replicate, lmode=InterpolationMode.linear): 69 | # p = uv * (self.texturesize[None, None, ...]-1) 70 | p = uv * (self.texturesize-1) 71 | p_floor = torch.floor(p).long() 72 | 73 | def tex(x, y): 74 | """ 75 | Texture indexing function, handle various boundary conditions. 76 | """ 77 | if bmode is BoundaryMode.zero: 78 | raise NotImplementedError() 79 | elif bmode is BoundaryMode.replicate: 80 | x = torch.clamp(x, min=0, max=self.texturesize[0].item()-1) 81 | y = torch.clamp(y, min=0, max=self.texturesize[1].item()-1) 82 | elif bmode is BoundaryMode.symmetric: 83 | raise NotImplementedError() 84 | elif bmode is BoundaryMode.periodic: 85 | raise NotImplementedError() 86 | img = self.texture[x.flatten(), y.flatten()] 87 | return img.reshape(x.shape) 88 | 89 | # Texture fetching, requires interpolation to compute fractional pixel values. 90 | if lmode is InterpolationMode.nearest: 91 | val = tex(p_floor[...,0], p_floor[...,1]) 92 | elif lmode is InterpolationMode.linear: 93 | x0, y0 = p_floor[...,0], p_floor[...,1] 94 | s00 = tex( x0, y0) 95 | s01 = tex( x0, 1+y0) 96 | s10 = tex(1+x0, y0) 97 | s11 = tex(1+x0, 1+y0) 98 | w1 = p - p_floor 99 | w0 = 1. - w1 100 | val = ( 101 | w0[...,0] * (w0[...,1] * s00 + w1[...,1] * s01) + 102 | w1[...,0] * (w0[...,1] * s10 + w1[...,1] * s11) 103 | ) 104 | 105 | # val = val * valid 106 | # val[torch.isnan(val)] = 0.0 107 | 108 | # TODO: should be added; 109 | # but might cause RuntimeError: Function 'MulBackward0' returned nan values in its 0th output. 110 | val[~valid] = 0.0 111 | return val 112 | 113 | def draw_points(self, ax, options, seq=range(3)): 114 | """ 115 | Visualization function. 116 | """ 117 | coeffs = np.array([ 118 | [ 1, 1, 1], 119 | [-1, 1, 1], 120 | [-1,-1, 1], 121 | [ 1,-1, 1], 122 | [ 1, 1, 1] 123 | ]) 124 | points_local = torch.Tensor(coeffs * np.append(self.halfsize.cpu().detach().numpy(), 0)).to(self.device) 125 | points_world = self.to_world.transform_point(points_local).T.cpu().detach().numpy() 126 | ax.plot(points_world[seq[0]], points_world[seq[1]], points_world[seq[2]], options) 127 | 128 | -------------------------------------------------------------------------------- /diffoptics/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.0' -------------------------------------------------------------------------------- /examples/autodiff.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import matplotlib.pyplot as plt 5 | 6 | import sys 7 | sys.path.append("../") 8 | import diffoptics as do 9 | 10 | # initialize a lens 11 | device = torch.device('cpu') 12 | lens = do.Lensgroup(device=device) 13 | 14 | save_dir = './autodiff_demo/' 15 | if not os.path.exists(save_dir): 16 | os.mkdir(save_dir) 17 | 18 | R = 12.7 19 | surfaces = [ 20 | do.Aspheric(R, 0.0, c=0.05, device=device), 21 | do.Aspheric(R, 6.5, c=0., device=device) 22 | ] 23 | materials = [ 24 | do.Material('air'), 25 | do.Material('N-BK7'), 26 | do.Material('air') 27 | ] 28 | lens.load(surfaces, materials) 29 | lens.d_sensor = 25.0 30 | lens.r_last = 12.7 31 | 32 | # generate array of rays 33 | wavelength = torch.Tensor([532.8]).to(device) # [nm] 34 | R = 10.0 # [mm] 35 | def render(): 36 | ray_init = lens.sample_ray(wavelength, M=9, R=R, sampling='grid') 37 | ps = lens.trace_to_sensor(ray_init) 38 | return ps[...,:2] 39 | 40 | def trace_all(): 41 | ray_init = lens.sample_ray_2D(R, wavelength, M=11) 42 | ps, oss = lens.trace_to_sensor_r(ray_init) 43 | return ps[...,:2], oss 44 | 45 | def compute_Jacobian(ps): 46 | Js = [] 47 | for i in range(1): 48 | J = torch.zeros(torch.numel(ps)) 49 | for j in range(torch.numel(ps)): 50 | mask = torch.zeros(torch.numel(ps)) 51 | mask[j] = 1 52 | ps.backward(mask.reshape(ps.shape), retain_graph=True) 53 | J[j] = lens.surfaces[i].c.grad.item() 54 | lens.surfaces[i].c.grad.data.zero_() 55 | J = J.reshape(ps.shape) 56 | 57 | # get data to numpy 58 | Js.append(J.cpu().detach().numpy()) 59 | return Js 60 | 61 | 62 | N = 20 63 | cs = np.linspace(0.045, 0.063, N) 64 | Iss = [] 65 | Jss = [] 66 | for index, c in enumerate(cs): 67 | index_string = str(index).zfill(3) 68 | # load optics 69 | lens.surfaces[0].c = torch.Tensor(np.array(c)) 70 | lens.surfaces[0].c.requires_grad = True 71 | 72 | # show trace figure 73 | ps, oss = trace_all() 74 | ax, fig = lens.plot_raytraces(oss, color='b-', show=False) 75 | ax.axis('off') 76 | ax.set_title("") 77 | fig.savefig(save_dir + "layout_trace_" + index_string + ".png", bbox_inches='tight') 78 | 79 | # show spot diagram 80 | RMS = lambda ps: torch.sqrt(torch.mean(torch.sum(torch.square(ps), axis=-1))) 81 | ps = render() 82 | rms_org = RMS(ps) 83 | print(f'RMS: {rms_org}') 84 | lens.spot_diagram(ps, xlims=[-4, 4], ylims=[-4, 4], savepath=save_dir + "spotdiagram_" + index_string + ".png", show=False) 85 | 86 | # compute Jacobian 87 | Js = compute_Jacobian(ps)[0] 88 | print(Js.max()) 89 | print(Js.min()) 90 | ps_ = ps.cpu().detach().numpy() 91 | fig = plt.figure() 92 | x, y = ps_[:,0], ps_[:,1] 93 | plt.plot(x, y, 'b.', zorder=0) 94 | plt.quiver(x, y, Js[:,0], Js[:,1], color='b', zorder=1) 95 | plt.xlim(-4, 4) 96 | plt.ylim(-4, 4) 97 | plt.gca().set_aspect('equal', adjustable='box') 98 | plt.xlabel('x [mm]') 99 | plt.ylabel('y [mm]') 100 | fig.savefig(save_dir + "flow_" + index_string + ".png", bbox_inches='tight') 101 | 102 | # compute images 103 | ray = lens.sample_ray(wavelength.item(), view=0.0, M=2049, sampling='grid') 104 | lens.film_size = [512, 512] 105 | lens.pixel_size = 50.0e-3/2 106 | I = lens.render(ray) 107 | I = I.cpu().detach().numpy() 108 | lm = do.LM(lens, ['surfaces[0].c'], 1e-2, option='diag') 109 | JI = lm.jacobian(lambda: lens.render(ray)).squeeze() 110 | J = JI.abs().cpu().detach().numpy() 111 | 112 | Iss.append(I) 113 | Jss.append(J) 114 | plt.close() 115 | 116 | Iss = np.array(Iss) 117 | Jss = np.array(Jss) 118 | for i in range(N): 119 | plt.imsave(save_dir + "I_" + str(i).zfill(3) + ".png", Iss[i], cmap='gray') 120 | plt.imsave(save_dir + "J_" + str(i).zfill(3) + ".png", Jss[i], cmap='gray') 121 | 122 | names = [ 123 | 'spotdiagram', 124 | 'layout_trace', 125 | 'I', 126 | 'J', 127 | 'flow' 128 | ] 129 | -------------------------------------------------------------------------------- /examples/backprop_compare.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | 5 | import sys 6 | sys.path.append("../") 7 | import diffoptics as do 8 | 9 | 10 | device = torch.device('cuda') 11 | 12 | # initialize a lens 13 | def init(): 14 | lens = do.Lensgroup(device=device) 15 | 16 | R = 12.7 17 | surfaces = [ 18 | do.Aspheric(R, 0.0, c=0.05, device=device), 19 | do.Aspheric(R, 6.5, c=0., device=device) 20 | ] 21 | materials = [ 22 | do.Material('air'), 23 | do.Material('N-BK7'), 24 | do.Material('air') 25 | ] 26 | lens.load(surfaces, materials) 27 | lens.d_sensor = 25.0 28 | lens.r_last = 12.7 29 | lens.film_size = [256, 256] 30 | lens.pixel_size = 100.0e-3/2 31 | lens.surfaces[1].ai = torch.zeros(2, device=device) 32 | return lens 33 | 34 | def baseline(network_func, rays): 35 | lens = init() 36 | lens.surfaces[0].c.requires_grad = True 37 | lens.surfaces[1].ai.requires_grad = True 38 | 39 | I = 0.0 40 | for ray in rays: 41 | I = I + lens.render(ray) 42 | 43 | L = network_func(I) 44 | L.backward() 45 | print("Baseline:") 46 | print("primal: {}".format(L)) 47 | print("derivatives: {}".format([lens.surfaces[0].c.grad, lens.surfaces[1].ai.grad])) 48 | return float(torch.cuda.memory_allocated() / (1024 * 1024)) 49 | 50 | def ours_new(network_func, rays): 51 | lens = init() 52 | adj = do.Adjoint( 53 | lens, ['surfaces[0].c', 'surfaces[1].ai'], 54 | network_func, lens.render, rays 55 | ) 56 | 57 | L_item, grads = adj() 58 | 59 | print("Ours:") 60 | print("primal: {}".format(L_item)) 61 | print("derivatives: {}".format(grads)) 62 | torch.cuda.empty_cache() 63 | return float(torch.cuda.memory_allocated() / (1024 * 1024)) 64 | 65 | # Initialize a lens 66 | lens = init() 67 | 68 | # generate array of rays 69 | wavelength = torch.Tensor([532.8]).to(device) # [nm] 70 | 71 | def prepare_rays(view): 72 | ray = lens.sample_ray(wavelength.item(), view=view, M=2000+1, sampling='grid') 73 | return ray 74 | 75 | # define a network 76 | torch.manual_seed(0) 77 | I_ref = torch.rand(lens.film_size, device=device) 78 | def network_func(I): 79 | return ((I - I_ref)**2).mean() 80 | 81 | # timings 82 | start = torch.cuda.Event(enable_timing=True) 83 | end = torch.cuda.Event(enable_timing=True) 84 | 85 | # compares 86 | views = [1, 3, 5, 7, 9, 11, 13, 15] 87 | max_views = len(views) 88 | 89 | num_rayss = np.zeros(max_views) 90 | time = np.zeros((max_views, 2)) 91 | memory = np.zeros((max_views, 2)) 92 | for i, num_views in enumerate(views): 93 | print("view = {}".format(num_views)) 94 | views = np.linspace(0, 1, num_views) 95 | num_rays = num_views * 2001**2 / 1e6 96 | num_rayss[i] = num_rays 97 | 98 | # prepare rays 99 | rays = [prepare_rays(view) for view in views] 100 | 101 | # Baseline 102 | try: 103 | start.record() 104 | memory[i,0] = baseline(network_func, rays) 105 | end.record() 106 | torch.cuda.synchronize() 107 | print("Baseline time: {:.3f} s".format(start.elapsed_time(end)*1e-3)) 108 | time[i,0] = start.elapsed_time(end) 109 | except: 110 | print('Baseline: Memory insuffient! Stop running for this case!') 111 | time[i,0] = np.nan 112 | memory[i,0] = np.nan 113 | 114 | # Ours 115 | start.record() 116 | memory[i,1] = ours_new(network_func, rays) 117 | end.record() 118 | torch.cuda.synchronize() 119 | print("Ours (adjoint-based) time: {:.3f} s".format(start.elapsed_time(end)*1e-3)) 120 | time[i,1] = start.elapsed_time(end) 121 | 122 | 123 | # show results 124 | fig = plt.figure() 125 | plt.plot(num_rayss, time, '-o') 126 | plt.title("Time Comparison") 127 | plt.xlabel("Number of rays [Millions]") 128 | plt.ylabel("Computation time [Seconds]") 129 | plt.legend(["Baseline (backpropagation)", "Ours (adjoint-based)"]) 130 | plt.show() 131 | -------------------------------------------------------------------------------- /examples/caustic_pyramid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import matplotlib.pyplot as plt 5 | from matplotlib.image import imread, imsave 6 | from skimage.transform import resize 7 | 8 | import sys 9 | sys.path.append("../") 10 | import diffoptics as do 11 | 12 | # initialize a lens 13 | device = do.init() 14 | # device = torch.device('cpu') 15 | lens = do.Lensgroup(device=device) 16 | 17 | # construct freeform optics 18 | R = 25.4 19 | ns = [256, 256] 20 | surfaces = [ 21 | do.Aspheric(R, 0.0, c=0., is_square=True, device=device), 22 | do.Mesh(R, 1.0, ns, is_square=True, device=device) 23 | ] 24 | materials = [ 25 | do.Material('air'), 26 | do.Material('N-BK7'), 27 | do.Material('air') 28 | ] 29 | lens.load(surfaces, materials) 30 | 31 | # set scene geometry 32 | D = torch.Tensor([50.0]).to(device) # [mm] 33 | wavelength = torch.Tensor([532.8]).to(device) # [nm] 34 | 35 | # example image 36 | filename = 'einstein' 37 | img_org = imread('./images/' + filename + '.jpg') # assume image is grayscale 38 | if img_org.mean() > 1.0: 39 | img_org = img_org / 255.0 40 | 41 | # downsample the image 42 | NN = 2 43 | img_org = img_org[::NN,::NN] 44 | N_max = 128 45 | img_org = img_org[:N_max,:N_max] 46 | 47 | # mark differentiable variables 48 | lens.surfaces[1].c.requires_grad = True 49 | 50 | # create save dir 51 | savepath = './einstein_pyramid/' 52 | if not os.path.exists(savepath): 53 | os.mkdir(savepath) 54 | 55 | def caustic(N, pyramid_i, lr=1e-3, maxit=100): 56 | img = resize(img_org, (N, N)) 57 | E = np.sum(img) # total energy 58 | print(f'image size = {img.shape}') 59 | 60 | N_pad = 0 61 | N_total = N + 2*N_pad 62 | img = np.pad(img, (N_pad,N_pad), 'constant', constant_values=np.inf) 63 | img[np.isinf(img)] = 0.0 # revert img back for visualization 64 | I_ref = torch.Tensor(img).to(device) # [mask] 65 | 66 | # max square length 67 | R_square = R * N_total/N 68 | 69 | # set image plane pixel grid 70 | R_image = R_square 71 | pixel_size = 2*R_image / N_total # [mm] 72 | 73 | def sample_ray(M=1, random=False): 74 | M = int(M*N) 75 | x, y = torch.meshgrid( 76 | torch.linspace(-R_square, R_square, M, device=device), 77 | torch.linspace(-R_square, R_square, M, device=device) 78 | ) 79 | p = 2*R_square / M 80 | if random: 81 | x = x + p * (torch.rand(M,M,device=device)-0.5) 82 | y = y + p * (torch.rand(M,M,device=device)-0.5) 83 | o = torch.stack((x,y,torch.zeros_like(x, device=device)), axis=2) 84 | d = torch.zeros_like(o) 85 | d[...,2] = torch.ones_like(x) 86 | return do.Ray(o, d, wavelength, device=device), E 87 | 88 | def render_single(I, ray_init, irr): 89 | ray, valid = lens.trace(ray_init)[:2] 90 | J = irr * valid * ray.d[...,2] 91 | p = ray(D) 92 | p = p[...,:2] 93 | del ray, valid 94 | 95 | # compute shifts and do linear interpolation 96 | uv = (p + R_square) / pixel_size 97 | index_l = torch.clamp(torch.floor(uv).long(), min=0, max=N_total-1) 98 | index_r = torch.clamp(index_l + 1, min=0, max=N_total-1) 99 | w_r = torch.clamp(uv - index_l, min=0, max=1) 100 | w_l = 1.0 - w_r 101 | del uv 102 | 103 | # compute image 104 | I = torch.index_put(I, (index_l[...,0],index_l[...,1]), w_l[...,0]*w_l[...,1]*J, accumulate=True) 105 | I = torch.index_put(I, (index_r[...,0],index_l[...,1]), w_r[...,0]*w_l[...,1]*J, accumulate=True) 106 | I = torch.index_put(I, (index_l[...,0],index_r[...,1]), w_l[...,0]*w_r[...,1]*J, accumulate=True) 107 | I = torch.index_put(I, (index_r[...,0],index_r[...,1]), w_r[...,0]*w_r[...,1]*J, accumulate=True) 108 | return I 109 | 110 | def render(spp=1): 111 | I = torch.zeros((N_total,N_total), device=device) 112 | ray_init, irr = sample_ray(M=24, random=True) # Reduce M if your GPU memory is low 113 | I = render_single(I, ray_init, irr) 114 | return I / spp 115 | 116 | # optimize 117 | ls = [] 118 | 119 | save_path = savepath + "/{}".format("pyramid_" + str(pyramid_i)) 120 | if not os.path.exists(save_path): 121 | os.makedirs(save_path) 122 | 123 | print('optimizing ...') 124 | optimizer = torch.optim.Adam([lens.surfaces[1].c], lr=lr, betas=(0.99,0.99), amsgrad=True) 125 | 126 | for it in range(maxit+1): 127 | I = render(spp=8) 128 | I = I / I.sum() * I_ref.sum() 129 | L = torch.mean((I - I_ref)**2) 130 | optimizer.zero_grad() 131 | L.backward(retain_graph=True) 132 | 133 | # record 134 | ls.append(L.cpu().detach().numpy()) 135 | if it % 10 == 0: 136 | print('iter = {}: loss = {:.4e}, grad_bar = {:.4e}'.format( 137 | it, L.item(), torch.sum(torch.abs(lens.surfaces[1].c.grad)) 138 | )) 139 | I_current = I.cpu().detach().numpy() 140 | imsave("{}/{:04d}.png".format(save_path, it), I_current, vmin=0.0, vmax=1.0, cmap='gray') 141 | 142 | # descent 143 | optimizer.step() 144 | 145 | if pyramid_i == 0: # last one, render final image 146 | lens.surfaces[1].c.requires_grad = False 147 | del L 148 | I_final = 0 149 | spp = 100 150 | for i in range(spp): 151 | if i % 10 == 0: 152 | print("=== rendering spp = {}".format(i)) 153 | I_final += render().cpu().detach().numpy() 154 | return I_final / spp, I_ref, ls 155 | else: 156 | return I.cpu().detach().numpy(), None, ls 157 | 158 | pyramid_levels = 2 159 | for i in range(pyramid_levels, -1, -1): 160 | N = int(N_max/(2**i)) 161 | print("=== N = {}".format(N)) 162 | I_final, I_ref, ls = caustic(N, i, lr=1e-3, maxit=int(1000/4**i)) 163 | if i == 0: 164 | I_ref = I_ref.cpu().numpy() 165 | I_final = I_final / I_final.sum() * I_ref.sum() 166 | 167 | imsave(savepath + "/I_target.png", I_ref, vmin=0.0, vmax=1.0, cmap='gray') 168 | imsave(savepath + "/I_final.png", I_final, vmin=0.0, vmax=1.0, cmap='gray') 169 | 170 | # final results 171 | plt.imshow(I_final, cmap='gray') 172 | plt.title('Final caustic image') 173 | plt.show() 174 | 175 | fig, ax = plt.subplots() 176 | ax.plot(ls, 'k-o', linewidth=2) 177 | ax.set_xlabel('iteration') 178 | ax.set_ylabel('loss') 179 | fig.savefig("ls.pdf", bbox_inches='tight') 180 | plt.title('Loss') 181 | 182 | S = lens.surfaces[1].mesh().cpu().detach().numpy() 183 | S = S - S.min() 184 | imsave(savepath + "/phase.png", S, vmin=0, vmax=S.max(), cmap='coolwarm') 185 | imsave(savepath + "/phase_mod.png", np.mod(S*1e3,100), cmap='coolwarm') 186 | print(S.max()) 187 | 188 | plt.figure() 189 | plt.imshow(S, cmap='jet') 190 | plt.colorbar() 191 | plt.title('Optimized phase plate height [mm]') 192 | plt.show() 193 | 194 | -------------------------------------------------------------------------------- /examples/data/20210304/ref2.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/data/20210304/ref2.tif -------------------------------------------------------------------------------- /examples/end2end_edof_backward_tracing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | from pathlib import Path 5 | from tqdm import tqdm 6 | from datetime import datetime 7 | 8 | import sys 9 | sys.path.append("../") 10 | import diffoptics as do 11 | from utils_end2end import dict_to_tensor, tensor_to_dict, load_deblurganv2, ImageFolder 12 | torch.manual_seed(0) 13 | 14 | # Initialize a lens 15 | device = torch.device('cuda') 16 | lens = do.Lensgroup(device=device) 17 | 18 | # Load optics 19 | lens.load_file(Path('./lenses/end2end/end2end_edof.txt')) # norminal design 20 | lens.plot_setup2D() 21 | [surface.to(device) for surface in lens.surfaces] 22 | 23 | # set sensor pixel size and film size 24 | downsample_factor = 4 # downsampled for run 25 | pixel_size = downsample_factor * 3.45e-3 # [mm] 26 | film_size = [512 // downsample_factor, 512 // downsample_factor] 27 | lens.prepare_mts(pixel_size, film_size) 28 | print('Check your lens:') 29 | print(lens) 30 | 31 | # sample wavelengths in [nm] 32 | wavelengths = [656.2725, 587.5618, 486.1327] 33 | 34 | def create_screen(texture: torch.Tensor, z: float, pixelsize: float) -> do.Screen: 35 | texturesize = np.array(texture.shape[0:2]) 36 | screen = do.Screen( 37 | do.Transformation(np.eye(3), np.array([0, 0, z])), 38 | texturesize * pixelsize, texture, device=device 39 | ) 40 | return screen 41 | 42 | def render_single(wavelength: float, screen: do.Screen, sample_ray_function, images: list[torch.Tensor]): 43 | valid, ray_new = sample_ray_function(wavelength) 44 | uv, valid_screen = screen.intersect(ray_new)[1:] 45 | mask = valid & valid_screen 46 | 47 | # Render a batch of images 48 | I_batch = [] 49 | for image in images: 50 | screen.update_texture(image[..., wavelengths.index(wavelength)]) 51 | I_batch.append(screen.shading(uv, mask)) 52 | return torch.stack(I_batch, axis=0), mask 53 | 54 | def render(screen: do.Screen, images: list[torch.Tensor], ray_counts_per_pixel: int) -> torch.Tensor: 55 | Is = [] 56 | for wavelength in wavelengths: 57 | I = 0 58 | M = 0 59 | for i in range(ray_counts_per_pixel): 60 | I_current, mask = render_single(wavelength, screen, lambda x : lens.sample_ray_sensor(x), images) 61 | I = I + I_current 62 | M = M + mask 63 | I = I / (M[None, ...] + 1e-10) 64 | I = I.reshape((len(images), *np.flip(np.asarray(film_size)))).permute(0,2,1) 65 | Is.append(I) 66 | return torch.stack(Is, axis=-1) 67 | 68 | focal_length = 102 # [mm] 69 | def render_gt(screen: do.Screen, images: list[torch.Tensor]) -> torch.Tensor: 70 | Is = [] 71 | for wavelength in wavelengths: 72 | I, mask = render_single(wavelength, screen, lambda x : lens.sample_ray_sensor_pinhole(x, focal_length), images) 73 | I = I.reshape((len(images), *np.flip(np.asarray(film_size)))).permute(0,2,1) 74 | Is.append(I) 75 | return torch.stack(Is, axis=-1) 76 | 77 | 78 | ## Set differentiable optical parameters 79 | # XY_surface = ( 80 | # a[0] + 81 | # a[1] * x + a[2] * y + 82 | # a[3] * x**2 + a[4] * x*y + a[5] * y**2 + 83 | # a[6] * x**3 + a[7] * x**2*y + a[8] * x*y**2 + a[9] * y**3 84 | # ) 85 | # We optimize for a cubic profile (o.e. 3rd-order coefficients), as in the wavefront coding technology. 86 | diff_parameters = [ 87 | lens.surfaces[0].ai 88 | ] 89 | learning_rates = { 90 | 'surfaces[0].ai': 1e-15 * torch.Tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1]).to(device) 91 | } 92 | for diff_para, key in zip(diff_parameters, learning_rates.keys()): 93 | if len(diff_para) != len(learning_rates[key]): 94 | raise Exception('Learning rates of {} is not of equal length to the parameters!'.format(key)) 95 | diff_para.requires_grad = True 96 | diff_parameter_labels = learning_rates.keys() 97 | 98 | 99 | ## Create network 100 | net = load_deblurganv2() 101 | net.prepare() 102 | 103 | print('Initial:') 104 | current_parameters = [x.detach().cpu().numpy() for x in diff_parameters] 105 | print('Current optical parameters are:') 106 | for x, label in zip(current_parameters, diff_parameter_labels): 107 | print('-- lens.{}: {}'.format(label, x)) 108 | 109 | # Training dataset 110 | train_path = './training_dataset/' 111 | train_dataloader = torch.utils.data.DataLoader(ImageFolder(train_path), batch_size=1, shuffle=False) 112 | it = iter(train_dataloader) 113 | image = next(it).squeeze().to(device) 114 | 115 | # Training settings 116 | settings = { 117 | 'spp_forward': 100, # Rays per pixel for forward 118 | 'spp_backward': 20, # Rays per pixel for a single-pass backward 119 | 'num_passes': 5, # Number of accumulation passes for the backward 120 | 'image_batch_size': 5, # Images per batch 121 | 'network_training_iter': 200, # Training iterations for network update 122 | 'num_of_training': 10, # Training outer loop iteration 123 | 'savefig': True # Save intermediate results 124 | } 125 | 126 | if settings['savefig']: 127 | opath = Path('end2end_output') / str(datetime.now().strftime("%Y-%m-%d-%H-%M-%S")) 128 | opath.mkdir(parents=True, exist_ok=True) 129 | 130 | def wrapper_func(screen, images, squeezed_diff_parameters, diff_parameters, diff_parameter_labels): 131 | unpacked_diff_parameters = tensor_to_dict(squeezed_diff_parameters, diff_parameters) 132 | for idx, label in enumerate(diff_parameter_labels): 133 | exec('lens.{} = unpacked_diff_parameters[{}]'.format(label, idx)) 134 | return render(screen, images, settings['spp_forward']) 135 | 136 | 137 | # Physical parameters for the screen 138 | zs = [8e3, 6e3, 4.5e3] # [mm] 139 | pixelsizes = [0.1 * z/6e3 for z in zs] # [mm] 140 | 141 | print('Training starts ...') 142 | for iteration in range(settings['num_of_training']): 143 | for z_idx, z in enumerate(zs): 144 | 145 | # Print current status 146 | current_parameters = [x.detach().cpu().numpy() for x in diff_parameters] 147 | print('=========') 148 | print('Iteration = {}, z = {} [mm]:'.format(iteration, z)) 149 | print('Current optical parameters are:') 150 | for x, label in zip(current_parameters, diff_parameter_labels): 151 | print('-- lens.{}: {}'.format(label, x)) 152 | print('=========') 153 | 154 | # Put screen at a desired distance (and with a proper pixel size) 155 | screen = create_screen(image, z, pixelsizes[z_idx]) 156 | 157 | # Forward rendering 158 | tq = tqdm(range(settings['image_batch_size'])) 159 | tq.set_description('(1) Rendering batch images') 160 | 161 | # Load image batch (multiple images) 162 | images = [] 163 | for image_idx in tq: 164 | try: 165 | data = next(it) 166 | except StopIteration: 167 | it = iter(train_dataloader) 168 | data = next(it) 169 | image = data.squeeze().to(device) 170 | images.append(image.clone()) 171 | 172 | with torch.no_grad(): 173 | Is = render(screen, images, settings['spp_forward']) 174 | Is_gt = render_gt(screen, images) 175 | tq.close() 176 | 177 | # Save images for visualization 178 | Is_view = np.concatenate([I.cpu().numpy().astype(np.uint8) for I in Is], axis=1) 179 | Is_gt_view = np.concatenate([I.cpu().numpy().astype(np.uint8) for I in Is_gt], axis=1) 180 | 181 | # Reorder tensors to match neural network input format 182 | Is = 2 * torch.permute(Is, (0, 3, 1, 2)) / 255 - 1 183 | Is_gt = 2 * torch.permute(Is_gt, (0, 3, 1, 2)) / 255 - 1 184 | 185 | # Train network weights 186 | Is_output = net.run( 187 | Is, Is_gt, is_inference=False, 188 | num_iters=settings['network_training_iter'], desc='(2) Training network weights' 189 | ) 190 | Is_output_np = np.transpose(255/2 * (Is_output.detach().cpu().numpy() + 1), (0,2,3,1)).astype(np.uint8) 191 | Is_output_view = np.concatenate([I for I in Is_output_np], axis=1) 192 | del Is_output_np 193 | 194 | if settings['savefig']: 195 | fig, axs = plt.subplots(3, 1) 196 | for idx, I_view, label in zip( 197 | range(3), [Is_view, Is_gt_view, Is_output_view], ['Input', 'Ground truth', 'Network output'] 198 | ): 199 | axs[idx].imshow(I_view) 200 | axs[idx].set_title(label + ' image(s)') 201 | axs[idx].set_axis_off() 202 | fig.tight_layout() 203 | fig.savefig( 204 | str(opath / 'iter_{}_z={}mm_images.png'.format(iteration, z)), 205 | dpi=400, bbox_inches='tight', pad_inches=0.1 206 | ) 207 | fig.clear() 208 | plt.close(fig) 209 | 210 | # Back-propagate backend loss and obtain adjoint gradients 211 | Is.requires_grad = True 212 | Is_output = net.run(Is, Is_gt, is_inference=False, num_iters=1) 213 | 214 | # Get adjoint gradients of the image batch 215 | Is_grad = Is.grad.permute(0, 2, 3, 1) 216 | del Is, Is_gt, Is_output 217 | torch.cuda.empty_cache() 218 | 219 | # Back-propagate optical parameters with adjoint gradients, and accumulate 220 | tq = tqdm(range(settings['num_passes'])) 221 | tq.set_description('(3) Back-prop optical parameters') 222 | dthetas = torch.zeros_like(dict_to_tensor(diff_parameters)).detach() 223 | for inner_iteration in tq: 224 | dthetas += torch.autograd.functional.vjp( 225 | lambda x : wrapper_func(screen, images, x, diff_parameters, diff_parameter_labels), 226 | dict_to_tensor(diff_parameters), Is_grad 227 | )[1] 228 | tq.close() 229 | 230 | # Update optical parameters 231 | with torch.no_grad(): 232 | for label, diff_para, dtheta in zip( 233 | diff_parameter_labels, diff_parameters, tensor_to_dict(dthetas, diff_parameters) 234 | ): 235 | diff_para -= learning_rates[label] * dtheta.squeeze() / settings['num_passes'] 236 | diff_para.grad = None 237 | -------------------------------------------------------------------------------- /examples/images/einstein.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/images/einstein.jpg -------------------------------------------------------------------------------- /examples/images/point.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/images/point.tif -------------------------------------------------------------------------------- /examples/images/squirrel.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/images/squirrel.jpg -------------------------------------------------------------------------------- /examples/lenses/DoubleGauss/US02532751-1.txt: -------------------------------------------------------------------------------- 1 | Originate from US02532751-1 2 | type distance roc diameter material 3 | O 0.0 0.0 0.0 VACUUM 4 | S 0.0 13.354729316461547 17.4 SSK4 0.005 1e-6 1e-8 -3e-10 5 | S 2.2352 35.64148197667863 17.4 VACUUM 6 | S 0.0762 10.330017837998932 14.0 SK1 7 | S 3.1750 0.0 14.0 F15 8 | S 0.9652 6.494496063151893 9.0 VACUUM 9 | A 3.8608 0.0 4.886 OCCLUDER 10 | S 3.302 -7.026950339915501 9.0 F15 11 | S 0.9652 0.0 12.0 SK16 12 | S 2.7686 -9.746574604143909 12.0 VACUUM 13 | S 0.0762 69.81692521236866 14.0 SK16 14 | S 1.7526 -19.226275376106166 14.0 VACUUM 15 | I 17.42769142705 0.0 20.664 VACUUM 16 | -------------------------------------------------------------------------------- /examples/lenses/DoubleGauss/US02532751-1.zmx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/lenses/DoubleGauss/US02532751-1.zmx -------------------------------------------------------------------------------- /examples/lenses/README.md: -------------------------------------------------------------------------------- 1 | ## Materials 2 | 3 | We modeled all materials with Cauchy's coefficients. Predefined materials could be found in [`../../diffoptics/basics.py`](https://github.com/vccimaging/DiffOptics/blob/7c30e22967280c97fb2953db0bcf894611df37b7/diffoptics/basics.py#L188-L234). 4 | 5 | Custom materials could be specified as `
/ `, for example: 6 | 7 | ``` 8 | 1.66565/35.64 9 | ``` 10 | 11 | ## Format 12 | 13 | All surfaces follow the sequence of entries: 14 | 15 | ``` 16 | type distance roc diameter material 17 | ``` 18 | 19 | For each specific surface, the parameters are pended after the last column: 20 | 21 | ``` 22 | Thorlabs-AC508-1000-A 23 | type distance roc diameter material 24 | O 0 0 100 AIR 25 | S 0 757.9 50.8 N-BK7 26 | S 6.0 -364.7 50.8 N-SF2 27 | S 6.0 -954.2 50.8 AIR 28 | I 996.4 0 50.8 AIR 29 | ``` 30 | where the symbol means are: 31 | 32 | | Symbol | Surface type | Note | 33 | | :----: | :----------: | :------------: | 34 | | `O` | Asphere | Object plane | 35 | | `S` | Asphere | Lens surface | 36 | | `A` | Asphere | Aperture plane | 37 | | `I` | Asphere | Image plane | 38 | 39 | -------------------------------------------------------------------------------- /examples/lenses/ThorLabs/ACL5040U.txt: -------------------------------------------------------------------------------- 1 | Thorlabs-ACL5040U 2 | type distance roc diameter material 3 | O 20 0 100 AIR 4 | S 0 20.923 50.0 B270 -0.6405 2.0e-06 5 | S 21.0 0 50.0 AIR 6 | I 26.0 0 50.0 AIR 7 | -------------------------------------------------------------------------------- /examples/lenses/ThorLabs/LA1131.txt: -------------------------------------------------------------------------------- 1 | Thorlabs-LA1131 2 | type distance roc diameter material 3 | O 20 0 100 AIR 4 | S 0 25.75 25.4 N-BK7 5 | S 5.34 0 25.4 AIR 6 | I 4.42 0 25.4 AIR -------------------------------------------------------------------------------- /examples/lenses/Zemax_samples/Nikon-z35-f1.8-JPA2019090949-example2.txt: -------------------------------------------------------------------------------- 1 | Originate from Zemax, modified by Congli Wang 2 | type distance roc diameter material 3 | O 0 0 0 AIR 4 | S 0 5.267 1.694 1.5168/64.12 5 | S 0.102 0.961 1.392 AIR 6 | S 0.309 1.442 1.322 1.9027/35.72 7 | S 0.246 10.280 1.250 1.5955/39.21 8 | S 0.083 1.215 1.092 AIR 9 | S 0.411 -1.099 1.048 1.6990/30.05 10 | S 0.088 2.918 1.172 1.9108/35.25 11 | S 0.258 -1.669 1.202 AIR 12 | S 0.009 1.643 1.248 1.5928/68.62 13 | S 0.379 -1.412 1.226 1.7205/34.70 14 | S 0.069 -2.572 1.214 AIR 15 | A 0.118 0 1.110 OCCLUDER 16 | S 0.604 -0.973 0.952 1.5927/35.31 17 | S 0.051 -24.080 0.980 AIR 18 | S 0.009 2.376 1.086 1.5928/68.62 19 | S 0.282 -1.306 1.138 AIR 20 | S 0.239 -7.317 1.208 1.6935/53.20 0 -0.240 -0.4268 21 | S 0.122 -2.200 1.254 AIR 0 -0.05053 -0.3491 0.1459 0.07718 22 | S 0.154 -1.545 1.324 1.4875/70.44 23 | S 0.083 -7.257 1.424 AIR 24 | S 0.750 0 2.400 1.5168/64.12 25 | S 0.074 0 2.400 AIR 26 | I 0.043 0 1.912 AIR 27 | -------------------------------------------------------------------------------- /examples/lenses/end2end/end2end_edof.txt: -------------------------------------------------------------------------------- 1 | End2end EDOF exmaple 2 | type distance roc diameter material 3 | O 0.0 0.0 0.0 AIR 4 | S 0.0 62.8 12.7 n-bk7 5 | S 4.000 -45.7 12.7 sf5 6 | S 2.500 -128.2 12.7 AIR 7 | S 1.500 0.0 12.7 n-bk7 8 | X 2.500 0.0 12.7 AIR 0 0 0 0 0 0 0 0 0 0 0 9 | I 96.000 0.0 1.7664 AIR -------------------------------------------------------------------------------- /examples/misalignment_point.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | from pathlib import Path 5 | from matplotlib.image import imread, imsave 6 | import cv2 7 | 8 | import imageio 9 | 10 | import sys 11 | sys.path.append("../") 12 | import diffoptics as do 13 | 14 | """ 15 | Experimental parameters: 16 | 17 | light source to sensor: about 675 [mm] 18 | sensor: GS3-U3-50S5M, pixel size 3.45 [um], resolution 2448 × 2048 19 | """ 20 | 21 | # initialize a lens 22 | device = do.init() 23 | # device = torch.device('cpu') 24 | lens = do.Lensgroup(device=device) 25 | 26 | # ==== Load lens file 27 | lens.load_file(Path('./lenses/Thorlabs/LA1131.txt')) 28 | lens.d_sensor = torch.Tensor([56.0]).to(device) # [mm] sensor distance 29 | lens.plot_setup2D(with_sensor=True) 30 | R = lens.surfaces[0].r 31 | 32 | # sensor information 33 | downsample_N = 4 34 | pixel_size = 3.45e-3 * downsample_N # [mm] 35 | N_total = int(2048 / downsample_N) 36 | R_sensor = N_total * pixel_size / 2 # [mm] 37 | 38 | # set scene geometry 39 | wavelength = torch.Tensor([622.5]).to(device) # [nm] 40 | 41 | # point light source position 42 | light_o = torch.Tensor([0.0, 0.0, -650]).to(device) 43 | lens.light_o = light_o # hook-up 44 | 45 | R_in = 1.42*R # must be >= sqrt(2) 46 | M = 1024 47 | def sample_ray(M, light_o): 48 | o_x, o_y = torch.meshgrid( 49 | torch.linspace(-R_in, R_in, M, device=device), 50 | torch.linspace(-R_in, R_in, M, device=device) 51 | ) 52 | valid = (o_x**2 + o_y**2) < (0.95*R)**2 53 | 54 | o = torch.stack((o_x, o_y, -torch.ones_like(o_x)), axis=-1) 55 | d = torch.stack((o_x, o_y, torch.zeros_like(o_x)), axis=-1) - light_o[None, None, ...] 56 | d = do.normalize(d) 57 | 58 | o = o[valid] 59 | d = d[valid] 60 | 61 | return do.Ray(o, d, wavelength, device=device) 62 | 63 | lens.pixel_size = pixel_size 64 | lens.film_size = [N_total,N_total] 65 | def render(): 66 | ray = sample_ray(M, lens.light_o) 67 | I = lens.render(ray) 68 | I = N_total**2 * I / I.sum() 69 | return I 70 | 71 | 72 | # centroid 73 | X, Y = torch.meshgrid( 74 | 1 + torch.arange(N_total, device=device), 75 | 1 + torch.arange(N_total, device=device) 76 | ) 77 | def centroid(I): 78 | return torch.stack(( 79 | torch.sum(X * I) / torch.sum(I), 80 | torch.sum(Y * I) / torch.sum(I) 81 | )) 82 | 83 | ### Optimization utilities 84 | def loss(I, I_mea): 85 | data_term = torch.mean((I - I_mea)**2) 86 | comp_centroid = True 87 | if comp_centroid: 88 | c_mea = centroid(I_mea) 89 | c = centroid(I) 90 | loss = data_term + 0.0005 * torch.mean((c - c_mea)**2) 91 | else: 92 | loss = data_term 93 | return loss 94 | 95 | 96 | # read image 97 | img = imread('./data/20210304/ref2.tif') # for now we use grayscale 98 | img = img.astype(float) 99 | I_mea = cv2.resize(img, dsize=(N_total, N_total), interpolation=cv2.INTER_AREA) 100 | I_mea = np.maximum(0.0, I_mea - np.median(I_mea)) 101 | I_mea = N_total**2 * I_mea / I_mea.sum() 102 | I_mea = torch.Tensor(I_mea).to(device) 103 | 104 | # AUTO DIFF 105 | diff_variables = ['d_sensor', 'theta_x', 'theta_y', 'light_o'] 106 | out = do.LM(lens, diff_variables, 1e-3, option='diag') \ 107 | .optimize(render, lambda y: I_mea - y, maxit=30, record=True) 108 | 109 | 110 | # crop images 111 | def crop(I): 112 | c = 200 113 | return I[c:I.shape[0]-c, c:I.shape[1]-c] 114 | 115 | opath = Path('misalignment_point') 116 | opath.mkdir(parents=True, exist_ok=True) 117 | def save(I_mea, Is): 118 | images = [] 119 | for I in Is: 120 | images.append(crop(I)) 121 | imageio.mimsave(str(opath / 'movie.mp4'), images) 122 | 123 | # show results 124 | plt.figure() 125 | plt.imshow(crop(Is[0]), cmap='gray') 126 | plt.title('Simulation (initial)') 127 | 128 | plt.figure() 129 | plt.imshow(crop(Is[-1]), cmap='gray') 130 | plt.title('Simulation (optimized)') 131 | 132 | I_mea = I_mea.cpu().detach().numpy() 133 | 134 | plt.figure() 135 | plt.imshow(crop(I_mea), cmap='gray') 136 | plt.title('Measurement') 137 | plt.show() 138 | 139 | plt.imsave(str(opath / 'I0.jpg'), crop(Is[0]), vmin=0, vmax=np.maximum(I_mea.max(), Is[-1].max()), cmap='gray') 140 | plt.imsave(str(opath / 'I.jpg'), crop(Is[-1]), vmin=0, vmax=np.maximum(I_mea.max(), Is[-1].max()), cmap='gray') 141 | plt.imsave(str(opath / 'I_mea.jpg'), crop(I_mea), vmin=0, vmax=np.maximum(I_mea.max(), Is[-1].max()), cmap='gray') 142 | 143 | 144 | save(I_mea, out['Is']) 145 | 146 | fig = plt.figure() 147 | plt.plot(out['ls'], 'k-o') 148 | plt.xlabel('iteration') 149 | plt.ylabel('loss') 150 | fig.savefig(str(opath / "ls.pdf"), bbox_inches='tight') 151 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.pyc 3 | #fpn_inception.h5 4 | #fpn_mobilenet.h5 5 | *.h5 -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, Orest Kupyn, Tetiana Martyniuk, Junru Wu and Zhangyang Wang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | 26 | --------------------------- LICENSE FOR DeblurGANv2 -------------------------------- 27 | BSD License 28 | 29 | For DeblurGANv2 software 30 | Copyright (c) 2019, Orest Kupyn, Tetiana Martyniuk, Junru Wu and Zhangyang Wang 31 | All rights reserved. 32 | 33 | Redistribution and use in source and binary forms, with or without 34 | modification, are permitted provided that the following conditions are met: 35 | 36 | * Redistributions of source code must retain the above copyright notice, this 37 | list of conditions and the following disclaimer. 38 | 39 | * Redistributions in binary form must reproduce the above copyright notice, 40 | this list of conditions and the following disclaimer in the documentation 41 | and/or other materials provided with the distribution. 42 | 43 | ----------------------------- LICENSE FOR DCGAN -------------------------------- 44 | BSD License 45 | 46 | For dcgan.torch software 47 | 48 | Copyright (c) 2015, Facebook, Inc. All rights reserved. 49 | 50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 51 | 52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 53 | 54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 55 | 56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 57 | 58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 59 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/README.md: -------------------------------------------------------------------------------- 1 | # DeblurGAN-v2: Deblurring (Orders-of-Magnitude) Faster and Better 2 | 3 | Code for this paper [DeblurGAN-v2: Deblurring (Orders-of-Magnitude) Faster and Better](https://arxiv.org/abs/1908.03826) 4 | 5 | Orest Kupyn, Tetiana Martyniuk, Junru Wu, Zhangyang Wang 6 | 7 | In ICCV 2019 8 | 9 | ## Overview 10 | 11 | We present a new end-to-end generative adversarial network (GAN) for single image motion deblurring, named 12 | DeblurGAN-v2, which considerably boosts state-of-the-art deblurring efficiency, quality, and flexibility. DeblurGAN-v2 13 | is based on a relativistic conditional GAN with a double-scale discriminator. For the first time, we introduce the 14 | Feature Pyramid Network into deblurring, as a core building block in the generator of DeblurGAN-v2. It can flexibly 15 | work with a wide range of backbones, to navigate the balance between performance and efficiency. The plug-in of 16 | sophisticated backbones (e.g., Inception-ResNet-v2) can lead to solid state-of-the-art deblurring. Meanwhile, 17 | with light-weight backbones (e.g., MobileNet and its variants), DeblurGAN-v2 reaches 10-100 times faster than 18 | the nearest competitors, while maintaining close to state-of-the-art results, implying the option of real-time 19 | video deblurring. We demonstrate that DeblurGAN-v2 obtains very competitive performance on several popular 20 | benchmarks, in terms of deblurring quality (both objective and subjective), as well as efficiency. Besides, 21 | we show the architecture to be effective for general image restoration tasks too. 22 | 23 | 27 | 28 | ![](./doc_images/kohler_visual.png) 29 | ![](./doc_images/restore_visual.png) 30 | ![](./doc_images/gopro_table.png) 31 | ![](./doc_images/lai_table.png) 32 | 33 | 34 | 35 | ## DeblurGAN-v2 Architecture 36 | 37 | ![](./doc_images/pipeline.jpg) 38 | 39 | 45 | 46 | 49 | 50 | ## Datasets 51 | 52 | The datasets for training can be downloaded via the links below: 53 | - [DVD](https://drive.google.com/file/d/1bpj9pCcZR_6-AHb5aNnev5lILQbH8GMZ/view) 54 | - [GoPro](https://drive.google.com/file/d/1KStHiZn5TNm2mo3OLZLjnRvd0vVFCI0W/view) 55 | - [NFS](https://drive.google.com/file/d/1Ut7qbQOrsTZCUJA_mJLptRMipD8sJzjy/view) 56 | 57 | ## Training 58 | 59 | #### Command 60 | 61 | ```python train.py``` 62 | 63 | training script will load config under config/config.yaml 64 | 65 | #### Tensorboard visualization 66 | 67 | ![](./doc_images/tensorboard2.png) 68 | 69 | ## Testing 70 | 71 | To test on a single image, 72 | 73 | ```python predict.py IMAGE_NAME.jpg``` 74 | 75 | By default, the name of the pretrained model used by Predictor is 'best_fpn.h5'. One can change it in the code ('weights_path' argument). It assumes that the fpn_inception backbone is used. If you want to try it with different backbone pretrain, please specify it also under ['model']['g_name'] in config/config.yaml. 76 | 77 | ## Pre-trained models 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 |
DatasetG ModelD ModelLoss TypePSNR/ SSIMLink
GoPro Test DatasetInceptionResNet-v2double_ganragan-ls29.55/ 0.934fpn_inception.h5
MobileNetdouble_ganragan-ls28.17/ 0.925fpn_mobilenet.h5
MobileNet-DSCdouble_ganragan-ls28.03/ 0.922
111 | 112 | ## Parent Repository 113 | 114 | The code was taken from https://github.com/KupynOrest/RestoreGAN . This repository contains flexible pipelines for different Image Restoration tasks. 115 | 116 | ## Citation 117 | 118 | If you use this code for your research, please cite our paper. 119 | 120 | ``` 121 | ​``` 122 | @InProceedings{Kupyn_2019_ICCV, 123 | author = {Orest Kupyn and Tetiana Martyniuk and Junru Wu and Zhangyang Wang}, 124 | title = {DeblurGAN-v2: Deblurring (Orders-of-Magnitude) Faster and Better}, 125 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, 126 | month = {Oct}, 127 | year = {2019} 128 | } 129 | ​``` 130 | ``` 131 | 132 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/adversarial_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | 4 | 5 | class GANFactory: 6 | factories = {} 7 | 8 | def __init__(self): 9 | pass 10 | 11 | def add_factory(gan_id, model_factory): 12 | GANFactory.factories.put[gan_id] = model_factory 13 | 14 | add_factory = staticmethod(add_factory) 15 | 16 | # A Template Method: 17 | 18 | def create_model(gan_id, net_d=None, criterion=None): 19 | if gan_id not in GANFactory.factories: 20 | GANFactory.factories[gan_id] = \ 21 | eval(gan_id + '.Factory()') 22 | return GANFactory.factories[gan_id].create(net_d, criterion) 23 | 24 | create_model = staticmethod(create_model) 25 | 26 | 27 | class GANTrainer(object): 28 | def __init__(self, net_d, criterion): 29 | self.net_d = net_d 30 | self.criterion = criterion 31 | 32 | def loss_d(self, pred, gt): 33 | pass 34 | 35 | def loss_g(self, pred, gt): 36 | pass 37 | 38 | def get_params(self): 39 | pass 40 | 41 | 42 | class NoGAN(GANTrainer): 43 | def __init__(self, net_d, criterion): 44 | GANTrainer.__init__(self, net_d, criterion) 45 | 46 | def loss_d(self, pred, gt): 47 | return [0] 48 | 49 | def loss_g(self, pred, gt): 50 | return 0 51 | 52 | def get_params(self): 53 | return [torch.nn.Parameter(torch.Tensor(1))] 54 | 55 | class Factory: 56 | @staticmethod 57 | def create(net_d, criterion): return NoGAN(net_d, criterion) 58 | 59 | 60 | class SingleGAN(GANTrainer): 61 | def __init__(self, net_d, criterion): 62 | GANTrainer.__init__(self, net_d, criterion) 63 | self.net_d = self.net_d.cuda() 64 | 65 | def loss_d(self, pred, gt): 66 | return self.criterion(self.net_d, pred, gt) 67 | 68 | def loss_g(self, pred, gt): 69 | return self.criterion.get_g_loss(self.net_d, pred, gt) 70 | 71 | def get_params(self): 72 | return self.net_d.parameters() 73 | 74 | class Factory: 75 | @staticmethod 76 | def create(net_d, criterion): return SingleGAN(net_d, criterion) 77 | 78 | 79 | class DoubleGAN(GANTrainer): 80 | def __init__(self, net_d, criterion): 81 | GANTrainer.__init__(self, net_d, criterion) 82 | self.patch_d = net_d['patch'].cuda() 83 | self.full_d = net_d['full'].cuda() 84 | self.full_criterion = copy.deepcopy(criterion) 85 | 86 | def loss_d(self, pred, gt): 87 | return (self.criterion(self.patch_d, pred, gt) + self.full_criterion(self.full_d, pred, gt)) / 2 88 | 89 | def loss_g(self, pred, gt): 90 | return (self.criterion.get_g_loss(self.patch_d, pred, gt) + self.full_criterion.get_g_loss(self.full_d, pred, 91 | gt)) / 2 92 | 93 | def get_params(self): 94 | return list(self.patch_d.parameters()) + list(self.full_d.parameters()) 95 | 96 | class Factory: 97 | @staticmethod 98 | def create(net_d, criterion): return DoubleGAN(net_d, criterion) 99 | 100 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/aug.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import albumentations as albu 4 | 5 | 6 | def get_transforms(size: int, scope: str = 'geometric', crop='random'): 7 | augs = {'weak': albu.Compose([albu.HorizontalFlip(), 8 | ]), 9 | 'geometric': albu.OneOf([albu.HorizontalFlip(always_apply=True), 10 | albu.ShiftScaleRotate(always_apply=True), 11 | albu.Transpose(always_apply=True), 12 | albu.OpticalDistortion(always_apply=True), 13 | albu.ElasticTransform(always_apply=True), 14 | ]) 15 | } 16 | 17 | aug_fn = augs[scope] 18 | crop_fn = {'random': albu.RandomCrop(size, size, always_apply=True), 19 | 'center': albu.CenterCrop(size, size, always_apply=True)}[crop] 20 | pad = albu.PadIfNeeded(size, size) 21 | 22 | pipeline = albu.Compose([aug_fn, pad, crop_fn], additional_targets={'target': 'image'}) 23 | 24 | def process(a, b): 25 | r = pipeline(image=a, target=b) 26 | return r['image'], r['target'] 27 | 28 | return process 29 | 30 | 31 | def get_normalize(): 32 | normalize = albu.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 33 | normalize = albu.Compose([normalize], additional_targets={'target': 'image'}) 34 | 35 | def process(a, b): 36 | r = normalize(image=a, target=b) 37 | return r['image'], r['target'] 38 | 39 | return process 40 | 41 | 42 | def _resolve_aug_fn(name): 43 | d = { 44 | 'cutout': albu.Cutout, 45 | 'rgb_shift': albu.RGBShift, 46 | 'hsv_shift': albu.HueSaturationValue, 47 | 'motion_blur': albu.MotionBlur, 48 | 'median_blur': albu.MedianBlur, 49 | 'snow': albu.RandomSnow, 50 | 'shadow': albu.RandomShadow, 51 | 'fog': albu.RandomFog, 52 | 'brightness_contrast': albu.RandomBrightnessContrast, 53 | 'gamma': albu.RandomGamma, 54 | 'sun_flare': albu.RandomSunFlare, 55 | 'sharpen': albu.Sharpen, 56 | 'jpeg': albu.ImageCompression, 57 | 'gray': albu.ToGray, 58 | 'pixelize': albu.Downscale, 59 | # ToDo: partial gray 60 | } 61 | return d[name] 62 | 63 | 64 | def get_corrupt_function(config: List[dict]): 65 | augs = [] 66 | for aug_params in config: 67 | name = aug_params.pop('name') 68 | cls = _resolve_aug_fn(name) 69 | prob = aug_params.pop('prob') if 'prob' in aug_params else .5 70 | augs.append(cls(p=prob, **aug_params)) 71 | 72 | augs = albu.OneOf(augs) 73 | 74 | def process(x): 75 | return augs(image=x)['image'] 76 | 77 | return process 78 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/config/config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | project: deblur_gan 3 | experiment_desc: fpn 4 | 5 | train: 6 | files_a: &FILES_A ./training_dataset/*.png 7 | files_b: *FILES_A 8 | size: &SIZE 512 9 | crop: random 10 | preload: &PRELOAD false 11 | preload_size: &PRELOAD_SIZE 0 12 | bounds: [0, .9] 13 | scope: geometric 14 | corrupt: &CORRUPT 15 | - name: cutout 16 | prob: 0.5 17 | num_holes: 3 18 | max_h_size: 25 19 | max_w_size: 25 20 | - name: jpeg 21 | quality_lower: 70 22 | quality_upper: 90 23 | - name: motion_blur 24 | - name: median_blur 25 | - name: gamma 26 | - name: rgb_shift 27 | - name: hsv_shift 28 | - name: sharpen 29 | 30 | val: 31 | files_a: *FILES_A 32 | files_b: *FILES_A 33 | # files_a: &FILES_A 34 | # files_b: &FILES_B 35 | size: *SIZE 36 | scope: geometric 37 | crop: center 38 | preload: *PRELOAD 39 | preload_size: *PRELOAD_SIZE 40 | bounds: [.9, 1] 41 | corrupt: *CORRUPT 42 | 43 | phase: train 44 | warmup_num: 3 45 | model: 46 | g_name: fpn_inception 47 | blocks: 9 48 | d_name: double_gan # may be no_gan, patch_gan, double_gan, multi_scale 49 | d_layers: 3 50 | content_loss: perceptual 51 | adv_lambda: 0.001 52 | disc_loss: wgan-gp 53 | learn_residual: True 54 | norm_layer: instance 55 | dropout: True 56 | 57 | num_epochs: 10 58 | train_batches_per_epoch: 1000 59 | val_batches_per_epoch: 100 60 | batch_size: 1 61 | image_size: [512, 512] 62 | 63 | optimizer: 64 | name: adam 65 | lr: 0.01 66 | scheduler: 67 | name: linear 68 | start_epoch: 50 69 | min_lr: 0.0000001 70 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | from functools import partial 4 | from glob import glob 5 | from hashlib import sha1 6 | from typing import Callable, Iterable, Optional, Tuple 7 | 8 | import cv2 9 | import numpy as np 10 | from glog import logger 11 | from joblib import Parallel, cpu_count, delayed 12 | from skimage.io import imread 13 | from torch.utils.data import Dataset 14 | from tqdm import tqdm 15 | 16 | import aug 17 | 18 | 19 | def subsample(data: Iterable, bounds: Tuple[float, float], hash_fn: Callable, n_buckets=100, salt='', verbose=True): 20 | data = list(data) 21 | buckets = split_into_buckets(data, n_buckets=n_buckets, salt=salt, hash_fn=hash_fn) 22 | 23 | lower_bound, upper_bound = [x * n_buckets for x in bounds] 24 | msg = f'Subsampling buckets from {lower_bound} to {upper_bound}, total buckets number is {n_buckets}' 25 | if salt: 26 | msg += f'; salt is {salt}' 27 | if verbose: 28 | logger.info(msg) 29 | return np.array([sample for bucket, sample in zip(buckets, data) if lower_bound <= bucket < upper_bound]) 30 | 31 | 32 | def hash_from_paths(x: Tuple[str, str], salt: str = '') -> str: 33 | path_a, path_b = x 34 | names = ''.join(map(os.path.basename, (path_a, path_b))) 35 | return sha1(f'{names}_{salt}'.encode()).hexdigest() 36 | 37 | 38 | def split_into_buckets(data: Iterable, n_buckets: int, hash_fn: Callable, salt=''): 39 | hashes = map(partial(hash_fn, salt=salt), data) 40 | return np.array([int(x, 16) % n_buckets for x in hashes]) 41 | 42 | 43 | def _read_img(x: str): 44 | img = cv2.imread(x) 45 | if img is None: 46 | logger.warning(f'Can not read image {x} with OpenCV, switching to scikit-image') 47 | img = imread(x)[:, :, ::-1] 48 | return img 49 | 50 | 51 | class PairedDataset(Dataset): 52 | def __init__(self, 53 | files_a: Tuple[str], 54 | files_b: Tuple[str], 55 | transform_fn: Callable, 56 | normalize_fn: Callable, 57 | corrupt_fn: Optional[Callable] = None, 58 | preload: bool = True, 59 | preload_size: Optional[int] = 0, 60 | verbose=True): 61 | 62 | assert len(files_a) == len(files_b) 63 | 64 | self.preload = preload 65 | self.data_a = files_a 66 | self.data_b = files_b 67 | self.verbose = verbose 68 | self.corrupt_fn = corrupt_fn 69 | self.transform_fn = transform_fn 70 | self.normalize_fn = normalize_fn 71 | logger.info(f'Dataset has been created with {len(self.data_a)} samples') 72 | 73 | if preload: 74 | preload_fn = partial(self._bulk_preload, preload_size=preload_size) 75 | if files_a == files_b: 76 | self.data_a = self.data_b = preload_fn(self.data_a) 77 | else: 78 | self.data_a, self.data_b = map(preload_fn, (self.data_a, self.data_b)) 79 | self.preload = True 80 | 81 | def _bulk_preload(self, data: Iterable[str], preload_size: int): 82 | jobs = [delayed(self._preload)(x, preload_size=preload_size) for x in data] 83 | jobs = tqdm(jobs, desc='preloading images', disable=not self.verbose) 84 | return Parallel(n_jobs=cpu_count(), backend='threading')(jobs) 85 | 86 | @staticmethod 87 | def _preload(x: str, preload_size: int): 88 | img = _read_img(x) 89 | if preload_size: 90 | h, w, *_ = img.shape 91 | h_scale = preload_size / h 92 | w_scale = preload_size / w 93 | scale = max(h_scale, w_scale) 94 | img = cv2.resize(img, fx=scale, fy=scale, dsize=None) 95 | assert min(img.shape[:2]) >= preload_size, f'weird img shape: {img.shape}' 96 | return img 97 | 98 | def _preprocess(self, img, res): 99 | def transpose(x): 100 | return np.transpose(x, (2, 0, 1)) 101 | 102 | return map(transpose, self.normalize_fn(img, res)) 103 | 104 | def __len__(self): 105 | return len(self.data_a) 106 | 107 | def __getitem__(self, idx): 108 | a, b = self.data_a[idx], self.data_b[idx] 109 | if not self.preload: 110 | a, b = map(_read_img, (a, b)) 111 | a, b = self.transform_fn(a, b) 112 | if self.corrupt_fn is not None: 113 | a = self.corrupt_fn(a) 114 | a, b = self._preprocess(a, b) 115 | return {'a': a, 'b': b} 116 | 117 | @staticmethod 118 | def from_config(config): 119 | config = deepcopy(config) 120 | files_a, files_b = map(lambda x: sorted(glob(config[x], recursive=True)), ('files_a', 'files_b')) 121 | transform_fn = aug.get_transforms(size=config['size'], scope=config['scope'], crop=config['crop']) 122 | normalize_fn = aug.get_normalize() 123 | corrupt_fn = aug.get_corrupt_function(config['corrupt']) 124 | 125 | hash_fn = hash_from_paths 126 | # ToDo: add more hash functions 127 | verbose = config.get('verbose', True) 128 | data = subsample(data=zip(files_a, files_b), 129 | bounds=config.get('bounds', (0, 1)), 130 | hash_fn=hash_fn, 131 | verbose=verbose) 132 | 133 | files_a, files_b = map(list, zip(*data)) 134 | 135 | return PairedDataset(files_a=files_a, 136 | files_b=files_b, 137 | preload=config['preload'], 138 | preload_size=config['preload_size'], 139 | corrupt_fn=corrupt_fn, 140 | normalize_fn=normalize_fn, 141 | transform_fn=transform_fn, 142 | verbose=verbose) 143 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/metric_counter.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from tensorboardX import SummaryWriter 6 | 7 | WINDOW_SIZE = 100 8 | 9 | 10 | class MetricCounter: 11 | def __init__(self, exp_name): 12 | self.writer = SummaryWriter(exp_name) 13 | logging.basicConfig(filename='{}.log'.format(exp_name), level=logging.DEBUG) 14 | self.metrics = defaultdict(list) 15 | self.images = defaultdict(list) 16 | self.best_metric = 0 17 | 18 | def add_image(self, x: np.ndarray, tag: str): 19 | self.images[tag].append(x) 20 | 21 | def clear(self): 22 | self.metrics = defaultdict(list) 23 | self.images = defaultdict(list) 24 | 25 | def add_losses(self, l_G, l_content, l_D=0): 26 | for name, value in zip(('G_loss', 'G_loss_content', 'G_loss_adv', 'D_loss'), 27 | (l_G, l_content, l_G - l_content, l_D)): 28 | self.metrics[name].append(value) 29 | 30 | def add_metrics(self, psnr, ssim): 31 | for name, value in zip(('PSNR', 'SSIM'), 32 | (psnr, ssim)): 33 | self.metrics[name].append(value) 34 | 35 | def loss_message(self): 36 | metrics = ((k, np.mean(self.metrics[k][-WINDOW_SIZE:])) for k in ('G_loss', 'PSNR', 'SSIM')) 37 | return '; '.join(map(lambda x: f'{x[0]}={x[1]:.4f}', metrics)) 38 | 39 | def write_to_tensorboard(self, epoch_num, validation=False): 40 | scalar_prefix = 'Validation' if validation else 'Train' 41 | for tag in ('G_loss', 'D_loss', 'G_loss_adv', 'G_loss_content', 'SSIM', 'PSNR'): 42 | self.writer.add_scalar(f'{scalar_prefix}_{tag}', np.mean(self.metrics[tag]), global_step=epoch_num) 43 | for tag in self.images: 44 | imgs = self.images[tag] 45 | if imgs: 46 | imgs = np.array(imgs) 47 | self.writer.add_images(tag, imgs[:, :, :, ::-1].astype('float32') / 255, dataformats='NHWC', 48 | global_step=epoch_num) 49 | self.images[tag] = [] 50 | 51 | def update_best_model(self): 52 | cur_metric = np.mean(self.metrics['PSNR']) 53 | if self.best_metric < cur_metric: 54 | self.best_metric = cur_metric 55 | return True 56 | return False 57 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/neural_networks/DeblurGANv2/models/__init__.py -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/models/fpn_densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torchvision.models import densenet121 5 | 6 | 7 | class FPNSegHead(nn.Module): 8 | def __init__(self, num_in, num_mid, num_out): 9 | super().__init__() 10 | 11 | self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False) 12 | self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False) 13 | 14 | def forward(self, x): 15 | x = nn.functional.relu(self.block0(x), inplace=True) 16 | x = nn.functional.relu(self.block1(x), inplace=True) 17 | return x 18 | 19 | 20 | class FPNDense(nn.Module): 21 | 22 | def __init__(self, output_ch=3, num_filters=128, num_filters_fpn=256, pretrained=True): 23 | super().__init__() 24 | 25 | # Feature Pyramid Network (FPN) with four feature maps of resolutions 26 | # 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps. 27 | 28 | self.fpn = FPN(num_filters=num_filters_fpn, pretrained=pretrained) 29 | 30 | # The segmentation heads on top of the FPN 31 | 32 | self.head1 = FPNSegHead(num_filters_fpn, num_filters, num_filters) 33 | self.head2 = FPNSegHead(num_filters_fpn, num_filters, num_filters) 34 | self.head3 = FPNSegHead(num_filters_fpn, num_filters, num_filters) 35 | self.head4 = FPNSegHead(num_filters_fpn, num_filters, num_filters) 36 | 37 | self.smooth = nn.Sequential( 38 | nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1), 39 | nn.BatchNorm2d(num_filters), 40 | nn.ReLU(), 41 | ) 42 | 43 | self.smooth2 = nn.Sequential( 44 | nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1), 45 | nn.BatchNorm2d(num_filters // 2), 46 | nn.ReLU(), 47 | ) 48 | 49 | self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1) 50 | 51 | def forward(self, x): 52 | map0, map1, map2, map3, map4 = self.fpn(x) 53 | 54 | map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest") 55 | map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest") 56 | map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest") 57 | map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest") 58 | 59 | smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1)) 60 | smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest") 61 | smoothed = self.smooth2(smoothed + map0) 62 | smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest") 63 | 64 | final = self.final(smoothed) 65 | return torch.tanh(final) 66 | 67 | def unfreeze(self): 68 | for param in self.fpn.parameters(): 69 | param.requires_grad = True 70 | 71 | 72 | class FPN(nn.Module): 73 | 74 | def __init__(self, num_filters=256, pretrained=True): 75 | """Creates an `FPN` instance for feature extraction. 76 | Args: 77 | num_filters: the number of filters in each output pyramid level 78 | pretrained: use ImageNet pre-trained backbone feature extractor 79 | """ 80 | 81 | super().__init__() 82 | 83 | self.features = densenet121(pretrained=pretrained).features 84 | 85 | self.enc0 = nn.Sequential(self.features.conv0, 86 | self.features.norm0, 87 | self.features.relu0) 88 | self.pool0 = self.features.pool0 89 | self.enc1 = self.features.denseblock1 # 256 90 | self.enc2 = self.features.denseblock2 # 512 91 | self.enc3 = self.features.denseblock3 # 1024 92 | self.enc4 = self.features.denseblock4 # 2048 93 | self.norm = self.features.norm5 # 2048 94 | 95 | self.tr1 = self.features.transition1 # 256 96 | self.tr2 = self.features.transition2 # 512 97 | self.tr3 = self.features.transition3 # 1024 98 | 99 | self.lateral4 = nn.Conv2d(1024, num_filters, kernel_size=1, bias=False) 100 | self.lateral3 = nn.Conv2d(1024, num_filters, kernel_size=1, bias=False) 101 | self.lateral2 = nn.Conv2d(512, num_filters, kernel_size=1, bias=False) 102 | self.lateral1 = nn.Conv2d(256, num_filters, kernel_size=1, bias=False) 103 | self.lateral0 = nn.Conv2d(64, num_filters // 2, kernel_size=1, bias=False) 104 | 105 | def forward(self, x): 106 | # Bottom-up pathway, from ResNet 107 | enc0 = self.enc0(x) 108 | 109 | pooled = self.pool0(enc0) 110 | 111 | enc1 = self.enc1(pooled) # 256 112 | tr1 = self.tr1(enc1) 113 | 114 | enc2 = self.enc2(tr1) # 512 115 | tr2 = self.tr2(enc2) 116 | 117 | enc3 = self.enc3(tr2) # 1024 118 | tr3 = self.tr3(enc3) 119 | 120 | enc4 = self.enc4(tr3) # 2048 121 | enc4 = self.norm(enc4) 122 | 123 | # Lateral connections 124 | 125 | lateral4 = self.lateral4(enc4) 126 | lateral3 = self.lateral3(enc3) 127 | lateral2 = self.lateral2(enc2) 128 | lateral1 = self.lateral1(enc1) 129 | lateral0 = self.lateral0(enc0) 130 | 131 | # Top-down pathway 132 | 133 | map4 = lateral4 134 | map3 = lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest") 135 | map2 = lateral2 + nn.functional.upsample(map3, scale_factor=2, mode="nearest") 136 | map1 = lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest") 137 | 138 | return lateral0, map1, map2, map3, map4 139 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/models/fpn_inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pretrainedmodels import inceptionresnetv2 4 | from torchsummary import summary 5 | import torch.nn.functional as F 6 | 7 | class FPNHead(nn.Module): 8 | def __init__(self, num_in, num_mid, num_out): 9 | super().__init__() 10 | 11 | self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False) 12 | self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False) 13 | 14 | def forward(self, x): 15 | x = nn.functional.relu(self.block0(x), inplace=True) 16 | x = nn.functional.relu(self.block1(x), inplace=True) 17 | return x 18 | 19 | class ConvBlock(nn.Module): 20 | def __init__(self, num_in, num_out, norm_layer): 21 | super().__init__() 22 | 23 | self.block = nn.Sequential(nn.Conv2d(num_in, num_out, kernel_size=3, padding=1), 24 | norm_layer(num_out), 25 | nn.ReLU(inplace=True)) 26 | 27 | def forward(self, x): 28 | x = self.block(x) 29 | return x 30 | 31 | 32 | class FPNInception(nn.Module): 33 | 34 | def __init__(self, norm_layer, output_ch=3, num_filters=128, num_filters_fpn=256): 35 | super().__init__() 36 | 37 | # Feature Pyramid Network (FPN) with four feature maps of resolutions 38 | # 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps. 39 | self.fpn = FPN(num_filters=num_filters_fpn, norm_layer=norm_layer) 40 | 41 | # The segmentation heads on top of the FPN 42 | 43 | self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters) 44 | self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters) 45 | self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters) 46 | self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters) 47 | 48 | self.smooth = nn.Sequential( 49 | nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1), 50 | norm_layer(num_filters), 51 | nn.ReLU(), 52 | ) 53 | 54 | self.smooth2 = nn.Sequential( 55 | nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1), 56 | norm_layer(num_filters // 2), 57 | nn.ReLU(), 58 | ) 59 | 60 | self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1) 61 | 62 | def unfreeze(self): 63 | self.fpn.unfreeze() 64 | 65 | def forward(self, x): 66 | map0, map1, map2, map3, map4 = self.fpn(x) 67 | 68 | map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest") 69 | map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest") 70 | map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest") 71 | map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest") 72 | 73 | smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1)) 74 | smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest") 75 | smoothed = self.smooth2(smoothed + map0) 76 | smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest") 77 | 78 | final = self.final(smoothed) 79 | res = torch.tanh(final) + x 80 | 81 | return torch.clamp(res, min = -1,max = 1) 82 | 83 | 84 | class FPN(nn.Module): 85 | 86 | def __init__(self, norm_layer, num_filters=256): 87 | """Creates an `FPN` instance for feature extraction. 88 | Args: 89 | num_filters: the number of filters in each output pyramid level 90 | pretrained: use ImageNet pre-trained backbone feature extractor 91 | """ 92 | 93 | super().__init__() 94 | self.inception = inceptionresnetv2(num_classes=1000, pretrained='imagenet') 95 | 96 | self.enc0 = self.inception.conv2d_1a 97 | self.enc1 = nn.Sequential( 98 | self.inception.conv2d_2a, 99 | self.inception.conv2d_2b, 100 | self.inception.maxpool_3a, 101 | ) # 64 102 | self.enc2 = nn.Sequential( 103 | self.inception.conv2d_3b, 104 | self.inception.conv2d_4a, 105 | self.inception.maxpool_5a, 106 | ) # 192 107 | self.enc3 = nn.Sequential( 108 | self.inception.mixed_5b, 109 | self.inception.repeat, 110 | self.inception.mixed_6a, 111 | ) # 1088 112 | self.enc4 = nn.Sequential( 113 | self.inception.repeat_1, 114 | self.inception.mixed_7a, 115 | ) #2080 116 | self.td1 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1), 117 | norm_layer(num_filters), 118 | nn.ReLU(inplace=True)) 119 | self.td2 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1), 120 | norm_layer(num_filters), 121 | nn.ReLU(inplace=True)) 122 | self.td3 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1), 123 | norm_layer(num_filters), 124 | nn.ReLU(inplace=True)) 125 | self.pad = nn.ReflectionPad2d(1) 126 | self.lateral4 = nn.Conv2d(2080, num_filters, kernel_size=1, bias=False) 127 | self.lateral3 = nn.Conv2d(1088, num_filters, kernel_size=1, bias=False) 128 | self.lateral2 = nn.Conv2d(192, num_filters, kernel_size=1, bias=False) 129 | self.lateral1 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False) 130 | self.lateral0 = nn.Conv2d(32, num_filters // 2, kernel_size=1, bias=False) 131 | 132 | for param in self.inception.parameters(): 133 | param.requires_grad = False 134 | 135 | def unfreeze(self): 136 | for param in self.inception.parameters(): 137 | param.requires_grad = True 138 | 139 | def forward(self, x): 140 | 141 | # Bottom-up pathway, from ResNet 142 | enc0 = self.enc0(x) 143 | 144 | enc1 = self.enc1(enc0) # 256 145 | 146 | enc2 = self.enc2(enc1) # 512 147 | 148 | enc3 = self.enc3(enc2) # 1024 149 | 150 | enc4 = self.enc4(enc3) # 2048 151 | 152 | # Lateral connections 153 | 154 | lateral4 = self.pad(self.lateral4(enc4)) 155 | lateral3 = self.pad(self.lateral3(enc3)) 156 | lateral2 = self.lateral2(enc2) 157 | lateral1 = self.pad(self.lateral1(enc1)) 158 | lateral0 = self.lateral0(enc0) 159 | 160 | # Top-down pathway 161 | pad = (1, 2, 1, 2) # pad last dim by 1 on each side 162 | pad1 = (0, 1, 0, 1) 163 | map4 = lateral4 164 | map3 = self.td1(lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest")) 165 | map2 = self.td2(F.pad(lateral2, pad, "reflect") + nn.functional.upsample(map3, scale_factor=2, mode="nearest")) 166 | map1 = self.td3(lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest")) 167 | return F.pad(lateral0, pad1, "reflect"), map1, map2, map3, map4 168 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/models/fpn_inception_simple.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pretrainedmodels import inceptionresnetv2 4 | from torchsummary import summary 5 | import torch.nn.functional as F 6 | 7 | class FPNHead(nn.Module): 8 | def __init__(self, num_in, num_mid, num_out): 9 | super().__init__() 10 | 11 | self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False) 12 | self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False) 13 | 14 | def forward(self, x): 15 | x = nn.functional.relu(self.block0(x), inplace=True) 16 | x = nn.functional.relu(self.block1(x), inplace=True) 17 | return x 18 | 19 | class ConvBlock(nn.Module): 20 | def __init__(self, num_in, num_out, norm_layer): 21 | super().__init__() 22 | 23 | self.block = nn.Sequential(nn.Conv2d(num_in, num_out, kernel_size=3, padding=1), 24 | norm_layer(num_out), 25 | nn.ReLU(inplace=True)) 26 | 27 | def forward(self, x): 28 | x = self.block(x) 29 | return x 30 | 31 | 32 | class FPNInceptionSimple(nn.Module): 33 | 34 | def __init__(self, norm_layer, output_ch=3, num_filters=128, num_filters_fpn=256): 35 | super().__init__() 36 | 37 | # Feature Pyramid Network (FPN) with four feature maps of resolutions 38 | # 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps. 39 | self.fpn = FPN(num_filters=num_filters_fpn, norm_layer=norm_layer) 40 | 41 | # The segmentation heads on top of the FPN 42 | 43 | self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters) 44 | self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters) 45 | self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters) 46 | self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters) 47 | 48 | self.smooth = nn.Sequential( 49 | nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1), 50 | norm_layer(num_filters), 51 | nn.ReLU(), 52 | ) 53 | 54 | self.smooth2 = nn.Sequential( 55 | nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1), 56 | norm_layer(num_filters // 2), 57 | nn.ReLU(), 58 | ) 59 | 60 | self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1) 61 | 62 | def unfreeze(self): 63 | self.fpn.unfreeze() 64 | 65 | def forward(self, x): 66 | 67 | map0, map1, map2, map3, map4 = self.fpn(x) 68 | 69 | map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest") 70 | map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest") 71 | map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest") 72 | map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest") 73 | 74 | smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1)) 75 | smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest") 76 | smoothed = self.smooth2(smoothed + map0) 77 | smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest") 78 | 79 | final = self.final(smoothed) 80 | res = torch.tanh(final) + x 81 | 82 | return torch.clamp(res, min = -1,max = 1) 83 | 84 | 85 | class FPN(nn.Module): 86 | 87 | def __init__(self, norm_layer, num_filters=256): 88 | """Creates an `FPN` instance for feature extraction. 89 | Args: 90 | num_filters: the number of filters in each output pyramid level 91 | pretrained: use ImageNet pre-trained backbone feature extractor 92 | """ 93 | 94 | super().__init__() 95 | self.inception = inceptionresnetv2(num_classes=1000, pretrained='imagenet') 96 | 97 | self.enc0 = self.inception.conv2d_1a 98 | self.enc1 = nn.Sequential( 99 | self.inception.conv2d_2a, 100 | self.inception.conv2d_2b, 101 | self.inception.maxpool_3a, 102 | ) # 64 103 | self.enc2 = nn.Sequential( 104 | self.inception.conv2d_3b, 105 | self.inception.conv2d_4a, 106 | self.inception.maxpool_5a, 107 | ) # 192 108 | self.enc3 = nn.Sequential( 109 | self.inception.mixed_5b, 110 | self.inception.repeat, 111 | self.inception.mixed_6a, 112 | ) # 1088 113 | self.enc4 = nn.Sequential( 114 | self.inception.repeat_1, 115 | self.inception.mixed_7a, 116 | ) #2080 117 | 118 | self.pad = nn.ReflectionPad2d(1) 119 | self.lateral4 = nn.Conv2d(2080, num_filters, kernel_size=1, bias=False) 120 | self.lateral3 = nn.Conv2d(1088, num_filters, kernel_size=1, bias=False) 121 | self.lateral2 = nn.Conv2d(192, num_filters, kernel_size=1, bias=False) 122 | self.lateral1 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False) 123 | self.lateral0 = nn.Conv2d(32, num_filters // 2, kernel_size=1, bias=False) 124 | 125 | for param in self.inception.parameters(): 126 | param.requires_grad = False 127 | 128 | def unfreeze(self): 129 | for param in self.inception.parameters(): 130 | param.requires_grad = True 131 | 132 | def forward(self, x): 133 | 134 | # Bottom-up pathway, from ResNet 135 | enc0 = self.enc0(x) 136 | 137 | enc1 = self.enc1(enc0) # 256 138 | 139 | enc2 = self.enc2(enc1) # 512 140 | 141 | enc3 = self.enc3(enc2) # 1024 142 | 143 | enc4 = self.enc4(enc3) # 2048 144 | 145 | # Lateral connections 146 | 147 | lateral4 = self.pad(self.lateral4(enc4)) 148 | lateral3 = self.pad(self.lateral3(enc3)) 149 | lateral2 = self.lateral2(enc2) 150 | lateral1 = self.pad(self.lateral1(enc1)) 151 | lateral0 = self.lateral0(enc0) 152 | 153 | # Top-down pathway 154 | pad = (1, 2, 1, 2) # pad last dim by 1 on each side 155 | pad1 = (0, 1, 0, 1) 156 | map4 = lateral4 157 | map3 = lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest") 158 | map2 = F.pad(lateral2, pad, "reflect") + nn.functional.upsample(map3, scale_factor=2, mode="nearest") 159 | map1 = lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest") 160 | return F.pad(lateral0, pad1, "reflect"), map1, map2, map3, map4 161 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/models/fpn_mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.mobilenet_v2 import MobileNetV2 4 | 5 | class FPNHead(nn.Module): 6 | def __init__(self, num_in, num_mid, num_out): 7 | super().__init__() 8 | 9 | self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False) 10 | self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False) 11 | 12 | def forward(self, x): 13 | x = nn.functional.relu(self.block0(x), inplace=True) 14 | x = nn.functional.relu(self.block1(x), inplace=True) 15 | return x 16 | 17 | 18 | class FPNMobileNet(nn.Module): 19 | 20 | def __init__(self, norm_layer, output_ch=3, num_filters=64, num_filters_fpn=128, pretrained=True): 21 | super().__init__() 22 | 23 | # Feature Pyramid Network (FPN) with four feature maps of resolutions 24 | # 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps. 25 | 26 | self.fpn = FPN(num_filters=num_filters_fpn, norm_layer = norm_layer, pretrained=pretrained) 27 | 28 | # The segmentation heads on top of the FPN 29 | 30 | self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters) 31 | self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters) 32 | self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters) 33 | self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters) 34 | 35 | self.smooth = nn.Sequential( 36 | nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1), 37 | norm_layer(num_filters), 38 | nn.ReLU(), 39 | ) 40 | 41 | self.smooth2 = nn.Sequential( 42 | nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1), 43 | norm_layer(num_filters // 2), 44 | nn.ReLU(), 45 | ) 46 | 47 | self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1) 48 | 49 | def unfreeze(self): 50 | self.fpn.unfreeze() 51 | 52 | def forward(self, x): 53 | 54 | map0, map1, map2, map3, map4 = self.fpn(x) 55 | 56 | map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest") 57 | map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest") 58 | map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest") 59 | map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest") 60 | 61 | smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1)) 62 | smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest") 63 | smoothed = self.smooth2(smoothed + map0) 64 | smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest") 65 | 66 | final = self.final(smoothed) 67 | res = torch.tanh(final) + x 68 | 69 | return torch.clamp(res, min=-1, max=1) 70 | 71 | 72 | class FPN(nn.Module): 73 | 74 | def __init__(self, norm_layer, num_filters=128, pretrained=True): 75 | """Creates an `FPN` instance for feature extraction. 76 | Args: 77 | num_filters: the number of filters in each output pyramid level 78 | pretrained: use ImageNet pre-trained backbone feature extractor 79 | """ 80 | 81 | super().__init__() 82 | net = MobileNetV2(n_class=1000) 83 | 84 | if pretrained: 85 | #Load weights into the project directory 86 | state_dict = torch.load('mobilenetv2.pth.tar') # add map_location='cpu' if no gpu 87 | net.load_state_dict(state_dict) 88 | self.features = net.features 89 | 90 | self.enc0 = nn.Sequential(*self.features[0:2]) 91 | self.enc1 = nn.Sequential(*self.features[2:4]) 92 | self.enc2 = nn.Sequential(*self.features[4:7]) 93 | self.enc3 = nn.Sequential(*self.features[7:11]) 94 | self.enc4 = nn.Sequential(*self.features[11:16]) 95 | 96 | self.td1 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1), 97 | norm_layer(num_filters), 98 | nn.ReLU(inplace=True)) 99 | self.td2 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1), 100 | norm_layer(num_filters), 101 | nn.ReLU(inplace=True)) 102 | self.td3 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1), 103 | norm_layer(num_filters), 104 | nn.ReLU(inplace=True)) 105 | 106 | self.lateral4 = nn.Conv2d(160, num_filters, kernel_size=1, bias=False) 107 | self.lateral3 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False) 108 | self.lateral2 = nn.Conv2d(32, num_filters, kernel_size=1, bias=False) 109 | self.lateral1 = nn.Conv2d(24, num_filters, kernel_size=1, bias=False) 110 | self.lateral0 = nn.Conv2d(16, num_filters // 2, kernel_size=1, bias=False) 111 | 112 | for param in self.features.parameters(): 113 | param.requires_grad = False 114 | 115 | def unfreeze(self): 116 | for param in self.features.parameters(): 117 | param.requires_grad = True 118 | 119 | 120 | def forward(self, x): 121 | 122 | # Bottom-up pathway, from ResNet 123 | enc0 = self.enc0(x) 124 | 125 | enc1 = self.enc1(enc0) # 256 126 | 127 | enc2 = self.enc2(enc1) # 512 128 | 129 | enc3 = self.enc3(enc2) # 1024 130 | 131 | enc4 = self.enc4(enc3) # 2048 132 | 133 | # Lateral connections 134 | 135 | lateral4 = self.lateral4(enc4) 136 | lateral3 = self.lateral3(enc3) 137 | lateral2 = self.lateral2(enc2) 138 | lateral1 = self.lateral1(enc1) 139 | lateral0 = self.lateral0(enc0) 140 | 141 | # Top-down pathway 142 | map4 = lateral4 143 | map3 = self.td1(lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest")) 144 | map2 = self.td2(lateral2 + nn.functional.upsample(map3, scale_factor=2, mode="nearest")) 145 | map1 = self.td3(lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest")) 146 | return lateral0, map1, map2, map3, map4 147 | 148 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import torchvision.models as models 5 | import torchvision.transforms as transforms 6 | from torch.autograd import Variable 7 | 8 | from util.image_pool import ImagePool 9 | 10 | 11 | ############################################################################### 12 | # Functions 13 | ############################################################################### 14 | 15 | class ContentLoss(): 16 | def initialize(self, loss): 17 | self.criterion = loss 18 | 19 | def get_loss(self, fakeIm, realIm): 20 | return self.criterion(fakeIm, realIm) 21 | 22 | def __call__(self, fakeIm, realIm): 23 | return self.get_loss(fakeIm, realIm) 24 | 25 | 26 | class PerceptualLoss(): 27 | 28 | def contentFunc(self): 29 | conv_3_3_layer = 14 30 | cnn = models.vgg19(pretrained=True).features 31 | cnn = cnn.cuda() 32 | model = nn.Sequential() 33 | model = model.cuda() 34 | model = model.eval() 35 | for i, layer in enumerate(list(cnn)): 36 | model.add_module(str(i), layer) 37 | if i == conv_3_3_layer: 38 | break 39 | return model 40 | 41 | def initialize(self, loss): 42 | with torch.no_grad(): 43 | self.criterion = loss 44 | self.contentFunc = self.contentFunc() 45 | self.transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 46 | 47 | def get_loss(self, fakeIm, realIm): 48 | fakeIm = (fakeIm + 1) / 2.0 49 | realIm = (realIm + 1) / 2.0 50 | fakeIm[0, :, :, :] = self.transform(fakeIm[0, :, :, :]) 51 | realIm[0, :, :, :] = self.transform(realIm[0, :, :, :]) 52 | f_fake = self.contentFunc.forward(fakeIm) 53 | f_real = self.contentFunc.forward(realIm) 54 | f_real_no_grad = f_real.detach() 55 | loss = self.criterion(f_fake, f_real_no_grad) 56 | return 0.006 * torch.mean(loss) + 0.5 * nn.MSELoss()(fakeIm, realIm) 57 | 58 | def __call__(self, fakeIm, realIm): 59 | return self.get_loss(fakeIm, realIm) 60 | 61 | 62 | class GANLoss(nn.Module): 63 | def __init__(self, use_l1=True, target_real_label=1.0, target_fake_label=0.0, 64 | tensor=torch.FloatTensor): 65 | super(GANLoss, self).__init__() 66 | self.real_label = target_real_label 67 | self.fake_label = target_fake_label 68 | self.real_label_var = None 69 | self.fake_label_var = None 70 | self.Tensor = tensor 71 | if use_l1: 72 | self.loss = nn.L1Loss() 73 | else: 74 | self.loss = nn.BCEWithLogitsLoss() 75 | 76 | def get_target_tensor(self, input, target_is_real): 77 | if target_is_real: 78 | create_label = ((self.real_label_var is None) or 79 | (self.real_label_var.numel() != input.numel())) 80 | if create_label: 81 | real_tensor = self.Tensor(input.size()).fill_(self.real_label) 82 | self.real_label_var = Variable(real_tensor, requires_grad=False) 83 | target_tensor = self.real_label_var 84 | else: 85 | create_label = ((self.fake_label_var is None) or 86 | (self.fake_label_var.numel() != input.numel())) 87 | if create_label: 88 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) 89 | self.fake_label_var = Variable(fake_tensor, requires_grad=False) 90 | target_tensor = self.fake_label_var 91 | return target_tensor.cuda() 92 | 93 | def __call__(self, input, target_is_real): 94 | target_tensor = self.get_target_tensor(input, target_is_real) 95 | return self.loss(input, target_tensor) 96 | 97 | 98 | class DiscLoss(nn.Module): 99 | def name(self): 100 | return 'DiscLoss' 101 | 102 | def __init__(self): 103 | super(DiscLoss, self).__init__() 104 | 105 | self.criterionGAN = GANLoss(use_l1=False) 106 | self.fake_AB_pool = ImagePool(50) 107 | 108 | def get_g_loss(self, net, fakeB, realB): 109 | # First, G(A) should fake the discriminator 110 | pred_fake = net.forward(fakeB) 111 | return self.criterionGAN(pred_fake, 1) 112 | 113 | def get_loss(self, net, fakeB, realB): 114 | # Fake 115 | # stop backprop to the generator by detaching fake_B 116 | # Generated Image Disc Output should be close to zero 117 | self.pred_fake = net.forward(fakeB.detach()) 118 | self.loss_D_fake = self.criterionGAN(self.pred_fake, 0) 119 | 120 | # Real 121 | self.pred_real = net.forward(realB) 122 | self.loss_D_real = self.criterionGAN(self.pred_real, 1) 123 | 124 | # Combined loss 125 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 126 | return self.loss_D 127 | 128 | def __call__(self, net, fakeB, realB): 129 | return self.get_loss(net, fakeB, realB) 130 | 131 | 132 | class RelativisticDiscLoss(nn.Module): 133 | def name(self): 134 | return 'RelativisticDiscLoss' 135 | 136 | def __init__(self): 137 | super(RelativisticDiscLoss, self).__init__() 138 | 139 | self.criterionGAN = GANLoss(use_l1=False) 140 | self.fake_pool = ImagePool(50) # create image buffer to store previously generated images 141 | self.real_pool = ImagePool(50) 142 | 143 | def get_g_loss(self, net, fakeB, realB): 144 | # First, G(A) should fake the discriminator 145 | self.pred_fake = net.forward(fakeB) 146 | 147 | # Real 148 | self.pred_real = net.forward(realB) 149 | errG = (self.criterionGAN(self.pred_real - torch.mean(self.fake_pool.query()), 0) + 150 | self.criterionGAN(self.pred_fake - torch.mean(self.real_pool.query()), 1)) / 2 151 | return errG 152 | 153 | def get_loss(self, net, fakeB, realB): 154 | # Fake 155 | # stop backprop to the generator by detaching fake_B 156 | # Generated Image Disc Output should be close to zero 157 | self.fake_B = fakeB.detach() 158 | self.real_B = realB 159 | self.pred_fake = net.forward(fakeB.detach()) 160 | self.fake_pool.add(self.pred_fake) 161 | 162 | # Real 163 | self.pred_real = net.forward(realB) 164 | self.real_pool.add(self.pred_real) 165 | 166 | # Combined loss 167 | self.loss_D = (self.criterionGAN(self.pred_real - torch.mean(self.fake_pool.query()), 1) + 168 | self.criterionGAN(self.pred_fake - torch.mean(self.real_pool.query()), 0)) / 2 169 | return self.loss_D 170 | 171 | def __call__(self, net, fakeB, realB): 172 | return self.get_loss(net, fakeB, realB) 173 | 174 | 175 | class RelativisticDiscLossLS(nn.Module): 176 | def name(self): 177 | return 'RelativisticDiscLossLS' 178 | 179 | def __init__(self): 180 | super(RelativisticDiscLossLS, self).__init__() 181 | 182 | self.criterionGAN = GANLoss(use_l1=True) 183 | self.fake_pool = ImagePool(50) # create image buffer to store previously generated images 184 | self.real_pool = ImagePool(50) 185 | 186 | def get_g_loss(self, net, fakeB, realB): 187 | # First, G(A) should fake the discriminator 188 | self.pred_fake = net.forward(fakeB) 189 | 190 | # Real 191 | self.pred_real = net.forward(realB) 192 | errG = (torch.mean((self.pred_real - torch.mean(self.fake_pool.query()) + 1) ** 2) + 193 | torch.mean((self.pred_fake - torch.mean(self.real_pool.query()) - 1) ** 2)) / 2 194 | return errG 195 | 196 | def get_loss(self, net, fakeB, realB): 197 | # Fake 198 | # stop backprop to the generator by detaching fake_B 199 | # Generated Image Disc Output should be close to zero 200 | self.fake_B = fakeB.detach() 201 | self.real_B = realB 202 | self.pred_fake = net.forward(fakeB.detach()) 203 | self.fake_pool.add(self.pred_fake) 204 | 205 | # Real 206 | self.pred_real = net.forward(realB) 207 | self.real_pool.add(self.pred_real) 208 | 209 | # Combined loss 210 | self.loss_D = (torch.mean((self.pred_real - torch.mean(self.fake_pool.query()) - 1) ** 2) + 211 | torch.mean((self.pred_fake - torch.mean(self.real_pool.query()) + 1) ** 2)) / 2 212 | return self.loss_D 213 | 214 | def __call__(self, net, fakeB, realB): 215 | return self.get_loss(net, fakeB, realB) 216 | 217 | 218 | class DiscLossLS(DiscLoss): 219 | def name(self): 220 | return 'DiscLossLS' 221 | 222 | def __init__(self): 223 | super(DiscLossLS, self).__init__() 224 | self.criterionGAN = GANLoss(use_l1=True) 225 | 226 | def get_g_loss(self, net, fakeB, realB): 227 | return DiscLoss.get_g_loss(self, net, fakeB) 228 | 229 | def get_loss(self, net, fakeB, realB): 230 | return DiscLoss.get_loss(self, net, fakeB, realB) 231 | 232 | 233 | class DiscLossWGANGP(DiscLossLS): 234 | def name(self): 235 | return 'DiscLossWGAN-GP' 236 | 237 | def __init__(self): 238 | super(DiscLossWGANGP, self).__init__() 239 | self.LAMBDA = 10 240 | 241 | def get_g_loss(self, net, fakeB, realB): 242 | # First, G(A) should fake the discriminator 243 | self.D_fake = net.forward(fakeB) 244 | return -self.D_fake.mean() 245 | 246 | def calc_gradient_penalty(self, netD, real_data, fake_data): 247 | alpha = torch.rand(1, 1) 248 | alpha = alpha.expand(real_data.size()) 249 | alpha = alpha.cuda() 250 | 251 | interpolates = alpha * real_data + ((1 - alpha) * fake_data) 252 | 253 | interpolates = interpolates.cuda() 254 | interpolates = Variable(interpolates, requires_grad=True) 255 | 256 | disc_interpolates = netD.forward(interpolates) 257 | 258 | gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, 259 | grad_outputs=torch.ones(disc_interpolates.size()).cuda(), 260 | create_graph=True, retain_graph=True, only_inputs=True)[0] 261 | 262 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA 263 | return gradient_penalty 264 | 265 | def get_loss(self, net, fakeB, realB): 266 | self.D_fake = net.forward(fakeB.detach()) 267 | self.D_fake = self.D_fake.mean() 268 | 269 | # Real 270 | self.D_real = net.forward(realB) 271 | self.D_real = self.D_real.mean() 272 | # Combined loss 273 | self.loss_D = self.D_fake - self.D_real 274 | gradient_penalty = self.calc_gradient_penalty(net, realB.data, fakeB.data) 275 | return self.loss_D + gradient_penalty 276 | 277 | 278 | def get_loss(model): 279 | if model['content_loss'] == 'perceptual': 280 | content_loss = PerceptualLoss() 281 | content_loss.initialize(nn.MSELoss()) 282 | elif model['content_loss'] == 'l1': 283 | content_loss = ContentLoss() 284 | content_loss.initialize(nn.L1Loss()) 285 | else: 286 | raise ValueError("ContentLoss [%s] not recognized." % model['content_loss']) 287 | 288 | if model['disc_loss'] == 'wgan-gp': 289 | disc_loss = DiscLossWGANGP() 290 | elif model['disc_loss'] == 'lsgan': 291 | disc_loss = DiscLossLS() 292 | elif model['disc_loss'] == 'gan': 293 | disc_loss = DiscLoss() 294 | elif model['disc_loss'] == 'ragan': 295 | disc_loss = RelativisticDiscLoss() 296 | elif model['disc_loss'] == 'ragan-ls': 297 | disc_loss = RelativisticDiscLossLS() 298 | else: 299 | raise ValueError("GAN Loss [%s] not recognized." % model['disc_loss']) 300 | return content_loss, disc_loss 301 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/models/mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | 5 | def conv_bn(inp, oup, stride): 6 | return nn.Sequential( 7 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 8 | nn.BatchNorm2d(oup), 9 | nn.ReLU6(inplace=True) 10 | ) 11 | 12 | 13 | def conv_1x1_bn(inp, oup): 14 | return nn.Sequential( 15 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 16 | nn.BatchNorm2d(oup), 17 | nn.ReLU6(inplace=True) 18 | ) 19 | 20 | 21 | class InvertedResidual(nn.Module): 22 | def __init__(self, inp, oup, stride, expand_ratio): 23 | super(InvertedResidual, self).__init__() 24 | self.stride = stride 25 | assert stride in [1, 2] 26 | 27 | hidden_dim = round(inp * expand_ratio) 28 | self.use_res_connect = self.stride == 1 and inp == oup 29 | 30 | if expand_ratio == 1: 31 | self.conv = nn.Sequential( 32 | # dw 33 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 34 | nn.BatchNorm2d(hidden_dim), 35 | nn.ReLU6(inplace=True), 36 | # pw-linear 37 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 38 | nn.BatchNorm2d(oup), 39 | ) 40 | else: 41 | self.conv = nn.Sequential( 42 | # pw 43 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 44 | nn.BatchNorm2d(hidden_dim), 45 | nn.ReLU6(inplace=True), 46 | # dw 47 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 48 | nn.BatchNorm2d(hidden_dim), 49 | nn.ReLU6(inplace=True), 50 | # pw-linear 51 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 52 | nn.BatchNorm2d(oup), 53 | ) 54 | 55 | def forward(self, x): 56 | if self.use_res_connect: 57 | return x + self.conv(x) 58 | else: 59 | return self.conv(x) 60 | 61 | 62 | class MobileNetV2(nn.Module): 63 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 64 | super(MobileNetV2, self).__init__() 65 | block = InvertedResidual 66 | input_channel = 32 67 | last_channel = 1280 68 | interverted_residual_setting = [ 69 | # t, c, n, s 70 | [1, 16, 1, 1], 71 | [6, 24, 2, 2], 72 | [6, 32, 3, 2], 73 | [6, 64, 4, 2], 74 | [6, 96, 3, 1], 75 | [6, 160, 3, 2], 76 | [6, 320, 1, 1], 77 | ] 78 | 79 | # building first layer 80 | assert input_size % 32 == 0 81 | input_channel = int(input_channel * width_mult) 82 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 83 | self.features = [conv_bn(3, input_channel, 2)] 84 | # building inverted residual blocks 85 | for t, c, n, s in interverted_residual_setting: 86 | output_channel = int(c * width_mult) 87 | for i in range(n): 88 | if i == 0: 89 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 90 | else: 91 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 92 | input_channel = output_channel 93 | # building last several layers 94 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 95 | # make it nn.Sequential 96 | self.features = nn.Sequential(*self.features) 97 | 98 | # building classifier 99 | self.classifier = nn.Sequential( 100 | nn.Dropout(0.2), 101 | nn.Linear(self.last_channel, n_class), 102 | ) 103 | 104 | self._initialize_weights() 105 | 106 | def forward(self, x): 107 | x = self.features(x) 108 | x = x.mean(3).mean(2) 109 | x = self.classifier(x) 110 | return x 111 | 112 | def _initialize_weights(self): 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 116 | m.weight.data.normal_(0, math.sqrt(2. / n)) 117 | if m.bias is not None: 118 | m.bias.data.zero_() 119 | elif isinstance(m, nn.BatchNorm2d): 120 | m.weight.data.fill_(1) 121 | m.bias.data.zero_() 122 | elif isinstance(m, nn.Linear): 123 | n = m.weight.size(1) 124 | m.weight.data.normal_(0, 0.01) 125 | m.bias.data.zero_() 126 | 127 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/models/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | from skimage.metrics import structural_similarity as SSIM 4 | from util.metrics import PSNR 5 | 6 | 7 | class DeblurModel(nn.Module): 8 | def __init__(self): 9 | super(DeblurModel, self).__init__() 10 | 11 | def get_input(self, data): 12 | img = data['a'] 13 | inputs = img 14 | targets = data['b'] 15 | inputs, targets = inputs.cuda(), targets.cuda() 16 | return inputs, targets 17 | 18 | def tensor2im(self, image_tensor, imtype=np.uint8): 19 | image_numpy = image_tensor[0].cpu().float().numpy() 20 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 21 | return image_numpy.astype(imtype) 22 | 23 | def get_images_and_metrics(self, inp, output, target) -> (float, float, np.ndarray): 24 | inp = self.tensor2im(inp) 25 | fake = self.tensor2im(output.data) 26 | real = self.tensor2im(target.data) 27 | psnr = PSNR(fake, real) 28 | ssim = SSIM(fake, real, multichannel=True) 29 | vis_img = np.hstack((inp, fake, real)) 30 | return psnr, ssim, vis_img 31 | 32 | 33 | def get_model(model_config): 34 | return DeblurModel() 35 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.autograd import Variable 6 | import numpy as np 7 | from models.fpn_mobilenet import FPNMobileNet 8 | from models.fpn_inception import FPNInception 9 | from models.fpn_inception_simple import FPNInceptionSimple 10 | from models.unet_seresnext import UNetSEResNext 11 | from models.fpn_densenet import FPNDense 12 | ############################################################################### 13 | # Functions 14 | ############################################################################### 15 | 16 | 17 | def get_norm_layer(norm_type='instance'): 18 | if norm_type == 'batch': 19 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 20 | elif norm_type == 'instance': 21 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True) 22 | else: 23 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 24 | return norm_layer 25 | 26 | ############################################################################## 27 | # Classes 28 | ############################################################################## 29 | 30 | 31 | # Defines the generator that consists of Resnet blocks between a few 32 | # downsampling/upsampling operations. 33 | # Code and idea originally from Justin Johnson's architecture. 34 | # https://github.com/jcjohnson/fast-neural-style/ 35 | class ResnetGenerator(nn.Module): 36 | def __init__(self, input_nc=3, output_nc=3, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, use_parallel=True, learn_residual=True, padding_type='reflect'): 37 | assert(n_blocks >= 0) 38 | super(ResnetGenerator, self).__init__() 39 | self.input_nc = input_nc 40 | self.output_nc = output_nc 41 | self.ngf = ngf 42 | self.use_parallel = use_parallel 43 | self.learn_residual = learn_residual 44 | if type(norm_layer) == functools.partial: 45 | use_bias = norm_layer.func == nn.InstanceNorm2d 46 | else: 47 | use_bias = norm_layer == nn.InstanceNorm2d 48 | 49 | model = [nn.ReflectionPad2d(3), 50 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, 51 | bias=use_bias), 52 | norm_layer(ngf), 53 | nn.ReLU(True)] 54 | 55 | n_downsampling = 2 56 | for i in range(n_downsampling): 57 | mult = 2**i 58 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, 59 | stride=2, padding=1, bias=use_bias), 60 | norm_layer(ngf * mult * 2), 61 | nn.ReLU(True)] 62 | 63 | mult = 2**n_downsampling 64 | for i in range(n_blocks): 65 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 66 | 67 | for i in range(n_downsampling): 68 | mult = 2**(n_downsampling - i) 69 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 70 | kernel_size=3, stride=2, 71 | padding=1, output_padding=1, 72 | bias=use_bias), 73 | norm_layer(int(ngf * mult / 2)), 74 | nn.ReLU(True)] 75 | model += [nn.ReflectionPad2d(3)] 76 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 77 | model += [nn.Tanh()] 78 | 79 | self.model = nn.Sequential(*model) 80 | 81 | def forward(self, input): 82 | output = self.model(input) 83 | if self.learn_residual: 84 | output = input + output 85 | output = torch.clamp(output,min = -1,max = 1) 86 | return output 87 | 88 | 89 | # Define a resnet block 90 | class ResnetBlock(nn.Module): 91 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 92 | super(ResnetBlock, self).__init__() 93 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 94 | 95 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 96 | conv_block = [] 97 | p = 0 98 | if padding_type == 'reflect': 99 | conv_block += [nn.ReflectionPad2d(1)] 100 | elif padding_type == 'replicate': 101 | conv_block += [nn.ReplicationPad2d(1)] 102 | elif padding_type == 'zero': 103 | p = 1 104 | else: 105 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 106 | 107 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 108 | norm_layer(dim), 109 | nn.ReLU(True)] 110 | if use_dropout: 111 | conv_block += [nn.Dropout(0.5)] 112 | 113 | p = 0 114 | if padding_type == 'reflect': 115 | conv_block += [nn.ReflectionPad2d(1)] 116 | elif padding_type == 'replicate': 117 | conv_block += [nn.ReplicationPad2d(1)] 118 | elif padding_type == 'zero': 119 | p = 1 120 | else: 121 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 122 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 123 | norm_layer(dim)] 124 | 125 | return nn.Sequential(*conv_block) 126 | 127 | def forward(self, x): 128 | out = x + self.conv_block(x) 129 | return out 130 | 131 | 132 | class DicsriminatorTail(nn.Module): 133 | def __init__(self, nf_mult, n_layers, ndf=64, norm_layer=nn.BatchNorm2d, use_parallel=True): 134 | super(DicsriminatorTail, self).__init__() 135 | self.use_parallel = use_parallel 136 | if type(norm_layer) == functools.partial: 137 | use_bias = norm_layer.func == nn.InstanceNorm2d 138 | else: 139 | use_bias = norm_layer == nn.InstanceNorm2d 140 | 141 | kw = 4 142 | padw = int(np.ceil((kw-1)/2)) 143 | 144 | nf_mult_prev = nf_mult 145 | nf_mult = min(2**n_layers, 8) 146 | sequence = [ 147 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 148 | kernel_size=kw, stride=1, padding=padw, bias=use_bias), 149 | norm_layer(ndf * nf_mult), 150 | nn.LeakyReLU(0.2, True) 151 | ] 152 | 153 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 154 | 155 | self.model = nn.Sequential(*sequence) 156 | 157 | def forward(self, input): 158 | return self.model(input) 159 | 160 | 161 | class MultiScaleDiscriminator(nn.Module): 162 | def __init__(self, input_nc=3, ndf=64, norm_layer=nn.BatchNorm2d, use_parallel=True): 163 | super(MultiScaleDiscriminator, self).__init__() 164 | self.use_parallel = use_parallel 165 | if type(norm_layer) == functools.partial: 166 | use_bias = norm_layer.func == nn.InstanceNorm2d 167 | else: 168 | use_bias = norm_layer == nn.InstanceNorm2d 169 | 170 | kw = 4 171 | padw = int(np.ceil((kw-1)/2)) 172 | sequence = [ 173 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 174 | nn.LeakyReLU(0.2, True) 175 | ] 176 | 177 | nf_mult = 1 178 | for n in range(1, 3): 179 | nf_mult_prev = nf_mult 180 | nf_mult = min(2**n, 8) 181 | sequence += [ 182 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 183 | kernel_size=kw, stride=2, padding=padw, bias=use_bias), 184 | norm_layer(ndf * nf_mult), 185 | nn.LeakyReLU(0.2, True) 186 | ] 187 | 188 | self.scale_one = nn.Sequential(*sequence) 189 | self.first_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=3) 190 | nf_mult_prev = 4 191 | nf_mult = 8 192 | 193 | self.scale_two = nn.Sequential( 194 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 195 | kernel_size=kw, stride=2, padding=padw, bias=use_bias), 196 | norm_layer(ndf * nf_mult), 197 | nn.LeakyReLU(0.2, True)) 198 | nf_mult_prev = nf_mult 199 | self.second_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=4) 200 | self.scale_three = nn.Sequential( 201 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 202 | norm_layer(ndf * nf_mult), 203 | nn.LeakyReLU(0.2, True)) 204 | self.third_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=5) 205 | 206 | def forward(self, input): 207 | x = self.scale_one(input) 208 | x_1 = self.first_tail(x) 209 | x = self.scale_two(x) 210 | x_2 = self.second_tail(x) 211 | x = self.scale_three(x) 212 | x = self.third_tail(x) 213 | return [x_1, x_2, x] 214 | 215 | 216 | # Defines the PatchGAN discriminator with the specified arguments. 217 | class NLayerDiscriminator(nn.Module): 218 | def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_parallel=True): 219 | super(NLayerDiscriminator, self).__init__() 220 | self.use_parallel = use_parallel 221 | if type(norm_layer) == functools.partial: 222 | use_bias = norm_layer.func == nn.InstanceNorm2d 223 | else: 224 | use_bias = norm_layer == nn.InstanceNorm2d 225 | 226 | kw = 4 227 | padw = int(np.ceil((kw-1)/2)) 228 | sequence = [ 229 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 230 | nn.LeakyReLU(0.2, True) 231 | ] 232 | 233 | nf_mult = 1 234 | for n in range(1, n_layers): 235 | nf_mult_prev = nf_mult 236 | nf_mult = min(2**n, 8) 237 | sequence += [ 238 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 239 | kernel_size=kw, stride=2, padding=padw, bias=use_bias), 240 | norm_layer(ndf * nf_mult), 241 | nn.LeakyReLU(0.2, True) 242 | ] 243 | 244 | nf_mult_prev = nf_mult 245 | nf_mult = min(2**n_layers, 8) 246 | sequence += [ 247 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 248 | kernel_size=kw, stride=1, padding=padw, bias=use_bias), 249 | norm_layer(ndf * nf_mult), 250 | nn.LeakyReLU(0.2, True) 251 | ] 252 | 253 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 254 | 255 | if use_sigmoid: 256 | sequence += [nn.Sigmoid()] 257 | 258 | self.model = nn.Sequential(*sequence) 259 | 260 | def forward(self, input): 261 | return self.model(input) 262 | 263 | 264 | def get_fullD(model_config): 265 | model_d = NLayerDiscriminator(n_layers=5, 266 | norm_layer=get_norm_layer(norm_type=model_config['norm_layer']), 267 | use_sigmoid=False) 268 | return model_d 269 | 270 | 271 | def get_generator(model_config): 272 | generator_name = model_config['g_name'] 273 | if generator_name == 'resnet': 274 | model_g = ResnetGenerator(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']), 275 | use_dropout=model_config['dropout'], 276 | n_blocks=model_config['blocks'], 277 | learn_residual=model_config['learn_residual']) 278 | elif generator_name == 'fpn_mobilenet': 279 | model_g = FPNMobileNet(norm_layer=get_norm_layer(norm_type=model_config['norm_layer'])) 280 | elif generator_name == 'fpn_inception': 281 | model_g = FPNInception(norm_layer=get_norm_layer(norm_type=model_config['norm_layer'])) 282 | elif generator_name == 'fpn_inception_simple': 283 | model_g = FPNInceptionSimple(norm_layer=get_norm_layer(norm_type=model_config['norm_layer'])) 284 | elif generator_name == 'fpn_dense': 285 | model_g = FPNDense() 286 | elif generator_name == 'unet_seresnext': 287 | model_g = UNetSEResNext(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']), 288 | pretrained=model_config['pretrained']) 289 | else: 290 | raise ValueError("Generator Network [%s] not recognized." % generator_name) 291 | 292 | return nn.DataParallel(model_g) 293 | 294 | 295 | def get_discriminator(model_config): 296 | discriminator_name = model_config['d_name'] 297 | if discriminator_name == 'no_gan': 298 | model_d = None 299 | elif discriminator_name == 'patch_gan': 300 | model_d = NLayerDiscriminator(n_layers=model_config['d_layers'], 301 | norm_layer=get_norm_layer(norm_type=model_config['norm_layer']), 302 | use_sigmoid=False) 303 | model_d = nn.DataParallel(model_d) 304 | elif discriminator_name == 'double_gan': 305 | patch_gan = NLayerDiscriminator(n_layers=model_config['d_layers'], 306 | norm_layer=get_norm_layer(norm_type=model_config['norm_layer']), 307 | use_sigmoid=False) 308 | patch_gan = nn.DataParallel(patch_gan) 309 | full_gan = get_fullD(model_config) 310 | full_gan = nn.DataParallel(full_gan) 311 | model_d = {'patch': patch_gan, 312 | 'full': full_gan} 313 | elif discriminator_name == 'multi_scale': 314 | model_d = MultiScaleDiscriminator(norm_layer=get_norm_layer(norm_type=model_config['norm_layer'])) 315 | model_d = nn.DataParallel(model_d) 316 | else: 317 | raise ValueError("Discriminator Network [%s] not recognized." % discriminator_name) 318 | 319 | return model_d 320 | 321 | 322 | def get_nets(model_config): 323 | return get_generator(model_config), get_discriminator(model_config) 324 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/models/unet_seresnext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.parallel 4 | import torch.optim 5 | import torch.utils.data 6 | from torch.nn import Sequential 7 | from collections import OrderedDict 8 | import torchvision 9 | from torch.nn import functional as F 10 | from models.senet import se_resnext50_32x4d 11 | 12 | 13 | def conv3x3(in_, out): 14 | return nn.Conv2d(in_, out, 3, padding=1) 15 | 16 | 17 | class ConvRelu(nn.Module): 18 | def __init__(self, in_, out): 19 | super(ConvRelu, self).__init__() 20 | self.conv = conv3x3(in_, out) 21 | self.activation = nn.ReLU(inplace=True) 22 | 23 | def forward(self, x): 24 | x = self.conv(x) 25 | x = self.activation(x) 26 | return x 27 | 28 | class UNetSEResNext(nn.Module): 29 | 30 | def __init__(self, num_classes=3, num_filters=32, 31 | pretrained=True, is_deconv=True): 32 | super().__init__() 33 | self.num_classes = num_classes 34 | pretrain = 'imagenet' if pretrained is True else None 35 | self.encoder = se_resnext50_32x4d(num_classes=1000, pretrained=pretrain) 36 | bottom_channel_nr = 2048 37 | 38 | self.conv1 = self.encoder.layer0 39 | #self.se_e1 = SCSEBlock(64) 40 | self.conv2 = self.encoder.layer1 41 | #self.se_e2 = SCSEBlock(64 * 4) 42 | self.conv3 = self.encoder.layer2 43 | #self.se_e3 = SCSEBlock(128 * 4) 44 | self.conv4 = self.encoder.layer3 45 | #self.se_e4 = SCSEBlock(256 * 4) 46 | self.conv5 = self.encoder.layer4 47 | #self.se_e5 = SCSEBlock(512 * 4) 48 | 49 | self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False) 50 | 51 | self.dec5 = DecoderBlockV(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 2, is_deconv) 52 | #self.se_d5 = SCSEBlock(num_filters * 2) 53 | self.dec4 = DecoderBlockV(bottom_channel_nr // 2 + num_filters * 2, num_filters * 8, num_filters * 2, is_deconv) 54 | #self.se_d4 = SCSEBlock(num_filters * 2) 55 | self.dec3 = DecoderBlockV(bottom_channel_nr // 4 + num_filters * 2, num_filters * 4, num_filters * 2, is_deconv) 56 | #self.se_d3 = SCSEBlock(num_filters * 2) 57 | self.dec2 = DecoderBlockV(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2, num_filters * 2, is_deconv) 58 | #self.se_d2 = SCSEBlock(num_filters * 2) 59 | self.dec1 = DecoderBlockV(num_filters * 2, num_filters, num_filters * 2, is_deconv) 60 | #self.se_d1 = SCSEBlock(num_filters * 2) 61 | self.dec0 = ConvRelu(num_filters * 10, num_filters * 2) 62 | self.final = nn.Conv2d(num_filters * 2, num_classes, kernel_size=1) 63 | 64 | def forward(self, x): 65 | conv1 = self.conv1(x) 66 | #conv1 = self.se_e1(conv1) 67 | conv2 = self.conv2(conv1) 68 | #conv2 = self.se_e2(conv2) 69 | conv3 = self.conv3(conv2) 70 | #conv3 = self.se_e3(conv3) 71 | conv4 = self.conv4(conv3) 72 | #conv4 = self.se_e4(conv4) 73 | conv5 = self.conv5(conv4) 74 | #conv5 = self.se_e5(conv5) 75 | 76 | center = self.center(conv5) 77 | dec5 = self.dec5(torch.cat([center, conv5], 1)) 78 | #dec5 = self.se_d5(dec5) 79 | dec4 = self.dec4(torch.cat([dec5, conv4], 1)) 80 | #dec4 = self.se_d4(dec4) 81 | dec3 = self.dec3(torch.cat([dec4, conv3], 1)) 82 | #dec3 = self.se_d3(dec3) 83 | dec2 = self.dec2(torch.cat([dec3, conv2], 1)) 84 | #dec2 = self.se_d2(dec2) 85 | dec1 = self.dec1(dec2) 86 | #dec1 = self.se_d1(dec1) 87 | 88 | f = torch.cat(( 89 | dec1, 90 | F.upsample(dec2, scale_factor=2, mode='bilinear', align_corners=False), 91 | F.upsample(dec3, scale_factor=4, mode='bilinear', align_corners=False), 92 | F.upsample(dec4, scale_factor=8, mode='bilinear', align_corners=False), 93 | F.upsample(dec5, scale_factor=16, mode='bilinear', align_corners=False), 94 | ), 1) 95 | 96 | dec0 = self.dec0(f) 97 | 98 | return self.final(dec0) 99 | 100 | class DecoderBlockV(nn.Module): 101 | def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True): 102 | super(DecoderBlockV, self).__init__() 103 | self.in_channels = in_channels 104 | 105 | if is_deconv: 106 | self.block = nn.Sequential( 107 | ConvRelu(in_channels, middle_channels), 108 | nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2, 109 | padding=1), 110 | nn.InstanceNorm2d(out_channels, affine=False), 111 | nn.ReLU(inplace=True) 112 | 113 | ) 114 | else: 115 | self.block = nn.Sequential( 116 | nn.Upsample(scale_factor=2, mode='bilinear'), 117 | ConvRelu(in_channels, middle_channels), 118 | ConvRelu(middle_channels, out_channels), 119 | ) 120 | 121 | def forward(self, x): 122 | return self.block(x) 123 | 124 | 125 | 126 | class DecoderCenter(nn.Module): 127 | def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True): 128 | super(DecoderCenter, self).__init__() 129 | self.in_channels = in_channels 130 | 131 | 132 | if is_deconv: 133 | """ 134 | Paramaters for Deconvolution were chosen to avoid artifacts, following 135 | link https://distill.pub/2016/deconv-checkerboard/ 136 | """ 137 | 138 | self.block = nn.Sequential( 139 | ConvRelu(in_channels, middle_channels), 140 | nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2, 141 | padding=1), 142 | nn.InstanceNorm2d(out_channels, affine=False), 143 | nn.ReLU(inplace=True) 144 | ) 145 | else: 146 | self.block = nn.Sequential( 147 | ConvRelu(in_channels, middle_channels), 148 | ConvRelu(middle_channels, out_channels) 149 | 150 | ) 151 | 152 | def forward(self, x): 153 | return self.block(x) 154 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/picture_to_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | 4 | # # image path 5 | # im_dir = 'D:\github repo\DeblurGANv2\submit' 6 | # # output video path 7 | # video_dir = 'D:\github repo\DeblurGANv2' 8 | # if not os.path.exists(video_dir): 9 | # os.makedirs(video_dir) 10 | # # set saved fps 11 | # fps = 20 12 | # # get frames list 13 | # frames = sorted(os.listdir(im_dir)) 14 | # # w,h of image 15 | # img = cv2.imread(os.path.join(im_dir, frames[0])) 16 | # # 17 | # img_size = (img.shape[1], img.shape[0]) 18 | # # get seq name 19 | # seq_name = os.path.dirname(im_dir).split('/')[-1] 20 | # # splice video_dir 21 | # video_dir = os.path.join(video_dir, seq_name + '.avi') 22 | # fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') 23 | # # also can write like:fourcc = cv2.VideoWriter_fourcc(*'MJPG') 24 | # # if want to write .mp4 file, use 'MP4V' 25 | # videowriter = cv2.VideoWriter(video_dir, fourcc, fps, img_size) 26 | # 27 | # for frame in frames: 28 | # f_path = os.path.join(im_dir, frame) 29 | # image = cv2.imread(f_path) 30 | # videowriter.write(image) 31 | # print(frame + " has been written!") 32 | # 33 | # videowriter.release() 34 | 35 | 36 | # def main(): 37 | # data_path = 'D:\github repo\DeblurGANv2\submit' 38 | # fps = 30 # 视频帧率 39 | # size = (1280, 720) # 需要转为视频的图片的尺寸 40 | # video = cv2.VideoWriter("output.avi", cv2.VideoWriter_fourcc('I', '4', '2', '0'), fps, size) 41 | # 42 | # for i in range(1029): 43 | # image_path = data_path + "%010d_color_labels.png" % (i + 1) 44 | # print(image_path) 45 | # img = cv2.imread(image_path) 46 | # video.write(img) 47 | # video.release() 48 | # 49 | # 50 | # 51 | # 52 | # cv2.destroyAllWindows() 53 | # 54 | # 55 | # if __name__ == "__main__": 56 | # main() 57 | 58 | 59 | import cv2 60 | import os 61 | import numpy as np 62 | from PIL import Image 63 | 64 | 65 | def frame2video(im_dir, video_dir, fps): 66 | im_list = os.listdir(im_dir) 67 | im_list.sort(key=lambda x: int(x.replace("frame", "").split('.')[0])) # 最好再看看图片顺序对不 68 | img = Image.open(os.path.join(im_dir, im_list[0])) 69 | img_size = img.size # 获得图片分辨率,im_dir文件夹下的图片分辨率需要一致 70 | 71 | # fourcc = cv2.cv.CV_FOURCC('M','J','P','G') #opencv版本是2 72 | fourcc = cv2.VideoWriter_fourcc(*'XVID') # opencv版本是3 73 | videoWriter = cv2.VideoWriter(video_dir, fourcc, fps, img_size) 74 | # count = 1 75 | for i in im_list: 76 | im_name = os.path.join(im_dir + i) 77 | frame = cv2.imdecode(np.fromfile(im_name, dtype=np.uint8), -1) 78 | videoWriter.write(frame) 79 | # count+=1 80 | # if (count == 200): 81 | # print(im_name) 82 | # break 83 | videoWriter.release() 84 | print('finish') 85 | 86 | 87 | if __name__ == '__main__': 88 | # im_dir = 'D:\github repo\DeblurGANv2\submit\\' # 帧存放路径 89 | im_dir = 'D:\github repo\DeblurGANv2\dataset1\\blur\\' # 帧存放路径\\' # 帧存放路径 90 | video_dir = 'D:\github repo\DeblurGANv2/test.avi' # 合成视频存放的路径 91 | fps = 30 # 帧率,每秒钟帧数越多,所显示的动作就会越流畅 92 | frame2video(im_dir, video_dir, fps) -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from typing import Optional 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import yaml 9 | from fire import Fire 10 | from tqdm import tqdm 11 | 12 | from aug import get_normalize 13 | from models.networks import get_generator 14 | 15 | 16 | class Predictor: 17 | def __init__(self, weights_path: str, model_name: str = ''): 18 | with open('config/config.yaml',encoding='utf-8') as cfg: 19 | config = yaml.load(cfg, Loader=yaml.FullLoader) 20 | model = get_generator(model_name or config['model']) 21 | model.load_state_dict(torch.load(weights_path)['model']) 22 | self.model = model.cuda() 23 | self.model.train(True) 24 | # GAN inference should be in train mode to use actual stats in norm layers, 25 | # it's not a bug 26 | self.normalize_fn = get_normalize() 27 | 28 | @staticmethod 29 | def _array_to_batch(x): 30 | x = np.transpose(x, (2, 0, 1)) 31 | x = np.expand_dims(x, 0) 32 | return torch.from_numpy(x) 33 | 34 | def _preprocess(self, x: np.ndarray, mask: Optional[np.ndarray]): 35 | x, _ = self.normalize_fn(x, x) 36 | if mask is None: 37 | mask = np.ones_like(x, dtype=np.float32) 38 | else: 39 | mask = np.round(mask.astype('float32') / 255) 40 | 41 | h, w, _ = x.shape 42 | block_size = 32 43 | min_height = (h // block_size + 1) * block_size 44 | min_width = (w // block_size + 1) * block_size 45 | 46 | pad_params = {'mode': 'constant', 47 | 'constant_values': 0, 48 | 'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0)) 49 | } 50 | x = np.pad(x, **pad_params) 51 | mask = np.pad(mask, **pad_params) 52 | 53 | return map(self._array_to_batch, (x, mask)), h, w 54 | 55 | @staticmethod 56 | def _postprocess(x: torch.Tensor) -> np.ndarray: 57 | x, = x 58 | x = x.detach().cpu().float().numpy() 59 | x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0 60 | return x.astype('uint8') 61 | 62 | def __call__(self, img: np.ndarray, mask: Optional[np.ndarray], ignore_mask=True) -> np.ndarray: 63 | (img, mask), h, w = self._preprocess(img, mask) 64 | with torch.no_grad(): 65 | inputs = [img.cuda()] 66 | if not ignore_mask: 67 | inputs += [mask] 68 | pred = self.model(*inputs) 69 | return self._postprocess(pred)[:h, :w, :] 70 | 71 | def process_video(pairs, predictor, output_dir): 72 | for video_filepath, mask in tqdm(pairs): 73 | video_filename = os.path.basename(video_filepath) 74 | output_filepath = os.path.join(output_dir, os.path.splitext(video_filename)[0]+'_deblur.mp4') 75 | video_in = cv2.VideoCapture(video_filepath) 76 | fps = video_in.get(cv2.CAP_PROP_FPS) 77 | width = int(video_in.get(cv2.CAP_PROP_FRAME_WIDTH)) 78 | height = int(video_in.get(cv2.CAP_PROP_FRAME_HEIGHT)) 79 | total_frame_num = int(video_in.get(cv2.CAP_PROP_FRAME_COUNT)) 80 | video_out = cv2.VideoWriter(output_filepath, cv2.VideoWriter_fourcc(*'MP4V'), fps, (width, height)) 81 | tqdm.write(f'process {video_filepath} to {output_filepath}, {fps}fps, resolution: {width}x{height}') 82 | for frame_num in tqdm(range(total_frame_num), desc=video_filename): 83 | res, img = video_in.read() 84 | if not res: 85 | break 86 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 87 | pred = predictor(img, mask) 88 | pred = cv2.cvtColor(pred, cv2.COLOR_RGB2BGR) 89 | video_out.write(pred) 90 | 91 | def main(img_pattern: str, 92 | mask_pattern: Optional[str] = None, 93 | weights_path='fpn_inception.h5', 94 | out_dir='submit/', 95 | side_by_side: bool = False, 96 | video: bool = False): 97 | def sorted_glob(pattern): 98 | return sorted(glob(pattern)) 99 | 100 | imgs = sorted_glob(img_pattern) 101 | masks = sorted_glob(mask_pattern) if mask_pattern is not None else [None for _ in imgs] 102 | pairs = zip(imgs, masks) 103 | names = sorted([os.path.basename(x) for x in glob(img_pattern)]) 104 | predictor = Predictor(weights_path=weights_path) 105 | 106 | os.makedirs(out_dir, exist_ok=True) 107 | if not video: 108 | for name, pair in tqdm(zip(names, pairs), total=len(names)): 109 | f_img, f_mask = pair 110 | img, mask = map(cv2.imread, (f_img, f_mask)) 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 112 | 113 | pred = predictor(img, mask) 114 | if side_by_side: 115 | pred = np.hstack((img, pred)) 116 | pred = cv2.cvtColor(pred, cv2.COLOR_RGB2BGR) 117 | cv2.imwrite(os.path.join(out_dir, name), 118 | pred) 119 | else: 120 | process_video(pairs, predictor, out_dir) 121 | 122 | # def getfiles(): 123 | # filenames = os.listdir(r'.\dataset1\blur') 124 | # print(filenames) 125 | def get_files(): 126 | list=[] 127 | for filepath,dirnames,filenames in os.walk(r'.\dataset1\blur'): 128 | for filename in filenames: 129 | list.append(os.path.join(filepath,filename)) 130 | return list 131 | 132 | 133 | 134 | 135 | 136 | if __name__ == '__main__': 137 | # Fire(main) 138 | #增加批量处理图片: 139 | img_path=get_files() 140 | for i in img_path: 141 | main(i) 142 | # main('test_img/tt.mp4') 143 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.0.1 2 | torchvision 3 | torchsummary 4 | pretrainedmodels 5 | numpy 6 | opencv-python-headless 7 | joblib 8 | albumentations>=1.0.0 9 | scikit-image==0.18.1 10 | tqdm 11 | glog 12 | tensorboardx 13 | fire 14 | # this file is not ready yet 15 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/schedulers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from torch.optim import lr_scheduler 4 | 5 | 6 | class WarmRestart(lr_scheduler.CosineAnnealingLR): 7 | """This class implements Stochastic Gradient Descent with Warm Restarts(SGDR): https://arxiv.org/abs/1608.03983. 8 | 9 | Set the learning rate of each parameter group using a cosine annealing schedule, When last_epoch=-1, sets initial lr as lr. 10 | This can't support scheduler.step(epoch). please keep epoch=None. 11 | """ 12 | 13 | def __init__(self, optimizer, T_max=30, T_mult=1, eta_min=0, last_epoch=-1): 14 | """implements SGDR 15 | 16 | Parameters: 17 | ---------- 18 | T_max : int 19 | Maximum number of epochs. 20 | T_mult : int 21 | Multiplicative factor of T_max. 22 | eta_min : int 23 | Minimum learning rate. Default: 0. 24 | last_epoch : int 25 | The index of last epoch. Default: -1. 26 | """ 27 | self.T_mult = T_mult 28 | super().__init__(optimizer, T_max, eta_min, last_epoch) 29 | 30 | def get_lr(self): 31 | if self.last_epoch == self.T_max: 32 | self.last_epoch = 0 33 | self.T_max *= self.T_mult 34 | return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 for 35 | base_lr in self.base_lrs] 36 | 37 | 38 | class LinearDecay(lr_scheduler._LRScheduler): 39 | """This class implements LinearDecay 40 | 41 | """ 42 | 43 | def __init__(self, optimizer, num_epochs, start_epoch=0, min_lr=0, last_epoch=-1): 44 | """implements LinearDecay 45 | 46 | Parameters: 47 | ---------- 48 | 49 | """ 50 | self.num_epochs = num_epochs 51 | self.start_epoch = start_epoch 52 | self.min_lr = min_lr 53 | super().__init__(optimizer, last_epoch) 54 | 55 | def get_lr(self): 56 | if self.last_epoch < self.start_epoch: 57 | return self.base_lrs 58 | return [base_lr - ((base_lr - self.min_lr) / self.num_epochs) * (self.last_epoch - self.start_epoch) for 59 | base_lr in self.base_lrs] 60 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | print() -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python3 -m unittest discover $(pwd) 4 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/test_aug.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from aug import get_transforms 6 | 7 | 8 | class AugTest(unittest.TestCase): 9 | @staticmethod 10 | def make_images(): 11 | img = (np.random.rand(100, 100, 3) * 255).astype('uint8') 12 | return img.copy(), img.copy() 13 | 14 | def test_aug(self): 15 | for scope in ('strong', 'weak'): 16 | for crop in ('random', 'center'): 17 | aug_pipeline = get_transforms(80, scope=scope, crop=crop) 18 | a, b = self.make_images() 19 | a, b = aug_pipeline(a, b) 20 | np.testing.assert_allclose(a, b) 21 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/test_batchsize.py: -------------------------------------------------------------------------------- 1 | # import os 2 | # def get_files(): 3 | # list=[] 4 | # for filepath,dirnames,filenames in os.walk(r'.\dataset1\blur'): 5 | # for filename in filenames: 6 | # list.append(os.path.join(filepath,filename)) 7 | # return list 8 | # 9 | # a=get_files() 10 | # print(len(a)) 11 | # 12 | # # for i in a: 13 | # # print(i) 14 | import cv2 15 | print(cv2.__version__) -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/test_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from shutil import rmtree 4 | from tempfile import mkdtemp 5 | 6 | import cv2 7 | import numpy as np 8 | from torch.utils.data import DataLoader 9 | 10 | from dataset import PairedDataset 11 | 12 | 13 | def make_img(): 14 | return (np.random.rand(100, 100, 3) * 255).astype('uint8') 15 | 16 | 17 | class AugTest(unittest.TestCase): 18 | tmp_dir = mkdtemp() 19 | raw = os.path.join(tmp_dir, 'raw') 20 | gt = os.path.join(tmp_dir, 'gt') 21 | 22 | def setUp(self): 23 | for d in (self.raw, self.gt): 24 | os.makedirs(d) 25 | 26 | for i in range(5): 27 | for d in (self.raw, self.gt): 28 | img = make_img() 29 | cv2.imwrite(os.path.join(d, f'{i}.png'), img) 30 | 31 | def tearDown(self): 32 | rmtree(self.tmp_dir) 33 | 34 | def dataset_gen(self, equal=True): 35 | base_config = {'files_a': os.path.join(self.raw, '*.png'), 36 | 'files_b': os.path.join(self.raw if equal else self.gt, '*.png'), 37 | 'size': 32, 38 | } 39 | for b in ([0, 1], [0, 0.9]): 40 | for scope in ('strong', 'weak'): 41 | for crop in ('random', 'center'): 42 | for preload in (0, 1): 43 | for preload_size in (0, 64): 44 | config = base_config.copy() 45 | config['bounds'] = b 46 | config['scope'] = scope 47 | config['crop'] = crop 48 | config['preload'] = preload 49 | config['preload_size'] = preload_size 50 | config['verbose'] = False 51 | dataset = PairedDataset.from_config(config) 52 | yield dataset 53 | 54 | def test_equal_datasets(self): 55 | for dataset in self.dataset_gen(equal=True): 56 | dataloader = DataLoader(dataset=dataset, 57 | batch_size=2, 58 | shuffle=True, 59 | drop_last=True) 60 | dataloader = iter(dataloader) 61 | batch = next(dataloader) 62 | a, b = map(lambda x: x.numpy(), map(batch.get, ('a', 'b'))) 63 | 64 | np.testing.assert_allclose(a, b) 65 | 66 | def test_datasets(self): 67 | for dataset in self.dataset_gen(equal=False): 68 | dataloader = DataLoader(dataset=dataset, 69 | batch_size=2, 70 | shuffle=True, 71 | drop_last=True) 72 | dataloader = iter(dataloader) 73 | batch = next(dataloader) 74 | a, b = map(lambda x: x.numpy(), map(batch.get, ('a', 'b'))) 75 | 76 | assert not np.all(a == b), 'images should not be the same' 77 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/test_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import numpy as np 4 | import torch 5 | import cv2 6 | import yaml 7 | import os 8 | from torchvision import models, transforms 9 | from torch.autograd import Variable 10 | import shutil 11 | import glob 12 | import tqdm 13 | from util.metrics import PSNR 14 | from albumentations import Compose, CenterCrop, PadIfNeeded 15 | from PIL import Image 16 | from ssim.ssimlib import SSIM 17 | from models.networks import get_generator 18 | 19 | 20 | def get_args(): 21 | parser = argparse.ArgumentParser('Test an image') 22 | parser.add_argument('--img_folder', required=True, help='GoPRO Folder') 23 | parser.add_argument('--weights_path', required=True, help='Weights path') 24 | 25 | return parser.parse_args() 26 | 27 | 28 | def prepare_dirs(path): 29 | if os.path.exists(path): 30 | shutil.rmtree(path) 31 | os.makedirs(path) 32 | 33 | 34 | def get_gt_image(path): 35 | dir, filename = os.path.split(path) 36 | base, seq = os.path.split(dir) 37 | base, _ = os.path.split(base) 38 | img = cv2.cvtColor(cv2.imread(os.path.join(base, 'sharp', seq, filename)), cv2.COLOR_BGR2RGB) 39 | return img 40 | 41 | 42 | def test_image(model, image_path): 43 | img_transforms = transforms.Compose([ 44 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 45 | ]) 46 | size_transform = Compose([ 47 | PadIfNeeded(736, 1280) 48 | ]) 49 | crop = CenterCrop(720, 1280) 50 | img = cv2.imread(image_path) 51 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 52 | img_s = size_transform(image=img)['image'] 53 | img_tensor = torch.from_numpy(np.transpose(img_s / 255, (2, 0, 1)).astype('float32')) 54 | img_tensor = img_transforms(img_tensor) 55 | with torch.no_grad(): 56 | img_tensor = Variable(img_tensor.unsqueeze(0).cuda()) 57 | result_image = model(img_tensor) 58 | result_image = result_image[0].cpu().float().numpy() 59 | result_image = (np.transpose(result_image, (1, 2, 0)) + 1) / 2.0 * 255.0 60 | result_image = crop(image=result_image)['image'] 61 | result_image = result_image.astype('uint8') 62 | gt_image = get_gt_image(image_path) 63 | _, filename = os.path.split(image_path) 64 | psnr = PSNR(result_image, gt_image) 65 | pilFake = Image.fromarray(result_image) 66 | pilReal = Image.fromarray(gt_image) 67 | ssim = SSIM(pilFake).cw_ssim_value(pilReal) 68 | return psnr, ssim 69 | 70 | 71 | def test(model, files): 72 | psnr = 0 73 | ssim = 0 74 | for file in tqdm.tqdm(files): 75 | cur_psnr, cur_ssim = test_image(model, file) 76 | psnr += cur_psnr 77 | ssim += cur_ssim 78 | print("PSNR = {}".format(psnr / len(files))) 79 | print("SSIM = {}".format(ssim / len(files))) 80 | 81 | 82 | if __name__ == '__main__': 83 | args = get_args() 84 | with open('config/config.yaml') as cfg: 85 | config = yaml.load(cfg) 86 | model = get_generator(config['model']) 87 | model.load_state_dict(torch.load(args.weights_path)['model']) 88 | model = model.cuda() 89 | filenames = sorted(glob.glob(args.img_folder + '/test' + '/blur/**/*.png', recursive=True)) 90 | test(model, filenames) 91 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from functools import partial 4 | 5 | import cv2 6 | import torch 7 | import torch.optim as optim 8 | import tqdm 9 | import yaml 10 | from joblib import cpu_count 11 | from torch.utils.data import DataLoader 12 | 13 | from adversarial_trainer import GANFactory 14 | from dataset import PairedDataset 15 | from metric_counter import MetricCounter 16 | from models.losses import get_loss 17 | from models.models import get_model 18 | from models.networks import get_nets 19 | from schedulers import LinearDecay, WarmRestart 20 | from fire import Fire 21 | 22 | cv2.setNumThreads(0) 23 | 24 | 25 | class Trainer: 26 | def __init__(self, config, train: DataLoader, val: DataLoader): 27 | self.config = config 28 | self.train_dataset = train 29 | self.val_dataset = val 30 | self.adv_lambda = config['model']['adv_lambda'] 31 | self.metric_counter = MetricCounter(config['experiment_desc']) 32 | self.warmup_epochs = config['warmup_num'] 33 | 34 | def train(self): 35 | self._init_params() 36 | for epoch in range(0, self.config['num_epochs']): 37 | if (epoch == self.warmup_epochs) and not (self.warmup_epochs == 0): 38 | self.netG.module.unfreeze() 39 | self.optimizer_G = self._get_optim(self.netG.parameters()) 40 | self.scheduler_G = self._get_scheduler(self.optimizer_G) 41 | self._run_epoch(epoch) 42 | self._validate(epoch) 43 | self.scheduler_G.step() 44 | self.scheduler_D.step() 45 | 46 | if self.metric_counter.update_best_model(): 47 | torch.save({ 48 | 'model': self.netG.state_dict() 49 | }, 'best_{}.h5'.format(self.config['experiment_desc'])) 50 | torch.save({ 51 | 'model': self.netG.state_dict() 52 | }, 'last_{}.h5'.format(self.config['experiment_desc'])) 53 | print(self.metric_counter.loss_message()) 54 | logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % ( 55 | self.config['experiment_desc'], epoch, self.metric_counter.loss_message())) 56 | 57 | def _run_epoch(self, epoch): 58 | self.metric_counter.clear() 59 | for param_group in self.optimizer_G.param_groups: 60 | lr = param_group['lr'] 61 | 62 | epoch_size = self.config.get('train_batches_per_epoch') or len(self.train_dataset) 63 | tq = tqdm.tqdm(self.train_dataset, total=epoch_size) 64 | tq.set_description('Epoch {}, lr {}'.format(epoch, lr)) 65 | i = 0 66 | for data in tq: 67 | inputs, targets = self.model.get_input(data) 68 | outputs = self.netG(inputs) 69 | loss_D = self._update_d(outputs, targets) 70 | self.optimizer_G.zero_grad() 71 | loss_content = self.criterionG(outputs, targets) 72 | loss_adv = self.adv_trainer.loss_g(outputs, targets) 73 | loss_G = loss_content + self.adv_lambda * loss_adv 74 | loss_G.backward() 75 | self.optimizer_G.step() 76 | self.metric_counter.add_losses(loss_G.item(), loss_content.item(), loss_D) 77 | curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets) 78 | self.metric_counter.add_metrics(curr_psnr, curr_ssim) 79 | tq.set_postfix(loss=self.metric_counter.loss_message()) 80 | if not i: 81 | self.metric_counter.add_image(img_for_vis, tag='train') 82 | i += 1 83 | if i > epoch_size: 84 | break 85 | tq.close() 86 | self.metric_counter.write_to_tensorboard(epoch) 87 | 88 | def _validate(self, epoch): 89 | self.metric_counter.clear() 90 | epoch_size = self.config.get('val_batches_per_epoch') or len(self.val_dataset) 91 | tq = tqdm.tqdm(self.val_dataset, total=epoch_size) 92 | tq.set_description('Validation') 93 | i = 0 94 | for data in tq: 95 | inputs, targets = self.model.get_input(data) 96 | with torch.no_grad(): 97 | outputs = self.netG(inputs) 98 | loss_content = self.criterionG(outputs, targets) 99 | loss_adv = self.adv_trainer.loss_g(outputs, targets) 100 | loss_G = loss_content + self.adv_lambda * loss_adv 101 | self.metric_counter.add_losses(loss_G.item(), loss_content.item()) 102 | curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets) 103 | self.metric_counter.add_metrics(curr_psnr, curr_ssim) 104 | if not i: 105 | self.metric_counter.add_image(img_for_vis, tag='val') 106 | i += 1 107 | if i > epoch_size: 108 | break 109 | tq.close() 110 | self.metric_counter.write_to_tensorboard(epoch, validation=True) 111 | 112 | def _update_d(self, outputs, targets): 113 | if self.config['model']['d_name'] == 'no_gan': 114 | return 0 115 | self.optimizer_D.zero_grad() 116 | loss_D = self.adv_lambda * self.adv_trainer.loss_d(outputs, targets) 117 | loss_D.backward(retain_graph=True) 118 | self.optimizer_D.step() 119 | return loss_D.item() 120 | 121 | def _get_optim(self, params): 122 | if self.config['optimizer']['name'] == 'adam': 123 | optimizer = optim.Adam(params, lr=self.config['optimizer']['lr']) 124 | elif self.config['optimizer']['name'] == 'sgd': 125 | optimizer = optim.SGD(params, lr=self.config['optimizer']['lr']) 126 | elif self.config['optimizer']['name'] == 'adadelta': 127 | optimizer = optim.Adadelta(params, lr=self.config['optimizer']['lr']) 128 | else: 129 | raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name']) 130 | return optimizer 131 | 132 | def _get_scheduler(self, optimizer): 133 | if self.config['scheduler']['name'] == 'plateau': 134 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 135 | mode='min', 136 | patience=self.config['scheduler']['patience'], 137 | factor=self.config['scheduler']['factor'], 138 | min_lr=self.config['scheduler']['min_lr']) 139 | elif self.config['optimizer']['name'] == 'sgdr': 140 | scheduler = WarmRestart(optimizer) 141 | elif self.config['scheduler']['name'] == 'linear': 142 | scheduler = LinearDecay(optimizer, 143 | min_lr=self.config['scheduler']['min_lr'], 144 | num_epochs=self.config['num_epochs'], 145 | start_epoch=self.config['scheduler']['start_epoch']) 146 | else: 147 | raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name']) 148 | return scheduler 149 | 150 | @staticmethod 151 | def _get_adversarial_trainer(d_name, net_d, criterion_d): 152 | if d_name == 'no_gan': 153 | return GANFactory.create_model('NoGAN') 154 | elif d_name == 'patch_gan' or d_name == 'multi_scale': 155 | return GANFactory.create_model('SingleGAN', net_d, criterion_d) 156 | elif d_name == 'double_gan': 157 | return GANFactory.create_model('DoubleGAN', net_d, criterion_d) 158 | else: 159 | raise ValueError("Discriminator Network [%s] not recognized." % d_name) 160 | 161 | def _init_params(self): 162 | self.criterionG, criterionD = get_loss(self.config['model']) 163 | self.netG, netD = get_nets(self.config['model']) 164 | self.netG.cuda() 165 | self.adv_trainer = self._get_adversarial_trainer(self.config['model']['d_name'], netD, criterionD) 166 | self.model = get_model(self.config['model']) 167 | self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters())) 168 | self.optimizer_D = self._get_optim(self.adv_trainer.get_params()) 169 | self.scheduler_G = self._get_scheduler(self.optimizer_G) 170 | self.scheduler_D = self._get_scheduler(self.optimizer_D) 171 | 172 | 173 | def main(config_path='config/config.yaml'): 174 | with open(config_path, 'r',encoding='utf-8') as f: 175 | config = yaml.load(f, Loader=yaml.SafeLoader) 176 | 177 | batch_size = config.pop('batch_size') 178 | # get_dataloader = partial(DataLoader, 179 | # batch_size=batch_size, 180 | # num_workers=0 if os.environ.get('DEBUG') else cpu_count(), 181 | # shuffle=True, drop_last=True) 182 | get_dataloader = partial(DataLoader, 183 | batch_size=batch_size, 184 | shuffle=True, drop_last=True) 185 | 186 | datasets = map(config.pop, ('train', 'val')) 187 | datasets = map(PairedDataset.from_config, datasets) 188 | train, val = map(get_dataloader, datasets) 189 | trainer = Trainer(config, train=train, val=val) 190 | trainer.train() 191 | 192 | 193 | if __name__ == '__main__': 194 | Fire(main) 195 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/train_end2end.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import partial 3 | 4 | import cv2 5 | import torch 6 | import torch.optim as optim 7 | import tqdm 8 | import yaml 9 | from torch.utils.data import DataLoader 10 | 11 | from adversarial_trainer import GANFactory 12 | from dataset import PairedDataset 13 | from metric_counter import MetricCounter 14 | from models.losses import get_loss 15 | from models.models import get_model 16 | from models.networks import get_nets 17 | from schedulers import LinearDecay, WarmRestart 18 | from fire import Fire 19 | 20 | cv2.setNumThreads(0) 21 | 22 | 23 | class Trainer: 24 | def __init__(self, config, train: DataLoader, val: DataLoader): 25 | self.config = config 26 | self.train_dataset = train 27 | self.val_dataset = val 28 | self.adv_lambda = config['model']['adv_lambda'] 29 | self.metric_counter = MetricCounter(config['experiment_desc']) 30 | self.warmup_epochs = config['warmup_num'] 31 | self._init_params() 32 | 33 | def train(self): 34 | self._init_params() 35 | for epoch in range(0, self.config['num_epochs']): 36 | if (epoch == self.warmup_epochs) and not (self.warmup_epochs == 0): 37 | self.netG.module.unfreeze() 38 | self.optimizer_G = self._get_optim(self.netG.parameters()) 39 | self.scheduler_G = self._get_scheduler(self.optimizer_G) 40 | self._run_epoch(epoch) 41 | self._validate(epoch) 42 | self.scheduler_G.step() 43 | self.scheduler_D.step() 44 | 45 | if self.metric_counter.update_best_model(): 46 | torch.save({ 47 | 'model': self.netG.state_dict() 48 | }, 'best_{}.h5'.format(self.config['experiment_desc'])) 49 | torch.save({ 50 | 'model': self.netG.state_dict() 51 | }, 'last_{}.h5'.format(self.config['experiment_desc'])) 52 | print(self.metric_counter.loss_message()) 53 | logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % ( 54 | self.config['experiment_desc'], epoch, self.metric_counter.loss_message())) 55 | 56 | def prepare(self, epoch=0): 57 | if (epoch == self.warmup_epochs) and not (self.warmup_epochs == 0): 58 | self.netG.module.unfreeze() 59 | self.optimizer_G = self._get_optim(self.netG.parameters()) 60 | self.scheduler_G = self._get_scheduler(self.optimizer_G) 61 | 62 | def run(self, inputs, targets, is_inference=False, num_iters=1, desc=None): 63 | """ 64 | Runs the network given an input blur image and a ground truth image. 65 | 66 | Customized for end2end training. 67 | 68 | Args: 69 | inputs: Input image(s). 70 | targets: Ground truth image(s). 71 | is_inference: Is inference. 72 | num_iters: Number of iterations on this batch of inputs and targets. 73 | desc: Description of the progress bar. 74 | 75 | Returns: 76 | outputs: Output image(s). 77 | """ 78 | 79 | if is_inference: 80 | with torch.no_grad(): 81 | outputs = self.netG(inputs) 82 | loss_content = self.criterionG(outputs, targets) 83 | loss_adv = self.adv_trainer.loss_g(outputs, targets) 84 | loss_G = loss_content + self.adv_lambda * loss_adv 85 | 86 | return outputs.detach() 87 | 88 | if desc is not None: 89 | tq = tqdm.tqdm(range(num_iters), total=num_iters) 90 | tq.set_description(desc) 91 | else: 92 | tq = range(num_iters) 93 | 94 | for it in tq: 95 | outputs = self.netG(inputs) 96 | loss_D = self._update_d(outputs, targets) 97 | self.optimizer_G.zero_grad() 98 | loss_content = self.criterionG(outputs, targets) 99 | loss_adv = self.adv_trainer.loss_g(outputs, targets) 100 | loss_G = loss_content + self.adv_lambda * loss_adv 101 | loss_G.backward() 102 | self.optimizer_G.step() 103 | self.metric_counter.add_losses(loss_G.item(), loss_content.item(), loss_D) 104 | curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs.detach(), outputs, targets) 105 | self.metric_counter.add_metrics(curr_psnr, curr_ssim) 106 | 107 | if desc is not None: 108 | tq.set_postfix(loss=self.metric_counter.loss_message()) 109 | 110 | if desc is not None: 111 | tq.close() 112 | 113 | return outputs.detach() 114 | 115 | def _run_epoch(self, epoch): 116 | self.metric_counter.clear() 117 | for param_group in self.optimizer_G.param_groups: 118 | lr = param_group['lr'] 119 | 120 | epoch_size = self.config.get('train_batches_per_epoch') or len(self.train_dataset) 121 | tq = tqdm.tqdm(self.train_dataset, total=epoch_size) 122 | tq.set_description('Epoch {}, lr {}'.format(epoch, lr)) 123 | i = 0 124 | for data in tq: 125 | inputs, targets = self.model.get_input(data) 126 | outputs = self.netG(inputs) 127 | loss_D = self._update_d(outputs, targets) 128 | self.optimizer_G.zero_grad() 129 | loss_content = self.criterionG(outputs, targets) 130 | loss_adv = self.adv_trainer.loss_g(outputs, targets) 131 | loss_G = loss_content + self.adv_lambda * loss_adv 132 | loss_G.backward() 133 | self.optimizer_G.step() 134 | self.metric_counter.add_losses(loss_G.item(), loss_content.item(), loss_D) 135 | curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets) 136 | self.metric_counter.add_metrics(curr_psnr, curr_ssim) 137 | tq.set_postfix(loss=self.metric_counter.loss_message()) 138 | if not i: 139 | self.metric_counter.add_image(img_for_vis, tag='train') 140 | i += 1 141 | if i > epoch_size: 142 | break 143 | tq.close() 144 | self.metric_counter.write_to_tensorboard(epoch) 145 | 146 | def _validate(self, epoch): 147 | self.metric_counter.clear() 148 | epoch_size = self.config.get('val_batches_per_epoch') or len(self.val_dataset) 149 | tq = tqdm.tqdm(self.val_dataset, total=epoch_size) 150 | tq.set_description('Validation') 151 | i = 0 152 | for data in tq: 153 | inputs, targets = self.model.get_input(data) 154 | with torch.no_grad(): 155 | outputs = self.netG(inputs) 156 | loss_content = self.criterionG(outputs, targets) 157 | loss_adv = self.adv_trainer.loss_g(outputs, targets) 158 | loss_G = loss_content + self.adv_lambda * loss_adv 159 | self.metric_counter.add_losses(loss_G.item(), loss_content.item()) 160 | curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets) 161 | self.metric_counter.add_metrics(curr_psnr, curr_ssim) 162 | if not i: 163 | self.metric_counter.add_image(img_for_vis, tag='val') 164 | i += 1 165 | if i > epoch_size: 166 | break 167 | tq.close() 168 | self.metric_counter.write_to_tensorboard(epoch, validation=True) 169 | 170 | def _update_d(self, outputs, targets): 171 | if self.config['model']['d_name'] == 'no_gan': 172 | return 0 173 | self.optimizer_D.zero_grad() 174 | loss_D = self.adv_lambda * self.adv_trainer.loss_d(outputs, targets) 175 | loss_D.backward(retain_graph=True) 176 | self.optimizer_D.step() 177 | return loss_D.item() 178 | 179 | def _get_optim(self, params): 180 | if self.config['optimizer']['name'] == 'adam': 181 | optimizer = optim.Adam(params, lr=self.config['optimizer']['lr']) 182 | elif self.config['optimizer']['name'] == 'sgd': 183 | optimizer = optim.SGD(params, lr=self.config['optimizer']['lr']) 184 | elif self.config['optimizer']['name'] == 'adadelta': 185 | optimizer = optim.Adadelta(params, lr=self.config['optimizer']['lr']) 186 | else: 187 | raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name']) 188 | return optimizer 189 | 190 | def _get_scheduler(self, optimizer): 191 | if self.config['scheduler']['name'] == 'plateau': 192 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 193 | mode='min', 194 | patience=self.config['scheduler']['patience'], 195 | factor=self.config['scheduler']['factor'], 196 | min_lr=self.config['scheduler']['min_lr']) 197 | elif self.config['optimizer']['name'] == 'sgdr': 198 | scheduler = WarmRestart(optimizer) 199 | elif self.config['scheduler']['name'] == 'linear': 200 | scheduler = LinearDecay(optimizer, 201 | min_lr=self.config['scheduler']['min_lr'], 202 | num_epochs=self.config['num_epochs'], 203 | start_epoch=self.config['scheduler']['start_epoch']) 204 | else: 205 | raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name']) 206 | return scheduler 207 | 208 | @staticmethod 209 | def _get_adversarial_trainer(d_name, net_d, criterion_d): 210 | if d_name == 'no_gan': 211 | return GANFactory.create_model('NoGAN') 212 | elif d_name == 'patch_gan' or d_name == 'multi_scale': 213 | return GANFactory.create_model('SingleGAN', net_d, criterion_d) 214 | elif d_name == 'double_gan': 215 | return GANFactory.create_model('DoubleGAN', net_d, criterion_d) 216 | else: 217 | raise ValueError("Discriminator Network [%s] not recognized." % d_name) 218 | 219 | def _init_params(self): 220 | self.criterionG, criterionD = get_loss(self.config['model']) 221 | self.netG, netD = get_nets(self.config['model']) 222 | self.netG.cuda() 223 | self.adv_trainer = self._get_adversarial_trainer(self.config['model']['d_name'], netD, criterionD) 224 | self.model = get_model(self.config['model']) 225 | self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters())) 226 | self.optimizer_D = self._get_optim(self.adv_trainer.get_params()) 227 | self.scheduler_G = self._get_scheduler(self.optimizer_G) 228 | self.scheduler_D = self._get_scheduler(self.optimizer_D) 229 | 230 | 231 | def load_from_config(config_path='config/config.yaml'): 232 | with open(config_path, 'r', encoding='utf-8') as f: 233 | config = yaml.load(f, Loader=yaml.SafeLoader) 234 | 235 | batch_size = config.pop('batch_size') 236 | # get_dataloader = partial(DataLoader, 237 | # batch_size=batch_size, 238 | # num_workers=0 if os.environ.get('DEBUG') else cpu_count(), 239 | # shuffle=True, drop_last=True) 240 | get_dataloader = partial(DataLoader, 241 | batch_size=batch_size, 242 | shuffle=True, drop_last=True) 243 | 244 | datasets = map(config.pop, ('train', 'val')) 245 | datasets = map(PairedDataset.from_config, datasets) 246 | train, val = map(get_dataloader, datasets) 247 | trainer = Trainer(config, train=train, val=val) 248 | 249 | return trainer 250 | 251 | 252 | def main(): 253 | trainer = load_from_config() 254 | trainer.train() 255 | 256 | 257 | if __name__ == '__main__': 258 | Fire(main) 259 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/neural_networks/DeblurGANv2/util/__init__.py -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | from collections import deque 6 | 7 | 8 | class ImagePool(): 9 | def __init__(self, pool_size): 10 | self.pool_size = pool_size 11 | self.sample_size = pool_size 12 | if self.pool_size > 0: 13 | self.num_imgs = 0 14 | self.images = deque() 15 | 16 | def add(self, images): 17 | if self.pool_size == 0: 18 | return images 19 | for image in images.data: 20 | image = torch.unsqueeze(image, 0) 21 | if self.num_imgs < self.pool_size: 22 | self.num_imgs = self.num_imgs + 1 23 | self.images.append(image) 24 | else: 25 | self.images.popleft() 26 | self.images.append(image) 27 | 28 | def query(self): 29 | if len(self.images) > self.sample_size: 30 | return_images = list(random.sample(self.images, self.sample_size)) 31 | else: 32 | return_images = list(self.images) 33 | return torch.cat(return_images, 0) 34 | -------------------------------------------------------------------------------- /examples/neural_networks/DeblurGANv2/util/metrics.py: -------------------------------------------------------------------------------- 1 | import math 2 | from math import exp 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | 10 | def gaussian(window_size, sigma): 11 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 12 | return gauss / gauss.sum() 13 | 14 | 15 | def create_window(window_size, channel): 16 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 17 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 18 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 19 | return window 20 | 21 | 22 | def SSIM(img1, img2): 23 | (_, channel, _, _) = img1.size() 24 | window_size = 11 25 | window = create_window(window_size, channel) 26 | 27 | if img1.is_cuda: 28 | window = window.cuda(img1.get_device()) 29 | window = window.type_as(img1) 30 | 31 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 32 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 33 | 34 | mu1_sq = mu1.pow(2) 35 | mu2_sq = mu2.pow(2) 36 | mu1_mu2 = mu1 * mu2 37 | 38 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 39 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 40 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 41 | 42 | C1 = 0.01 ** 2 43 | C2 = 0.03 ** 2 44 | 45 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 46 | return ssim_map.mean() 47 | 48 | 49 | def PSNR(img1, img2): 50 | mse = np.mean((img1 / 255. - img2 / 255.) ** 2) 51 | if mse == 0: 52 | return 100 53 | PIXEL_MAX = 1 54 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 55 | -------------------------------------------------------------------------------- /examples/nikon.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | from pathlib import Path 5 | 6 | import sys 7 | sys.path.append("../") 8 | import diffoptics as do 9 | from datetime import datetime 10 | 11 | """ 12 | DISCLAIMER: 13 | 14 | This script was used to generate Figure 11 in the paper. However, the results produced in 15 | the paper mistakenly assumed the air refractive index to be 1, rather than 1.000293. This 16 | slight difference will produce a slightly different result from the paper. 17 | 18 | If you want to reproduce the result in the paper, change the following item in 19 | 20 | diffoptics.basics.Material.__init__.MATERIAL_TABLE 21 | 22 | from: 23 | 24 | "air": [1.000293, np.inf], 25 | 26 | to 27 | 28 | "air": [1.000000, np.inf], 29 | 30 | And re-run this script. 31 | """ 32 | 33 | 34 | # initialize a lens 35 | device = torch.device('cpu') 36 | # device = torch.device('cuda') 37 | lens = do.Lensgroup(device=device) 38 | 39 | # load optics 40 | lens.load_file(Path('./lenses/Zemax_samples/Nikon-z35-f1.8-JPA2019090949-example2.txt')) 41 | # lens.plot_setup2D() 42 | 43 | # sample wavelengths in [nm] 44 | wavelengths = torch.Tensor([656.2725, 587.5618, 486.1327]).to(device) 45 | views = np.array([0, 10, 20, 32.45]) 46 | colors_list = 'bgry' 47 | 48 | def plot_layout(string): 49 | ax, fig = lens.plot_setup2D_with_trace(views, wavelengths[1], M=5, entrance_pupil=True) 50 | ax.axis('off') 51 | ax.set_title("") 52 | fig.savefig("layout_trace_" + string + "_" + datetime.now().strftime('%Y%m%d-%H%M%S-%f') + ".pdf", bbox_inches='tight') 53 | 54 | M = 31 55 | def render(verbose=False, entrance_pupil=False): 56 | def render_single(wavelength): 57 | pss = [] 58 | spot_rms = [] 59 | loss = 0.0 60 | for view in views: 61 | ray = lens.sample_ray(wavelength, view=view, M=M, sampling='grid', entrance_pupil=entrance_pupil) 62 | ps = lens.trace_to_sensor(ray, ignore_invalid=True) 63 | 64 | # calculate RMS 65 | tmp, ps = lens.rms(ps[...,:2], squared=True) 66 | loss = loss + tmp 67 | pss.append(ps) 68 | spot_rms.append(np.sqrt(tmp.item())) 69 | return pss, loss, np.array(spot_rms) 70 | 71 | pss_all = [] 72 | rms_all = [] 73 | loss = 0.0 74 | for wavelength in wavelengths: 75 | if verbose: 76 | print("Rendering wavelength = {} [nm] ...".format(wavelength.item())) 77 | pss, loss_single, rmss = render_single(wavelength) 78 | loss = loss + loss_single 79 | pss_all.append(pss) 80 | rms_all.append(rmss) 81 | return pss_all, loss, np.array(rms_all) 82 | 83 | def func(): 84 | ps = render()[0] 85 | return torch.vstack([torch.vstack(ps[i]) for i in range(len(ps))]) 86 | 87 | def loss_func(): 88 | return render()[1] 89 | 90 | def info(string): 91 | loss, rms = render()[1:] 92 | print("=== {} ===".format(string)) 93 | print("loss = {}".format(loss)) 94 | print("==========") 95 | plot_layout(string) 96 | return rms 97 | 98 | rms_org = info('original') 99 | print(rms_org.mean()) 100 | 101 | id_range = list(range(0, 19)) 102 | id_range.pop(lens.aperture_ind) 103 | id_asphere = [16, 17] 104 | for i in id_asphere: 105 | lens.surfaces[i].ai = torch.Tensor([0.0]).to(device) 106 | 107 | diff_names = [] 108 | diff_names += ['surfaces[{}].c'.format(str(i)) for i in id_range] 109 | diff_names += ['surfaces[{}].k'.format(str(i)) for i in id_asphere] 110 | diff_names += ['surfaces[{}].ai'.format(str(i)) for i in id_asphere] 111 | 112 | rms_init = info('initial') 113 | print(rms_init.mean()) 114 | 115 | # optimize 116 | start = torch.cuda.Event(enable_timing=True) 117 | end = torch.cuda.Event(enable_timing=True) 118 | 119 | start.record() 120 | out = do.LM(lens, diff_names, 1e-2, option='diag') \ 121 | .optimize(func, lambda y: 0.0 - y, maxit=100, record=True) 122 | end.record() 123 | torch.cuda.synchronize() 124 | print('Finished in {:.2f} mins'.format(start.elapsed_time(end)/1000/60)) 125 | 126 | rms_opt = info('optimized') 127 | print(rms_opt.mean()) 128 | 129 | # plot loss 130 | fig, ax = plt.subplots(figsize=(12,6)) 131 | ax.semilogy(out['ls'], 'k-o', linewidth=3) 132 | plt.xlabel('iteration') 133 | plt.ylabel('error function') 134 | plt.savefig("./ls_nikon.pdf", bbox_inches='tight') 135 | 136 | def save_fig(xs, string): 137 | fig, ax = plt.subplots(figsize=(3,1.5)) 138 | xs = xs.T 139 | for i, x in enumerate(xs): 140 | ax.semilogy(x, colors_list[i], marker='o', linewidth=1) 141 | plt.xlabel('wavelength [nm]') 142 | plt.ylabel('RMS spot size [um]') 143 | plt.ylim([0.08, 50]) 144 | plt.xticks([0,1,2], ['656.27', '587.56', '486.13']) 145 | plt.yticks([0.1,1,10,50], ['0.1', '1', '10', '50']) 146 | fig.savefig("./rms_" + string + "_nikon.pdf", bbox_inches='tight') 147 | 148 | save_fig(rms_init * 1e3, "init") 149 | save_fig(rms_org * 1e3, "org") 150 | save_fig(rms_opt * 1e3, "opt") 151 | -------------------------------------------------------------------------------- /examples/render_image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | import matplotlib.pyplot as plt 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | 8 | import sys 9 | sys.path.append("../") 10 | import diffoptics as do 11 | 12 | # initialize a lens 13 | device = torch.device('cuda') 14 | lens = do.Lensgroup(device=device) 15 | 16 | # load optics 17 | lens.load_file(Path('./lenses/DoubleGauss/US02532751-1.txt')) 18 | 19 | # set sensor pixel size and film size 20 | pixel_size = 6.45e-3 # [mm] 21 | film_size = [768, 1024] 22 | 23 | # set a rendering image sensor, and call prepare_mts to prepare the lensgroup for rendering 24 | lens.prepare_mts(pixel_size, film_size) 25 | # lens.plot_setup2D() 26 | 27 | # create a dummy screen 28 | z0 = 10e3 # [mm] 29 | pixelsize = 1.1 # [mm] 30 | texture = cv2.cvtColor(cv2.imread('./images/squirrel.jpg'), cv2.COLOR_BGR2RGB) 31 | texture = np.flip(texture.astype(np.float32), axis=(0,1)).copy() 32 | texture_torch = torch.Tensor(texture).to(device=device) 33 | texturesize = np.array(texture.shape[0:2]) 34 | screen = do.Screen( 35 | do.Transformation(np.eye(3), np.array([0, 0, z0])), 36 | texturesize * pixelsize, texture_torch, device=device 37 | ) 38 | 39 | # helper function 40 | def render_single(wavelength, screen): 41 | valid, ray_new = lens.sample_ray_sensor(wavelength) 42 | uv, valid_screen = screen.intersect(ray_new)[1:] 43 | mask = valid & valid_screen 44 | I = screen.shading(uv, mask) 45 | return I, mask 46 | 47 | # sample wavelengths in [nm] 48 | wavelengths = [656.2725, 587.5618, 486.1327] 49 | 50 | # render 51 | ray_counts_per_pixel = 100 52 | Is = [] 53 | for wavelength_id, wavelength in enumerate(wavelengths): 54 | screen.update_texture(texture_torch[..., wavelength_id]) 55 | 56 | # multi-pass rendering by sampling the aperture 57 | I = 0 58 | M = 0 59 | for i in tqdm(range(ray_counts_per_pixel)): 60 | I_current, mask = render_single(wavelength, screen) 61 | I = I + I_current 62 | M = M + mask 63 | I = I / (M + 1e-10) 64 | 65 | # reshape data to a 2D image 66 | I = I.reshape(*np.flip(np.asarray(film_size))).permute(1,0) 67 | Is.append(I.cpu()) 68 | 69 | # show image 70 | I_rendered = torch.stack(Is, axis=-1).numpy().astype(np.uint8) 71 | plt.imshow(I_rendered) 72 | plt.show() 73 | plt.imsave('I_rendered.png', I_rendered) 74 | -------------------------------------------------------------------------------- /examples/render_psf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | from pathlib import Path 5 | from tqdm import tqdm 6 | 7 | import sys 8 | sys.path.append("../") 9 | import diffoptics as do 10 | 11 | # initialize a lens 12 | device = torch.device('cuda') 13 | lens = do.Lensgroup(device=device) 14 | 15 | # load optics 16 | lens.load_file(Path('./lenses/DoubleGauss/US02532751-1.txt')) 17 | 18 | # sensor area 19 | pixel_size = 6.45e-3 # [mm] 20 | film_size = torch.tensor([1200, 1600], device=device) 21 | R_square = film_size * pixel_size 22 | 23 | # generate array of rays 24 | wavelength = 532.8 # [nm] 25 | R = 5.0 # [mm] 26 | # lens.plot_setup2D() 27 | 28 | def render_psf(I, p): 29 | # compute shifts and do linear interpolation 30 | uv = (p + R_square/2) / pixel_size 31 | index_l = torch.vstack(( 32 | torch.clamp(torch.floor(uv[:,0]).long(), min=0, max=film_size[0]), 33 | torch.clamp(torch.floor(uv[:,1]).long(), min=0, max=film_size[1])) 34 | ).T 35 | index_r = torch.vstack(( 36 | torch.clamp(index_l[:,0] + 1, min=0, max=film_size[0]), 37 | torch.clamp(index_l[:,1] + 1, min=0, max=film_size[1])) 38 | ).T 39 | w_r = torch.clamp(uv - index_l, min=0, max=1) 40 | w_l = 1.0 - w_r 41 | del uv 42 | 43 | # compute image 44 | I = torch.index_put(I, (index_l[...,0],index_l[...,1]), w_l[...,0]*w_l[...,1], accumulate=True) 45 | I = torch.index_put(I, (index_r[...,0],index_l[...,1]), w_r[...,0]*w_l[...,1], accumulate=True) 46 | I = torch.index_put(I, (index_l[...,0],index_r[...,1]), w_l[...,0]*w_r[...,1], accumulate=True) 47 | I = torch.index_put(I, (index_r[...,0],index_r[...,1]), w_r[...,0]*w_r[...,1], accumulate=True) 48 | return I 49 | 50 | def generate_surface_samples(M): 51 | Dx = np.random.rand(M,M) 52 | Dy = np.random.rand(M,M) 53 | [px, py] = do.Sampler().concentric_sample_disk(Dx, Dy) 54 | return np.stack((px.flatten(), py.flatten()), axis=1) 55 | 56 | def sample_ray(o_obj, M): 57 | p_aperture_2d = R * generate_surface_samples(M) 58 | N = p_aperture_2d.shape[0] 59 | p_aperture = np.hstack((p_aperture_2d, np.zeros((N,1)))).reshape((N,3)) 60 | o = np.ones(N)[:, None] * o_obj[None, :] 61 | 62 | o = o.astype(np.float32) 63 | p_aperture = p_aperture.astype(np.float32) 64 | 65 | d = do.normalize(torch.from_numpy(p_aperture - o)) 66 | 67 | o = torch.from_numpy(o).to(lens.device) 68 | d = d.to(lens.device) 69 | 70 | return do.Ray(o, d, wavelength, device=lens.device) 71 | 72 | def render(o_obj, M, rep_count): 73 | I = torch.zeros(*film_size, device=device) 74 | for i in range(rep_count): 75 | rays = sample_ray(o_obj, M) 76 | ps = lens.trace_to_sensor(rays, ignore_invalid=True) 77 | I = render_psf(I, ps[..., :2]) 78 | return I / rep_count 79 | 80 | # PSF rendering parameters 81 | x_max_halfangle = 10 # [deg] 82 | y_max_halfangle = 7.5 # [deg] 83 | Nx = 2 * 8 + 1 84 | Ny = 2 * 6 + 1 85 | 86 | # sampling parameters 87 | M = 1001 88 | rep_count = 1 89 | 90 | def render_at_depth(z): 91 | x_halfmax = np.abs(z) * np.tan(np.deg2rad(x_max_halfangle)) 92 | y_halfmax = np.abs(z) * np.tan(np.deg2rad(y_max_halfangle)) 93 | 94 | I_psf_all = torch.zeros(*film_size, device=device) 95 | for x in tqdm(np.linspace(-x_halfmax, x_halfmax, Nx)): 96 | for y in np.linspace(-y_halfmax, y_halfmax, Ny): 97 | o_obj = np.array([y, x, z]) 98 | I_psf = render(o_obj, M, rep_count) 99 | I_psf_all = I_psf_all + I_psf 100 | return I_psf_all 101 | 102 | # render PSF at different depths 103 | zs = [-1e4, -7e3, -5e3, -3e3, -2e3, -1.5e3, -1e3] 104 | savedir = Path('./rendered_psfs') 105 | savedir.mkdir(exist_ok=True, parents=True) 106 | I_psfs = [] 107 | for z in zs: 108 | I_psf = render_at_depth(z) 109 | I_psf = I_psf.cpu().numpy() 110 | plt.imsave(str(savedir / 'I_psf_z={}.png'.format(z)), np.uint8(255 * I_psf / I_psf.max()), cmap='hot') 111 | I_psfs.append(I_psf) 112 | -------------------------------------------------------------------------------- /examples/sanity_check.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import matplotlib.pyplot as plt 5 | from pathlib import Path 6 | 7 | import sys 8 | sys.path.append("../") 9 | import diffoptics as do 10 | 11 | # initialize a lens 12 | device = torch.device('cuda') 13 | lens = do.Lensgroup(device=device) 14 | 15 | # load optics 16 | lens.load_file(Path('./lenses/DoubleGauss/US02532751-1.txt')) 17 | 18 | # sample wavelengths in [nm] 19 | wavelengths = torch.Tensor([656.2725, 587.5618, 486.1327]).to(device) 20 | 21 | colors_list = 'bgry' 22 | views = np.linspace(0, 21, 4, endpoint=True) 23 | ax, fig = lens.plot_setup2D_with_trace(views, wavelengths[1], M=4) 24 | ax.axis('off') 25 | ax.set_title('Sanity Check Setup 2D') 26 | fig.savefig('sanity_check_setup.pdf') 27 | 28 | # spot diagrams 29 | spot_rmss = [] 30 | valid_maps = [] 31 | for i, view in enumerate(views): 32 | ray = lens.sample_ray(wavelengths[1], view=view, M=31, sampling='grid', entrance_pupil=True) 33 | ps = lens.trace_to_sensor(ray, ignore_invalid=True) 34 | lim = 20e-3 35 | lens.spot_diagram( 36 | ps[...,:2], show=True, xlims=[-lim, lim], ylims=[-lim, lim], color=colors_list[i]+'.', 37 | savepath='sanity_check_field_view_{}.png'.format(int(view)) 38 | ) 39 | 40 | spot_rmss.append(lens.rms(ps)) 41 | 42 | plt.show() 43 | -------------------------------------------------------------------------------- /examples/spherical_aberration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | from pathlib import Path 4 | 5 | import sys 6 | sys.path.append("../") 7 | import diffoptics as do 8 | 9 | # initialization 10 | # device = do.init() 11 | device = torch.device('cpu') 12 | 13 | # load target lens 14 | lens = do.Lensgroup(device=device) 15 | lens.load_file(Path('./lenses/Thorlabs/ACL5040U.txt')) 16 | print(lens.surfaces[0]) 17 | 18 | # generate array of rays 19 | wavelength = torch.Tensor([532.8]).to(device) # [nm] 20 | R = 15.0 # [mm] 21 | def render(): 22 | ray_init = lens.sample_ray(wavelength, M=31, R=R) 23 | ps = lens.trace_to_sensor(ray_init) 24 | return ps[...,:2] 25 | 26 | def trace_all(): 27 | ray_init = lens.sample_ray_2D(R, wavelength, M=15) 28 | ps, oss = lens.trace_to_sensor_r(ray_init) 29 | return ps[...,:2], oss 30 | ps, oss = trace_all() 31 | ax, fig = lens.plot_raytraces(oss) 32 | 33 | ax, fig = lens.plot_setup2D_with_trace([0.0], wavelength, M=5, R=R) 34 | ax.axis('off') 35 | ax.set_title("") 36 | fig.savefig("layout_trace_asphere.pdf", bbox_inches='tight') 37 | 38 | # show initial RMS 39 | ps_org = render() 40 | L_org = torch.mean(torch.sum(torch.square(ps_org), axis=-1)) 41 | print('original loss: {:.3e}'.format(L_org)) 42 | lens.spot_diagram(ps_org, xlims=[-50.0e-3, 50.0e-3], ylims=[-50.0e-3, 50.0e-3]) 43 | 44 | diff_names = [ 45 | 'surfaces[0].c', 46 | 'surfaces[0].k', 47 | 'surfaces[0].ai' 48 | ] 49 | 50 | # optimize 51 | out = do.LM(lens, diff_names, 1e-4, option='diag') \ 52 | .optimize(render, lambda y: 0.0 - y, maxit=300, record=True) 53 | 54 | # show loss 55 | plt.figure() 56 | plt.semilogy(out['ls'], '-o') 57 | plt.xlabel('Iteration') 58 | plt.ylabel('Loss') 59 | plt.show() 60 | 61 | # show spot diagram 62 | ps = render() 63 | L = torch.mean(torch.sum(torch.square(ps), axis=-1)) 64 | print('final loss: {:.3e}'.format(L)) 65 | lens.spot_diagram(ps, xlims=[-50.0e-3, 50.0e-3], ylims=[-50.0e-3, 50.0e-3]) 66 | print(lens.surfaces[0]) 67 | # lens.plot_setup2D() 68 | -------------------------------------------------------------------------------- /examples/training_dataset/0008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/training_dataset/0008.png -------------------------------------------------------------------------------- /examples/training_dataset/0010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/training_dataset/0010.png -------------------------------------------------------------------------------- /examples/training_dataset/0023.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/training_dataset/0023.png -------------------------------------------------------------------------------- /examples/training_dataset/0030.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/training_dataset/0030.png -------------------------------------------------------------------------------- /examples/training_dataset/0031.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/training_dataset/0031.png -------------------------------------------------------------------------------- /examples/training_dataset/0032.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/training_dataset/0032.png -------------------------------------------------------------------------------- /examples/training_dataset/0115.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/training_dataset/0115.png -------------------------------------------------------------------------------- /examples/training_dataset/0267.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/training_dataset/0267.png -------------------------------------------------------------------------------- /examples/utils_end2end.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | import cv2 6 | import pathlib 7 | 8 | 9 | def dict_to_tensor(xs): 10 | """ 11 | Concatenates (or, packs) a dictionary of differentiable parameters into a Tensor array. 12 | """ 13 | return torch.cat([x.view(-1) for x in xs], 0) 14 | 15 | def tensor_to_dict(xs, diff_parameters): 16 | """ 17 | Unpacks a Tensor array into a dictionary of differentiable parameters. 18 | """ 19 | ys = [] 20 | idx = 0 21 | for diff_para in diff_parameters: 22 | y = xs[idx:idx+diff_para.numel()] 23 | idx += diff_para.numel() 24 | ys.append(y) 25 | return ys 26 | 27 | def load_image(image_name): 28 | img = cv2.cvtColor(cv2.imread(image_name), cv2.COLOR_BGR2RGB) 29 | img = np.flip(img.astype(np.float32), axis=(0,1)).copy() 30 | return img 31 | 32 | 33 | class ImageFolder(torch.utils.data.Dataset): 34 | def __init__(self, root_path): 35 | super(ImageFolder, self).__init__() 36 | self.file_names = sorted(os.listdir(root_path)) 37 | self.root_path = pathlib.Path(root_path) 38 | 39 | def __len__(self): 40 | return len(self.file_names) 41 | 42 | def __getitem__(self, index): 43 | return load_image(str(self.root_path / self.file_names[index])) 44 | 45 | 46 | def load_deblurganv2(): 47 | """ 48 | Loads the DeblurGANv2 as the neural network backend. 49 | """ 50 | neural_network_path = pathlib.Path('./neural_networks/DeblurGANv2') 51 | sys.path.append(str(neural_network_path)) 52 | 53 | import train_end2end 54 | 55 | trainer = train_end2end.load_from_config(str(neural_network_path / 'config/config.yaml')) 56 | 57 | return trainer 58 | -------------------------------------------------------------------------------- /imgs/abp.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/abp.jpg -------------------------------------------------------------------------------- /imgs/applications.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/applications.jpg -------------------------------------------------------------------------------- /imgs/bp_abp_comp.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/bp_abp_comp.jpg -------------------------------------------------------------------------------- /imgs/examples/I.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/I.jpg -------------------------------------------------------------------------------- /imgs/examples/I0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/I0.jpg -------------------------------------------------------------------------------- /imgs/examples/I_final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/I_final.png -------------------------------------------------------------------------------- /imgs/examples/I_psf_z=-1000.0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/I_psf_z=-1000.0.png -------------------------------------------------------------------------------- /imgs/examples/I_psf_z=-10000.0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/I_psf_z=-10000.0.png -------------------------------------------------------------------------------- /imgs/examples/I_psf_z=-1500.0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/I_psf_z=-1500.0.png -------------------------------------------------------------------------------- /imgs/examples/I_psf_z=-2000.0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/I_psf_z=-2000.0.png -------------------------------------------------------------------------------- /imgs/examples/I_psf_z=-3000.0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/I_psf_z=-3000.0.png -------------------------------------------------------------------------------- /imgs/examples/I_rendered.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/I_rendered.jpg -------------------------------------------------------------------------------- /imgs/examples/I_target.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/I_target.png -------------------------------------------------------------------------------- /imgs/examples/iter_1_z=6000.0mm_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/iter_1_z=6000.0mm_images.png -------------------------------------------------------------------------------- /imgs/examples/optimized.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/optimized.gif -------------------------------------------------------------------------------- /imgs/examples/optimized.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/optimized.mp4 -------------------------------------------------------------------------------- /imgs/examples/phase.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/phase.png -------------------------------------------------------------------------------- /imgs/examples/sanity_check_dO.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/sanity_check_dO.jpg -------------------------------------------------------------------------------- /imgs/examples/sanity_check_zemax.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/sanity_check_zemax.jpg -------------------------------------------------------------------------------- /imgs/memory_comp.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/memory_comp.jpg -------------------------------------------------------------------------------- /imgs/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/overview.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.5.3 2 | mitsuba==3.0.2 3 | numpy==1.21.6 4 | scipy==1.9.1 5 | torch==1.12.1+cu113 6 | --------------------------------------------------------------------------------