├── .ipynb_checkpoints └── NerfppVisualizationDash-checkpoint.ipynb ├── NerfppVisualizationDash.ipynb ├── README.md ├── __pycache__ ├── drawing_tools.cpython-38.pyc └── nerfplusplus_tools.cpython-38.pyc ├── app.py ├── drawing_tools.py ├── example.gif └── nerfplusplus_tools.py /.ipynb_checkpoints/NerfppVisualizationDash-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "a6fca250", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import plotly.graph_objects as go\n", 11 | "import numpy as np\n", 12 | "import torch\n", 13 | "import plotly.express as px" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "id": "de77ce6c", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "HUGE_NUMBER = 1e10\n", 24 | "TINY_NUMBER = 1e-6 " 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "541c3686", 30 | "metadata": {}, 31 | "source": [ 32 | "# Ray Casting" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "id": "27ee36ff", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# from autonomousvision/giraffe\n", 43 | "\n", 44 | "# 0.\n", 45 | "def to_pytorch(tensor, return_type=False):\n", 46 | " ''' Converts input tensor to pytorch.\n", 47 | " Args:\n", 48 | " tensor (tensor): Numpy or Pytorch tensor\n", 49 | " return_type (bool): whether to return input type\n", 50 | " '''\n", 51 | " is_numpy = False\n", 52 | " if type(tensor) == np.ndarray:\n", 53 | " tensor = torch.from_numpy(tensor).float()\n", 54 | " is_numpy = True\n", 55 | " tensor = tensor.clone()\n", 56 | " if return_type:\n", 57 | " return tensor, is_numpy\n", 58 | " return tensor\n", 59 | "\n", 60 | "# 1. get camera intrinsic\n", 61 | "def get_camera_mat(fov=49.13, invert=True):\n", 62 | " # fov = 2 * arctan( sensor / (2 * focal))\n", 63 | " # focal = (sensor / 2) * 1 / (tan(0.5 * fov))\n", 64 | " # in our case, sensor = 2 as pixels are in [-1, 1]\n", 65 | " focal = 1. / np.tan(0.5 * fov * np.pi/180.)\n", 66 | " focal = focal.astype(np.float32)\n", 67 | " mat = torch.tensor([\n", 68 | " [focal, 0., 0., 0.],\n", 69 | " [0., focal, 0., 0.],\n", 70 | " [0., 0., 1, 0.],\n", 71 | " [0., 0., 0., 1.]\n", 72 | " ]).reshape(1, 4, 4)\n", 73 | "\n", 74 | " if invert:\n", 75 | " mat = torch.inverse(mat)\n", 76 | " return mat\n", 77 | "\n", 78 | "# 2. get camera position with camera pose (theta & phi)\n", 79 | "def to_sphere(u, v):\n", 80 | " theta = 2 * np.pi * u\n", 81 | " phi = np.arccos(1 - 2 * v)\n", 82 | " cx = np.sin(phi) * np.cos(theta)\n", 83 | " cy = np.sin(phi) * np.sin(theta)\n", 84 | " cz = np.cos(phi)\n", 85 | " return np.stack([cx, cy, cz], axis=-1)\n", 86 | "\n", 87 | "# 3. get camera coordinate system assuming it points to the center of the sphere\n", 88 | "def look_at(eye, at=np.array([0, 0, 0]), up=np.array([0, 0, 1]), eps=1e-5,\n", 89 | " to_pytorch=True):\n", 90 | " at = at.astype(float).reshape(1, 3)\n", 91 | " up = up.astype(float).reshape(1, 3)\n", 92 | " eye = eye.reshape(-1, 3)\n", 93 | " up = up.repeat(eye.shape[0] // up.shape[0], axis=0)\n", 94 | " eps = np.array([eps]).reshape(1, 1).repeat(up.shape[0], axis=0)\n", 95 | "\n", 96 | " z_axis = eye - at\n", 97 | " z_axis /= np.max(np.stack([np.linalg.norm(z_axis,\n", 98 | " axis=1, keepdims=True), eps]))\n", 99 | "\n", 100 | " x_axis = np.cross(up, z_axis)\n", 101 | " x_axis /= np.max(np.stack([np.linalg.norm(x_axis,\n", 102 | " axis=1, keepdims=True), eps]))\n", 103 | "\n", 104 | " y_axis = np.cross(z_axis, x_axis)\n", 105 | " y_axis /= np.max(np.stack([np.linalg.norm(y_axis,\n", 106 | " axis=1, keepdims=True), eps]))\n", 107 | "\n", 108 | " r_mat = np.concatenate(\n", 109 | " (x_axis.reshape(-1, 3, 1), y_axis.reshape(-1, 3, 1), z_axis.reshape(\n", 110 | " -1, 3, 1)), axis=2)\n", 111 | "\n", 112 | " if to_pytorch:\n", 113 | " r_mat = torch.tensor(r_mat).float()\n", 114 | "\n", 115 | " return r_mat\n", 116 | "\n", 117 | "# 5. arange 2d array of pixel coordinate and give depth of 1\n", 118 | "def arange_pixels(resolution=(128, 128), batch_size=1, image_range=(-1., 1.),\n", 119 | " subsample_to=None, invert_y_axis=False):\n", 120 | " ''' Arranges pixels for given resolution in range image_range.\n", 121 | " The function returns the unscaled pixel locations as integers and the\n", 122 | " scaled float values.\n", 123 | " Args:\n", 124 | " resolution (tuple): image resolution\n", 125 | " batch_size (int): batch size\n", 126 | " image_range (tuple): range of output points (default [-1, 1])\n", 127 | " subsample_to (int): if integer and > 0, the points are randomly\n", 128 | " subsampled to this value\n", 129 | " '''\n", 130 | " h, w = resolution\n", 131 | " n_points = resolution[0] * resolution[1]\n", 132 | "\n", 133 | " # Arrange pixel location in scale resolution\n", 134 | " pixel_locations = torch.meshgrid(torch.arange(0, w), torch.arange(0, h))\n", 135 | " pixel_locations = torch.stack(\n", 136 | " [pixel_locations[0], pixel_locations[1]],\n", 137 | " dim=-1).long().view(1, -1, 2).repeat(batch_size, 1, 1)\n", 138 | " pixel_scaled = pixel_locations.clone().float()\n", 139 | "\n", 140 | " # Shift and scale points to match image_range\n", 141 | " scale = (image_range[1] - image_range[0])\n", 142 | " loc = scale / 2\n", 143 | " pixel_scaled[:, :, 0] = scale * pixel_scaled[:, :, 0] / (w - 1) - loc\n", 144 | " pixel_scaled[:, :, 1] = scale * pixel_scaled[:, :, 1] / (h - 1) - loc\n", 145 | "\n", 146 | " # Subsample points if subsample_to is not None and > 0\n", 147 | " if (subsample_to is not None and subsample_to > 0 and\n", 148 | " subsample_to < n_points):\n", 149 | " idx = np.random.choice(pixel_scaled.shape[1], size=(subsample_to,),\n", 150 | " replace=False)\n", 151 | " pixel_scaled = pixel_scaled[:, idx]\n", 152 | " pixel_locations = pixel_locations[:, idx]\n", 153 | "\n", 154 | " if invert_y_axis:\n", 155 | " assert(image_range == (-1, 1))\n", 156 | " pixel_scaled[..., -1] *= -1.\n", 157 | " pixel_locations[..., -1] = (h - 1) - pixel_locations[..., -1]\n", 158 | "\n", 159 | " return pixel_locations, pixel_scaled\n", 160 | "\n", 161 | "# 6. mat_mul with intrinsic and then extrinsic gives you p_world (pixels in world) \n", 162 | "def image_points_to_world(image_points, camera_mat, world_mat, scale_mat=None,\n", 163 | " invert=False, negative_depth=True):\n", 164 | " ''' Transforms points on image plane to world coordinates.\n", 165 | " In contrast to transform_to_world, no depth value is needed as points on\n", 166 | " the image plane have a fixed depth of 1.\n", 167 | " Args:\n", 168 | " image_points (tensor): image points tensor of size B x N x 2\n", 169 | " camera_mat (tensor): camera matrix\n", 170 | " world_mat (tensor): world matrix\n", 171 | " scale_mat (tensor): scale matrix\n", 172 | " invert (bool): whether to invert matrices (default: False)\n", 173 | " '''\n", 174 | " batch_size, n_pts, dim = image_points.shape\n", 175 | " assert(dim == 2)\n", 176 | " d_image = torch.ones(batch_size, n_pts, 1)\n", 177 | " if negative_depth:\n", 178 | " d_image *= -1.\n", 179 | " return transform_to_world(image_points, d_image, camera_mat, world_mat,\n", 180 | " scale_mat, invert=invert)\n", 181 | "\n", 182 | "def transform_to_world(pixels, depth, camera_mat, world_mat, scale_mat=None,\n", 183 | " invert=True, use_absolute_depth=True):\n", 184 | " ''' Transforms pixel positions p with given depth value d to world coordinates.\n", 185 | " Args:\n", 186 | " pixels (tensor): pixel tensor of size B x N x 2\n", 187 | " depth (tensor): depth tensor of size B x N x 1\n", 188 | " camera_mat (tensor): camera matrix\n", 189 | " world_mat (tensor): world matrix\n", 190 | " scale_mat (tensor): scale matrix\n", 191 | " invert (bool): whether to invert matrices (default: true)\n", 192 | " '''\n", 193 | " assert(pixels.shape[-1] == 2)\n", 194 | "\n", 195 | " if scale_mat is None:\n", 196 | " scale_mat = torch.eye(4).unsqueeze(0).repeat(\n", 197 | " camera_mat.shape[0], 1, 1)\n", 198 | "\n", 199 | " # Convert to pytorch\n", 200 | " pixels, is_numpy = to_pytorch(pixels, True)\n", 201 | " depth = to_pytorch(depth)\n", 202 | " camera_mat = to_pytorch(camera_mat)\n", 203 | " world_mat = to_pytorch(world_mat)\n", 204 | " scale_mat = to_pytorch(scale_mat)\n", 205 | "\n", 206 | " # Invert camera matrices\n", 207 | " if invert:\n", 208 | " camera_mat = torch.inverse(camera_mat)\n", 209 | " world_mat = torch.inverse(world_mat)\n", 210 | " scale_mat = torch.inverse(scale_mat)\n", 211 | "\n", 212 | " # Transform pixels to homogen coordinates\n", 213 | " pixels = pixels.permute(0, 2, 1)\n", 214 | " pixels = torch.cat([pixels, torch.ones_like(pixels)], dim=1)\n", 215 | "\n", 216 | " # Project pixels into camera space\n", 217 | " if use_absolute_depth:\n", 218 | " pixels[:, :2] = pixels[:, :2] * depth.permute(0, 2, 1).abs()\n", 219 | " pixels[:, 2:3] = pixels[:, 2:3] * depth.permute(0, 2, 1)\n", 220 | " else:\n", 221 | " pixels[:, :3] = pixels[:, :3] * depth.permute(0, 2, 1)\n", 222 | " \n", 223 | " # Transform pixels to world space\n", 224 | " p_world = scale_mat @ world_mat @ camera_mat @ pixels\n", 225 | "\n", 226 | " # Transform p_world back to 3D coordinates\n", 227 | " p_world = p_world[:, :3].permute(0, 2, 1)\n", 228 | "\n", 229 | " if is_numpy:\n", 230 | " p_world = p_world.numpy()\n", 231 | " return p_world\n", 232 | "\n", 233 | "\n", 234 | "# 7. mat_mul zeros with intrinsic&extrinsic for camera pos (which we alread obtained as loc)\n", 235 | "def origin_to_world(n_points, camera_mat, world_mat, scale_mat=None,\n", 236 | " invert=False):\n", 237 | " ''' Transforms origin (camera location) to world coordinates.\n", 238 | " Args:\n", 239 | " n_points (int): how often the transformed origin is repeated in the\n", 240 | " form (batch_size, n_points, 3)\n", 241 | " camera_mat (tensor): camera matrix\n", 242 | " world_mat (tensor): world matrix\n", 243 | " scale_mat (tensor): scale matrix\n", 244 | " invert (bool): whether to invert the matrices (default: false)\n", 245 | " '''\n", 246 | " \n", 247 | " batch_size = camera_mat.shape[0]\n", 248 | " device = camera_mat.device\n", 249 | " # Create origin in homogen coordinates\n", 250 | " p = torch.zeros(batch_size, 4, n_points).to(device)\n", 251 | " p[:, -1] = 1.\n", 252 | "\n", 253 | " if scale_mat is None:\n", 254 | " scale_mat = torch.eye(4).unsqueeze(\n", 255 | " 0).repeat(batch_size, 1, 1).to(device)\n", 256 | "\n", 257 | " # Invert matrices\n", 258 | " if invert:\n", 259 | " camera_mat = torch.inverse(camera_mat)\n", 260 | " world_mat = torch.inverse(world_mat)\n", 261 | " scale_mat = torch.inverse(scale_mat)\n", 262 | " \n", 263 | " camera_mat = to_pytorch(camera_mat)\n", 264 | " world_mat = to_pytorch(world_mat)\n", 265 | " scale_mat = to_pytorch(scale_mat)\n", 266 | " \n", 267 | " # Apply transformation\n", 268 | " p_world = scale_mat @ world_mat @ camera_mat @ p\n", 269 | "\n", 270 | " # Transform points back to 3D coordinates\n", 271 | " p_world = p_world[:, :3].permute(0, 2, 1)\n", 272 | " return p_world" 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "id": "414ea1ff", 278 | "metadata": {}, 279 | "source": [ 280 | "# Sampling" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": 4, 286 | "id": "d62aa105", 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "# from Kai-46/nerfplusplus\n", 291 | "\n", 292 | "# 8. intersect sphere for distinguishing fg and bg\n", 293 | "def intersect_sphere(ray_o, ray_d):\n", 294 | " '''\n", 295 | " ray_o, ray_d: [..., 3]\n", 296 | " compute the depth of the intersection point between this ray and unit sphere\n", 297 | " '''\n", 298 | " # note: d1 becomes negative if this mid point is behind camera\n", 299 | " d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)\n", 300 | " p = ray_o + d1.unsqueeze(-1) * ray_d\n", 301 | " # consider the case where the ray does not intersect the sphere\n", 302 | " ray_d_cos = 1. / torch.norm(ray_d, dim=-1)\n", 303 | " p_norm_sq = torch.sum(p * p, dim=-1)\n", 304 | " if (p_norm_sq >= 1.).any():\n", 305 | " raise Exception('Not all your cameras are bounded by the unit sphere; please make sure the cameras are normalized properly!')\n", 306 | " d2 = torch.sqrt(1. - p_norm_sq) * ray_d_cos\n", 307 | "\n", 308 | " return d1 + d2\n", 309 | "\n", 310 | "# 9. inverse sphere sampling for bg\n", 311 | "def depth2pts_outside(ray_o, ray_d, depth):\n", 312 | " '''\n", 313 | " ray_o, ray_d: [..., 3]\n", 314 | " depth: [...]; inverse of distance to sphere origin\n", 315 | " '''\n", 316 | " # note: d1 becomes negative if this mid point is behind camera\n", 317 | " d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)\n", 318 | " p_mid = ray_o + d1.unsqueeze(-1) * ray_d\n", 319 | " p_mid_norm = torch.norm(p_mid, dim=-1)\n", 320 | " ray_d_cos = 1. / torch.norm(ray_d, dim=-1)\n", 321 | " d2 = torch.sqrt(1. - p_mid_norm * p_mid_norm) * ray_d_cos\n", 322 | " \n", 323 | " p_sphere = ray_o + (d1 + d2).unsqueeze(-1) * ray_d\n", 324 | "\n", 325 | " rot_axis = torch.cross(ray_o, p_sphere, dim=-1)\n", 326 | " rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True)\n", 327 | " phi = torch.asin(p_mid_norm)\n", 328 | " theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1]\n", 329 | " rot_angle = (phi - theta).unsqueeze(-1) # [..., 1]\n", 330 | "\n", 331 | " # now rotate p_sphere\n", 332 | " # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula\n", 333 | " p_sphere_new = p_sphere * torch.cos(rot_angle) + \\\n", 334 | " torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \\\n", 335 | " rot_axis * torch.sum(rot_axis*p_sphere, dim=-1, keepdim=True) * (1.-torch.cos(rot_angle))\n", 336 | " p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True)\n", 337 | " pts = torch.cat((p_sphere_new.squeeze(), depth.squeeze().unsqueeze(-1)), dim=-1) # (modified) added .squeeze()\n", 338 | "\n", 339 | " # now calculate conventional depth\n", 340 | " depth_real = 1. / (depth + TINY_NUMBER) * torch.cos(theta) * ray_d_cos + d1\n", 341 | " return pts, depth_real\n", 342 | "\n", 343 | "# 10. perturb sample z values (depth) for some randomness (stratified sampling)\n", 344 | "def perturb_samples(z_vals):\n", 345 | " # get intervals between samples\n", 346 | " mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])\n", 347 | " upper = torch.cat([mids, z_vals[..., -1:]], dim=-1)\n", 348 | " lower = torch.cat([z_vals[..., 0:1], mids], dim=-1)\n", 349 | " # uniform samples in those intervals\n", 350 | " t_rand = torch.rand_like(z_vals)\n", 351 | " z_vals = lower + (upper - lower) * t_rand # [N_rays, N_samples]\n", 352 | "\n", 353 | " return z_vals\n", 354 | "\n", 355 | "# 11. return fg (just uniform) and bg (uniform inverse sphere) samples\n", 356 | "def uniformsampling(ray_o, ray_d, min_depth, N_samples, device):\n", 357 | " \n", 358 | " dots_sh = list(ray_d.shape[:-1])\n", 359 | "\n", 360 | " # foreground depth\n", 361 | " fg_far_depth = intersect_sphere(ray_o, ray_d) # [...,]\n", 362 | " fg_near_depth = min_depth # [..., ]\n", 363 | " step = (fg_far_depth - fg_near_depth) / (N_samples - 1)\n", 364 | " fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples]\n", 365 | " fg_depth = perturb_samples(fg_depth) # random perturbation during training\n", 366 | "\n", 367 | " # background depth\n", 368 | " bg_depth = torch.linspace(0., 1., N_samples).view(\n", 369 | " [1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(device)\n", 370 | " bg_depth = perturb_samples(bg_depth) # random perturbation during training\n", 371 | "\n", 372 | "\n", 373 | " fg_ray_o = ray_o.unsqueeze(-2).expand(dots_sh + [N_samples, 3])\n", 374 | " fg_ray_d = ray_d.unsqueeze(-2).expand(dots_sh + [N_samples, 3])\n", 375 | "\n", 376 | " bg_ray_o = ray_o.unsqueeze(-2).expand(dots_sh + [N_samples, 3])\n", 377 | " bg_ray_d = ray_d.unsqueeze(-2).expand(dots_sh + [N_samples, 3])\n", 378 | " \n", 379 | " # sampling foreground\n", 380 | " fg_pts = fg_ray_o + fg_depth.unsqueeze(-1) * fg_ray_d\n", 381 | "\n", 382 | " # sampling background\n", 383 | " bg_pts, bg_depth_real = depth2pts_outside(bg_ray_o, bg_ray_d, bg_depth)\n", 384 | "\n", 385 | " return fg_pts, bg_pts, bg_depth_real\n", 386 | "\n", 387 | "# 12 convert bg pts (x', y', z' 1/r) to real bg pts (x, y, z)\n", 388 | "def pts2realpts(bg_pts, bg_depth_real):\n", 389 | " return bg_pts[:, :, :3] * bg_depth_real.squeeze().unsqueeze(-1)" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 5, 395 | "id": "856d052e", 396 | "metadata": {}, 397 | "outputs": [], 398 | "source": [ 399 | "def giraffe(u = 1,\n", 400 | " v = 0.5,\n", 401 | " r=2.713,\n", 402 | " depth_range=[0.5, 6.],\n", 403 | " n_ray_samples=16,\n", 404 | " resolution_vol = 4,\n", 405 | " batch_size = 1\n", 406 | " ):\n", 407 | "\n", 408 | " range_radius=[r, r]\n", 409 | " \n", 410 | " res = resolution_vol\n", 411 | " n_points = res * res\n", 412 | "\n", 413 | " # 1. get camera intrinsic \n", 414 | " camera_mat = get_camera_mat()\n", 415 | "\n", 416 | " # 2. get camera position with camera pose (theta & phi)\n", 417 | " loc = to_sphere(u, v)\n", 418 | " loc = torch.tensor(loc).float()\n", 419 | " radius = range_radius[0] + \\\n", 420 | " torch.rand(batch_size) * (range_radius[1] - range_radius[0])\n", 421 | " loc = loc * radius.unsqueeze(-1)\n", 422 | "\n", 423 | " # 3. get camera coordinate system assuming it points to the center of the sphere\n", 424 | " R = look_at(loc)\n", 425 | "\n", 426 | " # 4. The carmera coordinate is the rotational matrix and with camera loc, it is camera extrinsic\n", 427 | " RT = np.eye(4).reshape(1, 4, 4)\n", 428 | " RT[:, :3, :3] = R\n", 429 | " RT[:, :3, -1] = loc\n", 430 | " world_mat = RT\n", 431 | "\n", 432 | " # 5. arange 2d array of pixel coordinate and give depth of 1\n", 433 | " pixels = arange_pixels((res, res), 1, invert_y_axis=False)[1]\n", 434 | " pixels[..., -1] *= -1. # still dunno why this is here\n", 435 | "\n", 436 | " # 6. mat_mul with intrinsic and then extrinsic gives you p_world (pixels in world) \n", 437 | " pixels_world = image_points_to_world(pixels, camera_mat, world_mat)\n", 438 | " \n", 439 | " # 7. mat_mul zeros with intrinsic&extrinsic for camera pos (which we alread obtained as loc)\n", 440 | " camera_world = origin_to_world(n_points, camera_mat, world_mat)\n", 441 | "\n", 442 | " # 8. ray = pixel - camera origin (in world)\n", 443 | " ray_vector = pixels_world - camera_world\n", 444 | "\n", 445 | " # 9. depths from closest to furthest (0.5 ~ 6.0)\n", 446 | " di = depth_range[0] + \\\n", 447 | " torch.linspace(0., 1., steps=n_ray_samples).reshape(1, 1, -1) * (\n", 448 | " depth_range[1] - depth_range[0])\n", 449 | " di = di.repeat(batch_size, n_points, 1)\n", 450 | "\n", 451 | " # 10. calculate points\n", 452 | " p_i = camera_world.unsqueeze(-2).contiguous() + \\\n", 453 | " di.unsqueeze(-1).contiguous() * ray_vector.unsqueeze(-2).contiguous()\n", 454 | " \n", 455 | " return pixels_world, camera_world, world_mat, p_i" 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": 6, 461 | "id": "1a0c73e1", 462 | "metadata": {}, 463 | "outputs": [], 464 | "source": [ 465 | "def nerfpp(u = 1,\n", 466 | " v = 0.5,\n", 467 | " fov = 49.13,\n", 468 | " depth_range=[0.5, 6.],\n", 469 | " n_ray_samples=16,\n", 470 | " resolution_vol = 4,\n", 471 | " batch_size = 1,\n", 472 | " device = torch.device('cpu')\n", 473 | " ):\n", 474 | " \n", 475 | " r = 1\n", 476 | " range_radius=[r, r]\n", 477 | " \n", 478 | " res = resolution_vol\n", 479 | " n_points = res * res\n", 480 | "\n", 481 | " # 1. get camera intrinsic - fiddle around with fov this time\n", 482 | " camera_mat = get_camera_mat(fov=fov)\n", 483 | "\n", 484 | " # 2. get camera position with camera pose (theta & phi)\n", 485 | " loc = to_sphere(u, v)\n", 486 | " loc = torch.tensor(loc).float()\n", 487 | " \n", 488 | " radius = range_radius[0] + \\\n", 489 | " torch.rand(batch_size) * (range_radius[1] - range_radius[0])\n", 490 | " \n", 491 | " loc = loc * radius.unsqueeze(-1)\n", 492 | "\n", 493 | " # 3. get camera coordinate system assuming it points to the center of the sphere\n", 494 | " R = look_at(loc)\n", 495 | "\n", 496 | " # 4. The carmera coordinate is the rotational matrix and with camera loc, it is camera extrinsic\n", 497 | " RT = np.eye(4).reshape(1, 4, 4)\n", 498 | " RT[:, :3, :3] = R\n", 499 | " RT[:, :3, -1] = loc\n", 500 | " world_mat = RT\n", 501 | "\n", 502 | " # 5. arange 2d array of pixel coordinate and give depth of 1\n", 503 | " pixels = arange_pixels((res, res), 1, invert_y_axis=False)[1]\n", 504 | " pixels[..., -1] *= -1. # still dunno why this is here\n", 505 | "\n", 506 | " # 6. mat_mul with intrinsic and then extrinsic gives you p_world (pixels in world) \n", 507 | " pixels_world = image_points_to_world(pixels, camera_mat, world_mat)\n", 508 | "\n", 509 | " # 7. mat_mul zeros with intrinsic&extrinsic for camera pos (which we alread obtained as loc)\n", 510 | " camera_world = origin_to_world(n_points, camera_mat, world_mat)\n", 511 | "\n", 512 | " # 8. ray = pixel - camera origin (in world)\n", 513 | " ray_vector = pixels_world - camera_world\n", 514 | "\n", 515 | " # 9. sample fg and bg points according to nerfpp (uniform and inverse sphere)\n", 516 | " fg_pts, bg_pts, bg_depth_real = uniformsampling(ray_o=camera_world, ray_d=ray_vector, min_depth=depth_range[0], N_samples=n_ray_samples, device=device)\n", 517 | " \n", 518 | " #10. convert bg pts (x', y', z' 1/r) to real bg pts (x, y, z)\n", 519 | " bg_pts_real = pts2realpts(bg_pts, bg_depth_real)\n", 520 | " \n", 521 | " return pixels_world, camera_world, world_mat, fg_pts, bg_pts_real" 522 | ] 523 | }, 524 | { 525 | "cell_type": "markdown", 526 | "id": "2fd19406", 527 | "metadata": {}, 528 | "source": [ 529 | "# Visualization" 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "execution_count": 7, 535 | "id": "b9dd6b14", 536 | "metadata": {}, 537 | "outputs": [], 538 | "source": [ 539 | "# draw sphere with radius r \n", 540 | "# also draw contours and vertical lines\n", 541 | "def draw_sphere(r, sphere_colorscale, sphere_opacity):\n", 542 | " # sphere\n", 543 | " u = np.linspace(0, 2 * np.pi, 100)\n", 544 | " v = np.linspace(0, np.pi, 100)\n", 545 | " x = r * np.outer(np.cos(u), np.sin(v))\n", 546 | " y = r * np.outer(np.sin(u), np.sin(v))\n", 547 | " z = r * np.outer(np.ones(np.size(u)), np.cos(v))\n", 548 | " \n", 549 | " # vertical lines on sphere\n", 550 | " u2 = np.linspace(0, 2 * np.pi, 20)\n", 551 | " x2 = r * np.outer(np.cos(u2), np.sin(v))\n", 552 | " y2 = r * np.outer(np.sin(u2), np.sin(v))\n", 553 | " z2 = r * np.outer(np.ones(np.size(u2)), np.cos(v))\n", 554 | " \n", 555 | " # create sphere and draw sphere with contours\n", 556 | " fig = go.Figure(data=[go.Surface(x=x, y=y, z=z, \n", 557 | " colorscale=sphere_colorscale, opacity=sphere_opacity,\n", 558 | " contours = {\n", 559 | " 'z' : {'show' : True, 'start' : -r,\n", 560 | " 'end' : r, 'size' : r/10,\n", 561 | " 'color' : 'white',\n", 562 | " 'width' : 1}\n", 563 | " }\n", 564 | " , showscale=False)])\n", 565 | " \n", 566 | " # vertical lines on sphere\n", 567 | " for i in range(len(u2)):\n", 568 | " fig.add_scatter3d(x=x2[i], y=y2[i], z=z2[i], \n", 569 | " line=dict(\n", 570 | " color='white',\n", 571 | " width=1\n", 572 | " ),\n", 573 | " mode='lines',\n", 574 | " showlegend=False)\n", 575 | " \n", 576 | " return fig\n", 577 | "\n", 578 | "# draw xyplane\n", 579 | "def draw_XYplane(fig, xy_plane_colorscale, xy_plane_opacity, x_range = [-2, 2], y_range = [-2, 2]):\n", 580 | " x3 = np.linspace(x_range[0], x_range[1], 100)\n", 581 | " y3 = np.linspace(y_range[0], y_range[1], 100)\n", 582 | " z3 = np.zeros(shape=(100,100))\n", 583 | " \n", 584 | " fig.add_surface(x=x3, y=y3, z=z3,\n", 585 | " colorscale =xy_plane_colorscale, opacity=xy_plane_opacity,\n", 586 | " showscale=False\n", 587 | " )\n", 588 | " \n", 589 | " return fig\n", 590 | " \n", 591 | "\n", 592 | "def draw_XYZworld(fig, world_axis_size):\n", 593 | " # x, y, z positive direction (world)\n", 594 | " X_axis = [0, world_axis_size]\n", 595 | " X_text = [None, \"X\"]\n", 596 | " X0 = [0, 0]\n", 597 | " Y_axis = [0, world_axis_size]\n", 598 | " Y_text = [None, \"Y\"]\n", 599 | " Y0 = [0, 0]\n", 600 | " Z_axis = [0, world_axis_size]\n", 601 | " Z_text = [None, \"Z\"]\n", 602 | " Z0 = [0, 0]\n", 603 | " \n", 604 | " fig.add_scatter3d(x=X_axis, y=Y0, z=Z0, \n", 605 | " line=dict(\n", 606 | " color='red',\n", 607 | " width=10\n", 608 | " ),\n", 609 | " mode='lines+text',\n", 610 | " text=X_text,\n", 611 | " textposition='top center',\n", 612 | " textfont=dict(\n", 613 | " color=\"red\",\n", 614 | " size=18\n", 615 | " ),\n", 616 | " showlegend=False)\n", 617 | "\n", 618 | " fig.add_scatter3d(x=X0, y=Y_axis, z=Z0, \n", 619 | " line=dict(\n", 620 | " color='green',\n", 621 | " width=10\n", 622 | " ),\n", 623 | " mode='lines+text',\n", 624 | " text=Y_text,\n", 625 | " textposition='top center',\n", 626 | " textfont=dict(\n", 627 | " color=\"green\",\n", 628 | " size=18\n", 629 | " ),\n", 630 | " showlegend=False)\n", 631 | "\n", 632 | " fig.add_scatter3d(x=X0, y=Y0, z=Z_axis, \n", 633 | " line=dict(\n", 634 | " color='blue',\n", 635 | " width=10\n", 636 | " ),\n", 637 | " mode='lines+text',\n", 638 | " text=Z_text,\n", 639 | " textposition='top center',\n", 640 | " textfont=dict(\n", 641 | " color=\"blue\",\n", 642 | " size=18\n", 643 | " ),\n", 644 | " showlegend=False)\n", 645 | " \n", 646 | " return fig\n", 647 | "\n", 648 | "# draw cam and cam coordinate system\n", 649 | "def draw_cam_init(fig, world_mat, camera_axis_size, camera_color):\n", 650 | " # camera at init\n", 651 | "\n", 652 | " Xc = [world_mat[0, : ,3][0]]\n", 653 | " Yc = [world_mat[0, : ,3][1]]\n", 654 | " Zc = [world_mat[0, : ,3][2]]\n", 655 | " text_c = [\"Camera\"]\n", 656 | "\n", 657 | " # camera axis\n", 658 | " Xc_Xaxis = Xc + [world_mat[0, : ,0][0]*camera_axis_size+Xc[0]]\n", 659 | " Yc_Xaxis = Yc + [world_mat[0, : ,0][1]*camera_axis_size+Yc[0]]\n", 660 | " Zc_Xaxis = Zc + [world_mat[0, : ,0][2]*camera_axis_size+Zc[0]]\n", 661 | " text_Xaxis = [None, \"Xc\"]\n", 662 | " \n", 663 | " # -z in world perspective\n", 664 | " Xc_Yaxis = Xc + [world_mat[0, : ,1][0]*camera_axis_size+Xc[0]]\n", 665 | " Yc_Yaxis = Yc + [world_mat[0, : ,1][1]*camera_axis_size+Yc[0]]\n", 666 | " Zc_Yaxis = Zc + [world_mat[0, : ,1][2]*camera_axis_size+Zc[0]]\n", 667 | " text_Yaxis = [None, \"Yc\"]\n", 668 | "\n", 669 | " # y in world perspective\n", 670 | " Xc_Zaxis = Xc + [world_mat[0, : ,2][0]*camera_axis_size+Xc[0]]\n", 671 | " Yc_Zaxis = Yc + [world_mat[0, : ,2][1]*camera_axis_size+Yc[0]]\n", 672 | " Zc_Zaxis = Zc + [world_mat[0, : ,2][2]*camera_axis_size+Zc[0]]\n", 673 | " text_Zaxis = [None, \"Zc\"]\n", 674 | " \n", 675 | " # cam pos\n", 676 | " fig.add_scatter3d(x=Xc, y=Yc, z=Zc, \n", 677 | " mode='markers',\n", 678 | " marker=dict(\n", 679 | " color=camera_color,\n", 680 | " size=4,\n", 681 | " sizemode='diameter'\n", 682 | " ),\n", 683 | " showlegend=False)\n", 684 | "\n", 685 | " # camera axis\n", 686 | " fig.add_scatter3d(x=Xc_Xaxis, y=Yc_Xaxis, z=Zc_Xaxis, \n", 687 | " line=dict(\n", 688 | " color='red',\n", 689 | " width=10\n", 690 | " ),\n", 691 | " mode='lines+text',\n", 692 | " text=text_Xaxis,\n", 693 | " textposition='top center',\n", 694 | " textfont=dict(\n", 695 | " color=\"red\",\n", 696 | " size=18\n", 697 | " ),\n", 698 | " showlegend=False)\n", 699 | "\n", 700 | " fig.add_scatter3d(x=Xc_Yaxis, y=Yc_Yaxis, z=Zc_Yaxis, \n", 701 | " line=dict(\n", 702 | " color='green',\n", 703 | " width=10\n", 704 | " ),\n", 705 | " mode='lines+text',\n", 706 | " text=text_Yaxis,\n", 707 | " textposition='top center',\n", 708 | " textfont=dict(\n", 709 | " color=\"green\",\n", 710 | " size=18\n", 711 | " ),\n", 712 | " showlegend=False)\n", 713 | "\n", 714 | " fig.add_scatter3d(x=Xc_Zaxis, y=Yc_Zaxis, z=Zc_Zaxis, \n", 715 | " line=dict(\n", 716 | " color='blue',\n", 717 | " width=10\n", 718 | " ),\n", 719 | " mode='lines+text',\n", 720 | " text=text_Zaxis,\n", 721 | " textposition='top center',\n", 722 | " textfont=dict(\n", 723 | " color=\"blue\",\n", 724 | " size=18\n", 725 | " ),\n", 726 | " showlegend=False)\n", 727 | " \n", 728 | " return fig\n", 729 | "\n", 730 | "# draw all rays\n", 731 | "def draw_all_rays(fig, p_i, ray_color):\n", 732 | " for i in range(p_i.shape[1]):\n", 733 | " Xray = p_i[0, i, :, 0]\n", 734 | " Yray = p_i[0, i, :, 1]\n", 735 | " Zray = p_i[0, i, :, 2]\n", 736 | " \n", 737 | " fig.add_scatter3d(x=Xray, y=Yray, z=Zray, \n", 738 | " line=dict(\n", 739 | " color=ray_color,\n", 740 | " width=5\n", 741 | " ),\n", 742 | " mode='lines',\n", 743 | " showlegend=False)\n", 744 | " \n", 745 | " return fig\n", 746 | "\n", 747 | "# draw all rays\n", 748 | "def draw_all_rays_with_marker(fig, p_i, marker_size, ray_color):\n", 749 | " \n", 750 | " # convert colorscale string to px.colors.seqeuntial\n", 751 | " # default color is set to Viridis in case of mismatch\n", 752 | " c = px.colors.sequential.Viridis\n", 753 | "\n", 754 | " for c_name in [ray_color, ray_color.capitalize()]:\n", 755 | " try:\n", 756 | " c = getattr(px.colors.sequential, c_name)\n", 757 | " except:\n", 758 | " continue\n", 759 | " \n", 760 | " for i in range(p_i.shape[1]):\n", 761 | " Xray = p_i[0, i, :, 0]\n", 762 | " Yray = p_i[0, i, :, 1]\n", 763 | " Zray = p_i[0, i, :, 2]\n", 764 | " \n", 765 | " fig.add_scatter3d(x=Xray, y=Yray, z=Zray, \n", 766 | " \n", 767 | " marker=dict(\n", 768 | "# color=np.arange(len(Xray)),\n", 769 | " color=c,\n", 770 | "# colorscale='Viridis',\n", 771 | " size=marker_size\n", 772 | " ),\n", 773 | " \n", 774 | " line=dict(\n", 775 | "# color=np.arange(len(Xray)),\n", 776 | " color=c,\n", 777 | "# colorscale='Viridis',\n", 778 | " width=3\n", 779 | " ),\n", 780 | " mode=\"lines+markers\",\n", 781 | " showlegend=False)\n", 782 | " \n", 783 | " return fig\n", 784 | "\n", 785 | "# draw near&far frustrum with rays connecting the corners (changed for nerfpp)\n", 786 | "def draw_ray_frus(fig, p_i, frustrum_color, frustrum_opacity, at=[0, -1]):\n", 787 | " \n", 788 | " for i in at:\n", 789 | "# Xfrus = p_i[0, :, i, 0][[0,1,2,3,7,11,15,14,13,12,8,4,0]]\n", 790 | "# Yfrus = p_i[0, :, i, 1][[0,1,2,3,7,11,15,14,13,12,8,4,0]]\n", 791 | "# Zfrus = p_i[0, :, i, 2][[0,1,2,3,7,11,15,14,13,12,8,4,0]]\n", 792 | "\n", 793 | " Xfrus = p_i[0, :, i, 0]\n", 794 | " Yfrus = p_i[0, :, i, 1]\n", 795 | " Zfrus = p_i[0, :, i, 2]\n", 796 | " \n", 797 | " fig.add_scatter3d(x=Xfrus, y=Yfrus, z=Zfrus, \n", 798 | " line=dict(\n", 799 | " color=frustrum_color,\n", 800 | " width=5\n", 801 | " ),\n", 802 | " mode='lines',\n", 803 | " surfaceaxis=0,\n", 804 | " surfacecolor=frustrum_color,\n", 805 | " opacity=frustrum_opacity,\n", 806 | " showlegend=False)\n", 807 | " \n", 808 | " return fig\n", 809 | "\n", 810 | "# draw foreground sample points, ray and frustrum\n", 811 | "def draw_foreground(fig, fg_pts, fg_color, marker_size, at=[0, -1]):\n", 812 | " fig = draw_all_rays_with_marker(fig, fg_pts, marker_size, fg_color)\n", 813 | " \n", 814 | " return fig\n", 815 | "\n", 816 | "# draw background sample points, ray and frustrum\n", 817 | "def draw_background(fig, bg_pts, bg_color, marker_size, at=[0, -1]):\n", 818 | " fig = draw_all_rays_with_marker(fig, bg_pts, marker_size, bg_color)\n", 819 | " \n", 820 | " return fig" 821 | ] 822 | }, 823 | { 824 | "cell_type": "code", 825 | "execution_count": 8, 826 | "id": "f12fecdd", 827 | "metadata": {}, 828 | "outputs": [ 829 | { 830 | "name": "stderr", 831 | "output_type": "stream", 832 | "text": [ 833 | "C:\\Users\\laphi\\AppData\\Local\\Temp/ipykernel_24504/3891967274.py:2: UserWarning: \n", 834 | "The dash_core_components package is deprecated. Please replace\n", 835 | "`import dash_core_components as dcc` with `from dash import dcc`\n", 836 | " import dash_core_components as dcc\n", 837 | "C:\\Users\\laphi\\AppData\\Local\\Temp/ipykernel_24504/3891967274.py:3: UserWarning: \n", 838 | "The dash_html_components package is deprecated. Please replace\n", 839 | "`import dash_html_components as html` with `from dash import html`\n", 840 | " import dash_html_components as html\n" 841 | ] 842 | } 843 | ], 844 | "source": [ 845 | "from jupyter_dash import JupyterDash\n", 846 | "import dash_core_components as dcc\n", 847 | "import dash_html_components as html\n", 848 | "from dash.dependencies import Input, Output\n", 849 | "\n", 850 | "# for colors\n", 851 | "import matplotlib.colors as mcolors" 852 | ] 853 | }, 854 | { 855 | "cell_type": "code", 856 | "execution_count": 10, 857 | "id": "e5e28428", 858 | "metadata": { 859 | "scrolled": false 860 | }, 861 | "outputs": [ 862 | { 863 | "name": "stderr", 864 | "output_type": "stream", 865 | "text": [ 866 | "c:\\users\\laphi\\appdata\\local\\programs\\python\\python38\\lib\\site-packages\\jupyter_dash\\jupyter_app.py:139: UserWarning:\n", 867 | "\n", 868 | "The 'environ['werkzeug.server.shutdown']' function is deprecated and will be removed in Werkzeug 2.1.\n", 869 | "\n" 870 | ] 871 | }, 872 | { 873 | "data": { 874 | "text/html": [ 875 | "\n", 876 | " \n", 883 | " " 884 | ], 885 | "text/plain": [ 886 | "" 887 | ] 888 | }, 889 | "metadata": {}, 890 | "output_type": "display_data" 891 | } 892 | ], 893 | "source": [ 894 | "app = JupyterDash(__name__)\n", 895 | "\n", 896 | "app.layout = html.Div([\n", 897 | " html.H1(\"Nerfplusplus ray sampling visualization\"),\n", 898 | " dcc.Graph(id='graph'),\n", 899 | " \n", 900 | " html.Div([\n", 901 | " html.Div([\n", 902 | " \n", 903 | " # changes to setting and ray casted \n", 904 | " html.Label([ \"u (theta)\",\n", 905 | " dcc.Slider(\n", 906 | " id='u-slider', \n", 907 | " min=0, max=1,\n", 908 | " value=0.00,\n", 909 | " marks={str(val) : str(val) for val in [0.00, 0.25, 0.50, 0.75]},\n", 910 | " step=0.01, tooltip = { 'always_visible': True }\n", 911 | " ), ]),\n", 912 | " html.Label([ \"v (phi)\",\n", 913 | " dcc.Slider(\n", 914 | " id='v-slider', \n", 915 | " min=0, max=1,\n", 916 | " value=0.25,\n", 917 | " marks={str(val) : str(val) for val in [0.00, 0.25, 0.50, 0.75]},\n", 918 | " step=0.01, tooltip = { 'always_visible': True }\n", 919 | " ), ]),\n", 920 | " html.Label([ \"fov (field-of-view))\",\n", 921 | " dcc.Slider(\n", 922 | " id='fov-slider', \n", 923 | " min=0, max=100,\n", 924 | " value=50,\n", 925 | " marks={str(val) : str(val) for val in [0, 20, 40, 60, 80, 100]},\n", 926 | " step=5, tooltip = { 'always_visible': True }\n", 927 | " ), ]),\n", 928 | " html.Label([ \"foreground near depth\",\n", 929 | " dcc.Slider(\n", 930 | " id='foreground-near-depth-slider', \n", 931 | " min=0, max=2,\n", 932 | " value=0.5,\n", 933 | " marks={f\"{val:.1f}\" : f\"{val:.1f}\" for val in [0.1 * i for i in range(21)]},\n", 934 | " step=0.1, tooltip = { 'always_visible': True }\n", 935 | " ), ])\n", 936 | " ], style = {'width' : '48%', 'display' : 'inline-block'}),\n", 937 | " \n", 938 | " html.Div([\n", 939 | " # changes to visual appearance\n", 940 | " \n", 941 | " # axis scale\n", 942 | " html.Div([\n", 943 | " html.Label([ \"world axis size\",\n", 944 | " html.Div([\n", 945 | " dcc.Input(id='world-axis-size-input',\n", 946 | " value=1.5,\n", 947 | " type='number', style={'width': '50%'}\n", 948 | " )\n", 949 | " ]),\n", 950 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n", 951 | " html.Label([ \"camera axis size\",\n", 952 | " html.Div([\n", 953 | " dcc.Input(id='camera-axis-size-input',\n", 954 | " value=0.3,\n", 955 | " type='number', style={'width': '50%'}\n", 956 | " )\n", 957 | " ]),\n", 958 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n", 959 | " html.Label([ \"sample marker size\",\n", 960 | " html.Div([\n", 961 | " dcc.Input(id='sample-marker-size-input',\n", 962 | " value=2,\n", 963 | " type='number', style={'width': '50%'}\n", 964 | " )\n", 965 | " ]),\n", 966 | " ], style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}),\n", 967 | " ]),\n", 968 | " \n", 969 | " # opacity \n", 970 | " html.Div([\n", 971 | " html.Label([ \"sphere opacity\",\n", 972 | " html.Div([\n", 973 | " dcc.Input(id='sphere-opacity-input',\n", 974 | " value=0.2,\n", 975 | " type='number', style={'width': '50%'}\n", 976 | " )\n", 977 | " ])\n", 978 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n", 979 | " \n", 980 | " html.Label([ \"xy-plane opacity\", \n", 981 | " html.Div([\n", 982 | " dcc.Input(id='xy-plane-opacity-input',\n", 983 | " value=0.8,\n", 984 | " type='number', style={'width': '50%'}\n", 985 | " )\n", 986 | " ])\n", 987 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n", 988 | " \n", 989 | " ]),\n", 990 | " \n", 991 | " # color\n", 992 | " html.Div([\n", 993 | " html.Label([ \"camera color\",\n", 994 | " html.Div([\n", 995 | " dcc.Dropdown(id='camera-color-input',\n", 996 | " clearable=False,\n", 997 | " value='yellow',\n", 998 | " options=[\n", 999 | " {'label': c, 'value': c}\n", 1000 | " for (c, _) in mcolors.CSS4_COLORS.items()\n", 1001 | " ], style={'width': '80%'}\n", 1002 | " )\n", 1003 | " ])\n", 1004 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n", 1005 | " \n", 1006 | " html.Label([ \"foreground color\",\n", 1007 | " html.Div([\n", 1008 | " dcc.Dropdown(id='fg-color-input',\n", 1009 | " clearable=False,\n", 1010 | " value='plotly3',\n", 1011 | " options=[\n", 1012 | " {'label': c, 'value': c}\n", 1013 | " for c in px.colors.named_colorscales()\n", 1014 | " ], style={'width': '80%'}\n", 1015 | " )\n", 1016 | " ])\n", 1017 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n", 1018 | " \n", 1019 | " html.Label([ \"background color\",\n", 1020 | " html.Div([\n", 1021 | " dcc.Dropdown(id='bg-color-input',\n", 1022 | " clearable=False,\n", 1023 | " value='plotly3',\n", 1024 | " options=[\n", 1025 | " {'label': c, 'value': c}\n", 1026 | " for c in px.colors.named_colorscales()\n", 1027 | " ], style={'width': '80%'}\n", 1028 | " )\n", 1029 | " ])\n", 1030 | " ], style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}),\n", 1031 | " ]),\n", 1032 | " \n", 1033 | " # colorscale\n", 1034 | " html.Div([\n", 1035 | " html.Label([ \"sphere colorscale\",\n", 1036 | " html.Div([\n", 1037 | " dcc.Dropdown(id='sphere-colorscale-input',\n", 1038 | " clearable=False,\n", 1039 | " value='greys',\n", 1040 | " options=[\n", 1041 | " {'label': c, 'value': c}\n", 1042 | " for c in px.colors.named_colorscales()\n", 1043 | " ], style={'width': '80%'}\n", 1044 | " )\n", 1045 | " ])\n", 1046 | " ], style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}),\n", 1047 | " \n", 1048 | " html.Label([ \"xy-plane colorscale\",\n", 1049 | " html.Div([\n", 1050 | " dcc.Dropdown(id='xy-plane-colorscale-input',\n", 1051 | " clearable=False,\n", 1052 | " value='greys',\n", 1053 | " options=[\n", 1054 | " {'label': c, 'value': c}\n", 1055 | " for c in px.colors.named_colorscales()\n", 1056 | " ], style={'width': '80%'}\n", 1057 | " )\n", 1058 | " ])\n", 1059 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n", 1060 | " \n", 1061 | " html.Label([ \" \",\n", 1062 | " html.Div([\n", 1063 | " dcc.Checklist(id='show-background-checklist',\n", 1064 | " \n", 1065 | " options=[\n", 1066 | " {'label': 'show background', 'value': 'show_background'}\n", 1067 | " ], style={'width': '80%'},\n", 1068 | " value=[],\n", 1069 | " )\n", 1070 | " ])\n", 1071 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n", 1072 | " \n", 1073 | " ]),\n", 1074 | " \n", 1075 | " \n", 1076 | " \n", 1077 | " ], style = {'width' : '48%', 'float' : 'right', 'display' : 'inline-block'}),\n", 1078 | " \n", 1079 | " ]),\n", 1080 | " \n", 1081 | "])\n", 1082 | "\n", 1083 | "@app.callback(\n", 1084 | " Output('graph', 'figure'),\n", 1085 | " Input(\"u-slider\", \"value\"),\n", 1086 | " Input(\"v-slider\", \"value\"),\n", 1087 | " \n", 1088 | " Input(\"fov-slider\", \"value\"),\n", 1089 | " Input(\"foreground-near-depth-slider\", \"value\"),\n", 1090 | " \n", 1091 | " Input(\"world-axis-size-input\", \"value\"),\n", 1092 | " Input(\"camera-axis-size-input\", \"value\"),\n", 1093 | " Input(\"sample-marker-size-input\", \"value\"),\n", 1094 | " \n", 1095 | " Input(\"camera-color-input\", \"value\"),\n", 1096 | " Input(\"fg-color-input\", \"value\"),\n", 1097 | " Input(\"bg-color-input\", \"value\"),\n", 1098 | " \n", 1099 | " Input('sphere-colorscale-input', \"value\"),\n", 1100 | " Input('xy-plane-colorscale-input', \"value\"),\n", 1101 | " Input('show-background-checklist', \"value\"),\n", 1102 | " \n", 1103 | " Input(\"sphere-opacity-input\", \"value\"),\n", 1104 | " Input(\"xy-plane-opacity-input\", \"value\"),\n", 1105 | ")\n", 1106 | "\n", 1107 | "def update_figure(u, v, \n", 1108 | " fov, foreground_near_depth,\n", 1109 | " world_axis_size, camera_axis_size, sample_marker_size,\n", 1110 | " camera_color, fg_color, bg_color,\n", 1111 | " sphere_colorscale, xy_plane_colorscale, show_background,\n", 1112 | " sphere_opacity, xy_plane_opacity \n", 1113 | " ):\n", 1114 | " \n", 1115 | " depth_range = [foreground_near_depth, 2]\n", 1116 | " \n", 1117 | " # sphere\n", 1118 | " fig = draw_sphere(r=1, sphere_colorscale=sphere_colorscale, sphere_opacity=sphere_opacity)\n", 1119 | "\n", 1120 | " # change figure size\n", 1121 | "# fig.update_layout(autosize=False, width = 500, height=500)\n", 1122 | "\n", 1123 | " # draw axes in proportion to the proportion of their ranges\n", 1124 | " fig.update_layout(scene_aspectmode='data')\n", 1125 | "\n", 1126 | " # xy plane\n", 1127 | " fig = draw_XYplane(fig, xy_plane_colorscale, xy_plane_opacity,\n", 1128 | " x_range=[-depth_range[1], depth_range[1]], y_range=[-depth_range[1], depth_range[1]])\n", 1129 | "\n", 1130 | " # show world coordinate system (X, Y, Z positive direction)\n", 1131 | " fig = draw_XYZworld(fig, world_axis_size=world_axis_size)\n", 1132 | "\n", 1133 | " pixels_world, camera_world, world_mat, fg_pts, bg_pts = nerfpp(u=u, v=v, fov=fov, depth_range=depth_range)\n", 1134 | "\n", 1135 | " # draw camera at init (with its cooridnate system)\n", 1136 | " fig = draw_cam_init(fig, world_mat, \n", 1137 | " camera_axis_size=camera_axis_size, camera_color=camera_color)\n", 1138 | "\n", 1139 | " # draw foreground and background sample point, ray, and frustrum\n", 1140 | " fig = draw_foreground(fig, fg_pts, fg_color, sample_marker_size, at=[0, -1])\n", 1141 | " \n", 1142 | " if show_background:\n", 1143 | " fig = draw_background(fig, bg_pts.unsqueeze(0), bg_color, sample_marker_size, at=[0, -1])\n", 1144 | " \n", 1145 | " return fig\n", 1146 | "\n", 1147 | "app.run_server(mode='inline')" 1148 | ] 1149 | } 1150 | ], 1151 | "metadata": { 1152 | "kernelspec": { 1153 | "display_name": "Python 3 (ipykernel)", 1154 | "language": "python", 1155 | "name": "python3" 1156 | }, 1157 | "language_info": { 1158 | "codemirror_mode": { 1159 | "name": "ipython", 1160 | "version": 3 1161 | }, 1162 | "file_extension": ".py", 1163 | "mimetype": "text/x-python", 1164 | "name": "python", 1165 | "nbconvert_exporter": "python", 1166 | "pygments_lexer": "ipython3", 1167 | "version": "3.8.8" 1168 | } 1169 | }, 1170 | "nbformat": 4, 1171 | "nbformat_minor": 5 1172 | } 1173 | -------------------------------------------------------------------------------- /NerfppVisualizationDash.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "a6fca250", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import plotly.graph_objects as go\n", 11 | "import numpy as np\n", 12 | "import torch\n", 13 | "import plotly.express as px" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "id": "de77ce6c", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "HUGE_NUMBER = 1e10\n", 24 | "TINY_NUMBER = 1e-6 " 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "541c3686", 30 | "metadata": {}, 31 | "source": [ 32 | "# Ray Casting" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "id": "27ee36ff", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# from autonomousvision/giraffe\n", 43 | "\n", 44 | "# 0.\n", 45 | "def to_pytorch(tensor, return_type=False):\n", 46 | " ''' Converts input tensor to pytorch.\n", 47 | " Args:\n", 48 | " tensor (tensor): Numpy or Pytorch tensor\n", 49 | " return_type (bool): whether to return input type\n", 50 | " '''\n", 51 | " is_numpy = False\n", 52 | " if type(tensor) == np.ndarray:\n", 53 | " tensor = torch.from_numpy(tensor).float()\n", 54 | " is_numpy = True\n", 55 | " tensor = tensor.clone()\n", 56 | " if return_type:\n", 57 | " return tensor, is_numpy\n", 58 | " return tensor\n", 59 | "\n", 60 | "# 1. get camera intrinsic\n", 61 | "def get_camera_mat(fov=49.13, invert=True):\n", 62 | " # fov = 2 * arctan( sensor / (2 * focal))\n", 63 | " # focal = (sensor / 2) * 1 / (tan(0.5 * fov))\n", 64 | " # in our case, sensor = 2 as pixels are in [-1, 1]\n", 65 | " focal = 1. / np.tan(0.5 * fov * np.pi/180.)\n", 66 | " focal = focal.astype(np.float32)\n", 67 | " mat = torch.tensor([\n", 68 | " [focal, 0., 0., 0.],\n", 69 | " [0., focal, 0., 0.],\n", 70 | " [0., 0., 1, 0.],\n", 71 | " [0., 0., 0., 1.]\n", 72 | " ]).reshape(1, 4, 4)\n", 73 | "\n", 74 | " if invert:\n", 75 | " mat = torch.inverse(mat)\n", 76 | " return mat\n", 77 | "\n", 78 | "# 2. get camera position with camera pose (theta & phi)\n", 79 | "def to_sphere(u, v):\n", 80 | " theta = 2 * np.pi * u\n", 81 | " phi = np.arccos(1 - 2 * v)\n", 82 | " cx = np.sin(phi) * np.cos(theta)\n", 83 | " cy = np.sin(phi) * np.sin(theta)\n", 84 | " cz = np.cos(phi)\n", 85 | " return np.stack([cx, cy, cz], axis=-1)\n", 86 | "\n", 87 | "# 3. get camera coordinate system assuming it points to the center of the sphere\n", 88 | "def look_at(eye, at=np.array([0, 0, 0]), up=np.array([0, 0, 1]), eps=1e-5,\n", 89 | " to_pytorch=True):\n", 90 | " at = at.astype(float).reshape(1, 3)\n", 91 | " up = up.astype(float).reshape(1, 3)\n", 92 | " eye = eye.reshape(-1, 3)\n", 93 | " up = up.repeat(eye.shape[0] // up.shape[0], axis=0)\n", 94 | " eps = np.array([eps]).reshape(1, 1).repeat(up.shape[0], axis=0)\n", 95 | "\n", 96 | " z_axis = eye - at\n", 97 | " z_axis /= np.max(np.stack([np.linalg.norm(z_axis,\n", 98 | " axis=1, keepdims=True), eps]))\n", 99 | "\n", 100 | " x_axis = np.cross(up, z_axis)\n", 101 | " x_axis /= np.max(np.stack([np.linalg.norm(x_axis,\n", 102 | " axis=1, keepdims=True), eps]))\n", 103 | "\n", 104 | " y_axis = np.cross(z_axis, x_axis)\n", 105 | " y_axis /= np.max(np.stack([np.linalg.norm(y_axis,\n", 106 | " axis=1, keepdims=True), eps]))\n", 107 | "\n", 108 | " r_mat = np.concatenate(\n", 109 | " (x_axis.reshape(-1, 3, 1), y_axis.reshape(-1, 3, 1), z_axis.reshape(\n", 110 | " -1, 3, 1)), axis=2)\n", 111 | "\n", 112 | " if to_pytorch:\n", 113 | " r_mat = torch.tensor(r_mat).float()\n", 114 | "\n", 115 | " return r_mat\n", 116 | "\n", 117 | "# 5. arange 2d array of pixel coordinate and give depth of 1\n", 118 | "def arange_pixels(resolution=(128, 128), batch_size=1, image_range=(-1., 1.),\n", 119 | " subsample_to=None, invert_y_axis=False):\n", 120 | " ''' Arranges pixels for given resolution in range image_range.\n", 121 | " The function returns the unscaled pixel locations as integers and the\n", 122 | " scaled float values.\n", 123 | " Args:\n", 124 | " resolution (tuple): image resolution\n", 125 | " batch_size (int): batch size\n", 126 | " image_range (tuple): range of output points (default [-1, 1])\n", 127 | " subsample_to (int): if integer and > 0, the points are randomly\n", 128 | " subsampled to this value\n", 129 | " '''\n", 130 | " h, w = resolution\n", 131 | " n_points = resolution[0] * resolution[1]\n", 132 | "\n", 133 | " # Arrange pixel location in scale resolution\n", 134 | " pixel_locations = torch.meshgrid(torch.arange(0, w), torch.arange(0, h))\n", 135 | " pixel_locations = torch.stack(\n", 136 | " [pixel_locations[0], pixel_locations[1]],\n", 137 | " dim=-1).long().view(1, -1, 2).repeat(batch_size, 1, 1)\n", 138 | " pixel_scaled = pixel_locations.clone().float()\n", 139 | "\n", 140 | " # Shift and scale points to match image_range\n", 141 | " scale = (image_range[1] - image_range[0])\n", 142 | " loc = scale / 2\n", 143 | " pixel_scaled[:, :, 0] = scale * pixel_scaled[:, :, 0] / (w - 1) - loc\n", 144 | " pixel_scaled[:, :, 1] = scale * pixel_scaled[:, :, 1] / (h - 1) - loc\n", 145 | "\n", 146 | " # Subsample points if subsample_to is not None and > 0\n", 147 | " if (subsample_to is not None and subsample_to > 0 and\n", 148 | " subsample_to < n_points):\n", 149 | " idx = np.random.choice(pixel_scaled.shape[1], size=(subsample_to,),\n", 150 | " replace=False)\n", 151 | " pixel_scaled = pixel_scaled[:, idx]\n", 152 | " pixel_locations = pixel_locations[:, idx]\n", 153 | "\n", 154 | " if invert_y_axis:\n", 155 | " assert(image_range == (-1, 1))\n", 156 | " pixel_scaled[..., -1] *= -1.\n", 157 | " pixel_locations[..., -1] = (h - 1) - pixel_locations[..., -1]\n", 158 | "\n", 159 | " return pixel_locations, pixel_scaled\n", 160 | "\n", 161 | "# 6. mat_mul with intrinsic and then extrinsic gives you p_world (pixels in world) \n", 162 | "def image_points_to_world(image_points, camera_mat, world_mat, scale_mat=None,\n", 163 | " invert=False, negative_depth=True):\n", 164 | " ''' Transforms points on image plane to world coordinates.\n", 165 | " In contrast to transform_to_world, no depth value is needed as points on\n", 166 | " the image plane have a fixed depth of 1.\n", 167 | " Args:\n", 168 | " image_points (tensor): image points tensor of size B x N x 2\n", 169 | " camera_mat (tensor): camera matrix\n", 170 | " world_mat (tensor): world matrix\n", 171 | " scale_mat (tensor): scale matrix\n", 172 | " invert (bool): whether to invert matrices (default: False)\n", 173 | " '''\n", 174 | " batch_size, n_pts, dim = image_points.shape\n", 175 | " assert(dim == 2)\n", 176 | " d_image = torch.ones(batch_size, n_pts, 1)\n", 177 | " if negative_depth:\n", 178 | " d_image *= -1.\n", 179 | " return transform_to_world(image_points, d_image, camera_mat, world_mat,\n", 180 | " scale_mat, invert=invert)\n", 181 | "\n", 182 | "def transform_to_world(pixels, depth, camera_mat, world_mat, scale_mat=None,\n", 183 | " invert=True, use_absolute_depth=True):\n", 184 | " ''' Transforms pixel positions p with given depth value d to world coordinates.\n", 185 | " Args:\n", 186 | " pixels (tensor): pixel tensor of size B x N x 2\n", 187 | " depth (tensor): depth tensor of size B x N x 1\n", 188 | " camera_mat (tensor): camera matrix\n", 189 | " world_mat (tensor): world matrix\n", 190 | " scale_mat (tensor): scale matrix\n", 191 | " invert (bool): whether to invert matrices (default: true)\n", 192 | " '''\n", 193 | " assert(pixels.shape[-1] == 2)\n", 194 | "\n", 195 | " if scale_mat is None:\n", 196 | " scale_mat = torch.eye(4).unsqueeze(0).repeat(\n", 197 | " camera_mat.shape[0], 1, 1)\n", 198 | "\n", 199 | " # Convert to pytorch\n", 200 | " pixels, is_numpy = to_pytorch(pixels, True)\n", 201 | " depth = to_pytorch(depth)\n", 202 | " camera_mat = to_pytorch(camera_mat)\n", 203 | " world_mat = to_pytorch(world_mat)\n", 204 | " scale_mat = to_pytorch(scale_mat)\n", 205 | "\n", 206 | " # Invert camera matrices\n", 207 | " if invert:\n", 208 | " camera_mat = torch.inverse(camera_mat)\n", 209 | " world_mat = torch.inverse(world_mat)\n", 210 | " scale_mat = torch.inverse(scale_mat)\n", 211 | "\n", 212 | " # Transform pixels to homogen coordinates\n", 213 | " pixels = pixels.permute(0, 2, 1)\n", 214 | " pixels = torch.cat([pixels, torch.ones_like(pixels)], dim=1)\n", 215 | "\n", 216 | " # Project pixels into camera space\n", 217 | " if use_absolute_depth:\n", 218 | " pixels[:, :2] = pixels[:, :2] * depth.permute(0, 2, 1).abs()\n", 219 | " pixels[:, 2:3] = pixels[:, 2:3] * depth.permute(0, 2, 1)\n", 220 | " else:\n", 221 | " pixels[:, :3] = pixels[:, :3] * depth.permute(0, 2, 1)\n", 222 | " \n", 223 | " # Transform pixels to world space\n", 224 | " p_world = scale_mat @ world_mat @ camera_mat @ pixels\n", 225 | "\n", 226 | " # Transform p_world back to 3D coordinates\n", 227 | " p_world = p_world[:, :3].permute(0, 2, 1)\n", 228 | "\n", 229 | " if is_numpy:\n", 230 | " p_world = p_world.numpy()\n", 231 | " return p_world\n", 232 | "\n", 233 | "\n", 234 | "# 7. mat_mul zeros with intrinsic&extrinsic for camera pos (which we alread obtained as loc)\n", 235 | "def origin_to_world(n_points, camera_mat, world_mat, scale_mat=None,\n", 236 | " invert=False):\n", 237 | " ''' Transforms origin (camera location) to world coordinates.\n", 238 | " Args:\n", 239 | " n_points (int): how often the transformed origin is repeated in the\n", 240 | " form (batch_size, n_points, 3)\n", 241 | " camera_mat (tensor): camera matrix\n", 242 | " world_mat (tensor): world matrix\n", 243 | " scale_mat (tensor): scale matrix\n", 244 | " invert (bool): whether to invert the matrices (default: false)\n", 245 | " '''\n", 246 | " \n", 247 | " batch_size = camera_mat.shape[0]\n", 248 | " device = camera_mat.device\n", 249 | " # Create origin in homogen coordinates\n", 250 | " p = torch.zeros(batch_size, 4, n_points).to(device)\n", 251 | " p[:, -1] = 1.\n", 252 | "\n", 253 | " if scale_mat is None:\n", 254 | " scale_mat = torch.eye(4).unsqueeze(\n", 255 | " 0).repeat(batch_size, 1, 1).to(device)\n", 256 | "\n", 257 | " # Invert matrices\n", 258 | " if invert:\n", 259 | " camera_mat = torch.inverse(camera_mat)\n", 260 | " world_mat = torch.inverse(world_mat)\n", 261 | " scale_mat = torch.inverse(scale_mat)\n", 262 | " \n", 263 | " camera_mat = to_pytorch(camera_mat)\n", 264 | " world_mat = to_pytorch(world_mat)\n", 265 | " scale_mat = to_pytorch(scale_mat)\n", 266 | " \n", 267 | " # Apply transformation\n", 268 | " p_world = scale_mat @ world_mat @ camera_mat @ p\n", 269 | "\n", 270 | " # Transform points back to 3D coordinates\n", 271 | " p_world = p_world[:, :3].permute(0, 2, 1)\n", 272 | " return p_world" 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "id": "414ea1ff", 278 | "metadata": {}, 279 | "source": [ 280 | "# Sampling" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": 4, 286 | "id": "d62aa105", 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "# from Kai-46/nerfplusplus\n", 291 | "\n", 292 | "# 8. intersect sphere for distinguishing fg and bg\n", 293 | "def intersect_sphere(ray_o, ray_d):\n", 294 | " '''\n", 295 | " ray_o, ray_d: [..., 3]\n", 296 | " compute the depth of the intersection point between this ray and unit sphere\n", 297 | " '''\n", 298 | " # note: d1 becomes negative if this mid point is behind camera\n", 299 | " d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)\n", 300 | " p = ray_o + d1.unsqueeze(-1) * ray_d\n", 301 | " # consider the case where the ray does not intersect the sphere\n", 302 | " ray_d_cos = 1. / torch.norm(ray_d, dim=-1)\n", 303 | " p_norm_sq = torch.sum(p * p, dim=-1)\n", 304 | " if (p_norm_sq >= 1.).any():\n", 305 | " raise Exception('Not all your cameras are bounded by the unit sphere; please make sure the cameras are normalized properly!')\n", 306 | " d2 = torch.sqrt(1. - p_norm_sq) * ray_d_cos\n", 307 | "\n", 308 | " return d1 + d2\n", 309 | "\n", 310 | "# 9. inverse sphere sampling for bg\n", 311 | "def depth2pts_outside(ray_o, ray_d, depth):\n", 312 | " '''\n", 313 | " ray_o, ray_d: [..., 3]\n", 314 | " depth: [...]; inverse of distance to sphere origin\n", 315 | " '''\n", 316 | " # note: d1 becomes negative if this mid point is behind camera\n", 317 | " d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1)\n", 318 | " p_mid = ray_o + d1.unsqueeze(-1) * ray_d\n", 319 | " p_mid_norm = torch.norm(p_mid, dim=-1)\n", 320 | " ray_d_cos = 1. / torch.norm(ray_d, dim=-1)\n", 321 | " d2 = torch.sqrt(1. - p_mid_norm * p_mid_norm) * ray_d_cos\n", 322 | " \n", 323 | " p_sphere = ray_o + (d1 + d2).unsqueeze(-1) * ray_d\n", 324 | "\n", 325 | " rot_axis = torch.cross(ray_o, p_sphere, dim=-1)\n", 326 | " rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True)\n", 327 | " phi = torch.asin(p_mid_norm)\n", 328 | " theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1]\n", 329 | " rot_angle = (phi - theta).unsqueeze(-1) # [..., 1]\n", 330 | "\n", 331 | " # now rotate p_sphere\n", 332 | " # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula\n", 333 | " p_sphere_new = p_sphere * torch.cos(rot_angle) + \\\n", 334 | " torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \\\n", 335 | " rot_axis * torch.sum(rot_axis*p_sphere, dim=-1, keepdim=True) * (1.-torch.cos(rot_angle))\n", 336 | " p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True)\n", 337 | " pts = torch.cat((p_sphere_new.squeeze(), depth.squeeze().unsqueeze(-1)), dim=-1) # (modified) added .squeeze()\n", 338 | "\n", 339 | " # now calculate conventional depth\n", 340 | " depth_real = 1. / (depth + TINY_NUMBER) * torch.cos(theta) * ray_d_cos + d1\n", 341 | " return pts, depth_real\n", 342 | "\n", 343 | "# 10. perturb sample z values (depth) for some randomness (stratified sampling)\n", 344 | "def perturb_samples(z_vals):\n", 345 | " # get intervals between samples\n", 346 | " mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])\n", 347 | " upper = torch.cat([mids, z_vals[..., -1:]], dim=-1)\n", 348 | " lower = torch.cat([z_vals[..., 0:1], mids], dim=-1)\n", 349 | " # uniform samples in those intervals\n", 350 | " t_rand = torch.rand_like(z_vals)\n", 351 | " z_vals = lower + (upper - lower) * t_rand # [N_rays, N_samples]\n", 352 | "\n", 353 | " return z_vals\n", 354 | "\n", 355 | "# 11. return fg (just uniform) and bg (uniform inverse sphere) samples\n", 356 | "def uniformsampling(ray_o, ray_d, min_depth, N_samples, device):\n", 357 | " \n", 358 | " dots_sh = list(ray_d.shape[:-1])\n", 359 | "\n", 360 | " # foreground depth\n", 361 | " fg_far_depth = intersect_sphere(ray_o, ray_d) # [...,]\n", 362 | " fg_near_depth = min_depth # [..., ]\n", 363 | " step = (fg_far_depth - fg_near_depth) / (N_samples - 1)\n", 364 | " fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples]\n", 365 | " fg_depth = perturb_samples(fg_depth) # random perturbation during training\n", 366 | "\n", 367 | " # background depth\n", 368 | " bg_depth = torch.linspace(0., 1., N_samples).view(\n", 369 | " [1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(device)\n", 370 | " bg_depth = perturb_samples(bg_depth) # random perturbation during training\n", 371 | "\n", 372 | "\n", 373 | " fg_ray_o = ray_o.unsqueeze(-2).expand(dots_sh + [N_samples, 3])\n", 374 | " fg_ray_d = ray_d.unsqueeze(-2).expand(dots_sh + [N_samples, 3])\n", 375 | "\n", 376 | " bg_ray_o = ray_o.unsqueeze(-2).expand(dots_sh + [N_samples, 3])\n", 377 | " bg_ray_d = ray_d.unsqueeze(-2).expand(dots_sh + [N_samples, 3])\n", 378 | " \n", 379 | " # sampling foreground\n", 380 | " fg_pts = fg_ray_o + fg_depth.unsqueeze(-1) * fg_ray_d\n", 381 | "\n", 382 | " # sampling background\n", 383 | " bg_pts, bg_depth_real = depth2pts_outside(bg_ray_o, bg_ray_d, bg_depth)\n", 384 | "\n", 385 | " return fg_pts, bg_pts, bg_depth_real\n", 386 | "\n", 387 | "# 12 convert bg pts (x', y', z' 1/r) to real bg pts (x, y, z)\n", 388 | "def pts2realpts(bg_pts, bg_depth_real):\n", 389 | " return bg_pts[:, :, :3] * bg_depth_real.squeeze().unsqueeze(-1)" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 5, 395 | "id": "856d052e", 396 | "metadata": {}, 397 | "outputs": [], 398 | "source": [ 399 | "def giraffe(u = 1,\n", 400 | " v = 0.5,\n", 401 | " r=2.713,\n", 402 | " depth_range=[0.5, 6.],\n", 403 | " n_ray_samples=16,\n", 404 | " resolution_vol = 4,\n", 405 | " batch_size = 1\n", 406 | " ):\n", 407 | "\n", 408 | " range_radius=[r, r]\n", 409 | " \n", 410 | " res = resolution_vol\n", 411 | " n_points = res * res\n", 412 | "\n", 413 | " # 1. get camera intrinsic \n", 414 | " camera_mat = get_camera_mat()\n", 415 | "\n", 416 | " # 2. get camera position with camera pose (theta & phi)\n", 417 | " loc = to_sphere(u, v)\n", 418 | " loc = torch.tensor(loc).float()\n", 419 | " radius = range_radius[0] + \\\n", 420 | " torch.rand(batch_size) * (range_radius[1] - range_radius[0])\n", 421 | " loc = loc * radius.unsqueeze(-1)\n", 422 | "\n", 423 | " # 3. get camera coordinate system assuming it points to the center of the sphere\n", 424 | " R = look_at(loc)\n", 425 | "\n", 426 | " # 4. The carmera coordinate is the rotational matrix and with camera loc, it is camera extrinsic\n", 427 | " RT = np.eye(4).reshape(1, 4, 4)\n", 428 | " RT[:, :3, :3] = R\n", 429 | " RT[:, :3, -1] = loc\n", 430 | " world_mat = RT\n", 431 | "\n", 432 | " # 5. arange 2d array of pixel coordinate and give depth of 1\n", 433 | " pixels = arange_pixels((res, res), 1, invert_y_axis=False)[1]\n", 434 | " pixels[..., -1] *= -1. # still dunno why this is here\n", 435 | "\n", 436 | " # 6. mat_mul with intrinsic and then extrinsic gives you p_world (pixels in world) \n", 437 | " pixels_world = image_points_to_world(pixels, camera_mat, world_mat)\n", 438 | " \n", 439 | " # 7. mat_mul zeros with intrinsic&extrinsic for camera pos (which we alread obtained as loc)\n", 440 | " camera_world = origin_to_world(n_points, camera_mat, world_mat)\n", 441 | "\n", 442 | " # 8. ray = pixel - camera origin (in world)\n", 443 | " ray_vector = pixels_world - camera_world\n", 444 | "\n", 445 | " # 9. depths from closest to furthest (0.5 ~ 6.0)\n", 446 | " di = depth_range[0] + \\\n", 447 | " torch.linspace(0., 1., steps=n_ray_samples).reshape(1, 1, -1) * (\n", 448 | " depth_range[1] - depth_range[0])\n", 449 | " di = di.repeat(batch_size, n_points, 1)\n", 450 | "\n", 451 | " # 10. calculate points\n", 452 | " p_i = camera_world.unsqueeze(-2).contiguous() + \\\n", 453 | " di.unsqueeze(-1).contiguous() * ray_vector.unsqueeze(-2).contiguous()\n", 454 | " \n", 455 | " return pixels_world, camera_world, world_mat, p_i" 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": 6, 461 | "id": "1a0c73e1", 462 | "metadata": {}, 463 | "outputs": [], 464 | "source": [ 465 | "def nerfpp(u = 1,\n", 466 | " v = 0.5,\n", 467 | " fov = 49.13,\n", 468 | " depth_range=[0.5, 6.],\n", 469 | " n_ray_samples=16,\n", 470 | " resolution_vol = 4,\n", 471 | " batch_size = 1,\n", 472 | " device = torch.device('cpu')\n", 473 | " ):\n", 474 | " \n", 475 | " r = 1\n", 476 | " range_radius=[r, r]\n", 477 | " \n", 478 | " res = resolution_vol\n", 479 | " n_points = res * res\n", 480 | "\n", 481 | " # 1. get camera intrinsic - fiddle around with fov this time\n", 482 | " camera_mat = get_camera_mat(fov=fov)\n", 483 | "\n", 484 | " # 2. get camera position with camera pose (theta & phi)\n", 485 | " loc = to_sphere(u, v)\n", 486 | " loc = torch.tensor(loc).float()\n", 487 | " \n", 488 | " radius = range_radius[0] + \\\n", 489 | " torch.rand(batch_size) * (range_radius[1] - range_radius[0])\n", 490 | " \n", 491 | " loc = loc * radius.unsqueeze(-1)\n", 492 | "\n", 493 | " # 3. get camera coordinate system assuming it points to the center of the sphere\n", 494 | " R = look_at(loc)\n", 495 | "\n", 496 | " # 4. The carmera coordinate is the rotational matrix and with camera loc, it is camera extrinsic\n", 497 | " RT = np.eye(4).reshape(1, 4, 4)\n", 498 | " RT[:, :3, :3] = R\n", 499 | " RT[:, :3, -1] = loc\n", 500 | " world_mat = RT\n", 501 | "\n", 502 | " # 5. arange 2d array of pixel coordinate and give depth of 1\n", 503 | " pixels = arange_pixels((res, res), 1, invert_y_axis=False)[1]\n", 504 | " pixels[..., -1] *= -1. # still dunno why this is here\n", 505 | "\n", 506 | " # 6. mat_mul with intrinsic and then extrinsic gives you p_world (pixels in world) \n", 507 | " pixels_world = image_points_to_world(pixels, camera_mat, world_mat)\n", 508 | "\n", 509 | " # 7. mat_mul zeros with intrinsic&extrinsic for camera pos (which we alread obtained as loc)\n", 510 | " camera_world = origin_to_world(n_points, camera_mat, world_mat)\n", 511 | "\n", 512 | " # 8. ray = pixel - camera origin (in world)\n", 513 | " ray_vector = pixels_world - camera_world\n", 514 | "\n", 515 | " # 9. sample fg and bg points according to nerfpp (uniform and inverse sphere)\n", 516 | " fg_pts, bg_pts, bg_depth_real = uniformsampling(ray_o=camera_world, ray_d=ray_vector, min_depth=depth_range[0], N_samples=n_ray_samples, device=device)\n", 517 | " \n", 518 | " #10. convert bg pts (x', y', z' 1/r) to real bg pts (x, y, z)\n", 519 | " bg_pts_real = pts2realpts(bg_pts, bg_depth_real)\n", 520 | " \n", 521 | " return pixels_world, camera_world, world_mat, fg_pts, bg_pts_real" 522 | ] 523 | }, 524 | { 525 | "cell_type": "markdown", 526 | "id": "2fd19406", 527 | "metadata": {}, 528 | "source": [ 529 | "# Visualization" 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "execution_count": 7, 535 | "id": "b9dd6b14", 536 | "metadata": {}, 537 | "outputs": [], 538 | "source": [ 539 | "# draw sphere with radius r \n", 540 | "# also draw contours and vertical lines\n", 541 | "def draw_sphere(r, sphere_colorscale, sphere_opacity):\n", 542 | " # sphere\n", 543 | " u = np.linspace(0, 2 * np.pi, 100)\n", 544 | " v = np.linspace(0, np.pi, 100)\n", 545 | " x = r * np.outer(np.cos(u), np.sin(v))\n", 546 | " y = r * np.outer(np.sin(u), np.sin(v))\n", 547 | " z = r * np.outer(np.ones(np.size(u)), np.cos(v))\n", 548 | " \n", 549 | " # vertical lines on sphere\n", 550 | " u2 = np.linspace(0, 2 * np.pi, 20)\n", 551 | " x2 = r * np.outer(np.cos(u2), np.sin(v))\n", 552 | " y2 = r * np.outer(np.sin(u2), np.sin(v))\n", 553 | " z2 = r * np.outer(np.ones(np.size(u2)), np.cos(v))\n", 554 | " \n", 555 | " # create sphere and draw sphere with contours\n", 556 | " fig = go.Figure(data=[go.Surface(x=x, y=y, z=z, \n", 557 | " colorscale=sphere_colorscale, opacity=sphere_opacity,\n", 558 | " contours = {\n", 559 | " 'z' : {'show' : True, 'start' : -r,\n", 560 | " 'end' : r, 'size' : r/10,\n", 561 | " 'color' : 'white',\n", 562 | " 'width' : 1}\n", 563 | " }\n", 564 | " , showscale=False)])\n", 565 | " \n", 566 | " # vertical lines on sphere\n", 567 | " for i in range(len(u2)):\n", 568 | " fig.add_scatter3d(x=x2[i], y=y2[i], z=z2[i], \n", 569 | " line=dict(\n", 570 | " color='white',\n", 571 | " width=1\n", 572 | " ),\n", 573 | " mode='lines',\n", 574 | " showlegend=False)\n", 575 | " \n", 576 | " return fig\n", 577 | "\n", 578 | "# draw xyplane\n", 579 | "def draw_XYplane(fig, xy_plane_colorscale, xy_plane_opacity, x_range = [-2, 2], y_range = [-2, 2]):\n", 580 | " x3 = np.linspace(x_range[0], x_range[1], 100)\n", 581 | " y3 = np.linspace(y_range[0], y_range[1], 100)\n", 582 | " z3 = np.zeros(shape=(100,100))\n", 583 | " \n", 584 | " fig.add_surface(x=x3, y=y3, z=z3,\n", 585 | " colorscale =xy_plane_colorscale, opacity=xy_plane_opacity,\n", 586 | " showscale=False\n", 587 | " )\n", 588 | " \n", 589 | " return fig\n", 590 | " \n", 591 | "\n", 592 | "def draw_XYZworld(fig, world_axis_size):\n", 593 | " # x, y, z positive direction (world)\n", 594 | " X_axis = [0, world_axis_size]\n", 595 | " X_text = [None, \"X\"]\n", 596 | " X0 = [0, 0]\n", 597 | " Y_axis = [0, world_axis_size]\n", 598 | " Y_text = [None, \"Y\"]\n", 599 | " Y0 = [0, 0]\n", 600 | " Z_axis = [0, world_axis_size]\n", 601 | " Z_text = [None, \"Z\"]\n", 602 | " Z0 = [0, 0]\n", 603 | " \n", 604 | " fig.add_scatter3d(x=X_axis, y=Y0, z=Z0, \n", 605 | " line=dict(\n", 606 | " color='red',\n", 607 | " width=10\n", 608 | " ),\n", 609 | " mode='lines+text',\n", 610 | " text=X_text,\n", 611 | " textposition='top center',\n", 612 | " textfont=dict(\n", 613 | " color=\"red\",\n", 614 | " size=18\n", 615 | " ),\n", 616 | " showlegend=False)\n", 617 | "\n", 618 | " fig.add_scatter3d(x=X0, y=Y_axis, z=Z0, \n", 619 | " line=dict(\n", 620 | " color='green',\n", 621 | " width=10\n", 622 | " ),\n", 623 | " mode='lines+text',\n", 624 | " text=Y_text,\n", 625 | " textposition='top center',\n", 626 | " textfont=dict(\n", 627 | " color=\"green\",\n", 628 | " size=18\n", 629 | " ),\n", 630 | " showlegend=False)\n", 631 | "\n", 632 | " fig.add_scatter3d(x=X0, y=Y0, z=Z_axis, \n", 633 | " line=dict(\n", 634 | " color='blue',\n", 635 | " width=10\n", 636 | " ),\n", 637 | " mode='lines+text',\n", 638 | " text=Z_text,\n", 639 | " textposition='top center',\n", 640 | " textfont=dict(\n", 641 | " color=\"blue\",\n", 642 | " size=18\n", 643 | " ),\n", 644 | " showlegend=False)\n", 645 | " \n", 646 | " return fig\n", 647 | "\n", 648 | "# draw cam and cam coordinate system\n", 649 | "def draw_cam_init(fig, world_mat, camera_axis_size, camera_color):\n", 650 | " # camera at init\n", 651 | "\n", 652 | " Xc = [world_mat[0, : ,3][0]]\n", 653 | " Yc = [world_mat[0, : ,3][1]]\n", 654 | " Zc = [world_mat[0, : ,3][2]]\n", 655 | " text_c = [\"Camera\"]\n", 656 | "\n", 657 | " # camera axis\n", 658 | " Xc_Xaxis = Xc + [world_mat[0, : ,0][0]*camera_axis_size+Xc[0]]\n", 659 | " Yc_Xaxis = Yc + [world_mat[0, : ,0][1]*camera_axis_size+Yc[0]]\n", 660 | " Zc_Xaxis = Zc + [world_mat[0, : ,0][2]*camera_axis_size+Zc[0]]\n", 661 | " text_Xaxis = [None, \"Xc\"]\n", 662 | " \n", 663 | " # -z in world perspective\n", 664 | " Xc_Yaxis = Xc + [world_mat[0, : ,1][0]*camera_axis_size+Xc[0]]\n", 665 | " Yc_Yaxis = Yc + [world_mat[0, : ,1][1]*camera_axis_size+Yc[0]]\n", 666 | " Zc_Yaxis = Zc + [world_mat[0, : ,1][2]*camera_axis_size+Zc[0]]\n", 667 | " text_Yaxis = [None, \"Yc\"]\n", 668 | "\n", 669 | " # y in world perspective\n", 670 | " Xc_Zaxis = Xc + [world_mat[0, : ,2][0]*camera_axis_size+Xc[0]]\n", 671 | " Yc_Zaxis = Yc + [world_mat[0, : ,2][1]*camera_axis_size+Yc[0]]\n", 672 | " Zc_Zaxis = Zc + [world_mat[0, : ,2][2]*camera_axis_size+Zc[0]]\n", 673 | " text_Zaxis = [None, \"Zc\"]\n", 674 | " \n", 675 | " # cam pos\n", 676 | " fig.add_scatter3d(x=Xc, y=Yc, z=Zc, \n", 677 | " mode='markers',\n", 678 | " marker=dict(\n", 679 | " color=camera_color,\n", 680 | " size=4,\n", 681 | " sizemode='diameter'\n", 682 | " ),\n", 683 | " showlegend=False)\n", 684 | "\n", 685 | " # camera axis\n", 686 | " fig.add_scatter3d(x=Xc_Xaxis, y=Yc_Xaxis, z=Zc_Xaxis, \n", 687 | " line=dict(\n", 688 | " color='red',\n", 689 | " width=10\n", 690 | " ),\n", 691 | " mode='lines+text',\n", 692 | " text=text_Xaxis,\n", 693 | " textposition='top center',\n", 694 | " textfont=dict(\n", 695 | " color=\"red\",\n", 696 | " size=18\n", 697 | " ),\n", 698 | " showlegend=False)\n", 699 | "\n", 700 | " fig.add_scatter3d(x=Xc_Yaxis, y=Yc_Yaxis, z=Zc_Yaxis, \n", 701 | " line=dict(\n", 702 | " color='green',\n", 703 | " width=10\n", 704 | " ),\n", 705 | " mode='lines+text',\n", 706 | " text=text_Yaxis,\n", 707 | " textposition='top center',\n", 708 | " textfont=dict(\n", 709 | " color=\"green\",\n", 710 | " size=18\n", 711 | " ),\n", 712 | " showlegend=False)\n", 713 | "\n", 714 | " fig.add_scatter3d(x=Xc_Zaxis, y=Yc_Zaxis, z=Zc_Zaxis, \n", 715 | " line=dict(\n", 716 | " color='blue',\n", 717 | " width=10\n", 718 | " ),\n", 719 | " mode='lines+text',\n", 720 | " text=text_Zaxis,\n", 721 | " textposition='top center',\n", 722 | " textfont=dict(\n", 723 | " color=\"blue\",\n", 724 | " size=18\n", 725 | " ),\n", 726 | " showlegend=False)\n", 727 | " \n", 728 | " return fig\n", 729 | "\n", 730 | "# draw all rays\n", 731 | "def draw_all_rays(fig, p_i, ray_color):\n", 732 | " for i in range(p_i.shape[1]):\n", 733 | " Xray = p_i[0, i, :, 0]\n", 734 | " Yray = p_i[0, i, :, 1]\n", 735 | " Zray = p_i[0, i, :, 2]\n", 736 | " \n", 737 | " fig.add_scatter3d(x=Xray, y=Yray, z=Zray, \n", 738 | " line=dict(\n", 739 | " color=ray_color,\n", 740 | " width=5\n", 741 | " ),\n", 742 | " mode='lines',\n", 743 | " showlegend=False)\n", 744 | " \n", 745 | " return fig\n", 746 | "\n", 747 | "# draw all rays\n", 748 | "def draw_all_rays_with_marker(fig, p_i, marker_size, ray_color):\n", 749 | " \n", 750 | " # convert colorscale string to px.colors.seqeuntial\n", 751 | " # default color is set to Viridis in case of mismatch\n", 752 | " c = px.colors.sequential.Viridis\n", 753 | "\n", 754 | " for c_name in [ray_color, ray_color.capitalize()]:\n", 755 | " try:\n", 756 | " c = getattr(px.colors.sequential, c_name)\n", 757 | " except:\n", 758 | " continue\n", 759 | " \n", 760 | " for i in range(p_i.shape[1]):\n", 761 | " Xray = p_i[0, i, :, 0]\n", 762 | " Yray = p_i[0, i, :, 1]\n", 763 | " Zray = p_i[0, i, :, 2]\n", 764 | " \n", 765 | " fig.add_scatter3d(x=Xray, y=Yray, z=Zray, \n", 766 | " \n", 767 | " marker=dict(\n", 768 | "# color=np.arange(len(Xray)),\n", 769 | " color=c,\n", 770 | "# colorscale='Viridis',\n", 771 | " size=marker_size\n", 772 | " ),\n", 773 | " \n", 774 | " line=dict(\n", 775 | "# color=np.arange(len(Xray)),\n", 776 | " color=c,\n", 777 | "# colorscale='Viridis',\n", 778 | " width=3\n", 779 | " ),\n", 780 | " mode=\"lines+markers\",\n", 781 | " showlegend=False)\n", 782 | " \n", 783 | " return fig\n", 784 | "\n", 785 | "# draw near&far frustrum with rays connecting the corners (changed for nerfpp)\n", 786 | "def draw_ray_frus(fig, p_i, frustrum_color, frustrum_opacity, at=[0, -1]):\n", 787 | " \n", 788 | " for i in at:\n", 789 | "# Xfrus = p_i[0, :, i, 0][[0,1,2,3,7,11,15,14,13,12,8,4,0]]\n", 790 | "# Yfrus = p_i[0, :, i, 1][[0,1,2,3,7,11,15,14,13,12,8,4,0]]\n", 791 | "# Zfrus = p_i[0, :, i, 2][[0,1,2,3,7,11,15,14,13,12,8,4,0]]\n", 792 | "\n", 793 | " Xfrus = p_i[0, :, i, 0]\n", 794 | " Yfrus = p_i[0, :, i, 1]\n", 795 | " Zfrus = p_i[0, :, i, 2]\n", 796 | " \n", 797 | " fig.add_scatter3d(x=Xfrus, y=Yfrus, z=Zfrus, \n", 798 | " line=dict(\n", 799 | " color=frustrum_color,\n", 800 | " width=5\n", 801 | " ),\n", 802 | " mode='lines',\n", 803 | " surfaceaxis=0,\n", 804 | " surfacecolor=frustrum_color,\n", 805 | " opacity=frustrum_opacity,\n", 806 | " showlegend=False)\n", 807 | " \n", 808 | " return fig\n", 809 | "\n", 810 | "# draw foreground sample points, ray and frustrum\n", 811 | "def draw_foreground(fig, fg_pts, fg_color, marker_size, at=[0, -1]):\n", 812 | " fig = draw_all_rays_with_marker(fig, fg_pts, marker_size, fg_color)\n", 813 | " \n", 814 | " return fig\n", 815 | "\n", 816 | "# draw background sample points, ray and frustrum\n", 817 | "def draw_background(fig, bg_pts, bg_color, marker_size, at=[0, -1]):\n", 818 | " fig = draw_all_rays_with_marker(fig, bg_pts, marker_size, bg_color)\n", 819 | " \n", 820 | " return fig" 821 | ] 822 | }, 823 | { 824 | "cell_type": "code", 825 | "execution_count": 8, 826 | "id": "f12fecdd", 827 | "metadata": {}, 828 | "outputs": [ 829 | { 830 | "name": "stderr", 831 | "output_type": "stream", 832 | "text": [ 833 | "C:\\Users\\laphi\\AppData\\Local\\Temp/ipykernel_24504/3891967274.py:2: UserWarning: \n", 834 | "The dash_core_components package is deprecated. Please replace\n", 835 | "`import dash_core_components as dcc` with `from dash import dcc`\n", 836 | " import dash_core_components as dcc\n", 837 | "C:\\Users\\laphi\\AppData\\Local\\Temp/ipykernel_24504/3891967274.py:3: UserWarning: \n", 838 | "The dash_html_components package is deprecated. Please replace\n", 839 | "`import dash_html_components as html` with `from dash import html`\n", 840 | " import dash_html_components as html\n" 841 | ] 842 | } 843 | ], 844 | "source": [ 845 | "from jupyter_dash import JupyterDash\n", 846 | "import dash_core_components as dcc\n", 847 | "import dash_html_components as html\n", 848 | "from dash.dependencies import Input, Output\n", 849 | "\n", 850 | "# for colors\n", 851 | "import matplotlib.colors as mcolors" 852 | ] 853 | }, 854 | { 855 | "cell_type": "code", 856 | "execution_count": 10, 857 | "id": "e5e28428", 858 | "metadata": { 859 | "scrolled": false 860 | }, 861 | "outputs": [ 862 | { 863 | "name": "stderr", 864 | "output_type": "stream", 865 | "text": [ 866 | "c:\\users\\laphi\\appdata\\local\\programs\\python\\python38\\lib\\site-packages\\jupyter_dash\\jupyter_app.py:139: UserWarning:\n", 867 | "\n", 868 | "The 'environ['werkzeug.server.shutdown']' function is deprecated and will be removed in Werkzeug 2.1.\n", 869 | "\n" 870 | ] 871 | }, 872 | { 873 | "data": { 874 | "text/html": [ 875 | "\n", 876 | " \n", 883 | " " 884 | ], 885 | "text/plain": [ 886 | "" 887 | ] 888 | }, 889 | "metadata": {}, 890 | "output_type": "display_data" 891 | } 892 | ], 893 | "source": [ 894 | "app = JupyterDash(__name__)\n", 895 | "\n", 896 | "app.layout = html.Div([\n", 897 | " html.H1(\"Nerfplusplus ray sampling visualization\"),\n", 898 | " dcc.Graph(id='graph'),\n", 899 | " \n", 900 | " html.Div([\n", 901 | " html.Div([\n", 902 | " \n", 903 | " # changes to setting and ray casted \n", 904 | " html.Label([ \"u (theta)\",\n", 905 | " dcc.Slider(\n", 906 | " id='u-slider', \n", 907 | " min=0, max=1,\n", 908 | " value=0.00,\n", 909 | " marks={str(val) : str(val) for val in [0.00, 0.25, 0.50, 0.75]},\n", 910 | " step=0.01, tooltip = { 'always_visible': True }\n", 911 | " ), ]),\n", 912 | " html.Label([ \"v (phi)\",\n", 913 | " dcc.Slider(\n", 914 | " id='v-slider', \n", 915 | " min=0, max=1,\n", 916 | " value=0.25,\n", 917 | " marks={str(val) : str(val) for val in [0.00, 0.25, 0.50, 0.75]},\n", 918 | " step=0.01, tooltip = { 'always_visible': True }\n", 919 | " ), ]),\n", 920 | " html.Label([ \"fov (field-of-view))\",\n", 921 | " dcc.Slider(\n", 922 | " id='fov-slider', \n", 923 | " min=0, max=100,\n", 924 | " value=50,\n", 925 | " marks={str(val) : str(val) for val in [0, 20, 40, 60, 80, 100]},\n", 926 | " step=5, tooltip = { 'always_visible': True }\n", 927 | " ), ]),\n", 928 | " html.Label([ \"foreground near depth\",\n", 929 | " dcc.Slider(\n", 930 | " id='foreground-near-depth-slider', \n", 931 | " min=0, max=2,\n", 932 | " value=0.5,\n", 933 | " marks={f\"{val:.1f}\" : f\"{val:.1f}\" for val in [0.1 * i for i in range(21)]},\n", 934 | " step=0.1, tooltip = { 'always_visible': True }\n", 935 | " ), ])\n", 936 | " ], style = {'width' : '48%', 'display' : 'inline-block'}),\n", 937 | " \n", 938 | " html.Div([\n", 939 | " # changes to visual appearance\n", 940 | " \n", 941 | " # axis scale\n", 942 | " html.Div([\n", 943 | " html.Label([ \"world axis size\",\n", 944 | " html.Div([\n", 945 | " dcc.Input(id='world-axis-size-input',\n", 946 | " value=1.5,\n", 947 | " type='number', style={'width': '50%'}\n", 948 | " )\n", 949 | " ]),\n", 950 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n", 951 | " html.Label([ \"camera axis size\",\n", 952 | " html.Div([\n", 953 | " dcc.Input(id='camera-axis-size-input',\n", 954 | " value=0.3,\n", 955 | " type='number', style={'width': '50%'}\n", 956 | " )\n", 957 | " ]),\n", 958 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n", 959 | " html.Label([ \"sample marker size\",\n", 960 | " html.Div([\n", 961 | " dcc.Input(id='sample-marker-size-input',\n", 962 | " value=2,\n", 963 | " type='number', style={'width': '50%'}\n", 964 | " )\n", 965 | " ]),\n", 966 | " ], style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}),\n", 967 | " ]),\n", 968 | " \n", 969 | " # opacity \n", 970 | " html.Div([\n", 971 | " html.Label([ \"sphere opacity\",\n", 972 | " html.Div([\n", 973 | " dcc.Input(id='sphere-opacity-input',\n", 974 | " value=0.2,\n", 975 | " type='number', style={'width': '50%'}\n", 976 | " )\n", 977 | " ])\n", 978 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n", 979 | " \n", 980 | " html.Label([ \"xy-plane opacity\", \n", 981 | " html.Div([\n", 982 | " dcc.Input(id='xy-plane-opacity-input',\n", 983 | " value=0.8,\n", 984 | " type='number', style={'width': '50%'}\n", 985 | " )\n", 986 | " ])\n", 987 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n", 988 | " \n", 989 | " ]),\n", 990 | " \n", 991 | " # color\n", 992 | " html.Div([\n", 993 | " html.Label([ \"camera color\",\n", 994 | " html.Div([\n", 995 | " dcc.Dropdown(id='camera-color-input',\n", 996 | " clearable=False,\n", 997 | " value='yellow',\n", 998 | " options=[\n", 999 | " {'label': c, 'value': c}\n", 1000 | " for (c, _) in mcolors.CSS4_COLORS.items()\n", 1001 | " ], style={'width': '80%'}\n", 1002 | " )\n", 1003 | " ])\n", 1004 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n", 1005 | " \n", 1006 | " html.Label([ \"foreground color\",\n", 1007 | " html.Div([\n", 1008 | " dcc.Dropdown(id='fg-color-input',\n", 1009 | " clearable=False,\n", 1010 | " value='plotly3',\n", 1011 | " options=[\n", 1012 | " {'label': c, 'value': c}\n", 1013 | " for c in px.colors.named_colorscales()\n", 1014 | " ], style={'width': '80%'}\n", 1015 | " )\n", 1016 | " ])\n", 1017 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n", 1018 | " \n", 1019 | " html.Label([ \"background color\",\n", 1020 | " html.Div([\n", 1021 | " dcc.Dropdown(id='bg-color-input',\n", 1022 | " clearable=False,\n", 1023 | " value='plotly3',\n", 1024 | " options=[\n", 1025 | " {'label': c, 'value': c}\n", 1026 | " for c in px.colors.named_colorscales()\n", 1027 | " ], style={'width': '80%'}\n", 1028 | " )\n", 1029 | " ])\n", 1030 | " ], style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}),\n", 1031 | " ]),\n", 1032 | " \n", 1033 | " # colorscale\n", 1034 | " html.Div([\n", 1035 | " html.Label([ \"sphere colorscale\",\n", 1036 | " html.Div([\n", 1037 | " dcc.Dropdown(id='sphere-colorscale-input',\n", 1038 | " clearable=False,\n", 1039 | " value='greys',\n", 1040 | " options=[\n", 1041 | " {'label': c, 'value': c}\n", 1042 | " for c in px.colors.named_colorscales()\n", 1043 | " ], style={'width': '80%'}\n", 1044 | " )\n", 1045 | " ])\n", 1046 | " ], style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}),\n", 1047 | " \n", 1048 | " html.Label([ \"xy-plane colorscale\",\n", 1049 | " html.Div([\n", 1050 | " dcc.Dropdown(id='xy-plane-colorscale-input',\n", 1051 | " clearable=False,\n", 1052 | " value='greys',\n", 1053 | " options=[\n", 1054 | " {'label': c, 'value': c}\n", 1055 | " for c in px.colors.named_colorscales()\n", 1056 | " ], style={'width': '80%'}\n", 1057 | " )\n", 1058 | " ])\n", 1059 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n", 1060 | " \n", 1061 | " html.Label([ \" \",\n", 1062 | " html.Div([\n", 1063 | " dcc.Checklist(id='show-background-checklist',\n", 1064 | " \n", 1065 | " options=[\n", 1066 | " {'label': 'show background', 'value': 'show_background'}\n", 1067 | " ], style={'width': '80%'},\n", 1068 | " value=[],\n", 1069 | " )\n", 1070 | " ])\n", 1071 | " ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}),\n", 1072 | " \n", 1073 | " ]),\n", 1074 | " \n", 1075 | " \n", 1076 | " \n", 1077 | " ], style = {'width' : '48%', 'float' : 'right', 'display' : 'inline-block'}),\n", 1078 | " \n", 1079 | " ]),\n", 1080 | " \n", 1081 | "])\n", 1082 | "\n", 1083 | "@app.callback(\n", 1084 | " Output('graph', 'figure'),\n", 1085 | " Input(\"u-slider\", \"value\"),\n", 1086 | " Input(\"v-slider\", \"value\"),\n", 1087 | " \n", 1088 | " Input(\"fov-slider\", \"value\"),\n", 1089 | " Input(\"foreground-near-depth-slider\", \"value\"),\n", 1090 | " \n", 1091 | " Input(\"world-axis-size-input\", \"value\"),\n", 1092 | " Input(\"camera-axis-size-input\", \"value\"),\n", 1093 | " Input(\"sample-marker-size-input\", \"value\"),\n", 1094 | " \n", 1095 | " Input(\"camera-color-input\", \"value\"),\n", 1096 | " Input(\"fg-color-input\", \"value\"),\n", 1097 | " Input(\"bg-color-input\", \"value\"),\n", 1098 | " \n", 1099 | " Input('sphere-colorscale-input', \"value\"),\n", 1100 | " Input('xy-plane-colorscale-input', \"value\"),\n", 1101 | " Input('show-background-checklist', \"value\"),\n", 1102 | " \n", 1103 | " Input(\"sphere-opacity-input\", \"value\"),\n", 1104 | " Input(\"xy-plane-opacity-input\", \"value\"),\n", 1105 | ")\n", 1106 | "\n", 1107 | "def update_figure(u, v, \n", 1108 | " fov, foreground_near_depth,\n", 1109 | " world_axis_size, camera_axis_size, sample_marker_size,\n", 1110 | " camera_color, fg_color, bg_color,\n", 1111 | " sphere_colorscale, xy_plane_colorscale, show_background,\n", 1112 | " sphere_opacity, xy_plane_opacity \n", 1113 | " ):\n", 1114 | " \n", 1115 | " depth_range = [foreground_near_depth, 2]\n", 1116 | " \n", 1117 | " # sphere\n", 1118 | " fig = draw_sphere(r=1, sphere_colorscale=sphere_colorscale, sphere_opacity=sphere_opacity)\n", 1119 | "\n", 1120 | " # change figure size\n", 1121 | "# fig.update_layout(autosize=False, width = 500, height=500)\n", 1122 | "\n", 1123 | " # draw axes in proportion to the proportion of their ranges\n", 1124 | " fig.update_layout(scene_aspectmode='data')\n", 1125 | "\n", 1126 | " # xy plane\n", 1127 | " fig = draw_XYplane(fig, xy_plane_colorscale, xy_plane_opacity,\n", 1128 | " x_range=[-depth_range[1], depth_range[1]], y_range=[-depth_range[1], depth_range[1]])\n", 1129 | "\n", 1130 | " # show world coordinate system (X, Y, Z positive direction)\n", 1131 | " fig = draw_XYZworld(fig, world_axis_size=world_axis_size)\n", 1132 | "\n", 1133 | " pixels_world, camera_world, world_mat, fg_pts, bg_pts = nerfpp(u=u, v=v, fov=fov, depth_range=depth_range)\n", 1134 | "\n", 1135 | " # draw camera at init (with its cooridnate system)\n", 1136 | " fig = draw_cam_init(fig, world_mat, \n", 1137 | " camera_axis_size=camera_axis_size, camera_color=camera_color)\n", 1138 | "\n", 1139 | " # draw foreground and background sample point, ray, and frustrum\n", 1140 | " fig = draw_foreground(fig, fg_pts, fg_color, sample_marker_size, at=[0, -1])\n", 1141 | " \n", 1142 | " if show_background:\n", 1143 | " fig = draw_background(fig, bg_pts.unsqueeze(0), bg_color, sample_marker_size, at=[0, -1])\n", 1144 | " \n", 1145 | " return fig\n", 1146 | "\n", 1147 | "app.run_server(mode='inline')" 1148 | ] 1149 | } 1150 | ], 1151 | "metadata": { 1152 | "kernelspec": { 1153 | "display_name": "Python 3 (ipykernel)", 1154 | "language": "python", 1155 | "name": "python3" 1156 | }, 1157 | "language_info": { 1158 | "codemirror_mode": { 1159 | "name": "ipython", 1160 | "version": 3 1161 | }, 1162 | "file_extension": ".py", 1163 | "mimetype": "text/x-python", 1164 | "name": "python", 1165 | "nbconvert_exporter": "python", 1166 | "pygments_lexer": "ipython3", 1167 | "version": "3.8.8" 1168 | } 1169 | }, 1170 | "nbformat": 4, 1171 | "nbformat_minor": 5 1172 | } 1173 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VisualizingNerfplusplus 2 | 3 | initialized to the Nerfplusplus setting. 4 | 5 | run code and have fun~ 6 | 7 | ``` 8 | python app.py 9 | ``` 10 | 11 | 12 |

