├── model ├── __init__.py ├── diff_utils │ ├── __init__.py │ ├── distributed.py │ ├── pix3d_util.py │ ├── util.py │ ├── visualizer.py │ └── demo_util.py ├── networks │ ├── __init__.py │ ├── diffusion_shape │ │ ├── __init__.py │ │ ├── diff_utils │ │ │ ├── __init__.py │ │ │ ├── distributed.py │ │ │ ├── pix3d_util.py │ │ │ ├── util.py │ │ │ ├── visualizer.py │ │ │ └── demo_util.py │ │ ├── network.py │ │ ├── base_model.py │ │ └── sg_diff.py │ ├── vqvae_networks │ │ ├── __init__.py │ │ ├── network.py │ │ ├── quantizer.py │ │ └── vqvae_model.py │ ├── clip_networks │ │ └── network.py │ └── diffusion_layout │ │ ├── loss.py │ │ └── mmg2layout.py ├── layers.py ├── losses.py ├── model_utils.py └── graph.py ├── dataset ├── __init__.py └── util.py ├── helpers ├── __init__.py ├── interrupt_handler.py ├── psutil.py ├── visualize_graph.py ├── lr_scheduler.py └── viz_util.py ├── scripts ├── __init__.py ├── pytorch_structural_losses │ ├── .gitignore │ ├── __init__.py │ ├── src │ │ ├── nndistance.cuh │ │ ├── approxmatch.cuh │ │ ├── utils.hpp │ │ ├── nndistance.cu │ │ └── structural_loss.cpp │ ├── pybind │ │ ├── bind.cpp │ │ └── extern.hpp │ ├── setup.py │ ├── nn_distance.py │ ├── match_cost.py │ └── Makefile ├── StructuralLosses │ ├── __init__.py │ ├── nn_distance.py │ └── match_cost.py ├── collect_gt_sdf_images.py └── compute_fid_scores_3dfront.py ├── assets ├── pku.png ├── tum.png ├── teaser.png ├── beihang.png ├── pipeline.png ├── shuyuan.png ├── thetaicon.png ├── thetacloud2.png ├── icon.svg └── edgecloud-logo.svg ├── requirements.txt ├── scripts_sh ├── mmd_cov_1nn.sh ├── consistency_check.sh ├── eval_all_mask.sh ├── train_all_mask.sh └── compute_fid_scores.sh ├── extension └── old_chamfer │ ├── setup.py │ ├── test.py │ ├── chamfer_cuda.cpp │ ├── dist_chamfer.py │ └── chamfer.cu ├── config ├── vqvae_snet.yaml ├── threedfront_objfeat_vqvae.yaml ├── sdfusion-txt2shape_mp.yaml └── full_mp.yaml ├── LICENSE └── .gitignore /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /helpers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/diff_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/networks/diffusion_shape/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/networks/vqvae_networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/networks/diffusion_shape/diff_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/pytorch_structural_losses/.gitignore: -------------------------------------------------------------------------------- 1 | PyTorchStructuralLosses.egg-info/ 2 | -------------------------------------------------------------------------------- /assets/pku.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangzhifeio/MMGDreamer/HEAD/assets/pku.png -------------------------------------------------------------------------------- /assets/tum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangzhifeio/MMGDreamer/HEAD/assets/tum.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangzhifeio/MMGDreamer/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /assets/beihang.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangzhifeio/MMGDreamer/HEAD/assets/beihang.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangzhifeio/MMGDreamer/HEAD/assets/pipeline.png -------------------------------------------------------------------------------- /assets/shuyuan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangzhifeio/MMGDreamer/HEAD/assets/shuyuan.png -------------------------------------------------------------------------------- /assets/thetaicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangzhifeio/MMGDreamer/HEAD/assets/thetaicon.png -------------------------------------------------------------------------------- /assets/thetacloud2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangzhifeio/MMGDreamer/HEAD/assets/thetacloud2.png -------------------------------------------------------------------------------- /scripts/StructuralLosses/__init__.py: -------------------------------------------------------------------------------- 1 | #import torch 2 | 3 | #from MakePytorchBackend import AddGPU, Foo, ApproxMatch 4 | 5 | #from Add import add_gpu, approx_match 6 | 7 | -------------------------------------------------------------------------------- /scripts/pytorch_structural_losses/__init__.py: -------------------------------------------------------------------------------- 1 | #import torch 2 | 3 | #from MakePytorchBackend import AddGPU, Foo, ApproxMatch 4 | 5 | #from Add import add_gpu, approx_match 6 | 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.7.2 2 | numpy==1.23 3 | scipy==1.10 4 | tqdm==4.65.0 5 | wheel==0.38.4 6 | pyrender 7 | seaborn 8 | trimesh 9 | h5py 10 | PyMCubes 11 | imageio 12 | scikit-image 13 | fvcore 14 | opencv-python==4.8.0.76 15 | pyyaml==6.0 -------------------------------------------------------------------------------- /scripts_sh/mmd_cov_1nn.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=4 python scripts/compute_mmd_cov_1nn.py \ 2 | --path_to_gt_mesh /s2/yangzhifei/project/MMGDreamer/FRONT/gt_fov90_h8_obj_meshes_refine \ 3 | --path_to_synthesized_mesh /data/yangzhifei/project/MMGDreamer/experiments/xxxx/vis/2049/mmgscene/object_meshes \ 4 | --save_name mmd_cov_1nn 5 | -------------------------------------------------------------------------------- /scripts_sh/consistency_check.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 python scripts/consistency_check.py --path_to_test /s2/yangzhifei/project/MMGDreamer/experiments/train_all_image_mask/xxxx/2049 \ 2 | --room all \ 3 | --catfile /data/yangzhifei/project/MMGDreamer/FRONT/classes_all.txt \ 4 | --gt_consistency_file /data/yangzhifei/project/MMGDreamer/FRONT/consistencies_all_test.json \ 5 | 6 | -------------------------------------------------------------------------------- /extension/old_chamfer/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='chamfer', 6 | ext_modules=[ 7 | CUDAExtension('chamfer', [ 8 | 'chamfer_cuda.cpp', 9 | 'chamfer.cu', 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /scripts/pytorch_structural_losses/src/nndistance.cuh: -------------------------------------------------------------------------------- 1 | void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream); 2 | void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream); 3 | -------------------------------------------------------------------------------- /scripts_sh/eval_all_mask.sh: -------------------------------------------------------------------------------- 1 | # mask type three 2 | CUDA_VISIBLE_DEVICES=4 python scripts/eval_3dfront_three.py --exp ./experiments/train_all_image_mask \ 3 | --dataset /s2/yangzhifei/project/MMGDreamer/FRONT \ 4 | --epoch 2049 \ 5 | --visualize True \ 6 | --room_type all \ 7 | --render_type mmgscene \ 8 | --gen_shape True \ 9 | --with_image True \ 10 | --mask_type three \ 11 | --name_render I_R 12 | 13 | -------------------------------------------------------------------------------- /scripts/pytorch_structural_losses/pybind/bind.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "extern.hpp" 6 | 7 | namespace py = pybind11; 8 | 9 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 10 | m.def("ApproxMatch", &ApproxMatch); 11 | m.def("MatchCost", &MatchCost); 12 | m.def("MatchCostGrad", &MatchCostGrad); 13 | m.def("NNDistance", &NNDistance); 14 | m.def("NNDistanceGrad", &NNDistanceGrad); 15 | } 16 | -------------------------------------------------------------------------------- /extension/old_chamfer/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import dist_chamfer 3 | dist = dist_chamfer.chamferDist() 4 | 5 | with torch.enable_grad(): 6 | p1 = torch.rand(10,1000,6) 7 | p2 = torch.rand(10,1500,6) 8 | p1.requires_grad = True 9 | p2.requires_grad = True 10 | points1 = p1.cuda() 11 | points2 = p2.cuda() 12 | cost, _ = dist(points1, points2) 13 | print(cost) 14 | loss = torch.sum(cost) 15 | print(loss) 16 | loss.backward() 17 | print(points1.grad, points2.grad) 18 | -------------------------------------------------------------------------------- /scripts/pytorch_structural_losses/pybind/extern.hpp: -------------------------------------------------------------------------------- 1 | std::vector ApproxMatch(at::Tensor in_a, at::Tensor in_b); 2 | at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match); 3 | std::vector MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match); 4 | 5 | std::vector NNDistance(at::Tensor set_d, at::Tensor set_q); 6 | std::vector NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2); 7 | -------------------------------------------------------------------------------- /scripts/pytorch_structural_losses/src/approxmatch.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | template 3 | void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N, 4 | cudaStream_t stream); 5 | */ 6 | void approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream); 7 | void matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream); 8 | void matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream); 9 | -------------------------------------------------------------------------------- /config/vqvae_snet.yaml: -------------------------------------------------------------------------------- 1 | # ref: https://github.com/yccyenchicheng/AutoSDF, https://github.com/CompVis/latent-diffusion/ 2 | # code/configs/pvqvae_nembed-8192-z-3x16x16x16-snet.yaml 3 | 4 | model: 5 | params: 6 | embed_dim: 3 7 | n_embed: 8192 8 | ddconfig: 9 | double_z: False 10 | z_channels: 3 11 | resolution: 64 12 | in_channels: 1 13 | out_ch: 1 14 | ch: 64 15 | # ch_mult: [1,1,2,2,4] # num_down = len(ch_mult)-1 16 | ch_mult: [1,2,4] # num_down = len(ch_mult)-1 17 | num_res_blocks: 1 18 | attn_resolutions: [] 19 | dropout: 0.0 20 | 21 | lossconfig: 22 | params: 23 | codebook_weight: 1.0 -------------------------------------------------------------------------------- /scripts_sh/train_all_mask.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=6 python scripts/train_3dfront_mask.py --exp ./experiments/train_all_image_mask \ 2 | --room_type all \ 3 | --dataset /s2/yangzhifei/project/MMGDreamer/FRONT \ 4 | --residual True \ 5 | --network_type mmgdreamer \ 6 | --with_SDF True \ 7 | --with_CLIP True \ 8 | --with_image True \ 9 | --batchSize 128 \ 10 | --workers 8 \ 11 | --loadmodel False \ 12 | --nepoch 2050 \ 13 | --large False \ 14 | --diff_yaml /s2/yangzhifei/project/MMGDreamer/config/full_mp.yaml \ 15 | --use_scene_rels True \ 16 | --use_image_scene_rels True \ 17 | --shuffle_objs True \ 18 | --mask_random True 19 | -------------------------------------------------------------------------------- /scripts/pytorch_structural_losses/src/utils.hpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | class Formatter { 6 | public: 7 | Formatter() {} 8 | ~Formatter() {} 9 | 10 | template Formatter &operator<<(const Type &value) { 11 | stream_ << value; 12 | return *this; 13 | } 14 | 15 | std::string str() const { return stream_.str(); } 16 | operator std::string() const { return stream_.str(); } 17 | 18 | enum ConvertToString { to_str }; 19 | 20 | std::string operator>>(ConvertToString) { return stream_.str(); } 21 | 22 | private: 23 | std::stringstream stream_; 24 | Formatter(const Formatter &); 25 | Formatter &operator=(Formatter &); 26 | }; 27 | -------------------------------------------------------------------------------- /config/threedfront_objfeat_vqvae.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: "objfeatvqvae" 3 | objfeat_type: "" 4 | vq_type: "gumbel" 5 | 6 | training: 7 | splits: ["train", "val"] 8 | epochs: 2000 9 | steps_per_epoch: 500 10 | batch_size: 128 11 | save_frequency: 100 12 | log_frequency: 1 13 | optimizer: 14 | name: "adamw" 15 | lr: 0.0001 16 | weight_decay: 0.02 17 | loss_weights: 18 | qloss: 1. 19 | rec_mse: 1. 20 | ema: 21 | use_ema: true 22 | max_decay: 0.9999 23 | min_decay: 0. 24 | update_after_step: 0 25 | use_warmup: true 26 | inv_gamma: 1. 27 | power: 0.75 28 | 29 | validation: 30 | splits: ["test"] 31 | batch_size: 64 32 | frequency: 1 33 | -------------------------------------------------------------------------------- /helpers/interrupt_handler.py: -------------------------------------------------------------------------------- 1 | import signal 2 | 3 | 4 | class InterruptHandler(object): 5 | def __init__(self, sig=signal.SIGINT): 6 | self.sig = sig 7 | 8 | def __enter__(self): 9 | 10 | self.interrupted = False 11 | self.released = False 12 | 13 | self.original_handler = signal.getsignal(self.sig) 14 | 15 | def handler(signum, frame): 16 | self.release() 17 | self.interrupted = True 18 | 19 | signal.signal(self.sig, handler) 20 | 21 | return self 22 | 23 | def __exit__(self, type, value, tb): 24 | self.release() 25 | 26 | def release(self): 27 | 28 | if self.released: 29 | return False 30 | 31 | signal.signal(self.sig, self.original_handler) 32 | 33 | self.released = True 34 | 35 | return True 36 | -------------------------------------------------------------------------------- /config/sdfusion-txt2shape_mp.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | linear_start: 0.00085 4 | linear_end: 0.012 5 | conditioning_key: crossattn 6 | timesteps: 1000 7 | scale_factor: 0.18215 8 | 9 | bert: 10 | params: 11 | n_embed: 1280 12 | n_layer: 32 13 | 14 | unet: 15 | params: 16 | image_size: 16 17 | in_channels: 3 18 | out_channels: 3 19 | model_channels: 224 20 | num_res_blocks: 2 21 | attention_resolutions: [ 4, 2 ] # 16, 8, 4 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | channel_mult: [ 1, 2, 3 ] 27 | # num_head_channels: 32 28 | num_heads: 8 29 | 30 | # 3d 31 | dims: 3 32 | 33 | # cond_model params 34 | use_spatial_transformer: true 35 | transformer_depth: 1 36 | context_dim: 1280 37 | use_checkpoint: true 38 | legacy: False 39 | 40 | messsage_passing: True 41 | enable_t_emb: true 42 | -------------------------------------------------------------------------------- /scripts/pytorch_structural_losses/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 3 | 4 | # Python interface 5 | setup( 6 | name='PyTorchStructuralLosses', 7 | version='0.1.0', 8 | install_requires=['torch'], 9 | packages=['StructuralLosses'], 10 | package_dir={'StructuralLosses': './'}, 11 | ext_modules=[ 12 | CUDAExtension( 13 | name='StructuralLossesBackend', 14 | include_dirs=['./'], 15 | sources=[ 16 | 'pybind/bind.cpp', 17 | ], 18 | libraries=['make_pytorch'], 19 | library_dirs=['objs'], 20 | # extra_compile_args=['-g'] 21 | ) 22 | ], 23 | cmdclass={'build_ext': BuildExtension}, 24 | author='Christopher B. Choy', 25 | author_email='chrischoy@ai.stanford.edu', 26 | description='Tutorial for Pytorch C++ Extension with a Makefile', 27 | keywords='Pytorch C++ Extension', 28 | url='https://github.com/chrischoy/MakePytorchPlusPlus', 29 | zip_safe=False, 30 | ) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 yangzhifei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /extension/old_chamfer/chamfer_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | ///TMP 5 | //#include "common.h" 6 | /// NOT TMP 7 | 8 | 9 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); 10 | 11 | 12 | int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); 13 | 14 | 15 | 16 | 17 | int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { 18 | return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); 19 | } 20 | 21 | 22 | int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, 23 | at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { 24 | 25 | return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); 26 | } 27 | 28 | 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); 32 | m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); 33 | } -------------------------------------------------------------------------------- /model/layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # Modification copyright 2021 Helisa Dhamo, Fabian Manhardt 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch.nn as nn 19 | 20 | 21 | def build_mlp(dim_list, activation='relu', batch_norm='none', 22 | dropout=0, final_nonlinearity=True): 23 | layers = [] 24 | for i in range(len(dim_list) - 1): 25 | dim_in, dim_out = dim_list[i], dim_list[i + 1] 26 | layers.append(nn.Linear(dim_in, dim_out)) 27 | final_layer = (i == len(dim_list) - 2) 28 | if not final_layer or final_nonlinearity: 29 | if batch_norm == 'batch': 30 | layers.append(nn.BatchNorm1d(dim_out)) 31 | if activation == 'relu': 32 | layers.append(nn.ReLU()) 33 | elif activation == 'leakyrelu': 34 | layers.append(nn.LeakyReLU()) 35 | if dropout > 0: 36 | layers.append(nn.Dropout(p=dropout)) 37 | 38 | return nn.Sequential(*layers) 39 | -------------------------------------------------------------------------------- /scripts_sh/compute_fid_scores.sh: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------------------------------------------------------------------------------------------- 2 | 3 | CUDA_VISIBLE_DEVICES=3 python scripts/compute_fid_scores_3dfront.py --room bedroom \ 4 | --path_to_real_renderings /s2/yangzhifei/project/MMGDreamer/FRONT/sdf_fov90_h8_wo_lamp_no_stool/small/test \ 5 | --path_to_synthesized_renderings xxxx_render_imgs_path \ 6 | --path_to_test /s2/yangzhifei/project/MMGDreamer/experiments/fid_kid_tmp/ 7 | 8 | CUDA_VISIBLE_DEVICES=3 python scripts/compute_fid_scores_3dfront.py --room livingroom \ 9 | --path_to_real_renderings /s2/yangzhifei/project/MMGDreamer/FRONT/sdf_fov90_h8_wo_lamp_no_stool/small/test \ 10 | --path_to_synthesized_renderings xxxx_render_imgs_path \ 11 | --path_to_test /s2/yangzhifei/project/MMGDreamer/experiments/fid_kid_tmp/ 12 | 13 | CUDA_VISIBLE_DEVICES=3 python scripts/compute_fid_scores_3dfront.py --room diningroom \ 14 | --path_to_real_renderings /s2/yangzhifei/project/MMGDreamer/FRONT/sdf_fov90_h8_wo_lamp_no_stool/small/test \ 15 | --path_to_synthesized_renderings xxxx_render_imgs_path \ 16 | --path_to_test /s2/yangzhifei/project/MMGDreamer/experiments/fid_kid_tmp/ 17 | 18 | CUDA_VISIBLE_DEVICES=3 python scripts/compute_fid_scores_3dfront.py --room all \ 19 | --path_to_real_renderings /s2/yangzhifei/project/MMGDreamer/FRONT/sdf_fov90_h8_wo_lamp_no_stool/small/test \ 20 | --path_to_synthesized_renderings xxxx_render_imgs_path \ 21 | --path_to_test /s2/yangzhifei/project/MMGDreamer/experiments/fid_kid_tmp/ 22 | 23 | -------------------------------------------------------------------------------- /model/networks/clip_networks/network.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: 3 | - https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/encoders/modules.py 4 | - https://github.com/openai/CLIP 5 | """ 6 | 7 | import kornia 8 | from einops import rearrange, repeat 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from external.clip import clip 14 | 15 | class CLIPImageEncoder(nn.Module): 16 | def __init__( 17 | self, 18 | model="ViT-B/32", 19 | jit=False, 20 | device='cuda' if torch.cuda.is_available() else 'cpu', 21 | antialias=False, 22 | ): 23 | super().__init__() 24 | self.model, _ = clip.load(name=model, device=device, jit=jit) 25 | 26 | # self.model, self.preprocess = clip.load(name=model, device=device, jit=jit) 27 | self.model = self.model.float() # turns out this is important... 28 | 29 | self.antialias = antialias 30 | 31 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 32 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 33 | 34 | def preprocess(self, x): 35 | # normalize to [0,1] 36 | x = kornia.geometry.resize(x, (224, 224), 37 | interpolation='bicubic',align_corners=True, 38 | antialias=self.antialias) 39 | x = (x + 1.) / 2. 40 | # renormalize according to clip 41 | x = kornia.enhance.normalize(x, self.mean, self.std) 42 | return x 43 | 44 | def forward(self, x): 45 | # x is assumed to be in range [-1,1] 46 | return self.model.encode_image(self.preprocess(x)) 47 | -------------------------------------------------------------------------------- /extension/old_chamfer/dist_chamfer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | from torch.autograd import Function 4 | import torch 5 | import sys 6 | from numbers import Number 7 | from collections import Set, Mapping, deque 8 | import chamfer 9 | 10 | # Chamfer's distance module @thibaultgroueix 11 | # GPU tensors only 12 | class chamferFunction(Function): 13 | @staticmethod 14 | def forward(ctx, xyz1, xyz2): 15 | batchsize, n, _ = xyz1.size() 16 | _, m, _ = xyz2.size() 17 | 18 | dist1 = torch.zeros(batchsize, n) 19 | dist2 = torch.zeros(batchsize, m) 20 | 21 | idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) 22 | idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) 23 | 24 | dist1 = dist1.cuda() 25 | dist2 = dist2.cuda() 26 | idx1 = idx1.cuda() 27 | idx2 = idx2.cuda() 28 | 29 | chamfer.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 30 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 31 | return dist1, dist2 32 | 33 | @staticmethod 34 | def backward(ctx, graddist1, graddist2): 35 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 36 | graddist1 = graddist1.contiguous() 37 | graddist2 = graddist2.contiguous() 38 | 39 | gradxyz1 = torch.zeros(xyz1.size()) 40 | gradxyz2 = torch.zeros(xyz2.size()) 41 | 42 | gradxyz1 = gradxyz1.cuda() 43 | gradxyz2 = gradxyz2.cuda() 44 | chamfer.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) 45 | return gradxyz1, gradxyz2 46 | 47 | class chamferDist(nn.Module): 48 | def __init__(self): 49 | super(chamferDist, self).__init__() 50 | 51 | def forward(self, input1, input2): 52 | return chamferFunction.apply(input1, input2) 53 | -------------------------------------------------------------------------------- /scripts/pytorch_structural_losses/nn_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | # from extensions.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad 4 | from scripts.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad 5 | 6 | # Inherit from Function 7 | class NNDistanceFunction(Function): 8 | # Note that both forward and backward are @staticmethods 9 | @staticmethod 10 | # bias is an optional argument 11 | def forward(ctx, seta, setb): 12 | #print("Match Cost Forward") 13 | ctx.save_for_backward(seta, setb) 14 | ''' 15 | input: 16 | set1 : batch_size * #dataset_points * 3 17 | set2 : batch_size * #query_points * 3 18 | returns: 19 | dist1, idx1, dist2, idx2 20 | ''' 21 | dist1, idx1, dist2, idx2 = NNDistance(seta, setb) 22 | ctx.idx1 = idx1 23 | ctx.idx2 = idx2 24 | return dist1, dist2 25 | 26 | # This function has only a single output, so it gets only one gradient 27 | @staticmethod 28 | def backward(ctx, grad_dist1, grad_dist2): 29 | #print("Match Cost Backward") 30 | # This is a pattern that is very convenient - at the top of backward 31 | # unpack saved_tensors and initialize all gradients w.r.t. inputs to 32 | # None. Thanks to the fact that additional trailing Nones are 33 | # ignored, the return statement is simple even when the function has 34 | # optional inputs. 35 | seta, setb = ctx.saved_tensors 36 | idx1 = ctx.idx1 37 | idx2 = ctx.idx2 38 | grada, gradb = NNDistanceGrad(seta, setb, idx1, idx2, grad_dist1, grad_dist2) 39 | return grada, gradb 40 | 41 | nn_distance = NNDistanceFunction.apply 42 | 43 | -------------------------------------------------------------------------------- /scripts/StructuralLosses/nn_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | # from extensions.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad 4 | # from scripts.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad 5 | from StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad 6 | # Inherit from Function 7 | class NNDistanceFunction(Function): 8 | # Note that both forward and backward are @staticmethods 9 | @staticmethod 10 | # bias is an optional argument 11 | def forward(ctx, seta, setb): 12 | #print("Match Cost Forward") 13 | ctx.save_for_backward(seta, setb) 14 | ''' 15 | input: 16 | set1 : batch_size * #dataset_points * 3 17 | set2 : batch_size * #query_points * 3 18 | returns: 19 | dist1, idx1, dist2, idx2 20 | ''' 21 | dist1, idx1, dist2, idx2 = NNDistance(seta, setb) 22 | ctx.idx1 = idx1 23 | ctx.idx2 = idx2 24 | return dist1, dist2 25 | 26 | # This function has only a single output, so it gets only one gradient 27 | @staticmethod 28 | def backward(ctx, grad_dist1, grad_dist2): 29 | #print("Match Cost Backward") 30 | # This is a pattern that is very convenient - at the top of backward 31 | # unpack saved_tensors and initialize all gradients w.r.t. inputs to 32 | # None. Thanks to the fact that additional trailing Nones are 33 | # ignored, the return statement is simple even when the function has 34 | # optional inputs. 35 | seta, setb = ctx.saved_tensors 36 | idx1 = ctx.idx1 37 | idx2 = ctx.idx2 38 | grada, gradb = NNDistanceGrad(seta, setb, idx1, idx2, grad_dist1, grad_dist2) 39 | return grada, gradb 40 | 41 | nn_distance = NNDistanceFunction.apply 42 | 43 | -------------------------------------------------------------------------------- /model/networks/diffusion_shape/network.py: -------------------------------------------------------------------------------- 1 | """ Reference: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/ddpm.py#L1395-L1421 """ 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | 7 | from einops import rearrange, repeat 8 | 9 | from .openai_model_3d import UNet3DModel 10 | 11 | class DiffusionUNet(nn.Module): 12 | def __init__(self, unet_params, vq_conf=None, conditioning_key=None): 13 | """ init method """ 14 | super().__init__() 15 | self.conditioning_key = conditioning_key # default for lsun_bedrooms 16 | unet_params.conditioning_key = self.conditioning_key 17 | self.diffusion_net = UNet3DModel(**unet_params) 18 | 19 | 20 | 21 | def forward(self, x, obj_embed, triples, t, c_concat: list = None, c_crossattn: list = None): 22 | # x: should be latent code. shape: (bs X z_dim X d X h X w) 23 | 24 | if self.conditioning_key is None: 25 | out = self.diffusion_net(x, obj_embed, triples, t) 26 | elif self.conditioning_key == 'concat': 27 | xc = torch.cat([x] + c_concat, dim=1) 28 | out = self.diffusion_net(xc, obj_embed, triples, t) 29 | elif self.conditioning_key == 'crossattn': 30 | cc = torch.cat(c_crossattn, 1) 31 | out = self.diffusion_net(x, obj_embed, triples, t, context=cc) 32 | elif self.conditioning_key == 'hybrid': 33 | xc = torch.cat([x] + c_concat, dim=1) 34 | cc = torch.cat(c_crossattn, 1) 35 | out = self.diffusion_net(xc, obj_embed, triples, t, context=cc) 36 | # import pdb; pdb.set_trace() 37 | elif self.conditioning_key == 'adm': 38 | cc = c_crossattn[0] 39 | out = self.diffusion_net(x, obj_embed, triples, t, y=cc) 40 | else: 41 | raise NotImplementedError() 42 | 43 | return out 44 | -------------------------------------------------------------------------------- /scripts/pytorch_structural_losses/match_cost.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from scripts.StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad 4 | 5 | # Inherit from Function 6 | class MatchCostFunction(Function): 7 | # Note that both forward and backward are @staticmethods 8 | @staticmethod 9 | # bias is an optional argument 10 | def forward(ctx, seta, setb): 11 | #print("Match Cost Forward") 12 | ctx.save_for_backward(seta, setb) 13 | ''' 14 | input: 15 | set1 : batch_size * #dataset_points * 3 16 | set2 : batch_size * #query_points * 3 17 | returns: 18 | match : batch_size * #query_points * #dataset_points 19 | ''' 20 | match, temp = ApproxMatch(seta, setb) 21 | ctx.match = match 22 | cost = MatchCost(seta, setb, match) 23 | return cost 24 | 25 | """ 26 | grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match) 27 | return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None] 28 | """ 29 | # This function has only a single output, so it gets only one gradient 30 | @staticmethod 31 | def backward(ctx, grad_output): 32 | #print("Match Cost Backward") 33 | # This is a pattern that is very convenient - at the top of backward 34 | # unpack saved_tensors and initialize all gradients w.r.t. inputs to 35 | # None. Thanks to the fact that additional trailing Nones are 36 | # ignored, the return statement is simple even when the function has 37 | # optional inputs. 38 | seta, setb = ctx.saved_tensors 39 | #grad_input = grad_weight = grad_bias = None 40 | grada, gradb = MatchCostGrad(seta, setb, ctx.match) 41 | grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2) 42 | return grada*grad_output_expand, gradb*grad_output_expand 43 | 44 | match_cost = MatchCostFunction.apply 45 | 46 | -------------------------------------------------------------------------------- /scripts/StructuralLosses/match_cost.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | # from scripts.StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad 4 | from StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad 5 | # Inherit from Function 6 | class MatchCostFunction(Function): 7 | # Note that both forward and backward are @staticmethods 8 | @staticmethod 9 | # bias is an optional argument 10 | def forward(ctx, seta, setb): 11 | #print("Match Cost Forward") 12 | ctx.save_for_backward(seta, setb) 13 | ''' 14 | input: 15 | set1 : batch_size * #dataset_points * 3 16 | set2 : batch_size * #query_points * 3 17 | returns: 18 | match : batch_size * #query_points * #dataset_points 19 | ''' 20 | match, temp = ApproxMatch(seta, setb) 21 | ctx.match = match 22 | cost = MatchCost(seta, setb, match) 23 | return cost 24 | 25 | """ 26 | grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match) 27 | return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None] 28 | """ 29 | # This function has only a single output, so it gets only one gradient 30 | @staticmethod 31 | def backward(ctx, grad_output): 32 | #print("Match Cost Backward") 33 | # This is a pattern that is very convenient - at the top of backward 34 | # unpack saved_tensors and initialize all gradients w.r.t. inputs to 35 | # None. Thanks to the fact that additional trailing Nones are 36 | # ignored, the return statement is simple even when the function has 37 | # optional inputs. 38 | seta, setb = ctx.saved_tensors 39 | #grad_input = grad_weight = grad_bias = None 40 | grada, gradb = MatchCostGrad(seta, setb, ctx.match) 41 | grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2) 42 | return grada*grad_output_expand, gradb*grad_output_expand 43 | 44 | match_cost = MatchCostFunction.apply 45 | 46 | -------------------------------------------------------------------------------- /config/full_mp.yaml: -------------------------------------------------------------------------------- 1 | hyper: 2 | batch_size: 128 # shape batchsize 3 | gpu_ids: 0 4 | logs_dir: /s2/yangzhifei/project/MMGDreamer/experiments/train_all/full_mp/diff_crossattn 5 | results_dir: /s2/yangzhifei/project/MMGDreamer/experiments/train_all/full_mp/diff_crossattn 6 | name: ./ 7 | isTrain: True 8 | device: 'cuda' 9 | distributed: 0 10 | lr_init: 1e-4 11 | lr_step: [35000, 70000, 140000] 12 | lr_evo: [5e-5, 1e-5, 5e-6] 13 | # dataset unused 14 | dataset: 15 | res: 128 16 | trunc_thres: 0.2 17 | ratio: 1 18 | 19 | layout_branch: 20 | model: diffusion_scene_layout_ddpm 21 | angle_dim: 2 22 | denoiser: unet1d 23 | relation_condition: true 24 | denoiser_kwargs: 25 | dims: 1 # 1D 26 | in_channels: 8 # size(3)+loc(3)+sincos(2) 27 | out_channels: 8 # same 28 | model_channels: 512 29 | channel_mult: [ 1,1,1,1] 30 | num_res_blocks: 2 31 | attention_resolutions: [ 4, 2 ] 32 | num_heads: 8 33 | # cond_model params 34 | use_spatial_transformer: true 35 | transformer_depth: 1 36 | conditioning_key: 'crossattn' 37 | concat_dim: 1280 38 | crossattn_dim: 1280 39 | use_checkpoint: true 40 | enable_t_emb: true 41 | 42 | diffusion_kwargs: 43 | schedule_type: 'linear' 44 | beta_start: 0.0001 45 | beta_end: 0.02 46 | time_num: 1000 47 | model_mean_type: 'eps' 48 | model_var_type: 'fixedsmall' 49 | loss_separate: true 50 | loss_iou: false 51 | iou_type: obb 52 | train_stats_file: null 53 | 54 | shape_branch: 55 | model: sdfusion-txt2shape_mp 56 | sampling: greedy 57 | ckpt: null 58 | df_cfg: /s2/yangzhifei/project/MMGDreamer/config/sdfusion-txt2shape_mp.yaml 59 | ddim_steps: 100 60 | ddim_eta: 0.0 61 | uc_scale: 3.0 62 | vq_model: vqvae 63 | vq_cfg: /s2/yangzhifei/project/MMGDreamer/config/vqvae_snet.yaml 64 | vq_dset: None 65 | vq_cat: None 66 | vq_ckpt: /s2/yangzhifei/project/MMGDreamer/checkpoint/vqvae_threedfront_best.pth 67 | misc: 68 | debug: 0 69 | seed: 111 70 | backend: gloo 71 | local_rank: 0 72 | 73 | training: 74 | lr: 1e-5 75 | lr_policy: lambda 76 | lr_decay_iters: 50 77 | lambda_L1: 10.0 78 | -------------------------------------------------------------------------------- /assets/icon.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # output and data files 31 | experiments/ 32 | GT/ 33 | scripts/experiments/ 34 | scripts/checkpoint/ 35 | model/pretrained_model 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | .idea/* 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | -------------------------------------------------------------------------------- /helpers/psutil.py: -------------------------------------------------------------------------------- 1 | class FreeMemLinux(object): 2 | """ 3 | Non-cross platform way to get free memory on Linux. Note that this code 4 | uses the `with ... as`, which is conditionally Python 2.5 compatible! 5 | If for some reason you still have Python 2.5 on your system add in the 6 | head of your code, before all imports: 7 | from __future__ import with_statement 8 | """ 9 | 10 | def __init__(self, unit='kB'): 11 | 12 | with open('/proc/meminfo', 'r') as mem: 13 | lines = mem.readlines() 14 | 15 | self._tot = int(lines[0].split()[1]) 16 | self._free = int(lines[1].split()[1]) 17 | self._buff = int(lines[2].split()[1]) 18 | self._cached = int(lines[3].split()[1]) 19 | self._shared = int(lines[20].split()[1]) 20 | self._swapt = int(lines[14].split()[1]) 21 | self._swapf = int(lines[15].split()[1]) 22 | self._swapu = self._swapt - self._swapf 23 | 24 | self.unit = unit 25 | self._convert = self._factor() 26 | 27 | def _factor(self): 28 | """determine the convertion factor""" 29 | if self.unit == 'kB': 30 | return 1 31 | if self.unit == 'k': 32 | return 1024.0 33 | if self.unit == 'MB': 34 | return 1/1024.0 35 | if self.unit == 'GB': 36 | return 1/1024.0/1024.0 37 | if self.unit == '%': 38 | return 1.0/self._tot 39 | else: 40 | raise Exception("Unit not understood") 41 | 42 | @property 43 | def total(self): 44 | return self._convert * self._tot 45 | 46 | @property 47 | def used(self): 48 | return self._convert * (self._tot - self._free) 49 | 50 | @property 51 | def used_real(self): 52 | """memory used which is not cache or buffers""" 53 | return self._convert * (self._tot - self._free - 54 | self._buff - self._cached) 55 | 56 | @property 57 | def shared(self): 58 | return self._convert * (self._tot - self._free) 59 | 60 | @property 61 | def buffers(self): 62 | return self._convert * (self._buff) 63 | 64 | @property 65 | def cached(self): 66 | return self._convert * self._cached 67 | 68 | @property 69 | def user_free(self): 70 | """This is the free memory available for the user""" 71 | return self._convert *(self._free + self._buff + self._cached) 72 | 73 | @property 74 | def swap(self): 75 | return self._convert * self._swapt 76 | 77 | @property 78 | def swap_free(self): 79 | return self._convert * self._swapf 80 | 81 | @property 82 | def swap_used(self): 83 | return self._convert * self._swapu -------------------------------------------------------------------------------- /model/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def bce_loss(input, target, reduce=True): 6 | """ 7 | Numerically stable version of the binary cross-entropy loss function. 8 | As per https://github.com/pytorch/pytorch/issues/751 9 | See the TensorFlow docs for a derivation of this formula: 10 | https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits 11 | Inputs: 12 | - input: PyTorch Tensor of shape (N, ) giving scores. 13 | - target: PyTorch Tensor of shape (N,) containing 0 and 1 giving targets. 14 | Returns: 15 | - A PyTorch Tensor containing the mean BCE loss over the minibatch of 16 | input data. 17 | """ 18 | neg_abs = -input.abs() 19 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 20 | if reduce: 21 | return loss.mean() 22 | else: 23 | return loss 24 | 25 | 26 | def calculate_model_losses(args, pred, target, name, angles=None, angles_pred=None, mu=None, logvar=None, 27 | KL_weight=None, writer=None, counter=None, withangles=False): 28 | total_loss = 0.0 29 | losses = {} 30 | rec_loss = F.l1_loss(pred, target) 31 | total_loss = add_loss(total_loss, rec_loss, losses, name, 1) 32 | if withangles: 33 | angle_loss = F.nll_loss(angles_pred, angles) 34 | total_loss = add_loss(total_loss, angle_loss, losses, 'angle_pred', 1) 35 | 36 | try: 37 | loss_gauss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / mu.size(0) 38 | 39 | except: 40 | print("blowup!!!") 41 | print("logvar", torch.sum(logvar.data), torch.sum(torch.abs(logvar.data)), torch.max(logvar.data), 42 | torch.min(logvar.data)) 43 | print("mu", torch.sum(mu.data), torch.sum(torch.abs(mu.data)), torch.max(mu.data), torch.min(mu.data)) 44 | return total_loss, losses 45 | total_loss = add_loss(total_loss, loss_gauss, losses, 'KLD_Gauss', KL_weight) 46 | 47 | writer.add_scalar('Train_Loss_KL_{}'.format(name), loss_gauss, counter) 48 | writer.add_scalar('Train_Loss_Rec_{}'.format(name), rec_loss, counter) 49 | if withangles: 50 | writer.add_scalar('Train_Loss_Angle_{}'.format(name), angle_loss, counter) 51 | return total_loss, losses 52 | 53 | 54 | def add_loss(total_loss, curr_loss, loss_dict, loss_name, weight=1): 55 | curr_loss_weighted = curr_loss * weight 56 | loss_dict[loss_name] = curr_loss_weighted.item() 57 | if total_loss is not None: 58 | return total_loss + curr_loss_weighted 59 | else: 60 | return curr_loss_weighted 61 | retur 62 | 63 | class VQLoss(nn.Module): 64 | def __init__(self, codebook_weight=1.0): 65 | super().__init__() 66 | self.codebook_weight = codebook_weight 67 | 68 | def forward(self, codebook_loss, inputs, reconstructions, split="train"): 69 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 70 | 71 | nll_loss = rec_loss 72 | nll_loss = torch.mean(nll_loss) 73 | 74 | loss = nll_loss + self.codebook_weight * codebook_loss.mean() 75 | 76 | log = { 77 | "loss_total": loss.clone().detach().mean(), 78 | "loss_codebook": codebook_loss.detach().mean(), 79 | "loss_nll": nll_loss.detach().mean(), 80 | "loss_rec": rec_loss.detach().mean(), 81 | } 82 | 83 | return loss, log -------------------------------------------------------------------------------- /model/model_utils.py: -------------------------------------------------------------------------------- 1 | from termcolor import colored 2 | import torch 3 | 4 | from model.networks.vqvae_networks.network import VQVAE 5 | from typing import * 6 | 7 | import sys 8 | import time 9 | 10 | def load_vqvae(vq_conf, vq_ckpt, opt=None): 11 | assert type(vq_ckpt) == str 12 | 13 | # init vqvae for decoding shapes 14 | mparam = vq_conf.model.params 15 | n_embed = mparam.n_embed 16 | embed_dim = mparam.embed_dim 17 | ddconfig = mparam.ddconfig 18 | 19 | n_down = len(ddconfig.ch_mult) - 1 20 | 21 | vqvae = VQVAE(ddconfig, n_embed, embed_dim) 22 | 23 | map_fn = lambda storage, loc: storage 24 | state_dict = torch.load(vq_ckpt, map_location=map_fn) 25 | if 'vqvae' in state_dict: 26 | vqvae.load_state_dict(state_dict['vqvae']) 27 | else: 28 | vqvae.load_state_dict(state_dict) 29 | 30 | print(colored('[*] VQVAE: weight successfully load from: %s' % vq_ckpt, 'blue')) 31 | vqvae.requires_grad = False 32 | 33 | vqvae.to(opt.hyper.device) 34 | vqvae.eval() 35 | return vqvae 36 | 37 | class AverageAggregator(object): 38 | def __init__(self): 39 | self._value = 0. 40 | self._count = 0 41 | 42 | @property 43 | def value(self): 44 | return self._value / self._count 45 | 46 | @value.setter 47 | def value(self, val: float): 48 | self._value += val 49 | self._count += 1 50 | 51 | def update(self, val: float, n=1): 52 | self._value += val 53 | self._count += n 54 | 55 | 56 | class StatsLogger(object): 57 | __INSTANCE = None 58 | 59 | def __init__(self): 60 | if StatsLogger.__INSTANCE is not None: 61 | raise RuntimeError("StatsLogger should not be directly created") 62 | 63 | self._values = dict() 64 | self._loss = AverageAggregator() 65 | self._output_files = [sys.stdout] 66 | 67 | def add_output_file(self, f): 68 | self._output_files.append(f) 69 | 70 | @property 71 | def loss(self): 72 | return self._loss.value 73 | 74 | @loss.setter 75 | def loss(self, val: float): 76 | self._loss.value = val 77 | 78 | def update_loss(self, val: float, n=1): 79 | self._loss.update(val, n) 80 | 81 | def __getitem__(self, key: str): 82 | if key not in self._values: 83 | self._values[key] = AverageAggregator() 84 | return self._values[key] 85 | 86 | def clear(self): 87 | self._values.clear() 88 | self._loss = AverageAggregator() 89 | for f in self._output_files: 90 | if f.isatty(): # if the file stream is interactive 91 | print(file=f, flush=True) 92 | 93 | def print_progress(self, epoch: Union[int, str], iter: int, precision="{:.5f}"): 94 | fmt = "[{}] [epoch {:4d} iter {:3d}] | loss: " + precision 95 | msg = fmt.format(time.strftime("%Y-%m-%d %H:%M:%S"), epoch, iter, self._loss.value) 96 | for k, v in self._values.items(): 97 | msg += " | " + k + ": " + precision.format(v.value) 98 | for f in self._output_files: 99 | if f.isatty(): # if the file stream is interactive 100 | print(msg + "\b"*len(msg), end="", flush=True, file=f) 101 | else: 102 | print(msg, flush=True, file=f) 103 | 104 | @classmethod 105 | def instance(cls): 106 | if StatsLogger.__INSTANCE is None: 107 | StatsLogger.__INSTANCE = cls() 108 | return StatsLogger.__INSTANCE 109 | -------------------------------------------------------------------------------- /model/diff_utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | source: https://github.com/rosinality/stylegan2-pytorch/blob/master/distributed.py 3 | """ 4 | 5 | import math 6 | import pickle 7 | 8 | import torch 9 | from torch import distributed as dist 10 | from torch.utils.data.sampler import Sampler 11 | 12 | 13 | def get_rank(): 14 | if not dist.is_available(): 15 | return 0 16 | 17 | if not dist.is_initialized(): 18 | return 0 19 | 20 | return dist.get_rank() 21 | 22 | 23 | def synchronize(local_rank=0): 24 | if not dist.is_available(): 25 | return 26 | 27 | if not dist.is_initialized(): 28 | return 29 | 30 | world_size = dist.get_world_size() 31 | 32 | if world_size == 1: 33 | return 34 | 35 | dist.barrier() 36 | # dist.barrier(device_ids=[local_rank]) 37 | 38 | 39 | def get_world_size(): 40 | if not dist.is_available(): 41 | return 1 42 | 43 | if not dist.is_initialized(): 44 | return 1 45 | 46 | return dist.get_world_size() 47 | 48 | 49 | def reduce_sum(tensor): 50 | if not dist.is_available(): 51 | return tensor 52 | 53 | if not dist.is_initialized(): 54 | return tensor 55 | 56 | tensor = tensor.clone() 57 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 58 | 59 | return tensor 60 | 61 | 62 | def gather_grad(params): 63 | world_size = get_world_size() 64 | 65 | if world_size == 1: 66 | return 67 | 68 | for param in params: 69 | if param.grad is not None: 70 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 71 | param.grad.data.div_(world_size) 72 | 73 | 74 | def all_gather(data): 75 | world_size = get_world_size() 76 | 77 | if world_size == 1: 78 | return [data] 79 | 80 | buffer = pickle.dumps(data) 81 | storage = torch.ByteStorage.from_buffer(buffer) 82 | tensor = torch.ByteTensor(storage).to('cuda') 83 | 84 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 85 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 86 | dist.all_gather(size_list, local_size) 87 | size_list = [int(size.item()) for size in size_list] 88 | max_size = max(size_list) 89 | 90 | tensor_list = [] 91 | for _ in size_list: 92 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 93 | 94 | if local_size != max_size: 95 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 96 | tensor = torch.cat((tensor, padding), 0) 97 | 98 | dist.all_gather(tensor_list, tensor) 99 | 100 | data_list = [] 101 | 102 | for size, tensor in zip(size_list, tensor_list): 103 | buffer = tensor.cpu().numpy().tobytes()[:size] 104 | data_list.append(pickle.loads(buffer)) 105 | 106 | return data_list 107 | 108 | 109 | def reduce_loss_dict(loss_dict): 110 | world_size = get_world_size() 111 | # print(world_size) 112 | 113 | if world_size < 2: 114 | return loss_dict 115 | 116 | with torch.no_grad(): 117 | keys = [] 118 | losses = [] 119 | 120 | for k in sorted(loss_dict.keys()): 121 | keys.append(k) 122 | losses.append(loss_dict[k]) 123 | 124 | try: 125 | losses = torch.stack(losses, 0) 126 | except: 127 | print(losses) 128 | dist.reduce(losses, dst=0) 129 | 130 | if dist.get_rank() == 0: 131 | losses /= world_size 132 | 133 | reduced_losses = {k: v for k, v in zip(keys, losses)} 134 | 135 | return reduced_losses -------------------------------------------------------------------------------- /model/networks/diffusion_shape/diff_utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | source: https://github.com/rosinality/stylegan2-pytorch/blob/master/distributed.py 3 | """ 4 | 5 | import math 6 | import pickle 7 | 8 | import torch 9 | from torch import distributed as dist 10 | from torch.utils.data.sampler import Sampler 11 | 12 | 13 | def get_rank(): 14 | if not dist.is_available(): 15 | return 0 16 | 17 | if not dist.is_initialized(): 18 | return 0 19 | 20 | return dist.get_rank() 21 | 22 | 23 | def synchronize(local_rank=0): 24 | if not dist.is_available(): 25 | return 26 | 27 | if not dist.is_initialized(): 28 | return 29 | 30 | world_size = dist.get_world_size() 31 | 32 | if world_size == 1: 33 | return 34 | 35 | dist.barrier() 36 | # dist.barrier(device_ids=[local_rank]) 37 | 38 | 39 | def get_world_size(): 40 | if not dist.is_available(): 41 | return 1 42 | 43 | if not dist.is_initialized(): 44 | return 1 45 | 46 | return dist.get_world_size() 47 | 48 | 49 | def reduce_sum(tensor): 50 | if not dist.is_available(): 51 | return tensor 52 | 53 | if not dist.is_initialized(): 54 | return tensor 55 | 56 | tensor = tensor.clone() 57 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 58 | 59 | return tensor 60 | 61 | 62 | def gather_grad(params): 63 | world_size = get_world_size() 64 | 65 | if world_size == 1: 66 | return 67 | 68 | for param in params: 69 | if param.grad is not None: 70 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 71 | param.grad.data.div_(world_size) 72 | 73 | 74 | def all_gather(data): 75 | world_size = get_world_size() 76 | 77 | if world_size == 1: 78 | return [data] 79 | 80 | buffer = pickle.dumps(data) 81 | storage = torch.ByteStorage.from_buffer(buffer) 82 | tensor = torch.ByteTensor(storage).to('cuda') 83 | 84 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 85 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 86 | dist.all_gather(size_list, local_size) 87 | size_list = [int(size.item()) for size in size_list] 88 | max_size = max(size_list) 89 | 90 | tensor_list = [] 91 | for _ in size_list: 92 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 93 | 94 | if local_size != max_size: 95 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 96 | tensor = torch.cat((tensor, padding), 0) 97 | 98 | dist.all_gather(tensor_list, tensor) 99 | 100 | data_list = [] 101 | 102 | for size, tensor in zip(size_list, tensor_list): 103 | buffer = tensor.cpu().numpy().tobytes()[:size] 104 | data_list.append(pickle.loads(buffer)) 105 | 106 | return data_list 107 | 108 | 109 | def reduce_loss_dict(loss_dict): 110 | world_size = get_world_size() 111 | # print(world_size) 112 | 113 | if world_size < 2: 114 | return loss_dict 115 | 116 | with torch.no_grad(): 117 | keys = [] 118 | losses = [] 119 | 120 | for k in sorted(loss_dict.keys()): 121 | keys.append(k) 122 | losses.append(loss_dict[k]) 123 | 124 | try: 125 | losses = torch.stack(losses, 0) 126 | except: 127 | print(losses) 128 | dist.reduce(losses, dst=0) 129 | 130 | if dist.get_rank() == 0: 131 | losses /= world_size 132 | 133 | reduced_losses = {k: v for k, v in zip(keys, losses)} 134 | 135 | return reduced_losses -------------------------------------------------------------------------------- /model/diff_utils/pix3d_util.py: -------------------------------------------------------------------------------- 1 | 2 | # import numba 3 | import os 4 | import numpy as np 5 | from scipy.interpolate import RegularGridInterpolator as rgi 6 | from PIL import Image 7 | 8 | 9 | downsample_uneven_warned = False 10 | def downsample(vox_in, times, use_max=True): 11 | global downsample_uneven_warned 12 | if vox_in.shape[0] % times != 0 and not downsample_uneven_warned: 13 | print('WARNING: not dividing the space evenly.') 14 | downsample_uneven_warned = True 15 | return _downsample(vox_in, times, use_max=use_max) 16 | 17 | 18 | # @numba.jit(nopython=True, cache=True) 19 | def _downsample(vox_in, times, use_max=True): 20 | dim = vox_in.shape[0] // times 21 | vox_out = np.zeros((dim, dim, dim)) 22 | for x in range(dim): 23 | for y in range(dim): 24 | for z in range(dim): 25 | subx = x * times 26 | suby = y * times 27 | subz = z * times 28 | subvox = vox_in[subx:subx + times, 29 | suby:suby + times, subz:subz + times] 30 | if use_max: 31 | vox_out[x, y, z] = np.max(subvox) 32 | else: 33 | vox_out[x, y, z] = np.mean(subvox) 34 | return vox_out 35 | 36 | def thresholding(V, threshold): 37 | """ 38 | return the original voxel in its bounding box and bounding box coordinates. 39 | """ 40 | if V.max() < threshold: 41 | return np.zeros((2,2,2)), 0, 1, 0, 1, 0, 1 42 | V_bin = (V >= threshold) 43 | x_sum = np.sum(np.sum(V_bin, axis=2), axis=1) 44 | y_sum = np.sum(np.sum(V_bin, axis=2), axis=0) 45 | z_sum = np.sum(np.sum(V_bin, axis=1), axis=0) 46 | 47 | x_min = x_sum.nonzero()[0].min() 48 | y_min = y_sum.nonzero()[0].min() 49 | z_min = z_sum.nonzero()[0].min() 50 | x_max = x_sum.nonzero()[0].max() 51 | y_max = y_sum.nonzero()[0].max() 52 | z_max = z_sum.nonzero()[0].max() 53 | return V[x_min:x_max+1, y_min:y_max+1, z_min:z_max+1], x_min, x_max, y_min, y_max, z_min, z_max 54 | 55 | def interp3(V, xi, yi, zi, fill_value=0): 56 | x = np.arange(V.shape[0]) 57 | y = np.arange(V.shape[1]) 58 | z = np.arange(V.shape[2]) 59 | interp_func = rgi((x, y, z), V, 'linear', False, fill_value) 60 | return interp_func(np.array([xi, yi, zi]).T) 61 | 62 | def mesh_grid(input_lr, output_size): 63 | x_min, x_max, y_min, y_max, z_min, z_max = input_lr 64 | length = max(max(x_max - x_min, y_max - y_min), z_max - z_min) 65 | center = np.array([x_max - x_min, y_max - y_min, z_max - z_min]) / 2. 66 | x = np.linspace(center[0] - length / 2, center[0] + length / 2, output_size[0]) 67 | y = np.linspace(center[1] - length / 2, center[1] + length / 2, output_size[1]) 68 | z = np.linspace(center[2] - length / 2, center[2] + length / 2, output_size[2]) 69 | return np.meshgrid(x, y, z) 70 | 71 | def downsample_voxel(voxel, threshold, output_size, resample=True): 72 | if voxel.shape[0] > 100: 73 | # assert output_size[0] in (32, 128) 74 | # downsample to 32 before finding bounding box 75 | if output_size[0] == 32: 76 | voxel = downsample(voxel, 4, use_max=True) 77 | if not resample: 78 | return voxel 79 | 80 | voxel, x_min, x_max, y_min, y_max, z_min, z_max = thresholding( 81 | voxel, threshold) 82 | x_mesh, y_mesh, z_mesh = mesh_grid( 83 | (x_min, x_max, y_min, y_max, z_min, z_max), output_size) 84 | x_mesh = np.reshape(np.transpose(x_mesh, (1, 0, 2)), (-1)) 85 | y_mesh = np.reshape(np.transpose(y_mesh, (1, 0, 2)), (-1)) 86 | z_mesh = np.reshape(z_mesh, (-1)) 87 | 88 | fill_value = 0 89 | voxel_d = np.reshape(interp3(voxel, x_mesh, y_mesh, z_mesh, fill_value), 90 | (output_size[0], output_size[1], output_size[2])) 91 | return voxel_d -------------------------------------------------------------------------------- /model/networks/diffusion_shape/diff_utils/pix3d_util.py: -------------------------------------------------------------------------------- 1 | 2 | # import numba 3 | import os 4 | import numpy as np 5 | from scipy.interpolate import RegularGridInterpolator as rgi 6 | from PIL import Image 7 | 8 | 9 | downsample_uneven_warned = False 10 | def downsample(vox_in, times, use_max=True): 11 | global downsample_uneven_warned 12 | if vox_in.shape[0] % times != 0 and not downsample_uneven_warned: 13 | print('WARNING: not dividing the space evenly.') 14 | downsample_uneven_warned = True 15 | return _downsample(vox_in, times, use_max=use_max) 16 | 17 | 18 | # @numba.jit(nopython=True, cache=True) 19 | def _downsample(vox_in, times, use_max=True): 20 | dim = vox_in.shape[0] // times 21 | vox_out = np.zeros((dim, dim, dim)) 22 | for x in range(dim): 23 | for y in range(dim): 24 | for z in range(dim): 25 | subx = x * times 26 | suby = y * times 27 | subz = z * times 28 | subvox = vox_in[subx:subx + times, 29 | suby:suby + times, subz:subz + times] 30 | if use_max: 31 | vox_out[x, y, z] = np.max(subvox) 32 | else: 33 | vox_out[x, y, z] = np.mean(subvox) 34 | return vox_out 35 | 36 | def thresholding(V, threshold): 37 | """ 38 | return the original voxel in its bounding box and bounding box coordinates. 39 | """ 40 | if V.max() < threshold: 41 | return np.zeros((2,2,2)), 0, 1, 0, 1, 0, 1 42 | V_bin = (V >= threshold) 43 | x_sum = np.sum(np.sum(V_bin, axis=2), axis=1) 44 | y_sum = np.sum(np.sum(V_bin, axis=2), axis=0) 45 | z_sum = np.sum(np.sum(V_bin, axis=1), axis=0) 46 | 47 | x_min = x_sum.nonzero()[0].min() 48 | y_min = y_sum.nonzero()[0].min() 49 | z_min = z_sum.nonzero()[0].min() 50 | x_max = x_sum.nonzero()[0].max() 51 | y_max = y_sum.nonzero()[0].max() 52 | z_max = z_sum.nonzero()[0].max() 53 | return V[x_min:x_max+1, y_min:y_max+1, z_min:z_max+1], x_min, x_max, y_min, y_max, z_min, z_max 54 | 55 | def interp3(V, xi, yi, zi, fill_value=0): 56 | x = np.arange(V.shape[0]) 57 | y = np.arange(V.shape[1]) 58 | z = np.arange(V.shape[2]) 59 | interp_func = rgi((x, y, z), V, 'linear', False, fill_value) 60 | return interp_func(np.array([xi, yi, zi]).T) 61 | 62 | def mesh_grid(input_lr, output_size): 63 | x_min, x_max, y_min, y_max, z_min, z_max = input_lr 64 | length = max(max(x_max - x_min, y_max - y_min), z_max - z_min) 65 | center = np.array([x_max - x_min, y_max - y_min, z_max - z_min]) / 2. 66 | x = np.linspace(center[0] - length / 2, center[0] + length / 2, output_size[0]) 67 | y = np.linspace(center[1] - length / 2, center[1] + length / 2, output_size[1]) 68 | z = np.linspace(center[2] - length / 2, center[2] + length / 2, output_size[2]) 69 | return np.meshgrid(x, y, z) 70 | 71 | def downsample_voxel(voxel, threshold, output_size, resample=True): 72 | if voxel.shape[0] > 100: 73 | # assert output_size[0] in (32, 128) 74 | # downsample to 32 before finding bounding box 75 | if output_size[0] == 32: 76 | voxel = downsample(voxel, 4, use_max=True) 77 | if not resample: 78 | return voxel 79 | 80 | voxel, x_min, x_max, y_min, y_max, z_min, z_max = thresholding( 81 | voxel, threshold) 82 | x_mesh, y_mesh, z_mesh = mesh_grid( 83 | (x_min, x_max, y_min, y_max, z_min, z_max), output_size) 84 | x_mesh = np.reshape(np.transpose(x_mesh, (1, 0, 2)), (-1)) 85 | y_mesh = np.reshape(np.transpose(y_mesh, (1, 0, 2)), (-1)) 86 | z_mesh = np.reshape(z_mesh, (-1)) 87 | 88 | fill_value = 0 89 | voxel_d = np.reshape(interp3(voxel, x_mesh, y_mesh, z_mesh, fill_value), 90 | (output_size[0], output_size[1], output_size[2])) 91 | return voxel_d -------------------------------------------------------------------------------- /assets/edgecloud-logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /dataset/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from plyfile import PlyData, PlyElement 3 | 4 | 5 | def read_all_ply(filename): 6 | """ Reads a PLY file from disk. 7 | Args: 8 | filename: string 9 | 10 | Returns: np.array, np.array, np.array 11 | """ 12 | file = open(filename, 'rb') 13 | plydata = PlyData.read(file) 14 | points = np.stack((plydata['vertex']['x'], plydata['vertex']['y'], plydata['vertex']['z'])).transpose() 15 | colors = np.stack((plydata['vertex']['red'], plydata['vertex']['green'], plydata['vertex']['blue'])).transpose() 16 | try: 17 | labels = plydata['vertex']['objectId'] 18 | except: 19 | try: 20 | labels = plydata['vertex']['label'] 21 | except: 22 | labels = np.array([]) 23 | try: 24 | faces = np.array(plydata['face'].data['vertex_indices'].tolist()) 25 | except: 26 | faces = np.array([]) 27 | 28 | file.close() 29 | 30 | return points, labels, colors, faces 31 | 32 | 33 | def read_ply(filename, points_only=False): 34 | """ Reads a PLY file from disk. 35 | Args: 36 | filename: string 37 | 38 | Returns: np.array, np.array, np.array 39 | """ 40 | file = open(filename, 'rb') 41 | plydata = PlyData.read(file) 42 | points = np.stack((plydata['vertex']['x'], plydata['vertex']['y'], plydata['vertex']['z'])).transpose() 43 | 44 | if points_only: 45 | return points 46 | try: 47 | labels = plydata['vertex']['objectId'] 48 | except: 49 | try: 50 | labels = plydata['vertex']['label'] 51 | except: 52 | labels = np.array([]) 53 | try: 54 | faces = np.array(plydata['face'].data['vertex_indices'].tolist()) 55 | except: 56 | faces = np.array([]) 57 | 58 | try: 59 | masks = plydata['vertex']['mask'] 60 | except: 61 | masks = np.array([]) 62 | 63 | file.close() 64 | 65 | return points, labels, faces, masks 66 | 67 | 68 | def write_ply(filename, points, mask=None, faces=None): 69 | """ Writes a set of points, optionally with faces, labels and a colormap as a PLY file to disk. 70 | Args: 71 | filename: string 72 | points: np.array 73 | faces: np.array 74 | labels: np.array 75 | colormap: np.array 76 | """ 77 | colors = [[0, 0, 0], [0, 255, 0], [0, 128, 0], [0, 0, 255]] 78 | with open(filename, 'w') as file: 79 | 80 | file.write('ply\n') 81 | file.write('format ascii 1.0\n') 82 | file.write('element vertex %d\n' % points.shape[0]) 83 | file.write('property float x\n') 84 | file.write('property float y\n') 85 | file.write('property float z\n') 86 | 87 | if mask is not None: 88 | file.write('property ushort mask\n') 89 | file.write('property uchar red\n') 90 | file.write('property uchar green\n') 91 | file.write('property uchar blue\n') 92 | 93 | if faces is not None: 94 | file.write('element face %d\n' % faces.shape[0]) 95 | file.write('property list uchar int vertex_indices\n') 96 | 97 | file.write('end_header\n') 98 | 99 | if mask is None: 100 | for point_i in range(points.shape[0]): 101 | file.write('%f %f %f\n' % (points[point_i, 0], points[point_i, 1], points[point_i, 2])) 102 | else: 103 | for point_i in range(points.shape[0]): 104 | file.write('%f %f %f %i %i %i % i\n' % (points[point_i, 0], points[point_i, 1], points[point_i, 2], mask[point_i], colors[mask[point_i]][0], colors[mask[point_i]][1], colors[mask[point_i]][2])) 105 | 106 | if faces is not None: 107 | for face_i in range(faces.shape[0]): 108 | file.write('3 %d %d %d\n' % ( 109 | faces[face_i, 0], faces[face_i, 1], faces[face_i, 2])) 110 | -------------------------------------------------------------------------------- /scripts/pytorch_structural_losses/Makefile: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Uncomment for debugging 3 | # DEBUG := 1 4 | # Pretty build 5 | # Q ?= @ 6 | 7 | CXX := g++ 8 | PYTHON := python 9 | NVCC := /usr/local/cuda/bin/nvcc 10 | 11 | # PYTHON Header path 12 | PYTHON_HEADER_DIR := $(shell $(PYTHON) -c 'from distutils.sysconfig import get_python_inc; print(get_python_inc())') 13 | PYTORCH_INCLUDES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import include_paths; [print(p) for p in include_paths()]') 14 | PYTORCH_LIBRARIES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import library_paths; [print(p) for p in library_paths()]') 15 | 16 | # CUDA ROOT DIR that contains bin/ lib64/ and include/ 17 | # CUDA_DIR := /usr/local/cuda 18 | CUDA_DIR := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import _find_cuda_home; print(_find_cuda_home())') 19 | 20 | INCLUDE_DIRS := ./ $(CUDA_DIR)/include 21 | 22 | INCLUDE_DIRS += $(PYTHON_HEADER_DIR) 23 | INCLUDE_DIRS += $(PYTORCH_INCLUDES) 24 | INCLUDE_DIRS += ./pybind 25 | 26 | # Custom (MKL/ATLAS/OpenBLAS) include and lib directories. 27 | # Leave commented to accept the defaults for your choice of BLAS 28 | # (which should work)! 29 | # BLAS_INCLUDE := /path/to/your/blas 30 | # BLAS_LIB := /path/to/your/blas 31 | 32 | ############################################################################### 33 | SRC_DIR := ./src 34 | OBJ_DIR := ./objs 35 | CPP_SRCS := $(wildcard $(SRC_DIR)/*.cpp) 36 | CU_SRCS := $(wildcard $(SRC_DIR)/*.cu) 37 | OBJS := $(patsubst $(SRC_DIR)/%.cpp,$(OBJ_DIR)/%.o,$(CPP_SRCS)) 38 | CU_OBJS := $(patsubst $(SRC_DIR)/%.cu,$(OBJ_DIR)/cuda/%.o,$(CU_SRCS)) 39 | STATIC_LIB := $(OBJ_DIR)/libmake_pytorch.a 40 | 41 | # CUDA architecture setting: going with all of them. 42 | # For CUDA < 6.0, comment the *_50 through *_61 lines for compatibility. 43 | # For CUDA < 8.0, comment the *_60 and *_61 lines for compatibility. 44 | CUDA_ARCH := -gencode arch=compute_61,code=sm_61 \ 45 | -gencode arch=compute_61,code=compute_61 \ 46 | -gencode arch=compute_52,code=sm_52 47 | 48 | # We will also explicitly add stdc++ to the link target. 49 | LIBRARIES += stdc++ cudart c10 caffe2 torch torch_python caffe2_gpu 50 | 51 | # Debugging 52 | ifeq ($(DEBUG), 1) 53 | COMMON_FLAGS += -DDEBUG -g -O0 54 | # https://gcoe-dresden.de/reaching-the-shore-with-a-fog-warning-my-eurohack-day-4-morning-session/ 55 | NVCCFLAGS += -g -G # -rdc true 56 | else 57 | COMMON_FLAGS += -DNDEBUG -O3 58 | endif 59 | 60 | WARNINGS := -Wall -Wno-sign-compare -Wcomment 61 | 62 | INCLUDE_DIRS += $(BLAS_INCLUDE) 63 | 64 | # Automatic dependency generation (nvcc is handled separately) 65 | CXXFLAGS += -MMD -MP 66 | 67 | # Complete build flags. 68 | COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) \ 69 | -DTORCH_API_INCLUDE_EXTENSION_H -D_GLIBCXX_USE_CXX11_ABI=0 70 | CXXFLAGS += -pthread -fPIC -fwrapv -std=c++14 $(COMMON_FLAGS) $(WARNINGS) 71 | NVCCFLAGS += -std=c++14 -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS) 72 | 73 | # hanqi modified 74 | all: $(STATIC_LIB) 75 | $(PYTHON) setup.py build 76 | @ mv build/lib.linux-x86_64-3.8/StructuralLosses .. 77 | @ mv build/lib.linux-x86_64-3.8/*.so ../StructuralLosses/ 78 | @- $(RM) -rf $(OBJ_DIR) build objs 79 | 80 | $(OBJ_DIR): 81 | @ mkdir -p $@ 82 | @ mkdir -p $@/cuda 83 | 84 | $(OBJ_DIR)/%.o: $(SRC_DIR)/%.cpp | $(OBJ_DIR) 85 | @ echo CXX $< 86 | $(Q)$(CXX) $< $(CXXFLAGS) -c -o $@ 87 | 88 | $(OBJ_DIR)/cuda/%.o: $(SRC_DIR)/%.cu | $(OBJ_DIR) 89 | @ echo NVCC $< 90 | $(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -M $< -o ${@:.o=.d} \ 91 | -odir $(@D) 92 | $(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@ 93 | 94 | $(STATIC_LIB): $(OBJS) $(CU_OBJS) | $(OBJ_DIR) 95 | $(RM) -f $(STATIC_LIB) 96 | $(RM) -rf build dist 97 | @ echo LD -o $@ 98 | ar rc $(STATIC_LIB) $(OBJS) $(CU_OBJS) 99 | 100 | clean: 101 | @- $(RM) -rf $(OBJ_DIR) build dist ../StructuralLosses 102 | 103 | -------------------------------------------------------------------------------- /helpers/visualize_graph.py: -------------------------------------------------------------------------------- 1 | from graphviz import Digraph 2 | import os 3 | from helpers import viz_util 4 | import json 5 | 6 | 7 | def visualize_scene_graph(graph, relationships, rel_filter_in = [], rel_filter_out = [], obj_ids = [], title ="", scan_id="", 8 | outfolder="./vis_graphs/"): 9 | g = Digraph(comment='Scene Graph' + title, format='png') 10 | 11 | for (i,obj) in enumerate(graph["objects"]): 12 | if (len(obj_ids) == 0) or (int(obj['id']) in obj_ids): 13 | if "node_mask" in graph.keys() and graph["node_mask"][i] == 0: 14 | g.node(str(obj['id']), obj["label"], fontname='helvetica', color=obj["ply_color"], fontcolor='red') 15 | else: 16 | g.node(str(obj['id']), obj["label"], fontname='helvetica', color=obj["ply_color"], style='filled') 17 | if "edge_mask" in graph.keys(): 18 | edge_mask = graph["edge_mask"] 19 | else: 20 | edge_mask = None 21 | draw_edges(g, graph["relationships"], relationships, rel_filter_in, rel_filter_out, obj_ids, edge_mask) 22 | g.render(outfolder + scan_id) 23 | 24 | 25 | def draw_edges(g, graph_relationships, relationships, rel_filter_in, rel_filter_out, obj_ids, edge_mask=None): 26 | edges = {} 27 | if edge_mask is not None: 28 | joined_edge_mask = {} 29 | for (i, rel) in enumerate(graph_relationships): 30 | rel_text = relationships[rel[2]] 31 | if (len(rel_filter_in) == 0 or (rel_text.rstrip() in rel_filter_in)) and not rel_text.rstrip() in rel_filter_out: 32 | if (len(obj_ids) == 0) or ((rel[1] in obj_ids) and (rel[0] in obj_ids)): 33 | index = str(rel[0]) + "_" + str(rel[1]) 34 | if index not in edges: 35 | edges[index] = [] 36 | if edge_mask is not None: 37 | joined_edge_mask[index] = [] 38 | edges[index].append(rel[3]) 39 | if edge_mask is not None: 40 | joined_edge_mask[index].append(edge_mask[i]) 41 | 42 | for (i,edge) in enumerate(edges): 43 | edge_obj_sub = edge.split("_") 44 | rels = ', '.join(edges[edge]) 45 | if edge_mask is not None and 0 in joined_edge_mask[edge]: 46 | g.edge(str(edge_obj_sub[0]), str(edge_obj_sub[1]), label=rels, color='red', style='dotted') 47 | else: 48 | g.edge(str(edge_obj_sub[0]), str(edge_obj_sub[1]), label=rels, color='grey') 49 | 50 | 51 | def run(use_sampled_graphs=True, scan_id="4d3d82b6-8cf4-2e04-830a-4303fa0e79c7", split=None, with_manipulation=False, 52 | data_path='./GT', outfolder="./vis_graphs/", graphfile='graphs_layout.yml'): 53 | 54 | if use_sampled_graphs: 55 | # use this option to customize your own graphs in the yaml format 56 | palette_json = os.path.join(data_path, "color_palette.json") 57 | color_palette = json.load(open(palette_json, 'r'))['hex'] 58 | graph_yaml = os.path.join(data_path, graphfile) 59 | else: 60 | # use this option to read scene graphs from the dataset 61 | relationships_json = os.path.join(data_path, 'relationships_validation_clean.json') #"relationships_train.json") 62 | objects_json = os.path.join(data_path, "objects.json") 63 | 64 | relationships = viz_util.read_relationships(os.path.join(data_path, "relationships.txt")) 65 | 66 | if use_sampled_graphs: 67 | rel_label_to_id = {} 68 | for (i,r) in enumerate(relationships): 69 | rel_label_to_id[r] = i 70 | graph = viz_util.load_semantic_scene_graphs_custom(graph_yaml, color_palette, rel_label_to_id, with_manipuation=False) 71 | if with_manipulation: 72 | graph_mani = viz_util.load_semantic_scene_graphs_custom(graph_yaml, color_palette, rel_label_to_id, with_manipuation=True) 73 | else: 74 | graph = viz_util.load_semantic_scene_graphs(relationships_json, objects_json) 75 | 76 | if split is not '': 77 | scan_id = scan_id + '_' + split 78 | 79 | filter_dict_in = [] 80 | filter_dict_out = [] # ["left", "right", "behind", "front", "same as", "same symmetry as", "bigger than", "lower than", "higher than", "close by"] 81 | for scan_id in [scan_id]: 82 | visualize_scene_graph(graph[scan_id], relationships, filter_dict_in, filter_dict_out, [], "v1", scan_id=scan_id, 83 | outfolder=outfolder) 84 | if with_manipulation and use_sampled_graphs: 85 | # manipulation only supported for custom graphs 86 | visualize_scene_graph(graph_mani[scan_id], relationships, filter_dict_in, filter_dict_out, [], "v1", scan_id=scan_id + "_mani", 87 | outfolder=outfolder) 88 | 89 | idx = [o['id'] for o in graph[scan_id]['objects']] 90 | color = [o['ply_color'] for o in graph[scan_id]['objects']] 91 | # return used colors so that they can be used for 3D model visualization 92 | return dict(zip(idx, color)) 93 | 94 | -------------------------------------------------------------------------------- /model/networks/diffusion_shape/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from termcolor import colored, cprint 3 | from .diff_utils import util 4 | import torch 5 | from torch import nn 6 | 7 | def create_model(opt): 8 | model = None 9 | 10 | if opt.model == 'vqvae': 11 | from ..vqvae_networks.vqvae_model import VQVAEModel 12 | model = VQVAEModel() 13 | 14 | elif opt.model == 'sdfusion': 15 | from .sdfusion_model import SDFusionModel 16 | model = SDFusionModel() 17 | 18 | elif opt.model == 'sdfusion-txt2shape': 19 | from .sdfusion_txt2shape_model import SDFusionText2ShapeModel 20 | model = SDFusionText2ShapeModel() 21 | 22 | else: 23 | raise ValueError("Model [%s] not recognized." % opt.model) 24 | 25 | model.initialize(opt) 26 | cprint("[*] Model has been created: %s" % model.name(), 'blue') 27 | return model 28 | 29 | 30 | # modified from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 31 | class BaseModel(): 32 | def name(self): 33 | return 'BaseModel' 34 | 35 | def initialize(self, opt): 36 | self.opt = opt 37 | self.gpu_ids = opt.hyper.gpu_ids 38 | self.isTrain = opt.hyper.isTrain 39 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 40 | 41 | self.model_names = [] 42 | self.epoch_labels = [] 43 | self.optimizers = [] 44 | 45 | def set_input(self, input): 46 | self.input = input 47 | 48 | def forward(self): 49 | pass 50 | 51 | def get_image_paths(self): 52 | pass 53 | 54 | def optimize_parameters(self): 55 | pass 56 | 57 | def get_current_visuals(self, vocab, obj_and_shape): 58 | return self.input 59 | 60 | def get_current_errors(self): 61 | return {} 62 | 63 | # define the optimizers 64 | def set_optimizers(self): 65 | pass 66 | 67 | def set_requires_grad(self, nets, requires_grad=False): 68 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 69 | Parameters: 70 | nets (network list) -- a list of networks 71 | requires_grad (bool) -- whether the networks require gradients or not 72 | """ 73 | if not isinstance(nets, list): 74 | nets = [nets] 75 | for net in nets: 76 | if net is not None: 77 | for param in net.parameters(): 78 | param.requires_grad = requires_grad 79 | 80 | # update learning rate (called once every epoch) 81 | def update_learning_rate(self): 82 | for scheduler in self.schedulers: 83 | scheduler.step() 84 | lr = self.optimizers[0].param_groups[0]['lr'] 85 | print('[*] learning rate = %.7f' % lr) 86 | 87 | def eval(self): 88 | for name in self.model_names: 89 | if isinstance(name, str): 90 | net = getattr(self, 'net' + name) 91 | net.eval() 92 | 93 | def train(self): 94 | for name in self.model_names: 95 | if isinstance(name, str): 96 | net = getattr(self, 'net' + name) 97 | net.train() 98 | 99 | # print network information 100 | def print_networks(self, verbose=False): 101 | print('---------- Networks initialized -------------') 102 | for name in self.model_names: 103 | if isinstance(name, str): 104 | net = getattr(self, 'net' + name) 105 | num_params = 0 106 | for param in net.parameters(): 107 | num_params += param.numel() 108 | if verbose: 109 | print(net) 110 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 111 | print('-----------------------------------------------') 112 | 113 | def tocuda(self, var_names): 114 | for name in var_names: 115 | if isinstance(name, str): 116 | var = getattr(self, name) 117 | # setattr(self, name, var.cuda(self.gpu_ids[0], non_blocking=True)) 118 | setattr(self, name, var.cuda(self.opt.hyper.device, non_blocking=True)) 119 | 120 | 121 | def tnsrs2ims(self, tensor_names): 122 | ims = [] 123 | for name in tensor_names: 124 | if isinstance(name, str): 125 | var = getattr(self, name) 126 | ims.append(util.tensor2im(var.data)) 127 | return ims 128 | -------------------------------------------------------------------------------- /model/networks/diffusion_layout/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ''' 4 | https://github.com/open-mmlab/mmdetection3d/blob/master/mmdet3d/core/bbox/iou_calculators/iou3d_calculator.py 5 | ''' 6 | 7 | def axis_aligned_bbox_overlaps_3d(bboxes1, 8 | bboxes2, 9 | mode='iou', 10 | is_aligned=False, 11 | eps=1e-6): 12 | """Calculate overlap between two set of axis aligned 3D bboxes. If 13 | ``is_aligned`` is ``False``, then calculate the overlaps between each bbox 14 | of bboxes1 and bboxes2, otherwise the overlaps between each aligned pair of 15 | bboxes1 and bboxes2. 16 | Args: 17 | bboxes1 (Tensor): shape (B, m, 6) in 18 | format or empty. 19 | bboxes2 (Tensor): shape (B, n, 6) in 20 | format or empty. 21 | B indicates the batch dim, in shape (B1, B2, ..., Bn). 22 | If ``is_aligned`` is ``True``, then m and n must be equal. 23 | mode (str): "iou" (intersection over union) or "giou" (generalized 24 | intersection over union). 25 | is_aligned (bool, optional): If True, then m and n must be equal. 26 | Defaults to False. 27 | eps (float, optional): A value added to the denominator for numerical 28 | stability. Defaults to 1e-6. 29 | Returns: 30 | Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,) 31 | """ 32 | 33 | assert mode in ['iou', 'giou'], f'Unsupported mode {mode}' 34 | # Either the boxes are empty or the length of boxes's last dimension is 6 35 | assert (bboxes1.size(-1) == 6 or bboxes1.size(0) == 0) 36 | assert (bboxes2.size(-1) == 6 or bboxes2.size(0) == 0) 37 | 38 | # Batch dim must be the same 39 | # Batch dim: (B1, B2, ... Bn) 40 | assert bboxes1.shape[:-2] == bboxes2.shape[:-2] 41 | batch_shape = bboxes1.shape[:-2] 42 | 43 | rows = bboxes1.size(-2) 44 | cols = bboxes2.size(-2) 45 | if is_aligned: 46 | assert rows == cols 47 | 48 | if rows * cols == 0: 49 | if is_aligned: 50 | return bboxes1.new(batch_shape + (rows, )) 51 | else: 52 | return bboxes1.new(batch_shape + (rows, cols)) 53 | 54 | area1 = (bboxes1[..., 3] - 55 | bboxes1[..., 0]) * (bboxes1[..., 4] - bboxes1[..., 1]) * ( 56 | bboxes1[..., 5] - bboxes1[..., 2]) 57 | area2 = (bboxes2[..., 3] - 58 | bboxes2[..., 0]) * (bboxes2[..., 4] - bboxes2[..., 1]) * ( 59 | bboxes2[..., 5] - bboxes2[..., 2]) 60 | 61 | if is_aligned: 62 | lt = torch.max(bboxes1[..., :3], bboxes2[..., :3]) # [B, rows, 3] 63 | rb = torch.min(bboxes1[..., 3:], bboxes2[..., 3:]) # [B, rows, 3] 64 | 65 | wh = (rb - lt).clamp(min=0) # [B, rows, 2] 66 | overlap = wh[..., 0] * wh[..., 1] * wh[..., 2] 67 | 68 | if mode in ['iou', 'giou']: 69 | union = area1 + area2 - overlap 70 | else: 71 | union = area1 72 | if mode == 'giou': 73 | enclosed_lt = torch.min(bboxes1[..., :3], bboxes2[..., :3]) 74 | enclosed_rb = torch.max(bboxes1[..., 3:], bboxes2[..., 3:]) 75 | else: 76 | lt = torch.max(bboxes1[..., :, None, :3], 77 | bboxes2[..., None, :, :3]) # [B, rows, cols, 3] 78 | rb = torch.min(bboxes1[..., :, None, 3:], 79 | bboxes2[..., None, :, 3:]) # [B, rows, cols, 3] 80 | 81 | wh = (rb - lt).clamp(min=0) # [B, rows, cols, 3] 82 | overlap = wh[..., 0] * wh[..., 1] * wh[..., 2] 83 | 84 | if mode in ['iou', 'giou']: 85 | union = area1[..., None] + area2[..., None, :] - overlap 86 | if mode == 'giou': 87 | enclosed_lt = torch.min(bboxes1[..., :, None, :3], 88 | bboxes2[..., None, :, :3]) 89 | enclosed_rb = torch.max(bboxes1[..., :, None, 3:], 90 | bboxes2[..., None, :, 3:]) 91 | 92 | eps = union.new_tensor([eps]) 93 | union = torch.max(union, eps) 94 | ious = overlap / union 95 | if mode in ['iou']: 96 | return ious 97 | # calculate gious 98 | enclose_wh = (enclosed_rb - enclosed_lt).clamp(min=0) 99 | enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1] * enclose_wh[..., 2] 100 | enclose_area = torch.max(enclose_area, eps) 101 | gious = ious - (enclose_area - union) / enclose_area 102 | return gious 103 | -------------------------------------------------------------------------------- /model/diff_utils/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | from PIL import Image 8 | from einops import rearrange 9 | 10 | import torch 11 | import torchvision.utils as vutils 12 | 13 | from torch.autograd import Variable 14 | from torch.optim.lr_scheduler import _LRScheduler 15 | 16 | 17 | ################# START: PyTorch Tensor functions ################# 18 | 19 | # Converts a Tensor into a Numpy array 20 | # |imtype|: the desired type of the converted numpy array 21 | def tensor2im(image_tensor, imtype=np.uint8): 22 | # image_numpy = image_tensor[0].cpu().float().numpy() 23 | # if image_numpy.shape[0] == 1: 24 | # image_numpy = np.tile(image_numpy, (3, 1, 1)) 25 | # image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 26 | # return image_numpy.astype(imtype) 27 | 28 | n_img = min(image_tensor.shape[0], 16) 29 | image_tensor = image_tensor[:n_img] 30 | 31 | if image_tensor.shape[1] == 1: 32 | image_tensor = image_tensor.repeat(1, 3, 1, 1) 33 | 34 | # if image_tensor.shape[1] == 4: 35 | # import pdb; pdb.set_trace() 36 | 37 | image_tensor = vutils.make_grid( image_tensor, nrow=4 ) 38 | 39 | image_numpy = image_tensor.cpu().float().numpy() 40 | image_numpy = ( np.transpose( image_numpy, (1, 2, 0) ) + 1) / 2.0 * 255. 41 | return image_numpy.astype(imtype) 42 | 43 | def tensor_to_pil(tensor): 44 | # """ assume shape: c h w """ 45 | if tensor.dim() == 4: 46 | tensor = vutils.make_grid(tensor) 47 | 48 | # assert tensor.dim() == 3 49 | return Image.fromarray( (rearrange(tensor, 'c h w -> h w c').cpu().numpy() * 255.).astype(np.uint8) ) 50 | 51 | ################# END: PyTorch Tensor functions ################# 52 | 53 | 54 | def to_variable(numpy_data, volatile=False): 55 | numpy_data = numpy_data.astype(np.float32) 56 | torch_data = torch.from_numpy(numpy_data).float() 57 | variable = Variable(torch_data, volatile=volatile) 58 | return variable 59 | 60 | def diagnose_network(net, name='network'): 61 | mean = 0.0 62 | count = 0 63 | for param in net.parameters(): 64 | if param.grad is not None: 65 | mean += torch.mean(torch.abs(param.grad.data)) 66 | count += 1 67 | if count > 0: 68 | mean = mean / count 69 | print(name) 70 | print(mean) 71 | 72 | 73 | def save_image(image_numpy, image_path): 74 | image_pil = Image.fromarray(image_numpy) 75 | image_pil.save(image_path) 76 | 77 | 78 | def print_numpy(x, val=True, shp=False): 79 | x = x.astype(np.float64) 80 | if shp: 81 | print('shape,', x.shape) 82 | if val: 83 | x = x.flatten() 84 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 85 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 86 | 87 | 88 | def mkdirs(paths): 89 | if isinstance(paths, list) and not isinstance(paths, str): 90 | for path in paths: 91 | mkdir(path) 92 | else: 93 | mkdir(paths) 94 | 95 | 96 | def mkdir(path): 97 | if not os.path.exists(path): 98 | os.makedirs(path) 99 | 100 | def seed_everything(seed): 101 | 102 | random.seed(seed) 103 | os.environ['PYTHONHASHSEED'] = str(seed) 104 | np.random.seed(seed) 105 | torch.manual_seed(seed) 106 | torch.cuda.manual_seed(seed) 107 | torch.backends.cudnn.deterministic = True 108 | torch.backends.cudnn.benchmark = True 109 | 110 | 111 | def iou(x_gt, x, thres): 112 | thres_gt = 0.0 113 | 114 | # compute iou 115 | # > 0 free space, < 0 occupied 116 | x_gt_mask = x_gt.clone().detach() 117 | x_gt_mask[x_gt > thres_gt] = 0. 118 | x_gt_mask[x_gt <= thres_gt] = 1. 119 | 120 | x_mask = x.clone().detach() 121 | x_mask[x > thres] = 0. 122 | x_mask[x <= thres] = 1. 123 | 124 | inter = torch.logical_and(x_gt_mask, x_mask) 125 | union = torch.logical_or(x_gt_mask, x_mask) 126 | inter = rearrange(inter, 'b c d h w -> b (c d h w)') 127 | union = rearrange(union, 'b c d h w -> b (c d h w)') 128 | 129 | iou = inter.sum(1) / (union.sum(1) + 1e-12) 130 | return iou 131 | 132 | #################### START: MISCELLANEOUS #################### 133 | def count_params(model, verbose=False): 134 | total_params = sum(p.numel() for p in model.parameters()) 135 | if verbose: 136 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 137 | return total_params 138 | 139 | #################### END: MISCELLANEOUS #################### 140 | 141 | 142 | 143 | # Noam Learning rate schedule. 144 | # From https://github.com/tugstugi/pytorch-saltnet/blob/master/utils/lr_scheduler.py 145 | class NoamLR(_LRScheduler): 146 | 147 | def __init__(self, optimizer, warmup_steps): 148 | self.warmup_steps = warmup_steps 149 | super().__init__(optimizer) 150 | 151 | def get_lr(self): 152 | last_epoch = max(1, self.last_epoch) 153 | scale = self.warmup_steps ** 0.5 * min(last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5)) 154 | return [base_lr * scale for base_lr in self.base_lrs] -------------------------------------------------------------------------------- /model/networks/diffusion_shape/diff_utils/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | from PIL import Image 8 | from einops import rearrange 9 | 10 | import torch 11 | import torchvision.utils as vutils 12 | 13 | from torch.autograd import Variable 14 | from torch.optim.lr_scheduler import _LRScheduler 15 | 16 | 17 | ################# START: PyTorch Tensor functions ################# 18 | 19 | # Converts a Tensor into a Numpy array 20 | # |imtype|: the desired type of the converted numpy array 21 | def tensor2im(image_tensor, imtype=np.uint8): 22 | # image_numpy = image_tensor[0].cpu().float().numpy() 23 | # if image_numpy.shape[0] == 1: 24 | # image_numpy = np.tile(image_numpy, (3, 1, 1)) 25 | # image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 26 | # return image_numpy.astype(imtype) 27 | 28 | n_img = min(image_tensor.shape[0], 16) 29 | image_tensor = image_tensor[:n_img] 30 | 31 | if image_tensor.shape[1] == 1: 32 | image_tensor = image_tensor.repeat(1, 3, 1, 1) 33 | 34 | # if image_tensor.shape[1] == 4: 35 | # import pdb; pdb.set_trace() 36 | 37 | image_tensor = vutils.make_grid( image_tensor, nrow=4 ) 38 | 39 | image_numpy = image_tensor.cpu().float().numpy() 40 | image_numpy = ( np.transpose( image_numpy, (1, 2, 0) ) + 1) / 2.0 * 255. 41 | return image_numpy.astype(imtype) 42 | 43 | def tensor_to_pil(tensor): 44 | # """ assume shape: c h w """ 45 | if tensor.dim() == 4: 46 | tensor = vutils.make_grid(tensor) 47 | 48 | # assert tensor.dim() == 3 49 | return Image.fromarray( (rearrange(tensor, 'c h w -> h w c').cpu().numpy() * 255.).astype(np.uint8) ) 50 | 51 | ################# END: PyTorch Tensor functions ################# 52 | 53 | 54 | def to_variable(numpy_data, volatile=False): 55 | numpy_data = numpy_data.astype(np.float32) 56 | torch_data = torch.from_numpy(numpy_data).float() 57 | variable = Variable(torch_data, volatile=volatile) 58 | return variable 59 | 60 | def diagnose_network(net, name='network'): 61 | mean = 0.0 62 | count = 0 63 | for param in net.parameters(): 64 | if param.grad is not None: 65 | mean += torch.mean(torch.abs(param.grad.data)) 66 | count += 1 67 | if count > 0: 68 | mean = mean / count 69 | print(name) 70 | print(mean) 71 | 72 | 73 | def save_image(image_numpy, image_path): 74 | image_pil = Image.fromarray(image_numpy) 75 | image_pil.save(image_path) 76 | 77 | 78 | def print_numpy(x, val=True, shp=False): 79 | x = x.astype(np.float64) 80 | if shp: 81 | print('shape,', x.shape) 82 | if val: 83 | x = x.flatten() 84 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 85 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 86 | 87 | 88 | def mkdirs(paths): 89 | if isinstance(paths, list) and not isinstance(paths, str): 90 | for path in paths: 91 | mkdir(path) 92 | else: 93 | mkdir(paths) 94 | 95 | 96 | def mkdir(path): 97 | if not os.path.exists(path): 98 | os.makedirs(path) 99 | 100 | def seed_everything(seed): 101 | 102 | random.seed(seed) 103 | os.environ['PYTHONHASHSEED'] = str(seed) 104 | np.random.seed(seed) 105 | torch.manual_seed(seed) 106 | torch.cuda.manual_seed(seed) 107 | torch.backends.cudnn.deterministic = True 108 | torch.backends.cudnn.benchmark = True 109 | 110 | 111 | def iou(x_gt, x, thres): 112 | thres_gt = 0.0 113 | 114 | # compute iou 115 | # > 0 free space, < 0 occupied 116 | x_gt_mask = x_gt.clone().detach() 117 | x_gt_mask[x_gt > thres_gt] = 0. 118 | x_gt_mask[x_gt <= thres_gt] = 1. 119 | 120 | x_mask = x.clone().detach() 121 | x_mask[x > thres] = 0. 122 | x_mask[x <= thres] = 1. 123 | 124 | inter = torch.logical_and(x_gt_mask, x_mask) 125 | union = torch.logical_or(x_gt_mask, x_mask) 126 | inter = rearrange(inter, 'b c d h w -> b (c d h w)') 127 | union = rearrange(union, 'b c d h w -> b (c d h w)') 128 | 129 | iou = inter.sum(1) / (union.sum(1) + 1e-12) 130 | return iou 131 | 132 | #################### START: MISCELLANEOUS #################### 133 | def count_params(model, verbose=False): 134 | total_params = sum(p.numel() for p in model.parameters()) 135 | if verbose: 136 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 137 | return total_params 138 | 139 | #################### END: MISCELLANEOUS #################### 140 | 141 | 142 | 143 | # Noam Learning rate schedule. 144 | # From https://github.com/tugstugi/pytorch-saltnet/blob/master/utils/lr_scheduler.py 145 | class NoamLR(_LRScheduler): 146 | 147 | def __init__(self, optimizer, warmup_steps): 148 | self.warmup_steps = warmup_steps 149 | super().__init__(optimizer) 150 | 151 | def get_lr(self): 152 | last_epoch = max(1, self.last_epoch) 153 | scale = self.warmup_steps ** 0.5 * min(last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5)) 154 | return [base_lr * scale for base_lr in self.base_lrs] -------------------------------------------------------------------------------- /helpers/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import logging 3 | import math 4 | from bisect import bisect_right 5 | from typing import List 6 | import torch 7 | from fvcore.common.param_scheduler import ( 8 | CompositeParamScheduler, 9 | ConstantParamScheduler, 10 | LinearParamScheduler, 11 | ParamScheduler, 12 | ) 13 | 14 | try: 15 | from torch.optim.lr_scheduler import LRScheduler 16 | except ImportError: 17 | from torch.optim.lr_scheduler import _LRScheduler as LRScheduler 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class WarmupParamScheduler(CompositeParamScheduler): 23 | """ 24 | Add an initial warmup stage to another scheduler. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | scheduler: ParamScheduler, 30 | warmup_factor: float, 31 | warmup_length: float, 32 | warmup_method: str = "linear", 33 | rescale_interval: bool = False, 34 | ): 35 | """ 36 | Args: 37 | scheduler: warmup will be added at the beginning of this scheduler 38 | warmup_factor: the factor w.r.t the initial value of ``scheduler``, e.g. 0.001 39 | warmup_length: the relative length (in [0, 1]) of warmup steps w.r.t the entire 40 | training, e.g. 0.01 41 | warmup_method: one of "linear" or "constant" 42 | rescale_interval: whether we will rescale the interval of the scheduler after 43 | warmup 44 | """ 45 | end_value = scheduler(warmup_length) # the value to reach when warmup ends 46 | start_value = warmup_factor * scheduler(0.0) 47 | if warmup_method == "constant": 48 | warmup = ConstantParamScheduler(start_value) 49 | elif warmup_method == "linear": 50 | warmup = LinearParamScheduler(start_value, end_value) 51 | else: 52 | raise ValueError("Unknown warmup method: {}".format(warmup_method)) 53 | super().__init__( 54 | [warmup, scheduler], 55 | interval_scaling=["rescaled", "rescaled" if rescale_interval else "fixed"], 56 | lengths=[warmup_length, 1 - warmup_length], 57 | ) 58 | 59 | 60 | class LRMultiplier(LRScheduler): 61 | """ 62 | A LRScheduler which uses fvcore :class:`ParamScheduler` to multiply the 63 | learning rate of each param in the optimizer. 64 | Every step, the learning rate of each parameter becomes its initial value 65 | multiplied by the output of the given :class:`ParamScheduler`. 66 | 67 | The absolute learning rate value of each parameter can be different. 68 | This scheduler can be used as long as the relative scale among them do 69 | not change during training. 70 | 71 | Examples: 72 | :: 73 | LRMultiplier( 74 | opt, 75 | WarmupParamScheduler( 76 | MultiStepParamScheduler( 77 | [1, 0.1, 0.01], 78 | milestones=[60000, 80000], 79 | num_updates=90000, 80 | ), 0.001, 100 / 90000 81 | ), 82 | max_iter=90000 83 | ) 84 | """ 85 | 86 | # NOTES: in the most general case, every LR can use its own scheduler. 87 | # Supporting this requires interaction with the optimizer when its parameter 88 | # group is initialized. For example, classyvision implements its own optimizer 89 | # that allows different schedulers for every parameter group. 90 | # To avoid this complexity, we use this class to support the most common cases 91 | # where the relative scale among all LRs stay unchanged during training. In this 92 | # case we only need a total of one scheduler that defines the relative LR multiplier. 93 | 94 | def __init__( 95 | self, 96 | optimizer: torch.optim.Optimizer, 97 | multiplier: ParamScheduler, 98 | max_iter: int, 99 | last_iter: int = -1, 100 | ): 101 | """ 102 | Args: 103 | optimizer, last_iter: See ``torch.optim.lr_scheduler.LRScheduler``. 104 | ``last_iter`` is the same as ``last_epoch``. 105 | multiplier: a fvcore ParamScheduler that defines the multiplier on 106 | every LR of the optimizer 107 | max_iter: the total number of training iterations 108 | """ 109 | if not isinstance(multiplier, ParamScheduler): 110 | raise ValueError( 111 | "_LRMultiplier(multiplier=) must be an instance of fvcore " 112 | f"ParamScheduler. Got {multiplier} instead." 113 | ) 114 | self._multiplier = multiplier 115 | self._max_iter = max_iter 116 | super().__init__(optimizer, last_epoch=last_iter) 117 | 118 | def state_dict(self): 119 | # fvcore schedulers are stateless. Only keep pytorch scheduler states 120 | return {"base_lrs": self.base_lrs, "last_epoch": self.last_epoch} 121 | 122 | def get_lr(self) -> List[float]: 123 | multiplier = self._multiplier(self.last_epoch / self._max_iter) 124 | return [base_lr * multiplier for base_lr in self.base_lrs] 125 | 126 | 127 | -------------------------------------------------------------------------------- /scripts/pytorch_structural_losses/src/nndistance.cu: -------------------------------------------------------------------------------- 1 | 2 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 3 | const int batch=512; 4 | __shared__ float buf[batch*3]; 5 | for (int i=blockIdx.x;ibest){ 117 | result[(i*n+j)]=best; 118 | result_i[(i*n+j)]=best_i; 119 | } 120 | } 121 | __syncthreads(); 122 | } 123 | } 124 | } 125 | void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 126 | NmDistanceKernel<<>>(b,n,xyz,m,xyz2,result,result_i); 127 | NmDistanceKernel<<>>(b,m,xyz2,n,xyz,result2,result2_i); 128 | } 129 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 130 | for (int i=blockIdx.x;i>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2); 153 | NmDistanceGradKernel<<>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1); 154 | } 155 | 156 | -------------------------------------------------------------------------------- /model/networks/diffusion_shape/sg_diff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from functools import partial 4 | from torch import nn 5 | from omegaconf import OmegaConf 6 | from model.networks.diffusion_networks.network import DiffusionUNet 7 | 8 | 9 | from model.networks.diffusion_networks.ldm_diffusion_util import ( 10 | make_beta_schedule, 11 | extract_into_tensor, 12 | noise_like, 13 | exists, 14 | default, 15 | ) 16 | 17 | class sg_diff(nn.Module): 18 | def __init__(self, df_cfg= 'configs/sdfusion-txt2shape.yaml', vq_cfg= 'configs/vqvae_snet.yaml'): 19 | super(sg_diff, self).__init__() 20 | self.df_cfg =df_cfg 21 | self.vq_cfg = vq_cfg 22 | df_conf = OmegaConf.load(self.df_cfg) 23 | vq_conf = OmegaConf.load(self.vq_cfg) 24 | df_model_params = df_conf.model.params 25 | unet_params = df_conf.unet.params 26 | self.df = DiffusionUNet(unet_params, vq_conf=vq_conf, conditioning_key=df_model_params.conditioning_key) 27 | self.df.to('cuda') 28 | 29 | self.init_diffusion_params(uc_scale=3., df_cfg=opt) 30 | 31 | def init_diffusion_params(self, uc_scale=3., df_cfg=None): 32 | 33 | df_conf = OmegaConf.load(df_cfg) 34 | df_model_params = df_conf.model.params 35 | 36 | # ref: ddpm.py, line 44 in __init__() 37 | self.parameterization = "eps" 38 | self.learn_logvar = False 39 | 40 | self.v_posterior = 0. 41 | self.original_elbo_weight = 0. 42 | self.l_simple_weight = 1. 43 | # ref: ddpm.py, register_schedule 44 | self.register_schedule( 45 | timesteps=df_model_params.timesteps, 46 | linear_start=df_model_params.linear_start, 47 | linear_end=df_model_params.linear_end, 48 | ) 49 | 50 | logvar_init = 0. 51 | self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) 52 | # for cls-free guidance 53 | self.uc_scale = uc_scale 54 | 55 | def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, 56 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 57 | if exists(given_betas): 58 | betas = given_betas 59 | else: 60 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 61 | cosine_s=cosine_s) 62 | alphas = 1. - betas 63 | alphas_cumprod = np.cumprod(alphas, axis=0) 64 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 65 | 66 | timesteps, = betas.shape 67 | self.num_timesteps = int(timesteps) 68 | self.linear_start = linear_start 69 | self.linear_end = linear_end 70 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 71 | 72 | to_torch = partial(torch.tensor, dtype=torch.float32) 73 | 74 | self.betas = to_torch(betas).to(self.device) 75 | self.alphas_cumprod = to_torch(alphas_cumprod).to(self.device) 76 | self.alphas_cumprod_prev = to_torch(alphas_cumprod_prev).to(self.device) 77 | 78 | # calculations for diffusion q(x_t | x_{t-1}) and others 79 | self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod)).to(self.device) 80 | self.sqrt_one_minus_alphas_cumprod = to_torch(np.sqrt(1. - alphas_cumprod)).to(self.device) 81 | self.log_one_minus_alphas_cumprod = to_torch(np.log(1. - alphas_cumprod)).to(self.device) 82 | self.sqrt_recip_alphas_cumprod = to_torch(np.sqrt(1. / alphas_cumprod)).to(self.device) 83 | self.sqrt_recipm1_alphas_cumprod = to_torch(np.sqrt(1. / alphas_cumprod - 1)).to(self.device) 84 | 85 | # calculations for posterior q(x_{t-1} | x_t, x_0) 86 | posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( 87 | 1. - alphas_cumprod) + self.v_posterior * betas 88 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 89 | self.posterior_variance = to_torch(posterior_variance).to(self.device) 90 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 91 | self.posterior_log_variance_clipped = to_torch(np.log(np.maximum(posterior_variance, 1e-20))).to(self.device) 92 | self.posterior_mean_coef1 = to_torch( 93 | betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)).to(self.device) 94 | self.posterior_mean_coef2 = to_torch( 95 | (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)).to(self.device) 96 | 97 | if self.parameterization == "eps": 98 | lvlb_weights = self.betas ** 2 / ( 99 | 2 * self.posterior_variance * to_torch(alphas).to(self.device) * (1 - self.alphas_cumprod)) 100 | elif self.parameterization == "x0": 101 | lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) 102 | else: 103 | raise NotImplementedError("mu not supported") 104 | # TODO how to choose this term 105 | lvlb_weights[0] = lvlb_weights[1] 106 | self.lvlb_weights = lvlb_weights 107 | assert not torch.isnan(self.lvlb_weights).all() -------------------------------------------------------------------------------- /model/networks/diffusion_layout/mmg2layout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module 3 | 4 | from .diffusion_ddpm import DiffusionPoint 5 | from .denoise_net import UNet1DModel 6 | 7 | class MMGToLayout(Module): 8 | 9 | def __init__(self, config, n_classes=None): 10 | super().__init__() 11 | self.device = config.hyper.device 12 | self.rel_condition = config.layout_branch.relation_condition 13 | # define the denoising network 14 | if config.layout_branch.denoiser == "unet1d": 15 | denoise_net = UNet1DModel(**config.layout_branch.denoiser_kwargs) 16 | else: 17 | raise NotImplementedError() 18 | 19 | # define the diffusion type 20 | self.df = DiffusionPoint( 21 | denoise_net = denoise_net, 22 | config = config.layout_branch, 23 | **config.layout_branch.diffusion_kwargs 24 | ) 25 | self.n_classes = n_classes # not used 26 | self.config = config 27 | 28 | # read object property dimension 29 | self.translation_dim = config.layout_branch.get("translation_dim", 3) 30 | self.size_dim = config.layout_branch.get("size_dim", 3) 31 | self.angle_dim = config.layout_branch.angle_dim 32 | self.bbox_dim = self.translation_dim + self.size_dim + self.angle_dim 33 | 34 | # param list 35 | trainable_models = [self.df] 36 | trainable_params = [] 37 | for m in trainable_models: 38 | trainable_params += [p for p in m.parameters() if p.requires_grad == True] 39 | self.trainable_params = trainable_params 40 | 41 | self.df.to(self.device) 42 | self.scene_ids=None 43 | 44 | def set_requires_grad(self, nets, requires_grad=False): 45 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 46 | Parameters: 47 | nets (network list) -- a list of networks 48 | requires_grad (bool) -- whether the networks require gradients or not 49 | """ 50 | if not isinstance(nets, list): 51 | nets = [nets] 52 | for net in nets: 53 | if net is not None: 54 | for param in net.parameters(): 55 | param.requires_grad = requires_grad 56 | 57 | def set_input(self, data_dict): 58 | vars_list = [] 59 | try: 60 | self.x = data_dict['box'] 61 | self.scene_ids = data_dict['obj_id_to_scene'] 62 | B, D = self.x.shape 63 | vars_list.append('x') 64 | except: 65 | print('inference mode, no gt boxes and scene ids') 66 | 67 | self.preds = data_dict['preds'] 68 | self.rel = data_dict['c_b'] 69 | self.uc_rel = data_dict['uc_b'] 70 | vars_list += ['preds', 'rel', 'uc_rel'] 71 | self.tocuda(var_names=vars_list) 72 | 73 | def tocuda(self, var_names): 74 | for name in var_names: 75 | if isinstance(name, str): 76 | var = getattr(self, name) 77 | setattr(self, name, var.cuda(self.device, non_blocking=True)) 78 | 79 | def forward(self): 80 | self.df.train() 81 | rel = self.rel 82 | obj_embed = self.uc_rel 83 | target_box = self.x 84 | triples = self.preds 85 | 86 | # Compute the loss 87 | self.loss, self.loss_dict = self.get_loss(obj_embed=obj_embed, obj_triples=triples, target_box=target_box, rel=rel) 88 | return self.loss, self.loss_dict 89 | 90 | def get_loss(self, obj_embed, obj_triples, target_box, rel): 91 | # Unpack the sample_params 92 | batch_size, D_params = target_box.shape 93 | if self.rel_condition: 94 | condition_cross = rel # use rel embed for cross attention 95 | else: 96 | raise NotImplementedError 97 | 98 | loss, loss_dict = self.df.get_loss_iter(obj_embed, obj_triples, target_box, scene_ids=self.scene_ids, condition_cross=condition_cross) 99 | 100 | return loss, loss_dict 101 | 102 | def sample(self, box_dim, batch_size, obj_embed=None, obj_triples=None, text=None, rel=None, ret_traj=False, ddim=False, clip_denoised=False, freq=40, batch_seeds=None): 103 | 104 | noise_shape = (batch_size, box_dim) 105 | condition = rel if self.rel_condition else None 106 | condition_cross = None 107 | # reverse sampling 108 | samples = self.df.gen_samples_sg(noise_shape, obj_embed.device, obj_embed, obj_triples, condition=condition, clip_denoised=clip_denoised) 109 | 110 | return samples 111 | 112 | @torch.no_grad() 113 | def generate_layout_sg(self, box_dim, text=None, ret_traj=False, ddim=False, clip_denoised=False, batch_seeds=None): 114 | 115 | rel = self.rel 116 | obj_embed = self.uc_rel 117 | triples = self.preds 118 | 119 | samples = self.sample(box_dim, batch_size=len(obj_embed), obj_embed=obj_embed, obj_triples=triples, text=text, rel=rel, ret_traj=ret_traj, ddim=ddim, clip_denoised=clip_denoised, batch_seeds=batch_seeds) 120 | samples_dict = { 121 | "sizes": samples[:, 0:self.size_dim].contiguous(), 122 | "translations": samples[:, self.size_dim:self.size_dim + self.translation_dim].contiguous(), 123 | "angles": samples[:, self.size_dim + self.translation_dim:self.bbox_dim].contiguous(), 124 | } 125 | 126 | return samples_dict 127 | 128 | -------------------------------------------------------------------------------- /model/diff_utils/visualizer.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | from collections import OrderedDict 4 | import os 5 | import ntpath 6 | import time 7 | 8 | from termcolor import colored 9 | from . import util 10 | 11 | import torch 12 | import imageio 13 | import numpy as np 14 | import cv2 15 | import matplotlib.pyplot as plt 16 | 17 | def parse_line(line): 18 | info_d = {} 19 | 20 | l1, l2 = line.split(') ') 21 | l1 = l1.replace('(', '') 22 | l1 = l1.split(', ') 23 | 24 | l2 = l2.replace('(', '') 25 | l2 = l2.split(' ') 26 | 27 | info_d = {} 28 | for s in l1: 29 | 30 | k, v = s.split(': ') 31 | 32 | 33 | if k in ['epoch', 'iters']: 34 | info_d[k] = int(v) 35 | else: 36 | info_d[k] = float(v) 37 | 38 | l2_keys = l2[0::2] 39 | l2_vals = l2[1::2] 40 | 41 | for k, v in zip(l2_keys, l2_vals): 42 | k = k.replace(':','') 43 | info_d[k] = float(v) 44 | 45 | return info_d 46 | 47 | 48 | class Visualizer(): 49 | def __init__(self, opt): 50 | # self.opt = opt 51 | self.isTrain = opt.hyper.isTrain 52 | self.gif_fps = 4 53 | 54 | if self.isTrain: 55 | # self.log_dir = os.path.join(opt.checkpoints_dir, opt.name) 56 | self.log_dir = os.path.join(opt.hyper.logs_dir, opt.hyper.name) 57 | else: 58 | self.log_dir = os.path.join(opt.hyper.results_dir, opt.hyper.name) 59 | 60 | self.img_dir = os.path.join(self.log_dir, 'images') 61 | self.name = opt.hyper.name 62 | self.opt = opt 63 | 64 | def setup_io(self): 65 | 66 | print('[*] create image directory:\n%s...' % os.path.abspath(self.img_dir) ) 67 | util.mkdirs([self.img_dir]) 68 | # self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 69 | 70 | if self.isTrain: 71 | self.log_name = os.path.join(self.log_dir, 'loss_log.txt') 72 | # with open(self.log_name, "a") as log_file: 73 | with open(self.log_name, "w") as log_file: 74 | now = time.strftime("%c") 75 | log_file.write('================ Training Loss (%s) ================\n' % now) 76 | 77 | def reset(self): 78 | self.saved = False 79 | 80 | def print_current_errors(self, writer, current_iters, errors, t): 81 | message = f"[{self.opt.hyper.name}] (GPU: {self.opt.hyper.gpu_ids}, iters: {current_iters}, time: {t:.3f}) " 82 | for k, v in errors.items(): 83 | message += '%s: %.6f ' % (k, v) 84 | 85 | print(colored(message, 'magenta')) 86 | with open(self.log_name, "a") as log_file: 87 | log_file.write('%s\n' % message) 88 | 89 | self.log_tensorboard_errors(writer, errors, current_iters) 90 | 91 | def print_current_metrics(self, writer, current_iters, metrics, phase): 92 | # message = f'([{phase}] GPU: {}, steps: %d) ' % (phase, self.opt.gpu_ids_str, current_iters) 93 | # message = f'([{self.opt.exp_time}] [{phase}] GPU: {self.opt.gpu_ids_str}, steps: {current_iters}) ' 94 | message = f'([{self.opt.hyper.name}] [{phase}] GPU: {self.opt.hyper.gpu_ids}, steps: {current_iters}) ' 95 | for k, v in metrics.items(): 96 | message += '%s: %.3f ' % (k, v) 97 | 98 | print(colored(message, 'yellow')) 99 | with open(self.log_name, "a") as log_file: 100 | log_file.write('%s\n' % message) 101 | 102 | # self.log_tensorboard_metrics(metrics, epoch, phase) 103 | self.log_tensorboard_metrics(writer, metrics, current_iters, phase) 104 | 105 | def display_current_results(self, writer, visuals, current_iters, im_name='', phase='train'): 106 | 107 | # write images to disk 108 | for label, image_numpy in visuals.items(): 109 | img_path = os.path.join(self.img_dir, f'{phase}_step___{current_iters:05d}_{label}_{im_name}.png') 110 | util.save_image(image_numpy, img_path) 111 | 112 | # log to tensorboard 113 | self.log_tensorboard_visuals(writer, visuals, current_iters, phase=phase) 114 | 115 | def log_tensorboard_visuals(self, writer, visuals, cur_step, labels_while_list=None, phase='train'): 116 | 117 | if labels_while_list is None: 118 | labels_while_list = [] 119 | 120 | # NOTE: we have ('text', text_data) as visuals now 121 | for ix, (label, image_numpy) in enumerate(visuals.items()): 122 | if image_numpy.shape[2] == 4: 123 | image_numpy = image_numpy[:, :, :3] 124 | 125 | if label not in labels_while_list: 126 | # writer.add_image('vis/%d-%s' % (ix+1, label), image_numpy, global_step=cur_step, dataformats='HWC') 127 | writer.add_image('%s/%d-%s' % (phase, ix+1, label), image_numpy, global_step=cur_step, dataformats='HWC') 128 | else: 129 | pass 130 | # log the unwanted image just in case 131 | # writer.add_image('other/%s' % (label), image_numpy, global_step=cur_step, dataformats='HWC') 132 | 133 | def log_tensorboard_errors(self, writer, errors, cur_step): 134 | 135 | for label, error in errors.items(): 136 | writer.add_scalar('losses/%s' % label, error, cur_step) 137 | 138 | def log_tensorboard_metrics(self, writer, metrics, cur_step, phase): 139 | 140 | for label, value in metrics.items(): 141 | writer.add_scalar('metrics/%s-%s' % (phase, label), value, cur_step) -------------------------------------------------------------------------------- /model/networks/vqvae_networks/network.py: -------------------------------------------------------------------------------- 1 | # adopt from: 2 | # - VQVAE: https://github.com/nadavbh12/VQ-VAE 3 | # - Encoder: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/diffusionmodules/model.py 4 | 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.utils.data 9 | from torch import nn 10 | from torch.nn import init 11 | from torch.nn import functional as F 12 | 13 | from einops import rearrange 14 | 15 | from model.networks.vqvae_networks.vqvae_modules import Encoder3D, Decoder3D 16 | from model.networks.vqvae_networks.quantizer import VectorQuantizer 17 | 18 | def init_weights(net, init_type='normal', gain=0.01): 19 | def init_func(m): 20 | classname = m.__class__.__name__ 21 | if classname.find('BatchNorm2d') != -1: 22 | if hasattr(m, 'weight') and m.weight is not None: 23 | init.normal_(m.weight.data, 1.0, gain) 24 | if hasattr(m, 'bias') and m.bias is not None: 25 | init.constant_(m.bias.data, 0.0) 26 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 27 | if init_type == 'normal': 28 | init.normal_(m.weight.data, 0.0, gain) 29 | elif init_type == 'xavier': 30 | init.xavier_normal_(m.weight.data, gain=gain) 31 | elif init_type == 'xavier_uniform': 32 | init.xavier_uniform_(m.weight.data, gain=1.0) 33 | elif init_type == 'kaiming': 34 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 35 | elif init_type == 'orthogonal': 36 | init.orthogonal_(m.weight.data, gain=gain) 37 | elif init_type == 'none': # uses pytorch's default init method 38 | m.reset_parameters() 39 | else: 40 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 41 | if hasattr(m, 'bias') and m.bias is not None: 42 | init.constant_(m.bias.data, 0.0) 43 | 44 | net.apply(init_func) 45 | 46 | # propagate to children 47 | for m in net.children(): 48 | m.apply(init_func) 49 | 50 | 51 | class VQVAE(nn.Module): 52 | def __init__(self, 53 | ddconfig, 54 | n_embed, 55 | embed_dim, 56 | remap=None, 57 | sane_index_shape=False, # tell vector quantizer to return indices as bhw 58 | ): 59 | super(VQVAE, self).__init__() 60 | 61 | self.ddconfig = ddconfig 62 | self.n_embed = n_embed 63 | self.embed_dim = embed_dim 64 | 65 | self.encoder = Encoder3D(**ddconfig) 66 | self.decoder = Decoder3D(**ddconfig) 67 | 68 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=1.0, 69 | remap=remap, sane_index_shape=sane_index_shape, legacy=False) 70 | self.quant_conv = torch.nn.Conv3d(ddconfig["z_channels"], embed_dim, 1) 71 | self.post_quant_conv = torch.nn.Conv3d(embed_dim, ddconfig["z_channels"], 1) 72 | 73 | init_weights(self.encoder, 'normal', 0.02) 74 | init_weights(self.decoder, 'normal', 0.02) 75 | init_weights(self.quant_conv, 'normal', 0.02) 76 | init_weights(self.post_quant_conv, 'normal', 0.02) 77 | 78 | def encode(self, x): 79 | h = self.encoder(x) 80 | h = self.quant_conv(h) 81 | quant, emb_loss, info = self.quantize(h, is_voxel=True) 82 | return quant, emb_loss, info 83 | 84 | def encode_no_quant(self, x): 85 | h = self.encoder(x) 86 | h = self.quant_conv(h) 87 | # quant, emb_loss, info = self.quantize(h, is_voxel=True) 88 | return h 89 | 90 | def decode(self, quant): 91 | quant = self.post_quant_conv(quant) 92 | dec = self.decoder(quant) 93 | return dec 94 | 95 | def decode_no_quant(self, h, force_not_quantize=False): 96 | # also go through quantization layer 97 | if not force_not_quantize: 98 | quant, emb_loss, info = self.quantize(h, is_voxel=True) 99 | else: 100 | quant = h 101 | quant = self.post_quant_conv(quant) 102 | dec = self.decoder(quant) 103 | return dec 104 | 105 | def decode_from_quant(self,quant_code): 106 | embed_from_code = self.quantize.embedding(quant_code) 107 | return embed_from_code 108 | 109 | def decode_enc_idices(self, enc_indices, z_spatial_dim=8): 110 | 111 | # for transformer 112 | enc_indices = rearrange(enc_indices, 't bs -> (bs t)') 113 | z_q = self.quantize.embedding(enc_indices) # (bs t) zd 114 | z_q = rearrange(z_q, '(bs d1 d2 d3) zd -> bs zd d1 d2 d3', d1=z_spatial_dim, d2=z_spatial_dim, d3=z_spatial_dim) 115 | dec = self.decode(z_q) 116 | return dec 117 | 118 | def decode_code(self, code_b): 119 | quant_b = self.quantize.embed_code(code_b) 120 | dec = self.decode(quant_b) 121 | return dec 122 | 123 | def forward(self, input, verbose=False, forward_no_quant=False, encode_only=False): 124 | 125 | if forward_no_quant: 126 | # for diffusion model's training 127 | z = self.encode_no_quant(input) 128 | if encode_only: 129 | return z 130 | 131 | dec = self.decode_no_quant(z) 132 | return dec, z 133 | 134 | quant, diff, info = self.encode(input) 135 | dec = self.decode(quant) 136 | 137 | if verbose: 138 | return dec, quant, diff, info 139 | else: 140 | return dec, diff 141 | -------------------------------------------------------------------------------- /model/networks/diffusion_shape/diff_utils/visualizer.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | from collections import OrderedDict 4 | import os 5 | import ntpath 6 | import time 7 | 8 | from termcolor import colored 9 | from . import util 10 | 11 | import torch 12 | import imageio 13 | import numpy as np 14 | import cv2 15 | import matplotlib.pyplot as plt 16 | 17 | def parse_line(line): 18 | info_d = {} 19 | 20 | l1, l2 = line.split(') ') 21 | l1 = l1.replace('(', '') 22 | l1 = l1.split(', ') 23 | 24 | l2 = l2.replace('(', '') 25 | l2 = l2.split(' ') 26 | 27 | info_d = {} 28 | for s in l1: 29 | 30 | k, v = s.split(': ') 31 | 32 | 33 | if k in ['epoch', 'iters']: 34 | info_d[k] = int(v) 35 | else: 36 | info_d[k] = float(v) 37 | 38 | l2_keys = l2[0::2] 39 | l2_vals = l2[1::2] 40 | 41 | for k, v in zip(l2_keys, l2_vals): 42 | k = k.replace(':','') 43 | info_d[k] = float(v) 44 | 45 | return info_d 46 | 47 | 48 | class Visualizer(): 49 | def __init__(self, opt): 50 | # self.opt = opt 51 | self.isTrain = opt.hyper.isTrain 52 | self.gif_fps = 4 53 | 54 | if self.isTrain: 55 | # self.log_dir = os.path.join(opt.checkpoints_dir, opt.name) 56 | self.log_dir = os.path.join(opt.hyper.logs_dir, opt.hyper.name) 57 | else: 58 | self.log_dir = os.path.join(opt.hyper.results_dir, opt.hyper.name) 59 | 60 | self.img_dir = os.path.join(self.log_dir, 'images') 61 | self.name = opt.hyper.name 62 | self.opt = opt 63 | 64 | def setup_io(self): 65 | 66 | print('[*] create image directory:\n%s...' % os.path.abspath(self.img_dir) ) 67 | util.mkdirs([self.img_dir]) 68 | # self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 69 | 70 | if self.isTrain: 71 | self.log_name = os.path.join(self.log_dir, 'loss_log.txt') 72 | # with open(self.log_name, "a") as log_file: 73 | with open(self.log_name, "w") as log_file: 74 | now = time.strftime("%c") 75 | log_file.write('================ Training Loss (%s) ================\n' % now) 76 | 77 | def reset(self): 78 | self.saved = False 79 | 80 | def print_current_errors(self, writer, current_iters, errors, t): 81 | # message = '(GPU: %s, epoch: %d, iters: %d, time: %.3f) ' % (self.opt.gpu_ids_str, t) 82 | # message = f"[{self.opt.exp_time}] (GPU: {self.opt.gpu_ids_str}, iters: {current_iters}, time: {t:.3f}) " 83 | message = f"[{self.opt.hyper.name}] (GPU: {self.opt.hyper.gpu_ids}, iters: {current_iters}, time: {t:.3f}) " 84 | for k, v in errors.items(): 85 | message += '%s: %.6f ' % (k, v) 86 | 87 | print(colored(message, 'magenta')) 88 | with open(self.log_name, "a") as log_file: 89 | log_file.write('%s\n' % message) 90 | 91 | self.log_tensorboard_errors(writer, errors, current_iters) 92 | 93 | def print_current_metrics(self, writer, current_iters, metrics, phase): 94 | # message = f'([{phase}] GPU: {}, steps: %d) ' % (phase, self.opt.gpu_ids_str, current_iters) 95 | # message = f'([{self.opt.exp_time}] [{phase}] GPU: {self.opt.gpu_ids_str}, steps: {current_iters}) ' 96 | message = f'([{self.opt.hyper.name}] [{phase}] GPU: {self.opt.hyper.gpu_ids}, steps: {current_iters}) ' 97 | for k, v in metrics.items(): 98 | message += '%s: %.3f ' % (k, v) 99 | 100 | print(colored(message, 'yellow')) 101 | with open(self.log_name, "a") as log_file: 102 | log_file.write('%s\n' % message) 103 | 104 | # self.log_tensorboard_metrics(metrics, epoch, phase) 105 | self.log_tensorboard_metrics(writer, metrics, current_iters, phase) 106 | 107 | def display_current_results(self, writer, visuals, current_iters, im_name='', phase='train'): 108 | 109 | # write images to disk 110 | for label, image_numpy in visuals.items(): 111 | img_path = os.path.join(self.img_dir, f'{phase}_step___{current_iters:05d}_{label}_{im_name}.png') 112 | util.save_image(image_numpy, img_path) 113 | 114 | # log to tensorboard 115 | self.log_tensorboard_visuals(writer, visuals, current_iters, phase=phase) 116 | 117 | def log_tensorboard_visuals(self, writer, visuals, cur_step, labels_while_list=None, phase='train'): 118 | 119 | if labels_while_list is None: 120 | labels_while_list = [] 121 | 122 | # NOTE: we have ('text', text_data) as visuals now 123 | for ix, (label, image_numpy) in enumerate(visuals.items()): 124 | if image_numpy.shape[2] == 4: 125 | image_numpy = image_numpy[:, :, :3] 126 | 127 | if label not in labels_while_list: 128 | # writer.add_image('vis/%d-%s' % (ix+1, label), image_numpy, global_step=cur_step, dataformats='HWC') 129 | writer.add_image('%s/%d-%s' % (phase, ix+1, label), image_numpy, global_step=cur_step, dataformats='HWC') 130 | else: 131 | pass 132 | # log the unwanted image just in case 133 | # writer.add_image('other/%s' % (label), image_numpy, global_step=cur_step, dataformats='HWC') 134 | 135 | def log_tensorboard_errors(self, writer, errors, cur_step): 136 | 137 | for label, error in errors.items(): 138 | writer.add_scalar('losses/%s' % label, error, cur_step) 139 | 140 | def log_tensorboard_metrics(self, writer, metrics, cur_step, phase): 141 | 142 | for label, value in metrics.items(): 143 | writer.add_scalar('metrics/%s-%s' % (phase, label), value, cur_step) -------------------------------------------------------------------------------- /helpers/viz_util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import yaml 3 | 4 | 5 | def load_semantic_scene_graphs_custom(yml_relationships, color_palette, rel_label_to_id, with_manipuation=False): 6 | scene_graphs = {} 7 | 8 | graphs = yaml.load(open(yml_relationships, 'r')) 9 | for scene_id, scene in graphs['Scenes'].items(): 10 | 11 | scene_graphs[str(scene_id)] = {} 12 | scene_graphs[str(scene_id)]['objects'] = [] 13 | scene_graphs[str(scene_id)]['relationships'] = [] 14 | scene_graphs[str(scene_id)]['node_mask'] = [1] * len(scene['nodes']) 15 | scene_graphs[str(scene_id)]['edge_mask'] = [1] * len(scene['relships']) 16 | 17 | for (i, n) in enumerate(scene['nodes']): 18 | obj_item = {'ply_color': color_palette[i%len(color_palette)], 19 | 'id': str(i), 20 | 'label': n} 21 | scene_graphs[str(scene_id)]['objects'].append(obj_item) 22 | for r in scene['relships']: 23 | rel_4 = [r[0], r[1], rel_label_to_id[r[2]], r[2]] 24 | scene_graphs[str(scene_id)]['relationships'].append(rel_4) 25 | counter = len(scene['nodes']) 26 | if with_manipuation: 27 | for m in scene['manipulations']: 28 | if m[1] == 'add': 29 | # visualize an addition 30 | # ['chair', 'add', [[2, 'standing on'], [1, 'left']]] 31 | obj_item = {'ply_color': color_palette[counter%len(color_palette)], 32 | 'id': str(counter), 33 | 'label': m[0]} 34 | scene_graphs[str(scene_id)]['objects'].append(obj_item) 35 | 36 | scene_graphs[str(scene_id)]['node_mask'].append(0) 37 | for mani_rel in m[2]: 38 | rel_4 = [counter, mani_rel[0], rel_label_to_id[mani_rel[1]], mani_rel[1]] 39 | scene_graphs[str(scene_id)]['relationships'].append(rel_4) 40 | scene_graphs[str(scene_id)]['edge_mask'].append(0) 41 | counter += 1 42 | if m[1] == 'rel': 43 | # visualize changes in the relationship 44 | for (rid, r) in enumerate(scene_graphs[str(scene_id)]['relationships']): 45 | s, o, p, l = r 46 | if isinstance(m[2][3], list): 47 | # ['', 'rel', [0, 1, 'right', [0, 1, 'left']]] 48 | if s == m[2][0] and o == m[2][1] and l == m[2][2] and s == m[2][3][0] and o == m[2][3][1]: 49 | # a change on the SAME (s, o) pair, indicate the change 50 | scene_graphs[str(scene_id)]['edge_mask'][rid] = 0 51 | scene_graphs[str(scene_id)]['relationships'][rid][3] = m[2][2] + '->' + m[2][3][2] 52 | scene_graphs[str(scene_id)]['relationships'][rid][2] = rel_label_to_id[m[2][3][2]] 53 | break 54 | elif s == m[2][0] and o == m[2][1] and l == m[2][2]: 55 | # overwrite this edge with a new pair (s,o) 56 | del scene_graphs[str(scene_id)]['edge_mask'][rid] 57 | del scene_graphs[str(scene_id)]['relationships'][rid] 58 | scene_graphs[str(scene_id)]['edge_mask'].append(0) 59 | new_edge = [m[2][3][0], m[2][3][1], rel_label_to_id[m[2][3][2]], m[2][3][2]] 60 | scene_graphs[str(scene_id)]['relationships'].append(new_edge) 61 | else: 62 | # ['', 'rel', [0, 1, 'right', 'left']] 63 | if s == m[2][0] and o == m[2][1] and l == m[2][2]: 64 | scene_graphs[str(scene_id)]['edge_mask'][rid] = 0 65 | scene_graphs[str(scene_id)]['relationships'][rid][3] = m[2][2] + '->' + m[2][3] 66 | scene_graphs[str(scene_id)]['relationships'][rid][2] = rel_label_to_id[m[2][3]] 67 | break 68 | 69 | return scene_graphs 70 | 71 | 72 | def load_semantic_scene_graphs(json_relationships, json_objects): 73 | scene_graphs_obj = {} 74 | 75 | with open(json_objects, "r") as read_file: 76 | data = json.load(read_file) 77 | for s in data["scans"]: 78 | scan = s["scan"] 79 | objs = s['objects'] 80 | scene_graphs_obj[scan] = {} 81 | scene_graphs_obj[scan]['scan'] = scan 82 | scene_graphs_obj[scan]['objects'] = [] 83 | for obj in objs: 84 | scene_graphs_obj[scan]['objects'].append(obj) 85 | scene_graphs = {} 86 | with open(json_relationships, "r") as read_file: 87 | data = json.load(read_file) 88 | for s in data["scans"]: 89 | scan = s["scan"] 90 | split = str(s["split"]) 91 | if scan + "_" + split not in scene_graphs: 92 | scene_graphs[scan + "_" + split] = {} 93 | scene_graphs[scan + "_" + split]['objects'] = [] 94 | print("WARNING: no objects for this scene") 95 | scene_graphs[scan + "_" + split]['relationships'] = [] 96 | for k in s["objects"].keys(): 97 | ob = s['objects'][k] 98 | for i,o in enumerate(scene_graphs_obj[scan]['objects']): 99 | if o['id'] == k: 100 | inst = i 101 | break 102 | scene_graphs[scan + "_" + split]['objects'].append(scene_graphs_obj[scan]['objects'][inst]) 103 | for rel in s["relationships"]: 104 | scene_graphs[scan + "_" + split]['relationships'].append(rel) 105 | return scene_graphs 106 | 107 | 108 | def read_relationships(read_file): 109 | relationships = [] 110 | with open(read_file, 'r') as f: 111 | for line in f: 112 | relationship = line.rstrip().lower() 113 | relationships.append(relationship) 114 | return relationships 115 | -------------------------------------------------------------------------------- /model/networks/vqvae_networks/quantizer.py: -------------------------------------------------------------------------------- 1 | """ adapted from: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from torch import einsum 8 | from einops import rearrange 9 | 10 | class VectorQuantizer(nn.Module): 11 | """ 12 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly 13 | avoids costly matrix multiplications and allows for post-hoc remapping of indices. 14 | """ 15 | # NOTE: due to a bug the beta term was applied to the wrong term. for 16 | # backwards compatibility we use the buggy version by default, but you can 17 | # specify legacy=False to fix it. 18 | def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", 19 | sane_index_shape=False, legacy=True): 20 | super().__init__() 21 | self.n_e = n_e 22 | self.e_dim = e_dim 23 | self.beta = beta 24 | self.legacy = legacy 25 | 26 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 27 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 28 | 29 | self.remap = remap 30 | if self.remap is not None: 31 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 32 | self.re_embed = self.used.shape[0] 33 | self.unknown_index = unknown_index # "random" or "extra" or integer 34 | if self.unknown_index == "extra": 35 | self.unknown_index = self.re_embed 36 | self.re_embed = self.re_embed+1 37 | print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " 38 | f"Using {self.unknown_index} for unknown indices.") 39 | else: 40 | self.re_embed = n_e 41 | 42 | self.sane_index_shape = sane_index_shape 43 | 44 | def remap_to_used(self, inds): 45 | ishape = inds.shape 46 | assert len(ishape)>1 47 | inds = inds.reshape(ishape[0],-1) 48 | used = self.used.to(inds) 49 | match = (inds[:,:,None]==used[None,None,...]).long() 50 | new = match.argmax(-1) 51 | unknown = match.sum(2)<1 52 | if self.unknown_index == "random": 53 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) 54 | else: 55 | new[unknown] = self.unknown_index 56 | return new.reshape(ishape) 57 | 58 | def unmap_to_all(self, inds): 59 | ishape = inds.shape 60 | assert len(ishape)>1 61 | inds = inds.reshape(ishape[0],-1) 62 | used = self.used.to(inds) 63 | if self.re_embed > self.used.shape[0]: # extra token 64 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero 65 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) 66 | return back.reshape(ishape) 67 | 68 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False, is_voxel=False): 69 | assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" 70 | assert rescale_logits==False, "Only for interface compatible with Gumbel" 71 | assert return_logits==False, "Only for interface compatible with Gumbel" 72 | # reshape z -> (batch, height, width, channel) and flatten 73 | if not is_voxel: 74 | z = rearrange(z, 'b c h w -> b h w c').contiguous() 75 | else: 76 | z = rearrange(z, 'b c d h w -> b d h w c').contiguous() 77 | z_flattened = z.view(-1, self.e_dim) 78 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 79 | 80 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 81 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 82 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) 83 | 84 | min_encoding_indices = torch.argmin(d, dim=1) 85 | z_q = self.embedding(min_encoding_indices).view(z.shape) 86 | perplexity = None 87 | min_encodings = None 88 | 89 | # compute loss for embedding 90 | if not self.legacy: 91 | loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ 92 | torch.mean((z_q - z.detach()) ** 2) 93 | else: 94 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ 95 | torch.mean((z_q - z.detach()) ** 2) 96 | 97 | # preserve gradients 98 | z_q = z + (z_q - z).detach() 99 | 100 | # reshape back to match original input shape 101 | if not is_voxel: 102 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() 103 | else: 104 | z_q = rearrange(z_q, 'b d h w c -> b c d h w').contiguous() 105 | 106 | if self.remap is not None: 107 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis 108 | min_encoding_indices = self.remap_to_used(min_encoding_indices) 109 | min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten 110 | 111 | if self.sane_index_shape: 112 | if not is_voxel: 113 | min_encoding_indices = min_encoding_indices.reshape( 114 | z_q.shape[0], z_q.shape[2], z_q.shape[3]) 115 | else: 116 | min_encoding_indices = min_encoding_indices.reshape( 117 | z_q.shape[0], z_q.shape[2], z_q.shape[3], z_q.shape[4]) 118 | 119 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 120 | 121 | def get_codebook_entry(self, indices, shape): 122 | # shape specifying (batch, height, width, channel) 123 | if self.remap is not None: 124 | indices = indices.reshape(shape[0],-1) # add batch axis 125 | indices = self.unmap_to_all(indices) 126 | indices = indices.reshape(-1) # flatten again 127 | 128 | # get quantized latent vectors 129 | z_q = self.embedding(indices) 130 | 131 | if shape is not None: 132 | z_q = z_q.view(shape) 133 | # reshape back to match original input shape 134 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 135 | 136 | return z_q 137 | -------------------------------------------------------------------------------- /extension/old_chamfer/chamfer.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | 11 | 12 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 13 | const int batch=512; 14 | __shared__ float buf[batch*3]; 15 | for (int i=blockIdx.x;ibest){ 127 | result[(i*n+j)]=best; 128 | result_i[(i*n+j)]=best_i; 129 | } 130 | } 131 | __syncthreads(); 132 | } 133 | } 134 | } 135 | // int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 136 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ 137 | 138 | const auto batch_size = xyz1.size(0); 139 | const auto n = xyz1.size(1); //num_points point cloud A 140 | const auto m = xyz2.size(1); //num_points point cloud B 141 | 142 | NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); 143 | NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); 144 | 145 | cudaError_t err = cudaGetLastError(); 146 | if (err != cudaSuccess) { 147 | printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); 148 | //THError("aborting"); 149 | return 0; 150 | } 151 | return 1; 152 | 153 | 154 | } 155 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 156 | for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); 185 | NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); 186 | 187 | cudaError_t err = cudaGetLastError(); 188 | if (err != cudaSuccess) { 189 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); 190 | //THError("aborting"); 191 | return 0; 192 | } 193 | return 1; 194 | 195 | } 196 | 197 | -------------------------------------------------------------------------------- /scripts/pytorch_structural_losses/src/structural_loss.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "src/approxmatch.cuh" 5 | #include "src/nndistance.cuh" 6 | 7 | #include 8 | #include 9 | 10 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | /* 15 | input: 16 | set1 : batch_size * #dataset_points * 3 17 | set2 : batch_size * #query_points * 3 18 | returns: 19 | match : batch_size * #query_points * #dataset_points 20 | */ 21 | // temp: TensorShape{b,(n+m)*2} 22 | std::vector ApproxMatch(at::Tensor set_d, at::Tensor set_q) { 23 | //std::cout << "[ApproxMatch] Called." << std::endl; 24 | int64_t batch_size = set_d.size(0); 25 | int64_t n_dataset_points = set_d.size(1); // n 26 | int64_t n_query_points = set_q.size(1); // m 27 | //std::cout << "[ApproxMatch] batch_size:" << batch_size << std::endl; 28 | at::Tensor match = torch::empty({batch_size, n_query_points, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 29 | at::Tensor temp = torch::empty({batch_size, (n_query_points+n_dataset_points)*2}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 30 | CHECK_INPUT(set_d); 31 | CHECK_INPUT(set_q); 32 | CHECK_INPUT(match); 33 | CHECK_INPUT(temp); 34 | 35 | approxmatch(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),temp.data(), at::cuda::getCurrentCUDAStream()); 36 | return {match, temp}; 37 | } 38 | 39 | at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match) { 40 | //std::cout << "[MatchCost] Called." << std::endl; 41 | int64_t batch_size = set_d.size(0); 42 | int64_t n_dataset_points = set_d.size(1); // n 43 | int64_t n_query_points = set_q.size(1); // m 44 | //std::cout << "[MatchCost] batch_size:" << batch_size << std::endl; 45 | at::Tensor out = torch::empty({batch_size}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 46 | CHECK_INPUT(set_d); 47 | CHECK_INPUT(set_q); 48 | CHECK_INPUT(match); 49 | CHECK_INPUT(out); 50 | matchcost(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),out.data(),at::cuda::getCurrentCUDAStream()); 51 | return out; 52 | } 53 | 54 | std::vector MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match) { 55 | //std::cout << "[MatchCostGrad] Called." << std::endl; 56 | int64_t batch_size = set_d.size(0); 57 | int64_t n_dataset_points = set_d.size(1); // n 58 | int64_t n_query_points = set_q.size(1); // m 59 | //std::cout << "[MatchCostGrad] batch_size:" << batch_size << std::endl; 60 | at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 61 | at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 62 | CHECK_INPUT(set_d); 63 | CHECK_INPUT(set_q); 64 | CHECK_INPUT(match); 65 | CHECK_INPUT(grad1); 66 | CHECK_INPUT(grad2); 67 | matchcostgrad(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),grad1.data(),grad2.data(),at::cuda::getCurrentCUDAStream()); 68 | return {grad1, grad2}; 69 | } 70 | 71 | 72 | /* 73 | input: 74 | set_d : batch_size * #dataset_points * 3 75 | set_q : batch_size * #query_points * 3 76 | returns: 77 | dist1, idx1 : batch_size * #dataset_points 78 | dist2, idx2 : batch_size * #query_points 79 | */ 80 | std::vector NNDistance(at::Tensor set_d, at::Tensor set_q) { 81 | //std::cout << "[NNDistance] Called." << std::endl; 82 | int64_t batch_size = set_d.size(0); 83 | int64_t n_dataset_points = set_d.size(1); // n 84 | int64_t n_query_points = set_q.size(1); // m 85 | //std::cout << "[NNDistance] batch_size:" << batch_size << std::endl; 86 | at::Tensor dist1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 87 | at::Tensor idx1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device())); 88 | at::Tensor dist2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 89 | at::Tensor idx2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device())); 90 | CHECK_INPUT(set_d); 91 | CHECK_INPUT(set_q); 92 | CHECK_INPUT(dist1); 93 | CHECK_INPUT(idx1); 94 | CHECK_INPUT(dist2); 95 | CHECK_INPUT(idx2); 96 | // void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream); 97 | nndistance(batch_size,n_dataset_points,set_d.data(),n_query_points,set_q.data(),dist1.data(),idx1.data(),dist2.data(),idx2.data(), at::cuda::getCurrentCUDAStream()); 98 | return {dist1, idx1, dist2, idx2}; 99 | } 100 | 101 | std::vector NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2) { 102 | //std::cout << "[NNDistanceGrad] Called." << std::endl; 103 | int64_t batch_size = set_d.size(0); 104 | int64_t n_dataset_points = set_d.size(1); // n 105 | int64_t n_query_points = set_q.size(1); // m 106 | //std::cout << "[NNDistanceGrad] batch_size:" << batch_size << std::endl; 107 | at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 108 | at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 109 | CHECK_INPUT(set_d); 110 | CHECK_INPUT(set_q); 111 | CHECK_INPUT(idx1); 112 | CHECK_INPUT(idx2); 113 | CHECK_INPUT(grad_dist1); 114 | CHECK_INPUT(grad_dist2); 115 | CHECK_INPUT(grad1); 116 | CHECK_INPUT(grad2); 117 | //void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream); 118 | nndistancegrad(batch_size,n_dataset_points,set_d.data(),n_query_points,set_q.data(), 119 | grad_dist1.data(),idx1.data(), 120 | grad_dist2.data(),idx2.data(), 121 | grad1.data(),grad2.data(), 122 | at::cuda::getCurrentCUDAStream()); 123 | return {grad1, grad2}; 124 | } 125 | 126 | -------------------------------------------------------------------------------- /scripts/collect_gt_sdf_images.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import numpy as np 3 | import os 4 | import json 5 | import trimesh 6 | import pyrender 7 | import cv2 8 | import h5py 9 | import torch 10 | import sys 11 | import seaborn as sns 12 | sys.path.append('/s2/yangzhifei/project/MMGDreamer/') 13 | from model.diff_utils.util_3d import render_sdf, render_mesh, sdf_to_mesh 14 | from helpers.util import pytorch3d_to_trimesh, fit_shapes_to_box_v2 15 | import platform 16 | import os 17 | if platform.system() == "Linux": 18 | os.environ['PYOPENGL_PLATFORM'] = 'egl' 19 | ## Need to download 3D-FUTURE-SDF 20 | # files from SG-FRONT 21 | obj_info_path_test = "/data/yangzhifei/project/MMGDreamer/FRONT/obj_boxes_all_test.json" 22 | obj_info_path_trainval = "/data/yangzhifei/project/MMGDreamer/FRONT/obj_boxes_all_trainval.json" 23 | rel_trainval_file = "/data/yangzhifei/project/MMGDreamer/FRONT/relationships_all_trainval.json" 24 | rel_test_file = "/data/yangzhifei/project/MMGDreamer/FRONT/relationships_all_test.json" 25 | class_file = "/data/yangzhifei/project/MMGDreamer/FRONT/classes_all.txt" 26 | 27 | bath = '/data/yangzhifei/project/MMGDreamer/FRONT/sdf_fov90_h8_all_test_new' 28 | cat = {} 29 | large = False 30 | without_lamp = True # lamp遮挡严重,需要设置为False 31 | no_stool = True 32 | mapping_file = "/data/yangzhifei/project/MMGDreamer/FRONT/mapping.json" if not no_stool else "/data/yangzhifei/project/MMGDreamer/FRONT/mapping.json" 33 | if without_lamp: 34 | bath += '_wo_lamp' 35 | if no_stool: 36 | bath += '_no_stool' 37 | 38 | if large: 39 | img_path = os.path.join(bath, 'large') 40 | with open(class_file) as f: 41 | for line in f: 42 | category = line.rstrip() 43 | cat[category] = category 44 | classes = dict(zip(sorted(cat), range(len(cat)))) 45 | else: 46 | img_path = os.path.join(bath, 'small') 47 | mapping_full2simple = json.load(open(mapping_file, "r")) 48 | classes = dict(zip(sorted(list(set(mapping_full2simple.values()))), range(len(list(set(mapping_full2simple.values())))))) 49 | 50 | classes_r = dict(zip(classes.values(), classes.keys())) 51 | 52 | def render_img(trimesh_meshes): 53 | scene = pyrender.Scene() 54 | renderer = pyrender.OffscreenRenderer(viewport_width=256, viewport_height=256) 55 | for tri_mesh in trimesh_meshes: 56 | pyrender_mesh = pyrender.Mesh.from_trimesh(tri_mesh, smooth=False) 57 | scene.add(pyrender_mesh) 58 | 59 | camera = pyrender.PerspectiveCamera(yfov=np.pi/2) 60 | 61 | # set up positions and the origin 62 | # camera_location = np.array([0.0, 8.0, 0.0]) # y axis 63 | camera_location = np.array([0.0, 4.0, 0.0]) # y axis 64 | look_at_point = np.array([0.0, 0.0, 0.0]) 65 | up_vector = np.array([0.0, 0.0, -1.0]) # -z axis 66 | 67 | camera_direction = (look_at_point - camera_location) / np.linalg.norm(look_at_point - camera_location) 68 | right_vector = np.cross(camera_direction, up_vector) 69 | up_vector = np.cross(right_vector, camera_direction) 70 | 71 | camera_pose = np.identity(4) 72 | camera_pose[:3, 0] = right_vector 73 | camera_pose[:3, 1] = up_vector 74 | camera_pose[:3, 2] = -camera_direction 75 | camera_pose[:3, 3] = camera_location 76 | scene.add(camera, pose=camera_pose) 77 | 78 | light = pyrender.DirectionalLight(color=np.ones(3), intensity=2.0) 79 | scene.add(light, pose=camera_pose) 80 | 81 | # 添加一个点光源,更改颜色和强度 82 | point_light = pyrender.PointLight(color=np.ones(3), intensity=20.0) 83 | scene.add(point_light, pose=camera_pose) 84 | color, depth = renderer.render(scene) 85 | return color 86 | 87 | 88 | num_classes = len(classes_r.values()) 89 | color_palette = np.array(sns.color_palette('hls', num_classes)) 90 | 91 | ################################################################################## 92 | 93 | # with open(rel_trainval_file) as f: 94 | # rel = json.load(f) 95 | # img_path_trainval = os.path.join(img_path,'trainval') 96 | # if not os.path.exists(img_path_trainval): 97 | # os.makedirs(img_path_trainval) 98 | # 99 | # for root , direct, files in os.walk(img_path_trainval): 100 | # existed_files = files 101 | # 102 | # for info in rel['scans']: 103 | # obj_mesh_list = [] 104 | # scan_id = info['scan'] 105 | # cat_names = list(info['objects'].values()) 106 | # if scan_id+'.png' in existed_files: 107 | # continue 108 | # print(scan_id) 109 | # obj_list = sorted(glob.glob(os.path.join(image_path,scan_id,'*.obj'))) 110 | # for obj, cat_name in zip(obj_list[:-1], cat_names[:-1]): 111 | # obj_mesh = trimesh.load(obj) 112 | # obj_mesh = trimesh.Trimesh(vertices=obj_mesh.vertices,faces=obj_mesh.faces) 113 | # color = color_palette[classes[cat_name]] if large else color_palette[classes[mapping_full2simple[cat_name]]] 114 | # obj_mesh.visual.vertex_colors = color 115 | # obj_mesh.visual.face_colors = color 116 | # obj_mesh_list.append(obj_mesh) 117 | # color_img = render_img(obj_mesh_list) 118 | # color_bgr = cv2.cvtColor(color_img, cv2.COLOR_RGBA2BGR) 119 | # cv2.imwrite(os.path.join(img_path_trainval, '{}.png'.format(scan_id)), color_bgr) 120 | 121 | ################################################################################## 122 | 123 | with open(obj_info_path_test) as f: 124 | obj_info = json.load(f) 125 | with open(rel_test_file) as f: 126 | rel = json.load(f) 127 | for info in rel['scans']: 128 | obj_mesh_list = [] 129 | scan_id = info['scan'] 130 | for k, v in info["objects"].items(): 131 | # floor 132 | if obj_info[scan_id][k]['model_path'] is None: 133 | continue 134 | bbox = obj_info[scan_id][k]['param7'] 135 | bbox[3:6] -= np.array(obj_info[scan_id]['scene_center']) # centered in the scene 136 | class_ = mapping_full2simple[v] if not large else v 137 | if without_lamp and (class_ == 'lamp' or class_ == 'ceiling_lamp' or class_ == 'pendant_lamp'): 138 | continue 139 | class_id = classes[class_] 140 | color = color_palette[class_id] 141 | 142 | # the base of the model path should be changed to your own path 143 | sdf_path = obj_info[scan_id][k]['model_path'].replace('3D-FUTURE-model','3D-FUTURE-SDF').replace('raw_model.obj', 'ori_sample_grid.h5') 144 | h5_f = h5py.File(sdf_path, 'r') 145 | obj_sdf = h5_f['pc_sdf_sample'][:].astype(np.float32) 146 | sdf = torch.Tensor(obj_sdf).view(1, 64, 64, 64) 147 | sdf = torch.clamp(sdf, min=-0.2, max=0.2) 148 | pyorch3d_mesh = sdf_to_mesh(sdf.view(1, 1, 64, 64, 64), render_all=True) 149 | trimesh_mesh = pytorch3d_to_trimesh(pyorch3d_mesh) 150 | trimesh_mesh.visual.vertex_colors = color 151 | trimesh_mesh.visual.face_colors = color 152 | box_points, obj = fit_shapes_to_box_v2(trimesh_mesh, bbox, degrees=False) 153 | obj_mesh_list.append(obj) 154 | 155 | # scene = trimesh.Scene(obj_mesh_list) 156 | # scene.show() 157 | 158 | img_path_test = os.path.join(img_path, 'test') 159 | if not os.path.exists(img_path_test): 160 | os.makedirs(img_path_test) 161 | # print(obj_mesh_list) 162 | color_img = render_img(obj_mesh_list) 163 | color_bgr = cv2.cvtColor(color_img, cv2.COLOR_RGBA2BGR) 164 | cv2.imwrite(os.path.join(img_path_test, '{}.png'.format(scan_id)), color_bgr) 165 | -------------------------------------------------------------------------------- /scripts/compute_fid_scores_3dfront.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2021 NVIDIA Corporation. All rights reserved. 3 | # Licensed under the NVIDIA Source Code License. 4 | # See LICENSE at https://github.com/nv-tlabs/ATISS. 5 | # Authors: Despoina Paschalidou, Amlan Kar, Maria Shugrina, Karsten Kreis, 6 | # Andreas Geiger, Sanja Fidler 7 | # 8 | 9 | """Script for computing the FID score between real and synthesized scenes. 10 | """ 11 | import argparse 12 | import os 13 | import sys 14 | 15 | import torch 16 | 17 | import numpy as np 18 | from PIL import Image 19 | 20 | from cleanfid import fid 21 | 22 | import shutil 23 | 24 | parser = argparse.ArgumentParser( 25 | description=("Compute the FID scores between the real and the " 26 | "synthetic images")) 27 | parser.add_argument( 28 | "--path_to_real_renderings", 29 | default="/data/yangzhifei/project/MMGDreamer/FRONT/sdf_fov90_h8_no_stool/small/test", 30 | help="Path to the folder containing the real renderings" 31 | ) 32 | parser.add_argument( 33 | "--path_to_synthesized_renderings", 34 | default="/data/yangzhifei/project/MMGDreamer/experiments/train_diningroom/vis/200/render_imgs/mmgscene", 35 | help="Path to the folder containing the synthesized" 36 | ) 37 | # parser.add_argument( 38 | # "path_to_annotations", 39 | # help="Path to the folder containing the annotations" 40 | # ) 41 | parser.add_argument( 42 | "--compare_trainval", 43 | action="store_true", 44 | help="if compare trainval" 45 | ) 46 | 47 | parser.add_argument( 48 | "--room", 49 | default="bedroom", 50 | help="if compare trainval, [bedroom, livingroom, diningroom, all]" 51 | ) 52 | 53 | parser.add_argument( 54 | "--path_to_test", 55 | default="/data/yangzhifei/project/MMGDreamer/experiments/fid_kid_tmp/", 56 | help="path_to_test_real, temp" 57 | ) 58 | 59 | args = parser.parse_args() 60 | 61 | class ThreedFrontRenderDataset(object): 62 | def __init__(self, dataset): 63 | self.dataset = dataset 64 | 65 | def __len__(self): 66 | return len(self.dataset) 67 | 68 | def __getitem__(self, idx): 69 | image_path = self.dataset[idx].image_path 70 | img = Image.open(image_path) 71 | return img 72 | 73 | 74 | def main(): 75 | 76 | instruct_scene = False 77 | room = args.room 78 | print("testing {}...".format(room)) 79 | room_dict = {'bedroom': ["Bedroom", "MasterBedroom", "SecondBedroom"], 'livingroom': ['LivingDiningRoom','LivingRoom'], 80 | 'diningroom': ['LivingDiningRoom','DiningRoom'], 81 | 'all': ["Bedroom", "MasterBedroom", "SecondBedroom",'LivingDiningRoom','LivingRoom','DiningRoom']} 82 | 83 | print("Generating temporary a folder with test_real images...") 84 | path_to_test_real = os.path.join(args.path_to_test, "real")# /tmp/test_real 85 | if not os.path.exists(path_to_test_real): 86 | os.makedirs(path_to_test_real) 87 | real_images = [ 88 | os.path.join(args.path_to_real_renderings, oi) 89 | for oi in os.listdir(args.path_to_real_renderings) 90 | if oi.endswith(".png") and oi.split('-')[0] in room_dict[room] 91 | ] 92 | for i, fi in enumerate(real_images): 93 | shutil.copyfile(fi, "{}/{:05d}.png".format(path_to_test_real, i)) 94 | # Number of images to be copied 95 | N = len(real_images) 96 | print('number of real images :', len(real_images)) 97 | 98 | print("Generating temporary a folder with test_fake images...") 99 | path_to_test_fake = os.path.join(args.path_to_test, "fake") #/tmp/test_fake/ 100 | if not os.path.exists(path_to_test_fake): 101 | os.makedirs(path_to_test_fake) 102 | 103 | if not instruct_scene: 104 | synthesized_images = [ 105 | os.path.join(args.path_to_synthesized_renderings, oi) 106 | for oi in os.listdir(args.path_to_synthesized_renderings) 107 | if oi.endswith(".png") and oi.split('-')[0] in room_dict[room] 108 | ] 109 | else: 110 | synthesized_images = [ 111 | os.path.join(args.path_to_synthesized_renderings, oi) 112 | for oi in os.listdir(args.path_to_synthesized_renderings) 113 | if oi.endswith(".png") and oi.split('_')[1].split('-')[0] in room_dict[room] 114 | ] 115 | print('number of synthesized images :', len(synthesized_images)) 116 | 117 | scores = [] 118 | scores2 = [] 119 | file_path_for_output = args.path_to_synthesized_renderings.split("/")[:-2] 120 | file_path_for_output = os.path.join("/".join(file_path_for_output), args.room + "_fid_kid_result.txt") 121 | if args.compare_trainval: 122 | if True: 123 | for i, fi in enumerate(synthesized_images): 124 | shutil.copyfile(fi, "{}/{:05d}.png".format(path_to_test_fake, i)) 125 | 126 | # Compute the FID score 127 | fid_score = fid.compute_fid(path_to_test_real, path_to_test_fake, device=torch.device("cuda")) 128 | print('fid score:', fid_score) 129 | kid_score = fid.compute_kid(path_to_test_real, path_to_test_fake, device=torch.device("cuda")) 130 | print('kid score:', kid_score) 131 | os.system('rm -r %s'%path_to_test_real) 132 | os.system('rm -r %s'%path_to_test_fake) 133 | with open(file_path_for_output, 'w') as file: 134 | file.write('fid score:{}'.format(fid_score)) 135 | file.write('kid score:{}'.format(kid_score)) 136 | else: 137 | for _ in range(1): 138 | # np.random.shuffle(synthesized_images) 139 | # synthesized_images_subset = np.random.choice(synthesized_images, N) 140 | synthesized_images_subset = synthesized_images 141 | for i, fi in enumerate(synthesized_images_subset): 142 | shutil.copyfile(fi, "{}/{:05d}.png".format(path_to_test_fake, i)) 143 | 144 | # Compute the FID score 145 | fid_score = fid.compute_fid(path_to_test_real, path_to_test_fake, device=torch.device("cuda")) 146 | 147 | scores.append(fid_score) 148 | print('iter: {:d}, fid :{:f}'.format(_, fid_score)) 149 | print('iter: {:d}, fid avg: {:f}'.format(_, sum(scores) / len(scores)) ) 150 | print('iter: {:d}, fid std: {:f}'.format(_, np.std(scores)) ) 151 | 152 | fid_score_clip = fid.compute_fid(path_to_test_real, path_to_test_fake, mode="clean", model_name="clip_vit_b_32") 153 | print('iter: {:d}, fid-clip :{:f}'.format(_, fid_score_clip)) 154 | 155 | kid_score = fid.compute_kid(path_to_test_real, path_to_test_fake, device=torch.device("cuda")) 156 | scores2.append(kid_score) 157 | print('iter: {:d}, kid: {:f}'.format(_, kid_score) ) 158 | print('iter: {:d}, kid avg: {:f}'.format(_, sum(scores2) / len(scores2)) ) 159 | print('iter: {:d}, kid std: {:f}'.format(_, np.std(scores2)) ) 160 | with open(file_path_for_output, 'w') as file: 161 | file.write('iter: {:d}, fid :{:f}'.format(_, fid_score)) 162 | file.write('iter: {:d}, fid avg: {:f}'.format(_, sum(scores) / len(scores))) 163 | file.write('iter: {:d}, fid std: {:f}'.format(_, np.std(scores))) 164 | file.write('iter: {:d}, fid-clip :{:f}'.format(_, fid_score_clip)) 165 | file.write('iter: {:d}, kid: {:f}'.format(_, kid_score)) 166 | file.write('iter: {:d}, kid avg: {:f}'.format(_, sum(scores2) / len(scores2))) 167 | file.write('iter: {:d}, kid std: {:f}'.format(_, np.std(scores2))) 168 | 169 | os.system('rm -r %s'%path_to_test_real) 170 | os.system('rm -r %s'%path_to_test_fake) 171 | 172 | 173 | if __name__ == "__main__": 174 | main() 175 | -------------------------------------------------------------------------------- /model/networks/vqvae_networks/vqvae_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | 4 | import numpy as np 5 | import mcubes 6 | import omegaconf 7 | from termcolor import colored 8 | from einops import rearrange 9 | from tqdm import tqdm 10 | 11 | import torch 12 | from torch import nn, optim 13 | from torch.profiler import record_function 14 | 15 | import torchvision.utils as vutils 16 | import torchvision.transforms as transforms 17 | 18 | from model.base_model import BaseModel 19 | from model.networks.vqvae_networks.network import VQVAE 20 | from model.losses import VQLoss 21 | 22 | import model.diff_utils.util 23 | from model.diff_utils.util_3d import init_mesh_renderer, render_sdf 24 | from model.diff_utils.distributed import reduce_loss_dict 25 | 26 | class VQVAEModel(BaseModel): 27 | def name(self): 28 | return 'VQVAE-Model' 29 | 30 | def initialize(self, opt): 31 | BaseModel.initialize(self, opt) 32 | self.isTrain = opt.hyper.isTrain 33 | self.model_name = self.name() 34 | self.device = opt.hyper.device 35 | 36 | # ------------------------------- 37 | # Define Networks 38 | # ------------------------------- 39 | 40 | # model 41 | assert opt.shape_branch.vq_cfg is not None 42 | configs = omegaconf.OmegaConf.load(opt.shape_branch.vq_cfg) 43 | mparam = configs.model.params 44 | n_embed = mparam.n_embed 45 | embed_dim = mparam.embed_dim 46 | ddconfig = mparam.ddconfig 47 | 48 | self.vqvae = VQVAE(ddconfig, n_embed, embed_dim) 49 | self.vqvae.to(self.device) 50 | 51 | if self.isTrain: 52 | # define loss functions 53 | codebook_weight = configs.lossconfig.params.codebook_weight 54 | self.loss_vq = VQLoss(codebook_weight=codebook_weight).to(self.device) 55 | 56 | # initialize optimizers 57 | self.optimizer = optim.Adam(self.vqvae.parameters(), lr=opt.training.lr, betas=(0.5, 0.9)) 58 | self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 1000, 0.9) 59 | 60 | self.optimizers = [self.optimizer] 61 | self.schedulers = [self.scheduler] 62 | 63 | self.print_networks(verbose=False) 64 | 65 | # continue training 66 | if opt.shape_branch.ckpt is not None: 67 | self.load_ckpt(opt.shape_branch.ckpt, load_opt=self.isTrain) 68 | 69 | # setup renderer 70 | dist, elev, azim = 1.7, 20, 20 71 | self.renderer = init_mesh_renderer(image_size=256, dist=dist, elev=elev, azim=azim, device=self.device) 72 | 73 | # for saving best ckpt 74 | self.best_iou = -1e12 75 | 76 | # for distributed training 77 | if self.opt.hyper.distributed: 78 | self.make_distributed(opt) 79 | self.vqvae_module = self.vqvae.module 80 | else: 81 | self.vqvae_module = self.vqvae 82 | 83 | def switch_eval(self): 84 | self.vqvae.eval() 85 | 86 | def switch_train(self): 87 | self.vqvae.train() 88 | 89 | def make_distributed(self, opt): 90 | self.vqvae = nn.parallel.DistributedDataParallel( 91 | self.vqvae, 92 | device_ids=[opt.local_rank], 93 | output_device=opt.local_rank, 94 | broadcast_buffers=False, 95 | ) 96 | 97 | def set_input(self, input): 98 | 99 | 100 | x = input['sdf'] 101 | self.x = x 102 | self.cur_bs = x.shape[0] # to handle last batch 103 | vars_list = ['x'] 104 | 105 | self.tocuda(var_names=vars_list) 106 | 107 | def forward(self): 108 | self.x_recon, self.qloss = self.vqvae(self.x, verbose=False) 109 | 110 | @torch.no_grad() 111 | def inference(self, data, should_render=False, verbose=False): 112 | self.switch_eval() 113 | # self.switch_train() 114 | self.set_input(data) 115 | 116 | with torch.no_grad(): 117 | self.z = self.vqvae(self.x, forward_no_quant=True, encode_only=True) 118 | self.x_recon = self.vqvae_module.decode_no_quant(self.z) 119 | 120 | if should_render: 121 | self.image = render_sdf(self.renderer, self.x) 122 | self.image_recon = render_sdf(self.renderer, self.x_recon) 123 | 124 | self.switch_train() 125 | 126 | def test_iou(self, data, thres=0.0): 127 | """ 128 | thres: threshold to consider a voxel to be free space or occupied space. 129 | """ 130 | self.inference(data, should_render=False) 131 | 132 | x = self.x 133 | x_recon = self.x_recon 134 | iou = model.diff_utils.util.iou(x, x_recon, thres) 135 | 136 | return iou 137 | 138 | def eval_metrics(self, dataloader, thres=0.0, global_step=0): 139 | # self.eval() 140 | self.switch_eval() 141 | 142 | iou_list = [] 143 | with torch.no_grad(): 144 | for ix, test_data in tqdm(enumerate(dataloader), total=len(dataloader)): 145 | 146 | iou = self.test_iou(test_data, thres=thres) 147 | iou_list.append(iou.detach()) 148 | 149 | # DEBUG 150 | # self.image_recon = render_sdf(self.renderer, self.x_recon) 151 | # vutils.save_image(self.image_recon, f'tmp/{ix}-{global_step}-recon.png') 152 | 153 | iou = torch.cat(iou_list) 154 | iou_mean, iou_std = iou.mean(), iou.std() 155 | 156 | ret = OrderedDict([ 157 | ('iou', iou_mean.data), 158 | ('iou_std', iou_std.data), 159 | ]) 160 | 161 | # check whether to save best epoch 162 | if ret['iou'] > self.best_iou: 163 | self.best_iou = ret['iou'] 164 | save_name = f'epoch-best' 165 | self.save(save_name, global_step) # pass 0 just now 166 | 167 | self.switch_train() 168 | return ret 169 | 170 | 171 | def backward(self): 172 | '''backward pass for the generator in training the unsupervised model''' 173 | total_loss, loss_dict = self.loss_vq(self.qloss, self.x, self.x_recon) 174 | 175 | self.loss = total_loss 176 | 177 | self.loss_dict = reduce_loss_dict(loss_dict) 178 | 179 | self.loss_total = loss_dict['loss_total'] 180 | self.loss_codebook = loss_dict['loss_codebook'] 181 | self.loss_nll = loss_dict['loss_nll'] 182 | self.loss_rec = loss_dict['loss_rec'] 183 | 184 | self.loss.backward() 185 | 186 | def optimize_parameters(self, total_steps): 187 | 188 | self.forward() 189 | self.optimizer.zero_grad(set_to_none=True) 190 | self.backward() 191 | self.optimizer.step() 192 | 193 | def get_current_errors(self): 194 | 195 | ret = OrderedDict([ 196 | ('total', self.loss_total.mean().data), 197 | ('codebook', self.loss_codebook.mean().data), 198 | ('nll', self.loss_nll.mean().data), 199 | ('rec', self.loss_rec.mean().data), 200 | ]) 201 | 202 | return ret 203 | 204 | def get_current_visuals(self): 205 | 206 | with torch.no_grad(): 207 | self.image = render_sdf(self.renderer, self.x) 208 | self.image_recon = render_sdf(self.renderer, self.x_recon) 209 | 210 | vis_tensor_names = [ 211 | 'image', 212 | 'image_recon', 213 | ] 214 | 215 | vis_ims = self.tnsrs2ims(vis_tensor_names) 216 | visuals = zip(vis_tensor_names, vis_ims) 217 | 218 | return OrderedDict(visuals) 219 | 220 | def save(self, label, global_step=0, save_opt=False): 221 | 222 | state_dict = { 223 | 'vqvae': self.vqvae_module.state_dict(), 224 | # 'opt': self.optimizer.state_dict(), 225 | 'global_step': global_step, 226 | } 227 | 228 | if save_opt: 229 | state_dict['opt'] = self.optimizer.state_dict() 230 | 231 | save_filename = 'vqvae_%s.pth' % (label) 232 | save_path = os.path.join(self.opt.ckpt_dir, save_filename) 233 | 234 | torch.save(state_dict, save_path) 235 | 236 | def get_codebook_weight(self): 237 | ret = self.vqvae.quantize.embedding.cpu().state_dict() 238 | self.vqvae.quantize.embedding.cuda() 239 | return ret 240 | 241 | def load_ckpt(self, ckpt, load_opt=False): 242 | map_fn = lambda storage, loc: storage 243 | if type(ckpt) == str: 244 | state_dict = torch.load(ckpt, map_location=map_fn) 245 | else: 246 | state_dict = ckpt 247 | 248 | # NOTE: handle version difference... 249 | if 'vqvae' not in state_dict: 250 | self.vqvae.load_state_dict(state_dict) 251 | else: 252 | self.vqvae.load_state_dict(state_dict['vqvae']) 253 | 254 | print(colored('[*] weight successfully load from: %s' % ckpt, 'blue')) 255 | if load_opt: 256 | self.optimizer.load_state_dict(state_dict['opt']) 257 | print(colored('[*] optimizer successfully restored from: %s' % ckpt, 'blue')) 258 | 259 | 260 | -------------------------------------------------------------------------------- /model/graph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # Modification copyright 2021 Helisa Dhamo, Fabian Manhardt 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import torch.nn as nn 20 | from model.layers import build_mlp 21 | 22 | """ 23 | PyTorch modules for dealing with scene graphs. 24 | """ 25 | 26 | 27 | def make_mlp(dim_list, activation='relu', batch_norm='none', dropout=0, norelu=False): 28 | return build_mlp(dim_list, activation, batch_norm, dropout, final_nonlinearity=(not norelu)) 29 | 30 | 31 | def _init_weights(module): 32 | if hasattr(module, 'weight'): 33 | if isinstance(module, nn.Linear): 34 | nn.init.kaiming_normal_(module.weight) 35 | 36 | 37 | class WeightNetGCN(nn.Module): 38 | """ predict a weight array for the subject and the objects """ 39 | def __init__(self, feat_dim_in1=256, feat_dim_in2=256, feat_dim=128, separate_s_o=True): 40 | super(WeightNetGCN, self).__init__() 41 | 42 | self.separate = separate_s_o 43 | 44 | if self.separate: 45 | self.Net_s = nn.Sequential( 46 | nn.Linear(3*feat_dim, 64), 47 | nn.ReLU(inplace=True), 48 | nn.Linear(64, 1), 49 | nn.Sigmoid() 50 | ) 51 | 52 | self.Net_o = nn.Sequential( 53 | nn.Linear(3*feat_dim, 64), 54 | nn.ReLU(inplace=True), 55 | nn.Linear(64, 1), 56 | nn.Sigmoid() 57 | ) 58 | else: 59 | self.Net = nn.Sequential( 60 | nn.Linear(3*feat_dim, 64), 61 | nn.ReLU(inplace=True), 62 | nn.Linear(64, 1), 63 | nn.Sigmoid() 64 | ) 65 | 66 | self.down_sample_obj = nn.Linear(feat_dim_in1, feat_dim) 67 | self.down_sample_pred = nn.Linear(feat_dim_in2, feat_dim) 68 | 69 | def forward(self, s, p, o): 70 | 71 | s = self.down_sample_obj(s) 72 | p = self.down_sample_pred(p) 73 | o = self.down_sample_obj(o) 74 | 75 | if self.separate: 76 | feat1 = torch.cat([s, o, p], 1) 77 | w_s = self.Net_s(feat1) 78 | 79 | feat2 = torch.cat([s, o, p], 1) 80 | w_o = self.Net_o(feat2) 81 | else: 82 | feat = torch.cat([s, o, p], 1) 83 | w_o = self.Net(feat) 84 | w_s = w_o 85 | 86 | return w_s, w_o 87 | 88 | 89 | class GraphTripleConv(nn.Module): 90 | """ 91 | A single layer of scene graph convolution. 92 | """ 93 | def __init__(self, input_dim_obj, input_dim_pred, output_dim=None, hidden_dim=512, 94 | pooling='avg', mlp_normalization='none', residual=True): 95 | super(GraphTripleConv, self).__init__() 96 | if output_dim is None: 97 | output_dim = input_dim_obj 98 | self.input_dim_obj = input_dim_obj 99 | self.input_dim_pred = input_dim_pred 100 | self.output_dim = output_dim 101 | self.hidden_dim = hidden_dim 102 | 103 | self.residual = residual 104 | 105 | assert pooling in ['sum', 'avg', 'wAvg'], 'Invalid pooling "%s"' % pooling 106 | 107 | self.pooling = pooling 108 | net1_layers = [2 * input_dim_obj + input_dim_pred, hidden_dim, 2 * hidden_dim + input_dim_pred] 109 | net1_layers = [l for l in net1_layers if l is not None] 110 | self.net1 = build_mlp(net1_layers, batch_norm=mlp_normalization) 111 | self.net1.apply(_init_weights) 112 | 113 | net2_layers = [hidden_dim, hidden_dim, output_dim] 114 | self.net2 = build_mlp(net2_layers, batch_norm=mlp_normalization) 115 | self.net2.apply(_init_weights) 116 | 117 | if self.residual: 118 | self.linear_projection = nn.Linear(input_dim_obj, output_dim) 119 | self.linear_projection_pred = nn.Linear(input_dim_pred, input_dim_pred) # TODO is there any better option? 120 | 121 | if self.pooling == 'wAvg': 122 | self.weightNet = WeightNetGCN(hidden_dim, output_dim, 128) 123 | 124 | def forward(self, obj_vecs, pred_vecs, edges): 125 | """ 126 | Inputs: 127 | - obj_vecs: FloatTensor of shape (num_objs, D) giving vectors for all objects 128 | - pred_vecs: FloatTensor of shape (num_triples, D) giving vectors for all predicates 129 | - edges: LongTensor of shape (num_triples, 2) where edges[k] = [i, j] indicates the 130 | presence of a triple [obj_vecs[i], pred_vecs[k], obj_vecs[j]] 131 | 132 | Outputs: 133 | - new_obj_vecs: FloatTensor of shape (num_objs, D) giving new vectors for objects 134 | - new_pred_vecs: FloatTensor of shape (num_triples, D) giving new vectors for predicates 135 | """ 136 | 137 | dtype, device = obj_vecs.dtype, obj_vecs.device 138 | num_objs, num_triples = obj_vecs.size(0), pred_vecs.size(0) 139 | Din_obj, Din_pred, H, Dout = self.input_dim_obj, self.input_dim_pred, self.hidden_dim, self.output_dim 140 | 141 | # Break apart indices for subjects and objects; these have shape (num_triples,) 142 | s_idx = edges[:, 0].contiguous() 143 | o_idx = edges[:, 1].contiguous() 144 | 145 | # Get current vectors for subjects and objects; these have shape (num_triples, Din) 146 | cur_s_vecs = obj_vecs[s_idx] 147 | cur_o_vecs = obj_vecs[o_idx] 148 | 149 | # Get current vectors for triples; shape is (num_triples, 3 * Din) 150 | # Pass through net1 to get new triple vecs; shape is (num_triples, 2 * H + Dout) 151 | cur_t_vecs = torch.cat([cur_s_vecs, pred_vecs, cur_o_vecs], dim=1) 152 | new_t_vecs = self.net1(cur_t_vecs) 153 | 154 | # Break apart into new s, p, and o vecs; s and o vecs have shape (num_triples, H) and 155 | # p vecs have shape (num_triples, Dout) 156 | new_s_vecs = new_t_vecs[:, :H] 157 | new_p_vecs = new_t_vecs[:, H:(H+Din_pred)] 158 | new_o_vecs = new_t_vecs[:, (H+Din_pred):] 159 | 160 | # Allocate space for pooled object vectors of shape (num_objs, H) 161 | pooled_obj_vecs = torch.zeros(num_objs, H, dtype=dtype, device=device) 162 | 163 | if self.pooling == 'wAvg': 164 | 165 | s_weights, o_weights = self.weightNet(new_s_vecs.detach(), 166 | new_p_vecs.detach(), 167 | new_o_vecs.detach()) 168 | 169 | new_s_vecs = s_weights * new_s_vecs 170 | new_o_vecs = o_weights * new_o_vecs 171 | 172 | # Use scatter_add to sum vectors for objects that appear in multiple triples; 173 | # we first need to expand the indices to have shape (num_triples, D) 174 | s_idx_exp = s_idx.view(-1, 1).expand_as(new_s_vecs) 175 | o_idx_exp = o_idx.view(-1, 1).expand_as(new_o_vecs) 176 | pooled_obj_vecs = pooled_obj_vecs.scatter_add(0, s_idx_exp, new_s_vecs) 177 | pooled_obj_vecs = pooled_obj_vecs.scatter_add(0, o_idx_exp, new_o_vecs) 178 | 179 | if self.pooling == 'wAvg': 180 | pooled_weight_sums = torch.zeros(num_objs, 1, dtype=dtype, device=device) 181 | pooled_weight_sums = pooled_weight_sums.scatter_add(0, o_idx.view(-1, 1), o_weights) 182 | pooled_weight_sums = pooled_weight_sums.scatter_add(0, s_idx.view(-1, 1), s_weights) 183 | 184 | pooled_obj_vecs = pooled_obj_vecs / (pooled_weight_sums + 0.0001) 185 | 186 | if self.pooling == 'avg': 187 | # Figure out how many times each object has appeared, again using 188 | # some scatter_add trickery. 189 | obj_counts = torch.zeros(num_objs, dtype=dtype, device=device) 190 | ones = torch.ones(num_triples, dtype=dtype, device=device) 191 | obj_counts = obj_counts.scatter_add(0, s_idx, ones) 192 | obj_counts = obj_counts.scatter_add(0, o_idx, ones) 193 | 194 | # Divide the new object vectors by the number of times they 195 | # appeared, but first clamp at 1 to avoid dividing by zero; 196 | # objects that appear in no triples will have output vector 0 197 | # so this will not affect them. 198 | obj_counts = obj_counts.clamp(min=1) 199 | pooled_obj_vecs = pooled_obj_vecs / obj_counts.view(-1, 1) 200 | 201 | # Send pooled object vectors through net2 to get output object vectors, 202 | # of shape (num_objs, Dout) 203 | new_obj_vecs = self.net2(pooled_obj_vecs) 204 | 205 | if self.residual: 206 | projected_obj_vecs = self.linear_projection(obj_vecs) 207 | new_obj_vecs = new_obj_vecs + projected_obj_vecs 208 | # new 209 | new_p_vecs = new_p_vecs + self.linear_projection_pred(pred_vecs) 210 | 211 | return new_obj_vecs, new_p_vecs 212 | 213 | 214 | class GraphTripleConvNet(nn.Module): 215 | """ A sequence of scene graph convolution layers """ 216 | def __init__(self, input_dim_obj, input_dim_pred, num_layers=2, hidden_dim=512, 217 | residual=False, pooling='avg', 218 | mlp_normalization='none', output_dim=None): 219 | super(GraphTripleConvNet, self).__init__() 220 | 221 | self.num_layers = num_layers 222 | self.gconvs = nn.ModuleList() 223 | gconv_kwargs = { 224 | 'input_dim_obj': input_dim_obj, 225 | 'input_dim_pred': input_dim_pred, 226 | 'hidden_dim': hidden_dim, 227 | 'pooling': pooling, 228 | 'residual': residual, 229 | 'mlp_normalization': mlp_normalization, 230 | } 231 | gconv_kwargs_out = { 232 | 'input_dim_obj': input_dim_obj, 233 | 'input_dim_pred': input_dim_pred, 234 | 'hidden_dim': hidden_dim, 235 | 'pooling': pooling, 236 | 'residual': residual, 237 | 'mlp_normalization': mlp_normalization, 238 | 'output_dim': output_dim 239 | } 240 | for i in range(self.num_layers): 241 | if output_dim is not None and i >= self.num_layers - 1: 242 | self.gconvs.append(GraphTripleConv(**gconv_kwargs_out)) 243 | else: 244 | self.gconvs.append(GraphTripleConv(**gconv_kwargs)) 245 | 246 | def forward(self, obj_vecs, pred_vecs, edges): 247 | for i in range(self.num_layers): 248 | gconv = self.gconvs[i] 249 | obj_vecs, pred_vecs = gconv(obj_vecs, pred_vecs, edges) 250 | return obj_vecs, pred_vecs -------------------------------------------------------------------------------- /model/diff_utils/demo_util.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from einops import rearrange 4 | from omegaconf import OmegaConf 5 | from PIL import Image 6 | 7 | import torch 8 | import torchvision.utils as vutils 9 | 10 | from datasets.base_dataset import CreateDataset 11 | from datasets.dataloader import CreateDataLoader, get_data_generator 12 | 13 | from models.base_model import create_model 14 | 15 | from utils.util import seed_everything 16 | 17 | ############ START: all Opt classes ############ 18 | 19 | class BaseOpt(object): 20 | def __init__(self, gpu_ids=0, seed=None): 21 | # important args 22 | self.isTrain = False 23 | self.gpu_ids = [gpu_ids] 24 | # self.device = f'cuda:{gpu_ids}' 25 | self.device = 'cuda' 26 | self.debug = '0' 27 | 28 | # default args 29 | self.serial_batches = False 30 | self.nThreads = 4 31 | self.distributed = False 32 | 33 | # hyperparams 34 | self.batch_size = 1 35 | 36 | # dataset args 37 | self.max_dataset_size = 10000000 38 | self.trunc_thres = 0.2 39 | 40 | if seed is not None: 41 | seed_everything(seed) 42 | 43 | self.phase = 'test' 44 | 45 | def name(self): 46 | 47 | return 'BaseOpt' 48 | 49 | class VQVAEOpt(BaseOpt): 50 | def __init__(self, gpu_ids=0, seed=None): 51 | super().__init__(gpu_ids) 52 | 53 | # some other custom args here 54 | 55 | print(f'[*] {self.name()} initialized.') 56 | 57 | def name(self): 58 | return 'VQVAETestOpt' 59 | 60 | class SDFusionOpt(BaseOpt): 61 | def __init__(self, gpu_ids=0, seed=None): 62 | super().__init__(gpu_ids, seed=seed) 63 | 64 | # some other custom args here 65 | 66 | ## building net 67 | # opt.res = 128 68 | # opt.dataset_mode = 'buildingnet' 69 | # opt.cat = 'all' 70 | print(f'[*] {self.name()} initialized.') 71 | 72 | def init_dset_args(self, dataset_mode='snet', cat='all', res=64): 73 | # dataset - snet 74 | self.dataroot = None 75 | self.cached_dir = None 76 | self.ratio = 1.0 77 | self.res = res 78 | self.dataset_mode = dataset_mode 79 | self.cat = cat 80 | 81 | def init_model_args( 82 | self, 83 | ckpt_path='saved_ckpt/sdfusion-snet-all.pth', 84 | vq_ckpt_path='saved_ckpt/vqvae-snet-all.pth', 85 | ): 86 | self.model = 'sdfusion' 87 | self.df_cfg = 'configs/sdfusion_snet.yaml' 88 | self.ckpt = ckpt_path 89 | 90 | self.vq_model = 'vqvae' 91 | self.vq_cfg = 'configs/vqvae_snet.yaml' 92 | self.vq_ckpt = vq_ckpt_path 93 | self.vq_dset = 'snet' 94 | self.vq_cat = 'all' 95 | 96 | def name(self): 97 | return 'SDFusionTestOption' 98 | 99 | 100 | class SDFusionText2ShapeOpt(BaseOpt): 101 | def __init__(self, gpu_ids=0, seed=None): 102 | super().__init__(gpu_ids, seed=seed) 103 | 104 | # some other custom args here 105 | print(f'[*] {self.name()} initialized.') 106 | 107 | def init_dset_args(self, dataset_mode='text2shape', cat='all', res=64): 108 | # dataset - snet 109 | self.dataroot = None 110 | self.cached_dir = None 111 | self.ratio = 1.0 112 | self.res = res 113 | self.dataset_mode = dataset_mode 114 | self.cat = cat 115 | 116 | def init_model_args( 117 | self, 118 | ckpt_path='saved_ckpt/sdfusion-txt2shape.pth', 119 | vq_ckpt_path='saved_ckpt/vqvae-snet-all.pth', 120 | ): 121 | self.model = 'sdfusion-txt2shape' 122 | self.df_cfg = 'configs/sdfusion-txt2shape.yaml' 123 | self.ckpt = ckpt_path 124 | 125 | self.vq_model = 'vqvae' 126 | self.vq_cfg = 'configs/vqvae_snet.yaml' 127 | self.vq_ckpt = vq_ckpt_path 128 | self.vq_dset = 'snet' 129 | self.vq_cat = 'all' 130 | 131 | def name(self): 132 | return 'SDFusionText2ShapeOption' 133 | 134 | class SDFusionImage2ShapeOpt(BaseOpt): 135 | def __init__(self, gpu_ids=0, seed=None): 136 | super().__init__(gpu_ids, seed=seed) 137 | 138 | # some other custom args here 139 | print(f'[*] {self.name()} initialized.') 140 | 141 | def init_dset_args(self, dataset_mode='pix3d_img2shape', cat='all', res=64): 142 | # dataset - snet 143 | self.dataroot = None 144 | self.cached_dir = None 145 | self.ratio = 1.0 146 | self.res = res 147 | self.dataset_mode = dataset_mode 148 | self.cat = cat 149 | 150 | def init_model_args( 151 | self, 152 | ckpt_path='saved_ckpt/sdfusion-img2shape.pth', 153 | vq_ckpt_path='saved_ckpt/vqvae-snet-all.pth', 154 | ): 155 | self.model = 'sdfusion-img2shape' 156 | self.df_cfg = 'configs/sdfusion-img2shape.yaml' 157 | self.ckpt = ckpt_path 158 | 159 | self.vq_model = 'vqvae' 160 | self.vq_cfg = 'configs/vqvae_snet.yaml' 161 | self.vq_ckpt = vq_ckpt_path 162 | self.vq_dset = 'snet' 163 | self.vq_cat = 'all' 164 | 165 | def name(self): 166 | return 'SDFusionImage2ShapeOption' 167 | 168 | 169 | ############ END: all Opt classes ############ 170 | 171 | # get partial shape from range 172 | def get_partial_shape(shape, xyz_dict, z=None): 173 | """ 174 | args: 175 | shape: input sdf. (B, 1, H, W, D) 176 | xyz_dict: user-specified range. 177 | x: left to right 178 | y: bottom to top 179 | z: front to back 180 | """ 181 | x = shape 182 | device = x.device 183 | (x_min, x_max) = xyz_dict['x'] 184 | (y_min, y_max) = xyz_dict['y'] 185 | (z_min, z_max) = xyz_dict['z'] 186 | 187 | # clamp to [-1, 1] 188 | x_min, x_max = max(-1, x_min), min(1, x_max) 189 | y_min, y_max = max(-1, y_min), min(1, y_max) 190 | z_min, z_max = max(-1, z_min), min(1, z_max) 191 | 192 | B, _, H, W, D = x.shape # assume D = H = W 193 | 194 | x_st = int( (x_min - (-1))/2 * H ) 195 | x_ed = int( (x_max - (-1))/2 * H ) 196 | 197 | y_st = int( (y_min - (-1))/2 * W ) 198 | y_ed = int( (y_max - (-1))/2 * W ) 199 | 200 | z_st = int( (z_min - (-1))/2 * D ) 201 | z_ed = int( (z_max - (-1))/2 * D ) 202 | 203 | # print('x: ', xyz_dict['x'], x_st, x_ed) 204 | # print('y: ', xyz_dict['y'], y_st, y_ed) 205 | # print('z: ', xyz_dict['z'], z_st, z_ed) 206 | 207 | # where to keep 208 | x_mask = torch.ones(B, 1, H, W, D).bool().to(device) 209 | x_mask[:, :, :x_st, :, :] = False 210 | x_mask[:, :, x_ed:, :, :] = False 211 | 212 | x_mask[:, :, :, :y_st, :] = False 213 | x_mask[:, :, :, y_ed:, :] = False 214 | 215 | x_mask[:, :, :, :, :z_st] = False 216 | x_mask[:, :, :, :, z_ed:] = False 217 | 218 | shape_part = x.clone() 219 | shape_missing = x.clone() 220 | shape_part[~x_mask] = 0.2 # T-SDF 221 | shape_missing[x_mask] = 0.2 222 | 223 | ret = { 224 | 'shape_part': shape_part, 225 | 'shape_missing': shape_missing, 226 | 'shape_mask': x_mask, 227 | } 228 | 229 | if z is not None: 230 | B, _, zH, zW, zD = z.shape # assume D = H = W 231 | 232 | x_st = int( (x_min - (-1))/2 * zH ) 233 | x_ed = int( (x_max - (-1))/2 * zH ) 234 | 235 | y_st = int( (y_min - (-1))/2 * zW ) 236 | y_ed = int( (y_max - (-1))/2 * zW ) 237 | 238 | z_st = int( (z_min - (-1))/2 * zD ) 239 | z_ed = int( (z_max - (-1))/2 * zD ) 240 | 241 | # where to keep 242 | z_mask = torch.ones(B, 1, zH, zW, zD).to(device) 243 | z_mask[:, :, :x_st, :, :] = 0. 244 | z_mask[:, :, x_ed:, :, :] = 0. 245 | 246 | z_mask[:, :, :, :y_st, :] = 0. 247 | z_mask[:, :, :, y_ed:, :] = 0. 248 | 249 | z_mask[:, :, :, :, :z_st] = 0. 250 | z_mask[:, :, :, :, z_ed:] = 0. 251 | 252 | ret['z_mask'] = z_mask 253 | 254 | return ret 255 | 256 | # for img2shape 257 | # https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array 258 | def mask2bbox(mask): 259 | # mask: w x h 260 | rows = np.any(mask, axis=1) 261 | cols = np.any(mask, axis=0) 262 | rmin, rmax = np.where(rows)[0][[0, -1]] 263 | cmin, cmax = np.where(cols)[0][[0, -1]] 264 | # return rmin, rmax, cmin, cmax 265 | return cmin, rmin, cmax, rmax 266 | 267 | # ref: pix2vox: https://github.com/hzxie/Pix2Vox/blob/f1b82823e79d4afeedddfadb3da0940bcf1c536d/utils/data_transforms.py 268 | def crop_square(img, bbox, img_size_h=256, img_size_w=256): 269 | # from pix2vox 270 | img_height, img_width, c = img.shape 271 | 272 | x0, y0, x1, y1 = bbox 273 | 274 | # Calculate the size of bounding boxes 275 | bbox_width = x1 - x0 276 | bbox_height = y1 - y0 277 | bbox_x_mid = (x0 + x1) * .5 278 | bbox_y_mid = (y0 + y1) * .5 279 | 280 | # Make the crop area as a square 281 | square_object_size = max(bbox_width, bbox_height) 282 | x_left = int(bbox_x_mid - square_object_size * .5) 283 | x_right = int(bbox_x_mid + square_object_size * .5) 284 | y_top = int(bbox_y_mid - square_object_size * .5) 285 | y_bottom = int(bbox_y_mid + square_object_size * .5) 286 | 287 | # If the crop position is out of the image, fix it with padding 288 | pad_x_left = 0 289 | if x_left < 0: 290 | pad_x_left = -x_left 291 | x_left = 0 292 | pad_x_right = 0 293 | if x_right >= img_width: 294 | pad_x_right = x_right - img_width + 1 295 | x_right = img_width - 1 296 | pad_y_top = 0 297 | if y_top < 0: 298 | pad_y_top = -y_top 299 | y_top = 0 300 | pad_y_bottom = 0 301 | if y_bottom >= img_height: 302 | pad_y_bottom = y_bottom - img_height + 1 303 | y_bottom = img_height - 1 304 | 305 | # Padding the image and resize the image 306 | processed_image = np.pad(img[y_top:y_bottom + 1, x_left:x_right + 1], 307 | ((pad_y_top, pad_y_bottom), (pad_x_left, pad_x_right), (0, 0)), 308 | mode='edge') 309 | 310 | pil_img = Image.fromarray(processed_image) 311 | pil_img = pil_img.resize((img_size_w, img_size_h)) 312 | # processed_image = cv2.resize(processed_image, (img_size_w, img_size_h)) 313 | 314 | return pil_img 315 | 316 | 317 | def preprocess_image(image, mask): 318 | if type(image) is str: 319 | img_np = np.array(Image.open(image).convert('RGB')) 320 | else: 321 | img_np = image 322 | if type(mask) is str: 323 | mask_np = np.array(Image.open(mask).convert('1')) 324 | else: 325 | mask_np = mask 326 | 327 | # get bbox from mask 328 | x0, y0, x1, y1 = mask2bbox(mask_np) 329 | bbox = [x0, y0, x1, y1] 330 | 331 | r = 0.7 332 | img_comp = img_np * mask_np[:, :, None] + (1 - mask_np[:, :, None]) * (r*255 + (1 - r) * img_np) 333 | img_comp = crop_square(img_comp.astype(np.uint8), bbox) 334 | 335 | img_clean = img_np * mask_np[:, :, None] 336 | img_clean = crop_square(img_clean.astype(np.uint8), bbox) 337 | 338 | return img_comp, img_clean 339 | 340 | -------------------------------------------------------------------------------- /model/networks/diffusion_shape/diff_utils/demo_util.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from einops import rearrange 4 | from omegaconf import OmegaConf 5 | from PIL import Image 6 | 7 | import torch 8 | import torchvision.utils as vutils 9 | 10 | from datasets.base_dataset import CreateDataset 11 | from datasets.dataloader import CreateDataLoader, get_data_generator 12 | 13 | from models.base_model import create_model 14 | 15 | from utils.util import seed_everything 16 | 17 | ############ START: all Opt classes ############ 18 | 19 | class BaseOpt(object): 20 | def __init__(self, gpu_ids=0, seed=None): 21 | # important args 22 | self.isTrain = False 23 | self.gpu_ids = [gpu_ids] 24 | # self.device = f'cuda:{gpu_ids}' 25 | self.device = 'cuda' 26 | self.debug = '0' 27 | 28 | # default args 29 | self.serial_batches = False 30 | self.nThreads = 4 31 | self.distributed = False 32 | 33 | # hyperparams 34 | self.batch_size = 1 35 | 36 | # dataset args 37 | self.max_dataset_size = 10000000 38 | self.trunc_thres = 0.2 39 | 40 | if seed is not None: 41 | seed_everything(seed) 42 | 43 | self.phase = 'test' 44 | 45 | def name(self): 46 | 47 | return 'BaseOpt' 48 | 49 | class VQVAEOpt(BaseOpt): 50 | def __init__(self, gpu_ids=0, seed=None): 51 | super().__init__(gpu_ids) 52 | 53 | # some other custom args here 54 | 55 | print(f'[*] {self.name()} initialized.') 56 | 57 | def name(self): 58 | return 'VQVAETestOpt' 59 | 60 | class SDFusionOpt(BaseOpt): 61 | def __init__(self, gpu_ids=0, seed=None): 62 | super().__init__(gpu_ids, seed=seed) 63 | 64 | # some other custom args here 65 | 66 | ## building net 67 | # opt.res = 128 68 | # opt.dataset_mode = 'buildingnet' 69 | # opt.cat = 'all' 70 | print(f'[*] {self.name()} initialized.') 71 | 72 | def init_dset_args(self, dataset_mode='snet', cat='all', res=64): 73 | # dataset - snet 74 | self.dataroot = None 75 | self.cached_dir = None 76 | self.ratio = 1.0 77 | self.res = res 78 | self.dataset_mode = dataset_mode 79 | self.cat = cat 80 | 81 | def init_model_args( 82 | self, 83 | ckpt_path='saved_ckpt/sdfusion-snet-all.pth', 84 | vq_ckpt_path='saved_ckpt/vqvae-snet-all.pth', 85 | ): 86 | self.model = 'sdfusion' 87 | self.df_cfg = 'configs/sdfusion_snet.yaml' 88 | self.ckpt = ckpt_path 89 | 90 | self.vq_model = 'vqvae' 91 | self.vq_cfg = 'configs/vqvae_snet.yaml' 92 | self.vq_ckpt = vq_ckpt_path 93 | self.vq_dset = 'snet' 94 | self.vq_cat = 'all' 95 | 96 | def name(self): 97 | return 'SDFusionTestOption' 98 | 99 | 100 | class SDFusionText2ShapeOpt(BaseOpt): 101 | def __init__(self, gpu_ids=0, seed=None): 102 | super().__init__(gpu_ids, seed=seed) 103 | 104 | # some other custom args here 105 | print(f'[*] {self.name()} initialized.') 106 | 107 | def init_dset_args(self, dataset_mode='text2shape', cat='all', res=64): 108 | # dataset - snet 109 | self.dataroot = None 110 | self.cached_dir = None 111 | self.ratio = 1.0 112 | self.res = res 113 | self.dataset_mode = dataset_mode 114 | self.cat = cat 115 | 116 | def init_model_args( 117 | self, 118 | ckpt_path='saved_ckpt/sdfusion-txt2shape.pth', 119 | vq_ckpt_path='saved_ckpt/vqvae-snet-all.pth', 120 | ): 121 | self.model = 'sdfusion-txt2shape' 122 | self.df_cfg = 'configs/sdfusion-txt2shape.yaml' 123 | self.ckpt = ckpt_path 124 | 125 | self.vq_model = 'vqvae' 126 | self.vq_cfg = 'configs/vqvae_snet.yaml' 127 | self.vq_ckpt = vq_ckpt_path 128 | self.vq_dset = 'snet' 129 | self.vq_cat = 'all' 130 | 131 | def name(self): 132 | return 'SDFusionText2ShapeOption' 133 | 134 | class SDFusionImage2ShapeOpt(BaseOpt): 135 | def __init__(self, gpu_ids=0, seed=None): 136 | super().__init__(gpu_ids, seed=seed) 137 | 138 | # some other custom args here 139 | print(f'[*] {self.name()} initialized.') 140 | 141 | def init_dset_args(self, dataset_mode='pix3d_img2shape', cat='all', res=64): 142 | # dataset - snet 143 | self.dataroot = None 144 | self.cached_dir = None 145 | self.ratio = 1.0 146 | self.res = res 147 | self.dataset_mode = dataset_mode 148 | self.cat = cat 149 | 150 | def init_model_args( 151 | self, 152 | ckpt_path='saved_ckpt/sdfusion-img2shape.pth', 153 | vq_ckpt_path='saved_ckpt/vqvae-snet-all.pth', 154 | ): 155 | self.model = 'sdfusion-img2shape' 156 | self.df_cfg = 'configs/sdfusion-img2shape.yaml' 157 | self.ckpt = ckpt_path 158 | 159 | self.vq_model = 'vqvae' 160 | self.vq_cfg = 'configs/vqvae_snet.yaml' 161 | self.vq_ckpt = vq_ckpt_path 162 | self.vq_dset = 'snet' 163 | self.vq_cat = 'all' 164 | 165 | def name(self): 166 | return 'SDFusionImage2ShapeOption' 167 | 168 | 169 | ############ END: all Opt classes ############ 170 | 171 | # get partial shape from range 172 | def get_partial_shape(shape, xyz_dict, z=None): 173 | """ 174 | args: 175 | shape: input sdf. (B, 1, H, W, D) 176 | xyz_dict: user-specified range. 177 | x: left to right 178 | y: bottom to top 179 | z: front to back 180 | """ 181 | x = shape 182 | device = x.device 183 | (x_min, x_max) = xyz_dict['x'] 184 | (y_min, y_max) = xyz_dict['y'] 185 | (z_min, z_max) = xyz_dict['z'] 186 | 187 | # clamp to [-1, 1] 188 | x_min, x_max = max(-1, x_min), min(1, x_max) 189 | y_min, y_max = max(-1, y_min), min(1, y_max) 190 | z_min, z_max = max(-1, z_min), min(1, z_max) 191 | 192 | B, _, H, W, D = x.shape # assume D = H = W 193 | 194 | x_st = int( (x_min - (-1))/2 * H ) 195 | x_ed = int( (x_max - (-1))/2 * H ) 196 | 197 | y_st = int( (y_min - (-1))/2 * W ) 198 | y_ed = int( (y_max - (-1))/2 * W ) 199 | 200 | z_st = int( (z_min - (-1))/2 * D ) 201 | z_ed = int( (z_max - (-1))/2 * D ) 202 | 203 | # print('x: ', xyz_dict['x'], x_st, x_ed) 204 | # print('y: ', xyz_dict['y'], y_st, y_ed) 205 | # print('z: ', xyz_dict['z'], z_st, z_ed) 206 | 207 | # where to keep 208 | x_mask = torch.ones(B, 1, H, W, D).bool().to(device) 209 | x_mask[:, :, :x_st, :, :] = False 210 | x_mask[:, :, x_ed:, :, :] = False 211 | 212 | x_mask[:, :, :, :y_st, :] = False 213 | x_mask[:, :, :, y_ed:, :] = False 214 | 215 | x_mask[:, :, :, :, :z_st] = False 216 | x_mask[:, :, :, :, z_ed:] = False 217 | 218 | shape_part = x.clone() 219 | shape_missing = x.clone() 220 | shape_part[~x_mask] = 0.2 # T-SDF 221 | shape_missing[x_mask] = 0.2 222 | 223 | ret = { 224 | 'shape_part': shape_part, 225 | 'shape_missing': shape_missing, 226 | 'shape_mask': x_mask, 227 | } 228 | 229 | if z is not None: 230 | B, _, zH, zW, zD = z.shape # assume D = H = W 231 | 232 | x_st = int( (x_min - (-1))/2 * zH ) 233 | x_ed = int( (x_max - (-1))/2 * zH ) 234 | 235 | y_st = int( (y_min - (-1))/2 * zW ) 236 | y_ed = int( (y_max - (-1))/2 * zW ) 237 | 238 | z_st = int( (z_min - (-1))/2 * zD ) 239 | z_ed = int( (z_max - (-1))/2 * zD ) 240 | 241 | # where to keep 242 | z_mask = torch.ones(B, 1, zH, zW, zD).to(device) 243 | z_mask[:, :, :x_st, :, :] = 0. 244 | z_mask[:, :, x_ed:, :, :] = 0. 245 | 246 | z_mask[:, :, :, :y_st, :] = 0. 247 | z_mask[:, :, :, y_ed:, :] = 0. 248 | 249 | z_mask[:, :, :, :, :z_st] = 0. 250 | z_mask[:, :, :, :, z_ed:] = 0. 251 | 252 | ret['z_mask'] = z_mask 253 | 254 | return ret 255 | 256 | # for img2shape 257 | # https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array 258 | def mask2bbox(mask): 259 | # mask: w x h 260 | rows = np.any(mask, axis=1) 261 | cols = np.any(mask, axis=0) 262 | rmin, rmax = np.where(rows)[0][[0, -1]] 263 | cmin, cmax = np.where(cols)[0][[0, -1]] 264 | # return rmin, rmax, cmin, cmax 265 | return cmin, rmin, cmax, rmax 266 | 267 | # ref: pix2vox: https://github.com/hzxie/Pix2Vox/blob/f1b82823e79d4afeedddfadb3da0940bcf1c536d/utils/data_transforms.py 268 | def crop_square(img, bbox, img_size_h=256, img_size_w=256): 269 | # from pix2vox 270 | img_height, img_width, c = img.shape 271 | 272 | x0, y0, x1, y1 = bbox 273 | 274 | # Calculate the size of bounding boxes 275 | bbox_width = x1 - x0 276 | bbox_height = y1 - y0 277 | bbox_x_mid = (x0 + x1) * .5 278 | bbox_y_mid = (y0 + y1) * .5 279 | 280 | # Make the crop area as a square 281 | square_object_size = max(bbox_width, bbox_height) 282 | x_left = int(bbox_x_mid - square_object_size * .5) 283 | x_right = int(bbox_x_mid + square_object_size * .5) 284 | y_top = int(bbox_y_mid - square_object_size * .5) 285 | y_bottom = int(bbox_y_mid + square_object_size * .5) 286 | 287 | # If the crop position is out of the image, fix it with padding 288 | pad_x_left = 0 289 | if x_left < 0: 290 | pad_x_left = -x_left 291 | x_left = 0 292 | pad_x_right = 0 293 | if x_right >= img_width: 294 | pad_x_right = x_right - img_width + 1 295 | x_right = img_width - 1 296 | pad_y_top = 0 297 | if y_top < 0: 298 | pad_y_top = -y_top 299 | y_top = 0 300 | pad_y_bottom = 0 301 | if y_bottom >= img_height: 302 | pad_y_bottom = y_bottom - img_height + 1 303 | y_bottom = img_height - 1 304 | 305 | # Padding the image and resize the image 306 | processed_image = np.pad(img[y_top:y_bottom + 1, x_left:x_right + 1], 307 | ((pad_y_top, pad_y_bottom), (pad_x_left, pad_x_right), (0, 0)), 308 | mode='edge') 309 | 310 | pil_img = Image.fromarray(processed_image) 311 | pil_img = pil_img.resize((img_size_w, img_size_h)) 312 | # processed_image = cv2.resize(processed_image, (img_size_w, img_size_h)) 313 | 314 | return pil_img 315 | 316 | 317 | def preprocess_image(image, mask): 318 | if type(image) is str: 319 | img_np = np.array(Image.open(image).convert('RGB')) 320 | else: 321 | img_np = image 322 | if type(mask) is str: 323 | mask_np = np.array(Image.open(mask).convert('1')) 324 | else: 325 | mask_np = mask 326 | 327 | # get bbox from mask 328 | x0, y0, x1, y1 = mask2bbox(mask_np) 329 | bbox = [x0, y0, x1, y1] 330 | 331 | r = 0.7 332 | img_comp = img_np * mask_np[:, :, None] + (1 - mask_np[:, :, None]) * (r*255 + (1 - r) * img_np) 333 | img_comp = crop_square(img_comp.astype(np.uint8), bbox) 334 | 335 | img_clean = img_np * mask_np[:, :, None] 336 | img_clean = crop_square(img_clean.astype(np.uint8), bbox) 337 | 338 | return img_comp, img_clean 339 | 340 | --------------------------------------------------------------------------------