├── app.py ├── assets └── examples │ ├── cup.png │ ├── 000102.jpg │ ├── chair1.png │ ├── monkey.png │ ├── sandie.png │ ├── teapot.png │ ├── bear_toy.png │ ├── toaster.png │ ├── baby_yoda.png │ ├── cartoon_boy.png │ ├── image (12).png │ ├── man_origin.png │ ├── purple_bear.png │ ├── wooden_bear.png │ ├── wooden_owl.png │ ├── fire_dragon.webp │ ├── house_teapot.png │ ├── mouse_statue.png │ ├── plush_cow_toy.png │ ├── shy_ghost_rgb.jpg │ ├── 20240120-181122.jpeg │ ├── chicken_nesting.png │ ├── green_parrot_rgb.jpg │ ├── lounger_armchair.png │ ├── medieval_flask.png │ ├── monster_hand_rgb.png │ ├── a_pikachu_with_smily_face.webp │ ├── a_cute_little_frog_comicbook_style.webp │ ├── retro_pc_photorealistic_high_detailed_rgb.png │ └── Mastral_A_2d_early_1600s_painting_of_simple_shield__rembrand_mi_af3afbe0-9a45-46be-a313-fc94a12689a6.png ├── meshgen ├── utils │ ├── hf_weights.py │ ├── optim.py │ ├── birefnet.py │ ├── captioner.py │ ├── images.py │ ├── misc.py │ ├── ema.py │ ├── io.py │ ├── remesh.py │ ├── math_utils.py │ ├── ray_utils.py │ ├── render_ops.py │ ├── render.py │ ├── briarmbg.py │ └── ops.py ├── modules │ ├── super_resolution.py │ ├── encoders │ │ └── dino.py │ ├── mlp.py │ ├── ema.py │ ├── timm.py │ ├── resnet.py │ ├── attention.py │ └── mesh │ │ └── mesh.py ├── model │ ├── base.py │ ├── diffusion │ │ └── rfunet.py │ └── triplane_autoencoder.py └── util.py ├── requirements_cuda12.txt ├── requirements.txt ├── configs ├── texgen.yaml └── shapegen.yaml ├── readme.md ├── .gitignore ├── texgen.py ├── zero123pp └── utils.py ├── jointgen.py └── shapegen.py /app.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/examples/cup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/cup.png -------------------------------------------------------------------------------- /assets/examples/000102.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/000102.jpg -------------------------------------------------------------------------------- /assets/examples/chair1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/chair1.png -------------------------------------------------------------------------------- /assets/examples/monkey.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/monkey.png -------------------------------------------------------------------------------- /assets/examples/sandie.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/sandie.png -------------------------------------------------------------------------------- /assets/examples/teapot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/teapot.png -------------------------------------------------------------------------------- /assets/examples/bear_toy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/bear_toy.png -------------------------------------------------------------------------------- /assets/examples/toaster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/toaster.png -------------------------------------------------------------------------------- /assets/examples/baby_yoda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/baby_yoda.png -------------------------------------------------------------------------------- /assets/examples/cartoon_boy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/cartoon_boy.png -------------------------------------------------------------------------------- /assets/examples/image (12).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/image (12).png -------------------------------------------------------------------------------- /assets/examples/man_origin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/man_origin.png -------------------------------------------------------------------------------- /assets/examples/purple_bear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/purple_bear.png -------------------------------------------------------------------------------- /assets/examples/wooden_bear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/wooden_bear.png -------------------------------------------------------------------------------- /assets/examples/wooden_owl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/wooden_owl.png -------------------------------------------------------------------------------- /assets/examples/fire_dragon.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/fire_dragon.webp -------------------------------------------------------------------------------- /assets/examples/house_teapot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/house_teapot.png -------------------------------------------------------------------------------- /assets/examples/mouse_statue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/mouse_statue.png -------------------------------------------------------------------------------- /assets/examples/plush_cow_toy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/plush_cow_toy.png -------------------------------------------------------------------------------- /assets/examples/shy_ghost_rgb.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/shy_ghost_rgb.jpg -------------------------------------------------------------------------------- /assets/examples/20240120-181122.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/20240120-181122.jpeg -------------------------------------------------------------------------------- /assets/examples/chicken_nesting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/chicken_nesting.png -------------------------------------------------------------------------------- /assets/examples/green_parrot_rgb.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/green_parrot_rgb.jpg -------------------------------------------------------------------------------- /assets/examples/lounger_armchair.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/lounger_armchair.png -------------------------------------------------------------------------------- /assets/examples/medieval_flask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/medieval_flask.png -------------------------------------------------------------------------------- /assets/examples/monster_hand_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/monster_hand_rgb.png -------------------------------------------------------------------------------- /assets/examples/a_pikachu_with_smily_face.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/a_pikachu_with_smily_face.webp -------------------------------------------------------------------------------- /assets/examples/a_cute_little_frog_comicbook_style.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/a_cute_little_frog_comicbook_style.webp -------------------------------------------------------------------------------- /assets/examples/retro_pc_photorealistic_high_detailed_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/retro_pc_photorealistic_high_detailed_rgb.png -------------------------------------------------------------------------------- /assets/examples/Mastral_A_2d_early_1600s_painting_of_simple_shield__rembrand_mi_af3afbe0-9a45-46be-a313-fc94a12689a6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heheyas/MeshGen/HEAD/assets/examples/Mastral_A_2d_early_1600s_painting_of_simple_shield__rembrand_mi_af3afbe0-9a45-46be-a313-fc94a12689a6.png -------------------------------------------------------------------------------- /meshgen/utils/hf_weights.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import hf_hub_download 2 | 3 | repo = "heheyas/MeshGen" 4 | 5 | 6 | texture_inpainter_path = hf_hub_download(repo, "texture_inpainter.pth") 7 | shape_generator_path = hf_hub_download(repo, "shape_generator.pth") 8 | pbr_decomposer_path = hf_hub_download(repo, "pbr_decomposer.pth") 9 | mv_generator_path = hf_hub_download(repo, "mv_generator.pth") 10 | -------------------------------------------------------------------------------- /meshgen/modules/super_resolution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import numpy as np 4 | from RealESRGAN import RealESRGAN 5 | 6 | 7 | class RealESRGANUpscaler(torch.nn.Module): 8 | def __init__(self, scale: int = 4): 9 | super().__init__() 10 | model_path = f"weights/RealESRGAN_x{scale}.pth" 11 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | self.model = RealESRGAN(self.device, scale=scale) 13 | self.model.load_weights(model_path, download=True) 14 | 15 | def forward(self, image): 16 | return np.asarray(self.model.predict(image)) 17 | 18 | 19 | class ControlNetUpscaler(torch.nn.Module): 20 | def __init__(self): 21 | pass 22 | -------------------------------------------------------------------------------- /meshgen/utils/optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def param_groups_weight_decay( 6 | model: nn.Module, weight_decay=1e-5, no_weight_decay_list=() 7 | ): 8 | no_weight_decay_list = set(no_weight_decay_list) 9 | decay = [] 10 | no_decay = [] 11 | for name, param in model.named_parameters(): 12 | if not param.requires_grad: 13 | continue 14 | 15 | if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list: 16 | no_decay.append(param) 17 | else: 18 | decay.append(param) 19 | 20 | return [ 21 | {"params": no_decay, "weight_decay": 0.0}, 22 | {"params": decay, "weight_decay": weight_decay}, 23 | ] 24 | -------------------------------------------------------------------------------- /requirements_cuda12.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.30.0 2 | einops==0.8.0 3 | gradio==4.44.1 4 | httpx==0.27.2 5 | huggingface_hub==0.24.7 6 | jaxtyping==0.2.34 7 | kiui==0.2.10 8 | loguru==0.7.2 9 | lpips==0.1.4 10 | mediapy==1.2.2 11 | numpy<2 12 | omegaconf==2.1.1 13 | open3d==0.18.0 14 | openai==1.51.2 15 | opencv_python==4.10.0.82 16 | opencv_python_headless==4.10.0.82 17 | peft==0.12.0 18 | Pillow==10.4.0 19 | pyacvd==0.2.11 20 | pyfqmr==0.2.1 21 | pygltflib==1.16.2 22 | pyrender==0.1.45 23 | pytorch_lightning==1.5.9 24 | pyvista==0.44.1 25 | scikit-image==0.20.0 26 | torchmetrics==0.11.4 27 | tqdm==4.66.5 28 | transformers==4.41.2 29 | trimesh==3.23.5 30 | tyro==0.8.11 31 | vedo==2024.5.2 32 | xatlas==0.0.9 33 | kaolin==0.16.0 34 | rembg 35 | timm 36 | kornia 37 | -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.1.2_cu121.html 38 | git+https://github.com/sberbank-ai/Real-ESRGAN.git 39 | git+https://github.com/NVlabs/nvdiffrast 40 | git+https://github.com/facebookresearch/pytorch3d.git 41 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.30.0 2 | einops==0.8.0 3 | gradio==4.44.1 4 | httpx==0.27.2 5 | huggingface_hub==0.24.7 6 | jaxtyping==0.2.34 7 | kiui==0.2.10 8 | loguru==0.7.2 9 | lpips==0.1.4 10 | mediapy==1.2.2 11 | numpy<2 12 | omegaconf==2.1.1 13 | open3d==0.18.0 14 | openai==1.51.2 15 | opencv_python==4.10.0.82 16 | opencv_python_headless==4.10.0.82 17 | peft==0.12.0 18 | Pillow==10.4.0 19 | pyacvd==0.2.11 20 | pyfqmr==0.2.1 21 | pygltflib==1.16.2 22 | pyrender==0.1.45 23 | pytorch_lightning==1.5.9 24 | pyvista==0.44.1 25 | scikit-image==0.20.0 26 | torchmetrics==0.11.4 27 | tqdm==4.66.5 28 | transformers==4.41.2 29 | trimesh==3.23.5 30 | tyro==0.8.11 31 | vedo==2024.5.2 32 | xatlas==0.0.9 33 | kaolin==0.16.0 34 | rembg 35 | timm 36 | kornia 37 | torch-cluster 38 | -f https://data.pyg.org/whl/torch-2.1.2+cu118.html 39 | -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.1.2_cu118.html 40 | git+https://github.com/sberbank-ai/Real-ESRGAN.git 41 | git+https://github.com/NVlabs/nvdiffrast 42 | git+https://github.com/facebookresearch/pytorch3d.git 43 | -------------------------------------------------------------------------------- /meshgen/modules/encoders/dino.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoImageProcessor, AutoModel 3 | 4 | 5 | class FrozenDINOv2Encoder(torch.nn.Module): 6 | def __init__(self, model_name="facebook/dinov2-giant", do_rescale=True): 7 | super().__init__() 8 | self.model_name = model_name 9 | self.processor = AutoImageProcessor.from_pretrained( 10 | model_name, do_rescale=do_rescale 11 | ) 12 | self.model = AutoModel.from_pretrained(model_name) 13 | 14 | self.output_dim = self.model.config.hidden_size 15 | 16 | self.freeze() 17 | self.eval() 18 | 19 | def freeze(self): 20 | print(f"======== Freezing DinoWrapper ========") 21 | self.model.eval() 22 | for name, param in self.model.named_parameters(): 23 | param.requires_grad = False 24 | 25 | @property 26 | def device(self): 27 | return next(self.parameters()).device 28 | 29 | def forward(self, x): 30 | input = self.processor(images=x, return_tensors="pt") 31 | input["pixel_values"] = input["pixel_values"].to(self.device) 32 | outputs = self.model(**input) 33 | 34 | return outputs.last_hidden_state 35 | -------------------------------------------------------------------------------- /meshgen/utils/birefnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | from transformers import AutoModelForImageSegmentation 4 | 5 | _model = None 6 | 7 | 8 | def run_model(image): 9 | 10 | model, transform = get_model() 11 | 12 | image_size = image.size 13 | input_images = transform(image).unsqueeze(0).to("cuda") 14 | # Prediction 15 | with torch.no_grad(): 16 | preds = model(input_images)[-1].sigmoid().cpu() 17 | pred = preds[0].squeeze() 18 | pred_pil = transforms.ToPILImage()(pred) 19 | mask = pred_pil.resize(image_size) 20 | image.putalpha(mask) 21 | 22 | return image 23 | 24 | 25 | def get_model(): 26 | global _model 27 | 28 | if _model is not None: 29 | return _model 30 | 31 | birefnet = AutoModelForImageSegmentation.from_pretrained( 32 | "ZhengPeng7/BiRefNet", trust_remote_code=True 33 | ) 34 | birefnet.to("cuda") 35 | transform_image = transforms.Compose( 36 | [ 37 | transforms.Resize((1024, 1024)), 38 | transforms.ToTensor(), 39 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 40 | ] 41 | ) 42 | 43 | _model = birefnet, transform_image 44 | 45 | return _model 46 | -------------------------------------------------------------------------------- /meshgen/modules/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import itertools 4 | 5 | 6 | def nonlinearity(x): 7 | # swish 8 | return x * torch.sigmoid(x) 9 | 10 | 11 | class MlpDecoder(nn.Module): 12 | """ 13 | Triplane decoder that gives RGB and sigma values from sampled features. 14 | Using ReLU here instead of Softplus in the original implementation. 15 | 16 | Reference: 17 | EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 18 | """ 19 | 20 | def __init__( 21 | self, 22 | n_features: int, 23 | hidden_dim: int = 64, 24 | num_layers: int = 4, 25 | ): 26 | super().__init__() 27 | self.net = nn.Sequential( 28 | nn.Linear(3 * n_features, hidden_dim), 29 | nn.SiLU(), 30 | *itertools.chain( 31 | *[ 32 | [ 33 | nn.Linear(hidden_dim, hidden_dim), 34 | nn.SiLU(), 35 | ] 36 | for _ in range(num_layers - 2) 37 | ] 38 | ), 39 | nn.Linear(hidden_dim, 1), 40 | ) 41 | # init all bias to zero 42 | for m in self.modules(): 43 | if isinstance(m, nn.Linear): 44 | nn.init.zeros_(m.bias) 45 | 46 | def forward(self, x): 47 | 48 | return self.net(x).squeeze(-1) 49 | -------------------------------------------------------------------------------- /configs/texgen.yaml: -------------------------------------------------------------------------------- 1 | target: meshgen.model.texturing.texture_pbr_painter.SparseViewPBRPainter 2 | params: 3 | control_scale: [1.0, 0.0] 4 | radius: 4 5 | albedo_prompt: "" 6 | start_timestep_idx: 0 7 | sync_latent_end: 1.0 8 | sync_exp_start: 0.0 9 | sync_exp_end: 5.0 10 | multiview_generator: 11 | target: zero123pp.controlnet_joint_model_mesh_denoising.MVJointControlNet 12 | params: 13 | stable_diffusion_config: 14 | pretrained_model_name_or_path: sudo-ai/zero123plus-v1.2 15 | custom_pipeline: ./zero123pp 16 | control_type: depth 17 | # ckpt_path: /pfs/mt-1oY5F7/chenzilong/Texturing/logs/2024-09-23T21-38-03_depth_mv_only/checkpoints/last.ckpt 18 | conditioning_scale: 0.8 19 | scheduler_type: ddim 20 | mesh_kwargs: 21 | rotate: True 22 | texture_resolution: [1024, 1024] 23 | mesh_scale: 0.9 24 | use_latent: True 25 | sr_model: 26 | target: meshgen.modules.super_resolution.RealESRGANUpscaler 27 | params: 28 | scale: 4 29 | pbr_decomposer: 30 | target: zero123pp.ip2p_model.MVIp2p 31 | params: 32 | stable_diffusion_config: 33 | pretrained_model_name_or_path: sudo-ai/zero123plus-v1.2 34 | custom_pipeline: ./zero123pp/pipeline_ip2p.py 35 | # ckpt_path: /pfs/mt-1oY5F7/chenzilong/Texturing/logs/2024-09-21T22-21-49_ip2p_joint/checkpoints/epoch=000015.ckpt 36 | texture_inpainter: 37 | target: meshgen.model.texturing.texture_inpainter.TextureInpainter 38 | params: 39 | base_model: benjamin-paine/stable-diffusion-v1-5 40 | # ckpt_path: /pfs/mt-1oY5F7/chenzilong/Texturing/logs/2024-09-13T15-52-18_texture_inpainting_cnet_original_bg/checkpoints/last.ckpt 41 | inpaint_elevations: [0, -10, 20, -10, 20, -10, 20] 42 | inpaint_azimuths: [0, 60, 120, 180, 240, 300, 0] 43 | zero123pp_view_idx: [0, 3, 5, 2, 4, 1] 44 | renderer_kwargs: 45 | render_angle_thres: 70 46 | grid_size: 2048 47 | -------------------------------------------------------------------------------- /configs/shapegen.yaml: -------------------------------------------------------------------------------- 1 | target: meshgen.model.diffusion.rfunet.RectifiedFlowUNet2D 2 | params: 3 | weight_decay: 0.0001 4 | vis_every: null 5 | scale_factor: 0.2684278283548295 6 | shift_factor: 0.13850155472755432 7 | use_ema: true 8 | skip_validation: true 9 | rf_mu: 0.0 10 | rf_sigma: 1.0 11 | timestep_sample: logit_normal 12 | scheduler_config: 13 | target: meshgen.lr_scheduler.LambdaLinearScheduler 14 | params: 15 | warm_up_steps: 16 | - 1000 17 | cycle_lengths: 18 | - 5000000 19 | f_start: 20 | - 1.0e-06 21 | f_max: 22 | - 1.0 23 | f_min: 24 | - 0.1 25 | autoencoder: 26 | target: meshgen.model.triplane_autoencoder.TriplaneKLModel 27 | params: 28 | triplane_res: 32 29 | triplane_ch: 16 30 | box_warp: 1.1 31 | tv_loss_weight: 0.05 32 | weight_decay: 0.0 33 | encoder: 34 | target: meshgen.modules.shape2vecset.Encoder 35 | params: 36 | depth: 10 37 | dim: 768 38 | queries_dim: 768 39 | heads: 12 40 | dim_head: 64 41 | num_inputs: 65536 42 | num_latents: 3072 43 | output_dim: 32 44 | learnable_query: true 45 | deconv_decoder: 46 | target: meshgen.modules.resnet.DeConvDecoder 47 | params: 48 | z_channels: 16 49 | num_resos: 4 50 | num_res_blocks: 1 51 | ch: 64 52 | out_ch: 64 53 | dropout: 0.0 54 | mlp_decoder: 55 | target: meshgen.modules.mlp.MlpDecoder 56 | params: 57 | n_features: 64 58 | hidden_dim: 64 59 | num_layers: 6 60 | unet: 61 | target: meshgen.modules.diffusion_unet.UNetModel 62 | params: 63 | image_size: 32 64 | in_channels: 16 65 | out_channels: 16 66 | model_channels: 320 67 | attention_resolutions: 68 | - 4 69 | - 2 70 | - 1 71 | num_res_blocks: 2 72 | channel_mult: 73 | - 1 74 | - 2 75 | - 4 76 | - 4 77 | num_heads: 8 78 | use_spatial_transformer: true 79 | transformer_depth: 1 80 | context_dim: 1536 81 | use_checkpoint: false 82 | legacy: false 83 | dtype: bf16 84 | 85 | cond_encoder: 86 | target: meshgen.modules.encoders.dino.FrozenDINOv2Encoder 87 | params: 88 | model_name: facebook/dinov2-giant -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## MeshGen: Generating PBR Textured Mesh with Render-Enhanced Auto-Encoder and Generative Data Augmentation 2 | 3 | [Zilong Chen](https://heheyas.github.io), [Yikai Wang](), [Wenqiang Sun](), [Feng Wang](), [Yiwen Chen](), [Huaping Liu]() 4 | 5 | Tsinghua University, BNU, HKUST, NTU 6 | 7 | This repository contains the official implementation for MeshGen: Generating PBR Textured Mesh with Render-Enhanced Auto-Encoder and Generative Data Augmentation. 8 | 9 | [Arxiv](http://arxiv.org/abs/2505.04656) | [Project page](https://heheyas.github.io/MeshGen) | [HF Demo]() 10 | 11 | ### Run locally 12 | #### Install 13 | First use `pip<24.1` since we are using an old version of lightning: 14 | ```bash 15 | pip install 'pip<24.1' 16 | ``` 17 | If you are with CUDA 11: 18 | Install `torch`: 19 | ```bash 20 | pip install torch==2.1.2+cu118 torchvision==0.16.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 21 | ``` 22 | then install other dependencies: 23 | ```bash 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | Or you are with CUDA 12: 28 | Install `torch`: 29 | ```bash 30 | pip install torch==2.1.2+cu121 torchvision==0.16.2+cu121 extra-index-url https://download.pytorch.org/whl/cu121 31 | ``` 32 | then install other dependencies: 33 | ```bash 34 | pip install -r requirements_cuda12.txt 35 | ``` 36 | 37 | #### Shape Generation: 38 | ```bash 39 | torchrun --nproc_per_node= shapegen.py --images --output 40 | ``` 41 | 42 | #### Texture Generation: 43 | ```bash 44 | torchrun --nproc_per_node= texgen.py --meta --output 45 | ``` 46 | 47 | #### Textured Mesh Generation: 48 | ```bash 49 | torchrun --nproc_per_node= jointgen.py --images --output 50 | ``` 51 | 52 | #### Gradio demo: 53 | ```bash 54 | python app.py 55 | ``` 56 | 57 | ### Acknowledgement 58 | - [Stable Diffusion]() 59 | - [Paint3D]() 60 | - [Zero123++]() 61 | - [3DShape2Vecset]() 62 | 63 | ### Citation 64 | ```bibtex 65 | @inproceedings{chen2025meshgen, 66 | author = {Chen, Zilong and Wang, Yikai and Sun, Wenqiang and Wang, Feng and Chen, Yiwen and Liu, Huaping}, 67 | title = {MeshGen: Generating PBR Textured Mesh with Render-Enhanced Auto-Encoder and Generative Data Augmentation}, 68 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 69 | month = {June}, 70 | year = {2025} 71 | } 72 | ``` 73 | 74 | -------------------------------------------------------------------------------- /meshgen/utils/captioner.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | import base64 3 | from io import BytesIO 4 | from PIL import Image 5 | from openai import OpenAI 6 | 7 | _client = None 8 | 9 | 10 | ## using qwen instead of chatGPT 11 | def get_openai_client(backend): 12 | global _client 13 | if _client is None: 14 | if backend == "qwen": 15 | _client = OpenAI( 16 | api_key="", 17 | base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", 18 | ) 19 | elif backend == "zhipu": 20 | _client = OpenAI( 21 | api_key="", 22 | base_url="https://open.bigmodel.cn/api/paas/v4/", 23 | http_client=httpx.Client(trust_env=False), 24 | ) 25 | elif backend == "openai": 26 | _client = OpenAI( 27 | api_key="", 28 | base_url="https://api.chatanywhere.tech/v1", 29 | http_client=httpx.Client(trust_env=False), 30 | ) 31 | return _client 32 | 33 | 34 | system_prompt = """ 35 | You are an agent specialized in describing images rendered from 3D objects in English. You will be shown an image of a 3D object and asked to describe it with a short sentence in English. Please start with A .... Please respond in English.' 36 | """ 37 | 38 | 39 | def get_img_url(image_base64, quality): 40 | return {"url": f"data:image/jpeg;base64,{image_base64}", "detail": f"{quality}"} 41 | # return f"data:image/jpeg;base64,{image_base64}" 42 | 43 | 44 | def analyze_image(img_url, backend="zhipu"): 45 | client = get_openai_client(backend) 46 | backend2model = { 47 | "qwen": "qwen-vl-max", 48 | "zhipu": "glm-4v", 49 | "openai": "gpt-4o-mini", 50 | } 51 | response = client.chat.completions.create( 52 | model=backend2model[backend], 53 | messages=[ 54 | {"role": "system", "content": system_prompt}, 55 | { 56 | "role": "user", 57 | "content": [ 58 | { 59 | "type": "image_url", 60 | "image_url": img_url, 61 | }, 62 | ], 63 | }, 64 | ], 65 | max_tokens=300, 66 | top_p=0.1, 67 | ) 68 | 69 | return response.choices[0].message.content 70 | 71 | 72 | def captioning(image_pil, quality="high", backend="openai"): 73 | buffered = BytesIO() 74 | image_pil.save(buffered, format="JPEG") 75 | img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") 76 | img_url = get_img_url(img_str, quality) 77 | 78 | return analyze_image(img_url, backend) 79 | -------------------------------------------------------------------------------- /meshgen/utils/images.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from rembg import remove, new_session 4 | from kiui.op import recenter 5 | import numpy as np 6 | from transformers import pipeline 7 | from .birefnet import run_model as remove_bg_birefnet 8 | 9 | _rembg_session = None 10 | _bria_model = None 11 | 12 | 13 | def get_rembg_session(): 14 | global _rembg_session 15 | if _rembg_session is None: 16 | _rembg_session = new_session() 17 | 18 | return _rembg_session 19 | 20 | 21 | def get_bria_model(): 22 | global _bria_model 23 | if _bria_model is None: 24 | _bria_model = pipeline( 25 | "image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True 26 | ) 27 | 28 | return _bria_model 29 | 30 | 31 | def remove_bg_rembg(image, **kwargs): 32 | return remove(image, **kwargs) 33 | 34 | 35 | def remove_bg_briarmbg(image, **kwargs): 36 | model = get_bria_model() 37 | return model(image) 38 | 39 | 40 | def preprocess_image( 41 | image: Image.Image, 42 | size: int | tuple[int] = 512, 43 | border_ratio: None | float = None, 44 | remove_bg: bool = False, 45 | ignore_alpha: bool = False, 46 | alpha_matting: bool = True, 47 | backend="bria", 48 | ): 49 | rembg_session = get_rembg_session() 50 | 51 | if border_ratio > 0: 52 | if image.mode != "RGBA" or ignore_alpha: 53 | image = image.convert("RGB") 54 | if backend == "rembg": 55 | carved_image = remove_bg_rembg( 56 | image, alpha_matting=alpha_matting, session=rembg_session 57 | ) # [H, W, 4] 58 | elif backend == "bria": 59 | carved_image = remove_bg_briarmbg(image) 60 | elif backend == "birefnet": 61 | carved_image = remove_bg_birefnet(image) 62 | else: 63 | raise ValueError(f"Unknown backend: {backend}") 64 | carved_image = np.asarray(carved_image) 65 | else: 66 | image = np.asarray(image) 67 | carved_image = image 68 | mask = carved_image[..., -1] > 0 69 | image = recenter(carved_image, mask, border_ratio=border_ratio) 70 | image = image.astype(np.float32) / 255.0 71 | 72 | if remove_bg: 73 | if image.shape[-1] == 4: 74 | image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) 75 | else: 76 | image = image 77 | image = Image.fromarray((image * 255).astype(np.uint8)) 78 | # else: 79 | # # raise ValueError("border_ratio must be set currently") 80 | # pass 81 | 82 | if isinstance(size, int): 83 | size = (size, size) 84 | 85 | image = image.resize(size) 86 | 87 | return image 88 | -------------------------------------------------------------------------------- /meshgen/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | def checkpoint(func, inputs, params, flag): 6 | """ 7 | Evaluate a function without caching intermediate activations, allowing for 8 | reduced memory at the expense of extra compute in the backward pass. 9 | :param func: the function to evaluate. 10 | :param inputs: the argument sequence to pass to `func`. 11 | :param params: a sequence of parameters `func` depends on but does not 12 | explicitly take as arguments. 13 | :param flag: if False, disable gradient checkpointing. 14 | """ 15 | if flag: 16 | args = tuple(inputs) + tuple(params) 17 | return CheckpointFunction.apply(func, len(inputs), *args) 18 | else: 19 | return func(*inputs) 20 | 21 | 22 | class CheckpointFunction(torch.autograd.Function): 23 | @staticmethod 24 | def forward(ctx, run_function, length, *args): 25 | ctx.run_function = run_function 26 | ctx.input_tensors = list(args[:length]) 27 | ctx.input_params = list(args[length:]) 28 | 29 | with torch.no_grad(): 30 | output_tensors = ctx.run_function(*ctx.input_tensors) 31 | return output_tensors 32 | 33 | @staticmethod 34 | def backward(ctx, *output_grads): 35 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 36 | with torch.enable_grad(): 37 | # Fixes a bug where the first op in run_function modifies the 38 | # Tensor storage in place, which is not allowed for detach()'d 39 | # Tensors. 40 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 41 | output_tensors = ctx.run_function(*shallow_copies) 42 | input_grads = torch.autograd.grad( 43 | output_tensors, 44 | ctx.input_tensors + ctx.input_params, 45 | output_grads, 46 | allow_unused=True, 47 | ) 48 | del ctx.input_tensors 49 | del ctx.input_params 50 | del output_tensors 51 | return (None, None) + input_grads 52 | 53 | 54 | def get_rank(): 55 | # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, 56 | # therefore LOCAL_RANK needs to be checked first 57 | rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") 58 | for key in rank_keys: 59 | rank = os.environ.get(key) 60 | if rank is not None: 61 | return int(rank) 62 | return 0 63 | 64 | 65 | def get_device(): 66 | return torch.device(f"cuda:{get_rank()}") 67 | 68 | 69 | def set_if_none(dd, key, value, default_values=[None]): 70 | if hasattr(dd, key): 71 | ori_value = getattr(dd, key) 72 | if ori_value in default_values: 73 | setattr(dd, key, value) 74 | else: 75 | setattr(dd, key, value) 76 | -------------------------------------------------------------------------------- /meshgen/utils/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /meshgen/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | ( 16 | torch.tensor(0, dtype=torch.int) 17 | if use_num_upates 18 | else torch.tensor(-1, dtype=torch.int) 19 | ), 20 | ) 21 | 22 | for name, p in model.named_parameters(): 23 | if p.requires_grad: 24 | # remove as '.'-character is not allowed in buffers 25 | s_name = name.replace(".", "") 26 | self.m_name2s_name.update({name: s_name}) 27 | self.register_buffer(s_name, p.clone().detach().data) 28 | 29 | self.collected_params = [] 30 | 31 | def forward(self, model): 32 | decay = self.decay 33 | 34 | if self.num_updates >= 0: 35 | self.num_updates += 1 36 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 37 | 38 | one_minus_decay = 1.0 - decay 39 | 40 | with torch.no_grad(): 41 | m_param = dict(model.named_parameters()) 42 | shadow_params = dict(self.named_buffers()) 43 | 44 | for key in m_param: 45 | if m_param[key].requires_grad: 46 | sname = self.m_name2s_name[key] 47 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 48 | shadow_params[sname].sub_( 49 | one_minus_decay * (shadow_params[sname] - m_param[key]) 50 | ) 51 | else: 52 | assert not key in self.m_name2s_name 53 | 54 | def copy_to(self, model): 55 | m_param = dict(model.named_parameters()) 56 | shadow_params = dict(self.named_buffers()) 57 | for key in m_param: 58 | if m_param[key].requires_grad: 59 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 60 | else: 61 | assert not key in self.m_name2s_name 62 | 63 | def store(self, parameters): 64 | """ 65 | Save the current parameters for restoring later. 66 | Args: 67 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 68 | temporarily stored. 69 | """ 70 | self.collected_params = [param.clone() for param in parameters] 71 | 72 | def restore(self, parameters): 73 | """ 74 | Restore the parameters stored with the `store` method. 75 | Useful to validate the model with EMA parameters without affecting the 76 | original optimization process. Store the parameters before the 77 | `copy_to` method. After validation (or model saving), use this to 78 | restore the former parameters. 79 | Args: 80 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 81 | updated with the stored parameters. 82 | """ 83 | for c_param, param in zip(self.collected_params, parameters): 84 | param.data.copy_(c_param.data) 85 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | weights/ 158 | trash/ 159 | tmp/ 160 | logs/ 161 | -------------------------------------------------------------------------------- /meshgen/utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import trimesh 5 | import kaolin 6 | import mediapy 7 | from PIL import Image 8 | from einops import rearrange 9 | from pathlib import Path 10 | 11 | 12 | def load_mesh(filename, device="cpu"): 13 | try: 14 | filename = str(filename) 15 | if filename.endswith(".obj"): 16 | mesh = kaolin.io.obj.import_mesh(filename) 17 | elif filename.endswith(".glb"): 18 | mesh = kaolin.io.gltf.import_mesh(filename) 19 | elif filename.endswith(".off"): 20 | mesh = kaolin.io.off.import_mesh(filename) 21 | else: 22 | raise NotImplementedError(f"Unsupported file format: {filename}") 23 | vertices = mesh.vertices 24 | faces = mesh.faces 25 | except kaolin.io.utils.NonHomogeneousMeshError: 26 | mesh = trimesh.load(filename) 27 | if isinstance(mesh, trimesh.Scene): 28 | mesh = mesh.dump(concatenate=True) 29 | vertices = torch.from_numpy(mesh.vertices).float() 30 | faces = torch.from_numpy(mesh.faces).long() 31 | 32 | if device != "np": 33 | return vertices.to(device), faces.to(device) 34 | else: 35 | return vertices.cpu().numpy(), faces.cpu().numpy() 36 | 37 | 38 | # def normalize_mesh(vertices, faces, band=1 / 256): 39 | # input_np = False 40 | # if isinstance(vertices, np.ndarray): 41 | # input_np = True 42 | # vertices = torch.from_numpy(vertices) 43 | # faces = torch.from_numpy(faces) 44 | # tris = vertices[faces] 45 | # a = tris.min(0)[0].min(0)[0] 46 | # vertices -= a 47 | # tris -= a 48 | # vertices = (vertices / tris.max() + band) / (1 + band * 2) 49 | # vertices -= 0.5 50 | 51 | # if input_np: 52 | # vertices = vertices.numpy() 53 | # faces = faces.numpy() 54 | 55 | # return vertices, faces 56 | 57 | 58 | def normalize_mesh(vertices, faces, target_scale=0.55): 59 | offset = (vertices.min(dim=0)[0] + vertices.max(dim=0)[0]) / 2 60 | vertices -= offset 61 | scale_factor = 2.0 / (vertices.max(dim=0)[0] - vertices.min(dim=0)[0]).max() 62 | vertices *= scale_factor 63 | vertices *= target_scale 64 | 65 | return vertices, faces 66 | 67 | 68 | def export_mesh(vertices, faces, filename, centralize=True, normalize=True): 69 | if isinstance(vertices, torch.Tensor): 70 | vertices = vertices.detach().cpu().numpy() 71 | 72 | if isinstance(faces, torch.Tensor): 73 | faces = faces.detach().cpu().numpy() 74 | 75 | Path(filename).parent.mkdir(exist_ok=True, parents=True) 76 | 77 | mesh = trimesh.Trimesh(vertices, faces) 78 | if centralize: 79 | mesh.vertices -= mesh.centroid 80 | if normalize: 81 | mesh.vertices /= mesh.extents.max() 82 | mesh.export(filename) 83 | 84 | 85 | def write_video(filename, frames, **kwargs): 86 | if isinstance(frames, torch.Tensor): 87 | frames = frames.detach().cpu().numpy() 88 | 89 | if frames.dtype == np.float32: 90 | frames = (frames * 255).clip(0, 255).astype(np.uint8) 91 | 92 | Path(filename).parent.mkdir(exist_ok=True, parents=True) 93 | mediapy.write_video(filename, frames, **kwargs) 94 | 95 | 96 | def write_image(filename, image, format="hwc", **kwargs): 97 | if isinstance(image, torch.Tensor): 98 | image = image.detach().cpu().numpy() 99 | if format == "chw": 100 | image = rearrange(image, "c h w -> h w c") 101 | Path(filename).parent.mkdir(exist_ok=True, parents=True) 102 | mediapy.write_image(filename, image, **kwargs) 103 | 104 | 105 | def read_image(filename, normalize=False, chw=False, return_pt=False, device="cpu"): 106 | image = mediapy.read_image(filename) 107 | 108 | if normalize: 109 | image = image.astype(np.float32) / 255 110 | 111 | if chw: 112 | image = rearrange(image, "h w c -> c h w") 113 | 114 | if return_pt: 115 | image = torch.from_numpy(image).to(device) 116 | 117 | return image 118 | 119 | 120 | def save_tensor_image(tensor: torch.Tensor, save_path: str): 121 | if len(os.path.dirname(save_path)) > 0 and not os.path.exists( 122 | os.path.dirname(save_path) 123 | ): 124 | os.makedirs(os.path.dirname(save_path)) 125 | if len(tensor.shape) == 4: 126 | tensor = tensor.squeeze(0) # [1, c, h, w]-->[c, h, w] 127 | if tensor.shape[0] == 1: 128 | tensor = tensor.repeat(3, 1, 1) 129 | tensor = tensor.permute(1, 2, 0).detach().cpu().numpy() # [c, h, w]-->[h, w, c] 130 | Image.fromarray((tensor * 255).astype(np.uint8)).save(save_path) 131 | -------------------------------------------------------------------------------- /meshgen/utils/remesh.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import contextlib 4 | import io 5 | import sys 6 | import torch 7 | import tempfile 8 | import pyacvd 9 | import pyvista as pv 10 | 11 | from dataclasses import dataclass 12 | from vedo import Mesh 13 | import trimesh 14 | from meshgen.utils.io import load_mesh 15 | 16 | 17 | class Vertex: 18 | @dataclass 19 | class _V: 20 | x: float 21 | y: float 22 | z: float 23 | 24 | def __init__(self, x: float, y: float, z: float): 25 | self.co = self._V(float(x), float(y), float(z)) 26 | 27 | 28 | class Face: 29 | @dataclass 30 | class _F: 31 | vertices: list[int] 32 | 33 | def __len__(self): 34 | return len(self.vertices) 35 | 36 | def __getitem__(self, key): 37 | return self.vertices[key] 38 | 39 | def __init__(self, vertices) -> None: 40 | self.vertices = self._F(vertices) 41 | 42 | 43 | def vf_from_np(vertices, faces): 44 | blender_V = [] 45 | blender_F = [] 46 | for vv in vertices: 47 | blender_V.append(Vertex(vv[0], vv[1], vv[2])) 48 | 49 | for ff in faces: 50 | blender_F.append(Face([int(f_) for f_ in ff.tolist()])) 51 | 52 | return blender_V, blender_F 53 | 54 | 55 | @contextlib.contextmanager 56 | def nostdout(): 57 | save_stdout = sys.stdout 58 | sys.stdout = io.BytesIO() 59 | yield 60 | sys.stdout = save_stdout 61 | 62 | 63 | class suppress_stdout_stderr(object): 64 | """ 65 | A context manager for doing a "deep suppression" of stdout and stderr in 66 | Python, i.e. will suppress all print, even if the print originates in a 67 | compiled C/Fortran sub-function. 68 | This will not suppress raised exceptions, since exceptions are printed 69 | to stderr just before a script exits, and after the context manager has 70 | exited (at least, I think that is why it lets exceptions through). 71 | 72 | """ 73 | 74 | def __init__(self): 75 | # Open a pair of null files 76 | self.null_fds = [os.open(os.devnull, os.O_RDWR) for x in range(2)] 77 | # Save the actual stdout (1) and stderr (2) file descriptors. 78 | self.save_fds = [os.dup(1), os.dup(2)] 79 | 80 | def __enter__(self): 81 | # Assign the null pointers to stdout and stderr. 82 | os.dup2(self.null_fds[0], 1) 83 | os.dup2(self.null_fds[1], 2) 84 | 85 | def __exit__(self, *_): 86 | # Re-assign the real stdout/stderr back to (1) and (2) 87 | os.dup2(self.save_fds[0], 1) 88 | os.dup2(self.save_fds[1], 2) 89 | # Close all file descriptors 90 | for fd in self.null_fds + self.save_fds: 91 | os.close(fd) 92 | 93 | 94 | def auto_remesh(vertices, faces, triangulate=True, density=0.0, scaling=2.0): 95 | if isinstance(vertices, torch.Tensor): 96 | vertices = vertices.cpu().numpy() 97 | faces = faces.cpu().numpy() 98 | 99 | vv, ff = vf_from_np(vertices, faces) 100 | new_v, new_f = generate_quad_mesh(vv, ff, density, scaling, 0) 101 | # with suppress_stdout_stderr(): 102 | # new_v, new_f = generate_quad_mesh(vv, ff, density, scaling, 0) 103 | vedo_mesh = Mesh([new_v, new_f]) 104 | if triangulate: 105 | vedo_mesh.triangulate() 106 | 107 | return vedo_mesh.vertices, vedo_mesh.cells 108 | 109 | 110 | def instantmesh_remesh(vertices, faces, target_num_faces=10000): 111 | """ 112 | remeshing using InstantMeshes 113 | """ 114 | if isinstance(vertices, torch.Tensor): 115 | vertices = vertices.cpu().numpy() 116 | faces = faces.cpu().numpy() 117 | mesh = trimesh.Trimesh(vertices, faces) 118 | mesh_file = tempfile.NamedTemporaryFile(suffix=f"_original.obj", delete=False).name 119 | mesh.export(mesh_file, include_normals=True) 120 | 121 | remeshed_file = tempfile.NamedTemporaryFile( 122 | suffix=f"_remeshed.obj", delete=False 123 | ).name 124 | 125 | command = f"tmp/InstantMeshes {mesh_file} -f {target_num_faces} -o {remeshed_file}" 126 | os.system(command) 127 | 128 | v, f = load_mesh(remeshed_file) 129 | 130 | del mesh_file, remeshed_file 131 | 132 | return v, f 133 | 134 | 135 | def pyacvd_remesh(vertices, faces, target_num_faces=50000): 136 | cells = np.zeros((faces.shape[0], 4), dtype=int) 137 | cells[:, 1:] = faces 138 | cells[:, 0] = 3 139 | mesh = pv.PolyData(vertices, cells) 140 | clus = pyacvd.Clustering(mesh) 141 | clus.cluster(target_num_faces) 142 | remesh = clus.create_mesh() 143 | 144 | vertices = remesh.points 145 | faces = remesh.faces.reshape(-1, 4)[:, 1:] 146 | return vertices, faces 147 | -------------------------------------------------------------------------------- /meshgen/modules/timm.py: -------------------------------------------------------------------------------- 1 | # code from timm 0.3.2 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | import warnings 6 | 7 | 8 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 11 | def norm_cdf(x): 12 | # Computes standard normal cumulative distribution function 13 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 14 | 15 | if (mean < a - 2 * std) or (mean > b + 2 * std): 16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 17 | "The distribution of values may be incorrect.", 18 | stacklevel=2) 19 | 20 | with torch.no_grad(): 21 | # Values are generated by using a truncated uniform distribution and 22 | # then using the inverse CDF for the normal distribution. 23 | # Get upper and lower cdf values 24 | l = norm_cdf((a - mean) / std) 25 | u = norm_cdf((b - mean) / std) 26 | 27 | # Uniformly fill tensor with values from [l, u], then translate to 28 | # [2l-1, 2u-1]. 29 | tensor.uniform_(2 * l - 1, 2 * u - 1) 30 | 31 | # Use inverse cdf transform for normal distribution to get truncated 32 | # standard normal 33 | tensor.erfinv_() 34 | 35 | # Transform to proper mean, std 36 | tensor.mul_(std * math.sqrt(2.)) 37 | tensor.add_(mean) 38 | 39 | # Clamp to ensure it's in the proper range 40 | tensor.clamp_(min=a, max=b) 41 | return tensor 42 | 43 | 44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 45 | # type: (Tensor, float, float, float, float) -> Tensor 46 | r"""Fills the input Tensor with values drawn from a truncated 47 | normal distribution. The values are effectively drawn from the 48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 49 | with values outside :math:`[a, b]` redrawn until they are within 50 | the bounds. The method used for generating the random values works 51 | best when :math:`a \leq \text{mean} \leq b`. 52 | Args: 53 | tensor: an n-dimensional `torch.Tensor` 54 | mean: the mean of the normal distribution 55 | std: the standard deviation of the normal distribution 56 | a: the minimum cutoff value 57 | b: the maximum cutoff value 58 | Examples: 59 | >>> w = torch.empty(3, 5) 60 | >>> nn.init.trunc_normal_(w) 61 | """ 62 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 63 | 64 | 65 | def drop_path(x, drop_prob: float = 0., training: bool = False): 66 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 67 | 68 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 69 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 70 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 71 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 72 | 'survival rate' as the argument. 73 | 74 | """ 75 | if drop_prob == 0. or not training: 76 | return x 77 | keep_prob = 1 - drop_prob 78 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 79 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 80 | random_tensor.floor_() # binarize 81 | output = x.div(keep_prob) * random_tensor 82 | return output 83 | 84 | 85 | class DropPath(nn.Module): 86 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 87 | """ 88 | def __init__(self, drop_prob=None): 89 | super(DropPath, self).__init__() 90 | self.drop_prob = drop_prob 91 | 92 | def forward(self, x): 93 | return drop_path(x, self.drop_prob, self.training) 94 | 95 | 96 | class Mlp(nn.Module): 97 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 98 | super().__init__() 99 | out_features = out_features or in_features 100 | hidden_features = hidden_features or in_features 101 | self.fc1 = nn.Linear(in_features, hidden_features) 102 | self.act = act_layer() 103 | self.fc2 = nn.Linear(hidden_features, out_features) 104 | self.drop = nn.Dropout(drop) 105 | 106 | def forward(self, x): 107 | x = self.fc1(x) 108 | x = self.act(x) 109 | x = self.drop(x) 110 | x = self.fc2(x) 111 | x = self.drop(x) 112 | return x 113 | -------------------------------------------------------------------------------- /meshgen/model/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | from PIL import Image 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import pytorch_lightning as pl 7 | import torch.nn.functional as F 8 | from torch.optim.lr_scheduler import LambdaLR 9 | from torch.utils.data import default_collate 10 | from pathlib import Path 11 | import tqdm 12 | from einops import rearrange 13 | 14 | from meshgen.util import instantiate_from_config 15 | from diffusers.training_utils import EMAModel 16 | from meshgen.utils.io import load_mesh, write_video 17 | from meshgen.utils.render import render_mesh_spiral_offscreen 18 | from meshgen.modules.timm import trunc_normal_, Mlp 19 | 20 | 21 | def disabled_train(self, mode=True): 22 | """Overwrite model.train with this function to make sure train/eval mode 23 | does not change anymore.""" 24 | return self 25 | 26 | 27 | class BaseModel(pl.LightningModule): 28 | def init_from_ckpt(self, path, ignore_keys=list()): 29 | sd = torch.load(path, map_location="cpu") 30 | if "state_dict" in sd: 31 | sd = sd["state_dict"] 32 | keys = list(sd.keys()) 33 | for k in keys: 34 | for ik in ignore_keys: 35 | if k.startswith(ik): 36 | print("Deleting key {} from state_dict.".format(k)) 37 | del sd[k] 38 | missing, unexpected = self.load_state_dict(sd, strict=False) 39 | print( 40 | f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" 41 | ) 42 | if len(missing) > 0 or len(unexpected) > 0: 43 | print(f"Missing Keys: {missing}") 44 | print(f"Unexpected Keys: {unexpected}") 45 | 46 | def log_prefix(self, loss_dict, prefix): 47 | for k, v in loss_dict.items(): 48 | self.log( 49 | f"{prefix}/{k}", 50 | v, 51 | prog_bar=True, 52 | logger=True, 53 | on_step=True, 54 | on_epoch=True, 55 | batch_size=self.trainer.train_dataloader.loaders.batch_size, 56 | ) 57 | 58 | def shared_step(self, batch, batch_idx): 59 | raise NotImplementedError("shared_step must be implemented in subclass") 60 | 61 | def training_step(self, batch, batch_idx): 62 | loss, loss_dict = self.shared_step(batch, batch_idx) 63 | self.log_prefix(loss_dict, "train") 64 | return loss 65 | 66 | def validation_step(self, batch, batch_idx): 67 | loss, loss_dict = self.shared_step(batch, batch_idx) 68 | self.log_prefix(loss_dict, "val") 69 | return loss 70 | 71 | def configure_optimizers(self): 72 | lr = self.learning_rate 73 | # TODO: remove weight decay for some layers 74 | param_groups = self.get_optim_groups() 75 | # opt = torch.optim.AdamW( 76 | # self.parameters(), lr=lr, weight_decay=self.weight_decay 77 | # ) 78 | opt = torch.optim.AdamW(param_groups, lr=lr, weight_decay=self.weight_decay) 79 | if self.use_scheduler: 80 | assert "target" in self.scheduler_config 81 | scheduler = instantiate_from_config(self.scheduler_config) 82 | 83 | print("Setting up LambdaLR scheduler...") 84 | scheduler = [ 85 | { 86 | "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), 87 | "interval": "step", 88 | "frequency": 1, 89 | } 90 | ] 91 | return [opt], scheduler 92 | return opt 93 | 94 | def on_train_epoch_end(self, *args, **kwargs): 95 | self.eval() 96 | logdir = self.trainer.logdir 97 | this_logdir = Path(logdir) / f"spirals/epoch_{self.current_epoch}" 98 | this_logdir.mkdir(exist_ok=True, parents=True) 99 | self.test_in_the_wild(this_logdir) 100 | self.train() 101 | 102 | def on_train_start(self): 103 | return 104 | self.eval() 105 | logdir = self.trainer.logdir 106 | this_logdir = Path(logdir) / f"spirals/before_training" 107 | this_logdir.mkdir(exist_ok=True, parents=True) 108 | self.test_in_the_wild(this_logdir) 109 | self.train() 110 | 111 | def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): 112 | self.eval() 113 | if self.vis_every is None: 114 | return 115 | 116 | # if (self.global_step + 1) % self.vis_every != 0 and self.global_step != 0: 117 | # return 118 | if (self.global_step + 1) % self.vis_every != 0: 119 | return 120 | 121 | logdir = self.trainer.logdir ### 122 | this_logdir = Path(logdir) / f"spirals/step_{self.global_step}" 123 | this_logdir.mkdir(exist_ok=True, parents=True) 124 | self.test_in_the_wild(this_logdir) 125 | self.train() 126 | 127 | def test_in_the_wild(self, save_dir): 128 | raise NotImplementedError("test_in_the_wild must be implemented in subclass") 129 | 130 | def get_optim_groups(self): 131 | raise NotImplementedError("get_optim_groups must be implemented in subclass") 132 | -------------------------------------------------------------------------------- /meshgen/utils/math_utils.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Petr Kellnhofer 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import torch 24 | 25 | 26 | def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: 27 | """ 28 | Left-multiplies MxM @ NxM. Returns NxM. 29 | """ 30 | res = torch.matmul(vectors4, matrix.T) 31 | return res 32 | 33 | 34 | def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: 35 | """ 36 | Normalize vector lengths. 37 | """ 38 | return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) 39 | 40 | 41 | def torch_dot(x: torch.Tensor, y: torch.Tensor): 42 | """ 43 | Dot product of two tensors. 44 | """ 45 | return (x * y).sum(-1) 46 | 47 | 48 | def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length): 49 | """ 50 | Author: Petr Kellnhofer 51 | Intersects rays with the [-1, 1] NDC volume. 52 | Returns min and max distance of entry. 53 | Returns -1 for no intersection. 54 | https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection 55 | """ 56 | o_shape = rays_o.shape 57 | rays_o = rays_o.detach().reshape(-1, 3) 58 | rays_d = rays_d.detach().reshape(-1, 3) 59 | 60 | bb_min = [ 61 | -1 * (box_side_length / 2), 62 | -1 * (box_side_length / 2), 63 | -1 * (box_side_length / 2), 64 | ] 65 | bb_max = [ 66 | 1 * (box_side_length / 2), 67 | 1 * (box_side_length / 2), 68 | 1 * (box_side_length / 2), 69 | ] 70 | bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device) 71 | is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) 72 | 73 | # Precompute inverse for stability. 74 | invdir = 1 / rays_d 75 | sign = (invdir < 0).long() 76 | 77 | # Intersect with YZ plane. 78 | tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[ 79 | ..., 0 80 | ] 81 | tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[ 82 | ..., 0 83 | ] 84 | 85 | # Intersect with XZ plane. 86 | tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[ 87 | ..., 1 88 | ] 89 | tymax = ( 90 | bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1] 91 | ) * invdir[..., 1] 92 | 93 | # Resolve parallel rays. 94 | is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False 95 | 96 | # Use the shortest intersection. 97 | tmin = torch.max(tmin, tymin) 98 | tmax = torch.min(tmax, tymax) 99 | 100 | # Intersect with XY plane. 101 | tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[ 102 | ..., 2 103 | ] 104 | tzmax = ( 105 | bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2] 106 | ) * invdir[..., 2] 107 | 108 | # Resolve parallel rays. 109 | is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False 110 | 111 | # Use the shortest intersection. 112 | tmin = torch.max(tmin, tzmin) 113 | tmax = torch.min(tmax, tzmax) 114 | 115 | # Mark invalid. 116 | tmin[torch.logical_not(is_valid)] = -1 117 | tmax[torch.logical_not(is_valid)] = -2 118 | 119 | return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) 120 | 121 | 122 | def linspace(start: torch.Tensor, stop: torch.Tensor, num: int): 123 | """ 124 | Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. 125 | Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. 126 | """ 127 | # create a tensor of 'num' steps from 0 to 1 128 | steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1) 129 | 130 | # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings 131 | # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript 132 | # "cannot statically infer the expected size of a list in this contex", hence the code below 133 | for i in range(start.ndim): 134 | steps = steps.unsqueeze(-1) 135 | 136 | # the output starts at 'start' and increments until 'stop' in each dimension 137 | out = start[None] + steps * (stop - start)[None] 138 | 139 | return out 140 | -------------------------------------------------------------------------------- /meshgen/modules/resnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def nonlinearity(x): 7 | # swish 8 | return x * torch.sigmoid(x) 9 | 10 | 11 | def Normalize(in_channels, num_groups=32): 12 | return torch.nn.GroupNorm( 13 | num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True 14 | ) 15 | 16 | 17 | class Upsample(nn.Module): 18 | def __init__(self, in_channels, with_conv): 19 | super().__init__() 20 | self.with_conv = with_conv 21 | if self.with_conv: 22 | self.conv = torch.nn.Conv2d( 23 | in_channels, in_channels, kernel_size=3, stride=1, padding=1 24 | ) 25 | 26 | def forward(self, x): 27 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 28 | if self.with_conv: 29 | x = self.conv(x) 30 | return x 31 | 32 | 33 | class Downsample(nn.Module): 34 | def __init__(self, in_channels, with_conv): 35 | super().__init__() 36 | self.with_conv = with_conv 37 | if self.with_conv: 38 | # no asymmetric padding in torch conv, must do it ourselves 39 | self.conv = torch.nn.Conv2d( 40 | in_channels, in_channels, kernel_size=3, stride=2, padding=0 41 | ) 42 | 43 | def forward(self, x): 44 | if self.with_conv: 45 | pad = (0, 1, 0, 1) 46 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 47 | x = self.conv(x) 48 | else: 49 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 50 | return x 51 | 52 | 53 | class ResnetBlock(nn.Module): 54 | def __init__( 55 | self, 56 | *, 57 | in_channels, 58 | out_channels=None, 59 | conv_shortcut=False, 60 | dropout, 61 | temb_channels=512 62 | ): 63 | super().__init__() 64 | self.in_channels = in_channels 65 | out_channels = in_channels if out_channels is None else out_channels 66 | self.out_channels = out_channels 67 | self.use_conv_shortcut = conv_shortcut 68 | 69 | self.norm1 = Normalize(in_channels) 70 | self.conv1 = torch.nn.Conv2d( 71 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 72 | ) 73 | if temb_channels > 0: 74 | self.temb_proj = torch.nn.Linear(temb_channels, out_channels) 75 | self.norm2 = Normalize(out_channels) 76 | self.dropout = torch.nn.Dropout(dropout) 77 | self.conv2 = torch.nn.Conv2d( 78 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 79 | ) 80 | if self.in_channels != self.out_channels: 81 | if self.use_conv_shortcut: 82 | self.conv_shortcut = torch.nn.Conv2d( 83 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 84 | ) 85 | else: 86 | self.nin_shortcut = torch.nn.Conv2d( 87 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 88 | ) 89 | 90 | def forward(self, x, temb): 91 | h = x 92 | h = self.norm1(h) 93 | h = nonlinearity(h) 94 | h = self.conv1(h) 95 | 96 | if temb is not None: 97 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 98 | 99 | h = self.norm2(h) 100 | h = nonlinearity(h) 101 | h = self.dropout(h) 102 | h = self.conv2(h) 103 | 104 | if self.in_channels != self.out_channels: 105 | if self.use_conv_shortcut: 106 | x = self.conv_shortcut(x) 107 | else: 108 | x = self.nin_shortcut(x) 109 | 110 | return x + h 111 | 112 | 113 | class DeConvDecoder(nn.Module): 114 | def __init__( 115 | self, 116 | z_channels, 117 | num_resos, 118 | num_res_blocks, 119 | ch, 120 | out_ch, 121 | ch_mult=(1, 1, 1, 1), 122 | resolution=256, 123 | resamp_with_conv=True, 124 | dropout=0.0, 125 | ): 126 | super().__init__() 127 | 128 | self.num_resos = num_resos 129 | self.num_res_blocks = num_res_blocks 130 | 131 | block_in = ch * ch_mult[self.num_resos - 1] 132 | 133 | curr_res = resolution // 2 ** (self.num_resos - 1) 134 | self.z_shape = (1, z_channels, curr_res, curr_res) 135 | print( 136 | "Working with z of shape {} = {} dimensions.".format( 137 | self.z_shape, np.prod(self.z_shape) 138 | ) 139 | ) 140 | 141 | self.conv_in = torch.nn.Conv2d( 142 | z_channels, block_in, kernel_size=3, stride=1, padding=1 143 | ) 144 | 145 | self.up = nn.ModuleList() 146 | for i_level in reversed(range(self.num_resos)): 147 | block = nn.ModuleList() 148 | block_out = ch * ch_mult[i_level] 149 | for i_block in range(self.num_res_blocks + 1): 150 | block.append( 151 | ResnetBlock( 152 | in_channels=block_in, 153 | out_channels=block_out, 154 | temb_channels=0, 155 | dropout=dropout, 156 | ) 157 | ) 158 | block_in = block_out 159 | up = nn.Module() 160 | up.block = block 161 | if i_level != 0: 162 | up.upsample = Upsample(block_in, resamp_with_conv) 163 | curr_res = curr_res * 2 164 | self.up.insert(0, up) # prepend to get consistent order 165 | 166 | self.norm_out = Normalize(block_in) 167 | self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 168 | 169 | def forward(self, z): 170 | temb = None 171 | 172 | h = self.conv_in(z) 173 | 174 | # upsampling 175 | for i_level in reversed(range(self.num_resos)): 176 | for i_block in range(self.num_res_blocks + 1): 177 | h = self.up[i_level].block[i_block](h, temb) 178 | if i_level != 0: 179 | h = self.up[i_level].upsample(h) 180 | 181 | h = self.norm_out(h) 182 | h = nonlinearity(h) 183 | h = self.conv_out(h) 184 | 185 | return h 186 | -------------------------------------------------------------------------------- /meshgen/utils/ray_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.autograd.profiler as profiler 5 | import numpy as np 6 | from einops import rearrange, repeat, einsum 7 | 8 | from .math_utils import linspace, get_ray_limits_box 9 | 10 | 11 | def FOV_to_intrinsics(fov, device="cpu"): 12 | """ 13 | Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees. 14 | Note the intrinsics are returned as normalized by image size, rather than in pixel units. 15 | Assumes principal point is at image center. 16 | """ 17 | focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5) 18 | intrinsics = torch.tensor( 19 | [[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device 20 | ) 21 | return intrinsics 22 | 23 | 24 | class RayGenerator(torch.nn.Module): 25 | """ 26 | from camera pose and intrinsics to ray origins and directions 27 | """ 28 | 29 | def __init__(self): 30 | super().__init__() 31 | ( 32 | self.ray_origins_h, 33 | self.ray_directions, 34 | self.depths, 35 | self.image_coords, 36 | self.rendering_options, 37 | ) = (None, None, None, None, None) 38 | 39 | def forward(self, cam2world_matrix, fov, render_size): 40 | """ 41 | Create batches of rays and return origins and directions. 42 | 43 | cam2world_matrix: (N, 4, 4) 44 | intrinsics: (N, 3, 3) 45 | render_size: int 46 | 47 | ray_origins: (N, M, 3) 48 | ray_dirs: (N, M, 2) 49 | """ 50 | intrinsics = ( 51 | FOV_to_intrinsics(fov) 52 | .to(cam2world_matrix)[None] 53 | .repeat(cam2world_matrix.shape[0], 1, 1) 54 | ) 55 | 56 | N, M = cam2world_matrix.shape[0], render_size**2 57 | cam_locs_world = cam2world_matrix[:, :3, 3] 58 | fx = intrinsics[:, 0, 0] 59 | fy = intrinsics[:, 1, 1] 60 | cx = intrinsics[:, 0, 2] 61 | cy = intrinsics[:, 1, 2] 62 | sk = intrinsics[:, 0, 1] 63 | 64 | uv = torch.stack( 65 | torch.meshgrid( 66 | torch.arange( 67 | render_size, dtype=torch.float32, device=cam2world_matrix.device 68 | ), 69 | torch.arange( 70 | render_size, dtype=torch.float32, device=cam2world_matrix.device 71 | ), 72 | indexing="ij", 73 | ) 74 | ) 75 | uv = uv.flip(0).reshape(2, -1).transpose(1, 0) 76 | uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) 77 | 78 | x_cam = uv[:, :, 0].view(N, -1) * (1.0 / render_size) + (0.5 / render_size) 79 | y_cam = uv[:, :, 1].view(N, -1) * (1.0 / render_size) + (0.5 / render_size) 80 | z_cam = torch.ones((N, M), device=cam2world_matrix.device) 81 | 82 | x_lift = ( 83 | ( 84 | x_cam 85 | - cx.unsqueeze(-1) 86 | + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1) 87 | - sk.unsqueeze(-1) * y_cam / fy.unsqueeze(-1) 88 | ) 89 | / fx.unsqueeze(-1) 90 | * z_cam 91 | ) 92 | y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam 93 | 94 | cam_rel_points = torch.stack( 95 | (x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1 96 | ) 97 | 98 | # NOTE: this should be named _blender2opencv 99 | _opencv2blender = ( 100 | torch.tensor( 101 | [ 102 | [1, 0, 0, 0], 103 | [0, -1, 0, 0], 104 | [0, 0, -1, 0], 105 | [0, 0, 0, 1], 106 | ], 107 | dtype=torch.float32, 108 | device=cam2world_matrix.device, 109 | ) 110 | .unsqueeze(0) 111 | .repeat(N, 1, 1) 112 | ) 113 | 114 | cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) 115 | 116 | world_rel_points = torch.bmm( 117 | cam2world_matrix, cam_rel_points.permute(0, 2, 1) 118 | ).permute(0, 2, 1)[:, :, :3] 119 | 120 | ray_dirs = world_rel_points - cam_locs_world[:, None, :] 121 | ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2) 122 | 123 | ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) 124 | 125 | return ray_origins, ray_dirs 126 | 127 | 128 | class RaySampler(torch.nn.Module): 129 | def __init__( 130 | self, 131 | num_samples_per_ray, 132 | bbox_length=1.0, 133 | near=0.1, 134 | far=10.0, 135 | drop_invalid=False, 136 | disparity=False, 137 | ): 138 | super().__init__() 139 | self.num_samples_per_ray = num_samples_per_ray 140 | self.bbox_length = bbox_length 141 | self.near = near 142 | self.far = far 143 | self.disparity = disparity 144 | self.drop_invalid = drop_invalid 145 | 146 | def forward(self, ray_origins, ray_directions): 147 | if not self.disparity: 148 | t_start, t_end = get_ray_limits_box( 149 | ray_origins, ray_directions, self.bbox_length 150 | ) 151 | else: 152 | t_start = torch.full_like(ray_origins, self.near) 153 | t_end = torch.full_like(ray_origins, self.far) 154 | is_ray_valid = t_end > t_start 155 | if not self.drop_invalid: 156 | if torch.any(is_ray_valid).item(): 157 | t_start[~is_ray_valid] = t_start[is_ray_valid].min() 158 | t_end[~is_ray_valid] = t_start[is_ray_valid].max() 159 | else: 160 | is_ray_valid = is_ray_valid[..., 0] 161 | ray_origins = ray_origins[is_ray_valid] 162 | ray_directions = ray_directions[is_ray_valid] 163 | t_start = t_start[is_ray_valid] 164 | t_end = t_end[is_ray_valid] 165 | 166 | if not self.disparity: 167 | depths = linspace(t_start, t_end, self.num_samples_per_ray) 168 | depths += ( 169 | torch.rand_like(depths) 170 | * (t_end - t_start) 171 | / (self.num_samples_per_ray - 1) 172 | ) 173 | else: 174 | step = 1.0 / self.num_samples_per_ray 175 | z_steps = torch.linspace( 176 | 0, 1 - step, self.num_samples_per_ray, device=ray_origins.device 177 | ) 178 | z_steps += torch.rand_like(z_steps) * step 179 | depths = 1 / (1 / self.near * (1 - z_steps) + 1 / self.far * z_steps) 180 | depths = depths[..., None, None, None] 181 | 182 | return ray_origins[None] + ray_directions[None] * depths 183 | -------------------------------------------------------------------------------- /meshgen/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = "\n".join( 28 | xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) 29 | ) 30 | 31 | try: 32 | draw.text((0, 0), lines, fill="black", font=font) 33 | except UnicodeEncodeError: 34 | print("Cant encode string for logging. Skipping.") 35 | 36 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 37 | txts.append(txt) 38 | txts = np.stack(txts) 39 | txts = torch.tensor(txts) 40 | return txts 41 | 42 | 43 | def ismap(x): 44 | if not isinstance(x, torch.Tensor): 45 | return False 46 | return (len(x.shape) == 4) and (x.shape[1] > 3) 47 | 48 | 49 | def isimage(x): 50 | if not isinstance(x, torch.Tensor): 51 | return False 52 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 53 | 54 | 55 | def exists(x): 56 | return x is not None 57 | 58 | 59 | def default(val, d): 60 | if exists(val): 61 | return val 62 | return d() if isfunction(d) else d 63 | 64 | 65 | def mean_flat(tensor): 66 | """ 67 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 68 | Take the mean over all non-batch dimensions. 69 | """ 70 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 71 | 72 | 73 | def count_params(model, verbose=False): 74 | total_params = sum(p.numel() for p in model.parameters()) 75 | if verbose: 76 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 77 | return total_params 78 | 79 | 80 | def instantiate_from_config(config): 81 | if not "target" in config: 82 | if config == "__is_first_stage__": 83 | return None 84 | elif config == "__is_unconditional__": 85 | return None 86 | raise KeyError("Expected key `target` to instantiate.") 87 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 88 | 89 | 90 | def get_obj_from_str(string, reload=False): 91 | module, cls = string.rsplit(".", 1) 92 | if reload: 93 | module_imp = importlib.import_module(module) 94 | importlib.reload(module_imp) 95 | return getattr(importlib.import_module(module, package=None), cls) 96 | 97 | 98 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 99 | # create dummy dataset instance 100 | 101 | # run prefetching 102 | if idx_to_fn: 103 | res = func(data, worker_id=idx) 104 | else: 105 | res = func(data) 106 | Q.put([idx, res]) 107 | Q.put("Done") 108 | 109 | 110 | def parallel_data_prefetch( 111 | func: callable, 112 | data, 113 | n_proc, 114 | target_data_type="ndarray", 115 | cpu_intensive=True, 116 | use_worker_id=False, 117 | ): 118 | # if target_data_type not in ["ndarray", "list"]: 119 | # raise ValueError( 120 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 121 | # ) 122 | if isinstance(data, np.ndarray) and target_data_type == "list": 123 | raise ValueError("list expected but function got ndarray.") 124 | elif isinstance(data, abc.Iterable): 125 | if isinstance(data, dict): 126 | print( 127 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 128 | ) 129 | data = list(data.values()) 130 | if target_data_type == "ndarray": 131 | data = np.asarray(data) 132 | else: 133 | data = list(data) 134 | else: 135 | raise TypeError( 136 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 137 | ) 138 | 139 | if cpu_intensive: 140 | Q = mp.Queue(1000) 141 | proc = mp.Process 142 | else: 143 | Q = Queue(1000) 144 | proc = Thread 145 | # spawn processes 146 | if target_data_type == "ndarray": 147 | arguments = [ 148 | [func, Q, part, i, use_worker_id] 149 | for i, part in enumerate(np.array_split(data, n_proc)) 150 | ] 151 | else: 152 | step = ( 153 | int(len(data) / n_proc + 1) 154 | if len(data) % n_proc != 0 155 | else int(len(data) / n_proc) 156 | ) 157 | arguments = [ 158 | [func, Q, part, i, use_worker_id] 159 | for i, part in enumerate( 160 | [data[i : i + step] for i in range(0, len(data), step)] 161 | ) 162 | ] 163 | processes = [] 164 | for i in range(n_proc): 165 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 166 | processes += [p] 167 | 168 | # start processes 169 | print(f"Start prefetching...") 170 | import time 171 | 172 | start = time.time() 173 | gather_res = [[] for _ in range(n_proc)] 174 | try: 175 | for p in processes: 176 | p.start() 177 | 178 | k = 0 179 | while k < n_proc: 180 | # get result 181 | res = Q.get() 182 | if res == "Done": 183 | k += 1 184 | else: 185 | gather_res[res[0]] = res[1] 186 | 187 | except Exception as e: 188 | print("Exception: ", e) 189 | for p in processes: 190 | p.terminate() 191 | 192 | raise e 193 | finally: 194 | for p in processes: 195 | p.join() 196 | print(f"Prefetching complete. [{time.time() - start} sec.]") 197 | 198 | if target_data_type == "ndarray": 199 | if not isinstance(gather_res[0], np.ndarray): 200 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 201 | 202 | # order outputs 203 | return np.concatenate(gather_res, axis=0) 204 | elif target_data_type == "list": 205 | out = [] 206 | for r in gather_res: 207 | out.extend(r) 208 | return out 209 | else: 210 | return gather_res 211 | -------------------------------------------------------------------------------- /meshgen/utils/render_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from pathlib import Path 4 | from trimesh.transformations import rotation_matrix 5 | from einops import repeat, rearrange 6 | 7 | from meshgen.modules.mesh.render import Renderer 8 | from meshgen.utils.io import load_mesh 9 | 10 | 11 | def make_zero123pp_grid(data): 12 | masks = torch.stack(data["masks"], dim=0) 13 | depths = torch.stack(data["depths"], dim=0) 14 | normals = torch.stack(data["normals"], dim=0) 15 | 16 | depths = repeat(depths, "b h w 1 -> b h w c", c=3) 17 | depths = torch.cat([depths, masks], dim=-1) 18 | normals = torch.cat([normals, masks], dim=-1) 19 | 20 | [depths, normals, masks] = [ 21 | rearrange(img, "(a b) h w c -> (a h) (b w) c", a=3, b=2) 22 | for img in [depths, normals, masks] 23 | ] 24 | 25 | ret = dict(depths=depths, normals=normals, masks=masks) 26 | 27 | if "rgbs" in data: 28 | rgbs = torch.stack(data["rgbs"], dim=0) 29 | rgbs = rearrange(rgbs, "(a b) h w c -> (a h) (b w) c", a=3, b=2) 30 | # rgbs = torch.cat([rgbs, masks], dim=-1) 31 | ret["rgbs"] = rgbs 32 | 33 | return ret 34 | 35 | 36 | def normalize_vertices_to_cube(vertices: torch.Tensor, scale: float = 1.0): # V,3 37 | """shift and resize mesh to fit into a unit cube""" 38 | offset = (vertices.min(dim=0)[0] + vertices.max(dim=0)[0]) / 2 39 | vertices -= offset 40 | scale_factor = 2.0 / (vertices.max(dim=0)[0] - vertices.min(dim=0)[0]).max() 41 | vertices *= scale_factor 42 | return vertices * scale 43 | 44 | 45 | def render_orthogonal_4views( 46 | mesh_filename, 47 | reso, 48 | radius=2.5, 49 | rotate=False, 50 | depth_min_val=0.5, 51 | renderer=None, 52 | ): 53 | normals = [] 54 | depths = [] 55 | masks = [] 56 | device = "cuda" 57 | 58 | if isinstance(mesh_filename, str): 59 | mesh = load_mesh(mesh_filename, device) 60 | mesh.vertices = normalize_vertices_to_cube(mesh.vertices) 61 | if rotate: 62 | rotmat = rotation_matrix(-np.pi / 2, [1, 0, 0]) @ rotation_matrix( 63 | -np.pi / 2, [0, 0, 1] 64 | ) 65 | rotmat = torch.from_numpy(rotmat.astype(np.float32)).to(device)[:3, :3] 66 | mesh.vertices = mesh.vertices @ rotmat.T 67 | else: 68 | mesh = mesh_filename 69 | 70 | if renderer is None: 71 | renderer = Renderer(mesh.faces.shape[0], device) 72 | 73 | azimuths = [0, np.pi / 2, np.pi, np.pi * 1.5] 74 | for azi in azimuths: 75 | rendered = renderer.render_single_view( 76 | mesh.vertices, 77 | mesh.faces, 78 | np.pi / 2, 79 | azi, 80 | radius, 81 | dims=(reso, reso), 82 | depth_min_val=depth_min_val, 83 | ) 84 | this_normal = rendered["normal_map"][0] * 0.5 + 0.5 85 | normals.append(this_normal) 86 | depths.append(rendered["depth_map"][0]) 87 | masks.append(rendered["mask"][0]) 88 | 89 | data = dict() 90 | data["normals"] = normals 91 | data["depths"] = depths 92 | data["masks"] = masks 93 | 94 | return data 95 | 96 | 97 | def render_zero123pp_6views( 98 | mesh_filename, 99 | reso=320, 100 | radius=2.5, 101 | rotate=True, 102 | renderer=None, 103 | depth_min_val=0.0, 104 | ): 105 | normals = [] 106 | depths = [] 107 | masks = [] 108 | device = "cuda" 109 | if isinstance(mesh_filename, str): 110 | mesh = load_mesh(mesh_filename, device) 111 | mesh.vertices = normalize_vertices_to_cube(mesh.vertices) 112 | if rotate: 113 | rotmat = rotation_matrix(-np.pi / 2, [1, 0, 0]) @ rotation_matrix( 114 | -np.pi / 2, [0, 0, 1] 115 | ) 116 | rotmat = torch.from_numpy(rotmat.astype(np.float32)).to(device)[:3, :3] 117 | mesh.vertices = mesh.vertices @ rotmat.T 118 | else: 119 | mesh = mesh_filename 120 | 121 | if renderer is None: 122 | renderer = Renderer(mesh.faces.shape[0], device) 123 | azimuths = np.deg2rad(np.array([30, 90, 150, 210, 270, 330])).astype(np.float32) 124 | elevations = ( 125 | -np.deg2rad(np.array([20, -10, 20, -10, 20, -10])) + np.pi / 2 126 | ).astype(np.float32) 127 | 128 | for azi, ele in zip(azimuths, elevations): 129 | rendered = renderer.render_single_view( 130 | mesh.vertices, 131 | mesh.faces, 132 | ele, 133 | azi, 134 | radius, 135 | depth_min_val=depth_min_val, 136 | dims=(reso, reso), 137 | ) 138 | this_normal = rendered["normal_map"][0] * 0.5 + 0.5 139 | normals.append(this_normal) 140 | depths.append(rendered["depth_map"][0]) 141 | masks.append(rendered["mask"][0]) 142 | 143 | data = dict() 144 | data["normals"] = normals 145 | data["depths"] = depths 146 | data["masks"] = masks 147 | zero123pp_grid = make_zero123pp_grid(data) 148 | 149 | return zero123pp_grid 150 | 151 | 152 | @torch.no_grad() 153 | def render_zero123pp_6views_rgbs( 154 | mesh, 155 | reso=320, 156 | radius=2.5, 157 | bg="white", 158 | version="1.2", 159 | renderer=None, 160 | depth_min_val=0.0, 161 | flip_normals=False, 162 | ): 163 | normals = [] 164 | depths = [] 165 | masks = [] 166 | rgbs = [] 167 | 168 | if version == "1.2": 169 | azimuths = np.deg2rad(np.array([30, 90, 150, 210, 270, 330])).astype(np.float32) 170 | elevations = ( 171 | -np.deg2rad(np.array([20, -10, 20, -10, 20, -10])) + np.pi / 2 172 | ).astype(np.float32) 173 | else: 174 | azimuths = np.deg2rad(np.array([30, 90, 150, 210, 270, 330])).astype(np.float32) 175 | elevations = ( 176 | -np.deg2rad(np.array([30, -20, 30, -20, 30, -20])) + np.pi / 2 177 | ).astype(np.float32) 178 | 179 | move_axis = lambda x: rearrange(x, "c h w -> h w c") 180 | bg = 1.0 if bg == "white" else 0.5 181 | for azi, ele in zip(azimuths, elevations): 182 | rendered = mesh.render( 183 | ele, azi, radius, dims=(reso, reso), depth_min_val=depth_min_val 184 | ) 185 | # grey background 186 | rgbs.append( 187 | move_axis( 188 | rendered["mask"][0] * rendered["image"][0] 189 | + (1 - rendered["mask"][0]) * bg 190 | ) 191 | ) 192 | if flip_normals: 193 | rendered["normals"][0] = -rendered["normals"][0] 194 | this_normal = rendered["normals"][0] * 0.5 + 0.5 195 | normals.append(move_axis(this_normal)) 196 | depths.append(move_axis(rendered["depth"][0])) 197 | masks.append(move_axis(rendered["mask"][0])) 198 | 199 | data = dict() 200 | data["normals"] = normals 201 | data["depths"] = depths 202 | data["masks"] = masks 203 | data["rgbs"] = rgbs 204 | zero123pp_grid = make_zero123pp_grid(data) 205 | 206 | return zero123pp_grid 207 | -------------------------------------------------------------------------------- /texgen.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import datetime 4 | import numpy as np 5 | import torch 6 | import json 7 | from pathlib import Path 8 | from omegaconf import OmegaConf 9 | from pytorch_lightning import seed_everything 10 | 11 | import torch.distributed as dist 12 | 13 | from meshgen.util import instantiate_from_config 14 | from meshgen.utils.hf_weights import ( 15 | pbr_decomposer_path, 16 | texture_inpainter_path, 17 | mv_generator_path, 18 | ) 19 | from meshgen.utils.misc import set_if_none 20 | 21 | 22 | def get_parser(**parser_kwargs): 23 | def str2bool(v): 24 | if isinstance(v, bool): 25 | return v 26 | if v.lower() in ("yes", "true", "t", "y", "1"): 27 | return True 28 | elif v.lower() in ("no", "false", "f", "n", "0"): 29 | return False 30 | else: 31 | raise argparse.ArgumentTypeError("Boolean value expected.") 32 | 33 | parser = argparse.ArgumentParser(**parser_kwargs) 34 | parser.add_argument( 35 | "-n", 36 | "--name", 37 | type=str, 38 | const=True, 39 | default="", 40 | nargs="?", 41 | help="postfix for logdir", 42 | ) 43 | parser.add_argument( 44 | "-r", 45 | "--resume", 46 | type=str, 47 | const=True, 48 | default="", 49 | nargs="?", 50 | help="resume from logdir or checkpoint in logdir", 51 | ) 52 | parser.add_argument( 53 | "-b", 54 | "--base", 55 | nargs="*", 56 | metavar="base_config.yaml", 57 | help="paths to base configs. Loaded from left-to-right. " 58 | "Parameters can be overwritten or added with command-line options of the form `--key value`.", 59 | default=["configs/texgen.yaml"], 60 | ) 61 | parser.add_argument( 62 | "-p", "--project", help="name of new or path to existing project" 63 | ) 64 | parser.add_argument( 65 | "-d", 66 | "--debug", 67 | type=str2bool, 68 | nargs="?", 69 | const=True, 70 | default=False, 71 | help="enable post-mortem debugging", 72 | ) 73 | parser.add_argument( 74 | "--no_ignore_alpha", 75 | type=str2bool, 76 | nargs="?", 77 | const=True, 78 | default=False, 79 | help="enable post-mortem debugging", 80 | ) 81 | parser.add_argument( 82 | "-s", 83 | "--seed", 84 | type=int, 85 | default=23, 86 | help="seed for seed_everything", 87 | ) 88 | parser.add_argument( 89 | "-f", 90 | "--postfix", 91 | type=str, 92 | default="", 93 | help="post-postfix for default name", 94 | ) 95 | parser.add_argument( 96 | "-l", 97 | "--logdir", 98 | type=str, 99 | default="logs/texturing", 100 | help="directory for logging dat shit", 101 | ) 102 | parser.add_argument( 103 | "-m", 104 | "--meta", 105 | type=str, 106 | default=None, 107 | help="The meta data file for all meshes an their corresponding images", 108 | ) 109 | parser.add_argument( 110 | "--max_items", 111 | type=int, 112 | default=-1, 113 | help="The meta data file for all meshes an their corresponding images", 114 | ) 115 | parser.add_argument( 116 | "--scale_lr", 117 | type=str2bool, 118 | nargs="?", 119 | const=True, 120 | default=False, 121 | help="scale base-lr by ngpu * batch_size * n_accumulate", 122 | ) 123 | parser.add_argument( 124 | "--default_logger", 125 | type=str, 126 | help="The default logger to use", 127 | choices=["testtube", "wandb"], 128 | default="wandb", 129 | ) 130 | return parser 131 | 132 | 133 | def main(): 134 | local_rank = dist.get_rank() 135 | world_size = dist.get_world_size() 136 | torch.cuda.set_device(local_rank) 137 | parser = get_parser() 138 | opt, unknown = parser.parse_known_args() 139 | seed_everything(opt.seed) 140 | 141 | if opt.name: 142 | name = "_" + opt.name 143 | elif opt.base: 144 | cfg_fname = os.path.split(opt.base[0])[-1] 145 | cfg_name = os.path.splitext(cfg_fname)[0] 146 | name = "_" + cfg_name 147 | else: 148 | name = "" 149 | 150 | if local_rank == 0: 151 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 152 | else: 153 | now = None 154 | 155 | dist.barrier() 156 | outputs = [now] 157 | 158 | dist.broadcast_object_list(outputs, 0) 159 | now = outputs[0] 160 | 161 | nowname = now + name + opt.postfix 162 | 163 | include_ema = False 164 | 165 | if opt.base == []: 166 | opt.base = ["configs/texgen.yaml"] 167 | configs = [OmegaConf.load(cfg) for cfg in opt.base] 168 | cli = OmegaConf.from_dotlist(unknown) 169 | config = OmegaConf.merge(*configs, cli) 170 | logdir = Path(opt.logdir) / nowname 171 | config.params.exp_dir = str(logdir) 172 | if local_rank == 0: 173 | logdir.mkdir(parents=True, exist_ok=True) 174 | 175 | cfgdir = logdir / "configs" 176 | if local_rank == 0: 177 | cfgdir.mkdir(parents=True, exist_ok=True) 178 | OmegaConf.save(config, cfgdir / "config.yaml") 179 | 180 | meta = json.load(open(opt.meta, "r")) 181 | max_items = opt.max_items 182 | if max_items < 0: 183 | max_items = len(meta) 184 | 185 | if not include_ema: 186 | new_meta = [] 187 | for m in meta: 188 | if "ema" not in m["mesh"]: 189 | new_meta.append(m) 190 | meta = new_meta 191 | 192 | # meta = sorted(meta, key=lambda x: x["mesh"])[local_rank::world_size] 193 | meta = meta[local_rank::world_size] 194 | 195 | set_if_none( 196 | config.params.multiview_generator.params, "ckpt_path", mv_generator_path 197 | ) 198 | set_if_none( 199 | config.params.texture_inpainter.params, 200 | "ckpt_path", 201 | texture_inpainter_path, 202 | ) 203 | set_if_none(config.params.pbr_decomposer.params, "ckpt_path", pbr_decomposer_path) 204 | 205 | print(f"Using config file: {opt.base}") 206 | print(f"Using meta file: {opt.meta}") 207 | print(f"Using multi-view generator: {mv_generator_path}") 208 | print(f"Using texture inpainter: {texture_inpainter_path}") 209 | print(f"Using PBR decomposer: {pbr_decomposer_path}") 210 | 211 | model = instantiate_from_config(config) 212 | for idx, m in enumerate(meta): 213 | if idx >= max_items: 214 | break 215 | try: 216 | model( 217 | m["mesh"], 218 | m["image"], 219 | verbose=True, 220 | front_view_only=False, 221 | skip_front_view=True, 222 | debug=True, 223 | ignore_alpha=not opt.no_ignore_alpha, 224 | ) 225 | except: 226 | raise 227 | print(f"Error on {m['mesh']}") 228 | pass 229 | 230 | 231 | if __name__ == "__main__": 232 | dist.init_process_group(backend="nccl") 233 | main() 234 | -------------------------------------------------------------------------------- /zero123pp/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.utils.torch_utils import randn_tensor 3 | 4 | from meshgen.modules.mesh.textured_mesh import TexturedMesh 5 | from meshgen.utils.io import write_image 6 | from meshgen.utils.ops import latent_preview 7 | 8 | """ 9 | 10 | Customized Step Function 11 | step on texture 12 | texture 13 | 14 | """ 15 | 16 | 17 | @torch.no_grad() 18 | def step_tex( 19 | scheduler, 20 | mesh: TexturedMesh, 21 | model_output: torch.FloatTensor, 22 | timestep: int, 23 | sample: torch.FloatTensor, 24 | texture: None, 25 | generator=None, 26 | return_dict: bool = True, 27 | guidance_scale=1, 28 | main_views=[], 29 | hires_original_views=True, 30 | exp=None, 31 | cos_weighted=True, 32 | radius=4.5, 33 | background=None, 34 | fusion_method="l6", 35 | interpolation_mode="nearest", 36 | render_reso=(40, 40), 37 | ): 38 | t = timestep 39 | # print( 40 | # t.item(), 41 | # sample.shape, 42 | # model_output.shape, 43 | # interpolation_mode, 44 | # render_reso, 45 | # fusion_method, 46 | # ) 47 | 48 | prev_t = scheduler.previous_timestep(t) 49 | 50 | if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in [ 51 | "learned", 52 | "learned_range", 53 | ]: 54 | model_output, predicted_variance = torch.split( 55 | model_output, sample.shape[1], dim=1 56 | ) 57 | else: 58 | predicted_variance = None 59 | 60 | # 1. compute alphas, betas 61 | alpha_prod_t = scheduler.alphas_cumprod[t] 62 | alpha_prod_t_prev = ( 63 | scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one 64 | ) 65 | beta_prod_t = 1 - alpha_prod_t 66 | beta_prod_t_prev = 1 - alpha_prod_t_prev 67 | current_alpha_t = alpha_prod_t / alpha_prod_t_prev 68 | current_beta_t = 1 - current_alpha_t 69 | 70 | # 2. compute predicted original sample from predicted noise also called 71 | # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf 72 | if scheduler.config.prediction_type == "epsilon": 73 | pred_original_sample = ( 74 | sample - beta_prod_t ** (0.5) * model_output 75 | ) / alpha_prod_t ** (0.5) 76 | elif scheduler.config.prediction_type == "sample": 77 | pred_original_sample = model_output 78 | elif scheduler.config.prediction_type == "v_prediction": 79 | pred_original_sample = (alpha_prod_t**0.5) * sample - ( 80 | beta_prod_t**0.5 81 | ) * model_output 82 | else: 83 | raise ValueError( 84 | f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" 85 | " `v_prediction` for the DDPMScheduler." 86 | ) 87 | 88 | # 3. Clip or threshold "predicted x_0" 89 | if scheduler.config.thresholding: 90 | pred_original_sample = scheduler._threshold_sample(pred_original_sample) 91 | elif scheduler.config.clip_sample: 92 | pred_original_sample = pred_original_sample.clamp( 93 | -scheduler.config.clip_sample_range, scheduler.config.clip_sample_range 94 | ) 95 | 96 | # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t 97 | # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf 98 | pred_original_sample_coeff = ( 99 | alpha_prod_t_prev ** (0.5) * current_beta_t 100 | ) / beta_prod_t 101 | current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t 102 | 103 | """ 104 | Add multidiffusion here 105 | """ 106 | 107 | if texture is None: 108 | # will not arrive 109 | sample_views = [view for view in sample] 110 | sample_views, texture, _ = uvp.bake_texture( 111 | views=sample_views, main_views=main_views, exp=exp 112 | ) 113 | sample_views = torch.stack(sample_views, axis=0) 114 | 115 | # original_views = [view for view in pred_original_sample] 116 | # original_views, original_tex, visibility_weights = mesh.bake_texture( 117 | # views=original_views, main_views=main_views, exp=exp 118 | # ) 119 | mesh.backproject_zero123pp_6views_latents( 120 | pred_original_sample, radius, fusion_method=fusion_method 121 | ) 122 | original_views, _ = mesh.render_zero123pp_6views_latents( 123 | radius, reso=render_reso, interpolation_mode=interpolation_mode 124 | ) 125 | preview = latent_preview(original_views[None])[0] 126 | # write_image(f"trash/test_mesh_denoising/preview_{t}.png", preview) 127 | 128 | # uvp.set_texture_map(original_tex) 129 | # original_views = uvp.render_textured_views() 130 | # original_views = torch.stack(original_views, axis=0) 131 | 132 | original_tex = torch.clone(mesh.latent_img.data) 133 | tex_preview = latent_preview(original_tex)[0] 134 | # write_image(f"trash/test_mesh_denoising/tex_preview_{t}.png", tex_preview) 135 | 136 | # 5. Compute predicted previous sample µ_t 137 | # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf 138 | # pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample 139 | prev_tex = ( 140 | pred_original_sample_coeff * original_tex + current_sample_coeff * texture 141 | ) 142 | 143 | # 6. Add noise 144 | variance = 0 145 | 146 | if predicted_variance is not None: 147 | variance_views = [view for view in predicted_variance] 148 | variance_views, variance_tex, visibility_weights = uvp.bake_texture( 149 | views=variance_views, 150 | main_views=main_views, 151 | cos_weighted=cos_weighted, 152 | exp=exp, 153 | ) 154 | variance_views = torch.stack(variance_views, axis=0)[:, :-1, ...] 155 | else: 156 | variance_tex = None 157 | 158 | if t > 0: 159 | device = texture.device 160 | variance_noise = randn_tensor( 161 | texture.shape, generator=generator, device=device, dtype=texture.dtype 162 | ) 163 | if scheduler.variance_type == "fixed_small_log": 164 | variance = ( 165 | scheduler._get_variance(t, predicted_variance=variance_tex) 166 | * variance_noise 167 | ) 168 | elif scheduler.variance_type == "learned_range": 169 | variance = scheduler._get_variance(t, predicted_variance=variance_tex) 170 | variance = torch.exp(0.5 * variance) * variance_noise 171 | else: 172 | variance = ( 173 | scheduler._get_variance(t, predicted_variance=variance_tex) ** 0.5 174 | ) * variance_noise 175 | 176 | prev_tex = prev_tex + variance 177 | 178 | # uvp.set_texture_map(prev_tex) 179 | # prev_views = uvp.render_textured_views() 180 | mesh.latent_img.data = prev_tex 181 | prev_views, mask = mesh.render_zero123pp_6views_latents( 182 | radius, reso=render_reso, interpolation_mode=interpolation_mode 183 | ) 184 | pred_prev_sample = torch.clone(sample) 185 | pred_prev_sample = prev_views[None] 186 | 187 | if background is not None: 188 | mask = mask[None] 189 | if t > 0: 190 | alphas_cumprod = scheduler.alphas_cumprod[t] 191 | noise = torch.normal(0, 1, background.shape, device=background.device) 192 | background = (1 - alphas_cumprod) * noise + alphas_cumprod * background 193 | pred_prev_sample = pred_prev_sample * mask + background * (1 - mask) 194 | 195 | # for i, view in enumerate(prev_views): 196 | # pred_prev_sample[i] = view 197 | # masks = [view[-1:] for view in prev_views] 198 | 199 | return { 200 | "prev_sample": pred_prev_sample, 201 | "pred_original_sample": pred_original_sample, 202 | "prev_tex": prev_tex, 203 | } 204 | 205 | if not return_dict: 206 | return pred_prev_sample, pred_original_sample 207 | pass 208 | -------------------------------------------------------------------------------- /meshgen/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from meshgen.utils.ops import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return {el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = ( 53 | nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) 54 | if not glu 55 | else GEGLU(dim, inner_dim) 56 | ) 57 | 58 | self.net = nn.Sequential( 59 | project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) 60 | ) 61 | 62 | def forward(self, x): 63 | return self.net(x) 64 | 65 | 66 | def zero_module(module): 67 | """ 68 | Zero out the parameters of a module and return it. 69 | """ 70 | for p in module.parameters(): 71 | p.detach().zero_() 72 | return module 73 | 74 | 75 | def Normalize(in_channels): 76 | return torch.nn.GroupNorm( 77 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 78 | ) 79 | 80 | 81 | class LinearAttention(nn.Module): 82 | def __init__(self, dim, heads=4, dim_head=32): 83 | super().__init__() 84 | self.heads = heads 85 | hidden_dim = dim_head * heads 86 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 87 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 88 | 89 | def forward(self, x): 90 | b, c, h, w = x.shape 91 | qkv = self.to_qkv(x) 92 | q, k, v = rearrange( 93 | qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 94 | ) 95 | k = k.softmax(dim=-1) 96 | context = torch.einsum("bhdn,bhen->bhde", k, v) 97 | out = torch.einsum("bhde,bhdn->bhen", context, q) 98 | out = rearrange( 99 | out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w 100 | ) 101 | return self.to_out(out) 102 | 103 | 104 | class SpatialSelfAttention(nn.Module): 105 | def __init__(self, in_channels): 106 | super().__init__() 107 | self.in_channels = in_channels 108 | 109 | self.norm = Normalize(in_channels) 110 | self.q = torch.nn.Conv2d( 111 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 112 | ) 113 | self.k = torch.nn.Conv2d( 114 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 115 | ) 116 | self.v = torch.nn.Conv2d( 117 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 118 | ) 119 | self.proj_out = torch.nn.Conv2d( 120 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 121 | ) 122 | 123 | def forward(self, x): 124 | h_ = x 125 | h_ = self.norm(h_) 126 | q = self.q(h_) 127 | k = self.k(h_) 128 | v = self.v(h_) 129 | 130 | # compute attention 131 | b, c, h, w = q.shape 132 | q = rearrange(q, "b c h w -> b (h w) c") 133 | k = rearrange(k, "b c h w -> b c (h w)") 134 | w_ = torch.einsum("bij,bjk->bik", q, k) 135 | 136 | w_ = w_ * (int(c) ** (-0.5)) 137 | w_ = torch.nn.functional.softmax(w_, dim=2) 138 | 139 | # attend to values 140 | v = rearrange(v, "b c h w -> b c (h w)") 141 | w_ = rearrange(w_, "b i j -> b j i") 142 | h_ = torch.einsum("bij,bjk->bik", v, w_) 143 | h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) 144 | h_ = self.proj_out(h_) 145 | 146 | return x + h_ 147 | 148 | 149 | class CrossAttention(nn.Module): 150 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 151 | super().__init__() 152 | inner_dim = dim_head * heads 153 | context_dim = default(context_dim, query_dim) 154 | 155 | self.scale = dim_head**-0.5 156 | self.heads = heads 157 | 158 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 159 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 160 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 161 | 162 | self.to_out = nn.Sequential( 163 | nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) 164 | ) 165 | 166 | def forward(self, x, context=None, mask=None): 167 | h = self.heads 168 | 169 | q = self.to_q(x) 170 | context = default(context, x) 171 | k = self.to_k(context) 172 | v = self.to_v(context) 173 | 174 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 175 | 176 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale 177 | 178 | if exists(mask): 179 | mask = rearrange(mask, "b ... -> b (...)") 180 | max_neg_value = -torch.finfo(sim.dtype).max 181 | mask = repeat(mask, "b j -> (b h) () j", h=h) 182 | sim.masked_fill_(~mask, max_neg_value) 183 | 184 | # attention, what we cannot get enough of 185 | attn = sim.softmax(dim=-1) 186 | 187 | out = einsum("b i j, b j d -> b i d", attn, v) 188 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 189 | return self.to_out(out) 190 | 191 | 192 | class BasicTransformerBlock(nn.Module): 193 | def __init__( 194 | self, 195 | dim, 196 | n_heads, 197 | d_head, 198 | dropout=0.0, 199 | context_dim=None, 200 | gated_ff=True, 201 | checkpoint=True, 202 | ): 203 | super().__init__() 204 | self.attn1 = CrossAttention( 205 | query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout 206 | ) # is a self-attention 207 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 208 | self.attn2 = CrossAttention( 209 | query_dim=dim, 210 | context_dim=context_dim, 211 | heads=n_heads, 212 | dim_head=d_head, 213 | dropout=dropout, 214 | ) # is self-attn if context is none 215 | self.norm1 = nn.LayerNorm(dim) 216 | self.norm2 = nn.LayerNorm(dim) 217 | self.norm3 = nn.LayerNorm(dim) 218 | self.checkpoint = checkpoint 219 | 220 | def forward(self, x, context=None): 221 | return checkpoint( 222 | self._forward, (x, context), self.parameters(), self.checkpoint 223 | ) 224 | 225 | def _forward(self, x, context=None): 226 | x = self.attn1(self.norm1(x)) + x 227 | x = self.attn2(self.norm2(x), context=context) + x 228 | x = self.ff(self.norm3(x)) + x 229 | return x 230 | 231 | 232 | class SpatialTransformer(nn.Module): 233 | """ 234 | Transformer block for image-like data. 235 | First, project the input (aka embedding) 236 | and reshape to b, t, d. 237 | Then apply standard transformer action. 238 | Finally, reshape to image 239 | """ 240 | 241 | def __init__( 242 | self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None 243 | ): 244 | super().__init__() 245 | self.in_channels = in_channels 246 | inner_dim = n_heads * d_head 247 | self.norm = Normalize(in_channels) 248 | 249 | self.proj_in = nn.Conv2d( 250 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 251 | ) 252 | 253 | self.transformer_blocks = nn.ModuleList( 254 | [ 255 | BasicTransformerBlock( 256 | inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim 257 | ) 258 | for d in range(depth) 259 | ] 260 | ) 261 | 262 | self.proj_out = zero_module( 263 | nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 264 | ) 265 | 266 | def forward(self, x, context=None): 267 | # note: if no context is given, cross-attention defaults to self-attention 268 | b, c, h, w = x.shape 269 | x_in = x 270 | x = self.norm(x) 271 | x = self.proj_in(x) 272 | x = rearrange(x, "b c h w -> b (h w) c") 273 | for block in self.transformer_blocks: 274 | x = block(x, context=context) 275 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 276 | x = self.proj_out(x) 277 | return x + x_in 278 | -------------------------------------------------------------------------------- /meshgen/modules/mesh/mesh.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import torch 5 | import trimesh 6 | import kaolin as kal 7 | from loguru import logger 8 | import numpy as np 9 | from trimesh.transformations import rotation_matrix 10 | import torch.nn.functional as F 11 | 12 | from meshgen.utils.ops import mesh_simplification 13 | 14 | 15 | def dot(x, y, dim=-1): 16 | return torch.sum(x * y, dim, keepdim=True) 17 | 18 | 19 | class Mesh: 20 | def __init__( 21 | self, 22 | mesh_path, 23 | device, 24 | target_scale=1.0, 25 | mesh_dy=0.0, 26 | remove_mesh_part_names=None, 27 | remove_unsupported_buffers=None, 28 | intermediate_dir=None, 29 | rotate=False, 30 | simplify_if_necessary=True, 31 | ): 32 | # from https://github.com/threedle/text2mesh 33 | self.material_cvt, self.material_num, org_mesh_path, is_convert = ( 34 | None, 35 | 1, 36 | mesh_path, 37 | False, 38 | ) 39 | # if not mesh_path.endswith(".obj") and not mesh_path.endswith(".off"): 40 | # if mesh_path.endswith(".gltf"): 41 | # mesh_path = self.preprocess_gltf( 42 | # mesh_path, remove_mesh_part_names, remove_unsupported_buffers 43 | # ) 44 | # mesh_temp = trimesh.load( 45 | # mesh_path, force="mesh", process=True, maintain_order=True 46 | # ) 47 | # mesh_path = os.path.splitext(mesh_path)[0] + "_cvt.obj" 48 | # mesh_temp.export(mesh_path) 49 | # merge_texture_path = os.path.join( 50 | # os.path.dirname(mesh_path), "material_0.png" 51 | # ) 52 | # if os.path.exists(merge_texture_path): 53 | # self.material_cvt = cv2.imread(merge_texture_path) 54 | # self.material_num = ( 55 | # self.material_cvt.shape[1] // self.material_cvt.shape[0] 56 | # ) 57 | # logger.info( 58 | # "Converting current mesh model to obj file with {} material~".format( 59 | # self.material_num 60 | # ) 61 | # ) 62 | # is_convert = True 63 | 64 | if ".obj" in mesh_path: 65 | try: 66 | mesh = kal.io.obj.import_mesh( 67 | mesh_path, 68 | with_normals=True, 69 | with_materials=True, 70 | heterogeneous_mesh_handler=kal.io.utils.mesh_handler_naive_triangulate, 71 | ) 72 | except: 73 | mesh = kal.io.obj.import_mesh( 74 | mesh_path, 75 | with_normals=True, 76 | with_materials=False, 77 | heterogeneous_mesh_handler=kal.io.utils.mesh_handler_naive_triangulate, 78 | ) 79 | elif ".off" in mesh_path: 80 | mesh = kal.io.off.import_mesh(mesh_path) 81 | else: 82 | raise ValueError(f"{mesh_path} extension not implemented in mesh reader.") 83 | 84 | self.vertices = mesh.vertices.to(device) 85 | self.faces = mesh.faces.to(device) 86 | try: 87 | self.vt = mesh.uvs 88 | self.ft = mesh.face_uvs_idx 89 | except AttributeError: 90 | self.vt = None 91 | self.ft = None 92 | self.mesh_path = mesh_path 93 | self.normalize_mesh(target_scale=target_scale, mesh_dy=mesh_dy) 94 | 95 | if rotate: 96 | rotmat = rotation_matrix(-np.pi / 2, [1, 0, 0]) @ rotation_matrix( 97 | -np.pi / 2, [0, 0, 1] 98 | ) 99 | rotmat = torch.from_numpy(rotmat.astype(np.float32)).to(device)[:3, :3] 100 | self.vertices = self.vertices @ rotmat.T 101 | 102 | self.vn = self._compute_normals() 103 | 104 | if self.faces.shape[0] > 20000 and simplify_if_necessary: 105 | self.vertices, self.faces = mesh_simplification( 106 | self.vertices, self.faces, 20000 107 | ) 108 | self.vertices = self.vertices.to(torch.float32) 109 | self.faces = self.faces.to(torch.int64) 110 | self.vn = self._compute_normals() 111 | 112 | if is_convert and intermediate_dir is not None: 113 | if not os.path.exists(intermediate_dir): 114 | os.makedirs(intermediate_dir) 115 | if os.path.exists(os.path.splitext(org_mesh_path)[0] + "_removed.gltf"): 116 | os.system( 117 | "mv {} {}".format( 118 | os.path.splitext(org_mesh_path)[0] + "_removed.gltf", 119 | intermediate_dir, 120 | ) 121 | ) 122 | if mesh_path.endswith("_cvt.obj"): 123 | os.system("mv {} {}".format(mesh_path, intermediate_dir)) 124 | os.system( 125 | "mv {} {}".format( 126 | os.path.join(os.path.dirname(mesh_path), "material.mtl"), 127 | intermediate_dir, 128 | ) 129 | ) 130 | # if os.path.exists(merge_texture_path): 131 | # os.system( 132 | # "mv {} {}".format( 133 | # os.path.join(os.path.dirname(mesh_path), "material_0.png"), 134 | # intermediate_dir, 135 | # ) 136 | # ) 137 | 138 | def preprocess_gltf( 139 | self, mesh_path, remove_mesh_part_names, remove_unsupported_buffers 140 | ): 141 | with open(mesh_path, "r") as fr: 142 | gltf_json = json.load(fr) 143 | if remove_mesh_part_names is not None: 144 | temp_primitives = [] 145 | for primitive in gltf_json["meshes"][0]["primitives"]: 146 | if_append, material_id = True, primitive["material"] 147 | material_name = gltf_json["materials"][material_id]["name"] 148 | for remove_mesh_part_name in remove_mesh_part_names: 149 | if material_name.find(remove_mesh_part_name) >= 0: 150 | if_append = False 151 | break 152 | if if_append: 153 | temp_primitives.append(primitive) 154 | gltf_json["meshes"][0]["primitives"] = temp_primitives 155 | logger.info( 156 | "Deleting mesh with materials named '{}' from gltf model ~".format( 157 | remove_mesh_part_names 158 | ) 159 | ) 160 | 161 | if remove_unsupported_buffers is not None: 162 | temp_buffers = [] 163 | for buffer in gltf_json["buffers"]: 164 | if_append = True 165 | for unsupported_buffer in remove_unsupported_buffers: 166 | if buffer["uri"].find(unsupported_buffer) >= 0: 167 | if_append = False 168 | break 169 | if if_append: 170 | temp_buffers.append(buffer) 171 | gltf_json["buffers"] = temp_buffers 172 | logger.info( 173 | "Deleting unspported buffers within uri {} from gltf model ~".format( 174 | remove_unsupported_buffers 175 | ) 176 | ) 177 | updated_mesh_path = os.path.splitext(mesh_path)[0] + "_removed.gltf" 178 | with open(updated_mesh_path, "w") as fw: 179 | json.dump(gltf_json, fw, indent=4) 180 | return updated_mesh_path 181 | 182 | def normalize_mesh(self, target_scale=1.0, mesh_dy=0.0): 183 | # verts = self.vertices 184 | # center = verts.mean(dim=0) 185 | # verts = verts - center 186 | # scale = torch.max(torch.norm(verts, p=2, dim=1)) 187 | # verts = verts / scale 188 | # verts *= target_scale 189 | # verts[:, 1] += mesh_dy 190 | # self.vertices = verts 191 | vertices = self.vertices 192 | """shift and resize mesh to fit into a unit cube""" 193 | offset = (vertices.min(dim=0)[0] + vertices.max(dim=0)[0]) / 2 194 | vertices -= offset 195 | scale_factor = 2.0 / (vertices.max(dim=0)[0] - vertices.min(dim=0)[0]).max() 196 | vertices *= scale_factor 197 | 198 | self.vertices = vertices * target_scale 199 | 200 | def _compute_normals(self): 201 | i0 = self.faces[:, 0] 202 | i1 = self.faces[:, 1] 203 | i2 = self.faces[:, 2] 204 | 205 | v0 = self.vertices[i0, :] 206 | v1 = self.vertices[i1, :] 207 | v2 = self.vertices[i2, :] 208 | 209 | face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) 210 | 211 | # Splat face normals to vertices 212 | v_nrm = torch.zeros_like(self.vertices) 213 | v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) 214 | v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) 215 | v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) 216 | 217 | # Normalize, replace zero (degenerated) normals with some default value 218 | v_nrm = torch.where( 219 | dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) 220 | ) 221 | v_nrm = F.normalize(v_nrm, dim=1) 222 | 223 | if torch.is_anomaly_enabled(): 224 | assert torch.all(torch.isfinite(v_nrm)) 225 | 226 | return v_nrm 227 | -------------------------------------------------------------------------------- /jointgen.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import numpy as np 3 | import torch 4 | import tqdm 5 | from pathlib import Path 6 | from omegaconf import OmegaConf 7 | import torch.distributed as dist 8 | from einops import repeat 9 | from PIL import Image 10 | import json 11 | 12 | from meshgen.util import instantiate_from_config 13 | from meshgen.utils.io import write_video, export_mesh, write_image 14 | from meshgen.utils.render import ( 15 | render_mesh_spiral_offscreen, 16 | ) 17 | from meshgen.utils.images import preprocess_image 18 | from meshgen.utils.remesh import auto_remesh, instantmesh_remesh, pyacvd_remesh 19 | 20 | import tyro 21 | 22 | import warnings 23 | 24 | warnings.filterwarnings("ignore") 25 | 26 | 27 | def jointgen( 28 | images: str, 29 | output: str, 30 | config: str = "configs/shapegen.yaml", 31 | ckpt: str = None, 32 | cfg: float = 3.0, 33 | n_steps: int = 50, 34 | n_samples: int = 1, 35 | rotate: bool = False, 36 | no_preprocess: bool = False, 37 | border_ratio: float = 0.35, 38 | remove_bg: bool = True, 39 | ignore_alpha: bool = False, 40 | alpha_matting: bool = False, 41 | export: bool = False, 42 | thresh: float = 0.5, 43 | R: int = 256, 44 | remesh: bool = False, 45 | snr_shifting: float = 1.0, 46 | seed: int = 1234, 47 | ema: bool = True, 48 | rembg_backend: str = "bria", 49 | use_diso: bool = False, 50 | bf16: bool = True, 51 | ): 52 | render_kwargs = { 53 | "num_frames": 90, 54 | "elevation": 0, 55 | "radius": 2.0, 56 | "rotate": rotate, 57 | # "color": np.array([20, 100, 246]) / 255, 58 | } 59 | 60 | if dist.is_initialized(): 61 | local_rank = dist.get_rank() 62 | world_size = dist.get_world_size() 63 | else: 64 | local_rank = 0 65 | world_size = 1 66 | 67 | torch.cuda.set_device(local_rank) 68 | device = "cuda" 69 | config_file = config 70 | print(f"Using config file: {config_file}") 71 | config = OmegaConf.load(config_file) 72 | if ckpt is None: 73 | config.params.ckpt_path = ckpt 74 | # config.model.params.force_reinit_ema = False 75 | # config.model.params.autoencoder.params.use_diso = use_diso 76 | model = instantiate_from_config(config).to(device) 77 | model.eval() 78 | 79 | cond_key = model.cond_key 80 | if cond_key == "images": 81 | images = Path(images) 82 | if not images.is_dir(): 83 | image_files = [images] 84 | else: 85 | image_files = ( 86 | list(images.glob("*.png")) 87 | + list(images.glob("*.jpg")) 88 | + list(images.glob("*.webp")) 89 | + list(images.glob("*.PNG")) 90 | + list(images.glob("*.JPG")) 91 | + list(images.glob("*.WEBP")) 92 | + list(images.glob("*.jpeg")) 93 | + list(images.glob("*.JPEG")) 94 | ) 95 | elif cond_key == "text": 96 | with open(images, "r") as f: 97 | image_files = f.read().strip().split("\n") 98 | else: 99 | raise ValueError(f"Unknown cond_key: {model.cond_key}") 100 | 101 | image_files = sorted(image_files)[local_rank::world_size] 102 | 103 | do_ema = ema and config.params.get("use_ema", False) 104 | do_refine = hasattr(model, "control_model") 105 | 106 | output_dir = Path(output) 107 | output_dir.mkdir(exist_ok=True, parents=True) 108 | if export: 109 | mesh_output_dir = output_dir / "meshes" 110 | mesh_output_dir.mkdir(exist_ok=True, parents=True) 111 | meta = [] 112 | 113 | timestep_callback = None 114 | if snr_shifting != 1.0: 115 | timestep_callback = lambda ts: [ 116 | t / (snr_shifting - snr_shifting * t + t) for t in ts 117 | ] 118 | 119 | # fmt: off 120 | # timestep_callback = lambda ts: [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.765, 0.78, 0.795, 0.81, 0.825, 0.84, 0.855, 0.87, 0.885, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0] 121 | # fmt: on 122 | 123 | @torch.no_grad() 124 | @torch.autocast("cuda", torch.bfloat16, enabled=bf16) 125 | def vis_one(image_file): 126 | if cond_key == "images": 127 | this_image = Image.open(image_file) 128 | 129 | if not no_preprocess: 130 | this_image = preprocess_image( 131 | this_image, 132 | size=512, 133 | border_ratio=border_ratio, 134 | remove_bg=remove_bg, 135 | ignore_alpha=ignore_alpha, 136 | alpha_matting=alpha_matting, 137 | backend=rembg_backend, 138 | ) 139 | 140 | cond = np.array(this_image) 141 | uncond = np.zeros_like(cond) 142 | save_name = image_file.stem 143 | else: 144 | cond = image_file 145 | uncond = "" 146 | save_name = image_file.replace(" ", "_")[:30] 147 | 148 | denoised = model.sample_one( 149 | cond, 150 | uncond, 151 | n_samples=n_samples, 152 | cfg=cfg, 153 | n_steps=n_steps, 154 | timestep_callback=timestep_callback, 155 | seed=seed, 156 | ) 157 | if do_refine: 158 | refined_denoised = model.sample_one( 159 | cond, 160 | uncond, 161 | coarse=denoised, 162 | n_samples=n_samples, 163 | cfg=cfg, 164 | n_steps=n_steps, 165 | timestep_callback=timestep_callback, 166 | seed=seed, 167 | ) 168 | for i in range(n_samples): 169 | v, f = model.decode_shape(denoised[i : i + 1], thresh=thresh, R=R) 170 | if do_refine: 171 | v_refined, f_refined = model.decode_shape( 172 | refined_denoised[i : i + 1], thresh=thresh, R=R 173 | ) 174 | if remesh: 175 | # v, f = instantmesh_remesh(v, f, target_num_faces=5000) 176 | v, f = pyacvd_remesh(v, f, target_num_faces=20000) 177 | v = np.array(v) 178 | f = np.array(f) 179 | if do_refine: 180 | v_ref_refinedined, f_refined = instantmesh_remesh( 181 | v_ref_refinedined, f_refined, target_num_f_refinedaces=3000 182 | ) 183 | v_ref_refinedined = np.array(v_ref_refinedined) 184 | f_refined = np.array(f_refined) 185 | frames = render_mesh_spiral_offscreen(v, f, **render_kwargs) 186 | if do_refine: 187 | refined_frames = render_mesh_spiral_offscreen( 188 | v_refined, f_refined, **render_kwargs 189 | ) 190 | 191 | if cond_key == "images": 192 | this_image_float = cond 193 | cond_frames = repeat( 194 | this_image_float, "h w c -> t h w c", t=frames.shape[0] 195 | ) 196 | frames = np.concatenate([cond_frames, frames], axis=-2) 197 | if do_refine: 198 | frames = np.concatenate([frames, refined_frames], axis=-2) 199 | 200 | write_video(output_dir / f"{save_name}_sample_{i}.mp4", frames) 201 | if export: 202 | export_mesh(v, f, mesh_output_dir / f"{save_name}_sample_{i}.obj") 203 | this_meta = { 204 | "mesh": (mesh_output_dir / f"{save_name}_sample_{i}.obj") 205 | .absolute() 206 | .as_posix(), 207 | "image": image_file.absolute().as_posix(), 208 | } 209 | meta.append(this_meta) 210 | if do_refine: 211 | export_mesh( 212 | v_refined, 213 | f_refined, 214 | mesh_output_dir / f"{save_name}_sample_{i}_refined.obj", 215 | ) 216 | this_meta = { 217 | "mesh": ( 218 | mesh_output_dir / f"{save_name}_sample_{i}_refined.obj" 219 | ) 220 | .absolute() 221 | .as_posix(), 222 | "image": image_file.absolute().as_posix(), 223 | } 224 | meta.append(this_meta) 225 | 226 | if do_ema: 227 | with model.ema_scope(): 228 | denoised = model.sample_one( 229 | cond, 230 | uncond, 231 | n_samples=n_samples, 232 | cfg=cfg, 233 | n_steps=n_steps, 234 | timestep_callback=timestep_callback, 235 | seed=seed, 236 | ) 237 | for i in range(n_samples): 238 | v, f = model.decode_shape(denoised[i : i + 1], thresh=thresh, R=R) 239 | if remesh: 240 | v, f = instantmesh_remesh(v, f, target_num_faces=3000) 241 | v = np.array(v) 242 | f = np.array(f) 243 | frames = render_mesh_spiral_offscreen(v, f, **render_kwargs) 244 | 245 | if cond_key == "images": 246 | this_image_float = cond 247 | cond_frames = repeat( 248 | this_image_float, "h w c -> t h w c", t=frames.shape[0] 249 | ) 250 | frames = np.concatenate([cond_frames, frames], axis=-2) 251 | 252 | write_video(output_dir / f"{save_name}_sample_ema_{i}.mp4", frames) 253 | if export: 254 | export_mesh( 255 | v, f, mesh_output_dir / f"{save_name}_sample_ema_{i}.obj" 256 | ) 257 | this_meta = { 258 | "mesh": (mesh_output_dir / f"{save_name}_sample_ema_{i}.obj") 259 | .absolute() 260 | .as_posix(), 261 | "image": image_file.absolute().as_posix(), 262 | } 263 | meta.append(this_meta) 264 | 265 | for img in tqdm.tqdm(image_files, disable=local_rank): 266 | vis_one(img) 267 | 268 | if export: 269 | meta_gathered = [None for _ in range(world_size)] 270 | dist.barrier() 271 | dist.all_gather_object(meta_gathered, meta) 272 | 273 | if local_rank == 0: 274 | merged = [] 275 | for d in meta_gathered: 276 | merged += d 277 | 278 | json.dump(merged, open(output_dir / "meta.json", "w"), indent=4) 279 | 280 | 281 | if __name__ == "__main__": 282 | dist.init_process_group("nccl") 283 | tyro.cli(jointgen) 284 | -------------------------------------------------------------------------------- /meshgen/model/diffusion/rfunet.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import pytorch_lightning as pl 8 | import torch.nn.functional as F 9 | from pathlib import Path 10 | import tqdm 11 | from copy import deepcopy 12 | from einops import rearrange, repeat 13 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity 14 | from torchvision.transforms import v2 15 | from contextlib import contextmanager 16 | from torchvision.utils import make_grid, save_image 17 | import lpips 18 | 19 | from meshgen.util import instantiate_from_config 20 | from meshgen.model.base import BaseModel 21 | from meshgen.utils.ema import LitEma 22 | from meshgen.utils.io import load_mesh, write_video 23 | from meshgen.utils.render import render_mesh_spiral_offscreen 24 | from meshgen.utils.ops import logit_normal 25 | 26 | 27 | def disabled_train(self, mode=True): 28 | """Overwrite model.train with this function to make sure train/eval mode 29 | does not change anymore.""" 30 | return self 31 | 32 | 33 | class RectifiedFlowUNet2D(BaseModel): 34 | def __init__( 35 | self, 36 | autoencoder, 37 | unet, 38 | cond_encoder, 39 | input_key="surface", 40 | cond_key="images", 41 | shift_factor=0.0, 42 | scale_factor=0.25, 43 | weight_decay=0.0, 44 | ckpt_path=None, 45 | ignore_keys=[], 46 | scheduler_config=None, 47 | sample_kwargs={}, 48 | render_kwargs={}, 49 | vis_every=None, 50 | rf_mu=1.0986, 51 | rf_sigma=1.0, 52 | timestep_sample="uniform", 53 | rescale_image=False, 54 | use_ema=False, 55 | skip_validation=False, 56 | force_reinit_ema=False, 57 | _no_load_ckpt=False, 58 | *args, 59 | **kwargs, 60 | ): 61 | # TODO: add ema model 62 | super().__init__(*args, **kwargs) 63 | 64 | self.input_key = input_key 65 | self.cond_key = cond_key 66 | 67 | self.autoencoder = instantiate_from_config(autoencoder) 68 | self.autoencoder = self.autoencoder.eval() 69 | self.autoencoder.train = disabled_train 70 | for n, p in self.autoencoder.named_parameters(): 71 | p.requires_grad = False 72 | 73 | self.unet = instantiate_from_config(unet) 74 | self.cond_encoder = instantiate_from_config(cond_encoder) 75 | 76 | self.force_reinit_ema = force_reinit_ema 77 | self.use_ema = use_ema 78 | if self.use_ema: 79 | self.model_ema = LitEma(self.unet) 80 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 81 | 82 | if ckpt_path is not None and not _no_load_ckpt: 83 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 84 | 85 | self.skip_validation = skip_validation 86 | self.shift_factor = shift_factor 87 | self.scale_factor = scale_factor 88 | self.scheduler_config = scheduler_config 89 | self.use_scheduler = scheduler_config is not None 90 | self.weight_decay = weight_decay 91 | 92 | self.sample_kwargs = sample_kwargs 93 | self.render_kwargs = render_kwargs 94 | self.vis_every = vis_every 95 | 96 | self.rf_mu = rf_mu 97 | self.rf_sigma = rf_sigma 98 | self.timestep_sample = timestep_sample 99 | 100 | self.latent_shape = ( 101 | self.autoencoder.triplane_ch, 102 | self.autoencoder.triplane_res * 3, 103 | self.autoencoder.triplane_res, 104 | ) 105 | 106 | self.rescale_image = rescale_image 107 | 108 | self.train() 109 | 110 | def on_train_batch_end(self, *args, **kwargs): 111 | if self.use_ema: 112 | self.model_ema(self.unet) 113 | 114 | def on_load_checkpoint(self, checkpoint): 115 | if self.use_ema: 116 | contain_ema = False 117 | for k in checkpoint["state_dict"]: 118 | if "model_ema" in k: 119 | contain_ema = True 120 | break 121 | if not contain_ema or self.force_reinit_ema: 122 | ema_sd = {} 123 | for k, v in self.unet.state_dict().items(): 124 | ema_sd[f"model_ema.{k.replace('.', '')}"] = v 125 | ema_sd["model_ema.num_updates"] = torch.tensor(0, dtype=torch.int) 126 | ema_sd["model_ema.decay"] = torch.tensor(0.9999, dtype=torch.float32) 127 | checkpoint["state_dict"].update(ema_sd) 128 | 129 | @contextmanager 130 | def ema_scope(self, context=None): 131 | if self.use_ema: 132 | self.model_ema.store(self.unet.parameters()) 133 | self.model_ema.copy_to(self.unet) 134 | if context is not None: 135 | print(f"{context}: Switched to EMA weights") 136 | try: 137 | yield None 138 | finally: 139 | if self.use_ema: 140 | self.model_ema.restore(self.unet.parameters()) 141 | if context is not None: 142 | print(f"{context}: Restored training weights") 143 | 144 | @torch.no_grad() 145 | def encode_shape(self, pcd): 146 | z = self.autoencoder.encode(pcd) 147 | z = rearrange(z, "b n c h w -> b c (n h) w") 148 | 149 | return (z + self.shift_factor) * self.scale_factor 150 | 151 | @torch.no_grad() 152 | def decode_shape(self, z, **kwargs): 153 | z = z / self.scale_factor - self.shift_factor 154 | return self.autoencoder.decode_shape(z, **kwargs) 155 | 156 | def decode(self, z, queries=None): 157 | z = z / self.scale_factor - self.shift_factor 158 | return self.autoencoder.decode(z, queries, upsample=True) 159 | 160 | @torch.no_grad() 161 | def encode_cond(self, cond): 162 | return self.cond_encoder(cond) 163 | 164 | def forward(self, input, cond): 165 | X0 = self.encode_shape(input) 166 | cond_emb = self.encode_cond(cond) 167 | 168 | # t_input = logit_normal( 169 | # self.rf_mu, self.rf_sigma, (bs,), device=self.device, dtype=self.dtype 170 | # ) 171 | bs = X0.shape[0] 172 | if self.timestep_sample == "uniform": 173 | t_input = torch.rand((bs,), device=self.device, dtype=self.dtype) 174 | else: 175 | t_input = logit_normal( 176 | self.rf_mu, self.rf_sigma, (bs,), device=self.device, dtype=self.dtype 177 | ) 178 | 179 | t = t_input.view(bs, *((1,) * (len(X0.shape) - 1))) 180 | 181 | X1 = torch.randn_like(X0) 182 | Xt = X1 * t + X0 * (1 - t) 183 | 184 | pred = self.unet(Xt, cond_emb, t_input) 185 | 186 | loss = F.mse_loss((X1 - X0), pred) 187 | 188 | return loss 189 | 190 | def shared_step(self, batch, batch_idx): 191 | input = batch[self.input_key] 192 | cond = batch[self.cond_key] 193 | 194 | loss = self(input, cond) 195 | loss_dict = {"loss": loss} 196 | 197 | return loss, loss_dict 198 | 199 | @torch.no_grad() 200 | def sample_one( 201 | self, 202 | cond, 203 | uncond, 204 | cfg, 205 | n_steps, 206 | n_samples=1, 207 | seed=1234, 208 | x_init=None, 209 | timestep_callback=None, 210 | ): 211 | generator = torch.Generator(self.device).manual_seed(seed + self.local_rank) 212 | if x_init is None: 213 | x_init = torch.randn( 214 | (n_samples,) + self.latent_shape, 215 | device=self.device, 216 | generator=generator, 217 | dtype=self.dtype, 218 | ) 219 | ts = [i / n_steps for i in range(n_steps + 1)] 220 | if timestep_callback is not None: 221 | print("Using timestep callback") 222 | # ts = [timestep_callback(t) for t in ts] 223 | ts = timestep_callback(ts) 224 | 225 | cond_emb = self.encode_cond(cond) 226 | uncond_emb = self.encode_cond(uncond) 227 | cond_emb = repeat(cond_emb, "1 ... -> n ...", n=n_samples) 228 | uncond_emb = repeat(uncond_emb, "1 ... -> n ...", n=n_samples) 229 | 230 | x = x_init 231 | for s, t in tqdm.tqdm(list(zip(ts, ts[1:]))[::-1], disable=True): 232 | # pred = nnet(x, t=torch.full((x.size(0),), t).to(x)) 233 | this_t = torch.full((x.size(0),), t).to(x) 234 | 235 | cond_pred = self.unet(x, cond_emb, this_t) 236 | uncond_pred = self.unet(x, uncond_emb, this_t) 237 | pred = cond_pred + (cond_pred - uncond_pred) * cfg 238 | 239 | x = x + pred * (s - t) 240 | 241 | return x 242 | 243 | @torch.no_grad() 244 | def test_in_the_wild(self, save_dir): 245 | if self.skip_validation: 246 | return 247 | torch.cuda.empty_cache() 248 | gc.collect() 249 | if self.cond_key == "images": 250 | image_dir = Path("./data/images/GSO") 251 | images = list(image_dir.glob("*.png")) + list(image_dir.glob("*.jpg")) 252 | images = sorted(images)[self.local_rank :: self.trainer.world_size] 253 | 254 | for image in images: 255 | cond = np.array(Image.open(image)) 256 | if self.rescale_image: 257 | cond = cond.astype(np.float32) / 255.0 258 | uncond = np.zeros_like(cond) 259 | 260 | denosied = self.sample_one( 261 | cond, uncond, n_samples=1, **self.sample_kwargs 262 | ) 263 | try: 264 | v, f = self.decode_shape(denosied) 265 | frames = render_mesh_spiral_offscreen(v, f, **self.render_kwargs) 266 | 267 | write_video(save_dir / f"{image.stem}.mp4", frames) 268 | except IndexError: 269 | pass 270 | elif self.cond_key == "text": 271 | with open("data/texts/benchmark_captions.txt") as f: 272 | text_prompts = f.read().strip().split("\n") 273 | text_prompts = text_prompts[:32] 274 | text_prompts = sorted(text_prompts)[ 275 | self.local_rank :: self.trainer.world_size 276 | ] 277 | for text in text_prompts: 278 | denosied = self.sample_one(text, "", n_samples=1, **self.sample_kwargs) 279 | try: 280 | v, f = self.decode_shape(denosied) 281 | frames = render_mesh_spiral_offscreen(v, f, **self.render_kwargs) 282 | 283 | caption = text.replace(" ", "_")[:30] 284 | write_video(save_dir / f"{caption}.mp4", frames) 285 | except IndexError: 286 | pass 287 | else: 288 | raise NotImplementedError 289 | 290 | def get_optim_groups(self): 291 | return self.unet.parameters() 292 | -------------------------------------------------------------------------------- /shapegen.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import numpy as np 3 | import torch 4 | import tqdm 5 | from pathlib import Path 6 | from omegaconf import OmegaConf 7 | import torch.distributed as dist 8 | from einops import repeat 9 | from PIL import Image 10 | import json 11 | 12 | from meshgen.util import instantiate_from_config 13 | from meshgen.utils.io import write_video, export_mesh, write_image 14 | from meshgen.utils.render import ( 15 | render_mesh_spiral_offscreen, 16 | ) 17 | from meshgen.utils.images import preprocess_image 18 | from meshgen.utils.remesh import auto_remesh, instantmesh_remesh, pyacvd_remesh 19 | from meshgen.utils.hf_weights import shape_generator_path 20 | 21 | import tyro 22 | 23 | import warnings 24 | 25 | warnings.filterwarnings("ignore") 26 | 27 | 28 | def vis_diffusion( 29 | images: str, 30 | output: str, 31 | mesh_output: str = None, 32 | config: str = "configs/shapegen.yaml", 33 | ckpt: str = None, 34 | cfg: float = 3.0, 35 | n_steps: int = 50, 36 | n_samples: int = 1, 37 | rotate: bool = False, 38 | no_preprocess: bool = False, 39 | border_ratio: float = 0.35, 40 | remove_bg: bool = True, 41 | ignore_alpha: bool = False, 42 | alpha_matting: bool = False, 43 | export: bool = True, 44 | thresh: float = 0.5, 45 | R: int = 256, 46 | remesh: bool = False, 47 | snr_shifting: float = 1.0, 48 | seed: int = 1234, 49 | ema: bool = True, 50 | rembg_backend: str = "birefnet", 51 | use_diso: bool = False, 52 | bf16: bool = True, 53 | ): 54 | render_kwargs = { 55 | "num_frames": 90, 56 | "elevation": 0, 57 | "radius": 2.0, 58 | "rotate": rotate, 59 | # "color": np.array([20, 100, 246]) / 255, 60 | } 61 | 62 | if dist.is_initialized(): 63 | local_rank = dist.get_rank() 64 | world_size = dist.get_world_size() 65 | else: 66 | local_rank = 0 67 | world_size = 1 68 | 69 | torch.cuda.set_device(local_rank) 70 | device = "cuda" 71 | config_file = config 72 | print(f"Using config file: {config_file}") 73 | config = OmegaConf.load(config_file) 74 | if ckpt is None: 75 | ckpt = shape_generator_path 76 | config.params.ckpt_path = ckpt 77 | # config.model.params.force_reinit_ema = False 78 | # config.model.params.autoencoder.params.use_diso = use_diso 79 | model = instantiate_from_config(config).to(device) 80 | model.eval() 81 | 82 | cond_key = model.cond_key 83 | if cond_key == "images": 84 | images = Path(images) 85 | if not images.is_dir(): 86 | image_files = [images] 87 | else: 88 | image_files = ( 89 | list(images.glob("*.png")) 90 | + list(images.glob("*.jpg")) 91 | + list(images.glob("*.webp")) 92 | + list(images.glob("*.PNG")) 93 | + list(images.glob("*.JPG")) 94 | + list(images.glob("*.WEBP")) 95 | + list(images.glob("*.jpeg")) 96 | + list(images.glob("*.JPEG")) 97 | ) 98 | elif cond_key == "text": 99 | with open(images, "r") as f: 100 | image_files = f.read().strip().split("\n") 101 | else: 102 | raise ValueError(f"Unknown cond_key: {model.cond_key}") 103 | 104 | image_files = sorted(image_files)[local_rank::world_size] 105 | 106 | do_ema = ema and config.params.get("use_ema", False) 107 | do_refine = hasattr(model, "control_model") 108 | 109 | output_dir = Path(output) 110 | output_dir.mkdir(exist_ok=True, parents=True) 111 | if export: 112 | if mesh_output is None: 113 | mesh_output_dir = output_dir / "meshes" 114 | else: 115 | mesh_output_dir = Path(mesh_output) 116 | mesh_output_dir.mkdir(exist_ok=True, parents=True) 117 | meta = [] 118 | 119 | timestep_callback = None 120 | if snr_shifting != 1.0: 121 | timestep_callback = lambda ts: [ 122 | t / (snr_shifting - snr_shifting * t + t) for t in ts 123 | ] 124 | 125 | # fmt: off 126 | # timestep_callback = lambda ts: [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.765, 0.78, 0.795, 0.81, 0.825, 0.84, 0.855, 0.87, 0.885, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0] 127 | # fmt: on 128 | 129 | @torch.no_grad() 130 | @torch.autocast("cuda", torch.bfloat16, enabled=bf16) 131 | def vis_one(image_file): 132 | if cond_key == "images": 133 | this_image = Image.open(image_file) 134 | 135 | if not no_preprocess: 136 | this_image = preprocess_image( 137 | this_image, 138 | size=512, 139 | border_ratio=border_ratio, 140 | remove_bg=remove_bg, 141 | ignore_alpha=ignore_alpha, 142 | alpha_matting=alpha_matting, 143 | backend=rembg_backend, 144 | ) 145 | 146 | cond = np.array(this_image) 147 | uncond = np.zeros_like(cond) 148 | save_name = image_file.stem 149 | else: 150 | cond = image_file 151 | uncond = "" 152 | save_name = image_file.replace(" ", "_")[:30] 153 | 154 | denoised = model.sample_one( 155 | cond, 156 | uncond, 157 | n_samples=n_samples, 158 | cfg=cfg, 159 | n_steps=n_steps, 160 | timestep_callback=timestep_callback, 161 | seed=seed, 162 | ) 163 | if do_refine: 164 | refined_denoised = model.sample_one( 165 | cond, 166 | uncond, 167 | coarse=denoised, 168 | n_samples=n_samples, 169 | cfg=cfg, 170 | n_steps=n_steps, 171 | timestep_callback=timestep_callback, 172 | seed=seed, 173 | ) 174 | for i in range(n_samples): 175 | v, f = model.decode_shape(denoised[i : i + 1], thresh=thresh, R=R) 176 | if do_refine: 177 | v_refined, f_refined = model.decode_shape( 178 | refined_denoised[i : i + 1], thresh=thresh, R=R 179 | ) 180 | if remesh: 181 | # v, f = instantmesh_remesh(v, f, target_num_faces=5000) 182 | v, f = pyacvd_remesh(v, f, target_num_faces=20000) 183 | v = np.array(v) 184 | f = np.array(f) 185 | if do_refine: 186 | v_ref_refinedined, f_refined = instantmesh_remesh( 187 | v_ref_refinedined, f_refined, target_num_f_refinedaces=3000 188 | ) 189 | v_ref_refinedined = np.array(v_ref_refinedined) 190 | f_refined = np.array(f_refined) 191 | frames = render_mesh_spiral_offscreen(v, f, **render_kwargs) 192 | if do_refine: 193 | refined_frames = render_mesh_spiral_offscreen( 194 | v_refined, f_refined, **render_kwargs 195 | ) 196 | 197 | if cond_key == "images": 198 | this_image_float = cond 199 | cond_frames = repeat( 200 | this_image_float, "h w c -> t h w c", t=frames.shape[0] 201 | ) 202 | frames = np.concatenate([cond_frames, frames], axis=-2) 203 | if do_refine: 204 | frames = np.concatenate([frames, refined_frames], axis=-2) 205 | 206 | write_video(output_dir / f"{save_name}_sample_{i}.mp4", frames) 207 | if export: 208 | export_mesh(v, f, mesh_output_dir / f"{save_name}_sample_{i}.obj") 209 | this_meta = { 210 | "mesh": (mesh_output_dir / f"{save_name}_sample_{i}.obj") 211 | .absolute() 212 | .as_posix(), 213 | "image": image_file.absolute().as_posix(), 214 | } 215 | meta.append(this_meta) 216 | if do_refine: 217 | export_mesh( 218 | v_refined, 219 | f_refined, 220 | mesh_output_dir / f"{save_name}_sample_{i}_refined.obj", 221 | ) 222 | this_meta = { 223 | "mesh": ( 224 | mesh_output_dir / f"{save_name}_sample_{i}_refined.obj" 225 | ) 226 | .absolute() 227 | .as_posix(), 228 | "image": image_file.absolute().as_posix(), 229 | } 230 | meta.append(this_meta) 231 | 232 | if do_ema: 233 | with model.ema_scope(): 234 | denoised = model.sample_one( 235 | cond, 236 | uncond, 237 | n_samples=n_samples, 238 | cfg=cfg, 239 | n_steps=n_steps, 240 | timestep_callback=timestep_callback, 241 | seed=seed, 242 | ) 243 | for i in range(n_samples): 244 | v, f = model.decode_shape(denoised[i : i + 1], thresh=thresh, R=R) 245 | if remesh: 246 | v, f = instantmesh_remesh(v, f, target_num_faces=3000) 247 | v = np.array(v) 248 | f = np.array(f) 249 | frames = render_mesh_spiral_offscreen(v, f, **render_kwargs) 250 | 251 | if cond_key == "images": 252 | this_image_float = cond 253 | cond_frames = repeat( 254 | this_image_float, "h w c -> t h w c", t=frames.shape[0] 255 | ) 256 | frames = np.concatenate([cond_frames, frames], axis=-2) 257 | 258 | write_video(output_dir / f"{save_name}_sample_ema_{i}.mp4", frames) 259 | if export: 260 | export_mesh( 261 | v, f, mesh_output_dir / f"{save_name}_sample_ema_{i}.obj" 262 | ) 263 | this_meta = { 264 | "mesh": (mesh_output_dir / f"{save_name}_sample_ema_{i}.obj") 265 | .absolute() 266 | .as_posix(), 267 | "image": image_file.absolute().as_posix(), 268 | } 269 | meta.append(this_meta) 270 | 271 | for img in tqdm.tqdm(image_files, disable=local_rank): 272 | vis_one(img) 273 | 274 | if export: 275 | meta_gathered = [None for _ in range(world_size)] 276 | dist.barrier() 277 | dist.all_gather_object(meta_gathered, meta) 278 | 279 | if local_rank == 0: 280 | merged = [] 281 | for d in meta_gathered: 282 | merged += d 283 | 284 | json.dump(merged, open(output_dir / "meta.json", "w"), indent=4) 285 | 286 | 287 | if __name__ == "__main__": 288 | dist.init_process_group("nccl") 289 | tyro.cli(vis_diffusion) 290 | -------------------------------------------------------------------------------- /meshgen/utils/render.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import trimesh 5 | from trimesh.transformations import rotation_matrix 6 | import pyrender 7 | import open3d as o3d 8 | from einops import rearrange 9 | from open3d.visualization import rendering 10 | from open3d.visualization.rendering import OffscreenRenderer, MaterialRecord 11 | 12 | o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error) 13 | 14 | # from meshgen.utils.timer import tic, toc, enable_timing 15 | 16 | _render = None 17 | 18 | 19 | def get_uniform_campos(num_frames, radius, elevation): 20 | T = num_frames 21 | azimuths = np.deg2rad(np.linspace(0, 360, T + 1)[:T]) 22 | elevations = np.full_like(azimuths, np.deg2rad(elevation)) 23 | cam_dists = np.full_like(azimuths, radius) 24 | 25 | campos = np.stack( 26 | [ 27 | cam_dists * np.cos(elevations) * np.cos(azimuths), 28 | cam_dists * np.cos(elevations) * np.sin(azimuths), 29 | cam_dists * np.sin(elevations), 30 | ], 31 | axis=-1, 32 | ) 33 | 34 | return campos 35 | 36 | 37 | def init_renderer(reso=512, fov=50.0, force_new_render=False): 38 | if not force_new_render: 39 | global _render 40 | if _render is not None: 41 | return _render 42 | _render = OffscreenRenderer(reso, reso) 43 | _render.scene.set_background([1, 1, 1, 1]) 44 | bottom_plane = trimesh.creation.box([100, 100, 0.01]) 45 | bottom_plane.apply_translation([0, 0, -0.55]) 46 | bottom_plane = bottom_plane.as_open3d 47 | bottom_plane.compute_vertex_normals() 48 | bottom_plane.paint_uniform_color([1, 1, 1]) 49 | plane_mat = MaterialRecord() 50 | plane_mat.base_color = [1, 1, 1, 1] 51 | plane_mat.shader = "defaultLit" 52 | _render.scene.add_geometry("bottom_plane", bottom_plane, plane_mat) 53 | 54 | _render.scene.set_lighting( 55 | _render.scene.LightingProfile.MED_SHADOWS, (0.577, -0.577, -0.577) 56 | ) 57 | _render.scene.set_lighting( 58 | _render.scene.LightingProfile.MED_SHADOWS, (-0.577, -0.577, -0.577) 59 | ) 60 | old_camera = _render.scene.camera 61 | _render.scene.camera.set_projection( 62 | fov, 63 | 1, 64 | old_camera.get_near(), 65 | old_camera.get_far(), 66 | old_camera.get_field_of_view_type(), 67 | ) 68 | 69 | return _render 70 | 71 | 72 | @torch.no_grad() 73 | def render_mesh_spiral_offscreen( 74 | vertices, 75 | faces, 76 | reso=512, 77 | num_frames=90, 78 | elevation=0, 79 | radius=2.0, 80 | normalize=True, 81 | rotate=True, 82 | color=None, 83 | fov=50.0, 84 | force_new_render=False, 85 | ): 86 | if isinstance(vertices, torch.Tensor): 87 | vertices = vertices.cpu().numpy() 88 | faces = faces.cpu().numpy() 89 | 90 | if len(vertices) == 0: 91 | return np.ones((num_frames, reso, reso, 3), dtype=np.uint8) * 255 92 | 93 | mesh = trimesh.Trimesh(vertices, faces) 94 | if normalize: 95 | mesh.apply_translation(-mesh.centroid) 96 | scale = max(mesh.extents) 97 | scale_T = np.eye(4) 98 | scale_T[0, 0] = scale_T[1, 1] = scale_T[2, 2] = 1.0 / scale 99 | mesh.apply_transform(scale_T) 100 | 101 | if rotate: 102 | yz_rotation = rotation_matrix(np.pi / 2, [1, 0, 0]) 103 | mesh.apply_transform(yz_rotation) 104 | 105 | # init_renderer(reso, fov) 106 | render = init_renderer(reso, fov, force_new_render=force_new_render) 107 | 108 | # global _render 109 | # render = _render 110 | 111 | # color = np.array([58, 75, 101], dtype=np.float64) / 255 112 | if color is None: 113 | color = np.array([0.034, 0.294, 0.5], dtype=np.float64) * 1.2 114 | # color = np.array([144.0 / 255, 210.0 / 255, 236.0 / 255], dtype=np.float64) 115 | 116 | mesh = mesh.as_open3d 117 | mesh.compute_vertex_normals() 118 | mesh.paint_uniform_color(color) 119 | mat = MaterialRecord() 120 | mat.base_color = [1, 1, 1, 1] 121 | mat.base_metallic = 0.0 122 | mat.base_roughness = 1.0 123 | mat.shader = "defaultLit" 124 | render.scene.add_geometry("model", mesh, mat) 125 | 126 | campos = get_uniform_campos(num_frames, radius, elevation) 127 | 128 | # render = OffscreenRenderer(reso, reso) 129 | 130 | # render.scene.set_background([1, 1, 1, 1]) 131 | # render.scene.add_geometry("model", mesh, mat) 132 | 133 | # bottom_plane = trimesh.creation.box([100, 100, 0.01]) 134 | # bottom_plane.apply_translation([0, 0, -0.55]) 135 | # bottom_plane = bottom_plane.as_open3d 136 | # bottom_plane.compute_vertex_normals() 137 | # bottom_plane.paint_uniform_color([1, 1, 1]) 138 | # plane_mat = MaterialRecord() 139 | # plane_mat.base_color = [1, 1, 1, 1] 140 | # plane_mat.shader = "defaultLit" 141 | # render.scene.add_geometry("bottom_plane", bottom_plane, plane_mat) 142 | 143 | # # render.scene.scene.enable_sun_light(False) 144 | # light_dir = np.array([1, 1, 1]) 145 | # # render.scene.scene.add_spot_light( 146 | # # "light", [1, 1, 1], -3 * light_dir, light_dir, 1e8, 1e2, 0.1, 0.1, True 147 | # # ) 148 | 149 | # render.scene.set_lighting( 150 | # render.scene.LightingProfile.MED_SHADOWS, (0.577, -0.577, -0.577) 151 | # ) 152 | 153 | frames = [] 154 | render.scene.camera.look_at([0, 0, 0], campos[0], [0, 0, 1]) 155 | for i in range(num_frames): 156 | azimuth = i / num_frames * 2 * np.pi 157 | render.scene.set_geometry_transform( 158 | "model", rotation_matrix(azimuth, [0, 0, 1]) 159 | ) 160 | frame = np.asarray(render.render_to_image()) 161 | frames.append(frame) 162 | 163 | render.scene.remove_geometry("model") 164 | return np.stack(frames, axis=0) 165 | 166 | 167 | def render_point_cloud_spiral_offscreen( 168 | vertices, 169 | reso=512, 170 | num_frames=90, 171 | elevation=0, 172 | radius=2.0, 173 | point_radius=0.01, 174 | rotate=True, 175 | **kwargs, 176 | ): 177 | if isinstance(vertices, torch.Tensor): 178 | vertices = vertices.cpu().numpy() 179 | 180 | if rotate: 181 | yz_rotation = rotation_matrix(np.pi / 2, [1, 0, 0]) 182 | vertices = np.dot(vertices, yz_rotation[:3, :3].T) 183 | 184 | colors = np.array([125, 151, 250, 255]) / 255 185 | colors = np.tile(colors, (vertices.shape[0], 1)) 186 | 187 | sm = [] 188 | for v in vertices: 189 | this_point = trimesh.creation.uv_sphere(radius=point_radius) 190 | this_point.apply_translation(v) 191 | sm.append(this_point) 192 | 193 | sm = trimesh.util.concatenate(sm) 194 | 195 | return render_mesh_spiral_offscreen( 196 | sm.vertices, sm.faces, reso, num_frames, elevation, radius, **kwargs 197 | ) 198 | 199 | 200 | import torch 201 | import torch.nn.functional as F 202 | import numpy as np 203 | 204 | 205 | def pad_camera_extrinsics_4x4(extrinsics): 206 | if extrinsics.shape[-2] == 4: 207 | return extrinsics 208 | padding = torch.tensor([[0, 0, 0, 1]]).to(extrinsics) 209 | if extrinsics.ndim == 3: 210 | padding = padding.unsqueeze(0).repeat(extrinsics.shape[0], 1, 1) 211 | extrinsics = torch.cat([extrinsics, padding], dim=-2) 212 | return extrinsics 213 | 214 | 215 | def center_looking_at_camera_pose( 216 | camera_position: torch.Tensor, 217 | look_at: torch.Tensor = None, 218 | up_world: torch.Tensor = None, 219 | ): 220 | """ 221 | Create OpenGL camera extrinsics from camera locations and look-at position. 222 | 223 | camera_position: (M, 3) or (3,) 224 | look_at: (3) 225 | up_world: (3) 226 | return: (M, 3, 4) or (3, 4) 227 | """ 228 | # by default, looking at the origin and world up is z-axis 229 | if look_at is None: 230 | look_at = torch.tensor([0, 0, 0], dtype=torch.float32) 231 | if up_world is None: 232 | up_world = torch.tensor([0, 0, 1], dtype=torch.float32) 233 | if camera_position.ndim == 2: 234 | look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1) 235 | up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1) 236 | 237 | # OpenGL camera: z-backward, x-right, y-up 238 | z_axis = camera_position - look_at 239 | z_axis = F.normalize(z_axis, dim=-1).float() 240 | x_axis = torch.linalg.cross(up_world, z_axis, dim=-1) 241 | x_axis = F.normalize(x_axis, dim=-1).float() 242 | y_axis = torch.linalg.cross(z_axis, x_axis, dim=-1) 243 | y_axis = F.normalize(y_axis, dim=-1).float() 244 | 245 | extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1) 246 | extrinsics = pad_camera_extrinsics_4x4(extrinsics) 247 | return extrinsics 248 | 249 | 250 | def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5): 251 | azimuths = np.deg2rad(azimuths) 252 | elevations = np.deg2rad(elevations) 253 | 254 | xs = radius * np.cos(elevations) * np.cos(azimuths) 255 | ys = radius * np.cos(elevations) * np.sin(azimuths) 256 | zs = radius * np.sin(elevations) 257 | 258 | cam_locations = np.stack([xs, ys, zs], axis=-1) 259 | cam_locations = torch.from_numpy(cam_locations).float() 260 | 261 | c2ws = center_looking_at_camera_pose(cam_locations) 262 | return c2ws 263 | 264 | 265 | def get_circular_camera_poses(M=120, radius=2.5, elevation=30.0): 266 | # M: number of circular views 267 | # radius: camera dist to center 268 | # elevation: elevation degrees of the camera 269 | # return: (M, 4, 4) 270 | assert M > 0 and radius > 0 271 | 272 | elevation = np.deg2rad(elevation) 273 | 274 | camera_positions = [] 275 | for i in range(M): 276 | azimuth = 2 * np.pi * i / M 277 | x = radius * np.cos(elevation) * np.cos(azimuth) 278 | y = radius * np.cos(elevation) * np.sin(azimuth) 279 | z = radius * np.sin(elevation) 280 | camera_positions.append([x, y, z]) 281 | camera_positions = np.array(camera_positions) 282 | camera_positions = torch.from_numpy(camera_positions).float() 283 | extrinsics = center_looking_at_camera_pose(camera_positions) 284 | return extrinsics 285 | 286 | 287 | def FOV_to_intrinsics(fov, device="cpu"): 288 | """ 289 | Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees. 290 | Note the intrinsics are returned as normalized by image size, rather than in pixel units. 291 | Assumes principal point is at image center. 292 | """ 293 | focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5) 294 | intrinsics = torch.tensor( 295 | [[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device 296 | ) 297 | return intrinsics 298 | 299 | 300 | def get_render_cameras( 301 | batch_size=1, M=120, radius=2.0, elevation=20.0, is_flexicubes=False 302 | ): 303 | """ 304 | Get the rendering camera parameters. 305 | """ 306 | c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation) 307 | if is_flexicubes: 308 | cameras = torch.linalg.inv(c2ws) 309 | cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1) 310 | else: 311 | extrinsics = c2ws.flatten(-2) 312 | intrinsics = ( 313 | FOV_to_intrinsics(30.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2) 314 | ) 315 | cameras = torch.cat([extrinsics, intrinsics], dim=-1) 316 | cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1) 317 | return cameras 318 | 319 | 320 | def render_instant_frames( 321 | model, 322 | planes, 323 | reso=512, 324 | num_frames=90, 325 | elevation=0, 326 | radius=2.0, 327 | chunk_size=1, 328 | is_flexicubes=False, 329 | **kwargs, 330 | ): 331 | """ 332 | Render frames from triplanes. 333 | """ 334 | render_size = reso 335 | render_cameras = get_render_cameras( 336 | 1, num_frames, radius, elevation, is_flexicubes 337 | ).to(planes) 338 | frames = [] 339 | for i in range(0, render_cameras.shape[1], chunk_size): 340 | if is_flexicubes: 341 | frame = model.forward_geometry( 342 | planes, 343 | render_cameras[:, i : i + chunk_size], 344 | render_size=render_size, 345 | )["img"] 346 | else: 347 | frame = model.forward_synthesizer( 348 | planes, 349 | render_cameras[:, i : i + chunk_size], 350 | render_size=render_size, 351 | )["images_rgb"] 352 | frames.append(frame) 353 | 354 | frames = torch.cat(frames, dim=1)[0] # we suppose batch size is always 1 355 | frames = frames.cpu().numpy() 356 | frames = rearrange(frames, "T C H W -> T H W C") 357 | return frames 358 | -------------------------------------------------------------------------------- /meshgen/model/triplane_autoencoder.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import numpy as np 3 | import torch 4 | import pytorch_lightning as pl 5 | import torch.nn.functional as F 6 | from torch.optim.lr_scheduler import LambdaLR 7 | from torch.utils.data import default_collate 8 | from pathlib import Path 9 | from einops import rearrange, repeat 10 | import tqdm 11 | import nvdiffrast.torch as dr 12 | # from diso import DiffDMC 13 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity 14 | from skimage import measure 15 | 16 | from meshgen.util import instantiate_from_config 17 | from diffusers.training_utils import EMAModel 18 | from meshgen.utils.io import load_mesh, write_video, write_image, export_mesh 19 | from meshgen.utils.render import render_mesh_spiral_offscreen 20 | from meshgen.utils.ops import ( 21 | sample_on_surface, 22 | sample_from_planes, 23 | generate_planes, 24 | calc_normal, 25 | get_projection_matrix, 26 | safe_normalize, 27 | compute_tv_loss, 28 | ) 29 | from meshgen.utils.ray_utils import RaySampler, RayGenerator 30 | from meshgen.utils.math_utils import get_ray_limits_box, linspace 31 | from meshgen.utils.misc import get_device 32 | from meshgen.modules.shape2vecset import DiagonalGaussianDistribution 33 | 34 | 35 | class TriplaneAEModel(pl.LightningModule): 36 | def __init__( 37 | self, 38 | encoder, 39 | deconv_decoder, 40 | mlp_decoder, 41 | triplane_res, 42 | triplane_ch, 43 | ckpt_path=None, 44 | ignore_keys=[], 45 | scheduler_config=None, 46 | use_ema=False, 47 | weight_decay=0.0, 48 | monitor=None, 49 | is_shapenet=False, 50 | box_warp=0.55 * 2, 51 | use_diso=False, 52 | ): 53 | super().__init__() 54 | self.encoder = instantiate_from_config(encoder) 55 | 56 | self.deconv_decoder = instantiate_from_config(deconv_decoder) 57 | self.mlp_decoder = instantiate_from_config(mlp_decoder) 58 | 59 | self.use_ema = use_ema 60 | if self.use_ema: 61 | self.model_ema = EMAModel(self.parameters(), decay=0.999) 62 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 63 | 64 | self.scheduler_config = scheduler_config 65 | self.use_scheduler = scheduler_config is not None 66 | self.weight_decay = weight_decay 67 | 68 | if monitor is not None: 69 | self.monitor = monitor 70 | 71 | self.triplane_res = triplane_res 72 | self.triplane_ch = triplane_ch 73 | 74 | self.loss = torch.nn.BCEWithLogitsLoss() 75 | self.train() 76 | 77 | self.is_shapenet = is_shapenet 78 | 79 | # self.plane_axes = generate_planes() 80 | self.register_buffer("plane_axes", generate_planes()) 81 | self.box_warp = box_warp 82 | 83 | if ckpt_path is not None: 84 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 85 | 86 | self.use_diso = use_diso 87 | if self.use_diso: 88 | raise NotImplementedError("DiffDMC is not implemented yet.") 89 | # self.diffdmc = DiffDMC(torch.float32) 90 | 91 | def init_from_ckpt(self, path, ignore_keys=list()): 92 | sd = torch.load(path, map_location="cpu")["state_dict"] 93 | keys = list(sd.keys()) 94 | for k in keys: 95 | for ik in ignore_keys: 96 | if k.startswith(ik): 97 | print("Deleting key {} from state_dict.".format(k)) 98 | del sd[k] 99 | missing, unexpected = self.load_state_dict(sd, strict=False) 100 | print( 101 | f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" 102 | ) 103 | if len(missing) > 0 or len(unexpected) > 0: 104 | print(f"Missing Keys: {missing}") 105 | print(f"Unexpected Keys: {unexpected}") 106 | 107 | def encode(self, x): 108 | z = self.encoder(x) 109 | z = z.reshape(-1, 3, self.triplane_res, self.triplane_res, self.triplane_ch) 110 | z = rearrange(z, "b n h w c -> b n c h w") 111 | 112 | return z 113 | 114 | def upsample(self, z): 115 | if z.ndim == 5: 116 | z = rearrange(z, "b n c h w -> b c (n h) w") 117 | z = self.deconv_decoder(z) 118 | z = rearrange(z, "b c (n h) w -> b n c h w", n=3) 119 | 120 | return z 121 | 122 | def decode(self, z, queries, upsample=False): 123 | if upsample: 124 | z = self.upsample(z) 125 | features = sample_from_planes( 126 | self.plane_axes, z, queries, box_warp=self.box_warp 127 | ) 128 | features = rearrange(features, "b np q c -> b q (np c)") 129 | output = self.mlp_decoder(features) 130 | 131 | return output 132 | 133 | @torch.no_grad() 134 | def diff_decode_shape(self, z, R=256, skip_upsample=False): 135 | if not skip_upsample: 136 | z = self.upsample(z) 137 | extend = 0.55 if not self.is_shapenet else 1.05 138 | grid = ( 139 | torch.stack( 140 | torch.meshgrid( 141 | torch.arange(R), 142 | torch.arange(R), 143 | torch.arange(R), 144 | ), 145 | dim=-1, 146 | ) 147 | + 0.5 148 | ) 149 | grid = grid.reshape(-1, 3).to(self.device).float() / R * 2 * extend - extend 150 | # thresh = 0.5 151 | query = grid 152 | rec = self.decode(z, query[None]) 153 | occpancies = torch.sigmoid(rec).reshape(R, R, R) 154 | 155 | return occpancies 156 | 157 | @torch.no_grad() 158 | def decode_shape(self, z, thresh=0.5, R=256, skip_upsample=False): 159 | if not skip_upsample: 160 | z = self.upsample(z) 161 | mini_bs = 65536 * 16 162 | extend = 0.55 if not self.is_shapenet else 1.05 163 | grid = ( 164 | torch.stack( 165 | torch.meshgrid( 166 | torch.arange(R), 167 | torch.arange(R), 168 | torch.arange(R), 169 | ), 170 | dim=-1, 171 | ) 172 | + 0.5 173 | ) 174 | occs = [] 175 | grid = grid.reshape(-1, 3).to(self.device).float() / R * 2 * extend - extend 176 | # thresh = 0.5 177 | for start in tqdm.trange(0, grid.shape[0], mini_bs, disable=True): 178 | end = min(start + mini_bs, grid.shape[0]) 179 | query = grid[start:end] 180 | rec = self.decode(z, query[None]) 181 | occpancy = torch.sigmoid(rec) 182 | occs.append(occpancy.cpu()) 183 | 184 | occs = torch.cat(occs, dim=0).reshape(R, R, R).to(self.device) 185 | if not self.use_diso: 186 | vertices_pred, faces_pred, normals_pred, _ = measure.marching_cubes( 187 | occs.detach().cpu().float().numpy(), thresh, method="lewiner" 188 | ) 189 | else: 190 | vertices_pred, faces_pred = self.diffdmc( 191 | -occs.float(), isovalue=-thresh, normalize=True 192 | ) 193 | 194 | return vertices_pred, faces_pred 195 | 196 | def forward(self, x, queries): 197 | z = self.encode(x) 198 | z = self.upsample(z) 199 | x_hat = self.decode(z, queries) 200 | 201 | return x_hat 202 | 203 | def shared_step(self, batch, batch_idx): 204 | surface, queries, occupancies = ( 205 | batch["surface"], 206 | batch["queries"], 207 | batch["occupancies"], 208 | ) 209 | 210 | pred = self(surface, queries) 211 | target = occupancies 212 | 213 | loss_dict = {"loss": self.loss(pred, target).item()} 214 | 215 | return self.loss(pred, target), loss_dict 216 | 217 | def log_prefix(self, loss_dict, prefix): 218 | for k, v in loss_dict.items(): 219 | self.log( 220 | f"{prefix}/{k}", 221 | v, 222 | prog_bar=True, 223 | logger=True, 224 | on_step=True, 225 | on_epoch=True, 226 | ) 227 | 228 | def training_step(self, batch, batch_idx): 229 | loss, loss_dict = self.shared_step(batch, batch_idx) 230 | self.log_prefix(loss_dict, "train") 231 | return loss 232 | 233 | def validation_step(self, batch, batch_idx): 234 | loss, loss_dict = self.shared_step(batch, batch_idx) 235 | self.log_prefix(loss_dict, "val") 236 | return loss 237 | 238 | def configure_optimizers(self): 239 | lr = self.learning_rate 240 | param_groups = self.get_optim_groups() 241 | opt = torch.optim.AdamW(param_groups, lr=lr) 242 | if self.use_scheduler: 243 | assert "target" in self.scheduler_config 244 | scheduler = instantiate_from_config(self.scheduler_config) 245 | 246 | print("Setting up LambdaLR scheduler...") 247 | scheduler = [ 248 | { 249 | "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), 250 | "interval": "step", 251 | "frequency": 1, 252 | } 253 | ] 254 | return [opt], scheduler 255 | return opt 256 | 257 | def get_optim_groups(self): 258 | decay = set() 259 | no_decay = set() 260 | whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d) 261 | blacklist_weight_modules = ( 262 | torch.nn.LayerNorm, 263 | torch.nn.Embedding, 264 | torch.nn.Parameter, 265 | torch.nn.GroupNorm, 266 | ) 267 | for mn, m in self.named_modules(): 268 | for pn, p in m.named_parameters(): 269 | fpn = "%s.%s" % (mn, pn) if mn else pn # full param name 270 | 271 | if pn.endswith("bias"): 272 | # all biases will not be decayed 273 | no_decay.add(fpn) 274 | elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): 275 | # weights of whitelist modules will be weight decayed 276 | decay.add(fpn) 277 | elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): 278 | # weights of blacklist modules will NOT be weight decayed 279 | no_decay.add(fpn) 280 | elif pn.endswith("embed"): 281 | no_decay.add(fpn) 282 | elif pn.endswith("pos_emb"): 283 | no_decay.add(fpn) 284 | elif pn.endswith("query"): 285 | no_decay.add(fpn) 286 | 287 | # special case the position embedding parameter in the root GPT module as not decayed 288 | 289 | # validate that we considered every parameter 290 | param_dict = {pn: p for pn, p in self.named_parameters()} 291 | inter_params = decay & no_decay 292 | union_params = decay | no_decay 293 | assert ( 294 | len(inter_params) == 0 295 | ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) 296 | assert ( 297 | len(param_dict.keys() - union_params) == 0 298 | ), "parameters %s were not separated into either decay/no_decay set!" % ( 299 | str(param_dict.keys() - union_params), 300 | ) 301 | 302 | optim_groups = [ 303 | { 304 | "params": [param_dict[pn] for pn in sorted(list(decay))], 305 | "weight_decay": self.weight_decay, 306 | }, 307 | { 308 | "params": [param_dict[pn] for pn in sorted(list(no_decay))], 309 | "weight_decay": 0.0, 310 | }, 311 | ] 312 | 313 | return optim_groups 314 | 315 | @torch.no_grad() 316 | def eval_one(self, surface, do_rendering=False): 317 | if isinstance(surface, (str, Path)): 318 | vertices, faces = load_mesh(surface, self.device) 319 | surface = sample_on_surface(vertices, faces, self.encoder.num_inputs) 320 | latents = self.encode(surface[None]) 321 | v, f = self.decode_shape(latents) 322 | 323 | if do_rendering: 324 | renders = render_mesh_spiral_offscreen( 325 | v, f, elevation=20, radius=3, num_frames=90 326 | ) 327 | return v, f, renders 328 | else: 329 | return v, f 330 | 331 | def on_train_epoch_end(self, *args, **kwargs): 332 | self.eval() 333 | logdir = self.trainer.logdir ### 334 | this_logdir = Path(logdir) / f"spirals/epoch_{self.current_epoch}" 335 | this_logdir.mkdir(exist_ok=True, parents=True) 336 | 337 | local_rank = self.local_rank 338 | world_size = self.trainer.world_size 339 | meshes = list(Path("./data/eval_meshes").glob("*.off"))[:8] 340 | meshes = meshes[local_rank::world_size] 341 | 342 | for mesh in tqdm.tqdm(meshes, disable=local_rank): 343 | try: 344 | _, _, rendered = self.eval_one(mesh, do_rendering=True) 345 | write_video(this_logdir / f"{mesh.stem}.mp4", rendered) 346 | except TypeError: 347 | pass 348 | 349 | self.train() 350 | 351 | def on_train_start(self): 352 | return 353 | self.eval() 354 | logdir = self.trainer.logdir ### 355 | this_logdir = Path(logdir) / f"spirals/before_training" 356 | this_logdir.mkdir(exist_ok=True, parents=True) 357 | 358 | local_rank = self.local_rank 359 | world_size = self.trainer.world_size 360 | meshes = list(Path("./data/eval_meshes").glob("*.off"))[:8] 361 | meshes = meshes[local_rank::world_size] 362 | 363 | for mesh in tqdm.tqdm(meshes, disable=local_rank): 364 | _, _, rendered = self.eval_one(mesh, do_rendering=True) 365 | write_video(this_logdir / f"{mesh.stem}.mp4", rendered) 366 | 367 | self.train() 368 | 369 | 370 | class TriplaneKLModel(TriplaneAEModel): 371 | def __init__(self, *args, kl_weight=1e-5, tv_loss_weight=0.0, **kwargs): 372 | super().__init__(*args, **kwargs) 373 | self.kl_weight = kl_weight 374 | self.tv_loss_weight = tv_loss_weight 375 | 376 | def encode(self, x, return_kl=False): 377 | z = self.encoder(x) 378 | z = z.reshape(-1, 3, self.triplane_res, self.triplane_res, self.triplane_ch * 2) 379 | mean, logvar = z.chunk(2, dim=-1) 380 | posterior = DiagonalGaussianDistribution(mean, logvar) 381 | z = posterior.sample() 382 | z = rearrange(z, "b n h w c -> b n c h w") 383 | 384 | if return_kl: 385 | kl = posterior.kl() 386 | return z, kl 387 | else: 388 | return z 389 | 390 | def forward(self, x, queries): 391 | z, kl = self.encode(x, return_kl=True) 392 | z = self.upsample(z) 393 | x_hat = self.decode(z, queries) 394 | 395 | return z, x_hat, kl 396 | 397 | def shared_step(self, batch, batch_idx): 398 | surface, queries, occupancies = ( 399 | batch["surface"], 400 | batch["queries"], 401 | batch["occupancies"], 402 | ) 403 | 404 | z, pred, kl = self(surface, queries) 405 | kl = torch.mean(kl) 406 | target = occupancies 407 | 408 | loss_dict = {"recon_loss": self.loss(pred, target).item(), "kl": kl.item()} 409 | 410 | loss = self.loss(pred, target) + self.kl_weight * kl 411 | 412 | if self.tv_loss_weight > 0: 413 | tv_loss = compute_tv_loss(z) 414 | loss += self.tv_loss_weight * tv_loss 415 | loss_dict["tv_loss"] = tv_loss 416 | 417 | return loss, loss_dict 418 | -------------------------------------------------------------------------------- /meshgen/utils/briarmbg.py: -------------------------------------------------------------------------------- 1 | # RMBG1.4 (diffusers implementation) 2 | # Found on huggingface space of several projects 3 | # Not sure which project is the source of this file 4 | 5 | from PIL import Image 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torchvision.transforms.functional import normalize 11 | from huggingface_hub import PyTorchModelHubMixin 12 | 13 | 14 | class REBNCONV(nn.Module): 15 | def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1): 16 | super(REBNCONV, self).__init__() 17 | 18 | self.conv_s1 = nn.Conv2d( 19 | in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride 20 | ) 21 | self.bn_s1 = nn.BatchNorm2d(out_ch) 22 | self.relu_s1 = nn.ReLU(inplace=True) 23 | 24 | def forward(self, x): 25 | hx = x 26 | xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) 27 | 28 | return xout 29 | 30 | 31 | def _upsample_like(src, tar): 32 | src = F.interpolate(src, size=tar.shape[2:], mode="bilinear") 33 | return src 34 | 35 | 36 | ### RSU-7 ### 37 | class RSU7(nn.Module): 38 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512): 39 | super(RSU7, self).__init__() 40 | 41 | self.in_ch = in_ch 42 | self.mid_ch = mid_ch 43 | self.out_ch = out_ch 44 | 45 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2 46 | 47 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 48 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 49 | 50 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 51 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 52 | 53 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 54 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 55 | 56 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) 57 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 58 | 59 | self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) 60 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 61 | 62 | self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1) 63 | 64 | self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2) 65 | 66 | self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 67 | self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 68 | self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 69 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 70 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 71 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 72 | 73 | def forward(self, x): 74 | b, c, h, w = x.shape 75 | 76 | hx = x 77 | hxin = self.rebnconvin(hx) 78 | 79 | hx1 = self.rebnconv1(hxin) 80 | hx = self.pool1(hx1) 81 | 82 | hx2 = self.rebnconv2(hx) 83 | hx = self.pool2(hx2) 84 | 85 | hx3 = self.rebnconv3(hx) 86 | hx = self.pool3(hx3) 87 | 88 | hx4 = self.rebnconv4(hx) 89 | hx = self.pool4(hx4) 90 | 91 | hx5 = self.rebnconv5(hx) 92 | hx = self.pool5(hx5) 93 | 94 | hx6 = self.rebnconv6(hx) 95 | 96 | hx7 = self.rebnconv7(hx6) 97 | 98 | hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1)) 99 | hx6dup = _upsample_like(hx6d, hx5) 100 | 101 | hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1)) 102 | hx5dup = _upsample_like(hx5d, hx4) 103 | 104 | hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) 105 | hx4dup = _upsample_like(hx4d, hx3) 106 | 107 | hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) 108 | hx3dup = _upsample_like(hx3d, hx2) 109 | 110 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 111 | hx2dup = _upsample_like(hx2d, hx1) 112 | 113 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 114 | 115 | return hx1d + hxin 116 | 117 | 118 | ### RSU-6 ### 119 | class RSU6(nn.Module): 120 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 121 | super(RSU6, self).__init__() 122 | 123 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 124 | 125 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 126 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 127 | 128 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 129 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 130 | 131 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 132 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 133 | 134 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) 135 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 136 | 137 | self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) 138 | 139 | self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2) 140 | 141 | self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 142 | self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 143 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 144 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 145 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 146 | 147 | def forward(self, x): 148 | hx = x 149 | 150 | hxin = self.rebnconvin(hx) 151 | 152 | hx1 = self.rebnconv1(hxin) 153 | hx = self.pool1(hx1) 154 | 155 | hx2 = self.rebnconv2(hx) 156 | hx = self.pool2(hx2) 157 | 158 | hx3 = self.rebnconv3(hx) 159 | hx = self.pool3(hx3) 160 | 161 | hx4 = self.rebnconv4(hx) 162 | hx = self.pool4(hx4) 163 | 164 | hx5 = self.rebnconv5(hx) 165 | 166 | hx6 = self.rebnconv6(hx5) 167 | 168 | hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1)) 169 | hx5dup = _upsample_like(hx5d, hx4) 170 | 171 | hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) 172 | hx4dup = _upsample_like(hx4d, hx3) 173 | 174 | hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) 175 | hx3dup = _upsample_like(hx3d, hx2) 176 | 177 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 178 | hx2dup = _upsample_like(hx2d, hx1) 179 | 180 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 181 | 182 | return hx1d + hxin 183 | 184 | 185 | ### RSU-5 ### 186 | class RSU5(nn.Module): 187 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 188 | super(RSU5, self).__init__() 189 | 190 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 191 | 192 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 193 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 194 | 195 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 196 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 197 | 198 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 199 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 200 | 201 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) 202 | 203 | self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2) 204 | 205 | self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 206 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 207 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 208 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 209 | 210 | def forward(self, x): 211 | hx = x 212 | 213 | hxin = self.rebnconvin(hx) 214 | 215 | hx1 = self.rebnconv1(hxin) 216 | hx = self.pool1(hx1) 217 | 218 | hx2 = self.rebnconv2(hx) 219 | hx = self.pool2(hx2) 220 | 221 | hx3 = self.rebnconv3(hx) 222 | hx = self.pool3(hx3) 223 | 224 | hx4 = self.rebnconv4(hx) 225 | 226 | hx5 = self.rebnconv5(hx4) 227 | 228 | hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1)) 229 | hx4dup = _upsample_like(hx4d, hx3) 230 | 231 | hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) 232 | hx3dup = _upsample_like(hx3d, hx2) 233 | 234 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 235 | hx2dup = _upsample_like(hx2d, hx1) 236 | 237 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 238 | 239 | return hx1d + hxin 240 | 241 | 242 | ### RSU-4 ### 243 | class RSU4(nn.Module): 244 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 245 | super(RSU4, self).__init__() 246 | 247 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 248 | 249 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 250 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 251 | 252 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 253 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 254 | 255 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 256 | 257 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2) 258 | 259 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 260 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 261 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 262 | 263 | def forward(self, x): 264 | hx = x 265 | 266 | hxin = self.rebnconvin(hx) 267 | 268 | hx1 = self.rebnconv1(hxin) 269 | hx = self.pool1(hx1) 270 | 271 | hx2 = self.rebnconv2(hx) 272 | hx = self.pool2(hx2) 273 | 274 | hx3 = self.rebnconv3(hx) 275 | 276 | hx4 = self.rebnconv4(hx3) 277 | 278 | hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) 279 | hx3dup = _upsample_like(hx3d, hx2) 280 | 281 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 282 | hx2dup = _upsample_like(hx2d, hx1) 283 | 284 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 285 | 286 | return hx1d + hxin 287 | 288 | 289 | ### RSU-4F ### 290 | class RSU4F(nn.Module): 291 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 292 | super(RSU4F, self).__init__() 293 | 294 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 295 | 296 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 297 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2) 298 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4) 299 | 300 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8) 301 | 302 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4) 303 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2) 304 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 305 | 306 | def forward(self, x): 307 | hx = x 308 | 309 | hxin = self.rebnconvin(hx) 310 | 311 | hx1 = self.rebnconv1(hxin) 312 | hx2 = self.rebnconv2(hx1) 313 | hx3 = self.rebnconv3(hx2) 314 | 315 | hx4 = self.rebnconv4(hx3) 316 | 317 | hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) 318 | hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1)) 319 | hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1)) 320 | 321 | return hx1d + hxin 322 | 323 | 324 | class myrebnconv(nn.Module): 325 | def __init__( 326 | self, 327 | in_ch=3, 328 | out_ch=1, 329 | kernel_size=3, 330 | stride=1, 331 | padding=1, 332 | dilation=1, 333 | groups=1, 334 | ): 335 | super(myrebnconv, self).__init__() 336 | 337 | self.conv = nn.Conv2d( 338 | in_ch, 339 | out_ch, 340 | kernel_size=kernel_size, 341 | stride=stride, 342 | padding=padding, 343 | dilation=dilation, 344 | groups=groups, 345 | ) 346 | self.bn = nn.BatchNorm2d(out_ch) 347 | self.rl = nn.ReLU(inplace=True) 348 | 349 | def forward(self, x): 350 | return self.rl(self.bn(self.conv(x))) 351 | 352 | 353 | class BriaRMBG(nn.Module, PyTorchModelHubMixin): 354 | def __init__(self, config: dict = {"in_ch": 3, "out_ch": 1}): 355 | super(BriaRMBG, self).__init__() 356 | in_ch = config["in_ch"] 357 | out_ch = config["out_ch"] 358 | self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1) 359 | self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True) 360 | 361 | self.stage1 = RSU7(64, 32, 64) 362 | self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 363 | 364 | self.stage2 = RSU6(64, 32, 128) 365 | self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 366 | 367 | self.stage3 = RSU5(128, 64, 256) 368 | self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 369 | 370 | self.stage4 = RSU4(256, 128, 512) 371 | self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 372 | 373 | self.stage5 = RSU4F(512, 256, 512) 374 | self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 375 | 376 | self.stage6 = RSU4F(512, 256, 512) 377 | 378 | # decoder 379 | self.stage5d = RSU4F(1024, 256, 512) 380 | self.stage4d = RSU4(1024, 128, 256) 381 | self.stage3d = RSU5(512, 64, 128) 382 | self.stage2d = RSU6(256, 32, 64) 383 | self.stage1d = RSU7(128, 16, 64) 384 | 385 | self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) 386 | self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) 387 | self.side3 = nn.Conv2d(128, out_ch, 3, padding=1) 388 | self.side4 = nn.Conv2d(256, out_ch, 3, padding=1) 389 | self.side5 = nn.Conv2d(512, out_ch, 3, padding=1) 390 | self.side6 = nn.Conv2d(512, out_ch, 3, padding=1) 391 | 392 | # self.outconv = nn.Conv2d(6*out_ch,out_ch,1) 393 | 394 | def forward(self, x): 395 | hx = x 396 | 397 | hxin = self.conv_in(hx) 398 | # hx = self.pool_in(hxin) 399 | 400 | # stage 1 401 | hx1 = self.stage1(hxin) 402 | hx = self.pool12(hx1) 403 | 404 | # stage 2 405 | hx2 = self.stage2(hx) 406 | hx = self.pool23(hx2) 407 | 408 | # stage 3 409 | hx3 = self.stage3(hx) 410 | hx = self.pool34(hx3) 411 | 412 | # stage 4 413 | hx4 = self.stage4(hx) 414 | hx = self.pool45(hx4) 415 | 416 | # stage 5 417 | hx5 = self.stage5(hx) 418 | hx = self.pool56(hx5) 419 | 420 | # stage 6 421 | hx6 = self.stage6(hx) 422 | hx6up = _upsample_like(hx6, hx5) 423 | 424 | # -------------------- decoder -------------------- 425 | hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) 426 | hx5dup = _upsample_like(hx5d, hx4) 427 | 428 | hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) 429 | hx4dup = _upsample_like(hx4d, hx3) 430 | 431 | hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) 432 | hx3dup = _upsample_like(hx3d, hx2) 433 | 434 | hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) 435 | hx2dup = _upsample_like(hx2d, hx1) 436 | 437 | hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) 438 | 439 | # side output 440 | d1 = self.side1(hx1d) 441 | d1 = _upsample_like(d1, x) 442 | 443 | d2 = self.side2(hx2d) 444 | d2 = _upsample_like(d2, x) 445 | 446 | d3 = self.side3(hx3d) 447 | d3 = _upsample_like(d3, x) 448 | 449 | d4 = self.side4(hx4d) 450 | d4 = _upsample_like(d4, x) 451 | 452 | d5 = self.side5(hx5d) 453 | d5 = _upsample_like(d5, x) 454 | 455 | d6 = self.side6(hx6) 456 | d6 = _upsample_like(d6, x) 457 | 458 | return [ 459 | F.sigmoid(d1), 460 | F.sigmoid(d2), 461 | F.sigmoid(d3), 462 | F.sigmoid(d4), 463 | F.sigmoid(d5), 464 | F.sigmoid(d6), 465 | ], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6] 466 | 467 | 468 | def preprocess_image(im, model_input_size: list) -> torch.Tensor: 469 | if len(im.shape) < 3: 470 | im = im[:, :, np.newaxis] 471 | # orig_im_size=im.shape[0:2] 472 | im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1) 473 | im_tensor = F.interpolate( 474 | torch.unsqueeze(im_tensor, 0), size=model_input_size, mode="bilinear" 475 | ) 476 | image = torch.divide(im_tensor, 255.0) 477 | image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) 478 | return image 479 | 480 | 481 | def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray: 482 | result = torch.squeeze(F.interpolate(result, size=im_size, mode="bilinear"), 0) 483 | ma = torch.max(result) 484 | mi = torch.min(result) 485 | result = (result - mi) / (ma - mi) 486 | im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8) 487 | im_array = np.squeeze(im_array) 488 | return im 489 | -------------------------------------------------------------------------------- /meshgen/utils/ops.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from PIL import Image 3 | import math 4 | import numpy as np 5 | import torch 6 | import kaolin 7 | import pyrender 8 | from einops import rearrange 9 | import numbers 10 | import torch.nn.functional as F 11 | import pyvista as pv 12 | import pyacvd 13 | from einops import repeat 14 | import itertools 15 | from tqdm import tqdm 16 | 17 | import pyfqmr 18 | 19 | from meshgen.utils.birefnet import run_model as rembg_birefnet 20 | 21 | 22 | def sample_on_surface(vertices, faces, num_samples): 23 | """ 24 | sample on surface for a single mesh 25 | """ 26 | input_np = False 27 | if isinstance(vertices, np.ndarray): 28 | input_np = True 29 | vertices = torch.from_numpy(vertices) 30 | 31 | samples = kaolin.ops.mesh.sample_points(vertices[None], faces, num_samples)[0][0] 32 | 33 | if input_np: 34 | samples = samples.cpu().numpy() 35 | 36 | return samples 37 | 38 | 39 | # def fps_sample(vertices, num_samples): 40 | # return kaolin.ops.mesh.farthest_point_sample(vertices[None], num_samples)[0] 41 | 42 | 43 | def generate_planes(): 44 | """ 45 | Defines planes by the three vectors that form the "axes" of the 46 | plane. Should work with arbitrary number of planes and planes of 47 | arbitrary orientation. 48 | 49 | Bugfix reference: https://github.com/NVlabs/eg3d/issues/67 50 | """ 51 | return torch.tensor( 52 | [ 53 | [[1, 0, 0], [0, 1, 0], [0, 0, 1]], 54 | [[1, 0, 0], [0, 0, 1], [0, 1, 0]], 55 | [[0, 0, 1], [0, 1, 0], [1, 0, 0]], 56 | ], 57 | dtype=torch.float32, 58 | ) 59 | 60 | 61 | def project_onto_planes(planes, coordinates): 62 | """ 63 | Does a projection of a 3D point onto a batch of 2D planes, 64 | returning 2D plane coordinates. 65 | 66 | Takes plane axes of shape n_planes, 3, 3 67 | # Takes coordinates of shape N, M, 3 68 | # returns projections of shape N*n_planes, M, 2 69 | """ 70 | N, M, C = coordinates.shape 71 | n_planes, _, _ = planes.shape 72 | coordinates = ( 73 | coordinates.unsqueeze(1) 74 | .expand(-1, n_planes, -1, -1) 75 | .reshape(N * n_planes, M, 3) 76 | ) 77 | inv_planes = ( 78 | torch.linalg.inv(planes) 79 | .unsqueeze(0) 80 | .expand(N, -1, -1, -1) 81 | .reshape(N * n_planes, 3, 3) 82 | ) 83 | projections = torch.bmm(coordinates, inv_planes) 84 | return projections[..., :2] 85 | 86 | 87 | def sample_from_planes( 88 | plane_axes, 89 | plane_features, 90 | coordinates, 91 | mode="bilinear", 92 | padding_mode="zeros", 93 | box_warp=None, 94 | ): 95 | assert padding_mode == "zeros" 96 | N, n_planes, C, H, W = plane_features.shape 97 | _, M, _ = coordinates.shape 98 | plane_features = plane_features.reshape(N * n_planes, C, H, W) 99 | dtype = plane_features.dtype 100 | 101 | coordinates = (2 / box_warp) * coordinates # add specific box bounds 102 | 103 | projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1) 104 | with torch.autocast("cuda", enabled=False): 105 | output_features = ( 106 | torch.nn.functional.grid_sample( 107 | plane_features.float(), 108 | projected_coordinates.float(), 109 | mode=mode, 110 | padding_mode=padding_mode, 111 | align_corners=False, 112 | ) 113 | .permute(0, 3, 2, 1) 114 | .reshape(N, n_planes, M, C) 115 | ) 116 | return output_features 117 | 118 | 119 | def logit_normal(mu, sigma, shape, device, dtype): 120 | z = torch.randn(*shape, device=device, dtype=dtype) 121 | z = mu + sigma * z 122 | t = torch.sigmoid(z) 123 | return t 124 | 125 | 126 | def checkpoint(func, inputs, params, flag): 127 | """ 128 | Evaluate a function without caching intermediate activations, allowing for 129 | reduced memory at the expense of extra compute in the backward pass. 130 | :param func: the function to evaluate. 131 | :param inputs: the argument sequence to pass to `func`. 132 | :param params: a sequence of parameters `func` depends on but does not 133 | explicitly take as arguments. 134 | :param flag: if False, disable gradient checkpointing. 135 | """ 136 | if flag: 137 | args = tuple(inputs) + tuple(params) 138 | return CheckpointFunction.apply(func, len(inputs), *args) 139 | else: 140 | return func(*inputs) 141 | 142 | 143 | class CheckpointFunction(torch.autograd.Function): 144 | @staticmethod 145 | @torch.cuda.amp.custom_fwd 146 | def forward(ctx, run_function, length, *args): 147 | ctx.run_function = run_function 148 | ctx.input_tensors = list(args[:length]) 149 | ctx.input_params = list(args[length:]) 150 | 151 | with torch.no_grad(): 152 | output_tensors = ctx.run_function(*ctx.input_tensors) 153 | return output_tensors 154 | 155 | @staticmethod 156 | @torch.cuda.amp.custom_bwd 157 | def backward(ctx, *output_grads): 158 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 159 | with torch.enable_grad(): 160 | # Fixes a bug where the first op in run_function modifies the 161 | # Tensor storage in place, which is not allowed for detach()'d 162 | # Tensors. 163 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 164 | output_tensors = ctx.run_function(*shallow_copies) 165 | input_grads = torch.autograd.grad( 166 | output_tensors, 167 | ctx.input_tensors + ctx.input_params, 168 | output_grads, 169 | allow_unused=True, 170 | ) 171 | del ctx.input_tensors 172 | del ctx.input_params 173 | del output_tensors 174 | return (None, None) + input_grads 175 | 176 | 177 | def dot(x, y): 178 | return torch.sum(x * y, -1, keepdim=True) 179 | 180 | 181 | def length(x, eps=1e-20): 182 | return torch.sqrt(torch.clamp(dot(x, x), min=eps)) 183 | 184 | 185 | def safe_normalize(x, eps=1e-20): 186 | return x / length(x, eps) 187 | 188 | 189 | def get_projection_matrix(fov, reso, flip_y=False): 190 | # flip_y is used in nvdiffrast 191 | fov = np.deg2rad(fov) 192 | cam = pyrender.PerspectiveCamera(yfov=fov) 193 | 194 | proj_mat = cam.get_projection_matrix(reso, reso) 195 | if flip_y: 196 | proj_mat[1, :] *= -1 197 | 198 | return proj_mat 199 | 200 | 201 | def calc_normal(v, f): 202 | i0, i1, i2 = ( 203 | f[:, 0].long(), 204 | f[:, 1].long(), 205 | f[:, 2].long(), 206 | ) 207 | v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :] 208 | 209 | face_normals = torch.cross(v1 - v0, v2 - v0) 210 | face_normals = safe_normalize(face_normals) 211 | 212 | vn = torch.zeros_like(v) 213 | vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) 214 | vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) 215 | vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) 216 | 217 | vn = torch.where( 218 | torch.sum(vn * vn, -1, keepdim=True) > 1e-20, 219 | vn, 220 | torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device), 221 | ) 222 | 223 | return vn 224 | 225 | 226 | def compute_tv_loss(planes): 227 | # planes are with shape [b, 3, c, h, w] 228 | planes = rearrange(planes, "b n c h w -> (b n) c h w") 229 | batch_size, c, h, w = planes.shape 230 | count_h = batch_size * c * (h - 1) * w 231 | count_w = batch_size * c * h * (w - 1) 232 | h_tv = torch.square(planes[..., 1:, :] - planes[..., : h - 1, :]).sum() 233 | w_tv = torch.square(planes[..., :, 1:] - planes[..., :, : w - 1]).sum() 234 | return 2 * ( 235 | h_tv / count_h + w_tv / count_w 236 | ) # This is summing over batch and c instead of avg 237 | 238 | 239 | class GaussianSmoothing(torch.nn.Module): 240 | """ 241 | Apply gaussian smoothing on a 242 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 243 | in the input using a depthwise convolution. 244 | Arguments: 245 | channels (int, sequence): 246 | Number of channels of the input tensors. 247 | Output will have this number of channels as well. 248 | kernel_size (int, sequence): 249 | Size of the gaussian kernel. 250 | sigma (float, sequence): 251 | Standard deviation of the gaussian kernel. 252 | dim (int, optional): 253 | The number of dimensions of the data. 254 | Default value is 2 (spatial). 255 | stride (int, sequence, optional): 256 | Stride for Conv module. 257 | padding (int, sequence, optional): 258 | Padding for Conv module. 259 | padding_mode (str, optional): 260 | Padding mode for Conv module. 261 | """ 262 | 263 | def __init__(self, channels, kernel_size, sigma, dim=2, stride=1, padding=0): 264 | super(GaussianSmoothing, self).__init__() 265 | if isinstance(kernel_size, numbers.Number): 266 | kernel_size = [kernel_size] * dim 267 | if isinstance(sigma, numbers.Number): 268 | sigma = [sigma] * dim 269 | 270 | # Used for Conv module 271 | self.stride = stride 272 | self.padding = padding 273 | 274 | # The gaussian kernel is the product of the 275 | # gaussian function of each dimension. 276 | kernel = 1 277 | meshgrids = torch.meshgrid( 278 | [torch.arange(size, dtype=torch.float32) for size in kernel_size], 279 | indexing="ij", 280 | ) 281 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 282 | mean = (size - 1) / 2 283 | kernel *= ( 284 | 1 285 | / (std * math.sqrt(2 * math.pi)) 286 | * torch.exp(-(((mgrid - mean) / std) ** 2) / 2) 287 | ) 288 | 289 | # Make sure sum of values in gaussian kernel equals 1. 290 | kernel = kernel / torch.sum(kernel) 291 | 292 | # Reshape to depthwise convolutional weight 293 | kernel = kernel.view(1, 1, *kernel.size()) 294 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 295 | 296 | self.register_buffer("weight", kernel) 297 | self.groups = channels 298 | 299 | if dim == 1: 300 | self.conv = F.conv1d 301 | elif dim == 2: 302 | self.conv = F.conv2d 303 | elif dim == 3: 304 | self.conv = F.conv3d 305 | else: 306 | raise RuntimeError( 307 | f"Only 1, 2 and 3 dimensions are supported. Received {dim}." 308 | ) 309 | 310 | def forward(self, input): 311 | """ 312 | Apply gaussian filter to input. 313 | Arguments: 314 | input (torch.Tensor): Input to apply gaussian filter on. 315 | Returns: 316 | filtered (torch.Tensor): Filtered output. 317 | """ 318 | return self.conv( 319 | input, 320 | weight=self.weight, 321 | groups=self.groups, 322 | stride=self.stride, 323 | padding=self.padding, 324 | ) 325 | 326 | 327 | def preprocess_ip(ip, ref_mask, ref_image, ignore_alpha=True): 328 | ref_h, ref_w = ref_mask.shape 329 | # if ip.mode != "RGBA": 330 | # ip = ip.convert("RGB") 331 | # ip = remove(ip, alpha_matting=False, session=get_rembg_session()) 332 | if ip.mode != "RGBA" or ignore_alpha: 333 | ip = ip.convert("RGB") 334 | ip = rembg_birefnet(ip) 335 | # ip = remove(ip, alpha_matting=False, session=get_rembg_session()) 336 | # blend with white background 337 | ip = np.asarray(ip) 338 | ip_mask = ip[:, :, 3].astype(np.float32) / 255 339 | ip_rgb = ip[:, :, :3].astype(np.float32) / 255 340 | ip_rgb = ip_rgb * ip_mask[:, :, None].astype(np.float32) + ( 341 | 1 - ip_mask[:, :, None].astype(np.float32) 342 | ) 343 | ip = Image.fromarray((ip_rgb * 255).astype(np.uint8)) 344 | 345 | coords = np.nonzero(ip_mask) 346 | y_min, y_max = coords[0].min(), coords[0].max() 347 | x_min, x_max = coords[1].min(), coords[1].max() 348 | len_x, len_y = x_max - x_min, y_max - y_min 349 | center = (x_min + x_max) // 2, (y_min + y_max) // 2 350 | 351 | ref_coords = np.nonzero(ref_mask) 352 | ref_y_min, ref_y_max = ref_coords[0].min(), ref_coords[0].max() 353 | ref_x_min, ref_x_max = ref_coords[1].min(), ref_coords[1].max() 354 | ref_len_x, ref_len_y = ref_x_max - ref_x_min, ref_y_max - ref_y_min 355 | ref_center = (ref_x_min + ref_x_max) // 2, (ref_y_min + ref_y_max) // 2 356 | 357 | ip = ip.crop((x_min, y_min, x_max, y_max)).resize((ref_len_x, ref_len_y)) 358 | new_ip = Image.new("RGB", (ref_h, ref_w), (255, 255, 255)) 359 | new_ip.paste(ip, (ref_x_min, ref_y_min, ref_x_max, ref_y_max)) 360 | 361 | return new_ip 362 | 363 | 364 | def mesh_simplification(v, f, target=50000, backend="pyacvd"): 365 | is_torch = False 366 | if isinstance(v, torch.Tensor): 367 | is_torch = True 368 | device = v.device 369 | dtype = v.dtype 370 | v = v.detach().cpu().numpy() 371 | f = f.detach().cpu().numpy() 372 | 373 | if backend == "pyfqmr": 374 | mesh_simplifier = pyfqmr.Simplify() 375 | mesh_simplifier.setMesh(v, f) 376 | mesh_simplifier.simplify_mesh( 377 | target_count=target, 378 | aggressiveness=7, 379 | preserve_border=True, 380 | ) 381 | ret = mesh_simplifier.getMesh()[:2] 382 | elif backend == "pyacvd": 383 | cells = np.zeros((f.shape[0], 4), dtype=int) 384 | cells[:, 1:] = f 385 | cells[:, 0] = 3 386 | mesh = pv.PolyData(v, cells) 387 | clus = pyacvd.Clustering(mesh) 388 | clus.cluster(target) 389 | remesh = clus.create_mesh() 390 | 391 | vertices = remesh.points 392 | faces = remesh.faces.reshape(-1, 4)[:, 1:] 393 | ret = [vertices, faces] 394 | else: 395 | raise ValueError(f"Backend {backend} not implemented in mesh simplification") 396 | 397 | if is_torch: 398 | ret = list(map(lambda x: torch.from_numpy(x).to(device=device), ret)) 399 | 400 | return ret 401 | 402 | 403 | def dilate_depth_outline(depth, iters=5, dilate_kernel=3): 404 | # ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2GRAY) 405 | 406 | img = np.asarray(depth) 407 | for i in range(iters): 408 | _, mask = cv2.threshold(img, thresh=0, maxval=255, type=cv2.THRESH_BINARY) 409 | mask = cv2.GaussianBlur(mask, (3, 3), 0) 410 | mask = cv2.erode(mask, np.ones((3, 3), np.uint8)) 411 | mask = mask / 255 412 | 413 | img_dilate = cv2.dilate(img, np.ones((dilate_kernel, dilate_kernel), np.uint8)) 414 | 415 | img = (mask * img + (1 - mask) * img_dilate).astype(np.uint8) 416 | return Image.fromarray(img) 417 | 418 | 419 | def dilate_mask(mask, dilate_kernel=10): 420 | mask = np.asarray(mask) 421 | mask = cv2.dilate(mask, np.ones((dilate_kernel, dilate_kernel), np.uint8)) 422 | 423 | return Image.fromarray(mask) 424 | 425 | 426 | def extract_bg_mask(img, mask_color=[204, 25, 204], dilate_kernel=5): 427 | """ 428 | :param mask_color: BGR 429 | :return: 430 | """ 431 | img = np.asarray(img) 432 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 433 | 434 | mask = (img == mask_color).all(axis=2).astype(np.float32) 435 | mask = mask[:, :, np.newaxis] 436 | 437 | mask = cv2.dilate(mask, np.ones((dilate_kernel, dilate_kernel), np.uint8))[ 438 | :, :, np.newaxis 439 | ] 440 | mask = (mask * 255).astype(np.uint8) 441 | mask = repeat(mask, "h w 1 -> h w c", c=3) 442 | return Image.fromarray(mask) 443 | 444 | 445 | @torch.no_grad() 446 | def dilate_mask(mask, kernel_size=10, format="hwc"): 447 | is_torch = False 448 | is_pil = False 449 | if format == "chw": 450 | mask = rearrange(mask, "c h w -> h w c") 451 | if isinstance(mask, torch.Tensor): 452 | is_torch = True 453 | dtype = mask.dtype 454 | device = mask.device 455 | mask = mask.detach().cpu().numpy() 456 | if isinstance(mask, Image.Image): 457 | is_pil = True 458 | mask = np.asarray(mask) / 255 459 | mask = cv2.dilate(mask, np.ones((kernel_size, kernel_size))) 460 | 461 | if is_torch: 462 | mask = torch.from_numpy(mask).to(device=device, dtype=dtype) 463 | 464 | if is_pil: 465 | mask = Image.fromarray((mask * 255).astype(np.uint8)) 466 | 467 | if format == "chw": 468 | if mask.dim() == 2: 469 | mask = rearrange(mask, "h w -> 1 h w") 470 | else: 471 | mask = rearrange(mask, "h w c -> c h w") 472 | 473 | return mask 474 | 475 | 476 | @torch.no_grad() 477 | def latent_preview(x): 478 | # adapted from https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7 479 | v1_4_latent_rgb_factors = torch.tensor( 480 | [ 481 | # R G B 482 | [0.298, 0.207, 0.208], # L1 483 | [0.187, 0.286, 0.173], # L2 484 | [-0.158, 0.189, 0.264], # L3 485 | [-0.184, -0.271, -0.473], # L4 486 | ], 487 | dtype=x.dtype, 488 | device=x.device, 489 | ) 490 | image = x.permute(0, 2, 3, 1) @ v1_4_latent_rgb_factors 491 | image = (image / 2 + 0.5).clamp(0, 1) 492 | image = image.float() 493 | image = image.cpu() 494 | image = image.numpy() 495 | return image 496 | 497 | 498 | @torch.no_grad() 499 | def get_canny_edge(image, threshold1=100, threshold2=200): 500 | image = np.asarray(image) 501 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 502 | edges = cv2.Canny(gray, threshold1, threshold2) 503 | return Image.fromarray(edges) 504 | 505 | 506 | def uv_padding(image, hole_mask, padding=2, uv_padding_block=4): 507 | uv_padding_size = padding 508 | image1 = (image[0].detach().cpu().numpy() * 255).astype(np.uint8) 509 | hole_mask = (hole_mask[0].detach().cpu().numpy() * 255).astype(np.uint8) 510 | block = uv_padding_block 511 | res = image1.shape[0] 512 | chunk = res // block 513 | inpaint_image = np.zeros_like(image1) 514 | prods = list(itertools.product(range(block), range(block))) 515 | for i, j in tqdm(prods): 516 | patch = cv2.inpaint( 517 | image1[i * chunk : (i + 1) * chunk, j * chunk : (j + 1) * chunk], 518 | hole_mask[i * chunk : (i + 1) * chunk, j * chunk : (j + 1) * chunk], 519 | uv_padding_size, 520 | cv2.INPAINT_TELEA, 521 | ) 522 | inpaint_image[i * chunk : (i + 1) * chunk, j * chunk : (j + 1) * chunk] = patch 523 | inpaint_image = inpaint_image / 255.0 524 | return torch.from_numpy(inpaint_image).to(image) 525 | --------------------------------------------------------------------------------