├── LICENSE ├── README.md ├── SCONE ├── Attention.py ├── CustomDataset.py ├── CustomGeometry.py ├── SconeOcc.py ├── SconeVis.py ├── idr_torch.py ├── pretrain_scone_occ.py ├── pretrain_scone_vis.py ├── scone_utils.py ├── spherical_harmonics.py ├── test_scone.py └── utils.py └── docs └── gifs ├── colosseum.gif ├── fushimi.gif ├── museum.gif └── pantheon.gif /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Antoine GUEDON 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SCONE 2 | 3 |
4 |

5 | SCONE: Surface Coverage Optimization in Unknown Environments
by Volumetric Integration 6 |

7 | 8 | Antoine Guédon  9 | Pascal Monasse  10 | Vincent Lepetit  11 | 12 | fushimi.gif 13 | museum.gif
14 | pantheon.gif 15 | colosseum.gif 16 | 17 |

18 |
19 | 20 | Official PyTorch implementation of [**SCONE: Surface Coverage Optimization in Unknown Environments by Volumetric Integration**](https://arxiv.org/abs/2208.10449) (NeurIPS 2022, Spotlight). 21 | 22 |
23 | :warning: :warning: :warning: :warning: :warning: :warning: :warning: :warning: :warning: :warning: :warning: :warning: 24 |
25 |
26 | 27 | **We released a new model called [MACARONS (CVPR 2023)](https://github.com/Anttwo/MACARONS) which is a direct improvement of SCONE.
28 | MACARONS adapts the approach from SCONE to a fully self-supervised pipeline: It learns simultaneously to explore and reconstruct 3D scenes from RGB images only (there is no need for 3D ground truth data nor depth sensors).
29 | The codebase of MACARONS includes an entire, updated version of SCONE's code as well as detailed instructions for generating the ShapeNet training data we used for SCONE.
30 | Please refer to MACARONS' repo to find all the information you need.**
31 | 32 |
33 | Thank you!

34 | :warning: :warning: :warning: :warning: :warning: :warning: :warning: :warning: :warning: :warning: :warning: :warning: 35 |
36 | 37 |
38 | This repository currently contains: 39 | 40 | - scripts to initialize and train models 41 | - evaluation pipelines to reproduce quantitative results 42 | 43 | **Note**: We will add **installation guidelines**, **training data generation scripts** and **test notebooks** as soon as possible to allow for reproducibility. 44 | 45 |
46 | If you find this code useful, don't forget to star the repo :star: and cite the paper :point_down: 47 | 48 | ``` 49 | @inproceedings{guedon2022scone, 50 | title={{SCONE: Surface Coverage Optimization in Unknown Environments by Volumetric Integration}}, 51 | author={Gu\'edon, Antoine and Monasse, Pascal and Lepetit, Vincent}, 52 | booktitle={{Advances in Neural Information Processing Systems}}, 53 | year={2022}, 54 | } 55 | ``` 56 | 57 |
58 | 59 |
60 | Major code updates :clipboard: 61 | 62 | - 11/22: first code release 63 | 64 |
65 | 66 | ## Installation :construction_worker: 67 | 68 | We will add more details as soon as possible. 69 | 70 | ## Download Data 71 | 72 | ### 1. ShapeNetCore 73 | 74 | We generate training data for both occupancy probability prediction and coverage gain estimation from [ShapeNetCore v1](https://shapenet.org/).
75 | We will add the data generation scripts and corresponding instructions as soon as possible. 76 | 77 | ### 2. Custom Dataset of large 3D scenes 78 | 79 | We conducted inference experiments in large environments using 3D meshes downloaded on the website Sketchfab under the CC license.
80 | 3D models courtesy of [Brian Trepanier](https://sketchfab.com/CMBC), [Andrea Spognetta](https://sketchfab.com/spogna), and [Vr Interiors](https://sketchfab.com/vrInteriors).
81 | We will add more details as soon as possible. 82 | 83 | ## How to use :rocket: 84 | 85 | We will add more details as soon as possible. 86 | 87 | ## Further information :books: 88 | 89 | We adapted the code from [Phil Wang](https://github.com/lucidrains/se3-transformer-pytorch/blob/main/se3_transformer_pytorch/spherical_harmonics.py) to generate spherical harmonic features.
90 | We thank him for this very useful harmonics computation script!
91 | 92 | We also thank [Tom Monnier](https://www.tmonnier.com/) for his Markdown template, which we took inspiration from. 93 | -------------------------------------------------------------------------------- /SCONE/Attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from pytorch3d.ops import knn_points, knn_gather 6 | 7 | 8 | def attention(q, k, v, mask=None, dropout=None): 9 | """ 10 | Main attention mechanism function. 11 | 12 | Takes queries, keys and values as inputs and computes attention scores. 13 | :param q: (Tensor) Queries tensor with shape (..., N, d) 14 | :param k: (Tensor) Keys tensor with shape (..., N, d) 15 | :param v: (Tensor) Values tensor with shape (..., N, d) 16 | :param mask: (Tensor) Mask tensor with shape (..., N, N). Optional. 17 | :param dropout: Dropout module to apply on computed scores. 18 | :return: scores: (Tensor) Attention scores tensor with shape (..., N) 19 | """ 20 | # Query/Key matrix multiplication 21 | scores = q.matmul(k.transpose(-2, -1)) 22 | 23 | # Apply mask 24 | scores = scores if mask is None else scores.masked_fill(mask == 0, -1e3) 25 | 26 | # Normalization 27 | scores /= np.sqrt(q.shape[-1]) 28 | scores = nn.functional.softmax(scores, dim=-1) 29 | 30 | # Dropout 31 | scores = scores if dropout is None else dropout(scores) 32 | 33 | # Value matrix multiplication 34 | scores = scores.matmul(v) 35 | 36 | return scores 37 | 38 | 39 | class Embedding(nn.Module): 40 | def __init__(self, input_dim, output_dim, 41 | dropout=None, gelu=True, 42 | global_feature=False, 43 | additional_feature_dim=0, 44 | concatenate_input=True, 45 | k_for_knn=0): 46 | """ 47 | Class used to embed point clouds. 48 | 49 | :param input_dim: (int) Dimension of input points. 50 | :param output_dim: (int) Dimension of output features. 51 | :param dropout: Dropout module to apply on computed embeddings. 52 | :param gelu: (bool) If True, module uses GELU non-linearity. If False, module uses ReLU. 53 | :param global_feature: (bool) If True, output features are computed as the concatenation of per-point features 54 | and a global feature. 55 | :param additional_feature_dim: (int) Dimension of an additional feature to provide. 56 | :param concatenate_input: (bool) If True, concatenates input to the output features. 57 | :param k_for_knn: (int) if > 0, output features are computed by max-pooling features from k nearest neighbors. 58 | """ 59 | super(Embedding, self).__init__() 60 | 61 | self.use_knn = k_for_knn > 0 62 | self.k = k_for_knn 63 | 64 | self.input_dim = input_dim 65 | 66 | self.global_feature = global_feature 67 | self.additional_feature_dim = additional_feature_dim 68 | self.concatenate_input = concatenate_input 69 | 70 | self.inner_dim = output_dim // 2 71 | self.feature_dim = output_dim 72 | 73 | if additional_feature_dim > 0: 74 | self.feature_dim -= additional_feature_dim 75 | self.inner_dim = self.feature_dim 76 | 77 | if concatenate_input: 78 | self.feature_dim -= input_dim 79 | self.inner_dim = self.feature_dim 80 | 81 | if global_feature: 82 | self.feature_dim = self.feature_dim // 2 83 | self.inner_dim = self.feature_dim 84 | 85 | if global_feature or self.use_knn: 86 | self.max_pool = nn.functional.max_pool1d 87 | 88 | self.linear1 = nn.Linear(self.input_dim, self.inner_dim) 89 | self.linear2 = nn.Linear(self.inner_dim, self.feature_dim) 90 | 91 | self.dropout = nn.Dropout(dropout) if dropout is not None else None 92 | 93 | if gelu: 94 | self.nonlinear = nn.GELU() 95 | else: 96 | self.nonlinear = nn.ReLU(inplace=False) 97 | 98 | def forward(self, x, additional_feature=None): 99 | n_clouds, seq_len, x_dim = x.shape 100 | 101 | res = self.nonlinear(self.linear1(x)) 102 | res = res if self.dropout is None else self.dropout(res) 103 | res = self.linear2(res) 104 | 105 | if self.use_knn: 106 | # Computing spatial kNN 107 | _, knn_idx, _ = knn_points(p1=x[..., :3], p2=x[..., :3], K=self.k) 108 | res_knn = knn_gather(res, knn_idx) 109 | 110 | # Pooling among kNN features 111 | res_knn = res_knn.view(n_clouds * seq_len, self.k, self.feature_dim) 112 | res = self.max_pool(input=res_knn.transpose(-1, -2), 113 | kernel_size=self.k 114 | ).view(n_clouds, seq_len, self.feature_dim) 115 | 116 | if self.global_feature: 117 | global_feature = self.max_pool(input=res.transpose(-1, -2), 118 | kernel_size=seq_len).view(n_clouds, 1, self.feature_dim) 119 | global_feature = global_feature.expand(-1, seq_len, -1) 120 | res = torch.cat((res, global_feature), dim=-1) 121 | 122 | if self.additional_feature_dim > 0: 123 | res = torch.cat((res, additional_feature), dim=-1) 124 | 125 | if self.concatenate_input: 126 | res = torch.cat((res, x), dim=-1) 127 | 128 | return res 129 | 130 | 131 | class MultiHeadSelfAttention(nn.Module): 132 | def __init__(self, n_heads, in_dim, qk_dim, dropout=None): 133 | """ 134 | Main class for Multi-Head Self Attention neural module. 135 | Credits to https://arxiv.org/pdf/2012.09688.pdf 136 | 137 | :param n_heads: (int) Number of heads in the attention module. 138 | :param in_dim: (int) Dimension of input. 139 | :param qk_dim: (int) Dimension of keys and queries. 140 | :param dropout: Dropout module to apply on computed embeddings. 141 | """ 142 | super(MultiHeadSelfAttention, self).__init__() 143 | v_dim = in_dim 144 | 145 | self.n_heads = n_heads 146 | self.in_dim = in_dim 147 | self.qk_dim = qk_dim 148 | self.v_dim = v_dim 149 | 150 | self.qk_dim_per_head = qk_dim // n_heads 151 | self.v_dim_per_head = v_dim // n_heads 152 | 153 | self.w_q = nn.Linear(in_dim, qk_dim) 154 | self.w_k = nn.Linear(in_dim, qk_dim) 155 | self.w_v = nn.Linear(in_dim, v_dim) 156 | 157 | self.dropout = nn.Dropout(dropout) if dropout is not None else None 158 | 159 | if n_heads > 1: 160 | self.out = nn.Linear(in_dim, in_dim) 161 | 162 | def split_heads(self, q, k, v): 163 | """ 164 | Split all queries, keys and values between attention heads. 165 | 166 | :param q: (Tensor) Queries tensor with shape (batch_size, seq_len, qk_dim) 167 | :param k: (Tensor) Keys tensor with shape (batch_size, seq_len, qk_dim) 168 | :param v: (Tensor) Values tensor with shape (batch_size, seq_len, v_dim) 169 | :return: (3-tuple of Tensors) Queries, keys and values for each head. 170 | q_split and k_split tensors have shape (batch_size, seq_len, n_heads, qk_dim_per_head). 171 | v_split tensor has shape (batch_size, seq_len, n_heads, v_dim_per_head). 172 | """ 173 | # Each have size n_screen_cameras * pts_len * n_heads * dim_per_head 174 | q_split = q.reshape(q.shape[0], -1, self.n_heads, self.qk_dim_per_head) 175 | k_split = k.reshape(k.shape[0], -1, self.n_heads, self.qk_dim_per_head) 176 | v_split = v.reshape(v.shape[0], -1, self.n_heads, self.v_dim_per_head) 177 | 178 | return q_split, k_split, v_split 179 | 180 | def forward(self, x, mask=None): 181 | """ 182 | Forward pass. 183 | 184 | :param x: (Tensor) Input tensor with shape (batch_size, seq_len, in_dim) 185 | :param mask: (Tensor) Mask tensor with shape (batch_size, seq_len, seq_len). Optional. 186 | :return: scores (Tensor) Attention scores tensor with shape (batch_size, seq_len, in_dim) 187 | """ 188 | # pts should have size BS * SEQ_LEN * IN_DIM 189 | q = self.w_q(x) 190 | k = self.w_k(x) 191 | v = self.w_v(x) 192 | 193 | # Break into n_heads 194 | q, k, v = self.split_heads(q, k, v) 195 | q, k, v = [t.transpose(1, 2) for t in (q, k, v)] # BS * HEAD * SEQ_LEN * DIM_PER_HEAD 196 | 197 | scores = attention(q, k, v, mask, self.dropout) # BS * HEAD * SEQ_LEN * DIM_PER_HEAD 198 | scores = scores.transpose(1, 2).contiguous().view( 199 | scores.shape[0], -1, self.v_dim) # BS * SEQ_LEN * V_DIM 200 | 201 | if self.n_heads > 1: 202 | scores = self.out(scores) 203 | 204 | return scores 205 | 206 | 207 | class FeedForward(nn.Module): 208 | def __init__(self, input_dim, inner_dim, gelu=True, dropout=None): 209 | """ 210 | Feed Forward unit to use in attention encoder. 211 | 212 | :param input_dim: (int) Dimension of input tensor. 213 | :param inner_dim: (int) Dimension of inner tensor. 214 | :param gelu: (bool) If True, the unit uses GELU non-linearity. if False, it uses ReLU. 215 | :param dropout: Dropout module to apply on computed embeddings. 216 | """ 217 | super(FeedForward, self).__init__() 218 | self.linear1 = nn.Linear(input_dim, inner_dim) 219 | self.linear2 = nn.Linear(inner_dim, input_dim) 220 | self.dropout = nn.Dropout(dropout) if dropout is not None else None 221 | if gelu: 222 | self.nonlinear = nn.GELU() 223 | else: 224 | self.nonlinear = nn.ReLU(inplace=False) 225 | 226 | def forward(self, x): 227 | """ 228 | Forward pass. 229 | 230 | :param x: (Tensor) Input tensor with shape (..., input_dim) 231 | :return: (Tensor) Output tensor with shape (..., input_dim) 232 | """ 233 | res = self.nonlinear(self.linear1(x)) 234 | res = res if self.dropout is None else self.dropout(res) 235 | 236 | return self.linear2(res) 237 | 238 | 239 | class Encoder(nn.Module): 240 | def __init__(self, seq_len, qk_dim, embedding_dim=128, n_heads=1, 241 | dropout=None, gelu=True, FF=True): 242 | """ 243 | Transformer encoder based on a Multi-Head Self Attention mechanism. 244 | 245 | :param seq_len: (int) Length of input sequence. 246 | :param qk_dim: (int) Dimension of keys and queries. 247 | :param embedding_dim: (int) Dimension of input embeddings, values, and attention scores. 248 | :param n_heads: (int) Number of heads in the self-attention module. 249 | :param dropout: Dropout module to apply on computed features. 250 | :param gelu: (bool) If True, the encoder uses GELU non-linearity. if False, it uses ReLU. 251 | :param FF: (bool) If True, the encoder applies an additional Feed Forward unit after 252 | the Multi-Head Self Attention unit. 253 | """ 254 | super(Encoder, self).__init__() 255 | 256 | self.seq_len = seq_len 257 | self.embedding_dim = embedding_dim 258 | self.n_heads = n_heads 259 | self.qk_dim = qk_dim # self.embedding_dim // 4 260 | self.dropout = dropout 261 | self.FF = FF 262 | 263 | self.norm1 = nn.LayerNorm(self.embedding_dim) 264 | self.mhsa = MultiHeadSelfAttention(n_heads=self.n_heads, 265 | in_dim=self.embedding_dim, 266 | qk_dim=self.qk_dim, 267 | dropout=self.dropout) 268 | self.dropout1 = None if dropout is None else nn.Dropout(dropout) 269 | 270 | if FF: 271 | self.norm2 = nn.LayerNorm(self.embedding_dim) 272 | self.ff = FeedForward(input_dim=embedding_dim, 273 | inner_dim=2*embedding_dim, 274 | gelu=gelu, 275 | dropout=self.dropout) 276 | self.dropout2 = None if dropout is None else nn.Dropout(dropout) 277 | 278 | def forward(self, x, mask=None): 279 | """ 280 | Forward pass. 281 | 282 | :param x: (Tensor) Input tensor with shape (batch_size, seq_len, embedding_dim). 283 | :param mask: (Tensor) Mask tensor with shape (batch_size, seq_len, seq_len). Optional. 284 | :return: (Tensor) The features encoded with a self-attention mechanism. 285 | Has shape (batch_size, seq_len, embedding_dim) 286 | """ 287 | res = self.norm1(x) 288 | res = self.mhsa(res, mask=mask) 289 | if self.dropout is not None: 290 | res = self.dropout1(res) 291 | res = x + res 292 | 293 | if self.FF: 294 | res2 = self.norm2(res) 295 | res2 = self.ff(res2) 296 | if self.dropout is not None: 297 | res2 = self.dropout2(res2) 298 | res = res + res2 299 | 300 | return res -------------------------------------------------------------------------------- /SCONE/CustomDataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import os 4 | import torch 5 | from torch.utils.data import Dataset, DataLoader 6 | from torch.utils.data.sampler import SubsetRandomSampler 7 | 8 | from pytorch3d.datasets import ( 9 | collate_batched_meshes, 10 | render_cubified_voxels, 11 | ) 12 | from pytorch3d.renderer import ( 13 | TexturesVertex, 14 | TexturesAtlas, 15 | ) 16 | from pytorch3d.transforms import( 17 | matrix_to_quaternion, 18 | quaternion_apply, 19 | ) 20 | from utils import * 21 | import time 22 | 23 | class CustomDataset(Dataset): 24 | 25 | def __init__(self, data_path, memory_threshold, rasterizer, screen_rasterizer, params, camera_axes, device, 26 | save_to_json=False, load_from_json=False, json_name="models_list.json", 27 | load_obj=True): 28 | self.data_path = data_path 29 | self.rasterizer = rasterizer 30 | self.screen_rasterizer = screen_rasterizer 31 | # self.renderer = renderer 32 | # self.embedder = embedder 33 | # self.preprocess = preprocess 34 | self.image_size = params.image_size 35 | self.camera_dist = params.camera_dist 36 | self.elevation = params.elevation 37 | self.azim_angle = params.azim_angle 38 | self.camera_axes = camera_axes 39 | self.side = params.side 40 | self.device = device 41 | 42 | self.load_obj = load_obj 43 | 44 | if not load_from_json: 45 | models = [] 46 | for (dirpath, dirnames, filenames) in os.walk(data_path): 47 | for filename in filenames: 48 | if filename[-4:] == ".obj": 49 | models.append(os.path.join(dirpath, filename)) 50 | models = remove_heavy_files(models, memory_threshold) 51 | else: 52 | with open(json_name) as f: 53 | dir_to_load = json.load(f) 54 | models = [os.path.join(data_path, path) for path in dir_to_load['models']] 55 | 56 | if save_to_json: 57 | dir_to_save = {} 58 | dir_to_save['models'] = [path[1+len(data_path):] for path in models] 59 | with open(json_name, 'w') as outfile: 60 | json.dump(dir_to_save, outfile) 61 | print("Saved models list in", json_name, ".") 62 | 63 | print("Database loaded.") 64 | self.models = models 65 | 66 | def __len__(self): 67 | return len(self.models) 68 | 69 | def __getitem__(self, idx): # -> Dict: 70 | 71 | model_path = self.models[idx] 72 | model = {} 73 | 74 | if self.load_obj: 75 | verts, faces, aux = load_obj( 76 | model_path, 77 | device=self.device, 78 | load_textures=False, 79 | create_texture_atlas=True, 80 | # texture_atlas_size=4, 81 | # texture_wrap="repeat", 82 | ) 83 | 84 | verts = adjust_mesh(verts) 85 | 86 | # Create a textures object 87 | atlas = aux.texture_atlas 88 | 89 | model["verts"] = verts 90 | model["faces"] = faces.verts_idx 91 | model["textures"] = aux[4] 92 | model["atlas"] = aux.texture_atlas 93 | 94 | model["path"] = model_path 95 | return model 96 | 97 | class CustomShapenetDataset(Dataset): 98 | 99 | def __init__(self, data_path, memory_threshold, 100 | save_to_json=False, load_from_json=False, json_name="models_list.json", 101 | official_split=False, adjust_diagonally=False, 102 | load_obj=True): 103 | self.data_path = data_path 104 | self.official_split = official_split 105 | self.adjust_diagonally = adjust_diagonally 106 | 107 | self.load_obj = load_obj 108 | 109 | if not load_from_json: 110 | models = [] 111 | for (dirpath, dirnames, filenames) in os.walk(data_path): 112 | for filename in filenames: 113 | if filename[-4:] == ".obj": 114 | models.append(os.path.join(dirpath, filename)) 115 | models = remove_heavy_files(models, memory_threshold) 116 | else: 117 | with open(json_name) as f: 118 | dir_to_load = json.load(f) 119 | models = [os.path.join(data_path, path) for path in dir_to_load['models']] 120 | 121 | if save_to_json: 122 | dir_to_save = {} 123 | dir_to_save['models'] = [path[1+len(data_path):] for path in models] 124 | with open(json_name, 'w') as outfile: 125 | json.dump(dir_to_save, outfile) 126 | print("Saved models list in", json_name, ".") 127 | 128 | print("Database loaded.") 129 | self.models = models 130 | 131 | def __len__(self): 132 | return len(self.models) 133 | 134 | def __getitem__(self, idx): # -> Dict: 135 | 136 | model_path = self.models[idx] 137 | model = {} 138 | 139 | if self.load_obj: 140 | verts, faces, aux = load_obj( 141 | model_path, 142 | # device=self.device, 143 | load_textures=False, 144 | create_texture_atlas=True, 145 | # texture_atlas_size=4, 146 | # texture_wrap="repeat", 147 | ) 148 | 149 | if self.adjust_diagonally: 150 | verts = adjust_mesh_diagonally(verts, diag_range=1.0) 151 | else: 152 | verts = adjust_mesh(verts) 153 | 154 | # Create a textures object 155 | atlas = aux.texture_atlas 156 | 157 | model["verts"] = verts 158 | model["faces"] = faces.verts_idx 159 | model["textures"] = aux[4] 160 | model["atlas"] = aux.texture_atlas 161 | 162 | model["path"] = model_path 163 | return model 164 | 165 | 166 | class RGBDataset(Dataset): 167 | def __init__(self, data_path, alpha_max, use_future_images, scene_names=None, 168 | frames_to_remove_json='frames_to_remove.pt'): 169 | self.data_path = data_path 170 | self.alpha_max = alpha_max 171 | self.use_future_images = use_future_images 172 | 173 | self.data = {} 174 | self.indices = {} 175 | current_idx = 0 176 | 177 | self.data['scenes'] = {} 178 | self.data['n_scenes'] = 0 179 | 180 | # If no scene name is provided, we just take all scenes in the folder 181 | if scene_names is None: 182 | # scene_names = os.listdir(self.data_path) 183 | scene_names = [scene_name for scene_name in os.listdir(self.data_path) 184 | if os.path.isdir(os.path.join(self.data_path, scene_name))] 185 | 186 | if frames_to_remove_json in scene_names: 187 | scene_names.remove(frames_to_remove_json) 188 | 189 | self.frames_to_remove = torch.load(os.path.join(data_path, frames_to_remove_json)) 190 | 191 | # For every scene... 192 | for scene_name in scene_names: 193 | self.data['n_scenes'] += 1 194 | self.data['scenes'][scene_name] = {} 195 | scene_path = os.path.join(self.data_path, scene_name) 196 | scene_path = os.path.join(scene_path, 'images') 197 | # print(scene_path) 198 | 199 | self.data['scenes'][scene_name]['trajectories'] = {} 200 | self.data['scenes'][scene_name]['n_trajectories'] = 0 201 | # ...And every trajectory... 202 | for trajectory_nb in os.listdir(scene_path): 203 | self.data['scenes'][scene_name]['n_trajectories'] += 1 204 | self.data['scenes'][scene_name]['trajectories'][trajectory_nb] = {} 205 | trajectory_path = os.path.join(scene_path, trajectory_nb) 206 | # print(trajectory_path) 207 | 208 | traj_length = len(os.listdir(trajectory_path)) 209 | 210 | self.data['scenes'][scene_name]['trajectories'][trajectory_nb]['frames'] = {} 211 | self.data['scenes'][scene_name]['trajectories'][trajectory_nb]['n_frames'] = 0 212 | # ...We add all frames respecting conditions on the number of past (and, if required, future) frames 213 | for frame_name in os.listdir(trajectory_path): 214 | frame_nb = frame_name[:-3] 215 | short_path = scene_name + "/images/" + str(trajectory_nb) + "/" + str(frame_nb) + ".pt" 216 | 217 | save_index = False 218 | index_to_save = None 219 | 220 | if int(frame_nb) >= self.alpha_max and ( 221 | (not self.use_future_images) or 222 | int(frame_nb) < traj_length - self.alpha_max 223 | ): 224 | if not (short_path in self.frames_to_remove.keys()): 225 | self.indices[str(current_idx)] = {'scene_name': scene_name, 226 | 'trajectory_nb': trajectory_nb, 227 | 'frame_nb': frame_nb} 228 | save_index = True 229 | index_to_save = current_idx 230 | current_idx += 1 231 | 232 | self.data['scenes'][scene_name]['trajectories'][trajectory_nb]['n_frames'] += 1 233 | self.data['scenes'][scene_name][ 234 | 'trajectories'][trajectory_nb][ 235 | 'frames'][frame_nb] = {} 236 | self.data['scenes'][scene_name][ 237 | 'trajectories'][trajectory_nb][ 238 | 'frames'][frame_nb][ 239 | 'path'] = os.path.join(trajectory_path, str(frame_nb) + '.pt') 240 | if save_index: 241 | self.data['scenes'][scene_name][ 242 | 'trajectories'][trajectory_nb][ 243 | 'frames'][frame_nb][ 244 | 'index'] = index_to_save 245 | 246 | print("Database loaded.") 247 | 248 | def __len__(self): 249 | # total_length = 0 250 | # for scene_name in self.data['scenes']: 251 | # for trajectory_nb in self.data['scenes'][scene_name]['trajectories']: 252 | # total_length += len(self.data['scenes'][scene_name]['trajectories'][trajectory_nb]['frames'].keys()) 253 | # return total_length 254 | 255 | return len(self.indices.keys()) 256 | 257 | def __getitem__(self, idx): # -> Dict: 258 | 259 | scene_name = self.indices[str(idx)]['scene_name'] 260 | trajectory_nb = self.indices[str(idx)]['trajectory_nb'] 261 | frame_nb = self.indices[str(idx)]['frame_nb'] 262 | 263 | frame_path = self.data['scenes'][scene_name]['trajectories'][trajectory_nb]['frames'][frame_nb]['path'] 264 | frame = torch.load(frame_path, map_location='cpu') 265 | 266 | frame['path'] = frame_path 267 | frame['index'] = idx 268 | 269 | return frame 270 | 271 | def get_neighbor_frame(self, frame, alpha, device='cpu'): 272 | """ 273 | 274 | :param frame: dictionary 275 | :param alpha: int 276 | :param device: 277 | :return: 278 | """ 279 | idx = frame['index'] 280 | scene_name = self.indices[str(idx)]['scene_name'] 281 | trajectory_nb = self.indices[str(idx)]['trajectory_nb'] 282 | frame_nb = str(int(self.indices[str(idx)]['frame_nb']) + alpha) 283 | 284 | neighbor_frame_path = self.data['scenes'][scene_name]['trajectories'][trajectory_nb]['frames'][frame_nb]['path'] 285 | neighbor_frame = torch.load(neighbor_frame_path, map_location=device) 286 | 287 | neighbor_frame['path'] = neighbor_frame_path 288 | neighbor_frame['index'] = idx 289 | 290 | return neighbor_frame 291 | 292 | def get_neighbor_frame_from_idx(self, idx, alpha, device='cpu'): 293 | """ 294 | 295 | :param idx: int 296 | :param alpha: int 297 | :param device: 298 | :return: 299 | """ 300 | scene_name = self.indices[str(idx)]['scene_name'] 301 | trajectory_nb = self.indices[str(idx)]['trajectory_nb'] 302 | frame_nb = str(int(self.indices[str(idx)]['frame_nb']) + alpha) 303 | 304 | neighbor_frame_path = self.data['scenes'][scene_name]['trajectories'][trajectory_nb]['frames'][frame_nb]['path'] 305 | neighbor_frame = torch.load(neighbor_frame_path, map_location=device) 306 | 307 | neighbor_frame['path'] = neighbor_frame_path 308 | neighbor_frame['index'] = idx 309 | 310 | return neighbor_frame 311 | 312 | 313 | class SceneDataset(Dataset): 314 | def __init__(self, data_path, scene_names=None): 315 | self.data_path = data_path 316 | 317 | # If no scene name is provided, we just take all scenes in the folder 318 | if scene_names is None: 319 | # scene_names = os.listdir(self.data_path) 320 | scene_names = [scene_name for scene_name in os.listdir(self.data_path) 321 | if os.path.isdir(os.path.join(self.data_path, scene_name))] 322 | 323 | self.scene_names = scene_names 324 | 325 | def __len__(self): 326 | # total_length = 0 327 | # for scene_name in self.data['scenes']: 328 | # for trajectory_nb in self.data['scenes'][scene_name]['trajectories']: 329 | # total_length += len(self.data['scenes'][scene_name]['trajectories'][trajectory_nb]['frames'].keys()) 330 | # return total_length 331 | 332 | return len(self.scene_names) 333 | 334 | def __getitem__(self, idx): # -> Dict: 335 | 336 | scene_name = self.scene_names[idx] 337 | scene_path = os.path.join(self.data_path, scene_name) 338 | 339 | # Mesh info 340 | obj_name = scene_name + '.obj' 341 | 342 | # Settings info 343 | settings_file = os.path.join(scene_path, 'settings.json') 344 | with open(settings_file, "r") as read_content: 345 | settings = json.load(read_content) 346 | 347 | # Info about occupied camera poses 348 | occupied_pose = torch.load(os.path.join(scene_path, 'occupied_pose.pt'), map_location=torch.device('cpu')) 349 | 350 | scene = {} 351 | scene['scene_name'] = scene_name 352 | scene['obj_name'] = obj_name 353 | scene['settings'] = settings 354 | scene['occupied_pose'] = occupied_pose 355 | 356 | return scene -------------------------------------------------------------------------------- /SCONE/CustomGeometry.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | # from utils import * 4 | 5 | def get_cartesian_coords(r, elev, azim, in_degrees=False): 6 | """ 7 | Returns the cartesian coordinates of 3D points written in spherical coordinates. 8 | :param r: (Tensor) Radius tensor of 3D points, with shape (N). 9 | :param elev: (Tensor) Elevation tensor of 3D points, with shape (N). 10 | :param azim: (Tensor) Azimuth tensor of 3D points, with shape (N). 11 | :param in_degrees: (bool) In True, elevation and azimuth are written in degrees. 12 | Else, in radians. 13 | :return: (Tensor) Cartesian coordinates tensor with shape (N, 3). 14 | """ 15 | factor = 1 16 | if in_degrees: 17 | factor *= np.pi / 180. 18 | X = torch.stack(( 19 | torch.cos(factor * elev) * torch.sin(factor * azim), 20 | torch.sin(factor * elev), 21 | torch.cos(factor * elev) * torch.cos(factor * azim) 22 | ), dim=2) 23 | 24 | return r * X.view(-1, 3) 25 | 26 | 27 | def get_spherical_coords(X): 28 | """ 29 | Returns the spherical coordinates of 3D points written in cartesian coordinates 30 | :param X: (Tensor) Tensor with shape (N, 3) that represents 3D points in cartesian coordinates. 31 | :return: (3-tuple of Tensors) r_x, elev_x and azim_x are Tensors with shape (N) that corresponds 32 | to radius, elevation and azimuths of all 3D points. 33 | """ 34 | r_x = torch.linalg.norm(X, dim=1) 35 | 36 | elev_x = torch.asin(X[:, 1] / r_x) # between -pi/2 and pi/2 37 | elev_x[X[:, 1] / r_x <= -1] = -np.pi / 2 38 | elev_x[X[:, 1] / r_x >= 1] = np.pi / 2 39 | 40 | azim_x = torch.acos(X[:, 2] / (r_x * torch.cos(elev_x))) 41 | azim_x[X[:, 2] / (r_x * torch.cos(elev_x)) <= -1] = np.pi 42 | azim_x[X[:, 2] / (r_x * torch.cos(elev_x)) >= 1] = 0. 43 | azim_x[X[:, 0] < 0] *= -1 44 | 45 | return r_x, elev_x, azim_x 46 | 47 | def sample_cameras_on_sphere(n_X, radius, device): 48 | """ 49 | Deterministic sampling of camera positions on a sphere. 50 | 51 | :param n_X (int): number of positions to sample. Should be a square int. 52 | :param radius (float): radius of the sphere for sampling. 53 | :param device 54 | :return: A tensor with shape (n_X, 3). 55 | """ 56 | delta_theta = 0.9 * np.pi 57 | delta_phi = 0.9 * 2 * np.pi 58 | 59 | n_dim = int(np.sqrt(n_X)) 60 | d_theta = 2 * delta_theta / (n_dim - 1) 61 | d_phi = 2 * delta_phi / (n_dim - 1) 62 | 63 | increments = torch.linspace(0, n_dim - 1, n_dim, device=device) 64 | 65 | thetas = -delta_theta + increments * d_theta 66 | phis = -delta_phi + increments * d_phi 67 | 68 | thetas = thetas.view(n_dim, 1).expand(-1, n_dim) 69 | phis = phis.view(1, n_dim).expand(n_dim, -1) 70 | 71 | X = torch.stack(( 72 | torch.cos(thetas) * torch.sin(phis), 73 | torch.sin(thetas), 74 | torch.cos(thetas) * torch.cos(phis) 75 | ), dim=2) 76 | 77 | return radius * X.view(-1, 3) 78 | 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /SCONE/SconeOcc.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from Attention import * 3 | from utils import get_knn_points 4 | 5 | 6 | class XEmbedding(nn.Module): 7 | def __init__(self, x_dim, x_embedding_dim, dropout=None, gelu=True): 8 | """ 9 | Neural module for individual 3D point embedding using fully connected layers. 10 | 11 | :param x_dim: (int) Dimension of input 3D points. Usually, x_dim=3. 12 | :param x_embedding_dim: (int) Dimension of output embeddings. 13 | :param dropout: Dropout module to apply on computed features. 14 | :param gelu: (bool) If True, the model uses GELU non-linearity. Else, it uses ReLU. 15 | """ 16 | super(XEmbedding, self).__init__() 17 | 18 | self.linear1 = nn.Linear(x_dim, x_embedding_dim // 4) 19 | # self.fc1 = nn.Linear(3*2*self.n_harmonic_functions, self.x_embd_size//2) 20 | self.linear2 = nn.Linear(x_embedding_dim // 4, x_embedding_dim // 2) 21 | self.linear3 = nn.Linear(x_embedding_dim // 2, x_embedding_dim) 22 | 23 | if gelu: 24 | self.non_linear1 = nn.GELU() 25 | self.non_linear2 = nn.GELU() 26 | self.non_linear3 = nn.GELU() 27 | else: 28 | self.non_linear1 = nn.ReLU(inplace=False) 29 | self.non_linear2 = nn.ReLU(inplace=False) 30 | self.non_linear3 = nn.ReLU(inplace=False) 31 | 32 | self.dropout = nn.Dropout(dropout) if dropout is not None else None 33 | 34 | def forward(self, x): 35 | res = self.non_linear1(self.linear1(x)) 36 | res = self.non_linear2(self.linear2(res)) 37 | res = self.non_linear3(self.linear3(res)) 38 | 39 | res = res if self.dropout is None else self.dropout(res) 40 | 41 | return res 42 | 43 | 44 | class PCTransformer(nn.Module): 45 | def __init__(self, seq_len, pts_dim=3, pts_embedding_dim=256, feature_dim=512, 46 | concatenate_input=True, 47 | n_code=2, n_heads=4, FF=True, gelu=True, 48 | dropout=None): 49 | """ 50 | Main class for Transformer units dedicated to point cloud global encoding. 51 | 52 | :param seq_len: (int) Length of input point cloud sequence. 53 | :param pts_dim: (int) Dimension of input points. Usually, pts_dim=3. 54 | :param pts_embedding_dim: (int) Dimension of point embeddings. 55 | :param feature_dim: (int) Dimension of output features. 56 | :param concatenate_input: (bool) If True, the model concatenates the raw input to their initial embedding. 57 | :param n_code: (int) Number of Multi-Head Self Attention units. 58 | :param n_heads: (int) Number of heads in Multi-Head Self Attention units. 59 | :param FF: (bool) If True, the Transformer encoder applies a Feed Forward unit after 60 | each Multi-Head Self-Attention unit. 61 | :param gelu: (bool) If True, the model uses GELU non-linearity. Else, it uses ReLU. 62 | :param dropout: Dropout module to apply on computed features. 63 | """ 64 | super(PCTransformer, self).__init__() 65 | 66 | # Parameters 67 | self.seq_len = seq_len 68 | self.pts_dim = pts_dim 69 | self.pts_embedding_dim = pts_embedding_dim 70 | 71 | self.n_code = n_code 72 | self.n_heads = n_heads 73 | self.FF = FF 74 | self.gelu = gelu 75 | 76 | self.feature_dim = feature_dim 77 | 78 | self.dropout = dropout 79 | 80 | # Point embedding 81 | self.embedding = Embedding(input_dim=pts_dim, output_dim=pts_embedding_dim, 82 | dropout=None, gelu=gelu, 83 | global_feature=False, additional_feature_dim=0, 84 | concatenate_input=concatenate_input, 85 | k_for_knn=0) 86 | 87 | # Point cloud encoder Backbone 88 | encoders = [] 89 | for i in range(n_code): 90 | encoders += [Encoder(seq_len=seq_len, 91 | embedding_dim=pts_embedding_dim, 92 | qk_dim=pts_embedding_dim // 4, 93 | n_heads=n_heads, 94 | dropout=dropout, 95 | gelu=gelu, 96 | FF=FF)] 97 | 98 | self.encoders = nn.ModuleList(encoders) 99 | self.norm = nn.LayerNorm(self.pts_embedding_dim) 100 | self.linear0 = nn.Linear(self.pts_embedding_dim, feature_dim // 2) 101 | 102 | self.max_pool = nn.functional.max_pool1d # kernel_size will be <= seq_len 103 | self.avg_pool = nn.functional.avg_pool1d # kernel_size will be <= seq_len 104 | 105 | def forward(self, pc, mask=None): 106 | """ 107 | Forward pass. 108 | 109 | :param pc: (Tensor) Point cloud tensor with shape (n_clouds, seq_len, pts_dim) 110 | :param mask: (Tensor) Mask tensor with shape (batch_size, seq_len, seq_len). Optional. 111 | :return: (Tensor) Output features with shape (batch_size, features_dim) 112 | """ 113 | n_clouds, seq_len = pc.shape[0], pc.shape[1] 114 | 115 | pc_embd = self.embedding(pc) 116 | for encoder in self.encoders: 117 | pc_embd = encoder(pc_embd, mask=mask) 118 | features = self.norm(pc_embd) 119 | 120 | # Linear transformation and pooling operations for downstream tasks 121 | features = self.linear0(features) 122 | 123 | features = features.transpose(dim0=-1, dim1=-2) 124 | features = torch.cat((self.max_pool(input=features, kernel_size=seq_len), 125 | self.avg_pool(input=features, kernel_size=seq_len)), dim=-2) 126 | 127 | features = features.view(n_clouds, self.feature_dim) 128 | 129 | return features 130 | 131 | 132 | class SconeOcc(nn.Module): 133 | def __init__(self, seq_len=2048, pts_dim=3, pts_embedding_dim=128, 134 | concatenate_input=True, 135 | n_code=2, n_heads=4, FF=True, gelu=True, 136 | global_feature_dim=512, 137 | n_scale=3, local_feature_dim=256, k_for_knn=16, 138 | x_dim=3, x_embedding_dim=512, 139 | n_harmonics=64, 140 | output_dim=1, 141 | dropout=None, 142 | offset=True): 143 | """ 144 | Main class for SCONE's occupancy probability prediction module. 145 | A neural model that predicts a vector field as an implicit function, depending on an input point cloud 146 | and view state harmonic features representing the history of camera poses. 147 | A Transformer with Multi-Scale Neighborhood features (MSN features) is used to encode the point cloud, 148 | depending on the query point x. 149 | 150 | :param seq_len: (int) Number of points in the input point cloud. 151 | :param pts_dim: (int) Dimension of points in the input point cloud. 152 | :param pts_embedding_dim: (int) Dimension of embedded point cloud. 153 | :param concatenate_input: (bool) If True, concatenates raw input points to the point embeddings 154 | in the point cloud. 155 | :param n_code: (int) Number of Multi-Head Self Attention units. 156 | :param n_heads: (int) Number of heads in Multi-Head Self Attention units. 157 | :param FF: (bool) If True, the Transformer encoder applies a Feed Forward unit after 158 | each Multi-Head Self-Attention unit. 159 | :param gelu: (bool) If True, the model uses GELU non-linearity. Else, it uses ReLU. 160 | :param global_feature_dim: (int) Dimension of the point cloud global feature. 161 | :param n_scale: (int) Number of scales to compute neighborhood features in the point cloud. 162 | :param local_feature_dim: (int) Dimension of point cloud neighborhood features. 163 | :param k_for_knn: (int) Number of neighbors to use when computing a neighborhood feature. 164 | :param x_dim: (int) Dimension of query point x. 165 | :param x_embedding_dim: (int) Dimension of x embedding. 166 | :param n_harmonics: (int) Number of harmonics used to compute view_state harmonic features. 167 | :param output_dim: (int) Dimension of the output vector field. 168 | :param dropout: Dropout module to apply on computed embeddings. 169 | :param offset: (bool) If True, the model uses the offset between x and its neighbors rather than 170 | the coordinates of the neighbors to compute neighborhood features. 171 | This parameter should always be True, since it leads to far better performances. 172 | """ 173 | super(SconeOcc, self).__init__() 174 | 175 | # Parameters 176 | self.seq_len = seq_len 177 | self.pts_dim = pts_dim 178 | self.pts_embedding_dim = pts_embedding_dim 179 | 180 | self.n_code = n_code 181 | self.n_heads = n_heads 182 | self.FF = FF 183 | self.gelu = gelu 184 | 185 | self.n_scale = n_scale 186 | 187 | self.x_dim = x_dim 188 | self.x_embedding_dim = x_embedding_dim 189 | 190 | self.output_dim = output_dim 191 | 192 | self.dropout = dropout 193 | 194 | self.encoding_dim = pts_embedding_dim 195 | 196 | self.k_for_knn = k_for_knn 197 | self.offset = offset 198 | if self.offset: 199 | print("Offset set to True.") 200 | 201 | self.global_feature_dim = global_feature_dim 202 | self.local_feature_dim = local_feature_dim 203 | self.all_feature_size = self.x_embedding_dim \ 204 | + self.n_scale * self.local_feature_dim \ 205 | + self.global_feature_dim + n_harmonics 206 | 207 | # Point cloud transformers 208 | self.global_transformer = PCTransformer(seq_len=seq_len, pts_dim=pts_dim, 209 | pts_embedding_dim=pts_embedding_dim, feature_dim=global_feature_dim, 210 | concatenate_input=concatenate_input, 211 | n_code=n_code, n_heads=n_heads, FF=FF, gelu=gelu, 212 | dropout=dropout) 213 | 214 | local_transformers = [] 215 | for i in range(n_scale): 216 | local_transformers += [PCTransformer(seq_len=k_for_knn, pts_dim=pts_dim, 217 | pts_embedding_dim=pts_embedding_dim, feature_dim=local_feature_dim, 218 | concatenate_input=concatenate_input, 219 | n_code=n_code, n_heads=n_heads, FF=FF, gelu=gelu, 220 | dropout=dropout)] 221 | self.local_transformers = nn.ModuleList(local_transformers) 222 | 223 | # X embedding 224 | self.x_embedding = XEmbedding(x_dim=x_dim, x_embedding_dim=x_embedding_dim, 225 | dropout=dropout, gelu=gelu) 226 | 227 | # Point cloud feature extraction 228 | self.max_pool = nn.functional.max_pool1d # kernel_size will be <= seq_len 229 | self.avg_pool = nn.functional.avg_pool1d # kernel_size will be <= seq_len 230 | 231 | # MLP for occupancy probability prediction 232 | self.linear1 = nn.Linear(self.all_feature_size, 512) 233 | self.linear2 = nn.Linear(512, 256) 234 | self.linear3 = nn.Linear(256, output_dim) 235 | 236 | # self.non_linear1 = nn.ReLU(inplace=False) 237 | # self.non_linear2 = nn.ReLU(inplace=False) 238 | # self.non_linear3 = nn.ReLU(inplace=False) 239 | 240 | if gelu: 241 | self.non_linear1 = nn.GELU() 242 | self.non_linear2 = nn.GELU() 243 | self.non_linear3 = nn.GELU() 244 | else: 245 | self.non_linear1 = nn.ReLU(inplace=False) 246 | self.non_linear2 = nn.ReLU(inplace=False) 247 | self.non_linear3 = nn.ReLU(inplace=False) 248 | 249 | def forward(self, pc, x, view_harmonics, mask=None): 250 | """ 251 | Forward pass. 252 | :param pc: (Tensor) Input point cloud tensor with shape (n_clouds, seq_len, pts_dim) 253 | :param x: (Tensor) Input query points tensor with shape (n_clouds, n_sample, x_dim) 254 | :param view_harmonics: (Tensor) View state harmonic features. 255 | Tensor with shape (n_clouds, n_sample, n_harmonics). 256 | :param mask: (Tensor) Mask tensor with shape (batch_size, seq_len, seq_len). Optional. 257 | :return: (Tensor) Output vector field values for each query point in x. 258 | Has shape (n_clouds, n_sample, output_dim) 259 | """ 260 | n_clouds, full_seq_len = pc.shape[0], pc.shape[1] 261 | n_sample = x.shape[1] 262 | 263 | # -----Point cloud global encoding----- 264 | # Down sampling point cloud for global embedding 265 | global_down_sampled_pc = pc[:, torch.randperm(pc.shape[1])[:self.seq_len]] 266 | seq_len = global_down_sampled_pc.shape[1] 267 | 268 | global_features = self.global_transformer(global_down_sampled_pc) 269 | 270 | # -----Point cloud local encoding----- 271 | # Computing down sampling factor 272 | if self.n_scale > 1: 273 | ds_factor = int(np.power(full_seq_len / (self.k_for_knn * 8), 1./(self.n_scale - 1))) 274 | if ds_factor == 0: 275 | # print("Problem: ds_factor=0 encountered. Taking ds_factor=2 as a default value.") 276 | ds_factor = 2 277 | else: 278 | ds_factor = 1 279 | 280 | # kNN computation for local embedding 281 | down_sampled_pc = pc 282 | local_transformed = [] 283 | for n_transformer in range(self.n_scale): 284 | local_transformer = self.local_transformers[n_transformer] 285 | # Get kNN points in down sampled pc 286 | local_pc, _, _ = get_knn_points(x, down_sampled_pc, self.k_for_knn) 287 | if self.offset: 288 | local_pc = local_pc - x.view(n_clouds, n_sample, 1, 3) 289 | 290 | # Compute features 291 | local_transformed += [local_transformer(local_pc.view(-1, self.k_for_knn, 3), mask=mask)] 292 | 293 | # Down sample pc 294 | ds_seq_len = down_sampled_pc.shape[1] 295 | # print("Ds seq len:", ds_seq_len) 296 | 297 | if n_transformer < self.n_scale-1: 298 | down_sampled_pc = down_sampled_pc[:, torch.randperm(ds_seq_len)[:ds_seq_len // ds_factor]] 299 | # print("DS pc:", down_sampled_pc.shape) 300 | 301 | if self.n_scale > 0: 302 | local_features = torch.cat(local_transformed, dim=-1) 303 | else: 304 | local_features = torch.zeros(n_clouds, n_sample, 0, device=pc.get_device()) 305 | local_features = local_features.view(n_clouds, n_sample, self.n_scale * self.local_feature_dim) 306 | 307 | # -----X encoding----- 308 | x_features = self.x_embedding(x) 309 | 310 | # -----Occupancy prediction----- 311 | global_features = global_features.view(n_clouds, 1, self.global_feature_dim).expand(-1, n_sample, -1) 312 | x_features = x_features.view(n_clouds, n_sample, self.x_embedding_dim) 313 | 314 | res = torch.cat((global_features, local_features, x_features, view_harmonics), dim=-1) 315 | res = self.non_linear1(self.linear1(res)) 316 | res = self.non_linear2(self.linear2(res)) 317 | res = self.non_linear3(self.linear3(res)) 318 | 319 | return res.view(n_clouds, n_sample, self.output_dim) 320 | -------------------------------------------------------------------------------- /SCONE/SconeVis.py: -------------------------------------------------------------------------------- 1 | from Attention import * 2 | from CustomGeometry import get_spherical_coords 3 | from spherical_harmonics import clear_spherical_harmonics_cache, get_spherical_harmonics 4 | 5 | 6 | class SconeVis(nn.Module): 7 | def __init__(self, 8 | pts_dim=4, seq_len=2048, pts_embedding_dim=256, 9 | n_heads=4, n_code=3, 10 | n_harmonics=64, 11 | max_harmonic_rank=8, 12 | FF=True, 13 | gelu=True, 14 | dropout=None, 15 | use_view_state=True, 16 | use_global_feature=True, 17 | view_state_mode="end", 18 | concatenate_input=True, 19 | k_for_knn=0, 20 | alt=False, 21 | use_sigmoid=True): 22 | """ 23 | Main class for SCONE's visibility prediction module. 24 | 25 | :param pts_dim: (int) Input dimension. Since SCONE processes clouds of 3D-points concatenated with 26 | their occupancy probability, pts_dim should be equal to 4. 27 | :param seq_len: (int) Maximal number of points in the cloud. 28 | :param pts_embedding_dim: (int) Dimension of points' embeddings. 29 | :param n_heads: (int) Number of heads in Multi-Head Self Attention units. 30 | :param n_code: (int) Number of Multi-Head Self Attention units. 31 | :param n_harmonics: (int) Number of harmonics to use to encode visibility gain functions. 32 | :param max_harmonic_rank: (int) Maximal harmonic rank for harmonic functions. 33 | :param FF: (bool) If True, Transformer encoder(s) apply a Feed Forward unit after 34 | each Multi-Head Self-Attention unit. 35 | :param gelu: (bool) If True, the model uses GELU non-linearity. Else, it uses ReLU. 36 | :param dropout: Dropout module to apply on computed features. 37 | :param use_view_state: (bool) If True, model uses view_state harmonics as additional features. 38 | :param use_global_feature: (bool) If True, model computes an additional global feature concatenated to each 39 | point's embedding before applying the Transformer encoder. 40 | :param view_state_mode: (str) If view_state_mode=='start', view_state features are concatenated 41 | to the points' embeddings before applying the Transformer encoder. 42 | If view_state_mode=='end', view_state features are concatenated to the points' embeddings 43 | after applying the Transformer encoder. 44 | :param concatenate_input: (bool) If True, the model concatenates the raw input to their initial embedding. 45 | :param k_for_knn: (int) If > 0, the model compute embeddings for points based on their k nearest neighbors 46 | :param alt:(bool) If True, uses an alternate architecture for the end of the network. 47 | :param use_sigmoid: (bool) If True, uses a sigmoid function on predicted visibility scores. 48 | """ 49 | super(SconeVis, self).__init__() 50 | 51 | # self.harmonicEmbedder = HarmonicEmbedding(n_harmonic_functions=30, omega0=0.1) 52 | # self.n_harmonic_functions = 30 53 | # self.harmonicEmbedder = HarmonicEmbedding(n_harmonic_functions=self.n_harmonic_functions, omega0=1) 54 | 55 | self.n_harmonics = n_harmonics 56 | 57 | self.pts_dim = pts_dim 58 | self.seq_len = seq_len 59 | self.pts_embedding_dim = pts_embedding_dim 60 | self.n_heads = n_heads 61 | self.n_code = n_code 62 | self.n_harmonics = n_harmonics 63 | self.max_harmonic_rank = max_harmonic_rank 64 | 65 | self.use_view_state = use_view_state 66 | self.use_global_feature = use_global_feature 67 | self.view_state_mode = view_state_mode 68 | 69 | self.alt = alt 70 | 71 | self.use_sigmoid = use_sigmoid 72 | if use_sigmoid: 73 | print("Use sigmoid in model.") 74 | else: 75 | print("Use ReLU for output in model.") 76 | 77 | # Input embedding 78 | if use_view_state and view_state_mode == "start": 79 | additional_feature_dim = n_harmonics 80 | else: 81 | additional_feature_dim = 0 82 | self.embedding = Embedding(pts_dim, pts_embedding_dim, gelu=gelu, 83 | global_feature=use_global_feature, 84 | additional_feature_dim=additional_feature_dim, 85 | concatenate_input=concatenate_input, 86 | k_for_knn=k_for_knn, 87 | dropout=None) 88 | 89 | # Encoder Backbone 90 | encoders = [] 91 | for i in range(n_code): 92 | encoders += [Encoder(seq_len=seq_len, 93 | embedding_dim=pts_embedding_dim, 94 | qk_dim=pts_embedding_dim//4, 95 | n_heads=n_heads, 96 | dropout=dropout, 97 | gelu=gelu, 98 | FF=FF)] 99 | self.encoders = nn.ModuleList(encoders) 100 | 101 | self.norm = nn.LayerNorm(pts_embedding_dim) 102 | 103 | # MLP for Harmonics prediction 104 | if not alt: 105 | fc1_input_dim = pts_embedding_dim 106 | inner_feature_factor = 4 107 | if use_view_state and view_state_mode == "end": 108 | inner_feature_factor = 3 109 | else: 110 | fc1_input_dim = pts_embedding_dim + n_harmonics 111 | inner_feature_factor = 4 112 | 113 | self.fc1 = nn.Linear(fc1_input_dim, inner_feature_factor * n_harmonics) 114 | self.nonlinear1 = nn.GELU() 115 | 116 | self.fc2 = nn.Linear(4 * n_harmonics, 2 * n_harmonics) 117 | self.nonlinear2 = nn.GELU() 118 | 119 | self.fc3 = nn.Linear(2 * n_harmonics, n_harmonics) 120 | 121 | def forward(self, pts, mask=None, view_harmonics=None): 122 | """ 123 | Forward pass. 124 | :param pts: (Tensor) Input point cloud. Tensor with shape (n_clouds, seq_len, pts_dim) 125 | :param mask: (Tensor) Mask tensor with shape (batch_size, seq_len, seq_len). Optional. 126 | :param view_harmonics: (Tensor) View state harmonic features. Tensor with shape (n_clouds, seq_len, n_harmonics) 127 | :return: (Tensor) Visibility gains functions of each point as coordinates in spherical harmonics. 128 | Has shape (n_clouds, seq_len, n_harmonics) 129 | """ 130 | n_clouds = len(pts) 131 | seq_len = pts.shape[1] 132 | 133 | # Input embedding 134 | if self.use_view_state and self.view_state_mode == "start": 135 | x = self.embedding(pts, additional_feature=view_harmonics) 136 | else: 137 | x = self.embedding(pts) 138 | 139 | # Applying Encoders 140 | for encoder in self.encoders: 141 | x = encoder(x, mask=mask) 142 | 143 | # Final normalization, and linear layer for downstream task 144 | res = self.norm(x) 145 | 146 | if not self.alt: 147 | res = self.nonlinear1(self.fc1(res)) 148 | 149 | # Concatenating view harmonics if needed, and final prediction 150 | if self.use_view_state and self.view_state_mode == "end": 151 | res = torch.cat((res, view_harmonics), dim=-1) 152 | res = self.nonlinear2(self.fc2(res)) 153 | res = self.fc3(res) 154 | else: 155 | res = torch.cat((res, view_harmonics), dim=-1) 156 | res = self.nonlinear1(self.fc1(res)) 157 | res = self.nonlinear2(self.fc2(res)) 158 | res = self.fc3(res) 159 | 160 | res = res.view(n_clouds, seq_len, self.n_harmonics) 161 | 162 | return res 163 | 164 | def compute_visibilities(self, pts, harmonics, X_cam): 165 | """ 166 | Compute visibility gains of each points in pts for each camera in X_cam. 167 | :param pts: (Tensor) Input point cloud. Tensor with shape (n_clouds, seq_len, pts_dim) 168 | :param harmonics: (Tensor) Predicted visibility gain functions as coordinates in spherical harmonics. 169 | Has shape (n_clouds, seq_len, n_harmonics). 170 | :param X_cam: (Tensor) Tensor of camera centers' positions, with shape (n_clouds, n_camera_candidates, 3) 171 | :return: (Tensor) The predicted per-point visibility gains of all points. 172 | Has shape (n_clouds, n_camera_candidates, seq_len) 173 | """ 174 | clear_spherical_harmonics_cache() 175 | n_clouds = pts.shape[0] 176 | seq_len = pts.shape[1] 177 | n_harmonics = self.n_harmonics 178 | n_camera_candidates = X_cam.shape[1] 179 | 180 | device = pts.get_device() 181 | if device < 0: 182 | device = "cpu" 183 | 184 | X_pts = pts[..., :3] 185 | 186 | # tmp_pts = X_pts.view(n_clouds, 1, seq_len, 3).expand(-1, n_camera_candidates, -1, -1) 187 | # tmp_h = harmonics.view(n_clouds, 1, seq_len, 64).expand(-1, n_camera_candidates, -1, -1) 188 | # tmp_cam = X_cam.view(n_clouds, n_camera_candidates, 1, 3).expand(-1, -1, seq_len, -1) 189 | 190 | rays = (X_cam.view(n_clouds, n_camera_candidates, 1, 3).expand(-1, -1, seq_len, -1) 191 | - X_pts.view(n_clouds, 1, seq_len, 3).expand(-1, n_camera_candidates, -1, -1)).view(-1, 3) 192 | _, theta, phi = get_spherical_coords(rays) 193 | theta = -theta + np.pi / 2. 194 | 195 | z = torch.zeros([i for i in theta.shape] + [0], device=device) 196 | for i in range(self.max_harmonic_rank): 197 | y = get_spherical_harmonics(l=i, theta=theta, phi=phi) 198 | z = torch.cat((z, y), dim=-1) 199 | z = z.view(n_clouds, n_camera_candidates, seq_len, n_harmonics) 200 | 201 | z = torch.sum(z * harmonics.view(n_clouds, 1, seq_len, 64).expand(-1, n_camera_candidates, -1, -1), dim=-1) 202 | if self.use_sigmoid: 203 | z = torch.sigmoid(z) 204 | else: 205 | z = torch.relu(z) 206 | # z = torch.sum(z, dim=-1) / seq_len 207 | 208 | return z 209 | 210 | def compute_coverage_gain(self, pts, harmonics, X_cam): 211 | """ 212 | Computes global coverage gain for each camera candidate in X_cam. 213 | :param pts: tensor with shape (n_clouds, seq_len, 3 or 4) 214 | :param harmonics: tensor with shape (n_clouds, seq_len, n_harmonics) 215 | :param X_cam: tensor with shape (n_clouds, n_camera_candidates, 3) 216 | :return: A tensor z with shape (n_clouds, n_camera_candidates) 217 | """ 218 | clear_spherical_harmonics_cache() 219 | n_clouds = pts.shape[0] 220 | seq_len = pts.shape[1] 221 | n_harmonics = self.n_harmonics 222 | n_camera_candidates = X_cam.shape[1] 223 | 224 | X_pts = pts[..., :3] 225 | 226 | # tmp_pts = X_pts.view(n_clouds, 1, seq_len, 3).expand(-1, n_camera_candidates, -1, -1) 227 | # tmp_h = harmonics.view(n_clouds, 1, seq_len, 64).expand(-1, n_camera_candidates, -1, -1) 228 | # tmp_cam = X_cam.view(n_clouds, n_camera_candidates, 1, 3).expand(-1, -1, seq_len, -1) 229 | 230 | rays = (X_cam.view(n_clouds, n_camera_candidates, 1, 3).expand(-1, -1, seq_len, -1) 231 | - X_pts.view(n_clouds, 1, seq_len, 3).expand(-1, n_camera_candidates, -1, -1)).view(-1, 3) 232 | _, theta, phi = get_spherical_coords(rays) 233 | theta = -theta + np.pi/2. 234 | 235 | z = torch.zeros([i for i in theta.shape] + [0], device=theta.get_device()) 236 | for i in range(self.max_harmonic_rank): 237 | y = get_spherical_harmonics(l=i, theta=theta, phi=phi) 238 | z = torch.cat((z, y), dim=-1) 239 | z = z.view(n_clouds, n_camera_candidates, seq_len, n_harmonics) 240 | 241 | z = torch.sum(z * harmonics.view(n_clouds, 1, seq_len, 64).expand(-1, n_camera_candidates, -1, -1), dim=-1) 242 | if self.use_sigmoid: 243 | z = torch.sigmoid(z) 244 | else: 245 | z = torch.relu(z) 246 | 247 | # TO REMOVE 248 | # z[pts[..., 3].view(n_clouds, 1, seq_len).expand(-1, n_camera_candidates, -1) > 0.9] = 0 249 | 250 | z = torch.sum(z, dim=-1) / seq_len 251 | 252 | return z 253 | 254 | def compute_coverage_gain_multiple(self, pts, harmonics, X_cam, n_cam): 255 | """ 256 | Computes global coverage gains for each n_cam-subset of camera candidates in X_cam. 257 | :param pts: tensor with shape (n_clouds, seq_len, 3 or 4) 258 | :param harmonics: tensor with shape (n_clouds, seq_len, n_harmonics) 259 | :param X_cam: tensor with shape (n_clouds, n_camera_candidates, 3) 260 | :param n_cam: number of simultaneous NBV to select 261 | :return: A tensor z with shape (n_clouds, n_camera_candidates) 262 | """ 263 | clear_spherical_harmonics_cache() 264 | n_clouds = pts.shape[0] 265 | seq_len = pts.shape[1] 266 | n_harmonics = self.n_harmonics 267 | n_camera_candidates = X_cam.shape[1] 268 | 269 | X_pts = pts[..., :3] 270 | 271 | # tmp_pts = X_pts.view(n_clouds, 1, seq_len, 3).expand(-1, n_camera_candidates, -1, -1) 272 | # tmp_h = harmonics.view(n_clouds, 1, seq_len, 64).expand(-1, n_camera_candidates, -1, -1) 273 | # tmp_cam = X_cam.view(n_clouds, n_camera_candidates, 1, 3).expand(-1, -1, seq_len, -1) 274 | 275 | rays = (X_cam.view(n_clouds, n_camera_candidates, 1, 3).expand(-1, -1, seq_len, -1) 276 | - X_pts.view(n_clouds, 1, seq_len, 3).expand(-1, n_camera_candidates, -1, -1)).view(-1, 3) 277 | _, theta, phi = get_spherical_coords(rays) 278 | theta = -theta + np.pi/2. 279 | 280 | z = torch.zeros([i for i in theta.shape] + [0], device=theta.get_device()) 281 | for i in range(self.max_harmonic_rank): 282 | y = get_spherical_harmonics(l=i, theta=theta, phi=phi) 283 | z = torch.cat((z, y), dim=-1) 284 | z = z.view(n_clouds, n_camera_candidates, seq_len, n_harmonics) 285 | 286 | z = torch.sum(z * harmonics.view(n_clouds, 1, seq_len, 64).expand(-1, n_camera_candidates, -1, -1), dim=-1) 287 | if self.use_sigmoid: 288 | z = torch.sigmoid(z) 289 | else: 290 | z = torch.relu(z) 291 | 292 | single_idx = torch.arange(0, n_camera_candidates) 293 | if n_cam == 2: 294 | n_idx = torch.cartesian_prod(single_idx, single_idx) 295 | elif n_cam == 3: 296 | n_idx = torch.cartesian_prod(single_idx, single_idx, single_idx) 297 | else: 298 | raise NameError("n_cam is too large.") 299 | 300 | n_z = z[:, n_idx] 301 | n_z = torch.sum(torch.max(n_z, dim=-2)[0], dim=-1) / seq_len 302 | 303 | return n_z, n_idx 304 | 305 | 306 | class KLDivCE(nn.Module): 307 | """ 308 | Layer to compute KL-divergence after applying Softmax 309 | """ 310 | 311 | def __init__(self): 312 | super(KLDivCE, self).__init__() 313 | 314 | self.kl_div = nn.KLDivLoss(reduction='batchmean') 315 | self.log_soft_max = nn.LogSoftmax(dim=1) 316 | 317 | def forward(self, x, y): 318 | loss = self.kl_div(self.log_soft_max(x), torch.softmax(y, dim=1)) 319 | return loss 320 | 321 | 322 | class L1_loss(nn.Module): 323 | """ 324 | Layer to compute L1 loss between normalized coverage distributions 325 | """ 326 | 327 | def __init__(self): 328 | super(L1_loss, self).__init__() 329 | self.epsilon = 1e-7 330 | 331 | def forward(self, x, y): 332 | """ 333 | 334 | :param x: (Tensor) Should have shape (batch_size, n_camera, 1) 335 | :param y: (Tensor) Should have shape (batch_size, n_camera, 1) 336 | :return: 337 | """ 338 | batch_size, n_camera = x.shape[0], x.shape[1] 339 | x_mean = x.mean(dim=1, keepdim=True).expand(-1, n_camera, -1) 340 | y_mean = y.mean(dim=1, keepdim=True).expand(-1, n_camera, -1) 341 | 342 | x_std = x.std(dim=1, keepdim=True).expand(-1, n_camera, -1) 343 | y_std = y.std(dim=1, keepdim=True).expand(-1, n_camera, -1) 344 | 345 | norm_x = (x-x_mean) / (x_std + self.epsilon) 346 | norm_y = (y-y_mean) / (y_std + self.epsilon) 347 | 348 | loss = (norm_x - norm_y).abs().mean(dim=1) 349 | 350 | return loss.mean() 351 | 352 | 353 | class Uncentered_L1_loss(nn.Module): 354 | """ 355 | Layer to compute L1 loss between normalized coverage distributions 356 | """ 357 | 358 | def __init__(self): 359 | super(Uncentered_L1_loss, self).__init__() 360 | self.epsilon = 1e-7 361 | 362 | def forward(self, x, y): 363 | """ 364 | 365 | :param x: (Tensor) Should have shape (batch_size, n_camera, 1) 366 | :param y: (Tensor) Should have shape (batch_size, n_camera, 1) 367 | :return: 368 | """ 369 | batch_size, n_camera = x.shape[0], x.shape[1] 370 | x_mean = x.mean(dim=1, keepdim=True).expand(-1, n_camera, -1) 371 | y_mean = y.mean(dim=1, keepdim=True).expand(-1, n_camera, -1) 372 | 373 | norm_x = x / (x_mean + self.epsilon) 374 | norm_y = y / (y_mean + self.epsilon) 375 | 376 | loss = (norm_x - norm_y).abs().mean(dim=1) 377 | 378 | return loss.mean() 379 | -------------------------------------------------------------------------------- /SCONE/idr_torch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import os 5 | # import hostlist 6 | 7 | # Dummy script to avoid import errors 8 | rank = 0 9 | local_rank = 0 10 | size = 4 11 | cpus_per_task = 10 12 | hostnames = [] 13 | gpus_ids = [0] 14 | 15 | # # get SLURM variables 16 | # rank = int(os.environ['SLURM_PROCID']) 17 | # local_rank = int(os.environ['SLURM_LOCALID']) 18 | # size = int(os.environ['SLURM_NTASKS']) 19 | # cpus_per_task = int(os.environ['SLURM_CPUS_PER_TASK']) 20 | # 21 | # # get node list from slurm 22 | # hostnames = hostlist.expand_hostlist(os.environ['SLURM_JOB_NODELIST']) 23 | # 24 | # # get IDs of reserved GPU 25 | # gpu_ids = os.environ['SLURM_STEP_GPUS'].split(",") 26 | # 27 | # # define MASTER_ADD & MASTER_PORT 28 | # os.environ['MASTER_ADDR'] = hostnames[0] 29 | # os.environ['MASTER_PORT'] = str(12345 + int(min(gpu_ids))) # to avoid port conflict on the same node -------------------------------------------------------------------------------- /SCONE/pretrain_scone_occ.py: -------------------------------------------------------------------------------- 1 | # Updated from former script train_proba_model_faster.py 2 | from scone_utils import * 3 | import sys 4 | 5 | save_parameters = True 6 | start_training = False 7 | 8 | debug = False 9 | 10 | def save_train_params(save=False): 11 | # TO DO: for ddp, save only if is_master but load for everyone, with a synchronization in-between. 12 | params = {} 13 | 14 | # -----General parameters----- 15 | params["ddp"] = False 16 | params["jz"] = True 17 | 18 | if params["ddp"]: 19 | params["CUDA_VISIBLE_DEVICES"] = "0, 1" 20 | params["WORLD_SIZE"] = 2 21 | 22 | elif params["jz"]: 23 | params["WORLD_SIZE"] = idr_torch.size 24 | 25 | else: 26 | params["numGPU"] = 0 27 | params["WORLD_SIZE"] = 1 28 | 29 | params["anomaly_detection"] = True 30 | params["empty_cache_every_n_batch"] = 10 31 | 32 | # -----Ground truth computation parameters----- 33 | params["compute_gt_online"] = False 34 | params["compute_partial_point_cloud_online"] = False 35 | 36 | params["gt_surface_resolution"] = 1.5 37 | 38 | params["gt_max_diagonal"] = 1. # Formerly known as x_range 39 | 40 | params["n_points_surface"] = 16384 # To remove? 41 | 42 | # -----Model Parameters----- 43 | params["seq_len"] = 2048 44 | params["n_sample"] = 6000 # 12000, 100000 45 | 46 | params["view_state_n_elev"] = 7 47 | params["view_state_n_azim"] = 2 * 7 48 | params["harmonic_degree"] = 8 49 | 50 | params["min_occ"] = 0.01 51 | 52 | # Ablation study 53 | params["no_local_features"] = False # False 54 | params["no_view_harmonics"] = False # False 55 | # params["noise_std"] = 5 * np.sqrt(3) * params["side"] 56 | 57 | # -----General training parameters----- 58 | params["start_from_scratch"] = True 59 | params["pretrained_weights_name"] = None 60 | 61 | params["n_view_max"] = 5 62 | params["n_view_min"] = 1 63 | params["n_point_max_for_prediction"] = 300000 # 300000 64 | 65 | params["camera_dist"] = 1.5 66 | params["pole_cameras"] = True 67 | params["n_camera_elev"] = 5 68 | params["n_camera_azim"] = 2 * 5 69 | params["n_camera"] = params["n_camera_elev"] * params["n_camera_azim"] 70 | if params["pole_cameras"]: 71 | params["n_camera"] += 2 72 | 73 | params["prediction_in_random_camera_space"] = False 74 | 75 | params["total_batch_size"] = 12 # 12 76 | params["batch_size"] = params["total_batch_size"] // params["WORLD_SIZE"] 77 | params["total_batch_size"] = params["batch_size"] * params["WORLD_SIZE"] 78 | 79 | params["epochs"] = 1000 80 | params["learning_rate"] = 1e-4 81 | 82 | params["schedule_learning_rate"] = True 83 | if params["schedule_learning_rate"]: 84 | params["lr_epochs"] = [250] 85 | params["lr_factor"] = 0.1 86 | 87 | params["warmup"] = 1000 88 | params["warmup_rate"] = 1 / (params["warmup"] * params["learning_rate"] ** 2) 89 | 90 | params["noam_opt"] = False 91 | params["training_loss"] = "mse" # "mse" 92 | params["multiply_loss"] = False 93 | if params["multiply_loss"]: 94 | params["loss_multiplication_factor"] = 10. 95 | 96 | params["random_seed"] = 42 97 | params["torch_seed"] = 5 98 | 99 | # -----Model name to save----- 100 | model_name = "model_scone_occ" 101 | 102 | if params["no_local_features"]: 103 | model_name += "_no_local_features" 104 | 105 | if params["no_view_harmonics"]: 106 | model_name += "_no_view_harmonics" 107 | 108 | if params["ddp"]: 109 | model_name = "ddp_" + model_name 110 | elif params["jz"]: 111 | model_name = "jz_" + model_name 112 | 113 | model_name += "_" + params["training_loss"] 114 | 115 | if params["noam_opt"]: 116 | model_name += "noam_" 117 | model_name += "_warmup_" + str(params["warmup"]) 118 | 119 | if params["schedule_learning_rate"]: 120 | model_name += "_schedule" 121 | model_name += "_lr_" + str(params["learning_rate"]) 122 | 123 | if debug: 124 | model_name = "debug_" + model_name 125 | 126 | if params["prediction_in_random_camera_space"]: 127 | model_name += "_random_rot" 128 | 129 | params["scone_occ_model_name"] = model_name 130 | 131 | # -----Json name to save params----- 132 | json_name = "train_params_" + params["scone_occ_model_name"] + ".json" 133 | 134 | if save: 135 | with open(json_name, 'w') as outfile: 136 | json.dump(params, outfile) 137 | 138 | print("Parameters save in:") 139 | print(json_name) 140 | 141 | return json_name 142 | 143 | 144 | def loop(params, 145 | batch, mesh_dict, 146 | scone_occ, occ_loss_fn, 147 | device, is_master, 148 | n_views_list=None 149 | ): 150 | paths = mesh_dict['path'] 151 | 152 | pred_occs = torch.zeros(0, 1, device=device) 153 | truth_occs = torch.zeros(0, 1, device=device) 154 | 155 | base_harmonics, h_polar, h_azim = get_all_harmonics_under_degree(params.harmonic_degree, 156 | params.view_state_n_elev, 157 | params.view_state_n_azim, 158 | device) 159 | 160 | batch_size = len(paths) 161 | 162 | # Loading, if provided, view sequences (useful for consistent validation) 163 | if n_views_list is None: 164 | n_views = np.random.randint(params.n_view_min, params.n_view_max + 1, batch_size) 165 | else: 166 | n_views = get_validation_n_view(params, n_views_list, batch, idr_torch.rank) 167 | 168 | for i in range(batch_size): 169 | # ----------Load input mesh and ground truth data--------------------------------------------------------------- 170 | 171 | path_i = paths[i] 172 | # Loading info about partial point clouds and coverages 173 | part_pc, _ = get_gt_partial_point_clouds(path=path_i, 174 | normalization_factor=1./params.gt_surface_resolution, 175 | device=device) 176 | # Loading info about ground truth occupancy field 177 | X_world, occs = get_gt_occupancy_field(path=path_i, device=device) 178 | 179 | # ----------Set camera positions associated to partial point clouds--------------------------------------------- 180 | 181 | # Positions are loaded in world coordinates 182 | X_cam_world, camera_dist, camera_elev, camera_azim = get_cameras_on_sphere(params, device, 183 | pole_cameras=params.pole_cameras) 184 | 185 | # ----------Select initial observations of the object----------------------------------------------------------- 186 | 187 | # Select a subset of n_view cameras to compute an initial point cloud 188 | n_view = n_views[i] 189 | view_idx = torch.randperm(len(camera_elev), device=device)[:n_view] 190 | 191 | # Select either first camera view space, or random camera view space as prediction view space 192 | if params.prediction_in_random_camera_space: 193 | prediction_cam_idx = np.random.randint(low=0, high=len(camera_elev)) 194 | else: 195 | prediction_cam_idx = view_idx[0] 196 | prediction_box_center = torch.Tensor([0., 0., params.camera_dist]).to(device) 197 | 198 | # Move camera coordinates from world space to prediction view space, and normalize them for prediction box 199 | prediction_R, prediction_T = look_at_view_transform(dist=camera_dist[prediction_cam_idx], 200 | elev=camera_elev[prediction_cam_idx], 201 | azim=camera_azim[prediction_cam_idx], 202 | device=device) 203 | prediction_camera = FoVPerspectiveCameras(device=device, R=prediction_R, T=prediction_T) 204 | prediction_view_transform = prediction_camera.get_world_to_view_transform() 205 | 206 | X_cam = prediction_view_transform.transform_points(X_cam_world) 207 | X_cam = normalize_points_in_prediction_box(points=X_cam, 208 | prediction_box_center=prediction_box_center, 209 | prediction_box_diag=params.gt_max_diagonal) 210 | _, elev_cam, azim_cam = get_spherical_coords(X_cam) 211 | 212 | X_view = X_cam[view_idx] 213 | 214 | # ----------Capture initial observations------------------------------------------------------------------------ 215 | 216 | # Points observed in initial views 217 | pc = torch.vstack([part_pc[pc_idx] for pc_idx in view_idx]) 218 | 219 | # Downsampling partial point cloud 220 | pc = pc[torch.randperm(len(pc))[:n_view * params.seq_len]] 221 | 222 | # Move partial point cloud from world space to prediction view space, and normalize them in prediction box 223 | pc = prediction_view_transform.transform_points(pc) 224 | pc = normalize_points_in_prediction_box(points=pc, 225 | prediction_box_center=prediction_box_center, 226 | prediction_box_diag=params.gt_max_diagonal).view(1, -1, 3) 227 | 228 | # ----------Compute inputs to SconeOcc-------------------------------------------------------------------------- 229 | 230 | # Sample random proxy points in space 231 | X_idx = torch.randperm(len(X_world))[:params.n_sample] 232 | X_world, occs = X_world[X_idx], occs[X_idx] 233 | 234 | # Move proxy points from world space to prediction view space, and normalize them in prediction box 235 | X = prediction_view_transform.transform_points(X_world) 236 | X = normalize_points_in_prediction_box(points=X, 237 | prediction_box_center=prediction_box_center, 238 | prediction_box_diag=params.gt_max_diagonal).view(1, params.n_sample, 3) 239 | 240 | # Compute view state vector and corresponding view harmonics 241 | view_state = compute_view_state(X, X_view, 242 | params.view_state_n_elev, params.view_state_n_azim) 243 | view_harmonics = compute_view_harmonics(view_state, 244 | base_harmonics, h_polar, h_azim, 245 | params.view_state_n_elev, params.view_state_n_azim) 246 | if params.no_view_harmonics: 247 | view_harmonics *= 0. 248 | 249 | # ----------Predict Occupancy Probability----------------------------------------------------------------------- 250 | pred_i = scone_occ(pc, X, view_harmonics).view(-1, 1) 251 | pred_occs = torch.vstack((pred_occs, pred_i)) 252 | 253 | # ----------GT Occupancy Probability---------------------------------------------------------------------------- 254 | truth_occs = torch.vstack((truth_occs, occs)) 255 | 256 | # ----------Compute Loss-------------------------------------------------------------------------------------------- 257 | loss = occ_loss_fn(pred_occs, truth_occs) 258 | if params.multiply_loss: 259 | loss *= params.loss_multiplication_factor 260 | 261 | if batch % params.empty_cache_every_n_batch == 0 and is_master: 262 | print("View state sum-mean:", torch.mean(torch.sum(view_state, dim=-1))) 263 | 264 | return loss, pred_occs, truth_occs, batch_size, n_view 265 | 266 | 267 | def train(params, 268 | dataloader, 269 | scone_occ, occ_loss_fn, 270 | optimizer, 271 | device, is_master, 272 | train_losses): 273 | 274 | num_batches = len(dataloader) 275 | size = num_batches * params.total_batch_size 276 | train_loss = 0. 277 | 278 | # Preparing information model 279 | scone_occ.train() 280 | 281 | t0 = time.time() 282 | 283 | for batch, (mesh_dict) in enumerate(dataloader): 284 | 285 | loss, pred, truth, batch_size, n_screen_cameras = loop(params, 286 | batch, mesh_dict, 287 | scone_occ, occ_loss_fn, 288 | device, is_master) 289 | 290 | # Backpropagation 291 | optimizer.zero_grad() 292 | loss.backward() 293 | optimizer.step() 294 | 295 | train_loss += loss.detach() 296 | if params.multiply_loss: 297 | train_loss /= params.loss_multiplication_factor 298 | 299 | if batch % params.empty_cache_every_n_batch == 0: 300 | 301 | # loss = reduce_tensor(loss) 302 | if params.ddp or params.jz: 303 | loss = reduce_tensor(loss, world_size=params.WORLD_SIZE) 304 | loss = to_python_float(loss) 305 | 306 | current = batch * batch_size # * idr_torch.size 307 | if params.ddp or params.jz: 308 | current *= params.WORLD_SIZE 309 | 310 | truth_norm = to_python_float(torch.linalg.norm(truth.detach())) 311 | pred_norm = to_python_float(torch.linalg.norm(pred.detach())) 312 | 313 | # torch.cuda.synchronize() 314 | 315 | if is_master: 316 | print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]", 317 | "computed in", (time.time() - t0) / 60., "minutes.") 318 | print(">>>Prediction shape:", pred.shape, 319 | "\n>>>Truth norm:", truth_norm, ">>>Prediction norm:", pred_norm, 320 | "\nNumber of cameras:", n_screen_cameras) 321 | # print("Harmonics:\n", info_harmonics_i) 322 | # TO REMOVE 323 | torch.cuda.empty_cache() 324 | 325 | if batch % params.empty_cache_every_n_batch == 0: 326 | t0 = time.time() 327 | 328 | # train_loss = reduce_tensor(train_loss) 329 | if params.ddp or params.jz: 330 | train_loss = reduce_tensor(train_loss, world_size=params.WORLD_SIZE) 331 | train_loss = to_python_float(train_loss) 332 | train_loss /= num_batches 333 | train_losses.append(train_loss) 334 | 335 | 336 | def validation(params, 337 | dataloader, 338 | occ_model, occ_loss_fn, 339 | device, is_master, 340 | val_losses): 341 | 342 | num_batches = len(dataloader) 343 | size = num_batches * params.total_batch_size 344 | val_loss = 0. 345 | 346 | # Preparing information model 347 | occ_model.eval() 348 | 349 | t0 = time.time() 350 | 351 | n_views_list = get_validation_n_views_list(params, dataloader) 352 | 353 | for batch, (mesh_dict) in enumerate(dataloader): 354 | with torch.no_grad(): 355 | loss, pred, truth, batch_size, _ = loop(params, 356 | batch, mesh_dict, 357 | occ_model, occ_loss_fn, 358 | device, is_master, 359 | n_views_list=n_views_list) 360 | 361 | val_loss += loss.detach() 362 | if params.multiply_loss: 363 | val_loss /= params.loss_multiplication_factor 364 | 365 | if batch % params.empty_cache_every_n_batch == 0: 366 | torch.cuda.empty_cache() 367 | 368 | # val_loss = reduce_tensor(val_loss) 369 | if params.ddp or params.jz: 370 | val_loss = reduce_tensor(val_loss, world_size=params.WORLD_SIZE) 371 | 372 | val_loss = to_python_float(val_loss) 373 | val_loss /= num_batches 374 | val_losses.append(val_loss) 375 | 376 | if is_master: 377 | print(f"Validation Error: \n Avg loss: {val_loss:>8f} \n") 378 | 379 | 380 | def run(ddp_rank=None, params=None): 381 | # Set device 382 | device = setup_device(params, ddp_rank) 383 | 384 | batch_size = params.batch_size 385 | total_batch_size = params.total_batch_size 386 | 387 | if params.ddp: 388 | world_size = params.WORLD_SIZE 389 | rank = ddp_rank 390 | is_master = rank == 0 391 | elif params.jz: 392 | world_size = idr_torch.size 393 | rank = idr_torch.rank 394 | is_master = rank == 0 395 | else: 396 | world_size, rank = None, None 397 | is_master = True 398 | 399 | # Create dataloader 400 | train_dataloader, val_dataloader, _ = get_shapenet_dataloader(batch_size=params.batch_size, 401 | ddp=params.ddp, jz=params.jz, 402 | world_size=world_size, ddp_rank=rank, 403 | load_obj=False, 404 | data_path=None) 405 | 406 | 407 | # Initialize or Load models 408 | if params.no_local_features: 409 | raise NameError("no_local_features mode is not implemented yet.") 410 | else: 411 | scone_occ = SconeOcc().to(device) 412 | 413 | # Initialize information model (and DDP wrap if needed) 414 | scone_occ, optimizer, opt_name, start_epoch, best_loss = initialize_scone_occ(params=params, 415 | scone_occ=scone_occ, 416 | device=device, 417 | torch_seed=params.torch_seed, 418 | load_pretrained_weights=not params.start_from_scratch, 419 | pretrained_weights_name=params.pretrained_weights_name, 420 | ddp_rank=rank) 421 | 422 | best_train_loss = 1000 423 | epochs_without_improvement = 0 424 | learning_rate = params.learning_rate 425 | 426 | # Set loss function 427 | occ_loss_fn = get_occ_loss_fn(params) 428 | 429 | if is_master: 430 | print("Model name:", params.scone_occ_model_name, "\nArchitecture:\n") 431 | print(scone_occ) 432 | print("Model name:", params.scone_occ_model_name) 433 | print("Numbers of trainable parameters:", count_parameters(scone_occ)) 434 | print("Using", opt_name, "optimizer.") 435 | 436 | if params.training_loss == "cross_entropy": 437 | print("Using soft cross entropy loss.") 438 | elif params.training_loss == "mse": 439 | print("Using MSE loss.") 440 | 441 | print("Using", params.n_camera, "uniformly sampled camera position per mesh.") 442 | 443 | print("Training data:", len(train_dataloader), "batches.") 444 | print("Validation data:", len(val_dataloader), "batches.") 445 | print("Batch size:", params.total_batch_size) 446 | print("Batch size per GPU:", params.batch_size) 447 | 448 | # Begin training process 449 | train_losses = [] 450 | val_losses = [] 451 | val_coverages = [] 452 | 453 | t0 = time.time() 454 | for t_e in range(params.epochs): 455 | t = start_epoch + t_e 456 | if is_master: 457 | print(f"Epoch {t + 1}\n-------------------------------") 458 | torch.cuda.empty_cache() 459 | 460 | if params.schedule_learning_rate: 461 | if t in params.lr_epochs: 462 | print("Multiplying learning rate by", params.lr_factor) 463 | learning_rate *= params.lr_factor 464 | 465 | update_learning_rate(params, optimizer, learning_rate) 466 | 467 | print("Max learning rate set to", learning_rate) 468 | print("Current learning rate set to", optimizer._rate) 469 | 470 | train_dataloader.sampler.set_epoch(t) 471 | 472 | scone_occ.train() 473 | train(params, 474 | train_dataloader, 475 | scone_occ, occ_loss_fn, 476 | optimizer, 477 | device, is_master, 478 | train_losses) 479 | 480 | current_loss = train_losses[-1] 481 | 482 | if is_master: 483 | print("Training done for epoch", t + 1, ".") 484 | # torch.save(model, "unvalidated_" + model_name + ".pth") 485 | torch.save({ 486 | 'epoch': t + 1, 487 | 'model_state_dict': scone_occ.state_dict(), 488 | 'optimizer_state_dict': optimizer.state_dict(), 489 | 'loss': current_loss, 490 | 'train_losses': train_losses, 491 | 'val_losses': val_losses, 492 | }, "unvalidated_" + params.scone_occ_model_name + ".pth") 493 | 494 | if current_loss < best_train_loss: 495 | torch.save({ 496 | 'epoch': t + 1, 497 | 'model_state_dict': scone_occ.state_dict(), 498 | 'optimizer_state_dict': optimizer.state_dict(), 499 | 'loss': current_loss, 500 | 'train_losses': train_losses, 501 | 'val_losses': val_losses, 502 | }, "best_unval_" + params.scone_occ_model_name + ".pth") 503 | best_train_loss = current_loss 504 | print("Best model on training set saved with loss " + str(current_loss) + " .\n") 505 | 506 | torch.cuda.empty_cache() 507 | 508 | if is_master: 509 | print("Beginning evaluation on validation dataset...") 510 | # val_dataloader.sampler.set_epoch(t) 511 | 512 | scone_occ.eval() 513 | validation(params, 514 | val_dataloader, 515 | scone_occ, occ_loss_fn, 516 | device, is_master, 517 | val_losses 518 | ) 519 | 520 | current_val_loss = val_losses[-1] 521 | if current_val_loss < best_loss: 522 | # torch.save(model, "validated_" + model_name + ".pth") 523 | if is_master: 524 | torch.save({ 525 | 'epoch': t + 1, 526 | 'model_state_dict': scone_occ.state_dict(), 527 | 'optimizer_state_dict': optimizer.state_dict(), 528 | 'loss': current_val_loss, 529 | 'train_losses': train_losses, 530 | 'val_losses': val_losses 531 | }, "validated_" + params.scone_occ_model_name + ".pth") 532 | print("Model saved with loss " + str(current_val_loss) + " .\n") 533 | best_loss = val_losses[-1] 534 | epochs_without_improvement = 0 535 | else: 536 | epochs_without_improvement += 1 537 | 538 | # Save data about losses 539 | if is_master: 540 | losses_data = {} 541 | losses_data['train_loss'] = train_losses 542 | losses_data['val_loss'] = val_losses 543 | json_name = "losses_data_" + params.scone_occ_model_name + ".json" 544 | with open(json_name, 'w') as outfile: 545 | json.dump(losses_data, outfile) 546 | print("Saved data about losses in", json_name, ".") 547 | 548 | if is_master: 549 | print("Done in", (time.time() - t0) / 3600., "hours!") 550 | 551 | # Save data about losses 552 | losses_data = {} 553 | losses_data['train_loss'] = train_losses 554 | losses_data['val_loss'] = val_losses 555 | json_name = "losses_data_" + params.scone_occ_model_name + ".json" 556 | with open(json_name, 'w') as outfile: 557 | json.dump(losses_data, outfile) 558 | print("Saved data about losses in", json_name, ".") 559 | 560 | if params.ddp or params.jz: 561 | cleanup() 562 | 563 | 564 | if __name__ == "__main__": 565 | # Save and load parameters 566 | json_name = save_train_params(save=save_parameters) 567 | print("Loaded parameters stored in", json_name) 568 | 569 | if (not save_parameters) and (len(sys.argv) > 1): 570 | json_name = sys.argv[1] 571 | print("Using json name given in argument:") 572 | print(json_name) 573 | 574 | if start_training: 575 | params = load_params(json_name) 576 | 577 | if params.ddp: 578 | mp.spawn(run, 579 | args=(params, 580 | ), 581 | nprocs=params.WORLD_SIZE 582 | ) 583 | 584 | elif params.jz: 585 | run(params=params) 586 | 587 | else: 588 | run(params=params) 589 | -------------------------------------------------------------------------------- /SCONE/pretrain_scone_vis.py: -------------------------------------------------------------------------------- 1 | # Updated from former script train_discoveries_faster.py 2 | 3 | from scone_utils import * 4 | import sys 5 | 6 | save_parameters = True 7 | start_training = False 8 | 9 | debug = False 10 | 11 | def save_train_params(save=False): 12 | # TO DO: for ddp, save only if is_master but load for everyone, with a synchronization in-between. 13 | params = {} 14 | 15 | # -----General parameters----- 16 | params["ddp"] = False 17 | params["jz"] = True 18 | 19 | if params["ddp"]: 20 | params["CUDA_VISIBLE_DEVICES"] = "0, 1" 21 | params["WORLD_SIZE"] = 2 22 | 23 | elif params["jz"]: 24 | params["WORLD_SIZE"] = idr_torch.size 25 | 26 | else: 27 | params["numGPU"] = 0 28 | params["WORLD_SIZE"] = 1 29 | 30 | params["anomaly_detection"] = True 31 | params["empty_cache_every_n_batch"] = 10 32 | 33 | # -----Ground truth computation parameters----- 34 | params["compute_gt_online"] = False 35 | params["compute_partial_point_cloud_online"] = False 36 | 37 | params["gt_surface_resolution"] = 1.5 38 | 39 | params["gt_max_diagonal"] = 1. # Formerly known as x_range 40 | 41 | params["n_points_surface"] = 16384 # N points on GT surface 42 | 43 | params["surface_epsilon_is_constant"] = True 44 | if params["surface_epsilon_is_constant"]: 45 | params["surface_epsilon"] = 0.00707 46 | 47 | # -----SconeOcc Model Parameters----- 48 | # params["scone_occ_model_name"] = "best_unval_jz_model_scone_occ_mse_warmup_1000_schedule_lr_0.0001.pth" 49 | params["scone_occ_model_name"] = "best_unval_jz_model_scone_occ_mse_warmup_1000_schedule_lr_0.0001.pth" 50 | params["occ_no_view_harmonics"] = False 51 | 52 | params["n_view_max_for_scone_occ"] = 9 53 | params["max_points_per_scone_occ_pass"] = 300000 54 | 55 | # -----Model Parameters----- 56 | params["seq_len"] = 2048 57 | params["pts_dim"] = 4 58 | 59 | params["view_state_n_elev"] = 7 60 | params["view_state_n_azim"] = 2 * 7 61 | params["harmonic_degree"] = 8 62 | 63 | params["n_proxy_points"] = 100000 # 12000, 100000 64 | params["use_occ_to_sample_proxy_points"] = True # True 65 | params["min_occ_for_proxy_points"] = 0.1 66 | 67 | params["true_monte_carlo_sampling"] = True 68 | 69 | # -----Ablation study----- 70 | params["no_view_harmonics"] = False 71 | params["use_sigmoid"] = True 72 | 73 | # -----General training parameters----- 74 | params["start_from_scratch"] = True 75 | params["pretrained_weights_name"] = None 76 | 77 | params["n_view_max"] = 9 78 | params["n_view_min"] = 1 79 | params["filter_tol"] = 0.01 80 | 81 | params["camera_dist"] = 1.5 82 | params["pole_cameras"] = True 83 | params["n_camera_elev"] = 5 84 | params["n_camera_azim"] = 2 * 5 85 | params["n_camera"] = params["n_camera_elev"] * params["n_camera_azim"] 86 | if params["pole_cameras"]: 87 | params["n_camera"] += 2 88 | 89 | params["prediction_in_random_camera_space"] = False 90 | 91 | params["total_batch_size"] = 12 # 12 92 | params["batch_size"] = params["total_batch_size"] // params["WORLD_SIZE"] 93 | params["total_batch_size"] = params["batch_size"] * params["WORLD_SIZE"] 94 | 95 | params["epochs"] = 1000 96 | params["learning_rate"] = 1e-4 97 | 98 | params["schedule_learning_rate"] = True 99 | if params["schedule_learning_rate"]: 100 | params["lr_epochs"] = [179] # [179] 101 | params["lr_factor"] = 0.1 102 | 103 | params["warmup"] = 1000 104 | params["warmup_rate"] = 1 / (params["warmup"] * params["learning_rate"] ** 2) 105 | 106 | params["noam_opt"] = False 107 | params["training_metric"] = "surface_coverage_gain" 108 | # Training metric can be: "surface_coverage", "surface_coverage_gain", "absolute_coverage" 109 | params["training_loss"] = "uncentered_l1" # "kl_divergence", "l1", "uncentered_l1" 110 | params["multiply_loss"] = False 111 | if params["multiply_loss"]: 112 | params["loss_multiplication_factor"] = 10. 113 | 114 | params["nbv_validation"] = True 115 | 116 | params["random_seed"] = 42 117 | params["torch_seed"] = 5 118 | 119 | # -----Visibility Model name to save----- 120 | model_name = "model_scone_vis" 121 | model_name = model_name + "_" + params["training_metric"] 122 | 123 | if params["occ_no_view_harmonics"]: 124 | model_name += "_occ_no_vh" 125 | 126 | if params["no_view_harmonics"]: 127 | model_name += "_no_view_harmonics" 128 | 129 | if params["surface_epsilon_is_constant"]: 130 | model_name += "_constant_epsilon" 131 | else: 132 | model_name += "_adaptative_epsilon" 133 | 134 | if params["ddp"]: 135 | model_name = "ddp_" + model_name 136 | elif params["jz"]: 137 | model_name = "jz_" + model_name 138 | 139 | # model_name += "_seed" + str(params["random_seed"]) + "_" + str(params["torch_seed"]) 140 | 141 | model_name += "_" + params["training_loss"] 142 | 143 | if params["noam_opt"]: 144 | model_name += "_noam" 145 | model_name += "_warmup_" + str(params["warmup"]) 146 | 147 | if params["schedule_learning_rate"]: 148 | model_name += "_schedule" 149 | model_name += "_lr_" + str(params["learning_rate"]) 150 | 151 | if params["use_sigmoid"]: 152 | model_name += "_sigmoid" 153 | else: 154 | model_name += "_relu" 155 | 156 | if debug: 157 | model_name = "debug_" + model_name 158 | 159 | if params["prediction_in_random_camera_space"]: 160 | model_name += "_random_rot" 161 | 162 | if params["true_monte_carlo_sampling"]: 163 | model_name += "_tmcs" 164 | 165 | params["scone_vis_model_name"] = model_name 166 | 167 | # -----Json name to save params----- 168 | json_name = "train_params_" + params["scone_vis_model_name"] + ".json" 169 | 170 | if save: 171 | with open(json_name, 'w') as outfile: 172 | json.dump(params, outfile) 173 | 174 | print("Parameters save in:") 175 | print(json_name) 176 | 177 | return json_name 178 | 179 | 180 | def loop(params, 181 | batch, mesh_dict, 182 | scone_occ, scone_vis, cov_loss_fn, 183 | device, is_master, 184 | n_views_list=None, 185 | optimal_sequences=None 186 | ): 187 | paths = mesh_dict['path'] 188 | 189 | cov_pred = torch.zeros(0, params.n_camera, 1, device=device) 190 | cov_truth = torch.zeros(0, params.n_camera, 1, device=device) 191 | 192 | info_harmonics_i = None 193 | 194 | base_harmonics, h_polar, h_azim = get_all_harmonics_under_degree(params.harmonic_degree, 195 | params.view_state_n_elev, 196 | params.view_state_n_azim, 197 | device) 198 | 199 | batch_size = len(paths) 200 | 201 | if n_views_list is None: 202 | n_views = np.random.randint(params.n_view_min, params.n_view_max + 1, batch_size) 203 | else: 204 | n_views = get_validation_n_view(params, n_views_list, batch, idr_torch.rank) 205 | 206 | if batch == 0 and is_master: 207 | print("First batch:", mesh_dict['path']) 208 | 209 | for i in range(batch_size): 210 | # ----------Load input mesh and ground truth data--------------------------------------------------------------- 211 | 212 | path_i = paths[i] 213 | 214 | # Loading info about partial point clouds and coverages 215 | part_pc, coverage = get_gt_partial_point_clouds(path=path_i, 216 | normalization_factor=1. / params.gt_surface_resolution, 217 | device=device) 218 | 219 | # Loading info about ground truth surface 220 | # gt_surface, surface_epsilon = get_gt_surface(params=params, 221 | # path=path_i, 222 | # normalization_factor=1./params.gt_surface_resolution, 223 | # device=device) 224 | 225 | # Initial dense sampling 226 | X_world = sample_X_in_box(x_range=params.gt_max_diagonal, n_sample=params.n_proxy_points, device=device) 227 | 228 | # ----------Set camera candidates for coverage prediction------------------------------------------------------- 229 | X_cam_world, camera_dist, camera_elev, camera_azim = get_cameras_on_sphere(params, device, 230 | pole_cameras=params.pole_cameras) 231 | 232 | # ----------Select initial observations of the object----------------------------------------------------------- 233 | 234 | # Select a subset of n_view cameras to compute an initial point cloud 235 | n_view = n_views[i] 236 | if optimal_sequences is None: 237 | view_idx = torch.randperm(len(camera_elev), device=device)[:n_view] 238 | else: 239 | optimal_seq, _ = get_optimal_sequence(optimal_sequences, path_i, n_view) 240 | view_idx = optimal_seq.to(device) 241 | 242 | # Select either first camera view space, or random camera view space as prediction view space 243 | if params.prediction_in_random_camera_space: 244 | prediction_cam_idx = np.random.randint(low=0, high=len(camera_elev)) 245 | else: 246 | prediction_cam_idx = view_idx[0] 247 | prediction_box_center = torch.Tensor([0., 0., params.camera_dist]).to(device) 248 | 249 | # Move camera coordinates from world space to prediction view space, and normalize them for prediction box 250 | prediction_R, prediction_T = look_at_view_transform(dist=camera_dist[prediction_cam_idx], 251 | elev=camera_elev[prediction_cam_idx], 252 | azim=camera_azim[prediction_cam_idx], 253 | device=device) 254 | prediction_camera = FoVPerspectiveCameras(device=device, R=prediction_R, T=prediction_T) 255 | prediction_view_transform = prediction_camera.get_world_to_view_transform() 256 | 257 | X_cam = prediction_view_transform.transform_points(X_cam_world) 258 | X_cam = normalize_points_in_prediction_box(points=X_cam, 259 | prediction_box_center=prediction_box_center, 260 | prediction_box_diag=params.gt_max_diagonal) 261 | _, elev_cam, azim_cam = get_spherical_coords(X_cam) 262 | 263 | X_view = X_cam[view_idx] 264 | X_cam = X_cam.view(1, params.n_camera, 3) 265 | 266 | # ----------Capture initial observations------------------------------------------------------------------------ 267 | 268 | # Points observed in initial views 269 | pc = torch.vstack([part_pc[pc_idx] for pc_idx in view_idx]) 270 | 271 | # Downsampling partial point cloud 272 | pc = pc[torch.randperm(len(pc))[:n_view * params.seq_len]] 273 | 274 | # Move partial point cloud from world space to prediction view space, and normalize them in prediction box 275 | pc = prediction_view_transform.transform_points(pc) 276 | pc = normalize_points_in_prediction_box(points=pc, 277 | prediction_box_center=prediction_box_center, 278 | prediction_box_diag=params.gt_max_diagonal).view(1, -1, 3) 279 | 280 | # ----------Compute inputs to SconeVis----------------------------------------------- 281 | 282 | # Sample random proxy points in space 283 | X_idx = torch.randperm(len(X_world))[:params.n_proxy_points] 284 | X_world = X_world[X_idx] 285 | 286 | # Move proxy points from world space to prediction view space, and normalize them in prediction box 287 | X = prediction_view_transform.transform_points(X_world) 288 | X = normalize_points_in_prediction_box(points=X, 289 | prediction_box_center=prediction_box_center, 290 | prediction_box_diag=params.gt_max_diagonal 291 | ) 292 | 293 | # Filter Proxy Points using pc shape from view cameras 294 | R_view, T_view = look_at_view_transform(eye=X_view, 295 | at=torch.zeros_like(X_view), 296 | device=device) 297 | view_cameras = FoVPerspectiveCameras(R=R_view, T=T_view, zfar=1000, device=device) 298 | X, _ = filter_proxy_points(view_cameras, X, pc.view(-1, 3), filter_tol=params.filter_tol) 299 | X = X.view(1, X.shape[0], 3) 300 | 301 | # Compute view state vector and corresponding view harmonics 302 | view_state = compute_view_state(X, X_view, 303 | params.view_state_n_elev, params.view_state_n_azim) 304 | view_harmonics = compute_view_harmonics(view_state, 305 | base_harmonics, h_polar, h_azim, 306 | params.view_state_n_elev, params.view_state_n_azim) 307 | occ_view_harmonics = 0. + view_harmonics 308 | if params.occ_no_view_harmonics: 309 | occ_view_harmonics *= 0. 310 | if params.no_view_harmonics: 311 | view_harmonics *= 0. 312 | 313 | # Compute occupancy probabilities 314 | with torch.no_grad(): 315 | occ_prob_i = compute_occupancy_probability(scone_occ=scone_occ, 316 | pc=pc, 317 | X=X, 318 | view_harmonics=occ_view_harmonics, 319 | max_points_per_pass=params.max_points_per_scone_occ_pass 320 | ).view(-1, 1) 321 | 322 | proxy_points, view_harmonics, sample_idx = sample_proxy_points(X[0], occ_prob_i, view_harmonics.squeeze(dim=0), 323 | n_sample=params.seq_len, 324 | min_occ=params.min_occ_for_proxy_points, 325 | use_occ_to_sample=params.use_occ_to_sample_proxy_points, 326 | return_index=True) 327 | 328 | proxy_points = torch.unsqueeze(proxy_points, dim=0) 329 | view_harmonics = torch.unsqueeze(view_harmonics, dim=0) 330 | 331 | # ----------Predict Coverage Gains------------------------------------------------------------------------------ 332 | visibility_gain_harmonics = scone_vis(proxy_points, view_harmonics=view_harmonics) 333 | if params.true_monte_carlo_sampling: 334 | proxy_points = torch.unsqueeze(proxy_points[0][sample_idx], dim=0) 335 | visibility_gain_harmonics = torch.unsqueeze(visibility_gain_harmonics[0][sample_idx], dim=0) 336 | 337 | if params.ddp or params.jz: 338 | cov_pred_i = scone_vis.module.compute_coverage_gain(proxy_points, 339 | visibility_gain_harmonics, 340 | X_cam) 341 | else: 342 | cov_pred_i = scone_vis.compute_coverage_gain(proxy_points, 343 | visibility_gain_harmonics, 344 | X_cam) 345 | 346 | cov_pred = torch.vstack((cov_pred, cov_pred_i.view(1, -1, 1))) 347 | 348 | # ----------Compute ground truth information scores---------- 349 | cov_truth_i = compute_gt_coverage_gain_from_precomputed_matrices(coverage=coverage, 350 | initial_cam_idx=view_idx) 351 | cov_truth = torch.vstack((cov_truth, cov_truth_i.view(1, -1, 1))) 352 | 353 | # ----------Compute loss---------- 354 | # cov_pred = cov_pred.view(-1, params.n_camera, 1) 355 | # cov_truth = cov_truth.view(-1, params.n_camera, 1) 356 | loss = cov_loss_fn(cov_pred, cov_truth) 357 | 358 | if batch % params.empty_cache_every_n_batch == 0 and is_master: 359 | print("View state sum-mean:", torch.mean(torch.sum(view_state, dim=-1))) 360 | print("Point cloud features shape:", proxy_points.shape) 361 | # print("Surface epsilon:", params.surface_epsilon) 362 | 363 | return loss, cov_pred, cov_truth, batch_size, n_view 364 | 365 | 366 | def train(params, 367 | dataloader, 368 | scone_occ, 369 | scone_vis, cov_loss_fn, 370 | optimizer, 371 | device, is_master, 372 | train_losses): 373 | 374 | num_batches = len(dataloader) 375 | size = num_batches * params.total_batch_size 376 | train_loss = 0. 377 | 378 | # Preparing information model 379 | scone_vis.train() 380 | 381 | t0 = time.time() 382 | 383 | for batch, (mesh_dict) in enumerate(dataloader): 384 | 385 | loss, cov_pred, cov_truth, batch_size, n_view = loop(params, 386 | batch, mesh_dict, 387 | scone_occ, scone_vis, cov_loss_fn, 388 | device, is_master, 389 | n_views_list=None, 390 | optimal_sequences=None) 391 | 392 | # Backpropagation 393 | optimizer.zero_grad() 394 | loss.backward() 395 | optimizer.step() 396 | 397 | train_loss += loss.detach() 398 | if params.multiply_loss: 399 | train_loss /= params.loss_multiplication_factor 400 | 401 | if batch % params.empty_cache_every_n_batch == 0: 402 | 403 | # loss = reduce_tensor(loss) 404 | if params.ddp or params.jz: 405 | loss = reduce_tensor(loss, world_size=params.WORLD_SIZE) 406 | loss = to_python_float(loss) 407 | 408 | current = batch * batch_size # * idr_torch.size 409 | if params.ddp or params.jz: 410 | current *= params.WORLD_SIZE 411 | 412 | truth_norm = to_python_float(torch.linalg.norm(cov_truth.detach())) 413 | pred_norm = to_python_float(torch.linalg.norm(cov_pred.detach())) 414 | 415 | # torch.cuda.synchronize() 416 | 417 | if is_master: 418 | print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]", 419 | "computed in", (time.time() - t0) / 60., "minutes.") 420 | print(">>>Prediction shape:", cov_pred.shape, 421 | "\n>>>Truth norm:", truth_norm, ">>>Prediction norm:", pred_norm, 422 | "\nNumber of cameras:", n_view, "+ 1.") 423 | # print("Harmonics:\n", info_harmonics_i) 424 | # TO REMOVE 425 | torch.cuda.empty_cache() 426 | 427 | if batch % params.empty_cache_every_n_batch == 0: 428 | t0 = time.time() 429 | 430 | # train_loss = reduce_tensor(train_loss) 431 | if params.ddp or params.jz: 432 | train_loss = reduce_tensor(train_loss, world_size=params.WORLD_SIZE) 433 | train_loss = to_python_float(train_loss) 434 | train_loss /= num_batches 435 | train_losses.append(train_loss) 436 | 437 | 438 | def validation(params, 439 | dataloader, 440 | scone_occ, 441 | scone_vis, cov_loss_fn, 442 | device, is_master, 443 | val_losses, 444 | nbv_validation=False, 445 | val_coverages=None): 446 | 447 | num_batches = len(dataloader) 448 | size = num_batches * params.total_batch_size 449 | val_loss = 0. 450 | val_coverage = 0. 451 | 452 | # Preparing information model 453 | scone_vis.eval() 454 | 455 | t0 = time.time() 456 | 457 | n_views_list = get_validation_n_views_list(params, dataloader) 458 | optimal_sequences = get_validation_optimal_sequences(jz=params.jz, device=device) 459 | 460 | for batch, (mesh_dict) in enumerate(dataloader): 461 | with torch.no_grad(): 462 | loss, cov_pred, cov_truth, batch_size, _ = loop(params, 463 | batch, mesh_dict, 464 | scone_occ, scone_vis, cov_loss_fn, 465 | device, is_master, 466 | n_views_list=n_views_list, 467 | optimal_sequences=optimal_sequences) 468 | 469 | val_loss += loss.detach() 470 | if params.multiply_loss: 471 | val_loss /= params.loss_multiplication_factor 472 | 473 | if nbv_validation: 474 | with torch.no_grad(): 475 | info_pred_scores = cov_pred.view(batch_size, params.n_camera, 1) 476 | info_pred_scores = torch.squeeze(info_pred_scores, dim=-1) 477 | 478 | info_truth_scores = cov_truth.view(batch_size, params.n_camera, 1) 479 | info_truth_scores = torch.squeeze(info_truth_scores, dim=-1) 480 | 481 | max_info_preds, max_info_idx = torch.max(info_pred_scores, 482 | dim=1) 483 | 484 | max_idx = max_info_idx.view(batch_size, 1) 485 | true_max_coverages = torch.gather(info_truth_scores, 486 | dim=1, 487 | index=max_idx) 488 | 489 | val_coverage += torch.sum(true_max_coverages).detach() / batch_size 490 | 491 | if batch % params.empty_cache_every_n_batch == 0: 492 | torch.cuda.empty_cache() 493 | 494 | # val_loss = reduce_tensor(val_loss) 495 | if params.ddp or params.jz: 496 | val_loss = reduce_tensor(val_loss, world_size=params.WORLD_SIZE) 497 | if nbv_validation: 498 | val_coverage = reduce_tensor(val_coverage, world_size=params.WORLD_SIZE) 499 | 500 | val_loss = to_python_float(val_loss) 501 | val_loss /= num_batches 502 | val_losses.append(val_loss) 503 | 504 | if nbv_validation: 505 | val_coverage = to_python_float(val_coverage) 506 | val_coverage /= num_batches 507 | if val_coverages is None: 508 | raise NameError("Variable val_coverages is set to None.") 509 | else: 510 | val_coverages.append(val_coverage) 511 | 512 | if is_master: 513 | print(f"Validation Error: \n Avg loss: {val_loss:>8f} \n") 514 | if nbv_validation: 515 | print(f"Avg nbv coverage: {val_coverage:>8f} \n") 516 | 517 | 518 | def run(ddp_rank=None, params=None): 519 | # Set device 520 | device = setup_device(params, ddp_rank) 521 | 522 | batch_size = params.batch_size 523 | total_batch_size = params.total_batch_size 524 | 525 | if params.ddp: 526 | world_size = params.WORLD_SIZE 527 | rank = ddp_rank 528 | is_master = rank == 0 529 | elif params.jz: 530 | world_size = idr_torch.size 531 | rank = idr_torch.rank 532 | is_master = rank == 0 533 | else: 534 | world_size, rank = None, None 535 | is_master = True 536 | 537 | # Create dataloader 538 | train_dataloader, val_dataloader, _ = get_shapenet_dataloader(batch_size=params.batch_size, 539 | ddp=params.ddp, jz=params.jz, 540 | world_size=world_size, ddp_rank=rank, 541 | load_obj=False, 542 | data_path=None) 543 | 544 | # Initialize or Load models 545 | scone_occ = load_scone_occ(params, params.scone_occ_model_name, ddp_model=True, device=device) 546 | scone_occ.eval() 547 | 548 | scone_vis = SconeVis(use_sigmoid=params.use_sigmoid).to(device) 549 | scone_vis, optimizer, opt_name, start_epoch, best_loss, best_coverage = initialize_scone_vis(params=params, 550 | scone_vis=scone_vis, 551 | device=device, 552 | torch_seed=params.torch_seed, 553 | load_pretrained_weights=not params.start_from_scratch, 554 | pretrained_weights_name=params.pretrained_weights_name, 555 | ddp_rank=rank) 556 | 557 | best_train_loss = 1000 558 | epochs_without_improvement = 0 559 | learning_rate = params.learning_rate 560 | 561 | # Set loss function 562 | cov_loss_fn = get_cov_loss_fn(params) 563 | 564 | if is_master: 565 | print("Model name:", params.scone_vis_model_name, "\nArchitecture:\n") 566 | print(scone_vis) 567 | print("Model name:", params.scone_vis_model_name) 568 | print("Numbers of trainable parameters:", count_parameters(scone_vis)) 569 | print("Using", opt_name, "optimizer.") 570 | print("Using occupancy model", params.scone_occ_model_name) 571 | 572 | if params.training_loss == "kl_divergence": 573 | print("Using softmax + KL Divergence loss.") 574 | elif params.training_loss == "mse": 575 | print("Using MSE loss.") 576 | 577 | print("Using", params.n_camera, "uniformly sampled camera position per mesh.") 578 | 579 | print("Training data:", len(train_dataloader), "batches.") 580 | print("Validation data:", len(val_dataloader), "batches.") 581 | print("Batch size:", params.total_batch_size) 582 | print("Batch size per GPU:", params.batch_size) 583 | 584 | # Begin training process 585 | train_losses = [] 586 | val_losses = [] 587 | val_coverages = [] 588 | 589 | t0 = time.time() 590 | for t_e in range(params.epochs): 591 | t = start_epoch + t_e 592 | if is_master: 593 | print(f"Epoch {t + 1}\n-------------------------------") 594 | torch.cuda.empty_cache() 595 | 596 | # Update learning rate 597 | if params.schedule_learning_rate: 598 | if t in params.lr_epochs: 599 | print("Multiplying learning rate by", params.lr_factor) 600 | learning_rate *= params.lr_factor 601 | 602 | update_learning_rate(params, optimizer, learning_rate) 603 | print("Max learning rate set to", learning_rate) 604 | print("Current learning rate set to", optimizer._rate) 605 | 606 | train_dataloader.sampler.set_epoch(t) 607 | 608 | scone_vis.train() 609 | train(params, 610 | train_dataloader, 611 | scone_occ, 612 | scone_vis, cov_loss_fn, 613 | optimizer, 614 | device, is_master, 615 | train_losses 616 | ) 617 | 618 | current_loss = train_losses[-1] 619 | 620 | if is_master: 621 | print("Training done for epoch", t + 1, ".") 622 | # torch.save(model, "unvalidated_" + model_name + ".pth") 623 | torch.save({ 624 | 'epoch': t + 1, 625 | 'model_state_dict': scone_vis.state_dict(), 626 | 'optimizer_state_dict': optimizer.state_dict(), 627 | 'loss': current_loss, 628 | # 'coverage': current_val_coverage, 629 | 'train_losses': train_losses, 630 | 'val_losses': val_losses, 631 | }, "unvalidated_" + params.scone_vis_model_name + ".pth") 632 | 633 | if current_loss < best_train_loss: 634 | torch.save({ 635 | 'epoch': t + 1, 636 | 'model_state_dict': scone_vis.state_dict(), 637 | 'optimizer_state_dict': optimizer.state_dict(), 638 | 'loss': current_loss, 639 | # 'coverage': current_val_coverage, 640 | 'train_losses': train_losses, 641 | 'val_losses': val_losses, 642 | }, "best_unval_" + params.scone_vis_model_name + ".pth") 643 | best_train_loss = current_loss 644 | print("Best model on training set saved with loss " + str(current_loss) + " .\n") 645 | 646 | torch.cuda.empty_cache() 647 | 648 | if is_master: 649 | print("Beginning evaluation on validation dataset...") 650 | # val_dataloader.sampler.set_epoch(t) 651 | 652 | scone_vis.eval() 653 | validation(params, 654 | val_dataloader, 655 | scone_occ, 656 | scone_vis, cov_loss_fn, 657 | device, is_master, 658 | val_losses, 659 | nbv_validation=params.nbv_validation, 660 | val_coverages=val_coverages 661 | ) 662 | 663 | current_val_loss = val_losses[-1] 664 | current_val_coverage = val_coverages[-1] 665 | if current_val_loss < best_loss: 666 | # torch.save(model, "validated_" + model_name + ".pth") 667 | if is_master: 668 | torch.save({ 669 | 'epoch': t + 1, 670 | 'model_state_dict': scone_vis.state_dict(), 671 | 'optimizer_state_dict': optimizer.state_dict(), 672 | 'loss': current_val_loss, 673 | 'coverage': current_val_coverage, 674 | 'train_losses': train_losses, 675 | 'val_losses': val_losses 676 | }, "validated_" + params.scone_vis_model_name + ".pth") 677 | print("Model saved with loss " + str(current_val_loss) + " .\n") 678 | best_loss = val_losses[-1] 679 | epochs_without_improvement = 0 680 | else: 681 | epochs_without_improvement += 1 682 | 683 | if is_master and current_val_coverage > best_coverage: 684 | # torch.save(model, "validated_" + model_name + ".pth") 685 | torch.save({ 686 | 'epoch': t + 1, 687 | 'model_state_dict': scone_vis.state_dict(), 688 | 'optimizer_state_dict': optimizer.state_dict(), 689 | 'loss': current_val_loss, 690 | 'coverage': current_val_coverage, 691 | 'train_losses': train_losses, 692 | 'val_losses': val_losses 693 | }, "coverage_validated_" + params.scone_vis_model_name + ".pth") 694 | print("Model saved with coverage " + str(current_val_coverage) + " .\n") 695 | best_coverage = val_coverages[-1] 696 | 697 | # Save data about losses 698 | if is_master: 699 | losses_data = {} 700 | losses_data['train_loss'] = train_losses 701 | losses_data['val_loss'] = val_losses 702 | json_name = "losses_data_" + params.scone_vis_model_name + ".json" 703 | with open(json_name, 'w') as outfile: 704 | json.dump(losses_data, outfile) 705 | print("Saved data about losses in", json_name, ".") 706 | 707 | if is_master: 708 | print("Done in", (time.time() - t0) / 3600., "hours!") 709 | 710 | # Save data about losses 711 | losses_data = {} 712 | losses_data['train_loss'] = train_losses 713 | losses_data['val_loss'] = val_losses 714 | json_name = "losses_data_" + params.scone_vis_model_name + ".json" 715 | with open(json_name, 'w') as outfile: 716 | json.dump(losses_data, outfile) 717 | print("Saved data about losses in", json_name, ".") 718 | 719 | if params.ddp or params.jz: 720 | cleanup() 721 | 722 | 723 | if __name__ == "__main__": 724 | # Save and load parameters 725 | json_name = save_train_params(save=save_parameters) 726 | print("Loaded parameters stored in", json_name) 727 | 728 | if (not save_parameters) and (len(sys.argv) > 1): 729 | json_name = sys.argv[1] 730 | print("Using json name given in argument:") 731 | print(json_name) 732 | 733 | if start_training: 734 | params = load_params(json_name) 735 | 736 | if params.ddp: 737 | mp.spawn(run, 738 | args=(params, 739 | ), 740 | nprocs=params.WORLD_SIZE 741 | ) 742 | 743 | elif params.jz: 744 | run(params=params) 745 | 746 | else: 747 | run(params=params) -------------------------------------------------------------------------------- /SCONE/scone_utils.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from CustomGeometry import * 3 | from CustomDataset import CustomShapenetDataset 4 | from spherical_harmonics import clear_spherical_harmonics_cache 5 | 6 | import torch.distributed as dist 7 | from torch.utils.data.distributed import DistributedSampler 8 | import idr_torch 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | from spherical_harmonics import get_spherical_harmonics 11 | 12 | from SconeOcc import SconeOcc 13 | from SconeVis import SconeVis, KLDivCE, L1_loss, Uncentered_L1_loss 14 | 15 | 16 | def setup_device(params, ddp_rank=None): 17 | 18 | if params.ddp: 19 | print("Setup device", str(ddp_rank), "for DDP training...") 20 | os.environ["CUDA_VISIBLE_DEVICES"] = params.CUDA_VISIBLE_DEVICES 21 | 22 | os.environ['MASTER_ADDR'] = 'localhost' 23 | os.environ['MASTER_PORT'] = '12355' 24 | os.environ['WORLD_SIZE'] = str(params.WORLD_SIZE) 25 | os.environ['RANK'] = str(ddp_rank) 26 | 27 | dist.init_process_group("nccl", rank=ddp_rank, world_size=params.WORLD_SIZE) 28 | 29 | device = torch.device("cuda:" + str(ddp_rank)) 30 | torch.cuda.set_device(device) 31 | 32 | torch.cuda.empty_cache() 33 | print("Setup done!") 34 | 35 | if ddp_rank == 0: 36 | print(torch.cuda.memory_summary()) 37 | 38 | elif params.jz: 39 | print("Setup device", str(idr_torch.rank), " for Jean Zay training...") 40 | dist.init_process_group(backend='nccl', 41 | init_method='env://', 42 | rank=idr_torch.rank, 43 | world_size=idr_torch.size) 44 | 45 | torch.cuda.set_device(idr_torch.local_rank) 46 | device = torch.device("cuda") 47 | 48 | torch.cuda.empty_cache() 49 | print("Setup done!") 50 | 51 | if idr_torch.rank == 0: 52 | print(torch.cuda.memory_summary()) 53 | 54 | else: 55 | # Set our device: 56 | if torch.cuda.is_available(): 57 | device = torch.device("cuda:" + str(params.numGPU)) 58 | torch.cuda.set_device(device) 59 | else: 60 | device = torch.device("cpu") 61 | print(device) 62 | 63 | # Empty cache 64 | torch.cuda.empty_cache() 65 | print(torch.cuda.memory_summary()) 66 | 67 | return device 68 | 69 | 70 | def load_params(json_name): 71 | return Params(json_name) 72 | 73 | 74 | def reduce_tensor(tensor: torch.Tensor, world_size): 75 | """Reduce tensor across all nodes.""" 76 | rt = tensor.clone() 77 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 78 | rt /= world_size 79 | return rt 80 | 81 | 82 | def to_python_float(t: torch.Tensor): 83 | if hasattr(t, 'item'): 84 | return t.item() 85 | else: 86 | return t[0] 87 | 88 | 89 | def cleanup(): 90 | dist.destroy_process_group() 91 | 92 | 93 | def get_shapenet_dataloader(batch_size, 94 | ddp=False, jz=False, 95 | world_size=None, ddp_rank=None, 96 | test_novel=False, 97 | test_number=0, 98 | load_obj=False, 99 | data_path=None): 100 | # Database path 101 | # SHAPENET_PATH = "../../../../datasets/shapenet2/ShapeNetCore.v2" 102 | memory_threshold = 10e6 103 | if data_path is None: 104 | if jz: 105 | SHAPENET_PATH = "../../datasets/ShapeNetCore.v1" 106 | else: 107 | SHAPENET_PATH = "../../../../../../mnt/ssd/aguedon/ShapeNetCore.v1/ShapeNetCore.v1" 108 | SHAPENET_PATH = "../../../../datasets/ShapeNetCore.v1" 109 | # "../../datasets/ShapeNetCore.v1" 110 | # "../../../../../../mnt/ssd/aguedon/ShapeNetCore.v1/ShapeNetCore.v1" 111 | else: 112 | SHAPENET_PATH = data_path 113 | 114 | database_path = os.path.join(SHAPENET_PATH, "train_categories") 115 | train_json = os.path.join(SHAPENET_PATH, "train_list.json") 116 | val_json = os.path.join(SHAPENET_PATH, "val_list.json") 117 | if not test_novel: 118 | if test_number == 0: 119 | test_json = os.path.join(SHAPENET_PATH, "test_list.json") 120 | elif test_number == -1: 121 | test_json = os.path.join(SHAPENET_PATH, "all_test_list.json") 122 | else: 123 | test_json = os.path.join(SHAPENET_PATH, "test_list_" + str(test_number) + ".json") 124 | print("Using test split number " + str(test_number) + ".") 125 | else: 126 | database_path = os.path.join(SHAPENET_PATH, "test_categories") 127 | print("Using novel test split.") 128 | if test_number >= 0: 129 | test_json = os.path.join(SHAPENET_PATH, "test_novel_list.json") 130 | else: 131 | test_json = os.path.join(SHAPENET_PATH, "all_test_novel_list.json") 132 | # test_json = os.path.join(SHAPENET_PATH, "debug_list.json") 133 | 134 | train_dataset = CustomShapenetDataset(data_path=database_path, 135 | memory_threshold=memory_threshold, 136 | save_to_json=False, 137 | load_from_json=True, 138 | json_name=train_json, 139 | official_split=True, 140 | adjust_diagonally=True, 141 | load_obj=load_obj) 142 | val_dataset = CustomShapenetDataset(data_path=database_path, 143 | memory_threshold=memory_threshold, 144 | save_to_json=False, 145 | load_from_json=True, 146 | json_name=val_json, 147 | official_split=True, 148 | adjust_diagonally=True, 149 | load_obj=load_obj) 150 | test_dataset = CustomShapenetDataset(data_path=database_path, 151 | memory_threshold=memory_threshold, 152 | save_to_json=False, 153 | load_from_json=True, 154 | json_name=test_json, 155 | official_split=True, 156 | adjust_diagonally=True, 157 | load_obj=load_obj) 158 | 159 | if ddp or jz: 160 | if jz: 161 | rank = idr_torch.rank 162 | else: 163 | rank = ddp_rank 164 | train_sampler = DistributedSampler(train_dataset, 165 | num_replicas=world_size, 166 | rank=rank, 167 | drop_last=True) 168 | valid_sampler = DistributedSampler(val_dataset, 169 | num_replicas=world_size, 170 | rank=rank, 171 | shuffle=False, 172 | drop_last=True) 173 | test_sampler = DistributedSampler(test_dataset, 174 | num_replicas=world_size, 175 | rank=rank, 176 | shuffle=False, 177 | drop_last=True) 178 | 179 | train_dataloader = DataLoader(train_dataset, 180 | batch_size=batch_size, 181 | drop_last=True, 182 | collate_fn=collate_batched_meshes, 183 | sampler=train_sampler) 184 | validation_dataloader = DataLoader(val_dataset, 185 | batch_size=batch_size, 186 | drop_last=True, 187 | collate_fn=collate_batched_meshes, 188 | sampler=valid_sampler) 189 | test_dataloader = DataLoader(test_dataset, 190 | batch_size=batch_size, 191 | drop_last=True, 192 | collate_fn=collate_batched_meshes, 193 | sampler=test_sampler) 194 | else: 195 | train_dataloader = DataLoader(train_dataset, 196 | batch_size=batch_size, 197 | collate_fn=collate_batched_meshes, 198 | shuffle=True) 199 | validation_dataloader = DataLoader(val_dataset, 200 | batch_size=batch_size, 201 | collate_fn=collate_batched_meshes, 202 | shuffle=False) 203 | test_dataloader = DataLoader(test_dataset, 204 | batch_size=batch_size, 205 | collate_fn=collate_batched_meshes, 206 | shuffle=False) 207 | 208 | return train_dataloader, validation_dataloader, test_dataloader 209 | 210 | 211 | def get_optimizer(params, model): 212 | """ 213 | Returns AdamW optimizer with linear warmup steps at beginning. 214 | 215 | :param params: (Params) Hyper parameters file. 216 | :param model: Model to be trained. 217 | :return: (Tuple) Optimizer and its name. 218 | """ 219 | optimizer = WarmupConstantOpt(learning_rate=params.learning_rate, 220 | warmup=params.warmup, 221 | optimizer=torch.optim.AdamW(model.parameters(), 222 | lr=0 223 | ) 224 | ) 225 | opt_name = "WarmupAdamW" 226 | 227 | return optimizer, opt_name 228 | 229 | 230 | def update_learning_rate(params, optimizer, learning_rate): 231 | 232 | if params.noam_opt: 233 | optimizer.model_size = 1 / (optimizer.warmup * learning_rate ** 2) 234 | if optimizer._step == 0: 235 | optimizer._rate = 0 236 | else: 237 | optimizer._rate = (optimizer.model_size ** (-0.5) 238 | * min(optimizer._step ** (-0.5), 239 | optimizer._step * optimizer.warmup ** (-1.5))) 240 | 241 | else: 242 | optimizer.learning_rate = learning_rate 243 | if optimizer._step == 0: 244 | optimizer._rate = 0 245 | else: 246 | optimizer._rate = optimizer.learning_rate * min(1., optimizer._step / optimizer.warmup) 247 | 248 | 249 | def initialize_scone_occ(params, scone_occ, device, 250 | torch_seed=None, 251 | load_pretrained_weights=False, 252 | pretrained_weights_name=None, 253 | ddp_rank=None): 254 | """ 255 | Initializes SCONE's occupancy probability prediction module for training. 256 | Can be initialized from scratch, or from an already trained model to resume training. 257 | 258 | :param params: (Params) Hyper parameters file. 259 | :param scone_occ: (SconeOcc) Occupancy probability prediction model. 260 | :param device: Device. 261 | :param torch_seed: (int) Seed used to initialize the network. 262 | :param load_pretrained_weights: (bool) If True, pretrained weights are loaded for initialization. 263 | :param ddp_rank: Rank dor DDP training. 264 | :return: (Tuple) Initialized SconeOcc model, Optimizer, optimizer name, start epoch, best loss. 265 | If training from scratch, start_epoch=0 and best_loss=0. 266 | """ 267 | model_name = params.scone_occ_model_name 268 | start_epoch = 0 269 | best_loss = 1000. 270 | 271 | # Weight initialization process 272 | if load_pretrained_weights: 273 | # Load pretrained weights if needed 274 | if pretrained_weights_name==None: 275 | weights_file = "unvalidated_" + model_name + ".pth" 276 | else: 277 | weights_file = pretrained_weights_name 278 | checkpoint = torch.load(weights_file, map_location=device) 279 | start_epoch = checkpoint['epoch'] + 1 280 | best_loss = checkpoint['loss'] 281 | 282 | ddp_model = False 283 | if (model_name[:2] == "jz") or (model_name[:3] == "ddp"): 284 | ddp_model = True 285 | 286 | if ddp_model: 287 | scone_occ = load_ddp_state_dict(scone_occ, checkpoint['model_state_dict']) 288 | else: 289 | scone_occ.load_state_dict(checkpoint['model_state_dict']) 290 | 291 | else: 292 | # Else, applies a basic initialization process 293 | if torch_seed is not None: 294 | torch.manual_seed(torch_seed) 295 | print("Seed", torch_seed, "chosen.") 296 | scone_occ.apply(init_weights) 297 | 298 | # DDP wrapping if needed 299 | if params.ddp: 300 | scone_occ = DDP(scone_occ, 301 | device_ids=[ddp_rank]) 302 | elif params.jz: 303 | scone_occ = DDP(scone_occ, 304 | device_ids=[idr_torch.local_rank]) 305 | 306 | # Creating optimizer 307 | optimizer, opt_name = get_optimizer(params, scone_occ) 308 | if load_pretrained_weights: 309 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 310 | 311 | return scone_occ, optimizer, opt_name, start_epoch, best_loss 312 | 313 | 314 | def load_scone_occ(params, trained_model_name, ddp_model, device): 315 | """ 316 | Loads an already trained occupancy probability prediction model for inference on a single GPU. 317 | 318 | :param params: (Params) Parameters file. 319 | :param trained_model_name: (str) Name of trained model's checkpoint. 320 | :param device: Device. 321 | :return: (SconeOcc) Occupancy probability prediction module with trained weights. 322 | """ 323 | scone_occ = SconeOcc().to(device) 324 | 325 | model_name = params.scone_occ_model_name 326 | 327 | # Loads checkpoint 328 | checkpoint = torch.load(trained_model_name, map_location=device) 329 | print("Model name:", trained_model_name) 330 | print("Trained for", checkpoint['epoch'], 'epochs.') 331 | print("Training finished with loss", checkpoint['loss']) 332 | 333 | # Loads trained weights 334 | if ddp_model: 335 | scone_occ = load_ddp_state_dict(scone_occ, checkpoint['model_state_dict']) 336 | else: 337 | scone_occ.load_state_dict(checkpoint['model_state_dict']) 338 | 339 | return scone_occ 340 | 341 | 342 | def initialize_scone_vis(params, scone_vis, device, 343 | torch_seed=None, 344 | load_pretrained_weights=False, 345 | pretrained_weights_name=None, 346 | ddp_rank=None): 347 | """ 348 | Initializes SCONE's visibility prediction module for training. 349 | Can be initialized from scratch, or from an already trained model to resume training. 350 | 351 | :param params: (Params) Hyper parameters file. 352 | :param scone_vis: (SconeVis) Visibility prediction model. 353 | :param device: Device. 354 | :param torch_seed: (int) Seed used to initialize the network. 355 | :param load_pretrained_weights: (bool) If True, pretrained weights are loaded for initialization. 356 | :param ddp_rank: Rank dor DDP training. 357 | :return: (Tuple) Initialized SconeVis model, Optimizer, optimizer name, start epoch, best loss, best coverage. 358 | If training from scratch, start_epoch=0, best_loss=0. and best_coverage=0. 359 | """ 360 | model_name = params.scone_vis_model_name 361 | start_epoch = 0 362 | best_loss = 1000. 363 | best_coverage = 0. 364 | 365 | # Weight initialization process 366 | if load_pretrained_weights: 367 | # Load pretrained weights if needed 368 | if pretrained_weights_name==None: 369 | weights_file = "unvalidated_" + model_name + ".pth" 370 | else: 371 | weights_file = pretrained_weights_name 372 | checkpoint = torch.load(weights_file, map_location=device) 373 | start_epoch = checkpoint['epoch'] + 1 374 | best_loss = checkpoint['loss'] 375 | if 'coverage' in checkpoint: 376 | best_coverage = checkpoint['coverage'] 377 | 378 | ddp_model = False 379 | if (model_name[:2] == "jz") or (model_name[:3] == "ddp"): 380 | ddp_model = True 381 | 382 | if ddp_model: 383 | scone_vis = load_ddp_state_dict(scone_vis, checkpoint['model_state_dict']) 384 | else: 385 | scone_vis.load_state_dict(checkpoint['model_state_dict']) 386 | 387 | else: 388 | # Else, applies a basic initialization process 389 | if torch_seed is not None: 390 | torch.manual_seed(torch_seed) 391 | print("Seed", torch_seed, "chosen.") 392 | scone_vis.apply(init_weights) 393 | 394 | # DDP wrapping if needed 395 | if params.ddp: 396 | scone_vis = DDP(scone_vis, 397 | device_ids=[ddp_rank]) 398 | elif params.jz: 399 | scone_vis = DDP(scone_vis, 400 | device_ids=[idr_torch.local_rank]) 401 | 402 | # Creating optimizer 403 | optimizer, opt_name = get_optimizer(params, scone_vis) 404 | if load_pretrained_weights: 405 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 406 | 407 | return scone_vis, optimizer, opt_name, start_epoch, best_loss, best_coverage 408 | 409 | 410 | def load_scone_vis(params, trained_model_name, ddp_model, device): 411 | """ 412 | Loads an already trained visibility prediction model for inference on a single GPU. 413 | 414 | :param params: (Params) Parameters file. 415 | :param trained_model_name: (str) Name of trained model's checkpoint. 416 | :param device: Device. 417 | :return: (SconeVis) Visibility prediction module with trained weights. 418 | """ 419 | scone_vis = SconeVis(use_sigmoid=params.use_sigmoid).to(device) 420 | 421 | model_name = params.scone_vis_model_name 422 | 423 | # Loads checkpoint 424 | checkpoint = torch.load(trained_model_name, map_location=device) 425 | print("Model name:", trained_model_name) 426 | print("Trained for", checkpoint['epoch'], 'epochs.') 427 | print("Training finished with loss", checkpoint['loss']) 428 | 429 | # Loads trained weights 430 | if ddp_model: 431 | scone_vis = load_ddp_state_dict(scone_vis, checkpoint['model_state_dict']) 432 | else: 433 | scone_vis.load_state_dict(checkpoint['model_state_dict']) 434 | 435 | return scone_vis 436 | 437 | 438 | def get_cov_loss_fn(params): 439 | if params.training_loss == "kl_divergence": 440 | cov_loss_fn = KLDivCE() 441 | 442 | elif params.training_loss == "l1": 443 | cov_loss_fn = L1_loss() 444 | 445 | elif params.training_loss == "uncentered_l1": 446 | cov_loss_fn = Uncentered_L1_loss() 447 | 448 | else: 449 | raise NameError("Invalid training loss function." 450 | "Please choose a valid loss between 'kl_divergence', 'l1' or 'uncentered_l1.") 451 | 452 | return cov_loss_fn 453 | 454 | 455 | def get_occ_loss_fn(params): 456 | if params.training_loss == "mse": 457 | occ_loss_fn = nn.MSELoss(size_average=None, reduce=None, reduction='mean') 458 | return occ_loss_fn 459 | 460 | else: 461 | raise NameError("Invalid training loss function." 462 | "Please choose a valid loss like 'mse'.") 463 | 464 | 465 | # -----Data Functions----- 466 | 467 | # TO CHANGE! Maybe Load on CPU for large scenes. 468 | def get_gt_partial_point_clouds(path, device, normalization_factor=None): 469 | """ 470 | Loads ground truth partial point clouds for training. 471 | :param path: 472 | :param device: 473 | :param normalization_factor: factor to normalize the point cloud. 474 | if None, the point cloud is not normalized. 475 | :return: 476 | """ 477 | parent_dir = os.path.dirname(path) 478 | load_directory = os.path.join(parent_dir, "tensors") 479 | 480 | file_name = "partial_point_clouds.pt" 481 | pc_dict = torch.load(os.path.join(load_directory, file_name), 482 | map_location=device) 483 | 484 | part_pc = pc_dict['partial_point_cloud'] 485 | coverage = torch.vstack(pc_dict['coverage']) 486 | 487 | if normalization_factor is not None: 488 | for i in range(len(part_pc)): 489 | part_pc[i] = normalization_factor * part_pc[i] 490 | 491 | return part_pc, coverage 492 | 493 | 494 | def get_gt_occupancy_field(path, device): 495 | """ 496 | Loads ground truth occupancy field for training. 497 | :param path: 498 | :param device: 499 | :return: 500 | """ 501 | parent_dir = os.path.dirname(path) 502 | load_directory = os.path.join(parent_dir, "tensors") 503 | 504 | file_name = "occupancy_field.pt" 505 | 506 | pc_dict = torch.load(os.path.join(load_directory, file_name), 507 | map_location=device) 508 | 509 | X_world = pc_dict['occupancy_field'][..., :3] 510 | occs = pc_dict['occupancy_field'][..., 3:] 511 | 512 | return X_world, occs 513 | 514 | 515 | def get_gt_surface(params, path, device, normalization_factor=None): 516 | parent_dir = os.path.dirname(path) 517 | load_directory = os.path.join(parent_dir, "tensors") 518 | file_name = "surface_points.pt" 519 | 520 | surface_dict = torch.load(os.path.join(load_directory, file_name), 521 | map_location=device) 522 | gt_surface = surface_dict['surface_points'] 523 | 524 | if params.surface_epsilon_is_constant: 525 | surface_epsilon = params.surface_epsilon 526 | else: 527 | surface_epsilon = surface_dict['epsilon'] 528 | 529 | if normalization_factor is not None: 530 | gt_surface = gt_surface * normalization_factor 531 | surface_epsilon = surface_epsilon * normalization_factor 532 | 533 | return gt_surface, surface_epsilon 534 | 535 | 536 | def get_optimal_sequence(optimal_sequences, mesh_path, n_views): 537 | key = os.path.basename(os.path.dirname(mesh_path)) 538 | 539 | # optimal_seq = optimal_sequences[key]['idx'] 540 | optimal_seq = torch.Tensor(optimal_sequences[key]['idx']).long() 541 | seq_coverage = optimal_sequences[key]['coverage'] 542 | 543 | return optimal_seq[:n_views], seq_coverage[:n_views] 544 | 545 | 546 | def compute_gt_coverage_gain_from_precomputed_matrices(coverage, initial_cam_idx): 547 | device = coverage.get_device() 548 | 549 | n_camera_candidates, n_points_surface = coverage.shape[0], coverage.shape[1] 550 | 551 | # Compute coverage matrix of previous cameras, and the corresponding value 552 | coverage_matrix = torch.sum(coverage[initial_cam_idx], dim=0).view(1, n_points_surface).expand(n_camera_candidates, -1) 553 | previous_coverage = torch.mean(torch.heaviside(coverage_matrix, 554 | values=torch.zeros_like(coverage_matrix, 555 | device=device)), dim=-1) 556 | # Compute coverage matrices of previous + new camera for every camera 557 | coverage_matrix = coverage_matrix + coverage 558 | coverage_matrix = torch.mean(torch.heaviside(coverage_matrix, 559 | values=torch.zeros_like(coverage_matrix, 560 | device=device)), dim=-1) 561 | 562 | # Compute coverage gain value 563 | coverage_matrix = coverage_matrix - previous_coverage 564 | 565 | return coverage_matrix.view(-1, 1) 566 | 567 | 568 | def compute_surface_coverage_from_cam_idx(coverage, cam_idx): 569 | device = coverage.get_device() 570 | 571 | coverage_matrix = torch.sum(coverage[cam_idx], dim=0) 572 | 573 | coverage = torch.mean(torch.heaviside(coverage_matrix, 574 | values=torch.zeros_like(coverage_matrix, 575 | device=device)), dim=-1) 576 | 577 | return coverage.view(1) 578 | 579 | 580 | def get_validation_n_views_list(params, dataloader): 581 | n_views = params.n_view_max - params.n_view_min + 1 582 | 583 | n_views_list = np.repeat(np.arange(start=params.n_view_min, 584 | stop=params.n_view_max + 1).reshape(1, n_views), 585 | len(dataloader.dataset) // n_views + 1, axis=0).reshape(-1) 586 | 587 | return n_views_list 588 | 589 | 590 | def get_validation_n_view(params, n_views_list, batch, rank): 591 | idx = batch * params.total_batch_size + rank * params.batch_size 592 | 593 | return n_views_list[idx:idx + params.batch_size] 594 | 595 | 596 | def get_validation_optimal_sequences(jz, device): 597 | if jz: 598 | SHAPENET_PATH = "../../datasets/ShapeNetCore.v1" 599 | else: 600 | SHAPENET_PATH = "../../../../../../mnt/ssd/aguedon/ShapeNetCore.v1/ShapeNetCore.v1" 601 | 602 | file_name = "validation_optimal_trajectories.pt" 603 | optimal_sequences = torch.load(os.path.join(SHAPENET_PATH, file_name), 604 | map_location=device) 605 | 606 | return optimal_sequences 607 | 608 | 609 | def get_all_harmonics_under_degree(degree, n_elev, n_azim, device): 610 | """ 611 | Gets values for all harmonics with l < degree. 612 | :param degree: 613 | :param n_elev: 614 | :param n_azim: 615 | :param device: 616 | :return: 617 | """ 618 | h_elev = torch.Tensor( 619 | [-np.pi / 2 + (i + 1) / (n_elev + 1) * np.pi for i in range(n_elev) for j in range(n_azim)]).to(device) 620 | h_polar = -h_elev + np.pi / 2 621 | 622 | h_azim = torch.Tensor([2 * np.pi * j / n_azim for i in range(n_elev) for j in range(n_azim)]).to(device) 623 | 624 | z = torch.zeros([i for i in h_polar.shape] + [0], device=h_polar.get_device()) 625 | 626 | clear_spherical_harmonics_cache() 627 | for l in range(degree): 628 | y = get_spherical_harmonics(l, h_polar, h_azim) 629 | z = torch.cat((z, y), dim=-1) 630 | 631 | z = z.transpose(dim0=0, dim1=1) 632 | 633 | return z, h_polar, h_azim 634 | 635 | 636 | def get_cameras_on_sphere(params, device, pole_cameras=False, n_elev=None, n_azim=None): 637 | """ 638 | Returns cameras candidate positions, sampled on a sphere. 639 | Made for SCONE pretraining on ShapeNet. 640 | :param params: (Params) The dictionary of parameters. 641 | :param device: 642 | :return: A tuple of Tensors (X_cam, candidate_dist, candidate_elev, candidate_azim) 643 | X_cam has shape (n_camera_candidate, 3) 644 | All other tensors have shape (n_camera candidate, ) 645 | """ 646 | if n_elev is None or n_azim is None: 647 | n_elev = params.n_camera_elev 648 | n_azim = params.n_camera_azim 649 | n_camera = params.n_camera 650 | else: 651 | n_camera = n_elev * n_azim 652 | if pole_cameras: 653 | n_camera += 2 654 | 655 | candidate_dist = torch.Tensor([params.camera_dist for i in range(n_camera)]).to(device) 656 | 657 | candidate_elev = [-90. + (i + 1) / (n_elev + 1) * 180. 658 | for i in range(n_elev) 659 | for j in range(n_azim)] 660 | 661 | candidate_azim = [360. * j / n_azim 662 | for i in range(n_elev) 663 | for j in range(n_azim)] 664 | 665 | if pole_cameras: 666 | candidate_elev = [-89.9] + candidate_elev + [89.9] 667 | candidate_azim = [0.] + candidate_azim + [0.] 668 | 669 | candidate_elev = torch.Tensor(candidate_elev).to(device) 670 | candidate_azim = torch.Tensor(candidate_azim).to(device) 671 | 672 | X_cam = get_cartesian_coords(r=candidate_dist.view(-1, 1), 673 | elev=candidate_elev.view(-1, 1), 674 | azim=candidate_azim.view(-1, 1), 675 | in_degrees=True) 676 | 677 | return X_cam, candidate_dist, candidate_elev, candidate_azim 678 | 679 | 680 | def normalize_points_in_prediction_box(points, prediction_box_center, prediction_box_diag): 681 | """ 682 | 683 | :param points: 684 | :param prediction_box_center: 685 | :param prediction_box_diag: 686 | :return: 687 | """ 688 | return (points - prediction_box_center) / prediction_box_diag 689 | 690 | 691 | def compute_view_state(pts, X_view, n_elev, n_azim): 692 | """ 693 | Computes view_state vector for points pts and camera positions X_view. 694 | :param pts: Tensor with shape (n_cloud, seq_len, pts_dim) where pts_dim >= 3. 695 | :param X_view: Tensor with shape (n_screen_cameras, 3). 696 | Represents camera positions in prediction camera space coordinates. 697 | :param n_elev: Integer. Number of elevations values to discretize view states. 698 | :param n_azim: Integer. Number of azimuth values to discretize view states 699 | :return: A Tensor with shape (n_cloud, seq_len, n_elev*n_azim). 700 | """ 701 | # Initializing variables 702 | device = pts.get_device() 703 | n_view = len(X_view) 704 | n_clouds, seq_len, _ = pts.shape 705 | n_candidates = n_elev * n_azim 706 | 707 | elev_step = np.pi / (n_elev + 1) 708 | azim_step = 2 * np.pi / n_azim 709 | 710 | X_pts = pts[..., :3] 711 | 712 | # Computing camera elev and azim in every pts space coordinates 713 | rays = X_view.view(1, 1, n_view, 3).expand(n_clouds, seq_len, -1, -1) \ 714 | - X_pts.view(n_clouds, seq_len, 1, 3).expand(-1, -1, n_view, -1) 715 | 716 | _, ray_elev, ray_azim = get_spherical_coords(rays.view(-1, 3)) 717 | 718 | ray_elev = ray_elev.view(n_clouds, seq_len, n_view) 719 | ray_azim = ray_azim.view(n_clouds, seq_len, n_view) 720 | 721 | # Projecting elev and azim to the closest values in the discretized cameras 722 | idx_elev = floor_divide(ray_elev, elev_step) 723 | idx_azim = floor_divide(ray_azim, azim_step) 724 | 725 | # If closest to ceil than floor, we add 1 726 | idx_elev[ray_elev % elev_step > elev_step / 2.] += 1 727 | idx_azim[ray_azim % azim_step > azim_step / 2.] += 1 728 | 729 | # Elevation can't be below minimal or above maximal values 730 | idx_elev[idx_elev >= n_elev] = n_elev - 1 731 | idx_elev[idx_elev < -n_elev // 2] = -n_elev // 2 732 | 733 | # If azimuth is greater than 180 degrees, we reset it back to -180 degrees 734 | idx_azim[idx_azim > n_azim // 2] = -n_azim // 2 735 | 736 | # Normalizing indices to retrieve camera positions in flattened view_state 737 | idx_elev += n_elev // 2 738 | idx_azim[idx_azim < 0] += n_azim 739 | 740 | indices = idx_elev.long() * n_azim + idx_azim.long() 741 | indices %= n_candidates 742 | q = torch.arange(start=0, end=n_clouds * seq_len, step=1, device=device).view(-1, 1).expand(-1, n_view) 743 | 744 | flat_indices = indices.view(-1, n_view) 745 | flat_indices = q * n_candidates + flat_indices 746 | flat_indices = flat_indices.view(-1) 747 | 748 | # Compute view_state and set visited camera values to 1 749 | view_state = torch.zeros(n_clouds, seq_len, n_candidates, device=device) 750 | view_state.view(-1)[flat_indices] = 1. 751 | 752 | return view_state 753 | 754 | 755 | def move_view_state_to_view_space(view_state, fov_camera, n_elev, n_azim): 756 | """ 757 | "Rotate" the view state vectors to the corresponding view space. 758 | 759 | :param view_state: (Tensor) View state tensor with shape (n_cloud, seq_len, n_elev * n_azim) 760 | :param fov_camera: (FoVPerspectiveCamera) 761 | :param n_elev: (int) 762 | :param n_azim: (int) 763 | :return: Rotated view state tensor with shape (n_cloud, seq_len, n_elev * n_azim) 764 | """ 765 | device = view_state.get_device() 766 | n_clouds = view_state.shape[0] 767 | seq_len = view_state.shape[1] 768 | 769 | n_view = n_elev * n_azim 770 | 771 | candidate_dist = torch.Tensor([1. for i in range(n_elev * n_azim)]).to(device) 772 | 773 | candidate_elev = [-90. + (i + 1) / (n_elev + 1) * 180. 774 | for i in range(n_elev) 775 | for j in range(n_azim)] 776 | 777 | candidate_azim = [360. * j / n_azim 778 | for i in range(n_elev) 779 | for j in range(n_azim)] 780 | 781 | candidate_elev = torch.Tensor(candidate_elev).to(device) 782 | candidate_azim = torch.Tensor(candidate_azim).to(device) 783 | 784 | X_cam_ref = get_cartesian_coords(r=candidate_dist.view(-1, 1), 785 | elev=candidate_elev.view(-1, 1), 786 | azim=candidate_azim.view(-1, 1), 787 | in_degrees=True) 788 | X_cam_inv = fov_camera.get_world_to_view_transform().inverse().transform_points( 789 | X_cam_ref) - fov_camera.get_camera_center() 790 | 791 | elev_step = np.pi / (n_elev + 1) 792 | azim_step = 2 * np.pi / n_azim 793 | 794 | _, ray_elev, ray_azim = get_spherical_coords(X_cam_inv.view(-1, 3)) 795 | 796 | ray_elev = ray_elev.view(n_view) 797 | ray_azim = ray_azim.view(n_view) 798 | 799 | # Projecting elev and azim to the closest values in the discretized cameras 800 | idx_elev = floor_divide(ray_elev, elev_step) 801 | idx_azim = floor_divide(ray_azim, azim_step) 802 | 803 | # If closest to ceil than floor, we add 1 804 | idx_elev[ray_elev % elev_step > elev_step / 2.] += 1 805 | idx_azim[ray_azim % azim_step > azim_step / 2.] += 1 806 | 807 | # Elevation can't be below minimal or above maximal values 808 | idx_elev[idx_elev > n_elev // 2] = n_elev // 2 809 | idx_elev[idx_elev < -(n_elev // 2)] = -(n_elev // 2) 810 | 811 | # If azimuth is greater than 180 degrees, we reset it back to -180 degrees 812 | idx_azim[idx_azim > n_azim // 2] = -(n_azim // 2) 813 | 814 | # Normalizing indices to retrieve camera positions in flattened view_state 815 | idx_elev += n_elev // 2 816 | idx_azim[idx_azim < 0] += n_azim 817 | 818 | indices = idx_elev.long() * n_azim + idx_azim.long() 819 | 820 | rot_view_state = torch.gather(input=view_state, dim=2, index=indices.view(1, 1, -1).expand(n_clouds, seq_len, -1)) 821 | 822 | return rot_view_state 823 | 824 | 825 | 826 | def compute_view_harmonics(view_state, base_harmonics, h_polar, h_azim, n_elev, n_azim): 827 | """ 828 | Computes spherical harmonics corresponding to the view_state vector. 829 | :param view_state: Tensor with shape (n_cloud, seq_len, n_elev*n_azim). 830 | :param base_harmonics: Tensor with shape (n_harmonics, n_elev*n_azim). 831 | :param h_polar: 832 | :param h_azim: 833 | :param n_elev: 834 | :param n_azim: 835 | :return: Tensor with shape (n_cloud, seq_len, n_harmonics) 836 | """ 837 | # Define parameters 838 | n_harmonics = base_harmonics.shape[0] 839 | n_clouds, seq_len, n_values = view_state.shape 840 | 841 | polar_step = np.pi / (n_elev + 1) 842 | azim_step = 2 * np.pi / n_azim 843 | 844 | # Expanding variables to parallelize computation 845 | all_values = view_state.view(n_clouds, seq_len, 1, n_values).expand(-1, -1, n_harmonics, -1) 846 | all_polar = h_polar.view(1, 1, 1, n_values).expand(n_clouds, seq_len, n_harmonics, -1) 847 | # all_harmonics = base_harmonics.view(1, 1, n_harmonics, n_values).expand(n_clouds, seq_len, -1, -1) 848 | 849 | # Computing spherical L2-dot product on last axis 850 | coordinates = torch.sum(all_values * base_harmonics * torch.sin(all_polar) * polar_step * azim_step, dim=-1) 851 | 852 | return coordinates 853 | 854 | 855 | # ----------Model functions---------- 856 | 857 | def compute_occupancy_probability(scone_occ, pc, X, view_harmonics, mask=None, 858 | max_points_per_pass=20000): 859 | """ 860 | 861 | :param scone_occ: (Scone_Occ) SCONE's Occupancy Probability prediction model. 862 | :param pc: (Tensor) Input point cloud tensor with shape (n_clouds, seq_len, pts_dim) 863 | :param X: (Tensor) Input query points tensor with shape (n_clouds, n_sample, x_dim) 864 | :param view_harmonics: (Tensor) View state harmonic features. Tensor with shape (n_clouds, seq_len, n_harmonics) 865 | :param max_points_per_pass: (int) Maximal number of points per forward pass. 866 | :return: 867 | """ 868 | n_clouds, seq_len, pts_dim = pc.shape[0], pc.shape[1], pc.shape[2] 869 | n_sample, x_dim = X.shape[1], X.shape[2] 870 | n_harmonics = view_harmonics.shape[2] 871 | 872 | preds = torch.zeros(n_clouds, 0, 1).to(X.get_device()) 873 | 874 | p = max_points_per_pass // n_clouds 875 | q = n_sample // p 876 | r = n_sample % p 877 | n_loop = q 878 | if r != 0: 879 | n_loop += 1 880 | 881 | for i in range(n_loop): 882 | low_idx = i * p 883 | up_idx = (i + 1) * p 884 | if i == q: 885 | up_idx = q * p + r 886 | preds_i = scone_occ(pc, X[:, low_idx:up_idx], view_harmonics[:, low_idx:up_idx]) 887 | preds_i = preds_i.view(n_clouds, up_idx - low_idx, -1) 888 | preds = torch.cat((preds, preds_i), dim=1) 889 | 890 | return preds 891 | 892 | 893 | def filter_proxy_points(view_cameras, X, pc, filter_tol=0.01): 894 | """ 895 | Filter proxy points considering camera field of view and partial surface point cloud. 896 | WARNING: Works for a single scene! So X must have shape (n_proxy_points, 3)! 897 | :param view_cameras: 898 | :param X: (Tensor) Proxy points tensor with shape () 899 | :param pc: (Tensor) 900 | :param filter_tol: 901 | :return: 902 | """ 903 | 904 | n_view = view_cameras.R.shape[0] 905 | 906 | if (len(X.shape) != 2) or (len(pc.shape) != 2): 907 | raise NameError("Wrong shapes! X must have shape (n_proxy_points, 3) and pc must have shape (N, 3).") 908 | 909 | view_projection_transform = view_cameras.get_full_projection_transform() 910 | X_proj = view_projection_transform.transform_points(X)[..., :2].view(n_view, -1, 2) 911 | pc_proj = view_projection_transform.transform_points(pc)[..., :2].view(n_view, -1, 2) 912 | 913 | max_proj = torch.max(pc_proj, dim=-2, keepdim=True)[0].expand(-1, X_proj.shape[-2], -1) 914 | min_proj = torch.min(pc_proj, dim=-2, keepdim=True)[0].expand(-1, X_proj.shape[-2], -1) 915 | 916 | filter_mask = torch.prod((X_proj < max_proj + filter_tol) * (X_proj > min_proj - filter_tol), dim=0) 917 | filter_mask = torch.prod(filter_mask, dim=-1).bool() 918 | 919 | return X[filter_mask], filter_mask 920 | 921 | 922 | def sample_proxy_points(X_world, preds, view_harmonics, n_sample, min_occ, use_occ_to_sample=True, 923 | return_index=False): 924 | """ 925 | 926 | :param X: Tensor with shape (n_points, 3) 927 | :param preds: Tensor with shape (n_points, 1) 928 | :param view_harmonics: Tensor with shape (n_points, n_harmonics) 929 | :param n_sample: integer 930 | :return: 931 | """ 932 | mask = preds[..., 0] > min_occ 933 | res_X = X_world[mask] 934 | res_preds = preds[mask] 935 | res_harmonics = view_harmonics[mask] 936 | 937 | device = res_X.get_device() 938 | n_points = res_X.shape[0] 939 | 940 | if use_occ_to_sample: 941 | sample_probs = res_preds[..., 0] / torch.sum(res_preds) 942 | sample_probs = torch.cumsum(sample_probs, dim=-1) 943 | 944 | samples = torch.rand(n_sample, 1, device=device) 945 | 946 | res_idx = sample_probs.view(1, n_points).expand(n_sample, -1) - samples.expand(-1, n_points) 947 | res_idx[res_idx < 0] = 2 948 | res_idx = torch.argmin(res_idx, dim=-1) 949 | 950 | res_idx, inverse_idx = torch.unique(res_idx, dim=0, return_inverse=True) 951 | 952 | res = torch.cat((res_X[res_idx], res_preds[res_idx]), dim=-1) 953 | res_harmonics = res_harmonics[res_idx] 954 | # res = torch.unique(res, dim=0) 955 | 956 | else: 957 | if len(res_X) > n_sample: 958 | res_X = res_X[:n_sample] 959 | res_preds = res_preds[:n_sample] 960 | res_harmonics = res_harmonics[:n_sample] 961 | 962 | res = torch.cat((res_X, res_preds), dim=-1) 963 | inverse_idx = None 964 | 965 | if return_index: 966 | return res, res_harmonics, inverse_idx 967 | else: 968 | return res, res_harmonics 969 | -------------------------------------------------------------------------------- /SCONE/spherical_harmonics.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/lucidrains/se3-transformer-pytorch/blob/main/se3_transformer_pytorch/spherical_harmonics.py 2 | 3 | from math import pi, sqrt 4 | from operator import mul 5 | import torch 6 | from functools import reduce, wraps 7 | 8 | from math import pi, sqrt 9 | from operator import mul 10 | import torch 11 | from functools import reduce, wraps 12 | 13 | 14 | def cache(cache, key_fn): 15 | def cache_inner(fn): 16 | @wraps(fn) 17 | def inner(*args, **kwargs): 18 | key_name = key_fn(*args, **kwargs) 19 | if key_name in cache: 20 | return cache[key_name] 21 | res = fn(*args, **kwargs) 22 | cache[key_name] = res 23 | return res 24 | 25 | return inner 26 | 27 | return cache_inner 28 | 29 | 30 | # constants 31 | 32 | CACHE = {} 33 | 34 | 35 | def clear_spherical_harmonics_cache(): 36 | CACHE.clear() 37 | 38 | 39 | def lpmv_cache_key_fn(l, m, x): 40 | return (l, m) 41 | 42 | 43 | # spherical harmonics 44 | 45 | # def reduce(function, iterable, initializer=None): 46 | # it = iter(iterable) 47 | # if initializer is None: 48 | # value = next(it) 49 | # else: 50 | # value = initializer 51 | # for element in it: 52 | # value = function(value, element) 53 | # return value 54 | 55 | def semifactorial(x): 56 | return reduce(mul, range(x, 1, -2), 1.) 57 | 58 | 59 | def pochhammer(x, k): 60 | return reduce(mul, range(x + 1, x + k), float(x)) 61 | 62 | 63 | def negative_lpmv(l, m, y): 64 | if m < 0: 65 | y *= ((-1) ** m / pochhammer(l + m + 1, -2 * m)) 66 | return y 67 | 68 | 69 | @cache(cache=CACHE, key_fn=lpmv_cache_key_fn) 70 | def lpmv(l, m, x): 71 | """Associated Legendre function including Condon-Shortley phase. 72 | Args: 73 | m: int order 74 | l: int degree 75 | x: float argument tensor 76 | Returns: 77 | tensor of x-shape 78 | """ 79 | # Check memoized versions 80 | m_abs = abs(m) 81 | 82 | if m_abs > l: 83 | return None 84 | 85 | if l == 0: 86 | x_device = x.get_device() 87 | if x_device < 0: 88 | x_device = "cpu" 89 | return torch.ones_like(x, device=x_device) 90 | 91 | # Check if on boundary else recurse solution down to boundary 92 | if m_abs == l: 93 | # Compute P_m^m 94 | y = (-1) ** m_abs * semifactorial(2 * m_abs - 1) 95 | y *= torch.pow(1 - x * x, m_abs / 2) 96 | return negative_lpmv(l, m, y) 97 | 98 | # Recursively precompute lower degree harmonics 99 | lpmv(l - 1, m, x) 100 | 101 | # Compute P_{l}^m from recursion in P_{l-1}^m and P_{l-2}^m 102 | # Inplace speedup 103 | y = ((2 * l - 1) / (l - m_abs)) * x * lpmv(l - 1, m_abs, x) 104 | 105 | if l - m_abs > 1: 106 | y -= ((l + m_abs - 1) / (l - m_abs)) * CACHE[(l - 2, m_abs)] 107 | 108 | if m < 0: 109 | y = self.negative_lpmv(l, m, y) 110 | return y 111 | 112 | 113 | def get_spherical_harmonics_element(l, m, theta, phi): 114 | """Tesseral spherical harmonic with Condon-Shortley phase. 115 | The Tesseral spherical harmonics are also known as the real spherical 116 | harmonics. 117 | Args: 118 | l: int for degree 119 | m: int for order, where -l <= m < l 120 | theta: collatitude or polar angle 121 | phi: longitude or azimuth 122 | Returns: 123 | tensor of shape theta 124 | """ 125 | m_abs = abs(m) 126 | assert m_abs <= l, "absolute value of order m must be <= degree l" 127 | 128 | N = sqrt((2 * l + 1) / (4 * pi)) 129 | leg = lpmv(l, m_abs, torch.cos(theta)) 130 | 131 | if m == 0: 132 | return N * leg 133 | 134 | if m > 0: 135 | Y = torch.cos(m * phi) 136 | else: 137 | Y = torch.sin(m_abs * phi) 138 | 139 | Y *= leg 140 | N *= sqrt(2. / pochhammer(l - m_abs + 1, 2 * m_abs)) 141 | Y *= N 142 | return Y 143 | 144 | 145 | def get_spherical_harmonics(l, theta, phi): 146 | """ Tesseral harmonic with Condon-Shortley phase. 147 | The Tesseral spherical harmonics are also known as the real spherical 148 | harmonics. 149 | Args: 150 | l: int for degree 151 | theta: collatitude or polar angle 152 | phi: longitude or azimuth 153 | Returns: 154 | tensor of shape [*theta.shape, 2*l+1] 155 | """ 156 | return torch.stack([get_spherical_harmonics_element(l, m, theta, phi) \ 157 | for m in range(-l, l + 1)], 158 | dim=-1) 159 | 160 | 161 | def evaluate_from_harmonic_coordinates(coordinates, theta, phi, degree): 162 | """ 163 | We denote by N the number of harmonics with l <= degree. 164 | WARNING! Theta should represent the polar angles, not elevation! 165 | :param coordinates: must have shape (N) 166 | :param theta: Tensor that represents polar angles. 167 | :param phi: Tensor that represents azimuth angles. 168 | :param degree: int 169 | :return: A Tensor with shape theta 170 | """ 171 | z = torch.zeros([i for i in theta.shape] + [0], device=theta.get_device()) 172 | 173 | for l in range(degree): 174 | y = get_spherical_harmonics(l, theta, phi) 175 | z = torch.cat((z, y), dim=-1) 176 | 177 | return torch.sum(coordinates * z, dim=-1) 178 | -------------------------------------------------------------------------------- /SCONE/test_scone.py: -------------------------------------------------------------------------------- 1 | from scone_utils import * 2 | 3 | pc_size = 1024 4 | n_view_max = 10 5 | test_novel = True 6 | test_number = -1 7 | 8 | params_name = "train_params_jz_model_scone_vis_surface_coverage_gain_constant_epsilon_uncentered_l1_warmup_1000_schedule_lr_0.0001_sigmoid_tmcs.json" 9 | scone_occ_model_name = "best_unval_jz_model_scone_occ_mse_warmup_1000_schedule_lr_0.0001.pth" 10 | scone_vis_model_name = "best_unval_jz_model_scone_vis_surface_coverage_gain_constant_epsilon_uncentered_l1_warmup_1000_schedule_lr_0.0001_sigmoid_tmcs.pth" 11 | 12 | 13 | def save_test_params(save=False): 14 | # TO DO: for ddp, save only if is_master but load for everyone, with a synchronization in-between. 15 | params = {} 16 | 17 | # -----General parameters----- 18 | params["ddp"] = False 19 | params["jz"] = False 20 | 21 | # TO ADAPT 22 | # params["coverage_validation"] = False 23 | params["test_number"] = test_number 24 | params["test_novel"] = test_novel 25 | # params["n_screen_cameras"] = 1 26 | params["pc_size"] = 1024 # 2048 27 | 28 | if params["ddp"]: 29 | params["CUDA_VISIBLE_DEVICES"] = "0, 1" 30 | params["WORLD_SIZE"] = 2 31 | 32 | elif params["jz"]: 33 | params["WORLD_SIZE"] = idr_torch.size 34 | 35 | else: 36 | params["numGPU"] = 1 37 | params["WORLD_SIZE"] = 1 38 | 39 | params["anomaly_detection"] = True 40 | params["empty_cache_every_n_batch"] = 10 41 | 42 | # -----Ground truth computation parameters----- 43 | params["compute_gt_online"] = False 44 | params["compute_partial_point_cloud_online"] = False 45 | 46 | params["gt_surface_resolution"] = 1.5 47 | 48 | params["gt_max_diagonal"] = 1. # Formerly known as x_range 49 | 50 | params["n_points_surface"] = 16384 # N points on GT surface 51 | 52 | params["surface_epsilon_is_constant"] = True 53 | if params["surface_epsilon_is_constant"]: 54 | params["surface_epsilon"] = 0.00707 55 | 56 | # -----SconeOcc Model Parameters----- 57 | # params["scone_occ_model_name"] = "best_unval_jz_model_scone_occ_mse_warmup_1000_schedule_lr_0.0001.pth" 58 | params["scone_occ_model_name"] = "best_unval_jz_model_scone_occ_mse_warmup_1000_schedule_lr_0.0001.pth" 59 | params["occ_no_view_harmonics"] = False 60 | 61 | params["n_view_max_for_scone_occ"] = 10 62 | params["max_points_per_scone_occ_pass"] = 300000 63 | 64 | # -----Model Parameters----- 65 | params["seq_len"] = 2048 66 | params["pts_dim"] = 4 67 | 68 | params["view_state_n_elev"] = 7 69 | params["view_state_n_azim"] = 2 * 7 70 | params["harmonic_degree"] = 8 71 | 72 | params["n_proxy_points"] = 100000 # 12000, 100000 73 | params["use_occ_to_sample_proxy_points"] = True # True 74 | params["min_occ_for_proxy_points"] = 0.1 75 | 76 | params["true_monte_carlo_sampling"] = True 77 | 78 | # -----Ablation study----- 79 | params["no_view_harmonics"] = False 80 | params["use_sigmoid"] = True 81 | 82 | # -----General training parameters----- 83 | params["start_from_scratch"] = True 84 | params["pretrained_weights_name"] = None 85 | 86 | params["n_view_max"] = 10 87 | params["n_view_min"] = 1 88 | params["filter_tol"] = 0.01 89 | 90 | params["camera_dist"] = 1.5 91 | params["pole_cameras"] = True 92 | params["n_camera_elev"] = 5 93 | params["n_camera_azim"] = 2 * 5 94 | params["n_camera"] = params["n_camera_elev"] * params["n_camera_azim"] 95 | if params["pole_cameras"]: 96 | params["n_camera"] += 2 97 | 98 | params["prediction_in_random_camera_space"] = False 99 | 100 | params["batch_size"] = 4 101 | 102 | params["noam_opt"] = False 103 | params["training_metric"] = "surface_coverage_gain" 104 | # Training metric can be: "surface_coverage", "surface_coverage_gain", "absolute_coverage" 105 | params["training_loss"] = "uncentered_l1" # "kl_divergence", "l1", "uncentered_l1" 106 | params["multiply_loss"] = False 107 | if params["multiply_loss"]: 108 | params["loss_multiplication_factor"] = 10. 109 | 110 | params["nbv_validation"] = True 111 | 112 | params["random_seed"] = 42 113 | params["torch_seed"] = 5 114 | 115 | # -----Visibility Model name to save----- 116 | params["scone_vis_model_name"] = scone_vis_model_name 117 | 118 | # -----Json name to save params----- 119 | json_name = "test_params_" + scone_vis_model_name + ".json" 120 | 121 | if save: 122 | with open(json_name, 'w') as outfile: 123 | json.dump(params, outfile) 124 | 125 | print("Parameters save in:") 126 | print(json_name) 127 | 128 | return json_name 129 | 130 | 131 | def test_loop(params, dataloader, 132 | scone_occ, scone_vis, 133 | device): 134 | 135 | # Begin test process 136 | coverage_dict = {} 137 | sum_coverages = torch.zeros(params.n_view_max, device=device) 138 | 139 | t0 = time.time() 140 | torch.cuda.empty_cache() 141 | 142 | print("Beginning evaluation on test dataset...") 143 | 144 | base_harmonics, h_polar, h_azim = get_all_harmonics_under_degree(params.harmonic_degree, 145 | params.view_state_n_elev, 146 | params.view_state_n_azim, 147 | device) 148 | 149 | size = len(dataloader.dataset) 150 | num_batches = len(dataloader) 151 | 152 | t0 = time.time() 153 | computation_time = 0. 154 | 155 | for batch, (mesh_dict) in enumerate(dataloader): 156 | paths = mesh_dict['path'] 157 | batch_size = len(paths) 158 | 159 | for i in range(batch_size): 160 | # ----------Load input mesh and ground truth data----------------------------------------------------------- 161 | 162 | path_i = paths[i] 163 | 164 | coverage_dict[path_i] = [] 165 | coverages = torch.zeros(params.n_view_max, device=device) 166 | 167 | # Loading info about partial point clouds and coverages 168 | part_pc, coverage_matrix = get_gt_partial_point_clouds(path=path_i, 169 | normalization_factor=1. / params.gt_surface_resolution, 170 | device=device) 171 | 172 | # Initial dense sampling 173 | X_world = sample_X_in_box(x_range=params.gt_max_diagonal, n_sample=params.n_proxy_points, device=device) 174 | 175 | # ----------Set camera candidates for coverage prediction--------------------------------------------------- 176 | X_cam_world, camera_dist, camera_elev, camera_azim = get_cameras_on_sphere(params, device, 177 | pole_cameras=params.pole_cameras) 178 | n_view = 1 179 | view_idx = torch.randperm(len(camera_elev), device=device)[:n_view] 180 | 181 | prediction_cam_idx = view_idx[0] 182 | prediction_box_center = torch.Tensor([0., 0., params.camera_dist]).to(device) 183 | 184 | # Move camera coordinates from world space to prediction view space, and normalize them for prediction box 185 | prediction_R, prediction_T = look_at_view_transform(dist=camera_dist[prediction_cam_idx], 186 | elev=camera_elev[prediction_cam_idx], 187 | azim=camera_azim[prediction_cam_idx], 188 | device=device) 189 | prediction_camera = FoVPerspectiveCameras(device=device, R=prediction_R, T=prediction_T) 190 | prediction_view_transform = prediction_camera.get_world_to_view_transform() 191 | 192 | X_cam = prediction_view_transform.transform_points(X_cam_world) 193 | X_cam = normalize_points_in_prediction_box(points=X_cam, 194 | prediction_box_center=prediction_box_center, 195 | prediction_box_diag=params.gt_max_diagonal) 196 | _, elev_cam, azim_cam = get_spherical_coords(X_cam) 197 | 198 | X_view = X_cam[view_idx] 199 | # X_cam = X_cam.view(1, params.n_camera, 3) 200 | 201 | # Compute initial coverage 202 | coverage = compute_surface_coverage_from_cam_idx(coverage_matrix, view_idx).detach().item() 203 | 204 | coverage_dict[path_i].append(coverage) 205 | coverages[0] += coverage 206 | 207 | # Sample random proxy points in space 208 | X_idx = torch.randperm(len(X_world))[:params.n_proxy_points] 209 | X_world = X_world[X_idx] 210 | 211 | for j_view in range(1, params.n_view_max): 212 | features = None 213 | args = None 214 | computation_t0 = time.time() 215 | 216 | # ----------Capture initial observations---------------------------------------------------------------- 217 | 218 | # Points observed in initial views 219 | pc = torch.vstack([part_pc[pc_idx] for pc_idx in view_idx]) 220 | 221 | # Downsampling partial point cloud 222 | # pc = pc[torch.randperm(len(pc))[:n_view * params.seq_len]] 223 | pc = pc[torch.randperm(len(pc))[:n_view * pc_size]] 224 | 225 | # Move partial point cloud from world space to prediction view space, 226 | # and normalize them in prediction box 227 | pc = prediction_view_transform.transform_points(pc) 228 | pc = normalize_points_in_prediction_box(points=pc, 229 | prediction_box_center=prediction_box_center, 230 | prediction_box_diag=params.gt_max_diagonal).view(1, -1, 3) 231 | 232 | # Move proxy points from world space to prediction view space, and normalize them in prediction box 233 | X = prediction_view_transform.transform_points(X_world) 234 | X = normalize_points_in_prediction_box(points=X, 235 | prediction_box_center=prediction_box_center, 236 | prediction_box_diag=params.gt_max_diagonal 237 | ) 238 | 239 | # Filter Proxy Points using pc shape from view cameras 240 | R_view, T_view = look_at_view_transform(eye=X_view, 241 | at=torch.zeros_like(X_view), 242 | device=device) 243 | view_cameras = FoVPerspectiveCameras(R=R_view, T=T_view, zfar=1000, device=device) 244 | X, _ = filter_proxy_points(view_cameras, X, pc.view(-1, 3), filter_tol=params.filter_tol) 245 | X = X.view(1, X.shape[0], 3) 246 | 247 | # Compute view state vector and corresponding view harmonics 248 | view_state = compute_view_state(X, X_view, 249 | params.view_state_n_elev, params.view_state_n_azim) 250 | view_harmonics = compute_view_harmonics(view_state, 251 | base_harmonics, h_polar, h_azim, 252 | params.view_state_n_elev, params.view_state_n_azim) 253 | occ_view_harmonics = 0. + view_harmonics 254 | if params.occ_no_view_harmonics: 255 | occ_view_harmonics *= 0. 256 | if params.no_view_harmonics: 257 | view_harmonics *= 0. 258 | 259 | # Compute occupancy probabilities 260 | with torch.no_grad(): 261 | occ_prob_i = compute_occupancy_probability(scone_occ=scone_occ, 262 | pc=pc, 263 | X=X, 264 | view_harmonics=occ_view_harmonics, 265 | max_points_per_pass=params.max_points_per_scone_occ_pass 266 | ).view(-1, 1) 267 | 268 | proxy_points, view_harmonics, sample_idx = sample_proxy_points(X[0], occ_prob_i, 269 | view_harmonics.squeeze(dim=0), 270 | n_sample=params.seq_len, 271 | min_occ=params.min_occ_for_proxy_points, 272 | use_occ_to_sample=params.use_occ_to_sample_proxy_points, 273 | return_index=True) 274 | 275 | proxy_points = torch.unsqueeze(proxy_points, dim=0) 276 | view_harmonics = torch.unsqueeze(view_harmonics, dim=0) 277 | 278 | # ----------Predict Coverage Gains------------------------------------------------------------------------------ 279 | visibility_gain_harmonics = scone_vis(proxy_points, view_harmonics=view_harmonics) 280 | if params.true_monte_carlo_sampling: 281 | proxy_points = torch.unsqueeze(proxy_points[0][sample_idx], dim=0) 282 | visibility_gain_harmonics = torch.unsqueeze(visibility_gain_harmonics[0][sample_idx], dim=0) 283 | 284 | if params.ddp or params.jz: 285 | cov_pred_i = scone_vis.module.compute_coverage_gain(proxy_points, 286 | visibility_gain_harmonics, 287 | X_cam) 288 | else: 289 | cov_pred_i = scone_vis.compute_coverage_gain(proxy_points, 290 | visibility_gain_harmonics, 291 | X_cam.view(1, -1, 3)).view(-1, 1) 292 | 293 | # Identify maximum gain to get NBV camera 294 | (max_gain, max_idx) = torch.max(cov_pred_i, dim=0) 295 | 296 | computation_time += time.time() - computation_t0 297 | 298 | # Set NBV camera parameters 299 | view_idx = torch.cat((view_idx, torch.Tensor([max_idx]).long().to(device)), dim=0) 300 | X_nbv = X_cam[max_idx:max_idx + 1] 301 | X_nbv_world = X_cam_world[max_idx:max_idx + 1] 302 | r_nbv_world = camera_dist[max_idx:max_idx + 1] 303 | elev_nbv_world = camera_elev[max_idx:max_idx + 1] 304 | azim_nbv_world = camera_azim[max_idx:max_idx + 1] 305 | 306 | nbv_dist = torch.Tensor([params.camera_dist]).to(device) 307 | 308 | nbv_R, nbv_T = look_at_view_transform(dist=r_nbv_world, 309 | elev=elev_nbv_world, 310 | azim=azim_nbv_world, 311 | device=device) 312 | 313 | R_view = torch.vstack((R_view, nbv_R)) 314 | # screen_dist = torch.vstack((screen_dist, nbv_dist)) 315 | X_view = torch.vstack((X_view, X_nbv)) 316 | 317 | # Computing surface coverage 318 | coverage = compute_surface_coverage_from_cam_idx(coverage_matrix, view_idx).detach().item() 319 | # avg_coverage += coverage / batch_size 320 | coverage_dict[path_i].append(coverage) 321 | coverages[j_view] += coverage 322 | 323 | sum_coverages += coverages 324 | 325 | # ----------Metrics computation on batch---------- 326 | 327 | if batch % 10 == 0: 328 | torch.cuda.empty_cache() 329 | print("--- Batch", batch, "---") 330 | print("Batch size:", batch_size) 331 | # print("Coverage:", avg_coverage / (batch + 1)) 332 | print("Coverages:", sum_coverages / ((batch + 1) * params.batch_size)) 333 | print("Nb of meshes done:", (batch + 1) * params.batch_size) 334 | print("Computation time:", computation_time, '\n') 335 | 336 | results = {} 337 | # results["occ_threshold"] = occ_threshold 338 | # results["uncertainty_threshold"] = params.uncertainty_threshold 339 | # results["uncertainty_mode"] = params.uncertainty_mode 340 | # results["compute_cross_correction"] = params.compute_cross_correction 341 | # results["nbv_mode"] = nbv_mode 342 | results["coverages"] = coverage_dict 343 | 344 | # print("Results:", results) 345 | 346 | print("Avg coverages loss:", sum_coverages.detach().cpu() / len(dataloader.dataset)) 347 | print("Done in", (time.time() - t0) / 3600., "hours!") 348 | print("Computation time:", computation_time) 349 | print("Average computation time:", computation_time / len(dataloader.dataset)) 350 | 351 | print("Terminated in", (time.time() - t0) / 60., "minutes.") 352 | return results 353 | 354 | if __name__ == '__main__': 355 | json_name = save_test_params(True) 356 | params = load_params(json_name) 357 | params.n_view_max = n_view_max 358 | 359 | # Set device 360 | device = setup_device(params, ddp_rank=None) 361 | 362 | # Load models 363 | print("Loading SconeOcc...") 364 | scone_occ = load_scone_occ(params, scone_occ_model_name, ddp_model=True, device=device) 365 | print("Model has", count_parameters(scone_occ) / 1e6, "M parameters.") 366 | scone_occ.eval() 367 | 368 | print("Loading SconeVis...") 369 | scone_vis = load_scone_vis(params, scone_vis_model_name, ddp_model=True, device=device) 370 | print("Model has", count_parameters(scone_vis) / 1e6, "M parameters.") 371 | scone_vis.eval() 372 | 373 | if test_novel and params.test_novel: 374 | print("Test on novel categories.") 375 | 376 | train_dataloader, val_dataloader, test_dataloader = get_shapenet_dataloader(batch_size=params.batch_size, 377 | ddp=params.ddp, jz=params.jz, 378 | world_size=None, ddp_rank=None, 379 | test_number=params.test_number, 380 | test_novel=test_novel, 381 | load_obj=False, 382 | data_path=None) 383 | 384 | # Main loop 385 | eval_results = [] 386 | eval_results.append(test_loop(params, test_dataloader, 387 | scone_occ, scone_vis, device)) 388 | 389 | if params.test_number == 0: 390 | json_name = "test_iterative_results_for_models_" + scone_vis_model_name + ".json" 391 | elif params.test_number == -1: 392 | json_name = "full_test_iterative_results_for_models_" + scone_vis_model_name + ".json" 393 | else: 394 | json_name = "test_" + str(params.test_number)\ 395 | + "_iterative_results_for_models_" + scone_vis_model_name + "_v2.json" 396 | 397 | if test_novel: 398 | json_name = "novel_" + json_name 399 | 400 | # json_name = "novel_test_" + str(params.test_number) +"_iterative_results_for_random_nbv_faster.json" 401 | 402 | for res in eval_results: 403 | for key in res: 404 | if type(res[key]) == torch.Tensor: 405 | res[key] = res[key].detach().item() 406 | 407 | if type(res[key]) == np.ndarray: 408 | res[key] = float(res[key]) 409 | 410 | with open(json_name, 'w') as outfile: 411 | json.dump(eval_results, outfile) 412 | print("Saved data about test losses in", json_name) -------------------------------------------------------------------------------- /docs/gifs/colosseum.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anttwo/SCONE/24f176518a794b69c5aeed81d22c0077ce4b9169/docs/gifs/colosseum.gif -------------------------------------------------------------------------------- /docs/gifs/fushimi.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anttwo/SCONE/24f176518a794b69c5aeed81d22c0077ce4b9169/docs/gifs/fushimi.gif -------------------------------------------------------------------------------- /docs/gifs/museum.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anttwo/SCONE/24f176518a794b69c5aeed81d22c0077ce4b9169/docs/gifs/museum.gif -------------------------------------------------------------------------------- /docs/gifs/pantheon.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anttwo/SCONE/24f176518a794b69c5aeed81d22c0077ce4b9169/docs/gifs/pantheon.gif --------------------------------------------------------------------------------