├── AudioNet_model.py ├── AudioNet_utils.py ├── LICENSE ├── OF_render.py ├── ObjectFolder1.0 ├── AudioNet_model.py ├── AudioNet_utils.py ├── ObjectFolder_teaser.png ├── README.md ├── TouchNet_model.py ├── TouchNet_utils.py ├── VisionNet_utils.py ├── box_utils.py ├── cam_utils.py ├── dataset_visualization.png ├── demo │ ├── audio_demo_forces.npy │ ├── audio_demo_vertices.npy │ ├── touch_demo_vertices.npy │ └── vision_demo.npy ├── evaluate.py ├── indirect_utils.py ├── intersect.py ├── load_osf.py ├── ray_utils.py ├── requirements.txt ├── run_osf_helpers.py ├── scatter.py └── shadow_utils.py ├── ObjectFolder2.0_teaser.png ├── README.md ├── TouchNet_model.py ├── TouchNet_utils.py ├── VisionNet_configs.py ├── VisionNet_utils.py ├── basics ├── CalibData.py ├── __init__.py └── sensorParams.py ├── box_utils.py ├── build_occupancy_tree.py ├── calibs ├── dataPack.npz ├── depth_bg.npy ├── polycalib.npz └── real_bg.npy ├── cam_utils.py ├── ddsp_torch.py ├── demo ├── ObjectFile.pth ├── audio_demo_forces.npy ├── audio_demo_vertices.npy ├── model.obj ├── touch_demo_gelinfo.npy ├── touch_demo_vertices.npy └── vision_demo.npy ├── environment.yml ├── fast_kilonerf_renderer.py ├── load_osf.py ├── local_distill.py ├── multi_modules.py ├── objects.csv ├── ray_utils.py ├── run_nerf_helpers.py ├── taxim_render.py ├── utils.py └── von_mises.py /AudioNet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.autograd.set_detect_anomaly(True) 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from AudioNet_utils import * 7 | 8 | class DenseLayer(nn.Linear): 9 | def __init__(self, in_dim: int, out_dim: int, activation: str = 'relu', *args, **kwargs) -> None: 10 | self.activation = activation 11 | super().__init__(in_dim, out_dim, *args, **kwargs) 12 | 13 | def reset_parameters(self) -> None: 14 | torch.nn.init.xavier_uniform_(self.weight, gain=torch.nn.init.calculate_gain(self.activation)) 15 | if self.bias is not None: 16 | torch.nn.init.zeros_(self.bias) 17 | 18 | class Embedder: 19 | def __init__(self, **kwargs): 20 | self.kwargs = kwargs 21 | self.create_embedding_fn() 22 | 23 | def create_embedding_fn(self): 24 | embed_fns = [] 25 | d = self.kwargs['input_dims'] 26 | out_dim = 0 27 | if self.kwargs['include_input']: 28 | embed_fns.append(lambda x: x) 29 | out_dim += d 30 | 31 | max_freq = self.kwargs['max_freq_log2'] 32 | N_freqs = self.kwargs['num_freqs'] 33 | 34 | if self.kwargs['log_sampling']: 35 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) 36 | else: 37 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) 38 | 39 | for freq in freq_bands: 40 | for p_fn in self.kwargs['periodic_fns']: 41 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 42 | out_dim += d 43 | 44 | self.embed_fns = embed_fns 45 | self.out_dim = out_dim 46 | 47 | def embed(self, inputs): 48 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 49 | 50 | def get_embedder(multires, i=0): 51 | if i == -1: 52 | #5 is for x, y, z, f, t 53 | return nn.Identity(), 5 54 | 55 | embed_kwargs = { 56 | 'include_input': True, 57 | 'input_dims': 3, 58 | 'max_freq_log2': multires-1, 59 | 'num_freqs': multires, 60 | 'log_sampling': True, 61 | 'periodic_fns': [torch.sin, torch.cos], 62 | } 63 | 64 | embedder_obj = Embedder(**embed_kwargs) 65 | embed = lambda x, eo=embedder_obj: eo.embed(x) 66 | return embed, embedder_obj.out_dim 67 | 68 | class AudioNeRF(nn.Module): 69 | def __init__(self, D=8, input_ch=5, output_ch=2): 70 | super(AudioNeRF, self).__init__() 71 | self.model_x = NeRF(D = D, input_ch = input_ch, output_ch = output_ch) 72 | self.model_y = NeRF(D = D, input_ch = input_ch, output_ch = output_ch) 73 | self.model_z = NeRF(D = D, input_ch = input_ch, output_ch = output_ch) 74 | 75 | def forward(self, embedded_x, embedded_y, embedded_z): 76 | results_x = self.model_x(embedded_x) 77 | results_y = self.model_y(embedded_y) 78 | results_z = self.model_z(embedded_z) 79 | return results_x, results_y, results_z 80 | 81 | 82 | class NeRF(nn.Module): 83 | def __init__(self, D=8, W=256, input_ch=5, input_ch_views=0, output_ch=2, skips=[4], use_viewdirs=False): 84 | """ 85 | """ 86 | super(NeRF, self).__init__() 87 | self.D = D 88 | self.W = W 89 | self.input_ch = input_ch 90 | self.input_ch_views = input_ch_views 91 | self.skips = skips 92 | self.use_viewdirs = use_viewdirs 93 | 94 | self.pts_linears = nn.ModuleList( 95 | [DenseLayer(input_ch, W, activation='relu')] + [DenseLayer(W, W, activation='relu') if i not in self.skips else DenseLayer(W + input_ch, W, activation='relu') for i in range(D-1)]) 96 | 97 | self.views_linears = nn.ModuleList([DenseLayer(input_ch_views + W, W//2, activation='relu')]) 98 | 99 | if use_viewdirs: 100 | self.feature_linear = DenseLayer(W, W, activation='sigmoid') 101 | #self.alpha_linear = DenseLayer(W, 1, activation='linear') 102 | self.rgb_linear = DenseLayer(W//2, output_ch, activation='sigmoid') 103 | else: 104 | self.output_linear = DenseLayer(W, output_ch, activation='sigmoid') 105 | 106 | 107 | def forward(self, x): 108 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1) 109 | h = input_pts 110 | for i, l in enumerate(self.pts_linears): 111 | h = self.pts_linears[i](h) 112 | h = F.relu(h) 113 | if i in self.skips: 114 | h = torch.cat([input_pts, h], -1) 115 | 116 | if self.use_viewdirs: 117 | feature = self.feature_linear(h) 118 | h = torch.cat([feature, input_views], -1) 119 | 120 | for i, l in enumerate(self.views_linears): 121 | h = self.views_linears[i](h) 122 | h = F.relu(h) 123 | 124 | outputs = self.rgb_linear(h) 125 | else: 126 | outputs = self.output_linear(h) 127 | 128 | return outputs 129 | -------------------------------------------------------------------------------- /AudioNet_utils.py: -------------------------------------------------------------------------------- 1 | from scipy.io import wavfile 2 | import librosa 3 | import librosa.display 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from AudioNet_model import * 7 | import os 8 | from collections import OrderedDict 9 | 10 | def strip_prefix_if_present(state_dict, prefix): 11 | keys = sorted(state_dict.keys()) 12 | if not all(key.startswith(prefix) for key in keys): 13 | return state_dict 14 | stripped_state_dict = OrderedDict() 15 | for key, value in state_dict.items(): 16 | stripped_state_dict[key.replace(prefix, "")] = value 17 | return stripped_state_dict 18 | 19 | def mkdirs(path, remove=False): 20 | if os.path.isdir(path): 21 | if remove: 22 | shutil.rmtree(path) 23 | else: 24 | return 25 | os.makedirs(path) 26 | 27 | def generate_spectrogram_magphase(audio, stft_frame, stft_hop, n_fft, with_phase=False): 28 | spectro = librosa.core.stft(audio, hop_length=stft_hop, n_fft=n_fft, win_length=stft_frame, center=True) 29 | spectro_mag, spectro_phase = librosa.core.magphase(spectro) 30 | spectro_mag = np.expand_dims(spectro_mag, axis=0) 31 | if with_phase: 32 | spectro_phase = np.expand_dims(np.angle(spectro_phase), axis=0) 33 | return spectro_mag, spectro_phase 34 | else: 35 | return spectro_mag 36 | 37 | def generate_spectrogram_complex(audio, stft_frame, stft_hop, n_fft): 38 | spectro = librosa.core.stft(audio, hop_length=stft_hop, n_fft=n_fft, win_length=stft_frame, center=True) 39 | real = np.expand_dims(np.real(spectro), axis=0) 40 | imag = np.expand_dims(np.imag(spectro), axis=0) 41 | spectro_two_channel = np.concatenate((real, imag), axis=0) 42 | return spectro_two_channel 43 | 44 | def batchify(fn, chunk): 45 | """ 46 | Constructs a version of 'fn' that applies to smaller batches 47 | """ 48 | if chunk is None: 49 | return fn 50 | def ret(inputs): 51 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 52 | return ret 53 | 54 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64): 55 | """ 56 | Prepares inputs and applies network 'fn'. 57 | """ 58 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) 59 | embedded = embed_fn(inputs_flat) 60 | 61 | if viewdirs is not None: 62 | input_dirs = viewdirs[:,None].expand(inputs.shape) 63 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) 64 | embedded_dirs = embeddirs_fn(input_dirs_flat) 65 | embedded = torch.cat([embedded, embedded_dirs], -1) 66 | 67 | outputs_flat = batchify(fn, netchunk)(embedded) 68 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 69 | return outputs 70 | 71 | def create_nerf(args): 72 | """ 73 | Instantiate NeRF's MLP model. 74 | """ 75 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed) 76 | 77 | input_ch_views = 0 78 | embeddirs_fn = None 79 | if args.use_viewdirs: 80 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed) 81 | output_ch = 2 82 | skips = [4] 83 | model = NeRF(D=args.netdepth, W=args.netwidth, 84 | input_ch=input_ch, output_ch=output_ch, skips=skips, 85 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs) 86 | model = nn.DataParallel(model).to(device) 87 | grad_vars = list(model.parameters()) 88 | -------------------------------------------------------------------------------- /OF_render.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import sys 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | import imageio 8 | import json 9 | import random 10 | import time 11 | from tqdm import tqdm, trange 12 | import scipy 13 | import librosa 14 | from scipy.io.wavfile import write 15 | from scipy.spatial import KDTree 16 | import ddsp_torch as ddsp 17 | import itertools 18 | from taxim_render import TaximRender 19 | from PIL import Image 20 | import argparse 21 | from load_osf import load_osf_data 22 | import AudioNet_utils 23 | import AudioNet_model 24 | import TouchNet_utils 25 | import TouchNet_model 26 | import VisionNet_utils 27 | from utils import * 28 | 29 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 | 31 | def config_parser(): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--modality", type=str, default="vision,audio,touch") 34 | parser.add_argument("--object_file_path", type=str, default="demo/ObjectFile.pth", help='ObjectFile path') 35 | parser.add_argument("--KiloOSF", action="store_true") 36 | 37 | # VisionNet options 38 | parser.add_argument("--vision_test_file_path", default='data/vision_demo.npy', help='The path of the testing file for vision, which should be a npy file.') 39 | parser.add_argument("--vision_results_dir", type=str, default='./results/vision/', help='The path of the vision results directory to save rendered images.') 40 | 41 | # AudioNet options 42 | parser.add_argument('--audio_vertices_file_path', default='./data/audio_demo_vertices.npy', help='The path of the testing vertices file for audio, which should be a npy file.') 43 | parser.add_argument('--audio_forces_file_path', default='./data/forces.npy', help='The path of forces file for audio, which should be a npy file.') 44 | parser.add_argument("--audio_results_dir", type=str, default='./results/audio/', help='The path of the audio results directory to save impact sounds.') 45 | 46 | # TouchNet options 47 | parser.add_argument('--touch_vertices_file_path', default='./data/touch_demo_vertices.npy', help='The path of the testing vertices file for touch, which should be a npy file.') 48 | parser.add_argument('--touch_gelinfo_file_path', default='./data/touch_demo_gelinfo.npy', help='The path of the gel configurations for touch, which should be a npy file.') 49 | parser.add_argument('--touch_results_dir', type=str, default='./results/touch/', help='The path of the touch results directory to save rendered tactile RGB images.') 50 | 51 | return parser 52 | 53 | 54 | def VisionNet_eval(args): 55 | 56 | checkpoint = torch.load(args.object_file_path) 57 | cfg = checkpoint['VisionNet']['cfg'] 58 | 59 | metadata, render_metadata = None, None 60 | background_color = torch.ones(3, dtype=torch.float, device=device) 61 | 62 | poses, hwf, i_split, metadata = load_osf_data(args.vision_test_file_path) 63 | i_test = i_split[0] 64 | render_poses = np.array(poses[i_test]) 65 | 66 | # Create dummy metadata if not loaded from dataset. 67 | if metadata is None: 68 | metadata = torch.tensor([[0, 0, 1]] * len(images), dtype=torch.float) # [N, 3] 69 | if render_metadata is None: 70 | render_metadata = metadata 71 | 72 | # Cast intrinsics to right types 73 | H, W, focal = hwf 74 | H, W = int(H), int(W) 75 | hwf = [H, W, focal] 76 | intrinsics = CameraIntrinsics(int(H), int(W), focal, focal, W * .5, H * .5) 77 | 78 | render_kwargs_train = { 79 | 'perturb' : cfg['perturb'], 80 | 'N_samples' : cfg['num_samples_per_ray'], 81 | 'N_importance' : cfg['num_importance_samples_per_ray'], 82 | 'use_viewdirs': True, 83 | 'use_lightdirs': True, 84 | 'white_bkgd' : cfg['blender_white_background'], 85 | 'raw_noise_std' : cfg['raw_noise_std'], 86 | 'near' : cfg['near'], 87 | 'far' : cfg['far'], 88 | 'metadata': metadata, 89 | 'render_metadata': metadata, 90 | 'random_direction_probability': cfg['random_direction_probability'], 91 | 'von_mises_kappa': cfg['von_mises_kappa'], 92 | 'background_color': background_color, 93 | 'lightdirs_method': 'metadata', 94 | 'cfg': cfg 95 | } 96 | 97 | ConfigManager.init(cfg) 98 | 99 | if args.KiloOSF: 100 | import kilonerf_cuda 101 | from local_distill import create_multi_network_fourier_embedding, has_flag, create_multi_network 102 | kilonerf_cuda.init_stream_pool(16) 103 | kilonerf_cuda.init_magma() 104 | 105 | position_num_input_channels, position_fourier_embedding = create_multi_network_fourier_embedding(1, cfg['num_frequencies']) 106 | direction_num_input_channels, direction_fourier_embedding = create_multi_network_fourier_embedding(1, cfg['num_frequencies_direction']) 107 | light_num_input_channels, light_fourier_embedding = create_multi_network_fourier_embedding(1, cfg['num_frequencies_light']) 108 | 109 | root_nodes = occupancy_grid = None 110 | 111 | res = cfg['fixed_resolution'] 112 | network_resolution = torch.tensor(res, dtype=torch.long, device=torch.device('cpu')) 113 | num_networks = res[0] * res[1] * res[2] 114 | model = multi_network = create_multi_network(num_networks, position_num_input_channels, direction_num_input_channels, light_num_input_channels, 4, 'multimatmul_differentiable', cfg).to(device) 115 | 116 | global_domain_min, global_domain_max = ConfigManager.get_global_domain_min_and_max(torch.device('cpu')) 117 | global_domain_size = global_domain_max - global_domain_min 118 | network_voxel_size = global_domain_size / network_resolution 119 | 120 | # Determine bounding boxes (domains) of all networks. Required for global to local coordinate conversion. 121 | domain_mins = [] 122 | domain_maxs = [] 123 | for coord in itertools.product(*[range(r) for r in res]): 124 | coord = torch.tensor(coord, device=torch.device('cpu')) 125 | domain_min = global_domain_min + network_voxel_size * coord 126 | domain_max = domain_min + network_voxel_size 127 | domain_mins.append(domain_min.tolist()) 128 | domain_maxs.append(domain_max.tolist()) 129 | domain_mins = torch.tensor(domain_mins, device=device) 130 | domain_maxs = torch.tensor(domain_maxs, device=device) 131 | occupancy_grid = checkpoint['VisionNet']['occupancy_grid'] 132 | 133 | additional_kwargs = { 134 | 'root_nodes': root_nodes, 135 | 'position_fourier_embedding': position_fourier_embedding, 136 | 'direction_fourier_embedding': direction_fourier_embedding, 137 | 'light_fourier_embedding': light_fourier_embedding, 138 | 'multi_network': multi_network, 139 | 'domain_mins': domain_mins, 140 | 'domain_maxs': domain_maxs, 141 | 'occupancy_grid': occupancy_grid, 142 | 'debug_network_color_map': None 143 | } 144 | else: 145 | model, embed_fn, embeddirs_fn, embedlights_fn = create_nerf(cfg) 146 | model = model.to(device) 147 | network_query_fn = lambda inputs, viewdirs, lightdirs, network_fn : VisionNet_utils.run_network(inputs, viewdirs, lightdirs, network_fn, 148 | embed_fn=embed_fn, 149 | embeddirs_fn=embeddirs_fn, 150 | embedlights_fn=embedlights_fn, 151 | netchunk=cfg['network_chunk_size']) 152 | 153 | additional_kwargs = { 154 | 'network_query_fn' : network_query_fn, 155 | 'network_fn' : model 156 | } 157 | 158 | render_kwargs_train.update(additional_kwargs) 159 | render_kwargs_train['ndc'] = False 160 | render_kwargs_train['lindisp'] = cfg['llff_lindisp'] 161 | 162 | render_kwargs_test = render_kwargs_train.copy() 163 | render_kwargs_test['perturb'] = False 164 | render_kwargs_test['raw_noise_std'] = 0. 165 | render_kwargs_test['random_direction_probability'] = -1 166 | render_kwargs_test['von_mises_kappa'] = -1 167 | 168 | model.load_state_dict(checkpoint['VisionNet']['model_state_dict']) 169 | 170 | # Move testing data to GPU 171 | render_poses = torch.Tensor(render_poses).to(device) 172 | 173 | model.eval() 174 | with torch.no_grad(): 175 | images = None 176 | testsavedir = args.vision_results_dir 177 | os.makedirs(testsavedir, exist_ok=True) 178 | rgbs, _ = VisionNet_utils.render_path(render_poses, intrinsics, cfg['chunk_size'], render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=cfg['render_factor']) 179 | 180 | def AudioNet_eval(args): 181 | 182 | checkpoint = torch.load(args.object_file_path) 183 | normalizer_dic = checkpoint['AudioNet']['normalizer'] 184 | gains_f1_min = normalizer_dic['f1_min'] 185 | gains_f1_max = normalizer_dic['f1_max'] 186 | gains_f2_min = normalizer_dic['f2_min'] 187 | gains_f2_max = normalizer_dic['f2_max'] 188 | gains_f3_min = normalizer_dic['f3_min'] 189 | gains_f3_max = normalizer_dic['f3_max'] 190 | xyz_min = normalizer_dic['xyz_min'] 191 | xyz_max = normalizer_dic['xyz_max'] 192 | freqs = checkpoint['AudioNet']['frequencies'] 193 | damps = checkpoint['AudioNet']['dampings'] 194 | 195 | forces = np.load(args.audio_forces_file_path) 196 | 197 | xyz = np.load(args.audio_vertices_file_path).reshape((-1, 3)) 198 | # normalize xyz to [-1, 1] 199 | xyz = (xyz - xyz_min) / (xyz_max - xyz_min) 200 | 201 | N = xyz.shape[0] 202 | G = freqs.shape[0] 203 | 204 | embed_fn, input_ch = AudioNet_model.get_embedder(10, 0) 205 | model = AudioNet_model.AudioNeRF(D=8, input_ch=input_ch, output_ch=G) 206 | state_dic = checkpoint['AudioNet']["model_state_dict"] 207 | state_dic = AudioNet_utils.strip_prefix_if_present(state_dic, 'module.') 208 | model.load_state_dict(state_dic) 209 | model = nn.DataParallel(model).to(device) 210 | model.eval() 211 | 212 | preds_gain_x = torch.zeros((N, G)).to(device) 213 | preds_gain_y = torch.zeros((N, G)).to(device) 214 | preds_gain_z = torch.zeros((N, G)).to(device) 215 | 216 | batch_size = 1024 217 | 218 | for i in trange(N // batch_size + 1): 219 | curr_x = torch.Tensor(xyz[i*batch_size:(i+1)*batch_size]).to(device) 220 | curr_y = torch.Tensor(xyz[i*batch_size:(i+1)*batch_size]).to(device) 221 | curr_z = torch.Tensor(xyz[i*batch_size:(i+1)*batch_size]).to(device) 222 | embedded_x = embed_fn(curr_x) 223 | embedded_y = embed_fn(curr_y) 224 | embedded_z = embed_fn(curr_z) 225 | results_x, results_y, results_z = model(embedded_x, embedded_y, embedded_z) 226 | 227 | preds_gain_x[i*batch_size:(i+1)*batch_size] = results_x 228 | preds_gain_y[i*batch_size:(i+1)*batch_size] = results_y 229 | preds_gain_z[i*batch_size:(i+1)*batch_size] = results_z 230 | 231 | preds_gain_x = preds_gain_x * (gains_f1_max - gains_f1_min) + gains_f1_min 232 | preds_gain_y = preds_gain_y * (gains_f2_max - gains_f2_min) + gains_f2_min 233 | preds_gain_z = preds_gain_z * (gains_f3_max - gains_f3_min) + gains_f3_min 234 | preds_gain = torch.cat((preds_gain_x[:, None, :], preds_gain_y[:, None, :], preds_gain_z[:, None, :]), 1) 235 | 236 | freqs = torch.Tensor(freqs).to(device) 237 | damps = torch.Tensor(damps).to(device) 238 | 239 | testsavedir = args.audio_results_dir 240 | os.makedirs(testsavedir, exist_ok=True) 241 | 242 | for i in trange(N): 243 | preds_gain_x_i = preds_gain[i, 0, :] 244 | preds_gain_y_i = preds_gain[i, 1, :] 245 | preds_gain_z_i = preds_gain[i, 2, :] 246 | force_x, force_y, force_z = forces[i] 247 | combined_preds_gain = force_x * preds_gain_x_i + force_y * preds_gain_y_i + force_z * preds_gain_z_i 248 | combined_preds_gain = combined_preds_gain.unsqueeze(0) 249 | modal_fir = torch.unsqueeze(ddsp.get_modal_fir(combined_preds_gain, freqs, damps), axis=1) 250 | impulse = torch.reshape(torch.Tensor(scipy.signal.unit_impulse(44100*3)).to(device), (1, -1)).repeat(modal_fir.shape[0], 1) 251 | result = ddsp.fft_convolve(impulse, modal_fir) 252 | signal = result[0, :].detach().cpu().numpy() 253 | signal = signal / np.abs(signal).max() 254 | # write wav file 255 | output_path = os.path.join(testsavedir, str(i+1) + '.wav') 256 | write(output_path, 44100, signal.astype(np.float32)) 257 | 258 | 259 | def TouchNet_eval(args): 260 | 261 | checkpoint = torch.load(args.object_file_path) 262 | 263 | rotation_max = 15 264 | depth_max = 0.04 265 | depth_min = 0.0339 266 | displacement_min = 0.0005 267 | displacement_max = 0.0020 268 | depth_max = 0.04 269 | depth_min = 0.0339 270 | rgb_width = 120 271 | rgb_height = 160 272 | network_depth = 8 273 | 274 | #TODO load object... 275 | vertex_min = checkpoint['TouchNet']['xyz_min'] 276 | vertex_max = checkpoint['TouchNet']['xyz_max'] 277 | 278 | vertex_coordinates = np.load(args.touch_vertices_file_path) 279 | N = vertex_coordinates.shape[0] 280 | gelinfo_data = np.load(args.touch_gelinfo_file_path) 281 | theta, phi, displacement = gelinfo_data[:, 0], gelinfo_data[:, 1], gelinfo_data[:, 2] 282 | phi_x = np.cos(phi) 283 | phi_y = np.sin(phi) 284 | 285 | # normalize theta to [-1, 1] 286 | theta = (theta - np.radians(0)) / (np.radians(rotation_max) - np.radians(0)) 287 | 288 | #normalize displacement to [-1,1] 289 | displacement_norm = (displacement - displacement_min) / (displacement_max - displacement_min) 290 | 291 | #normalize coordinates to [-1,1] 292 | vertex_coordinates = (vertex_coordinates - vertex_min) / (vertex_max - vertex_min) 293 | 294 | #initialize horizontal and vertical features 295 | w_feats = np.repeat(np.repeat(np.arange(rgb_width).reshape((rgb_width, 1)), rgb_height, axis=1).reshape((1, 1, rgb_width, rgb_height)), N, axis=0) 296 | h_feats = np.repeat(np.repeat(np.arange(rgb_height).reshape((1, rgb_height)), rgb_width, axis=0).reshape((1, 1, rgb_width, rgb_height)), N, axis=0) 297 | #normalize horizontal and vertical features to [-1, 1] 298 | w_feats_min = w_feats.min() 299 | w_feats_max = w_feats.max() 300 | h_feats_min = h_feats.min() 301 | h_feats_max = h_feats.max() 302 | w_feats = (w_feats - w_feats_min) / (w_feats_max - w_feats_min) 303 | h_feats = (h_feats - h_feats_min) / (h_feats_max - h_feats_min) 304 | w_feats = torch.FloatTensor(w_feats) 305 | h_feats = torch.FloatTensor(h_feats) 306 | 307 | theta = np.repeat(theta.reshape((N, 1, 1)), rgb_width * rgb_height, axis=1) 308 | phi_x = np.repeat(phi_x.reshape((N, 1, 1)), rgb_width * rgb_height, axis=1) 309 | phi_y = np.repeat(phi_y.reshape((N, 1, 1)), rgb_width * rgb_height, axis=1) 310 | displacement_norm = np.repeat(displacement_norm.reshape((N, 1, 1)), rgb_width * rgb_height, axis=1) 311 | vertex_coordinates = np.repeat(vertex_coordinates.reshape((N, 1, 3)), rgb_width * rgb_height, axis=1) 312 | 313 | data_wh = np.concatenate((w_feats, h_feats), axis=1) 314 | data_wh = np.transpose(data_wh.reshape((N, 2, -1)), axes=[0, 2, 1]) 315 | #Now get final feats matrix as [x, y, z, theta, phi_x, phi_y, displacement, w, h] 316 | data = np.concatenate((vertex_coordinates, theta, phi_x, phi_y, displacement_norm, data_wh), axis=2).reshape((-1, 9)) 317 | 318 | #checkpoint = torch.load(args.object_file_path) 319 | embed_fn, input_ch = TouchNet_model.get_embedder(10, 0) 320 | model = TouchNet_model.NeRF(D = network_depth, input_ch = input_ch, output_ch = 1) 321 | state_dic = checkpoint['TouchNet']['model_state_dict'] 322 | state_dic = TouchNet_utils.strip_prefix_if_present(state_dic, 'module.') 323 | model.load_state_dict(state_dic) 324 | model = nn.DataParallel(model).to(device) 325 | model.eval() 326 | 327 | preds = np.empty((data.shape[0], 1)) 328 | 329 | batch_size = 1024 330 | 331 | testsavedir = args.touch_results_dir 332 | os.makedirs(testsavedir, exist_ok=True) 333 | 334 | for i in trange(data.shape[0] // batch_size + 1): 335 | inputs = torch.Tensor(data[i*batch_size:(i+1)*batch_size]).to(device) 336 | embedded = embed_fn(inputs) 337 | results = model(embedded) 338 | preds[i*batch_size:(i+1)*batch_size, :] = results.detach().cpu().numpy() 339 | 340 | preds = preds * (depth_max - depth_min) + depth_min 341 | preds = np.transpose(preds.reshape((N, -1, 1)), axes = [0, 2, 1]).reshape((N, rgb_width, rgb_height)) 342 | taxim = TaximRender("./calibs/") 343 | for i in trange(N): 344 | height_map, contact_map, tactile_map = taxim.render(preds[i], displacement[i]) 345 | tactile_map = Image.fromarray(tactile_map.astype(np.uint8), 'RGB') 346 | filename = os.path.join(testsavedir, '{}.png'.format(i+1)) 347 | tactile_map.save(filename) 348 | 349 | if __name__ =='__main__': 350 | parser = config_parser() 351 | args = parser.parse_args() 352 | modalities = args.modality.strip().split(",") 353 | 354 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 355 | 356 | if "vision" in modalities: 357 | VisionNet_eval(args=args) 358 | if "audio" in modalities: 359 | AudioNet_eval(args=args) 360 | if "touch" in modalities: 361 | TouchNet_eval(args=args) 362 | -------------------------------------------------------------------------------- /ObjectFolder1.0/AudioNet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.autograd.set_detect_anomaly(True) 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from AudioNet_utils import * 7 | 8 | class DenseLayer(nn.Linear): 9 | def __init__(self, in_dim: int, out_dim: int, activation: str = 'relu', *args, **kwargs) -> None: 10 | self.activation = activation 11 | super().__init__(in_dim, out_dim, *args, **kwargs) 12 | 13 | def reset_parameters(self) -> None: 14 | torch.nn.init.xavier_uniform_(self.weight, gain=torch.nn.init.calculate_gain(self.activation)) 15 | if self.bias is not None: 16 | torch.nn.init.zeros_(self.bias) 17 | 18 | class Embedder: 19 | def __init__(self, **kwargs): 20 | self.kwargs = kwargs 21 | self.create_embedding_fn() 22 | 23 | def create_embedding_fn(self): 24 | embed_fns = [] 25 | d = self.kwargs['input_dims'] 26 | out_dim = 0 27 | if self.kwargs['include_input']: 28 | embed_fns.append(lambda x: x) 29 | out_dim += d 30 | 31 | max_freq = self.kwargs['max_freq_log2'] 32 | N_freqs = self.kwargs['num_freqs'] 33 | 34 | if self.kwargs['log_sampling']: 35 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) 36 | else: 37 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) 38 | 39 | for freq in freq_bands: 40 | for p_fn in self.kwargs['periodic_fns']: 41 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 42 | out_dim += d 43 | 44 | self.embed_fns = embed_fns 45 | self.out_dim = out_dim 46 | 47 | def embed(self, inputs): 48 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 49 | 50 | def get_embedder(multires, i=0): 51 | if i == -1: 52 | #5 is for x, y, z, f, t 53 | return nn.Identity(), 5 54 | 55 | embed_kwargs = { 56 | 'include_input': True, 57 | 'input_dims': 5, 58 | 'max_freq_log2': multires-1, 59 | 'num_freqs': multires, 60 | 'log_sampling': True, 61 | 'periodic_fns': [torch.sin, torch.cos], 62 | } 63 | 64 | embedder_obj = Embedder(**embed_kwargs) 65 | embed = lambda x, eo=embedder_obj: eo.embed(x) 66 | return embed, embedder_obj.out_dim 67 | 68 | class AudioNeRF(nn.Module): 69 | def __init__(self, D=8, input_ch=5): 70 | super(AudioNeRF, self).__init__() 71 | self.model_x = NeRF(D = D, input_ch = input_ch) 72 | self.model_y = NeRF(D = D, input_ch = input_ch) 73 | self.model_z = NeRF(D = D, input_ch = input_ch) 74 | 75 | def forward(self, embedded_x, embedded_y, embedded_z): 76 | results_x = self.model_x(embedded_x) 77 | results_y = self.model_y(embedded_y) 78 | results_z = self.model_z(embedded_z) 79 | return results_x, results_y, results_z 80 | 81 | 82 | class NeRF(nn.Module): 83 | def __init__(self, D=8, W=256, input_ch=5, input_ch_views=0, output_ch=2, skips=[4], use_viewdirs=False): 84 | """ 85 | """ 86 | super(NeRF, self).__init__() 87 | self.D = D 88 | self.W = W 89 | self.input_ch = input_ch 90 | self.input_ch_views = input_ch_views 91 | self.skips = skips 92 | self.use_viewdirs = use_viewdirs 93 | 94 | self.pts_linears = nn.ModuleList( 95 | [DenseLayer(input_ch, W, activation='relu')] + [DenseLayer(W, W, activation='relu') if i not in self.skips else DenseLayer(W + input_ch, W, activation='relu') for i in range(D-1)]) 96 | 97 | self.views_linears = nn.ModuleList([DenseLayer(input_ch_views + W, W//2, activation='relu')]) 98 | 99 | if use_viewdirs: 100 | self.feature_linear = DenseLayer(W, W, activation='sigmoid') 101 | #self.alpha_linear = DenseLayer(W, 1, activation='linear') 102 | self.rgb_linear = DenseLayer(W//2, output_ch, activation='sigmoid') 103 | else: 104 | self.output_linear = DenseLayer(W, output_ch, activation='sigmoid') 105 | 106 | 107 | def forward(self, x): 108 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1) 109 | h = input_pts 110 | for i, l in enumerate(self.pts_linears): 111 | h = self.pts_linears[i](h) 112 | h = F.relu(h) 113 | if i in self.skips: 114 | h = torch.cat([input_pts, h], -1) 115 | 116 | if self.use_viewdirs: 117 | feature = self.feature_linear(h) 118 | h = torch.cat([feature, input_views], -1) 119 | 120 | for i, l in enumerate(self.views_linears): 121 | h = self.views_linears[i](h) 122 | h = F.relu(h) 123 | 124 | outputs = self.rgb_linear(h) 125 | else: 126 | outputs = self.output_linear(h) 127 | 128 | return outputs 129 | -------------------------------------------------------------------------------- /ObjectFolder1.0/AudioNet_utils.py: -------------------------------------------------------------------------------- 1 | from scipy.io import wavfile 2 | import librosa 3 | import librosa.display 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from AudioNet_model import * 7 | import os 8 | from collections import OrderedDict 9 | 10 | def transform_mesh_collision_binvox(coordinates, translate, scale): 11 | for i in range(3): 12 | coordinates[i] = (coordinates[i] + translate[i]) * scale 13 | return coordinates 14 | 15 | def strip_prefix_if_present(state_dict, prefix): 16 | keys = sorted(state_dict.keys()) 17 | if not all(key.startswith(prefix) for key in keys): 18 | return state_dict 19 | stripped_state_dict = OrderedDict() 20 | for key, value in state_dict.items(): 21 | stripped_state_dict[key.replace(prefix, "")] = value 22 | return stripped_state_dict 23 | 24 | def mkdirs(path, remove=False): 25 | if os.path.isdir(path): 26 | if remove: 27 | shutil.rmtree(path) 28 | else: 29 | return 30 | os.makedirs(path) 31 | 32 | def generate_spectrogram_magphase(audio, stft_frame, stft_hop, n_fft, with_phase=False): 33 | spectro = librosa.core.stft(audio, hop_length=stft_hop, n_fft=n_fft, win_length=stft_frame, center=True) 34 | spectro_mag, spectro_phase = librosa.core.magphase(spectro) 35 | spectro_mag = np.expand_dims(spectro_mag, axis=0) 36 | if with_phase: 37 | spectro_phase = np.expand_dims(np.angle(spectro_phase), axis=0) 38 | return spectro_mag, spectro_phase 39 | else: 40 | return spectro_mag 41 | 42 | def generate_spectrogram_complex(audio, stft_frame, stft_hop, n_fft): 43 | spectro = librosa.core.stft(audio, hop_length=stft_hop, n_fft=n_fft, win_length=stft_frame, center=True) 44 | real = np.expand_dims(np.real(spectro), axis=0) 45 | imag = np.expand_dims(np.imag(spectro), axis=0) 46 | spectro_two_channel = np.concatenate((real, imag), axis=0) 47 | return spectro_two_channel 48 | 49 | def batchify(fn, chunk): 50 | """ 51 | Constructs a version of 'fn' that applies to smaller batches 52 | """ 53 | if chunk is None: 54 | return fn 55 | def ret(inputs): 56 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 57 | return ret 58 | 59 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64): 60 | """ 61 | Prepares inputs and applies network 'fn'. 62 | """ 63 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) 64 | embedded = embed_fn(inputs_flat) 65 | 66 | if viewdirs is not None: 67 | input_dirs = viewdirs[:,None].expand(inputs.shape) 68 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) 69 | embedded_dirs = embeddirs_fn(input_dirs_flat) 70 | embedded = torch.cat([embedded, embedded_dirs], -1) 71 | 72 | outputs_flat = batchify(fn, netchunk)(embedded) 73 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 74 | return outputs 75 | 76 | """ 77 | def get_embedder(multires, i=0): 78 | if i == -1: 79 | #5 is for x, y, z, f, t 80 | return nn.Identity(), 5 81 | 82 | embed_kwargs = { 83 | 'include_input': True, 84 | 'input_dims': 5, 85 | 'max_freq_log2': multires-1, 86 | 'num_freqs': multires, 87 | 'log_sampling': True, 88 | 'periodic_fns': [torch.sin, torch.cos], 89 | } 90 | 91 | embedder_obj = Embedder(**embed_kwargs) 92 | embed = lambda x, eo=embedder_obj: eo.embed(x) 93 | return embed, embedder_obj.out_dim 94 | """ 95 | 96 | def create_nerf(args): 97 | """ 98 | Instantiate NeRF's MLP model. 99 | """ 100 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed) 101 | 102 | input_ch_views = 0 103 | embeddirs_fn = None 104 | if args.use_viewdirs: 105 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed) 106 | output_ch = 2 107 | skips = [4] 108 | model = NeRF(D=args.netdepth, W=args.netwidth, 109 | input_ch=input_ch, output_ch=output_ch, skips=skips, 110 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs) 111 | model = nn.DataParallel(model).to(device) 112 | grad_vars = list(model.parameters()) 113 | -------------------------------------------------------------------------------- /ObjectFolder1.0/ObjectFolder_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/ObjectFolder1.0/ObjectFolder_teaser.png -------------------------------------------------------------------------------- /ObjectFolder1.0/README.md: -------------------------------------------------------------------------------- 1 | ## ObjectFolder: A Dataset of Objects with Implicit Visual, Auditory, and Tactile Representations (CoRL 2021) 2 | [[Project Page]](https://ai.stanford.edu/~rhgao/objectfolder/) [[arXiv]](https://arxiv.org/abs/2109.07991) 3 | 4 | 5 | 6 |
7 | 8 | [ObjectFolder: A Dataset of Objects with Implicit Visual, Auditory, and Tactile Representations](https://arxiv.org/abs/2109.07991) 9 | [Ruohan Gao](https://www.ai.stanford.edu/~rhgao/), [Yen-Yu Chang](https://yuyuchang.github.io/), [Shivani Mall](), [Li Fei-Fei](https://profiles.stanford.edu/fei-fei-li), [Jiajun Wu](https://jiajunwu.com/)
10 | Stanford University 11 | In Conference on Robot Learning (**CoRL**), 2021 12 | 13 |
14 | 15 | If you find our code or project useful in your research, please cite: 16 | 17 | @inproceedings{gao2021ObjectFolder, 18 | title = {ObjectFolder: A Dataset of Objects with Implicit Visual, Auditory, and Tactile Representations}, 19 | author = {Gao, Ruohan and Chang, Yen-Yu and Mall, Shivani and Fei-Fei, Li and Wu, Jiajun}, 20 | booktitle = {CoRL}, 21 | year = {2021} 22 | } 23 | 24 | ### About ObjectFolder Dataset 25 | 26 | 27 | 28 | ObjectFolder is a dataset of 100 objects in the form of implicit representations. It contains 100 Object Files each containing the complete multisensory profile for an object instance. Each Object File implicit neural representation network contains three sub-networks---VisionNet, AudioNet, and TouchNet, which through querying with the corresponding extrinsic parameters we can obtain the visual appearance of the object from different views, impact sounds of the object at each position, and tactile sensing of the object at every surface location, respectively. The dataset contains common household objects of diverse categories such as bowl, mug, cabinet, television, shelf, fork, and spoon. See the paper for details. 29 | 30 |
31 | 32 | ### Dataset Download and Preparation 33 | ``` 34 | git clone https://github.com/rhgao/ObjectFolder.git 35 | cd ObjectFolder/ObjectFolder1.0 36 | wget https://download.cs.stanford.edu/viscam/ObjectFolder/ObjectFolder1.0.tar.gz 37 | tar -xvf ObjectFolder1.0.tar.gz 38 | pip install -r requirements.txt 39 | ``` 40 | 41 | ### Rendering visual, auditory, and tactile sensory data 42 | Run the following command to render images, impact sounds, and tactile RGB images: 43 | ``` 44 | $ python evaluate.py --object_file_path path_of_ObjectFile \ 45 | --vision_test_file_path path_of_vision_test_file \ 46 | --vision_results_dir path_of_vision_results_directory \ 47 | --audio_vertices_file_path path_of_audio_testing_vertices_file \ 48 | --audio_forces_file_path path_of_forces_file \ 49 | --audio_results_dir path_of_audio_results_directory \ 50 | --touch_vertices_file_path path_of_touch_testing_vertices_file \ 51 | --touch_results_path path_of_touch_results_directory 52 | ``` 53 | This code can be run with the following command-line arguments: 54 | * `--object_file_path`: The path of ObjectFile. 55 | * `--vision_test_file_path`: The path of the testing file for vision, which should be a npy file. 56 | * `--vision_results_dir`: The path of the vision results directory to save rendered images. 57 | * `--audio_vertices_file_path`: The path of the testing vertices file for audio, which should be a npy file. 58 | * `--audio_forces_file_path`: The path of forces file for audio, which should be a npy file. 59 | * `--audio_results_dir`: The path of audio results directory to save rendered impact sounds as .wav files. 60 | * `--touch_vertices_file_path`: The path of the testing vertices file for touch, which should be a npy file. 61 | * `--touch_results_dir`: The path of the touch results directory to save rendered tactile RGB images. 62 | 63 | ### Data format 64 | * `--vision_test_file_path`: It is a npy file with shape of (N, 6), where N is the number of testing viewpoints. Each data point contains the coordinates of the camera and the light in the form of (camera_x, camera_y, camera_z, light_x, light_y, light_z). 65 | * `--audio_vertices_file_path`: It is a npy file with shape of (N, 3), where N is the number of testing vertices. Each data point represents a coordinate on the object in the form of (x, y, z). 66 | * `--audio_forces_file_path`: It is a npy file with shape of (N, 3), where N is the number of testing vertices. Each data point represents the force values for the corresponding impact in the form of (F_x, F_y, F_z). 67 | * `--touch_vertices_file_path`: It is a npy file with shape of (N, 3), where N is the number of testing vertices. Each data point contains a coordinate on the object in the form of (x, y, z). 68 | 69 | ### Demo 70 | Below we show an example of rendering the visual, auditory, and tactile data from the ObjectFile implicit representation for one object: 71 | ``` 72 | $ python evaluate.py --object_file_path Objects/25/ObjectFile.pth \ 73 | --vision_test_file_path demo/vision_demo.npy \ 74 | --vision_results_dir demo/vision_results/ \ 75 | --audio_vertices_file_path demo/audio_demo_vertices.npy \ 76 | --audio_forces_file_path demo/audio_demo_forces.npy \ 77 | --audio_results_dir demo/audio_results/ \ 78 | --touch_vertices_file_path demo/touch_demo_vertices.npy \ 79 | --touch_results_dir demo/touch_results/ 80 | ``` 81 | 82 | The rendered images, impact sounds, tactile images will be saved in `demo/vision_results/`, `demo/audio_results/`, and `demo/touch_results/`, respectively. 83 | 84 | ### Acknowlegements 85 | The code for the neural implicit representation network is adapted from Yen-Chen Lin's [PyTorch implementation](https://github.com/yenchenlin/nerf-pytorch) of [NeRF](https://www.matthewtancik.com/nerf) and Michelle Guo's TensorFlow implementation of [OSF](https://www.shellguo.com/osf/). 86 | 87 | ### License 88 | ObjectFolder is CC BY 4.0 licensed, as found in the LICENSE file. The 100 high quality 3D objects originally come from online repositories including: 20 objects from [3D Model Haven](https://3dmodelhaven.com/), 28 objects from the [YCB dataset](http://ycb-benchmarks.s3-website-us-east-1.amazonaws.com/), and 52 objects from [Google Scanned Objects](https://app.ignitionrobotics.org/GoogleResearch/fuel/collections/Google\%20Scanned\%20Objects). Please also refer to their original lisence file. 89 | -------------------------------------------------------------------------------- /ObjectFolder1.0/TouchNet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.autograd.set_detect_anomaly(True) 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from TouchNet_utils import * 7 | 8 | class DenseLayer(nn.Linear): 9 | def __init__(self, in_dim: int, out_dim: int, activation: str = 'relu', *args, **kwargs) -> None: 10 | self.activation = activation 11 | super().__init__(in_dim, out_dim, *args, **kwargs) 12 | 13 | def reset_parameters(self) -> None: 14 | torch.nn.init.xavier_uniform_(self.weight, gain=torch.nn.init.calculate_gain(self.activation)) 15 | if self.bias is not None: 16 | torch.nn.init.zeros_(self.bias) 17 | 18 | class Embedder: 19 | def __init__(self, **kwargs): 20 | self.kwargs = kwargs 21 | self.create_embedding_fn() 22 | 23 | def create_embedding_fn(self): 24 | embed_fns = [] 25 | d = self.kwargs['input_dims'] 26 | out_dim = 0 27 | if self.kwargs['include_input']: 28 | embed_fns.append(lambda x: x) 29 | out_dim += d 30 | 31 | max_freq = self.kwargs['max_freq_log2'] 32 | N_freqs = self.kwargs['num_freqs'] 33 | 34 | if self.kwargs['log_sampling']: 35 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) 36 | else: 37 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) 38 | 39 | for freq in freq_bands: 40 | for p_fn in self.kwargs['periodic_fns']: 41 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 42 | out_dim += d 43 | 44 | self.embed_fns = embed_fns 45 | self.out_dim = out_dim 46 | 47 | def embed(self, inputs): 48 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 49 | 50 | def get_embedder(multires, i=0): 51 | if i == -1: 52 | #5 is for x, y, z, r, g, b 53 | return nn.Identity(), 5 54 | 55 | embed_kwargs = { 56 | 'include_input': True, 57 | 'input_dims': 5, 58 | 'max_freq_log2': multires-1, 59 | 'num_freqs': multires, 60 | 'log_sampling': True, 61 | 'periodic_fns': [torch.sin, torch.cos], 62 | } 63 | 64 | embedder_obj = Embedder(**embed_kwargs) 65 | embed = lambda x, eo=embedder_obj: eo.embed(x) 66 | return embed, embedder_obj.out_dim 67 | 68 | class AudioNeRF(nn.Module): 69 | def __init__(self, D=8, input_ch=5): 70 | super(AudioNeRF, self).__init__() 71 | self.model_x = NeRF(D = D, input_ch = input_ch) 72 | self.model_y = NeRF(D = D, input_ch = input_ch) 73 | self.model_z = NeRF(D = D, input_ch = input_ch) 74 | 75 | def forward(self, embedded_x, embedded_y, embedded_z): 76 | results_x = self.model_x(embedded_x) 77 | results_y = self.model_y(embedded_y) 78 | results_z = self.model_z(embedded_z) 79 | return results_x, results_y, results_z 80 | 81 | 82 | class NeRF(nn.Module): 83 | def __init__(self, D=8, W=256, input_ch=5, input_ch_views=0, output_ch=2, skips=[4], use_viewdirs=False): 84 | """ 85 | """ 86 | super(NeRF, self).__init__() 87 | self.D = D 88 | self.W = W 89 | self.input_ch = input_ch 90 | self.input_ch_views = input_ch_views 91 | self.skips = skips 92 | self.use_viewdirs = use_viewdirs 93 | 94 | self.pts_linears = nn.ModuleList( 95 | [DenseLayer(input_ch, W, activation='relu')] + [DenseLayer(W, W, activation='relu') if i not in self.skips else DenseLayer(W + input_ch, W, activation='relu') for i in range(D-1)]) 96 | 97 | self.views_linears = nn.ModuleList([DenseLayer(input_ch_views + W, W//2, activation='relu')]) 98 | 99 | if use_viewdirs: 100 | self.feature_linear = DenseLayer(W, W, activation='sigmoid') 101 | #self.alpha_linear = DenseLayer(W, 1, activation='linear') 102 | self.rgb_linear = DenseLayer(W//2, output_ch, activation='sigmoid') 103 | else: 104 | self.output_linear = DenseLayer(W, output_ch, activation='sigmoid') 105 | 106 | 107 | def forward(self, x): 108 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1) 109 | h = input_pts 110 | for i, l in enumerate(self.pts_linears): 111 | h = self.pts_linears[i](h) 112 | h = F.relu(h) 113 | if i in self.skips: 114 | h = torch.cat([input_pts, h], -1) 115 | 116 | if self.use_viewdirs: 117 | feature = self.feature_linear(h) 118 | h = torch.cat([feature, input_views], -1) 119 | 120 | for i, l in enumerate(self.views_linears): 121 | h = self.views_linears[i](h) 122 | h = F.relu(h) 123 | 124 | outputs = self.rgb_linear(h) 125 | else: 126 | outputs = self.output_linear(h) 127 | 128 | return outputs 129 | -------------------------------------------------------------------------------- /ObjectFolder1.0/TouchNet_utils.py: -------------------------------------------------------------------------------- 1 | from scipy.io import wavfile 2 | import librosa 3 | import librosa.display 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from TouchNet_model import * 7 | import os 8 | from collections import OrderedDict 9 | 10 | def strip_prefix_if_present(state_dict, prefix): 11 | keys = sorted(state_dict.keys()) 12 | if not all(key.startswith(prefix) for key in keys): 13 | return state_dict 14 | stripped_state_dict = OrderedDict() 15 | for key, value in state_dict.items(): 16 | stripped_state_dict[key.replace(prefix, "")] = value 17 | return stripped_state_dict 18 | 19 | def mkdirs(path, remove=False): 20 | if os.path.isdir(path): 21 | if remove: 22 | shutil.rmtree(path) 23 | else: 24 | return 25 | os.makedirs(path) 26 | 27 | def generate_spectrogram_magphase(audio, stft_frame, stft_hop, n_fft, with_phase=False): 28 | spectro = librosa.core.stft(audio, hop_length=stft_hop, n_fft=n_fft, win_length=stft_frame, center=True) 29 | spectro_mag, spectro_phase = librosa.core.magphase(spectro) 30 | spectro_mag = np.expand_dims(spectro_mag, axis=0) 31 | if with_phase: 32 | spectro_phase = np.expand_dims(np.angle(spectro_phase), axis=0) 33 | return spectro_mag, spectro_phase 34 | else: 35 | return spectro_mag 36 | 37 | def generate_spectrogram_complex(audio, stft_frame, stft_hop, n_fft): 38 | spectro = librosa.core.stft(audio, hop_length=stft_hop, n_fft=n_fft, win_length=stft_frame, center=True) 39 | real = np.expand_dims(np.real(spectro), axis=0) 40 | imag = np.expand_dims(np.imag(spectro), axis=0) 41 | spectro_two_channel = np.concatenate((real, imag), axis=0) 42 | return spectro_two_channel 43 | 44 | def batchify(fn, chunk): 45 | """ 46 | Constructs a version of 'fn' that applies to smaller batches 47 | """ 48 | if chunk is None: 49 | return fn 50 | def ret(inputs): 51 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 52 | return ret 53 | 54 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64): 55 | """ 56 | Prepares inputs and applies network 'fn'. 57 | """ 58 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) 59 | embedded = embed_fn(inputs_flat) 60 | 61 | if viewdirs is not None: 62 | input_dirs = viewdirs[:,None].expand(inputs.shape) 63 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) 64 | embedded_dirs = embeddirs_fn(input_dirs_flat) 65 | embedded = torch.cat([embedded, embedded_dirs], -1) 66 | 67 | outputs_flat = batchify(fn, netchunk)(embedded) 68 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 69 | return outputs 70 | 71 | """ 72 | def get_embedder(multires, i=0): 73 | if i == -1: 74 | #5 is for x, y, z, f, t 75 | return nn.Identity(), 5 76 | 77 | embed_kwargs = { 78 | 'include_input': True, 79 | 'input_dims': 5, 80 | 'max_freq_log2': multires-1, 81 | 'num_freqs': multires, 82 | 'log_sampling': True, 83 | 'periodic_fns': [torch.sin, torch.cos], 84 | } 85 | 86 | embedder_obj = Embedder(**embed_kwargs) 87 | embed = lambda x, eo=embedder_obj: eo.embed(x) 88 | return embed, embedder_obj.out_dim 89 | """ 90 | 91 | def create_nerf(args): 92 | """ 93 | Instantiate NeRF's MLP model. 94 | """ 95 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed) 96 | 97 | input_ch_views = 0 98 | embeddirs_fn = None 99 | if args.use_viewdirs: 100 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed) 101 | output_ch = 2 102 | skips = [4] 103 | model = NeRF(D=args.netdepth, W=args.netwidth, 104 | input_ch=input_ch, output_ch=output_ch, skips=skips, 105 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs) 106 | model = nn.DataParallel(model).to(device) 107 | grad_vars = list(model.parameters()) 108 | -------------------------------------------------------------------------------- /ObjectFolder1.0/box_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for bounding box computation.""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import ray_utils 6 | 7 | def ray_to_box_coordinate_frame_pairwise(box_center, box_rotation_matrix, 8 | rays_start_point, rays_end_point): 9 | """Moves a set of rays into a box's coordinate frame. 10 | 11 | Args: 12 | box_center: A tensor of size [3] or [r, 3]. 13 | box_rotation_matrix: A tensor of size [3, 3] or [r, 3, 3]. 14 | rays_start_point: A tensor of size [r, 3] where r is the number of rays. 15 | rays_end_points: A tensor of size [r, 3] where r is the number of rays. 16 | 17 | Returns: 18 | rays_start_point_in_box_frame: A tensor of size [r, 3]. 19 | rays_end_point_in_box_frame: A tensor if size [r, 3]. 20 | """ 21 | r = rays_start_point.size()[0] 22 | box_center = torch.broadcast_to(box_center, (r, 3)) 23 | box_rotation_matrix = torch.broadcast_to(box_rotation_matrix, (r, 3, 3)) 24 | rays_start_point_in_box_frame = torch.matmul( 25 | (rays_start_point - box_center).unsqueeze(1), 26 | box_rotation_matrix) 27 | rays_end_point_in_box_frame = torch.matmul( 28 | (rays_end_point - box_center).unsqueeze(1), 29 | box_rotation_matrix) 30 | return (rays_start_point_in_box_frame.view(-1, 3), 31 | rays_end_point_in_box_frame.view(-1, 3)) 32 | 33 | 34 | def ray_box_intersection_pairwise(box_center, 35 | box_rotation_matrix, 36 | box_length, 37 | box_width, 38 | box_height, 39 | rays_start_point, 40 | rays_end_point, 41 | exclude_negative_t=False, 42 | exclude_enlarged_t=True, 43 | epsilon=0.000001): 44 | """Intersects a set of rays with a box. 45 | 46 | Note: The intersection points are returned in the box coordinate frame. 47 | Note: Make sure the start and end point of the rays are not the same. 48 | Note: Even though a start and end point is passed for each ray, rays are 49 | never ending and can intersect a box beyond their start / end points. 50 | 51 | Args: 52 | box_center: A tensor of size [3] or [r, 3]. 53 | box_rotation_matrix: A tensor of size [3, 3] or [r, 3, 3]. 54 | box_length: A scalar tensor or of size [r]. 55 | box_width: A scalar tensor or of size [r]. 56 | box_height: A scalar tensor or of size [r]. 57 | rays_start_point: A tensor of size [r, 3] where r is the number of rays. 58 | rays_end_point: A tensor of size [r, 3] there r is the number of rays. 59 | exclude_negative_t: bool. 60 | exclude_enlarged_t: bool. 61 | epsilon: A very small number. 62 | 63 | Returns: 64 | intersection_points_in_box_frame: A tensor of size [r', 2, 3] 65 | that contains intersection points in box coordinate frame. 66 | indices_of_intersecting_rays: A tensor of size [r']. 67 | intersection_ts: A tensor of size [r']. 68 | """ 69 | r = rays_start_point.size()[0] 70 | box_length = box_length.expand(r) 71 | box_width = box_width.expand(r) 72 | box_height = box_height.expand(r) 73 | box_center = torch.broadcast_to(box_center, (r, 3)) 74 | box_rotation_matrix = torch.broadcast_to(box_rotation_matrix, (r, 3, 3)) 75 | rays_start_point_in_box_frame, rays_end_point_in_box_frame = ( 76 | ray_to_box_coordinate_frame_pairwise( 77 | box_center=box_center, 78 | box_rotation_matrix=box_rotation_matrix, 79 | rays_start_point=rays_start_point, 80 | rays_end_point=rays_end_point)) 81 | rays_a = rays_end_point_in_box_frame - rays_start_point_in_box_frame 82 | intersection_masks = [] 83 | intersection_points = [] 84 | intersection_ts = [] 85 | box_size = [box_length, box_width, box_height] 86 | for axis in range(3): 87 | plane_value = box_size[axis] / 2.0 88 | for _ in range(2): 89 | plane_value = -plane_value 90 | # Compute the scalar multiples of 'rays_a' to apply in order to intersect 91 | # with the plane. 92 | t = ((plane_value - rays_start_point_in_box_frame[:, axis]) / # [R,] 93 | rays_a[:, axis]) 94 | # The current axis only intersects with plane if the ray is not parallel 95 | # with the plane. Note that this will result in 't' being +/- infinity, becasue 96 | # the ray component in the axis is zero, resulting in rays_a[:, axis] = 0. 97 | intersects_with_plane = torch.abs(rays_a[:, axis]) > epsilon 98 | if exclude_negative_t: # Only allow at most one negative t 99 | t = torch.maximum(t, torch.tensor(0.0)) # [R,] 100 | if exclude_enlarged_t: 101 | t = torch.maximum(t, torch.tensor(1.0)) # [R,] 102 | intersection_ts.append(t) # [R, 1] 103 | intersection_points_i = [] 104 | 105 | # Initialize a mask which represents whether each ray intersects with the 106 | # current plane. 107 | intersection_masks_i = torch.ones_like(t, dtype=torch.int32).bool() # [R,] 108 | for axis2 in range(3): 109 | # Compute the point of intersection for the current axis. 110 | intersection_points_i_axis2 = ( # [R,] 111 | rays_start_point_in_box_frame[:, axis2] + t * rays_a[:, axis2]) 112 | intersection_points_i.append(intersection_points_i_axis2) # 3x [R,] 113 | 114 | # Update the intersection mask depending on whether the intersection 115 | # point is within bounds. 116 | intersection_masks_i = torch.logical_and( # [R,] 117 | torch.logical_and(intersection_masks_i, intersects_with_plane), 118 | torch.logical_and( 119 | intersection_points_i_axis2 <= (box_size[axis2] / 2.0 + epsilon), 120 | intersection_points_i_axis2 >= (-box_size[axis2] / 2.0 - epsilon))) 121 | intersection_points_i = torch.stack(intersection_points_i, dim=1) # [R, 3] 122 | intersection_masks.append(intersection_masks_i) # List of [R,] 123 | intersection_points.append(intersection_points_i) # List of [R, 3] 124 | intersection_ts = torch.stack(intersection_ts, dim=1) # [R, 6] 125 | intersection_masks = torch.stack(intersection_masks, dim=1) # [R, 6] 126 | intersection_points = torch.stack(intersection_points, dim=1) # [R, 6, 3] 127 | 128 | # Compute a mask over rays with exactly two plane intersections out of the six 129 | # planes. More intersections are possible if the ray coincides with a box 130 | # edge or corner, but we'll ignore these cases for now. 131 | counts = torch.sum(intersection_masks.int(), dim=1) # [R,] 132 | intersection_masks_any = torch.eq(counts, 2) # [R,] 133 | indices = torch.arange(intersection_masks_any.size()[0]).int() # [R,] 134 | # Apply the intersection masks over tensors. 135 | indices = indices[intersection_masks_any] # [R',] 136 | intersection_masks = intersection_masks[intersection_masks_any] # [R', 6] 137 | intersection_points = intersection_points[intersection_masks_any] # [R', 6, 3] 138 | intersection_points = intersection_points[intersection_masks].view(-1, 2, 3) # [R', 2, 3] 139 | # Ensure one or more positive ts. 140 | intersection_ts = intersection_ts[intersection_masks_any] # [R', 6] 141 | intersection_ts = intersection_ts[intersection_masks] # [R'*2] 142 | intersection_ts = intersection_ts.view(indices.size()[0], 2) # [R', 2] 143 | positive_ts_mask = (intersection_ts >= 0) # [R', 2] 144 | positive_ts_count = torch.sum(positive_ts_mask.int(), dim=1) # [R'] 145 | positive_ts_mask = (positive_ts_count >= 1) # [R'] 146 | intersection_points = intersection_points[positive_ts_mask] # [R'', 2, 3] 147 | false_indices = indices[torch.logical_not(positive_ts_mask)] # [R',] 148 | indices = indices[positive_ts_mask] # [R'',] 149 | if len(false_indices) > 0: 150 | intersection_masks_any[false_indices[:, None]] = torch.zeros(false_indices.size(), dtype=torch.bool) 151 | return rays_start_point_in_box_frame, intersection_masks_any, intersection_points, indices 152 | 153 | 154 | def compute_bounds_from_intersect_points(rays_o, intersect_indices, 155 | intersect_points): 156 | """Computes bounds from intersection points. 157 | 158 | Note: Make sure that inputs are in the same coordiante frame. 159 | 160 | Args: 161 | rays_o: [R, 3] float tensor 162 | intersect_indices: [R', 1] float tensor 163 | intersect_points: [R', 2, 3] float tensor 164 | 165 | Returns: 166 | intersect_bounds: [R', 2] float tensor 167 | 168 | where R is the number of rays and R' is the number of intersecting rays. 169 | """ 170 | intersect_rays_o = rays_o[intersect_indices] # [R', 1, 3] 171 | intersect_diff = intersect_points - intersect_rays_o # [R', 2, 3] 172 | intersect_bounds = torch.norm(intersect_diff, dim=2) # [R', 2] 173 | 174 | # Sort the bounds so that near comes before far for all rays. 175 | intersect_bounds, _ = torch.sort(intersect_bounds, dim=1) # [R', 2] 176 | 177 | # For some reason the sort function returns [R', ?] instead of [R', 2], so we 178 | # will explicitly reshape it. 179 | intersect_bounds = intersect_bounds.view(-1, 2) # [R', 2] 180 | return intersect_bounds 181 | 182 | 183 | def compute_ray_bbox_bounds_pairwise(rays_o, rays_d, box_length, 184 | box_width, box_height, box_center, 185 | box_rotation, far_limit=1e10): 186 | """Computes near and far bounds for rays intersecting with bounding boxes. 187 | 188 | Note: rays and boxes are defined in world coordinate frame. 189 | 190 | Args: 191 | rays_o: [R, 3] float tensor. A set of ray origins. 192 | rays_d: [R, 3] float tensor. A set of ray directions. 193 | box_length: scalar or [R,] float tensor. Bounding box length. 194 | box_width: scalar or [R,] float tensor. Bounding box width. 195 | box_height: scalar or [R,] float tensor. Bounding box height. 196 | box_center: [3,] or [R, 3] float tensor. The center of the box. 197 | box_rotation: [3, 3] or [R, 3, 3] float tensor. The box rotation matrix. 198 | far_limit: float. The maximum far value to use. 199 | 200 | Returns: 201 | intersect_bounds: [R', 2] float tensor. The bounds per-ray, sorted in 202 | ascending order. 203 | intersect_indices: [R', 1] float tensor. The intersection indices. 204 | intersect_mask: [R,] float tensor. The mask denoting intersections. 205 | """ 206 | # Compute ray destinations. 207 | normalized_rays_d = ray_utils.normalize_rays(rays=rays_d) 208 | rays_dst = rays_o + far_limit * normalized_rays_d 209 | 210 | # Transform the rays from world to box coordinate frame. 211 | rays_o_in_box_frame, intersect_mask, intersect_points_in_box_frame, intersect_indices = ( # [R,], [R', 2, 3], [R', 2] 212 | ray_box_intersection_pairwise( 213 | box_center=box_center, 214 | box_rotation_matrix=box_rotation, 215 | box_length=box_length, 216 | box_width=box_width, 217 | box_height=box_height, 218 | rays_start_point=rays_o, 219 | rays_end_point=rays_dst)) 220 | intersect_indices = intersect_indices.unsqueeze(1).long() # [R', 1] 221 | intersect_bounds = compute_bounds_from_intersect_points( 222 | rays_o=rays_o_in_box_frame, 223 | intersect_indices=intersect_indices, 224 | intersect_points=intersect_points_in_box_frame) 225 | return intersect_bounds, intersect_indices, intersect_mask 226 | -------------------------------------------------------------------------------- /ObjectFolder1.0/cam_utils.py: -------------------------------------------------------------------------------- 1 | """Various camera utility functions.""" 2 | 3 | import numpy as np 4 | 5 | 6 | def w2c_to_c2w(w2c): 7 | """ 8 | Args: 9 | w2c: [N, 4, 4] np.float32. World-to-camera extrinsics matrix. 10 | 11 | Returns: 12 | c2w: [N, 4, 4] np.float32. Camera-to-world extrinsics matrix. 13 | """ 14 | R = w2c[:3, :3] 15 | T = w2c[:3, 3] 16 | 17 | c2w = np.eye(4, dtype=np.float32) 18 | c2w[:3, 3] = -1 * np.dot(R.transpose(), w2c[:3, 3]) 19 | c2w[:3, :3] = R.transpose() 20 | return c2w 21 | -------------------------------------------------------------------------------- /ObjectFolder1.0/dataset_visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/ObjectFolder1.0/dataset_visualization.png -------------------------------------------------------------------------------- /ObjectFolder1.0/demo/audio_demo_forces.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/ObjectFolder1.0/demo/audio_demo_forces.npy -------------------------------------------------------------------------------- /ObjectFolder1.0/demo/audio_demo_vertices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/ObjectFolder1.0/demo/audio_demo_vertices.npy -------------------------------------------------------------------------------- /ObjectFolder1.0/demo/touch_demo_vertices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/ObjectFolder1.0/demo/touch_demo_vertices.npy -------------------------------------------------------------------------------- /ObjectFolder1.0/demo/vision_demo.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/ObjectFolder1.0/demo/vision_demo.npy -------------------------------------------------------------------------------- /ObjectFolder1.0/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import sys 4 | import torch 5 | import numpy as np 6 | import imageio 7 | import json 8 | import random 9 | import time 10 | from tqdm import tqdm, trange 11 | from scipy.spatial import KDTree 12 | 13 | import indirect_utils 14 | from load_osf import load_osf_data 15 | from intersect import compute_object_intersect_tensors 16 | from ray_utils import transform_rays 17 | from run_osf_helpers import * 18 | from scatter import scatter_coarse_and_fine 19 | import shadow_utils 20 | 21 | import VisionNet_utils 22 | import AudioNet_utils 23 | import AudioNet_model 24 | import TouchNet_utils 25 | import TouchNet_model 26 | 27 | from scipy.io.wavfile import write 28 | import librosa 29 | import imageio 30 | 31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 32 | 33 | def config_parser(): 34 | import configargparse 35 | parser = configargparse.ArgParser( 36 | config_file_parser_class=configargparse.YAMLConfigFileParser 37 | ) 38 | #parser = configargparse.ArgumentParser() 39 | parser.add_argument("--object_file_path", type=str, required=True, help='ObjectFile path') 40 | parser.add_argument('--config', is_config_file=True, help='config file path') 41 | 42 | # VisionNet options 43 | parser.add_argument("--vision_test_file_path", default='data/vision_demo.npy', help='The path of the testing file for vision, which should be a npy file.') 44 | parser.add_argument("--vision_results_dir", type=str, default='./results/vision/', help='The path of the vision results directory to save rendered images.') 45 | parser.add_argument("--chunk", type=int, default=1024*32, 46 | help='number of rays processed in parallel, decrease if running out of memory') 47 | parser.add_argument("--netchunk", type=int, default=1024*64, 48 | help='number of pts sent through network in parallel, decrease if running out of memory') 49 | 50 | # AudioNet options 51 | parser.add_argument('--audio_vertices_file_path', default='./data/audio_demo_vertices.npy', help='The path of the testing vertices file for audio, which should be a npy file.') 52 | parser.add_argument('--audio_forces_file_path', default='./data/forces.npy', help='The path of forces file for audio, which should be a npy file.') 53 | parser.add_argument('--audio_batchSize', type=int, default=10000, help='input batch size') 54 | parser.add_argument('--audio_results_dir', type=str, default='./results/audio/', help='The path of audio results directory to save rendered impact sounds as .wav files.') 55 | 56 | # TouchNet options 57 | parser.add_argument('--touch_vertices_file_path', default='./data/touch_demo_vertices.npy', help='The path of the testing vertices file for touch, which should be a npy file.') 58 | parser.add_argument('--touch_batchSize', type=int, default=10000, help='input batch size') 59 | parser.add_argument('--touch_results_dir', type=str, default='./results/touch/', help='The path of the touch results directory to save rendered tactile RGB images.') 60 | 61 | return parser 62 | 63 | 64 | def VisionNet_eval(args): 65 | 66 | args.secondary_chunk = args.chunk 67 | 68 | metadata, render_metadata = None, None 69 | near = 0.01 70 | far = 4 71 | 72 | poses, hwf, i_split, metadata = load_osf_data(args.vision_test_file_path) 73 | i_test = i_split[0] 74 | 75 | render_poses = np.array(poses[i_test]) 76 | 77 | # Create dummy metadata if not loaded from dataset. 78 | if metadata is None: 79 | metadata = torch.tensor([[0, 0, 1]] * len(images), dtype=torch.float) # [N, 3] 80 | if render_metadata is None: 81 | render_metadata = metadata 82 | 83 | # Cast intrinsics to right types 84 | H, W, focal = hwf 85 | H, W = int(H), int(W) 86 | hwf = [H, W, focal] 87 | 88 | # Create nerf model 89 | render_kwargs_train, render_kwargs_test, start, grad_vars, models, optimizer = VisionNet_utils.create_nerf( 90 | args, metadata, render_metadata) 91 | global_step = start 92 | 93 | bds_dict = { 94 | 'near': torch.tensor(near).float(), 95 | 'far': torch.tensor(far).float(), 96 | } 97 | render_kwargs_train.update(bds_dict) 98 | render_kwargs_test.update(bds_dict) 99 | 100 | render_kwargs_test['render_metadata'] = metadata 101 | 102 | # Move testing data to GPU 103 | render_poses = torch.Tensor(render_poses).to(device) 104 | 105 | with torch.no_grad(): 106 | images = None 107 | 108 | testsavedir = args.vision_results_dir 109 | os.makedirs(testsavedir, exist_ok=True) 110 | print('Begin rendering images in ', testsavedir) 111 | rgbs, _ = VisionNet_utils.render_path(render_poses, hwf, args.chunk, render_kwargs_test, 112 | gt_imgs=images, savedir=testsavedir, 113 | c2w_staticcam=None, render_start=None, 114 | render_end=None) 115 | print('Done rendering images in ', testsavedir) 116 | 117 | 118 | def AudioNet_eval(args): 119 | dim = 32 120 | audio_sampling_rate = 16000 121 | audio_window_size = 400 122 | audio_hop_size = 160 123 | audio_stft_time_dim = 201 124 | audio_stft_freq_dim = 257 125 | audio_network_depth = 8 126 | xyz = np.load(args.audio_vertices_file_path) 127 | forces = np.load(args.audio_forces_file_path) 128 | xyz = xyz[0:1,:] 129 | forces = forces[0:1] 130 | N = xyz.shape[0] 131 | print(N) 132 | #N: number of data 133 | #D: number of dimension 134 | #C: number of channels, real and img 135 | #F: number of frequency 136 | #T: number of timestamps 137 | N, D, C, F, T = N, 3, 2, audio_stft_freq_dim, audio_stft_time_dim 138 | 139 | checkpoint = torch.load(args.object_file_path) 140 | normalizer_dic = checkpoint['AudioNet']['normalizer'] 141 | voxel_vertex = checkpoint['AudioNet']['voxel_vertex'] 142 | vert_tree = KDTree(voxel_vertex) 143 | translation = checkpoint['AudioNet']['translation'] 144 | scale = checkpoint['AudioNet']['scale'] 145 | 146 | k = 4 # Average over 4 nearest neighbors 147 | xyz_in_voxel = np.zeros((4, N, 3)) 148 | for i in range(N): 149 | obj_coordinates = xyz[i] 150 | binvox_coordinates = AudioNet_utils.transform_mesh_collision_binvox(obj_coordinates, translation, scale) 151 | coordinates_in_voxel = binvox_coordinates * dim 152 | voxel_verts_index = vert_tree.query(coordinates_in_voxel, k)[1] 153 | for j in range(k): 154 | xyz_in_voxel[j, i] = voxel_vertex[voxel_verts_index[j]] 155 | 156 | xyz_in_voxel = np.repeat(xyz_in_voxel.reshape((4, N, 1, 3)), F * T, axis=2) 157 | #normalize xyz_in_voxel to [-1, 1] 158 | xyz_in_voxel_min = xyz_in_voxel.min() 159 | xyz_in_voxel_max = xyz_in_voxel.max() 160 | xyz_in_voxel = (xyz_in_voxel - xyz_in_voxel_min) / (xyz_in_voxel_max - xyz_in_voxel_min) 161 | 162 | spec_comps_f1_min = normalizer_dic['f1_min'] 163 | spec_comps_f1_max = normalizer_dic['f1_max'] 164 | spec_comps_f2_min = normalizer_dic['f2_min'] 165 | spec_comps_f2_max = normalizer_dic['f2_max'] 166 | spec_comps_f3_min = normalizer_dic['f3_min'] 167 | spec_comps_f3_max = normalizer_dic['f3_max'] 168 | 169 | #initialize frequency and time features 170 | freq_feats = np.repeat(np.repeat(np.arange(F).reshape((F, 1)), T, axis=1).reshape((1, 1, F, T)), N, axis=0) 171 | time_feats = np.repeat(np.repeat(np.arange(T).reshape((1, T)), F, axis=0).reshape((1, 1, F, T)), N, axis=0) 172 | 173 | #normalize frequency and time features to [-1, 1] 174 | freq_feats_min = freq_feats.min() 175 | freq_feats_max = freq_feats.max() 176 | time_feats_min = time_feats.min() 177 | time_feats_max = time_feats.max() 178 | freq_feats = (freq_feats - freq_feats_min) / (freq_feats_max - freq_feats_min) 179 | time_feats = (time_feats - time_feats_min) / (time_feats_max - time_feats_min) 180 | 181 | data_x = np.concatenate((freq_feats, time_feats), axis=1) 182 | data_y = np.concatenate((freq_feats, time_feats), axis=1) 183 | data_z = np.concatenate((freq_feats, time_feats), axis=1) 184 | data_x = np.transpose(data_x.reshape((N, 2, -1)), axes = [0, 2, 1]) 185 | data_y = np.transpose(data_y.reshape((N, 2, -1)), axes = [0, 2, 1]) 186 | data_z = np.transpose(data_z.reshape((N, 2, -1)), axes = [0, 2, 1]) 187 | data_x = np.repeat(data_x.reshape((1, N, -1, 2)), k, axis=0) 188 | data_y = np.repeat(data_y.reshape((1, N, -1, 2)), k, axis=0) 189 | data_z = np.repeat(data_z.reshape((1, N, -1, 2)), k, axis=0) 190 | 191 | #Now concatenate xyz and feats to get final feats matrix as [x, y, z, f, t, real, img] 192 | feats_x = np.concatenate((xyz_in_voxel, data_x), axis=3).reshape((-1, 5)) 193 | feats_y = np.concatenate((xyz_in_voxel, data_y), axis=3).reshape((-1, 5)) 194 | feats_z = np.concatenate((xyz_in_voxel, data_z), axis=3).reshape((-1, 5)) 195 | 196 | embed_fn, input_ch = AudioNet_model.get_embedder(10, 0) 197 | model = AudioNet_model.AudioNeRF(D = audio_network_depth, input_ch = input_ch) 198 | state_dic = checkpoint['AudioNet']["model_state_dict"] 199 | state_dic = AudioNet_utils.strip_prefix_if_present(state_dic, 'module.') 200 | model.load_state_dict(state_dic) 201 | model = nn.DataParallel(model).to(device) 202 | model.eval() 203 | loss_fn = torch.nn.MSELoss(reduction='mean') 204 | 205 | start_time = time.time() 206 | preds_x = np.zeros((feats_x.shape[0], 2)) 207 | preds_y = np.zeros((feats_y.shape[0], 2)) 208 | preds_z = np.zeros((feats_z.shape[0], 2)) 209 | N_rand = args.audio_batchSize 210 | 211 | print("Begin rendering impact sounds in ", args.audio_results_dir) 212 | for i in trange(feats_x.shape[0] // N_rand + 1): 213 | curr_feats_x = torch.Tensor(feats_x[i*N_rand:(i+1)*N_rand]).to(device) 214 | curr_feats_y = torch.Tensor(feats_y[i*N_rand:(i+1)*N_rand]).to(device) 215 | curr_feats_z = torch.Tensor(feats_z[i*N_rand:(i+1)*N_rand]).to(device) 216 | embedded_x = embed_fn(curr_feats_x) 217 | embedded_y = embed_fn(curr_feats_y) 218 | embedded_z = embed_fn(curr_feats_z) 219 | results_x, results_y, results_z = model(embedded_x, embedded_y, embedded_z) 220 | 221 | preds_x[i*N_rand:(i+1)*N_rand, :] = results_x.detach().cpu().numpy() 222 | preds_y[i*N_rand:(i+1)*N_rand, :] = results_y.detach().cpu().numpy() 223 | preds_z[i*N_rand:(i+1)*N_rand, :] = results_z.detach().cpu().numpy() 224 | 225 | preds_x = preds_x * (spec_comps_f1_max - spec_comps_f1_min) + spec_comps_f1_min 226 | preds_y = preds_y * (spec_comps_f2_max - spec_comps_f2_min) + spec_comps_f2_min 227 | preds_z = preds_z * (spec_comps_f3_max - spec_comps_f3_min) + spec_comps_f3_min 228 | preds_x = np.transpose(preds_x.reshape((k, N, -1, 2)), axes = [0, 1, 3, 2]).reshape((k, N, 1, C, F, T)) 229 | preds_y = np.transpose(preds_y.reshape((k, N, -1, 2)), axes = [0, 1, 3, 2]).reshape((k, N, 1, C, F, T)) 230 | preds_z = np.transpose(preds_z.reshape((k, N, -1, 2)), axes = [0, 1, 3, 2]).reshape((k, N, 1, C, F, T)) 231 | 232 | #save evaluation results 233 | os.makedirs(args.audio_results_dir, exist_ok=True) 234 | 235 | for i in trange(N): 236 | force_x, force_y, force_z = forces[i] 237 | signal = np.zeros(audio_sampling_rate*2) 238 | for j in range(k): 239 | spec_x = preds_x[j, i, 0, 0, :, :] + preds_x[j, i, 0, 1, :, :] * 1j 240 | signal_x = librosa.istft(spec_x, hop_length=audio_hop_size, win_length=audio_window_size, length=audio_sampling_rate*2) 241 | spec_y = preds_y[j, i, 0, 0, :, :] + preds_y[j, i, 0, 1, :, :] * 1j 242 | signal_y = librosa.istft(spec_y, hop_length=audio_hop_size, win_length=audio_window_size, length=audio_sampling_rate*2) 243 | spec_z = preds_z[j, i, 0, 0, :, :] + preds_z[j, i, 0, 1, :, :] * 1j 244 | signal_z = librosa.istft(spec_z, hop_length=audio_hop_size, win_length=audio_window_size, length=audio_sampling_rate*2) 245 | temp = signal_x * force_x + signal_y * force_y + signal_z * force_z 246 | signal += temp 247 | signal = signal / np.abs(signal).max() 248 | end_time = time.time() 249 | print(end_time - start_time) 250 | # Write WAV file 251 | output_path = os.path.join(args.audio_results_dir, str(i+1) + '.wav') 252 | write(output_path, audio_sampling_rate, signal.astype(np.float32)) 253 | print('Done rendering impact sounds in ', args.audio_results_dir) 254 | 255 | 256 | def TouchNet_eval(args): 257 | touch_network_depth = 8 258 | 259 | xyz = np.load(args.touch_vertices_file_path) 260 | 261 | #N: number of data 262 | #C: channels 263 | #W: Width dimension 264 | #H: Height dimension 265 | #N, C, W, H = touch_images.shape 266 | N, C, W, H = xyz.shape[0], 3, 160, 120 267 | 268 | #initialize frequency and time features 269 | w_feats = np.repeat(np.repeat(np.arange(W).reshape((W, 1)), H, axis=1).reshape((1, 1, W, H)), N, axis=0) 270 | h_feats = np.repeat(np.repeat(np.arange(H).reshape((1, H)), W, axis=0).reshape((1, 1, W, H)), N, axis=0) 271 | 272 | checkpoint = torch.load(args.object_file_path) 273 | 274 | #normalize frequency and time features to [-1, 1] 275 | w_feats_min = w_feats.min() 276 | w_feats_max = w_feats.max() 277 | h_feats_min = h_feats.min() 278 | h_feats_max = h_feats.max() 279 | w_feats = 2 * ((w_feats - w_feats_min) / w_feats_max) - 1 280 | h_feats = 2 * ((h_feats - h_feats_min) / h_feats_max) - 1 281 | 282 | data_x = np.concatenate((w_feats, h_feats), axis=1) 283 | data_x = np.transpose(data_x.reshape((N, 2, -1)), axes = [0, 2, 1]) 284 | 285 | xyz = np.repeat(xyz.reshape((N, 1, 3)), W * H, axis=1) 286 | 287 | #normalize xyz to [-1, 1] 288 | xyz_min = xyz.min() 289 | xyz_max = xyz.max() 290 | xyz = 2 * ((xyz - xyz_min) / xyz_max) - 1 291 | 292 | #Now concatenate xyz and feats to get final feats matrix as [x, y, z, w, h, r, g, b] 293 | data= np.concatenate((xyz, data_x), axis=2).reshape((-1, 5)) 294 | feats = data 295 | 296 | embed_fn, input_ch = TouchNet_model.get_embedder(10, 0) 297 | model = TouchNet_model.NeRF(D = touch_network_depth, input_ch = input_ch, output_ch = 3) 298 | state_dic = checkpoint['TouchNet']["model_state_dict"] 299 | state_dic = TouchNet_utils.strip_prefix_if_present(state_dic, 'module.') 300 | model.load_state_dict(state_dic) 301 | model = nn.DataParallel(model).to(device) 302 | model.eval() 303 | loss_fn = torch.nn.MSELoss(reduction='mean') 304 | 305 | preds = np.zeros((feats.shape[0], 3)) 306 | N_rand = args.touch_batchSize 307 | 308 | print("Begin rendering tactile images in ", args.touch_results_dir) 309 | start_time = time.time() 310 | for i in trange(feats.shape[0] // N_rand + 1): 311 | curr_feats = torch.Tensor(feats[i*N_rand:(i+1)*N_rand]).to(device) 312 | embedded = embed_fn(curr_feats) 313 | results = model(embedded) 314 | preds[i*N_rand:(i+1)*N_rand, :] = results.detach().cpu().numpy() 315 | end_time = time.time() 316 | print(end_time - start_time) 317 | 318 | preds = (((preds + 1) / 2) * 255) 319 | preds = np.transpose(preds.reshape((N, -1, 3)), axes = [0, 2, 1]).reshape((N, C, W, H)) 320 | preds = np.clip(np.rint(preds), 0, 255).astype(np.uint8) 321 | preds = preds.transpose(0,2,3,1) 322 | 323 | os.makedirs(args.touch_results_dir, exist_ok=True) 324 | #save evaluation results 325 | for i in trange(N): 326 | filename = os.path.join(args.touch_results_dir, '{}.png'.format(i+1)) 327 | imageio.imwrite(filename, preds[i]) 328 | print("Done rendering tactile images in ", args.touch_results_dir) 329 | 330 | if __name__ =='__main__': 331 | parser = config_parser() 332 | args = parser.parse_args() 333 | 334 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 335 | 336 | VisionNet_eval(args=args) 337 | AudioNet_eval(args=args) 338 | TouchNet_eval(args=args) 339 | -------------------------------------------------------------------------------- /ObjectFolder1.0/indirect_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for indirect illumination.""" 2 | import math 3 | 4 | import torch 5 | 6 | import ray_utils 7 | 8 | 9 | def create_ray_batch(pts, near, far, rays_i, use_viewdirs): 10 | """Create batch for indirect rays. 11 | 12 | Args: 13 | pts: [R, S, 3] float tensor. Primary points. 14 | near: Near sampling bound. 15 | far: Far sampling bound. 16 | rays_i: [R, 1] float tensor. Ray image IDs. 17 | use_viewdirs: bool. Whether to use view directions. 18 | 19 | Returns: 20 | ray_batch: [RS, M] float tensor. Batch of secondary rays containing one secondary 21 | ray for each primary ray sample. Each secondary ray originates at a primary 22 | point and points in the direction towards the randomly sampled (indirect) light 23 | source. 24 | """ 25 | num_primary_rays = pts.size()[0] 26 | num_primary_samples = pts.size()[1] 27 | 28 | rays_dst = ray_utils.sample_random_lightdirs(num_primary_rays, num_primary_samples) # [R, S, 3] 29 | 30 | rays_o = pts.view(-1, 3) # [RS, 3] 31 | rays_dst = pts.view(-1, 3) # [RS, 3] 32 | 33 | rays_near = torch.full((rays_o.size()[0], 1), near).float() # [RS, 1] 34 | rays_far = torch.full((rays_o.size()[0], 1), far).float() # [RS, 1] 35 | 36 | rays_i = torch.tile(rays_i[:, None, :], (1, num_primary_samples, 1)) # [R?, S, 1] 37 | rays_i = rays_i.view(-1, 1) # [RS, 1] 38 | 39 | ray_batch = ray_utils.create_ray_batch( 40 | rays_o=rays_o, rays_dst=rays_dst, rays_near=rays_near, rays_far=rays_far, 41 | rays_i=rays_i, use_viewdirs=use_viewdirs) 42 | 43 | return ray_batch 44 | -------------------------------------------------------------------------------- /ObjectFolder1.0/intersect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import box_utils 4 | import ray_utils 5 | 6 | 7 | def apply_mask_to_tensors(mask, tensors): 8 | """Applies mask to a list of tensors. 9 | 10 | Args: 11 | mask: [R, ...]. Mask to apply. 12 | tensors: List of [R, ...]. List of tensors. 13 | 14 | Returns: 15 | intersect_tensors: List of shape [R?, ...]. Masked tensors. 16 | """ 17 | intersect_tensors = [] 18 | for t in tensors: 19 | intersect_t = t[mask] # [R?, ...] 20 | intersect_tensors.append(intersect_t) 21 | return intersect_tensors 22 | 23 | 24 | def get_full_intersection_tensors(ray_batch): 25 | """Test case that selects all rays.""" 26 | # mask = torch.rand(ray_batch.size())[:, 0] # [R?,] 27 | # mask = (mask > 0.5).bool() 28 | mask = torch.ones_like(ray_batch, dtype=torch.bool)[:, 0] 29 | 30 | n_intersect = mask.size()[0] # R? 31 | indices = torch.arange(n_intersect, dtype=torch.int).unsqueeze(1) # [R?, 1] 32 | indices = apply_mask_to_tensors( # [R?, M] 33 | mask=mask, # [R,] 34 | tensors=[indices])[0] # [R, M] 35 | 36 | bounds = ray_batch[:, 6:8] # [R?,] 37 | bounds = apply_mask_to_tensors( # [R?, M] 38 | mask=mask, # [R,] 39 | tensors=[bounds])[0] # [R, M] 40 | return mask, indices, bounds 41 | 42 | 43 | def compute_object_intersect_tensors(ray_batch, box_center, box_dims): 44 | """Compute rays that intersect with bounding boxes. 45 | 46 | Args: 47 | ray_batch: [R, M] float tensor. Batch of rays. 48 | box_center: List of 3 floats containing the (x, y, z) center of bbox. 49 | box_dims: List of 3 floats containing the x, y, z dimensions of the bbox. 50 | 51 | Returns: 52 | intersect_ray_batch: [R?, M] float tensor. Batch of intersecting rays. 53 | indices: [R?, 1] float tensor. Indices of intersecting rays. 54 | """ 55 | # Check that bbox params are properly formed. 56 | for lst in [box_center, box_dims]: 57 | assert type(lst) == list 58 | assert len(lst) == 3 59 | assert all((isinstance(x, int) or isinstance(x, float)) for x in lst) 60 | 61 | # For now, we assume bbox has no rotation. 62 | num_rays = ray_batch.size()[0] # R 63 | box_center = torch.tile(torch.tensor(box_center), (num_rays, 1)).float() # [R, 3] 64 | box_dims = torch.tile(torch.tensor(box_dims), (num_rays, 1)).float() # [R, 3] 65 | box_rotation = torch.tile(torch.eye(3).unsqueeze(0), (num_rays, 1, 1)).float() # [R, 3, 3] 66 | 67 | # Compute ray-bbox intersections. 68 | bounds, indices, mask = box_utils.compute_ray_bbox_bounds_pairwise( # [R', 2], [R',], [R,] 69 | rays_o=ray_batch[:, 0:3], # [R, 3] 70 | rays_d=ray_batch[:, 3:6], # [R, 3] 71 | box_length=box_dims[:, 0], # [R,] 72 | box_width=box_dims[:, 1], # [R,] 73 | box_height=box_dims[:, 2], # [R,] 74 | box_center=box_center, # [R, 3] 75 | box_rotation=box_rotation) # [R, 3, 3] 76 | 77 | # Apply the intersection mask to the ray batch. 78 | intersect_ray_batch = apply_mask_to_tensors( # [R?, M] 79 | mask=mask, # [R,] 80 | tensors=[ray_batch])[0] # [R, M] 81 | 82 | # Update the near and far bounds of the ray batch with the intersect bounds. 83 | intersect_ray_batch = ray_utils.update_ray_batch_bounds( # [R?, M] 84 | ray_batch=intersect_ray_batch, # [R?, M] 85 | bounds=bounds) # [R?, 2] 86 | return intersect_ray_batch, indices, mask # [R?, M], [R?, 1] 87 | -------------------------------------------------------------------------------- /ObjectFolder1.0/load_osf.py: -------------------------------------------------------------------------------- 1 | """Data loader for OSF data.""" 2 | 3 | import os 4 | import torch 5 | import numpy as np 6 | import imageio 7 | import json 8 | 9 | import cam_utils 10 | 11 | 12 | trans_t = lambda t: torch.tensor([ 13 | [1,0,0,0], 14 | [0,1,0,0], 15 | [0,0,1,t], 16 | [0,0,0,1] 17 | ], dtype=torch.float) 18 | 19 | rot_phi = lambda phi: torch.tensor([ 20 | [1,0,0,0], 21 | [0,np.cos(phi),-np.sin(phi),0], 22 | [0,np.sin(phi), np.cos(phi),0], 23 | [0,0,0,1] 24 | ], dtype=torch.float) 25 | 26 | rot_theta = lambda th: torch.tensor([ 27 | [np.cos(th),0,-np.sin(th),0], 28 | [0,1,0,0], 29 | [np.sin(th),0, np.cos(th),0], 30 | [0,0,0,1] 31 | ], dtype=torch.float) 32 | 33 | 34 | def pose_spherical(theta, phi, radius): 35 | c2w = trans_t(radius) 36 | c2w = rot_phi(phi/180.*np.pi) @ c2w 37 | c2w = rot_theta(theta/180.*np.pi) @ c2w 38 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 39 | return c2w 40 | 41 | def coordinates_to_c2w(x, y, z, r=2.5): 42 | theta = np.arccos(z / r) 43 | phi = np.arctan2(x, -y) 44 | Rx = np.array([[1, 0, 0], [0, np.cos(theta), -np.sin(theta)], [0, np.sin(theta), np.cos(theta)]]) 45 | Rz = np.array([[np.cos(phi), -np.sin(phi), 0], [np.sin(phi), np.cos(phi), 0], [0, 0, 1]]) 46 | R = Rz @ Rx 47 | c2w = R.tolist() 48 | c2w[0].append(x) 49 | c2w[1].append(y) 50 | c2w[2].append(z) 51 | c2w.append([0., 0., 0., 1.]) 52 | #c2w = np.array(c2w).astype(np.float32) 53 | return c2w 54 | 55 | def convert_cameras_to_nerf_format(anno): 56 | """ 57 | Args: 58 | anno: List of annotations for each example. Each annotation is represented by a 59 | dictionary that must contain the key `RT` which is the world-to-camera 60 | extrinsics matrix with shape [3, 4], in [right, down, forward] coordinates. 61 | 62 | Returns: 63 | c2w: [N, 4, 4] np.float32. Array of camera-to-world extrinsics matrices in 64 | [right, up, backwards] coordinates. 65 | """ 66 | c2w_list = [] 67 | for a in anno: 68 | # Convert from w2c to c2w. 69 | w2c = np.array(a['RT'] + [[0.0, 0.0, 0.0, 1.0]]) 70 | c2w = cam_utils.w2c_to_c2w(w2c) 71 | 72 | # Convert from [right, down, forwards] to [right, up, backwards] 73 | c2w[:3, 1] *= -1 # down -> up 74 | c2w[:3, 2] *= -1 # forwards -> back 75 | c2w_list.append(c2w) 76 | c2w = np.array(c2w_list) 77 | print("c2w: ", c2w) 78 | return c2w 79 | 80 | 81 | def load_osf_data(test_file_path): 82 | 83 | all_poses = [] 84 | all_metadata = [] 85 | counts = [0] 86 | test_file = np.load(test_file_path) 87 | N = test_file.shape[0] 88 | for i in range(N): 89 | cx, cy, cz, lx, ly, lz = test_file[i] 90 | poses = coordinates_to_c2w(cx, cy, cz) 91 | metadata = np.array([[lx, ly, lz]]).astype(np.float32) 92 | all_poses.append(poses) 93 | all_metadata.append(metadata) 94 | 95 | poses = np.array(all_poses).astype(np.float32) 96 | 97 | metadata = np.concatenate(all_metadata, 0) 98 | counts.append(N) 99 | i_split = [np.arange(counts[0], counts[1])] 100 | 101 | H, W, focal = 256, 256, 355.5555419921875 102 | 103 | return poses, [H, W, focal], i_split, metadata 104 | -------------------------------------------------------------------------------- /ObjectFolder1.0/ray_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for ray computation.""" 2 | import math 3 | import numpy as np 4 | from scipy.spatial.transform import Rotation as R 5 | import box_utils 6 | import torch 7 | 8 | 9 | def apply_batched_transformations(inputs, transformations): 10 | """Batched transformation of inputs. 11 | 12 | Args: 13 | inputs: List of [R, S, 3] 14 | transformations: [R, 4, 4] 15 | 16 | Returns: 17 | transformed_inputs: List of [R, S, 3] 18 | """ 19 | # if rotation_only: 20 | # transformations[:, :3, 3] = torch.zeros((3,), dtype=torch.float) 21 | 22 | transformed_inputs = [] 23 | for x in inputs: 24 | N_samples = x.size()[1] 25 | homog_transformations = transformations.unsqueeze(1) # [R, 1, 4, 4] 26 | homog_transformations = torch.tile(homog_transformations, (1, N_samples, 1, 1)) # [R, S, 4, 4] 27 | homog_component = torch.ones_like(x)[..., 0:1] # [R, S, 1] 28 | homog_x = torch.cat((x, homog_component), axis=-1) # [R, S, 4] 29 | homog_x = homog_x.unsqueeze(2) 30 | transformed_x = torch.matmul( 31 | homog_x, 32 | torch.transpose(homog_transformations, 2, 3)) # [R, S, 1, 4] 33 | transformed_x = transformed_x[..., 0, :3] # [R, S, 3] 34 | transformed_inputs.append(transformed_x) 35 | return transformed_inputs 36 | 37 | 38 | def get_transformation_from_params(params): 39 | translation, rotation = [0, 0, 0], [0, 0, 0] 40 | if 'translation' in params: 41 | translation = params['translation'] 42 | if 'rotation' in params: 43 | rotation = params['rotation'] 44 | translation = torch.tensor(translation, dtype=torch.float) 45 | rotmat = torch.tensor(R.from_euler('xyz', rotation, degrees=True).as_matrix(), dtype=torch.float) 46 | return translation, rotmat 47 | 48 | 49 | def rotate_dirs(dirs, rotmat): 50 | """ 51 | Args: 52 | dirs: [R, 3] float tensor. 53 | rotmat: [3, 3] 54 | """ 55 | if type(dirs) == np.ndarray: 56 | dirs = torch.tensor(dirs).float() 57 | #rotmat = rotmat.unsqueeze(0) 58 | rotmat = torch.broadcast_to(rotmat, (dirs.shape[0], 3, 3)) # [R, 3, 3] 59 | dirs_obj = torch.matmul(dirs.unsqueeze(1), torch.transpose(rotmat, 1, 2)) # [R, 1, 3] 60 | dirs_obj = dirs_obj.squeeze(1) # [R, 3] 61 | return dirs_obj 62 | 63 | 64 | def transform_dirs(dirs, params, inverse=False): 65 | _, rotmat = get_transformation_from_params(params) # [3,], [3, 3] 66 | if inverse: 67 | rotmat = torch.transpose(rotmat, 0, 1) # [3, 3] 68 | dirs_transformed = rotate_dirs(dirs, rotmat) 69 | return dirs_transformed 70 | 71 | 72 | def transform_rays(ray_batch, params, use_viewdirs, inverse=False): 73 | """Transform rays into object coordinate frame given o2w transformation params. 74 | 75 | Note: do not assume viewdirs is always the normalized version of rays_d (e.g., in staticcam case). 76 | 77 | Args: 78 | ray_batch: [R, M] float tensor. Batch of rays. 79 | params: Dictionary containing transformation parameters: 80 | 'translation': List of 3 elements. xyz translation. 81 | 'rotation': List of 3 euler angles in xyz. 82 | use_viewdirs: bool. Whether to we are using viewdirs. 83 | inverse: bool. Whether to apply inverse of the transformations provided in 'params'. 84 | 85 | Returns: 86 | ray_batch_obj: [R, M] float tensor. The ray batch, in object coordinate frame. 87 | """ 88 | rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] 89 | translation, rotmat = get_transformation_from_params(params) # [3,], [3, 3] 90 | 91 | if inverse: 92 | translation = -1 * translation # [3,] 93 | rotmat = torch.transpose(rotmat, 1, 0) # [3, 3] 94 | 95 | translation_inverse = -1 * translation 96 | rotmat_inverse = torch.transpose(rotmat, 1, 0) 97 | 98 | # Transform the ray origin. 99 | rays_o_obj, _ = box_utils.ray_to_box_coordinate_frame_pairwise( 100 | box_center=translation_inverse, 101 | box_rotation_matrix=rotmat_inverse, 102 | rays_start_point=rays_o, 103 | rays_end_point=rays_d) 104 | 105 | # Only apply rotation to rays_d. 106 | rays_d_obj = rotate_dirs(rays_d, rotmat) 107 | 108 | ray_batch_obj = update_ray_batch_slice(ray_batch, rays_o_obj, 0, 3) 109 | ray_batch_obj = update_ray_batch_slice(ray_batch_obj, rays_d_obj, 3, 6) 110 | if use_viewdirs: 111 | # Grab viewdirs from the ray batch itself. Because it may be different from rays_d 112 | # (as in the staticcam case). 113 | viewdirs = ray_batch[:, 8:11] 114 | viewdirs_obj = rotate_dirs(viewdirs, rotmat) 115 | ray_batch_obj = update_ray_batch_slice(ray_batch_obj, viewdirs_obj, 8, 11) 116 | return ray_batch_obj 117 | 118 | 119 | def transform_points_into_world_coordinate_frame(pts, params, check_numerics=False): 120 | translation, rotmat = get_transformation_from_params(params) # [3,], [3, 3] 121 | 122 | # pts_flat = pts.view(-1, 3) # [RS, 3] 123 | # num_examples = pts_flat.size()[0] # RS 124 | 125 | # translation = translation.unsqueeze(0) 126 | # translation = torch.tile(translation, (num_examples, 1)) # [RS, 3] 127 | # rotmat = rotmat.unsqueeze(0) 128 | # rotmat = torch.tile(rotmat, (num_examples, 1, 1)) 129 | 130 | # # pts_flat_transformed = torch.matmul(pts_flat[:, None, :], torch.transpose(rotmat, 2, 1)) # [RS, 1, 3] 131 | # pts_flat_transformed = pts_flat[:, None, :] # [RS, 1, 3] 132 | # pts_flat_transformed += translation[:, None, :] # [RS, 1, 3] 133 | # pts_transformed = pts_flat_transformed.view(pts.size()) # [R, S, 3] 134 | chunk = 256 135 | # Check batch transformations works without rotation. 136 | if check_numerics: 137 | transformations = np.eye(4) 138 | transformations[:3, 3] = translation 139 | transformations = torch.tensor(transformations, dtype=torch.float) # [4, 4] 140 | transformations = torch.tile(transformations[None, ...], (pts.size()[0], 1, 1)) # [R, 4, 4] 141 | pts_transformed1 = [] 142 | for i in range(0, pts.size()[0], chunk): 143 | pts_transformed1_chunk = apply_batched_transformations( 144 | inputs=[pts[i:i+chunk]], transformations=transformations[i:i+chunk])[0] 145 | pts_transformed1.append(pts_transformed1_chunk) 146 | pts_transformed1 = torch.cat(pts_transformed1, dim=0) 147 | 148 | pts_transformed2 = pts + translation[None, None, :] 149 | 150 | # Now add rotation 151 | transformations = np.eye(4) 152 | transformations = torch.tensor(transformations, dtype=torch.float) 153 | transformations[:3, :3] = rotmat 154 | transformations[:3, 3] = translation 155 | #transformations = torch.tensor(transformations, dtype=torch.float) # [4, 4] 156 | transformations = torch.tile(transformations[None, ...], (pts.size()[0], 1, 1)) # [R, 4, 4] 157 | pts_transformed = [] 158 | for i in range(0, pts.size()[0], chunk): 159 | pts_transformed_chunk = apply_batched_transformations( 160 | inputs=[pts[i:i+chunk]], transformations=transformations[i:i+chunk])[0] 161 | pts_transformed.append(pts_transformed_chunk) 162 | pts_transformed = torch.cat(pts_transformed, dim=0) 163 | return pts_transformed 164 | 165 | 166 | # def transform_rays(ray_batch, translation, use_viewdirs): 167 | # """Apply transformation to rays. 168 | 169 | # Args: 170 | # ray_batch: [R, M] float tensor. All information necessary 171 | # for sampling along a ray, including: ray origin, ray direction, min 172 | # dist, max dist, and unit-magnitude viewing direction. 173 | # translation: [3,] float tensor. The (x, y, z) translation to apply. 174 | # use_viewdirs: Whether to use view directions. 175 | 176 | # Returns: 177 | # ray_batch: [R, M] float tensor. Transformed ray batch. 178 | # """ 179 | # assert translation.size()[0] == 3, "translation.size()[0] must be 3..." 180 | 181 | # # Since we are only supporting translation for now, only ray origins need to be 182 | # # modified. Ray directions do not need to change. 183 | # rays_o = ray_batch[:, 0:3] + translation 184 | # rays_remaining = ray_batch[:, 3:] 185 | # ray_batch = torch.cat((rays_o, rays_remaining), dim=1) 186 | # return ray_batch 187 | 188 | def compute_rays_length(rays_d): 189 | """Compute ray length. 190 | 191 | Args: 192 | rays_d: [R, 3] float tensor. Ray directions. 193 | 194 | Returns: 195 | rays_length: [R, 1] float tensor. Ray lengths. 196 | """ 197 | rays_length = torch.norm(rays_d, dim=-1, keepdim=True) # [N_rays, 1] 198 | return rays_length 199 | 200 | 201 | def normalize_rays(rays): 202 | """Normalize ray directions. 203 | 204 | Args: 205 | rays: [R, 3] float tensor. Ray directions. 206 | 207 | Returns: 208 | normalized_rays: [R, 3] float tensor. Normalized ray directions. 209 | """ 210 | normalized_rays = rays / compute_rays_length(rays_d=rays) 211 | return normalized_rays 212 | 213 | 214 | def compute_ray_dirs_and_length(rays_o, rays_dst): 215 | """Compute ray directions. 216 | 217 | Args: 218 | rays_o: [R, 3] float tensor. Ray origins. 219 | rays_dst: [R, 3] float tensor. Ray destinations. 220 | 221 | Returns: 222 | rays_d: [R, 3] float tensor. Normalized ray directions. 223 | """ 224 | # The ray directions are the difference between the ray destinations and the 225 | # ray origins. 226 | rays_d = rays_dst - rays_o # [R, 3] # Direction out of light source 227 | 228 | # Compute the length of the rays. 229 | rays_length = compute_rays_length(rays_d=rays_d) 230 | 231 | # Normalized the ray directions. 232 | rays_d = rays_d / rays_length # [R, 3] # Normalize direction 233 | return rays_d, rays_length 234 | 235 | 236 | def update_ray_batch_slice(ray_batch, x, start, end): 237 | left = ray_batch[:, :start] # [R, ?] 238 | right = ray_batch[:, end:] # [R, ?] 239 | updated_ray_batch = torch.cat((left, x, right), dim=-1) 240 | return updated_ray_batch 241 | 242 | 243 | def update_ray_batch_bounds(ray_batch, bounds): 244 | updated_ray_batch = update_ray_batch_slice(ray_batch=ray_batch, x=bounds, 245 | start=6, end=8) 246 | return updated_ray_batch 247 | 248 | 249 | def create_ray_batch( 250 | rays_o, rays_dst, rays_i, use_viewdirs, rays_near=None, rays_far=None, epsilon=1e-10): 251 | # Compute the ray directions. 252 | rays_d = rays_dst - rays_o # [R,3] # Direction out of light source 253 | rays_length = compute_rays_length(rays_d=rays_d) # [R, 1] 254 | rays_d = rays_d / rays_length # [R, 3] # Normalize direction 255 | viewdirs = rays_d # [R, 3] 256 | 257 | # If bounds are not provided, set the beginning and end of ray as sampling bounds. 258 | if rays_near is None: 259 | rays_near = torch.zeros((rays_o.size()[0], 1), dtype=torch.float) + epsilon # [R, 1] 260 | if rays_far is None: 261 | rays_far = rays_length # [R, 1] 262 | 263 | ray_batch = torch.cat((rays_o, rays_d, rays_near, rays_far), dim=-1) 264 | if use_viewdirs: 265 | ray_batch = torch.cat((ray_batch, viewdirs), dim=-1) 266 | ray_batch = torch.cat((ray_batch, rays_i), dim=-1) 267 | return ray_batch 268 | 269 | 270 | def sample_random_lightdirs(num_rays, num_samples, upper_only=False): 271 | """Randomly sample directions in the unit sphere. 272 | 273 | Args: 274 | num_rays: int or tensor shape dimension. Number of rays. 275 | num_samples: int or tensor shape dimension. Number of samples per ray. 276 | upper_only: bool. Whether to sample only on the upper hemisphere. 277 | 278 | Returns: 279 | lightdirs: [R, S, 3] float tensor. Random light directions sampled from the unit 280 | sphere for each sampled point. 281 | """ 282 | if upper_only: 283 | min_z = 0 284 | else: 285 | min_z = -1 286 | 287 | phi = torch.rand(num_rays, num_samples) * (2 * math.pi) # [R, S] 288 | cos_theta = torch.rand(num_rays, num_samples) * (1 - min_z) + min_z # [R, S] 289 | theta = torch.acos(cos_theta) # [R, S] 290 | 291 | x = torch.sin(theta) * torch.cos(phi) 292 | y = torch.sin(theta) * torch.sin(phi) 293 | z = torch.cos(theta) 294 | 295 | lightdirs = torch.cat((x[..., None], y[..., None], z[..., None]), dim=-1) # [R, S, 3] 296 | return lightdirs 297 | 298 | 299 | def get_light_positions(rays_i, img_light_pos): 300 | """Extracts light positions given scene IDs. 301 | 302 | Args: 303 | rays_i: [R, 1] float tensor. Per-ray image IDs. 304 | img_light_pos: [N, 3] float tensor. Per-image light positions. 305 | 306 | Returns: 307 | rays_light_pos: [R, 3] float tensor. Per-ray light positions. 308 | """ 309 | rays_light_pos = img_light_pos[rays_i.long()].squeeze() # [R, 3] 310 | return rays_light_pos 311 | 312 | 313 | def get_lightdirs(lightdirs_method, num_rays=None, num_samples=None, rays_i=None, 314 | metadata=None, ray_batch=None, use_viewdirs=False, normalize=False): 315 | """Compute lightdirs. 316 | 317 | Args: 318 | lightdirs_method: str. Method to use for computing lightdirs. 319 | num_rays: int or tensor shape dimension. Number of rays. 320 | num_samples: int or tensor shape dimension. Number of samples per ray. 321 | rays_i: [R, 1] float tensor. Ray image IDs. 322 | metadata: [N, 3] float tensor. Metadata about each image. Currently only light 323 | position is provided. 324 | ray_batch: [R, M] float tensor. Ray batch. 325 | use_viewdirs: bool. Whether to use viewdirs. 326 | normalize: bool. Whether to normalize lightdirs. 327 | 328 | Returns; 329 | lightdirs: [R, S, 3] float tensor. Light directions for each sample. 330 | """ 331 | if lightdirs_method == 'viewdirs': 332 | raise NotImplementedError 333 | assert use_viewdirs 334 | lightdirs = ray_batch[:, 8:11] # [R, 3] 335 | lightdirs *= 1.5 336 | lightdirs = torch.tile(lightdirs[:, None, :], (1, num_samples, 1)) 337 | elif lightdirs_method == 'metadata': 338 | lightdirs = get_light_positions(rays_i, metadata) # [R, 3] 339 | lightdirs = torch.tile(lightdirs[:, None, :], (1, num_samples, 1)) # [R, S, 3] 340 | elif lightdirs_method == 'random': 341 | lightdirs = sample_random_lightdirs(num_rays, num_samples) # [R, S, 3] 342 | elif lightdirs_method == 'random_upper': 343 | lightdirs = sample_random_lightdirs(num_rays, num_samples, upper_only=True) # [R, S, 3] 344 | else: 345 | raise ValueError(f'Invalid lightdirs_method: {lightdirs_method}.') 346 | if normalize: 347 | lightdirs_flat = lightdirs.view(-1, 3) # [RS, 3] 348 | lightdirs_flat = normalize_rays(lightdirs_flat) # [RS, 3] 349 | lightdirs = lightdirs_flat.view(lightdirs.size()) # [R, S, 3] 350 | return lightdirs 351 | -------------------------------------------------------------------------------- /ObjectFolder1.0/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.21 2 | torch==1.8.1 3 | scipy==1.6.3 4 | librosa==0.8.1 5 | matplotlib==3.4.2 6 | imageio==2.9.0 7 | tqdm==4.60.0 8 | torchvision==0.9.1 9 | imageio-ffmpeg==0.4.3 10 | configargparse 11 | -------------------------------------------------------------------------------- /ObjectFolder1.0/scatter.py: -------------------------------------------------------------------------------- 1 | """Utility functions for scattering rays.""" 2 | import torch 3 | 4 | 5 | def create_scatter_indices_for_dim(dim, shape, indices=None): 6 | """Create scatter indices for a given dimension.""" 7 | dim_size = shape[dim] 8 | N_dims = len(shape) 9 | reshape = [1] * N_dims 10 | reshape[dim] = -1 11 | 12 | if indices is None: 13 | indices = torch.arange(dim_size, dtype=torch.int) # [dim_size,] 14 | 15 | indices = indices.view(reshape) 16 | 17 | indices = torch.broadcast_to( 18 | indices, shape) # [Ro, S, 1] or [Ro, S, C, 1] [0,1,1,1] vs. [512,64,1,1] 19 | 20 | indices = indices.int() 21 | return indices 22 | 23 | 24 | def create_scatter_indices(updates, dim2known_indices): 25 | """Create scatter indices.""" 26 | updates_expanded = updates.unsqueeze(-1) 27 | target_shape = updates_expanded.size() 28 | n_dims = len(updates.size()) # 2 or 3 29 | 30 | dim_indices_list = [] 31 | for dim in range(n_dims): 32 | indices = None 33 | if dim in dim2known_indices: 34 | indices = dim2known_indices[dim] 35 | dim_indices = create_scatter_indices_for_dim( # [Ro, S, C, 1] 36 | dim=dim, 37 | shape=target_shape, # [Ro, S, 1] or [Ro, S, C, 1] 38 | indices=indices) # [Ro,] 39 | dim_indices_list.append(dim_indices) 40 | scatter_indices = torch.cat((dim_indices_list), dim=-1) # [Ro, S, C, 3] 41 | return scatter_indices 42 | 43 | 44 | def scatter_nd(tensor, updates, dim2known_indices): 45 | scatter_indices = create_scatter_indices( # [Ro, S, C, 3] 46 | updates=updates, # [Ro, S] 47 | dim2known_indices=dim2known_indices) # [Ro,] 48 | #scattered_tensor = (tensor[scatter_indices.view(-1, 3)[:, 0], 49 | # scatter_indices.view(-1, 3)[:, 1], 50 | # scatter_indices.view(-1, 3)[:, 2]] = updates.view(-1)) 51 | scattered_tensor = tensor[scatter_indices.view(-1, 3)[:, 0], 52 | scatter_indices.view(-1, 3)[:, 1], 53 | scatter_indices.view(-1, 3)[:, 2]] 54 | scattered_tensor = updates.view(-1) 55 | return scattered_tensor 56 | 57 | 58 | def scatter_results(intersect, indices, N_rays, keys, N_samples, N_importance=None): 59 | """Scatters intersecting ray results into the original set of rays. 60 | 61 | Args: 62 | intersect: Dict. Values are tensors of intersecting rays which can be any of the 63 | following. 64 | z_vals: [R?, S] 65 | pts: [R?, S, 3] 66 | rgb: [R?, S, 3] 67 | raw: [R?, S, 4] 68 | alpha: [R?, S, 1] 69 | indices: [R?, 1] int tensor. Intersecting ray indices to scatter back to the full set 70 | of rays. 71 | N_rays: int or int tensor. Total number of rays. 72 | keys: [str]. List of keys from the 'intersect' dictionary to scatter. 73 | N_samples: [int]. Number of samples. 74 | N_importance: [int]. Number of importance (fine) samples. 75 | 76 | Returns: 77 | scattered_results: Dict. Scattered results, where each value is of shape [R, ...]. 78 | The original intersecting ray results are padding with samples with zero density, 79 | so they won't have any contribution to the final render. 80 | """ 81 | # We use 'None' to indicate that the intersecting set of rays is equivalent to 82 | # the full set if rays, so we are done. 83 | if indices is None: 84 | return {k: intersect[k] for k in keys} 85 | 86 | scattered_results = {} 87 | # N_samples = intersect['z_vals'].shape[1] 88 | dim2known_indices = {0: indices} # [R', 1] 89 | for k in keys: 90 | if k == 'z_vals': 91 | # tensor = torch.rand((N_rays, N_samples), dtype=torch.float) 92 | tensor = torch.arange(N_samples) # [S,] 93 | tensor = tensor.float() 94 | tensor = torch.stack([tensor] * N_rays) # [R, S] 95 | elif k == 'z_samples': 96 | # tensor = torch.rand((N_rays, N_samples), dtype=torch.float) #[R, S] 97 | # N_importance = intersect[k].size()[1] 98 | tensor = torch.arange(N_importance) # [I,] 99 | tensor = tensor.float() 100 | tensor = torch.stack([tensor] * N_rays) # [R, I] 101 | elif k == 'raw': 102 | tensor = torch.full((N_rays, N_samples, 4), 1000.0, dtype=torch.float) # [R, S, 4] 103 | elif k == 'pts': 104 | tensor = torch.full((N_rays, N_samples, 3), 1000.0, dtype=torch.float) # [R, S, 3] 105 | elif 'rgb' in k: 106 | tensor = torch.zeros((N_rays, N_sampels, 3), dtype=torch.float) # [R, S, 3] 107 | elif 'alpha' in k: 108 | tensor = torch.zeros((N_rays, N_samples, 1), dtype=torch.float) # [R, S, 1] 109 | else: 110 | raise ValueError(f'Invalid key: {k}') 111 | # No intersections to scatter. 112 | if len(indices) == 0: 113 | scattered_results[k] = tensor 114 | else: 115 | scattered_v = scatter_nd( # [R, S, K] 116 | tensor=tensor, 117 | updates=intersect[k], # [Ro, S] 118 | dim2known_indices=dim2known_indices) 119 | # Convert the batch dimension to a known dimension. 120 | # For some reason 'scattered_z_vals' becomes [R, ?]. We need to explicitly 121 | # reshape it with 'N_sampels'. 122 | if k == 'z_samples': 123 | scattered_v = scattered_v.view(N_rays, N_importance) # [R, I] 124 | else: 125 | if k == 'z_vals': 126 | scattered_v = scattered_v.view(N_rays, N_samples) # [R, S] 127 | else: 128 | # scattered_v = scattered_v.view((N_rays,) + scattered_v.size()[1:]) # [R, S, K] 129 | # scattered_v = scattered_v.view((-1,) + scattered_v.size()[1:]) # [R, S, K] 130 | scattered_v = scattered_v.view(N_rays, N_samples, tensor.size()[2]) # [R, S, K] 131 | scattered_results[k] = scattered_v 132 | return scattered_results 133 | 134 | 135 | def scatter_coarse_and_fine( 136 | ret0, ret, indices, num_rays, N_samples, N_importance, **kwargs): 137 | # TODO: only process raw if retraw=True. 138 | ret0 = scatter_results( 139 | intersect=ret0, 140 | indices=indices, 141 | N_rays=num_rays, 142 | keys=['z_vals', 'rgb', 'alpha', 'raw'], 143 | N_samples=N_samples) 144 | ret = scatter_results( 145 | intersect=ret, 146 | indices=indices, 147 | N_rays=num_rays, 148 | key=['z_vals', 'rgb', 'alpha', 'raw', 'z_samples'], 149 | N_samples=N_samples + N_importance, 150 | N_importance=N_importance) 151 | return ret0, ret 152 | -------------------------------------------------------------------------------- /ObjectFolder1.0/shadow_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for shadows.""" 2 | import torch 3 | 4 | import ray_utils 5 | import run_osf_helpers 6 | 7 | 8 | def create_ray_batch(ray_batch, pts, metadata, use_viewdirs, lightdirs_method): 9 | """Create batch for shadow rays. 10 | 11 | Args: 12 | ray_batch: [R?, M] float tensor. Batch of primary rays 13 | pts: [R?, S, 3] float tensor. Primary points. 14 | metadata: [N, 3] float tensor. Metadata about each image. Currently only light 15 | position is provided. 16 | use_viewdirs: bool. Whether to use view directions. 17 | 18 | Returns: 19 | shadow_ray_batch: [R?S, M] float tensor. Batch of shadow rays containing one shadow 20 | ray for each primary ray sample. Each shadow ray originates at a primary point 21 | and points in the direction towards the light source. 22 | """ 23 | num_primary_rays = pts.size()[0] # R 24 | num_primary_samples = pts.size()[1] # S 25 | 26 | # Samples are shadow ray origins. 27 | rays_o = pts.view(-1, 3) # [R?S, 3] 28 | 29 | num_shadow_rays = rays_o.size()[0] # R?S 30 | num_samples = pts.size()[1] # S 31 | 32 | rays_i = ray_batch[:, 11:12] # [R, 1] 33 | 34 | # Get light positions for each ray as the ray destinations. 35 | rays_dst = ray_utils.get_lightdirs( # [R?, S, 3] 36 | lightdirs_method=lightdirs_method, num_rays=num_primary_rays, 37 | num_samples=num_primary_samples, rays_i=rays_i, metadata=metadata, 38 | ray_batch=ray_batch, use_viewdirs=use_viewdirs) 39 | rays_dst = rays_dst.view(rays_o.size()) # [R?S, 3] 40 | 41 | rays_i = torch.tile(rays_i.unsqueeze(1), (1, num_primary_samples, 1)) # [R?, S, 1] 42 | rays_i = rays_i.view(-1, 1) # [R?S, 1] 43 | 44 | shadow_ray_batch = ray_utils.create_ray_batch(rays_o, rays_dst, rays_i, use_viewdirs) 45 | return shadow_ray_batch 46 | 47 | 48 | def compute_transmittance(alpha): 49 | """Applies shadows to outputs. 50 | Args: 51 | alpha: [R?S, S, 1] tf.float32. Alpha predictions from the model. 52 | Returns: 53 | last_trans: [R?S,] tf.float32. Shadow transmittance per ray. 54 | """ 55 | trans = run_nerf_helpers.compute_transmittance(alpha=alpha[..., 0]) # [R?S, S] 56 | 57 | # Transmittance is computed in the direction origin -> end, so we grab the 58 | # last transmittance. 59 | last_trans = trans[:, -1] # [R?S,] 60 | return last_trans 61 | -------------------------------------------------------------------------------- /ObjectFolder2.0_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/ObjectFolder2.0_teaser.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## ObjectFolder 2.0 2 | ObjectFolder 2.0: A Multisensory Object Dataset for Sim2Real Transfer (CVPR 2022) 3 | [[Project Page]](https://ai.stanford.edu/~rhgao/objectfolder2.0/) [[arXiv]](https://arxiv.org/abs/2204.02389) 4 | 5 | 6 | 7 |
8 | 9 | [ObjectFolder 2.0: A Multisensory Object Dataset for Sim2Real Transfer](https://ai.stanford.edu/~rhgao/objectfolder2.0/) 10 | [Ruohan Gao*](https://www.ai.stanford.edu/~rhgao/), [Zilin Si*](https://si-lynnn.github.io/), [Yen-Yu Chang*](https://yuyuchang.github.io/), [Samuel Clarke](https://samuelpclarke.com/), [Jeannette Bohg](https://web.stanford.edu/~bohg/), [Li Fei-Fei](https://profiles.stanford.edu/fei-fei-li), [Wenzhen Yuan](http://robotouch.ri.cmu.edu/yuanwz/), [Jiajun Wu](https://jiajunwu.com/)
11 | Stanford University, Carnegie Mellon University
12 | In Conference on Computer Vision and Pattern Recognition (**CVPR**), 2022 13 | 14 |
15 | 16 | If you find our dataset, code or project useful in your research, we appreciate it if you could cite: 17 | 18 | @inproceedings{gao2022ObjectFolderV2, 19 | title = {ObjectFolder 2.0: A Multisensory Object Dataset for Sim2Real Transfer}, 20 | author = {Gao, Ruohan and Si, Zilin and Chang, Yen-Yu and Clarke, Samuel and Bohg, Jeannette and Fei-Fei, Li and Yuan, Wenzhen and Wu, Jiajun}, 21 | booktitle = {CVPR}, 22 | year = {2022} 23 | } 24 | 25 | @inproceedings{gao2021ObjectFolder, 26 | title = {ObjectFolder: A Dataset of Objects with Implicit Visual, Auditory, and Tactile Representations}, 27 | author = {Gao, Ruohan and Chang, Yen-Yu and Mall, Shivani and Fei-Fei, Li and Wu, Jiajun}, 28 | booktitle = {CoRL}, 29 | year = {2021} 30 | } 31 | 32 | ### About ObjectFolder Dataset 33 | 34 | ObjectFolder 2.0 is a dataset of 1,000 objects in the form of implicit representations. It contains 1,000 Object Files each containing the complete multisensory profile for an object instance. Each Object File implicit neural representation network contains three sub-networks---VisionNet, AudioNet, and TouchNet, which through querying with the corresponding extrinsic parameters we can obtain the visual appearance of the object from different views and lighting conditions, impact sounds of the object at each position of specified force profile, and tactile sensing of the object at every surface location for varied rotation angels and pressing depth, respectively. The dataset contains common household objects of diverse categories such as wood desks, ceramic bowls, plastic toys, steel forks, glass mirrors, etc. The objects.csv file contains the metadata for these 1,000 objects. Note that the first 100 objects are the same as ObjectFolder 1.0, and we recommend using the new version for improved multisensory simulation and implicit representation for the objects. See the paper for details. 35 | 36 |
37 | 38 | ### Prerequisites 39 | * OS: Ubuntu 20.04.2 LTS 40 | * GPU: >= NVIDIA GTX 1080 Ti with >= 460.73.01 driver 41 | * Python package manager `conda` 42 | 43 | ### Setup 44 | ``` 45 | git clone https://github.com/rhgao/ObjectFolder.git 46 | cd ObjectFolder 47 | export OBJECTFOLDER_HOME=$PWD 48 | conda env create -f $OBJECTFOLDER_HOME/environment.yml 49 | source activate ObjectFolder-env 50 | ``` 51 | 52 | ### Dataset Download and Preparation 53 | Use the following command to download the first 100 objects: 54 | ``` 55 | wget https://download.cs.stanford.edu/viscam/ObjectFolder/ObjectFolder1-100.tar.gz 56 | tar -xvf ObjectFolder1-100.tar.gz 57 | ``` 58 | Use the following command to download ObjectFiles with KiloOSF that supports real-time visual rendering at the expense of larger model size: 59 | ``` 60 | wget https://download.cs.stanford.edu/viscam/ObjectFolder/ObjectFolder1-100KiloOSF.tar.gz 61 | tar -xvf ObjectFolder101-100KiloOSF.tar.gz 62 | ``` 63 | Similarly, use the following command to download objects from 101-1000: 64 | ``` 65 | wget https://download.cs.stanford.edu/viscam/ObjectFolder/ObjectFolder[X+1]-[X+100].tar.gz 66 | wget https://download.cs.stanford.edu/viscam/ObjectFolder/ObjectFolder[X+1]-[X+100]KiloOSF.tar.gz 67 | tar -xvf ObjectFolder[X+1]-[X+100].tar.gz 68 | tar -xvf ObjectFolder[X+1]-[X+100]KiloOSF.tar.gz 69 | # replace X with a value in [100,200,300,400,500,600,700,800,900] 70 | ``` 71 | 72 | ### Rendering Visual, Acoustic, and Tactile Sensory Data 73 | Run the following command to render visual appearance of the object from a specified camera viewpoint and lighting direction: 74 | ``` 75 | $ python OF_render.py --modality vision --object_file_path path_of_ObjectFile \ 76 | --vision_test_file_path path_of_vision_test_file \ 77 | --vision_results_dir path_of_vision_results_directory \ 78 | ``` 79 | Run the following command to render impact sounds of the object at a specified surface location and impact force: 80 | ``` 81 | $ python OF_render.py --modality audio --object_file_path path_of_ObjectFile \ 82 | --audio_vertices_file_path path_of_audio_testing_vertices_file \ 83 | --audio_forces_file_path path_of_forces_file \ 84 | --audio_results_dir path_of_audio_results_directory \ 85 | ``` 86 | Run the following command to render tactile RGB iamges of the object at a specified surface location, gel rotation angle, and deformation: 87 | ``` 88 | $ python OF_render.py --modality touch --object_file_path path_of_ObjectFile \ 89 | --touch_vertices_file_path path_of_touch_testing_vertices_file \ 90 | --touch_gelinfo_file_path path_of_gelinfor_file \ 91 | --touch_results_path path_of_touch_results_directory 92 | ``` 93 | The command-line arguments are described as follows: 94 | * `--object_file_path`: The path of ObjectFile. 95 | * `--vision_test_file_path`: The path of the testing file for vision, which should be a npy file. 96 | * `--vision_results_dir`: The path of the vision results directory to save rendered images. 97 | * `--audio_vertices_file_path`: The path of the testing vertices file for audio, which should be a npy file. 98 | * `--audio_forces_file_path`: The path of forces file for audio, which should be a npy file. 99 | * `--audio_results_dir`: The path of audio results directory to save rendered impact sounds as .wav files. 100 | * `--touch_vertices_file_path`: The path of the testing vertices file for touch, which should be a npy file. 101 | * `--touch_gelinfo_file_path`: The path of the gelinfo file for touch that speficifies the gel rotation angle and deformation depth, which should be a npy file. 102 | * `--touch_results_dir`: The path of the touch results directory to save rendered tactile RGB images. 103 | 104 | ### Data Format 105 | * `--vision_test_file_path`: It is a npy file with shape of (N, 6), where N is the number of testing viewpoints. Each data point contains the coordinates of the camera and the light in the form of (camera_x, camera_y, camera_z, light_x, light_y, light_z). 106 | * `--audio_vertices_file_path`: It is a npy file with shape of (N, 3), where N is the number of testing vertices. Each data point represents a coordinate on the object in the form of (x, y, z). 107 | * `--audio_forces_file_path`: It is a npy file with shape of (N, 3), where N is the number of testing vertices. Each data point represents the force values for the corresponding impact in the form of (F_x, F_y, F_z). 108 | * `--touch_vertices_file_path`: It is a npy file with shape of (N, 3), where N is the number of testing vertices. Each data point contains a coordinate on the object in the form of (x, y, z). 109 | * `--touch_gelinfo_file_path`: It is a npy file with shape of (N, 3), where N is the number of testing vertices. Each data point contains the gel rotation angle and gel displacement in the form of (theta, phi, depth). theta is in the range of [0, np.radians(15)], phi is in the range of [0, 2pi], and depth is in the range of [0.0005,0.002]. 110 | 111 | ### ObjectFile with KiloOSF for Real-time Visual Rendering 112 | To use KiloOSF, please make a copy of [cuda](https://github.com/creiser/kilonerf/tree/master/cuda) in the root directory of this repo and follow the steps in [CUDA extension installation](https://github.com/creiser/kilonerf). Run the following command to render visual appearance of the object from a specified camera viewpoint and lighting direction: 113 | ``` 114 | $ python OF_render.py --modality vision --KiloOSF --object_file_path path_of_KiloOSF_ObjectFile \ 115 | --vision_test_file_path path_of_vision_test_file \ 116 | --vision_results_dir path_of_vision_results_directory \ 117 | ``` 118 | 119 | ### Demo 120 | Below we show an example of rendering the visual, acoustic, and tactile data from the ObjectFile implicit representation for one object: 121 | ``` 122 | $ python OF_render.py --modality vision,audio,touch --object_file_path demo/ObjectFile.pth \ 123 | --vision_test_file_path demo/vision_demo.npy \ 124 | --vision_results_dir demo/vision_results/ \ 125 | --audio_vertices_file_path demo/audio_demo_vertices.npy \ 126 | --audio_forces_file_path demo/audio_demo_forces.npy \ 127 | --audio_results_dir demo/audio_results/ \ 128 | --touch_vertices_file_path demo/touch_demo_vertices.npy \ 129 | --touch_gelinfo_file_path demo/touch_demo_gelinfo.npy \ 130 | --touch_results_dir demo/touch_results/ 131 | ``` 132 | The rendered images, impact sounds, tactile images will be saved in `demo/vision_results/`, `demo/audio_results/`, and `demo/touch_results/`, respectively. 133 | 134 | ### License 135 | ObjectFolder is CC BY 4.0 licensed, as found in the LICENSE file. The mesh files for the 1,000 high quality 3D objects originally come from online repositories including: [ABO dataset](https://amazon-berkeley-objects.s3.amazonaws.com/index.html), [3D Model Haven](https://3dmodelhaven.com/), [YCB dataset](http://ycb-benchmarks.s3-website-us-east-1.amazonaws.com/), and [Google Scanned Objects](https://app.ignitionrobotics.org/GoogleResearch/fuel/collections/Google\%20Scanned\%20Objects). Please also refer to their original lisence file. We appreciate their sharing of these great object assets. 136 | -------------------------------------------------------------------------------- /TouchNet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.autograd.set_detect_anomaly(True) 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from TouchNet_utils import * 7 | 8 | class DenseLayer(nn.Linear): 9 | def __init__(self, in_dim: int, out_dim: int, activation: str = 'relu', *args, **kwargs) -> None: 10 | self.activation = activation 11 | super().__init__(in_dim, out_dim, *args, **kwargs) 12 | 13 | def reset_parameters(self) -> None: 14 | torch.nn.init.xavier_uniform_(self.weight, gain=torch.nn.init.calculate_gain(self.activation)) 15 | if self.bias is not None: 16 | torch.nn.init.zeros_(self.bias) 17 | 18 | class Embedder: 19 | def __init__(self, **kwargs): 20 | self.kwargs = kwargs 21 | self.create_embedding_fn() 22 | 23 | def create_embedding_fn(self): 24 | embed_fns = [] 25 | d = self.kwargs['input_dims'] 26 | out_dim = 0 27 | if self.kwargs['include_input']: 28 | embed_fns.append(lambda x: x) 29 | out_dim += d 30 | 31 | max_freq = self.kwargs['max_freq_log2'] 32 | N_freqs = self.kwargs['num_freqs'] 33 | 34 | if self.kwargs['log_sampling']: 35 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) 36 | else: 37 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) 38 | 39 | for freq in freq_bands: 40 | for p_fn in self.kwargs['periodic_fns']: 41 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 42 | out_dim += d 43 | 44 | self.embed_fns = embed_fns 45 | self.out_dim = out_dim 46 | 47 | def embed(self, inputs): 48 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 49 | 50 | def get_embedder(multires, i=0): 51 | if i == -1: 52 | #9 is for x, y, z, theta, phi_x, phi_y, displacement, w, h 53 | return nn.Identity(), 9 54 | 55 | embed_kwargs = { 56 | 'include_input': True, 57 | 'input_dims': 9, 58 | 'max_freq_log2': multires-1, 59 | 'num_freqs': multires, 60 | 'log_sampling': True, 61 | 'periodic_fns': [torch.sin, torch.cos], 62 | } 63 | 64 | embedder_obj = Embedder(**embed_kwargs) 65 | embed = lambda x, eo=embedder_obj: eo.embed(x) 66 | return embed, embedder_obj.out_dim 67 | 68 | class AudioNeRF(nn.Module): 69 | def __init__(self, D=8, input_ch=5): 70 | super(AudioNeRF, self).__init__() 71 | self.model_x = NeRF(D = D, input_ch = input_ch) 72 | self.model_y = NeRF(D = D, input_ch = input_ch) 73 | self.model_z = NeRF(D = D, input_ch = input_ch) 74 | 75 | def forward(self, embedded_x, embedded_y, embedded_z): 76 | results_x = self.model_x(embedded_x) 77 | results_y = self.model_y(embedded_y) 78 | results_z = self.model_z(embedded_z) 79 | return results_x, results_y, results_z 80 | 81 | class NeRF(nn.Module): 82 | def __init__(self, D=8, W=256, input_ch=5, input_ch_views=0, output_ch=2, skips=[4], use_viewdirs=False): 83 | """ 84 | """ 85 | super(NeRF, self).__init__() 86 | self.D = D 87 | self.W = W 88 | self.input_ch = input_ch 89 | self.input_ch_views = input_ch_views 90 | self.skips = skips 91 | self.use_viewdirs = use_viewdirs 92 | 93 | self.pts_linears = nn.ModuleList( 94 | [DenseLayer(input_ch, W, activation='relu')] + [DenseLayer(W, W, activation='relu') if i not in self.skips else DenseLayer(W + input_ch, W, activation='relu') for i in range(D-1)]) 95 | 96 | self.views_linears = nn.ModuleList([DenseLayer(input_ch_views + W, W//2, activation='relu')]) 97 | 98 | if use_viewdirs: 99 | self.feature_linear = DenseLayer(W, W, activation='sigmoid') 100 | #self.alpha_linear = DenseLayer(W, 1, activation='linear') 101 | self.rgb_linear = DenseLayer(W//2, output_ch, activation='sigmoid') 102 | else: 103 | self.output_linear = DenseLayer(W, output_ch, activation='sigmoid') 104 | 105 | def forward(self, x): 106 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1) 107 | h = input_pts 108 | for i, l in enumerate(self.pts_linears): 109 | h = self.pts_linears[i](h) 110 | h = F.relu(h) 111 | if i in self.skips: 112 | h = torch.cat([input_pts, h], -1) 113 | 114 | if self.use_viewdirs: 115 | feature = self.feature_linear(h) 116 | h = torch.cat([feature, input_views], -1) 117 | 118 | for i, l in enumerate(self.views_linears): 119 | h = self.views_linears[i](h) 120 | h = F.relu(h) 121 | 122 | outputs = self.rgb_linear(h) 123 | else: 124 | outputs = self.output_linear(h) 125 | 126 | return outputs 127 | -------------------------------------------------------------------------------- /TouchNet_utils.py: -------------------------------------------------------------------------------- 1 | from scipy.io import wavfile 2 | import librosa 3 | import librosa.display 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from TouchNet_model import * 7 | import os 8 | from collections import OrderedDict 9 | from torch._six import container_abcs, string_classes, int_classes 10 | 11 | def strip_prefix_if_present(state_dict, prefix): 12 | keys = sorted(state_dict.keys()) 13 | if not all(key.startswith(prefix) for key in keys): 14 | return state_dict 15 | stripped_state_dict = OrderedDict() 16 | for key, value in state_dict.items(): 17 | stripped_state_dict[key.replace(prefix, "")] = value 18 | return stripped_state_dict 19 | 20 | def mkdirs(path, remove=False): 21 | if os.path.isdir(path): 22 | if remove: 23 | shutil.rmtree(path) 24 | else: 25 | return 26 | os.makedirs(path) 27 | 28 | def generate_spectrogram_magphase(audio, stft_frame, stft_hop, n_fft, with_phase=False): 29 | spectro = librosa.core.stft(audio, hop_length=stft_hop, n_fft=n_fft, win_length=stft_frame, center=True) 30 | spectro_mag, spectro_phase = librosa.core.magphase(spectro) 31 | spectro_mag = np.expand_dims(spectro_mag, axis=0) 32 | if with_phase: 33 | spectro_phase = np.expand_dims(np.angle(spectro_phase), axis=0) 34 | return spectro_mag, spectro_phase 35 | else: 36 | return spectro_mag 37 | 38 | def generate_spectrogram_complex(audio, stft_frame, stft_hop, n_fft): 39 | spectro = librosa.core.stft(audio, hop_length=stft_hop, n_fft=n_fft, win_length=stft_frame, center=True) 40 | real = np.expand_dims(np.real(spectro), axis=0) 41 | imag = np.expand_dims(np.imag(spectro), axis=0) 42 | spectro_two_channel = np.concatenate((real, imag), axis=0) 43 | return spectro_two_channel 44 | 45 | def batchify(fn, chunk): 46 | """ 47 | Constructs a version of 'fn' that applies to smaller batches 48 | """ 49 | if chunk is None: 50 | return fn 51 | def ret(inputs): 52 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 53 | return ret 54 | 55 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64): 56 | """ 57 | Prepares inputs and applies network 'fn'. 58 | """ 59 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) 60 | embedded = embed_fn(inputs_flat) 61 | 62 | if viewdirs is not None: 63 | input_dirs = viewdirs[:,None].expand(inputs.shape) 64 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) 65 | embedded_dirs = embeddirs_fn(input_dirs_flat) 66 | embedded = torch.cat([embedded, embedded_dirs], -1) 67 | 68 | outputs_flat = batchify(fn, netchunk)(embedded) 69 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 70 | return outputs 71 | 72 | def create_nerf(args): 73 | """ 74 | Instantiate NeRF's MLP model. 75 | """ 76 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed) 77 | 78 | input_ch_views = 0 79 | embeddirs_fn = None 80 | if args.use_viewdirs: 81 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed) 82 | output_ch = 2 83 | skips = [4] 84 | model = NeRF(D=args.netdepth, W=args.netwidth, 85 | input_ch=input_ch, output_ch=output_ch, skips=skips, 86 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs) 87 | model = nn.DataParallel(model).to(device) 88 | grad_vars = list(model.parameters()) 89 | 90 | def object_collate(batch): 91 | r"""Puts each data field into a tensor with outer dimension batch size""" 92 | #print batch 93 | elem_type = type(batch[0]) 94 | if isinstance(batch[0], torch.Tensor): 95 | out = None 96 | return torch.stack(batch, 0, out=out) 97 | elif elem_type.__module__ == 'numpy': 98 | elem = batch[0] 99 | if elem_type.__name__ == 'ndarray': 100 | return torch.cat([torch.from_numpy(b) for b in batch], 0) #concatenate even if dimension differs 101 | #return object_collate([torch.from_numpy(b) for b in batch]) 102 | if elem.shape == (): # scalars 103 | py_type = float if elem.dtype.name.startswith('float') else int 104 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 105 | elif isinstance(batch[0], float): 106 | return torch.tensor(batch, dtype=torch.float64) 107 | elif isinstance(batch[0], int_classes): 108 | return torch.tensor(batch) 109 | elif isinstance(batch[0], string_classes): 110 | return batch 111 | elif isinstance(batch[0], container_abcs.Mapping): 112 | return {key: object_collate([d[key] for d in batch]) for key in batch[0]} 113 | elif isinstance(batch[0], container_abcs.Sequence): 114 | transposed = zip(*batch) 115 | return [object_collate(samples) for samples in transposed] 116 | 117 | raise TypeError((error_msg_fmt.format(type(batch[0])))) 118 | -------------------------------------------------------------------------------- /VisionNet_configs.py: -------------------------------------------------------------------------------- 1 | cfg = { 2 | # 'dataset_dir': '/viscam/u/yenyu/ObjectFolderV2/objects/23/VisionNet_osf_data/', 3 | # 'pretrained_cfg_path': '/viscam/u/yenyu/kiloosf2/cfgs/paper/pretrain/23.yaml', 4 | # 'pretrained_checkpoint_path': '/viscam/u/yenyu/kiloosf2/logs/paper/pretrain/23/checkpoint_0200000.pth', 5 | 'blender_half_res': False, 6 | 'dataset_type': 'osf', 7 | 'discovery': { 8 | 'alpha_distance': 0.0211, 9 | 'alpha_rgb_initalization': 'pass_actual_nonlinearity', 10 | 'bias_initialization_method': 'standard', 11 | 'convert_density_to_alpha': True, 12 | 'direction_layer_size': 64, 13 | 'equal_split_metric': 'mse', 14 | 'hidden_layer_size': 64, 15 | 'iterations': 30000, 16 | 'late_feed_direction': True, 17 | 'max_num_networks': 512, 18 | 'network_rng_seed': 8078673, 19 | 'nonlinearity_initalization': 'pass_actual_nonlinearity', 20 | 'num_examples_per_network': 1020000, 21 | 'num_frequencies': 10, 22 | 'num_frequencies_direction': 4, 23 | 'num_frequencies_light': 4, 24 | 'num_hidden_layers': 2, 25 | 'num_train_examples_per_network': 1000000, 26 | 'outputs': 'color_and_density', 27 | 'quantile_se': 0.99, 28 | 'query_batch_size': 80000, 29 | 'refeed_position_index': 'None', 30 | 'test_batch_size': 512, 31 | 'test_error_metric': 'quantile_se', 32 | 'test_every': 500, 'train_batch_size': 128, 33 | 'use_same_initialization_for_all_networks': True, 34 | 'weight_initialization_method': 'kaiming_uniform' 35 | }, 36 | 'fixed_resolution': [16, 16, 16], 37 | 'max_error': 100000, 38 | 'performance_monitoring': False, 39 | 'render_only': True, 40 | 'restart_after_checkpoint': True, 41 | 'skip_final': True, 42 | 'tree_type': 'kdtree_longest', 43 | 'global_domain_min': [-0.6, -0.6, -0.6], 44 | 'global_domain_max': [0.6, 0.6, 0.6], 45 | 'alpha_distance': 0.0211, 46 | 'alpha_rgb_initalization': 'pass_actual_nonlinearity', 47 | 'bias_initialization_method': 'standard', 48 | 'convert_density_to_alpha': True, 49 | 'direction_layer_size': 64, 50 | 'equal_split_metric': 'mse', 51 | 'hidden_layer_size': 64, 52 | 'iterations': 1000000, 53 | 'late_feed_direction': True, 54 | 'max_num_networks': 512, 55 | 'network_rng_seed': 8078673, 56 | 'nonlinearity_initalization': 'pass_actual_nonlinearity', 57 | 'num_examples_per_network': 1020000, 58 | 'num_frequencies': 10, 59 | 'num_frequencies_direction': 4, 60 | 'num_frequencies_light': 4, 61 | 'num_hidden_layers': 2, 62 | 'num_train_examples_per_network': 1000000, 63 | 'outputs': 'color_and_density', 64 | 'quantile_se': 0.99, 65 | 'query_batch_size': 80000, 66 | 'refeed_position_index': 'None', 67 | 'test_batch_size': 512, 68 | 'test_error_metric': 'quantile_se', 69 | 'test_every': 500, 70 | 'train_batch_size': 128, 71 | 'use_same_initialization_for_all_networks': False, 72 | 'weight_initialization_method': 'kaiming_uniform', 73 | # 'distilled_cfg_path': '/viscam/u/yenyu/kiloosf2/cfgs/paper/distill/23.yaml', 74 | # 'distilled_checkpoint_path': '/viscam/u/yenyu/kiloosf2/logs/paper/distill/23/checkpoint.pth', 75 | # 'occupancy_cfg_path': '/viscam/u/yenyu/kiloosf2/cfgs/paper/pretrain_occupancy/23.yaml', 76 | # 'occupancy_log_path': '/viscam/u/yenyu/kiloosf2/logs/paper/pretrain_occupancy/23/occupancy.pth', 77 | 'checkpoint_interval': 50000, 78 | 'chunk_size': 40000, 79 | 'initial_learning_rate': 0.001, 80 | 'l2_regularization_lambda': 1e-06, 81 | 'learing_rate_decay_rate': 500, 82 | 'no_batching': True, 83 | 'num_rays_per_batch': 8192, 84 | 'num_samples_per_ray': 384, 85 | 'perturb': 1.0, 86 | 'precrop_fraction': 0.5, 87 | 'precrop_iterations': 0, 88 | 'raw_noise_std': 0.0, 89 | 'near': 0.01, 90 | 'far': 4, 91 | 'spiral_radius': None, 92 | 'half_res': False, 93 | 'n_render_spiral': 40, 94 | 'render_spiral_angles': None, 95 | 'num_importance_samples_per_ray': 0, 96 | 'render_test': True, 97 | 'no_color_sigmoid': False, 98 | 'render_factor': 0, 99 | 'testskip': 1, 100 | 'deepvoxels_shape': 'greek', 101 | 'blender_white_background': True, 102 | 'llff_factor': 8, 103 | 'llff_no_ndc': False, 104 | 'llff_lindisp': False, 105 | 'llff_spherify': False, 106 | 'llff_hold': False, 107 | 'print_interval': 100, 108 | 'render_testset_interval': 100000, 109 | 'render_video_interval': 100000000, 110 | 'network_chunk_size': 65536, 111 | 'rng_seed': 0, 112 | 'use_initialization_fix': False, 113 | 'model_type': 'multi_network', 114 | 'random_direction_probability': -1, 115 | 'von_mises_kappa': -1, 116 | 'view_dependent_dropout_probability': -1, 117 | 'occupancy': { 118 | # 'dataset_dir': '/viscam/u/yenyu/ObjectFolderV2/objects/23/VisionNet_osf_data/', 119 | 'dataset_type': 'osf', 120 | # 'pretrained_cfg_path': '/viscam/u/yenyu/kiloosf2/cfgs/paper/pretrain/23.yaml', 121 | # 'pretrained_checkpoint_path': '/viscam/u/yenyu/kiloosf2/logs/paper/pretrain/23/checkpoint_0200000.pth', 122 | 'resolution': [256, 256, 256], 123 | 'subsample_resolution': [3, 3, 3], 124 | 'threshold': 10, 125 | 'voxel_batch_size': 16384, 126 | 'global_domain_min': [-0.6, -0.6, -0.6], 127 | 'global_domain_max': [0.6, 0.6, 0.6] 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /basics/CalibData.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | class CalibData: 3 | def __init__(self, dataPath): 4 | self.dataPath = dataPath 5 | data = np.load(dataPath) 6 | 7 | self.numBins = data['bins'] 8 | self.grad_r = data['grad_r'] 9 | self.grad_g = data['grad_g'] 10 | self.grad_b = data['grad_b'] 11 | -------------------------------------------------------------------------------- /basics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/basics/__init__.py -------------------------------------------------------------------------------- /basics/sensorParams.py: -------------------------------------------------------------------------------- 1 | ball_radius = 4.00/2; 2 | pixmm = 0.1245 # 0.0295; # 0.0302 3 | numBins = 120; 4 | 5 | # sensor setting 6 | h = 120 7 | w = 160 8 | cam2gel = 0.04 9 | -------------------------------------------------------------------------------- /box_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for bounding box computation.""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import ray_utils 6 | 7 | def ray_to_box_coordinate_frame_pairwise(box_center, box_rotation_matrix, 8 | rays_start_point, rays_end_point): 9 | """Moves a set of rays into a box's coordinate frame. 10 | 11 | Args: 12 | box_center: A tensor of size [3] or [r, 3]. 13 | box_rotation_matrix: A tensor of size [3, 3] or [r, 3, 3]. 14 | rays_start_point: A tensor of size [r, 3] where r is the number of rays. 15 | rays_end_points: A tensor of size [r, 3] where r is the number of rays. 16 | 17 | Returns: 18 | rays_start_point_in_box_frame: A tensor of size [r, 3]. 19 | rays_end_point_in_box_frame: A tensor if size [r, 3]. 20 | """ 21 | r = rays_start_point.size()[0] 22 | box_center = torch.broadcast_to(box_center, (r, 3)) 23 | box_rotation_matrix = torch.broadcast_to(box_rotation_matrix, (r, 3, 3)) 24 | rays_start_point_in_box_frame = torch.matmul( 25 | (rays_start_point - box_center).unsqueeze(1), 26 | box_rotation_matrix) 27 | rays_end_point_in_box_frame = torch.matmul( 28 | (rays_end_point - box_center).unsqueeze(1), 29 | box_rotation_matrix) 30 | return (rays_start_point_in_box_frame.view(-1, 3), 31 | rays_end_point_in_box_frame.view(-1, 3)) 32 | 33 | 34 | def ray_box_intersection_pairwise(box_center, 35 | box_rotation_matrix, 36 | box_length, 37 | box_width, 38 | box_height, 39 | rays_start_point, 40 | rays_end_point, 41 | exclude_negative_t=False, 42 | exclude_enlarged_t=True, 43 | epsilon=0.000001): 44 | """Intersects a set of rays with a box. 45 | 46 | Note: The intersection points are returned in the box coordinate frame. 47 | Note: Make sure the start and end point of the rays are not the same. 48 | Note: Even though a start and end point is passed for each ray, rays are 49 | never ending and can intersect a box beyond their start / end points. 50 | 51 | Args: 52 | box_center: A tensor of size [3] or [r, 3]. 53 | box_rotation_matrix: A tensor of size [3, 3] or [r, 3, 3]. 54 | box_length: A scalar tensor or of size [r]. 55 | box_width: A scalar tensor or of size [r]. 56 | box_height: A scalar tensor or of size [r]. 57 | rays_start_point: A tensor of size [r, 3] where r is the number of rays. 58 | rays_end_point: A tensor of size [r, 3] there r is the number of rays. 59 | exclude_negative_t: bool. 60 | exclude_enlarged_t: bool. 61 | epsilon: A very small number. 62 | 63 | Returns: 64 | intersection_points_in_box_frame: A tensor of size [r', 2, 3] 65 | that contains intersection points in box coordinate frame. 66 | indices_of_intersecting_rays: A tensor of size [r']. 67 | intersection_ts: A tensor of size [r']. 68 | """ 69 | r = rays_start_point.size()[0] 70 | box_length = box_length.expand(r) 71 | box_width = box_width.expand(r) 72 | box_height = box_height.expand(r) 73 | box_center = torch.broadcast_to(box_center, (r, 3)) 74 | box_rotation_matrix = torch.broadcast_to(box_rotation_matrix, (r, 3, 3)) 75 | rays_start_point_in_box_frame, rays_end_point_in_box_frame = ( 76 | ray_to_box_coordinate_frame_pairwise( 77 | box_center=box_center, 78 | box_rotation_matrix=box_rotation_matrix, 79 | rays_start_point=rays_start_point, 80 | rays_end_point=rays_end_point)) 81 | rays_a = rays_end_point_in_box_frame - rays_start_point_in_box_frame 82 | intersection_masks = [] 83 | intersection_points = [] 84 | intersection_ts = [] 85 | box_size = [box_length, box_width, box_height] 86 | for axis in range(3): 87 | plane_value = box_size[axis] / 2.0 88 | for _ in range(2): 89 | plane_value = -plane_value 90 | # Compute the scalar multiples of 'rays_a' to apply in order to intersect 91 | # with the plane. 92 | t = ((plane_value - rays_start_point_in_box_frame[:, axis]) / # [R,] 93 | rays_a[:, axis]) 94 | # The current axis only intersects with plane if the ray is not parallel 95 | # with the plane. Note that this will result in 't' being +/- infinity, becasue 96 | # the ray component in the axis is zero, resulting in rays_a[:, axis] = 0. 97 | intersects_with_plane = torch.abs(rays_a[:, axis]) > epsilon 98 | if exclude_negative_t: # Only allow at most one negative t 99 | t = torch.maximum(t, torch.tensor(0.0)) # [R,] 100 | if exclude_enlarged_t: 101 | t = torch.maximum(t, torch.tensor(1.0)) # [R,] 102 | intersection_ts.append(t) # [R, 1] 103 | intersection_points_i = [] 104 | 105 | # Initialize a mask which represents whether each ray intersects with the 106 | # current plane. 107 | intersection_masks_i = torch.ones_like(t, dtype=torch.int32).bool() # [R,] 108 | for axis2 in range(3): 109 | # Compute the point of intersection for the current axis. 110 | intersection_points_i_axis2 = ( # [R,] 111 | rays_start_point_in_box_frame[:, axis2] + t * rays_a[:, axis2]) 112 | intersection_points_i.append(intersection_points_i_axis2) # 3x [R,] 113 | 114 | # Update the intersection mask depending on whether the intersection 115 | # point is within bounds. 116 | intersection_masks_i = torch.logical_and( # [R,] 117 | torch.logical_and(intersection_masks_i, intersects_with_plane), 118 | torch.logical_and( 119 | intersection_points_i_axis2 <= (box_size[axis2] / 2.0 + epsilon), 120 | intersection_points_i_axis2 >= (-box_size[axis2] / 2.0 - epsilon))) 121 | intersection_points_i = torch.stack(intersection_points_i, dim=1) # [R, 3] 122 | intersection_masks.append(intersection_masks_i) # List of [R,] 123 | intersection_points.append(intersection_points_i) # List of [R, 3] 124 | intersection_ts = torch.stack(intersection_ts, dim=1) # [R, 6] 125 | intersection_masks = torch.stack(intersection_masks, dim=1) # [R, 6] 126 | intersection_points = torch.stack(intersection_points, dim=1) # [R, 6, 3] 127 | 128 | # Compute a mask over rays with exactly two plane intersections out of the six 129 | # planes. More intersections are possible if the ray coincides with a box 130 | # edge or corner, but we'll ignore these cases for now. 131 | counts = torch.sum(intersection_masks.int(), dim=1) # [R,] 132 | intersection_masks_any = torch.eq(counts, 2) # [R,] 133 | indices = torch.arange(intersection_masks_any.size()[0]).int() # [R,] 134 | # Apply the intersection masks over tensors. 135 | indices = indices[intersection_masks_any] # [R',] 136 | intersection_masks = intersection_masks[intersection_masks_any] # [R', 6] 137 | intersection_points = intersection_points[intersection_masks_any] # [R', 6, 3] 138 | intersection_points = intersection_points[intersection_masks].view(-1, 2, 3) # [R', 2, 3] 139 | # Ensure one or more positive ts. 140 | intersection_ts = intersection_ts[intersection_masks_any] # [R', 6] 141 | intersection_ts = intersection_ts[intersection_masks] # [R'*2] 142 | intersection_ts = intersection_ts.view(indices.size()[0], 2) # [R', 2] 143 | positive_ts_mask = (intersection_ts >= 0) # [R', 2] 144 | positive_ts_count = torch.sum(positive_ts_mask.int(), dim=1) # [R'] 145 | positive_ts_mask = (positive_ts_count >= 1) # [R'] 146 | intersection_points = intersection_points[positive_ts_mask] # [R'', 2, 3] 147 | false_indices = indices[torch.logical_not(positive_ts_mask)] # [R',] 148 | indices = indices[positive_ts_mask] # [R'',] 149 | if len(false_indices) > 0: 150 | intersection_masks_any[false_indices[:, None]] = torch.zeros(false_indices.size(), dtype=torch.bool) 151 | return rays_start_point_in_box_frame, intersection_masks_any, intersection_points, indices 152 | 153 | 154 | def compute_bounds_from_intersect_points(rays_o, intersect_indices, 155 | intersect_points): 156 | """Computes bounds from intersection points. 157 | 158 | Note: Make sure that inputs are in the same coordiante frame. 159 | 160 | Args: 161 | rays_o: [R, 3] float tensor 162 | intersect_indices: [R', 1] float tensor 163 | intersect_points: [R', 2, 3] float tensor 164 | 165 | Returns: 166 | intersect_bounds: [R', 2] float tensor 167 | 168 | where R is the number of rays and R' is the number of intersecting rays. 169 | """ 170 | intersect_rays_o = rays_o[intersect_indices] # [R', 1, 3] 171 | intersect_diff = intersect_points - intersect_rays_o # [R', 2, 3] 172 | intersect_bounds = torch.norm(intersect_diff, dim=2) # [R', 2] 173 | 174 | # Sort the bounds so that near comes before far for all rays. 175 | intersect_bounds, _ = torch.sort(intersect_bounds, dim=1) # [R', 2] 176 | 177 | # For some reason the sort function returns [R', ?] instead of [R', 2], so we 178 | # will explicitly reshape it. 179 | intersect_bounds = intersect_bounds.view(-1, 2) # [R', 2] 180 | return intersect_bounds 181 | 182 | 183 | def compute_ray_bbox_bounds_pairwise(rays_o, rays_d, box_length, 184 | box_width, box_height, box_center, 185 | box_rotation, far_limit=1e10): 186 | """Computes near and far bounds for rays intersecting with bounding boxes. 187 | 188 | Note: rays and boxes are defined in world coordinate frame. 189 | 190 | Args: 191 | rays_o: [R, 3] float tensor. A set of ray origins. 192 | rays_d: [R, 3] float tensor. A set of ray directions. 193 | box_length: scalar or [R,] float tensor. Bounding box length. 194 | box_width: scalar or [R,] float tensor. Bounding box width. 195 | box_height: scalar or [R,] float tensor. Bounding box height. 196 | box_center: [3,] or [R, 3] float tensor. The center of the box. 197 | box_rotation: [3, 3] or [R, 3, 3] float tensor. The box rotation matrix. 198 | far_limit: float. The maximum far value to use. 199 | 200 | Returns: 201 | intersect_bounds: [R', 2] float tensor. The bounds per-ray, sorted in 202 | ascending order. 203 | intersect_indices: [R', 1] float tensor. The intersection indices. 204 | intersect_mask: [R,] float tensor. The mask denoting intersections. 205 | """ 206 | # Compute ray destinations. 207 | normalized_rays_d = ray_utils.normalize_rays(rays=rays_d) 208 | rays_dst = rays_o + far_limit * normalized_rays_d 209 | 210 | # Transform the rays from world to box coordinate frame. 211 | rays_o_in_box_frame, intersect_mask, intersect_points_in_box_frame, intersect_indices = ( # [R,], [R', 2, 3], [R', 2] 212 | ray_box_intersection_pairwise( 213 | box_center=box_center, 214 | box_rotation_matrix=box_rotation, 215 | box_length=box_length, 216 | box_width=box_width, 217 | box_height=box_height, 218 | rays_start_point=rays_o, 219 | rays_end_point=rays_dst)) 220 | intersect_indices = intersect_indices.unsqueeze(1).long() # [R', 1] 221 | intersect_bounds = compute_bounds_from_intersect_points( 222 | rays_o=rays_o_in_box_frame, 223 | intersect_indices=intersect_indices, 224 | intersect_points=intersect_points_in_box_frame) 225 | return intersect_bounds, intersect_indices, intersect_mask 226 | -------------------------------------------------------------------------------- /build_occupancy_tree.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import grad 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import math 9 | import os 10 | import scipy.integrate as integrate 11 | import math 12 | from collections import deque 13 | import time 14 | from mpl_toolkits.mplot3d import Axes3D 15 | import matplotlib.pyplot as plt 16 | import itertools 17 | 18 | from utils import * 19 | from local_distill import create_multi_network_fourier_embedding, has_flag, create_multi_network 20 | from multi_modules import build_multi_network_from_single_networks, query_multi_network 21 | import kilonerf_cuda 22 | 23 | # TODO: move this to utils.py 24 | class Node: 25 | def __init__(self): 26 | pass 27 | 28 | # This function actually builds an occupancy grid 29 | def build_occupancy_tree(cfg, log_path): 30 | dev = torch.device('cuda') 31 | kilonerf_cuda.init_stream_pool(16) # TODO: cleanup 32 | kilonerf_cuda.init_magma() 33 | 34 | ConfigManager.init(cfg) 35 | 36 | global_domain_min, global_domain_max = ConfigManager.get_global_domain_min_and_max(torch.device('cpu')) 37 | global_domain_size = global_domain_max - global_domain_min 38 | Logger.write('global_domain_min: {}, global_domain_max: {}'.format(global_domain_min, global_domain_max)) 39 | 40 | pretrained_cfg = load_yaml_as_dict(cfg['pretrained_cfg_path']) 41 | 42 | if 'distilled_cfg_path' in pretrained_cfg: 43 | pretrained_cfg = load_yaml_as_dict(pretrained_cfg['distilled_cfg_path']) 44 | 45 | if 'discovery' in pretrained_cfg: 46 | for key in pretrained_cfg['discovery']: 47 | pretrained_cfg[key] = pretrained_cfg['discovery'][key] 48 | """ 49 | else: 50 | # end2end from scratch case 51 | assert pretrained_cfg['model_type'] == 'multi_network', 'occupancy grid creation is only implemented for multi networks' 52 | """ 53 | 54 | cp = torch.load(cfg['pretrained_checkpoint_path']) 55 | use_multi_network = pretrained_cfg['model_type'] == 'multi_network' or not ('model_type' in pretrained_cfg) 56 | if use_multi_network: 57 | position_num_input_channels, position_fourier_embedding = create_multi_network_fourier_embedding(1, pretrained_cfg['num_frequencies']) 58 | direction_num_input_channels, direction_fourier_embedding = create_multi_network_fourier_embedding(1, pretrained_cfg['num_frequencies_direction']) 59 | 60 | if 'model_state_dict' in cp: 61 | res = pretrained_cfg['fixed_resolution'] 62 | network_resolution = torch.tensor(res, dtype=torch.long, device=torch.device('cpu')) 63 | num_networks = res[0] * res[1] * res[2] 64 | network_voxel_size = global_domain_size / network_resolution 65 | 66 | multi_network = create_multi_network(num_networks, position_num_input_channels, direction_num_input_channels, 4, 67 | 'multimatmul_differentiable', pretrained_cfg).to(dev) 68 | multi_network.load_state_dict(cp['model_state_dict']) 69 | 70 | # Determine bounding boxes (domains) of all networks. Required for global to local coordinate conversion. 71 | domain_mins = [] 72 | domain_maxs = [] 73 | for coord in itertools.product(*[range(r) for r in res]): 74 | coord = torch.tensor(coord, device=torch.device('cpu')) 75 | domain_min = global_domain_min + network_voxel_size * coord 76 | domain_max = domain_min + network_voxel_size 77 | domain_mins.append(domain_min.tolist()) 78 | domain_maxs.append(domain_max.tolist()) 79 | domain_mins = torch.tensor(domain_mins, device=dev) 80 | domain_maxs = torch.tensor(domain_maxs, device=dev) 81 | else: 82 | root_nodes = cp['root_nodes'] 83 | 84 | # Merging individual networks into multi network for efficient inference 85 | single_networks = [] 86 | domain_mins, domain_maxs = [], [] 87 | nodes_to_process = root_nodes.copy() 88 | for node in nodes_to_process: 89 | if hasattr(node, 'network'): 90 | node.network_index = len(single_networks) 91 | single_networks.append(node.network) 92 | domain_mins.append(node.domain_min) 93 | domain_maxs.append(node.domain_max) 94 | else: 95 | nodes_to_process.append(node.leq_child) 96 | nodes_to_process.append(node.gt_child) 97 | linear_implementation = 'multimatmul_differentiable' 98 | multi_network = build_multi_network_from_single_networks(single_networks, linear_implementation=linear_implementation).to(dev) 99 | domain_mins = torch.tensor(domain_mins, device=dev) 100 | domain_maxs = torch.tensor(domain_maxs, device=dev) 101 | else: 102 | # Load teacher NeRF model: 103 | print("Load teacher NeRF model...") 104 | pretrained_nerf = load_pretrained_nerf_model(dev, cfg) 105 | 106 | occupancy_res = cfg['resolution'] 107 | total_num_voxels = occupancy_res[0] * occupancy_res[1] * occupancy_res[2] 108 | occupancy_grid = torch.tensor(occupancy_res, device=dev, dtype=torch.bool) 109 | occupancy_resolution = torch.tensor(occupancy_res, dtype=torch.long, device=torch.device('cpu')) 110 | occupancy_voxel_size = global_domain_size / occupancy_resolution 111 | first_voxel_min = global_domain_min 112 | first_voxel_max = first_voxel_min + occupancy_voxel_size 113 | 114 | first_voxel_samples = [] 115 | for dim in range(3): 116 | first_voxel_samples.append(torch.linspace(first_voxel_min[dim], first_voxel_max[dim], cfg['subsample_resolution'][dim])) 117 | first_voxel_samples = torch.stack(torch.meshgrid(*first_voxel_samples), dim=3).view(-1, 3) 118 | 119 | ranges = [] 120 | for dim in range(3): 121 | ranges.append(torch.arange(0, occupancy_res[dim])) 122 | index_grid = torch.stack(torch.meshgrid(*ranges), dim=3) 123 | index_grid = (index_grid * occupancy_voxel_size).unsqueeze(3) 124 | 125 | points = first_voxel_samples.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand(occupancy_res + list(first_voxel_samples.shape)) 126 | points = points + index_grid 127 | points = points.view(total_num_voxels, -1, 3) 128 | num_samples_per_voxel = points.size(1) 129 | 130 | mock_directions = torch.empty(min(cfg['voxel_batch_size'], total_num_voxels), 3).to(dev) 131 | 132 | # We query in a fixed grid at a higher resolution than the occupancy grid resolution to detect fine structures. 133 | all_densities = torch.empty(total_num_voxels, num_samples_per_voxel) 134 | print("all_densities size: ", all_densities.shape) 135 | end = 0 136 | while end < total_num_voxels: 137 | print('sampling network: {}/{} ({:.4f}%)'.format(end, total_num_voxels, 100 * end / total_num_voxels)) 138 | start = end 139 | end = min(start + cfg['voxel_batch_size'], total_num_voxels) 140 | actual_batch_size = end - start 141 | points_subset = points[start:end].to(dev).contiguous() # voxel_batch_size x num_samples_per_voxel x 3 142 | mock_directions_subset = mock_directions[:actual_batch_size] 143 | density_dim = 3 144 | with torch.no_grad(): 145 | if use_multi_network: 146 | result = query_multi_network(multi_network, domain_mins, domain_maxs, points_subset, mock_directions_subset, 147 | position_fourier_embedding, direction_fourier_embedding, None, None, False, None, pretrained_cfg)[:, :, density_dim] 148 | else: 149 | print("Use teacher NeRF model...") 150 | mock_directions_subset = mock_directions_subset.unsqueeze(1).expand(points_subset.size()) 151 | points_and_dirs = torch.cat([points_subset.reshape(-1, 3), mock_directions_subset.reshape(-1, 3)], dim=-1) 152 | lights_subset = torch.zeros((points_and_dirs.shape[0], 3), device=dev) 153 | lights_subset[:, -1] = 1 154 | points_and_dirs_and_lights = torch.cat([points_and_dirs, lights_subset], dim=-1) 155 | result = pretrained_nerf(points_and_dirs_and_lights)[:, density_dim].view(actual_batch_size, -1) 156 | # result = F.relu(pretrained_nerf(points_and_dirs_and_lights)[:, density_dim].view(actual_batch_size, -1)) 157 | # print(result.max()) 158 | all_densities[start:end] = result.cpu() 159 | del points, points_subset, mock_directions 160 | 161 | occupancy_grid = all_densities.to(dev) > cfg['threshold'] 162 | del all_densities 163 | occupancy_grid = occupancy_grid.view(cfg['resolution'] + [-1]) 164 | 165 | occupancy_grid = occupancy_grid.any(dim=3) # checks if any point in the voxel is above the threshold 166 | 167 | 168 | Logger.write('{} out of {} voxels are occupied. {:.2f}%'.format(occupancy_grid.sum().item(), occupancy_grid.numel(), 100 * occupancy_grid.sum().item() / occupancy_grid.numel())) 169 | 170 | occupancy_filename = log_path + '/occupancy.pth' 171 | torch.save(occupancy_grid, occupancy_filename) 172 | Logger.write('Saved occupancy grid to {}'.format(occupancy_filename)) 173 | 174 | def main(): 175 | cfg, log_path = parse_args_and_init_logger() 176 | build_occupancy_tree(cfg, log_path) 177 | 178 | if __name__ == '__main__': 179 | main() 180 | 181 | -------------------------------------------------------------------------------- /calibs/dataPack.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/calibs/dataPack.npz -------------------------------------------------------------------------------- /calibs/depth_bg.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/calibs/depth_bg.npy -------------------------------------------------------------------------------- /calibs/polycalib.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/calibs/polycalib.npz -------------------------------------------------------------------------------- /calibs/real_bg.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/calibs/real_bg.npy -------------------------------------------------------------------------------- /cam_utils.py: -------------------------------------------------------------------------------- 1 | """Various camera utility functions.""" 2 | 3 | import numpy as np 4 | 5 | 6 | def w2c_to_c2w(w2c): 7 | """ 8 | Args: 9 | w2c: [N, 4, 4] np.float32. World-to-camera extrinsics matrix. 10 | 11 | Returns: 12 | c2w: [N, 4, 4] np.float32. Camera-to-world extrinsics matrix. 13 | """ 14 | R = w2c[:3, :3] 15 | T = w2c[:3, 3] 16 | 17 | c2w = np.eye(4, dtype=np.float32) 18 | c2w[:3, 3] = -1 * np.dot(R.transpose(), w2c[:3, 3]) 19 | c2w[:3, :3] = R.transpose() 20 | return c2w 21 | -------------------------------------------------------------------------------- /ddsp_torch.py: -------------------------------------------------------------------------------- 1 | from typing import Text 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from scipy import fftpack 6 | import torch 7 | 8 | def torch_float32(x): 9 | """Ensure array/tensor is a float32 tf.Tensor.""" 10 | if isinstance(x, torch.Tensor): 11 | return x.float() # This is a no-op if x is float32. 12 | elif isinstance(x, np.ndarray): 13 | return torch.from_numpy(x).cuda() # This is a no-op if x is float32. 14 | else: 15 | return torch.tensor(x, dtype=torch.float32).cuda() # This is a no-op if x is float32. 16 | 17 | def safe_log(x, eps=1e-7): 18 | """Avoid taking the log of a non-positive number.""" 19 | safe_x = torch.where(x <= eps, eps, x.double()) 20 | return torch.log(safe_x) 21 | 22 | def stft(audio, frame_size=2048, overlap=0.75): 23 | """Differentiable stft in PyTorch, computed in batch.""" 24 | assert frame_size * overlap % 2.0 == 0.0 25 | 26 | # Remove channel dim if present. 27 | audio = torch_float32(audio) 28 | if len(audio.shape) == 3: 29 | audio = torch.squeeze(audio, axis=-1) 30 | 31 | s = torch.stft( 32 | audio, 33 | n_fft=int(frame_size), 34 | hop_length=int(frame_size * (1.0 - overlap)), 35 | win_length=int(frame_size), 36 | window=torch.hann_window(int(frame_size)).to(audio), 37 | pad_mode='reflect', 38 | return_complex=True, 39 | ) 40 | return s 41 | 42 | def compute_mag(audio, size=2048, overlap=0.75): 43 | mag = torch.abs(stft(audio, frame_size=size, overlap=overlap)) 44 | return torch_float32(mag) 45 | 46 | def compute_logmag(audio, size=2048, overlap=0.75): 47 | return safe_log(compute_mag(audio, size, overlap)) 48 | 49 | def specplot(audio, 50 | vmin=-5, 51 | vmax=1, 52 | rotate=True, 53 | size=512 + 256, 54 | **matshow_kwargs): 55 | """Plot the log magnitude spectrogram of audio.""" 56 | # If batched, take first element. 57 | if len(audio.shape) == 2: 58 | audio = audio[0] 59 | 60 | logmag = compute_logmag(torch_float32(audio), size=size) 61 | # logmag = spectral_ops.compute_logmel(core.tf_float32(audio), lo_hz=8.0, bins=80, fft_size=size) 62 | # logmag = spectral_ops.compute_mfcc(core.tf_float32(audio), mfcc_bins=40, fft_size=size) 63 | # if rotate: 64 | # logmag = torch.rot90(logmag) 65 | logmag = torch.flip(logmag, [0]) 66 | # Plotting. 67 | plt.matshow(logmag.detach().cpu(), 68 | vmin=vmin, 69 | vmax=vmax, 70 | cmap=plt.cm.magma, 71 | aspect='auto', 72 | **matshow_kwargs) 73 | plt.xticks([]) 74 | plt.yticks([]) 75 | plt.xlabel('Time') 76 | plt.ylabel('Frequency') 77 | 78 | # Time-varying convolution ----------------------------------------------------- 79 | def get_fft_size(frame_size: int, ir_size: int, power_of_2: bool = True) -> int: 80 | """Calculate final size for efficient FFT. 81 | 82 | Args: 83 | frame_size: Size of the audio frame. 84 | ir_size: Size of the convolving impulse response. 85 | power_of_2: Constrain to be a power of 2. If False, allow other 5-smooth 86 | numbers. TPU requires power of 2, while GPU is more flexible. 87 | 88 | Returns: 89 | fft_size: Size for efficient FFT. 90 | """ 91 | convolved_frame_size = ir_size + frame_size - 1 92 | if power_of_2: 93 | # Next power of 2. 94 | fft_size = int(2**np.ceil(np.log2(convolved_frame_size))) 95 | else: 96 | fft_size = int(fftpack.helper.next_fast_len(convolved_frame_size)) 97 | return fft_size 98 | 99 | def crop_and_compensate_delay(audio: torch.Tensor, audio_size: int, ir_size: int, 100 | padding: Text, 101 | delay_compensation: int) -> torch.Tensor: 102 | """Crop audio output from convolution to compensate for group delay. 103 | 104 | Args: 105 | audio: Audio after convolution. Tensor of shape [batch, time_steps]. 106 | audio_size: Initial size of the audio before convolution. 107 | ir_size: Size of the convolving impulse response. 108 | padding: Either 'valid' or 'same'. For 'same' the final output to be the 109 | same size as the input audio (audio_timesteps). For 'valid' the audio is 110 | extended to include the tail of the impulse response (audio_timesteps + 111 | ir_timesteps - 1). 112 | delay_compensation: Samples to crop from start of output audio to compensate 113 | for group delay of the impulse response. If delay_compensation < 0 it 114 | defaults to automatically calculating a constant group delay of the 115 | windowed linear phase filter from frequency_impulse_response(). 116 | 117 | Returns: 118 | Tensor of cropped and shifted audio. 119 | 120 | Raises: 121 | ValueError: If padding is not either 'valid' or 'same'. 122 | """ 123 | # Crop the output. 124 | if padding == 'valid': 125 | crop_size = ir_size + audio_size - 1 126 | elif padding == 'same': 127 | crop_size = audio_size 128 | else: 129 | raise ValueError('Padding must be \'valid\' or \'same\', instead ' 130 | 'of {}.'.format(padding)) 131 | 132 | # Compensate for the group delay of the filter by trimming the front. 133 | # For an impulse response produced by frequency_impulse_response(), 134 | # the group delay is constant because the filter is linear phase. 135 | total_size = int(audio.shape[-1]) 136 | crop = total_size - crop_size 137 | start = ((ir_size - 1) // 2 - 138 | 1 if delay_compensation < 0 else delay_compensation) 139 | end = crop - start 140 | return audio[:, start:-end] 141 | 142 | def fft_convolve(audio: torch.Tensor, 143 | impulse_response: torch.Tensor, 144 | padding: Text = 'same', 145 | delay_compensation: int = -1) -> torch.Tensor: 146 | """Filter audio with frames of time-varying impulse responses. 147 | 148 | Time-varying filter. Given audio [batch, n_samples], and a series of impulse 149 | responses [batch, n_frames, n_impulse_response], splits the audio into frames, 150 | applies filters, and then overlap-and-adds audio back together. 151 | Applies non-windowed non-overlapping STFT/ISTFT to efficiently compute 152 | convolution for large impulse response sizes. 153 | 154 | Args: 155 | audio: Input audio. Tensor of shape [batch, audio_timesteps]. 156 | impulse_response: Finite impulse response to convolve. Can either be a 2-D 157 | Tensor of shape [batch, ir_size], or a 3-D Tensor of shape [batch, 158 | ir_frames, ir_size]. A 2-D tensor will apply a single linear 159 | time-invariant filter to the audio. A 3-D Tensor will apply a linear 160 | time-varying filter. Automatically chops the audio into equally shaped 161 | blocks to match ir_frames. 162 | padding: Either 'valid' or 'same'. For 'same' the final output to be the 163 | same size as the input audio (audio_timesteps). For 'valid' the audio is 164 | extended to include the tail of the impulse response (audio_timesteps + 165 | ir_timesteps - 1). 166 | delay_compensation: Samples to crop from start of output audio to compensate 167 | for group delay of the impulse response. If delay_compensation is less 168 | than 0 it defaults to automatically calculating a constant group delay of 169 | the windowed linear phase filter from frequency_impulse_response(). 170 | 171 | Returns: 172 | audio_out: Convolved audio. Tensor of shape 173 | [batch, audio_timesteps + ir_timesteps - 1] ('valid' padding) or shape 174 | [batch, audio_timesteps] ('same' padding). 175 | 176 | Raises: 177 | ValueError: If audio and impulse response have different batch size. 178 | ValueError: If audio cannot be split into evenly spaced frames. (i.e. the 179 | number of impulse response frames is on the order of the audio size and 180 | not a multiple of the audio size.) 181 | """ 182 | audio, impulse_response = torch_float32(audio), torch_float32(impulse_response) 183 | 184 | # Add a frame dimension to impulse response if it doesn't have one. 185 | ir_shape = list(impulse_response.shape) 186 | if len(ir_shape) == 2: 187 | impulse_response = torch.unsqueeze(impulse_response, axis = 2) 188 | ir_shape = list(impulse_response.shape) 189 | 190 | # Get shapes of audio and impulse response. 191 | batch_size_ir, n_ir_frames, ir_size = ir_shape 192 | batch_size, audio_size = list(audio.shape) 193 | 194 | # Validate that batch sizes match. 195 | if batch_size != batch_size_ir: 196 | raise ValueError('Batch size of audio ({}) and impulse response ({}) must ' 197 | 'be the same.'.format(batch_size, batch_size_ir)) 198 | 199 | # Cut audio into frames. 200 | frame_size = int(np.ceil(audio_size / n_ir_frames)) 201 | hop_size = frame_size 202 | audio_frames = audio.unfold(1, frame_size, hop_size) 203 | 204 | # Check that number of frames match. 205 | n_audio_frames = int(audio_frames.shape[1]) 206 | if n_audio_frames != n_ir_frames: 207 | raise ValueError( 208 | 'Number of Audio frames ({}) and impulse response frames ({}) do not ' 209 | 'match. For small hop size = ceil(audio_size / n_ir_frames), ' 210 | 'number of impulse response frames must be a multiple of the audio ' 211 | 'size.'.format(n_audio_frames, n_ir_frames)) 212 | 213 | # Pad and FFT the audio and impulse responses. 214 | fft_size = get_fft_size(frame_size, ir_size, power_of_2=True) 215 | audio_fft = torch.fft.rfft(audio_frames, fft_size) 216 | ir_fft = torch.fft.rfft(impulse_response, fft_size) 217 | 218 | # Multiply the FFTs (same as convolution in time). 219 | audio_ir_fft = torch.multiply(audio_fft, ir_fft) 220 | 221 | # Take the IFFT to resynthesize audio. 222 | audio_frames_out = torch.fft.irfft(audio_ir_fft) 223 | # audio_out = tf.signal.overlap_and_add(audio_frames_out, hop_size) 224 | audio_out = torch.squeeze(audio_frames_out, axis=1) 225 | 226 | # Crop and shift the output audio. 227 | return crop_and_compensate_delay(audio_out, audio_size, ir_size, padding, 228 | delay_compensation) 229 | 230 | def get_modal_fir(gains, frequencies, dampings, n_samples=44100*2, sample_rate=44100): 231 | t = torch.reshape(torch.arange(n_samples)/sample_rate, (1, 1, -1)).cuda() 232 | g = torch.unsqueeze(gains, axis=2) 233 | f = torch.reshape(frequencies, (1, -1, 1)) 234 | d = torch.reshape(dampings, (1, -1, 1)) 235 | pure = torch.sin(2 * np.pi * f * t) 236 | damped = torch.exp(-1 * torch.abs(d) * t) * pure 237 | signal = torch.sum(g * damped, axis=1) 238 | return torch.cat((torch.zeros_like(signal), signal), axis=1) -------------------------------------------------------------------------------- /demo/ObjectFile.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/demo/ObjectFile.pth -------------------------------------------------------------------------------- /demo/audio_demo_forces.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/demo/audio_demo_forces.npy -------------------------------------------------------------------------------- /demo/audio_demo_vertices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/demo/audio_demo_vertices.npy -------------------------------------------------------------------------------- /demo/touch_demo_gelinfo.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/demo/touch_demo_gelinfo.npy -------------------------------------------------------------------------------- /demo/touch_demo_vertices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/demo/touch_demo_vertices.npy -------------------------------------------------------------------------------- /demo/vision_demo.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/demo/vision_demo.npy -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ObjectFolder-env 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | - nvidia 7 | dependencies: 8 | - pip=21.0.1 9 | - python=3.8 10 | - pytorch=1.8.1 11 | - torchvision 12 | - cudatoolkit=11.1.1 13 | - numpy 14 | - scikit-image 15 | - scipy 16 | - tqdm 17 | - imageio 18 | - pyyaml 19 | - pip: 20 | - imageio-ffmpeg 21 | - lpips 22 | - opencv-python 23 | - librosa 24 | -------------------------------------------------------------------------------- /fast_kilonerf_renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import kilonerf_cuda 3 | from utils import PerfMonitor, ConfigManager 4 | from run_nerf_helpers import get_rays, replace_transparency_by_background_color 5 | 6 | class FastKiloNeRFRenderer(): 7 | def __init__(self, c2w, intrinsics, background_color, occupancy_grid, multi_network, domain_mins, domain_maxs, 8 | white_bkgd, max_depth_index, min_distance, max_distance, performance_monitoring, occupancy_resolution, max_samples_per_ray, transmittance_threshold): 9 | 10 | self.set_camera_pose(c2w) 11 | self.intrinsics = intrinsics 12 | self.background_color = background_color 13 | self.occupancy_grid = occupancy_grid 14 | self.multi_network = multi_network 15 | self.domain_mins = domain_mins 16 | self.domain_maxs = domain_maxs 17 | self.white_bkgd = white_bkgd 18 | self.max_depth_index = max_depth_index 19 | self.min_distance = min_distance 20 | self.max_distance = max_distance 21 | self.performance_monitoring = performance_monitoring 22 | self.occupancy_resolution = occupancy_resolution 23 | self.max_samples_per_ray = max_samples_per_ray 24 | self.transmittance_threshold = transmittance_threshold 25 | 26 | # Get ray directions for abitrary render pose 27 | # Precompute distances between sampling points, which vary along the pixel dimension 28 | _, rays_d = get_rays(intrinsics, self.c2w) # H x W x 3 29 | direction_norms = torch.norm(rays_d, dim=-1) # H x W 30 | self.distance_between_samples = (1 / (self.max_depth_index - 1)) * (self.max_distance - self.min_distance) 31 | self.constant_dists = (self.distance_between_samples * direction_norms).view(-1).unsqueeze(1) # H * W x 1 32 | 33 | self.rgb_map = torch.empty([self.intrinsics.H, self.intrinsics.W, 3], dtype=torch.float, device=occupancy_grid.device) 34 | self.rgb_map_pointer = self.rgb_map.data_ptr() 35 | 36 | def set_rgb_map_pointer(self, rgb_map_pointer): 37 | self.rgb_map = None 38 | self.rgb_map_pointer = rgb_map_pointer 39 | 40 | def set_camera_pose(self, c2w): 41 | self.c2w = c2w[:3, :4] 42 | 43 | def render(self): 44 | PerfMonitor.add('start') 45 | PerfMonitor.is_active = self.performance_monitoring 46 | 47 | rays_o, rays_d = get_rays(self.intrinsics, self.c2w, expand_origin=False) 48 | PerfMonitor.add('ray directions', ['preprocessing']) 49 | 50 | origin = rays_o 51 | directions = rays_d.reshape(-1, 3) # directions are *not* normalized. 52 | res = self.occupancy_resolution 53 | global_domain_min, global_domain_max = ConfigManager.get_global_domain_min_and_max(directions.device) 54 | global_domain_size = global_domain_max - global_domain_min 55 | occupancy_resolution = torch.tensor(res, dtype=torch.long, device=directions.device) 56 | strides = torch.tensor([res[2] * res[1], res[2], 1], dtype=torch.int, device=directions.device) # assumes row major ordering 57 | voxel_size = global_domain_size / occupancy_resolution 58 | num_rays = directions.size(0) 59 | 60 | active_ray_mask = torch.empty(num_rays, dtype=torch.bool, device=directions.device) 61 | depth_indices = torch.empty(num_rays, dtype=torch.short, device=directions.device) 62 | acc_map = torch.empty([self.intrinsics.H, self.intrinsics.W], dtype=torch.float, device=directions.device) 63 | # the final transmittance of a pass will be the initial transmittance of the next 64 | transmittance = torch.empty([self.intrinsics.H, self.intrinsics.W], dtype=torch.float, device=directions.device) 65 | 66 | PerfMonitor.add('prep', ['preprocessing']) 67 | 68 | is_initial_query = True 69 | is_final_pass = False 70 | 71 | pass_idx = 0 72 | integrate_num_blocks = 40 73 | integrate_num_threads = 512 74 | while not is_final_pass: 75 | 76 | if type(self.max_samples_per_ray) is list: 77 | # choose max samples per ray depending on the pass 78 | # in the later passes we can sample more per ray to avoid too much overhead from too many passes 79 | current_max_samples_per_ray = self.max_samples_per_ray[min(pass_idx, len(self.max_samples_per_ray) - 1)] 80 | else: 81 | # just use the same number of samples for all passes 82 | current_max_samples_per_ray = self.max_samples_per_ray 83 | 84 | # Compute query indices along the rays and determine assignment of query location to networks 85 | # Tunable CUDA hyperparameters 86 | kernel_max_num_blocks = 40 87 | kernel_max_num_threads = 512 88 | version = 0 89 | query_indices, assigned_networks = kilonerf_cuda.generate_query_indices_on_ray(origin, directions, self.occupancy_grid, active_ray_mask, depth_indices, voxel_size, 90 | global_domain_min, global_domain_max, strides, self.distance_between_samples, current_max_samples_per_ray, self.max_depth_index, self.min_distance, is_initial_query, 91 | kernel_max_num_blocks, kernel_max_num_threads, version) 92 | 93 | PerfMonitor.add('sample query points', ['preprocessing']) 94 | 95 | with_explicit_mask = True 96 | query_indices = query_indices.view(-1) 97 | assigned_networks = assigned_networks.view(-1) 98 | if with_explicit_mask: 99 | active_samples_mask = assigned_networks != -1 100 | assigned_networks = assigned_networks[active_samples_mask] 101 | # when with_expclit_mask = False: Sort all points, including those with assigned_network == -1 102 | # Points with assigned_network == -1 will be placed in the beginning and can then be filtered by moving the start of the array (zero cost) 103 | 104 | #assigned_networks, reorder_indices = torch.sort(assigned_networks) # sorting via PyTorch is significantly slower 105 | #reorder_indices = torch.arange(assigned_networks.size(0), dtype=torch.int32, device=assigned_networks.device) 106 | #kilonerf_cuda.sort_by_key_int16_int32(assigned_networks, reorder_indices) # stable sort does not seem to be slower/faster 107 | reorder_indices = torch.arange(assigned_networks.size(0), dtype=torch.int64, device=assigned_networks.device) 108 | kilonerf_cuda.sort_by_key_int16_int64(assigned_networks, reorder_indices) 109 | PerfMonitor.add('sort', ['reorder and backorder']) 110 | 111 | # make sure that also batch sizes are given for networks which are queried 0 points 112 | contained_nets, batch_size_per_network_incomplete = torch.unique_consecutive(assigned_networks, return_counts=True) 113 | if not with_explicit_mask: 114 | num_removable_points = batch_size_per_network_incomplete[0] 115 | contained_nets = contained_nets[1:].to(torch.long) 116 | batch_size_per_network_incomplete = batch_size_per_network_incomplete[1:] 117 | else: 118 | contained_nets = contained_nets.to(torch.long) 119 | batch_size_per_network = torch.zeros(self.multi_network.num_networks, device=query_indices.device, dtype=torch.long) 120 | batch_size_per_network[contained_nets] = batch_size_per_network_incomplete 121 | ends = batch_size_per_network.cumsum(0).to(torch.int32) 122 | starts = ends - batch_size_per_network.to(torch.int32) 123 | PerfMonitor.add('batch_size_per_network', ['reorder and backorder']) 124 | 125 | 126 | # Remove all points which are assigned to no network (those points are in empty space or outside the global domain) 127 | if with_explicit_mask: 128 | query_indices = query_indices[active_samples_mask] 129 | else: 130 | reorder_indices = reorder_indices[num_removable_points:] # just moving a pointer 131 | PerfMonitor.add('remove points', ['reorder and backorder']) 132 | 133 | # Reorder the query indices 134 | query_indices = query_indices[reorder_indices] 135 | #query_indices = kilonerf_cuda.gather_int32(reorder_indices, query_indices) 136 | query_indices = query_indices 137 | PerfMonitor.add('reorder', ['reorder and backorder']) 138 | 139 | num_points_to_process = query_indices.size(0) if query_indices.ndim > 0 else 0 140 | #print("#points to process:", num_points_to_process, flush=True) 141 | if num_points_to_process == 0: 142 | break 143 | 144 | # Evaluate the network 145 | network_eval_num_blocks = -1 # ignored currently 146 | compute_capability = torch.cuda.get_device_capability(query_indices.device) 147 | if compute_capability == (7, 5): 148 | network_eval_num_threads = 512 # for some reason the compiler uses more than 96 registers for this CC, so we cannot launch 640 threads 149 | else: 150 | network_eval_num_threads = 640 151 | version = 0 152 | raw_outputs = kilonerf_cuda.network_eval_query_index(query_indices, self.multi_network.serialized_params, self.domain_mins, self.domain_maxs, starts, ends, origin, 153 | self.c2w[:3, :3].contiguous(), self.multi_network.num_networks, self.multi_network.hidden_layer_size, 154 | self.intrinsics.H, self.intrinsics.W, self.intrinsics.cx, self.intrinsics.cy, self.intrinsics.fx, self.intrinsics.fy, self.max_depth_index, self.min_distance, self.distance_between_samples, 155 | network_eval_num_blocks, network_eval_num_threads, version) 156 | PerfMonitor.add('fused network eval', ['network query']) 157 | 158 | # Backorder outputs 159 | if with_explicit_mask: 160 | raw_outputs_backordered = torch.empty_like(raw_outputs) 161 | raw_outputs_backordered[reorder_indices] = raw_outputs 162 | #raw_outputs_backordered = kilonerf_cuda.scatter_int32_float4(reorder_indices, raw_outputs) 163 | del raw_outputs 164 | raw_outputs_full = torch.zeros(num_rays * current_max_samples_per_ray, 4, dtype=torch.float, device=raw_outputs_backordered.device) 165 | raw_outputs_full[active_samples_mask] = raw_outputs_backordered 166 | else: 167 | raw_outputs_full = torch.zeros(num_rays * current_max_samples_per_ray, 4, dtype=torch.float, device=raw_outputs.device) 168 | raw_outputs_full[reorder_indices] = raw_outputs 169 | PerfMonitor.add('backorder', ['reorder and backorder']) 170 | 171 | # Integrate sampled densities and colors along each ray to render the final image 172 | version = 0 173 | kilonerf_cuda.integrate(raw_outputs_full, self.constant_dists, self.rgb_map_pointer, acc_map, transmittance, active_ray_mask, num_rays, current_max_samples_per_ray, 174 | self.transmittance_threshold, is_initial_query, integrate_num_blocks, integrate_num_threads, version) 175 | is_final_pass = not active_ray_mask.any().item() 176 | 177 | is_initial_query = False 178 | if not is_final_pass: 179 | PerfMonitor.add('integration', ['integration']) 180 | pass_idx += 1 181 | 182 | if self.white_bkgd: 183 | kilonerf_cuda.replace_transparency_by_background_color(self.rgb_map_pointer, acc_map, self.background_color, integrate_num_blocks, integrate_num_threads) 184 | 185 | PerfMonitor.is_active = True 186 | PerfMonitor.add('integration', ['integration']) 187 | elapsed_time = PerfMonitor.log_and_reset(self.performance_monitoring) 188 | self.rgb_map = self.rgb_map.view(self.intrinsics.H, self.intrinsics.W, 3) if self.rgb_map is not None else None 189 | return self.rgb_map, elapsed_time 190 | -------------------------------------------------------------------------------- /load_osf.py: -------------------------------------------------------------------------------- 1 | """Data loader for OSF data.""" 2 | 3 | import os 4 | import torch 5 | import numpy as np 6 | import imageio 7 | import json 8 | 9 | import cam_utils 10 | 11 | 12 | trans_t = lambda t: torch.tensor([ 13 | [1,0,0,0], 14 | [0,1,0,0], 15 | [0,0,1,t], 16 | [0,0,0,1] 17 | ], dtype=torch.float) 18 | 19 | rot_phi = lambda phi: torch.tensor([ 20 | [1,0,0,0], 21 | [0,np.cos(phi),-np.sin(phi),0], 22 | [0,np.sin(phi), np.cos(phi),0], 23 | [0,0,0,1] 24 | ], dtype=torch.float) 25 | 26 | rot_theta = lambda th: torch.tensor([ 27 | [np.cos(th),0,-np.sin(th),0], 28 | [0,1,0,0], 29 | [np.sin(th),0, np.cos(th),0], 30 | [0,0,0,1] 31 | ], dtype=torch.float) 32 | 33 | 34 | def pose_spherical(theta, phi, radius): 35 | c2w = trans_t(radius) 36 | c2w = rot_phi(phi/180.*np.pi) @ c2w 37 | c2w = rot_theta(theta/180.*np.pi) @ c2w 38 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 39 | return c2w 40 | 41 | def coordinates_to_c2w(x, y, z, r=2.5): 42 | theta = np.arccos(z / r) 43 | phi = np.arctan2(x, -y) 44 | Rx = np.array([[1, 0, 0], [0, np.cos(theta), -np.sin(theta)], [0, np.sin(theta), np.cos(theta)]]) 45 | Rz = np.array([[np.cos(phi), -np.sin(phi), 0], [np.sin(phi), np.cos(phi), 0], [0, 0, 1]]) 46 | R = Rz @ Rx 47 | c2w = R.tolist() 48 | c2w[0].append(x) 49 | c2w[1].append(y) 50 | c2w[2].append(z) 51 | c2w.append([0., 0., 0., 1.]) 52 | #c2w = np.array(c2w).astype(np.float32) 53 | return c2w 54 | 55 | def convert_cameras_to_nerf_format(anno): 56 | """ 57 | Args: 58 | anno: List of annotations for each example. Each annotation is represented by a 59 | dictionary that must contain the key `RT` which is the world-to-camera 60 | extrinsics matrix with shape [3, 4], in [right, down, forward] coordinates. 61 | 62 | Returns: 63 | c2w: [N, 4, 4] np.float32. Array of camera-to-world extrinsics matrices in 64 | [right, up, backwards] coordinates. 65 | """ 66 | c2w_list = [] 67 | for a in anno: 68 | # Convert from w2c to c2w. 69 | w2c = np.array(a['RT'] + [[0.0, 0.0, 0.0, 1.0]]) 70 | c2w = cam_utils.w2c_to_c2w(w2c) 71 | 72 | # Convert from [right, down, forwards] to [right, up, backwards] 73 | c2w[:3, 1] *= -1 # down -> up 74 | c2w[:3, 2] *= -1 # forwards -> back 75 | c2w_list.append(c2w) 76 | c2w = np.array(c2w_list) 77 | print("c2w: ", c2w) 78 | return c2w 79 | 80 | 81 | def load_osf_data(test_file_path): 82 | 83 | all_poses = [] 84 | all_metadata = [] 85 | counts = [0] 86 | test_file = np.load(test_file_path) 87 | N = test_file.shape[0] 88 | for i in range(N): 89 | cx, cy, cz, lx, ly, lz = test_file[i] 90 | poses = coordinates_to_c2w(cx, cy, cz) 91 | metadata = np.array([[lx, ly, lz]]).astype(np.float32) 92 | all_poses.append(poses) 93 | all_metadata.append(metadata) 94 | 95 | poses = np.array(all_poses).astype(np.float32) 96 | 97 | metadata = np.concatenate(all_metadata, 0) 98 | counts.append(N) 99 | i_split = [np.arange(counts[0], counts[1])] 100 | 101 | H, W, focal = 256, 256, 355.5555419921875 102 | 103 | return poses, [H, W, focal], i_split, metadata 104 | -------------------------------------------------------------------------------- /ray_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for ray computation.""" 2 | import math 3 | import numpy as np 4 | from scipy.spatial.transform import Rotation as R 5 | import box_utils 6 | import torch 7 | 8 | 9 | def apply_batched_transformations(inputs, transformations): 10 | """Batched transformation of inputs. 11 | 12 | Args: 13 | inputs: List of [R, S, 3] 14 | transformations: [R, 4, 4] 15 | 16 | Returns: 17 | transformed_inputs: List of [R, S, 3] 18 | """ 19 | # if rotation_only: 20 | # transformations[:, :3, 3] = torch.zeros((3,), dtype=torch.float) 21 | 22 | transformed_inputs = [] 23 | for x in inputs: 24 | N_samples = x.size()[1] 25 | homog_transformations = transformations.unsqueeze(1) # [R, 1, 4, 4] 26 | homog_transformations = torch.tile(homog_transformations, (1, N_samples, 1, 1)) # [R, S, 4, 4] 27 | homog_component = torch.ones_like(x)[..., 0:1] # [R, S, 1] 28 | homog_x = torch.cat((x, homog_component), axis=-1) # [R, S, 4] 29 | homog_x = homog_x.unsqueeze(2) 30 | transformed_x = torch.matmul( 31 | homog_x, 32 | torch.transpose(homog_transformations, 2, 3)) # [R, S, 1, 4] 33 | transformed_x = transformed_x[..., 0, :3] # [R, S, 3] 34 | transformed_inputs.append(transformed_x) 35 | return transformed_inputs 36 | 37 | 38 | def get_transformation_from_params(params): 39 | translation, rotation = [0, 0, 0], [0, 0, 0] 40 | if 'translation' in params: 41 | translation = params['translation'] 42 | if 'rotation' in params: 43 | rotation = params['rotation'] 44 | translation = torch.tensor(translation, dtype=torch.float) 45 | rotmat = torch.tensor(R.from_euler('xyz', rotation, degrees=True).as_matrix(), dtype=torch.float) 46 | return translation, rotmat 47 | 48 | 49 | def rotate_dirs(dirs, rotmat): 50 | """ 51 | Args: 52 | dirs: [R, 3] float tensor. 53 | rotmat: [3, 3] 54 | """ 55 | if type(dirs) == np.ndarray: 56 | dirs = torch.tensor(dirs).float() 57 | #rotmat = rotmat.unsqueeze(0) 58 | rotmat = torch.broadcast_to(rotmat, (dirs.shape[0], 3, 3)) # [R, 3, 3] 59 | dirs_obj = torch.matmul(dirs.unsqueeze(1), torch.transpose(rotmat, 1, 2)) # [R, 1, 3] 60 | dirs_obj = dirs_obj.squeeze(1) # [R, 3] 61 | return dirs_obj 62 | 63 | 64 | def transform_dirs(dirs, params, inverse=False): 65 | _, rotmat = get_transformation_from_params(params) # [3,], [3, 3] 66 | if inverse: 67 | rotmat = torch.transpose(rotmat, 0, 1) # [3, 3] 68 | dirs_transformed = rotate_dirs(dirs, rotmat) 69 | return dirs_transformed 70 | 71 | 72 | def transform_rays(ray_batch, params, use_viewdirs, inverse=False): 73 | """Transform rays into object coordinate frame given o2w transformation params. 74 | 75 | Note: do not assume viewdirs is always the normalized version of rays_d (e.g., in staticcam case). 76 | 77 | Args: 78 | ray_batch: [R, M] float tensor. Batch of rays. 79 | params: Dictionary containing transformation parameters: 80 | 'translation': List of 3 elements. xyz translation. 81 | 'rotation': List of 3 euler angles in xyz. 82 | use_viewdirs: bool. Whether to we are using viewdirs. 83 | inverse: bool. Whether to apply inverse of the transformations provided in 'params'. 84 | 85 | Returns: 86 | ray_batch_obj: [R, M] float tensor. The ray batch, in object coordinate frame. 87 | """ 88 | rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] 89 | translation, rotmat = get_transformation_from_params(params) # [3,], [3, 3] 90 | 91 | if inverse: 92 | translation = -1 * translation # [3,] 93 | rotmat = torch.transpose(rotmat, 1, 0) # [3, 3] 94 | 95 | translation_inverse = -1 * translation 96 | rotmat_inverse = torch.transpose(rotmat, 1, 0) 97 | 98 | # Transform the ray origin. 99 | rays_o_obj, _ = box_utils.ray_to_box_coordinate_frame_pairwise( 100 | box_center=translation_inverse, 101 | box_rotation_matrix=rotmat_inverse, 102 | rays_start_point=rays_o, 103 | rays_end_point=rays_d) 104 | 105 | # Only apply rotation to rays_d. 106 | rays_d_obj = rotate_dirs(rays_d, rotmat) 107 | 108 | ray_batch_obj = update_ray_batch_slice(ray_batch, rays_o_obj, 0, 3) 109 | ray_batch_obj = update_ray_batch_slice(ray_batch_obj, rays_d_obj, 3, 6) 110 | if use_viewdirs: 111 | # Grab viewdirs from the ray batch itself. Because it may be different from rays_d 112 | # (as in the staticcam case). 113 | viewdirs = ray_batch[:, 8:11] 114 | viewdirs_obj = rotate_dirs(viewdirs, rotmat) 115 | ray_batch_obj = update_ray_batch_slice(ray_batch_obj, viewdirs_obj, 8, 11) 116 | return ray_batch_obj 117 | 118 | 119 | def transform_points_into_world_coordinate_frame(pts, params, check_numerics=False): 120 | translation, rotmat = get_transformation_from_params(params) # [3,], [3, 3] 121 | 122 | # pts_flat = pts.view(-1, 3) # [RS, 3] 123 | # num_examples = pts_flat.size()[0] # RS 124 | 125 | # translation = translation.unsqueeze(0) 126 | # translation = torch.tile(translation, (num_examples, 1)) # [RS, 3] 127 | # rotmat = rotmat.unsqueeze(0) 128 | # rotmat = torch.tile(rotmat, (num_examples, 1, 1)) 129 | 130 | # # pts_flat_transformed = torch.matmul(pts_flat[:, None, :], torch.transpose(rotmat, 2, 1)) # [RS, 1, 3] 131 | # pts_flat_transformed = pts_flat[:, None, :] # [RS, 1, 3] 132 | # pts_flat_transformed += translation[:, None, :] # [RS, 1, 3] 133 | # pts_transformed = pts_flat_transformed.view(pts.size()) # [R, S, 3] 134 | chunk = 256 135 | # Check batch transformations works without rotation. 136 | if check_numerics: 137 | transformations = np.eye(4) 138 | transformations[:3, 3] = translation 139 | transformations = torch.tensor(transformations, dtype=torch.float) # [4, 4] 140 | transformations = torch.tile(transformations[None, ...], (pts.size()[0], 1, 1)) # [R, 4, 4] 141 | pts_transformed1 = [] 142 | for i in range(0, pts.size()[0], chunk): 143 | pts_transformed1_chunk = apply_batched_transformations( 144 | inputs=[pts[i:i+chunk]], transformations=transformations[i:i+chunk])[0] 145 | pts_transformed1.append(pts_transformed1_chunk) 146 | pts_transformed1 = torch.cat(pts_transformed1, dim=0) 147 | 148 | pts_transformed2 = pts + translation[None, None, :] 149 | 150 | # Now add rotation 151 | transformations = np.eye(4) 152 | transformations = torch.tensor(transformations, dtype=torch.float) 153 | transformations[:3, :3] = rotmat 154 | transformations[:3, 3] = translation 155 | #transformations = torch.tensor(transformations, dtype=torch.float) # [4, 4] 156 | transformations = torch.tile(transformations[None, ...], (pts.size()[0], 1, 1)) # [R, 4, 4] 157 | pts_transformed = [] 158 | for i in range(0, pts.size()[0], chunk): 159 | pts_transformed_chunk = apply_batched_transformations( 160 | inputs=[pts[i:i+chunk]], transformations=transformations[i:i+chunk])[0] 161 | pts_transformed.append(pts_transformed_chunk) 162 | pts_transformed = torch.cat(pts_transformed, dim=0) 163 | return pts_transformed 164 | 165 | 166 | # def transform_rays(ray_batch, translation, use_viewdirs): 167 | # """Apply transformation to rays. 168 | 169 | # Args: 170 | # ray_batch: [R, M] float tensor. All information necessary 171 | # for sampling along a ray, including: ray origin, ray direction, min 172 | # dist, max dist, and unit-magnitude viewing direction. 173 | # translation: [3,] float tensor. The (x, y, z) translation to apply. 174 | # use_viewdirs: Whether to use view directions. 175 | 176 | # Returns: 177 | # ray_batch: [R, M] float tensor. Transformed ray batch. 178 | # """ 179 | # assert translation.size()[0] == 3, "translation.size()[0] must be 3..." 180 | 181 | # # Since we are only supporting translation for now, only ray origins need to be 182 | # # modified. Ray directions do not need to change. 183 | # rays_o = ray_batch[:, 0:3] + translation 184 | # rays_remaining = ray_batch[:, 3:] 185 | # ray_batch = torch.cat((rays_o, rays_remaining), dim=1) 186 | # return ray_batch 187 | 188 | def compute_rays_length(rays_d): 189 | """Compute ray length. 190 | 191 | Args: 192 | rays_d: [R, 3] float tensor. Ray directions. 193 | 194 | Returns: 195 | rays_length: [R, 1] float tensor. Ray lengths. 196 | """ 197 | rays_length = torch.norm(rays_d, dim=-1, keepdim=True) # [N_rays, 1] 198 | return rays_length 199 | 200 | 201 | def normalize_rays(rays): 202 | """Normalize ray directions. 203 | 204 | Args: 205 | rays: [R, 3] float tensor. Ray directions. 206 | 207 | Returns: 208 | normalized_rays: [R, 3] float tensor. Normalized ray directions. 209 | """ 210 | normalized_rays = rays / compute_rays_length(rays_d=rays) 211 | return normalized_rays 212 | 213 | 214 | def compute_ray_dirs_and_length(rays_o, rays_dst): 215 | """Compute ray directions. 216 | 217 | Args: 218 | rays_o: [R, 3] float tensor. Ray origins. 219 | rays_dst: [R, 3] float tensor. Ray destinations. 220 | 221 | Returns: 222 | rays_d: [R, 3] float tensor. Normalized ray directions. 223 | """ 224 | # The ray directions are the difference between the ray destinations and the 225 | # ray origins. 226 | rays_d = rays_dst - rays_o # [R, 3] # Direction out of light source 227 | 228 | # Compute the length of the rays. 229 | rays_length = compute_rays_length(rays_d=rays_d) 230 | 231 | # Normalized the ray directions. 232 | rays_d = rays_d / rays_length # [R, 3] # Normalize direction 233 | return rays_d, rays_length 234 | 235 | 236 | def update_ray_batch_slice(ray_batch, x, start, end): 237 | left = ray_batch[:, :start] # [R, ?] 238 | right = ray_batch[:, end:] # [R, ?] 239 | updated_ray_batch = torch.cat((left, x, right), dim=-1) 240 | return updated_ray_batch 241 | 242 | 243 | def update_ray_batch_bounds(ray_batch, bounds): 244 | updated_ray_batch = update_ray_batch_slice(ray_batch=ray_batch, x=bounds, 245 | start=6, end=8) 246 | return updated_ray_batch 247 | 248 | 249 | def create_ray_batch( 250 | rays_o, rays_dst, rays_i, use_viewdirs, rays_near=None, rays_far=None, epsilon=1e-10): 251 | # Compute the ray directions. 252 | rays_d = rays_dst - rays_o # [R,3] # Direction out of light source 253 | rays_length = compute_rays_length(rays_d=rays_d) # [R, 1] 254 | rays_d = rays_d / rays_length # [R, 3] # Normalize direction 255 | viewdirs = rays_d # [R, 3] 256 | 257 | # If bounds are not provided, set the beginning and end of ray as sampling bounds. 258 | if rays_near is None: 259 | rays_near = torch.zeros((rays_o.size()[0], 1), dtype=torch.float) + epsilon # [R, 1] 260 | if rays_far is None: 261 | rays_far = rays_length # [R, 1] 262 | 263 | ray_batch = torch.cat((rays_o, rays_d, rays_near, rays_far), dim=-1) 264 | if use_viewdirs: 265 | ray_batch = torch.cat((ray_batch, viewdirs), dim=-1) 266 | ray_batch = torch.cat((ray_batch, rays_i), dim=-1) 267 | return ray_batch 268 | 269 | 270 | def sample_random_lightdirs(num_rays, num_samples, upper_only=False): 271 | """Randomly sample directions in the unit sphere. 272 | 273 | Args: 274 | num_rays: int or tensor shape dimension. Number of rays. 275 | num_samples: int or tensor shape dimension. Number of samples per ray. 276 | upper_only: bool. Whether to sample only on the upper hemisphere. 277 | 278 | Returns: 279 | lightdirs: [R, S, 3] float tensor. Random light directions sampled from the unit 280 | sphere for each sampled point. 281 | """ 282 | if upper_only: 283 | min_z = 0 284 | else: 285 | min_z = -1 286 | 287 | phi = torch.rand(num_rays, num_samples) * (2 * math.pi) # [R, S] 288 | cos_theta = torch.rand(num_rays, num_samples) * (1 - min_z) + min_z # [R, S] 289 | theta = torch.acos(cos_theta) # [R, S] 290 | 291 | x = torch.sin(theta) * torch.cos(phi) 292 | y = torch.sin(theta) * torch.sin(phi) 293 | z = torch.cos(theta) 294 | 295 | lightdirs = torch.cat((x[..., None], y[..., None], z[..., None]), dim=-1) # [R, S, 3] 296 | return lightdirs 297 | 298 | 299 | def get_light_positions(rays_i, img_light_pos): 300 | """Extracts light positions given scene IDs. 301 | 302 | Args: 303 | rays_i: [R, 1] float tensor. Per-ray image IDs. 304 | img_light_pos: [N, 3] float tensor. Per-image light positions. 305 | 306 | Returns: 307 | rays_light_pos: [R, 3] float tensor. Per-ray light positions. 308 | """ 309 | #print("img_light_pos shape: ", img_light_pos.shape) 310 | rays_light_pos = img_light_pos[rays_i.long()].squeeze() # [R, 3] 311 | return rays_light_pos 312 | 313 | 314 | def get_lightdirs(lightdirs_method, num_rays=None, num_samples=None, rays_i=None, 315 | metadata=None, ray_batch=None, use_viewdirs=False, normalize=False): 316 | """Compute lightdirs. 317 | 318 | Args: 319 | lightdirs_method: str. Method to use for computing lightdirs. 320 | num_rays: int or tensor shape dimension. Number of rays. 321 | num_samples: int or tensor shape dimension. Number of samples per ray. 322 | rays_i: [R, 1] float tensor. Ray image IDs. 323 | metadata: [N, 3] float tensor. Metadata about each image. Currently only light 324 | position is provided. 325 | ray_batch: [R, M] float tensor. Ray batch. 326 | use_viewdirs: bool. Whether to use viewdirs. 327 | normalize: bool. Whether to normalize lightdirs. 328 | 329 | Returns; 330 | lightdirs: [R, S, 3] float tensor. Light directions for each sample. 331 | """ 332 | if lightdirs_method == 'viewdirs': 333 | raise NotImplementedError 334 | assert use_viewdirs 335 | lightdirs = ray_batch[:, 8:11] # [R, 3] 336 | lightdirs *= 1.5 337 | lightdirs = torch.tile(lightdirs[:, None, :], (1, num_samples, 1)) 338 | elif lightdirs_method == 'metadata': 339 | lightdirs = get_light_positions(rays_i, metadata) # [R, 3] 340 | lightdirs = torch.tile(lightdirs[:, None, :], (1, num_samples, 1)) # [R, S, 3] 341 | elif lightdirs_method == 'random': 342 | lightdirs = sample_random_lightdirs(num_rays, num_samples) # [R, S, 3] 343 | elif lightdirs_method == 'random_upper': 344 | lightdirs = sample_random_lightdirs(num_rays, num_samples, upper_only=True) # [R, S, 3] 345 | else: 346 | raise ValueError(f'Invalid lightdirs_method: {lightdirs_method}.') 347 | if normalize: 348 | lightdirs_flat = lightdirs.view(-1, 3) # [RS, 3] 349 | lightdirs_flat = normalize_rays(lightdirs_flat) # [RS, 3] 350 | lightdirs = lightdirs_flat.view(lightdirs.size()) # [R, S, 3] 351 | return lightdirs 352 | -------------------------------------------------------------------------------- /taxim_render.py: -------------------------------------------------------------------------------- 1 | ''' 2 | GelSight tactile render with taxim 3 | 4 | Zilin Si (zsi@andrew.cmu.edu) 5 | Last revision: March 2022 6 | ''' 7 | 8 | import os 9 | from os import path as osp 10 | import numpy as np 11 | import cv2 12 | 13 | from basics import sensorParams as psp 14 | from basics.CalibData import CalibData 15 | 16 | class TaximRender: 17 | 18 | def __init__(self, calib_path): 19 | # taxim calibration files 20 | # polytable 21 | calib_data = osp.join(calib_path, "polycalib.npz") 22 | self.calib_data = CalibData(calib_data) 23 | # raw calibration data 24 | rawData = osp.join(calib_path, "dataPack.npz") 25 | data_file = np.load(rawData, allow_pickle=True) 26 | self.f0 = data_file['f0'] 27 | ## tactile image config 28 | bins = psp.numBins 29 | [xx, yy] = np.meshgrid(range(psp.w), range(psp.h)) 30 | xf = xx.flatten() 31 | yf = yy.flatten() 32 | self.A = np.array([xf*xf,yf*yf,xf*yf,xf,yf,np.ones(psp.h*psp.w)]).T 33 | binm = bins - 1 34 | self.x_binr = 0.5*np.pi/binm # x [0,pi/2] 35 | self.y_binr = 2*np.pi/binm # y [-pi, pi] 36 | 37 | # load depth bg 38 | self.bg_depth = np.load(osp.join(calib_path,"depth_bg.npy"), allow_pickle=True) 39 | # load tactile bg 40 | self.real_bg = np.load(osp.join(calib_path,"real_bg.npy"), allow_pickle=True) 41 | 42 | def correct_height_map(self, height_map): 43 | # move the center of depth to the origin 44 | height_map = (height_map-psp.cam2gel) * -1000 / psp.pixmm 45 | return height_map 46 | 47 | def padding(self, img): 48 | # pad one row & one col on each side 49 | if len(img.shape) == 2: 50 | return np.pad(img, ((1, 1), (1, 1)), 'edge') 51 | elif len(img.shape) == 3: 52 | return np.pad(img, ((1, 1), (1, 1), (0, 0)), 'edge') 53 | 54 | def generate_normals(self, height_map): 55 | # from height map to gradient magnitude & directions 56 | 57 | [h,w] = height_map.shape 58 | center = height_map[1:h-1,1:w-1] # z(x,y) 59 | top = height_map[0:h-2,1:w-1] # z(x-1,y) 60 | bot = height_map[2:h,1:w-1] # z(x+1,y) 61 | left = height_map[1:h-1,0:w-2] # z(x,y-1) 62 | right = height_map[1:h-1,2:w] # z(x,y+1) 63 | dzdx = (bot-top)/2.0 64 | dzdy = (right-left)/2.0 65 | 66 | mag_tan = np.sqrt(dzdx**2 + dzdy**2) 67 | grad_mag = np.arctan(mag_tan) 68 | invalid_mask = mag_tan == 0 69 | valid_mask = ~invalid_mask 70 | grad_dir = np.zeros((h-2,w-2)) 71 | grad_dir[valid_mask] = np.arctan2(dzdx[valid_mask]/mag_tan[valid_mask], dzdy[valid_mask]/mag_tan[valid_mask]) 72 | 73 | grad_mag = self.padding(grad_mag) 74 | grad_dir = self.padding(grad_dir) 75 | return grad_mag, grad_dir 76 | 77 | def render(self, depth, press_depth): 78 | 79 | depth = self.correct_height_map(depth) 80 | height_map = depth.copy() 81 | 82 | ## generate contact mask 83 | pressing_height_pix = press_depth * 1000 / psp.pixmm 84 | contact_mask = (height_map-(self.bg_depth)) > pressing_height_pix * 0.2 85 | 86 | # smooth out the soft contact 87 | zq_back = height_map.copy() 88 | kernel_size = [11,5] 89 | for k in range(len(kernel_size)): 90 | height_map = cv2.GaussianBlur(height_map.astype(np.float32),(kernel_size[k],kernel_size[k]),0) 91 | height_map[contact_mask] = zq_back[contact_mask] 92 | # height_map = cv2.GaussianBlur(height_map.astype(np.float32),(5,5),0) 93 | 94 | # generate gradients 95 | grad_mag, grad_dir = self.generate_normals(height_map) 96 | 97 | # simulate raw image 98 | sim_img_r = np.zeros((psp.h,psp.w,3)) 99 | idx_x = np.floor(grad_mag/self.x_binr).astype('int') 100 | idx_y = np.floor((grad_dir+np.pi)/self.y_binr).astype('int') 101 | 102 | params_r = self.calib_data.grad_r[idx_x,idx_y,:] 103 | params_r = params_r.reshape((psp.h*psp.w), params_r.shape[2]) 104 | params_g = self.calib_data.grad_g[idx_x,idx_y,:] 105 | params_g = params_g.reshape((psp.h*psp.w), params_g.shape[2]) 106 | params_b = self.calib_data.grad_b[idx_x,idx_y,:] 107 | params_b = params_b.reshape((psp.h*psp.w), params_b.shape[2]) 108 | 109 | est_r = np.sum(self.A * params_r,axis = 1) 110 | est_g = np.sum(self.A * params_g,axis = 1) 111 | est_b = np.sum(self.A * params_b,axis = 1) 112 | sim_img_r[:,:,0] = est_r.reshape((psp.h,psp.w)) 113 | sim_img_r[:,:,1] = est_g.reshape((psp.h,psp.w)) 114 | sim_img_r[:,:,2] = est_b.reshape((psp.h,psp.w)) 115 | 116 | # add back ground 117 | tactile_img = sim_img_r + self.real_bg 118 | tactile_img = np.clip(tactile_img, 0, 255) 119 | 120 | return height_map, contact_mask, tactile_img 121 | 122 | #if __name__ == "__main__": 123 | 124 | # define the press depth, and get the depth map from touch net. 125 | # taxim = TaximRender(calib_path) 126 | # height_map, contact_mask, tactile_img = taxim.render(depth, press_depth) 127 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import signal 3 | import time 4 | from collections import deque, defaultdict 5 | from itertools import product 6 | import numpy as np 7 | import argparse 8 | import yaml 9 | import run_nerf_helpers 10 | import torch 11 | from tqdm import tqdm 12 | import lpips 13 | 14 | class CameraIntrinsics: 15 | def __init__(self, H, W, fx, fy, cx, cy): 16 | self.H = H 17 | self.W = W 18 | self.fx = fx 19 | self.fy = fy 20 | self.cx = cx 21 | self.cy = cy 22 | 23 | def get_random_points_inside_domain(num_points, domain_min, domain_max): 24 | x = np.random.uniform(domain_min[0], domain_max[0], size=(num_points,)) 25 | y = np.random.uniform(domain_min[1], domain_max[1], size=(num_points,)) 26 | z = np.random.uniform(domain_min[2], domain_max[2], size=(num_points,)) 27 | return np.column_stack((x, y, z)) 28 | 29 | def get_random_directions(num_samples): 30 | random_directions = np.random.randn(num_samples, 3) 31 | random_directions /= np.linalg.norm(random_directions, axis=1).reshape(-1, 1) 32 | return random_directions 33 | # return 2.5 * random_directions 34 | 35 | def get_random_lights(num_samples): 36 | random_lights = np.random.randn(num_samples, 3) 37 | random_lights /= np.linalg.norm(random_lights, axis=1).reshape(-1, 1) 38 | return random_lights 39 | 40 | def load_pretrained_nerf_model(dev, cfg): 41 | pretrained_cfg = load_yaml_as_dict(cfg['pretrained_cfg_path']) 42 | if 'use_initialization_fix' not in pretrained_cfg: 43 | pretrained_cfg['use_initialization_fix'] = False 44 | if 'num_importance_samples_per_ray' not in pretrained_cfg: 45 | pretrained_cfg['num_importance_samples_per_ray'] = 0 46 | pretrained_nerf, embed_fn, embeddirs_fn, embedlights_fn = create_nerf(pretrained_cfg) 47 | pretrained_nerf = pretrained_nerf.to(dev) 48 | checkpoint = torch.load(cfg['pretrained_checkpoint_path']) 49 | pretrained_nerf.load_state_dict(checkpoint['model_state_dict']) 50 | pretrained_nerf = run_nerf_helpers.ChainEmbeddingAndModel(pretrained_nerf.model_coarse, embed_fn, embeddirs_fn, embedlights_fn) # pos. encoding 51 | return pretrained_nerf 52 | 53 | def create_nerf(cfg): 54 | embed_fn, input_ch = run_nerf_helpers.get_embedder(cfg['num_frequencies'], 0) 55 | embeddirs_fn, input_ch_views = run_nerf_helpers.get_embedder(cfg['num_frequencies_direction'], 0) 56 | embedlights_fn, input_ch_lights = run_nerf_helpers.get_embedder(cfg['num_frequencies_direction'], 0) 57 | output_ch = 4 58 | skips = [cfg['refeed_position_index']] 59 | model = run_nerf_helpers.NeRF2(D=cfg['num_hidden_layers'], W=cfg['hidden_layer_size'], 60 | input_ch=input_ch, output_ch=output_ch, skips=skips, 61 | input_ch_views=input_ch_views, input_ch_lights=input_ch_lights, 62 | use_viewdirs=True, use_lightdirs=True, 63 | direction_layer_size=cfg['direction_layer_size'], use_initialization_fix=cfg['use_initialization_fix']) 64 | 65 | if cfg['num_importance_samples_per_ray'] > 0: 66 | model_fine = run_nerf_helpers.NeRF2(D=cfg['num_hidden_layers'], W=cfg['hidden_layer_size'], 67 | input_ch=input_ch, output_ch=output_ch, skips=skips, 68 | input_ch_views=input_ch_views, input_ch_lights=input_ch_lights, 69 | use_viewdirs=True, use_lightdirs=True, 70 | direction_layer_size=cfg['direction_layer_size'], use_initialization_fix=cfg['use_initialization_fix']) 71 | model = run_nerf_helpers.CoarseAndFine(model, model_fine) 72 | 73 | return model, embed_fn, embeddirs_fn, embedlights_fn 74 | 75 | def query_densities(points, pretrained_nerf, cfg, dev): 76 | mock_directions = torch.zeros_like(points) # density does not depend on direction 77 | points_and_dirs = torch.cat([points, mock_directions], dim=1) 78 | num_points_and_dirs = points_and_dirs.size(0) 79 | densities = torch.empty(num_points_and_dirs) 80 | if 'query_batch_size' in cfg: 81 | query_batch_size = cfg['query_batch_size'] 82 | else: 83 | query_batch_size = num_points_and_dirs 84 | with torch.no_grad(): 85 | start = 0 86 | while start < num_points_and_dirs: 87 | end = min(start + query_batch_size, num_points_and_dirs) 88 | densities[start:end] = F.relu(pretrained_nerf(points_and_dirs[start:end].to(dev))[:, -1]).cpu() # Only select the densities (A) from NeRF's RGBA output 89 | start = end 90 | return densities 91 | 92 | def has_flag(cfg, name): 93 | return name in cfg and cfg[name] 94 | 95 | def load_yaml_as_dict(path): 96 | with open(path) as yaml_file: 97 | yaml_as_dict = yaml.load(yaml_file, Loader=yaml.FullLoader) 98 | return yaml_as_dict 99 | 100 | def parse_args_and_init_logger(default_cfg_path=None, parse_render_cfg_path=False): 101 | parser = argparse.ArgumentParser(description='NeRF distillation') 102 | parser.add_argument('cfg_path', type=str) 103 | parser.add_argument('log_path', type=str, nargs='?') 104 | if parse_render_cfg_path: 105 | parser.add_argument('-rcfg', '--render_cfg_path', type=str) 106 | args = parser.parse_args() 107 | if args.log_path is None: 108 | start = args.cfg_path.find('/') 109 | end = args.cfg_path.rfind('.') 110 | args.log_path = 'logs' + args.cfg_path[start:end] 111 | print('auto log path:', args.log_path) 112 | 113 | create_directory_if_not_exists(args.log_path) 114 | Logger.filename = args.log_path + '/log.txt' 115 | 116 | cfg = load_yaml_as_dict(args.cfg_path) 117 | if default_cfg_path is not None: 118 | default_cfg = load_yaml_as_dict(default_cfg_path) 119 | for key in default_cfg: 120 | if not key in cfg: 121 | cfg[key] = default_cfg[key] 122 | print(cfg) 123 | 124 | ret_val = (cfg, args.log_path) 125 | if parse_render_cfg_path: 126 | ret_val += (args.render_cfg_path,) 127 | 128 | return ret_val 129 | 130 | class IterativeMean: 131 | def __init__(self): 132 | self.value = None 133 | self.num_old_values = 0 134 | 135 | def add_values(self, new_values): 136 | if self.value: 137 | self.value = (self.num_old_values * self.value + new_values.size(0) * new_values.mean()) / (self.num_old_values + new_values.size(0)) 138 | else: 139 | self.value = new_values.mean() 140 | self.num_old_values += new_values.size(0) 141 | 142 | def get_mean(self): 143 | return self.value.item() 144 | 145 | 146 | def create_directory_if_not_exists(directory): 147 | if not os.path.isdir(directory): 148 | os.makedirs(directory) 149 | 150 | class Logger: 151 | filename = None 152 | 153 | @staticmethod 154 | def write(text): 155 | with open(Logger.filename, 'a') as log_file: 156 | print(text, flush=True) 157 | log_file.write(text + '\n') 158 | 159 | class GracefulKiller: 160 | kill_now = False 161 | def __init__(self): 162 | signal.signal(signal.SIGUSR1, self.exit_gracefully) 163 | 164 | def exit_gracefully(self, signum, frame): 165 | self.kill_now = True 166 | 167 | def extract_domain_boxes_from_tree(root_node): 168 | nodes_to_process = deque([root_node]) 169 | boxes = [] 170 | while nodes_to_process: 171 | node = nodes_to_process.popleft() 172 | if hasattr(node, 'leq_child'): 173 | nodes_to_process.append(node.leq_child) 174 | nodes_to_process.append(node.gt_child) 175 | else: 176 | boxes.append([node.domain_min, node.domain_max]) 177 | 178 | return boxes 179 | 180 | def write_boxes_to_obj(boxes, obj_filename): 181 | txt = '' 182 | i = 0 183 | for box in tqdm(boxes): 184 | for min_or_max in product(range(2), repeat=3): 185 | txt += 'v {} {} {}\n'.format(box[min_or_max[0]][0], box[min_or_max[1]][1], box[min_or_max[2]][2]) 186 | for x, y, z in [(0b000, 0b100, 0b010), (0b100, 0b010, 0b110), 187 | (0b001, 0b101, 0b011), (0b101, 0b011, 0b111), 188 | (0b000, 0b010, 0b001), (0b001, 0b011, 0b010), 189 | (0b100, 0b110, 0b101), (0b101, 0b111, 0b110), 190 | (0b000, 0b100, 0b001), (0b100, 0b101, 0b001), 191 | (0b010, 0b110, 0b011), (0b110, 0b111, 0b011)]: 192 | txt += 'f {} {} {}\n'.format(1 + i * 8 + x, 1 + i * 8 + y, 1 + i * 8 + z) 193 | i += 1 194 | 195 | with open(obj_filename, 'a') as obj_file: 196 | obj_file.write(txt) 197 | 198 | class PerfMonitor: 199 | events = [] 200 | is_active = True 201 | 202 | @staticmethod 203 | def add(name, groups=[]): 204 | if PerfMonitor.is_active: 205 | torch.cuda.synchronize() 206 | t = time.perf_counter() 207 | PerfMonitor.events.append((name, t, groups)) 208 | 209 | @staticmethod 210 | def log_and_reset(write_detailed_log): 211 | previous_t = PerfMonitor.events[0][1] 212 | group_map = defaultdict(float) 213 | elapsed_times = [] 214 | for name, t, groups in PerfMonitor.events[1:]: 215 | elapsed_time = t - previous_t 216 | elapsed_times.append(elapsed_time) 217 | for group in groups: 218 | group_map[group] += elapsed_time 219 | group_map['total'] += elapsed_time 220 | previous_t = t 221 | max_length = max([len(name) for name, _, _ in PerfMonitor.events] + [len(group) for group in group_map]) 222 | 223 | if write_detailed_log: 224 | for event, elapsed_time in zip(PerfMonitor.events[1:], elapsed_times): 225 | name = event[0] 226 | extra_whitespace = ' ' * (max_length - len(name)) 227 | Logger.write('{}:{} {:7.2f} ms'.format(name, extra_whitespace, 1000 * (elapsed_time))) 228 | Logger.write('') 229 | for group in group_map: 230 | extra_whitespace = ' ' * (max_length - len(group)) 231 | Logger.write('{}:{} {:7.2f} ms'.format(group, extra_whitespace, 1000 * (group_map[group]))) 232 | 233 | # Reset 234 | PerfMonitor.events = [] 235 | 236 | return group_map['total'] 237 | 238 | class LPIPS: 239 | loss_fn_alex = None 240 | 241 | @staticmethod 242 | def calculate(img_a, img_b): 243 | img_a, img_b = [img.permute([2, 1, 0]).unsqueeze(0) for img in [img_a, img_b]] 244 | if LPIPS.loss_fn_alex == None: # lazy init 245 | LPIPS.loss_fn_alex = lpips.LPIPS(net='alex', version='0.1') 246 | return LPIPS.loss_fn_alex(img_a, img_b) 247 | 248 | 249 | def get_distance_to_closest_point_in_box(point, domain_min, domain_max): 250 | closest_point = np.array([0., 0., 0.]) 251 | for dim in range(3): 252 | if point[dim] < domain_min[dim]: 253 | closest_point[dim] = domain_min[dim] 254 | elif domain_max[dim] < point[dim]: 255 | closest_point[dim] = domain_max[dim] 256 | else: # in between domain_min and domain_max 257 | closest_point[dim] = point[dim] 258 | return np.linalg.norm(point - closest_point) 259 | 260 | def get_distance_to_furthest_point_in_box(point, domain_min, domain_max): 261 | furthest_point = np.array([0., 0., 0.]) 262 | for dim in range(3): 263 | mid = (domain_min[dim] + domain_max[dim]) / 2 264 | if point[dim] > mid: 265 | furthest_point[dim] = domain_min[dim] 266 | else: 267 | furthest_point[dim] = domain_max[dim] 268 | return np.linalg.norm(point - furthest_point) 269 | 270 | def load_matrix(path): 271 | return np.array([[float(w) for w in line.strip().split()] for line in open(path)]).astype(np.float32) 272 | 273 | class ConfigManager: 274 | global_domain_min = None 275 | global_domain_max = None 276 | 277 | @staticmethod 278 | def init(cfg): 279 | if 'global_domain_min' in cfg and 'global_domain_max' in cfg: 280 | ConfigManager.global_domain_min = cfg['global_domain_min'] 281 | ConfigManager.global_domain_max = cfg['global_domain_max'] 282 | elif 'dataset_dir' in cfg and cfg['dataset_type'] == 'nsvf': 283 | bbox_path = os.path.join(cfg['dataset_dir'], 'bbox.txt') 284 | bounding_box = load_matrix(bbox_path)[0, :-1] 285 | ConfigManager.global_domain_min = bounding_box[:3] 286 | ConfigManager.global_domain_max = bounding_box[3:] 287 | 288 | @staticmethod 289 | def get_global_domain_min_and_max(device=None): 290 | result = ConfigManager.global_domain_min, ConfigManager.global_domain_max 291 | if device: 292 | result = [torch.tensor(x, dtype=torch.float, device=device) for x in result] 293 | return result 294 | 295 | 296 | def main(): 297 | if False: 298 | boxes = [ 299 | [[-0.078125, 0.390625, 0.546875], [-0.0625, 0.40625, 0.5625]], 300 | [[-0.625, -0.375, -0.375], [-0.5, -0.25, -0.25]], 301 | [[-0.625, -0.25, -0.375], [-0.5, -0.125, -0.25]], 302 | [[-0.625, -0.125, -0.375], [-0.5, 0.0, -0.25]], 303 | [[-0.5, 0.5, 0.0], [-0.375, 0.625, 0.25]], 304 | [[-0.125, 0.125, -0.5], [0.0, 0.25, -0.25]], 305 | [[-0.375, -0.25, 0.0], [-0.25, -0.125, 0.25]], 306 | [[-0.625, -0.5, -0.5], [-0.5, -0.375, -0.25]], 307 | [[-0.125, 0.0, 0.5], [0.0, 0.25, 0.75]] 308 | ] 309 | boxes = [[[0.15625, -0.3125, 0.8125], [0.1875, -0.25, 0.875]]] 310 | print(boxes) 311 | write_boxes_to_obj(boxes, 'hard_domains_2.obj') 312 | 313 | if __name__ == '__main__': 314 | main() 315 | -------------------------------------------------------------------------------- /von_mises.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | ''' 4 | Pick a point uniformly from the unit circle 5 | ''' 6 | def circle_uniform_pick(size, out = None): 7 | if out is None: 8 | out = np.empty((size, 2)) 9 | 10 | angle = 2 * np.pi * np.random.random(size) 11 | out[:,0], out[:,1] = np.cos(angle), np.sin(angle) 12 | 13 | return out 14 | 15 | def cross_product_matrix_batched(U): 16 | batch_size = U.shape[0] 17 | result = np.zeros(shape=(batch_size, 3, 3)) 18 | result[:, 0, 1] = -U[:, 2] 19 | result[:, 0, 2] = U[:, 1] 20 | result[:, 1, 0] = U[:, 2] 21 | result[:, 1, 2] = -U[:, 0] 22 | result[:, 2, 0] = -U[:, 1] 23 | result[:, 2, 1] = U[:, 0] 24 | return result 25 | 26 | ''' 27 | Von Mises-Fisher distribution, ie. isotropic Gaussian distribution defined over 28 | a sphere. 29 | mus => mean directions 30 | kappa => concentration 31 | 32 | Uses numerical tricks described in "Numerically stable sampling of the von 33 | Mises Fisher distribution on S2 (and other tricks)" by Wenzel Jakob 34 | ''' 35 | 36 | def sample_von_mises_3d(mus, kappa, out=None): 37 | size = mus.shape[0] 38 | 39 | # Generate the samples for mu=(0, 0, 1) 40 | eta = np.random.random(size) 41 | tmp = 1. - (((eta - 1.) / eta) * np.exp(-2. * kappa)) 42 | W = 1. + (np.log(eta) + np.log(tmp)) / kappa 43 | 44 | V = np.empty((size, 2)) 45 | circle_uniform_pick(size, out = V) 46 | V *= np.sqrt(1. - W ** 2)[:, None] 47 | 48 | if out is None: 49 | out = np.empty((size, 3)) 50 | 51 | out[:, 0], out[:, 1], out[:, 2] = V[:, 0], V[:, 1], W 52 | 53 | angles = np.arccos(mus[:, 2]) 54 | mask = angles != 0. 55 | angles = angles[mask] 56 | mus = mus[mask] 57 | 58 | axis = np.zeros(shape=mus.shape) 59 | axis[:, 0] = -mus[:, 1] 60 | axis[:, 1] = mus[:, 0] 61 | 62 | axis /= np.sqrt(np.sum(axis ** 2, axis=1))[:, None] 63 | rot = np.cos(angles)[:, None, None] * np.identity(3)[None, :, :] 64 | rot += np.sin(angles)[:, None, None] * cross_product_matrix_batched(axis) 65 | rot += (1. - np.cos(angles))[:, None, None] * np.matmul(axis[:, :, None], axis[:, None, :]) 66 | 67 | out[mask] = (rot @ out[mask, :, None])[:, :, 0] 68 | return out 69 | 70 | 71 | if __name__ == '__main__': 72 | from math import sqrt 73 | mus = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], 74 | [1 / sqrt(2), -1 / sqrt(2), 0], 75 | [-1 / sqrt(2), -1 / sqrt(2), 0], 76 | [-1 / sqrt(2), 1 / sqrt(2), 0], 77 | [-1 / sqrt(2), 0., -1 / sqrt(2)], 78 | [1 / sqrt(2), 0., 1 / sqrt(2)]]) 79 | print(mus) 80 | 81 | sampled = sample_von_mises_3d(mus, 100000) 82 | print(sampled) 83 | 84 | --------------------------------------------------------------------------------