├── .gitignore ├── .gitmodules ├── 01_simpleMeshCNN └── main.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | datasets 2 | old 3 | jaxgptoolbox/__pycache__ 4 | jaxgptoolbox/.DS_Store 5 | .DS_Store 6 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "jaxgptoolbox"] 2 | path = jaxgptoolbox 3 | url = https://github.com/ml-for-gp/jaxgptoolbox 4 | -------------------------------------------------------------------------------- /01_simpleMeshCNN/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | 4 | import jax 5 | import jax.numpy as np 6 | from jax import jit, value_and_grad 7 | from jax.experimental import optimizers 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as onp 11 | import numpy.random as random 12 | import glob, time, pickle 13 | 14 | import jaxgptoolbox as jgp 15 | 16 | random.seed(0) # set random seed 17 | 18 | def compute_input(V,F,E,flap): 19 | # dihedral angles 20 | dihedral_angles, _ = jgp.dihedral_angles(V,F) 21 | dihedral_angles = dihedral_angles / np.pi # normalize to 0 ~ 1 22 | 23 | # edge length ratios 24 | # / \ 25 | # b a 26 | # / \ 27 | # - e - - 28 | # \ / 29 | # c d 30 | # \ / 31 | Elen = np.sqrt(np.sum((V[E[:,0],:] - V[E[:,1],:])**2, axis = 1)) 32 | nE = E.shape[0] 33 | ratio_e_a = Elen[flap[:,0]] / Elen 34 | ratio_e_b = Elen[flap[:,1]] / Elen 35 | ratio_e_c = Elen[flap[:,2]] / Elen 36 | ratio_e_d = Elen[flap[:,3]] / Elen 37 | ratio_features = np.stack((ratio_e_a, ratio_e_b, ratio_e_c, ratio_e_d), axis = 1) 38 | ratio_features = np.sort(ratio_features, axis = 1) 39 | 40 | # concatenate 41 | dihedral_angles = np.expand_dims(dihedral_angles,1) 42 | fE = np.concatenate((dihedral_angles,ratio_features), axis = 1) 43 | return fE # fE.shape = (#edges, #channels) 44 | 45 | trainFolder = '../datasets/meshMNIST_100V/12_meshMNIST/' # replace this with the path to the dataset folder 46 | labels = np.asarray(onp.loadtxt(trainFolder + 'labels.txt', delimiter=','), dtype=np.int16) # replace this with the path to the label file 47 | 48 | numMeshes = len(glob.glob(trainFolder + "*.obj")) 49 | meshList = [{} for sub in range(numMeshes)] 50 | print("this will take a while, so we should put it into preprocessing") 51 | for meshIdx in range(numMeshes): 52 | if meshIdx % 100 == 0: 53 | print(str(meshIdx) + "/" + str(numMeshes)) 54 | meshName = str(meshIdx+1).zfill(4) + ".obj" 55 | V,F = jgp.readOBJ(trainFolder + meshName) 56 | E, flap = jgp.edge_flaps(F) 57 | fE = compute_input(V,F,E,flap) 58 | 59 | meshList[meshIdx]["V"] = V 60 | meshList[meshIdx]["F"] = F 61 | meshList[meshIdx]["flap"] = flap 62 | meshList[meshIdx]["fE"] = fE 63 | meshList[meshIdx]["label"] = labels[meshIdx,1] - 1.0 # 0: "1", 1: "2" 64 | 65 | def initialize_meshCNN_weights(conv_dims, mlp_dims): 66 | num_flap_edges = 5 67 | scale = 1e-2 68 | 69 | # Conv filter parameters 70 | params_conv = [] 71 | for ii in range(len(conv_dims) - 1): 72 | C = scale * random.randn(conv_dims[ii+1], conv_dims[ii], num_flap_edges) 73 | params_conv.append(C) 74 | 75 | # MLP parameters 76 | params_mlp = [] 77 | W = scale * random.randn(mlp_dims[0], conv_dims[-1]) 78 | b = scale * random.randn(mlp_dims[0]) 79 | params_mlp.append([W, b]) 80 | for ii in range(len(mlp_dims) - 1): 81 | W = scale * random.randn(mlp_dims[ii+1], mlp_dims[ii]) 82 | b = scale * random.randn(mlp_dims[ii+1]) 83 | params_mlp.append([W, b]) 84 | 85 | params = (params_conv, params_mlp) 86 | return params 87 | 88 | conv_dims = [5,8,8] 89 | mlp_dims = [4,1] 90 | params = initialize_meshCNN_weights(conv_dims, mlp_dims) 91 | 92 | def forward(fE, params, flap): 93 | # ========================= 94 | # edge functions (fE) to flap functions (fP) 95 | def E2P(fE, flap): 96 | # fE.shape = (#edges, #channels) 97 | # fP.shape = (#edges, #channels, 5) 98 | e = fE[np.arange(flap.shape[0]),:] 99 | a = fE[flap[:,0],:] 100 | b = fE[flap[:,1],:] 101 | c = fE[flap[:,2],:] 102 | d = fE[flap[:,3],:] 103 | fP = np.stack((e, np.abs(a-c), a+c, np.abs(b-d), b+d), axis = 2) 104 | return fP 105 | 106 | # mesh convolution 107 | def mesh_conv_single(W, fP_i): 108 | # fP_i.shape = (#input_channels, 5) 109 | # W.shape = (#output_channels, #input_channels 5) 110 | dot = W * np.expand_dims(fP_i, 0) 111 | fE = np.sum(np.sum(dot,1),1) 112 | return fE 113 | # vectorize "mesh_conv_single" so that it can process all fP 114 | mesh_conv = jax.vmap(mesh_conv_single, in_axes=(None, 0), out_axes=0) 115 | # ========================= 116 | 117 | params_conv = params[0] 118 | params_mlp = params[1] 119 | 120 | # mesh convolutions 121 | for ii in range(len(params_conv)): 122 | conv_W = params_conv[ii] 123 | fP = E2P(fE, flap) 124 | fE = mesh_conv(conv_W, fP) 125 | fE = jax.nn.relu(fE) 126 | 127 | # global pooling (max pooling) 128 | f = np.max(fE,0) 129 | 130 | # fully connected 131 | for ii in range(len(params_mlp)): 132 | W, b = params_mlp[ii] 133 | f = np.dot(W, f) + b 134 | if ii == (len(params_mlp) - 1): 135 | f = jax.nn.sigmoid(f) 136 | else: 137 | f = jax.nn.relu(f) 138 | return f 139 | 140 | pred = forward(meshList[0]['fE'], params, meshList[0]['flap']) 141 | print(pred) 142 | 143 | def loss(params, flap, input, label): 144 | pred = forward(input, params, flap)[0] 145 | return -label * np.log(pred) - (1.0-label) * np.log(1 - pred) 146 | 147 | stepSize = 1e-3 148 | opt_init, opt_update, get_params = optimizers.adam(step_size=stepSize) 149 | opt_state = opt_init(params) 150 | 151 | @jit 152 | def update(epoch, opt_state, input, label, flap): 153 | params = get_params(opt_state) 154 | value, grads = value_and_grad(loss, argnums=0)(params, flap, input, label) # backpropagation 155 | opt_state = opt_update(epoch, grads, opt_state) 156 | return value, opt_state 157 | 158 | numEpochs = 200 159 | for epoch in range(numEpochs): 160 | ts = time.time() 161 | loss_total = 0.0 162 | for meshIdx in range(numMeshes): 163 | lossVal, opt_state = update(epoch, opt_state, \ 164 | meshList[meshIdx]['fE'], \ 165 | meshList[meshIdx]['label'], \ 166 | meshList[meshIdx]['flap']) 167 | loss_total += lossVal 168 | loss_total /= numMeshes 169 | 170 | if epoch % 10 == 0: 171 | print("epoch %d, train loss %f, epoch time: %.7s sec" % (epoch, loss_total, time.time() - ts)) 172 | 173 | NETPARAM = "params.pickle" 174 | params_final = get_params(opt_state) 175 | with open(NETPARAM, 'wb') as handle: 176 | pickle.dump(params_final, handle, protocol=pickle.HIGHEST_PROTOCOL) 177 | 178 | 179 | # testing 180 | params_final = get_params(opt_state) 181 | pred = forward(meshList[0]['fE'], params_final, meshList[0]['flap']) 182 | print("testing") 183 | print(pred) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tutorials 2 | 3 | This repository contains a set of tutorial code for mesh-based neural networks. This tutorial also accompanies a SIGGRAPH 2021 course -- An Introduction to Deep Learning on Meshes. All the other materials can be found [here](https://anintroductiontodeeplearningonmeshes.github.io). 4 | 5 | If any questions, please contact [Hsueh-Ti Derek Liu](https://www.dgp.toronto.edu/~hsuehtil/) or [Rana Hanocka](https://people.cs.uchicago.edu/~ranahanocka/). --------------------------------------------------------------------------------