├── README.md ├── multi_view_code ├── bunny.sdf └── code │ ├── .DS_Store │ ├── main.py │ ├── renderer.cpp │ ├── renderer_kernel.cu │ └── setup.py ├── single_view_code ├── differentiable_rendering.py ├── main.py ├── models.py ├── renderer.cpp ├── renderer_kernel.cu └── setup.py └── virtual_env ├── install_conda.sh └── install_pip.sh /README.md: -------------------------------------------------------------------------------- 1 | # SDFDiff: Differentiable Rendering of Signed Distance Fields for 3D Shape Optimization 2 | 3 | **IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2020 (Oral)** 4 | 5 | Authors: **Yue Jiang, Dantong Ji, Zhizhong Han, Matthias Zwicker** 6 | 7 | **Project page:** https://yuejiang-nj.github.io/papers/CVPR2020_SDFDiff/project_page.html 8 | 9 | **Paper:** http://www.cs.umd.edu/~yuejiang/papers/SDFDiff.pdf 10 | 11 | **Video:** https://www.youtube.com/watch?v=l3h9JZHAOqI&t=13s 12 | 13 | **Talk:** https://youtu.be/0A83pElG5gk 14 | 15 | 16 | ## Prerequisite Installation 17 | 18 | 1. Python3 19 | 2. CUDA10 20 | 3. Pytorch 21 | 22 | 23 | ## To Get Started: 24 | 25 | SDFDiff has been implemented and tested on Ubuntu 18.04 with python >= 3.7. 26 | 27 | Clone the repo: 28 | ``` bash 29 | git clone https://github.com/YueJiang-nj/CVPR2020-SDFDiff.git 30 | ``` 31 | 32 | Install the requirements using `virtualenv` or `conda`: 33 | ``` bash 34 | # pip 35 | source virtual_env/install_pip.sh 36 | 37 | # conda 38 | source virtual_env/install_conda.sh 39 | ``` 40 | 41 | ## Introduction 42 | 43 | The project has the following file layout: 44 | 45 | README.md 46 | multi_view_code/ 47 | bunny.sdf 48 | dragon.sdf 49 | code/ 50 | main.py 51 | renderer.cpp 52 | renderer_kernel.cu 53 | setup.py 54 | single_view_code/ 55 | differentiable_rendering.py 56 | main.py 57 | models.py 58 | renderer.cpp 59 | renderer_kernel.cu 60 | setup.py 61 | 62 | 63 | **multi_view_code** contains the source code for multi-view 3D reconstruction using our SDFDiff. 64 | 65 | **single_view_code** contains the source code for single-view 3D reconstruction using our SDFDiff and deep learning models. 66 | 67 | ## Running the Demo 68 | 69 | We have prepared a demo to run SDFDiff on a bunny object. 70 | 71 | To run the multi-view 3D reconstruction on bunny, you can follow the following steps in the folder multi_view_code/code: 72 | 73 | ``` bash 74 | 1. You need to run “python setup.py install” to compile our SDF differentiable renderer. 75 | 76 | 2. Once built, you can execute the bunny reconstruction example via “python main.py” 77 | ``` 78 | 79 | ## Parameter Tuning 80 | 81 | There are two kinds of parameters you can modify to get better results: 82 | 83 | ``` 84 | 1. Weighted Loss 85 | In the line: loss = image_loss[cam] + sdf_loss[cam] + Lp_loss 86 | You can make it weighted. loss = a * image_loss[cam] + b * sdf_loss[cam] + c * Lp_loss and try different a, b, c. For example, the surface would be smoother if you increase c. 87 | 88 | 2. Intermediate Resolutions 89 | In the line: voxel_res_list = [8,16,24,32,40,48,56,64] 90 | You can add more intermediate resolutions in the list. It can also produce better results when we have more intermediate resolutions. 91 | ``` 92 | 93 | ## Generating SDF from Mesh 94 | 95 | If you have a mesh file xxx.obj, you need to generate SDF from the mesh file to run our SDFDiff code. 96 | 97 | First, you need to git clone the following tools. 98 | 99 | ``` bash 100 | # a tool to generate watertight meshes from arbitrary meshes 101 | git clone https://github.com/hjwdzh/Manifold.git 102 | 103 | # A tool to generate SDF from watertight meshes 104 | git clone https://github.com/christopherbatty/SDFGen.git 105 | ``` 106 | 107 | Then you can run the following to get SDF from your mesh file xxx.obj. 108 | 109 | ``` bash 110 | # Generate watertight meshes from arbitrary meshes 111 | ./Manifold/build/manifold ./obj_files/xxx.obj ./watertight_meshes_and_sdfs/xxx.obj 112 | 113 | # Generate SDF from watertight meshes 114 | ./SDFGen/build/bin/SDFGen ./watertight_meshes_and_sdfs/xxx.obj 0.002 0 115 | ``` 116 | 117 | ## Citation 118 | ```bibtex 119 | @InProceedings{jiang2020sdfdiff, 120 | author = {Jiang, Yue and Ji, Dantong and Han, Zhizhong and Zwicker, Matthias}, 121 | title = {SDFDiff: Differentiable Rendering of Signed Distance Fields for 3D Shape Optimization}, 122 | booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 123 | month = {June}, 124 | year = {2020} 125 | } 126 | ``` 127 | -------------------------------------------------------------------------------- /multi_view_code/code/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YueJiang-nj/CVPR2020-SDFDiff/d44883497cc8dd8bdf106fed408ea0a50b27af76/multi_view_code/code/.DS_Store -------------------------------------------------------------------------------- /multi_view_code/code/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import math 4 | import torchvision 5 | from torch.autograd import Variable 6 | import renderer 7 | import time 8 | import sys, os 9 | import torch.nn.functional as F 10 | from torchvision.utils import save_image, make_grid 11 | 12 | def read_txt(file_path, grid_res_x, grid_res_y, grid_res_z): 13 | with open(file_path) as file: 14 | grid = Tensor(grid_res_x, grid_res_y, grid_res_z) 15 | for i in range(grid_res_x): 16 | for j in range(grid_res_y): 17 | for k in range(grid_res_z): 18 | grid[i][j][k] = float(file.readline()) 19 | print (grid) 20 | 21 | return grid 22 | 23 | # Read a file and create a sdf grid with target_grid_res 24 | def read_sdf(file_path, target_grid_res, target_bounding_box_min, target_bounding_box_max, target_voxel_size): 25 | 26 | with open(file_path) as file: 27 | line = file.readline() 28 | 29 | # Get grid resolutions 30 | grid_res = line.split() 31 | grid_res_x = int(grid_res[0]) 32 | grid_res_y = int(grid_res[1]) 33 | grid_res_z = int(grid_res[2]) 34 | 35 | # Get bounding box min 36 | line = file.readline() 37 | bounding_box_min = line.split() 38 | bounding_box_min_x = float(bounding_box_min[0]) 39 | bounding_box_min_y = float(bounding_box_min[1]) 40 | bounding_box_min_z = float(bounding_box_min[2]) 41 | 42 | line = file.readline() 43 | voxel_size = float(line) 44 | 45 | # max bounding box (we need to plus 0.0001 to avoid round error) 46 | bounding_box_max_x = bounding_box_min_x + voxel_size * (grid_res_x - 1) 47 | bounding_box_max_y = bounding_box_min_y + voxel_size * (grid_res_y - 1) 48 | bounding_box_max_z = bounding_box_min_z + voxel_size * (grid_res_z - 1) 49 | 50 | min_bounding_box_min = min(bounding_box_min_x, bounding_box_min_y, bounding_box_min_z) 51 | # print(bounding_box_min_x, bounding_box_min_y, bounding_box_min_z) 52 | max_bounding_box_max = max(bounding_box_max_x, bounding_box_max_y, bounding_box_max_z) 53 | # print(bounding_box_max_x, bounding_box_max_y, bounding_box_max_z) 54 | max_dist = max(bounding_box_max_x - bounding_box_min_x, bounding_box_max_y - bounding_box_min_y, bounding_box_max_z - bounding_box_min_z) 55 | 56 | # max_dist += 0.1 57 | max_grid_res = max(grid_res_x, grid_res_y, grid_res_z) 58 | 59 | grid = [] 60 | for i in range(grid_res_x): 61 | grid.append([]) 62 | for j in range(grid_res_y): 63 | grid[i].append([]) 64 | for k in range(grid_res_z): 65 | # grid_value = float(file.readline()) 66 | grid[i][j].append(2) 67 | # lst.append(grid_value) 68 | 69 | for i in range(grid_res_z): 70 | for j in range(grid_res_y): 71 | for k in range(grid_res_x): 72 | grid_value = float(file.readline()) 73 | grid[k][j][i] = grid_value 74 | 75 | grid = Tensor(grid) 76 | target_grid = Tensor(target_grid_res, target_grid_res, target_grid_res) 77 | 78 | linear_space_x = torch.linspace(0, target_grid_res-1, target_grid_res) 79 | linear_space_y = torch.linspace(0, target_grid_res-1, target_grid_res) 80 | linear_space_z = torch.linspace(0, target_grid_res-1, target_grid_res) 81 | first_loop = linear_space_x.repeat(target_grid_res * target_grid_res, 1).t().contiguous().view(-1).unsqueeze_(1) 82 | second_loop = linear_space_y.repeat(target_grid_res, target_grid_res).t().contiguous().view(-1).unsqueeze_(1) 83 | third_loop = linear_space_z.repeat(target_grid_res * target_grid_res).unsqueeze_(1) 84 | loop = torch.cat((first_loop, second_loop, third_loop), 1).cuda() 85 | 86 | min_x = Tensor([bounding_box_min_x]).repeat(target_grid_res*target_grid_res*target_grid_res, 1) 87 | min_y = Tensor([bounding_box_min_y]).repeat(target_grid_res*target_grid_res*target_grid_res, 1) 88 | min_z = Tensor([bounding_box_min_z]).repeat(target_grid_res*target_grid_res*target_grid_res, 1) 89 | bounding_min_matrix = torch.cat((min_x, min_y, min_z), 1) 90 | 91 | move_to_center_x = Tensor([(max_dist - (bounding_box_max_x - bounding_box_min_x)) / 2]).repeat(target_grid_res*target_grid_res*target_grid_res, 1) 92 | move_to_center_y = Tensor([(max_dist - (bounding_box_max_y - bounding_box_min_y)) / 2]).repeat(target_grid_res*target_grid_res*target_grid_res, 1) 93 | move_to_center_z = Tensor([(max_dist - (bounding_box_max_z - bounding_box_min_z)) / 2]).repeat(target_grid_res*target_grid_res*target_grid_res, 1) 94 | move_to_center_matrix = torch.cat((move_to_center_x, move_to_center_y, move_to_center_z), 1) 95 | 96 | # Get the position of the grid points in the refined grid 97 | points = bounding_min_matrix + target_voxel_size * max_dist / (target_bounding_box_max - target_bounding_box_min) * loop - move_to_center_matrix 98 | if points[(points[:, 0] < bounding_box_min_x)].shape[0] != 0: 99 | points[(points[:, 0] < bounding_box_min_x)] = Tensor([bounding_box_max_x, bounding_box_max_y, bounding_box_max_z]).view(1,3) 100 | if points[(points[:, 1] < bounding_box_min_y)].shape[0] != 0: 101 | points[(points[:, 1] < bounding_box_min_y)] = Tensor([bounding_box_max_x, bounding_box_min_y, bounding_box_min_z]).view(1,3) 102 | if points[(points[:, 2] < bounding_box_min_z)].shape[0] != 0: 103 | points[(points[:, 2] < bounding_box_min_z)] = Tensor([bounding_box_max_x, bounding_box_min_y, bounding_box_min_z]).view(1,3) 104 | if points[(points[:, 0] > bounding_box_max_x)].shape[0] != 0: 105 | points[(points[:, 0] > bounding_box_max_x)] = Tensor([bounding_box_max_x, bounding_box_min_y, bounding_box_min_z]).view(1,3) 106 | if points[(points[:, 1] > bounding_box_max_y)].shape[0] != 0: 107 | points[(points[:, 1] > bounding_box_max_y)] = Tensor([bounding_box_max_x, bounding_box_min_y, bounding_box_min_z]).view(1,3) 108 | if points[(points[:, 2] > bounding_box_max_z)].shape[0] != 0: 109 | points[(points[:, 2] > bounding_box_max_z)] = Tensor([bounding_box_max_x, bounding_box_min_y, bounding_box_min_z]).view(1,3) 110 | voxel_min_point_index_x = torch.floor((points[:,0].unsqueeze_(1) - min_x) / voxel_size).clamp(max=grid_res_x-2) 111 | voxel_min_point_index_y = torch.floor((points[:,1].unsqueeze_(1) - min_y) / voxel_size).clamp(max=grid_res_y-2) 112 | voxel_min_point_index_z = torch.floor((points[:,2].unsqueeze_(1) - min_z) / voxel_size).clamp(max=grid_res_z-2) 113 | voxel_min_point_index = torch.cat((voxel_min_point_index_x, voxel_min_point_index_y, voxel_min_point_index_z), 1) 114 | voxel_min_point = bounding_min_matrix + voxel_min_point_index * voxel_size 115 | 116 | # Compute the sdf value of the grid points in the refined grid 117 | target_grid = calculate_sdf_value(grid, points, voxel_min_point, voxel_min_point_index, voxel_size, grid_res_x, grid_res_y, grid_res_z).view(target_grid_res, target_grid_res, target_grid_res) 118 | 119 | # "shortest path" algorithm to fill the values (for changing from "cuboid" SDF to "cube" SDF) 120 | # min of the SDF values of the closest points + the distance to these points 121 | # calculate the max resolution get which areas we need to compute the shortest path 122 | max_res = max(grid_res_x, grid_res_y, grid_res_z) 123 | if grid_res_x == max_res: 124 | min_x = 0 125 | max_x = target_grid_res - 1 126 | min_y = math.ceil((target_grid_res - target_grid_res / float(grid_res_x) * grid_res_y) / 2) 127 | max_y = target_grid_res - min_y - 1 128 | min_z = math.ceil((target_grid_res - target_grid_res / float(grid_res_x) * grid_res_z) / 2) 129 | max_z = target_grid_res - min_z - 1 130 | if grid_res_y == max_res: 131 | min_x = math.ceil((target_grid_res - target_grid_res / float(grid_res_y) * grid_res_x) / 2) 132 | max_x = target_grid_res - min_x - 1 133 | min_y = 0 134 | max_y = target_grid_res - 1 135 | min_z = math.ceil((target_grid_res - target_grid_res / float(grid_res_y) * grid_res_z) / 2) 136 | max_z = target_grid_res - min_z - 1 137 | if grid_res_z == max_res: 138 | min_x = math.ceil((target_grid_res - target_grid_res / float(grid_res_z) * grid_res_x) / 2) 139 | max_x = target_grid_res - min_x - 1 140 | min_y = math.ceil((target_grid_res - target_grid_res / float(grid_res_z) * grid_res_y) / 2) 141 | max_y = target_grid_res - min_y - 1 142 | min_z = 0 143 | max_z = target_grid_res - 1 144 | min_x = int(min_x) 145 | max_x = int(max_x) 146 | min_y = int(min_y) 147 | max_y = int(max_y) 148 | min_z = int(min_z) 149 | max_z = int(max_z) 150 | 151 | # fill the values 152 | res = target_grid.shape[0] 153 | for i in range(res): 154 | for j in range(res): 155 | for k in range(res): 156 | 157 | # fill the values outside both x-axis and y-axis 158 | if k < min_x and j < min_y: 159 | target_grid[k][j][i] = target_grid[min_x][min_y][i] + math.sqrt((min_x - k) ** 2 + (min_y - j) ** 2) * voxel_size 160 | elif k < min_x and j > max_y: 161 | target_grid[k][j][i] = target_grid[min_x][max_y][i] + math.sqrt((min_x - k) ** 2 + (max_y - j) ** 2) * voxel_size 162 | elif k > max_x and j < min_y: 163 | target_grid[k][j][i] = target_grid[max_x][min_y][i] + math.sqrt((max_x - k) ** 2 + (min_y - j) ** 2) * voxel_size 164 | elif k > max_x and j > max_y: 165 | target_grid[k][j][i] = target_grid[max_x][max_y][i] + math.sqrt((max_x - k) ** 2 + (max_y - j) ** 2) * voxel_size 166 | 167 | # fill the values outside both x-axis and z-axis 168 | elif k < min_x and i < min_z: 169 | target_grid[k][j][i] = target_grid[min_x][j][min_z] + math.sqrt((min_x - k) ** 2 + (min_z - i) ** 2) * voxel_size 170 | elif k < min_x and i > max_z: 171 | target_grid[k][j][i] = target_grid[min_x][j][max_z] + math.sqrt((min_x - k) ** 2 + (max_z - i) ** 2) * voxel_size 172 | elif k > max_x and i < min_z: 173 | target_grid[k][j][i] = target_grid[max_x][j][min_z] + math.sqrt((max_x - k) ** 2 + (min_z - i) ** 2) * voxel_size 174 | elif k > max_x and i > max_z: 175 | target_grid[k][j][i] = target_grid[max_x][j][max_z] + math.sqrt((max_x - k) ** 2 + (max_z - i) ** 2) * voxel_size 176 | 177 | # fill the values outside both y-axis and z-axis 178 | elif j < min_y and i < min_z: 179 | target_grid[k][j][i] = target_grid[k][min_y][min_z] + math.sqrt((min_y - j) ** 2 + (min_z - i) ** 2) * voxel_size 180 | elif j < min_y and i > max_z: 181 | target_grid[k][j][i] = target_grid[k][min_y][max_z] + math.sqrt((min_y - j) ** 2 + (max_z - i) ** 2) * voxel_size 182 | elif j > max_y and i < min_z: 183 | target_grid[k][j][i] = target_grid[k][max_y][min_z] + math.sqrt((max_y - j) ** 2 + (min_z - i) ** 2) * voxel_size 184 | elif j > max_y and i > max_z: 185 | target_grid[k][j][i] = target_grid[k][max_y][max_z] + math.sqrt((max_y - j) ** 2 + (max_z - i) ** 2) * voxel_size 186 | 187 | # fill the values outside x-axis 188 | elif k < min_x: 189 | target_grid[k][j][i] = target_grid[min_x][j][i] + math.sqrt((min_x - k) ** 2) * voxel_size 190 | elif k > max_x: 191 | target_grid[k][j][i] = target_grid[max_x][j][i] + math.sqrt((max_x - k) ** 2) * voxel_size 192 | 193 | # fill the values outside y-axis 194 | elif j < min_y: 195 | target_grid[k][j][i] = target_grid[k][min_y][i] + math.sqrt((min_y - j) ** 2) * voxel_size 196 | elif j > max_y: 197 | target_grid[k][j][i] = target_grid[k][max_y][i] + math.sqrt((max_y - j) ** 2) * voxel_size 198 | 199 | # fill the values outside z-axis 200 | elif i < min_z: 201 | target_grid[k][j][i] = target_grid[k][j][min_z] + math.sqrt((min_z - i) ** 2) * voxel_size 202 | elif i > max_z: 203 | target_grid[k][j][i] = target_grid[k][j][max_z] + math.sqrt((max_z - i) ** 2) * voxel_size 204 | 205 | return target_grid 206 | 207 | 208 | def grid_construction_cube(grid_res, bounding_box_min, bounding_box_max): 209 | 210 | # Construct the sdf grid for a cube with size 2 211 | voxel_size = (bounding_box_max - bounding_box_min) / (grid_res - 1) 212 | cube_left_bound_index = float(grid_res - 1) / 4; 213 | cube_right_bound_index = float(grid_res - 1) / 4 * 3; 214 | cube_center = float(grid_res - 1) / 2; 215 | 216 | grid = Tensor(grid_res, grid_res, grid_res) 217 | for i in range(grid_res): 218 | for j in range(grid_res): 219 | for k in range(grid_res): 220 | if (i >= cube_left_bound_index and i <= cube_right_bound_index and 221 | j >= cube_left_bound_index and j <= cube_right_bound_index and 222 | k >= cube_left_bound_index and k <= cube_right_bound_index): 223 | grid[i,j,k] = voxel_size * max(abs(i - cube_center), abs(j - cube_center), abs(k - cube_center)) - 1; 224 | else: 225 | grid[i,j,k] = math.sqrt(pow(voxel_size * (max(i - cube_right_bound_index, cube_left_bound_index - i, 0)), 2) + 226 | pow(voxel_size * (max(j - cube_right_bound_index, cube_left_bound_index - j, 0)), 2) + 227 | pow(voxel_size * (max(k - cube_right_bound_index, cube_left_bound_index - k, 0)), 2)); 228 | return grid 229 | 230 | def grid_construction_torus(grid_res, bounding_box_min, bounding_box_max): 231 | 232 | # radius of the circle between the two circles 233 | radius_big = 1.5 234 | 235 | # radius of the small circle 236 | radius_small = 0.5 237 | 238 | voxel_size = (bounding_box_max - bounding_box_min) / (grid_res - 1) 239 | grid = Tensor(grid_res, grid_res, grid_res) 240 | for i in range(grid_res): 241 | for j in range(grid_res): 242 | for k in range(grid_res): 243 | x = bounding_box_min + voxel_size * i 244 | y = bounding_box_min + voxel_size * j 245 | z = bounding_box_min + voxel_size * k 246 | 247 | grid[i,j,k] = math.sqrt(math.pow((math.sqrt(math.pow(y, 2) + math.pow(z, 2)) - radius_big), 2) 248 | + math.pow(x, 2)) - radius_small; 249 | 250 | return grid 251 | 252 | 253 | 254 | def grid_construction_sphere_big(grid_res, bounding_box_min, bounding_box_max): 255 | 256 | # Construct the sdf grid for a sphere with radius 1 257 | linear_space = torch.linspace(bounding_box_min, bounding_box_max, grid_res) 258 | x_dim = linear_space.view(-1, 1).repeat(grid_res, 1, grid_res) 259 | y_dim = linear_space.view(1, -1).repeat(grid_res, grid_res, 1) 260 | z_dim = linear_space.view(-1, 1, 1).repeat(1, grid_res, grid_res) 261 | grid = torch.sqrt(x_dim * x_dim + y_dim * y_dim + z_dim * z_dim) - 1.6 262 | if cuda: 263 | return grid.cuda() 264 | else: 265 | return grid 266 | 267 | def grid_construction_sphere_small(grid_res, bounding_box_min, bounding_box_max): 268 | 269 | # Construct the sdf grid for a sphere with radius 1 270 | linear_space = torch.linspace(bounding_box_min, bounding_box_max, grid_res) 271 | x_dim = linear_space.view(-1, 1).repeat(grid_res, 1, grid_res) 272 | y_dim = linear_space.view(1, -1).repeat(grid_res, grid_res, 1) 273 | z_dim = linear_space.view(-1, 1, 1).repeat(1, grid_res, grid_res) 274 | grid = torch.sqrt(x_dim * x_dim + y_dim * y_dim + z_dim * z_dim) - 1 275 | if cuda: 276 | return grid.cuda() 277 | else: 278 | return grid 279 | 280 | 281 | def get_grid_normal(grid, voxel_size, grid_res_x, grid_res_y, grid_res_z): 282 | 283 | # largest index 284 | n_x = grid_res_x - 1 285 | n_y = grid_res_y - 1 286 | n_z = grid_res_z - 1 287 | 288 | # x-axis normal vectors 289 | X_1 = torch.cat((grid[1:,:,:], (3 * grid[n_x,:,:] - 3 * grid[n_x-1,:,:] + grid[n_x-2,:,:]).unsqueeze_(0)), 0) 290 | X_2 = torch.cat(((-3 * grid[1,:,:] + 3 * grid[0,:,:] + grid[2,:,:]).unsqueeze_(0), grid[:n_x,:,:]), 0) 291 | grid_normal_x = (X_1 - X_2) / (2 * voxel_size) 292 | 293 | # y-axis normal vectors 294 | Y_1 = torch.cat((grid[:,1:,:], (3 * grid[:,n_y,:] - 3 * grid[:,n_y-1,:] + grid[:,n_y-2,:]).unsqueeze_(1)), 1) 295 | Y_2 = torch.cat(((-3 * grid[:,1,:] + 3 * grid[:,0,:] + grid[:,2,:]).unsqueeze_(1), grid[:,:n_y,:]), 1) 296 | grid_normal_y = (Y_1 - Y_2) / (2 * voxel_size) 297 | 298 | # z-axis normal vectors 299 | Z_1 = torch.cat((grid[:,:,1:], (3 * grid[:,:,n_z] - 3 * grid[:,:,n_z-1] + grid[:,:,n_z-2]).unsqueeze_(2)), 2) 300 | Z_2 = torch.cat(((-3 * grid[:,:,1] + 3 * grid[:,:,0] + grid[:,:,2]).unsqueeze_(2), grid[:,:,:n_z]), 2) 301 | grid_normal_z = (Z_1 - Z_2) / (2 * voxel_size) 302 | 303 | 304 | return [grid_normal_x, grid_normal_y, grid_normal_z] 305 | 306 | 307 | def get_intersection_normal(intersection_grid_normal, intersection_pos, voxel_min_point, voxel_size): 308 | 309 | # Compute parameters 310 | tx = (intersection_pos[:,:,0] - voxel_min_point[:,:,0]) / voxel_size 311 | ty = (intersection_pos[:,:,1] - voxel_min_point[:,:,1]) / voxel_size 312 | tz = (intersection_pos[:,:,2] - voxel_min_point[:,:,2]) / voxel_size 313 | 314 | intersection_normal = (1 - tz) * (1 - ty) * (1 - tx) * intersection_grid_normal[:,:,0] \ 315 | + tz * (1 - ty) * (1 - tx) * intersection_grid_normal[:,:,1] \ 316 | + (1 - tz) * ty * (1 - tx) * intersection_grid_normal[:,:,2] \ 317 | + tz * ty * (1 - tx) * intersection_grid_normal[:,:,3] \ 318 | + (1 - tz) * (1 - ty) * tx * intersection_grid_normal[:,:,4] \ 319 | + tz * (1 - ty) * tx * intersection_grid_normal[:,:,5] \ 320 | + (1 - tz) * ty * tx * intersection_grid_normal[:,:,6] \ 321 | + tz * ty * tx * intersection_grid_normal[:,:,7] 322 | 323 | return intersection_normal 324 | 325 | 326 | # Do one more step for ray matching 327 | def calculate_sdf_value(grid, points, voxel_min_point, voxel_min_point_index, voxel_size, grid_res_x, grid_res_y, grid_res_z): 328 | 329 | string = "" 330 | 331 | # Linear interpolate along x axis the eight values 332 | tx = (points[:,0] - voxel_min_point[:,0]) / voxel_size; 333 | string = string + "\n\nvoxel_size: \n" + str(voxel_size) 334 | string = string + "\n\ntx: \n" + str(tx) 335 | print(grid.shape) 336 | 337 | if cuda: 338 | tx = tx.cuda() 339 | x = voxel_min_point_index.long()[:,0] 340 | y = voxel_min_point_index.long()[:,1] 341 | z = voxel_min_point_index.long()[:,2] 342 | 343 | string = string + "\n\nx: \n" + str(x) 344 | string = string + "\n\ny: \n" + str(y) 345 | string = string + "\n\nz: \n" + str(z) 346 | 347 | c01 = (1 - tx) * grid[x,y,z] + tx * grid[x+1,y,z]; 348 | c23 = (1 - tx) * grid[x,y+1,z] + tx * grid[x+1,y+1,z]; 349 | c45 = (1 - tx) * grid[x,y,z+1] + tx * grid[x+1,y,z+1]; 350 | c67 = (1 - tx) * grid[x,y+1,z+1] + tx * grid[x+1,y+1,z+1]; 351 | 352 | string = string + "\n\n(1 - tx): \n" + str((1 - tx)) 353 | string = string + "\n\ngrid[x,y,z]: \n" + str(grid[x,y,z]) 354 | string = string + "\n\ngrid[x+1,y,z]: \n" + str(grid[x+1,y,z]) 355 | string = string + "\n\nc01: \n" + str(c01) 356 | string = string + "\n\nc23: \n" + str(c23) 357 | string = string + "\n\nc45: \n" + str(c45) 358 | string = string + "\n\nc67: \n" + str(c67) 359 | 360 | # Linear interpolate along the y axis 361 | ty = (points[:,1] - voxel_min_point[:,1]) / voxel_size; 362 | ty = ty.cuda() 363 | c0 = (1 - ty) * c01 + ty * c23; 364 | c1 = (1 - ty) * c45 + ty * c67; 365 | 366 | string = string + "\n\nty: \n" + str(ty) 367 | 368 | string = string + "\n\nc0: \n" + str(c0) 369 | string = string + "\n\nc1: \n" + str(c1) 370 | 371 | # Return final value interpolated along z 372 | tz = (points[:,2] - voxel_min_point[:,2]) / voxel_size; 373 | tz = tz.cuda() 374 | string = string + "\n\ntz: \n" + str(tz) 375 | 376 | else: 377 | x = voxel_min_point_index.numpy()[:,0] 378 | y = voxel_min_point_index.numpy()[:,1] 379 | z = voxel_min_point_index.numpy()[:,2] 380 | 381 | c01 = (1 - tx) * grid[x,y,z] + tx * grid[x+1,y,z]; 382 | c23 = (1 - tx) * grid[x,y+1,z] + tx * grid[x+1,y+1,z]; 383 | c45 = (1 - tx) * grid[x,y,z+1] + tx * grid[x+1,y,z+1]; 384 | c67 = (1 - tx) * grid[x,y+1,z+1] + tx * grid[x+1,y+1,z+1]; 385 | 386 | # Linear interpolate along the y axis 387 | ty = (points[:,1] - voxel_min_point[:,1]) / voxel_size; 388 | c0 = (1 - ty) * c01 + ty * c23; 389 | c1 = (1 - ty) * c45 + ty * c67; 390 | 391 | # Return final value interpolated along z 392 | tz = (points[:,2] - voxel_min_point[:,2]) / voxel_size; 393 | 394 | result = (1 - tz) * c0 + tz * c1; 395 | 396 | return result 397 | 398 | 399 | def compute_intersection_pos(grid, intersection_pos_rough, voxel_min_point, voxel_min_point_index, ray_direction, voxel_size, mask): 400 | 401 | # Linear interpolate along x axis the eight values 402 | tx = (intersection_pos_rough[:,:,0] - voxel_min_point[:,:,0]) / voxel_size; 403 | 404 | if cuda: 405 | 406 | x = voxel_min_point_index.long()[:,:,0] 407 | y = voxel_min_point_index.long()[:,:,1] 408 | z = voxel_min_point_index.long()[:,:,2] 409 | 410 | c01 = (1 - tx) * grid[x,y,z].cuda() + tx * grid[x+1,y,z].cuda(); 411 | c23 = (1 - tx) * grid[x,y+1,z].cuda() + tx * grid[x+1,y+1,z].cuda(); 412 | c45 = (1 - tx) * grid[x,y,z+1].cuda() + tx * grid[x+1,y,z+1].cuda(); 413 | c67 = (1 - tx) * grid[x,y+1,z+1].cuda() + tx * grid[x+1,y+1,z+1].cuda(); 414 | 415 | else: 416 | x = voxel_min_point_index.numpy()[:,:,0] 417 | y = voxel_min_point_index.numpy()[:,:,1] 418 | z = voxel_min_point_index.numpy()[:,:,2] 419 | 420 | c01 = (1 - tx) * grid[x,y,z] + tx * grid[x+1,y,z]; 421 | c23 = (1 - tx) * grid[x,y+1,z] + tx * grid[x+1,y+1,z]; 422 | c45 = (1 - tx) * grid[x,y,z+1] + tx * grid[x+1,y,z+1]; 423 | c67 = (1 - tx) * grid[x,y+1,z+1] + tx * grid[x+1,y+1,z+1]; 424 | 425 | # Linear interpolate along the y axis 426 | ty = (intersection_pos_rough[:,:,1] - voxel_min_point[:,:,1]) / voxel_size; 427 | c0 = (1 - ty) * c01 + ty * c23; 428 | c1 = (1 - ty) * c45 + ty * c67; 429 | 430 | # Return final value interpolated along z 431 | tz = (intersection_pos_rough[:,:,2] - voxel_min_point[:,:,2]) / voxel_size; 432 | 433 | sdf_value = (1 - tz) * c0 + tz * c1; 434 | 435 | return (intersection_pos_rough + ray_direction * sdf_value.view(width,height,1).repeat(1,1,3))\ 436 | + (1 - mask.view(width,height,1).repeat(1,1,3)) 437 | 438 | def generate_image(bounding_box_min_x, bounding_box_min_y, bounding_box_min_z, \ 439 | bounding_box_max_x, bounding_box_max_y, bounding_box_max_z, \ 440 | voxel_size, grid_res_x, grid_res_y, grid_res_z, width, height, grid, camera, back, camera_list): 441 | 442 | # Get normal vectors for points on the grid 443 | [grid_normal_x, grid_normal_y, grid_normal_z] = get_grid_normal(grid, voxel_size, grid_res_x, grid_res_y, grid_res_z) 444 | 445 | # Generate rays 446 | e = camera 447 | 448 | w_h_3 = torch.zeros(width, height, 3).cuda() 449 | w_h = torch.zeros(width, height).cuda() 450 | eye_x = e[0] 451 | eye_y = e[1] 452 | eye_z = e[2] 453 | 454 | # Do ray tracing in cpp 455 | outputs = renderer.ray_matching(w_h_3, w_h, grid, width, height, bounding_box_min_x, bounding_box_min_y, bounding_box_min_z, \ 456 | bounding_box_max_x, bounding_box_max_y, bounding_box_max_z, \ 457 | grid_res_x, grid_res_y, grid_res_z, \ 458 | eye_x, \ 459 | eye_y, \ 460 | eye_z 461 | ) 462 | 463 | # {intersection_pos, voxel_position, directions} 464 | intersection_pos_rough = outputs[0] 465 | voxel_min_point_index = outputs[1] 466 | ray_direction = outputs[2] 467 | 468 | # Initialize grid values and normals for intersection voxels 469 | intersection_grid_normal_x = Tensor(width, height, 8) 470 | intersection_grid_normal_y = Tensor(width, height, 8) 471 | intersection_grid_normal_z = Tensor(width, height, 8) 472 | intersection_grid = Tensor(width, height, 8) 473 | 474 | # Make the pixels with no intersections with rays be 0 475 | mask = (voxel_min_point_index[:,:,0] != -1).type(Tensor) 476 | 477 | # Get the indices of the minimum point of the intersecting voxels 478 | x = voxel_min_point_index[:,:,0].type(torch.cuda.LongTensor) 479 | y = voxel_min_point_index[:,:,1].type(torch.cuda.LongTensor) 480 | z = voxel_min_point_index[:,:,2].type(torch.cuda.LongTensor) 481 | x[x == -1] = 0 482 | y[y == -1] = 0 483 | z[z == -1] = 0 484 | 485 | # Get the x-axis of normal vectors for the 8 points of the intersecting voxel 486 | # This line is equivalent to grid_normal_x[x,y,z] 487 | x1 = torch.index_select(grid_normal_x.view(-1), 0, z.view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 488 | x2 = torch.index_select(grid_normal_x.view(-1), 0, (z+1).view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 489 | x3 = torch.index_select(grid_normal_x.view(-1), 0, z.view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 490 | x4 = torch.index_select(grid_normal_x.view(-1), 0, (z+1).view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 491 | x5 = torch.index_select(grid_normal_x.view(-1), 0, z.view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 492 | x6 = torch.index_select(grid_normal_x.view(-1), 0, (z+1).view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 493 | x7 = torch.index_select(grid_normal_x.view(-1), 0, z.view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 494 | x8 = torch.index_select(grid_normal_x.view(-1), 0, (z+1).view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 495 | intersection_grid_normal_x = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), 2) + (1 - mask.view(width, height, 1).repeat(1,1,8)) 496 | 497 | # Get the y-axis of normal vectors for the 8 points of the intersecting voxel 498 | y1 = torch.index_select(grid_normal_y.view(-1), 0, z.view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 499 | y2 = torch.index_select(grid_normal_y.view(-1), 0, (z+1).view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 500 | y3 = torch.index_select(grid_normal_y.view(-1), 0, z.view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 501 | y4 = torch.index_select(grid_normal_y.view(-1), 0, (z+1).view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 502 | y5 = torch.index_select(grid_normal_y.view(-1), 0, z.view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 503 | y6 = torch.index_select(grid_normal_y.view(-1), 0, (z+1).view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 504 | y7 = torch.index_select(grid_normal_y.view(-1), 0, z.view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 505 | y8 = torch.index_select(grid_normal_y.view(-1), 0, (z+1).view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 506 | intersection_grid_normal_y = torch.cat((y1, y2, y3, y4, y5, y6, y7, y8), 2) + (1 - mask.view(width, height, 1).repeat(1,1,8)) 507 | 508 | # Get the z-axis of normal vectors for the 8 points of the intersecting voxel 509 | z1 = torch.index_select(grid_normal_z.view(-1), 0, z.view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 510 | z2 = torch.index_select(grid_normal_z.view(-1), 0, (z+1).view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 511 | z3 = torch.index_select(grid_normal_z.view(-1), 0, z.view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 512 | z4 = torch.index_select(grid_normal_z.view(-1), 0, (z+1).view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 513 | z5 = torch.index_select(grid_normal_z.view(-1), 0, z.view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 514 | z6 = torch.index_select(grid_normal_z.view(-1), 0, (z+1).view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 515 | z7 = torch.index_select(grid_normal_z.view(-1), 0, z.view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 516 | z8 = torch.index_select(grid_normal_z.view(-1), 0, (z+1).view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 517 | intersection_grid_normal_z = torch.cat((z1, z2, z3, z4, z5, z6, z7, z8), 2) + (1 - mask.view(width, height, 1).repeat(1,1,8)) 518 | 519 | # Change from grid coordinates to world coordinates 520 | voxel_min_point = Tensor([bounding_box_min_x, bounding_box_min_y, bounding_box_min_z]) + voxel_min_point_index * voxel_size 521 | 522 | intersection_pos = compute_intersection_pos(grid, intersection_pos_rough,\ 523 | voxel_min_point, voxel_min_point_index,\ 524 | ray_direction, voxel_size, mask) 525 | 526 | intersection_pos = intersection_pos * mask.repeat(3,1,1).permute(1,2,0) 527 | shading = Tensor(width, height).fill_(0) 528 | 529 | # Compute the normal vectors for the intersecting points 530 | intersection_normal_x = get_intersection_normal(intersection_grid_normal_x, intersection_pos, voxel_min_point, voxel_size) 531 | intersection_normal_y = get_intersection_normal(intersection_grid_normal_y, intersection_pos, voxel_min_point, voxel_size) 532 | intersection_normal_z = get_intersection_normal(intersection_grid_normal_z, intersection_pos, voxel_min_point, voxel_size) 533 | 534 | # Put all the xyz-axis of the normal vectors into a single matrix 535 | intersection_normal_x_resize = intersection_normal_x.unsqueeze_(2) 536 | intersection_normal_y_resize = intersection_normal_y.unsqueeze_(2) 537 | intersection_normal_z_resize = intersection_normal_z.unsqueeze_(2) 538 | intersection_normal = torch.cat((intersection_normal_x_resize, intersection_normal_y_resize, intersection_normal_z_resize), 2) 539 | intersection_normal = intersection_normal / torch.unsqueeze(torch.norm(intersection_normal, p=2, dim=2), 2).repeat(1, 1, 3) 540 | 541 | # Create the point light 542 | light_position = camera.repeat(width, height, 1) 543 | light_norm = torch.unsqueeze(torch.norm(light_position - intersection_pos, p=2, dim=2), 2).repeat(1, 1, 3) 544 | light_direction_point = (light_position - intersection_pos) / light_norm 545 | 546 | # Create the directional light 547 | shading = 0 548 | light_direction = (camera / torch.norm(camera, p=2)).repeat(width, height, 1) 549 | l_dot_n = torch.sum(light_direction * intersection_normal, 2).unsqueeze_(2) 550 | shading += 10 * torch.max(l_dot_n, Tensor(width, height, 1).fill_(0))[:,:,0] / torch.pow(torch.sum((light_position - intersection_pos) * light_direction_point, dim=2), 2) 551 | 552 | # Get the final image 553 | image = shading * mask 554 | image[mask == 0] = 0 555 | 556 | return image 557 | 558 | # The energy E captures the difference between a rendered image and 559 | # a desired target image, and the rendered image is a function of the 560 | # SDF values. You could write E(SDF) = ||rendering(SDF)-target_image||^2. 561 | # In addition, there is a second term in the energy as you observed that 562 | # constrains the length of the normal of the SDF to 1. This is a regularization 563 | # term to make sure the output is still a valid SDF. 564 | def loss_fn(output, target, grid, voxel_size, grid_res_x, grid_res_y, grid_res_z, width, height): 565 | 566 | image_loss = torch.sum(torch.abs(target - output)) #/ (width * height) 567 | 568 | [grid_normal_x, grid_normal_y, grid_normal_z] = get_grid_normal(grid, voxel_size, grid_res_x, grid_res_y, grid_res_z) 569 | sdf_loss = torch.sum(torch.abs(torch.pow(grid_normal_x[1:grid_res_x-1, 1:grid_res_y-1, 1:grid_res_z-1], 2)\ 570 | + torch.pow(grid_normal_y[1:grid_res_x-1, 1:grid_res_y-1, 1:grid_res_z-1], 2)\ 571 | + torch.pow(grid_normal_z[1:grid_res_x-1, 1:grid_res_y-1, 1:grid_res_z-1], 2) - 1)) #/ ((grid_res-1) * (grid_res-1) * (grid_res-1)) 572 | 573 | 574 | print("\n\nimage loss: ", image_loss) 575 | print("sdf loss: ", sdf_loss) 576 | 577 | return image_loss, sdf_loss 578 | 579 | def sdf_diff(sdf1, sdf2): 580 | return torch.sum(torch.abs(sdf1 - sdf2)).item() 581 | 582 | 583 | if __name__ == "__main__": 584 | 585 | # define the folder name for results 586 | dir_name = "results/" 587 | os.mkdir("./" + dir_name) 588 | 589 | # Speed up 590 | torch.backends.cudnn.benchmark = True 591 | 592 | cuda = True if torch.cuda.is_available() else False 593 | print(cuda) 594 | 595 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 596 | 597 | width = 256 598 | height = 256 599 | 600 | camera_list = [Tensor([0,0,5]), # 0 601 | Tensor([0.1,5,0]), 602 | Tensor([5,0,0]), 603 | Tensor([0,0,-5]), 604 | Tensor([0.1,-5,0]), 605 | Tensor([-5,0,0]), # 5 606 | 607 | Tensor([5/math.sqrt(2),0,5/math.sqrt(2)]), 608 | Tensor([5/math.sqrt(2),5/math.sqrt(2),0]), 609 | Tensor([0,5/math.sqrt(2),5/math.sqrt(2)]), 610 | 611 | Tensor([-5/math.sqrt(2),0,-5/math.sqrt(2)]), 612 | Tensor([-5/math.sqrt(2),-5/math.sqrt(2),0]), #10 613 | Tensor([0,-5/math.sqrt(2),-5/math.sqrt(2)]), 614 | 615 | Tensor([-5/math.sqrt(2),0,5/math.sqrt(2)]), 616 | Tensor([-5/math.sqrt(2),5/math.sqrt(2),0]), 617 | Tensor([0,-5/math.sqrt(2),5/math.sqrt(2)]), 618 | 619 | Tensor([5/math.sqrt(2),0,-5/math.sqrt(2)]), 620 | Tensor([5/math.sqrt(2),-5/math.sqrt(2),0]), 621 | Tensor([0,5/math.sqrt(2),-5/math.sqrt(2)]), 622 | 623 | Tensor([5/math.sqrt(3),5/math.sqrt(3),5/math.sqrt(3)]), 624 | Tensor([5/math.sqrt(3),5/math.sqrt(3),-5/math.sqrt(3)]), 625 | Tensor([5/math.sqrt(3),-5/math.sqrt(3),5/math.sqrt(3)]), 626 | Tensor([-5/math.sqrt(3),5/math.sqrt(3),5/math.sqrt(3)]), 627 | Tensor([-5/math.sqrt(3),-5/math.sqrt(3),5/math.sqrt(3)]), 628 | Tensor([-5/math.sqrt(3),5/math.sqrt(3),-5/math.sqrt(3)]), 629 | Tensor([5/math.sqrt(3),-5/math.sqrt(3),-5/math.sqrt(3)]), 630 | Tensor([-5/math.sqrt(3),-5/math.sqrt(3),-5/math.sqrt(3)])] 631 | 632 | # bounding box 633 | bounding_box_min_x = -2. 634 | bounding_box_min_y = -2. 635 | bounding_box_min_z = -2. 636 | bounding_box_max_x = 2. 637 | bounding_box_max_y = 2. 638 | bounding_box_max_z = 2. 639 | 640 | 641 | # size of the image 642 | width = 64 643 | height = 64 644 | 645 | loss = 500 646 | 647 | image_loss_list = [] 648 | sdf_loss_list = [] 649 | e = camera_list[0] 650 | 651 | # Find proper grid resolution 652 | pixel_distance = torch.tan(Tensor([math.pi/6])) * 2 / height 653 | 654 | # Compute largest distance between the grid and the camera 655 | largest_distance_camera_grid = torch.sqrt(torch.pow(max(torch.abs(e[0] - bounding_box_max_x), torch.abs(e[0] - bounding_box_min_x)), 2) 656 | + torch.pow(max(torch.abs(e[1] - bounding_box_max_y), torch.abs(e[1] - bounding_box_min_y)), 2) 657 | + torch.pow(max(torch.abs(e[2] - bounding_box_max_z), torch.abs(e[2] - bounding_box_min_z)), 2)) 658 | grid_res_x = 8 659 | grid_res_y = 8 660 | grid_res_z = 8 661 | 662 | # define the resolutions of the multi-resolution part 663 | voxel_res_list = [8,16,24,32,40,48,56,64] 664 | grid_res_x = grid_res_y = grid_res_z = voxel_res_list.pop(0) 665 | voxel_size = Tensor([4. / (grid_res_x-1)]) 666 | 667 | # Construct the sdf grid 668 | grid_initial = grid_construction_sphere_big(grid_res_x, bounding_box_min_x, bounding_box_max_x) #### 669 | 670 | # set parameters 671 | sdf_diff_list = [] 672 | time_list = [] 673 | image_loss = [1000] * len(camera_list) 674 | sdf_loss = [1000] * len(camera_list) 675 | iterations = 0 676 | scale = 1 677 | start_time = time.time() 678 | learning_rate = 0.01 679 | tolerance = 8 / 10 680 | 681 | # image size 682 | width = 256 683 | height = 256 684 | 685 | start_time = time.time() 686 | while (grid_res_x <= 64): 687 | tolerance *= 1.05 688 | image_target = [] 689 | 690 | # load sdf file 691 | grid_target = read_sdf("../bunny.sdf", grid_res_x, bounding_box_min_x, bounding_box_max_x, 4. / (grid_res_x-1)) 692 | grid_initial.requires_grad = True 693 | optimizer = torch.optim.Adam([grid_initial], lr = learning_rate, eps=1e-2) 694 | 695 | # output images 696 | for cam in range(len(camera_list)): 697 | image_initial = generate_image(bounding_box_min_x, bounding_box_min_y, bounding_box_min_z, \ 698 | bounding_box_max_x, bounding_box_max_y, bounding_box_max_z, \ 699 | voxel_size, grid_res_x, grid_res_y, grid_res_z, width, height, grid_initial, camera_list[cam], 1, camera_list) 700 | torchvision.utils.save_image(image_initial, "./" + dir_name + "grid_res_" + str(grid_res_x) + "_start_" + str(cam) + ".png", nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0) 701 | image = generate_image(bounding_box_min_x, bounding_box_min_y, bounding_box_min_z, \ 702 | bounding_box_max_x, bounding_box_max_y, bounding_box_max_z, \ 703 | 4. / (grid_res_x-1), grid_res_x, grid_res_y, grid_res_z, width, height, grid_target, camera_list[cam]+ torch.randn_like(camera_list[0]) * 0.015, 1, camera_list) 704 | image_target.append(image) 705 | torchvision.utils.save_image(image, "./" + dir_name + "grid_res_" + str(grid_res_x) + "_target_" + str(cam) + ".png", nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0) 706 | 707 | # deform initial SDf to target SDF 708 | i = 0 709 | loss_camera = [1000] * len(camera_list) 710 | average = 100000 711 | while sum(loss_camera) < average - tolerance / 2: 712 | average = sum(loss_camera) 713 | for cam in range(len(camera_list)): 714 | loss = 100000 715 | prev_loss = loss + 1 716 | num = 0 717 | while((num < 5) and loss < prev_loss): 718 | num += 1; 719 | prev_loss = loss 720 | iterations += 1 721 | 722 | optimizer.zero_grad() 723 | 724 | # Generate images 725 | image_initial = generate_image(bounding_box_min_x, bounding_box_min_y, bounding_box_min_z, \ 726 | bounding_box_max_x, bounding_box_max_y, bounding_box_max_z, \ 727 | voxel_size, grid_res_x, grid_res_y, grid_res_z, width, height, grid_initial, camera_list[cam], 1, camera_list) 728 | 729 | # Perform backprobagation 730 | # compute image loss and sdf loss 731 | image_loss[cam], sdf_loss[cam] = loss_fn(image_initial, image_target[cam], grid_initial, voxel_size, grid_res_x, grid_res_y, grid_res_z, width, height) 732 | 733 | # compute laplacian loss 734 | conv_input = (grid_initial).unsqueeze(0).unsqueeze(0) 735 | conv_filter = torch.cuda.FloatTensor([[[[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [1, -6, 1], [0, 1, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]]]]]) 736 | Lp_loss = torch.sum(F.conv3d(conv_input, conv_filter) ** 2) 737 | 738 | # get total loss 739 | loss = image_loss[cam] + sdf_loss[cam] + Lp_loss 740 | image_loss[cam] = image_loss[cam] / len(camera_list) 741 | sdf_loss[cam] = sdf_loss[cam] / len(camera_list) 742 | loss_camera[cam] = image_loss[cam] + sdf_loss[cam] 743 | 744 | # print out loss messages 745 | print("grid res:", grid_res_x, "iteration:", i, "num:", num, "loss:", loss, "\ncamera:", camera_list[cam]) 746 | loss.backward() 747 | optimizer.step() 748 | 749 | i += 1 750 | 751 | # genetate result images 752 | for cam in range(len(camera_list)): 753 | image_initial = generate_image(bounding_box_min_x, bounding_box_min_y, bounding_box_min_z, \ 754 | bounding_box_max_x, bounding_box_max_y, bounding_box_max_z, \ 755 | voxel_size, grid_res_x, grid_res_y, grid_res_z, width, height, grid_initial, camera_list[cam],0, camera_list) 756 | 757 | torchvision.utils.save_image(image_initial, "./" + dir_name + "final_cam_" + str(grid_res_x) + "_" + str(cam) + ".png", nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0) 758 | 759 | # Save the final SDF result 760 | with open("./" + dir_name + str(grid_res_x) + "_best_sdf_bunny.pt", 'wb') as f: 761 | torch.save(grid_initial, f) 762 | 763 | # moves on to the next resolution stage 764 | if grid_res_x < 64: 765 | grid_res_update_x = grid_res_update_y = grid_res_update_z = voxel_res_list.pop(0) 766 | voxel_size_update = (bounding_box_max_x - bounding_box_min_x) / (grid_res_update_x - 1) 767 | grid_initial_update = Tensor(grid_res_update_x, grid_res_update_y, grid_res_update_z) 768 | linear_space_x = torch.linspace(0, grid_res_update_x-1, grid_res_update_x) 769 | linear_space_y = torch.linspace(0, grid_res_update_y-1, grid_res_update_y) 770 | linear_space_z = torch.linspace(0, grid_res_update_z-1, grid_res_update_z) 771 | first_loop = linear_space_x.repeat(grid_res_update_y * grid_res_update_z, 1).t().contiguous().view(-1).unsqueeze_(1) 772 | second_loop = linear_space_y.repeat(grid_res_update_z, grid_res_update_x).t().contiguous().view(-1).unsqueeze_(1) 773 | third_loop = linear_space_z.repeat(grid_res_update_x * grid_res_update_y).unsqueeze_(1) 774 | loop = torch.cat((first_loop, second_loop, third_loop), 1).cuda() 775 | min_x = Tensor([bounding_box_min_x]).repeat(grid_res_update_x*grid_res_update_y*grid_res_update_z, 1) 776 | min_y = Tensor([bounding_box_min_y]).repeat(grid_res_update_x*grid_res_update_y*grid_res_update_z, 1) 777 | min_z = Tensor([bounding_box_min_z]).repeat(grid_res_update_x*grid_res_update_y*grid_res_update_z, 1) 778 | bounding_min_matrix = torch.cat((min_x, min_y, min_z), 1) 779 | 780 | # Get the position of the grid points in the refined grid 781 | points = bounding_min_matrix + voxel_size_update * loop 782 | voxel_min_point_index_x = torch.floor((points[:,0].unsqueeze_(1) - min_x) / voxel_size).clamp(max=grid_res_x-2) 783 | voxel_min_point_index_y = torch.floor((points[:,1].unsqueeze_(1) - min_y) / voxel_size).clamp(max=grid_res_y-2) 784 | voxel_min_point_index_z = torch.floor((points[:,2].unsqueeze_(1) - min_z) / voxel_size).clamp(max=grid_res_z-2) 785 | voxel_min_point_index = torch.cat((voxel_min_point_index_x, voxel_min_point_index_y, voxel_min_point_index_z), 1) 786 | voxel_min_point = bounding_min_matrix + voxel_min_point_index * voxel_size 787 | 788 | # Compute the sdf value of the grid points in the refined grid 789 | grid_initial_update = calculate_sdf_value(grid_initial, points, voxel_min_point, voxel_min_point_index, voxel_size, grid_res_x, grid_res_y, grid_res_z).view(grid_res_update_x, grid_res_update_y, grid_res_update_z) 790 | 791 | # Update the grid resolution for the refined sdf grid 792 | grid_res_x = grid_res_update_x 793 | grid_res_y = grid_res_update_y 794 | grid_res_z = grid_res_update_z 795 | 796 | # Update the voxel size for the refined sdf grid 797 | voxel_size = voxel_size_update 798 | 799 | # Update the sdf grid 800 | grid_initial = grid_initial_update.data 801 | 802 | # Double the size of the image 803 | if width < 256: 804 | width = int(width * 2) 805 | height = int(height * 2) 806 | learning_rate /= 1.03 807 | 808 | print("Time:", time.time() - start_time) 809 | 810 | print("----- END -----") 811 | 812 | -------------------------------------------------------------------------------- /multi_view_code/code/renderer.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | std::vector ray_matching_cuda( 5 | const at::Tensor w_h_3, 6 | const at::Tensor w_h, 7 | const at::Tensor grid, 8 | const int width, 9 | const int height, 10 | const float bounding_box_min_x, 11 | const float bounding_box_min_y, 12 | const float bounding_box_min_z, 13 | const float bounding_box_max_x, 14 | const float bounding_box_max_y, 15 | const float bounding_box_max_z, 16 | const int grid_res_x, 17 | const int grid_res_y, 18 | const int grid_res_z, 19 | const float eye_x, 20 | const float eye_y, 21 | const float eye_z); 22 | 23 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 24 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 25 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 26 | 27 | std::vector ray_matching( 28 | const at::Tensor w_h_3, 29 | const at::Tensor w_h, 30 | const at::Tensor grid, 31 | const int width, 32 | const int height, 33 | const float bounding_box_min_x, 34 | const float bounding_box_min_y, 35 | const float bounding_box_min_z, 36 | const float bounding_box_max_x, 37 | const float bounding_box_max_y, 38 | const float bounding_box_max_z, 39 | const int grid_res_x, 40 | const int grid_res_y, 41 | const int grid_res_z, 42 | const float eye_x, 43 | const float eye_y, 44 | const float eye_z) { 45 | CHECK_INPUT(w_h_3); 46 | CHECK_INPUT(w_h); 47 | CHECK_INPUT(grid); 48 | 49 | return ray_matching_cuda(w_h_3, w_h, grid, width, height, 50 | bounding_box_min_x, 51 | bounding_box_min_y, 52 | bounding_box_min_z, 53 | bounding_box_max_x, 54 | bounding_box_max_y, 55 | bounding_box_max_z, 56 | grid_res_x, 57 | grid_res_y, 58 | grid_res_z, 59 | eye_x, eye_y, eye_z); 60 | } 61 | 62 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 63 | m.def("ray_matching", &ray_matching, "Ray Matching"); 64 | } -------------------------------------------------------------------------------- /multi_view_code/code/renderer_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #define PI 3.14159265358979323846f 7 | namespace { 8 | 9 | __device__ __forceinline__ float DegToRad(const float °) { return (deg * (PI / 180.f)); } 10 | 11 | 12 | __device__ __forceinline__ float length( 13 | const float x, 14 | const float y, 15 | const float z) { 16 | return sqrtf(powf(x, 2) + powf(y, 2) + powf(z, 2)); 17 | } 18 | 19 | // Cross product 20 | __device__ __forceinline__ float cross_x( 21 | const float a_x, 22 | const float a_y, 23 | const float a_z, 24 | const float b_x, 25 | const float b_y, 26 | const float b_z) { 27 | return a_y * b_z - a_z * b_y; 28 | } 29 | 30 | 31 | __device__ __forceinline__ float cross_y( 32 | const float a_x, 33 | const float a_y, 34 | const float a_z, 35 | const float b_x, 36 | const float b_y, 37 | const float b_z) { 38 | return a_z * b_x - a_x * b_z; 39 | } 40 | 41 | 42 | __device__ __forceinline__ float cross_z( 43 | const float a_x, 44 | const float a_y, 45 | const float a_z, 46 | const float b_x, 47 | const float b_y, 48 | const float b_z) { 49 | return a_x * b_y - a_y * b_x; 50 | } 51 | 52 | __global__ void GenerateRay( 53 | float* origins, 54 | float* directions, 55 | float* origin_image_distances, 56 | float* pixel_distances, 57 | const int width, 58 | const int height, 59 | const float eye_x, 60 | const float eye_y, 61 | const float eye_z) { 62 | 63 | const float at_x = 0; 64 | const float at_y = 0; 65 | const float at_z = 0; 66 | const float up_x = 0; 67 | const float up_y = 1; 68 | const float up_z = 0; 69 | 70 | // Compute camera view volume 71 | const float top = tan(DegToRad(30)); 72 | const float bottom = -top; 73 | const float right = (__int2float_rd(width) / __int2float_rd(height)) * top; 74 | const float left = -right; 75 | 76 | // Compute local base 77 | const float w_x = (eye_x - at_x) / length(eye_x - at_x, eye_y - at_y, eye_z - at_z); 78 | const float w_y = (eye_y - at_y) / length(eye_x - at_x, eye_y - at_y, eye_z - at_z); 79 | const float w_z = (eye_z - at_z) / length(eye_x - at_x, eye_y - at_y, eye_z - at_z); 80 | const float cross_up_w_x = cross_x(up_x, up_y, up_z, w_x, w_y, w_z); 81 | const float cross_up_w_y = cross_y(up_x, up_y, up_z, w_x, w_y, w_z); 82 | const float cross_up_w_z = cross_z(up_x, up_y, up_z, w_x, w_y, w_z); 83 | const float u_x = (cross_up_w_x) / length(cross_up_w_x, cross_up_w_y, cross_up_w_z); 84 | const float u_y = (cross_up_w_y) / length(cross_up_w_x, cross_up_w_y, cross_up_w_z); 85 | const float u_z = (cross_up_w_z) / length(cross_up_w_x, cross_up_w_y, cross_up_w_z); 86 | const float v_x = cross_x(w_x, w_y, w_z, u_x, u_y, u_z); 87 | const float v_y = cross_y(w_x, w_y, w_z, u_x, u_y, u_z); 88 | const float v_z = cross_z(w_x, w_y, w_z, u_x, u_y, u_z); 89 | 90 | 91 | const int pixel_index = blockIdx.x * blockDim.x + threadIdx.x; 92 | 93 | if (pixel_index < width * height) { 94 | const int x = pixel_index % width; 95 | const int y = pixel_index / width; 96 | const int i = 3 * pixel_index; 97 | 98 | // Compute point on view plane 99 | // Ray passes through the center of the pixel 100 | const float view_plane_x = left + (right - left) * (__int2float_rd(x) + 0.5) / __int2float_rd(width); 101 | const float view_plane_y = top - (top - bottom) * (__int2float_rd(y) + 0.5) / __int2float_rd(height); 102 | const float s_x = view_plane_x * u_x + view_plane_y * v_x - w_x; 103 | const float s_y = view_plane_x * u_y + view_plane_y * v_y - w_y; 104 | const float s_z = view_plane_x * u_z + view_plane_y * v_z - w_z; 105 | origins[i] = eye_x; 106 | origins[i+1] = eye_y; 107 | origins[i+2] = eye_z; 108 | 109 | 110 | directions[i] = s_x / length(s_x, s_y, s_z); 111 | directions[i+1] = s_y / length(s_x, s_y, s_z); 112 | directions[i+2] = s_z / length(s_x, s_y, s_z); 113 | 114 | origin_image_distances[pixel_index] = length(s_x, s_y, s_z); 115 | pixel_distances[pixel_index] = (right - left) / __int2float_rd(width); 116 | 117 | } 118 | } 119 | 120 | // Check if a point is inside 121 | __device__ __forceinline__ bool InsideBoundingBox( 122 | const float p_x, 123 | const float p_y, 124 | const float p_z, 125 | const float bounding_box_min_x, 126 | const float bounding_box_min_y, 127 | const float bounding_box_min_z, 128 | const float bounding_box_max_x, 129 | const float bounding_box_max_y, 130 | const float bounding_box_max_z) { 131 | 132 | return (p_x >= bounding_box_min_x) && (p_x <= bounding_box_max_x) && 133 | (p_y >= bounding_box_min_y) && (p_y <= bounding_box_max_y) && 134 | (p_z >= bounding_box_min_z) && (p_z <= bounding_box_max_z); 135 | } 136 | 137 | // Compute the distance along the ray between the point and the bounding box 138 | __device__ float Distance( 139 | const float reached_point_x, 140 | const float reached_point_y, 141 | const float reached_point_z, 142 | float direction_x, 143 | float direction_y, 144 | float direction_z, 145 | const float bounding_box_min_x, 146 | const float bounding_box_min_y, 147 | const float bounding_box_min_z, 148 | const float bounding_box_max_x, 149 | const float bounding_box_max_y, 150 | const float bounding_box_max_z) { 151 | 152 | float dist = -1.f; 153 | direction_x = direction_x / length(direction_x, direction_y, direction_z); 154 | direction_y = direction_y / length(direction_x, direction_y, direction_z); 155 | direction_z = direction_z / length(direction_x, direction_y, direction_z); 156 | 157 | // For each axis count any excess distance outside box extents 158 | float v = reached_point_x; 159 | float d = direction_x; 160 | if (dist == -1) { 161 | if ((v < bounding_box_min_x) && (d > 0)) { dist = (bounding_box_min_x - v) / d; } 162 | if ((v > bounding_box_max_x) && (d < 0)) { dist = (bounding_box_max_x - v) / d; } 163 | } else { 164 | if ((v < bounding_box_min_x) && (d > 0)) { dist = fmaxf(dist, (bounding_box_min_x - v) / d); } 165 | if ((v > bounding_box_max_x) && (d < 0)) { dist = fmaxf(dist, (bounding_box_max_x - v) / d); } 166 | } 167 | 168 | v = reached_point_y; 169 | d = direction_y; 170 | if (dist == -1) { 171 | if ((v < bounding_box_min_y) && (d > 0)) { dist = (bounding_box_min_y - v) / d; } 172 | if ((v > bounding_box_max_y) && (d < 0)) { dist = (bounding_box_max_y - v) / d; } 173 | } else { 174 | if ((v < bounding_box_min_y) && (d > 0)) { dist = fmaxf(dist, (bounding_box_min_y - v) / d); } 175 | if ((v > bounding_box_max_y) && (d < 0)) { dist = fmaxf(dist, (bounding_box_max_y - v) / d); } 176 | } 177 | 178 | v = reached_point_z; 179 | d = direction_z; 180 | if (dist == -1) { 181 | if ((v < bounding_box_min_z) && (d > 0)) { dist = (bounding_box_min_z - v) / d; } 182 | if ((v > bounding_box_max_z) && (d < 0)) { dist = (bounding_box_max_z - v) / d; } 183 | } else { 184 | if ((v < bounding_box_min_z) && (d > 0)) { dist = fmaxf(dist, (bounding_box_min_z - v) / d); } 185 | if ((v > bounding_box_max_z) && (d < 0)) { dist = fmaxf(dist, (bounding_box_max_z - v) / d); } 186 | } 187 | 188 | return dist; 189 | } 190 | 191 | __device__ __forceinline__ int flat(float const x, float const y, float const z, 192 | int const grid_res_x, int const grid_res_y, int const grid_res_z) { 193 | return __int2float_rd(z) + __int2float_rd(y) * grid_res_z + __int2float_rd(x) * grid_res_z * grid_res_y; 194 | } 195 | 196 | // Get the signed distance value at the specific point 197 | __device__ float ValueAt( 198 | const float* grid, 199 | const float reached_point_x, 200 | const float reached_point_y, 201 | const float reached_point_z, 202 | const float direction_x, 203 | const float direction_y, 204 | const float direction_z, 205 | const float bounding_box_min_x, 206 | const float bounding_box_min_y, 207 | const float bounding_box_min_z, 208 | const float bounding_box_max_x, 209 | const float bounding_box_max_y, 210 | const float bounding_box_max_z, 211 | const int grid_res_x, 212 | const int grid_res_y, 213 | const int grid_res_z, 214 | const bool first_time) { 215 | 216 | // Check if we are outside the BBOX 217 | if (!InsideBoundingBox(reached_point_x, reached_point_y, reached_point_z, 218 | bounding_box_min_x, 219 | bounding_box_min_y, 220 | bounding_box_min_z, 221 | bounding_box_max_x, 222 | bounding_box_max_y, 223 | bounding_box_max_z)) { 224 | 225 | // If it is the first time, then the ray has not entered the grid 226 | if (first_time) { 227 | 228 | return Distance(reached_point_x, reached_point_y, reached_point_z, 229 | direction_x, direction_y, direction_z, 230 | bounding_box_min_x, 231 | bounding_box_min_y, 232 | bounding_box_min_z, 233 | bounding_box_max_x, 234 | bounding_box_max_y, 235 | bounding_box_max_z) + 0.00001f; 236 | } 237 | 238 | // Otherwise, the ray has left the grid 239 | else { 240 | return -1; 241 | } 242 | } 243 | 244 | // Compute voxel size 245 | float voxel_size = (bounding_box_max_x - bounding_box_min_x) / (grid_res_x - 1); 246 | 247 | // Compute the the minimum point of the intersecting voxel 248 | float min_index_x = floorf((reached_point_x - bounding_box_min_x) / voxel_size); 249 | float min_index_y = floorf((reached_point_y - bounding_box_min_y) / voxel_size); 250 | float min_index_z = floorf((reached_point_z - bounding_box_min_z) / voxel_size); 251 | 252 | // Check whether the ray intersects the vertex with the last index of the axis 253 | // If so, we should record the previous index 254 | if (min_index_x == (bounding_box_max_x - bounding_box_min_x) / voxel_size) { 255 | min_index_x = (bounding_box_max_x - bounding_box_min_x) / voxel_size - 1; 256 | } 257 | if (min_index_y == (bounding_box_max_y - bounding_box_min_y) / voxel_size) { 258 | min_index_y = (bounding_box_max_y - bounding_box_min_y) / voxel_size - 1; 259 | } 260 | if (min_index_z == (bounding_box_max_z - bounding_box_min_z) / voxel_size) { 261 | min_index_z = (bounding_box_max_z - bounding_box_min_z) / voxel_size - 1; 262 | } 263 | 264 | // Linear interpolate along x axis the eight values 265 | const float tx = (reached_point_x - (bounding_box_min_x + min_index_x * voxel_size)) / voxel_size; 266 | const float c01 = (1.f - tx) * grid[flat(min_index_x, min_index_y, min_index_z, grid_res_x, grid_res_y, grid_res_z)] 267 | + tx * grid[flat(min_index_x+1, min_index_y, min_index_z, grid_res_x, grid_res_y, grid_res_z)]; 268 | const float c23 = (1.f - tx) * grid[flat(min_index_x, min_index_y+1, min_index_z, grid_res_x, grid_res_y, grid_res_z)] 269 | + tx * grid[flat(min_index_x+1, min_index_y+1, min_index_z, grid_res_x, grid_res_y, grid_res_z)]; 270 | const float c45 = (1.f - tx) * grid[flat(min_index_x, min_index_y, min_index_z+1, grid_res_x, grid_res_y, grid_res_z)] 271 | + tx * grid[flat(min_index_x+1, min_index_y, min_index_z+1, grid_res_x, grid_res_y, grid_res_z)]; 272 | const float c67 = (1.f - tx) * grid[flat(min_index_x, min_index_y+1, min_index_z+1, grid_res_x, grid_res_y, grid_res_z)] 273 | + tx * grid[flat(min_index_x+1, min_index_y+1, min_index_z+1, grid_res_x, grid_res_y, grid_res_z)]; 274 | 275 | // Linear interpolate along the y axis 276 | const float ty = (reached_point_y - (bounding_box_min_y + min_index_y * voxel_size)) / voxel_size; 277 | const float c0 = (1.f - ty) * c01 + ty * c23; 278 | const float c1 = (1.f - ty) * c45 + ty * c67; 279 | 280 | // Return final value interpolated along z 281 | const float tz = (reached_point_z - (bounding_box_min_z + min_index_z * voxel_size)) / voxel_size; 282 | 283 | return (1.f - tz) * c0 + tz * c1; 284 | } 285 | 286 | // Compute the intersection of the ray and the grid 287 | // The intersection procedure uses ray marching to check if we have an interaction with the stored surface 288 | __global__ void Intersect( 289 | const float* grid, 290 | const float* origins, 291 | const float* directions, 292 | const float* origin_image_distances, 293 | const float* pixel_distances, 294 | const float bounding_box_min_x, 295 | const float bounding_box_min_y, 296 | const float bounding_box_min_z, 297 | const float bounding_box_max_x, 298 | const float bounding_box_max_y, 299 | const float bounding_box_max_z, 300 | const int grid_res_x, 301 | const int grid_res_y, 302 | const int grid_res_z, 303 | float* voxel_position, 304 | float* intersection_pos, 305 | const int width, 306 | const int height) { 307 | 308 | // Compute voxel size 309 | const float voxel_size = (bounding_box_max_x - bounding_box_min_x) / (grid_res_x - 1); 310 | 311 | // Define constant values 312 | const int max_steps = 1000; 313 | bool first_time = true; 314 | float depth = 0; 315 | int gotten_result = 0; 316 | 317 | const int pixel_index = blockIdx.x * blockDim.x + threadIdx.x; 318 | 319 | if (pixel_index < width * height) { 320 | 321 | const int i = 3 * pixel_index; 322 | 323 | for (int steps = 0; steps < max_steps; steps++) { 324 | 325 | float reached_point_x = origins[i] + depth * directions[i]; 326 | float reached_point_y = origins[i+1] + depth * directions[i+1]; 327 | float reached_point_z = origins[i+2] + depth * directions[i+2]; 328 | 329 | // Get the signed distance value for the point the ray reaches 330 | const float distance = ValueAt(grid, reached_point_x, reached_point_y, reached_point_z, 331 | directions[i], directions[i+1], directions[i+2], 332 | bounding_box_min_x, 333 | bounding_box_min_y, 334 | bounding_box_min_z, 335 | bounding_box_max_x, 336 | bounding_box_max_y, 337 | bounding_box_max_z, 338 | grid_res_x, 339 | grid_res_y, 340 | grid_res_z, first_time); 341 | first_time = false; 342 | 343 | // Check if the ray is going ourside the bounding box 344 | if (distance == -1) { 345 | voxel_position[i] = -1; 346 | voxel_position[i+1] = -1; 347 | voxel_position[i+2] = -1; 348 | intersection_pos[i] = -1; 349 | intersection_pos[i+1] = -1; 350 | intersection_pos[i+2] = -1; 351 | gotten_result = 1; 352 | break; 353 | } 354 | 355 | // Check if we are close enough to the surface 356 | if (distance < pixel_distances[pixel_index] / origin_image_distances[pixel_index] * depth && distance) { 357 | 358 | // Compute the the minimum point of the intersecting voxel 359 | voxel_position[i] = floorf((reached_point_x - bounding_box_min_x) / voxel_size); 360 | voxel_position[i+1] = floorf((reached_point_y - bounding_box_min_y) / voxel_size); 361 | voxel_position[i+2] = floorf((reached_point_z - bounding_box_min_z) / voxel_size); 362 | if (voxel_position[i] == grid_res_x - 1) { 363 | voxel_position[i] = voxel_position[i] - 1; 364 | } 365 | if (voxel_position[i+1] == grid_res_x - 1) { 366 | voxel_position[i+1] = voxel_position[i+1] - 1; 367 | } 368 | if (voxel_position[i+2] == grid_res_x - 1) { 369 | voxel_position[i+2] = voxel_position[i+2] - 1; 370 | } 371 | intersection_pos[i] = reached_point_x; 372 | intersection_pos[i+1] = reached_point_y; 373 | intersection_pos[i+2] = reached_point_z; 374 | gotten_result = 1; 375 | break; 376 | } 377 | 378 | // Increase distance 379 | depth += distance; 380 | 381 | } 382 | 383 | if (gotten_result == 0) { 384 | 385 | // No intersections 386 | voxel_position[i] = -1; 387 | voxel_position[i+1] = -1; 388 | voxel_position[i+2] = -1; 389 | intersection_pos[i] = -1; 390 | intersection_pos[i+1] = -1; 391 | intersection_pos[i+2] = -1; 392 | } 393 | } 394 | } 395 | } // namespace 396 | 397 | // Ray marching to get the first corner position of the voxel the ray intersects 398 | std::vector ray_matching_cuda( 399 | const at::Tensor w_h_3, 400 | const at::Tensor w_h, 401 | const at::Tensor grid, 402 | const int width, 403 | const int height, 404 | const float bounding_box_min_x, 405 | const float bounding_box_min_y, 406 | const float bounding_box_min_z, 407 | const float bounding_box_max_x, 408 | const float bounding_box_max_y, 409 | const float bounding_box_max_z, 410 | const int grid_res_x, 411 | const int grid_res_y, 412 | const int grid_res_z, 413 | const float eye_x, 414 | const float eye_y, 415 | const float eye_z) { 416 | 417 | const int thread = 512; 418 | 419 | at::Tensor origins = at::zeros_like(w_h_3); 420 | at::Tensor directions = at::zeros_like(w_h_3); 421 | at::Tensor origin_image_distances = at::zeros_like(w_h); 422 | at::Tensor pixel_distances = at::zeros_like(w_h); 423 | 424 | GenerateRay<<<(width * height + thread - 1) / thread, thread>>>( 425 | origins.data(), 426 | directions.data(), 427 | origin_image_distances.data(), 428 | pixel_distances.data(), 429 | width, 430 | height, 431 | eye_x, 432 | eye_y, 433 | eye_z); 434 | 435 | at::Tensor voxel_position = at::zeros_like(w_h_3); 436 | at::Tensor intersection_pos = at::zeros_like(w_h_3); 437 | 438 | Intersect<<<(width * height + thread - 1) / thread, thread>>>( 439 | grid.data(), 440 | origins.data(), 441 | directions.data(), 442 | origin_image_distances.data(), 443 | pixel_distances.data(), 444 | bounding_box_min_x, 445 | bounding_box_min_y, 446 | bounding_box_min_z, 447 | bounding_box_max_x, 448 | bounding_box_max_y, 449 | bounding_box_max_z, 450 | grid_res_x, 451 | grid_res_y, 452 | grid_res_z, 453 | voxel_position.data(), 454 | intersection_pos.data(), 455 | width, 456 | height); 457 | 458 | return {intersection_pos, voxel_position, directions}; 459 | } 460 | 461 | 462 | 463 | -------------------------------------------------------------------------------- /multi_view_code/code/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='renderer', 6 | ext_modules=[ 7 | CUDAExtension('renderer', [ 8 | 'renderer.cpp', 9 | 'renderer_kernel.cu', 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | -------------------------------------------------------------------------------- /single_view_code/differentiable_rendering.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import math 4 | import torchvision 5 | from torch.autograd import Variable 6 | from torchvision import transforms, datasets 7 | from torchvision.utils import save_image, make_grid 8 | import renderer 9 | import time 10 | import sys 11 | 12 | cuda = True if torch.cuda.is_available() else False 13 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 14 | 15 | 16 | def read_txt(file_path, grid_res_x, grid_res_y, grid_res_z): 17 | with open(file_path) as file: 18 | grid = Tensor(grid_res_x, grid_res_y, grid_res_z) 19 | for i in range(grid_res_x): 20 | for j in range(grid_res_y): 21 | for k in range(grid_res_z): 22 | grid[i][j][k] = float(file.readline()) 23 | print (grid) 24 | 25 | return grid 26 | 27 | # Read a file and create a sdf grid with target_grid_res 28 | def read_sdf(file_path, target_grid_res, target_bounding_box_min, target_bounding_box_max, target_voxel_size): 29 | 30 | with open(file_path) as file: 31 | line = file.readline() 32 | 33 | # Get grid resolutions 34 | grid_res = line.split() 35 | grid_res_x = int(grid_res[0]) 36 | grid_res_y = int(grid_res[1]) 37 | grid_res_z = int(grid_res[2]) 38 | 39 | # Get bounding box min 40 | line = file.readline() 41 | bounding_box_min = line.split() 42 | bounding_box_min_x = float(bounding_box_min[0]) 43 | bounding_box_min_y = float(bounding_box_min[1]) 44 | bounding_box_min_z = float(bounding_box_min[2]) 45 | 46 | line = file.readline() 47 | voxel_size = float(line) 48 | 49 | # max bounding box (we need to plus 0.0001 to avoid round error) 50 | bounding_box_max_x = bounding_box_min_x + voxel_size * (grid_res_x - 1)# + 0.0001 51 | bounding_box_max_y = bounding_box_min_y + voxel_size * (grid_res_y - 1) #+ 0.0001 52 | bounding_box_max_z = bounding_box_min_z + voxel_size * (grid_res_z - 1) #+ 0.0001 53 | 54 | min_bounding_box_min = min(bounding_box_min_x, bounding_box_min_y, bounding_box_min_z) 55 | print(bounding_box_min_x, bounding_box_min_y, bounding_box_min_z) 56 | max_bounding_box_max = max(bounding_box_max_x, bounding_box_max_y, bounding_box_max_z) 57 | print(bounding_box_max_x, bounding_box_max_y, bounding_box_max_z) 58 | max_dist = max(bounding_box_max_x - bounding_box_min_x, bounding_box_max_y - bounding_box_min_y, bounding_box_max_z - bounding_box_min_z) 59 | 60 | max_grid_res = max(grid_res_x, grid_res_y, grid_res_z) 61 | 62 | grid = [] 63 | for i in range(grid_res_x): 64 | grid.append([]) 65 | for j in range(grid_res_y): 66 | grid[i].append([]) 67 | for k in range(grid_res_z): 68 | grid[i][j].append(2) 69 | 70 | for i in range(grid_res_z): 71 | for j in range(grid_res_y): 72 | for k in range(grid_res_x): 73 | grid_value = float(file.readline()) 74 | grid[k][j][i] = grid_value 75 | 76 | grid = Tensor(grid) 77 | 78 | target_grid = Tensor(target_grid_res, target_grid_res, target_grid_res) 79 | 80 | linear_space_x = torch.linspace(0, target_grid_res-1, target_grid_res) 81 | linear_space_y = torch.linspace(0, target_grid_res-1, target_grid_res) 82 | linear_space_z = torch.linspace(0, target_grid_res-1, target_grid_res) 83 | first_loop = linear_space_x.repeat(target_grid_res * target_grid_res, 1).t().contiguous().view(-1).unsqueeze_(1) 84 | second_loop = linear_space_y.repeat(target_grid_res, target_grid_res).t().contiguous().view(-1).unsqueeze_(1) 85 | third_loop = linear_space_z.repeat(target_grid_res * target_grid_res).unsqueeze_(1) 86 | loop = torch.cat((first_loop, second_loop, third_loop), 1).cuda() 87 | 88 | min_x = Tensor([bounding_box_min_x]).repeat(target_grid_res*target_grid_res*target_grid_res, 1) 89 | min_y = Tensor([bounding_box_min_y]).repeat(target_grid_res*target_grid_res*target_grid_res, 1) 90 | min_z = Tensor([bounding_box_min_z]).repeat(target_grid_res*target_grid_res*target_grid_res, 1) 91 | bounding_min_matrix = torch.cat((min_x, min_y, min_z), 1) 92 | 93 | move_to_center_x = Tensor([(max_dist - (bounding_box_max_x - bounding_box_min_x)) / 2]).repeat(target_grid_res*target_grid_res*target_grid_res, 1) 94 | move_to_center_y = Tensor([(max_dist - (bounding_box_max_y - bounding_box_min_y)) / 2]).repeat(target_grid_res*target_grid_res*target_grid_res, 1) 95 | move_to_center_z = Tensor([(max_dist - (bounding_box_max_z - bounding_box_min_z)) / 2]).repeat(target_grid_res*target_grid_res*target_grid_res, 1) 96 | move_to_center_matrix = torch.cat((move_to_center_x, move_to_center_y, move_to_center_z), 1) 97 | 98 | # Get the position of the grid points in the refined grid 99 | points = bounding_min_matrix + target_voxel_size * max_dist / (target_bounding_box_max - target_bounding_box_min) * loop - move_to_center_matrix 100 | if points[(points[:, 0] < bounding_box_min_x)].shape[0] != 0: 101 | points[(points[:, 0] < bounding_box_min_x)] = Tensor([bounding_box_max_x, bounding_box_max_y, bounding_box_max_z]).view(1,3) 102 | if points[(points[:, 1] < bounding_box_min_y)].shape[0] != 0: 103 | points[(points[:, 1] < bounding_box_min_y)] = Tensor([bounding_box_max_x, bounding_box_min_y, bounding_box_min_z]).view(1,3) 104 | if points[(points[:, 2] < bounding_box_min_z)].shape[0] != 0: 105 | points[(points[:, 2] < bounding_box_min_z)] = Tensor([bounding_box_max_x, bounding_box_min_y, bounding_box_min_z]).view(1,3) 106 | if points[(points[:, 0] > bounding_box_max_x)].shape[0] != 0: 107 | points[(points[:, 0] > bounding_box_max_x)] = Tensor([bounding_box_max_x, bounding_box_min_y, bounding_box_min_z]).view(1,3) 108 | if points[(points[:, 1] > bounding_box_max_y)].shape[0] != 0: 109 | points[(points[:, 1] > bounding_box_max_y)] = Tensor([bounding_box_max_x, bounding_box_min_y, bounding_box_min_z]).view(1,3) 110 | if points[(points[:, 2] > bounding_box_max_z)].shape[0] != 0: 111 | points[(points[:, 2] > bounding_box_max_z)] = Tensor([bounding_box_max_x, bounding_box_min_y, bounding_box_min_z]).view(1,3) 112 | voxel_min_point_index_x = torch.floor((points[:,0].unsqueeze_(1) - min_x) / voxel_size).clamp(max=grid_res_x-2) 113 | voxel_min_point_index_y = torch.floor((points[:,1].unsqueeze_(1) - min_y) / voxel_size).clamp(max=grid_res_y-2) 114 | voxel_min_point_index_z = torch.floor((points[:,2].unsqueeze_(1) - min_z) / voxel_size).clamp(max=grid_res_z-2) 115 | voxel_min_point_index = torch.cat((voxel_min_point_index_x, voxel_min_point_index_y, voxel_min_point_index_z), 1) 116 | voxel_min_point = bounding_min_matrix + voxel_min_point_index * voxel_size 117 | 118 | # Compute the sdf value of the grid points in the refined grid 119 | target_grid = calculate_sdf_value(grid, points, voxel_min_point, voxel_min_point_index, voxel_size, grid_res_x, grid_res_y, grid_res_z).view(target_grid_res, target_grid_res, target_grid_res) 120 | return target_grid 121 | 122 | 123 | def grid_construction_cube(grid_res, bounding_box_min, bounding_box_max): 124 | 125 | # Construct the sdf grid for a cube with size 2 126 | voxel_size = (bounding_box_max - bounding_box_min) / (grid_res - 1) 127 | cube_left_bound_index = float(grid_res - 1) / 4; 128 | cube_right_bound_index = float(grid_res - 1) / 4 * 3; 129 | cube_center = float(grid_res - 1) / 2; 130 | 131 | grid = Tensor(grid_res, grid_res, grid_res) 132 | for i in range(grid_res): 133 | for j in range(grid_res): 134 | for k in range(grid_res): 135 | if (i >= cube_left_bound_index and i <= cube_right_bound_index and 136 | j >= cube_left_bound_index and j <= cube_right_bound_index and 137 | k >= cube_left_bound_index and k <= cube_right_bound_index): 138 | grid[i,j,k] = voxel_size * max(abs(i - cube_center), abs(j - cube_center), abs(k - cube_center)) - 1; 139 | else: 140 | grid[i,j,k] = math.sqrt(pow(voxel_size * (max(i - cube_right_bound_index, cube_left_bound_index - i, 0)), 2) + 141 | pow(voxel_size * (max(j - cube_right_bound_index, cube_left_bound_index - j, 0)), 2) + 142 | pow(voxel_size * (max(k - cube_right_bound_index, cube_left_bound_index - k, 0)), 2)); 143 | return grid 144 | 145 | def grid_construction_torus(grid_res, bounding_box_min, bounding_box_max): 146 | 147 | # radius of the circle between the two circles 148 | radius_big = 1.5 149 | 150 | # radius of the small circle 151 | radius_small = 0.5 152 | 153 | voxel_size = (bounding_box_max - bounding_box_min) / (grid_res - 1) 154 | grid = Tensor(grid_res, grid_res, grid_res) 155 | for i in range(grid_res): 156 | for j in range(grid_res): 157 | for k in range(grid_res): 158 | x = bounding_box_min + voxel_size * i 159 | y = bounding_box_min + voxel_size * j 160 | z = bounding_box_min + voxel_size * k 161 | 162 | grid[i,j,k] = math.sqrt(math.pow((math.sqrt(math.pow(y, 2) + math.pow(z, 2)) - radius_big), 2) 163 | + math.pow(x, 2)) - radius_small; 164 | 165 | return grid 166 | 167 | 168 | 169 | def grid_construction_sphere_big(grid_res, bounding_box_min, bounding_box_max): 170 | 171 | # Construct the sdf grid for a sphere with radius 1 172 | linear_space = torch.linspace(bounding_box_min, bounding_box_max, grid_res) 173 | x_dim = linear_space.view(-1, 1).repeat(grid_res, 1, grid_res) 174 | y_dim = linear_space.view(1, -1).repeat(grid_res, grid_res, 1) 175 | z_dim = linear_space.view(-1, 1, 1).repeat(1, grid_res, grid_res) 176 | grid = torch.sqrt(x_dim * x_dim + y_dim * y_dim + z_dim * z_dim) - 1.6 177 | if cuda: 178 | return grid.cuda() 179 | else: 180 | return grid 181 | 182 | def grid_construction_sphere_small(grid_res, bounding_box_min, bounding_box_max): 183 | 184 | # Construct the sdf grid for a sphere with radius 1 185 | linear_space = torch.linspace(bounding_box_min, bounding_box_max, grid_res) 186 | x_dim = linear_space.view(-1, 1).repeat(grid_res, 1, grid_res) 187 | y_dim = linear_space.view(1, -1).repeat(grid_res, grid_res, 1) 188 | z_dim = linear_space.view(-1, 1, 1).repeat(1, grid_res, grid_res) 189 | grid = torch.sqrt(x_dim * x_dim + y_dim * y_dim + z_dim * z_dim) - 1 190 | if cuda: 191 | return grid.cuda() 192 | else: 193 | return grid 194 | 195 | 196 | def get_grid_normal(grid, voxel_size, grid_res_x, grid_res_y, grid_res_z): 197 | 198 | # largest index 199 | n_x = grid_res_x - 1 200 | n_y = grid_res_y - 1 201 | n_z = grid_res_z - 1 202 | 203 | # x-axis normal vectors 204 | X_1 = torch.cat((grid[1:,:,:], (3 * grid[n_x,:,:] - 3 * grid[n_x-1,:,:] + grid[n_x-2,:,:]).unsqueeze_(0)), 0) 205 | X_2 = torch.cat(((-3 * grid[1,:,:] + 3 * grid[0,:,:] + grid[2,:,:]).unsqueeze_(0), grid[:n_x,:,:]), 0) 206 | grid_normal_x = (X_1 - X_2) / (2 * voxel_size) 207 | 208 | # y-axis normal vectors 209 | Y_1 = torch.cat((grid[:,1:,:], (3 * grid[:,n_y,:] - 3 * grid[:,n_y-1,:] + grid[:,n_y-2,:]).unsqueeze_(1)), 1) 210 | Y_2 = torch.cat(((-3 * grid[:,1,:] + 3 * grid[:,0,:] + grid[:,2,:]).unsqueeze_(1), grid[:,:n_y,:]), 1) 211 | grid_normal_y = (Y_1 - Y_2) / (2 * voxel_size) 212 | 213 | # z-axis normal vectors 214 | Z_1 = torch.cat((grid[:,:,1:], (3 * grid[:,:,n_z] - 3 * grid[:,:,n_z-1] + grid[:,:,n_z-2]).unsqueeze_(2)), 2) 215 | Z_2 = torch.cat(((-3 * grid[:,:,1] + 3 * grid[:,:,0] + grid[:,:,2]).unsqueeze_(2), grid[:,:,:n_z]), 2) 216 | grid_normal_z = (Z_1 - Z_2) / (2 * voxel_size) 217 | 218 | 219 | return [grid_normal_x, grid_normal_y, grid_normal_z] 220 | 221 | 222 | def get_intersection_normal(intersection_grid_normal, intersection_pos, voxel_min_point, voxel_size): 223 | 224 | # Compute parameters 225 | tx = (intersection_pos[:,:,0] - voxel_min_point[:,:,0]) / voxel_size 226 | ty = (intersection_pos[:,:,1] - voxel_min_point[:,:,1]) / voxel_size 227 | tz = (intersection_pos[:,:,2] - voxel_min_point[:,:,2]) / voxel_size 228 | 229 | intersection_normal = (1 - tz) * (1 - ty) * (1 - tx) * intersection_grid_normal[:,:,0] \ 230 | + tz * (1 - ty) * (1 - tx) * intersection_grid_normal[:,:,1] \ 231 | + (1 - tz) * ty * (1 - tx) * intersection_grid_normal[:,:,2] \ 232 | + tz * ty * (1 - tx) * intersection_grid_normal[:,:,3] \ 233 | + (1 - tz) * (1 - ty) * tx * intersection_grid_normal[:,:,4] \ 234 | + tz * (1 - ty) * tx * intersection_grid_normal[:,:,5] \ 235 | + (1 - tz) * ty * tx * intersection_grid_normal[:,:,6] \ 236 | + tz * ty * tx * intersection_grid_normal[:,:,7] 237 | 238 | return intersection_normal 239 | 240 | 241 | # Do one more step for ray matching 242 | def calculate_sdf_value(grid, points, voxel_min_point, voxel_min_point_index, voxel_size, grid_res_x, grid_res_y, grid_res_z): 243 | 244 | string = "" 245 | # Linear interpolate along x axis the eight values 246 | tx = (points[:,0] - voxel_min_point[:,0]) / voxel_size; 247 | string = string + "\n\nvoxel_size: \n" + str(voxel_size) 248 | string = string + "\n\ntx: \n" + str(tx) 249 | # print(grid.shape) 250 | 251 | if cuda: 252 | tx = tx.cuda() 253 | x = voxel_min_point_index.long()[:,0] 254 | y = voxel_min_point_index.long()[:,1] 255 | z = voxel_min_point_index.long()[:,2] 256 | 257 | string = string + "\n\nx: \n" + str(x) 258 | string = string + "\n\ny: \n" + str(y) 259 | string = string + "\n\nz: \n" + str(z) 260 | 261 | c01 = (1 - tx) * grid[x,y,z] + tx * grid[x+1,y,z]; 262 | c23 = (1 - tx) * grid[x,y+1,z] + tx * grid[x+1,y+1,z]; 263 | c45 = (1 - tx) * grid[x,y,z+1] + tx * grid[x+1,y,z+1]; 264 | c67 = (1 - tx) * grid[x,y+1,z+1] + tx * grid[x+1,y+1,z+1]; 265 | 266 | string = string + "\n\n(1 - tx): \n" + str((1 - tx)) 267 | string = string + "\n\ngrid[x,y,z]: \n" + str(grid[x,y,z]) 268 | string = string + "\n\ngrid[x+1,y,z]: \n" + str(grid[x+1,y,z]) 269 | string = string + "\n\nc01: \n" + str(c01) 270 | string = string + "\n\nc23: \n" + str(c23) 271 | string = string + "\n\nc45: \n" + str(c45) 272 | string = string + "\n\nc67: \n" + str(c67) 273 | 274 | # Linear interpolate along the y axis 275 | ty = (points[:,1] - voxel_min_point[:,1]) / voxel_size; 276 | ty = ty.cuda() 277 | c0 = (1 - ty) * c01 + ty * c23; 278 | c1 = (1 - ty) * c45 + ty * c67; 279 | 280 | string = string + "\n\nty: \n" + str(ty) 281 | 282 | string = string + "\n\nc0: \n" + str(c0) 283 | string = string + "\n\nc1: \n" + str(c1) 284 | 285 | # Return final value interpolated along z 286 | tz = (points[:,2] - voxel_min_point[:,2]) / voxel_size; 287 | tz = tz.cuda() 288 | string = string + "\n\ntz: \n" + str(tz) 289 | 290 | else: 291 | x = voxel_min_point_index.numpy()[:,0] 292 | y = voxel_min_point_index.numpy()[:,1] 293 | z = voxel_min_point_index.numpy()[:,2] 294 | 295 | c01 = (1 - tx) * grid[x,y,z] + tx * grid[x+1,y,z]; 296 | c23 = (1 - tx) * grid[x,y+1,z] + tx * grid[x+1,y+1,z]; 297 | c45 = (1 - tx) * grid[x,y,z+1] + tx * grid[x+1,y,z+1]; 298 | c67 = (1 - tx) * grid[x,y+1,z+1] + tx * grid[x+1,y+1,z+1]; 299 | 300 | # Linear interpolate along the y axis 301 | ty = (points[:,1] - voxel_min_point[:,1]) / voxel_size; 302 | c0 = (1 - ty) * c01 + ty * c23; 303 | c1 = (1 - ty) * c45 + ty * c67; 304 | 305 | # Return final value interpolated along z 306 | tz = (points[:,2] - voxel_min_point[:,2]) / voxel_size; 307 | 308 | result = (1 - tz) * c0 + tz * c1; 309 | 310 | return result 311 | 312 | 313 | def compute_intersection_pos(grid, intersection_pos_rough, voxel_min_point, voxel_min_point_index, ray_direction, voxel_size, mask): 314 | 315 | # Linear interpolate along x axis the eight values 316 | tx = (intersection_pos_rough[:,:,0] - voxel_min_point[:,:,0]) / voxel_size; 317 | 318 | if cuda: 319 | 320 | x = voxel_min_point_index.long()[:,:,0] 321 | y = voxel_min_point_index.long()[:,:,1] 322 | z = voxel_min_point_index.long()[:,:,2] 323 | 324 | c01 = (1 - tx) * grid[x,y,z].cuda() + tx * grid[x+1,y,z].cuda(); 325 | c23 = (1 - tx) * grid[x,y+1,z].cuda() + tx * grid[x+1,y+1,z].cuda(); 326 | c45 = (1 - tx) * grid[x,y,z+1].cuda() + tx * grid[x+1,y,z+1].cuda(); 327 | c67 = (1 - tx) * grid[x,y+1,z+1].cuda() + tx * grid[x+1,y+1,z+1].cuda(); 328 | 329 | else: 330 | x = voxel_min_point_index.numpy()[:,:,0] 331 | y = voxel_min_point_index.numpy()[:,:,1] 332 | z = voxel_min_point_index.numpy()[:,:,2] 333 | 334 | c01 = (1 - tx) * grid[x,y,z] + tx * grid[x+1,y,z]; 335 | c23 = (1 - tx) * grid[x,y+1,z] + tx * grid[x+1,y+1,z]; 336 | c45 = (1 - tx) * grid[x,y,z+1] + tx * grid[x+1,y,z+1]; 337 | c67 = (1 - tx) * grid[x,y+1,z+1] + tx * grid[x+1,y+1,z+1]; 338 | 339 | # Linear interpolate along the y axis 340 | ty = (intersection_pos_rough[:,:,1] - voxel_min_point[:,:,1]) / voxel_size; 341 | c0 = (1 - ty) * c01 + ty * c23; 342 | c1 = (1 - ty) * c45 + ty * c67; 343 | 344 | # Return final value interpolated along z 345 | tz = (intersection_pos_rough[:,:,2] - voxel_min_point[:,:,2]) / voxel_size; 346 | 347 | sdf_value = (1 - tz) * c0 + tz * c1; 348 | 349 | return (intersection_pos_rough + ray_direction * sdf_value.view(width,height,1).repeat(1,1,3))\ 350 | + (1 - mask.view(width,height,1).repeat(1,1,3)) 351 | 352 | 353 | def differentiable_rendering(grid, grid_res, image_res, camera): 354 | # print(grid_res, image_res, camera) 355 | # return grid 356 | global width, height 357 | width = image_res 358 | height = image_res 359 | 360 | return generate_image(-2, -2, -2, 2, 2, 2, \ 361 | 4./(grid_res-1), grid_res, grid_res, grid_res, image_res, image_res, grid, camera, False, []) 362 | 363 | def differentiable_rendering_silhouette(grid, grid_res, image_res, camera): 364 | # print(grid_res, image_res, camera) 365 | return generate_image(-2, -2, -2, 2, 2, 2, \ 366 | 4./(grid_res-1), grid_res, grid_res, grid_res, image_res, image_res, grid, camera, True) 367 | 368 | def generate_image(bounding_box_min_x, bounding_box_min_y, bounding_box_min_z, \ 369 | bounding_box_max_x, bounding_box_max_y, bounding_box_max_z, \ 370 | voxel_size, grid_res_x, grid_res_y, grid_res_z, width, height, grid, camera, back, camera_list): 371 | 372 | # Get normal vectors for points on the grid 373 | [grid_normal_x, grid_normal_y, grid_normal_z] = get_grid_normal(grid, voxel_size, grid_res_x, grid_res_y, grid_res_z) 374 | 375 | # Generate rays 376 | e = camera 377 | 378 | w_h_3 = torch.zeros(width, height, 3).cuda() 379 | w_h = torch.zeros(width, height).cuda() 380 | eye_x = e[0] 381 | eye_y = e[1] 382 | eye_z = e[2] 383 | 384 | # Do ray tracing in cpp 385 | outputs = renderer.ray_matching(w_h_3, w_h, grid, width, height, bounding_box_min_x, bounding_box_min_y, bounding_box_min_z, \ 386 | bounding_box_max_x, bounding_box_max_y, bounding_box_max_z, \ 387 | grid_res_x, grid_res_y, grid_res_z, \ 388 | eye_x, \ 389 | eye_y, \ 390 | eye_z 391 | ) 392 | 393 | # {intersection_pos, voxel_position, directions} 394 | intersection_pos_rough = outputs[0] 395 | voxel_min_point_index = outputs[1] 396 | ray_direction = outputs[2] 397 | 398 | # Initialize grid values and normals for intersection voxels 399 | intersection_grid_normal_x = Tensor(width, height, 8) 400 | intersection_grid_normal_y = Tensor(width, height, 8) 401 | intersection_grid_normal_z = Tensor(width, height, 8) 402 | intersection_grid = Tensor(width, height, 8) 403 | 404 | # Make the pixels with no intersections with rays be 0 405 | mask = (voxel_min_point_index[:,:,0] != -1).type(Tensor) 406 | 407 | # Get the indices of the minimum point of the intersecting voxels 408 | x = voxel_min_point_index[:,:,0].type(torch.cuda.LongTensor) 409 | y = voxel_min_point_index[:,:,1].type(torch.cuda.LongTensor) 410 | z = voxel_min_point_index[:,:,2].type(torch.cuda.LongTensor) 411 | x[x == -1] = 0 412 | y[y == -1] = 0 413 | z[z == -1] = 0 414 | 415 | # Get the x-axis of normal vectors for the 8 points of the intersecting voxel 416 | # This line is equivalent to grid_normal_x[x,y,z] 417 | x1 = torch.index_select(grid_normal_x.view(-1), 0, z.view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 418 | x2 = torch.index_select(grid_normal_x.view(-1), 0, (z+1).view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 419 | x3 = torch.index_select(grid_normal_x.view(-1), 0, z.view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 420 | x4 = torch.index_select(grid_normal_x.view(-1), 0, (z+1).view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 421 | x5 = torch.index_select(grid_normal_x.view(-1), 0, z.view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 422 | x6 = torch.index_select(grid_normal_x.view(-1), 0, (z+1).view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 423 | x7 = torch.index_select(grid_normal_x.view(-1), 0, z.view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 424 | x8 = torch.index_select(grid_normal_x.view(-1), 0, (z+1).view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 425 | intersection_grid_normal_x = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), 2) + (1 - mask.view(width, height, 1).repeat(1,1,8)) 426 | 427 | # Get the y-axis of normal vectors for the 8 points of the intersecting voxel 428 | y1 = torch.index_select(grid_normal_y.view(-1), 0, z.view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 429 | y2 = torch.index_select(grid_normal_y.view(-1), 0, (z+1).view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 430 | y3 = torch.index_select(grid_normal_y.view(-1), 0, z.view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 431 | y4 = torch.index_select(grid_normal_y.view(-1), 0, (z+1).view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 432 | y5 = torch.index_select(grid_normal_y.view(-1), 0, z.view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 433 | y6 = torch.index_select(grid_normal_y.view(-1), 0, (z+1).view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 434 | y7 = torch.index_select(grid_normal_y.view(-1), 0, z.view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 435 | y8 = torch.index_select(grid_normal_y.view(-1), 0, (z+1).view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 436 | intersection_grid_normal_y = torch.cat((y1, y2, y3, y4, y5, y6, y7, y8), 2) + (1 - mask.view(width, height, 1).repeat(1,1,8)) 437 | 438 | # Get the z-axis of normal vectors for the 8 points of the intersecting voxel 439 | z1 = torch.index_select(grid_normal_z.view(-1), 0, z.view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 440 | z2 = torch.index_select(grid_normal_z.view(-1), 0, (z+1).view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 441 | z3 = torch.index_select(grid_normal_z.view(-1), 0, z.view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 442 | z4 = torch.index_select(grid_normal_z.view(-1), 0, (z+1).view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * x.view(-1)).view(x.shape).unsqueeze_(2) 443 | z5 = torch.index_select(grid_normal_z.view(-1), 0, z.view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 444 | z6 = torch.index_select(grid_normal_z.view(-1), 0, (z+1).view(-1) + grid_res_x * y.view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 445 | z7 = torch.index_select(grid_normal_z.view(-1), 0, z.view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 446 | z8 = torch.index_select(grid_normal_z.view(-1), 0, (z+1).view(-1) + grid_res_x * (y+1).view(-1) + grid_res_x * grid_res_x * (x+1).view(-1)).view(x.shape).unsqueeze_(2) 447 | intersection_grid_normal_z = torch.cat((z1, z2, z3, z4, z5, z6, z7, z8), 2) + (1 - mask.view(width, height, 1).repeat(1,1,8)) 448 | 449 | # Change from grid coordinates to world coordinates 450 | voxel_min_point = Tensor([bounding_box_min_x, bounding_box_min_y, bounding_box_min_z]) + voxel_min_point_index * voxel_size 451 | 452 | intersection_pos = compute_intersection_pos(grid, intersection_pos_rough,\ 453 | voxel_min_point, voxel_min_point_index,\ 454 | ray_direction, voxel_size, mask) 455 | 456 | intersection_pos = intersection_pos * mask.repeat(3,1,1).permute(1,2,0) 457 | shading = Tensor(width, height).fill_(0) 458 | 459 | # Compute the normal vectors for the intersecting points 460 | intersection_normal_x = get_intersection_normal(intersection_grid_normal_x, intersection_pos, voxel_min_point, voxel_size) 461 | intersection_normal_y = get_intersection_normal(intersection_grid_normal_y, intersection_pos, voxel_min_point, voxel_size) 462 | intersection_normal_z = get_intersection_normal(intersection_grid_normal_z, intersection_pos, voxel_min_point, voxel_size) 463 | 464 | # Put all the xyz-axis of the normal vectors into a single matrix 465 | intersection_normal_x_resize = intersection_normal_x.unsqueeze_(2) 466 | intersection_normal_y_resize = intersection_normal_y.unsqueeze_(2) 467 | intersection_normal_z_resize = intersection_normal_z.unsqueeze_(2) 468 | intersection_normal = torch.cat((intersection_normal_x_resize, intersection_normal_y_resize, intersection_normal_z_resize), 2) 469 | intersection_normal = intersection_normal / torch.unsqueeze(torch.norm(intersection_normal, p=2, dim=2), 2).repeat(1, 1, 3) 470 | 471 | # Create the point light 472 | shading = 0 473 | light_position = camera.repeat(width, height, 1) 474 | light_norm = torch.unsqueeze(torch.norm(light_position - intersection_pos, p=2, dim=2), 2).repeat(1, 1, 3) 475 | light_direction_point = (light_position - intersection_pos) / light_norm 476 | light_direction = camera.repeat(width, height, 1) 477 | l_dot_n = torch.sum(light_direction * intersection_normal, 2).unsqueeze_(2) 478 | shading += 2 * torch.max(l_dot_n, Tensor(width, height, 1).fill_(0))[:,:,0] / torch.pow(torch.sum((light_position - intersection_pos) * light_direction_point, dim=2), 2) 479 | 480 | # Get the final image 481 | image = shading * mask 482 | image[mask == 0] = 1 483 | mask = torch.clamp(image * 10000, 0, 1) 484 | 485 | return image, mask 486 | 487 | # The energy E captures the difference between a rendered image and 488 | # a desired target image, and the rendered image is a function of the 489 | # SDF values. You could write E(SDF) = ||rendering(SDF)-target_image||^2. 490 | # In addition, there is a second term in the energy as you observed that 491 | # constrains the length of the normal of the SDF to 1. This is a regularization 492 | # term to make sure the output is still a valid SDF. 493 | def loss_fn(output, target, grid, voxel_size, grid_res_x, grid_res_y, grid_res_z, width, height): 494 | 495 | image_loss = torch.sum(torch.abs(target - output)) #/ (width * height) 496 | 497 | [grid_normal_x, grid_normal_y, grid_normal_z] = get_grid_normal(grid, voxel_size, grid_res_x, grid_res_y, grid_res_z) 498 | sdf_loss = torch.sum(torch.abs(torch.pow(grid_normal_x[1:grid_res_x-1, 1:grid_res_y-1, 1:grid_res_z-1], 2)\ 499 | + torch.pow(grid_normal_y[1:grid_res_x-1, 1:grid_res_y-1, 1:grid_res_z-1], 2)\ 500 | + torch.pow(grid_normal_z[1:grid_res_x-1, 1:grid_res_y-1, 1:grid_res_z-1], 2) - 1)) #/ ((grid_res-1) * (grid_res-1) * (grid_res-1)) 501 | 502 | return image_loss, sdf_loss 503 | 504 | def sdf_diff(sdf1, sdf2): 505 | return torch.sum(torch.abs(sdf1 - sdf2)).item() 506 | 507 | -------------------------------------------------------------------------------- /single_view_code/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torchvision import transforms, datasets 4 | from torchvision.utils import save_image, make_grid 5 | import torchvision 6 | 7 | from models import Encoder, Decoder, Refiner 8 | from dataset import Dataset 9 | 10 | import torch.nn as nn 11 | import matplotlib.pyplot as plt 12 | from differentiable_rendering import differentiable_rendering, grid_construction_sphere_big, loss_fn, calculate_sdf_value 13 | import math 14 | from random import * 15 | 16 | i = 0 17 | num_rec = 0 18 | num_epochs = 0 19 | sample = 0 20 | factor = 15 21 | avg = 1000000 22 | directory = "../result/" 23 | overall_loss = [] 24 | latent_loss = [] 25 | rec_loss = [] 26 | num_rec_list = [] 27 | num_print = 0 28 | 29 | cuda = True if torch.cuda.is_available() else False 30 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 31 | 32 | camera_list = [] 33 | angle = 0 34 | h = 3 35 | for i in range(24): 36 | camera_list.append(Tensor([math.cos(angle) * math.sqrt(25-h**2), h, math.sin(angle) * math.sqrt(25-h**2)])) 37 | angle += math.pi / 12 38 | 39 | 40 | def generate_samples(test_loader, shape_encoder, shape_decoder, sketch_encoder, args): 41 | with torch.no_grad(): 42 | 43 | sketch, shape = next(iter(test_loader)) 44 | 45 | torchvision.utils.save_image(sketch[0], "./" + directory + "/sketch" + str(num_rec) + ".png", nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0) 46 | sketch_feature, _ = sketch_encoder(sketch) 47 | sketch_out = shape_decoder(sketch_feature) 48 | shape_feature = shape_encoder(shape) 49 | shape_out = shape_decoder(shape_feature) 50 | images = torch.cuda.FloatTensor(24, 1, args.image_res, args.image_res) #### 51 | print("----", F.mse_loss(shape, sketch_out), shape.shape[0]) 52 | # print(out.shape[0]) 53 | for i in range(18, 26): 54 | # self.fake_image[i,0,:,:] = differentiable_rendering(fake_sdf[i,0,:,:,:], self.fake_sdf.shape[-1], self.opt.crop_size, camera_list[self.index[i]]) 55 | images[i-18,0,:,:] = differentiable_rendering(sketch_out[0,0,:,:,:], sketch_out.shape[-1], args.image_res, camera_list[i]) 56 | images[i-10,0,:,:] = differentiable_rendering(shape_out[0,0,:,:,:], shape_out.shape[-1], args.image_res, camera_list[i]) 57 | images[i-2,0,:,:] = differentiable_rendering(shape[0,0,:,:,:], shape.shape[-1], args.image_res, camera_list[i]) 58 | return images 59 | 60 | 61 | def train(train_loader, encoder, decoder, refiner, optimizer, args): 62 | global i 63 | global num_rec 64 | global overall_loss 65 | global latent_loss 66 | global rec_loss 67 | global factor 68 | global avg 69 | global num_print 70 | j=0 71 | loss_curr = 0 72 | latent_loss_curr = 0 73 | rec_loss_curr = 0 74 | avg_curr = 0 75 | num_img = 20 76 | 77 | for image, shape in train_loader: 78 | if image.shape[0] < num_img: 79 | break 80 | 81 | for iteration in range(1): 82 | j += 1 83 | 84 | shape = shape.cuda() 85 | image = image.cuda() 86 | 87 | optimizer.zero_grad() 88 | 89 | latent = encoder(image) 90 | result = decoder(latent) 91 | result = refiner(result) 92 | 93 | img_res = 256 94 | 95 | images = torch.cuda.FloatTensor(num_img * 2, 1, img_res, img_res) #### 96 | show_images = torch.cuda.FloatTensor(num_img, 1, img_res, img_res) 97 | 98 | loss = 0 99 | 100 | rand = torch.randint(0, 24, (num_img,)) 101 | if j % 100 == 1: 102 | for i in range(num_img): 103 | cam = rand[i] 104 | images[i,0,:,:], _ = differentiable_rendering(result[i,0,:,:,:], result.shape[-1], img_res, camera_list[cam]) 105 | images[i+num_img,0,:,:], _ = differentiable_rendering(shape[i,0,:,:,:], shape.shape[-1], img_res, camera_list[cam]) 106 | if j % 100 == 1: 107 | if i % 2 == 0: 108 | show_images[int(i/2),0,:,:] = images[i,0,:,:] 109 | show_images[int(i/2) + int(num_img / 2),0,:,:] = images[i+num_img,0,:,:] 110 | if j % 100 == 1: 111 | grid = make_grid(show_images, nrow=int(num_img/2)) 112 | torchvision.utils.save_image(grid, "../result/" + args.category + "/train_" + str(num_rec) + "_" + str(j) + ".png", nrow=6, padding=2, normalize=False, range=None, scale_each=False, pad_value=0) 113 | 114 | obj_loss = 0 115 | 116 | # narrow band 117 | mask = torch.abs(result[0,0]) < 0.1 118 | mask = mask.float() 119 | 120 | # sdf loss 121 | image_loss, sdf_loss = loss_fn(images[:num_img][0,0], images[num_img:][0,0], result[0,0] * mask, 4/64., 64, 64, 64, 256, 256) 122 | obj_loss += sdf_loss / (64**3) * 0.02 123 | 124 | # laplancian loss 125 | conv_input = (result[0,0] * mask).unsqueeze(0).unsqueeze(0) 126 | conv_filter = torch.cuda.FloatTensor([[[[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [1, -6, 1], [0, 1, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]]]]]) 127 | Lp_loss = torch.sum(F.conv3d(conv_input, conv_filter) ** 2) / (64**3) 128 | obj_loss += Lp_loss * 0.02 129 | 130 | # image loss 131 | obj_loss += F.mse_loss(images[:num_img], images[num_img:]) * 15 * (256 * 256 / img_res / img_res) 132 | 133 | # back probagate 134 | loss = obj_loss 135 | loss.backward() 136 | optimizer.step() 137 | 138 | 139 | def Train(args): 140 | save_filename = './models/{0}'.format(args.output_folder) 141 | 142 | train_dataset = Dataset(True, args.category) 143 | 144 | # Define the data loaders 145 | train_loader = torch.utils.data.DataLoader(train_dataset, 146 | batch_size=int(6), shuffle=True) 147 | 148 | encoder = EncoderSimple(args).to(args.device) 149 | decoder = DecoderSimple(args).to(args.device) 150 | refiner = Refiner(args).to(args.device) 151 | params = list(encoder.parameters()) + list(decoder.parameters()) + list(refiner.parameters()) 152 | optimizer = torch.optim.Adam(params, lr=args.lr) 153 | 154 | for epoch in range(args.num_epochs): 155 | global i 156 | global num_rec 157 | train(train_loader, encoder, decoder, refiner, optimizer, args) 158 | num_rec += 1 159 | 160 | print("======= Finished Epoch " + str(num_rec) + " =======") 161 | 162 | if num_rec % 1 == 0: 163 | with open("../models/" + args.category + '/ccencoder{0}.pt'.format(epoch + 1), 'wb') as f: 164 | torch.save(encoder.state_dict(), f) 165 | with open('../models/' + args.category + '/ccdecoder{0}.pt'.format(epoch + 1), 'wb') as f: 166 | torch.save(decoder.state_dict(), f) 167 | with open('../models/' + args.category + '/ccrefiner{0}.pt'.format(epoch + 1), 'wb') as f: 168 | torch.save(refiner.state_dict(), f) 169 | 170 | 171 | if __name__ == '__main__': 172 | 173 | import argparse 174 | import os 175 | import multiprocessing as mp 176 | torch.backends.cudnn.benchmark = True 177 | 178 | parser = argparse.ArgumentParser(description='SDFDiff') 179 | 180 | # Optimization 181 | parser.add_argument('--batch-size', type=int, default=300, #500,#int(6678 / 14), #int(6678/14), 182 | help='batch size (default: 128)') 183 | parser.add_argument('--num-epochs', type=int, default=50000, 184 | help='number of epochs (default: 500)') 185 | parser.add_argument('--lr', type=float, default=3e-4, 186 | help='learning rate for Adam optimizer (default: 2e-4)') 187 | parser.add_argument('--beta', type=float, default=0.1, 188 | help='contribution of commitment loss, between 0.1 and 2.0 (default: 1.0)') 189 | 190 | # Miscellaneous 191 | parser.add_argument('--output-folder', type=str, default='vqvae', 192 | help='name of the output folder (default: vqvae)') 193 | parser.add_argument('--num-workers', type=int, default=mp.cpu_count() - 1, 194 | help='number of workers for trajectories sampling (default: {0})'.format(mp.cpu_count() - 1)) 195 | parser.add_argument('--device', type=str, default='cuda', 196 | help='set the device (cpu or cuda, default: cpu)') 197 | parser.add_argument('--sdf-res', type=str, default=32, 198 | help='SDF resolution') 199 | parser.add_argument('--image-res', type=str, default=64, 200 | help='image resolution') 201 | parser.add_argument('--dataset-size', type=str, default=6678*10, 202 | help='the size of the dataset') 203 | parser.add_argument('--category', type=str, default='vessel', 204 | help='the category of the dataset') 205 | 206 | 207 | args = parser.parse_args() 208 | 209 | # Create logs and models folder if they don't exist 210 | if not os.path.exists('./logs'): 211 | os.makedirs('./logs') 212 | if not os.path.exists('./models'): 213 | os.makedirs('./models') 214 | # Device 215 | args.device = torch.device(args.device 216 | if torch.cuda.is_available() else 'cpu') 217 | # Slurm 218 | if 'SLURM_JOB_ID' in os.environ: 219 | args.output_folder += '-{0}'.format(os.environ['SLURM_JOB_ID']) 220 | if not os.path.exists('./models/{0}'.format(args.output_folder)): 221 | os.makedirs('./models/{0}'.format(args.output_folder)) 222 | args.steps = 0 223 | 224 | 225 | Train(args) 226 | -------------------------------------------------------------------------------- /single_view_code/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import torchvision.models as models 6 | 7 | 8 | class EncoderSimple(nn.Module): 9 | def __init__(self, args): 10 | super().__init__() 11 | 12 | self.elayer1 = torch.nn.Sequential( 13 | torch.nn.Conv2d(1, 64, 5, 2, 2), 14 | torch.nn.BatchNorm2d(64), 15 | torch.nn.ReLU(True) 16 | ) 17 | 18 | self.elayer2 = torch.nn.Sequential( 19 | torch.nn.Conv2d(64, 128, 5, 2, 2), 20 | torch.nn.BatchNorm2d(128), 21 | torch.nn.ReLU(True) 22 | ) 23 | 24 | self.elayer3 = torch.nn.Sequential( 25 | torch.nn.Conv2d(128, 256, 5, 2, 2), 26 | torch.nn.BatchNorm2d(256), 27 | torch.nn.ReLU(True) 28 | ) 29 | 30 | self.elayer4 = torch.nn.Sequential( 31 | torch.nn.Linear(256 * 64, 1024), 32 | torch.nn.ReLU(True) 33 | ) 34 | 35 | self.elayer5 = torch.nn.Sequential( 36 | torch.nn.Linear(1024, 1024), 37 | torch.nn.ReLU(True) 38 | ) 39 | 40 | self.elayer6 = torch.nn.Sequential( 41 | torch.nn.Linear(1024, 512), 42 | torch.nn.ReLU(True) 43 | ) 44 | 45 | 46 | def forward(self, x): 47 | 48 | out = self.elayer1(x) 49 | out = self.elayer2(out) 50 | out = self.elayer3(out) 51 | out = out.view(out.size(0), -1) 52 | out = self.elayer4(out) 53 | out = self.elayer5(out) 54 | out = self.elayer6(out) 55 | 56 | return out 57 | 58 | 59 | class DecoderSimple(nn.Module): 60 | def __init__(self, args): 61 | super().__init__() 62 | 63 | self.dlayer1 = torch.nn.Sequential( 64 | torch.nn.Linear(512, 1024), 65 | torch.nn.ReLU(True) 66 | ) 67 | 68 | self.dlayer2 = torch.nn.Sequential( 69 | torch.nn.Linear(1024, 1024), 70 | torch.nn.ReLU(True) 71 | ) 72 | 73 | self.dlayer3 = torch.nn.Sequential( 74 | torch.nn.Linear(1024, 32 * 4 * 4 * 4), 75 | torch.nn.ReLU(True) 76 | ) 77 | 78 | self.dlayer4 = torch.nn.Sequential( 79 | torch.nn.Linear(32 * 4 * 4 * 4, 64 * 8 * 8 * 8), 80 | torch.nn.ReLU(True) 81 | ) 82 | 83 | self.dlayer5_1 = torch.nn.Sequential( 84 | torch.nn.ConvTranspose3d(64, 32, kernel_size=3, stride=1, padding=1), 85 | torch.nn.BatchNorm3d(32), 86 | torch.nn.ReLU() 87 | 88 | ) 89 | 90 | self.dlayer5_2 = torch.nn.Sequential( 91 | torch.nn.ConvTranspose3d(64 + 32, 32, kernel_size=3, stride=1, padding=1), 92 | torch.nn.BatchNorm3d(32), 93 | torch.nn.ReLU() 94 | ) 95 | 96 | self.dlayer5_3 = torch.nn.Sequential( 97 | torch.nn.ConvTranspose3d(64 + 32 * 2, 32, kernel_size=3, stride=1, padding=1), 98 | torch.nn.BatchNorm3d(32), 99 | torch.nn.ReLU() 100 | ) 101 | 102 | self.dlayer5_4 = torch.nn.Sequential( 103 | torch.nn.ConvTranspose3d(64 + 32 * 3, 64, kernel_size=1, stride=1, padding=0), 104 | torch.nn.BatchNorm3d(64), 105 | torch.nn.ReLU() 106 | ) 107 | 108 | self.dlayer5 = torch.nn.Sequential( 109 | torch.nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, padding=(1, 1, 1)), 110 | torch.nn.BatchNorm3d(32), 111 | torch.nn.ReLU(), 112 | ) 113 | 114 | self.dlayer6_1 = torch.nn.Sequential( 115 | torch.nn.ConvTranspose3d(32, 16, kernel_size=3, stride=1, padding=1), 116 | torch.nn.BatchNorm3d(16), 117 | torch.nn.ReLU() 118 | ) 119 | 120 | self.dlayer6_2 = torch.nn.Sequential( 121 | torch.nn.ConvTranspose3d(32 + 16, 16, kernel_size=3, stride=1, padding=1), 122 | torch.nn.BatchNorm3d(16), 123 | torch.nn.ReLU() 124 | ) 125 | 126 | self.dlayer6_3 = torch.nn.Sequential( 127 | torch.nn.ConvTranspose3d(32 + 16 * 2, 16, kernel_size=3, stride=1, padding=1), 128 | torch.nn.BatchNorm3d(16), 129 | torch.nn.ReLU() 130 | ) 131 | 132 | self.dlayer6_4 = torch.nn.Sequential( 133 | torch.nn.ConvTranspose3d(32 + 16 * 3, 32, kernel_size=1, stride=1, padding=0), 134 | torch.nn.BatchNorm3d(32), 135 | torch.nn.ReLU() 136 | ) 137 | 138 | self.dlayer6 = torch.nn.Sequential( 139 | torch.nn.ConvTranspose3d(32, 16, kernel_size=4, stride=2, padding=(1, 1, 1)), 140 | torch.nn.BatchNorm3d(16), 141 | torch.nn.ReLU() 142 | ) 143 | 144 | self.dlayer7_1 = torch.nn.Sequential( 145 | torch.nn.ConvTranspose3d(16, 8, kernel_size=3, stride=1, padding=1), 146 | torch.nn.BatchNorm3d(8), 147 | torch.nn.ReLU() 148 | ) 149 | 150 | self.dlayer7_2 = torch.nn.Sequential( 151 | torch.nn.ConvTranspose3d(16 + 8, 8, kernel_size=3, stride=1, padding=1), 152 | torch.nn.BatchNorm3d(8), 153 | torch.nn.ReLU() 154 | ) 155 | 156 | self.dlayer7_3 = torch.nn.Sequential( 157 | torch.nn.ConvTranspose3d(16 + 8 * 2, 8, kernel_size=3, stride=1, padding=1), 158 | torch.nn.BatchNorm3d(8), 159 | torch.nn.ReLU() 160 | ) 161 | 162 | self.dlayer7_4 = torch.nn.Sequential( 163 | torch.nn.ConvTranspose3d(16 + 8 * 3, 16, kernel_size=1, stride=1, padding=0), 164 | torch.nn.BatchNorm3d(16), 165 | torch.nn.ReLU() 166 | ) 167 | 168 | self.dlayer7 = torch.nn.Sequential( 169 | torch.nn.ConvTranspose3d(16, 8, kernel_size=3, stride=1, padding=1), 170 | torch.nn.BatchNorm3d(8), 171 | torch.nn.ReLU() 172 | ) 173 | 174 | self.dlayer8 = torch.nn.Sequential( 175 | torch.nn.ConvTranspose3d(8, 1, kernel_size=4, stride=2, padding=(1, 1, 1)), 176 | torch.nn.Sigmoid() 177 | ) 178 | 179 | 180 | def forward(self, x): 181 | 182 | out = self.dlayer1(x) 183 | out = self.dlayer2(out) 184 | out = self.dlayer3(out) 185 | out = self.dlayer4(out) 186 | out = out.view(out.size(0), 64, 8, 8, 8) 187 | out1 = self.dlayer5_1(out) 188 | out2 = self.dlayer5_2(torch.cat((out, out1), 1)) 189 | out3 = self.dlayer5_3(torch.cat((out, out1, out2), 1)) 190 | out = self.dlayer5_4(torch.cat((out, out1, out2, out3), 1)) 191 | out = self.dlayer5(out) 192 | out1 = self.dlayer6_1(out) 193 | out2 = self.dlayer6_2(torch.cat((out, out1), 1)) 194 | out3 = self.dlayer6_3(torch.cat((out, out1, out2), 1)) 195 | out = self.dlayer6_4(torch.cat((out, out1, out2, out3), 1)) 196 | out = self.dlayer6(out) 197 | out1 = self.dlayer7_1(out) 198 | out2 = self.dlayer7_2(torch.cat((out, out1), 1)) 199 | out3 = self.dlayer7_3(torch.cat((out, out1, out2), 1)) 200 | out = self.dlayer7_4(torch.cat((out, out1, out2, out3), 1)) 201 | out = self.dlayer7(out) 202 | out = self.dlayer8(out) 203 | 204 | return out * 4 - 2 205 | 206 | 207 | class Refiner(nn.Module): 208 | def __init__(self, args): 209 | super().__init__() 210 | 211 | self.rlayer1 = torch.nn.Sequential( 212 | torch.nn.Conv3d(1, 32, kernel_size=4, padding=(2, 2, 2)), 213 | torch.nn.BatchNorm3d(32), 214 | torch.nn.LeakyReLU(0.2), 215 | torch.nn.MaxPool3d(kernel_size=2) 216 | ) 217 | 218 | self.rlayer2 = torch.nn.Sequential( 219 | torch.nn.Conv3d(32, 64, kernel_size=4, padding=(2, 2, 2)), 220 | torch.nn.BatchNorm3d(64), 221 | torch.nn.LeakyReLU(0.2), 222 | torch.nn.MaxPool3d(kernel_size=2) 223 | ) 224 | 225 | self.rlayer3 = torch.nn.Sequential( 226 | torch.nn.Conv3d(64, 128, kernel_size=4, padding=(2, 2, 2)), 227 | torch.nn.BatchNorm3d(128), 228 | torch.nn.LeakyReLU(0.2), 229 | torch.nn.MaxPool3d(kernel_size=2) 230 | ) 231 | 232 | self.rlayer4 = torch.nn.Sequential( 233 | torch.nn.Conv3d(128, 256, kernel_size=4, padding=(2, 2, 2)), 234 | torch.nn.BatchNorm3d(256), 235 | torch.nn.LeakyReLU(0.2), 236 | torch.nn.MaxPool3d(kernel_size=2) 237 | ) 238 | 239 | self.rlayer5 = torch.nn.Sequential( 240 | torch.nn.Linear(8192*2, 2048), 241 | torch.nn.ReLU(True) 242 | ) 243 | 244 | self.rlayer6 = torch.nn.Sequential( 245 | torch.nn.Linear(2048, 8192*2), 246 | torch.nn.ReLU(True) 247 | ) 248 | 249 | self.rlayer7 = torch.nn.Sequential( 250 | torch.nn.ConvTranspose3d(256, 128, kernel_size=4, stride=2, padding=(1, 1, 1)), 251 | torch.nn.BatchNorm3d(128), 252 | torch.nn.ReLU() 253 | ) 254 | 255 | self.rlayer8 = torch.nn.Sequential( 256 | torch.nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, padding=(1, 1, 1)), 257 | torch.nn.BatchNorm3d(64), 258 | torch.nn.ReLU() 259 | ) 260 | 261 | self.rlayer9 = torch.nn.Sequential( 262 | torch.nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, padding=(1, 1, 1)), 263 | torch.nn.BatchNorm3d(32), 264 | torch.nn.ReLU() 265 | ) 266 | 267 | self.rlayer10 = torch.nn.Sequential( 268 | torch.nn.ConvTranspose3d(32, 1, kernel_size=4, stride=2, padding=(1, 1, 1)), 269 | torch.nn.Sigmoid() 270 | ) 271 | 272 | def forward(self, coarse_volumes): 273 | # bx1x64x64x64 274 | volumes_32_l = self.rlayer1(coarse_volumes) 275 | # bx32x32x32x32 276 | volumes_16_l = self.rlayer2(volumes_32_l) 277 | # bx64x16x16x16 278 | volumes_8_l = self.rlayer3(volumes_16_l) 279 | # bx128x8x8x8 280 | volumes_4_l = self.rlayer4(volumes_8_l) 281 | # bx256x4x4x4 282 | flatten_features = self.rlayer5(volumes_4_l.view(-1, 8192 * 2)) 283 | flatten_features = self.rlayer6(flatten_features) 284 | volumes_4_r = volumes_4_l + flatten_features.view(-1, 256, 4, 4, 4) 285 | # bx256x4x4x4 286 | volumes_8_r = volumes_8_l + self.rlayer7(volumes_4_r) 287 | # bx128x8x8x8 288 | volumes_16_r = volumes_16_l + self.rlayer8(volumes_8_r) 289 | # bx64x16x16x16 290 | volumes_32_r = volumes_32_l + self.rlayer9(volumes_16_r) 291 | # bx32x32x32x32 292 | volumes_64_r = (coarse_volumes + self.rlayer10(volumes_32_r)) * 0.5 293 | # bx1x64x64x64 294 | 295 | return volumes_64_r * 4 - 2 296 | 297 | -------------------------------------------------------------------------------- /single_view_code/renderer.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | std::vector ray_matching_cuda( 5 | const at::Tensor w_h_3, 6 | const at::Tensor w_h, 7 | const at::Tensor grid, 8 | const int width, 9 | const int height, 10 | const float bounding_box_min_x, 11 | const float bounding_box_min_y, 12 | const float bounding_box_min_z, 13 | const float bounding_box_max_x, 14 | const float bounding_box_max_y, 15 | const float bounding_box_max_z, 16 | const int grid_res_x, 17 | const int grid_res_y, 18 | const int grid_res_z, 19 | const float eye_x, 20 | const float eye_y, 21 | const float eye_z); 22 | 23 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 24 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 25 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 26 | 27 | std::vector ray_matching( 28 | const at::Tensor w_h_3, 29 | const at::Tensor w_h, 30 | const at::Tensor grid, 31 | const int width, 32 | const int height, 33 | const float bounding_box_min_x, 34 | const float bounding_box_min_y, 35 | const float bounding_box_min_z, 36 | const float bounding_box_max_x, 37 | const float bounding_box_max_y, 38 | const float bounding_box_max_z, 39 | const int grid_res_x, 40 | const int grid_res_y, 41 | const int grid_res_z, 42 | const float eye_x, 43 | const float eye_y, 44 | const float eye_z) { 45 | CHECK_INPUT(w_h_3); 46 | CHECK_INPUT(w_h); 47 | CHECK_INPUT(grid); 48 | 49 | return ray_matching_cuda(w_h_3, w_h, grid, width, height, 50 | bounding_box_min_x, 51 | bounding_box_min_y, 52 | bounding_box_min_z, 53 | bounding_box_max_x, 54 | bounding_box_max_y, 55 | bounding_box_max_z, 56 | grid_res_x, 57 | grid_res_y, 58 | grid_res_z, 59 | eye_x, eye_y, eye_z); 60 | } 61 | 62 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 63 | m.def("ray_matching", &ray_matching, "Ray Matching"); 64 | } -------------------------------------------------------------------------------- /single_view_code/renderer_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #define PI 3.14159265358979323846f 7 | namespace { 8 | 9 | __device__ __forceinline__ float DegToRad(const float °) { return (deg * (PI / 180.f)); } 10 | 11 | 12 | __device__ __forceinline__ float length( 13 | const float x, 14 | const float y, 15 | const float z) { 16 | return sqrtf(powf(x, 2) + powf(y, 2) + powf(z, 2)); 17 | } 18 | 19 | // Cross product 20 | __device__ __forceinline__ float cross_x( 21 | const float a_x, 22 | const float a_y, 23 | const float a_z, 24 | const float b_x, 25 | const float b_y, 26 | const float b_z) { 27 | return a_y * b_z - a_z * b_y; 28 | } 29 | 30 | 31 | __device__ __forceinline__ float cross_y( 32 | const float a_x, 33 | const float a_y, 34 | const float a_z, 35 | const float b_x, 36 | const float b_y, 37 | const float b_z) { 38 | return a_z * b_x - a_x * b_z; 39 | } 40 | 41 | 42 | __device__ __forceinline__ float cross_z( 43 | const float a_x, 44 | const float a_y, 45 | const float a_z, 46 | const float b_x, 47 | const float b_y, 48 | const float b_z) { 49 | return a_x * b_y - a_y * b_x; 50 | } 51 | 52 | __global__ void GenerateRay( 53 | float* origins, 54 | float* directions, 55 | float* origin_image_distances, 56 | float* pixel_distances, 57 | const int width, 58 | const int height, 59 | const float eye_x, 60 | const float eye_y, 61 | const float eye_z) { 62 | 63 | const float at_x = 0; 64 | const float at_y = 0; 65 | const float at_z = 0; 66 | const float up_x = 0; 67 | const float up_y = 1; 68 | const float up_z = 0; 69 | 70 | // Compute camera view volume 71 | const float top = tan(DegToRad(30)); 72 | const float bottom = -top; 73 | const float right = (__int2float_rd(width) / __int2float_rd(height)) * top; 74 | const float left = -right; 75 | 76 | // Compute local base 77 | const float w_x = (eye_x - at_x) / length(eye_x - at_x, eye_y - at_y, eye_z - at_z); 78 | const float w_y = (eye_y - at_y) / length(eye_x - at_x, eye_y - at_y, eye_z - at_z); 79 | const float w_z = (eye_z - at_z) / length(eye_x - at_x, eye_y - at_y, eye_z - at_z); 80 | const float cross_up_w_x = cross_x(up_x, up_y, up_z, w_x, w_y, w_z); 81 | const float cross_up_w_y = cross_y(up_x, up_y, up_z, w_x, w_y, w_z); 82 | const float cross_up_w_z = cross_z(up_x, up_y, up_z, w_x, w_y, w_z); 83 | const float u_x = (cross_up_w_x) / length(cross_up_w_x, cross_up_w_y, cross_up_w_z); 84 | const float u_y = (cross_up_w_y) / length(cross_up_w_x, cross_up_w_y, cross_up_w_z); 85 | const float u_z = (cross_up_w_z) / length(cross_up_w_x, cross_up_w_y, cross_up_w_z); 86 | const float v_x = cross_x(w_x, w_y, w_z, u_x, u_y, u_z); 87 | const float v_y = cross_y(w_x, w_y, w_z, u_x, u_y, u_z); 88 | const float v_z = cross_z(w_x, w_y, w_z, u_x, u_y, u_z); 89 | 90 | 91 | const int pixel_index = blockIdx.x * blockDim.x + threadIdx.x; 92 | 93 | if (pixel_index < width * height) { 94 | const int x = pixel_index % width; 95 | const int y = pixel_index / width; 96 | const int i = 3 * pixel_index; 97 | 98 | // Compute point on view plane 99 | // Ray passes through the center of the pixel 100 | const float view_plane_x = left + (right - left) * (__int2float_rd(x) + 0.5) / __int2float_rd(width); 101 | const float view_plane_y = top - (top - bottom) * (__int2float_rd(y) + 0.5) / __int2float_rd(height); 102 | const float s_x = view_plane_x * u_x + view_plane_y * v_x - w_x; 103 | const float s_y = view_plane_x * u_y + view_plane_y * v_y - w_y; 104 | const float s_z = view_plane_x * u_z + view_plane_y * v_z - w_z; 105 | origins[i] = eye_x; 106 | origins[i+1] = eye_y; 107 | origins[i+2] = eye_z; 108 | 109 | 110 | directions[i] = s_x / length(s_x, s_y, s_z); 111 | directions[i+1] = s_y / length(s_x, s_y, s_z); 112 | directions[i+2] = s_z / length(s_x, s_y, s_z); 113 | 114 | origin_image_distances[pixel_index] = length(s_x, s_y, s_z); 115 | pixel_distances[pixel_index] = (right - left) / __int2float_rd(width); 116 | 117 | } 118 | } 119 | 120 | // Check if a point is inside 121 | __device__ __forceinline__ bool InsideBoundingBox( 122 | const float p_x, 123 | const float p_y, 124 | const float p_z, 125 | const float bounding_box_min_x, 126 | const float bounding_box_min_y, 127 | const float bounding_box_min_z, 128 | const float bounding_box_max_x, 129 | const float bounding_box_max_y, 130 | const float bounding_box_max_z) { 131 | 132 | return (p_x >= bounding_box_min_x) && (p_x <= bounding_box_max_x) && 133 | (p_y >= bounding_box_min_y) && (p_y <= bounding_box_max_y) && 134 | (p_z >= bounding_box_min_z) && (p_z <= bounding_box_max_z); 135 | } 136 | 137 | // Compute the distance along the ray between the point and the bounding box 138 | __device__ float Distance( 139 | const float reached_point_x, 140 | const float reached_point_y, 141 | const float reached_point_z, 142 | float direction_x, 143 | float direction_y, 144 | float direction_z, 145 | const float bounding_box_min_x, 146 | const float bounding_box_min_y, 147 | const float bounding_box_min_z, 148 | const float bounding_box_max_x, 149 | const float bounding_box_max_y, 150 | const float bounding_box_max_z) { 151 | 152 | float dist = -1.f; 153 | direction_x = direction_x / length(direction_x, direction_y, direction_z); 154 | direction_y = direction_y / length(direction_x, direction_y, direction_z); 155 | direction_z = direction_z / length(direction_x, direction_y, direction_z); 156 | 157 | // For each axis count any excess distance outside box extents 158 | float v = reached_point_x; 159 | float d = direction_x; 160 | if (dist == -1) { 161 | if ((v < bounding_box_min_x) && (d > 0)) { dist = (bounding_box_min_x - v) / d; } 162 | if ((v > bounding_box_max_x) && (d < 0)) { dist = (bounding_box_max_x - v) / d; } 163 | } else { 164 | if ((v < bounding_box_min_x) && (d > 0)) { dist = fmaxf(dist, (bounding_box_min_x - v) / d); } 165 | if ((v > bounding_box_max_x) && (d < 0)) { dist = fmaxf(dist, (bounding_box_max_x - v) / d); } 166 | } 167 | 168 | v = reached_point_y; 169 | d = direction_y; 170 | if (dist == -1) { 171 | if ((v < bounding_box_min_y) && (d > 0)) { dist = (bounding_box_min_y - v) / d; } 172 | if ((v > bounding_box_max_y) && (d < 0)) { dist = (bounding_box_max_y - v) / d; } 173 | } else { 174 | if ((v < bounding_box_min_y) && (d > 0)) { dist = fmaxf(dist, (bounding_box_min_y - v) / d); } 175 | if ((v > bounding_box_max_y) && (d < 0)) { dist = fmaxf(dist, (bounding_box_max_y - v) / d); } 176 | } 177 | 178 | v = reached_point_z; 179 | d = direction_z; 180 | if (dist == -1) { 181 | if ((v < bounding_box_min_z) && (d > 0)) { dist = (bounding_box_min_z - v) / d; } 182 | if ((v > bounding_box_max_z) && (d < 0)) { dist = (bounding_box_max_z - v) / d; } 183 | } else { 184 | if ((v < bounding_box_min_z) && (d > 0)) { dist = fmaxf(dist, (bounding_box_min_z - v) / d); } 185 | if ((v > bounding_box_max_z) && (d < 0)) { dist = fmaxf(dist, (bounding_box_max_z - v) / d); } 186 | } 187 | 188 | return dist; 189 | } 190 | 191 | __device__ __forceinline__ int flat(float const x, float const y, float const z, 192 | int const grid_res_x, int const grid_res_y, int const grid_res_z) { 193 | return __int2float_rd(z) + __int2float_rd(y) * grid_res_z + __int2float_rd(x) * grid_res_z * grid_res_y; 194 | } 195 | 196 | // Get the signed distance value at the specific point 197 | __device__ float ValueAt( 198 | const float* grid, 199 | const float reached_point_x, 200 | const float reached_point_y, 201 | const float reached_point_z, 202 | const float direction_x, 203 | const float direction_y, 204 | const float direction_z, 205 | const float bounding_box_min_x, 206 | const float bounding_box_min_y, 207 | const float bounding_box_min_z, 208 | const float bounding_box_max_x, 209 | const float bounding_box_max_y, 210 | const float bounding_box_max_z, 211 | const int grid_res_x, 212 | const int grid_res_y, 213 | const int grid_res_z, 214 | const bool first_time) { 215 | 216 | // Check if we are outside the BBOX 217 | if (!InsideBoundingBox(reached_point_x, reached_point_y, reached_point_z, 218 | bounding_box_min_x, 219 | bounding_box_min_y, 220 | bounding_box_min_z, 221 | bounding_box_max_x, 222 | bounding_box_max_y, 223 | bounding_box_max_z)) { 224 | 225 | // If it is the first time, then the ray has not entered the grid 226 | if (first_time) { 227 | 228 | return Distance(reached_point_x, reached_point_y, reached_point_z, 229 | direction_x, direction_y, direction_z, 230 | bounding_box_min_x, 231 | bounding_box_min_y, 232 | bounding_box_min_z, 233 | bounding_box_max_x, 234 | bounding_box_max_y, 235 | bounding_box_max_z) + 0.00001f; 236 | } 237 | 238 | // Otherwise, the ray has left the grid 239 | else { 240 | return -1; 241 | } 242 | } 243 | 244 | // Compute voxel size 245 | float voxel_size = (bounding_box_max_x - bounding_box_min_x) / (grid_res_x - 1); 246 | 247 | // Compute the the minimum point of the intersecting voxel 248 | float min_index_x = floorf((reached_point_x - bounding_box_min_x) / voxel_size); 249 | float min_index_y = floorf((reached_point_y - bounding_box_min_y) / voxel_size); 250 | float min_index_z = floorf((reached_point_z - bounding_box_min_z) / voxel_size); 251 | 252 | // Check whether the ray intersects the vertex with the last index of the axis 253 | // If so, we should record the previous index 254 | if (min_index_x == (bounding_box_max_x - bounding_box_min_x) / voxel_size) { 255 | min_index_x = (bounding_box_max_x - bounding_box_min_x) / voxel_size - 1; 256 | } 257 | if (min_index_y == (bounding_box_max_y - bounding_box_min_y) / voxel_size) { 258 | min_index_y = (bounding_box_max_y - bounding_box_min_y) / voxel_size - 1; 259 | } 260 | if (min_index_z == (bounding_box_max_z - bounding_box_min_z) / voxel_size) { 261 | min_index_z = (bounding_box_max_z - bounding_box_min_z) / voxel_size - 1; 262 | } 263 | 264 | // Linear interpolate along x axis the eight values 265 | const float tx = (reached_point_x - (bounding_box_min_x + min_index_x * voxel_size)) / voxel_size; 266 | const float c01 = (1.f - tx) * grid[flat(min_index_x, min_index_y, min_index_z, grid_res_x, grid_res_y, grid_res_z)] 267 | + tx * grid[flat(min_index_x+1, min_index_y, min_index_z, grid_res_x, grid_res_y, grid_res_z)]; 268 | const float c23 = (1.f - tx) * grid[flat(min_index_x, min_index_y+1, min_index_z, grid_res_x, grid_res_y, grid_res_z)] 269 | + tx * grid[flat(min_index_x+1, min_index_y+1, min_index_z, grid_res_x, grid_res_y, grid_res_z)]; 270 | const float c45 = (1.f - tx) * grid[flat(min_index_x, min_index_y, min_index_z+1, grid_res_x, grid_res_y, grid_res_z)] 271 | + tx * grid[flat(min_index_x+1, min_index_y, min_index_z+1, grid_res_x, grid_res_y, grid_res_z)]; 272 | const float c67 = (1.f - tx) * grid[flat(min_index_x, min_index_y+1, min_index_z+1, grid_res_x, grid_res_y, grid_res_z)] 273 | + tx * grid[flat(min_index_x+1, min_index_y+1, min_index_z+1, grid_res_x, grid_res_y, grid_res_z)]; 274 | 275 | // Linear interpolate along the y axis 276 | const float ty = (reached_point_y - (bounding_box_min_y + min_index_y * voxel_size)) / voxel_size; 277 | const float c0 = (1.f - ty) * c01 + ty * c23; 278 | const float c1 = (1.f - ty) * c45 + ty * c67; 279 | 280 | // Return final value interpolated along z 281 | const float tz = (reached_point_z - (bounding_box_min_z + min_index_z * voxel_size)) / voxel_size; 282 | 283 | return (1.f - tz) * c0 + tz * c1; 284 | } 285 | 286 | // Compute the intersection of the ray and the grid 287 | // The intersection procedure uses ray marching to check if we have an interaction with the stored surface 288 | __global__ void Intersect( 289 | const float* grid, 290 | const float* origins, 291 | const float* directions, 292 | const float* origin_image_distances, 293 | const float* pixel_distances, 294 | const float bounding_box_min_x, 295 | const float bounding_box_min_y, 296 | const float bounding_box_min_z, 297 | const float bounding_box_max_x, 298 | const float bounding_box_max_y, 299 | const float bounding_box_max_z, 300 | const int grid_res_x, 301 | const int grid_res_y, 302 | const int grid_res_z, 303 | float* voxel_position, 304 | float* intersection_pos, 305 | const int width, 306 | const int height) { 307 | 308 | // Compute voxel size 309 | const float voxel_size = (bounding_box_max_x - bounding_box_min_x) / (grid_res_x - 1); 310 | 311 | // Define constant values 312 | const int max_steps = 1000; 313 | bool first_time = true; 314 | float depth = 0; 315 | int gotten_result = 0; 316 | 317 | const int pixel_index = blockIdx.x * blockDim.x + threadIdx.x; 318 | 319 | if (pixel_index < width * height) { 320 | 321 | const int i = 3 * pixel_index; 322 | 323 | for (int steps = 0; steps < max_steps; steps++) { 324 | 325 | float reached_point_x = origins[i] + depth * directions[i]; 326 | float reached_point_y = origins[i+1] + depth * directions[i+1]; 327 | float reached_point_z = origins[i+2] + depth * directions[i+2]; 328 | 329 | // Get the signed distance value for the point the ray reaches 330 | const float distance = ValueAt(grid, reached_point_x, reached_point_y, reached_point_z, 331 | directions[i], directions[i+1], directions[i+2], 332 | bounding_box_min_x, 333 | bounding_box_min_y, 334 | bounding_box_min_z, 335 | bounding_box_max_x, 336 | bounding_box_max_y, 337 | bounding_box_max_z, 338 | grid_res_x, 339 | grid_res_y, 340 | grid_res_z, first_time); 341 | first_time = false; 342 | 343 | // Check if the ray is going ourside the bounding box 344 | if (distance == -1) { 345 | voxel_position[i] = -1; 346 | voxel_position[i+1] = -1; 347 | voxel_position[i+2] = -1; 348 | intersection_pos[i] = -1; 349 | intersection_pos[i+1] = -1; 350 | intersection_pos[i+2] = -1; 351 | gotten_result = 1; 352 | break; 353 | } 354 | 355 | // Check if we are close enough to the surface 356 | if (distance < pixel_distances[pixel_index] / origin_image_distances[pixel_index] * depth / 2) { 357 | //if (distance < 0.1) { 358 | 359 | // Compute the the minimum point of the intersecting voxel 360 | voxel_position[i] = floorf((reached_point_x - bounding_box_min_x) / voxel_size); 361 | voxel_position[i+1] = floorf((reached_point_y - bounding_box_min_y) / voxel_size); 362 | voxel_position[i+2] = floorf((reached_point_z - bounding_box_min_z) / voxel_size); 363 | if (voxel_position[i] == grid_res_x - 1) { 364 | voxel_position[i] = voxel_position[i] - 1; 365 | } 366 | if (voxel_position[i+1] == grid_res_y - 1) { 367 | voxel_position[i+1] = voxel_position[i+1] - 1; 368 | } 369 | if (voxel_position[i+2] == grid_res_z - 1) { 370 | voxel_position[i+2] = voxel_position[i+2] - 1; 371 | } 372 | intersection_pos[i] = reached_point_x; 373 | intersection_pos[i+1] = reached_point_y; 374 | intersection_pos[i+2] = reached_point_z; 375 | gotten_result = 1; 376 | break; 377 | } 378 | 379 | // Increase distance 380 | depth += distance; 381 | 382 | } 383 | 384 | if (gotten_result == 0) { 385 | 386 | // No intersections 387 | voxel_position[i] = -1; 388 | voxel_position[i+1] = -1; 389 | voxel_position[i+2] = -1; 390 | intersection_pos[i] = -1; 391 | intersection_pos[i+1] = -1; 392 | intersection_pos[i+2] = -1; 393 | } 394 | } 395 | } 396 | } // namespace 397 | 398 | // Ray marching to get the first corner position of the voxel the ray intersects 399 | std::vector ray_matching_cuda( 400 | const at::Tensor w_h_3, 401 | const at::Tensor w_h, 402 | const at::Tensor grid, 403 | const int width, 404 | const int height, 405 | const float bounding_box_min_x, 406 | const float bounding_box_min_y, 407 | const float bounding_box_min_z, 408 | const float bounding_box_max_x, 409 | const float bounding_box_max_y, 410 | const float bounding_box_max_z, 411 | const int grid_res_x, 412 | const int grid_res_y, 413 | const int grid_res_z, 414 | const float eye_x, 415 | const float eye_y, 416 | const float eye_z) { 417 | 418 | const int thread = 512; 419 | 420 | at::Tensor origins = at::zeros_like(w_h_3); 421 | at::Tensor directions = at::zeros_like(w_h_3); 422 | at::Tensor origin_image_distances = at::zeros_like(w_h); 423 | at::Tensor pixel_distances = at::zeros_like(w_h); 424 | 425 | GenerateRay<<<(width * height + thread - 1) / thread, thread>>>( 426 | origins.data(), 427 | directions.data(), 428 | origin_image_distances.data(), 429 | pixel_distances.data(), 430 | width, 431 | height, 432 | eye_x, 433 | eye_y, 434 | eye_z); 435 | 436 | at::Tensor voxel_position = at::zeros_like(w_h_3); 437 | at::Tensor intersection_pos = at::zeros_like(w_h_3); 438 | 439 | Intersect<<<(width * height + thread - 1) / thread, thread>>>( 440 | grid.data(), 441 | origins.data(), 442 | directions.data(), 443 | origin_image_distances.data(), 444 | pixel_distances.data(), 445 | bounding_box_min_x, 446 | bounding_box_min_y, 447 | bounding_box_min_z, 448 | bounding_box_max_x, 449 | bounding_box_max_y, 450 | bounding_box_max_z, 451 | grid_res_x, 452 | grid_res_y, 453 | grid_res_z, 454 | voxel_position.data(), 455 | intersection_pos.data(), 456 | width, 457 | height); 458 | 459 | return {intersection_pos, voxel_position, directions}; 460 | } 461 | 462 | 463 | 464 | -------------------------------------------------------------------------------- /single_view_code/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='renderer', 6 | ext_modules=[ 7 | CUDAExtension('renderer', [ 8 | 'renderer.cpp', 9 | 'renderer_kernel.cu', 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /virtual_env/install_conda.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CONDA_ENV_NAME=sdfdiff 4 | echo $CONDA_ENV_NAME 5 | 6 | conda create -n $CONDA_ENV_NAME python=3.7 7 | 8 | eval "$(conda shell.bash hook)" 9 | conda activate $CONDA_ENV_NAME 10 | 11 | which python 12 | which pip 13 | 14 | pip install numpy==1.17.5 torch==1.4.0 torchvision==0.5.0 15 | -------------------------------------------------------------------------------- /virtual_env/install_pip.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | echo "Creating virtual environment" 4 | python3.7 -m venv sdfdiff 5 | echo "Activating virtual environment" 6 | 7 | source $PWD/sdfdiff/bin/activate 8 | 9 | $PWD/sdfdiff/bin/pip install numpy==1.17.5 torch==1.4.0 torchvision==0.5.0 10 | --------------------------------------------------------------------------------