├── .gitignore
├── LICENSE
├── README.md
├── diffoptics
├── __init__.py
├── basics.py
├── optics.py
├── shapes.py
├── solvers.py
└── version.py
├── examples
├── autodiff.py
├── backprop_compare.py
├── caustic_pyramid.py
├── data
│ └── 20210304
│ │ └── ref2.tif
├── end2end_edof_backward_tracing.py
├── images
│ ├── einstein.jpg
│ ├── point.tif
│ └── squirrel.jpg
├── lenses
│ ├── DoubleGauss
│ │ ├── US02532751-1.txt
│ │ └── US02532751-1.zmx
│ ├── README.md
│ ├── ThorLabs
│ │ ├── ACL5040U.txt
│ │ └── LA1131.txt
│ ├── Zemax_samples
│ │ └── Nikon-z35-f1.8-JPA2019090949-example2.txt
│ └── end2end
│ │ └── end2end_edof.txt
├── misalignment_point.py
├── neural_networks
│ └── DeblurGANv2
│ │ ├── .gitignore
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── adversarial_trainer.py
│ │ ├── aug.py
│ │ ├── config
│ │ └── config.yaml
│ │ ├── dataset.py
│ │ ├── metric_counter.py
│ │ ├── models
│ │ ├── __init__.py
│ │ ├── fpn_densenet.py
│ │ ├── fpn_inception.py
│ │ ├── fpn_inception_simple.py
│ │ ├── fpn_mobilenet.py
│ │ ├── losses.py
│ │ ├── mobilenet_v2.py
│ │ ├── models.py
│ │ ├── networks.py
│ │ ├── senet.py
│ │ └── unet_seresnext.py
│ │ ├── picture_to_video.py
│ │ ├── predict.py
│ │ ├── requirements.txt
│ │ ├── schedulers.py
│ │ ├── test.py
│ │ ├── test.sh
│ │ ├── test_aug.py
│ │ ├── test_batchsize.py
│ │ ├── test_dataset.py
│ │ ├── test_metrics.py
│ │ ├── train.py
│ │ ├── train_end2end.py
│ │ └── util
│ │ ├── __init__.py
│ │ ├── image_pool.py
│ │ └── metrics.py
├── nikon.py
├── render_image.py
├── render_psf.py
├── sanity_check.py
├── spherical_aberration.py
├── training_dataset
│ ├── 0008.png
│ ├── 0010.png
│ ├── 0023.png
│ ├── 0030.png
│ ├── 0031.png
│ ├── 0032.png
│ ├── 0115.png
│ └── 0267.png
└── utils_end2end.py
├── imgs
├── abp.jpg
├── applications.jpg
├── bp_abp_comp.jpg
├── examples
│ ├── I.jpg
│ ├── I0.jpg
│ ├── I_final.png
│ ├── I_psf_z=-1000.0.png
│ ├── I_psf_z=-10000.0.png
│ ├── I_psf_z=-1500.0.png
│ ├── I_psf_z=-2000.0.png
│ ├── I_psf_z=-3000.0.png
│ ├── I_rendered.jpg
│ ├── I_target.png
│ ├── iter_1_z=6000.0mm_images.png
│ ├── optimized.gif
│ ├── optimized.mp4
│ ├── phase.png
│ ├── sanity_check_dO.jpg
│ └── sanity_check_zemax.jpg
├── memory_comp.jpg
└── overview.jpg
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/.gitignore
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 vccimaging
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # dO: A differentiable engine for Deep Lens design of computational imaging systems
2 | This is the PyTorch implementation for our paper "dO: A differentiable engine for Deep Lens design of computational imaging systems".
3 | ### [Project Page](https://vccimaging.org/Publications/Wang2022DiffOptics/) | [Paper](https://vccimaging.org/Publications/Wang2022DiffOptics/Wang2022DiffOptics.pdf) | [Supplementary Material](https://vccimaging.org/Publications/Wang2022DiffOptics/Wang2022DiffOptics_supp.pdf)
4 |
5 | dO: A differentiable engine for Deep Lens design of computational imaging systems
6 | [Congli Wang](https://congliwang.github.io),
7 | [Ni Chen](https://ni-chen.github.io), and
8 | [Wolfgang Heidrich](https://vccimaging.org/People/heidriw)
9 | King Abdullah University of Science and Technology (KAUST)
10 | IEEE Transactions on Computational Imaging 2022
11 |
12 |
13 | Figure: Our engine dO models ray tracing in a lens system in a derivative-aware way, this enables ray tracing with back-propagation. To be derivative-aware, all modules must be differentiable so that gradients can be back-propagated from the error metric ϵ(p(θ)) to variable parameters θ. This is achieved by two stages of the reverse-mode AD: the forward and the backward passes. To ensure differentiability and efficiency, a custom ray-surface intersection solver is introduced. Instead of unrolling iterations for forward/backward, only the forward (no AD) is computed to obtain solutions at surfaces fi = 0, and gradients are amended afterwards.
14 |
15 | ## TL; DR
16 |
17 | We implemented in PyTorch a memory- and computation-efficient differentiable ray tracing system for optical designs, for design applications in freeform, Deep Lens, metrology, and more.
18 |
19 | ## Update list
20 |
21 | - [x] Initial code release.
22 | - [x] [`autodiff.py`](./examples/autodiff.py): Demo of the `dO` engine.
23 | - [x] [`backprop_compare.py`](./examples/backprop_compare.py): Example on comparison between back-propagation and adjoint back-propagation.
24 | - [x] [`caustic_pyramid.py`](./examples/caustic_pyramid.py): Example on freeform caustic design.
25 |
26 | | Target irradiance | Optimized irradiance | Optimized phase map |
27 | | :-------------------------------------: | :-----------------------------------: | :-------------------------------: |
28 | |  |  |  |
29 |
30 | - [x] [`misalignment_point.py`](./examples/misalignment_point.py): Example on misalignment back-engineering, using real measurements.
31 |
32 | | Model (initial) | Measurement | Model (optimized) |
33 | | :-------------------------: | :-----------------------: | :---------------------------------------: |
34 | |  |  |  |
35 |
36 | - [x] [`nikon.py`](./examples/nikon.py): Example on optimizing a Nikon design.
37 | - [x] [`render_image.py`](./examples/render_image.py): Example on rendering a single image from a design.
38 |
39 | |  |
40 | | ------------------------------------------- |
41 |
42 | - [x] [`render_psf.py`](./examples/render_psf.py): Example on rendering PSFs of varying fields and depths for a design.
43 |
44 | |  |  |
45 | | ------------------------------------------------------------ | ------------------------------------------------------------ |
46 | |  |  |
47 |
48 | - [x] [`sanity_check.py`](./examples/sanity_check.py): Example on Zemax versus `dO` for sanity check.
49 |
50 | | `dO` | Zemax |
51 | | :---------------------------------------------------: | :---------------------------------------------------------: |
52 | |  |  |
53 |
54 |
55 | - [x] [`spherical_aberration.py`](./examples/spherical_aberration.py): Example on optimizing spherical aberration.
56 |
57 | - [x] [`end2end_edof_backward_tracing.py`](./examples/end2end_edof_backward_tracing.py): Example on end-to-end learning of wavefront coding, for extended depth of field applications, using backward ray tracing.
58 |
59 | |  |
60 | | ------------------------------------------------------------ |
61 |
62 | - [ ] Code cleanups and add comments.
63 |
64 | - [ ] File I/O with Zemax.
65 |
66 | - [ ] Mini GUI for easy operations.
67 |
68 | ## Installation
69 |
70 | ### Prerequisite
71 | Though no GPUs are required, for speed's sake it is better to run the engine on a GPU.
72 |
73 | Install the required Python packages:
74 | ```python
75 | pip install -r requirements.txt
76 | ```
77 |
78 | ### Running examples
79 | Examples are in the [`./examples`](./examples) folder, and running some of the examples may require installing additional Python packages. Just follow the terminal hints, for example install the following:
80 |
81 | ```shell
82 | pip install imageio opencv-python scikit-image
83 | ```
84 |
85 | In case Python cannot find the path to `dO`, run the example scripts in the [`./examples`](./examples) directory, for example:
86 |
87 | ```shell
88 | cd /examples
89 | python3 misalignment_point.py
90 | ```
91 |
92 |
93 | ## Summary
94 |
95 | ### Target problem
96 |
97 | - General optical design/metrology or Deep Lens designs are parameter-optimization problems, and learning-based methods (e.g. with back-propagation) can be employed as solvers. This requires the optical modeling to be numerically derivative-aware (i.e. differentiable).
98 | - However, straightforward differentiable ray tracing with auto-diff (AD) is not memory/computation-efficient.
99 |
100 | ### Our solutions
101 |
102 | - Differentiable ray-surface intersections requires a differentiable root-finding solver, which is typically iterative, like Newton's solver. Straightforward implementation is inefficient in both memory and computation. However, our paper makes an observation that, the status of the solver's iterations is *irrelevant* to the final solution -- That means, a differentiable root-finding solver can be smartly implemented as: (1) Find the optimal solution without AD (e.g. in block `with torch.no_grad()` in PyTorch), and (2) Re-engage AD to the solution found. This leads to great reduce in memory consumption, scaling up the system differentiability to large number of parameters or rays.
103 |
104 | |  |
105 | | ------------------------------------------------------------ |
106 | | Figure: Comparison between the straightforward and our proposed differentiable ray-surface intersection methods for freeform surface optimization. Our method reduces the required memory by about 6 times. |
107 |
108 | - When optimizing a custom merit function for image-based applications appended with a neural network, e.g. in Deep Lens designs, the training (or, back-propagation) can be split into two parts:
109 | - (Front-end) Optical design parameter optimization (training).
110 | - (Back-end) Neural network post-processing training.
111 |
112 | This de-coupling resembles the checkpointing technology in deep learning, and hence reducing the memory-hunger issue when tracing many number of rays.
113 |
114 | |  |
115 | | ------------------------------------------------------------ |
116 | |  |
117 | | Figure: Adjoint back-propagation (Adjoint BP) and the corresponding comparison against back-propagation (BP). Our implementation enables the scale up to many millions of rays while the conventional cannot. |
118 |
119 | ### Applications
120 |
121 | |  |
122 | | :----------------------------------------------------------: |
123 | | Figure: Using dO the differentiable ray tracing system, we show the feasibility of advanced optical designs. |
124 |
125 | ## Relevant Project
126 |
127 | [Towards self-calibrated lens metrology by differentiable refractive deflectometry](https://vccimaging.org/Publications/Wang2021DiffDeflectometry/Wang2021DiffDeflectometry.pdf)
128 | [Congli Wang](https://congliwang.github.io),
129 | [Ni Chen](https://ni-chen.github.io), and
130 | [Wolfgang Heidrich](https://vccimaging.org/People/heidriw)
131 | King Abdullah University of Science and Technology (KAUST)
132 | OSA Optics Express 2021
133 |
134 | GitHub: https://github.com/vccimaging/DiffDeflectometry.
135 |
136 | ## Citation
137 |
138 | ```bibtex
139 | @article{wang2022dO,
140 | title={{dO: A differentiable engine for Deep Lens design of computational imaging systems}},
141 | author={Wang, Congli and Chen, Ni and Heidrich, Wolfgang},
142 | journal={IEEE Transactions on Computational Imaging},
143 | year={2022},
144 | volume={8},
145 | number={},
146 | pages={905-916},
147 | doi={10.1109/TCI.2022.3212837},
148 | publisher={IEEE}
149 | }
150 | ```
151 |
152 | ## Contact
153 | Please either open an issue, or contact Congli Wang for questions.
154 |
155 |
--------------------------------------------------------------------------------
/diffoptics/__init__.py:
--------------------------------------------------------------------------------
1 | from .version import __version__
2 |
3 | from .basics import *
4 | from .shapes import *
5 | from .optics import *
6 | from .solvers import *
7 |
--------------------------------------------------------------------------------
/diffoptics/basics.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | import torch
3 | import numpy as np
4 |
5 |
6 | class PrettyPrinter():
7 | def __str__(self):
8 | lines = [self.__class__.__name__ + ':']
9 | for key, val in vars(self).items():
10 | if val.__class__.__name__ in ('list', 'tuple'):
11 | for i, v in enumerate(val):
12 | lines += '{}[{}]: {}'.format(key, i, v).split('\n')
13 |
14 | elif val.__class__.__name__ in 'dict':
15 | pass
16 | elif key == key.upper() and len(key) > 5:
17 | pass
18 | else:
19 | lines += '{}: {}'.format(key, val).split('\n')
20 | return '\n '.join(lines)
21 |
22 | def to(self, device=torch.device('cpu')):
23 | for key, val in vars(self).items():
24 | if torch.is_tensor(val):
25 | exec('self.{x} = self.{x}.to(device)'.format(x=key))
26 | elif issubclass(type(val), PrettyPrinter):
27 | exec(f'self.{key}.to(device)')
28 | elif val.__class__.__name__ in ('list', 'tuple'):
29 | for i, v in enumerate(val):
30 | if torch.is_tensor(v):
31 | exec('self.{x}[{i}] = self.{x}[{i}].to(device)'.format(x=key, i=i))
32 | elif issubclass(type(v), PrettyPrinter):
33 | exec('self.{}[{}].to(device)'.format(key, i))
34 |
35 |
36 | class Ray(PrettyPrinter):
37 | """
38 | Definition of a geometric ray.
39 |
40 | - o is the ray origin
41 | - d is the ray direction (normalized)
42 | """
43 | def __init__(self, o, d, wavelength, device=torch.device('cpu')):
44 | self.o = o
45 | self.d = d
46 |
47 | # scalar-version
48 | self.wavelength = wavelength # [nm]
49 | self.mint = 1e-5 # [mm]
50 | self.maxt = 1e5 # [mm]
51 | self.to(device)
52 |
53 | def __call__(self, t):
54 | return self.o + t[..., None] * self.d
55 |
56 |
57 | class Transformation(PrettyPrinter):
58 | """
59 | Rigid Transformation.
60 |
61 | - R is the rotation matrix.
62 | - t is the translational vector.
63 | """
64 | def __init__(self, R, t):
65 | if torch.is_tensor(R):
66 | self.R = R
67 | else:
68 | self.R = torch.Tensor(R)
69 | if torch.is_tensor(t):
70 | self.t = t
71 | else:
72 | self.t = torch.Tensor(t)
73 |
74 | def transform_point(self, o):
75 | return torch.squeeze(self.R @ o[..., None]) + self.t
76 |
77 | def transform_vector(self, d):
78 | return torch.squeeze(self.R @ d[..., None])
79 |
80 | def transform_ray(self, ray):
81 | o = self.transform_point(ray.o)
82 | d = self.transform_vector(ray.d)
83 | if o.is_cuda:
84 | return Ray(o, d, ray.wavelength, device=torch.device('cuda'))
85 | else:
86 | return Ray(o, d, ray.wavelength)
87 |
88 | def inverse(self):
89 | RT = self.R.T
90 | t = self.t
91 | return Transformation(RT, -RT @ t)
92 |
93 |
94 | class Spectrum(PrettyPrinter):
95 | """
96 | Spectrum distribution of the rays.
97 | """
98 | def __init__(self):
99 | self.WAVELENGTH_MIN = 400 # [nm]
100 | self.WAVELENGTH_MAX = 760 # [nm]
101 | self.to()
102 |
103 | def sample_wavelength(self, sample):
104 | return self.WAVELENGTH_MIN + (self.WAVELENGTH_MAX - self.WAVELENGTH_MIN) * sample
105 |
106 |
107 | class Sampler(PrettyPrinter):
108 | """
109 | Sampler for generating random sample points.
110 | """
111 | def __init__(self):
112 | self.to()
113 | self.pi_over_2 = np.pi / 2
114 | self.pi_over_4 = np.pi / 4
115 |
116 | def concentric_sample_disk(self, x, y):
117 | # https://pbr-book.org/3ed-2018/Monte_Carlo_Integration/2D_Sampling_with_Multidimensional_Transformations
118 |
119 | # map uniform random numbers to [-1,1]^2
120 | x = 2 * x - 1
121 | y = 2 * y - 1
122 |
123 | # handle degeneracy at the origin when xy == [0,0]
124 |
125 | # apply concentric mapping to point
126 | eps = np.finfo(float).eps
127 |
128 | if type(x) is torch.Tensor and type(y) is torch.Tensor:
129 | cond = torch.abs(x) > torch.abs(y)
130 | r = torch.where(cond, x, y)
131 | theta = torch.where(cond,
132 | self.pi_over_4 * (y / (x + eps)),
133 | self.pi_over_2 - self.pi_over_4 * (x / (y + eps))
134 | )
135 | return r * torch.cos(theta), r * torch.sin(theta)
136 |
137 | if type(x) is np.ndarray and type(y) is np.ndarray:
138 | cond = np.abs(x) > np.abs(y)
139 | r = np.where(cond, x, y)
140 | theta = np.where(cond,
141 | self.pi_over_4 * (y / (x + eps)),
142 | self.pi_over_2 - self.pi_over_4 * (x / (y + eps))
143 | )
144 | return r * np.cos(theta), r * np.sin(theta)
145 |
146 |
147 | class Filter(PrettyPrinter):
148 | def __init__(self, radius):
149 | self.radius = radius
150 | def eval(self, p):
151 | raise NotImplementedError()
152 |
153 | class Box(Filter):
154 | def __init__(self, radius=None):
155 | if radius is None:
156 | radius = [0.5, 0.5]
157 | Filter.__init__(self, radius)
158 | def eval(self, x):
159 | return torch.ones_like(x)
160 |
161 | class Triangle(Filter):
162 | def __init__(self, radius):
163 | if radius is None:
164 | radius = [2.0, 2.0]
165 | Filter.__init__(self, radius)
166 | def eval(self, p):
167 | x, y = p[...,0], p[...,1]
168 | return (torch.maximum(torch.zeros_like(x), self.radius[0] - x) *
169 | torch.maximum(torch.zeros_like(y), self.radius[1] - y))
170 |
171 | # ----------------------------------------------------------------------------------------
172 |
173 | class Material(PrettyPrinter):
174 | """
175 | Optical materials for computing the refractive indices.
176 |
177 | The following follows the simple formula that
178 |
179 | n(\lambda) = A + B / \lambda^2
180 |
181 | where the two constants A and B can be computed from nD (index at 589.3 nm) and V (abbe number).
182 | """
183 | def __init__(self, name=None):
184 | self.name = 'vacuum' if name is None else name.lower()
185 |
186 | # This table is hard-coded. TODO: Import glass libraries from Zemax.
187 | self.MATERIAL_TABLE = { # [nD, Abbe number]
188 | "vacuum": [1., np.inf],
189 | "air": [1.000293, np.inf],
190 | "occluder": [1., np.inf],
191 | "f2": [1.620, 36.37],
192 | "f15": [1.60570, 37.831],
193 | "uvfs": [1.458, 67.82],
194 |
195 | # https://shop.schott.com/advanced_optics/
196 | "bk10": [1.49780, 66.954],
197 | "n-baf10": [1.67003, 47.11],
198 | "n-bk7": [1.51680, 64.17],
199 | "n-sf1": [1.71736, 29.62],
200 | "n-sf2": [1.64769, 33.82],
201 | "n-sf4": [1.75513, 27.38],
202 | "n-sf5": [1.67271, 32.25],
203 | "n-sf6": [1.80518, 25.36],
204 | "n-sf6ht": [1.80518, 25.36],
205 | "n-sf8": [1.68894, 31.31],
206 | "n-sf10": [1.72828, 28.53],
207 | "n-sf11": [1.78472, 25.68],
208 | "sf1": [1.71736, 29.51],
209 | "sf2": [1.64769, 33.85],
210 | "sf4": [1.75520, 27.58],
211 | "sf5": [1.67270, 32.21],
212 | "sf6": [1.80518, 25.43],
213 | "sf18": [1.72150, 29.245],
214 |
215 | # HIKARI.AGF
216 | "baf10": [1.67, 47.05],
217 |
218 | # SUMITA.AGF
219 | "sk1": [1.61030, 56.712],
220 | "sk16": [1.62040, 60.306],
221 | "ssk4": [1.61770, 55.116],
222 |
223 | # https://www.pgo-online.com/intl/B270.html
224 | "b270": [1.52290, 58.50],
225 |
226 | # https://refractiveindex.info, nD at 589.3 nm
227 | "s-nph1": [1.8078, 22.76],
228 | "d-k59": [1.5175, 63.50],
229 |
230 | "flint": [1.6200, 36.37],
231 | "pmma": [1.491756, 58.00],
232 | "polycarb": [1.585470, 30.00]
233 | }
234 | self.A, self.B = self._lookup_material()
235 |
236 | def ior(self, wavelength):
237 | """Computes index of refraction of a given wavelength (in [nm])"""
238 | return self.A + self.B / wavelength**2
239 |
240 | @staticmethod
241 | def nV_to_AB(n, V):
242 | def ivs(a): return 1./a**2
243 | lambdas = [656.3, 589.3, 486.1]
244 | B = 0.0 if V == 0 else (n - 1) / V / ( ivs(lambdas[2]) - ivs(lambdas[0]) )
245 | A = n - B * ivs(lambdas[1])
246 | return A, B
247 |
248 | def _lookup_material(self):
249 | out = self.MATERIAL_TABLE.get(self.name)
250 | if isinstance(out, list):
251 | n, V = out
252 | elif out is None:
253 | # try parsing input as a n/V pair
254 | tmp = self.name.split('/')
255 | n, V = float(tmp[0]), float(tmp[1])
256 | return self.nV_to_AB(n, V)
257 |
258 | def to_string(self):
259 | return f'{self.A} + {self.B}/lambda^2'
260 |
261 |
262 | class InterpolationMode(Enum):
263 | nearest = 1
264 | linear = 2
265 |
266 | class BoundaryMode(Enum):
267 | zero = 1
268 | replicate = 2
269 | symmetric = 3
270 | periodic = 4
271 |
272 | class SimulationMode(Enum):
273 | render = 1
274 | trace = 2
275 |
276 |
277 | """
278 | Utility functions.
279 | """
280 |
281 | def init():
282 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
283 | print("DiffOptics is using: {}".format(device))
284 | torch.set_default_tensor_type('torch.FloatTensor')
285 | return device
286 |
287 | def length2(d):
288 | return torch.sum(d**2, axis=-1)
289 |
290 | def length(d):
291 | return torch.sqrt(length2(d))
292 |
293 | def normalize(d):
294 | return d / length(d)[..., None]
295 |
296 | def set_zeros(x, valid=None):
297 | if valid == None:
298 | return torch.where(torch.isnan(x), torch.zeros_like(x), x)
299 | else:
300 | mask = valid[...,None] if len(x.shape) > len(valid.shape) else valid
301 | return torch.where(~mask, torch.zeros_like(x), x)
302 |
303 | def rodrigues_rotation_matrix(k, theta): # theta: [rad]
304 | """
305 | This function implements the Rodrigues rotation matrix.
306 | """
307 | # cross-product matrix
308 | kx, ky, kz = k[0], k[1], k[2]
309 | K = torch.Tensor([
310 | [ 0, -kz, ky],
311 | [ kz, 0, -kx],
312 | [-ky, kx, 0]
313 | ]).to(k.device)
314 | if not torch.is_tensor(theta):
315 | theta = torch.Tensor(np.asarray(theta)).to(k.device)
316 | return torch.eye(3, device=k.device) + torch.sin(theta) * K + (1 - torch.cos(theta)) * K @ K
317 |
318 | def set_axes_equal(ax, scale=np.ones(3)):
319 | """
320 | Make axes of 3D plot have equal scale (or scaled by `scale`).
321 | """
322 | limits = np.array([
323 | ax.get_xlim3d(),
324 | ax.get_ylim3d(),
325 | ax.get_zlim3d()
326 | ])
327 | tmp = np.abs(limits[:,1]-limits[:,0])
328 | ax.set_box_aspect(scale * tmp/np.min(tmp))
329 |
330 |
331 | """
332 | Test functions
333 | """
334 |
335 | def generate_test_rays():
336 | filmsize = np.array([4, 2])
337 |
338 | o = np.array([3,4,-200])
339 | o = np.tile(o[None, None, ...], [*filmsize, 1])
340 | o = torch.Tensor(o)
341 |
342 | dx = 0.1 * torch.rand(*filmsize)
343 | dy = 0.1 * torch.rand(*filmsize)
344 | d = normalize(torch.stack((dx, dy, torch.ones_like(dx)), axis=-1))
345 |
346 | wavelength = 500 # [nm]
347 | return Ray(o, d, wavelength)
348 |
349 | def generate_test_transformation():
350 | k = np.random.rand(3)
351 | k = k / np.sqrt(np.sum(k**2))
352 | k = torch.Tensor(k)
353 | theta = 1 # [rad]
354 | R = rodrigues_rotation_matrix(k, theta)
355 | t = np.random.rand(3)
356 | return Transformation(R, t)
357 |
358 | def generate_test_material():
359 | return Material('N-BK7')
360 |
361 |
362 | if __name__ == "__main__":
363 | init()
364 |
365 | rays = generate_test_rays()
366 | to_world = generate_test_transformation()
367 | print(to_world)
368 |
369 | rays_new = to_world.transform_ray(rays)
370 | o_old = rays.o[2,1,...].numpy()
371 | o_new = rays_new.o[2,1,...].numpy()
372 | assert (to_world.transform_point(o_old) - o_new).abs().sum() < 1e-15
373 |
374 | material = generate_test_material()
375 | print(material)
376 |
--------------------------------------------------------------------------------
/diffoptics/shapes.py:
--------------------------------------------------------------------------------
1 | from .basics import *
2 |
3 |
4 | class Endpoint(PrettyPrinter):
5 | """
6 | Abstract class for objects.
7 | """
8 | def __init__(self, transformation, device=torch.device('cpu')):
9 | self.to_world = transformation
10 | self.to_object = transformation.inverse()
11 | self.to(device)
12 | self.device = device
13 |
14 | def intersect(self, ray):
15 | raise NotImplementedError()
16 |
17 | def sample_ray(self, position_sample=None):
18 | raise NotImplementedError()
19 |
20 | def draw_points(self, ax, options, seq=range(3)):
21 | raise NotImplementedError()
22 |
23 | def update_Rt(self, R, t):
24 | self.to_world = Transformation(R, t)
25 | self.to_object = self.to_world.inverse()
26 | self.to(self.device)
27 |
28 |
29 | class Screen(Endpoint):
30 | """
31 | A screen obejct, useful for image rendering.
32 |
33 | Local frame centers at [-w, w]/2 x [-h, h]/2.
34 | """
35 | def __init__(self, transformation, size, texture, device=torch.device('cpu')):
36 | self.size = torch.Tensor(np.float32(size)) # screen dimension [mm]
37 | self.halfsize = self.size/2 # screen half-dimension [mm]
38 | self.texture_shift = torch.zeros(2) # screen image shift [mm]
39 | self.update_texture(texture)
40 | Endpoint.__init__(self, transformation, device)
41 | self.to(device)
42 |
43 | def update_texture(self, texture: torch.Tensor):
44 | self.texture = texture # screen image
45 | self.texturesize = torch.Tensor(np.array(texture.shape[0:2])).long().to(texture.device) # screen image dimension [pixel]
46 |
47 | def intersect(self, ray):
48 | ray_in = self.to_object.transform_ray(ray)
49 | t = - ray_in.o[..., 2] / (1e-10 + ray_in.d[..., 2]) # (TODO: potential NaN grad)
50 | local = ray_in(t)
51 |
52 | # Is intersection within ray segment and rectangle?
53 | valid = (
54 | (t >= ray_in.mint) &
55 | (t <= ray_in.maxt) &
56 | (torch.abs(local[..., 0] - self.texture_shift[0]) <= self.halfsize[0]) &
57 | (torch.abs(local[..., 1] - self.texture_shift[1]) <= self.halfsize[1])
58 | )
59 |
60 | # UV coordinate
61 | uv = (local[..., 0:2] + self.halfsize - self.texture_shift) / self.size
62 |
63 | # Force uv to be valid in [0,1]^2 (just a sanity check: uv should be in [0,1]^2)
64 | uv = torch.clamp(uv, min=0.0, max=1.0)
65 |
66 | return local, uv, valid
67 |
68 | def shading(self, uv, valid, bmode=BoundaryMode.replicate, lmode=InterpolationMode.linear):
69 | # p = uv * (self.texturesize[None, None, ...]-1)
70 | p = uv * (self.texturesize-1)
71 | p_floor = torch.floor(p).long()
72 |
73 | def tex(x, y):
74 | """
75 | Texture indexing function, handle various boundary conditions.
76 | """
77 | if bmode is BoundaryMode.zero:
78 | raise NotImplementedError()
79 | elif bmode is BoundaryMode.replicate:
80 | x = torch.clamp(x, min=0, max=self.texturesize[0].item()-1)
81 | y = torch.clamp(y, min=0, max=self.texturesize[1].item()-1)
82 | elif bmode is BoundaryMode.symmetric:
83 | raise NotImplementedError()
84 | elif bmode is BoundaryMode.periodic:
85 | raise NotImplementedError()
86 | img = self.texture[x.flatten(), y.flatten()]
87 | return img.reshape(x.shape)
88 |
89 | # Texture fetching, requires interpolation to compute fractional pixel values.
90 | if lmode is InterpolationMode.nearest:
91 | val = tex(p_floor[...,0], p_floor[...,1])
92 | elif lmode is InterpolationMode.linear:
93 | x0, y0 = p_floor[...,0], p_floor[...,1]
94 | s00 = tex( x0, y0)
95 | s01 = tex( x0, 1+y0)
96 | s10 = tex(1+x0, y0)
97 | s11 = tex(1+x0, 1+y0)
98 | w1 = p - p_floor
99 | w0 = 1. - w1
100 | val = (
101 | w0[...,0] * (w0[...,1] * s00 + w1[...,1] * s01) +
102 | w1[...,0] * (w0[...,1] * s10 + w1[...,1] * s11)
103 | )
104 |
105 | # val = val * valid
106 | # val[torch.isnan(val)] = 0.0
107 |
108 | # TODO: should be added;
109 | # but might cause RuntimeError: Function 'MulBackward0' returned nan values in its 0th output.
110 | val[~valid] = 0.0
111 | return val
112 |
113 | def draw_points(self, ax, options, seq=range(3)):
114 | """
115 | Visualization function.
116 | """
117 | coeffs = np.array([
118 | [ 1, 1, 1],
119 | [-1, 1, 1],
120 | [-1,-1, 1],
121 | [ 1,-1, 1],
122 | [ 1, 1, 1]
123 | ])
124 | points_local = torch.Tensor(coeffs * np.append(self.halfsize.cpu().detach().numpy(), 0)).to(self.device)
125 | points_world = self.to_world.transform_point(points_local).T.cpu().detach().numpy()
126 | ax.plot(points_world[seq[0]], points_world[seq[1]], points_world[seq[2]], options)
127 |
128 |
--------------------------------------------------------------------------------
/diffoptics/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.1.0'
--------------------------------------------------------------------------------
/examples/autodiff.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import matplotlib.pyplot as plt
5 |
6 | import sys
7 | sys.path.append("../")
8 | import diffoptics as do
9 |
10 | # initialize a lens
11 | device = torch.device('cpu')
12 | lens = do.Lensgroup(device=device)
13 |
14 | save_dir = './autodiff_demo/'
15 | if not os.path.exists(save_dir):
16 | os.mkdir(save_dir)
17 |
18 | R = 12.7
19 | surfaces = [
20 | do.Aspheric(R, 0.0, c=0.05, device=device),
21 | do.Aspheric(R, 6.5, c=0., device=device)
22 | ]
23 | materials = [
24 | do.Material('air'),
25 | do.Material('N-BK7'),
26 | do.Material('air')
27 | ]
28 | lens.load(surfaces, materials)
29 | lens.d_sensor = 25.0
30 | lens.r_last = 12.7
31 |
32 | # generate array of rays
33 | wavelength = torch.Tensor([532.8]).to(device) # [nm]
34 | R = 10.0 # [mm]
35 | def render():
36 | ray_init = lens.sample_ray(wavelength, M=9, R=R, sampling='grid')
37 | ps = lens.trace_to_sensor(ray_init)
38 | return ps[...,:2]
39 |
40 | def trace_all():
41 | ray_init = lens.sample_ray_2D(R, wavelength, M=11)
42 | ps, oss = lens.trace_to_sensor_r(ray_init)
43 | return ps[...,:2], oss
44 |
45 | def compute_Jacobian(ps):
46 | Js = []
47 | for i in range(1):
48 | J = torch.zeros(torch.numel(ps))
49 | for j in range(torch.numel(ps)):
50 | mask = torch.zeros(torch.numel(ps))
51 | mask[j] = 1
52 | ps.backward(mask.reshape(ps.shape), retain_graph=True)
53 | J[j] = lens.surfaces[i].c.grad.item()
54 | lens.surfaces[i].c.grad.data.zero_()
55 | J = J.reshape(ps.shape)
56 |
57 | # get data to numpy
58 | Js.append(J.cpu().detach().numpy())
59 | return Js
60 |
61 |
62 | N = 20
63 | cs = np.linspace(0.045, 0.063, N)
64 | Iss = []
65 | Jss = []
66 | for index, c in enumerate(cs):
67 | index_string = str(index).zfill(3)
68 | # load optics
69 | lens.surfaces[0].c = torch.Tensor(np.array(c))
70 | lens.surfaces[0].c.requires_grad = True
71 |
72 | # show trace figure
73 | ps, oss = trace_all()
74 | ax, fig = lens.plot_raytraces(oss, color='b-', show=False)
75 | ax.axis('off')
76 | ax.set_title("")
77 | fig.savefig(save_dir + "layout_trace_" + index_string + ".png", bbox_inches='tight')
78 |
79 | # show spot diagram
80 | RMS = lambda ps: torch.sqrt(torch.mean(torch.sum(torch.square(ps), axis=-1)))
81 | ps = render()
82 | rms_org = RMS(ps)
83 | print(f'RMS: {rms_org}')
84 | lens.spot_diagram(ps, xlims=[-4, 4], ylims=[-4, 4], savepath=save_dir + "spotdiagram_" + index_string + ".png", show=False)
85 |
86 | # compute Jacobian
87 | Js = compute_Jacobian(ps)[0]
88 | print(Js.max())
89 | print(Js.min())
90 | ps_ = ps.cpu().detach().numpy()
91 | fig = plt.figure()
92 | x, y = ps_[:,0], ps_[:,1]
93 | plt.plot(x, y, 'b.', zorder=0)
94 | plt.quiver(x, y, Js[:,0], Js[:,1], color='b', zorder=1)
95 | plt.xlim(-4, 4)
96 | plt.ylim(-4, 4)
97 | plt.gca().set_aspect('equal', adjustable='box')
98 | plt.xlabel('x [mm]')
99 | plt.ylabel('y [mm]')
100 | fig.savefig(save_dir + "flow_" + index_string + ".png", bbox_inches='tight')
101 |
102 | # compute images
103 | ray = lens.sample_ray(wavelength.item(), view=0.0, M=2049, sampling='grid')
104 | lens.film_size = [512, 512]
105 | lens.pixel_size = 50.0e-3/2
106 | I = lens.render(ray)
107 | I = I.cpu().detach().numpy()
108 | lm = do.LM(lens, ['surfaces[0].c'], 1e-2, option='diag')
109 | JI = lm.jacobian(lambda: lens.render(ray)).squeeze()
110 | J = JI.abs().cpu().detach().numpy()
111 |
112 | Iss.append(I)
113 | Jss.append(J)
114 | plt.close()
115 |
116 | Iss = np.array(Iss)
117 | Jss = np.array(Jss)
118 | for i in range(N):
119 | plt.imsave(save_dir + "I_" + str(i).zfill(3) + ".png", Iss[i], cmap='gray')
120 | plt.imsave(save_dir + "J_" + str(i).zfill(3) + ".png", Jss[i], cmap='gray')
121 |
122 | names = [
123 | 'spotdiagram',
124 | 'layout_trace',
125 | 'I',
126 | 'J',
127 | 'flow'
128 | ]
129 |
--------------------------------------------------------------------------------
/examples/backprop_compare.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import matplotlib.pyplot as plt
4 |
5 | import sys
6 | sys.path.append("../")
7 | import diffoptics as do
8 |
9 |
10 | device = torch.device('cuda')
11 |
12 | # initialize a lens
13 | def init():
14 | lens = do.Lensgroup(device=device)
15 |
16 | R = 12.7
17 | surfaces = [
18 | do.Aspheric(R, 0.0, c=0.05, device=device),
19 | do.Aspheric(R, 6.5, c=0., device=device)
20 | ]
21 | materials = [
22 | do.Material('air'),
23 | do.Material('N-BK7'),
24 | do.Material('air')
25 | ]
26 | lens.load(surfaces, materials)
27 | lens.d_sensor = 25.0
28 | lens.r_last = 12.7
29 | lens.film_size = [256, 256]
30 | lens.pixel_size = 100.0e-3/2
31 | lens.surfaces[1].ai = torch.zeros(2, device=device)
32 | return lens
33 |
34 | def baseline(network_func, rays):
35 | lens = init()
36 | lens.surfaces[0].c.requires_grad = True
37 | lens.surfaces[1].ai.requires_grad = True
38 |
39 | I = 0.0
40 | for ray in rays:
41 | I = I + lens.render(ray)
42 |
43 | L = network_func(I)
44 | L.backward()
45 | print("Baseline:")
46 | print("primal: {}".format(L))
47 | print("derivatives: {}".format([lens.surfaces[0].c.grad, lens.surfaces[1].ai.grad]))
48 | return float(torch.cuda.memory_allocated() / (1024 * 1024))
49 |
50 | def ours_new(network_func, rays):
51 | lens = init()
52 | adj = do.Adjoint(
53 | lens, ['surfaces[0].c', 'surfaces[1].ai'],
54 | network_func, lens.render, rays
55 | )
56 |
57 | L_item, grads = adj()
58 |
59 | print("Ours:")
60 | print("primal: {}".format(L_item))
61 | print("derivatives: {}".format(grads))
62 | torch.cuda.empty_cache()
63 | return float(torch.cuda.memory_allocated() / (1024 * 1024))
64 |
65 | # Initialize a lens
66 | lens = init()
67 |
68 | # generate array of rays
69 | wavelength = torch.Tensor([532.8]).to(device) # [nm]
70 |
71 | def prepare_rays(view):
72 | ray = lens.sample_ray(wavelength.item(), view=view, M=2000+1, sampling='grid')
73 | return ray
74 |
75 | # define a network
76 | torch.manual_seed(0)
77 | I_ref = torch.rand(lens.film_size, device=device)
78 | def network_func(I):
79 | return ((I - I_ref)**2).mean()
80 |
81 | # timings
82 | start = torch.cuda.Event(enable_timing=True)
83 | end = torch.cuda.Event(enable_timing=True)
84 |
85 | # compares
86 | views = [1, 3, 5, 7, 9, 11, 13, 15]
87 | max_views = len(views)
88 |
89 | num_rayss = np.zeros(max_views)
90 | time = np.zeros((max_views, 2))
91 | memory = np.zeros((max_views, 2))
92 | for i, num_views in enumerate(views):
93 | print("view = {}".format(num_views))
94 | views = np.linspace(0, 1, num_views)
95 | num_rays = num_views * 2001**2 / 1e6
96 | num_rayss[i] = num_rays
97 |
98 | # prepare rays
99 | rays = [prepare_rays(view) for view in views]
100 |
101 | # Baseline
102 | try:
103 | start.record()
104 | memory[i,0] = baseline(network_func, rays)
105 | end.record()
106 | torch.cuda.synchronize()
107 | print("Baseline time: {:.3f} s".format(start.elapsed_time(end)*1e-3))
108 | time[i,0] = start.elapsed_time(end)
109 | except:
110 | print('Baseline: Memory insuffient! Stop running for this case!')
111 | time[i,0] = np.nan
112 | memory[i,0] = np.nan
113 |
114 | # Ours
115 | start.record()
116 | memory[i,1] = ours_new(network_func, rays)
117 | end.record()
118 | torch.cuda.synchronize()
119 | print("Ours (adjoint-based) time: {:.3f} s".format(start.elapsed_time(end)*1e-3))
120 | time[i,1] = start.elapsed_time(end)
121 |
122 |
123 | # show results
124 | fig = plt.figure()
125 | plt.plot(num_rayss, time, '-o')
126 | plt.title("Time Comparison")
127 | plt.xlabel("Number of rays [Millions]")
128 | plt.ylabel("Computation time [Seconds]")
129 | plt.legend(["Baseline (backpropagation)", "Ours (adjoint-based)"])
130 | plt.show()
131 |
--------------------------------------------------------------------------------
/examples/caustic_pyramid.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import matplotlib.pyplot as plt
5 | from matplotlib.image import imread, imsave
6 | from skimage.transform import resize
7 |
8 | import sys
9 | sys.path.append("../")
10 | import diffoptics as do
11 |
12 | # initialize a lens
13 | device = do.init()
14 | # device = torch.device('cpu')
15 | lens = do.Lensgroup(device=device)
16 |
17 | # construct freeform optics
18 | R = 25.4
19 | ns = [256, 256]
20 | surfaces = [
21 | do.Aspheric(R, 0.0, c=0., is_square=True, device=device),
22 | do.Mesh(R, 1.0, ns, is_square=True, device=device)
23 | ]
24 | materials = [
25 | do.Material('air'),
26 | do.Material('N-BK7'),
27 | do.Material('air')
28 | ]
29 | lens.load(surfaces, materials)
30 |
31 | # set scene geometry
32 | D = torch.Tensor([50.0]).to(device) # [mm]
33 | wavelength = torch.Tensor([532.8]).to(device) # [nm]
34 |
35 | # example image
36 | filename = 'einstein'
37 | img_org = imread('./images/' + filename + '.jpg') # assume image is grayscale
38 | if img_org.mean() > 1.0:
39 | img_org = img_org / 255.0
40 |
41 | # downsample the image
42 | NN = 2
43 | img_org = img_org[::NN,::NN]
44 | N_max = 128
45 | img_org = img_org[:N_max,:N_max]
46 |
47 | # mark differentiable variables
48 | lens.surfaces[1].c.requires_grad = True
49 |
50 | # create save dir
51 | savepath = './einstein_pyramid/'
52 | if not os.path.exists(savepath):
53 | os.mkdir(savepath)
54 |
55 | def caustic(N, pyramid_i, lr=1e-3, maxit=100):
56 | img = resize(img_org, (N, N))
57 | E = np.sum(img) # total energy
58 | print(f'image size = {img.shape}')
59 |
60 | N_pad = 0
61 | N_total = N + 2*N_pad
62 | img = np.pad(img, (N_pad,N_pad), 'constant', constant_values=np.inf)
63 | img[np.isinf(img)] = 0.0 # revert img back for visualization
64 | I_ref = torch.Tensor(img).to(device) # [mask]
65 |
66 | # max square length
67 | R_square = R * N_total/N
68 |
69 | # set image plane pixel grid
70 | R_image = R_square
71 | pixel_size = 2*R_image / N_total # [mm]
72 |
73 | def sample_ray(M=1, random=False):
74 | M = int(M*N)
75 | x, y = torch.meshgrid(
76 | torch.linspace(-R_square, R_square, M, device=device),
77 | torch.linspace(-R_square, R_square, M, device=device)
78 | )
79 | p = 2*R_square / M
80 | if random:
81 | x = x + p * (torch.rand(M,M,device=device)-0.5)
82 | y = y + p * (torch.rand(M,M,device=device)-0.5)
83 | o = torch.stack((x,y,torch.zeros_like(x, device=device)), axis=2)
84 | d = torch.zeros_like(o)
85 | d[...,2] = torch.ones_like(x)
86 | return do.Ray(o, d, wavelength, device=device), E
87 |
88 | def render_single(I, ray_init, irr):
89 | ray, valid = lens.trace(ray_init)[:2]
90 | J = irr * valid * ray.d[...,2]
91 | p = ray(D)
92 | p = p[...,:2]
93 | del ray, valid
94 |
95 | # compute shifts and do linear interpolation
96 | uv = (p + R_square) / pixel_size
97 | index_l = torch.clamp(torch.floor(uv).long(), min=0, max=N_total-1)
98 | index_r = torch.clamp(index_l + 1, min=0, max=N_total-1)
99 | w_r = torch.clamp(uv - index_l, min=0, max=1)
100 | w_l = 1.0 - w_r
101 | del uv
102 |
103 | # compute image
104 | I = torch.index_put(I, (index_l[...,0],index_l[...,1]), w_l[...,0]*w_l[...,1]*J, accumulate=True)
105 | I = torch.index_put(I, (index_r[...,0],index_l[...,1]), w_r[...,0]*w_l[...,1]*J, accumulate=True)
106 | I = torch.index_put(I, (index_l[...,0],index_r[...,1]), w_l[...,0]*w_r[...,1]*J, accumulate=True)
107 | I = torch.index_put(I, (index_r[...,0],index_r[...,1]), w_r[...,0]*w_r[...,1]*J, accumulate=True)
108 | return I
109 |
110 | def render(spp=1):
111 | I = torch.zeros((N_total,N_total), device=device)
112 | ray_init, irr = sample_ray(M=24, random=True) # Reduce M if your GPU memory is low
113 | I = render_single(I, ray_init, irr)
114 | return I / spp
115 |
116 | # optimize
117 | ls = []
118 |
119 | save_path = savepath + "/{}".format("pyramid_" + str(pyramid_i))
120 | if not os.path.exists(save_path):
121 | os.makedirs(save_path)
122 |
123 | print('optimizing ...')
124 | optimizer = torch.optim.Adam([lens.surfaces[1].c], lr=lr, betas=(0.99,0.99), amsgrad=True)
125 |
126 | for it in range(maxit+1):
127 | I = render(spp=8)
128 | I = I / I.sum() * I_ref.sum()
129 | L = torch.mean((I - I_ref)**2)
130 | optimizer.zero_grad()
131 | L.backward(retain_graph=True)
132 |
133 | # record
134 | ls.append(L.cpu().detach().numpy())
135 | if it % 10 == 0:
136 | print('iter = {}: loss = {:.4e}, grad_bar = {:.4e}'.format(
137 | it, L.item(), torch.sum(torch.abs(lens.surfaces[1].c.grad))
138 | ))
139 | I_current = I.cpu().detach().numpy()
140 | imsave("{}/{:04d}.png".format(save_path, it), I_current, vmin=0.0, vmax=1.0, cmap='gray')
141 |
142 | # descent
143 | optimizer.step()
144 |
145 | if pyramid_i == 0: # last one, render final image
146 | lens.surfaces[1].c.requires_grad = False
147 | del L
148 | I_final = 0
149 | spp = 100
150 | for i in range(spp):
151 | if i % 10 == 0:
152 | print("=== rendering spp = {}".format(i))
153 | I_final += render().cpu().detach().numpy()
154 | return I_final / spp, I_ref, ls
155 | else:
156 | return I.cpu().detach().numpy(), None, ls
157 |
158 | pyramid_levels = 2
159 | for i in range(pyramid_levels, -1, -1):
160 | N = int(N_max/(2**i))
161 | print("=== N = {}".format(N))
162 | I_final, I_ref, ls = caustic(N, i, lr=1e-3, maxit=int(1000/4**i))
163 | if i == 0:
164 | I_ref = I_ref.cpu().numpy()
165 | I_final = I_final / I_final.sum() * I_ref.sum()
166 |
167 | imsave(savepath + "/I_target.png", I_ref, vmin=0.0, vmax=1.0, cmap='gray')
168 | imsave(savepath + "/I_final.png", I_final, vmin=0.0, vmax=1.0, cmap='gray')
169 |
170 | # final results
171 | plt.imshow(I_final, cmap='gray')
172 | plt.title('Final caustic image')
173 | plt.show()
174 |
175 | fig, ax = plt.subplots()
176 | ax.plot(ls, 'k-o', linewidth=2)
177 | ax.set_xlabel('iteration')
178 | ax.set_ylabel('loss')
179 | fig.savefig("ls.pdf", bbox_inches='tight')
180 | plt.title('Loss')
181 |
182 | S = lens.surfaces[1].mesh().cpu().detach().numpy()
183 | S = S - S.min()
184 | imsave(savepath + "/phase.png", S, vmin=0, vmax=S.max(), cmap='coolwarm')
185 | imsave(savepath + "/phase_mod.png", np.mod(S*1e3,100), cmap='coolwarm')
186 | print(S.max())
187 |
188 | plt.figure()
189 | plt.imshow(S, cmap='jet')
190 | plt.colorbar()
191 | plt.title('Optimized phase plate height [mm]')
192 | plt.show()
193 |
194 |
--------------------------------------------------------------------------------
/examples/data/20210304/ref2.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/data/20210304/ref2.tif
--------------------------------------------------------------------------------
/examples/end2end_edof_backward_tracing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import matplotlib.pyplot as plt
4 | from pathlib import Path
5 | from tqdm import tqdm
6 | from datetime import datetime
7 |
8 | import sys
9 | sys.path.append("../")
10 | import diffoptics as do
11 | from utils_end2end import dict_to_tensor, tensor_to_dict, load_deblurganv2, ImageFolder
12 | torch.manual_seed(0)
13 |
14 | # Initialize a lens
15 | device = torch.device('cuda')
16 | lens = do.Lensgroup(device=device)
17 |
18 | # Load optics
19 | lens.load_file(Path('./lenses/end2end/end2end_edof.txt')) # norminal design
20 | lens.plot_setup2D()
21 | [surface.to(device) for surface in lens.surfaces]
22 |
23 | # set sensor pixel size and film size
24 | downsample_factor = 4 # downsampled for run
25 | pixel_size = downsample_factor * 3.45e-3 # [mm]
26 | film_size = [512 // downsample_factor, 512 // downsample_factor]
27 | lens.prepare_mts(pixel_size, film_size)
28 | print('Check your lens:')
29 | print(lens)
30 |
31 | # sample wavelengths in [nm]
32 | wavelengths = [656.2725, 587.5618, 486.1327]
33 |
34 | def create_screen(texture: torch.Tensor, z: float, pixelsize: float) -> do.Screen:
35 | texturesize = np.array(texture.shape[0:2])
36 | screen = do.Screen(
37 | do.Transformation(np.eye(3), np.array([0, 0, z])),
38 | texturesize * pixelsize, texture, device=device
39 | )
40 | return screen
41 |
42 | def render_single(wavelength: float, screen: do.Screen, sample_ray_function, images: list[torch.Tensor]):
43 | valid, ray_new = sample_ray_function(wavelength)
44 | uv, valid_screen = screen.intersect(ray_new)[1:]
45 | mask = valid & valid_screen
46 |
47 | # Render a batch of images
48 | I_batch = []
49 | for image in images:
50 | screen.update_texture(image[..., wavelengths.index(wavelength)])
51 | I_batch.append(screen.shading(uv, mask))
52 | return torch.stack(I_batch, axis=0), mask
53 |
54 | def render(screen: do.Screen, images: list[torch.Tensor], ray_counts_per_pixel: int) -> torch.Tensor:
55 | Is = []
56 | for wavelength in wavelengths:
57 | I = 0
58 | M = 0
59 | for i in range(ray_counts_per_pixel):
60 | I_current, mask = render_single(wavelength, screen, lambda x : lens.sample_ray_sensor(x), images)
61 | I = I + I_current
62 | M = M + mask
63 | I = I / (M[None, ...] + 1e-10)
64 | I = I.reshape((len(images), *np.flip(np.asarray(film_size)))).permute(0,2,1)
65 | Is.append(I)
66 | return torch.stack(Is, axis=-1)
67 |
68 | focal_length = 102 # [mm]
69 | def render_gt(screen: do.Screen, images: list[torch.Tensor]) -> torch.Tensor:
70 | Is = []
71 | for wavelength in wavelengths:
72 | I, mask = render_single(wavelength, screen, lambda x : lens.sample_ray_sensor_pinhole(x, focal_length), images)
73 | I = I.reshape((len(images), *np.flip(np.asarray(film_size)))).permute(0,2,1)
74 | Is.append(I)
75 | return torch.stack(Is, axis=-1)
76 |
77 |
78 | ## Set differentiable optical parameters
79 | # XY_surface = (
80 | # a[0] +
81 | # a[1] * x + a[2] * y +
82 | # a[3] * x**2 + a[4] * x*y + a[5] * y**2 +
83 | # a[6] * x**3 + a[7] * x**2*y + a[8] * x*y**2 + a[9] * y**3
84 | # )
85 | # We optimize for a cubic profile (o.e. 3rd-order coefficients), as in the wavefront coding technology.
86 | diff_parameters = [
87 | lens.surfaces[0].ai
88 | ]
89 | learning_rates = {
90 | 'surfaces[0].ai': 1e-15 * torch.Tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1]).to(device)
91 | }
92 | for diff_para, key in zip(diff_parameters, learning_rates.keys()):
93 | if len(diff_para) != len(learning_rates[key]):
94 | raise Exception('Learning rates of {} is not of equal length to the parameters!'.format(key))
95 | diff_para.requires_grad = True
96 | diff_parameter_labels = learning_rates.keys()
97 |
98 |
99 | ## Create network
100 | net = load_deblurganv2()
101 | net.prepare()
102 |
103 | print('Initial:')
104 | current_parameters = [x.detach().cpu().numpy() for x in diff_parameters]
105 | print('Current optical parameters are:')
106 | for x, label in zip(current_parameters, diff_parameter_labels):
107 | print('-- lens.{}: {}'.format(label, x))
108 |
109 | # Training dataset
110 | train_path = './training_dataset/'
111 | train_dataloader = torch.utils.data.DataLoader(ImageFolder(train_path), batch_size=1, shuffle=False)
112 | it = iter(train_dataloader)
113 | image = next(it).squeeze().to(device)
114 |
115 | # Training settings
116 | settings = {
117 | 'spp_forward': 100, # Rays per pixel for forward
118 | 'spp_backward': 20, # Rays per pixel for a single-pass backward
119 | 'num_passes': 5, # Number of accumulation passes for the backward
120 | 'image_batch_size': 5, # Images per batch
121 | 'network_training_iter': 200, # Training iterations for network update
122 | 'num_of_training': 10, # Training outer loop iteration
123 | 'savefig': True # Save intermediate results
124 | }
125 |
126 | if settings['savefig']:
127 | opath = Path('end2end_output') / str(datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))
128 | opath.mkdir(parents=True, exist_ok=True)
129 |
130 | def wrapper_func(screen, images, squeezed_diff_parameters, diff_parameters, diff_parameter_labels):
131 | unpacked_diff_parameters = tensor_to_dict(squeezed_diff_parameters, diff_parameters)
132 | for idx, label in enumerate(diff_parameter_labels):
133 | exec('lens.{} = unpacked_diff_parameters[{}]'.format(label, idx))
134 | return render(screen, images, settings['spp_forward'])
135 |
136 |
137 | # Physical parameters for the screen
138 | zs = [8e3, 6e3, 4.5e3] # [mm]
139 | pixelsizes = [0.1 * z/6e3 for z in zs] # [mm]
140 |
141 | print('Training starts ...')
142 | for iteration in range(settings['num_of_training']):
143 | for z_idx, z in enumerate(zs):
144 |
145 | # Print current status
146 | current_parameters = [x.detach().cpu().numpy() for x in diff_parameters]
147 | print('=========')
148 | print('Iteration = {}, z = {} [mm]:'.format(iteration, z))
149 | print('Current optical parameters are:')
150 | for x, label in zip(current_parameters, diff_parameter_labels):
151 | print('-- lens.{}: {}'.format(label, x))
152 | print('=========')
153 |
154 | # Put screen at a desired distance (and with a proper pixel size)
155 | screen = create_screen(image, z, pixelsizes[z_idx])
156 |
157 | # Forward rendering
158 | tq = tqdm(range(settings['image_batch_size']))
159 | tq.set_description('(1) Rendering batch images')
160 |
161 | # Load image batch (multiple images)
162 | images = []
163 | for image_idx in tq:
164 | try:
165 | data = next(it)
166 | except StopIteration:
167 | it = iter(train_dataloader)
168 | data = next(it)
169 | image = data.squeeze().to(device)
170 | images.append(image.clone())
171 |
172 | with torch.no_grad():
173 | Is = render(screen, images, settings['spp_forward'])
174 | Is_gt = render_gt(screen, images)
175 | tq.close()
176 |
177 | # Save images for visualization
178 | Is_view = np.concatenate([I.cpu().numpy().astype(np.uint8) for I in Is], axis=1)
179 | Is_gt_view = np.concatenate([I.cpu().numpy().astype(np.uint8) for I in Is_gt], axis=1)
180 |
181 | # Reorder tensors to match neural network input format
182 | Is = 2 * torch.permute(Is, (0, 3, 1, 2)) / 255 - 1
183 | Is_gt = 2 * torch.permute(Is_gt, (0, 3, 1, 2)) / 255 - 1
184 |
185 | # Train network weights
186 | Is_output = net.run(
187 | Is, Is_gt, is_inference=False,
188 | num_iters=settings['network_training_iter'], desc='(2) Training network weights'
189 | )
190 | Is_output_np = np.transpose(255/2 * (Is_output.detach().cpu().numpy() + 1), (0,2,3,1)).astype(np.uint8)
191 | Is_output_view = np.concatenate([I for I in Is_output_np], axis=1)
192 | del Is_output_np
193 |
194 | if settings['savefig']:
195 | fig, axs = plt.subplots(3, 1)
196 | for idx, I_view, label in zip(
197 | range(3), [Is_view, Is_gt_view, Is_output_view], ['Input', 'Ground truth', 'Network output']
198 | ):
199 | axs[idx].imshow(I_view)
200 | axs[idx].set_title(label + ' image(s)')
201 | axs[idx].set_axis_off()
202 | fig.tight_layout()
203 | fig.savefig(
204 | str(opath / 'iter_{}_z={}mm_images.png'.format(iteration, z)),
205 | dpi=400, bbox_inches='tight', pad_inches=0.1
206 | )
207 | fig.clear()
208 | plt.close(fig)
209 |
210 | # Back-propagate backend loss and obtain adjoint gradients
211 | Is.requires_grad = True
212 | Is_output = net.run(Is, Is_gt, is_inference=False, num_iters=1)
213 |
214 | # Get adjoint gradients of the image batch
215 | Is_grad = Is.grad.permute(0, 2, 3, 1)
216 | del Is, Is_gt, Is_output
217 | torch.cuda.empty_cache()
218 |
219 | # Back-propagate optical parameters with adjoint gradients, and accumulate
220 | tq = tqdm(range(settings['num_passes']))
221 | tq.set_description('(3) Back-prop optical parameters')
222 | dthetas = torch.zeros_like(dict_to_tensor(diff_parameters)).detach()
223 | for inner_iteration in tq:
224 | dthetas += torch.autograd.functional.vjp(
225 | lambda x : wrapper_func(screen, images, x, diff_parameters, diff_parameter_labels),
226 | dict_to_tensor(diff_parameters), Is_grad
227 | )[1]
228 | tq.close()
229 |
230 | # Update optical parameters
231 | with torch.no_grad():
232 | for label, diff_para, dtheta in zip(
233 | diff_parameter_labels, diff_parameters, tensor_to_dict(dthetas, diff_parameters)
234 | ):
235 | diff_para -= learning_rates[label] * dtheta.squeeze() / settings['num_passes']
236 | diff_para.grad = None
237 |
--------------------------------------------------------------------------------
/examples/images/einstein.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/images/einstein.jpg
--------------------------------------------------------------------------------
/examples/images/point.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/images/point.tif
--------------------------------------------------------------------------------
/examples/images/squirrel.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/images/squirrel.jpg
--------------------------------------------------------------------------------
/examples/lenses/DoubleGauss/US02532751-1.txt:
--------------------------------------------------------------------------------
1 | Originate from US02532751-1
2 | type distance roc diameter material
3 | O 0.0 0.0 0.0 VACUUM
4 | S 0.0 13.354729316461547 17.4 SSK4 0.005 1e-6 1e-8 -3e-10
5 | S 2.2352 35.64148197667863 17.4 VACUUM
6 | S 0.0762 10.330017837998932 14.0 SK1
7 | S 3.1750 0.0 14.0 F15
8 | S 0.9652 6.494496063151893 9.0 VACUUM
9 | A 3.8608 0.0 4.886 OCCLUDER
10 | S 3.302 -7.026950339915501 9.0 F15
11 | S 0.9652 0.0 12.0 SK16
12 | S 2.7686 -9.746574604143909 12.0 VACUUM
13 | S 0.0762 69.81692521236866 14.0 SK16
14 | S 1.7526 -19.226275376106166 14.0 VACUUM
15 | I 17.42769142705 0.0 20.664 VACUUM
16 |
--------------------------------------------------------------------------------
/examples/lenses/DoubleGauss/US02532751-1.zmx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/lenses/DoubleGauss/US02532751-1.zmx
--------------------------------------------------------------------------------
/examples/lenses/README.md:
--------------------------------------------------------------------------------
1 | ## Materials
2 |
3 | We modeled all materials with Cauchy's coefficients. Predefined materials could be found in [`../../diffoptics/basics.py`](https://github.com/vccimaging/DiffOptics/blob/7c30e22967280c97fb2953db0bcf894611df37b7/diffoptics/basics.py#L188-L234).
4 |
5 | Custom materials could be specified as ` / `, for example:
6 |
7 | ```
8 | 1.66565/35.64
9 | ```
10 |
11 | ## Format
12 |
13 | All surfaces follow the sequence of entries:
14 |
15 | ```
16 | type distance roc diameter material
17 | ```
18 |
19 | For each specific surface, the parameters are pended after the last column:
20 |
21 | ```
22 | Thorlabs-AC508-1000-A
23 | type distance roc diameter material
24 | O 0 0 100 AIR
25 | S 0 757.9 50.8 N-BK7
26 | S 6.0 -364.7 50.8 N-SF2
27 | S 6.0 -954.2 50.8 AIR
28 | I 996.4 0 50.8 AIR
29 | ```
30 | where the symbol means are:
31 |
32 | | Symbol | Surface type | Note |
33 | | :----: | :----------: | :------------: |
34 | | `O` | Asphere | Object plane |
35 | | `S` | Asphere | Lens surface |
36 | | `A` | Asphere | Aperture plane |
37 | | `I` | Asphere | Image plane |
38 |
39 |
--------------------------------------------------------------------------------
/examples/lenses/ThorLabs/ACL5040U.txt:
--------------------------------------------------------------------------------
1 | Thorlabs-ACL5040U
2 | type distance roc diameter material
3 | O 20 0 100 AIR
4 | S 0 20.923 50.0 B270 -0.6405 2.0e-06
5 | S 21.0 0 50.0 AIR
6 | I 26.0 0 50.0 AIR
7 |
--------------------------------------------------------------------------------
/examples/lenses/ThorLabs/LA1131.txt:
--------------------------------------------------------------------------------
1 | Thorlabs-LA1131
2 | type distance roc diameter material
3 | O 20 0 100 AIR
4 | S 0 25.75 25.4 N-BK7
5 | S 5.34 0 25.4 AIR
6 | I 4.42 0 25.4 AIR
--------------------------------------------------------------------------------
/examples/lenses/Zemax_samples/Nikon-z35-f1.8-JPA2019090949-example2.txt:
--------------------------------------------------------------------------------
1 | Originate from Zemax, modified by Congli Wang
2 | type distance roc diameter material
3 | O 0 0 0 AIR
4 | S 0 5.267 1.694 1.5168/64.12
5 | S 0.102 0.961 1.392 AIR
6 | S 0.309 1.442 1.322 1.9027/35.72
7 | S 0.246 10.280 1.250 1.5955/39.21
8 | S 0.083 1.215 1.092 AIR
9 | S 0.411 -1.099 1.048 1.6990/30.05
10 | S 0.088 2.918 1.172 1.9108/35.25
11 | S 0.258 -1.669 1.202 AIR
12 | S 0.009 1.643 1.248 1.5928/68.62
13 | S 0.379 -1.412 1.226 1.7205/34.70
14 | S 0.069 -2.572 1.214 AIR
15 | A 0.118 0 1.110 OCCLUDER
16 | S 0.604 -0.973 0.952 1.5927/35.31
17 | S 0.051 -24.080 0.980 AIR
18 | S 0.009 2.376 1.086 1.5928/68.62
19 | S 0.282 -1.306 1.138 AIR
20 | S 0.239 -7.317 1.208 1.6935/53.20 0 -0.240 -0.4268
21 | S 0.122 -2.200 1.254 AIR 0 -0.05053 -0.3491 0.1459 0.07718
22 | S 0.154 -1.545 1.324 1.4875/70.44
23 | S 0.083 -7.257 1.424 AIR
24 | S 0.750 0 2.400 1.5168/64.12
25 | S 0.074 0 2.400 AIR
26 | I 0.043 0 1.912 AIR
27 |
--------------------------------------------------------------------------------
/examples/lenses/end2end/end2end_edof.txt:
--------------------------------------------------------------------------------
1 | End2end EDOF exmaple
2 | type distance roc diameter material
3 | O 0.0 0.0 0.0 AIR
4 | S 0.0 62.8 12.7 n-bk7
5 | S 4.000 -45.7 12.7 sf5
6 | S 2.500 -128.2 12.7 AIR
7 | S 1.500 0.0 12.7 n-bk7
8 | X 2.500 0.0 12.7 AIR 0 0 0 0 0 0 0 0 0 0 0
9 | I 96.000 0.0 1.7664 AIR
--------------------------------------------------------------------------------
/examples/misalignment_point.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import matplotlib.pyplot as plt
4 | from pathlib import Path
5 | from matplotlib.image import imread, imsave
6 | import cv2
7 |
8 | import imageio
9 |
10 | import sys
11 | sys.path.append("../")
12 | import diffoptics as do
13 |
14 | """
15 | Experimental parameters:
16 |
17 | light source to sensor: about 675 [mm]
18 | sensor: GS3-U3-50S5M, pixel size 3.45 [um], resolution 2448 × 2048
19 | """
20 |
21 | # initialize a lens
22 | device = do.init()
23 | # device = torch.device('cpu')
24 | lens = do.Lensgroup(device=device)
25 |
26 | # ==== Load lens file
27 | lens.load_file(Path('./lenses/Thorlabs/LA1131.txt'))
28 | lens.d_sensor = torch.Tensor([56.0]).to(device) # [mm] sensor distance
29 | lens.plot_setup2D(with_sensor=True)
30 | R = lens.surfaces[0].r
31 |
32 | # sensor information
33 | downsample_N = 4
34 | pixel_size = 3.45e-3 * downsample_N # [mm]
35 | N_total = int(2048 / downsample_N)
36 | R_sensor = N_total * pixel_size / 2 # [mm]
37 |
38 | # set scene geometry
39 | wavelength = torch.Tensor([622.5]).to(device) # [nm]
40 |
41 | # point light source position
42 | light_o = torch.Tensor([0.0, 0.0, -650]).to(device)
43 | lens.light_o = light_o # hook-up
44 |
45 | R_in = 1.42*R # must be >= sqrt(2)
46 | M = 1024
47 | def sample_ray(M, light_o):
48 | o_x, o_y = torch.meshgrid(
49 | torch.linspace(-R_in, R_in, M, device=device),
50 | torch.linspace(-R_in, R_in, M, device=device)
51 | )
52 | valid = (o_x**2 + o_y**2) < (0.95*R)**2
53 |
54 | o = torch.stack((o_x, o_y, -torch.ones_like(o_x)), axis=-1)
55 | d = torch.stack((o_x, o_y, torch.zeros_like(o_x)), axis=-1) - light_o[None, None, ...]
56 | d = do.normalize(d)
57 |
58 | o = o[valid]
59 | d = d[valid]
60 |
61 | return do.Ray(o, d, wavelength, device=device)
62 |
63 | lens.pixel_size = pixel_size
64 | lens.film_size = [N_total,N_total]
65 | def render():
66 | ray = sample_ray(M, lens.light_o)
67 | I = lens.render(ray)
68 | I = N_total**2 * I / I.sum()
69 | return I
70 |
71 |
72 | # centroid
73 | X, Y = torch.meshgrid(
74 | 1 + torch.arange(N_total, device=device),
75 | 1 + torch.arange(N_total, device=device)
76 | )
77 | def centroid(I):
78 | return torch.stack((
79 | torch.sum(X * I) / torch.sum(I),
80 | torch.sum(Y * I) / torch.sum(I)
81 | ))
82 |
83 | ### Optimization utilities
84 | def loss(I, I_mea):
85 | data_term = torch.mean((I - I_mea)**2)
86 | comp_centroid = True
87 | if comp_centroid:
88 | c_mea = centroid(I_mea)
89 | c = centroid(I)
90 | loss = data_term + 0.0005 * torch.mean((c - c_mea)**2)
91 | else:
92 | loss = data_term
93 | return loss
94 |
95 |
96 | # read image
97 | img = imread('./data/20210304/ref2.tif') # for now we use grayscale
98 | img = img.astype(float)
99 | I_mea = cv2.resize(img, dsize=(N_total, N_total), interpolation=cv2.INTER_AREA)
100 | I_mea = np.maximum(0.0, I_mea - np.median(I_mea))
101 | I_mea = N_total**2 * I_mea / I_mea.sum()
102 | I_mea = torch.Tensor(I_mea).to(device)
103 |
104 | # AUTO DIFF
105 | diff_variables = ['d_sensor', 'theta_x', 'theta_y', 'light_o']
106 | out = do.LM(lens, diff_variables, 1e-3, option='diag') \
107 | .optimize(render, lambda y: I_mea - y, maxit=30, record=True)
108 |
109 |
110 | # crop images
111 | def crop(I):
112 | c = 200
113 | return I[c:I.shape[0]-c, c:I.shape[1]-c]
114 |
115 | opath = Path('misalignment_point')
116 | opath.mkdir(parents=True, exist_ok=True)
117 | def save(I_mea, Is):
118 | images = []
119 | for I in Is:
120 | images.append(crop(I))
121 | imageio.mimsave(str(opath / 'movie.mp4'), images)
122 |
123 | # show results
124 | plt.figure()
125 | plt.imshow(crop(Is[0]), cmap='gray')
126 | plt.title('Simulation (initial)')
127 |
128 | plt.figure()
129 | plt.imshow(crop(Is[-1]), cmap='gray')
130 | plt.title('Simulation (optimized)')
131 |
132 | I_mea = I_mea.cpu().detach().numpy()
133 |
134 | plt.figure()
135 | plt.imshow(crop(I_mea), cmap='gray')
136 | plt.title('Measurement')
137 | plt.show()
138 |
139 | plt.imsave(str(opath / 'I0.jpg'), crop(Is[0]), vmin=0, vmax=np.maximum(I_mea.max(), Is[-1].max()), cmap='gray')
140 | plt.imsave(str(opath / 'I.jpg'), crop(Is[-1]), vmin=0, vmax=np.maximum(I_mea.max(), Is[-1].max()), cmap='gray')
141 | plt.imsave(str(opath / 'I_mea.jpg'), crop(I_mea), vmin=0, vmax=np.maximum(I_mea.max(), Is[-1].max()), cmap='gray')
142 |
143 |
144 | save(I_mea, out['Is'])
145 |
146 | fig = plt.figure()
147 | plt.plot(out['ls'], 'k-o')
148 | plt.xlabel('iteration')
149 | plt.ylabel('loss')
150 | fig.savefig(str(opath / "ls.pdf"), bbox_inches='tight')
151 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | *.pyc
3 | #fpn_inception.h5
4 | #fpn_mobilenet.h5
5 | *.h5
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2019, Orest Kupyn, Tetiana Martyniuk, Junru Wu and Zhangyang Wang
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 |
25 |
26 | --------------------------- LICENSE FOR DeblurGANv2 --------------------------------
27 | BSD License
28 |
29 | For DeblurGANv2 software
30 | Copyright (c) 2019, Orest Kupyn, Tetiana Martyniuk, Junru Wu and Zhangyang Wang
31 | All rights reserved.
32 |
33 | Redistribution and use in source and binary forms, with or without
34 | modification, are permitted provided that the following conditions are met:
35 |
36 | * Redistributions of source code must retain the above copyright notice, this
37 | list of conditions and the following disclaimer.
38 |
39 | * Redistributions in binary form must reproduce the above copyright notice,
40 | this list of conditions and the following disclaimer in the documentation
41 | and/or other materials provided with the distribution.
42 |
43 | ----------------------------- LICENSE FOR DCGAN --------------------------------
44 | BSD License
45 |
46 | For dcgan.torch software
47 |
48 | Copyright (c) 2015, Facebook, Inc. All rights reserved.
49 |
50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
51 |
52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
53 |
54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
55 |
56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
57 |
58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
59 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/README.md:
--------------------------------------------------------------------------------
1 | # DeblurGAN-v2: Deblurring (Orders-of-Magnitude) Faster and Better
2 |
3 | Code for this paper [DeblurGAN-v2: Deblurring (Orders-of-Magnitude) Faster and Better](https://arxiv.org/abs/1908.03826)
4 |
5 | Orest Kupyn, Tetiana Martyniuk, Junru Wu, Zhangyang Wang
6 |
7 | In ICCV 2019
8 |
9 | ## Overview
10 |
11 | We present a new end-to-end generative adversarial network (GAN) for single image motion deblurring, named
12 | DeblurGAN-v2, which considerably boosts state-of-the-art deblurring efficiency, quality, and flexibility. DeblurGAN-v2
13 | is based on a relativistic conditional GAN with a double-scale discriminator. For the first time, we introduce the
14 | Feature Pyramid Network into deblurring, as a core building block in the generator of DeblurGAN-v2. It can flexibly
15 | work with a wide range of backbones, to navigate the balance between performance and efficiency. The plug-in of
16 | sophisticated backbones (e.g., Inception-ResNet-v2) can lead to solid state-of-the-art deblurring. Meanwhile,
17 | with light-weight backbones (e.g., MobileNet and its variants), DeblurGAN-v2 reaches 10-100 times faster than
18 | the nearest competitors, while maintaining close to state-of-the-art results, implying the option of real-time
19 | video deblurring. We demonstrate that DeblurGAN-v2 obtains very competitive performance on several popular
20 | benchmarks, in terms of deblurring quality (both objective and subjective), as well as efficiency. Besides,
21 | we show the architecture to be effective for general image restoration tasks too.
22 |
23 |
27 |
28 | 
29 | 
30 | 
31 | 
32 |
33 |
34 |
35 | ## DeblurGAN-v2 Architecture
36 |
37 | 
38 |
39 |
45 |
46 |
49 |
50 | ## Datasets
51 |
52 | The datasets for training can be downloaded via the links below:
53 | - [DVD](https://drive.google.com/file/d/1bpj9pCcZR_6-AHb5aNnev5lILQbH8GMZ/view)
54 | - [GoPro](https://drive.google.com/file/d/1KStHiZn5TNm2mo3OLZLjnRvd0vVFCI0W/view)
55 | - [NFS](https://drive.google.com/file/d/1Ut7qbQOrsTZCUJA_mJLptRMipD8sJzjy/view)
56 |
57 | ## Training
58 |
59 | #### Command
60 |
61 | ```python train.py```
62 |
63 | training script will load config under config/config.yaml
64 |
65 | #### Tensorboard visualization
66 |
67 | 
68 |
69 | ## Testing
70 |
71 | To test on a single image,
72 |
73 | ```python predict.py IMAGE_NAME.jpg```
74 |
75 | By default, the name of the pretrained model used by Predictor is 'best_fpn.h5'. One can change it in the code ('weights_path' argument). It assumes that the fpn_inception backbone is used. If you want to try it with different backbone pretrain, please specify it also under ['model']['g_name'] in config/config.yaml.
76 |
77 | ## Pre-trained models
78 |
79 |
80 |
81 | Dataset |
82 | G Model |
83 | D Model |
84 | Loss Type |
85 | PSNR/ SSIM |
86 | Link |
87 |
88 |
89 | GoPro Test Dataset |
90 | InceptionResNet-v2 |
91 | double_gan |
92 | ragan-ls |
93 | 29.55/ 0.934 |
94 | fpn_inception.h5 |
95 |
96 |
97 | MobileNet |
98 | double_gan |
99 | ragan-ls |
100 | 28.17/ 0.925 |
101 | fpn_mobilenet.h5 |
102 |
103 |
104 | MobileNet-DSC |
105 | double_gan |
106 | ragan-ls |
107 | 28.03/ 0.922 |
108 | |
109 |
110 |
111 |
112 | ## Parent Repository
113 |
114 | The code was taken from https://github.com/KupynOrest/RestoreGAN . This repository contains flexible pipelines for different Image Restoration tasks.
115 |
116 | ## Citation
117 |
118 | If you use this code for your research, please cite our paper.
119 |
120 | ```
121 | ```
122 | @InProceedings{Kupyn_2019_ICCV,
123 | author = {Orest Kupyn and Tetiana Martyniuk and Junru Wu and Zhangyang Wang},
124 | title = {DeblurGAN-v2: Deblurring (Orders-of-Magnitude) Faster and Better},
125 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
126 | month = {Oct},
127 | year = {2019}
128 | }
129 | ```
130 | ```
131 |
132 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/adversarial_trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import copy
3 |
4 |
5 | class GANFactory:
6 | factories = {}
7 |
8 | def __init__(self):
9 | pass
10 |
11 | def add_factory(gan_id, model_factory):
12 | GANFactory.factories.put[gan_id] = model_factory
13 |
14 | add_factory = staticmethod(add_factory)
15 |
16 | # A Template Method:
17 |
18 | def create_model(gan_id, net_d=None, criterion=None):
19 | if gan_id not in GANFactory.factories:
20 | GANFactory.factories[gan_id] = \
21 | eval(gan_id + '.Factory()')
22 | return GANFactory.factories[gan_id].create(net_d, criterion)
23 |
24 | create_model = staticmethod(create_model)
25 |
26 |
27 | class GANTrainer(object):
28 | def __init__(self, net_d, criterion):
29 | self.net_d = net_d
30 | self.criterion = criterion
31 |
32 | def loss_d(self, pred, gt):
33 | pass
34 |
35 | def loss_g(self, pred, gt):
36 | pass
37 |
38 | def get_params(self):
39 | pass
40 |
41 |
42 | class NoGAN(GANTrainer):
43 | def __init__(self, net_d, criterion):
44 | GANTrainer.__init__(self, net_d, criterion)
45 |
46 | def loss_d(self, pred, gt):
47 | return [0]
48 |
49 | def loss_g(self, pred, gt):
50 | return 0
51 |
52 | def get_params(self):
53 | return [torch.nn.Parameter(torch.Tensor(1))]
54 |
55 | class Factory:
56 | @staticmethod
57 | def create(net_d, criterion): return NoGAN(net_d, criterion)
58 |
59 |
60 | class SingleGAN(GANTrainer):
61 | def __init__(self, net_d, criterion):
62 | GANTrainer.__init__(self, net_d, criterion)
63 | self.net_d = self.net_d.cuda()
64 |
65 | def loss_d(self, pred, gt):
66 | return self.criterion(self.net_d, pred, gt)
67 |
68 | def loss_g(self, pred, gt):
69 | return self.criterion.get_g_loss(self.net_d, pred, gt)
70 |
71 | def get_params(self):
72 | return self.net_d.parameters()
73 |
74 | class Factory:
75 | @staticmethod
76 | def create(net_d, criterion): return SingleGAN(net_d, criterion)
77 |
78 |
79 | class DoubleGAN(GANTrainer):
80 | def __init__(self, net_d, criterion):
81 | GANTrainer.__init__(self, net_d, criterion)
82 | self.patch_d = net_d['patch'].cuda()
83 | self.full_d = net_d['full'].cuda()
84 | self.full_criterion = copy.deepcopy(criterion)
85 |
86 | def loss_d(self, pred, gt):
87 | return (self.criterion(self.patch_d, pred, gt) + self.full_criterion(self.full_d, pred, gt)) / 2
88 |
89 | def loss_g(self, pred, gt):
90 | return (self.criterion.get_g_loss(self.patch_d, pred, gt) + self.full_criterion.get_g_loss(self.full_d, pred,
91 | gt)) / 2
92 |
93 | def get_params(self):
94 | return list(self.patch_d.parameters()) + list(self.full_d.parameters())
95 |
96 | class Factory:
97 | @staticmethod
98 | def create(net_d, criterion): return DoubleGAN(net_d, criterion)
99 |
100 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/aug.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import albumentations as albu
4 |
5 |
6 | def get_transforms(size: int, scope: str = 'geometric', crop='random'):
7 | augs = {'weak': albu.Compose([albu.HorizontalFlip(),
8 | ]),
9 | 'geometric': albu.OneOf([albu.HorizontalFlip(always_apply=True),
10 | albu.ShiftScaleRotate(always_apply=True),
11 | albu.Transpose(always_apply=True),
12 | albu.OpticalDistortion(always_apply=True),
13 | albu.ElasticTransform(always_apply=True),
14 | ])
15 | }
16 |
17 | aug_fn = augs[scope]
18 | crop_fn = {'random': albu.RandomCrop(size, size, always_apply=True),
19 | 'center': albu.CenterCrop(size, size, always_apply=True)}[crop]
20 | pad = albu.PadIfNeeded(size, size)
21 |
22 | pipeline = albu.Compose([aug_fn, pad, crop_fn], additional_targets={'target': 'image'})
23 |
24 | def process(a, b):
25 | r = pipeline(image=a, target=b)
26 | return r['image'], r['target']
27 |
28 | return process
29 |
30 |
31 | def get_normalize():
32 | normalize = albu.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
33 | normalize = albu.Compose([normalize], additional_targets={'target': 'image'})
34 |
35 | def process(a, b):
36 | r = normalize(image=a, target=b)
37 | return r['image'], r['target']
38 |
39 | return process
40 |
41 |
42 | def _resolve_aug_fn(name):
43 | d = {
44 | 'cutout': albu.Cutout,
45 | 'rgb_shift': albu.RGBShift,
46 | 'hsv_shift': albu.HueSaturationValue,
47 | 'motion_blur': albu.MotionBlur,
48 | 'median_blur': albu.MedianBlur,
49 | 'snow': albu.RandomSnow,
50 | 'shadow': albu.RandomShadow,
51 | 'fog': albu.RandomFog,
52 | 'brightness_contrast': albu.RandomBrightnessContrast,
53 | 'gamma': albu.RandomGamma,
54 | 'sun_flare': albu.RandomSunFlare,
55 | 'sharpen': albu.Sharpen,
56 | 'jpeg': albu.ImageCompression,
57 | 'gray': albu.ToGray,
58 | 'pixelize': albu.Downscale,
59 | # ToDo: partial gray
60 | }
61 | return d[name]
62 |
63 |
64 | def get_corrupt_function(config: List[dict]):
65 | augs = []
66 | for aug_params in config:
67 | name = aug_params.pop('name')
68 | cls = _resolve_aug_fn(name)
69 | prob = aug_params.pop('prob') if 'prob' in aug_params else .5
70 | augs.append(cls(p=prob, **aug_params))
71 |
72 | augs = albu.OneOf(augs)
73 |
74 | def process(x):
75 | return augs(image=x)['image']
76 |
77 | return process
78 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/config/config.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | project: deblur_gan
3 | experiment_desc: fpn
4 |
5 | train:
6 | files_a: &FILES_A ./training_dataset/*.png
7 | files_b: *FILES_A
8 | size: &SIZE 512
9 | crop: random
10 | preload: &PRELOAD false
11 | preload_size: &PRELOAD_SIZE 0
12 | bounds: [0, .9]
13 | scope: geometric
14 | corrupt: &CORRUPT
15 | - name: cutout
16 | prob: 0.5
17 | num_holes: 3
18 | max_h_size: 25
19 | max_w_size: 25
20 | - name: jpeg
21 | quality_lower: 70
22 | quality_upper: 90
23 | - name: motion_blur
24 | - name: median_blur
25 | - name: gamma
26 | - name: rgb_shift
27 | - name: hsv_shift
28 | - name: sharpen
29 |
30 | val:
31 | files_a: *FILES_A
32 | files_b: *FILES_A
33 | # files_a: &FILES_A
34 | # files_b: &FILES_B
35 | size: *SIZE
36 | scope: geometric
37 | crop: center
38 | preload: *PRELOAD
39 | preload_size: *PRELOAD_SIZE
40 | bounds: [.9, 1]
41 | corrupt: *CORRUPT
42 |
43 | phase: train
44 | warmup_num: 3
45 | model:
46 | g_name: fpn_inception
47 | blocks: 9
48 | d_name: double_gan # may be no_gan, patch_gan, double_gan, multi_scale
49 | d_layers: 3
50 | content_loss: perceptual
51 | adv_lambda: 0.001
52 | disc_loss: wgan-gp
53 | learn_residual: True
54 | norm_layer: instance
55 | dropout: True
56 |
57 | num_epochs: 10
58 | train_batches_per_epoch: 1000
59 | val_batches_per_epoch: 100
60 | batch_size: 1
61 | image_size: [512, 512]
62 |
63 | optimizer:
64 | name: adam
65 | lr: 0.01
66 | scheduler:
67 | name: linear
68 | start_epoch: 50
69 | min_lr: 0.0000001
70 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from copy import deepcopy
3 | from functools import partial
4 | from glob import glob
5 | from hashlib import sha1
6 | from typing import Callable, Iterable, Optional, Tuple
7 |
8 | import cv2
9 | import numpy as np
10 | from glog import logger
11 | from joblib import Parallel, cpu_count, delayed
12 | from skimage.io import imread
13 | from torch.utils.data import Dataset
14 | from tqdm import tqdm
15 |
16 | import aug
17 |
18 |
19 | def subsample(data: Iterable, bounds: Tuple[float, float], hash_fn: Callable, n_buckets=100, salt='', verbose=True):
20 | data = list(data)
21 | buckets = split_into_buckets(data, n_buckets=n_buckets, salt=salt, hash_fn=hash_fn)
22 |
23 | lower_bound, upper_bound = [x * n_buckets for x in bounds]
24 | msg = f'Subsampling buckets from {lower_bound} to {upper_bound}, total buckets number is {n_buckets}'
25 | if salt:
26 | msg += f'; salt is {salt}'
27 | if verbose:
28 | logger.info(msg)
29 | return np.array([sample for bucket, sample in zip(buckets, data) if lower_bound <= bucket < upper_bound])
30 |
31 |
32 | def hash_from_paths(x: Tuple[str, str], salt: str = '') -> str:
33 | path_a, path_b = x
34 | names = ''.join(map(os.path.basename, (path_a, path_b)))
35 | return sha1(f'{names}_{salt}'.encode()).hexdigest()
36 |
37 |
38 | def split_into_buckets(data: Iterable, n_buckets: int, hash_fn: Callable, salt=''):
39 | hashes = map(partial(hash_fn, salt=salt), data)
40 | return np.array([int(x, 16) % n_buckets for x in hashes])
41 |
42 |
43 | def _read_img(x: str):
44 | img = cv2.imread(x)
45 | if img is None:
46 | logger.warning(f'Can not read image {x} with OpenCV, switching to scikit-image')
47 | img = imread(x)[:, :, ::-1]
48 | return img
49 |
50 |
51 | class PairedDataset(Dataset):
52 | def __init__(self,
53 | files_a: Tuple[str],
54 | files_b: Tuple[str],
55 | transform_fn: Callable,
56 | normalize_fn: Callable,
57 | corrupt_fn: Optional[Callable] = None,
58 | preload: bool = True,
59 | preload_size: Optional[int] = 0,
60 | verbose=True):
61 |
62 | assert len(files_a) == len(files_b)
63 |
64 | self.preload = preload
65 | self.data_a = files_a
66 | self.data_b = files_b
67 | self.verbose = verbose
68 | self.corrupt_fn = corrupt_fn
69 | self.transform_fn = transform_fn
70 | self.normalize_fn = normalize_fn
71 | logger.info(f'Dataset has been created with {len(self.data_a)} samples')
72 |
73 | if preload:
74 | preload_fn = partial(self._bulk_preload, preload_size=preload_size)
75 | if files_a == files_b:
76 | self.data_a = self.data_b = preload_fn(self.data_a)
77 | else:
78 | self.data_a, self.data_b = map(preload_fn, (self.data_a, self.data_b))
79 | self.preload = True
80 |
81 | def _bulk_preload(self, data: Iterable[str], preload_size: int):
82 | jobs = [delayed(self._preload)(x, preload_size=preload_size) for x in data]
83 | jobs = tqdm(jobs, desc='preloading images', disable=not self.verbose)
84 | return Parallel(n_jobs=cpu_count(), backend='threading')(jobs)
85 |
86 | @staticmethod
87 | def _preload(x: str, preload_size: int):
88 | img = _read_img(x)
89 | if preload_size:
90 | h, w, *_ = img.shape
91 | h_scale = preload_size / h
92 | w_scale = preload_size / w
93 | scale = max(h_scale, w_scale)
94 | img = cv2.resize(img, fx=scale, fy=scale, dsize=None)
95 | assert min(img.shape[:2]) >= preload_size, f'weird img shape: {img.shape}'
96 | return img
97 |
98 | def _preprocess(self, img, res):
99 | def transpose(x):
100 | return np.transpose(x, (2, 0, 1))
101 |
102 | return map(transpose, self.normalize_fn(img, res))
103 |
104 | def __len__(self):
105 | return len(self.data_a)
106 |
107 | def __getitem__(self, idx):
108 | a, b = self.data_a[idx], self.data_b[idx]
109 | if not self.preload:
110 | a, b = map(_read_img, (a, b))
111 | a, b = self.transform_fn(a, b)
112 | if self.corrupt_fn is not None:
113 | a = self.corrupt_fn(a)
114 | a, b = self._preprocess(a, b)
115 | return {'a': a, 'b': b}
116 |
117 | @staticmethod
118 | def from_config(config):
119 | config = deepcopy(config)
120 | files_a, files_b = map(lambda x: sorted(glob(config[x], recursive=True)), ('files_a', 'files_b'))
121 | transform_fn = aug.get_transforms(size=config['size'], scope=config['scope'], crop=config['crop'])
122 | normalize_fn = aug.get_normalize()
123 | corrupt_fn = aug.get_corrupt_function(config['corrupt'])
124 |
125 | hash_fn = hash_from_paths
126 | # ToDo: add more hash functions
127 | verbose = config.get('verbose', True)
128 | data = subsample(data=zip(files_a, files_b),
129 | bounds=config.get('bounds', (0, 1)),
130 | hash_fn=hash_fn,
131 | verbose=verbose)
132 |
133 | files_a, files_b = map(list, zip(*data))
134 |
135 | return PairedDataset(files_a=files_a,
136 | files_b=files_b,
137 | preload=config['preload'],
138 | preload_size=config['preload_size'],
139 | corrupt_fn=corrupt_fn,
140 | normalize_fn=normalize_fn,
141 | transform_fn=transform_fn,
142 | verbose=verbose)
143 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/metric_counter.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import defaultdict
3 |
4 | import numpy as np
5 | from tensorboardX import SummaryWriter
6 |
7 | WINDOW_SIZE = 100
8 |
9 |
10 | class MetricCounter:
11 | def __init__(self, exp_name):
12 | self.writer = SummaryWriter(exp_name)
13 | logging.basicConfig(filename='{}.log'.format(exp_name), level=logging.DEBUG)
14 | self.metrics = defaultdict(list)
15 | self.images = defaultdict(list)
16 | self.best_metric = 0
17 |
18 | def add_image(self, x: np.ndarray, tag: str):
19 | self.images[tag].append(x)
20 |
21 | def clear(self):
22 | self.metrics = defaultdict(list)
23 | self.images = defaultdict(list)
24 |
25 | def add_losses(self, l_G, l_content, l_D=0):
26 | for name, value in zip(('G_loss', 'G_loss_content', 'G_loss_adv', 'D_loss'),
27 | (l_G, l_content, l_G - l_content, l_D)):
28 | self.metrics[name].append(value)
29 |
30 | def add_metrics(self, psnr, ssim):
31 | for name, value in zip(('PSNR', 'SSIM'),
32 | (psnr, ssim)):
33 | self.metrics[name].append(value)
34 |
35 | def loss_message(self):
36 | metrics = ((k, np.mean(self.metrics[k][-WINDOW_SIZE:])) for k in ('G_loss', 'PSNR', 'SSIM'))
37 | return '; '.join(map(lambda x: f'{x[0]}={x[1]:.4f}', metrics))
38 |
39 | def write_to_tensorboard(self, epoch_num, validation=False):
40 | scalar_prefix = 'Validation' if validation else 'Train'
41 | for tag in ('G_loss', 'D_loss', 'G_loss_adv', 'G_loss_content', 'SSIM', 'PSNR'):
42 | self.writer.add_scalar(f'{scalar_prefix}_{tag}', np.mean(self.metrics[tag]), global_step=epoch_num)
43 | for tag in self.images:
44 | imgs = self.images[tag]
45 | if imgs:
46 | imgs = np.array(imgs)
47 | self.writer.add_images(tag, imgs[:, :, :, ::-1].astype('float32') / 255, dataformats='NHWC',
48 | global_step=epoch_num)
49 | self.images[tag] = []
50 |
51 | def update_best_model(self):
52 | cur_metric = np.mean(self.metrics['PSNR'])
53 | if self.best_metric < cur_metric:
54 | self.best_metric = cur_metric
55 | return True
56 | return False
57 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/neural_networks/DeblurGANv2/models/__init__.py
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/models/fpn_densenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from torchvision.models import densenet121
5 |
6 |
7 | class FPNSegHead(nn.Module):
8 | def __init__(self, num_in, num_mid, num_out):
9 | super().__init__()
10 |
11 | self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
12 | self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)
13 |
14 | def forward(self, x):
15 | x = nn.functional.relu(self.block0(x), inplace=True)
16 | x = nn.functional.relu(self.block1(x), inplace=True)
17 | return x
18 |
19 |
20 | class FPNDense(nn.Module):
21 |
22 | def __init__(self, output_ch=3, num_filters=128, num_filters_fpn=256, pretrained=True):
23 | super().__init__()
24 |
25 | # Feature Pyramid Network (FPN) with four feature maps of resolutions
26 | # 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
27 |
28 | self.fpn = FPN(num_filters=num_filters_fpn, pretrained=pretrained)
29 |
30 | # The segmentation heads on top of the FPN
31 |
32 | self.head1 = FPNSegHead(num_filters_fpn, num_filters, num_filters)
33 | self.head2 = FPNSegHead(num_filters_fpn, num_filters, num_filters)
34 | self.head3 = FPNSegHead(num_filters_fpn, num_filters, num_filters)
35 | self.head4 = FPNSegHead(num_filters_fpn, num_filters, num_filters)
36 |
37 | self.smooth = nn.Sequential(
38 | nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
39 | nn.BatchNorm2d(num_filters),
40 | nn.ReLU(),
41 | )
42 |
43 | self.smooth2 = nn.Sequential(
44 | nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
45 | nn.BatchNorm2d(num_filters // 2),
46 | nn.ReLU(),
47 | )
48 |
49 | self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)
50 |
51 | def forward(self, x):
52 | map0, map1, map2, map3, map4 = self.fpn(x)
53 |
54 | map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
55 | map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
56 | map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
57 | map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest")
58 |
59 | smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
60 | smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
61 | smoothed = self.smooth2(smoothed + map0)
62 | smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
63 |
64 | final = self.final(smoothed)
65 | return torch.tanh(final)
66 |
67 | def unfreeze(self):
68 | for param in self.fpn.parameters():
69 | param.requires_grad = True
70 |
71 |
72 | class FPN(nn.Module):
73 |
74 | def __init__(self, num_filters=256, pretrained=True):
75 | """Creates an `FPN` instance for feature extraction.
76 | Args:
77 | num_filters: the number of filters in each output pyramid level
78 | pretrained: use ImageNet pre-trained backbone feature extractor
79 | """
80 |
81 | super().__init__()
82 |
83 | self.features = densenet121(pretrained=pretrained).features
84 |
85 | self.enc0 = nn.Sequential(self.features.conv0,
86 | self.features.norm0,
87 | self.features.relu0)
88 | self.pool0 = self.features.pool0
89 | self.enc1 = self.features.denseblock1 # 256
90 | self.enc2 = self.features.denseblock2 # 512
91 | self.enc3 = self.features.denseblock3 # 1024
92 | self.enc4 = self.features.denseblock4 # 2048
93 | self.norm = self.features.norm5 # 2048
94 |
95 | self.tr1 = self.features.transition1 # 256
96 | self.tr2 = self.features.transition2 # 512
97 | self.tr3 = self.features.transition3 # 1024
98 |
99 | self.lateral4 = nn.Conv2d(1024, num_filters, kernel_size=1, bias=False)
100 | self.lateral3 = nn.Conv2d(1024, num_filters, kernel_size=1, bias=False)
101 | self.lateral2 = nn.Conv2d(512, num_filters, kernel_size=1, bias=False)
102 | self.lateral1 = nn.Conv2d(256, num_filters, kernel_size=1, bias=False)
103 | self.lateral0 = nn.Conv2d(64, num_filters // 2, kernel_size=1, bias=False)
104 |
105 | def forward(self, x):
106 | # Bottom-up pathway, from ResNet
107 | enc0 = self.enc0(x)
108 |
109 | pooled = self.pool0(enc0)
110 |
111 | enc1 = self.enc1(pooled) # 256
112 | tr1 = self.tr1(enc1)
113 |
114 | enc2 = self.enc2(tr1) # 512
115 | tr2 = self.tr2(enc2)
116 |
117 | enc3 = self.enc3(tr2) # 1024
118 | tr3 = self.tr3(enc3)
119 |
120 | enc4 = self.enc4(tr3) # 2048
121 | enc4 = self.norm(enc4)
122 |
123 | # Lateral connections
124 |
125 | lateral4 = self.lateral4(enc4)
126 | lateral3 = self.lateral3(enc3)
127 | lateral2 = self.lateral2(enc2)
128 | lateral1 = self.lateral1(enc1)
129 | lateral0 = self.lateral0(enc0)
130 |
131 | # Top-down pathway
132 |
133 | map4 = lateral4
134 | map3 = lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest")
135 | map2 = lateral2 + nn.functional.upsample(map3, scale_factor=2, mode="nearest")
136 | map1 = lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest")
137 |
138 | return lateral0, map1, map2, map3, map4
139 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/models/fpn_inception.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from pretrainedmodels import inceptionresnetv2
4 | from torchsummary import summary
5 | import torch.nn.functional as F
6 |
7 | class FPNHead(nn.Module):
8 | def __init__(self, num_in, num_mid, num_out):
9 | super().__init__()
10 |
11 | self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
12 | self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)
13 |
14 | def forward(self, x):
15 | x = nn.functional.relu(self.block0(x), inplace=True)
16 | x = nn.functional.relu(self.block1(x), inplace=True)
17 | return x
18 |
19 | class ConvBlock(nn.Module):
20 | def __init__(self, num_in, num_out, norm_layer):
21 | super().__init__()
22 |
23 | self.block = nn.Sequential(nn.Conv2d(num_in, num_out, kernel_size=3, padding=1),
24 | norm_layer(num_out),
25 | nn.ReLU(inplace=True))
26 |
27 | def forward(self, x):
28 | x = self.block(x)
29 | return x
30 |
31 |
32 | class FPNInception(nn.Module):
33 |
34 | def __init__(self, norm_layer, output_ch=3, num_filters=128, num_filters_fpn=256):
35 | super().__init__()
36 |
37 | # Feature Pyramid Network (FPN) with four feature maps of resolutions
38 | # 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
39 | self.fpn = FPN(num_filters=num_filters_fpn, norm_layer=norm_layer)
40 |
41 | # The segmentation heads on top of the FPN
42 |
43 | self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters)
44 | self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters)
45 | self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters)
46 | self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters)
47 |
48 | self.smooth = nn.Sequential(
49 | nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
50 | norm_layer(num_filters),
51 | nn.ReLU(),
52 | )
53 |
54 | self.smooth2 = nn.Sequential(
55 | nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
56 | norm_layer(num_filters // 2),
57 | nn.ReLU(),
58 | )
59 |
60 | self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)
61 |
62 | def unfreeze(self):
63 | self.fpn.unfreeze()
64 |
65 | def forward(self, x):
66 | map0, map1, map2, map3, map4 = self.fpn(x)
67 |
68 | map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
69 | map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
70 | map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
71 | map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest")
72 |
73 | smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
74 | smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
75 | smoothed = self.smooth2(smoothed + map0)
76 | smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
77 |
78 | final = self.final(smoothed)
79 | res = torch.tanh(final) + x
80 |
81 | return torch.clamp(res, min = -1,max = 1)
82 |
83 |
84 | class FPN(nn.Module):
85 |
86 | def __init__(self, norm_layer, num_filters=256):
87 | """Creates an `FPN` instance for feature extraction.
88 | Args:
89 | num_filters: the number of filters in each output pyramid level
90 | pretrained: use ImageNet pre-trained backbone feature extractor
91 | """
92 |
93 | super().__init__()
94 | self.inception = inceptionresnetv2(num_classes=1000, pretrained='imagenet')
95 |
96 | self.enc0 = self.inception.conv2d_1a
97 | self.enc1 = nn.Sequential(
98 | self.inception.conv2d_2a,
99 | self.inception.conv2d_2b,
100 | self.inception.maxpool_3a,
101 | ) # 64
102 | self.enc2 = nn.Sequential(
103 | self.inception.conv2d_3b,
104 | self.inception.conv2d_4a,
105 | self.inception.maxpool_5a,
106 | ) # 192
107 | self.enc3 = nn.Sequential(
108 | self.inception.mixed_5b,
109 | self.inception.repeat,
110 | self.inception.mixed_6a,
111 | ) # 1088
112 | self.enc4 = nn.Sequential(
113 | self.inception.repeat_1,
114 | self.inception.mixed_7a,
115 | ) #2080
116 | self.td1 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
117 | norm_layer(num_filters),
118 | nn.ReLU(inplace=True))
119 | self.td2 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
120 | norm_layer(num_filters),
121 | nn.ReLU(inplace=True))
122 | self.td3 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
123 | norm_layer(num_filters),
124 | nn.ReLU(inplace=True))
125 | self.pad = nn.ReflectionPad2d(1)
126 | self.lateral4 = nn.Conv2d(2080, num_filters, kernel_size=1, bias=False)
127 | self.lateral3 = nn.Conv2d(1088, num_filters, kernel_size=1, bias=False)
128 | self.lateral2 = nn.Conv2d(192, num_filters, kernel_size=1, bias=False)
129 | self.lateral1 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False)
130 | self.lateral0 = nn.Conv2d(32, num_filters // 2, kernel_size=1, bias=False)
131 |
132 | for param in self.inception.parameters():
133 | param.requires_grad = False
134 |
135 | def unfreeze(self):
136 | for param in self.inception.parameters():
137 | param.requires_grad = True
138 |
139 | def forward(self, x):
140 |
141 | # Bottom-up pathway, from ResNet
142 | enc0 = self.enc0(x)
143 |
144 | enc1 = self.enc1(enc0) # 256
145 |
146 | enc2 = self.enc2(enc1) # 512
147 |
148 | enc3 = self.enc3(enc2) # 1024
149 |
150 | enc4 = self.enc4(enc3) # 2048
151 |
152 | # Lateral connections
153 |
154 | lateral4 = self.pad(self.lateral4(enc4))
155 | lateral3 = self.pad(self.lateral3(enc3))
156 | lateral2 = self.lateral2(enc2)
157 | lateral1 = self.pad(self.lateral1(enc1))
158 | lateral0 = self.lateral0(enc0)
159 |
160 | # Top-down pathway
161 | pad = (1, 2, 1, 2) # pad last dim by 1 on each side
162 | pad1 = (0, 1, 0, 1)
163 | map4 = lateral4
164 | map3 = self.td1(lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest"))
165 | map2 = self.td2(F.pad(lateral2, pad, "reflect") + nn.functional.upsample(map3, scale_factor=2, mode="nearest"))
166 | map1 = self.td3(lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest"))
167 | return F.pad(lateral0, pad1, "reflect"), map1, map2, map3, map4
168 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/models/fpn_inception_simple.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from pretrainedmodels import inceptionresnetv2
4 | from torchsummary import summary
5 | import torch.nn.functional as F
6 |
7 | class FPNHead(nn.Module):
8 | def __init__(self, num_in, num_mid, num_out):
9 | super().__init__()
10 |
11 | self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
12 | self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)
13 |
14 | def forward(self, x):
15 | x = nn.functional.relu(self.block0(x), inplace=True)
16 | x = nn.functional.relu(self.block1(x), inplace=True)
17 | return x
18 |
19 | class ConvBlock(nn.Module):
20 | def __init__(self, num_in, num_out, norm_layer):
21 | super().__init__()
22 |
23 | self.block = nn.Sequential(nn.Conv2d(num_in, num_out, kernel_size=3, padding=1),
24 | norm_layer(num_out),
25 | nn.ReLU(inplace=True))
26 |
27 | def forward(self, x):
28 | x = self.block(x)
29 | return x
30 |
31 |
32 | class FPNInceptionSimple(nn.Module):
33 |
34 | def __init__(self, norm_layer, output_ch=3, num_filters=128, num_filters_fpn=256):
35 | super().__init__()
36 |
37 | # Feature Pyramid Network (FPN) with four feature maps of resolutions
38 | # 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
39 | self.fpn = FPN(num_filters=num_filters_fpn, norm_layer=norm_layer)
40 |
41 | # The segmentation heads on top of the FPN
42 |
43 | self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters)
44 | self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters)
45 | self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters)
46 | self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters)
47 |
48 | self.smooth = nn.Sequential(
49 | nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
50 | norm_layer(num_filters),
51 | nn.ReLU(),
52 | )
53 |
54 | self.smooth2 = nn.Sequential(
55 | nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
56 | norm_layer(num_filters // 2),
57 | nn.ReLU(),
58 | )
59 |
60 | self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)
61 |
62 | def unfreeze(self):
63 | self.fpn.unfreeze()
64 |
65 | def forward(self, x):
66 |
67 | map0, map1, map2, map3, map4 = self.fpn(x)
68 |
69 | map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
70 | map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
71 | map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
72 | map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest")
73 |
74 | smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
75 | smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
76 | smoothed = self.smooth2(smoothed + map0)
77 | smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
78 |
79 | final = self.final(smoothed)
80 | res = torch.tanh(final) + x
81 |
82 | return torch.clamp(res, min = -1,max = 1)
83 |
84 |
85 | class FPN(nn.Module):
86 |
87 | def __init__(self, norm_layer, num_filters=256):
88 | """Creates an `FPN` instance for feature extraction.
89 | Args:
90 | num_filters: the number of filters in each output pyramid level
91 | pretrained: use ImageNet pre-trained backbone feature extractor
92 | """
93 |
94 | super().__init__()
95 | self.inception = inceptionresnetv2(num_classes=1000, pretrained='imagenet')
96 |
97 | self.enc0 = self.inception.conv2d_1a
98 | self.enc1 = nn.Sequential(
99 | self.inception.conv2d_2a,
100 | self.inception.conv2d_2b,
101 | self.inception.maxpool_3a,
102 | ) # 64
103 | self.enc2 = nn.Sequential(
104 | self.inception.conv2d_3b,
105 | self.inception.conv2d_4a,
106 | self.inception.maxpool_5a,
107 | ) # 192
108 | self.enc3 = nn.Sequential(
109 | self.inception.mixed_5b,
110 | self.inception.repeat,
111 | self.inception.mixed_6a,
112 | ) # 1088
113 | self.enc4 = nn.Sequential(
114 | self.inception.repeat_1,
115 | self.inception.mixed_7a,
116 | ) #2080
117 |
118 | self.pad = nn.ReflectionPad2d(1)
119 | self.lateral4 = nn.Conv2d(2080, num_filters, kernel_size=1, bias=False)
120 | self.lateral3 = nn.Conv2d(1088, num_filters, kernel_size=1, bias=False)
121 | self.lateral2 = nn.Conv2d(192, num_filters, kernel_size=1, bias=False)
122 | self.lateral1 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False)
123 | self.lateral0 = nn.Conv2d(32, num_filters // 2, kernel_size=1, bias=False)
124 |
125 | for param in self.inception.parameters():
126 | param.requires_grad = False
127 |
128 | def unfreeze(self):
129 | for param in self.inception.parameters():
130 | param.requires_grad = True
131 |
132 | def forward(self, x):
133 |
134 | # Bottom-up pathway, from ResNet
135 | enc0 = self.enc0(x)
136 |
137 | enc1 = self.enc1(enc0) # 256
138 |
139 | enc2 = self.enc2(enc1) # 512
140 |
141 | enc3 = self.enc3(enc2) # 1024
142 |
143 | enc4 = self.enc4(enc3) # 2048
144 |
145 | # Lateral connections
146 |
147 | lateral4 = self.pad(self.lateral4(enc4))
148 | lateral3 = self.pad(self.lateral3(enc3))
149 | lateral2 = self.lateral2(enc2)
150 | lateral1 = self.pad(self.lateral1(enc1))
151 | lateral0 = self.lateral0(enc0)
152 |
153 | # Top-down pathway
154 | pad = (1, 2, 1, 2) # pad last dim by 1 on each side
155 | pad1 = (0, 1, 0, 1)
156 | map4 = lateral4
157 | map3 = lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest")
158 | map2 = F.pad(lateral2, pad, "reflect") + nn.functional.upsample(map3, scale_factor=2, mode="nearest")
159 | map1 = lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest")
160 | return F.pad(lateral0, pad1, "reflect"), map1, map2, map3, map4
161 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/models/fpn_mobilenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.mobilenet_v2 import MobileNetV2
4 |
5 | class FPNHead(nn.Module):
6 | def __init__(self, num_in, num_mid, num_out):
7 | super().__init__()
8 |
9 | self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
10 | self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)
11 |
12 | def forward(self, x):
13 | x = nn.functional.relu(self.block0(x), inplace=True)
14 | x = nn.functional.relu(self.block1(x), inplace=True)
15 | return x
16 |
17 |
18 | class FPNMobileNet(nn.Module):
19 |
20 | def __init__(self, norm_layer, output_ch=3, num_filters=64, num_filters_fpn=128, pretrained=True):
21 | super().__init__()
22 |
23 | # Feature Pyramid Network (FPN) with four feature maps of resolutions
24 | # 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
25 |
26 | self.fpn = FPN(num_filters=num_filters_fpn, norm_layer = norm_layer, pretrained=pretrained)
27 |
28 | # The segmentation heads on top of the FPN
29 |
30 | self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters)
31 | self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters)
32 | self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters)
33 | self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters)
34 |
35 | self.smooth = nn.Sequential(
36 | nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
37 | norm_layer(num_filters),
38 | nn.ReLU(),
39 | )
40 |
41 | self.smooth2 = nn.Sequential(
42 | nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
43 | norm_layer(num_filters // 2),
44 | nn.ReLU(),
45 | )
46 |
47 | self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)
48 |
49 | def unfreeze(self):
50 | self.fpn.unfreeze()
51 |
52 | def forward(self, x):
53 |
54 | map0, map1, map2, map3, map4 = self.fpn(x)
55 |
56 | map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
57 | map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
58 | map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
59 | map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest")
60 |
61 | smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
62 | smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
63 | smoothed = self.smooth2(smoothed + map0)
64 | smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
65 |
66 | final = self.final(smoothed)
67 | res = torch.tanh(final) + x
68 |
69 | return torch.clamp(res, min=-1, max=1)
70 |
71 |
72 | class FPN(nn.Module):
73 |
74 | def __init__(self, norm_layer, num_filters=128, pretrained=True):
75 | """Creates an `FPN` instance for feature extraction.
76 | Args:
77 | num_filters: the number of filters in each output pyramid level
78 | pretrained: use ImageNet pre-trained backbone feature extractor
79 | """
80 |
81 | super().__init__()
82 | net = MobileNetV2(n_class=1000)
83 |
84 | if pretrained:
85 | #Load weights into the project directory
86 | state_dict = torch.load('mobilenetv2.pth.tar') # add map_location='cpu' if no gpu
87 | net.load_state_dict(state_dict)
88 | self.features = net.features
89 |
90 | self.enc0 = nn.Sequential(*self.features[0:2])
91 | self.enc1 = nn.Sequential(*self.features[2:4])
92 | self.enc2 = nn.Sequential(*self.features[4:7])
93 | self.enc3 = nn.Sequential(*self.features[7:11])
94 | self.enc4 = nn.Sequential(*self.features[11:16])
95 |
96 | self.td1 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
97 | norm_layer(num_filters),
98 | nn.ReLU(inplace=True))
99 | self.td2 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
100 | norm_layer(num_filters),
101 | nn.ReLU(inplace=True))
102 | self.td3 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
103 | norm_layer(num_filters),
104 | nn.ReLU(inplace=True))
105 |
106 | self.lateral4 = nn.Conv2d(160, num_filters, kernel_size=1, bias=False)
107 | self.lateral3 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False)
108 | self.lateral2 = nn.Conv2d(32, num_filters, kernel_size=1, bias=False)
109 | self.lateral1 = nn.Conv2d(24, num_filters, kernel_size=1, bias=False)
110 | self.lateral0 = nn.Conv2d(16, num_filters // 2, kernel_size=1, bias=False)
111 |
112 | for param in self.features.parameters():
113 | param.requires_grad = False
114 |
115 | def unfreeze(self):
116 | for param in self.features.parameters():
117 | param.requires_grad = True
118 |
119 |
120 | def forward(self, x):
121 |
122 | # Bottom-up pathway, from ResNet
123 | enc0 = self.enc0(x)
124 |
125 | enc1 = self.enc1(enc0) # 256
126 |
127 | enc2 = self.enc2(enc1) # 512
128 |
129 | enc3 = self.enc3(enc2) # 1024
130 |
131 | enc4 = self.enc4(enc3) # 2048
132 |
133 | # Lateral connections
134 |
135 | lateral4 = self.lateral4(enc4)
136 | lateral3 = self.lateral3(enc3)
137 | lateral2 = self.lateral2(enc2)
138 | lateral1 = self.lateral1(enc1)
139 | lateral0 = self.lateral0(enc0)
140 |
141 | # Top-down pathway
142 | map4 = lateral4
143 | map3 = self.td1(lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest"))
144 | map2 = self.td2(lateral2 + nn.functional.upsample(map3, scale_factor=2, mode="nearest"))
145 | map1 = self.td3(lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest"))
146 | return lateral0, map1, map2, map3, map4
147 |
148 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/models/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.autograd as autograd
3 | import torch.nn as nn
4 | import torchvision.models as models
5 | import torchvision.transforms as transforms
6 | from torch.autograd import Variable
7 |
8 | from util.image_pool import ImagePool
9 |
10 |
11 | ###############################################################################
12 | # Functions
13 | ###############################################################################
14 |
15 | class ContentLoss():
16 | def initialize(self, loss):
17 | self.criterion = loss
18 |
19 | def get_loss(self, fakeIm, realIm):
20 | return self.criterion(fakeIm, realIm)
21 |
22 | def __call__(self, fakeIm, realIm):
23 | return self.get_loss(fakeIm, realIm)
24 |
25 |
26 | class PerceptualLoss():
27 |
28 | def contentFunc(self):
29 | conv_3_3_layer = 14
30 | cnn = models.vgg19(pretrained=True).features
31 | cnn = cnn.cuda()
32 | model = nn.Sequential()
33 | model = model.cuda()
34 | model = model.eval()
35 | for i, layer in enumerate(list(cnn)):
36 | model.add_module(str(i), layer)
37 | if i == conv_3_3_layer:
38 | break
39 | return model
40 |
41 | def initialize(self, loss):
42 | with torch.no_grad():
43 | self.criterion = loss
44 | self.contentFunc = self.contentFunc()
45 | self.transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
46 |
47 | def get_loss(self, fakeIm, realIm):
48 | fakeIm = (fakeIm + 1) / 2.0
49 | realIm = (realIm + 1) / 2.0
50 | fakeIm[0, :, :, :] = self.transform(fakeIm[0, :, :, :])
51 | realIm[0, :, :, :] = self.transform(realIm[0, :, :, :])
52 | f_fake = self.contentFunc.forward(fakeIm)
53 | f_real = self.contentFunc.forward(realIm)
54 | f_real_no_grad = f_real.detach()
55 | loss = self.criterion(f_fake, f_real_no_grad)
56 | return 0.006 * torch.mean(loss) + 0.5 * nn.MSELoss()(fakeIm, realIm)
57 |
58 | def __call__(self, fakeIm, realIm):
59 | return self.get_loss(fakeIm, realIm)
60 |
61 |
62 | class GANLoss(nn.Module):
63 | def __init__(self, use_l1=True, target_real_label=1.0, target_fake_label=0.0,
64 | tensor=torch.FloatTensor):
65 | super(GANLoss, self).__init__()
66 | self.real_label = target_real_label
67 | self.fake_label = target_fake_label
68 | self.real_label_var = None
69 | self.fake_label_var = None
70 | self.Tensor = tensor
71 | if use_l1:
72 | self.loss = nn.L1Loss()
73 | else:
74 | self.loss = nn.BCEWithLogitsLoss()
75 |
76 | def get_target_tensor(self, input, target_is_real):
77 | if target_is_real:
78 | create_label = ((self.real_label_var is None) or
79 | (self.real_label_var.numel() != input.numel()))
80 | if create_label:
81 | real_tensor = self.Tensor(input.size()).fill_(self.real_label)
82 | self.real_label_var = Variable(real_tensor, requires_grad=False)
83 | target_tensor = self.real_label_var
84 | else:
85 | create_label = ((self.fake_label_var is None) or
86 | (self.fake_label_var.numel() != input.numel()))
87 | if create_label:
88 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
89 | self.fake_label_var = Variable(fake_tensor, requires_grad=False)
90 | target_tensor = self.fake_label_var
91 | return target_tensor.cuda()
92 |
93 | def __call__(self, input, target_is_real):
94 | target_tensor = self.get_target_tensor(input, target_is_real)
95 | return self.loss(input, target_tensor)
96 |
97 |
98 | class DiscLoss(nn.Module):
99 | def name(self):
100 | return 'DiscLoss'
101 |
102 | def __init__(self):
103 | super(DiscLoss, self).__init__()
104 |
105 | self.criterionGAN = GANLoss(use_l1=False)
106 | self.fake_AB_pool = ImagePool(50)
107 |
108 | def get_g_loss(self, net, fakeB, realB):
109 | # First, G(A) should fake the discriminator
110 | pred_fake = net.forward(fakeB)
111 | return self.criterionGAN(pred_fake, 1)
112 |
113 | def get_loss(self, net, fakeB, realB):
114 | # Fake
115 | # stop backprop to the generator by detaching fake_B
116 | # Generated Image Disc Output should be close to zero
117 | self.pred_fake = net.forward(fakeB.detach())
118 | self.loss_D_fake = self.criterionGAN(self.pred_fake, 0)
119 |
120 | # Real
121 | self.pred_real = net.forward(realB)
122 | self.loss_D_real = self.criterionGAN(self.pred_real, 1)
123 |
124 | # Combined loss
125 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
126 | return self.loss_D
127 |
128 | def __call__(self, net, fakeB, realB):
129 | return self.get_loss(net, fakeB, realB)
130 |
131 |
132 | class RelativisticDiscLoss(nn.Module):
133 | def name(self):
134 | return 'RelativisticDiscLoss'
135 |
136 | def __init__(self):
137 | super(RelativisticDiscLoss, self).__init__()
138 |
139 | self.criterionGAN = GANLoss(use_l1=False)
140 | self.fake_pool = ImagePool(50) # create image buffer to store previously generated images
141 | self.real_pool = ImagePool(50)
142 |
143 | def get_g_loss(self, net, fakeB, realB):
144 | # First, G(A) should fake the discriminator
145 | self.pred_fake = net.forward(fakeB)
146 |
147 | # Real
148 | self.pred_real = net.forward(realB)
149 | errG = (self.criterionGAN(self.pred_real - torch.mean(self.fake_pool.query()), 0) +
150 | self.criterionGAN(self.pred_fake - torch.mean(self.real_pool.query()), 1)) / 2
151 | return errG
152 |
153 | def get_loss(self, net, fakeB, realB):
154 | # Fake
155 | # stop backprop to the generator by detaching fake_B
156 | # Generated Image Disc Output should be close to zero
157 | self.fake_B = fakeB.detach()
158 | self.real_B = realB
159 | self.pred_fake = net.forward(fakeB.detach())
160 | self.fake_pool.add(self.pred_fake)
161 |
162 | # Real
163 | self.pred_real = net.forward(realB)
164 | self.real_pool.add(self.pred_real)
165 |
166 | # Combined loss
167 | self.loss_D = (self.criterionGAN(self.pred_real - torch.mean(self.fake_pool.query()), 1) +
168 | self.criterionGAN(self.pred_fake - torch.mean(self.real_pool.query()), 0)) / 2
169 | return self.loss_D
170 |
171 | def __call__(self, net, fakeB, realB):
172 | return self.get_loss(net, fakeB, realB)
173 |
174 |
175 | class RelativisticDiscLossLS(nn.Module):
176 | def name(self):
177 | return 'RelativisticDiscLossLS'
178 |
179 | def __init__(self):
180 | super(RelativisticDiscLossLS, self).__init__()
181 |
182 | self.criterionGAN = GANLoss(use_l1=True)
183 | self.fake_pool = ImagePool(50) # create image buffer to store previously generated images
184 | self.real_pool = ImagePool(50)
185 |
186 | def get_g_loss(self, net, fakeB, realB):
187 | # First, G(A) should fake the discriminator
188 | self.pred_fake = net.forward(fakeB)
189 |
190 | # Real
191 | self.pred_real = net.forward(realB)
192 | errG = (torch.mean((self.pred_real - torch.mean(self.fake_pool.query()) + 1) ** 2) +
193 | torch.mean((self.pred_fake - torch.mean(self.real_pool.query()) - 1) ** 2)) / 2
194 | return errG
195 |
196 | def get_loss(self, net, fakeB, realB):
197 | # Fake
198 | # stop backprop to the generator by detaching fake_B
199 | # Generated Image Disc Output should be close to zero
200 | self.fake_B = fakeB.detach()
201 | self.real_B = realB
202 | self.pred_fake = net.forward(fakeB.detach())
203 | self.fake_pool.add(self.pred_fake)
204 |
205 | # Real
206 | self.pred_real = net.forward(realB)
207 | self.real_pool.add(self.pred_real)
208 |
209 | # Combined loss
210 | self.loss_D = (torch.mean((self.pred_real - torch.mean(self.fake_pool.query()) - 1) ** 2) +
211 | torch.mean((self.pred_fake - torch.mean(self.real_pool.query()) + 1) ** 2)) / 2
212 | return self.loss_D
213 |
214 | def __call__(self, net, fakeB, realB):
215 | return self.get_loss(net, fakeB, realB)
216 |
217 |
218 | class DiscLossLS(DiscLoss):
219 | def name(self):
220 | return 'DiscLossLS'
221 |
222 | def __init__(self):
223 | super(DiscLossLS, self).__init__()
224 | self.criterionGAN = GANLoss(use_l1=True)
225 |
226 | def get_g_loss(self, net, fakeB, realB):
227 | return DiscLoss.get_g_loss(self, net, fakeB)
228 |
229 | def get_loss(self, net, fakeB, realB):
230 | return DiscLoss.get_loss(self, net, fakeB, realB)
231 |
232 |
233 | class DiscLossWGANGP(DiscLossLS):
234 | def name(self):
235 | return 'DiscLossWGAN-GP'
236 |
237 | def __init__(self):
238 | super(DiscLossWGANGP, self).__init__()
239 | self.LAMBDA = 10
240 |
241 | def get_g_loss(self, net, fakeB, realB):
242 | # First, G(A) should fake the discriminator
243 | self.D_fake = net.forward(fakeB)
244 | return -self.D_fake.mean()
245 |
246 | def calc_gradient_penalty(self, netD, real_data, fake_data):
247 | alpha = torch.rand(1, 1)
248 | alpha = alpha.expand(real_data.size())
249 | alpha = alpha.cuda()
250 |
251 | interpolates = alpha * real_data + ((1 - alpha) * fake_data)
252 |
253 | interpolates = interpolates.cuda()
254 | interpolates = Variable(interpolates, requires_grad=True)
255 |
256 | disc_interpolates = netD.forward(interpolates)
257 |
258 | gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
259 | grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
260 | create_graph=True, retain_graph=True, only_inputs=True)[0]
261 |
262 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA
263 | return gradient_penalty
264 |
265 | def get_loss(self, net, fakeB, realB):
266 | self.D_fake = net.forward(fakeB.detach())
267 | self.D_fake = self.D_fake.mean()
268 |
269 | # Real
270 | self.D_real = net.forward(realB)
271 | self.D_real = self.D_real.mean()
272 | # Combined loss
273 | self.loss_D = self.D_fake - self.D_real
274 | gradient_penalty = self.calc_gradient_penalty(net, realB.data, fakeB.data)
275 | return self.loss_D + gradient_penalty
276 |
277 |
278 | def get_loss(model):
279 | if model['content_loss'] == 'perceptual':
280 | content_loss = PerceptualLoss()
281 | content_loss.initialize(nn.MSELoss())
282 | elif model['content_loss'] == 'l1':
283 | content_loss = ContentLoss()
284 | content_loss.initialize(nn.L1Loss())
285 | else:
286 | raise ValueError("ContentLoss [%s] not recognized." % model['content_loss'])
287 |
288 | if model['disc_loss'] == 'wgan-gp':
289 | disc_loss = DiscLossWGANGP()
290 | elif model['disc_loss'] == 'lsgan':
291 | disc_loss = DiscLossLS()
292 | elif model['disc_loss'] == 'gan':
293 | disc_loss = DiscLoss()
294 | elif model['disc_loss'] == 'ragan':
295 | disc_loss = RelativisticDiscLoss()
296 | elif model['disc_loss'] == 'ragan-ls':
297 | disc_loss = RelativisticDiscLossLS()
298 | else:
299 | raise ValueError("GAN Loss [%s] not recognized." % model['disc_loss'])
300 | return content_loss, disc_loss
301 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/models/mobilenet_v2.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 |
4 |
5 | def conv_bn(inp, oup, stride):
6 | return nn.Sequential(
7 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
8 | nn.BatchNorm2d(oup),
9 | nn.ReLU6(inplace=True)
10 | )
11 |
12 |
13 | def conv_1x1_bn(inp, oup):
14 | return nn.Sequential(
15 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
16 | nn.BatchNorm2d(oup),
17 | nn.ReLU6(inplace=True)
18 | )
19 |
20 |
21 | class InvertedResidual(nn.Module):
22 | def __init__(self, inp, oup, stride, expand_ratio):
23 | super(InvertedResidual, self).__init__()
24 | self.stride = stride
25 | assert stride in [1, 2]
26 |
27 | hidden_dim = round(inp * expand_ratio)
28 | self.use_res_connect = self.stride == 1 and inp == oup
29 |
30 | if expand_ratio == 1:
31 | self.conv = nn.Sequential(
32 | # dw
33 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
34 | nn.BatchNorm2d(hidden_dim),
35 | nn.ReLU6(inplace=True),
36 | # pw-linear
37 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
38 | nn.BatchNorm2d(oup),
39 | )
40 | else:
41 | self.conv = nn.Sequential(
42 | # pw
43 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
44 | nn.BatchNorm2d(hidden_dim),
45 | nn.ReLU6(inplace=True),
46 | # dw
47 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
48 | nn.BatchNorm2d(hidden_dim),
49 | nn.ReLU6(inplace=True),
50 | # pw-linear
51 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
52 | nn.BatchNorm2d(oup),
53 | )
54 |
55 | def forward(self, x):
56 | if self.use_res_connect:
57 | return x + self.conv(x)
58 | else:
59 | return self.conv(x)
60 |
61 |
62 | class MobileNetV2(nn.Module):
63 | def __init__(self, n_class=1000, input_size=224, width_mult=1.):
64 | super(MobileNetV2, self).__init__()
65 | block = InvertedResidual
66 | input_channel = 32
67 | last_channel = 1280
68 | interverted_residual_setting = [
69 | # t, c, n, s
70 | [1, 16, 1, 1],
71 | [6, 24, 2, 2],
72 | [6, 32, 3, 2],
73 | [6, 64, 4, 2],
74 | [6, 96, 3, 1],
75 | [6, 160, 3, 2],
76 | [6, 320, 1, 1],
77 | ]
78 |
79 | # building first layer
80 | assert input_size % 32 == 0
81 | input_channel = int(input_channel * width_mult)
82 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
83 | self.features = [conv_bn(3, input_channel, 2)]
84 | # building inverted residual blocks
85 | for t, c, n, s in interverted_residual_setting:
86 | output_channel = int(c * width_mult)
87 | for i in range(n):
88 | if i == 0:
89 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
90 | else:
91 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
92 | input_channel = output_channel
93 | # building last several layers
94 | self.features.append(conv_1x1_bn(input_channel, self.last_channel))
95 | # make it nn.Sequential
96 | self.features = nn.Sequential(*self.features)
97 |
98 | # building classifier
99 | self.classifier = nn.Sequential(
100 | nn.Dropout(0.2),
101 | nn.Linear(self.last_channel, n_class),
102 | )
103 |
104 | self._initialize_weights()
105 |
106 | def forward(self, x):
107 | x = self.features(x)
108 | x = x.mean(3).mean(2)
109 | x = self.classifier(x)
110 | return x
111 |
112 | def _initialize_weights(self):
113 | for m in self.modules():
114 | if isinstance(m, nn.Conv2d):
115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
116 | m.weight.data.normal_(0, math.sqrt(2. / n))
117 | if m.bias is not None:
118 | m.bias.data.zero_()
119 | elif isinstance(m, nn.BatchNorm2d):
120 | m.weight.data.fill_(1)
121 | m.bias.data.zero_()
122 | elif isinstance(m, nn.Linear):
123 | n = m.weight.size(1)
124 | m.weight.data.normal_(0, 0.01)
125 | m.bias.data.zero_()
126 |
127 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/models/models.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.nn as nn
3 | from skimage.metrics import structural_similarity as SSIM
4 | from util.metrics import PSNR
5 |
6 |
7 | class DeblurModel(nn.Module):
8 | def __init__(self):
9 | super(DeblurModel, self).__init__()
10 |
11 | def get_input(self, data):
12 | img = data['a']
13 | inputs = img
14 | targets = data['b']
15 | inputs, targets = inputs.cuda(), targets.cuda()
16 | return inputs, targets
17 |
18 | def tensor2im(self, image_tensor, imtype=np.uint8):
19 | image_numpy = image_tensor[0].cpu().float().numpy()
20 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
21 | return image_numpy.astype(imtype)
22 |
23 | def get_images_and_metrics(self, inp, output, target) -> (float, float, np.ndarray):
24 | inp = self.tensor2im(inp)
25 | fake = self.tensor2im(output.data)
26 | real = self.tensor2im(target.data)
27 | psnr = PSNR(fake, real)
28 | ssim = SSIM(fake, real, multichannel=True)
29 | vis_img = np.hstack((inp, fake, real))
30 | return psnr, ssim, vis_img
31 |
32 |
33 | def get_model(model_config):
34 | return DeblurModel()
35 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import functools
5 | from torch.autograd import Variable
6 | import numpy as np
7 | from models.fpn_mobilenet import FPNMobileNet
8 | from models.fpn_inception import FPNInception
9 | from models.fpn_inception_simple import FPNInceptionSimple
10 | from models.unet_seresnext import UNetSEResNext
11 | from models.fpn_densenet import FPNDense
12 | ###############################################################################
13 | # Functions
14 | ###############################################################################
15 |
16 |
17 | def get_norm_layer(norm_type='instance'):
18 | if norm_type == 'batch':
19 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
20 | elif norm_type == 'instance':
21 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
22 | else:
23 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
24 | return norm_layer
25 |
26 | ##############################################################################
27 | # Classes
28 | ##############################################################################
29 |
30 |
31 | # Defines the generator that consists of Resnet blocks between a few
32 | # downsampling/upsampling operations.
33 | # Code and idea originally from Justin Johnson's architecture.
34 | # https://github.com/jcjohnson/fast-neural-style/
35 | class ResnetGenerator(nn.Module):
36 | def __init__(self, input_nc=3, output_nc=3, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, use_parallel=True, learn_residual=True, padding_type='reflect'):
37 | assert(n_blocks >= 0)
38 | super(ResnetGenerator, self).__init__()
39 | self.input_nc = input_nc
40 | self.output_nc = output_nc
41 | self.ngf = ngf
42 | self.use_parallel = use_parallel
43 | self.learn_residual = learn_residual
44 | if type(norm_layer) == functools.partial:
45 | use_bias = norm_layer.func == nn.InstanceNorm2d
46 | else:
47 | use_bias = norm_layer == nn.InstanceNorm2d
48 |
49 | model = [nn.ReflectionPad2d(3),
50 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
51 | bias=use_bias),
52 | norm_layer(ngf),
53 | nn.ReLU(True)]
54 |
55 | n_downsampling = 2
56 | for i in range(n_downsampling):
57 | mult = 2**i
58 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
59 | stride=2, padding=1, bias=use_bias),
60 | norm_layer(ngf * mult * 2),
61 | nn.ReLU(True)]
62 |
63 | mult = 2**n_downsampling
64 | for i in range(n_blocks):
65 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
66 |
67 | for i in range(n_downsampling):
68 | mult = 2**(n_downsampling - i)
69 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
70 | kernel_size=3, stride=2,
71 | padding=1, output_padding=1,
72 | bias=use_bias),
73 | norm_layer(int(ngf * mult / 2)),
74 | nn.ReLU(True)]
75 | model += [nn.ReflectionPad2d(3)]
76 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
77 | model += [nn.Tanh()]
78 |
79 | self.model = nn.Sequential(*model)
80 |
81 | def forward(self, input):
82 | output = self.model(input)
83 | if self.learn_residual:
84 | output = input + output
85 | output = torch.clamp(output,min = -1,max = 1)
86 | return output
87 |
88 |
89 | # Define a resnet block
90 | class ResnetBlock(nn.Module):
91 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
92 | super(ResnetBlock, self).__init__()
93 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
94 |
95 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
96 | conv_block = []
97 | p = 0
98 | if padding_type == 'reflect':
99 | conv_block += [nn.ReflectionPad2d(1)]
100 | elif padding_type == 'replicate':
101 | conv_block += [nn.ReplicationPad2d(1)]
102 | elif padding_type == 'zero':
103 | p = 1
104 | else:
105 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
106 |
107 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
108 | norm_layer(dim),
109 | nn.ReLU(True)]
110 | if use_dropout:
111 | conv_block += [nn.Dropout(0.5)]
112 |
113 | p = 0
114 | if padding_type == 'reflect':
115 | conv_block += [nn.ReflectionPad2d(1)]
116 | elif padding_type == 'replicate':
117 | conv_block += [nn.ReplicationPad2d(1)]
118 | elif padding_type == 'zero':
119 | p = 1
120 | else:
121 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
122 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
123 | norm_layer(dim)]
124 |
125 | return nn.Sequential(*conv_block)
126 |
127 | def forward(self, x):
128 | out = x + self.conv_block(x)
129 | return out
130 |
131 |
132 | class DicsriminatorTail(nn.Module):
133 | def __init__(self, nf_mult, n_layers, ndf=64, norm_layer=nn.BatchNorm2d, use_parallel=True):
134 | super(DicsriminatorTail, self).__init__()
135 | self.use_parallel = use_parallel
136 | if type(norm_layer) == functools.partial:
137 | use_bias = norm_layer.func == nn.InstanceNorm2d
138 | else:
139 | use_bias = norm_layer == nn.InstanceNorm2d
140 |
141 | kw = 4
142 | padw = int(np.ceil((kw-1)/2))
143 |
144 | nf_mult_prev = nf_mult
145 | nf_mult = min(2**n_layers, 8)
146 | sequence = [
147 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
148 | kernel_size=kw, stride=1, padding=padw, bias=use_bias),
149 | norm_layer(ndf * nf_mult),
150 | nn.LeakyReLU(0.2, True)
151 | ]
152 |
153 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
154 |
155 | self.model = nn.Sequential(*sequence)
156 |
157 | def forward(self, input):
158 | return self.model(input)
159 |
160 |
161 | class MultiScaleDiscriminator(nn.Module):
162 | def __init__(self, input_nc=3, ndf=64, norm_layer=nn.BatchNorm2d, use_parallel=True):
163 | super(MultiScaleDiscriminator, self).__init__()
164 | self.use_parallel = use_parallel
165 | if type(norm_layer) == functools.partial:
166 | use_bias = norm_layer.func == nn.InstanceNorm2d
167 | else:
168 | use_bias = norm_layer == nn.InstanceNorm2d
169 |
170 | kw = 4
171 | padw = int(np.ceil((kw-1)/2))
172 | sequence = [
173 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
174 | nn.LeakyReLU(0.2, True)
175 | ]
176 |
177 | nf_mult = 1
178 | for n in range(1, 3):
179 | nf_mult_prev = nf_mult
180 | nf_mult = min(2**n, 8)
181 | sequence += [
182 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
183 | kernel_size=kw, stride=2, padding=padw, bias=use_bias),
184 | norm_layer(ndf * nf_mult),
185 | nn.LeakyReLU(0.2, True)
186 | ]
187 |
188 | self.scale_one = nn.Sequential(*sequence)
189 | self.first_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=3)
190 | nf_mult_prev = 4
191 | nf_mult = 8
192 |
193 | self.scale_two = nn.Sequential(
194 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
195 | kernel_size=kw, stride=2, padding=padw, bias=use_bias),
196 | norm_layer(ndf * nf_mult),
197 | nn.LeakyReLU(0.2, True))
198 | nf_mult_prev = nf_mult
199 | self.second_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=4)
200 | self.scale_three = nn.Sequential(
201 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
202 | norm_layer(ndf * nf_mult),
203 | nn.LeakyReLU(0.2, True))
204 | self.third_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=5)
205 |
206 | def forward(self, input):
207 | x = self.scale_one(input)
208 | x_1 = self.first_tail(x)
209 | x = self.scale_two(x)
210 | x_2 = self.second_tail(x)
211 | x = self.scale_three(x)
212 | x = self.third_tail(x)
213 | return [x_1, x_2, x]
214 |
215 |
216 | # Defines the PatchGAN discriminator with the specified arguments.
217 | class NLayerDiscriminator(nn.Module):
218 | def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_parallel=True):
219 | super(NLayerDiscriminator, self).__init__()
220 | self.use_parallel = use_parallel
221 | if type(norm_layer) == functools.partial:
222 | use_bias = norm_layer.func == nn.InstanceNorm2d
223 | else:
224 | use_bias = norm_layer == nn.InstanceNorm2d
225 |
226 | kw = 4
227 | padw = int(np.ceil((kw-1)/2))
228 | sequence = [
229 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
230 | nn.LeakyReLU(0.2, True)
231 | ]
232 |
233 | nf_mult = 1
234 | for n in range(1, n_layers):
235 | nf_mult_prev = nf_mult
236 | nf_mult = min(2**n, 8)
237 | sequence += [
238 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
239 | kernel_size=kw, stride=2, padding=padw, bias=use_bias),
240 | norm_layer(ndf * nf_mult),
241 | nn.LeakyReLU(0.2, True)
242 | ]
243 |
244 | nf_mult_prev = nf_mult
245 | nf_mult = min(2**n_layers, 8)
246 | sequence += [
247 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
248 | kernel_size=kw, stride=1, padding=padw, bias=use_bias),
249 | norm_layer(ndf * nf_mult),
250 | nn.LeakyReLU(0.2, True)
251 | ]
252 |
253 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
254 |
255 | if use_sigmoid:
256 | sequence += [nn.Sigmoid()]
257 |
258 | self.model = nn.Sequential(*sequence)
259 |
260 | def forward(self, input):
261 | return self.model(input)
262 |
263 |
264 | def get_fullD(model_config):
265 | model_d = NLayerDiscriminator(n_layers=5,
266 | norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
267 | use_sigmoid=False)
268 | return model_d
269 |
270 |
271 | def get_generator(model_config):
272 | generator_name = model_config['g_name']
273 | if generator_name == 'resnet':
274 | model_g = ResnetGenerator(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
275 | use_dropout=model_config['dropout'],
276 | n_blocks=model_config['blocks'],
277 | learn_residual=model_config['learn_residual'])
278 | elif generator_name == 'fpn_mobilenet':
279 | model_g = FPNMobileNet(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
280 | elif generator_name == 'fpn_inception':
281 | model_g = FPNInception(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
282 | elif generator_name == 'fpn_inception_simple':
283 | model_g = FPNInceptionSimple(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
284 | elif generator_name == 'fpn_dense':
285 | model_g = FPNDense()
286 | elif generator_name == 'unet_seresnext':
287 | model_g = UNetSEResNext(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
288 | pretrained=model_config['pretrained'])
289 | else:
290 | raise ValueError("Generator Network [%s] not recognized." % generator_name)
291 |
292 | return nn.DataParallel(model_g)
293 |
294 |
295 | def get_discriminator(model_config):
296 | discriminator_name = model_config['d_name']
297 | if discriminator_name == 'no_gan':
298 | model_d = None
299 | elif discriminator_name == 'patch_gan':
300 | model_d = NLayerDiscriminator(n_layers=model_config['d_layers'],
301 | norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
302 | use_sigmoid=False)
303 | model_d = nn.DataParallel(model_d)
304 | elif discriminator_name == 'double_gan':
305 | patch_gan = NLayerDiscriminator(n_layers=model_config['d_layers'],
306 | norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
307 | use_sigmoid=False)
308 | patch_gan = nn.DataParallel(patch_gan)
309 | full_gan = get_fullD(model_config)
310 | full_gan = nn.DataParallel(full_gan)
311 | model_d = {'patch': patch_gan,
312 | 'full': full_gan}
313 | elif discriminator_name == 'multi_scale':
314 | model_d = MultiScaleDiscriminator(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
315 | model_d = nn.DataParallel(model_d)
316 | else:
317 | raise ValueError("Discriminator Network [%s] not recognized." % discriminator_name)
318 |
319 | return model_d
320 |
321 |
322 | def get_nets(model_config):
323 | return get_generator(model_config), get_discriminator(model_config)
324 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/models/unet_seresnext.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.parallel
4 | import torch.optim
5 | import torch.utils.data
6 | from torch.nn import Sequential
7 | from collections import OrderedDict
8 | import torchvision
9 | from torch.nn import functional as F
10 | from models.senet import se_resnext50_32x4d
11 |
12 |
13 | def conv3x3(in_, out):
14 | return nn.Conv2d(in_, out, 3, padding=1)
15 |
16 |
17 | class ConvRelu(nn.Module):
18 | def __init__(self, in_, out):
19 | super(ConvRelu, self).__init__()
20 | self.conv = conv3x3(in_, out)
21 | self.activation = nn.ReLU(inplace=True)
22 |
23 | def forward(self, x):
24 | x = self.conv(x)
25 | x = self.activation(x)
26 | return x
27 |
28 | class UNetSEResNext(nn.Module):
29 |
30 | def __init__(self, num_classes=3, num_filters=32,
31 | pretrained=True, is_deconv=True):
32 | super().__init__()
33 | self.num_classes = num_classes
34 | pretrain = 'imagenet' if pretrained is True else None
35 | self.encoder = se_resnext50_32x4d(num_classes=1000, pretrained=pretrain)
36 | bottom_channel_nr = 2048
37 |
38 | self.conv1 = self.encoder.layer0
39 | #self.se_e1 = SCSEBlock(64)
40 | self.conv2 = self.encoder.layer1
41 | #self.se_e2 = SCSEBlock(64 * 4)
42 | self.conv3 = self.encoder.layer2
43 | #self.se_e3 = SCSEBlock(128 * 4)
44 | self.conv4 = self.encoder.layer3
45 | #self.se_e4 = SCSEBlock(256 * 4)
46 | self.conv5 = self.encoder.layer4
47 | #self.se_e5 = SCSEBlock(512 * 4)
48 |
49 | self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)
50 |
51 | self.dec5 = DecoderBlockV(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 2, is_deconv)
52 | #self.se_d5 = SCSEBlock(num_filters * 2)
53 | self.dec4 = DecoderBlockV(bottom_channel_nr // 2 + num_filters * 2, num_filters * 8, num_filters * 2, is_deconv)
54 | #self.se_d4 = SCSEBlock(num_filters * 2)
55 | self.dec3 = DecoderBlockV(bottom_channel_nr // 4 + num_filters * 2, num_filters * 4, num_filters * 2, is_deconv)
56 | #self.se_d3 = SCSEBlock(num_filters * 2)
57 | self.dec2 = DecoderBlockV(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2, num_filters * 2, is_deconv)
58 | #self.se_d2 = SCSEBlock(num_filters * 2)
59 | self.dec1 = DecoderBlockV(num_filters * 2, num_filters, num_filters * 2, is_deconv)
60 | #self.se_d1 = SCSEBlock(num_filters * 2)
61 | self.dec0 = ConvRelu(num_filters * 10, num_filters * 2)
62 | self.final = nn.Conv2d(num_filters * 2, num_classes, kernel_size=1)
63 |
64 | def forward(self, x):
65 | conv1 = self.conv1(x)
66 | #conv1 = self.se_e1(conv1)
67 | conv2 = self.conv2(conv1)
68 | #conv2 = self.se_e2(conv2)
69 | conv3 = self.conv3(conv2)
70 | #conv3 = self.se_e3(conv3)
71 | conv4 = self.conv4(conv3)
72 | #conv4 = self.se_e4(conv4)
73 | conv5 = self.conv5(conv4)
74 | #conv5 = self.se_e5(conv5)
75 |
76 | center = self.center(conv5)
77 | dec5 = self.dec5(torch.cat([center, conv5], 1))
78 | #dec5 = self.se_d5(dec5)
79 | dec4 = self.dec4(torch.cat([dec5, conv4], 1))
80 | #dec4 = self.se_d4(dec4)
81 | dec3 = self.dec3(torch.cat([dec4, conv3], 1))
82 | #dec3 = self.se_d3(dec3)
83 | dec2 = self.dec2(torch.cat([dec3, conv2], 1))
84 | #dec2 = self.se_d2(dec2)
85 | dec1 = self.dec1(dec2)
86 | #dec1 = self.se_d1(dec1)
87 |
88 | f = torch.cat((
89 | dec1,
90 | F.upsample(dec2, scale_factor=2, mode='bilinear', align_corners=False),
91 | F.upsample(dec3, scale_factor=4, mode='bilinear', align_corners=False),
92 | F.upsample(dec4, scale_factor=8, mode='bilinear', align_corners=False),
93 | F.upsample(dec5, scale_factor=16, mode='bilinear', align_corners=False),
94 | ), 1)
95 |
96 | dec0 = self.dec0(f)
97 |
98 | return self.final(dec0)
99 |
100 | class DecoderBlockV(nn.Module):
101 | def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
102 | super(DecoderBlockV, self).__init__()
103 | self.in_channels = in_channels
104 |
105 | if is_deconv:
106 | self.block = nn.Sequential(
107 | ConvRelu(in_channels, middle_channels),
108 | nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
109 | padding=1),
110 | nn.InstanceNorm2d(out_channels, affine=False),
111 | nn.ReLU(inplace=True)
112 |
113 | )
114 | else:
115 | self.block = nn.Sequential(
116 | nn.Upsample(scale_factor=2, mode='bilinear'),
117 | ConvRelu(in_channels, middle_channels),
118 | ConvRelu(middle_channels, out_channels),
119 | )
120 |
121 | def forward(self, x):
122 | return self.block(x)
123 |
124 |
125 |
126 | class DecoderCenter(nn.Module):
127 | def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
128 | super(DecoderCenter, self).__init__()
129 | self.in_channels = in_channels
130 |
131 |
132 | if is_deconv:
133 | """
134 | Paramaters for Deconvolution were chosen to avoid artifacts, following
135 | link https://distill.pub/2016/deconv-checkerboard/
136 | """
137 |
138 | self.block = nn.Sequential(
139 | ConvRelu(in_channels, middle_channels),
140 | nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
141 | padding=1),
142 | nn.InstanceNorm2d(out_channels, affine=False),
143 | nn.ReLU(inplace=True)
144 | )
145 | else:
146 | self.block = nn.Sequential(
147 | ConvRelu(in_channels, middle_channels),
148 | ConvRelu(middle_channels, out_channels)
149 |
150 | )
151 |
152 | def forward(self, x):
153 | return self.block(x)
154 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/picture_to_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 |
4 | # # image path
5 | # im_dir = 'D:\github repo\DeblurGANv2\submit'
6 | # # output video path
7 | # video_dir = 'D:\github repo\DeblurGANv2'
8 | # if not os.path.exists(video_dir):
9 | # os.makedirs(video_dir)
10 | # # set saved fps
11 | # fps = 20
12 | # # get frames list
13 | # frames = sorted(os.listdir(im_dir))
14 | # # w,h of image
15 | # img = cv2.imread(os.path.join(im_dir, frames[0]))
16 | # #
17 | # img_size = (img.shape[1], img.shape[0])
18 | # # get seq name
19 | # seq_name = os.path.dirname(im_dir).split('/')[-1]
20 | # # splice video_dir
21 | # video_dir = os.path.join(video_dir, seq_name + '.avi')
22 | # fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
23 | # # also can write like:fourcc = cv2.VideoWriter_fourcc(*'MJPG')
24 | # # if want to write .mp4 file, use 'MP4V'
25 | # videowriter = cv2.VideoWriter(video_dir, fourcc, fps, img_size)
26 | #
27 | # for frame in frames:
28 | # f_path = os.path.join(im_dir, frame)
29 | # image = cv2.imread(f_path)
30 | # videowriter.write(image)
31 | # print(frame + " has been written!")
32 | #
33 | # videowriter.release()
34 |
35 |
36 | # def main():
37 | # data_path = 'D:\github repo\DeblurGANv2\submit'
38 | # fps = 30 # 视频帧率
39 | # size = (1280, 720) # 需要转为视频的图片的尺寸
40 | # video = cv2.VideoWriter("output.avi", cv2.VideoWriter_fourcc('I', '4', '2', '0'), fps, size)
41 | #
42 | # for i in range(1029):
43 | # image_path = data_path + "%010d_color_labels.png" % (i + 1)
44 | # print(image_path)
45 | # img = cv2.imread(image_path)
46 | # video.write(img)
47 | # video.release()
48 | #
49 | #
50 | #
51 | #
52 | # cv2.destroyAllWindows()
53 | #
54 | #
55 | # if __name__ == "__main__":
56 | # main()
57 |
58 |
59 | import cv2
60 | import os
61 | import numpy as np
62 | from PIL import Image
63 |
64 |
65 | def frame2video(im_dir, video_dir, fps):
66 | im_list = os.listdir(im_dir)
67 | im_list.sort(key=lambda x: int(x.replace("frame", "").split('.')[0])) # 最好再看看图片顺序对不
68 | img = Image.open(os.path.join(im_dir, im_list[0]))
69 | img_size = img.size # 获得图片分辨率,im_dir文件夹下的图片分辨率需要一致
70 |
71 | # fourcc = cv2.cv.CV_FOURCC('M','J','P','G') #opencv版本是2
72 | fourcc = cv2.VideoWriter_fourcc(*'XVID') # opencv版本是3
73 | videoWriter = cv2.VideoWriter(video_dir, fourcc, fps, img_size)
74 | # count = 1
75 | for i in im_list:
76 | im_name = os.path.join(im_dir + i)
77 | frame = cv2.imdecode(np.fromfile(im_name, dtype=np.uint8), -1)
78 | videoWriter.write(frame)
79 | # count+=1
80 | # if (count == 200):
81 | # print(im_name)
82 | # break
83 | videoWriter.release()
84 | print('finish')
85 |
86 |
87 | if __name__ == '__main__':
88 | # im_dir = 'D:\github repo\DeblurGANv2\submit\\' # 帧存放路径
89 | im_dir = 'D:\github repo\DeblurGANv2\dataset1\\blur\\' # 帧存放路径\\' # 帧存放路径
90 | video_dir = 'D:\github repo\DeblurGANv2/test.avi' # 合成视频存放的路径
91 | fps = 30 # 帧率,每秒钟帧数越多,所显示的动作就会越流畅
92 | frame2video(im_dir, video_dir, fps)
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/predict.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from typing import Optional
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 | import yaml
9 | from fire import Fire
10 | from tqdm import tqdm
11 |
12 | from aug import get_normalize
13 | from models.networks import get_generator
14 |
15 |
16 | class Predictor:
17 | def __init__(self, weights_path: str, model_name: str = ''):
18 | with open('config/config.yaml',encoding='utf-8') as cfg:
19 | config = yaml.load(cfg, Loader=yaml.FullLoader)
20 | model = get_generator(model_name or config['model'])
21 | model.load_state_dict(torch.load(weights_path)['model'])
22 | self.model = model.cuda()
23 | self.model.train(True)
24 | # GAN inference should be in train mode to use actual stats in norm layers,
25 | # it's not a bug
26 | self.normalize_fn = get_normalize()
27 |
28 | @staticmethod
29 | def _array_to_batch(x):
30 | x = np.transpose(x, (2, 0, 1))
31 | x = np.expand_dims(x, 0)
32 | return torch.from_numpy(x)
33 |
34 | def _preprocess(self, x: np.ndarray, mask: Optional[np.ndarray]):
35 | x, _ = self.normalize_fn(x, x)
36 | if mask is None:
37 | mask = np.ones_like(x, dtype=np.float32)
38 | else:
39 | mask = np.round(mask.astype('float32') / 255)
40 |
41 | h, w, _ = x.shape
42 | block_size = 32
43 | min_height = (h // block_size + 1) * block_size
44 | min_width = (w // block_size + 1) * block_size
45 |
46 | pad_params = {'mode': 'constant',
47 | 'constant_values': 0,
48 | 'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
49 | }
50 | x = np.pad(x, **pad_params)
51 | mask = np.pad(mask, **pad_params)
52 |
53 | return map(self._array_to_batch, (x, mask)), h, w
54 |
55 | @staticmethod
56 | def _postprocess(x: torch.Tensor) -> np.ndarray:
57 | x, = x
58 | x = x.detach().cpu().float().numpy()
59 | x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
60 | return x.astype('uint8')
61 |
62 | def __call__(self, img: np.ndarray, mask: Optional[np.ndarray], ignore_mask=True) -> np.ndarray:
63 | (img, mask), h, w = self._preprocess(img, mask)
64 | with torch.no_grad():
65 | inputs = [img.cuda()]
66 | if not ignore_mask:
67 | inputs += [mask]
68 | pred = self.model(*inputs)
69 | return self._postprocess(pred)[:h, :w, :]
70 |
71 | def process_video(pairs, predictor, output_dir):
72 | for video_filepath, mask in tqdm(pairs):
73 | video_filename = os.path.basename(video_filepath)
74 | output_filepath = os.path.join(output_dir, os.path.splitext(video_filename)[0]+'_deblur.mp4')
75 | video_in = cv2.VideoCapture(video_filepath)
76 | fps = video_in.get(cv2.CAP_PROP_FPS)
77 | width = int(video_in.get(cv2.CAP_PROP_FRAME_WIDTH))
78 | height = int(video_in.get(cv2.CAP_PROP_FRAME_HEIGHT))
79 | total_frame_num = int(video_in.get(cv2.CAP_PROP_FRAME_COUNT))
80 | video_out = cv2.VideoWriter(output_filepath, cv2.VideoWriter_fourcc(*'MP4V'), fps, (width, height))
81 | tqdm.write(f'process {video_filepath} to {output_filepath}, {fps}fps, resolution: {width}x{height}')
82 | for frame_num in tqdm(range(total_frame_num), desc=video_filename):
83 | res, img = video_in.read()
84 | if not res:
85 | break
86 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
87 | pred = predictor(img, mask)
88 | pred = cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)
89 | video_out.write(pred)
90 |
91 | def main(img_pattern: str,
92 | mask_pattern: Optional[str] = None,
93 | weights_path='fpn_inception.h5',
94 | out_dir='submit/',
95 | side_by_side: bool = False,
96 | video: bool = False):
97 | def sorted_glob(pattern):
98 | return sorted(glob(pattern))
99 |
100 | imgs = sorted_glob(img_pattern)
101 | masks = sorted_glob(mask_pattern) if mask_pattern is not None else [None for _ in imgs]
102 | pairs = zip(imgs, masks)
103 | names = sorted([os.path.basename(x) for x in glob(img_pattern)])
104 | predictor = Predictor(weights_path=weights_path)
105 |
106 | os.makedirs(out_dir, exist_ok=True)
107 | if not video:
108 | for name, pair in tqdm(zip(names, pairs), total=len(names)):
109 | f_img, f_mask = pair
110 | img, mask = map(cv2.imread, (f_img, f_mask))
111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
112 |
113 | pred = predictor(img, mask)
114 | if side_by_side:
115 | pred = np.hstack((img, pred))
116 | pred = cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)
117 | cv2.imwrite(os.path.join(out_dir, name),
118 | pred)
119 | else:
120 | process_video(pairs, predictor, out_dir)
121 |
122 | # def getfiles():
123 | # filenames = os.listdir(r'.\dataset1\blur')
124 | # print(filenames)
125 | def get_files():
126 | list=[]
127 | for filepath,dirnames,filenames in os.walk(r'.\dataset1\blur'):
128 | for filename in filenames:
129 | list.append(os.path.join(filepath,filename))
130 | return list
131 |
132 |
133 |
134 |
135 |
136 | if __name__ == '__main__':
137 | # Fire(main)
138 | #增加批量处理图片:
139 | img_path=get_files()
140 | for i in img_path:
141 | main(i)
142 | # main('test_img/tt.mp4')
143 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.0.1
2 | torchvision
3 | torchsummary
4 | pretrainedmodels
5 | numpy
6 | opencv-python-headless
7 | joblib
8 | albumentations>=1.0.0
9 | scikit-image==0.18.1
10 | tqdm
11 | glog
12 | tensorboardx
13 | fire
14 | # this file is not ready yet
15 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/schedulers.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from torch.optim import lr_scheduler
4 |
5 |
6 | class WarmRestart(lr_scheduler.CosineAnnealingLR):
7 | """This class implements Stochastic Gradient Descent with Warm Restarts(SGDR): https://arxiv.org/abs/1608.03983.
8 |
9 | Set the learning rate of each parameter group using a cosine annealing schedule, When last_epoch=-1, sets initial lr as lr.
10 | This can't support scheduler.step(epoch). please keep epoch=None.
11 | """
12 |
13 | def __init__(self, optimizer, T_max=30, T_mult=1, eta_min=0, last_epoch=-1):
14 | """implements SGDR
15 |
16 | Parameters:
17 | ----------
18 | T_max : int
19 | Maximum number of epochs.
20 | T_mult : int
21 | Multiplicative factor of T_max.
22 | eta_min : int
23 | Minimum learning rate. Default: 0.
24 | last_epoch : int
25 | The index of last epoch. Default: -1.
26 | """
27 | self.T_mult = T_mult
28 | super().__init__(optimizer, T_max, eta_min, last_epoch)
29 |
30 | def get_lr(self):
31 | if self.last_epoch == self.T_max:
32 | self.last_epoch = 0
33 | self.T_max *= self.T_mult
34 | return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 for
35 | base_lr in self.base_lrs]
36 |
37 |
38 | class LinearDecay(lr_scheduler._LRScheduler):
39 | """This class implements LinearDecay
40 |
41 | """
42 |
43 | def __init__(self, optimizer, num_epochs, start_epoch=0, min_lr=0, last_epoch=-1):
44 | """implements LinearDecay
45 |
46 | Parameters:
47 | ----------
48 |
49 | """
50 | self.num_epochs = num_epochs
51 | self.start_epoch = start_epoch
52 | self.min_lr = min_lr
53 | super().__init__(optimizer, last_epoch)
54 |
55 | def get_lr(self):
56 | if self.last_epoch < self.start_epoch:
57 | return self.base_lrs
58 | return [base_lr - ((base_lr - self.min_lr) / self.num_epochs) * (self.last_epoch - self.start_epoch) for
59 | base_lr in self.base_lrs]
60 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | print()
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python3 -m unittest discover $(pwd)
4 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/test_aug.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 |
5 | from aug import get_transforms
6 |
7 |
8 | class AugTest(unittest.TestCase):
9 | @staticmethod
10 | def make_images():
11 | img = (np.random.rand(100, 100, 3) * 255).astype('uint8')
12 | return img.copy(), img.copy()
13 |
14 | def test_aug(self):
15 | for scope in ('strong', 'weak'):
16 | for crop in ('random', 'center'):
17 | aug_pipeline = get_transforms(80, scope=scope, crop=crop)
18 | a, b = self.make_images()
19 | a, b = aug_pipeline(a, b)
20 | np.testing.assert_allclose(a, b)
21 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/test_batchsize.py:
--------------------------------------------------------------------------------
1 | # import os
2 | # def get_files():
3 | # list=[]
4 | # for filepath,dirnames,filenames in os.walk(r'.\dataset1\blur'):
5 | # for filename in filenames:
6 | # list.append(os.path.join(filepath,filename))
7 | # return list
8 | #
9 | # a=get_files()
10 | # print(len(a))
11 | #
12 | # # for i in a:
13 | # # print(i)
14 | import cv2
15 | print(cv2.__version__)
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/test_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import unittest
3 | from shutil import rmtree
4 | from tempfile import mkdtemp
5 |
6 | import cv2
7 | import numpy as np
8 | from torch.utils.data import DataLoader
9 |
10 | from dataset import PairedDataset
11 |
12 |
13 | def make_img():
14 | return (np.random.rand(100, 100, 3) * 255).astype('uint8')
15 |
16 |
17 | class AugTest(unittest.TestCase):
18 | tmp_dir = mkdtemp()
19 | raw = os.path.join(tmp_dir, 'raw')
20 | gt = os.path.join(tmp_dir, 'gt')
21 |
22 | def setUp(self):
23 | for d in (self.raw, self.gt):
24 | os.makedirs(d)
25 |
26 | for i in range(5):
27 | for d in (self.raw, self.gt):
28 | img = make_img()
29 | cv2.imwrite(os.path.join(d, f'{i}.png'), img)
30 |
31 | def tearDown(self):
32 | rmtree(self.tmp_dir)
33 |
34 | def dataset_gen(self, equal=True):
35 | base_config = {'files_a': os.path.join(self.raw, '*.png'),
36 | 'files_b': os.path.join(self.raw if equal else self.gt, '*.png'),
37 | 'size': 32,
38 | }
39 | for b in ([0, 1], [0, 0.9]):
40 | for scope in ('strong', 'weak'):
41 | for crop in ('random', 'center'):
42 | for preload in (0, 1):
43 | for preload_size in (0, 64):
44 | config = base_config.copy()
45 | config['bounds'] = b
46 | config['scope'] = scope
47 | config['crop'] = crop
48 | config['preload'] = preload
49 | config['preload_size'] = preload_size
50 | config['verbose'] = False
51 | dataset = PairedDataset.from_config(config)
52 | yield dataset
53 |
54 | def test_equal_datasets(self):
55 | for dataset in self.dataset_gen(equal=True):
56 | dataloader = DataLoader(dataset=dataset,
57 | batch_size=2,
58 | shuffle=True,
59 | drop_last=True)
60 | dataloader = iter(dataloader)
61 | batch = next(dataloader)
62 | a, b = map(lambda x: x.numpy(), map(batch.get, ('a', 'b')))
63 |
64 | np.testing.assert_allclose(a, b)
65 |
66 | def test_datasets(self):
67 | for dataset in self.dataset_gen(equal=False):
68 | dataloader = DataLoader(dataset=dataset,
69 | batch_size=2,
70 | shuffle=True,
71 | drop_last=True)
72 | dataloader = iter(dataloader)
73 | batch = next(dataloader)
74 | a, b = map(lambda x: x.numpy(), map(batch.get, ('a', 'b')))
75 |
76 | assert not np.all(a == b), 'images should not be the same'
77 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/test_metrics.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import numpy as np
4 | import torch
5 | import cv2
6 | import yaml
7 | import os
8 | from torchvision import models, transforms
9 | from torch.autograd import Variable
10 | import shutil
11 | import glob
12 | import tqdm
13 | from util.metrics import PSNR
14 | from albumentations import Compose, CenterCrop, PadIfNeeded
15 | from PIL import Image
16 | from ssim.ssimlib import SSIM
17 | from models.networks import get_generator
18 |
19 |
20 | def get_args():
21 | parser = argparse.ArgumentParser('Test an image')
22 | parser.add_argument('--img_folder', required=True, help='GoPRO Folder')
23 | parser.add_argument('--weights_path', required=True, help='Weights path')
24 |
25 | return parser.parse_args()
26 |
27 |
28 | def prepare_dirs(path):
29 | if os.path.exists(path):
30 | shutil.rmtree(path)
31 | os.makedirs(path)
32 |
33 |
34 | def get_gt_image(path):
35 | dir, filename = os.path.split(path)
36 | base, seq = os.path.split(dir)
37 | base, _ = os.path.split(base)
38 | img = cv2.cvtColor(cv2.imread(os.path.join(base, 'sharp', seq, filename)), cv2.COLOR_BGR2RGB)
39 | return img
40 |
41 |
42 | def test_image(model, image_path):
43 | img_transforms = transforms.Compose([
44 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
45 | ])
46 | size_transform = Compose([
47 | PadIfNeeded(736, 1280)
48 | ])
49 | crop = CenterCrop(720, 1280)
50 | img = cv2.imread(image_path)
51 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
52 | img_s = size_transform(image=img)['image']
53 | img_tensor = torch.from_numpy(np.transpose(img_s / 255, (2, 0, 1)).astype('float32'))
54 | img_tensor = img_transforms(img_tensor)
55 | with torch.no_grad():
56 | img_tensor = Variable(img_tensor.unsqueeze(0).cuda())
57 | result_image = model(img_tensor)
58 | result_image = result_image[0].cpu().float().numpy()
59 | result_image = (np.transpose(result_image, (1, 2, 0)) + 1) / 2.0 * 255.0
60 | result_image = crop(image=result_image)['image']
61 | result_image = result_image.astype('uint8')
62 | gt_image = get_gt_image(image_path)
63 | _, filename = os.path.split(image_path)
64 | psnr = PSNR(result_image, gt_image)
65 | pilFake = Image.fromarray(result_image)
66 | pilReal = Image.fromarray(gt_image)
67 | ssim = SSIM(pilFake).cw_ssim_value(pilReal)
68 | return psnr, ssim
69 |
70 |
71 | def test(model, files):
72 | psnr = 0
73 | ssim = 0
74 | for file in tqdm.tqdm(files):
75 | cur_psnr, cur_ssim = test_image(model, file)
76 | psnr += cur_psnr
77 | ssim += cur_ssim
78 | print("PSNR = {}".format(psnr / len(files)))
79 | print("SSIM = {}".format(ssim / len(files)))
80 |
81 |
82 | if __name__ == '__main__':
83 | args = get_args()
84 | with open('config/config.yaml') as cfg:
85 | config = yaml.load(cfg)
86 | model = get_generator(config['model'])
87 | model.load_state_dict(torch.load(args.weights_path)['model'])
88 | model = model.cuda()
89 | filenames = sorted(glob.glob(args.img_folder + '/test' + '/blur/**/*.png', recursive=True))
90 | test(model, filenames)
91 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/train.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from functools import partial
4 |
5 | import cv2
6 | import torch
7 | import torch.optim as optim
8 | import tqdm
9 | import yaml
10 | from joblib import cpu_count
11 | from torch.utils.data import DataLoader
12 |
13 | from adversarial_trainer import GANFactory
14 | from dataset import PairedDataset
15 | from metric_counter import MetricCounter
16 | from models.losses import get_loss
17 | from models.models import get_model
18 | from models.networks import get_nets
19 | from schedulers import LinearDecay, WarmRestart
20 | from fire import Fire
21 |
22 | cv2.setNumThreads(0)
23 |
24 |
25 | class Trainer:
26 | def __init__(self, config, train: DataLoader, val: DataLoader):
27 | self.config = config
28 | self.train_dataset = train
29 | self.val_dataset = val
30 | self.adv_lambda = config['model']['adv_lambda']
31 | self.metric_counter = MetricCounter(config['experiment_desc'])
32 | self.warmup_epochs = config['warmup_num']
33 |
34 | def train(self):
35 | self._init_params()
36 | for epoch in range(0, self.config['num_epochs']):
37 | if (epoch == self.warmup_epochs) and not (self.warmup_epochs == 0):
38 | self.netG.module.unfreeze()
39 | self.optimizer_G = self._get_optim(self.netG.parameters())
40 | self.scheduler_G = self._get_scheduler(self.optimizer_G)
41 | self._run_epoch(epoch)
42 | self._validate(epoch)
43 | self.scheduler_G.step()
44 | self.scheduler_D.step()
45 |
46 | if self.metric_counter.update_best_model():
47 | torch.save({
48 | 'model': self.netG.state_dict()
49 | }, 'best_{}.h5'.format(self.config['experiment_desc']))
50 | torch.save({
51 | 'model': self.netG.state_dict()
52 | }, 'last_{}.h5'.format(self.config['experiment_desc']))
53 | print(self.metric_counter.loss_message())
54 | logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % (
55 | self.config['experiment_desc'], epoch, self.metric_counter.loss_message()))
56 |
57 | def _run_epoch(self, epoch):
58 | self.metric_counter.clear()
59 | for param_group in self.optimizer_G.param_groups:
60 | lr = param_group['lr']
61 |
62 | epoch_size = self.config.get('train_batches_per_epoch') or len(self.train_dataset)
63 | tq = tqdm.tqdm(self.train_dataset, total=epoch_size)
64 | tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
65 | i = 0
66 | for data in tq:
67 | inputs, targets = self.model.get_input(data)
68 | outputs = self.netG(inputs)
69 | loss_D = self._update_d(outputs, targets)
70 | self.optimizer_G.zero_grad()
71 | loss_content = self.criterionG(outputs, targets)
72 | loss_adv = self.adv_trainer.loss_g(outputs, targets)
73 | loss_G = loss_content + self.adv_lambda * loss_adv
74 | loss_G.backward()
75 | self.optimizer_G.step()
76 | self.metric_counter.add_losses(loss_G.item(), loss_content.item(), loss_D)
77 | curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
78 | self.metric_counter.add_metrics(curr_psnr, curr_ssim)
79 | tq.set_postfix(loss=self.metric_counter.loss_message())
80 | if not i:
81 | self.metric_counter.add_image(img_for_vis, tag='train')
82 | i += 1
83 | if i > epoch_size:
84 | break
85 | tq.close()
86 | self.metric_counter.write_to_tensorboard(epoch)
87 |
88 | def _validate(self, epoch):
89 | self.metric_counter.clear()
90 | epoch_size = self.config.get('val_batches_per_epoch') or len(self.val_dataset)
91 | tq = tqdm.tqdm(self.val_dataset, total=epoch_size)
92 | tq.set_description('Validation')
93 | i = 0
94 | for data in tq:
95 | inputs, targets = self.model.get_input(data)
96 | with torch.no_grad():
97 | outputs = self.netG(inputs)
98 | loss_content = self.criterionG(outputs, targets)
99 | loss_adv = self.adv_trainer.loss_g(outputs, targets)
100 | loss_G = loss_content + self.adv_lambda * loss_adv
101 | self.metric_counter.add_losses(loss_G.item(), loss_content.item())
102 | curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
103 | self.metric_counter.add_metrics(curr_psnr, curr_ssim)
104 | if not i:
105 | self.metric_counter.add_image(img_for_vis, tag='val')
106 | i += 1
107 | if i > epoch_size:
108 | break
109 | tq.close()
110 | self.metric_counter.write_to_tensorboard(epoch, validation=True)
111 |
112 | def _update_d(self, outputs, targets):
113 | if self.config['model']['d_name'] == 'no_gan':
114 | return 0
115 | self.optimizer_D.zero_grad()
116 | loss_D = self.adv_lambda * self.adv_trainer.loss_d(outputs, targets)
117 | loss_D.backward(retain_graph=True)
118 | self.optimizer_D.step()
119 | return loss_D.item()
120 |
121 | def _get_optim(self, params):
122 | if self.config['optimizer']['name'] == 'adam':
123 | optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
124 | elif self.config['optimizer']['name'] == 'sgd':
125 | optimizer = optim.SGD(params, lr=self.config['optimizer']['lr'])
126 | elif self.config['optimizer']['name'] == 'adadelta':
127 | optimizer = optim.Adadelta(params, lr=self.config['optimizer']['lr'])
128 | else:
129 | raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name'])
130 | return optimizer
131 |
132 | def _get_scheduler(self, optimizer):
133 | if self.config['scheduler']['name'] == 'plateau':
134 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
135 | mode='min',
136 | patience=self.config['scheduler']['patience'],
137 | factor=self.config['scheduler']['factor'],
138 | min_lr=self.config['scheduler']['min_lr'])
139 | elif self.config['optimizer']['name'] == 'sgdr':
140 | scheduler = WarmRestart(optimizer)
141 | elif self.config['scheduler']['name'] == 'linear':
142 | scheduler = LinearDecay(optimizer,
143 | min_lr=self.config['scheduler']['min_lr'],
144 | num_epochs=self.config['num_epochs'],
145 | start_epoch=self.config['scheduler']['start_epoch'])
146 | else:
147 | raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name'])
148 | return scheduler
149 |
150 | @staticmethod
151 | def _get_adversarial_trainer(d_name, net_d, criterion_d):
152 | if d_name == 'no_gan':
153 | return GANFactory.create_model('NoGAN')
154 | elif d_name == 'patch_gan' or d_name == 'multi_scale':
155 | return GANFactory.create_model('SingleGAN', net_d, criterion_d)
156 | elif d_name == 'double_gan':
157 | return GANFactory.create_model('DoubleGAN', net_d, criterion_d)
158 | else:
159 | raise ValueError("Discriminator Network [%s] not recognized." % d_name)
160 |
161 | def _init_params(self):
162 | self.criterionG, criterionD = get_loss(self.config['model'])
163 | self.netG, netD = get_nets(self.config['model'])
164 | self.netG.cuda()
165 | self.adv_trainer = self._get_adversarial_trainer(self.config['model']['d_name'], netD, criterionD)
166 | self.model = get_model(self.config['model'])
167 | self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
168 | self.optimizer_D = self._get_optim(self.adv_trainer.get_params())
169 | self.scheduler_G = self._get_scheduler(self.optimizer_G)
170 | self.scheduler_D = self._get_scheduler(self.optimizer_D)
171 |
172 |
173 | def main(config_path='config/config.yaml'):
174 | with open(config_path, 'r',encoding='utf-8') as f:
175 | config = yaml.load(f, Loader=yaml.SafeLoader)
176 |
177 | batch_size = config.pop('batch_size')
178 | # get_dataloader = partial(DataLoader,
179 | # batch_size=batch_size,
180 | # num_workers=0 if os.environ.get('DEBUG') else cpu_count(),
181 | # shuffle=True, drop_last=True)
182 | get_dataloader = partial(DataLoader,
183 | batch_size=batch_size,
184 | shuffle=True, drop_last=True)
185 |
186 | datasets = map(config.pop, ('train', 'val'))
187 | datasets = map(PairedDataset.from_config, datasets)
188 | train, val = map(get_dataloader, datasets)
189 | trainer = Trainer(config, train=train, val=val)
190 | trainer.train()
191 |
192 |
193 | if __name__ == '__main__':
194 | Fire(main)
195 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/train_end2end.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from functools import partial
3 |
4 | import cv2
5 | import torch
6 | import torch.optim as optim
7 | import tqdm
8 | import yaml
9 | from torch.utils.data import DataLoader
10 |
11 | from adversarial_trainer import GANFactory
12 | from dataset import PairedDataset
13 | from metric_counter import MetricCounter
14 | from models.losses import get_loss
15 | from models.models import get_model
16 | from models.networks import get_nets
17 | from schedulers import LinearDecay, WarmRestart
18 | from fire import Fire
19 |
20 | cv2.setNumThreads(0)
21 |
22 |
23 | class Trainer:
24 | def __init__(self, config, train: DataLoader, val: DataLoader):
25 | self.config = config
26 | self.train_dataset = train
27 | self.val_dataset = val
28 | self.adv_lambda = config['model']['adv_lambda']
29 | self.metric_counter = MetricCounter(config['experiment_desc'])
30 | self.warmup_epochs = config['warmup_num']
31 | self._init_params()
32 |
33 | def train(self):
34 | self._init_params()
35 | for epoch in range(0, self.config['num_epochs']):
36 | if (epoch == self.warmup_epochs) and not (self.warmup_epochs == 0):
37 | self.netG.module.unfreeze()
38 | self.optimizer_G = self._get_optim(self.netG.parameters())
39 | self.scheduler_G = self._get_scheduler(self.optimizer_G)
40 | self._run_epoch(epoch)
41 | self._validate(epoch)
42 | self.scheduler_G.step()
43 | self.scheduler_D.step()
44 |
45 | if self.metric_counter.update_best_model():
46 | torch.save({
47 | 'model': self.netG.state_dict()
48 | }, 'best_{}.h5'.format(self.config['experiment_desc']))
49 | torch.save({
50 | 'model': self.netG.state_dict()
51 | }, 'last_{}.h5'.format(self.config['experiment_desc']))
52 | print(self.metric_counter.loss_message())
53 | logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % (
54 | self.config['experiment_desc'], epoch, self.metric_counter.loss_message()))
55 |
56 | def prepare(self, epoch=0):
57 | if (epoch == self.warmup_epochs) and not (self.warmup_epochs == 0):
58 | self.netG.module.unfreeze()
59 | self.optimizer_G = self._get_optim(self.netG.parameters())
60 | self.scheduler_G = self._get_scheduler(self.optimizer_G)
61 |
62 | def run(self, inputs, targets, is_inference=False, num_iters=1, desc=None):
63 | """
64 | Runs the network given an input blur image and a ground truth image.
65 |
66 | Customized for end2end training.
67 |
68 | Args:
69 | inputs: Input image(s).
70 | targets: Ground truth image(s).
71 | is_inference: Is inference.
72 | num_iters: Number of iterations on this batch of inputs and targets.
73 | desc: Description of the progress bar.
74 |
75 | Returns:
76 | outputs: Output image(s).
77 | """
78 |
79 | if is_inference:
80 | with torch.no_grad():
81 | outputs = self.netG(inputs)
82 | loss_content = self.criterionG(outputs, targets)
83 | loss_adv = self.adv_trainer.loss_g(outputs, targets)
84 | loss_G = loss_content + self.adv_lambda * loss_adv
85 |
86 | return outputs.detach()
87 |
88 | if desc is not None:
89 | tq = tqdm.tqdm(range(num_iters), total=num_iters)
90 | tq.set_description(desc)
91 | else:
92 | tq = range(num_iters)
93 |
94 | for it in tq:
95 | outputs = self.netG(inputs)
96 | loss_D = self._update_d(outputs, targets)
97 | self.optimizer_G.zero_grad()
98 | loss_content = self.criterionG(outputs, targets)
99 | loss_adv = self.adv_trainer.loss_g(outputs, targets)
100 | loss_G = loss_content + self.adv_lambda * loss_adv
101 | loss_G.backward()
102 | self.optimizer_G.step()
103 | self.metric_counter.add_losses(loss_G.item(), loss_content.item(), loss_D)
104 | curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs.detach(), outputs, targets)
105 | self.metric_counter.add_metrics(curr_psnr, curr_ssim)
106 |
107 | if desc is not None:
108 | tq.set_postfix(loss=self.metric_counter.loss_message())
109 |
110 | if desc is not None:
111 | tq.close()
112 |
113 | return outputs.detach()
114 |
115 | def _run_epoch(self, epoch):
116 | self.metric_counter.clear()
117 | for param_group in self.optimizer_G.param_groups:
118 | lr = param_group['lr']
119 |
120 | epoch_size = self.config.get('train_batches_per_epoch') or len(self.train_dataset)
121 | tq = tqdm.tqdm(self.train_dataset, total=epoch_size)
122 | tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
123 | i = 0
124 | for data in tq:
125 | inputs, targets = self.model.get_input(data)
126 | outputs = self.netG(inputs)
127 | loss_D = self._update_d(outputs, targets)
128 | self.optimizer_G.zero_grad()
129 | loss_content = self.criterionG(outputs, targets)
130 | loss_adv = self.adv_trainer.loss_g(outputs, targets)
131 | loss_G = loss_content + self.adv_lambda * loss_adv
132 | loss_G.backward()
133 | self.optimizer_G.step()
134 | self.metric_counter.add_losses(loss_G.item(), loss_content.item(), loss_D)
135 | curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
136 | self.metric_counter.add_metrics(curr_psnr, curr_ssim)
137 | tq.set_postfix(loss=self.metric_counter.loss_message())
138 | if not i:
139 | self.metric_counter.add_image(img_for_vis, tag='train')
140 | i += 1
141 | if i > epoch_size:
142 | break
143 | tq.close()
144 | self.metric_counter.write_to_tensorboard(epoch)
145 |
146 | def _validate(self, epoch):
147 | self.metric_counter.clear()
148 | epoch_size = self.config.get('val_batches_per_epoch') or len(self.val_dataset)
149 | tq = tqdm.tqdm(self.val_dataset, total=epoch_size)
150 | tq.set_description('Validation')
151 | i = 0
152 | for data in tq:
153 | inputs, targets = self.model.get_input(data)
154 | with torch.no_grad():
155 | outputs = self.netG(inputs)
156 | loss_content = self.criterionG(outputs, targets)
157 | loss_adv = self.adv_trainer.loss_g(outputs, targets)
158 | loss_G = loss_content + self.adv_lambda * loss_adv
159 | self.metric_counter.add_losses(loss_G.item(), loss_content.item())
160 | curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
161 | self.metric_counter.add_metrics(curr_psnr, curr_ssim)
162 | if not i:
163 | self.metric_counter.add_image(img_for_vis, tag='val')
164 | i += 1
165 | if i > epoch_size:
166 | break
167 | tq.close()
168 | self.metric_counter.write_to_tensorboard(epoch, validation=True)
169 |
170 | def _update_d(self, outputs, targets):
171 | if self.config['model']['d_name'] == 'no_gan':
172 | return 0
173 | self.optimizer_D.zero_grad()
174 | loss_D = self.adv_lambda * self.adv_trainer.loss_d(outputs, targets)
175 | loss_D.backward(retain_graph=True)
176 | self.optimizer_D.step()
177 | return loss_D.item()
178 |
179 | def _get_optim(self, params):
180 | if self.config['optimizer']['name'] == 'adam':
181 | optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
182 | elif self.config['optimizer']['name'] == 'sgd':
183 | optimizer = optim.SGD(params, lr=self.config['optimizer']['lr'])
184 | elif self.config['optimizer']['name'] == 'adadelta':
185 | optimizer = optim.Adadelta(params, lr=self.config['optimizer']['lr'])
186 | else:
187 | raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name'])
188 | return optimizer
189 |
190 | def _get_scheduler(self, optimizer):
191 | if self.config['scheduler']['name'] == 'plateau':
192 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
193 | mode='min',
194 | patience=self.config['scheduler']['patience'],
195 | factor=self.config['scheduler']['factor'],
196 | min_lr=self.config['scheduler']['min_lr'])
197 | elif self.config['optimizer']['name'] == 'sgdr':
198 | scheduler = WarmRestart(optimizer)
199 | elif self.config['scheduler']['name'] == 'linear':
200 | scheduler = LinearDecay(optimizer,
201 | min_lr=self.config['scheduler']['min_lr'],
202 | num_epochs=self.config['num_epochs'],
203 | start_epoch=self.config['scheduler']['start_epoch'])
204 | else:
205 | raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name'])
206 | return scheduler
207 |
208 | @staticmethod
209 | def _get_adversarial_trainer(d_name, net_d, criterion_d):
210 | if d_name == 'no_gan':
211 | return GANFactory.create_model('NoGAN')
212 | elif d_name == 'patch_gan' or d_name == 'multi_scale':
213 | return GANFactory.create_model('SingleGAN', net_d, criterion_d)
214 | elif d_name == 'double_gan':
215 | return GANFactory.create_model('DoubleGAN', net_d, criterion_d)
216 | else:
217 | raise ValueError("Discriminator Network [%s] not recognized." % d_name)
218 |
219 | def _init_params(self):
220 | self.criterionG, criterionD = get_loss(self.config['model'])
221 | self.netG, netD = get_nets(self.config['model'])
222 | self.netG.cuda()
223 | self.adv_trainer = self._get_adversarial_trainer(self.config['model']['d_name'], netD, criterionD)
224 | self.model = get_model(self.config['model'])
225 | self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
226 | self.optimizer_D = self._get_optim(self.adv_trainer.get_params())
227 | self.scheduler_G = self._get_scheduler(self.optimizer_G)
228 | self.scheduler_D = self._get_scheduler(self.optimizer_D)
229 |
230 |
231 | def load_from_config(config_path='config/config.yaml'):
232 | with open(config_path, 'r', encoding='utf-8') as f:
233 | config = yaml.load(f, Loader=yaml.SafeLoader)
234 |
235 | batch_size = config.pop('batch_size')
236 | # get_dataloader = partial(DataLoader,
237 | # batch_size=batch_size,
238 | # num_workers=0 if os.environ.get('DEBUG') else cpu_count(),
239 | # shuffle=True, drop_last=True)
240 | get_dataloader = partial(DataLoader,
241 | batch_size=batch_size,
242 | shuffle=True, drop_last=True)
243 |
244 | datasets = map(config.pop, ('train', 'val'))
245 | datasets = map(PairedDataset.from_config, datasets)
246 | train, val = map(get_dataloader, datasets)
247 | trainer = Trainer(config, train=train, val=val)
248 |
249 | return trainer
250 |
251 |
252 | def main():
253 | trainer = load_from_config()
254 | trainer.train()
255 |
256 |
257 | if __name__ == '__main__':
258 | Fire(main)
259 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/neural_networks/DeblurGANv2/util/__init__.py
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/util/image_pool.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | from torch.autograd import Variable
5 | from collections import deque
6 |
7 |
8 | class ImagePool():
9 | def __init__(self, pool_size):
10 | self.pool_size = pool_size
11 | self.sample_size = pool_size
12 | if self.pool_size > 0:
13 | self.num_imgs = 0
14 | self.images = deque()
15 |
16 | def add(self, images):
17 | if self.pool_size == 0:
18 | return images
19 | for image in images.data:
20 | image = torch.unsqueeze(image, 0)
21 | if self.num_imgs < self.pool_size:
22 | self.num_imgs = self.num_imgs + 1
23 | self.images.append(image)
24 | else:
25 | self.images.popleft()
26 | self.images.append(image)
27 |
28 | def query(self):
29 | if len(self.images) > self.sample_size:
30 | return_images = list(random.sample(self.images, self.sample_size))
31 | else:
32 | return_images = list(self.images)
33 | return torch.cat(return_images, 0)
34 |
--------------------------------------------------------------------------------
/examples/neural_networks/DeblurGANv2/util/metrics.py:
--------------------------------------------------------------------------------
1 | import math
2 | from math import exp
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 |
9 |
10 | def gaussian(window_size, sigma):
11 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
12 | return gauss / gauss.sum()
13 |
14 |
15 | def create_window(window_size, channel):
16 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
17 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
18 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
19 | return window
20 |
21 |
22 | def SSIM(img1, img2):
23 | (_, channel, _, _) = img1.size()
24 | window_size = 11
25 | window = create_window(window_size, channel)
26 |
27 | if img1.is_cuda:
28 | window = window.cuda(img1.get_device())
29 | window = window.type_as(img1)
30 |
31 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
32 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
33 |
34 | mu1_sq = mu1.pow(2)
35 | mu2_sq = mu2.pow(2)
36 | mu1_mu2 = mu1 * mu2
37 |
38 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
39 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
40 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
41 |
42 | C1 = 0.01 ** 2
43 | C2 = 0.03 ** 2
44 |
45 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
46 | return ssim_map.mean()
47 |
48 |
49 | def PSNR(img1, img2):
50 | mse = np.mean((img1 / 255. - img2 / 255.) ** 2)
51 | if mse == 0:
52 | return 100
53 | PIXEL_MAX = 1
54 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
55 |
--------------------------------------------------------------------------------
/examples/nikon.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import matplotlib.pyplot as plt
4 | from pathlib import Path
5 |
6 | import sys
7 | sys.path.append("../")
8 | import diffoptics as do
9 | from datetime import datetime
10 |
11 | """
12 | DISCLAIMER:
13 |
14 | This script was used to generate Figure 11 in the paper. However, the results produced in
15 | the paper mistakenly assumed the air refractive index to be 1, rather than 1.000293. This
16 | slight difference will produce a slightly different result from the paper.
17 |
18 | If you want to reproduce the result in the paper, change the following item in
19 |
20 | diffoptics.basics.Material.__init__.MATERIAL_TABLE
21 |
22 | from:
23 |
24 | "air": [1.000293, np.inf],
25 |
26 | to
27 |
28 | "air": [1.000000, np.inf],
29 |
30 | And re-run this script.
31 | """
32 |
33 |
34 | # initialize a lens
35 | device = torch.device('cpu')
36 | # device = torch.device('cuda')
37 | lens = do.Lensgroup(device=device)
38 |
39 | # load optics
40 | lens.load_file(Path('./lenses/Zemax_samples/Nikon-z35-f1.8-JPA2019090949-example2.txt'))
41 | # lens.plot_setup2D()
42 |
43 | # sample wavelengths in [nm]
44 | wavelengths = torch.Tensor([656.2725, 587.5618, 486.1327]).to(device)
45 | views = np.array([0, 10, 20, 32.45])
46 | colors_list = 'bgry'
47 |
48 | def plot_layout(string):
49 | ax, fig = lens.plot_setup2D_with_trace(views, wavelengths[1], M=5, entrance_pupil=True)
50 | ax.axis('off')
51 | ax.set_title("")
52 | fig.savefig("layout_trace_" + string + "_" + datetime.now().strftime('%Y%m%d-%H%M%S-%f') + ".pdf", bbox_inches='tight')
53 |
54 | M = 31
55 | def render(verbose=False, entrance_pupil=False):
56 | def render_single(wavelength):
57 | pss = []
58 | spot_rms = []
59 | loss = 0.0
60 | for view in views:
61 | ray = lens.sample_ray(wavelength, view=view, M=M, sampling='grid', entrance_pupil=entrance_pupil)
62 | ps = lens.trace_to_sensor(ray, ignore_invalid=True)
63 |
64 | # calculate RMS
65 | tmp, ps = lens.rms(ps[...,:2], squared=True)
66 | loss = loss + tmp
67 | pss.append(ps)
68 | spot_rms.append(np.sqrt(tmp.item()))
69 | return pss, loss, np.array(spot_rms)
70 |
71 | pss_all = []
72 | rms_all = []
73 | loss = 0.0
74 | for wavelength in wavelengths:
75 | if verbose:
76 | print("Rendering wavelength = {} [nm] ...".format(wavelength.item()))
77 | pss, loss_single, rmss = render_single(wavelength)
78 | loss = loss + loss_single
79 | pss_all.append(pss)
80 | rms_all.append(rmss)
81 | return pss_all, loss, np.array(rms_all)
82 |
83 | def func():
84 | ps = render()[0]
85 | return torch.vstack([torch.vstack(ps[i]) for i in range(len(ps))])
86 |
87 | def loss_func():
88 | return render()[1]
89 |
90 | def info(string):
91 | loss, rms = render()[1:]
92 | print("=== {} ===".format(string))
93 | print("loss = {}".format(loss))
94 | print("==========")
95 | plot_layout(string)
96 | return rms
97 |
98 | rms_org = info('original')
99 | print(rms_org.mean())
100 |
101 | id_range = list(range(0, 19))
102 | id_range.pop(lens.aperture_ind)
103 | id_asphere = [16, 17]
104 | for i in id_asphere:
105 | lens.surfaces[i].ai = torch.Tensor([0.0]).to(device)
106 |
107 | diff_names = []
108 | diff_names += ['surfaces[{}].c'.format(str(i)) for i in id_range]
109 | diff_names += ['surfaces[{}].k'.format(str(i)) for i in id_asphere]
110 | diff_names += ['surfaces[{}].ai'.format(str(i)) for i in id_asphere]
111 |
112 | rms_init = info('initial')
113 | print(rms_init.mean())
114 |
115 | # optimize
116 | start = torch.cuda.Event(enable_timing=True)
117 | end = torch.cuda.Event(enable_timing=True)
118 |
119 | start.record()
120 | out = do.LM(lens, diff_names, 1e-2, option='diag') \
121 | .optimize(func, lambda y: 0.0 - y, maxit=100, record=True)
122 | end.record()
123 | torch.cuda.synchronize()
124 | print('Finished in {:.2f} mins'.format(start.elapsed_time(end)/1000/60))
125 |
126 | rms_opt = info('optimized')
127 | print(rms_opt.mean())
128 |
129 | # plot loss
130 | fig, ax = plt.subplots(figsize=(12,6))
131 | ax.semilogy(out['ls'], 'k-o', linewidth=3)
132 | plt.xlabel('iteration')
133 | plt.ylabel('error function')
134 | plt.savefig("./ls_nikon.pdf", bbox_inches='tight')
135 |
136 | def save_fig(xs, string):
137 | fig, ax = plt.subplots(figsize=(3,1.5))
138 | xs = xs.T
139 | for i, x in enumerate(xs):
140 | ax.semilogy(x, colors_list[i], marker='o', linewidth=1)
141 | plt.xlabel('wavelength [nm]')
142 | plt.ylabel('RMS spot size [um]')
143 | plt.ylim([0.08, 50])
144 | plt.xticks([0,1,2], ['656.27', '587.56', '486.13'])
145 | plt.yticks([0.1,1,10,50], ['0.1', '1', '10', '50'])
146 | fig.savefig("./rms_" + string + "_nikon.pdf", bbox_inches='tight')
147 |
148 | save_fig(rms_init * 1e3, "init")
149 | save_fig(rms_org * 1e3, "org")
150 | save_fig(rms_opt * 1e3, "opt")
151 |
--------------------------------------------------------------------------------
/examples/render_image.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import torch
4 | import matplotlib.pyplot as plt
5 | from pathlib import Path
6 | from tqdm import tqdm
7 |
8 | import sys
9 | sys.path.append("../")
10 | import diffoptics as do
11 |
12 | # initialize a lens
13 | device = torch.device('cuda')
14 | lens = do.Lensgroup(device=device)
15 |
16 | # load optics
17 | lens.load_file(Path('./lenses/DoubleGauss/US02532751-1.txt'))
18 |
19 | # set sensor pixel size and film size
20 | pixel_size = 6.45e-3 # [mm]
21 | film_size = [768, 1024]
22 |
23 | # set a rendering image sensor, and call prepare_mts to prepare the lensgroup for rendering
24 | lens.prepare_mts(pixel_size, film_size)
25 | # lens.plot_setup2D()
26 |
27 | # create a dummy screen
28 | z0 = 10e3 # [mm]
29 | pixelsize = 1.1 # [mm]
30 | texture = cv2.cvtColor(cv2.imread('./images/squirrel.jpg'), cv2.COLOR_BGR2RGB)
31 | texture = np.flip(texture.astype(np.float32), axis=(0,1)).copy()
32 | texture_torch = torch.Tensor(texture).to(device=device)
33 | texturesize = np.array(texture.shape[0:2])
34 | screen = do.Screen(
35 | do.Transformation(np.eye(3), np.array([0, 0, z0])),
36 | texturesize * pixelsize, texture_torch, device=device
37 | )
38 |
39 | # helper function
40 | def render_single(wavelength, screen):
41 | valid, ray_new = lens.sample_ray_sensor(wavelength)
42 | uv, valid_screen = screen.intersect(ray_new)[1:]
43 | mask = valid & valid_screen
44 | I = screen.shading(uv, mask)
45 | return I, mask
46 |
47 | # sample wavelengths in [nm]
48 | wavelengths = [656.2725, 587.5618, 486.1327]
49 |
50 | # render
51 | ray_counts_per_pixel = 100
52 | Is = []
53 | for wavelength_id, wavelength in enumerate(wavelengths):
54 | screen.update_texture(texture_torch[..., wavelength_id])
55 |
56 | # multi-pass rendering by sampling the aperture
57 | I = 0
58 | M = 0
59 | for i in tqdm(range(ray_counts_per_pixel)):
60 | I_current, mask = render_single(wavelength, screen)
61 | I = I + I_current
62 | M = M + mask
63 | I = I / (M + 1e-10)
64 |
65 | # reshape data to a 2D image
66 | I = I.reshape(*np.flip(np.asarray(film_size))).permute(1,0)
67 | Is.append(I.cpu())
68 |
69 | # show image
70 | I_rendered = torch.stack(Is, axis=-1).numpy().astype(np.uint8)
71 | plt.imshow(I_rendered)
72 | plt.show()
73 | plt.imsave('I_rendered.png', I_rendered)
74 |
--------------------------------------------------------------------------------
/examples/render_psf.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import matplotlib.pyplot as plt
4 | from pathlib import Path
5 | from tqdm import tqdm
6 |
7 | import sys
8 | sys.path.append("../")
9 | import diffoptics as do
10 |
11 | # initialize a lens
12 | device = torch.device('cuda')
13 | lens = do.Lensgroup(device=device)
14 |
15 | # load optics
16 | lens.load_file(Path('./lenses/DoubleGauss/US02532751-1.txt'))
17 |
18 | # sensor area
19 | pixel_size = 6.45e-3 # [mm]
20 | film_size = torch.tensor([1200, 1600], device=device)
21 | R_square = film_size * pixel_size
22 |
23 | # generate array of rays
24 | wavelength = 532.8 # [nm]
25 | R = 5.0 # [mm]
26 | # lens.plot_setup2D()
27 |
28 | def render_psf(I, p):
29 | # compute shifts and do linear interpolation
30 | uv = (p + R_square/2) / pixel_size
31 | index_l = torch.vstack((
32 | torch.clamp(torch.floor(uv[:,0]).long(), min=0, max=film_size[0]),
33 | torch.clamp(torch.floor(uv[:,1]).long(), min=0, max=film_size[1]))
34 | ).T
35 | index_r = torch.vstack((
36 | torch.clamp(index_l[:,0] + 1, min=0, max=film_size[0]),
37 | torch.clamp(index_l[:,1] + 1, min=0, max=film_size[1]))
38 | ).T
39 | w_r = torch.clamp(uv - index_l, min=0, max=1)
40 | w_l = 1.0 - w_r
41 | del uv
42 |
43 | # compute image
44 | I = torch.index_put(I, (index_l[...,0],index_l[...,1]), w_l[...,0]*w_l[...,1], accumulate=True)
45 | I = torch.index_put(I, (index_r[...,0],index_l[...,1]), w_r[...,0]*w_l[...,1], accumulate=True)
46 | I = torch.index_put(I, (index_l[...,0],index_r[...,1]), w_l[...,0]*w_r[...,1], accumulate=True)
47 | I = torch.index_put(I, (index_r[...,0],index_r[...,1]), w_r[...,0]*w_r[...,1], accumulate=True)
48 | return I
49 |
50 | def generate_surface_samples(M):
51 | Dx = np.random.rand(M,M)
52 | Dy = np.random.rand(M,M)
53 | [px, py] = do.Sampler().concentric_sample_disk(Dx, Dy)
54 | return np.stack((px.flatten(), py.flatten()), axis=1)
55 |
56 | def sample_ray(o_obj, M):
57 | p_aperture_2d = R * generate_surface_samples(M)
58 | N = p_aperture_2d.shape[0]
59 | p_aperture = np.hstack((p_aperture_2d, np.zeros((N,1)))).reshape((N,3))
60 | o = np.ones(N)[:, None] * o_obj[None, :]
61 |
62 | o = o.astype(np.float32)
63 | p_aperture = p_aperture.astype(np.float32)
64 |
65 | d = do.normalize(torch.from_numpy(p_aperture - o))
66 |
67 | o = torch.from_numpy(o).to(lens.device)
68 | d = d.to(lens.device)
69 |
70 | return do.Ray(o, d, wavelength, device=lens.device)
71 |
72 | def render(o_obj, M, rep_count):
73 | I = torch.zeros(*film_size, device=device)
74 | for i in range(rep_count):
75 | rays = sample_ray(o_obj, M)
76 | ps = lens.trace_to_sensor(rays, ignore_invalid=True)
77 | I = render_psf(I, ps[..., :2])
78 | return I / rep_count
79 |
80 | # PSF rendering parameters
81 | x_max_halfangle = 10 # [deg]
82 | y_max_halfangle = 7.5 # [deg]
83 | Nx = 2 * 8 + 1
84 | Ny = 2 * 6 + 1
85 |
86 | # sampling parameters
87 | M = 1001
88 | rep_count = 1
89 |
90 | def render_at_depth(z):
91 | x_halfmax = np.abs(z) * np.tan(np.deg2rad(x_max_halfangle))
92 | y_halfmax = np.abs(z) * np.tan(np.deg2rad(y_max_halfangle))
93 |
94 | I_psf_all = torch.zeros(*film_size, device=device)
95 | for x in tqdm(np.linspace(-x_halfmax, x_halfmax, Nx)):
96 | for y in np.linspace(-y_halfmax, y_halfmax, Ny):
97 | o_obj = np.array([y, x, z])
98 | I_psf = render(o_obj, M, rep_count)
99 | I_psf_all = I_psf_all + I_psf
100 | return I_psf_all
101 |
102 | # render PSF at different depths
103 | zs = [-1e4, -7e3, -5e3, -3e3, -2e3, -1.5e3, -1e3]
104 | savedir = Path('./rendered_psfs')
105 | savedir.mkdir(exist_ok=True, parents=True)
106 | I_psfs = []
107 | for z in zs:
108 | I_psf = render_at_depth(z)
109 | I_psf = I_psf.cpu().numpy()
110 | plt.imsave(str(savedir / 'I_psf_z={}.png'.format(z)), np.uint8(255 * I_psf / I_psf.max()), cmap='hot')
111 | I_psfs.append(I_psf)
112 |
--------------------------------------------------------------------------------
/examples/sanity_check.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import matplotlib.pyplot as plt
5 | from pathlib import Path
6 |
7 | import sys
8 | sys.path.append("../")
9 | import diffoptics as do
10 |
11 | # initialize a lens
12 | device = torch.device('cuda')
13 | lens = do.Lensgroup(device=device)
14 |
15 | # load optics
16 | lens.load_file(Path('./lenses/DoubleGauss/US02532751-1.txt'))
17 |
18 | # sample wavelengths in [nm]
19 | wavelengths = torch.Tensor([656.2725, 587.5618, 486.1327]).to(device)
20 |
21 | colors_list = 'bgry'
22 | views = np.linspace(0, 21, 4, endpoint=True)
23 | ax, fig = lens.plot_setup2D_with_trace(views, wavelengths[1], M=4)
24 | ax.axis('off')
25 | ax.set_title('Sanity Check Setup 2D')
26 | fig.savefig('sanity_check_setup.pdf')
27 |
28 | # spot diagrams
29 | spot_rmss = []
30 | valid_maps = []
31 | for i, view in enumerate(views):
32 | ray = lens.sample_ray(wavelengths[1], view=view, M=31, sampling='grid', entrance_pupil=True)
33 | ps = lens.trace_to_sensor(ray, ignore_invalid=True)
34 | lim = 20e-3
35 | lens.spot_diagram(
36 | ps[...,:2], show=True, xlims=[-lim, lim], ylims=[-lim, lim], color=colors_list[i]+'.',
37 | savepath='sanity_check_field_view_{}.png'.format(int(view))
38 | )
39 |
40 | spot_rmss.append(lens.rms(ps))
41 |
42 | plt.show()
43 |
--------------------------------------------------------------------------------
/examples/spherical_aberration.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import matplotlib.pyplot as plt
3 | from pathlib import Path
4 |
5 | import sys
6 | sys.path.append("../")
7 | import diffoptics as do
8 |
9 | # initialization
10 | # device = do.init()
11 | device = torch.device('cpu')
12 |
13 | # load target lens
14 | lens = do.Lensgroup(device=device)
15 | lens.load_file(Path('./lenses/Thorlabs/ACL5040U.txt'))
16 | print(lens.surfaces[0])
17 |
18 | # generate array of rays
19 | wavelength = torch.Tensor([532.8]).to(device) # [nm]
20 | R = 15.0 # [mm]
21 | def render():
22 | ray_init = lens.sample_ray(wavelength, M=31, R=R)
23 | ps = lens.trace_to_sensor(ray_init)
24 | return ps[...,:2]
25 |
26 | def trace_all():
27 | ray_init = lens.sample_ray_2D(R, wavelength, M=15)
28 | ps, oss = lens.trace_to_sensor_r(ray_init)
29 | return ps[...,:2], oss
30 | ps, oss = trace_all()
31 | ax, fig = lens.plot_raytraces(oss)
32 |
33 | ax, fig = lens.plot_setup2D_with_trace([0.0], wavelength, M=5, R=R)
34 | ax.axis('off')
35 | ax.set_title("")
36 | fig.savefig("layout_trace_asphere.pdf", bbox_inches='tight')
37 |
38 | # show initial RMS
39 | ps_org = render()
40 | L_org = torch.mean(torch.sum(torch.square(ps_org), axis=-1))
41 | print('original loss: {:.3e}'.format(L_org))
42 | lens.spot_diagram(ps_org, xlims=[-50.0e-3, 50.0e-3], ylims=[-50.0e-3, 50.0e-3])
43 |
44 | diff_names = [
45 | 'surfaces[0].c',
46 | 'surfaces[0].k',
47 | 'surfaces[0].ai'
48 | ]
49 |
50 | # optimize
51 | out = do.LM(lens, diff_names, 1e-4, option='diag') \
52 | .optimize(render, lambda y: 0.0 - y, maxit=300, record=True)
53 |
54 | # show loss
55 | plt.figure()
56 | plt.semilogy(out['ls'], '-o')
57 | plt.xlabel('Iteration')
58 | plt.ylabel('Loss')
59 | plt.show()
60 |
61 | # show spot diagram
62 | ps = render()
63 | L = torch.mean(torch.sum(torch.square(ps), axis=-1))
64 | print('final loss: {:.3e}'.format(L))
65 | lens.spot_diagram(ps, xlims=[-50.0e-3, 50.0e-3], ylims=[-50.0e-3, 50.0e-3])
66 | print(lens.surfaces[0])
67 | # lens.plot_setup2D()
68 |
--------------------------------------------------------------------------------
/examples/training_dataset/0008.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/training_dataset/0008.png
--------------------------------------------------------------------------------
/examples/training_dataset/0010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/training_dataset/0010.png
--------------------------------------------------------------------------------
/examples/training_dataset/0023.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/training_dataset/0023.png
--------------------------------------------------------------------------------
/examples/training_dataset/0030.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/training_dataset/0030.png
--------------------------------------------------------------------------------
/examples/training_dataset/0031.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/training_dataset/0031.png
--------------------------------------------------------------------------------
/examples/training_dataset/0032.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/training_dataset/0032.png
--------------------------------------------------------------------------------
/examples/training_dataset/0115.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/training_dataset/0115.png
--------------------------------------------------------------------------------
/examples/training_dataset/0267.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/examples/training_dataset/0267.png
--------------------------------------------------------------------------------
/examples/utils_end2end.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 | import torch
5 | import cv2
6 | import pathlib
7 |
8 |
9 | def dict_to_tensor(xs):
10 | """
11 | Concatenates (or, packs) a dictionary of differentiable parameters into a Tensor array.
12 | """
13 | return torch.cat([x.view(-1) for x in xs], 0)
14 |
15 | def tensor_to_dict(xs, diff_parameters):
16 | """
17 | Unpacks a Tensor array into a dictionary of differentiable parameters.
18 | """
19 | ys = []
20 | idx = 0
21 | for diff_para in diff_parameters:
22 | y = xs[idx:idx+diff_para.numel()]
23 | idx += diff_para.numel()
24 | ys.append(y)
25 | return ys
26 |
27 | def load_image(image_name):
28 | img = cv2.cvtColor(cv2.imread(image_name), cv2.COLOR_BGR2RGB)
29 | img = np.flip(img.astype(np.float32), axis=(0,1)).copy()
30 | return img
31 |
32 |
33 | class ImageFolder(torch.utils.data.Dataset):
34 | def __init__(self, root_path):
35 | super(ImageFolder, self).__init__()
36 | self.file_names = sorted(os.listdir(root_path))
37 | self.root_path = pathlib.Path(root_path)
38 |
39 | def __len__(self):
40 | return len(self.file_names)
41 |
42 | def __getitem__(self, index):
43 | return load_image(str(self.root_path / self.file_names[index]))
44 |
45 |
46 | def load_deblurganv2():
47 | """
48 | Loads the DeblurGANv2 as the neural network backend.
49 | """
50 | neural_network_path = pathlib.Path('./neural_networks/DeblurGANv2')
51 | sys.path.append(str(neural_network_path))
52 |
53 | import train_end2end
54 |
55 | trainer = train_end2end.load_from_config(str(neural_network_path / 'config/config.yaml'))
56 |
57 | return trainer
58 |
--------------------------------------------------------------------------------
/imgs/abp.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/abp.jpg
--------------------------------------------------------------------------------
/imgs/applications.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/applications.jpg
--------------------------------------------------------------------------------
/imgs/bp_abp_comp.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/bp_abp_comp.jpg
--------------------------------------------------------------------------------
/imgs/examples/I.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/I.jpg
--------------------------------------------------------------------------------
/imgs/examples/I0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/I0.jpg
--------------------------------------------------------------------------------
/imgs/examples/I_final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/I_final.png
--------------------------------------------------------------------------------
/imgs/examples/I_psf_z=-1000.0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/I_psf_z=-1000.0.png
--------------------------------------------------------------------------------
/imgs/examples/I_psf_z=-10000.0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/I_psf_z=-10000.0.png
--------------------------------------------------------------------------------
/imgs/examples/I_psf_z=-1500.0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/I_psf_z=-1500.0.png
--------------------------------------------------------------------------------
/imgs/examples/I_psf_z=-2000.0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/I_psf_z=-2000.0.png
--------------------------------------------------------------------------------
/imgs/examples/I_psf_z=-3000.0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/I_psf_z=-3000.0.png
--------------------------------------------------------------------------------
/imgs/examples/I_rendered.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/I_rendered.jpg
--------------------------------------------------------------------------------
/imgs/examples/I_target.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/I_target.png
--------------------------------------------------------------------------------
/imgs/examples/iter_1_z=6000.0mm_images.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/iter_1_z=6000.0mm_images.png
--------------------------------------------------------------------------------
/imgs/examples/optimized.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/optimized.gif
--------------------------------------------------------------------------------
/imgs/examples/optimized.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/optimized.mp4
--------------------------------------------------------------------------------
/imgs/examples/phase.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/phase.png
--------------------------------------------------------------------------------
/imgs/examples/sanity_check_dO.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/sanity_check_dO.jpg
--------------------------------------------------------------------------------
/imgs/examples/sanity_check_zemax.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/examples/sanity_check_zemax.jpg
--------------------------------------------------------------------------------
/imgs/memory_comp.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/memory_comp.jpg
--------------------------------------------------------------------------------
/imgs/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vccimaging/DiffOptics/f62dc49aa45c5ea3f8165634622392fbe44e6448/imgs/overview.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.5.3
2 | mitsuba==3.0.2
3 | numpy==1.21.6
4 | scipy==1.9.1
5 | torch==1.12.1+cu113
6 |
--------------------------------------------------------------------------------