13 | 14 |

15 | -------------------------------------------------------------------------------- /__pycache__/drawing_tools.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laphisboy/VisualizingNerfplusplus/4b86e2d15215b0baa29af6cc7e794c8521fa466a/__pycache__/drawing_tools.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/nerfplusplus_tools.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laphisboy/VisualizingNerfplusplus/4b86e2d15215b0baa29af6cc7e794c8521fa466a/__pycache__/nerfplusplus_tools.cpython-38.pyc -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import plotly.express as px 2 | 3 | import dash 4 | import dash_core_components as dcc 5 | import dash_html_components as html 6 | from dash.dependencies import Input, Output 7 | 8 | # for colors 9 | import matplotlib.colors as mcolors 10 | 11 | from drawing_tools import * 12 | from nerfplusplus_tools import * 13 | 14 | app = dash.Dash(__name__) 15 | 16 | app.layout = html.Div([ 17 | html.H1("Nerfplusplus ray sampling visualization"), 18 | dcc.Graph(id='graph'), 19 | 20 | html.Div([ 21 | html.Div([ 22 | 23 | # changes to setting and ray casted 24 | html.Label([ "u (theta)", 25 | dcc.Slider( 26 | id='u-slider', 27 | min=0, max=1, 28 | value=0.00, 29 | marks={str(val) : str(val) for val in [0.00, 0.25, 0.50, 0.75]}, 30 | step=0.01, tooltip = { 'always_visible': True } 31 | ), ]), 32 | html.Label([ "v (phi)", 33 | dcc.Slider( 34 | id='v-slider', 35 | min=0, max=1, 36 | value=0.25, 37 | marks={str(val) : str(val) for val in [0.00, 0.25, 0.50, 0.75]}, 38 | step=0.01, tooltip = { 'always_visible': True } 39 | ), ]), 40 | html.Label([ "fov (field-of-view))", 41 | dcc.Slider( 42 | id='fov-slider', 43 | min=0, max=100, 44 | value=50, 45 | marks={str(val) : str(val) for val in [0, 20, 40, 60, 80, 100]}, 46 | step=5, tooltip = { 'always_visible': True } 47 | ), ]), 48 | html.Label([ "foreground near depth", 49 | dcc.Slider( 50 | id='foreground-near-depth-slider', 51 | min=0, max=2, 52 | value=0.5, 53 | marks={f"{val:.1f}" : f"{val:.1f}" for val in [0.1 * i for i in range(21)]}, 54 | step=0.1, tooltip = { 'always_visible': True } 55 | ), ]) 56 | ], style = {'width' : '48%', 'display' : 'inline-block'}), 57 | 58 | html.Div([ 59 | # changes to visual appearance 60 | 61 | # axis scale 62 | html.Div([ 63 | html.Label([ "world axis size", 64 | html.Div([ 65 | dcc.Input(id='world-axis-size-input', 66 | value=1.5, 67 | type='number', style={'width': '50%'} 68 | ) 69 | ]), 70 | ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}), 71 | html.Label([ "camera axis size", 72 | html.Div([ 73 | dcc.Input(id='camera-axis-size-input', 74 | value=0.3, 75 | type='number', style={'width': '50%'} 76 | ) 77 | ]), 78 | ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}), 79 | html.Label([ "sample marker size", 80 | html.Div([ 81 | dcc.Input(id='sample-marker-size-input', 82 | value=2, 83 | type='number', style={'width': '50%'} 84 | ) 85 | ]), 86 | ], style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}), 87 | ]), 88 | 89 | # opacity 90 | html.Div([ 91 | html.Label([ "sphere opacity", 92 | html.Div([ 93 | dcc.Input(id='sphere-opacity-input', 94 | value=0.2, 95 | type='number', style={'width': '50%'} 96 | ) 97 | ]) 98 | ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}), 99 | 100 | html.Label([ "xy-plane opacity", 101 | html.Div([ 102 | dcc.Input(id='xy-plane-opacity-input', 103 | value=0.8, 104 | type='number', style={'width': '50%'} 105 | ) 106 | ]) 107 | ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}), 108 | 109 | ]), 110 | 111 | # color 112 | html.Div([ 113 | html.Label([ "camera color", 114 | html.Div([ 115 | dcc.Dropdown(id='camera-color-input', 116 | clearable=False, 117 | value='yellow', 118 | options=[ 119 | {'label': c, 'value': c} 120 | for (c, _) in mcolors.CSS4_COLORS.items() 121 | ], style={'width': '80%'} 122 | ) 123 | ]) 124 | ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}), 125 | 126 | html.Label([ "foreground color", 127 | html.Div([ 128 | dcc.Dropdown(id='fg-color-input', 129 | clearable=False, 130 | value='plotly3', 131 | options=[ 132 | {'label': c, 'value': c} 133 | for c in px.colors.named_colorscales() 134 | ], style={'width': '80%'} 135 | ) 136 | ]) 137 | ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}), 138 | 139 | html.Label([ "background color", 140 | html.Div([ 141 | dcc.Dropdown(id='bg-color-input', 142 | clearable=False, 143 | value='plotly3', 144 | options=[ 145 | {'label': c, 'value': c} 146 | for c in px.colors.named_colorscales() 147 | ], style={'width': '80%'} 148 | ) 149 | ]) 150 | ], style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}), 151 | ]), 152 | 153 | # colorscale 154 | html.Div([ 155 | html.Label([ "sphere colorscale", 156 | html.Div([ 157 | dcc.Dropdown(id='sphere-colorscale-input', 158 | clearable=False, 159 | value='greys', 160 | options=[ 161 | {'label': c, 'value': c} 162 | for c in px.colors.named_colorscales() 163 | ], style={'width': '80%'} 164 | ) 165 | ]) 166 | ], style = {'width' : '32%', 'float' : 'left', 'display' : 'inline-block'}), 167 | 168 | html.Label([ "xy-plane colorscale", 169 | html.Div([ 170 | dcc.Dropdown(id='xy-plane-colorscale-input', 171 | clearable=False, 172 | value='greys', 173 | options=[ 174 | {'label': c, 'value': c} 175 | for c in px.colors.named_colorscales() 176 | ], style={'width': '80%'} 177 | ) 178 | ]) 179 | ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}), 180 | 181 | html.Label([ " ", 182 | html.Div([ 183 | dcc.Checklist(id='show-background-checklist', 184 | 185 | options=[ 186 | {'label': 'show background', 'value': 'show_background'} 187 | ], style={'width': '80%'}, 188 | value=[ ], 189 | ) 190 | ]) 191 | ], style = {'width' : '34%', 'float' : 'left', 'display' : 'inline-block'}), 192 | 193 | ]), 194 | 195 | 196 | 197 | ], style = {'width' : '48%', 'float' : 'right', 'display' : 'inline-block'}), 198 | 199 | ]), 200 | 201 | ]) 202 | 203 | @app.callback( 204 | Output('graph', 'figure'), 205 | Input("u-slider", "value"), 206 | Input("v-slider", "value"), 207 | 208 | Input("fov-slider", "value"), 209 | Input("foreground-near-depth-slider", "value"), 210 | 211 | Input("world-axis-size-input", "value"), 212 | Input("camera-axis-size-input", "value"), 213 | Input("sample-marker-size-input", "value"), 214 | 215 | Input("camera-color-input", "value"), 216 | Input("fg-color-input", "value"), 217 | Input("bg-color-input", "value"), 218 | 219 | Input('sphere-colorscale-input', "value"), 220 | Input('xy-plane-colorscale-input', "value"), 221 | Input('show-background-checklist', "value"), 222 | 223 | Input("sphere-opacity-input", "value"), 224 | Input("xy-plane-opacity-input", "value"), 225 | ) 226 | 227 | def update_figure(u, v, 228 | fov, foreground_near_depth, 229 | world_axis_size, camera_axis_size, sample_marker_size, 230 | camera_color, fg_color, bg_color, 231 | sphere_colorscale, xy_plane_colorscale, show_background, 232 | sphere_opacity, xy_plane_opacity 233 | ): 234 | 235 | depth_range = [foreground_near_depth, 2] 236 | 237 | # sphere 238 | fig = draw_sphere(r=1, sphere_colorscale=sphere_colorscale, sphere_opacity=sphere_opacity) 239 | 240 | # change figure size 241 | # fig.update_layout(autosize=False, width = 500, height=500) 242 | 243 | # draw axes in proportion to the proportion of their ranges 244 | fig.update_layout(scene_aspectmode='data') 245 | 246 | # xy plane 247 | fig = draw_XYplane(fig, xy_plane_colorscale, xy_plane_opacity, 248 | x_range=[-depth_range[1], depth_range[1]], y_range=[-depth_range[1], depth_range[1]]) 249 | 250 | # show world coordinate system (X, Y, Z positive direction) 251 | fig = draw_XYZworld(fig, world_axis_size=world_axis_size) 252 | 253 | pixels_world, camera_world, world_mat, fg_pts, bg_pts = nerfpp(u=u, v=v, fov=fov, depth_range=depth_range) 254 | 255 | # draw camera at init (with its cooridnate system) 256 | fig = draw_cam_init(fig, world_mat, 257 | camera_axis_size=camera_axis_size, camera_color=camera_color) 258 | 259 | # draw foreground and background sample point, ray, and frustrum 260 | fig = draw_foreground(fig, fg_pts, fg_color, sample_marker_size, at=[0, -1]) 261 | 262 | if show_background: 263 | fig = draw_background(fig, bg_pts.unsqueeze(0), bg_color, sample_marker_size, at=[0, -1]) 264 | 265 | return fig 266 | 267 | if __name__ == '__main__': 268 | app.run_server(debug=True) -------------------------------------------------------------------------------- /drawing_tools.py: -------------------------------------------------------------------------------- 1 | import plotly.graph_objects as go 2 | import plotly.express as px 3 | import numpy as np 4 | 5 | # draw sphere with radius r 6 | # also draw contours and vertical lines 7 | def draw_sphere(r, sphere_colorscale, sphere_opacity): 8 | # sphere 9 | u = np.linspace(0, 2 * np.pi, 100) 10 | v = np.linspace(0, np.pi, 100) 11 | x = r * np.outer(np.cos(u), np.sin(v)) 12 | y = r * np.outer(np.sin(u), np.sin(v)) 13 | z = r * np.outer(np.ones(np.size(u)), np.cos(v)) 14 | 15 | # vertical lines on sphere 16 | u2 = np.linspace(0, 2 * np.pi, 20) 17 | x2 = r * np.outer(np.cos(u2), np.sin(v)) 18 | y2 = r * np.outer(np.sin(u2), np.sin(v)) 19 | z2 = r * np.outer(np.ones(np.size(u2)), np.cos(v)) 20 | 21 | # create sphere and draw sphere with contours 22 | fig = go.Figure(data=[go.Surface(x=x, y=y, z=z, 23 | colorscale=sphere_colorscale, opacity=sphere_opacity, 24 | contours = { 25 | 'z' : {'show' : True, 'start' : -r, 26 | 'end' : r, 'size' : r/10, 27 | 'color' : 'white', 28 | 'width' : 1} 29 | } 30 | , showscale=False)]) 31 | 32 | # vertical lines on sphere 33 | for i in range(len(u2)): 34 | fig.add_scatter3d(x=x2[i], y=y2[i], z=z2[i], 35 | line=dict( 36 | color='white', 37 | width=1 38 | ), 39 | mode='lines', 40 | showlegend=False) 41 | 42 | return fig 43 | 44 | # draw xyplane 45 | def draw_XYplane(fig, xy_plane_colorscale, xy_plane_opacity, x_range = [-2, 2], y_range = [-2, 2]): 46 | x3 = np.linspace(x_range[0], x_range[1], 100) 47 | y3 = np.linspace(y_range[0], y_range[1], 100) 48 | z3 = np.zeros(shape=(100,100)) 49 | 50 | fig.add_surface(x=x3, y=y3, z=z3, 51 | colorscale =xy_plane_colorscale, opacity=xy_plane_opacity, 52 | showscale=False 53 | ) 54 | 55 | return fig 56 | 57 | 58 | def draw_XYZworld(fig, world_axis_size): 59 | # x, y, z positive direction (world) 60 | X_axis = [0, world_axis_size] 61 | X_text = [None, "X"] 62 | X0 = [0, 0] 63 | Y_axis = [0, world_axis_size] 64 | Y_text = [None, "Y"] 65 | Y0 = [0, 0] 66 | Z_axis = [0, world_axis_size] 67 | Z_text = [None, "Z"] 68 | Z0 = [0, 0] 69 | 70 | fig.add_scatter3d(x=X_axis, y=Y0, z=Z0, 71 | line=dict( 72 | color='red', 73 | width=10 74 | ), 75 | mode='lines+text', 76 | text=X_text, 77 | textposition='top center', 78 | textfont=dict( 79 | color="red", 80 | size=18 81 | ), 82 | showlegend=False) 83 | 84 | fig.add_scatter3d(x=X0, y=Y_axis, z=Z0, 85 | line=dict( 86 | color='green', 87 | width=10 88 | ), 89 | mode='lines+text', 90 | text=Y_text, 91 | textposition='top center', 92 | textfont=dict( 93 | color="green", 94 | size=18 95 | ), 96 | showlegend=False) 97 | 98 | fig.add_scatter3d(x=X0, y=Y0, z=Z_axis, 99 | line=dict( 100 | color='blue', 101 | width=10 102 | ), 103 | mode='lines+text', 104 | text=Z_text, 105 | textposition='top center', 106 | textfont=dict( 107 | color="blue", 108 | size=18 109 | ), 110 | showlegend=False) 111 | 112 | return fig 113 | 114 | # draw cam and cam coordinate system 115 | def draw_cam_init(fig, world_mat, camera_axis_size, camera_color): 116 | # camera at init 117 | 118 | Xc = [world_mat[0, : ,3][0]] 119 | Yc = [world_mat[0, : ,3][1]] 120 | Zc = [world_mat[0, : ,3][2]] 121 | text_c = ["Camera"] 122 | 123 | # camera axis 124 | Xc_Xaxis = Xc + [world_mat[0, : ,0][0]*camera_axis_size+Xc[0]] 125 | Yc_Xaxis = Yc + [world_mat[0, : ,0][1]*camera_axis_size+Yc[0]] 126 | Zc_Xaxis = Zc + [world_mat[0, : ,0][2]*camera_axis_size+Zc[0]] 127 | text_Xaxis = [None, "Xc"] 128 | 129 | # -z in world perspective 130 | Xc_Yaxis = Xc + [world_mat[0, : ,1][0]*camera_axis_size+Xc[0]] 131 | Yc_Yaxis = Yc + [world_mat[0, : ,1][1]*camera_axis_size+Yc[0]] 132 | Zc_Yaxis = Zc + [world_mat[0, : ,1][2]*camera_axis_size+Zc[0]] 133 | text_Yaxis = [None, "Yc"] 134 | 135 | # y in world perspective 136 | Xc_Zaxis = Xc + [world_mat[0, : ,2][0]*camera_axis_size+Xc[0]] 137 | Yc_Zaxis = Yc + [world_mat[0, : ,2][1]*camera_axis_size+Yc[0]] 138 | Zc_Zaxis = Zc + [world_mat[0, : ,2][2]*camera_axis_size+Zc[0]] 139 | text_Zaxis = [None, "Zc"] 140 | 141 | # cam pos 142 | fig.add_scatter3d(x=Xc, y=Yc, z=Zc, 143 | mode='markers', 144 | marker=dict( 145 | color=camera_color, 146 | size=4, 147 | sizemode='diameter' 148 | ), 149 | showlegend=False) 150 | 151 | # camera axis 152 | fig.add_scatter3d(x=Xc_Xaxis, y=Yc_Xaxis, z=Zc_Xaxis, 153 | line=dict( 154 | color='red', 155 | width=10 156 | ), 157 | mode='lines+text', 158 | text=text_Xaxis, 159 | textposition='top center', 160 | textfont=dict( 161 | color="red", 162 | size=18 163 | ), 164 | showlegend=False) 165 | 166 | fig.add_scatter3d(x=Xc_Yaxis, y=Yc_Yaxis, z=Zc_Yaxis, 167 | line=dict( 168 | color='green', 169 | width=10 170 | ), 171 | mode='lines+text', 172 | text=text_Yaxis, 173 | textposition='top center', 174 | textfont=dict( 175 | color="green", 176 | size=18 177 | ), 178 | showlegend=False) 179 | 180 | fig.add_scatter3d(x=Xc_Zaxis, y=Yc_Zaxis, z=Zc_Zaxis, 181 | line=dict( 182 | color='blue', 183 | width=10 184 | ), 185 | mode='lines+text', 186 | text=text_Zaxis, 187 | textposition='top center', 188 | textfont=dict( 189 | color="blue", 190 | size=18 191 | ), 192 | showlegend=False) 193 | 194 | return fig 195 | 196 | # draw all rays 197 | def draw_all_rays(fig, p_i, ray_color): 198 | for i in range(p_i.shape[1]): 199 | Xray = p_i[0, i, :, 0] 200 | Yray = p_i[0, i, :, 1] 201 | Zray = p_i[0, i, :, 2] 202 | 203 | fig.add_scatter3d(x=Xray, y=Yray, z=Zray, 204 | line=dict( 205 | color=ray_color, 206 | width=5 207 | ), 208 | mode='lines', 209 | showlegend=False) 210 | 211 | return fig 212 | 213 | # draw all rays 214 | def draw_all_rays_with_marker(fig, p_i, marker_size, ray_color): 215 | 216 | # convert colorscale string to px.colors.seqeuntial 217 | # default color is set to Viridis in case of mismatch 218 | c = px.colors.sequential.Viridis 219 | 220 | for c_name in [ray_color, ray_color.capitalize()]: 221 | try: 222 | c = getattr(px.colors.sequential, c_name) 223 | except: 224 | continue 225 | 226 | for i in range(p_i.shape[1]): 227 | Xray = p_i[0, i, :, 0] 228 | Yray = p_i[0, i, :, 1] 229 | Zray = p_i[0, i, :, 2] 230 | 231 | fig.add_scatter3d(x=Xray, y=Yray, z=Zray, 232 | 233 | marker=dict( 234 | # color=np.arange(len(Xray)), 235 | color=c, 236 | # colorscale='Viridis', 237 | size=marker_size 238 | ), 239 | 240 | line=dict( 241 | # color=np.arange(len(Xray)), 242 | color=c, 243 | # colorscale='Viridis', 244 | width=3 245 | ), 246 | mode="lines+markers", 247 | showlegend=False) 248 | 249 | return fig 250 | 251 | # draw near&far frustrum with rays connecting the corners (changed for nerfpp) 252 | def draw_ray_frus(fig, p_i, frustrum_color, frustrum_opacity, at=[0, -1]): 253 | 254 | for i in at: 255 | # Xfrus = p_i[0, :, i, 0][[0,1,2,3,7,11,15,14,13,12,8,4,0]] 256 | # Yfrus = p_i[0, :, i, 1][[0,1,2,3,7,11,15,14,13,12,8,4,0]] 257 | # Zfrus = p_i[0, :, i, 2][[0,1,2,3,7,11,15,14,13,12,8,4,0]] 258 | 259 | Xfrus = p_i[0, :, i, 0] 260 | Yfrus = p_i[0, :, i, 1] 261 | Zfrus = p_i[0, :, i, 2] 262 | 263 | fig.add_scatter3d(x=Xfrus, y=Yfrus, z=Zfrus, 264 | line=dict( 265 | color=frustrum_color, 266 | width=5 267 | ), 268 | mode='lines', 269 | surfaceaxis=0, 270 | surfacecolor=frustrum_color, 271 | opacity=frustrum_opacity, 272 | showlegend=False) 273 | 274 | return fig 275 | 276 | # draw foreground sample points, ray and frustrum 277 | def draw_foreground(fig, fg_pts, fg_color, marker_size, at=[0, -1]): 278 | fig = draw_all_rays_with_marker(fig, fg_pts, marker_size, fg_color) 279 | 280 | return fig 281 | 282 | # draw background sample points, ray and frustrum 283 | def draw_background(fig, bg_pts, bg_color, marker_size, at=[0, -1]): 284 | fig = draw_all_rays_with_marker(fig, bg_pts, marker_size, bg_color) 285 | 286 | return fig -------------------------------------------------------------------------------- /example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laphisboy/VisualizingNerfplusplus/4b86e2d15215b0baa29af6cc7e794c8521fa466a/example.gif -------------------------------------------------------------------------------- /nerfplusplus_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | HUGE_NUMBER = 1e10 5 | TINY_NUMBER = 1e-6 6 | 7 | # from autonomousvision/giraffe 8 | 9 | # 0. 10 | def to_pytorch(tensor, return_type=False): 11 | ''' Converts input tensor to pytorch. 12 | Args: 13 | tensor (tensor): Numpy or Pytorch tensor 14 | return_type (bool): whether to return input type 15 | ''' 16 | is_numpy = False 17 | if type(tensor) == np.ndarray: 18 | tensor = torch.from_numpy(tensor).float() 19 | is_numpy = True 20 | tensor = tensor.clone() 21 | if return_type: 22 | return tensor, is_numpy 23 | return tensor 24 | 25 | # 1. get camera intrinsic 26 | def get_camera_mat(fov=49.13, invert=True): 27 | # fov = 2 * arctan( sensor / (2 * focal)) 28 | # focal = (sensor / 2) * 1 / (tan(0.5 * fov)) 29 | # in our case, sensor = 2 as pixels are in [-1, 1] 30 | focal = 1. / np.tan(0.5 * fov * np.pi/180.) 31 | focal = focal.astype(np.float32) 32 | mat = torch.tensor([ 33 | [focal, 0., 0., 0.], 34 | [0., focal, 0., 0.], 35 | [0., 0., 1, 0.], 36 | [0., 0., 0., 1.] 37 | ]).reshape(1, 4, 4) 38 | 39 | if invert: 40 | mat = torch.inverse(mat) 41 | return mat 42 | 43 | # 2. get camera position with camera pose (theta & phi) 44 | def to_sphere(u, v): 45 | theta = 2 * np.pi * u 46 | phi = np.arccos(1 - 2 * v) 47 | cx = np.sin(phi) * np.cos(theta) 48 | cy = np.sin(phi) * np.sin(theta) 49 | cz = np.cos(phi) 50 | return np.stack([cx, cy, cz], axis=-1) 51 | 52 | # 3. get camera coordinate system assuming it points to the center of the sphere 53 | def look_at(eye, at=np.array([0, 0, 0]), up=np.array([0, 0, 1]), eps=1e-5, 54 | to_pytorch=True): 55 | at = at.astype(float).reshape(1, 3) 56 | up = up.astype(float).reshape(1, 3) 57 | eye = eye.reshape(-1, 3) 58 | up = up.repeat(eye.shape[0] // up.shape[0], axis=0) 59 | eps = np.array([eps]).reshape(1, 1).repeat(up.shape[0], axis=0) 60 | 61 | z_axis = eye - at 62 | z_axis /= np.max(np.stack([np.linalg.norm(z_axis, 63 | axis=1, keepdims=True), eps])) 64 | 65 | x_axis = np.cross(up, z_axis) 66 | x_axis /= np.max(np.stack([np.linalg.norm(x_axis, 67 | axis=1, keepdims=True), eps])) 68 | 69 | y_axis = np.cross(z_axis, x_axis) 70 | y_axis /= np.max(np.stack([np.linalg.norm(y_axis, 71 | axis=1, keepdims=True), eps])) 72 | 73 | r_mat = np.concatenate( 74 | (x_axis.reshape(-1, 3, 1), y_axis.reshape(-1, 3, 1), z_axis.reshape( 75 | -1, 3, 1)), axis=2) 76 | 77 | if to_pytorch: 78 | r_mat = torch.tensor(r_mat).float() 79 | 80 | return r_mat 81 | 82 | # 5. arange 2d array of pixel coordinate and give depth of 1 83 | def arange_pixels(resolution=(128, 128), batch_size=1, image_range=(-1., 1.), 84 | subsample_to=None, invert_y_axis=False): 85 | ''' Arranges pixels for given resolution in range image_range. 86 | The function returns the unscaled pixel locations as integers and the 87 | scaled float values. 88 | Args: 89 | resolution (tuple): image resolution 90 | batch_size (int): batch size 91 | image_range (tuple): range of output points (default [-1, 1]) 92 | subsample_to (int): if integer and > 0, the points are randomly 93 | subsampled to this value 94 | ''' 95 | h, w = resolution 96 | n_points = resolution[0] * resolution[1] 97 | 98 | # Arrange pixel location in scale resolution 99 | pixel_locations = torch.meshgrid(torch.arange(0, w), torch.arange(0, h)) 100 | pixel_locations = torch.stack( 101 | [pixel_locations[0], pixel_locations[1]], 102 | dim=-1).long().view(1, -1, 2).repeat(batch_size, 1, 1) 103 | pixel_scaled = pixel_locations.clone().float() 104 | 105 | # Shift and scale points to match image_range 106 | scale = (image_range[1] - image_range[0]) 107 | loc = scale / 2 108 | pixel_scaled[:, :, 0] = scale * pixel_scaled[:, :, 0] / (w - 1) - loc 109 | pixel_scaled[:, :, 1] = scale * pixel_scaled[:, :, 1] / (h - 1) - loc 110 | 111 | # Subsample points if subsample_to is not None and > 0 112 | if (subsample_to is not None and subsample_to > 0 and 113 | subsample_to < n_points): 114 | idx = np.random.choice(pixel_scaled.shape[1], size=(subsample_to,), 115 | replace=False) 116 | pixel_scaled = pixel_scaled[:, idx] 117 | pixel_locations = pixel_locations[:, idx] 118 | 119 | if invert_y_axis: 120 | assert(image_range == (-1, 1)) 121 | pixel_scaled[..., -1] *= -1. 122 | pixel_locations[..., -1] = (h - 1) - pixel_locations[..., -1] 123 | 124 | return pixel_locations, pixel_scaled 125 | 126 | # 6. mat_mul with intrinsic and then extrinsic gives you p_world (pixels in world) 127 | def image_points_to_world(image_points, camera_mat, world_mat, scale_mat=None, 128 | invert=False, negative_depth=True): 129 | ''' Transforms points on image plane to world coordinates. 130 | In contrast to transform_to_world, no depth value is needed as points on 131 | the image plane have a fixed depth of 1. 132 | Args: 133 | image_points (tensor): image points tensor of size B x N x 2 134 | camera_mat (tensor): camera matrix 135 | world_mat (tensor): world matrix 136 | scale_mat (tensor): scale matrix 137 | invert (bool): whether to invert matrices (default: False) 138 | ''' 139 | batch_size, n_pts, dim = image_points.shape 140 | assert(dim == 2) 141 | d_image = torch.ones(batch_size, n_pts, 1) 142 | if negative_depth: 143 | d_image *= -1. 144 | return transform_to_world(image_points, d_image, camera_mat, world_mat, 145 | scale_mat, invert=invert) 146 | 147 | def transform_to_world(pixels, depth, camera_mat, world_mat, scale_mat=None, 148 | invert=True, use_absolute_depth=True): 149 | ''' Transforms pixel positions p with given depth value d to world coordinates. 150 | Args: 151 | pixels (tensor): pixel tensor of size B x N x 2 152 | depth (tensor): depth tensor of size B x N x 1 153 | camera_mat (tensor): camera matrix 154 | world_mat (tensor): world matrix 155 | scale_mat (tensor): scale matrix 156 | invert (bool): whether to invert matrices (default: true) 157 | ''' 158 | assert(pixels.shape[-1] == 2) 159 | 160 | if scale_mat is None: 161 | scale_mat = torch.eye(4).unsqueeze(0).repeat( 162 | camera_mat.shape[0], 1, 1) 163 | 164 | # Convert to pytorch 165 | pixels, is_numpy = to_pytorch(pixels, True) 166 | depth = to_pytorch(depth) 167 | camera_mat = to_pytorch(camera_mat) 168 | world_mat = to_pytorch(world_mat) 169 | scale_mat = to_pytorch(scale_mat) 170 | 171 | # Invert camera matrices 172 | if invert: 173 | camera_mat = torch.inverse(camera_mat) 174 | world_mat = torch.inverse(world_mat) 175 | scale_mat = torch.inverse(scale_mat) 176 | 177 | # Transform pixels to homogen coordinates 178 | pixels = pixels.permute(0, 2, 1) 179 | pixels = torch.cat([pixels, torch.ones_like(pixels)], dim=1) 180 | 181 | # Project pixels into camera space 182 | if use_absolute_depth: 183 | pixels[:, :2] = pixels[:, :2] * depth.permute(0, 2, 1).abs() 184 | pixels[:, 2:3] = pixels[:, 2:3] * depth.permute(0, 2, 1) 185 | else: 186 | pixels[:, :3] = pixels[:, :3] * depth.permute(0, 2, 1) 187 | 188 | # Transform pixels to world space 189 | p_world = scale_mat @ world_mat @ camera_mat @ pixels 190 | 191 | # Transform p_world back to 3D coordinates 192 | p_world = p_world[:, :3].permute(0, 2, 1) 193 | 194 | if is_numpy: 195 | p_world = p_world.numpy() 196 | return p_world 197 | 198 | 199 | # 7. mat_mul zeros with intrinsic&extrinsic for camera pos (which we alread obtained as loc) 200 | def origin_to_world(n_points, camera_mat, world_mat, scale_mat=None, 201 | invert=False): 202 | ''' Transforms origin (camera location) to world coordinates. 203 | Args: 204 | n_points (int): how often the transformed origin is repeated in the 205 | form (batch_size, n_points, 3) 206 | camera_mat (tensor): camera matrix 207 | world_mat (tensor): world matrix 208 | scale_mat (tensor): scale matrix 209 | invert (bool): whether to invert the matrices (default: false) 210 | ''' 211 | 212 | batch_size = camera_mat.shape[0] 213 | device = camera_mat.device 214 | # Create origin in homogen coordinates 215 | p = torch.zeros(batch_size, 4, n_points).to(device) 216 | p[:, -1] = 1. 217 | 218 | if scale_mat is None: 219 | scale_mat = torch.eye(4).unsqueeze( 220 | 0).repeat(batch_size, 1, 1).to(device) 221 | 222 | # Invert matrices 223 | if invert: 224 | camera_mat = torch.inverse(camera_mat) 225 | world_mat = torch.inverse(world_mat) 226 | scale_mat = torch.inverse(scale_mat) 227 | 228 | camera_mat = to_pytorch(camera_mat) 229 | world_mat = to_pytorch(world_mat) 230 | scale_mat = to_pytorch(scale_mat) 231 | 232 | # Apply transformation 233 | p_world = scale_mat @ world_mat @ camera_mat @ p 234 | 235 | # Transform points back to 3D coordinates 236 | p_world = p_world[:, :3].permute(0, 2, 1) 237 | return p_world 238 | 239 | 240 | # from Kai-46/nerfplusplus 241 | 242 | # 8. intersect sphere for distinguishing fg and bg 243 | def intersect_sphere(ray_o, ray_d): 244 | ''' 245 | ray_o, ray_d: [..., 3] 246 | compute the depth of the intersection point between this ray and unit sphere 247 | ''' 248 | # note: d1 becomes negative if this mid point is behind camera 249 | d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1) 250 | p = ray_o + d1.unsqueeze(-1) * ray_d 251 | # consider the case where the ray does not intersect the sphere 252 | ray_d_cos = 1. / torch.norm(ray_d, dim=-1) 253 | p_norm_sq = torch.sum(p * p, dim=-1) 254 | if (p_norm_sq >= 1.).any(): 255 | raise Exception('Not all your cameras are bounded by the unit sphere; please make sure the cameras are normalized properly!') 256 | d2 = torch.sqrt(1. - p_norm_sq) * ray_d_cos 257 | 258 | return d1 + d2 259 | 260 | # 9. inverse sphere sampling for bg 261 | def depth2pts_outside(ray_o, ray_d, depth): 262 | ''' 263 | ray_o, ray_d: [..., 3] 264 | depth: [...]; inverse of distance to sphere origin 265 | ''' 266 | # note: d1 becomes negative if this mid point is behind camera 267 | d1 = -torch.sum(ray_d * ray_o, dim=-1) / torch.sum(ray_d * ray_d, dim=-1) 268 | p_mid = ray_o + d1.unsqueeze(-1) * ray_d 269 | p_mid_norm = torch.norm(p_mid, dim=-1) 270 | ray_d_cos = 1. / torch.norm(ray_d, dim=-1) 271 | d2 = torch.sqrt(1. - p_mid_norm * p_mid_norm) * ray_d_cos 272 | 273 | p_sphere = ray_o + (d1 + d2).unsqueeze(-1) * ray_d 274 | 275 | rot_axis = torch.cross(ray_o, p_sphere, dim=-1) 276 | rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True) 277 | phi = torch.asin(p_mid_norm) 278 | theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1] 279 | rot_angle = (phi - theta).unsqueeze(-1) # [..., 1] 280 | 281 | # now rotate p_sphere 282 | # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula 283 | p_sphere_new = p_sphere * torch.cos(rot_angle) + \ 284 | torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \ 285 | rot_axis * torch.sum(rot_axis*p_sphere, dim=-1, keepdim=True) * (1.-torch.cos(rot_angle)) 286 | p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True) 287 | pts = torch.cat((p_sphere_new.squeeze(), depth.squeeze().unsqueeze(-1)), dim=-1) # (modified) added .squeeze() 288 | 289 | # now calculate conventional depth 290 | depth_real = 1. / (depth + TINY_NUMBER) * torch.cos(theta) * ray_d_cos + d1 291 | return pts, depth_real 292 | 293 | # 10. perturb sample z values (depth) for some randomness (stratified sampling) 294 | def perturb_samples(z_vals): 295 | # get intervals between samples 296 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) 297 | upper = torch.cat([mids, z_vals[..., -1:]], dim=-1) 298 | lower = torch.cat([z_vals[..., 0:1], mids], dim=-1) 299 | # uniform samples in those intervals 300 | t_rand = torch.rand_like(z_vals) 301 | z_vals = lower + (upper - lower) * t_rand # [N_rays, N_samples] 302 | 303 | return z_vals 304 | 305 | # 11. return fg (just uniform) and bg (uniform inverse sphere) samples 306 | def uniformsampling(ray_o, ray_d, min_depth, N_samples, device): 307 | 308 | dots_sh = list(ray_d.shape[:-1]) 309 | 310 | # foreground depth 311 | fg_far_depth = intersect_sphere(ray_o, ray_d) # [...,] 312 | fg_near_depth = min_depth # [..., ] 313 | step = (fg_far_depth - fg_near_depth) / (N_samples - 1) 314 | fg_depth = torch.stack([fg_near_depth + i * step for i in range(N_samples)], dim=-1) # [..., N_samples] 315 | fg_depth = perturb_samples(fg_depth) # random perturbation during training 316 | 317 | # background depth 318 | bg_depth = torch.linspace(0., 1., N_samples).view( 319 | [1, ] * len(dots_sh) + [N_samples,]).expand(dots_sh + [N_samples,]).to(device) 320 | bg_depth = perturb_samples(bg_depth) # random perturbation during training 321 | 322 | 323 | fg_ray_o = ray_o.unsqueeze(-2).expand(dots_sh + [N_samples, 3]) 324 | fg_ray_d = ray_d.unsqueeze(-2).expand(dots_sh + [N_samples, 3]) 325 | 326 | bg_ray_o = ray_o.unsqueeze(-2).expand(dots_sh + [N_samples, 3]) 327 | bg_ray_d = ray_d.unsqueeze(-2).expand(dots_sh + [N_samples, 3]) 328 | 329 | # sampling foreground 330 | fg_pts = fg_ray_o + fg_depth.unsqueeze(-1) * fg_ray_d 331 | 332 | # sampling background 333 | bg_pts, bg_depth_real = depth2pts_outside(bg_ray_o, bg_ray_d, bg_depth) 334 | 335 | return fg_pts, bg_pts, bg_depth_real 336 | 337 | # 12 convert bg pts (x', y', z' 1/r) to real bg pts (x, y, z) 338 | def pts2realpts(bg_pts, bg_depth_real): 339 | return bg_pts[:, :, :3] * bg_depth_real.squeeze().unsqueeze(-1) 340 | 341 | def nerfpp(u = 1, 342 | v = 0.5, 343 | fov = 49.13, 344 | depth_range=[0.5, 6.], 345 | n_ray_samples=16, 346 | resolution_vol = 4, 347 | batch_size = 1, 348 | device = torch.device('cpu') 349 | ): 350 | 351 | r = 1 352 | range_radius=[r, r] 353 | 354 | res = resolution_vol 355 | n_points = res * res 356 | 357 | # 1. get camera intrinsic - fiddle around with fov this time 358 | camera_mat = get_camera_mat(fov=fov) 359 | 360 | # 2. get camera position with camera pose (theta & phi) 361 | loc = to_sphere(u, v) 362 | loc = torch.tensor(loc).float() 363 | 364 | radius = range_radius[0] + \ 365 | torch.rand(batch_size) * (range_radius[1] - range_radius[0]) 366 | 367 | loc = loc * radius.unsqueeze(-1) 368 | 369 | # 3. get camera coordinate system assuming it points to the center of the sphere 370 | R = look_at(loc) 371 | 372 | # 4. The carmera coordinate is the rotational matrix and with camera loc, it is camera extrinsic 373 | RT = np.eye(4).reshape(1, 4, 4) 374 | RT[:, :3, :3] = R 375 | RT[:, :3, -1] = loc 376 | world_mat = RT 377 | 378 | # 5. arange 2d array of pixel coordinate and give depth of 1 379 | pixels = arange_pixels((res, res), 1, invert_y_axis=False)[1] 380 | pixels[..., -1] *= -1. # still dunno why this is here 381 | 382 | # 6. mat_mul with intrinsic and then extrinsic gives you p_world (pixels in world) 383 | pixels_world = image_points_to_world(pixels, camera_mat, world_mat) 384 | 385 | # 7. mat_mul zeros with intrinsic&extrinsic for camera pos (which we alread obtained as loc) 386 | camera_world = origin_to_world(n_points, camera_mat, world_mat) 387 | 388 | # 8. ray = pixel - camera origin (in world) 389 | ray_vector = pixels_world - camera_world 390 | 391 | # 9. sample fg and bg points according to nerfpp (uniform and inverse sphere) 392 | fg_pts, bg_pts, bg_depth_real = uniformsampling(ray_o=camera_world, ray_d=ray_vector, min_depth=depth_range[0], N_samples=n_ray_samples, device=device) 393 | 394 | #10. convert bg pts (x', y', z' 1/r) to real bg pts (x, y, z) 395 | bg_pts_real = pts2realpts(bg_pts, bg_depth_real) 396 | 397 | return pixels_world, camera_world, world_mat, fg_pts, bg_pts_real 398 | --------------------------------------------------------------------------------