├── .gitignore ├── slang ├── temp.slang ├── set-curved-inside-sign.slang ├── count-triangles-per-cell.slang ├── vertex-is-continuous.slang ├── count-adj-disc.slang ├── add-triangles-per-cell.slang ├── add-adj-disc.slang ├── point-in-triangle.slang ├── seco-link-radial-feats.slang ├── link-radial-feats.slang ├── feature-interpolation.slang ├── utils.slang └── d-feature-interpolation.slang ├── utils.py ├── configs.py ├── README.md ├── train.py ├── samplers.py └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.msh 2 | *.obj 3 | *.eps 4 | *.svg 5 | *.json 6 | *.drc 7 | *.npz 8 | *.h5 9 | *.gz 10 | *.png 11 | *.lock 12 | *.zip 13 | .slangpy_cache/ 14 | data/ 15 | results/ 16 | slang/.slangpy_cache/ 17 | .slangpy_cache/ 18 | __pycache__/ -------------------------------------------------------------------------------- /slang/temp.slang: -------------------------------------------------------------------------------- 1 | // printf("[%d], vertices are %d, %d, %d \n", tri_idx, v0_idx, v1_idx, v2_idx); 2 | // printf("[%d], vertex locations are %f, %f \n", v0_idx, v0.x, v0.y); 3 | 4 | // Find bounding rectangle for triangle 5 | // Note: You can surely find a better/ tigher fit than a bounding rectangle, 6 | // I was just lazy... 7 | 8 | 9 | // Vertex locations are assumed to be normalized to [0,1]^2 10 | 11 | // printf("X: %f %f %d %d \n", min_xy.x, max_xy.x, min_xy_idx.x, max_xy_idx.x); 12 | // printf("Y: %f %f %d %d \n", min_xy.y, max_xy.y, min_xy_idx.y, max_xy_idx.y); 13 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | 4 | def make_comparison_plot(XC, YC, L, res, model, sampler, path): 5 | _X, _Y = torch.linspace(XC-L/2,XC+L/2,res).cuda(), torch.linspace(YC-L/2,YC+L/2,res).cuda() 6 | XX, YY = torch.meshgrid(_X, _Y) 7 | Q = torch.stack([XX.reshape(-1), YY.reshape(-1)], dim=-1) 8 | gt = sampler(Q=Q) 9 | plt.figure(figsize=(10,5)) 10 | plt.subplot(1,2,1) 11 | plt.imshow(gt.detach().cpu().reshape(res,res,3), origin="lower", vmin=0, vmax=1) 12 | plt.title('Reference') 13 | plt.axis('off') 14 | 15 | plt.subplot(1,2,2) 16 | pred = model.forward(Q=Q) 17 | plt.imshow(pred.detach().cpu().reshape(res,res,3), origin="lower", vmin=0, vmax=1) 18 | plt.title('Ours') 19 | plt.axis('off') 20 | 21 | plt.savefig(path) 22 | plt.close() -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | CONFIGS = {} 2 | 3 | CONFIGS['rendering_flowerpot'] = { 4 | 'NUM_ITERS': 2000, 5 | 'BATCH_SIZE': 2**19, 6 | 'LR': 0.01, 7 | 'FNAME': 'data/rendering/flowerpot/img.msh_curved.npz', 8 | 'SAMPLING': 'triangle', 9 | 'SAMPLING_GRID_SIZE': -1 # set to -1 for triangle sampling 10 | } 11 | 12 | CONFIGS['vg_shapes'] = { 13 | 'NUM_ITERS': 2000, 14 | 'BATCH_SIZE': 2**19, 15 | 'LR': 0.01, 16 | 'FNAME': 'data/vg/shapes/img.msh_curved.npz', 17 | 'SAMPLING': 'grid', 18 | 'SAMPLING_GRID_SIZE': 10000 19 | } 20 | 21 | CONFIGS['wos_circles'] = { 22 | 'NUM_ITERS': 2000, 23 | 'BATCH_SIZE': 2**16, 24 | 'LR': 0.01, 25 | 'FNAME': 'data/wos/circles/img.msh_curved.npz', 26 | 'SAMPLING': 'grid', 27 | 'SAMPLING_GRID_SIZE': 2000 28 | } 29 | 30 | CONFIGS['wos_overview'] = { 31 | 'NUM_ITERS': 2000, 32 | 'BATCH_SIZE': 2**16, 33 | 'LR': 0.01, 34 | 'FNAME': 'data/wos/overview/img.msh_curved.npz', 35 | 'SAMPLING': 'grid', 36 | 'SAMPLING_GRID_SIZE': 2000 37 | } -------------------------------------------------------------------------------- /slang/set-curved-inside-sign.slang: -------------------------------------------------------------------------------- 1 | import utils; 2 | // This routine finds the nearest vertex index in the CCW and CW directions (the last two arguments) that are adjacent to a given vertex (second argument) for a given triangle (first argument) that the vertex is a part of 3 | [AutoPyBindCUDA] 4 | [CUDAKernel] 5 | void run( 6 | int T_NUM_CURVE, 7 | TensorView V, // Array of vertices 8 | TensorView T, 9 | TensorView T_bez_cp_idx, 10 | TensorView T_inside_cubic_sign, 11 | ) 12 | { 13 | uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx(); 14 | if (dispatch_id.x >= T_NUM_CURVE) return; 15 | int tri_idx = dispatch_id.x; 16 | int3 T_curr = T[tri_idx]; 17 | int2 T_bez_curr = T_bez_cp_idx[tri_idx]; 18 | 19 | float2 v0 = V[T_curr.x], v3 = V[T_curr.y], q = V[T_curr.z]; 20 | float2 v1 = l2b(v0, V[T_bez_curr.x], V[T_bez_curr.y], v3, 1), 21 | v2 = l2b(v0, V[T_bez_curr.x], V[T_bez_curr.y], v3, 2); 22 | T_inside_cubic_sign[tri_idx] = implicit_cubic(q, v0, v1, v2, v3) > 0.0; 23 | } -------------------------------------------------------------------------------- /slang/count-triangles-per-cell.slang: -------------------------------------------------------------------------------- 1 | // First step of the pre-processing, count the number of 2 | // triangles per cell for our acceleration structure 3 | 4 | // Loop over all triangles (in parallel) and add 1 to the count 5 | // of all the pixels they could potentially overlap 6 | import utils; 7 | 8 | [AutoPyBindCUDA] 9 | [CUDAKernel] 10 | void run( 11 | int X, // Acceleration structure length in X 12 | int Y, // Acceleration structure length in Y 13 | int T_NUM, // Number of triangles 14 | TensorView V, // Array of vertices 15 | TensorView T, // Array of triangles 16 | TensorView cell_triangle_count 17 | ) 18 | { 19 | uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx(); 20 | 21 | if (dispatch_id.x >= T_NUM) return; 22 | 23 | int tri_idx = dispatch_id.x; 24 | // Get triangle vertices and bounding box 25 | float2 v0 = V[T[tri_idx].x], v1 = V[T[tri_idx].y], v2 = V[T[tri_idx].z]; 26 | BBox bbox = get_bounding_box(v0, v1, v2, X, Y); 27 | 28 | int oldVal; // Slang just needs this temp variable for interlocked add 29 | 30 | for (int x = bbox.min_xy_idx.x - 1; x <= bbox.max_xy_idx.x + 1; x++) { 31 | for (int y = bbox.min_xy_idx.y - 1; y <= bbox.max_xy_idx.y + 1; y++) { 32 | if (x < 0 || x >= X || y < 0 || y >= Y) continue; 33 | cell_triangle_count.InterlockedAdd(x * Y + y, 1, oldVal); 34 | } 35 | } 36 | } -------------------------------------------------------------------------------- /slang/vertex-is-continuous.slang: -------------------------------------------------------------------------------- 1 | import utils; 2 | 3 | [AutoPyBindCUDA] 4 | [CUDAKernel] 5 | void run( 6 | int T_NUM_DISC, 7 | TensorView V, // Array of vertices 8 | TensorView T, // Array of triangles 9 | TensorView V_is_continuous 10 | ) 11 | { 12 | uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx(); 13 | 14 | // Continuous triangle, continue 15 | if (dispatch_id.x >= T_NUM_DISC) return; 16 | int tri_idx = dispatch_id.x; 17 | // The first edge is always the discontinuous one 18 | // We don't need to care about atomics since even in a race condition the value set will be false 19 | int oldVal; 20 | V_is_continuous.InterlockedAdd(T[tri_idx].x, -1, oldVal); 21 | V_is_continuous.InterlockedAdd(T[tri_idx].y, -1, oldVal); 22 | 23 | // Vertices on domain boundary are continuous, no discontinuities allowed here 24 | // if ( 25 | // V[T[tri_idx].z].x < BEPS || V[T[tri_idx].z].y < BEPS || 26 | // V[T[tri_idx].z].x > 1.0 - BEPS || V[T[tri_idx].z].y > 1.0 - BEPS 27 | // ) { 28 | // V_is_continuous.InterlockedAdd(T[tri_idx].z, -1, oldVal); 29 | // } 30 | 31 | if ( 32 | V[T[tri_idx].z].x < BEPS || V[T[tri_idx].z].y < BEPS || 33 | V[T[tri_idx].z].x > 1.0 - BEPS || V[T[tri_idx].z].y > 1.0 - BEPS 34 | ) { 35 | V_is_continuous.InterlockedAdd(T[tri_idx].z, -1, oldVal); 36 | } 37 | // V_is_continuous[T[tri_idx].x] = false; 38 | // V_is_continuous[T[tri_idx].y] = false; 39 | } -------------------------------------------------------------------------------- /slang/count-adj-disc.slang: -------------------------------------------------------------------------------- 1 | import utils; 2 | 3 | [AutoPyBindCUDA] 4 | [CUDAKernel] 5 | void run( 6 | int T_NUM, 7 | int T_NUM_DISC, 8 | TensorView V, // Array of vertices 9 | TensorView T, // Array of triangles 10 | TensorView V_num_adj_disc_V, 11 | TensorView V_is_continuous, 12 | ) 13 | { 14 | uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx(); 15 | 16 | // Continuous triangle, continue 17 | if (dispatch_id.x >= T_NUM) return; 18 | int tri_idx = dispatch_id.x; 19 | // The first edge [X,Y] is always the discontinuous one 20 | float2 vx = V[T[tri_idx].x], vy = V[T[tri_idx].y], vz = V[T[tri_idx].z]; 21 | int oldVal; 22 | bool v0 = !V_is_continuous[T[tri_idx].x]; 23 | bool v1 = !V_is_continuous[T[tri_idx].y]; 24 | bool v2 = !V_is_continuous[T[tri_idx].z]; 25 | if (v0 && v1) { 26 | if (tri_idx < T_NUM_DISC || (idb(vx) && idb(vy) && dbi(vx) == dbi(vy))) { 27 | V_num_adj_disc_V.InterlockedAdd(T[tri_idx].x, 1, oldVal); 28 | V_num_adj_disc_V.InterlockedAdd(T[tri_idx].y, 1, oldVal); 29 | } 30 | } 31 | if (v1 && v2) { 32 | if (idb(vz) && idb(vy) && dbi(vz) == dbi(vy)) { 33 | V_num_adj_disc_V.InterlockedAdd(T[tri_idx].y, 1, oldVal); 34 | V_num_adj_disc_V.InterlockedAdd(T[tri_idx].z, 1, oldVal); 35 | } 36 | } 37 | if (v0 && v2) { 38 | if (idb(vz) && idb(vx) && dbi(vz) == dbi(vx)) { 39 | V_num_adj_disc_V.InterlockedAdd(T[tri_idx].z, 1, oldVal); 40 | V_num_adj_disc_V.InterlockedAdd(T[tri_idx].x, 1, oldVal); 41 | } 42 | } 43 | 44 | } -------------------------------------------------------------------------------- /slang/add-triangles-per-cell.slang: -------------------------------------------------------------------------------- 1 | import utils; 2 | 3 | [AutoPyBindCUDA] 4 | [CUDAKernel] 5 | void run( 6 | int X, // Acceleration structure length in X 7 | int Y, // Acceleration structure length in Y 8 | int T_NUM, // Number of triangles 9 | TensorView V, // Array of vertices 10 | TensorView T, // Array of triangles 11 | TensorView cell_to_triangle_index_ptr, // Pointer to next array for each cell 12 | TensorView cell_to_triangle_index, // Triangle indinces within each cell 13 | TensorView temp_cell_triangle_count // Temp counter to avoid deadlock 14 | ) 15 | { 16 | uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx(); 17 | 18 | if (dispatch_id.x >= T_NUM) return; 19 | 20 | int tri_idx = dispatch_id.x; 21 | // Get triangle vertices and bounding box 22 | float2 v0 = V[T[tri_idx].x], v1 = V[T[tri_idx].y], v2 = V[T[tri_idx].z]; 23 | BBox bbox = get_bounding_box(v0, v1, v2, X, Y); 24 | 25 | 26 | for (int x = bbox.min_xy_idx.x - 1; x <= bbox.max_xy_idx.x + 1; x++) { 27 | for (int y = bbox.min_xy_idx.y - 1; y <= bbox.max_xy_idx.y + 1; y++) { 28 | if (x < 0 || x >= X || y < 0 || y >= Y) continue; 29 | int oldVal; // Slang just needs this temp variable for interlocked add 30 | temp_cell_triangle_count.InterlockedAdd(x * Y + y, 1, oldVal); 31 | int start = cell_to_triangle_index_ptr[x * Y + y]; 32 | // There is guaranteed to be no race condition here 33 | cell_to_triangle_index[start + oldVal] = tri_idx; 34 | } 35 | } 36 | } -------------------------------------------------------------------------------- /slang/add-adj-disc.slang: -------------------------------------------------------------------------------- 1 | import utils; 2 | 3 | [AutoPyBindCUDA] 4 | [CUDAKernel] 5 | void run( 6 | int T_NUM, 7 | int T_NUM_DISC, 8 | TensorView V, // Array of vertices 9 | TensorView T, // Array of triangles 10 | TensorView temp_V_num_adj_disc_V, 11 | TensorView V_adj_disc_idx_ptr, 12 | TensorView V_adj_disc_idx, 13 | TensorView V_is_continuous 14 | ) 15 | { 16 | uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx(); 17 | 18 | // Continuous triangle, continue 19 | if (dispatch_id.x >= T_NUM) return; 20 | int tri_idx = dispatch_id.x; 21 | // The first edge [X,Y] is always the discontinuous one 22 | int x = T[tri_idx].x, y = T[tri_idx].y, z = T[tri_idx].z; 23 | 24 | float2 vx = V[T[tri_idx].x], vy = V[T[tri_idx].y], vz = V[T[tri_idx].z]; 25 | int oldVal; 26 | bool v0 = !V_is_continuous[T[tri_idx].x]; 27 | bool v1 = !V_is_continuous[T[tri_idx].y]; 28 | bool v2 = !V_is_continuous[T[tri_idx].z]; 29 | if (v0 && v1) { 30 | if (tri_idx < T_NUM_DISC || (idb(vx) && idb(vy) && dbi(vx) == dbi(vy))) { 31 | temp_V_num_adj_disc_V.InterlockedAdd(x, 1, oldVal); 32 | V_adj_disc_idx[V_adj_disc_idx_ptr[x] + oldVal] = y; 33 | temp_V_num_adj_disc_V.InterlockedAdd(y, 1, oldVal); 34 | V_adj_disc_idx[V_adj_disc_idx_ptr[y] + oldVal] = x; 35 | } 36 | 37 | } 38 | if (v1 && v2) { 39 | if (idb(vz) && idb(vy) && dbi(vz) == dbi(vy)) { 40 | temp_V_num_adj_disc_V.InterlockedAdd(y, 1, oldVal); 41 | V_adj_disc_idx[V_adj_disc_idx_ptr[y] + oldVal] = z; 42 | temp_V_num_adj_disc_V.InterlockedAdd(z, 1, oldVal); 43 | V_adj_disc_idx[V_adj_disc_idx_ptr[z] + oldVal] = y; 44 | } 45 | } 46 | if (v0 && v2) { 47 | if (idb(vz) && idb(vx) && dbi(vz) == dbi(vx)) { 48 | temp_V_num_adj_disc_V.InterlockedAdd(z, 1, oldVal); 49 | V_adj_disc_idx[V_adj_disc_idx_ptr[z] + oldVal] = x; 50 | temp_V_num_adj_disc_V.InterlockedAdd(x, 1, oldVal); 51 | V_adj_disc_idx[V_adj_disc_idx_ptr[x] + oldVal] = z; 52 | } 53 | } 54 | } -------------------------------------------------------------------------------- /slang/point-in-triangle.slang: -------------------------------------------------------------------------------- 1 | import utils; 2 | 3 | [AutoPyBindCUDA] 4 | [CUDAKernel] 5 | void run( 6 | int X, // Acceleration structure length in X 7 | int Y, // Acceleration structure length in Y 8 | TensorView Q, // Array of query point 9 | TensorView V, // Array of vertices 10 | TensorView T, // Array of triangles 11 | TensorView cell_to_triangle_index, 12 | TensorView cell_to_triangle_index_ptr, 13 | TensorView QT_idx, // Triangle index for Query point 14 | TensorView QT_uv, // Triangle index for Query point 15 | ) 16 | { 17 | uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx(); 18 | int qi = dispatch_id.x; 19 | if (qi >= Q.size(0)) return; 20 | 21 | int NUM_CELLS = cell_to_triangle_index_ptr.size(0); 22 | 23 | float2 q = Q[qi]; // Current query point 24 | q.x = max(min(q.x, 1.0 - BEPS), BEPS); 25 | q.y = max(min(q.y, 1.0 - BEPS), BEPS); 26 | 27 | // // First find the cell that the query point belongs to 28 | int cell_idx = int(q.x * X) * Y + int(q.y * Y); 29 | cell_idx = min(max(cell_idx, 0), NUM_CELLS - 1); 30 | 31 | 32 | // Now linearly search through all the triangles that belong to this cell. For our use case, there should only be a single triangle per cell 33 | int s = cell_to_triangle_index_ptr[cell_idx]; 34 | int e = cell_to_triangle_index_ptr[cell_idx+1]; 35 | for (int i = s; i < e; i++) { 36 | int ti = cell_to_triangle_index[i]; 37 | // printf("Query idx %d, Cell idx %d Triangle idx %d\n", qi, cell_idx, ti); 38 | // if (qi == 0) 39 | // printf("Query %f %f, Triangle V0 %f %f, V1 %f %f, V2 %f %f\n", q.x, q.y, V[T[ti].x].x, V[T[ti].x].y, V[T[ti].y].x, V[T[ti].y].y, V[T[ti].z].x, V[T[ti].z].y); 40 | if (is_inside_triangle(q, V[T[ti].x], V[T[ti].y], V[T[ti].z])) { 41 | // printf("Inside: Query idx %d, Cell idx %d Triangle idx %d\n", qi, cell_idx, ti); 42 | QT_idx[qi] = ti; 43 | QT_uv[qi] = get_barycentrics(q, V[T[ti].x], V[T[ti].y], V[T[ti].z]); 44 | break; 45 | } 46 | } 47 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implementation of code for Discontinuity-Aware 2D Neural Fields (SIGGRAPH Asia 2023, Transactions on Graphics) [website](https://yashbelhe.github.io/danf/index.html), [paper](https://yashbelhe.github.io/danf/DiscontinuityAwareNeuralFields_SigAsia2023.pdf), [slides PDF](https://yashbelhe.github.io/danf/index.html). 2 | ### Authors: Yash Belhe, Michael Gharbi, Matt Fisher, Iliyan Georgiev, Ravi Ramamoorthi, Tzu-Mao Li 3 | 4 | This code is a re-implementation of the paper using SLANG and PyTorch. As such, the results may not exactly match the original implementation. 5 | 6 | Using this code, you can (qualitatively) reproduce a few examples from the paper: 7 | 1. Flowerpot scene (Fig. 9) -- Rendering 8 | 2. Circles scene (Fig. 10) -- Walk on Spheres 9 | 3. Shapes scene (Fig. 2) -- Vector Graphics 10 | 11 | ## Setup 12 | Install pytorch and [diffvg](https://github.com/BachiLi/diffvg) (with python bindings). 13 | 14 | ```pip install scikit-image numpy matplotlib slangpy svgpathtools pillow``` 15 | 16 | ## Data 17 | Download the data from [here](https://drive.google.com/drive/folders/1IOt6_cjE67gquhSm41275n-wIbpgnyWx?usp=sharing) and place it in the root directory. 18 | 19 | ## Run 20 | To run the circles scene: `python train.py circles` and similarly for `shapes, flowerpot`. 21 | Results will be generated in the results directory. 22 | 23 | Note: the first time you run this, there might be some delay (2-3 mins) while SLANG compiles some kernels. 24 | 25 | ### Notable missing components: 26 | 1. Mesh compression using draco. 27 | 2. Data preparation for custom scenes, a> modified version of TriWild and b> edge extraction for rendering scenes. 28 | 29 | ``` 30 | @article{Belhe:2023:DiscontinuityAwareNeuralFields, 31 | author = {Yash Belhe and Micha\"{e}l Gharbi and Matthew Fisher and Iliyan Georgiev and Ravi Ramamoorthi and Tzu-Mao Li}, 32 | title = {Discontinuity-aware 2D neural fields}, 33 | journal = {ACM Transactions on Graphics (Proceedings of SIGGRAPH Asia)}, 34 | year = {2023}, 35 | volume = {41}, 36 | number = {6}, 37 | doi = {10.1145/3550454.3555484} 38 | } 39 | ``` -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | import model 6 | import samplers 7 | import os 8 | import utils 9 | import configs 10 | import argparse 11 | 12 | parser = argparse.ArgumentParser(description='Optional app description') 13 | parser.add_argument('expt_name', type=str, help='Name of the experiment') 14 | args = parser.parse_args() 15 | 16 | torch.manual_seed(30) 17 | 18 | if args.expt_name == "flowerpot": 19 | config = configs.CONFIGS['rendering_flowerpot'] 20 | elif args.expt_name == "shapes": 21 | config = configs.CONFIGS['vg_shapes'] 22 | elif args.expt_name == "overview": 23 | config = configs.CONFIGS['wos_overview'] 24 | elif args.expt_name == "circles": 25 | config = configs.CONFIGS['wos_circles'] 26 | else: 27 | assert False, "unsupported scene" 28 | 29 | SAVE_INT = 100 30 | FEATURE_DIM = 5 31 | ACCEL_GRID_DIMS = (2000, 2000) 32 | LR = 0.01 33 | BETAS = (0.9, 0.99) 34 | NUM_QUERIES_SQRT = config['SAMPLING_GRID_SIZE'] 35 | NUM_ITERS = config['NUM_ITERS'] 36 | BATCH_SIZE = config['BATCH_SIZE'] 37 | 38 | fname = config['FNAME'] 39 | results_dir = os.path.join('results', *fname.split('/')[1:3]) 40 | if not os.path.exists(results_dir): 41 | os.makedirs(results_dir) 42 | 43 | 44 | app = fname.split('/')[1] 45 | 46 | sampler = samplers.BaseSampler() 47 | if app == "vg": 48 | sampler = samplers.VGSampler(fname) 49 | elif app == "rendering": 50 | sampler = samplers.RenderingSampler(fname) 51 | elif app == "wos": 52 | sampler = samplers.WoSSampler(fname) 53 | 54 | mesh = np.load(fname) 55 | model = model.DANN(mesh=mesh, FEATURE_DIM=FEATURE_DIM, ACCEL_GRID_DIMS=ACCEL_GRID_DIMS) 56 | optim = torch.optim.Adam(model.parameters(), lr=LR, betas=BETAS) 57 | 58 | iter_count = 0 59 | for idx in range(NUM_ITERS): 60 | if config['SAMPLING'] == 'triangle': 61 | Q = samplers.get_stratified_in_triangles(model=model) 62 | elif config['SAMPLING'] == 'grid': 63 | Q = samplers.get_stratified_random(NUM_QUERIES_SQRT) 64 | else: 65 | assert False 66 | randperm = torch.randperm(Q.shape[0]) 67 | Q = Q[randperm] 68 | 69 | gt = sampler(Q) 70 | 71 | for batch_idx in range((Q.shape[0] + BATCH_SIZE - 1) // BATCH_SIZE): 72 | Qb = Q[batch_idx*BATCH_SIZE: (batch_idx+1)*BATCH_SIZE] 73 | gtb = gt[batch_idx*BATCH_SIZE: (batch_idx+1)*BATCH_SIZE] 74 | 75 | predb = model(Q=Qb) 76 | if app == "rendering": 77 | loss = ((predb-gtb).square()/(predb.detach().square() + 0.01)).mean() # rendering 78 | else: 79 | loss = (predb-gtb).square().mean() # vg, wos 80 | 81 | loss.backward() 82 | optim.step() 83 | optim.zero_grad() 84 | 85 | print(f"Iter {iter_count}, Loss: {loss.item()}") 86 | 87 | if iter_count % SAVE_INT == 0: 88 | torch.save({'model': model.state_dict()}, os.path.join(results_dir, f"checkpoint_{iter_count}.pth")) 89 | locs = [(0.5, 0.5, 1)] 90 | # locs = [(0.5, 0.5, 1), (0.57695, 0.34957, 0.006), (0.513, 0.4012, 0.005), (0.57195, 0.34757, 0.1)] 91 | for loc_idx, loc in enumerate(locs): 92 | utils.make_comparison_plot(loc[0], loc[1], loc[2], 500, model, sampler, os.path.join(results_dir, f"output_{loc_idx}_{iter_count}.png")) 93 | 94 | iter_count += 1 95 | if iter_count > NUM_ITERS: 96 | break 97 | -------------------------------------------------------------------------------- /slang/seco-link-radial-feats.slang: -------------------------------------------------------------------------------- 1 | import utils; 2 | 3 | int next_vertex(int3 tri, int v_idx) { 4 | if (v_idx == 0) return tri.y; 5 | else if (v_idx == 1) return tri.z; 6 | else return tri.x; 7 | } 8 | 9 | int previous_vertex(int3 tri, int v_idx) { 10 | if (v_idx == 0) return tri.z; 11 | else if (v_idx == 1) return tri.x; 12 | else return tri.y; 13 | } 14 | 15 | int get_kth_adj_for_nth_elem(TensorView data, TensorView data_ptr, int k, int n) { 16 | return data[data_ptr[n] + k]; 17 | } 18 | 19 | // This routine finds the nearest vertex index in the CCW and CW directions (the last two arguments) that are adjacent to a given vertex (second argument) for a given triangle (first argument) that the vertex is a part of 20 | [AutoPyBindCUDA] 21 | [CUDAKernel] 22 | void run( 23 | int T_NUM_CURVE, 24 | TensorView V, // Array of vertices 25 | TensorView T, // Array of triangles 26 | TensorView V_is_continuous, 27 | TensorView V_adj_disc_idx, 28 | TensorView V_adj_disc_idx_ptr, 29 | TensorView V_num_adj_disc_V, 30 | TensorView seco_T_adj_disc_feat_idx) 31 | { 32 | uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx(); 33 | 34 | if (dispatch_id.x >= T_NUM_CURVE) return; 35 | int tri_idx = dispatch_id.x; 36 | int v_idx = -1; 37 | 38 | // We only need to do this for the first and second vertex along which the curved edge lies 39 | for (int k = 0; k < 2; k++) { 40 | if (k == 0) v_idx = T[tri_idx].x; 41 | else if (k == 1) v_idx = T[tri_idx].y; 42 | else v_idx = T[tri_idx].z; 43 | 44 | // if (V_is_continuous[v_idx]) { 45 | // T_adj_disc_V[tri_idx*6 + k*2 + 0] = -2; // just to mark that it has been visited 46 | // T_adj_disc_V[tri_idx * 6 + k * 2 + 1] = -2; // just to mark that it has been visited 47 | // seco_T_adj_disc_feat_idx[tri_idx * 6 + k * 2 + 0] = -2; // just to mark that it has been visited 48 | // seco_T_adj_disc_feat_idx[tri_idx*6 + k*2 + 1] = -2; //just to mark that it has been visited 49 | // continue; 50 | // } 51 | 52 | if (k == 0) { 53 | // CW Direction 54 | // int vn_idx = next_vertex(T[tri_idx], k); 55 | int vn_idx = previous_vertex(T[tri_idx], k); 56 | float2 vn_dir = V[vn_idx] - V[v_idx]; 57 | float vn_min_angle = 10000.0; 58 | int vn_min_angle_idx = -100; 59 | int offset_min = -100; 60 | for (int i = 0; i < V_num_adj_disc_V[v_idx]; i++) { 61 | int vm_idx = get_kth_adj_for_nth_elem(V_adj_disc_idx, V_adj_disc_idx_ptr, i, v_idx); 62 | float2 vm_dir = V[vm_idx] - V[v_idx]; 63 | float curr_angle = get_angle_ccw(vn_dir, vm_dir); 64 | if (curr_angle < vn_min_angle) { 65 | vn_min_angle = curr_angle; 66 | vn_min_angle_idx = vm_idx; 67 | offset_min = i; 68 | } 69 | } 70 | seco_T_adj_disc_feat_idx[tri_idx*2 + 0] = V_adj_disc_idx_ptr[v_idx] + offset_min; // feature index for CW dir 71 | }; 72 | 73 | if (k == 1) { 74 | // CCW Direction 75 | int vp_idx = next_vertex(T[tri_idx], k); 76 | // int vp_idx = previous_vertex(T[tri_idx], v_idx); 77 | float2 vp_dir = V[vp_idx] - V[v_idx]; 78 | float vp_min_angle = 10000.0; 79 | int vp_min_angle_idx = -100; 80 | int offset_min = -100; 81 | for (int i = 0; i < V_num_adj_disc_V[v_idx]; i++) { 82 | int vm_idx = get_kth_adj_for_nth_elem(V_adj_disc_idx, V_adj_disc_idx_ptr, i, v_idx); 83 | float2 vm_dir = V[vm_idx] - V[v_idx]; 84 | float curr_angle = get_angle_ccw(vm_dir, vp_dir); 85 | if (curr_angle < vp_min_angle) { 86 | vp_min_angle = curr_angle; 87 | vp_min_angle_idx = vm_idx; 88 | offset_min = i; 89 | } 90 | } 91 | seco_T_adj_disc_feat_idx[tri_idx*2 + 1] = V_adj_disc_idx_ptr[v_idx] + offset_min; // feature index for CCW dir 92 | }; 93 | } 94 | } -------------------------------------------------------------------------------- /slang/link-radial-feats.slang: -------------------------------------------------------------------------------- 1 | import utils; 2 | 3 | int next_vertex(int3 tri, int v_idx) { 4 | if (v_idx == 0) return tri.y; 5 | else if (v_idx == 1) return tri.z; 6 | else return tri.x; 7 | } 8 | 9 | int previous_vertex(int3 tri, int v_idx) { 10 | if (v_idx == 0) return tri.z; 11 | else if (v_idx == 1) return tri.x; 12 | else return tri.y; 13 | } 14 | 15 | int get_kth_adj_for_nth_elem(TensorView data, TensorView data_ptr, int k, int n) { 16 | return data[data_ptr[n] + k]; 17 | } 18 | 19 | // This routine finds the nearest vertex index in the CCW and CW directions (the last two arguments) that are adjacent to a given vertex (second argument) for a given triangle (first argument) that the vertex is a part of 20 | [AutoPyBindCUDA] 21 | [CUDAKernel] 22 | void run( 23 | int T_NUM, 24 | TensorView V, // Array of vertices 25 | TensorView T, // Array of triangles 26 | TensorView V_is_continuous, 27 | TensorView V_adj_disc_idx, 28 | TensorView V_adj_disc_idx_ptr, 29 | TensorView V_num_adj_disc_V, 30 | TensorView T_adj_disc_V, 31 | TensorView T_adj_disc_feat_idx) 32 | { 33 | uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx(); 34 | 35 | if (dispatch_id.x >= T_NUM) return; 36 | int tri_idx = dispatch_id.x; 37 | int v_idx = -1; 38 | for (int k = 0; k < 3; k++) { 39 | if (k == 0) v_idx = T[tri_idx].x; 40 | else if (k == 1) v_idx = T[tri_idx].y; 41 | else v_idx = T[tri_idx].z; 42 | float2 v = V[v_idx]; 43 | 44 | if (V_is_continuous[v_idx]) { 45 | T_adj_disc_V[tri_idx*6 + k*2 + 0] = -2; // just to mark that it has been visited 46 | T_adj_disc_V[tri_idx*6 + k*2 + 1] = -2; // just to mark that it has been visited 47 | T_adj_disc_feat_idx[tri_idx*6 + k*2 + 0] = -2; //just to mark that it has been visited 48 | T_adj_disc_feat_idx[tri_idx*6 + k*2 + 1] = -2; //just to mark that it has been visited 49 | continue; 50 | } 51 | 52 | { 53 | // CW Direction 54 | int _v_idx = next_vertex(T[tri_idx], k); 55 | float2 v_dir = V[_v_idx] - v; 56 | float min_angle = 10000.0; 57 | int min_angle_idx = -100; 58 | int offset_min = -100; 59 | 60 | for (int i = 0; i < V_num_adj_disc_V[v_idx]; i++) { 61 | int vi_idx = get_kth_adj_for_nth_elem(V_adj_disc_idx, V_adj_disc_idx_ptr, i, v_idx); 62 | float2 vi_dir = V[vi_idx] - v; 63 | // Measure angle starting from v_i moving toward v in the CCW direction 64 | float curr_angle = get_angle_ccw(vi_dir, v_dir); 65 | if (curr_angle < min_angle) { 66 | min_angle = curr_angle; 67 | min_angle_idx = vi_idx; 68 | offset_min = i; 69 | } 70 | } 71 | if (v_idx == 4260 && (tri_idx == 77 || tri_idx == 7597 || tri_idx == 357)) { 72 | printf("T %d, Next V %d, CW V %d \n", tri_idx, _v_idx, min_angle_idx); 73 | } 74 | T_adj_disc_V[tri_idx*6 + k*2 + 0] = min_angle_idx; // nearest disc V idx in CW dir 75 | T_adj_disc_feat_idx[tri_idx*6 + k*2 + 0] = V_adj_disc_idx_ptr[v_idx] + offset_min; // feature index for CW dir 76 | }; 77 | 78 | { 79 | // CCW Direction 80 | int _v_idx = previous_vertex(T[tri_idx], k); 81 | float2 v_dir = V[_v_idx] - v; 82 | float min_angle = 10000.0; 83 | int min_angle_idx = -100; 84 | int offset_min = -100; 85 | for (int i = 0; i < V_num_adj_disc_V[v_idx]; i++) { 86 | int vi_idx = get_kth_adj_for_nth_elem(V_adj_disc_idx, V_adj_disc_idx_ptr, i, v_idx); 87 | float2 vi_dir = V[vi_idx] - v; 88 | // Measure angle starting from v moving toward v_i in the CCW direction 89 | float curr_angle = get_angle_ccw(v_dir, vi_dir); 90 | // if (tri_idx == 0) { 91 | // printf("Proc %d: %d, %f \n", v_idx, vi_idx, curr_angle); 92 | // } 93 | if (curr_angle < min_angle) { 94 | min_angle = curr_angle; 95 | min_angle_idx = vi_idx; 96 | offset_min = i; 97 | } 98 | } 99 | if (v_idx == 4260 && (tri_idx == 77 || tri_idx == 7597 || tri_idx == 357)) { 100 | printf("T %d, Prev V %d, CCW V %d \n", tri_idx, _v_idx, min_angle_idx); 101 | } 102 | // if (tri_idx == 0) { 103 | // printf("Final %d: %d, %f \n", v_idx, min_angle_idx, min_angle); 104 | // } 105 | T_adj_disc_V[tri_idx*6 + k*2 + 1] = min_angle_idx; // nearest disc V idx in CCW dir 106 | T_adj_disc_feat_idx[tri_idx*6 + k*2 + 1] = V_adj_disc_idx_ptr[v_idx] + offset_min; // feature index for CCW dir 107 | }; 108 | } 109 | } -------------------------------------------------------------------------------- /slang/feature-interpolation.slang: -------------------------------------------------------------------------------- 1 | import utils; 2 | #define EPS_DIV 0.000000001 3 | 4 | // This routine finds the nearest vertex index in the CCW and CW directions (the last two arguments) that are adjacent to a given vertex (second argument) for a given triangle (first argument) that the vertex is a part of 5 | [AutoPyBindCUDA] 6 | [CUDAKernel] 7 | void run( 8 | int T_NUM, 9 | int T_NUM_CURVE, 10 | int FEATURE_DIM, 11 | TensorView V, // Array of vertices 12 | TensorView T, // Array of triangles 13 | TensorView Q, // Array of query points 14 | TensorView QT_idx, // Array of triangles indices that contain the respective query points 15 | TensorView QT_uv, // Array of barycentrics for each query point 16 | TensorView V_is_continuous, 17 | TensorView T_bez_cp_idx, 18 | TensorView T_inside_cubic_sign, 19 | TensorView T_adj_disc_V, 20 | TensorView T_adj_disc_feat_idx, 21 | TensorView seco_T_adj_disc_feat_idx, 22 | TensorView V_continuous_feat, 23 | TensorView V_discontinuous_feat, 24 | TensorView V_curved_feat, 25 | TensorView interpolated_features 26 | ) 27 | { 28 | uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx(); 29 | 30 | if (dispatch_id.x >= Q.size(0)) return; 31 | int q_idx = dispatch_id.x; // Current query index 32 | int tri_idx = QT_idx[q_idx]; // Triangle index for the current query point 33 | int3 T_curr = T[tri_idx]; // Triangle 34 | int2 T_bez_curr = T_bez_cp_idx[tri_idx]; // Triangle 35 | float3 uv = QT_uv[q_idx]; // Barycentric 36 | 37 | int v_idx = -1; // Vertex index within the triangle 38 | float2 q = Q[q_idx]; // query location 39 | float2 v = float2(-1); // vertex location 40 | 41 | bool is_in_curved_region = false; 42 | 43 | // Is a curved triangle 44 | if (tri_idx < T_NUM_CURVE) { 45 | // Is inside the curved region inside the curved triangle 46 | float2 v0 = V[T_curr.x], v3 = V[T_curr.y]; 47 | float2 v1 = l2b(v0, V[T_bez_curr.x], V[T_bez_curr.y], v3, 1), 48 | v2 = l2b(v0, V[T_bez_curr.x], V[T_bez_curr.y], v3, 2); 49 | bool cubic_sign = implicit_cubic(q, v0, v1, v2, v3) > 0.0; 50 | if (cubic_sign != T_inside_cubic_sign[tri_idx]) { 51 | is_in_curved_region = true; 52 | } 53 | } 54 | 55 | for (int k = 0; k < 3; k++) { 56 | float uv_curr = 0.0; // barycentric coordinate for current vertex query point pair 57 | if (k == 0) { 58 | v_idx = T_curr.x; 59 | uv_curr = uv.x; 60 | } else if (k == 1) { 61 | v_idx = T_curr.y; 62 | uv_curr = uv.y; 63 | } 64 | else { 65 | v_idx = T_curr.z; 66 | uv_curr = uv.z; 67 | } 68 | v = V[v_idx]; 69 | 70 | if (is_in_curved_region) { 71 | int offset_v0 = seco_T_adj_disc_feat_idx[tri_idx * 2 + 0]; 72 | int offset_v1 = seco_T_adj_disc_feat_idx[tri_idx * 2 + 1]; 73 | 74 | for (int f = 0; f < FEATURE_DIM; f++) { 75 | if (k == 0) { 76 | interpolated_features[FEATURE_DIM*q_idx + f] += uv_curr*V_discontinuous_feat[2*FEATURE_DIM*offset_v0 + FEATURE_DIM + f]; 77 | } else if (k == 1) { 78 | interpolated_features[FEATURE_DIM*q_idx + f] += uv_curr*V_discontinuous_feat[2*FEATURE_DIM*offset_v1 + f]; 79 | } else if (k == 2) { 80 | interpolated_features[FEATURE_DIM*q_idx + f] += uv_curr*V_curved_feat[FEATURE_DIM*tri_idx + f]; 81 | } 82 | } 83 | } else { 84 | if (V_is_continuous[v_idx]) { 85 | // Continuous vertex, no interpolation needed 86 | for (int f = 0; f < FEATURE_DIM; f++) { 87 | interpolated_features[FEATURE_DIM*q_idx + f] += uv_curr * V_continuous_feat[FEATURE_DIM*v_idx + f]; 88 | } 89 | } else { 90 | // Discontinuous vertex 91 | int cw_idx = tri_idx*6 + k*2 + 0; 92 | int ccw_idx = tri_idx*6 + k*2 + 1; 93 | 94 | // Find the vector for the directed edge corresponding to the CW and CCW features 95 | float2 disc_cw_vec = V[T_adj_disc_V[cw_idx]] - v; 96 | float2 disc_ccw_vec = V[T_adj_disc_V[ccw_idx]] - v; 97 | 98 | // This angle is correct for BOTH CW AND CCW, not a bug. 99 | float disc_cw_angle = get_angle_ccw(disc_cw_vec, q - v); 100 | float disc_ccw_angle = get_angle_ccw(q - v, disc_ccw_vec); 101 | 102 | float w_cw = disc_ccw_angle / (disc_ccw_angle + disc_cw_angle + EPS_DIV); 103 | float w_ccw = 1.0 - w_cw; 104 | 105 | int offset_cw = T_adj_disc_feat_idx[cw_idx]; 106 | int offset_ccw = T_adj_disc_feat_idx[ccw_idx]; 107 | for (int f = 0; f < FEATURE_DIM; f++) { 108 | interpolated_features[FEATURE_DIM*q_idx + f] += uv_curr * ( 109 | w_cw * V_discontinuous_feat[2*FEATURE_DIM*offset_cw + f] + 110 | w_ccw * V_discontinuous_feat[2*FEATURE_DIM*offset_ccw + FEATURE_DIM + f] 111 | ); 112 | } 113 | } 114 | } 115 | } 116 | } -------------------------------------------------------------------------------- /slang/utils.slang: -------------------------------------------------------------------------------- 1 | static const float BEPS = 1e-5; 2 | // Note: Consider using double precision for positions to get better point in triangle tests. 3 | // The parameter below is quite sensitive, make sure it is set correctly for high zoom levels 4 | static const float TRI_TEST_EPS = 1e-10; 5 | // static const float TRI_TEST_EPS = 1e-8; 6 | static const float CUBIC_FACTOR = 1000.0; 7 | // static const float TRI_TEST_EPS = 1e-5; 8 | 9 | #define M_PI 3.14159265358979323846 10 | // #define ATAN_EPS 1e-3 11 | #define ATAN_EPS 1e-7 12 | // #define ATAN_EPS 1e-6 13 | // #define ATAN_EPS 1e-4 14 | 15 | double three_C_k(int k) { 16 | if (k == 0 || k == 3) return 1.0; 17 | else if (k == 1 || k == 2) return 3.0; 18 | else return 0.0 / 0.0; 19 | } 20 | 21 | float2 l2b(float2 v0, float2 v1, float2 v2, float2 v3, int idx) { 22 | if (idx == 1) { 23 | return -5.0 / 6.0 * v0 + 3.0 * v1 - 3.0 / 2.0 * v2 + 1.0 / 3.0 * v3; 24 | } else if (idx == 2) { 25 | return -5.0 / 6.0 * v3 + 3.0 * v2 - 3.0 / 2.0 * v1 + 1.0 / 3.0 * v0; 26 | } 27 | return float2(-1.0); 28 | } 29 | 30 | double l_ij(float2 p, float2 p_i, float2 p_j, int i, int j) { 31 | double x = p.x, y = p.y, x_i = p_i.x, y_i = p_i.y, x_j = p_j.x, y_j = p_j.y; 32 | return three_C_k(i) * three_C_k(j) * (x * (y_i - y_j) - y * (x_i - x_j) + x_i * y_j - x_j * y_i); 33 | } 34 | 35 | double implicit_cubic(float2 p, float2 p0, float2 p1, float2 p2, float2 p3) { 36 | double l_21 = l_ij(p, p2, p1, 2, 1); 37 | double l_10 = l_ij(p, p1, p0, 1, 0); 38 | double l_20 = l_ij(p, p2, p0, 2, 0); 39 | double l_30 = l_ij(p, p3, p0, 3, 0); 40 | double l_31 = l_ij(p, p3, p1, 3, 1); 41 | double l_32 = l_ij(p, p3, p2, 3, 2); 42 | double c = l_30 + l_21; 43 | return l_32*(c*l_10 - l_20*l_20) - l_31*(l_31*l_10 - l_20*l_30) + l_30*(l_31*l_20 - c*l_30); 44 | } 45 | 46 | 47 | // domain_boundary_idx 48 | int dbi(float2 x) { 49 | if (x.x < BEPS) return 0; 50 | else if (x.y < BEPS) return 1; 51 | else if (1.0 - x.x < BEPS) return 2; 52 | else if (1.0 - x.y < BEPS) return 4; 53 | return -1; 54 | } 55 | // is on domain boundary 56 | bool idb(float2 x) { 57 | if (x.x < BEPS) return true; 58 | else if (x.y < BEPS) return true; 59 | else if (1.0 - x.x < BEPS) return true; 60 | else if (1.0 - x.y < BEPS) return true; 61 | return false; 62 | } 63 | 64 | float get_angle_ccw(float2 v1, float2 v2) { 65 | // double2 _v1 = double2(v1); 66 | // double2 _v2 = double2(v2); 67 | // double val = atan2(_v1.x*_v2.y - _v1.y*_v2.x, dot(_v1, _v2)); 68 | float val = atan2(v1.y*v2.x - v1.x*v2.y, dot(v1, v2)); 69 | // float val = atan2(v1.x*v2.y - v1.y*v2.x, dot(v1, v2)); 70 | // float val = atan2(v1.x*v2.y - v1.y*v2.x, dot(v1, v2)); 71 | if (val >= -ATAN_EPS) { 72 | // if (val >= 0.0) { 73 | return val; 74 | } 75 | return 2.0 * M_PI + val; 76 | } 77 | 78 | 79 | struct BBox { 80 | float2 min_xy; 81 | float2 max_xy; 82 | int2 min_xy_idx; 83 | int2 max_xy_idx; 84 | } 85 | 86 | BBox get_bounding_box(float2 v0, float2 v1, float2 v2, int kX, int kY) { 87 | float2 min_xy = float2(min(min(v0.x, v1.x), v2.x), min(min(v0.y, v1.y), v2.y)); 88 | float2 max_xy = float2(max(max(v0.x, v1.x), v2.x), max(max(v0.y, v1.y), v2.y)); 89 | int2 min_xy_idx = int2(int(min_xy.x * kX), int(min_xy.y * kY)); 90 | int2 max_xy_idx = int2(int(max_xy.x * kX), int(max_xy.y * kY)); 91 | // int2 max_xy_idx = int2(int(max_xy.x * kX) + 1, int(max_xy.y * kY) + 1); 92 | BBox ret = { min_xy, max_xy, min_xy_idx, max_xy_idx }; 93 | return ret; 94 | } 95 | 96 | bool is_inside_triangle(float2 q, float2 v1, float2 v2, float2 v3) { 97 | // 2D triangle test (return 1.0 if the triangle contains the point (x,y), 0.0 otherwise). 98 | float2 e0 = v2 - v1; 99 | float2 e1 = v3 - v2; 100 | float2 e2 = v1 - v3; 101 | 102 | float2 k0 = float2(-e0.y, e0.x); 103 | float2 k1 = float2(-e1.y, e1.x); 104 | float2 k2 = float2(-e2.y, e2.x); 105 | 106 | float d0 = dot(k0, q - v1); 107 | float d1 = dot(k1, q - v2); 108 | float d2 = dot(k2, q - v3); 109 | 110 | // Check the three half-plane values. 111 | // if (d0 >= -TRI_TEST_EPS && d1 >= -TRI_TEST_EPS && d2 >= -TRI_TEST_EPS) 112 | if (d0 <= TRI_TEST_EPS && d1 <= TRI_TEST_EPS && d2 <= TRI_TEST_EPS) 113 | return true; 114 | else 115 | return false; 116 | } 117 | 118 | // Parallelogram law 119 | float area(float2 v1, float2 v2) { 120 | return v1[0] * v2[1] - v1[1] * v2[0]; 121 | } 122 | 123 | float3 clip(float3 x, float l, float u) { 124 | return float3(max(min(x.x, u), l), max(min(x.y, u), l), max(min(x.z, u), l)); 125 | } 126 | 127 | float3 sum_one(float3 x) { 128 | return x / (x.x + x.y + x.z); 129 | } 130 | 131 | // Taken from: https://github.com/postmalloc/barycuda/blob/master/src/bary.cu 132 | float3 get_barycentrics(float2 q, float2 a, float2 b, float2 c) { 133 | float2 aq = q - a; 134 | float2 bq = q - b; 135 | float2 cq = q - c; 136 | 137 | float2 ac = c - a; 138 | float2 ab = b - a; 139 | float2 ca = a - c; 140 | float2 bc = c - b; 141 | 142 | float nor = area(ab, ac); 143 | float nor0 = area(bc, bq); 144 | float nor1 = area(ca, cq); 145 | float nor2 = area(ab, aq); 146 | 147 | // return float3(nor0, nor1, nor2) / nor; 148 | return sum_one(clip(float3(nor0, nor1, nor2) / nor, 0.0, 1.0)); 149 | } -------------------------------------------------------------------------------- /slang/d-feature-interpolation.slang: -------------------------------------------------------------------------------- 1 | import utils; 2 | #define EPS_DIV 0.000000001 3 | 4 | // This routine finds the nearest vertex index in the CCW and CW directions (the last two arguments) that are adjacent to a given vertex (second argument) for a given triangle (first argument) that the vertex is a part of 5 | [AutoPyBindCUDA] 6 | [CUDAKernel] 7 | void run( 8 | int T_NUM, 9 | int T_NUM_CURVE, 10 | int FEATURE_DIM, 11 | TensorView V, // Array of vertices 12 | TensorView T, // Array of triangles 13 | TensorView Q, // Array of query points 14 | TensorView QT_idx, // Array of triangles indices that contain the respective query points 15 | TensorView QT_uv, // Array of barycentrics for each query point 16 | TensorView V_is_continuous, 17 | TensorView T_bez_cp_idx, 18 | TensorView T_inside_cubic_sign, 19 | TensorView T_adj_disc_V, 20 | TensorView T_adj_disc_feat_idx, 21 | TensorView seco_T_adj_disc_feat_idx, 22 | TensorView V_continuous_feat, 23 | TensorView V_discontinuous_feat, 24 | TensorView V_curved_feat, 25 | TensorView interpolated_features, 26 | TensorView d_V_continuous_feat, 27 | TensorView d_V_discontinuous_feat, 28 | TensorView d_V_curved_feat, 29 | TensorView d_interpolated_features 30 | ) 31 | { 32 | uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx(); 33 | 34 | if (dispatch_id.x >= Q.size(0)) return; 35 | int q_idx = dispatch_id.x; // Current query index 36 | int tri_idx = QT_idx[q_idx]; // Triangle index for the current query point 37 | int3 T_curr = T[tri_idx]; // Triangle 38 | int2 T_bez_curr = T_bez_cp_idx[tri_idx]; // Triangle 39 | float3 uv = QT_uv[q_idx]; // Barycentric 40 | 41 | int v_idx = -1; // Vertex index within the triangle 42 | float2 q = Q[q_idx]; // query location 43 | float2 v = float2(-1); // vertex location 44 | 45 | bool is_in_curved_region = false; 46 | 47 | // Is a curved triangle 48 | if (tri_idx < T_NUM_CURVE) { 49 | // Is inside the curved region inside the curved triangle 50 | float2 v0 = V[T_curr.x], v3 = V[T_curr.y]; 51 | float2 v1 = l2b(v0, V[T_bez_curr.x], V[T_bez_curr.y], v3, 1), 52 | v2 = l2b(v0, V[T_bez_curr.x], V[T_bez_curr.y], v3, 2); 53 | bool cubic_sign = implicit_cubic(q, v0, v1, v2, v3) > 0.0; 54 | if (cubic_sign != T_inside_cubic_sign[tri_idx]) { 55 | is_in_curved_region = true; 56 | } 57 | } 58 | 59 | for (int k = 0; k < 3; k++) { 60 | float uv_curr = 0.0; // barycentric coordinate for current vertex query point pair 61 | if (k == 0) { 62 | v_idx = T_curr.x; 63 | uv_curr = uv.x; 64 | } else if (k == 1) { 65 | v_idx = T_curr.y; 66 | uv_curr = uv.y; 67 | } 68 | else { 69 | v_idx = T_curr.z; 70 | uv_curr = uv.z; 71 | } 72 | v = V[v_idx]; 73 | 74 | if (is_in_curved_region) { 75 | int offset_v0 = seco_T_adj_disc_feat_idx[tri_idx * 2 + 0]; 76 | int offset_v1 = seco_T_adj_disc_feat_idx[tri_idx * 2 + 1]; 77 | float oldVal; 78 | 79 | for (int f = 0; f < FEATURE_DIM; f++) { 80 | if (k == 0) { 81 | d_V_discontinuous_feat.InterlockedAdd(2*FEATURE_DIM*offset_v0 + FEATURE_DIM + f, uv_curr*d_interpolated_features[FEATURE_DIM*q_idx + f], oldVal); 82 | } else if (k == 1) { 83 | d_V_discontinuous_feat.InterlockedAdd(2*FEATURE_DIM*offset_v1 + f, uv_curr*d_interpolated_features[FEATURE_DIM*q_idx + f], oldVal); 84 | } else if (k == 2) { 85 | d_V_discontinuous_feat.InterlockedAdd(FEATURE_DIM*tri_idx + f, uv_curr*d_interpolated_features[FEATURE_DIM*q_idx + f], oldVal); 86 | } 87 | } 88 | } else { 89 | if (V_is_continuous[v_idx]) { 90 | // Continuous vertex, no interpolation needed 91 | for (int f = 0; f < FEATURE_DIM; f++) { 92 | // FWD: interpolated_features[FEATURE_DIM*q_idx + f] += uv_curr * V_continuous_feat[FEATURE_DIM*v_idx + f]; 93 | float oldVal = 0.0; 94 | d_V_continuous_feat.InterlockedAdd(FEATURE_DIM*v_idx + f, uv_curr*d_interpolated_features[FEATURE_DIM*q_idx + f], oldVal); 95 | } 96 | } else { 97 | // Discontinuous vertex 98 | int cw_idx = tri_idx*6 + k*2 + 0; 99 | int ccw_idx = tri_idx*6 + k*2 + 1; 100 | 101 | // Find the vector for the directed edge corresponding to the CW and CCW features 102 | float2 disc_cw_vec = V[T_adj_disc_V[cw_idx]] - v; 103 | float2 disc_ccw_vec = V[T_adj_disc_V[ccw_idx]] - v; 104 | 105 | // This angle is correct for BOTH CW AND CCW, not a bug. 106 | float disc_cw_angle = get_angle_ccw(disc_cw_vec, q - v); 107 | float disc_ccw_angle = get_angle_ccw(q - v, disc_ccw_vec); 108 | 109 | float w_cw = disc_ccw_angle / (disc_ccw_angle + disc_cw_angle + EPS_DIV); 110 | float w_ccw = 1.0 - w_cw; 111 | 112 | int offset_cw = T_adj_disc_feat_idx[cw_idx]; 113 | int offset_ccw = T_adj_disc_feat_idx[ccw_idx]; 114 | for (int f = 0; f < FEATURE_DIM; f++) { 115 | // FWD: interpolated_features[FEATURE_DIM*q_idx + f] += uv_curr * ( 116 | // w_cw * V_discontinuous_feat[2*FEATURE_DIM*offset_cw + f] + 117 | // w_ccw * V_discontinuous_feat[2*FEATURE_DIM*offset_ccw + FEATURE_DIM + f] 118 | // ); 119 | float oldVal = 0.0; 120 | d_V_discontinuous_feat.InterlockedAdd( 121 | 2*FEATURE_DIM*offset_cw + f, 122 | uv_curr*w_cw*d_interpolated_features[FEATURE_DIM*q_idx + f], 123 | oldVal 124 | ); 125 | d_V_discontinuous_feat.InterlockedAdd( 126 | 2*FEATURE_DIM*offset_ccw + FEATURE_DIM + f, 127 | uv_curr*w_ccw*d_interpolated_features[FEATURE_DIM*q_idx + f], 128 | oldVal 129 | ); 130 | } 131 | } 132 | } 133 | } 134 | } -------------------------------------------------------------------------------- /samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pydiffvg 3 | import math 4 | import numpy as np 5 | from svgpathtools import svg2paths, Line, Arc 6 | from PIL import Image 7 | import os 8 | 9 | import time 10 | 11 | def get_locs_zoom(x, y, width): 12 | hw = width/2 13 | return (x - hw, x + hw, y - hw, y+ hw) 14 | 15 | def get_stratified_random(n_sqrt, device=torch.device("cuda")): 16 | x = torch.linspace(0, 1, n_sqrt, device=device) 17 | y = torch.linspace(0, 1, n_sqrt, device=device) 18 | xx, yy = torch.meshgrid(x, y) 19 | xx, yy = xx.reshape(-1), yy.reshape(-1) 20 | xx = xx + torch.rand_like(xx) / n_sqrt 21 | yy = yy + torch.rand_like(yy) / n_sqrt 22 | return torch.stack([xx, yy], axis=-1) 23 | 24 | def get_stratified_in_triangles(model, device=torch.device("cuda")): 25 | triangles = model.T 26 | vertices = model.V 27 | N = triangles.shape[0] 28 | u1 = torch.rand(N, device=device).unsqueeze(1) 29 | u2 = torch.rand(N, device=device).unsqueeze(1) 30 | sqrt_u1 = torch.sqrt(u1) 31 | tri_idxs = torch.arange(N, device=device) 32 | ABC = vertices[triangles[tri_idxs].long()] 33 | ABC = torch.tensor(ABC, device=device) 34 | res = ABC[:,0]*(1.0 - sqrt_u1) + ABC[:,1]*(1.0 - u2)*sqrt_u1 + ABC[:,2]*u2*sqrt_u1 35 | return res 36 | 37 | class BaseSampler: 38 | def __call__(self): 39 | raise NotImplementedError 40 | 41 | def plot_locs(self): 42 | return [ 43 | (0.0, 1.0, 0.0, 1.0), 44 | *[get_locs_zoom(0.4182068, 0.7723169, zoom) for zoom in [0.5, 0.2, 0.1, 0.01]] 45 | ] 46 | 47 | ################ Vector Graphics ############# 48 | 49 | class VGSampler(BaseSampler): 50 | def __init__(self, fname): 51 | pydiffvg.set_device(torch.device('cuda:0')) 52 | svg_fname = fname.split(".")[0] + '.svg' 53 | print(f"Loaded SVG file: {svg_fname}") 54 | self.w, self.h, self.sh, self.sh_grp = pydiffvg.svg_to_scene(svg_fname) 55 | 56 | def __call__(self, Q): 57 | x = Q[:,0]*self.w 58 | # x = (1.0 - Q[:,0])*self.w 59 | y = Q[:,1]*self.h 60 | xy = torch.stack([y,x], axis=-1) 61 | scene_args = pydiffvg.RenderFunction.serialize_scene(self.w, self.h, self.sh, self.sh_grp, eval_positions=xy) 62 | samples = pydiffvg.RenderFunction.apply(self.w, self.h, 0, 0, 0, None, *scene_args) 63 | return samples[:,:3] 64 | 65 | ############# Rendering ############## 66 | 67 | class RenderingSampler(BaseSampler): 68 | def __init__(self, fname): 69 | print("Start: Loading Rendering Sampler") 70 | self.S = 50000 71 | self.img = torch.load('data/rendering/flowerpot/img.pt').cuda() 72 | print("End: Loading Rendering Sampler") 73 | 74 | def __call__(self, Q): 75 | _x, _y = Q[:,0], Q[:,1] 76 | _x_idx = torch.clamp((_x * self.S).type(torch.long).cpu(), 0, self.S - 1) 77 | _y_idx = torch.clamp((_y * self.S).type(torch.long).cpu(), 0, self.S - 1) 78 | res = self.img[_x_idx, _y_idx] 79 | res = res.type(torch.float32)/255.0 80 | return res 81 | 82 | ######### Walk on Spheres ######### 83 | class WoSScene: 84 | lines = None 85 | is_line = None 86 | left_img = None 87 | right_img = None 88 | 89 | 90 | class WoSSampler(BaseSampler): 91 | def __init__(self, fname, device=torch.device('cuda')): 92 | svg_fname = fname.split(".")[0] + '.svg' 93 | base_dir = fname.split("/")[:-1] 94 | paths, attributes = svg2paths(svg_fname) 95 | segments = [] 96 | segment_line = [] 97 | self.max_walk_length = 20 98 | self.eps = 1e-3 99 | self.spp = 10 100 | # self.spp = 100 101 | W = 1000 102 | for path in paths: 103 | for subpath in path: 104 | if type(subpath) == Line: 105 | segments.append( 106 | (torch.tensor([subpath.start.real/W, subpath.start.imag/W], device=device, dtype=torch.float), 107 | torch.tensor([subpath.end.real/W, subpath.end.imag/W], device=device, dtype=torch.float))) 108 | segment_line.append(True) 109 | elif type(subpath) == Arc: 110 | segments.append((torch.tensor([subpath.center.real/W, subpath.center.imag/W], device=device, dtype=torch.float), 111 | torch.tensor(subpath.radius.real/W, device=device, dtype=torch.float))) 112 | segment_line.append(False) 113 | break 114 | else: 115 | print(type(subpath)) 116 | continue 117 | # assert False, "Walk on Spheres Loader, unsupported segment type" 118 | self.wos_scene = WoSScene() 119 | self.wos_scene.lines = segments 120 | self.wos_scene.is_line = segment_line 121 | self.wos_scene.right_img = torch.tensor(np.asarray(Image.open(os.path.join(*base_dir, 'left.png'))) 122 | [:, :, :3]/255.0, device=device,dtype=torch.float32).transpose(0, 1) 123 | self.wos_scene.left_img = torch.tensor(np.asarray(Image.open(os.path.join(*base_dir, 'right.png'))) 124 | [:, :, :3]/255.0, device=device, dtype=torch.float32).transpose(0, 1) 125 | 126 | def __call__(self, Q): 127 | XY = Q[:,[1,0]] 128 | return walk_on_spheres(XY, self.spp, self.max_walk_length, self.eps, self.wos_scene) 129 | 130 | def dot(x, y): 131 | return torch.sum(x * y, axis=1) 132 | 133 | def length(x): 134 | return torch.sqrt(torch.sum(torch.square(x), axis=1)) 135 | 136 | 137 | def line_segment(p, a, b): 138 | a = a.unsqueeze(0) 139 | b = b.unsqueeze(0) 140 | ba = b - a 141 | pa = p - a 142 | h = torch.clamp(dot(pa, ba) / (dot(ba, ba) + 1e-8), 0., 1.).unsqueeze(1) 143 | return length(pa - h * ba) 144 | 145 | def line_implicit(p, a, b): 146 | x1, y1 = a[0], a[1] 147 | x2, y2 = b[0], b[1] 148 | 149 | return (y2-y1)*p[:, 0] - (x2-x1)*p[:, 1] + x2*y1 - x1*y2 150 | 151 | 152 | def sdCircle(p, c, r): 153 | c = c.unsqueeze(0) 154 | return length(p - c) - r 155 | 156 | 157 | def walk_on_spheres(XY, spp, max_walk_length, eps, scene_obj, device=torch.device('cuda')): 158 | def scene(p, scene_obj): 159 | d = 1e-2 160 | lines = scene_obj.lines 161 | is_line = scene_obj.is_line 162 | img_w = scene_obj.left_img.shape[0] 163 | img_h = scene_obj.left_img.shape[1] 164 | 165 | def color(p, side): 166 | px = p[:, 0] 167 | py = p[:, 1] 168 | x = torch.clip(torch.round(px * img_w), 0, img_w - 1).type(torch.long) 169 | y = torch.clip(torch.round(py * img_h), 0, img_h - 1).type(torch.long) 170 | 171 | return torch.where(side.unsqueeze(1), scene_obj.left_img[x, y], scene_obj.right_img[x, y]) 172 | 173 | min_dist = 10000 + torch.zeros(p.shape[0], device=device) 174 | min_colors = torch.zeros((p.shape[0], 3), device=device) 175 | for line_idx, line in enumerate(lines): 176 | if is_line[line_idx]: 177 | curr_dist = line_segment(p, *line) 178 | assert not torch.any(torch.isnan(curr_dist)) 179 | side_line = line_implicit(p, *line) <= 0 180 | else: 181 | sd = sdCircle(p, *line) 182 | curr_dist = torch.abs(sd) 183 | assert not torch.any(torch.isnan(sd)) 184 | side_line = sd <= 0 185 | curr_color = color(p, side_line) 186 | # curr_closest_point = closest_point_on_line(p, *line) 187 | min_colors = torch.where( 188 | (curr_dist < min_dist).unsqueeze(1), curr_color, min_colors) 189 | min_dist = torch.minimum(min_dist, curr_dist) 190 | return min_dist, min_colors 191 | 192 | p = torch.repeat_interleave(XY, spp, dim=0) 193 | 194 | N = p.shape[0] 195 | 196 | final_color = torch.zeros((N, 3), dtype=torch.float32, device=device) 197 | temp_color = torch.zeros((N, 3), dtype=torch.float32, device=device) 198 | temp_dist = torch.zeros(N, dtype=torch.float32, device=device) 199 | temp_hitb = torch.zeros(N, dtype=torch.bool, device=device) 200 | active = torch.ones(N, dtype=bool, device=device) 201 | for walk_length in range(max_walk_length): 202 | if active.sum() == 0: 203 | break 204 | dist, color = scene(p[active], scene_obj) 205 | 206 | hit_b = dist < eps 207 | temp_color.zero_() 208 | temp_dist.zero_() 209 | temp_hitb.zero_() 210 | temp_color[active] = color 211 | temp_dist[active] = dist 212 | temp_hitb[active] = hit_b 213 | final_color = torch.where(torch.logical_and( 214 | active, temp_hitb).unsqueeze(1), temp_color, final_color) 215 | active[temp_hitb] = False 216 | 217 | theta = 2 * torch.pi * torch.rand(N, device=device) 218 | delta_p = temp_dist.unsqueeze(1) * torch.stack([torch.cos(theta), torch.sin(theta)], axis=1) 219 | p = p + delta_p 220 | return torch.mean(final_color.reshape(XY.shape[0], spp, 3), axis=1) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 3 | 4 | import torch 5 | import torch.nn as nn 6 | import slangpy 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | BLOCK_SIZE_1D=256 11 | def launch_1d(x, LEN): 12 | x.launchRaw( 13 | blockSize=(BLOCK_SIZE_1D, 1, 1), 14 | gridSize=(LEN // BLOCK_SIZE_1D + 1, 1, 1) 15 | ) 16 | 17 | CUDA_DEVICE = torch.device('cuda') 18 | 19 | ############# Load all slang modules ############# 20 | slang_vertex_is_continuous = slangpy.loadModule("slang/vertex-is-continuous.slang") 21 | slang_set_curved_inside_sign = slangpy.loadModule("slang/set-curved-inside-sign.slang") 22 | slang_count_adj_disc = slangpy.loadModule("slang/count-adj-disc.slang") 23 | slang_add_adj_disc = slangpy.loadModule("slang/add-adj-disc.slang") 24 | slang_link_radial_feats = slangpy.loadModule("slang/link-radial-feats.slang") 25 | slang_seco_link_radial_feats = slangpy.loadModule("slang/seco-link-radial-feats.slang") 26 | slang_count_triangles_per_cell = slangpy.loadModule("slang/count-triangles-per-cell.slang") 27 | slang_add_triangles_per_cell = slangpy.loadModule("slang/add-triangles-per-cell.slang") 28 | slang_point_in_triangle = slangpy.loadModule("slang/point-in-triangle.slang") 29 | slang_feature_interpolation = slangpy.loadModule("slang/feature-interpolation.slang") 30 | slang_d_feature_interpolation = slangpy.loadModule("slang/d-feature-interpolation.slang") 31 | 32 | class DiscontinuityAwareInterpolation(torch.autograd.Function): 33 | @staticmethod 34 | def forward(ctx, FEATURE_DIM, Q, V, T, QT_idx, QT_uv, V_is_continuous, T_adj_disc_V, T_adj_disc_feat_idx, V_continuous_feat, V_discontinuous_feat, V_curved_feat, T_bez_cp_idx, T_inside_cubic_sign, seco_T_adj_disc_feat_idx, T_NUM_CURVE): 35 | Q_NUM = Q.shape[0] 36 | T_NUM = T.shape[0] 37 | interpolated_features = torch.zeros((Q_NUM*FEATURE_DIM), dtype=torch.float, device=CUDA_DEVICE) 38 | launch_1d(slang_feature_interpolation.run( 39 | T_NUM=T_NUM, FEATURE_DIM=FEATURE_DIM, Q=Q, V=V, T=T, 40 | QT_idx=QT_idx, QT_uv=QT_uv, 41 | V_is_continuous=V_is_continuous, 42 | V_continuous_feat=V_continuous_feat, 43 | V_discontinuous_feat=V_discontinuous_feat, 44 | V_curved_feat=V_curved_feat, 45 | interpolated_features=interpolated_features, 46 | T_adj_disc_V=T_adj_disc_V, 47 | T_adj_disc_feat_idx=T_adj_disc_feat_idx, 48 | T_bez_cp_idx=T_bez_cp_idx, 49 | T_inside_cubic_sign=T_inside_cubic_sign, 50 | seco_T_adj_disc_feat_idx=seco_T_adj_disc_feat_idx, 51 | T_NUM_CURVE=T_NUM_CURVE 52 | ), LEN=Q_NUM) 53 | interpolated_features = interpolated_features.reshape(Q_NUM, FEATURE_DIM) 54 | ctx.save_for_backward(Q, V, T, QT_idx, QT_uv, V_is_continuous, V_continuous_feat.data, V_discontinuous_feat.data, V_curved_feat.data, T_adj_disc_V, T_adj_disc_feat_idx, interpolated_features, T_bez_cp_idx, T_inside_cubic_sign, seco_T_adj_disc_feat_idx, torch.tensor(T_NUM_CURVE)) 55 | return interpolated_features 56 | 57 | @staticmethod 58 | def backward(ctx, grad_output): 59 | (Q, V, T, QT_idx, QT_uv, V_is_continuous, V_continuous_feat, V_discontinuous_feat, V_curved_feat, T_adj_disc_V, T_adj_disc_feat_idx, interpolated_features, T_bez_cp_idx, T_inside_cubic_sign, seco_T_adj_disc_feat_idx, T_NUM_CURVE) = ctx.saved_tensors 60 | 61 | FEATURE_DIM = interpolated_features.shape[-1] 62 | Q_NUM = Q.shape[0] 63 | T_NUM = T.shape[0] 64 | 65 | d_V_continuous_feat = torch.zeros_like(V_continuous_feat) 66 | d_V_discontinuous_feat = torch.zeros_like(V_discontinuous_feat) 67 | d_V_curved_feat = torch.zeros_like(V_curved_feat) 68 | 69 | d_interpolated_features = grad_output.reshape(-1) 70 | interpolated_features = interpolated_features.reshape(-1) 71 | launch_1d(slang_d_feature_interpolation.run( 72 | T_NUM=T_NUM, FEATURE_DIM=FEATURE_DIM, Q=Q, V=V, T=T, 73 | QT_idx=QT_idx, QT_uv=QT_uv, 74 | V_is_continuous=V_is_continuous, 75 | V_continuous_feat=V_continuous_feat, 76 | V_discontinuous_feat=V_discontinuous_feat, 77 | V_curved_feat=V_curved_feat, 78 | interpolated_features=interpolated_features, 79 | T_adj_disc_V=T_adj_disc_V, 80 | T_adj_disc_feat_idx=T_adj_disc_feat_idx, 81 | T_bez_cp_idx=T_bez_cp_idx, 82 | T_inside_cubic_sign=T_inside_cubic_sign, 83 | seco_T_adj_disc_feat_idx=seco_T_adj_disc_feat_idx, 84 | T_NUM_CURVE=T_NUM_CURVE, 85 | d_V_continuous_feat=d_V_continuous_feat, 86 | d_V_discontinuous_feat=d_V_discontinuous_feat, 87 | d_V_curved_feat=d_V_curved_feat, 88 | d_interpolated_features=d_interpolated_features 89 | ), LEN=Q_NUM) 90 | 91 | return tuple([None for _ in range(9)]) + (d_V_continuous_feat, d_V_discontinuous_feat, d_V_curved_feat) + tuple([None for _ in range(4)]) 92 | 93 | class DANN(torch.nn.Module): 94 | def __init__(self, mesh, FEATURE_DIM=5, ACCEL_GRID_DIMS=(100, 100), USE_PE=False, OUT_DIM=3, INFERENCE=False): 95 | super(DANN, self).__init__() 96 | T_continuous = mesh['continuous_triangles'] 97 | T_linear = mesh['linear_triangles'] 98 | T_curve = mesh['curved_triangles'] 99 | 100 | self.T_NUM_CONTINUOUS = len(T_continuous) 101 | self.T_NUM_LINEAR = len(T_linear) 102 | self.T_NUM_CURVE = len(T_curve) 103 | 104 | self.T_NUM = self.T_NUM_CURVE + self.T_NUM_LINEAR + self.T_NUM_CONTINUOUS 105 | self.T_NUM_DISC = self.T_NUM_LINEAR + self.T_NUM_CURVE 106 | 107 | self.T = [] 108 | T_bez_cp_idx = [] 109 | if self.T_NUM_CURVE > 0: 110 | self.T.append(T_curve[:,:3]) 111 | T_bez_cp_idx = T_curve[:,3:5] 112 | if self.T_NUM_LINEAR > 0: 113 | self.T.append(T_linear[:,:3]) 114 | if self.T_NUM_CONTINUOUS > 0: 115 | self.T.append(T_continuous) 116 | 117 | self.T = torch.tensor(np.concatenate(self.T)).cuda().contiguous() 118 | self.V = torch.tensor(mesh['vertices']).cuda().contiguous() 119 | 120 | V_NUM = len(self.V) 121 | 122 | # if INFERENCE: 123 | # # In inference mode, all the other necessary tensors are already saved in the checkpoint 124 | # return 125 | 126 | # For the curved triangles, store the vertex indices for the bezier control points 127 | if self.T_NUM_CURVE > 0: 128 | self.register_buffer("T_bez_cp_idx", torch.tensor(T_bez_cp_idx, dtype=torch.int32).cuda().contiguous()) 129 | self.register_buffer("T_inside_cubic_sign", torch.zeros(self.T_NUM_CURVE, dtype=torch.bool).cuda()) 130 | else: 131 | self.register_buffer("T_bez_cp_idx", torch.tensor([[0, 0]], dtype=torch.int32).cuda()) 132 | self.register_buffer("T_inside_cubic_sign", torch.tensor([0.0], dtype=torch.bool).cuda()) 133 | launch_1d(slang_set_curved_inside_sign.run(T_NUM_CURVE=self.T_NUM_CURVE, V=self.V, T=self.T, T_bez_cp_idx=self.T_bez_cp_idx, T_inside_cubic_sign=self.T_inside_cubic_sign), LEN=self.T_NUM_CURVE) 134 | 135 | # For each vertex is it continuous or not 136 | # Need to use int because interlockedand is only for int/ float 137 | temp_V_is_continuous = torch.ones(V_NUM, dtype=torch.int).cuda() 138 | launch_1d(slang_vertex_is_continuous.run(T_NUM_DISC=self.T_NUM_DISC, V=self.V, T=self.T, V_is_continuous=temp_V_is_continuous), LEN=self.T_NUM_DISC) 139 | # self.V_is_continuous = temp_V_is_continuous > 0 140 | self.register_buffer("V_is_continuous", temp_V_is_continuous > 0) 141 | 142 | # Counter Clock Wise (CCW) feature definition: A vertex-edge feature X->Y (for the vertex X) is called counter clockwise if it lies along the directed edge X->Y (i.e the edge appears in a triangle [X,Y,Z] or [Z,X,Y] or [Y,Z,X]). 143 | # Clock Wise (CW) feature definition: A vertex-edge feature X->Y (for the vertex X) is called clockwise if it is not counter clock wise. 144 | 145 | # Total number of adjcent vertices that are discontinuous 146 | V_num_adj_disc_V = torch.zeros(V_NUM, dtype=torch.int).cuda() 147 | launch_1d(slang_count_adj_disc.run(T_NUM=self.T_NUM, T_NUM_DISC=self.T_NUM_DISC, V=self.V, T=self.T, V_num_adj_disc_V=V_num_adj_disc_V, V_is_continuous=self.V_is_continuous), LEN=self.T_NUM) 148 | self.V_num_adj_disc_V = V_num_adj_disc_V 149 | 150 | 151 | temp_V_num_adj_disc_V = torch.zeros(V_NUM, dtype=torch.int).cuda() 152 | 153 | V_NUM_DIRECTED_EDGES = torch.sum(V_num_adj_disc_V) 154 | 155 | V_adj_disc_idx_ptr = torch.cumsum(V_num_adj_disc_V, dim=0, dtype=torch.int) 156 | V_adj_disc_idx_ptr -= V_num_adj_disc_V 157 | 158 | V_adj_disc_idx = torch.zeros(torch.sum(V_num_adj_disc_V), dtype=torch.int).cuda() - 1 159 | launch_1d(slang_add_adj_disc.run( 160 | T_NUM=self.T_NUM, T_NUM_DISC=self.T_NUM_DISC, V=self.V, T=self.T, 161 | temp_V_num_adj_disc_V=temp_V_num_adj_disc_V, 162 | V_adj_disc_idx_ptr=V_adj_disc_idx_ptr, 163 | V_adj_disc_idx=V_adj_disc_idx, 164 | V_is_continuous=self.V_is_continuous 165 | ), LEN=self.T_NUM) 166 | 167 | self.register_buffer("T_adj_disc_V", torch.zeros((self.T_NUM * 6), dtype=torch.int).cuda() - 1) 168 | self.register_buffer("T_adj_disc_feat_idx", torch.zeros((self.T_NUM * 6), dtype=torch.int).cuda() - 1) 169 | 170 | 171 | launch_1d(slang_link_radial_feats.run( 172 | T_NUM=self.T_NUM, V=self.V, T=self.T, 173 | V_is_continuous=self.V_is_continuous, 174 | V_adj_disc_idx=V_adj_disc_idx, 175 | V_adj_disc_idx_ptr=V_adj_disc_idx_ptr, 176 | V_num_adj_disc_V=V_num_adj_disc_V, 177 | T_adj_disc_V=self.T_adj_disc_V, 178 | T_adj_disc_feat_idx=self.T_adj_disc_feat_idx, 179 | ), LEN=self.T_NUM) 180 | 181 | self.register_buffer("seco_T_adj_disc_feat_idx", torch.zeros(max((self.T_NUM_CURVE * 2), 1), dtype=torch.int).cuda() - 1) 182 | launch_1d(slang_seco_link_radial_feats.run( 183 | T_NUM_CURVE=self.T_NUM_CURVE, V=self.V, T=self.T, 184 | V_is_continuous=self.V_is_continuous, 185 | V_adj_disc_idx=V_adj_disc_idx, 186 | V_adj_disc_idx_ptr=V_adj_disc_idx_ptr, 187 | V_num_adj_disc_V=V_num_adj_disc_V, 188 | seco_T_adj_disc_feat_idx=self.seco_T_adj_disc_feat_idx, 189 | ), LEN=self.T_NUM_CURVE) 190 | 191 | ############### SAVE ALL THE BUFFERS NEEDED FOR INFERENCE ################## 192 | # self.register_buffer("V_is_continuous", V_is_continuous) 193 | # self.register_buffer("T_bez_cp_idx", T_bez_cp_idx) 194 | # self.register_buffer("T_inside_cubic_sign", T_inside_cubic_sign) 195 | # self.register_buffer("T_adj_disc_V", T_adj_disc_V) 196 | # self.register_buffer("T_adj_disc_feat_idx", T_adj_disc_feat_idx) 197 | # self.register_buffer("seco_T_adj_disc_feat_idx", seco_T_adj_disc_feat_idx) 198 | 199 | 200 | ############### INIITIALIZE ALL FEATURE BUFFERS ############## 201 | # feature at continuous vertex (memory here can be reduced if we only define these for cont vertex) 202 | # will require another pointer array to the offsets into this though, which is annoying and has 203 | # memory requirements too 204 | self.FEATURE_DIM = FEATURE_DIM 205 | UR = 0.002 206 | self.V_continuous_feat = torch.nn.Parameter(torch.randn((V_NUM * self.FEATURE_DIM), dtype=torch.float, device=CUDA_DEVICE)) 207 | nn.init.uniform_(self.V_continuous_feat.data, -UR, UR) 208 | # feature at discontinuous vertex. even index: CCW feature, odd index: CW feature 209 | self.V_discontinuous_feat = torch.nn.Parameter(torch.randn((V_NUM_DIRECTED_EDGES*2 * self.FEATURE_DIM), dtype=torch.float, device=CUDA_DEVICE)) 210 | nn.init.uniform_(self.V_discontinuous_feat.data, -UR, UR) 211 | # extra feature for single vertex on curved triangle 212 | self.V_curved_feat = torch.nn.Parameter(torch.randn((max(self.T_NUM_CURVE, 1) * self.FEATURE_DIM), dtype=torch.float, device=CUDA_DEVICE)) 213 | nn.init.uniform_(self.V_curved_feat.data, -UR, UR) 214 | 215 | MLP_INP_DIM = FEATURE_DIM 216 | 217 | self.mlp = nn.Sequential( 218 | nn.Linear(MLP_INP_DIM, 64), 219 | nn.ReLU(), 220 | nn.Linear(64, 64), 221 | nn.ReLU(), 222 | nn.Linear(64, OUT_DIM) 223 | ).cuda() 224 | 225 | 226 | ############# SETUP ACCELERATION STRUCTURE FOR POINT IN TRIANGLE QUERY ############# 227 | self.A_X, self.A_Y = ACCEL_GRID_DIMS 228 | self.A_NUM_CELLS = self.A_X * self.A_Y 229 | cell_triangle_count = torch.zeros((self.A_Y, self.A_X), dtype=torch.int, device=CUDA_DEVICE).reshape(-1) 230 | 231 | ######### COUNT THE NUMBER OF TRIANGLES IN EACH CELL ########## 232 | launch_1d(slang_count_triangles_per_cell.run(Y=self.A_Y, X=self.A_X, T_NUM=self.T_NUM, V=self.V, T=self.T, cell_triangle_count=cell_triangle_count), LEN=self.T_NUM) 233 | 234 | CELL_TOTAL_TRIANGLES = torch.sum(cell_triangle_count) 235 | 236 | # This 1D array has a list of triangles for each cell 237 | self.cell_to_triangle_index = torch.zeros(CELL_TOTAL_TRIANGLES, dtype=torch.int, device=CUDA_DEVICE) -1 238 | # This 1D array tells you the index of the first triangle for each cell in `cell_to_triangle_index` 239 | self.cell_to_triangle_index_ptr = torch.cumsum(cell_triangle_count, dim=0, dtype=torch.int) 240 | self.cell_to_triangle_index_ptr -= cell_triangle_count 241 | self.cell_to_triangle_index_ptr = torch.cat([self.cell_to_triangle_index_ptr, torch.tensor([CELL_TOTAL_TRIANGLES], dtype=torch.int, device=CUDA_DEVICE)]) 242 | temp_cell_triangle_count = torch.zeros((self.A_Y, self.A_X), dtype=torch.int, device=CUDA_DEVICE).reshape(-1) 243 | 244 | ########## ADD TRIANGLE IDS TO EACH CELL BUFFER ###### 245 | launch_1d(slang_add_triangles_per_cell.run(Y=self.A_Y, X=self.A_X, T_NUM=self.T_NUM, V=self.V, T=self.T, 246 | cell_to_triangle_index_ptr=self.cell_to_triangle_index_ptr, 247 | cell_to_triangle_index=self.cell_to_triangle_index, 248 | temp_cell_triangle_count=temp_cell_triangle_count, 249 | ), LEN=self.T_NUM) 250 | 251 | ''' 252 | Inputs: Q: [N,2] array of query points. Every point is in [0,1]^2 253 | Outputs: [N, out_dims] array of colors/ function values predicted for the query points 254 | ''' 255 | def forward(self, Q): 256 | # Step 1: Perform point in triangle query to get triangle indices and barycentric coordinates 257 | Q_NUM = Q.shape[0] 258 | # Initialize buffers to store output 259 | QT_idx = torch.zeros(Q_NUM, dtype=torch.int, device=CUDA_DEVICE) - 1 260 | QT_uv = torch.zeros((Q_NUM,3), dtype=torch.float, device=CUDA_DEVICE) - 1.0 261 | 262 | launch_1d(slang_point_in_triangle.run(Y=self.A_Y, X=self.A_X, Q=Q, V=self.V, T=self.T, 263 | cell_to_triangle_index=self.cell_to_triangle_index, 264 | cell_to_triangle_index_ptr=self.cell_to_triangle_index_ptr, 265 | QT_idx=QT_idx, QT_uv=QT_uv 266 | ), LEN=Q_NUM) 267 | assert torch.all(QT_idx >= 0), "PIT: didn't find a triangle" 268 | assert torch.all(QT_uv >= 0) and torch.all(QT_uv <= 1), "PIT: UV out of bounds" 269 | assert torch.allclose(QT_uv.sum(axis=-1), torch.ones_like(QT_uv.sum(axis=-1))), "PIT: UV sum not one" 270 | 271 | interpolated_features = DiscontinuityAwareInterpolation.apply(self.FEATURE_DIM, Q, self.V, self.T, QT_idx, QT_uv, self.V_is_continuous, self.T_adj_disc_V, self.T_adj_disc_feat_idx, self.V_continuous_feat, self.V_discontinuous_feat, self.V_curved_feat, self.T_bez_cp_idx, self.T_inside_cubic_sign, self.seco_T_adj_disc_feat_idx, self.T_NUM_CURVE) 272 | 273 | res = self.mlp(interpolated_features).type(torch.float32) 274 | return res 275 | --------------------------------------------------------------------------------