├── models ├── __init__.py ├── networks │ ├── __init__.py │ ├── diffusion_networks │ │ ├── __init__.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── spmm.py │ │ │ └── scatter.py │ │ ├── graph_unet_union.py │ │ ├── graph_unet_lr.py │ │ └── graph_unet_hr.py │ ├── dualoctree_networks │ │ ├── __init__.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── spmm.py │ │ │ └── scatter.py │ │ ├── modules.py │ │ ├── distributions.py │ │ ├── quantizer.py │ │ ├── mpu.py │ │ └── loss.py │ ├── clip_networks │ │ └── network.py │ └── bert_networks │ │ └── network.py ├── losses.py ├── model_utils.py └── base_model.py ├── utils ├── __init__.py ├── render │ ├── .DS_Store │ ├── shades │ │ ├── mesh.frag │ │ └── mesh.vert │ ├── math.py │ ├── render.py │ └── render_utils.py ├── render_utils.py ├── visualizer.py └── distributed.py ├── metrics ├── __init__.py ├── pytorch_structural_losses │ ├── .gitignore │ ├── __init__.py │ ├── src │ │ ├── nndistance.cuh │ │ ├── approxmatch.cuh │ │ ├── utils.hpp │ │ ├── nndistance.cu │ │ └── structural_loss.cpp │ ├── pybind │ │ ├── bind.cpp │ │ └── extern.hpp │ ├── setup.py │ ├── nn_distance.py │ ├── match_cost.py │ └── Makefile ├── 1-NNA.py ├── generate_dataset_for_fid.py ├── cov_mmd.py ├── diversity.py ├── calc_fid.py ├── generate_pointclouds.py ├── generate_synth_image.py ├── generate_dataset_pointclouds.py ├── compute_metrics.py └── evaluation_metrics.py ├── options ├── __init__.py ├── train_options.py └── base_options.py ├── datasets ├── __init__.py ├── utils.py ├── dataloader.py ├── sampler.py ├── shapenet_utils.py └── dualoctree_snet.py ├── solver ├── __init__.py └── dataset.py ├── assets └── teaser.png ├── requirements.txt ├── .gitignore ├── configs ├── octfusion_snet_uncond.yaml ├── octfusion_snet_cond.yaml ├── octfusion_obja_uncond.yaml ├── vae_snet_eval.yaml ├── vae_snet_eval_depth984.yaml ├── vae_snet_train.yaml ├── vae_obja_eval.yaml ├── vae_obja_train.yaml └── vae_obja_eval_depth864.yaml ├── tools └── gen_split.py ├── scripts ├── run_snet_vae.sh ├── run_snet_cond.sh └── run_snet_uncond.sh ├── README.md └── train.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/networks/diffusion_networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/networks/diffusion_networks/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/networks/dualoctree_networks/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/networks/dualoctree_networks/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .dualoctree_snet import get_shapenet_dataset 2 | -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- 1 | from . import dataset 2 | from .dataset import Dataset 3 | -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/.gitignore: -------------------------------------------------------------------------------- 1 | PyTorchStructuralLosses.egg-info/ 2 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/octree-nn/octfusion/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /utils/render/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/octree-nn/octfusion/HEAD/utils/render/.DS_Store -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/__init__.py: -------------------------------------------------------------------------------- 1 | #import torch 2 | 3 | #from MakePytorchBackend import AddGPU, Foo, ApproxMatch 4 | 5 | #from Add import add_gpu, approx_match 6 | 7 | -------------------------------------------------------------------------------- /utils/render/shades/mesh.frag: -------------------------------------------------------------------------------- 1 | #version 330 core 2 | 3 | in vec3 frag_position; 4 | in vec3 frag_normal; 5 | 6 | out vec4 frag_color; 7 | 8 | void main() 9 | { 10 | vec3 normal = normalize(frag_normal); 11 | 12 | frag_color = vec4(normal * 0.5 + 0.5, 1.0); 13 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ocnn==2.2.2 2 | h5py 3 | termcolor 4 | scipy 5 | einops 6 | tqdm 7 | matplotlib 8 | opencv-python 9 | PyMCubes 10 | imageio 11 | trimesh 12 | omegaconf 13 | tensorboard 14 | notebook 15 | numpy 16 | tqdm 17 | yacs 18 | scipy 19 | plyfile 20 | tensorboard 21 | scikit-image 22 | trimesh 23 | wget 24 | mesh2sdf 25 | matplotlib 26 | objaverse 27 | pandas 28 | kornia 29 | ftfy 30 | regex 31 | timm 32 | open3d -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/src/nndistance.cuh: -------------------------------------------------------------------------------- 1 | void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream); 2 | void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream); 3 | -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/pybind/bind.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "extern.hpp" 6 | 7 | namespace py = pybind11; 8 | 9 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 10 | m.def("ApproxMatch", &ApproxMatch); 11 | m.def("MatchCost", &MatchCost); 12 | m.def("MatchCostGrad", &MatchCostGrad); 13 | m.def("NNDistance", &NNDistance); 14 | m.def("NNDistanceGrad", &NNDistanceGrad); 15 | } 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # extensions 2 | *.pyc 3 | *.tar.gz 4 | *.tar.xz 5 | *.tar 6 | *.zip 7 | *.gif 8 | *.rar 9 | LICENSE 10 | 11 | # cache 12 | __pycache__/ 13 | .ipynb*/ 14 | 15 | # checkpoints 16 | /saved_ckpt 17 | Tencent/ 18 | 19 | # logs 20 | /logs 21 | logs_home/ 22 | quant_logs*/ 23 | qual_logs*/ 24 | 25 | # data 26 | data 27 | data1 28 | demo_data/ 29 | *results*/ 30 | *airplane*/ 31 | *chair*/ 32 | *table*/ 33 | *car*/ 34 | *rifle*/ 35 | split_small/ 36 | *.pth 37 | *.pkl 38 | 39 | # bak 40 | bak*/ 41 | /.vscode 42 | /mytools -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/pybind/extern.hpp: -------------------------------------------------------------------------------- 1 | std::vector ApproxMatch(at::Tensor in_a, at::Tensor in_b); 2 | at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match); 3 | std::vector MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match); 4 | 5 | std::vector NNDistance(at::Tensor set_d, at::Tensor set_q); 6 | std::vector NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2); 7 | -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/src/approxmatch.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | template 3 | void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N, 4 | cudaStream_t stream); 5 | */ 6 | void approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream); 7 | void matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream); 8 | void matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream); 9 | -------------------------------------------------------------------------------- /utils/render/shades/mesh.vert: -------------------------------------------------------------------------------- 1 | #version 330 core 2 | 3 | // Vertex Attributes 4 | layout(location = 0) in vec3 position; 5 | layout(location = NORMAL_LOC) in vec3 normal; 6 | layout(location = INST_M_LOC) in mat4 inst_m; 7 | 8 | // Uniforms 9 | uniform mat4 M; 10 | uniform mat4 V; 11 | uniform mat4 P; 12 | 13 | // Outputs 14 | out vec3 frag_position; 15 | out vec3 frag_normal; 16 | 17 | void main() 18 | { 19 | gl_Position = P * V * M * inst_m * vec4(position, 1); 20 | frag_position = vec3(M * inst_m * vec4(position, 1.0)); 21 | 22 | mat4 N = transpose(inverse(M * inst_m)); 23 | frag_normal = normalize(vec3(N * vec4(normal, 0.0))); 24 | } -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/src/utils.hpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | class Formatter { 6 | public: 7 | Formatter() {} 8 | ~Formatter() {} 9 | 10 | template Formatter &operator<<(const Type &value) { 11 | stream_ << value; 12 | return *this; 13 | } 14 | 15 | std::string str() const { return stream_.str(); } 16 | operator std::string() const { return stream_.str(); } 17 | 18 | enum ConvertToString { to_str }; 19 | 20 | std::string operator>>(ConvertToString) { return stream_.str(); } 21 | 22 | private: 23 | std::stringstream stream_; 24 | Formatter(const Formatter &); 25 | Formatter &operator=(Formatter &); 26 | }; 27 | -------------------------------------------------------------------------------- /configs/octfusion_snet_uncond.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | linear_start: 0.00085 4 | linear_end: 0.012 5 | conditioning_key: None 6 | timesteps: 1000 7 | scale_factor: 0.18215 8 | 9 | unet: 10 | params: 11 | image_size: [16, 64] 12 | input_depth: [4, 6] 13 | unet_type: ["lr", "hr"] 14 | df_type: ["x0", "eps"] 15 | full_depth: 4 16 | input_channels: [8, 3] 17 | out_channels: [8, 3] 18 | model_channels: [64, 128] 19 | num_res_blocks: [[1, 1, 1], [1, 1, 0]] 20 | attention_resolutions: [2, 4] # 16, 8, 4 21 | channel_mult: [[1, 2, 4], [1, 2, 4]] 22 | # num_head_channels: 32 23 | num_heads: 4 24 | use_checkpoint: False 25 | 26 | # 3d 27 | dims: 3 28 | -------------------------------------------------------------------------------- /configs/octfusion_snet_cond.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | linear_start: 0.00085 4 | linear_end: 0.012 5 | conditioning_key: None 6 | timesteps: 1000 7 | scale_factor: 0.18215 8 | 9 | unet: 10 | params: 11 | image_size: [16, 64] 12 | input_depth: [4, 6] 13 | unet_type: ["lr", "hr"] 14 | df_type: ["x0", "eps"] 15 | full_depth: 4 16 | input_channels: [8, 3] 17 | out_channels: [8, 3] 18 | model_channels: [64, 128] 19 | num_res_blocks: [[1, 1, 1], [2, 2, 0]] 20 | attention_resolutions: [2, 4, 8] # 16, 8, 4 21 | channel_mult: [[1, 2, 4, 8], [1, 2, 4]] 22 | # num_head_channels: 32 23 | num_heads: 4 24 | use_checkpoint: False 25 | num_classes: 5 26 | # 3d 27 | dims: 3 28 | -------------------------------------------------------------------------------- /configs/octfusion_obja_uncond.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | linear_start: 0.00085 4 | linear_end: 0.012 5 | conditioning_key: None 6 | timesteps: 1000 7 | scale_factor: 0.18215 8 | 9 | unet: 10 | params: 11 | image_size: [16, 64, 256] 12 | input_depth: [4, 6, 8] 13 | unet_type: ["lr", "hr", "feature"] 14 | df_type: ["x0", "x0", "x0"] 15 | full_depth: 4 16 | input_channels: [8, 8, 3] 17 | out_channels: [8, 8, 3] 18 | model_channels: [64, 128, 64] 19 | num_res_blocks: [[1, 1, 1], [2, 2, 0], [1, 1, 1]] 20 | attention_resolutions: [2, 4] # 16, 8, 4 21 | channel_mult: [[1, 2, 4], [1, 2, 4], [1, 2, 4]] 22 | # num_head_channels: 32 23 | num_heads: 4 24 | use_checkpoint: True 25 | 26 | # 3d 27 | dims: 3 28 | -------------------------------------------------------------------------------- /metrics/1-NNA.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | from metrics.evaluation_metrics import compute_cov_mmd, compute_1_nna 3 | import torch 4 | import time 5 | import numpy as np 6 | import os 7 | import argparse 8 | import sys 9 | 10 | 11 | sample_pcs = torch.load('chair_sample_pcs.pth') 12 | print(sample_pcs.shape) 13 | 14 | sample_pcs = sample_pcs.cuda().to(torch.float32) 15 | 16 | ref_pcs = torch.load('chair_ref_pcs.pth') 17 | print(ref_pcs.shape) 18 | 19 | ref_pcs = ref_pcs.cuda().to(torch.float32) 20 | 21 | print('##################################################################') 22 | 23 | results = compute_1_nna( 24 | sample_pcs[:ref_pcs.shape[0]], ref_pcs, batch_size = 256) 25 | results = {k: (v.cpu().detach().item() 26 | if not isinstance(v, float) else v) for k, v in results.items()} 27 | 28 | pprint(results) 29 | 30 | print('##################################################################') 31 | -------------------------------------------------------------------------------- /metrics/generate_dataset_for_fid.py: -------------------------------------------------------------------------------- 1 | from utils.render_utils import generate_image_for_fid 2 | import trimesh 3 | import os 4 | 5 | filelist = '/data/checkpoints/xiongbj/DualOctreeGNN-Pytorch-HR/data/ShapeNet/filelist/train_im_5.txt' 6 | mesh_dataset = '/data/public-datasets/ShapeNetCore.v1' 7 | image_path = '/data/checkpoints/xiongbj/DualOctreeGNN-Pytorch-HR/data/ShapeNet/fid_images' 8 | 9 | with open(filelist) as fid: 10 | lines = fid.readlines() 11 | 12 | for i, line in enumerate(lines): 13 | filename = line.split()[0] 14 | category = filename.split('/')[0] 15 | category_path = os.path.join(image_path, category) 16 | if not os.path.exists(category_path): os.makedirs(category_path) 17 | mesh_path = os.path.join(mesh_dataset, filename, 'model.obj') 18 | mesh = trimesh.load(mesh_path, force = 'mesh') 19 | generate_image_for_fid(mesh,category_path, i) 20 | print(f'The {i} th mesh finish rendering') 21 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | # adapt from https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/vqperceptual.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class VQLoss(nn.Module): 7 | def __init__(self, codebook_weight=1.0): 8 | super().__init__() 9 | self.codebook_weight = codebook_weight 10 | 11 | def forward(self, codebook_loss, inputs, reconstructions, split="train"): 12 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 13 | 14 | nll_loss = rec_loss 15 | nll_loss = torch.mean(nll_loss) 16 | 17 | loss = nll_loss + self.codebook_weight * codebook_loss.mean() 18 | 19 | log = { 20 | "loss_total": loss.clone().detach().mean(), 21 | "loss_codebook": codebook_loss.detach().mean(), 22 | "loss_nll": nll_loss.detach().mean(), 23 | "loss_rec": rec_loss.detach().mean(), 24 | } 25 | 26 | return loss, log 27 | -------------------------------------------------------------------------------- /utils/render_utils.py: -------------------------------------------------------------------------------- 1 | from .render.render import render_mesh 2 | import torch 3 | import torchvision 4 | from .util import ensure_directory, scale_to_unit_sphere 5 | import os 6 | 7 | def render_one_mesh(mesh, i, j, mydir, render_resolution=1024): 8 | mesh = scale_to_unit_sphere(mesh) 9 | image = render_mesh(mesh, index=j, resolution=render_resolution)/255 10 | torchvision.utils.save_image(torch.from_numpy(image.copy()).permute( 11 | 2, 0, 1), f"{mydir}/{i}_{j}.png") 12 | 13 | # Inception v3 input size (299, 299) 14 | def generate_image_for_fid(mesh, mydir, i): 15 | render_resolution = 299 16 | mesh = scale_to_unit_sphere(mesh) 17 | for j in range(20): 18 | if os.path.exists(f"{mydir}/view_{j}/{i}.png"): 19 | continue 20 | ensure_directory(f"{mydir}/view_{j}") 21 | image = render_mesh(mesh, index=j, resolution=render_resolution)/255 22 | torchvision.utils.save_image(torch.from_numpy(image.copy()).permute( 23 | 2, 0, 1), f"{mydir}/view_{j}/{i}.png") -------------------------------------------------------------------------------- /metrics/cov_mmd.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | from metrics.evaluation_metrics import compute_cov_mmd, compute_1_nna 3 | import torch 4 | import time 5 | import numpy as np 6 | import os 7 | import argparse 8 | import sys 9 | import pickle 10 | 11 | gpu_ids = 0 12 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_ids}" 13 | 14 | sample_pcs = torch.load('chair_sample_pcs.pth') 15 | print(sample_pcs.shape) 16 | 17 | sample_pcs = sample_pcs.cuda().to(torch.float32) 18 | 19 | ref_pcs = torch.load('chair_ref_pcs.pth') 20 | print(ref_pcs.shape) 21 | 22 | ref_pcs = ref_pcs.cuda().to(torch.float32) 23 | 24 | print('##################################################################') 25 | 26 | results = compute_cov_mmd(sample_pcs, ref_pcs, batch_size = 256) 27 | # results = compute_cov_mmd(sample_pcs[:256], ref_pcs[:256], batch_size = 256) 28 | results = {k: (v.cpu().detach().item() 29 | if not isinstance(v, float) else v) for k, v in results.items()} 30 | 31 | pprint(results) 32 | 33 | print('##################################################################') 34 | -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 3 | 4 | # Python interface 5 | setup( 6 | name='PyTorchStructuralLosses', 7 | version='0.1.0', 8 | install_requires=['torch'], 9 | packages=['StructuralLosses'], 10 | package_dir={'StructuralLosses': './'}, 11 | ext_modules=[ 12 | CUDAExtension( 13 | name='StructuralLossesBackend', 14 | include_dirs=['./'], 15 | sources=[ 16 | 'pybind/bind.cpp', 17 | ], 18 | libraries=['make_pytorch'], 19 | library_dirs=['objs'], 20 | # extra_compile_args=['-g'] 21 | ) 22 | ], 23 | cmdclass={'build_ext': BuildExtension}, 24 | author='Christopher B. Choy', 25 | author_email='chrischoy@ai.stanford.edu', 26 | description='Tutorial for Pytorch C++ Extension with a Makefile', 27 | keywords='Pytorch C++ Extension', 28 | url='https://github.com/chrischoy/MakePytorchPlusPlus', 29 | zip_safe=False, 30 | ) 31 | -------------------------------------------------------------------------------- /utils/render/math.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.transform import Rotation 3 | 4 | PROJECTION_MATRIX = np.array( 5 | [[ 1.73205081, 0, 0, 0, ], 6 | [ 0, 1.73205081, 0, 0, ], 7 | [ 0, 0, -1.02020202, -0.2020202, ], 8 | [ 0, 0, -1, 0, ]], dtype=float) 9 | 10 | def get_rotation_matrix(angle, axis='y'): 11 | rotation = Rotation.from_euler(axis, angle, degrees=True) 12 | matrix = np.identity(4) 13 | matrix[:3, :3] = rotation.as_matrix() 14 | return matrix 15 | 16 | def get_camera_transform(camera_distance, rotation_y, rotation_x=0, project=False): 17 | camera_transform = np.identity(4) 18 | camera_transform[2, 3] = -camera_distance 19 | camera_transform = np.matmul(camera_transform, get_rotation_matrix(rotation_x, axis='x')) 20 | camera_transform = np.matmul(camera_transform, get_rotation_matrix(rotation_y, axis='y')) 21 | 22 | if project: 23 | camera_transform = np.matmul(PROJECTION_MATRIX, camera_transform) 24 | return camera_transform -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import ocnn 9 | import torch 10 | from ocnn.dataset import CollateBatch 11 | 12 | 13 | def collate_func(batch): 14 | 15 | collate_batch = CollateBatch(merge_points=False) 16 | output = collate_batch(batch) 17 | # output = ocnn.collate_octrees(batch) 18 | 19 | if 'pos' in output: 20 | batch_idx = torch.cat([torch.ones(pos.size(0), 1) * i 21 | for i, pos in enumerate(output['pos'])], dim=0) 22 | pos = torch.cat(output['pos'], dim=0) 23 | output['pos'] = torch.cat([pos, batch_idx], dim=1) 24 | 25 | for key in ['grad', 'sdf', 'occu', 'weight']: 26 | if key in output: 27 | output[key] = torch.cat(output[key], dim=0) 28 | 29 | if 'split_small' in output: 30 | output['split_small'] = torch.stack(output['split_small']) 31 | 32 | if 'split_large' in output: 33 | output['split_large'] = torch.cat(output['split_large'], dim = 0) 34 | 35 | return output 36 | -------------------------------------------------------------------------------- /models/model_utils.py: -------------------------------------------------------------------------------- 1 | from termcolor import colored 2 | import torch 3 | 4 | from models.networks.dualoctree_networks.graph_vae import GraphVAE 5 | 6 | def load_dualoctree(conf, ckpt, opt = None): 7 | flags = conf.model 8 | params = [flags.depth, flags.channel, flags.nout, 9 | flags.full_depth, flags.depth_stop, flags.depth_out, flags.use_checkpoint] 10 | if flags.name == 'graph_vae': 11 | params.append(flags.resblock_type) 12 | params.append(flags.bottleneck) 13 | params.append(flags.resblk_num) 14 | params.append(flags.code_channel) 15 | params.append(flags.embed_dim) 16 | dualoctree = GraphVAE(*params) 17 | 18 | if ckpt is not None: 19 | trained_dict = torch.load(ckpt, map_location='cuda') 20 | if ckpt.endswith('.solver.tar'): 21 | model_dict = trained_dict['model_dict'] 22 | else: 23 | model_dict = trained_dict 24 | 25 | if 'autoencoder' in model_dict: 26 | model_dict = model_dict['autoencoder'] 27 | 28 | dualoctree.load_state_dict(model_dict) 29 | 30 | print(colored('[*] DualOctree: weight successfully load from: %s' % ckpt, 'blue')) 31 | dualoctree.requires_grad = False 32 | 33 | dualoctree.to(opt.device) 34 | dualoctree.eval() 35 | return dualoctree -------------------------------------------------------------------------------- /metrics/diversity.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from compute_metrics import compute_metrics 4 | import trimesh 5 | import numpy as np 6 | import torch 7 | import pickle 8 | 9 | gpu_ids = 3 10 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_ids}" 11 | 12 | num_samples = 2048 13 | 14 | input_obj = 'chair_mesh_2t/4165.obj' 15 | 16 | def normalize_pc_to_unit_shpere(points): 17 | centroid = (np.max(points, axis=0) + np.min(points, axis=0))/2 18 | points -= centroid 19 | distances = np.linalg.norm(points, axis=1) 20 | points /= np.max(distances) 21 | return points 22 | 23 | mesh = trimesh.load(input_obj, force='mesh') 24 | points, idx = trimesh.sample.sample_surface(mesh, num_samples) 25 | points = points.astype(np.float32) 26 | 27 | points = normalize_pc_to_unit_shpere(points) 28 | 29 | points = torch.from_numpy(points) 30 | points = points.cuda().to(torch.float32) 31 | 32 | sample_pc = points 33 | 34 | ref_pcs = torch.load('chair_train_ref_pcs.pth') 35 | ref_pcs = ref_pcs.cuda().to(torch.float32) 36 | 37 | cd = compute_metrics(sample_pc, ref_pcs, batch_size = 256) 38 | 39 | with open('name.pkl', 'rb') as file: 40 | name = pickle.load(file) 41 | 42 | k = 3 43 | 44 | sorted_values, sorted_indices = torch.topk(cd.view(-1), k, largest=False) 45 | 46 | print(name[sorted_indices[0].item()]) 47 | print(name[sorted_indices[1].item()]) 48 | print(name[sorted_indices[2].item()]) 49 | -------------------------------------------------------------------------------- /metrics/calc_fid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from cleanfid import fid 3 | import os 4 | gpu_ids = 0 5 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_ids}" 6 | 7 | snc_synth_id_to_category_5 = { 8 | '02691156': 'airplane', '02958343': 'car', '03001627': 'chair', 9 | '04379243': 'table', 10 | '04090263': 'rifle' 11 | } 12 | 13 | category_to_snc_synth_id = {v:k for (k,v) in snc_synth_id_to_category_5.items()} 14 | 15 | category = 'airplane' 16 | cond = True 17 | synth_id = category_to_snc_synth_id[category] 18 | 19 | root_dir = "logs/airplane_union/uncond_1000epoch_lr2e-4" 20 | synthesis_path = f'{root_dir}/fid_images_{category}' 21 | 22 | dataset_path = f'data/ShapeNet/fid_images/{category}' 23 | 24 | views1 = os.listdir(synthesis_path) 25 | views2 = os.listdir(dataset_path) 26 | 27 | views1 = sorted(views1, key=lambda item:int(item[5:])) 28 | views2 = sorted(views2, key=lambda item:int(item[5:])) 29 | 30 | assert len(views1) == len(views2) 31 | num_views = len(views1) 32 | 33 | fid_sum = 0 34 | fid_dict = {} 35 | 36 | for i, (view1, view2) in enumerate(zip(views1, views2)): 37 | assert view1 == view2 38 | view1_path = os.path.join(synthesis_path, view1) 39 | view2_path = os.path.join(dataset_path, view2) 40 | fid_value = fid.compute_fid(view1_path, view2_path, batch_size = 128) 41 | fid_sum += fid_value 42 | fid_dict[view1] = fid_value 43 | print(f'Finish {i} th view') 44 | print(f'The FID of {i} th view is {fid_value}') 45 | 46 | fid_ave = fid_sum / num_views 47 | 48 | print('FID Value:', fid_ave) 49 | for k, v in fid_dict.items(): 50 | print(k, v) 51 | -------------------------------------------------------------------------------- /datasets/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data 3 | 4 | from datasets.sampler import InfSampler, DistributedInfSampler 5 | from builder import get_dataset 6 | from omegaconf import OmegaConf 7 | from termcolor import colored, cprint 8 | 9 | def get_data_generator(loader): 10 | while True: 11 | for data in loader: 12 | yield data 13 | 14 | def config_dataloader(opt): 15 | dualoctree_conf = OmegaConf.load(opt.vq_cfg) 16 | flags_train, flags_test = dualoctree_conf.data.train, dualoctree_conf.data.test 17 | flags_train.filelist = os.path.join(flags_train.filelist, f'train_{opt.category}.txt') 18 | flags_test.filelist = os.path.join(flags_test.filelist, f'test_{opt.category}.txt') 19 | 20 | train_loader = get_dataloader(opt,flags_train, drop_last = False) 21 | 22 | if not flags_test.disable: 23 | test_loader = get_dataloader(opt,flags_test, drop_last = False) 24 | 25 | train_ds, test_ds = train_loader.dataset, test_loader.dataset 26 | cprint('[*] # training images = %d' % len(train_ds), 'yellow') 27 | cprint('[*] # testing images = %d' % len(test_ds), 'yellow') 28 | return train_loader, test_loader 29 | 30 | def get_dataloader(opt, flags, drop_last = False): 31 | dataset, collate_fn = get_dataset(flags) 32 | 33 | if opt.distributed: 34 | sampler = DistributedInfSampler(dataset, shuffle=flags.shuffle) 35 | else: 36 | sampler = InfSampler(dataset, shuffle=flags.shuffle) 37 | 38 | data_loader = torch.utils.data.DataLoader( 39 | dataset, batch_size=flags.batch_size, num_workers=flags.num_workers, 40 | sampler=sampler, collate_fn=collate_fn, pin_memory=True, drop_last = drop_last) 41 | 42 | return data_loader 43 | -------------------------------------------------------------------------------- /metrics/generate_pointclouds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import trimesh 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | def scale_to_unit_sphere(mesh): 7 | if isinstance(mesh, trimesh.Scene): 8 | mesh = mesh.dump().sum() 9 | vertices = mesh.vertices - mesh.bounding_box.centroid 10 | distances = np.linalg.norm(vertices, axis=1) 11 | vertices /= np.max(distances) 12 | return trimesh.Trimesh(vertices=vertices, faces=mesh.faces) 13 | 14 | def scale_to_unit_cube(mesh, padding=0.0): 15 | if isinstance(mesh, trimesh.Scene): 16 | mesh = mesh.dump().sum() 17 | 18 | vertices = mesh.vertices - mesh.bounding_box.centroid 19 | vertices *= 2 / np.max(mesh.bounding_box.extents) * (1 - padding) 20 | 21 | return trimesh.Trimesh(vertices=vertices, faces=mesh.faces) 22 | 23 | def sample_pts_from_mesh(mesh_folder, output_folder): 24 | 25 | if not os.path.exists(output_folder): 26 | os.makedirs(output_folder) 27 | 28 | print('-> Run sample_pts_from_mesh.') 29 | num_samples = 2048 30 | filenames = os.listdir(mesh_folder) 31 | for mesh_name in tqdm(filenames[:2000]): 32 | mesh_path = os.path.join(mesh_folder, mesh_name) 33 | mesh = trimesh.load(mesh_path, force='mesh') 34 | mesh = scale_to_unit_cube(mesh) 35 | points = mesh.sample(count = num_samples) 36 | filename_pts = os.path.join(output_folder, mesh_name[:-4]+'.npy') 37 | np.save(filename_pts, points.astype(np.float32)) 38 | 39 | if __name__ == '__main__': 40 | category = 'chair' 41 | mesh_folder = '/data/xiongbj/OctFusion-Cascade/chair_mesh_2t' 42 | output_folder = '/data/xiongbj/OctFusion-Cascade/chair_pointclouds' 43 | sample_pts_from_mesh(mesh_folder, output_folder) 44 | -------------------------------------------------------------------------------- /models/networks/clip_networks/network.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: 3 | - https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/encoders/modules.py 4 | - https://github.com/openai/CLIP 5 | """ 6 | 7 | import kornia 8 | from einops import rearrange, repeat 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from external.clip import clip 14 | 15 | class CLIPImageEncoder(nn.Module): 16 | def __init__( 17 | self, 18 | model="ViT-B/32", 19 | jit=False, 20 | device='cuda' if torch.cuda.is_available() else 'cpu', 21 | antialias=False, 22 | ): 23 | super().__init__() 24 | self.model, _ = clip.load(name=model, device=device, jit=jit) 25 | 26 | # self.model, self.preprocess = clip.load(name=model, device=device, jit=jit) 27 | self.model = self.model.float() # turns out this is important... 28 | 29 | self.antialias = antialias 30 | 31 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 32 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 33 | 34 | def preprocess(self, x): 35 | # normalize to [0,1] 36 | x = kornia.geometry.resize(x, (224, 224), 37 | interpolation='bicubic',align_corners=True, 38 | antialias=self.antialias) 39 | x = (x + 1.) / 2. 40 | # renormalize according to clip 41 | x = kornia.enhance.normalize(x, self.mean, self.std) 42 | return x 43 | 44 | def forward(self, x): 45 | # x is assumed to be in range [-1,1] 46 | return self.model.encode_image(self.preprocess(x)) 47 | -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/nn_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | # from extensions.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad 4 | from metrics.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad 5 | 6 | # Inherit from Function 7 | class NNDistanceFunction(Function): 8 | # Note that both forward and backward are @staticmethods 9 | @staticmethod 10 | # bias is an optional argument 11 | def forward(ctx, seta, setb): 12 | #print("Match Cost Forward") 13 | ctx.save_for_backward(seta, setb) 14 | ''' 15 | input: 16 | set1 : batch_size * #dataset_points * 3 17 | set2 : batch_size * #query_points * 3 18 | returns: 19 | dist1, idx1, dist2, idx2 20 | ''' 21 | dist1, idx1, dist2, idx2 = NNDistance(seta, setb) 22 | ctx.idx1 = idx1 23 | ctx.idx2 = idx2 24 | return dist1, dist2 25 | 26 | # This function has only a single output, so it gets only one gradient 27 | @staticmethod 28 | def backward(ctx, grad_dist1, grad_dist2): 29 | #print("Match Cost Backward") 30 | # This is a pattern that is very convenient - at the top of backward 31 | # unpack saved_tensors and initialize all gradients w.r.t. inputs to 32 | # None. Thanks to the fact that additional trailing Nones are 33 | # ignored, the return statement is simple even when the function has 34 | # optional inputs. 35 | seta, setb = ctx.saved_tensors 36 | idx1 = ctx.idx1 37 | idx2 = ctx.idx2 38 | grada, gradb = NNDistanceGrad(seta, setb, idx1, idx2, grad_dist1, grad_dist2) 39 | return grada, gradb 40 | 41 | nn_distance = NNDistanceFunction.apply 42 | 43 | -------------------------------------------------------------------------------- /datasets/sampler.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from torch.utils.data import Sampler, DistributedSampler, Dataset 10 | 11 | 12 | class InfSampler(Sampler): 13 | def __init__(self, dataset: Dataset, shuffle: bool = True) -> None: 14 | self.dataset = dataset 15 | self.shuffle = shuffle 16 | self.reset_sampler() 17 | 18 | def reset_sampler(self): 19 | num = len(self.dataset) 20 | indices = torch.randperm(num) if self.shuffle else torch.arange(num) 21 | self.indices = indices.tolist() 22 | self.iter_num = 0 23 | 24 | def __iter__(self): 25 | return self 26 | 27 | def __next__(self): 28 | value = self.indices[self.iter_num] 29 | self.iter_num = self.iter_num + 1 30 | 31 | if self.iter_num >= len(self.indices): 32 | self.reset_sampler() 33 | return value 34 | 35 | def __len__(self): 36 | return len(self.dataset) 37 | 38 | 39 | class DistributedInfSampler(DistributedSampler): 40 | def __init__(self, dataset: Dataset, shuffle: bool = True) -> None: 41 | super().__init__(dataset, shuffle=shuffle) 42 | self.reset_sampler() 43 | 44 | def reset_sampler(self): 45 | self.indices = list(super().__iter__()) 46 | self.iter_num = 0 47 | 48 | def __iter__(self): 49 | return self 50 | 51 | def __next__(self): 52 | value = self.indices[self.iter_num] 53 | self.iter_num = self.iter_num + 1 54 | 55 | if self.iter_num >= len(self.indices): 56 | self.reset_sampler() 57 | return value 58 | -------------------------------------------------------------------------------- /metrics/generate_synth_image.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | import os 3 | from multiprocessing import Pool, current_process 4 | import multiprocessing as mp 5 | from tqdm import tqdm 6 | import sys 7 | current_dir = os.path.dirname(os.path.abspath(__file__)) 8 | 9 | sys.path.append(os.path.dirname(current_dir)) 10 | from utils.render_utils import generate_image_for_fid 11 | 12 | os.environ['EGL_DEVICE_ID'] = '1' 13 | category = "airplane" 14 | snc_category_to_synth_id_13 = { 15 | 'airplane': '02691156', 16 | 'bench': '02828884', 17 | 'cabinet': '02933112', 18 | 'car': '02958343', 19 | 'chair': '03001627', 20 | 'monitor': '03211117', 21 | 'lamp': '03636649', 22 | 'loudspeaker': '03691459', 23 | 'rifle': '04090263', 24 | 'sofa': '04256520', 25 | 'table': '04379243', 26 | 'telephone': '04401088', 27 | 'vessel': '04530566', 28 | } 29 | 30 | cond = True 31 | root_dir = "logs/airplane_union/uncond_1000epoch_lr2e-4" 32 | fid_dir = f'{root_dir}/fid_images_{category}' 33 | mesh_dir = f'{root_dir}/results_{category}' 34 | 35 | 36 | os.makedirs(fid_dir, exist_ok=True) 37 | 38 | meshes = os.listdir(mesh_dir) 39 | 40 | def process_mesh(mesh): 41 | name = mesh[:-4] 42 | mesh_path = os.path.join(mesh_dir, mesh) 43 | mesh = trimesh.load(mesh_path, force="mesh") 44 | 45 | # Set the GPU for this process 46 | os.environ['EGL_DEVICE_ID'] = str(current_process()._identity[0] % 4) 47 | try: 48 | generate_image_for_fid(mesh, fid_dir, name) 49 | except: 50 | print(f'The mesh {name} occurs an error!') 51 | return 52 | print(f'The mesh {name} finish rendering') 53 | 54 | num_processes = 20 # mp.cpu_count() 55 | if num_processes > 1: 56 | with Pool(num_processes) as pool: # Create a pool with 4 processes 57 | list(tqdm(pool.imap(process_mesh, meshes), total=len(meshes))) 58 | else: 59 | for mesh in tqdm(meshes): 60 | process_mesh(mesh) 61 | -------------------------------------------------------------------------------- /metrics/generate_dataset_pointclouds.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | filelist = '/data/xiongbj/ShapeNet/filelist/test_chair.txt' 7 | mesh_dataset = '/data/xiongbj/ShapeNet/mesh' 8 | pointcloud_path = '/data/xiongbj/ShapeNet/test_pointclouds' 9 | 10 | def scale_to_unit_sphere(mesh): 11 | if isinstance(mesh, trimesh.Scene): 12 | mesh = mesh.dump().sum() 13 | vertices = mesh.vertices - mesh.bounding_box.centroid 14 | distances = np.linalg.norm(vertices, axis=1) 15 | vertices /= np.max(distances) 16 | return trimesh.Trimesh(vertices=vertices, faces=mesh.faces) 17 | 18 | 19 | def scale_to_unit_cube(mesh, padding=0.0): 20 | if isinstance(mesh, trimesh.Scene): 21 | mesh = mesh.dump().sum() 22 | 23 | vertices = mesh.vertices - mesh.bounding_box.centroid 24 | vertices *= 2 / np.max(mesh.bounding_box.extents) * (1 - padding) 25 | 26 | return trimesh.Trimesh(vertices=vertices, faces=mesh.faces) 27 | 28 | 29 | def sample_pts_from_mesh(mesh_path, output_path): 30 | 31 | num_samples = 2048 32 | mesh = trimesh.load(mesh_path, force='mesh') 33 | mesh = scale_to_unit_cube(mesh) 34 | 35 | points = mesh.sample(count = num_samples) 36 | 37 | np.save(output_path, points.astype(np.float32)) 38 | 39 | if __name__ == '__main__': 40 | print('-> Run sample_pts_from_mesh.') 41 | 42 | with open(filelist) as fid: 43 | lines = fid.readlines() 44 | 45 | for i, line in tqdm(enumerate(lines)): 46 | filename = line.split()[0] 47 | category = filename.split('/')[0] 48 | category_path = os.path.join(pointcloud_path, category) 49 | if not os.path.exists(category_path): os.makedirs(category_path) 50 | mesh_path = os.path.join(mesh_dataset, filename, 'model.obj') 51 | output_path = os.path.join(pointcloud_path, filename + '.npy') 52 | sample_pts_from_mesh(mesh_path, output_path) 53 | -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/match_cost.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from metrics.StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad 4 | 5 | # Inherit from Function 6 | class MatchCostFunction(Function): 7 | # Note that both forward and backward are @staticmethods 8 | @staticmethod 9 | # bias is an optional argument 10 | def forward(ctx, seta, setb): 11 | #print("Match Cost Forward") 12 | ctx.save_for_backward(seta, setb) 13 | ''' 14 | input: 15 | set1 : batch_size * #dataset_points * 3 16 | set2 : batch_size * #query_points * 3 17 | returns: 18 | match : batch_size * #query_points * #dataset_points 19 | ''' 20 | match, temp = ApproxMatch(seta, setb) 21 | ctx.match = match 22 | cost = MatchCost(seta, setb, match) 23 | return cost 24 | 25 | """ 26 | grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match) 27 | return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None] 28 | """ 29 | # This function has only a single output, so it gets only one gradient 30 | @staticmethod 31 | def backward(ctx, grad_output): 32 | #print("Match Cost Backward") 33 | # This is a pattern that is very convenient - at the top of backward 34 | # unpack saved_tensors and initialize all gradients w.r.t. inputs to 35 | # None. Thanks to the fact that additional trailing Nones are 36 | # ignored, the return statement is simple even when the function has 37 | # optional inputs. 38 | seta, setb = ctx.saved_tensors 39 | #grad_input = grad_weight = grad_bias = None 40 | grada, gradb = MatchCostGrad(seta, setb, ctx.match) 41 | grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2) 42 | return grada*grad_output_expand, gradb*grad_output_expand 43 | 44 | match_cost = MatchCostFunction.apply 45 | 46 | -------------------------------------------------------------------------------- /metrics/compute_metrics.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import argparse 10 | import trimesh.sample 11 | import numpy as np 12 | import torch 13 | from tqdm import tqdm 14 | from scipy.spatial import cKDTree 15 | 16 | def distChamfer(a, b): 17 | x, y = a, b 18 | bs, num_points, points_dim = x.size() 19 | xx = torch.bmm(x, x.transpose(2, 1)) 20 | yy = torch.bmm(y, y.transpose(2, 1)) 21 | zz = torch.bmm(x, y.transpose(2, 1)) 22 | diag_ind = torch.arange(0, num_points).to(a).long() 23 | rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx) 24 | ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy) 25 | P = (rx.transpose(2, 1) + ry - 2 * zz) 26 | return P.min(1)[0], P.min(2)[0] 27 | 28 | try: 29 | from metrics.StructuralLosses.nn_distance import nn_distance 30 | def distChamferCUDA(x, y): 31 | return nn_distance(x, y) 32 | except Exception as e: 33 | print(str(e)) 34 | print("distChamferCUDA not available; fall back to slower version.") 35 | def distChamferCUDA(x, y): 36 | return distChamfer(x, y) 37 | 38 | 39 | def compute_metrics(sample_pcs, ref_pcs, batch_size): 40 | 41 | N_ref = ref_pcs.shape[0] 42 | cd_lst = [] 43 | for ref_b_start in range(0, N_ref, batch_size): 44 | ref_b_end = min(N_ref, ref_b_start + batch_size) 45 | ref_batch = ref_pcs[ref_b_start:ref_b_end] 46 | 47 | batch_size_ref = ref_batch.size(0) 48 | sample_batch_exp = sample_pcs.view(1, -1, 3).expand(batch_size_ref, -1, -1) 49 | sample_batch_exp = sample_batch_exp.contiguous() 50 | dl, dr = distChamferCUDA(sample_batch_exp, ref_batch) 51 | cd = (dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1) 52 | cd_lst.append(cd) 53 | 54 | cd_lst = torch.cat(cd_lst, dim=1) 55 | 56 | return cd_lst -------------------------------------------------------------------------------- /utils/render/render.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils.render.render_utils import Render, create_pose 3 | import matplotlib 4 | import os 5 | os.environ['PYOPENGL_PLATFORM'] = 'egl' 6 | matplotlib.use("Agg") 7 | 8 | 9 | 10 | FrontVector = (np.array([[0.52573, 0.38197, 0.85065], 11 | [-0.20081, 0.61803, 0.85065], 12 | [-0.64984, 0.00000, 0.85065], 13 | [-0.20081, -0.61803, 0.85065], 14 | [0.52573, -0.38197, 0.85065], 15 | [0.85065, -0.61803, 0.20081], 16 | [1.0515, 0.00000, -0.20081], 17 | [0.85065, 0.61803, 0.20081], 18 | [0.32492, 1.00000, -0.20081], 19 | [-0.32492, 1.00000, 0.20081], 20 | [-0.85065, 0.61803, -0.20081], 21 | [-1.0515, 0.00000, 0.20081], 22 | [-0.85065, -0.61803, -0.20081], 23 | [-0.32492, -1.00000, 0.20081], 24 | [0.32492, -1.00000, -0.20081], 25 | [0.64984, 0.00000, -0.85065], 26 | [0.20081, 0.61803, -0.85065], 27 | [-0.52573, 0.38197, -0.85065], 28 | [-0.52573, -0.38197, -0.85065], 29 | [0.20081, -0.61803, -0.85065]]))*2 30 | 31 | def render_mesh(mesh, resolution=1024, index=5, background=None, scale=1, no_fix_normal=True): 32 | 33 | camera_pose = create_pose(FrontVector[index]*scale) 34 | 35 | render = Render(size=resolution, camera_pose=camera_pose, 36 | background=background) 37 | 38 | triangle_id, rendered_image, normal_map, depth_image, p_images = render.render(path=None, 39 | clean=True, 40 | mesh=mesh, 41 | only_render_images=no_fix_normal) 42 | return rendered_image -------------------------------------------------------------------------------- /models/networks/diffusion_networks/utils/spmm.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from .scatter import scatter_add 10 | 11 | 12 | def spmm(index, value, m, n, matrix): 13 | """Matrix product of sparse matrix with dense matrix. 14 | 15 | Args: 16 | index (:class:`LongTensor`): The index tensor of sparse matrix. 17 | value (:class:`Tensor`): The value tensor of sparse matrix. 18 | m (int): The first dimension of corresponding dense matrix. 19 | n (int): The second dimension of corresponding dense matrix. 20 | matrix (:class:`Tensor`): The dense matrix. 21 | 22 | :rtype: :class:`Tensor` 23 | """ 24 | 25 | assert n == matrix.size(-2) 26 | 27 | row, col = index 28 | matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1) 29 | 30 | out = matrix.index_select(-2, col) 31 | out = out * value.unsqueeze(-1) 32 | out = scatter_add(out, row, dim=-2, dim_size=m) 33 | 34 | return out 35 | 36 | 37 | def modulated_spmm(index, value, m, n, matrix, xyzf): 38 | """Matrix product of sparse matrix with dense matrix. 39 | 40 | Args: 41 | index (:class:`LongTensor`): The index tensor of sparse matrix. 42 | value (:class:`Tensor`): The value tensor of sparse matrix. 43 | m (int): The first dimension of corresponding dense matrix. 44 | n (int): The second dimension of corresponding dense matrix. 45 | matrix (:class:`Tensor`): The dense matrix. 46 | 47 | :rtype: :class:`Tensor` 48 | """ 49 | 50 | assert n == matrix.size(-2) 51 | 52 | row, col = index 53 | matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1) 54 | 55 | out = matrix.index_select(-2, col) 56 | ones = torch.ones((xyzf.shape[0], 1), device=xyzf.device) 57 | out = torch.sum(out * torch.cat([xyzf, ones], dim=1), dim=1, keepdim=True) 58 | out = out * value.unsqueeze(-1) 59 | out = scatter_add(out, row, dim=-2, dim_size=m) 60 | 61 | return out 62 | -------------------------------------------------------------------------------- /models/networks/dualoctree_networks/utils/spmm.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from .scatter import scatter_add 10 | 11 | 12 | def spmm(index, value, m, n, matrix): 13 | """Matrix product of sparse matrix with dense matrix. 14 | 15 | Args: 16 | index (:class:`LongTensor`): The index tensor of sparse matrix. 17 | value (:class:`Tensor`): The value tensor of sparse matrix. 18 | m (int): The first dimension of corresponding dense matrix. 19 | n (int): The second dimension of corresponding dense matrix. 20 | matrix (:class:`Tensor`): The dense matrix. 21 | 22 | :rtype: :class:`Tensor` 23 | """ 24 | 25 | assert n == matrix.size(-2) 26 | 27 | row, col = index 28 | matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1) 29 | 30 | out = matrix.index_select(-2, col) 31 | out = out * value.unsqueeze(-1) 32 | out = scatter_add(out, row, dim=-2, dim_size=m) 33 | 34 | return out 35 | 36 | 37 | def modulated_spmm(index, value, m, n, matrix, xyzf): 38 | """Matrix product of sparse matrix with dense matrix. 39 | 40 | Args: 41 | index (:class:`LongTensor`): The index tensor of sparse matrix. 42 | value (:class:`Tensor`): The value tensor of sparse matrix. 43 | m (int): The first dimension of corresponding dense matrix. 44 | n (int): The second dimension of corresponding dense matrix. 45 | matrix (:class:`Tensor`): The dense matrix. 46 | 47 | :rtype: :class:`Tensor` 48 | """ 49 | 50 | assert n == matrix.size(-2) 51 | 52 | row, col = index 53 | matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1) 54 | 55 | out = matrix.index_select(-2, col) 56 | ones = torch.ones((xyzf.shape[0], 1), device=xyzf.device) 57 | out = torch.sum(out * torch.cat([xyzf, ones], dim=1), dim=1, keepdim=True) 58 | out = out * value.unsqueeze(-1) 59 | out = scatter_add(out, row, dim=-2, dim_size=m) 60 | 61 | return out 62 | -------------------------------------------------------------------------------- /tools/gen_split.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import ocnn 4 | from ocnn.octree import Octree, Points 5 | from glob import glob 6 | import os 7 | import sys 8 | from tqdm import tqdm 9 | current_dir = os.path.dirname(os.path.realpath(__file__)) 10 | sys.path.append(os.path.dirname(current_dir)) 11 | from utils.util_dualoctree import split2octree_large, split2octree_small, octree2split_large, octree2split_small 12 | root_folder = "data/Objaverse" 13 | pointcloud_dir = "data/Objaverse/objaverse-select" 14 | split_dir = "data/Objaverse/objaverse-split" 15 | octree_dir = "data/Objaverse/objaverse-octree" 16 | 17 | def get_filenames(filelist): 18 | r''' Gets filenames from a filelist. 19 | ''' 20 | 21 | filelist = os.path.join(root_folder, 'filelist', filelist) 22 | with open(filelist, 'r') as fid: 23 | lines = fid.readlines() 24 | filenames = [line.split()[0] for line in lines] 25 | return filenames 26 | 27 | def points2octree(points): 28 | octree = ocnn.octree.Octree(depth = 10, full_depth = 4) 29 | octree.build_octree(points) 30 | return octree 31 | 32 | filenames = get_filenames('train_obja.txt') 33 | 34 | for filename in tqdm(filenames): 35 | print(filename) 36 | filename_pointcloud = os.path.join(pointcloud_dir, filename, "pointcloud.npz") 37 | filename_split = os.path.join(split_dir, filename) 38 | filename_octree = os.path.join(octree_dir, filename, "octree.pth") 39 | raw = np.load(filename_pointcloud) 40 | points, normals = raw['points'], raw['normals'] 41 | 42 | # transform points to octree 43 | points_gt = Points(points = torch.from_numpy(points).float(), normals = torch.from_numpy(normals).float()) 44 | octree_gt = points2octree(points_gt) 45 | os.makedirs(os.path.dirname(filename_octree), exist_ok = True) 46 | torch.save(octree_gt, filename_octree) 47 | 48 | split_small = octree2split_small(octree_gt, full_depth=4) 49 | split_large = octree2split_large(octree_gt, small_depth=6) 50 | os.makedirs(filename_split, exist_ok = True) 51 | torch.save(split_small.squeeze(0), os.path.join(filename_split, "split_small.pth")) 52 | torch.save(split_large, os.path.join(filename_split, "split_large.pth")) 53 | 54 | 55 | -------------------------------------------------------------------------------- /solver/dataset.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch 10 | import torch.utils.data 11 | import numpy as np 12 | from tqdm import tqdm 13 | from datasets.shapenet_utils import snc_synth_id_to_label_13, snc_synth_id_to_label_5 14 | 15 | 16 | def read_file(filename): 17 | points = np.fromfile(filename, dtype=np.uint8) 18 | return torch.from_numpy(points) # convert it to torch.tensor 19 | 20 | 21 | class Dataset(torch.utils.data.Dataset): 22 | 23 | def __init__(self, root, filelist, transform, read_file=read_file, 24 | in_memory=False, take: int = -1): 25 | super(Dataset, self).__init__() 26 | self.root = root 27 | self.filelist = filelist 28 | self.transform = transform 29 | self.in_memory = in_memory 30 | self.read_file = read_file 31 | self.take = take 32 | 33 | self.filenames, self.labels = self.load_filenames() 34 | if self.in_memory: 35 | print('Load files into memory from ' + self.filelist) 36 | self.samples = [self.read_file(os.path.join(self.root, f)) 37 | for f in tqdm(self.filenames, ncols=80, leave=False)] 38 | 39 | def __len__(self): 40 | return len(self.filenames) 41 | 42 | def __getitem__(self, idx): 43 | sample = self.samples[idx] if self.in_memory else \ 44 | self.read_file(os.path.join(self.root, self.filenames[idx])) # noqa 45 | output = self.transform(sample, idx) # data augmentation + build octree 46 | output['label'] = self.labels[idx] 47 | output['filename'] = self.filenames[idx] 48 | return output 49 | 50 | def load_filenames(self): 51 | filenames, labels = [], [] 52 | with open(self.filelist) as fid: 53 | lines = fid.readlines() 54 | for line in lines: 55 | filename = line.split()[0] 56 | label = filename.split('/')[0] 57 | if label in snc_synth_id_to_label_5: 58 | label = snc_synth_id_to_label_5[label] 59 | else: 60 | label = 0 61 | filenames.append(filename) 62 | labels.append(torch.tensor(label)) 63 | 64 | num = len(filenames) 65 | if self.take > num or self.take < 1: 66 | self.take = num 67 | 68 | return filenames[:self.take], labels[:self.take] 69 | -------------------------------------------------------------------------------- /models/networks/dualoctree_networks/utils/scatter.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from typing import Optional 10 | 11 | 12 | def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): 13 | if dim < 0: 14 | dim = other.dim() + dim 15 | if src.dim() == 1: 16 | for _ in range(0, dim): 17 | src = src.unsqueeze(0) 18 | for _ in range(src.dim(), other.dim()): 19 | src = src.unsqueeze(-1) 20 | src = src.expand_as(other) 21 | return src 22 | 23 | 24 | def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 25 | out: Optional[torch.Tensor] = None, 26 | dim_size: Optional[int] = None) -> torch.Tensor: 27 | index = broadcast(index, src, dim) 28 | if out is None: 29 | size = list(src.size()) 30 | if dim_size is not None: 31 | size[dim] = dim_size 32 | elif index.numel() == 0: 33 | size[dim] = 0 34 | else: 35 | size[dim] = int(index.max()) + 1 36 | out = torch.zeros(size, dtype=src.dtype, device=src.device) 37 | return out.scatter_add_(dim, index, src) 38 | else: 39 | return out.scatter_add_(dim, index, src) 40 | 41 | 42 | def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 43 | weights: Optional[torch.Tensor] = None, 44 | out: Optional[torch.Tensor] = None, 45 | dim_size: Optional[int] = None) -> torch.Tensor: 46 | if weights is not None: 47 | src = src * broadcast(weights, src, dim) 48 | out = scatter_add(src, index, dim, out, dim_size) 49 | dim_size = out.size(dim) 50 | 51 | index_dim = dim 52 | if index_dim < 0: 53 | index_dim = index_dim + src.dim() 54 | if index.dim() <= index_dim: 55 | index_dim = index.dim() - 1 56 | 57 | if weights is None: 58 | weights = torch.ones(index.size(), dtype=src.dtype, device=src.device) 59 | count = scatter_add(weights, index, index_dim, None, dim_size) 60 | count[count < 1] = 1 61 | count = broadcast(count, out, dim) 62 | if out.is_floating_point(): 63 | out.true_divide_(count) 64 | else: 65 | out.div_(count, rounding_mode='floor') 66 | return out 67 | -------------------------------------------------------------------------------- /configs/vae_snet_eval.yaml: -------------------------------------------------------------------------------- 1 | solver: 2 | sdf_scale: 0.9 3 | resolution: 256 4 | save_sdf: False 5 | 6 | model: 7 | name: graph_vae 8 | 9 | channel: 4 10 | depth: &depth 8 11 | nout: 4 12 | depth_out: *depth 13 | full_depth: &full_depth 4 14 | depth_stop: 6 15 | bottleneck: 4 16 | resblock_type: basic 17 | code_channel: 16 18 | resblk_num: 2 19 | 20 | embed_dim: 3 21 | n_embed: 8192 22 | use_checkpoint: True 23 | 24 | data: 25 | train: 26 | name: shapenet 27 | 28 | # octree building 29 | depth: *depth 30 | offset: 0.0 31 | full_depth: *full_depth 32 | node_dis: True 33 | split_label: True 34 | 35 | # no data augmentation 36 | disable: False 37 | distort: False 38 | 39 | # data loading 40 | location: &location data/ShapeNet/dataset_256 41 | filelist: &filelist data/ShapeNet/filelist 42 | load_octree: &load_octree False 43 | load_pointcloud: &load_pointcloud True 44 | load_split_small: &load_split_small False 45 | load_split_large: &load_split_large False 46 | load_sdf: &load_sdf False 47 | load_occu: &load_occu False 48 | load_color: &load_color False 49 | batch_size: &batch_size 2 50 | shuffle: True 51 | 52 | point_scale: &point_scale 0.5 53 | point_sample_num: 10000 54 | sample_surf_points: False 55 | in_memory: False 56 | num_workers: 8 57 | 58 | 59 | test: 60 | name: shapenet 61 | 62 | # octree building 63 | depth: *depth 64 | offset: 0.0 65 | full_depth: *full_depth 66 | node_dis: True 67 | split_label: True 68 | 69 | # no data augmentation 70 | disable: False 71 | distort: False 72 | 73 | # data loading 74 | location: *location 75 | filelist: *filelist 76 | batch_size: 1 77 | load_octree: *load_octree 78 | load_pointcloud: *load_pointcloud 79 | load_split_small: *load_split_small 80 | load_split_large: *load_split_large 81 | load_sdf: *load_sdf 82 | load_occu: *load_occu 83 | load_color: *load_color 84 | shuffle: False 85 | 86 | point_scale: *point_scale 87 | point_sample_num: 10000 88 | sample_surf_points: False 89 | in_memory: False 90 | num_workers: 8 91 | -------------------------------------------------------------------------------- /models/networks/diffusion_networks/utils/scatter.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from typing import Optional 10 | 11 | 12 | def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): 13 | if dim < 0: 14 | dim = other.dim() + dim 15 | if src.dim() == 1: 16 | for _ in range(0, dim): 17 | src = src.unsqueeze(0) 18 | for _ in range(src.dim(), other.dim()): 19 | src = src.unsqueeze(-1) 20 | src = src.expand_as(other) 21 | return src 22 | 23 | 24 | def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 25 | out: Optional[torch.Tensor] = None, 26 | dim_size: Optional[int] = None) -> torch.Tensor: 27 | index = broadcast(index, src, dim) 28 | if out is None: 29 | size = list(src.size()) 30 | if dim_size is not None: 31 | size[dim] = dim_size 32 | elif index.numel() == 0: 33 | size[dim] = 0 34 | else: 35 | size[dim] = int(index.max()) + 1 36 | out = torch.zeros(size, dtype=src.dtype, device=src.device) 37 | return out.scatter_add_(dim, index, src) 38 | else: 39 | return out.scatter_add_(dim, index, src) 40 | 41 | 42 | def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 43 | weights: Optional[torch.Tensor] = None, 44 | out: Optional[torch.Tensor] = None, 45 | dim_size: Optional[int] = None) -> torch.Tensor: 46 | if weights is not None: 47 | src = src * broadcast(weights, src, dim) 48 | out = scatter_add(src, index, dim, out, dim_size) 49 | dim_size = out.size(dim) 50 | 51 | index_dim = dim 52 | if index_dim < 0: 53 | index_dim = index_dim + src.dim() 54 | if index.dim() <= index_dim: 55 | index_dim = index.dim() - 1 56 | 57 | if weights is None: 58 | weights = torch.ones(index.size(), dtype=src.dtype, device=src.device) 59 | count = scatter_add(weights, index, index_dim, None, dim_size) 60 | count[count < 1] = 1 61 | count = broadcast(count, out, dim) 62 | if out.is_floating_point(): 63 | out.true_divide_(count) 64 | else: 65 | out.div_(count, rounding_mode='floor') 66 | return out 67 | -------------------------------------------------------------------------------- /configs/vae_snet_eval_depth984.yaml: -------------------------------------------------------------------------------- 1 | solver: 2 | sdf_scale: 0.9 3 | resolution: 256 4 | save_sdf: False 5 | 6 | model: 7 | name: graph_vae 8 | 9 | channel: 4 10 | depth: &depth 9 11 | nout: 4 12 | depth_out: *depth 13 | full_depth: &full_depth 4 14 | depth_stop: 8 15 | bottleneck: 4 16 | resblock_type: basic 17 | code_channel: 16 18 | resblk_num: 2 19 | 20 | embed_dim: 3 21 | n_embed: 8192 22 | use_checkpoint: True 23 | 24 | data: 25 | train: 26 | name: shapenet 27 | 28 | # octree building 29 | depth: *depth 30 | offset: 0.0 31 | full_depth: *full_depth 32 | node_dis: True 33 | split_label: True 34 | 35 | # no data augmentation 36 | disable: False 37 | distort: False 38 | 39 | # data loading 40 | location: &location data/ShapeNet/dataset_256 41 | filelist: &filelist data/ShapeNet/filelist 42 | load_octree: &load_octree False 43 | load_pointcloud: &load_pointcloud True 44 | load_split_small: &load_split_small False 45 | load_split_large: &load_split_large False 46 | load_sdf: &load_sdf False 47 | load_occu: &load_occu False 48 | load_color: &load_color False 49 | batch_size: &batch_size 2 50 | shuffle: True 51 | 52 | point_scale: &point_scale 0.5 53 | point_sample_num: 10000 54 | sample_surf_points: False 55 | in_memory: False 56 | num_workers: 8 57 | 58 | 59 | test: 60 | name: shapenet 61 | 62 | # octree building 63 | depth: *depth 64 | offset: 0.0 65 | full_depth: *full_depth 66 | node_dis: True 67 | split_label: True 68 | 69 | # no data augmentation 70 | disable: False 71 | distort: False 72 | 73 | # data loading 74 | location: *location 75 | filelist: *filelist 76 | batch_size: 1 77 | load_octree: *load_octree 78 | load_pointcloud: *load_pointcloud 79 | load_split_small: *load_split_small 80 | load_split_large: *load_split_large 81 | load_sdf: *load_sdf 82 | load_occu: *load_occu 83 | load_color: *load_color 84 | shuffle: False 85 | 86 | point_scale: *point_scale 87 | point_sample_num: 10000 88 | sample_surf_points: False 89 | in_memory: False 90 | num_workers: 8 91 | -------------------------------------------------------------------------------- /configs/vae_snet_train.yaml: -------------------------------------------------------------------------------- 1 | solver: 2 | sdf_scale: 0.9 3 | resolution: 256 4 | save_sdf: False 5 | 6 | model: 7 | name: graph_vae 8 | 9 | channel: 4 10 | depth: &depth 8 11 | nout: 4 12 | depth_out: *depth 13 | full_depth: &full_depth 4 14 | depth_stop: 6 15 | bottleneck: 4 16 | resblock_type: basic 17 | code_channel: 16 18 | resblk_num: 2 19 | 20 | embed_dim: 3 21 | n_embed: 8192 22 | use_checkpoint: True 23 | 24 | data: 25 | train: 26 | name: shapenet 27 | 28 | # octree building 29 | depth: *depth 30 | offset: 0.0 31 | full_depth: *full_depth 32 | node_dis: True 33 | split_label: True 34 | 35 | # no data augmentation 36 | disable: False 37 | distort: False 38 | 39 | # data loading 40 | location: &location data/ShapeNet/dataset_10w 41 | filelist: &filelist data/ShapeNet/filelist 42 | load_octree: &load_octree False 43 | load_pointcloud: &load_pointcloud True 44 | load_split_small: &load_split_small False 45 | load_split_large: &load_split_large False 46 | load_sdf: &load_sdf True 47 | load_occu: &load_occu False 48 | load_color: &load_color False 49 | batch_size: &batch_size 4 50 | shuffle: True 51 | 52 | point_scale: &point_scale 0.5 53 | point_sample_num: 50000 54 | sample_surf_points: False 55 | in_memory: False 56 | num_workers: 8 57 | 58 | 59 | test: 60 | name: shapenet 61 | 62 | # octree building 63 | depth: *depth 64 | offset: 0.0 65 | full_depth: *full_depth 66 | node_dis: True 67 | split_label: True 68 | 69 | # no data augmentation 70 | disable: False 71 | distort: False 72 | 73 | # data loading 74 | location: *location 75 | filelist: *filelist 76 | batch_size: 1 77 | load_octree: *load_octree 78 | load_pointcloud: *load_pointcloud 79 | load_split_small: *load_split_small 80 | load_split_large: *load_split_large 81 | load_sdf: *load_sdf 82 | load_occu: *load_occu 83 | load_color: *load_color 84 | shuffle: False 85 | 86 | point_scale: *point_scale 87 | point_sample_num: 10000 88 | sample_surf_points: False 89 | in_memory: False 90 | num_workers: 8 91 | 92 | loss: 93 | name: geometry 94 | loss_type: sdf_reg_loss 95 | kl_weight: 0.1 -------------------------------------------------------------------------------- /configs/vae_obja_eval.yaml: -------------------------------------------------------------------------------- 1 | solver: 2 | sdf_scale: 0.9 3 | resolution: 256 4 | save_sdf: False 5 | 6 | model: 7 | name: graph_vae 8 | 9 | channel: 4 10 | depth: &depth 10 11 | nout: 4 12 | depth_out: *depth 13 | full_depth: &full_depth 4 14 | depth_stop: 8 15 | bottleneck: 4 16 | resblock_type: basic 17 | code_channel: 16 18 | resblk_num: 2 19 | 20 | embed_dim: 3 21 | n_embed: 8192 22 | use_checkpoint: True 23 | 24 | data: 25 | train: 26 | name: shapenet 27 | 28 | # octree building 29 | depth: *depth 30 | offset: 0.0 31 | full_depth: *full_depth 32 | node_dis: True 33 | split_label: True 34 | 35 | # no data augmentation 36 | disable: False 37 | distort: False 38 | 39 | # data loading 40 | location: &location data/Objaverse/objaverse-select 41 | filelist: &filelist data/Objaverse/filelist 42 | load_octree: &load_octree False 43 | load_pointcloud: &load_pointcloud True 44 | load_split_small: &load_split_small False 45 | load_split_large: &load_split_large False 46 | load_sdf: &load_sdf False 47 | load_occu: &load_occu False 48 | load_color: &load_color False 49 | batch_size: &batch_size 1 50 | shuffle: True 51 | 52 | point_scale: &point_scale 1.0 53 | point_sample_num: 10000 54 | sample_surf_points: False 55 | in_memory: False 56 | num_workers: 8 57 | 58 | 59 | test: 60 | name: shapenet 61 | 62 | # octree building 63 | depth: *depth 64 | offset: 0.0 65 | full_depth: *full_depth 66 | node_dis: True 67 | split_label: True 68 | 69 | # no data augmentation 70 | disable: False 71 | distort: False 72 | 73 | # data loading 74 | location: *location 75 | filelist: *filelist 76 | batch_size: 1 77 | load_octree: *load_octree 78 | load_pointcloud: *load_pointcloud 79 | load_split_small: *load_split_small 80 | load_split_large: *load_split_large 81 | load_sdf: *load_sdf 82 | load_occu: *load_occu 83 | load_color: *load_color 84 | shuffle: False 85 | 86 | point_scale: *point_scale 87 | point_sample_num: 10000 88 | sample_surf_points: False 89 | in_memory: False 90 | num_workers: 8 91 | 92 | loss: 93 | name: geometry 94 | loss_type: sdf_reg_loss 95 | kl_weight: 0.1 -------------------------------------------------------------------------------- /configs/vae_obja_train.yaml: -------------------------------------------------------------------------------- 1 | solver: 2 | sdf_scale: 0.9 3 | resolution: 256 4 | save_sdf: False 5 | 6 | model: 7 | name: graph_vae 8 | 9 | channel: 4 10 | depth: &depth 10 11 | nout: 4 12 | depth_out: *depth 13 | full_depth: &full_depth 4 14 | depth_stop: 8 15 | bottleneck: 4 16 | resblock_type: basic 17 | code_channel: 16 18 | resblk_num: 2 19 | 20 | embed_dim: 3 21 | n_embed: 8192 22 | use_checkpoint: True 23 | 24 | data: 25 | train: 26 | name: shapenet 27 | 28 | # octree building 29 | depth: *depth 30 | offset: 0.0 31 | full_depth: *full_depth 32 | node_dis: True 33 | split_label: True 34 | 35 | # no data augmentation 36 | disable: False 37 | distort: False 38 | 39 | # data loading 40 | location: &location data/Objaverse/objaverse-select 41 | filelist: &filelist data/Objaverse/filelist 42 | load_octree: &load_octree False 43 | load_pointcloud: &load_pointcloud True 44 | load_split_small: &load_split_small False 45 | load_split_large: &load_split_large False 46 | load_sdf: &load_sdf True 47 | load_occu: &load_occu False 48 | load_color: &load_color False 49 | batch_size: &batch_size 2 50 | shuffle: True 51 | 52 | point_scale: &point_scale 1.0 53 | point_sample_num: 10000 54 | sample_surf_points: False 55 | in_memory: False 56 | num_workers: 8 57 | 58 | 59 | test: 60 | name: shapenet 61 | 62 | # octree building 63 | depth: *depth 64 | offset: 0.0 65 | full_depth: *full_depth 66 | node_dis: True 67 | split_label: True 68 | 69 | # no data augmentation 70 | disable: False 71 | distort: False 72 | 73 | # data loading 74 | location: *location 75 | filelist: *filelist 76 | batch_size: 1 77 | load_octree: *load_octree 78 | load_pointcloud: *load_pointcloud 79 | load_split_small: *load_split_small 80 | load_split_large: *load_split_large 81 | load_sdf: *load_sdf 82 | load_occu: *load_occu 83 | load_color: *load_color 84 | shuffle: False 85 | 86 | point_scale: *point_scale 87 | point_sample_num: 10000 88 | sample_surf_points: False 89 | in_memory: False 90 | num_workers: 8 91 | 92 | loss: 93 | name: geometry 94 | loss_type: sdf_reg_loss 95 | kl_weight: 0.1 -------------------------------------------------------------------------------- /configs/vae_obja_eval_depth864.yaml: -------------------------------------------------------------------------------- 1 | solver: 2 | sdf_scale: 0.9 3 | resolution: 256 4 | save_sdf: False 5 | 6 | model: 7 | name: graph_vae 8 | 9 | channel: 4 10 | depth: &depth 8 11 | nout: 4 12 | depth_out: *depth 13 | full_depth: &full_depth 4 14 | depth_stop: 6 15 | bottleneck: 4 16 | resblock_type: basic 17 | code_channel: 16 18 | resblk_num: 2 19 | 20 | embed_dim: 3 21 | n_embed: 8192 22 | use_checkpoint: True 23 | 24 | data: 25 | train: 26 | name: shapenet 27 | 28 | # octree building 29 | depth: *depth 30 | offset: 0.0 31 | full_depth: *full_depth 32 | node_dis: True 33 | split_label: True 34 | 35 | # no data augmentation 36 | disable: False 37 | distort: False 38 | 39 | # data loading 40 | location: &location data/Objaverse/objaverse-select 41 | filelist: &filelist data/Objaverse/filelist 42 | load_octree: &load_octree False 43 | load_pointcloud: &load_pointcloud True 44 | load_split_small: &load_split_small False 45 | load_split_large: &load_split_large False 46 | load_sdf: &load_sdf False 47 | load_occu: &load_occu False 48 | load_color: &load_color False 49 | batch_size: &batch_size 2 50 | shuffle: True 51 | 52 | point_scale: &point_scale 1.0 53 | point_sample_num: 10000 54 | sample_surf_points: False 55 | in_memory: False 56 | num_workers: 8 57 | 58 | 59 | test: 60 | name: shapenet 61 | 62 | # octree building 63 | depth: *depth 64 | offset: 0.0 65 | full_depth: *full_depth 66 | node_dis: True 67 | split_label: True 68 | 69 | # no data augmentation 70 | disable: False 71 | distort: False 72 | 73 | # data loading 74 | location: *location 75 | filelist: *filelist 76 | batch_size: 1 77 | load_octree: *load_octree 78 | load_pointcloud: *load_pointcloud 79 | load_split_small: *load_split_small 80 | load_split_large: *load_split_large 81 | load_sdf: *load_sdf 82 | load_occu: *load_occu 83 | load_color: *load_color 84 | shuffle: False 85 | 86 | point_scale: *point_scale 87 | point_sample_num: 10000 88 | sample_surf_points: False 89 | in_memory: False 90 | num_workers: 8 91 | 92 | loss: 93 | name: geometry 94 | loss_type: sdf_reg_loss 95 | kl_weight: 0.1 -------------------------------------------------------------------------------- /utils/visualizer.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | from collections import OrderedDict 4 | import os 5 | import ntpath 6 | import time 7 | 8 | from termcolor import colored 9 | from . import util 10 | 11 | import torch 12 | import imageio 13 | import numpy as np 14 | import cv2 15 | import matplotlib.pyplot as plt 16 | 17 | class Visualizer(): 18 | def __init__(self, opt): 19 | # self.opt = opt 20 | self.isTrain = opt.isTrain 21 | self.gif_fps = 4 22 | 23 | if self.isTrain: 24 | # self.log_dir = os.path.join(opt.checkpoints_dir, opt.name) 25 | self.log_dir = os.path.join(opt.logs_dir, opt.name) 26 | 27 | self.train_img_dir = os.path.join(self.log_dir, 'train_temp') 28 | self.test_img_dir = os.path.join(self.log_dir, 'test_temp') 29 | 30 | self.name = opt.name 31 | self.opt = opt 32 | 33 | def setup_io(self): 34 | 35 | if self.isTrain: 36 | print('[*] create image directory:\n%s...' % os.path.abspath(self.train_img_dir) ) 37 | print('[*] create image directory:\n%s...' % os.path.abspath(self.test_img_dir) ) 38 | util.mkdirs([self.train_img_dir, self.test_img_dir]) 39 | # self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 40 | 41 | self.log_name = os.path.join(self.log_dir, 'loss_log.txt') 42 | # with open(self.log_name, "a") as log_file: 43 | with open(self.log_name, "w") as log_file: 44 | now = time.strftime("%c") 45 | log_file.write('================ Training Loss (%s) ================\n' % now) 46 | 47 | def reset(self): 48 | self.saved = False 49 | 50 | def print_current_errors(self, current_iters, errors, t): 51 | # message = '(GPU: %s, epoch: %d, iters: %d, time: %.3f) ' % (self.opt.gpu_ids_str, t) 52 | # message = f"[{self.opt.exp_time}] (GPU: {self.opt.gpu_ids_str}, iters: {current_iters}, time: {t:.3f}) " 53 | message = f"[{self.opt.name}] (GPU: {self.opt.gpu_ids_str}, iters: {current_iters}, time: {t:.3f}) " 54 | for k, v in errors.items(): 55 | message += '%s: %.6f ' % (k, v) 56 | 57 | print(colored(message, 'magenta')) 58 | with open(self.log_name, "a") as log_file: 59 | log_file.write('%s\n' % message) 60 | 61 | self.log_tensorboard_errors(errors, current_iters) 62 | 63 | 64 | def log_tensorboard_errors(self, errors, cur_step): 65 | writer = self.opt.writer 66 | 67 | for label, error in errors.items(): 68 | writer.add_scalar('losses/%s' % label, error, cur_step) 69 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | class TrainOptions(BaseOptions): 4 | 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.parser.add_argument('--lr', type=float, default=1e-4, help='initial learning rate for adam') 8 | self.parser.add_argument('--min_lr', type=float, default=1e-6, help='initial learning rate for adam') 9 | self.parser.add_argument('--update_learning_rate', type=int, default=0, help='whether to update learning rate') 10 | self.parser.add_argument('--warmup_epochs', type=float, default=0, help='initial learning rate for adam') 11 | self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') 12 | self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 13 | # self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 14 | 15 | 16 | # display stuff 17 | self.parser.add_argument('--display_freq', type=int, default=3000, help='frequency of showing training results on screen') 18 | self.parser.add_argument('--print_freq', type=int, default=25, help='frequency of showing training results on console') 19 | self.parser.add_argument('--ckpt_num', type=int, default=5, help='The number of checkpoint kept') 20 | 21 | self.parser.add_argument('--save_latest_freq', type=int, default=500, help='frequency of saving the latest results') 22 | self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') 23 | self.parser.add_argument('--save_steps_freq', type=int, default=1000, help='frequency of saving checkpoints') 24 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 25 | self.parser.add_argument('--ema_rate', type=float, default=0.999, help='the rate of Exponential Moving Average') 26 | 27 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 28 | self.parser.add_argument('--total_iters', type=int, default=100000000, help='# of iter for training') 29 | self.parser.add_argument('--epochs', type=int, default=4000, help='# of iter for training') 30 | self.parser.add_argument('--start_iter', type=int, default=0, help='# of iter for training') 31 | 32 | self.parser.add_argument('--mode', type=str, default='train', help='# of iter for training', choices=["train", "generate"]) 33 | self.parser.add_argument('--isTrain', type=str, default='True', help='# of iter for training') 34 | -------------------------------------------------------------------------------- /models/networks/bert_networks/network.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/encoders/modules.py 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from models.networks.bert_networks.x_transformer import Encoder, TransformerWrapper 8 | 9 | # from models.networks.transformer_networks.x_transformer import Encoder, TransformerWrapper 10 | 11 | class AbstractEncoder(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def encode(self, *args, **kwargs): 16 | raise NotImplementedError 17 | 18 | class BERTTokenizer(AbstractEncoder): 19 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 20 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 21 | super().__init__() 22 | from transformers import BertTokenizerFast # TODO: add to reuquirements 23 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 24 | self.device = device 25 | # self.vq_interface = vq_interface 26 | self.vq_interface = False # NOTE: currently set to false. 27 | self.max_length = max_length 28 | 29 | def forward(self, text): 30 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 31 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 32 | tokens = batch_encoding["input_ids"].to(self.device) 33 | return tokens 34 | 35 | @torch.no_grad() 36 | def encode(self, text): 37 | tokens = self(text) 38 | if not self.vq_interface: 39 | return tokens 40 | return None, None, [None, None, tokens] 41 | 42 | def decode(self, text): 43 | return text 44 | 45 | 46 | class BERTTextEncoder(AbstractEncoder): 47 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 48 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 49 | device="cuda",use_tokenizer=True, embedding_dropout=0.0): 50 | super().__init__() 51 | self.use_tknz_fn = use_tokenizer 52 | if self.use_tknz_fn: 53 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 54 | self.device = device 55 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 56 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 57 | emb_dropout=embedding_dropout) 58 | 59 | def forward(self, text): 60 | if self.use_tknz_fn: 61 | tokens = self.tknz_fn(text)#.to(self.device) 62 | else: 63 | tokens = text 64 | z = self.transformer(tokens, return_embeddings=True) 65 | return z 66 | 67 | def encode(self, text): 68 | # output of length 77 69 | return self(text) -------------------------------------------------------------------------------- /models/networks/dualoctree_networks/modules.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import math 9 | import torch 10 | import torch.nn 11 | import torch.nn.init 12 | import torch.nn.functional as F 13 | import torch.utils.checkpoint 14 | # import torch_geometric.nn 15 | 16 | from .utils.scatter import scatter_mean 17 | from ocnn.octree import key2xyz, xyz2key 18 | 19 | from ocnn.octree import Octree 20 | from ocnn.utils import scatter_add 21 | from models.networks.modules import ( 22 | nonlinearity, 23 | ckpt_conv_wrapper, 24 | DualOctreeGroupNorm, 25 | Conv1x1, 26 | Conv1x1Gn, 27 | Conv1x1GnGelu, 28 | Conv1x1GnGeluSequential, 29 | Downsample, 30 | Upsample, 31 | GraphConv, 32 | GraphResBlock, 33 | GraphResBlocks, 34 | GraphDownsample, 35 | GraphUpsample, 36 | ) 37 | 38 | 39 | class GraphDownsample(torch.nn.Module): 40 | 41 | def __init__(self, channels_in, channels_out=None): 42 | super().__init__() 43 | self.channels_in = channels_in 44 | self.channels_out = channels_out or channels_in 45 | self.downsample = Downsample(channels_in) 46 | if self.channels_in != self.channels_out: 47 | self.conv1x1 = Conv1x1GnGelu(self.channels_in, self.channels_out) 48 | 49 | def forward(self, x, octree, d, leaf_mask, numd, lnumd): 50 | # downsample nodes at layer depth 51 | outd = x[-numd:] 52 | outd = self.downsample(outd) 53 | 54 | # get the nodes at layer (depth-1) 55 | out = torch.zeros(leaf_mask.shape[0], x.shape[1], device=x.device) 56 | out[leaf_mask] = x[-lnumd-numd:-numd] 57 | out[leaf_mask.logical_not()] = outd 58 | 59 | # construct the final output 60 | out = torch.cat([x[:-numd-lnumd], out], dim=0) 61 | 62 | if self.channels_in != self.channels_out: 63 | out = self.conv1x1(out, octree, d) 64 | return out 65 | 66 | def extra_repr(self): 67 | return 'channels_in={}, channels_out={}'.format( 68 | self.channels_in, self.channels_out) 69 | 70 | 71 | class GraphUpsample(torch.nn.Module): 72 | 73 | def __init__(self, channels_in, channels_out=None): 74 | super().__init__() 75 | self.channels_in = channels_in 76 | self.channels_out = channels_out or channels_in 77 | self.upsample = Upsample(channels_in) 78 | if self.channels_in != self.channels_out: 79 | self.conv1x1 = Conv1x1GnGelu(self.channels_in, self.channels_out) 80 | 81 | def forward(self, x, octree, d, leaf_mask, numd): 82 | # upsample nodes at layer (depth-1) 83 | outd = x[-numd:] 84 | out1 = outd[leaf_mask.logical_not()] 85 | out1 = self.upsample(out1) 86 | 87 | # construct the final output 88 | out = torch.cat([x[:-numd], outd[leaf_mask], out1], dim=0) 89 | if self.channels_in != self.channels_out: 90 | out = self.conv1x1(out, octree, d) 91 | return out 92 | 93 | def extra_repr(self): 94 | return 'channels_in={}, channels_out={}'.format( 95 | self.channels_in, self.channels_out) 96 | -------------------------------------------------------------------------------- /scripts/run_snet_vae.sh: -------------------------------------------------------------------------------- 1 | RED='\033[0;31m' 2 | NC='\033[0m' # No Color 3 | DATE_WITH_TIME=`date "+%Y-%m-%dT%H-%M-%S"` 4 | 5 | logs_dir='logs' 6 | 7 | ### set gpus ### 8 | # gpu_ids=0 # single-gpu 9 | gpu_ids=4,5,6,7 # multi-gpu 10 | 11 | if [ ${#gpu_ids} -gt 1 ]; then 12 | # specify these two if multi-gpu 13 | # NGPU=2 14 | # NGPU=3 15 | NGPU=4 16 | HOST_NODE_ADDR="localhost:27000" 17 | echo "HERE" 18 | fi 19 | ################ 20 | 21 | ### hyper params ### 22 | lr=1e-3 23 | min_lr=1e-6 24 | update_learning_rate=0 25 | warmup_epochs=40 26 | epochs=300 27 | batch_size=2 28 | ema_rate=0.999 29 | ckpt_num=3 30 | seed=42 31 | #################### 32 | 33 | ### model stuff ### 34 | model='vae' 35 | mode="$1" 36 | stage_flag="$2" 37 | dataset_mode='snet' 38 | note="test" 39 | category="$3" 40 | 41 | df_yaml="octfusion_${dataset_mode}_uncond.yaml" 42 | df_cfg="configs/${df_yaml}" 43 | vq_model="GraphVAE" 44 | vq_yaml="vae_${dataset_mode}_train.yaml" 45 | vq_cfg="configs/${vq_yaml}" 46 | vq_ckpt="saved_ckpt/vae-ckpt/vae-shapenet-depth-8.pth" 47 | 48 | ##################### 49 | 50 | ### display & log stuff ### 51 | display_freq=3000 52 | print_freq=25 53 | save_steps_freq=3000 54 | save_latest_freq=500 55 | ########################### 56 | 57 | 58 | today=$(date '+%m%d') 59 | me=`basename "$0"` 60 | me=$(echo $me | cut -d'.' -f 1) 61 | 62 | name="${category}_union/${note}_${dataset_mode}_lr${lr}" 63 | 64 | debug=0 65 | if [ "$mode" = "generate" || "$mode" = "inference_vae" ]; then 66 | df_cfg="${logs_dir}/${name}/${df_yaml}" 67 | vq_cfg="${logs_dir}/${name}/${vq_yaml}" 68 | ckpt="${logs_dir}/${name}/ckpt/df_steps-latest.pth" 69 | fi 70 | 71 | cmd="train.py --name ${name} --logs_dir ${logs_dir} --gpu_ids ${gpu_ids} --mode ${mode} \ 72 | --lr ${lr} --epochs ${epochs} --min_lr ${min_lr} --warmup_epochs ${warmup_epochs} --update_learning_rate ${update_learning_rate} --ema_rate ${ema_rate} --seed ${seed} \ 73 | --model ${model} --df_cfg ${df_cfg} --ckpt_num ${ckpt_num} --category ${category} \ 74 | --vq_model ${vq_model} --vq_cfg ${vq_cfg} \ 75 | --display_freq ${display_freq} --print_freq ${print_freq} \ 76 | --save_steps_freq ${save_steps_freq} --save_latest_freq ${save_latest_freq} \ 77 | --debug ${debug}" 78 | 79 | if [ ! -z "$ckpt" ]; then 80 | cmd="${cmd} --ckpt ${ckpt}" 81 | echo "continue training with ckpt=${ckpt}" 82 | fi 83 | if [ ! -z "$pretrain_ckpt" ]; then 84 | cmd="${cmd} --pretrain_ckpt ${pretrain_ckpt}" 85 | echo "second stage training with pretrain_ckpt=${pretrain_ckpt}" 86 | fi 87 | if [ $mode = "generate" ]; then 88 | cmd="${cmd} --vq_ckpt ${vq_ckpt}" 89 | fi 90 | 91 | multi_gpu=0 92 | if [ ${#gpu_ids} -gt 1 ]; then 93 | multi_gpu=1 94 | fi 95 | 96 | echo "[*] Training is starting on `hostname`, GPU#: ${gpu_ids}, logs_dir: ${logs_dir}" 97 | 98 | echo "[*] Training with command: " 99 | 100 | if [ $multi_gpu = 1 ]; then 101 | 102 | cmd="--nnodes=1 --nproc_per_node=${NGPU} --rdzv-backend=c10d --rdzv-endpoint=${HOST_NODE_ADDR} ${cmd}" 103 | echo "CUDA_VISIBLE_DEVICES=${gpu_ids} torchrun ${cmd}" 104 | CUDA_VISIBLE_DEVICES=${gpu_ids} torchrun ${cmd} 105 | 106 | else 107 | 108 | echo "CUDA_VISIBLE_DEVICES=${gpu_ids} python3 ${cmd}" 109 | CUDA_VISIBLE_DEVICES=${gpu_ids} python3 ${cmd} 110 | 111 | fi 112 | -------------------------------------------------------------------------------- /models/networks/dualoctree_networks/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | # return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | # + self.var - 1.0 - self.logvar, 46 | # dim=[1]) 47 | return 0.5 * (torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar) 48 | else: 49 | return 0.5 * torch.sum( 50 | torch.pow(self.mean - other.mean, 2) / other.var 51 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 52 | dim=[1, 2, 3, 4]) 53 | 54 | def nll(self, sample, dims=[1,2,3,4]): 55 | if self.deterministic: 56 | return torch.Tensor([0.]) 57 | logtwopi = np.log(2.0 * np.pi) 58 | return 0.5 * torch.sum( 59 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 60 | dim=dims) 61 | 62 | def mode(self): 63 | return self.mean 64 | 65 | 66 | def normal_kl(mean1, logvar1, mean2, logvar2): 67 | """ 68 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 69 | Compute the KL divergence between two gaussians. 70 | Shapes are automatically broadcasted, so batches can be compared to 71 | scalars, among other use cases. 72 | """ 73 | tensor = None 74 | for obj in (mean1, logvar1, mean2, logvar2): 75 | if isinstance(obj, torch.Tensor): 76 | tensor = obj 77 | break 78 | assert tensor is not None, "at least one argument must be a Tensor" 79 | 80 | # Force variances to be Tensors. Broadcasting helps convert scalars to 81 | # Tensors, but it does not work for torch.exp(). 82 | logvar1, logvar2 = [ 83 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 84 | for x in (logvar1, logvar2) 85 | ] 86 | 87 | return 0.5 * ( 88 | -1.0 89 | + logvar2 90 | - logvar1 91 | + torch.exp(logvar1 - logvar2) 92 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 93 | ) 94 | -------------------------------------------------------------------------------- /scripts/run_snet_cond.sh: -------------------------------------------------------------------------------- 1 | RED='\033[0;31m' 2 | NC='\033[0m' # No Color 3 | DATE_WITH_TIME=`date "+%Y-%m-%dT%H-%M-%S"` 4 | 5 | logs_dir='logs' 6 | 7 | ### set gpus ### 8 | gpu_ids=0,1,2,3 # multi-gpu 9 | 10 | if [ ${#gpu_ids} -gt 1 ]; then 11 | # specify these two if multi-gpu 12 | NGPU=4 13 | HOST_NODE_ADDR="localhost:25000" 14 | echo "HERE" 15 | fi 16 | ################ 17 | 18 | ### model stuff ### 19 | model='union_2t' 20 | mode="$1" 21 | stage_flag="$2" 22 | dataset_mode='snet' 23 | note="test" 24 | category="$3" 25 | 26 | df_yaml="octfusion_${dataset_mode}_cond.yaml" 27 | df_cfg="configs/${df_yaml}" 28 | vq_model="GraphVAE" 29 | vq_yaml="vae_${dataset_mode}_eval.yaml" 30 | vq_cfg="configs/${vq_yaml}" 31 | vq_ckpt="saved_ckpt/vae-ckpt/vae-shapenet-depth-8.pth" 32 | 33 | ### hyper params ### 34 | lr=2e-4 35 | min_lr=1e-6 36 | update_learning_rate=0 37 | warmup_epochs=40 38 | ema_rate=0.999 39 | ckpt_num=3 40 | seed=42 41 | 42 | if [ $stage_flag = "lr" ]; then 43 | epochs=3000 44 | batch_size=16 45 | else 46 | epochs=500 47 | batch_size=2 48 | fi 49 | 50 | if [ $mode = "train" ]; then 51 | pretrain_ckpt="saved_ckpt/diffusion-ckpt/im_5/df_steps-split.pth" 52 | else 53 | ckpt="saved_ckpt/diffusion-ckpt/im_5/df_steps-union.pth" 54 | fi 55 | 56 | #################### 57 | 58 | ##################### 59 | 60 | ### display & log stuff ### 61 | display_freq=1000 62 | print_freq=25 63 | save_steps_freq=3000 64 | save_latest_freq=500 65 | ########################### 66 | 67 | 68 | today=$(date '+%m%d') 69 | me=`basename "$0"` 70 | me=$(echo $me | cut -d'.' -f 1) 71 | 72 | name="${category}_union/${model}_${note}_lr${lr}" 73 | 74 | debug=0 75 | 76 | cmd="train.py --name ${name} --logs_dir ${logs_dir} --gpu_ids ${gpu_ids} --mode ${mode} \ 77 | --lr ${lr} --epochs ${epochs} --min_lr ${min_lr} --warmup_epochs ${warmup_epochs} --update_learning_rate ${update_learning_rate} --ema_rate ${ema_rate} --seed ${seed} \ 78 | --model ${model} --stage_flag ${stage_flag} --df_cfg ${df_cfg} --ckpt_num ${ckpt_num} --category ${category} \ 79 | --vq_model ${vq_model} --vq_cfg ${vq_cfg} --vq_ckpt ${vq_ckpt} \ 80 | --display_freq ${display_freq} --print_freq ${print_freq} \ 81 | --save_steps_freq ${save_steps_freq} --save_latest_freq ${save_latest_freq} \ 82 | --debug ${debug}" 83 | 84 | if [ ! -z "$ckpt" ]; then 85 | cmd="${cmd} --ckpt ${ckpt}" 86 | echo "continue training with ckpt=${ckpt}" 87 | fi 88 | if [ ! -z "$pretrain_ckpt" ]; then 89 | cmd="${cmd} --pretrain_ckpt ${pretrain_ckpt}" 90 | echo "second stage training with pretrain_ckpt=${pretrain_ckpt}" 91 | fi 92 | if [ ! -z "$split_dir" ]; then 93 | cmd="${cmd} --split_dir ${split_dir}" 94 | echo "generate with split_dir=${split_dir}" 95 | fi 96 | 97 | multi_gpu=0 98 | if [ ${#gpu_ids} -gt 1 ]; then 99 | multi_gpu=1 100 | fi 101 | 102 | echo "[*] Training is starting on `hostname`, GPU#: ${gpu_ids}, logs_dir: ${logs_dir}" 103 | 104 | echo "[*] Training with command: " 105 | 106 | if [ $multi_gpu = 1 ]; then 107 | 108 | cmd="--nnodes=1 --nproc_per_node=${NGPU} --rdzv-backend=c10d --rdzv-endpoint=${HOST_NODE_ADDR} ${cmd}" 109 | echo "CUDA_VISIBLE_DEVICES=${gpu_ids} torchrun ${cmd}" 110 | CUDA_VISIBLE_DEVICES=${gpu_ids} torchrun ${cmd} 111 | 112 | else 113 | 114 | echo "CUDA_VISIBLE_DEVICES=${gpu_ids} python3 ${cmd}" 115 | CUDA_VISIBLE_DEVICES=${gpu_ids} python3 ${cmd} 116 | 117 | fi 118 | -------------------------------------------------------------------------------- /scripts/run_snet_uncond.sh: -------------------------------------------------------------------------------- 1 | RED='\033[0;31m' 2 | NC='\033[0m' # No Color 3 | DATE_WITH_TIME=`date "+%Y-%m-%dT%H-%M-%S"` 4 | 5 | logs_dir='logs' 6 | 7 | ### set gpus ### 8 | gpu_ids=0,1,2,3 # multi-gpu 9 | 10 | if [ ${#gpu_ids} -gt 1 ]; then 11 | # specify these two if multi-gpu 12 | NGPU=4 13 | HOST_NODE_ADDR="localhost:25000" 14 | echo "HERE" 15 | fi 16 | ################ 17 | 18 | ### model stuff ### 19 | model='union_2t' 20 | mode="$1" 21 | stage_flag="$2" 22 | dataset_mode='snet' 23 | note="test" 24 | category="$3" 25 | 26 | df_yaml="octfusion_${dataset_mode}_uncond.yaml" 27 | df_cfg="configs/${df_yaml}" 28 | vq_model="GraphVAE" 29 | vq_yaml="vae_${dataset_mode}_eval.yaml" 30 | vq_cfg="configs/${vq_yaml}" 31 | vq_ckpt="saved_ckpt/vae-ckpt/vae-shapenet-depth-8.pth" 32 | 33 | ### hyper params ### 34 | lr=2e-4 35 | min_lr=1e-6 36 | update_learning_rate=0 37 | warmup_epochs=40 38 | ema_rate=0.999 39 | ckpt_num=3 40 | seed=42 41 | 42 | if [ $stage_flag = "lr" ]; then 43 | epochs=3000 44 | batch_size=16 45 | else 46 | epochs=500 47 | batch_size=2 48 | fi 49 | 50 | if [ $mode = "train" ]; then 51 | pretrain_ckpt="saved_ckpt/diffusion-ckpt/${category}/df_steps-split.pth" 52 | else 53 | ckpt="saved_ckpt/diffusion-ckpt/${category}/df_steps-union.pth" 54 | fi 55 | 56 | #################### 57 | 58 | ##################### 59 | 60 | ### display & log stuff ### 61 | display_freq=1000 62 | print_freq=25 63 | save_steps_freq=3000 64 | save_latest_freq=500 65 | ########################### 66 | 67 | 68 | today=$(date '+%m%d') 69 | me=`basename "$0"` 70 | me=$(echo $me | cut -d'.' -f 1) 71 | 72 | name="${category}_union/${model}_${note}_lr${lr}" 73 | 74 | debug=0 75 | 76 | cmd="train.py --name ${name} --logs_dir ${logs_dir} --gpu_ids ${gpu_ids} --mode ${mode} \ 77 | --lr ${lr} --epochs ${epochs} --min_lr ${min_lr} --warmup_epochs ${warmup_epochs} --update_learning_rate ${update_learning_rate} --ema_rate ${ema_rate} --seed ${seed} \ 78 | --model ${model} --stage_flag ${stage_flag} --df_cfg ${df_cfg} --ckpt_num ${ckpt_num} --category ${category} \ 79 | --vq_model ${vq_model} --vq_cfg ${vq_cfg} --vq_ckpt ${vq_ckpt} \ 80 | --display_freq ${display_freq} --print_freq ${print_freq} \ 81 | --save_steps_freq ${save_steps_freq} --save_latest_freq ${save_latest_freq} \ 82 | --debug ${debug}" 83 | 84 | if [ ! -z "$ckpt" ]; then 85 | cmd="${cmd} --ckpt ${ckpt}" 86 | echo "continue training with ckpt=${ckpt}" 87 | fi 88 | if [ ! -z "$pretrain_ckpt" ]; then 89 | cmd="${cmd} --pretrain_ckpt ${pretrain_ckpt}" 90 | echo "second stage training with pretrain_ckpt=${pretrain_ckpt}" 91 | fi 92 | if [ ! -z "$split_dir" ]; then 93 | cmd="${cmd} --split_dir ${split_dir}" 94 | echo "generate with split_dir=${split_dir}" 95 | fi 96 | 97 | multi_gpu=0 98 | if [ ${#gpu_ids} -gt 1 ]; then 99 | multi_gpu=1 100 | fi 101 | 102 | echo "[*] Training is starting on `hostname`, GPU#: ${gpu_ids}, logs_dir: ${logs_dir}" 103 | 104 | echo "[*] Training with command: " 105 | 106 | if [ $multi_gpu = 1 ]; then 107 | 108 | cmd="--nnodes=1 --nproc_per_node=${NGPU} --rdzv-backend=c10d --rdzv-endpoint=${HOST_NODE_ADDR} ${cmd}" 109 | echo "CUDA_VISIBLE_DEVICES=${gpu_ids} torchrun ${cmd}" 110 | CUDA_VISIBLE_DEVICES=${gpu_ids} torchrun ${cmd} 111 | 112 | else 113 | 114 | echo "CUDA_VISIBLE_DEVICES=${gpu_ids} python3 ${cmd}" 115 | CUDA_VISIBLE_DEVICES=${gpu_ids} python3 ${cmd} 116 | 117 | fi 118 | -------------------------------------------------------------------------------- /models/networks/diffusion_networks/graph_unet_union.py: -------------------------------------------------------------------------------- 1 | ### adapted from: https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py 2 | 3 | from abc import abstractmethod 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from models.networks.diffusion_networks import graph_unet_lr, graph_unet_hr 9 | from random import random 10 | 11 | class UNet3DModel(nn.Module): 12 | 13 | 14 | # def __init__(self, config_dict): 15 | def __init__( 16 | self, 17 | stage_flag, 18 | image_size, 19 | input_depth, 20 | unet_type, 21 | full_depth, 22 | input_channels, 23 | out_channels, 24 | model_channels, 25 | num_res_blocks, 26 | attention_resolutions, 27 | channel_mult, 28 | num_heads, 29 | use_checkpoint, 30 | dims, 31 | num_classes=None, 32 | **kwargs, 33 | ): 34 | super().__init__() 35 | num_models = len(unet_type) 36 | self.unet_lr = None 37 | self.unet_hr = None 38 | self.unet_feature = None 39 | for i in range(num_models): 40 | if unet_type[i] == "lr": 41 | unet_model = graph_unet_lr.UNet3DModel( 42 | full_depth=full_depth, 43 | in_split_channels=input_channels[i], 44 | model_channels=model_channels[i], 45 | out_split_channels=out_channels[i], 46 | attention_resolutions=attention_resolutions, 47 | channel_mult=channel_mult[i], 48 | use_checkpoint=use_checkpoint, 49 | num_heads=num_heads, 50 | dims=dims, 51 | num_classes=num_classes, 52 | ) 53 | self.unet_lr = unet_model 54 | elif unet_type[i] == "hr" or unet_type[i] == "feature": 55 | unet_model = graph_unet_hr.UNet3DModel( 56 | image_size=image_size[i], 57 | input_depth=input_depth[i], 58 | full_depth=full_depth, 59 | in_channels=input_channels[i], 60 | model_channels=model_channels[i], 61 | lr_model_channels=model_channels[i - 1], 62 | out_channels=out_channels[i], 63 | num_res_blocks=num_res_blocks[i], 64 | channel_mult=channel_mult[i], 65 | dims=dims, 66 | use_checkpoint=use_checkpoint, 67 | num_heads=num_heads, 68 | num_classes=num_classes, 69 | ) 70 | if unet_type[i] == "hr": 71 | self.unet_hr = unet_model 72 | if unet_type[i] == "feature": 73 | self.unet_feature = unet_model 74 | else: 75 | raise ValueError 76 | if unet_type[i] == stage_flag: 77 | break 78 | 79 | 80 | def forward(self, unet_type=None, **input_data): 81 | if unet_type == "lr": 82 | if 'self_cond' not in input_data and random() < 0.5: 83 | with torch.no_grad(): 84 | self_cond = self.unet_lr(**input_data) 85 | input_data['self_cond'] = self_cond 86 | return self.unet_lr(**input_data) 87 | elif unet_type == "hr": 88 | return self.unet_hr(**input_data) 89 | elif unet_type == "feature": 90 | return self.unet_feature(**input_data) 91 | else: 92 | raise ValueError 93 | 94 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | source: https://github.com/rosinality/stylegan2-pytorch/blob/master/distributed.py 3 | """ 4 | 5 | import math 6 | import pickle 7 | 8 | import torch 9 | from torch import distributed as dist 10 | from torch.utils.data.sampler import Sampler 11 | 12 | 13 | def get_rank(): 14 | if not dist.is_available(): 15 | return 0 16 | 17 | if not dist.is_initialized(): 18 | return 0 19 | 20 | return dist.get_rank() 21 | 22 | 23 | def synchronize(local_rank=0): 24 | if not dist.is_available(): 25 | return 26 | 27 | if not dist.is_initialized(): 28 | return 29 | 30 | world_size = dist.get_world_size() 31 | 32 | if world_size == 1: 33 | return 34 | 35 | dist.barrier() 36 | # dist.barrier(device_ids=[local_rank]) 37 | 38 | 39 | def get_world_size(): 40 | if not dist.is_available(): 41 | return 1 42 | 43 | if not dist.is_initialized(): 44 | return 1 45 | 46 | return dist.get_world_size() 47 | 48 | 49 | def reduce_sum(tensor): 50 | if not dist.is_available(): 51 | return tensor 52 | 53 | if not dist.is_initialized(): 54 | return tensor 55 | 56 | tensor = tensor.clone() 57 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 58 | 59 | return tensor 60 | 61 | 62 | def gather_grad(params): 63 | world_size = get_world_size() 64 | 65 | if world_size == 1: 66 | return 67 | 68 | for param in params: 69 | if param.grad is not None: 70 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 71 | param.grad.data.div_(world_size) 72 | 73 | 74 | def all_gather(data): 75 | world_size = get_world_size() 76 | 77 | if world_size == 1: 78 | return [data] 79 | 80 | buffer = pickle.dumps(data) 81 | storage = torch.ByteStorage.from_buffer(buffer) 82 | tensor = torch.ByteTensor(storage).to('cuda') 83 | 84 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 85 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 86 | dist.all_gather(size_list, local_size) 87 | size_list = [int(size.item()) for size in size_list] 88 | max_size = max(size_list) 89 | 90 | tensor_list = [] 91 | for _ in size_list: 92 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 93 | 94 | if local_size != max_size: 95 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 96 | tensor = torch.cat((tensor, padding), 0) 97 | 98 | dist.all_gather(tensor_list, tensor) 99 | 100 | data_list = [] 101 | 102 | for size, tensor in zip(size_list, tensor_list): 103 | buffer = tensor.cpu().numpy().tobytes()[:size] 104 | data_list.append(pickle.loads(buffer)) 105 | 106 | return data_list 107 | 108 | 109 | def reduce_loss_dict(loss_dict): 110 | world_size = get_world_size() 111 | # print(world_size) 112 | 113 | if world_size < 2: 114 | return loss_dict 115 | 116 | with torch.no_grad(): 117 | keys = [] 118 | losses = [] 119 | 120 | for k in sorted(loss_dict.keys()): 121 | keys.append(k) 122 | losses.append(loss_dict[k]) 123 | 124 | try: 125 | losses = torch.stack(losses, 0) 126 | except: 127 | print(losses) 128 | dist.reduce(losses, dst=0) 129 | 130 | if dist.get_rank() == 0: 131 | losses /= world_size 132 | 133 | reduced_losses = {k: v for k, v in zip(keys, losses)} 134 | 135 | return reduced_losses -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/Makefile: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Uncomment for debugging 3 | # DEBUG := 1 4 | # Pretty build 5 | # Q ?= @ 6 | 7 | CXX := g++ 8 | PYTHON := python 9 | NVCC := /usr/local/cuda/bin/nvcc 10 | 11 | # PYTHON Header path 12 | PYTHON_HEADER_DIR := $(shell $(PYTHON) -c 'from distutils.sysconfig import get_python_inc; print(get_python_inc())') 13 | PYTORCH_INCLUDES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import include_paths; [print(p) for p in include_paths()]') 14 | PYTORCH_LIBRARIES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import library_paths; [print(p) for p in library_paths()]') 15 | 16 | # CUDA ROOT DIR that contains bin/ lib64/ and include/ 17 | # CUDA_DIR := /usr/local/cuda 18 | CUDA_DIR := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import _find_cuda_home; print(_find_cuda_home())') 19 | 20 | INCLUDE_DIRS := ./ $(CUDA_DIR)/include 21 | 22 | INCLUDE_DIRS += $(PYTHON_HEADER_DIR) 23 | INCLUDE_DIRS += $(PYTORCH_INCLUDES) 24 | 25 | # Custom (MKL/ATLAS/OpenBLAS) include and lib directories. 26 | # Leave commented to accept the defaults for your choice of BLAS 27 | # (which should work)! 28 | # BLAS_INCLUDE := /path/to/your/blas 29 | # BLAS_LIB := /path/to/your/blas 30 | 31 | ############################################################################### 32 | SRC_DIR := ./src 33 | OBJ_DIR := ./objs 34 | CPP_SRCS := $(wildcard $(SRC_DIR)/*.cpp) 35 | CU_SRCS := $(wildcard $(SRC_DIR)/*.cu) 36 | OBJS := $(patsubst $(SRC_DIR)/%.cpp,$(OBJ_DIR)/%.o,$(CPP_SRCS)) 37 | CU_OBJS := $(patsubst $(SRC_DIR)/%.cu,$(OBJ_DIR)/cuda/%.o,$(CU_SRCS)) 38 | STATIC_LIB := $(OBJ_DIR)/libmake_pytorch.a 39 | 40 | # CUDA architecture setting: going with all of them. 41 | # For CUDA < 6.0, comment the *_50 through *_61 lines for compatibility. 42 | # For CUDA < 8.0, comment the *_60 and *_61 lines for compatibility. 43 | CUDA_ARCH := -gencode arch=compute_61,code=sm_61 \ 44 | -gencode arch=compute_61,code=compute_61 \ 45 | -gencode arch=compute_52,code=sm_52 46 | 47 | # We will also explicitly add stdc++ to the link target. 48 | LIBRARIES += stdc++ cudart c10 caffe2 torch torch_python caffe2_gpu 49 | 50 | # Debugging 51 | ifeq ($(DEBUG), 1) 52 | COMMON_FLAGS += -DDEBUG -g -O0 53 | # https://gcoe-dresden.de/reaching-the-shore-with-a-fog-warning-my-eurohack-day-4-morning-session/ 54 | NVCCFLAGS += -g -G # -rdc true 55 | else 56 | COMMON_FLAGS += -DNDEBUG -O3 57 | endif 58 | 59 | WARNINGS := -Wall -Wno-sign-compare -Wcomment 60 | 61 | INCLUDE_DIRS += $(BLAS_INCLUDE) 62 | 63 | # Automatic dependency generation (nvcc is handled separately) 64 | CXXFLAGS += -MMD -MP 65 | 66 | # Complete build flags. 67 | COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) \ 68 | -DTORCH_API_INCLUDE_EXTENSION_H -D_GLIBCXX_USE_CXX11_ABI=0 69 | CXXFLAGS += -pthread -fPIC -fwrapv -std=c++17 $(COMMON_FLAGS) $(WARNINGS) 70 | NVCCFLAGS += -std=c++17 -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS) 71 | 72 | all: $(STATIC_LIB) 73 | $(PYTHON) setup.py build 74 | @ mv build/lib.linux-x86_64-cpython-310/StructuralLosses .. 75 | @ mv build/lib.linux-x86_64-cpython-310/*.so ../StructuralLosses/ 76 | @- $(RM) -rf $(OBJ_DIR) build objs 77 | 78 | $(OBJ_DIR): 79 | @ mkdir -p $@ 80 | @ mkdir -p $@/cuda 81 | 82 | $(OBJ_DIR)/%.o: $(SRC_DIR)/%.cpp | $(OBJ_DIR) 83 | @ echo CXX $< 84 | $(Q)$(CXX) $< $(CXXFLAGS) -c -o $@ 85 | 86 | $(OBJ_DIR)/cuda/%.o: $(SRC_DIR)/%.cu | $(OBJ_DIR) 87 | @ echo NVCC $< 88 | $(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -M $< -o ${@:.o=.d} \ 89 | -odir $(@D) 90 | $(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@ 91 | 92 | $(STATIC_LIB): $(OBJS) $(CU_OBJS) | $(OBJ_DIR) 93 | $(RM) -f $(STATIC_LIB) 94 | $(RM) -rf build dist 95 | @ echo LD -o $@ 96 | ar rc $(STATIC_LIB) $(OBJS) $(CU_OBJS) 97 | 98 | clean: 99 | @- $(RM) -rf $(OBJ_DIR) build dist ../StructuralLosses 100 | -------------------------------------------------------------------------------- /datasets/shapenet_utils.py: -------------------------------------------------------------------------------- 1 | TSDF_VALUE = 1/32 2 | SDF_CLIP_VALUE = 0.05 3 | 4 | snc_category_to_synth_id_13 = { 5 | 'airplane': '02691156', 6 | 'bench': '02828884', 7 | 'cabinet': '02933112', 8 | 'car': '02958343', 9 | 'chair': '03001627', 10 | 'monitor': '03211117', 11 | 'lamp': '03636649', 12 | 'loudspeaker': '03691459', 13 | 'rifle': '04090263', 14 | 'sofa': '04256520', 15 | 'table': '04379243', 16 | 'telephone': '04401088', 17 | 'vessel': '04530566', 18 | } 19 | 20 | snc_synth_id_to_label_13 = { 21 | '02691156':0, 22 | '02828884':1, 23 | '02933112':2, 24 | '02958343':3, 25 | '03001627':4, 26 | '03211117':5, 27 | '03636649':6, 28 | '03691459':7, 29 | '04090263':8, 30 | '04256520':9, 31 | '04379243':10, 32 | '04401088':11, 33 | '04530566':12, 34 | } 35 | 36 | snc_synth_id_to_label_5 = { 37 | '02691156':0, 38 | '02958343':1, 39 | '03001627':2, 40 | '04379243':3, 41 | '04090263':4, 42 | } 43 | 44 | snc_category_to_synth_id_5 = { 45 | 'airplane': '02691156', 'car': '02958343', 'chair': '03001627', 46 | 'table': '04379243', 'rifle': '04090263' 47 | } 48 | 49 | 50 | snc_synth_id_to_category_5 = { 51 | '02691156': 'airplane', '02958343': 'car', '03001627': 'chair', 52 | '04379243': 'table', 53 | '04090263': 'rifle' 54 | } 55 | 56 | 57 | snc_synth_id_to_category_all = { 58 | '02691156': 'airplane', '02773838': 'bag', '02801938': 'basket', 59 | '02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench', 60 | '02834778': 'bicycle', '02843684': 'birdhouse', '02871439': 'bookshelf', 61 | '02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus', 62 | '02933112': 'cabinet', '02747177': 'can', '02942699': 'camera', 63 | '02954340': 'cap', '02958343': 'car', '03001627': 'chair', 64 | '03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor', 65 | '04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can', 66 | '04460130': 'tower', '04468005': 'train', '03085013': 'keyboard', 67 | '03261776': 'earphone', '03325088': 'faucet', '03337140': 'file', 68 | '03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar', 69 | '03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop', 70 | '03691459': 'loudspeaker', '03710193': 'mailbox', '03759954': 'microphone', 71 | '03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug', 72 | '03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol', 73 | '03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control', 74 | '04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard', 75 | '04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel', 76 | '04554684': 'washer', '02858304': 'boat', '02992529': 'cellphone' 77 | } 78 | 79 | 80 | snc_category_to_synth_id_all = { 81 | 'airplane': '02691156', 'bag': '02773838', 'basket': '02801938', 82 | 'bathtub': '02808440', 'bed': '02818832', 'bench': '02828884', 83 | 'bicycle': '02834778', 'birdhouse': '02843684', 'bookshelf': '02871439', 84 | 'bottle': '02876657', 'bowl': '02880940', 'bus': '02924116', 85 | 'cabinet': '02933112', 'can': '02747177', 'camera': '02942699', 86 | 'cap': '02954340', 'car': '02958343', 'chair': '03001627', 87 | 'clock': '03046257', 'dishwasher': '03207941', 'monitor': '03211117', 88 | 'table': '04379243', 'telephone': '04401088', 'tin_can': '02946921', 89 | 'tower': '04460130', 'train': '04468005', 'keyboard': '03085013', 90 | 'earphone': '03261776', 'faucet': '03325088', 'file': '03337140', 91 | 'guitar': '03467517', 'helmet': '03513137', 'jar': '03593526', 92 | 'knife': '03624134', 'lamp': '03636649', 'laptop': '03642806', 93 | 'loudspeaker': '03691459', 'mailbox': '03710193', 'microphone': '03759954', 94 | 'microwave': '03761084', 'motorcycle': '03790512', 'mug': '03797390', 95 | 'piano': '03928116', 'pillow': '03938244', 'pistol': '03948459', 96 | 'pot': '03991062', 'printer': '04004475', 'remote_control': '04074963', 97 | 'rifle': '04090263', 'rocket': '04099429', 'skateboard': '04225987', 98 | 'sofa': '04256520', 'stove': '04330267', 'vessel': '04530566', 99 | 'washer': '04554684', 'boat': '02858304', 'cellphone': '02992529' 100 | } 101 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from termcolor import colored, cprint 3 | import torch 4 | import utils.util as util 5 | import math 6 | 7 | def create_model(opt): 8 | model = None 9 | 10 | if opt.model == "union_2t": 11 | from models.octfusion_model_union import OctFusionModel 12 | model = OctFusionModel() 13 | elif opt.model == "union_3t": 14 | from models.octfusion_model_union_3t import OctFusionModel 15 | model = OctFusionModel() 16 | elif opt.model == "vae": 17 | from models.octfusion_model_vae import OctFusionModel 18 | model = OctFusionModel() 19 | else: 20 | raise ValueError 21 | 22 | model.initialize(opt) 23 | cprint("[*] Model has been created: %s" % model.name(), 'blue') 24 | return model 25 | 26 | 27 | # modified from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 28 | class BaseModel(): 29 | def name(self): 30 | return 'BaseModel' 31 | 32 | def initialize(self, opt): 33 | self.opt = opt 34 | self.gpu_ids = opt.gpu_ids 35 | self.isTrain = opt.isTrain 36 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 37 | 38 | self.model_names = [] 39 | self.epoch_labels = [] 40 | self.optimizers = [] 41 | 42 | def set_input(self, input): 43 | self.input = input 44 | 45 | def forward(self): 46 | pass 47 | 48 | def get_image_paths(self): 49 | pass 50 | 51 | def optimize_parameters(self): 52 | pass 53 | 54 | def get_current_errors(self): 55 | return {} 56 | 57 | # define the optimizers 58 | def set_optimizers(self): 59 | pass 60 | 61 | def set_requires_grad(self, nets, requires_grad=False): 62 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 63 | Parameters: 64 | nets (network list) -- a list of networks 65 | requires_grad (bool) -- whether the networks require gradients or not 66 | """ 67 | if not isinstance(nets, list): 68 | nets = [nets] 69 | for net in nets: 70 | if net is not None: 71 | for param in net.parameters(): 72 | param.requires_grad = requires_grad 73 | 74 | # update learning rate (called once every epoch) 75 | def update_learning_rate(self): 76 | for scheduler in self.schedulers: 77 | scheduler.step() 78 | lr = self.optimizers[0].param_groups[0]['lr'] 79 | print('[*] learning rate = %.7f' % lr) 80 | 81 | def update_learning_rate_cos(self, epoch, opt): 82 | """Decay the learning rate with half-cycle cosine after warmup""" 83 | if epoch >= opt.warmup_epochs: 84 | lr = opt.min_lr + (opt.lr - opt.min_lr) * 0.5 * \ 85 | (1. + math.cos(math.pi * (epoch - opt.warmup_epochs) / (opt.epochs - opt.warmup_epochs))) 86 | for param_group in self.optimizer.param_groups: 87 | if "lr_scale" in param_group: 88 | param_group["lr"] = lr * param_group["lr_scale"] 89 | else: 90 | param_group["lr"] = lr 91 | print('[*] learning rate = %.7f' % lr) 92 | 93 | def eval(self): 94 | for name in self.model_names: 95 | if isinstance(name, str): 96 | net = getattr(self, 'net' + name) 97 | net.eval() 98 | 99 | def train(self): 100 | for name in self.model_names: 101 | if isinstance(name, str): 102 | net = getattr(self, 'net' + name) 103 | net.train() 104 | 105 | # print network information 106 | def print_networks(self, verbose=False): 107 | print('---------- Networks initialized -------------') 108 | for name in self.model_names: 109 | if isinstance(name, str): 110 | net = getattr(self, 'net' + name) 111 | num_params = 0 112 | for param in net.parameters(): 113 | num_params += param.numel() 114 | if verbose: 115 | print(net) 116 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 117 | print('-----------------------------------------------') 118 | 119 | def tocuda(self, var_names): 120 | for name in var_names: 121 | if isinstance(name, str): 122 | var = getattr(self, name) 123 | # setattr(self, name, var.cuda(self.gpu_ids[0], non_blocking=True)) 124 | setattr(self, name, var.cuda(self.opt.device, non_blocking=True)) 125 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OctFusion: Octree-based Diffusion Models for 3D Shape Generation 2 | [[`arXiv`](https://arxiv.org/abs/2408.14732)] 3 | [[`BibTex`](#citation)] 4 | 5 | Code release for the paper "OctFusion: Octree-based Diffusion Models for 3D Shape Generation". Computer Graphics Forum (presented at SGP 2025) 6 | 7 | ![teaser](./assets/teaser.png) 8 | 9 | 10 | ## 1. Installation 11 | 1. Clone this repository 12 | ```bash 13 | git clone https://github.com/octree-nn/octfusion.git 14 | cd octfusion 15 | ``` 16 | 2. Create a `Conda` environment. 17 | ```bash 18 | conda create -n octfusion python=3.11 -y && conda activate octfusion 19 | ``` 20 | 21 | 3. Install PyTorch with Conda 22 | ```bash 23 | conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia 24 | ``` 25 | 26 | 4. Install other requirements. 27 | ```bash 28 | pip3 install -r requirements.txt 29 | ``` 30 | 31 | ## 2. Generation with pre-trained models 32 | 33 | ### 2.1 Download pre-trained models 34 | We provide the pretrained models for the category-conditioned generation and sketch-conditioned generation. Please download the pretrained models from [Google Drive](https://drive.google.com/drive/folders/140U_xzAy1MobUqurN67Fm2Y-3oWrZQ1m?usp=drive_link) or [Baidu Netdisk](https://pan.baidu.com/s/15-jp9Mwtw4soch8GAC7qgQ?pwd=rhui) and put them in `saved_ckpt/diffusion-ckpt` and `saved_ckpt/vae-ckpt`. 35 | 36 | ### 2.2 Generation 37 | 1. Unconditional generation in category `airplane`, `car`, `chair`, `rifle`, `table`. 38 | ``` 39 | sh scripts/run_snet_uncond.sh generate hr $category 40 | # Example 41 | sh scripts/run_snet_uncond.sh generate hr airplane 42 | 43 | ``` 44 | 45 | 2. Category-conditioned generation 46 | ``` 47 | sh scripts/run_snet_cond.sh generate hr $category 48 | # Example 49 | sh scripts/run_snet_cond.sh generate hr chair 50 | ``` 51 | 52 | ## 3. Train from scratch 53 | ### 3.1 Data Preparation 54 | 55 | 1. Download `ShapeNetCore.v1.zip` (31G) from [ShapeNet](https://shapenet.org/) and place it in `data/ShapeNet/ShapeNetCore.v1.zip`. Download `filelist` from [Google Drive](https://drive.google.com/drive/folders/140U_xzAy1MobUqurN67Fm2Y-3oWrZQ1m?usp=drive_link) or [Baidu Netdisk](https://pan.baidu.com/s/15-jp9Mwtw4soch8GAC7qgQ?pwd=rhui) and place it in `data/ShapeNet/filelist`. 56 | 57 | 2. Convert the meshes in `ShapeNetCore.v1` to signed distance fields (SDFs). 58 | We use the same data preparation as [DualOctreeGNN](https://github.com/microsoft/DualOctreeGNN.git). Note that this process is relatively slow, it may take several days to finish converting all the meshes from ShapeNet. 59 | ```bash 60 | python tools/repair_mesh.py --run convert_mesh_to_sdf 61 | python tools/repair_mesh.py --run generate_dataset 62 | ``` 63 | 64 | 65 | 66 | ### 3.2 Train OctFusion 67 | 1. VAE Training. We provide pretrained weights in `saved_ckpt/vae-ckpt/vae-shapenet-depth-8.pth`. 68 | ```bash 69 | sh scripts/run_snet_vae.sh train vae im_5 70 | ``` 71 | 2. Train the first stage model. We provide pretrained weights in `saved_ckpt/diffusion-ckpt/$category/df_steps-split.pth`. 72 | ```bash 73 | sh scripts/run_snet_uncond.sh train lr $category 74 | ``` 75 | 76 | 3. Load the pretrained first stage model and train the second stage. We provide pretrained weights in `saved_ckpt/diffusion-ckpt/$category/df_steps-union.pth`. 77 | ```bash 78 | sh scripts/run_snet_uncond.sh train hr $category 79 | ``` 80 | # Citation 81 | 82 | If you find this code helpful, please consider citing: 83 | 84 | 85 | 1. arxiv version 86 | ```BibTeX 87 | @article{xiong2024octfusion, 88 | title={Octfusion: Octree-based diffusion models for 3d shape generation}, 89 | author={Xiong, Bojun and Wei, Si-Tong and Zheng, Xin-Yang and Cao, Yan-Pei and Lian, Zhouhui and Wang, Peng-Shuai}, 90 | journal={arXiv preprint arXiv:2408.14732}, 91 | year={2024} 92 | } 93 | ``` 94 | 2. CGF version 95 | ```BibTex 96 | @article{Xiong_2025_SGP, 97 | journal = {Computer Graphics Forum}, 98 | title = {{OctFusion: Octree-based Diffusion Models for 3D Shape Generation}}, 99 | author = {Xiong, Bojun and Wei, Si-Tong and Zheng, Xin-Yang and Cao, Yan-Pei and Lian, Zhouhui and Wang, Peng-Shuai}, 100 | year = {2025}, 101 | publisher = {The Eurographics Association and John Wiley & Sons Ltd.}, 102 | ISSN = {1467-8659}, 103 | DOI = {10.1111/cgf.70198} 104 | } 105 | ``` 106 | 107 | # Issues and FAQ 108 | Coming soon! 109 | 110 | # Acknowledgement 111 | This code borrows heavely from [SDFusion](https://github.com/yccyenchicheng/SDFusion), [LAS-Diffusion](https://github.com/Zhengxinyang/LAS-Diffusion), [DualOctreeGNN](https://github.com/microsoft/DualOctreeGNN). We thank the authors for their great work. The followings packages are required to compute the SDF: [mesh2sdf](https://github.com/wang-ps/mesh2sdf). 112 | -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/src/nndistance.cu: -------------------------------------------------------------------------------- 1 | 2 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 3 | const int batch=512; 4 | __shared__ float buf[batch*3]; 5 | for (int i=blockIdx.x;ibest){ 117 | result[(i*n+j)]=best; 118 | result_i[(i*n+j)]=best_i; 119 | } 120 | } 121 | __syncthreads(); 122 | } 123 | } 124 | } 125 | void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 126 | NmDistanceKernel<<>>(b,n,xyz,m,xyz2,result,result_i); 127 | NmDistanceKernel<<>>(b,m,xyz2,n,xyz,result2,result2_i); 128 | } 129 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 130 | for (int i=blockIdx.x;i>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2); 153 | NmDistanceGradKernel<<>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1); 154 | } 155 | 156 | -------------------------------------------------------------------------------- /models/networks/dualoctree_networks/quantizer.py: -------------------------------------------------------------------------------- 1 | """ adapted from: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from torch import einsum 8 | from einops import rearrange 9 | 10 | class VectorQuantizer(nn.Module): 11 | """ 12 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly 13 | avoids costly matrix multiplications and allows for post-hoc remapping of indices. 14 | """ 15 | # NOTE: due to a bug the beta term was applied to the wrong term. for 16 | # backwards compatibility we use the buggy version by default, but you can 17 | # specify legacy=False to fix it. 18 | def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", 19 | sane_index_shape=False, legacy=True): 20 | super().__init__() 21 | self.n_e = n_e 22 | self.e_dim = e_dim 23 | self.beta = beta 24 | self.legacy = legacy 25 | 26 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 27 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 28 | 29 | self.remap = remap 30 | if self.remap is not None: 31 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 32 | self.re_embed = self.used.shape[0] 33 | self.unknown_index = unknown_index # "random" or "extra" or integer 34 | if self.unknown_index == "extra": 35 | self.unknown_index = self.re_embed 36 | self.re_embed = self.re_embed+1 37 | print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " 38 | f"Using {self.unknown_index} for unknown indices.") 39 | else: 40 | self.re_embed = n_e 41 | 42 | self.sane_index_shape = sane_index_shape 43 | 44 | def remap_to_used(self, inds): 45 | ishape = inds.shape 46 | assert len(ishape)>1 47 | inds = inds.reshape(ishape[0],-1) 48 | used = self.used.to(inds) 49 | match = (inds[:,:,None]==used[None,None,...]).long() 50 | new = match.argmax(-1) 51 | unknown = match.sum(2)<1 52 | if self.unknown_index == "random": 53 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) 54 | else: 55 | new[unknown] = self.unknown_index 56 | return new.reshape(ishape) 57 | 58 | def unmap_to_all(self, inds): 59 | ishape = inds.shape 60 | assert len(ishape)>1 61 | inds = inds.reshape(ishape[0],-1) 62 | used = self.used.to(inds) 63 | if self.re_embed > self.used.shape[0]: # extra token 64 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero 65 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) 66 | return back.reshape(ishape) 67 | 68 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False, is_voxel=False, is_octree = False): 69 | assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" 70 | assert rescale_logits==False, "Only for interface compatible with Gumbel" 71 | assert return_logits==False, "Only for interface compatible with Gumbel" 72 | # reshape z -> (batch, height, width, channel) and flatten 73 | if is_voxel: 74 | z = rearrange(z, 'b c d h w -> b d h w c').contiguous() 75 | elif is_octree: 76 | pass 77 | 78 | # if not is_voxel: 79 | # z = rearrange(z, 'b c h w -> b h w c').contiguous() 80 | # else: 81 | # z = rearrange(z, 'b c d h w -> b d h w c').contiguous() 82 | 83 | z_flattened = z.view(-1, self.e_dim) 84 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 85 | 86 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 87 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 88 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) 89 | 90 | min_encoding_indices = torch.argmin(d, dim=1) 91 | z_q = self.embedding(min_encoding_indices).view(z.shape) 92 | perplexity = None 93 | min_encodings = None 94 | 95 | # compute loss for embedding 96 | if not self.legacy: 97 | loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ 98 | torch.mean((z_q - z.detach()) ** 2) 99 | else: 100 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ 101 | torch.mean((z_q - z.detach()) ** 2) 102 | 103 | # preserve gradients 104 | z_q = z + (z_q - z).detach() 105 | 106 | # reshape back to match original input shape 107 | if is_voxel: 108 | z_q = rearrange(z_q, 'b d h w c -> b c d h w').contiguous() 109 | elif is_octree: 110 | pass 111 | 112 | # if not is_voxel: 113 | # z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() 114 | # else: 115 | # z_q = rearrange(z_q, 'b d h w c -> b c d h w').contiguous() 116 | 117 | if self.remap is not None: 118 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis 119 | min_encoding_indices = self.remap_to_used(min_encoding_indices) 120 | min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten 121 | 122 | if self.sane_index_shape: 123 | if not is_voxel: 124 | min_encoding_indices = min_encoding_indices.reshape( 125 | z_q.shape[0], z_q.shape[2], z_q.shape[3]) 126 | else: 127 | min_encoding_indices = min_encoding_indices.reshape( 128 | z_q.shape[0], z_q.shape[2], z_q.shape[3], z_q.shape[4]) 129 | 130 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 131 | 132 | def get_codebook_entry(self, indices, shape): 133 | # shape specifying (batch, height, width, channel) 134 | if self.remap is not None: 135 | indices = indices.reshape(shape[0],-1) # add batch axis 136 | indices = self.unmap_to_all(indices) 137 | indices = indices.reshape(-1) # flatten again 138 | 139 | # get quantized latent vectors 140 | z_q = self.embedding(indices) 141 | 142 | if shape is not None: 143 | z_q = z_q.view(shape) 144 | # reshape back to match original input shape 145 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 146 | 147 | return z_q 148 | -------------------------------------------------------------------------------- /models/networks/dualoctree_networks/mpu.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | from ocnn.octree import xyz2key 9 | from ocnn.utils import cumsum 10 | import torch 11 | import torch.nn 12 | 13 | from .utils.spmm import spmm, modulated_spmm 14 | 15 | kNN = 8 16 | 17 | 18 | class ABS(torch.autograd.Function): 19 | '''The derivative of torch.abs on `0` is `0`, and in this implementation, we 20 | modified it to `1` 21 | ''' 22 | @staticmethod 23 | def forward(ctx, input): 24 | ctx.save_for_backward(input) 25 | return input.abs() 26 | 27 | @staticmethod 28 | def backward(ctx, grad_in): 29 | input, = ctx.saved_tensors 30 | sign = input < 0 31 | grad_out = grad_in * (-2.0 * sign.to(input.dtype) + 1.0) 32 | return grad_out 33 | 34 | 35 | def linear_basis(x): 36 | return 1.0 - ABS.apply(x) 37 | 38 | 39 | def get_linear_mask(dim=3): 40 | mask = torch.tensor([0, 1], dtype=torch.float32) 41 | mask = torch.meshgrid([mask]*dim) 42 | mask = torch.stack(mask, -1).view(-1, dim) 43 | return mask 44 | 45 | # mask = tensor([[0., 0., 0.], 46 | # [0., 0., 1.], 47 | # [0., 1., 0.], 48 | # [0., 1., 1.], 49 | # [1., 0., 0.], 50 | # [1., 0., 1.], 51 | # [1., 1., 0.], 52 | # [1., 1., 1.]]) 53 | 54 | 55 | def octree_linear_pts(octree, depth, pts): 56 | # get neigh coordinates 57 | scale = 2 ** depth ## 在特定深度depth下的scale为2**depth 58 | mask = get_linear_mask(dim=3).to(pts.device) 59 | xyzf, ids = torch.split(pts, [3, 1], 1) # 将pos[N, 4]分解为前三个(坐标)xyzf和最后一个(batch_idx)ids 60 | xyzf = (xyzf + 1.0) * (scale / 2.0) # [-1, 1] -> [0, scale] 将xyz坐标放缩到[0, scale] 61 | xyzf = xyzf - 0.5 # the code is defined on the center 62 | xyzi = torch.floor(xyzf).detach() # the integer part (N, 3), use floor 63 | corners = xyzi.unsqueeze(1) + mask # (N, 8, 3), 得到这N个点的周围8个corner的坐标,也就是 [N,8,3],这里corner的坐标都是整数,在[0,scale]范围内 64 | coordsf = xyzf.unsqueeze(1) - corners # (N, 8, 3), in [-1.0, 1.0] 65 | 66 | # coorers -> key 67 | ids = ids.detach().repeat(1, kNN).unsqueeze(-1) # (N, 8, 1) 68 | key = torch.cat([corners, ids], dim=-1).view(-1, 4).short() # (N*8, 4) 69 | key = xyz2key(x = key[:,0], y = key[:,1], z = key[:,2], b = key[:,3]).long() # (N*8, ) 70 | idx = octree.search_key(key, depth) # (N*8, ) 71 | # key = ocnn.octree_encode_key(key).long() # (N*8, ) 72 | # idx = ocnn.octree_search_key(key, octree, depth, key_is_xyz=True) 73 | 74 | # corners -> flags 75 | valid = torch.logical_and(corners > -1, corners < scale) # out-of-bound 76 | valid = torch.all(valid, dim=-1).view(-1) 77 | flgs = torch.logical_and(idx > -1, valid) 78 | 79 | # remove invalid pts 80 | idx = idx[flgs].long() # (N*8, ) -> (N', ) 81 | coordsf = coordsf.view(-1, 3)[flgs] # (N, 8, 3) -> (N', 3) 82 | 83 | # bspline weights 84 | weights = linear_basis(coordsf) # (N', 3) 85 | weights = torch.prod(weights, axis=-1).view(-1) # (N', ) 86 | # Here, the scale factor `2**(depth - 6)` is used to emphasize high-resolution 87 | # basis functions. Tune this factor further if needed! !!! NOTE !!! 88 | # weights = weights * 2**(depth - 6) # used for shapenet 89 | weights = weights * (depth**2 / 50) # testing 90 | 91 | # rescale back the original scale 92 | # After recaling, the coordsf is in the same scale as pts 93 | coordsf = coordsf * (2.0 / scale) # [-1.0, 1.0] -> [-2.0/scale, 2.0/scale] 这一步相当于,把[0,scale]的坐标重新缩小到[-1,1]的尺度上 94 | return {'idx': idx, 'xyzf': coordsf, 'weights': weights, 'flgs': flgs} 95 | 96 | 97 | def get_linear_pred(pts, octree, shape_code, neighs, depth_start, depth_end): 98 | npt = pts.size(0) 99 | indices, weights, xyzfs = [], [], [] 100 | nnum = octree.nnum 101 | nnum_cum = cumsum(nnum, dim=0, exclusive=True) 102 | # nnum_cum = ocnn.octree_property(octree, 'node_num_cum') 103 | ids = torch.arange(npt, device=pts.device, dtype=torch.long) 104 | ids = ids.unsqueeze(-1).repeat(1, kNN).view(-1) 105 | for d in range(depth_start, depth_end+1): 106 | neighd = neighs[d] 107 | idxd = neighd['idx'] 108 | xyzfd = neighd['xyzf'] 109 | weightd = neighd['weights'] 110 | valid = neighd['flgs'] 111 | idsd = ids[valid] 112 | 113 | if d < depth_end: 114 | child = octree.children[d] 115 | leaf = child[idxd] < 0 # keep only leaf nodes 116 | idsd, idxd, weightd, xyzfd = idsd[leaf], idxd[leaf], weightd[leaf], xyzfd[leaf] 117 | 118 | idxd = idxd + (nnum_cum[d] - nnum_cum[depth_start]) 119 | indices.append(torch.stack([idsd, idxd], dim=1)) 120 | weights.append(weightd) 121 | xyzfs.append(xyzfd) 122 | 123 | indices = torch.cat(indices, dim=0).t() 124 | weights = torch.cat(weights, dim=0) 125 | xyzfs = torch.cat(xyzfs, dim=0) 126 | 127 | code_num = shape_code.size(0) 128 | output = modulated_spmm(indices, weights, npt, code_num, shape_code, xyzfs) 129 | norm = spmm(indices, weights, npt, code_num, torch.ones(code_num, 1).cuda()) # 这里norm的维度为[N,] 130 | output = torch.div(output, norm + 1e-8).squeeze() # 这里output的维度为[N, 1](也就是每个查询点的sdf值) 131 | 132 | # whether the point has affected by the octree node in depth layer 133 | mask = neighs[depth_end]['flgs'].view(-1, kNN).any(axis=-1) 134 | return output, mask 135 | 136 | 137 | class NeuralMPU: 138 | def __init__(self, full_depth, depth_stop, depth): 139 | self.full_depth = full_depth 140 | self.depth_stop = depth_stop 141 | self.depth = depth 142 | 143 | def __call__(self, pos, reg_voxs, octree_out): # reg_voxs就是dual octree的每个节点中存储的vector,其维度为4 144 | mpus = dict() 145 | neighs = dict() 146 | for d in range(self.full_depth, self.depth+1): 147 | neighs[d] = octree_linear_pts(octree_out, d, pos) 148 | 149 | for d in range(self.depth_stop, self.depth+1): 150 | fval, flgs = get_linear_pred( 151 | pos, octree_out, reg_voxs[d], neighs, self.full_depth, d) 152 | mpus[d] = (fval, flgs) 153 | return mpus -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/src/structural_loss.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "src/approxmatch.cuh" 5 | #include "src/nndistance.cuh" 6 | 7 | #include 8 | #include 9 | 10 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | /* 15 | input: 16 | set1 : batch_size * #dataset_points * 3 17 | set2 : batch_size * #query_points * 3 18 | returns: 19 | match : batch_size * #query_points * #dataset_points 20 | */ 21 | // temp: TensorShape{b,(n+m)*2} 22 | std::vector ApproxMatch(at::Tensor set_d, at::Tensor set_q) { 23 | //std::cout << "[ApproxMatch] Called." << std::endl; 24 | int64_t batch_size = set_d.size(0); 25 | int64_t n_dataset_points = set_d.size(1); // n 26 | int64_t n_query_points = set_q.size(1); // m 27 | //std::cout << "[ApproxMatch] batch_size:" << batch_size << std::endl; 28 | at::Tensor match = torch::empty({batch_size, n_query_points, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 29 | at::Tensor temp = torch::empty({batch_size, (n_query_points+n_dataset_points)*2}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 30 | CHECK_INPUT(set_d); 31 | CHECK_INPUT(set_q); 32 | CHECK_INPUT(match); 33 | CHECK_INPUT(temp); 34 | 35 | approxmatch(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),temp.data(), at::cuda::getCurrentCUDAStream()); 36 | return {match, temp}; 37 | } 38 | 39 | at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match) { 40 | //std::cout << "[MatchCost] Called." << std::endl; 41 | int64_t batch_size = set_d.size(0); 42 | int64_t n_dataset_points = set_d.size(1); // n 43 | int64_t n_query_points = set_q.size(1); // m 44 | //std::cout << "[MatchCost] batch_size:" << batch_size << std::endl; 45 | at::Tensor out = torch::empty({batch_size}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 46 | CHECK_INPUT(set_d); 47 | CHECK_INPUT(set_q); 48 | CHECK_INPUT(match); 49 | CHECK_INPUT(out); 50 | matchcost(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),out.data(),at::cuda::getCurrentCUDAStream()); 51 | return out; 52 | } 53 | 54 | std::vector MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match) { 55 | //std::cout << "[MatchCostGrad] Called." << std::endl; 56 | int64_t batch_size = set_d.size(0); 57 | int64_t n_dataset_points = set_d.size(1); // n 58 | int64_t n_query_points = set_q.size(1); // m 59 | //std::cout << "[MatchCostGrad] batch_size:" << batch_size << std::endl; 60 | at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 61 | at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 62 | CHECK_INPUT(set_d); 63 | CHECK_INPUT(set_q); 64 | CHECK_INPUT(match); 65 | CHECK_INPUT(grad1); 66 | CHECK_INPUT(grad2); 67 | matchcostgrad(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),grad1.data(),grad2.data(),at::cuda::getCurrentCUDAStream()); 68 | return {grad1, grad2}; 69 | } 70 | 71 | 72 | /* 73 | input: 74 | set_d : batch_size * #dataset_points * 3 75 | set_q : batch_size * #query_points * 3 76 | returns: 77 | dist1, idx1 : batch_size * #dataset_points 78 | dist2, idx2 : batch_size * #query_points 79 | */ 80 | std::vector NNDistance(at::Tensor set_d, at::Tensor set_q) { 81 | //std::cout << "[NNDistance] Called." << std::endl; 82 | int64_t batch_size = set_d.size(0); 83 | int64_t n_dataset_points = set_d.size(1); // n 84 | int64_t n_query_points = set_q.size(1); // m 85 | //std::cout << "[NNDistance] batch_size:" << batch_size << std::endl; 86 | at::Tensor dist1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 87 | at::Tensor idx1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device())); 88 | at::Tensor dist2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 89 | at::Tensor idx2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device())); 90 | CHECK_INPUT(set_d); 91 | CHECK_INPUT(set_q); 92 | CHECK_INPUT(dist1); 93 | CHECK_INPUT(idx1); 94 | CHECK_INPUT(dist2); 95 | CHECK_INPUT(idx2); 96 | // void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream); 97 | nndistance(batch_size,n_dataset_points,set_d.data(),n_query_points,set_q.data(),dist1.data(),idx1.data(),dist2.data(),idx2.data(), at::cuda::getCurrentCUDAStream()); 98 | return {dist1, idx1, dist2, idx2}; 99 | } 100 | 101 | std::vector NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2) { 102 | //std::cout << "[NNDistanceGrad] Called." << std::endl; 103 | int64_t batch_size = set_d.size(0); 104 | int64_t n_dataset_points = set_d.size(1); // n 105 | int64_t n_query_points = set_q.size(1); // m 106 | //std::cout << "[NNDistanceGrad] batch_size:" << batch_size << std::endl; 107 | at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 108 | at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 109 | CHECK_INPUT(set_d); 110 | CHECK_INPUT(set_q); 111 | CHECK_INPUT(idx1); 112 | CHECK_INPUT(idx2); 113 | CHECK_INPUT(grad_dist1); 114 | CHECK_INPUT(grad_dist2); 115 | CHECK_INPUT(grad1); 116 | CHECK_INPUT(grad2); 117 | //void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream); 118 | nndistancegrad(batch_size,n_dataset_points,set_d.data(),n_query_points,set_q.data(), 119 | grad_dist1.data(),idx1.data(), 120 | grad_dist2.data(),idx2.data(), 121 | grad1.data(),grad2.data(), 122 | at::cuda::getCurrentCUDAStream()); 123 | return {grad1, grad2}; 124 | } 125 | 126 | -------------------------------------------------------------------------------- /utils/render/render_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['PYOPENGL_PLATFORM'] = 'egl' 3 | import matplotlib 4 | matplotlib.use("Agg") 5 | import pyglet 6 | pyglet.options['shadow_window'] = False 7 | 8 | import time 9 | from PIL import Image 10 | import numpy as np 11 | import pyrender 12 | import trimesh 13 | from pyrender import ( 14 | DirectionalLight, 15 | SpotLight, 16 | PointLight, 17 | ) 18 | import pyrr 19 | from matplotlib import pyplot as plt 20 | 21 | 22 | SIZE = None 23 | 24 | class Render: 25 | def __init__(self, size, camera_pose, background = None): 26 | self.size = size 27 | global SIZE 28 | SIZE = size 29 | self.camera_pose = camera_pose 30 | self.background = background 31 | 32 | def render(self, path, clean=True, intensity=3.0, mesh=None, only_render_images=False): 33 | if not isinstance(mesh, trimesh.Trimesh): 34 | mesh = prepare_mesh(path, color=False, clean=clean) 35 | try: 36 | if mesh.visual.defined: 37 | mesh.visual.material.kwargs["Ns"] = 1.0 38 | except: 39 | print("Error loading material!") 40 | 41 | 42 | triangle_id, normal_map, depth_image, p_image = None, None, None, None 43 | if not only_render_images: 44 | triangle_id, normal_map, depth_image, p_image = correct_normals( 45 | mesh, self.camera_pose, correct=True) 46 | mesh1 = pyrender.Mesh.from_trimesh(mesh, smooth=False) 47 | rendered_image, _ = pyrender_rendering( 48 | mesh1, viz=False, light=True, camera_pose=self.camera_pose, intensity=intensity, bg_color = self.background 49 | ) 50 | 51 | return triangle_id, rendered_image, normal_map, depth_image, p_image 52 | 53 | 54 | def render_normal(self, path, clean=True, intensity=6.0, mesh=None): 55 | try: 56 | if mesh.visual.defined: 57 | mesh.visual.material.kwargs["Ns"] = 1.0 58 | except: 59 | print("Error loading material!") 60 | 61 | triangle_id, normal_map, depth_image, p_image = correct_normals( 62 | mesh, self.camera_pose, correct=True) 63 | 64 | return normal_map, depth_image 65 | 66 | 67 | def correct_normals(mesh, camera_pose, correct=True): 68 | rayintersector = trimesh.ray.ray_pyembree.RayMeshIntersector(mesh) 69 | 70 | a, b, index_tri, sign, p_image = trimesh_ray_tracing( 71 | mesh, camera_pose, resolution=SIZE*2, rayintersector=rayintersector 72 | ) 73 | if correct: 74 | mesh.faces[index_tri[sign > 0]] = np.fliplr( 75 | mesh.faces[index_tri[sign > 0]]) 76 | 77 | normalmap = render_normal_map( 78 | pyrender.Mesh.from_trimesh(mesh, smooth=False), 79 | camera_pose, 80 | SIZE, 81 | viz=False, 82 | ) 83 | 84 | return b, normalmap, a, p_image 85 | 86 | 87 | 88 | def init_light(scene, camera_pose, intensity=6.0): 89 | direc_l = DirectionalLight(color=np.ones(3), intensity=intensity) 90 | spot_l = SpotLight( 91 | color=np.ones(3), 92 | intensity=intensity, 93 | innerConeAngle=np.pi / 16, 94 | outerConeAngle=np.pi / 6, 95 | ) 96 | point_l = PointLight(color=np.ones(3), intensity=2*intensity) 97 | direc_l_node = scene.add(direc_l, pose=camera_pose) 98 | point_l_node = scene.add(point_l, pose=camera_pose) 99 | spot_l_node = scene.add(spot_l, pose=camera_pose) 100 | 101 | 102 | class CustomShaderCache: 103 | def __init__(self): 104 | self.program = None 105 | 106 | def get_program( 107 | self, vertex_shader, fragment_shader, geometry_shader=None, defines=None 108 | ): 109 | if self.program is None: 110 | current_work_dir = os.path.dirname(__file__) 111 | print(current_work_dir) 112 | self.program = pyrender.shader_program.ShaderProgram( 113 | current_work_dir + "/shades/mesh.vert", current_work_dir + "/shades/mesh.frag", defines=defines 114 | ) 115 | return self.program 116 | 117 | 118 | def render_normal_map(mesh, camera_pose, size, viz=False): 119 | scene = pyrender.Scene(bg_color=(255,255,255)) 120 | scene.add(mesh) 121 | camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0) 122 | scene.add(camera, pose=camera_pose) 123 | 124 | renderer = pyrender.OffscreenRenderer(size, size) 125 | renderer._renderer._program_cache = CustomShaderCache() 126 | 127 | normals, depth = renderer.render( 128 | scene 129 | ) 130 | 131 | world_space_normals = normals / 255 * 2 - 1 132 | 133 | if viz: 134 | image = Image.fromarray(normals, "RGB") 135 | image.show() 136 | 137 | return world_space_normals 138 | 139 | 140 | def pyrender_rendering(mesh, camera_pose, viz=False, light=False, intensity=3.0, bg_color=None): 141 | # renderer 142 | r = pyrender.OffscreenRenderer(SIZE, SIZE) 143 | 144 | scene = pyrender.Scene(bg_color=bg_color) 145 | scene.add(mesh) 146 | 147 | camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.) 148 | camera = scene.add(camera, pose=camera_pose) 149 | # light 150 | if light: 151 | init_light(scene, camera_pose, intensity=intensity) 152 | 153 | scene.set_pose(camera, camera_pose) 154 | 155 | if light: 156 | color, depth = r.render( 157 | scene, flags=pyrender.constants.RenderFlags.ALL_SOLID | pyrender.constants.RenderFlags.FACE_NORMALS) 158 | else: 159 | color, depth = r.render( 160 | scene, flags=pyrender.constants.RenderFlags.FLAT 161 | ) 162 | 163 | return color, depth 164 | 165 | 166 | 167 | def create_pose(eye): 168 | target = np.zeros(3) 169 | camera_pose = np.array(pyrr.Matrix44.look_at(eye=eye, 170 | target=target, 171 | up=np.array([0.0, 1.0, 0])).T) 172 | return np.linalg.inv(np.array(camera_pose)) 173 | 174 | 175 | def trimesh_ray_tracing(mesh, M, resolution=225, fov=60, rayintersector=None): 176 | extra = np.eye(4) 177 | extra[0, 0] = 0 178 | extra[0, 1] = 1 179 | extra[1, 0] = -1 180 | extra[1, 1] = 0 181 | scene = mesh.scene() 182 | 183 | scene.camera_transform = M @ extra 184 | scene.camera.resolution = [resolution, resolution] 185 | scene.camera.fov = fov, fov 186 | origins, vectors, pixels = scene.camera_rays() 187 | 188 | index_tri, index_ray, points = rayintersector.intersects_id( 189 | origins, vectors, multiple_hits=False, return_locations=True 190 | ) 191 | depth = trimesh.util.diagonal_dot(points - origins[0], vectors[index_ray]) 192 | sign = trimesh.util.diagonal_dot( 193 | mesh.face_normals[index_tri], vectors[index_ray]) 194 | 195 | pixel_ray = pixels[index_ray] 196 | a = np.zeros(scene.camera.resolution, dtype=np.uint8) 197 | b = np.ones(scene.camera.resolution, dtype=np.int32) * -1 198 | p_image = np.ones([scene.camera.resolution[0], 199 | scene.camera.resolution[1], 3], dtype=np.float32) * -1 200 | 201 | a[pixel_ray[:, 0], pixel_ray[:, 1]] = depth 202 | b[pixel_ray[:, 0], pixel_ray[:, 1]] = index_tri 203 | p_image[pixel_ray[:, 0], pixel_ray[:, 1]] = points 204 | 205 | return a, b, index_tri, sign, p_image 206 | -------------------------------------------------------------------------------- /datasets/dualoctree_snet.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import ocnn 10 | import torch 11 | import numpy as np 12 | import copy 13 | 14 | from ocnn.octree import Octree, Points 15 | from solver import Dataset 16 | from .utils import collate_func 17 | 18 | 19 | class TransformShape: 20 | 21 | def __init__(self, flags): 22 | self.flags = flags 23 | 24 | self.depth = flags.depth 25 | self.full_depth = flags.full_depth 26 | 27 | self.point_sample_num = flags.point_sample_num 28 | self.point_scale = flags.point_scale 29 | self.noise_std = 0.005 30 | 31 | def points2octree(self, points: Points): 32 | octree = Octree(self.depth, self.full_depth) 33 | octree.build_octree(points) 34 | return octree 35 | 36 | def process_points_cloud(self, sample): 37 | # get the input 38 | points, normals = sample['points'], sample['normals'] 39 | points = points / self.point_scale # scale to [-1.0, 1.0] 40 | 41 | # transform points to octree 42 | points_gt = Points(points = torch.from_numpy(points).float(), normals = torch.from_numpy(normals).float()) 43 | if self.flags.load_color: 44 | points_gt.features = torch.from_numpy(sample['colors']).float() 45 | points_gt.clip(min=-1, max=1) 46 | 47 | return {'points': points_gt} 48 | 49 | def sample_sdf(self, sample): # 这里加载的sdf的坐标也都是在[-1,1]范围内的。 50 | sdf = sample['sdf'] 51 | grad = sample['grad'] 52 | points = sample['points'] / self.point_scale # to [-1, 1] 53 | 54 | rand_idx = np.random.choice(points.shape[0], size=self.point_sample_num) 55 | points = torch.from_numpy(points[rand_idx]).float() 56 | sdf = torch.from_numpy(sdf[rand_idx]).float() 57 | grad = torch.from_numpy(grad[rand_idx]).float() 58 | return {'pos': points, 'sdf': sdf, 'grad': grad} 59 | 60 | def sample_on_surface(self, points, normals): 61 | rand_idx = np.random.choice(points.shape[0], size=self.point_sample_num) 62 | xyz = torch.from_numpy(points[rand_idx]).float() 63 | grad = torch.from_numpy(normals[rand_idx]).float() 64 | sdf = torch.zeros(self.point_sample_num) 65 | return {'pos': xyz, 'sdf': sdf, 'grad': grad} 66 | 67 | def sample_off_surface(self, xyz): 68 | xyz = xyz / self.point_scale # to [-1, 1] 69 | 70 | rand_idx = np.random.choice(xyz.shape[0], size=self.point_sample_num) 71 | xyz = torch.from_numpy(xyz[rand_idx]).float() 72 | # grad = torch.zeros(self.sample_number, 3) # dummy grads 73 | grad = xyz / (xyz.norm(p=2, dim=1, keepdim=True) + 1.0e-6) 74 | sdf = -1 * torch.ones(self.point_sample_num) # dummy sdfs 75 | return {'pos': xyz, 'sdf': sdf, 'grad': grad} 76 | 77 | def __call__(self, sample, idx): 78 | output = {} 79 | 80 | if self.flags.load_octree: 81 | output['octree_in'] = sample['octree_in'] 82 | 83 | if self.flags.load_pointcloud: 84 | output = self.process_points_cloud(sample['point_cloud']) 85 | 86 | if self.flags.load_split_small: 87 | output['split_small'] = sample['split_small'] 88 | 89 | if self.flags.load_split_large: 90 | output['split_large'] = sample['split_large'] 91 | 92 | # sample ground truth sdfs 93 | if self.flags.load_sdf: 94 | sdf_samples = self.sample_sdf(sample['sdf']) 95 | output.update(sdf_samples) 96 | 97 | # sample on surface points and off surface points 98 | if self.flags.sample_surf_points: 99 | on_surf = self.sample_on_surface(sample['points'], sample['normals']) 100 | off_surf = self.sample_off_surface(sample['sdf']['points']) # TODO 101 | sdf_samples = { 102 | 'pos': torch.cat([on_surf['pos'], off_surf['pos']], dim=0), 103 | 'grad': torch.cat([on_surf['grad'], off_surf['grad']], dim=0), 104 | 'sdf': torch.cat([on_surf['sdf'], off_surf['sdf']], dim=0)} 105 | output.update(sdf_samples) 106 | 107 | return output 108 | 109 | 110 | class ReadFile: 111 | def __init__(self, flags): 112 | self.load_octree = flags.load_octree 113 | self.load_pointcloud = flags.load_pointcloud 114 | self.load_split_small = flags.load_split_small 115 | self.load_split_large = flags.load_split_large 116 | self.load_occu = flags.load_occu 117 | self.load_sdf = flags.load_sdf 118 | self.load_color = flags.load_color 119 | 120 | def __call__(self, filename): 121 | output = {} 122 | 123 | if self.load_octree: 124 | octree_path = os.path.join(filename, 'octree.pth') 125 | raw = torch.load(octree_path) 126 | octree_in = raw['octree_in'] 127 | output['octree_in'] = octree_in 128 | 129 | if self.load_pointcloud: 130 | filename_pc = os.path.join(filename, 'pointcloud.npz') 131 | raw = np.load(filename_pc) 132 | point_cloud = {'points': raw['points'], 'normals': raw['normals']} 133 | if self.load_color: 134 | filename_color = os.path.join(filename, 'color.npz') 135 | raw = np.load(filename_color) 136 | point_cloud['colors'] = raw['colors'] 137 | else: 138 | point_cloud['colors'] = None 139 | output['point_cloud'] = point_cloud 140 | 141 | 142 | if self.load_split_small: 143 | filename_split_small = os.path.join(filename, 'split_small.pth') 144 | raw = torch.load(filename_split_small, map_location = 'cpu') 145 | output['split_small'] = raw 146 | 147 | if self.load_split_large: 148 | filename_split_large = os.path.join(filename, 'split_large.pth') 149 | try: 150 | raw = torch.load(filename_split_large, map_location = 'cpu') 151 | except: 152 | print('Error!!') 153 | print(filename) 154 | output['split_large'] = raw 155 | 156 | if self.load_occu: 157 | filename_occu = os.path.join(filename, 'points.npz') 158 | raw = np.load(filename_occu) 159 | occu = {'points': raw['points'], 'occupancies': raw['occupancies']} 160 | output['occu'] = occu 161 | 162 | if self.load_sdf: 163 | filename_sdf = os.path.join(filename, 'sdf.npz') 164 | raw = np.load(filename_sdf) 165 | sdf = {'points': raw['points'], 'grad': raw['grad'], 'sdf': raw['sdf']} 166 | output['sdf'] = sdf 167 | 168 | return output 169 | 170 | 171 | def get_shapenet_dataset(flags): 172 | transform = TransformShape(flags) 173 | read_file = ReadFile(flags) 174 | dataset = Dataset(flags.location, flags.filelist, transform, 175 | read_file=read_file, in_memory=flags.in_memory) 176 | return dataset, collate_func 177 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from termcolor import colored 5 | from omegaconf import OmegaConf 6 | 7 | import torch 8 | from torch.utils.tensorboard import SummaryWriter 9 | 10 | import utils 11 | 12 | from utils.distributed import ( 13 | get_rank, 14 | synchronize, 15 | ) 16 | 17 | def str2bool(v): 18 | if isinstance(v, bool): 19 | return v 20 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 21 | return True 22 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 23 | return False 24 | else: 25 | raise argparse.ArgumentTypeError('Boolean value expected.') 26 | 27 | class BaseOptions(): 28 | def __init__(self): 29 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 30 | self.initialized = False 31 | 32 | def initialize(self): 33 | # hyper parameters 34 | self.parser.add_argument('--batch_size', type=int, default=2, help='input batch size') 35 | self.parser.add_argument('--gpu_ids', type=str, default='1', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 36 | 37 | # log stuff 38 | self.parser.add_argument('--logs_dir', type=str, default='./logs', help='the root of the logs dir. All training logs are saved here') 39 | self.parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 40 | 41 | # dataset stuff 42 | self.parser.add_argument('--dataset_mode', type=str, default='snet', help='chooses how datasets are loaded. [snet, obja]') 43 | self.parser.add_argument('--num_times', type=int, default=2, help='dataset resolution') 44 | self.parser.add_argument('--category', type=str, default='chair', help='category for shapenet') 45 | self.parser.add_argument('--split_dir', type=str, default=None, help='split path') 46 | self.parser.add_argument('--trunc_thres', type=float, default=0.2, help='threshold for truncated sdf.') 47 | 48 | self.parser.add_argument('--ratio', type=float, default=1., help='ratio of the dataset to use. for debugging and overfitting') 49 | 50 | ############## START: model related options ################ 51 | self.parser.add_argument( 52 | '--model', type=str, default='union_2t', 53 | choices=['union_2t', 'union_3t', 'vae'], 54 | help='chooses which model to use.' 55 | ) 56 | self.parser.add_argument( 57 | '--stage_flag', type=str, default='lr', 58 | choices=['lr', 'hr', 'feature'], 59 | help='chooses which model to use.' 60 | ) 61 | self.parser.add_argument('--ckpt', type=str, default=None, help='ckpt to load.') 62 | self.parser.add_argument('--pretrain_ckpt', type=str, default=None, help='pretrain ckpt to load.') 63 | 64 | # diffusion stuff 65 | self.parser.add_argument('--df_cfg', type=str, default='configs/octfusion_snet.yaml', help="diffusion model's config file") 66 | self.parser.add_argument('--ddim_steps', type=int, default=100, help='steps for ddim sampler') 67 | self.parser.add_argument('--ddim_eta', type=float, default=0.0) 68 | self.parser.add_argument('--uc_scale', type=float, default=1.0, help='scale for un guidance') 69 | 70 | # vqvae stuff 71 | self.parser.add_argument('--vq_model', type=str, default='',choices=['vqvae', 'GraphAE', 'GraphVQVAE','GraphVAE'], help='for choosing the vqvae model to use.') 72 | 73 | self.parser.add_argument('--vq_cfg', type=str, default='configs/vqvae_snet.yaml', help='vqvae model config file') 74 | self.parser.add_argument('--vq_ckpt', type=str, default=None, help='vqvae ckpt to load.') 75 | 76 | # dualocnn stuff 77 | self.parser.add_argument('--sync_bn', type=bool, default=False, help='whether to use torch.nn.SyncBatchNorm.') 78 | ############## END: model related options ################ 79 | 80 | # misc 81 | self.parser.add_argument('--debug', default='0', type=str, choices=['0', '1'], help='if true, debug mode') 82 | self.parser.add_argument('--seed', default=0, type=int, help='seed') 83 | 84 | # multi-gpu stuff 85 | self.parser.add_argument("--backend", type=str, default="gloo", help="which backend to use") 86 | self.parser.add_argument("--local_rank", type=int, default=0, help="local rank for distributed training") 87 | 88 | self.initialized = True 89 | 90 | def parse_and_setup(self): 91 | import sys 92 | cmd = ' '.join(sys.argv) 93 | print(f'python {cmd}') 94 | 95 | if not self.initialized: 96 | self.initialize() 97 | 98 | self.opt = self.parser.parse_args() 99 | self.opt.isTrain = str2bool(self.opt.isTrain) 100 | 101 | if self.opt.isTrain: 102 | self.opt.phase = 'train' 103 | else: 104 | self.opt.phase = 'test' 105 | 106 | # setup multi-gpu stuffs here 107 | # basically from stylegan2-pytorch, train.py by rosinality 108 | self.opt.device = 'cuda' 109 | n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 110 | self.opt.distributed = n_gpu > 1 111 | 112 | self.opt.local_rank = int(os.environ["LOCAL_RANK"]) if "LOCAL_RANK" in os.environ else 0 113 | if self.opt.distributed: 114 | torch.cuda.set_device(self.opt.local_rank) 115 | torch.distributed.init_process_group(backend=self.opt.backend, init_method="env://") 116 | synchronize() 117 | 118 | name = self.opt.name 119 | 120 | self.opt.name = name 121 | 122 | self.opt.gpu_ids_str = self.opt.gpu_ids 123 | 124 | # NOTE: seed or not? 125 | # seed = opt.seed 126 | # util.seed_everything(seed) 127 | 128 | self.opt.rank = get_rank() 129 | 130 | if get_rank() == 0: 131 | # print args 132 | args = vars(self.opt) 133 | 134 | print('------------ Options -------------') 135 | for k, v in sorted(args.items()): 136 | print('%s: %s' % (str(k), str(v))) 137 | print('-------------- End ----------------') 138 | 139 | # make experiment dir 140 | if self.opt.isTrain: 141 | expr_dir = os.path.join(self.opt.logs_dir, self.opt.name) 142 | utils.util.mkdirs(expr_dir) 143 | 144 | ckpt_dir = os.path.join(self.opt.logs_dir, self.opt.name, 'ckpt') 145 | if not os.path.exists(ckpt_dir): 146 | os.makedirs(ckpt_dir) 147 | self.opt.ckpt_dir = ckpt_dir 148 | 149 | file_name = os.path.join(expr_dir, 'opt.txt') 150 | with open(file_name, 'wt') as opt_file: 151 | opt_file.write('------------ Options -------------\n') 152 | for k, v in sorted(args.items()): 153 | opt_file.write('%s: %s\n' % (str(k), str(v))) 154 | opt_file.write('-------------- End ----------------\n') 155 | 156 | # tensorboard writer 157 | tb_dir = '%s/tboard' % expr_dir 158 | if not os.path.exists(tb_dir): 159 | os.makedirs(tb_dir) 160 | self.opt.tb_dir = tb_dir 161 | writer = SummaryWriter(log_dir=tb_dir) 162 | self.opt.writer = writer 163 | 164 | return self.opt 165 | -------------------------------------------------------------------------------- /models/networks/dualoctree_networks/loss.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import ocnn 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | 13 | def compute_gradient(y, x): 14 | if x.dtype is not torch.float32: 15 | x = x.to(torch.float32) 16 | if y.dtype is not torch.float32: 17 | y = y.to(torch.float32) 18 | grad_outputs = torch.ones_like(y) 19 | grad = torch.autograd.grad(y, [x], grad_outputs, create_graph=True)[0] 20 | return grad 21 | 22 | 23 | def sdf_reg_loss(sdf, grad, sdf_gt, grad_gt, name_suffix=''): 24 | wg, ws = 1.0, 200.0 25 | grad_loss = (grad - grad_gt).pow(2).mean() * wg 26 | sdf_loss = (sdf - sdf_gt).pow(2).mean() * ws 27 | loss_dict = {'grad_loss' + name_suffix: grad_loss, 28 | 'sdf_loss' + name_suffix: sdf_loss} 29 | return loss_dict 30 | 31 | 32 | def sdf_grad_loss(sdf, grad, sdf_gt, grad_gt, name_suffix=''): 33 | on_surf = sdf_gt != -1 34 | off_surf = on_surf.logical_not() 35 | 36 | sdf_loss = sdf[on_surf].pow(2).mean() * 200.0 37 | norm_loss = (grad[on_surf] - grad_gt[on_surf]).pow(2).mean() * 1.0 38 | intr_loss = torch.exp(-40 * torch.abs(sdf[off_surf])).mean() * 0.1 39 | grad_loss = (grad[off_surf].norm(2, dim=-1) - 1).abs().mean() * 0.1 40 | 41 | losses = [sdf_loss, intr_loss, norm_loss, grad_loss] 42 | names = ['sdf_loss', 'inter_loss', 'norm_loss', 'grad_loss'] 43 | names = [name + name_suffix for name in names] 44 | loss_dict = dict(zip(names, losses)) 45 | return loss_dict 46 | 47 | 48 | def sdf_grad_regularized_loss(sdf, grad, sdf_gt, grad_gt, name_suffix=''): 49 | on_surf = sdf_gt != -1 50 | off_surf = on_surf.logical_not() 51 | 52 | sdf_loss = sdf[on_surf].pow(2).mean() * 200.0 53 | norm_loss = (grad[on_surf] - grad_gt[on_surf]).pow(2).mean() * 1.0 54 | intr_loss = torch.exp(-40 * torch.abs(sdf[off_surf])).mean() * 0.1 55 | grad_loss = (grad[off_surf].norm(2, dim=-1) - 1).abs().mean() * 0.1 56 | grad_reg_loss = (grad[off_surf] - grad_gt[off_surf]).pow(2).mean() * 0.1 57 | 58 | losses = [sdf_loss, intr_loss, norm_loss, grad_loss, grad_reg_loss] 59 | names = ['sdf_loss', 'inter_loss', 'norm_loss', 'grad_loss', 'grad_reg_loss'] 60 | names = [name + name_suffix for name in names] 61 | loss_dict = dict(zip(names, losses)) 62 | return loss_dict 63 | 64 | 65 | def possion_grad_loss(sdf, grad, sdf_gt, grad_gt, name_suffix=''): 66 | on_surf = sdf_gt == 0 67 | out_of_bbox = sdf_gt == 1.0 68 | off_surf = on_surf.logical_not() 69 | 70 | sdf_loss = sdf[on_surf].pow(2).mean() * 200.0 71 | norm_loss = (grad[on_surf] - grad_gt[on_surf]).pow(2).mean() * 1.0 72 | intr_loss = torch.exp(-40 * torch.abs(sdf[off_surf])).mean() * 0.1 73 | grad_loss = grad[off_surf].pow(2).mean() * 0.1 # poisson loss 74 | bbox_loss = torch.mean(torch.relu(-sdf[out_of_bbox])) * 100.0 75 | 76 | losses = [sdf_loss, intr_loss, norm_loss, grad_loss, bbox_loss] 77 | names = ['sdf_loss', 'inter_loss', 'norm_loss', 'grad_loss', 'bbox_loss'] 78 | names = [name + name_suffix for name in names] 79 | loss_dict = dict(zip(names, losses)) 80 | return loss_dict 81 | 82 | def color_loss(color, gt_color, name_suffix=''): 83 | loss_dict = dict() 84 | # TODO: color loss type 85 | color_loss = (color - gt_color).pow(2).mean() * 200.0 86 | loss_dict['color_loss' + name_suffix] = color_loss 87 | 88 | return loss_dict 89 | 90 | def compute_mpu_gradients(mpus, pos, fval_transform=None): 91 | grads = dict() 92 | for d in mpus.keys(): 93 | fval, flags = mpus[d] 94 | if fval_transform is not None: 95 | fval = fval_transform(fval) 96 | grads[d] = compute_gradient(fval, pos)[:, :3] 97 | return grads 98 | 99 | 100 | def compute_octree_loss(logits, octree_out): 101 | weights = [1.0] * 16 102 | # weights = [1.0] * 4 + [0.8, 0.6, 0.4] + [0.2] * 16 103 | 104 | output = dict() 105 | for d in logits.keys(): 106 | logitd = logits[d] 107 | label_gt = octree_out.nempty_mask(d).long() 108 | # label_gt = ocnn.octree_property(octree_out, 'split', d).long() 109 | output['loss_%d' % d] = F.cross_entropy(logitd, label_gt) * weights[d] 110 | output['accu_%d' % d] = logitd.argmax(1).eq(label_gt).float().mean() 111 | return output 112 | 113 | 114 | def compute_sdf_loss(mpus, grads, sdf_gt, grad_gt, reg_loss_func): 115 | output = dict() 116 | for d in mpus.keys(): 117 | sdf, flgs = mpus[d] # TODO: tune the loss weights and `flgs` 118 | reg_loss = reg_loss_func(sdf, grads[d], sdf_gt, grad_gt, '_%d' % d) 119 | # if d < 3: # ignore depth 2 120 | # for key in reg_loss.keys(): 121 | # reg_loss[key] = reg_loss[key] * 0.0 122 | output.update(reg_loss) 123 | return output 124 | 125 | def compute_color_loss(colors, color_gt): 126 | output = dict() 127 | for d in colors.keys(): 128 | color, flgs = colors[d] # TODO: tune the loss weights and `flgs` 129 | reg_loss = color_loss(color, color_gt,'_%d' % d) 130 | output.update(reg_loss) 131 | return output 132 | 133 | def compute_occu_loss_v0(mpus, grads, occu_gt, grad_gt, weight): 134 | output = dict() 135 | for d in mpus.keys(): 136 | occu, flgs, grad = mpus[d] 137 | 138 | # pos_weight = torch.ones_like(occu_gt) * 10.0 139 | loss_o = F.binary_cross_entropy_with_logits(occu, occu_gt, weight=weight) 140 | # loss_g = torch.mean((grad - grad_gt) ** 2) 141 | 142 | occu = torch.sigmoid(occu) 143 | non_surface_points = occu_gt != 0.5 144 | accu = (occu > 0.5).eq(occu_gt).float()[non_surface_points].mean() 145 | 146 | output['occu_loss_%d' % d] = loss_o 147 | # output['grad_loss_%d' % d] = loss_g 148 | output['occu_accu_%d' % d] = accu 149 | return output 150 | 151 | def get_sdf_loss_function(loss_type=''): 152 | if loss_type == 'sdf_reg_loss': 153 | return sdf_reg_loss 154 | elif loss_type == 'sdf_grad_loss': 155 | return sdf_grad_loss 156 | elif loss_type == 'possion_grad_loss': 157 | return possion_grad_loss 158 | elif loss_type == 'sdf_grad_reg_loss': 159 | return sdf_grad_regularized_loss 160 | else: 161 | return None 162 | 163 | 164 | def geometry_loss(batch, model_out, reg_loss_type='', codebook_weight = 1.0, kl_weight = 1.0): 165 | # octree loss 166 | output = compute_octree_loss(model_out['logits'], model_out['octree_out']) 167 | 168 | # regression loss 169 | grads = compute_mpu_gradients(model_out['mpus'], batch['pos']) 170 | reg_loss_func = get_sdf_loss_function(reg_loss_type) 171 | sdf_loss = compute_sdf_loss( 172 | model_out['mpus'], grads, batch['sdf'], batch['grad'], reg_loss_func) 173 | output.update(sdf_loss) 174 | if 'emb_loss' in model_out.keys(): 175 | output['emb_loss'] = codebook_weight * model_out['emb_loss'] # 只用于graph_vqvae 176 | if 'kl_loss' in model_out.keys(): 177 | output['kl_loss'] = kl_weight * model_out['kl_loss'] # 只用于graph_vae 178 | return output 179 | 180 | def geometry_color_loss(batch, model_out, reg_loss_type='', codebook_weight = 1.0, kl_weight = 1.0): 181 | # octree loss 182 | output = compute_octree_loss(model_out['logits'], model_out['octree_out']) 183 | 184 | # regression loss 185 | grads = compute_mpu_gradients(model_out['mpus'], batch['pos']) 186 | reg_loss_func = get_sdf_loss_function(reg_loss_type) 187 | sdf_loss = compute_sdf_loss( 188 | model_out['mpus'], grads, batch['sdf'], batch['grad'], reg_loss_func) 189 | output.update(sdf_loss) 190 | color_loss = compute_color_loss(model_out['colors'], batch['color']) 191 | output.update(color_loss) 192 | if 'emb_loss' in model_out.keys(): 193 | output['emb_loss'] = codebook_weight * model_out['emb_loss'] # 只用于graph_vqvae 194 | if 'kl_loss' in model_out.keys(): 195 | output['kl_loss'] = kl_weight * model_out['kl_loss'] # 只用于graph_vae 196 | return output 197 | 198 | 199 | def dfaust_loss(batch, model_out, reg_loss_type=''): 200 | # there is no octree loss 201 | grads = compute_mpu_gradients(model_out['mpus'], batch['pos']) 202 | reg_loss_func = get_sdf_loss_function(reg_loss_type) 203 | output = compute_sdf_loss( 204 | model_out['mpus'], grads, batch['sdf'], batch['grad'], reg_loss_func) 205 | return output 206 | -------------------------------------------------------------------------------- /metrics/evaluation_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import warnings 4 | from scipy.stats import entropy 5 | from sklearn.neighbors import NearestNeighbors 6 | from numpy.linalg import norm 7 | from scipy.optimize import linear_sum_assignment 8 | from tqdm import tqdm 9 | 10 | # Borrow from https://github.com/ThibaultGROUEIX/AtlasNet 11 | def distChamfer(a, b): 12 | x, y = a, b 13 | bs, num_points, points_dim = x.size() 14 | xx = torch.bmm(x, x.transpose(2, 1)) 15 | yy = torch.bmm(y, y.transpose(2, 1)) 16 | zz = torch.bmm(x, y.transpose(2, 1)) 17 | diag_ind = torch.arange(0, num_points).to(a).long() 18 | rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx) 19 | ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy) 20 | P = (rx.transpose(2, 1) + ry - 2 * zz) 21 | return P.min(1)[0], P.min(2)[0] 22 | 23 | # Import CUDA version of approximate EMD, from https://github.com/zekunhao1995/pcgan-pytorch/ 24 | try: 25 | from .StructuralLosses.nn_distance import nn_distance 26 | def distChamferCUDA(x, y): 27 | return nn_distance(x, y) 28 | except Exception as e: 29 | print(str(e)) 30 | print("distChamferCUDA not available; fall back to slower version.") 31 | def distChamferCUDA(x, y): 32 | return distChamfer(x, y) 33 | 34 | 35 | def emd_approx(x, y): 36 | bs, npts, mpts, dim = x.size(0), x.size(1), y.size(1), x.size(2) 37 | assert npts == mpts, "EMD only works if two point clouds are equal size" 38 | dim = x.shape[-1] 39 | x = x.reshape(bs, npts, 1, dim) 40 | y = y.reshape(bs, 1, mpts, dim) 41 | dist = (x - y).norm(dim=-1, keepdim=False) # (bs, npts, mpts) 42 | 43 | emd_lst = [] 44 | dist_np = dist.cpu().detach().numpy() 45 | for i in range(bs): 46 | d_i = dist_np[i] 47 | r_idx, c_idx = linear_sum_assignment(d_i) 48 | emd_i = d_i[r_idx, c_idx].mean() 49 | emd_lst.append(emd_i) 50 | emd = np.stack(emd_lst).reshape(-1) 51 | emd_torch = torch.from_numpy(emd).to(x) 52 | return emd_torch 53 | 54 | 55 | try: 56 | from .StructuralLosses.match_cost import match_cost 57 | def emd_approx_cuda(sample, ref): 58 | B, N, N_ref = sample.size(0), sample.size(1), ref.size(1) 59 | assert N == N_ref, "Not sure what would EMD do in this case" 60 | emd = match_cost(sample, ref) # (B,) 61 | emd_norm = emd / float(N) # (B,) 62 | return emd_norm 63 | except Exception as e: 64 | print(str(e)) 65 | print("emd_approx_cuda not available. Fall back to slower version.") 66 | def emd_approx_cuda(sample, ref): 67 | return emd_approx(sample, ref) 68 | 69 | 70 | def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True, 71 | accelerated_emd=False): 72 | N_sample = sample_pcs.shape[0] 73 | N_ref = ref_pcs.shape[0] 74 | assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample) 75 | 76 | cd_lst = [] 77 | emd_lst = [] 78 | iterator = range(0, N_sample, batch_size) 79 | 80 | for b_start in iterator: 81 | b_end = min(N_sample, b_start + batch_size) 82 | sample_batch = sample_pcs[b_start:b_end] 83 | ref_batch = ref_pcs[b_start:b_end] 84 | 85 | if accelerated_cd: 86 | dl, dr = distChamferCUDA(sample_batch, ref_batch) 87 | else: 88 | dl, dr = distChamfer(sample_batch, ref_batch) 89 | cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1)) 90 | 91 | if accelerated_emd: 92 | emd_batch = emd_approx_cuda(sample_batch, ref_batch) 93 | else: 94 | emd_batch = emd_approx(sample_batch, ref_batch) 95 | emd_lst.append(emd_batch) 96 | 97 | if reduced: 98 | cd = torch.cat(cd_lst).mean() 99 | emd = torch.cat(emd_lst).mean() 100 | else: 101 | cd = torch.cat(cd_lst) 102 | emd = torch.cat(emd_lst) 103 | 104 | results = { 105 | 'MMD-CD': cd, 106 | 'MMD-EMD': emd, 107 | } 108 | return results 109 | 110 | 111 | def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, accelerated_cd=True, 112 | accelerated_emd=True): 113 | accelerated_cd = True 114 | accelerated_emd = True 115 | N_sample = sample_pcs.shape[0] 116 | N_ref = ref_pcs.shape[0] 117 | all_cd = [] 118 | all_emd = [] 119 | iterator = range(N_sample) 120 | for sample_b_start in tqdm(iterator): 121 | sample_batch = sample_pcs[sample_b_start] 122 | 123 | cd_lst = [] 124 | emd_lst = [] 125 | for ref_b_start in range(0, N_ref, batch_size): 126 | ref_b_end = min(N_ref, ref_b_start + batch_size) 127 | ref_batch = ref_pcs[ref_b_start:ref_b_end] 128 | 129 | batch_size_ref = ref_batch.size(0) 130 | sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1) 131 | sample_batch_exp = sample_batch_exp.contiguous() 132 | 133 | if accelerated_cd and distChamferCUDA is not None: 134 | dl, dr = distChamferCUDA(sample_batch_exp, ref_batch) 135 | else: 136 | dl, dr = distChamfer(sample_batch_exp, ref_batch) 137 | cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1)) 138 | 139 | if accelerated_emd: 140 | emd_batch = emd_approx_cuda(sample_batch_exp, ref_batch) 141 | else: 142 | emd_batch = emd_approx(sample_batch_exp, ref_batch) 143 | emd_lst.append(emd_batch.view(1, -1)) 144 | 145 | cd_lst = torch.cat(cd_lst, dim=1) 146 | emd_lst = torch.cat(emd_lst, dim=1) 147 | all_cd.append(cd_lst) 148 | all_emd.append(emd_lst) 149 | 150 | all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref 151 | all_emd = torch.cat(all_emd, dim=0) # N_sample, N_ref 152 | 153 | return all_cd, all_emd 154 | 155 | 156 | # Adapted from https://github.com/xuqiantong/GAN-Metrics/blob/master/framework/metric.py 157 | def knn(Mxx, Mxy, Myy, k, sqrt=False): 158 | n0 = Mxx.size(0) 159 | n1 = Myy.size(0) 160 | label = torch.cat((torch.ones(n0), torch.zeros(n1))).to(Mxx) 161 | M = torch.cat((torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.transpose(0, 1), Myy), 1)), 0) 162 | if sqrt: 163 | M = M.abs().sqrt() 164 | INFINITY = float('inf') 165 | val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk(k, 0, False) 166 | 167 | count = torch.zeros(n0 + n1).to(Mxx) 168 | for i in range(0, k): 169 | count = count + label.index_select(0, idx[i]) 170 | pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float() 171 | 172 | s = { 173 | 'tp': (pred * label).sum(), 174 | 'fp': (pred * (1 - label)).sum(), 175 | 'fn': ((1 - pred) * label).sum(), 176 | 'tn': ((1 - pred) * (1 - label)).sum(), 177 | } 178 | 179 | s.update({ 180 | 'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10), 181 | 'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10), 182 | 'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10), 183 | 'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10), 184 | 'acc': torch.eq(label, pred).float().mean(), 185 | }) 186 | return s 187 | 188 | 189 | def lgan_mmd_cov(all_dist): 190 | N_sample, N_ref = all_dist.size(0), all_dist.size(1) 191 | min_val_fromsmp, min_idx = torch.min(all_dist, dim=1) 192 | min_val, _ = torch.min(all_dist, dim=0) 193 | mmd = min_val.mean() 194 | mmd_smp = min_val_fromsmp.mean() 195 | cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref) 196 | cov = torch.tensor(cov).to(all_dist) 197 | return { 198 | 'lgan_mmd': mmd, 199 | 'lgan_cov': cov, 200 | 'lgan_mmd_smp': mmd_smp, 201 | } 202 | 203 | 204 | def compute_cov_mmd(sample_pcs, ref_pcs, batch_size): 205 | results = {} 206 | accelerated_cd = True 207 | M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd) 208 | 209 | res_cd = lgan_mmd_cov(M_rs_cd.t()) 210 | results.update({ 211 | "%s-CD" % k: v for k, v in res_cd.items() 212 | }) 213 | 214 | res_emd = lgan_mmd_cov(M_rs_emd.t()) 215 | results.update({ 216 | "%s-EMD" % k: v for k, v in res_emd.items() 217 | }) 218 | return results 219 | 220 | 221 | def compute_1_nna(sample_pcs, ref_pcs, batch_size): 222 | results = {} 223 | accelerated_cd = True 224 | M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd) 225 | M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size, accelerated_cd=accelerated_cd) 226 | M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd) 227 | 228 | # 1-NN results 229 | one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False) 230 | results.update({ 231 | "1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k 232 | }) 233 | one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False) 234 | results.update({ 235 | "1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if 'acc' in k 236 | }) 237 | 238 | return results 239 | 240 | 241 | if __name__ == "__main__": 242 | B, N = 2, 10 243 | x = torch.rand(B, N, 3) 244 | y = torch.rand(B, N, 3) 245 | 246 | distChamfer = distChamferCUDA 247 | min_l, min_r = distChamfer(x.cuda(), y.cuda()) 248 | print(min_l.shape) 249 | print(min_r.shape) 250 | 251 | l_dist = min_l.mean().cpu().detach().item() 252 | r_dist = min_r.mean().cpu().detach().item() 253 | print(l_dist, r_dist) 254 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import inspect 4 | import random 5 | 6 | from termcolor import colored, cprint 7 | from tqdm import tqdm 8 | 9 | import torch.backends.cudnn as cudnn 10 | # cudnn.benchmark = True 11 | 12 | from options.train_options import TrainOptions 13 | from datasets.dataloader import config_dataloader, get_data_generator 14 | from models.base_model import create_model 15 | 16 | import torch.multiprocessing 17 | torch.multiprocessing.set_sharing_strategy('file_descriptor') 18 | 19 | from utils.distributed import ( 20 | get_rank, 21 | synchronize, 22 | reduce_loss_dict, 23 | reduce_sum, 24 | get_world_size, 25 | ) 26 | 27 | from utils.util import seed_everything, category_5_to_label, category_5_to_num 28 | 29 | import torch 30 | from utils.visualizer import Visualizer 31 | 32 | 33 | def train_main_worker(opt, model, train_loader, test_loader, visualizer): 34 | 35 | if get_rank() == 0: 36 | cprint('[*] Start training. name: %s' % opt.name, 'blue') 37 | 38 | train_dg = get_data_generator(train_loader) 39 | test_dg = get_data_generator(test_loader) 40 | 41 | epoch_length = len(train_loader) 42 | print('The epoch length is', epoch_length) 43 | 44 | total_iters = epoch_length * opt.epochs 45 | start_iter = opt.start_iter 46 | 47 | epoch = start_iter // epoch_length 48 | 49 | # pbar = tqdm(total=total_iters) 50 | pbar = tqdm(range(start_iter, total_iters)) 51 | 52 | iter_start_time = time.time() 53 | for iter_i in range(start_iter, total_iters): 54 | 55 | opt.iter_i = iter_i 56 | iter_ip1 = iter_i + 1 57 | 58 | if get_rank() == 0: 59 | visualizer.reset() 60 | 61 | data = next(train_dg) 62 | data['iter_num'] = iter_i 63 | data['epoch'] = epoch 64 | model.set_input(data) 65 | model.optimize_parameters() 66 | 67 | # if torch.isnan(model.loss).any() == True: 68 | # break 69 | 70 | if get_rank() == 0: 71 | pbar.update(1) 72 | if iter_i % opt.print_freq == 0: 73 | errors = model.get_current_errors() 74 | 75 | t = (time.time() - iter_start_time) / opt.batch_size 76 | visualizer.print_current_errors(iter_i, errors, t) 77 | 78 | if iter_ip1 % opt.save_latest_freq == 0: 79 | cprint('saving the latest model (current_iter %d)' % (iter_i), 'blue') 80 | latest_name = f'steps-latest' 81 | model.save(latest_name, iter_ip1) 82 | 83 | # save every 3000 steps (batches) 84 | if iter_ip1 % opt.save_steps_freq == 0: 85 | cprint('saving the model at iters %d' % iter_ip1, 'blue') 86 | latest_name = f'steps-latest' 87 | model.save(latest_name, iter_ip1) 88 | cur_name = f'steps-{iter_ip1}' 89 | model.save(cur_name, iter_ip1) 90 | 91 | cprint(f'[*] End of steps %d \t Time Taken: %d sec \n%s' % 92 | ( 93 | iter_ip1, 94 | time.time() - iter_start_time, 95 | os.path.abspath(os.path.join(opt.logs_dir, opt.name)) 96 | ), 'blue', attrs=['bold'] 97 | ) 98 | 99 | if iter_i % epoch_length == epoch_length - 1: 100 | print('Finish One Epoch!') 101 | epoch += 1 102 | print('Now Epoch is:', epoch) 103 | 104 | # display every n batches 105 | if iter_i % opt.display_freq == 0: 106 | if iter_i == 0 and opt.debug == "0": 107 | pbar.update(1) 108 | continue 109 | 110 | # eval 111 | if opt.model == "vae": 112 | data = next(test_dg) 113 | data['iter_num'] = iter_i 114 | data['epoch'] = epoch 115 | model.set_input(data) 116 | model.inference(save_folder = f'temp/{iter_i}') 117 | else: 118 | if opt.category == "im_5": 119 | category = random.choice(list(category_5_to_num.keys())) 120 | else: 121 | category = opt.category 122 | 123 | model.sample(category = category, prefix = 'results', ema = True, ddim_steps = 200, save_index = iter_i) 124 | 125 | # torch.cuda.empty_cache() 126 | 127 | if opt.update_learning_rate: 128 | model.update_learning_rate_cos(epoch, opt) 129 | 130 | 131 | 132 | def generate_vae(opt, model, test_loader): 133 | if get_rank() == 0: 134 | cprint('[*] Start training. name: %s' % opt.name, 'blue') 135 | 136 | test_dg = get_data_generator(test_loader) 137 | 138 | epoch_length = len(train_loader) 139 | print('The epoch length is', epoch_length) 140 | 141 | total_iters = epoch_length 142 | start_iter = 0 143 | 144 | # pbar = tqdm(total=total_iters) 145 | pbar = tqdm(range(start_iter, total_iters)) 146 | 147 | for iter_i in range(start_iter, total_iters): 148 | 149 | data = next(test_dg) 150 | data['iter_num'] = iter_i 151 | data['epoch'] = 0 152 | model.set_input(data) 153 | seed_everything(opt.seed) 154 | model.inference() 155 | pbar.update 156 | 157 | 158 | def generate(opt, model): 159 | 160 | # get n_epochs here 161 | total_iters = 100000000 162 | pbar = tqdm(total=total_iters) 163 | 164 | total_num = category_5_to_num[opt.category] 165 | 166 | for iter_i in range(total_iters): 167 | 168 | result_index = iter_i * get_world_size() + get_rank() 169 | if opt.split_dir is not None: 170 | split_path = os.path.join(opt.split_dir, f'{result_index}.pth') 171 | split_small = torch.load(split_path) 172 | split_small = split_small.to(model.device) 173 | else: 174 | split_small = None 175 | model.batch_size = 1 176 | 177 | if result_index >= total_num: 178 | break 179 | 180 | if opt.category == "im_5": 181 | category = random.choice(list(category_5_to_label.keys())) 182 | else: 183 | category = opt.category 184 | model.sample(split_small = split_small, category = category, prefix = 'results', ema = True, ddim_steps = 200, clean = False, save_index = result_index) 185 | pbar.update(1) 186 | 187 | if __name__ == "__main__": 188 | # this will parse args, setup log_dirs, multi-gpus 189 | opt = TrainOptions().parse_and_setup() 190 | device = opt.device 191 | rank = opt.rank 192 | 193 | # CUDA_VISIBLE_DEVICES = int(os.environ["LOCAL_RANK"]) 194 | # import pdb; pdb.set_trace() 195 | 196 | # get current time, print at terminal. easier to track exp 197 | from datetime import datetime 198 | opt.exp_time = datetime.now().strftime('%Y-%m-%dT%H-%M') 199 | 200 | # main loop 201 | model = create_model(opt) 202 | opt.start_iter = model.start_iter 203 | cprint(f'[*] "{opt.model}" initialized.', 'cyan') 204 | 205 | # visualizer 206 | visualizer = Visualizer(opt) 207 | if get_rank() == 0: 208 | visualizer.setup_io() 209 | 210 | # save model and dataset files 211 | if get_rank() == 0: 212 | expr_dir = '%s/%s' % (opt.logs_dir, opt.name) 213 | model_f = inspect.getfile(model.__class__) 214 | modelf_out = os.path.join(expr_dir, os.path.basename(model_f)) 215 | os.system(f'cp {model_f} {modelf_out}') 216 | if opt.model != "vae": 217 | unet_f = inspect.getfile(model.df_module.__class__) 218 | unetf_out = os.path.join(expr_dir, os.path.basename(unet_f)) 219 | os.system(f'cp {unet_f} {unetf_out}') 220 | dset_f = "datasets/dualoctree_snet.py" 221 | dsetf_out = os.path.join(expr_dir, os.path.basename(dset_f)) 222 | os.system(f'cp {dset_f} {dsetf_out}') 223 | sh_f = 'scripts/run_snet_uncond.sh' 224 | sh_out = os.path.join(expr_dir, os.path.basename(sh_f)) 225 | os.system(f'cp {sh_f} {sh_out}') 226 | train_f = 'train.py' 227 | train_out = os.path.join(expr_dir, os.path.basename(train_f)) 228 | os.system(f'cp {train_f} {train_out}') 229 | 230 | if opt.vq_cfg is not None: 231 | vq_cfg = opt.vq_cfg 232 | cfg_out = os.path.join(expr_dir, os.path.basename(vq_cfg)) 233 | os.system(f'cp {vq_cfg} {cfg_out}') 234 | 235 | if opt.df_cfg is not None: 236 | df_cfg = opt.df_cfg 237 | cfg_out = os.path.join(expr_dir, os.path.basename(df_cfg)) 238 | os.system(f'cp {df_cfg} {cfg_out}') 239 | if opt.mode == 'train': 240 | # if opt.debug == "0": 241 | # # try: 242 | # # train_main_worker(opt, model, train_loader, test_loader, visualizer) 243 | # # except: 244 | # # import traceback 245 | # # print(traceback.format_exc(), flush=True) 246 | # # with open(os.path.join(opt.logs_dir, opt.name, "error.txt"), "a") as f: 247 | # # f.write(traceback.format_exc() + "\n") 248 | # # raise ValueError 249 | # else: 250 | train_loader, test_loader = config_dataloader(opt) 251 | train_main_worker(opt, model, train_loader, test_loader, visualizer) 252 | elif opt.mode == 'generate': 253 | if opt.model == "vae": 254 | train_loader, test_loader = config_dataloader(opt) 255 | generate_vae(opt, model, test_loader) 256 | else: 257 | generate(opt, model) 258 | else: 259 | raise ValueError 260 | 261 | 262 | -------------------------------------------------------------------------------- /models/networks/diffusion_networks/graph_unet_lr.py: -------------------------------------------------------------------------------- 1 | ### adapted from: https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py 2 | 3 | from abc import abstractmethod 4 | from functools import partial 5 | import math 6 | from typing import Iterable 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch as th 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from ocnn.nn import octree2voxel 15 | from einops import rearrange 16 | 17 | # from ldm.modules.diffusionmodules.util import ( 18 | # from external.ldm.modules.diffusionmodules.util import ( 19 | from models.networks.diffusion_networks.ldm_diffusion_util import ( 20 | conv_nd, 21 | default, 22 | ) 23 | from models.networks.modules import ( 24 | ConvDownsample, 25 | ConvUpsample, 26 | ResnetBlock, 27 | AttentionBlock, 28 | LearnedSinusoidalPosEmb, 29 | activation_function, 30 | our_Identity, 31 | convnormalization, 32 | ) 33 | 34 | class UNet3DModel(nn.Module): 35 | """ 36 | The full UNet model with attention and timestep embedding. 37 | :param in_channels: channels in the input Tensor. 38 | :param model_channels: base channel count for the model. 39 | :param out_channels: channels in the output Tensor. 40 | :param num_res_blocks: number of residual blocks per downsample. 41 | :param attention_resolutions: a collection of downsample rates at which 42 | attention will take place. May be a set, list, or tuple. 43 | For example, if this contains 4, then at 4x downsampling, attention 44 | will be used. 45 | :param dropout: the dropout probability. 46 | :param channel_mult: channel multiplier for each level of the UNet. 47 | :param conv_resample: if True, use learned convolutions for upsampling and 48 | downsampling. 49 | :param dims: determines if the signal is 1D, 2D, or 3D. 50 | :param num_classes: if specified (as an int), then this model will be 51 | class-conditional with `num_classes` classes. 52 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 53 | :param num_heads: the number of attention heads in each attention layer. 54 | :param num_heads_channels: if specified, ignore num_heads and instead use 55 | a fixed channel width per attention head. 56 | :param num_heads_upsample: works with num_heads to set a different number 57 | of heads for upsampling. Deprecated. 58 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 59 | :param resblock_updown: use residual blocks for up/downsampling. 60 | :param use_new_attention_order: use a different attention pattern for potentially 61 | increased efficiency. 62 | """ 63 | 64 | # def __init__(self, config_dict): 65 | def __init__( 66 | self, 67 | full_depth, 68 | in_split_channels, 69 | model_channels, 70 | out_split_channels, 71 | attention_resolutions, 72 | dropout=0, 73 | channel_mult=(1, 2, 4, 8), 74 | dims=2, 75 | num_classes=None, 76 | use_checkpoint=False, 77 | num_heads=-1, 78 | use_text_condition=False, 79 | context_dim=None, # custom transformer support 80 | n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model 81 | **kwargs, 82 | ): 83 | super().__init__() 84 | # import pdb; pdb.set_trace() 85 | self.full_depth = full_depth 86 | self.in_channels = in_split_channels 87 | self.model_channels = model_channels 88 | self.out_channels = out_split_channels 89 | self.attention_resolutions = attention_resolutions 90 | self.dropout = dropout 91 | self.channel_mult = channel_mult 92 | self.num_classes = num_classes 93 | self.use_checkpoint = use_checkpoint 94 | self.dtype = th.float32 95 | self.num_heads = num_heads 96 | self.predict_codebook_ids = n_embed is not None 97 | 98 | channels = [self.model_channels, * 99 | map(lambda m: self.model_channels * m, self.channel_mult)] 100 | in_out = list(zip(channels[:-1], channels[1:])) 101 | 102 | 103 | time_embed_dim = self.model_channels * 4 104 | 105 | self.time_pos_emb = LearnedSinusoidalPosEmb(self.model_channels) 106 | 107 | self.time_emb = nn.Sequential( 108 | nn.Linear(self.model_channels + 1, time_embed_dim), 109 | activation_function(), 110 | nn.Linear(time_embed_dim, time_embed_dim) 111 | ) 112 | 113 | if self.num_classes is not None: 114 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 115 | 116 | self.input_emb = conv_nd(dims, 2 * self.in_channels, self.model_channels, 3, padding=1) 117 | 118 | self.downs = nn.ModuleList([]) 119 | self.ups = nn.ModuleList([]) 120 | num_resolutions = len(in_out) 121 | ds = 1 122 | 123 | for ind, (dim_in, dim_out) in enumerate(in_out): 124 | is_last = ind >= (num_resolutions - 1) 125 | self.downs.append(nn.ModuleList([ 126 | ResnetBlock(dims, dim_in, dim_out, 127 | emb_dim=time_embed_dim, dropout=dropout, use_text_condition=use_text_condition), 128 | nn.Sequential( 129 | convnormalization(dim_out), 130 | activation_function(), 131 | AttentionBlock( 132 | dim_out, num_heads=num_heads)) if ds in attention_resolutions else our_Identity(), 133 | ConvDownsample( 134 | dim_out, dims=dims) if not is_last else our_Identity() 135 | ])) 136 | if not is_last: 137 | ds *= 2 138 | 139 | mid_dim = channels[-1] 140 | self.mid_block1 = ResnetBlock( 141 | dims, mid_dim, mid_dim, emb_dim=time_embed_dim, dropout=dropout, use_text_condition=use_text_condition) 142 | 143 | self.mid_self_attn = nn.Sequential( 144 | convnormalization(mid_dim), 145 | activation_function(), 146 | AttentionBlock(mid_dim, num_heads=num_heads) 147 | ) if ds in attention_resolutions else our_Identity() 148 | 149 | self.mid_block2 = ResnetBlock( 150 | dims, mid_dim, mid_dim, emb_dim=time_embed_dim, dropout=dropout, use_text_condition=use_text_condition) 151 | 152 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 153 | is_last = ind >= (num_resolutions - 1) 154 | self.ups.append(nn.ModuleList([ 155 | ResnetBlock(dims, dim_out * 2, dim_in, 156 | emb_dim=time_embed_dim, dropout=dropout, use_text_condition=use_text_condition), 157 | nn.Sequential( 158 | convnormalization(dim_in), 159 | activation_function(), 160 | AttentionBlock( 161 | dim_in, num_heads=num_heads)) if ds in attention_resolutions else our_Identity(), 162 | ConvUpsample( 163 | dim_in, dims=dims) if not is_last else our_Identity() 164 | ])) 165 | if not is_last: 166 | ds //= 2 167 | 168 | self.end = nn.Sequential( 169 | convnormalization(self.model_channels), 170 | activation_function() 171 | ) 172 | 173 | self.out = conv_nd(dims, self.model_channels, self.out_channels, 3, padding=1) 174 | 175 | def forward_as_middle(self, h, doctree, timesteps, label, context): 176 | h_lr = octree2voxel(data=h, octree=doctree.octree, depth=self.full_depth) 177 | h_lr = h_lr.permute(0, 4, 1, 2, 3).contiguous() 178 | h_lr = self.forward(x=h_lr, timesteps=timesteps, label=label, context=context, as_middle=True) 179 | x, y, z, b = doctree.octree.xyzb(self.full_depth) 180 | h_lr = h_lr.permute(0, 2, 3, 4, 1).contiguous() 181 | h_lr = h_lr[b, x, y, z, :] 182 | return h_lr 183 | 184 | def forward(self, x=None, timesteps=None, x_self_cond=None, label = None, context=None, as_middle=False, **kwargs): 185 | """ 186 | Apply the model to an input batch. 187 | :param x: an [N x C x ...] Tensor of inputs. 188 | :param timesteps: a 1-D batch of timesteps. 189 | :param context: conditioning plugged in via crossattn 190 | :param y: an [N] Tensor of labels, if class-conditional. 191 | :return: an [N x C x ...] Tensor of outputs. 192 | """ 193 | assert (label is not None) == ( 194 | self.num_classes is not None 195 | ), "must specify label if and only if the model is class-conditional" 196 | 197 | if not as_middle: 198 | x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) 199 | x = torch.cat((x, x_self_cond), dim=1) 200 | x = self.input_emb(x) 201 | 202 | h = [] 203 | 204 | emb = self.time_emb(self.time_pos_emb(timesteps)) 205 | 206 | if self.num_classes is not None: 207 | assert label.shape == (x.shape[0],) 208 | emb = emb + self.label_emb(label) 209 | 210 | for resnet, self_attn, downsample in self.downs: 211 | x = resnet(x, emb) 212 | x = self_attn(x) 213 | h.append(x) 214 | x = downsample(x) 215 | 216 | x = self.mid_block1(x, emb) 217 | x = self.mid_self_attn(x) 218 | x = self.mid_block2(x, emb) 219 | 220 | for resnet, self_attn, upsample in self.ups: 221 | x = torch.cat((x, h.pop()), dim=1) 222 | x = resnet(x, emb) 223 | x = self_attn(x) 224 | x = upsample(x) 225 | 226 | x = self.end(x) 227 | if as_middle: 228 | return x 229 | else: 230 | return self.out(x) 231 | -------------------------------------------------------------------------------- /models/networks/diffusion_networks/graph_unet_hr.py: -------------------------------------------------------------------------------- 1 | ### adapted from: https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py 2 | 3 | from abc import abstractmethod 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from ocnn.nn import octree2voxel 10 | from ocnn.utils import scatter_add 11 | 12 | from models.networks.dualoctree_networks import dual_octree 13 | from models.networks.diffusion_networks.ldm_diffusion_util import create_full_octree 14 | 15 | # from ldm.modules.diffusionmodules.util import ( 16 | # from external.ldm.modules.diffusionmodules.util import ( 17 | from models.networks.diffusion_networks.ldm_diffusion_util import ( 18 | checkpoint, 19 | conv_nd, 20 | linear, 21 | avg_pool_nd, 22 | zero_module, 23 | voxelnormalization, 24 | timestep_embedding, 25 | ) 26 | 27 | from models.networks.modules import ( 28 | GraphConv, 29 | Conv1x1, 30 | graphnormalization, 31 | TimestepBlock, 32 | GraphDownsample, 33 | GraphUpsample, 34 | GraphResBlockEmbed, 35 | 36 | ) 37 | 38 | class UNet3DModel(nn.Module): 39 | """ 40 | The full UNet model with attention and timestep embedding. 41 | :param in_channels: channels in the input Tensor. 42 | :param model_channels: base channel count for the model. 43 | :param out_channels: channels in the output Tensor. 44 | :param num_res_blocks: number of residual blocks per downsample. 45 | :param attention_resolutions: a collection of downsample rates at which 46 | attention will take place. May be a set, list, or tuple. 47 | For example, if this contains 4, then at 4x downsampling, attention 48 | will be used. 49 | :param dropout: the dropout probability. 50 | :param channel_mult: channel multiplier for each level of the UNet. 51 | :param conv_resample: if True, use learned convolutions for upsampling and 52 | downsampling. 53 | :param dims: determines if the signal is 1D, 2D, or 3D. 54 | :param num_classes: if specified (as an int), then this model will be 55 | class-conditional with `num_classes` classes. 56 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 57 | :param num_heads: the number of attention heads in each attention layer. 58 | :param num_heads_channels: if specified, ignore num_heads and instead use 59 | a fixed channel width per attention head. 60 | :param num_heads_upsample: works with num_heads to set a different number 61 | of heads for upsampling. Deprecated. 62 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 63 | :param resblock_updown: use residual blocks for up/downsampling. 64 | :param use_new_attention_order: use a different attention pattern for potentially 65 | increased efficiency. 66 | """ 67 | 68 | # def __init__(self, config_dict): 69 | def __init__( 70 | self, 71 | image_size, 72 | input_depth, 73 | full_depth, 74 | in_channels, 75 | model_channels, 76 | lr_model_channels, 77 | out_channels, 78 | num_res_blocks, 79 | dropout=0, 80 | channel_mult=[1, 2, 4], 81 | dims=3, 82 | num_classes=None, 83 | use_checkpoint=False, 84 | num_heads=-1, 85 | use_scale_shift_norm=False, 86 | **kwargs, 87 | ): 88 | super().__init__() 89 | 90 | self.image_size = image_size 91 | self.input_depth = input_depth 92 | self.full_depth = full_depth 93 | self.in_channels = in_channels 94 | self.model_channels = model_channels 95 | self.out_channels = out_channels 96 | self.num_res_blocks = num_res_blocks 97 | self.dropout = dropout 98 | self.channel_mult = channel_mult 99 | self.num_classes = num_classes 100 | self.use_checkpoint = use_checkpoint 101 | self.dtype = torch.float32 102 | self.num_heads = num_heads 103 | n_edge_type, avg_degree = 7, 7 104 | 105 | time_embed_dim = model_channels * 4 106 | 107 | self.time_embed = nn.Sequential( 108 | linear(model_channels, time_embed_dim), 109 | nn.SiLU(), 110 | linear(time_embed_dim, time_embed_dim) 111 | ) 112 | 113 | if self.num_classes is not None: 114 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 115 | 116 | d = self.input_depth 117 | 118 | self.input_blocks = nn.ModuleList([ 119 | GraphConv(self.in_channels, model_channels, n_edge_type, avg_degree, self.input_depth - 1) 120 | ]) 121 | 122 | self._feature_size = model_channels 123 | input_block_chans = [model_channels] 124 | ch = model_channels 125 | for level, mult in enumerate(channel_mult): 126 | for _ in range(self.num_res_blocks[level]): 127 | resblk = GraphResBlockEmbed( 128 | ch, 129 | time_embed_dim, 130 | dropout, 131 | out_channels=mult * model_channels, 132 | n_edge_type = n_edge_type, 133 | avg_degree = avg_degree, 134 | n_node_type = d - 1, 135 | dims=dims, 136 | use_checkpoint=use_checkpoint, 137 | use_scale_shift_norm=use_scale_shift_norm, 138 | ) 139 | ch = mult * model_channels 140 | self.input_blocks.append(resblk) 141 | self._feature_size += ch 142 | input_block_chans.append(ch) 143 | 144 | if level != len(channel_mult) - 1: 145 | out_ch = ch 146 | d -= 1 147 | self.input_blocks.append( 148 | GraphDownsample(ch, out_ch,n_edge_type, avg_degree, d-1) 149 | ) 150 | ch = out_ch 151 | input_block_chans.append(ch) 152 | self._feature_size += ch 153 | 154 | self.middle_block1 = GraphResBlockEmbed( 155 | ch, 156 | time_embed_dim, 157 | dropout, 158 | out_channels = lr_model_channels, 159 | n_edge_type = n_edge_type, 160 | avg_degree = avg_degree, 161 | n_node_type = d - 1, 162 | dims=dims, 163 | use_checkpoint=use_checkpoint, 164 | use_scale_shift_norm=use_scale_shift_norm, 165 | ) 166 | 167 | self.middle_block2 = GraphResBlockEmbed( 168 | lr_model_channels * 2, 169 | time_embed_dim, 170 | dropout, 171 | out_channels = ch, 172 | n_edge_type = n_edge_type, 173 | avg_degree = avg_degree, 174 | n_node_type = d - 1, 175 | dims=dims, 176 | use_checkpoint=use_checkpoint, 177 | use_scale_shift_norm=use_scale_shift_norm, 178 | ) 179 | 180 | self._feature_size += ch 181 | 182 | self.output_blocks = nn.ModuleList([]) 183 | for level, mult in list(enumerate(channel_mult))[::-1]: 184 | for i in range(self.num_res_blocks[level] + 1): 185 | ich = input_block_chans.pop() 186 | resblk = GraphResBlockEmbed( 187 | ch + ich, 188 | time_embed_dim, 189 | dropout, 190 | out_channels=model_channels * mult, 191 | n_edge_type = n_edge_type, 192 | avg_degree = avg_degree, 193 | n_node_type = d - 1, 194 | dims=dims, 195 | use_checkpoint=use_checkpoint, 196 | use_scale_shift_norm=use_scale_shift_norm, 197 | ) 198 | self.output_blocks.append(resblk) 199 | ch = model_channels * mult 200 | if level and i == self.num_res_blocks[level]: 201 | out_ch = ch 202 | d += 1 203 | upsample = GraphUpsample(ch, out_ch, n_edge_type, avg_degree, d-1) 204 | self.output_blocks.append(upsample) 205 | self._feature_size += ch 206 | 207 | self.end_norm = graphnormalization(ch) 208 | self.end = nn.SiLU() 209 | self.out = zero_module(GraphConv(ch, self.out_channels, n_edge_type, avg_degree, self.input_depth - 1)) 210 | 211 | def forward_as_middle(self, h, doctree, timesteps, label, context): 212 | return self.forward(x=h, doctree=doctree, timesteps=timesteps, label=label, context=context, as_middle=True) 213 | 214 | def forward(self, x = None, doctree = None, unet_lr = None, timesteps = None, label = None, context = None, as_middle=False, **kwargs): 215 | """ 216 | Apply the model to an input batch. 217 | :param x: an [N x C x ...] Tensor of inputs. 218 | :param timesteps: a 1-D batch of timesteps. 219 | :param context: conditioning plugged in via crossattn 220 | :param y: an [N] Tensor of labels, if class-conditional. 221 | :return: an [N x C x ...] Tensor of outputs. 222 | """ 223 | assert (label is not None) == ( 224 | self.num_classes is not None 225 | ), "must specify y if and only if the model is class-conditional" 226 | 227 | 228 | hs = [] 229 | t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) 230 | emb = self.time_embed(t_emb) 231 | 232 | if self.num_classes is not None: 233 | assert label.shape == (doctree.batch_size,) 234 | emb = emb + self.label_emb(label) 235 | 236 | d = self.input_depth 237 | 238 | if not as_middle: 239 | h = self.input_blocks[0](x, doctree, d) 240 | else: 241 | h = x 242 | hs.append(h) 243 | 244 | for module in self.input_blocks[1:]: 245 | if isinstance(module, GraphConv): 246 | h = module(h, doctree, d) 247 | elif isinstance(module, GraphResBlockEmbed): 248 | h = module(h, emb, doctree, d) 249 | elif isinstance(module, GraphDownsample): 250 | h = module(h, doctree, d) 251 | d -= 1 252 | 253 | hs.append(h) 254 | 255 | 256 | 257 | if unet_lr is not None: 258 | h = self.middle_block1(h, emb, doctree, d) 259 | h_lr = unet_lr.forward_as_middle(h, doctree, timesteps, label, context) 260 | h = torch.cat([h, h_lr], dim=1) 261 | 262 | h = self.middle_block2(h, emb, doctree, d) 263 | 264 | for module in self.output_blocks: 265 | if isinstance(module, GraphResBlockEmbed): 266 | h = torch.cat([h, hs.pop()], dim=1) 267 | h = module(h, emb, doctree, d) 268 | elif isinstance(module, GraphUpsample): 269 | h = module(h, doctree, d) 270 | d += 1 271 | 272 | h = self.end(self.end_norm(h, doctree, d)) 273 | 274 | if as_middle: 275 | return h 276 | 277 | out = self.out(h, doctree, d) 278 | 279 | assert out.shape[0] == x.shape[0] 280 | 281 | return out 282 | --------------------------------------------------------------------------------