├── LICENSE
├── README.md
├── SCONE
├── Attention.py
├── CustomDataset.py
├── CustomGeometry.py
├── SconeOcc.py
├── SconeVis.py
├── idr_torch.py
├── pretrain_scone_occ.py
├── pretrain_scone_vis.py
├── scone_utils.py
├── spherical_harmonics.py
├── test_scone.py
└── utils.py
└── docs
└── gifs
├── colosseum.gif
├── fushimi.gif
├── museum.gif
└── pantheon.gif
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Antoine GUEDON
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SCONE
2 |
3 |
19 |
20 | Official PyTorch implementation of [**SCONE: Surface Coverage Optimization in Unknown Environments by Volumetric Integration**](https://arxiv.org/abs/2208.10449) (NeurIPS 2022, Spotlight).
21 |
22 |
23 | :warning: :warning: :warning: :warning: :warning: :warning: :warning: :warning: :warning: :warning: :warning: :warning:
24 |
25 |
26 |
27 | **We released a new model called [MACARONS (CVPR 2023)](https://github.com/Anttwo/MACARONS) which is a direct improvement of SCONE.
28 | MACARONS adapts the approach from SCONE to a fully self-supervised pipeline: It learns simultaneously to explore and reconstruct 3D scenes from RGB images only (there is no need for 3D ground truth data nor depth sensors).
29 | The codebase of MACARONS includes an entire, updated version of SCONE's code as well as detailed instructions for generating the ShapeNet training data we used for SCONE.
30 | Please refer to MACARONS' repo to find all the information you need.**
31 |
32 |
33 | Thank you!
34 | :warning: :warning: :warning: :warning: :warning: :warning: :warning: :warning: :warning: :warning: :warning: :warning:
35 |
36 |
37 |
38 | This repository currently contains:
39 |
40 | - scripts to initialize and train models
41 | - evaluation pipelines to reproduce quantitative results
42 |
43 | **Note**: We will add **installation guidelines**, **training data generation scripts** and **test notebooks** as soon as possible to allow for reproducibility.
44 |
45 |
46 | If you find this code useful, don't forget to star the repo :star: and cite the paper :point_down:
47 |
48 | ```
49 | @inproceedings{guedon2022scone,
50 | title={{SCONE: Surface Coverage Optimization in Unknown Environments by Volumetric Integration}},
51 | author={Gu\'edon, Antoine and Monasse, Pascal and Lepetit, Vincent},
52 | booktitle={{Advances in Neural Information Processing Systems}},
53 | year={2022},
54 | }
55 | ```
56 |
57 |
58 |
59 |
60 | Major code updates :clipboard:
61 |
62 | - 11/22: first code release
63 |
64 |
65 |
66 | ## Installation :construction_worker:
67 |
68 | We will add more details as soon as possible.
69 |
70 | ## Download Data
71 |
72 | ### 1. ShapeNetCore
73 |
74 | We generate training data for both occupancy probability prediction and coverage gain estimation from [ShapeNetCore v1](https://shapenet.org/).
75 | We will add the data generation scripts and corresponding instructions as soon as possible.
76 |
77 | ### 2. Custom Dataset of large 3D scenes
78 |
79 | We conducted inference experiments in large environments using 3D meshes downloaded on the website Sketchfab under the CC license.
80 | 3D models courtesy of [Brian Trepanier](https://sketchfab.com/CMBC), [Andrea Spognetta](https://sketchfab.com/spogna), and [Vr Interiors](https://sketchfab.com/vrInteriors).
81 | We will add more details as soon as possible.
82 |
83 | ## How to use :rocket:
84 |
85 | We will add more details as soon as possible.
86 |
87 | ## Further information :books:
88 |
89 | We adapted the code from [Phil Wang](https://github.com/lucidrains/se3-transformer-pytorch/blob/main/se3_transformer_pytorch/spherical_harmonics.py) to generate spherical harmonic features.
90 | We thank him for this very useful harmonics computation script!
91 |
92 | We also thank [Tom Monnier](https://www.tmonnier.com/) for his Markdown template, which we took inspiration from.
93 |
--------------------------------------------------------------------------------
/SCONE/Attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 | from pytorch3d.ops import knn_points, knn_gather
6 |
7 |
8 | def attention(q, k, v, mask=None, dropout=None):
9 | """
10 | Main attention mechanism function.
11 |
12 | Takes queries, keys and values as inputs and computes attention scores.
13 | :param q: (Tensor) Queries tensor with shape (..., N, d)
14 | :param k: (Tensor) Keys tensor with shape (..., N, d)
15 | :param v: (Tensor) Values tensor with shape (..., N, d)
16 | :param mask: (Tensor) Mask tensor with shape (..., N, N). Optional.
17 | :param dropout: Dropout module to apply on computed scores.
18 | :return: scores: (Tensor) Attention scores tensor with shape (..., N)
19 | """
20 | # Query/Key matrix multiplication
21 | scores = q.matmul(k.transpose(-2, -1))
22 |
23 | # Apply mask
24 | scores = scores if mask is None else scores.masked_fill(mask == 0, -1e3)
25 |
26 | # Normalization
27 | scores /= np.sqrt(q.shape[-1])
28 | scores = nn.functional.softmax(scores, dim=-1)
29 |
30 | # Dropout
31 | scores = scores if dropout is None else dropout(scores)
32 |
33 | # Value matrix multiplication
34 | scores = scores.matmul(v)
35 |
36 | return scores
37 |
38 |
39 | class Embedding(nn.Module):
40 | def __init__(self, input_dim, output_dim,
41 | dropout=None, gelu=True,
42 | global_feature=False,
43 | additional_feature_dim=0,
44 | concatenate_input=True,
45 | k_for_knn=0):
46 | """
47 | Class used to embed point clouds.
48 |
49 | :param input_dim: (int) Dimension of input points.
50 | :param output_dim: (int) Dimension of output features.
51 | :param dropout: Dropout module to apply on computed embeddings.
52 | :param gelu: (bool) If True, module uses GELU non-linearity. If False, module uses ReLU.
53 | :param global_feature: (bool) If True, output features are computed as the concatenation of per-point features
54 | and a global feature.
55 | :param additional_feature_dim: (int) Dimension of an additional feature to provide.
56 | :param concatenate_input: (bool) If True, concatenates input to the output features.
57 | :param k_for_knn: (int) if > 0, output features are computed by max-pooling features from k nearest neighbors.
58 | """
59 | super(Embedding, self).__init__()
60 |
61 | self.use_knn = k_for_knn > 0
62 | self.k = k_for_knn
63 |
64 | self.input_dim = input_dim
65 |
66 | self.global_feature = global_feature
67 | self.additional_feature_dim = additional_feature_dim
68 | self.concatenate_input = concatenate_input
69 |
70 | self.inner_dim = output_dim // 2
71 | self.feature_dim = output_dim
72 |
73 | if additional_feature_dim > 0:
74 | self.feature_dim -= additional_feature_dim
75 | self.inner_dim = self.feature_dim
76 |
77 | if concatenate_input:
78 | self.feature_dim -= input_dim
79 | self.inner_dim = self.feature_dim
80 |
81 | if global_feature:
82 | self.feature_dim = self.feature_dim // 2
83 | self.inner_dim = self.feature_dim
84 |
85 | if global_feature or self.use_knn:
86 | self.max_pool = nn.functional.max_pool1d
87 |
88 | self.linear1 = nn.Linear(self.input_dim, self.inner_dim)
89 | self.linear2 = nn.Linear(self.inner_dim, self.feature_dim)
90 |
91 | self.dropout = nn.Dropout(dropout) if dropout is not None else None
92 |
93 | if gelu:
94 | self.nonlinear = nn.GELU()
95 | else:
96 | self.nonlinear = nn.ReLU(inplace=False)
97 |
98 | def forward(self, x, additional_feature=None):
99 | n_clouds, seq_len, x_dim = x.shape
100 |
101 | res = self.nonlinear(self.linear1(x))
102 | res = res if self.dropout is None else self.dropout(res)
103 | res = self.linear2(res)
104 |
105 | if self.use_knn:
106 | # Computing spatial kNN
107 | _, knn_idx, _ = knn_points(p1=x[..., :3], p2=x[..., :3], K=self.k)
108 | res_knn = knn_gather(res, knn_idx)
109 |
110 | # Pooling among kNN features
111 | res_knn = res_knn.view(n_clouds * seq_len, self.k, self.feature_dim)
112 | res = self.max_pool(input=res_knn.transpose(-1, -2),
113 | kernel_size=self.k
114 | ).view(n_clouds, seq_len, self.feature_dim)
115 |
116 | if self.global_feature:
117 | global_feature = self.max_pool(input=res.transpose(-1, -2),
118 | kernel_size=seq_len).view(n_clouds, 1, self.feature_dim)
119 | global_feature = global_feature.expand(-1, seq_len, -1)
120 | res = torch.cat((res, global_feature), dim=-1)
121 |
122 | if self.additional_feature_dim > 0:
123 | res = torch.cat((res, additional_feature), dim=-1)
124 |
125 | if self.concatenate_input:
126 | res = torch.cat((res, x), dim=-1)
127 |
128 | return res
129 |
130 |
131 | class MultiHeadSelfAttention(nn.Module):
132 | def __init__(self, n_heads, in_dim, qk_dim, dropout=None):
133 | """
134 | Main class for Multi-Head Self Attention neural module.
135 | Credits to https://arxiv.org/pdf/2012.09688.pdf
136 |
137 | :param n_heads: (int) Number of heads in the attention module.
138 | :param in_dim: (int) Dimension of input.
139 | :param qk_dim: (int) Dimension of keys and queries.
140 | :param dropout: Dropout module to apply on computed embeddings.
141 | """
142 | super(MultiHeadSelfAttention, self).__init__()
143 | v_dim = in_dim
144 |
145 | self.n_heads = n_heads
146 | self.in_dim = in_dim
147 | self.qk_dim = qk_dim
148 | self.v_dim = v_dim
149 |
150 | self.qk_dim_per_head = qk_dim // n_heads
151 | self.v_dim_per_head = v_dim // n_heads
152 |
153 | self.w_q = nn.Linear(in_dim, qk_dim)
154 | self.w_k = nn.Linear(in_dim, qk_dim)
155 | self.w_v = nn.Linear(in_dim, v_dim)
156 |
157 | self.dropout = nn.Dropout(dropout) if dropout is not None else None
158 |
159 | if n_heads > 1:
160 | self.out = nn.Linear(in_dim, in_dim)
161 |
162 | def split_heads(self, q, k, v):
163 | """
164 | Split all queries, keys and values between attention heads.
165 |
166 | :param q: (Tensor) Queries tensor with shape (batch_size, seq_len, qk_dim)
167 | :param k: (Tensor) Keys tensor with shape (batch_size, seq_len, qk_dim)
168 | :param v: (Tensor) Values tensor with shape (batch_size, seq_len, v_dim)
169 | :return: (3-tuple of Tensors) Queries, keys and values for each head.
170 | q_split and k_split tensors have shape (batch_size, seq_len, n_heads, qk_dim_per_head).
171 | v_split tensor has shape (batch_size, seq_len, n_heads, v_dim_per_head).
172 | """
173 | # Each have size n_screen_cameras * pts_len * n_heads * dim_per_head
174 | q_split = q.reshape(q.shape[0], -1, self.n_heads, self.qk_dim_per_head)
175 | k_split = k.reshape(k.shape[0], -1, self.n_heads, self.qk_dim_per_head)
176 | v_split = v.reshape(v.shape[0], -1, self.n_heads, self.v_dim_per_head)
177 |
178 | return q_split, k_split, v_split
179 |
180 | def forward(self, x, mask=None):
181 | """
182 | Forward pass.
183 |
184 | :param x: (Tensor) Input tensor with shape (batch_size, seq_len, in_dim)
185 | :param mask: (Tensor) Mask tensor with shape (batch_size, seq_len, seq_len). Optional.
186 | :return: scores (Tensor) Attention scores tensor with shape (batch_size, seq_len, in_dim)
187 | """
188 | # pts should have size BS * SEQ_LEN * IN_DIM
189 | q = self.w_q(x)
190 | k = self.w_k(x)
191 | v = self.w_v(x)
192 |
193 | # Break into n_heads
194 | q, k, v = self.split_heads(q, k, v)
195 | q, k, v = [t.transpose(1, 2) for t in (q, k, v)] # BS * HEAD * SEQ_LEN * DIM_PER_HEAD
196 |
197 | scores = attention(q, k, v, mask, self.dropout) # BS * HEAD * SEQ_LEN * DIM_PER_HEAD
198 | scores = scores.transpose(1, 2).contiguous().view(
199 | scores.shape[0], -1, self.v_dim) # BS * SEQ_LEN * V_DIM
200 |
201 | if self.n_heads > 1:
202 | scores = self.out(scores)
203 |
204 | return scores
205 |
206 |
207 | class FeedForward(nn.Module):
208 | def __init__(self, input_dim, inner_dim, gelu=True, dropout=None):
209 | """
210 | Feed Forward unit to use in attention encoder.
211 |
212 | :param input_dim: (int) Dimension of input tensor.
213 | :param inner_dim: (int) Dimension of inner tensor.
214 | :param gelu: (bool) If True, the unit uses GELU non-linearity. if False, it uses ReLU.
215 | :param dropout: Dropout module to apply on computed embeddings.
216 | """
217 | super(FeedForward, self).__init__()
218 | self.linear1 = nn.Linear(input_dim, inner_dim)
219 | self.linear2 = nn.Linear(inner_dim, input_dim)
220 | self.dropout = nn.Dropout(dropout) if dropout is not None else None
221 | if gelu:
222 | self.nonlinear = nn.GELU()
223 | else:
224 | self.nonlinear = nn.ReLU(inplace=False)
225 |
226 | def forward(self, x):
227 | """
228 | Forward pass.
229 |
230 | :param x: (Tensor) Input tensor with shape (..., input_dim)
231 | :return: (Tensor) Output tensor with shape (..., input_dim)
232 | """
233 | res = self.nonlinear(self.linear1(x))
234 | res = res if self.dropout is None else self.dropout(res)
235 |
236 | return self.linear2(res)
237 |
238 |
239 | class Encoder(nn.Module):
240 | def __init__(self, seq_len, qk_dim, embedding_dim=128, n_heads=1,
241 | dropout=None, gelu=True, FF=True):
242 | """
243 | Transformer encoder based on a Multi-Head Self Attention mechanism.
244 |
245 | :param seq_len: (int) Length of input sequence.
246 | :param qk_dim: (int) Dimension of keys and queries.
247 | :param embedding_dim: (int) Dimension of input embeddings, values, and attention scores.
248 | :param n_heads: (int) Number of heads in the self-attention module.
249 | :param dropout: Dropout module to apply on computed features.
250 | :param gelu: (bool) If True, the encoder uses GELU non-linearity. if False, it uses ReLU.
251 | :param FF: (bool) If True, the encoder applies an additional Feed Forward unit after
252 | the Multi-Head Self Attention unit.
253 | """
254 | super(Encoder, self).__init__()
255 |
256 | self.seq_len = seq_len
257 | self.embedding_dim = embedding_dim
258 | self.n_heads = n_heads
259 | self.qk_dim = qk_dim # self.embedding_dim // 4
260 | self.dropout = dropout
261 | self.FF = FF
262 |
263 | self.norm1 = nn.LayerNorm(self.embedding_dim)
264 | self.mhsa = MultiHeadSelfAttention(n_heads=self.n_heads,
265 | in_dim=self.embedding_dim,
266 | qk_dim=self.qk_dim,
267 | dropout=self.dropout)
268 | self.dropout1 = None if dropout is None else nn.Dropout(dropout)
269 |
270 | if FF:
271 | self.norm2 = nn.LayerNorm(self.embedding_dim)
272 | self.ff = FeedForward(input_dim=embedding_dim,
273 | inner_dim=2*embedding_dim,
274 | gelu=gelu,
275 | dropout=self.dropout)
276 | self.dropout2 = None if dropout is None else nn.Dropout(dropout)
277 |
278 | def forward(self, x, mask=None):
279 | """
280 | Forward pass.
281 |
282 | :param x: (Tensor) Input tensor with shape (batch_size, seq_len, embedding_dim).
283 | :param mask: (Tensor) Mask tensor with shape (batch_size, seq_len, seq_len). Optional.
284 | :return: (Tensor) The features encoded with a self-attention mechanism.
285 | Has shape (batch_size, seq_len, embedding_dim)
286 | """
287 | res = self.norm1(x)
288 | res = self.mhsa(res, mask=mask)
289 | if self.dropout is not None:
290 | res = self.dropout1(res)
291 | res = x + res
292 |
293 | if self.FF:
294 | res2 = self.norm2(res)
295 | res2 = self.ff(res2)
296 | if self.dropout is not None:
297 | res2 = self.dropout2(res2)
298 | res = res + res2
299 |
300 | return res
--------------------------------------------------------------------------------
/SCONE/CustomDataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 | import os
4 | import torch
5 | from torch.utils.data import Dataset, DataLoader
6 | from torch.utils.data.sampler import SubsetRandomSampler
7 |
8 | from pytorch3d.datasets import (
9 | collate_batched_meshes,
10 | render_cubified_voxels,
11 | )
12 | from pytorch3d.renderer import (
13 | TexturesVertex,
14 | TexturesAtlas,
15 | )
16 | from pytorch3d.transforms import(
17 | matrix_to_quaternion,
18 | quaternion_apply,
19 | )
20 | from utils import *
21 | import time
22 |
23 | class CustomDataset(Dataset):
24 |
25 | def __init__(self, data_path, memory_threshold, rasterizer, screen_rasterizer, params, camera_axes, device,
26 | save_to_json=False, load_from_json=False, json_name="models_list.json",
27 | load_obj=True):
28 | self.data_path = data_path
29 | self.rasterizer = rasterizer
30 | self.screen_rasterizer = screen_rasterizer
31 | # self.renderer = renderer
32 | # self.embedder = embedder
33 | # self.preprocess = preprocess
34 | self.image_size = params.image_size
35 | self.camera_dist = params.camera_dist
36 | self.elevation = params.elevation
37 | self.azim_angle = params.azim_angle
38 | self.camera_axes = camera_axes
39 | self.side = params.side
40 | self.device = device
41 |
42 | self.load_obj = load_obj
43 |
44 | if not load_from_json:
45 | models = []
46 | for (dirpath, dirnames, filenames) in os.walk(data_path):
47 | for filename in filenames:
48 | if filename[-4:] == ".obj":
49 | models.append(os.path.join(dirpath, filename))
50 | models = remove_heavy_files(models, memory_threshold)
51 | else:
52 | with open(json_name) as f:
53 | dir_to_load = json.load(f)
54 | models = [os.path.join(data_path, path) for path in dir_to_load['models']]
55 |
56 | if save_to_json:
57 | dir_to_save = {}
58 | dir_to_save['models'] = [path[1+len(data_path):] for path in models]
59 | with open(json_name, 'w') as outfile:
60 | json.dump(dir_to_save, outfile)
61 | print("Saved models list in", json_name, ".")
62 |
63 | print("Database loaded.")
64 | self.models = models
65 |
66 | def __len__(self):
67 | return len(self.models)
68 |
69 | def __getitem__(self, idx): # -> Dict:
70 |
71 | model_path = self.models[idx]
72 | model = {}
73 |
74 | if self.load_obj:
75 | verts, faces, aux = load_obj(
76 | model_path,
77 | device=self.device,
78 | load_textures=False,
79 | create_texture_atlas=True,
80 | # texture_atlas_size=4,
81 | # texture_wrap="repeat",
82 | )
83 |
84 | verts = adjust_mesh(verts)
85 |
86 | # Create a textures object
87 | atlas = aux.texture_atlas
88 |
89 | model["verts"] = verts
90 | model["faces"] = faces.verts_idx
91 | model["textures"] = aux[4]
92 | model["atlas"] = aux.texture_atlas
93 |
94 | model["path"] = model_path
95 | return model
96 |
97 | class CustomShapenetDataset(Dataset):
98 |
99 | def __init__(self, data_path, memory_threshold,
100 | save_to_json=False, load_from_json=False, json_name="models_list.json",
101 | official_split=False, adjust_diagonally=False,
102 | load_obj=True):
103 | self.data_path = data_path
104 | self.official_split = official_split
105 | self.adjust_diagonally = adjust_diagonally
106 |
107 | self.load_obj = load_obj
108 |
109 | if not load_from_json:
110 | models = []
111 | for (dirpath, dirnames, filenames) in os.walk(data_path):
112 | for filename in filenames:
113 | if filename[-4:] == ".obj":
114 | models.append(os.path.join(dirpath, filename))
115 | models = remove_heavy_files(models, memory_threshold)
116 | else:
117 | with open(json_name) as f:
118 | dir_to_load = json.load(f)
119 | models = [os.path.join(data_path, path) for path in dir_to_load['models']]
120 |
121 | if save_to_json:
122 | dir_to_save = {}
123 | dir_to_save['models'] = [path[1+len(data_path):] for path in models]
124 | with open(json_name, 'w') as outfile:
125 | json.dump(dir_to_save, outfile)
126 | print("Saved models list in", json_name, ".")
127 |
128 | print("Database loaded.")
129 | self.models = models
130 |
131 | def __len__(self):
132 | return len(self.models)
133 |
134 | def __getitem__(self, idx): # -> Dict:
135 |
136 | model_path = self.models[idx]
137 | model = {}
138 |
139 | if self.load_obj:
140 | verts, faces, aux = load_obj(
141 | model_path,
142 | # device=self.device,
143 | load_textures=False,
144 | create_texture_atlas=True,
145 | # texture_atlas_size=4,
146 | # texture_wrap="repeat",
147 | )
148 |
149 | if self.adjust_diagonally:
150 | verts = adjust_mesh_diagonally(verts, diag_range=1.0)
151 | else:
152 | verts = adjust_mesh(verts)
153 |
154 | # Create a textures object
155 | atlas = aux.texture_atlas
156 |
157 | model["verts"] = verts
158 | model["faces"] = faces.verts_idx
159 | model["textures"] = aux[4]
160 | model["atlas"] = aux.texture_atlas
161 |
162 | model["path"] = model_path
163 | return model
164 |
165 |
166 | class RGBDataset(Dataset):
167 | def __init__(self, data_path, alpha_max, use_future_images, scene_names=None,
168 | frames_to_remove_json='frames_to_remove.pt'):
169 | self.data_path = data_path
170 | self.alpha_max = alpha_max
171 | self.use_future_images = use_future_images
172 |
173 | self.data = {}
174 | self.indices = {}
175 | current_idx = 0
176 |
177 | self.data['scenes'] = {}
178 | self.data['n_scenes'] = 0
179 |
180 | # If no scene name is provided, we just take all scenes in the folder
181 | if scene_names is None:
182 | # scene_names = os.listdir(self.data_path)
183 | scene_names = [scene_name for scene_name in os.listdir(self.data_path)
184 | if os.path.isdir(os.path.join(self.data_path, scene_name))]
185 |
186 | if frames_to_remove_json in scene_names:
187 | scene_names.remove(frames_to_remove_json)
188 |
189 | self.frames_to_remove = torch.load(os.path.join(data_path, frames_to_remove_json))
190 |
191 | # For every scene...
192 | for scene_name in scene_names:
193 | self.data['n_scenes'] += 1
194 | self.data['scenes'][scene_name] = {}
195 | scene_path = os.path.join(self.data_path, scene_name)
196 | scene_path = os.path.join(scene_path, 'images')
197 | # print(scene_path)
198 |
199 | self.data['scenes'][scene_name]['trajectories'] = {}
200 | self.data['scenes'][scene_name]['n_trajectories'] = 0
201 | # ...And every trajectory...
202 | for trajectory_nb in os.listdir(scene_path):
203 | self.data['scenes'][scene_name]['n_trajectories'] += 1
204 | self.data['scenes'][scene_name]['trajectories'][trajectory_nb] = {}
205 | trajectory_path = os.path.join(scene_path, trajectory_nb)
206 | # print(trajectory_path)
207 |
208 | traj_length = len(os.listdir(trajectory_path))
209 |
210 | self.data['scenes'][scene_name]['trajectories'][trajectory_nb]['frames'] = {}
211 | self.data['scenes'][scene_name]['trajectories'][trajectory_nb]['n_frames'] = 0
212 | # ...We add all frames respecting conditions on the number of past (and, if required, future) frames
213 | for frame_name in os.listdir(trajectory_path):
214 | frame_nb = frame_name[:-3]
215 | short_path = scene_name + "/images/" + str(trajectory_nb) + "/" + str(frame_nb) + ".pt"
216 |
217 | save_index = False
218 | index_to_save = None
219 |
220 | if int(frame_nb) >= self.alpha_max and (
221 | (not self.use_future_images) or
222 | int(frame_nb) < traj_length - self.alpha_max
223 | ):
224 | if not (short_path in self.frames_to_remove.keys()):
225 | self.indices[str(current_idx)] = {'scene_name': scene_name,
226 | 'trajectory_nb': trajectory_nb,
227 | 'frame_nb': frame_nb}
228 | save_index = True
229 | index_to_save = current_idx
230 | current_idx += 1
231 |
232 | self.data['scenes'][scene_name]['trajectories'][trajectory_nb]['n_frames'] += 1
233 | self.data['scenes'][scene_name][
234 | 'trajectories'][trajectory_nb][
235 | 'frames'][frame_nb] = {}
236 | self.data['scenes'][scene_name][
237 | 'trajectories'][trajectory_nb][
238 | 'frames'][frame_nb][
239 | 'path'] = os.path.join(trajectory_path, str(frame_nb) + '.pt')
240 | if save_index:
241 | self.data['scenes'][scene_name][
242 | 'trajectories'][trajectory_nb][
243 | 'frames'][frame_nb][
244 | 'index'] = index_to_save
245 |
246 | print("Database loaded.")
247 |
248 | def __len__(self):
249 | # total_length = 0
250 | # for scene_name in self.data['scenes']:
251 | # for trajectory_nb in self.data['scenes'][scene_name]['trajectories']:
252 | # total_length += len(self.data['scenes'][scene_name]['trajectories'][trajectory_nb]['frames'].keys())
253 | # return total_length
254 |
255 | return len(self.indices.keys())
256 |
257 | def __getitem__(self, idx): # -> Dict:
258 |
259 | scene_name = self.indices[str(idx)]['scene_name']
260 | trajectory_nb = self.indices[str(idx)]['trajectory_nb']
261 | frame_nb = self.indices[str(idx)]['frame_nb']
262 |
263 | frame_path = self.data['scenes'][scene_name]['trajectories'][trajectory_nb]['frames'][frame_nb]['path']
264 | frame = torch.load(frame_path, map_location='cpu')
265 |
266 | frame['path'] = frame_path
267 | frame['index'] = idx
268 |
269 | return frame
270 |
271 | def get_neighbor_frame(self, frame, alpha, device='cpu'):
272 | """
273 |
274 | :param frame: dictionary
275 | :param alpha: int
276 | :param device:
277 | :return:
278 | """
279 | idx = frame['index']
280 | scene_name = self.indices[str(idx)]['scene_name']
281 | trajectory_nb = self.indices[str(idx)]['trajectory_nb']
282 | frame_nb = str(int(self.indices[str(idx)]['frame_nb']) + alpha)
283 |
284 | neighbor_frame_path = self.data['scenes'][scene_name]['trajectories'][trajectory_nb]['frames'][frame_nb]['path']
285 | neighbor_frame = torch.load(neighbor_frame_path, map_location=device)
286 |
287 | neighbor_frame['path'] = neighbor_frame_path
288 | neighbor_frame['index'] = idx
289 |
290 | return neighbor_frame
291 |
292 | def get_neighbor_frame_from_idx(self, idx, alpha, device='cpu'):
293 | """
294 |
295 | :param idx: int
296 | :param alpha: int
297 | :param device:
298 | :return:
299 | """
300 | scene_name = self.indices[str(idx)]['scene_name']
301 | trajectory_nb = self.indices[str(idx)]['trajectory_nb']
302 | frame_nb = str(int(self.indices[str(idx)]['frame_nb']) + alpha)
303 |
304 | neighbor_frame_path = self.data['scenes'][scene_name]['trajectories'][trajectory_nb]['frames'][frame_nb]['path']
305 | neighbor_frame = torch.load(neighbor_frame_path, map_location=device)
306 |
307 | neighbor_frame['path'] = neighbor_frame_path
308 | neighbor_frame['index'] = idx
309 |
310 | return neighbor_frame
311 |
312 |
313 | class SceneDataset(Dataset):
314 | def __init__(self, data_path, scene_names=None):
315 | self.data_path = data_path
316 |
317 | # If no scene name is provided, we just take all scenes in the folder
318 | if scene_names is None:
319 | # scene_names = os.listdir(self.data_path)
320 | scene_names = [scene_name for scene_name in os.listdir(self.data_path)
321 | if os.path.isdir(os.path.join(self.data_path, scene_name))]
322 |
323 | self.scene_names = scene_names
324 |
325 | def __len__(self):
326 | # total_length = 0
327 | # for scene_name in self.data['scenes']:
328 | # for trajectory_nb in self.data['scenes'][scene_name]['trajectories']:
329 | # total_length += len(self.data['scenes'][scene_name]['trajectories'][trajectory_nb]['frames'].keys())
330 | # return total_length
331 |
332 | return len(self.scene_names)
333 |
334 | def __getitem__(self, idx): # -> Dict:
335 |
336 | scene_name = self.scene_names[idx]
337 | scene_path = os.path.join(self.data_path, scene_name)
338 |
339 | # Mesh info
340 | obj_name = scene_name + '.obj'
341 |
342 | # Settings info
343 | settings_file = os.path.join(scene_path, 'settings.json')
344 | with open(settings_file, "r") as read_content:
345 | settings = json.load(read_content)
346 |
347 | # Info about occupied camera poses
348 | occupied_pose = torch.load(os.path.join(scene_path, 'occupied_pose.pt'), map_location=torch.device('cpu'))
349 |
350 | scene = {}
351 | scene['scene_name'] = scene_name
352 | scene['obj_name'] = obj_name
353 | scene['settings'] = settings
354 | scene['occupied_pose'] = occupied_pose
355 |
356 | return scene
--------------------------------------------------------------------------------
/SCONE/CustomGeometry.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | # from utils import *
4 |
5 | def get_cartesian_coords(r, elev, azim, in_degrees=False):
6 | """
7 | Returns the cartesian coordinates of 3D points written in spherical coordinates.
8 | :param r: (Tensor) Radius tensor of 3D points, with shape (N).
9 | :param elev: (Tensor) Elevation tensor of 3D points, with shape (N).
10 | :param azim: (Tensor) Azimuth tensor of 3D points, with shape (N).
11 | :param in_degrees: (bool) In True, elevation and azimuth are written in degrees.
12 | Else, in radians.
13 | :return: (Tensor) Cartesian coordinates tensor with shape (N, 3).
14 | """
15 | factor = 1
16 | if in_degrees:
17 | factor *= np.pi / 180.
18 | X = torch.stack((
19 | torch.cos(factor * elev) * torch.sin(factor * azim),
20 | torch.sin(factor * elev),
21 | torch.cos(factor * elev) * torch.cos(factor * azim)
22 | ), dim=2)
23 |
24 | return r * X.view(-1, 3)
25 |
26 |
27 | def get_spherical_coords(X):
28 | """
29 | Returns the spherical coordinates of 3D points written in cartesian coordinates
30 | :param X: (Tensor) Tensor with shape (N, 3) that represents 3D points in cartesian coordinates.
31 | :return: (3-tuple of Tensors) r_x, elev_x and azim_x are Tensors with shape (N) that corresponds
32 | to radius, elevation and azimuths of all 3D points.
33 | """
34 | r_x = torch.linalg.norm(X, dim=1)
35 |
36 | elev_x = torch.asin(X[:, 1] / r_x) # between -pi/2 and pi/2
37 | elev_x[X[:, 1] / r_x <= -1] = -np.pi / 2
38 | elev_x[X[:, 1] / r_x >= 1] = np.pi / 2
39 |
40 | azim_x = torch.acos(X[:, 2] / (r_x * torch.cos(elev_x)))
41 | azim_x[X[:, 2] / (r_x * torch.cos(elev_x)) <= -1] = np.pi
42 | azim_x[X[:, 2] / (r_x * torch.cos(elev_x)) >= 1] = 0.
43 | azim_x[X[:, 0] < 0] *= -1
44 |
45 | return r_x, elev_x, azim_x
46 |
47 | def sample_cameras_on_sphere(n_X, radius, device):
48 | """
49 | Deterministic sampling of camera positions on a sphere.
50 |
51 | :param n_X (int): number of positions to sample. Should be a square int.
52 | :param radius (float): radius of the sphere for sampling.
53 | :param device
54 | :return: A tensor with shape (n_X, 3).
55 | """
56 | delta_theta = 0.9 * np.pi
57 | delta_phi = 0.9 * 2 * np.pi
58 |
59 | n_dim = int(np.sqrt(n_X))
60 | d_theta = 2 * delta_theta / (n_dim - 1)
61 | d_phi = 2 * delta_phi / (n_dim - 1)
62 |
63 | increments = torch.linspace(0, n_dim - 1, n_dim, device=device)
64 |
65 | thetas = -delta_theta + increments * d_theta
66 | phis = -delta_phi + increments * d_phi
67 |
68 | thetas = thetas.view(n_dim, 1).expand(-1, n_dim)
69 | phis = phis.view(1, n_dim).expand(n_dim, -1)
70 |
71 | X = torch.stack((
72 | torch.cos(thetas) * torch.sin(phis),
73 | torch.sin(thetas),
74 | torch.cos(thetas) * torch.cos(phis)
75 | ), dim=2)
76 |
77 | return radius * X.view(-1, 3)
78 |
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/SCONE/SconeOcc.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from Attention import *
3 | from utils import get_knn_points
4 |
5 |
6 | class XEmbedding(nn.Module):
7 | def __init__(self, x_dim, x_embedding_dim, dropout=None, gelu=True):
8 | """
9 | Neural module for individual 3D point embedding using fully connected layers.
10 |
11 | :param x_dim: (int) Dimension of input 3D points. Usually, x_dim=3.
12 | :param x_embedding_dim: (int) Dimension of output embeddings.
13 | :param dropout: Dropout module to apply on computed features.
14 | :param gelu: (bool) If True, the model uses GELU non-linearity. Else, it uses ReLU.
15 | """
16 | super(XEmbedding, self).__init__()
17 |
18 | self.linear1 = nn.Linear(x_dim, x_embedding_dim // 4)
19 | # self.fc1 = nn.Linear(3*2*self.n_harmonic_functions, self.x_embd_size//2)
20 | self.linear2 = nn.Linear(x_embedding_dim // 4, x_embedding_dim // 2)
21 | self.linear3 = nn.Linear(x_embedding_dim // 2, x_embedding_dim)
22 |
23 | if gelu:
24 | self.non_linear1 = nn.GELU()
25 | self.non_linear2 = nn.GELU()
26 | self.non_linear3 = nn.GELU()
27 | else:
28 | self.non_linear1 = nn.ReLU(inplace=False)
29 | self.non_linear2 = nn.ReLU(inplace=False)
30 | self.non_linear3 = nn.ReLU(inplace=False)
31 |
32 | self.dropout = nn.Dropout(dropout) if dropout is not None else None
33 |
34 | def forward(self, x):
35 | res = self.non_linear1(self.linear1(x))
36 | res = self.non_linear2(self.linear2(res))
37 | res = self.non_linear3(self.linear3(res))
38 |
39 | res = res if self.dropout is None else self.dropout(res)
40 |
41 | return res
42 |
43 |
44 | class PCTransformer(nn.Module):
45 | def __init__(self, seq_len, pts_dim=3, pts_embedding_dim=256, feature_dim=512,
46 | concatenate_input=True,
47 | n_code=2, n_heads=4, FF=True, gelu=True,
48 | dropout=None):
49 | """
50 | Main class for Transformer units dedicated to point cloud global encoding.
51 |
52 | :param seq_len: (int) Length of input point cloud sequence.
53 | :param pts_dim: (int) Dimension of input points. Usually, pts_dim=3.
54 | :param pts_embedding_dim: (int) Dimension of point embeddings.
55 | :param feature_dim: (int) Dimension of output features.
56 | :param concatenate_input: (bool) If True, the model concatenates the raw input to their initial embedding.
57 | :param n_code: (int) Number of Multi-Head Self Attention units.
58 | :param n_heads: (int) Number of heads in Multi-Head Self Attention units.
59 | :param FF: (bool) If True, the Transformer encoder applies a Feed Forward unit after
60 | each Multi-Head Self-Attention unit.
61 | :param gelu: (bool) If True, the model uses GELU non-linearity. Else, it uses ReLU.
62 | :param dropout: Dropout module to apply on computed features.
63 | """
64 | super(PCTransformer, self).__init__()
65 |
66 | # Parameters
67 | self.seq_len = seq_len
68 | self.pts_dim = pts_dim
69 | self.pts_embedding_dim = pts_embedding_dim
70 |
71 | self.n_code = n_code
72 | self.n_heads = n_heads
73 | self.FF = FF
74 | self.gelu = gelu
75 |
76 | self.feature_dim = feature_dim
77 |
78 | self.dropout = dropout
79 |
80 | # Point embedding
81 | self.embedding = Embedding(input_dim=pts_dim, output_dim=pts_embedding_dim,
82 | dropout=None, gelu=gelu,
83 | global_feature=False, additional_feature_dim=0,
84 | concatenate_input=concatenate_input,
85 | k_for_knn=0)
86 |
87 | # Point cloud encoder Backbone
88 | encoders = []
89 | for i in range(n_code):
90 | encoders += [Encoder(seq_len=seq_len,
91 | embedding_dim=pts_embedding_dim,
92 | qk_dim=pts_embedding_dim // 4,
93 | n_heads=n_heads,
94 | dropout=dropout,
95 | gelu=gelu,
96 | FF=FF)]
97 |
98 | self.encoders = nn.ModuleList(encoders)
99 | self.norm = nn.LayerNorm(self.pts_embedding_dim)
100 | self.linear0 = nn.Linear(self.pts_embedding_dim, feature_dim // 2)
101 |
102 | self.max_pool = nn.functional.max_pool1d # kernel_size will be <= seq_len
103 | self.avg_pool = nn.functional.avg_pool1d # kernel_size will be <= seq_len
104 |
105 | def forward(self, pc, mask=None):
106 | """
107 | Forward pass.
108 |
109 | :param pc: (Tensor) Point cloud tensor with shape (n_clouds, seq_len, pts_dim)
110 | :param mask: (Tensor) Mask tensor with shape (batch_size, seq_len, seq_len). Optional.
111 | :return: (Tensor) Output features with shape (batch_size, features_dim)
112 | """
113 | n_clouds, seq_len = pc.shape[0], pc.shape[1]
114 |
115 | pc_embd = self.embedding(pc)
116 | for encoder in self.encoders:
117 | pc_embd = encoder(pc_embd, mask=mask)
118 | features = self.norm(pc_embd)
119 |
120 | # Linear transformation and pooling operations for downstream tasks
121 | features = self.linear0(features)
122 |
123 | features = features.transpose(dim0=-1, dim1=-2)
124 | features = torch.cat((self.max_pool(input=features, kernel_size=seq_len),
125 | self.avg_pool(input=features, kernel_size=seq_len)), dim=-2)
126 |
127 | features = features.view(n_clouds, self.feature_dim)
128 |
129 | return features
130 |
131 |
132 | class SconeOcc(nn.Module):
133 | def __init__(self, seq_len=2048, pts_dim=3, pts_embedding_dim=128,
134 | concatenate_input=True,
135 | n_code=2, n_heads=4, FF=True, gelu=True,
136 | global_feature_dim=512,
137 | n_scale=3, local_feature_dim=256, k_for_knn=16,
138 | x_dim=3, x_embedding_dim=512,
139 | n_harmonics=64,
140 | output_dim=1,
141 | dropout=None,
142 | offset=True):
143 | """
144 | Main class for SCONE's occupancy probability prediction module.
145 | A neural model that predicts a vector field as an implicit function, depending on an input point cloud
146 | and view state harmonic features representing the history of camera poses.
147 | A Transformer with Multi-Scale Neighborhood features (MSN features) is used to encode the point cloud,
148 | depending on the query point x.
149 |
150 | :param seq_len: (int) Number of points in the input point cloud.
151 | :param pts_dim: (int) Dimension of points in the input point cloud.
152 | :param pts_embedding_dim: (int) Dimension of embedded point cloud.
153 | :param concatenate_input: (bool) If True, concatenates raw input points to the point embeddings
154 | in the point cloud.
155 | :param n_code: (int) Number of Multi-Head Self Attention units.
156 | :param n_heads: (int) Number of heads in Multi-Head Self Attention units.
157 | :param FF: (bool) If True, the Transformer encoder applies a Feed Forward unit after
158 | each Multi-Head Self-Attention unit.
159 | :param gelu: (bool) If True, the model uses GELU non-linearity. Else, it uses ReLU.
160 | :param global_feature_dim: (int) Dimension of the point cloud global feature.
161 | :param n_scale: (int) Number of scales to compute neighborhood features in the point cloud.
162 | :param local_feature_dim: (int) Dimension of point cloud neighborhood features.
163 | :param k_for_knn: (int) Number of neighbors to use when computing a neighborhood feature.
164 | :param x_dim: (int) Dimension of query point x.
165 | :param x_embedding_dim: (int) Dimension of x embedding.
166 | :param n_harmonics: (int) Number of harmonics used to compute view_state harmonic features.
167 | :param output_dim: (int) Dimension of the output vector field.
168 | :param dropout: Dropout module to apply on computed embeddings.
169 | :param offset: (bool) If True, the model uses the offset between x and its neighbors rather than
170 | the coordinates of the neighbors to compute neighborhood features.
171 | This parameter should always be True, since it leads to far better performances.
172 | """
173 | super(SconeOcc, self).__init__()
174 |
175 | # Parameters
176 | self.seq_len = seq_len
177 | self.pts_dim = pts_dim
178 | self.pts_embedding_dim = pts_embedding_dim
179 |
180 | self.n_code = n_code
181 | self.n_heads = n_heads
182 | self.FF = FF
183 | self.gelu = gelu
184 |
185 | self.n_scale = n_scale
186 |
187 | self.x_dim = x_dim
188 | self.x_embedding_dim = x_embedding_dim
189 |
190 | self.output_dim = output_dim
191 |
192 | self.dropout = dropout
193 |
194 | self.encoding_dim = pts_embedding_dim
195 |
196 | self.k_for_knn = k_for_knn
197 | self.offset = offset
198 | if self.offset:
199 | print("Offset set to True.")
200 |
201 | self.global_feature_dim = global_feature_dim
202 | self.local_feature_dim = local_feature_dim
203 | self.all_feature_size = self.x_embedding_dim \
204 | + self.n_scale * self.local_feature_dim \
205 | + self.global_feature_dim + n_harmonics
206 |
207 | # Point cloud transformers
208 | self.global_transformer = PCTransformer(seq_len=seq_len, pts_dim=pts_dim,
209 | pts_embedding_dim=pts_embedding_dim, feature_dim=global_feature_dim,
210 | concatenate_input=concatenate_input,
211 | n_code=n_code, n_heads=n_heads, FF=FF, gelu=gelu,
212 | dropout=dropout)
213 |
214 | local_transformers = []
215 | for i in range(n_scale):
216 | local_transformers += [PCTransformer(seq_len=k_for_knn, pts_dim=pts_dim,
217 | pts_embedding_dim=pts_embedding_dim, feature_dim=local_feature_dim,
218 | concatenate_input=concatenate_input,
219 | n_code=n_code, n_heads=n_heads, FF=FF, gelu=gelu,
220 | dropout=dropout)]
221 | self.local_transformers = nn.ModuleList(local_transformers)
222 |
223 | # X embedding
224 | self.x_embedding = XEmbedding(x_dim=x_dim, x_embedding_dim=x_embedding_dim,
225 | dropout=dropout, gelu=gelu)
226 |
227 | # Point cloud feature extraction
228 | self.max_pool = nn.functional.max_pool1d # kernel_size will be <= seq_len
229 | self.avg_pool = nn.functional.avg_pool1d # kernel_size will be <= seq_len
230 |
231 | # MLP for occupancy probability prediction
232 | self.linear1 = nn.Linear(self.all_feature_size, 512)
233 | self.linear2 = nn.Linear(512, 256)
234 | self.linear3 = nn.Linear(256, output_dim)
235 |
236 | # self.non_linear1 = nn.ReLU(inplace=False)
237 | # self.non_linear2 = nn.ReLU(inplace=False)
238 | # self.non_linear3 = nn.ReLU(inplace=False)
239 |
240 | if gelu:
241 | self.non_linear1 = nn.GELU()
242 | self.non_linear2 = nn.GELU()
243 | self.non_linear3 = nn.GELU()
244 | else:
245 | self.non_linear1 = nn.ReLU(inplace=False)
246 | self.non_linear2 = nn.ReLU(inplace=False)
247 | self.non_linear3 = nn.ReLU(inplace=False)
248 |
249 | def forward(self, pc, x, view_harmonics, mask=None):
250 | """
251 | Forward pass.
252 | :param pc: (Tensor) Input point cloud tensor with shape (n_clouds, seq_len, pts_dim)
253 | :param x: (Tensor) Input query points tensor with shape (n_clouds, n_sample, x_dim)
254 | :param view_harmonics: (Tensor) View state harmonic features.
255 | Tensor with shape (n_clouds, n_sample, n_harmonics).
256 | :param mask: (Tensor) Mask tensor with shape (batch_size, seq_len, seq_len). Optional.
257 | :return: (Tensor) Output vector field values for each query point in x.
258 | Has shape (n_clouds, n_sample, output_dim)
259 | """
260 | n_clouds, full_seq_len = pc.shape[0], pc.shape[1]
261 | n_sample = x.shape[1]
262 |
263 | # -----Point cloud global encoding-----
264 | # Down sampling point cloud for global embedding
265 | global_down_sampled_pc = pc[:, torch.randperm(pc.shape[1])[:self.seq_len]]
266 | seq_len = global_down_sampled_pc.shape[1]
267 |
268 | global_features = self.global_transformer(global_down_sampled_pc)
269 |
270 | # -----Point cloud local encoding-----
271 | # Computing down sampling factor
272 | if self.n_scale > 1:
273 | ds_factor = int(np.power(full_seq_len / (self.k_for_knn * 8), 1./(self.n_scale - 1)))
274 | if ds_factor == 0:
275 | # print("Problem: ds_factor=0 encountered. Taking ds_factor=2 as a default value.")
276 | ds_factor = 2
277 | else:
278 | ds_factor = 1
279 |
280 | # kNN computation for local embedding
281 | down_sampled_pc = pc
282 | local_transformed = []
283 | for n_transformer in range(self.n_scale):
284 | local_transformer = self.local_transformers[n_transformer]
285 | # Get kNN points in down sampled pc
286 | local_pc, _, _ = get_knn_points(x, down_sampled_pc, self.k_for_knn)
287 | if self.offset:
288 | local_pc = local_pc - x.view(n_clouds, n_sample, 1, 3)
289 |
290 | # Compute features
291 | local_transformed += [local_transformer(local_pc.view(-1, self.k_for_knn, 3), mask=mask)]
292 |
293 | # Down sample pc
294 | ds_seq_len = down_sampled_pc.shape[1]
295 | # print("Ds seq len:", ds_seq_len)
296 |
297 | if n_transformer < self.n_scale-1:
298 | down_sampled_pc = down_sampled_pc[:, torch.randperm(ds_seq_len)[:ds_seq_len // ds_factor]]
299 | # print("DS pc:", down_sampled_pc.shape)
300 |
301 | if self.n_scale > 0:
302 | local_features = torch.cat(local_transformed, dim=-1)
303 | else:
304 | local_features = torch.zeros(n_clouds, n_sample, 0, device=pc.get_device())
305 | local_features = local_features.view(n_clouds, n_sample, self.n_scale * self.local_feature_dim)
306 |
307 | # -----X encoding-----
308 | x_features = self.x_embedding(x)
309 |
310 | # -----Occupancy prediction-----
311 | global_features = global_features.view(n_clouds, 1, self.global_feature_dim).expand(-1, n_sample, -1)
312 | x_features = x_features.view(n_clouds, n_sample, self.x_embedding_dim)
313 |
314 | res = torch.cat((global_features, local_features, x_features, view_harmonics), dim=-1)
315 | res = self.non_linear1(self.linear1(res))
316 | res = self.non_linear2(self.linear2(res))
317 | res = self.non_linear3(self.linear3(res))
318 |
319 | return res.view(n_clouds, n_sample, self.output_dim)
320 |
--------------------------------------------------------------------------------
/SCONE/SconeVis.py:
--------------------------------------------------------------------------------
1 | from Attention import *
2 | from CustomGeometry import get_spherical_coords
3 | from spherical_harmonics import clear_spherical_harmonics_cache, get_spherical_harmonics
4 |
5 |
6 | class SconeVis(nn.Module):
7 | def __init__(self,
8 | pts_dim=4, seq_len=2048, pts_embedding_dim=256,
9 | n_heads=4, n_code=3,
10 | n_harmonics=64,
11 | max_harmonic_rank=8,
12 | FF=True,
13 | gelu=True,
14 | dropout=None,
15 | use_view_state=True,
16 | use_global_feature=True,
17 | view_state_mode="end",
18 | concatenate_input=True,
19 | k_for_knn=0,
20 | alt=False,
21 | use_sigmoid=True):
22 | """
23 | Main class for SCONE's visibility prediction module.
24 |
25 | :param pts_dim: (int) Input dimension. Since SCONE processes clouds of 3D-points concatenated with
26 | their occupancy probability, pts_dim should be equal to 4.
27 | :param seq_len: (int) Maximal number of points in the cloud.
28 | :param pts_embedding_dim: (int) Dimension of points' embeddings.
29 | :param n_heads: (int) Number of heads in Multi-Head Self Attention units.
30 | :param n_code: (int) Number of Multi-Head Self Attention units.
31 | :param n_harmonics: (int) Number of harmonics to use to encode visibility gain functions.
32 | :param max_harmonic_rank: (int) Maximal harmonic rank for harmonic functions.
33 | :param FF: (bool) If True, Transformer encoder(s) apply a Feed Forward unit after
34 | each Multi-Head Self-Attention unit.
35 | :param gelu: (bool) If True, the model uses GELU non-linearity. Else, it uses ReLU.
36 | :param dropout: Dropout module to apply on computed features.
37 | :param use_view_state: (bool) If True, model uses view_state harmonics as additional features.
38 | :param use_global_feature: (bool) If True, model computes an additional global feature concatenated to each
39 | point's embedding before applying the Transformer encoder.
40 | :param view_state_mode: (str) If view_state_mode=='start', view_state features are concatenated
41 | to the points' embeddings before applying the Transformer encoder.
42 | If view_state_mode=='end', view_state features are concatenated to the points' embeddings
43 | after applying the Transformer encoder.
44 | :param concatenate_input: (bool) If True, the model concatenates the raw input to their initial embedding.
45 | :param k_for_knn: (int) If > 0, the model compute embeddings for points based on their k nearest neighbors
46 | :param alt:(bool) If True, uses an alternate architecture for the end of the network.
47 | :param use_sigmoid: (bool) If True, uses a sigmoid function on predicted visibility scores.
48 | """
49 | super(SconeVis, self).__init__()
50 |
51 | # self.harmonicEmbedder = HarmonicEmbedding(n_harmonic_functions=30, omega0=0.1)
52 | # self.n_harmonic_functions = 30
53 | # self.harmonicEmbedder = HarmonicEmbedding(n_harmonic_functions=self.n_harmonic_functions, omega0=1)
54 |
55 | self.n_harmonics = n_harmonics
56 |
57 | self.pts_dim = pts_dim
58 | self.seq_len = seq_len
59 | self.pts_embedding_dim = pts_embedding_dim
60 | self.n_heads = n_heads
61 | self.n_code = n_code
62 | self.n_harmonics = n_harmonics
63 | self.max_harmonic_rank = max_harmonic_rank
64 |
65 | self.use_view_state = use_view_state
66 | self.use_global_feature = use_global_feature
67 | self.view_state_mode = view_state_mode
68 |
69 | self.alt = alt
70 |
71 | self.use_sigmoid = use_sigmoid
72 | if use_sigmoid:
73 | print("Use sigmoid in model.")
74 | else:
75 | print("Use ReLU for output in model.")
76 |
77 | # Input embedding
78 | if use_view_state and view_state_mode == "start":
79 | additional_feature_dim = n_harmonics
80 | else:
81 | additional_feature_dim = 0
82 | self.embedding = Embedding(pts_dim, pts_embedding_dim, gelu=gelu,
83 | global_feature=use_global_feature,
84 | additional_feature_dim=additional_feature_dim,
85 | concatenate_input=concatenate_input,
86 | k_for_knn=k_for_knn,
87 | dropout=None)
88 |
89 | # Encoder Backbone
90 | encoders = []
91 | for i in range(n_code):
92 | encoders += [Encoder(seq_len=seq_len,
93 | embedding_dim=pts_embedding_dim,
94 | qk_dim=pts_embedding_dim//4,
95 | n_heads=n_heads,
96 | dropout=dropout,
97 | gelu=gelu,
98 | FF=FF)]
99 | self.encoders = nn.ModuleList(encoders)
100 |
101 | self.norm = nn.LayerNorm(pts_embedding_dim)
102 |
103 | # MLP for Harmonics prediction
104 | if not alt:
105 | fc1_input_dim = pts_embedding_dim
106 | inner_feature_factor = 4
107 | if use_view_state and view_state_mode == "end":
108 | inner_feature_factor = 3
109 | else:
110 | fc1_input_dim = pts_embedding_dim + n_harmonics
111 | inner_feature_factor = 4
112 |
113 | self.fc1 = nn.Linear(fc1_input_dim, inner_feature_factor * n_harmonics)
114 | self.nonlinear1 = nn.GELU()
115 |
116 | self.fc2 = nn.Linear(4 * n_harmonics, 2 * n_harmonics)
117 | self.nonlinear2 = nn.GELU()
118 |
119 | self.fc3 = nn.Linear(2 * n_harmonics, n_harmonics)
120 |
121 | def forward(self, pts, mask=None, view_harmonics=None):
122 | """
123 | Forward pass.
124 | :param pts: (Tensor) Input point cloud. Tensor with shape (n_clouds, seq_len, pts_dim)
125 | :param mask: (Tensor) Mask tensor with shape (batch_size, seq_len, seq_len). Optional.
126 | :param view_harmonics: (Tensor) View state harmonic features. Tensor with shape (n_clouds, seq_len, n_harmonics)
127 | :return: (Tensor) Visibility gains functions of each point as coordinates in spherical harmonics.
128 | Has shape (n_clouds, seq_len, n_harmonics)
129 | """
130 | n_clouds = len(pts)
131 | seq_len = pts.shape[1]
132 |
133 | # Input embedding
134 | if self.use_view_state and self.view_state_mode == "start":
135 | x = self.embedding(pts, additional_feature=view_harmonics)
136 | else:
137 | x = self.embedding(pts)
138 |
139 | # Applying Encoders
140 | for encoder in self.encoders:
141 | x = encoder(x, mask=mask)
142 |
143 | # Final normalization, and linear layer for downstream task
144 | res = self.norm(x)
145 |
146 | if not self.alt:
147 | res = self.nonlinear1(self.fc1(res))
148 |
149 | # Concatenating view harmonics if needed, and final prediction
150 | if self.use_view_state and self.view_state_mode == "end":
151 | res = torch.cat((res, view_harmonics), dim=-1)
152 | res = self.nonlinear2(self.fc2(res))
153 | res = self.fc3(res)
154 | else:
155 | res = torch.cat((res, view_harmonics), dim=-1)
156 | res = self.nonlinear1(self.fc1(res))
157 | res = self.nonlinear2(self.fc2(res))
158 | res = self.fc3(res)
159 |
160 | res = res.view(n_clouds, seq_len, self.n_harmonics)
161 |
162 | return res
163 |
164 | def compute_visibilities(self, pts, harmonics, X_cam):
165 | """
166 | Compute visibility gains of each points in pts for each camera in X_cam.
167 | :param pts: (Tensor) Input point cloud. Tensor with shape (n_clouds, seq_len, pts_dim)
168 | :param harmonics: (Tensor) Predicted visibility gain functions as coordinates in spherical harmonics.
169 | Has shape (n_clouds, seq_len, n_harmonics).
170 | :param X_cam: (Tensor) Tensor of camera centers' positions, with shape (n_clouds, n_camera_candidates, 3)
171 | :return: (Tensor) The predicted per-point visibility gains of all points.
172 | Has shape (n_clouds, n_camera_candidates, seq_len)
173 | """
174 | clear_spherical_harmonics_cache()
175 | n_clouds = pts.shape[0]
176 | seq_len = pts.shape[1]
177 | n_harmonics = self.n_harmonics
178 | n_camera_candidates = X_cam.shape[1]
179 |
180 | device = pts.get_device()
181 | if device < 0:
182 | device = "cpu"
183 |
184 | X_pts = pts[..., :3]
185 |
186 | # tmp_pts = X_pts.view(n_clouds, 1, seq_len, 3).expand(-1, n_camera_candidates, -1, -1)
187 | # tmp_h = harmonics.view(n_clouds, 1, seq_len, 64).expand(-1, n_camera_candidates, -1, -1)
188 | # tmp_cam = X_cam.view(n_clouds, n_camera_candidates, 1, 3).expand(-1, -1, seq_len, -1)
189 |
190 | rays = (X_cam.view(n_clouds, n_camera_candidates, 1, 3).expand(-1, -1, seq_len, -1)
191 | - X_pts.view(n_clouds, 1, seq_len, 3).expand(-1, n_camera_candidates, -1, -1)).view(-1, 3)
192 | _, theta, phi = get_spherical_coords(rays)
193 | theta = -theta + np.pi / 2.
194 |
195 | z = torch.zeros([i for i in theta.shape] + [0], device=device)
196 | for i in range(self.max_harmonic_rank):
197 | y = get_spherical_harmonics(l=i, theta=theta, phi=phi)
198 | z = torch.cat((z, y), dim=-1)
199 | z = z.view(n_clouds, n_camera_candidates, seq_len, n_harmonics)
200 |
201 | z = torch.sum(z * harmonics.view(n_clouds, 1, seq_len, 64).expand(-1, n_camera_candidates, -1, -1), dim=-1)
202 | if self.use_sigmoid:
203 | z = torch.sigmoid(z)
204 | else:
205 | z = torch.relu(z)
206 | # z = torch.sum(z, dim=-1) / seq_len
207 |
208 | return z
209 |
210 | def compute_coverage_gain(self, pts, harmonics, X_cam):
211 | """
212 | Computes global coverage gain for each camera candidate in X_cam.
213 | :param pts: tensor with shape (n_clouds, seq_len, 3 or 4)
214 | :param harmonics: tensor with shape (n_clouds, seq_len, n_harmonics)
215 | :param X_cam: tensor with shape (n_clouds, n_camera_candidates, 3)
216 | :return: A tensor z with shape (n_clouds, n_camera_candidates)
217 | """
218 | clear_spherical_harmonics_cache()
219 | n_clouds = pts.shape[0]
220 | seq_len = pts.shape[1]
221 | n_harmonics = self.n_harmonics
222 | n_camera_candidates = X_cam.shape[1]
223 |
224 | X_pts = pts[..., :3]
225 |
226 | # tmp_pts = X_pts.view(n_clouds, 1, seq_len, 3).expand(-1, n_camera_candidates, -1, -1)
227 | # tmp_h = harmonics.view(n_clouds, 1, seq_len, 64).expand(-1, n_camera_candidates, -1, -1)
228 | # tmp_cam = X_cam.view(n_clouds, n_camera_candidates, 1, 3).expand(-1, -1, seq_len, -1)
229 |
230 | rays = (X_cam.view(n_clouds, n_camera_candidates, 1, 3).expand(-1, -1, seq_len, -1)
231 | - X_pts.view(n_clouds, 1, seq_len, 3).expand(-1, n_camera_candidates, -1, -1)).view(-1, 3)
232 | _, theta, phi = get_spherical_coords(rays)
233 | theta = -theta + np.pi/2.
234 |
235 | z = torch.zeros([i for i in theta.shape] + [0], device=theta.get_device())
236 | for i in range(self.max_harmonic_rank):
237 | y = get_spherical_harmonics(l=i, theta=theta, phi=phi)
238 | z = torch.cat((z, y), dim=-1)
239 | z = z.view(n_clouds, n_camera_candidates, seq_len, n_harmonics)
240 |
241 | z = torch.sum(z * harmonics.view(n_clouds, 1, seq_len, 64).expand(-1, n_camera_candidates, -1, -1), dim=-1)
242 | if self.use_sigmoid:
243 | z = torch.sigmoid(z)
244 | else:
245 | z = torch.relu(z)
246 |
247 | # TO REMOVE
248 | # z[pts[..., 3].view(n_clouds, 1, seq_len).expand(-1, n_camera_candidates, -1) > 0.9] = 0
249 |
250 | z = torch.sum(z, dim=-1) / seq_len
251 |
252 | return z
253 |
254 | def compute_coverage_gain_multiple(self, pts, harmonics, X_cam, n_cam):
255 | """
256 | Computes global coverage gains for each n_cam-subset of camera candidates in X_cam.
257 | :param pts: tensor with shape (n_clouds, seq_len, 3 or 4)
258 | :param harmonics: tensor with shape (n_clouds, seq_len, n_harmonics)
259 | :param X_cam: tensor with shape (n_clouds, n_camera_candidates, 3)
260 | :param n_cam: number of simultaneous NBV to select
261 | :return: A tensor z with shape (n_clouds, n_camera_candidates)
262 | """
263 | clear_spherical_harmonics_cache()
264 | n_clouds = pts.shape[0]
265 | seq_len = pts.shape[1]
266 | n_harmonics = self.n_harmonics
267 | n_camera_candidates = X_cam.shape[1]
268 |
269 | X_pts = pts[..., :3]
270 |
271 | # tmp_pts = X_pts.view(n_clouds, 1, seq_len, 3).expand(-1, n_camera_candidates, -1, -1)
272 | # tmp_h = harmonics.view(n_clouds, 1, seq_len, 64).expand(-1, n_camera_candidates, -1, -1)
273 | # tmp_cam = X_cam.view(n_clouds, n_camera_candidates, 1, 3).expand(-1, -1, seq_len, -1)
274 |
275 | rays = (X_cam.view(n_clouds, n_camera_candidates, 1, 3).expand(-1, -1, seq_len, -1)
276 | - X_pts.view(n_clouds, 1, seq_len, 3).expand(-1, n_camera_candidates, -1, -1)).view(-1, 3)
277 | _, theta, phi = get_spherical_coords(rays)
278 | theta = -theta + np.pi/2.
279 |
280 | z = torch.zeros([i for i in theta.shape] + [0], device=theta.get_device())
281 | for i in range(self.max_harmonic_rank):
282 | y = get_spherical_harmonics(l=i, theta=theta, phi=phi)
283 | z = torch.cat((z, y), dim=-1)
284 | z = z.view(n_clouds, n_camera_candidates, seq_len, n_harmonics)
285 |
286 | z = torch.sum(z * harmonics.view(n_clouds, 1, seq_len, 64).expand(-1, n_camera_candidates, -1, -1), dim=-1)
287 | if self.use_sigmoid:
288 | z = torch.sigmoid(z)
289 | else:
290 | z = torch.relu(z)
291 |
292 | single_idx = torch.arange(0, n_camera_candidates)
293 | if n_cam == 2:
294 | n_idx = torch.cartesian_prod(single_idx, single_idx)
295 | elif n_cam == 3:
296 | n_idx = torch.cartesian_prod(single_idx, single_idx, single_idx)
297 | else:
298 | raise NameError("n_cam is too large.")
299 |
300 | n_z = z[:, n_idx]
301 | n_z = torch.sum(torch.max(n_z, dim=-2)[0], dim=-1) / seq_len
302 |
303 | return n_z, n_idx
304 |
305 |
306 | class KLDivCE(nn.Module):
307 | """
308 | Layer to compute KL-divergence after applying Softmax
309 | """
310 |
311 | def __init__(self):
312 | super(KLDivCE, self).__init__()
313 |
314 | self.kl_div = nn.KLDivLoss(reduction='batchmean')
315 | self.log_soft_max = nn.LogSoftmax(dim=1)
316 |
317 | def forward(self, x, y):
318 | loss = self.kl_div(self.log_soft_max(x), torch.softmax(y, dim=1))
319 | return loss
320 |
321 |
322 | class L1_loss(nn.Module):
323 | """
324 | Layer to compute L1 loss between normalized coverage distributions
325 | """
326 |
327 | def __init__(self):
328 | super(L1_loss, self).__init__()
329 | self.epsilon = 1e-7
330 |
331 | def forward(self, x, y):
332 | """
333 |
334 | :param x: (Tensor) Should have shape (batch_size, n_camera, 1)
335 | :param y: (Tensor) Should have shape (batch_size, n_camera, 1)
336 | :return:
337 | """
338 | batch_size, n_camera = x.shape[0], x.shape[1]
339 | x_mean = x.mean(dim=1, keepdim=True).expand(-1, n_camera, -1)
340 | y_mean = y.mean(dim=1, keepdim=True).expand(-1, n_camera, -1)
341 |
342 | x_std = x.std(dim=1, keepdim=True).expand(-1, n_camera, -1)
343 | y_std = y.std(dim=1, keepdim=True).expand(-1, n_camera, -1)
344 |
345 | norm_x = (x-x_mean) / (x_std + self.epsilon)
346 | norm_y = (y-y_mean) / (y_std + self.epsilon)
347 |
348 | loss = (norm_x - norm_y).abs().mean(dim=1)
349 |
350 | return loss.mean()
351 |
352 |
353 | class Uncentered_L1_loss(nn.Module):
354 | """
355 | Layer to compute L1 loss between normalized coverage distributions
356 | """
357 |
358 | def __init__(self):
359 | super(Uncentered_L1_loss, self).__init__()
360 | self.epsilon = 1e-7
361 |
362 | def forward(self, x, y):
363 | """
364 |
365 | :param x: (Tensor) Should have shape (batch_size, n_camera, 1)
366 | :param y: (Tensor) Should have shape (batch_size, n_camera, 1)
367 | :return:
368 | """
369 | batch_size, n_camera = x.shape[0], x.shape[1]
370 | x_mean = x.mean(dim=1, keepdim=True).expand(-1, n_camera, -1)
371 | y_mean = y.mean(dim=1, keepdim=True).expand(-1, n_camera, -1)
372 |
373 | norm_x = x / (x_mean + self.epsilon)
374 | norm_y = y / (y_mean + self.epsilon)
375 |
376 | loss = (norm_x - norm_y).abs().mean(dim=1)
377 |
378 | return loss.mean()
379 |
--------------------------------------------------------------------------------
/SCONE/idr_torch.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | import os
5 | # import hostlist
6 |
7 | # Dummy script to avoid import errors
8 | rank = 0
9 | local_rank = 0
10 | size = 4
11 | cpus_per_task = 10
12 | hostnames = []
13 | gpus_ids = [0]
14 |
15 | # # get SLURM variables
16 | # rank = int(os.environ['SLURM_PROCID'])
17 | # local_rank = int(os.environ['SLURM_LOCALID'])
18 | # size = int(os.environ['SLURM_NTASKS'])
19 | # cpus_per_task = int(os.environ['SLURM_CPUS_PER_TASK'])
20 | #
21 | # # get node list from slurm
22 | # hostnames = hostlist.expand_hostlist(os.environ['SLURM_JOB_NODELIST'])
23 | #
24 | # # get IDs of reserved GPU
25 | # gpu_ids = os.environ['SLURM_STEP_GPUS'].split(",")
26 | #
27 | # # define MASTER_ADD & MASTER_PORT
28 | # os.environ['MASTER_ADDR'] = hostnames[0]
29 | # os.environ['MASTER_PORT'] = str(12345 + int(min(gpu_ids))) # to avoid port conflict on the same node
--------------------------------------------------------------------------------
/SCONE/pretrain_scone_occ.py:
--------------------------------------------------------------------------------
1 | # Updated from former script train_proba_model_faster.py
2 | from scone_utils import *
3 | import sys
4 |
5 | save_parameters = True
6 | start_training = False
7 |
8 | debug = False
9 |
10 | def save_train_params(save=False):
11 | # TO DO: for ddp, save only if is_master but load for everyone, with a synchronization in-between.
12 | params = {}
13 |
14 | # -----General parameters-----
15 | params["ddp"] = False
16 | params["jz"] = True
17 |
18 | if params["ddp"]:
19 | params["CUDA_VISIBLE_DEVICES"] = "0, 1"
20 | params["WORLD_SIZE"] = 2
21 |
22 | elif params["jz"]:
23 | params["WORLD_SIZE"] = idr_torch.size
24 |
25 | else:
26 | params["numGPU"] = 0
27 | params["WORLD_SIZE"] = 1
28 |
29 | params["anomaly_detection"] = True
30 | params["empty_cache_every_n_batch"] = 10
31 |
32 | # -----Ground truth computation parameters-----
33 | params["compute_gt_online"] = False
34 | params["compute_partial_point_cloud_online"] = False
35 |
36 | params["gt_surface_resolution"] = 1.5
37 |
38 | params["gt_max_diagonal"] = 1. # Formerly known as x_range
39 |
40 | params["n_points_surface"] = 16384 # To remove?
41 |
42 | # -----Model Parameters-----
43 | params["seq_len"] = 2048
44 | params["n_sample"] = 6000 # 12000, 100000
45 |
46 | params["view_state_n_elev"] = 7
47 | params["view_state_n_azim"] = 2 * 7
48 | params["harmonic_degree"] = 8
49 |
50 | params["min_occ"] = 0.01
51 |
52 | # Ablation study
53 | params["no_local_features"] = False # False
54 | params["no_view_harmonics"] = False # False
55 | # params["noise_std"] = 5 * np.sqrt(3) * params["side"]
56 |
57 | # -----General training parameters-----
58 | params["start_from_scratch"] = True
59 | params["pretrained_weights_name"] = None
60 |
61 | params["n_view_max"] = 5
62 | params["n_view_min"] = 1
63 | params["n_point_max_for_prediction"] = 300000 # 300000
64 |
65 | params["camera_dist"] = 1.5
66 | params["pole_cameras"] = True
67 | params["n_camera_elev"] = 5
68 | params["n_camera_azim"] = 2 * 5
69 | params["n_camera"] = params["n_camera_elev"] * params["n_camera_azim"]
70 | if params["pole_cameras"]:
71 | params["n_camera"] += 2
72 |
73 | params["prediction_in_random_camera_space"] = False
74 |
75 | params["total_batch_size"] = 12 # 12
76 | params["batch_size"] = params["total_batch_size"] // params["WORLD_SIZE"]
77 | params["total_batch_size"] = params["batch_size"] * params["WORLD_SIZE"]
78 |
79 | params["epochs"] = 1000
80 | params["learning_rate"] = 1e-4
81 |
82 | params["schedule_learning_rate"] = True
83 | if params["schedule_learning_rate"]:
84 | params["lr_epochs"] = [250]
85 | params["lr_factor"] = 0.1
86 |
87 | params["warmup"] = 1000
88 | params["warmup_rate"] = 1 / (params["warmup"] * params["learning_rate"] ** 2)
89 |
90 | params["noam_opt"] = False
91 | params["training_loss"] = "mse" # "mse"
92 | params["multiply_loss"] = False
93 | if params["multiply_loss"]:
94 | params["loss_multiplication_factor"] = 10.
95 |
96 | params["random_seed"] = 42
97 | params["torch_seed"] = 5
98 |
99 | # -----Model name to save-----
100 | model_name = "model_scone_occ"
101 |
102 | if params["no_local_features"]:
103 | model_name += "_no_local_features"
104 |
105 | if params["no_view_harmonics"]:
106 | model_name += "_no_view_harmonics"
107 |
108 | if params["ddp"]:
109 | model_name = "ddp_" + model_name
110 | elif params["jz"]:
111 | model_name = "jz_" + model_name
112 |
113 | model_name += "_" + params["training_loss"]
114 |
115 | if params["noam_opt"]:
116 | model_name += "noam_"
117 | model_name += "_warmup_" + str(params["warmup"])
118 |
119 | if params["schedule_learning_rate"]:
120 | model_name += "_schedule"
121 | model_name += "_lr_" + str(params["learning_rate"])
122 |
123 | if debug:
124 | model_name = "debug_" + model_name
125 |
126 | if params["prediction_in_random_camera_space"]:
127 | model_name += "_random_rot"
128 |
129 | params["scone_occ_model_name"] = model_name
130 |
131 | # -----Json name to save params-----
132 | json_name = "train_params_" + params["scone_occ_model_name"] + ".json"
133 |
134 | if save:
135 | with open(json_name, 'w') as outfile:
136 | json.dump(params, outfile)
137 |
138 | print("Parameters save in:")
139 | print(json_name)
140 |
141 | return json_name
142 |
143 |
144 | def loop(params,
145 | batch, mesh_dict,
146 | scone_occ, occ_loss_fn,
147 | device, is_master,
148 | n_views_list=None
149 | ):
150 | paths = mesh_dict['path']
151 |
152 | pred_occs = torch.zeros(0, 1, device=device)
153 | truth_occs = torch.zeros(0, 1, device=device)
154 |
155 | base_harmonics, h_polar, h_azim = get_all_harmonics_under_degree(params.harmonic_degree,
156 | params.view_state_n_elev,
157 | params.view_state_n_azim,
158 | device)
159 |
160 | batch_size = len(paths)
161 |
162 | # Loading, if provided, view sequences (useful for consistent validation)
163 | if n_views_list is None:
164 | n_views = np.random.randint(params.n_view_min, params.n_view_max + 1, batch_size)
165 | else:
166 | n_views = get_validation_n_view(params, n_views_list, batch, idr_torch.rank)
167 |
168 | for i in range(batch_size):
169 | # ----------Load input mesh and ground truth data---------------------------------------------------------------
170 |
171 | path_i = paths[i]
172 | # Loading info about partial point clouds and coverages
173 | part_pc, _ = get_gt_partial_point_clouds(path=path_i,
174 | normalization_factor=1./params.gt_surface_resolution,
175 | device=device)
176 | # Loading info about ground truth occupancy field
177 | X_world, occs = get_gt_occupancy_field(path=path_i, device=device)
178 |
179 | # ----------Set camera positions associated to partial point clouds---------------------------------------------
180 |
181 | # Positions are loaded in world coordinates
182 | X_cam_world, camera_dist, camera_elev, camera_azim = get_cameras_on_sphere(params, device,
183 | pole_cameras=params.pole_cameras)
184 |
185 | # ----------Select initial observations of the object-----------------------------------------------------------
186 |
187 | # Select a subset of n_view cameras to compute an initial point cloud
188 | n_view = n_views[i]
189 | view_idx = torch.randperm(len(camera_elev), device=device)[:n_view]
190 |
191 | # Select either first camera view space, or random camera view space as prediction view space
192 | if params.prediction_in_random_camera_space:
193 | prediction_cam_idx = np.random.randint(low=0, high=len(camera_elev))
194 | else:
195 | prediction_cam_idx = view_idx[0]
196 | prediction_box_center = torch.Tensor([0., 0., params.camera_dist]).to(device)
197 |
198 | # Move camera coordinates from world space to prediction view space, and normalize them for prediction box
199 | prediction_R, prediction_T = look_at_view_transform(dist=camera_dist[prediction_cam_idx],
200 | elev=camera_elev[prediction_cam_idx],
201 | azim=camera_azim[prediction_cam_idx],
202 | device=device)
203 | prediction_camera = FoVPerspectiveCameras(device=device, R=prediction_R, T=prediction_T)
204 | prediction_view_transform = prediction_camera.get_world_to_view_transform()
205 |
206 | X_cam = prediction_view_transform.transform_points(X_cam_world)
207 | X_cam = normalize_points_in_prediction_box(points=X_cam,
208 | prediction_box_center=prediction_box_center,
209 | prediction_box_diag=params.gt_max_diagonal)
210 | _, elev_cam, azim_cam = get_spherical_coords(X_cam)
211 |
212 | X_view = X_cam[view_idx]
213 |
214 | # ----------Capture initial observations------------------------------------------------------------------------
215 |
216 | # Points observed in initial views
217 | pc = torch.vstack([part_pc[pc_idx] for pc_idx in view_idx])
218 |
219 | # Downsampling partial point cloud
220 | pc = pc[torch.randperm(len(pc))[:n_view * params.seq_len]]
221 |
222 | # Move partial point cloud from world space to prediction view space, and normalize them in prediction box
223 | pc = prediction_view_transform.transform_points(pc)
224 | pc = normalize_points_in_prediction_box(points=pc,
225 | prediction_box_center=prediction_box_center,
226 | prediction_box_diag=params.gt_max_diagonal).view(1, -1, 3)
227 |
228 | # ----------Compute inputs to SconeOcc--------------------------------------------------------------------------
229 |
230 | # Sample random proxy points in space
231 | X_idx = torch.randperm(len(X_world))[:params.n_sample]
232 | X_world, occs = X_world[X_idx], occs[X_idx]
233 |
234 | # Move proxy points from world space to prediction view space, and normalize them in prediction box
235 | X = prediction_view_transform.transform_points(X_world)
236 | X = normalize_points_in_prediction_box(points=X,
237 | prediction_box_center=prediction_box_center,
238 | prediction_box_diag=params.gt_max_diagonal).view(1, params.n_sample, 3)
239 |
240 | # Compute view state vector and corresponding view harmonics
241 | view_state = compute_view_state(X, X_view,
242 | params.view_state_n_elev, params.view_state_n_azim)
243 | view_harmonics = compute_view_harmonics(view_state,
244 | base_harmonics, h_polar, h_azim,
245 | params.view_state_n_elev, params.view_state_n_azim)
246 | if params.no_view_harmonics:
247 | view_harmonics *= 0.
248 |
249 | # ----------Predict Occupancy Probability-----------------------------------------------------------------------
250 | pred_i = scone_occ(pc, X, view_harmonics).view(-1, 1)
251 | pred_occs = torch.vstack((pred_occs, pred_i))
252 |
253 | # ----------GT Occupancy Probability----------------------------------------------------------------------------
254 | truth_occs = torch.vstack((truth_occs, occs))
255 |
256 | # ----------Compute Loss--------------------------------------------------------------------------------------------
257 | loss = occ_loss_fn(pred_occs, truth_occs)
258 | if params.multiply_loss:
259 | loss *= params.loss_multiplication_factor
260 |
261 | if batch % params.empty_cache_every_n_batch == 0 and is_master:
262 | print("View state sum-mean:", torch.mean(torch.sum(view_state, dim=-1)))
263 |
264 | return loss, pred_occs, truth_occs, batch_size, n_view
265 |
266 |
267 | def train(params,
268 | dataloader,
269 | scone_occ, occ_loss_fn,
270 | optimizer,
271 | device, is_master,
272 | train_losses):
273 |
274 | num_batches = len(dataloader)
275 | size = num_batches * params.total_batch_size
276 | train_loss = 0.
277 |
278 | # Preparing information model
279 | scone_occ.train()
280 |
281 | t0 = time.time()
282 |
283 | for batch, (mesh_dict) in enumerate(dataloader):
284 |
285 | loss, pred, truth, batch_size, n_screen_cameras = loop(params,
286 | batch, mesh_dict,
287 | scone_occ, occ_loss_fn,
288 | device, is_master)
289 |
290 | # Backpropagation
291 | optimizer.zero_grad()
292 | loss.backward()
293 | optimizer.step()
294 |
295 | train_loss += loss.detach()
296 | if params.multiply_loss:
297 | train_loss /= params.loss_multiplication_factor
298 |
299 | if batch % params.empty_cache_every_n_batch == 0:
300 |
301 | # loss = reduce_tensor(loss)
302 | if params.ddp or params.jz:
303 | loss = reduce_tensor(loss, world_size=params.WORLD_SIZE)
304 | loss = to_python_float(loss)
305 |
306 | current = batch * batch_size # * idr_torch.size
307 | if params.ddp or params.jz:
308 | current *= params.WORLD_SIZE
309 |
310 | truth_norm = to_python_float(torch.linalg.norm(truth.detach()))
311 | pred_norm = to_python_float(torch.linalg.norm(pred.detach()))
312 |
313 | # torch.cuda.synchronize()
314 |
315 | if is_master:
316 | print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]",
317 | "computed in", (time.time() - t0) / 60., "minutes.")
318 | print(">>>Prediction shape:", pred.shape,
319 | "\n>>>Truth norm:", truth_norm, ">>>Prediction norm:", pred_norm,
320 | "\nNumber of cameras:", n_screen_cameras)
321 | # print("Harmonics:\n", info_harmonics_i)
322 | # TO REMOVE
323 | torch.cuda.empty_cache()
324 |
325 | if batch % params.empty_cache_every_n_batch == 0:
326 | t0 = time.time()
327 |
328 | # train_loss = reduce_tensor(train_loss)
329 | if params.ddp or params.jz:
330 | train_loss = reduce_tensor(train_loss, world_size=params.WORLD_SIZE)
331 | train_loss = to_python_float(train_loss)
332 | train_loss /= num_batches
333 | train_losses.append(train_loss)
334 |
335 |
336 | def validation(params,
337 | dataloader,
338 | occ_model, occ_loss_fn,
339 | device, is_master,
340 | val_losses):
341 |
342 | num_batches = len(dataloader)
343 | size = num_batches * params.total_batch_size
344 | val_loss = 0.
345 |
346 | # Preparing information model
347 | occ_model.eval()
348 |
349 | t0 = time.time()
350 |
351 | n_views_list = get_validation_n_views_list(params, dataloader)
352 |
353 | for batch, (mesh_dict) in enumerate(dataloader):
354 | with torch.no_grad():
355 | loss, pred, truth, batch_size, _ = loop(params,
356 | batch, mesh_dict,
357 | occ_model, occ_loss_fn,
358 | device, is_master,
359 | n_views_list=n_views_list)
360 |
361 | val_loss += loss.detach()
362 | if params.multiply_loss:
363 | val_loss /= params.loss_multiplication_factor
364 |
365 | if batch % params.empty_cache_every_n_batch == 0:
366 | torch.cuda.empty_cache()
367 |
368 | # val_loss = reduce_tensor(val_loss)
369 | if params.ddp or params.jz:
370 | val_loss = reduce_tensor(val_loss, world_size=params.WORLD_SIZE)
371 |
372 | val_loss = to_python_float(val_loss)
373 | val_loss /= num_batches
374 | val_losses.append(val_loss)
375 |
376 | if is_master:
377 | print(f"Validation Error: \n Avg loss: {val_loss:>8f} \n")
378 |
379 |
380 | def run(ddp_rank=None, params=None):
381 | # Set device
382 | device = setup_device(params, ddp_rank)
383 |
384 | batch_size = params.batch_size
385 | total_batch_size = params.total_batch_size
386 |
387 | if params.ddp:
388 | world_size = params.WORLD_SIZE
389 | rank = ddp_rank
390 | is_master = rank == 0
391 | elif params.jz:
392 | world_size = idr_torch.size
393 | rank = idr_torch.rank
394 | is_master = rank == 0
395 | else:
396 | world_size, rank = None, None
397 | is_master = True
398 |
399 | # Create dataloader
400 | train_dataloader, val_dataloader, _ = get_shapenet_dataloader(batch_size=params.batch_size,
401 | ddp=params.ddp, jz=params.jz,
402 | world_size=world_size, ddp_rank=rank,
403 | load_obj=False,
404 | data_path=None)
405 |
406 |
407 | # Initialize or Load models
408 | if params.no_local_features:
409 | raise NameError("no_local_features mode is not implemented yet.")
410 | else:
411 | scone_occ = SconeOcc().to(device)
412 |
413 | # Initialize information model (and DDP wrap if needed)
414 | scone_occ, optimizer, opt_name, start_epoch, best_loss = initialize_scone_occ(params=params,
415 | scone_occ=scone_occ,
416 | device=device,
417 | torch_seed=params.torch_seed,
418 | load_pretrained_weights=not params.start_from_scratch,
419 | pretrained_weights_name=params.pretrained_weights_name,
420 | ddp_rank=rank)
421 |
422 | best_train_loss = 1000
423 | epochs_without_improvement = 0
424 | learning_rate = params.learning_rate
425 |
426 | # Set loss function
427 | occ_loss_fn = get_occ_loss_fn(params)
428 |
429 | if is_master:
430 | print("Model name:", params.scone_occ_model_name, "\nArchitecture:\n")
431 | print(scone_occ)
432 | print("Model name:", params.scone_occ_model_name)
433 | print("Numbers of trainable parameters:", count_parameters(scone_occ))
434 | print("Using", opt_name, "optimizer.")
435 |
436 | if params.training_loss == "cross_entropy":
437 | print("Using soft cross entropy loss.")
438 | elif params.training_loss == "mse":
439 | print("Using MSE loss.")
440 |
441 | print("Using", params.n_camera, "uniformly sampled camera position per mesh.")
442 |
443 | print("Training data:", len(train_dataloader), "batches.")
444 | print("Validation data:", len(val_dataloader), "batches.")
445 | print("Batch size:", params.total_batch_size)
446 | print("Batch size per GPU:", params.batch_size)
447 |
448 | # Begin training process
449 | train_losses = []
450 | val_losses = []
451 | val_coverages = []
452 |
453 | t0 = time.time()
454 | for t_e in range(params.epochs):
455 | t = start_epoch + t_e
456 | if is_master:
457 | print(f"Epoch {t + 1}\n-------------------------------")
458 | torch.cuda.empty_cache()
459 |
460 | if params.schedule_learning_rate:
461 | if t in params.lr_epochs:
462 | print("Multiplying learning rate by", params.lr_factor)
463 | learning_rate *= params.lr_factor
464 |
465 | update_learning_rate(params, optimizer, learning_rate)
466 |
467 | print("Max learning rate set to", learning_rate)
468 | print("Current learning rate set to", optimizer._rate)
469 |
470 | train_dataloader.sampler.set_epoch(t)
471 |
472 | scone_occ.train()
473 | train(params,
474 | train_dataloader,
475 | scone_occ, occ_loss_fn,
476 | optimizer,
477 | device, is_master,
478 | train_losses)
479 |
480 | current_loss = train_losses[-1]
481 |
482 | if is_master:
483 | print("Training done for epoch", t + 1, ".")
484 | # torch.save(model, "unvalidated_" + model_name + ".pth")
485 | torch.save({
486 | 'epoch': t + 1,
487 | 'model_state_dict': scone_occ.state_dict(),
488 | 'optimizer_state_dict': optimizer.state_dict(),
489 | 'loss': current_loss,
490 | 'train_losses': train_losses,
491 | 'val_losses': val_losses,
492 | }, "unvalidated_" + params.scone_occ_model_name + ".pth")
493 |
494 | if current_loss < best_train_loss:
495 | torch.save({
496 | 'epoch': t + 1,
497 | 'model_state_dict': scone_occ.state_dict(),
498 | 'optimizer_state_dict': optimizer.state_dict(),
499 | 'loss': current_loss,
500 | 'train_losses': train_losses,
501 | 'val_losses': val_losses,
502 | }, "best_unval_" + params.scone_occ_model_name + ".pth")
503 | best_train_loss = current_loss
504 | print("Best model on training set saved with loss " + str(current_loss) + " .\n")
505 |
506 | torch.cuda.empty_cache()
507 |
508 | if is_master:
509 | print("Beginning evaluation on validation dataset...")
510 | # val_dataloader.sampler.set_epoch(t)
511 |
512 | scone_occ.eval()
513 | validation(params,
514 | val_dataloader,
515 | scone_occ, occ_loss_fn,
516 | device, is_master,
517 | val_losses
518 | )
519 |
520 | current_val_loss = val_losses[-1]
521 | if current_val_loss < best_loss:
522 | # torch.save(model, "validated_" + model_name + ".pth")
523 | if is_master:
524 | torch.save({
525 | 'epoch': t + 1,
526 | 'model_state_dict': scone_occ.state_dict(),
527 | 'optimizer_state_dict': optimizer.state_dict(),
528 | 'loss': current_val_loss,
529 | 'train_losses': train_losses,
530 | 'val_losses': val_losses
531 | }, "validated_" + params.scone_occ_model_name + ".pth")
532 | print("Model saved with loss " + str(current_val_loss) + " .\n")
533 | best_loss = val_losses[-1]
534 | epochs_without_improvement = 0
535 | else:
536 | epochs_without_improvement += 1
537 |
538 | # Save data about losses
539 | if is_master:
540 | losses_data = {}
541 | losses_data['train_loss'] = train_losses
542 | losses_data['val_loss'] = val_losses
543 | json_name = "losses_data_" + params.scone_occ_model_name + ".json"
544 | with open(json_name, 'w') as outfile:
545 | json.dump(losses_data, outfile)
546 | print("Saved data about losses in", json_name, ".")
547 |
548 | if is_master:
549 | print("Done in", (time.time() - t0) / 3600., "hours!")
550 |
551 | # Save data about losses
552 | losses_data = {}
553 | losses_data['train_loss'] = train_losses
554 | losses_data['val_loss'] = val_losses
555 | json_name = "losses_data_" + params.scone_occ_model_name + ".json"
556 | with open(json_name, 'w') as outfile:
557 | json.dump(losses_data, outfile)
558 | print("Saved data about losses in", json_name, ".")
559 |
560 | if params.ddp or params.jz:
561 | cleanup()
562 |
563 |
564 | if __name__ == "__main__":
565 | # Save and load parameters
566 | json_name = save_train_params(save=save_parameters)
567 | print("Loaded parameters stored in", json_name)
568 |
569 | if (not save_parameters) and (len(sys.argv) > 1):
570 | json_name = sys.argv[1]
571 | print("Using json name given in argument:")
572 | print(json_name)
573 |
574 | if start_training:
575 | params = load_params(json_name)
576 |
577 | if params.ddp:
578 | mp.spawn(run,
579 | args=(params,
580 | ),
581 | nprocs=params.WORLD_SIZE
582 | )
583 |
584 | elif params.jz:
585 | run(params=params)
586 |
587 | else:
588 | run(params=params)
589 |
--------------------------------------------------------------------------------
/SCONE/pretrain_scone_vis.py:
--------------------------------------------------------------------------------
1 | # Updated from former script train_discoveries_faster.py
2 |
3 | from scone_utils import *
4 | import sys
5 |
6 | save_parameters = True
7 | start_training = False
8 |
9 | debug = False
10 |
11 | def save_train_params(save=False):
12 | # TO DO: for ddp, save only if is_master but load for everyone, with a synchronization in-between.
13 | params = {}
14 |
15 | # -----General parameters-----
16 | params["ddp"] = False
17 | params["jz"] = True
18 |
19 | if params["ddp"]:
20 | params["CUDA_VISIBLE_DEVICES"] = "0, 1"
21 | params["WORLD_SIZE"] = 2
22 |
23 | elif params["jz"]:
24 | params["WORLD_SIZE"] = idr_torch.size
25 |
26 | else:
27 | params["numGPU"] = 0
28 | params["WORLD_SIZE"] = 1
29 |
30 | params["anomaly_detection"] = True
31 | params["empty_cache_every_n_batch"] = 10
32 |
33 | # -----Ground truth computation parameters-----
34 | params["compute_gt_online"] = False
35 | params["compute_partial_point_cloud_online"] = False
36 |
37 | params["gt_surface_resolution"] = 1.5
38 |
39 | params["gt_max_diagonal"] = 1. # Formerly known as x_range
40 |
41 | params["n_points_surface"] = 16384 # N points on GT surface
42 |
43 | params["surface_epsilon_is_constant"] = True
44 | if params["surface_epsilon_is_constant"]:
45 | params["surface_epsilon"] = 0.00707
46 |
47 | # -----SconeOcc Model Parameters-----
48 | # params["scone_occ_model_name"] = "best_unval_jz_model_scone_occ_mse_warmup_1000_schedule_lr_0.0001.pth"
49 | params["scone_occ_model_name"] = "best_unval_jz_model_scone_occ_mse_warmup_1000_schedule_lr_0.0001.pth"
50 | params["occ_no_view_harmonics"] = False
51 |
52 | params["n_view_max_for_scone_occ"] = 9
53 | params["max_points_per_scone_occ_pass"] = 300000
54 |
55 | # -----Model Parameters-----
56 | params["seq_len"] = 2048
57 | params["pts_dim"] = 4
58 |
59 | params["view_state_n_elev"] = 7
60 | params["view_state_n_azim"] = 2 * 7
61 | params["harmonic_degree"] = 8
62 |
63 | params["n_proxy_points"] = 100000 # 12000, 100000
64 | params["use_occ_to_sample_proxy_points"] = True # True
65 | params["min_occ_for_proxy_points"] = 0.1
66 |
67 | params["true_monte_carlo_sampling"] = True
68 |
69 | # -----Ablation study-----
70 | params["no_view_harmonics"] = False
71 | params["use_sigmoid"] = True
72 |
73 | # -----General training parameters-----
74 | params["start_from_scratch"] = True
75 | params["pretrained_weights_name"] = None
76 |
77 | params["n_view_max"] = 9
78 | params["n_view_min"] = 1
79 | params["filter_tol"] = 0.01
80 |
81 | params["camera_dist"] = 1.5
82 | params["pole_cameras"] = True
83 | params["n_camera_elev"] = 5
84 | params["n_camera_azim"] = 2 * 5
85 | params["n_camera"] = params["n_camera_elev"] * params["n_camera_azim"]
86 | if params["pole_cameras"]:
87 | params["n_camera"] += 2
88 |
89 | params["prediction_in_random_camera_space"] = False
90 |
91 | params["total_batch_size"] = 12 # 12
92 | params["batch_size"] = params["total_batch_size"] // params["WORLD_SIZE"]
93 | params["total_batch_size"] = params["batch_size"] * params["WORLD_SIZE"]
94 |
95 | params["epochs"] = 1000
96 | params["learning_rate"] = 1e-4
97 |
98 | params["schedule_learning_rate"] = True
99 | if params["schedule_learning_rate"]:
100 | params["lr_epochs"] = [179] # [179]
101 | params["lr_factor"] = 0.1
102 |
103 | params["warmup"] = 1000
104 | params["warmup_rate"] = 1 / (params["warmup"] * params["learning_rate"] ** 2)
105 |
106 | params["noam_opt"] = False
107 | params["training_metric"] = "surface_coverage_gain"
108 | # Training metric can be: "surface_coverage", "surface_coverage_gain", "absolute_coverage"
109 | params["training_loss"] = "uncentered_l1" # "kl_divergence", "l1", "uncentered_l1"
110 | params["multiply_loss"] = False
111 | if params["multiply_loss"]:
112 | params["loss_multiplication_factor"] = 10.
113 |
114 | params["nbv_validation"] = True
115 |
116 | params["random_seed"] = 42
117 | params["torch_seed"] = 5
118 |
119 | # -----Visibility Model name to save-----
120 | model_name = "model_scone_vis"
121 | model_name = model_name + "_" + params["training_metric"]
122 |
123 | if params["occ_no_view_harmonics"]:
124 | model_name += "_occ_no_vh"
125 |
126 | if params["no_view_harmonics"]:
127 | model_name += "_no_view_harmonics"
128 |
129 | if params["surface_epsilon_is_constant"]:
130 | model_name += "_constant_epsilon"
131 | else:
132 | model_name += "_adaptative_epsilon"
133 |
134 | if params["ddp"]:
135 | model_name = "ddp_" + model_name
136 | elif params["jz"]:
137 | model_name = "jz_" + model_name
138 |
139 | # model_name += "_seed" + str(params["random_seed"]) + "_" + str(params["torch_seed"])
140 |
141 | model_name += "_" + params["training_loss"]
142 |
143 | if params["noam_opt"]:
144 | model_name += "_noam"
145 | model_name += "_warmup_" + str(params["warmup"])
146 |
147 | if params["schedule_learning_rate"]:
148 | model_name += "_schedule"
149 | model_name += "_lr_" + str(params["learning_rate"])
150 |
151 | if params["use_sigmoid"]:
152 | model_name += "_sigmoid"
153 | else:
154 | model_name += "_relu"
155 |
156 | if debug:
157 | model_name = "debug_" + model_name
158 |
159 | if params["prediction_in_random_camera_space"]:
160 | model_name += "_random_rot"
161 |
162 | if params["true_monte_carlo_sampling"]:
163 | model_name += "_tmcs"
164 |
165 | params["scone_vis_model_name"] = model_name
166 |
167 | # -----Json name to save params-----
168 | json_name = "train_params_" + params["scone_vis_model_name"] + ".json"
169 |
170 | if save:
171 | with open(json_name, 'w') as outfile:
172 | json.dump(params, outfile)
173 |
174 | print("Parameters save in:")
175 | print(json_name)
176 |
177 | return json_name
178 |
179 |
180 | def loop(params,
181 | batch, mesh_dict,
182 | scone_occ, scone_vis, cov_loss_fn,
183 | device, is_master,
184 | n_views_list=None,
185 | optimal_sequences=None
186 | ):
187 | paths = mesh_dict['path']
188 |
189 | cov_pred = torch.zeros(0, params.n_camera, 1, device=device)
190 | cov_truth = torch.zeros(0, params.n_camera, 1, device=device)
191 |
192 | info_harmonics_i = None
193 |
194 | base_harmonics, h_polar, h_azim = get_all_harmonics_under_degree(params.harmonic_degree,
195 | params.view_state_n_elev,
196 | params.view_state_n_azim,
197 | device)
198 |
199 | batch_size = len(paths)
200 |
201 | if n_views_list is None:
202 | n_views = np.random.randint(params.n_view_min, params.n_view_max + 1, batch_size)
203 | else:
204 | n_views = get_validation_n_view(params, n_views_list, batch, idr_torch.rank)
205 |
206 | if batch == 0 and is_master:
207 | print("First batch:", mesh_dict['path'])
208 |
209 | for i in range(batch_size):
210 | # ----------Load input mesh and ground truth data---------------------------------------------------------------
211 |
212 | path_i = paths[i]
213 |
214 | # Loading info about partial point clouds and coverages
215 | part_pc, coverage = get_gt_partial_point_clouds(path=path_i,
216 | normalization_factor=1. / params.gt_surface_resolution,
217 | device=device)
218 |
219 | # Loading info about ground truth surface
220 | # gt_surface, surface_epsilon = get_gt_surface(params=params,
221 | # path=path_i,
222 | # normalization_factor=1./params.gt_surface_resolution,
223 | # device=device)
224 |
225 | # Initial dense sampling
226 | X_world = sample_X_in_box(x_range=params.gt_max_diagonal, n_sample=params.n_proxy_points, device=device)
227 |
228 | # ----------Set camera candidates for coverage prediction-------------------------------------------------------
229 | X_cam_world, camera_dist, camera_elev, camera_azim = get_cameras_on_sphere(params, device,
230 | pole_cameras=params.pole_cameras)
231 |
232 | # ----------Select initial observations of the object-----------------------------------------------------------
233 |
234 | # Select a subset of n_view cameras to compute an initial point cloud
235 | n_view = n_views[i]
236 | if optimal_sequences is None:
237 | view_idx = torch.randperm(len(camera_elev), device=device)[:n_view]
238 | else:
239 | optimal_seq, _ = get_optimal_sequence(optimal_sequences, path_i, n_view)
240 | view_idx = optimal_seq.to(device)
241 |
242 | # Select either first camera view space, or random camera view space as prediction view space
243 | if params.prediction_in_random_camera_space:
244 | prediction_cam_idx = np.random.randint(low=0, high=len(camera_elev))
245 | else:
246 | prediction_cam_idx = view_idx[0]
247 | prediction_box_center = torch.Tensor([0., 0., params.camera_dist]).to(device)
248 |
249 | # Move camera coordinates from world space to prediction view space, and normalize them for prediction box
250 | prediction_R, prediction_T = look_at_view_transform(dist=camera_dist[prediction_cam_idx],
251 | elev=camera_elev[prediction_cam_idx],
252 | azim=camera_azim[prediction_cam_idx],
253 | device=device)
254 | prediction_camera = FoVPerspectiveCameras(device=device, R=prediction_R, T=prediction_T)
255 | prediction_view_transform = prediction_camera.get_world_to_view_transform()
256 |
257 | X_cam = prediction_view_transform.transform_points(X_cam_world)
258 | X_cam = normalize_points_in_prediction_box(points=X_cam,
259 | prediction_box_center=prediction_box_center,
260 | prediction_box_diag=params.gt_max_diagonal)
261 | _, elev_cam, azim_cam = get_spherical_coords(X_cam)
262 |
263 | X_view = X_cam[view_idx]
264 | X_cam = X_cam.view(1, params.n_camera, 3)
265 |
266 | # ----------Capture initial observations------------------------------------------------------------------------
267 |
268 | # Points observed in initial views
269 | pc = torch.vstack([part_pc[pc_idx] for pc_idx in view_idx])
270 |
271 | # Downsampling partial point cloud
272 | pc = pc[torch.randperm(len(pc))[:n_view * params.seq_len]]
273 |
274 | # Move partial point cloud from world space to prediction view space, and normalize them in prediction box
275 | pc = prediction_view_transform.transform_points(pc)
276 | pc = normalize_points_in_prediction_box(points=pc,
277 | prediction_box_center=prediction_box_center,
278 | prediction_box_diag=params.gt_max_diagonal).view(1, -1, 3)
279 |
280 | # ----------Compute inputs to SconeVis-----------------------------------------------
281 |
282 | # Sample random proxy points in space
283 | X_idx = torch.randperm(len(X_world))[:params.n_proxy_points]
284 | X_world = X_world[X_idx]
285 |
286 | # Move proxy points from world space to prediction view space, and normalize them in prediction box
287 | X = prediction_view_transform.transform_points(X_world)
288 | X = normalize_points_in_prediction_box(points=X,
289 | prediction_box_center=prediction_box_center,
290 | prediction_box_diag=params.gt_max_diagonal
291 | )
292 |
293 | # Filter Proxy Points using pc shape from view cameras
294 | R_view, T_view = look_at_view_transform(eye=X_view,
295 | at=torch.zeros_like(X_view),
296 | device=device)
297 | view_cameras = FoVPerspectiveCameras(R=R_view, T=T_view, zfar=1000, device=device)
298 | X, _ = filter_proxy_points(view_cameras, X, pc.view(-1, 3), filter_tol=params.filter_tol)
299 | X = X.view(1, X.shape[0], 3)
300 |
301 | # Compute view state vector and corresponding view harmonics
302 | view_state = compute_view_state(X, X_view,
303 | params.view_state_n_elev, params.view_state_n_azim)
304 | view_harmonics = compute_view_harmonics(view_state,
305 | base_harmonics, h_polar, h_azim,
306 | params.view_state_n_elev, params.view_state_n_azim)
307 | occ_view_harmonics = 0. + view_harmonics
308 | if params.occ_no_view_harmonics:
309 | occ_view_harmonics *= 0.
310 | if params.no_view_harmonics:
311 | view_harmonics *= 0.
312 |
313 | # Compute occupancy probabilities
314 | with torch.no_grad():
315 | occ_prob_i = compute_occupancy_probability(scone_occ=scone_occ,
316 | pc=pc,
317 | X=X,
318 | view_harmonics=occ_view_harmonics,
319 | max_points_per_pass=params.max_points_per_scone_occ_pass
320 | ).view(-1, 1)
321 |
322 | proxy_points, view_harmonics, sample_idx = sample_proxy_points(X[0], occ_prob_i, view_harmonics.squeeze(dim=0),
323 | n_sample=params.seq_len,
324 | min_occ=params.min_occ_for_proxy_points,
325 | use_occ_to_sample=params.use_occ_to_sample_proxy_points,
326 | return_index=True)
327 |
328 | proxy_points = torch.unsqueeze(proxy_points, dim=0)
329 | view_harmonics = torch.unsqueeze(view_harmonics, dim=0)
330 |
331 | # ----------Predict Coverage Gains------------------------------------------------------------------------------
332 | visibility_gain_harmonics = scone_vis(proxy_points, view_harmonics=view_harmonics)
333 | if params.true_monte_carlo_sampling:
334 | proxy_points = torch.unsqueeze(proxy_points[0][sample_idx], dim=0)
335 | visibility_gain_harmonics = torch.unsqueeze(visibility_gain_harmonics[0][sample_idx], dim=0)
336 |
337 | if params.ddp or params.jz:
338 | cov_pred_i = scone_vis.module.compute_coverage_gain(proxy_points,
339 | visibility_gain_harmonics,
340 | X_cam)
341 | else:
342 | cov_pred_i = scone_vis.compute_coverage_gain(proxy_points,
343 | visibility_gain_harmonics,
344 | X_cam)
345 |
346 | cov_pred = torch.vstack((cov_pred, cov_pred_i.view(1, -1, 1)))
347 |
348 | # ----------Compute ground truth information scores----------
349 | cov_truth_i = compute_gt_coverage_gain_from_precomputed_matrices(coverage=coverage,
350 | initial_cam_idx=view_idx)
351 | cov_truth = torch.vstack((cov_truth, cov_truth_i.view(1, -1, 1)))
352 |
353 | # ----------Compute loss----------
354 | # cov_pred = cov_pred.view(-1, params.n_camera, 1)
355 | # cov_truth = cov_truth.view(-1, params.n_camera, 1)
356 | loss = cov_loss_fn(cov_pred, cov_truth)
357 |
358 | if batch % params.empty_cache_every_n_batch == 0 and is_master:
359 | print("View state sum-mean:", torch.mean(torch.sum(view_state, dim=-1)))
360 | print("Point cloud features shape:", proxy_points.shape)
361 | # print("Surface epsilon:", params.surface_epsilon)
362 |
363 | return loss, cov_pred, cov_truth, batch_size, n_view
364 |
365 |
366 | def train(params,
367 | dataloader,
368 | scone_occ,
369 | scone_vis, cov_loss_fn,
370 | optimizer,
371 | device, is_master,
372 | train_losses):
373 |
374 | num_batches = len(dataloader)
375 | size = num_batches * params.total_batch_size
376 | train_loss = 0.
377 |
378 | # Preparing information model
379 | scone_vis.train()
380 |
381 | t0 = time.time()
382 |
383 | for batch, (mesh_dict) in enumerate(dataloader):
384 |
385 | loss, cov_pred, cov_truth, batch_size, n_view = loop(params,
386 | batch, mesh_dict,
387 | scone_occ, scone_vis, cov_loss_fn,
388 | device, is_master,
389 | n_views_list=None,
390 | optimal_sequences=None)
391 |
392 | # Backpropagation
393 | optimizer.zero_grad()
394 | loss.backward()
395 | optimizer.step()
396 |
397 | train_loss += loss.detach()
398 | if params.multiply_loss:
399 | train_loss /= params.loss_multiplication_factor
400 |
401 | if batch % params.empty_cache_every_n_batch == 0:
402 |
403 | # loss = reduce_tensor(loss)
404 | if params.ddp or params.jz:
405 | loss = reduce_tensor(loss, world_size=params.WORLD_SIZE)
406 | loss = to_python_float(loss)
407 |
408 | current = batch * batch_size # * idr_torch.size
409 | if params.ddp or params.jz:
410 | current *= params.WORLD_SIZE
411 |
412 | truth_norm = to_python_float(torch.linalg.norm(cov_truth.detach()))
413 | pred_norm = to_python_float(torch.linalg.norm(cov_pred.detach()))
414 |
415 | # torch.cuda.synchronize()
416 |
417 | if is_master:
418 | print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]",
419 | "computed in", (time.time() - t0) / 60., "minutes.")
420 | print(">>>Prediction shape:", cov_pred.shape,
421 | "\n>>>Truth norm:", truth_norm, ">>>Prediction norm:", pred_norm,
422 | "\nNumber of cameras:", n_view, "+ 1.")
423 | # print("Harmonics:\n", info_harmonics_i)
424 | # TO REMOVE
425 | torch.cuda.empty_cache()
426 |
427 | if batch % params.empty_cache_every_n_batch == 0:
428 | t0 = time.time()
429 |
430 | # train_loss = reduce_tensor(train_loss)
431 | if params.ddp or params.jz:
432 | train_loss = reduce_tensor(train_loss, world_size=params.WORLD_SIZE)
433 | train_loss = to_python_float(train_loss)
434 | train_loss /= num_batches
435 | train_losses.append(train_loss)
436 |
437 |
438 | def validation(params,
439 | dataloader,
440 | scone_occ,
441 | scone_vis, cov_loss_fn,
442 | device, is_master,
443 | val_losses,
444 | nbv_validation=False,
445 | val_coverages=None):
446 |
447 | num_batches = len(dataloader)
448 | size = num_batches * params.total_batch_size
449 | val_loss = 0.
450 | val_coverage = 0.
451 |
452 | # Preparing information model
453 | scone_vis.eval()
454 |
455 | t0 = time.time()
456 |
457 | n_views_list = get_validation_n_views_list(params, dataloader)
458 | optimal_sequences = get_validation_optimal_sequences(jz=params.jz, device=device)
459 |
460 | for batch, (mesh_dict) in enumerate(dataloader):
461 | with torch.no_grad():
462 | loss, cov_pred, cov_truth, batch_size, _ = loop(params,
463 | batch, mesh_dict,
464 | scone_occ, scone_vis, cov_loss_fn,
465 | device, is_master,
466 | n_views_list=n_views_list,
467 | optimal_sequences=optimal_sequences)
468 |
469 | val_loss += loss.detach()
470 | if params.multiply_loss:
471 | val_loss /= params.loss_multiplication_factor
472 |
473 | if nbv_validation:
474 | with torch.no_grad():
475 | info_pred_scores = cov_pred.view(batch_size, params.n_camera, 1)
476 | info_pred_scores = torch.squeeze(info_pred_scores, dim=-1)
477 |
478 | info_truth_scores = cov_truth.view(batch_size, params.n_camera, 1)
479 | info_truth_scores = torch.squeeze(info_truth_scores, dim=-1)
480 |
481 | max_info_preds, max_info_idx = torch.max(info_pred_scores,
482 | dim=1)
483 |
484 | max_idx = max_info_idx.view(batch_size, 1)
485 | true_max_coverages = torch.gather(info_truth_scores,
486 | dim=1,
487 | index=max_idx)
488 |
489 | val_coverage += torch.sum(true_max_coverages).detach() / batch_size
490 |
491 | if batch % params.empty_cache_every_n_batch == 0:
492 | torch.cuda.empty_cache()
493 |
494 | # val_loss = reduce_tensor(val_loss)
495 | if params.ddp or params.jz:
496 | val_loss = reduce_tensor(val_loss, world_size=params.WORLD_SIZE)
497 | if nbv_validation:
498 | val_coverage = reduce_tensor(val_coverage, world_size=params.WORLD_SIZE)
499 |
500 | val_loss = to_python_float(val_loss)
501 | val_loss /= num_batches
502 | val_losses.append(val_loss)
503 |
504 | if nbv_validation:
505 | val_coverage = to_python_float(val_coverage)
506 | val_coverage /= num_batches
507 | if val_coverages is None:
508 | raise NameError("Variable val_coverages is set to None.")
509 | else:
510 | val_coverages.append(val_coverage)
511 |
512 | if is_master:
513 | print(f"Validation Error: \n Avg loss: {val_loss:>8f} \n")
514 | if nbv_validation:
515 | print(f"Avg nbv coverage: {val_coverage:>8f} \n")
516 |
517 |
518 | def run(ddp_rank=None, params=None):
519 | # Set device
520 | device = setup_device(params, ddp_rank)
521 |
522 | batch_size = params.batch_size
523 | total_batch_size = params.total_batch_size
524 |
525 | if params.ddp:
526 | world_size = params.WORLD_SIZE
527 | rank = ddp_rank
528 | is_master = rank == 0
529 | elif params.jz:
530 | world_size = idr_torch.size
531 | rank = idr_torch.rank
532 | is_master = rank == 0
533 | else:
534 | world_size, rank = None, None
535 | is_master = True
536 |
537 | # Create dataloader
538 | train_dataloader, val_dataloader, _ = get_shapenet_dataloader(batch_size=params.batch_size,
539 | ddp=params.ddp, jz=params.jz,
540 | world_size=world_size, ddp_rank=rank,
541 | load_obj=False,
542 | data_path=None)
543 |
544 | # Initialize or Load models
545 | scone_occ = load_scone_occ(params, params.scone_occ_model_name, ddp_model=True, device=device)
546 | scone_occ.eval()
547 |
548 | scone_vis = SconeVis(use_sigmoid=params.use_sigmoid).to(device)
549 | scone_vis, optimizer, opt_name, start_epoch, best_loss, best_coverage = initialize_scone_vis(params=params,
550 | scone_vis=scone_vis,
551 | device=device,
552 | torch_seed=params.torch_seed,
553 | load_pretrained_weights=not params.start_from_scratch,
554 | pretrained_weights_name=params.pretrained_weights_name,
555 | ddp_rank=rank)
556 |
557 | best_train_loss = 1000
558 | epochs_without_improvement = 0
559 | learning_rate = params.learning_rate
560 |
561 | # Set loss function
562 | cov_loss_fn = get_cov_loss_fn(params)
563 |
564 | if is_master:
565 | print("Model name:", params.scone_vis_model_name, "\nArchitecture:\n")
566 | print(scone_vis)
567 | print("Model name:", params.scone_vis_model_name)
568 | print("Numbers of trainable parameters:", count_parameters(scone_vis))
569 | print("Using", opt_name, "optimizer.")
570 | print("Using occupancy model", params.scone_occ_model_name)
571 |
572 | if params.training_loss == "kl_divergence":
573 | print("Using softmax + KL Divergence loss.")
574 | elif params.training_loss == "mse":
575 | print("Using MSE loss.")
576 |
577 | print("Using", params.n_camera, "uniformly sampled camera position per mesh.")
578 |
579 | print("Training data:", len(train_dataloader), "batches.")
580 | print("Validation data:", len(val_dataloader), "batches.")
581 | print("Batch size:", params.total_batch_size)
582 | print("Batch size per GPU:", params.batch_size)
583 |
584 | # Begin training process
585 | train_losses = []
586 | val_losses = []
587 | val_coverages = []
588 |
589 | t0 = time.time()
590 | for t_e in range(params.epochs):
591 | t = start_epoch + t_e
592 | if is_master:
593 | print(f"Epoch {t + 1}\n-------------------------------")
594 | torch.cuda.empty_cache()
595 |
596 | # Update learning rate
597 | if params.schedule_learning_rate:
598 | if t in params.lr_epochs:
599 | print("Multiplying learning rate by", params.lr_factor)
600 | learning_rate *= params.lr_factor
601 |
602 | update_learning_rate(params, optimizer, learning_rate)
603 | print("Max learning rate set to", learning_rate)
604 | print("Current learning rate set to", optimizer._rate)
605 |
606 | train_dataloader.sampler.set_epoch(t)
607 |
608 | scone_vis.train()
609 | train(params,
610 | train_dataloader,
611 | scone_occ,
612 | scone_vis, cov_loss_fn,
613 | optimizer,
614 | device, is_master,
615 | train_losses
616 | )
617 |
618 | current_loss = train_losses[-1]
619 |
620 | if is_master:
621 | print("Training done for epoch", t + 1, ".")
622 | # torch.save(model, "unvalidated_" + model_name + ".pth")
623 | torch.save({
624 | 'epoch': t + 1,
625 | 'model_state_dict': scone_vis.state_dict(),
626 | 'optimizer_state_dict': optimizer.state_dict(),
627 | 'loss': current_loss,
628 | # 'coverage': current_val_coverage,
629 | 'train_losses': train_losses,
630 | 'val_losses': val_losses,
631 | }, "unvalidated_" + params.scone_vis_model_name + ".pth")
632 |
633 | if current_loss < best_train_loss:
634 | torch.save({
635 | 'epoch': t + 1,
636 | 'model_state_dict': scone_vis.state_dict(),
637 | 'optimizer_state_dict': optimizer.state_dict(),
638 | 'loss': current_loss,
639 | # 'coverage': current_val_coverage,
640 | 'train_losses': train_losses,
641 | 'val_losses': val_losses,
642 | }, "best_unval_" + params.scone_vis_model_name + ".pth")
643 | best_train_loss = current_loss
644 | print("Best model on training set saved with loss " + str(current_loss) + " .\n")
645 |
646 | torch.cuda.empty_cache()
647 |
648 | if is_master:
649 | print("Beginning evaluation on validation dataset...")
650 | # val_dataloader.sampler.set_epoch(t)
651 |
652 | scone_vis.eval()
653 | validation(params,
654 | val_dataloader,
655 | scone_occ,
656 | scone_vis, cov_loss_fn,
657 | device, is_master,
658 | val_losses,
659 | nbv_validation=params.nbv_validation,
660 | val_coverages=val_coverages
661 | )
662 |
663 | current_val_loss = val_losses[-1]
664 | current_val_coverage = val_coverages[-1]
665 | if current_val_loss < best_loss:
666 | # torch.save(model, "validated_" + model_name + ".pth")
667 | if is_master:
668 | torch.save({
669 | 'epoch': t + 1,
670 | 'model_state_dict': scone_vis.state_dict(),
671 | 'optimizer_state_dict': optimizer.state_dict(),
672 | 'loss': current_val_loss,
673 | 'coverage': current_val_coverage,
674 | 'train_losses': train_losses,
675 | 'val_losses': val_losses
676 | }, "validated_" + params.scone_vis_model_name + ".pth")
677 | print("Model saved with loss " + str(current_val_loss) + " .\n")
678 | best_loss = val_losses[-1]
679 | epochs_without_improvement = 0
680 | else:
681 | epochs_without_improvement += 1
682 |
683 | if is_master and current_val_coverage > best_coverage:
684 | # torch.save(model, "validated_" + model_name + ".pth")
685 | torch.save({
686 | 'epoch': t + 1,
687 | 'model_state_dict': scone_vis.state_dict(),
688 | 'optimizer_state_dict': optimizer.state_dict(),
689 | 'loss': current_val_loss,
690 | 'coverage': current_val_coverage,
691 | 'train_losses': train_losses,
692 | 'val_losses': val_losses
693 | }, "coverage_validated_" + params.scone_vis_model_name + ".pth")
694 | print("Model saved with coverage " + str(current_val_coverage) + " .\n")
695 | best_coverage = val_coverages[-1]
696 |
697 | # Save data about losses
698 | if is_master:
699 | losses_data = {}
700 | losses_data['train_loss'] = train_losses
701 | losses_data['val_loss'] = val_losses
702 | json_name = "losses_data_" + params.scone_vis_model_name + ".json"
703 | with open(json_name, 'w') as outfile:
704 | json.dump(losses_data, outfile)
705 | print("Saved data about losses in", json_name, ".")
706 |
707 | if is_master:
708 | print("Done in", (time.time() - t0) / 3600., "hours!")
709 |
710 | # Save data about losses
711 | losses_data = {}
712 | losses_data['train_loss'] = train_losses
713 | losses_data['val_loss'] = val_losses
714 | json_name = "losses_data_" + params.scone_vis_model_name + ".json"
715 | with open(json_name, 'w') as outfile:
716 | json.dump(losses_data, outfile)
717 | print("Saved data about losses in", json_name, ".")
718 |
719 | if params.ddp or params.jz:
720 | cleanup()
721 |
722 |
723 | if __name__ == "__main__":
724 | # Save and load parameters
725 | json_name = save_train_params(save=save_parameters)
726 | print("Loaded parameters stored in", json_name)
727 |
728 | if (not save_parameters) and (len(sys.argv) > 1):
729 | json_name = sys.argv[1]
730 | print("Using json name given in argument:")
731 | print(json_name)
732 |
733 | if start_training:
734 | params = load_params(json_name)
735 |
736 | if params.ddp:
737 | mp.spawn(run,
738 | args=(params,
739 | ),
740 | nprocs=params.WORLD_SIZE
741 | )
742 |
743 | elif params.jz:
744 | run(params=params)
745 |
746 | else:
747 | run(params=params)
--------------------------------------------------------------------------------
/SCONE/scone_utils.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | from CustomGeometry import *
3 | from CustomDataset import CustomShapenetDataset
4 | from spherical_harmonics import clear_spherical_harmonics_cache
5 |
6 | import torch.distributed as dist
7 | from torch.utils.data.distributed import DistributedSampler
8 | import idr_torch
9 | from torch.nn.parallel import DistributedDataParallel as DDP
10 | from spherical_harmonics import get_spherical_harmonics
11 |
12 | from SconeOcc import SconeOcc
13 | from SconeVis import SconeVis, KLDivCE, L1_loss, Uncentered_L1_loss
14 |
15 |
16 | def setup_device(params, ddp_rank=None):
17 |
18 | if params.ddp:
19 | print("Setup device", str(ddp_rank), "for DDP training...")
20 | os.environ["CUDA_VISIBLE_DEVICES"] = params.CUDA_VISIBLE_DEVICES
21 |
22 | os.environ['MASTER_ADDR'] = 'localhost'
23 | os.environ['MASTER_PORT'] = '12355'
24 | os.environ['WORLD_SIZE'] = str(params.WORLD_SIZE)
25 | os.environ['RANK'] = str(ddp_rank)
26 |
27 | dist.init_process_group("nccl", rank=ddp_rank, world_size=params.WORLD_SIZE)
28 |
29 | device = torch.device("cuda:" + str(ddp_rank))
30 | torch.cuda.set_device(device)
31 |
32 | torch.cuda.empty_cache()
33 | print("Setup done!")
34 |
35 | if ddp_rank == 0:
36 | print(torch.cuda.memory_summary())
37 |
38 | elif params.jz:
39 | print("Setup device", str(idr_torch.rank), " for Jean Zay training...")
40 | dist.init_process_group(backend='nccl',
41 | init_method='env://',
42 | rank=idr_torch.rank,
43 | world_size=idr_torch.size)
44 |
45 | torch.cuda.set_device(idr_torch.local_rank)
46 | device = torch.device("cuda")
47 |
48 | torch.cuda.empty_cache()
49 | print("Setup done!")
50 |
51 | if idr_torch.rank == 0:
52 | print(torch.cuda.memory_summary())
53 |
54 | else:
55 | # Set our device:
56 | if torch.cuda.is_available():
57 | device = torch.device("cuda:" + str(params.numGPU))
58 | torch.cuda.set_device(device)
59 | else:
60 | device = torch.device("cpu")
61 | print(device)
62 |
63 | # Empty cache
64 | torch.cuda.empty_cache()
65 | print(torch.cuda.memory_summary())
66 |
67 | return device
68 |
69 |
70 | def load_params(json_name):
71 | return Params(json_name)
72 |
73 |
74 | def reduce_tensor(tensor: torch.Tensor, world_size):
75 | """Reduce tensor across all nodes."""
76 | rt = tensor.clone()
77 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
78 | rt /= world_size
79 | return rt
80 |
81 |
82 | def to_python_float(t: torch.Tensor):
83 | if hasattr(t, 'item'):
84 | return t.item()
85 | else:
86 | return t[0]
87 |
88 |
89 | def cleanup():
90 | dist.destroy_process_group()
91 |
92 |
93 | def get_shapenet_dataloader(batch_size,
94 | ddp=False, jz=False,
95 | world_size=None, ddp_rank=None,
96 | test_novel=False,
97 | test_number=0,
98 | load_obj=False,
99 | data_path=None):
100 | # Database path
101 | # SHAPENET_PATH = "../../../../datasets/shapenet2/ShapeNetCore.v2"
102 | memory_threshold = 10e6
103 | if data_path is None:
104 | if jz:
105 | SHAPENET_PATH = "../../datasets/ShapeNetCore.v1"
106 | else:
107 | SHAPENET_PATH = "../../../../../../mnt/ssd/aguedon/ShapeNetCore.v1/ShapeNetCore.v1"
108 | SHAPENET_PATH = "../../../../datasets/ShapeNetCore.v1"
109 | # "../../datasets/ShapeNetCore.v1"
110 | # "../../../../../../mnt/ssd/aguedon/ShapeNetCore.v1/ShapeNetCore.v1"
111 | else:
112 | SHAPENET_PATH = data_path
113 |
114 | database_path = os.path.join(SHAPENET_PATH, "train_categories")
115 | train_json = os.path.join(SHAPENET_PATH, "train_list.json")
116 | val_json = os.path.join(SHAPENET_PATH, "val_list.json")
117 | if not test_novel:
118 | if test_number == 0:
119 | test_json = os.path.join(SHAPENET_PATH, "test_list.json")
120 | elif test_number == -1:
121 | test_json = os.path.join(SHAPENET_PATH, "all_test_list.json")
122 | else:
123 | test_json = os.path.join(SHAPENET_PATH, "test_list_" + str(test_number) + ".json")
124 | print("Using test split number " + str(test_number) + ".")
125 | else:
126 | database_path = os.path.join(SHAPENET_PATH, "test_categories")
127 | print("Using novel test split.")
128 | if test_number >= 0:
129 | test_json = os.path.join(SHAPENET_PATH, "test_novel_list.json")
130 | else:
131 | test_json = os.path.join(SHAPENET_PATH, "all_test_novel_list.json")
132 | # test_json = os.path.join(SHAPENET_PATH, "debug_list.json")
133 |
134 | train_dataset = CustomShapenetDataset(data_path=database_path,
135 | memory_threshold=memory_threshold,
136 | save_to_json=False,
137 | load_from_json=True,
138 | json_name=train_json,
139 | official_split=True,
140 | adjust_diagonally=True,
141 | load_obj=load_obj)
142 | val_dataset = CustomShapenetDataset(data_path=database_path,
143 | memory_threshold=memory_threshold,
144 | save_to_json=False,
145 | load_from_json=True,
146 | json_name=val_json,
147 | official_split=True,
148 | adjust_diagonally=True,
149 | load_obj=load_obj)
150 | test_dataset = CustomShapenetDataset(data_path=database_path,
151 | memory_threshold=memory_threshold,
152 | save_to_json=False,
153 | load_from_json=True,
154 | json_name=test_json,
155 | official_split=True,
156 | adjust_diagonally=True,
157 | load_obj=load_obj)
158 |
159 | if ddp or jz:
160 | if jz:
161 | rank = idr_torch.rank
162 | else:
163 | rank = ddp_rank
164 | train_sampler = DistributedSampler(train_dataset,
165 | num_replicas=world_size,
166 | rank=rank,
167 | drop_last=True)
168 | valid_sampler = DistributedSampler(val_dataset,
169 | num_replicas=world_size,
170 | rank=rank,
171 | shuffle=False,
172 | drop_last=True)
173 | test_sampler = DistributedSampler(test_dataset,
174 | num_replicas=world_size,
175 | rank=rank,
176 | shuffle=False,
177 | drop_last=True)
178 |
179 | train_dataloader = DataLoader(train_dataset,
180 | batch_size=batch_size,
181 | drop_last=True,
182 | collate_fn=collate_batched_meshes,
183 | sampler=train_sampler)
184 | validation_dataloader = DataLoader(val_dataset,
185 | batch_size=batch_size,
186 | drop_last=True,
187 | collate_fn=collate_batched_meshes,
188 | sampler=valid_sampler)
189 | test_dataloader = DataLoader(test_dataset,
190 | batch_size=batch_size,
191 | drop_last=True,
192 | collate_fn=collate_batched_meshes,
193 | sampler=test_sampler)
194 | else:
195 | train_dataloader = DataLoader(train_dataset,
196 | batch_size=batch_size,
197 | collate_fn=collate_batched_meshes,
198 | shuffle=True)
199 | validation_dataloader = DataLoader(val_dataset,
200 | batch_size=batch_size,
201 | collate_fn=collate_batched_meshes,
202 | shuffle=False)
203 | test_dataloader = DataLoader(test_dataset,
204 | batch_size=batch_size,
205 | collate_fn=collate_batched_meshes,
206 | shuffle=False)
207 |
208 | return train_dataloader, validation_dataloader, test_dataloader
209 |
210 |
211 | def get_optimizer(params, model):
212 | """
213 | Returns AdamW optimizer with linear warmup steps at beginning.
214 |
215 | :param params: (Params) Hyper parameters file.
216 | :param model: Model to be trained.
217 | :return: (Tuple) Optimizer and its name.
218 | """
219 | optimizer = WarmupConstantOpt(learning_rate=params.learning_rate,
220 | warmup=params.warmup,
221 | optimizer=torch.optim.AdamW(model.parameters(),
222 | lr=0
223 | )
224 | )
225 | opt_name = "WarmupAdamW"
226 |
227 | return optimizer, opt_name
228 |
229 |
230 | def update_learning_rate(params, optimizer, learning_rate):
231 |
232 | if params.noam_opt:
233 | optimizer.model_size = 1 / (optimizer.warmup * learning_rate ** 2)
234 | if optimizer._step == 0:
235 | optimizer._rate = 0
236 | else:
237 | optimizer._rate = (optimizer.model_size ** (-0.5)
238 | * min(optimizer._step ** (-0.5),
239 | optimizer._step * optimizer.warmup ** (-1.5)))
240 |
241 | else:
242 | optimizer.learning_rate = learning_rate
243 | if optimizer._step == 0:
244 | optimizer._rate = 0
245 | else:
246 | optimizer._rate = optimizer.learning_rate * min(1., optimizer._step / optimizer.warmup)
247 |
248 |
249 | def initialize_scone_occ(params, scone_occ, device,
250 | torch_seed=None,
251 | load_pretrained_weights=False,
252 | pretrained_weights_name=None,
253 | ddp_rank=None):
254 | """
255 | Initializes SCONE's occupancy probability prediction module for training.
256 | Can be initialized from scratch, or from an already trained model to resume training.
257 |
258 | :param params: (Params) Hyper parameters file.
259 | :param scone_occ: (SconeOcc) Occupancy probability prediction model.
260 | :param device: Device.
261 | :param torch_seed: (int) Seed used to initialize the network.
262 | :param load_pretrained_weights: (bool) If True, pretrained weights are loaded for initialization.
263 | :param ddp_rank: Rank dor DDP training.
264 | :return: (Tuple) Initialized SconeOcc model, Optimizer, optimizer name, start epoch, best loss.
265 | If training from scratch, start_epoch=0 and best_loss=0.
266 | """
267 | model_name = params.scone_occ_model_name
268 | start_epoch = 0
269 | best_loss = 1000.
270 |
271 | # Weight initialization process
272 | if load_pretrained_weights:
273 | # Load pretrained weights if needed
274 | if pretrained_weights_name==None:
275 | weights_file = "unvalidated_" + model_name + ".pth"
276 | else:
277 | weights_file = pretrained_weights_name
278 | checkpoint = torch.load(weights_file, map_location=device)
279 | start_epoch = checkpoint['epoch'] + 1
280 | best_loss = checkpoint['loss']
281 |
282 | ddp_model = False
283 | if (model_name[:2] == "jz") or (model_name[:3] == "ddp"):
284 | ddp_model = True
285 |
286 | if ddp_model:
287 | scone_occ = load_ddp_state_dict(scone_occ, checkpoint['model_state_dict'])
288 | else:
289 | scone_occ.load_state_dict(checkpoint['model_state_dict'])
290 |
291 | else:
292 | # Else, applies a basic initialization process
293 | if torch_seed is not None:
294 | torch.manual_seed(torch_seed)
295 | print("Seed", torch_seed, "chosen.")
296 | scone_occ.apply(init_weights)
297 |
298 | # DDP wrapping if needed
299 | if params.ddp:
300 | scone_occ = DDP(scone_occ,
301 | device_ids=[ddp_rank])
302 | elif params.jz:
303 | scone_occ = DDP(scone_occ,
304 | device_ids=[idr_torch.local_rank])
305 |
306 | # Creating optimizer
307 | optimizer, opt_name = get_optimizer(params, scone_occ)
308 | if load_pretrained_weights:
309 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
310 |
311 | return scone_occ, optimizer, opt_name, start_epoch, best_loss
312 |
313 |
314 | def load_scone_occ(params, trained_model_name, ddp_model, device):
315 | """
316 | Loads an already trained occupancy probability prediction model for inference on a single GPU.
317 |
318 | :param params: (Params) Parameters file.
319 | :param trained_model_name: (str) Name of trained model's checkpoint.
320 | :param device: Device.
321 | :return: (SconeOcc) Occupancy probability prediction module with trained weights.
322 | """
323 | scone_occ = SconeOcc().to(device)
324 |
325 | model_name = params.scone_occ_model_name
326 |
327 | # Loads checkpoint
328 | checkpoint = torch.load(trained_model_name, map_location=device)
329 | print("Model name:", trained_model_name)
330 | print("Trained for", checkpoint['epoch'], 'epochs.')
331 | print("Training finished with loss", checkpoint['loss'])
332 |
333 | # Loads trained weights
334 | if ddp_model:
335 | scone_occ = load_ddp_state_dict(scone_occ, checkpoint['model_state_dict'])
336 | else:
337 | scone_occ.load_state_dict(checkpoint['model_state_dict'])
338 |
339 | return scone_occ
340 |
341 |
342 | def initialize_scone_vis(params, scone_vis, device,
343 | torch_seed=None,
344 | load_pretrained_weights=False,
345 | pretrained_weights_name=None,
346 | ddp_rank=None):
347 | """
348 | Initializes SCONE's visibility prediction module for training.
349 | Can be initialized from scratch, or from an already trained model to resume training.
350 |
351 | :param params: (Params) Hyper parameters file.
352 | :param scone_vis: (SconeVis) Visibility prediction model.
353 | :param device: Device.
354 | :param torch_seed: (int) Seed used to initialize the network.
355 | :param load_pretrained_weights: (bool) If True, pretrained weights are loaded for initialization.
356 | :param ddp_rank: Rank dor DDP training.
357 | :return: (Tuple) Initialized SconeVis model, Optimizer, optimizer name, start epoch, best loss, best coverage.
358 | If training from scratch, start_epoch=0, best_loss=0. and best_coverage=0.
359 | """
360 | model_name = params.scone_vis_model_name
361 | start_epoch = 0
362 | best_loss = 1000.
363 | best_coverage = 0.
364 |
365 | # Weight initialization process
366 | if load_pretrained_weights:
367 | # Load pretrained weights if needed
368 | if pretrained_weights_name==None:
369 | weights_file = "unvalidated_" + model_name + ".pth"
370 | else:
371 | weights_file = pretrained_weights_name
372 | checkpoint = torch.load(weights_file, map_location=device)
373 | start_epoch = checkpoint['epoch'] + 1
374 | best_loss = checkpoint['loss']
375 | if 'coverage' in checkpoint:
376 | best_coverage = checkpoint['coverage']
377 |
378 | ddp_model = False
379 | if (model_name[:2] == "jz") or (model_name[:3] == "ddp"):
380 | ddp_model = True
381 |
382 | if ddp_model:
383 | scone_vis = load_ddp_state_dict(scone_vis, checkpoint['model_state_dict'])
384 | else:
385 | scone_vis.load_state_dict(checkpoint['model_state_dict'])
386 |
387 | else:
388 | # Else, applies a basic initialization process
389 | if torch_seed is not None:
390 | torch.manual_seed(torch_seed)
391 | print("Seed", torch_seed, "chosen.")
392 | scone_vis.apply(init_weights)
393 |
394 | # DDP wrapping if needed
395 | if params.ddp:
396 | scone_vis = DDP(scone_vis,
397 | device_ids=[ddp_rank])
398 | elif params.jz:
399 | scone_vis = DDP(scone_vis,
400 | device_ids=[idr_torch.local_rank])
401 |
402 | # Creating optimizer
403 | optimizer, opt_name = get_optimizer(params, scone_vis)
404 | if load_pretrained_weights:
405 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
406 |
407 | return scone_vis, optimizer, opt_name, start_epoch, best_loss, best_coverage
408 |
409 |
410 | def load_scone_vis(params, trained_model_name, ddp_model, device):
411 | """
412 | Loads an already trained visibility prediction model for inference on a single GPU.
413 |
414 | :param params: (Params) Parameters file.
415 | :param trained_model_name: (str) Name of trained model's checkpoint.
416 | :param device: Device.
417 | :return: (SconeVis) Visibility prediction module with trained weights.
418 | """
419 | scone_vis = SconeVis(use_sigmoid=params.use_sigmoid).to(device)
420 |
421 | model_name = params.scone_vis_model_name
422 |
423 | # Loads checkpoint
424 | checkpoint = torch.load(trained_model_name, map_location=device)
425 | print("Model name:", trained_model_name)
426 | print("Trained for", checkpoint['epoch'], 'epochs.')
427 | print("Training finished with loss", checkpoint['loss'])
428 |
429 | # Loads trained weights
430 | if ddp_model:
431 | scone_vis = load_ddp_state_dict(scone_vis, checkpoint['model_state_dict'])
432 | else:
433 | scone_vis.load_state_dict(checkpoint['model_state_dict'])
434 |
435 | return scone_vis
436 |
437 |
438 | def get_cov_loss_fn(params):
439 | if params.training_loss == "kl_divergence":
440 | cov_loss_fn = KLDivCE()
441 |
442 | elif params.training_loss == "l1":
443 | cov_loss_fn = L1_loss()
444 |
445 | elif params.training_loss == "uncentered_l1":
446 | cov_loss_fn = Uncentered_L1_loss()
447 |
448 | else:
449 | raise NameError("Invalid training loss function."
450 | "Please choose a valid loss between 'kl_divergence', 'l1' or 'uncentered_l1.")
451 |
452 | return cov_loss_fn
453 |
454 |
455 | def get_occ_loss_fn(params):
456 | if params.training_loss == "mse":
457 | occ_loss_fn = nn.MSELoss(size_average=None, reduce=None, reduction='mean')
458 | return occ_loss_fn
459 |
460 | else:
461 | raise NameError("Invalid training loss function."
462 | "Please choose a valid loss like 'mse'.")
463 |
464 |
465 | # -----Data Functions-----
466 |
467 | # TO CHANGE! Maybe Load on CPU for large scenes.
468 | def get_gt_partial_point_clouds(path, device, normalization_factor=None):
469 | """
470 | Loads ground truth partial point clouds for training.
471 | :param path:
472 | :param device:
473 | :param normalization_factor: factor to normalize the point cloud.
474 | if None, the point cloud is not normalized.
475 | :return:
476 | """
477 | parent_dir = os.path.dirname(path)
478 | load_directory = os.path.join(parent_dir, "tensors")
479 |
480 | file_name = "partial_point_clouds.pt"
481 | pc_dict = torch.load(os.path.join(load_directory, file_name),
482 | map_location=device)
483 |
484 | part_pc = pc_dict['partial_point_cloud']
485 | coverage = torch.vstack(pc_dict['coverage'])
486 |
487 | if normalization_factor is not None:
488 | for i in range(len(part_pc)):
489 | part_pc[i] = normalization_factor * part_pc[i]
490 |
491 | return part_pc, coverage
492 |
493 |
494 | def get_gt_occupancy_field(path, device):
495 | """
496 | Loads ground truth occupancy field for training.
497 | :param path:
498 | :param device:
499 | :return:
500 | """
501 | parent_dir = os.path.dirname(path)
502 | load_directory = os.path.join(parent_dir, "tensors")
503 |
504 | file_name = "occupancy_field.pt"
505 |
506 | pc_dict = torch.load(os.path.join(load_directory, file_name),
507 | map_location=device)
508 |
509 | X_world = pc_dict['occupancy_field'][..., :3]
510 | occs = pc_dict['occupancy_field'][..., 3:]
511 |
512 | return X_world, occs
513 |
514 |
515 | def get_gt_surface(params, path, device, normalization_factor=None):
516 | parent_dir = os.path.dirname(path)
517 | load_directory = os.path.join(parent_dir, "tensors")
518 | file_name = "surface_points.pt"
519 |
520 | surface_dict = torch.load(os.path.join(load_directory, file_name),
521 | map_location=device)
522 | gt_surface = surface_dict['surface_points']
523 |
524 | if params.surface_epsilon_is_constant:
525 | surface_epsilon = params.surface_epsilon
526 | else:
527 | surface_epsilon = surface_dict['epsilon']
528 |
529 | if normalization_factor is not None:
530 | gt_surface = gt_surface * normalization_factor
531 | surface_epsilon = surface_epsilon * normalization_factor
532 |
533 | return gt_surface, surface_epsilon
534 |
535 |
536 | def get_optimal_sequence(optimal_sequences, mesh_path, n_views):
537 | key = os.path.basename(os.path.dirname(mesh_path))
538 |
539 | # optimal_seq = optimal_sequences[key]['idx']
540 | optimal_seq = torch.Tensor(optimal_sequences[key]['idx']).long()
541 | seq_coverage = optimal_sequences[key]['coverage']
542 |
543 | return optimal_seq[:n_views], seq_coverage[:n_views]
544 |
545 |
546 | def compute_gt_coverage_gain_from_precomputed_matrices(coverage, initial_cam_idx):
547 | device = coverage.get_device()
548 |
549 | n_camera_candidates, n_points_surface = coverage.shape[0], coverage.shape[1]
550 |
551 | # Compute coverage matrix of previous cameras, and the corresponding value
552 | coverage_matrix = torch.sum(coverage[initial_cam_idx], dim=0).view(1, n_points_surface).expand(n_camera_candidates, -1)
553 | previous_coverage = torch.mean(torch.heaviside(coverage_matrix,
554 | values=torch.zeros_like(coverage_matrix,
555 | device=device)), dim=-1)
556 | # Compute coverage matrices of previous + new camera for every camera
557 | coverage_matrix = coverage_matrix + coverage
558 | coverage_matrix = torch.mean(torch.heaviside(coverage_matrix,
559 | values=torch.zeros_like(coverage_matrix,
560 | device=device)), dim=-1)
561 |
562 | # Compute coverage gain value
563 | coverage_matrix = coverage_matrix - previous_coverage
564 |
565 | return coverage_matrix.view(-1, 1)
566 |
567 |
568 | def compute_surface_coverage_from_cam_idx(coverage, cam_idx):
569 | device = coverage.get_device()
570 |
571 | coverage_matrix = torch.sum(coverage[cam_idx], dim=0)
572 |
573 | coverage = torch.mean(torch.heaviside(coverage_matrix,
574 | values=torch.zeros_like(coverage_matrix,
575 | device=device)), dim=-1)
576 |
577 | return coverage.view(1)
578 |
579 |
580 | def get_validation_n_views_list(params, dataloader):
581 | n_views = params.n_view_max - params.n_view_min + 1
582 |
583 | n_views_list = np.repeat(np.arange(start=params.n_view_min,
584 | stop=params.n_view_max + 1).reshape(1, n_views),
585 | len(dataloader.dataset) // n_views + 1, axis=0).reshape(-1)
586 |
587 | return n_views_list
588 |
589 |
590 | def get_validation_n_view(params, n_views_list, batch, rank):
591 | idx = batch * params.total_batch_size + rank * params.batch_size
592 |
593 | return n_views_list[idx:idx + params.batch_size]
594 |
595 |
596 | def get_validation_optimal_sequences(jz, device):
597 | if jz:
598 | SHAPENET_PATH = "../../datasets/ShapeNetCore.v1"
599 | else:
600 | SHAPENET_PATH = "../../../../../../mnt/ssd/aguedon/ShapeNetCore.v1/ShapeNetCore.v1"
601 |
602 | file_name = "validation_optimal_trajectories.pt"
603 | optimal_sequences = torch.load(os.path.join(SHAPENET_PATH, file_name),
604 | map_location=device)
605 |
606 | return optimal_sequences
607 |
608 |
609 | def get_all_harmonics_under_degree(degree, n_elev, n_azim, device):
610 | """
611 | Gets values for all harmonics with l < degree.
612 | :param degree:
613 | :param n_elev:
614 | :param n_azim:
615 | :param device:
616 | :return:
617 | """
618 | h_elev = torch.Tensor(
619 | [-np.pi / 2 + (i + 1) / (n_elev + 1) * np.pi for i in range(n_elev) for j in range(n_azim)]).to(device)
620 | h_polar = -h_elev + np.pi / 2
621 |
622 | h_azim = torch.Tensor([2 * np.pi * j / n_azim for i in range(n_elev) for j in range(n_azim)]).to(device)
623 |
624 | z = torch.zeros([i for i in h_polar.shape] + [0], device=h_polar.get_device())
625 |
626 | clear_spherical_harmonics_cache()
627 | for l in range(degree):
628 | y = get_spherical_harmonics(l, h_polar, h_azim)
629 | z = torch.cat((z, y), dim=-1)
630 |
631 | z = z.transpose(dim0=0, dim1=1)
632 |
633 | return z, h_polar, h_azim
634 |
635 |
636 | def get_cameras_on_sphere(params, device, pole_cameras=False, n_elev=None, n_azim=None):
637 | """
638 | Returns cameras candidate positions, sampled on a sphere.
639 | Made for SCONE pretraining on ShapeNet.
640 | :param params: (Params) The dictionary of parameters.
641 | :param device:
642 | :return: A tuple of Tensors (X_cam, candidate_dist, candidate_elev, candidate_azim)
643 | X_cam has shape (n_camera_candidate, 3)
644 | All other tensors have shape (n_camera candidate, )
645 | """
646 | if n_elev is None or n_azim is None:
647 | n_elev = params.n_camera_elev
648 | n_azim = params.n_camera_azim
649 | n_camera = params.n_camera
650 | else:
651 | n_camera = n_elev * n_azim
652 | if pole_cameras:
653 | n_camera += 2
654 |
655 | candidate_dist = torch.Tensor([params.camera_dist for i in range(n_camera)]).to(device)
656 |
657 | candidate_elev = [-90. + (i + 1) / (n_elev + 1) * 180.
658 | for i in range(n_elev)
659 | for j in range(n_azim)]
660 |
661 | candidate_azim = [360. * j / n_azim
662 | for i in range(n_elev)
663 | for j in range(n_azim)]
664 |
665 | if pole_cameras:
666 | candidate_elev = [-89.9] + candidate_elev + [89.9]
667 | candidate_azim = [0.] + candidate_azim + [0.]
668 |
669 | candidate_elev = torch.Tensor(candidate_elev).to(device)
670 | candidate_azim = torch.Tensor(candidate_azim).to(device)
671 |
672 | X_cam = get_cartesian_coords(r=candidate_dist.view(-1, 1),
673 | elev=candidate_elev.view(-1, 1),
674 | azim=candidate_azim.view(-1, 1),
675 | in_degrees=True)
676 |
677 | return X_cam, candidate_dist, candidate_elev, candidate_azim
678 |
679 |
680 | def normalize_points_in_prediction_box(points, prediction_box_center, prediction_box_diag):
681 | """
682 |
683 | :param points:
684 | :param prediction_box_center:
685 | :param prediction_box_diag:
686 | :return:
687 | """
688 | return (points - prediction_box_center) / prediction_box_diag
689 |
690 |
691 | def compute_view_state(pts, X_view, n_elev, n_azim):
692 | """
693 | Computes view_state vector for points pts and camera positions X_view.
694 | :param pts: Tensor with shape (n_cloud, seq_len, pts_dim) where pts_dim >= 3.
695 | :param X_view: Tensor with shape (n_screen_cameras, 3).
696 | Represents camera positions in prediction camera space coordinates.
697 | :param n_elev: Integer. Number of elevations values to discretize view states.
698 | :param n_azim: Integer. Number of azimuth values to discretize view states
699 | :return: A Tensor with shape (n_cloud, seq_len, n_elev*n_azim).
700 | """
701 | # Initializing variables
702 | device = pts.get_device()
703 | n_view = len(X_view)
704 | n_clouds, seq_len, _ = pts.shape
705 | n_candidates = n_elev * n_azim
706 |
707 | elev_step = np.pi / (n_elev + 1)
708 | azim_step = 2 * np.pi / n_azim
709 |
710 | X_pts = pts[..., :3]
711 |
712 | # Computing camera elev and azim in every pts space coordinates
713 | rays = X_view.view(1, 1, n_view, 3).expand(n_clouds, seq_len, -1, -1) \
714 | - X_pts.view(n_clouds, seq_len, 1, 3).expand(-1, -1, n_view, -1)
715 |
716 | _, ray_elev, ray_azim = get_spherical_coords(rays.view(-1, 3))
717 |
718 | ray_elev = ray_elev.view(n_clouds, seq_len, n_view)
719 | ray_azim = ray_azim.view(n_clouds, seq_len, n_view)
720 |
721 | # Projecting elev and azim to the closest values in the discretized cameras
722 | idx_elev = floor_divide(ray_elev, elev_step)
723 | idx_azim = floor_divide(ray_azim, azim_step)
724 |
725 | # If closest to ceil than floor, we add 1
726 | idx_elev[ray_elev % elev_step > elev_step / 2.] += 1
727 | idx_azim[ray_azim % azim_step > azim_step / 2.] += 1
728 |
729 | # Elevation can't be below minimal or above maximal values
730 | idx_elev[idx_elev >= n_elev] = n_elev - 1
731 | idx_elev[idx_elev < -n_elev // 2] = -n_elev // 2
732 |
733 | # If azimuth is greater than 180 degrees, we reset it back to -180 degrees
734 | idx_azim[idx_azim > n_azim // 2] = -n_azim // 2
735 |
736 | # Normalizing indices to retrieve camera positions in flattened view_state
737 | idx_elev += n_elev // 2
738 | idx_azim[idx_azim < 0] += n_azim
739 |
740 | indices = idx_elev.long() * n_azim + idx_azim.long()
741 | indices %= n_candidates
742 | q = torch.arange(start=0, end=n_clouds * seq_len, step=1, device=device).view(-1, 1).expand(-1, n_view)
743 |
744 | flat_indices = indices.view(-1, n_view)
745 | flat_indices = q * n_candidates + flat_indices
746 | flat_indices = flat_indices.view(-1)
747 |
748 | # Compute view_state and set visited camera values to 1
749 | view_state = torch.zeros(n_clouds, seq_len, n_candidates, device=device)
750 | view_state.view(-1)[flat_indices] = 1.
751 |
752 | return view_state
753 |
754 |
755 | def move_view_state_to_view_space(view_state, fov_camera, n_elev, n_azim):
756 | """
757 | "Rotate" the view state vectors to the corresponding view space.
758 |
759 | :param view_state: (Tensor) View state tensor with shape (n_cloud, seq_len, n_elev * n_azim)
760 | :param fov_camera: (FoVPerspectiveCamera)
761 | :param n_elev: (int)
762 | :param n_azim: (int)
763 | :return: Rotated view state tensor with shape (n_cloud, seq_len, n_elev * n_azim)
764 | """
765 | device = view_state.get_device()
766 | n_clouds = view_state.shape[0]
767 | seq_len = view_state.shape[1]
768 |
769 | n_view = n_elev * n_azim
770 |
771 | candidate_dist = torch.Tensor([1. for i in range(n_elev * n_azim)]).to(device)
772 |
773 | candidate_elev = [-90. + (i + 1) / (n_elev + 1) * 180.
774 | for i in range(n_elev)
775 | for j in range(n_azim)]
776 |
777 | candidate_azim = [360. * j / n_azim
778 | for i in range(n_elev)
779 | for j in range(n_azim)]
780 |
781 | candidate_elev = torch.Tensor(candidate_elev).to(device)
782 | candidate_azim = torch.Tensor(candidate_azim).to(device)
783 |
784 | X_cam_ref = get_cartesian_coords(r=candidate_dist.view(-1, 1),
785 | elev=candidate_elev.view(-1, 1),
786 | azim=candidate_azim.view(-1, 1),
787 | in_degrees=True)
788 | X_cam_inv = fov_camera.get_world_to_view_transform().inverse().transform_points(
789 | X_cam_ref) - fov_camera.get_camera_center()
790 |
791 | elev_step = np.pi / (n_elev + 1)
792 | azim_step = 2 * np.pi / n_azim
793 |
794 | _, ray_elev, ray_azim = get_spherical_coords(X_cam_inv.view(-1, 3))
795 |
796 | ray_elev = ray_elev.view(n_view)
797 | ray_azim = ray_azim.view(n_view)
798 |
799 | # Projecting elev and azim to the closest values in the discretized cameras
800 | idx_elev = floor_divide(ray_elev, elev_step)
801 | idx_azim = floor_divide(ray_azim, azim_step)
802 |
803 | # If closest to ceil than floor, we add 1
804 | idx_elev[ray_elev % elev_step > elev_step / 2.] += 1
805 | idx_azim[ray_azim % azim_step > azim_step / 2.] += 1
806 |
807 | # Elevation can't be below minimal or above maximal values
808 | idx_elev[idx_elev > n_elev // 2] = n_elev // 2
809 | idx_elev[idx_elev < -(n_elev // 2)] = -(n_elev // 2)
810 |
811 | # If azimuth is greater than 180 degrees, we reset it back to -180 degrees
812 | idx_azim[idx_azim > n_azim // 2] = -(n_azim // 2)
813 |
814 | # Normalizing indices to retrieve camera positions in flattened view_state
815 | idx_elev += n_elev // 2
816 | idx_azim[idx_azim < 0] += n_azim
817 |
818 | indices = idx_elev.long() * n_azim + idx_azim.long()
819 |
820 | rot_view_state = torch.gather(input=view_state, dim=2, index=indices.view(1, 1, -1).expand(n_clouds, seq_len, -1))
821 |
822 | return rot_view_state
823 |
824 |
825 |
826 | def compute_view_harmonics(view_state, base_harmonics, h_polar, h_azim, n_elev, n_azim):
827 | """
828 | Computes spherical harmonics corresponding to the view_state vector.
829 | :param view_state: Tensor with shape (n_cloud, seq_len, n_elev*n_azim).
830 | :param base_harmonics: Tensor with shape (n_harmonics, n_elev*n_azim).
831 | :param h_polar:
832 | :param h_azim:
833 | :param n_elev:
834 | :param n_azim:
835 | :return: Tensor with shape (n_cloud, seq_len, n_harmonics)
836 | """
837 | # Define parameters
838 | n_harmonics = base_harmonics.shape[0]
839 | n_clouds, seq_len, n_values = view_state.shape
840 |
841 | polar_step = np.pi / (n_elev + 1)
842 | azim_step = 2 * np.pi / n_azim
843 |
844 | # Expanding variables to parallelize computation
845 | all_values = view_state.view(n_clouds, seq_len, 1, n_values).expand(-1, -1, n_harmonics, -1)
846 | all_polar = h_polar.view(1, 1, 1, n_values).expand(n_clouds, seq_len, n_harmonics, -1)
847 | # all_harmonics = base_harmonics.view(1, 1, n_harmonics, n_values).expand(n_clouds, seq_len, -1, -1)
848 |
849 | # Computing spherical L2-dot product on last axis
850 | coordinates = torch.sum(all_values * base_harmonics * torch.sin(all_polar) * polar_step * azim_step, dim=-1)
851 |
852 | return coordinates
853 |
854 |
855 | # ----------Model functions----------
856 |
857 | def compute_occupancy_probability(scone_occ, pc, X, view_harmonics, mask=None,
858 | max_points_per_pass=20000):
859 | """
860 |
861 | :param scone_occ: (Scone_Occ) SCONE's Occupancy Probability prediction model.
862 | :param pc: (Tensor) Input point cloud tensor with shape (n_clouds, seq_len, pts_dim)
863 | :param X: (Tensor) Input query points tensor with shape (n_clouds, n_sample, x_dim)
864 | :param view_harmonics: (Tensor) View state harmonic features. Tensor with shape (n_clouds, seq_len, n_harmonics)
865 | :param max_points_per_pass: (int) Maximal number of points per forward pass.
866 | :return:
867 | """
868 | n_clouds, seq_len, pts_dim = pc.shape[0], pc.shape[1], pc.shape[2]
869 | n_sample, x_dim = X.shape[1], X.shape[2]
870 | n_harmonics = view_harmonics.shape[2]
871 |
872 | preds = torch.zeros(n_clouds, 0, 1).to(X.get_device())
873 |
874 | p = max_points_per_pass // n_clouds
875 | q = n_sample // p
876 | r = n_sample % p
877 | n_loop = q
878 | if r != 0:
879 | n_loop += 1
880 |
881 | for i in range(n_loop):
882 | low_idx = i * p
883 | up_idx = (i + 1) * p
884 | if i == q:
885 | up_idx = q * p + r
886 | preds_i = scone_occ(pc, X[:, low_idx:up_idx], view_harmonics[:, low_idx:up_idx])
887 | preds_i = preds_i.view(n_clouds, up_idx - low_idx, -1)
888 | preds = torch.cat((preds, preds_i), dim=1)
889 |
890 | return preds
891 |
892 |
893 | def filter_proxy_points(view_cameras, X, pc, filter_tol=0.01):
894 | """
895 | Filter proxy points considering camera field of view and partial surface point cloud.
896 | WARNING: Works for a single scene! So X must have shape (n_proxy_points, 3)!
897 | :param view_cameras:
898 | :param X: (Tensor) Proxy points tensor with shape ()
899 | :param pc: (Tensor)
900 | :param filter_tol:
901 | :return:
902 | """
903 |
904 | n_view = view_cameras.R.shape[0]
905 |
906 | if (len(X.shape) != 2) or (len(pc.shape) != 2):
907 | raise NameError("Wrong shapes! X must have shape (n_proxy_points, 3) and pc must have shape (N, 3).")
908 |
909 | view_projection_transform = view_cameras.get_full_projection_transform()
910 | X_proj = view_projection_transform.transform_points(X)[..., :2].view(n_view, -1, 2)
911 | pc_proj = view_projection_transform.transform_points(pc)[..., :2].view(n_view, -1, 2)
912 |
913 | max_proj = torch.max(pc_proj, dim=-2, keepdim=True)[0].expand(-1, X_proj.shape[-2], -1)
914 | min_proj = torch.min(pc_proj, dim=-2, keepdim=True)[0].expand(-1, X_proj.shape[-2], -1)
915 |
916 | filter_mask = torch.prod((X_proj < max_proj + filter_tol) * (X_proj > min_proj - filter_tol), dim=0)
917 | filter_mask = torch.prod(filter_mask, dim=-1).bool()
918 |
919 | return X[filter_mask], filter_mask
920 |
921 |
922 | def sample_proxy_points(X_world, preds, view_harmonics, n_sample, min_occ, use_occ_to_sample=True,
923 | return_index=False):
924 | """
925 |
926 | :param X: Tensor with shape (n_points, 3)
927 | :param preds: Tensor with shape (n_points, 1)
928 | :param view_harmonics: Tensor with shape (n_points, n_harmonics)
929 | :param n_sample: integer
930 | :return:
931 | """
932 | mask = preds[..., 0] > min_occ
933 | res_X = X_world[mask]
934 | res_preds = preds[mask]
935 | res_harmonics = view_harmonics[mask]
936 |
937 | device = res_X.get_device()
938 | n_points = res_X.shape[0]
939 |
940 | if use_occ_to_sample:
941 | sample_probs = res_preds[..., 0] / torch.sum(res_preds)
942 | sample_probs = torch.cumsum(sample_probs, dim=-1)
943 |
944 | samples = torch.rand(n_sample, 1, device=device)
945 |
946 | res_idx = sample_probs.view(1, n_points).expand(n_sample, -1) - samples.expand(-1, n_points)
947 | res_idx[res_idx < 0] = 2
948 | res_idx = torch.argmin(res_idx, dim=-1)
949 |
950 | res_idx, inverse_idx = torch.unique(res_idx, dim=0, return_inverse=True)
951 |
952 | res = torch.cat((res_X[res_idx], res_preds[res_idx]), dim=-1)
953 | res_harmonics = res_harmonics[res_idx]
954 | # res = torch.unique(res, dim=0)
955 |
956 | else:
957 | if len(res_X) > n_sample:
958 | res_X = res_X[:n_sample]
959 | res_preds = res_preds[:n_sample]
960 | res_harmonics = res_harmonics[:n_sample]
961 |
962 | res = torch.cat((res_X, res_preds), dim=-1)
963 | inverse_idx = None
964 |
965 | if return_index:
966 | return res, res_harmonics, inverse_idx
967 | else:
968 | return res, res_harmonics
969 |
--------------------------------------------------------------------------------
/SCONE/spherical_harmonics.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/lucidrains/se3-transformer-pytorch/blob/main/se3_transformer_pytorch/spherical_harmonics.py
2 |
3 | from math import pi, sqrt
4 | from operator import mul
5 | import torch
6 | from functools import reduce, wraps
7 |
8 | from math import pi, sqrt
9 | from operator import mul
10 | import torch
11 | from functools import reduce, wraps
12 |
13 |
14 | def cache(cache, key_fn):
15 | def cache_inner(fn):
16 | @wraps(fn)
17 | def inner(*args, **kwargs):
18 | key_name = key_fn(*args, **kwargs)
19 | if key_name in cache:
20 | return cache[key_name]
21 | res = fn(*args, **kwargs)
22 | cache[key_name] = res
23 | return res
24 |
25 | return inner
26 |
27 | return cache_inner
28 |
29 |
30 | # constants
31 |
32 | CACHE = {}
33 |
34 |
35 | def clear_spherical_harmonics_cache():
36 | CACHE.clear()
37 |
38 |
39 | def lpmv_cache_key_fn(l, m, x):
40 | return (l, m)
41 |
42 |
43 | # spherical harmonics
44 |
45 | # def reduce(function, iterable, initializer=None):
46 | # it = iter(iterable)
47 | # if initializer is None:
48 | # value = next(it)
49 | # else:
50 | # value = initializer
51 | # for element in it:
52 | # value = function(value, element)
53 | # return value
54 |
55 | def semifactorial(x):
56 | return reduce(mul, range(x, 1, -2), 1.)
57 |
58 |
59 | def pochhammer(x, k):
60 | return reduce(mul, range(x + 1, x + k), float(x))
61 |
62 |
63 | def negative_lpmv(l, m, y):
64 | if m < 0:
65 | y *= ((-1) ** m / pochhammer(l + m + 1, -2 * m))
66 | return y
67 |
68 |
69 | @cache(cache=CACHE, key_fn=lpmv_cache_key_fn)
70 | def lpmv(l, m, x):
71 | """Associated Legendre function including Condon-Shortley phase.
72 | Args:
73 | m: int order
74 | l: int degree
75 | x: float argument tensor
76 | Returns:
77 | tensor of x-shape
78 | """
79 | # Check memoized versions
80 | m_abs = abs(m)
81 |
82 | if m_abs > l:
83 | return None
84 |
85 | if l == 0:
86 | x_device = x.get_device()
87 | if x_device < 0:
88 | x_device = "cpu"
89 | return torch.ones_like(x, device=x_device)
90 |
91 | # Check if on boundary else recurse solution down to boundary
92 | if m_abs == l:
93 | # Compute P_m^m
94 | y = (-1) ** m_abs * semifactorial(2 * m_abs - 1)
95 | y *= torch.pow(1 - x * x, m_abs / 2)
96 | return negative_lpmv(l, m, y)
97 |
98 | # Recursively precompute lower degree harmonics
99 | lpmv(l - 1, m, x)
100 |
101 | # Compute P_{l}^m from recursion in P_{l-1}^m and P_{l-2}^m
102 | # Inplace speedup
103 | y = ((2 * l - 1) / (l - m_abs)) * x * lpmv(l - 1, m_abs, x)
104 |
105 | if l - m_abs > 1:
106 | y -= ((l + m_abs - 1) / (l - m_abs)) * CACHE[(l - 2, m_abs)]
107 |
108 | if m < 0:
109 | y = self.negative_lpmv(l, m, y)
110 | return y
111 |
112 |
113 | def get_spherical_harmonics_element(l, m, theta, phi):
114 | """Tesseral spherical harmonic with Condon-Shortley phase.
115 | The Tesseral spherical harmonics are also known as the real spherical
116 | harmonics.
117 | Args:
118 | l: int for degree
119 | m: int for order, where -l <= m < l
120 | theta: collatitude or polar angle
121 | phi: longitude or azimuth
122 | Returns:
123 | tensor of shape theta
124 | """
125 | m_abs = abs(m)
126 | assert m_abs <= l, "absolute value of order m must be <= degree l"
127 |
128 | N = sqrt((2 * l + 1) / (4 * pi))
129 | leg = lpmv(l, m_abs, torch.cos(theta))
130 |
131 | if m == 0:
132 | return N * leg
133 |
134 | if m > 0:
135 | Y = torch.cos(m * phi)
136 | else:
137 | Y = torch.sin(m_abs * phi)
138 |
139 | Y *= leg
140 | N *= sqrt(2. / pochhammer(l - m_abs + 1, 2 * m_abs))
141 | Y *= N
142 | return Y
143 |
144 |
145 | def get_spherical_harmonics(l, theta, phi):
146 | """ Tesseral harmonic with Condon-Shortley phase.
147 | The Tesseral spherical harmonics are also known as the real spherical
148 | harmonics.
149 | Args:
150 | l: int for degree
151 | theta: collatitude or polar angle
152 | phi: longitude or azimuth
153 | Returns:
154 | tensor of shape [*theta.shape, 2*l+1]
155 | """
156 | return torch.stack([get_spherical_harmonics_element(l, m, theta, phi) \
157 | for m in range(-l, l + 1)],
158 | dim=-1)
159 |
160 |
161 | def evaluate_from_harmonic_coordinates(coordinates, theta, phi, degree):
162 | """
163 | We denote by N the number of harmonics with l <= degree.
164 | WARNING! Theta should represent the polar angles, not elevation!
165 | :param coordinates: must have shape (N)
166 | :param theta: Tensor that represents polar angles.
167 | :param phi: Tensor that represents azimuth angles.
168 | :param degree: int
169 | :return: A Tensor with shape theta
170 | """
171 | z = torch.zeros([i for i in theta.shape] + [0], device=theta.get_device())
172 |
173 | for l in range(degree):
174 | y = get_spherical_harmonics(l, theta, phi)
175 | z = torch.cat((z, y), dim=-1)
176 |
177 | return torch.sum(coordinates * z, dim=-1)
178 |
--------------------------------------------------------------------------------
/SCONE/test_scone.py:
--------------------------------------------------------------------------------
1 | from scone_utils import *
2 |
3 | pc_size = 1024
4 | n_view_max = 10
5 | test_novel = True
6 | test_number = -1
7 |
8 | params_name = "train_params_jz_model_scone_vis_surface_coverage_gain_constant_epsilon_uncentered_l1_warmup_1000_schedule_lr_0.0001_sigmoid_tmcs.json"
9 | scone_occ_model_name = "best_unval_jz_model_scone_occ_mse_warmup_1000_schedule_lr_0.0001.pth"
10 | scone_vis_model_name = "best_unval_jz_model_scone_vis_surface_coverage_gain_constant_epsilon_uncentered_l1_warmup_1000_schedule_lr_0.0001_sigmoid_tmcs.pth"
11 |
12 |
13 | def save_test_params(save=False):
14 | # TO DO: for ddp, save only if is_master but load for everyone, with a synchronization in-between.
15 | params = {}
16 |
17 | # -----General parameters-----
18 | params["ddp"] = False
19 | params["jz"] = False
20 |
21 | # TO ADAPT
22 | # params["coverage_validation"] = False
23 | params["test_number"] = test_number
24 | params["test_novel"] = test_novel
25 | # params["n_screen_cameras"] = 1
26 | params["pc_size"] = 1024 # 2048
27 |
28 | if params["ddp"]:
29 | params["CUDA_VISIBLE_DEVICES"] = "0, 1"
30 | params["WORLD_SIZE"] = 2
31 |
32 | elif params["jz"]:
33 | params["WORLD_SIZE"] = idr_torch.size
34 |
35 | else:
36 | params["numGPU"] = 1
37 | params["WORLD_SIZE"] = 1
38 |
39 | params["anomaly_detection"] = True
40 | params["empty_cache_every_n_batch"] = 10
41 |
42 | # -----Ground truth computation parameters-----
43 | params["compute_gt_online"] = False
44 | params["compute_partial_point_cloud_online"] = False
45 |
46 | params["gt_surface_resolution"] = 1.5
47 |
48 | params["gt_max_diagonal"] = 1. # Formerly known as x_range
49 |
50 | params["n_points_surface"] = 16384 # N points on GT surface
51 |
52 | params["surface_epsilon_is_constant"] = True
53 | if params["surface_epsilon_is_constant"]:
54 | params["surface_epsilon"] = 0.00707
55 |
56 | # -----SconeOcc Model Parameters-----
57 | # params["scone_occ_model_name"] = "best_unval_jz_model_scone_occ_mse_warmup_1000_schedule_lr_0.0001.pth"
58 | params["scone_occ_model_name"] = "best_unval_jz_model_scone_occ_mse_warmup_1000_schedule_lr_0.0001.pth"
59 | params["occ_no_view_harmonics"] = False
60 |
61 | params["n_view_max_for_scone_occ"] = 10
62 | params["max_points_per_scone_occ_pass"] = 300000
63 |
64 | # -----Model Parameters-----
65 | params["seq_len"] = 2048
66 | params["pts_dim"] = 4
67 |
68 | params["view_state_n_elev"] = 7
69 | params["view_state_n_azim"] = 2 * 7
70 | params["harmonic_degree"] = 8
71 |
72 | params["n_proxy_points"] = 100000 # 12000, 100000
73 | params["use_occ_to_sample_proxy_points"] = True # True
74 | params["min_occ_for_proxy_points"] = 0.1
75 |
76 | params["true_monte_carlo_sampling"] = True
77 |
78 | # -----Ablation study-----
79 | params["no_view_harmonics"] = False
80 | params["use_sigmoid"] = True
81 |
82 | # -----General training parameters-----
83 | params["start_from_scratch"] = True
84 | params["pretrained_weights_name"] = None
85 |
86 | params["n_view_max"] = 10
87 | params["n_view_min"] = 1
88 | params["filter_tol"] = 0.01
89 |
90 | params["camera_dist"] = 1.5
91 | params["pole_cameras"] = True
92 | params["n_camera_elev"] = 5
93 | params["n_camera_azim"] = 2 * 5
94 | params["n_camera"] = params["n_camera_elev"] * params["n_camera_azim"]
95 | if params["pole_cameras"]:
96 | params["n_camera"] += 2
97 |
98 | params["prediction_in_random_camera_space"] = False
99 |
100 | params["batch_size"] = 4
101 |
102 | params["noam_opt"] = False
103 | params["training_metric"] = "surface_coverage_gain"
104 | # Training metric can be: "surface_coverage", "surface_coverage_gain", "absolute_coverage"
105 | params["training_loss"] = "uncentered_l1" # "kl_divergence", "l1", "uncentered_l1"
106 | params["multiply_loss"] = False
107 | if params["multiply_loss"]:
108 | params["loss_multiplication_factor"] = 10.
109 |
110 | params["nbv_validation"] = True
111 |
112 | params["random_seed"] = 42
113 | params["torch_seed"] = 5
114 |
115 | # -----Visibility Model name to save-----
116 | params["scone_vis_model_name"] = scone_vis_model_name
117 |
118 | # -----Json name to save params-----
119 | json_name = "test_params_" + scone_vis_model_name + ".json"
120 |
121 | if save:
122 | with open(json_name, 'w') as outfile:
123 | json.dump(params, outfile)
124 |
125 | print("Parameters save in:")
126 | print(json_name)
127 |
128 | return json_name
129 |
130 |
131 | def test_loop(params, dataloader,
132 | scone_occ, scone_vis,
133 | device):
134 |
135 | # Begin test process
136 | coverage_dict = {}
137 | sum_coverages = torch.zeros(params.n_view_max, device=device)
138 |
139 | t0 = time.time()
140 | torch.cuda.empty_cache()
141 |
142 | print("Beginning evaluation on test dataset...")
143 |
144 | base_harmonics, h_polar, h_azim = get_all_harmonics_under_degree(params.harmonic_degree,
145 | params.view_state_n_elev,
146 | params.view_state_n_azim,
147 | device)
148 |
149 | size = len(dataloader.dataset)
150 | num_batches = len(dataloader)
151 |
152 | t0 = time.time()
153 | computation_time = 0.
154 |
155 | for batch, (mesh_dict) in enumerate(dataloader):
156 | paths = mesh_dict['path']
157 | batch_size = len(paths)
158 |
159 | for i in range(batch_size):
160 | # ----------Load input mesh and ground truth data-----------------------------------------------------------
161 |
162 | path_i = paths[i]
163 |
164 | coverage_dict[path_i] = []
165 | coverages = torch.zeros(params.n_view_max, device=device)
166 |
167 | # Loading info about partial point clouds and coverages
168 | part_pc, coverage_matrix = get_gt_partial_point_clouds(path=path_i,
169 | normalization_factor=1. / params.gt_surface_resolution,
170 | device=device)
171 |
172 | # Initial dense sampling
173 | X_world = sample_X_in_box(x_range=params.gt_max_diagonal, n_sample=params.n_proxy_points, device=device)
174 |
175 | # ----------Set camera candidates for coverage prediction---------------------------------------------------
176 | X_cam_world, camera_dist, camera_elev, camera_azim = get_cameras_on_sphere(params, device,
177 | pole_cameras=params.pole_cameras)
178 | n_view = 1
179 | view_idx = torch.randperm(len(camera_elev), device=device)[:n_view]
180 |
181 | prediction_cam_idx = view_idx[0]
182 | prediction_box_center = torch.Tensor([0., 0., params.camera_dist]).to(device)
183 |
184 | # Move camera coordinates from world space to prediction view space, and normalize them for prediction box
185 | prediction_R, prediction_T = look_at_view_transform(dist=camera_dist[prediction_cam_idx],
186 | elev=camera_elev[prediction_cam_idx],
187 | azim=camera_azim[prediction_cam_idx],
188 | device=device)
189 | prediction_camera = FoVPerspectiveCameras(device=device, R=prediction_R, T=prediction_T)
190 | prediction_view_transform = prediction_camera.get_world_to_view_transform()
191 |
192 | X_cam = prediction_view_transform.transform_points(X_cam_world)
193 | X_cam = normalize_points_in_prediction_box(points=X_cam,
194 | prediction_box_center=prediction_box_center,
195 | prediction_box_diag=params.gt_max_diagonal)
196 | _, elev_cam, azim_cam = get_spherical_coords(X_cam)
197 |
198 | X_view = X_cam[view_idx]
199 | # X_cam = X_cam.view(1, params.n_camera, 3)
200 |
201 | # Compute initial coverage
202 | coverage = compute_surface_coverage_from_cam_idx(coverage_matrix, view_idx).detach().item()
203 |
204 | coverage_dict[path_i].append(coverage)
205 | coverages[0] += coverage
206 |
207 | # Sample random proxy points in space
208 | X_idx = torch.randperm(len(X_world))[:params.n_proxy_points]
209 | X_world = X_world[X_idx]
210 |
211 | for j_view in range(1, params.n_view_max):
212 | features = None
213 | args = None
214 | computation_t0 = time.time()
215 |
216 | # ----------Capture initial observations----------------------------------------------------------------
217 |
218 | # Points observed in initial views
219 | pc = torch.vstack([part_pc[pc_idx] for pc_idx in view_idx])
220 |
221 | # Downsampling partial point cloud
222 | # pc = pc[torch.randperm(len(pc))[:n_view * params.seq_len]]
223 | pc = pc[torch.randperm(len(pc))[:n_view * pc_size]]
224 |
225 | # Move partial point cloud from world space to prediction view space,
226 | # and normalize them in prediction box
227 | pc = prediction_view_transform.transform_points(pc)
228 | pc = normalize_points_in_prediction_box(points=pc,
229 | prediction_box_center=prediction_box_center,
230 | prediction_box_diag=params.gt_max_diagonal).view(1, -1, 3)
231 |
232 | # Move proxy points from world space to prediction view space, and normalize them in prediction box
233 | X = prediction_view_transform.transform_points(X_world)
234 | X = normalize_points_in_prediction_box(points=X,
235 | prediction_box_center=prediction_box_center,
236 | prediction_box_diag=params.gt_max_diagonal
237 | )
238 |
239 | # Filter Proxy Points using pc shape from view cameras
240 | R_view, T_view = look_at_view_transform(eye=X_view,
241 | at=torch.zeros_like(X_view),
242 | device=device)
243 | view_cameras = FoVPerspectiveCameras(R=R_view, T=T_view, zfar=1000, device=device)
244 | X, _ = filter_proxy_points(view_cameras, X, pc.view(-1, 3), filter_tol=params.filter_tol)
245 | X = X.view(1, X.shape[0], 3)
246 |
247 | # Compute view state vector and corresponding view harmonics
248 | view_state = compute_view_state(X, X_view,
249 | params.view_state_n_elev, params.view_state_n_azim)
250 | view_harmonics = compute_view_harmonics(view_state,
251 | base_harmonics, h_polar, h_azim,
252 | params.view_state_n_elev, params.view_state_n_azim)
253 | occ_view_harmonics = 0. + view_harmonics
254 | if params.occ_no_view_harmonics:
255 | occ_view_harmonics *= 0.
256 | if params.no_view_harmonics:
257 | view_harmonics *= 0.
258 |
259 | # Compute occupancy probabilities
260 | with torch.no_grad():
261 | occ_prob_i = compute_occupancy_probability(scone_occ=scone_occ,
262 | pc=pc,
263 | X=X,
264 | view_harmonics=occ_view_harmonics,
265 | max_points_per_pass=params.max_points_per_scone_occ_pass
266 | ).view(-1, 1)
267 |
268 | proxy_points, view_harmonics, sample_idx = sample_proxy_points(X[0], occ_prob_i,
269 | view_harmonics.squeeze(dim=0),
270 | n_sample=params.seq_len,
271 | min_occ=params.min_occ_for_proxy_points,
272 | use_occ_to_sample=params.use_occ_to_sample_proxy_points,
273 | return_index=True)
274 |
275 | proxy_points = torch.unsqueeze(proxy_points, dim=0)
276 | view_harmonics = torch.unsqueeze(view_harmonics, dim=0)
277 |
278 | # ----------Predict Coverage Gains------------------------------------------------------------------------------
279 | visibility_gain_harmonics = scone_vis(proxy_points, view_harmonics=view_harmonics)
280 | if params.true_monte_carlo_sampling:
281 | proxy_points = torch.unsqueeze(proxy_points[0][sample_idx], dim=0)
282 | visibility_gain_harmonics = torch.unsqueeze(visibility_gain_harmonics[0][sample_idx], dim=0)
283 |
284 | if params.ddp or params.jz:
285 | cov_pred_i = scone_vis.module.compute_coverage_gain(proxy_points,
286 | visibility_gain_harmonics,
287 | X_cam)
288 | else:
289 | cov_pred_i = scone_vis.compute_coverage_gain(proxy_points,
290 | visibility_gain_harmonics,
291 | X_cam.view(1, -1, 3)).view(-1, 1)
292 |
293 | # Identify maximum gain to get NBV camera
294 | (max_gain, max_idx) = torch.max(cov_pred_i, dim=0)
295 |
296 | computation_time += time.time() - computation_t0
297 |
298 | # Set NBV camera parameters
299 | view_idx = torch.cat((view_idx, torch.Tensor([max_idx]).long().to(device)), dim=0)
300 | X_nbv = X_cam[max_idx:max_idx + 1]
301 | X_nbv_world = X_cam_world[max_idx:max_idx + 1]
302 | r_nbv_world = camera_dist[max_idx:max_idx + 1]
303 | elev_nbv_world = camera_elev[max_idx:max_idx + 1]
304 | azim_nbv_world = camera_azim[max_idx:max_idx + 1]
305 |
306 | nbv_dist = torch.Tensor([params.camera_dist]).to(device)
307 |
308 | nbv_R, nbv_T = look_at_view_transform(dist=r_nbv_world,
309 | elev=elev_nbv_world,
310 | azim=azim_nbv_world,
311 | device=device)
312 |
313 | R_view = torch.vstack((R_view, nbv_R))
314 | # screen_dist = torch.vstack((screen_dist, nbv_dist))
315 | X_view = torch.vstack((X_view, X_nbv))
316 |
317 | # Computing surface coverage
318 | coverage = compute_surface_coverage_from_cam_idx(coverage_matrix, view_idx).detach().item()
319 | # avg_coverage += coverage / batch_size
320 | coverage_dict[path_i].append(coverage)
321 | coverages[j_view] += coverage
322 |
323 | sum_coverages += coverages
324 |
325 | # ----------Metrics computation on batch----------
326 |
327 | if batch % 10 == 0:
328 | torch.cuda.empty_cache()
329 | print("--- Batch", batch, "---")
330 | print("Batch size:", batch_size)
331 | # print("Coverage:", avg_coverage / (batch + 1))
332 | print("Coverages:", sum_coverages / ((batch + 1) * params.batch_size))
333 | print("Nb of meshes done:", (batch + 1) * params.batch_size)
334 | print("Computation time:", computation_time, '\n')
335 |
336 | results = {}
337 | # results["occ_threshold"] = occ_threshold
338 | # results["uncertainty_threshold"] = params.uncertainty_threshold
339 | # results["uncertainty_mode"] = params.uncertainty_mode
340 | # results["compute_cross_correction"] = params.compute_cross_correction
341 | # results["nbv_mode"] = nbv_mode
342 | results["coverages"] = coverage_dict
343 |
344 | # print("Results:", results)
345 |
346 | print("Avg coverages loss:", sum_coverages.detach().cpu() / len(dataloader.dataset))
347 | print("Done in", (time.time() - t0) / 3600., "hours!")
348 | print("Computation time:", computation_time)
349 | print("Average computation time:", computation_time / len(dataloader.dataset))
350 |
351 | print("Terminated in", (time.time() - t0) / 60., "minutes.")
352 | return results
353 |
354 | if __name__ == '__main__':
355 | json_name = save_test_params(True)
356 | params = load_params(json_name)
357 | params.n_view_max = n_view_max
358 |
359 | # Set device
360 | device = setup_device(params, ddp_rank=None)
361 |
362 | # Load models
363 | print("Loading SconeOcc...")
364 | scone_occ = load_scone_occ(params, scone_occ_model_name, ddp_model=True, device=device)
365 | print("Model has", count_parameters(scone_occ) / 1e6, "M parameters.")
366 | scone_occ.eval()
367 |
368 | print("Loading SconeVis...")
369 | scone_vis = load_scone_vis(params, scone_vis_model_name, ddp_model=True, device=device)
370 | print("Model has", count_parameters(scone_vis) / 1e6, "M parameters.")
371 | scone_vis.eval()
372 |
373 | if test_novel and params.test_novel:
374 | print("Test on novel categories.")
375 |
376 | train_dataloader, val_dataloader, test_dataloader = get_shapenet_dataloader(batch_size=params.batch_size,
377 | ddp=params.ddp, jz=params.jz,
378 | world_size=None, ddp_rank=None,
379 | test_number=params.test_number,
380 | test_novel=test_novel,
381 | load_obj=False,
382 | data_path=None)
383 |
384 | # Main loop
385 | eval_results = []
386 | eval_results.append(test_loop(params, test_dataloader,
387 | scone_occ, scone_vis, device))
388 |
389 | if params.test_number == 0:
390 | json_name = "test_iterative_results_for_models_" + scone_vis_model_name + ".json"
391 | elif params.test_number == -1:
392 | json_name = "full_test_iterative_results_for_models_" + scone_vis_model_name + ".json"
393 | else:
394 | json_name = "test_" + str(params.test_number)\
395 | + "_iterative_results_for_models_" + scone_vis_model_name + "_v2.json"
396 |
397 | if test_novel:
398 | json_name = "novel_" + json_name
399 |
400 | # json_name = "novel_test_" + str(params.test_number) +"_iterative_results_for_random_nbv_faster.json"
401 |
402 | for res in eval_results:
403 | for key in res:
404 | if type(res[key]) == torch.Tensor:
405 | res[key] = res[key].detach().item()
406 |
407 | if type(res[key]) == np.ndarray:
408 | res[key] = float(res[key])
409 |
410 | with open(json_name, 'w') as outfile:
411 | json.dump(eval_results, outfile)
412 | print("Saved data about test losses in", json_name)
--------------------------------------------------------------------------------
/docs/gifs/colosseum.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anttwo/SCONE/24f176518a794b69c5aeed81d22c0077ce4b9169/docs/gifs/colosseum.gif
--------------------------------------------------------------------------------
/docs/gifs/fushimi.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anttwo/SCONE/24f176518a794b69c5aeed81d22c0077ce4b9169/docs/gifs/fushimi.gif
--------------------------------------------------------------------------------
/docs/gifs/museum.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anttwo/SCONE/24f176518a794b69c5aeed81d22c0077ce4b9169/docs/gifs/museum.gif
--------------------------------------------------------------------------------
/docs/gifs/pantheon.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anttwo/SCONE/24f176518a794b69c5aeed81d22c0077ce4b9169/docs/gifs/pantheon.gif
--------------------------------------------------------------------------------