├── .havenignore
├── src
├── __init__.py
├── datasets
│ ├── utils.py
│ ├── color_matching.py
│ ├── extract_baselines.py
│ ├── preprocess_Waymo.py
│ ├── __init__.py
│ └── waymo_od.py
├── renderer
│ ├── metrics.py
│ └── losses.py
├── pointLF
│ ├── feature_mapping.py
│ ├── pointLF_helper.py
│ ├── ptlf_vis.py
│ ├── light_field_renderer.py
│ ├── layer.py
│ ├── pointcloud_encoding
│ │ ├── simpleview.py
│ │ └── pointnet_features.py
│ ├── scene_point_lightfield.py
│ ├── attention_modules.py
│ └── icp
│ │ └── pts_registration.py
├── scenes
│ ├── nodes.py
│ ├── raysampler
│ │ ├── frustum_helpers.py
│ │ └── rayintersection.py
│ └── init_detection.py
├── utils_dist.py
└── utils.py
├── scripts
├── vis.gif
└── tst_waymo.py
├── exp_configs
├── __init__.py
└── pointLF_exps.py
├── LICENSE
├── .gitignore
├── README.md
└── trainval.py
/.havenignore:
--------------------------------------------------------------------------------
1 | .tmp
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/scripts/vis.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-computational-imaging/neural-point-light-fields/HEAD/scripts/vis.gif
--------------------------------------------------------------------------------
/exp_configs/__init__.py:
--------------------------------------------------------------------------------
1 | from . import pointLF_exps
2 |
3 | EXP_GROUPS = {}
4 | EXP_GROUPS.update(pointLF_exps.EXP_GROUPS)
5 |
--------------------------------------------------------------------------------
/src/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def invert_transformation(rot, t):
4 | t = np.matmul(-rot.T, t)
5 | inv_translation = np.concatenate([rot.T, t[:, None]], axis=1)
6 | return np.concatenate([inv_translation, np.array([[0., 0., 0., 1.]])])
7 |
8 | def roty_matrix(roty):
9 | c = np.cos(roty)
10 | s = np.sin(roty)
11 | return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])
12 |
13 |
14 | def rotz_matrix(roty):
15 | c = np.cos(roty)
16 | s = np.sin(roty)
17 | return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 princeton-computational-imaging
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/renderer/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from ..scenes import NeuralScene
5 | from pytorch3d.renderer.implicit.utils import RayBundle
6 |
7 |
8 | def calc_mse(x: torch.Tensor, y: torch.Tensor, **kwargs):
9 | return torch.mean((x - y) ** 2)
10 |
11 |
12 | def calc_psnr(x: torch.Tensor, y: torch.Tensor, **kwargs):
13 | mse = calc_mse(x, y)
14 | psnr = -10.0 * torch.log10(mse)
15 | return psnr
16 |
17 |
18 | def calc_latent_dist(xycfn, scene, reg, **kwargs):
19 | trainable_latent_nodes_id = scene.getSceneIdByTypeId(list(scene.nodes['scene_object'].keys()))
20 |
21 | # Remove "non"-nodes from the rays
22 | latent_nodess_id = xycfn[..., 4].unique().tolist()
23 | try:
24 | latent_nodess_id.remove(-1)
25 | except:
26 | pass
27 |
28 | # Just include nodes that have latent arrays
29 | # TODO: Do that for each class separatly
30 | latent_nodess_id = set(trainable_latent_nodes_id) & set(latent_nodess_id)
31 |
32 | if len(latent_nodess_id) != 0:
33 | latent_codes = torch.stack([scene.getNodeBySceneId(i).latent for i in latent_nodess_id])
34 | latent_dist = torch.sum(reg * torch.norm(latent_codes, dim=-1))
35 | else:
36 | latent_dist = torch.tensor(0.)
37 |
38 | return latent_dist
39 |
40 |
41 | def get_rgb_gt(rgb: torch.Tensor, scene: NeuralScene, xycfn: torch.Tensor):
42 | xycf = xycfn[..., -1, :4].reshape(len(rgb), -1)
43 | rgb_gt = torch.zeros_like(rgb)
44 |
45 | # TODO: Make more efficient by not retriving image for each pixel,
46 | # but storing all gt_images in a single tensor on the cpu
47 | # TODO: During test time just get all images avilable
48 | for f in xycf[:, 3].unique():
49 | if f == -1:
50 | continue
51 | frame = scene.frames[int(f)]
52 | for c in xycf[:, 2].unique():
53 | cf_mask = torch.all(xycf[:, 2:] == torch.tensor([c, f], device=xycf.device), dim=1)
54 | xy = xycf[cf_mask, :2].cpu()
55 |
56 | c_id = scene.getNodeBySceneId(int(c)).type_idx
57 | gt_img = frame.images[c_id]
58 | gt_px = torch.from_numpy(gt_img[xy[:, 1], xy[:, 0]]).to(device=rgb.device, dtype=rgb.dtype)
59 | rgb_gt[cf_mask] = gt_px
60 |
61 | return rgb_gt
--------------------------------------------------------------------------------
/src/datasets/color_matching.py:
--------------------------------------------------------------------------------
1 | # https://stackoverflow.com/questions/56918877/color-match-in-images
2 | import numpy as np
3 | import cv2
4 | from skimage.io import imread, imsave
5 | from skimage import exposure
6 | from skimage.exposure import match_histograms
7 |
8 | import matplotlib.pyplot as plt
9 |
10 |
11 | # https://www.pyimagesearch.com/2014/06/30/super-fast-color-transfer-images/
12 | def color_transfer(source, target):
13 | source = cv2.normalize(source, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
14 | target = cv2.normalize(target, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
15 | source = cv2.cvtColor(source, cv2.COLOR_RGB2BGR)
16 | target = cv2.cvtColor(target, cv2.COLOR_RGB2BGR)
17 |
18 | # convert the images from the RGB to L*ab* color space, being
19 | # sure to utilizing the floating point data type (note: OpenCV
20 | # expects floats to be 32-bit, so use that instead of 64-bit)
21 | source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype("float32")
22 | target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype("float32")
23 |
24 | # compute color statistics for the source and target images
25 | (lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, bMeanSrc, bStdSrc) = image_stats(source)
26 | (lMeanTar, lStdTar, aMeanTar, aStdTar, bMeanTar, bStdTar) = image_stats(target)
27 | # subtract the means from the target image
28 | (l, a, b) = cv2.split(target)
29 | l -= lMeanTar
30 | # a -= aMeanTar
31 | # b -= bMeanTar
32 | # scale by the standard deviations
33 | l = (lStdTar / lStdSrc) * l
34 | # a = (aStdTar / aStdSrc) * a
35 | # b = (bStdTar / bStdSrc) * b
36 | # add in the source mean
37 | l += lMeanSrc
38 | # a += aMeanSrc
39 | # b += bMeanSrc
40 | # clip the pixel intensities to [0, 255] if they fall outside
41 | # this range
42 | l = np.clip(l, 0, 255)
43 | a = np.clip(a, 0, 255)
44 | b = np.clip(b, 0, 255)
45 | # merge the channels together and convert back to the RGB color
46 | # space, being sure to utilize the 8-bit unsigned integer data
47 | # type
48 | transfer = cv2.merge([l, a, b])
49 | transfer = cv2.cvtColor(transfer.astype("uint8"), cv2.COLOR_LAB2BGR)
50 |
51 | # return the color transferred image
52 | transfer = cv2.cvtColor(transfer, cv2.COLOR_BGR2RGB).astype("float32") / 255.
53 | return transfer
54 |
55 |
56 | def image_stats(image):
57 | # compute the mean and standard deviation of each channel
58 | (l, a, b) = cv2.split(image)
59 | (lMean, lStd) = (l.mean(), l.std())
60 | (aMean, aStd) = (a.mean(), a.std())
61 | (bMean, bStd) = (b.mean(), b.std())
62 | # return the color statistics
63 | return (lMean, lStd, aMean, aStd, bMean, bStd)
64 |
65 |
66 | def histogram_matching(source, target):
67 | matched = match_histograms(target, source, multichannel=False)
68 | return matched
--------------------------------------------------------------------------------
/src/pointLF/feature_mapping.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 |
7 | # Positional Encoding
8 | class PositionalEncoding(nn.Module):
9 | def __init__(self, multires, input_dims=3, include_input=True, log_sampling=True):
10 |
11 | super().__init__()
12 | self.embed_fns = []
13 | self.out_dims = 0
14 |
15 | if include_input:
16 | self.embed_fns.append(lambda x: x)
17 | self.out_dims += input_dims
18 |
19 | max_freq = multires - 1
20 | N_freqs = multires
21 |
22 | if log_sampling:
23 | freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs)
24 | else:
25 | freq_bands = torch.linspace(2.0 ** 0.0, 2.0 ** max_freq, steps=N_freqs)
26 |
27 | for freq in freq_bands:
28 | for periodic_fn in [torch.sin, torch.cos]:
29 | self.embed_fns.append(
30 | lambda x, periodic_fn=periodic_fn, freq=freq: periodic_fn(x * freq)
31 | )
32 | self.out_dims += input_dims
33 |
34 | def forward(self, x: torch.Tensor):
35 | return torch.cat([fn(x) for fn in self.embed_fns], -1)
36 |
37 |
38 | # Positional Encoding Old (section 5.1)
39 | class Embedder:
40 | def __init__(self, **kwargs):
41 | self.kwargs = kwargs
42 | self.create_embedding_fn()
43 |
44 | def create_embedding_fn(self):
45 | embed_fns = []
46 | d = self.kwargs["input_dims"]
47 | out_dim = 0
48 | if self.kwargs["include_input"]:
49 | embed_fns.append(lambda x: x)
50 | out_dim += d
51 |
52 | max_freq = self.kwargs["max_freq_log2"]
53 | N_freqs = self.kwargs["num_freqs"]
54 |
55 | if self.kwargs["log_sampling"]:
56 | freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs)
57 | else:
58 | freq_bands = torch.linspace(2.0 ** 0.0, 2.0 ** max_freq, steps=N_freqs)
59 |
60 | for freq in freq_bands:
61 | for p_fn in self.kwargs["periodic_fns"]:
62 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
63 | out_dim += d
64 |
65 | self.embed_fns = embed_fns
66 | self.out_dim = out_dim
67 |
68 | def embed(self, inputs):
69 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
70 |
71 |
72 | def get_embedder(multires, i=0, input_dims=3):
73 | if i == -1:
74 | return nn.Identity(), input_dims
75 |
76 | embed_kwargs = {
77 | "include_input": True,
78 | "input_dims": input_dims,
79 | "max_freq_log2": multires - 1,
80 | "num_freqs": multires,
81 | "log_sampling": True,
82 | "periodic_fns": [torch.sin, torch.cos],
83 | }
84 |
85 | embedder_obj = Embedder(**embed_kwargs)
86 | embed = lambda x, eo=embedder_obj: eo.embed(x)
87 | return embed, embedder_obj.out_dim
88 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | '# Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | .tmp/*
6 | .tmp
7 | data/*
8 | job_config.py
9 | *.png
10 | /pretrained_models/
11 | conda-spec-file-julian.txt
12 | example_weights/*
13 | # C extensions
14 | *.so
15 | results/*
16 | .results/
17 | # Distribution / packaging
18 | .Python
19 | build/
20 | develop-eggs/
21 | dist/
22 | downloads/
23 | eggs/
24 | .eggs/
25 | lib/
26 | lib64/
27 | parts/
28 | sdist/
29 | var/
30 | wheels/
31 | pip-wheel-metadata/
32 | share/python-wheels/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 | MANIFEST
37 |
38 | # PyInstaller
39 | # Usually these files are written by a python script from a template
40 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
41 | *.manifest
42 | *.spec
43 |
44 | # Installer logs
45 | pip-log.txt
46 | pip-delete-this-directory.txt
47 |
48 | # Unit test / coverage reports
49 | htmlcov/
50 | .tox/
51 | .nox/
52 | .coverage
53 | .coverage.*
54 | .cache
55 | nosetests.xml
56 | coverage.xml
57 | *.cover
58 | *.py,cover
59 | .hypothesis/
60 | .pytest_cache/
61 |
62 | # Translations
63 | *.mo
64 | *.pot
65 |
66 | # Django stuff:
67 | *.log
68 | local_settings.py
69 | db.sqlite3
70 | db.sqlite3-journal
71 |
72 | # Flask stuff:
73 | instance/
74 | .webassets-cache
75 |
76 | # Scrapy stuff:
77 | .scrapy
78 |
79 | # Sphinx documentation
80 | docs/_build/
81 |
82 | # PyBuilder
83 | target/
84 |
85 | # Jupyter Notebook
86 | .ipynb_checkpoints
87 |
88 | # IPython
89 | profile_default/
90 | ipython_config.py
91 |
92 | # pyenv
93 | .python-version
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
103 | __pypackages__/
104 |
105 | # Celery stuff
106 | celerybeat-schedule
107 | celerybeat.pid
108 |
109 | # SageMath parsed files
110 | *.sage.py
111 |
112 | # Environments
113 | .env
114 | .venv
115 | env/
116 | venv/
117 | ENV/
118 | env.bak/
119 | venv.bak/
120 |
121 | # Spyder project settings
122 | .spyderproject
123 | .spyproject
124 |
125 | # Rope project settings
126 | .ropeproject
127 |
128 | # mkdocs documentation
129 | /site
130 |
131 | # mypy
132 | .mypy_cache/
133 | .dmypy.json
134 | dmypy.json
135 |
136 | # Pyre type checker
137 | .pyre/
138 |
139 | *.pyc
140 | *#*
141 | *.pyx
142 | *.record
143 | *.tar.gz
144 | *.tar
145 | *.swp
146 | *.npy
147 | *.ckpt*
148 | *.idea*
149 | axcana.egg-info
150 | .vscode
151 | .idea
152 | *.blg
153 | *.gz
154 | *.avi
155 | main.log
156 | main.aux
157 | main.bbl
158 | main.brf
159 | main.pdf
160 | *.log
161 | *.aux
162 | *.bbl
163 | *.brf
164 | mainSup.pdf
165 | egrebuttal.pdf
166 | .tmp/*
167 | /results/
168 | /results/
169 |
170 | /model_library/
171 | old_model/
172 | /.tmp/
173 | /src/datasets/eos_dataset_src/
174 |
--------------------------------------------------------------------------------
/src/pointLF/pointLF_helper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | def pre_scale_MV(x, translation=-1.4):
6 | batchsize = x.size()[0]
7 | scaled_x = x
8 |
9 | for k in range(batchsize):
10 | for i in range(3):
11 | max = scaled_x[k, ..., i].max()
12 | min = scaled_x[k, ..., i].sort()[0][10]
13 | ax_size = max - min
14 | scaled_x[k, ..., i] -= min
15 | scaled_x[k, ..., i] *= 2 / ax_size
16 | scaled_x[k, ..., i] -= 1.
17 |
18 | scaled_x = torch.minimum(scaled_x, torch.tensor([1.], device='cuda'))
19 | scaled_x = torch.maximum(scaled_x, torch.tensor([-1.], device='cuda'))
20 |
21 | # scaled_x[:, 0] = torch.tensor([-1.0, -1.0, -1.0], device=scaled_x.device)
22 | # scaled_x[:, 1] = torch.tensor([1.0, -1.0, -1.0], device=scaled_x.device)
23 | # scaled_x[:, 2] = torch.tensor([-1.0, 1.0, -1.0], device=scaled_x.device)
24 | # scaled_x[:, 3] = torch.tensor([1.0, 1.0, -1.0], device=scaled_x.device)
25 | # scaled_x[:, 4] = torch.tensor([-1.0, -1.0, 1.0], device=scaled_x.device)
26 | # scaled_x[:, 5] = torch.tensor([-1.0, 1.0, 1.0], device=scaled_x.device)
27 | # scaled_x[:, 6] = torch.tensor([1.0, -1.0, 1.0], device=scaled_x.device)
28 | # scaled_x[:, 7] = torch.tensor([1.0, 1.0, 1.0], device=scaled_x.device)
29 |
30 | scaled_x *= 1 / -translation
31 |
32 | # scaled_plt = scaled_x[0].cpu().detach().numpy()
33 | # fig3d = plt.figure()
34 | # ax3d = fig3d.gca(projection='3d')
35 | # ax3d.scatter(scaled_plt[:, 0], scaled_plt[:, 1], scaled_plt[:, 2], c='blue')
36 |
37 | return scaled_x
38 |
39 |
40 | def select_Mv_feat(feature_maps, scaled_pts, closest_mask, batchsize, k_closest, feature_extractor, img_resolution=128,
41 | feature_resolution=16):
42 | feat2img_f = img_resolution // feature_resolution
43 |
44 | # n_feat_maps, batchsize, n_features, feat_heigth, feat_width = feature_maps.shape
45 | n_batch, maps_per_batch, n_features, feat_heigth, feat_width = feature_maps.shape
46 | n_feat_maps = maps_per_batch * n_batch
47 | feature_maps = feature_maps.reshape(n_feat_maps, n_features, feat_heigth, feat_width)
48 |
49 | # Only retrive pts_feat for relevant points
50 | masked_scaled_pts = [sc_x[mask] for (sc_x, mask) in zip(scaled_pts, closest_mask)]
51 | masked_scaled_pts = torch.stack(masked_scaled_pts).view(n_batch, -1, 3)
52 |
53 | # Get coordinates in the feautre maps for each point
54 | coordinates, coord_x, coord_y, depth = feature_extractor._get_img_coord(masked_scaled_pts, resolution=img_resolution)
55 |
56 | # Adjust for downscaled feature maps
57 | coord_x = torch.round(coord_x.view(n_feat_maps, -1, k_closest) / feat2img_f).to(torch.long)
58 | coord_x = torch.minimum(coord_x, torch.tensor([feat_heigth - 1], device=coord_x.device))
59 | coord_y = torch.round(coord_y.view(n_feat_maps, -1, k_closest) / feat2img_f).to(torch.long)
60 | coord_y = torch.minimum(coord_y, torch.tensor([feat_width - 1], device=coord_x.device))
61 |
62 | # depth = depth.view(n_batch, maps_per_batch, -1, k_closest)
63 |
64 | # Extract features for each ray and k closest points
65 | feature_maps = feature_maps.permute(0, 2, 3, 1)
66 | pts_feat = torch.stack([feature_maps[i][tuple([coord_x[i], coord_y[i]])] for i in range(n_feat_maps)])
67 | pts_feat = pts_feat.reshape(n_batch, maps_per_batch, -1, k_closest, n_features)
68 | pts_feat = pts_feat.permute(0, 2, 3, 1, 4)
69 | # Sum all pts_feat from all feature maps
70 | # pts_feat = pts_feat.sum(dim=1)
71 | # pts_feat = torch.max(pts_feat, dim=1)[0]
72 |
73 | return pts_feat
74 |
75 |
76 | def lin_weighting(z, distance, projected, my=0.9):
77 |
78 |
79 | inv_pt_ray_dist = torch.div(1, distance)
80 | pt_ray_dist_weights = inv_pt_ray_dist / torch.norm(inv_pt_ray_dist, dim=-1)[..., None]
81 |
82 | inv_proj_dist = torch.div(1, projected)
83 | pt_proj_dist_weights = inv_proj_dist / torch.norm(inv_proj_dist, dim=-1)[..., None]
84 |
85 | z = z * (my * pt_ray_dist_weights + (1-my) * pt_proj_dist_weights)[..., None, None]
86 |
87 | return torch.sum(z, dim=2)
--------------------------------------------------------------------------------
/scripts/tst_waymo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | import matplotlib.pyplot as plt
5 | import matplotlib.patches as patches
6 | from PIL import Image
7 | import tensorflow as tf
8 | from waymo_open_dataset.utils import frame_utils
9 | from waymo_open_dataset import dataset_pb2 as open_dataset
10 | from NeuralSceneGraph.data_loader.load_waymo_od import load_waymo_od_data
11 |
12 |
13 | tf.compat.v1.enable_eager_execution()
14 |
15 | # Plot every i_plt image
16 | i_plt = 1
17 |
18 | start = 50
19 | end = 60
20 |
21 | frames_path = '/home/julian/Desktop/waymo_open/segment-9985243312780923024_3049_720_3069_720_with_camera_labels.tfrecord'
22 | basedir = '/home/julian/Desktop/waymo_open'
23 | img_dir = '/home/julian/Desktop/waymo_open/tst_01'
24 |
25 | cam_ls = ['front', 'front_left', 'front_right']
26 |
27 | records = []
28 | dir_list = os.listdir(basedir)
29 | dir_list.sort()
30 | for f in dir_list:
31 | if 'record' in f:
32 | records.append(os.path.join(basedir, f))
33 |
34 |
35 |
36 | images = load_waymo_od_data(frames_path, selected_frames=[start, end])[0]
37 | for i_record, tf_record in enumerate(records):
38 | dataset = tf.data.TFRecordDataset(tf_record, compression_type='')
39 | print(tf_record)
40 |
41 | for i, data in enumerate(dataset):
42 | if not i % i_plt:
43 | frame = open_dataset.Frame()
44 | frame.ParseFromString(bytearray(data.numpy()))
45 |
46 | for index, camera_image in enumerate(frame.images):
47 | if camera_image.name in [1, 2, 3]:
48 | img_arr = np.array(tf.image.decode_jpeg(camera_image.image))
49 | # plt.imshow(img_arr, cmap=None)
50 |
51 | cam_dir = os.path.join(img_dir, cam_ls[camera_image.name-1])
52 | im_name = 'img_' + str(i_record) + '_' + str(i) + '.jpg'
53 | im = Image.fromarray(img_arr)
54 | im.save(os.path.join(cam_dir, im_name))
55 |
56 | # frames = []
57 | # max_frames=10
58 | # i_plt = 100
59 | #
60 | # for i, data in enumerate(dataset):
61 | # if not i % i_plt:
62 | # frame = open_dataset.Frame()
63 | # frame.ParseFromString(bytearray(data.numpy()))
64 | #
65 | # for index, camera_image in enumerate(frame.images):
66 | # if camera_image.name in [1, 2, 3]:
67 | # plt.imshow(tf.image.decode_jpeg(camera_image.image), cmap=None)
68 |
69 | # layout = [3, 3, index+1]
70 | # ax = plt.subplot(*layout)
71 | #
72 | # plt.imshow(tf.image.decode_jpeg(camera_image.image), cmap=None)
73 | # plt.title(open_dataset.CameraName.Name.Name(camera_image.name))
74 | # plt.grid(False)
75 | # plt.axis('off')
76 |
77 | # frames.append(frame)
78 | # if i >= max_frames-1:
79 | # break
80 |
81 | # frame = frames[0]
82 |
83 | (range_images, camera_projections, range_image_top_pose) = (
84 | frame_utils.parse_range_image_and_camera_projection(frame))
85 |
86 | print(frame.context)
87 |
88 | def show_camera_image(camera_image, camera_labels, layout, cmap=None):
89 | """Show a camera image and the given camera labels."""
90 |
91 | ax = plt.subplot(*layout)
92 |
93 | # Draw the camera labels.
94 | for camera_labels in frame.camera_labels:
95 | # Ignore camera labels that do not correspond to this camera.
96 | if camera_labels.name != camera_image.name:
97 | continue
98 |
99 | # Iterate over the individual labels.
100 | for label in camera_labels.labels:
101 | # Draw the object bounding box.
102 | ax.add_patch(patches.Rectangle(
103 | xy=(label.box.center_x - 0.5 * label.box.length,
104 | label.box.center_y - 0.5 * label.box.width),
105 | width=label.box.length,
106 | height=label.box.width,
107 | linewidth=1,
108 | edgecolor='red',
109 | facecolor='none'))
110 |
111 | # Show the camera image.
112 | plt.imshow(tf.image.decode_jpeg(camera_image.image), cmap=cmap)
113 | plt.title(open_dataset.CameraName.Name.Name(camera_image.name))
114 | plt.grid(False)
115 | plt.axis('off')
116 |
117 | plt.figure(figsize=(25, 20))
118 |
119 | for index, image in enumerate(frame.images):
120 | show_camera_image(image, frame.camera_labels, [3, 3, index+1])
121 |
122 |
123 | a = 0
--------------------------------------------------------------------------------
/src/pointLF/ptlf_vis.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import open3d as o3d
5 | from sklearn.manifold import TSNE
6 |
7 |
8 | def plt_pts_selected_2D(output_dict, axis):
9 | selected_points = [pts[mask] for pts, mask in
10 | zip(output_dict['points_in'], output_dict['closest_mask_in'])]
11 |
12 | for l, cf in enumerate(output_dict['samples']):
13 | plt_pts = selected_points[l].reshape(-1, 3)
14 | all_pts = output_dict['points_in'][l]
15 |
16 | plt.figure()
17 | plt.scatter(all_pts[:, axis[0]], all_pts[:, axis[1]])
18 | plt.scatter(plt_pts[:, axis[0]], plt_pts[:, axis[1]])
19 |
20 | if 'points_scaled' in output_dict:
21 | selected_points = [pts[mask] for pts, mask in
22 | zip(output_dict['points_scaled'], output_dict['closest_mask_in'])]
23 | for l, cf in enumerate(output_dict['samples']):
24 | plt_pts = selected_points[l].reshape(-1, 3)
25 | all_pts = output_dict['points_scaled'][l]
26 |
27 | plt.figure()
28 | plt.scatter(all_pts[:, axis[0]], all_pts[:, axis[1]])
29 | plt.scatter(plt_pts[:, axis[0]], plt_pts[:, axis[1]])
30 |
31 | def plt_BEV_pts_selected(output_dict):
32 | plt_pts_selected_2D(output_dict, (0,1))
33 |
34 | def plt_SIDE_pts_selected(output_dict):
35 | plt_pts_selected_2D(output_dict, (0, 2))
36 |
37 | def plt_FRONT_pts_selected(output_dict):
38 | plt_pts_selected_2D(output_dict, (1, 2))
39 |
40 | def visualize_output(output_dict, selected_only=False, scaled=False, n_plt_rays=None):
41 | if scaled:
42 | pts_in = output_dict['points_scaled']
43 | else:
44 | pts_in = output_dict['points_in']
45 |
46 | masks_in = output_dict['closest_mask_in']
47 |
48 | if 'sum_mv_point_features' in output_dict:
49 | feat_per_point = output_dict['sum_mv_point_features'].squeeze()
50 | if len(feat_per_point.shape) == 4:
51 | n_batch, n_rays, n_closest_pts, feat_dim = feat_per_point.shape
52 | else:
53 | n_batch = 1
54 | n_rays, n_closest_pts, feat_dim = feat_per_point.shape
55 |
56 | for i, (pts, mask, feat) in enumerate(zip(pts_in, masks_in, feat_per_point)):
57 |
58 | # Get feature embedding for visualization
59 | feat = feat.reshape(-1, feat_dim)
60 | feat_embedded = TSNE(n_components=3).fit_transform(feat)
61 |
62 | # Transform embedded space to RGB
63 | feat_embedded = feat_embedded - feat_embedded.min(axis=0)
64 | color = feat_embedded / feat_embedded.max(axis=0)
65 |
66 | if n_plt_rays is not None:
67 | ray_ids = np.random.choice(len(mask), n_plt_rays)
68 | mask = mask[ray_ids]
69 | color = color.reshape(n_rays, n_closest_pts, 3)
70 | color = color[ray_ids].reshape(-1, 3)
71 |
72 | pts_close = pts[mask]
73 | pcd_close = get_pcd_vis(pts_close, color_vector=color)
74 | if selected_only:
75 | o3d.visualization.draw_geometries([pcd_close])
76 | else:
77 | pcd = get_pcd_vis(pts)
78 | o3d.visualization.draw_geometries([pcd, pcd_close])
79 | else:
80 | for i, (pts, mask) in enumerate(zip(pts_in, masks_in)):
81 | pts_close = pts[mask]
82 | pcd_close = get_pcd_vis(pts, uniform_color=[1., 0.7, 0.])
83 |
84 | if selected_only:
85 | o3d.visualization.draw_geometries([pcd_close])
86 | else:
87 | pcd = get_pcd_vis(pts)
88 | o3d.visualization.draw_geometries([pcd, pcd_close])
89 |
90 | def get_pcd_vis(pts, uniform_color=None, color_vector=None):
91 | pts = pts.reshape(-1,3)
92 | pcd = o3d.geometry.PointCloud()
93 | pcd.points = o3d.utility.Vector3dVector(pts)
94 | if uniform_color is not None:
95 | # pcd.paint_uniform_color([1., 0.7, 0.])
96 | pcd.paint_uniform_color(uniform_color)
97 | if color_vector is not None:
98 | assert len(pts) == len(color_vector)
99 | pcd.colors = o3d.utility.Vector3dVector(color_vector)
100 |
101 | return pcd
102 |
103 |
104 |
105 | # pts_in_name = ".tmp/points_in_{}.ply".format(i)
106 | # o3d.io.write_point_cloud(pts_in_name, pcd)
107 | # pcd = o3d.io.read_point_cloud(pts_in_name)
--------------------------------------------------------------------------------
/src/scenes/nodes.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from pytorch3d.renderer import PerspectiveCameras
5 | from src.pointLF.point_light_field import PointLightField
6 | from pytorch3d.transforms.rotation_conversions import euler_angles_to_matrix
7 |
8 | DEVICE = "cpu"
9 | TRAINCAM = False
10 | TRAINOBJSZ = False
11 | TRAINOBJ = False
12 |
13 |
14 | class Intrinsics:
15 | def __init__(
16 | self,
17 | H,
18 | W,
19 | focal,
20 | P=None
21 | ):
22 | self.H = int(H)
23 | self.W = int(W)
24 | if np.size(focal) == 1:
25 | self.f_x = nn.Parameter(
26 | torch.tensor(focal, device=DEVICE, requires_grad=TRAINCAM)
27 | )
28 | self.f_y = nn.Parameter(
29 | torch.tensor(focal, device=DEVICE, requires_grad=TRAINCAM)
30 | )
31 | else:
32 | self.f_x = nn.Parameter(
33 | torch.tensor(focal[0], device=DEVICE, requires_grad=TRAINCAM)
34 | )
35 | self.f_y = nn.Parameter(
36 | torch.tensor(focal[1], device=DEVICE, requires_grad=TRAINCAM)
37 | )
38 |
39 | self.P = P
40 |
41 |
42 | class NeuralCamera(PerspectiveCameras):
43 | def __init__(self, h, w, f, intrinsics=None, name=None, type=None, P=None):
44 | # TODO: Cleanup intrinsics and h, w, f
45 | # TODO: Add P matrix for projection
46 | self.H = h
47 | self.W = w
48 | if intrinsics is None:
49 | self.intrinsics = Intrinsics(h, w, f, P)
50 | else:
51 | self.intrinsics = intrinsics
52 |
53 | # Add opengl2cam rotation
54 | opengl2cam = euler_angles_to_matrix(
55 | torch.tensor([np.pi, np.pi, 0.0], device=DEVICE), "ZYX"
56 | )
57 | if type == 'waymo':
58 | waymo_rot = euler_angles_to_matrix(torch.tensor([np.pi, 0., 0.],
59 | device=DEVICE), 'ZYX')
60 |
61 | opengl2cam = torch.matmul(waymo_rot, opengl2cam)
62 |
63 | # Simplified version under the assumption of square pixels
64 | # [self.intrinsics.f_x, self.intrinsics.f_y]
65 | PerspectiveCameras.__init__(
66 | self,
67 | focal_length=torch.tensor([[self.intrinsics.f_x, self.intrinsics.f_y]]),
68 | principal_point=torch.tensor(
69 | [[self.intrinsics.W / 2, self.intrinsics.H / 2]]
70 | ),
71 | R=opengl2cam,
72 | image_size=torch.tensor([[self.intrinsics.W, self.intrinsics.H]]),
73 | )
74 | self.name = name
75 |
76 |
77 | class Lidar:
78 | def __init__(self, Tr_li2cam=None, name=None):
79 | # TODO: Add all relevant params to scene init
80 | self.sensor_type = 'lidar'
81 | self.li2cam = Tr_li2cam
82 |
83 | self.name = name
84 |
85 |
86 | class ObjectClass:
87 | def __init__(self, name):
88 | self.static = False
89 | self.name = name
90 |
91 |
92 | class SceneObject:
93 | def __init__(self, length, height, width, object_class_node):
94 | self.static = False
95 | self.object_class_type_idx = object_class_node.type_idx
96 | self.object_class_name = object_class_node.name
97 | self.length = length
98 | self.height = height
99 | self.width = width
100 | self.box_size = torch.tensor([self.length, self.height, self.width])
101 |
102 |
103 | class Background:
104 | def __init__(self, transformation=None, near=0.5, far=100.0, lightfield_config={}):
105 | self.static = True
106 | global_transformation = np.eye(4)
107 |
108 | if transformation is not None:
109 | transformation = np.squeeze(transformation)
110 | if transformation.shape == (3, 3):
111 | global_transformation[:3, :3] = transformation
112 | elif transformation.shape == (3):
113 | global_transformation[:3, 3] = transformation
114 | elif transformation.shape == (3, 4):
115 | global_transformation[:3, :] = transformation
116 | elif transformation.shape == (4, 4):
117 | global_transformation = transformation
118 | else:
119 | print(
120 | "Ignoring wolrd transformation, not of shape [3, 3], [3, 4], [3, 1] or [4, 4], but",
121 | transformation.shape,
122 | )
123 |
124 | self.transformation = torch.from_numpy(global_transformation)
125 | self.near = torch.tensor(near)
126 | self.far = torch.tensor(far)
--------------------------------------------------------------------------------
/src/utils_dist.py:
--------------------------------------------------------------------------------
1 | import torch.distributed as dist
2 | import torch
3 | from torch.utils.data import DataLoader, DistributedSampler
4 | from operator import itemgetter
5 |
6 |
7 | class DatasetFromSampler(torch.utils.data.Dataset):
8 | """Dataset to create indexes from `Sampler`.
9 | Args:
10 | sampler: PyTorch sampler
11 | """
12 |
13 | def __init__(self, sampler):
14 | """Initialisation for DatasetFromSampler."""
15 | self.sampler = sampler
16 | self.sampler_list = None
17 |
18 | def __getitem__(self, index: int):
19 | """Gets element of the dataset.
20 | Args:
21 | index: index of the element in the dataset
22 | Returns:
23 | Single element by index
24 | """
25 | if self.sampler_list is None:
26 | self.sampler_list = list(self.sampler)
27 | return self.sampler_list[index]
28 |
29 | def __len__(self) -> int:
30 | """
31 | Returns:
32 | int: length of the dataset
33 | """
34 | return len(self.sampler)
35 |
36 |
37 | class DistributedSamplerWrapper(DistributedSampler):
38 | """
39 | Wrapper over `Sampler` for distributed training.
40 | Allows you to use any sampler in distributed mode.
41 | It is especially useful in conjunction with
42 | `torch.nn.parallel.DistributedDataParallel`. In such case, each
43 | process can pass a DistributedSamplerWrapper instance as a DataLoader
44 | sampler, and load a subset of subsampled data of the original dataset
45 | that is exclusive to it.
46 | .. note::
47 | Sampler is assumed to be of constant size.
48 | """
49 |
50 | def __init__(
51 | self,
52 | sampler,
53 | num_replicas=None,
54 | rank=None,
55 | shuffle: bool = True,
56 | ):
57 | """
58 | Args:
59 | sampler: Sampler used for subsampling
60 | num_replicas (int, optional): Number of processes participating in
61 | distributed training
62 | rank (int, optional): Rank of the current process
63 | within ``num_replicas``
64 | shuffle (bool, optional): If true (default),
65 | sampler will shuffle the indices
66 | """
67 | super(DistributedSamplerWrapper, self).__init__(
68 | DatasetFromSampler(sampler),
69 | num_replicas=num_replicas,
70 | rank=rank,
71 | shuffle=shuffle,
72 | )
73 | self.sampler = sampler
74 |
75 | def __iter__(self):
76 | """@TODO: Docs. Contribution is welcome."""
77 | self.dataset = DatasetFromSampler(self.sampler)
78 | indexes_of_indexes = super().__iter__()
79 | subsampler_indexes = self.dataset
80 | return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
81 |
82 |
83 | def setup_for_distributed(is_master):
84 | """
85 | This function disables printing when not in master process
86 | """
87 | import builtins as __builtin__
88 |
89 | builtin_print = __builtin__.print
90 |
91 | def print(*args, **kwargs):
92 | force = kwargs.pop("force", False)
93 | if is_master or force:
94 | builtin_print(*args, **kwargs)
95 |
96 | __builtin__.print = print
97 |
98 |
99 | def is_dist_avail_and_initialized():
100 | if not dist.is_available():
101 | return False
102 | if not dist.is_initialized():
103 | return False
104 | return True
105 |
106 |
107 | def get_world_size():
108 | if not is_dist_avail_and_initialized():
109 | return 1
110 | return dist.get_world_size()
111 |
112 |
113 | def get_rank():
114 | if not is_dist_avail_and_initialized():
115 | return 0
116 | return dist.get_rank()
117 |
118 |
119 | def is_main_process():
120 | return get_rank() == 0
121 |
122 |
123 | def save_on_master(*args, **kwargs):
124 | if is_main_process():
125 | torch.save(*args, **kwargs)
126 |
127 |
128 | import os, torch
129 |
130 |
131 | def init_distributed_mode(args):
132 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
133 | args.rank = int(os.environ["RANK"])
134 | args.world_size = int(os.environ["WORLD_SIZE"])
135 | args.gpu = int(os.environ["LOCAL_RANK"])
136 | elif "SLURM_PROCID" in os.environ:
137 | args.rank = int(os.environ["SLURM_PROCID"])
138 | args.gpu = args.rank % torch.cuda.device_count()
139 | else:
140 | print("Not using distributed mode")
141 | args.distributed = False
142 | return
143 |
144 | args.distributed = True
145 |
146 | torch.cuda.set_device(args.gpu)
147 | args.dist_backend = "nccl"
148 | print(
149 | "| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True
150 | )
151 | torch.distributed.init_process_group(
152 | backend=args.dist_backend,
153 | init_method=args.dist_url,
154 | world_size=args.world_size,
155 | rank=args.rank,
156 | )
157 | torch.distributed.barrier()
158 | setup_for_distributed(args.rank == 0)
159 |
--------------------------------------------------------------------------------
/src/pointLF/light_field_renderer.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Tuple, List
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import numpy
7 | from pytorch3d.renderer import RayBundle
8 | from ..scenes import NeuralScene
9 |
10 |
11 | class LightFieldRenderer(nn.Module):
12 |
13 | def __init__(self, light_field_module, chunksize: int, cam_centered: bool = False):
14 | super(LightFieldRenderer, self).__init__()
15 |
16 | self.light_field_module = light_field_module
17 | self.chunksize = chunksize
18 | if cam_centered:
19 | self.rotate2cam = True
20 | else:
21 | self.rotate2cam = False
22 |
23 | # def forward(self, frame_idx: int, camera_idx: int, scene: NeuralScene, volumetric_function: Callable, chunk_idx: int, **kwargs):
24 | def forward(self, input_dict, scene, **kwargs):
25 |
26 | ##############
27 | # TODO: Clean Up and move somewhere else
28 | ray_bundle = input_dict['ray_bundle']
29 | device = ray_bundle.origins.device
30 | pts = input_dict['pts']
31 | ray_dirs_select = input_dict['ray_dirs_select']
32 | closest_point_mask = input_dict['closest_point_mask']
33 |
34 | cf_ls = [list(pts_k[0].keys()) for pts_k in pts]
35 | import numpy as np
36 | unique_cf = np.unique(np.concatenate([np.array(cf) for cf in cf_ls]), axis=0)
37 |
38 | pts_to_unpack = {
39 | 0: 'pt_cloud_select',
40 | 1: 'closest_point_dist',
41 | 2: 'closest_point_azimuth',
42 | 3: 'closest_point_pitch',
43 |
44 | }
45 |
46 | if len(unique_cf) != len(cf_ls):
47 | new_cf = [tuple(list(cf[0]) + [j]) for j, cf in enumerate(cf_ls)]
48 |
49 | output_dict = {
50 | new_cf[j]:
51 | v
52 | for j, pt in enumerate(pts) for k, v in pt[4].items()
53 | }
54 |
55 | closest_point_mask = {new_cf[j]: v for j, pt in enumerate(closest_point_mask) for k, v in pt.items()}
56 | ray_dirs_select = {new_cf[j]: v.to(device) for j, pt in enumerate(ray_dirs_select) for k, v in pt.items()}
57 | pts = {
58 | n: {
59 | new_cf[j]:
60 | v.to(device)
61 | for j, pt in enumerate(pts) for k, v in pt[i].items()
62 | }
63 | for i, n in pts_to_unpack.items()
64 | }
65 | pts.update({'output_dict': output_dict})
66 |
67 | a = 0
68 |
69 | else:
70 | output_dict = {
71 | k:
72 | v
73 | for pt in pts for k, v in pt[4].items()
74 | }
75 |
76 | closest_point_mask = {k: v for pt in closest_point_mask for k, v in pt.items()}
77 | ray_dirs_select = {k: v.to(device) for pt in ray_dirs_select for k, v in pt.items()}
78 | pts = {
79 | n: {
80 | k:
81 | v.to(device)
82 | for pt in pts for k, v in pt[i].items()
83 | }
84 | for i, n in pts_to_unpack.items()
85 | }
86 | pts.update({'output_dict': output_dict})
87 | ##################
88 |
89 | images, output_dict = self.light_field_module(
90 | ray_bundle=input_dict['ray_bundle'],
91 | scene=scene,
92 | closest_point_mask=closest_point_mask,
93 | pt_cloud_select=pts['pt_cloud_select'],
94 | closest_point_dist=pts['closest_point_dist'],
95 | closest_point_azimuth=pts['closest_point_azimuth'],
96 | closest_point_pitch=pts['closest_point_pitch'],
97 | output_dict=pts['output_dict'],
98 | ray_dirs_select=ray_dirs_select,
99 | rotate2cam=self.rotate2cam,
100 | **kwargs
101 | )
102 |
103 | if scene.tonemapping:
104 | rgb = torch.zeros_like(images)
105 | tone_mapping_ls = [scene.frames[cf[0][1]].load_tone_mapping(cf[0][0]) for cf in cf_ls]
106 | for i in range(len(images)):
107 | rgb[i] = self.tone_map(images[i], tone_mapping_ls[i])
108 |
109 | else:
110 | rgb = images
111 |
112 | output_dict.update(
113 | {
114 | 'rgb': rgb.view(-1, 3),
115 | 'ray_bundle': ray_bundle._replace(xys=ray_bundle.xys[..., None, :])
116 | }
117 | )
118 |
119 | return output_dict
120 |
121 | def tone_map(self, x, tone_mapping_params):
122 |
123 | x = (tone_mapping_params['contrast'] * (x - 0.5) + \
124 | 0.5 + \
125 | tone_mapping_params['brightness']) * \
126 | torch.cat(list(tone_mapping_params['wht_pt'].values()))
127 | x = self.leaky_clamping(x, gamma=tone_mapping_params['gamma'])
128 |
129 | return x
130 |
131 | def leaky_clamping(self, x, gamma, alpha=0.01):
132 | x[x < 0] = x[x < 0] * alpha
133 | x[x > 1] = (-alpha / x[x > 1]) + alpha + 1.
134 | x[(x > 0.) & (x < 1.)] = x[(x > 0.) & (x < 1.)] ** gamma
135 | return x
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Neural Point Light Fields (CVPR 2022)
2 |
3 |
4 |
5 | ### [Project Page](https://light.princeton.edu/publication/neural-point-light-fields)
6 | #### Julian Ost, Issam Laradji, Alejandro Newell, Yuval Bahat, Felix Heide
7 |
8 |
9 | Neural Point Light Fields represent scenes with a light field living on a sparse point cloud. As neural volumetric
10 | rendering methods require dense sampling of the underlying functional scene representation, at hundreds of samples
11 | along with a ray cast through the volume, they are fundamentally limited to small scenes with the same objects
12 | projected to hundreds of training views. Promoting sparse point clouds to neural implicit light fields allows us to
13 | represent large scenes effectively with only a single implicit sampling operation per ray.
14 |
15 | These point light fields are a function of the ray direction, and local point feature neighborhood, allowing us to
16 | interpolate the light field conditioned training images without dense object coverage and parallax. We assess the
17 | proposed method for novel view synthesis on large driving scenarios, where we synthesize realistic unseen views that
18 | existing implicit approaches fail to represent. We validate that Neural Point Light Fields make it possible to predict
19 | videos along unseen trajectories previously only feasible to generate by explicitly modeling the scene.
20 |
21 | ---
22 |
23 | ### Data Preparation
24 | #### Waymo
25 |
26 | 1. Download the compressed data of the Waymo Open Dataset:
27 | [Waymo Validation Tar Files](https://console.cloud.google.com/storage/browser/waymo_open_dataset_v_1_3_1/validation?pageState=(%22StorageObjectListTable%22:(%22f%22:%22%255B%255D%22))&prefix=&forceOnObjectsSortingFiltering=false)
28 |
29 | 2. To run the code as used in the paper, store as follows: `./data/validation/validation_xxxx/segment-xxxxxx`
30 |
31 | 3. Neural Point Light Fields is well tested on the segments mentioned in the [Supplementary](https://light.princeton.edu/wp-content/uploads/2022/04/NeuralPointLightFields-Supplementary.pdf) and shown in the experiment group `pointLF_waymo_server`.
32 |
33 |
34 | 4. Preprocess the data running: `./src/datasets/preprocess_Waymo.py -d "./data/validation/validation_xxxx/segment-xxxxxx" --no_data`
35 |
36 | ---
37 |
38 | ### Requirements
39 |
40 | Environment setup
41 | ```
42 | conda create -n NeuralPointLF python=3.7
43 | conda activate NeuralPointLF
44 | ```
45 | Install required packages
46 | ```
47 | conda install -c pytorch -c conda-forge pytorch=1.7.1 torchvision=0.8.2 cudatoolkit=11.0
48 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath
49 | conda install -c bottler nvidiacub
50 | conda install jupyterlab
51 | pip install scikit-image matplotlib imageio plotly opencv-python
52 | conda install pytorch3d -c pytorch3d
53 | conda install -c open3d-admin -c conda-forge open3d
54 | ```
55 | Add large-scale ML toolkit
56 | ```
57 | pip install --upgrade git+https://github.com/haven-ai/haven-ai
58 | ```
59 |
60 | ---
61 | ### Training and Validation
62 | In the first iteration of a scene, the point clouds will be preprocessed and stored, which might take some time.
63 | If you want to train on unmerged point cloud data set `merge_pcd=False` in the config file.
64 |
65 | Train one specific scene from the Waymo Open data set:
66 | ```
67 | python trainval.py -e pointLF_waymo_server -sb -d ./data/waymo/validation --epoch_size
68 | 100 --num_workers=
69 | ```
70 |
71 | Reconstruct the training path (`--render_only=True`)
72 | ```
73 | python trainval.py -e pointLF_waymo_server -sb -d ./data/waymo/validation --epoch_size
74 | 100 --num_workers= --render_only=True
75 | ```
76 |
77 | Argument Descriptions:
78 | ```
79 | -e [Experiment group to run like 'mushrooms' (the rest of the experiment groups are in exp_configs/sps_exps.py)]
80 | -sb [Directory where the experiments are saved]
81 | -r [Flag for whether to reset the experiments]
82 | -d [Directory where the datasets are aved]
83 | ```
84 |
85 | **_Disclaimer_**: The codebase is optimized to run on larger GPU servers with a lot of free CPU memory.
86 | To test on local and low memory, choose `pointLF_waymo_local` instead of `pointLF_waymo_server`.
87 | Adjustments of batch size, chunk size and number of rays will have an effect on necessary resources.
88 |
89 | ---
90 | ### Visualization of Results
91 |
92 | Follow these steps to visualize plots. Open `results.ipynb`, run the first cell to get a dashboard like in the gif below, click on the "plots" tab, then click on "Display plots". Parameters of the plots can be adjusted in the dashboard for custom visualizations.
93 |
94 |
95 |
96 |
97 |
98 | ---
99 | #### Citation
100 | ```
101 | @InProceedings{ost2022pointlightfields,
102 | title = {Neural Point Light Fields},
103 | author = {Ost, Julian and Laradji, Issam and Newell, Alejandro and Bahat, Yuval and Heide, Felix},
104 | journal = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
105 | year = {2022}
106 | }
107 | ```
108 |
109 |
110 |
--------------------------------------------------------------------------------
/src/datasets/extract_baselines.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import open3d as o3d
4 | import matplotlib.pyplot as plt
5 | from PIL import Image
6 |
7 |
8 | def extract_waymo_poses(dataset, scene_list, i_train, i_test, i_all):
9 | cam_names = ['FRONT', 'FRONT_LEFT', 'FRONT_RIGHT', 'SIDE_LEFT', 'SIDE_RIGHT']
10 | assert len(scene_list) == 1
11 | first_fr = scene_list[0]['first_frame']
12 | last_fr = scene_list[0]['last_frame']
13 |
14 | pose_file_n = "poses_{}_{}.npy".format(str(first_fr).zfill(4), str(last_fr).zfill(4))
15 | img_file_n = "imgs_{}_{}.npy".format(str(first_fr).zfill(4), str(last_fr).zfill(4))
16 | index_test_n= "index_test_{}_{}.npy".format(str(first_fr).zfill(4), str(last_fr).zfill(4))
17 | index_train_n = "index_train_{}_{}.npy".format(str(first_fr).zfill(4), str(last_fr).zfill(4))
18 | pt_depth_file_n = "pt_depth_{}_{}.npy".format(str(first_fr).zfill(4), str(last_fr).zfill(4))
19 | depth_img_file_n = "pt_depth_{}_{}.npy".format(str(first_fr).zfill(4), str(last_fr).zfill(4))
20 |
21 | segmemt_pth = ''
22 | for s in dataset.images[0].split('/')[1:-2]:
23 | segmemt_pth += '/' + s
24 |
25 | remove_side_views = True
26 |
27 | n_frames = len(dataset.poses_world) // 5
28 |
29 | if remove_side_views:
30 | n_imgs = n_frames * 3
31 | cam_names = cam_names[:3]
32 |
33 | else:
34 | n_imgs = n_frames * 5
35 |
36 | cam_pose = dataset.poses_world[:n_imgs]
37 | img_path = dataset.images[:n_imgs]
38 |
39 | cam_pose_openGL = cam_pose.dot(np.array([[-1., 0., 0., 0., ], [0., 1., 0., 0., ], [0., 0., -1., 0., ], [0., 0., 0., 1., ], ]))
40 |
41 | # Add H, W, focal
42 | hwf1 = [np.array([[dataset.H[c_name], dataset.W[c_name], dataset.focal[c_name], 1.]]).repeat(n_frames, axis=0) for c_name in cam_names]
43 | hwf1 = np.concatenate(hwf1)[:, :, None]
44 | cam_pose_openGL = np.concatenate([cam_pose_openGL, hwf1], axis=2)
45 |
46 | np.save(os.path.join(segmemt_pth, pose_file_n), cam_pose_openGL)
47 | np.save(os.path.join(segmemt_pth, img_file_n), img_path)
48 | np.save(os.path.join(segmemt_pth, index_test_n), i_test)
49 | np.save(os.path.join(segmemt_pth, index_train_n), i_train)
50 |
51 | xyz_pts = []
52 |
53 | # Extract depth points (and images)
54 | for i in range(n_imgs):
55 | fr = i % n_frames
56 | pts_i_veh = np.asarray(o3d.io.read_point_cloud(dataset.point_cloud_pth[fr]).points)
57 | pts_i_veh = np.concatenate([pts_i_veh, np.ones([len(pts_i_veh), 1])], axis=-1)
58 |
59 | cam2veh_i = dataset.poses[i]
60 | veh2cam_i = np.concatenate([cam2veh_i[:3, :3].T, np.matmul(cam2veh_i[:3, :3].T, -cam2veh_i[:3, 3])[:, None]], axis=-1)
61 |
62 | pts_i_cam = np.matmul(veh2cam_i, pts_i_veh.T).T
63 |
64 | focal_i = hwf1[i, 2]
65 | h_i = hwf1[i, 0]
66 | w_i = hwf1[i, 1]
67 |
68 | # x - W
69 | x = -focal_i * (pts_i_cam[:, 0] / pts_i_cam[:, 2])
70 | # y - H
71 | y = -focal_i * (pts_i_cam[:, 1] / pts_i_cam[:, 2])
72 |
73 | xyz = np.stack([x, y, pts_i_cam[:, 2]]).T
74 |
75 | visible_pts_map = (xyz[:, 2] > 0) & (np.abs(xyz[:, 0]) < w_i // 2) & (np.abs(xyz[:, 1]) < h_i // 2)
76 |
77 | xyz_visible = xyz[visible_pts_map]
78 |
79 | # xxx['coord'][:, 0] == W
80 | # xxx['coord'][:, 1] == H
81 | xyz_visible[:, 0] = np.maximum(np.minimum(xyz_visible[:, 0] + w_i // 2, w_i), 0)
82 | xyz_visible[:, 1] = np.maximum(np.minimum(xyz_visible[:, 1] + h_i // 2, h_i), 0)
83 |
84 | xyz_pts.append(
85 | {
86 | 'depth': xyz_visible[:, 2],
87 | 'coord': xyz_visible[:, :2],
88 | 'weight': np.ones_like(xyz_visible[:, 2])
89 | }
90 | )
91 |
92 | # ######### Debug Depth Outputs
93 | # if i == 102:
94 | # scale = 8
95 | #
96 | # h_scaled = h_i // scale
97 | # w_scaled = w_i // scale
98 | #
99 | # xyz_vis_scaled = xyz_visible / 8
100 | #
101 | # depth_img = np.zeros([int(h_scaled), int(w_scaled), 1])
102 | # depth_img[np.floor(xyz_vis_scaled[:, 1]).astype("int"), np.floor(xyz_vis_scaled[:, 0]).astype("int")] = xyz_visible[:, 2][:, None]
103 | # plt.figure()
104 | # plt.imshow(depth_img[..., 0], cmap="plasma")
105 | #
106 | # plt.figure()
107 | # img_i = Image.open(img_path[i])
108 | # img_i = img_i.resize((w_scaled, h_scaled))
109 | # plt.imshow(np.asarray(img_i))
110 | #
111 | # # pcd = o3d.geometry.PointCloud()
112 | # # pcd.points = o3d.utility.Vector3dVector(pts_i_cam)
113 |
114 | np.save(os.path.join(segmemt_pth, pt_depth_file_n), xyz_pts)
115 |
116 |
117 | # plt.scatter(cam_pose_openGL[:, 0, 3], cam_pose_openGL[:, 1, 3])
118 | # plt.axis('equal')
119 | #
120 | # for i in range(len(pos)):
121 | # p = cam_pose_openGL[i]
122 | # # p = pos[i].dot(np.array([[-1., 0., 0., 0., ],
123 | # # [0., 1., 0., 0., ],
124 | # # [0., 0., -1., 0., ],
125 | # # [0., 0., 0., 1., ], ]))
126 | # plt.arrow(p[0, 3], p[1, 3], p[0, 0], p[1, 0], color="red")
127 | # plt.arrow(p[0, 3], p[1, 3], p[0, 2], p[1, 2], color="black")
128 |
129 |
130 |
131 | def extract_depth_information(dataset, exp_dict):
132 | pass
133 |
134 |
135 | def extract_depth_image(dataset, exp_dict):
136 | pass
--------------------------------------------------------------------------------
/src/pointLF/layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch import Tensor
5 | import math
6 |
7 | class DenseLayer(nn.Linear):
8 | def __init__(self, in_dim: int, out_dim: int, activate: str = "relu", *args, **kwargs) -> None:
9 | if activate:
10 | self.activation = activate
11 | else:
12 | self.activation = 'linear'
13 | super().__init__(in_dim, out_dim, *args,)
14 |
15 | def reset_parameters(self) -> None:
16 | torch.nn.init.xavier_uniform_(self.weight, gain=torch.nn.init.calculate_gain(self.activation))
17 | if self.bias is not None:
18 | torch.nn.init.zeros_(self.bias)
19 |
20 | def forward(self, input: Tensor, **kwargs) -> Tensor:
21 | out = super(DenseLayer, self).forward(input)
22 | if self.activation == 'relu':
23 | out = F.relu(out)
24 |
25 | return out
26 |
27 |
28 | class EqualLinear(nn.Module):
29 | """Linear layer with equalized learning rate.
30 |
31 | During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to
32 | prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU
33 | activation functions.
34 |
35 | Args:
36 | ----
37 | in_channel: int
38 | Input channels.
39 | out_channel: int
40 | Output channels.
41 | bias: bool
42 | Use bias term.
43 | bias_init: float
44 | Initial value for the bias.
45 | lr_mul: float
46 | Learning rate multiplier. By scaling weights and the bias we can proportionally scale the magnitude of
47 | the gradients, effectively increasing/decreasing the learning rate for this layer.
48 | activate: bool
49 | Apply leakyReLU activation.
50 |
51 | """
52 |
53 | def __init__(self, in_channel, out_channel, bias=True, bias_init=0, lr_mul=1, activate=False):
54 | super().__init__()
55 |
56 | self.weight = nn.Parameter(torch.randn(out_channel, in_channel).div_(lr_mul))
57 |
58 | if bias:
59 | self.bias = nn.Parameter(torch.zeros(out_channel).fill_(bias_init))
60 | else:
61 | self.bias = None
62 |
63 | self.activate = activate
64 | self.scale = (1 / math.sqrt(in_channel)) * lr_mul
65 | self.lr_mul = lr_mul
66 |
67 | def forward(self, input):
68 | if self.activate:
69 | out = F.linear(input, self.weight * self.scale)
70 | out = fused_leaky_relu(out, self.bias * self.lr_mul)
71 | else:
72 | out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
73 | return out
74 |
75 | def __repr__(self):
76 | return f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
77 |
78 |
79 | class FusedLeakyReLU(nn.Module):
80 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
81 | super().__init__()
82 |
83 | if bias:
84 | self.bias = nn.Parameter(torch.zeros(channel))
85 |
86 | else:
87 | self.bias = None
88 |
89 | self.negative_slope = negative_slope
90 | self.scale = scale
91 |
92 | def forward(self, input):
93 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
94 |
95 |
96 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
97 | if input.dtype == torch.float16:
98 | bias = bias.half()
99 |
100 | if bias is not None:
101 | rest_dim = [1] * (input.ndim - bias.ndim - 1)
102 | return F.leaky_relu(input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2) * scale
103 |
104 | else:
105 | return F.leaky_relu(input, negative_slope=0.2) * scale
106 |
107 |
108 | class ModulationLayer(nn.Module):
109 | def __init__(self, in_ch, out_ch, z_dim, demodulate=True, activate=True, bias=True, **kwargs):
110 | super(ModulationLayer, self).__init__()
111 | self.eps = 1e-8
112 |
113 | self.in_ch = in_ch
114 | self.out_ch = out_ch
115 | self.z_dim = z_dim
116 | self.demodulate = demodulate
117 |
118 | self.scale = 1 / math.sqrt(in_ch)
119 | self.weight = nn.Parameter(torch.randn(out_ch, in_ch))
120 | self.modulation = EqualLinear(z_dim, in_ch, bias_init=1, activate=False)
121 |
122 | if activate:
123 | # FusedLeakyReLU includes a bias term
124 | self.activate = FusedLeakyReLU(out_ch, bias=bias)
125 | elif bias:
126 | self.bias = nn.Parameter(torch.zeros(1, out_ch))
127 |
128 | def __repr__(self):
129 | return f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, z_dim={self.z_dim})'
130 |
131 |
132 | def forward(self, input, z):
133 | # feature modulation
134 | gamma = self.modulation(z) # B, in_ch
135 | input = input * gamma
136 |
137 | weight = self.weight * self.scale
138 |
139 | if self.demodulate:
140 | # weight is out_ch x in_ch
141 | # here we calculate the standard deviation per input channel
142 | demod = torch.rsqrt(weight.pow(2).sum([1]) + self.eps)
143 | weight = weight * demod.view(-1, 1)
144 |
145 | # also normalize inputs
146 | input_demod = torch.rsqrt(input.pow(2).sum([1]) + self.eps)
147 | input = input * input_demod.view(-1, 1)
148 |
149 | out = F.linear(input, weight)
150 |
151 | if hasattr(self, 'activate'):
152 | out = self.activate(out)
153 |
154 | if hasattr(self, 'bias'):
155 | out = out + self.bias
156 |
157 | return out
--------------------------------------------------------------------------------
/src/scenes/raysampler/frustum_helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from pytorch3d.transforms import Rotate, Transform3d
4 | from src.scenes.frames import GraphEdge
5 |
6 |
7 | def get_frustum_world(camera, edge2camera, device='cpu'):
8 | cam_trafo = edge2camera.get_transformation_c2p().to(device=device)
9 | H = camera.intrinsics.H
10 | W = camera.intrinsics.W
11 | focal = camera.intrinsics.f_x
12 | sensor_corner = torch.tensor([[0., 0.], [0., H - 1], [W - 1, H - 1], [W - 1, 0.]], device=device)
13 | frustum_edges = torch.stack([(sensor_corner[:, 0] - W * .5) / focal, -(sensor_corner[:, 1] - H * .5) / focal,
14 | -torch.ones(size=(4,), device=device)], -1)
15 | frustum_edges /= torch.norm(frustum_edges, dim=-1)[:, None]
16 | frustum_edges = \
17 | Rotate(camera.R, device=device).compose(cam_trafo.translate(-edge2camera.translation)).transform_points(
18 | frustum_edges.reshape(1, -1, 3))[0]
19 | frustum_edges /= torch.norm(frustum_edges, dim=-1)[:, None]
20 |
21 | frustum_normals = torch.cross(frustum_edges, frustum_edges[[1, 2, 3, 0], :])
22 |
23 | frustum_normals /= torch.norm(frustum_normals, dim=-1)[:, None]
24 |
25 | return frustum_edges, frustum_normals
26 |
27 |
28 | def get_frustum(camera, edge2camera, edge2reference, device='cpu'):
29 | # Gets the edges and normals of a cameras frustum with respect to a reference system
30 | openGL2dataset = Rotate(camera.R, device=device)
31 |
32 | if type(edge2reference) == GraphEdge:
33 | ref2cam = torch.eye(4, device=device)[None]
34 |
35 | ref2cam[0, 3, :3] = edge2reference.translation - edge2camera.translation
36 | ref2cam[0, :3, :3] = edge2camera.getRotation_c2p().compose(edge2reference.getRotation_p2c()).get_matrix()[:, :3, :3]
37 | cam_trafo = Transform3d(matrix=ref2cam)
38 |
39 | else:
40 | wo2veh = edge2reference
41 | cam2wo = edge2camera.get_transformation_c2p().cpu().get_matrix()[0].T.detach().numpy()
42 |
43 | cam2veh = wo2veh.dot(cam2wo)
44 | cam2veh = Transform3d(matrix=torch.tensor(cam2veh.T, device=device, dtype=torch.float32))
45 | cam_trafo = cam2veh
46 | ref2cam = cam2veh.get_matrix()
47 |
48 | H = camera.intrinsics.H
49 | W = camera.intrinsics.W
50 | focal = camera.intrinsics.f_x.detach()
51 | sensor_corner = torch.tensor([[0., 0.], [0., H - 1], [W - 1, H - 1], [W - 1, 0.]], device=device)
52 | frustum_edges = torch.stack([(sensor_corner[:, 0] - W * .5) / focal, -(sensor_corner[:, 1] - H * .5) / focal,
53 | -torch.ones(size=(4,), device=device)], -1)
54 | frustum_edges /= torch.norm(frustum_edges, dim=-1)[:, None]
55 | frustum_edges = openGL2dataset.compose(cam_trafo.translate(-ref2cam[:, 3, :3])).transform_points(
56 | frustum_edges.reshape(1, -1, 3))[0]
57 | frustum_edges /= torch.norm(frustum_edges, dim=-1)[:, None]
58 | frustum_normals = torch.cross(frustum_edges, frustum_edges[[1, 2, 3, 0], :])
59 | frustum_normals /= torch.norm(frustum_normals, dim=-1)[:, None]
60 | return frustum_edges, frustum_normals
61 |
62 |
63 | def get_frustum_torch(camera, edge2camera, edge2reference, device='cpu'):
64 | # Gets the edges and normals of a cameras frustum with respect to a reference system
65 | openGL2dataset = Rotate(camera.R, device=device)
66 |
67 | if type(edge2reference) == GraphEdge:
68 | ref2cam = torch.eye(4, device=device)[None]
69 |
70 | ref2cam[0, 3, :3] = edge2reference.translation - edge2camera.translation
71 | ref2cam[0, :3, :3] = edge2camera.getRotation_c2p().compose(edge2reference.getRotation_p2c()).get_matrix()[:, :3,
72 | :3]
73 | cam_trafo = Transform3d(matrix=ref2cam)
74 |
75 | else:
76 | wo2veh = edge2reference
77 | cam2wo = edge2camera.get_transformation_c2p(device='cpu', requires_grad=False).get_matrix()[0].T
78 |
79 | cam2veh = torch.matmul(wo2veh, cam2wo)
80 | cam2veh = Transform3d(matrix=cam2veh.T)
81 | cam_trafo = cam2veh
82 | ref2cam = cam2veh.get_matrix()
83 |
84 | H = camera.intrinsics.H
85 | W = camera.intrinsics.W
86 | focal = camera.intrinsics.f_x.detach()
87 | sensor_corner = torch.tensor([[0., 0.], [0., H - 1], [W - 1, H - 1], [W - 1, 0.]], device=device)
88 | frustum_edges = torch.stack([(sensor_corner[:, 0] - W * .5) / focal, -(sensor_corner[:, 1] - H * .5) / focal,
89 | -torch.ones(size=(4,), device=device)], -1)
90 | frustum_edges /= torch.norm(frustum_edges, dim=-1)[:, None]
91 | frustum_edges = openGL2dataset.compose(cam_trafo.translate(-ref2cam[:, 3, :3])).transform_points(
92 | frustum_edges.reshape(1, -1, 3))[0]
93 | frustum_edges /= torch.norm(frustum_edges, dim=-1)[:, None]
94 | frustum_normals = torch.cross(frustum_edges, frustum_edges[[1, 2, 3, 0], :])
95 | frustum_normals /= torch.norm(frustum_normals, dim=-1)[:, None]
96 | return frustum_edges, frustum_normals
97 |
98 |
99 | # TODO: Convert to full numpy version
100 | def get_frustum_np(camera, edge2camera, edge2reference, device='cpu'):
101 | # Gets the edges and normals of a cameras frustum with respect to a reference system
102 |
103 | openGL2dataset = Rotate(camera.R, device=device)
104 |
105 | ref2cam = torch.eye(4, device=device)[None]
106 |
107 | ref2cam[0, 3, :3] = edge2reference.translation - edge2camera.translation
108 | ref2cam[0, :3, :3] = edge2camera.getRotation_c2p().compose(edge2reference.getRotation_p2c()).get_matrix()[:, :3, :3]
109 | cam_trafo = Transform3d(matrix=ref2cam)
110 |
111 | H = camera.intrinsics.H
112 | W = camera.intrinsics.W
113 | focal = camera.intrinsics.f_x
114 | sensor_corner = torch.tensor([[0., 0.], [0., H - 1], [W - 1, H - 1], [W - 1, 0.]], device=device)
115 | frustum_edges = torch.stack([(sensor_corner[:, 0] - W * .5) / focal, -(sensor_corner[:, 1] - H * .5) / focal,
116 | -torch.ones(size=(4,), device=device)], -1)
117 | frustum_edges /= torch.norm(frustum_edges, dim=-1)[:, None]
118 | frustum_edges = openGL2dataset.compose(cam_trafo.translate(-ref2cam[:, 3, :3])).transform_points(
119 | frustum_edges.reshape(1, -1, 3))[0]
120 | frustum_edges /= torch.norm(frustum_edges, dim=-1)[:, None]
121 | frustum_normals = torch.cross(frustum_edges, frustum_edges[[1, 2, 3, 0], :])
122 | frustum_normals /= torch.norm(frustum_normals, dim=-1)[:, None]
123 | return frustum_edges, frustum_normals
--------------------------------------------------------------------------------
/src/pointLF/pointcloud_encoding/simpleview.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | # from all_utils import DATASET_NUM_CLASS
4 | from .simpleview_utils import PCViews, Squeeze, BatchNormPoint
5 | import matplotlib.pyplot as plt
6 |
7 | from .resnet import _resnet, BasicBlock
8 |
9 |
10 | class MVModel(nn.Module):
11 | def __init__(self, task,
12 | # dataset,
13 | backbone,
14 | feat_size,
15 | resolution=128,
16 | upscale_feats = True,
17 | ):
18 | super().__init__()
19 | assert task == 'cls'
20 | self.task = task
21 | self.num_class = 10 # DATASET_NUM_CLASS[dataset]
22 | self.dropout_p = 0.5
23 | self.feat_size = feat_size
24 |
25 | self.resolution = resolution
26 | self.upscale_feats = False
27 |
28 | pc_views = PCViews()
29 | self.num_views = pc_views.num_views
30 | self._get_img = pc_views.get_img
31 | self._get_img_coord = pc_views.get_img_coord
32 |
33 | img_layers, in_features = self.get_img_layers(
34 | backbone, feat_size=feat_size)
35 | self.img_model = nn.Sequential(*img_layers)
36 |
37 | # self.final_fc = MVFC(
38 | # num_views=self.num_views,
39 | # in_features=in_features,
40 | # out_features=self.num_class,
41 | # dropout_p=self.dropout_p)
42 |
43 | # Upscale resnet outputs to img resolution
44 | if upscale_feats:
45 | self.upscale_feats = True
46 |
47 | self.upscaleLin = nn.ModuleList(
48 | [
49 | nn.Sequential(
50 | nn.Linear(in_features=self.feat_size * (2 ** i), out_features=self.resolution), nn.ReLU()
51 | ) for i in range(4)
52 | ]
53 | )
54 | self.upsampleLayer = nn.Upsample(size=(resolution, resolution), mode='nearest')
55 |
56 | self.out_idx = [3, 4, 5, 7]
57 |
58 | def forward(self, pc, **kwargs):
59 | """
60 | :param pc:
61 | :return:
62 | """
63 |
64 | # Does not give the same results if trained on a single or more images, because of batch norm
65 | pc = pc.cuda()
66 | img = self.get_img(pc)
67 |
68 | # feat = self.img_model(img)
69 | outs = []
70 | h = img
71 | for layer in self.img_model:
72 | h = layer(h)
73 | outs.append(h)
74 |
75 | if self.upscale_feats:
76 | feat_ls = []
77 | for i, upLayer in zip(self.out_idx, self.upscaleLin):
78 | h = outs[i].transpose(1,-1)
79 | h = upLayer(h).transpose(-1,1)
80 | feat_ls.append(self.upsampleLayer(h))
81 | feat = torch.sum(torch.stack(feat_ls), dim=0) / len(self.upscaleLin)
82 |
83 | # plt.imshow(torch.sum(outs[7].transpose(1,-1), dim=-1).cpu().detach().numpy()[0])
84 | # plt.imshow(torch.sum(feat.transpose(1, -1), dim=-1).cpu().detach().numpy()[0])
85 | else:
86 | feat = outs[-1]
87 |
88 | # if len(pc) >= 1:
89 | # i_sh = [6, len(pc)] + list(img.shape[1:])
90 | # f_sh = [6, len(pc)] + list(feat.shape[1:])
91 | # img = img.reshape(i_sh)
92 | # feat = feat.reshape(f_sh)
93 | #
94 | # else:
95 | # img = img.unsqueeze(1)
96 | # feat = feat.unsqueeze(1)
97 |
98 | n_img, in_ch, w_in, h_in = img.shape
99 | n_feat_maps, out_ch, w_out, h_out = feat.shape
100 | n_batch = len(pc)
101 | feat = feat.reshape(n_batch, n_feat_maps//n_batch, out_ch, w_out, h_out)
102 | img = img.reshape(n_batch, n_img // n_batch, in_ch, w_in, h_in)
103 |
104 | return feat, img, None
105 |
106 |
107 | def get_img(self, pc):
108 | img = self._get_img(pc, self.resolution)
109 | img = torch.tensor(img).float()
110 | img = img.to(next(self.parameters()).device)
111 | assert len(img.shape) == 3
112 | img = img.unsqueeze(3)
113 | # [num_pc * num_views, 1, RESOLUTION, RESOLUTION]
114 | img = img.permute(0, 3, 1, 2)
115 |
116 | return img
117 |
118 | @staticmethod
119 | def get_img_layers(backbone, feat_size):
120 | """
121 | Return layers for the image model
122 | """
123 | assert backbone == 'resnet18'
124 | layers = [2, 2, 2, 2]
125 | block = BasicBlock
126 | backbone_mod = _resnet(
127 | arch=None,
128 | block=block,
129 | layers=layers,
130 | pretrained=False,
131 | progress=False,
132 | feature_size=feat_size,
133 | zero_init_residual=True)
134 |
135 | all_layers = [x for x in backbone_mod.children()]
136 | in_features = all_layers[-1].in_features
137 |
138 | # all layers except the final fc layer.py and the initial conv layers
139 | # WARNING: this is checked only for resnet models
140 |
141 | # main_layers = all_layers[4:-1]
142 | main_layers = all_layers[4:-2]
143 | img_layers = [
144 | nn.Conv2d(1, feat_size, kernel_size=(3, 3), stride=(1, 1),
145 | padding=(1, 1), bias=False),
146 | nn.BatchNorm2d(feat_size, eps=1e-05, momentum=0.1,
147 | affine=True, track_running_stats=True),
148 | nn.ReLU(inplace=True),
149 | *main_layers,
150 | Squeeze()
151 | ]
152 |
153 | return img_layers, in_features
154 |
155 |
156 | class MVFC(nn.Module):
157 | """
158 | Final FC layers for the MV model
159 | """
160 |
161 | def __init__(self, num_views, in_features, out_features, dropout_p):
162 | super().__init__()
163 | self.num_views = num_views
164 | self.in_features = in_features
165 | self.model = nn.Sequential(
166 | BatchNormPoint(in_features),
167 | # dropout before concatenation so that each view drops features independently
168 | nn.Dropout(dropout_p),
169 | nn.Flatten(),
170 | nn.Linear(in_features=in_features * self.num_views,
171 | out_features=in_features),
172 | nn.BatchNorm1d(in_features),
173 | nn.ReLU(),
174 | nn.Dropout(dropout_p),
175 | nn.Linear(in_features=in_features, out_features=out_features,
176 | bias=True))
177 |
178 | def forward(self, feat):
179 | feat = feat.view((-1, self.num_views, self.in_features))
180 | out = self.model(feat)
181 | return out
182 |
--------------------------------------------------------------------------------
/src/scenes/init_detection.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from . import createSceneObject
4 | import random
5 |
6 |
7 | def init_frames_on_BEV_anchors(scene,
8 | image_ls,
9 | exp_dict_pretrained,
10 | tgt_frames,
11 | n_anchors_depth=3,
12 | n_anchors_angle=3,
13 | ):
14 | # Get camera
15 | cam = list(scene.nodes['camera'].values())[0]
16 | far = exp_dict_pretrained['scenes'][0]['far_plane']
17 | near = exp_dict_pretrained['scenes'][0]['near_plane']
18 | box_scale= exp_dict_pretrained['scenes'][0]['box_scale']
19 | cam_above_ground_kitti = 1.5
20 |
21 | param_dict = {}
22 |
23 | # Initialize frame with camera and background objects
24 | frame_idx = scene.init_blank_frames(image_ls,
25 | camera_node=cam,
26 | far=far,
27 | near=near,
28 | box_scale=box_scale,)
29 |
30 | # Create anchors at the midpoints of each cell in an BEV grid inside the viewing frustum
31 | anchor_angle, anchor_depth, anchor_height = _create_cyl_mid_point_anchors(cam,
32 | cam_above_ground_kitti,
33 | n_anchors_angle,
34 | n_anchors_depth,
35 | near,
36 | far)
37 |
38 | # Get known object paramteres to sample from
39 | known_objs = list(scene.nodes['scene_object'].keys())
40 |
41 | # Get Car object
42 | for v in scene.nodes['object_class'].values():
43 | if v.name == 'Car':
44 | obj_class = v
45 |
46 | for fr_id in frame_idx:
47 | # Add random deviations
48 | anchor_angle_fr, anchor_depth_fr = _add_random_offset_anchors(cam, anchor_angle, anchor_depth, n_anchors_angle,
49 | n_anchors_depth, near, far)
50 |
51 | # Combine angle and depth values to xyz
52 | anchor_x = anchor_depth * torch.tan(anchor_angle_fr)
53 | anchors = torch.cat([anchor_x[..., None], anchor_height, anchor_depth_fr[..., None]], dim=2)
54 |
55 | # Loop over all anchors
56 | for i, anchor in enumerate(anchors.view(-1, 3)):
57 | # Get params and add a new obj at this anchor
58 | size, latent = _get_anchor_obj_params(scene, known_objs)
59 |
60 | new_obj_dict = createSceneObject(length=size[0],
61 | height=size[1],
62 | width=size[2],
63 | object_class_node=obj_class,
64 | latent=latent,
65 | type_idx=i)
66 |
67 | nodes = scene.updateNodes(new_obj_dict)
68 | new_obj = list(nodes['scene_object'].values())[0]
69 |
70 | # Get rotation
71 | rotation = _get_anchor_box_rotation(anchor)
72 |
73 | # Add new edge for the anchor and new object to the scene
74 | scene.add_new_obj_edges(frame_id=fr_id,
75 | object_node=new_obj,
76 | translation=anchor,
77 | rotation=rotation,
78 | box_size_scaling=box_scale,)
79 |
80 | frame = scene.frames[fr_id]
81 | for obj_id in frame.get_objects_ids():
82 | for v_name, v in frame.get_object_parameters(obj_id).items():
83 | if v_name != 'scaling':
84 | param_dict['{}_{}'.format(v_name, obj_id)] = v
85 |
86 | tgt_dict = {}
87 | for tgt_id in tgt_frames:
88 | tgt_frame = scene.frames[tgt_id]
89 | for obj_id in tgt_frame.get_objects_ids():
90 | for v_name, v in tgt_frame.get_object_parameters(obj_id).items():
91 | if v_name != 'scaling':
92 | tgt_dict['{}_{}'.format(v_name, obj_id)] = v
93 |
94 |
95 | camera_idx = [cam.scene_idx] * len(frame_idx)
96 | return param_dict, frame_idx, camera_idx, tgt_dict
97 |
98 |
99 | def _create_cyl_mid_point_anchors(camera,
100 | cam_above_ground,
101 | n_anchors_angle,
102 | n_anchors_depth,
103 | near,
104 | far):
105 | # Basic BEV Anchors
106 | cam_param = camera.intrinsics
107 | fov_y = 2 * torch.arctan(cam_param.H / (2 * cam_param.f_y))
108 | fov_x = 2 * torch.arctan(cam_param.W / (2 * cam_param.f_x))
109 |
110 | device = fov_y.device
111 |
112 | # Sample anchors on BEV Plane inside camera viewing frustum
113 | # Sample anchors from left to right along the angle inside the FOV
114 | percent_fov_x = torch.linspace(0, n_anchors_angle - 1, n_anchors_angle, device=device) / n_anchors_angle
115 | angle_mid_point = 1 / (2 * n_anchors_angle)
116 | anchor_angle = fov_x * (percent_fov_x + angle_mid_point) - fov_x / 2
117 | anchor_angle = anchor_angle[None, :].repeat(n_anchors_depth, 1)
118 |
119 | # Sample along depth from near to far
120 | depth_mid_point = (far - near) / (2 * n_anchors_depth)
121 | anchor_depth = (far - near) * (
122 | torch.linspace(0, n_anchors_depth - 1, n_anchors_depth, device=device) / n_anchors_depth)
123 | anchor_depth += near + depth_mid_point
124 | anchor_depth = anchor_depth[:, None].repeat(1, n_anchors_angle)
125 |
126 | anchor_height = torch.ones(size=(n_anchors_depth, n_anchors_angle, 1), device=device) * cam_above_ground
127 |
128 | return anchor_angle, anchor_depth, anchor_height
129 |
130 |
131 | def _add_random_offset_anchors(camera, anchor_angle, anchor_depth, n_anchors_angle, n_anchors_depth, near, far):
132 | # Basic BEV Anchors
133 | cam_param = camera.intrinsics
134 | fov_y = 2 * torch.arctan(cam_param.H / (2 * cam_param.f_y))
135 | fov_x = 2 * torch.arctan(cam_param.W / (2 * cam_param.f_x))
136 |
137 | device = anchor_angle.device
138 |
139 | # Add random deviations
140 | rand_angle = (2 * torch.rand(size=anchor_angle.shape, device=device) - 1) * (
141 | fov_x / n_anchors_angle)
142 | rand_depth = (2 * torch.rand(size=anchor_angle.shape, device=device) - 1) * (
143 | (far - near) / n_anchors_depth)
144 |
145 | anchor_angle_fr = anchor_angle + rand_angle
146 | anchor_depth_fr = anchor_depth + rand_depth
147 | return anchor_angle_fr, anchor_depth_fr
148 |
149 |
150 | def _get_anchor_obj_params(scene, known_objs):
151 | random.shuffle(known_objs)
152 | known_obj_id = known_objs[0]
153 | known_obj = scene.nodes['scene_object'][known_obj_id]
154 | size = known_obj.box_size
155 | latent = known_obj.latent
156 |
157 | return size, latent
158 |
159 |
160 | def _get_anchor_box_rotation(anchor):
161 | yaw = torch.rand((1,), device=anchor.device) * - np.pi / 2
162 | if yaw < 1e-4:
163 | yaw += 1e-4
164 | rotation = torch.tensor([0., yaw, 0.])
165 | return rotation
--------------------------------------------------------------------------------
/exp_configs/pointLF_exps.py:
--------------------------------------------------------------------------------
1 | from haven import haven_utils as hu
2 |
3 | EXP_GROUPS = {}
4 |
5 |
6 | def get_scenes(scene_ids, selected_frames='default', dataset='waymo', object_types=None):
7 | scene_list = []
8 | for scene_id in scene_ids:
9 | first_frame = None
10 | last_frame = None
11 | if isinstance(selected_frames, list):
12 | first_frame = selected_frames[0]
13 | last_frame = selected_frames[1]
14 |
15 | if type(scene_id) != list and dataset == 'waymo':
16 | for record_id in range(25):
17 | scene_list += [{'scene_id': [scene_id, record_id],
18 | 'type': dataset,
19 | 'first_frame': None,
20 | 'last_frame': None,
21 | 'far_plane': 150,
22 | 'near_plane': .5,
23 | "new_world": True,
24 | 'box_scale': 1.5,
25 | 'object_types': object_types,
26 | 'fix': True,
27 | 'pt_cloud_fix': True}]
28 |
29 | else:
30 | scene_dict = {'scene_id': scene_id,
31 | 'type': dataset,
32 | 'first_frame': first_frame,
33 | 'last_frame': last_frame,
34 | 'far_plane': 150,
35 | 'near_plane': .5,
36 | "new_world": True,
37 | 'box_scale': 1.5,
38 | 'object_types': object_types,
39 | 'fix': True,
40 | 'pt_cloud_fix': True}
41 | if object_types is not None:
42 | scene_dict['object_types'] = object_types
43 | scene_list += [scene_dict]
44 | return scene_list
45 |
46 |
47 | EXP_GROUPS['pointLF_waymo_local'] = hu.cartesian_exp_group({
48 | "n_rays": [512],
49 | 'image_batch_size': [2, ],
50 | "chunk": [512],
51 | "scenes": [
52 | # get_scenes(scene_ids=[[0, 8]], dataset='waymo', selected_frames=[135, 197], object_types=['TYPE_VEHICLE']),
53 | # get_scenes(scene_ids=[[0, 11]], dataset='waymo', selected_frames=[0, 20], object_types=['TYPE_VEHICLE']),
54 | get_scenes(scene_ids=[[0, 2]], dataset='waymo', selected_frames=[0, 80], object_types=['TYPE_VEHICLE']),
55 | # get_scenes(scene_ids=[[0, 2]], dataset='waymo', selected_frames=[0, 197], object_types=['TYPE_VEHICLE']),
56 | # get_scenes(scene_ids=[[2, 0]], dataset='waymo', selected_frames=[0, 10], object_types=['TYPE_VEHICLE']),
57 | # get_scenes(scene_ids=[[0, 19]], dataset='waymo', selected_frames=[0, 198], object_types=['TYPE_VEHICLE']),
58 | # get_scenes(scene_ids=[[0, 19]], dataset='waymo', selected_frames=[96, 198], object_types=['TYPE_VEHICLE']),
59 | # get_scenes(scene_ids=[[1, 15]], dataset='waymo', selected_frames=[0, 197], object_types=['TYPE_VEHICLE']),
60 | # get_scenes(scene_ids=[[2, 0]], dataset='waymo', selected_frames=[0, 40], object_types=['TYPE_VEHICLE']),
61 | ],
62 | "precache": [False],
63 | "lrate": [0.001, ],
64 | "lrate_decay": 250,
65 | "netchunk": 65536,
66 | 'lightfield': {'k_closest': 4,
67 | 'n_features': 128,
68 | # 'n_sample_pts': 20000,
69 | 'n_sample_pts': 5000,
70 | # 'pointfeat_encoder': 'pointnet_lf_global_weighted',
71 | # 'pointfeat_encoder': 'multiview_attention_modulation',
72 | 'pointfeat_encoder': 'multiview_attention',
73 | # 'pointfeat_encoder': 'multiview_attention_up',
74 | # 'pointfeat_encoder': 'naive_ablation',
75 | # 'pointfeat_encoder': 'one_point_ablation',
76 | # 'pointfeat_encoder': 'pointnet_ablation',
77 | # 'pointfeat_encoder': 'encoding_attention_only',
78 | # 'pointfeat_encoder': 'multiview_attention_big',
79 | # 'pointfeat_encoder': 'multiview_encoded',
80 | # 'pointfeat_encoder': 'multiview_distance_attention',
81 | # 'pointfeat_encoder': 'multiview_encoded_modulation',
82 | # 'pointfeat_encoder': 'multiview_encoded_weighted_modulation',
83 | 'merge_pcd': False,
84 | 'all_cams': True,
85 | 'D_lf': 8,
86 | 'skips_lf': [4],
87 | 'camera_centered': False,
88 | 'augment_frame_order': True,
89 | 'new_enc': False,
90 | 'torch_sampler': True,
91 | 'sky_dome': True,
92 | 'num_merged_frames': 1,
93 | },
94 | # "overfit": "frame",
95 | "scale": 0.0625, # 0.125,
96 | "point_chunk": 1e7,
97 | 'version': [0],
98 | 'tonemapping': False,
99 | 'pose_refinement': False,
100 | 'pt_cache': True,
101 | },
102 | remove_none=True
103 | )
104 |
105 |
106 | EXP_GROUPS['pointLF_waymo_server'] = hu.cartesian_exp_group({
107 | "n_rays": [8192],
108 | 'image_batch_size': [2, ],
109 | "chunk": [64000],
110 | "scenes": [
111 | # get_scenes(scene_ids=[[0, 2]], dataset='waymo', selected_frames=[0, 20], object_types=['TYPE_VEHICLE']),
112 | # get_scenes(scene_ids=[[0, 2]], dataset='waymo', selected_frames=[0, 40], object_types=['TYPE_VEHICLE']),
113 | get_scenes(scene_ids=[[0, 2]], dataset='waymo', selected_frames=[0, 80], object_types=['TYPE_VEHICLE']),
114 | # get_scenes(scene_ids=[[0, 2]], dataset='waymo', selected_frames=[0, 120], object_types=['TYPE_VEHICLE']),
115 | # get_scenes(scene_ids=[[0, 2]], dataset='waymo', selected_frames=[0, 197], object_types=['TYPE_VEHICLE']),
116 | # get_scenes(scene_ids=[[0, 8]], dataset='waymo', selected_frames=[171, 190], object_types=['TYPE_VEHICLE']),
117 | # get_scenes(scene_ids=[[0, 8]], dataset='waymo', selected_frames=[135, 197], object_types=['TYPE_VEHICLE']),
118 | # get_scenes(scene_ids=[[0, 11]], dataset='waymo', selected_frames=[0, 164], object_types=['TYPE_VEHICLE']),
119 | # get_scenes(scene_ids=[[1, 15]], dataset='waymo', selected_frames=[0, 197], object_types=['TYPE_VEHICLE']),
120 | # get_scenes(scene_ids=[[2, 3]], dataset='waymo', selected_frames=[0, 198], object_types=['TYPE_VEHICLE']),
121 | # get_scenes(scene_ids=[[2, 9]], dataset='waymo', selected_frames=[0, 197], object_types=['TYPE_VEHICLE']),
122 | ],
123 | "precache": [False],
124 | "latent_balance": 0.0001,
125 | "lrate_decay": 250,
126 | "netchunk": 65536,
127 | 'lightfield': {'k_closest': 8,
128 | 'n_features': 128,
129 | 'n_sample_pts': 500,
130 | 'pointfeat_encoder': 'multiview_attention',
131 | 'merge_pcd': True,
132 | 'all_cams': False,
133 | 'D_lf': 8,
134 | 'W_lf': 256,
135 | 'skips_lf': [4],
136 | 'layer_modulation': False,
137 | 'camera_centered': False,
138 | 'augment_frame_order': True,
139 | 'new_enc': False,
140 | 'torch_sampler': True,
141 | 'sky_dome': True,
142 | 'num_merged_frames': 20,
143 | },
144 | # "overfit": "frame",
145 | "scale": .125,
146 | 'version': 0,
147 | "point_chunk": 1e7,
148 | 'tonemapping': False,
149 | 'pose_refinement': False,
150 | 'pt_cache': True,
151 | },
152 | remove_none=True
153 | )
154 |
--------------------------------------------------------------------------------
/src/datasets/preprocess_Waymo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import glob
4 | import numpy as np
5 | from tqdm import tqdm
6 | from waymo_open_dataset import dataset_pb2 as open_dataset
7 | from waymo_open_dataset.utils import frame_utils
8 | # from waymo_open_dataset import label_pb2 as open_label
9 | import imageio
10 | import tensorflow.compat.v1 as tf
11 | tf.enable_eager_execution()
12 | import pickle
13 | from collections import defaultdict
14 | from copy import deepcopy
15 | # import waymo
16 | import cv2
17 | import open3d as o3d
18 |
19 |
20 | SAVE_INTRINSIC = True
21 | SINGLE_TRACK_INFO_FILE = True
22 | DEBUG = False # If True, processing only the first tfrecord file, and saving with a "_debug" suffix.
23 | MULTIPLE_DIRS = False
24 | # DATADIRS = '/media/ybahat/data/Datasets/Waymo/val/0001'
25 |
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('-d', '--datadir')
28 | parser.add_argument('-nd', '--no_data', default=False)
29 | args,_ = parser.parse_known_args()
30 | datadirs = args.datadir
31 | export_data = not args.no_data
32 | # datadirs = DATADIRS
33 | saving_dir = '/'.join(datadirs.split('/')[:-1])
34 | if '.tfrecord' not in datadirs:
35 | saving_dir = 1*datadirs
36 | datadirs = glob.glob(datadirs+'/*.tfrecord',recursive=True)
37 | datadirs = sorted([f for f in datadirs if '.tfrecord' in f])
38 | MULTIPLE_DIRS = True
39 |
40 | if not isinstance(datadirs,list): datadirs = [datadirs]
41 | if not os.path.isdir(saving_dir): os.mkdir(saving_dir)
42 |
43 | def extract_label_fields(l,dims):
44 | assert dims in [2,3]
45 | label_dict = {'c_x':l.box.center_x,'c_y':l.box.center_y,'width':l.box.width,'length':l.box.length,'type':l.type}
46 | if dims==3:
47 | label_dict['c_z'] = l.box.center_z
48 | label_dict['height'] = l.box.height
49 | label_dict['heading'] = l.box.heading
50 | return label_dict
51 |
52 | def read_intrinsic(intrinsic_params_vector):
53 | return dict(zip(['f_u', 'f_v', 'c_u', 'c_v', 'k_1', 'k_2', 'p_1', 'p_2', 'k_3'], intrinsic_params_vector))
54 |
55 | isotropic_focal = lambda intrinsic_dict: intrinsic_dict['f_u']==intrinsic_dict['f_v']
56 |
57 | # datadirs = [os.path.join(args.datadir,
58 | # 'segment-10203656353524179475_7625_000_7645_000_with_camera_labels.tfrecord')]
59 | for file_num,file in enumerate(datadirs):
60 | if SINGLE_TRACK_INFO_FILE:
61 | tracking_info = {}
62 | if file_num > 0 and DEBUG: break
63 | file_name = file.split('/')[-1].split('.')[0]
64 | print('Processing file ',file_name)
65 | if not os.path.isdir(os.path.join(saving_dir, file_name)): os.mkdir(os.path.join(saving_dir, file_name))
66 | if not os.path.isdir(os.path.join(saving_dir,file_name, 'images')): os.mkdir(os.path.join(saving_dir,file_name, 'images'))
67 | if not os.path.isdir(os.path.join(saving_dir, file_name, 'point_cloud')): os.mkdir(os.path.join(saving_dir, file_name, 'point_cloud'))
68 | if not SINGLE_TRACK_INFO_FILE:
69 | if not os.path.isdir(os.path.join(saving_dir,file_name, 'tracking')): os.mkdir(os.path.join(saving_dir,file_name, 'tracking'))
70 | dataset = tf.data.TFRecordDataset(file, compression_type='')
71 | for f_num, data in enumerate(tqdm(dataset)):
72 | frame = open_dataset.Frame()
73 | frame.ParseFromString(bytearray(data.numpy()))
74 | pose = np.zeros([len(frame.images), 4, 4])
75 | im_paths = {}
76 | pcd_paths = {}
77 | if SAVE_INTRINSIC:
78 | intrinsic = np.zeros([len(frame.images),9])
79 | extrinsic = np.zeros_like(pose)
80 | width,height,camera_labels = np.zeros([len(frame.images)]),np.zeros([len(frame.images)]),defaultdict(dict)
81 | for im in frame.images:
82 | saving_name = os.path.join(saving_dir,file_name, 'images','%03d_%s.png'%(f_num,open_dataset.CameraName.Name.Name(im.name)))
83 | if not DEBUG and export_data:
84 | im_array = tf.image.decode_jpeg(im.image).numpy()
85 | # No compression imageio
86 | # imageio.imwrite(saving_name, im_array, compress_level=0)
87 | # Less compression imageio
88 | imageio.imwrite(saving_name, im_array, compress_level=3)
89 | # Original:
90 | # imageio.imwrite(saving_name, im_array,)
91 | # OpenCV Alternative (needs debugging for right colors):
92 | # cv2.imwrite(saving_name, im_array)
93 | pose[im.name-1, :, :] = np.reshape(im.pose.transform, [4, 4])
94 | im_paths[im.name] = saving_name
95 | extrinsic[im.name-1, :, :] = np.reshape(frame.context.camera_calibrations[im.name-1].extrinsic.transform, [4, 4])
96 | if SAVE_INTRINSIC:
97 | intrinsic[im.name-1, :] = frame.context.camera_calibrations[im.name-1].intrinsic
98 | assert isotropic_focal(read_intrinsic(intrinsic[im.name-1, :])),'Unexpected difference between f_u and f_v.'
99 | width[im.name-1] = frame.context.camera_calibrations[im.name-1].width
100 | height[im.name-1] = frame.context.camera_calibrations[im.name-1].height
101 | for obj_label in frame.projected_lidar_labels[im.name-1].labels:
102 | camera_labels[im.name][obj_label.id.replace('_'+open_dataset.CameraName.Name.Name(im.name),'')] = extract_label_fields(obj_label,2)
103 | # Extract point cloud data from stored range images
104 | laser_calib = np.zeros([len(frame.lasers), 4,4])
105 | if export_data:
106 | (range_images, camera_projections, range_image_top_pose) = \
107 | frame_utils.parse_range_image_and_camera_projection(frame)
108 | points, cp_points = frame_utils.convert_range_image_to_point_cloud(frame,
109 | range_images,
110 | camera_projections,
111 | range_image_top_pose)
112 | else:
113 | points =np.empty([len(frame.lasers), 1])
114 |
115 | laser_mapping = {}
116 | for (laser, pts) in zip(frame.lasers, points):
117 | saving_name = os.path.join(saving_dir, file_name, 'point_cloud', '%03d_%s.ply' % (f_num, open_dataset.LaserName.Name.Name(laser.name)))
118 | if export_data:
119 | pcd = o3d.geometry.PointCloud()
120 | pcd.points = o3d.utility.Vector3dVector(pts)
121 | o3d.io.write_point_cloud(saving_name, pcd)
122 | calib_id = int(np.where(np.array([cali.name for cali in frame.context.laser_calibrations[:5]]) == laser.name)[0])
123 | laser_calib[laser.name-1, :, :] = np.reshape(frame.context.laser_calibrations[calib_id].extrinsic.transform, [4, 4])
124 | pcd_paths[laser.name] = saving_name
125 | laser_mapping.update({open_dataset.LaserName.Name.Name(laser.name): calib_id})
126 |
127 | if 'intrinsic' in tracking_info:
128 | assert np.all(tracking_info['intrinsic']==intrinsic) and np.all(tracking_info['width']==width) and np.all(tracking_info['height']==height)
129 | else:
130 | tracking_info['intrinsic'],tracking_info['width'],tracking_info['height'] = intrinsic,width,height
131 | dict_2_save = {'per_cam_veh_pose':pose,'cam2veh':extrinsic,'im_paths':im_paths,'width':width,'height':height,
132 | 'veh2laser':laser_calib, 'pcd_paths': pcd_paths}
133 | if SAVE_INTRINSIC and not SINGLE_TRACK_INFO_FILE:
134 | dict_2_save['intrinsic'] = intrinsic
135 | lidar_labels = {}
136 | for obj_label in frame.laser_labels:
137 | lidar_labels[obj_label.id] = extract_label_fields(obj_label,3)
138 | dict_2_save['lidar_labels'] = lidar_labels
139 | dict_2_save['camera_labels'] = camera_labels
140 | dict_2_save['veh_pose'] = np.reshape(frame.pose.transform,[4,4])
141 | # dict_2_save['lidar2veh'] = np.reshape(frame.context.laser_calibrations['extrinsic'].transform,[4,4])
142 | dict_2_save['timestamp'] = frame.timestamp_micros
143 | if SINGLE_TRACK_INFO_FILE:
144 | tracking_info[(file_num,f_num)] = deepcopy(dict_2_save)
145 | else:
146 | with open(os.path.join(saving_dir,file_name, 'tracking','%03d.pkl'%(f_num)),'wb') as f:
147 | pickle.dump(dict_2_save,f)
148 | if SINGLE_TRACK_INFO_FILE:
149 | with open(os.path.join(saving_dir, file_name, 'tracking_info%s.pkl'%('_debug' if DEBUG else '')), 'wb') as f:
150 | pickle.dump(tracking_info, f)
--------------------------------------------------------------------------------
/src/pointLF/pointcloud_encoding/pointnet_features.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 |
7 | # https://github.com/fxia22/pointnet.pytorch
8 | class STN3d(nn.Module):
9 | def __init__(self):
10 | super(STN3d, self).__init__()
11 | self.conv1 = torch.nn.Conv1d(3, 64, 1)
12 | self.conv2 = torch.nn.Conv1d(64, 128, 1)
13 | self.conv3 = torch.nn.Conv1d(128, 1024, 1)
14 | self.fc1 = nn.Linear(1024, 512)
15 | self.fc2 = nn.Linear(512, 256)
16 | self.fc3 = nn.Linear(256, 9)
17 | self.relu = nn.ReLU()
18 |
19 | self.bn1 = nn.BatchNorm1d(64)
20 | self.bn2 = nn.BatchNorm1d(128)
21 | self.bn3 = nn.BatchNorm1d(1024)
22 | self.bn4 = nn.BatchNorm1d(512)
23 | self.bn5 = nn.BatchNorm1d(256)
24 |
25 | def forward(self, x):
26 | batchsize = x.size()[0]
27 | x = F.relu(self.bn1(self.conv1(x)))
28 | x = F.relu(self.bn2(self.conv2(x)))
29 | x = F.relu(self.bn3(self.conv3(x)))
30 | x = torch.max(x, 2, keepdim=True)[0]
31 | x = x.view(-1, 1024)
32 |
33 | if batchsize == 1:
34 | x = F.relu(self.fc1(x))
35 | x = F.relu(self.fc2(x))
36 | else:
37 | x = F.relu(self.bn4(self.fc1(x)))
38 | x = F.relu(self.bn5(self.fc2(x)))
39 | x = self.fc3(x)
40 |
41 | iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat(
42 | batchsize, 1)
43 | if x.is_cuda:
44 | iden = iden.cuda()
45 | x = x + iden
46 | x = x.view(-1, 3, 3)
47 | return x
48 |
49 |
50 | class STNkd(nn.Module):
51 | def __init__(self, k=64):
52 | super(STNkd, self).__init__()
53 | self.conv1 = torch.nn.Conv1d(k, 64, 1)
54 | self.conv2 = torch.nn.Conv1d(64, 128, 1)
55 | self.conv3 = torch.nn.Conv1d(128, 1024, 1)
56 | self.fc1 = nn.Linear(1024, 512)
57 | self.fc2 = nn.Linear(512, 256)
58 | self.fc3 = nn.Linear(256, k * k)
59 | self.relu = nn.ReLU()
60 |
61 | self.bn1 = nn.BatchNorm1d(64)
62 | self.bn2 = nn.BatchNorm1d(128)
63 | self.bn3 = nn.BatchNorm1d(1024)
64 | self.bn4 = nn.BatchNorm1d(512)
65 | self.bn5 = nn.BatchNorm1d(256)
66 |
67 | self.k = k
68 |
69 | def forward(self, x):
70 | batchsize = x.size()[0]
71 | x = F.relu(self.bn1(self.conv1(x)))
72 | x = F.relu(self.bn2(self.conv2(x)))
73 | x = F.relu(self.bn3(self.conv3(x)))
74 | x = torch.max(x, 2, keepdim=True)[0]
75 | x = x.view(-1, 1024)
76 |
77 | if batchsize == 1:
78 | x = F.relu(self.fc1(x))
79 | x = F.relu(self.fc2(x))
80 | else:
81 | x = F.relu(self.bn4(self.fc1(x)))
82 | x = F.relu(self.bn5(self.fc2(x)))
83 | x = self.fc3(x)
84 |
85 | iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat(
86 | batchsize, 1)
87 | if x.is_cuda:
88 | iden = iden.cuda()
89 | x = x + iden
90 | x = x.view(-1, self.k, self.k)
91 | return x
92 |
93 |
94 | class PointNetfeat(nn.Module):
95 | def __init__(self, global_feat=True, feature_transform=False):
96 | super(PointNetfeat, self).__init__()
97 | self.stn = STN3d()
98 | self.conv1 = torch.nn.Conv1d(3, 64, 1)
99 | self.conv2 = torch.nn.Conv1d(64, 128, 1)
100 | self.conv3 = torch.nn.Conv1d(128, 1024, 1)
101 | self.bn1 = nn.BatchNorm1d(64)
102 | self.bn2 = nn.BatchNorm1d(128)
103 | self.bn3 = nn.BatchNorm1d(1024)
104 | self.global_feat = global_feat
105 | self.feature_transform = feature_transform
106 | if self.feature_transform:
107 | self.fstn = STNkd(k=64)
108 |
109 | def forward(self, x):
110 | n_pts = x.size()[2]
111 | # input transform
112 | trans = self.stn(x)
113 | x = x.transpose(2, 1)
114 | x = torch.bmm(x, trans)
115 | x = x.transpose(2, 1)
116 | # (64,64)
117 | x = F.relu(self.bn1(self.conv1(x)))
118 |
119 | if self.feature_transform:
120 | # feature transform
121 | trans_feat = self.fstn(x)
122 | x = x.transpose(2, 1)
123 | x = torch.bmm(x, trans_feat)
124 | x = x.transpose(2, 1)
125 | else:
126 | trans_feat = None
127 |
128 | # (64, 128, 1024)
129 | pointfeat = x
130 | x = F.relu(self.bn2(self.conv2(x)))
131 | x = self.bn3(self.conv3(x))
132 |
133 | # Max pool -> global feature
134 | x = torch.max(x, 2, keepdim=True)[0]
135 | x = x.view(-1, 1024)
136 | if self.global_feat:
137 | # Returns only the global features
138 | return x, trans, trans_feat
139 | else:
140 | # Concatenates local and global features
141 | x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
142 | return torch.cat([x, pointfeat], 1), trans, trans_feat
143 |
144 |
145 | class PointNetDenseCls(nn.Module):
146 | def __init__(self, k=2, feature_transform=False):
147 | super(PointNetDenseCls, self).__init__()
148 | self.k = k
149 | self.feature_transform = feature_transform
150 | self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)
151 | self.conv1 = torch.nn.Conv1d(1088, 512, 1)
152 | self.conv2 = torch.nn.Conv1d(512, 256, 1)
153 | self.conv3 = torch.nn.Conv1d(256, 128, 1)
154 | self.conv4 = torch.nn.Conv1d(128, self.k, 1)
155 | self.bn1 = nn.BatchNorm1d(512)
156 | self.bn2 = nn.BatchNorm1d(256)
157 | self.bn3 = nn.BatchNorm1d(128)
158 |
159 | def forward(self, x, **kwargs):
160 | """
161 | x: [batch_size, 3, n_points]
162 | """
163 | batchsize = x.size()[0]
164 | n_pts = x.size()[2]
165 | x, trans, trans_feat = self.feat(x)
166 | x = F.relu(self.bn1(self.conv1(x)))
167 | x = F.relu(self.bn2(self.conv2(x)))
168 | x = F.relu(self.bn3(self.conv3(x)))
169 | x = self.conv4(x)
170 | x = x.transpose(2, 1).contiguous()
171 | x = F.log_softmax(x.view(-1, self.k), dim=-1)
172 | x = x.view(batchsize, n_pts, self.k)
173 | return x, trans, trans_feat
174 |
175 |
176 | class PointNetLightFieldEncoder(nn.Module):
177 | def __init__(self, k=2, feature_transform=False, points_only=False, original=False):
178 | super(PointNetLightFieldEncoder, self).__init__()
179 | self.k = k
180 | self.feature_transform = feature_transform
181 | self.points_only = points_only
182 | self.original = original
183 | self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)
184 | if not points_only and not original:
185 | self.conv1 = torch.nn.Conv1d(1024, 256, 1)
186 | self.conv2 = torch.nn.Conv1d(256, 64, 1)
187 | self.conv3 = torch.nn.Conv1d(128, self.k, 1)
188 |
189 | self.bn1 = nn.BatchNorm1d(256)
190 | self.bn2 = nn.BatchNorm1d(64)
191 | self.bn3 = nn.BatchNorm1d(self.k)
192 | elif original:
193 | self.conv1 = torch.nn.Conv1d(1088, 512, 1)
194 | self.conv2 = torch.nn.Conv1d(512, 256, 1)
195 | self.conv3 = torch.nn.Conv1d(256, 128, 1)
196 | self.conv4 = torch.nn.Conv1d(128, self.k, 1)
197 | self.bn1 = nn.BatchNorm1d(512)
198 | self.bn2 = nn.BatchNorm1d(256)
199 | self.bn3 = nn.BatchNorm1d(128)
200 | else:
201 | self.conv4 = torch.nn.Conv1d(64, self.k, 1)
202 | self.bn4 = nn.BatchNorm1d(self.k)
203 |
204 | def forward(self, x, **kwargs):
205 | """
206 | x: [batch_size, 3, n_points]
207 | """
208 | batchsize = x.size()[0]
209 | n_pts = x.size()[2]
210 | x, trans, trans_feat = self.feat(x)
211 | point_feat = x[:, 1024:, :]
212 | glob_feat = x[:, :1024, :]
213 |
214 | if not self.points_only and not self.original:
215 | x = F.relu(self.bn1(self.conv1(glob_feat)))
216 | x = F.relu(self.bn2(self.conv2(x)))
217 | x = torch.cat([x, point_feat], dim=1)
218 | x = F.relu(self.bn3(self.conv3(x)))
219 | elif self.original:
220 | x = F.relu(self.bn1(self.conv1(x)))
221 | x = F.relu(self.bn2(self.conv2(x)))
222 | x = F.relu(self.bn3(self.conv3(x)))
223 | x = self.conv4(x)
224 | else:
225 | x = F.relu(self.bn4(self.conv4(point_feat)))
226 |
227 | x = x.transpose(2, 1).contiguous()
228 | # x = F.log_softmax(x.view(-1, self.k), dim=-1)
229 | x = x.view(batchsize, n_pts, self.k)
230 | return x, trans, trans_feat
--------------------------------------------------------------------------------
/trainval.py:
--------------------------------------------------------------------------------
1 | import importlib; importlib.util.find_spec("waymo_open_dataset")
2 | # assert importlib.util.find_spec("waymo_open_dataset") is not None, "no waymo"
3 |
4 | import argparse
5 | import os
6 | import exp_configs
7 | import pandas as pd
8 | import numpy as np
9 | import torch
10 | import time
11 |
12 | from torch.utils.data import DataLoader
13 | from haven import haven_utils as hu
14 | from haven import haven_wizard as hw
15 | from src import models
16 | from src.scenes import NeuralScene
17 | from src import utils as ut
18 | from src import utils_dist as utd
19 |
20 | torch.backends.cudnn.benchmark = True
21 |
22 | ONLY_PRESENT_SCORES = True
23 |
24 | # 1. define the training and validation function
25 | def trainval(exp_dict, savedir, args):
26 | """
27 | exp_dict: dictionary defining the hyperparameters of the experiment
28 | savedir: the directory where the experiment will be saved
29 | args: arguments passed through the command line
30 | """
31 | # set seed
32 | seed = 42 + exp_dict.get("runs", 0)
33 | np.random.seed(seed)
34 | torch.manual_seed(seed)
35 | torch.cuda.manual_seed_all(seed)
36 |
37 | model_state_dict = None
38 |
39 | render_epi = False
40 | if render_epi:
41 | exp_dict_pth = '/home/julian/workspace/NeuralSceneGraphs/model_library/scene_0_2/exp_dict.json'
42 | model_pth = '/home/julian/workspace/NeuralSceneGraphs/model_library/scene_0_2/model.pth'
43 | epi_frame_idx = 6
44 | epi_row = 96
45 |
46 | # exp_dict_pth = '/home/julian/workspace/NeuralSceneGraphs/model_library/scene_2_3/exp_dict.json'
47 | # model_pth = '/home/julian/workspace/NeuralSceneGraphs/model_library/scene_2_3/model.pth'
48 | # epi_frame_idx = 79
49 | # epi_row = 118
50 | # epi_row = 101
51 |
52 | if args.render_only and os.path.exists(exp_dict_pth):
53 | exp_dict = hu.load_json(exp_dict_pth)
54 | model_state_dict = hu.torch_load(model_pth)
55 |
56 | exp_dict["scale"] = 0.0625
57 | exp_dict["scale"] = 0.125
58 |
59 | scene = NeuralScene(
60 | scene_list=exp_dict["scenes"],
61 | datadir=args.datadir,
62 | args=args,
63 | exp_dict=exp_dict,
64 | )
65 |
66 | batch_size = min(len(scene), exp_dict.get("image_batch_size", 1))
67 | rand_sampler = torch.utils.data.RandomSampler(
68 | scene, num_samples=args.epoch_size * batch_size, replacement=True
69 | )
70 |
71 | scene_loader = torch.utils.data.DataLoader(
72 | scene,
73 | sampler=rand_sampler,
74 | collate_fn=ut.collate_fn_dict_of_lists,
75 | batch_size=batch_size,
76 | num_workers=args.num_workers,
77 | drop_last=True,
78 | )
79 |
80 | if scene.refine_camera_pose:
81 | calib_sampler = torch.utils.data.RandomSampler(
82 | scene, num_samples=len(scene), replacement=True
83 | )
84 | scene_calib_loader = torch.utils.data.DataLoader(
85 | scene,
86 | sampler=calib_sampler,
87 | collate_fn=ut.collate_fn_dict_of_lists,
88 | batch_size=1,
89 | num_workers=0,
90 | drop_last=True,
91 | )
92 | # TODO: Find permanent fix https://discuss.pytorch.org/t/runtimeerror-received-0-items-of-ancdata/4999
93 | # torch.multiprocessing.set_sharing_strategy('file_system')
94 |
95 | model = models.Model(scene, exp_dict, precache=exp_dict.get("precache"), args=args)
96 |
97 | # 3. load checkpoint
98 | chk_dict = hw.get_checkpoint(savedir, return_model_state_dict=True)
99 |
100 | if len(chk_dict["model_state_dict"]):
101 | model.set_state_dict(chk_dict["model_state_dict"])
102 |
103 | if model_state_dict is not None:
104 | model.set_state_dict(model_state_dict)
105 |
106 | # val_dict = model.val_on_scene(scene, savedir_images=os.path.join(savedir, "images"), all_frames=True)
107 |
108 | if not args.render_only:
109 | for e in range(chk_dict["epoch"], 5000):
110 | # 0. init score dict
111 | score_dict = {"epoch": e, "n_objects": len(scene.nodes["scene_object"])}
112 |
113 | # (3. Optional Camera Calibration)
114 | if e % 25 == 0 and scene.refine_camera_pose and e > 0:
115 | for e_calib in range(5):
116 | s_time = time.time()
117 | scene.recalibrate = True
118 | calib_dict = model.train_on_scene(scene_calib_loader)
119 | scene.recalibrate = False
120 | score_dict["calib_time"] = time.time() - s_time
121 | score_dict.update(calib_dict)
122 |
123 | # 1. train on batch
124 | s_time = time.time()
125 | train_dict = model.train_on_scene(scene_loader)
126 | score_dict["train_time"] = time.time() - s_time
127 | score_dict.update(train_dict)
128 |
129 | s_time = time.time()
130 | val_dict = model.val_on_scene(
131 | scene,
132 | savedir_images=os.path.join(savedir, "images"),
133 | )
134 |
135 | # 2. val on batch
136 | if e % 100 == 0 and e > 0:
137 | val_dict = model.val_on_scene(
138 | scene,
139 | savedir_images=os.path.join(savedir, "images_all_frames_{}".format(e)),
140 | all_frames=True,
141 | )
142 |
143 | score_dict["val_time"] = time.time() - s_time
144 | score_dict.update(val_dict)
145 | # ONLY MASTER PROCESS?
146 | if utd.is_main_process():
147 | # 3. save checkpoint
148 | chk_dict["score_list"] += [score_dict]
149 | hw.save_checkpoint(
150 | savedir,
151 | model_state_dict=model.get_state_dict(),
152 | score_list=chk_dict["score_list"],
153 | verbose=not ONLY_PRESENT_SCORES,
154 | )
155 | if ONLY_PRESENT_SCORES:
156 | score_df = pd.DataFrame(chk_dict["score_list"])
157 | print("Save directory: %s" % savedir)
158 | print(score_df.tail(1).to_string(index=False), "\n")
159 | elif render_epi:
160 | val_dict = model.val_on_scene(scene, savedir_images=os.path.join(savedir, "images"), all_frames=False, EPI=True,
161 | epi_row=epi_row, epi_frame_idx=epi_frame_idx)
162 | else:
163 | val_dict = model.val_on_scene(scene, savedir_images=os.path.join(savedir, "images"), all_frames=True)
164 |
165 |
166 | # 7. create main
167 | if __name__ == "__main__":
168 | # 9. Launch experiments using magic command
169 | parser = argparse.ArgumentParser()
170 | parser.add_argument(
171 | "-e", "--exp_group_list", nargs="+", help="Define which exp groups to run."
172 | )
173 | parser.add_argument(
174 | "-sb",
175 | "--savedir_base",
176 | default=None,
177 | help="Define the base directory where the experiments will be saved.",
178 | )
179 | parser.add_argument("-d", "--datadir")
180 | parser.add_argument(
181 | "-r", "--reset", default=0, type=int, help="Reset or resume the experiment."
182 | )
183 | parser.add_argument(
184 | "-j", "--job_scheduler", default=None, help="Run jobs in cluster."
185 | )
186 | parser.add_argument(
187 | "-v",
188 | "--visualize",
189 | default="results/neural_scenes.ipynb",
190 | help="Run jobs in cluster.",
191 | )
192 | parser.add_argument("-p", "--python_binary_path", default="python")
193 | parser.add_argument("-db", "--debug", type=int, default=0)
194 | parser.add_argument("--epoch_size", type=int, default=100)
195 | parser.add_argument("--num_workers", type=int, default=0)
196 | parser.add_argument("--render_only", type=bool, default=False)
197 | # parser.add_argument(
198 | # "--dist_url", default="env://", help="url used to set up distributed training"
199 | # )
200 | parser.add_argument("--ngpus", type=int, default=1)
201 |
202 | args, others = parser.parse_known_args()
203 |
204 | # Load job config to run things on cluster
205 | python_binary_path = args.python_binary_path
206 | jc = None
207 | if os.path.exists("job_config.py"):
208 | import job_config
209 |
210 | jc = job_config.JOB_CONFIG
211 | if args.ngpus > 1:
212 | jc["resources"]["gpu"] = args.ngpus
213 | python_binary_path += (
214 | f" -m torch.distributed.launch --nproc_per_node={args.ngpus} --use_env "
215 | )
216 |
217 | # utd.init_distributed_mode(args)
218 | # if args.distributed and not utd.is_main_process():
219 | # args.reset = 0
220 |
221 | hw.run_wizard(
222 | func=trainval,
223 | exp_groups=exp_configs.EXP_GROUPS,
224 | savedir_base=args.savedir_base,
225 | reset=args.reset,
226 | python_binary_path=python_binary_path,
227 | job_config=jc,
228 | args=args,
229 | use_threads=True,
230 | results_fname="results/neural_scenes.ipynb",
231 | )
232 |
--------------------------------------------------------------------------------
/src/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | import pylab as plt
2 |
3 | # import tensorflow as tf
4 | # from ..nerf_tf.prepare_input_helper import extract_object_information
5 | import numpy as np
6 | import torch, os, copy, glob
7 |
8 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9 |
10 |
11 | def get_dataset(datadir, scene_dict, args):
12 | scene_dict = copy.deepcopy(scene_dict)
13 | selected_frames = None
14 | if scene_dict["last_frame"] is not None:
15 | selected_frames = [scene_dict["first_frame"], scene_dict["last_frame"]]
16 |
17 | if scene_dict['type'] == 'waymo':
18 | from .waymo import Waymo
19 | # Scene Id in Waymo is a combination of dataset number (e.g. 4 for 0004)
20 | # and the rank of the tfrecord sorted alphabetical (e.g. 0 for segment-18446264979321894359_3700_000_3720_000_with_camera_labels.tfrecord)
21 | datadir_full = datadir + ''
22 | scene_id = scene_dict['scene_id']
23 | record_id = scene_id[1]
24 | subdirs_list = [d.name for d in os.scandir(datadir) if d.is_dir() if d.name[-4:].isnumeric()]
25 | assert scene_id[0] -z, y --> x, z --> y
164 | x_c_0 = np.matmul(poses[0, :, :], np.array([5.0, 0.0, 0.0, 1.0]))[:3]
165 | y_c_0 = np.matmul(poses[0, :, :], np.array([0.0, 5.0, 0.0, 1.0]))[:3]
166 | z_c_0 = np.matmul(poses[0, :, :], np.array([0.0, 0.0, 5.0, 1.0]))[:3]
167 | coord_cam_0 = [x_c_0, y_c_0, z_c_0]
168 | c_origin_0 = poses[0, :3, 3]
169 |
170 | plt.sca(ax_lst[0, 0])
171 | plt.arrow(
172 | c_origin_0[ax_birdseye[0]],
173 | c_origin_0[ax_birdseye[1]],
174 | coord_cam_0[ax_birdseye[0]][ax_birdseye[0]] - c_origin_0[ax_birdseye[0]],
175 | coord_cam_0[ax_birdseye[0]][ax_birdseye[1]] - c_origin_0[ax_birdseye[1]],
176 | color="red",
177 | width=0.1,
178 | )
179 | plt.arrow(
180 | c_origin_0[ax_birdseye[0]],
181 | c_origin_0[ax_birdseye[1]],
182 | coord_cam_0[ax_birdseye[1]][ax_birdseye[0]] - c_origin_0[ax_birdseye[0]],
183 | coord_cam_0[ax_birdseye[1]][ax_birdseye[1]] - c_origin_0[ax_birdseye[1]],
184 | color="green",
185 | width=0.1,
186 | )
187 | plt.axis("equal")
188 | plt.sca(ax_lst[1, 0])
189 | plt.arrow(
190 | c_origin_0[ax_xy[0]],
191 | c_origin_0[ax_xy[1]],
192 | coord_cam_0[ax_xy[0]][ax_xy[0]] - c_origin_0[ax_xy[0]],
193 | coord_cam_0[ax_xy[0]][ax_xy[1]] - c_origin_0[ax_xy[1]],
194 | color="red",
195 | width=0.1,
196 | )
197 | plt.arrow(
198 | c_origin_0[ax_xy[0]],
199 | c_origin_0[ax_xy[1]],
200 | coord_cam_0[ax_xy[1]][ax_xy[0]] - c_origin_0[ax_xy[0]],
201 | coord_cam_0[ax_xy[1]][ax_xy[1]] - c_origin_0[ax_xy[1]],
202 | color="green",
203 | width=0.1,
204 | )
205 | plt.axis("equal")
206 | plt.sca(ax_lst[0, 1])
207 | plt.arrow(
208 | c_origin_0[ax_zy[0]],
209 | c_origin_0[ax_zy[1]],
210 | coord_cam_0[ax_zy[0]][ax_zy[0]] - c_origin_0[ax_zy[0]],
211 | coord_cam_0[ax_zy[0]][ax_zy[1]] - c_origin_0[ax_zy[1]],
212 | color="red",
213 | width=0.1,
214 | )
215 | plt.arrow(
216 | c_origin_0[ax_zy[0]],
217 | c_origin_0[ax_zy[1]],
218 | coord_cam_0[ax_zy[1]][ax_zy[0]] - c_origin_0[ax_zy[0]],
219 | coord_cam_0[ax_zy[1]][ax_zy[1]] - c_origin_0[ax_zy[1]],
220 | color="green",
221 | width=0.1,
222 | )
223 | plt.axis("equal")
224 |
225 | # Plot global coord axis
226 | plt.sca(ax_lst[0, 0])
227 | plt.arrow(0, 0, 5, 0, color="cyan", width=0.1)
228 | plt.arrow(0, 0, 0, 5, color="cyan", width=0.1)
229 | plt.savefig(fname)
230 |
--------------------------------------------------------------------------------
/src/pointLF/scene_point_lightfield.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from ..scenes import NeuralScene
7 | from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points
8 | from pytorch3d.transforms import Transform3d, Translate, Rotate, Scale
9 |
10 |
11 | class PointLightFieldComposition(nn.Module):
12 | def __init__(self,
13 | scene: NeuralScene):
14 |
15 | super(PointLightFieldComposition, self).__init__()
16 |
17 | self._models = {}
18 | self._static_models = {}
19 | self._static_trafos = {}
20 |
21 | for node_types_dict in scene.nodes.values():
22 | for node in node_types_dict.values():
23 | if hasattr(node, 'lightfield'):
24 | if not node.static:
25 | self._models[node.scene_idx] = getattr(node, 'lightfield')
26 | else:
27 | self._static_models[node.scene_idx] = getattr(node, 'lightfield')
28 | self._static_trafos[node.scene_idx] = node.transformation[:3, -1]
29 |
30 | def forward(self,
31 | ray_bundle: RayBundle,
32 | scene: NeuralScene,
33 | closest_point_mask,
34 | pt_cloud_select,
35 | closest_point_dist,
36 | closest_point_azimuth,
37 | closest_point_pitch,
38 | output_dict,
39 | ray_dirs_select,
40 | rotate2cam=False,
41 | **kwargs,
42 | ):
43 | device = ray_bundle.origins.device
44 |
45 | ray_bundle = ray_bundle._replace(
46 | lengths=ray_bundle.lengths.detach(),
47 | directions=ray_bundle.directions.detach(),
48 | origins=ray_bundle.origins.detach(),
49 | )
50 |
51 | # closest_point_mask, [pt_cloud_select, closest_point_dist, closest_point_azimuth, closest_point_pitch, output_dict], ray_dirs_select = pts
52 |
53 | xycfn = ray_bundle.xys
54 |
55 | c = torch.zeros_like(ray_bundle.origins)
56 |
57 | n_batch_rays = min([v.shape[0] for v in closest_point_dist.values()])
58 | # sample_mask = {
59 | # cf_id :
60 | # torch.stack(torch.where(torch.all(xycfn[..., 2:4] == torch.tensor(cf_id, device=device), dim=-1)))
61 | # for cf_id in list(closest_point_dist.keys())
62 | # }
63 | sample_mask = {
64 | cf_id:
65 | torch.stack([
66 | torch.ones(xycfn.shape[-2], dtype=torch.int64, device=device) * j,
67 | torch.linspace(0, xycfn.shape[-2] - 1, xycfn.shape[-2], dtype=torch.int64, device=device)
68 | ])
69 | for j, cf_id in enumerate(list(closest_point_dist.keys()))
70 | }
71 |
72 | pt_cloud_select, closest_point_dist, closest_point_azimuth, closest_point_pitch, ray_dirs_select, closest_point_mask, sample_mask = \
73 | self.split_for_uneven_batch_sz(pt_cloud_select, closest_point_dist, closest_point_azimuth, closest_point_pitch, ray_dirs_select, closest_point_mask, sample_mask, n_batch_rays,
74 | device)
75 |
76 | for node_idx, model in self._static_models.items():
77 | x = None
78 | # TODO: Only get frame specific here
79 | for cf_id, mask in closest_point_mask.items():
80 |
81 | # Check if respective background and frame match
82 | frame = scene.frames[int(cf_id[1])]
83 | if any(node_idx == scene.frames[int(cf_id[1])].scene_matrix[:, 0]):
84 | # TODO: Multiple frames at once
85 | # Get projected rgb
86 | # closest_rgb_fr = self._get_closest_rgb(scene, ray_bundle, point_cloud=pt_cloud_select[cf_id], c=int(cf_id[0]), f=int(cf_id[1]), device=device)
87 | closest_rgb_fr = torch.zeros_like(pt_cloud_select[cf_id])
88 |
89 | # TODO: rotation of the point cloud as part of the sampler or optional
90 | if rotate2cam:
91 | li_idx, li_node = list(scene.nodes['lidar'].items())[0]
92 | assert li_node.name == "TOP"
93 | cam_ed = frame.get_edge_by_child_idx([cf_id[0]])[0][0]
94 | li_ed = frame.get_edge_by_child_idx([li_idx])[0][0]
95 | # Transform x from li2world2cam
96 | # li2world
97 | li2wo = li_ed.get_transformation_c2p().to(device)
98 | # world2cam
99 | wo2cam = cam_ed.get_transformation_p2c().to(device)
100 |
101 | pt_cloud_select[cf_id] = wo2cam.transform_points(li2wo.transform_points(pt_cloud_select[cf_id]))
102 |
103 | # Check if intersections could be found inside frustum
104 | if mask is not None:
105 | if x is None:
106 | x = pt_cloud_select[cf_id][None]
107 | x_dist = closest_point_dist[cf_id][None]
108 | azimuth = closest_point_azimuth[cf_id][None]
109 | pitch = closest_point_pitch[cf_id][None]
110 | ray_dirs = ray_dirs_select[cf_id][None]
111 | closest_mask = closest_point_mask[cf_id][None]
112 | # sample_mask = torch.stack(torch.where(torch.all(xycfn[..., 2:4] == cf_id, dim=-1)))[None]
113 | closest_rgb = closest_rgb_fr[None]
114 | projected_dist = ray_bundle.lengths[tuple(sample_mask[cf_id])][None]
115 | else:
116 | # print(cf_id)
117 | # print(pt_cloud_select[cf_id][None].shape)
118 | # print(x.shape)
119 | x = torch.cat([x, pt_cloud_select[cf_id][None]])
120 | x_dist = torch.cat([x_dist, closest_point_dist[cf_id][None]])
121 | azimuth = torch.cat([azimuth, closest_point_azimuth[cf_id][None]])
122 | pitch = torch.cat([pitch, closest_point_pitch[cf_id][None]])
123 | ray_dirs = torch.cat([ray_dirs, ray_dirs_select[cf_id][None]])
124 | closest_mask = torch.cat([closest_mask, closest_point_mask[cf_id][None]])
125 | # new_fr_mask = torch.stack(torch.where(torch.all(xycfn[..., 2:4] == cf_id, dim=-1)))[None]
126 | # sample_mask = torch.cat([sample_mask, new_fr_mask])
127 | closest_rgb = torch.cat([closest_rgb, closest_rgb_fr[None]])
128 |
129 | projected_dist = torch.cat([projected_dist, ray_bundle.lengths[tuple(sample_mask[cf_id])][None]])
130 |
131 | sample_idx = list(closest_point_mask.keys())
132 | if x is not None:
133 | if self.training:
134 | # start = torch.cuda.Event(enable_timing=True)
135 | # end = torch.cuda.Event(enable_timing=True)
136 | # start.record()
137 | color, output_dict = model(x, ray_dirs, closest_mask, x_dist, x_proj=projected_dist, x_pitch=pitch,
138 | x_azimuth=azimuth, rgb=closest_rgb, sample_idx=sample_idx)
139 | # end.record()
140 | # torch.cuda.synchronize()
141 | # print('CUDA Time: {}'.format(start.elapsed_time(end)))
142 | else:
143 | color, output_dict = model(x, ray_dirs, closest_mask, x_dist, x_proj=projected_dist, x_pitch=pitch,
144 | x_azimuth=azimuth, rgb=closest_rgb, sample_idx=sample_idx)
145 | for (sample_color, mask) in zip(color, list(sample_mask.values())):
146 | c[tuple(mask)] = sample_color
147 |
148 | return c, output_dict
149 |
150 |
151 | def split_for_uneven_batch_sz(self, pt_cloud_select, closest_point_dist, closest_point_azimuth, closest_point_pitch, ray_dirs_select, closest_point_mask, sample_mask, n_batch_rays,
152 | device):
153 | for cf_id in list(closest_point_dist.keys()):
154 | v = closest_point_dist[cf_id]
155 | if v.shape[0] > n_batch_rays:
156 | factor = v.shape[0] // n_batch_rays
157 | for i in range(factor):
158 | cf_id_new = cf_id + tuple([i * 1])
159 | pt_cloud_select[cf_id_new] = pt_cloud_select[cf_id]
160 | closest_point_dist[cf_id_new] = closest_point_dist[cf_id][i * n_batch_rays: (i + 1) * n_batch_rays]
161 | closest_point_azimuth[cf_id_new] = closest_point_azimuth[cf_id][i * n_batch_rays: (i + 1) * n_batch_rays]
162 | closest_point_pitch[cf_id_new] = closest_point_pitch[cf_id][i * n_batch_rays: (i + 1) * n_batch_rays]
163 | ray_dirs_select[cf_id_new] = ray_dirs_select[cf_id][i * n_batch_rays: (i + 1) * n_batch_rays]
164 | closest_point_mask[cf_id_new] = closest_point_mask[cf_id][i * n_batch_rays: (i + 1) * n_batch_rays]
165 | sample_mask[cf_id_new] = sample_mask[cf_id][:, i * n_batch_rays: (i + 1) * n_batch_rays]
166 |
167 | del pt_cloud_select[cf_id]
168 | del closest_point_dist[cf_id]
169 | del closest_point_azimuth[cf_id]
170 | del closest_point_pitch[cf_id]
171 | del ray_dirs_select[cf_id]
172 | del closest_point_mask[cf_id]
173 | del sample_mask[cf_id]
174 |
175 | return pt_cloud_select, closest_point_dist, closest_point_azimuth, closest_point_pitch, ray_dirs_select, closest_point_mask, sample_mask
176 |
177 |
178 | def _get_closest_rgb(self, scene, ray_bundle, point_cloud, c, f, device):
179 | rgb = None
180 | img = torch.tensor(scene.frames[f].load_image(c), device=device)
181 |
182 | cam = scene.nodes['camera'][c]
183 | cam_ed = scene.frames[f].get_edge_by_child_idx([c])[0][0]
184 | # Get Camera Intrinsics and Extrensics
185 | cam_focal = cam.intrinsics.f_x
186 | cam_H = cam.intrinsics.H
187 | cam_W = cam.intrinsics.W
188 | cam_rot = Rotate(cam.R[None, ...].to(device), device=device)
189 | cam_transform = cam_ed.getTransformation()
190 |
191 | cam_w_xyz = point_cloud[:, [1, 2, 0]]
192 | cam_xy = cam_focal * (cam_w_xyz[:, [0, 1]] / cam_w_xyz[:, 2, None])
193 | cam_uv = -cam_xy + torch.tensor([[cam_W / 2, cam_H / 2]], device=device)
194 | cam_uv = cam_uv[:, [1, 0]].to(dtype=torch.int64)
195 | cam_uv = torch.maximum(cam_uv, torch.tensor(0, device=device))
196 | cam_uv[:, 0] = torch.minimum(cam_uv[:, 0], torch.tensor(cam_H - 1, device=device))
197 | cam_uv[:, 1] = torch.minimum(cam_uv[:, 1], torch.tensor(cam_W - 1, device=device))
198 | rgb = img[tuple(cam_uv.T)].detach()
199 |
200 | # img[tuple(cam_uv.T)] = np.array([1.0, 0., 0.])
201 | # f3 = plt.figure()
202 | # plt.imshow(img)
203 |
204 | return rgb
--------------------------------------------------------------------------------
/src/renderer/losses.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from ..scenes import NeuralScene
7 | from pytorch3d.renderer.implicit.utils import RayBundle
8 | import imageio
9 |
10 |
11 | def mse_loss(x, y, colorspace="RGB"):
12 | if "HS" in colorspace:
13 | # x,y = rgb2hsv(x),rgb2hsv(y)
14 | x, y = rgb2NonModuloHSV(x), rgb2NonModuloHSV(y)
15 | if colorspace == "HS":
16 | x, y = x[:, :-1], y[:, :-1]
17 | return torch.nn.functional.mse_loss(x, y)
18 |
19 |
20 | def calc_uncertainty_weighted_loss(
21 | x: torch.Tensor, y: torch.Tensor, uncertainty: torch.Tensor, **kwargs
22 | ):
23 | uncertainty = uncertainty + 0.1
24 | uncertainty_variance = uncertainty ** 2
25 | # TODO: Check if mean or sum in loss function
26 | uncert_weighted_mse = torch.mean((x - y) ** 2 / (2 * uncertainty_variance))
27 | uncert_loss = (
28 | torch.mean(torch.log(uncertainty_variance) / 2) + 3
29 | ) # +3 to make it positive
30 | return uncert_weighted_mse + uncert_loss
31 |
32 |
33 | def calc_density_regularizer(densities: torch.Tensor, reg: torch.int, **kwargs):
34 | return reg * torch.mean(densities)
35 |
36 |
37 | def calc_weight_regularizer(
38 | weights: torch.Tensor,
39 | reg: torch.int,
40 | xycfn: torch.Tensor,
41 | scene: NeuralScene,
42 | **kwargs,
43 | ):
44 | background_weights = []
45 |
46 | n_rays, n_sampling_pts = weights.shape
47 | short_xycfn = xycfn[..., -n_sampling_pts:, :]
48 | short_xycfn = short_xycfn.reshape(n_rays, n_sampling_pts, -1)
49 |
50 | camera_ids_sampled = short_xycfn[:, 0, 2].unique()
51 | for cf in short_xycfn[:, 0, 2:4].unique(dim=0):
52 | c = int(cf[0])
53 | f = int(cf[1])
54 | frame = scene.frames[int(f)]
55 |
56 | cf_mask = torch.all(short_xycfn[..., 2:4] == cf, dim=-1)
57 | cf_mask = torch.where(cf_mask[:, 0])[0]
58 |
59 | relevant_xy = short_xycfn[cf_mask, 0, :2]
60 | xy_mask = tuple(relevant_xy[:, [1, 0]].T)
61 | segmentation_mask = frame.load_mask(c, return_xycfn=False)[xy_mask] == 0
62 | if xy_mask[0].shape[0] == 1:
63 | continue
64 | background_weights.append(weights[cf_mask[segmentation_mask]])
65 |
66 | background_weights = torch.cat(background_weights)
67 |
68 | return reg * torch.mean(background_weights)
69 |
70 |
71 | def xycfn2list(xycfn):
72 | xycfn_list = []
73 | xycfn = xycfn.squeeze(0)
74 | for cf in xycfn[:, 0, 2:4].unique(dim=0):
75 | c = int(cf[0])
76 | f = int(cf[1])
77 |
78 | cf_mask = torch.all(xycfn[..., 2:4] == cf, dim=-1)
79 | cf_mask = torch.where(cf_mask[:, 0])[0]
80 |
81 | yx = xycfn[cf_mask, 0, :2]
82 | xycfn_list += [{"f": f, "c": c, "yx": yx}]
83 |
84 | return xycfn_list
85 |
86 |
87 | def calc_psnr(x: torch.Tensor, y: torch.Tensor):
88 | mse = torch.nn.functional.mse_loss(x, y)
89 | psnr = -10.0 * torch.log10(mse)
90 | return psnr
91 |
92 |
93 | # def calc_latent_dist(
94 | # xycfn, scene, reg, transient_frame_embbeding_reg=0, scene_function=None
95 | # ):
96 | # loss = 0
97 |
98 | # trainable_latent_nodes_id = list(scene.nodes["scene_object"].keys())
99 |
100 | # # Remove "non"-nodes from the rays
101 | # latent_nodess_id = xycfn[..., 4].unique().tolist()
102 | # try:
103 | # latent_nodess_id.remove(-1)
104 | # except:
105 | # pass
106 |
107 | # # Just include nodes that have latent arrays
108 | # latent_nodess_id = set(trainable_latent_nodes_id) & set(latent_nodess_id)
109 |
110 | # if len(latent_nodess_id) != 0:
111 | # latent_codes = torch.stack(
112 | # [
113 | # torch.cat(list(scene.nodes[i]["node"].latent.values()))
114 | # for i in latent_nodess_id
115 | # ]
116 | # )
117 | # latent_dist = torch.sum(reg * torch.norm(latent_codes, dim=-1))
118 | # else:
119 | # latent_dist = torch.tensor(0.0)
120 |
121 | # loss += latent_dist
122 |
123 | # if transient_frame_embbeding_reg:
124 | # image_latent = []
125 | # fn = xycfn[..., 3:5].view(-1, 2).unique(dim=0)
126 |
127 | # tranient_latents = []
128 | # for (f, n) in fn:
129 | # # Get Transient embeddings for each node and frame combination
130 | # key = f"{f}_{n}"
131 | # if key in scene_function.transient_object_embeddings:
132 | # fn_idx = torch.where(
133 | # torch.all(fn == torch.tensor([f, n], device=fn.device), dim=1)
134 | # )
135 | # transient_embedding_object_frame = (
136 | # scene_function.transient_object_embeddings[key]
137 | # )
138 | # tranient_latents.append(transient_embedding_object_frame)
139 |
140 | # # Find a loss on all those to be similar
141 | # image_latent = torch.stack(tranient_latents)
142 | # loss += transient_frame_embbeding_reg * torch.sum(
143 | # torch.std(image_latent, dim=0)
144 | # )
145 |
146 | # return loss
147 |
148 |
149 | def extract_objects_HS_stats(unique_obj_IDs, RGB, per_pix_obj_ID, Hue2Cartesian=True):
150 | # It is actually not neccesary to do this in PyTorch. Currently we are using GT colors, so we might as well work in Numpy since we don't need to backpropagate through it.
151 | if Hue2Cartesian:
152 | color_vals = rgb2NonModuloHSV(RGB)[:, :-1]
153 | else:
154 | color_vals = rgb2hsv(RGB)[:, :-1]
155 | per_obj_vals = dict(
156 | zip(
157 | unique_obj_IDs, [color_vals[per_pix_obj_ID == i, :] for i in unique_obj_IDs]
158 | )
159 | )
160 | # per_obj_color_stats = dict([(k,(np.mean(v.cpu().numpy(),0),np.cov(v.cpu().numpy().transpose()))) for k,v in per_obj_vals.items()])
161 | # color_means,color_covs = np.stack([per_obj_color_stats[k][0] for k in unique_obj_IDs]),np.stack([per_obj_color_stats[k][1] for k in unique_obj_IDs])
162 | per_obj_color_stats = dict(
163 | [(k, (torch.mean(v, 0), cov(v))) for k, v in per_obj_vals.items()]
164 | )
165 | color_means, color_covs = torch.stack(
166 | [per_obj_color_stats[k][0] for k in unique_obj_IDs]
167 | ), torch.stack([per_obj_color_stats[k][1] for k in unique_obj_IDs])
168 | return color_means, color_covs
169 |
170 |
171 | def calc_latent_color_loss(unique_obj_IDs, RGB, per_pix_obj_ID, extract_latent_fn):
172 | color_means, color_covs = extract_objects_HS_stats(
173 | unique_obj_IDs, RGB, per_pix_obj_ID
174 | )
175 | MEANS_WEIGHT = 0.5
176 | latent_dists, color_dists = [], []
177 | for ind1 in range(len(unique_obj_IDs)):
178 | for ind2 in range(ind1 + 1, len(unique_obj_IDs)):
179 | latent_dist = torch.norm(
180 | extract_latent_fn(ind1) - extract_latent_fn(ind2), p=2, dim=0
181 | ) # /np.sqrt(latent_size)
182 | latent_dists.append(latent_dist)
183 | # color_dist = MEANS_WEIGHT*np.linalg.norm(color_means[ind1,:]-color_means[ind2,:],ord=2)/np.sqrt(color_means.shape[1])
184 | # color_dist += (1-MEANS_WEIGHT)*np.linalg.norm(color_covs[ind1,...]-color_covs[ind2,...],ord='fro')/color_means.shape[1]
185 | color_dist = MEANS_WEIGHT * torch.norm(
186 | color_means[ind1, :] - color_means[ind2, :], p=2, dim=0
187 | ) # /np.sqrt(color_means.shape[1])
188 | color_dist += (1 - MEANS_WEIGHT) * torch.norm(
189 | color_covs[ind1, ...] - color_covs[ind2, ...], p="fro"
190 | ) # /color_means.shape[1]
191 | color_dists.append(color_dist)
192 | latent_dists = torch.stack(latent_dists)
193 | color_dists = torch.stack(color_dists)
194 | # color_dists = np.array(color_dists)/np.mean(color_dists)*torch.mean(latent_dists).item()
195 | # return torch.mean((torch.tensor(color_dists).type(latent_dists.type())-latent_dists)**2)
196 | color_dists = color_dists / torch.mean(color_dists) * torch.mean(latent_dists)
197 | return torch.mean((color_dists - latent_dists) ** 2)
198 |
199 |
200 | def prod_density_distribution(
201 | obj_weights: torch.Tensor, transient_weights: torch.Tensor
202 | ):
203 | obj_weights = obj_weights + 1e-5
204 | obj_sample_pdf = obj_weights / torch.sum(obj_weights, -1, keepdim=True)
205 |
206 | transient_weights = transient_weights + 1e-5
207 | transient_sample_pdf = transient_weights / torch.sum(
208 | transient_weights, -1, keepdim=True
209 | )
210 |
211 | return torch.mean(100 * obj_sample_pdf * transient_sample_pdf)
212 |
213 |
214 | def get_rgb_gt(
215 | rgb: torch.Tensor, scene: NeuralScene, xycfn: torch.Tensor, use_gt_mask=False
216 | ):
217 | xycf = xycfn[..., 0, :4].reshape(len(rgb), -1)
218 | rgb_gt = torch.ones_like(rgb) * -1
219 |
220 | # TODO: Make more efficient by not retriving image for each pixel,
221 | # but storing all gt_images in a single tensor on the cpu
222 | # TODO: During test time just get all images avilable
223 | camera_ids_sampled = xycf[:, 2].unique()
224 | for f in xycf[:, 3].unique():
225 | frame = scene.frames[int(f)]
226 | for c in frame.camera_ids:
227 | if c not in camera_ids_sampled:
228 | continue
229 | cf_mask = torch.all(
230 | xycf[:, 2:] == torch.tensor([c, f], device=xycf.device), dim=1
231 | )
232 | xy = xycf[cf_mask, :2].cpu()
233 |
234 | gt_img = frame.load_image(int(c))
235 | if use_gt_mask == 2:
236 | gt_img[frame.load_mask(c, return_xycfn=False) == 0] = 0
237 | if use_gt_mask == 3:
238 | gt_img[frame.load_mask(c, return_xycfn=False) == 0] = 1
239 | gt_px = torch.from_numpy(gt_img[xy[:, 1], xy[:, 0]]).to(
240 | device=rgb.device, dtype=rgb.dtype
241 | )
242 | rgb_gt[cf_mask] = gt_px
243 |
244 | assert (rgb_gt == -1).sum() == 0
245 | return rgb_gt
246 |
247 |
248 | def rgb2hsv(rgb_vect):
249 | img = rgb_vect # * 0.5 + 0.5
250 | per_pix_max, per_pix_min = img.max(1)[0], img.min(1)[0]
251 | delta = per_pix_max - per_pix_min
252 | delta_is_0 = delta == 0
253 | hue = torch.zeros([img.shape[0]]).to(img.device)
254 | max_is_R = torch.logical_and(
255 | img[:, 0] == per_pix_max, torch.logical_not(delta_is_0)
256 | )
257 | max_is_G = torch.logical_and(
258 | img[:, 1] == per_pix_max, torch.logical_not(delta_is_0)
259 | )
260 | max_is_B = torch.logical_and(
261 | img[:, 2] == per_pix_max, torch.logical_not(delta_is_0)
262 | )
263 | hue[max_is_B] = 4.0 + ((img[max_is_B, 0] - img[max_is_B, 1]) / delta[max_is_B])
264 | hue[max_is_G] = 2.0 + ((img[max_is_G, 2] - img[max_is_G, 0]) / delta[max_is_G])
265 | hue[max_is_R] = (
266 | 0.0 + ((img[max_is_R, 1] - img[max_is_R, 2]) / delta[max_is_R])
267 | ) % 6
268 |
269 | hue[delta_is_0] = 0.0
270 | hue = hue / 6
271 |
272 | saturation = torch.zeros_like(hue)
273 | max_is_0 = per_pix_max == 0
274 | saturation[torch.logical_not(max_is_0)] = (
275 | delta[torch.logical_not(max_is_0)] / per_pix_max[torch.logical_not(max_is_0)]
276 | )
277 | saturation[max_is_0] = 0.0
278 |
279 | value = per_pix_max
280 | return torch.stack([hue, saturation, value], 1)
281 |
282 |
283 | def hue2cartesian(hue):
284 | return 0.5 * torch.stack(
285 | [torch.cos(2 * np.pi * hue), torch.sin(2 * np.pi * hue)], 1
286 | )
287 |
288 |
289 | def rgb2NonModuloHSV(rgb_vect):
290 | HSV = rgb2hsv(rgb_vect)
291 | return torch.cat([hue2cartesian(HSV[:, 0]), HSV[:, 1:]], 1)
292 |
293 |
294 | def cov(x, rowvar=False, bias=False, ddof=None, aweights=None):
295 | # From https://github.com/pytorch/pytorch/issues/19037
296 | """Estimates covariance matrix like numpy.cov"""
297 | # ensure at least 2D
298 | if x.dim() == 1:
299 | x = x.view(-1, 1)
300 |
301 | # treat each column as a data point, each row as a variable
302 | if rowvar and x.shape[0] != 1:
303 | x = x.t()
304 |
305 | if ddof is None:
306 | if bias == 0:
307 | ddof = 1
308 | else:
309 | ddof = 0
310 |
311 | w = aweights
312 | if w is not None:
313 | if not torch.is_tensor(w):
314 | w = torch.tensor(w, dtype=torch.float)
315 | w_sum = torch.sum(w)
316 | avg = torch.sum(x * (w / w_sum)[:, None], 0)
317 | else:
318 | avg = torch.mean(x, 0)
319 |
320 | # Determine the normalization
321 | if w is None:
322 | fact = x.shape[0] - ddof
323 | elif ddof == 0:
324 | fact = w_sum
325 | elif aweights is None:
326 | fact = w_sum - ddof
327 | else:
328 | fact = w_sum - ddof * torch.sum(w * w) / w_sum
329 |
330 | xm = x.sub(avg.expand_as(x))
331 |
332 | if w is None:
333 | X_T = xm.t()
334 | else:
335 | X_T = torch.mm(torch.diag(w), xm).t()
336 |
337 | c = torch.mm(X_T, xm)
338 | c = c / fact
339 |
340 | return c.squeeze()
341 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import tqdm, argparse
2 | import os, exp_configs
3 | import copy
4 | import pandas as pd
5 | import numpy as np
6 | import torch
7 | from PIL import Image
8 | from haven import haven_utils as hu
9 | from haven import haven_examples as he
10 | from haven import haven_wizard as hw
11 | from haven import haven_results as hr
12 | from src import models
13 | from src.scenes import NeuralScene
14 | from .scenes import createCamera, createStereoCamera
15 | # from src.renderer import NeuralSceneRenderer
16 | from pytorch3d.transforms.rotation_conversions import axis_angle_to_quaternion, quaternion_to_axis_angle
17 | import time
18 | from pytorch3d.transforms.so3 import so3_exponential_map, so3_log_map
19 | from pytorch3d.transforms.rotation_conversions import (
20 | euler_angles_to_matrix,
21 | matrix_to_euler_angles,
22 | )
23 |
24 |
25 | def mask_to_xycfn(frame_idx, camera_idx, mask):
26 | xycfn_list = []
27 | for node_id in np.unique(mask):
28 | if node_id == 0:
29 | continue
30 | y, x = np.where(mask == node_id)
31 | xycfn_i = np.zeros((y.shape[0], 5))
32 | xycfn_i[:, 4] = node_id
33 | xycfn_i[:, 3] = frame_idx
34 | xycfn_i[:, 2] = camera_idx
35 | xycfn_i[:, 1] = y
36 | xycfn_i[:, 0] = x
37 | xycfn_list += [xycfn_i]
38 | xycfn = np.vstack(xycfn_list)
39 | return xycfn
40 |
41 |
42 | class to_args(object):
43 | def __init__(self, adict):
44 | self.__dict__.update(adict)
45 |
46 |
47 | def collate_fn_lists_of_dicts(batch):
48 | return batch
49 |
50 |
51 | def collate_fn_dict_of_lists(batch):
52 | import functools
53 |
54 | return {
55 | key: [item[key] for item in batch]
56 | for key in list(
57 | functools.reduce(
58 | lambda x, y: x.union(y), (set(dicts.keys()) for dicts in batch)
59 | )
60 | )
61 | }
62 |
63 |
64 | def display_points(gt_image, out, color):
65 | obj_mask = torch.where(out["xycfn"][..., -1].ge(1))
66 | xycf_relevant = out["xycfn"][obj_mask][..., :4]
67 | for f in xycf_relevant[..., 3].unique():
68 | # frame = scene.frames[int(f)]
69 | for c in xycf_relevant[:, 2].unique():
70 | cf_mask = torch.all(
71 | xycf_relevant[:, 2:]
72 | == torch.tensor([c, f], device=xycf_relevant.device),
73 | dim=1,
74 | )
75 | xy = xycf_relevant[cf_mask, :2].cpu()
76 | # c_id = scene.getNodeBySceneId(int(c)).type_idx
77 | gt_img = gt_image.copy()
78 | gt_img[xy[:, 1], xy[:, 0]] = np.array(color)
79 | return gt_img
80 |
81 |
82 | def add_camera_path_frame(scene, frame_idx=[0], cam_idx=None, n_steps=9, remove_old_cameras=True, offset=0.):
83 | edges_to_cameras_ls = []
84 | imgs = []
85 | frames = []
86 | cams = []
87 |
88 | if cam_idx is None:
89 | cam_idx = [None] * len(frame_idx)
90 |
91 | max_ind_list = len(scene.frames_cameras)
92 |
93 | for frame_id, c_id in zip(frame_idx, cam_idx):
94 | # Extract all information from the frame for which cameras should be interpolated
95 | selected_frame = scene.frames[frame_id]
96 | frames += [selected_frame]
97 | cam_nodes = scene.nodes['camera']
98 | if c_id is None:
99 | # KITIT
100 | cams += [list(cam_nodes.values())[0]]
101 |
102 | edges_to_cameras = selected_frame.get_edge_by_child_idx(list(cam_nodes.keys()))
103 | edges_to_cameras_ls += [ed[0] for ed in edges_to_cameras]
104 |
105 | # edge_ids_to_cameras_ls = selected_frame.get_edge_idx_by_child_idx(list(cam_nodes.keys()))
106 | # edge_ids_to_cameras_ls = [ed_id[0] for ed_id in edge_ids_to_cameras_ls]
107 |
108 | imgs += [list(selected_frame.images.values())[0]]
109 | else:
110 | cams += [cam_nodes[c_id]]
111 |
112 | edges_to_cameras = selected_frame.get_edge_by_child_idx([c_id])
113 | edges_to_cameras_ls += [ed[0] for ed in edges_to_cameras]
114 |
115 | # edge_ids_to_cameras_ls = selected_frame.get_edge_idx_by_child_idx([c_id])
116 | # edge_ids_to_cameras_ls = [ed_id[0] for ed_id in edge_ids_to_cameras_ls]
117 |
118 | imgs += [selected_frame.images[c_id]]
119 |
120 | scene_dict = {}
121 | for fr_c in scene.frames_cameras:
122 | if fr_c[0] == frame_id:
123 | scene_dict = fr_c[2]
124 | scene_id = scene_dict['scene_id']
125 | break
126 |
127 | add_new_val_render_path(frames, cams, edges_to_cameras_ls, imgs, scene, scene_dict, n_steps,
128 | remove_old_cameras, offset)
129 |
130 | new_max_ind_list = len(scene.frames_cameras)
131 | ind_list = np.linspace(max_ind_list, new_max_ind_list - 1, new_max_ind_list - max_ind_list, dtype=int)
132 |
133 | return scene, ind_list
134 |
135 |
136 | def add_new_val_render_path(frames, cams, edges_to_cameras_ls, img_path, scene, scene_dict, n_steps=3, remove_old_cameras=False,
137 | offset=0.):
138 | selected_frame = frames[0]
139 | copied_cam = cams[0]
140 | cam_ids = scene.nodes['camera'].keys()
141 |
142 | fr_edge_ids_to_cameras_ls = [[fr.frame_idx, ed.index] for fr, ed in zip(frames, edges_to_cameras_ls)]
143 | # Create a new frame as a copy of the selected frame
144 | new_frame = copy.deepcopy(selected_frame)
145 |
146 | # Remove image paths from unused cameras
147 | if remove_old_cameras:
148 | for c_id in cam_ids:
149 | if c_id in new_frame.images:
150 | del new_frame.images[c_id]
151 | if len(new_frame.get_edge_by_child_idx([c_id])[0]) > 0:
152 | new_frame.camera_ids.remove(c_id)
153 |
154 | # Get Camera poses
155 | new_rotations, new_translations = interpolate_between_camera_edges(edges_to_cameras_ls, n_steps=n_steps, offset=offset)
156 | new_rotations_no, new_translations_no = interpolate_between_camera_edges(edges_to_cameras_ls, n_steps=n_steps,
157 | offset=0.)
158 |
159 | # Create edges and nodes from new cameras
160 | new_edges_ls = []
161 | for k, (rotation, translation) in enumerate(zip(new_rotations, new_translations)):
162 | # Create new virtual camera
163 | new_cam = createCamera(copied_cam.H, copied_cam.W, copied_cam.intrinsics.f_x.cpu().numpy(), type=scene_dict['type'])
164 | new_nodes = scene.updateNodes(new_cam)
165 | new_cam_id = list(new_nodes['camera'].keys())[0]
166 | new_frame.images[new_cam_id] = img_path[k // n_steps]
167 | new_frame.camera_ids.append(new_cam_id)
168 |
169 | # Create new edge to the camera
170 | new_edge = copy.deepcopy(edges_to_cameras_ls[0])
171 | new_edge.translation = translation
172 | new_edge.rotation = rotation
173 | new_edge.child = new_cam_id
174 |
175 | new_edges_ls.append(new_edge)
176 |
177 | # Remove old cameras if requested
178 | # if remove_old_cameras:
179 | # for fr_id, ed_id in fr_edge_ids_to_cameras_ls:
180 | # new_frame.removeEdge(ed_id)
181 |
182 | # Add edges to the new camera poses to the graph
183 | new_frame.add_edges(new_edges_ls)
184 |
185 | # Add frame to the scene
186 | frame_list = [new_frame]
187 |
188 | scene_id = 1e4
189 | if not len(scene_dict) == 0:
190 | scene_id = scene_dict['scene_id']
191 |
192 | scene.updateFrames(frame_list, scene_id, scene_dict)
193 |
194 |
195 | def add_new_val_render_poses(frame_to_copy_from, camera_to_copy_from, edges_to_cameras_ls, img_path, scene, scene_dict, n_steps=3, remove_old_cameras=True):
196 | selected_frame = frame_to_copy_from
197 | copied_cam = camera_to_copy_from
198 | cam_ids = scene.nodes['camera'].keys()
199 |
200 | edge_ids_to_cameras_ls = selected_frame.get_edge_idx_by_child_idx(list(cam_ids))
201 | edge_ids_to_cameras_ls = [ed_id[0] for ed_id in edge_ids_to_cameras_ls]
202 | # Create a new frame as a copy of the selected frame
203 | new_frame = copy.deepcopy(selected_frame)
204 |
205 | # Remove image paths from unused cameras
206 | if remove_old_cameras:
207 | for c_id in cam_ids:
208 | del new_frame.images[c_id]
209 | new_frame.camera_ids.remove(c_id)
210 |
211 | # Get Camera poses
212 | new_rotations, new_translations = interpolate_between_camera_edges(edges_to_cameras_ls, n_steps=n_steps, )
213 |
214 | # Create edges and nodes from new cameras
215 | new_edges_ls = []
216 | for k, (rotation, translation) in enumerate(zip(new_rotations, new_translations)):
217 | # Create new virtual camera
218 | new_cam = createCamera(copied_cam.H, copied_cam.W, copied_cam.intrinsics.f_x.cpu().numpy(), )
219 | new_nodes = scene.updateNodes(new_cam)
220 | new_cam_id = list(new_nodes['camera'].keys())[0]
221 | new_frame.images[new_cam_id] = img_path
222 | new_frame.camera_ids.append(new_cam_id)
223 |
224 | # Create new edge to the camera
225 | new_edge = copy.deepcopy(edges_to_cameras_ls[0])
226 | new_edge.translation = translation
227 | new_edge.rotation = rotation
228 | new_edge.child = new_cam_id
229 |
230 | new_edges_ls.append(new_edge)
231 |
232 | # Remove old cameras if requested
233 | if remove_old_cameras:
234 | for ed_id in edge_ids_to_cameras_ls:
235 | new_frame.removeEdge(ed_id)
236 |
237 | # Add edges to the new camera poses to the graph
238 | new_frame.add_edges(new_edges_ls)
239 |
240 | # Add frame to the scene
241 | frame_list = [new_frame]
242 |
243 | scene_id = 1e4
244 | if not len(scene_dict) == 0:
245 | scene_id = scene_dict['scene_id']
246 |
247 | scene.updateFrames(frame_list, scene_id, scene_dict)
248 |
249 |
250 | def interpolate_between_camera_edges(edges_to_cameras_ls, n_steps=5, offset=0.):
251 | rots = [ed.rotation for ed in edges_to_cameras_ls]
252 | translations = [ed.translation for ed in edges_to_cameras_ls]
253 |
254 | steps = torch.linspace(0, 1, n_steps, device=rots[0].device)
255 |
256 | new_quat_rots = []
257 | new_translations = []
258 | for cam_pair_i in range(len(rots) - 1):
259 | # Interpolation between translations
260 | translation_0 = translations[cam_pair_i] + torch.matmul(so3_exponential_map(rots[cam_pair_i]),
261 | torch.tensor([1., 0., 0.]) * offset)
262 | translation_1 = translations[cam_pair_i + 1] + torch.matmul(so3_exponential_map(rots[cam_pair_i + 1]),
263 | torch.tensor([1., 0., 0.]) * offset)
264 | mid_translations = translation_0 * (1 - steps[:, None]) + \
265 | translation_1 * (steps[:, None])
266 |
267 | for i in range(n_steps - 1):
268 | new_translations.append(mid_translations[i, None])
269 |
270 | # Implementation between rotations with Quaternion SLERP
271 | quat_rot_0 = axis_angle_to_quaternion(rots[cam_pair_i])
272 | quat_rot_1 = axis_angle_to_quaternion(rots[cam_pair_i + 1])
273 | cosHalfTheta = torch.sum(quat_rot_0 * quat_rot_1)
274 | halfTheta = torch.acos(cosHalfTheta)
275 | sinHalfTheta = torch.sqrt(1.0 - cosHalfTheta * cosHalfTheta)
276 |
277 | if (torch.abs(sinHalfTheta) < 0.001):
278 | # theta is 180 degree --> Rotation around different axis is possible
279 | for i in range(n_steps - 1):
280 | new_quat_rots.append(quat_rot_0)
281 | elif (torch.abs(cosHalfTheta) >= 1.0):
282 | # 0 degree difference --> No new rotation necessary
283 | mid_quat = quat_rot_0 * (1. - steps)[:, None] + quat_rot_1 * (steps)[:, None]
284 | for i in range(n_steps - 1):
285 | new_quat_rots.append(mid_quat[i, None])
286 | else:
287 | ratioA = torch.sin((1 - steps) * halfTheta) / sinHalfTheta
288 | ratioB = torch.sin(steps * halfTheta) / sinHalfTheta
289 | mid_quat = quat_rot_0 * ratioA[:, None] + quat_rot_1 * ratioB[:, None]
290 | for i in range(n_steps - 1):
291 | new_quat_rots.append(mid_quat[i, None])
292 |
293 | new_quat_rots.append(quat_rot_1)
294 | new_translations.append(translation_1)
295 |
296 | new_rotations = [quaternion_to_axis_angle(quaterion) for quaterion in new_quat_rots]
297 |
298 | return new_rotations, new_translations
299 |
300 |
301 | def output_gif(scene, ind_list, savedir_images, tgt_fname):
302 | tgt_path = os.path.join(savedir_images, (tgt_fname + '.gif'))
303 |
304 | iname_ls = [f'frame_{scene.frames_cameras[i][0]}_camera_{scene.frames_cameras[i][1]}' for i in ind_list]
305 |
306 | image_path_ls = [os.path.join(savedir_images, f"{iname}.png") for iname in iname_ls]
307 |
308 | img, *imgs = [Image.open(f) for f in image_path_ls]
309 | img.save(
310 | fp=tgt_path,
311 | format="GIF",
312 | append_images=imgs,
313 | save_all=True,
314 | duration=3000 // len(image_path_ls),
315 | loop=0,
316 | )
--------------------------------------------------------------------------------
/src/scenes/raysampler/rayintersection.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from pytorch3d.renderer.implicit.utils import RayBundle
5 | from pytorch3d.transforms import Transform3d
6 | from pytorch3d.transforms.so3 import so3_log_map, so3_exponential_map
7 |
8 |
9 | def join_list(x):
10 | return [inner for outer in x for inner in outer]
11 |
12 |
13 | class RayIntersection(nn.Module):
14 | """
15 | Module to compute Ray intersections
16 | """
17 |
18 | def __init__(self,
19 | n_samples: int,
20 | chunk_size: int):
21 | super().__init__()
22 | self._n_samples = n_samples
23 | self._chunk_size = int(chunk_size)
24 |
25 | def forward(self,
26 | ray_bundle: RayBundle,
27 | scene,
28 | intersections: torch.Tensor = None,
29 | **kwargs
30 | ):
31 | """
32 | Takes rays [ray_origin, ray_direction, frame_id], nodes to intersect and a scene as an Input.
33 | Outputs intersection points in ray coordinates, [z_val, node_scene_id, frame_id]
34 | """
35 | # TODO: Only compute required intersection with plane/boxes/etc. in test time
36 |
37 | # return intersection_pts
38 | pass
39 |
40 |
41 | class RayBoxIntersection(RayIntersection):
42 | def __init__(self,
43 | box_nodes: dict,
44 | chunk_size: int,
45 | n_intersections_box: int=2):
46 |
47 | super().__init__(n_samples=n_intersections_box,
48 | chunk_size=chunk_size)
49 |
50 | self._box_nodes = [node.scene_idx for node in box_nodes.values()]
51 |
52 | def forward(self,
53 | origins: torch.Tensor,
54 | directions: torch.Tensor,
55 | box_bounds: torch.Tensor = torch.tensor([[-1, -1, -1], [1, 1, 1]]),
56 | transformation: Transform3d = None,
57 | output_points: bool = False,
58 | **kwargs,
59 | ):
60 |
61 | if transformation is not None:
62 | origins = transformation.transform_points(origins)
63 | directions = transformation.transform_points(directions)
64 |
65 | if directions.dim() == 2:
66 | n_batch = 1
67 | n_rays_batch = directions.shape[0]
68 | elif directions.dim() == 3:
69 | ray_d_sh = directions.shape
70 | n_batch = ray_d_sh[0]
71 | n_rays_batch = ray_d_sh[1]
72 | directions = directions.flatten(0, -2)
73 | origins = origins.flatten(0, -2)
74 | else:
75 | ValueError("Ray directions are of dimesion {}, but must be of dimension 2 or 3.".format(directions.dim()))
76 |
77 | if origins.dim() < directions.dim():
78 | origins = origins.expand(n_batch, n_rays_batch, 3)
79 | origins = origins.flatten(0, -2)
80 | else:
81 | if origins.shape != directions.shape:
82 | ValueError()
83 |
84 | box_bounds = box_bounds.to(origins.device)
85 | # Perform ray-aabb intersection and return in and out without chunks
86 | inv_d = torch.reciprocal(directions)
87 |
88 | t_min = (box_bounds[0] - origins) * inv_d
89 | t_max = (box_bounds[1] - origins) * inv_d
90 |
91 | t0 = torch.minimum(t_min, t_max)
92 | t1 = torch.maximum(t_min, t_max)
93 | t_near = torch.maximum(torch.maximum(t0[..., 0], t0[..., 1]), t0[..., 2])
94 | t_far = torch.minimum(torch.minimum(t1[..., 0], t1[..., 1]), t1[..., 2])
95 | # Check if rays are inside boxes
96 | intersection_idx = torch.where(t_far > t_near)
97 | # Check that boxes are in front of the ray origin
98 | intersection_idx = intersection_idx[0][t_far[intersection_idx] > 0]
99 | if not len(intersection_idx) == 0:
100 | z_in = t_near[intersection_idx]
101 | z_out = t_far[intersection_idx]
102 |
103 | # Reindex for [n_batches, ....] and sort again
104 | batch_idx = torch.floor(intersection_idx / n_rays_batch).to(torch.int64)
105 | intersection_idx = intersection_idx % n_rays_batch
106 | # intersection_idx, new_sort = torch.sort(intersection_idx)
107 | intersection_mask = tuple([batch_idx, intersection_idx])
108 |
109 | if not output_points:
110 | return [z_in, z_out], intersection_mask
111 | else:
112 | pts_in = origins.view(-1, n_rays_batch, 3)[intersection_idx] + \
113 | directions.view(-1, n_rays_batch, 3)[intersection_idx] * z_in[:, None]
114 | pts_out = origins.view(-1, n_rays_batch, 3)[intersection_idx] + \
115 | directions.view(-1, n_rays_batch, 3)[intersection_idx] * z_out[:, None]
116 |
117 | return [z_in, z_out, pts_in, pts_out], intersection_mask
118 |
119 | else:
120 | if output_points:
121 | return [None, None, None, None], None
122 | else:
123 | return [None, None], None
124 |
125 |
126 | class RaySphereIntersection(RayIntersection):
127 | def __init__(self,
128 | n_samples_box: int,
129 | chunk_size: int,):
130 | super().__init__(n_samples=n_samples_box)
131 |
132 | def forward(self,
133 | ray_bundle: RayBundle,
134 | scene,
135 | intersections: torch.Tensor = None,
136 | **kwargs):
137 | intersection_pts = []
138 | return intersection_pts
139 |
140 |
141 | class RayPlaneIntersection(RayIntersection):
142 | def __init__(self,
143 | n_planes: int,
144 | near: float,
145 | far: float,
146 | chunk_size: int,
147 | camera_poses: list,
148 | background_trafos: torch.Tensor,
149 | transient_background: bool=False):
150 | super().__init__(n_samples=n_planes,
151 | chunk_size=chunk_size)
152 |
153 | self._planes_n = {}
154 | self._planes_p = {}
155 | self._plane_delta = {}
156 | self._near = {}
157 | self._far = {}
158 | self._transient_background = False
159 |
160 | for key, val in background_trafos.items():
161 | if self._transient_background:
162 | self._near[key] = torch.as_tensor(near[key])
163 | self._far[key] = torch.as_tensor(far[key])
164 |
165 | else:
166 | global_trafo = val
167 | near_k = torch.as_tensor(near[key])
168 | far_k = torch.as_tensor(far[key])
169 |
170 | self._planes_n[key] = global_trafo[:3, 2]
171 |
172 | all_camera_poses = []
173 | for edge_dict in camera_poses.values():
174 | for edge in edge_dict.values():
175 | if len(edge) == 0:
176 | continue
177 | if edge[0].parent == key:
178 | all_camera_poses.append(edge[0].translation)
179 |
180 | all_camera_poses = torch.cat(all_camera_poses)
181 | n_cameras = 2
182 | # len(camera_poses)
183 | # assert n_cameras == 2
184 | n_cam_poses = len(all_camera_poses)
185 |
186 | max_pose_dist = torch.norm(all_camera_poses[-1] - all_camera_poses[0])
187 |
188 | # TODO: shouldn't this be int(n_cam_poses / n_cameras) +1
189 | end = int(n_cam_poses / n_cameras) + 1
190 | pose_dist = (all_camera_poses[1:end] - all_camera_poses[:end-1])
191 | pose_dist = torch.norm(pose_dist, dim=1).max()
192 | planes_p = global_trafo[:3, -1] + near_k * self._planes_n[key]
193 |
194 | self._plane_delta[key] = (far_k - near_k) / (self._n_samples - 1)
195 |
196 | poses_per_plane = int(((far_k - near_k) / self._n_samples) / pose_dist)
197 | if poses_per_plane != 0:
198 | add_planes = int(np.ceil((n_cam_poses/n_cameras) / poses_per_plane))
199 | else:
200 | add_planes = 1
201 | id_planes = torch.linspace(0, self._n_samples + add_planes - 1, self._n_samples+ add_planes, dtype=int)
202 |
203 | self._planes_p[key] = planes_p + (id_planes * self._plane_delta[key])[:, None] * self._planes_n[key]
204 | far_k = near_k + self._plane_delta[key] * (id_planes[-1] + add_planes)
205 |
206 | self._near[key] = torch.as_tensor(near_k)
207 | self._far[key] = torch.as_tensor(far_k)
208 |
209 | def forward(self,
210 | ray_bundle: RayBundle,
211 | scene,
212 | intersections: torch.Tensor = None,
213 | obj_obly: bool=False,
214 | **kwargs):
215 | """ Ray-Plane intersection for given planes in the scenes
216 |
217 | Args:
218 | rays: ray origin and directions
219 | planes: first plane position, plane normal and distance between planes
220 | id_planes: ids of used planes
221 | near: distance between camera pose and first intersecting plane
222 |
223 | Returns:
224 | pts: [N_rays, N_samples+N_importance] - intersection points of rays and selected planes
225 | z_vals: integration step along each ray for the respective points
226 | """
227 | if not obj_obly and not self._transient_background:
228 | node_id = 0
229 |
230 | # TODO: Compare with outputs from old method
231 | # Extract ray and plane definitions
232 | device = ray_bundle.origins.device
233 | N_rays = np.prod(ray_bundle.origins.shape[:-1])
234 |
235 | # Get amount of all planes
236 | rays_sh = list(ray_bundle.origins.shape)
237 | xys_sh = list(ray_bundle.xys.shape)
238 |
239 | # Flatten ray origins and directions
240 | all_origs = ray_bundle.origins.flatten(0, -2)
241 | all_dirs = ray_bundle.directions.flatten(0, -2)
242 | xycfn = ray_bundle.xys.flatten(0,-2)
243 |
244 | if len(xycfn) != len(all_origs):
245 | Warning("Please check that global sampler is executed before the local sampler!")
246 |
247 | # TODO: Initilaize planes with right dtype float32
248 | # TODO: Check run time for multiple scene implementation
249 | d_origin_planes = torch.zeros([self._n_samples, len(xycfn)], device=device)
250 |
251 | for n_idx in xycfn[:,-1].unique():
252 | n_i_mask = torch.where(xycfn[:, -1] == n_idx)
253 | background_mask = [
254 | torch.linspace(0, self._n_samples - 1, self._n_samples, dtype=torch.int64)[:, None].repeat(1,len(n_i_mask[0])),
255 | n_i_mask[0][None].repeat(self._n_samples, 1)]
256 |
257 | n_idx = int(n_idx)
258 | origs = all_origs[n_i_mask]
259 | dirs = all_dirs[n_i_mask]
260 |
261 | p = self._planes_p[n_idx].to(device=origs.device, dtype=origs.dtype)
262 | n = self._planes_n[n_idx].to(device=origs.device, dtype=origs.dtype)
263 | near = self._near[n_idx].to(device=origs.device, dtype=origs.dtype)
264 | delta = self._plane_delta[n_idx]
265 |
266 | if len(p) > self._n_samples:
267 | # import matplotlib.pyplot as plt
268 | #
269 | # plt.scatter(origs[:, 0], origs[:, 2])
270 | # plt.scatter(p[:, 0], p[:, 2])
271 | # plt.axis('equal')
272 | # Just get the intersections with self._n_planes - 1 planes in front of the camera and the last plane
273 | d_p0_orig = torch.matmul(p[0] - origs - 1e-3, n)
274 | d_p0_orig = torch.maximum(-d_p0_orig, -near)
275 | start_idx = torch.ceil((d_p0_orig + near) / delta).to(dtype=torch.int64)
276 | plane_idx = start_idx + torch.linspace(0, self._n_samples - 2, self._n_samples - 1, dtype=int, device=near.device)[:, None]
277 | plane_idx = torch.cat([plane_idx, torch.full([1, len(origs)], len(p) - 1, device=near.device)])
278 | p = p[plane_idx]
279 | else:
280 | p = p[:, None, :]
281 |
282 | d_origin_planes_i = p - origs
283 | d_origin_planes_i = torch.matmul(d_origin_planes_i, n)
284 | d_origin_planes_i = d_origin_planes_i / torch.matmul(dirs, n)
285 | # TODO: Include check that validity here (if everything is positive)
286 | d_origin_planes[background_mask] = d_origin_planes_i
287 |
288 | rays_sh.insert(-1, self._n_samples)
289 | lengths = d_origin_planes.transpose(1, 0).reshape(rays_sh[:-1])
290 | else:
291 | if not self._transient_background:
292 | device = ray_bundle.lengths.device
293 | near = min(list(self._near.values()))
294 | far = max(list(self._far.values()))
295 | z = torch.linspace(near, far, self._n_samples)
296 | lengths = torch.ones(ray_bundle.lengths.shape[:-1])[..., None].repeat(1, 1, self._n_samples).to(device)
297 | lengths *= z.to(device)
298 | else:
299 | device = ray_bundle.lengths.device
300 | far = max(list(self._far.values()))
301 | lengths = torch.ones(ray_bundle.lengths.shape[:-1])[..., None].repeat(1, 1, self._n_samples).to(
302 | device)
303 | lengths *= torch.tensor([far]).to(device)
304 |
305 | return lengths
--------------------------------------------------------------------------------
/src/datasets/waymo_od.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import imageio
4 | import numpy as np
5 | import random
6 | from matplotlib import pyplot as plt
7 | from PIL import Image
8 | import tensorflow as tf
9 |
10 | # TODO: Make waymo data import and work with neural scenes graphs
11 | # from waymo_open_dataset.utils import frame_utils
12 | # from waymo_open_dataset import dataset_pb2 as open_dataset
13 | # from waymo_open_dataset.utils.transform_utils import *
14 | camera_names = {1: 'FRONT',
15 | 2: 'FRONT LEFT',
16 | 3: 'FRONT RIGHT',
17 | 4: 'SIDE LEFT',
18 | 5: 'SIDE RIGHT',}
19 |
20 | cameras = [1,2,3] # [1, 2, 3]
21 |
22 |
23 | def get_scene_objects(dataset, speed_thresh):
24 | waymo_obj_meta = {}
25 | max_n_obj = 0
26 |
27 | for i, data in enumerate(dataset):
28 | frame = open_dataset.Frame()
29 | frame.ParseFromString(bytearray(data.numpy()))
30 | types = [frame.laser_labels[0].TYPE_PEDESTRIAN,
31 | frame.laser_labels[0].TYPE_VEHICLE,
32 | frame.laser_labels[0].TYPE_CYCLIST]
33 |
34 | n_obj_frame = 0
35 | for laser_label in frame.laser_labels:
36 | if laser_label.type in types:
37 | id = laser_label.id
38 | # length, height, width
39 | dim = np.array([laser_label.box.length, laser_label.box.height, laser_label.box.width])
40 | speed = np.sqrt(np.sum([laser_label.metadata.speed_x ** 2, laser_label.metadata.speed_y ** 2]))
41 | acc_x = laser_label.metadata.accel_x
42 |
43 | if speed > speed_thresh:
44 | n_obj_frame += 1
45 | if id not in waymo_obj_meta:
46 | internal_track_id = len(waymo_obj_meta) + 1
47 | meta_obj = [id, internal_track_id, laser_label.type, dim]
48 | waymo_obj_meta[id] = meta_obj
49 | else:
50 | if np.sum(waymo_obj_meta[id][3] - dim) > 1e-10:
51 | print('Dimension mismatch for same object!')
52 | print(id)
53 | print(np.sum(waymo_obj_meta[id][2] - dim))
54 |
55 | if n_obj_frame > max_n_obj:
56 | max_n_obj = n_obj_frame
57 |
58 | return waymo_obj_meta, max_n_obj
59 |
60 |
61 | def get_frame_objects(laser_labels, v2w_frame_i, waymo_obj_meta, speed_thresh):
62 | # Get all objects from a frame
63 | frame_obj_dict = {}
64 |
65 | # TODO: Add Cyclists and pedestrians (like for metadata)
66 | types = [laser_labels[0].TYPE_VEHICLE]
67 |
68 | for label in laser_labels:
69 | if label.type in types:
70 | id = label.id
71 |
72 | if id in waymo_obj_meta:
73 | # TODO: CHECK vkitti/nerf x, y, z
74 | waymo2vkitti_vehicle = np.array([[1., 0., 0., 0.],
75 | [0., 0., -1., 0.],
76 | [0., 1., 0., 0.],
77 | [0., 0., 0., 1.]])
78 |
79 | x_v = label.box.center_x
80 | y_v = label.box.center_y
81 | z_v = label.box.center_z
82 | yaw_obj = np.array(label.box.heading)
83 |
84 | R_obj2v = get_yaw_rotation(yaw_obj)
85 | t_obj_v = tf.constant([x_v, y_v, z_v])
86 | transform_obj2v = get_transform(tf.cast(R_obj2v, tf.double), tf.cast(t_obj_v, tf.double))
87 | transform_obj2w = np.matmul(v2w_frame_i, transform_obj2v)
88 | R = transform_obj2w[:3, :3]
89 |
90 | yaw_aprox = np.arctan2(-R[2, 0], R[0, 0]) # np.arctan2(R[2, 0], R[0, 0]) - np.arctan2(0, 1)
91 | if yaw_aprox > np.pi:
92 | yaw_aprox -= 2*np.pi
93 | elif yaw_aprox > np.pi:
94 | yaw_aprox += 2*np.pi
95 |
96 | yaw_aprox_o = np.arccos(transform_obj2w[0, 0])
97 | if np.absolute(np.rad2deg(yaw_aprox - yaw_aprox_o)) > 1e-2:
98 | a = 0
99 |
100 | # yaw_aprox = yaw_aprox_o
101 |
102 | speed = np.sqrt(np.sum([label.metadata.speed_x ** 2, label.metadata.speed_y ** 2]))
103 | is_moving = 1. if speed > speed_thresh else 0.
104 |
105 | obj_prop = np.array(
106 | [transform_obj2w[0, 3], transform_obj2w[1, 3], transform_obj2w[2, 3], yaw_aprox, 0, 0, is_moving])
107 | frame_obj_dict[id] = obj_prop
108 |
109 | return frame_obj_dict
110 |
111 |
112 | def get_camera_pose(v2w_frame_i, calibration):
113 | # FROM Waymo OD documentation:
114 | # "Each sensor comes with an extrinsic transform that defines the transform from the
115 | # sensor frame to the vehicle frame.
116 | #
117 | # The camera frame is placed in the center of the camera lens.
118 | # The x-axis points down the lens barrel out of the lens.
119 | # The z-axis points up. The y/z plane is parallel to the camera plane.
120 | # The coordinate system is right handed."
121 |
122 | # Match opengl z --> -x, x --> y, y --> z
123 | opengl2camera = np.array([[0., 0., -1., 0.],
124 | [-1., 0., 0., 0.],
125 | [0., 1., 0., 0.],
126 | [0., 0., 0., 1.]])
127 | extrinsic_transform_c2v = np.reshape(calibration.extrinsic.transform, [4, 4])
128 | extrinsic_transform_c2v = np.matmul(extrinsic_transform_c2v, opengl2camera)
129 | c2w_frame_i_cam_c = np.matmul(v2w_frame_i, extrinsic_transform_c2v)
130 |
131 |
132 | return c2w_frame_i_cam_c
133 |
134 | def get_bbox_2d(label_2d):
135 | center = [label_2d.box.center_x, label_2d.box.center_y]
136 | dim_box = [label_2d.box.length, label_2d.box.width]
137 |
138 | left = np.ceil(center[0] - dim_box[0] * 0.5)
139 | right = np.ceil(center[0] + dim_box[0] * 0.5)
140 | top = np.ceil(center[1] + dim_box[1] * 0.5)
141 | bottom = np.ceil(center[1] - dim_box[1] * 0.5)
142 |
143 | return np.array([left, right, bottom, top])[None, :]
144 |
145 |
146 | def load_waymo_od_data(basedir, selected_frames, max_frames=5, use_obj=True, row_id=False):
147 | """
148 | :param basedir: Path to segment tfrecord
149 | :param max_frames:
150 | :param use_obj:
151 | :param row_id:
152 | :return:
153 | """
154 | if selected_frames == -1:
155 | start_frame = 0
156 | end_frame = 0
157 | else:
158 | start_frame = selected_frames[0]
159 | end_frame = selected_frames[1]
160 |
161 | print('Scene Representation from cameras:')
162 | for cam in cameras:
163 | print(camera_names[cam], ',')
164 |
165 | speed_thresh = 5.1
166 |
167 |
168 | dataset = tf.data.TFRecordDataset(basedir, compression_type='')
169 |
170 | frames = []
171 |
172 | print('Extracting all moving objects!')
173 | # Extract all moving objects
174 | # waymo_obj_meta: object_id, object_type, object_label, color, lenght, height, width
175 | # max_n_obj: maximum number of objects in a single frame
176 | waymo_obj_meta, max_n_obj = get_scene_objects(dataset, speed_thresh)
177 |
178 | # All images from cameras specified at the beginning
179 | images = []
180 |
181 | # Pose of each images camera
182 | poses = []
183 |
184 | # 2D bounding boxes
185 | bboxes = []
186 |
187 | # Frame Number, Camera Name, object_id, xyz, angle, ismoving
188 | visible_objects = []
189 | max_frame_obj = 0
190 |
191 | count = []
192 | # Extract all frames from a single tf_record
193 | for i, data in enumerate(dataset):
194 | if start_frame <= i <= end_frame:
195 | frame = open_dataset.Frame()
196 | frame.ParseFromString(bytearray(data.numpy()))
197 |
198 | # FROM Waymo OD documentation:
199 | # Global Frame/ World Frame
200 | # The origin of this frame is set to the vehicle position when the vehicle starts.
201 | # It is an ‘East-North-Up’ coordinate frame. ‘Up(z)’ is aligned with the gravity vector,
202 | # positive upwards. ‘East(x)’ points directly east along the line of latitude. ‘North(y)’
203 | # points towards the north pole.
204 |
205 | # Match vkitti: waymo global to vkitti global z --> -y, z --> y
206 | waymo2vkitti_world = np.array([[1., 0., 0., 0.],
207 | [0., 0., -1., 0.],
208 | [0., 1., 0., 0.],
209 | [0., 0., 0., 1.]])
210 |
211 | # Vehicle Frame
212 | # The x-axis is positive forwards, y-axis is positive to the left, z-axis is positive upwards.
213 | # A vehicle pose defines the transform from the vehicle frame to the global frame.
214 |
215 |
216 | # Vehicle to global frame transformation for this frame
217 | v2w_frame_i = np.reshape(frame.pose.transform, [4, 4])
218 | v2w_frame_i = np.matmul(waymo2vkitti_world, v2w_frame_i)
219 |
220 | # Get all objects in this frames
221 | frame_obj_dict = get_frame_objects(frame.laser_labels, v2w_frame_i, waymo_obj_meta, speed_thresh)
222 |
223 | # Loop over all camera images and visible objects per camera
224 | for camera_image in frame.images:
225 | for projected_lidar_labels in frame.projected_lidar_labels:
226 | if projected_lidar_labels.name != camera_image.name or projected_lidar_labels.name not in cameras:
227 | continue
228 |
229 | for calibration in frame.context.camera_calibrations:
230 | if calibration.name != camera_image.name:
231 | continue
232 |
233 | cam_no = np.array(camera_image.name).astype(np.float32)[None]
234 | frame_no = np.array(i).astype(np.float32)[None]
235 | count.append(len(images))
236 |
237 | # Extract images and camera pose
238 | images.append(np.array(tf.image.decode_jpeg(camera_image.image)))
239 | extrinsic_transform_c2w = get_camera_pose(v2w_frame_i, calibration)
240 | poses.append(extrinsic_transform_c2w)
241 |
242 | # Extract dynamic objects for image
243 | image_objects = np.ones([max_n_obj, 14]) * -1.
244 | images_boxes = []
245 | i_obj = 0
246 | for label_2d in projected_lidar_labels.labels:
247 | track_id = label_2d.id[:22]
248 |
249 | # Only add objects with 3D information/dynamic objects
250 | if track_id in frame_obj_dict:
251 | pose_3d = frame_obj_dict[track_id]
252 | dim = np.array(waymo_obj_meta[track_id][3]).astype(np.float32)
253 | # Move vehicle reference point to bottom of the box like vkitti
254 | pose_3d[1] = pose_3d[1] + (dim[1] / 2)
255 |
256 | internal_track_id = np.array(waymo_obj_meta[track_id][1]).astype(np.float32)[None]
257 | obj_type = np.array(waymo_obj_meta[track_id][2]).astype(np.float32)[None]
258 |
259 | obj = np.concatenate([frame_no, cam_no, internal_track_id, obj_type, dim, pose_3d])
260 |
261 | image_objects[i_obj, :] = obj
262 | i_obj += 1
263 |
264 | # Extract 2D bounding box for training
265 | bbox_2d = get_bbox_2d(label_2d)
266 | images_boxes.append(bbox_2d)
267 |
268 | if i_obj > max_frame_obj:
269 | max_frame_obj = i_obj
270 |
271 | bboxes.append(images_boxes)
272 | visible_objects.append(np.array(image_objects))
273 |
274 | if len(frames) >= max_frames:
275 | break
276 |
277 | if max_frame_obj > 0:
278 | visible_objects = np.array(visible_objects)[:, :max_frame_obj, :]
279 | else:
280 | print(max_frame_obj)
281 | print(visible_objects)
282 | visible_objects = np.array(visible_objects)[:, None, :]
283 | poses = np.array(poses)
284 | bboxes = np.array(bboxes)
285 | images = (np.maximum(np.minimum(np.array(images), 255), 0) / 255.).astype(np.float32)
286 |
287 |
288 | focal = np.reshape(frame.context.camera_calibrations[0].intrinsic, [9])[0]
289 | H = frame.context.camera_calibrations[0].height
290 | W = frame.context.camera_calibrations[0].width
291 |
292 |
293 | i_split = [np.sort(count[:]),
294 | count[int(0.8 * len(count)):],
295 | count[int(0.8 * len(count)):]]
296 |
297 | novel_view = 'left'
298 | n_oneside = int(poses.shape[0]/2)
299 |
300 | render_poses = poses[:1]
301 | # Novel view middle between both cameras:
302 | if novel_view == 'mid':
303 | new_poses_o = ((poses[n_oneside:, :, -1] - poses[:n_oneside, :, -1]) / 2) + poses[:n_oneside, :, -1]
304 | new_poses = np.concatenate([poses[:n_oneside, :, :-1], new_poses_o[...,None]], axis=2)
305 | render_poses = new_poses
306 |
307 | elif novel_view == 'left':
308 | # Render at trained left camera pose
309 | render_poses = poses[:n_oneside, ...]
310 | elif novel_view == 'right':
311 | # Render at trained left camera pose
312 | render_poses = poses[n_oneside:, ...]
313 |
314 | if use_obj:
315 | render_objects = visible_objects[:n_oneside, ...]
316 | else:
317 | render_objects = None
318 |
319 | # Create meta file matching vkitti2 meta data
320 | objects_meta = {}
321 | for meta_value in waymo_obj_meta.values():
322 | objects_meta[meta_value[1]] = np.concatenate([np.array(meta_value[1])[None],
323 | meta_value[3],
324 | np.array([meta_value[2]]) ])
325 |
326 | half_res = True
327 | if half_res:
328 | print('Using half resolution!!!')
329 | H = H // 2
330 | W = W // 2
331 | focal = focal / 2.
332 | images = tf.image.resize_area(images, [H, W]).numpy()
333 |
334 | for frame_boxes in bboxes:
335 | for i_box, box in enumerate(frame_boxes):
336 | frame_boxes[i_box] = box // 2
337 |
338 |
339 | return images, poses, render_poses, [H, W, focal], i_split, visible_objects, objects_meta, render_objects, bboxes, waymo_obj_meta
--------------------------------------------------------------------------------
/src/pointLF/attention_modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from src.pointLF.feature_mapping import PositionalEncoding
5 | from src.pointLF.layer import DenseLayer
6 |
7 |
8 | class FeatureDistanceEncoder(nn.Module):
9 | def __init__(self, feat_dim_in=128, W=256, D=1, n_frequ_pos_encoding=4, key_len=16, val_len=None, use_distance=True, use_projected=False, use_pitch=True, use_azimuth=True, azimuth_2d=False, no_feat=False):
10 | '''
11 | feat_dim_in:
12 | W:
13 | D:
14 | n_frequ_pos_encoding:
15 | key_len:
16 | use_distance:
17 | use_projected:
18 | '''
19 | super(FeatureDistanceEncoder, self).__init__()
20 | use_feat = not no_feat
21 |
22 | self.poseEncodingLen = PositionalEncoding(n_frequ_pos_encoding, 1)
23 | self.poseEncodingAng = PositionalEncoding(n_frequ_pos_encoding, 1, include_input=False)
24 |
25 | self.use_distance = use_distance
26 | self.use_projected = use_projected
27 | self.use_pitch = use_pitch
28 | self.use_azimuth = use_azimuth
29 | self.az_2d = azimuth_2d
30 | self.use_feat = use_feat
31 |
32 | self.used_distances = {
33 | 'distance': use_distance,
34 | 'projected_distance': use_projected,
35 | 'pitch': use_pitch,
36 | 'azimuth': use_azimuth,
37 | }
38 |
39 | self.enc_dist_dim = (n_frequ_pos_encoding * 2) * (use_distance + use_projected + use_pitch + use_azimuth + azimuth_2d) \
40 | + 1 * (use_distance + use_projected + 2 * azimuth_2d)
41 |
42 | self.input_ch = feat_dim_in * use_feat + self.enc_dist_dim
43 |
44 | self.dim_feat = feat_dim_in * use_feat
45 |
46 | self.linear = nn.ModuleList(
47 | [DenseLayer(self.input_ch, W)] +
48 | [DenseLayer(W, W,) for i in range(D - 1)])
49 |
50 |
51 | if val_len is None:
52 | self.val_out_ch = W - key_len
53 | else:
54 | self.val_out_ch = val_len
55 |
56 | self.value_linear = DenseLayer(W, self.val_out_ch)
57 | self.key_linear = DenseLayer(W, key_len)
58 |
59 | def forward(self, features, distance, projected_distance, pitch, azimuth):
60 | '''
61 |
62 | directions: [Batch_sz, N_rays, 3]
63 | features: [Batch_sz, N_rays, N_k_closest, N_feat_maps, dim_feat]
64 | distance: [Batch_sz, N_rays, N_k_closest]
65 | projected_distance: [Batch_sz, N_rays, N_k_closest]
66 | :return:
67 | :rtype:
68 | '''
69 | n_batch, n_rays, k_closest, n_feat_maps, feat_len = features.shape
70 |
71 | x = torch.empty(n_batch, n_rays, k_closest, 0, device=distance.device)
72 |
73 | # Normalize feature distances
74 | if self.used_distances['distance']:
75 | enc_dist = self.poseEncodingLen(distance[..., None])
76 | x = torch.cat([x, enc_dist], dim=-1)
77 |
78 | if self.used_distances['projected_distance']:
79 | x_proj_normalized = projected_distance / projected_distance.max(dim=-1).values[..., None]
80 | enc_proj = self.poseEncodingLen(x_proj_normalized[..., None])
81 | x = torch.cat([x, enc_proj], dim=-1)
82 |
83 | if self.used_distances['pitch']:
84 | enc_pitch = self.poseEncodingAng(pitch[..., None])
85 | x = torch.cat([x, enc_pitch], dim=-1)
86 |
87 | if self.used_distances['azimuth']:
88 | if not self.az_2d:
89 | enc_azimuth = self.poseEncodingAng(azimuth[..., None])
90 | else:
91 | enc_azimuth = torch.cat([
92 | self.poseEncodingLen(torch.sin(azimuth)[..., None]),
93 | self.poseEncodingLen(torch.cos(azimuth)[..., None])
94 | ], dim=-1)
95 |
96 | x = torch.cat([x, enc_azimuth], dim=-1)
97 |
98 | x = x[..., None, :].expand(n_batch, n_rays, k_closest, n_feat_maps, self.enc_dist_dim)
99 |
100 | if self.use_feat:
101 | x = torch.cat([features, x], dim=-1)
102 |
103 | x = x.reshape(n_batch*n_rays, k_closest * n_feat_maps, self.input_ch)
104 |
105 | for i, layer in enumerate(self.linear):
106 | x = layer(x)
107 | x = F.relu(x)
108 |
109 | values = self.value_linear(x)
110 | keys = self.key_linear(x)
111 |
112 | return values, keys
113 |
114 |
115 | class RayPointPoseEncoder(nn.Module):
116 | def __init__(self, W=64, D=1, n_frequ_pos_encoding=4, key_len=16, use_distance=True,
117 | use_projected=True, use_angle=True, feat_map_encoding=True):
118 | '''
119 | feat_len:
120 | W:
121 | D:
122 | n_frequ_pos_encoding:
123 | key_len:
124 | use_distance:
125 | use_projected:
126 | '''
127 | super(RayPointPoseEncoder, self).__init__()
128 | self.poseEncoding = PositionalEncoding(n_frequ_pos_encoding, 1)
129 |
130 | self.use_distance = use_distance
131 | self.use_projected = use_projected
132 | self.use_angle = use_angle
133 | self.use_feat_map_enc = feat_map_encoding
134 |
135 | self.input_ch = (n_frequ_pos_encoding * 2 + 1) * (use_distance + use_projected + use_angle) + (n_frequ_pos_encoding * 2) * feat_map_encoding * 3
136 |
137 | self.linear = nn.ModuleList(
138 | [DenseLayer(self.input_ch, W)] +
139 | [DenseLayer(W, W, ) for i in range(D - 1)])
140 |
141 | self.key_linear = DenseLayer(W, key_len)
142 |
143 | self.feat_map_enc = torch.tensor([[-1., 0., 0.], [1., 0., 0.],
144 | [0., -1., 0.], [0., 1., 0.],
145 | [0., 0., -1.], [0., 0., 1.], ])
146 | self.feat_map_enc = PositionalEncoding(n_frequ_pos_encoding, 3, include_input=False,)(self.feat_map_enc)
147 |
148 | def forward(self, distance, projected_distance, angle, n_feat_maps=1):
149 | '''
150 | distance: [Batch_sz, N_rays, N_k_closest]
151 | projected_distance: [Batch_sz, N_rays, N_k_closest]
152 | angle: [Batch_sz, N_rays, N_k_closest]
153 | :return:
154 | :rtype:
155 | '''
156 | n_batch, n_rays, k_closest = distance.shape
157 |
158 | # Normalize feature distances
159 | x_proj_normalized = projected_distance / projected_distance.max(dim=-1).values[..., None]
160 | x_dist_normalized = distance / distance.max(dim=-1).values[..., None]
161 |
162 | # Positional encoding of ray distances and projected distance
163 | # TODO: Debug from here
164 | enc_proj = self.poseEncoding(x_proj_normalized[..., None])
165 | enc_dist = self.poseEncoding(x_dist_normalized[..., None])
166 | enc_angle = self.poseEncoding(angle[..., None])
167 |
168 | if self.use_feat_map_enc:
169 | if n_feat_maps == 6:
170 | feat_map_enc = self.feat_map_enc.to(enc_proj.device)
171 |
172 |
173 | x = torch.cat([enc_dist, enc_proj, enc_angle], dim=-1)
174 | x = x[..., None, :].repeat(1, 1, 1, n_feat_maps, 1)
175 |
176 | feat_map_enc = feat_map_enc[None, None, None].repeat(n_batch, n_rays, k_closest, 1, 1)
177 |
178 | x = torch.cat([x, feat_map_enc], dim=-1)
179 | x = x.reshape(n_batch * n_rays, k_closest * n_feat_maps, self.input_ch)
180 |
181 | else:
182 | Warning('Not implemented yet.')
183 | else:
184 | x = torch.cat([enc_dist, enc_proj, enc_angle], dim=-1)
185 | x = x.reshape(n_batch * n_rays, k_closest, self.input_ch)
186 |
187 | # enc_proj = enc_proj.unsqueeze(3).repeat(1, 1, 1, n_feat_maps, 1)
188 | # enc_dist = enc_dist.unsqueeze(3).repeat(1, 1, 1, n_feat_maps, 1)
189 | # enc_angle = enc_angle.unsqueeze(3).repeat(1, 1, 1, n_feat_maps, 1)
190 |
191 |
192 |
193 | for i, layer in enumerate(self.linear):
194 | x = layer(x)
195 | x = F.relu(x)
196 |
197 | keys = self.key_linear(x)
198 |
199 | return keys
200 |
201 |
202 | class RayEncoder(nn.Module):
203 | def __init__(self, W=64, D=1, q_len=16, n_frequ_pos_encoding=4,):
204 | super(RayEncoder, self).__init__()
205 | self.poseEncoding = PositionalEncoding(n_frequ_pos_encoding, 3)
206 |
207 | self.input_ch = (n_frequ_pos_encoding * 2 + 1) * 3
208 |
209 | self.linear = nn.ModuleList(
210 | [DenseLayer(self.input_ch, W)] +
211 | [DenseLayer(W, W, ) for i in range(D - 2)])
212 |
213 | self.query_linear = DenseLayer(W, q_len,)
214 |
215 |
216 | def forward(self, ray_direction):
217 | x = self.poseEncoding(ray_direction)
218 |
219 | for i, layer in enumerate(self.linear):
220 | x = layer(x)
221 |
222 | x = self.query_linear(x)
223 |
224 | return x
225 |
226 | class ScaledDotProductAttention(nn.Module):
227 | def __init__(self, temperature, attn_dropout=0.1):
228 | super(ScaledDotProductAttention, self).__init__()
229 | self.temperature = temperature
230 | self.dropout = nn.Dropout(attn_dropout)
231 |
232 | def forward(self, q, k, v, mask=None):
233 | attn = torch.matmul(q / self.temperature, k.transpose(1, 2))
234 |
235 | if mask is not None:
236 | attn = attn.masked_fill(mask == 0, -1e9)
237 |
238 |
239 | attn = self.dropout(F.softmax(attn, dim=-1))
240 | output = torch.matmul(attn, v)
241 |
242 | return output, attn
243 |
244 |
245 |
246 | class PointFeatureAttention(nn.Module):
247 | def __init__(self, feat_dim_in, feat_dim_out, embeded_dim=128, n_att_heads=8, kdim=128, vdim=128, Feat_D=2, Feat_W=256, Ray_D=2, Ray_W=64,
248 | new_encoding=False, no_feat=False):
249 | super(PointFeatureAttention, self).__init__()
250 |
251 | if new_encoding:
252 | n_frequ = 8
253 | else:
254 | n_frequ = 4
255 |
256 | self.dim_out = feat_dim_out
257 | self.FeatureEncoder = FeatureDistanceEncoder(feat_dim_in=feat_dim_in,
258 | W=Feat_W,
259 | D=Feat_D,
260 | n_frequ_pos_encoding=n_frequ,
261 | key_len=kdim,
262 | val_len=vdim,
263 | use_distance=True,
264 | use_projected=False,
265 | use_pitch=True,
266 | use_azimuth=True,
267 | azimuth_2d=new_encoding,
268 | no_feat=no_feat)
269 |
270 | self.RayEncoder = RayEncoder(W=Ray_W,
271 | D=Ray_D,
272 | q_len=embeded_dim,
273 | n_frequ_pos_encoding=4,)
274 |
275 | self.n_att_heads = n_att_heads
276 |
277 | if self.n_att_heads > 0:
278 | self.attention = nn.MultiheadAttention(embed_dim=embeded_dim, num_heads=n_att_heads, kdim=kdim, vdim=vdim)
279 | else:
280 | self.attention = ScaledDotProductAttention(1.)
281 |
282 | self.dim_reduction_out = nn.Linear(embeded_dim, feat_dim_out)
283 |
284 | def forward(self, directions, features, distance=None, projected_distance=None, pitch=None, azimuth=None, **kwargs):
285 | feat_dim = features.shape[-1]
286 | n_batch, n_rays, _ = directions.shape
287 | # Generate Key and values from projected point cloud features and positional encoded ray-point distance with an MLP
288 | val, key = self.FeatureEncoder(features, distance, projected_distance, pitch, azimuth)
289 |
290 | # Use encoded ray-dirs in MLP to get query vector
291 | query = self.RayEncoder(directions.reshape(-1, 3))
292 | query = query[:, None]
293 |
294 | # Attention
295 | # Query Vector + per point Key + per-point value vector in transformer-style attention + pooling to create conditioning vector
296 | # Transpose for torch attention module
297 | if self.n_att_heads > 0:
298 | query = query.transpose(1,0)
299 | key = key.transpose(1, 0)
300 | val = val.transpose(1, 0)
301 | out, attn_weights = self.attention(query, key, val)
302 | out = self.dim_reduction_out(out)
303 | out = out.transpose(0, 1)
304 | else:
305 | out, attn_weights = self.attention(query, key, val)
306 | out = self.dim_reduction_out(out)
307 |
308 | out = out.reshape(n_batch, n_rays, self.dim_out)
309 | return out, attn_weights
310 |
311 |
312 | class PointDistanceAttention(nn.Module):
313 | def __init__(self, v_len=128, kq_len=16):
314 | super(PointDistanceAttention, self).__init__()
315 |
316 | self.RayPointPoseEncoder = RayPointPoseEncoder(W=256,
317 | D=2,
318 | n_frequ_pos_encoding=4,
319 | key_len=kq_len,
320 | use_distance=True,
321 | use_projected=True,
322 | use_angle=True,)
323 |
324 | self.RayEncoder = RayEncoder(W=64,
325 | D=1,
326 | q_len=kq_len,
327 | n_frequ_pos_encoding=4, )
328 |
329 | self.kq_len = kq_len
330 | self.val_len = v_len
331 |
332 | # TODO: Integrate and Decide between multi head and single attention module
333 | self.attention = ScaledDotProductAttention(1.)
334 |
335 | def forward(self, directions, features, distance, projected_distance, angle, **kwargs):
336 | n_feat_maps, feat_dim = features.shape[-2:]
337 | n_batch, n_rays, _ = directions.shape
338 |
339 | # b) Take features from feature pints directly and encode with distance, projection and angles
340 | key = self.RayPointPoseEncoder(distance, projected_distance, angle, n_feat_maps)
341 | val = features.reshape(n_batch * n_rays, -1, feat_dim)
342 |
343 | # Use encoded ray-dirs in MLP to get query vector
344 | query = self.RayEncoder(directions.reshape(-1, 3))
345 |
346 | # Attention
347 | # Query Vector + per point Key + per-point value vector in transformer-style attention + pooling to create conditioning vector
348 | out, attn = self.attention(query[:, None], key, val)
349 |
350 | out = out.reshape(n_batch, n_rays, self.val_len)
351 | return out
--------------------------------------------------------------------------------
/src/pointLF/icp/pts_registration.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import open3d as o3d
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 |
8 | class ICP(nn.Module):
9 | def __init__(self, n_frames=5, step_width=2):
10 | super(ICP, self).__init__()
11 | self.n_frames = n_frames
12 | self.step_width = step_width
13 |
14 | self._merged_pcd_cache = {}
15 | # nppcd = np.load('./src/pointLF/icp/all_pts.npy') #id: 8,0,4,6,10,8,6,6
16 | # poses = np.load('./src/pointLF/icp/lidar_poses.npy') #id: 0,1,2,3,4,5,6,7,8,9,10
17 | #
18 | #
19 | # pts_4 = self.array2pcd(self.transformpts(nppcd[2], poses[4]))
20 | # pts_6 = self.array2pcd(self.transformpts(nppcd[3], poses[6]))
21 | # pts_10 = self.array2pcd(self.transformpts(nppcd[4], poses[10]))
22 | #
23 | # pts_4.paint_uniform_color([1, 0.706, 0])
24 | # pts_6.paint_uniform_color([0, 0.651, 0.929])
25 | # pts_10.paint_uniform_color([0.0, 1.0, 0.0])
26 | #
27 | # pts_4 = self.icp_core(pts_4,pts_10)
28 | # pts_6 = self.icp_core(pts_6,pts_10)
29 | # o3d.visualization.draw_geometries([pts_4,pts_6,pts_10])
30 |
31 | def forward(self, scene, cam_frame_id, caching=False, augment=False, pc_frame_id=None):
32 | '''
33 |
34 | :param scene:
35 | :type scene:
36 | :param cam_frame_id:
37 | :type cam_frame_id:
38 | :param caching:
39 | :type caching:
40 | :param augment: Do not choose the pcd from the current frame but a frame in front or after
41 | :type augment:
42 | :return:
43 | :rtype:
44 | '''
45 | li_sel_idx = np.array([v.scene_idx if v.name == 'TOP' else -1 for k, v in scene.nodes['lidar'].items()])
46 | top_lidar_id = li_sel_idx[np.where(li_sel_idx > 0)][0]
47 |
48 | if pc_frame_id is None:
49 | if augment:
50 | max_augment_distance = 5
51 | augment_distance = np.random.randint(0, max_augment_distance * 2 + 1) - max_augment_distance
52 | pc_frame_id = np.minimum(np.maximum(cam_frame_id + augment_distance, 0), len(scene.frames) - 1)
53 | else:
54 | pc_frame_id = cam_frame_id
55 |
56 | current_points_frame = scene.frames[pc_frame_id]
57 | current_camera_frame = scene.frames[cam_frame_id]
58 |
59 | pcd_path = scene.frames[pc_frame_id].point_cloud_pth[top_lidar_id]
60 |
61 | if not scene.scene_descriptor["type"] == "kitti":
62 | if scene.scene_descriptor.get('pt_cloud_fix', False):
63 | merged_pcd_dir = os.path.join(
64 | *(
65 | ["/"] +
66 | pcd_path.split("/")[1:-1] +
67 | ['merged_pcd_full_{}_scene_{}_{}_frames_{}_{}_n_fr_{}'.format(
68 | scene.scene_descriptor['type'],
69 | # str(scene.scene_descriptor['scene_id']).zfill(4),
70 | str(scene.scene_descriptor['scene_id'][0]).zfill(4),
71 | str(scene.scene_descriptor['scene_id'][1]).zfill(4),
72 | str(scene.scene_descriptor['first_frame']).zfill(4),
73 | str(scene.scene_descriptor['last_frame']).zfill(4),
74 | str(self.n_frames).zfill(4) if self.n_frames is not None else str('all'),)
75 | ]
76 | )
77 | )
78 | else:
79 | merged_pcd_dir = os.path.join(
80 | *(
81 | ["/"] +
82 | pcd_path.split("/")[1:-1] +
83 | ['merged_pcd_{}_scene_{}_{}_frames_{}_{}_n_fr_{}'.format(
84 | scene.scene_descriptor['type'],
85 | # str(scene.scene_descriptor['scene_id']).zfill(4),
86 | str(scene.scene_descriptor['scene_id'][0]).zfill(4),
87 | str(scene.scene_descriptor['scene_id'][1]).zfill(4),
88 | str(scene.scene_descriptor['first_frame']).zfill(4),
89 | str(scene.scene_descriptor['last_frame']).zfill(4),
90 | str(self.n_frames).zfill(4) if self.n_frames is not None else str('all'), )
91 | ]
92 | )
93 | )
94 | else:
95 | merged_pcd_dir = os.path.join(
96 | *(
97 | ["/"] +
98 | pcd_path.split("/")[1:-1] +
99 | ['merged_pcd_full_{}_scene_{}_frames_{}_{}_n_fr_{}'.format(
100 | scene.scene_descriptor['type'],
101 | str(scene.scene_descriptor['scene_id']).zfill(4),
102 | str(scene.scene_descriptor['first_frame']).zfill(4),
103 | str(scene.scene_descriptor['last_frame']).zfill(4),
104 | str(self.n_frames).zfill(4) if self.n_frames is not None else str('all'), )
105 | ]
106 | )
107 | )
108 |
109 | merged_pcd_path = os.path.join(
110 | merged_pcd_dir, '{}.pcd'.format(str(pc_frame_id).zfill(6))
111 | )
112 |
113 | os.umask(0)
114 | if not os.path.isdir(merged_pcd_dir):
115 | os.mkdir(merged_pcd_dir)
116 | else:
117 | # TODO: Add version check here
118 | if os.path.isfile(merged_pcd_path):
119 | current_points_frame.merged_pcd_pth = merged_pcd_path
120 |
121 | # Get the transformation from world coordinates to the vehicle coordinates of the requested frame
122 | veh2wo_0 = current_points_frame.global_transformation
123 | wo2veh_0 = np.concatenate(
124 | [veh2wo_0[:3, :3].T,
125 | veh2wo_0[:3, :3].T.dot(-veh2wo_0[:3, 3])[:, None]],
126 | axis=1
127 | )
128 | wo2veh_0 = np.concatenate([wo2veh_0, np.array([[0., 0., 0., 1.]])])
129 | pose_0 = np.eye(4)
130 |
131 | # Get the transformation from the vehicle pose of the camera to the vehivle pose of the point cloud
132 | veh2wo_cam = current_camera_frame.global_transformation
133 | camera_trafo = wo2veh_0.dot(veh2wo_cam)
134 |
135 | if current_points_frame.merged_pcd_pth is None or (current_points_frame.merged_pcd is None and caching):
136 | all_points_post = None
137 | all_points_pre = None
138 |
139 | # Get points
140 | pts_0 = current_points_frame.load_point_cloud(top_lidar_id)
141 | pts_0 = self.array2pcd(self.transformpts(pts_0[:, :3], pose_0))
142 |
143 | all_points = pts_0
144 | all_points, _ = all_points.remove_statistical_outlier(nb_neighbors=30, std_ratio=2.0)
145 |
146 | first_fr_id = min(scene.frames.keys())
147 | if self.n_frames is not None:
148 | pre_current = np.linspace(cam_frame_id - 1, first_fr_id, cam_frame_id, dtype=int)[:self.n_frames]
149 | else:
150 | pre_current = np.linspace(cam_frame_id - 1, first_fr_id, cam_frame_id, dtype=int)
151 |
152 | last_fr_id = max(scene.frames.keys())
153 | if self.n_frames is not None:
154 | post_current = np.linspace(cam_frame_id + 1, last_fr_id, last_fr_id - cam_frame_id, dtype=int)[:self.n_frames]
155 | else:
156 | post_current = np.linspace(cam_frame_id + 1, last_fr_id, last_fr_id - cam_frame_id, dtype=int)
157 |
158 | # Loop over adjacent frames in the future
159 | # all_points_post = self.merge_adajcent_points(post_current, scene, wo2veh_0, pts_0, top_lidar_id)
160 |
161 | for fr_id in np.concatenate([post_current]):
162 | frame_i = scene.frames[fr_id]
163 | # Load point cloud
164 | pts_i = frame_i.load_point_cloud(top_lidar_id)
165 |
166 | # Do not keep dynamic scene parts behind the geo vehicle from future frames
167 | pts_front_idx = np.where(pts_i[:, 0] > 0.)
168 | pts_back_idx = np.where(np.all(np.stack([pts_i[:, 0] < 0., np.abs(pts_i[:, 1]) > 1.5]), axis=0))
169 | pts_idx = np.concatenate([pts_front_idx[0], pts_back_idx[0]])
170 | pts_i = pts_i[pts_idx]
171 |
172 | # Get Transformation from veh frame to world
173 | veh2wo_i = frame_i.global_transformation
174 |
175 | # Center all point clouds at the requested frame
176 | # Waymo
177 | pose_i = wo2veh_0.dot(veh2wo_i)
178 |
179 | # Transform point cloud into the vehicle frame of the current frame
180 | pts_i = self.array2pcd(self.transformpts(pts_i[:, :3], pose_i))
181 | # Remove noise from point cloud
182 | pts_i, _ = pts_i.remove_statistical_outlier(nb_neighbors=30, std_ratio=2.0)
183 |
184 | if all_points_post is None:
185 | # Match first pointcloud close to the selected frames lidar pose
186 | pts_i = self.icp_core(pts_i, pts_0)
187 | all_points_post = pts_i
188 | else:
189 | pts_i = self.icp_core(pts_i, all_points_post)
190 | all_points_post = all_points_post + pts_i
191 |
192 | if all_points_post is not None:
193 | all_points_post = self.icp_core(all_points_post, pts_0)
194 | all_points += all_points_post
195 |
196 | # Loop over adjacent frames in the past
197 | # all_points_pre = self.merge_adajcent_points(pre_current, scene, wo2veh_0, pts_0, top_lidar_id)
198 |
199 | for fr_id in np.concatenate([pre_current]):
200 | frame_i = scene.frames[fr_id]
201 | # Load point cloud
202 | pts_i = frame_i.load_point_cloud(top_lidar_id)
203 |
204 | # Get Transformation from veh frame to world
205 | veh2wo_i = frame_i.global_transformation
206 |
207 | # Center all point clouds at the requested frame
208 | # Waymo
209 | pose_i = wo2veh_0.dot(veh2wo_i)
210 |
211 | # Transform point cloud into the vehicle frame of the current frame
212 | pts_i = self.array2pcd(self.transformpts(pts_i[:, :3], pose_i))
213 | # Remove noise from point cloud
214 | pts_i, _ = pts_i.remove_statistical_outlier(nb_neighbors=30, std_ratio=2.0)
215 |
216 | if all_points_pre is None:
217 | # Match first pointcloud close to the selected frames lidar pose
218 | pts_i = self.icp_core(pts_i, pts_0)
219 | all_points_pre = pts_i
220 | else:
221 | pts_i = self.icp_core(pts_i, all_points_pre)
222 | all_points_pre = all_points_pre + pts_i
223 |
224 | if all_points_pre is not None:
225 | all_points_pre = self.icp_core(all_points_pre, pts_0)
226 | all_points += all_points_pre
227 |
228 | # Outlier and noise removal
229 | all_points, ind = all_points.remove_statistical_outlier(nb_neighbors=30, std_ratio=2.0)
230 |
231 | if caching:
232 | print("Caching pcd")
233 | current_points_frame.merged_pcd = all_points
234 | else:
235 | # Store merged point cloud for future readings
236 | o3d.io.write_point_cloud(merged_pcd_path, all_points)
237 | else:
238 | if caching:
239 | print("Retriving pcd from cache")
240 | all_points = current_points_frame.merged_pcd
241 | else:
242 | # t0 = time.time()
243 | all_points = o3d.io.read_point_cloud(current_points_frame.merged_pcd_pth, format='pcd')
244 | # print("Read points {}".format(time.time()- t0))
245 |
246 | pts = np.asarray(all_points.points)
247 |
248 | return pts, merged_pcd_path, camera_trafo, pc_frame_id
249 |
250 |
251 | def merge_adajcent_points(self, fr_id_ls, scene, wo2veh_0, pts_0, top_lidar_id):
252 |
253 | # Loop over adjacent frames in the future
254 | for fr_id in np.concatenate([fr_id_ls]):
255 | frame_i = scene.frames[fr_id]
256 | # Load point cloud
257 | pts_i = frame_i.load_point_cloud(top_lidar_id)
258 |
259 | # Do not keep dynamic scene parts hiding behind the ego vehicle from adjacent frames
260 | pts_front_idx = np.where(pts_i[:, 0] > 0.)
261 | pts_back_idx = np.where(np.all(np.stack([pts_i[:, 0] < 0., np.abs(pts_i[:, 1]) > 1.5]), axis=0))
262 | pts_idx = np.concatenate([pts_front_idx[0], pts_back_idx[0]])
263 | pts_i = pts_i[pts_idx]
264 |
265 | # Get Transformation from veh frame to world
266 | veh2wo_i = frame_i.global_transformation
267 |
268 | # Center all point clouds at the requested frame
269 | # Waymo
270 | pose_i = wo2veh_0.dot(veh2wo_i)
271 |
272 | # Transform point cloud into the vehicle frame of the current frame
273 | pts_i = self.array2pcd(self.transformpts(pts_i[:, :3], pose_i))
274 | # Remove noise from point cloud
275 | pts_i, _ = pts_i.remove_statistical_outlier(nb_neighbors=30, std_ratio=2.0)
276 |
277 | if all_points_adj is None:
278 | # Match first pointcloud close to the selected frames lidar pose
279 | pts_i = self.icp_core(pts_i, pts_0)
280 | all_points_adj = pts_i
281 | else:
282 | pts_i = self.icp_core(pts_i, all_points_adj)
283 | all_points_adj = all_points_adj + pts_i
284 |
285 | return all_points_adj
286 |
287 |
288 | def transformpts(self,pts,pose):
289 | pts = np.concatenate((pts,np.ones((pts.shape[0],1))),1)
290 | pts = (pose.dot(pts.T)).T
291 | return pts[:,:3]
292 |
293 | def array2pcd(self, all_pts, color=None):
294 | pcd = o3d.geometry.PointCloud()
295 | pcd.points = o3d.utility.Vector3dVector(all_pts)
296 | if color is not None:
297 | pcd.paint_uniform_color(color)
298 | return pcd
299 |
300 | def icp_core(self,processed_source,processed_target):
301 | threshold = 1.0
302 | trans_init = np.eye(4).astype(np.int)
303 |
304 | reg_p2p = o3d.pipelines.registration.registration_icp(
305 | source=processed_source,
306 | target=processed_target,
307 | max_correspondence_distance=threshold,
308 | init=trans_init,
309 | estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint()
310 | )
311 |
312 | # print(reg_p2p.transformation)
313 | processed_source.transform(reg_p2p.transformation)
314 | return processed_source
315 |
--------------------------------------------------------------------------------