├── .ipynb_checkpoints
└── NerfppVisualizationDash-checkpoint.ipynb
├── NerfppVisualizationDash.ipynb
├── README.md
├── __pycache__
├── drawing_tools.cpython-38.pyc
└── nerfplusplus_tools.cpython-38.pyc
├── app.py
├── drawing_tools.py
├── example.gif
└── nerfplusplus_tools.py
/.ipynb_checkpoints/NerfppVisualizationDash-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "a6fca250",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import plotly.graph_objects as go\n",
11 | "import numpy as np\n",
12 | "import torch\n",
13 | "import plotly.express as px"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 2,
19 | "id": "de77ce6c",
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "HUGE_NUMBER = 1e10\n",
24 | "TINY_NUMBER = 1e-6 "
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "id": "541c3686",
30 | "metadata": {},
31 | "source": [
32 | "# Ray Casting"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 3,
38 | "id": "27ee36ff",
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "# from autonomousvision/giraffe\n",
43 | "\n",
44 | "# 0.\n",
45 | "def to_pytorch(tensor, return_type=False):\n",
46 | " ''' Converts input tensor to pytorch.\n",
47 | " Args:\n",
48 | " tensor (tensor): Numpy or Pytorch tensor\n",
49 | " return_type (bool): whether to return input type\n",
50 | " '''\n",
51 | " is_numpy = False\n",
52 | " if type(tensor) == np.ndarray:\n",
53 | " tensor = torch.from_numpy(tensor).float()\n",
54 | " is_numpy = True\n",
55 | " tensor = tensor.clone()\n",
56 | " if return_type:\n",
57 | " return tensor, is_numpy\n",
58 | " return tensor\n",
59 | "\n",
60 | "# 1. get camera intrinsic\n",
61 | "def get_camera_mat(fov=49.13, invert=True):\n",
62 | " # fov = 2 * arctan( sensor / (2 * focal))\n",
63 | " # focal = (sensor / 2) * 1 / (tan(0.5 * fov))\n",
64 | " # in our case, sensor = 2 as pixels are in [-1, 1]\n",
65 | " focal = 1. / np.tan(0.5 * fov * np.pi/180.)\n",
66 | " focal = focal.astype(np.float32)\n",
67 | " mat = torch.tensor([\n",
68 | " [focal, 0., 0., 0.],\n",
69 | " [0., focal, 0., 0.],\n",
70 | " [0., 0., 1, 0.],\n",
71 | " [0., 0., 0., 1.]\n",
72 | " ]).reshape(1, 4, 4)\n",
73 | "\n",
74 | " if invert:\n",
75 | " mat = torch.inverse(mat)\n",
76 | " return mat\n",
77 | "\n",
78 | "# 2. get camera position with camera pose (theta & phi)\n",
79 | "def to_sphere(u, v):\n",
80 | " theta = 2 * np.pi * u\n",
81 | " phi = np.arccos(1 - 2 * v)\n",
82 | " cx = np.sin(phi) * np.cos(theta)\n",
83 | " cy = np.sin(phi) * np.sin(theta)\n",
84 | " cz = np.cos(phi)\n",
85 | " return np.stack([cx, cy, cz], axis=-1)\n",
86 | "\n",
87 | "# 3. get camera coordinate system assuming it points to the center of the sphere\n",
88 | "def look_at(eye, at=np.array([0, 0, 0]), up=np.array([0, 0, 1]), eps=1e-5,\n",
89 | " to_pytorch=True):\n",
90 | " at = at.astype(float).reshape(1, 3)\n",
91 | " up = up.astype(float).reshape(1, 3)\n",
92 | " eye = eye.reshape(-1, 3)\n",
93 | " up = up.repeat(eye.shape[0] // up.shape[0], axis=0)\n",
94 | " eps = np.array([eps]).reshape(1, 1).repeat(up.shape[0], axis=0)\n",
95 | "\n",
96 | " z_axis = eye - at\n",
97 | " z_axis /= np.max(np.stack([np.linalg.norm(z_axis,\n",
98 | " axis=1, keepdims=True), eps]))\n",
99 | "\n",
100 | " x_axis = np.cross(up, z_axis)\n",
101 | " x_axis /= np.max(np.stack([np.linalg.norm(x_axis,\n",
102 | " axis=1, keepdims=True), eps]))\n",
103 | "\n",
104 | " y_axis = np.cross(z_axis, x_axis)\n",
105 | " y_axis /= np.max(np.stack([np.linalg.norm(y_axis,\n",
106 | " axis=1, keepdims=True), eps]))\n",
107 | "\n",
108 | " r_mat = np.concatenate(\n",
109 | " (x_axis.reshape(-1, 3, 1), y_axis.reshape(-1, 3, 1), z_axis.reshape(\n",
110 | " -1, 3, 1)), axis=2)\n",
111 | "\n",
112 | " if to_pytorch:\n",
113 | " r_mat = torch.tensor(r_mat).float()\n",
114 | "\n",
115 | " return r_mat\n",
116 | "\n",
117 | "# 5. arange 2d array of pixel coordinate and give depth of 1\n",
118 | "def arange_pixels(resolution=(128, 128), batch_size=1, image_range=(-1., 1.),\n",
119 | " subsample_to=None, invert_y_axis=False):\n",
120 | " ''' Arranges pixels for given resolution in range image_range.\n",
121 | " The function returns the unscaled pixel locations as integers and the\n",
122 | " scaled float values.\n",
123 | " Args:\n",
124 | " resolution (tuple): image resolution\n",
125 | " batch_size (int): batch size\n",
126 | " image_range (tuple): range of output points (default [-1, 1])\n",
127 | " subsample_to (int): if integer and > 0, the points are randomly\n",
128 | " subsampled to this value\n",
129 | " '''\n",
130 | " h, w = resolution\n",
131 | " n_points = resolution[0] * resolution[1]\n",
132 | "\n",
133 | " # Arrange pixel location in scale resolution\n",
134 | " pixel_locations = torch.meshgrid(torch.arange(0, w), torch.arange(0, h))\n",
135 | " pixel_locations = torch.stack(\n",
136 | " [pixel_locations[0], pixel_locations[1]],\n",
137 | " dim=-1).long().view(1, -1, 2).repeat(batch_size, 1, 1)\n",
138 | " pixel_scaled = pixel_locations.clone().float()\n",
139 | "\n",
140 | " # Shift and scale points to match image_range\n",
141 | " scale = (image_range[1] - image_range[0])\n",
142 | " loc = scale / 2\n",
143 | " pixel_scaled[:, :, 0] = scale * pixel_scaled[:, :, 0] / (w - 1) - loc\n",
144 | " pixel_scaled[:, :, 1] = scale * pixel_scaled[:, :, 1] / (h - 1) - loc\n",
145 | "\n",
146 | " # Subsample points if subsample_to is not None and > 0\n",
147 | " if (subsample_to is not None and subsample_to > 0 and\n",
148 | " subsample_to < n_points):\n",
149 | " idx = np.random.choice(pixel_scaled.shape[1], size=(subsample_to,),\n",
150 | " replace=False)\n",
151 | " pixel_scaled = pixel_scaled[:, idx]\n",
152 | " pixel_locations = pixel_locations[:, idx]\n",
153 | "\n",
154 | " if invert_y_axis:\n",
155 | " assert(image_range == (-1, 1))\n",
156 | " pixel_scaled[..., -1] *= -1.\n",
157 | " pixel_locations[..., -1] = (h - 1) - pixel_locations[..., -1]\n",
158 | "\n",
159 | " return pixel_locations, pixel_scaled\n",
160 | "\n",
161 | "# 6. mat_mul with intrinsic and then extrinsic gives you p_world (pixels in world) \n",
162 | "def image_points_to_world(image_points, camera_mat, world_mat, scale_mat=None,\n",
163 | " invert=False, negative_depth=True):\n",
164 | " ''' Transforms points on image plane to world coordinates.\n",
165 | " In contrast to transform_to_world, no depth value is needed as points on\n",
166 | " the image plane have a fixed depth of 1.\n",
167 | " Args:\n",
168 | " image_points (tensor): image points tensor of size B x N x 2\n",
169 | " camera_mat (tensor): camera matrix\n",
170 | " world_mat (tensor): world matrix\n",
171 | " scale_mat (tensor): scale matrix\n",
172 | " invert (bool): whether to invert matrices (default: False)\n",
173 | " '''\n",
174 | " batch_size, n_pts, dim = image_points.shape\n",
175 | " assert(dim == 2)\n",
176 | " d_image = torch.ones(batch_size, n_pts, 1)\n",
177 | " if negative_depth:\n",
178 | " d_image *= -1.\n",
179 | " return transform_to_world(image_points, d_image, camera_mat, world_mat,\n",
180 | " scale_mat, invert=invert)\n",
181 | "\n",
182 | "def transform_to_world(pixels, depth, camera_mat, world_mat, scale_mat=None,\n",
183 | " invert=True, use_absolute_depth=True):\n",
184 | " ''' Transforms pixel positions p with given depth value d to world coordinates.\n",
185 | " Args:\n",
186 | " pixels (tensor): pixel tensor of size B x N x 2\n",
187 | " depth (tensor): depth tensor of size B x N x 1\n",
188 | " camera_mat (tensor): camera matrix\n",
189 | " world_mat (tensor): world matrix\n",
190 | " scale_mat (tensor): scale matrix\n",
191 | " invert (bool): whether to invert matrices (default: true)\n",
192 | " '''\n",
193 | " assert(pixels.shape[-1] == 2)\n",
194 | "\n",
195 | " if scale_mat is None:\n",
196 | " scale_mat = torch.eye(4).unsqueeze(0).repeat(\n",
197 | " camera_mat.shape[0], 1, 1)\n",
198 | "\n",
199 | " # Convert to pytorch\n",
200 | " pixels, is_numpy = to_pytorch(pixels, True)\n",
201 | " depth = to_pytorch(depth)\n",
202 | " camera_mat = to_pytorch(camera_mat)\n",
203 | " world_mat = to_pytorch(world_mat)\n",
204 | " scale_mat = to_pytorch(scale_mat)\n",
205 | "\n",
206 | " # Invert camera matrices\n",
207 | " if invert:\n",
208 | " camera_mat = torch.inverse(camera_mat)\n",
209 | " world_mat = torch.inverse(world_mat)\n",
210 | " scale_mat = torch.inverse(scale_mat)\n",
211 | "\n",
212 | " # Transform pixels to homogen coordinates\n",
213 | " pixels = pixels.permute(0, 2, 1)\n",
214 | " pixels = torch.cat([pixels, torch.ones_like(pixels)], dim=1)\n",
215 | "\n",
216 | " # Project pixels into camera space\n",
217 | " if use_absolute_depth:\n",
218 | " pixels[:, :2] = pixels[:, :2] * depth.permute(0, 2, 1).abs()\n",
219 | " pixels[:, 2:3] = pixels[:, 2:3] * depth.permute(0, 2, 1)\n",
220 | " else:\n",
221 | " pixels[:, :3] = pixels[:, :3] * depth.permute(0, 2, 1)\n",
222 | " \n",
223 | " # Transform pixels to world space\n",
224 | " p_world = scale_mat @ world_mat @ camera_mat @ pixels\n",
225 | "\n",
226 | " # Transform p_world back to 3D coordinates\n",
227 | " p_world = p_world[:, :3].permute(0, 2, 1)\n",
228 | "\n",
229 | " if is_numpy:\n",
230 | " p_world = p_world.numpy()\n",
231 | " return p_world\n",
232 | "\n",
233 | "\n",
234 | "# 7. mat_mul zeros with intrinsic&extrinsic for camera pos (which we alread obtained as loc)\n",
235 | "def origin_to_world(n_points, camera_mat, world_mat, scale_mat=None,\n",
236 | " invert=False):\n",
237 | " ''' Transforms origin (camera location) to world coordinates.\n",
238 | " Args:\n",
239 | " n_points (int): how often the transformed origin is repeated in the\n",
240 | " form (batch_size, n_points, 3)\n",
241 | " camera_mat (tensor): camera matrix\n",
242 | " world_mat (tensor): world matrix\n",
243 | " scale_mat (tensor): scale matrix\n",
244 | " invert (bool): whether to invert the matrices (default: false)\n",
245 | " '''\n",
246 | " \n",
247 | " batch_size = camera_mat.shape[0]\n",
248 | " device = camera_mat.device\n",
249 | " # Create origin in homogen coordinates\n",
250 | " p = torch.zeros(batch_size, 4, n_points).to(device)\n",
251 | " p[:, -1] = 1.\n",
252 | "\n",
253 | " if scale_mat is None:\n",
254 | " scale_mat = torch.eye(4).unsqueeze(\n",
255 | " 0).repeat(batch_size, 1, 1).to(device)\n",
256 | "\n",
257 | " # Invert matrices\n",
258 | " if invert:\n",
259 | " camera_mat = torch.inverse(camera_mat)\n",
260 | " world_mat = torch.inverse(world_mat)\n",
261 | " scale_mat = torch.inverse(scale_mat)\n",
262 | " \n",
263 | " camera_mat = to_pytorch(camera_mat)\n",
264 | " world_mat = to_pytorch(world_mat)\n",
265 | " scale_mat = to_pytorch(scale_mat)\n",
266 | " \n",
267 | " # Apply transformation\n",
268 | " p_world = scale_mat @ world_mat @ camera_mat @ p\n",
269 | "\n",
270 | " # Transform points back to 3D coordinates\n",
271 | " p_world = p_world[:, :3].permute(0, 2, 1)\n",
272 | " return p_world"
273 | ]
274 | },
275 | {
276 | "cell_type": "markdown",
277 | "id": "414ea1ff",
278 | "metadata": {},
279 | "source": [
280 | "# Sampling"
281 | ]
282 | },
283 | {
284 | "cell_type": "code",
285 | "execution_count": 4,
286 | "id": "d62aa105",
287 | "metadata": {},
288 | "outputs": [],
289 | "source": [
290 | "# from Kai-46/nerfplusplus\n",
291 | "\n",
292 | "# 8. intersect sphere for distinguishing fg and bg\n",
293 | "def intersect_sphere(ray_o, ray_d):\n",
294 | " '''\n",
295 | " ray_o, ray_d: [..., 3]\n",
296 | " compute the depth of the intersection point between this ray and unit sphere\n",
297 | " '''\n",
298 | " # note: d1 becomes negative if this mid point is behind camera\n",
299 | " d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)\n",
300 | " p = ray_o + d1.unsqueeze(-1) * ray_d\n",
301 | " # consider the case where the ray does not intersect the sphere\n",
302 | " ray_d_cos = 1. / torch.norm(ray_d, dim=-1)\n",
303 | " p_norm_sq = torch.sum(p * p, dim=-1)\n",
304 | " if (p_norm_sq >= 1.).any():\n",
305 | " raise Exception('Not all your cameras are bounded by the unit sphere; please make sure the cameras are normalized properly!')\n",
306 | " d2 = torch.sqrt(1. - p_norm_sq) * ray_d_cos\n",
307 | "\n",
308 | " return d1 + d2\n",
309 | "\n",
310 | "# 9. inverse sphere sampling for bg\n",
311 | "def depth2pts_outside(ray_o, ray_d, depth):\n",
312 | " '''\n",
313 | " ray_o, ray_d: [..., 3]\n",
314 | " depth: [...]; inverse of distance to sphere origin\n",
315 | " '''\n",
316 | " # note: d1 becomes negative if this mid point is behind camera\n",
317 | " d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)\n",
318 | " p_mid = ray_o + d1.unsqueeze(-1) * ray_d\n",
319 | " p_mid_norm = torch.norm(p_mid, dim=-1)\n",
320 | " ray_d_cos = 1. / torch.norm(ray_d, dim=-1)\n",
321 | " d2 = torch.sqrt(1. - p_mid_norm * p_mid_norm) * ray_d_cos\n",
322 | " \n",
323 | " p_sphere = ray_o + (d1 + d2).unsqueeze(-1) * ray_d\n",
324 | "\n",
325 | " rot_axis = torch.cross(ray_o, p_sphere, dim=-1)\n",
326 | " rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True)\n",
327 | " phi = torch.asin(p_mid_norm)\n",
328 | " theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1]\n",
329 | " rot_angle = (phi - theta).unsqueeze(-1) # [..., 1]\n",
330 | "\n",
331 | " # now rotate p_sphere\n",
332 | " # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula\n",
333 | " p_sphere_new = p_sphere * torch.cos(rot_angle) + \\\n",
334 | " torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \\\n",
335 | " rot_axis * torch.sum(rot_axis*p_sphere, dim=-1, keepdim=True) * (1.-torch.cos(rot_angle))\n",
336 | " p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True)\n",
337 | " pts = torch.cat((p_sphere_new.squeeze(), depth.squeeze().unsqueeze(-1)), dim=-1) # (modified) added .squeeze()\n",
338 | "\n",
339 | " # now calculate conventional depth\n",
340 | " depth_real = 1. / (depth + TINY_NUMBER) * torch.cos(theta) * ray_d_cos + d1\n",
341 | " return pts, depth_real\n",
342 | "\n",
343 | "# 10. perturb sample z values (depth) for some randomness (stratified sampling)\n",
344 | "def perturb_samples(z_vals):\n",
345 | " # get intervals between samples\n",
346 | " mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])\n",
347 | " upper = torch.cat([mids, z_vals[..., -1:]], dim=-1)\n",
348 | " lower = torch.cat([z_vals[..., 0:1], mids], dim=-1)\n",
349 | " # uniform samples in those intervals\n",
350 | " t_rand = torch.rand_like(z_vals)\n",
351 | " z_vals = lower + (upper - lower) * t_rand # [N_rays, N_samples]\n",
352 | "\n",
353 | " return z_vals\n",
354 | "\n",
355 | "# 11. return fg (just uniform) and bg (uniform inverse sphere) samples\n",
356 | "def uniformsampling(ray_o, ray_d, min_depth, N_samples, device):\n",
357 | " \n",
358 | " dots_sh = list(ray_d.shape[:-1])\n",
359 | "\n",
360 | " # foreground depth\n",
361 | " fg_far_depth = intersect_sphere(ray_o, ray_d) # [...,]\n",
362 | " fg_near_depth = min_depth # [..., ]\n",
363 | " step = (fg_far_depth - fg_near_depth) / (N_samples - 1)\n",
364 | " fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples]\n",
365 | " fg_depth = perturb_samples(fg_depth) # random perturbation during training\n",
366 | "\n",
367 | " # background depth\n",
368 | " bg_depth = torch.linspace(0., 1., N_samples).view(\n",
369 | " [1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(device)\n",
370 | " bg_depth = perturb_samples(bg_depth) # random perturbation during training\n",
371 | "\n",
372 | "\n",
373 | " fg_ray_o = ray_o.unsqueeze(-2).expand(dots_sh + [N_samples, 3])\n",
374 | " fg_ray_d = ray_d.unsqueeze(-2).expand(dots_sh + [N_samples, 3])\n",
375 | "\n",
376 | " bg_ray_o = ray_o.unsqueeze(-2).expand(dots_sh + [N_samples, 3])\n",
377 | " bg_ray_d = ray_d.unsqueeze(-2).expand(dots_sh + [N_samples, 3])\n",
378 | " \n",
379 | " # sampling foreground\n",
380 | " fg_pts = fg_ray_o + fg_depth.unsqueeze(-1) * fg_ray_d\n",
381 | "\n",
382 | " # sampling background\n",
383 | " bg_pts, bg_depth_real = depth2pts_outside(bg_ray_o, bg_ray_d, bg_depth)\n",
384 | "\n",
385 | " return fg_pts, bg_pts, bg_depth_real\n",
386 | "\n",
387 | "# 12 convert bg pts (x', y', z' 1/r) to real bg pts (x, y, z)\n",
388 | "def pts2realpts(bg_pts, bg_depth_real):\n",
389 | " return bg_pts[:, :, :3] * bg_depth_real.squeeze().unsqueeze(-1)"
390 | ]
391 | },
392 | {
393 | "cell_type": "code",
394 | "execution_count": 5,
395 | "id": "856d052e",
396 | "metadata": {},
397 | "outputs": [],
398 | "source": [
399 | "def giraffe(u = 1,\n",
400 | " v = 0.5,\n",
401 | " r=2.713,\n",
402 | " depth_range=[0.5, 6.],\n",
403 | " n_ray_samples=16,\n",
404 | " resolution_vol = 4,\n",
405 | " batch_size = 1\n",
406 | " ):\n",
407 | "\n",
408 | " range_radius=[r, r]\n",
409 | " \n",
410 | " res = resolution_vol\n",
411 | " n_points = res * res\n",
412 | "\n",
413 | " # 1. get camera intrinsic \n",
414 | " camera_mat = get_camera_mat()\n",
415 | "\n",
416 | " # 2. get camera position with camera pose (theta & phi)\n",
417 | " loc = to_sphere(u, v)\n",
418 | " loc = torch.tensor(loc).float()\n",
419 | " radius = range_radius[0] + \\\n",
420 | " torch.rand(batch_size) * (range_radius[1] - range_radius[0])\n",
421 | " loc = loc * radius.unsqueeze(-1)\n",
422 | "\n",
423 | " # 3. get camera coordinate system assuming it points to the center of the sphere\n",
424 | " R = look_at(loc)\n",
425 | "\n",
426 | " # 4. The carmera coordinate is the rotational matrix and with camera loc, it is camera extrinsic\n",
427 | " RT = np.eye(4).reshape(1, 4, 4)\n",
428 | " RT[:, :3, :3] = R\n",
429 | " RT[:, :3, -1] = loc\n",
430 | " world_mat = RT\n",
431 | "\n",
432 | " # 5. arange 2d array of pixel coordinate and give depth of 1\n",
433 | " pixels = arange_pixels((res, res), 1, invert_y_axis=False)[1]\n",
434 | " pixels[..., -1] *= -1. # still dunno why this is here\n",
435 | "\n",
436 | " # 6. mat_mul with intrinsic and then extrinsic gives you p_world (pixels in world) \n",
437 | " pixels_world = image_points_to_world(pixels, camera_mat, world_mat)\n",
438 | " \n",
439 | " # 7. mat_mul zeros with intrinsic&extrinsic for camera pos (which we alread obtained as loc)\n",
440 | " camera_world = origin_to_world(n_points, camera_mat, world_mat)\n",
441 | "\n",
442 | " # 8. ray = pixel - camera origin (in world)\n",
443 | " ray_vector = pixels_world - camera_world\n",
444 | "\n",
445 | " # 9. depths from closest to furthest (0.5 ~ 6.0)\n",
446 | " di = depth_range[0] + \\\n",
447 | " torch.linspace(0., 1., steps=n_ray_samples).reshape(1, 1, -1) * (\n",
448 | " depth_range[1] - depth_range[0])\n",
449 | " di = di.repeat(batch_size, n_points, 1)\n",
450 | "\n",
451 | " # 10. calculate points\n",
452 | " p_i = camera_world.unsqueeze(-2).contiguous() + \\\n",
453 | " di.unsqueeze(-1).contiguous() * ray_vector.unsqueeze(-2).contiguous()\n",
454 | " \n",
455 | " return pixels_world, camera_world, world_mat, p_i"
456 | ]
457 | },
458 | {
459 | "cell_type": "code",
460 | "execution_count": 6,
461 | "id": "1a0c73e1",
462 | "metadata": {},
463 | "outputs": [],
464 | "source": [
465 | "def nerfpp(u = 1,\n",
466 | " v = 0.5,\n",
467 | " fov = 49.13,\n",
468 | " depth_range=[0.5, 6.],\n",
469 | " n_ray_samples=16,\n",
470 | " resolution_vol = 4,\n",
471 | " batch_size = 1,\n",
472 | " device = torch.device('cpu')\n",
473 | " ):\n",
474 | " \n",
475 | " r = 1\n",
476 | " range_radius=[r, r]\n",
477 | " \n",
478 | " res = resolution_vol\n",
479 | " n_points = res * res\n",
480 | "\n",
481 | " # 1. get camera intrinsic - fiddle around with fov this time\n",
482 | " camera_mat = get_camera_mat(fov=fov)\n",
483 | "\n",
484 | " # 2. get camera position with camera pose (theta & phi)\n",
485 | " loc = to_sphere(u, v)\n",
486 | " loc = torch.tensor(loc).float()\n",
487 | " \n",
488 | " radius = range_radius[0] + \\\n",
489 | " torch.rand(batch_size) * (range_radius[1] - range_radius[0])\n",
490 | " \n",
491 | " loc = loc * radius.unsqueeze(-1)\n",
492 | "\n",
493 | " # 3. get camera coordinate system assuming it points to the center of the sphere\n",
494 | " R = look_at(loc)\n",
495 | "\n",
496 | " # 4. The carmera coordinate is the rotational matrix and with camera loc, it is camera extrinsic\n",
497 | " RT = np.eye(4).reshape(1, 4, 4)\n",
498 | " RT[:, :3, :3] = R\n",
499 | " RT[:, :3, -1] = loc\n",
500 | " world_mat = RT\n",
501 | "\n",
502 | " # 5. arange 2d array of pixel coordinate and give depth of 1\n",
503 | " pixels = arange_pixels((res, res), 1, invert_y_axis=False)[1]\n",
504 | " pixels[..., -1] *= -1. # still dunno why this is here\n",
505 | "\n",
506 | " # 6. mat_mul with intrinsic and then extrinsic gives you p_world (pixels in world) \n",
507 | " pixels_world = image_points_to_world(pixels, camera_mat, world_mat)\n",
508 | "\n",
509 | " # 7. mat_mul zeros with intrinsic&extrinsic for camera pos (which we alread obtained as loc)\n",
510 | " camera_world = origin_to_world(n_points, camera_mat, world_mat)\n",
511 | "\n",
512 | " # 8. ray = pixel - camera origin (in world)\n",
513 | " ray_vector = pixels_world - camera_world\n",
514 | "\n",
515 | " # 9. sample fg and bg points according to nerfpp (uniform and inverse sphere)\n",
516 | " fg_pts, bg_pts, bg_depth_real = uniformsampling(ray_o=camera_world, ray_d=ray_vector, min_depth=depth_range[0], N_samples=n_ray_samples, device=device)\n",
517 | " \n",
518 | " #10. convert bg pts (x', y', z' 1/r) to real bg pts (x, y, z)\n",
519 | " bg_pts_real = pts2realpts(bg_pts, bg_depth_real)\n",
520 | " \n",
521 | " return pixels_world, camera_world, world_mat, fg_pts, bg_pts_real"
522 | ]
523 | },
524 | {
525 | "cell_type": "markdown",
526 | "id": "2fd19406",
527 | "metadata": {},
528 | "source": [
529 | "# Visualization"
530 | ]
531 | },
532 | {
533 | "cell_type": "code",
534 | "execution_count": 7,
535 | "id": "b9dd6b14",
536 | "metadata": {},
537 | "outputs": [],
538 | "source": [
539 | "# draw sphere with radius r \n",
540 | "# also draw contours and vertical lines\n",
541 | "def draw_sphere(r, sphere_colorscale, sphere_opacity):\n",
542 | " # sphere\n",
543 | " u = np.linspace(0, 2 * np.pi, 100)\n",
544 | " v = np.linspace(0, np.pi, 100)\n",
545 | " x = r * np.outer(np.cos(u), np.sin(v))\n",
546 | " y = r * np.outer(np.sin(u), np.sin(v))\n",
547 | " z = r * np.outer(np.ones(np.size(u)), np.cos(v))\n",
548 | " \n",
549 | " # vertical lines on sphere\n",
550 | " u2 = np.linspace(0, 2 * np.pi, 20)\n",
551 | " x2 = r * np.outer(np.cos(u2), np.sin(v))\n",
552 | " y2 = r * np.outer(np.sin(u2), np.sin(v))\n",
553 | " z2 = r * np.outer(np.ones(np.size(u2)), np.cos(v))\n",
554 | " \n",
555 | " # create sphere and draw sphere with contours\n",
556 | " fig = go.Figure(data=[go.Surface(x=x, y=y, z=z, \n",
557 | " colorscale=sphere_colorscale, opacity=sphere_opacity,\n",
558 | " contours = {\n",
559 | " 'z' : {'show' : True, 'start' : -r,\n",
560 | " 'end' : r, 'size' : r/10,\n",
561 | " 'color' : 'white',\n",
562 | " 'width' : 1}\n",
563 | " }\n",
564 | " , showscale=False)])\n",
565 | " \n",
566 | " # vertical lines on sphere\n",
567 | " for i in range(len(u2)):\n",
568 | " fig.add_scatter3d(x=x2[i], y=y2[i], z=z2[i], \n",
569 | " line=dict(\n",
570 | " color='white',\n",
571 | " width=1\n",
572 | " ),\n",
573 | " mode='lines',\n",
574 | " showlegend=False)\n",
575 | " \n",
576 | " return fig\n",
577 | "\n",
578 | "# draw xyplane\n",
579 | "def draw_XYplane(fig, xy_plane_colorscale, xy_plane_opacity, x_range = [-2, 2], y_range = [-2, 2]):\n",
580 | " x3 = np.linspace(x_range[0], x_range[1], 100)\n",
581 | " y3 = np.linspace(y_range[0], y_range[1], 100)\n",
582 | " z3 = np.zeros(shape=(100,100))\n",
583 | " \n",
584 | " fig.add_surface(x=x3, y=y3, z=z3,\n",
585 | " colorscale =xy_plane_colorscale, opacity=xy_plane_opacity,\n",
586 | " showscale=False\n",
587 | " )\n",
588 | " \n",
589 | " return fig\n",
590 | " \n",
591 | "\n",
592 | "def draw_XYZworld(fig, world_axis_size):\n",
593 | " # x, y, z positive direction (world)\n",
594 | " X_axis = [0, world_axis_size]\n",
595 | " X_text = [None, \"X\"]\n",
596 | " X0 = [0, 0]\n",
597 | " Y_axis = [0, world_axis_size]\n",
598 | " Y_text = [None, \"Y\"]\n",
599 | " Y0 = [0, 0]\n",
600 | " Z_axis = [0, world_axis_size]\n",
601 | " Z_text = [None, \"Z\"]\n",
602 | " Z0 = [0, 0]\n",
603 | " \n",
604 | " fig.add_scatter3d(x=X_axis, y=Y0, z=Z0, \n",
605 | " line=dict(\n",
606 | " color='red',\n",
607 | " width=10\n",
608 | " ),\n",
609 | " mode='lines+text',\n",
610 | " text=X_text,\n",
611 | " textposition='top center',\n",
612 | " textfont=dict(\n",
613 | " color=\"red\",\n",
614 | " size=18\n",
615 | " ),\n",
616 | " showlegend=False)\n",
617 | "\n",
618 | " fig.add_scatter3d(x=X0, y=Y_axis, z=Z0, \n",
619 | " line=dict(\n",
620 | " color='green',\n",
621 | " width=10\n",
622 | " ),\n",
623 | " mode='lines+text',\n",
624 | " text=Y_text,\n",
625 | " textposition='top center',\n",
626 | " textfont=dict(\n",
627 | " color=\"green\",\n",
628 | " size=18\n",
629 | " ),\n",
630 | " showlegend=False)\n",
631 | "\n",
632 | " fig.add_scatter3d(x=X0, y=Y0, z=Z_axis, \n",
633 | " line=dict(\n",
634 | " color='blue',\n",
635 | " width=10\n",
636 | " ),\n",
637 | " mode='lines+text',\n",
638 | " text=Z_text,\n",
639 | " textposition='top center',\n",
640 | " textfont=dict(\n",
641 | " color=\"blue\",\n",
642 | " size=18\n",
643 | " ),\n",
644 | " showlegend=False)\n",
645 | " \n",
646 | " return fig\n",
647 | "\n",
648 | "# draw cam and cam coordinate system\n",
649 | "def draw_cam_init(fig, world_mat, camera_axis_size, camera_color):\n",
650 | " # camera at init\n",
651 | "\n",
652 | " Xc = [world_mat[0, : ,3][0]]\n",
653 | " Yc = [world_mat[0, : ,3][1]]\n",
654 | " Zc = [world_mat[0, : ,3][2]]\n",
655 | " text_c = [\"Camera\"]\n",
656 | "\n",
657 | " # camera axis\n",
658 | " Xc_Xaxis = Xc + [world_mat[0, : ,0][0]*camera_axis_size+Xc[0]]\n",
659 | " Yc_Xaxis = Yc + [world_mat[0, : ,0][1]*camera_axis_size+Yc[0]]\n",
660 | " Zc_Xaxis = Zc + [world_mat[0, : ,0][2]*camera_axis_size+Zc[0]]\n",
661 | " text_Xaxis = [None, \"Xc\"]\n",
662 | " \n",
663 | " # -z in world perspective\n",
664 | " Xc_Yaxis = Xc + [world_mat[0, : ,1][0]*camera_axis_size+Xc[0]]\n",
665 | " Yc_Yaxis = Yc + [world_mat[0, : ,1][1]*camera_axis_size+Yc[0]]\n",
666 | " Zc_Yaxis = Zc + [world_mat[0, : ,1][2]*camera_axis_size+Zc[0]]\n",
667 | " text_Yaxis = [None, \"Yc\"]\n",
668 | "\n",
669 | " # y in world perspective\n",
670 | " Xc_Zaxis = Xc + [world_mat[0, : ,2][0]*camera_axis_size+Xc[0]]\n",
671 | " Yc_Zaxis = Yc + [world_mat[0, : ,2][1]*camera_axis_size+Yc[0]]\n",
672 | " Zc_Zaxis = Zc + [world_mat[0, : ,2][2]*camera_axis_size+Zc[0]]\n",
673 | " text_Zaxis = [None, \"Zc\"]\n",
674 | " \n",
675 | " # cam pos\n",
676 | " fig.add_scatter3d(x=Xc, y=Yc, z=Zc, \n",
677 | " mode='markers',\n",
678 | " marker=dict(\n",
679 | " color=camera_color,\n",
680 | " size=4,\n",
681 | " sizemode='diameter'\n",
682 | " ),\n",
683 | " showlegend=False)\n",
684 | "\n",
685 | " # camera axis\n",
686 | " fig.add_scatter3d(x=Xc_Xaxis, y=Yc_Xaxis, z=Zc_Xaxis, \n",
687 | " line=dict(\n",
688 | " color='red',\n",
689 | " width=10\n",
690 | " ),\n",
691 | " mode='lines+text',\n",
692 | " text=text_Xaxis,\n",
693 | " textposition='top center',\n",
694 | " textfont=dict(\n",
695 | " color=\"red\",\n",
696 | " size=18\n",
697 | " ),\n",
698 | " showlegend=False)\n",
699 | "\n",
700 | " fig.add_scatter3d(x=Xc_Yaxis, y=Yc_Yaxis, z=Zc_Yaxis, \n",
701 | " line=dict(\n",
702 | " color='green',\n",
703 | " width=10\n",
704 | " ),\n",
705 | " mode='lines+text',\n",
706 | " text=text_Yaxis,\n",
707 | " textposition='top center',\n",
708 | " textfont=dict(\n",
709 | " color=\"green\",\n",
710 | " size=18\n",
711 | " ),\n",
712 | " showlegend=False)\n",
713 | "\n",
714 | " fig.add_scatter3d(x=Xc_Zaxis, y=Yc_Zaxis, z=Zc_Zaxis, \n",
715 | " line=dict(\n",
716 | " color='blue',\n",
717 | " width=10\n",
718 | " ),\n",
719 | " mode='lines+text',\n",
720 | " text=text_Zaxis,\n",
721 | " textposition='top center',\n",
722 | " textfont=dict(\n",
723 | " color=\"blue\",\n",
724 | " size=18\n",
725 | " ),\n",
726 | " showlegend=False)\n",
727 | " \n",
728 | " return fig\n",
729 | "\n",
730 | "# draw all rays\n",
731 | "def draw_all_rays(fig, p_i, ray_color):\n",
732 | " for i in range(p_i.shape[1]):\n",
733 | " Xray = p_i[0, i, :, 0]\n",
734 | " Yray = p_i[0, i, :, 1]\n",
735 | " Zray = p_i[0, i, :, 2]\n",
736 | " \n",
737 | " fig.add_scatter3d(x=Xray, y=Yray, z=Zray, \n",
738 | " line=dict(\n",
739 | " color=ray_color,\n",
740 | " width=5\n",
741 | " ),\n",
742 | " mode='lines',\n",
743 | " showlegend=False)\n",
744 | " \n",
745 | " return fig\n",
746 | "\n",
747 | "# draw all rays\n",
748 | "def draw_all_rays_with_marker(fig, p_i, marker_size, ray_color):\n",
749 | " \n",
750 | " # convert colorscale string to px.colors.seqeuntial\n",
751 | " # default color is set to Viridis in case of mismatch\n",
752 | " c = px.colors.sequential.Viridis\n",
753 | "\n",
754 | " for c_name in [ray_color, ray_color.capitalize()]:\n",
755 | " try:\n",
756 | " c = getattr(px.colors.sequential, c_name)\n",
757 | " except:\n",
758 | " continue\n",
759 | " \n",
760 | " for i in range(p_i.shape[1]):\n",
761 | " Xray = p_i[0, i, :, 0]\n",
762 | " Yray = p_i[0, i, :, 1]\n",
763 | " Zray = p_i[0, i, :, 2]\n",
764 | " \n",
765 | " fig.add_scatter3d(x=Xray, y=Yray, z=Zray, \n",
766 | " \n",
767 | " marker=dict(\n",
768 | "# color=np.arange(len(Xray)),\n",
769 | " color=c,\n",
770 | "# colorscale='Viridis',\n",
771 | " size=marker_size\n",
772 | " ),\n",
773 | " \n",
774 | " line=dict(\n",
775 | "# color=np.arange(len(Xray)),\n",
776 | " color=c,\n",
777 | "# colorscale='Viridis',\n",
778 | " width=3\n",
779 | " ),\n",
780 | " mode=\"lines+markers\",\n",
781 | " showlegend=False)\n",
782 | " \n",
783 | " return fig\n",
784 | "\n",
785 | "# draw near&far frustrum with rays connecting the corners (changed for nerfpp)\n",
786 | "def draw_ray_frus(fig, p_i, frustrum_color, frustrum_opacity, at=[0, -1]):\n",
787 | " \n",
788 | " for i in at:\n",
789 | "# Xfrus = p_i[0, :, i, 0][[0,1,2,3,7,11,15,14,13,12,8,4,0]]\n",
790 | "# Yfrus = p_i[0, :, i, 1][[0,1,2,3,7,11,15,14,13,12,8,4,0]]\n",
791 | "# Zfrus = p_i[0, :, i, 2][[0,1,2,3,7,11,15,14,13,12,8,4,0]]\n",
792 | "\n",
793 | " Xfrus = p_i[0, :, i, 0]\n",
794 | " Yfrus = p_i[0, :, i, 1]\n",
795 | " Zfrus = p_i[0, :, i, 2]\n",
796 | " \n",
797 | " fig.add_scatter3d(x=Xfrus, y=Yfrus, z=Zfrus, \n",
798 | " line=dict(\n",
799 | " color=frustrum_color,\n",
800 | " width=5\n",
801 | " ),\n",
802 | " mode='lines',\n",
803 | " surfaceaxis=0,\n",
804 | " surfacecolor=frustrum_color,\n",
805 | " opacity=frustrum_opacity,\n",
806 | " showlegend=False)\n",
807 | " \n",
808 | " return fig\n",
809 | "\n",
810 | "# draw foreground sample points, ray and frustrum\n",
811 | "def draw_foreground(fig, fg_pts, fg_color, marker_size, at=[0, -1]):\n",
812 | " fig = draw_all_rays_with_marker(fig, fg_pts, marker_size, fg_color)\n",
813 | " \n",
814 | " return fig\n",
815 | "\n",
816 | "# draw background sample points, ray and frustrum\n",
817 | "def draw_background(fig, bg_pts, bg_color, marker_size, at=[0, -1]):\n",
818 | " fig = draw_all_rays_with_marker(fig, bg_pts, marker_size, bg_color)\n",
819 | " \n",
820 | " return fig"
821 | ]
822 | },
823 | {
824 | "cell_type": "code",
825 | "execution_count": 8,
826 | "id": "f12fecdd",
827 | "metadata": {},
828 | "outputs": [
829 | {
830 | "name": "stderr",
831 | "output_type": "stream",
832 | "text": [
833 | "C:\\Users\\laphi\\AppData\\Local\\Temp/ipykernel_24504/3891967274.py:2: UserWarning: \n",
834 | "The dash_core_components package is deprecated. Please replace\n",
835 | "`import dash_core_components as dcc` with `from dash import dcc`\n",
836 | " import dash_core_components as dcc\n",
837 | "C:\\Users\\laphi\\AppData\\Local\\Temp/ipykernel_24504/3891967274.py:3: UserWarning: \n",
838 | "The dash_html_components package is deprecated. Please replace\n",
839 | "`import dash_html_components as html` with `from dash import html`\n",
840 | " import dash_html_components as html\n"
841 | ]
842 | }
843 | ],
844 | "source": [
845 | "from jupyter_dash import JupyterDash\n",
846 | "import dash_core_components as dcc\n",
847 | "import dash_html_components as html\n",
848 | "from dash.dependencies import Input, Output\n",
849 | "\n",
850 | "# for colors\n",
851 | "import matplotlib.colors as mcolors"
852 | ]
853 | },
854 | {
855 | "cell_type": "code",
856 | "execution_count": 10,
857 | "id": "e5e28428",
858 | "metadata": {
859 | "scrolled": false
860 | },
861 | "outputs": [
862 | {
863 | "name": "stderr",
864 | "output_type": "stream",
865 | "text": [
866 | "c:\\users\\laphi\\appdata\\local\\programs\\python\\python38\\lib\\site-packages\\jupyter_dash\\jupyter_app.py:139: UserWarning:\n",
867 | "\n",
868 | "The 'environ['werkzeug.server.shutdown']' function is deprecated and will be removed in Werkzeug 2.1.\n",
869 | "\n"
870 | ]
871 | },
872 | {
873 | "data": {
874 | "text/html": [
875 | "\n",
876 | " \n",
883 | " "
884 | ],
885 | "text/plain": [
886 | ""
887 | ]
888 | },
889 | "metadata": {},
890 | "output_type": "display_data"
891 | }
892 | ],
893 | "source": [
894 | "app = JupyterDash(__name__)\n",
895 | "\n",
896 | "app.layout = html.Div([\n",
897 | " html.H1(\"Nerfplusplus ray sampling visualization\"),\n",
898 | " dcc.Graph(id='graph'),\n",
899 | " \n",
900 | " html.Div([\n",
901 | " html.Div([\n",
902 | " \n",
903 | " # changes to setting and ray casted \n",
904 | " html.Label([ \"u (theta)\",\n",
905 | " dcc.Slider(\n",
906 | " id='u-slider', \n",
907 | " min=0, max=1,\n",
908 | " value=0.00,\n",
909 | " marks={str(val) : str(val) for val in [0.00, 0.25, 0.50, 0.75]},\n",
910 | " step=0.01, tooltip = { 'always_visible': True }\n",
911 | " ), ]),\n",
912 | " html.Label([ \"v (phi)\",\n",
913 | " dcc.Slider(\n",
914 | " id='v-slider', \n",
915 | " min=0, max=1,\n",
916 | " value=0.25,\n",
917 | " marks={str(val) : str(val) for val in [0.00, 0.25, 0.50, 0.75]},\n",
918 | " step=0.01, tooltip = { 'always_visible': True }\n",
919 | " ), ]),\n",
920 | " html.Label([ \"fov (field-of-view))\",\n",
921 | " dcc.Slider(\n",
922 | " id='fov-slider', \n",
923 | " min=0, max=100,\n",
924 | " value=50,\n",
925 | " marks={str(val) : str(val) for val in [0, 20, 40, 60, 80, 100]},\n",
926 | " step=5, tooltip = { 'always_visible': True }\n",
927 | " ), ]),\n",
928 | " html.Label([ \"foreground near depth\",\n",
929 | " dcc.Slider(\n",
930 | " id='foreground-near-depth-slider', \n",
931 | " min=0, max=2,\n",
932 | " value=0.5,\n",
933 | " marks={f\"{val:.1f}\" : f\"{val:.1f}\" for val in [0.1 * i for i in range(21)]},\n",
934 | " step=0.1, tooltip = { 'always_visible': True }\n",
935 | " ), ])\n",
936 | " ], style = {'width' : '48%', 'display' : 'inline-block'}),\n",
937 | " \n",
938 | " html.Div([\n",
939 | " # changes to visual appearance\n",
940 | " \n",
941 | " # axis scale\n",
942 | " html.Div([\n",
943 | " html.Label([ \"world axis size\",\n",
944 | " html.Div([\n",
945 | " dcc.Input(id='world-axis-size-input',\n",
946 | " value=1.5,\n",
947 | " type='number', style={'width': '50%'}\n",
948 | " )\n",
949 | " ]),\n",
950 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n",
951 | " html.Label([ \"camera axis size\",\n",
952 | " html.Div([\n",
953 | " dcc.Input(id='camera-axis-size-input',\n",
954 | " value=0.3,\n",
955 | " type='number', style={'width': '50%'}\n",
956 | " )\n",
957 | " ]),\n",
958 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n",
959 | " html.Label([ \"sample marker size\",\n",
960 | " html.Div([\n",
961 | " dcc.Input(id='sample-marker-size-input',\n",
962 | " value=2,\n",
963 | " type='number', style={'width': '50%'}\n",
964 | " )\n",
965 | " ]),\n",
966 | " ], style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}),\n",
967 | " ]),\n",
968 | " \n",
969 | " # opacity \n",
970 | " html.Div([\n",
971 | " html.Label([ \"sphere opacity\",\n",
972 | " html.Div([\n",
973 | " dcc.Input(id='sphere-opacity-input',\n",
974 | " value=0.2,\n",
975 | " type='number', style={'width': '50%'}\n",
976 | " )\n",
977 | " ])\n",
978 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n",
979 | " \n",
980 | " html.Label([ \"xy-plane opacity\", \n",
981 | " html.Div([\n",
982 | " dcc.Input(id='xy-plane-opacity-input',\n",
983 | " value=0.8,\n",
984 | " type='number', style={'width': '50%'}\n",
985 | " )\n",
986 | " ])\n",
987 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n",
988 | " \n",
989 | " ]),\n",
990 | " \n",
991 | " # color\n",
992 | " html.Div([\n",
993 | " html.Label([ \"camera color\",\n",
994 | " html.Div([\n",
995 | " dcc.Dropdown(id='camera-color-input',\n",
996 | " clearable=False,\n",
997 | " value='yellow',\n",
998 | " options=[\n",
999 | " {'label': c, 'value': c}\n",
1000 | " for (c, _) in mcolors.CSS4_COLORS.items()\n",
1001 | " ], style={'width': '80%'}\n",
1002 | " )\n",
1003 | " ])\n",
1004 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n",
1005 | " \n",
1006 | " html.Label([ \"foreground color\",\n",
1007 | " html.Div([\n",
1008 | " dcc.Dropdown(id='fg-color-input',\n",
1009 | " clearable=False,\n",
1010 | " value='plotly3',\n",
1011 | " options=[\n",
1012 | " {'label': c, 'value': c}\n",
1013 | " for c in px.colors.named_colorscales()\n",
1014 | " ], style={'width': '80%'}\n",
1015 | " )\n",
1016 | " ])\n",
1017 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n",
1018 | " \n",
1019 | " html.Label([ \"background color\",\n",
1020 | " html.Div([\n",
1021 | " dcc.Dropdown(id='bg-color-input',\n",
1022 | " clearable=False,\n",
1023 | " value='plotly3',\n",
1024 | " options=[\n",
1025 | " {'label': c, 'value': c}\n",
1026 | " for c in px.colors.named_colorscales()\n",
1027 | " ], style={'width': '80%'}\n",
1028 | " )\n",
1029 | " ])\n",
1030 | " ], style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}),\n",
1031 | " ]),\n",
1032 | " \n",
1033 | " # colorscale\n",
1034 | " html.Div([\n",
1035 | " html.Label([ \"sphere colorscale\",\n",
1036 | " html.Div([\n",
1037 | " dcc.Dropdown(id='sphere-colorscale-input',\n",
1038 | " clearable=False,\n",
1039 | " value='greys',\n",
1040 | " options=[\n",
1041 | " {'label': c, 'value': c}\n",
1042 | " for c in px.colors.named_colorscales()\n",
1043 | " ], style={'width': '80%'}\n",
1044 | " )\n",
1045 | " ])\n",
1046 | " ], style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}),\n",
1047 | " \n",
1048 | " html.Label([ \"xy-plane colorscale\",\n",
1049 | " html.Div([\n",
1050 | " dcc.Dropdown(id='xy-plane-colorscale-input',\n",
1051 | " clearable=False,\n",
1052 | " value='greys',\n",
1053 | " options=[\n",
1054 | " {'label': c, 'value': c}\n",
1055 | " for c in px.colors.named_colorscales()\n",
1056 | " ], style={'width': '80%'}\n",
1057 | " )\n",
1058 | " ])\n",
1059 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n",
1060 | " \n",
1061 | " html.Label([ \" \",\n",
1062 | " html.Div([\n",
1063 | " dcc.Checklist(id='show-background-checklist',\n",
1064 | " \n",
1065 | " options=[\n",
1066 | " {'label': 'show background', 'value': 'show_background'}\n",
1067 | " ], style={'width': '80%'},\n",
1068 | " value=[],\n",
1069 | " )\n",
1070 | " ])\n",
1071 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n",
1072 | " \n",
1073 | " ]),\n",
1074 | " \n",
1075 | " \n",
1076 | " \n",
1077 | " ], style = {'width' : '48%', 'float' : 'right', 'display' : 'inline-block'}),\n",
1078 | " \n",
1079 | " ]),\n",
1080 | " \n",
1081 | "])\n",
1082 | "\n",
1083 | "@app.callback(\n",
1084 | " Output('graph', 'figure'),\n",
1085 | " Input(\"u-slider\", \"value\"),\n",
1086 | " Input(\"v-slider\", \"value\"),\n",
1087 | " \n",
1088 | " Input(\"fov-slider\", \"value\"),\n",
1089 | " Input(\"foreground-near-depth-slider\", \"value\"),\n",
1090 | " \n",
1091 | " Input(\"world-axis-size-input\", \"value\"),\n",
1092 | " Input(\"camera-axis-size-input\", \"value\"),\n",
1093 | " Input(\"sample-marker-size-input\", \"value\"),\n",
1094 | " \n",
1095 | " Input(\"camera-color-input\", \"value\"),\n",
1096 | " Input(\"fg-color-input\", \"value\"),\n",
1097 | " Input(\"bg-color-input\", \"value\"),\n",
1098 | " \n",
1099 | " Input('sphere-colorscale-input', \"value\"),\n",
1100 | " Input('xy-plane-colorscale-input', \"value\"),\n",
1101 | " Input('show-background-checklist', \"value\"),\n",
1102 | " \n",
1103 | " Input(\"sphere-opacity-input\", \"value\"),\n",
1104 | " Input(\"xy-plane-opacity-input\", \"value\"),\n",
1105 | ")\n",
1106 | "\n",
1107 | "def update_figure(u, v, \n",
1108 | " fov, foreground_near_depth,\n",
1109 | " world_axis_size, camera_axis_size, sample_marker_size,\n",
1110 | " camera_color, fg_color, bg_color,\n",
1111 | " sphere_colorscale, xy_plane_colorscale, show_background,\n",
1112 | " sphere_opacity, xy_plane_opacity \n",
1113 | " ):\n",
1114 | " \n",
1115 | " depth_range = [foreground_near_depth, 2]\n",
1116 | " \n",
1117 | " # sphere\n",
1118 | " fig = draw_sphere(r=1, sphere_colorscale=sphere_colorscale, sphere_opacity=sphere_opacity)\n",
1119 | "\n",
1120 | " # change figure size\n",
1121 | "# fig.update_layout(autosize=False, width = 500, height=500)\n",
1122 | "\n",
1123 | " # draw axes in proportion to the proportion of their ranges\n",
1124 | " fig.update_layout(scene_aspectmode='data')\n",
1125 | "\n",
1126 | " # xy plane\n",
1127 | " fig = draw_XYplane(fig, xy_plane_colorscale, xy_plane_opacity,\n",
1128 | " x_range=[-depth_range[1], depth_range[1]], y_range=[-depth_range[1], depth_range[1]])\n",
1129 | "\n",
1130 | " # show world coordinate system (X, Y, Z positive direction)\n",
1131 | " fig = draw_XYZworld(fig, world_axis_size=world_axis_size)\n",
1132 | "\n",
1133 | " pixels_world, camera_world, world_mat, fg_pts, bg_pts = nerfpp(u=u, v=v, fov=fov, depth_range=depth_range)\n",
1134 | "\n",
1135 | " # draw camera at init (with its cooridnate system)\n",
1136 | " fig = draw_cam_init(fig, world_mat, \n",
1137 | " camera_axis_size=camera_axis_size, camera_color=camera_color)\n",
1138 | "\n",
1139 | " # draw foreground and background sample point, ray, and frustrum\n",
1140 | " fig = draw_foreground(fig, fg_pts, fg_color, sample_marker_size, at=[0, -1])\n",
1141 | " \n",
1142 | " if show_background:\n",
1143 | " fig = draw_background(fig, bg_pts.unsqueeze(0), bg_color, sample_marker_size, at=[0, -1])\n",
1144 | " \n",
1145 | " return fig\n",
1146 | "\n",
1147 | "app.run_server(mode='inline')"
1148 | ]
1149 | }
1150 | ],
1151 | "metadata": {
1152 | "kernelspec": {
1153 | "display_name": "Python 3 (ipykernel)",
1154 | "language": "python",
1155 | "name": "python3"
1156 | },
1157 | "language_info": {
1158 | "codemirror_mode": {
1159 | "name": "ipython",
1160 | "version": 3
1161 | },
1162 | "file_extension": ".py",
1163 | "mimetype": "text/x-python",
1164 | "name": "python",
1165 | "nbconvert_exporter": "python",
1166 | "pygments_lexer": "ipython3",
1167 | "version": "3.8.8"
1168 | }
1169 | },
1170 | "nbformat": 4,
1171 | "nbformat_minor": 5
1172 | }
1173 |
--------------------------------------------------------------------------------
/NerfppVisualizationDash.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "a6fca250",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import plotly.graph_objects as go\n",
11 | "import numpy as np\n",
12 | "import torch\n",
13 | "import plotly.express as px"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 2,
19 | "id": "de77ce6c",
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "HUGE_NUMBER = 1e10\n",
24 | "TINY_NUMBER = 1e-6 "
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "id": "541c3686",
30 | "metadata": {},
31 | "source": [
32 | "# Ray Casting"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 3,
38 | "id": "27ee36ff",
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "# from autonomousvision/giraffe\n",
43 | "\n",
44 | "# 0.\n",
45 | "def to_pytorch(tensor, return_type=False):\n",
46 | " ''' Converts input tensor to pytorch.\n",
47 | " Args:\n",
48 | " tensor (tensor): Numpy or Pytorch tensor\n",
49 | " return_type (bool): whether to return input type\n",
50 | " '''\n",
51 | " is_numpy = False\n",
52 | " if type(tensor) == np.ndarray:\n",
53 | " tensor = torch.from_numpy(tensor).float()\n",
54 | " is_numpy = True\n",
55 | " tensor = tensor.clone()\n",
56 | " if return_type:\n",
57 | " return tensor, is_numpy\n",
58 | " return tensor\n",
59 | "\n",
60 | "# 1. get camera intrinsic\n",
61 | "def get_camera_mat(fov=49.13, invert=True):\n",
62 | " # fov = 2 * arctan( sensor / (2 * focal))\n",
63 | " # focal = (sensor / 2) * 1 / (tan(0.5 * fov))\n",
64 | " # in our case, sensor = 2 as pixels are in [-1, 1]\n",
65 | " focal = 1. / np.tan(0.5 * fov * np.pi/180.)\n",
66 | " focal = focal.astype(np.float32)\n",
67 | " mat = torch.tensor([\n",
68 | " [focal, 0., 0., 0.],\n",
69 | " [0., focal, 0., 0.],\n",
70 | " [0., 0., 1, 0.],\n",
71 | " [0., 0., 0., 1.]\n",
72 | " ]).reshape(1, 4, 4)\n",
73 | "\n",
74 | " if invert:\n",
75 | " mat = torch.inverse(mat)\n",
76 | " return mat\n",
77 | "\n",
78 | "# 2. get camera position with camera pose (theta & phi)\n",
79 | "def to_sphere(u, v):\n",
80 | " theta = 2 * np.pi * u\n",
81 | " phi = np.arccos(1 - 2 * v)\n",
82 | " cx = np.sin(phi) * np.cos(theta)\n",
83 | " cy = np.sin(phi) * np.sin(theta)\n",
84 | " cz = np.cos(phi)\n",
85 | " return np.stack([cx, cy, cz], axis=-1)\n",
86 | "\n",
87 | "# 3. get camera coordinate system assuming it points to the center of the sphere\n",
88 | "def look_at(eye, at=np.array([0, 0, 0]), up=np.array([0, 0, 1]), eps=1e-5,\n",
89 | " to_pytorch=True):\n",
90 | " at = at.astype(float).reshape(1, 3)\n",
91 | " up = up.astype(float).reshape(1, 3)\n",
92 | " eye = eye.reshape(-1, 3)\n",
93 | " up = up.repeat(eye.shape[0] // up.shape[0], axis=0)\n",
94 | " eps = np.array([eps]).reshape(1, 1).repeat(up.shape[0], axis=0)\n",
95 | "\n",
96 | " z_axis = eye - at\n",
97 | " z_axis /= np.max(np.stack([np.linalg.norm(z_axis,\n",
98 | " axis=1, keepdims=True), eps]))\n",
99 | "\n",
100 | " x_axis = np.cross(up, z_axis)\n",
101 | " x_axis /= np.max(np.stack([np.linalg.norm(x_axis,\n",
102 | " axis=1, keepdims=True), eps]))\n",
103 | "\n",
104 | " y_axis = np.cross(z_axis, x_axis)\n",
105 | " y_axis /= np.max(np.stack([np.linalg.norm(y_axis,\n",
106 | " axis=1, keepdims=True), eps]))\n",
107 | "\n",
108 | " r_mat = np.concatenate(\n",
109 | " (x_axis.reshape(-1, 3, 1), y_axis.reshape(-1, 3, 1), z_axis.reshape(\n",
110 | " -1, 3, 1)), axis=2)\n",
111 | "\n",
112 | " if to_pytorch:\n",
113 | " r_mat = torch.tensor(r_mat).float()\n",
114 | "\n",
115 | " return r_mat\n",
116 | "\n",
117 | "# 5. arange 2d array of pixel coordinate and give depth of 1\n",
118 | "def arange_pixels(resolution=(128, 128), batch_size=1, image_range=(-1., 1.),\n",
119 | " subsample_to=None, invert_y_axis=False):\n",
120 | " ''' Arranges pixels for given resolution in range image_range.\n",
121 | " The function returns the unscaled pixel locations as integers and the\n",
122 | " scaled float values.\n",
123 | " Args:\n",
124 | " resolution (tuple): image resolution\n",
125 | " batch_size (int): batch size\n",
126 | " image_range (tuple): range of output points (default [-1, 1])\n",
127 | " subsample_to (int): if integer and > 0, the points are randomly\n",
128 | " subsampled to this value\n",
129 | " '''\n",
130 | " h, w = resolution\n",
131 | " n_points = resolution[0] * resolution[1]\n",
132 | "\n",
133 | " # Arrange pixel location in scale resolution\n",
134 | " pixel_locations = torch.meshgrid(torch.arange(0, w), torch.arange(0, h))\n",
135 | " pixel_locations = torch.stack(\n",
136 | " [pixel_locations[0], pixel_locations[1]],\n",
137 | " dim=-1).long().view(1, -1, 2).repeat(batch_size, 1, 1)\n",
138 | " pixel_scaled = pixel_locations.clone().float()\n",
139 | "\n",
140 | " # Shift and scale points to match image_range\n",
141 | " scale = (image_range[1] - image_range[0])\n",
142 | " loc = scale / 2\n",
143 | " pixel_scaled[:, :, 0] = scale * pixel_scaled[:, :, 0] / (w - 1) - loc\n",
144 | " pixel_scaled[:, :, 1] = scale * pixel_scaled[:, :, 1] / (h - 1) - loc\n",
145 | "\n",
146 | " # Subsample points if subsample_to is not None and > 0\n",
147 | " if (subsample_to is not None and subsample_to > 0 and\n",
148 | " subsample_to < n_points):\n",
149 | " idx = np.random.choice(pixel_scaled.shape[1], size=(subsample_to,),\n",
150 | " replace=False)\n",
151 | " pixel_scaled = pixel_scaled[:, idx]\n",
152 | " pixel_locations = pixel_locations[:, idx]\n",
153 | "\n",
154 | " if invert_y_axis:\n",
155 | " assert(image_range == (-1, 1))\n",
156 | " pixel_scaled[..., -1] *= -1.\n",
157 | " pixel_locations[..., -1] = (h - 1) - pixel_locations[..., -1]\n",
158 | "\n",
159 | " return pixel_locations, pixel_scaled\n",
160 | "\n",
161 | "# 6. mat_mul with intrinsic and then extrinsic gives you p_world (pixels in world) \n",
162 | "def image_points_to_world(image_points, camera_mat, world_mat, scale_mat=None,\n",
163 | " invert=False, negative_depth=True):\n",
164 | " ''' Transforms points on image plane to world coordinates.\n",
165 | " In contrast to transform_to_world, no depth value is needed as points on\n",
166 | " the image plane have a fixed depth of 1.\n",
167 | " Args:\n",
168 | " image_points (tensor): image points tensor of size B x N x 2\n",
169 | " camera_mat (tensor): camera matrix\n",
170 | " world_mat (tensor): world matrix\n",
171 | " scale_mat (tensor): scale matrix\n",
172 | " invert (bool): whether to invert matrices (default: False)\n",
173 | " '''\n",
174 | " batch_size, n_pts, dim = image_points.shape\n",
175 | " assert(dim == 2)\n",
176 | " d_image = torch.ones(batch_size, n_pts, 1)\n",
177 | " if negative_depth:\n",
178 | " d_image *= -1.\n",
179 | " return transform_to_world(image_points, d_image, camera_mat, world_mat,\n",
180 | " scale_mat, invert=invert)\n",
181 | "\n",
182 | "def transform_to_world(pixels, depth, camera_mat, world_mat, scale_mat=None,\n",
183 | " invert=True, use_absolute_depth=True):\n",
184 | " ''' Transforms pixel positions p with given depth value d to world coordinates.\n",
185 | " Args:\n",
186 | " pixels (tensor): pixel tensor of size B x N x 2\n",
187 | " depth (tensor): depth tensor of size B x N x 1\n",
188 | " camera_mat (tensor): camera matrix\n",
189 | " world_mat (tensor): world matrix\n",
190 | " scale_mat (tensor): scale matrix\n",
191 | " invert (bool): whether to invert matrices (default: true)\n",
192 | " '''\n",
193 | " assert(pixels.shape[-1] == 2)\n",
194 | "\n",
195 | " if scale_mat is None:\n",
196 | " scale_mat = torch.eye(4).unsqueeze(0).repeat(\n",
197 | " camera_mat.shape[0], 1, 1)\n",
198 | "\n",
199 | " # Convert to pytorch\n",
200 | " pixels, is_numpy = to_pytorch(pixels, True)\n",
201 | " depth = to_pytorch(depth)\n",
202 | " camera_mat = to_pytorch(camera_mat)\n",
203 | " world_mat = to_pytorch(world_mat)\n",
204 | " scale_mat = to_pytorch(scale_mat)\n",
205 | "\n",
206 | " # Invert camera matrices\n",
207 | " if invert:\n",
208 | " camera_mat = torch.inverse(camera_mat)\n",
209 | " world_mat = torch.inverse(world_mat)\n",
210 | " scale_mat = torch.inverse(scale_mat)\n",
211 | "\n",
212 | " # Transform pixels to homogen coordinates\n",
213 | " pixels = pixels.permute(0, 2, 1)\n",
214 | " pixels = torch.cat([pixels, torch.ones_like(pixels)], dim=1)\n",
215 | "\n",
216 | " # Project pixels into camera space\n",
217 | " if use_absolute_depth:\n",
218 | " pixels[:, :2] = pixels[:, :2] * depth.permute(0, 2, 1).abs()\n",
219 | " pixels[:, 2:3] = pixels[:, 2:3] * depth.permute(0, 2, 1)\n",
220 | " else:\n",
221 | " pixels[:, :3] = pixels[:, :3] * depth.permute(0, 2, 1)\n",
222 | " \n",
223 | " # Transform pixels to world space\n",
224 | " p_world = scale_mat @ world_mat @ camera_mat @ pixels\n",
225 | "\n",
226 | " # Transform p_world back to 3D coordinates\n",
227 | " p_world = p_world[:, :3].permute(0, 2, 1)\n",
228 | "\n",
229 | " if is_numpy:\n",
230 | " p_world = p_world.numpy()\n",
231 | " return p_world\n",
232 | "\n",
233 | "\n",
234 | "# 7. mat_mul zeros with intrinsic&extrinsic for camera pos (which we alread obtained as loc)\n",
235 | "def origin_to_world(n_points, camera_mat, world_mat, scale_mat=None,\n",
236 | " invert=False):\n",
237 | " ''' Transforms origin (camera location) to world coordinates.\n",
238 | " Args:\n",
239 | " n_points (int): how often the transformed origin is repeated in the\n",
240 | " form (batch_size, n_points, 3)\n",
241 | " camera_mat (tensor): camera matrix\n",
242 | " world_mat (tensor): world matrix\n",
243 | " scale_mat (tensor): scale matrix\n",
244 | " invert (bool): whether to invert the matrices (default: false)\n",
245 | " '''\n",
246 | " \n",
247 | " batch_size = camera_mat.shape[0]\n",
248 | " device = camera_mat.device\n",
249 | " # Create origin in homogen coordinates\n",
250 | " p = torch.zeros(batch_size, 4, n_points).to(device)\n",
251 | " p[:, -1] = 1.\n",
252 | "\n",
253 | " if scale_mat is None:\n",
254 | " scale_mat = torch.eye(4).unsqueeze(\n",
255 | " 0).repeat(batch_size, 1, 1).to(device)\n",
256 | "\n",
257 | " # Invert matrices\n",
258 | " if invert:\n",
259 | " camera_mat = torch.inverse(camera_mat)\n",
260 | " world_mat = torch.inverse(world_mat)\n",
261 | " scale_mat = torch.inverse(scale_mat)\n",
262 | " \n",
263 | " camera_mat = to_pytorch(camera_mat)\n",
264 | " world_mat = to_pytorch(world_mat)\n",
265 | " scale_mat = to_pytorch(scale_mat)\n",
266 | " \n",
267 | " # Apply transformation\n",
268 | " p_world = scale_mat @ world_mat @ camera_mat @ p\n",
269 | "\n",
270 | " # Transform points back to 3D coordinates\n",
271 | " p_world = p_world[:, :3].permute(0, 2, 1)\n",
272 | " return p_world"
273 | ]
274 | },
275 | {
276 | "cell_type": "markdown",
277 | "id": "414ea1ff",
278 | "metadata": {},
279 | "source": [
280 | "# Sampling"
281 | ]
282 | },
283 | {
284 | "cell_type": "code",
285 | "execution_count": 4,
286 | "id": "d62aa105",
287 | "metadata": {},
288 | "outputs": [],
289 | "source": [
290 | "# from Kai-46/nerfplusplus\n",
291 | "\n",
292 | "# 8. intersect sphere for distinguishing fg and bg\n",
293 | "def intersect_sphere(ray_o, ray_d):\n",
294 | " '''\n",
295 | " ray_o, ray_d: [..., 3]\n",
296 | " compute the depth of the intersection point between this ray and unit sphere\n",
297 | " '''\n",
298 | " # note: d1 becomes negative if this mid point is behind camera\n",
299 | " d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)\n",
300 | " p = ray_o + d1.unsqueeze(-1) * ray_d\n",
301 | " # consider the case where the ray does not intersect the sphere\n",
302 | " ray_d_cos = 1. / torch.norm(ray_d, dim=-1)\n",
303 | " p_norm_sq = torch.sum(p * p, dim=-1)\n",
304 | " if (p_norm_sq >= 1.).any():\n",
305 | " raise Exception('Not all your cameras are bounded by the unit sphere; please make sure the cameras are normalized properly!')\n",
306 | " d2 = torch.sqrt(1. - p_norm_sq) * ray_d_cos\n",
307 | "\n",
308 | " return d1 + d2\n",
309 | "\n",
310 | "# 9. inverse sphere sampling for bg\n",
311 | "def depth2pts_outside(ray_o, ray_d, depth):\n",
312 | " '''\n",
313 | " ray_o, ray_d: [..., 3]\n",
314 | " depth: [...]; inverse of distance to sphere origin\n",
315 | " '''\n",
316 | " # note: d1 becomes negative if this mid point is behind camera\n",
317 | " d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)\n",
318 | " p_mid = ray_o + d1.unsqueeze(-1) * ray_d\n",
319 | " p_mid_norm = torch.norm(p_mid, dim=-1)\n",
320 | " ray_d_cos = 1. / torch.norm(ray_d, dim=-1)\n",
321 | " d2 = torch.sqrt(1. - p_mid_norm * p_mid_norm) * ray_d_cos\n",
322 | " \n",
323 | " p_sphere = ray_o + (d1 + d2).unsqueeze(-1) * ray_d\n",
324 | "\n",
325 | " rot_axis = torch.cross(ray_o, p_sphere, dim=-1)\n",
326 | " rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True)\n",
327 | " phi = torch.asin(p_mid_norm)\n",
328 | " theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1]\n",
329 | " rot_angle = (phi - theta).unsqueeze(-1) # [..., 1]\n",
330 | "\n",
331 | " # now rotate p_sphere\n",
332 | " # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula\n",
333 | " p_sphere_new = p_sphere * torch.cos(rot_angle) + \\\n",
334 | " torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \\\n",
335 | " rot_axis * torch.sum(rot_axis*p_sphere, dim=-1, keepdim=True) * (1.-torch.cos(rot_angle))\n",
336 | " p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True)\n",
337 | " pts = torch.cat((p_sphere_new.squeeze(), depth.squeeze().unsqueeze(-1)), dim=-1) # (modified) added .squeeze()\n",
338 | "\n",
339 | " # now calculate conventional depth\n",
340 | " depth_real = 1. / (depth + TINY_NUMBER) * torch.cos(theta) * ray_d_cos + d1\n",
341 | " return pts, depth_real\n",
342 | "\n",
343 | "# 10. perturb sample z values (depth) for some randomness (stratified sampling)\n",
344 | "def perturb_samples(z_vals):\n",
345 | " # get intervals between samples\n",
346 | " mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])\n",
347 | " upper = torch.cat([mids, z_vals[..., -1:]], dim=-1)\n",
348 | " lower = torch.cat([z_vals[..., 0:1], mids], dim=-1)\n",
349 | " # uniform samples in those intervals\n",
350 | " t_rand = torch.rand_like(z_vals)\n",
351 | " z_vals = lower + (upper - lower) * t_rand # [N_rays, N_samples]\n",
352 | "\n",
353 | " return z_vals\n",
354 | "\n",
355 | "# 11. return fg (just uniform) and bg (uniform inverse sphere) samples\n",
356 | "def uniformsampling(ray_o, ray_d, min_depth, N_samples, device):\n",
357 | " \n",
358 | " dots_sh = list(ray_d.shape[:-1])\n",
359 | "\n",
360 | " # foreground depth\n",
361 | " fg_far_depth = intersect_sphere(ray_o, ray_d) # [...,]\n",
362 | " fg_near_depth = min_depth # [..., ]\n",
363 | " step = (fg_far_depth - fg_near_depth) / (N_samples - 1)\n",
364 | " fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples]\n",
365 | " fg_depth = perturb_samples(fg_depth) # random perturbation during training\n",
366 | "\n",
367 | " # background depth\n",
368 | " bg_depth = torch.linspace(0., 1., N_samples).view(\n",
369 | " [1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(device)\n",
370 | " bg_depth = perturb_samples(bg_depth) # random perturbation during training\n",
371 | "\n",
372 | "\n",
373 | " fg_ray_o = ray_o.unsqueeze(-2).expand(dots_sh + [N_samples, 3])\n",
374 | " fg_ray_d = ray_d.unsqueeze(-2).expand(dots_sh + [N_samples, 3])\n",
375 | "\n",
376 | " bg_ray_o = ray_o.unsqueeze(-2).expand(dots_sh + [N_samples, 3])\n",
377 | " bg_ray_d = ray_d.unsqueeze(-2).expand(dots_sh + [N_samples, 3])\n",
378 | " \n",
379 | " # sampling foreground\n",
380 | " fg_pts = fg_ray_o + fg_depth.unsqueeze(-1) * fg_ray_d\n",
381 | "\n",
382 | " # sampling background\n",
383 | " bg_pts, bg_depth_real = depth2pts_outside(bg_ray_o, bg_ray_d, bg_depth)\n",
384 | "\n",
385 | " return fg_pts, bg_pts, bg_depth_real\n",
386 | "\n",
387 | "# 12 convert bg pts (x', y', z' 1/r) to real bg pts (x, y, z)\n",
388 | "def pts2realpts(bg_pts, bg_depth_real):\n",
389 | " return bg_pts[:, :, :3] * bg_depth_real.squeeze().unsqueeze(-1)"
390 | ]
391 | },
392 | {
393 | "cell_type": "code",
394 | "execution_count": 5,
395 | "id": "856d052e",
396 | "metadata": {},
397 | "outputs": [],
398 | "source": [
399 | "def giraffe(u = 1,\n",
400 | " v = 0.5,\n",
401 | " r=2.713,\n",
402 | " depth_range=[0.5, 6.],\n",
403 | " n_ray_samples=16,\n",
404 | " resolution_vol = 4,\n",
405 | " batch_size = 1\n",
406 | " ):\n",
407 | "\n",
408 | " range_radius=[r, r]\n",
409 | " \n",
410 | " res = resolution_vol\n",
411 | " n_points = res * res\n",
412 | "\n",
413 | " # 1. get camera intrinsic \n",
414 | " camera_mat = get_camera_mat()\n",
415 | "\n",
416 | " # 2. get camera position with camera pose (theta & phi)\n",
417 | " loc = to_sphere(u, v)\n",
418 | " loc = torch.tensor(loc).float()\n",
419 | " radius = range_radius[0] + \\\n",
420 | " torch.rand(batch_size) * (range_radius[1] - range_radius[0])\n",
421 | " loc = loc * radius.unsqueeze(-1)\n",
422 | "\n",
423 | " # 3. get camera coordinate system assuming it points to the center of the sphere\n",
424 | " R = look_at(loc)\n",
425 | "\n",
426 | " # 4. The carmera coordinate is the rotational matrix and with camera loc, it is camera extrinsic\n",
427 | " RT = np.eye(4).reshape(1, 4, 4)\n",
428 | " RT[:, :3, :3] = R\n",
429 | " RT[:, :3, -1] = loc\n",
430 | " world_mat = RT\n",
431 | "\n",
432 | " # 5. arange 2d array of pixel coordinate and give depth of 1\n",
433 | " pixels = arange_pixels((res, res), 1, invert_y_axis=False)[1]\n",
434 | " pixels[..., -1] *= -1. # still dunno why this is here\n",
435 | "\n",
436 | " # 6. mat_mul with intrinsic and then extrinsic gives you p_world (pixels in world) \n",
437 | " pixels_world = image_points_to_world(pixels, camera_mat, world_mat)\n",
438 | " \n",
439 | " # 7. mat_mul zeros with intrinsic&extrinsic for camera pos (which we alread obtained as loc)\n",
440 | " camera_world = origin_to_world(n_points, camera_mat, world_mat)\n",
441 | "\n",
442 | " # 8. ray = pixel - camera origin (in world)\n",
443 | " ray_vector = pixels_world - camera_world\n",
444 | "\n",
445 | " # 9. depths from closest to furthest (0.5 ~ 6.0)\n",
446 | " di = depth_range[0] + \\\n",
447 | " torch.linspace(0., 1., steps=n_ray_samples).reshape(1, 1, -1) * (\n",
448 | " depth_range[1] - depth_range[0])\n",
449 | " di = di.repeat(batch_size, n_points, 1)\n",
450 | "\n",
451 | " # 10. calculate points\n",
452 | " p_i = camera_world.unsqueeze(-2).contiguous() + \\\n",
453 | " di.unsqueeze(-1).contiguous() * ray_vector.unsqueeze(-2).contiguous()\n",
454 | " \n",
455 | " return pixels_world, camera_world, world_mat, p_i"
456 | ]
457 | },
458 | {
459 | "cell_type": "code",
460 | "execution_count": 6,
461 | "id": "1a0c73e1",
462 | "metadata": {},
463 | "outputs": [],
464 | "source": [
465 | "def nerfpp(u = 1,\n",
466 | " v = 0.5,\n",
467 | " fov = 49.13,\n",
468 | " depth_range=[0.5, 6.],\n",
469 | " n_ray_samples=16,\n",
470 | " resolution_vol = 4,\n",
471 | " batch_size = 1,\n",
472 | " device = torch.device('cpu')\n",
473 | " ):\n",
474 | " \n",
475 | " r = 1\n",
476 | " range_radius=[r, r]\n",
477 | " \n",
478 | " res = resolution_vol\n",
479 | " n_points = res * res\n",
480 | "\n",
481 | " # 1. get camera intrinsic - fiddle around with fov this time\n",
482 | " camera_mat = get_camera_mat(fov=fov)\n",
483 | "\n",
484 | " # 2. get camera position with camera pose (theta & phi)\n",
485 | " loc = to_sphere(u, v)\n",
486 | " loc = torch.tensor(loc).float()\n",
487 | " \n",
488 | " radius = range_radius[0] + \\\n",
489 | " torch.rand(batch_size) * (range_radius[1] - range_radius[0])\n",
490 | " \n",
491 | " loc = loc * radius.unsqueeze(-1)\n",
492 | "\n",
493 | " # 3. get camera coordinate system assuming it points to the center of the sphere\n",
494 | " R = look_at(loc)\n",
495 | "\n",
496 | " # 4. The carmera coordinate is the rotational matrix and with camera loc, it is camera extrinsic\n",
497 | " RT = np.eye(4).reshape(1, 4, 4)\n",
498 | " RT[:, :3, :3] = R\n",
499 | " RT[:, :3, -1] = loc\n",
500 | " world_mat = RT\n",
501 | "\n",
502 | " # 5. arange 2d array of pixel coordinate and give depth of 1\n",
503 | " pixels = arange_pixels((res, res), 1, invert_y_axis=False)[1]\n",
504 | " pixels[..., -1] *= -1. # still dunno why this is here\n",
505 | "\n",
506 | " # 6. mat_mul with intrinsic and then extrinsic gives you p_world (pixels in world) \n",
507 | " pixels_world = image_points_to_world(pixels, camera_mat, world_mat)\n",
508 | "\n",
509 | " # 7. mat_mul zeros with intrinsic&extrinsic for camera pos (which we alread obtained as loc)\n",
510 | " camera_world = origin_to_world(n_points, camera_mat, world_mat)\n",
511 | "\n",
512 | " # 8. ray = pixel - camera origin (in world)\n",
513 | " ray_vector = pixels_world - camera_world\n",
514 | "\n",
515 | " # 9. sample fg and bg points according to nerfpp (uniform and inverse sphere)\n",
516 | " fg_pts, bg_pts, bg_depth_real = uniformsampling(ray_o=camera_world, ray_d=ray_vector, min_depth=depth_range[0], N_samples=n_ray_samples, device=device)\n",
517 | " \n",
518 | " #10. convert bg pts (x', y', z' 1/r) to real bg pts (x, y, z)\n",
519 | " bg_pts_real = pts2realpts(bg_pts, bg_depth_real)\n",
520 | " \n",
521 | " return pixels_world, camera_world, world_mat, fg_pts, bg_pts_real"
522 | ]
523 | },
524 | {
525 | "cell_type": "markdown",
526 | "id": "2fd19406",
527 | "metadata": {},
528 | "source": [
529 | "# Visualization"
530 | ]
531 | },
532 | {
533 | "cell_type": "code",
534 | "execution_count": 7,
535 | "id": "b9dd6b14",
536 | "metadata": {},
537 | "outputs": [],
538 | "source": [
539 | "# draw sphere with radius r \n",
540 | "# also draw contours and vertical lines\n",
541 | "def draw_sphere(r, sphere_colorscale, sphere_opacity):\n",
542 | " # sphere\n",
543 | " u = np.linspace(0, 2 * np.pi, 100)\n",
544 | " v = np.linspace(0, np.pi, 100)\n",
545 | " x = r * np.outer(np.cos(u), np.sin(v))\n",
546 | " y = r * np.outer(np.sin(u), np.sin(v))\n",
547 | " z = r * np.outer(np.ones(np.size(u)), np.cos(v))\n",
548 | " \n",
549 | " # vertical lines on sphere\n",
550 | " u2 = np.linspace(0, 2 * np.pi, 20)\n",
551 | " x2 = r * np.outer(np.cos(u2), np.sin(v))\n",
552 | " y2 = r * np.outer(np.sin(u2), np.sin(v))\n",
553 | " z2 = r * np.outer(np.ones(np.size(u2)), np.cos(v))\n",
554 | " \n",
555 | " # create sphere and draw sphere with contours\n",
556 | " fig = go.Figure(data=[go.Surface(x=x, y=y, z=z, \n",
557 | " colorscale=sphere_colorscale, opacity=sphere_opacity,\n",
558 | " contours = {\n",
559 | " 'z' : {'show' : True, 'start' : -r,\n",
560 | " 'end' : r, 'size' : r/10,\n",
561 | " 'color' : 'white',\n",
562 | " 'width' : 1}\n",
563 | " }\n",
564 | " , showscale=False)])\n",
565 | " \n",
566 | " # vertical lines on sphere\n",
567 | " for i in range(len(u2)):\n",
568 | " fig.add_scatter3d(x=x2[i], y=y2[i], z=z2[i], \n",
569 | " line=dict(\n",
570 | " color='white',\n",
571 | " width=1\n",
572 | " ),\n",
573 | " mode='lines',\n",
574 | " showlegend=False)\n",
575 | " \n",
576 | " return fig\n",
577 | "\n",
578 | "# draw xyplane\n",
579 | "def draw_XYplane(fig, xy_plane_colorscale, xy_plane_opacity, x_range = [-2, 2], y_range = [-2, 2]):\n",
580 | " x3 = np.linspace(x_range[0], x_range[1], 100)\n",
581 | " y3 = np.linspace(y_range[0], y_range[1], 100)\n",
582 | " z3 = np.zeros(shape=(100,100))\n",
583 | " \n",
584 | " fig.add_surface(x=x3, y=y3, z=z3,\n",
585 | " colorscale =xy_plane_colorscale, opacity=xy_plane_opacity,\n",
586 | " showscale=False\n",
587 | " )\n",
588 | " \n",
589 | " return fig\n",
590 | " \n",
591 | "\n",
592 | "def draw_XYZworld(fig, world_axis_size):\n",
593 | " # x, y, z positive direction (world)\n",
594 | " X_axis = [0, world_axis_size]\n",
595 | " X_text = [None, \"X\"]\n",
596 | " X0 = [0, 0]\n",
597 | " Y_axis = [0, world_axis_size]\n",
598 | " Y_text = [None, \"Y\"]\n",
599 | " Y0 = [0, 0]\n",
600 | " Z_axis = [0, world_axis_size]\n",
601 | " Z_text = [None, \"Z\"]\n",
602 | " Z0 = [0, 0]\n",
603 | " \n",
604 | " fig.add_scatter3d(x=X_axis, y=Y0, z=Z0, \n",
605 | " line=dict(\n",
606 | " color='red',\n",
607 | " width=10\n",
608 | " ),\n",
609 | " mode='lines+text',\n",
610 | " text=X_text,\n",
611 | " textposition='top center',\n",
612 | " textfont=dict(\n",
613 | " color=\"red\",\n",
614 | " size=18\n",
615 | " ),\n",
616 | " showlegend=False)\n",
617 | "\n",
618 | " fig.add_scatter3d(x=X0, y=Y_axis, z=Z0, \n",
619 | " line=dict(\n",
620 | " color='green',\n",
621 | " width=10\n",
622 | " ),\n",
623 | " mode='lines+text',\n",
624 | " text=Y_text,\n",
625 | " textposition='top center',\n",
626 | " textfont=dict(\n",
627 | " color=\"green\",\n",
628 | " size=18\n",
629 | " ),\n",
630 | " showlegend=False)\n",
631 | "\n",
632 | " fig.add_scatter3d(x=X0, y=Y0, z=Z_axis, \n",
633 | " line=dict(\n",
634 | " color='blue',\n",
635 | " width=10\n",
636 | " ),\n",
637 | " mode='lines+text',\n",
638 | " text=Z_text,\n",
639 | " textposition='top center',\n",
640 | " textfont=dict(\n",
641 | " color=\"blue\",\n",
642 | " size=18\n",
643 | " ),\n",
644 | " showlegend=False)\n",
645 | " \n",
646 | " return fig\n",
647 | "\n",
648 | "# draw cam and cam coordinate system\n",
649 | "def draw_cam_init(fig, world_mat, camera_axis_size, camera_color):\n",
650 | " # camera at init\n",
651 | "\n",
652 | " Xc = [world_mat[0, : ,3][0]]\n",
653 | " Yc = [world_mat[0, : ,3][1]]\n",
654 | " Zc = [world_mat[0, : ,3][2]]\n",
655 | " text_c = [\"Camera\"]\n",
656 | "\n",
657 | " # camera axis\n",
658 | " Xc_Xaxis = Xc + [world_mat[0, : ,0][0]*camera_axis_size+Xc[0]]\n",
659 | " Yc_Xaxis = Yc + [world_mat[0, : ,0][1]*camera_axis_size+Yc[0]]\n",
660 | " Zc_Xaxis = Zc + [world_mat[0, : ,0][2]*camera_axis_size+Zc[0]]\n",
661 | " text_Xaxis = [None, \"Xc\"]\n",
662 | " \n",
663 | " # -z in world perspective\n",
664 | " Xc_Yaxis = Xc + [world_mat[0, : ,1][0]*camera_axis_size+Xc[0]]\n",
665 | " Yc_Yaxis = Yc + [world_mat[0, : ,1][1]*camera_axis_size+Yc[0]]\n",
666 | " Zc_Yaxis = Zc + [world_mat[0, : ,1][2]*camera_axis_size+Zc[0]]\n",
667 | " text_Yaxis = [None, \"Yc\"]\n",
668 | "\n",
669 | " # y in world perspective\n",
670 | " Xc_Zaxis = Xc + [world_mat[0, : ,2][0]*camera_axis_size+Xc[0]]\n",
671 | " Yc_Zaxis = Yc + [world_mat[0, : ,2][1]*camera_axis_size+Yc[0]]\n",
672 | " Zc_Zaxis = Zc + [world_mat[0, : ,2][2]*camera_axis_size+Zc[0]]\n",
673 | " text_Zaxis = [None, \"Zc\"]\n",
674 | " \n",
675 | " # cam pos\n",
676 | " fig.add_scatter3d(x=Xc, y=Yc, z=Zc, \n",
677 | " mode='markers',\n",
678 | " marker=dict(\n",
679 | " color=camera_color,\n",
680 | " size=4,\n",
681 | " sizemode='diameter'\n",
682 | " ),\n",
683 | " showlegend=False)\n",
684 | "\n",
685 | " # camera axis\n",
686 | " fig.add_scatter3d(x=Xc_Xaxis, y=Yc_Xaxis, z=Zc_Xaxis, \n",
687 | " line=dict(\n",
688 | " color='red',\n",
689 | " width=10\n",
690 | " ),\n",
691 | " mode='lines+text',\n",
692 | " text=text_Xaxis,\n",
693 | " textposition='top center',\n",
694 | " textfont=dict(\n",
695 | " color=\"red\",\n",
696 | " size=18\n",
697 | " ),\n",
698 | " showlegend=False)\n",
699 | "\n",
700 | " fig.add_scatter3d(x=Xc_Yaxis, y=Yc_Yaxis, z=Zc_Yaxis, \n",
701 | " line=dict(\n",
702 | " color='green',\n",
703 | " width=10\n",
704 | " ),\n",
705 | " mode='lines+text',\n",
706 | " text=text_Yaxis,\n",
707 | " textposition='top center',\n",
708 | " textfont=dict(\n",
709 | " color=\"green\",\n",
710 | " size=18\n",
711 | " ),\n",
712 | " showlegend=False)\n",
713 | "\n",
714 | " fig.add_scatter3d(x=Xc_Zaxis, y=Yc_Zaxis, z=Zc_Zaxis, \n",
715 | " line=dict(\n",
716 | " color='blue',\n",
717 | " width=10\n",
718 | " ),\n",
719 | " mode='lines+text',\n",
720 | " text=text_Zaxis,\n",
721 | " textposition='top center',\n",
722 | " textfont=dict(\n",
723 | " color=\"blue\",\n",
724 | " size=18\n",
725 | " ),\n",
726 | " showlegend=False)\n",
727 | " \n",
728 | " return fig\n",
729 | "\n",
730 | "# draw all rays\n",
731 | "def draw_all_rays(fig, p_i, ray_color):\n",
732 | " for i in range(p_i.shape[1]):\n",
733 | " Xray = p_i[0, i, :, 0]\n",
734 | " Yray = p_i[0, i, :, 1]\n",
735 | " Zray = p_i[0, i, :, 2]\n",
736 | " \n",
737 | " fig.add_scatter3d(x=Xray, y=Yray, z=Zray, \n",
738 | " line=dict(\n",
739 | " color=ray_color,\n",
740 | " width=5\n",
741 | " ),\n",
742 | " mode='lines',\n",
743 | " showlegend=False)\n",
744 | " \n",
745 | " return fig\n",
746 | "\n",
747 | "# draw all rays\n",
748 | "def draw_all_rays_with_marker(fig, p_i, marker_size, ray_color):\n",
749 | " \n",
750 | " # convert colorscale string to px.colors.seqeuntial\n",
751 | " # default color is set to Viridis in case of mismatch\n",
752 | " c = px.colors.sequential.Viridis\n",
753 | "\n",
754 | " for c_name in [ray_color, ray_color.capitalize()]:\n",
755 | " try:\n",
756 | " c = getattr(px.colors.sequential, c_name)\n",
757 | " except:\n",
758 | " continue\n",
759 | " \n",
760 | " for i in range(p_i.shape[1]):\n",
761 | " Xray = p_i[0, i, :, 0]\n",
762 | " Yray = p_i[0, i, :, 1]\n",
763 | " Zray = p_i[0, i, :, 2]\n",
764 | " \n",
765 | " fig.add_scatter3d(x=Xray, y=Yray, z=Zray, \n",
766 | " \n",
767 | " marker=dict(\n",
768 | "# color=np.arange(len(Xray)),\n",
769 | " color=c,\n",
770 | "# colorscale='Viridis',\n",
771 | " size=marker_size\n",
772 | " ),\n",
773 | " \n",
774 | " line=dict(\n",
775 | "# color=np.arange(len(Xray)),\n",
776 | " color=c,\n",
777 | "# colorscale='Viridis',\n",
778 | " width=3\n",
779 | " ),\n",
780 | " mode=\"lines+markers\",\n",
781 | " showlegend=False)\n",
782 | " \n",
783 | " return fig\n",
784 | "\n",
785 | "# draw near&far frustrum with rays connecting the corners (changed for nerfpp)\n",
786 | "def draw_ray_frus(fig, p_i, frustrum_color, frustrum_opacity, at=[0, -1]):\n",
787 | " \n",
788 | " for i in at:\n",
789 | "# Xfrus = p_i[0, :, i, 0][[0,1,2,3,7,11,15,14,13,12,8,4,0]]\n",
790 | "# Yfrus = p_i[0, :, i, 1][[0,1,2,3,7,11,15,14,13,12,8,4,0]]\n",
791 | "# Zfrus = p_i[0, :, i, 2][[0,1,2,3,7,11,15,14,13,12,8,4,0]]\n",
792 | "\n",
793 | " Xfrus = p_i[0, :, i, 0]\n",
794 | " Yfrus = p_i[0, :, i, 1]\n",
795 | " Zfrus = p_i[0, :, i, 2]\n",
796 | " \n",
797 | " fig.add_scatter3d(x=Xfrus, y=Yfrus, z=Zfrus, \n",
798 | " line=dict(\n",
799 | " color=frustrum_color,\n",
800 | " width=5\n",
801 | " ),\n",
802 | " mode='lines',\n",
803 | " surfaceaxis=0,\n",
804 | " surfacecolor=frustrum_color,\n",
805 | " opacity=frustrum_opacity,\n",
806 | " showlegend=False)\n",
807 | " \n",
808 | " return fig\n",
809 | "\n",
810 | "# draw foreground sample points, ray and frustrum\n",
811 | "def draw_foreground(fig, fg_pts, fg_color, marker_size, at=[0, -1]):\n",
812 | " fig = draw_all_rays_with_marker(fig, fg_pts, marker_size, fg_color)\n",
813 | " \n",
814 | " return fig\n",
815 | "\n",
816 | "# draw background sample points, ray and frustrum\n",
817 | "def draw_background(fig, bg_pts, bg_color, marker_size, at=[0, -1]):\n",
818 | " fig = draw_all_rays_with_marker(fig, bg_pts, marker_size, bg_color)\n",
819 | " \n",
820 | " return fig"
821 | ]
822 | },
823 | {
824 | "cell_type": "code",
825 | "execution_count": 8,
826 | "id": "f12fecdd",
827 | "metadata": {},
828 | "outputs": [
829 | {
830 | "name": "stderr",
831 | "output_type": "stream",
832 | "text": [
833 | "C:\\Users\\laphi\\AppData\\Local\\Temp/ipykernel_24504/3891967274.py:2: UserWarning: \n",
834 | "The dash_core_components package is deprecated. Please replace\n",
835 | "`import dash_core_components as dcc` with `from dash import dcc`\n",
836 | " import dash_core_components as dcc\n",
837 | "C:\\Users\\laphi\\AppData\\Local\\Temp/ipykernel_24504/3891967274.py:3: UserWarning: \n",
838 | "The dash_html_components package is deprecated. Please replace\n",
839 | "`import dash_html_components as html` with `from dash import html`\n",
840 | " import dash_html_components as html\n"
841 | ]
842 | }
843 | ],
844 | "source": [
845 | "from jupyter_dash import JupyterDash\n",
846 | "import dash_core_components as dcc\n",
847 | "import dash_html_components as html\n",
848 | "from dash.dependencies import Input, Output\n",
849 | "\n",
850 | "# for colors\n",
851 | "import matplotlib.colors as mcolors"
852 | ]
853 | },
854 | {
855 | "cell_type": "code",
856 | "execution_count": 10,
857 | "id": "e5e28428",
858 | "metadata": {
859 | "scrolled": false
860 | },
861 | "outputs": [
862 | {
863 | "name": "stderr",
864 | "output_type": "stream",
865 | "text": [
866 | "c:\\users\\laphi\\appdata\\local\\programs\\python\\python38\\lib\\site-packages\\jupyter_dash\\jupyter_app.py:139: UserWarning:\n",
867 | "\n",
868 | "The 'environ['werkzeug.server.shutdown']' function is deprecated and will be removed in Werkzeug 2.1.\n",
869 | "\n"
870 | ]
871 | },
872 | {
873 | "data": {
874 | "text/html": [
875 | "\n",
876 | " \n",
883 | " "
884 | ],
885 | "text/plain": [
886 | ""
887 | ]
888 | },
889 | "metadata": {},
890 | "output_type": "display_data"
891 | }
892 | ],
893 | "source": [
894 | "app = JupyterDash(__name__)\n",
895 | "\n",
896 | "app.layout = html.Div([\n",
897 | " html.H1(\"Nerfplusplus ray sampling visualization\"),\n",
898 | " dcc.Graph(id='graph'),\n",
899 | " \n",
900 | " html.Div([\n",
901 | " html.Div([\n",
902 | " \n",
903 | " # changes to setting and ray casted \n",
904 | " html.Label([ \"u (theta)\",\n",
905 | " dcc.Slider(\n",
906 | " id='u-slider', \n",
907 | " min=0, max=1,\n",
908 | " value=0.00,\n",
909 | " marks={str(val) : str(val) for val in [0.00, 0.25, 0.50, 0.75]},\n",
910 | " step=0.01, tooltip = { 'always_visible': True }\n",
911 | " ), ]),\n",
912 | " html.Label([ \"v (phi)\",\n",
913 | " dcc.Slider(\n",
914 | " id='v-slider', \n",
915 | " min=0, max=1,\n",
916 | " value=0.25,\n",
917 | " marks={str(val) : str(val) for val in [0.00, 0.25, 0.50, 0.75]},\n",
918 | " step=0.01, tooltip = { 'always_visible': True }\n",
919 | " ), ]),\n",
920 | " html.Label([ \"fov (field-of-view))\",\n",
921 | " dcc.Slider(\n",
922 | " id='fov-slider', \n",
923 | " min=0, max=100,\n",
924 | " value=50,\n",
925 | " marks={str(val) : str(val) for val in [0, 20, 40, 60, 80, 100]},\n",
926 | " step=5, tooltip = { 'always_visible': True }\n",
927 | " ), ]),\n",
928 | " html.Label([ \"foreground near depth\",\n",
929 | " dcc.Slider(\n",
930 | " id='foreground-near-depth-slider', \n",
931 | " min=0, max=2,\n",
932 | " value=0.5,\n",
933 | " marks={f\"{val:.1f}\" : f\"{val:.1f}\" for val in [0.1 * i for i in range(21)]},\n",
934 | " step=0.1, tooltip = { 'always_visible': True }\n",
935 | " ), ])\n",
936 | " ], style = {'width' : '48%', 'display' : 'inline-block'}),\n",
937 | " \n",
938 | " html.Div([\n",
939 | " # changes to visual appearance\n",
940 | " \n",
941 | " # axis scale\n",
942 | " html.Div([\n",
943 | " html.Label([ \"world axis size\",\n",
944 | " html.Div([\n",
945 | " dcc.Input(id='world-axis-size-input',\n",
946 | " value=1.5,\n",
947 | " type='number', style={'width': '50%'}\n",
948 | " )\n",
949 | " ]),\n",
950 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n",
951 | " html.Label([ \"camera axis size\",\n",
952 | " html.Div([\n",
953 | " dcc.Input(id='camera-axis-size-input',\n",
954 | " value=0.3,\n",
955 | " type='number', style={'width': '50%'}\n",
956 | " )\n",
957 | " ]),\n",
958 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n",
959 | " html.Label([ \"sample marker size\",\n",
960 | " html.Div([\n",
961 | " dcc.Input(id='sample-marker-size-input',\n",
962 | " value=2,\n",
963 | " type='number', style={'width': '50%'}\n",
964 | " )\n",
965 | " ]),\n",
966 | " ], style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}),\n",
967 | " ]),\n",
968 | " \n",
969 | " # opacity \n",
970 | " html.Div([\n",
971 | " html.Label([ \"sphere opacity\",\n",
972 | " html.Div([\n",
973 | " dcc.Input(id='sphere-opacity-input',\n",
974 | " value=0.2,\n",
975 | " type='number', style={'width': '50%'}\n",
976 | " )\n",
977 | " ])\n",
978 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n",
979 | " \n",
980 | " html.Label([ \"xy-plane opacity\", \n",
981 | " html.Div([\n",
982 | " dcc.Input(id='xy-plane-opacity-input',\n",
983 | " value=0.8,\n",
984 | " type='number', style={'width': '50%'}\n",
985 | " )\n",
986 | " ])\n",
987 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n",
988 | " \n",
989 | " ]),\n",
990 | " \n",
991 | " # color\n",
992 | " html.Div([\n",
993 | " html.Label([ \"camera color\",\n",
994 | " html.Div([\n",
995 | " dcc.Dropdown(id='camera-color-input',\n",
996 | " clearable=False,\n",
997 | " value='yellow',\n",
998 | " options=[\n",
999 | " {'label': c, 'value': c}\n",
1000 | " for (c, _) in mcolors.CSS4_COLORS.items()\n",
1001 | " ], style={'width': '80%'}\n",
1002 | " )\n",
1003 | " ])\n",
1004 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n",
1005 | " \n",
1006 | " html.Label([ \"foreground color\",\n",
1007 | " html.Div([\n",
1008 | " dcc.Dropdown(id='fg-color-input',\n",
1009 | " clearable=False,\n",
1010 | " value='plotly3',\n",
1011 | " options=[\n",
1012 | " {'label': c, 'value': c}\n",
1013 | " for c in px.colors.named_colorscales()\n",
1014 | " ], style={'width': '80%'}\n",
1015 | " )\n",
1016 | " ])\n",
1017 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n",
1018 | " \n",
1019 | " html.Label([ \"background color\",\n",
1020 | " html.Div([\n",
1021 | " dcc.Dropdown(id='bg-color-input',\n",
1022 | " clearable=False,\n",
1023 | " value='plotly3',\n",
1024 | " options=[\n",
1025 | " {'label': c, 'value': c}\n",
1026 | " for c in px.colors.named_colorscales()\n",
1027 | " ], style={'width': '80%'}\n",
1028 | " )\n",
1029 | " ])\n",
1030 | " ], style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}),\n",
1031 | " ]),\n",
1032 | " \n",
1033 | " # colorscale\n",
1034 | " html.Div([\n",
1035 | " html.Label([ \"sphere colorscale\",\n",
1036 | " html.Div([\n",
1037 | " dcc.Dropdown(id='sphere-colorscale-input',\n",
1038 | " clearable=False,\n",
1039 | " value='greys',\n",
1040 | " options=[\n",
1041 | " {'label': c, 'value': c}\n",
1042 | " for c in px.colors.named_colorscales()\n",
1043 | " ], style={'width': '80%'}\n",
1044 | " )\n",
1045 | " ])\n",
1046 | " ], style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}),\n",
1047 | " \n",
1048 | " html.Label([ \"xy-plane colorscale\",\n",
1049 | " html.Div([\n",
1050 | " dcc.Dropdown(id='xy-plane-colorscale-input',\n",
1051 | " clearable=False,\n",
1052 | " value='greys',\n",
1053 | " options=[\n",
1054 | " {'label': c, 'value': c}\n",
1055 | " for c in px.colors.named_colorscales()\n",
1056 | " ], style={'width': '80%'}\n",
1057 | " )\n",
1058 | " ])\n",
1059 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n",
1060 | " \n",
1061 | " html.Label([ \" \",\n",
1062 | " html.Div([\n",
1063 | " dcc.Checklist(id='show-background-checklist',\n",
1064 | " \n",
1065 | " options=[\n",
1066 | " {'label': 'show background', 'value': 'show_background'}\n",
1067 | " ], style={'width': '80%'},\n",
1068 | " value=[],\n",
1069 | " )\n",
1070 | " ])\n",
1071 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n",
1072 | " \n",
1073 | " ]),\n",
1074 | " \n",
1075 | " \n",
1076 | " \n",
1077 | " ], style = {'width' : '48%', 'float' : 'right', 'display' : 'inline-block'}),\n",
1078 | " \n",
1079 | " ]),\n",
1080 | " \n",
1081 | "])\n",
1082 | "\n",
1083 | "@app.callback(\n",
1084 | " Output('graph', 'figure'),\n",
1085 | " Input(\"u-slider\", \"value\"),\n",
1086 | " Input(\"v-slider\", \"value\"),\n",
1087 | " \n",
1088 | " Input(\"fov-slider\", \"value\"),\n",
1089 | " Input(\"foreground-near-depth-slider\", \"value\"),\n",
1090 | " \n",
1091 | " Input(\"world-axis-size-input\", \"value\"),\n",
1092 | " Input(\"camera-axis-size-input\", \"value\"),\n",
1093 | " Input(\"sample-marker-size-input\", \"value\"),\n",
1094 | " \n",
1095 | " Input(\"camera-color-input\", \"value\"),\n",
1096 | " Input(\"fg-color-input\", \"value\"),\n",
1097 | " Input(\"bg-color-input\", \"value\"),\n",
1098 | " \n",
1099 | " Input('sphere-colorscale-input', \"value\"),\n",
1100 | " Input('xy-plane-colorscale-input', \"value\"),\n",
1101 | " Input('show-background-checklist', \"value\"),\n",
1102 | " \n",
1103 | " Input(\"sphere-opacity-input\", \"value\"),\n",
1104 | " Input(\"xy-plane-opacity-input\", \"value\"),\n",
1105 | ")\n",
1106 | "\n",
1107 | "def update_figure(u, v, \n",
1108 | " fov, foreground_near_depth,\n",
1109 | " world_axis_size, camera_axis_size, sample_marker_size,\n",
1110 | " camera_color, fg_color, bg_color,\n",
1111 | " sphere_colorscale, xy_plane_colorscale, show_background,\n",
1112 | " sphere_opacity, xy_plane_opacity \n",
1113 | " ):\n",
1114 | " \n",
1115 | " depth_range = [foreground_near_depth, 2]\n",
1116 | " \n",
1117 | " # sphere\n",
1118 | " fig = draw_sphere(r=1, sphere_colorscale=sphere_colorscale, sphere_opacity=sphere_opacity)\n",
1119 | "\n",
1120 | " # change figure size\n",
1121 | "# fig.update_layout(autosize=False, width = 500, height=500)\n",
1122 | "\n",
1123 | " # draw axes in proportion to the proportion of their ranges\n",
1124 | " fig.update_layout(scene_aspectmode='data')\n",
1125 | "\n",
1126 | " # xy plane\n",
1127 | " fig = draw_XYplane(fig, xy_plane_colorscale, xy_plane_opacity,\n",
1128 | " x_range=[-depth_range[1], depth_range[1]], y_range=[-depth_range[1], depth_range[1]])\n",
1129 | "\n",
1130 | " # show world coordinate system (X, Y, Z positive direction)\n",
1131 | " fig = draw_XYZworld(fig, world_axis_size=world_axis_size)\n",
1132 | "\n",
1133 | " pixels_world, camera_world, world_mat, fg_pts, bg_pts = nerfpp(u=u, v=v, fov=fov, depth_range=depth_range)\n",
1134 | "\n",
1135 | " # draw camera at init (with its cooridnate system)\n",
1136 | " fig = draw_cam_init(fig, world_mat, \n",
1137 | " camera_axis_size=camera_axis_size, camera_color=camera_color)\n",
1138 | "\n",
1139 | " # draw foreground and background sample point, ray, and frustrum\n",
1140 | " fig = draw_foreground(fig, fg_pts, fg_color, sample_marker_size, at=[0, -1])\n",
1141 | " \n",
1142 | " if show_background:\n",
1143 | " fig = draw_background(fig, bg_pts.unsqueeze(0), bg_color, sample_marker_size, at=[0, -1])\n",
1144 | " \n",
1145 | " return fig\n",
1146 | "\n",
1147 | "app.run_server(mode='inline')"
1148 | ]
1149 | }
1150 | ],
1151 | "metadata": {
1152 | "kernelspec": {
1153 | "display_name": "Python 3 (ipykernel)",
1154 | "language": "python",
1155 | "name": "python3"
1156 | },
1157 | "language_info": {
1158 | "codemirror_mode": {
1159 | "name": "ipython",
1160 | "version": 3
1161 | },
1162 | "file_extension": ".py",
1163 | "mimetype": "text/x-python",
1164 | "name": "python",
1165 | "nbconvert_exporter": "python",
1166 | "pygments_lexer": "ipython3",
1167 | "version": "3.8.8"
1168 | }
1169 | },
1170 | "nbformat": 4,
1171 | "nbformat_minor": 5
1172 | }
1173 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VisualizingNerfplusplus
2 |
3 | initialized to the Nerfplusplus setting.
4 |
5 | run code and have fun~
6 |
7 | ```
8 | python app.py
9 | ```
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/__pycache__/drawing_tools.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laphisboy/VisualizingNerfplusplus/4b86e2d15215b0baa29af6cc7e794c8521fa466a/__pycache__/drawing_tools.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/nerfplusplus_tools.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laphisboy/VisualizingNerfplusplus/4b86e2d15215b0baa29af6cc7e794c8521fa466a/__pycache__/nerfplusplus_tools.cpython-38.pyc
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import plotly.express as px
2 |
3 | import dash
4 | import dash_core_components as dcc
5 | import dash_html_components as html
6 | from dash.dependencies import Input, Output
7 |
8 | # for colors
9 | import matplotlib.colors as mcolors
10 |
11 | from drawing_tools import *
12 | from nerfplusplus_tools import *
13 |
14 | app = dash.Dash(__name__)
15 |
16 | app.layout = html.Div([
17 | html.H1("Nerfplusplus ray sampling visualization"),
18 | dcc.Graph(id='graph'),
19 |
20 | html.Div([
21 | html.Div([
22 |
23 | # changes to setting and ray casted
24 | html.Label([ "u (theta)",
25 | dcc.Slider(
26 | id='u-slider',
27 | min=0, max=1,
28 | value=0.00,
29 | marks={str(val) : str(val) for val in [0.00, 0.25, 0.50, 0.75]},
30 | step=0.01, tooltip = { 'always_visible': True }
31 | ), ]),
32 | html.Label([ "v (phi)",
33 | dcc.Slider(
34 | id='v-slider',
35 | min=0, max=1,
36 | value=0.25,
37 | marks={str(val) : str(val) for val in [0.00, 0.25, 0.50, 0.75]},
38 | step=0.01, tooltip = { 'always_visible': True }
39 | ), ]),
40 | html.Label([ "fov (field-of-view))",
41 | dcc.Slider(
42 | id='fov-slider',
43 | min=0, max=100,
44 | value=50,
45 | marks={str(val) : str(val) for val in [0, 20, 40, 60, 80, 100]},
46 | step=5, tooltip = { 'always_visible': True }
47 | ), ]),
48 | html.Label([ "foreground near depth",
49 | dcc.Slider(
50 | id='foreground-near-depth-slider',
51 | min=0, max=2,
52 | value=0.5,
53 | marks={f"{val:.1f}" : f"{val:.1f}" for val in [0.1 * i for i in range(21)]},
54 | step=0.1, tooltip = { 'always_visible': True }
55 | ), ])
56 | ], style = {'width' : '48%', 'display' : 'inline-block'}),
57 |
58 | html.Div([
59 | # changes to visual appearance
60 |
61 | # axis scale
62 | html.Div([
63 | html.Label([ "world axis size",
64 | html.Div([
65 | dcc.Input(id='world-axis-size-input',
66 | value=1.5,
67 | type='number', style={'width': '50%'}
68 | )
69 | ]),
70 | ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),
71 | html.Label([ "camera axis size",
72 | html.Div([
73 | dcc.Input(id='camera-axis-size-input',
74 | value=0.3,
75 | type='number', style={'width': '50%'}
76 | )
77 | ]),
78 | ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),
79 | html.Label([ "sample marker size",
80 | html.Div([
81 | dcc.Input(id='sample-marker-size-input',
82 | value=2,
83 | type='number', style={'width': '50%'}
84 | )
85 | ]),
86 | ], style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}),
87 | ]),
88 |
89 | # opacity
90 | html.Div([
91 | html.Label([ "sphere opacity",
92 | html.Div([
93 | dcc.Input(id='sphere-opacity-input',
94 | value=0.2,
95 | type='number', style={'width': '50%'}
96 | )
97 | ])
98 | ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),
99 |
100 | html.Label([ "xy-plane opacity",
101 | html.Div([
102 | dcc.Input(id='xy-plane-opacity-input',
103 | value=0.8,
104 | type='number', style={'width': '50%'}
105 | )
106 | ])
107 | ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),
108 |
109 | ]),
110 |
111 | # color
112 | html.Div([
113 | html.Label([ "camera color",
114 | html.Div([
115 | dcc.Dropdown(id='camera-color-input',
116 | clearable=False,
117 | value='yellow',
118 | options=[
119 | {'label': c, 'value': c}
120 | for (c, _) in mcolors.CSS4_COLORS.items()
121 | ], style={'width': '80%'}
122 | )
123 | ])
124 | ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),
125 |
126 | html.Label([ "foreground color",
127 | html.Div([
128 | dcc.Dropdown(id='fg-color-input',
129 | clearable=False,
130 | value='plotly3',
131 | options=[
132 | {'label': c, 'value': c}
133 | for c in px.colors.named_colorscales()
134 | ], style={'width': '80%'}
135 | )
136 | ])
137 | ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),
138 |
139 | html.Label([ "background color",
140 | html.Div([
141 | dcc.Dropdown(id='bg-color-input',
142 | clearable=False,
143 | value='plotly3',
144 | options=[
145 | {'label': c, 'value': c}
146 | for c in px.colors.named_colorscales()
147 | ], style={'width': '80%'}
148 | )
149 | ])
150 | ], style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}),
151 | ]),
152 |
153 | # colorscale
154 | html.Div([
155 | html.Label([ "sphere colorscale",
156 | html.Div([
157 | dcc.Dropdown(id='sphere-colorscale-input',
158 | clearable=False,
159 | value='greys',
160 | options=[
161 | {'label': c, 'value': c}
162 | for c in px.colors.named_colorscales()
163 | ], style={'width': '80%'}
164 | )
165 | ])
166 | ], style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}),
167 |
168 | html.Label([ "xy-plane colorscale",
169 | html.Div([
170 | dcc.Dropdown(id='xy-plane-colorscale-input',
171 | clearable=False,
172 | value='greys',
173 | options=[
174 | {'label': c, 'value': c}
175 | for c in px.colors.named_colorscales()
176 | ], style={'width': '80%'}
177 | )
178 | ])
179 | ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),
180 |
181 | html.Label([ " ",
182 | html.Div([
183 | dcc.Checklist(id='show-background-checklist',
184 |
185 | options=[
186 | {'label': 'show background', 'value': 'show_background'}
187 | ], style={'width': '80%'},
188 | value=[ ],
189 | )
190 | ])
191 | ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),
192 |
193 | ]),
194 |
195 |
196 |
197 | ], style = {'width' : '48%', 'float' : 'right', 'display' : 'inline-block'}),
198 |
199 | ]),
200 |
201 | ])
202 |
203 | @app.callback(
204 | Output('graph', 'figure'),
205 | Input("u-slider", "value"),
206 | Input("v-slider", "value"),
207 |
208 | Input("fov-slider", "value"),
209 | Input("foreground-near-depth-slider", "value"),
210 |
211 | Input("world-axis-size-input", "value"),
212 | Input("camera-axis-size-input", "value"),
213 | Input("sample-marker-size-input", "value"),
214 |
215 | Input("camera-color-input", "value"),
216 | Input("fg-color-input", "value"),
217 | Input("bg-color-input", "value"),
218 |
219 | Input('sphere-colorscale-input', "value"),
220 | Input('xy-plane-colorscale-input', "value"),
221 | Input('show-background-checklist', "value"),
222 |
223 | Input("sphere-opacity-input", "value"),
224 | Input("xy-plane-opacity-input", "value"),
225 | )
226 |
227 | def update_figure(u, v,
228 | fov, foreground_near_depth,
229 | world_axis_size, camera_axis_size, sample_marker_size,
230 | camera_color, fg_color, bg_color,
231 | sphere_colorscale, xy_plane_colorscale, show_background,
232 | sphere_opacity, xy_plane_opacity
233 | ):
234 |
235 | depth_range = [foreground_near_depth, 2]
236 |
237 | # sphere
238 | fig = draw_sphere(r=1, sphere_colorscale=sphere_colorscale, sphere_opacity=sphere_opacity)
239 |
240 | # change figure size
241 | # fig.update_layout(autosize=False, width = 500, height=500)
242 |
243 | # draw axes in proportion to the proportion of their ranges
244 | fig.update_layout(scene_aspectmode='data')
245 |
246 | # xy plane
247 | fig = draw_XYplane(fig, xy_plane_colorscale, xy_plane_opacity,
248 | x_range=[-depth_range[1], depth_range[1]], y_range=[-depth_range[1], depth_range[1]])
249 |
250 | # show world coordinate system (X, Y, Z positive direction)
251 | fig = draw_XYZworld(fig, world_axis_size=world_axis_size)
252 |
253 | pixels_world, camera_world, world_mat, fg_pts, bg_pts = nerfpp(u=u, v=v, fov=fov, depth_range=depth_range)
254 |
255 | # draw camera at init (with its cooridnate system)
256 | fig = draw_cam_init(fig, world_mat,
257 | camera_axis_size=camera_axis_size, camera_color=camera_color)
258 |
259 | # draw foreground and background sample point, ray, and frustrum
260 | fig = draw_foreground(fig, fg_pts, fg_color, sample_marker_size, at=[0, -1])
261 |
262 | if show_background:
263 | fig = draw_background(fig, bg_pts.unsqueeze(0), bg_color, sample_marker_size, at=[0, -1])
264 |
265 | return fig
266 |
267 | if __name__ == '__main__':
268 | app.run_server(debug=True)
--------------------------------------------------------------------------------
/drawing_tools.py:
--------------------------------------------------------------------------------
1 | import plotly.graph_objects as go
2 | import plotly.express as px
3 | import numpy as np
4 |
5 | # draw sphere with radius r
6 | # also draw contours and vertical lines
7 | def draw_sphere(r, sphere_colorscale, sphere_opacity):
8 | # sphere
9 | u = np.linspace(0, 2 * np.pi, 100)
10 | v = np.linspace(0, np.pi, 100)
11 | x = r * np.outer(np.cos(u), np.sin(v))
12 | y = r * np.outer(np.sin(u), np.sin(v))
13 | z = r * np.outer(np.ones(np.size(u)), np.cos(v))
14 |
15 | # vertical lines on sphere
16 | u2 = np.linspace(0, 2 * np.pi, 20)
17 | x2 = r * np.outer(np.cos(u2), np.sin(v))
18 | y2 = r * np.outer(np.sin(u2), np.sin(v))
19 | z2 = r * np.outer(np.ones(np.size(u2)), np.cos(v))
20 |
21 | # create sphere and draw sphere with contours
22 | fig = go.Figure(data=[go.Surface(x=x, y=y, z=z,
23 | colorscale=sphere_colorscale, opacity=sphere_opacity,
24 | contours = {
25 | 'z' : {'show' : True, 'start' : -r,
26 | 'end' : r, 'size' : r/10,
27 | 'color' : 'white',
28 | 'width' : 1}
29 | }
30 | , showscale=False)])
31 |
32 | # vertical lines on sphere
33 | for i in range(len(u2)):
34 | fig.add_scatter3d(x=x2[i], y=y2[i], z=z2[i],
35 | line=dict(
36 | color='white',
37 | width=1
38 | ),
39 | mode='lines',
40 | showlegend=False)
41 |
42 | return fig
43 |
44 | # draw xyplane
45 | def draw_XYplane(fig, xy_plane_colorscale, xy_plane_opacity, x_range = [-2, 2], y_range = [-2, 2]):
46 | x3 = np.linspace(x_range[0], x_range[1], 100)
47 | y3 = np.linspace(y_range[0], y_range[1], 100)
48 | z3 = np.zeros(shape=(100,100))
49 |
50 | fig.add_surface(x=x3, y=y3, z=z3,
51 | colorscale =xy_plane_colorscale, opacity=xy_plane_opacity,
52 | showscale=False
53 | )
54 |
55 | return fig
56 |
57 |
58 | def draw_XYZworld(fig, world_axis_size):
59 | # x, y, z positive direction (world)
60 | X_axis = [0, world_axis_size]
61 | X_text = [None, "X"]
62 | X0 = [0, 0]
63 | Y_axis = [0, world_axis_size]
64 | Y_text = [None, "Y"]
65 | Y0 = [0, 0]
66 | Z_axis = [0, world_axis_size]
67 | Z_text = [None, "Z"]
68 | Z0 = [0, 0]
69 |
70 | fig.add_scatter3d(x=X_axis, y=Y0, z=Z0,
71 | line=dict(
72 | color='red',
73 | width=10
74 | ),
75 | mode='lines+text',
76 | text=X_text,
77 | textposition='top center',
78 | textfont=dict(
79 | color="red",
80 | size=18
81 | ),
82 | showlegend=False)
83 |
84 | fig.add_scatter3d(x=X0, y=Y_axis, z=Z0,
85 | line=dict(
86 | color='green',
87 | width=10
88 | ),
89 | mode='lines+text',
90 | text=Y_text,
91 | textposition='top center',
92 | textfont=dict(
93 | color="green",
94 | size=18
95 | ),
96 | showlegend=False)
97 |
98 | fig.add_scatter3d(x=X0, y=Y0, z=Z_axis,
99 | line=dict(
100 | color='blue',
101 | width=10
102 | ),
103 | mode='lines+text',
104 | text=Z_text,
105 | textposition='top center',
106 | textfont=dict(
107 | color="blue",
108 | size=18
109 | ),
110 | showlegend=False)
111 |
112 | return fig
113 |
114 | # draw cam and cam coordinate system
115 | def draw_cam_init(fig, world_mat, camera_axis_size, camera_color):
116 | # camera at init
117 |
118 | Xc = [world_mat[0, : ,3][0]]
119 | Yc = [world_mat[0, : ,3][1]]
120 | Zc = [world_mat[0, : ,3][2]]
121 | text_c = ["Camera"]
122 |
123 | # camera axis
124 | Xc_Xaxis = Xc + [world_mat[0, : ,0][0]*camera_axis_size+Xc[0]]
125 | Yc_Xaxis = Yc + [world_mat[0, : ,0][1]*camera_axis_size+Yc[0]]
126 | Zc_Xaxis = Zc + [world_mat[0, : ,0][2]*camera_axis_size+Zc[0]]
127 | text_Xaxis = [None, "Xc"]
128 |
129 | # -z in world perspective
130 | Xc_Yaxis = Xc + [world_mat[0, : ,1][0]*camera_axis_size+Xc[0]]
131 | Yc_Yaxis = Yc + [world_mat[0, : ,1][1]*camera_axis_size+Yc[0]]
132 | Zc_Yaxis = Zc + [world_mat[0, : ,1][2]*camera_axis_size+Zc[0]]
133 | text_Yaxis = [None, "Yc"]
134 |
135 | # y in world perspective
136 | Xc_Zaxis = Xc + [world_mat[0, : ,2][0]*camera_axis_size+Xc[0]]
137 | Yc_Zaxis = Yc + [world_mat[0, : ,2][1]*camera_axis_size+Yc[0]]
138 | Zc_Zaxis = Zc + [world_mat[0, : ,2][2]*camera_axis_size+Zc[0]]
139 | text_Zaxis = [None, "Zc"]
140 |
141 | # cam pos
142 | fig.add_scatter3d(x=Xc, y=Yc, z=Zc,
143 | mode='markers',
144 | marker=dict(
145 | color=camera_color,
146 | size=4,
147 | sizemode='diameter'
148 | ),
149 | showlegend=False)
150 |
151 | # camera axis
152 | fig.add_scatter3d(x=Xc_Xaxis, y=Yc_Xaxis, z=Zc_Xaxis,
153 | line=dict(
154 | color='red',
155 | width=10
156 | ),
157 | mode='lines+text',
158 | text=text_Xaxis,
159 | textposition='top center',
160 | textfont=dict(
161 | color="red",
162 | size=18
163 | ),
164 | showlegend=False)
165 |
166 | fig.add_scatter3d(x=Xc_Yaxis, y=Yc_Yaxis, z=Zc_Yaxis,
167 | line=dict(
168 | color='green',
169 | width=10
170 | ),
171 | mode='lines+text',
172 | text=text_Yaxis,
173 | textposition='top center',
174 | textfont=dict(
175 | color="green",
176 | size=18
177 | ),
178 | showlegend=False)
179 |
180 | fig.add_scatter3d(x=Xc_Zaxis, y=Yc_Zaxis, z=Zc_Zaxis,
181 | line=dict(
182 | color='blue',
183 | width=10
184 | ),
185 | mode='lines+text',
186 | text=text_Zaxis,
187 | textposition='top center',
188 | textfont=dict(
189 | color="blue",
190 | size=18
191 | ),
192 | showlegend=False)
193 |
194 | return fig
195 |
196 | # draw all rays
197 | def draw_all_rays(fig, p_i, ray_color):
198 | for i in range(p_i.shape[1]):
199 | Xray = p_i[0, i, :, 0]
200 | Yray = p_i[0, i, :, 1]
201 | Zray = p_i[0, i, :, 2]
202 |
203 | fig.add_scatter3d(x=Xray, y=Yray, z=Zray,
204 | line=dict(
205 | color=ray_color,
206 | width=5
207 | ),
208 | mode='lines',
209 | showlegend=False)
210 |
211 | return fig
212 |
213 | # draw all rays
214 | def draw_all_rays_with_marker(fig, p_i, marker_size, ray_color):
215 |
216 | # convert colorscale string to px.colors.seqeuntial
217 | # default color is set to Viridis in case of mismatch
218 | c = px.colors.sequential.Viridis
219 |
220 | for c_name in [ray_color, ray_color.capitalize()]:
221 | try:
222 | c = getattr(px.colors.sequential, c_name)
223 | except:
224 | continue
225 |
226 | for i in range(p_i.shape[1]):
227 | Xray = p_i[0, i, :, 0]
228 | Yray = p_i[0, i, :, 1]
229 | Zray = p_i[0, i, :, 2]
230 |
231 | fig.add_scatter3d(x=Xray, y=Yray, z=Zray,
232 |
233 | marker=dict(
234 | # color=np.arange(len(Xray)),
235 | color=c,
236 | # colorscale='Viridis',
237 | size=marker_size
238 | ),
239 |
240 | line=dict(
241 | # color=np.arange(len(Xray)),
242 | color=c,
243 | # colorscale='Viridis',
244 | width=3
245 | ),
246 | mode="lines+markers",
247 | showlegend=False)
248 |
249 | return fig
250 |
251 | # draw near&far frustrum with rays connecting the corners (changed for nerfpp)
252 | def draw_ray_frus(fig, p_i, frustrum_color, frustrum_opacity, at=[0, -1]):
253 |
254 | for i in at:
255 | # Xfrus = p_i[0, :, i, 0][[0,1,2,3,7,11,15,14,13,12,8,4,0]]
256 | # Yfrus = p_i[0, :, i, 1][[0,1,2,3,7,11,15,14,13,12,8,4,0]]
257 | # Zfrus = p_i[0, :, i, 2][[0,1,2,3,7,11,15,14,13,12,8,4,0]]
258 |
259 | Xfrus = p_i[0, :, i, 0]
260 | Yfrus = p_i[0, :, i, 1]
261 | Zfrus = p_i[0, :, i, 2]
262 |
263 | fig.add_scatter3d(x=Xfrus, y=Yfrus, z=Zfrus,
264 | line=dict(
265 | color=frustrum_color,
266 | width=5
267 | ),
268 | mode='lines',
269 | surfaceaxis=0,
270 | surfacecolor=frustrum_color,
271 | opacity=frustrum_opacity,
272 | showlegend=False)
273 |
274 | return fig
275 |
276 | # draw foreground sample points, ray and frustrum
277 | def draw_foreground(fig, fg_pts, fg_color, marker_size, at=[0, -1]):
278 | fig = draw_all_rays_with_marker(fig, fg_pts, marker_size, fg_color)
279 |
280 | return fig
281 |
282 | # draw background sample points, ray and frustrum
283 | def draw_background(fig, bg_pts, bg_color, marker_size, at=[0, -1]):
284 | fig = draw_all_rays_with_marker(fig, bg_pts, marker_size, bg_color)
285 |
286 | return fig
--------------------------------------------------------------------------------
/example.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laphisboy/VisualizingNerfplusplus/4b86e2d15215b0baa29af6cc7e794c8521fa466a/example.gif
--------------------------------------------------------------------------------
/nerfplusplus_tools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | HUGE_NUMBER = 1e10
5 | TINY_NUMBER = 1e-6
6 |
7 | # from autonomousvision/giraffe
8 |
9 | # 0.
10 | def to_pytorch(tensor, return_type=False):
11 | ''' Converts input tensor to pytorch.
12 | Args:
13 | tensor (tensor): Numpy or Pytorch tensor
14 | return_type (bool): whether to return input type
15 | '''
16 | is_numpy = False
17 | if type(tensor) == np.ndarray:
18 | tensor = torch.from_numpy(tensor).float()
19 | is_numpy = True
20 | tensor = tensor.clone()
21 | if return_type:
22 | return tensor, is_numpy
23 | return tensor
24 |
25 | # 1. get camera intrinsic
26 | def get_camera_mat(fov=49.13, invert=True):
27 | # fov = 2 * arctan( sensor / (2 * focal))
28 | # focal = (sensor / 2) * 1 / (tan(0.5 * fov))
29 | # in our case, sensor = 2 as pixels are in [-1, 1]
30 | focal = 1. / np.tan(0.5 * fov * np.pi/180.)
31 | focal = focal.astype(np.float32)
32 | mat = torch.tensor([
33 | [focal, 0., 0., 0.],
34 | [0., focal, 0., 0.],
35 | [0., 0., 1, 0.],
36 | [0., 0., 0., 1.]
37 | ]).reshape(1, 4, 4)
38 |
39 | if invert:
40 | mat = torch.inverse(mat)
41 | return mat
42 |
43 | # 2. get camera position with camera pose (theta & phi)
44 | def to_sphere(u, v):
45 | theta = 2 * np.pi * u
46 | phi = np.arccos(1 - 2 * v)
47 | cx = np.sin(phi) * np.cos(theta)
48 | cy = np.sin(phi) * np.sin(theta)
49 | cz = np.cos(phi)
50 | return np.stack([cx, cy, cz], axis=-1)
51 |
52 | # 3. get camera coordinate system assuming it points to the center of the sphere
53 | def look_at(eye, at=np.array([0, 0, 0]), up=np.array([0, 0, 1]), eps=1e-5,
54 | to_pytorch=True):
55 | at = at.astype(float).reshape(1, 3)
56 | up = up.astype(float).reshape(1, 3)
57 | eye = eye.reshape(-1, 3)
58 | up = up.repeat(eye.shape[0] // up.shape[0], axis=0)
59 | eps = np.array([eps]).reshape(1, 1).repeat(up.shape[0], axis=0)
60 |
61 | z_axis = eye - at
62 | z_axis /= np.max(np.stack([np.linalg.norm(z_axis,
63 | axis=1, keepdims=True), eps]))
64 |
65 | x_axis = np.cross(up, z_axis)
66 | x_axis /= np.max(np.stack([np.linalg.norm(x_axis,
67 | axis=1, keepdims=True), eps]))
68 |
69 | y_axis = np.cross(z_axis, x_axis)
70 | y_axis /= np.max(np.stack([np.linalg.norm(y_axis,
71 | axis=1, keepdims=True), eps]))
72 |
73 | r_mat = np.concatenate(
74 | (x_axis.reshape(-1, 3, 1), y_axis.reshape(-1, 3, 1), z_axis.reshape(
75 | -1, 3, 1)), axis=2)
76 |
77 | if to_pytorch:
78 | r_mat = torch.tensor(r_mat).float()
79 |
80 | return r_mat
81 |
82 | # 5. arange 2d array of pixel coordinate and give depth of 1
83 | def arange_pixels(resolution=(128, 128), batch_size=1, image_range=(-1., 1.),
84 | subsample_to=None, invert_y_axis=False):
85 | ''' Arranges pixels for given resolution in range image_range.
86 | The function returns the unscaled pixel locations as integers and the
87 | scaled float values.
88 | Args:
89 | resolution (tuple): image resolution
90 | batch_size (int): batch size
91 | image_range (tuple): range of output points (default [-1, 1])
92 | subsample_to (int): if integer and > 0, the points are randomly
93 | subsampled to this value
94 | '''
95 | h, w = resolution
96 | n_points = resolution[0] * resolution[1]
97 |
98 | # Arrange pixel location in scale resolution
99 | pixel_locations = torch.meshgrid(torch.arange(0, w), torch.arange(0, h))
100 | pixel_locations = torch.stack(
101 | [pixel_locations[0], pixel_locations[1]],
102 | dim=-1).long().view(1, -1, 2).repeat(batch_size, 1, 1)
103 | pixel_scaled = pixel_locations.clone().float()
104 |
105 | # Shift and scale points to match image_range
106 | scale = (image_range[1] - image_range[0])
107 | loc = scale / 2
108 | pixel_scaled[:, :, 0] = scale * pixel_scaled[:, :, 0] / (w - 1) - loc
109 | pixel_scaled[:, :, 1] = scale * pixel_scaled[:, :, 1] / (h - 1) - loc
110 |
111 | # Subsample points if subsample_to is not None and > 0
112 | if (subsample_to is not None and subsample_to > 0 and
113 | subsample_to < n_points):
114 | idx = np.random.choice(pixel_scaled.shape[1], size=(subsample_to,),
115 | replace=False)
116 | pixel_scaled = pixel_scaled[:, idx]
117 | pixel_locations = pixel_locations[:, idx]
118 |
119 | if invert_y_axis:
120 | assert(image_range == (-1, 1))
121 | pixel_scaled[..., -1] *= -1.
122 | pixel_locations[..., -1] = (h - 1) - pixel_locations[..., -1]
123 |
124 | return pixel_locations, pixel_scaled
125 |
126 | # 6. mat_mul with intrinsic and then extrinsic gives you p_world (pixels in world)
127 | def image_points_to_world(image_points, camera_mat, world_mat, scale_mat=None,
128 | invert=False, negative_depth=True):
129 | ''' Transforms points on image plane to world coordinates.
130 | In contrast to transform_to_world, no depth value is needed as points on
131 | the image plane have a fixed depth of 1.
132 | Args:
133 | image_points (tensor): image points tensor of size B x N x 2
134 | camera_mat (tensor): camera matrix
135 | world_mat (tensor): world matrix
136 | scale_mat (tensor): scale matrix
137 | invert (bool): whether to invert matrices (default: False)
138 | '''
139 | batch_size, n_pts, dim = image_points.shape
140 | assert(dim == 2)
141 | d_image = torch.ones(batch_size, n_pts, 1)
142 | if negative_depth:
143 | d_image *= -1.
144 | return transform_to_world(image_points, d_image, camera_mat, world_mat,
145 | scale_mat, invert=invert)
146 |
147 | def transform_to_world(pixels, depth, camera_mat, world_mat, scale_mat=None,
148 | invert=True, use_absolute_depth=True):
149 | ''' Transforms pixel positions p with given depth value d to world coordinates.
150 | Args:
151 | pixels (tensor): pixel tensor of size B x N x 2
152 | depth (tensor): depth tensor of size B x N x 1
153 | camera_mat (tensor): camera matrix
154 | world_mat (tensor): world matrix
155 | scale_mat (tensor): scale matrix
156 | invert (bool): whether to invert matrices (default: true)
157 | '''
158 | assert(pixels.shape[-1] == 2)
159 |
160 | if scale_mat is None:
161 | scale_mat = torch.eye(4).unsqueeze(0).repeat(
162 | camera_mat.shape[0], 1, 1)
163 |
164 | # Convert to pytorch
165 | pixels, is_numpy = to_pytorch(pixels, True)
166 | depth = to_pytorch(depth)
167 | camera_mat = to_pytorch(camera_mat)
168 | world_mat = to_pytorch(world_mat)
169 | scale_mat = to_pytorch(scale_mat)
170 |
171 | # Invert camera matrices
172 | if invert:
173 | camera_mat = torch.inverse(camera_mat)
174 | world_mat = torch.inverse(world_mat)
175 | scale_mat = torch.inverse(scale_mat)
176 |
177 | # Transform pixels to homogen coordinates
178 | pixels = pixels.permute(0, 2, 1)
179 | pixels = torch.cat([pixels, torch.ones_like(pixels)], dim=1)
180 |
181 | # Project pixels into camera space
182 | if use_absolute_depth:
183 | pixels[:, :2] = pixels[:, :2] * depth.permute(0, 2, 1).abs()
184 | pixels[:, 2:3] = pixels[:, 2:3] * depth.permute(0, 2, 1)
185 | else:
186 | pixels[:, :3] = pixels[:, :3] * depth.permute(0, 2, 1)
187 |
188 | # Transform pixels to world space
189 | p_world = scale_mat @ world_mat @ camera_mat @ pixels
190 |
191 | # Transform p_world back to 3D coordinates
192 | p_world = p_world[:, :3].permute(0, 2, 1)
193 |
194 | if is_numpy:
195 | p_world = p_world.numpy()
196 | return p_world
197 |
198 |
199 | # 7. mat_mul zeros with intrinsic&extrinsic for camera pos (which we alread obtained as loc)
200 | def origin_to_world(n_points, camera_mat, world_mat, scale_mat=None,
201 | invert=False):
202 | ''' Transforms origin (camera location) to world coordinates.
203 | Args:
204 | n_points (int): how often the transformed origin is repeated in the
205 | form (batch_size, n_points, 3)
206 | camera_mat (tensor): camera matrix
207 | world_mat (tensor): world matrix
208 | scale_mat (tensor): scale matrix
209 | invert (bool): whether to invert the matrices (default: false)
210 | '''
211 |
212 | batch_size = camera_mat.shape[0]
213 | device = camera_mat.device
214 | # Create origin in homogen coordinates
215 | p = torch.zeros(batch_size, 4, n_points).to(device)
216 | p[:, -1] = 1.
217 |
218 | if scale_mat is None:
219 | scale_mat = torch.eye(4).unsqueeze(
220 | 0).repeat(batch_size, 1, 1).to(device)
221 |
222 | # Invert matrices
223 | if invert:
224 | camera_mat = torch.inverse(camera_mat)
225 | world_mat = torch.inverse(world_mat)
226 | scale_mat = torch.inverse(scale_mat)
227 |
228 | camera_mat = to_pytorch(camera_mat)
229 | world_mat = to_pytorch(world_mat)
230 | scale_mat = to_pytorch(scale_mat)
231 |
232 | # Apply transformation
233 | p_world = scale_mat @ world_mat @ camera_mat @ p
234 |
235 | # Transform points back to 3D coordinates
236 | p_world = p_world[:, :3].permute(0, 2, 1)
237 | return p_world
238 |
239 |
240 | # from Kai-46/nerfplusplus
241 |
242 | # 8. intersect sphere for distinguishing fg and bg
243 | def intersect_sphere(ray_o, ray_d):
244 | '''
245 | ray_o, ray_d: [..., 3]
246 | compute the depth of the intersection point between this ray and unit sphere
247 | '''
248 | # note: d1 becomes negative if this mid point is behind camera
249 | d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)
250 | p = ray_o + d1.unsqueeze(-1) * ray_d
251 | # consider the case where the ray does not intersect the sphere
252 | ray_d_cos = 1. / torch.norm(ray_d, dim=-1)
253 | p_norm_sq = torch.sum(p * p, dim=-1)
254 | if (p_norm_sq >= 1.).any():
255 | raise Exception('Not all your cameras are bounded by the unit sphere; please make sure the cameras are normalized properly!')
256 | d2 = torch.sqrt(1. - p_norm_sq) * ray_d_cos
257 |
258 | return d1 + d2
259 |
260 | # 9. inverse sphere sampling for bg
261 | def depth2pts_outside(ray_o, ray_d, depth):
262 | '''
263 | ray_o, ray_d: [..., 3]
264 | depth: [...]; inverse of distance to sphere origin
265 | '''
266 | # note: d1 becomes negative if this mid point is behind camera
267 | d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)
268 | p_mid = ray_o + d1.unsqueeze(-1) * ray_d
269 | p_mid_norm = torch.norm(p_mid, dim=-1)
270 | ray_d_cos = 1. / torch.norm(ray_d, dim=-1)
271 | d2 = torch.sqrt(1. - p_mid_norm * p_mid_norm) * ray_d_cos
272 |
273 | p_sphere = ray_o + (d1 + d2).unsqueeze(-1) * ray_d
274 |
275 | rot_axis = torch.cross(ray_o, p_sphere, dim=-1)
276 | rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True)
277 | phi = torch.asin(p_mid_norm)
278 | theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1]
279 | rot_angle = (phi - theta).unsqueeze(-1) # [..., 1]
280 |
281 | # now rotate p_sphere
282 | # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
283 | p_sphere_new = p_sphere * torch.cos(rot_angle) + \
284 | torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \
285 | rot_axis * torch.sum(rot_axis*p_sphere, dim=-1, keepdim=True) * (1.-torch.cos(rot_angle))
286 | p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True)
287 | pts = torch.cat((p_sphere_new.squeeze(), depth.squeeze().unsqueeze(-1)), dim=-1) # (modified) added .squeeze()
288 |
289 | # now calculate conventional depth
290 | depth_real = 1. / (depth + TINY_NUMBER) * torch.cos(theta) * ray_d_cos + d1
291 | return pts, depth_real
292 |
293 | # 10. perturb sample z values (depth) for some randomness (stratified sampling)
294 | def perturb_samples(z_vals):
295 | # get intervals between samples
296 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
297 | upper = torch.cat([mids, z_vals[..., -1:]], dim=-1)
298 | lower = torch.cat([z_vals[..., 0:1], mids], dim=-1)
299 | # uniform samples in those intervals
300 | t_rand = torch.rand_like(z_vals)
301 | z_vals = lower + (upper - lower) * t_rand # [N_rays, N_samples]
302 |
303 | return z_vals
304 |
305 | # 11. return fg (just uniform) and bg (uniform inverse sphere) samples
306 | def uniformsampling(ray_o, ray_d, min_depth, N_samples, device):
307 |
308 | dots_sh = list(ray_d.shape[:-1])
309 |
310 | # foreground depth
311 | fg_far_depth = intersect_sphere(ray_o, ray_d) # [...,]
312 | fg_near_depth = min_depth # [..., ]
313 | step = (fg_far_depth - fg_near_depth) / (N_samples - 1)
314 | fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples]
315 | fg_depth = perturb_samples(fg_depth) # random perturbation during training
316 |
317 | # background depth
318 | bg_depth = torch.linspace(0., 1., N_samples).view(
319 | [1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(device)
320 | bg_depth = perturb_samples(bg_depth) # random perturbation during training
321 |
322 |
323 | fg_ray_o = ray_o.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
324 | fg_ray_d = ray_d.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
325 |
326 | bg_ray_o = ray_o.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
327 | bg_ray_d = ray_d.unsqueeze(-2).expand(dots_sh + [N_samples, 3])
328 |
329 | # sampling foreground
330 | fg_pts = fg_ray_o + fg_depth.unsqueeze(-1) * fg_ray_d
331 |
332 | # sampling background
333 | bg_pts, bg_depth_real = depth2pts_outside(bg_ray_o, bg_ray_d, bg_depth)
334 |
335 | return fg_pts, bg_pts, bg_depth_real
336 |
337 | # 12 convert bg pts (x', y', z' 1/r) to real bg pts (x, y, z)
338 | def pts2realpts(bg_pts, bg_depth_real):
339 | return bg_pts[:, :, :3] * bg_depth_real.squeeze().unsqueeze(-1)
340 |
341 | def nerfpp(u = 1,
342 | v = 0.5,
343 | fov = 49.13,
344 | depth_range=[0.5, 6.],
345 | n_ray_samples=16,
346 | resolution_vol = 4,
347 | batch_size = 1,
348 | device = torch.device('cpu')
349 | ):
350 |
351 | r = 1
352 | range_radius=[r, r]
353 |
354 | res = resolution_vol
355 | n_points = res * res
356 |
357 | # 1. get camera intrinsic - fiddle around with fov this time
358 | camera_mat = get_camera_mat(fov=fov)
359 |
360 | # 2. get camera position with camera pose (theta & phi)
361 | loc = to_sphere(u, v)
362 | loc = torch.tensor(loc).float()
363 |
364 | radius = range_radius[0] + \
365 | torch.rand(batch_size) * (range_radius[1] - range_radius[0])
366 |
367 | loc = loc * radius.unsqueeze(-1)
368 |
369 | # 3. get camera coordinate system assuming it points to the center of the sphere
370 | R = look_at(loc)
371 |
372 | # 4. The carmera coordinate is the rotational matrix and with camera loc, it is camera extrinsic
373 | RT = np.eye(4).reshape(1, 4, 4)
374 | RT[:, :3, :3] = R
375 | RT[:, :3, -1] = loc
376 | world_mat = RT
377 |
378 | # 5. arange 2d array of pixel coordinate and give depth of 1
379 | pixels = arange_pixels((res, res), 1, invert_y_axis=False)[1]
380 | pixels[..., -1] *= -1. # still dunno why this is here
381 |
382 | # 6. mat_mul with intrinsic and then extrinsic gives you p_world (pixels in world)
383 | pixels_world = image_points_to_world(pixels, camera_mat, world_mat)
384 |
385 | # 7. mat_mul zeros with intrinsic&extrinsic for camera pos (which we alread obtained as loc)
386 | camera_world = origin_to_world(n_points, camera_mat, world_mat)
387 |
388 | # 8. ray = pixel - camera origin (in world)
389 | ray_vector = pixels_world - camera_world
390 |
391 | # 9. sample fg and bg points according to nerfpp (uniform and inverse sphere)
392 | fg_pts, bg_pts, bg_depth_real = uniformsampling(ray_o=camera_world, ray_d=ray_vector, min_depth=depth_range[0], N_samples=n_ray_samples, device=device)
393 |
394 | #10. convert bg pts (x', y', z' 1/r) to real bg pts (x, y, z)
395 | bg_pts_real = pts2realpts(bg_pts, bg_depth_real)
396 |
397 | return pixels_world, camera_world, world_mat, fg_pts, bg_pts_real
398 |
--------------------------------------------------------------------------------