├── AudioNet_model.py
├── AudioNet_utils.py
├── LICENSE
├── OF_render.py
├── ObjectFolder1.0
├── AudioNet_model.py
├── AudioNet_utils.py
├── ObjectFolder_teaser.png
├── README.md
├── TouchNet_model.py
├── TouchNet_utils.py
├── VisionNet_utils.py
├── box_utils.py
├── cam_utils.py
├── dataset_visualization.png
├── demo
│ ├── audio_demo_forces.npy
│ ├── audio_demo_vertices.npy
│ ├── touch_demo_vertices.npy
│ └── vision_demo.npy
├── evaluate.py
├── indirect_utils.py
├── intersect.py
├── load_osf.py
├── ray_utils.py
├── requirements.txt
├── run_osf_helpers.py
├── scatter.py
└── shadow_utils.py
├── ObjectFolder2.0_teaser.png
├── README.md
├── TouchNet_model.py
├── TouchNet_utils.py
├── VisionNet_configs.py
├── VisionNet_utils.py
├── basics
├── CalibData.py
├── __init__.py
└── sensorParams.py
├── box_utils.py
├── build_occupancy_tree.py
├── calibs
├── dataPack.npz
├── depth_bg.npy
├── polycalib.npz
└── real_bg.npy
├── cam_utils.py
├── ddsp_torch.py
├── demo
├── ObjectFile.pth
├── audio_demo_forces.npy
├── audio_demo_vertices.npy
├── model.obj
├── touch_demo_gelinfo.npy
├── touch_demo_vertices.npy
└── vision_demo.npy
├── environment.yml
├── fast_kilonerf_renderer.py
├── load_osf.py
├── local_distill.py
├── multi_modules.py
├── objects.csv
├── ray_utils.py
├── run_nerf_helpers.py
├── taxim_render.py
├── utils.py
└── von_mises.py
/AudioNet_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.autograd.set_detect_anomaly(True)
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from AudioNet_utils import *
7 |
8 | class DenseLayer(nn.Linear):
9 | def __init__(self, in_dim: int, out_dim: int, activation: str = 'relu', *args, **kwargs) -> None:
10 | self.activation = activation
11 | super().__init__(in_dim, out_dim, *args, **kwargs)
12 |
13 | def reset_parameters(self) -> None:
14 | torch.nn.init.xavier_uniform_(self.weight, gain=torch.nn.init.calculate_gain(self.activation))
15 | if self.bias is not None:
16 | torch.nn.init.zeros_(self.bias)
17 |
18 | class Embedder:
19 | def __init__(self, **kwargs):
20 | self.kwargs = kwargs
21 | self.create_embedding_fn()
22 |
23 | def create_embedding_fn(self):
24 | embed_fns = []
25 | d = self.kwargs['input_dims']
26 | out_dim = 0
27 | if self.kwargs['include_input']:
28 | embed_fns.append(lambda x: x)
29 | out_dim += d
30 |
31 | max_freq = self.kwargs['max_freq_log2']
32 | N_freqs = self.kwargs['num_freqs']
33 |
34 | if self.kwargs['log_sampling']:
35 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
36 | else:
37 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
38 |
39 | for freq in freq_bands:
40 | for p_fn in self.kwargs['periodic_fns']:
41 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
42 | out_dim += d
43 |
44 | self.embed_fns = embed_fns
45 | self.out_dim = out_dim
46 |
47 | def embed(self, inputs):
48 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
49 |
50 | def get_embedder(multires, i=0):
51 | if i == -1:
52 | #5 is for x, y, z, f, t
53 | return nn.Identity(), 5
54 |
55 | embed_kwargs = {
56 | 'include_input': True,
57 | 'input_dims': 3,
58 | 'max_freq_log2': multires-1,
59 | 'num_freqs': multires,
60 | 'log_sampling': True,
61 | 'periodic_fns': [torch.sin, torch.cos],
62 | }
63 |
64 | embedder_obj = Embedder(**embed_kwargs)
65 | embed = lambda x, eo=embedder_obj: eo.embed(x)
66 | return embed, embedder_obj.out_dim
67 |
68 | class AudioNeRF(nn.Module):
69 | def __init__(self, D=8, input_ch=5, output_ch=2):
70 | super(AudioNeRF, self).__init__()
71 | self.model_x = NeRF(D = D, input_ch = input_ch, output_ch = output_ch)
72 | self.model_y = NeRF(D = D, input_ch = input_ch, output_ch = output_ch)
73 | self.model_z = NeRF(D = D, input_ch = input_ch, output_ch = output_ch)
74 |
75 | def forward(self, embedded_x, embedded_y, embedded_z):
76 | results_x = self.model_x(embedded_x)
77 | results_y = self.model_y(embedded_y)
78 | results_z = self.model_z(embedded_z)
79 | return results_x, results_y, results_z
80 |
81 |
82 | class NeRF(nn.Module):
83 | def __init__(self, D=8, W=256, input_ch=5, input_ch_views=0, output_ch=2, skips=[4], use_viewdirs=False):
84 | """
85 | """
86 | super(NeRF, self).__init__()
87 | self.D = D
88 | self.W = W
89 | self.input_ch = input_ch
90 | self.input_ch_views = input_ch_views
91 | self.skips = skips
92 | self.use_viewdirs = use_viewdirs
93 |
94 | self.pts_linears = nn.ModuleList(
95 | [DenseLayer(input_ch, W, activation='relu')] + [DenseLayer(W, W, activation='relu') if i not in self.skips else DenseLayer(W + input_ch, W, activation='relu') for i in range(D-1)])
96 |
97 | self.views_linears = nn.ModuleList([DenseLayer(input_ch_views + W, W//2, activation='relu')])
98 |
99 | if use_viewdirs:
100 | self.feature_linear = DenseLayer(W, W, activation='sigmoid')
101 | #self.alpha_linear = DenseLayer(W, 1, activation='linear')
102 | self.rgb_linear = DenseLayer(W//2, output_ch, activation='sigmoid')
103 | else:
104 | self.output_linear = DenseLayer(W, output_ch, activation='sigmoid')
105 |
106 |
107 | def forward(self, x):
108 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
109 | h = input_pts
110 | for i, l in enumerate(self.pts_linears):
111 | h = self.pts_linears[i](h)
112 | h = F.relu(h)
113 | if i in self.skips:
114 | h = torch.cat([input_pts, h], -1)
115 |
116 | if self.use_viewdirs:
117 | feature = self.feature_linear(h)
118 | h = torch.cat([feature, input_views], -1)
119 |
120 | for i, l in enumerate(self.views_linears):
121 | h = self.views_linears[i](h)
122 | h = F.relu(h)
123 |
124 | outputs = self.rgb_linear(h)
125 | else:
126 | outputs = self.output_linear(h)
127 |
128 | return outputs
129 |
--------------------------------------------------------------------------------
/AudioNet_utils.py:
--------------------------------------------------------------------------------
1 | from scipy.io import wavfile
2 | import librosa
3 | import librosa.display
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from AudioNet_model import *
7 | import os
8 | from collections import OrderedDict
9 |
10 | def strip_prefix_if_present(state_dict, prefix):
11 | keys = sorted(state_dict.keys())
12 | if not all(key.startswith(prefix) for key in keys):
13 | return state_dict
14 | stripped_state_dict = OrderedDict()
15 | for key, value in state_dict.items():
16 | stripped_state_dict[key.replace(prefix, "")] = value
17 | return stripped_state_dict
18 |
19 | def mkdirs(path, remove=False):
20 | if os.path.isdir(path):
21 | if remove:
22 | shutil.rmtree(path)
23 | else:
24 | return
25 | os.makedirs(path)
26 |
27 | def generate_spectrogram_magphase(audio, stft_frame, stft_hop, n_fft, with_phase=False):
28 | spectro = librosa.core.stft(audio, hop_length=stft_hop, n_fft=n_fft, win_length=stft_frame, center=True)
29 | spectro_mag, spectro_phase = librosa.core.magphase(spectro)
30 | spectro_mag = np.expand_dims(spectro_mag, axis=0)
31 | if with_phase:
32 | spectro_phase = np.expand_dims(np.angle(spectro_phase), axis=0)
33 | return spectro_mag, spectro_phase
34 | else:
35 | return spectro_mag
36 |
37 | def generate_spectrogram_complex(audio, stft_frame, stft_hop, n_fft):
38 | spectro = librosa.core.stft(audio, hop_length=stft_hop, n_fft=n_fft, win_length=stft_frame, center=True)
39 | real = np.expand_dims(np.real(spectro), axis=0)
40 | imag = np.expand_dims(np.imag(spectro), axis=0)
41 | spectro_two_channel = np.concatenate((real, imag), axis=0)
42 | return spectro_two_channel
43 |
44 | def batchify(fn, chunk):
45 | """
46 | Constructs a version of 'fn' that applies to smaller batches
47 | """
48 | if chunk is None:
49 | return fn
50 | def ret(inputs):
51 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
52 | return ret
53 |
54 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
55 | """
56 | Prepares inputs and applies network 'fn'.
57 | """
58 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
59 | embedded = embed_fn(inputs_flat)
60 |
61 | if viewdirs is not None:
62 | input_dirs = viewdirs[:,None].expand(inputs.shape)
63 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
64 | embedded_dirs = embeddirs_fn(input_dirs_flat)
65 | embedded = torch.cat([embedded, embedded_dirs], -1)
66 |
67 | outputs_flat = batchify(fn, netchunk)(embedded)
68 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
69 | return outputs
70 |
71 | def create_nerf(args):
72 | """
73 | Instantiate NeRF's MLP model.
74 | """
75 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
76 |
77 | input_ch_views = 0
78 | embeddirs_fn = None
79 | if args.use_viewdirs:
80 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
81 | output_ch = 2
82 | skips = [4]
83 | model = NeRF(D=args.netdepth, W=args.netwidth,
84 | input_ch=input_ch, output_ch=output_ch, skips=skips,
85 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs)
86 | model = nn.DataParallel(model).to(device)
87 | grad_vars = list(model.parameters())
88 |
--------------------------------------------------------------------------------
/OF_render.py:
--------------------------------------------------------------------------------
1 | import os
2 | import datetime
3 | import sys
4 | import torch
5 | import torch.nn as nn
6 | import numpy as np
7 | import imageio
8 | import json
9 | import random
10 | import time
11 | from tqdm import tqdm, trange
12 | import scipy
13 | import librosa
14 | from scipy.io.wavfile import write
15 | from scipy.spatial import KDTree
16 | import ddsp_torch as ddsp
17 | import itertools
18 | from taxim_render import TaximRender
19 | from PIL import Image
20 | import argparse
21 | from load_osf import load_osf_data
22 | import AudioNet_utils
23 | import AudioNet_model
24 | import TouchNet_utils
25 | import TouchNet_model
26 | import VisionNet_utils
27 | from utils import *
28 |
29 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30 |
31 | def config_parser():
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument("--modality", type=str, default="vision,audio,touch")
34 | parser.add_argument("--object_file_path", type=str, default="demo/ObjectFile.pth", help='ObjectFile path')
35 | parser.add_argument("--KiloOSF", action="store_true")
36 |
37 | # VisionNet options
38 | parser.add_argument("--vision_test_file_path", default='data/vision_demo.npy', help='The path of the testing file for vision, which should be a npy file.')
39 | parser.add_argument("--vision_results_dir", type=str, default='./results/vision/', help='The path of the vision results directory to save rendered images.')
40 |
41 | # AudioNet options
42 | parser.add_argument('--audio_vertices_file_path', default='./data/audio_demo_vertices.npy', help='The path of the testing vertices file for audio, which should be a npy file.')
43 | parser.add_argument('--audio_forces_file_path', default='./data/forces.npy', help='The path of forces file for audio, which should be a npy file.')
44 | parser.add_argument("--audio_results_dir", type=str, default='./results/audio/', help='The path of the audio results directory to save impact sounds.')
45 |
46 | # TouchNet options
47 | parser.add_argument('--touch_vertices_file_path', default='./data/touch_demo_vertices.npy', help='The path of the testing vertices file for touch, which should be a npy file.')
48 | parser.add_argument('--touch_gelinfo_file_path', default='./data/touch_demo_gelinfo.npy', help='The path of the gel configurations for touch, which should be a npy file.')
49 | parser.add_argument('--touch_results_dir', type=str, default='./results/touch/', help='The path of the touch results directory to save rendered tactile RGB images.')
50 |
51 | return parser
52 |
53 |
54 | def VisionNet_eval(args):
55 |
56 | checkpoint = torch.load(args.object_file_path)
57 | cfg = checkpoint['VisionNet']['cfg']
58 |
59 | metadata, render_metadata = None, None
60 | background_color = torch.ones(3, dtype=torch.float, device=device)
61 |
62 | poses, hwf, i_split, metadata = load_osf_data(args.vision_test_file_path)
63 | i_test = i_split[0]
64 | render_poses = np.array(poses[i_test])
65 |
66 | # Create dummy metadata if not loaded from dataset.
67 | if metadata is None:
68 | metadata = torch.tensor([[0, 0, 1]] * len(images), dtype=torch.float) # [N, 3]
69 | if render_metadata is None:
70 | render_metadata = metadata
71 |
72 | # Cast intrinsics to right types
73 | H, W, focal = hwf
74 | H, W = int(H), int(W)
75 | hwf = [H, W, focal]
76 | intrinsics = CameraIntrinsics(int(H), int(W), focal, focal, W * .5, H * .5)
77 |
78 | render_kwargs_train = {
79 | 'perturb' : cfg['perturb'],
80 | 'N_samples' : cfg['num_samples_per_ray'],
81 | 'N_importance' : cfg['num_importance_samples_per_ray'],
82 | 'use_viewdirs': True,
83 | 'use_lightdirs': True,
84 | 'white_bkgd' : cfg['blender_white_background'],
85 | 'raw_noise_std' : cfg['raw_noise_std'],
86 | 'near' : cfg['near'],
87 | 'far' : cfg['far'],
88 | 'metadata': metadata,
89 | 'render_metadata': metadata,
90 | 'random_direction_probability': cfg['random_direction_probability'],
91 | 'von_mises_kappa': cfg['von_mises_kappa'],
92 | 'background_color': background_color,
93 | 'lightdirs_method': 'metadata',
94 | 'cfg': cfg
95 | }
96 |
97 | ConfigManager.init(cfg)
98 |
99 | if args.KiloOSF:
100 | import kilonerf_cuda
101 | from local_distill import create_multi_network_fourier_embedding, has_flag, create_multi_network
102 | kilonerf_cuda.init_stream_pool(16)
103 | kilonerf_cuda.init_magma()
104 |
105 | position_num_input_channels, position_fourier_embedding = create_multi_network_fourier_embedding(1, cfg['num_frequencies'])
106 | direction_num_input_channels, direction_fourier_embedding = create_multi_network_fourier_embedding(1, cfg['num_frequencies_direction'])
107 | light_num_input_channels, light_fourier_embedding = create_multi_network_fourier_embedding(1, cfg['num_frequencies_light'])
108 |
109 | root_nodes = occupancy_grid = None
110 |
111 | res = cfg['fixed_resolution']
112 | network_resolution = torch.tensor(res, dtype=torch.long, device=torch.device('cpu'))
113 | num_networks = res[0] * res[1] * res[2]
114 | model = multi_network = create_multi_network(num_networks, position_num_input_channels, direction_num_input_channels, light_num_input_channels, 4, 'multimatmul_differentiable', cfg).to(device)
115 |
116 | global_domain_min, global_domain_max = ConfigManager.get_global_domain_min_and_max(torch.device('cpu'))
117 | global_domain_size = global_domain_max - global_domain_min
118 | network_voxel_size = global_domain_size / network_resolution
119 |
120 | # Determine bounding boxes (domains) of all networks. Required for global to local coordinate conversion.
121 | domain_mins = []
122 | domain_maxs = []
123 | for coord in itertools.product(*[range(r) for r in res]):
124 | coord = torch.tensor(coord, device=torch.device('cpu'))
125 | domain_min = global_domain_min + network_voxel_size * coord
126 | domain_max = domain_min + network_voxel_size
127 | domain_mins.append(domain_min.tolist())
128 | domain_maxs.append(domain_max.tolist())
129 | domain_mins = torch.tensor(domain_mins, device=device)
130 | domain_maxs = torch.tensor(domain_maxs, device=device)
131 | occupancy_grid = checkpoint['VisionNet']['occupancy_grid']
132 |
133 | additional_kwargs = {
134 | 'root_nodes': root_nodes,
135 | 'position_fourier_embedding': position_fourier_embedding,
136 | 'direction_fourier_embedding': direction_fourier_embedding,
137 | 'light_fourier_embedding': light_fourier_embedding,
138 | 'multi_network': multi_network,
139 | 'domain_mins': domain_mins,
140 | 'domain_maxs': domain_maxs,
141 | 'occupancy_grid': occupancy_grid,
142 | 'debug_network_color_map': None
143 | }
144 | else:
145 | model, embed_fn, embeddirs_fn, embedlights_fn = create_nerf(cfg)
146 | model = model.to(device)
147 | network_query_fn = lambda inputs, viewdirs, lightdirs, network_fn : VisionNet_utils.run_network(inputs, viewdirs, lightdirs, network_fn,
148 | embed_fn=embed_fn,
149 | embeddirs_fn=embeddirs_fn,
150 | embedlights_fn=embedlights_fn,
151 | netchunk=cfg['network_chunk_size'])
152 |
153 | additional_kwargs = {
154 | 'network_query_fn' : network_query_fn,
155 | 'network_fn' : model
156 | }
157 |
158 | render_kwargs_train.update(additional_kwargs)
159 | render_kwargs_train['ndc'] = False
160 | render_kwargs_train['lindisp'] = cfg['llff_lindisp']
161 |
162 | render_kwargs_test = render_kwargs_train.copy()
163 | render_kwargs_test['perturb'] = False
164 | render_kwargs_test['raw_noise_std'] = 0.
165 | render_kwargs_test['random_direction_probability'] = -1
166 | render_kwargs_test['von_mises_kappa'] = -1
167 |
168 | model.load_state_dict(checkpoint['VisionNet']['model_state_dict'])
169 |
170 | # Move testing data to GPU
171 | render_poses = torch.Tensor(render_poses).to(device)
172 |
173 | model.eval()
174 | with torch.no_grad():
175 | images = None
176 | testsavedir = args.vision_results_dir
177 | os.makedirs(testsavedir, exist_ok=True)
178 | rgbs, _ = VisionNet_utils.render_path(render_poses, intrinsics, cfg['chunk_size'], render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=cfg['render_factor'])
179 |
180 | def AudioNet_eval(args):
181 |
182 | checkpoint = torch.load(args.object_file_path)
183 | normalizer_dic = checkpoint['AudioNet']['normalizer']
184 | gains_f1_min = normalizer_dic['f1_min']
185 | gains_f1_max = normalizer_dic['f1_max']
186 | gains_f2_min = normalizer_dic['f2_min']
187 | gains_f2_max = normalizer_dic['f2_max']
188 | gains_f3_min = normalizer_dic['f3_min']
189 | gains_f3_max = normalizer_dic['f3_max']
190 | xyz_min = normalizer_dic['xyz_min']
191 | xyz_max = normalizer_dic['xyz_max']
192 | freqs = checkpoint['AudioNet']['frequencies']
193 | damps = checkpoint['AudioNet']['dampings']
194 |
195 | forces = np.load(args.audio_forces_file_path)
196 |
197 | xyz = np.load(args.audio_vertices_file_path).reshape((-1, 3))
198 | # normalize xyz to [-1, 1]
199 | xyz = (xyz - xyz_min) / (xyz_max - xyz_min)
200 |
201 | N = xyz.shape[0]
202 | G = freqs.shape[0]
203 |
204 | embed_fn, input_ch = AudioNet_model.get_embedder(10, 0)
205 | model = AudioNet_model.AudioNeRF(D=8, input_ch=input_ch, output_ch=G)
206 | state_dic = checkpoint['AudioNet']["model_state_dict"]
207 | state_dic = AudioNet_utils.strip_prefix_if_present(state_dic, 'module.')
208 | model.load_state_dict(state_dic)
209 | model = nn.DataParallel(model).to(device)
210 | model.eval()
211 |
212 | preds_gain_x = torch.zeros((N, G)).to(device)
213 | preds_gain_y = torch.zeros((N, G)).to(device)
214 | preds_gain_z = torch.zeros((N, G)).to(device)
215 |
216 | batch_size = 1024
217 |
218 | for i in trange(N // batch_size + 1):
219 | curr_x = torch.Tensor(xyz[i*batch_size:(i+1)*batch_size]).to(device)
220 | curr_y = torch.Tensor(xyz[i*batch_size:(i+1)*batch_size]).to(device)
221 | curr_z = torch.Tensor(xyz[i*batch_size:(i+1)*batch_size]).to(device)
222 | embedded_x = embed_fn(curr_x)
223 | embedded_y = embed_fn(curr_y)
224 | embedded_z = embed_fn(curr_z)
225 | results_x, results_y, results_z = model(embedded_x, embedded_y, embedded_z)
226 |
227 | preds_gain_x[i*batch_size:(i+1)*batch_size] = results_x
228 | preds_gain_y[i*batch_size:(i+1)*batch_size] = results_y
229 | preds_gain_z[i*batch_size:(i+1)*batch_size] = results_z
230 |
231 | preds_gain_x = preds_gain_x * (gains_f1_max - gains_f1_min) + gains_f1_min
232 | preds_gain_y = preds_gain_y * (gains_f2_max - gains_f2_min) + gains_f2_min
233 | preds_gain_z = preds_gain_z * (gains_f3_max - gains_f3_min) + gains_f3_min
234 | preds_gain = torch.cat((preds_gain_x[:, None, :], preds_gain_y[:, None, :], preds_gain_z[:, None, :]), 1)
235 |
236 | freqs = torch.Tensor(freqs).to(device)
237 | damps = torch.Tensor(damps).to(device)
238 |
239 | testsavedir = args.audio_results_dir
240 | os.makedirs(testsavedir, exist_ok=True)
241 |
242 | for i in trange(N):
243 | preds_gain_x_i = preds_gain[i, 0, :]
244 | preds_gain_y_i = preds_gain[i, 1, :]
245 | preds_gain_z_i = preds_gain[i, 2, :]
246 | force_x, force_y, force_z = forces[i]
247 | combined_preds_gain = force_x * preds_gain_x_i + force_y * preds_gain_y_i + force_z * preds_gain_z_i
248 | combined_preds_gain = combined_preds_gain.unsqueeze(0)
249 | modal_fir = torch.unsqueeze(ddsp.get_modal_fir(combined_preds_gain, freqs, damps), axis=1)
250 | impulse = torch.reshape(torch.Tensor(scipy.signal.unit_impulse(44100*3)).to(device), (1, -1)).repeat(modal_fir.shape[0], 1)
251 | result = ddsp.fft_convolve(impulse, modal_fir)
252 | signal = result[0, :].detach().cpu().numpy()
253 | signal = signal / np.abs(signal).max()
254 | # write wav file
255 | output_path = os.path.join(testsavedir, str(i+1) + '.wav')
256 | write(output_path, 44100, signal.astype(np.float32))
257 |
258 |
259 | def TouchNet_eval(args):
260 |
261 | checkpoint = torch.load(args.object_file_path)
262 |
263 | rotation_max = 15
264 | depth_max = 0.04
265 | depth_min = 0.0339
266 | displacement_min = 0.0005
267 | displacement_max = 0.0020
268 | depth_max = 0.04
269 | depth_min = 0.0339
270 | rgb_width = 120
271 | rgb_height = 160
272 | network_depth = 8
273 |
274 | #TODO load object...
275 | vertex_min = checkpoint['TouchNet']['xyz_min']
276 | vertex_max = checkpoint['TouchNet']['xyz_max']
277 |
278 | vertex_coordinates = np.load(args.touch_vertices_file_path)
279 | N = vertex_coordinates.shape[0]
280 | gelinfo_data = np.load(args.touch_gelinfo_file_path)
281 | theta, phi, displacement = gelinfo_data[:, 0], gelinfo_data[:, 1], gelinfo_data[:, 2]
282 | phi_x = np.cos(phi)
283 | phi_y = np.sin(phi)
284 |
285 | # normalize theta to [-1, 1]
286 | theta = (theta - np.radians(0)) / (np.radians(rotation_max) - np.radians(0))
287 |
288 | #normalize displacement to [-1,1]
289 | displacement_norm = (displacement - displacement_min) / (displacement_max - displacement_min)
290 |
291 | #normalize coordinates to [-1,1]
292 | vertex_coordinates = (vertex_coordinates - vertex_min) / (vertex_max - vertex_min)
293 |
294 | #initialize horizontal and vertical features
295 | w_feats = np.repeat(np.repeat(np.arange(rgb_width).reshape((rgb_width, 1)), rgb_height, axis=1).reshape((1, 1, rgb_width, rgb_height)), N, axis=0)
296 | h_feats = np.repeat(np.repeat(np.arange(rgb_height).reshape((1, rgb_height)), rgb_width, axis=0).reshape((1, 1, rgb_width, rgb_height)), N, axis=0)
297 | #normalize horizontal and vertical features to [-1, 1]
298 | w_feats_min = w_feats.min()
299 | w_feats_max = w_feats.max()
300 | h_feats_min = h_feats.min()
301 | h_feats_max = h_feats.max()
302 | w_feats = (w_feats - w_feats_min) / (w_feats_max - w_feats_min)
303 | h_feats = (h_feats - h_feats_min) / (h_feats_max - h_feats_min)
304 | w_feats = torch.FloatTensor(w_feats)
305 | h_feats = torch.FloatTensor(h_feats)
306 |
307 | theta = np.repeat(theta.reshape((N, 1, 1)), rgb_width * rgb_height, axis=1)
308 | phi_x = np.repeat(phi_x.reshape((N, 1, 1)), rgb_width * rgb_height, axis=1)
309 | phi_y = np.repeat(phi_y.reshape((N, 1, 1)), rgb_width * rgb_height, axis=1)
310 | displacement_norm = np.repeat(displacement_norm.reshape((N, 1, 1)), rgb_width * rgb_height, axis=1)
311 | vertex_coordinates = np.repeat(vertex_coordinates.reshape((N, 1, 3)), rgb_width * rgb_height, axis=1)
312 |
313 | data_wh = np.concatenate((w_feats, h_feats), axis=1)
314 | data_wh = np.transpose(data_wh.reshape((N, 2, -1)), axes=[0, 2, 1])
315 | #Now get final feats matrix as [x, y, z, theta, phi_x, phi_y, displacement, w, h]
316 | data = np.concatenate((vertex_coordinates, theta, phi_x, phi_y, displacement_norm, data_wh), axis=2).reshape((-1, 9))
317 |
318 | #checkpoint = torch.load(args.object_file_path)
319 | embed_fn, input_ch = TouchNet_model.get_embedder(10, 0)
320 | model = TouchNet_model.NeRF(D = network_depth, input_ch = input_ch, output_ch = 1)
321 | state_dic = checkpoint['TouchNet']['model_state_dict']
322 | state_dic = TouchNet_utils.strip_prefix_if_present(state_dic, 'module.')
323 | model.load_state_dict(state_dic)
324 | model = nn.DataParallel(model).to(device)
325 | model.eval()
326 |
327 | preds = np.empty((data.shape[0], 1))
328 |
329 | batch_size = 1024
330 |
331 | testsavedir = args.touch_results_dir
332 | os.makedirs(testsavedir, exist_ok=True)
333 |
334 | for i in trange(data.shape[0] // batch_size + 1):
335 | inputs = torch.Tensor(data[i*batch_size:(i+1)*batch_size]).to(device)
336 | embedded = embed_fn(inputs)
337 | results = model(embedded)
338 | preds[i*batch_size:(i+1)*batch_size, :] = results.detach().cpu().numpy()
339 |
340 | preds = preds * (depth_max - depth_min) + depth_min
341 | preds = np.transpose(preds.reshape((N, -1, 1)), axes = [0, 2, 1]).reshape((N, rgb_width, rgb_height))
342 | taxim = TaximRender("./calibs/")
343 | for i in trange(N):
344 | height_map, contact_map, tactile_map = taxim.render(preds[i], displacement[i])
345 | tactile_map = Image.fromarray(tactile_map.astype(np.uint8), 'RGB')
346 | filename = os.path.join(testsavedir, '{}.png'.format(i+1))
347 | tactile_map.save(filename)
348 |
349 | if __name__ =='__main__':
350 | parser = config_parser()
351 | args = parser.parse_args()
352 | modalities = args.modality.strip().split(",")
353 |
354 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
355 |
356 | if "vision" in modalities:
357 | VisionNet_eval(args=args)
358 | if "audio" in modalities:
359 | AudioNet_eval(args=args)
360 | if "touch" in modalities:
361 | TouchNet_eval(args=args)
362 |
--------------------------------------------------------------------------------
/ObjectFolder1.0/AudioNet_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.autograd.set_detect_anomaly(True)
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from AudioNet_utils import *
7 |
8 | class DenseLayer(nn.Linear):
9 | def __init__(self, in_dim: int, out_dim: int, activation: str = 'relu', *args, **kwargs) -> None:
10 | self.activation = activation
11 | super().__init__(in_dim, out_dim, *args, **kwargs)
12 |
13 | def reset_parameters(self) -> None:
14 | torch.nn.init.xavier_uniform_(self.weight, gain=torch.nn.init.calculate_gain(self.activation))
15 | if self.bias is not None:
16 | torch.nn.init.zeros_(self.bias)
17 |
18 | class Embedder:
19 | def __init__(self, **kwargs):
20 | self.kwargs = kwargs
21 | self.create_embedding_fn()
22 |
23 | def create_embedding_fn(self):
24 | embed_fns = []
25 | d = self.kwargs['input_dims']
26 | out_dim = 0
27 | if self.kwargs['include_input']:
28 | embed_fns.append(lambda x: x)
29 | out_dim += d
30 |
31 | max_freq = self.kwargs['max_freq_log2']
32 | N_freqs = self.kwargs['num_freqs']
33 |
34 | if self.kwargs['log_sampling']:
35 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
36 | else:
37 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
38 |
39 | for freq in freq_bands:
40 | for p_fn in self.kwargs['periodic_fns']:
41 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
42 | out_dim += d
43 |
44 | self.embed_fns = embed_fns
45 | self.out_dim = out_dim
46 |
47 | def embed(self, inputs):
48 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
49 |
50 | def get_embedder(multires, i=0):
51 | if i == -1:
52 | #5 is for x, y, z, f, t
53 | return nn.Identity(), 5
54 |
55 | embed_kwargs = {
56 | 'include_input': True,
57 | 'input_dims': 5,
58 | 'max_freq_log2': multires-1,
59 | 'num_freqs': multires,
60 | 'log_sampling': True,
61 | 'periodic_fns': [torch.sin, torch.cos],
62 | }
63 |
64 | embedder_obj = Embedder(**embed_kwargs)
65 | embed = lambda x, eo=embedder_obj: eo.embed(x)
66 | return embed, embedder_obj.out_dim
67 |
68 | class AudioNeRF(nn.Module):
69 | def __init__(self, D=8, input_ch=5):
70 | super(AudioNeRF, self).__init__()
71 | self.model_x = NeRF(D = D, input_ch = input_ch)
72 | self.model_y = NeRF(D = D, input_ch = input_ch)
73 | self.model_z = NeRF(D = D, input_ch = input_ch)
74 |
75 | def forward(self, embedded_x, embedded_y, embedded_z):
76 | results_x = self.model_x(embedded_x)
77 | results_y = self.model_y(embedded_y)
78 | results_z = self.model_z(embedded_z)
79 | return results_x, results_y, results_z
80 |
81 |
82 | class NeRF(nn.Module):
83 | def __init__(self, D=8, W=256, input_ch=5, input_ch_views=0, output_ch=2, skips=[4], use_viewdirs=False):
84 | """
85 | """
86 | super(NeRF, self).__init__()
87 | self.D = D
88 | self.W = W
89 | self.input_ch = input_ch
90 | self.input_ch_views = input_ch_views
91 | self.skips = skips
92 | self.use_viewdirs = use_viewdirs
93 |
94 | self.pts_linears = nn.ModuleList(
95 | [DenseLayer(input_ch, W, activation='relu')] + [DenseLayer(W, W, activation='relu') if i not in self.skips else DenseLayer(W + input_ch, W, activation='relu') for i in range(D-1)])
96 |
97 | self.views_linears = nn.ModuleList([DenseLayer(input_ch_views + W, W//2, activation='relu')])
98 |
99 | if use_viewdirs:
100 | self.feature_linear = DenseLayer(W, W, activation='sigmoid')
101 | #self.alpha_linear = DenseLayer(W, 1, activation='linear')
102 | self.rgb_linear = DenseLayer(W//2, output_ch, activation='sigmoid')
103 | else:
104 | self.output_linear = DenseLayer(W, output_ch, activation='sigmoid')
105 |
106 |
107 | def forward(self, x):
108 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
109 | h = input_pts
110 | for i, l in enumerate(self.pts_linears):
111 | h = self.pts_linears[i](h)
112 | h = F.relu(h)
113 | if i in self.skips:
114 | h = torch.cat([input_pts, h], -1)
115 |
116 | if self.use_viewdirs:
117 | feature = self.feature_linear(h)
118 | h = torch.cat([feature, input_views], -1)
119 |
120 | for i, l in enumerate(self.views_linears):
121 | h = self.views_linears[i](h)
122 | h = F.relu(h)
123 |
124 | outputs = self.rgb_linear(h)
125 | else:
126 | outputs = self.output_linear(h)
127 |
128 | return outputs
129 |
--------------------------------------------------------------------------------
/ObjectFolder1.0/AudioNet_utils.py:
--------------------------------------------------------------------------------
1 | from scipy.io import wavfile
2 | import librosa
3 | import librosa.display
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from AudioNet_model import *
7 | import os
8 | from collections import OrderedDict
9 |
10 | def transform_mesh_collision_binvox(coordinates, translate, scale):
11 | for i in range(3):
12 | coordinates[i] = (coordinates[i] + translate[i]) * scale
13 | return coordinates
14 |
15 | def strip_prefix_if_present(state_dict, prefix):
16 | keys = sorted(state_dict.keys())
17 | if not all(key.startswith(prefix) for key in keys):
18 | return state_dict
19 | stripped_state_dict = OrderedDict()
20 | for key, value in state_dict.items():
21 | stripped_state_dict[key.replace(prefix, "")] = value
22 | return stripped_state_dict
23 |
24 | def mkdirs(path, remove=False):
25 | if os.path.isdir(path):
26 | if remove:
27 | shutil.rmtree(path)
28 | else:
29 | return
30 | os.makedirs(path)
31 |
32 | def generate_spectrogram_magphase(audio, stft_frame, stft_hop, n_fft, with_phase=False):
33 | spectro = librosa.core.stft(audio, hop_length=stft_hop, n_fft=n_fft, win_length=stft_frame, center=True)
34 | spectro_mag, spectro_phase = librosa.core.magphase(spectro)
35 | spectro_mag = np.expand_dims(spectro_mag, axis=0)
36 | if with_phase:
37 | spectro_phase = np.expand_dims(np.angle(spectro_phase), axis=0)
38 | return spectro_mag, spectro_phase
39 | else:
40 | return spectro_mag
41 |
42 | def generate_spectrogram_complex(audio, stft_frame, stft_hop, n_fft):
43 | spectro = librosa.core.stft(audio, hop_length=stft_hop, n_fft=n_fft, win_length=stft_frame, center=True)
44 | real = np.expand_dims(np.real(spectro), axis=0)
45 | imag = np.expand_dims(np.imag(spectro), axis=0)
46 | spectro_two_channel = np.concatenate((real, imag), axis=0)
47 | return spectro_two_channel
48 |
49 | def batchify(fn, chunk):
50 | """
51 | Constructs a version of 'fn' that applies to smaller batches
52 | """
53 | if chunk is None:
54 | return fn
55 | def ret(inputs):
56 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
57 | return ret
58 |
59 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
60 | """
61 | Prepares inputs and applies network 'fn'.
62 | """
63 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
64 | embedded = embed_fn(inputs_flat)
65 |
66 | if viewdirs is not None:
67 | input_dirs = viewdirs[:,None].expand(inputs.shape)
68 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
69 | embedded_dirs = embeddirs_fn(input_dirs_flat)
70 | embedded = torch.cat([embedded, embedded_dirs], -1)
71 |
72 | outputs_flat = batchify(fn, netchunk)(embedded)
73 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
74 | return outputs
75 |
76 | """
77 | def get_embedder(multires, i=0):
78 | if i == -1:
79 | #5 is for x, y, z, f, t
80 | return nn.Identity(), 5
81 |
82 | embed_kwargs = {
83 | 'include_input': True,
84 | 'input_dims': 5,
85 | 'max_freq_log2': multires-1,
86 | 'num_freqs': multires,
87 | 'log_sampling': True,
88 | 'periodic_fns': [torch.sin, torch.cos],
89 | }
90 |
91 | embedder_obj = Embedder(**embed_kwargs)
92 | embed = lambda x, eo=embedder_obj: eo.embed(x)
93 | return embed, embedder_obj.out_dim
94 | """
95 |
96 | def create_nerf(args):
97 | """
98 | Instantiate NeRF's MLP model.
99 | """
100 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
101 |
102 | input_ch_views = 0
103 | embeddirs_fn = None
104 | if args.use_viewdirs:
105 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
106 | output_ch = 2
107 | skips = [4]
108 | model = NeRF(D=args.netdepth, W=args.netwidth,
109 | input_ch=input_ch, output_ch=output_ch, skips=skips,
110 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs)
111 | model = nn.DataParallel(model).to(device)
112 | grad_vars = list(model.parameters())
113 |
--------------------------------------------------------------------------------
/ObjectFolder1.0/ObjectFolder_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/ObjectFolder1.0/ObjectFolder_teaser.png
--------------------------------------------------------------------------------
/ObjectFolder1.0/README.md:
--------------------------------------------------------------------------------
1 | ## ObjectFolder: A Dataset of Objects with Implicit Visual, Auditory, and Tactile Representations (CoRL 2021)
2 | [[Project Page]](https://ai.stanford.edu/~rhgao/objectfolder/) [[arXiv]](https://arxiv.org/abs/2109.07991)
3 |
4 |
5 |
6 |
7 |
8 | [ObjectFolder: A Dataset of Objects with Implicit Visual, Auditory, and Tactile Representations](https://arxiv.org/abs/2109.07991)
9 | [Ruohan Gao](https://www.ai.stanford.edu/~rhgao/), [Yen-Yu Chang](https://yuyuchang.github.io/), [Shivani Mall](), [Li Fei-Fei](https://profiles.stanford.edu/fei-fei-li), [Jiajun Wu](https://jiajunwu.com/)
10 | Stanford University
11 | In Conference on Robot Learning (**CoRL**), 2021
12 |
13 |
14 |
15 | If you find our code or project useful in your research, please cite:
16 |
17 | @inproceedings{gao2021ObjectFolder,
18 | title = {ObjectFolder: A Dataset of Objects with Implicit Visual, Auditory, and Tactile Representations},
19 | author = {Gao, Ruohan and Chang, Yen-Yu and Mall, Shivani and Fei-Fei, Li and Wu, Jiajun},
20 | booktitle = {CoRL},
21 | year = {2021}
22 | }
23 |
24 | ### About ObjectFolder Dataset
25 |
26 |
27 |
28 | ObjectFolder is a dataset of 100 objects in the form of implicit representations. It contains 100 Object Files each containing the complete multisensory profile for an object instance. Each Object File implicit neural representation network contains three sub-networks---VisionNet, AudioNet, and TouchNet, which through querying with the corresponding extrinsic parameters we can obtain the visual appearance of the object from different views, impact sounds of the object at each position, and tactile sensing of the object at every surface location, respectively. The dataset contains common household objects of diverse categories such as bowl, mug, cabinet, television, shelf, fork, and spoon. See the paper for details.
29 |
30 |
31 |
32 | ### Dataset Download and Preparation
33 | ```
34 | git clone https://github.com/rhgao/ObjectFolder.git
35 | cd ObjectFolder/ObjectFolder1.0
36 | wget https://download.cs.stanford.edu/viscam/ObjectFolder/ObjectFolder1.0.tar.gz
37 | tar -xvf ObjectFolder1.0.tar.gz
38 | pip install -r requirements.txt
39 | ```
40 |
41 | ### Rendering visual, auditory, and tactile sensory data
42 | Run the following command to render images, impact sounds, and tactile RGB images:
43 | ```
44 | $ python evaluate.py --object_file_path path_of_ObjectFile \
45 | --vision_test_file_path path_of_vision_test_file \
46 | --vision_results_dir path_of_vision_results_directory \
47 | --audio_vertices_file_path path_of_audio_testing_vertices_file \
48 | --audio_forces_file_path path_of_forces_file \
49 | --audio_results_dir path_of_audio_results_directory \
50 | --touch_vertices_file_path path_of_touch_testing_vertices_file \
51 | --touch_results_path path_of_touch_results_directory
52 | ```
53 | This code can be run with the following command-line arguments:
54 | * `--object_file_path`: The path of ObjectFile.
55 | * `--vision_test_file_path`: The path of the testing file for vision, which should be a npy file.
56 | * `--vision_results_dir`: The path of the vision results directory to save rendered images.
57 | * `--audio_vertices_file_path`: The path of the testing vertices file for audio, which should be a npy file.
58 | * `--audio_forces_file_path`: The path of forces file for audio, which should be a npy file.
59 | * `--audio_results_dir`: The path of audio results directory to save rendered impact sounds as .wav files.
60 | * `--touch_vertices_file_path`: The path of the testing vertices file for touch, which should be a npy file.
61 | * `--touch_results_dir`: The path of the touch results directory to save rendered tactile RGB images.
62 |
63 | ### Data format
64 | * `--vision_test_file_path`: It is a npy file with shape of (N, 6), where N is the number of testing viewpoints. Each data point contains the coordinates of the camera and the light in the form of (camera_x, camera_y, camera_z, light_x, light_y, light_z).
65 | * `--audio_vertices_file_path`: It is a npy file with shape of (N, 3), where N is the number of testing vertices. Each data point represents a coordinate on the object in the form of (x, y, z).
66 | * `--audio_forces_file_path`: It is a npy file with shape of (N, 3), where N is the number of testing vertices. Each data point represents the force values for the corresponding impact in the form of (F_x, F_y, F_z).
67 | * `--touch_vertices_file_path`: It is a npy file with shape of (N, 3), where N is the number of testing vertices. Each data point contains a coordinate on the object in the form of (x, y, z).
68 |
69 | ### Demo
70 | Below we show an example of rendering the visual, auditory, and tactile data from the ObjectFile implicit representation for one object:
71 | ```
72 | $ python evaluate.py --object_file_path Objects/25/ObjectFile.pth \
73 | --vision_test_file_path demo/vision_demo.npy \
74 | --vision_results_dir demo/vision_results/ \
75 | --audio_vertices_file_path demo/audio_demo_vertices.npy \
76 | --audio_forces_file_path demo/audio_demo_forces.npy \
77 | --audio_results_dir demo/audio_results/ \
78 | --touch_vertices_file_path demo/touch_demo_vertices.npy \
79 | --touch_results_dir demo/touch_results/
80 | ```
81 |
82 | The rendered images, impact sounds, tactile images will be saved in `demo/vision_results/`, `demo/audio_results/`, and `demo/touch_results/`, respectively.
83 |
84 | ### Acknowlegements
85 | The code for the neural implicit representation network is adapted from Yen-Chen Lin's [PyTorch implementation](https://github.com/yenchenlin/nerf-pytorch) of [NeRF](https://www.matthewtancik.com/nerf) and Michelle Guo's TensorFlow implementation of [OSF](https://www.shellguo.com/osf/).
86 |
87 | ### License
88 | ObjectFolder is CC BY 4.0 licensed, as found in the LICENSE file. The 100 high quality 3D objects originally come from online repositories including: 20 objects from [3D Model Haven](https://3dmodelhaven.com/), 28 objects from the [YCB dataset](http://ycb-benchmarks.s3-website-us-east-1.amazonaws.com/), and 52 objects from [Google Scanned Objects](https://app.ignitionrobotics.org/GoogleResearch/fuel/collections/Google\%20Scanned\%20Objects). Please also refer to their original lisence file.
89 |
--------------------------------------------------------------------------------
/ObjectFolder1.0/TouchNet_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.autograd.set_detect_anomaly(True)
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from TouchNet_utils import *
7 |
8 | class DenseLayer(nn.Linear):
9 | def __init__(self, in_dim: int, out_dim: int, activation: str = 'relu', *args, **kwargs) -> None:
10 | self.activation = activation
11 | super().__init__(in_dim, out_dim, *args, **kwargs)
12 |
13 | def reset_parameters(self) -> None:
14 | torch.nn.init.xavier_uniform_(self.weight, gain=torch.nn.init.calculate_gain(self.activation))
15 | if self.bias is not None:
16 | torch.nn.init.zeros_(self.bias)
17 |
18 | class Embedder:
19 | def __init__(self, **kwargs):
20 | self.kwargs = kwargs
21 | self.create_embedding_fn()
22 |
23 | def create_embedding_fn(self):
24 | embed_fns = []
25 | d = self.kwargs['input_dims']
26 | out_dim = 0
27 | if self.kwargs['include_input']:
28 | embed_fns.append(lambda x: x)
29 | out_dim += d
30 |
31 | max_freq = self.kwargs['max_freq_log2']
32 | N_freqs = self.kwargs['num_freqs']
33 |
34 | if self.kwargs['log_sampling']:
35 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
36 | else:
37 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
38 |
39 | for freq in freq_bands:
40 | for p_fn in self.kwargs['periodic_fns']:
41 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
42 | out_dim += d
43 |
44 | self.embed_fns = embed_fns
45 | self.out_dim = out_dim
46 |
47 | def embed(self, inputs):
48 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
49 |
50 | def get_embedder(multires, i=0):
51 | if i == -1:
52 | #5 is for x, y, z, r, g, b
53 | return nn.Identity(), 5
54 |
55 | embed_kwargs = {
56 | 'include_input': True,
57 | 'input_dims': 5,
58 | 'max_freq_log2': multires-1,
59 | 'num_freqs': multires,
60 | 'log_sampling': True,
61 | 'periodic_fns': [torch.sin, torch.cos],
62 | }
63 |
64 | embedder_obj = Embedder(**embed_kwargs)
65 | embed = lambda x, eo=embedder_obj: eo.embed(x)
66 | return embed, embedder_obj.out_dim
67 |
68 | class AudioNeRF(nn.Module):
69 | def __init__(self, D=8, input_ch=5):
70 | super(AudioNeRF, self).__init__()
71 | self.model_x = NeRF(D = D, input_ch = input_ch)
72 | self.model_y = NeRF(D = D, input_ch = input_ch)
73 | self.model_z = NeRF(D = D, input_ch = input_ch)
74 |
75 | def forward(self, embedded_x, embedded_y, embedded_z):
76 | results_x = self.model_x(embedded_x)
77 | results_y = self.model_y(embedded_y)
78 | results_z = self.model_z(embedded_z)
79 | return results_x, results_y, results_z
80 |
81 |
82 | class NeRF(nn.Module):
83 | def __init__(self, D=8, W=256, input_ch=5, input_ch_views=0, output_ch=2, skips=[4], use_viewdirs=False):
84 | """
85 | """
86 | super(NeRF, self).__init__()
87 | self.D = D
88 | self.W = W
89 | self.input_ch = input_ch
90 | self.input_ch_views = input_ch_views
91 | self.skips = skips
92 | self.use_viewdirs = use_viewdirs
93 |
94 | self.pts_linears = nn.ModuleList(
95 | [DenseLayer(input_ch, W, activation='relu')] + [DenseLayer(W, W, activation='relu') if i not in self.skips else DenseLayer(W + input_ch, W, activation='relu') for i in range(D-1)])
96 |
97 | self.views_linears = nn.ModuleList([DenseLayer(input_ch_views + W, W//2, activation='relu')])
98 |
99 | if use_viewdirs:
100 | self.feature_linear = DenseLayer(W, W, activation='sigmoid')
101 | #self.alpha_linear = DenseLayer(W, 1, activation='linear')
102 | self.rgb_linear = DenseLayer(W//2, output_ch, activation='sigmoid')
103 | else:
104 | self.output_linear = DenseLayer(W, output_ch, activation='sigmoid')
105 |
106 |
107 | def forward(self, x):
108 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
109 | h = input_pts
110 | for i, l in enumerate(self.pts_linears):
111 | h = self.pts_linears[i](h)
112 | h = F.relu(h)
113 | if i in self.skips:
114 | h = torch.cat([input_pts, h], -1)
115 |
116 | if self.use_viewdirs:
117 | feature = self.feature_linear(h)
118 | h = torch.cat([feature, input_views], -1)
119 |
120 | for i, l in enumerate(self.views_linears):
121 | h = self.views_linears[i](h)
122 | h = F.relu(h)
123 |
124 | outputs = self.rgb_linear(h)
125 | else:
126 | outputs = self.output_linear(h)
127 |
128 | return outputs
129 |
--------------------------------------------------------------------------------
/ObjectFolder1.0/TouchNet_utils.py:
--------------------------------------------------------------------------------
1 | from scipy.io import wavfile
2 | import librosa
3 | import librosa.display
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from TouchNet_model import *
7 | import os
8 | from collections import OrderedDict
9 |
10 | def strip_prefix_if_present(state_dict, prefix):
11 | keys = sorted(state_dict.keys())
12 | if not all(key.startswith(prefix) for key in keys):
13 | return state_dict
14 | stripped_state_dict = OrderedDict()
15 | for key, value in state_dict.items():
16 | stripped_state_dict[key.replace(prefix, "")] = value
17 | return stripped_state_dict
18 |
19 | def mkdirs(path, remove=False):
20 | if os.path.isdir(path):
21 | if remove:
22 | shutil.rmtree(path)
23 | else:
24 | return
25 | os.makedirs(path)
26 |
27 | def generate_spectrogram_magphase(audio, stft_frame, stft_hop, n_fft, with_phase=False):
28 | spectro = librosa.core.stft(audio, hop_length=stft_hop, n_fft=n_fft, win_length=stft_frame, center=True)
29 | spectro_mag, spectro_phase = librosa.core.magphase(spectro)
30 | spectro_mag = np.expand_dims(spectro_mag, axis=0)
31 | if with_phase:
32 | spectro_phase = np.expand_dims(np.angle(spectro_phase), axis=0)
33 | return spectro_mag, spectro_phase
34 | else:
35 | return spectro_mag
36 |
37 | def generate_spectrogram_complex(audio, stft_frame, stft_hop, n_fft):
38 | spectro = librosa.core.stft(audio, hop_length=stft_hop, n_fft=n_fft, win_length=stft_frame, center=True)
39 | real = np.expand_dims(np.real(spectro), axis=0)
40 | imag = np.expand_dims(np.imag(spectro), axis=0)
41 | spectro_two_channel = np.concatenate((real, imag), axis=0)
42 | return spectro_two_channel
43 |
44 | def batchify(fn, chunk):
45 | """
46 | Constructs a version of 'fn' that applies to smaller batches
47 | """
48 | if chunk is None:
49 | return fn
50 | def ret(inputs):
51 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
52 | return ret
53 |
54 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
55 | """
56 | Prepares inputs and applies network 'fn'.
57 | """
58 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
59 | embedded = embed_fn(inputs_flat)
60 |
61 | if viewdirs is not None:
62 | input_dirs = viewdirs[:,None].expand(inputs.shape)
63 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
64 | embedded_dirs = embeddirs_fn(input_dirs_flat)
65 | embedded = torch.cat([embedded, embedded_dirs], -1)
66 |
67 | outputs_flat = batchify(fn, netchunk)(embedded)
68 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
69 | return outputs
70 |
71 | """
72 | def get_embedder(multires, i=0):
73 | if i == -1:
74 | #5 is for x, y, z, f, t
75 | return nn.Identity(), 5
76 |
77 | embed_kwargs = {
78 | 'include_input': True,
79 | 'input_dims': 5,
80 | 'max_freq_log2': multires-1,
81 | 'num_freqs': multires,
82 | 'log_sampling': True,
83 | 'periodic_fns': [torch.sin, torch.cos],
84 | }
85 |
86 | embedder_obj = Embedder(**embed_kwargs)
87 | embed = lambda x, eo=embedder_obj: eo.embed(x)
88 | return embed, embedder_obj.out_dim
89 | """
90 |
91 | def create_nerf(args):
92 | """
93 | Instantiate NeRF's MLP model.
94 | """
95 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
96 |
97 | input_ch_views = 0
98 | embeddirs_fn = None
99 | if args.use_viewdirs:
100 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
101 | output_ch = 2
102 | skips = [4]
103 | model = NeRF(D=args.netdepth, W=args.netwidth,
104 | input_ch=input_ch, output_ch=output_ch, skips=skips,
105 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs)
106 | model = nn.DataParallel(model).to(device)
107 | grad_vars = list(model.parameters())
108 |
--------------------------------------------------------------------------------
/ObjectFolder1.0/box_utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions for bounding box computation."""
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import ray_utils
6 |
7 | def ray_to_box_coordinate_frame_pairwise(box_center, box_rotation_matrix,
8 | rays_start_point, rays_end_point):
9 | """Moves a set of rays into a box's coordinate frame.
10 |
11 | Args:
12 | box_center: A tensor of size [3] or [r, 3].
13 | box_rotation_matrix: A tensor of size [3, 3] or [r, 3, 3].
14 | rays_start_point: A tensor of size [r, 3] where r is the number of rays.
15 | rays_end_points: A tensor of size [r, 3] where r is the number of rays.
16 |
17 | Returns:
18 | rays_start_point_in_box_frame: A tensor of size [r, 3].
19 | rays_end_point_in_box_frame: A tensor if size [r, 3].
20 | """
21 | r = rays_start_point.size()[0]
22 | box_center = torch.broadcast_to(box_center, (r, 3))
23 | box_rotation_matrix = torch.broadcast_to(box_rotation_matrix, (r, 3, 3))
24 | rays_start_point_in_box_frame = torch.matmul(
25 | (rays_start_point - box_center).unsqueeze(1),
26 | box_rotation_matrix)
27 | rays_end_point_in_box_frame = torch.matmul(
28 | (rays_end_point - box_center).unsqueeze(1),
29 | box_rotation_matrix)
30 | return (rays_start_point_in_box_frame.view(-1, 3),
31 | rays_end_point_in_box_frame.view(-1, 3))
32 |
33 |
34 | def ray_box_intersection_pairwise(box_center,
35 | box_rotation_matrix,
36 | box_length,
37 | box_width,
38 | box_height,
39 | rays_start_point,
40 | rays_end_point,
41 | exclude_negative_t=False,
42 | exclude_enlarged_t=True,
43 | epsilon=0.000001):
44 | """Intersects a set of rays with a box.
45 |
46 | Note: The intersection points are returned in the box coordinate frame.
47 | Note: Make sure the start and end point of the rays are not the same.
48 | Note: Even though a start and end point is passed for each ray, rays are
49 | never ending and can intersect a box beyond their start / end points.
50 |
51 | Args:
52 | box_center: A tensor of size [3] or [r, 3].
53 | box_rotation_matrix: A tensor of size [3, 3] or [r, 3, 3].
54 | box_length: A scalar tensor or of size [r].
55 | box_width: A scalar tensor or of size [r].
56 | box_height: A scalar tensor or of size [r].
57 | rays_start_point: A tensor of size [r, 3] where r is the number of rays.
58 | rays_end_point: A tensor of size [r, 3] there r is the number of rays.
59 | exclude_negative_t: bool.
60 | exclude_enlarged_t: bool.
61 | epsilon: A very small number.
62 |
63 | Returns:
64 | intersection_points_in_box_frame: A tensor of size [r', 2, 3]
65 | that contains intersection points in box coordinate frame.
66 | indices_of_intersecting_rays: A tensor of size [r'].
67 | intersection_ts: A tensor of size [r'].
68 | """
69 | r = rays_start_point.size()[0]
70 | box_length = box_length.expand(r)
71 | box_width = box_width.expand(r)
72 | box_height = box_height.expand(r)
73 | box_center = torch.broadcast_to(box_center, (r, 3))
74 | box_rotation_matrix = torch.broadcast_to(box_rotation_matrix, (r, 3, 3))
75 | rays_start_point_in_box_frame, rays_end_point_in_box_frame = (
76 | ray_to_box_coordinate_frame_pairwise(
77 | box_center=box_center,
78 | box_rotation_matrix=box_rotation_matrix,
79 | rays_start_point=rays_start_point,
80 | rays_end_point=rays_end_point))
81 | rays_a = rays_end_point_in_box_frame - rays_start_point_in_box_frame
82 | intersection_masks = []
83 | intersection_points = []
84 | intersection_ts = []
85 | box_size = [box_length, box_width, box_height]
86 | for axis in range(3):
87 | plane_value = box_size[axis] / 2.0
88 | for _ in range(2):
89 | plane_value = -plane_value
90 | # Compute the scalar multiples of 'rays_a' to apply in order to intersect
91 | # with the plane.
92 | t = ((plane_value - rays_start_point_in_box_frame[:, axis]) / # [R,]
93 | rays_a[:, axis])
94 | # The current axis only intersects with plane if the ray is not parallel
95 | # with the plane. Note that this will result in 't' being +/- infinity, becasue
96 | # the ray component in the axis is zero, resulting in rays_a[:, axis] = 0.
97 | intersects_with_plane = torch.abs(rays_a[:, axis]) > epsilon
98 | if exclude_negative_t: # Only allow at most one negative t
99 | t = torch.maximum(t, torch.tensor(0.0)) # [R,]
100 | if exclude_enlarged_t:
101 | t = torch.maximum(t, torch.tensor(1.0)) # [R,]
102 | intersection_ts.append(t) # [R, 1]
103 | intersection_points_i = []
104 |
105 | # Initialize a mask which represents whether each ray intersects with the
106 | # current plane.
107 | intersection_masks_i = torch.ones_like(t, dtype=torch.int32).bool() # [R,]
108 | for axis2 in range(3):
109 | # Compute the point of intersection for the current axis.
110 | intersection_points_i_axis2 = ( # [R,]
111 | rays_start_point_in_box_frame[:, axis2] + t * rays_a[:, axis2])
112 | intersection_points_i.append(intersection_points_i_axis2) # 3x [R,]
113 |
114 | # Update the intersection mask depending on whether the intersection
115 | # point is within bounds.
116 | intersection_masks_i = torch.logical_and( # [R,]
117 | torch.logical_and(intersection_masks_i, intersects_with_plane),
118 | torch.logical_and(
119 | intersection_points_i_axis2 <= (box_size[axis2] / 2.0 + epsilon),
120 | intersection_points_i_axis2 >= (-box_size[axis2] / 2.0 - epsilon)))
121 | intersection_points_i = torch.stack(intersection_points_i, dim=1) # [R, 3]
122 | intersection_masks.append(intersection_masks_i) # List of [R,]
123 | intersection_points.append(intersection_points_i) # List of [R, 3]
124 | intersection_ts = torch.stack(intersection_ts, dim=1) # [R, 6]
125 | intersection_masks = torch.stack(intersection_masks, dim=1) # [R, 6]
126 | intersection_points = torch.stack(intersection_points, dim=1) # [R, 6, 3]
127 |
128 | # Compute a mask over rays with exactly two plane intersections out of the six
129 | # planes. More intersections are possible if the ray coincides with a box
130 | # edge or corner, but we'll ignore these cases for now.
131 | counts = torch.sum(intersection_masks.int(), dim=1) # [R,]
132 | intersection_masks_any = torch.eq(counts, 2) # [R,]
133 | indices = torch.arange(intersection_masks_any.size()[0]).int() # [R,]
134 | # Apply the intersection masks over tensors.
135 | indices = indices[intersection_masks_any] # [R',]
136 | intersection_masks = intersection_masks[intersection_masks_any] # [R', 6]
137 | intersection_points = intersection_points[intersection_masks_any] # [R', 6, 3]
138 | intersection_points = intersection_points[intersection_masks].view(-1, 2, 3) # [R', 2, 3]
139 | # Ensure one or more positive ts.
140 | intersection_ts = intersection_ts[intersection_masks_any] # [R', 6]
141 | intersection_ts = intersection_ts[intersection_masks] # [R'*2]
142 | intersection_ts = intersection_ts.view(indices.size()[0], 2) # [R', 2]
143 | positive_ts_mask = (intersection_ts >= 0) # [R', 2]
144 | positive_ts_count = torch.sum(positive_ts_mask.int(), dim=1) # [R']
145 | positive_ts_mask = (positive_ts_count >= 1) # [R']
146 | intersection_points = intersection_points[positive_ts_mask] # [R'', 2, 3]
147 | false_indices = indices[torch.logical_not(positive_ts_mask)] # [R',]
148 | indices = indices[positive_ts_mask] # [R'',]
149 | if len(false_indices) > 0:
150 | intersection_masks_any[false_indices[:, None]] = torch.zeros(false_indices.size(), dtype=torch.bool)
151 | return rays_start_point_in_box_frame, intersection_masks_any, intersection_points, indices
152 |
153 |
154 | def compute_bounds_from_intersect_points(rays_o, intersect_indices,
155 | intersect_points):
156 | """Computes bounds from intersection points.
157 |
158 | Note: Make sure that inputs are in the same coordiante frame.
159 |
160 | Args:
161 | rays_o: [R, 3] float tensor
162 | intersect_indices: [R', 1] float tensor
163 | intersect_points: [R', 2, 3] float tensor
164 |
165 | Returns:
166 | intersect_bounds: [R', 2] float tensor
167 |
168 | where R is the number of rays and R' is the number of intersecting rays.
169 | """
170 | intersect_rays_o = rays_o[intersect_indices] # [R', 1, 3]
171 | intersect_diff = intersect_points - intersect_rays_o # [R', 2, 3]
172 | intersect_bounds = torch.norm(intersect_diff, dim=2) # [R', 2]
173 |
174 | # Sort the bounds so that near comes before far for all rays.
175 | intersect_bounds, _ = torch.sort(intersect_bounds, dim=1) # [R', 2]
176 |
177 | # For some reason the sort function returns [R', ?] instead of [R', 2], so we
178 | # will explicitly reshape it.
179 | intersect_bounds = intersect_bounds.view(-1, 2) # [R', 2]
180 | return intersect_bounds
181 |
182 |
183 | def compute_ray_bbox_bounds_pairwise(rays_o, rays_d, box_length,
184 | box_width, box_height, box_center,
185 | box_rotation, far_limit=1e10):
186 | """Computes near and far bounds for rays intersecting with bounding boxes.
187 |
188 | Note: rays and boxes are defined in world coordinate frame.
189 |
190 | Args:
191 | rays_o: [R, 3] float tensor. A set of ray origins.
192 | rays_d: [R, 3] float tensor. A set of ray directions.
193 | box_length: scalar or [R,] float tensor. Bounding box length.
194 | box_width: scalar or [R,] float tensor. Bounding box width.
195 | box_height: scalar or [R,] float tensor. Bounding box height.
196 | box_center: [3,] or [R, 3] float tensor. The center of the box.
197 | box_rotation: [3, 3] or [R, 3, 3] float tensor. The box rotation matrix.
198 | far_limit: float. The maximum far value to use.
199 |
200 | Returns:
201 | intersect_bounds: [R', 2] float tensor. The bounds per-ray, sorted in
202 | ascending order.
203 | intersect_indices: [R', 1] float tensor. The intersection indices.
204 | intersect_mask: [R,] float tensor. The mask denoting intersections.
205 | """
206 | # Compute ray destinations.
207 | normalized_rays_d = ray_utils.normalize_rays(rays=rays_d)
208 | rays_dst = rays_o + far_limit * normalized_rays_d
209 |
210 | # Transform the rays from world to box coordinate frame.
211 | rays_o_in_box_frame, intersect_mask, intersect_points_in_box_frame, intersect_indices = ( # [R,], [R', 2, 3], [R', 2]
212 | ray_box_intersection_pairwise(
213 | box_center=box_center,
214 | box_rotation_matrix=box_rotation,
215 | box_length=box_length,
216 | box_width=box_width,
217 | box_height=box_height,
218 | rays_start_point=rays_o,
219 | rays_end_point=rays_dst))
220 | intersect_indices = intersect_indices.unsqueeze(1).long() # [R', 1]
221 | intersect_bounds = compute_bounds_from_intersect_points(
222 | rays_o=rays_o_in_box_frame,
223 | intersect_indices=intersect_indices,
224 | intersect_points=intersect_points_in_box_frame)
225 | return intersect_bounds, intersect_indices, intersect_mask
226 |
--------------------------------------------------------------------------------
/ObjectFolder1.0/cam_utils.py:
--------------------------------------------------------------------------------
1 | """Various camera utility functions."""
2 |
3 | import numpy as np
4 |
5 |
6 | def w2c_to_c2w(w2c):
7 | """
8 | Args:
9 | w2c: [N, 4, 4] np.float32. World-to-camera extrinsics matrix.
10 |
11 | Returns:
12 | c2w: [N, 4, 4] np.float32. Camera-to-world extrinsics matrix.
13 | """
14 | R = w2c[:3, :3]
15 | T = w2c[:3, 3]
16 |
17 | c2w = np.eye(4, dtype=np.float32)
18 | c2w[:3, 3] = -1 * np.dot(R.transpose(), w2c[:3, 3])
19 | c2w[:3, :3] = R.transpose()
20 | return c2w
21 |
--------------------------------------------------------------------------------
/ObjectFolder1.0/dataset_visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/ObjectFolder1.0/dataset_visualization.png
--------------------------------------------------------------------------------
/ObjectFolder1.0/demo/audio_demo_forces.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/ObjectFolder1.0/demo/audio_demo_forces.npy
--------------------------------------------------------------------------------
/ObjectFolder1.0/demo/audio_demo_vertices.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/ObjectFolder1.0/demo/audio_demo_vertices.npy
--------------------------------------------------------------------------------
/ObjectFolder1.0/demo/touch_demo_vertices.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/ObjectFolder1.0/demo/touch_demo_vertices.npy
--------------------------------------------------------------------------------
/ObjectFolder1.0/demo/vision_demo.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/ObjectFolder1.0/demo/vision_demo.npy
--------------------------------------------------------------------------------
/ObjectFolder1.0/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import datetime
3 | import sys
4 | import torch
5 | import numpy as np
6 | import imageio
7 | import json
8 | import random
9 | import time
10 | from tqdm import tqdm, trange
11 | from scipy.spatial import KDTree
12 |
13 | import indirect_utils
14 | from load_osf import load_osf_data
15 | from intersect import compute_object_intersect_tensors
16 | from ray_utils import transform_rays
17 | from run_osf_helpers import *
18 | from scatter import scatter_coarse_and_fine
19 | import shadow_utils
20 |
21 | import VisionNet_utils
22 | import AudioNet_utils
23 | import AudioNet_model
24 | import TouchNet_utils
25 | import TouchNet_model
26 |
27 | from scipy.io.wavfile import write
28 | import librosa
29 | import imageio
30 |
31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32 |
33 | def config_parser():
34 | import configargparse
35 | parser = configargparse.ArgParser(
36 | config_file_parser_class=configargparse.YAMLConfigFileParser
37 | )
38 | #parser = configargparse.ArgumentParser()
39 | parser.add_argument("--object_file_path", type=str, required=True, help='ObjectFile path')
40 | parser.add_argument('--config', is_config_file=True, help='config file path')
41 |
42 | # VisionNet options
43 | parser.add_argument("--vision_test_file_path", default='data/vision_demo.npy', help='The path of the testing file for vision, which should be a npy file.')
44 | parser.add_argument("--vision_results_dir", type=str, default='./results/vision/', help='The path of the vision results directory to save rendered images.')
45 | parser.add_argument("--chunk", type=int, default=1024*32,
46 | help='number of rays processed in parallel, decrease if running out of memory')
47 | parser.add_argument("--netchunk", type=int, default=1024*64,
48 | help='number of pts sent through network in parallel, decrease if running out of memory')
49 |
50 | # AudioNet options
51 | parser.add_argument('--audio_vertices_file_path', default='./data/audio_demo_vertices.npy', help='The path of the testing vertices file for audio, which should be a npy file.')
52 | parser.add_argument('--audio_forces_file_path', default='./data/forces.npy', help='The path of forces file for audio, which should be a npy file.')
53 | parser.add_argument('--audio_batchSize', type=int, default=10000, help='input batch size')
54 | parser.add_argument('--audio_results_dir', type=str, default='./results/audio/', help='The path of audio results directory to save rendered impact sounds as .wav files.')
55 |
56 | # TouchNet options
57 | parser.add_argument('--touch_vertices_file_path', default='./data/touch_demo_vertices.npy', help='The path of the testing vertices file for touch, which should be a npy file.')
58 | parser.add_argument('--touch_batchSize', type=int, default=10000, help='input batch size')
59 | parser.add_argument('--touch_results_dir', type=str, default='./results/touch/', help='The path of the touch results directory to save rendered tactile RGB images.')
60 |
61 | return parser
62 |
63 |
64 | def VisionNet_eval(args):
65 |
66 | args.secondary_chunk = args.chunk
67 |
68 | metadata, render_metadata = None, None
69 | near = 0.01
70 | far = 4
71 |
72 | poses, hwf, i_split, metadata = load_osf_data(args.vision_test_file_path)
73 | i_test = i_split[0]
74 |
75 | render_poses = np.array(poses[i_test])
76 |
77 | # Create dummy metadata if not loaded from dataset.
78 | if metadata is None:
79 | metadata = torch.tensor([[0, 0, 1]] * len(images), dtype=torch.float) # [N, 3]
80 | if render_metadata is None:
81 | render_metadata = metadata
82 |
83 | # Cast intrinsics to right types
84 | H, W, focal = hwf
85 | H, W = int(H), int(W)
86 | hwf = [H, W, focal]
87 |
88 | # Create nerf model
89 | render_kwargs_train, render_kwargs_test, start, grad_vars, models, optimizer = VisionNet_utils.create_nerf(
90 | args, metadata, render_metadata)
91 | global_step = start
92 |
93 | bds_dict = {
94 | 'near': torch.tensor(near).float(),
95 | 'far': torch.tensor(far).float(),
96 | }
97 | render_kwargs_train.update(bds_dict)
98 | render_kwargs_test.update(bds_dict)
99 |
100 | render_kwargs_test['render_metadata'] = metadata
101 |
102 | # Move testing data to GPU
103 | render_poses = torch.Tensor(render_poses).to(device)
104 |
105 | with torch.no_grad():
106 | images = None
107 |
108 | testsavedir = args.vision_results_dir
109 | os.makedirs(testsavedir, exist_ok=True)
110 | print('Begin rendering images in ', testsavedir)
111 | rgbs, _ = VisionNet_utils.render_path(render_poses, hwf, args.chunk, render_kwargs_test,
112 | gt_imgs=images, savedir=testsavedir,
113 | c2w_staticcam=None, render_start=None,
114 | render_end=None)
115 | print('Done rendering images in ', testsavedir)
116 |
117 |
118 | def AudioNet_eval(args):
119 | dim = 32
120 | audio_sampling_rate = 16000
121 | audio_window_size = 400
122 | audio_hop_size = 160
123 | audio_stft_time_dim = 201
124 | audio_stft_freq_dim = 257
125 | audio_network_depth = 8
126 | xyz = np.load(args.audio_vertices_file_path)
127 | forces = np.load(args.audio_forces_file_path)
128 | xyz = xyz[0:1,:]
129 | forces = forces[0:1]
130 | N = xyz.shape[0]
131 | print(N)
132 | #N: number of data
133 | #D: number of dimension
134 | #C: number of channels, real and img
135 | #F: number of frequency
136 | #T: number of timestamps
137 | N, D, C, F, T = N, 3, 2, audio_stft_freq_dim, audio_stft_time_dim
138 |
139 | checkpoint = torch.load(args.object_file_path)
140 | normalizer_dic = checkpoint['AudioNet']['normalizer']
141 | voxel_vertex = checkpoint['AudioNet']['voxel_vertex']
142 | vert_tree = KDTree(voxel_vertex)
143 | translation = checkpoint['AudioNet']['translation']
144 | scale = checkpoint['AudioNet']['scale']
145 |
146 | k = 4 # Average over 4 nearest neighbors
147 | xyz_in_voxel = np.zeros((4, N, 3))
148 | for i in range(N):
149 | obj_coordinates = xyz[i]
150 | binvox_coordinates = AudioNet_utils.transform_mesh_collision_binvox(obj_coordinates, translation, scale)
151 | coordinates_in_voxel = binvox_coordinates * dim
152 | voxel_verts_index = vert_tree.query(coordinates_in_voxel, k)[1]
153 | for j in range(k):
154 | xyz_in_voxel[j, i] = voxel_vertex[voxel_verts_index[j]]
155 |
156 | xyz_in_voxel = np.repeat(xyz_in_voxel.reshape((4, N, 1, 3)), F * T, axis=2)
157 | #normalize xyz_in_voxel to [-1, 1]
158 | xyz_in_voxel_min = xyz_in_voxel.min()
159 | xyz_in_voxel_max = xyz_in_voxel.max()
160 | xyz_in_voxel = (xyz_in_voxel - xyz_in_voxel_min) / (xyz_in_voxel_max - xyz_in_voxel_min)
161 |
162 | spec_comps_f1_min = normalizer_dic['f1_min']
163 | spec_comps_f1_max = normalizer_dic['f1_max']
164 | spec_comps_f2_min = normalizer_dic['f2_min']
165 | spec_comps_f2_max = normalizer_dic['f2_max']
166 | spec_comps_f3_min = normalizer_dic['f3_min']
167 | spec_comps_f3_max = normalizer_dic['f3_max']
168 |
169 | #initialize frequency and time features
170 | freq_feats = np.repeat(np.repeat(np.arange(F).reshape((F, 1)), T, axis=1).reshape((1, 1, F, T)), N, axis=0)
171 | time_feats = np.repeat(np.repeat(np.arange(T).reshape((1, T)), F, axis=0).reshape((1, 1, F, T)), N, axis=0)
172 |
173 | #normalize frequency and time features to [-1, 1]
174 | freq_feats_min = freq_feats.min()
175 | freq_feats_max = freq_feats.max()
176 | time_feats_min = time_feats.min()
177 | time_feats_max = time_feats.max()
178 | freq_feats = (freq_feats - freq_feats_min) / (freq_feats_max - freq_feats_min)
179 | time_feats = (time_feats - time_feats_min) / (time_feats_max - time_feats_min)
180 |
181 | data_x = np.concatenate((freq_feats, time_feats), axis=1)
182 | data_y = np.concatenate((freq_feats, time_feats), axis=1)
183 | data_z = np.concatenate((freq_feats, time_feats), axis=1)
184 | data_x = np.transpose(data_x.reshape((N, 2, -1)), axes = [0, 2, 1])
185 | data_y = np.transpose(data_y.reshape((N, 2, -1)), axes = [0, 2, 1])
186 | data_z = np.transpose(data_z.reshape((N, 2, -1)), axes = [0, 2, 1])
187 | data_x = np.repeat(data_x.reshape((1, N, -1, 2)), k, axis=0)
188 | data_y = np.repeat(data_y.reshape((1, N, -1, 2)), k, axis=0)
189 | data_z = np.repeat(data_z.reshape((1, N, -1, 2)), k, axis=0)
190 |
191 | #Now concatenate xyz and feats to get final feats matrix as [x, y, z, f, t, real, img]
192 | feats_x = np.concatenate((xyz_in_voxel, data_x), axis=3).reshape((-1, 5))
193 | feats_y = np.concatenate((xyz_in_voxel, data_y), axis=3).reshape((-1, 5))
194 | feats_z = np.concatenate((xyz_in_voxel, data_z), axis=3).reshape((-1, 5))
195 |
196 | embed_fn, input_ch = AudioNet_model.get_embedder(10, 0)
197 | model = AudioNet_model.AudioNeRF(D = audio_network_depth, input_ch = input_ch)
198 | state_dic = checkpoint['AudioNet']["model_state_dict"]
199 | state_dic = AudioNet_utils.strip_prefix_if_present(state_dic, 'module.')
200 | model.load_state_dict(state_dic)
201 | model = nn.DataParallel(model).to(device)
202 | model.eval()
203 | loss_fn = torch.nn.MSELoss(reduction='mean')
204 |
205 | start_time = time.time()
206 | preds_x = np.zeros((feats_x.shape[0], 2))
207 | preds_y = np.zeros((feats_y.shape[0], 2))
208 | preds_z = np.zeros((feats_z.shape[0], 2))
209 | N_rand = args.audio_batchSize
210 |
211 | print("Begin rendering impact sounds in ", args.audio_results_dir)
212 | for i in trange(feats_x.shape[0] // N_rand + 1):
213 | curr_feats_x = torch.Tensor(feats_x[i*N_rand:(i+1)*N_rand]).to(device)
214 | curr_feats_y = torch.Tensor(feats_y[i*N_rand:(i+1)*N_rand]).to(device)
215 | curr_feats_z = torch.Tensor(feats_z[i*N_rand:(i+1)*N_rand]).to(device)
216 | embedded_x = embed_fn(curr_feats_x)
217 | embedded_y = embed_fn(curr_feats_y)
218 | embedded_z = embed_fn(curr_feats_z)
219 | results_x, results_y, results_z = model(embedded_x, embedded_y, embedded_z)
220 |
221 | preds_x[i*N_rand:(i+1)*N_rand, :] = results_x.detach().cpu().numpy()
222 | preds_y[i*N_rand:(i+1)*N_rand, :] = results_y.detach().cpu().numpy()
223 | preds_z[i*N_rand:(i+1)*N_rand, :] = results_z.detach().cpu().numpy()
224 |
225 | preds_x = preds_x * (spec_comps_f1_max - spec_comps_f1_min) + spec_comps_f1_min
226 | preds_y = preds_y * (spec_comps_f2_max - spec_comps_f2_min) + spec_comps_f2_min
227 | preds_z = preds_z * (spec_comps_f3_max - spec_comps_f3_min) + spec_comps_f3_min
228 | preds_x = np.transpose(preds_x.reshape((k, N, -1, 2)), axes = [0, 1, 3, 2]).reshape((k, N, 1, C, F, T))
229 | preds_y = np.transpose(preds_y.reshape((k, N, -1, 2)), axes = [0, 1, 3, 2]).reshape((k, N, 1, C, F, T))
230 | preds_z = np.transpose(preds_z.reshape((k, N, -1, 2)), axes = [0, 1, 3, 2]).reshape((k, N, 1, C, F, T))
231 |
232 | #save evaluation results
233 | os.makedirs(args.audio_results_dir, exist_ok=True)
234 |
235 | for i in trange(N):
236 | force_x, force_y, force_z = forces[i]
237 | signal = np.zeros(audio_sampling_rate*2)
238 | for j in range(k):
239 | spec_x = preds_x[j, i, 0, 0, :, :] + preds_x[j, i, 0, 1, :, :] * 1j
240 | signal_x = librosa.istft(spec_x, hop_length=audio_hop_size, win_length=audio_window_size, length=audio_sampling_rate*2)
241 | spec_y = preds_y[j, i, 0, 0, :, :] + preds_y[j, i, 0, 1, :, :] * 1j
242 | signal_y = librosa.istft(spec_y, hop_length=audio_hop_size, win_length=audio_window_size, length=audio_sampling_rate*2)
243 | spec_z = preds_z[j, i, 0, 0, :, :] + preds_z[j, i, 0, 1, :, :] * 1j
244 | signal_z = librosa.istft(spec_z, hop_length=audio_hop_size, win_length=audio_window_size, length=audio_sampling_rate*2)
245 | temp = signal_x * force_x + signal_y * force_y + signal_z * force_z
246 | signal += temp
247 | signal = signal / np.abs(signal).max()
248 | end_time = time.time()
249 | print(end_time - start_time)
250 | # Write WAV file
251 | output_path = os.path.join(args.audio_results_dir, str(i+1) + '.wav')
252 | write(output_path, audio_sampling_rate, signal.astype(np.float32))
253 | print('Done rendering impact sounds in ', args.audio_results_dir)
254 |
255 |
256 | def TouchNet_eval(args):
257 | touch_network_depth = 8
258 |
259 | xyz = np.load(args.touch_vertices_file_path)
260 |
261 | #N: number of data
262 | #C: channels
263 | #W: Width dimension
264 | #H: Height dimension
265 | #N, C, W, H = touch_images.shape
266 | N, C, W, H = xyz.shape[0], 3, 160, 120
267 |
268 | #initialize frequency and time features
269 | w_feats = np.repeat(np.repeat(np.arange(W).reshape((W, 1)), H, axis=1).reshape((1, 1, W, H)), N, axis=0)
270 | h_feats = np.repeat(np.repeat(np.arange(H).reshape((1, H)), W, axis=0).reshape((1, 1, W, H)), N, axis=0)
271 |
272 | checkpoint = torch.load(args.object_file_path)
273 |
274 | #normalize frequency and time features to [-1, 1]
275 | w_feats_min = w_feats.min()
276 | w_feats_max = w_feats.max()
277 | h_feats_min = h_feats.min()
278 | h_feats_max = h_feats.max()
279 | w_feats = 2 * ((w_feats - w_feats_min) / w_feats_max) - 1
280 | h_feats = 2 * ((h_feats - h_feats_min) / h_feats_max) - 1
281 |
282 | data_x = np.concatenate((w_feats, h_feats), axis=1)
283 | data_x = np.transpose(data_x.reshape((N, 2, -1)), axes = [0, 2, 1])
284 |
285 | xyz = np.repeat(xyz.reshape((N, 1, 3)), W * H, axis=1)
286 |
287 | #normalize xyz to [-1, 1]
288 | xyz_min = xyz.min()
289 | xyz_max = xyz.max()
290 | xyz = 2 * ((xyz - xyz_min) / xyz_max) - 1
291 |
292 | #Now concatenate xyz and feats to get final feats matrix as [x, y, z, w, h, r, g, b]
293 | data= np.concatenate((xyz, data_x), axis=2).reshape((-1, 5))
294 | feats = data
295 |
296 | embed_fn, input_ch = TouchNet_model.get_embedder(10, 0)
297 | model = TouchNet_model.NeRF(D = touch_network_depth, input_ch = input_ch, output_ch = 3)
298 | state_dic = checkpoint['TouchNet']["model_state_dict"]
299 | state_dic = TouchNet_utils.strip_prefix_if_present(state_dic, 'module.')
300 | model.load_state_dict(state_dic)
301 | model = nn.DataParallel(model).to(device)
302 | model.eval()
303 | loss_fn = torch.nn.MSELoss(reduction='mean')
304 |
305 | preds = np.zeros((feats.shape[0], 3))
306 | N_rand = args.touch_batchSize
307 |
308 | print("Begin rendering tactile images in ", args.touch_results_dir)
309 | start_time = time.time()
310 | for i in trange(feats.shape[0] // N_rand + 1):
311 | curr_feats = torch.Tensor(feats[i*N_rand:(i+1)*N_rand]).to(device)
312 | embedded = embed_fn(curr_feats)
313 | results = model(embedded)
314 | preds[i*N_rand:(i+1)*N_rand, :] = results.detach().cpu().numpy()
315 | end_time = time.time()
316 | print(end_time - start_time)
317 |
318 | preds = (((preds + 1) / 2) * 255)
319 | preds = np.transpose(preds.reshape((N, -1, 3)), axes = [0, 2, 1]).reshape((N, C, W, H))
320 | preds = np.clip(np.rint(preds), 0, 255).astype(np.uint8)
321 | preds = preds.transpose(0,2,3,1)
322 |
323 | os.makedirs(args.touch_results_dir, exist_ok=True)
324 | #save evaluation results
325 | for i in trange(N):
326 | filename = os.path.join(args.touch_results_dir, '{}.png'.format(i+1))
327 | imageio.imwrite(filename, preds[i])
328 | print("Done rendering tactile images in ", args.touch_results_dir)
329 |
330 | if __name__ =='__main__':
331 | parser = config_parser()
332 | args = parser.parse_args()
333 |
334 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
335 |
336 | VisionNet_eval(args=args)
337 | AudioNet_eval(args=args)
338 | TouchNet_eval(args=args)
339 |
--------------------------------------------------------------------------------
/ObjectFolder1.0/indirect_utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions for indirect illumination."""
2 | import math
3 |
4 | import torch
5 |
6 | import ray_utils
7 |
8 |
9 | def create_ray_batch(pts, near, far, rays_i, use_viewdirs):
10 | """Create batch for indirect rays.
11 |
12 | Args:
13 | pts: [R, S, 3] float tensor. Primary points.
14 | near: Near sampling bound.
15 | far: Far sampling bound.
16 | rays_i: [R, 1] float tensor. Ray image IDs.
17 | use_viewdirs: bool. Whether to use view directions.
18 |
19 | Returns:
20 | ray_batch: [RS, M] float tensor. Batch of secondary rays containing one secondary
21 | ray for each primary ray sample. Each secondary ray originates at a primary
22 | point and points in the direction towards the randomly sampled (indirect) light
23 | source.
24 | """
25 | num_primary_rays = pts.size()[0]
26 | num_primary_samples = pts.size()[1]
27 |
28 | rays_dst = ray_utils.sample_random_lightdirs(num_primary_rays, num_primary_samples) # [R, S, 3]
29 |
30 | rays_o = pts.view(-1, 3) # [RS, 3]
31 | rays_dst = pts.view(-1, 3) # [RS, 3]
32 |
33 | rays_near = torch.full((rays_o.size()[0], 1), near).float() # [RS, 1]
34 | rays_far = torch.full((rays_o.size()[0], 1), far).float() # [RS, 1]
35 |
36 | rays_i = torch.tile(rays_i[:, None, :], (1, num_primary_samples, 1)) # [R?, S, 1]
37 | rays_i = rays_i.view(-1, 1) # [RS, 1]
38 |
39 | ray_batch = ray_utils.create_ray_batch(
40 | rays_o=rays_o, rays_dst=rays_dst, rays_near=rays_near, rays_far=rays_far,
41 | rays_i=rays_i, use_viewdirs=use_viewdirs)
42 |
43 | return ray_batch
44 |
--------------------------------------------------------------------------------
/ObjectFolder1.0/intersect.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import box_utils
4 | import ray_utils
5 |
6 |
7 | def apply_mask_to_tensors(mask, tensors):
8 | """Applies mask to a list of tensors.
9 |
10 | Args:
11 | mask: [R, ...]. Mask to apply.
12 | tensors: List of [R, ...]. List of tensors.
13 |
14 | Returns:
15 | intersect_tensors: List of shape [R?, ...]. Masked tensors.
16 | """
17 | intersect_tensors = []
18 | for t in tensors:
19 | intersect_t = t[mask] # [R?, ...]
20 | intersect_tensors.append(intersect_t)
21 | return intersect_tensors
22 |
23 |
24 | def get_full_intersection_tensors(ray_batch):
25 | """Test case that selects all rays."""
26 | # mask = torch.rand(ray_batch.size())[:, 0] # [R?,]
27 | # mask = (mask > 0.5).bool()
28 | mask = torch.ones_like(ray_batch, dtype=torch.bool)[:, 0]
29 |
30 | n_intersect = mask.size()[0] # R?
31 | indices = torch.arange(n_intersect, dtype=torch.int).unsqueeze(1) # [R?, 1]
32 | indices = apply_mask_to_tensors( # [R?, M]
33 | mask=mask, # [R,]
34 | tensors=[indices])[0] # [R, M]
35 |
36 | bounds = ray_batch[:, 6:8] # [R?,]
37 | bounds = apply_mask_to_tensors( # [R?, M]
38 | mask=mask, # [R,]
39 | tensors=[bounds])[0] # [R, M]
40 | return mask, indices, bounds
41 |
42 |
43 | def compute_object_intersect_tensors(ray_batch, box_center, box_dims):
44 | """Compute rays that intersect with bounding boxes.
45 |
46 | Args:
47 | ray_batch: [R, M] float tensor. Batch of rays.
48 | box_center: List of 3 floats containing the (x, y, z) center of bbox.
49 | box_dims: List of 3 floats containing the x, y, z dimensions of the bbox.
50 |
51 | Returns:
52 | intersect_ray_batch: [R?, M] float tensor. Batch of intersecting rays.
53 | indices: [R?, 1] float tensor. Indices of intersecting rays.
54 | """
55 | # Check that bbox params are properly formed.
56 | for lst in [box_center, box_dims]:
57 | assert type(lst) == list
58 | assert len(lst) == 3
59 | assert all((isinstance(x, int) or isinstance(x, float)) for x in lst)
60 |
61 | # For now, we assume bbox has no rotation.
62 | num_rays = ray_batch.size()[0] # R
63 | box_center = torch.tile(torch.tensor(box_center), (num_rays, 1)).float() # [R, 3]
64 | box_dims = torch.tile(torch.tensor(box_dims), (num_rays, 1)).float() # [R, 3]
65 | box_rotation = torch.tile(torch.eye(3).unsqueeze(0), (num_rays, 1, 1)).float() # [R, 3, 3]
66 |
67 | # Compute ray-bbox intersections.
68 | bounds, indices, mask = box_utils.compute_ray_bbox_bounds_pairwise( # [R', 2], [R',], [R,]
69 | rays_o=ray_batch[:, 0:3], # [R, 3]
70 | rays_d=ray_batch[:, 3:6], # [R, 3]
71 | box_length=box_dims[:, 0], # [R,]
72 | box_width=box_dims[:, 1], # [R,]
73 | box_height=box_dims[:, 2], # [R,]
74 | box_center=box_center, # [R, 3]
75 | box_rotation=box_rotation) # [R, 3, 3]
76 |
77 | # Apply the intersection mask to the ray batch.
78 | intersect_ray_batch = apply_mask_to_tensors( # [R?, M]
79 | mask=mask, # [R,]
80 | tensors=[ray_batch])[0] # [R, M]
81 |
82 | # Update the near and far bounds of the ray batch with the intersect bounds.
83 | intersect_ray_batch = ray_utils.update_ray_batch_bounds( # [R?, M]
84 | ray_batch=intersect_ray_batch, # [R?, M]
85 | bounds=bounds) # [R?, 2]
86 | return intersect_ray_batch, indices, mask # [R?, M], [R?, 1]
87 |
--------------------------------------------------------------------------------
/ObjectFolder1.0/load_osf.py:
--------------------------------------------------------------------------------
1 | """Data loader for OSF data."""
2 |
3 | import os
4 | import torch
5 | import numpy as np
6 | import imageio
7 | import json
8 |
9 | import cam_utils
10 |
11 |
12 | trans_t = lambda t: torch.tensor([
13 | [1,0,0,0],
14 | [0,1,0,0],
15 | [0,0,1,t],
16 | [0,0,0,1]
17 | ], dtype=torch.float)
18 |
19 | rot_phi = lambda phi: torch.tensor([
20 | [1,0,0,0],
21 | [0,np.cos(phi),-np.sin(phi),0],
22 | [0,np.sin(phi), np.cos(phi),0],
23 | [0,0,0,1]
24 | ], dtype=torch.float)
25 |
26 | rot_theta = lambda th: torch.tensor([
27 | [np.cos(th),0,-np.sin(th),0],
28 | [0,1,0,0],
29 | [np.sin(th),0, np.cos(th),0],
30 | [0,0,0,1]
31 | ], dtype=torch.float)
32 |
33 |
34 | def pose_spherical(theta, phi, radius):
35 | c2w = trans_t(radius)
36 | c2w = rot_phi(phi/180.*np.pi) @ c2w
37 | c2w = rot_theta(theta/180.*np.pi) @ c2w
38 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
39 | return c2w
40 |
41 | def coordinates_to_c2w(x, y, z, r=2.5):
42 | theta = np.arccos(z / r)
43 | phi = np.arctan2(x, -y)
44 | Rx = np.array([[1, 0, 0], [0, np.cos(theta), -np.sin(theta)], [0, np.sin(theta), np.cos(theta)]])
45 | Rz = np.array([[np.cos(phi), -np.sin(phi), 0], [np.sin(phi), np.cos(phi), 0], [0, 0, 1]])
46 | R = Rz @ Rx
47 | c2w = R.tolist()
48 | c2w[0].append(x)
49 | c2w[1].append(y)
50 | c2w[2].append(z)
51 | c2w.append([0., 0., 0., 1.])
52 | #c2w = np.array(c2w).astype(np.float32)
53 | return c2w
54 |
55 | def convert_cameras_to_nerf_format(anno):
56 | """
57 | Args:
58 | anno: List of annotations for each example. Each annotation is represented by a
59 | dictionary that must contain the key `RT` which is the world-to-camera
60 | extrinsics matrix with shape [3, 4], in [right, down, forward] coordinates.
61 |
62 | Returns:
63 | c2w: [N, 4, 4] np.float32. Array of camera-to-world extrinsics matrices in
64 | [right, up, backwards] coordinates.
65 | """
66 | c2w_list = []
67 | for a in anno:
68 | # Convert from w2c to c2w.
69 | w2c = np.array(a['RT'] + [[0.0, 0.0, 0.0, 1.0]])
70 | c2w = cam_utils.w2c_to_c2w(w2c)
71 |
72 | # Convert from [right, down, forwards] to [right, up, backwards]
73 | c2w[:3, 1] *= -1 # down -> up
74 | c2w[:3, 2] *= -1 # forwards -> back
75 | c2w_list.append(c2w)
76 | c2w = np.array(c2w_list)
77 | print("c2w: ", c2w)
78 | return c2w
79 |
80 |
81 | def load_osf_data(test_file_path):
82 |
83 | all_poses = []
84 | all_metadata = []
85 | counts = [0]
86 | test_file = np.load(test_file_path)
87 | N = test_file.shape[0]
88 | for i in range(N):
89 | cx, cy, cz, lx, ly, lz = test_file[i]
90 | poses = coordinates_to_c2w(cx, cy, cz)
91 | metadata = np.array([[lx, ly, lz]]).astype(np.float32)
92 | all_poses.append(poses)
93 | all_metadata.append(metadata)
94 |
95 | poses = np.array(all_poses).astype(np.float32)
96 |
97 | metadata = np.concatenate(all_metadata, 0)
98 | counts.append(N)
99 | i_split = [np.arange(counts[0], counts[1])]
100 |
101 | H, W, focal = 256, 256, 355.5555419921875
102 |
103 | return poses, [H, W, focal], i_split, metadata
104 |
--------------------------------------------------------------------------------
/ObjectFolder1.0/ray_utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions for ray computation."""
2 | import math
3 | import numpy as np
4 | from scipy.spatial.transform import Rotation as R
5 | import box_utils
6 | import torch
7 |
8 |
9 | def apply_batched_transformations(inputs, transformations):
10 | """Batched transformation of inputs.
11 |
12 | Args:
13 | inputs: List of [R, S, 3]
14 | transformations: [R, 4, 4]
15 |
16 | Returns:
17 | transformed_inputs: List of [R, S, 3]
18 | """
19 | # if rotation_only:
20 | # transformations[:, :3, 3] = torch.zeros((3,), dtype=torch.float)
21 |
22 | transformed_inputs = []
23 | for x in inputs:
24 | N_samples = x.size()[1]
25 | homog_transformations = transformations.unsqueeze(1) # [R, 1, 4, 4]
26 | homog_transformations = torch.tile(homog_transformations, (1, N_samples, 1, 1)) # [R, S, 4, 4]
27 | homog_component = torch.ones_like(x)[..., 0:1] # [R, S, 1]
28 | homog_x = torch.cat((x, homog_component), axis=-1) # [R, S, 4]
29 | homog_x = homog_x.unsqueeze(2)
30 | transformed_x = torch.matmul(
31 | homog_x,
32 | torch.transpose(homog_transformations, 2, 3)) # [R, S, 1, 4]
33 | transformed_x = transformed_x[..., 0, :3] # [R, S, 3]
34 | transformed_inputs.append(transformed_x)
35 | return transformed_inputs
36 |
37 |
38 | def get_transformation_from_params(params):
39 | translation, rotation = [0, 0, 0], [0, 0, 0]
40 | if 'translation' in params:
41 | translation = params['translation']
42 | if 'rotation' in params:
43 | rotation = params['rotation']
44 | translation = torch.tensor(translation, dtype=torch.float)
45 | rotmat = torch.tensor(R.from_euler('xyz', rotation, degrees=True).as_matrix(), dtype=torch.float)
46 | return translation, rotmat
47 |
48 |
49 | def rotate_dirs(dirs, rotmat):
50 | """
51 | Args:
52 | dirs: [R, 3] float tensor.
53 | rotmat: [3, 3]
54 | """
55 | if type(dirs) == np.ndarray:
56 | dirs = torch.tensor(dirs).float()
57 | #rotmat = rotmat.unsqueeze(0)
58 | rotmat = torch.broadcast_to(rotmat, (dirs.shape[0], 3, 3)) # [R, 3, 3]
59 | dirs_obj = torch.matmul(dirs.unsqueeze(1), torch.transpose(rotmat, 1, 2)) # [R, 1, 3]
60 | dirs_obj = dirs_obj.squeeze(1) # [R, 3]
61 | return dirs_obj
62 |
63 |
64 | def transform_dirs(dirs, params, inverse=False):
65 | _, rotmat = get_transformation_from_params(params) # [3,], [3, 3]
66 | if inverse:
67 | rotmat = torch.transpose(rotmat, 0, 1) # [3, 3]
68 | dirs_transformed = rotate_dirs(dirs, rotmat)
69 | return dirs_transformed
70 |
71 |
72 | def transform_rays(ray_batch, params, use_viewdirs, inverse=False):
73 | """Transform rays into object coordinate frame given o2w transformation params.
74 |
75 | Note: do not assume viewdirs is always the normalized version of rays_d (e.g., in staticcam case).
76 |
77 | Args:
78 | ray_batch: [R, M] float tensor. Batch of rays.
79 | params: Dictionary containing transformation parameters:
80 | 'translation': List of 3 elements. xyz translation.
81 | 'rotation': List of 3 euler angles in xyz.
82 | use_viewdirs: bool. Whether to we are using viewdirs.
83 | inverse: bool. Whether to apply inverse of the transformations provided in 'params'.
84 |
85 | Returns:
86 | ray_batch_obj: [R, M] float tensor. The ray batch, in object coordinate frame.
87 | """
88 | rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6]
89 | translation, rotmat = get_transformation_from_params(params) # [3,], [3, 3]
90 |
91 | if inverse:
92 | translation = -1 * translation # [3,]
93 | rotmat = torch.transpose(rotmat, 1, 0) # [3, 3]
94 |
95 | translation_inverse = -1 * translation
96 | rotmat_inverse = torch.transpose(rotmat, 1, 0)
97 |
98 | # Transform the ray origin.
99 | rays_o_obj, _ = box_utils.ray_to_box_coordinate_frame_pairwise(
100 | box_center=translation_inverse,
101 | box_rotation_matrix=rotmat_inverse,
102 | rays_start_point=rays_o,
103 | rays_end_point=rays_d)
104 |
105 | # Only apply rotation to rays_d.
106 | rays_d_obj = rotate_dirs(rays_d, rotmat)
107 |
108 | ray_batch_obj = update_ray_batch_slice(ray_batch, rays_o_obj, 0, 3)
109 | ray_batch_obj = update_ray_batch_slice(ray_batch_obj, rays_d_obj, 3, 6)
110 | if use_viewdirs:
111 | # Grab viewdirs from the ray batch itself. Because it may be different from rays_d
112 | # (as in the staticcam case).
113 | viewdirs = ray_batch[:, 8:11]
114 | viewdirs_obj = rotate_dirs(viewdirs, rotmat)
115 | ray_batch_obj = update_ray_batch_slice(ray_batch_obj, viewdirs_obj, 8, 11)
116 | return ray_batch_obj
117 |
118 |
119 | def transform_points_into_world_coordinate_frame(pts, params, check_numerics=False):
120 | translation, rotmat = get_transformation_from_params(params) # [3,], [3, 3]
121 |
122 | # pts_flat = pts.view(-1, 3) # [RS, 3]
123 | # num_examples = pts_flat.size()[0] # RS
124 |
125 | # translation = translation.unsqueeze(0)
126 | # translation = torch.tile(translation, (num_examples, 1)) # [RS, 3]
127 | # rotmat = rotmat.unsqueeze(0)
128 | # rotmat = torch.tile(rotmat, (num_examples, 1, 1))
129 |
130 | # # pts_flat_transformed = torch.matmul(pts_flat[:, None, :], torch.transpose(rotmat, 2, 1)) # [RS, 1, 3]
131 | # pts_flat_transformed = pts_flat[:, None, :] # [RS, 1, 3]
132 | # pts_flat_transformed += translation[:, None, :] # [RS, 1, 3]
133 | # pts_transformed = pts_flat_transformed.view(pts.size()) # [R, S, 3]
134 | chunk = 256
135 | # Check batch transformations works without rotation.
136 | if check_numerics:
137 | transformations = np.eye(4)
138 | transformations[:3, 3] = translation
139 | transformations = torch.tensor(transformations, dtype=torch.float) # [4, 4]
140 | transformations = torch.tile(transformations[None, ...], (pts.size()[0], 1, 1)) # [R, 4, 4]
141 | pts_transformed1 = []
142 | for i in range(0, pts.size()[0], chunk):
143 | pts_transformed1_chunk = apply_batched_transformations(
144 | inputs=[pts[i:i+chunk]], transformations=transformations[i:i+chunk])[0]
145 | pts_transformed1.append(pts_transformed1_chunk)
146 | pts_transformed1 = torch.cat(pts_transformed1, dim=0)
147 |
148 | pts_transformed2 = pts + translation[None, None, :]
149 |
150 | # Now add rotation
151 | transformations = np.eye(4)
152 | transformations = torch.tensor(transformations, dtype=torch.float)
153 | transformations[:3, :3] = rotmat
154 | transformations[:3, 3] = translation
155 | #transformations = torch.tensor(transformations, dtype=torch.float) # [4, 4]
156 | transformations = torch.tile(transformations[None, ...], (pts.size()[0], 1, 1)) # [R, 4, 4]
157 | pts_transformed = []
158 | for i in range(0, pts.size()[0], chunk):
159 | pts_transformed_chunk = apply_batched_transformations(
160 | inputs=[pts[i:i+chunk]], transformations=transformations[i:i+chunk])[0]
161 | pts_transformed.append(pts_transformed_chunk)
162 | pts_transformed = torch.cat(pts_transformed, dim=0)
163 | return pts_transformed
164 |
165 |
166 | # def transform_rays(ray_batch, translation, use_viewdirs):
167 | # """Apply transformation to rays.
168 |
169 | # Args:
170 | # ray_batch: [R, M] float tensor. All information necessary
171 | # for sampling along a ray, including: ray origin, ray direction, min
172 | # dist, max dist, and unit-magnitude viewing direction.
173 | # translation: [3,] float tensor. The (x, y, z) translation to apply.
174 | # use_viewdirs: Whether to use view directions.
175 |
176 | # Returns:
177 | # ray_batch: [R, M] float tensor. Transformed ray batch.
178 | # """
179 | # assert translation.size()[0] == 3, "translation.size()[0] must be 3..."
180 |
181 | # # Since we are only supporting translation for now, only ray origins need to be
182 | # # modified. Ray directions do not need to change.
183 | # rays_o = ray_batch[:, 0:3] + translation
184 | # rays_remaining = ray_batch[:, 3:]
185 | # ray_batch = torch.cat((rays_o, rays_remaining), dim=1)
186 | # return ray_batch
187 |
188 | def compute_rays_length(rays_d):
189 | """Compute ray length.
190 |
191 | Args:
192 | rays_d: [R, 3] float tensor. Ray directions.
193 |
194 | Returns:
195 | rays_length: [R, 1] float tensor. Ray lengths.
196 | """
197 | rays_length = torch.norm(rays_d, dim=-1, keepdim=True) # [N_rays, 1]
198 | return rays_length
199 |
200 |
201 | def normalize_rays(rays):
202 | """Normalize ray directions.
203 |
204 | Args:
205 | rays: [R, 3] float tensor. Ray directions.
206 |
207 | Returns:
208 | normalized_rays: [R, 3] float tensor. Normalized ray directions.
209 | """
210 | normalized_rays = rays / compute_rays_length(rays_d=rays)
211 | return normalized_rays
212 |
213 |
214 | def compute_ray_dirs_and_length(rays_o, rays_dst):
215 | """Compute ray directions.
216 |
217 | Args:
218 | rays_o: [R, 3] float tensor. Ray origins.
219 | rays_dst: [R, 3] float tensor. Ray destinations.
220 |
221 | Returns:
222 | rays_d: [R, 3] float tensor. Normalized ray directions.
223 | """
224 | # The ray directions are the difference between the ray destinations and the
225 | # ray origins.
226 | rays_d = rays_dst - rays_o # [R, 3] # Direction out of light source
227 |
228 | # Compute the length of the rays.
229 | rays_length = compute_rays_length(rays_d=rays_d)
230 |
231 | # Normalized the ray directions.
232 | rays_d = rays_d / rays_length # [R, 3] # Normalize direction
233 | return rays_d, rays_length
234 |
235 |
236 | def update_ray_batch_slice(ray_batch, x, start, end):
237 | left = ray_batch[:, :start] # [R, ?]
238 | right = ray_batch[:, end:] # [R, ?]
239 | updated_ray_batch = torch.cat((left, x, right), dim=-1)
240 | return updated_ray_batch
241 |
242 |
243 | def update_ray_batch_bounds(ray_batch, bounds):
244 | updated_ray_batch = update_ray_batch_slice(ray_batch=ray_batch, x=bounds,
245 | start=6, end=8)
246 | return updated_ray_batch
247 |
248 |
249 | def create_ray_batch(
250 | rays_o, rays_dst, rays_i, use_viewdirs, rays_near=None, rays_far=None, epsilon=1e-10):
251 | # Compute the ray directions.
252 | rays_d = rays_dst - rays_o # [R,3] # Direction out of light source
253 | rays_length = compute_rays_length(rays_d=rays_d) # [R, 1]
254 | rays_d = rays_d / rays_length # [R, 3] # Normalize direction
255 | viewdirs = rays_d # [R, 3]
256 |
257 | # If bounds are not provided, set the beginning and end of ray as sampling bounds.
258 | if rays_near is None:
259 | rays_near = torch.zeros((rays_o.size()[0], 1), dtype=torch.float) + epsilon # [R, 1]
260 | if rays_far is None:
261 | rays_far = rays_length # [R, 1]
262 |
263 | ray_batch = torch.cat((rays_o, rays_d, rays_near, rays_far), dim=-1)
264 | if use_viewdirs:
265 | ray_batch = torch.cat((ray_batch, viewdirs), dim=-1)
266 | ray_batch = torch.cat((ray_batch, rays_i), dim=-1)
267 | return ray_batch
268 |
269 |
270 | def sample_random_lightdirs(num_rays, num_samples, upper_only=False):
271 | """Randomly sample directions in the unit sphere.
272 |
273 | Args:
274 | num_rays: int or tensor shape dimension. Number of rays.
275 | num_samples: int or tensor shape dimension. Number of samples per ray.
276 | upper_only: bool. Whether to sample only on the upper hemisphere.
277 |
278 | Returns:
279 | lightdirs: [R, S, 3] float tensor. Random light directions sampled from the unit
280 | sphere for each sampled point.
281 | """
282 | if upper_only:
283 | min_z = 0
284 | else:
285 | min_z = -1
286 |
287 | phi = torch.rand(num_rays, num_samples) * (2 * math.pi) # [R, S]
288 | cos_theta = torch.rand(num_rays, num_samples) * (1 - min_z) + min_z # [R, S]
289 | theta = torch.acos(cos_theta) # [R, S]
290 |
291 | x = torch.sin(theta) * torch.cos(phi)
292 | y = torch.sin(theta) * torch.sin(phi)
293 | z = torch.cos(theta)
294 |
295 | lightdirs = torch.cat((x[..., None], y[..., None], z[..., None]), dim=-1) # [R, S, 3]
296 | return lightdirs
297 |
298 |
299 | def get_light_positions(rays_i, img_light_pos):
300 | """Extracts light positions given scene IDs.
301 |
302 | Args:
303 | rays_i: [R, 1] float tensor. Per-ray image IDs.
304 | img_light_pos: [N, 3] float tensor. Per-image light positions.
305 |
306 | Returns:
307 | rays_light_pos: [R, 3] float tensor. Per-ray light positions.
308 | """
309 | rays_light_pos = img_light_pos[rays_i.long()].squeeze() # [R, 3]
310 | return rays_light_pos
311 |
312 |
313 | def get_lightdirs(lightdirs_method, num_rays=None, num_samples=None, rays_i=None,
314 | metadata=None, ray_batch=None, use_viewdirs=False, normalize=False):
315 | """Compute lightdirs.
316 |
317 | Args:
318 | lightdirs_method: str. Method to use for computing lightdirs.
319 | num_rays: int or tensor shape dimension. Number of rays.
320 | num_samples: int or tensor shape dimension. Number of samples per ray.
321 | rays_i: [R, 1] float tensor. Ray image IDs.
322 | metadata: [N, 3] float tensor. Metadata about each image. Currently only light
323 | position is provided.
324 | ray_batch: [R, M] float tensor. Ray batch.
325 | use_viewdirs: bool. Whether to use viewdirs.
326 | normalize: bool. Whether to normalize lightdirs.
327 |
328 | Returns;
329 | lightdirs: [R, S, 3] float tensor. Light directions for each sample.
330 | """
331 | if lightdirs_method == 'viewdirs':
332 | raise NotImplementedError
333 | assert use_viewdirs
334 | lightdirs = ray_batch[:, 8:11] # [R, 3]
335 | lightdirs *= 1.5
336 | lightdirs = torch.tile(lightdirs[:, None, :], (1, num_samples, 1))
337 | elif lightdirs_method == 'metadata':
338 | lightdirs = get_light_positions(rays_i, metadata) # [R, 3]
339 | lightdirs = torch.tile(lightdirs[:, None, :], (1, num_samples, 1)) # [R, S, 3]
340 | elif lightdirs_method == 'random':
341 | lightdirs = sample_random_lightdirs(num_rays, num_samples) # [R, S, 3]
342 | elif lightdirs_method == 'random_upper':
343 | lightdirs = sample_random_lightdirs(num_rays, num_samples, upper_only=True) # [R, S, 3]
344 | else:
345 | raise ValueError(f'Invalid lightdirs_method: {lightdirs_method}.')
346 | if normalize:
347 | lightdirs_flat = lightdirs.view(-1, 3) # [RS, 3]
348 | lightdirs_flat = normalize_rays(lightdirs_flat) # [RS, 3]
349 | lightdirs = lightdirs_flat.view(lightdirs.size()) # [R, S, 3]
350 | return lightdirs
351 |
--------------------------------------------------------------------------------
/ObjectFolder1.0/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.21
2 | torch==1.8.1
3 | scipy==1.6.3
4 | librosa==0.8.1
5 | matplotlib==3.4.2
6 | imageio==2.9.0
7 | tqdm==4.60.0
8 | torchvision==0.9.1
9 | imageio-ffmpeg==0.4.3
10 | configargparse
11 |
--------------------------------------------------------------------------------
/ObjectFolder1.0/scatter.py:
--------------------------------------------------------------------------------
1 | """Utility functions for scattering rays."""
2 | import torch
3 |
4 |
5 | def create_scatter_indices_for_dim(dim, shape, indices=None):
6 | """Create scatter indices for a given dimension."""
7 | dim_size = shape[dim]
8 | N_dims = len(shape)
9 | reshape = [1] * N_dims
10 | reshape[dim] = -1
11 |
12 | if indices is None:
13 | indices = torch.arange(dim_size, dtype=torch.int) # [dim_size,]
14 |
15 | indices = indices.view(reshape)
16 |
17 | indices = torch.broadcast_to(
18 | indices, shape) # [Ro, S, 1] or [Ro, S, C, 1] [0,1,1,1] vs. [512,64,1,1]
19 |
20 | indices = indices.int()
21 | return indices
22 |
23 |
24 | def create_scatter_indices(updates, dim2known_indices):
25 | """Create scatter indices."""
26 | updates_expanded = updates.unsqueeze(-1)
27 | target_shape = updates_expanded.size()
28 | n_dims = len(updates.size()) # 2 or 3
29 |
30 | dim_indices_list = []
31 | for dim in range(n_dims):
32 | indices = None
33 | if dim in dim2known_indices:
34 | indices = dim2known_indices[dim]
35 | dim_indices = create_scatter_indices_for_dim( # [Ro, S, C, 1]
36 | dim=dim,
37 | shape=target_shape, # [Ro, S, 1] or [Ro, S, C, 1]
38 | indices=indices) # [Ro,]
39 | dim_indices_list.append(dim_indices)
40 | scatter_indices = torch.cat((dim_indices_list), dim=-1) # [Ro, S, C, 3]
41 | return scatter_indices
42 |
43 |
44 | def scatter_nd(tensor, updates, dim2known_indices):
45 | scatter_indices = create_scatter_indices( # [Ro, S, C, 3]
46 | updates=updates, # [Ro, S]
47 | dim2known_indices=dim2known_indices) # [Ro,]
48 | #scattered_tensor = (tensor[scatter_indices.view(-1, 3)[:, 0],
49 | # scatter_indices.view(-1, 3)[:, 1],
50 | # scatter_indices.view(-1, 3)[:, 2]] = updates.view(-1))
51 | scattered_tensor = tensor[scatter_indices.view(-1, 3)[:, 0],
52 | scatter_indices.view(-1, 3)[:, 1],
53 | scatter_indices.view(-1, 3)[:, 2]]
54 | scattered_tensor = updates.view(-1)
55 | return scattered_tensor
56 |
57 |
58 | def scatter_results(intersect, indices, N_rays, keys, N_samples, N_importance=None):
59 | """Scatters intersecting ray results into the original set of rays.
60 |
61 | Args:
62 | intersect: Dict. Values are tensors of intersecting rays which can be any of the
63 | following.
64 | z_vals: [R?, S]
65 | pts: [R?, S, 3]
66 | rgb: [R?, S, 3]
67 | raw: [R?, S, 4]
68 | alpha: [R?, S, 1]
69 | indices: [R?, 1] int tensor. Intersecting ray indices to scatter back to the full set
70 | of rays.
71 | N_rays: int or int tensor. Total number of rays.
72 | keys: [str]. List of keys from the 'intersect' dictionary to scatter.
73 | N_samples: [int]. Number of samples.
74 | N_importance: [int]. Number of importance (fine) samples.
75 |
76 | Returns:
77 | scattered_results: Dict. Scattered results, where each value is of shape [R, ...].
78 | The original intersecting ray results are padding with samples with zero density,
79 | so they won't have any contribution to the final render.
80 | """
81 | # We use 'None' to indicate that the intersecting set of rays is equivalent to
82 | # the full set if rays, so we are done.
83 | if indices is None:
84 | return {k: intersect[k] for k in keys}
85 |
86 | scattered_results = {}
87 | # N_samples = intersect['z_vals'].shape[1]
88 | dim2known_indices = {0: indices} # [R', 1]
89 | for k in keys:
90 | if k == 'z_vals':
91 | # tensor = torch.rand((N_rays, N_samples), dtype=torch.float)
92 | tensor = torch.arange(N_samples) # [S,]
93 | tensor = tensor.float()
94 | tensor = torch.stack([tensor] * N_rays) # [R, S]
95 | elif k == 'z_samples':
96 | # tensor = torch.rand((N_rays, N_samples), dtype=torch.float) #[R, S]
97 | # N_importance = intersect[k].size()[1]
98 | tensor = torch.arange(N_importance) # [I,]
99 | tensor = tensor.float()
100 | tensor = torch.stack([tensor] * N_rays) # [R, I]
101 | elif k == 'raw':
102 | tensor = torch.full((N_rays, N_samples, 4), 1000.0, dtype=torch.float) # [R, S, 4]
103 | elif k == 'pts':
104 | tensor = torch.full((N_rays, N_samples, 3), 1000.0, dtype=torch.float) # [R, S, 3]
105 | elif 'rgb' in k:
106 | tensor = torch.zeros((N_rays, N_sampels, 3), dtype=torch.float) # [R, S, 3]
107 | elif 'alpha' in k:
108 | tensor = torch.zeros((N_rays, N_samples, 1), dtype=torch.float) # [R, S, 1]
109 | else:
110 | raise ValueError(f'Invalid key: {k}')
111 | # No intersections to scatter.
112 | if len(indices) == 0:
113 | scattered_results[k] = tensor
114 | else:
115 | scattered_v = scatter_nd( # [R, S, K]
116 | tensor=tensor,
117 | updates=intersect[k], # [Ro, S]
118 | dim2known_indices=dim2known_indices)
119 | # Convert the batch dimension to a known dimension.
120 | # For some reason 'scattered_z_vals' becomes [R, ?]. We need to explicitly
121 | # reshape it with 'N_sampels'.
122 | if k == 'z_samples':
123 | scattered_v = scattered_v.view(N_rays, N_importance) # [R, I]
124 | else:
125 | if k == 'z_vals':
126 | scattered_v = scattered_v.view(N_rays, N_samples) # [R, S]
127 | else:
128 | # scattered_v = scattered_v.view((N_rays,) + scattered_v.size()[1:]) # [R, S, K]
129 | # scattered_v = scattered_v.view((-1,) + scattered_v.size()[1:]) # [R, S, K]
130 | scattered_v = scattered_v.view(N_rays, N_samples, tensor.size()[2]) # [R, S, K]
131 | scattered_results[k] = scattered_v
132 | return scattered_results
133 |
134 |
135 | def scatter_coarse_and_fine(
136 | ret0, ret, indices, num_rays, N_samples, N_importance, **kwargs):
137 | # TODO: only process raw if retraw=True.
138 | ret0 = scatter_results(
139 | intersect=ret0,
140 | indices=indices,
141 | N_rays=num_rays,
142 | keys=['z_vals', 'rgb', 'alpha', 'raw'],
143 | N_samples=N_samples)
144 | ret = scatter_results(
145 | intersect=ret,
146 | indices=indices,
147 | N_rays=num_rays,
148 | key=['z_vals', 'rgb', 'alpha', 'raw', 'z_samples'],
149 | N_samples=N_samples + N_importance,
150 | N_importance=N_importance)
151 | return ret0, ret
152 |
--------------------------------------------------------------------------------
/ObjectFolder1.0/shadow_utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions for shadows."""
2 | import torch
3 |
4 | import ray_utils
5 | import run_osf_helpers
6 |
7 |
8 | def create_ray_batch(ray_batch, pts, metadata, use_viewdirs, lightdirs_method):
9 | """Create batch for shadow rays.
10 |
11 | Args:
12 | ray_batch: [R?, M] float tensor. Batch of primary rays
13 | pts: [R?, S, 3] float tensor. Primary points.
14 | metadata: [N, 3] float tensor. Metadata about each image. Currently only light
15 | position is provided.
16 | use_viewdirs: bool. Whether to use view directions.
17 |
18 | Returns:
19 | shadow_ray_batch: [R?S, M] float tensor. Batch of shadow rays containing one shadow
20 | ray for each primary ray sample. Each shadow ray originates at a primary point
21 | and points in the direction towards the light source.
22 | """
23 | num_primary_rays = pts.size()[0] # R
24 | num_primary_samples = pts.size()[1] # S
25 |
26 | # Samples are shadow ray origins.
27 | rays_o = pts.view(-1, 3) # [R?S, 3]
28 |
29 | num_shadow_rays = rays_o.size()[0] # R?S
30 | num_samples = pts.size()[1] # S
31 |
32 | rays_i = ray_batch[:, 11:12] # [R, 1]
33 |
34 | # Get light positions for each ray as the ray destinations.
35 | rays_dst = ray_utils.get_lightdirs( # [R?, S, 3]
36 | lightdirs_method=lightdirs_method, num_rays=num_primary_rays,
37 | num_samples=num_primary_samples, rays_i=rays_i, metadata=metadata,
38 | ray_batch=ray_batch, use_viewdirs=use_viewdirs)
39 | rays_dst = rays_dst.view(rays_o.size()) # [R?S, 3]
40 |
41 | rays_i = torch.tile(rays_i.unsqueeze(1), (1, num_primary_samples, 1)) # [R?, S, 1]
42 | rays_i = rays_i.view(-1, 1) # [R?S, 1]
43 |
44 | shadow_ray_batch = ray_utils.create_ray_batch(rays_o, rays_dst, rays_i, use_viewdirs)
45 | return shadow_ray_batch
46 |
47 |
48 | def compute_transmittance(alpha):
49 | """Applies shadows to outputs.
50 | Args:
51 | alpha: [R?S, S, 1] tf.float32. Alpha predictions from the model.
52 | Returns:
53 | last_trans: [R?S,] tf.float32. Shadow transmittance per ray.
54 | """
55 | trans = run_nerf_helpers.compute_transmittance(alpha=alpha[..., 0]) # [R?S, S]
56 |
57 | # Transmittance is computed in the direction origin -> end, so we grab the
58 | # last transmittance.
59 | last_trans = trans[:, -1] # [R?S,]
60 | return last_trans
61 |
--------------------------------------------------------------------------------
/ObjectFolder2.0_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/ObjectFolder2.0_teaser.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## ObjectFolder 2.0
2 | ObjectFolder 2.0: A Multisensory Object Dataset for Sim2Real Transfer (CVPR 2022)
3 | [[Project Page]](https://ai.stanford.edu/~rhgao/objectfolder2.0/) [[arXiv]](https://arxiv.org/abs/2204.02389)
4 |
5 |
6 |
7 |
8 |
9 | [ObjectFolder 2.0: A Multisensory Object Dataset for Sim2Real Transfer](https://ai.stanford.edu/~rhgao/objectfolder2.0/)
10 | [Ruohan Gao*](https://www.ai.stanford.edu/~rhgao/), [Zilin Si*](https://si-lynnn.github.io/), [Yen-Yu Chang*](https://yuyuchang.github.io/), [Samuel Clarke](https://samuelpclarke.com/), [Jeannette Bohg](https://web.stanford.edu/~bohg/), [Li Fei-Fei](https://profiles.stanford.edu/fei-fei-li), [Wenzhen Yuan](http://robotouch.ri.cmu.edu/yuanwz/), [Jiajun Wu](https://jiajunwu.com/)
11 | Stanford University, Carnegie Mellon University
12 | In Conference on Computer Vision and Pattern Recognition (**CVPR**), 2022
13 |
14 |
15 |
16 | If you find our dataset, code or project useful in your research, we appreciate it if you could cite:
17 |
18 | @inproceedings{gao2022ObjectFolderV2,
19 | title = {ObjectFolder 2.0: A Multisensory Object Dataset for Sim2Real Transfer},
20 | author = {Gao, Ruohan and Si, Zilin and Chang, Yen-Yu and Clarke, Samuel and Bohg, Jeannette and Fei-Fei, Li and Yuan, Wenzhen and Wu, Jiajun},
21 | booktitle = {CVPR},
22 | year = {2022}
23 | }
24 |
25 | @inproceedings{gao2021ObjectFolder,
26 | title = {ObjectFolder: A Dataset of Objects with Implicit Visual, Auditory, and Tactile Representations},
27 | author = {Gao, Ruohan and Chang, Yen-Yu and Mall, Shivani and Fei-Fei, Li and Wu, Jiajun},
28 | booktitle = {CoRL},
29 | year = {2021}
30 | }
31 |
32 | ### About ObjectFolder Dataset
33 |
34 | ObjectFolder 2.0 is a dataset of 1,000 objects in the form of implicit representations. It contains 1,000 Object Files each containing the complete multisensory profile for an object instance. Each Object File implicit neural representation network contains three sub-networks---VisionNet, AudioNet, and TouchNet, which through querying with the corresponding extrinsic parameters we can obtain the visual appearance of the object from different views and lighting conditions, impact sounds of the object at each position of specified force profile, and tactile sensing of the object at every surface location for varied rotation angels and pressing depth, respectively. The dataset contains common household objects of diverse categories such as wood desks, ceramic bowls, plastic toys, steel forks, glass mirrors, etc. The objects.csv file contains the metadata for these 1,000 objects. Note that the first 100 objects are the same as ObjectFolder 1.0, and we recommend using the new version for improved multisensory simulation and implicit representation for the objects. See the paper for details.
35 |
36 |
37 |
38 | ### Prerequisites
39 | * OS: Ubuntu 20.04.2 LTS
40 | * GPU: >= NVIDIA GTX 1080 Ti with >= 460.73.01 driver
41 | * Python package manager `conda`
42 |
43 | ### Setup
44 | ```
45 | git clone https://github.com/rhgao/ObjectFolder.git
46 | cd ObjectFolder
47 | export OBJECTFOLDER_HOME=$PWD
48 | conda env create -f $OBJECTFOLDER_HOME/environment.yml
49 | source activate ObjectFolder-env
50 | ```
51 |
52 | ### Dataset Download and Preparation
53 | Use the following command to download the first 100 objects:
54 | ```
55 | wget https://download.cs.stanford.edu/viscam/ObjectFolder/ObjectFolder1-100.tar.gz
56 | tar -xvf ObjectFolder1-100.tar.gz
57 | ```
58 | Use the following command to download ObjectFiles with KiloOSF that supports real-time visual rendering at the expense of larger model size:
59 | ```
60 | wget https://download.cs.stanford.edu/viscam/ObjectFolder/ObjectFolder1-100KiloOSF.tar.gz
61 | tar -xvf ObjectFolder101-100KiloOSF.tar.gz
62 | ```
63 | Similarly, use the following command to download objects from 101-1000:
64 | ```
65 | wget https://download.cs.stanford.edu/viscam/ObjectFolder/ObjectFolder[X+1]-[X+100].tar.gz
66 | wget https://download.cs.stanford.edu/viscam/ObjectFolder/ObjectFolder[X+1]-[X+100]KiloOSF.tar.gz
67 | tar -xvf ObjectFolder[X+1]-[X+100].tar.gz
68 | tar -xvf ObjectFolder[X+1]-[X+100]KiloOSF.tar.gz
69 | # replace X with a value in [100,200,300,400,500,600,700,800,900]
70 | ```
71 |
72 | ### Rendering Visual, Acoustic, and Tactile Sensory Data
73 | Run the following command to render visual appearance of the object from a specified camera viewpoint and lighting direction:
74 | ```
75 | $ python OF_render.py --modality vision --object_file_path path_of_ObjectFile \
76 | --vision_test_file_path path_of_vision_test_file \
77 | --vision_results_dir path_of_vision_results_directory \
78 | ```
79 | Run the following command to render impact sounds of the object at a specified surface location and impact force:
80 | ```
81 | $ python OF_render.py --modality audio --object_file_path path_of_ObjectFile \
82 | --audio_vertices_file_path path_of_audio_testing_vertices_file \
83 | --audio_forces_file_path path_of_forces_file \
84 | --audio_results_dir path_of_audio_results_directory \
85 | ```
86 | Run the following command to render tactile RGB iamges of the object at a specified surface location, gel rotation angle, and deformation:
87 | ```
88 | $ python OF_render.py --modality touch --object_file_path path_of_ObjectFile \
89 | --touch_vertices_file_path path_of_touch_testing_vertices_file \
90 | --touch_gelinfo_file_path path_of_gelinfor_file \
91 | --touch_results_path path_of_touch_results_directory
92 | ```
93 | The command-line arguments are described as follows:
94 | * `--object_file_path`: The path of ObjectFile.
95 | * `--vision_test_file_path`: The path of the testing file for vision, which should be a npy file.
96 | * `--vision_results_dir`: The path of the vision results directory to save rendered images.
97 | * `--audio_vertices_file_path`: The path of the testing vertices file for audio, which should be a npy file.
98 | * `--audio_forces_file_path`: The path of forces file for audio, which should be a npy file.
99 | * `--audio_results_dir`: The path of audio results directory to save rendered impact sounds as .wav files.
100 | * `--touch_vertices_file_path`: The path of the testing vertices file for touch, which should be a npy file.
101 | * `--touch_gelinfo_file_path`: The path of the gelinfo file for touch that speficifies the gel rotation angle and deformation depth, which should be a npy file.
102 | * `--touch_results_dir`: The path of the touch results directory to save rendered tactile RGB images.
103 |
104 | ### Data Format
105 | * `--vision_test_file_path`: It is a npy file with shape of (N, 6), where N is the number of testing viewpoints. Each data point contains the coordinates of the camera and the light in the form of (camera_x, camera_y, camera_z, light_x, light_y, light_z).
106 | * `--audio_vertices_file_path`: It is a npy file with shape of (N, 3), where N is the number of testing vertices. Each data point represents a coordinate on the object in the form of (x, y, z).
107 | * `--audio_forces_file_path`: It is a npy file with shape of (N, 3), where N is the number of testing vertices. Each data point represents the force values for the corresponding impact in the form of (F_x, F_y, F_z).
108 | * `--touch_vertices_file_path`: It is a npy file with shape of (N, 3), where N is the number of testing vertices. Each data point contains a coordinate on the object in the form of (x, y, z).
109 | * `--touch_gelinfo_file_path`: It is a npy file with shape of (N, 3), where N is the number of testing vertices. Each data point contains the gel rotation angle and gel displacement in the form of (theta, phi, depth). theta is in the range of [0, np.radians(15)], phi is in the range of [0, 2pi], and depth is in the range of [0.0005,0.002].
110 |
111 | ### ObjectFile with KiloOSF for Real-time Visual Rendering
112 | To use KiloOSF, please make a copy of [cuda](https://github.com/creiser/kilonerf/tree/master/cuda) in the root directory of this repo and follow the steps in [CUDA extension installation](https://github.com/creiser/kilonerf). Run the following command to render visual appearance of the object from a specified camera viewpoint and lighting direction:
113 | ```
114 | $ python OF_render.py --modality vision --KiloOSF --object_file_path path_of_KiloOSF_ObjectFile \
115 | --vision_test_file_path path_of_vision_test_file \
116 | --vision_results_dir path_of_vision_results_directory \
117 | ```
118 |
119 | ### Demo
120 | Below we show an example of rendering the visual, acoustic, and tactile data from the ObjectFile implicit representation for one object:
121 | ```
122 | $ python OF_render.py --modality vision,audio,touch --object_file_path demo/ObjectFile.pth \
123 | --vision_test_file_path demo/vision_demo.npy \
124 | --vision_results_dir demo/vision_results/ \
125 | --audio_vertices_file_path demo/audio_demo_vertices.npy \
126 | --audio_forces_file_path demo/audio_demo_forces.npy \
127 | --audio_results_dir demo/audio_results/ \
128 | --touch_vertices_file_path demo/touch_demo_vertices.npy \
129 | --touch_gelinfo_file_path demo/touch_demo_gelinfo.npy \
130 | --touch_results_dir demo/touch_results/
131 | ```
132 | The rendered images, impact sounds, tactile images will be saved in `demo/vision_results/`, `demo/audio_results/`, and `demo/touch_results/`, respectively.
133 |
134 | ### License
135 | ObjectFolder is CC BY 4.0 licensed, as found in the LICENSE file. The mesh files for the 1,000 high quality 3D objects originally come from online repositories including: [ABO dataset](https://amazon-berkeley-objects.s3.amazonaws.com/index.html), [3D Model Haven](https://3dmodelhaven.com/), [YCB dataset](http://ycb-benchmarks.s3-website-us-east-1.amazonaws.com/), and [Google Scanned Objects](https://app.ignitionrobotics.org/GoogleResearch/fuel/collections/Google\%20Scanned\%20Objects). Please also refer to their original lisence file. We appreciate their sharing of these great object assets.
136 |
--------------------------------------------------------------------------------
/TouchNet_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.autograd.set_detect_anomaly(True)
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from TouchNet_utils import *
7 |
8 | class DenseLayer(nn.Linear):
9 | def __init__(self, in_dim: int, out_dim: int, activation: str = 'relu', *args, **kwargs) -> None:
10 | self.activation = activation
11 | super().__init__(in_dim, out_dim, *args, **kwargs)
12 |
13 | def reset_parameters(self) -> None:
14 | torch.nn.init.xavier_uniform_(self.weight, gain=torch.nn.init.calculate_gain(self.activation))
15 | if self.bias is not None:
16 | torch.nn.init.zeros_(self.bias)
17 |
18 | class Embedder:
19 | def __init__(self, **kwargs):
20 | self.kwargs = kwargs
21 | self.create_embedding_fn()
22 |
23 | def create_embedding_fn(self):
24 | embed_fns = []
25 | d = self.kwargs['input_dims']
26 | out_dim = 0
27 | if self.kwargs['include_input']:
28 | embed_fns.append(lambda x: x)
29 | out_dim += d
30 |
31 | max_freq = self.kwargs['max_freq_log2']
32 | N_freqs = self.kwargs['num_freqs']
33 |
34 | if self.kwargs['log_sampling']:
35 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
36 | else:
37 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
38 |
39 | for freq in freq_bands:
40 | for p_fn in self.kwargs['periodic_fns']:
41 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
42 | out_dim += d
43 |
44 | self.embed_fns = embed_fns
45 | self.out_dim = out_dim
46 |
47 | def embed(self, inputs):
48 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
49 |
50 | def get_embedder(multires, i=0):
51 | if i == -1:
52 | #9 is for x, y, z, theta, phi_x, phi_y, displacement, w, h
53 | return nn.Identity(), 9
54 |
55 | embed_kwargs = {
56 | 'include_input': True,
57 | 'input_dims': 9,
58 | 'max_freq_log2': multires-1,
59 | 'num_freqs': multires,
60 | 'log_sampling': True,
61 | 'periodic_fns': [torch.sin, torch.cos],
62 | }
63 |
64 | embedder_obj = Embedder(**embed_kwargs)
65 | embed = lambda x, eo=embedder_obj: eo.embed(x)
66 | return embed, embedder_obj.out_dim
67 |
68 | class AudioNeRF(nn.Module):
69 | def __init__(self, D=8, input_ch=5):
70 | super(AudioNeRF, self).__init__()
71 | self.model_x = NeRF(D = D, input_ch = input_ch)
72 | self.model_y = NeRF(D = D, input_ch = input_ch)
73 | self.model_z = NeRF(D = D, input_ch = input_ch)
74 |
75 | def forward(self, embedded_x, embedded_y, embedded_z):
76 | results_x = self.model_x(embedded_x)
77 | results_y = self.model_y(embedded_y)
78 | results_z = self.model_z(embedded_z)
79 | return results_x, results_y, results_z
80 |
81 | class NeRF(nn.Module):
82 | def __init__(self, D=8, W=256, input_ch=5, input_ch_views=0, output_ch=2, skips=[4], use_viewdirs=False):
83 | """
84 | """
85 | super(NeRF, self).__init__()
86 | self.D = D
87 | self.W = W
88 | self.input_ch = input_ch
89 | self.input_ch_views = input_ch_views
90 | self.skips = skips
91 | self.use_viewdirs = use_viewdirs
92 |
93 | self.pts_linears = nn.ModuleList(
94 | [DenseLayer(input_ch, W, activation='relu')] + [DenseLayer(W, W, activation='relu') if i not in self.skips else DenseLayer(W + input_ch, W, activation='relu') for i in range(D-1)])
95 |
96 | self.views_linears = nn.ModuleList([DenseLayer(input_ch_views + W, W//2, activation='relu')])
97 |
98 | if use_viewdirs:
99 | self.feature_linear = DenseLayer(W, W, activation='sigmoid')
100 | #self.alpha_linear = DenseLayer(W, 1, activation='linear')
101 | self.rgb_linear = DenseLayer(W//2, output_ch, activation='sigmoid')
102 | else:
103 | self.output_linear = DenseLayer(W, output_ch, activation='sigmoid')
104 |
105 | def forward(self, x):
106 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
107 | h = input_pts
108 | for i, l in enumerate(self.pts_linears):
109 | h = self.pts_linears[i](h)
110 | h = F.relu(h)
111 | if i in self.skips:
112 | h = torch.cat([input_pts, h], -1)
113 |
114 | if self.use_viewdirs:
115 | feature = self.feature_linear(h)
116 | h = torch.cat([feature, input_views], -1)
117 |
118 | for i, l in enumerate(self.views_linears):
119 | h = self.views_linears[i](h)
120 | h = F.relu(h)
121 |
122 | outputs = self.rgb_linear(h)
123 | else:
124 | outputs = self.output_linear(h)
125 |
126 | return outputs
127 |
--------------------------------------------------------------------------------
/TouchNet_utils.py:
--------------------------------------------------------------------------------
1 | from scipy.io import wavfile
2 | import librosa
3 | import librosa.display
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from TouchNet_model import *
7 | import os
8 | from collections import OrderedDict
9 | from torch._six import container_abcs, string_classes, int_classes
10 |
11 | def strip_prefix_if_present(state_dict, prefix):
12 | keys = sorted(state_dict.keys())
13 | if not all(key.startswith(prefix) for key in keys):
14 | return state_dict
15 | stripped_state_dict = OrderedDict()
16 | for key, value in state_dict.items():
17 | stripped_state_dict[key.replace(prefix, "")] = value
18 | return stripped_state_dict
19 |
20 | def mkdirs(path, remove=False):
21 | if os.path.isdir(path):
22 | if remove:
23 | shutil.rmtree(path)
24 | else:
25 | return
26 | os.makedirs(path)
27 |
28 | def generate_spectrogram_magphase(audio, stft_frame, stft_hop, n_fft, with_phase=False):
29 | spectro = librosa.core.stft(audio, hop_length=stft_hop, n_fft=n_fft, win_length=stft_frame, center=True)
30 | spectro_mag, spectro_phase = librosa.core.magphase(spectro)
31 | spectro_mag = np.expand_dims(spectro_mag, axis=0)
32 | if with_phase:
33 | spectro_phase = np.expand_dims(np.angle(spectro_phase), axis=0)
34 | return spectro_mag, spectro_phase
35 | else:
36 | return spectro_mag
37 |
38 | def generate_spectrogram_complex(audio, stft_frame, stft_hop, n_fft):
39 | spectro = librosa.core.stft(audio, hop_length=stft_hop, n_fft=n_fft, win_length=stft_frame, center=True)
40 | real = np.expand_dims(np.real(spectro), axis=0)
41 | imag = np.expand_dims(np.imag(spectro), axis=0)
42 | spectro_two_channel = np.concatenate((real, imag), axis=0)
43 | return spectro_two_channel
44 |
45 | def batchify(fn, chunk):
46 | """
47 | Constructs a version of 'fn' that applies to smaller batches
48 | """
49 | if chunk is None:
50 | return fn
51 | def ret(inputs):
52 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
53 | return ret
54 |
55 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
56 | """
57 | Prepares inputs and applies network 'fn'.
58 | """
59 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
60 | embedded = embed_fn(inputs_flat)
61 |
62 | if viewdirs is not None:
63 | input_dirs = viewdirs[:,None].expand(inputs.shape)
64 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
65 | embedded_dirs = embeddirs_fn(input_dirs_flat)
66 | embedded = torch.cat([embedded, embedded_dirs], -1)
67 |
68 | outputs_flat = batchify(fn, netchunk)(embedded)
69 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
70 | return outputs
71 |
72 | def create_nerf(args):
73 | """
74 | Instantiate NeRF's MLP model.
75 | """
76 | embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
77 |
78 | input_ch_views = 0
79 | embeddirs_fn = None
80 | if args.use_viewdirs:
81 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
82 | output_ch = 2
83 | skips = [4]
84 | model = NeRF(D=args.netdepth, W=args.netwidth,
85 | input_ch=input_ch, output_ch=output_ch, skips=skips,
86 | input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs)
87 | model = nn.DataParallel(model).to(device)
88 | grad_vars = list(model.parameters())
89 |
90 | def object_collate(batch):
91 | r"""Puts each data field into a tensor with outer dimension batch size"""
92 | #print batch
93 | elem_type = type(batch[0])
94 | if isinstance(batch[0], torch.Tensor):
95 | out = None
96 | return torch.stack(batch, 0, out=out)
97 | elif elem_type.__module__ == 'numpy':
98 | elem = batch[0]
99 | if elem_type.__name__ == 'ndarray':
100 | return torch.cat([torch.from_numpy(b) for b in batch], 0) #concatenate even if dimension differs
101 | #return object_collate([torch.from_numpy(b) for b in batch])
102 | if elem.shape == (): # scalars
103 | py_type = float if elem.dtype.name.startswith('float') else int
104 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
105 | elif isinstance(batch[0], float):
106 | return torch.tensor(batch, dtype=torch.float64)
107 | elif isinstance(batch[0], int_classes):
108 | return torch.tensor(batch)
109 | elif isinstance(batch[0], string_classes):
110 | return batch
111 | elif isinstance(batch[0], container_abcs.Mapping):
112 | return {key: object_collate([d[key] for d in batch]) for key in batch[0]}
113 | elif isinstance(batch[0], container_abcs.Sequence):
114 | transposed = zip(*batch)
115 | return [object_collate(samples) for samples in transposed]
116 |
117 | raise TypeError((error_msg_fmt.format(type(batch[0]))))
118 |
--------------------------------------------------------------------------------
/VisionNet_configs.py:
--------------------------------------------------------------------------------
1 | cfg = {
2 | # 'dataset_dir': '/viscam/u/yenyu/ObjectFolderV2/objects/23/VisionNet_osf_data/',
3 | # 'pretrained_cfg_path': '/viscam/u/yenyu/kiloosf2/cfgs/paper/pretrain/23.yaml',
4 | # 'pretrained_checkpoint_path': '/viscam/u/yenyu/kiloosf2/logs/paper/pretrain/23/checkpoint_0200000.pth',
5 | 'blender_half_res': False,
6 | 'dataset_type': 'osf',
7 | 'discovery': {
8 | 'alpha_distance': 0.0211,
9 | 'alpha_rgb_initalization': 'pass_actual_nonlinearity',
10 | 'bias_initialization_method': 'standard',
11 | 'convert_density_to_alpha': True,
12 | 'direction_layer_size': 64,
13 | 'equal_split_metric': 'mse',
14 | 'hidden_layer_size': 64,
15 | 'iterations': 30000,
16 | 'late_feed_direction': True,
17 | 'max_num_networks': 512,
18 | 'network_rng_seed': 8078673,
19 | 'nonlinearity_initalization': 'pass_actual_nonlinearity',
20 | 'num_examples_per_network': 1020000,
21 | 'num_frequencies': 10,
22 | 'num_frequencies_direction': 4,
23 | 'num_frequencies_light': 4,
24 | 'num_hidden_layers': 2,
25 | 'num_train_examples_per_network': 1000000,
26 | 'outputs': 'color_and_density',
27 | 'quantile_se': 0.99,
28 | 'query_batch_size': 80000,
29 | 'refeed_position_index': 'None',
30 | 'test_batch_size': 512,
31 | 'test_error_metric': 'quantile_se',
32 | 'test_every': 500, 'train_batch_size': 128,
33 | 'use_same_initialization_for_all_networks': True,
34 | 'weight_initialization_method': 'kaiming_uniform'
35 | },
36 | 'fixed_resolution': [16, 16, 16],
37 | 'max_error': 100000,
38 | 'performance_monitoring': False,
39 | 'render_only': True,
40 | 'restart_after_checkpoint': True,
41 | 'skip_final': True,
42 | 'tree_type': 'kdtree_longest',
43 | 'global_domain_min': [-0.6, -0.6, -0.6],
44 | 'global_domain_max': [0.6, 0.6, 0.6],
45 | 'alpha_distance': 0.0211,
46 | 'alpha_rgb_initalization': 'pass_actual_nonlinearity',
47 | 'bias_initialization_method': 'standard',
48 | 'convert_density_to_alpha': True,
49 | 'direction_layer_size': 64,
50 | 'equal_split_metric': 'mse',
51 | 'hidden_layer_size': 64,
52 | 'iterations': 1000000,
53 | 'late_feed_direction': True,
54 | 'max_num_networks': 512,
55 | 'network_rng_seed': 8078673,
56 | 'nonlinearity_initalization': 'pass_actual_nonlinearity',
57 | 'num_examples_per_network': 1020000,
58 | 'num_frequencies': 10,
59 | 'num_frequencies_direction': 4,
60 | 'num_frequencies_light': 4,
61 | 'num_hidden_layers': 2,
62 | 'num_train_examples_per_network': 1000000,
63 | 'outputs': 'color_and_density',
64 | 'quantile_se': 0.99,
65 | 'query_batch_size': 80000,
66 | 'refeed_position_index': 'None',
67 | 'test_batch_size': 512,
68 | 'test_error_metric': 'quantile_se',
69 | 'test_every': 500,
70 | 'train_batch_size': 128,
71 | 'use_same_initialization_for_all_networks': False,
72 | 'weight_initialization_method': 'kaiming_uniform',
73 | # 'distilled_cfg_path': '/viscam/u/yenyu/kiloosf2/cfgs/paper/distill/23.yaml',
74 | # 'distilled_checkpoint_path': '/viscam/u/yenyu/kiloosf2/logs/paper/distill/23/checkpoint.pth',
75 | # 'occupancy_cfg_path': '/viscam/u/yenyu/kiloosf2/cfgs/paper/pretrain_occupancy/23.yaml',
76 | # 'occupancy_log_path': '/viscam/u/yenyu/kiloosf2/logs/paper/pretrain_occupancy/23/occupancy.pth',
77 | 'checkpoint_interval': 50000,
78 | 'chunk_size': 40000,
79 | 'initial_learning_rate': 0.001,
80 | 'l2_regularization_lambda': 1e-06,
81 | 'learing_rate_decay_rate': 500,
82 | 'no_batching': True,
83 | 'num_rays_per_batch': 8192,
84 | 'num_samples_per_ray': 384,
85 | 'perturb': 1.0,
86 | 'precrop_fraction': 0.5,
87 | 'precrop_iterations': 0,
88 | 'raw_noise_std': 0.0,
89 | 'near': 0.01,
90 | 'far': 4,
91 | 'spiral_radius': None,
92 | 'half_res': False,
93 | 'n_render_spiral': 40,
94 | 'render_spiral_angles': None,
95 | 'num_importance_samples_per_ray': 0,
96 | 'render_test': True,
97 | 'no_color_sigmoid': False,
98 | 'render_factor': 0,
99 | 'testskip': 1,
100 | 'deepvoxels_shape': 'greek',
101 | 'blender_white_background': True,
102 | 'llff_factor': 8,
103 | 'llff_no_ndc': False,
104 | 'llff_lindisp': False,
105 | 'llff_spherify': False,
106 | 'llff_hold': False,
107 | 'print_interval': 100,
108 | 'render_testset_interval': 100000,
109 | 'render_video_interval': 100000000,
110 | 'network_chunk_size': 65536,
111 | 'rng_seed': 0,
112 | 'use_initialization_fix': False,
113 | 'model_type': 'multi_network',
114 | 'random_direction_probability': -1,
115 | 'von_mises_kappa': -1,
116 | 'view_dependent_dropout_probability': -1,
117 | 'occupancy': {
118 | # 'dataset_dir': '/viscam/u/yenyu/ObjectFolderV2/objects/23/VisionNet_osf_data/',
119 | 'dataset_type': 'osf',
120 | # 'pretrained_cfg_path': '/viscam/u/yenyu/kiloosf2/cfgs/paper/pretrain/23.yaml',
121 | # 'pretrained_checkpoint_path': '/viscam/u/yenyu/kiloosf2/logs/paper/pretrain/23/checkpoint_0200000.pth',
122 | 'resolution': [256, 256, 256],
123 | 'subsample_resolution': [3, 3, 3],
124 | 'threshold': 10,
125 | 'voxel_batch_size': 16384,
126 | 'global_domain_min': [-0.6, -0.6, -0.6],
127 | 'global_domain_max': [0.6, 0.6, 0.6]
128 | }
129 | }
130 |
--------------------------------------------------------------------------------
/basics/CalibData.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | class CalibData:
3 | def __init__(self, dataPath):
4 | self.dataPath = dataPath
5 | data = np.load(dataPath)
6 |
7 | self.numBins = data['bins']
8 | self.grad_r = data['grad_r']
9 | self.grad_g = data['grad_g']
10 | self.grad_b = data['grad_b']
11 |
--------------------------------------------------------------------------------
/basics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/basics/__init__.py
--------------------------------------------------------------------------------
/basics/sensorParams.py:
--------------------------------------------------------------------------------
1 | ball_radius = 4.00/2;
2 | pixmm = 0.1245 # 0.0295; # 0.0302
3 | numBins = 120;
4 |
5 | # sensor setting
6 | h = 120
7 | w = 160
8 | cam2gel = 0.04
9 |
--------------------------------------------------------------------------------
/box_utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions for bounding box computation."""
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import ray_utils
6 |
7 | def ray_to_box_coordinate_frame_pairwise(box_center, box_rotation_matrix,
8 | rays_start_point, rays_end_point):
9 | """Moves a set of rays into a box's coordinate frame.
10 |
11 | Args:
12 | box_center: A tensor of size [3] or [r, 3].
13 | box_rotation_matrix: A tensor of size [3, 3] or [r, 3, 3].
14 | rays_start_point: A tensor of size [r, 3] where r is the number of rays.
15 | rays_end_points: A tensor of size [r, 3] where r is the number of rays.
16 |
17 | Returns:
18 | rays_start_point_in_box_frame: A tensor of size [r, 3].
19 | rays_end_point_in_box_frame: A tensor if size [r, 3].
20 | """
21 | r = rays_start_point.size()[0]
22 | box_center = torch.broadcast_to(box_center, (r, 3))
23 | box_rotation_matrix = torch.broadcast_to(box_rotation_matrix, (r, 3, 3))
24 | rays_start_point_in_box_frame = torch.matmul(
25 | (rays_start_point - box_center).unsqueeze(1),
26 | box_rotation_matrix)
27 | rays_end_point_in_box_frame = torch.matmul(
28 | (rays_end_point - box_center).unsqueeze(1),
29 | box_rotation_matrix)
30 | return (rays_start_point_in_box_frame.view(-1, 3),
31 | rays_end_point_in_box_frame.view(-1, 3))
32 |
33 |
34 | def ray_box_intersection_pairwise(box_center,
35 | box_rotation_matrix,
36 | box_length,
37 | box_width,
38 | box_height,
39 | rays_start_point,
40 | rays_end_point,
41 | exclude_negative_t=False,
42 | exclude_enlarged_t=True,
43 | epsilon=0.000001):
44 | """Intersects a set of rays with a box.
45 |
46 | Note: The intersection points are returned in the box coordinate frame.
47 | Note: Make sure the start and end point of the rays are not the same.
48 | Note: Even though a start and end point is passed for each ray, rays are
49 | never ending and can intersect a box beyond their start / end points.
50 |
51 | Args:
52 | box_center: A tensor of size [3] or [r, 3].
53 | box_rotation_matrix: A tensor of size [3, 3] or [r, 3, 3].
54 | box_length: A scalar tensor or of size [r].
55 | box_width: A scalar tensor or of size [r].
56 | box_height: A scalar tensor or of size [r].
57 | rays_start_point: A tensor of size [r, 3] where r is the number of rays.
58 | rays_end_point: A tensor of size [r, 3] there r is the number of rays.
59 | exclude_negative_t: bool.
60 | exclude_enlarged_t: bool.
61 | epsilon: A very small number.
62 |
63 | Returns:
64 | intersection_points_in_box_frame: A tensor of size [r', 2, 3]
65 | that contains intersection points in box coordinate frame.
66 | indices_of_intersecting_rays: A tensor of size [r'].
67 | intersection_ts: A tensor of size [r'].
68 | """
69 | r = rays_start_point.size()[0]
70 | box_length = box_length.expand(r)
71 | box_width = box_width.expand(r)
72 | box_height = box_height.expand(r)
73 | box_center = torch.broadcast_to(box_center, (r, 3))
74 | box_rotation_matrix = torch.broadcast_to(box_rotation_matrix, (r, 3, 3))
75 | rays_start_point_in_box_frame, rays_end_point_in_box_frame = (
76 | ray_to_box_coordinate_frame_pairwise(
77 | box_center=box_center,
78 | box_rotation_matrix=box_rotation_matrix,
79 | rays_start_point=rays_start_point,
80 | rays_end_point=rays_end_point))
81 | rays_a = rays_end_point_in_box_frame - rays_start_point_in_box_frame
82 | intersection_masks = []
83 | intersection_points = []
84 | intersection_ts = []
85 | box_size = [box_length, box_width, box_height]
86 | for axis in range(3):
87 | plane_value = box_size[axis] / 2.0
88 | for _ in range(2):
89 | plane_value = -plane_value
90 | # Compute the scalar multiples of 'rays_a' to apply in order to intersect
91 | # with the plane.
92 | t = ((plane_value - rays_start_point_in_box_frame[:, axis]) / # [R,]
93 | rays_a[:, axis])
94 | # The current axis only intersects with plane if the ray is not parallel
95 | # with the plane. Note that this will result in 't' being +/- infinity, becasue
96 | # the ray component in the axis is zero, resulting in rays_a[:, axis] = 0.
97 | intersects_with_plane = torch.abs(rays_a[:, axis]) > epsilon
98 | if exclude_negative_t: # Only allow at most one negative t
99 | t = torch.maximum(t, torch.tensor(0.0)) # [R,]
100 | if exclude_enlarged_t:
101 | t = torch.maximum(t, torch.tensor(1.0)) # [R,]
102 | intersection_ts.append(t) # [R, 1]
103 | intersection_points_i = []
104 |
105 | # Initialize a mask which represents whether each ray intersects with the
106 | # current plane.
107 | intersection_masks_i = torch.ones_like(t, dtype=torch.int32).bool() # [R,]
108 | for axis2 in range(3):
109 | # Compute the point of intersection for the current axis.
110 | intersection_points_i_axis2 = ( # [R,]
111 | rays_start_point_in_box_frame[:, axis2] + t * rays_a[:, axis2])
112 | intersection_points_i.append(intersection_points_i_axis2) # 3x [R,]
113 |
114 | # Update the intersection mask depending on whether the intersection
115 | # point is within bounds.
116 | intersection_masks_i = torch.logical_and( # [R,]
117 | torch.logical_and(intersection_masks_i, intersects_with_plane),
118 | torch.logical_and(
119 | intersection_points_i_axis2 <= (box_size[axis2] / 2.0 + epsilon),
120 | intersection_points_i_axis2 >= (-box_size[axis2] / 2.0 - epsilon)))
121 | intersection_points_i = torch.stack(intersection_points_i, dim=1) # [R, 3]
122 | intersection_masks.append(intersection_masks_i) # List of [R,]
123 | intersection_points.append(intersection_points_i) # List of [R, 3]
124 | intersection_ts = torch.stack(intersection_ts, dim=1) # [R, 6]
125 | intersection_masks = torch.stack(intersection_masks, dim=1) # [R, 6]
126 | intersection_points = torch.stack(intersection_points, dim=1) # [R, 6, 3]
127 |
128 | # Compute a mask over rays with exactly two plane intersections out of the six
129 | # planes. More intersections are possible if the ray coincides with a box
130 | # edge or corner, but we'll ignore these cases for now.
131 | counts = torch.sum(intersection_masks.int(), dim=1) # [R,]
132 | intersection_masks_any = torch.eq(counts, 2) # [R,]
133 | indices = torch.arange(intersection_masks_any.size()[0]).int() # [R,]
134 | # Apply the intersection masks over tensors.
135 | indices = indices[intersection_masks_any] # [R',]
136 | intersection_masks = intersection_masks[intersection_masks_any] # [R', 6]
137 | intersection_points = intersection_points[intersection_masks_any] # [R', 6, 3]
138 | intersection_points = intersection_points[intersection_masks].view(-1, 2, 3) # [R', 2, 3]
139 | # Ensure one or more positive ts.
140 | intersection_ts = intersection_ts[intersection_masks_any] # [R', 6]
141 | intersection_ts = intersection_ts[intersection_masks] # [R'*2]
142 | intersection_ts = intersection_ts.view(indices.size()[0], 2) # [R', 2]
143 | positive_ts_mask = (intersection_ts >= 0) # [R', 2]
144 | positive_ts_count = torch.sum(positive_ts_mask.int(), dim=1) # [R']
145 | positive_ts_mask = (positive_ts_count >= 1) # [R']
146 | intersection_points = intersection_points[positive_ts_mask] # [R'', 2, 3]
147 | false_indices = indices[torch.logical_not(positive_ts_mask)] # [R',]
148 | indices = indices[positive_ts_mask] # [R'',]
149 | if len(false_indices) > 0:
150 | intersection_masks_any[false_indices[:, None]] = torch.zeros(false_indices.size(), dtype=torch.bool)
151 | return rays_start_point_in_box_frame, intersection_masks_any, intersection_points, indices
152 |
153 |
154 | def compute_bounds_from_intersect_points(rays_o, intersect_indices,
155 | intersect_points):
156 | """Computes bounds from intersection points.
157 |
158 | Note: Make sure that inputs are in the same coordiante frame.
159 |
160 | Args:
161 | rays_o: [R, 3] float tensor
162 | intersect_indices: [R', 1] float tensor
163 | intersect_points: [R', 2, 3] float tensor
164 |
165 | Returns:
166 | intersect_bounds: [R', 2] float tensor
167 |
168 | where R is the number of rays and R' is the number of intersecting rays.
169 | """
170 | intersect_rays_o = rays_o[intersect_indices] # [R', 1, 3]
171 | intersect_diff = intersect_points - intersect_rays_o # [R', 2, 3]
172 | intersect_bounds = torch.norm(intersect_diff, dim=2) # [R', 2]
173 |
174 | # Sort the bounds so that near comes before far for all rays.
175 | intersect_bounds, _ = torch.sort(intersect_bounds, dim=1) # [R', 2]
176 |
177 | # For some reason the sort function returns [R', ?] instead of [R', 2], so we
178 | # will explicitly reshape it.
179 | intersect_bounds = intersect_bounds.view(-1, 2) # [R', 2]
180 | return intersect_bounds
181 |
182 |
183 | def compute_ray_bbox_bounds_pairwise(rays_o, rays_d, box_length,
184 | box_width, box_height, box_center,
185 | box_rotation, far_limit=1e10):
186 | """Computes near and far bounds for rays intersecting with bounding boxes.
187 |
188 | Note: rays and boxes are defined in world coordinate frame.
189 |
190 | Args:
191 | rays_o: [R, 3] float tensor. A set of ray origins.
192 | rays_d: [R, 3] float tensor. A set of ray directions.
193 | box_length: scalar or [R,] float tensor. Bounding box length.
194 | box_width: scalar or [R,] float tensor. Bounding box width.
195 | box_height: scalar or [R,] float tensor. Bounding box height.
196 | box_center: [3,] or [R, 3] float tensor. The center of the box.
197 | box_rotation: [3, 3] or [R, 3, 3] float tensor. The box rotation matrix.
198 | far_limit: float. The maximum far value to use.
199 |
200 | Returns:
201 | intersect_bounds: [R', 2] float tensor. The bounds per-ray, sorted in
202 | ascending order.
203 | intersect_indices: [R', 1] float tensor. The intersection indices.
204 | intersect_mask: [R,] float tensor. The mask denoting intersections.
205 | """
206 | # Compute ray destinations.
207 | normalized_rays_d = ray_utils.normalize_rays(rays=rays_d)
208 | rays_dst = rays_o + far_limit * normalized_rays_d
209 |
210 | # Transform the rays from world to box coordinate frame.
211 | rays_o_in_box_frame, intersect_mask, intersect_points_in_box_frame, intersect_indices = ( # [R,], [R', 2, 3], [R', 2]
212 | ray_box_intersection_pairwise(
213 | box_center=box_center,
214 | box_rotation_matrix=box_rotation,
215 | box_length=box_length,
216 | box_width=box_width,
217 | box_height=box_height,
218 | rays_start_point=rays_o,
219 | rays_end_point=rays_dst))
220 | intersect_indices = intersect_indices.unsqueeze(1).long() # [R', 1]
221 | intersect_bounds = compute_bounds_from_intersect_points(
222 | rays_o=rays_o_in_box_frame,
223 | intersect_indices=intersect_indices,
224 | intersect_points=intersect_points_in_box_frame)
225 | return intersect_bounds, intersect_indices, intersect_mask
226 |
--------------------------------------------------------------------------------
/build_occupancy_tree.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import grad
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | import math
9 | import os
10 | import scipy.integrate as integrate
11 | import math
12 | from collections import deque
13 | import time
14 | from mpl_toolkits.mplot3d import Axes3D
15 | import matplotlib.pyplot as plt
16 | import itertools
17 |
18 | from utils import *
19 | from local_distill import create_multi_network_fourier_embedding, has_flag, create_multi_network
20 | from multi_modules import build_multi_network_from_single_networks, query_multi_network
21 | import kilonerf_cuda
22 |
23 | # TODO: move this to utils.py
24 | class Node:
25 | def __init__(self):
26 | pass
27 |
28 | # This function actually builds an occupancy grid
29 | def build_occupancy_tree(cfg, log_path):
30 | dev = torch.device('cuda')
31 | kilonerf_cuda.init_stream_pool(16) # TODO: cleanup
32 | kilonerf_cuda.init_magma()
33 |
34 | ConfigManager.init(cfg)
35 |
36 | global_domain_min, global_domain_max = ConfigManager.get_global_domain_min_and_max(torch.device('cpu'))
37 | global_domain_size = global_domain_max - global_domain_min
38 | Logger.write('global_domain_min: {}, global_domain_max: {}'.format(global_domain_min, global_domain_max))
39 |
40 | pretrained_cfg = load_yaml_as_dict(cfg['pretrained_cfg_path'])
41 |
42 | if 'distilled_cfg_path' in pretrained_cfg:
43 | pretrained_cfg = load_yaml_as_dict(pretrained_cfg['distilled_cfg_path'])
44 |
45 | if 'discovery' in pretrained_cfg:
46 | for key in pretrained_cfg['discovery']:
47 | pretrained_cfg[key] = pretrained_cfg['discovery'][key]
48 | """
49 | else:
50 | # end2end from scratch case
51 | assert pretrained_cfg['model_type'] == 'multi_network', 'occupancy grid creation is only implemented for multi networks'
52 | """
53 |
54 | cp = torch.load(cfg['pretrained_checkpoint_path'])
55 | use_multi_network = pretrained_cfg['model_type'] == 'multi_network' or not ('model_type' in pretrained_cfg)
56 | if use_multi_network:
57 | position_num_input_channels, position_fourier_embedding = create_multi_network_fourier_embedding(1, pretrained_cfg['num_frequencies'])
58 | direction_num_input_channels, direction_fourier_embedding = create_multi_network_fourier_embedding(1, pretrained_cfg['num_frequencies_direction'])
59 |
60 | if 'model_state_dict' in cp:
61 | res = pretrained_cfg['fixed_resolution']
62 | network_resolution = torch.tensor(res, dtype=torch.long, device=torch.device('cpu'))
63 | num_networks = res[0] * res[1] * res[2]
64 | network_voxel_size = global_domain_size / network_resolution
65 |
66 | multi_network = create_multi_network(num_networks, position_num_input_channels, direction_num_input_channels, 4,
67 | 'multimatmul_differentiable', pretrained_cfg).to(dev)
68 | multi_network.load_state_dict(cp['model_state_dict'])
69 |
70 | # Determine bounding boxes (domains) of all networks. Required for global to local coordinate conversion.
71 | domain_mins = []
72 | domain_maxs = []
73 | for coord in itertools.product(*[range(r) for r in res]):
74 | coord = torch.tensor(coord, device=torch.device('cpu'))
75 | domain_min = global_domain_min + network_voxel_size * coord
76 | domain_max = domain_min + network_voxel_size
77 | domain_mins.append(domain_min.tolist())
78 | domain_maxs.append(domain_max.tolist())
79 | domain_mins = torch.tensor(domain_mins, device=dev)
80 | domain_maxs = torch.tensor(domain_maxs, device=dev)
81 | else:
82 | root_nodes = cp['root_nodes']
83 |
84 | # Merging individual networks into multi network for efficient inference
85 | single_networks = []
86 | domain_mins, domain_maxs = [], []
87 | nodes_to_process = root_nodes.copy()
88 | for node in nodes_to_process:
89 | if hasattr(node, 'network'):
90 | node.network_index = len(single_networks)
91 | single_networks.append(node.network)
92 | domain_mins.append(node.domain_min)
93 | domain_maxs.append(node.domain_max)
94 | else:
95 | nodes_to_process.append(node.leq_child)
96 | nodes_to_process.append(node.gt_child)
97 | linear_implementation = 'multimatmul_differentiable'
98 | multi_network = build_multi_network_from_single_networks(single_networks, linear_implementation=linear_implementation).to(dev)
99 | domain_mins = torch.tensor(domain_mins, device=dev)
100 | domain_maxs = torch.tensor(domain_maxs, device=dev)
101 | else:
102 | # Load teacher NeRF model:
103 | print("Load teacher NeRF model...")
104 | pretrained_nerf = load_pretrained_nerf_model(dev, cfg)
105 |
106 | occupancy_res = cfg['resolution']
107 | total_num_voxels = occupancy_res[0] * occupancy_res[1] * occupancy_res[2]
108 | occupancy_grid = torch.tensor(occupancy_res, device=dev, dtype=torch.bool)
109 | occupancy_resolution = torch.tensor(occupancy_res, dtype=torch.long, device=torch.device('cpu'))
110 | occupancy_voxel_size = global_domain_size / occupancy_resolution
111 | first_voxel_min = global_domain_min
112 | first_voxel_max = first_voxel_min + occupancy_voxel_size
113 |
114 | first_voxel_samples = []
115 | for dim in range(3):
116 | first_voxel_samples.append(torch.linspace(first_voxel_min[dim], first_voxel_max[dim], cfg['subsample_resolution'][dim]))
117 | first_voxel_samples = torch.stack(torch.meshgrid(*first_voxel_samples), dim=3).view(-1, 3)
118 |
119 | ranges = []
120 | for dim in range(3):
121 | ranges.append(torch.arange(0, occupancy_res[dim]))
122 | index_grid = torch.stack(torch.meshgrid(*ranges), dim=3)
123 | index_grid = (index_grid * occupancy_voxel_size).unsqueeze(3)
124 |
125 | points = first_voxel_samples.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand(occupancy_res + list(first_voxel_samples.shape))
126 | points = points + index_grid
127 | points = points.view(total_num_voxels, -1, 3)
128 | num_samples_per_voxel = points.size(1)
129 |
130 | mock_directions = torch.empty(min(cfg['voxel_batch_size'], total_num_voxels), 3).to(dev)
131 |
132 | # We query in a fixed grid at a higher resolution than the occupancy grid resolution to detect fine structures.
133 | all_densities = torch.empty(total_num_voxels, num_samples_per_voxel)
134 | print("all_densities size: ", all_densities.shape)
135 | end = 0
136 | while end < total_num_voxels:
137 | print('sampling network: {}/{} ({:.4f}%)'.format(end, total_num_voxels, 100 * end / total_num_voxels))
138 | start = end
139 | end = min(start + cfg['voxel_batch_size'], total_num_voxels)
140 | actual_batch_size = end - start
141 | points_subset = points[start:end].to(dev).contiguous() # voxel_batch_size x num_samples_per_voxel x 3
142 | mock_directions_subset = mock_directions[:actual_batch_size]
143 | density_dim = 3
144 | with torch.no_grad():
145 | if use_multi_network:
146 | result = query_multi_network(multi_network, domain_mins, domain_maxs, points_subset, mock_directions_subset,
147 | position_fourier_embedding, direction_fourier_embedding, None, None, False, None, pretrained_cfg)[:, :, density_dim]
148 | else:
149 | print("Use teacher NeRF model...")
150 | mock_directions_subset = mock_directions_subset.unsqueeze(1).expand(points_subset.size())
151 | points_and_dirs = torch.cat([points_subset.reshape(-1, 3), mock_directions_subset.reshape(-1, 3)], dim=-1)
152 | lights_subset = torch.zeros((points_and_dirs.shape[0], 3), device=dev)
153 | lights_subset[:, -1] = 1
154 | points_and_dirs_and_lights = torch.cat([points_and_dirs, lights_subset], dim=-1)
155 | result = pretrained_nerf(points_and_dirs_and_lights)[:, density_dim].view(actual_batch_size, -1)
156 | # result = F.relu(pretrained_nerf(points_and_dirs_and_lights)[:, density_dim].view(actual_batch_size, -1))
157 | # print(result.max())
158 | all_densities[start:end] = result.cpu()
159 | del points, points_subset, mock_directions
160 |
161 | occupancy_grid = all_densities.to(dev) > cfg['threshold']
162 | del all_densities
163 | occupancy_grid = occupancy_grid.view(cfg['resolution'] + [-1])
164 |
165 | occupancy_grid = occupancy_grid.any(dim=3) # checks if any point in the voxel is above the threshold
166 |
167 |
168 | Logger.write('{} out of {} voxels are occupied. {:.2f}%'.format(occupancy_grid.sum().item(), occupancy_grid.numel(), 100 * occupancy_grid.sum().item() / occupancy_grid.numel()))
169 |
170 | occupancy_filename = log_path + '/occupancy.pth'
171 | torch.save(occupancy_grid, occupancy_filename)
172 | Logger.write('Saved occupancy grid to {}'.format(occupancy_filename))
173 |
174 | def main():
175 | cfg, log_path = parse_args_and_init_logger()
176 | build_occupancy_tree(cfg, log_path)
177 |
178 | if __name__ == '__main__':
179 | main()
180 |
181 |
--------------------------------------------------------------------------------
/calibs/dataPack.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/calibs/dataPack.npz
--------------------------------------------------------------------------------
/calibs/depth_bg.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/calibs/depth_bg.npy
--------------------------------------------------------------------------------
/calibs/polycalib.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/calibs/polycalib.npz
--------------------------------------------------------------------------------
/calibs/real_bg.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/calibs/real_bg.npy
--------------------------------------------------------------------------------
/cam_utils.py:
--------------------------------------------------------------------------------
1 | """Various camera utility functions."""
2 |
3 | import numpy as np
4 |
5 |
6 | def w2c_to_c2w(w2c):
7 | """
8 | Args:
9 | w2c: [N, 4, 4] np.float32. World-to-camera extrinsics matrix.
10 |
11 | Returns:
12 | c2w: [N, 4, 4] np.float32. Camera-to-world extrinsics matrix.
13 | """
14 | R = w2c[:3, :3]
15 | T = w2c[:3, 3]
16 |
17 | c2w = np.eye(4, dtype=np.float32)
18 | c2w[:3, 3] = -1 * np.dot(R.transpose(), w2c[:3, 3])
19 | c2w[:3, :3] = R.transpose()
20 | return c2w
21 |
--------------------------------------------------------------------------------
/ddsp_torch.py:
--------------------------------------------------------------------------------
1 | from typing import Text
2 |
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | from scipy import fftpack
6 | import torch
7 |
8 | def torch_float32(x):
9 | """Ensure array/tensor is a float32 tf.Tensor."""
10 | if isinstance(x, torch.Tensor):
11 | return x.float() # This is a no-op if x is float32.
12 | elif isinstance(x, np.ndarray):
13 | return torch.from_numpy(x).cuda() # This is a no-op if x is float32.
14 | else:
15 | return torch.tensor(x, dtype=torch.float32).cuda() # This is a no-op if x is float32.
16 |
17 | def safe_log(x, eps=1e-7):
18 | """Avoid taking the log of a non-positive number."""
19 | safe_x = torch.where(x <= eps, eps, x.double())
20 | return torch.log(safe_x)
21 |
22 | def stft(audio, frame_size=2048, overlap=0.75):
23 | """Differentiable stft in PyTorch, computed in batch."""
24 | assert frame_size * overlap % 2.0 == 0.0
25 |
26 | # Remove channel dim if present.
27 | audio = torch_float32(audio)
28 | if len(audio.shape) == 3:
29 | audio = torch.squeeze(audio, axis=-1)
30 |
31 | s = torch.stft(
32 | audio,
33 | n_fft=int(frame_size),
34 | hop_length=int(frame_size * (1.0 - overlap)),
35 | win_length=int(frame_size),
36 | window=torch.hann_window(int(frame_size)).to(audio),
37 | pad_mode='reflect',
38 | return_complex=True,
39 | )
40 | return s
41 |
42 | def compute_mag(audio, size=2048, overlap=0.75):
43 | mag = torch.abs(stft(audio, frame_size=size, overlap=overlap))
44 | return torch_float32(mag)
45 |
46 | def compute_logmag(audio, size=2048, overlap=0.75):
47 | return safe_log(compute_mag(audio, size, overlap))
48 |
49 | def specplot(audio,
50 | vmin=-5,
51 | vmax=1,
52 | rotate=True,
53 | size=512 + 256,
54 | **matshow_kwargs):
55 | """Plot the log magnitude spectrogram of audio."""
56 | # If batched, take first element.
57 | if len(audio.shape) == 2:
58 | audio = audio[0]
59 |
60 | logmag = compute_logmag(torch_float32(audio), size=size)
61 | # logmag = spectral_ops.compute_logmel(core.tf_float32(audio), lo_hz=8.0, bins=80, fft_size=size)
62 | # logmag = spectral_ops.compute_mfcc(core.tf_float32(audio), mfcc_bins=40, fft_size=size)
63 | # if rotate:
64 | # logmag = torch.rot90(logmag)
65 | logmag = torch.flip(logmag, [0])
66 | # Plotting.
67 | plt.matshow(logmag.detach().cpu(),
68 | vmin=vmin,
69 | vmax=vmax,
70 | cmap=plt.cm.magma,
71 | aspect='auto',
72 | **matshow_kwargs)
73 | plt.xticks([])
74 | plt.yticks([])
75 | plt.xlabel('Time')
76 | plt.ylabel('Frequency')
77 |
78 | # Time-varying convolution -----------------------------------------------------
79 | def get_fft_size(frame_size: int, ir_size: int, power_of_2: bool = True) -> int:
80 | """Calculate final size for efficient FFT.
81 |
82 | Args:
83 | frame_size: Size of the audio frame.
84 | ir_size: Size of the convolving impulse response.
85 | power_of_2: Constrain to be a power of 2. If False, allow other 5-smooth
86 | numbers. TPU requires power of 2, while GPU is more flexible.
87 |
88 | Returns:
89 | fft_size: Size for efficient FFT.
90 | """
91 | convolved_frame_size = ir_size + frame_size - 1
92 | if power_of_2:
93 | # Next power of 2.
94 | fft_size = int(2**np.ceil(np.log2(convolved_frame_size)))
95 | else:
96 | fft_size = int(fftpack.helper.next_fast_len(convolved_frame_size))
97 | return fft_size
98 |
99 | def crop_and_compensate_delay(audio: torch.Tensor, audio_size: int, ir_size: int,
100 | padding: Text,
101 | delay_compensation: int) -> torch.Tensor:
102 | """Crop audio output from convolution to compensate for group delay.
103 |
104 | Args:
105 | audio: Audio after convolution. Tensor of shape [batch, time_steps].
106 | audio_size: Initial size of the audio before convolution.
107 | ir_size: Size of the convolving impulse response.
108 | padding: Either 'valid' or 'same'. For 'same' the final output to be the
109 | same size as the input audio (audio_timesteps). For 'valid' the audio is
110 | extended to include the tail of the impulse response (audio_timesteps +
111 | ir_timesteps - 1).
112 | delay_compensation: Samples to crop from start of output audio to compensate
113 | for group delay of the impulse response. If delay_compensation < 0 it
114 | defaults to automatically calculating a constant group delay of the
115 | windowed linear phase filter from frequency_impulse_response().
116 |
117 | Returns:
118 | Tensor of cropped and shifted audio.
119 |
120 | Raises:
121 | ValueError: If padding is not either 'valid' or 'same'.
122 | """
123 | # Crop the output.
124 | if padding == 'valid':
125 | crop_size = ir_size + audio_size - 1
126 | elif padding == 'same':
127 | crop_size = audio_size
128 | else:
129 | raise ValueError('Padding must be \'valid\' or \'same\', instead '
130 | 'of {}.'.format(padding))
131 |
132 | # Compensate for the group delay of the filter by trimming the front.
133 | # For an impulse response produced by frequency_impulse_response(),
134 | # the group delay is constant because the filter is linear phase.
135 | total_size = int(audio.shape[-1])
136 | crop = total_size - crop_size
137 | start = ((ir_size - 1) // 2 -
138 | 1 if delay_compensation < 0 else delay_compensation)
139 | end = crop - start
140 | return audio[:, start:-end]
141 |
142 | def fft_convolve(audio: torch.Tensor,
143 | impulse_response: torch.Tensor,
144 | padding: Text = 'same',
145 | delay_compensation: int = -1) -> torch.Tensor:
146 | """Filter audio with frames of time-varying impulse responses.
147 |
148 | Time-varying filter. Given audio [batch, n_samples], and a series of impulse
149 | responses [batch, n_frames, n_impulse_response], splits the audio into frames,
150 | applies filters, and then overlap-and-adds audio back together.
151 | Applies non-windowed non-overlapping STFT/ISTFT to efficiently compute
152 | convolution for large impulse response sizes.
153 |
154 | Args:
155 | audio: Input audio. Tensor of shape [batch, audio_timesteps].
156 | impulse_response: Finite impulse response to convolve. Can either be a 2-D
157 | Tensor of shape [batch, ir_size], or a 3-D Tensor of shape [batch,
158 | ir_frames, ir_size]. A 2-D tensor will apply a single linear
159 | time-invariant filter to the audio. A 3-D Tensor will apply a linear
160 | time-varying filter. Automatically chops the audio into equally shaped
161 | blocks to match ir_frames.
162 | padding: Either 'valid' or 'same'. For 'same' the final output to be the
163 | same size as the input audio (audio_timesteps). For 'valid' the audio is
164 | extended to include the tail of the impulse response (audio_timesteps +
165 | ir_timesteps - 1).
166 | delay_compensation: Samples to crop from start of output audio to compensate
167 | for group delay of the impulse response. If delay_compensation is less
168 | than 0 it defaults to automatically calculating a constant group delay of
169 | the windowed linear phase filter from frequency_impulse_response().
170 |
171 | Returns:
172 | audio_out: Convolved audio. Tensor of shape
173 | [batch, audio_timesteps + ir_timesteps - 1] ('valid' padding) or shape
174 | [batch, audio_timesteps] ('same' padding).
175 |
176 | Raises:
177 | ValueError: If audio and impulse response have different batch size.
178 | ValueError: If audio cannot be split into evenly spaced frames. (i.e. the
179 | number of impulse response frames is on the order of the audio size and
180 | not a multiple of the audio size.)
181 | """
182 | audio, impulse_response = torch_float32(audio), torch_float32(impulse_response)
183 |
184 | # Add a frame dimension to impulse response if it doesn't have one.
185 | ir_shape = list(impulse_response.shape)
186 | if len(ir_shape) == 2:
187 | impulse_response = torch.unsqueeze(impulse_response, axis = 2)
188 | ir_shape = list(impulse_response.shape)
189 |
190 | # Get shapes of audio and impulse response.
191 | batch_size_ir, n_ir_frames, ir_size = ir_shape
192 | batch_size, audio_size = list(audio.shape)
193 |
194 | # Validate that batch sizes match.
195 | if batch_size != batch_size_ir:
196 | raise ValueError('Batch size of audio ({}) and impulse response ({}) must '
197 | 'be the same.'.format(batch_size, batch_size_ir))
198 |
199 | # Cut audio into frames.
200 | frame_size = int(np.ceil(audio_size / n_ir_frames))
201 | hop_size = frame_size
202 | audio_frames = audio.unfold(1, frame_size, hop_size)
203 |
204 | # Check that number of frames match.
205 | n_audio_frames = int(audio_frames.shape[1])
206 | if n_audio_frames != n_ir_frames:
207 | raise ValueError(
208 | 'Number of Audio frames ({}) and impulse response frames ({}) do not '
209 | 'match. For small hop size = ceil(audio_size / n_ir_frames), '
210 | 'number of impulse response frames must be a multiple of the audio '
211 | 'size.'.format(n_audio_frames, n_ir_frames))
212 |
213 | # Pad and FFT the audio and impulse responses.
214 | fft_size = get_fft_size(frame_size, ir_size, power_of_2=True)
215 | audio_fft = torch.fft.rfft(audio_frames, fft_size)
216 | ir_fft = torch.fft.rfft(impulse_response, fft_size)
217 |
218 | # Multiply the FFTs (same as convolution in time).
219 | audio_ir_fft = torch.multiply(audio_fft, ir_fft)
220 |
221 | # Take the IFFT to resynthesize audio.
222 | audio_frames_out = torch.fft.irfft(audio_ir_fft)
223 | # audio_out = tf.signal.overlap_and_add(audio_frames_out, hop_size)
224 | audio_out = torch.squeeze(audio_frames_out, axis=1)
225 |
226 | # Crop and shift the output audio.
227 | return crop_and_compensate_delay(audio_out, audio_size, ir_size, padding,
228 | delay_compensation)
229 |
230 | def get_modal_fir(gains, frequencies, dampings, n_samples=44100*2, sample_rate=44100):
231 | t = torch.reshape(torch.arange(n_samples)/sample_rate, (1, 1, -1)).cuda()
232 | g = torch.unsqueeze(gains, axis=2)
233 | f = torch.reshape(frequencies, (1, -1, 1))
234 | d = torch.reshape(dampings, (1, -1, 1))
235 | pure = torch.sin(2 * np.pi * f * t)
236 | damped = torch.exp(-1 * torch.abs(d) * t) * pure
237 | signal = torch.sum(g * damped, axis=1)
238 | return torch.cat((torch.zeros_like(signal), signal), axis=1)
--------------------------------------------------------------------------------
/demo/ObjectFile.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/demo/ObjectFile.pth
--------------------------------------------------------------------------------
/demo/audio_demo_forces.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/demo/audio_demo_forces.npy
--------------------------------------------------------------------------------
/demo/audio_demo_vertices.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/demo/audio_demo_vertices.npy
--------------------------------------------------------------------------------
/demo/touch_demo_gelinfo.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/demo/touch_demo_gelinfo.npy
--------------------------------------------------------------------------------
/demo/touch_demo_vertices.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/demo/touch_demo_vertices.npy
--------------------------------------------------------------------------------
/demo/vision_demo.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rhgao/ObjectFolder/3c6cd8930b2dcbadb6d94dadf2745c956bdcd236/demo/vision_demo.npy
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: ObjectFolder-env
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | - nvidia
7 | dependencies:
8 | - pip=21.0.1
9 | - python=3.8
10 | - pytorch=1.8.1
11 | - torchvision
12 | - cudatoolkit=11.1.1
13 | - numpy
14 | - scikit-image
15 | - scipy
16 | - tqdm
17 | - imageio
18 | - pyyaml
19 | - pip:
20 | - imageio-ffmpeg
21 | - lpips
22 | - opencv-python
23 | - librosa
24 |
--------------------------------------------------------------------------------
/fast_kilonerf_renderer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import kilonerf_cuda
3 | from utils import PerfMonitor, ConfigManager
4 | from run_nerf_helpers import get_rays, replace_transparency_by_background_color
5 |
6 | class FastKiloNeRFRenderer():
7 | def __init__(self, c2w, intrinsics, background_color, occupancy_grid, multi_network, domain_mins, domain_maxs,
8 | white_bkgd, max_depth_index, min_distance, max_distance, performance_monitoring, occupancy_resolution, max_samples_per_ray, transmittance_threshold):
9 |
10 | self.set_camera_pose(c2w)
11 | self.intrinsics = intrinsics
12 | self.background_color = background_color
13 | self.occupancy_grid = occupancy_grid
14 | self.multi_network = multi_network
15 | self.domain_mins = domain_mins
16 | self.domain_maxs = domain_maxs
17 | self.white_bkgd = white_bkgd
18 | self.max_depth_index = max_depth_index
19 | self.min_distance = min_distance
20 | self.max_distance = max_distance
21 | self.performance_monitoring = performance_monitoring
22 | self.occupancy_resolution = occupancy_resolution
23 | self.max_samples_per_ray = max_samples_per_ray
24 | self.transmittance_threshold = transmittance_threshold
25 |
26 | # Get ray directions for abitrary render pose
27 | # Precompute distances between sampling points, which vary along the pixel dimension
28 | _, rays_d = get_rays(intrinsics, self.c2w) # H x W x 3
29 | direction_norms = torch.norm(rays_d, dim=-1) # H x W
30 | self.distance_between_samples = (1 / (self.max_depth_index - 1)) * (self.max_distance - self.min_distance)
31 | self.constant_dists = (self.distance_between_samples * direction_norms).view(-1).unsqueeze(1) # H * W x 1
32 |
33 | self.rgb_map = torch.empty([self.intrinsics.H, self.intrinsics.W, 3], dtype=torch.float, device=occupancy_grid.device)
34 | self.rgb_map_pointer = self.rgb_map.data_ptr()
35 |
36 | def set_rgb_map_pointer(self, rgb_map_pointer):
37 | self.rgb_map = None
38 | self.rgb_map_pointer = rgb_map_pointer
39 |
40 | def set_camera_pose(self, c2w):
41 | self.c2w = c2w[:3, :4]
42 |
43 | def render(self):
44 | PerfMonitor.add('start')
45 | PerfMonitor.is_active = self.performance_monitoring
46 |
47 | rays_o, rays_d = get_rays(self.intrinsics, self.c2w, expand_origin=False)
48 | PerfMonitor.add('ray directions', ['preprocessing'])
49 |
50 | origin = rays_o
51 | directions = rays_d.reshape(-1, 3) # directions are *not* normalized.
52 | res = self.occupancy_resolution
53 | global_domain_min, global_domain_max = ConfigManager.get_global_domain_min_and_max(directions.device)
54 | global_domain_size = global_domain_max - global_domain_min
55 | occupancy_resolution = torch.tensor(res, dtype=torch.long, device=directions.device)
56 | strides = torch.tensor([res[2] * res[1], res[2], 1], dtype=torch.int, device=directions.device) # assumes row major ordering
57 | voxel_size = global_domain_size / occupancy_resolution
58 | num_rays = directions.size(0)
59 |
60 | active_ray_mask = torch.empty(num_rays, dtype=torch.bool, device=directions.device)
61 | depth_indices = torch.empty(num_rays, dtype=torch.short, device=directions.device)
62 | acc_map = torch.empty([self.intrinsics.H, self.intrinsics.W], dtype=torch.float, device=directions.device)
63 | # the final transmittance of a pass will be the initial transmittance of the next
64 | transmittance = torch.empty([self.intrinsics.H, self.intrinsics.W], dtype=torch.float, device=directions.device)
65 |
66 | PerfMonitor.add('prep', ['preprocessing'])
67 |
68 | is_initial_query = True
69 | is_final_pass = False
70 |
71 | pass_idx = 0
72 | integrate_num_blocks = 40
73 | integrate_num_threads = 512
74 | while not is_final_pass:
75 |
76 | if type(self.max_samples_per_ray) is list:
77 | # choose max samples per ray depending on the pass
78 | # in the later passes we can sample more per ray to avoid too much overhead from too many passes
79 | current_max_samples_per_ray = self.max_samples_per_ray[min(pass_idx, len(self.max_samples_per_ray) - 1)]
80 | else:
81 | # just use the same number of samples for all passes
82 | current_max_samples_per_ray = self.max_samples_per_ray
83 |
84 | # Compute query indices along the rays and determine assignment of query location to networks
85 | # Tunable CUDA hyperparameters
86 | kernel_max_num_blocks = 40
87 | kernel_max_num_threads = 512
88 | version = 0
89 | query_indices, assigned_networks = kilonerf_cuda.generate_query_indices_on_ray(origin, directions, self.occupancy_grid, active_ray_mask, depth_indices, voxel_size,
90 | global_domain_min, global_domain_max, strides, self.distance_between_samples, current_max_samples_per_ray, self.max_depth_index, self.min_distance, is_initial_query,
91 | kernel_max_num_blocks, kernel_max_num_threads, version)
92 |
93 | PerfMonitor.add('sample query points', ['preprocessing'])
94 |
95 | with_explicit_mask = True
96 | query_indices = query_indices.view(-1)
97 | assigned_networks = assigned_networks.view(-1)
98 | if with_explicit_mask:
99 | active_samples_mask = assigned_networks != -1
100 | assigned_networks = assigned_networks[active_samples_mask]
101 | # when with_expclit_mask = False: Sort all points, including those with assigned_network == -1
102 | # Points with assigned_network == -1 will be placed in the beginning and can then be filtered by moving the start of the array (zero cost)
103 |
104 | #assigned_networks, reorder_indices = torch.sort(assigned_networks) # sorting via PyTorch is significantly slower
105 | #reorder_indices = torch.arange(assigned_networks.size(0), dtype=torch.int32, device=assigned_networks.device)
106 | #kilonerf_cuda.sort_by_key_int16_int32(assigned_networks, reorder_indices) # stable sort does not seem to be slower/faster
107 | reorder_indices = torch.arange(assigned_networks.size(0), dtype=torch.int64, device=assigned_networks.device)
108 | kilonerf_cuda.sort_by_key_int16_int64(assigned_networks, reorder_indices)
109 | PerfMonitor.add('sort', ['reorder and backorder'])
110 |
111 | # make sure that also batch sizes are given for networks which are queried 0 points
112 | contained_nets, batch_size_per_network_incomplete = torch.unique_consecutive(assigned_networks, return_counts=True)
113 | if not with_explicit_mask:
114 | num_removable_points = batch_size_per_network_incomplete[0]
115 | contained_nets = contained_nets[1:].to(torch.long)
116 | batch_size_per_network_incomplete = batch_size_per_network_incomplete[1:]
117 | else:
118 | contained_nets = contained_nets.to(torch.long)
119 | batch_size_per_network = torch.zeros(self.multi_network.num_networks, device=query_indices.device, dtype=torch.long)
120 | batch_size_per_network[contained_nets] = batch_size_per_network_incomplete
121 | ends = batch_size_per_network.cumsum(0).to(torch.int32)
122 | starts = ends - batch_size_per_network.to(torch.int32)
123 | PerfMonitor.add('batch_size_per_network', ['reorder and backorder'])
124 |
125 |
126 | # Remove all points which are assigned to no network (those points are in empty space or outside the global domain)
127 | if with_explicit_mask:
128 | query_indices = query_indices[active_samples_mask]
129 | else:
130 | reorder_indices = reorder_indices[num_removable_points:] # just moving a pointer
131 | PerfMonitor.add('remove points', ['reorder and backorder'])
132 |
133 | # Reorder the query indices
134 | query_indices = query_indices[reorder_indices]
135 | #query_indices = kilonerf_cuda.gather_int32(reorder_indices, query_indices)
136 | query_indices = query_indices
137 | PerfMonitor.add('reorder', ['reorder and backorder'])
138 |
139 | num_points_to_process = query_indices.size(0) if query_indices.ndim > 0 else 0
140 | #print("#points to process:", num_points_to_process, flush=True)
141 | if num_points_to_process == 0:
142 | break
143 |
144 | # Evaluate the network
145 | network_eval_num_blocks = -1 # ignored currently
146 | compute_capability = torch.cuda.get_device_capability(query_indices.device)
147 | if compute_capability == (7, 5):
148 | network_eval_num_threads = 512 # for some reason the compiler uses more than 96 registers for this CC, so we cannot launch 640 threads
149 | else:
150 | network_eval_num_threads = 640
151 | version = 0
152 | raw_outputs = kilonerf_cuda.network_eval_query_index(query_indices, self.multi_network.serialized_params, self.domain_mins, self.domain_maxs, starts, ends, origin,
153 | self.c2w[:3, :3].contiguous(), self.multi_network.num_networks, self.multi_network.hidden_layer_size,
154 | self.intrinsics.H, self.intrinsics.W, self.intrinsics.cx, self.intrinsics.cy, self.intrinsics.fx, self.intrinsics.fy, self.max_depth_index, self.min_distance, self.distance_between_samples,
155 | network_eval_num_blocks, network_eval_num_threads, version)
156 | PerfMonitor.add('fused network eval', ['network query'])
157 |
158 | # Backorder outputs
159 | if with_explicit_mask:
160 | raw_outputs_backordered = torch.empty_like(raw_outputs)
161 | raw_outputs_backordered[reorder_indices] = raw_outputs
162 | #raw_outputs_backordered = kilonerf_cuda.scatter_int32_float4(reorder_indices, raw_outputs)
163 | del raw_outputs
164 | raw_outputs_full = torch.zeros(num_rays * current_max_samples_per_ray, 4, dtype=torch.float, device=raw_outputs_backordered.device)
165 | raw_outputs_full[active_samples_mask] = raw_outputs_backordered
166 | else:
167 | raw_outputs_full = torch.zeros(num_rays * current_max_samples_per_ray, 4, dtype=torch.float, device=raw_outputs.device)
168 | raw_outputs_full[reorder_indices] = raw_outputs
169 | PerfMonitor.add('backorder', ['reorder and backorder'])
170 |
171 | # Integrate sampled densities and colors along each ray to render the final image
172 | version = 0
173 | kilonerf_cuda.integrate(raw_outputs_full, self.constant_dists, self.rgb_map_pointer, acc_map, transmittance, active_ray_mask, num_rays, current_max_samples_per_ray,
174 | self.transmittance_threshold, is_initial_query, integrate_num_blocks, integrate_num_threads, version)
175 | is_final_pass = not active_ray_mask.any().item()
176 |
177 | is_initial_query = False
178 | if not is_final_pass:
179 | PerfMonitor.add('integration', ['integration'])
180 | pass_idx += 1
181 |
182 | if self.white_bkgd:
183 | kilonerf_cuda.replace_transparency_by_background_color(self.rgb_map_pointer, acc_map, self.background_color, integrate_num_blocks, integrate_num_threads)
184 |
185 | PerfMonitor.is_active = True
186 | PerfMonitor.add('integration', ['integration'])
187 | elapsed_time = PerfMonitor.log_and_reset(self.performance_monitoring)
188 | self.rgb_map = self.rgb_map.view(self.intrinsics.H, self.intrinsics.W, 3) if self.rgb_map is not None else None
189 | return self.rgb_map, elapsed_time
190 |
--------------------------------------------------------------------------------
/load_osf.py:
--------------------------------------------------------------------------------
1 | """Data loader for OSF data."""
2 |
3 | import os
4 | import torch
5 | import numpy as np
6 | import imageio
7 | import json
8 |
9 | import cam_utils
10 |
11 |
12 | trans_t = lambda t: torch.tensor([
13 | [1,0,0,0],
14 | [0,1,0,0],
15 | [0,0,1,t],
16 | [0,0,0,1]
17 | ], dtype=torch.float)
18 |
19 | rot_phi = lambda phi: torch.tensor([
20 | [1,0,0,0],
21 | [0,np.cos(phi),-np.sin(phi),0],
22 | [0,np.sin(phi), np.cos(phi),0],
23 | [0,0,0,1]
24 | ], dtype=torch.float)
25 |
26 | rot_theta = lambda th: torch.tensor([
27 | [np.cos(th),0,-np.sin(th),0],
28 | [0,1,0,0],
29 | [np.sin(th),0, np.cos(th),0],
30 | [0,0,0,1]
31 | ], dtype=torch.float)
32 |
33 |
34 | def pose_spherical(theta, phi, radius):
35 | c2w = trans_t(radius)
36 | c2w = rot_phi(phi/180.*np.pi) @ c2w
37 | c2w = rot_theta(theta/180.*np.pi) @ c2w
38 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
39 | return c2w
40 |
41 | def coordinates_to_c2w(x, y, z, r=2.5):
42 | theta = np.arccos(z / r)
43 | phi = np.arctan2(x, -y)
44 | Rx = np.array([[1, 0, 0], [0, np.cos(theta), -np.sin(theta)], [0, np.sin(theta), np.cos(theta)]])
45 | Rz = np.array([[np.cos(phi), -np.sin(phi), 0], [np.sin(phi), np.cos(phi), 0], [0, 0, 1]])
46 | R = Rz @ Rx
47 | c2w = R.tolist()
48 | c2w[0].append(x)
49 | c2w[1].append(y)
50 | c2w[2].append(z)
51 | c2w.append([0., 0., 0., 1.])
52 | #c2w = np.array(c2w).astype(np.float32)
53 | return c2w
54 |
55 | def convert_cameras_to_nerf_format(anno):
56 | """
57 | Args:
58 | anno: List of annotations for each example. Each annotation is represented by a
59 | dictionary that must contain the key `RT` which is the world-to-camera
60 | extrinsics matrix with shape [3, 4], in [right, down, forward] coordinates.
61 |
62 | Returns:
63 | c2w: [N, 4, 4] np.float32. Array of camera-to-world extrinsics matrices in
64 | [right, up, backwards] coordinates.
65 | """
66 | c2w_list = []
67 | for a in anno:
68 | # Convert from w2c to c2w.
69 | w2c = np.array(a['RT'] + [[0.0, 0.0, 0.0, 1.0]])
70 | c2w = cam_utils.w2c_to_c2w(w2c)
71 |
72 | # Convert from [right, down, forwards] to [right, up, backwards]
73 | c2w[:3, 1] *= -1 # down -> up
74 | c2w[:3, 2] *= -1 # forwards -> back
75 | c2w_list.append(c2w)
76 | c2w = np.array(c2w_list)
77 | print("c2w: ", c2w)
78 | return c2w
79 |
80 |
81 | def load_osf_data(test_file_path):
82 |
83 | all_poses = []
84 | all_metadata = []
85 | counts = [0]
86 | test_file = np.load(test_file_path)
87 | N = test_file.shape[0]
88 | for i in range(N):
89 | cx, cy, cz, lx, ly, lz = test_file[i]
90 | poses = coordinates_to_c2w(cx, cy, cz)
91 | metadata = np.array([[lx, ly, lz]]).astype(np.float32)
92 | all_poses.append(poses)
93 | all_metadata.append(metadata)
94 |
95 | poses = np.array(all_poses).astype(np.float32)
96 |
97 | metadata = np.concatenate(all_metadata, 0)
98 | counts.append(N)
99 | i_split = [np.arange(counts[0], counts[1])]
100 |
101 | H, W, focal = 256, 256, 355.5555419921875
102 |
103 | return poses, [H, W, focal], i_split, metadata
104 |
--------------------------------------------------------------------------------
/ray_utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions for ray computation."""
2 | import math
3 | import numpy as np
4 | from scipy.spatial.transform import Rotation as R
5 | import box_utils
6 | import torch
7 |
8 |
9 | def apply_batched_transformations(inputs, transformations):
10 | """Batched transformation of inputs.
11 |
12 | Args:
13 | inputs: List of [R, S, 3]
14 | transformations: [R, 4, 4]
15 |
16 | Returns:
17 | transformed_inputs: List of [R, S, 3]
18 | """
19 | # if rotation_only:
20 | # transformations[:, :3, 3] = torch.zeros((3,), dtype=torch.float)
21 |
22 | transformed_inputs = []
23 | for x in inputs:
24 | N_samples = x.size()[1]
25 | homog_transformations = transformations.unsqueeze(1) # [R, 1, 4, 4]
26 | homog_transformations = torch.tile(homog_transformations, (1, N_samples, 1, 1)) # [R, S, 4, 4]
27 | homog_component = torch.ones_like(x)[..., 0:1] # [R, S, 1]
28 | homog_x = torch.cat((x, homog_component), axis=-1) # [R, S, 4]
29 | homog_x = homog_x.unsqueeze(2)
30 | transformed_x = torch.matmul(
31 | homog_x,
32 | torch.transpose(homog_transformations, 2, 3)) # [R, S, 1, 4]
33 | transformed_x = transformed_x[..., 0, :3] # [R, S, 3]
34 | transformed_inputs.append(transformed_x)
35 | return transformed_inputs
36 |
37 |
38 | def get_transformation_from_params(params):
39 | translation, rotation = [0, 0, 0], [0, 0, 0]
40 | if 'translation' in params:
41 | translation = params['translation']
42 | if 'rotation' in params:
43 | rotation = params['rotation']
44 | translation = torch.tensor(translation, dtype=torch.float)
45 | rotmat = torch.tensor(R.from_euler('xyz', rotation, degrees=True).as_matrix(), dtype=torch.float)
46 | return translation, rotmat
47 |
48 |
49 | def rotate_dirs(dirs, rotmat):
50 | """
51 | Args:
52 | dirs: [R, 3] float tensor.
53 | rotmat: [3, 3]
54 | """
55 | if type(dirs) == np.ndarray:
56 | dirs = torch.tensor(dirs).float()
57 | #rotmat = rotmat.unsqueeze(0)
58 | rotmat = torch.broadcast_to(rotmat, (dirs.shape[0], 3, 3)) # [R, 3, 3]
59 | dirs_obj = torch.matmul(dirs.unsqueeze(1), torch.transpose(rotmat, 1, 2)) # [R, 1, 3]
60 | dirs_obj = dirs_obj.squeeze(1) # [R, 3]
61 | return dirs_obj
62 |
63 |
64 | def transform_dirs(dirs, params, inverse=False):
65 | _, rotmat = get_transformation_from_params(params) # [3,], [3, 3]
66 | if inverse:
67 | rotmat = torch.transpose(rotmat, 0, 1) # [3, 3]
68 | dirs_transformed = rotate_dirs(dirs, rotmat)
69 | return dirs_transformed
70 |
71 |
72 | def transform_rays(ray_batch, params, use_viewdirs, inverse=False):
73 | """Transform rays into object coordinate frame given o2w transformation params.
74 |
75 | Note: do not assume viewdirs is always the normalized version of rays_d (e.g., in staticcam case).
76 |
77 | Args:
78 | ray_batch: [R, M] float tensor. Batch of rays.
79 | params: Dictionary containing transformation parameters:
80 | 'translation': List of 3 elements. xyz translation.
81 | 'rotation': List of 3 euler angles in xyz.
82 | use_viewdirs: bool. Whether to we are using viewdirs.
83 | inverse: bool. Whether to apply inverse of the transformations provided in 'params'.
84 |
85 | Returns:
86 | ray_batch_obj: [R, M] float tensor. The ray batch, in object coordinate frame.
87 | """
88 | rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6]
89 | translation, rotmat = get_transformation_from_params(params) # [3,], [3, 3]
90 |
91 | if inverse:
92 | translation = -1 * translation # [3,]
93 | rotmat = torch.transpose(rotmat, 1, 0) # [3, 3]
94 |
95 | translation_inverse = -1 * translation
96 | rotmat_inverse = torch.transpose(rotmat, 1, 0)
97 |
98 | # Transform the ray origin.
99 | rays_o_obj, _ = box_utils.ray_to_box_coordinate_frame_pairwise(
100 | box_center=translation_inverse,
101 | box_rotation_matrix=rotmat_inverse,
102 | rays_start_point=rays_o,
103 | rays_end_point=rays_d)
104 |
105 | # Only apply rotation to rays_d.
106 | rays_d_obj = rotate_dirs(rays_d, rotmat)
107 |
108 | ray_batch_obj = update_ray_batch_slice(ray_batch, rays_o_obj, 0, 3)
109 | ray_batch_obj = update_ray_batch_slice(ray_batch_obj, rays_d_obj, 3, 6)
110 | if use_viewdirs:
111 | # Grab viewdirs from the ray batch itself. Because it may be different from rays_d
112 | # (as in the staticcam case).
113 | viewdirs = ray_batch[:, 8:11]
114 | viewdirs_obj = rotate_dirs(viewdirs, rotmat)
115 | ray_batch_obj = update_ray_batch_slice(ray_batch_obj, viewdirs_obj, 8, 11)
116 | return ray_batch_obj
117 |
118 |
119 | def transform_points_into_world_coordinate_frame(pts, params, check_numerics=False):
120 | translation, rotmat = get_transformation_from_params(params) # [3,], [3, 3]
121 |
122 | # pts_flat = pts.view(-1, 3) # [RS, 3]
123 | # num_examples = pts_flat.size()[0] # RS
124 |
125 | # translation = translation.unsqueeze(0)
126 | # translation = torch.tile(translation, (num_examples, 1)) # [RS, 3]
127 | # rotmat = rotmat.unsqueeze(0)
128 | # rotmat = torch.tile(rotmat, (num_examples, 1, 1))
129 |
130 | # # pts_flat_transformed = torch.matmul(pts_flat[:, None, :], torch.transpose(rotmat, 2, 1)) # [RS, 1, 3]
131 | # pts_flat_transformed = pts_flat[:, None, :] # [RS, 1, 3]
132 | # pts_flat_transformed += translation[:, None, :] # [RS, 1, 3]
133 | # pts_transformed = pts_flat_transformed.view(pts.size()) # [R, S, 3]
134 | chunk = 256
135 | # Check batch transformations works without rotation.
136 | if check_numerics:
137 | transformations = np.eye(4)
138 | transformations[:3, 3] = translation
139 | transformations = torch.tensor(transformations, dtype=torch.float) # [4, 4]
140 | transformations = torch.tile(transformations[None, ...], (pts.size()[0], 1, 1)) # [R, 4, 4]
141 | pts_transformed1 = []
142 | for i in range(0, pts.size()[0], chunk):
143 | pts_transformed1_chunk = apply_batched_transformations(
144 | inputs=[pts[i:i+chunk]], transformations=transformations[i:i+chunk])[0]
145 | pts_transformed1.append(pts_transformed1_chunk)
146 | pts_transformed1 = torch.cat(pts_transformed1, dim=0)
147 |
148 | pts_transformed2 = pts + translation[None, None, :]
149 |
150 | # Now add rotation
151 | transformations = np.eye(4)
152 | transformations = torch.tensor(transformations, dtype=torch.float)
153 | transformations[:3, :3] = rotmat
154 | transformations[:3, 3] = translation
155 | #transformations = torch.tensor(transformations, dtype=torch.float) # [4, 4]
156 | transformations = torch.tile(transformations[None, ...], (pts.size()[0], 1, 1)) # [R, 4, 4]
157 | pts_transformed = []
158 | for i in range(0, pts.size()[0], chunk):
159 | pts_transformed_chunk = apply_batched_transformations(
160 | inputs=[pts[i:i+chunk]], transformations=transformations[i:i+chunk])[0]
161 | pts_transformed.append(pts_transformed_chunk)
162 | pts_transformed = torch.cat(pts_transformed, dim=0)
163 | return pts_transformed
164 |
165 |
166 | # def transform_rays(ray_batch, translation, use_viewdirs):
167 | # """Apply transformation to rays.
168 |
169 | # Args:
170 | # ray_batch: [R, M] float tensor. All information necessary
171 | # for sampling along a ray, including: ray origin, ray direction, min
172 | # dist, max dist, and unit-magnitude viewing direction.
173 | # translation: [3,] float tensor. The (x, y, z) translation to apply.
174 | # use_viewdirs: Whether to use view directions.
175 |
176 | # Returns:
177 | # ray_batch: [R, M] float tensor. Transformed ray batch.
178 | # """
179 | # assert translation.size()[0] == 3, "translation.size()[0] must be 3..."
180 |
181 | # # Since we are only supporting translation for now, only ray origins need to be
182 | # # modified. Ray directions do not need to change.
183 | # rays_o = ray_batch[:, 0:3] + translation
184 | # rays_remaining = ray_batch[:, 3:]
185 | # ray_batch = torch.cat((rays_o, rays_remaining), dim=1)
186 | # return ray_batch
187 |
188 | def compute_rays_length(rays_d):
189 | """Compute ray length.
190 |
191 | Args:
192 | rays_d: [R, 3] float tensor. Ray directions.
193 |
194 | Returns:
195 | rays_length: [R, 1] float tensor. Ray lengths.
196 | """
197 | rays_length = torch.norm(rays_d, dim=-1, keepdim=True) # [N_rays, 1]
198 | return rays_length
199 |
200 |
201 | def normalize_rays(rays):
202 | """Normalize ray directions.
203 |
204 | Args:
205 | rays: [R, 3] float tensor. Ray directions.
206 |
207 | Returns:
208 | normalized_rays: [R, 3] float tensor. Normalized ray directions.
209 | """
210 | normalized_rays = rays / compute_rays_length(rays_d=rays)
211 | return normalized_rays
212 |
213 |
214 | def compute_ray_dirs_and_length(rays_o, rays_dst):
215 | """Compute ray directions.
216 |
217 | Args:
218 | rays_o: [R, 3] float tensor. Ray origins.
219 | rays_dst: [R, 3] float tensor. Ray destinations.
220 |
221 | Returns:
222 | rays_d: [R, 3] float tensor. Normalized ray directions.
223 | """
224 | # The ray directions are the difference between the ray destinations and the
225 | # ray origins.
226 | rays_d = rays_dst - rays_o # [R, 3] # Direction out of light source
227 |
228 | # Compute the length of the rays.
229 | rays_length = compute_rays_length(rays_d=rays_d)
230 |
231 | # Normalized the ray directions.
232 | rays_d = rays_d / rays_length # [R, 3] # Normalize direction
233 | return rays_d, rays_length
234 |
235 |
236 | def update_ray_batch_slice(ray_batch, x, start, end):
237 | left = ray_batch[:, :start] # [R, ?]
238 | right = ray_batch[:, end:] # [R, ?]
239 | updated_ray_batch = torch.cat((left, x, right), dim=-1)
240 | return updated_ray_batch
241 |
242 |
243 | def update_ray_batch_bounds(ray_batch, bounds):
244 | updated_ray_batch = update_ray_batch_slice(ray_batch=ray_batch, x=bounds,
245 | start=6, end=8)
246 | return updated_ray_batch
247 |
248 |
249 | def create_ray_batch(
250 | rays_o, rays_dst, rays_i, use_viewdirs, rays_near=None, rays_far=None, epsilon=1e-10):
251 | # Compute the ray directions.
252 | rays_d = rays_dst - rays_o # [R,3] # Direction out of light source
253 | rays_length = compute_rays_length(rays_d=rays_d) # [R, 1]
254 | rays_d = rays_d / rays_length # [R, 3] # Normalize direction
255 | viewdirs = rays_d # [R, 3]
256 |
257 | # If bounds are not provided, set the beginning and end of ray as sampling bounds.
258 | if rays_near is None:
259 | rays_near = torch.zeros((rays_o.size()[0], 1), dtype=torch.float) + epsilon # [R, 1]
260 | if rays_far is None:
261 | rays_far = rays_length # [R, 1]
262 |
263 | ray_batch = torch.cat((rays_o, rays_d, rays_near, rays_far), dim=-1)
264 | if use_viewdirs:
265 | ray_batch = torch.cat((ray_batch, viewdirs), dim=-1)
266 | ray_batch = torch.cat((ray_batch, rays_i), dim=-1)
267 | return ray_batch
268 |
269 |
270 | def sample_random_lightdirs(num_rays, num_samples, upper_only=False):
271 | """Randomly sample directions in the unit sphere.
272 |
273 | Args:
274 | num_rays: int or tensor shape dimension. Number of rays.
275 | num_samples: int or tensor shape dimension. Number of samples per ray.
276 | upper_only: bool. Whether to sample only on the upper hemisphere.
277 |
278 | Returns:
279 | lightdirs: [R, S, 3] float tensor. Random light directions sampled from the unit
280 | sphere for each sampled point.
281 | """
282 | if upper_only:
283 | min_z = 0
284 | else:
285 | min_z = -1
286 |
287 | phi = torch.rand(num_rays, num_samples) * (2 * math.pi) # [R, S]
288 | cos_theta = torch.rand(num_rays, num_samples) * (1 - min_z) + min_z # [R, S]
289 | theta = torch.acos(cos_theta) # [R, S]
290 |
291 | x = torch.sin(theta) * torch.cos(phi)
292 | y = torch.sin(theta) * torch.sin(phi)
293 | z = torch.cos(theta)
294 |
295 | lightdirs = torch.cat((x[..., None], y[..., None], z[..., None]), dim=-1) # [R, S, 3]
296 | return lightdirs
297 |
298 |
299 | def get_light_positions(rays_i, img_light_pos):
300 | """Extracts light positions given scene IDs.
301 |
302 | Args:
303 | rays_i: [R, 1] float tensor. Per-ray image IDs.
304 | img_light_pos: [N, 3] float tensor. Per-image light positions.
305 |
306 | Returns:
307 | rays_light_pos: [R, 3] float tensor. Per-ray light positions.
308 | """
309 | #print("img_light_pos shape: ", img_light_pos.shape)
310 | rays_light_pos = img_light_pos[rays_i.long()].squeeze() # [R, 3]
311 | return rays_light_pos
312 |
313 |
314 | def get_lightdirs(lightdirs_method, num_rays=None, num_samples=None, rays_i=None,
315 | metadata=None, ray_batch=None, use_viewdirs=False, normalize=False):
316 | """Compute lightdirs.
317 |
318 | Args:
319 | lightdirs_method: str. Method to use for computing lightdirs.
320 | num_rays: int or tensor shape dimension. Number of rays.
321 | num_samples: int or tensor shape dimension. Number of samples per ray.
322 | rays_i: [R, 1] float tensor. Ray image IDs.
323 | metadata: [N, 3] float tensor. Metadata about each image. Currently only light
324 | position is provided.
325 | ray_batch: [R, M] float tensor. Ray batch.
326 | use_viewdirs: bool. Whether to use viewdirs.
327 | normalize: bool. Whether to normalize lightdirs.
328 |
329 | Returns;
330 | lightdirs: [R, S, 3] float tensor. Light directions for each sample.
331 | """
332 | if lightdirs_method == 'viewdirs':
333 | raise NotImplementedError
334 | assert use_viewdirs
335 | lightdirs = ray_batch[:, 8:11] # [R, 3]
336 | lightdirs *= 1.5
337 | lightdirs = torch.tile(lightdirs[:, None, :], (1, num_samples, 1))
338 | elif lightdirs_method == 'metadata':
339 | lightdirs = get_light_positions(rays_i, metadata) # [R, 3]
340 | lightdirs = torch.tile(lightdirs[:, None, :], (1, num_samples, 1)) # [R, S, 3]
341 | elif lightdirs_method == 'random':
342 | lightdirs = sample_random_lightdirs(num_rays, num_samples) # [R, S, 3]
343 | elif lightdirs_method == 'random_upper':
344 | lightdirs = sample_random_lightdirs(num_rays, num_samples, upper_only=True) # [R, S, 3]
345 | else:
346 | raise ValueError(f'Invalid lightdirs_method: {lightdirs_method}.')
347 | if normalize:
348 | lightdirs_flat = lightdirs.view(-1, 3) # [RS, 3]
349 | lightdirs_flat = normalize_rays(lightdirs_flat) # [RS, 3]
350 | lightdirs = lightdirs_flat.view(lightdirs.size()) # [R, S, 3]
351 | return lightdirs
352 |
--------------------------------------------------------------------------------
/taxim_render.py:
--------------------------------------------------------------------------------
1 | '''
2 | GelSight tactile render with taxim
3 |
4 | Zilin Si (zsi@andrew.cmu.edu)
5 | Last revision: March 2022
6 | '''
7 |
8 | import os
9 | from os import path as osp
10 | import numpy as np
11 | import cv2
12 |
13 | from basics import sensorParams as psp
14 | from basics.CalibData import CalibData
15 |
16 | class TaximRender:
17 |
18 | def __init__(self, calib_path):
19 | # taxim calibration files
20 | # polytable
21 | calib_data = osp.join(calib_path, "polycalib.npz")
22 | self.calib_data = CalibData(calib_data)
23 | # raw calibration data
24 | rawData = osp.join(calib_path, "dataPack.npz")
25 | data_file = np.load(rawData, allow_pickle=True)
26 | self.f0 = data_file['f0']
27 | ## tactile image config
28 | bins = psp.numBins
29 | [xx, yy] = np.meshgrid(range(psp.w), range(psp.h))
30 | xf = xx.flatten()
31 | yf = yy.flatten()
32 | self.A = np.array([xf*xf,yf*yf,xf*yf,xf,yf,np.ones(psp.h*psp.w)]).T
33 | binm = bins - 1
34 | self.x_binr = 0.5*np.pi/binm # x [0,pi/2]
35 | self.y_binr = 2*np.pi/binm # y [-pi, pi]
36 |
37 | # load depth bg
38 | self.bg_depth = np.load(osp.join(calib_path,"depth_bg.npy"), allow_pickle=True)
39 | # load tactile bg
40 | self.real_bg = np.load(osp.join(calib_path,"real_bg.npy"), allow_pickle=True)
41 |
42 | def correct_height_map(self, height_map):
43 | # move the center of depth to the origin
44 | height_map = (height_map-psp.cam2gel) * -1000 / psp.pixmm
45 | return height_map
46 |
47 | def padding(self, img):
48 | # pad one row & one col on each side
49 | if len(img.shape) == 2:
50 | return np.pad(img, ((1, 1), (1, 1)), 'edge')
51 | elif len(img.shape) == 3:
52 | return np.pad(img, ((1, 1), (1, 1), (0, 0)), 'edge')
53 |
54 | def generate_normals(self, height_map):
55 | # from height map to gradient magnitude & directions
56 |
57 | [h,w] = height_map.shape
58 | center = height_map[1:h-1,1:w-1] # z(x,y)
59 | top = height_map[0:h-2,1:w-1] # z(x-1,y)
60 | bot = height_map[2:h,1:w-1] # z(x+1,y)
61 | left = height_map[1:h-1,0:w-2] # z(x,y-1)
62 | right = height_map[1:h-1,2:w] # z(x,y+1)
63 | dzdx = (bot-top)/2.0
64 | dzdy = (right-left)/2.0
65 |
66 | mag_tan = np.sqrt(dzdx**2 + dzdy**2)
67 | grad_mag = np.arctan(mag_tan)
68 | invalid_mask = mag_tan == 0
69 | valid_mask = ~invalid_mask
70 | grad_dir = np.zeros((h-2,w-2))
71 | grad_dir[valid_mask] = np.arctan2(dzdx[valid_mask]/mag_tan[valid_mask], dzdy[valid_mask]/mag_tan[valid_mask])
72 |
73 | grad_mag = self.padding(grad_mag)
74 | grad_dir = self.padding(grad_dir)
75 | return grad_mag, grad_dir
76 |
77 | def render(self, depth, press_depth):
78 |
79 | depth = self.correct_height_map(depth)
80 | height_map = depth.copy()
81 |
82 | ## generate contact mask
83 | pressing_height_pix = press_depth * 1000 / psp.pixmm
84 | contact_mask = (height_map-(self.bg_depth)) > pressing_height_pix * 0.2
85 |
86 | # smooth out the soft contact
87 | zq_back = height_map.copy()
88 | kernel_size = [11,5]
89 | for k in range(len(kernel_size)):
90 | height_map = cv2.GaussianBlur(height_map.astype(np.float32),(kernel_size[k],kernel_size[k]),0)
91 | height_map[contact_mask] = zq_back[contact_mask]
92 | # height_map = cv2.GaussianBlur(height_map.astype(np.float32),(5,5),0)
93 |
94 | # generate gradients
95 | grad_mag, grad_dir = self.generate_normals(height_map)
96 |
97 | # simulate raw image
98 | sim_img_r = np.zeros((psp.h,psp.w,3))
99 | idx_x = np.floor(grad_mag/self.x_binr).astype('int')
100 | idx_y = np.floor((grad_dir+np.pi)/self.y_binr).astype('int')
101 |
102 | params_r = self.calib_data.grad_r[idx_x,idx_y,:]
103 | params_r = params_r.reshape((psp.h*psp.w), params_r.shape[2])
104 | params_g = self.calib_data.grad_g[idx_x,idx_y,:]
105 | params_g = params_g.reshape((psp.h*psp.w), params_g.shape[2])
106 | params_b = self.calib_data.grad_b[idx_x,idx_y,:]
107 | params_b = params_b.reshape((psp.h*psp.w), params_b.shape[2])
108 |
109 | est_r = np.sum(self.A * params_r,axis = 1)
110 | est_g = np.sum(self.A * params_g,axis = 1)
111 | est_b = np.sum(self.A * params_b,axis = 1)
112 | sim_img_r[:,:,0] = est_r.reshape((psp.h,psp.w))
113 | sim_img_r[:,:,1] = est_g.reshape((psp.h,psp.w))
114 | sim_img_r[:,:,2] = est_b.reshape((psp.h,psp.w))
115 |
116 | # add back ground
117 | tactile_img = sim_img_r + self.real_bg
118 | tactile_img = np.clip(tactile_img, 0, 255)
119 |
120 | return height_map, contact_mask, tactile_img
121 |
122 | #if __name__ == "__main__":
123 |
124 | # define the press depth, and get the depth map from touch net.
125 | # taxim = TaximRender(calib_path)
126 | # height_map, contact_mask, tactile_img = taxim.render(depth, press_depth)
127 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import signal
3 | import time
4 | from collections import deque, defaultdict
5 | from itertools import product
6 | import numpy as np
7 | import argparse
8 | import yaml
9 | import run_nerf_helpers
10 | import torch
11 | from tqdm import tqdm
12 | import lpips
13 |
14 | class CameraIntrinsics:
15 | def __init__(self, H, W, fx, fy, cx, cy):
16 | self.H = H
17 | self.W = W
18 | self.fx = fx
19 | self.fy = fy
20 | self.cx = cx
21 | self.cy = cy
22 |
23 | def get_random_points_inside_domain(num_points, domain_min, domain_max):
24 | x = np.random.uniform(domain_min[0], domain_max[0], size=(num_points,))
25 | y = np.random.uniform(domain_min[1], domain_max[1], size=(num_points,))
26 | z = np.random.uniform(domain_min[2], domain_max[2], size=(num_points,))
27 | return np.column_stack((x, y, z))
28 |
29 | def get_random_directions(num_samples):
30 | random_directions = np.random.randn(num_samples, 3)
31 | random_directions /= np.linalg.norm(random_directions, axis=1).reshape(-1, 1)
32 | return random_directions
33 | # return 2.5 * random_directions
34 |
35 | def get_random_lights(num_samples):
36 | random_lights = np.random.randn(num_samples, 3)
37 | random_lights /= np.linalg.norm(random_lights, axis=1).reshape(-1, 1)
38 | return random_lights
39 |
40 | def load_pretrained_nerf_model(dev, cfg):
41 | pretrained_cfg = load_yaml_as_dict(cfg['pretrained_cfg_path'])
42 | if 'use_initialization_fix' not in pretrained_cfg:
43 | pretrained_cfg['use_initialization_fix'] = False
44 | if 'num_importance_samples_per_ray' not in pretrained_cfg:
45 | pretrained_cfg['num_importance_samples_per_ray'] = 0
46 | pretrained_nerf, embed_fn, embeddirs_fn, embedlights_fn = create_nerf(pretrained_cfg)
47 | pretrained_nerf = pretrained_nerf.to(dev)
48 | checkpoint = torch.load(cfg['pretrained_checkpoint_path'])
49 | pretrained_nerf.load_state_dict(checkpoint['model_state_dict'])
50 | pretrained_nerf = run_nerf_helpers.ChainEmbeddingAndModel(pretrained_nerf.model_coarse, embed_fn, embeddirs_fn, embedlights_fn) # pos. encoding
51 | return pretrained_nerf
52 |
53 | def create_nerf(cfg):
54 | embed_fn, input_ch = run_nerf_helpers.get_embedder(cfg['num_frequencies'], 0)
55 | embeddirs_fn, input_ch_views = run_nerf_helpers.get_embedder(cfg['num_frequencies_direction'], 0)
56 | embedlights_fn, input_ch_lights = run_nerf_helpers.get_embedder(cfg['num_frequencies_direction'], 0)
57 | output_ch = 4
58 | skips = [cfg['refeed_position_index']]
59 | model = run_nerf_helpers.NeRF2(D=cfg['num_hidden_layers'], W=cfg['hidden_layer_size'],
60 | input_ch=input_ch, output_ch=output_ch, skips=skips,
61 | input_ch_views=input_ch_views, input_ch_lights=input_ch_lights,
62 | use_viewdirs=True, use_lightdirs=True,
63 | direction_layer_size=cfg['direction_layer_size'], use_initialization_fix=cfg['use_initialization_fix'])
64 |
65 | if cfg['num_importance_samples_per_ray'] > 0:
66 | model_fine = run_nerf_helpers.NeRF2(D=cfg['num_hidden_layers'], W=cfg['hidden_layer_size'],
67 | input_ch=input_ch, output_ch=output_ch, skips=skips,
68 | input_ch_views=input_ch_views, input_ch_lights=input_ch_lights,
69 | use_viewdirs=True, use_lightdirs=True,
70 | direction_layer_size=cfg['direction_layer_size'], use_initialization_fix=cfg['use_initialization_fix'])
71 | model = run_nerf_helpers.CoarseAndFine(model, model_fine)
72 |
73 | return model, embed_fn, embeddirs_fn, embedlights_fn
74 |
75 | def query_densities(points, pretrained_nerf, cfg, dev):
76 | mock_directions = torch.zeros_like(points) # density does not depend on direction
77 | points_and_dirs = torch.cat([points, mock_directions], dim=1)
78 | num_points_and_dirs = points_and_dirs.size(0)
79 | densities = torch.empty(num_points_and_dirs)
80 | if 'query_batch_size' in cfg:
81 | query_batch_size = cfg['query_batch_size']
82 | else:
83 | query_batch_size = num_points_and_dirs
84 | with torch.no_grad():
85 | start = 0
86 | while start < num_points_and_dirs:
87 | end = min(start + query_batch_size, num_points_and_dirs)
88 | densities[start:end] = F.relu(pretrained_nerf(points_and_dirs[start:end].to(dev))[:, -1]).cpu() # Only select the densities (A) from NeRF's RGBA output
89 | start = end
90 | return densities
91 |
92 | def has_flag(cfg, name):
93 | return name in cfg and cfg[name]
94 |
95 | def load_yaml_as_dict(path):
96 | with open(path) as yaml_file:
97 | yaml_as_dict = yaml.load(yaml_file, Loader=yaml.FullLoader)
98 | return yaml_as_dict
99 |
100 | def parse_args_and_init_logger(default_cfg_path=None, parse_render_cfg_path=False):
101 | parser = argparse.ArgumentParser(description='NeRF distillation')
102 | parser.add_argument('cfg_path', type=str)
103 | parser.add_argument('log_path', type=str, nargs='?')
104 | if parse_render_cfg_path:
105 | parser.add_argument('-rcfg', '--render_cfg_path', type=str)
106 | args = parser.parse_args()
107 | if args.log_path is None:
108 | start = args.cfg_path.find('/')
109 | end = args.cfg_path.rfind('.')
110 | args.log_path = 'logs' + args.cfg_path[start:end]
111 | print('auto log path:', args.log_path)
112 |
113 | create_directory_if_not_exists(args.log_path)
114 | Logger.filename = args.log_path + '/log.txt'
115 |
116 | cfg = load_yaml_as_dict(args.cfg_path)
117 | if default_cfg_path is not None:
118 | default_cfg = load_yaml_as_dict(default_cfg_path)
119 | for key in default_cfg:
120 | if not key in cfg:
121 | cfg[key] = default_cfg[key]
122 | print(cfg)
123 |
124 | ret_val = (cfg, args.log_path)
125 | if parse_render_cfg_path:
126 | ret_val += (args.render_cfg_path,)
127 |
128 | return ret_val
129 |
130 | class IterativeMean:
131 | def __init__(self):
132 | self.value = None
133 | self.num_old_values = 0
134 |
135 | def add_values(self, new_values):
136 | if self.value:
137 | self.value = (self.num_old_values * self.value + new_values.size(0) * new_values.mean()) / (self.num_old_values + new_values.size(0))
138 | else:
139 | self.value = new_values.mean()
140 | self.num_old_values += new_values.size(0)
141 |
142 | def get_mean(self):
143 | return self.value.item()
144 |
145 |
146 | def create_directory_if_not_exists(directory):
147 | if not os.path.isdir(directory):
148 | os.makedirs(directory)
149 |
150 | class Logger:
151 | filename = None
152 |
153 | @staticmethod
154 | def write(text):
155 | with open(Logger.filename, 'a') as log_file:
156 | print(text, flush=True)
157 | log_file.write(text + '\n')
158 |
159 | class GracefulKiller:
160 | kill_now = False
161 | def __init__(self):
162 | signal.signal(signal.SIGUSR1, self.exit_gracefully)
163 |
164 | def exit_gracefully(self, signum, frame):
165 | self.kill_now = True
166 |
167 | def extract_domain_boxes_from_tree(root_node):
168 | nodes_to_process = deque([root_node])
169 | boxes = []
170 | while nodes_to_process:
171 | node = nodes_to_process.popleft()
172 | if hasattr(node, 'leq_child'):
173 | nodes_to_process.append(node.leq_child)
174 | nodes_to_process.append(node.gt_child)
175 | else:
176 | boxes.append([node.domain_min, node.domain_max])
177 |
178 | return boxes
179 |
180 | def write_boxes_to_obj(boxes, obj_filename):
181 | txt = ''
182 | i = 0
183 | for box in tqdm(boxes):
184 | for min_or_max in product(range(2), repeat=3):
185 | txt += 'v {} {} {}\n'.format(box[min_or_max[0]][0], box[min_or_max[1]][1], box[min_or_max[2]][2])
186 | for x, y, z in [(0b000, 0b100, 0b010), (0b100, 0b010, 0b110),
187 | (0b001, 0b101, 0b011), (0b101, 0b011, 0b111),
188 | (0b000, 0b010, 0b001), (0b001, 0b011, 0b010),
189 | (0b100, 0b110, 0b101), (0b101, 0b111, 0b110),
190 | (0b000, 0b100, 0b001), (0b100, 0b101, 0b001),
191 | (0b010, 0b110, 0b011), (0b110, 0b111, 0b011)]:
192 | txt += 'f {} {} {}\n'.format(1 + i * 8 + x, 1 + i * 8 + y, 1 + i * 8 + z)
193 | i += 1
194 |
195 | with open(obj_filename, 'a') as obj_file:
196 | obj_file.write(txt)
197 |
198 | class PerfMonitor:
199 | events = []
200 | is_active = True
201 |
202 | @staticmethod
203 | def add(name, groups=[]):
204 | if PerfMonitor.is_active:
205 | torch.cuda.synchronize()
206 | t = time.perf_counter()
207 | PerfMonitor.events.append((name, t, groups))
208 |
209 | @staticmethod
210 | def log_and_reset(write_detailed_log):
211 | previous_t = PerfMonitor.events[0][1]
212 | group_map = defaultdict(float)
213 | elapsed_times = []
214 | for name, t, groups in PerfMonitor.events[1:]:
215 | elapsed_time = t - previous_t
216 | elapsed_times.append(elapsed_time)
217 | for group in groups:
218 | group_map[group] += elapsed_time
219 | group_map['total'] += elapsed_time
220 | previous_t = t
221 | max_length = max([len(name) for name, _, _ in PerfMonitor.events] + [len(group) for group in group_map])
222 |
223 | if write_detailed_log:
224 | for event, elapsed_time in zip(PerfMonitor.events[1:], elapsed_times):
225 | name = event[0]
226 | extra_whitespace = ' ' * (max_length - len(name))
227 | Logger.write('{}:{} {:7.2f} ms'.format(name, extra_whitespace, 1000 * (elapsed_time)))
228 | Logger.write('')
229 | for group in group_map:
230 | extra_whitespace = ' ' * (max_length - len(group))
231 | Logger.write('{}:{} {:7.2f} ms'.format(group, extra_whitespace, 1000 * (group_map[group])))
232 |
233 | # Reset
234 | PerfMonitor.events = []
235 |
236 | return group_map['total']
237 |
238 | class LPIPS:
239 | loss_fn_alex = None
240 |
241 | @staticmethod
242 | def calculate(img_a, img_b):
243 | img_a, img_b = [img.permute([2, 1, 0]).unsqueeze(0) for img in [img_a, img_b]]
244 | if LPIPS.loss_fn_alex == None: # lazy init
245 | LPIPS.loss_fn_alex = lpips.LPIPS(net='alex', version='0.1')
246 | return LPIPS.loss_fn_alex(img_a, img_b)
247 |
248 |
249 | def get_distance_to_closest_point_in_box(point, domain_min, domain_max):
250 | closest_point = np.array([0., 0., 0.])
251 | for dim in range(3):
252 | if point[dim] < domain_min[dim]:
253 | closest_point[dim] = domain_min[dim]
254 | elif domain_max[dim] < point[dim]:
255 | closest_point[dim] = domain_max[dim]
256 | else: # in between domain_min and domain_max
257 | closest_point[dim] = point[dim]
258 | return np.linalg.norm(point - closest_point)
259 |
260 | def get_distance_to_furthest_point_in_box(point, domain_min, domain_max):
261 | furthest_point = np.array([0., 0., 0.])
262 | for dim in range(3):
263 | mid = (domain_min[dim] + domain_max[dim]) / 2
264 | if point[dim] > mid:
265 | furthest_point[dim] = domain_min[dim]
266 | else:
267 | furthest_point[dim] = domain_max[dim]
268 | return np.linalg.norm(point - furthest_point)
269 |
270 | def load_matrix(path):
271 | return np.array([[float(w) for w in line.strip().split()] for line in open(path)]).astype(np.float32)
272 |
273 | class ConfigManager:
274 | global_domain_min = None
275 | global_domain_max = None
276 |
277 | @staticmethod
278 | def init(cfg):
279 | if 'global_domain_min' in cfg and 'global_domain_max' in cfg:
280 | ConfigManager.global_domain_min = cfg['global_domain_min']
281 | ConfigManager.global_domain_max = cfg['global_domain_max']
282 | elif 'dataset_dir' in cfg and cfg['dataset_type'] == 'nsvf':
283 | bbox_path = os.path.join(cfg['dataset_dir'], 'bbox.txt')
284 | bounding_box = load_matrix(bbox_path)[0, :-1]
285 | ConfigManager.global_domain_min = bounding_box[:3]
286 | ConfigManager.global_domain_max = bounding_box[3:]
287 |
288 | @staticmethod
289 | def get_global_domain_min_and_max(device=None):
290 | result = ConfigManager.global_domain_min, ConfigManager.global_domain_max
291 | if device:
292 | result = [torch.tensor(x, dtype=torch.float, device=device) for x in result]
293 | return result
294 |
295 |
296 | def main():
297 | if False:
298 | boxes = [
299 | [[-0.078125, 0.390625, 0.546875], [-0.0625, 0.40625, 0.5625]],
300 | [[-0.625, -0.375, -0.375], [-0.5, -0.25, -0.25]],
301 | [[-0.625, -0.25, -0.375], [-0.5, -0.125, -0.25]],
302 | [[-0.625, -0.125, -0.375], [-0.5, 0.0, -0.25]],
303 | [[-0.5, 0.5, 0.0], [-0.375, 0.625, 0.25]],
304 | [[-0.125, 0.125, -0.5], [0.0, 0.25, -0.25]],
305 | [[-0.375, -0.25, 0.0], [-0.25, -0.125, 0.25]],
306 | [[-0.625, -0.5, -0.5], [-0.5, -0.375, -0.25]],
307 | [[-0.125, 0.0, 0.5], [0.0, 0.25, 0.75]]
308 | ]
309 | boxes = [[[0.15625, -0.3125, 0.8125], [0.1875, -0.25, 0.875]]]
310 | print(boxes)
311 | write_boxes_to_obj(boxes, 'hard_domains_2.obj')
312 |
313 | if __name__ == '__main__':
314 | main()
315 |
--------------------------------------------------------------------------------
/von_mises.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | '''
4 | Pick a point uniformly from the unit circle
5 | '''
6 | def circle_uniform_pick(size, out = None):
7 | if out is None:
8 | out = np.empty((size, 2))
9 |
10 | angle = 2 * np.pi * np.random.random(size)
11 | out[:,0], out[:,1] = np.cos(angle), np.sin(angle)
12 |
13 | return out
14 |
15 | def cross_product_matrix_batched(U):
16 | batch_size = U.shape[0]
17 | result = np.zeros(shape=(batch_size, 3, 3))
18 | result[:, 0, 1] = -U[:, 2]
19 | result[:, 0, 2] = U[:, 1]
20 | result[:, 1, 0] = U[:, 2]
21 | result[:, 1, 2] = -U[:, 0]
22 | result[:, 2, 0] = -U[:, 1]
23 | result[:, 2, 1] = U[:, 0]
24 | return result
25 |
26 | '''
27 | Von Mises-Fisher distribution, ie. isotropic Gaussian distribution defined over
28 | a sphere.
29 | mus => mean directions
30 | kappa => concentration
31 |
32 | Uses numerical tricks described in "Numerically stable sampling of the von
33 | Mises Fisher distribution on S2 (and other tricks)" by Wenzel Jakob
34 | '''
35 |
36 | def sample_von_mises_3d(mus, kappa, out=None):
37 | size = mus.shape[0]
38 |
39 | # Generate the samples for mu=(0, 0, 1)
40 | eta = np.random.random(size)
41 | tmp = 1. - (((eta - 1.) / eta) * np.exp(-2. * kappa))
42 | W = 1. + (np.log(eta) + np.log(tmp)) / kappa
43 |
44 | V = np.empty((size, 2))
45 | circle_uniform_pick(size, out = V)
46 | V *= np.sqrt(1. - W ** 2)[:, None]
47 |
48 | if out is None:
49 | out = np.empty((size, 3))
50 |
51 | out[:, 0], out[:, 1], out[:, 2] = V[:, 0], V[:, 1], W
52 |
53 | angles = np.arccos(mus[:, 2])
54 | mask = angles != 0.
55 | angles = angles[mask]
56 | mus = mus[mask]
57 |
58 | axis = np.zeros(shape=mus.shape)
59 | axis[:, 0] = -mus[:, 1]
60 | axis[:, 1] = mus[:, 0]
61 |
62 | axis /= np.sqrt(np.sum(axis ** 2, axis=1))[:, None]
63 | rot = np.cos(angles)[:, None, None] * np.identity(3)[None, :, :]
64 | rot += np.sin(angles)[:, None, None] * cross_product_matrix_batched(axis)
65 | rot += (1. - np.cos(angles))[:, None, None] * np.matmul(axis[:, :, None], axis[:, None, :])
66 |
67 | out[mask] = (rot @ out[mask, :, None])[:, :, 0]
68 | return out
69 |
70 |
71 | if __name__ == '__main__':
72 | from math import sqrt
73 | mus = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1],
74 | [1 / sqrt(2), -1 / sqrt(2), 0],
75 | [-1 / sqrt(2), -1 / sqrt(2), 0],
76 | [-1 / sqrt(2), 1 / sqrt(2), 0],
77 | [-1 / sqrt(2), 0., -1 / sqrt(2)],
78 | [1 / sqrt(2), 0., 1 / sqrt(2)]])
79 | print(mus)
80 |
81 | sampled = sample_von_mises_3d(mus, 100000)
82 | print(sampled)
83 |
84 |
--------------------------------------------------------------------------------