├── core ├── __init__.py ├── utils.py ├── options.py ├── options_pm.py ├── attention.py ├── drag_embedding.py ├── models.py └── gs.py ├── preprocess ├── src │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ └── objaverse_zero123plus.py │ ├── models │ │ ├── __init__.py │ │ ├── decoder │ │ │ ├── __init__.py │ │ │ └── transformer.py │ │ ├── encoder │ │ │ ├── __init__.py │ │ │ └── dino_wrapper.py │ │ ├── geometry │ │ │ ├── render │ │ │ │ ├── __init__.py │ │ │ │ └── neural_render.py │ │ │ ├── __init__.py │ │ │ ├── camera │ │ │ │ ├── __init__.py │ │ │ │ └── perspective_camera.py │ │ │ └── rep_3d │ │ │ │ ├── __init__.py │ │ │ │ ├── dmtet_utils.py │ │ │ │ ├── extract_texture_map.py │ │ │ │ └── flexicubes_geometry.py │ │ ├── renderer │ │ │ ├── __init__.py │ │ │ ├── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── ray_marcher.py │ │ │ │ ├── math_utils.py │ │ │ │ └── ray_sampler.py │ │ │ ├── synthesizer_mesh.py │ │ │ └── synthesizer.py │ │ └── lrm.py │ └── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── train_util.cpython-310.pyc │ │ └── train_util.cpython-38.pyc │ │ ├── train_util.py │ │ ├── infer_util.py │ │ ├── camera_util.py │ │ └── mesh_util.py ├── gen_rgba.py ├── gen_mv_partdrag4d.py ├── gen_mv_objaverse_hq.py ├── gen_propagate_drags.py └── zero123plus │ └── model.py ├── images └── teaser.png ├── PartDrag4D ├── rendering │ ├── render.sh │ ├── gen_filelist.py │ └── distributed.py └── z_buffer_al.py ├── acc_configs ├── gpu1.yaml ├── gpu4.yaml ├── gpu6.yaml └── gpu8.yaml ├── environment.yaml ├── eval.py ├── compute_metrics.py ├── train.py ├── README.md └── filelist ├── val_filelist_objaverse_hq.txt └── zero123_val_filelist_objavser_hq.txt /core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocess/src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocess/src/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocess/src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocess/src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocess/src/models/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocess/src/models/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GasaiYU/PartRM/HEAD/images/teaser.png -------------------------------------------------------------------------------- /preprocess/src/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GasaiYU/PartRM/HEAD/preprocess/src/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /preprocess/src/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GasaiYU/PartRM/HEAD/preprocess/src/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /preprocess/src/utils/__pycache__/train_util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GasaiYU/PartRM/HEAD/preprocess/src/utils/__pycache__/train_util.cpython-310.pyc -------------------------------------------------------------------------------- /preprocess/src/utils/__pycache__/train_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GasaiYU/PartRM/HEAD/preprocess/src/utils/__pycache__/train_util.cpython-38.pyc -------------------------------------------------------------------------------- /preprocess/src/models/geometry/render/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Renderer(): 4 | def __init__(self): 5 | pass 6 | 7 | def forward(self): 8 | pass -------------------------------------------------------------------------------- /PartDrag4D/rendering/render.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 \ 2 | python distributed.py \ 3 | --num_gpus 1 \ 4 | --workers_per_gpu 4 \ 5 | --view_path_root ../data/render_PartDrag4D \ 6 | --blender_path ./blender/blender-3.5.0-linux-x64 \ 7 | --input_models_path ../filelist/rendering.txt \ 8 | --num_images 12 9 | -------------------------------------------------------------------------------- /acc_configs/gpu1.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: 'NO' 4 | downcast_bf16: 'no' 5 | machine_rank: 0 6 | main_training_function: main 7 | mixed_precision: bf16 8 | num_machines: 1 9 | num_processes: 1 10 | rdzv_backend: static 11 | same_network: true 12 | tpu_env: [] 13 | tpu_use_cluster: false 14 | tpu_use_sudo: false 15 | use_cpu: false 16 | -------------------------------------------------------------------------------- /acc_configs/gpu4.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | machine_rank: 0 6 | main_training_function: main 7 | mixed_precision: fp16 8 | num_machines: 1 9 | num_processes: 4 10 | rdzv_backend: static 11 | same_network: true 12 | tpu_env: [] 13 | tpu_use_cluster: false 14 | tpu_use_sudo: false 15 | use_cpu: false 16 | -------------------------------------------------------------------------------- /acc_configs/gpu6.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | machine_rank: 0 6 | main_training_function: main 7 | mixed_precision: fp16 8 | num_machines: 1 9 | num_processes: 6 10 | rdzv_backend: static 11 | same_network: true 12 | tpu_env: [] 13 | tpu_use_cluster: false 14 | tpu_use_sudo: false 15 | use_cpu: false 16 | -------------------------------------------------------------------------------- /acc_configs/gpu8.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | machine_rank: 0 6 | main_training_function: main 7 | mixed_precision: bf16 8 | num_machines: 1 9 | num_processes: 8 10 | rdzv_backend: static 11 | same_network: true 12 | tpu_env: [] 13 | tpu_use_cluster: false 14 | tpu_use_sudo: false 15 | use_cpu: false 16 | -------------------------------------------------------------------------------- /preprocess/src/models/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | -------------------------------------------------------------------------------- /preprocess/src/models/renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | -------------------------------------------------------------------------------- /preprocess/src/models/renderer/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | -------------------------------------------------------------------------------- /preprocess/src/models/geometry/camera/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | from torch import nn 11 | 12 | 13 | class Camera(nn.Module): 14 | def __init__(self): 15 | super(Camera, self).__init__() 16 | pass 17 | -------------------------------------------------------------------------------- /preprocess/src/models/geometry/rep_3d/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | import numpy as np 11 | 12 | 13 | class Geometry(): 14 | def __init__(self): 15 | pass 16 | 17 | def forward(self): 18 | pass 19 | -------------------------------------------------------------------------------- /PartDrag4D/rendering/gen_filelist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | base = '../data/processed_data_partdrag4d' 5 | 6 | 7 | if __name__ == '__main__': 8 | classes = ['Dishwasher', 'Laptop', 'Microwave', 'Oven', 9 | 'Refrigerator', 'StorageFurniture', 'WashingMachine', 'TrashCan'] 10 | 11 | with open('../filelist/rendering.txt', 'w') as f: 12 | for class_name in classes: 13 | path = os.path.join(base, class_name) 14 | for root, dirs, files in os.walk(path): 15 | for file in files: 16 | if 'motion' in root: 17 | item_name = os.path.normpath(root).split(os.sep)[-2] 18 | item_idx = re.search(r'\d+', item_name).group() + '_' + file[0] 19 | if file.endswith('.obj'): 20 | f.write(os.path.join(root, file) + '\n') 21 | print('done') -------------------------------------------------------------------------------- /preprocess/src/utils/train_util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | def count_params(model, verbose=False): 5 | total_params = sum(p.numel() for p in model.parameters()) 6 | if verbose: 7 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 8 | return total_params 9 | 10 | 11 | def instantiate_from_config(config): 12 | if not "target" in config: 13 | if config == '__is_first_stage__': 14 | return None 15 | elif config == "__is_unconditional__": 16 | return None 17 | raise KeyError("Expected key `target` to instantiate.") 18 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 19 | 20 | 21 | def get_obj_from_str(string, reload=False): 22 | module, cls = string.rsplit(".", 1) 23 | if reload: 24 | module_imp = importlib.import_module(module) 25 | importlib.reload(module_imp) 26 | return getattr(importlib.import_module(module, package=None), cls) 27 | -------------------------------------------------------------------------------- /preprocess/src/models/geometry/rep_3d/dmtet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | 11 | 12 | def get_center_boundary_index(verts): 13 | length_ = torch.sum(verts ** 2, dim=-1) 14 | center_idx = torch.argmin(length_) 15 | boundary_neg = verts == verts.max() 16 | boundary_pos = verts == verts.min() 17 | boundary = torch.bitwise_or(boundary_pos, boundary_neg) 18 | boundary = torch.sum(boundary.float(), dim=-1) 19 | boundary_idx = torch.nonzero(boundary) 20 | return center_idx, boundary_idx.squeeze(dim=-1) 21 | -------------------------------------------------------------------------------- /preprocess/src/models/geometry/camera/perspective_camera.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | from . import Camera 11 | import numpy as np 12 | 13 | 14 | def projection(x=0.1, n=1.0, f=50.0, near_plane=None): 15 | if near_plane is None: 16 | near_plane = n 17 | return np.array( 18 | [[n / x, 0, 0, 0], 19 | [0, n / -x, 0, 0], 20 | [0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)], 21 | [0, 0, -1, 0]]).astype(np.float32) 22 | 23 | 24 | class PerspectiveCamera(Camera): 25 | def __init__(self, fovy=49.0, device='cuda'): 26 | super(PerspectiveCamera, self).__init__() 27 | self.device = device 28 | focal = np.tan(fovy / 180.0 * np.pi * 0.5) 29 | self.proj_mtx = torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0) 30 | 31 | def project(self, points_bxnx4): 32 | out = torch.matmul( 33 | points_bxnx4, 34 | torch.transpose(self.proj_mtx, 1, 2)) 35 | return out 36 | -------------------------------------------------------------------------------- /preprocess/src/models/geometry/rep_3d/extract_texture_map.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | import xatlas 11 | import numpy as np 12 | import nvdiffrast.torch as dr 13 | 14 | 15 | # ============================================================================================== 16 | def interpolate(attr, rast, attr_idx, rast_db=None): 17 | return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') 18 | 19 | 20 | def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution): 21 | vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy()) 22 | 23 | # Convert to tensors 24 | indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) 25 | 26 | uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) 27 | mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) 28 | # mesh_v_tex. ture 29 | uv_clip = uvs[None, ...] * 2.0 - 1.0 30 | 31 | # pad to four component coordinate 32 | uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1) 33 | 34 | # rasterize 35 | rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution)) 36 | 37 | # Interpolate world space position 38 | gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int()) 39 | mask = rast[..., 3:4] > 0 40 | return uvs, mesh_tex_idx, gb_pos, mask 41 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: partrm 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - cuda-version=12.6 7 | - omegaconf=2.3.0 8 | - openssl=3.3.2 9 | - python=3.8.18 10 | - pyyaml=6.0.2 11 | - yaml=0.2.5 12 | - pip=20.3.3 13 | - pip: 14 | - absl-py==2.1.0 15 | - basicsr==1.4.2 16 | - coloredlogs==15.0.1 17 | - configargparse==1.7 18 | - contourpy==1.1.1 19 | - cycler==0.12.1 20 | - dash==2.18.1 21 | - decorator==5.1.1 22 | - deprecated==1.2.14 23 | - diffusers==0.30.0 24 | - dlib==19.24.6 25 | - einops==0.8.0 26 | - future==1.0.0 27 | - gradio==4.44.0 28 | - gradio-client==1.3.0 29 | - grpcio==1.66.1 30 | - huggingface-hub==0.24.7 31 | - imageio==2.35.1 32 | - imageio-ffmpeg==0.5.1 33 | - ipython==8.12.3 34 | - jinja2==3.1.4 35 | - joblib==1.4.2 36 | - kiui==0.2.13 37 | - lightning-utilities==0.11.8 38 | - llvmlite==0.41.1 39 | - lmdb==1.5.1 40 | - lpips==0.1.4 41 | - matplotlib==3.7.5 42 | - networkx==3.0 43 | - numba==0.58.1 44 | - numpy==1.24.4 45 | - onnxruntime==1.19.2 46 | - onnxruntime-gpu==1.19.2 47 | - open3d==0.18.0 48 | - opencv-python==4.10.0.84 49 | - packaging==24.1 50 | - pandas==2.0.3 51 | - pillow==10.2.0 52 | - pip==24.3.1 53 | - plyfile==1.0.3 54 | - protobuf==5.28.1 55 | - psutil==6.0.0 56 | - pytorch-lightning==2.4.0 57 | - pytorch-msssim==1.0.0 58 | - rembg==2.0.59 59 | - requests==2.32.3 60 | - safetensors==0.4.5 61 | - scikit-image==0.21.0 62 | - scikit-learn==1.3.2 63 | - scipy==1.10.1 64 | - segment-anything==1.0 65 | - semantic-version==2.10.0 66 | - setuptools==75.3.0 67 | - shellingham==1.5.4 68 | - spaces==0.30.4 69 | - starlette==0.38.5 70 | - sympy==1.12 71 | - tensorboard==2.12.0 72 | - tensorboardx==2.6.2.2 73 | - tokenizers==0.19.1 74 | - --extra-index-url https://download.pytorch.org/whl/cu121 75 | - torch==2.4.1 76 | - torchvision==0.19.1 77 | - torchmetrics==1.5.0 78 | - tqdm==4.66.5 79 | - transformers==4.44.2 80 | - trimesh==4.4.9 81 | - typing-extensions==4.9.0 82 | - xformers==0.0.28 -------------------------------------------------------------------------------- /preprocess/gen_rgba.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | import argparse 6 | 7 | def gen_rgba(image_path): 8 | """ 9 | Generate rgba images from rgb images which are with white background. 10 | Args: 11 | image_path (str): Path to the rgb image. 12 | """ 13 | image = Image.open(image_path) 14 | rgba_image = image.convert('RGBA') 15 | pixels = rgba_image.load() 16 | 17 | for y in range(rgba_image.height): 18 | for x in range(rgba_image.width): 19 | r, g, b, a = pixels[x, y] 20 | 21 | if r >= 245 and g >= 245 and b >= 245: 22 | pixels[x, y] = (r, g, b, 0) 23 | 24 | rgba_image = rgba_image.resize((512, 512)) 25 | rgba_image.save(image_path.replace('.png', '_rgba.png')) 26 | 27 | def gen_rgba_from_filelist(file_list_path, dataset_name): 28 | """ 29 | Generate rgba images from rgb images which are with white background. 30 | Args: 31 | file_list_path (str): Path to the file list. 32 | """ 33 | with open(file_list_path, 'r') as f: 34 | file_list = f.readlines() 35 | 36 | if dataset_name == 'partdrag4d': 37 | for file_name in file_list: 38 | file_name = file_name.strip() 39 | for image_path in os.listdir(file_name): 40 | if image_path.endswith('.png') and not image_path.endswith('_rgba.png') and not image_path.startswith('000'): 41 | image_path = os.path.join(file_name, image_path) 42 | gen_rgba(image_path) 43 | 44 | elif dataset_name == 'objaverse_hq': 45 | for action_dir in file_list: 46 | action_dir = action_dir.strip() 47 | for frame_dir in os.listdir(action_dir): 48 | frame_dir = os.path.join(action_dir, frame_dir) 49 | for image_path in os.listdir(frame_dir): 50 | if image_path.endswith('.png') and not image_path.endswith('_rgba.png') and not image_path.startswith('000'): 51 | image_path = os.path.join(frame_dir, image_path) 52 | gen_rgba(image_path) 53 | 54 | else: 55 | raise ValueError('Unknown dataset name.') 56 | 57 | if __name__ == '__main__': 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('--filelist', type=str, help='path/to/your/zero123/filelist') 60 | parser.add_argument('--dataset', choices=['partdrag4d', 'objaverse_hq'], help='dataset name') 61 | args = parser.parse_args() 62 | 63 | gen_rgba_from_filelist(args.filelist, args.dataset) -------------------------------------------------------------------------------- /preprocess/gen_mv_partdrag4d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import requests 3 | from PIL import Image 4 | from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler 5 | from zero123plus.model import MVDiffusion 6 | 7 | from torchvision.utils import make_grid, save_image 8 | 9 | import torchvision.transforms as TF 10 | 11 | import os 12 | import numpy as np 13 | 14 | import argparse 15 | 16 | def unscale_image(image): 17 | image = image / 0.5 * 0.8 18 | return image 19 | 20 | def main(src_filelist, save_dir): 21 | os.makedirs(save_dir, exist_ok=True) 22 | # Load the pipeline 23 | mv_diffusion = MVDiffusion({'pretrained_model_name_or_path':"sudo-ai/zero123plus-v1.2", 24 | 'custom_pipeline':"./zero123plus"}) 25 | state_dict = torch.load('./zero123_ckpt/partdrag4d_zero123.ckpt')['state_dict'] 26 | 27 | mv_diffusion.load_state_dict(state_dict, strict=True) 28 | pipeline = mv_diffusion.pipeline 29 | pipeline = pipeline.to('cuda') 30 | 31 | val_image_paths = [] 32 | with open(src_filelist, 'r') as f: 33 | for line in f.readlines(): 34 | line = line.strip() 35 | image_path = f"{line}/000.png" 36 | val_image_paths.append(image_path) 37 | print(f"The length of the val images: {len(val_image_paths)}") 38 | 39 | for val_image_path in val_image_paths: 40 | print(f'Render {val_image_path}') 41 | render_id = val_image_path.split('/')[-2] 42 | 43 | cond = Image.open(val_image_path) 44 | os.makedirs(os.path.join(save_dir, render_id)) 45 | cond.save(os.path.join(save_dir, render_id, '000.png')) 46 | 47 | # Run the pipeline! 48 | with torch.no_grad(): 49 | latents = pipeline(cond, num_inference_steps=75, output_type='latent').images 50 | images = unscale_image(pipeline.vae.decode(latents / pipeline.vae.config.scaling_factor, return_dict=False)[0]) # [-1, 1] 51 | images = (images * 0.5 + 0.5).clamp(0, 1) 52 | 53 | resize = TF.Resize((512, 512)) 54 | images = images.squeeze(0) 55 | for i in range(6): 56 | row_idx = i % 2 57 | col_idx = i // 2 58 | image = images[:, col_idx*320:col_idx*320+320, row_idx*320:row_idx*320+320] 59 | image = resize(image) 60 | save_image(image, os.path.join(save_dir, render_id, f'{i+1:03d}.png')) 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('--src_filelist', default='../filelist/val_filelist_partdrag4d.txt') 66 | parser.add_argument('--output_dir', default='./zero123_preprocessed_data/PartDrag4D') 67 | args = parser.parse_args() 68 | 69 | main(args.src_filelist, args.output_dir) 70 | -------------------------------------------------------------------------------- /preprocess/gen_mv_objaverse_hq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import requests 3 | from PIL import Image 4 | from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler 5 | from zero123plus.model import MVDiffusion 6 | 7 | from torchvision.utils import make_grid, save_image 8 | 9 | import torchvision.transforms as TF 10 | 11 | import os 12 | import numpy as np 13 | 14 | import argparse 15 | 16 | def unscale_image(image): 17 | image = image / 0.5 * 0.8 18 | return image 19 | 20 | def main(src_filelist, save_dir): 21 | os.makedirs(save_dir, exist_ok=True) 22 | # Load the pipeline 23 | mv_diffusion = MVDiffusion({'pretrained_model_name_or_path':"sudo-ai/zero123plus-v1.2", 24 | 'custom_pipeline':"./zero123plus"}) 25 | state_dict = torch.load('./zero123_ckpt/objaverse_hq_zero123.ckpt')['state_dict'] 26 | 27 | mv_diffusion.load_state_dict(state_dict, strict=True) 28 | pipeline = mv_diffusion.pipeline 29 | pipeline = pipeline.to('cuda') 30 | 31 | val_image_paths = [] 32 | frame_ids = [] 33 | with open(src_filelist, 'r') as f: 34 | for line in f.readlines(): 35 | line = line.strip() 36 | frame_id = 0 37 | for image_path in os.listdir(line): 38 | if image_path.endswith('.png') and image_path.startswith('000'): 39 | val_image_paths.append(os.path.join(line, image_path)) 40 | frame_ids.append(frame_id) 41 | frame_id += 1 42 | 43 | print(f"The length of the val images: {len(val_image_paths)}") 44 | 45 | for image_cnt, val_image_path in enumerate(val_image_paths): 46 | frame_id = frame_ids[image_cnt] 47 | print(f'Render {val_image_path}') 48 | action_id = val_image_path.split('/')[-2] 49 | objaverse_id = val_image_path.split('/')[-3] 50 | 51 | cond = Image.open(val_image_path) 52 | os.makedirs(os.path.join(save_dir, objaverse_id, action_id,str(frame_id)), exist_ok=True) 53 | cond.save(os.path.join(save_dir, objaverse_id, action_id, str(frame_id), '000.png')) 54 | 55 | # Run the pipeline! 56 | with torch.no_grad(): 57 | latents = pipeline(cond, num_inference_steps=75, output_type='latent').images 58 | images = unscale_image(pipeline.vae.decode(latents / pipeline.vae.config.scaling_factor, return_dict=False)[0]) # [-1, 1] 59 | images = (images * 0.5 + 0.5).clamp(0, 1) 60 | 61 | resize = TF.Resize((512, 512)) 62 | images = images.squeeze(0) 63 | for i in range(6): 64 | row_idx = i % 2 65 | col_idx = i // 2 66 | image = images[:, col_idx*320:col_idx*320+320, row_idx*320:row_idx*320+320] 67 | image = resize(image) 68 | save_image(image, os.path.join(save_dir, objaverse_id, action_id, str(frame_id), f'{i+1:03d}.png')) 69 | 70 | 71 | if __name__ == "__main__": 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('--src_filelist', default='../filelist/val_filelist_objaverse_hq.txt') 74 | parser.add_argument('--output_dir', default='./zero123_preprocessed_data/Objaverse_HQ') 75 | args = parser.parse_args() 76 | 77 | main(args.src_filelist, args.output_dir) -------------------------------------------------------------------------------- /preprocess/src/utils/infer_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import rembg 4 | import torch 5 | import numpy as np 6 | import PIL.Image 7 | from PIL import Image 8 | from typing import Any 9 | 10 | 11 | def remove_background(image: PIL.Image.Image, 12 | rembg_session: Any = None, 13 | force: bool = False, 14 | **rembg_kwargs, 15 | ) -> PIL.Image.Image: 16 | do_remove = True 17 | if image.mode == "RGBA" and image.getextrema()[3][0] < 255: 18 | do_remove = False 19 | do_remove = do_remove or force 20 | if do_remove: 21 | image = rembg.remove(image, session=rembg_session, **rembg_kwargs) 22 | return image 23 | 24 | 25 | def resize_foreground( 26 | image: PIL.Image.Image, 27 | ratio: float, 28 | ) -> PIL.Image.Image: 29 | image = np.array(image) 30 | assert image.shape[-1] == 4 31 | alpha = np.where(image[..., 3] > 0) 32 | y1, y2, x1, x2 = ( 33 | alpha[0].min(), 34 | alpha[0].max(), 35 | alpha[1].min(), 36 | alpha[1].max(), 37 | ) 38 | # crop the foreground 39 | fg = image[y1:y2, x1:x2] 40 | # pad to square 41 | size = max(fg.shape[0], fg.shape[1]) 42 | ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 43 | ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 44 | new_image = np.pad( 45 | fg, 46 | ((ph0, ph1), (pw0, pw1), (0, 0)), 47 | mode="constant", 48 | constant_values=((0, 0), (0, 0), (0, 0)), 49 | ) 50 | 51 | # compute padding according to the ratio 52 | new_size = int(new_image.shape[0] / ratio) 53 | # pad to size, double side 54 | ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 55 | ph1, pw1 = new_size - size - ph0, new_size - size - pw0 56 | new_image = np.pad( 57 | new_image, 58 | ((ph0, ph1), (pw0, pw1), (0, 0)), 59 | mode="constant", 60 | constant_values=((0, 0), (0, 0), (0, 0)), 61 | ) 62 | new_image = PIL.Image.fromarray(new_image) 63 | return new_image 64 | 65 | 66 | def images_to_video( 67 | images: torch.Tensor, 68 | output_path: str, 69 | fps: int = 30, 70 | ) -> None: 71 | # images: (N, C, H, W) 72 | video_dir = os.path.dirname(output_path) 73 | video_name = os.path.basename(output_path) 74 | os.makedirs(video_dir, exist_ok=True) 75 | 76 | frames = [] 77 | for i in range(len(images)): 78 | frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) 79 | assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \ 80 | f"Frame shape mismatch: {frame.shape} vs {images.shape}" 81 | assert frame.min() >= 0 and frame.max() <= 255, \ 82 | f"Frame value out of range: {frame.min()} ~ {frame.max()}" 83 | frames.append(frame) 84 | imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10) 85 | 86 | 87 | def save_video( 88 | frames: torch.Tensor, 89 | output_path: str, 90 | fps: int = 30, 91 | ) -> None: 92 | # images: (N, C, H, W) 93 | frames = [(frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for frame in frames] 94 | writer = imageio.get_writer(output_path, fps=fps) 95 | for frame in frames: 96 | writer.append_data(frame) 97 | writer.close() -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import tyro 2 | import torch 3 | from core.models import LGM 4 | from accelerate import Accelerator, DistributedDataParallelKwargs 5 | from safetensors.torch import load_file 6 | 7 | import kiui 8 | import numpy as np 9 | 10 | import argparse 11 | 12 | 13 | def main(dataset_name): 14 | if dataset_name == 'partdrag4d': 15 | from core.options import AllConfigs 16 | else: 17 | from core.options_pm import AllConfigs 18 | 19 | opt = tyro.cli(AllConfigs) 20 | 21 | accelerator = Accelerator( 22 | mixed_precision=opt.mixed_precision, 23 | gradient_accumulation_steps=opt.gradient_accumulation_steps, 24 | ) 25 | 26 | model = LGM(opt) 27 | # resume 28 | if opt.resume is not None: 29 | if opt.resume.endswith('safetensors'): 30 | ckpt = load_file(opt.resume, device='cpu') 31 | else: 32 | ckpt = torch.load(opt.resume, map_location='cpu') 33 | 34 | # tolerant load (only load matching shapes) 35 | model.load_state_dict(ckpt, strict=False) 36 | 37 | # data 38 | if dataset_name == 'partdrag4d': 39 | from core.eval_dataset_partdrag4d import PartDrag4DEvalDatset as EvalDataset 40 | elif dataset_name == 'objaverse_hq': 41 | from core.eval_dataset_objaverse_hq import ObjaverseHQEvalDataset as EvalDataset 42 | 43 | test_dataset = EvalDataset(opt) 44 | test_dataloader = torch.utils.data.DataLoader( 45 | test_dataset, 46 | batch_size=1, 47 | shuffle=False, 48 | num_workers=4, 49 | pin_memory=True, 50 | drop_last=False, 51 | ) 52 | 53 | # accelerate 54 | model, test_dataloader = accelerator.prepare( 55 | model, test_dataloader 56 | ) 57 | 58 | # loop 59 | for epoch in range(1): 60 | # eval 61 | with torch.no_grad(): 62 | model.eval() 63 | for i, data in enumerate(test_dataloader): 64 | out, drag_start_2d, drag_move_2d = model(data) 65 | 66 | # save some images 67 | if accelerator.is_main_process: 68 | gt_images = data['images_output'].detach().cpu().numpy() # [B, V, 3, output_size, output_size] 69 | gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3] 70 | kiui.write_image(f'{opt.workspace}/eval_gt_images_{epoch}_{i}.jpg', gt_images) 71 | 72 | pred_images = out['images_pred'].detach().cpu().numpy() # [B, V, 3, output_size, output_size] 73 | pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3) 74 | kiui.write_image(f'{opt.workspace}/eval_pred_images_{epoch}_{i}.jpg', pred_images) 75 | 76 | origin_images = data['images_input'].detach().cpu().numpy() # [B, V, 3, output_size, output_size] 77 | origin_images = origin_images.transpose(0, 3, 1, 4, 2).reshape(-1, origin_images.shape[1] * origin_images.shape[3], 3) 78 | kiui.write_image(f'{opt.workspace}/eval_origin_images_{epoch}_{i}.jpg', origin_images) 79 | 80 | torch.cuda.empty_cache() 81 | 82 | if __name__ == "__main__": 83 | dataset = 'objaverse_hq' 84 | main(dataset) -------------------------------------------------------------------------------- /preprocess/src/models/encoder/dino_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch.nn as nn 17 | from transformers import ViTImageProcessor 18 | from einops import rearrange, repeat 19 | from .dino import ViTModel 20 | 21 | 22 | class DinoWrapper(nn.Module): 23 | """ 24 | Dino v1 wrapper using huggingface transformer implementation. 25 | """ 26 | def __init__(self, model_name: str, freeze: bool = True): 27 | super().__init__() 28 | self.model, self.processor = self._build_dino(model_name) 29 | self.camera_embedder = nn.Sequential( 30 | nn.Linear(16, self.model.config.hidden_size, bias=True), 31 | nn.SiLU(), 32 | nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size, bias=True) 33 | ) 34 | if freeze: 35 | self._freeze() 36 | 37 | def forward(self, image, camera): 38 | # image: [B, N, C, H, W] 39 | # camera: [B, N, D] 40 | # RGB image with [0,1] scale and properly sized 41 | if image.ndim == 5: 42 | image = rearrange(image, 'b n c h w -> (b n) c h w') 43 | dtype = image.dtype 44 | inputs = self.processor( 45 | images=image.float(), 46 | return_tensors="pt", 47 | do_rescale=False, 48 | do_resize=False, 49 | ).to(self.model.device).to(dtype) 50 | # embed camera 51 | N = camera.shape[1] 52 | camera_embeddings = self.camera_embedder(camera) 53 | camera_embeddings = rearrange(camera_embeddings, 'b n d -> (b n) d') 54 | embeddings = camera_embeddings 55 | # This resampling of positional embedding uses bicubic interpolation 56 | outputs = self.model(**inputs, adaln_input=embeddings, interpolate_pos_encoding=True) 57 | last_hidden_states = outputs.last_hidden_state 58 | return last_hidden_states 59 | 60 | def _freeze(self): 61 | print(f"======== Freezing DinoWrapper ========") 62 | self.model.eval() 63 | for name, param in self.model.named_parameters(): 64 | param.requires_grad = False 65 | 66 | @staticmethod 67 | def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5): 68 | import requests 69 | try: 70 | model = ViTModel.from_pretrained(model_name, add_pooling_layer=False) 71 | processor = ViTImageProcessor.from_pretrained(model_name) 72 | return model, processor 73 | except requests.exceptions.ProxyError as err: 74 | if proxy_error_retries > 0: 75 | print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...") 76 | import time 77 | time.sleep(proxy_error_cooldown) 78 | return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown) 79 | else: 80 | raise err 81 | -------------------------------------------------------------------------------- /preprocess/src/models/renderer/utils/ray_marcher.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | # 11 | # Modified by Jiale Xu 12 | # The modifications are subject to the same license as the original. 13 | 14 | 15 | """ 16 | The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths. 17 | Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!) 18 | """ 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | 24 | 25 | class MipRayMarcher2(nn.Module): 26 | def __init__(self, activation_factory): 27 | super().__init__() 28 | self.activation_factory = activation_factory 29 | 30 | def run_forward(self, colors, densities, depths, rendering_options, normals=None): 31 | dtype = colors.dtype 32 | deltas = depths[:, :, 1:] - depths[:, :, :-1] 33 | colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 34 | densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2 35 | depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 36 | 37 | # using factory mode for better usability 38 | densities_mid = self.activation_factory(rendering_options)(densities_mid).to(dtype) 39 | 40 | density_delta = densities_mid * deltas 41 | 42 | alpha = 1 - torch.exp(-density_delta).to(dtype) 43 | 44 | alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2) 45 | weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] 46 | weights = weights.to(dtype) 47 | 48 | composite_rgb = torch.sum(weights * colors_mid, -2) 49 | weight_total = weights.sum(2) 50 | # composite_depth = torch.sum(weights * depths_mid, -2) / weight_total 51 | composite_depth = torch.sum(weights * depths_mid, -2) 52 | 53 | # clip the composite to min/max range of depths 54 | composite_depth = torch.nan_to_num(composite_depth, float('inf')).to(dtype) 55 | composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths)) 56 | 57 | if rendering_options.get('white_back', False): 58 | composite_rgb = composite_rgb + 1 - weight_total 59 | 60 | # rendered value scale is 0-1, comment out original mipnerf scaling 61 | # composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1) 62 | 63 | return composite_rgb, composite_depth, weights 64 | 65 | 66 | def forward(self, colors, densities, depths, rendering_options, normals=None): 67 | if normals is not None: 68 | composite_rgb, composite_depth, composite_normals, weights = self.run_forward(colors, densities, depths, rendering_options, normals) 69 | return composite_rgb, composite_depth, composite_normals, weights 70 | 71 | composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options) 72 | return composite_rgb, composite_depth, weights 73 | -------------------------------------------------------------------------------- /PartDrag4D/rendering/distributed.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import subprocess 3 | import time 4 | from dataclasses import dataclass 5 | from typing import Optional 6 | import boto3 7 | import tyro 8 | import wandb 9 | 10 | @dataclass 11 | class Args: 12 | workers_per_gpu: int 13 | """number of workers per gpu""" 14 | 15 | input_models_path: str 16 | """Path to a json file containing a list of 3D object files""" 17 | 18 | num_images: int = 12 19 | """Number of rendered images""" 20 | 21 | upload_to_s3: bool = False 22 | """Whether to upload the rendered images to S3""" 23 | 24 | log_to_wandb: bool = False 25 | """Whether to log the progress to wandb""" 26 | 27 | num_gpus: int = -1 28 | """number of gpus to use. -1 means all available gpus""" 29 | 30 | blender_path: str = './blender/blender-3.5.0-linux-x64' 31 | """blender path""" 32 | 33 | view_path_root: str = './render_PartDrag4D' 34 | """Render images path""" 35 | 36 | 37 | def worker( 38 | queue: multiprocessing.JoinableQueue, 39 | count: multiprocessing.Value, 40 | gpu: int, 41 | s3: Optional[boto3.client], 42 | blender_path: str, 43 | view_path_root: str, 44 | num_images: int 45 | ) -> None: 46 | while True: 47 | item = queue.get() 48 | if item is None: 49 | break 50 | 51 | # Perform some operation on the item 52 | print(item, gpu) 53 | command = ( 54 | f" CUDA_VISIBLE_DEVICES={gpu} " 55 | f" {blender_path}/blender -b -P blender_script.py --" 56 | f" --object_path {item} --output_dir {view_path_root} --num_images {num_images}" 57 | ) 58 | print('command=======') 59 | print(command) 60 | subprocess.run(command, shell=True) 61 | 62 | with count.get_lock(): 63 | count.value += 1 64 | 65 | queue.task_done() 66 | 67 | 68 | if __name__ == "__main__": 69 | args = tyro.cli(Args) 70 | 71 | s3 = boto3.client("s3") if args.upload_to_s3 else None 72 | queue = multiprocessing.JoinableQueue() 73 | count = multiprocessing.Value("i", 0) 74 | blender_path = args.blender_path 75 | view_path_root = args.view_path_root 76 | num_images = args.num_images 77 | 78 | if args.log_to_wandb: 79 | wandb.init(project="objaverse-rendering", entity="prior-ai2") 80 | 81 | # Start worker processes on each of the GPUs 82 | for gpu_i in range(args.num_gpus): 83 | for worker_i in range(args.workers_per_gpu): 84 | worker_i = gpu_i * args.workers_per_gpu + worker_i 85 | process = multiprocessing.Process( 86 | target=worker, args=(queue, count, gpu_i, s3, blender_path, view_path_root, num_images) 87 | ) 88 | process.daemon = True 89 | process.start() 90 | 91 | with open(args.input_models_path, 'r') as f: 92 | model_paths = f.readlines() 93 | 94 | for item in model_paths: 95 | path = item.strip() 96 | queue.put(path) 97 | 98 | # update the wandb count 99 | if args.log_to_wandb: 100 | while True: 101 | time.sleep(5) 102 | wandb.log( 103 | { 104 | "count": count.value, 105 | "total": len(model_paths), 106 | "progress": count.value / len(model_paths), 107 | } 108 | ) 109 | if count.value == len(model_paths): 110 | break 111 | 112 | # Wait for all tasks to be completed 113 | queue.join() 114 | 115 | # Add sentinels to the queue to stop the worker processes 116 | for i in range(args.num_gpus * args.workers_per_gpu): 117 | queue.put(None) 118 | -------------------------------------------------------------------------------- /compute_metrics.py: -------------------------------------------------------------------------------- 1 | from pytorch_msssim import ssim 2 | import lpips 3 | import torch 4 | 5 | import cv2 6 | 7 | from tqdm import tqdm 8 | 9 | import numpy as np 10 | 11 | VAL_FILELIST = '/path/yo/your/filelist' 12 | 13 | # src_images: N*3*256*256 14 | # tgt_images: N*3*256*256 15 | # origin_images: N*3*256*256 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | 19 | def compute_ssim(src_images, tgt_images): 20 | total_ssim_loss = 0 21 | for i in range(len(src_images)): 22 | src_images[i] = src_images[i].unsqueeze(0).detach() 23 | tgt_images[i] = tgt_images[i].unsqueeze(0).detach() 24 | 25 | ssim_val = ssim(src_images, tgt_images, data_range=1.0, size_average=True) 26 | total_ssim_loss += ssim_val 27 | 28 | total_ssim_loss /= len(src_images) 29 | return total_ssim_loss 30 | 31 | def compute_lpips(src_images, tgt_images, lpips_loss_fn): 32 | total_lpips_loss = 0 33 | for i in range(len(src_images)): 34 | src_images[i] = src_images[i].unsqueeze(0).detach() 35 | tgt_images[i] = tgt_images[i].unsqueeze(0).detach() 36 | with torch.no_grad(): 37 | lpips_loss_val = lpips_loss_fn(src_images, tgt_images).mean() 38 | total_lpips_loss += lpips_loss_val 39 | del lpips_loss_val 40 | torch.cuda.empty_cache() 41 | 42 | total_lpips_loss /= len(src_images) 43 | return total_lpips_loss.mean() 44 | 45 | def compute_psnr(src_images, tgt_images): 46 | psnr_val = -10 * torch.log10(torch.mean((src_images - tgt_images) ** 2)) 47 | return psnr_val 48 | 49 | def process_image(image_path): 50 | image = cv2.imread(image_path) 51 | image = image[:, :4096, :] 52 | image = cv2.resize(image, (2048, 256)) 53 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 54 | image = image / 255.0 55 | image = image.transpose(2, 0, 1) 56 | image = torch.tensor(image, dtype=torch.float32, device=device) 57 | return image 58 | 59 | if __name__ == '__main__': 60 | count = 0 61 | 62 | total_lpips_loss = 0 63 | total_ssim_loss = 0 64 | total_psnr_loss = 0 65 | 66 | lpips_loss = lpips.LPIPS(net='vgg').to(device) 67 | lpips_loss.eval() 68 | lpips_loss.requires_grad = False 69 | 70 | 71 | with open(VAL_FILELIST, 'r') as f: 72 | for line in tqdm(f.readlines()): 73 | tgt_path, src_path, origin_path = line.strip().split(',') 74 | src_image = process_image(src_path) 75 | tgt_image = process_image(tgt_path) 76 | origin_image = process_image(origin_path) 77 | 78 | for j in range(8): 79 | count += 1 80 | crop_src_image = src_image[:, :, j*256:(j+1)*256] 81 | crop_tgt_image = tgt_image[:, :, j*256:(j+1)*256] 82 | crop_origin_image = origin_image[:, :, j*256:(j+1)*256] 83 | 84 | psnr_val = compute_psnr(crop_src_image.unsqueeze(0), crop_tgt_image.unsqueeze(0)) 85 | ssim_val = compute_ssim(crop_src_image.unsqueeze(0), crop_tgt_image.unsqueeze(0)) 86 | lpips_loss_val = compute_lpips(crop_src_image.unsqueeze(0), crop_tgt_image.unsqueeze(0), lpips_loss) 87 | 88 | if torch.isinf(psnr_val): 89 | count -= 1 90 | print(f"psnr_val is inf, skip") 91 | continue 92 | 93 | total_psnr_loss += psnr_val 94 | total_ssim_loss += ssim_val 95 | total_lpips_loss += lpips_loss_val 96 | 97 | del src_image, tgt_image, origin_image, psnr_val, ssim_val, lpips_loss_val 98 | 99 | 100 | psnr_val = total_psnr_loss / count 101 | ssim_val = total_ssim_loss / count 102 | lpips_loss_val = total_lpips_loss / count 103 | 104 | print("PSNR: ", psnr_val) 105 | print("SSIM: ", ssim_val) 106 | print("LPIPS: ", lpips_loss_val) 107 | 108 | 109 | -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import roma 8 | from kiui.op import safe_normalize 9 | 10 | def get_rays(pose, h, w, fovy, opengl=True): 11 | 12 | x, y = torch.meshgrid( 13 | torch.arange(w, device=pose.device), 14 | torch.arange(h, device=pose.device), 15 | indexing="xy", 16 | ) 17 | x = x.flatten() 18 | y = y.flatten() 19 | 20 | cx = w * 0.5 21 | cy = h * 0.5 22 | 23 | focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy)) 24 | 25 | camera_dirs = F.pad( 26 | torch.stack( 27 | [ 28 | (x - cx + 0.5) / focal, 29 | (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0), 30 | ], 31 | dim=-1, 32 | ), 33 | (0, 1), 34 | value=(-1.0 if opengl else 1.0), 35 | ) # [hw, 3] 36 | 37 | rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3] 38 | rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3] 39 | 40 | rays_o = rays_o.view(h, w, 3) 41 | rays_d = safe_normalize(rays_d).view(h, w, 3) 42 | 43 | return rays_o, rays_d 44 | 45 | def orbit_camera_jitter(poses, strength=0.1): 46 | # poses: [B, 4, 4], assume orbit camera in opengl format 47 | # random orbital rotate 48 | 49 | B = poses.shape[0] 50 | rotvec_x = poses[:, :3, 1] * strength * np.pi * (torch.rand(B, 1, device=poses.device) * 2 - 1) 51 | rotvec_y = poses[:, :3, 0] * strength * np.pi / 2 * (torch.rand(B, 1, device=poses.device) * 2 - 1) 52 | 53 | rot = roma.rotvec_to_rotmat(rotvec_x) @ roma.rotvec_to_rotmat(rotvec_y) 54 | R = rot @ poses[:, :3, :3] 55 | T = rot @ poses[:, :3, 3:] 56 | 57 | new_poses = poses.clone() 58 | new_poses[:, :3, :3] = R 59 | new_poses[:, :3, 3:] = T 60 | 61 | return new_poses 62 | 63 | def grid_distortion(images, strength=0.5): 64 | # images: [B, C, H, W] 65 | # num_steps: int, grid resolution for distortion 66 | # strength: float in [0, 1], strength of distortion 67 | 68 | B, C, H, W = images.shape 69 | 70 | num_steps = np.random.randint(8, 17) 71 | grid_steps = torch.linspace(-1, 1, num_steps) 72 | 73 | # have to loop batch... 74 | grids = [] 75 | for b in range(B): 76 | # construct displacement 77 | x_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive 78 | x_steps = (x_steps + strength * (torch.rand_like(x_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb 79 | x_steps = (x_steps * W).long() # [num_steps] 80 | x_steps[0] = 0 81 | x_steps[-1] = W 82 | xs = [] 83 | for i in range(num_steps - 1): 84 | xs.append(torch.linspace(grid_steps[i], grid_steps[i + 1], x_steps[i + 1] - x_steps[i])) 85 | xs = torch.cat(xs, dim=0) # [W] 86 | 87 | y_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive 88 | y_steps = (y_steps + strength * (torch.rand_like(y_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb 89 | y_steps = (y_steps * H).long() # [num_steps] 90 | y_steps[0] = 0 91 | y_steps[-1] = H 92 | ys = [] 93 | for i in range(num_steps - 1): 94 | ys.append(torch.linspace(grid_steps[i], grid_steps[i + 1], y_steps[i + 1] - y_steps[i])) 95 | ys = torch.cat(ys, dim=0) # [H] 96 | 97 | # construct grid 98 | grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy') # [H, W] 99 | grid = torch.stack([grid_x, grid_y], dim=-1) # [H, W, 2] 100 | 101 | grids.append(grid) 102 | 103 | grids = torch.stack(grids, dim=0).to(images.device) # [B, H, W, 2] 104 | 105 | # grid sample 106 | images = F.grid_sample(images, grids, align_corners=False) 107 | 108 | return images 109 | 110 | 111 | def get_action_embeddings(action): 112 | # action: [B, 3], assume normalized 113 | # return: [B, 768], embeddings 114 | 115 | # assume 3 actions: walk, jump, turn 116 | action = action * 0.5 + 0.5 117 | action = action * 2 * np.pi 118 | 119 | embeddings = torch.stack([ 120 | torch.cos(action), 121 | torch.sin(action), 122 | ], dim=-1).reshape(-1, 3) 123 | 124 | return embeddings 125 | -------------------------------------------------------------------------------- /preprocess/src/data/objaverse_zero123plus.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import webdataset as wds 5 | import pytorch_lightning as pl 6 | import torch 7 | from torch.utils.data import Dataset 8 | from torch.utils.data.distributed import DistributedSampler 9 | from PIL import Image 10 | from pathlib import Path 11 | 12 | from src.utils.train_util import instantiate_from_config 13 | 14 | 15 | class DataModuleFromConfig(pl.LightningDataModule): 16 | def __init__( 17 | self, 18 | batch_size=8, 19 | num_workers=4, 20 | train=None, 21 | validation=None, 22 | test=None, 23 | **kwargs, 24 | ): 25 | super().__init__() 26 | 27 | self.batch_size = batch_size 28 | self.num_workers = num_workers 29 | 30 | self.dataset_configs = dict() 31 | if train is not None: 32 | self.dataset_configs['train'] = train 33 | if validation is not None: 34 | self.dataset_configs['validation'] = validation 35 | if test is not None: 36 | self.dataset_configs['test'] = test 37 | 38 | def setup(self, stage): 39 | 40 | if stage in ['fit']: 41 | self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) 42 | else: 43 | raise NotImplementedError 44 | 45 | def train_dataloader(self): 46 | 47 | sampler = DistributedSampler(self.datasets['train']) 48 | return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler) 49 | 50 | def val_dataloader(self): 51 | 52 | sampler = DistributedSampler(self.datasets['validation']) 53 | return wds.WebLoader(self.datasets['validation'], batch_size=4, num_workers=self.num_workers, shuffle=False, sampler=sampler) 54 | 55 | def test_dataloader(self): 56 | 57 | return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) 58 | 59 | 60 | class ObjaverseData(Dataset): 61 | def __init__(self, 62 | root_dir='objaverse/', 63 | meta_fname='valid_paths.json', 64 | image_dir='rendering_zero123plus', 65 | validation=False, 66 | ): 67 | self.root_dir = Path(root_dir) 68 | self.image_dir = image_dir 69 | 70 | with open(os.path.join(root_dir, meta_fname)) as f: 71 | lvis_dict = json.load(f) 72 | paths = [] 73 | for k in lvis_dict.keys(): 74 | paths.extend(lvis_dict[k]) 75 | self.paths = paths 76 | 77 | total_objects = len(self.paths) 78 | if validation: 79 | self.paths = self.paths[-16:] # used last 16 as validation 80 | else: 81 | self.paths = self.paths[:-16] 82 | print('============= length of dataset %d =============' % len(self.paths)) 83 | 84 | def __len__(self): 85 | return len(self.paths) 86 | 87 | def load_im(self, path, color): 88 | pil_img = Image.open(path) 89 | 90 | image = np.asarray(pil_img, dtype=np.float32) / 255. 91 | alpha = image[:, :, 3:] 92 | image = image[:, :, :3] * alpha + color * (1 - alpha) 93 | 94 | image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() 95 | alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() 96 | return image, alpha 97 | 98 | def __getitem__(self, index): 99 | while True: 100 | image_path = os.path.join(self.root_dir, self.image_dir, self.paths[index]) 101 | 102 | '''background color, default: white''' 103 | bkg_color = [1., 1., 1.] 104 | 105 | img_list = [] 106 | try: 107 | for idx in range(7): 108 | img, alpha = self.load_im(os.path.join(image_path, '%03d.png' % idx), bkg_color) 109 | img_list.append(img) 110 | 111 | except Exception as e: 112 | print(e) 113 | index = np.random.randint(0, len(self.paths)) 114 | continue 115 | 116 | break 117 | 118 | imgs = torch.stack(img_list, dim=0).float() 119 | 120 | data = { 121 | 'cond_imgs': imgs[0], # (3, H, W) 122 | 'target_imgs': imgs[1:], # (6, 3, H, W) 123 | } 124 | return data 125 | -------------------------------------------------------------------------------- /core/options.py: -------------------------------------------------------------------------------- 1 | import tyro 2 | from dataclasses import dataclass 3 | from typing import Tuple, Literal, Dict, Optional 4 | 5 | @dataclass 6 | class Options: 7 | train_filelist: str = './filelist/train_filelist.txt' 8 | val_filelist: str = './filelist/val_filelist.txt' 9 | zero123_val_filelist: str = './filelist/zero123_val_filelist.txt' # Need to fix it 10 | propagated_drags_base : str = './preprocess/propagated_drags' # Need to fix it 11 | # Unet image input size 12 | input_size: int = 256 13 | # Unet definition 14 | down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024) 15 | down_attention: Tuple[bool, ...] = (False, False, False, True, True, True) 16 | mid_attention: bool = True 17 | up_channels: Tuple[int, ...] = (1024, 1024, 512, 256) 18 | up_attention: Tuple[bool, ...] = (True, True, True, False) 19 | # Unet output size, dependent on the input_size and U-Net structure! 20 | splat_size: int = 64 21 | # gaussian render size 22 | output_size: int = 256 23 | 24 | ### dataset 25 | # data mode (only support s3 now) 26 | data_mode: Literal['s3'] = 's3' 27 | # fovy of the dataset 28 | fovy: float = 49.1 29 | # camera near plane 30 | znear: float = 0.5 31 | # camera far plane 32 | zfar: float = 2.5 33 | # number of all views (input + output) 34 | num_views: int = 12 35 | # number of views 36 | num_input_views: int = 4 37 | # camera radius 38 | cam_radius: float = 2.4 # to better use [-1, 1]^3 space 39 | # num workers 40 | num_workers: int = 8 41 | 42 | ### training 43 | # workspace 44 | workspace: str = './workspace' 45 | # resume 46 | resume : Optional[str] = 'pretrained/partdrag4d.ckpt' 47 | # batch size (per-GPU) 48 | batch_size: int = 4 49 | # gradient accumulation 50 | gradient_accumulation_steps: int = 1 51 | # training epochs 52 | num_epochs: int = 3000 53 | # lpips loss weight 54 | lambda_lpips: float = 1.0 55 | # gradient clip 56 | gradient_clip: float = 1.0 57 | # mixed precision 58 | mixed_precision: str = 'bf16' 59 | # learning rate 60 | lr: float = 5e-4 61 | # augmentation prob for grid distortion 62 | prob_grid_distortion: float = 0.5 63 | # augmentation prob for camera jitter 64 | prob_cam_jitter: float = 0.5 65 | # The base dir for mesh 66 | mesh_base: str = './PartDrag4D/data/processed_data_partdrag4d' 67 | 68 | ### testing 69 | # test image path 70 | test_path: Optional[str] = None 71 | 72 | # use drag 73 | use_drag_encoding: bool = True 74 | use_ms_drag_encoding: bool = True 75 | num_drags : int = 10 76 | stage1: bool = False 77 | 78 | # GS flow 79 | base_file_path: str = './gs_database/PartDrag4d' 80 | val_ratio: float = 0.15 81 | random_drag: bool = True 82 | lambda_flow: float = 1.0 83 | perturb_drags: bool = True 84 | 85 | 86 | # all the default settings 87 | config_defaults: Dict[str, Options] = {} 88 | config_doc: Dict[str, str] = {} 89 | 90 | config_doc['lrm'] = 'the default settings for LGM' 91 | config_defaults['lrm'] = Options() 92 | 93 | config_doc['small'] = 'small model with lower resolution Gaussians' 94 | config_defaults['small'] = Options( 95 | input_size=256, 96 | splat_size=64, 97 | output_size=256, 98 | batch_size=8, 99 | gradient_accumulation_steps=1, 100 | mixed_precision='bf16', 101 | ) 102 | 103 | config_doc['big'] = 'big model with higher resolution Gaussians' 104 | config_defaults['big'] = Options( 105 | input_size=256, 106 | up_channels=(1024, 1024, 512, 256, 128), # one more decoder 107 | up_attention=(True, True, True, False, False), 108 | splat_size=128, 109 | output_size=256, # render & supervise Gaussians at a higher resolution. 110 | batch_size=8, 111 | num_views=12, 112 | gradient_accumulation_steps=1, 113 | mixed_precision='bf16', 114 | ) 115 | 116 | config_doc['tiny'] = 'tiny model for ablation' 117 | config_defaults['tiny'] = Options( 118 | input_size=256, 119 | down_channels=(32, 64, 128, 256, 512), 120 | down_attention=(False, False, False, False, True), 121 | up_channels=(512, 256, 128), 122 | up_attention=(True, False, False, False), 123 | splat_size=64, 124 | output_size=256, 125 | batch_size=16, 126 | num_views=8, 127 | gradient_accumulation_steps=1, 128 | mixed_precision='bf16', 129 | ) 130 | 131 | AllConfigs = tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc) 132 | -------------------------------------------------------------------------------- /core/options_pm.py: -------------------------------------------------------------------------------- 1 | import tyro 2 | from dataclasses import dataclass 3 | from typing import Tuple, Literal, Dict, Optional 4 | 5 | 6 | @dataclass 7 | class Options: 8 | ### model 9 | train_filelist: str = './filelist/train_objavser_hq.txt' 10 | val_filelist: str = './filelist/eval_objavser_hq.txt' 11 | zero123_val_filelist: str = './filelist/zero123_val_filelist_objavser_hq.txt' 12 | # Unet image input size 13 | input_size: int = 256 14 | # Unet definition 15 | down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024) 16 | down_attention: Tuple[bool, ...] = (False, False, False, True, True, True) 17 | mid_attention: bool = True 18 | up_channels: Tuple[int, ...] = (1024, 1024, 512, 256) 19 | up_attention: Tuple[bool, ...] = (True, True, True, False) 20 | # Unet output size, dependent on the input_size and U-Net structure! 21 | splat_size: int = 64 22 | # gaussian render size 23 | output_size: int = 256 24 | 25 | ### dataset 26 | # data mode (only support s3 now) 27 | data_mode: Literal['s3'] = 's3' 28 | # fovy of the dataset 29 | fovy: float = 49.1 30 | # camera near plane 31 | znear: float = 0.5 32 | # camera far plane 33 | zfar: float = 2.5 34 | # number of all views (input + output) 35 | num_views: int = 12 36 | # number of views 37 | num_input_views: int = 4 38 | # camera radius 39 | cam_radius: float = 2.5 # to better use [-1, 1]^3 space 40 | # num workers 41 | num_workers: int = 8 42 | 43 | ### training 44 | # workspace 45 | workspace: str = './workspace' 46 | # resume 47 | resume: Optional[str] = './pretrained/objaverse_hq.ckpt' 48 | # batch size (per-GPU) 49 | batch_size: int = 4 50 | # gradient accumulation 51 | gradient_accumulation_steps: int = 1 52 | # training epochs 53 | num_epochs: int = 3000 54 | # lpips loss weight 55 | lambda_lpips: float = 1.0 56 | # gradient clip 57 | gradient_clip: float = 1.0 58 | # mixed precision 59 | mixed_precision: str = 'bf16' 60 | # learning rate 61 | lr: float = 5e-4 62 | # augmentation prob for grid distortion 63 | prob_grid_distortion: float = 0.5 64 | # augmentation prob for camera jitter 65 | prob_cam_jitter: float = 0.5 66 | 67 | ### testing 68 | # test image path 69 | test_path: Optional[str] = None 70 | 71 | ### misc 72 | # nvdiffrast backend setting 73 | force_cuda_rast: bool = False 74 | # render fancy video with gaussian scaling effect 75 | fancy_video: bool = False 76 | 77 | # use drag 78 | use_drag_encoding: bool = True 79 | use_ms_drag_encoding: bool = True 80 | num_drags : int = 10 81 | stage1: bool = False 82 | 83 | # GS flow 84 | base_file_path: str = './gs_database/Objaverse_HQ' 85 | val_ratio: float = 0.1 86 | random_drag: bool = True 87 | lambda_flow: float = 1.0 88 | perturb_drags: bool = True 89 | 90 | # all the default settings 91 | config_defaults: Dict[str, Options] = {} 92 | config_doc: Dict[str, str] = {} 93 | 94 | config_doc['lrm'] = 'the default settings for LGM' 95 | config_defaults['lrm'] = Options() 96 | 97 | config_doc['small'] = 'small model with lower resolution Gaussians' 98 | config_defaults['small'] = Options( 99 | input_size=256, 100 | splat_size=64, 101 | output_size=256, 102 | batch_size=8, 103 | gradient_accumulation_steps=1, 104 | mixed_precision='bf16', 105 | ) 106 | 107 | config_doc['big'] = 'big model with higher resolution Gaussians' 108 | config_defaults['big'] = Options( 109 | input_size=256, 110 | up_channels=(1024, 1024, 512, 256, 128), # one more decoder 111 | up_attention=(True, True, True, False, False), 112 | splat_size=128, 113 | output_size=256, # render & supervise Gaussians at a higher resolution. 114 | batch_size=8, 115 | num_views=8, 116 | gradient_accumulation_steps=1, 117 | mixed_precision='bf16', 118 | ) 119 | 120 | config_doc['tiny'] = 'tiny model for ablation' 121 | config_defaults['tiny'] = Options( 122 | input_size=256, 123 | down_channels=(32, 64, 128, 256, 512), 124 | down_attention=(False, False, False, False, True), 125 | up_channels=(512, 256, 128), 126 | up_attention=(True, False, False, False), 127 | splat_size=64, 128 | output_size=256, 129 | batch_size=16, 130 | num_views=8, 131 | gradient_accumulation_steps=1, 132 | mixed_precision='bf16', 133 | ) 134 | 135 | AllConfigs = tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc) 136 | -------------------------------------------------------------------------------- /preprocess/src/utils/camera_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | 6 | def pad_camera_extrinsics_4x4(extrinsics): 7 | if extrinsics.shape[-2] == 4: 8 | return extrinsics 9 | padding = torch.tensor([[0, 0, 0, 1]]).to(extrinsics) 10 | if extrinsics.ndim == 3: 11 | padding = padding.unsqueeze(0).repeat(extrinsics.shape[0], 1, 1) 12 | extrinsics = torch.cat([extrinsics, padding], dim=-2) 13 | return extrinsics 14 | 15 | 16 | def center_looking_at_camera_pose(camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None): 17 | """ 18 | Create OpenGL camera extrinsics from camera locations and look-at position. 19 | 20 | camera_position: (M, 3) or (3,) 21 | look_at: (3) 22 | up_world: (3) 23 | return: (M, 3, 4) or (3, 4) 24 | """ 25 | # by default, looking at the origin and world up is z-axis 26 | if look_at is None: 27 | look_at = torch.tensor([0, 0, 0], dtype=torch.float32) 28 | if up_world is None: 29 | up_world = torch.tensor([0, 0, 1], dtype=torch.float32) 30 | if camera_position.ndim == 2: 31 | look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1) 32 | up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1) 33 | 34 | # OpenGL camera: z-backward, x-right, y-up 35 | z_axis = camera_position - look_at 36 | z_axis = F.normalize(z_axis, dim=-1).float() 37 | x_axis = torch.linalg.cross(up_world, z_axis, dim=-1) 38 | x_axis = F.normalize(x_axis, dim=-1).float() 39 | y_axis = torch.linalg.cross(z_axis, x_axis, dim=-1) 40 | y_axis = F.normalize(y_axis, dim=-1).float() 41 | 42 | extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1) 43 | extrinsics = pad_camera_extrinsics_4x4(extrinsics) 44 | return extrinsics 45 | 46 | 47 | def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5): 48 | azimuths = np.deg2rad(azimuths) 49 | elevations = np.deg2rad(elevations) 50 | 51 | xs = radius * np.cos(elevations) * np.cos(azimuths) 52 | ys = radius * np.cos(elevations) * np.sin(azimuths) 53 | zs = radius * np.sin(elevations) 54 | 55 | cam_locations = np.stack([xs, ys, zs], axis=-1) 56 | cam_locations = torch.from_numpy(cam_locations).float() 57 | 58 | c2ws = center_looking_at_camera_pose(cam_locations) 59 | return c2ws 60 | 61 | 62 | def get_circular_camera_poses(M=120, radius=2.5, elevation=30.0): 63 | # M: number of circular views 64 | # radius: camera dist to center 65 | # elevation: elevation degrees of the camera 66 | # return: (M, 4, 4) 67 | assert M > 0 and radius > 0 68 | 69 | elevation = np.deg2rad(elevation) 70 | 71 | camera_positions = [] 72 | for i in range(M): 73 | azimuth = 2 * np.pi * i / M 74 | x = radius * np.cos(elevation) * np.cos(azimuth) 75 | y = radius * np.cos(elevation) * np.sin(azimuth) 76 | z = radius * np.sin(elevation) 77 | camera_positions.append([x, y, z]) 78 | camera_positions = np.array(camera_positions) 79 | camera_positions = torch.from_numpy(camera_positions).float() 80 | extrinsics = center_looking_at_camera_pose(camera_positions) 81 | return extrinsics 82 | 83 | 84 | def FOV_to_intrinsics(fov, device='cpu'): 85 | """ 86 | Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees. 87 | Note the intrinsics are returned as normalized by image size, rather than in pixel units. 88 | Assumes principal point is at image center. 89 | """ 90 | focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5) 91 | intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) 92 | return intrinsics 93 | 94 | 95 | def get_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0): 96 | """ 97 | Get the input camera parameters. 98 | """ 99 | azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float) 100 | elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float) 101 | 102 | c2ws = spherical_camera_pose(azimuths, elevations, radius) 103 | c2ws = c2ws.float().flatten(-2) 104 | 105 | Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2) 106 | 107 | extrinsics = c2ws[:, :12] 108 | intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1) 109 | cameras = torch.cat([extrinsics, intrinsics], dim=-1) 110 | 111 | return cameras.unsqueeze(0).repeat(batch_size, 1, 1) 112 | -------------------------------------------------------------------------------- /preprocess/src/models/decoder/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | 20 | class BasicTransformerBlock(nn.Module): 21 | """ 22 | Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks. 23 | """ 24 | # use attention from torch.nn.MultiHeadAttention 25 | # Block contains a cross-attention layer, a self-attention layer, and a MLP 26 | def __init__( 27 | self, 28 | inner_dim: int, 29 | cond_dim: int, 30 | num_heads: int, 31 | eps: float, 32 | attn_drop: float = 0., 33 | attn_bias: bool = False, 34 | mlp_ratio: float = 4., 35 | mlp_drop: float = 0., 36 | ): 37 | super().__init__() 38 | 39 | self.norm1 = nn.LayerNorm(inner_dim) 40 | self.cross_attn = nn.MultiheadAttention( 41 | embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim, 42 | dropout=attn_drop, bias=attn_bias, batch_first=True) 43 | self.norm2 = nn.LayerNorm(inner_dim) 44 | self.self_attn = nn.MultiheadAttention( 45 | embed_dim=inner_dim, num_heads=num_heads, 46 | dropout=attn_drop, bias=attn_bias, batch_first=True) 47 | self.norm3 = nn.LayerNorm(inner_dim) 48 | self.mlp = nn.Sequential( 49 | nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), 50 | nn.GELU(), 51 | nn.Dropout(mlp_drop), 52 | nn.Linear(int(inner_dim * mlp_ratio), inner_dim), 53 | nn.Dropout(mlp_drop), 54 | ) 55 | 56 | def forward(self, x, cond): 57 | # x: [N, L, D] 58 | # cond: [N, L_cond, D_cond] 59 | x = x + self.cross_attn(self.norm1(x), cond, cond)[0] 60 | before_sa = self.norm2(x) 61 | x = x + self.self_attn(before_sa, before_sa, before_sa)[0] 62 | x = x + self.mlp(self.norm3(x)) 63 | return x 64 | 65 | 66 | class TriplaneTransformer(nn.Module): 67 | """ 68 | Transformer with condition that generates a triplane representation. 69 | 70 | Reference: 71 | Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486 72 | """ 73 | def __init__( 74 | self, 75 | inner_dim: int, 76 | image_feat_dim: int, 77 | triplane_low_res: int, 78 | triplane_high_res: int, 79 | triplane_dim: int, 80 | num_layers: int, 81 | num_heads: int, 82 | eps: float = 1e-6, 83 | ): 84 | super().__init__() 85 | 86 | # attributes 87 | self.triplane_low_res = triplane_low_res 88 | self.triplane_high_res = triplane_high_res 89 | self.triplane_dim = triplane_dim 90 | 91 | # modules 92 | # initialize pos_embed with 1/sqrt(dim) * N(0, 1) 93 | self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5) 94 | self.layers = nn.ModuleList([ 95 | BasicTransformerBlock( 96 | inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps) 97 | for _ in range(num_layers) 98 | ]) 99 | self.norm = nn.LayerNorm(inner_dim, eps=eps) 100 | self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0) 101 | 102 | def forward(self, image_feats): 103 | # image_feats: [N, L_cond, D_cond] 104 | 105 | N = image_feats.shape[0] 106 | H = W = self.triplane_low_res 107 | L = 3 * H * W 108 | 109 | x = self.pos_embed.repeat(N, 1, 1) # [N, L, D] 110 | for layer in self.layers: 111 | x = layer(x, image_feats) 112 | x = self.norm(x) 113 | 114 | # separate each plane and apply deconv 115 | x = x.view(N, 3, H, W, -1) 116 | x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W] 117 | x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W] 118 | x = self.deconv(x) # [3*N, D', H', W'] 119 | x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W'] 120 | x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W'] 121 | x = x.contiguous() 122 | 123 | return x 124 | -------------------------------------------------------------------------------- /preprocess/src/models/renderer/utils/math_utils.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Petr Kellnhofer 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import torch 24 | 25 | def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: 26 | """ 27 | Left-multiplies MxM @ NxM. Returns NxM. 28 | """ 29 | res = torch.matmul(vectors4, matrix.T) 30 | return res 31 | 32 | 33 | def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: 34 | """ 35 | Normalize vector lengths. 36 | """ 37 | return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) 38 | 39 | def torch_dot(x: torch.Tensor, y: torch.Tensor): 40 | """ 41 | Dot product of two tensors. 42 | """ 43 | return (x * y).sum(-1) 44 | 45 | 46 | def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length): 47 | """ 48 | Author: Petr Kellnhofer 49 | Intersects rays with the [-1, 1] NDC volume. 50 | Returns min and max distance of entry. 51 | Returns -1 for no intersection. 52 | https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection 53 | """ 54 | o_shape = rays_o.shape 55 | rays_o = rays_o.detach().reshape(-1, 3) 56 | rays_d = rays_d.detach().reshape(-1, 3) 57 | 58 | 59 | bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)] 60 | bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)] 61 | bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device) 62 | is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) 63 | 64 | # Precompute inverse for stability. 65 | invdir = 1 / rays_d 66 | sign = (invdir < 0).long() 67 | 68 | # Intersect with YZ plane. 69 | tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] 70 | tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] 71 | 72 | # Intersect with XZ plane. 73 | tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] 74 | tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] 75 | 76 | # Resolve parallel rays. 77 | is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False 78 | 79 | # Use the shortest intersection. 80 | tmin = torch.max(tmin, tymin) 81 | tmax = torch.min(tmax, tymax) 82 | 83 | # Intersect with XY plane. 84 | tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] 85 | tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] 86 | 87 | # Resolve parallel rays. 88 | is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False 89 | 90 | # Use the shortest intersection. 91 | tmin = torch.max(tmin, tzmin) 92 | tmax = torch.min(tmax, tzmax) 93 | 94 | # Mark invalid. 95 | tmin[torch.logical_not(is_valid)] = -1 96 | tmax[torch.logical_not(is_valid)] = -2 97 | 98 | return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) 99 | 100 | 101 | def linspace(start: torch.Tensor, stop: torch.Tensor, num: int): 102 | """ 103 | Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. 104 | Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. 105 | """ 106 | # create a tensor of 'num' steps from 0 to 1 107 | steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1) 108 | 109 | # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings 110 | # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript 111 | # "cannot statically infer the expected size of a list in this contex", hence the code below 112 | for i in range(start.ndim): 113 | steps = steps.unsqueeze(-1) 114 | 115 | # the output starts at 'start' and increments until 'stop' in each dimension 116 | out = start[None] + steps * (stop - start)[None] 117 | 118 | return out 119 | -------------------------------------------------------------------------------- /PartDrag4D/z_buffer_al.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | 3 | import numpy as np 4 | import os 5 | 6 | from tqdm import tqdm 7 | 8 | import argparse 9 | 10 | view_matrix = np.array([[-3.7966e-03, 7.0987e-01, -7.0432e-01, 1.8735e+00], 11 | [ 9.9999e-01, 2.6951e-03, -2.6740e-03, 0.0000e+00], 12 | [-1.1642e-10, -7.0432e-01, -7.0988e-01, 1.5000e+00], 13 | [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00]]) 14 | 15 | def local2world(points, scale, world_matrix): 16 | points = points.copy() 17 | scale = np.array(scale).copy() 18 | translation = np.array(world_matrix, dtype=np.float32) 19 | 20 | scaled_points = points * scale 21 | world_points = scaled_points + translation 22 | 23 | return world_points 24 | 25 | def ndc2Pix(v, S): 26 | return ((v + 1.0) * S - 1.0) * 0.5 27 | 28 | 29 | def judge(center, radius, z_buffer, cur_depth, epsilon=0.005): 30 | var_x = [max(center[0] - radius, 0), min(center[0] + radius, 255)] 31 | var_y = [max(center[1] - radius, 0), min(center[1] + radius, 255)] 32 | 33 | for i in range(var_x[0], var_x[1]): 34 | for j in range(var_y[0], var_y[1]): 35 | if z_buffer[i, j] == np.inf: 36 | continue 37 | if z_buffer[i, j] > cur_depth + epsilon: 38 | return False 39 | return True 40 | 41 | def refine_z_buffer_index(z_buffer, z_buffer_index, pcd_points): 42 | for i in range(z_buffer.shape[0]): 43 | for j in range(z_buffer.shape[1]): 44 | if z_buffer_index[i, j] == 0: 45 | continue 46 | if not judge([i, j], 10, z_buffer, z_buffer[i, j]): 47 | z_buffer_index[i, j] = 0 48 | z_buffer[i, j] = np.inf 49 | 50 | 51 | def project_points(points, cam_view): 52 | # local2world transformation using blender scale and world matrix 53 | scale = 0.42995866893870127 54 | world_matrix = (-0.0121, -0.0070, 0.0120) 55 | points = local2world(points, scale, world_matrix) 56 | 57 | # use blender fixed projection matrix 58 | proj_matrix = [ 59 | [2.777777671813965, 0.0000, 0.0000, 0.0000], 60 | [0.0000, 2.777777671813965, 0.0000, 0.0000], 61 | [0.0000, 0.0000, -1.0001999139785767, -0.20002000033855438], 62 | [0.0000, 0.0000, -1.0000, 0.0000] 63 | ] 64 | cam_proj = np.array(proj_matrix, dtype=np.float32) 65 | 66 | z_buffer = np.full((256, 256), fill_value=float('inf')) 67 | z_buffer_index = np.zeros([256, 256], dtype=np.int32) 68 | 69 | S = np.array([256, 256]) 70 | 71 | for i in range(points.shape[0]): 72 | view_matrix = cam_view 73 | view_matrix = np.linalg.inv(view_matrix) 74 | proj_matrix = cam_proj 75 | 76 | point_3D = points[i] 77 | point_3D_homogeneous = np.concatenate([point_3D, np.array([1.0])], axis=0) 78 | 79 | # Transfer to camera coordinates 80 | camera_coords = np.matmul(view_matrix, point_3D_homogeneous) 81 | 82 | # Transfer to clip coordinates 83 | clip_coords = np.matmul(proj_matrix, camera_coords) 84 | 85 | ndc_coords = clip_coords[:3] / clip_coords[3] 86 | 87 | ndc_coords_2d = ndc_coords[:2] 88 | ndc_depth = ndc_coords[2] 89 | 90 | # Get screen coordinates 91 | screen_coords = S - ndc2Pix(ndc_coords_2d, S) 92 | 93 | # Check if in the screen 94 | if int(screen_coords[0]) < 0 or int(screen_coords[0]) >= 256 or int(screen_coords[1]) < 0 or int(screen_coords[1]) >= 256: 95 | continue 96 | 97 | # z-buffer 98 | if z_buffer[int(screen_coords[0]), int(screen_coords[1])] > ndc_depth: 99 | z_buffer[int(screen_coords[0]), int(screen_coords[1])] = ndc_depth 100 | z_buffer_index[int(screen_coords[0]), int(screen_coords[1])] = i 101 | 102 | refine_z_buffer_index(z_buffer, z_buffer_index, points) 103 | return np.unique(z_buffer_index) 104 | 105 | def main(mesh_base): 106 | for categorial in os.listdir(mesh_base): 107 | if not os.path.isdir(os.path.join(mesh_base, categorial)): 108 | continue 109 | for mesh_id in tqdm(os.listdir(os.path.join(mesh_base, categorial))): 110 | if not os.path.exists(os.path.join(mesh_base, categorial, mesh_id, "motion")): 111 | continue 112 | for ply_file in os.listdir(os.path.join(mesh_base, categorial, mesh_id, "motion")): 113 | if ply_file.endswith(".ply"): 114 | print(f"Process {os.path.join(mesh_base, categorial, mesh_id, 'motion', ply_file)}") 115 | pcd = o3d.io.read_point_cloud(os.path.join(mesh_base, categorial, mesh_id, "motion", ply_file)) 116 | z_buffer_index = project_points(np.array(pcd.points), view_matrix) 117 | np.save(os.path.join(mesh_base, categorial, mesh_id, "motion", ply_file.replace(".ply", "_visible.npy")), z_buffer_index) 118 | 119 | 120 | if __name__ == "__main__": 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument('--mesh_base', default='./data/processed_data_partdrag4d', help='The render images dir') 123 | args = parser.parse_args() 124 | 125 | main(args.mesh_base) 126 | 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /preprocess/src/models/geometry/render/neural_render.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | import nvdiffrast.torch as dr 12 | from . import Renderer 13 | 14 | _FG_LUT = None 15 | 16 | 17 | def interpolate(attr, rast, attr_idx, rast_db=None): 18 | return dr.interpolate( 19 | attr.contiguous(), rast, attr_idx, rast_db=rast_db, 20 | diff_attrs=None if rast_db is None else 'all') 21 | 22 | 23 | def xfm_points(points, matrix, use_python=True): 24 | '''Transform points. 25 | Args: 26 | points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] 27 | matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] 28 | use_python: Use PyTorch's torch.matmul (for validation) 29 | Returns: 30 | Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. 31 | ''' 32 | out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2)) 33 | if torch.is_anomaly_enabled(): 34 | assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN" 35 | return out 36 | 37 | 38 | def dot(x, y): 39 | return torch.sum(x * y, -1, keepdim=True) 40 | 41 | 42 | def compute_vertex_normal(v_pos, t_pos_idx): 43 | i0 = t_pos_idx[:, 0] 44 | i1 = t_pos_idx[:, 1] 45 | i2 = t_pos_idx[:, 2] 46 | 47 | v0 = v_pos[i0, :] 48 | v1 = v_pos[i1, :] 49 | v2 = v_pos[i2, :] 50 | 51 | face_normals = torch.cross(v1 - v0, v2 - v0) 52 | 53 | # Splat face normals to vertices 54 | v_nrm = torch.zeros_like(v_pos) 55 | v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) 56 | v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) 57 | v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) 58 | 59 | # Normalize, replace zero (degenerated) normals with some default value 60 | v_nrm = torch.where( 61 | dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) 62 | ) 63 | v_nrm = F.normalize(v_nrm, dim=1) 64 | assert torch.all(torch.isfinite(v_nrm)) 65 | 66 | return v_nrm 67 | 68 | 69 | class NeuralRender(Renderer): 70 | def __init__(self, device='cuda', camera_model=None): 71 | super(NeuralRender, self).__init__() 72 | self.device = device 73 | self.ctx = dr.RasterizeCudaContext(device=device) 74 | self.projection_mtx = None 75 | self.camera = camera_model 76 | 77 | def render_mesh( 78 | self, 79 | mesh_v_pos_bxnx3, 80 | mesh_t_pos_idx_fx3, 81 | camera_mv_bx4x4, 82 | mesh_v_feat_bxnxd, 83 | resolution=256, 84 | spp=1, 85 | device='cuda', 86 | hierarchical_mask=False 87 | ): 88 | assert not hierarchical_mask 89 | 90 | mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4 91 | v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates 92 | v_pos_clip = self.camera.project(v_pos) # Projection in the camera 93 | 94 | v_nrm = compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates 95 | 96 | # Render the image, 97 | # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render 98 | num_layers = 1 99 | mask_pyramid = None 100 | assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes 101 | mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos 102 | 103 | with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler: 104 | for _ in range(num_layers): 105 | rast, db = peeler.rasterize_next_layer() 106 | gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3) 107 | 108 | hard_mask = torch.clamp(rast[..., -1:], 0, 1) 109 | antialias_mask = dr.antialias( 110 | hard_mask.clone().contiguous(), rast, v_pos_clip, 111 | mesh_t_pos_idx_fx3) 112 | 113 | depth = gb_feat[..., -2:-1] 114 | ori_mesh_feature = gb_feat[..., :-4] 115 | 116 | normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3) 117 | normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3) 118 | normal = F.normalize(normal, dim=-1) 119 | normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background 120 | 121 | return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal 122 | -------------------------------------------------------------------------------- /preprocess/src/models/geometry/rep_3d/flexicubes_geometry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | import numpy as np 11 | import os 12 | from . import Geometry 13 | from .flexicubes import FlexiCubes # replace later 14 | from .dmtet import sdf_reg_loss_batch 15 | import torch.nn.functional as F 16 | 17 | def get_center_boundary_index(grid_res, device): 18 | v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device) 19 | v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True 20 | center_indices = torch.nonzero(v.reshape(-1)) 21 | 22 | v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False 23 | v[:2, ...] = True 24 | v[-2:, ...] = True 25 | v[:, :2, ...] = True 26 | v[:, -2:, ...] = True 27 | v[:, :, :2] = True 28 | v[:, :, -2:] = True 29 | boundary_indices = torch.nonzero(v.reshape(-1)) 30 | return center_indices, boundary_indices 31 | 32 | ############################################################################### 33 | # Geometry interface 34 | ############################################################################### 35 | class FlexiCubesGeometry(Geometry): 36 | def __init__( 37 | self, grid_res=64, scale=2.0, device='cuda', renderer=None, 38 | render_type='neural_render', args=None): 39 | super(FlexiCubesGeometry, self).__init__() 40 | self.grid_res = grid_res 41 | self.device = device 42 | self.args = args 43 | self.fc = FlexiCubes(device, weight_scale=0.5) 44 | self.verts, self.indices = self.fc.construct_voxel_grid(grid_res) 45 | if isinstance(scale, list): 46 | self.verts[:, 0] = self.verts[:, 0] * scale[0] 47 | self.verts[:, 1] = self.verts[:, 1] * scale[1] 48 | self.verts[:, 2] = self.verts[:, 2] * scale[1] 49 | else: 50 | self.verts = self.verts * scale 51 | 52 | all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2) 53 | self.all_edges = torch.unique(all_edges, dim=0) 54 | 55 | # Parameters used for fix boundary sdf 56 | self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device) 57 | self.renderer = renderer 58 | self.render_type = render_type 59 | 60 | def getAABB(self): 61 | return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values 62 | 63 | def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False): 64 | if indices is None: 65 | indices = self.indices 66 | 67 | verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res, 68 | beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20], 69 | gamma_f=weight_n[:, 20], training=is_training 70 | ) 71 | return verts, faces, v_reg_loss 72 | 73 | 74 | def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False): 75 | return_value = dict() 76 | if self.render_type == 'neural_render': 77 | tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal = self.renderer.render_mesh( 78 | mesh_v_nx3.unsqueeze(dim=0), 79 | mesh_f_fx3.int(), 80 | camera_mv_bx4x4, 81 | mesh_v_nx3.unsqueeze(dim=0), 82 | resolution=resolution, 83 | device=self.device, 84 | hierarchical_mask=hierarchical_mask 85 | ) 86 | 87 | return_value['tex_pos'] = tex_pos 88 | return_value['mask'] = mask 89 | return_value['hard_mask'] = hard_mask 90 | return_value['rast'] = rast 91 | return_value['v_pos_clip'] = v_pos_clip 92 | return_value['mask_pyramid'] = mask_pyramid 93 | return_value['depth'] = depth 94 | return_value['normal'] = normal 95 | else: 96 | raise NotImplementedError 97 | 98 | return return_value 99 | 100 | def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256): 101 | # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 102 | v_list = [] 103 | f_list = [] 104 | n_batch = v_deformed_bxnx3.shape[0] 105 | all_render_output = [] 106 | for i_batch in range(n_batch): 107 | verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]) 108 | v_list.append(verts_nx3) 109 | f_list.append(faces_fx3) 110 | render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution) 111 | all_render_output.append(render_output) 112 | 113 | # Concatenate all render output 114 | return_keys = all_render_output[0].keys() 115 | return_value = dict() 116 | for k in return_keys: 117 | value = [v[k] for v in all_render_output] 118 | return_value[k] = value 119 | # We can do concatenation outside of the render 120 | return return_value 121 | -------------------------------------------------------------------------------- /preprocess/src/models/renderer/synthesizer_mesh.py: -------------------------------------------------------------------------------- 1 | # ORIGINAL LICENSE 2 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | # 5 | # Modified by Jiale Xu 6 | # The modifications are subject to the same license as the original. 7 | 8 | import itertools 9 | import torch 10 | import torch.nn as nn 11 | 12 | from .utils.renderer import generate_planes, project_onto_planes, sample_from_planes 13 | 14 | 15 | class OSGDecoder(nn.Module): 16 | """ 17 | Triplane decoder that gives RGB and sigma values from sampled features. 18 | Using ReLU here instead of Softplus in the original implementation. 19 | 20 | Reference: 21 | EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 22 | """ 23 | def __init__(self, n_features: int, 24 | hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): 25 | super().__init__() 26 | 27 | self.net_sdf = nn.Sequential( 28 | nn.Linear(3 * n_features, hidden_dim), 29 | activation(), 30 | *itertools.chain(*[[ 31 | nn.Linear(hidden_dim, hidden_dim), 32 | activation(), 33 | ] for _ in range(num_layers - 2)]), 34 | nn.Linear(hidden_dim, 1), 35 | ) 36 | self.net_rgb = nn.Sequential( 37 | nn.Linear(3 * n_features, hidden_dim), 38 | activation(), 39 | *itertools.chain(*[[ 40 | nn.Linear(hidden_dim, hidden_dim), 41 | activation(), 42 | ] for _ in range(num_layers - 2)]), 43 | nn.Linear(hidden_dim, 3), 44 | ) 45 | self.net_deformation = nn.Sequential( 46 | nn.Linear(3 * n_features, hidden_dim), 47 | activation(), 48 | *itertools.chain(*[[ 49 | nn.Linear(hidden_dim, hidden_dim), 50 | activation(), 51 | ] for _ in range(num_layers - 2)]), 52 | nn.Linear(hidden_dim, 3), 53 | ) 54 | self.net_weight = nn.Sequential( 55 | nn.Linear(8 * 3 * n_features, hidden_dim), 56 | activation(), 57 | *itertools.chain(*[[ 58 | nn.Linear(hidden_dim, hidden_dim), 59 | activation(), 60 | ] for _ in range(num_layers - 2)]), 61 | nn.Linear(hidden_dim, 21), 62 | ) 63 | 64 | # init all bias to zero 65 | for m in self.modules(): 66 | if isinstance(m, nn.Linear): 67 | nn.init.zeros_(m.bias) 68 | 69 | def get_geometry_prediction(self, sampled_features, flexicubes_indices): 70 | _N, n_planes, _M, _C = sampled_features.shape 71 | sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) 72 | 73 | sdf = self.net_sdf(sampled_features) 74 | deformation = self.net_deformation(sampled_features) 75 | 76 | grid_features = torch.index_select(input=sampled_features, index=flexicubes_indices.reshape(-1), dim=1) 77 | grid_features = grid_features.reshape( 78 | sampled_features.shape[0], flexicubes_indices.shape[0], flexicubes_indices.shape[1] * sampled_features.shape[-1]) 79 | weight = self.net_weight(grid_features) * 0.1 80 | 81 | return sdf, deformation, weight 82 | 83 | def get_texture_prediction(self, sampled_features): 84 | _N, n_planes, _M, _C = sampled_features.shape 85 | sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) 86 | 87 | rgb = self.net_rgb(sampled_features) 88 | rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF 89 | 90 | return rgb 91 | 92 | 93 | class TriplaneSynthesizer(nn.Module): 94 | """ 95 | Synthesizer that renders a triplane volume with planes and a camera. 96 | 97 | Reference: 98 | EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 99 | """ 100 | 101 | DEFAULT_RENDERING_KWARGS = { 102 | 'ray_start': 'auto', 103 | 'ray_end': 'auto', 104 | 'box_warp': 2., 105 | 'white_back': True, 106 | 'disparity_space_sampling': False, 107 | 'clamp_mode': 'softplus', 108 | 'sampler_bbox_min': -1., 109 | 'sampler_bbox_max': 1., 110 | } 111 | 112 | def __init__(self, triplane_dim: int, samples_per_ray: int): 113 | super().__init__() 114 | 115 | # attributes 116 | self.triplane_dim = triplane_dim 117 | self.rendering_kwargs = { 118 | **self.DEFAULT_RENDERING_KWARGS, 119 | 'depth_resolution': samples_per_ray // 2, 120 | 'depth_resolution_importance': samples_per_ray // 2, 121 | } 122 | 123 | # modules 124 | self.plane_axes = generate_planes() 125 | self.decoder = OSGDecoder(n_features=triplane_dim) 126 | 127 | def get_geometry_prediction(self, planes, sample_coordinates, flexicubes_indices): 128 | plane_axes = self.plane_axes.to(planes.device) 129 | sampled_features = sample_from_planes( 130 | plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp']) 131 | 132 | sdf, deformation, weight = self.decoder.get_geometry_prediction(sampled_features, flexicubes_indices) 133 | return sdf, deformation, weight 134 | 135 | def get_texture_prediction(self, planes, sample_coordinates): 136 | plane_axes = self.plane_axes.to(planes.device) 137 | sampled_features = sample_from_planes( 138 | plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp']) 139 | 140 | rgb = self.decoder.get_texture_prediction(sampled_features) 141 | return rgb 142 | -------------------------------------------------------------------------------- /core/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | import os 11 | import warnings 12 | 13 | from torch import Tensor 14 | from torch import nn 15 | 16 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 17 | try: 18 | if XFORMERS_ENABLED: 19 | from xformers.ops import memory_efficient_attention, unbind 20 | 21 | XFORMERS_AVAILABLE = True 22 | warnings.warn("xFormers is available (Attention)") 23 | else: 24 | warnings.warn("xFormers is disabled (Attention)") 25 | raise ImportError 26 | except ImportError: 27 | XFORMERS_AVAILABLE = False 28 | warnings.warn("xFormers is not available (Attention)") 29 | 30 | 31 | class Attention(nn.Module): 32 | def __init__( 33 | self, 34 | dim: int, 35 | num_heads: int = 8, 36 | qkv_bias: bool = False, 37 | proj_bias: bool = True, 38 | attn_drop: float = 0.0, 39 | proj_drop: float = 0.0, 40 | ) -> None: 41 | super().__init__() 42 | self.num_heads = num_heads 43 | head_dim = dim // num_heads 44 | self.scale = head_dim**-0.5 45 | 46 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 47 | self.attn_drop = nn.Dropout(attn_drop) 48 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 49 | self.proj_drop = nn.Dropout(proj_drop) 50 | 51 | def forward(self, x: Tensor) -> Tensor: 52 | B, N, C = x.shape 53 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 54 | 55 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 56 | attn = q @ k.transpose(-2, -1) 57 | 58 | attn = attn.softmax(dim=-1) 59 | attn = self.attn_drop(attn) 60 | 61 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 62 | x = self.proj(x) 63 | x = self.proj_drop(x) 64 | return x 65 | 66 | 67 | class MemEffAttention(Attention): 68 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 69 | if not XFORMERS_AVAILABLE: 70 | if attn_bias is not None: 71 | raise AssertionError("xFormers is required for using nested tensors") 72 | return super().forward(x) 73 | 74 | B, N, C = x.shape 75 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 76 | 77 | q, k, v = unbind(qkv, 2) 78 | 79 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 80 | x = x.reshape([B, N, C]) 81 | 82 | x = self.proj(x) 83 | x = self.proj_drop(x) 84 | return x 85 | 86 | 87 | class CrossAttention(nn.Module): 88 | def __init__( 89 | self, 90 | dim: int, 91 | dim_q: int, 92 | dim_k: int, 93 | dim_v: int, 94 | num_heads: int = 8, 95 | qkv_bias: bool = False, 96 | proj_bias: bool = True, 97 | attn_drop: float = 0.0, 98 | proj_drop: float = 0.0, 99 | ) -> None: 100 | super().__init__() 101 | self.dim = dim 102 | self.num_heads = num_heads 103 | head_dim = dim // num_heads 104 | self.scale = head_dim**-0.5 105 | 106 | self.to_q = nn.Linear(dim_q, dim, bias=qkv_bias) 107 | self.to_k = nn.Linear(dim_k, dim, bias=qkv_bias) 108 | self.to_v = nn.Linear(dim_v, dim, bias=qkv_bias) 109 | self.attn_drop = nn.Dropout(attn_drop) 110 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 111 | self.proj_drop = nn.Dropout(proj_drop) 112 | 113 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 114 | # q: [B, N, Cq] 115 | # k: [B, M, Ck] 116 | # v: [B, M, Cv] 117 | # return: [B, N, C] 118 | 119 | B, N, _ = q.shape 120 | M = k.shape[1] 121 | 122 | q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, N, C/nh] 123 | k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh] 124 | v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh] 125 | 126 | attn = q @ k.transpose(-2, -1) # [B, nh, N, M] 127 | 128 | attn = attn.softmax(dim=-1) # [B, nh, N, M] 129 | attn = self.attn_drop(attn) 130 | 131 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) # [B, nh, N, M] @ [B, nh, M, C/nh] --> [B, nh, N, C/nh] --> [B, N, nh, C/nh] --> [B, N, C] 132 | x = self.proj(x) 133 | x = self.proj_drop(x) 134 | return x 135 | 136 | 137 | class MemEffCrossAttention(CrossAttention): 138 | def forward(self, q: Tensor, k: Tensor, v: Tensor, attn_bias=None) -> Tensor: 139 | if not XFORMERS_AVAILABLE: 140 | if attn_bias is not None: 141 | raise AssertionError("xFormers is required for using nested tensors") 142 | return super().forward(x) 143 | 144 | B, N, _ = q.shape 145 | M = k.shape[1] 146 | 147 | q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads) # [B, N, nh, C/nh] 148 | k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh] 149 | v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh] 150 | 151 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 152 | x = x.reshape(B, N, -1) 153 | 154 | x = self.proj(x) 155 | x = self.proj_drop(x) 156 | return x 157 | -------------------------------------------------------------------------------- /preprocess/src/models/renderer/utils/ray_sampler.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 3 | # 4 | # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual 5 | # property and proprietary rights in and to this material, related 6 | # documentation and any modifications thereto. Any use, reproduction, 7 | # disclosure or distribution of this material and related documentation 8 | # without an express license agreement from NVIDIA CORPORATION or 9 | # its affiliates is strictly prohibited. 10 | # 11 | # Modified by Jiale Xu 12 | # The modifications are subject to the same license as the original. 13 | 14 | 15 | """ 16 | The ray sampler is a module that takes in camera matrices and resolution and batches of rays. 17 | Expects cam2world matrices that use the OpenCV camera coordinate system conventions. 18 | """ 19 | 20 | import torch 21 | 22 | class RaySampler(torch.nn.Module): 23 | def __init__(self): 24 | super().__init__() 25 | self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None 26 | 27 | 28 | def forward(self, cam2world_matrix, intrinsics, render_size): 29 | """ 30 | Create batches of rays and return origins and directions. 31 | 32 | cam2world_matrix: (N, 4, 4) 33 | intrinsics: (N, 3, 3) 34 | render_size: int 35 | 36 | ray_origins: (N, M, 3) 37 | ray_dirs: (N, M, 2) 38 | """ 39 | 40 | dtype = cam2world_matrix.dtype 41 | device = cam2world_matrix.device 42 | N, M = cam2world_matrix.shape[0], render_size**2 43 | cam_locs_world = cam2world_matrix[:, :3, 3] 44 | fx = intrinsics[:, 0, 0] 45 | fy = intrinsics[:, 1, 1] 46 | cx = intrinsics[:, 0, 2] 47 | cy = intrinsics[:, 1, 2] 48 | sk = intrinsics[:, 0, 1] 49 | 50 | uv = torch.stack(torch.meshgrid( 51 | torch.arange(render_size, dtype=dtype, device=device), 52 | torch.arange(render_size, dtype=dtype, device=device), 53 | indexing='ij', 54 | )) 55 | uv = uv.flip(0).reshape(2, -1).transpose(1, 0) 56 | uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) 57 | 58 | x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size) 59 | y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size) 60 | z_cam = torch.ones((N, M), dtype=dtype, device=device) 61 | 62 | x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam 63 | y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam 64 | 65 | cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1).to(dtype) 66 | 67 | _opencv2blender = torch.tensor([ 68 | [1, 0, 0, 0], 69 | [0, -1, 0, 0], 70 | [0, 0, -1, 0], 71 | [0, 0, 0, 1], 72 | ], dtype=dtype, device=device).unsqueeze(0).repeat(N, 1, 1) 73 | 74 | cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) 75 | 76 | world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] 77 | 78 | ray_dirs = world_rel_points - cam_locs_world[:, None, :] 79 | ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2).to(dtype) 80 | 81 | ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) 82 | 83 | return ray_origins, ray_dirs 84 | 85 | 86 | class OrthoRaySampler(torch.nn.Module): 87 | def __init__(self): 88 | super().__init__() 89 | self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None 90 | 91 | 92 | def forward(self, cam2world_matrix, ortho_scale, render_size): 93 | """ 94 | Create batches of rays and return origins and directions. 95 | 96 | cam2world_matrix: (N, 4, 4) 97 | ortho_scale: float 98 | render_size: int 99 | 100 | ray_origins: (N, M, 3) 101 | ray_dirs: (N, M, 3) 102 | """ 103 | 104 | N, M = cam2world_matrix.shape[0], render_size**2 105 | 106 | uv = torch.stack(torch.meshgrid( 107 | torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device), 108 | torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device), 109 | indexing='ij', 110 | )) 111 | uv = uv.flip(0).reshape(2, -1).transpose(1, 0) 112 | uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) 113 | 114 | x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size) 115 | y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size) 116 | z_cam = torch.zeros((N, M), device=cam2world_matrix.device) 117 | 118 | x_lift = (x_cam - 0.5) * ortho_scale 119 | y_lift = (y_cam - 0.5) * ortho_scale 120 | 121 | cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) 122 | 123 | _opencv2blender = torch.tensor([ 124 | [1, 0, 0, 0], 125 | [0, -1, 0, 0], 126 | [0, 0, -1, 0], 127 | [0, 0, 0, 1], 128 | ], dtype=torch.float32, device=cam2world_matrix.device).unsqueeze(0).repeat(N, 1, 1) 129 | 130 | cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) 131 | 132 | ray_origins = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] 133 | 134 | ray_dirs_cam = torch.stack([ 135 | torch.zeros((N, M), device=cam2world_matrix.device), 136 | torch.zeros((N, M), device=cam2world_matrix.device), 137 | torch.ones((N, M), device=cam2world_matrix.device), 138 | ], dim=-1) 139 | ray_dirs = torch.bmm(cam2world_matrix[:, :3, :3], ray_dirs_cam.permute(0, 2, 1)).permute(0, 2, 1) 140 | 141 | return ray_origins, ray_dirs 142 | -------------------------------------------------------------------------------- /preprocess/src/utils/mesh_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | import xatlas 11 | import trimesh 12 | import cv2 13 | import numpy as np 14 | import nvdiffrast.torch as dr 15 | from PIL import Image 16 | 17 | 18 | def save_obj(pointnp_px3, facenp_fx3, colornp_px3, fpath): 19 | 20 | pointnp_px3 = pointnp_px3 @ np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]]) 21 | facenp_fx3 = facenp_fx3[:, [2, 1, 0]] 22 | 23 | mesh = trimesh.Trimesh( 24 | vertices=pointnp_px3, 25 | faces=facenp_fx3, 26 | vertex_colors=colornp_px3, 27 | ) 28 | mesh.export(fpath, 'obj') 29 | 30 | 31 | def save_glb(pointnp_px3, facenp_fx3, colornp_px3, fpath): 32 | 33 | pointnp_px3 = pointnp_px3 @ np.array([[-1, 0, 0], [0, 1, 0], [0, 0, -1]]) 34 | 35 | mesh = trimesh.Trimesh( 36 | vertices=pointnp_px3, 37 | faces=facenp_fx3, 38 | vertex_colors=colornp_px3, 39 | ) 40 | mesh.export(fpath, 'glb') 41 | 42 | 43 | def save_obj_with_mtl(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, texmap_hxwx3, fname): 44 | import os 45 | fol, na = os.path.split(fname) 46 | na, _ = os.path.splitext(na) 47 | 48 | matname = '%s/%s.mtl' % (fol, na) 49 | fid = open(matname, 'w') 50 | fid.write('newmtl material_0\n') 51 | fid.write('Kd 1 1 1\n') 52 | fid.write('Ka 0 0 0\n') 53 | fid.write('Ks 0.4 0.4 0.4\n') 54 | fid.write('Ns 10\n') 55 | fid.write('illum 2\n') 56 | fid.write('map_Kd %s.png\n' % na) 57 | fid.close() 58 | #### 59 | 60 | fid = open(fname, 'w') 61 | fid.write('mtllib %s.mtl\n' % na) 62 | 63 | for pidx, p in enumerate(pointnp_px3): 64 | pp = p 65 | fid.write('v %f %f %f\n' % (pp[0], pp[1], pp[2])) 66 | 67 | for pidx, p in enumerate(tcoords_px2): 68 | pp = p 69 | fid.write('vt %f %f\n' % (pp[0], pp[1])) 70 | 71 | fid.write('usemtl material_0\n') 72 | for i, f in enumerate(facenp_fx3): 73 | f1 = f + 1 74 | f2 = facetex_fx3[i] + 1 75 | fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) 76 | fid.close() 77 | 78 | # save texture map 79 | lo, hi = 0, 1 80 | img = np.asarray(texmap_hxwx3, dtype=np.float32) 81 | img = (img - lo) * (255 / (hi - lo)) 82 | img = img.clip(0, 255) 83 | mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True) 84 | mask = (mask <= 3.0).astype(np.float32) 85 | kernel = np.ones((3, 3), 'uint8') 86 | dilate_img = cv2.dilate(img, kernel, iterations=1) 87 | img = img * (1 - mask) + dilate_img * mask 88 | img = img.clip(0, 255).astype(np.uint8) 89 | Image.fromarray(np.ascontiguousarray(img[::-1, :, :]), 'RGB').save(f'{fol}/{na}.png') 90 | 91 | 92 | def loadobj(meshfile): 93 | v = [] 94 | f = [] 95 | meshfp = open(meshfile, 'r') 96 | for line in meshfp.readlines(): 97 | data = line.strip().split(' ') 98 | data = [da for da in data if len(da) > 0] 99 | if len(data) != 4: 100 | continue 101 | if data[0] == 'v': 102 | v.append([float(d) for d in data[1:]]) 103 | if data[0] == 'f': 104 | data = [da.split('/')[0] for da in data] 105 | f.append([int(d) for d in data[1:]]) 106 | meshfp.close() 107 | 108 | # torch need int64 109 | facenp_fx3 = np.array(f, dtype=np.int64) - 1 110 | pointnp_px3 = np.array(v, dtype=np.float32) 111 | return pointnp_px3, facenp_fx3 112 | 113 | 114 | def loadobjtex(meshfile): 115 | v = [] 116 | vt = [] 117 | f = [] 118 | ft = [] 119 | meshfp = open(meshfile, 'r') 120 | for line in meshfp.readlines(): 121 | data = line.strip().split(' ') 122 | data = [da for da in data if len(da) > 0] 123 | if not ((len(data) == 3) or (len(data) == 4) or (len(data) == 5)): 124 | continue 125 | if data[0] == 'v': 126 | assert len(data) == 4 127 | 128 | v.append([float(d) for d in data[1:]]) 129 | if data[0] == 'vt': 130 | if len(data) == 3 or len(data) == 4: 131 | vt.append([float(d) for d in data[1:3]]) 132 | if data[0] == 'f': 133 | data = [da.split('/') for da in data] 134 | if len(data) == 4: 135 | f.append([int(d[0]) for d in data[1:]]) 136 | ft.append([int(d[1]) for d in data[1:]]) 137 | elif len(data) == 5: 138 | idx1 = [1, 2, 3] 139 | data1 = [data[i] for i in idx1] 140 | f.append([int(d[0]) for d in data1]) 141 | ft.append([int(d[1]) for d in data1]) 142 | idx2 = [1, 3, 4] 143 | data2 = [data[i] for i in idx2] 144 | f.append([int(d[0]) for d in data2]) 145 | ft.append([int(d[1]) for d in data2]) 146 | meshfp.close() 147 | 148 | # torch need int64 149 | facenp_fx3 = np.array(f, dtype=np.int64) - 1 150 | ftnp_fx3 = np.array(ft, dtype=np.int64) - 1 151 | pointnp_px3 = np.array(v, dtype=np.float32) 152 | uvs = np.array(vt, dtype=np.float32) 153 | return pointnp_px3, facenp_fx3, uvs, ftnp_fx3 154 | 155 | 156 | # ============================================================================================== 157 | def interpolate(attr, rast, attr_idx, rast_db=None): 158 | return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') 159 | 160 | 161 | def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution): 162 | vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy()) 163 | 164 | # Convert to tensors 165 | indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) 166 | 167 | uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) 168 | mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) 169 | # mesh_v_tex. ture 170 | uv_clip = uvs[None, ...] * 2.0 - 1.0 171 | 172 | # pad to four component coordinate 173 | uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1) 174 | 175 | # rasterize 176 | rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution)) 177 | 178 | # Interpolate world space position 179 | gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int()) 180 | mask = rast[..., 3:4] > 0 181 | return uvs, mesh_tex_idx, gb_pos, mask 182 | -------------------------------------------------------------------------------- /core/drag_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torchvision.transforms import GaussianBlur 5 | 6 | class FourierEmbedder(object): 7 | def __init__(self, num_freqs=64, temperature=100): 8 | 9 | self.num_freqs = num_freqs 10 | self.temperature = temperature 11 | self.freq_bands = temperature ** ( torch.arange(num_freqs) / num_freqs ) 12 | 13 | @torch.no_grad() 14 | def __call__(self, x, cat_dim=-1): 15 | "x: arbitrary shape of tensor. dim: cat dim" 16 | out = [] 17 | for freq in self.freq_bands: 18 | out.append( torch.sin( freq*x ) ) 19 | out.append( torch.cos( freq*x ) ) 20 | return torch.cat(out, cat_dim) 21 | 22 | 23 | class DragPositionNet(nn.Module): 24 | def __init__(self, num_drags, fourier_freqs=8, downsample_ratio=64): 25 | super().__init__() 26 | self.num_drags = num_drags 27 | 28 | self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) 29 | self.position_dim = fourier_freqs*2*2 # 2 for sin and cos, 2 for 2 dims (x1, y1) or (x2, y2) 30 | 31 | # -------------------------------------------------------------- # 32 | self.linears_drag = nn.Sequential( 33 | nn.Linear(self.position_dim, 128), 34 | nn.SiLU(), 35 | nn.Linear(128, 256), 36 | nn.SiLU(), 37 | nn.Linear(256, 512), 38 | ) 39 | 40 | self.downsample_ratio = downsample_ratio 41 | 42 | 43 | def forward(self, drags_start, drags_end): 44 | # drags_start: [B, V, N, 2], start points of drags 45 | # drags_end: [B, V, N, 2], move vectors of drags 46 | B, V, N, _ = drags_start.shape 47 | drags_start = drags_start.view(B*V, N, -1) 48 | 49 | drags_start_embeddings = [] 50 | for i in range(N): 51 | drag_start_embedding = self.fourier_embedder(drags_start[:, i, :]) 52 | drags_start_embeddings.append(self.linears_drag(drag_start_embedding)) 53 | drags_start_embeddings = torch.stack(drags_start_embeddings, dim=1) 54 | 55 | drags_end = drags_end.view(B*V, N, -1) 56 | drags_end_embeddings = [] 57 | for i in range(N): 58 | drag_end_embedding = self.fourier_embedder(drags_end[:, i, :]) 59 | drags_end_embeddings.append(self.linears_drag(drag_end_embedding)) 60 | drags_end_embeddings = torch.stack(drags_end_embeddings, dim=1) 61 | 62 | merge_start_embeddings = torch.zeros((B*V, 512, 8, 8)).to(drag_start_embedding.device) # [B*V, 256, 8, 8] 63 | merge_end_embeddings = torch.zeros((B*V, 512, 8, 8)).to(drag_start_embedding.device) # [B*V, 256, 8, 8] 64 | 65 | for i in range(B*V): 66 | for j in range(N): 67 | merge_start_embeddings[i, :, int(drags_start[i, j, 0]) // self.downsample_ratio, 68 | int(drags_start[i, j, 1]) // self.downsample_ratio] += drags_start_embeddings[i,j,:] 69 | merge_end_embeddings[i, :, int(drags_end[i, j, 0]) // self.downsample_ratio, 70 | int(drags_end[i, j, 1]) // self.downsample_ratio] += drags_end_embeddings[i,j, :] 71 | 72 | merge_embeddings = torch.cat([merge_start_embeddings, merge_end_embeddings], dim=1) 73 | return merge_embeddings 74 | 75 | 76 | class DragPositionNetMultiScale(nn.Module): 77 | def __init__(self, fourier_freqs=8, scales=[256, 128, 64, 32, 16, 8], channels=[64, 64, 128, 256, 512, 1024], drag_layer_idx=None): 78 | super().__init__() 79 | self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) 80 | self.position_dim = fourier_freqs*2*2 # 2 for sin and cos, 2 for 2 dims (x1, y1) or (x2, y2) 81 | 82 | # -------------------------------------------------------------- # 83 | self.linear_drags = [] 84 | for i in range(len(channels)): 85 | if drag_layer_idx is not None and i != drag_layer_idx: 86 | continue 87 | self.linear_drags.append(nn.Sequential( 88 | nn.Linear(self.position_dim, 128), 89 | nn.SiLU(), 90 | nn.Linear(128, 256), 91 | nn.SiLU(), 92 | nn.Linear(256, channels[i]//2), 93 | )) 94 | 95 | self.linears_drags = nn.ModuleList(self.linear_drags) 96 | self.gaussian_blur = GaussianBlur(kernel_size=5, sigma=1.0) 97 | 98 | self.scales = scales 99 | self.channels = channels 100 | 101 | if drag_layer_idx is not None: 102 | self.drag_layer_idx = drag_layer_idx 103 | self.scales = self.scales[self.drag_layer_idx:self.drag_layer_idx+1] 104 | self.channels = self.channels[self.drag_layer_idx:self.drag_layer_idx+1] 105 | 106 | def forward(self, drags_start, drags_end): 107 | scales = self.scales 108 | channels = self.channels 109 | 110 | B, V, N, _ = drags_start.shape 111 | 112 | drags_start = drags_start.view(B*V, N, -1) 113 | drags_end = drags_end.view(B*V, N, -1) 114 | 115 | 116 | multi_scale_merge_start_embeddings = [] 117 | multi_scale_merge_end_embeddings = [] 118 | 119 | for idx, scale in enumerate(scales): 120 | drags_start_embeddings = [] 121 | drags_end_embeddings = [] 122 | for i in range(N): 123 | drag_start_embedding = self.fourier_embedder(drags_start[:, i, :]) 124 | drags_start_embeddings.append(self.linears_drags[idx](drag_start_embedding)) 125 | drags_start_embeddings = torch.stack(drags_start_embeddings, dim=1) 126 | 127 | for i in range(N): 128 | drag_end_embedding = self.fourier_embedder(drags_end[:, i, :]) 129 | drags_end_embeddings.append(self.linears_drags[idx](drag_end_embedding)) 130 | drags_end_embeddings = torch.stack(drags_end_embeddings, dim=1) 131 | 132 | merge_start_embeddings = torch.zeros((B*V, channels[idx]//2, scale, scale)).to(drag_start_embedding.device) 133 | merge_end_embeddings = torch.zeros((B*V, channels[idx]//2, scale, scale)).to(drag_start_embedding.device) 134 | downsample_ratio = 512 // scale 135 | 136 | for i in range(B*V): 137 | for j in range(N): 138 | merge_start_embeddings[i, :, int(drags_start[i, j, 0]) // downsample_ratio, 139 | int(drags_start[i, j, 1]) // downsample_ratio] += drags_start_embeddings[i,j,:] 140 | merge_end_embeddings[i, :, int(drags_end[i, j, 0]) // downsample_ratio, 141 | int(drags_end[i, j, 1]) // downsample_ratio] += drags_end_embeddings[i,j, :] 142 | # merge_end_embeddings[i, :, int(drags_start[i, j, 0]) // downsample_ratio, 143 | # int(drags_start[i, j, 1]) // downsample_ratio] += drags_end_embeddings[i,j, :] 144 | # Add Gaussian Blur 145 | merge_start_embeddings = self.gaussian_blur(merge_start_embeddings) 146 | merge_end_embeddings = self.gaussian_blur(merge_end_embeddings) 147 | 148 | multi_scale_merge_start_embeddings.append(merge_start_embeddings) 149 | multi_scale_merge_end_embeddings.append(merge_end_embeddings) 150 | 151 | return multi_scale_merge_start_embeddings, multi_scale_merge_end_embeddings 152 | -------------------------------------------------------------------------------- /core/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from safetensors.torch import load_file 6 | 7 | import kiui 8 | from kiui.lpips import LPIPS 9 | 10 | from core.unet import UNet, UNetWithMSDrag 11 | from core.options import Options 12 | from core.gs import GaussianRenderer 13 | 14 | 15 | class LGM(nn.Module): 16 | def __init__( 17 | self, 18 | opt: Options, 19 | ): 20 | super().__init__() 21 | 22 | self.opt = opt 23 | 24 | # unet 25 | if opt.use_ms_drag_encoding: 26 | self.unet = UNetWithMSDrag( 27 | 9, 14, 28 | down_channels=opt.down_channels, 29 | down_attention=opt.down_attention, 30 | mid_attention=opt.mid_attention, 31 | up_channels=opt.up_channels, 32 | up_attention=opt.up_attention, 33 | use_drag_encoding=opt.use_drag_encoding 34 | ) 35 | else: 36 | self.unet = UNet( 37 | 9, 14, 38 | down_channels=self.opt.down_channels, 39 | down_attention=self.opt.down_attention, 40 | mid_attention=self.opt.mid_attention, 41 | up_channels=self.opt.up_channels, 42 | up_attention=self.opt.up_attention, 43 | use_drag_encoding=self.opt.use_drag_encoding 44 | ) 45 | 46 | # last conv 47 | self.conv = nn.Conv2d(14, 14, kernel_size=1) # NOTE: maybe remove it if train again 48 | 49 | # Gaussian Renderer 50 | self.gs = GaussianRenderer(opt) 51 | 52 | # activations... 53 | self.pos_act = lambda x: x.clamp(-1, 1) 54 | self.scale_act = lambda x: 0.1 * F.softplus(x) 55 | self.opacity_act = lambda x: torch.sigmoid(x) 56 | self.rot_act = lambda x: F.normalize(x, dim=-1) 57 | self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5 # NOTE: may use sigmoid if train again 58 | 59 | # LPIPS loss 60 | if self.opt.lambda_lpips > 0: 61 | self.lpips_loss = LPIPS(net='vgg') 62 | self.lpips_loss.requires_grad_(False) 63 | 64 | 65 | def state_dict(self, **kwargs): 66 | # remove lpips_loss 67 | state_dict = super().state_dict(**kwargs) 68 | for k in list(state_dict.keys()): 69 | if 'lpips_loss' in k: 70 | del state_dict[k] 71 | return state_dict 72 | 73 | 74 | def prepare_default_rays(self, device, elevation=0): 75 | 76 | from kiui.cam import orbit_camera 77 | from core.utils import get_rays 78 | 79 | cam_poses = np.stack([ 80 | orbit_camera(elevation, 0, radius=self.opt.cam_radius), 81 | orbit_camera(elevation, 90, radius=self.opt.cam_radius), 82 | orbit_camera(elevation, 180, radius=self.opt.cam_radius), 83 | orbit_camera(elevation, 270, radius=self.opt.cam_radius), 84 | ], axis=0) # [4, 4, 4] 85 | cam_poses = torch.from_numpy(cam_poses) 86 | 87 | rays_embeddings = [] 88 | for i in range(cam_poses.shape[0]): 89 | rays_o, rays_d = get_rays(cam_poses[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3] 90 | rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6] 91 | rays_embeddings.append(rays_plucker) 92 | 93 | rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w] 94 | 95 | return rays_embeddings 96 | 97 | 98 | def forward_gaussians(self, images, drags_start=None, drags_end=None): 99 | # images: [B, 4, 9, H, W] 100 | # return: Gaussians: [B, dim_t] 101 | B, V, C, H, W = images.shape 102 | images = images.view(B*V, C, H, W) 103 | 104 | x = self.unet(images, drags_start, drags_end) # [B*4, 14, h, w] 105 | x = self.conv(x) # [B*4, 14, h, w] 106 | 107 | x = x.reshape(B, 4, 14, self.opt.splat_size, self.opt.splat_size) 108 | x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14) 109 | 110 | pos = self.pos_act(x[..., 0:3]) # [B, N, 3] 111 | opacity = self.opacity_act(x[..., 3:4]) 112 | scale = self.scale_act(x[..., 4:7]) 113 | rotation = self.rot_act(x[..., 7:11]) 114 | rgbs = self.rgb_act(x[..., 11:]) 115 | 116 | gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14] 117 | 118 | return gaussians 119 | 120 | def get_gaussian_loss(self, gt_gaussians, pred_gaussians): 121 | gs_loss = 0 122 | 123 | for b in range(gt_gaussians.shape[0]): 124 | gs_loss += F.mse_loss(gt_gaussians[b], pred_gaussians[b], reduction='mean') 125 | 126 | return gs_loss / gt_gaussians.shape[0] 127 | 128 | def forward(self, data, step_ratio=1): 129 | # data: output of the dataloader 130 | # return: loss 131 | 132 | results = {} 133 | loss = 0 134 | 135 | images = data['input'] # [B, 4, 9, h, W], input features 136 | if self.opt.use_drag_encoding: 137 | drags_start = data['drags_start'] # [B, N, 3], start points of drags 138 | drags_end = data['drags_end'] # [B, N, 3], move vectors of drags 139 | else: 140 | drags_start = None 141 | drags_end = None 142 | 143 | # use the first view to predict gaussians 144 | gaussians = self.forward_gaussians(images, drags_start, drags_end) # [B, N, 14] 145 | 146 | # always use white bg 147 | bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device) 148 | 149 | # use the other views for rendering and supervision 150 | results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color) 151 | pred_images = results['image'] # [B, V, C, output_size, output_size] 152 | pred_alphas = results['alpha'] # [B, V, 1, output_size, output_size] 153 | 154 | results['gaussians'] = gaussians 155 | 156 | results['images_pred'] = pred_images 157 | results['alphas_pred'] = pred_alphas 158 | 159 | gt_images = data['images_output'] # [B, V, 3, output_size, output_size], ground-truth novel views 160 | gt_masks = data['masks_output'] # [B, V, 1, output_size, output_size], ground-truth masks 161 | 162 | gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks) 163 | 164 | loss_mse = F.mse_loss(pred_images, gt_images) + F.mse_loss(pred_alphas, gt_masks) 165 | loss = loss + loss_mse 166 | 167 | if self.opt.lambda_lpips > 0: 168 | loss_lpips = self.lpips_loss( 169 | # downsampled to at most 256 to reduce memory cost 170 | F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False), 171 | F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False), 172 | ).mean() 173 | results['loss_lpips'] = loss_lpips 174 | loss = loss + self.opt.lambda_lpips * loss_lpips 175 | 176 | if self.opt.lambda_flow > 0 and self.opt.stage1: 177 | loss_flow = self.get_gaussian_loss(data['gt_gaussians'], gaussians) 178 | results['loss_flow'] = loss_flow 179 | loss = self.opt.lambda_flow * loss_flow 180 | 181 | results['loss'] = loss 182 | 183 | # metric 184 | with torch.no_grad(): 185 | psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2)) 186 | results['psnr'] = psnr 187 | 188 | 189 | return results, drags_start, drags_end 190 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tyro 2 | import time 3 | import random 4 | 5 | import torch 6 | from core.models import LGM 7 | from accelerate import Accelerator, DistributedDataParallelKwargs 8 | from safetensors.torch import load_file 9 | 10 | import kiui 11 | 12 | from PIL import Image, ImageDraw 13 | import numpy as np 14 | 15 | from torch.utils.tensorboard import SummaryWriter 16 | 17 | import argparse 18 | 19 | def main(dataset_name): 20 | if dataset_name == 'partdrag4d': 21 | from core.options import AllConfigs 22 | else: 23 | from core.options_pm import AllConfigs 24 | 25 | opt = tyro.cli(AllConfigs) 26 | 27 | accelerator = Accelerator( 28 | mixed_precision=opt.mixed_precision, 29 | gradient_accumulation_steps=opt.gradient_accumulation_steps, 30 | ) 31 | 32 | model = LGM(opt) 33 | # resume 34 | if opt.resume is not None: 35 | if opt.resume.endswith('safetensors'): 36 | ckpt = load_file(opt.resume, device='cpu') 37 | else: 38 | ckpt = torch.load(opt.resume, map_location='cpu') 39 | model.load_state_dict(ckpt, strict=False) 40 | 41 | # data 42 | if dataset_name == 'partdrag4d': 43 | from core.train_dataset_partdrag4d import PartDrag4DTrainDatset as TrainDataset 44 | from core.eval_dataset_partdrag4d import PartDrag4DEvalDatset as EvalDataset 45 | elif dataset_name == 'objaverse_hq': 46 | from core.train_dataset_objaverse_hq import ObjaverseHQTrainDataset as TrainDataset 47 | from core.eval_dataset_objaverse_hq import ObjaverseHQEvalDataset as EvalDataset 48 | 49 | train_dataset = TrainDataset(opt) 50 | train_dataloader = torch.utils.data.DataLoader( 51 | train_dataset, 52 | batch_size=opt.batch_size, 53 | shuffle=True, 54 | num_workers=4, 55 | pin_memory=True, 56 | drop_last=True, 57 | ) 58 | 59 | test_dataset = EvalDataset(opt) 60 | test_dataloader = torch.utils.data.DataLoader( 61 | test_dataset, 62 | batch_size=1, 63 | shuffle=False, 64 | num_workers=4, 65 | pin_memory=True, 66 | drop_last=False, 67 | ) 68 | 69 | # optimizer 70 | optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr, weight_decay=0.05, betas=(0.9, 0.95)) 71 | 72 | total_steps = opt.num_epochs * len(train_dataloader) 73 | pct_start = 3000 / total_steps 74 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=opt.lr, total_steps=total_steps, pct_start=pct_start) 75 | 76 | # accelerate 77 | model, optimizer, train_dataloader, test_dataloader, scheduler = accelerator.prepare( 78 | model, optimizer, train_dataloader, test_dataloader, scheduler 79 | ) 80 | 81 | writer = SummaryWriter(opt.workspace) 82 | 83 | # loop 84 | for epoch in range(opt.num_epochs): 85 | # train 86 | model.train() 87 | total_loss = 0 88 | total_psnr = 0 89 | for i, data in enumerate(train_dataloader): 90 | with accelerator.accumulate(model): 91 | 92 | optimizer.zero_grad() 93 | 94 | step_ratio = (epoch + i / len(train_dataloader)) / opt.num_epochs 95 | 96 | out, drag_start_2d, drag_move_2d = model(data, step_ratio) 97 | loss = out['loss'] 98 | psnr = out['psnr'] 99 | 100 | writer.add_scalar('Loss/train_iter', loss.item(), epoch * len(train_dataloader) + i) 101 | 102 | accelerator.backward(loss) 103 | # gradient clipping 104 | if accelerator.sync_gradients: 105 | accelerator.clip_grad_norm_(model.parameters(), opt.gradient_clip) 106 | 107 | optimizer.step() 108 | scheduler.step() 109 | 110 | total_loss += loss.detach() 111 | total_psnr += psnr.detach() 112 | 113 | torch.cuda.empty_cache() 114 | 115 | if accelerator.is_main_process: 116 | # logging 117 | if i % 100 == 0: 118 | mem_free, mem_total = torch.cuda.mem_get_info() 119 | print(f"[INFO] {i}/{len(train_dataloader)} mem: {(mem_total-mem_free)/1024**3:.2f}/{mem_total/1024**3:.2f}G lr: {scheduler.get_last_lr()[0]:.7f} step_ratio: {step_ratio:.4f} loss: {loss.item():.6f}") 120 | 121 | # save log images 122 | if i % 200 == 0: 123 | gt_images = data['images_output'].detach().cpu().numpy() # [B, V, 3, output_size, output_size] 124 | gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3] 125 | kiui.write_image(f'{opt.workspace}/train_gt_images_{epoch}_{i}.jpg', gt_images) 126 | 127 | pred_images = out['images_pred'].detach().cpu().numpy() # [B, V, 3, output_size, output_size] 128 | pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3) 129 | kiui.write_image(f'{opt.workspace}/train_pred_images_{epoch}_{i}.jpg', pred_images) 130 | 131 | 132 | total_loss = accelerator.gather_for_metrics(total_loss).mean() 133 | total_psnr = accelerator.gather_for_metrics(total_psnr).mean() 134 | if accelerator.is_main_process: 135 | total_loss /= len(train_dataloader) 136 | total_psnr /= len(train_dataloader) 137 | accelerator.print(f"[train] epoch: {epoch} loss: {total_loss.item():.6f} psnr: {total_psnr.item():.4f}") 138 | 139 | writer.add_scalar('Loss/train_epoch', total_loss.item(), epoch) 140 | writer.add_scalar('PSNR/train_epoch', total_psnr.item(), epoch) 141 | 142 | accelerator.wait_for_everyone() 143 | accelerator.save_model(model, opt.workspace) 144 | 145 | # eval 146 | with torch.no_grad(): 147 | model.eval() 148 | total_psnr = 0 149 | total_lpips = 0 150 | for i, data in enumerate(test_dataloader): 151 | 152 | out, drag_start_2d, drag_move_2d = model(data) 153 | 154 | psnr = out['psnr'] 155 | lpips = out['loss_lpips'] 156 | total_psnr += psnr.detach() 157 | total_lpips += lpips.detach() 158 | 159 | # save some images 160 | if accelerator.is_main_process and i % 20 == 0: 161 | gt_images = data['images_output'].detach().cpu().numpy() # [B, V, 3, output_size, output_size] 162 | gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3] 163 | kiui.write_image(f'{opt.workspace}/eval_gt_images_{epoch}_{i}.jpg', gt_images) 164 | 165 | pred_images = out['images_pred'].detach().cpu().numpy() # [B, V, 3, output_size, output_size] 166 | pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3) 167 | kiui.write_image(f'{opt.workspace}/eval_pred_images_{epoch}_{i}.jpg', pred_images) 168 | 169 | 170 | torch.cuda.empty_cache() 171 | 172 | total_psnr = accelerator.gather_for_metrics(total_psnr).mean() 173 | total_lpips = accelerator.gather_for_metrics(total_lpips).mean() 174 | if accelerator.is_main_process: 175 | total_psnr /= len(test_dataloader) 176 | total_lpips /= len(test_dataloader) 177 | accelerator.print(f"[eval] epoch: {epoch} psnr: {total_psnr:.4f} lpips: {total_lpips:.4f}") 178 | 179 | writer.add_scalar('PSNR/eval_epoch', total_psnr.item(), epoch) 180 | 181 | 182 | if __name__ == "__main__": 183 | dataset = 'objaverse_hq' 184 | main(dataset) 185 | -------------------------------------------------------------------------------- /core/gs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from diff_gaussian_rasterization import ( 8 | GaussianRasterizationSettings, 9 | GaussianRasterizer, 10 | ) 11 | 12 | from core.options import Options 13 | 14 | import kiui 15 | 16 | class GaussianRenderer: 17 | def __init__(self, opt: Options): 18 | 19 | self.opt = opt 20 | self.bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda") 21 | 22 | # intrinsics 23 | self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy)) 24 | self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32) 25 | self.proj_matrix[0, 0] = 1 / self.tan_half_fov 26 | self.proj_matrix[1, 1] = 1 / self.tan_half_fov 27 | self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) 28 | self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) 29 | self.proj_matrix[2, 3] = 1 30 | 31 | def render(self, gaussians, cam_view, cam_view_proj, cam_pos, bg_color=None, scale_modifier=1): 32 | # gaussians: [B, N, 14] 33 | # cam_view, cam_view_proj: [B, V, 4, 4] 34 | # cam_pos: [B, V, 3] 35 | 36 | device = gaussians.device 37 | B, V = cam_view.shape[:2] 38 | 39 | # loop of loop... 40 | images = [] 41 | alphas = [] 42 | for b in range(B): 43 | 44 | # pos, opacity, scale, rotation, shs 45 | means3D = gaussians[b, :, 0:3].contiguous().float() 46 | opacity = gaussians[b, :, 3:4].contiguous().float() 47 | scales = gaussians[b, :, 4:7].contiguous().float() 48 | rotations = gaussians[b, :, 7:11].contiguous().float() 49 | rgbs = gaussians[b, :, 11:].contiguous().float() # [N, 3] 50 | 51 | for v in range(V): 52 | 53 | # render novel views 54 | view_matrix = cam_view[b, v].float() 55 | view_proj_matrix = cam_view_proj[b, v].float() 56 | campos = cam_pos[b, v].float() 57 | 58 | raster_settings = GaussianRasterizationSettings( 59 | image_height=self.opt.output_size, 60 | image_width=self.opt.output_size, 61 | tanfovx=self.tan_half_fov, 62 | tanfovy=self.tan_half_fov, 63 | bg=self.bg_color if bg_color is None else bg_color, 64 | scale_modifier=scale_modifier, 65 | viewmatrix=view_matrix, 66 | projmatrix=view_proj_matrix, 67 | sh_degree=0, 68 | campos=campos, 69 | prefiltered=False, 70 | debug=False, 71 | ) 72 | 73 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 74 | 75 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 76 | rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( 77 | means3D=means3D, 78 | means2D=torch.zeros_like(means3D, dtype=torch.float32, device=device), 79 | shs=None, 80 | colors_precomp=rgbs, 81 | opacities=opacity, 82 | scales=scales, 83 | rotations=rotations, 84 | cov3D_precomp=None, 85 | ) 86 | 87 | rendered_image = rendered_image.clamp(0, 1) 88 | 89 | images.append(rendered_image) 90 | alphas.append(rendered_alpha) 91 | 92 | images = torch.stack(images, dim=0).view(B, V, 3, self.opt.output_size, self.opt.output_size) 93 | alphas = torch.stack(alphas, dim=0).view(B, V, 1, self.opt.output_size, self.opt.output_size) 94 | 95 | return { 96 | "image": images, # [B, V, 3, H, W] 97 | "alpha": alphas, # [B, V, 1, H, W] 98 | } 99 | 100 | 101 | def save_ply(self, gaussians, path, compatible=True): 102 | # gaussians: [B, N, 14] 103 | # compatible: save pre-activated gaussians as in the original paper 104 | 105 | assert gaussians.shape[0] == 1, 'only support batch size 1' 106 | 107 | from plyfile import PlyData, PlyElement 108 | 109 | means3D = gaussians[0, :, 0:3].contiguous().float() 110 | opacity = gaussians[0, :, 3:4].contiguous().float() 111 | scales = gaussians[0, :, 4:7].contiguous().float() 112 | rotations = gaussians[0, :, 7:11].contiguous().float() 113 | shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float() # [N, 1, 3] 114 | 115 | # prune by opacity 116 | mask = opacity.squeeze(-1) >= 0.005 117 | means3D = means3D[mask] 118 | opacity = opacity[mask] 119 | scales = scales[mask] 120 | rotations = rotations[mask] 121 | shs = shs[mask] 122 | 123 | # invert activation to make it compatible with the original ply format 124 | if compatible: 125 | opacity = kiui.op.inverse_sigmoid(opacity) 126 | scales = torch.log(scales + 1e-8) 127 | shs = (shs - 0.5) / 0.28209479177387814 128 | 129 | xyzs = means3D.detach().cpu().numpy() 130 | f_dc = shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 131 | opacities = opacity.detach().cpu().numpy() 132 | scales = scales.detach().cpu().numpy() 133 | rotations = rotations.detach().cpu().numpy() 134 | 135 | l = ['x', 'y', 'z'] 136 | # All channels except the 3 DC 137 | for i in range(f_dc.shape[1]): 138 | l.append('f_dc_{}'.format(i)) 139 | l.append('opacity') 140 | for i in range(scales.shape[1]): 141 | l.append('scale_{}'.format(i)) 142 | for i in range(rotations.shape[1]): 143 | l.append('rot_{}'.format(i)) 144 | 145 | dtype_full = [(attribute, 'f4') for attribute in l] 146 | 147 | elements = np.empty(xyzs.shape[0], dtype=dtype_full) 148 | attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1) 149 | elements[:] = list(map(tuple, attributes)) 150 | el = PlyElement.describe(elements, 'vertex') 151 | 152 | PlyData([el]).write(path) 153 | 154 | def load_ply(self, path, compatible=True): 155 | 156 | from plyfile import PlyData, PlyElement 157 | 158 | plydata = PlyData.read(path) 159 | 160 | xyz = np.stack((np.asarray(plydata.elements[0]["x"]), 161 | np.asarray(plydata.elements[0]["y"]), 162 | np.asarray(plydata.elements[0]["z"])), axis=1) 163 | # print("Number of points at loading : ", xyz.shape[0]) 164 | 165 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] 166 | 167 | shs = np.zeros((xyz.shape[0], 3)) 168 | shs[:, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) 169 | shs[:, 1] = np.asarray(plydata.elements[0]["f_dc_1"]) 170 | shs[:, 2] = np.asarray(plydata.elements[0]["f_dc_2"]) 171 | 172 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] 173 | scales = np.zeros((xyz.shape[0], len(scale_names))) 174 | for idx, attr_name in enumerate(scale_names): 175 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) 176 | 177 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot_")] 178 | rots = np.zeros((xyz.shape[0], len(rot_names))) 179 | for idx, attr_name in enumerate(rot_names): 180 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) 181 | 182 | gaussians = np.concatenate([xyz, opacities, scales, rots, shs], axis=1) 183 | gaussians = torch.from_numpy(gaussians).float() # cpu 184 | 185 | if compatible: 186 | gaussians[..., 3:4] = torch.sigmoid(gaussians[..., 3:4]) 187 | gaussians[..., 4:7] = torch.exp(gaussians[..., 4:7]) 188 | gaussians[..., 11:] = 0.28209479177387814 * gaussians[..., 11:] + 0.5 189 | 190 | return gaussians -------------------------------------------------------------------------------- /preprocess/src/models/renderer/synthesizer.py: -------------------------------------------------------------------------------- 1 | # ORIGINAL LICENSE 2 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: LicenseRef-NvidiaProprietary 4 | # 5 | # Modified by Jiale Xu 6 | # The modifications are subject to the same license as the original. 7 | 8 | 9 | import itertools 10 | import torch 11 | import torch.nn as nn 12 | 13 | from .utils.renderer import ImportanceRenderer 14 | from .utils.ray_sampler import RaySampler 15 | 16 | 17 | class OSGDecoder(nn.Module): 18 | """ 19 | Triplane decoder that gives RGB and sigma values from sampled features. 20 | Using ReLU here instead of Softplus in the original implementation. 21 | 22 | Reference: 23 | EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 24 | """ 25 | def __init__(self, n_features: int, 26 | hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): 27 | super().__init__() 28 | self.net = nn.Sequential( 29 | nn.Linear(3 * n_features, hidden_dim), 30 | activation(), 31 | *itertools.chain(*[[ 32 | nn.Linear(hidden_dim, hidden_dim), 33 | activation(), 34 | ] for _ in range(num_layers - 2)]), 35 | nn.Linear(hidden_dim, 1 + 3), 36 | ) 37 | # init all bias to zero 38 | for m in self.modules(): 39 | if isinstance(m, nn.Linear): 40 | nn.init.zeros_(m.bias) 41 | 42 | def forward(self, sampled_features, ray_directions): 43 | # Aggregate features by mean 44 | # sampled_features = sampled_features.mean(1) 45 | # Aggregate features by concatenation 46 | _N, n_planes, _M, _C = sampled_features.shape 47 | sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) 48 | x = sampled_features 49 | 50 | N, M, C = x.shape 51 | x = x.contiguous().view(N*M, C) 52 | 53 | x = self.net(x) 54 | x = x.view(N, M, -1) 55 | rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF 56 | sigma = x[..., 0:1] 57 | 58 | return {'rgb': rgb, 'sigma': sigma} 59 | 60 | 61 | class TriplaneSynthesizer(nn.Module): 62 | """ 63 | Synthesizer that renders a triplane volume with planes and a camera. 64 | 65 | Reference: 66 | EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 67 | """ 68 | 69 | DEFAULT_RENDERING_KWARGS = { 70 | 'ray_start': 'auto', 71 | 'ray_end': 'auto', 72 | 'box_warp': 2., 73 | 'white_back': True, 74 | 'disparity_space_sampling': False, 75 | 'clamp_mode': 'softplus', 76 | 'sampler_bbox_min': -1., 77 | 'sampler_bbox_max': 1., 78 | } 79 | 80 | def __init__(self, triplane_dim: int, samples_per_ray: int): 81 | super().__init__() 82 | 83 | # attributes 84 | self.triplane_dim = triplane_dim 85 | self.rendering_kwargs = { 86 | **self.DEFAULT_RENDERING_KWARGS, 87 | 'depth_resolution': samples_per_ray // 2, 88 | 'depth_resolution_importance': samples_per_ray // 2, 89 | } 90 | 91 | # renderings 92 | self.renderer = ImportanceRenderer() 93 | self.ray_sampler = RaySampler() 94 | 95 | # modules 96 | self.decoder = OSGDecoder(n_features=triplane_dim) 97 | 98 | def forward(self, planes, cameras, render_size=128, crop_params=None): 99 | # planes: (N, 3, D', H', W') 100 | # cameras: (N, M, D_cam) 101 | # render_size: int 102 | assert planes.shape[0] == cameras.shape[0], "Batch size mismatch for planes and cameras" 103 | N, M = cameras.shape[:2] 104 | 105 | cam2world_matrix = cameras[..., :16].view(N, M, 4, 4) 106 | intrinsics = cameras[..., 16:25].view(N, M, 3, 3) 107 | 108 | # Create a batch of rays for volume rendering 109 | ray_origins, ray_directions = self.ray_sampler( 110 | cam2world_matrix=cam2world_matrix.reshape(-1, 4, 4), 111 | intrinsics=intrinsics.reshape(-1, 3, 3), 112 | render_size=render_size, 113 | ) 114 | assert N*M == ray_origins.shape[0], "Batch size mismatch for ray_origins" 115 | assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional" 116 | 117 | # Crop rays if crop_params is available 118 | if crop_params is not None: 119 | ray_origins = ray_origins.reshape(N*M, render_size, render_size, 3) 120 | ray_directions = ray_directions.reshape(N*M, render_size, render_size, 3) 121 | i, j, h, w = crop_params 122 | ray_origins = ray_origins[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3) 123 | ray_directions = ray_directions[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3) 124 | 125 | # Perform volume rendering 126 | rgb_samples, depth_samples, weights_samples = self.renderer( 127 | planes.repeat_interleave(M, dim=0), self.decoder, ray_origins, ray_directions, self.rendering_kwargs, 128 | ) 129 | 130 | # Reshape into 'raw' neural-rendered image 131 | if crop_params is not None: 132 | Himg, Wimg = crop_params[2:] 133 | else: 134 | Himg = Wimg = render_size 135 | rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, M, rgb_samples.shape[-1], Himg, Wimg).contiguous() 136 | depth_images = depth_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg) 137 | weight_images = weights_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg) 138 | 139 | out = { 140 | 'images_rgb': rgb_images, 141 | 'images_depth': depth_images, 142 | 'images_weight': weight_images, 143 | } 144 | return out 145 | 146 | def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None): 147 | # planes: (N, 3, D', H', W') 148 | # grid_size: int 149 | # aabb: (N, 2, 3) 150 | if aabb is None: 151 | aabb = torch.tensor([ 152 | [self.rendering_kwargs['sampler_bbox_min']] * 3, 153 | [self.rendering_kwargs['sampler_bbox_max']] * 3, 154 | ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1) 155 | assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb" 156 | N = planes.shape[0] 157 | 158 | # create grid points for triplane query 159 | grid_points = [] 160 | for i in range(N): 161 | grid_points.append(torch.stack(torch.meshgrid( 162 | torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device), 163 | torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device), 164 | torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device), 165 | indexing='ij', 166 | ), dim=-1).reshape(-1, 3)) 167 | cube_grid = torch.stack(grid_points, dim=0).to(planes.device) 168 | 169 | features = self.forward_points(planes, cube_grid) 170 | 171 | # reshape into grid 172 | features = { 173 | k: v.reshape(N, grid_size, grid_size, grid_size, -1) 174 | for k, v in features.items() 175 | } 176 | return features 177 | 178 | def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20): 179 | # planes: (N, 3, D', H', W') 180 | # points: (N, P, 3) 181 | N, P = points.shape[:2] 182 | 183 | # query triplane in chunks 184 | outs = [] 185 | for i in range(0, points.shape[1], chunk_size): 186 | chunk_points = points[:, i:i+chunk_size] 187 | 188 | # query triplane 189 | chunk_out = self.renderer.run_model_activated( 190 | planes=planes, 191 | decoder=self.decoder, 192 | sample_coordinates=chunk_points, 193 | sample_directions=torch.zeros_like(chunk_points), 194 | options=self.rendering_kwargs, 195 | ) 196 | outs.append(chunk_out) 197 | 198 | # concatenate the outputs 199 | point_features = { 200 | k: torch.cat([out[k] for out in outs], dim=1) 201 | for k in outs[0].keys() 202 | } 203 | return point_features 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PartRM: Modeling Part-Level Dynamics with Large Cross-State Reconstruction Model [CVPR 2025] 2 | 3 | This repository is an official implementation for: 4 | 5 | **PartRM: Modeling Part-Level Dynamics with Large Cross-State Reconstruction Model [CVPR 2025]** 6 | > Authors: Mingju Gao*, Yike Pan*, [Huan-ang Gao*](https://c7w.tech/about/), Zongzheng Zhang, Wenyi Li, [Hao Dong](https://zsdonghao.github.io/), [Hao Tang](https://ha0tang.github.io/), [Li Yi](https://ericyi.github.io/), [Hao Zhao](https://sites.google.com/view/fromandto) 7 | 8 | ![Teaser](./images/teaser.png) 9 | 10 | ## Introduction 11 | As interest grows in world models that predict future states from current observations and actions, accurately modeling part-level dynamics has become increasingly relevant for various applications. Existing approaches, such as Puppet-Master, rely on fine-tuning large-scale pre-trained video diffusion models, which are impractical for real-world use due to the limitations of 2D video representation and slow processing times. To overcome these challenges, we present PartRM, a novel 4D reconstruction framework that simultaneously models appearance, geometry, and part-level motion from multi-view images of a static object. PartRM builds upon large 3D Gaussian reconstruction models, leveraging their extensive knowledge of appearance and geometry in static objects. To address data scarcity in 4D, we introduce the PartDrag-4D dataset, providing multi-view observations of part-level dynamics across over 20,000 states. We enhance the model’s understanding of interaction conditions with a multi-scale drag embedding module that captures dynamics at varying granularities. To prevent catastrophic forgetting during fine-tuning, we implement a two-stage training process that focuses sequentially on motion and appearance learning. Experimental results show that PartRM establishes a new state-of-the-art in part-level motion learning and can be applied in manipulation tasks in robotics. 12 | Project page: https://partrm.c7w.tech/ 13 | 14 | ## Environment Setup 15 | Use `conda` to create a new virtual enviroment. We use `torch==2.1.0+cu121`. 16 | ```bash 17 | conda env create -f environment.yaml 18 | conda activate partrm 19 | ``` 20 | Also with gaussian splatting renderer 21 | ```bash 22 | # a modified gaussian splatting (+ depth, alpha rendering) 23 | git clone --recursive https://github.com/ashawkey/diff-gaussian-rasterization 24 | pip install ./diff-gaussian-rasterization 25 | ``` 26 | 27 | ## PartDrag-4D Dataset 28 | You can download PartDrag-4D dataset from [here](https://huggingface.co/GasaiYU/PartRM/tree/main). And unzip `pardrag_4d/partdrag_rendered.zip` to `PartDrag4D/data/render_PartDrag4D`, unzip `processed_data_partdrag4d.zip` to `./PartDrag4D/data/processed_data_partdrag4d` 29 | 30 | Below is how to render the PartDrag-4D dataset from scratch. 31 | 32 | You need to first get [PartNet-Mobility](https://sapien.ucsd.edu/browse) dataset and put it in the `PartDrag4D/data` directory of this repo. 33 | Then 34 | ```bash 35 | cd PartDrag4D 36 | ``` 37 | For mesh preprocessing and animating: 38 | ```bash 39 | cd preprocess 40 | python process_data_textured_uv.py 41 | python animated_data.py 42 | ``` 43 | For rendering, first download [blender](https://download.blender.org/release/Blender3.5/blender-3.5.0-linux-x64.tar.xz) and unzip it in the `../rendering/blender` directory: 44 | ```bash 45 | cd ../rendering/blender 46 | wget https://download.blender.org/release/Blender3.5/blender-3.5.0-linux-x64.tar.xz 47 | tar -xf blender-3.5.0-linux-x64.tar.xz 48 | ``` 49 | 50 | Then generate the rendering filelist and render the generated meshes using [blender](https://download.blender.org/release/Blender3.5/blender-3.5.0-linux-x64.tar.xz): 51 | ```bash 52 | cd .. 53 | python gen_filelist.py 54 | bash render.sh 55 | ``` 56 | You can modify `num_gpus` and `CUDA_VISIBLE_DEVICES` in the bash script to adjust the degree of parallelism. 57 | 58 | For surface drags extraction: 59 | ```bash 60 | cd .. 61 | python z_buffer_al.py 62 | ``` 63 | 64 | The animated meshes and extracted surface drags are stored in `./PartDrag4D/data/processed_data_partdrag4d`. The rendering results are stored in `./PartDrag4D/data/render_PartDrag4D`. 65 | 66 | We split the PartDrag-4D dataset into training and evaluation sets. You can refer to `./filelist/train_filelist_partdrag4d.txt` and `./filelist/val_filelist_partdrag4d.txt` for details. 67 | 68 | ## Images and Drags Preprocessing 69 | You can get Zero123++ and SAM checkpoint from [here](https://huggingface.co/GasaiYU/PartRM/tree/main/pretrained). Then put them into `preprocess/zero123_ckpt` and `preprocess/sam_ckpt` respectively. 70 | 71 | To generate multi-view images for **evaluation data**: 72 | ```bash 73 | cd ../preprocess 74 | python gen_mv_partdrag4d.py --src_filelist /path/to/src/rendering/filelist --output_dir /path/to/save/dir # For PartDrag-4D 75 | python gen_mv_objaverse_hq.py --src_filelist /path/to/src/rendering/filelist --output_dir /path/to/save/dir # For Objaverse-Animation-HQ, 76 | ``` 77 | The `src_filelist` is the path to the rendering filelist. You can refer to [this filelist](filelist/zero123_val_filelist_partdrag4d.txt) for PartDrag4D and [this filelist](filelist/zero123_val_filelist_objavser_hq.txt) for Objaverse-Animation-HQ for example. 78 | 79 | To generate RGBA format images for the input of **PartRM**: 80 | ```bash 81 | python gen_rgba.py --filelist /path/to/zero123/filelist --dataset [dataset_name] 82 | ``` 83 | You can refer to [this filelist](filelist/zero123_val_filelist_objavser_hq.txt) for PartDrag4D and [this filelist](filelist/zero123_val_filelist_partdrag4d.txt) for Objaverse-Animation-HQ for example. 84 | 85 | To generate propagated drags for **PartDrag-4D** dataset (You can download our preprocessed propagated drags from [here](https://huggingface.co/GasaiYU/PartRM/tree/main/propagated_drags)): 86 | ```bash 87 | python gen_propagate_drags.py --val_filelist /path/to/src/rendering/filelist --sample_num [The number of propagated drags] --save_dir /path/to/save/drags 88 | ``` 89 | The `val_filelist` is the same as the `src_filelist` (multi-view images generation) for PartDrag-4D above. 90 | 91 | ## Training 92 | 93 | 94 | 95 | 96 | We provide training scripts for `PartDrag-4D` and `Objaverse-Animation-HQ` datasets. You can adjust the dataset for training in the `train.py` and `eval.py` (partdrag4d or objaverse_hq). Then run: 97 | ```bash 98 | CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file acc_configs/gpu4.yaml train.py big --workspace [your workspace] 99 | ``` 100 | 101 | You should specify the `train_filelist`, `val_filelist`, `zero123_val_filelist`, `propagated_drags_base` and `mesh_base` in `core/options.py` and `core/options_pm.py`. 102 | 103 | - For `train_filelist`, you can refer to `filelist/train_filelist.txt` and `filelist/train_objavser_hq.txt`. 104 | 105 | - For `val_filelist`, you can refer to `filelist/val_filelist.txt` and `filelist/eval_objaverse_hq.txt`. 106 | 107 | - For `zero123_val_filelist`, you can refer to `filelist/zero123_val_filelist.txt` and `filelist/zero123_val_filelist_objavser_hq.txt`. 108 | 109 | For the 2-stage training proposed in paper, you should first set the `stage1` in `core/options.py` and `core/options_pm.py` **True**. After the motion-learning traing, set the `stage1` **False** to conduct the apperance learning training. 110 | 111 | ## Evaluation 112 | 113 | 114 | For evaluation, you should first run 115 | ```bash 116 | CUDA_VISIBLE_DEVICES=0 accelerate launch --config_file acc_configs/gpu1.yaml eval.py big --workspace [your workspace] 117 | ``` 118 | Note you should set the `stage1` in `core/options.py` and `core/options_pm.py` **False**. 119 | 120 | Then you should generte your eval filelist with every line like 121 | ``` 122 | gt_image_path,pred_image_path,source_image_path 123 | ``` 124 | The specify the `VAL_FILELIST` (The path of generated eval filelist) in `compute_metrics.py` and run: 125 | ``` 126 | python compute_metrics.py 127 | ``` 128 | 129 | You can get the PSNR, LPIPS and SSIM metrics. 130 | 131 | # Acknowledgement 132 | We build our work on [LGM](https://arxiv.org/pdf/2402.05054), [Zero123++](https://arxiv.org/pdf/2310.15110) and [3D Gaussian Splattings](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/3d_gaussian_splatting_high.pdf). -------------------------------------------------------------------------------- /preprocess/src/models/lrm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Zexin He 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | import mcubes 19 | import nvdiffrast.torch as dr 20 | from einops import rearrange, repeat 21 | 22 | from .encoder.dino_wrapper import DinoWrapper 23 | from .decoder.transformer import TriplaneTransformer 24 | from .renderer.synthesizer import TriplaneSynthesizer 25 | from ..utils.mesh_util import xatlas_uvmap 26 | 27 | 28 | class InstantNeRF(nn.Module): 29 | """ 30 | Full model of the large reconstruction model. 31 | """ 32 | def __init__( 33 | self, 34 | encoder_freeze: bool = False, 35 | encoder_model_name: str = 'facebook/dino-vitb16', 36 | encoder_feat_dim: int = 768, 37 | transformer_dim: int = 1024, 38 | transformer_layers: int = 16, 39 | transformer_heads: int = 16, 40 | triplane_low_res: int = 32, 41 | triplane_high_res: int = 64, 42 | triplane_dim: int = 80, 43 | rendering_samples_per_ray: int = 128, 44 | ): 45 | super().__init__() 46 | 47 | # modules 48 | self.encoder = DinoWrapper( 49 | model_name=encoder_model_name, 50 | freeze=encoder_freeze, 51 | ) 52 | 53 | self.transformer = TriplaneTransformer( 54 | inner_dim=transformer_dim, 55 | num_layers=transformer_layers, 56 | num_heads=transformer_heads, 57 | image_feat_dim=encoder_feat_dim, 58 | triplane_low_res=triplane_low_res, 59 | triplane_high_res=triplane_high_res, 60 | triplane_dim=triplane_dim, 61 | ) 62 | 63 | self.synthesizer = TriplaneSynthesizer( 64 | triplane_dim=triplane_dim, 65 | samples_per_ray=rendering_samples_per_ray, 66 | ) 67 | 68 | def forward_planes(self, images, cameras): 69 | # images: [B, V, C_img, H_img, W_img] 70 | # cameras: [B, V, 16] 71 | B = images.shape[0] 72 | 73 | # encode images 74 | image_feats = self.encoder(images, cameras) 75 | image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B) 76 | 77 | # transformer generating planes 78 | planes = self.transformer(image_feats) 79 | 80 | return planes 81 | 82 | def forward_synthesizer(self, planes, render_cameras, render_size: int): 83 | render_results = self.synthesizer( 84 | planes, 85 | render_cameras, 86 | render_size, 87 | ) 88 | return render_results 89 | 90 | def forward(self, images, cameras, render_cameras, render_size: int): 91 | # images: [B, V, C_img, H_img, W_img] 92 | # cameras: [B, V, 16] 93 | # render_cameras: [B, M, D_cam_render] 94 | # render_size: int 95 | B, M = render_cameras.shape[:2] 96 | 97 | planes = self.forward_planes(images, cameras) 98 | 99 | # render target views 100 | render_results = self.synthesizer(planes, render_cameras, render_size) 101 | 102 | return { 103 | 'planes': planes, 104 | **render_results, 105 | } 106 | 107 | def get_texture_prediction(self, planes, tex_pos, hard_mask=None): 108 | ''' 109 | Predict Texture given triplanes 110 | :param planes: the triplane feature map 111 | :param tex_pos: Position we want to query the texture field 112 | :param hard_mask: 2D silhoueete of the rendered image 113 | ''' 114 | tex_pos = torch.cat(tex_pos, dim=0) 115 | if not hard_mask is None: 116 | tex_pos = tex_pos * hard_mask.float() 117 | batch_size = tex_pos.shape[0] 118 | tex_pos = tex_pos.reshape(batch_size, -1, 3) 119 | ################### 120 | # We use mask to get the texture location (to save the memory) 121 | if hard_mask is not None: 122 | n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1) 123 | sample_tex_pose_list = [] 124 | max_point = n_point_list.max() 125 | expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5 126 | for i in range(tex_pos.shape[0]): 127 | tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3) 128 | if tex_pos_one_shape.shape[1] < max_point: 129 | tex_pos_one_shape = torch.cat( 130 | [tex_pos_one_shape, torch.zeros( 131 | 1, max_point - tex_pos_one_shape.shape[1], 3, 132 | device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1) 133 | sample_tex_pose_list.append(tex_pos_one_shape) 134 | tex_pos = torch.cat(sample_tex_pose_list, dim=0) 135 | 136 | tex_feat = torch.utils.checkpoint.checkpoint( 137 | self.synthesizer.forward_points, 138 | planes, 139 | tex_pos, 140 | use_reentrant=False, 141 | )['rgb'] 142 | 143 | if hard_mask is not None: 144 | final_tex_feat = torch.zeros( 145 | planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device) 146 | expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5 147 | for i in range(planes.shape[0]): 148 | final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1) 149 | tex_feat = final_tex_feat 150 | 151 | return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1]) 152 | 153 | def extract_mesh( 154 | self, 155 | planes: torch.Tensor, 156 | mesh_resolution: int = 256, 157 | mesh_threshold: int = 10.0, 158 | use_texture_map: bool = False, 159 | texture_resolution: int = 1024, 160 | **kwargs, 161 | ): 162 | ''' 163 | Extract a 3D mesh from triplane nerf. Only support batch_size 1. 164 | :param planes: triplane features 165 | :param mesh_resolution: marching cubes resolution 166 | :param mesh_threshold: iso-surface threshold 167 | :param use_texture_map: use texture map or vertex color 168 | :param texture_resolution: the resolution of texture map 169 | ''' 170 | assert planes.shape[0] == 1 171 | device = planes.device 172 | 173 | grid_out = self.synthesizer.forward_grid( 174 | planes=planes, 175 | grid_size=mesh_resolution, 176 | ) 177 | 178 | vertices, faces = mcubes.marching_cubes( 179 | grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(), 180 | mesh_threshold, 181 | ) 182 | vertices = vertices / (mesh_resolution - 1) * 2 - 1 183 | 184 | if not use_texture_map: 185 | # query vertex colors 186 | vertices_tensor = torch.tensor(vertices, dtype=torch.float32, device=device).unsqueeze(0) 187 | vertices_colors = self.synthesizer.forward_points( 188 | planes, vertices_tensor)['rgb'].squeeze(0).cpu().numpy() 189 | vertices_colors = (vertices_colors * 255).astype(np.uint8) 190 | 191 | return vertices, faces, vertices_colors 192 | 193 | # use x-atlas to get uv mapping for the mesh 194 | vertices = torch.tensor(vertices, dtype=torch.float32, device=device) 195 | faces = torch.tensor(faces.astype(int), dtype=torch.long, device=device) 196 | 197 | ctx = dr.RasterizeCudaContext(device=device) 198 | uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap( 199 | ctx, vertices, faces, resolution=texture_resolution) 200 | tex_hard_mask = tex_hard_mask.float() 201 | 202 | # query the texture field to get the RGB color for texture map 203 | tex_feat = self.get_texture_prediction( 204 | planes, [gb_pos], tex_hard_mask) 205 | background_feature = torch.zeros_like(tex_feat) 206 | img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask) 207 | texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0) 208 | 209 | return vertices, faces, uvs, mesh_tex_idx, texture_map -------------------------------------------------------------------------------- /preprocess/gen_propagate_drags.py: -------------------------------------------------------------------------------- 1 | # This is only for PartDrag4D dataset 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from torch.utils.data import Dataset 7 | import torch.nn.functional as F 8 | 9 | import os 10 | import cv2 11 | 12 | import json 13 | import random 14 | import open3d as o3d 15 | import copy 16 | 17 | from tqdm import tqdm 18 | 19 | from PIL import Image, ImageDraw 20 | 21 | from segment_anything import SamPredictor, sam_model_registry 22 | import argparse 23 | 24 | def gen_rand_num(num_range, cur_num): 25 | while True: 26 | a = random.randint(0, num_range - 1) 27 | if a != cur_num: 28 | return a 29 | 30 | def sam_predict(predictor: SamPredictor, image, query_point): 31 | predictor.set_image(image=image) 32 | labels = np.array([1]) 33 | 34 | masks, scores, logits = predictor.predict( 35 | point_coords=query_point, 36 | point_labels=labels, 37 | multimask_output=True 38 | ) 39 | 40 | min_sum = 10000000 41 | for mask in masks: 42 | if mask.sum() < min_sum: 43 | selected_mask = mask 44 | min_sum = mask.sum() 45 | 46 | return selected_mask 47 | 48 | class DragPropagateDataset(Dataset): 49 | def __init__(self, filelist, pcd_base, render_base): 50 | self.pcd_base = pcd_base 51 | self.render_base = render_base 52 | 53 | self.images, self.cameras = [], [] 54 | 55 | with open(filelist, 'r') as f: 56 | for line in f.readlines(): 57 | if line.startswith('#'): 58 | continue 59 | 60 | render_dir = line.strip() 61 | 62 | self.images.append(os.path.join(render_dir, f'000.png')) 63 | self.cameras.append(os.path.join(render_dir, f'000_camera.json')) 64 | 65 | self.class_names = [name for name in os.listdir(pcd_base) if os.path.isdir(os.path.join(pcd_base, name))] 66 | self.cam_radius = 2.4 67 | self.h = self.w = 512 68 | self.fov = 49.1 69 | 70 | def __len__(self): 71 | return len(self.images) 72 | 73 | 74 | def ndc2Pix(self, v, S): 75 | return ((v + 1.0) * S - 1.0) * 0.5 76 | 77 | 78 | def local2world(self, drags, scale, world_matrix): 79 | drags = drags.clone().detach() 80 | scale = torch.tensor(scale, dtype=torch.float32).clone().detach() 81 | translation = torch.tensor(world_matrix, dtype=torch.float32) 82 | 83 | scaled_points = drags * scale 84 | world_points = scaled_points + translation 85 | 86 | return world_points 87 | 88 | def project_drag(self, drags, cam_view): 89 | # local2world transformation using blender scale and world matrix 90 | scale = 0.42995866893870127 91 | world_matrix = (-0.0121, -0.0070, 0.0120) 92 | drags = self.local2world(drags, scale, world_matrix) 93 | 94 | # use blender fixed projection matrix 95 | proj_matrix = [ 96 | [2.777777671813965, 0.0000, 0.0000, 0.0000], 97 | [0.0000, 2.777777671813965, 0.0000, 0.0000], 98 | [0.0000, 0.0000, -1.0001999139785767, -0.20002000033855438], 99 | [0.0000, 0.0000, -1.0000, 0.0000] 100 | ] 101 | cam_proj = torch.tensor(proj_matrix, dtype=torch.float32) 102 | 103 | ndc_coords_2d_list = [] 104 | 105 | for i in range(cam_view.size(0)): 106 | for j in range(drags.size(0)): 107 | view_matrix = cam_view[i] 108 | 109 | view_matrix = torch.inverse(view_matrix) 110 | proj_matrix = cam_proj 111 | 112 | point_3D = drags[j] 113 | point_3D_homogeneous = torch.cat([point_3D, torch.tensor([1.0])], dim=0) 114 | 115 | camera_coords = torch.matmul(view_matrix, point_3D_homogeneous) 116 | clip_coords = torch.matmul(proj_matrix, camera_coords) 117 | 118 | ndc_coords = clip_coords[:3] / clip_coords[3] 119 | ndc_coords_2d_list.append(ndc_coords[:2].tolist()) 120 | 121 | ndc_coords_2d_tensor = torch.tensor(ndc_coords_2d_list).view(cam_view.size(0), drags.size(0), 2) 122 | 123 | # drag ndc to pix 124 | S = torch.tensor([512, 512]) 125 | drags_2d = S - self.ndc2Pix(ndc_coords_2d_tensor, S) 126 | 127 | return drags_2d 128 | 129 | 130 | def __getitem__(self, idx): 131 | # ----- Load image 132 | image_path = self.images[idx] 133 | with open(image_path, 'rb') as f: 134 | image = cv2.imdecode(np.frombuffer(f.read(), np.uint8), cv2.IMREAD_UNCHANGED) 135 | image = torch.from_numpy(image).float() / 255.0 136 | 137 | image = image.permute(2, 0, 1) # [4, 512, 512] 138 | mask = image[3:4] # [1, 512, 512] 139 | image = image[:3] * mask + (1 - mask) # [3, 512, 512], to white bg 140 | image = image[[2,1,0]].contiguous() # bgr to rgb 141 | 142 | # ----- Load camera parameters 143 | camera_path = self.cameras[idx] 144 | with open(camera_path, 'r') as f: 145 | meta = json.load(f) 146 | 147 | camera_matrix = torch.eye(4) 148 | camera_matrix[:3, 0] = torch.tensor(meta["x"]) 149 | camera_matrix[:3, 1] = -torch.tensor(meta["y"]) 150 | camera_matrix[:3, 2] = -torch.tensor(meta["z"]) 151 | camera_matrix[:3, 3] = torch.tensor(meta["origin"]) 152 | c2w = camera_matrix 153 | c2w = c2w.clone().float().reshape(4, 4) 154 | 155 | # ----- Get 3D Drags 156 | mesh_id = image_path.split('/')[-2].split('_')[0] 157 | cur_class = None 158 | for class_name in self.class_names: 159 | item_base = os.path.join(self.pcd_base, class_name, mesh_id) 160 | if os.path.exists(item_base): 161 | cur_class = class_name 162 | break 163 | if cur_class is None: 164 | raise Exception(f"{mesh_id} drag base not found") 165 | 166 | # ----- Get 2D Drag 167 | pcd_path = os.path.join(self.pcd_base, class_name, mesh_id, 'motion') 168 | cur_parts = image_path.split('/')[-2].split('_') 169 | rand_parts = copy.deepcopy(cur_parts) 170 | 171 | cur_motion_id = int(cur_parts[2]) 172 | rand_motion_id = gen_rand_num(6, cur_motion_id) 173 | rand_parts[2] = str(rand_motion_id) 174 | cur_pcd_name = '_'.join(cur_parts[1:]) 175 | rand_pcd_name = '_'.join(rand_parts[1:]) 176 | 177 | pcd0 = np.asarray(o3d.io.read_point_cloud(os.path.join(pcd_path, f'{cur_pcd_name}.ply')).points) 178 | pcd1 = np.asarray(o3d.io.read_point_cloud(os.path.join(pcd_path, f'{rand_pcd_name}.ply')).points) 179 | pcd_rand_idx = np.where((pcd0 - pcd1).sum(1) != 0)[0] 180 | 181 | surface_2d_index = np.load(os.path.join(pcd_path, f'{cur_pcd_name}_visible.npy')) 182 | surface_2d_index = np.intersect1d(surface_2d_index, pcd_rand_idx) 183 | 184 | if surface_2d_index.shape[0] > 0: 185 | rand_2d_index = random.choices(surface_2d_index, k=1) 186 | else: 187 | rand_2d_index = random.choices(pcd_rand_idx, k=1) 188 | 189 | rand_2d_drag_3d_start = torch.from_numpy(pcd0[rand_2d_index]).float() 190 | rand_2d_drag_start = self.project_drag(rand_2d_drag_3d_start, camera_matrix.unsqueeze(0)) 191 | rand_2d_drag_start = torch.clamp(rand_2d_drag_start, 0, 511) 192 | 193 | return rand_2d_drag_start, image_path 194 | 195 | def main(val_filelist, mesh_base, render_base, save_dir, sample_num=10): 196 | drag_propagate_dataset = DragPropagateDataset(filelist=val_filelist, pcd_base=mesh_base, render_base=render_base) 197 | sam = sam_model_registry['vit_h'](checkpoint='sam_ckpt/sam_vit_h_4b8939.pth') 198 | sam.to('cuda') 199 | predictor = SamPredictor(sam) 200 | 201 | for i in tqdm(range(len(drag_propagate_dataset))): 202 | data, image_path = drag_propagate_dataset[i] 203 | query_point = np.array([[int(data[0, 0, 0].item()), int(data[0, 0, 1].item())]]) 204 | sam_image= cv2.imread(image_path) 205 | sam_image = cv2.cvtColor(sam_image, cv2.COLOR_BGR2RGB) 206 | mask = sam_predict(predictor, sam_image, query_point) 207 | 208 | true_indices = np.argwhere(mask) 209 | sampled_indices = true_indices[np.random.permutation(true_indices.shape[0])[:sample_num]] 210 | render_id = image_path.split('/')[-2] 211 | os.makedirs(os.path.join(save_dir, render_id), exist_ok=True) 212 | np.save(os.path.join(save_dir, render_id, 'propagated_indices.npy'), sampled_indices) 213 | 214 | if __name__ == "__main__": 215 | parser = argparse.ArgumentParser() 216 | parser.add_argument('--val_filelist', default='../filelist/val_filelist_partdrag4d.txt', help='The path of the val filelist') 217 | parser.add_argument('--mesh_base', default='../PartDrag4D/data/processed_data_partdrag4d', help='The dir of the processed data which contains the mesh') 218 | parser.add_argument('--render_base', default='../PartDrag4D/data/render_PartDrag4D', help='The dir of the rendered images') 219 | parser.add_argument('--sample_num', default=10, help="The number of the sample points") 220 | parser.add_argument('--save_dir', default='./propagated_drags', help="The path to save the propagated drags.") 221 | args = parser.parse_args() 222 | 223 | main(val_filelist=args.val_filelist, mesh_base=args.mesh_base, render_base=args.render_base, 224 | save_dir=args.save_dir, sample_num=args.sample_num) -------------------------------------------------------------------------------- /filelist/val_filelist_objaverse_hq.txt: -------------------------------------------------------------------------------- 1 | /path/to/your/objavser/base/1f60d11ef7904a20aedcfebaed750b6f/Armature|Anim 2 | /path/to/your/objavser/base/1f60d11ef7904a20aedcfebaed750b6f/Anim 3 | /path/to/your/objavser/base/2cf1808d3d20440299fb427d506ac371/walk_slow_events 4 | /path/to/your/objavser/base/2cf1808d3d20440299fb427d506ac371/vuelta_cocinar 5 | /path/to/your/objavser/base/2cf1808d3d20440299fb427d506ac371/acariciar_tanque 6 | /path/to/your/objavser/base/2cf1808d3d20440299fb427d506ac371/driving 7 | /path/to/your/objavser/base/2cf1808d3d20440299fb427d506ac371/idle_fin_cocinar 8 | /path/to/your/objavser/base/2cf1808d3d20440299fb427d506ac371/idle1 9 | /path/to/your/objavser/base/2cf1808d3d20440299fb427d506ac371/idle_cocinar 10 | /path/to/your/objavser/base/2cf1808d3d20440299fb427d506ac371/sujetar_cerdo 11 | /path/to/your/objavser/base/2cf1808d3d20440299fb427d506ac371/idle3 12 | /path/to/your/objavser/base/2cf1808d3d20440299fb427d506ac371/look_surprised 13 | /path/to/your/objavser/base/2cf1808d3d20440299fb427d506ac371/start_drive 14 | /path/to/your/objavser/base/2cf1808d3d20440299fb427d506ac371/attack 15 | /path/to/your/objavser/base/2cf1808d3d20440299fb427d506ac371/walk_events 16 | /path/to/your/objavser/base/2cf1808d3d20440299fb427d506ac371/open_events 17 | /path/to/your/objavser/base/2cf1808d3d20440299fb427d506ac371/drop_paper 18 | /path/to/your/objavser/base/2cf1808d3d20440299fb427d506ac371/idle2 19 | /path/to/your/objavser/base/2cf1808d3d20440299fb427d506ac371/scare_0 20 | /path/to/your/objavser/base/2cf1808d3d20440299fb427d506ac371/walk_fast_events 21 | /path/to/your/objavser/base/2cf1808d3d20440299fb427d506ac371/olla_loop 22 | /path/to/your/objavser/base/2e523d5dde7c44f891ae6dc8fbd84fe5/Armature|Armature|mixamo.com|Layer0 23 | /path/to/your/objavser/base/2f57441ebba047a3af604aed7906a7f4/Scene 24 | /path/to/your/objavser/base/37473ebbc16e4df887d4e29bda9de201/mixamo.com 25 | /path/to/your/objavser/base/3a6a30a99ec146c9a7f8fec4970d8e1d/hit_big 26 | /path/to/your/objavser/base/3a6a30a99ec146c9a7f8fec4970d8e1d/fly_loop 27 | /path/to/your/objavser/base/3a6a30a99ec146c9a7f8fec4970d8e1d/special_b_precast_loop 28 | /path/to/your/objavser/base/3a6a30a99ec146c9a7f8fec4970d8e1d/combat_idle 29 | /path/to/your/objavser/base/3a6a30a99ec146c9a7f8fec4970d8e1d/die 30 | /path/to/your/objavser/base/3a6a30a99ec146c9a7f8fec4970d8e1d/jump_back_end 31 | /path/to/your/objavser/base/3a6a30a99ec146c9a7f8fec4970d8e1d/run 32 | /path/to/your/objavser/base/3a6a30a99ec146c9a7f8fec4970d8e1d/basic 33 | /path/to/your/objavser/base/3a6a30a99ec146c9a7f8fec4970d8e1d/damaged 34 | /path/to/your/objavser/base/3a6a30a99ec146c9a7f8fec4970d8e1d/hit_small 35 | /path/to/your/objavser/base/3a6a30a99ec146c9a7f8fec4970d8e1d/special_b_precast 36 | /path/to/your/objavser/base/3a6a30a99ec146c9a7f8fec4970d8e1d/stun 37 | /path/to/your/objavser/base/3a6a30a99ec146c9a7f8fec4970d8e1d/special_b_attack 38 | /path/to/your/objavser/base/3a6a30a99ec146c9a7f8fec4970d8e1d/jump_back_idle 39 | /path/to/your/objavser/base/3a6a30a99ec146c9a7f8fec4970d8e1d/dodge_loop 40 | /path/to/your/objavser/base/3a6a30a99ec146c9a7f8fec4970d8e1d/walk 41 | /path/to/your/objavser/base/3a6a30a99ec146c9a7f8fec4970d8e1d/special_a 42 | /path/to/your/objavser/base/4418275786f44f03916251c49beb8799/IMG_8945 43 | /path/to/your/objavser/base/4418275786f44f03916251c49beb8799/IMG 44 | /path/to/your/objavser/base/497997f9f3b84377ade07fa4fbdea9e2/mixamo.com 45 | /path/to/your/objavser/base/5086b8340c714cf58ab59f00babf9cca/Open 46 | /path/to/your/objavser/base/5086b8340c714cf58ab59f00babf9cca/Close 47 | /path/to/your/objavser/base/5086b8340c714cf58ab59f00babf9cca/Open-Malfunction 48 | /path/to/your/objavser/base/5086b8340c714cf58ab59f00babf9cca/Close-Malfunction 49 | /path/to/your/objavser/base/52288037886a4cef8a92397bdb310434/Take 001 50 | /path/to/your/objavser/base/578a78d17dea40c3859931ea03bbf2b8/squat 51 | /path/to/your/objavser/base/578a78d17dea40c3859931ea03bbf2b8/improve floss 52 | /path/to/your/objavser/base/68979e426f9a44a7829ed6c5e89870bb/Salto_Caballo 53 | /path/to/your/objavser/base/68979e426f9a44a7829ed6c5e89870bb/Andando_Caballo 54 | /path/to/your/objavser/base/68979e426f9a44a7829ed6c5e89870bb/Idle_Caballo 55 | /path/to/your/objavser/base/68979e426f9a44a7829ed6c5e89870bb/Galopando_Caballo 56 | /path/to/your/objavser/base/68979e426f9a44a7829ed6c5e89870bb/Andando 57 | /path/to/your/objavser/base/6ab7fd20b17b4f289ea3e333acbf90eb/mixamo.com 58 | /path/to/your/objavser/base/6ad5fc0ea9544c1da3e9c541a55391cd/Take 001 59 | /path/to/your/objavser/base/6cb5fdce78f247719d5ebe102f3564cf/Take 001 60 | /path/to/your/objavser/base/8685d3d1c08d4ff8a6a5e1d04138c59a/Take 001 61 | /path/to/your/objavser/base/9266f1e6f2434bae8bf2c48195a3a1d1/Take 001 62 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_4040_Special 63 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_1022_Tired 64 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_1012_Idle 65 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_5040_Win 66 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_5030_Charge 67 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_4020_Special 68 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_5050_HomeTap 69 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_3010_Attack 70 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_4030_Special 71 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_6031_Down 72 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_6010_Damage 73 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_6032_Death 74 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_6033_Awoke 75 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_3820_SkillC 76 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_4010_Special 77 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_5021_DefenceS 78 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_6020_Critical 79 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_3020_Magic 80 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_5022_DefenceR 81 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_1032_HomeIdle 82 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_5023_DefenceF 83 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit_mot_039801_2022_Move 84 | /path/to/your/objavser/base/980f91d0783a44f8bef160e56803640f/unit 85 | /path/to/your/objavser/base/9f684924a6654b7eb2f6aaddb9a95624/animation.model.new 86 | /path/to/your/objavser/base/a6704f7831b74920a6cceb483f5474da/Freddy--Idle 87 | /path/to/your/objavser/base/a6b6151036ef4f008be4fdc9670fd511/Animation 88 | /path/to/your/objavser/base/ac86acea72a34462a83aa8a380f663f2/Take 001 89 | /path/to/your/objavser/base/b70b5012638f43f2867c360e9c043697/att_me2 90 | /path/to/your/objavser/base/b70b5012638f43f2867c360e9c043697/att 91 | /path/to/your/objavser/base/b70b5012638f43f2867c360e9c043697/att_me1 92 | /path/to/your/objavser/base/be374fb87f3442b0bbe7548a756102e7/Armature.001Action 93 | /path/to/your/objavser/base/c7323e259bbf43da870d8f1cb422d65a/Take 001 94 | /path/to/your/objavser/base/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan_enderman.summon 95 | /path/to/your/objavser/base/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan_enderman.night 96 | /path/to/your/objavser/base/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan_enderman.spawn 97 | /path/to/your/objavser/base/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan_enderman.attack2 98 | /path/to/your/objavser/base/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan_enderman.attack 99 | /path/to/your/objavser/base/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan_enderman.death 100 | /path/to/your/objavser/base/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan_enderman.roaring 101 | /path/to/your/objavser/base/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan_enderman.jump 102 | /path/to/your/objavser/base/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan_enderman.attack4 103 | /path/to/your/objavser/base/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan 104 | /path/to/your/objavser/base/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan_enderman.attack3 105 | /path/to/your/objavser/base/d1e6256917b64175a7965f7cf4c8ad1a/Take 001 106 | /path/to/your/objavser/base/d4065c03ef10491ea8227cc44aad1094/Test 107 | /path/to/your/objavser/base/e1494e4b6b724f2fb9a8a80c17aa48ec/Tired stretch 108 | /path/to/your/objavser/base/eb9117a25a22437d818085c77f6dc5b2/Toy_Bonnie--CPU_Revive 109 | /path/to/your/objavser/base/eb9117a25a22437d818085c77f6dc5b2/Toy_Bonnie--CPU_Lurch 110 | /path/to/your/objavser/base/eb9117a25a22437d818085c77f6dc5b2/Toy_Bonnie--Shutdown 111 | /path/to/your/objavser/base/eb9117a25a22437d818085c77f6dc5b2/Toy_Bonnie--Idle 112 | /path/to/your/objavser/base/eb9117a25a22437d818085c77f6dc5b2/Toy 113 | /path/to/your/objavser/base/eb9117a25a22437d818085c77f6dc5b2/Toy_Bonnie--Jumpscare 114 | /path/to/your/objavser/base/eb9117a25a22437d818085c77f6dc5b2/Toy_Bonnie--Shocked 115 | /path/to/your/objavser/base/eb9117a25a22437d818085c77f6dc5b2/Toy_Bonnie--Haywire 116 | /path/to/your/objavser/base/eb9117a25a22437d818085c77f6dc5b2/Toy_Bonnie--CPU_Shutdown 117 | /path/to/your/objavser/base/eb9117a25a22437d818085c77f6dc5b2/Toy_Bonnie--Charge 118 | /path/to/your/objavser/base/f6830812ea034058b020ac5a09ef1e33/Armature|Armature|mixamo.com|Layer0 119 | /path/to/your/objavser/base/f6830812ea034058b020ac5a09ef1e33/Armature|Armature|Armature|Armature|mixamo.com|Layer0|Armature| 120 | /path/to/your/objavser/base/fdda29d31bb5476a9d6f09c1675475e7/Animation -------------------------------------------------------------------------------- /filelist/zero123_val_filelist_objavser_hq.txt: -------------------------------------------------------------------------------- 1 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/1f60d11ef7904a20aedcfebaed750b6f/Armature|Anim 2 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/1f60d11ef7904a20aedcfebaed750b6f/Anim 3 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2cf1808d3d20440299fb427d506ac371/walk_slow_events 4 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2cf1808d3d20440299fb427d506ac371/vuelta_cocinar 5 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2cf1808d3d20440299fb427d506ac371/acariciar_tanque 6 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2cf1808d3d20440299fb427d506ac371/driving 7 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2cf1808d3d20440299fb427d506ac371/idle_fin_cocinar 8 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2cf1808d3d20440299fb427d506ac371/idle1 9 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2cf1808d3d20440299fb427d506ac371/idle_cocinar 10 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2cf1808d3d20440299fb427d506ac371/sujetar_cerdo 11 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2cf1808d3d20440299fb427d506ac371/idle3 12 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2cf1808d3d20440299fb427d506ac371/look_surprised 13 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2cf1808d3d20440299fb427d506ac371/start_drive 14 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2cf1808d3d20440299fb427d506ac371/attack 15 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2cf1808d3d20440299fb427d506ac371/walk_events 16 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2cf1808d3d20440299fb427d506ac371/open_events 17 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2cf1808d3d20440299fb427d506ac371/drop_paper 18 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2cf1808d3d20440299fb427d506ac371/idle2 19 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2cf1808d3d20440299fb427d506ac371/scare_0 20 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2cf1808d3d20440299fb427d506ac371/walk_fast_events 21 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2cf1808d3d20440299fb427d506ac371/olla_loop 22 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2e523d5dde7c44f891ae6dc8fbd84fe5/Armature|Armature|mixamo.com|Layer0 23 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/2f57441ebba047a3af604aed7906a7f4/Scene 24 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/37473ebbc16e4df887d4e29bda9de201/mixamo.com 25 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/3a6a30a99ec146c9a7f8fec4970d8e1d/hit_big 26 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/3a6a30a99ec146c9a7f8fec4970d8e1d/fly_loop 27 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/3a6a30a99ec146c9a7f8fec4970d8e1d/special_b_precast_loop 28 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/3a6a30a99ec146c9a7f8fec4970d8e1d/combat_idle 29 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/3a6a30a99ec146c9a7f8fec4970d8e1d/die 30 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/3a6a30a99ec146c9a7f8fec4970d8e1d/jump_back_end 31 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/3a6a30a99ec146c9a7f8fec4970d8e1d/run 32 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/3a6a30a99ec146c9a7f8fec4970d8e1d/basic 33 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/3a6a30a99ec146c9a7f8fec4970d8e1d/damaged 34 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/3a6a30a99ec146c9a7f8fec4970d8e1d/hit_small 35 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/3a6a30a99ec146c9a7f8fec4970d8e1d/special_b_precast 36 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/3a6a30a99ec146c9a7f8fec4970d8e1d/stun 37 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/3a6a30a99ec146c9a7f8fec4970d8e1d/special_b_attack 38 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/3a6a30a99ec146c9a7f8fec4970d8e1d/jump_back_idle 39 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/3a6a30a99ec146c9a7f8fec4970d8e1d/dodge_loop 40 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/3a6a30a99ec146c9a7f8fec4970d8e1d/walk 41 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/3a6a30a99ec146c9a7f8fec4970d8e1d/special_a 42 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/4418275786f44f03916251c49beb8799/IMG_8945 43 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/4418275786f44f03916251c49beb8799/IMG 44 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/497997f9f3b84377ade07fa4fbdea9e2/mixamo.com 45 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/5086b8340c714cf58ab59f00babf9cca/Open 46 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/5086b8340c714cf58ab59f00babf9cca/Close 47 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/5086b8340c714cf58ab59f00babf9cca/Open-Malfunction 48 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/5086b8340c714cf58ab59f00babf9cca/Close-Malfunction 49 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/52288037886a4cef8a92397bdb310434/Take 001 50 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/578a78d17dea40c3859931ea03bbf2b8/squat 51 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/578a78d17dea40c3859931ea03bbf2b8/improve floss 52 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/68979e426f9a44a7829ed6c5e89870bb/Salto_Caballo 53 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/68979e426f9a44a7829ed6c5e89870bb/Andando_Caballo 54 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/68979e426f9a44a7829ed6c5e89870bb/Idle_Caballo 55 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/68979e426f9a44a7829ed6c5e89870bb/Galopando_Caballo 56 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/68979e426f9a44a7829ed6c5e89870bb/Andando 57 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/6ab7fd20b17b4f289ea3e333acbf90eb/mixamo.com 58 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/6ad5fc0ea9544c1da3e9c541a55391cd/Take 001 59 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/6cb5fdce78f247719d5ebe102f3564cf/Take 001 60 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/8685d3d1c08d4ff8a6a5e1d04138c59a/Take 001 61 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/9266f1e6f2434bae8bf2c48195a3a1d1/Take 001 62 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_4040_Special 63 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_1022_Tired 64 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_1012_Idle 65 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_5040_Win 66 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_5030_Charge 67 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_4020_Special 68 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_5050_HomeTap 69 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_3010_Attack 70 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_4030_Special 71 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_6031_Down 72 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_6010_Damage 73 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_6032_Death 74 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_6033_Awoke 75 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_3820_SkillC 76 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_4010_Special 77 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_5021_DefenceS 78 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_6020_Critical 79 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_3020_Magic 80 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_5022_DefenceR 81 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_1032_HomeIdle 82 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_5023_DefenceF 83 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit_mot_039801_2022_Move 84 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/980f91d0783a44f8bef160e56803640f/unit 85 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/9f684924a6654b7eb2f6aaddb9a95624/animation.model.new 86 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/a6704f7831b74920a6cceb483f5474da/Freddy--Idle 87 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/a6b6151036ef4f008be4fdc9670fd511/Animation 88 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/ac86acea72a34462a83aa8a380f663f2/Take 001 89 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/b70b5012638f43f2867c360e9c043697/att_me2 90 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/b70b5012638f43f2867c360e9c043697/att 91 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/b70b5012638f43f2867c360e9c043697/att_me1 92 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/be374fb87f3442b0bbe7548a756102e7/Armature.001Action 93 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/c7323e259bbf43da870d8f1cb422d65a/Take 001 94 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan_enderman.summon 95 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan_enderman.night 96 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan_enderman.spawn 97 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan_enderman.attack2 98 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan_enderman.attack 99 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan_enderman.death 100 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan_enderman.roaring 101 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan_enderman.jump 102 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan_enderman.attack4 103 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan 104 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/c7ece725e8c24c4f9bd7ca7fe1f2c6e6/animation.titan_enderman.attack3 105 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/d1e6256917b64175a7965f7cf4c8ad1a/Take 001 106 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/d4065c03ef10491ea8227cc44aad1094/Test 107 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/e1494e4b6b724f2fb9a8a80c17aa48ec/Tired stretch 108 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/eb9117a25a22437d818085c77f6dc5b2/Toy_Bonnie--CPU_Revive 109 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/eb9117a25a22437d818085c77f6dc5b2/Toy_Bonnie--CPU_Lurch 110 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/eb9117a25a22437d818085c77f6dc5b2/Toy_Bonnie--Shutdown 111 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/eb9117a25a22437d818085c77f6dc5b2/Toy_Bonnie--Idle 112 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/eb9117a25a22437d818085c77f6dc5b2/Toy 113 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/eb9117a25a22437d818085c77f6dc5b2/Toy_Bonnie--Jumpscare 114 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/eb9117a25a22437d818085c77f6dc5b2/Toy_Bonnie--Shocked 115 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/eb9117a25a22437d818085c77f6dc5b2/Toy_Bonnie--Haywire 116 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/eb9117a25a22437d818085c77f6dc5b2/Toy_Bonnie--CPU_Shutdown 117 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/eb9117a25a22437d818085c77f6dc5b2/Toy_Bonnie--Charge 118 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/f6830812ea034058b020ac5a09ef1e33/Armature|Armature|mixamo.com|Layer0 119 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/f6830812ea034058b020ac5a09ef1e33/Armature|Armature|Armature|Armature|mixamo.com|Layer0|Armature| 120 | ./preprocess/zero123_preprocessed_data/Objaverse_HQ/fdda29d31bb5476a9d6f09c1675475e7/Animation -------------------------------------------------------------------------------- /preprocess/zero123plus/model.py: -------------------------------------------------------------------------------- 1 | # This code is from InstantMesh https://github.com/TencentARC/InstantMesh 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import pytorch_lightning as pl 8 | from tqdm import tqdm 9 | from torchvision.transforms import v2 10 | from torchvision.utils import make_grid, save_image 11 | from einops import rearrange 12 | 13 | from src.utils.train_util import instantiate_from_config 14 | from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, DDPMScheduler, UNet2DConditionModel 15 | from .pipeline import RefOnlyNoisedUNet 16 | 17 | 18 | def scale_latents(latents): 19 | latents = (latents - 0.22) * 0.75 20 | return latents 21 | 22 | 23 | def unscale_latents(latents): 24 | latents = latents / 0.75 + 0.22 25 | return latents 26 | 27 | 28 | def scale_image(image): 29 | image = image * 0.5 / 0.8 30 | return image 31 | 32 | 33 | def unscale_image(image): 34 | image = image / 0.5 * 0.8 35 | return image 36 | 37 | 38 | def extract_into_tensor(a, t, x_shape): 39 | b, *_ = t.shape 40 | out = a.gather(-1, t) 41 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 42 | 43 | 44 | class MVDiffusion(pl.LightningModule): 45 | def __init__( 46 | self, 47 | stable_diffusion_config, 48 | drop_cond_prob=0.1, 49 | ): 50 | super(MVDiffusion, self).__init__() 51 | 52 | self.drop_cond_prob = drop_cond_prob 53 | 54 | self.register_schedule() 55 | 56 | # init modules 57 | pipeline = DiffusionPipeline.from_pretrained(**stable_diffusion_config) 58 | pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( 59 | pipeline.scheduler.config, timestep_spacing='trailing' 60 | ) 61 | self.pipeline = pipeline 62 | 63 | train_sched = DDPMScheduler.from_config(self.pipeline.scheduler.config) 64 | 65 | if isinstance(self.pipeline.unet, UNet2DConditionModel): 66 | self.pipeline.unet = RefOnlyNoisedUNet(self.pipeline.unet, train_sched, self.pipeline.scheduler) 67 | 68 | self.train_scheduler = train_sched # use ddpm scheduler during training 69 | 70 | self.unet = pipeline.unet 71 | 72 | # validation output buffer 73 | self.validation_step_outputs = [] 74 | 75 | def register_schedule(self): 76 | self.num_timesteps = 1000 77 | 78 | # replace scaled_linear schedule with linear schedule as Zero123++ 79 | beta_start = 0.00085 80 | beta_end = 0.0120 81 | betas = torch.linspace(beta_start, beta_end, 1000, dtype=torch.float32) 82 | 83 | alphas = 1. - betas 84 | alphas_cumprod = torch.cumprod(alphas, dim=0) 85 | alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=torch.float64), alphas_cumprod[:-1]], 0) 86 | 87 | self.register_buffer('betas', betas.float()) 88 | self.register_buffer('alphas_cumprod', alphas_cumprod.float()) 89 | self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev.float()) 90 | 91 | # calculations for diffusion q(x_t | x_{t-1}) and others 92 | self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod).float()) 93 | self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - alphas_cumprod).float()) 94 | 95 | self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod).float()) 96 | self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1).float()) 97 | 98 | def on_fit_start(self): 99 | device = torch.device(f'cuda:{self.global_rank}') 100 | self.pipeline.to(device) 101 | if self.global_rank == 0: 102 | os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True) 103 | os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True) 104 | 105 | def prepare_batch_data(self, batch): 106 | # prepare stable diffusion input 107 | cond_imgs = batch['cond_imgs'] # (B, C, H, W) 108 | cond_imgs = cond_imgs.to(self.device) 109 | 110 | # random resize the condition image 111 | cond_size = np.random.randint(128, 513) 112 | cond_imgs = v2.functional.resize(cond_imgs, cond_size, interpolation=3, antialias=True).clamp(0, 1) 113 | 114 | target_imgs = batch['target_imgs'] # (B, 6, C, H, W) 115 | target_imgs = v2.functional.resize(target_imgs, 320, interpolation=3, antialias=True).clamp(0, 1) 116 | target_imgs = rearrange(target_imgs, 'b (x y) c h w -> b c (x h) (y w)', x=3, y=2) # (B, C, 3H, 2W) 117 | target_imgs = target_imgs.to(self.device) 118 | 119 | return cond_imgs, target_imgs 120 | 121 | @torch.no_grad() 122 | def forward_vision_encoder(self, images): 123 | dtype = next(self.pipeline.vision_encoder.parameters()).dtype 124 | image_pil = [v2.functional.to_pil_image(images[i]) for i in range(images.shape[0])] 125 | image_pt = self.pipeline.feature_extractor_clip(images=image_pil, return_tensors="pt").pixel_values 126 | image_pt = image_pt.to(device=self.device, dtype=dtype) 127 | global_embeds = self.pipeline.vision_encoder(image_pt, output_hidden_states=False).image_embeds 128 | global_embeds = global_embeds.unsqueeze(-2) 129 | 130 | encoder_hidden_states = self.pipeline.encode_prompt("", self.device, 1, False)[0] 131 | ramp = global_embeds.new_tensor(self.pipeline.config.ramping_coefficients).unsqueeze(-1) 132 | encoder_hidden_states = encoder_hidden_states + global_embeds * ramp 133 | 134 | return encoder_hidden_states 135 | 136 | @torch.no_grad() 137 | def encode_condition_image(self, images): 138 | dtype = next(self.pipeline.vae.parameters()).dtype 139 | image_pil = [v2.functional.to_pil_image(images[i]) for i in range(images.shape[0])] 140 | image_pt = self.pipeline.feature_extractor_vae(images=image_pil, return_tensors="pt").pixel_values 141 | image_pt = image_pt.to(device=self.device, dtype=dtype) 142 | latents = self.pipeline.vae.encode(image_pt).latent_dist.sample() 143 | return latents 144 | 145 | @torch.no_grad() 146 | def encode_target_images(self, images): 147 | dtype = next(self.pipeline.vae.parameters()).dtype 148 | # equals to scaling images to [-1, 1] first and then call scale_image 149 | images = (images - 0.5) / 0.8 # [-0.625, 0.625] 150 | posterior = self.pipeline.vae.encode(images.to(dtype)).latent_dist 151 | latents = posterior.sample() * self.pipeline.vae.config.scaling_factor 152 | latents = scale_latents(latents) 153 | return latents 154 | 155 | def forward_unet(self, latents, t, prompt_embeds, cond_latents): 156 | dtype = next(self.pipeline.unet.parameters()).dtype 157 | latents = latents.to(dtype) 158 | prompt_embeds = prompt_embeds.to(dtype) 159 | cond_latents = cond_latents.to(dtype) 160 | cross_attention_kwargs = dict(cond_lat=cond_latents) 161 | pred_noise = self.pipeline.unet( 162 | latents, 163 | t, 164 | encoder_hidden_states=prompt_embeds, 165 | cross_attention_kwargs=cross_attention_kwargs, 166 | return_dict=False, 167 | )[0] 168 | return pred_noise 169 | 170 | def predict_start_from_z_and_v(self, x_t, t, v): 171 | return ( 172 | extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - 173 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v 174 | ) 175 | 176 | def get_v(self, x, noise, t): 177 | return ( 178 | extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - 179 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x 180 | ) 181 | 182 | def training_step(self, batch, batch_idx): 183 | # get input 184 | cond_imgs, target_imgs = self.prepare_batch_data(batch) 185 | 186 | # sample random timestep 187 | B = cond_imgs.shape[0] 188 | 189 | t = torch.randint(0, self.num_timesteps, size=(B,)).long().to(self.device) 190 | 191 | # classifier-free guidance 192 | # if np.random.rand() < self.drop_cond_prob: 193 | # prompt_embeds = self.pipeline._encode_prompt([""]*B, self.device, 1, False) 194 | # cond_latents = self.encode_condition_image(torch.zeros_like(cond_imgs)) 195 | # else: 196 | # prompt_embeds = self.forward_vision_encoder(cond_imgs) 197 | # cond_latents = self.encode_condition_image(cond_imgs) 198 | 199 | prompt_embeds = self.forward_vision_encoder(cond_imgs) 200 | cond_latents = self.encode_condition_image(cond_imgs) 201 | 202 | latents = self.encode_target_images(target_imgs) 203 | noise = torch.randn_like(latents) 204 | latents_noisy = self.train_scheduler.add_noise(latents, noise, t) 205 | 206 | if isinstance(prompt_embeds, tuple): 207 | print(prompt_embeds) 208 | v_pred = self.forward_unet(latents_noisy, t, prompt_embeds, cond_latents) 209 | v_target = self.get_v(latents, noise, t) 210 | 211 | loss, loss_dict = self.compute_loss(v_pred, v_target) 212 | 213 | # logging 214 | self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) 215 | self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False) 216 | lr = self.optimizers().param_groups[0]['lr'] 217 | self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) 218 | 219 | if self.global_step % 500 == 0 and self.global_rank == 0: 220 | with torch.no_grad(): 221 | latents_pred = self.predict_start_from_z_and_v(latents_noisy, t, v_pred) 222 | 223 | latents = unscale_latents(latents_pred) 224 | images = unscale_image(self.pipeline.vae.decode(latents / self.pipeline.vae.config.scaling_factor, return_dict=False)[0]) # [-1, 1] 225 | images = (images * 0.5 + 0.5).clamp(0, 1) 226 | images = torch.cat([target_imgs, images], dim=-2) 227 | 228 | grid = make_grid(images, nrow=images.shape[0], normalize=True, value_range=(0, 1)) 229 | save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')) 230 | 231 | return loss 232 | 233 | def compute_loss(self, noise_pred, noise_gt): 234 | loss = F.mse_loss(noise_pred, noise_gt) 235 | 236 | prefix = 'train' 237 | loss_dict = {} 238 | loss_dict.update({f'{prefix}/loss': loss}) 239 | 240 | return loss, loss_dict 241 | 242 | @torch.no_grad() 243 | def validation_step(self, batch, batch_idx): 244 | # get input 245 | cond_imgs, target_imgs = self.prepare_batch_data(batch) 246 | 247 | images_pil = [v2.functional.to_pil_image(cond_imgs[i]) for i in range(cond_imgs.shape[0])] 248 | 249 | outputs = [] 250 | for cond_img in images_pil: 251 | latent = self.pipeline(cond_img, num_inference_steps=75, output_type='latent').images 252 | image = unscale_image(self.pipeline.vae.decode(latent / self.pipeline.vae.config.scaling_factor, return_dict=False)[0]) # [-1, 1] 253 | image = (image * 0.5 + 0.5).clamp(0, 1) 254 | outputs.append(image) 255 | outputs = torch.cat(outputs, dim=0).to(self.device) 256 | images = torch.cat([target_imgs, outputs], dim=-2) 257 | 258 | self.validation_step_outputs.append(images) 259 | 260 | @torch.no_grad() 261 | def on_validation_epoch_end(self): 262 | images = torch.cat(self.validation_step_outputs, dim=0) 263 | 264 | all_images = self.all_gather(images) 265 | all_images = rearrange(all_images, 'r b c h w -> (r b) c h w') 266 | 267 | if self.global_rank == 0: 268 | grid = make_grid(all_images, nrow=8, normalize=True, value_range=(0, 1)) 269 | save_image(grid, os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')) 270 | 271 | self.validation_step_outputs.clear() # free memory 272 | 273 | def configure_optimizers(self): 274 | lr = self.learning_rate 275 | 276 | optimizer = torch.optim.AdamW(self.unet.parameters(), lr=lr) 277 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4) 278 | 279 | return {'optimizer': optimizer, 'lr_scheduler': scheduler} 280 | --------------------------------------------------------------------------------