├── assets ├── demo.gif ├── example_ref.png └── example_input.png ├── requirements.txt ├── examples ├── gsplat │ ├── requirements.txt │ ├── datasets │ │ ├── normalize.py │ │ ├── traj.py │ │ └── colmap.py │ ├── utils.py │ └── lib_bilagrid.py ├── nerfstudio │ ├── pyproject.toml │ └── difix3d │ │ ├── difix3d_config.py │ │ ├── difix3d_datamanager.py │ │ ├── difix3d_trainer.py │ │ ├── difix3d.py │ │ ├── difix3d_pipeline.py │ │ └── difix3d_field.py └── utils.py ├── src ├── dataset.py ├── loss.py ├── inference_difix.py ├── model.py └── train_difix.py ├── README.md └── LICENSE.txt /assets/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/Difix3D/HEAD/assets/demo.gif -------------------------------------------------------------------------------- /assets/example_ref.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/Difix3D/HEAD/assets/example_ref.png -------------------------------------------------------------------------------- /assets/example_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/Difix3D/HEAD/assets/example_input.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchmetrics 4 | wandb 5 | imageio[ffmpeg] 6 | einops 7 | lpips 8 | xformers 9 | peft==0.9.0 10 | diffusers==0.25.1 11 | huggingface-hub==0.25.1 12 | transformers==4.38.0 -------------------------------------------------------------------------------- /examples/gsplat/requirements.txt: -------------------------------------------------------------------------------- 1 | # assume torch is already installed 2 | 3 | # pycolmap for data parsing 4 | git+https://github.com/rmbrualla/pycolmap@cc7ea4b7301720ac29287dbe450952511b32125e 5 | # (optional) nerfacc for torch version rasterization 6 | # git+https://github.com/nerfstudio-project/nerfacc 7 | 8 | viser 9 | nerfview==0.0.2 10 | imageio[ffmpeg] 11 | numpy<2.0.0 12 | scikit-learn 13 | tqdm 14 | torchmetrics[image] 15 | opencv-python 16 | tyro>=0.8.8 17 | Pillow 18 | tensorboard 19 | tensorly 20 | pyyaml 21 | matplotlib 22 | git+https://github.com/rahul-goel/fused-ssim@1272e21a282342e89537159e4bad508b19b34157 23 | -------------------------------------------------------------------------------- /examples/nerfstudio/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "difix3d" 3 | version = "0.1.0" 4 | 5 | dependencies=[ 6 | "nerfstudio>=0.3.0", 7 | ] 8 | 9 | # black 10 | [tool.black] 11 | line-length = 120 12 | 13 | # pylint 14 | [tool.pylint.messages_control] 15 | max-line-length = 120 16 | generated-members = ["numpy.*", "torch.*", "cv2.*", "cv.*"] 17 | good-names-rgxs = "^[_a-zA-Z][_a-z0-9]?$" 18 | ignore-paths = ["scripts/colmap2nerf.py"] 19 | jobs = 0 20 | ignored-classes = ["TensorDataclass"] 21 | 22 | disable = [ 23 | "duplicate-code", 24 | "fixme", 25 | "logging-fstring-interpolation", 26 | "too-many-arguments", 27 | "too-many-branches", 28 | "too-many-instance-attributes", 29 | "too-many-locals", 30 | "unnecessary-ellipsis", 31 | ] 32 | 33 | [tool.setuptools.packages.find] 34 | include = ["difix3d"] 35 | 36 | [project.entry-points.'nerfstudio.method_configs'] 37 | difix3d = 'difix3d.difix3d_config:difix3d_method' -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from PIL import Image 4 | import torchvision.transforms.functional as F 5 | 6 | 7 | class PairedDataset(torch.utils.data.Dataset): 8 | def __init__(self, dataset_path, split, height=576, width=1024, tokenizer=None): 9 | 10 | super().__init__() 11 | with open(dataset_path, "r") as f: 12 | self.data = json.load(f)[split] 13 | self.img_ids = list(self.data.keys()) 14 | self.image_size = (height, width) 15 | self.tokenizer = tokenizer 16 | 17 | def __len__(self): 18 | 19 | return len(self.img_ids) 20 | 21 | def __getitem__(self, idx): 22 | 23 | img_id = self.img_ids[idx] 24 | 25 | input_img = self.data[img_id]["image"] 26 | output_img = self.data[img_id]["target_image"] 27 | ref_img = self.data[img_id]["ref_image"] if "ref_image" in self.data[img_id] else None 28 | caption = self.data[img_id]["prompt"] 29 | 30 | try: 31 | input_img = Image.open(input_img) 32 | output_img = Image.open(output_img) 33 | except: 34 | print("Error loading image:", input_img, output_img) 35 | return self.__getitem__(idx + 1) 36 | 37 | img_t = F.to_tensor(img_t) 38 | img_t = F.resize(img_t, self.image_size) 39 | img_t = F.normalize(img_t, mean=[0.5], std=[0.5]) 40 | 41 | output_t = F.to_tensor(output_t) 42 | output_t = F.resize(output_t, self.image_size) 43 | output_t = F.normalize(output_t, mean=[0.5], std=[0.5]) 44 | 45 | if ref_img is not None: 46 | ref_img = Image.open(ref_img) 47 | ref_t = F.to_tensor(ref_t) 48 | ref_t = F.resize(ref_t, self.image_size) 49 | ref_t = F.normalize(ref_t, mean=[0.5], std=[0.5]) 50 | 51 | img_t = torch.stack([img_t, ref_t], dim=0) 52 | output_t = torch.stack([output_t, ref_t], dim=0) 53 | else: 54 | img_t = img_t.unsqueeze(0) 55 | output_t = output_t.unsqueeze(0) 56 | 57 | out = { 58 | "output_pixel_values": output_t, 59 | "conditioning_pixel_values": img_t, 60 | "caption": caption, 61 | } 62 | 63 | if self.tokenizer is not None: 64 | input_ids = self.tokenizer( 65 | caption, max_length=self.tokenizer.model_max_length, 66 | padding="max_length", truncation=True, return_tensors="pt" 67 | ).input_ids 68 | out["input_ids"] = input_ids 69 | 70 | return out 71 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms, models 3 | 4 | # Define style weights for different layers 5 | STYLE_WEIGHTS = { 6 | 'relu1_2': 1.0 / 2.6, 7 | 'relu2_2': 1.0 / 4.8, 8 | 'relu3_3': 1.0 / 3.7, 9 | 'relu4_3': 1.0 / 5.6, 10 | 'relu5_3': 10.0 / 1.5 11 | } 12 | 13 | def get_features(image, model, layers=None): 14 | """ 15 | Extract features from specific layers of a model for a given image. 16 | 17 | Args: 18 | image (torch.Tensor): Input image tensor. 19 | model (torch.nn.Module): Pretrained model (e.g., VGG). 20 | layers (dict): Mapping of layer indices to layer names. 21 | 22 | Returns: 23 | dict: A dictionary of features for the specified layers. 24 | """ 25 | if layers is None: 26 | layers = { 27 | '3': 'relu1_2', 28 | '8': 'relu2_2', 29 | '15': 'relu3_3', 30 | '22': 'relu4_3', 31 | '29': 'relu5_3' 32 | } 33 | 34 | features = {} 35 | x = image 36 | for name, layer in model._modules.items(): 37 | x = layer(x) 38 | if name in layers: 39 | features[layers[name]] = x 40 | return features 41 | 42 | def gram_matrix(tensor): 43 | """ 44 | Compute the Gram matrix for a given tensor. 45 | 46 | Args: 47 | tensor (torch.Tensor): Input tensor of shape (batch_size, depth, height, width). 48 | 49 | Returns: 50 | torch.Tensor: Gram matrix of the input tensor. 51 | """ 52 | b, d, h, w = tensor.size() 53 | tensor = tensor.view(b * d, h * w) # Reshape tensor for matrix multiplication 54 | gram = torch.mm(tensor, tensor.t()) # Compute Gram matrix 55 | return gram 56 | 57 | def gram_loss(style, target, model): 58 | """ 59 | Compute the Gram loss (style loss) between a style image and a target image. 60 | 61 | Args: 62 | style (torch.Tensor): Style image tensor. 63 | target (torch.Tensor): Target image tensor. 64 | model (torch.nn.Module): Pretrained model (e.g., VGG). 65 | 66 | Returns: 67 | torch.Tensor: The computed Gram loss. 68 | """ 69 | # Extract features for the style and target images 70 | style_features = get_features(style, model) 71 | target_features = get_features(target, model) 72 | 73 | # Compute Gram matrices for the style image 74 | style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features} 75 | 76 | # Initialize total loss 77 | total_loss = 0 78 | 79 | # Compute the weighted Gram loss for each layer 80 | for layer, weight in STYLE_WEIGHTS.items(): 81 | target_feature = target_features[layer] 82 | target_gram = gram_matrix(target_feature) 83 | style_gram = style_grams[layer] 84 | 85 | # Compute the layer-specific Gram loss 86 | _, d, h, w = target_feature.shape 87 | layer_loss = weight * torch.mean((target_gram - style_gram) ** 2) 88 | total_loss += layer_loss / (d * h * w) 89 | 90 | return total_loss -------------------------------------------------------------------------------- /examples/nerfstudio/difix3d/difix3d_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from nerfstudio.cameras.camera_optimizers import CameraOptimizerConfig 17 | from nerfstudio.configs.base_config import ViewerConfig 18 | from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig 19 | from nerfstudio.engine.optimizers import AdamOptimizerConfig 20 | from nerfstudio.engine.schedulers import ExponentialDecaySchedulerConfig 21 | from nerfstudio.plugins.types import MethodSpecification 22 | 23 | from difix3d.difix3d_datamanager import Difix3DDataManagerConfig 24 | from difix3d.difix3d import Difix3DModelConfig 25 | from difix3d.difix3d_pipeline import Difix3DPipelineConfig 26 | from difix3d.difix3d_trainer import Difix3DTrainerConfig 27 | 28 | difix3d_method = MethodSpecification( 29 | config=Difix3DTrainerConfig( 30 | method_name="difix3d", 31 | # steps_per_eval_batch=1000, 32 | # steps_per_eval_image=100, 33 | # steps_per_save=250, 34 | # max_num_iterations=15000, 35 | # save_only_latest_checkpoint=True, 36 | # mixed_precision=False, 37 | steps_per_eval_batch=500, 38 | steps_per_save=2000, 39 | max_num_iterations=30000, 40 | mixed_precision=True, 41 | pipeline=Difix3DPipelineConfig( 42 | datamanager=Difix3DDataManagerConfig( 43 | dataparser=NerfstudioDataParserConfig(), 44 | train_num_rays_per_batch=16384, 45 | eval_num_rays_per_batch=4096, 46 | ), 47 | model=Difix3DModelConfig( 48 | eval_num_rays_per_chunk=1 << 15, 49 | average_init_density=0.01, 50 | camera_optimizer=CameraOptimizerConfig(mode="SO3xR3"), 51 | ), 52 | ), 53 | optimizers={ 54 | "proposal_networks": { 55 | "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15), 56 | "scheduler": ExponentialDecaySchedulerConfig(lr_final=0.0001, max_steps=200000), 57 | }, 58 | "fields": { 59 | "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15), 60 | "scheduler": ExponentialDecaySchedulerConfig(lr_final=0.0001, max_steps=200000), 61 | }, 62 | "camera_opt": { 63 | "optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15), 64 | "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=5000), 65 | }, 66 | }, 67 | viewer=ViewerConfig(num_rays_per_chunk=1 << 15), 68 | vis="viewer", 69 | ), 70 | description="Difix3D", 71 | ) -------------------------------------------------------------------------------- /examples/nerfstudio/difix3d/difix3d_datamanager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from __future__ import annotations 17 | 18 | from dataclasses import dataclass, field 19 | from typing import Dict, Tuple, Type 20 | 21 | from rich.progress import Console 22 | 23 | from nerfstudio.cameras.rays import RayBundle 24 | from nerfstudio.data.utils.dataloaders import CacheDataloader 25 | from nerfstudio.model_components.ray_generators import RayGenerator 26 | from nerfstudio.data.datamanagers.base_datamanager import ( 27 | VanillaDataManager, 28 | VanillaDataManagerConfig, 29 | ) 30 | 31 | CONSOLE = Console(width=120) 32 | 33 | @dataclass 34 | class Difix3DDataManagerConfig(VanillaDataManagerConfig): 35 | """Configuration for the Difix3DDataManager.""" 36 | 37 | _target: Type = field(default_factory=lambda: Difix3DDataManager) 38 | patch_size: int = 32 39 | """Size of patch to sample from. If >1, patch-based sampling will be used.""" 40 | 41 | 42 | class Difix3DDataManager(VanillaDataManager): 43 | """Data manager for Difix3D.""" 44 | 45 | config: Difix3DDataManagerConfig 46 | 47 | def setup_train(self): 48 | """Sets up the data loaders for training""" 49 | assert self.train_dataset is not None 50 | CONSOLE.print("Setting up training dataset...") 51 | self.train_image_dataloader = CacheDataloader( 52 | self.train_dataset, 53 | num_images_to_sample_from=self.config.train_num_images_to_sample_from, 54 | num_times_to_repeat_images=self.config.train_num_times_to_repeat_images, 55 | device=self.device, 56 | num_workers=self.world_size * 4, 57 | pin_memory=True, 58 | collate_fn=self.config.collate_fn, 59 | exclude_batch_keys_from_device=self.exclude_batch_keys_from_device, 60 | ) 61 | self.iter_train_image_dataloader = iter(self.train_image_dataloader) 62 | self.train_pixel_sampler = self._get_pixel_sampler(self.train_dataset, self.config.train_num_rays_per_batch) 63 | self.train_ray_generator = RayGenerator(self.train_dataset.cameras.to(self.device)) 64 | 65 | def next_train(self, step: int) -> Tuple[RayBundle, Dict]: 66 | """Returns the next batch of data from the train dataloader.""" 67 | self.train_count += 1 68 | image_batch = next(self.iter_train_image_dataloader) 69 | assert self.train_pixel_sampler is not None 70 | assert isinstance(image_batch, dict) 71 | batch = self.train_pixel_sampler.sample(image_batch) 72 | ray_indices = batch["indices"] 73 | ray_bundle = self.train_ray_generator(ray_indices) 74 | return ray_bundle, batch 75 | -------------------------------------------------------------------------------- /src/inference_difix.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import argparse 4 | import numpy as np 5 | from PIL import Image 6 | from glob import glob 7 | from tqdm import tqdm 8 | from model import Difix 9 | 10 | 11 | if __name__ == "__main__": 12 | # Argument parser 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--input_image', type=str, required=True, help='Path to the input image or directory') 15 | parser.add_argument('--ref_image', type=str, default=None, help='Path to the reference image or directory') 16 | parser.add_argument('--height', type=int, default=576, help='Height of the input image') 17 | parser.add_argument('--width', type=int, default=1024, help='Width of the input image') 18 | parser.add_argument('--prompt', type=str, required=True, help='The prompt to be used') 19 | parser.add_argument('--model_name', type=str, default=None, help='Name of the pretrained model to be used') 20 | parser.add_argument('--model_path', type=str, default=None, help='Path to a model state dict to be used') 21 | parser.add_argument('--output_dir', type=str, default='output', help='Directory to save the output') 22 | parser.add_argument('--seed', type=int, default=42, help='Random seed to be used') 23 | parser.add_argument('--timestep', type=int, default=199, help='Diffusion timestep') 24 | parser.add_argument('--video', action='store_true', help='If the input is a video') 25 | args = parser.parse_args() 26 | 27 | # Create output directory 28 | os.makedirs(args.output_dir, exist_ok=True) 29 | 30 | # Initialize the model 31 | model = Difix( 32 | pretrained_name=args.model_name, 33 | pretrained_path=args.model_path, 34 | timestep=args.timestep, 35 | mv_unet=True if args.ref_image is not None else False, 36 | ) 37 | model.set_eval() 38 | 39 | # Load input images 40 | if os.path.isdir(args.input_image): 41 | input_images = sorted(glob(os.path.join(args.input_image, "*.png"))) 42 | else: 43 | input_images = [args.input_image] 44 | 45 | # Load reference images if provided 46 | if args.ref_image is not None: 47 | if os.path.isdir(args.ref_image): 48 | ref_images = sorted(glob(os.path.join(args.ref_image, "*"))) 49 | else: 50 | ref_images = [args.ref_image] 51 | 52 | assert len(input_images) == len(ref_images), "Number of input images and reference images should be the same" 53 | 54 | # Process images 55 | output_images = [] 56 | for i, input_image in enumerate(tqdm(input_images, desc="Processing images")): 57 | image = Image.open(input_image).convert('RGB') 58 | ref_image = Image.open(ref_images[i]).convert('RGB') if args.ref_image is not None else None 59 | output_image = model.sample( 60 | image, 61 | height=args.height, 62 | width=args.width, 63 | ref_image=ref_image, 64 | prompt=args.prompt 65 | ) 66 | output_images.append(output_image) 67 | 68 | # Save outputs 69 | if args.video: 70 | # Save as video 71 | video_path = os.path.join(args.output_dir, "output.mp4") 72 | writer = imageio.get_writer(video_path, fps=30) 73 | for output_image in tqdm(output_images, desc="Saving video"): 74 | writer.append_data(np.array(output_image)) 75 | writer.close() 76 | else: 77 | # Save as individual images 78 | for i, output_image in enumerate(tqdm(output_images, desc="Saving images")): 79 | output_image.save(os.path.join(args.output_dir, os.path.basename(input_images[i]))) -------------------------------------------------------------------------------- /examples/nerfstudio/difix3d/difix3d_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import dataclasses 17 | from dataclasses import dataclass, field 18 | from typing import Type, Literal 19 | from nerfstudio.engine.trainer import Trainer, TrainerConfig 20 | from nerfstudio.engine.callbacks import TrainingCallbackAttributes 21 | from nerfstudio.viewer.viewer import Viewer as ViewerState 22 | from nerfstudio.utils import profiler, writer 23 | 24 | 25 | @dataclass 26 | class Difix3DTrainerConfig(TrainerConfig): 27 | """Configuration for the Difix3DTrainer.""" 28 | _target: Type = field(default_factory=lambda: Difix3DTrainer) 29 | 30 | 31 | class Difix3DTrainer(Trainer): 32 | """Trainer for Difix3D""" 33 | 34 | def __init__(self, config: TrainerConfig, local_rank: int = 0, world_size: int = 1) -> None: 35 | 36 | super().__init__(config, local_rank, world_size) 37 | 38 | def setup(self, test_mode: Literal["test", "val", "inference"] = "val") -> None: 39 | """Setup the Trainer by calling other setup functions. 40 | 41 | Args: 42 | test_mode: 43 | 'val': loads train/val datasets into memory 44 | 'test': loads train/test datasets into memory 45 | 'inference': does not load any dataset into memory 46 | """ 47 | self.pipeline = self.config.pipeline.setup( 48 | device=self.device, 49 | test_mode=test_mode, 50 | world_size=self.world_size, 51 | local_rank=self.local_rank, 52 | grad_scaler=self.grad_scaler, 53 | render_dir=self.base_dir / "renders", 54 | ) 55 | self.optimizers = self.setup_optimizers() 56 | 57 | # set up viewer if enabled 58 | viewer_log_path = self.base_dir / self.config.viewer.relative_log_filename 59 | self.viewer_state, banner_messages = None, None 60 | if self.config.is_viewer_legacy_enabled() and self.local_rank == 0: 61 | datapath = self.config.data 62 | if datapath is None: 63 | datapath = self.base_dir 64 | self.viewer_state = ViewerLegacyState( 65 | self.config.viewer, 66 | log_filename=viewer_log_path, 67 | datapath=datapath, 68 | pipeline=self.pipeline, 69 | trainer=self, 70 | train_lock=self.train_lock, 71 | ) 72 | banner_messages = [f"Legacy viewer at: {self.viewer_state.viewer_url}"] 73 | if self.config.is_viewer_enabled() and self.local_rank == 0: 74 | datapath = self.config.data 75 | if datapath is None: 76 | datapath = self.base_dir 77 | self.viewer_state = ViewerState( 78 | self.config.viewer, 79 | log_filename=viewer_log_path, 80 | datapath=datapath, 81 | pipeline=self.pipeline, 82 | trainer=self, 83 | train_lock=self.train_lock, 84 | share=self.config.viewer.make_share_url, 85 | ) 86 | banner_messages = self.viewer_state.viewer_info 87 | self._check_viewer_warnings() 88 | 89 | self._load_checkpoint() 90 | 91 | self.callbacks = self.pipeline.get_training_callbacks( 92 | TrainingCallbackAttributes( 93 | optimizers=self.optimizers, grad_scaler=self.grad_scaler, pipeline=self.pipeline, trainer=self 94 | ) 95 | ) 96 | 97 | # set up writers/profilers if enabled 98 | writer_log_path = self.base_dir / self.config.logging.relative_log_dir 99 | writer.setup_event_writer( 100 | self.config.is_wandb_enabled(), 101 | self.config.is_tensorboard_enabled(), 102 | self.config.is_comet_enabled(), 103 | log_dir=writer_log_path, 104 | experiment_name=self.config.experiment_name, 105 | project_name=self.config.project_name, 106 | ) 107 | writer.setup_local_writer( 108 | self.config.logging, max_iter=self.config.max_num_iterations, banner_messages=banner_messages 109 | ) 110 | writer.put_config(name="config", config_dict=dataclasses.asdict(self.config), step=0) 111 | profiler.setup_profiler(self.config.logging, writer_log_path) -------------------------------------------------------------------------------- /examples/nerfstudio/difix3d/difix3d.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from __future__ import annotations 17 | 18 | from dataclasses import dataclass, field 19 | from typing import Type 20 | 21 | import torch 22 | from nerfstudio.models.nerfacto import NerfactoModel, NerfactoModelConfig 23 | from difix3d.difix3d_field import Difix3DField 24 | from nerfstudio.cameras.camera_optimizers import CameraOptimizer 25 | from nerfstudio.field_components.spatial_distortions import SceneContraction 26 | from nerfstudio.fields.density_fields import HashMLPDensityField 27 | 28 | 29 | @dataclass 30 | class Difix3DModelConfig(NerfactoModelConfig): 31 | """Configuration for the Difix3DModel.""" 32 | _target: Type = field(default_factory=lambda: Difix3DModel) 33 | 34 | disable_camera_optimizer: bool = False 35 | """Whether to disable the camera optimizer or not.""" 36 | freeze_appearance_embedding: bool = False 37 | 38 | class Difix3DModel(NerfactoModel): 39 | """Model for Difix3D.""" 40 | 41 | config: Difix3DModelConfig 42 | 43 | def populate_modules(self): 44 | """Set the fields and modules.""" 45 | super().populate_modules() 46 | 47 | if self.config.disable_scene_contraction: 48 | scene_contraction = None 49 | else: 50 | scene_contraction = SceneContraction(order=float("inf")) 51 | 52 | appearance_embedding_dim = self.config.appearance_embed_dim if self.config.use_appearance_embedding else 0 53 | 54 | # Fields 55 | self.field = Difix3DField( 56 | self.scene_box.aabb, 57 | hidden_dim=self.config.hidden_dim, 58 | num_levels=self.config.num_levels, 59 | max_res=self.config.max_res, 60 | base_res=self.config.base_res, 61 | features_per_level=self.config.features_per_level, 62 | log2_hashmap_size=self.config.log2_hashmap_size, 63 | hidden_dim_color=self.config.hidden_dim_color, 64 | hidden_dim_transient=self.config.hidden_dim_transient, 65 | spatial_distortion=scene_contraction, 66 | num_images=self.num_train_data, 67 | use_pred_normals=self.config.predict_normals, 68 | use_average_appearance_embedding=self.config.use_average_appearance_embedding, 69 | appearance_embedding_dim=appearance_embedding_dim, 70 | average_init_density=self.config.average_init_density, 71 | implementation=self.config.implementation, 72 | freeze_appearance_embedding=self.config.freeze_appearance_embedding, 73 | ) 74 | 75 | self.camera_optimizer: CameraOptimizer = self.config.camera_optimizer.setup( 76 | num_cameras=self.num_train_data, device="cpu" 77 | ) 78 | if self.config.disable_camera_optimizer: 79 | self.camera_optimizer.mode = "off" 80 | self.density_fns = [] 81 | num_prop_nets = self.config.num_proposal_iterations 82 | # Build the proposal network(s) 83 | self.proposal_networks = torch.nn.ModuleList() 84 | if self.config.use_same_proposal_network: 85 | assert len(self.config.proposal_net_args_list) == 1, "Only one proposal network is allowed." 86 | prop_net_args = self.config.proposal_net_args_list[0] 87 | network = HashMLPDensityField( 88 | self.scene_box.aabb, 89 | spatial_distortion=scene_contraction, 90 | **prop_net_args, 91 | average_init_density=self.config.average_init_density, 92 | implementation=self.config.implementation, 93 | ) 94 | self.proposal_networks.append(network) 95 | self.density_fns.extend([network.density_fn for _ in range(num_prop_nets)]) 96 | else: 97 | for i in range(num_prop_nets): 98 | prop_net_args = self.config.proposal_net_args_list[min(i, len(self.config.proposal_net_args_list) - 1)] 99 | network = HashMLPDensityField( 100 | self.scene_box.aabb, 101 | spatial_distortion=scene_contraction, 102 | **prop_net_args, 103 | average_init_density=self.config.average_init_density, 104 | implementation=self.config.implementation, 105 | ) 106 | self.proposal_networks.append(network) 107 | self.density_fns.extend([network.density_fn for network in self.proposal_networks]) 108 | -------------------------------------------------------------------------------- /examples/gsplat/datasets/normalize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def similarity_from_cameras(c2w, strict_scaling=False, center_method="focus"): 5 | """ 6 | reference: nerf-factory 7 | Get a similarity transform to normalize dataset 8 | from c2w (OpenCV convention) cameras 9 | :param c2w: (N, 4) 10 | :return T (4,4) , scale (float) 11 | """ 12 | t = c2w[:, :3, 3] 13 | R = c2w[:, :3, :3] 14 | 15 | # (1) Rotate the world so that z+ is the up axis 16 | # we estimate the up axis by averaging the camera up axes 17 | ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1) 18 | world_up = np.mean(ups, axis=0) 19 | world_up /= np.linalg.norm(world_up) 20 | 21 | up_camspace = np.array([0.0, -1.0, 0.0]) 22 | c = (up_camspace * world_up).sum() 23 | cross = np.cross(world_up, up_camspace) 24 | skew = np.array( 25 | [ 26 | [0.0, -cross[2], cross[1]], 27 | [cross[2], 0.0, -cross[0]], 28 | [-cross[1], cross[0], 0.0], 29 | ] 30 | ) 31 | if c > -1: 32 | R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c) 33 | else: 34 | # In the unlikely case the original data has y+ up axis, 35 | # rotate 180-deg about x axis 36 | R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) 37 | 38 | # R_align = np.eye(3) # DEBUG 39 | R = R_align @ R 40 | fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1) 41 | t = (R_align @ t[..., None])[..., 0] 42 | 43 | # (2) Recenter the scene. 44 | if center_method == "focus": 45 | # find the closest point to the origin for each camera's center ray 46 | nearest = t + (fwds * -t).sum(-1)[:, None] * fwds 47 | translate = -np.median(nearest, axis=0) 48 | elif center_method == "poses": 49 | # use center of the camera positions 50 | translate = -np.median(t, axis=0) 51 | else: 52 | raise ValueError(f"Unknown center_method {center_method}") 53 | 54 | transform = np.eye(4) 55 | transform[:3, 3] = translate 56 | transform[:3, :3] = R_align 57 | 58 | # (3) Rescale the scene using camera distances 59 | scale_fn = np.max if strict_scaling else np.median 60 | scale = 1.0 / scale_fn(np.linalg.norm(t + translate, axis=-1)) 61 | transform[:3, :] *= scale 62 | 63 | return transform 64 | 65 | 66 | def align_principle_axes(point_cloud): 67 | # Compute centroid 68 | centroid = np.median(point_cloud, axis=0) 69 | 70 | # Translate point cloud to centroid 71 | translated_point_cloud = point_cloud - centroid 72 | 73 | # Compute covariance matrix 74 | covariance_matrix = np.cov(translated_point_cloud, rowvar=False) 75 | 76 | # Compute eigenvectors and eigenvalues 77 | eigenvalues, eigenvectors = np.linalg.eigh(covariance_matrix) 78 | 79 | # Sort eigenvectors by eigenvalues (descending order) so that the z-axis 80 | # is the principal axis with the smallest eigenvalue. 81 | sort_indices = eigenvalues.argsort()[::-1] 82 | eigenvectors = eigenvectors[:, sort_indices] 83 | 84 | # Check orientation of eigenvectors. If the determinant of the eigenvectors is 85 | # negative, then we need to flip the sign of one of the eigenvectors. 86 | if np.linalg.det(eigenvectors) < 0: 87 | eigenvectors[:, 0] *= -1 88 | 89 | # Create rotation matrix 90 | rotation_matrix = eigenvectors.T 91 | 92 | # Create SE(3) matrix (4x4 transformation matrix) 93 | transform = np.eye(4) 94 | transform[:3, :3] = rotation_matrix 95 | transform[:3, 3] = -rotation_matrix @ centroid 96 | 97 | return transform 98 | 99 | 100 | def transform_points(matrix, points): 101 | """Transform points using an SE(3) matrix. 102 | 103 | Args: 104 | matrix: 4x4 SE(3) matrix 105 | points: Nx3 array of points 106 | 107 | Returns: 108 | Nx3 array of transformed points 109 | """ 110 | assert matrix.shape == (4, 4) 111 | assert len(points.shape) == 2 and points.shape[1] == 3 112 | return points @ matrix[:3, :3].T + matrix[:3, 3] 113 | 114 | 115 | def transform_cameras(matrix, camtoworlds): 116 | """Transform cameras using an SE(3) matrix. 117 | 118 | Args: 119 | matrix: 4x4 SE(3) matrix 120 | camtoworlds: Nx4x4 array of camera-to-world matrices 121 | 122 | Returns: 123 | Nx4x4 array of transformed camera-to-world matrices 124 | """ 125 | assert matrix.shape == (4, 4) 126 | assert len(camtoworlds.shape) == 3 and camtoworlds.shape[1:] == (4, 4) 127 | camtoworlds = np.einsum("nij, ki -> nkj", camtoworlds, matrix) 128 | scaling = np.linalg.norm(camtoworlds[:, 0, :3], axis=1) 129 | camtoworlds[:, :3, :3] = camtoworlds[:, :3, :3] / scaling[:, None, None] 130 | return camtoworlds 131 | 132 | 133 | def normalize(camtoworlds, points=None): 134 | T1 = similarity_from_cameras(camtoworlds) 135 | camtoworlds = transform_cameras(T1, camtoworlds) 136 | if points is not None: 137 | points = transform_points(T1, points) 138 | T2 = align_principle_axes(points) 139 | camtoworlds = transform_cameras(T2, camtoworlds) 140 | points = transform_points(T2, points) 141 | return camtoworlds, points, T2 @ T1 142 | else: 143 | return camtoworlds, T1 144 | -------------------------------------------------------------------------------- /examples/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.transform import Rotation 3 | 4 | 5 | class CameraPoseInterpolator: 6 | """ 7 | A system for interpolating between sets of camera poses with visualization capabilities. 8 | """ 9 | 10 | def __init__(self, rotation_weight=1.0, translation_weight=1.0): 11 | """ 12 | Initialize the interpolator with weights for pose distance computation. 13 | 14 | Args: 15 | rotation_weight: Weight for rotational distance in pose matching 16 | translation_weight: Weight for translational distance in pose matching 17 | """ 18 | self.rotation_weight = rotation_weight 19 | self.translation_weight = translation_weight 20 | 21 | def compute_pose_distance(self, pose1, pose2): 22 | """ 23 | Compute weighted distance between two camera poses. 24 | 25 | Args: 26 | pose1, pose2: 4x4 transformation matrices 27 | 28 | Returns: 29 | Combined weighted distance between poses 30 | """ 31 | # Translation distance (Euclidean) 32 | t1, t2 = pose1[:3, 3], pose2[:3, 3] 33 | translation_dist = np.linalg.norm(t1 - t2) 34 | 35 | # Rotation distance (angular distance between quaternions) 36 | R1 = Rotation.from_matrix(pose1[:3, :3]) 37 | R2 = Rotation.from_matrix(pose2[:3, :3]) 38 | q1 = R1.as_quat() 39 | q2 = R2.as_quat() 40 | 41 | # Ensure quaternions are in the same hemisphere 42 | if np.dot(q1, q2) < 0: 43 | q2 = -q2 44 | 45 | rotation_dist = np.arccos(2 * np.dot(q1, q2)**2 - 1) 46 | 47 | return (self.translation_weight * translation_dist + 48 | self.rotation_weight * rotation_dist) 49 | 50 | def find_nearest_assignments(self, training_poses, testing_poses): 51 | """ 52 | Find the nearest training camera pose for each testing camera pose. 53 | 54 | Args: 55 | training_poses: [N, 4, 4] array of training camera poses 56 | testing_poses: [M, 4, 4] array of testing camera poses 57 | 58 | Returns: 59 | assignments: list of closest training pose indices for each testing pose 60 | """ 61 | M = len(testing_poses) 62 | assignments = [] 63 | 64 | for j in range(M): 65 | # Compute distance from each training pose to this testing pose 66 | distances = [self.compute_pose_distance(training_pose, testing_poses[j]) 67 | for training_pose in training_poses] 68 | # Find the index of the nearest training pose 69 | nearest_index = np.argmin(distances) 70 | assignments.append(nearest_index) 71 | 72 | return assignments 73 | 74 | def interpolate_rotation(self, R1, R2, t): 75 | """ 76 | Interpolate between two rotation matrices using SLERP. 77 | """ 78 | q1 = Rotation.from_matrix(R1).as_quat() 79 | q2 = Rotation.from_matrix(R2).as_quat() 80 | 81 | if np.dot(q1, q2) < 0: 82 | q2 = -q2 83 | 84 | # Clamp dot product to avoid invalid values in arccos 85 | dot_product = np.clip(np.dot(q1, q2), -1.0, 1.0) 86 | theta = np.arccos(dot_product) 87 | 88 | if np.abs(theta) < 1e-6: 89 | q_interp = (1 - t) * q1 + t * q2 90 | else: 91 | q_interp = (np.sin((1-t)*theta) * q1 + np.sin(t*theta) * q2) / np.sin(theta) 92 | 93 | q_interp = q_interp / np.linalg.norm(q_interp) 94 | return Rotation.from_quat(q_interp).as_matrix() 95 | 96 | def interpolate_poses(self, training_poses, testing_poses, num_steps=20): 97 | """ 98 | Interpolate between camera poses using nearest assignments. 99 | 100 | Args: 101 | training_poses: [N, 4, 4] array of training poses 102 | testing_poses: [M, 4, 4] array of testing poses 103 | num_steps: number of interpolation steps 104 | 105 | Returns: 106 | interpolated_sequences: list of lists of interpolated poses 107 | """ 108 | assignments = self.find_nearest_assignments(training_poses, testing_poses) 109 | interpolated_sequences = [] 110 | 111 | for test_idx, train_idx in enumerate(assignments): 112 | train_pose = training_poses[train_idx] 113 | test_pose = testing_poses[test_idx] 114 | sequence = [] 115 | 116 | for t in np.linspace(0, 1, num_steps): 117 | # Interpolate rotation 118 | R_interp = self.interpolate_rotation( 119 | train_pose[:3, :3], 120 | test_pose[:3, :3], 121 | t 122 | ) 123 | 124 | # Interpolate translation 125 | t_interp = (1-t) * train_pose[:3, 3] + t * test_pose[:3, 3] 126 | 127 | # Construct interpolated pose 128 | pose_interp = np.eye(4) 129 | pose_interp[:3, :3] = R_interp 130 | pose_interp[:3, 3] = t_interp 131 | 132 | sequence.append(pose_interp) 133 | 134 | interpolated_sequences.append(sequence) 135 | 136 | return interpolated_sequences 137 | 138 | 139 | def shift_poses(self, training_poses, testing_poses, distance=0.1, threshold=0.1): 140 | """ 141 | Shift nearest training poses toward testing poses by a specified distance. 142 | 143 | Args: 144 | training_poses: [N, 4, 4] array of training camera poses 145 | testing_poses: [M, 4, 4] array of testing camera poses 146 | distance: float, the step size to move training pose toward testing pose 147 | 148 | Returns: 149 | novel_poses: [M, 4, 4] array of shifted poses 150 | """ 151 | assignments = self.find_nearest_assignments(training_poses, testing_poses) 152 | novel_poses = [] 153 | 154 | for test_idx, train_idx in enumerate(assignments): 155 | train_pose = training_poses[train_idx] 156 | test_pose = testing_poses[test_idx] 157 | 158 | if self.compute_pose_distance(train_pose, test_pose) <= distance: 159 | novel_poses.append(test_pose) 160 | continue 161 | 162 | # Calculate translation step if shifting is necessary 163 | t1, t2 = train_pose[:3, 3], test_pose[:3, 3] 164 | translation_direction = t2 - t1 165 | translation_norm = np.linalg.norm(translation_direction) 166 | 167 | if translation_norm > 1e-6: 168 | translation_step = (translation_direction / translation_norm) * distance 169 | new_translation = t1 + translation_step 170 | else: 171 | # If translation direction is too small, use testing pose translation directly 172 | new_translation = t2 173 | 174 | # Check if the new translation would overshoot the testing pose translation 175 | if np.dot(new_translation - t1, t2 - t1) <= 0 or np.linalg.norm(new_translation - t2) <= distance: 176 | new_translation = t2 177 | 178 | # Update rotation 179 | R1 = train_pose[:3, :3] 180 | R2 = test_pose[:3, :3] 181 | if translation_norm > 1e-6: 182 | R_interp = self.interpolate_rotation(R1, R2, min(distance / translation_norm, 1.0)) 183 | else: 184 | R_interp = R2 # Use testing rotation if too close 185 | 186 | # Construct shifted pose 187 | shifted_pose = np.eye(4) 188 | shifted_pose[:3, :3] = R_interp 189 | shifted_pose[:3, 3] = new_translation 190 | 191 | novel_poses.append(shifted_pose) 192 | 193 | return np.array(novel_poses) -------------------------------------------------------------------------------- /examples/gsplat/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from sklearn.neighbors import NearestNeighbors 6 | from torch import Tensor 7 | import torch.nn.functional as F 8 | import matplotlib.pyplot as plt 9 | from matplotlib import colormaps 10 | 11 | 12 | class CameraOptModule(torch.nn.Module): 13 | """Camera pose optimization module.""" 14 | 15 | def __init__(self, n: int): 16 | super().__init__() 17 | # Delta positions (3D) + Delta rotations (6D) 18 | self.embeds = torch.nn.Embedding(n, 9) 19 | # Identity rotation in 6D representation 20 | self.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0])) 21 | 22 | def zero_init(self): 23 | torch.nn.init.zeros_(self.embeds.weight) 24 | 25 | def random_init(self, std: float): 26 | torch.nn.init.normal_(self.embeds.weight, std=std) 27 | 28 | def forward(self, camtoworlds: Tensor, embed_ids: Tensor) -> Tensor: 29 | """Adjust camera pose based on deltas. 30 | 31 | Args: 32 | camtoworlds: (..., 4, 4) 33 | embed_ids: (...,) 34 | 35 | Returns: 36 | updated camtoworlds: (..., 4, 4) 37 | """ 38 | assert camtoworlds.shape[:-2] == embed_ids.shape 39 | batch_shape = camtoworlds.shape[:-2] 40 | pose_deltas = self.embeds(embed_ids) # (..., 9) 41 | dx, drot = pose_deltas[..., :3], pose_deltas[..., 3:] 42 | rot = rotation_6d_to_matrix( 43 | drot + self.identity.expand(*batch_shape, -1) 44 | ) # (..., 3, 3) 45 | transform = torch.eye(4, device=pose_deltas.device).repeat((*batch_shape, 1, 1)) 46 | transform[..., :3, :3] = rot 47 | transform[..., :3, 3] = dx 48 | return torch.matmul(camtoworlds, transform) 49 | 50 | 51 | class AppearanceOptModule(torch.nn.Module): 52 | """Appearance optimization module.""" 53 | 54 | def __init__( 55 | self, 56 | n: int, 57 | feature_dim: int, 58 | embed_dim: int = 16, 59 | sh_degree: int = 3, 60 | mlp_width: int = 64, 61 | mlp_depth: int = 2, 62 | ): 63 | super().__init__() 64 | self.embed_dim = embed_dim 65 | self.sh_degree = sh_degree 66 | self.embeds = torch.nn.Embedding(n, embed_dim) 67 | layers = [] 68 | layers.append( 69 | torch.nn.Linear(embed_dim + feature_dim + (sh_degree + 1) ** 2, mlp_width) 70 | ) 71 | layers.append(torch.nn.ReLU(inplace=True)) 72 | for _ in range(mlp_depth - 1): 73 | layers.append(torch.nn.Linear(mlp_width, mlp_width)) 74 | layers.append(torch.nn.ReLU(inplace=True)) 75 | layers.append(torch.nn.Linear(mlp_width, 3)) 76 | self.color_head = torch.nn.Sequential(*layers) 77 | 78 | def forward( 79 | self, features: Tensor, embed_ids: Tensor, dirs: Tensor, sh_degree: int 80 | ) -> Tensor: 81 | """Adjust appearance based on embeddings. 82 | 83 | Args: 84 | features: (N, feature_dim) 85 | embed_ids: (C,) 86 | dirs: (C, N, 3) 87 | 88 | Returns: 89 | colors: (C, N, 3) 90 | """ 91 | from gsplat.cuda._torch_impl import _eval_sh_bases_fast 92 | 93 | C, N = dirs.shape[:2] 94 | # Camera embeddings 95 | if embed_ids is None: 96 | embeds = torch.zeros(C, self.embed_dim, device=features.device) 97 | else: 98 | embeds = self.embeds(embed_ids) # [C, D2] 99 | embeds = embeds[:, None, :].expand(-1, N, -1) # [C, N, D2] 100 | # GS features 101 | features = features[None, :, :].expand(C, -1, -1) # [C, N, D1] 102 | # View directions 103 | dirs = F.normalize(dirs, dim=-1) # [C, N, 3] 104 | num_bases_to_use = (sh_degree + 1) ** 2 105 | num_bases = (self.sh_degree + 1) ** 2 106 | sh_bases = torch.zeros(C, N, num_bases, device=features.device) # [C, N, K] 107 | sh_bases[:, :, :num_bases_to_use] = _eval_sh_bases_fast(num_bases_to_use, dirs) 108 | # Get colors 109 | if self.embed_dim > 0: 110 | h = torch.cat([embeds, features, sh_bases], dim=-1) # [C, N, D1 + D2 + K] 111 | else: 112 | h = torch.cat([features, sh_bases], dim=-1) 113 | colors = self.color_head(h) 114 | return colors 115 | 116 | 117 | def rotation_6d_to_matrix(d6: Tensor) -> Tensor: 118 | """ 119 | Converts 6D rotation representation by Zhou et al. [1] to rotation matrix 120 | using Gram--Schmidt orthogonalization per Section B of [1]. Adapted from pytorch3d. 121 | Args: 122 | d6: 6D rotation representation, of size (*, 6) 123 | 124 | Returns: 125 | batch of rotation matrices of size (*, 3, 3) 126 | 127 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. 128 | On the Continuity of Rotation Representations in Neural Networks. 129 | IEEE Conference on Computer Vision and Pattern Recognition, 2019. 130 | Retrieved from http://arxiv.org/abs/1812.07035 131 | """ 132 | 133 | a1, a2 = d6[..., :3], d6[..., 3:] 134 | b1 = F.normalize(a1, dim=-1) 135 | b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 136 | b2 = F.normalize(b2, dim=-1) 137 | b3 = torch.cross(b1, b2, dim=-1) 138 | return torch.stack((b1, b2, b3), dim=-2) 139 | 140 | 141 | def knn(x: Tensor, K: int = 4) -> Tensor: 142 | x_np = x.cpu().numpy() 143 | model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np) 144 | distances, _ = model.kneighbors(x_np) 145 | return torch.from_numpy(distances).to(x) 146 | 147 | 148 | def rgb_to_sh(rgb: Tensor) -> Tensor: 149 | C0 = 0.28209479177387814 150 | return (rgb - 0.5) / C0 151 | 152 | 153 | def set_random_seed(seed: int): 154 | random.seed(seed) 155 | np.random.seed(seed) 156 | torch.manual_seed(seed) 157 | 158 | 159 | # ref: https://github.com/hbb1/2d-gaussian-splatting/blob/main/utils/general_utils.py#L163 160 | def colormap(img, cmap="jet"): 161 | W, H = img.shape[:2] 162 | dpi = 300 163 | fig, ax = plt.subplots(1, figsize=(H / dpi, W / dpi), dpi=dpi) 164 | im = ax.imshow(img, cmap=cmap) 165 | ax.set_axis_off() 166 | fig.colorbar(im, ax=ax) 167 | fig.tight_layout() 168 | fig.canvas.draw() 169 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 170 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 171 | img = torch.from_numpy(data).float().permute(2, 0, 1) 172 | plt.close() 173 | return img 174 | 175 | 176 | def apply_float_colormap(img: torch.Tensor, colormap: str = "turbo") -> torch.Tensor: 177 | """Convert single channel to a color img. 178 | 179 | Args: 180 | img (torch.Tensor): (..., 1) float32 single channel image. 181 | colormap (str): Colormap for img. 182 | 183 | Returns: 184 | (..., 3) colored img with colors in [0, 1]. 185 | """ 186 | img = torch.nan_to_num(img, 0) 187 | if colormap == "gray": 188 | return img.repeat(1, 1, 3) 189 | img_long = (img * 255).long() 190 | img_long_min = torch.min(img_long) 191 | img_long_max = torch.max(img_long) 192 | assert img_long_min >= 0, f"the min value is {img_long_min}" 193 | assert img_long_max <= 255, f"the max value is {img_long_max}" 194 | return torch.tensor( 195 | colormaps[colormap].colors, # type: ignore 196 | device=img.device, 197 | )[img_long[..., 0]] 198 | 199 | 200 | def apply_depth_colormap( 201 | depth: torch.Tensor, 202 | acc: torch.Tensor = None, 203 | near_plane: float = None, 204 | far_plane: float = None, 205 | ) -> torch.Tensor: 206 | """Converts a depth image to color for easier analysis. 207 | 208 | Args: 209 | depth (torch.Tensor): (..., 1) float32 depth. 210 | acc (torch.Tensor | None): (..., 1) optional accumulation mask. 211 | near_plane: Closest depth to consider. If None, use min image value. 212 | far_plane: Furthest depth to consider. If None, use max image value. 213 | 214 | Returns: 215 | (..., 3) colored depth image with colors in [0, 1]. 216 | """ 217 | near_plane = near_plane or float(torch.min(depth)) 218 | far_plane = far_plane or float(torch.max(depth)) 219 | depth = (depth - near_plane) / (far_plane - near_plane + 1e-10) 220 | depth = torch.clip(depth, 0.0, 1.0) 221 | img = apply_float_colormap(depth, colormap="turbo") 222 | if acc is not None: 223 | img = img * acc + (1.0 - acc) 224 | return img 225 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Difix3D+ 2 | 3 | **Difix3D+: Improving 3D Reconstructions with Single-Step Diffusion Models** 4 | [Jay Zhangjie Wu*](https://zhangjiewu.github.io/), [Yuxuan Zhang*](https://scholar.google.com/citations?user=Jt5VvNgAAAAJ&hl=en), [Haithem Turki](https://haithemturki.com/), [Xuanchi Ren](https://xuanchiren.com/), [Jun Gao](https://www.cs.toronto.edu/~jungao/), 5 | [Mike Zheng Shou](https://sites.google.com/view/showlab/home?authuser=0), [Sanja Fidler](https://www.cs.utoronto.ca/~fidler/), [Zan Gojcic†](https://zgojcic.github.io/), [Huan Ling†](https://www.cs.toronto.edu/~linghuan/) _(*/† equal contribution/advising)_ 6 | CVPR 2025 (Oral) 7 | [Project Page](https://research.nvidia.com/labs/toronto-ai/difix3d/) | [Paper](https://arxiv.org/abs/2503.01774) | [Model](https://huggingface.co/nvidia/difix) | [Demo](https://huggingface.co/spaces/nvidia/difix) 8 | 9 |
10 | 11 |
12 | 13 | 14 | ## News 15 | 16 | * [11/06/2025] Code and models are now available! We will present our work at CVPR 2025 ([oral](https://cvpr.thecvf.com/virtual/2025/oral/35364), [poster](https://cvpr.thecvf.com/virtual/2025/poster/34172)). See you in Nashville🎵! 17 | 18 | 19 | ## Setup 20 | 21 | ```bash 22 | git clone https://github.com/nv-tlabs/Difix3D.git 23 | cd Difix3D 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | ## Quickstart (diffusers) 28 | 29 | ``` 30 | from pipeline_difix import DifixPipeline 31 | from diffusers.utils import load_image 32 | 33 | pipe = DifixPipeline.from_pretrained("nvidia/difix", trust_remote_code=True) 34 | pipe.to("cuda") 35 | 36 | input_image = load_image("assets/example_input.png") 37 | prompt = "remove degradation" 38 | 39 | output_image = pipe(prompt, image=input_image, num_inference_steps=1, timesteps=[199], guidance_scale=0.0).images[0] 40 | output_image.save("example_output.png") 41 | ``` 42 | 43 | Optionally, you can use a reference image to guide the denoising process. 44 | ``` 45 | from pipeline_difix import DifixPipeline 46 | from diffusers.utils import load_image 47 | 48 | pipe = DifixPipeline.from_pretrained("nvidia/difix_ref", trust_remote_code=True) 49 | pipe.to("cuda") 50 | 51 | input_image = load_image("assets/example_input.png") 52 | ref_image = load_image("assets/example_ref.png") 53 | prompt = "remove degradation" 54 | 55 | output_image = pipe(prompt, image=input_image, ref_image=ref_image, num_inference_steps=1, timesteps=[199], guidance_scale=0.0).images[0] 56 | output_image.save("example_output.png") 57 | ``` 58 | 59 | ## Difix: Single-step diffusion for 3D artifact removal 60 | 61 | ### Training 62 | 63 | #### Data Preparation 64 | 65 | Prepare your dataset in the following JSON format: 66 | 67 | ```json 68 | { 69 | "train": { 70 | "{data_id}": { 71 | "image": "{PATH_TO_IMAGE}", 72 | "target_image": "{PATH_TO_TARGET_IMAGE}", 73 | "ref_image": "{PATH_TO_REF_IMAGE}", 74 | "prompt": "remove degradation" 75 | } 76 | }, 77 | "test": { 78 | "{data_id}": { 79 | "image": "{PATH_TO_IMAGE}", 80 | "target_image": "{PATH_TO_TARGET_IMAGE}", 81 | "ref_image": "{PATH_TO_REF_IMAGE}", 82 | "prompt": "remove degradation" 83 | } 84 | } 85 | } 86 | ``` 87 | 88 | #### Single GPU 89 | 90 | ```bash 91 | accelerate launch --mixed_precision=bf16 src/train_difix.py \ 92 | --output_dir=./outputs/difix/train \ 93 | --dataset_path="data/data.json" \ 94 | --max_train_steps 10000 \ 95 | --resolution=512 --learning_rate 2e-5 \ 96 | --train_batch_size=1 --dataloader_num_workers 8 \ 97 | --enable_xformers_memory_efficient_attention \ 98 | --checkpointing_steps=1000 --eval_freq 1000 --viz_freq 100 \ 99 | --lambda_lpips 1.0 --lambda_l2 1.0 --lambda_gram 1.0 --gram_loss_warmup_steps 2000 \ 100 | --report_to "wandb" --tracker_project_name "difix" --tracker_run_name "train" --timestep 199 101 | ``` 102 | 103 | #### Multipe GPUs 104 | 105 | ```bash 106 | export NUM_NODES=1 107 | export NUM_GPUS=8 108 | accelerate launch --mixed_precision=bf16 --main_process_port 29501 --multi_gpu --num_machines $NUM_NODES --num_processes $NUM_GPUS src/train_difix.py \ 109 | --output_dir=./outputs/difix/train \ 110 | --dataset_path="data/data.json" \ 111 | --max_train_steps 10000 \ 112 | --resolution=512 --learning_rate 2e-5 \ 113 | --train_batch_size=1 --dataloader_num_workers 8 \ 114 | --enable_xformers_memory_efficient_attention \ 115 | --checkpointing_steps=1000 --eval_freq 1000 --viz_freq 100 \ 116 | --lambda_lpips 1.0 --lambda_l2 1.0 --lambda_gram 1.0 --gram_loss_warmup_steps 2000 \ 117 | --report_to "wandb" --tracker_project_name "difix" --tracker_run_name "train" --timestep 199 118 | ``` 119 | 120 | ### Inference 121 | 122 | Place the `model_*.pkl` in the `checkpoints` directory. You can run inference using the following command: 123 | 124 | ```bash 125 | python src/inference_difix.py \ 126 | --model_path "checkpoints/model.pkl" \ 127 | --input_image "assets/example_input.png" \ 128 | --prompt "remove degradation" \ 129 | --output_dir "outputs/difix" \ 130 | --timestep 199 131 | ``` 132 | 133 | 134 | ## Difix3D: Progressive 3D update 135 | 136 | ### Data Format 137 | 138 | The data should be organized in the following structure: 139 | 140 | ``` 141 | DATA_DIR/ 142 | ├── {SCENE_ID} 143 | │ ├── colmap 144 | │ │ ├── sparse 145 | │ │ │ └── 0 146 | │ │ │ ├── cameras.bin 147 | │ │ │ ├── database.db 148 | │ │ │ └── ... 149 | │ ├── images 150 | │ │ ├── image_train_000001.png 151 | │ │ ├── image_train_000002.png 152 | │ │ ├── ... 153 | │ │ ├── image_eval_000200.png 154 | │ │ ├── image_eval_000201.png 155 | │ │ └── ... 156 | │ ├── images_2 157 | │ ├── images_4 158 | │ └── images_8 159 | ``` 160 | 161 | ### nerfstudio 162 | 163 | Setup the nerfstudio environment. 164 | ```bash 165 | cd examples/nerfstudio 166 | pip install -e . 167 | cd ../.. 168 | ``` 169 | 170 | Run Difix3D finetuning with nerfstudio. 171 | ```bash 172 | SCENE_ID=032dee9fb0a8bc1b90871dc5fe950080d0bcd3caf166447f44e60ca50ac04ec7 173 | DATA=DATA_DIR/${SCENE_ID} 174 | DATA_FACTOR=4 175 | CKPT_PATH=CKPR_DIR/${SCENE_ID}/nerfacto/nerfstudio_models/step-000029999.ckpt # Path to the pretrained checkpoint file 176 | OUTPUT_DIR=outputs/difix3d/nerfacto/${SCENE_ID} 177 | 178 | CUDA_VISIBLE_DEVICES=0 ns-train difix3d \ 179 | --data ${DATA} --pipeline.model.appearance-embed-dim 0 --pipeline.model.camera-optimizer.mode off --save_only_latest_checkpoint False --vis viewer \ 180 | --output_dir ${OUTPUT_DIR} --experiment_name ${SCENE_ID} --timestamp '' --load-checkpoint ${CKPT_PATH} \ 181 | --max_num_iterations 30000 --steps_per_eval_all_images 0 --steps_per_eval_batch 0 --steps_per_eval_image 0 --steps_per_save 2000 --viewer.quit-on-train-completion True \ 182 | nerfstudio-data --orientation-method none --center_method none --auto-scale-poses False --downscale_factor ${DATA_FACTOR} --eval_mode filename 183 | ``` 184 | 185 | ### gsplat 186 | 187 | Install the gsplat following the instructions in the [gsplat repository](https://github.com/nerfstudio-project/gsplat?tab=readme-ov-file#installation). 188 | 189 | Run Difix3D finetuning with gsplat. 190 | ```bash 191 | SCENE_ID=032dee9fb0a8bc1b90871dc5fe950080d0bcd3caf166447f44e60ca50ac04ec7 192 | DATA=DATA_DIR/${SCENE_ID}/gaussian_splat 193 | DATA_FACTOR=4 194 | CKPT_PATH=CKPT_DIR/${SCENE_ID}/ckpts/ckpt_29999_rank0.pt # Path to the pretrained checkpoint file 195 | OUTPUT_DIR=outputs/difix3d/gsplat/${SCENE_ID} 196 | 197 | CUDA_VISIBLE_DEVICES=0 python examples/gsplat/simple_trainer_difix3d.py default \ 198 | --data_dir ${DATA} --data_factor ${DATA_FACTOR} \ 199 | --result_dir ${OUTPUT_DIR} --no-normalize-world-space --test_every 1 --ckpt ${CKPT_PATH} 200 | ``` 201 | 202 | 203 | ## Difix3D+: With real-time post-rendering 204 | 205 | Due to the limited capacity of reconstruction methods to represent sharp details, some regions remain blurry. To further enhance the novel views, we use our Difix model as the final post-processing step at render time. 206 | 207 | ```bash 208 | python src/inference_difix.py \ 209 | --model_path "checkpoints/model.pkl" \ 210 | --input_image "PATH_TO_IMAGES" \ 211 | --prompt "remove degradation" \ 212 | --output_dir "outputs/difix3d+" \ 213 | --timestep 199 214 | ``` 215 | 216 | ## Acknowledgements 217 | 218 | Our work is built upon the following projects: 219 | - [diffusers](https://github.com/huggingface/diffusers) 220 | - [img2img-turbo](https://github.com/GaParmar/img2img-turbo) 221 | - [nerfstudio](https://github.com/nerfstudio-project/nerfstudio) 222 | - [gsplat](https://github.com/nerfstudio-project/gsplat) 223 | - [DL3DV-10K](https://github.com/DL3DV-10K/Dataset) 224 | - [nerfbusters](https://github.com/ethanweber/nerfbusters) 225 | 226 | Shoutout to all the contributors of these projects for their invaluable work that made this research possible. 227 | 228 | ## License/Terms of Use: 229 | 230 | The use of the model and code is governed by the NVIDIA License. See [LICENSE.txt](LICENSE.txt) for details. 231 | Additional Information: [LICENSE.md · stabilityai/sd-turbo at main](https://huggingface.co/stabilityai/sd-turbo/blob/main/LICENSE.md) 232 | 233 | ## Citation 234 | 235 | ```bibtex 236 | @inproceedings{wu2025difix3d+, 237 | title={DIFIX3D+: Improving 3D Reconstructions with Single-Step Diffusion Models}, 238 | author={Wu, Jay Zhangjie and Zhang, Yuxuan and Turki, Haithem and Ren, Xuanchi and Gao, Jun and Shou, Mike Zheng and Fidler, Sanja and Gojcic, Zan and Ling, Huan}, 239 | booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference}, 240 | pages={26024--26035}, 241 | year={2025} 242 | } 243 | ``` -------------------------------------------------------------------------------- /examples/nerfstudio/difix3d/difix3d_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from dataclasses import dataclass, field 17 | from typing import Optional, Type 18 | from pathlib import Path 19 | from PIL import Image 20 | import os 21 | import tqdm 22 | import random 23 | import numpy as np 24 | import torch 25 | from torch.cuda.amp.grad_scaler import GradScaler 26 | from typing_extensions import Literal 27 | from nerfstudio.cameras.cameras import Cameras 28 | from nerfstudio.pipelines.base_pipeline import VanillaPipeline, VanillaPipelineConfig 29 | from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs 30 | 31 | 32 | from difix3d.difix3d_datamanager import ( 33 | Difix3DDataManagerConfig, 34 | ) 35 | from src.pipeline_difix import DifixPipeline 36 | from examples.utils import CameraPoseInterpolator 37 | 38 | 39 | @dataclass 40 | class Difix3DPipelineConfig(VanillaPipelineConfig): 41 | """Configuration for pipeline instantiation""" 42 | 43 | _target: Type = field(default_factory=lambda: Difix3DPipeline) 44 | """target class to instantiate""" 45 | datamanager: Difix3DDataManagerConfig = Difix3DDataManagerConfig() 46 | """specifies the datamanager config""" 47 | steps_per_fix: int = 2000 48 | """rate at which to fix artifacts""" 49 | steps_per_val: int = 5000 50 | """rate at which to evaluate the model""" 51 | 52 | class Difix3DPipeline(VanillaPipeline): 53 | """Difix3D pipeline""" 54 | 55 | config: Difix3DPipelineConfig 56 | 57 | def __init__( 58 | self, 59 | config: Difix3DPipelineConfig, 60 | device: str, 61 | test_mode: Literal["test", "val", "inference"] = "val", 62 | world_size: int = 1, 63 | local_rank: int = 0, 64 | grad_scaler: Optional[GradScaler] = None, 65 | render_dir: str = "renders", 66 | ): 67 | super().__init__(config, device, test_mode, world_size, local_rank) 68 | 69 | self.render_dir = render_dir 70 | 71 | self.difix = DifixPipeline.from_pretrained("nvidia/difix_ref", trust_remote_code=True) 72 | self.difix.set_progress_bar_config(disable=True) 73 | self.difix.to("cuda") 74 | 75 | self.training_poses = self.datamanager.train_dataparser_outputs.cameras.camera_to_worlds 76 | self.training_poses = torch.cat([self.training_poses, torch.tensor([0, 0, 0, 1]).reshape(1, 1, 4).repeat(self.training_poses.shape[0], 1, 1)], dim=1) 77 | self.testing_poses = self.datamanager.dataparser.get_dataparser_outputs(split=self.datamanager.test_split).cameras.camera_to_worlds 78 | self.testing_poses = torch.cat([self.testing_poses, torch.tensor([0, 0, 0, 1]).reshape(1, 1, 4).repeat(self.testing_poses.shape[0], 1, 1)], dim=1) 79 | self.current_novel_poses = self.training_poses 80 | self.current_novel_cameras = self.datamanager.train_dataparser_outputs.cameras 81 | 82 | self.interpolator = CameraPoseInterpolator(rotation_weight=1.0, translation_weight=1.0) 83 | self.novel_datamanagers = [] 84 | 85 | def get_train_loss_dict(self, step: int): 86 | """This function gets your training loss dict and performs image editing. 87 | Args: 88 | step: current iteration step to update sampler if using DDP (distributed) 89 | """ 90 | if len(self.novel_datamanagers) == 0 or random.random() < 0.6: 91 | ray_bundle, batch = self.datamanager.next_train(step) 92 | else: 93 | ray_bundle, batch = self.novel_datamanagers[-1].next_train(step) 94 | 95 | model_outputs = self.model(ray_bundle) 96 | metrics_dict = self.model.get_metrics_dict(model_outputs, batch) 97 | 98 | loss_dict = self.model.get_loss_dict(model_outputs, batch, metrics_dict) 99 | 100 | # run fixer 101 | if (step % self.config.steps_per_fix == 0): 102 | self.fix(step) 103 | 104 | # run evaluation 105 | if (step % self.config.steps_per_val == 0): 106 | self.val(step) 107 | 108 | return model_outputs, loss_dict, metrics_dict 109 | 110 | def forward(self): 111 | """Not implemented since we only want the parameter saving of the nn module, but not forward()""" 112 | raise NotImplementedError 113 | 114 | @torch.no_grad() 115 | def render_traj(self, step, cameras, tag="novel"): 116 | for i in tqdm.trange(0, len(cameras), desc="Rendering trajectory"): 117 | with torch.no_grad(): 118 | outputs = self.model.get_outputs_for_camera(cameras[i]) 119 | 120 | rgb_path = f"{self.render_dir}/{tag}/{step}/Pred/{i:04d}.png" 121 | os.makedirs(os.path.dirname(rgb_path), exist_ok=True) 122 | rgb_canvas = outputs['rgb'].cpu().numpy() 123 | rgb_canvas = (rgb_canvas * 255).astype(np.uint8) 124 | Image.fromarray(rgb_canvas).save(rgb_path) 125 | 126 | @torch.no_grad() 127 | def val(self, step): 128 | cameras = self.datamanager.dataparser.get_dataparser_outputs(split=self.datamanager.test_split).cameras 129 | for i in tqdm.trange(0, len(cameras), desc="Running evaluation"): 130 | with torch.no_grad(): 131 | outputs = self.model.get_outputs_for_camera(cameras[i]) 132 | 133 | rgb_path = f"{self.render_dir}/val/{step}/{i:04d}.png" 134 | os.makedirs(os.path.dirname(rgb_path), exist_ok=True) 135 | rgb_canvas = outputs['rgb'].cpu().numpy() 136 | rgb_canvas = (rgb_canvas * 255).astype(np.uint8) 137 | Image.fromarray(rgb_canvas).save(rgb_path) 138 | 139 | @torch.no_grad() 140 | def fix(self, step: int): 141 | 142 | novel_poses = self.interpolator.shift_poses(self.current_novel_poses.numpy(), self.testing_poses.numpy(), distance=0.5) 143 | novel_poses = torch.from_numpy(novel_poses).to(self.testing_poses.dtype) 144 | 145 | ref_image_indices = self.interpolator.find_nearest_assignments(self.training_poses.numpy(), novel_poses.numpy()) 146 | ref_image_filenames = np.array(self.datamanager.train_dataparser_outputs.image_filenames)[ref_image_indices].tolist() 147 | 148 | cameras = self.datamanager.train_dataparser_outputs.cameras 149 | cameras = Cameras( 150 | fx=cameras.fx[0].repeat(len(novel_poses), 1), 151 | fy=cameras.fy[0].repeat(len(novel_poses), 1), 152 | cx=cameras.cx[0].repeat(len(novel_poses), 1), 153 | cy=cameras.cy[0].repeat(len(novel_poses), 1), 154 | distortion_params=cameras.distortion_params[0].repeat(len(novel_poses), 1), 155 | height=cameras.height[0].repeat(len(novel_poses), 1), 156 | width=cameras.width[0].repeat(len(novel_poses), 1), 157 | camera_to_worlds=novel_poses[:, :3, :4], 158 | camera_type=cameras.camera_type[0].repeat(len(novel_poses), 1), 159 | metadata=cameras.metadata, 160 | ) 161 | 162 | self.render_traj(step, cameras) 163 | 164 | image_filenames = [] 165 | for i in tqdm.trange(0, len(novel_poses), desc="Fixing artifacts..."): 166 | image = Image.open(f"{self.render_dir}/novel/{step}/Pred/{i:04d}.png").convert("RGB") 167 | ref_image = Image.open(ref_image_filenames[i]).convert("RGB") 168 | output_image = self.difix(prompt="remove degradation", image=image, ref_image=ref_image, num_inference_steps=1, timesteps=[199], guidance_scale=0.0).images[0] 169 | output_image = output_image.resize(image.size, Image.LANCZOS) 170 | os.makedirs(f"{self.render_dir}/novel/{step}/Fixed", exist_ok=True) 171 | output_image.save(f"{self.render_dir}/novel/{step}/Fixed/{i:04d}.png") 172 | image_filenames.append(Path(f"{self.render_dir}/novel/{step}/Fixed/{i:04d}.png")) 173 | if ref_image is not None: 174 | os.makedirs(f"{self.render_dir}/novel/{step}/Ref", exist_ok=True) 175 | ref_image.save(f"{self.render_dir}/novel/{step}/Ref/{i:04d}.png") 176 | 177 | dataparser_outputs = self.datamanager.train_dataparser_outputs 178 | dataparser_outputs = DataparserOutputs( 179 | image_filenames=image_filenames, 180 | cameras=cameras, 181 | scene_box=dataparser_outputs.scene_box, 182 | mask_filenames=None, 183 | dataparser_scale=dataparser_outputs.dataparser_scale, 184 | dataparser_transform=dataparser_outputs.dataparser_transform, 185 | metadata=dataparser_outputs.metadata, 186 | ) 187 | 188 | datamanager_config = Difix3DDataManagerConfig( 189 | dataparser=self.config.datamanager.dataparser, 190 | train_num_rays_per_batch=16384, 191 | eval_num_rays_per_batch=4096, 192 | ) 193 | 194 | datamanager = datamanager_config.setup( 195 | device=self.datamanager.device, 196 | test_mode=self.datamanager.test_mode, 197 | world_size=self.datamanager.world_size, 198 | local_rank=self.datamanager.local_rank 199 | ) 200 | 201 | datamanager.train_dataparser_outputs = dataparser_outputs 202 | datamanager.train_dataset = datamanager.create_train_dataset() 203 | datamanager.setup_train() 204 | 205 | self.novel_datamanagers.append(datamanager) 206 | self.current_novel_poses = novel_poses 207 | self.current_novel_cameras = cameras -------------------------------------------------------------------------------- /examples/gsplat/datasets/traj.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code borrowed from 3 | 4 | https://github.com/google-research/multinerf/blob/5b4d4f64608ec8077222c52fdf814d40acc10bc1/internal/camera_utils.py 5 | """ 6 | 7 | import numpy as np 8 | import scipy 9 | 10 | 11 | def normalize(x: np.ndarray) -> np.ndarray: 12 | """Normalization helper function.""" 13 | return x / np.linalg.norm(x) 14 | 15 | 16 | def viewmatrix(lookdir: np.ndarray, up: np.ndarray, position: np.ndarray) -> np.ndarray: 17 | """Construct lookat view matrix.""" 18 | vec2 = normalize(lookdir) 19 | vec0 = normalize(np.cross(up, vec2)) 20 | vec1 = normalize(np.cross(vec2, vec0)) 21 | m = np.stack([vec0, vec1, vec2, position], axis=1) 22 | return m 23 | 24 | 25 | def focus_point_fn(poses: np.ndarray) -> np.ndarray: 26 | """Calculate nearest point to all focal axes in poses.""" 27 | directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4] 28 | m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1]) 29 | mt_m = np.transpose(m, [0, 2, 1]) @ m 30 | focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0] 31 | return focus_pt 32 | 33 | 34 | def average_pose(poses: np.ndarray) -> np.ndarray: 35 | """New pose using average position, z-axis, and up vector of input poses.""" 36 | position = poses[:, :3, 3].mean(0) 37 | z_axis = poses[:, :3, 2].mean(0) 38 | up = poses[:, :3, 1].mean(0) 39 | cam2world = viewmatrix(z_axis, up, position) 40 | return cam2world 41 | 42 | 43 | def generate_spiral_path( 44 | poses, 45 | bounds, 46 | n_frames=120, 47 | n_rots=2, 48 | zrate=0.5, 49 | spiral_scale_f=1.0, 50 | spiral_scale_r=1.0, 51 | focus_distance=0.75, 52 | ): 53 | """Calculates a forward facing spiral path for rendering.""" 54 | # Find a reasonable 'focus depth' for this dataset as a weighted average 55 | # of conservative near and far bounds in disparity space. 56 | near_bound = bounds.min() 57 | far_bound = bounds.max() 58 | # All cameras will point towards the world space point (0, 0, -focal). 59 | focal = 1 / (((1 - focus_distance) / near_bound + focus_distance / far_bound)) 60 | focal = focal * spiral_scale_f 61 | 62 | # Get radii for spiral path using 90th percentile of camera positions. 63 | positions = poses[:, :3, 3] 64 | radii = np.percentile(np.abs(positions), 90, 0) 65 | radii = radii * spiral_scale_r 66 | radii = np.concatenate([radii, [1.0]]) 67 | 68 | # Generate poses for spiral path. 69 | render_poses = [] 70 | cam2world = average_pose(poses) 71 | up = poses[:, :3, 1].mean(0) 72 | for theta in np.linspace(0.0, 2.0 * np.pi * n_rots, n_frames, endpoint=False): 73 | t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0] 74 | position = cam2world @ t 75 | lookat = cam2world @ [0, 0, -focal, 1.0] 76 | z_axis = position - lookat 77 | render_poses.append(viewmatrix(z_axis, up, position)) 78 | render_poses = np.stack(render_poses, axis=0) 79 | return render_poses 80 | 81 | 82 | def generate_ellipse_path_z( 83 | poses: np.ndarray, 84 | n_frames: int = 120, 85 | # const_speed: bool = True, 86 | variation: float = 0.0, 87 | phase: float = 0.0, 88 | height: float = 0.0, 89 | ) -> np.ndarray: 90 | """Generate an elliptical render path based on the given poses.""" 91 | # Calculate the focal point for the path (cameras point toward this). 92 | center = focus_point_fn(poses) 93 | # Path height sits at z=height (in middle of zero-mean capture pattern). 94 | offset = np.array([center[0], center[1], height]) 95 | 96 | # Calculate scaling for ellipse axes based on input camera positions. 97 | sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0) 98 | # Use ellipse that is symmetric about the focal point in xy. 99 | low = -sc + offset 100 | high = sc + offset 101 | # Optional height variation need not be symmetric 102 | z_low = np.percentile((poses[:, :3, 3]), 10, axis=0) 103 | z_high = np.percentile((poses[:, :3, 3]), 90, axis=0) 104 | 105 | def get_positions(theta): 106 | # Interpolate between bounds with trig functions to get ellipse in x-y. 107 | # Optionally also interpolate in z to change camera height along path. 108 | return np.stack( 109 | [ 110 | low[0] + (high - low)[0] * (np.cos(theta) * 0.5 + 0.5), 111 | low[1] + (high - low)[1] * (np.sin(theta) * 0.5 + 0.5), 112 | variation 113 | * ( 114 | z_low[2] 115 | + (z_high - z_low)[2] 116 | * (np.cos(theta + 2 * np.pi * phase) * 0.5 + 0.5) 117 | ) 118 | + height, 119 | ], 120 | -1, 121 | ) 122 | 123 | theta = np.linspace(0, 2.0 * np.pi, n_frames + 1, endpoint=True) 124 | positions = get_positions(theta) 125 | 126 | # if const_speed: 127 | # # Resample theta angles so that the velocity is closer to constant. 128 | # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1) 129 | # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1) 130 | # positions = get_positions(theta) 131 | 132 | # Throw away duplicated last position. 133 | positions = positions[:-1] 134 | 135 | # Set path's up vector to axis closest to average of input pose up vectors. 136 | avg_up = poses[:, :3, 1].mean(0) 137 | avg_up = avg_up / np.linalg.norm(avg_up) 138 | ind_up = np.argmax(np.abs(avg_up)) 139 | up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up]) 140 | 141 | return np.stack([viewmatrix(center - p, up, p) for p in positions]) 142 | 143 | 144 | def generate_ellipse_path_y( 145 | poses: np.ndarray, 146 | n_frames: int = 120, 147 | # const_speed: bool = True, 148 | variation: float = 0.0, 149 | phase: float = 0.0, 150 | height: float = 0.0, 151 | ) -> np.ndarray: 152 | """Generate an elliptical render path based on the given poses.""" 153 | # Calculate the focal point for the path (cameras point toward this). 154 | center = focus_point_fn(poses) 155 | # Path height sits at y=height (in middle of zero-mean capture pattern). 156 | offset = np.array([center[0], height, center[2]]) 157 | 158 | # Calculate scaling for ellipse axes based on input camera positions. 159 | sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0) 160 | # Use ellipse that is symmetric about the focal point in xy. 161 | low = -sc + offset 162 | high = sc + offset 163 | # Optional height variation need not be symmetric 164 | y_low = np.percentile((poses[:, :3, 3]), 10, axis=0) 165 | y_high = np.percentile((poses[:, :3, 3]), 90, axis=0) 166 | 167 | def get_positions(theta): 168 | # Interpolate between bounds with trig functions to get ellipse in x-z. 169 | # Optionally also interpolate in y to change camera height along path. 170 | return np.stack( 171 | [ 172 | low[0] + (high - low)[0] * (np.cos(theta) * 0.5 + 0.5), 173 | variation 174 | * ( 175 | y_low[1] 176 | + (y_high - y_low)[1] 177 | * (np.cos(theta + 2 * np.pi * phase) * 0.5 + 0.5) 178 | ) 179 | + height, 180 | low[2] + (high - low)[2] * (np.sin(theta) * 0.5 + 0.5), 181 | ], 182 | -1, 183 | ) 184 | 185 | theta = np.linspace(0, 2.0 * np.pi, n_frames + 1, endpoint=True) 186 | positions = get_positions(theta) 187 | 188 | # if const_speed: 189 | # # Resample theta angles so that the velocity is closer to constant. 190 | # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1) 191 | # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1) 192 | # positions = get_positions(theta) 193 | 194 | # Throw away duplicated last position. 195 | positions = positions[:-1] 196 | 197 | # Set path's up vector to axis closest to average of input pose up vectors. 198 | avg_up = poses[:, :3, 1].mean(0) 199 | avg_up = avg_up / np.linalg.norm(avg_up) 200 | ind_up = np.argmax(np.abs(avg_up)) 201 | up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up]) 202 | 203 | return np.stack([viewmatrix(p - center, up, p) for p in positions]) 204 | 205 | 206 | def generate_interpolated_path( 207 | poses: np.ndarray, 208 | n_interp: int, 209 | spline_degree: int = 5, 210 | smoothness: float = 0.03, 211 | rot_weight: float = 0.1, 212 | ): 213 | """Creates a smooth spline path between input keyframe camera poses. 214 | 215 | Spline is calculated with poses in format (position, lookat-point, up-point). 216 | 217 | Args: 218 | poses: (n, 3, 4) array of input pose keyframes. 219 | n_interp: returned path will have n_interp * (n - 1) total poses. 220 | spline_degree: polynomial degree of B-spline. 221 | smoothness: parameter for spline smoothing, 0 forces exact interpolation. 222 | rot_weight: relative weighting of rotation/translation in spline solve. 223 | 224 | Returns: 225 | Array of new camera poses with shape (n_interp * (n - 1), 3, 4). 226 | """ 227 | 228 | def poses_to_points(poses, dist): 229 | """Converts from pose matrices to (position, lookat, up) format.""" 230 | pos = poses[:, :3, -1] 231 | lookat = poses[:, :3, -1] - dist * poses[:, :3, 2] 232 | up = poses[:, :3, -1] + dist * poses[:, :3, 1] 233 | return np.stack([pos, lookat, up], 1) 234 | 235 | def points_to_poses(points): 236 | """Converts from (position, lookat, up) format to pose matrices.""" 237 | return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points]) 238 | 239 | def interp(points, n, k, s): 240 | """Runs multidimensional B-spline interpolation on the input points.""" 241 | sh = points.shape 242 | pts = np.reshape(points, (sh[0], -1)) 243 | k = min(k, sh[0] - 1) 244 | tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s) 245 | u = np.linspace(0, 1, n, endpoint=False) 246 | new_points = np.array(scipy.interpolate.splev(u, tck)) 247 | new_points = np.reshape(new_points.T, (n, sh[1], sh[2])) 248 | return new_points 249 | 250 | points = poses_to_points(poses, dist=rot_weight) 251 | new_points = interp( 252 | points, n_interp * (points.shape[0] - 1), k=spline_degree, s=smoothness 253 | ) 254 | return points_to_poses(new_points) 255 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import sys 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | import torch 8 | from torchvision import transforms 9 | from transformers import AutoTokenizer, CLIPTextModel 10 | from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler 11 | from peft import LoraConfig 12 | p = "src/" 13 | sys.path.append(p) 14 | from einops import rearrange, repeat 15 | 16 | 17 | def make_1step_sched(): 18 | noise_scheduler_1step = DDPMScheduler.from_pretrained("stabilityai/sd-turbo", subfolder="scheduler") 19 | noise_scheduler_1step.set_timesteps(1, device="cuda") 20 | noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda() 21 | return noise_scheduler_1step 22 | 23 | 24 | def my_vae_encoder_fwd(self, sample): 25 | sample = self.conv_in(sample) 26 | l_blocks = [] 27 | # down 28 | for down_block in self.down_blocks: 29 | l_blocks.append(sample) 30 | sample = down_block(sample) 31 | # middle 32 | sample = self.mid_block(sample) 33 | sample = self.conv_norm_out(sample) 34 | sample = self.conv_act(sample) 35 | sample = self.conv_out(sample) 36 | self.current_down_blocks = l_blocks 37 | return sample 38 | 39 | 40 | def my_vae_decoder_fwd(self, sample, latent_embeds=None): 41 | sample = self.conv_in(sample) 42 | upscale_dtype = next(iter(self.up_blocks.parameters())).dtype 43 | # middle 44 | sample = self.mid_block(sample, latent_embeds) 45 | sample = sample.to(upscale_dtype) 46 | if not self.ignore_skip: 47 | skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4] 48 | # up 49 | for idx, up_block in enumerate(self.up_blocks): 50 | skip_in = skip_convs[idx](self.incoming_skip_acts[::-1][idx] * self.gamma) 51 | # add skip 52 | sample = sample + skip_in 53 | sample = up_block(sample, latent_embeds) 54 | else: 55 | for idx, up_block in enumerate(self.up_blocks): 56 | sample = up_block(sample, latent_embeds) 57 | # post-process 58 | if latent_embeds is None: 59 | sample = self.conv_norm_out(sample) 60 | else: 61 | sample = self.conv_norm_out(sample, latent_embeds) 62 | sample = self.conv_act(sample) 63 | sample = self.conv_out(sample) 64 | return sample 65 | 66 | 67 | def download_url(url, outf): 68 | if not os.path.exists(outf): 69 | print(f"Downloading checkpoint to {outf}") 70 | response = requests.get(url, stream=True) 71 | total_size_in_bytes = int(response.headers.get('content-length', 0)) 72 | block_size = 1024 # 1 Kibibyte 73 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) 74 | with open(outf, 'wb') as file: 75 | for data in response.iter_content(block_size): 76 | progress_bar.update(len(data)) 77 | file.write(data) 78 | progress_bar.close() 79 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: 80 | print("ERROR, something went wrong") 81 | print(f"Downloaded successfully to {outf}") 82 | else: 83 | print(f"Skipping download, {outf} already exists") 84 | 85 | 86 | def load_ckpt_from_state_dict(net_difix, optimizer, pretrained_path): 87 | sd = torch.load(pretrained_path, map_location="cpu") 88 | 89 | if "state_dict_vae" in sd: 90 | _sd_vae = net_difix.vae.state_dict() 91 | for k in sd["state_dict_vae"]: 92 | _sd_vae[k] = sd["state_dict_vae"][k] 93 | net_difix.vae.load_state_dict(_sd_vae) 94 | _sd_unet = net_difix.unet.state_dict() 95 | for k in sd["state_dict_unet"]: 96 | _sd_unet[k] = sd["state_dict_unet"][k] 97 | net_difix.unet.load_state_dict(_sd_unet) 98 | 99 | optimizer.load_state_dict(sd["optimizer"]) 100 | 101 | return net_difix, optimizer 102 | 103 | 104 | def save_ckpt(net_difix, optimizer, outf): 105 | sd = {} 106 | sd["vae_lora_target_modules"] = net_difix.target_modules_vae 107 | sd["rank_vae"] = net_difix.lora_rank_vae 108 | sd["state_dict_unet"] = net_difix.unet.state_dict() 109 | sd["state_dict_vae"] = {k: v for k, v in net_difix.vae.state_dict().items() if "lora" in k or "skip" in k} 110 | 111 | sd["optimizer"] = optimizer.state_dict() 112 | 113 | torch.save(sd, outf) 114 | 115 | 116 | class Difix(torch.nn.Module): 117 | def __init__(self, pretrained_name=None, pretrained_path=None, ckpt_folder="checkpoints", lora_rank_vae=4, mv_unet=False, timestep=999): 118 | super().__init__() 119 | self.tokenizer = AutoTokenizer.from_pretrained("stabilityai/sd-turbo", subfolder="tokenizer") 120 | self.text_encoder = CLIPTextModel.from_pretrained("stabilityai/sd-turbo", subfolder="text_encoder").cuda() 121 | self.sched = make_1step_sched() 122 | 123 | vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae") 124 | vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__) 125 | vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__) 126 | # add the skip connection convs 127 | vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda() 128 | vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda() 129 | vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda() 130 | vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda() 131 | vae.decoder.ignore_skip = False 132 | 133 | if mv_unet: 134 | from mv_unet import UNet2DConditionModel 135 | else: 136 | from diffusers import UNet2DConditionModel 137 | 138 | unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet") 139 | 140 | if pretrained_path is not None: 141 | sd = torch.load(pretrained_path, map_location="cpu") 142 | vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"]) 143 | vae.add_adapter(vae_lora_config, adapter_name="vae_skip") 144 | _sd_vae = vae.state_dict() 145 | for k in sd["state_dict_vae"]: 146 | _sd_vae[k] = sd["state_dict_vae"][k] 147 | vae.load_state_dict(_sd_vae) 148 | _sd_unet = unet.state_dict() 149 | for k in sd["state_dict_unet"]: 150 | _sd_unet[k] = sd["state_dict_unet"][k] 151 | unet.load_state_dict(_sd_unet) 152 | 153 | elif pretrained_name is None and pretrained_path is None: 154 | print("Initializing model with random weights") 155 | target_modules_vae = [] 156 | 157 | torch.nn.init.constant_(vae.decoder.skip_conv_1.weight, 1e-5) 158 | torch.nn.init.constant_(vae.decoder.skip_conv_2.weight, 1e-5) 159 | torch.nn.init.constant_(vae.decoder.skip_conv_3.weight, 1e-5) 160 | torch.nn.init.constant_(vae.decoder.skip_conv_4.weight, 1e-5) 161 | target_modules_vae = ["conv1", "conv2", "conv_in", "conv_shortcut", "conv", "conv_out", 162 | "skip_conv_1", "skip_conv_2", "skip_conv_3", "skip_conv_4", 163 | "to_k", "to_q", "to_v", "to_out.0", 164 | ] 165 | 166 | target_modules = [] 167 | for id, (name, param) in enumerate(vae.named_modules()): 168 | if 'decoder' in name and any(name.endswith(x) for x in target_modules_vae): 169 | target_modules.append(name) 170 | target_modules_vae = target_modules 171 | vae.encoder.requires_grad_(False) 172 | 173 | vae_lora_config = LoraConfig(r=lora_rank_vae, init_lora_weights="gaussian", 174 | target_modules=target_modules_vae) 175 | vae.add_adapter(vae_lora_config, adapter_name="vae_skip") 176 | 177 | self.lora_rank_vae = lora_rank_vae 178 | self.target_modules_vae = target_modules_vae 179 | 180 | # unet.enable_xformers_memory_efficient_attention() 181 | unet.to("cuda") 182 | vae.to("cuda") 183 | 184 | self.unet, self.vae = unet, vae 185 | self.vae.decoder.gamma = 1 186 | self.timesteps = torch.tensor([timestep], device="cuda").long() 187 | self.text_encoder.requires_grad_(False) 188 | 189 | # print number of trainable parameters 190 | print("="*50) 191 | print(f"Number of trainable parameters in UNet: {sum(p.numel() for p in unet.parameters() if p.requires_grad) / 1e6:.2f}M") 192 | print(f"Number of trainable parameters in VAE: {sum(p.numel() for p in vae.parameters() if p.requires_grad) / 1e6:.2f}M") 193 | print("="*50) 194 | 195 | def set_eval(self): 196 | self.unet.eval() 197 | self.vae.eval() 198 | self.unet.requires_grad_(False) 199 | self.vae.requires_grad_(False) 200 | 201 | def set_train(self): 202 | self.unet.train() 203 | self.vae.train() 204 | self.unet.requires_grad_(True) 205 | 206 | for n, _p in self.vae.named_parameters(): 207 | if "lora" in n: 208 | _p.requires_grad = True 209 | self.vae.decoder.skip_conv_1.requires_grad_(True) 210 | self.vae.decoder.skip_conv_2.requires_grad_(True) 211 | self.vae.decoder.skip_conv_3.requires_grad_(True) 212 | self.vae.decoder.skip_conv_4.requires_grad_(True) 213 | 214 | def forward(self, x, timesteps=None, prompt=None, prompt_tokens=None): 215 | # either the prompt or the prompt_tokens should be provided 216 | assert (prompt is None) != (prompt_tokens is None), "Either prompt or prompt_tokens should be provided" 217 | assert (timesteps is None) != (self.timesteps is None), "Either timesteps or self.timesteps should be provided" 218 | 219 | if prompt is not None: 220 | # encode the text prompt 221 | caption_tokens = self.tokenizer(prompt, max_length=self.tokenizer.model_max_length, 222 | padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda() 223 | caption_enc = self.text_encoder(caption_tokens)[0] 224 | else: 225 | caption_enc = self.text_encoder(prompt_tokens)[0] 226 | 227 | num_views = x.shape[1] 228 | x = rearrange(x, 'b v c h w -> (b v) c h w') 229 | z = self.vae.encode(x).latent_dist.sample() * self.vae.config.scaling_factor 230 | caption_enc = repeat(caption_enc, 'b n c -> (b v) n c', v=num_views) 231 | 232 | unet_input = z 233 | 234 | model_pred = self.unet(unet_input, self.timesteps, encoder_hidden_states=caption_enc,).sample 235 | z_denoised = self.sched.step(model_pred, self.timesteps, z, return_dict=True).prev_sample 236 | self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks 237 | output_image = (self.vae.decode(z_denoised / self.vae.config.scaling_factor).sample).clamp(-1, 1) 238 | output_image = rearrange(output_image, '(b v) c h w -> b v c h w', v=num_views) 239 | 240 | return output_image 241 | 242 | def sample(self, image, width, height, ref_image=None, timesteps=None, prompt=None, prompt_tokens=None): 243 | input_width, input_height = image.size 244 | new_width = image.width - image.width % 8 245 | new_height = image.height - image.height % 8 246 | image = image.resize((new_width, new_height), Image.LANCZOS) 247 | 248 | T = transforms.Compose([ 249 | transforms.Resize((height, width), interpolation=Image.LANCZOS), 250 | transforms.ToTensor(), 251 | transforms.Normalize([0.5], [0.5]), 252 | ]) 253 | if ref_image is None: 254 | x = T(image).unsqueeze(0).unsqueeze(0).cuda() 255 | else: 256 | ref_image = ref_image.resize((new_width, new_height), Image.LANCZOS) 257 | x = torch.stack([T(image), T(ref_image)], dim=0).unsqueeze(0).cuda() 258 | 259 | output_image = self.forward(x, timesteps, prompt, prompt_tokens)[:, 0] 260 | output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5) 261 | output_pil = output_pil.resize((input_width, input_height), Image.LANCZOS) 262 | 263 | return output_pil 264 | 265 | def save_model(self, outf, optimizer): 266 | sd = {} 267 | sd["vae_lora_target_modules"] = self.target_modules_vae 268 | sd["rank_vae"] = self.lora_rank_vae 269 | sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k} 270 | sd["state_dict_vae"] = {k: v for k, v in self.vae.state_dict().items() if "lora" in k or "skip" in k} 271 | 272 | sd["optimizer"] = optimizer.state_dict() 273 | 274 | torch.save(sd, outf) 275 | -------------------------------------------------------------------------------- /examples/nerfstudio/difix3d/difix3d_field.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Dict, Literal, Optional, Tuple 17 | 18 | import torch 19 | from torch import Tensor, nn 20 | 21 | from nerfstudio.cameras.rays import RaySamples 22 | from nerfstudio.data.scene_box import SceneBox 23 | from nerfstudio.field_components.activations import trunc_exp 24 | from nerfstudio.field_components.embedding import Embedding 25 | from nerfstudio.field_components.encodings import NeRFEncoding, SHEncoding 26 | from nerfstudio.field_components.field_heads import ( 27 | FieldHeadNames, 28 | PredNormalsFieldHead, 29 | SemanticFieldHead, 30 | TransientDensityFieldHead, 31 | TransientRGBFieldHead, 32 | UncertaintyFieldHead, 33 | ) 34 | from nerfstudio.field_components.mlp import MLP, MLPWithHashEncoding 35 | from nerfstudio.field_components.spatial_distortions import SpatialDistortion 36 | from nerfstudio.fields.base_field import Field, get_normalized_directions 37 | 38 | 39 | class Difix3DField(Field): 40 | """Compound Field 41 | 42 | Args: 43 | aabb: parameters of scene aabb bounds 44 | num_images: number of images in the dataset 45 | num_layers: number of hidden layers 46 | hidden_dim: dimension of hidden layers 47 | geo_feat_dim: output geo feat dimensions 48 | num_levels: number of levels of the hashmap for the base mlp 49 | base_res: base resolution of the hashmap for the base mlp 50 | max_res: maximum resolution of the hashmap for the base mlp 51 | log2_hashmap_size: size of the hashmap for the base mlp 52 | num_layers_color: number of hidden layers for color network 53 | num_layers_transient: number of hidden layers for transient network 54 | features_per_level: number of features per level for the hashgrid 55 | hidden_dim_color: dimension of hidden layers for color network 56 | hidden_dim_transient: dimension of hidden layers for transient network 57 | appearance_embedding_dim: dimension of appearance embedding 58 | transient_embedding_dim: dimension of transient embedding 59 | use_transient_embedding: whether to use transient embedding 60 | use_semantics: whether to use semantic segmentation 61 | num_semantic_classes: number of semantic classes 62 | use_pred_normals: whether to use predicted normals 63 | use_average_appearance_embedding: whether to use average appearance embedding or zeros for inference 64 | spatial_distortion: spatial distortion to apply to the scene 65 | """ 66 | 67 | aabb: Tensor 68 | 69 | def __init__( 70 | self, 71 | aabb: Tensor, 72 | num_images: int, 73 | num_layers: int = 2, 74 | hidden_dim: int = 64, 75 | geo_feat_dim: int = 15, 76 | num_levels: int = 16, 77 | base_res: int = 16, 78 | max_res: int = 2048, 79 | log2_hashmap_size: int = 19, 80 | num_layers_color: int = 3, 81 | num_layers_transient: int = 2, 82 | features_per_level: int = 2, 83 | hidden_dim_color: int = 64, 84 | hidden_dim_transient: int = 64, 85 | appearance_embedding_dim: int = 32, 86 | transient_embedding_dim: int = 16, 87 | use_transient_embedding: bool = False, 88 | use_semantics: bool = False, 89 | num_semantic_classes: int = 100, 90 | pass_semantic_gradients: bool = False, 91 | use_pred_normals: bool = False, 92 | use_average_appearance_embedding: bool = False, 93 | spatial_distortion: Optional[SpatialDistortion] = None, 94 | average_init_density: float = 1.0, 95 | implementation: Literal["tcnn", "torch"] = "tcnn", 96 | freeze_appearance_embedding: bool = False, 97 | ) -> None: 98 | super().__init__() 99 | 100 | self.register_buffer("aabb", aabb) 101 | self.geo_feat_dim = geo_feat_dim 102 | 103 | self.register_buffer("max_res", torch.tensor(max_res)) 104 | self.register_buffer("num_levels", torch.tensor(num_levels)) 105 | self.register_buffer("log2_hashmap_size", torch.tensor(log2_hashmap_size)) 106 | 107 | self.spatial_distortion = spatial_distortion 108 | self.num_images = num_images 109 | self.appearance_embedding_dim = appearance_embedding_dim 110 | if self.appearance_embedding_dim > 0: 111 | self.embedding_appearance = Embedding(self.num_images, self.appearance_embedding_dim) 112 | else: 113 | self.embedding_appearance = None 114 | self.use_average_appearance_embedding = use_average_appearance_embedding 115 | self.freeze_appearance_embedding = freeze_appearance_embedding 116 | self.use_transient_embedding = use_transient_embedding 117 | self.use_semantics = use_semantics 118 | self.use_pred_normals = use_pred_normals 119 | self.pass_semantic_gradients = pass_semantic_gradients 120 | self.base_res = base_res 121 | self.average_init_density = average_init_density 122 | self.step = 0 123 | 124 | self.direction_encoding = SHEncoding( 125 | levels=4, 126 | implementation=implementation, 127 | ) 128 | 129 | self.position_encoding = NeRFEncoding( 130 | in_dim=3, num_frequencies=2, min_freq_exp=0, max_freq_exp=2 - 1, implementation=implementation 131 | ) 132 | 133 | self.mlp_base = MLPWithHashEncoding( 134 | num_levels=num_levels, 135 | min_res=base_res, 136 | max_res=max_res, 137 | log2_hashmap_size=log2_hashmap_size, 138 | features_per_level=features_per_level, 139 | num_layers=num_layers, 140 | layer_width=hidden_dim, 141 | out_dim=1 + self.geo_feat_dim, 142 | activation=nn.ReLU(), 143 | out_activation=None, 144 | implementation=implementation, 145 | ) 146 | 147 | # transients 148 | if self.use_transient_embedding: 149 | self.transient_embedding_dim = transient_embedding_dim 150 | self.embedding_transient = Embedding(self.num_images, self.transient_embedding_dim) 151 | self.mlp_transient = MLP( 152 | in_dim=self.geo_feat_dim + self.transient_embedding_dim, 153 | num_layers=num_layers_transient, 154 | layer_width=hidden_dim_transient, 155 | out_dim=hidden_dim_transient, 156 | activation=nn.ReLU(), 157 | out_activation=None, 158 | implementation=implementation, 159 | ) 160 | self.field_head_transient_uncertainty = UncertaintyFieldHead(in_dim=self.mlp_transient.get_out_dim()) 161 | self.field_head_transient_rgb = TransientRGBFieldHead(in_dim=self.mlp_transient.get_out_dim()) 162 | self.field_head_transient_density = TransientDensityFieldHead(in_dim=self.mlp_transient.get_out_dim()) 163 | 164 | # semantics 165 | if self.use_semantics: 166 | self.mlp_semantics = MLP( 167 | in_dim=self.geo_feat_dim, 168 | num_layers=2, 169 | layer_width=64, 170 | out_dim=hidden_dim_transient, 171 | activation=nn.ReLU(), 172 | out_activation=None, 173 | implementation=implementation, 174 | ) 175 | self.field_head_semantics = SemanticFieldHead( 176 | in_dim=self.mlp_semantics.get_out_dim(), num_classes=num_semantic_classes 177 | ) 178 | 179 | # predicted normals 180 | if self.use_pred_normals: 181 | self.mlp_pred_normals = MLP( 182 | in_dim=self.geo_feat_dim + self.position_encoding.get_out_dim(), 183 | num_layers=3, 184 | layer_width=64, 185 | out_dim=hidden_dim_transient, 186 | activation=nn.ReLU(), 187 | out_activation=None, 188 | implementation=implementation, 189 | ) 190 | self.field_head_pred_normals = PredNormalsFieldHead(in_dim=self.mlp_pred_normals.get_out_dim()) 191 | 192 | self.mlp_head = MLP( 193 | in_dim=self.direction_encoding.get_out_dim() + self.geo_feat_dim + self.appearance_embedding_dim, 194 | num_layers=num_layers_color, 195 | layer_width=hidden_dim_color, 196 | out_dim=3, 197 | activation=nn.ReLU(), 198 | out_activation=nn.Sigmoid(), 199 | implementation=implementation, 200 | ) 201 | 202 | def get_density(self, ray_samples: RaySamples) -> Tuple[Tensor, Tensor]: 203 | """Computes and returns the densities.""" 204 | if self.spatial_distortion is not None: 205 | positions = ray_samples.frustums.get_positions() 206 | positions = self.spatial_distortion(positions) 207 | positions = (positions + 2.0) / 4.0 208 | else: 209 | positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb) 210 | # Make sure the tcnn gets inputs between 0 and 1. 211 | selector = ((positions > 0.0) & (positions < 1.0)).all(dim=-1) 212 | positions = positions * selector[..., None] 213 | 214 | assert positions.numel() > 0, "positions is empty." 215 | 216 | self._sample_locations = positions 217 | if not self._sample_locations.requires_grad: 218 | self._sample_locations.requires_grad = True 219 | positions_flat = positions.view(-1, 3) 220 | 221 | assert positions_flat.numel() > 0, "positions_flat is empty." 222 | h = self.mlp_base(positions_flat).view(*ray_samples.frustums.shape, -1) 223 | density_before_activation, base_mlp_out = torch.split(h, [1, self.geo_feat_dim], dim=-1) 224 | self._density_before_activation = density_before_activation 225 | 226 | # Rectifying the density with an exponential is much more stable than a ReLU or 227 | # softplus, because it enables high post-activation (float32) density outputs 228 | # from smaller internal (float16) parameters. 229 | density = self.average_init_density * trunc_exp(density_before_activation.to(positions)) 230 | density = density * selector[..., None] 231 | return density, base_mlp_out 232 | 233 | def get_outputs( 234 | self, ray_samples: RaySamples, density_embedding: Optional[Tensor] = None 235 | ) -> Dict[FieldHeadNames, Tensor]: 236 | assert density_embedding is not None 237 | outputs = {} 238 | if ray_samples.camera_indices is None: 239 | raise AttributeError("Camera indices are not provided.") 240 | camera_indices = ray_samples.camera_indices.squeeze() 241 | directions = get_normalized_directions(ray_samples.frustums.directions) 242 | directions_flat = directions.view(-1, 3) 243 | d = self.direction_encoding(directions_flat) 244 | 245 | outputs_shape = ray_samples.frustums.directions.shape[:-1] 246 | 247 | # appearance 248 | embedded_appearance = None 249 | if self.embedding_appearance is not None: 250 | if self.training and not self.freeze_appearance_embedding: 251 | embedded_appearance = self.embedding_appearance(camera_indices) 252 | else: 253 | if self.use_average_appearance_embedding: 254 | embedded_appearance = torch.ones( 255 | (*directions.shape[:-1], self.appearance_embedding_dim), device=directions.device 256 | ) * self.embedding_appearance.mean(dim=0) 257 | else: 258 | embedded_appearance = torch.zeros( 259 | (*directions.shape[:-1], self.appearance_embedding_dim), device=directions.device 260 | ) 261 | 262 | # transients 263 | if self.use_transient_embedding and self.training: 264 | embedded_transient = self.embedding_transient(camera_indices) 265 | transient_input = torch.cat( 266 | [ 267 | density_embedding.view(-1, self.geo_feat_dim), 268 | embedded_transient.view(-1, self.transient_embedding_dim), 269 | ], 270 | dim=-1, 271 | ) 272 | x = self.mlp_transient(transient_input).view(*outputs_shape, -1).to(directions) 273 | outputs[FieldHeadNames.UNCERTAINTY] = self.field_head_transient_uncertainty(x) 274 | outputs[FieldHeadNames.TRANSIENT_RGB] = self.field_head_transient_rgb(x) 275 | outputs[FieldHeadNames.TRANSIENT_DENSITY] = self.field_head_transient_density(x) 276 | 277 | # semantics 278 | if self.use_semantics: 279 | semantics_input = density_embedding.view(-1, self.geo_feat_dim) 280 | if not self.pass_semantic_gradients: 281 | semantics_input = semantics_input.detach() 282 | 283 | x = self.mlp_semantics(semantics_input).view(*outputs_shape, -1).to(directions) 284 | outputs[FieldHeadNames.SEMANTICS] = self.field_head_semantics(x) 285 | 286 | # predicted normals 287 | if self.use_pred_normals: 288 | positions = ray_samples.frustums.get_positions() 289 | 290 | positions_flat = self.position_encoding(positions.view(-1, 3)) 291 | pred_normals_inp = torch.cat([positions_flat, density_embedding.view(-1, self.geo_feat_dim)], dim=-1) 292 | 293 | x = self.mlp_pred_normals(pred_normals_inp).view(*outputs_shape, -1).to(directions) 294 | outputs[FieldHeadNames.PRED_NORMALS] = self.field_head_pred_normals(x) 295 | 296 | h = torch.cat( 297 | [ 298 | d, 299 | density_embedding.view(-1, self.geo_feat_dim), 300 | ] 301 | + ( 302 | [embedded_appearance.view(-1, self.appearance_embedding_dim)] if embedded_appearance is not None else [] 303 | ), 304 | dim=-1, 305 | ) 306 | rgb = self.mlp_head(h).view(*outputs_shape, -1).to(directions) 307 | outputs.update({FieldHeadNames.RGB: rgb}) 308 | 309 | return outputs 310 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | NVIDIA License 2 | 3 | 1. Definitions 4 | 5 | “Licensor” means any person or entity that distributes its Work. 6 | “Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license. 7 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 8 | Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license. 9 | 10 | 2. License Grant 11 | 12 | 2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 13 | 14 | 3. Limitations 15 | 16 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. 17 | 18 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. 19 | 20 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 21 | 22 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately. 23 | 24 | 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license. 25 | 26 | 3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately. 27 | 28 | 4. Disclaimer of Warranty. 29 | 30 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 31 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 32 | 33 | 5. Limitation of Liability. 34 | 35 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 36 | 37 | STABILITY AI COMMUNITY LICENSE AGREEMENT 38 | 39 | Last Updated: July 5, 2024 40 | 41 | INTRODUCTION 42 | This Agreement applies to any individual person or entity (“You”, “Your” or “Licensee”) that uses or distributes any portion or element of the Stability AI Materials or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below. 43 | 44 | This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement preserves free access to the Models for people or organizations generating annual revenue of less than US $1,000,000 (or local currency equivalent). 45 | 46 | By clicking “I Accept” or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then “You” includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity’s behalf. 47 | 48 | RESEARCH & NON-COMMERCIAL USE LICENSE 49 | Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. “Research Purpose” means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. “Non-Commercial Purpose” means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing. 50 | 51 | COMMERCIAL USE LICENSE 52 | Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. “Commercial Purpose” means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your business’s or organization’s internal operations. If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you. 53 | 54 | GENERAL TERMS 55 | Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms. a. Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved”, and (iii) prominently display “Powered by Stability AI” on a related website, user interface, blogpost, about page, or product documentation. If You create a Derivative Work, You may add your own attribution notice(s) to the “Notice” text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the “Notice” text file that You changed the Stability AI Materials and how it was modified. b. Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AI’s AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works). c. Intellectual Property. (i) Trademark License. No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein. (ii) Ownership of Derivative Works. As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI’s ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI. (iii) Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law. (iv) Disputes. If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement. (v) Feedback. From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AI’s existing or prospective technology, products or services (collectively, “Feedback”). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided “AS IS” and You make no warranties whatsoever about any Feedback. d. Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS. e. Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. f. Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement. g. Governing Law. This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement. 56 | 57 | DEFINITIONS 58 | “Affiliate(s)” means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, “control” means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity. 59 | 60 | "Agreement" means this Stability AI Community License Agreement. 61 | 62 | “AUP” means the Stability AI Acceptable Use Policy available at (https://stability.ai/use-policy), as may be updated from time to time. 63 | 64 | "Derivative Work(s)” means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output, including “fine tune” and “low-rank adaptation” models derived from a Model or a Model’s output, but do not include the output of any Model. 65 | 66 | “Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models. 67 | 68 | “Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stability’s Core Models Webpage available at (https://stability.ai/core-models), as may be updated from time to time. 69 | 70 | "Stability AI" or "we" means Stability AI Ltd. and its Affiliates. 71 | 72 | "Software" means Stability AI’s proprietary software made available under this Agreement now or in the future. 73 | 74 | “Stability AI Materials” means, collectively, Stability’s proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement. 75 | 76 | “Trade Control Laws” means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations. -------------------------------------------------------------------------------- /src/train_difix.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import lpips 4 | import random 5 | import argparse 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.utils.checkpoint 10 | import torchvision 11 | import transformers 12 | from torchvision.transforms.functional import crop 13 | from accelerate import Accelerator 14 | from accelerate.utils import set_seed 15 | from PIL import Image 16 | from torchvision import transforms 17 | from tqdm.auto import tqdm 18 | from glob import glob 19 | from einops import rearrange 20 | 21 | import diffusers 22 | from diffusers.utils.import_utils import is_xformers_available 23 | from diffusers.optimization import get_scheduler 24 | 25 | import wandb 26 | 27 | from model import Difix, load_ckpt_from_state_dict, save_ckpt 28 | from dataset import PairedDataset 29 | from loss import gram_loss 30 | 31 | 32 | def main(args): 33 | accelerator = Accelerator( 34 | gradient_accumulation_steps=args.gradient_accumulation_steps, 35 | mixed_precision=args.mixed_precision, 36 | log_with=args.report_to, 37 | ) 38 | 39 | if accelerator.is_local_main_process: 40 | transformers.utils.logging.set_verbosity_warning() 41 | diffusers.utils.logging.set_verbosity_info() 42 | else: 43 | transformers.utils.logging.set_verbosity_error() 44 | diffusers.utils.logging.set_verbosity_error() 45 | 46 | if args.seed is not None: 47 | set_seed(args.seed) 48 | 49 | if accelerator.is_main_process: 50 | os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True) 51 | os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True) 52 | 53 | net_difix = Difix( 54 | lora_rank_vae=args.lora_rank_vae, 55 | timestep=args.timestep, 56 | mv_unet=args.mv_unet, 57 | ) 58 | net_difix.set_train() 59 | 60 | if args.enable_xformers_memory_efficient_attention: 61 | if is_xformers_available(): 62 | net_difix.unet.enable_xformers_memory_efficient_attention() 63 | else: 64 | raise ValueError("xformers is not available, please install it by running `pip install xformers`") 65 | 66 | if args.gradient_checkpointing: 67 | net_difix.unet.enable_gradient_checkpointing() 68 | 69 | if args.allow_tf32: 70 | torch.backends.cuda.matmul.allow_tf32 = True 71 | 72 | net_lpips = lpips.LPIPS(net='vgg').cuda() 73 | 74 | net_lpips.requires_grad_(False) 75 | 76 | net_vgg = torchvision.models.vgg16(pretrained=True).features 77 | for param in net_vgg.parameters(): 78 | param.requires_grad_(False) 79 | 80 | # make the optimizer 81 | layers_to_opt = [] 82 | layers_to_opt += list(net_difix.unet.parameters()) 83 | 84 | for n, _p in net_difix.vae.named_parameters(): 85 | if "lora" in n and "vae_skip" in n: 86 | assert _p.requires_grad 87 | layers_to_opt.append(_p) 88 | layers_to_opt = layers_to_opt + list(net_difix.vae.decoder.skip_conv_1.parameters()) + \ 89 | list(net_difix.vae.decoder.skip_conv_2.parameters()) + \ 90 | list(net_difix.vae.decoder.skip_conv_3.parameters()) + \ 91 | list(net_difix.vae.decoder.skip_conv_4.parameters()) 92 | 93 | optimizer = torch.optim.AdamW(layers_to_opt, lr=args.learning_rate, 94 | betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, 95 | eps=args.adam_epsilon,) 96 | lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer, 97 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 98 | num_training_steps=args.max_train_steps * accelerator.num_processes, 99 | num_cycles=args.lr_num_cycles, power=args.lr_power,) 100 | 101 | dataset_train = PairedDataset(dataset_path=args.dataset_path, split="train", tokenizer=net_difix.tokenizer) 102 | dl_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers) 103 | dataset_val = PairedDataset(dataset_path=args.dataset_path, split="test", tokenizer=net_difix.tokenizer) 104 | random.Random(42).shuffle(dataset_val.img_names) 105 | dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0) 106 | 107 | # Resume from checkpoint 108 | global_step = 0 109 | if args.resume is not None: 110 | if os.path.isdir(args.resume): 111 | # Resume from last ckpt 112 | ckpt_files = glob(os.path.join(args.resume, "*.pkl")) 113 | assert len(ckpt_files) > 0, f"No checkpoint files found: {args.resume}" 114 | ckpt_files = sorted(ckpt_files, key=lambda x: int(x.split("/")[-1].replace("model_", "").replace(".pkl", ""))) 115 | print("="*50); print(f"Loading checkpoint from {ckpt_files[-1]}"); print("="*50) 116 | global_step = int(ckpt_files[-1].split("/")[-1].replace("model_", "").replace(".pkl", "")) 117 | net_difix, optimizer = load_ckpt_from_state_dict( 118 | net_difix, optimizer, ckpt_files[-1] 119 | ) 120 | elif args.resume.endswith(".pkl"): 121 | print("="*50); print(f"Loading checkpoint from {args.resume}"); print("="*50) 122 | global_step = int(args.resume.split("/")[-1].replace("model_", "").replace(".pkl", "")) 123 | net_difix, optimizer = load_ckpt_from_state_dict( 124 | net_difix, optimizer, args.resume 125 | ) 126 | else: 127 | raise NotImplementedError(f"Invalid resume path: {args.resume}") 128 | else: 129 | print("="*50); print(f"Training from scratch"); print("="*50) 130 | 131 | weight_dtype = torch.float32 132 | if accelerator.mixed_precision == "fp16": 133 | weight_dtype = torch.float16 134 | elif accelerator.mixed_precision == "bf16": 135 | weight_dtype = torch.bfloat16 136 | 137 | # Move al networksr to device and cast to weight_dtype 138 | net_difix.to(accelerator.device, dtype=weight_dtype) 139 | net_lpips.to(accelerator.device, dtype=weight_dtype) 140 | net_vgg.to(accelerator.device, dtype=weight_dtype) 141 | 142 | # Prepare everything with our `accelerator`. 143 | net_difix, optimizer, dl_train, lr_scheduler = accelerator.prepare( 144 | net_difix, optimizer, dl_train, lr_scheduler 145 | ) 146 | net_lpips, net_vgg = accelerator.prepare(net_lpips, net_vgg) 147 | # renorm with image net statistics 148 | t_vgg_renorm = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 149 | 150 | # We need to initialize the trackers we use, and also store our configuration. 151 | # The trackers initializes automatically on the main process. 152 | if accelerator.is_main_process: 153 | init_kwargs = { 154 | "wandb": { 155 | "name": args.tracker_run_name, 156 | "dir": args.output_dir, 157 | }, 158 | } 159 | tracker_config = dict(vars(args)) 160 | accelerator.init_trackers(args.tracker_project_name, config=tracker_config, init_kwargs=init_kwargs) 161 | 162 | progress_bar = tqdm(range(0, args.max_train_steps), initial=global_step, desc="Steps", 163 | disable=not accelerator.is_local_main_process,) 164 | 165 | # start the training loop 166 | for epoch in range(0, args.num_training_epochs): 167 | for step, batch in enumerate(dl_train): 168 | l_acc = [net_difix] 169 | with accelerator.accumulate(*l_acc): 170 | x_src = batch["conditioning_pixel_values"] 171 | x_tgt = batch["output_pixel_values"] 172 | B, V, C, H, W = x_src.shape 173 | 174 | # forward pass 175 | x_tgt_pred = net_difix(x_src, prompt_tokens=batch["input_ids"]) 176 | 177 | x_tgt = rearrange(x_tgt, 'b v c h w -> (b v) c h w') 178 | x_tgt_pred = rearrange(x_tgt_pred, 'b v c h w -> (b v) c h w') 179 | 180 | # Reconstruction loss 181 | loss_l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.float(), reduction="mean") * args.lambda_l2 182 | loss_lpips = net_lpips(x_tgt_pred.float(), x_tgt.float()).mean() * args.lambda_lpips 183 | loss = loss_l2 + loss_lpips 184 | 185 | # Gram matrix loss 186 | if args.lambda_gram > 0: 187 | if global_step > args.gram_loss_warmup_steps: 188 | x_tgt_pred_renorm = t_vgg_renorm(x_tgt_pred * 0.5 + 0.5) 189 | crop_h, crop_w = 400, 400 190 | top, left = random.randint(0, H - crop_h), random.randint(0, W - crop_w) 191 | x_tgt_pred_renorm = crop(x_tgt_pred_renorm, top, left, crop_h, crop_w) 192 | 193 | x_tgt_renorm = t_vgg_renorm(x_tgt * 0.5 + 0.5) 194 | x_tgt_renorm = crop(x_tgt_renorm, top, left, crop_h, crop_w) 195 | 196 | loss_gram = gram_loss(x_tgt_pred_renorm.to(weight_dtype), x_tgt_renorm.to(weight_dtype), net_vgg) * args.lambda_gram 197 | loss += loss_gram 198 | else: 199 | loss_gram = torch.tensor(0.0).to(weight_dtype) 200 | 201 | accelerator.backward(loss, retain_graph=False) 202 | if accelerator.sync_gradients: 203 | accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm) 204 | optimizer.step() 205 | lr_scheduler.step() 206 | optimizer.zero_grad(set_to_none=args.set_grads_to_none) 207 | 208 | x_tgt = rearrange(x_tgt, '(b v) c h w -> b v c h w', v=V) 209 | x_tgt_pred = rearrange(x_tgt_pred, '(b v) c h w -> b v c h w', v=V) 210 | 211 | # Checks if the accelerator has performed an optimization step behind the scenes 212 | if accelerator.sync_gradients: 213 | progress_bar.update(1) 214 | global_step += 1 215 | 216 | if accelerator.is_main_process: 217 | logs = {} 218 | # log all the losses 219 | logs["loss_l2"] = loss_l2.detach().item() 220 | logs["loss_lpips"] = loss_lpips.detach().item() 221 | if args.lambda_gram > 0: 222 | logs["loss_gram"] = loss_gram.detach().item() 223 | progress_bar.set_postfix(**logs) 224 | 225 | # viz some images 226 | if global_step % args.viz_freq == 1: 227 | log_dict = { 228 | "train/source": [wandb.Image(rearrange(x_src, "b v c h w -> b c (v h) w")[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(B)], 229 | "train/target": [wandb.Image(rearrange(x_tgt, "b v c h w -> b c (v h) w")[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(B)], 230 | "train/model_output": [wandb.Image(rearrange(x_tgt_pred, "b v c h w -> b c (v h) w")[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(B)], 231 | } 232 | for k in log_dict: 233 | logs[k] = log_dict[k] 234 | 235 | # checkpoint the model 236 | if global_step % args.checkpointing_steps == 1: 237 | outf = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl") 238 | # accelerator.unwrap_model(net_difix).save_model(outf) 239 | save_ckpt(accelerator.unwrap_model(net_difix), optimizer, outf) 240 | 241 | # compute validation set L2, LPIPS 242 | if args.eval_freq > 0 and global_step % args.eval_freq == 1: 243 | l_l2, l_lpips = [], [] 244 | log_dict = {"sample/source": [], "sample/target": [], "sample/model_output": []} 245 | for step, batch_val in enumerate(dl_val): 246 | if step >= args.num_samples_eval: 247 | break 248 | x_src = batch_val["conditioning_pixel_values"].to(accelerator.device, dtype=weight_dtype) 249 | x_tgt = batch_val["output_pixel_values"].to(accelerator.device, dtype=weight_dtype) 250 | B, V, C, H, W = x_src.shape 251 | assert B == 1, "Use batch size 1 for eval." 252 | with torch.no_grad(): 253 | # forward pass 254 | x_tgt_pred = accelerator.unwrap_model(net_difix)(x_src, prompt_tokens=batch_val["input_ids"].cuda()) 255 | 256 | if step % 10 == 0: 257 | log_dict["sample/source"].append(wandb.Image(rearrange(x_src, "b v c h w -> b c (v h) w")[0].float().detach().cpu(), caption=f"idx={len(log_dict['sample/source'])}")) 258 | log_dict["sample/target"].append(wandb.Image(rearrange(x_tgt, "b v c h w -> b c (v h) w")[0].float().detach().cpu(), caption=f"idx={len(log_dict['sample/source'])}")) 259 | log_dict["sample/model_output"].append(wandb.Image(rearrange(x_tgt_pred, "b v c h w -> b c (v h) w")[0].float().detach().cpu(), caption=f"idx={len(log_dict['sample/source'])}")) 260 | 261 | x_tgt = x_tgt[:, 0] # take the input view 262 | x_tgt_pred = x_tgt_pred[:, 0] # take the input view 263 | # compute the reconstruction losses 264 | loss_l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.float(), reduction="mean") 265 | loss_lpips = net_lpips(x_tgt_pred.float(), x_tgt.float()).mean() 266 | 267 | l_l2.append(loss_l2.item()) 268 | l_lpips.append(loss_lpips.item()) 269 | 270 | logs["val/l2"] = np.mean(l_l2) 271 | logs["val/lpips"] = np.mean(l_lpips) 272 | for k in log_dict: 273 | logs[k] = log_dict[k] 274 | gc.collect() 275 | torch.cuda.empty_cache() 276 | accelerator.log(logs, step=global_step) 277 | 278 | 279 | if __name__ == "__main__": 280 | 281 | parser = argparse.ArgumentParser() 282 | # args for the loss function 283 | parser.add_argument("--lambda_lpips", default=1.0, type=float) 284 | parser.add_argument("--lambda_l2", default=1.0, type=float) 285 | parser.add_argument("--lambda_gram", default=1.0, type=float) 286 | parser.add_argument("--gram_loss_warmup_steps", default=2000, type=int) 287 | 288 | # dataset options 289 | parser.add_argument("--dataset_path", required=True, type=str) 290 | parser.add_argument("--train_image_prep", default="resized_crop_512", type=str) 291 | parser.add_argument("--test_image_prep", default="resized_crop_512", type=str) 292 | parser.add_argument("--prompt", default=None, type=str) 293 | 294 | # validation eval args 295 | parser.add_argument("--eval_freq", default=100, type=int) 296 | parser.add_argument("--num_samples_eval", type=int, default=100, help="Number of samples to use for all evaluation") 297 | 298 | parser.add_argument("--viz_freq", type=int, default=100, help="Frequency of visualizing the outputs.") 299 | parser.add_argument("--tracker_project_name", type=str, default="difix", help="The name of the wandb project to log to.") 300 | parser.add_argument("--tracker_run_name", type=str, required=True) 301 | 302 | # details about the model architecture 303 | parser.add_argument("--pretrained_model_name_or_path") 304 | parser.add_argument("--revision", type=str, default=None,) 305 | parser.add_argument("--variant", type=str, default=None,) 306 | parser.add_argument("--tokenizer_name", type=str, default=None) 307 | parser.add_argument("--lora_rank_vae", default=4, type=int) 308 | parser.add_argument("--timestep", default=199, type=int) 309 | parser.add_argument("--mv_unet", action="store_true") 310 | 311 | # training details 312 | parser.add_argument("--output_dir", required=True) 313 | parser.add_argument("--cache_dir", default=None,) 314 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 315 | parser.add_argument("--resolution", type=int, default=512,) 316 | parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.") 317 | parser.add_argument("--num_training_epochs", type=int, default=10) 318 | parser.add_argument("--max_train_steps", type=int, default=10_000,) 319 | parser.add_argument("--checkpointing_steps", type=int, default=500,) 320 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.",) 321 | parser.add_argument("--gradient_checkpointing", action="store_true",) 322 | parser.add_argument("--learning_rate", type=float, default=5e-6) 323 | parser.add_argument("--lr_scheduler", type=str, default="constant", 324 | help=( 325 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 326 | ' "constant", "constant_with_warmup"]' 327 | ), 328 | ) 329 | parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.") 330 | parser.add_argument("--lr_num_cycles", type=int, default=1, 331 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", 332 | ) 333 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") 334 | 335 | parser.add_argument("--dataloader_num_workers", type=int, default=0,) 336 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 337 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 338 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 339 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 340 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 341 | parser.add_argument("--allow_tf32", action="store_true", 342 | help=( 343 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 344 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 345 | ), 346 | ) 347 | parser.add_argument("--report_to", type=str, default="wandb", 348 | help=( 349 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 350 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 351 | ), 352 | ) 353 | parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"],) 354 | parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") 355 | parser.add_argument("--set_grads_to_none", action="store_true",) 356 | 357 | # resume 358 | parser.add_argument("--resume", default=None, type=str) 359 | 360 | args = parser.parse_args() 361 | 362 | main(args) 363 | -------------------------------------------------------------------------------- /examples/gsplat/datasets/colmap.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import Any, Dict, List, Optional 4 | from typing_extensions import assert_never 5 | 6 | import cv2 7 | import imageio.v2 as imageio 8 | import numpy as np 9 | import torch 10 | from pycolmap import SceneManager 11 | 12 | from .normalize import ( 13 | align_principle_axes, 14 | similarity_from_cameras, 15 | transform_cameras, 16 | transform_points, 17 | ) 18 | 19 | 20 | def _get_rel_paths(path_dir: str) -> List[str]: 21 | """Recursively get relative paths of files in a directory.""" 22 | paths = [] 23 | for dp, dn, fn in os.walk(path_dir): 24 | for f in fn: 25 | paths.append(os.path.relpath(os.path.join(dp, f), path_dir)) 26 | return paths 27 | 28 | 29 | class Parser: 30 | """COLMAP parser.""" 31 | 32 | def __init__( 33 | self, 34 | data_dir: str, 35 | factor: int = 1, 36 | normalize: bool = False, 37 | test_every: int = 8, 38 | ): 39 | self.data_dir = data_dir 40 | self.factor = factor 41 | self.normalize = normalize 42 | self.test_every = test_every 43 | 44 | colmap_dir = os.path.join(data_dir, "sparse/0/") 45 | if not os.path.exists(colmap_dir): 46 | # colmap_dir = os.path.join(data_dir, "sparse") 47 | colmap_dir = os.path.join(data_dir, "colmap/sparse/0") 48 | assert os.path.exists( 49 | colmap_dir 50 | ), f"COLMAP directory {colmap_dir} does not exist." 51 | 52 | manager = SceneManager(colmap_dir) 53 | manager.load_cameras() 54 | manager.load_images() 55 | manager.load_points3D() 56 | 57 | # Extract extrinsic matrices in world-to-camera format. 58 | imdata = manager.images 59 | w2c_mats = [] 60 | camera_ids = [] 61 | Ks_dict = dict() 62 | params_dict = dict() 63 | imsize_dict = dict() # width, height 64 | mask_dict = dict() 65 | bottom = np.array([0, 0, 0, 1]).reshape(1, 4) 66 | for k in imdata: 67 | im = imdata[k] 68 | rot = im.R() 69 | trans = im.tvec.reshape(3, 1) 70 | w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0) 71 | w2c_mats.append(w2c) 72 | 73 | # support different camera intrinsics 74 | camera_id = im.camera_id 75 | camera_ids.append(camera_id) 76 | 77 | # camera intrinsics 78 | cam = manager.cameras[camera_id] 79 | fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy 80 | K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) 81 | K[:2, :] /= factor 82 | Ks_dict[camera_id] = K 83 | 84 | # Get distortion parameters. 85 | type_ = cam.camera_type 86 | if type_ == 0 or type_ == "SIMPLE_PINHOLE": 87 | params = np.empty(0, dtype=np.float32) 88 | camtype = "perspective" 89 | elif type_ == 1 or type_ == "PINHOLE": 90 | params = np.empty(0, dtype=np.float32) 91 | camtype = "perspective" 92 | if type_ == 2 or type_ == "SIMPLE_RADIAL": 93 | params = np.array([cam.k1, 0.0, 0.0, 0.0], dtype=np.float32) 94 | camtype = "perspective" 95 | elif type_ == 3 or type_ == "RADIAL": 96 | params = np.array([cam.k1, cam.k2, 0.0, 0.0], dtype=np.float32) 97 | camtype = "perspective" 98 | elif type_ == 4 or type_ == "OPENCV": 99 | params = np.array([cam.k1, cam.k2, cam.p1, cam.p2], dtype=np.float32) 100 | camtype = "perspective" 101 | elif type_ == 5 or type_ == "OPENCV_FISHEYE": 102 | params = np.array([cam.k1, cam.k2, cam.k3, cam.k4], dtype=np.float32) 103 | camtype = "fisheye" 104 | assert ( 105 | camtype == "perspective" or camtype == "fisheye" 106 | ), f"Only perspective and fisheye cameras are supported, got {type_}" 107 | 108 | params_dict[camera_id] = params 109 | imsize_dict[camera_id] = (cam.width // factor, cam.height // factor) 110 | mask_dict[camera_id] = None 111 | print( 112 | f"[Parser] {len(imdata)} images, taken by {len(set(camera_ids))} cameras." 113 | ) 114 | 115 | if len(imdata) == 0: 116 | raise ValueError("No images found in COLMAP.") 117 | if not (type_ == 0 or type_ == 1): 118 | print("Warning: COLMAP Camera is not PINHOLE. Images have distortion.") 119 | 120 | w2c_mats = np.stack(w2c_mats, axis=0) 121 | 122 | # Convert extrinsics to camera-to-world. 123 | camtoworlds = np.linalg.inv(w2c_mats) 124 | 125 | # Image names from COLMAP. No need for permuting the poses according to 126 | # image names anymore. 127 | image_names = [imdata[k].name for k in imdata] 128 | 129 | # Previous Nerf results were generated with images sorted by filename, 130 | # ensure metrics are reported on the same test set. 131 | inds = np.argsort(image_names) 132 | image_names = [image_names[i] for i in inds] 133 | camtoworlds = camtoworlds[inds] 134 | camera_ids = [camera_ids[i] for i in inds] 135 | 136 | # Load extended metadata. Used by Bilarf dataset. 137 | self.extconf = { 138 | "spiral_radius_scale": 1.0, 139 | "no_factor_suffix": False, 140 | } 141 | extconf_file = os.path.join(data_dir, "ext_metadata.json") 142 | if os.path.exists(extconf_file): 143 | with open(extconf_file) as f: 144 | self.extconf.update(json.load(f)) 145 | 146 | # Load bounds if possible (only used in forward facing scenes). 147 | self.bounds = np.array([0.01, 1.0]) 148 | posefile = os.path.join(data_dir, "poses_bounds.npy") 149 | if os.path.exists(posefile): 150 | self.bounds = np.load(posefile)[:, -2:] 151 | 152 | # Load images. 153 | if factor > 1 and not self.extconf["no_factor_suffix"]: 154 | image_dir_suffix = f"_{factor}" 155 | else: 156 | image_dir_suffix = "" 157 | colmap_image_dir = os.path.join(data_dir, "images") 158 | image_dir = os.path.join(data_dir, "images" + image_dir_suffix) 159 | for d in [image_dir, colmap_image_dir]: 160 | if not os.path.exists(d): 161 | raise ValueError(f"Image folder {d} does not exist.") 162 | 163 | # Downsampled images may have different names vs images used for COLMAP, 164 | # so we need to map between the two sorted lists of files. 165 | if "3dv-dataset-nerfstudio" in data_dir: 166 | colmap_files = sorted(_get_rel_paths(colmap_image_dir), key=lambda x: int(x.split(".")[0].split("_")[-1])) 167 | image_files = sorted(_get_rel_paths(image_dir), key=lambda x: int(x.split(".")[0].split("_")[-1])) 168 | colmap_to_image = dict(zip(colmap_files, image_files)) 169 | image_names = colmap_files 170 | image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names] 171 | elif "DL3DV-Benchmark" in data_dir: 172 | colmap_files = sorted(_get_rel_paths(colmap_image_dir)) 173 | image_files = sorted(_get_rel_paths(image_dir)) 174 | colmap_to_image = dict(zip(colmap_files, image_files)) 175 | if len(colmap_files) != len(image_names): 176 | print(f"Warning: colmap_files: {len(colmap_files)}, image_names: {len(image_names)}") 177 | image_names = colmap_files 178 | image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names] 179 | else: 180 | colmap_files = sorted(_get_rel_paths(colmap_image_dir)) 181 | image_files = sorted(_get_rel_paths(image_dir)) 182 | colmap_to_image = dict(zip(colmap_files, image_files)) 183 | image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names] 184 | 185 | # 3D points and {image_name -> [point_idx]} 186 | points = manager.points3D.astype(np.float32) 187 | points_err = manager.point3D_errors.astype(np.float32) 188 | points_rgb = manager.point3D_colors.astype(np.uint8) 189 | point_indices = dict() 190 | 191 | image_id_to_name = {v: k for k, v in manager.name_to_image_id.items()} 192 | for point_id, data in manager.point3D_id_to_images.items(): 193 | for image_id, _ in data: 194 | image_name = image_id_to_name[image_id] 195 | point_idx = manager.point3D_id_to_point3D_idx[point_id] 196 | point_indices.setdefault(image_name, []).append(point_idx) 197 | point_indices = { 198 | k: np.array(v).astype(np.int32) for k, v in point_indices.items() 199 | } 200 | 201 | # Normalize the world space. 202 | if normalize: 203 | T1 = similarity_from_cameras(camtoworlds) 204 | camtoworlds = transform_cameras(T1, camtoworlds) 205 | points = transform_points(T1, points) 206 | 207 | T2 = align_principle_axes(points) 208 | camtoworlds = transform_cameras(T2, camtoworlds) 209 | points = transform_points(T2, points) 210 | 211 | transform = T2 @ T1 212 | else: 213 | transform = np.eye(4) 214 | 215 | self.image_names = image_names # List[str], (num_images,) 216 | self.image_paths = image_paths # List[str], (num_images,) 217 | self.alpha_mask_paths = None # List[str], (num_images,) 218 | self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4) 219 | self.camera_ids = camera_ids # List[int], (num_images,) 220 | self.Ks_dict = Ks_dict # Dict of camera_id -> K 221 | self.params_dict = params_dict # Dict of camera_id -> params 222 | self.imsize_dict = imsize_dict # Dict of camera_id -> (width, height) 223 | self.mask_dict = mask_dict # Dict of camera_id -> mask 224 | self.points = points # np.ndarray, (num_points, 3) 225 | self.points_err = points_err # np.ndarray, (num_points,) 226 | self.points_rgb = points_rgb # np.ndarray, (num_points, 3) 227 | self.point_indices = point_indices # Dict[str, np.ndarray], image_name -> [M,] 228 | self.transform = transform # np.ndarray, (4, 4) 229 | 230 | # load one image to check the size. In the case of tanksandtemples dataset, the 231 | # intrinsics stored in COLMAP corresponds to 2x upsampled images. 232 | actual_image = imageio.imread(self.image_paths[0])[..., :3] 233 | actual_height, actual_width = actual_image.shape[:2] 234 | colmap_width, colmap_height = self.imsize_dict[self.camera_ids[0]] 235 | s_height, s_width = actual_height / colmap_height, actual_width / colmap_width 236 | for camera_id, K in self.Ks_dict.items(): 237 | K[0, :] *= s_width 238 | K[1, :] *= s_height 239 | self.Ks_dict[camera_id] = K 240 | width, height = self.imsize_dict[camera_id] 241 | self.imsize_dict[camera_id] = (actual_width, actual_height) 242 | 243 | # undistortion 244 | self.mapx_dict = dict() 245 | self.mapy_dict = dict() 246 | self.roi_undist_dict = dict() 247 | for camera_id in self.params_dict.keys(): 248 | params = self.params_dict[camera_id] 249 | if len(params) == 0: 250 | continue # no distortion 251 | assert camera_id in self.Ks_dict, f"Missing K for camera {camera_id}" 252 | assert ( 253 | camera_id in self.params_dict 254 | ), f"Missing params for camera {camera_id}" 255 | K = self.Ks_dict[camera_id] 256 | width, height = self.imsize_dict[camera_id] 257 | 258 | if camtype == "perspective": 259 | K_undist, roi_undist = cv2.getOptimalNewCameraMatrix( 260 | K, params, (width, height), 0 261 | ) 262 | mapx, mapy = cv2.initUndistortRectifyMap( 263 | K, params, None, K_undist, (width, height), cv2.CV_32FC1 264 | ) 265 | mask = None 266 | elif camtype == "fisheye": 267 | fx = K[0, 0] 268 | fy = K[1, 1] 269 | cx = K[0, 2] 270 | cy = K[1, 2] 271 | grid_x, grid_y = np.meshgrid( 272 | np.arange(width, dtype=np.float32), 273 | np.arange(height, dtype=np.float32), 274 | indexing="xy", 275 | ) 276 | x1 = (grid_x - cx) / fx 277 | y1 = (grid_y - cy) / fy 278 | theta = np.sqrt(x1**2 + y1**2) 279 | r = ( 280 | 1.0 281 | + params[0] * theta**2 282 | + params[1] * theta**4 283 | + params[2] * theta**6 284 | + params[3] * theta**8 285 | ) 286 | mapx = fx * x1 * r + width // 2 287 | mapy = fy * y1 * r + height // 2 288 | 289 | # Use mask to define ROI 290 | mask = np.logical_and( 291 | np.logical_and(mapx > 0, mapy > 0), 292 | np.logical_and(mapx < width - 1, mapy < height - 1), 293 | ) 294 | y_indices, x_indices = np.nonzero(mask) 295 | y_min, y_max = y_indices.min(), y_indices.max() + 1 296 | x_min, x_max = x_indices.min(), x_indices.max() + 1 297 | mask = mask[y_min:y_max, x_min:x_max] 298 | K_undist = K.copy() 299 | K_undist[0, 2] -= x_min 300 | K_undist[1, 2] -= y_min 301 | roi_undist = [x_min, y_min, x_max - x_min, y_max - y_min] 302 | else: 303 | assert_never(camtype) 304 | 305 | self.mapx_dict[camera_id] = mapx 306 | self.mapy_dict[camera_id] = mapy 307 | self.Ks_dict[camera_id] = K_undist 308 | self.roi_undist_dict[camera_id] = roi_undist 309 | self.imsize_dict[camera_id] = (roi_undist[2], roi_undist[3]) 310 | self.mask_dict[camera_id] = mask 311 | 312 | # size of the scene measured by cameras 313 | camera_locations = camtoworlds[:, :3, 3] 314 | scene_center = np.mean(camera_locations, axis=0) 315 | dists = np.linalg.norm(camera_locations - scene_center, axis=1) 316 | self.scene_scale = np.max(dists) 317 | 318 | 319 | class Dataset: 320 | """A simple dataset class.""" 321 | 322 | def __init__( 323 | self, 324 | parser: Parser, 325 | split: str = "train", 326 | patch_size: Optional[int] = None, 327 | load_depths: bool = False, 328 | ): 329 | self.parser = parser 330 | self.split = split 331 | self.patch_size = patch_size 332 | self.load_depths = load_depths 333 | 334 | indices = np.arange(len(self.parser.image_names)) 335 | if self.parser.test_every == 1: 336 | image_names = sorted(_get_rel_paths(f"{self.parser.data_dir}/images"), key=lambda x: int(x.split(".")[0].split("_")[-1])) 337 | assert len(image_names) == len(self.parser.image_names) 338 | if split == "train": 339 | self.indices = [ind for ind in indices if "_train_" in image_names[ind]] 340 | else: 341 | self.indices = [ind for ind in indices if "_eval_" in image_names[ind]] 342 | elif self.parser.test_every == 0: 343 | self.indices = indices 344 | else: 345 | if split == "train": 346 | self.indices = indices[indices % self.parser.test_every == 0] 347 | else: 348 | self.indices = indices[indices % self.parser.test_every != 0] 349 | 350 | def __len__(self): 351 | return len(self.indices) 352 | 353 | def __getitem__(self, item: int) -> Dict[str, Any]: 354 | index = self.indices[item] 355 | image = imageio.imread(self.parser.image_paths[index])[..., :3] 356 | camera_id = self.parser.camera_ids[index] 357 | K = self.parser.Ks_dict[camera_id].copy() # undistorted K 358 | params = self.parser.params_dict[camera_id] 359 | camtoworlds = self.parser.camtoworlds[index] 360 | mask = self.parser.mask_dict[camera_id] 361 | 362 | if len(params) > 0: 363 | # Images are distorted. Undistort them. 364 | mapx, mapy = ( 365 | self.parser.mapx_dict[camera_id], 366 | self.parser.mapy_dict[camera_id], 367 | ) 368 | image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR) 369 | x, y, w, h = self.parser.roi_undist_dict[camera_id] 370 | image = image[y : y + h, x : x + w] 371 | 372 | if self.patch_size is not None: 373 | # Random crop. 374 | h, w = image.shape[:2] 375 | x = np.random.randint(0, max(w - self.patch_size, 1)) 376 | y = np.random.randint(0, max(h - self.patch_size, 1)) 377 | image = image[y : y + self.patch_size, x : x + self.patch_size] 378 | K[0, 2] -= x 379 | K[1, 2] -= y 380 | 381 | data = { 382 | "K": torch.from_numpy(K).float(), 383 | "camtoworld": torch.from_numpy(camtoworlds).float(), 384 | "image": torch.from_numpy(image).float(), 385 | "image_id": item, # the index of the image in the dataset 386 | } 387 | if mask is not None: 388 | data["mask"] = torch.from_numpy(mask).bool() 389 | 390 | if self.parser.alpha_mask_paths is not None: 391 | alpha_mask = imageio.imread(self.parser.alpha_mask_paths[index], mode="L")[:,:,None] / 255.0 392 | if self.patch_size is not None: 393 | alpha_mask = alpha_mask[y : y + self.patch_size, x : x + self.patch_size] 394 | data["alpha_mask"] = torch.from_numpy(alpha_mask).float() 395 | 396 | if self.load_depths: 397 | # projected points to image plane to get depths 398 | worldtocams = np.linalg.inv(camtoworlds) 399 | image_name = self.parser.image_names[index] 400 | point_indices = self.parser.point_indices[image_name] 401 | points_world = self.parser.points[point_indices] 402 | points_cam = (worldtocams[:3, :3] @ points_world.T + worldtocams[:3, 3:4]).T 403 | points_proj = (K @ points_cam.T).T 404 | points = points_proj[:, :2] / points_proj[:, 2:3] # (M, 2) 405 | depths = points_cam[:, 2] # (M,) 406 | # filter out points outside the image 407 | selector = ( 408 | (points[:, 0] >= 0) 409 | & (points[:, 0] < image.shape[1]) 410 | & (points[:, 1] >= 0) 411 | & (points[:, 1] < image.shape[0]) 412 | & (depths > 0) 413 | ) 414 | points = points[selector] 415 | depths = depths[selector] 416 | data["points"] = torch.from_numpy(points).float() 417 | data["depths"] = torch.from_numpy(depths).float() 418 | 419 | return data 420 | 421 | 422 | if __name__ == "__main__": 423 | import argparse 424 | 425 | import imageio.v2 as imageio 426 | import tqdm 427 | 428 | parser = argparse.ArgumentParser() 429 | parser.add_argument("--data_dir", type=str, default="data/360_v2/garden") 430 | parser.add_argument("--factor", type=int, default=4) 431 | args = parser.parse_args() 432 | 433 | # Parse COLMAP data. 434 | parser = Parser( 435 | data_dir=args.data_dir, factor=args.factor, normalize=True, test_every=8 436 | ) 437 | dataset = Dataset(parser, split="train", load_depths=True) 438 | print(f"Dataset: {len(dataset)} images.") 439 | 440 | writer = imageio.get_writer("results/points.mp4", fps=30) 441 | for data in tqdm.tqdm(dataset, desc="Plotting points"): 442 | image = data["image"].numpy().astype(np.uint8) 443 | points = data["points"].numpy() 444 | depths = data["depths"].numpy() 445 | for x, y in points: 446 | cv2.circle(image, (int(x), int(y)), 2, (255, 0, 0), -1) 447 | writer.append_data(image) 448 | writer.close() 449 | -------------------------------------------------------------------------------- /examples/gsplat/lib_bilagrid.py: -------------------------------------------------------------------------------- 1 | # # Copyright 2024 Yuehao Wang (https://github.com/yuehaowang). This part of code is borrowed form ["Bilateral Guided Radiance Field Processing"](https://bilarfpro.github.io/). 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | This is a standalone PyTorch implementation of 3D bilateral grid and CP-decomposed 4D bilateral grid. 17 | To use this module, you can download the "lib_bilagrid.py" file and simply put it in your project directory. 18 | 19 | For the details, please check our research project: ["Bilateral Guided Radiance Field Processing"](https://bilarfpro.github.io/). 20 | 21 | #### Dependencies 22 | 23 | In addition to PyTorch and Numpy, please install [tensorly](https://github.com/tensorly/tensorly). 24 | We have tested this module on Python 3.9.18, PyTorch 2.0.1 (CUDA 11), tensorly 0.8.1, and Numpy 1.25.2. 25 | 26 | #### Overview 27 | 28 | - For bilateral guided training, you need to construct a `BilateralGrid` instance, which can hold multiple bilateral grids 29 | for input views. Then, use `slice` function to obtain transformed RGB output and the corresponding affine transformations. 30 | 31 | - For bilateral guided finishing, you need to instantiate a `BilateralGridCP4D` object and use `slice4d`. 32 | 33 | #### Examples 34 | 35 | - Bilateral grid for approximating ISP: 36 | 37 | Open In Colab 38 | 39 | - Low-rank 4D bilateral grid for MR enhancement: 40 | 41 | Open In Colab 42 | 43 | 44 | Below is the API reference. 45 | 46 | """ 47 | 48 | import tensorly as tl 49 | import torch 50 | import torch.nn.functional as F 51 | from torch import nn 52 | 53 | tl.set_backend("pytorch") 54 | 55 | 56 | def color_correct( 57 | img: torch.Tensor, ref: torch.Tensor, num_iters: int = 5, eps: float = 0.5 / 255 58 | ) -> torch.Tensor: 59 | """ 60 | Warp `img` to match the colors in `ref_img` using iterative color matching. 61 | 62 | This function performs color correction by warping the colors of the input image 63 | to match those of a reference image. It uses a least squares method to find a 64 | transformation that maps the input image's colors to the reference image's colors. 65 | 66 | The algorithm iteratively solves a system of linear equations, updating the set of 67 | unsaturated pixels in each iteration. This approach helps handle non-linear color 68 | transformations and reduces the impact of clipping. 69 | 70 | Args: 71 | img (torch.Tensor): Input image to be color corrected. Shape: [..., num_channels] 72 | ref (torch.Tensor): Reference image to match colors. Shape: [..., num_channels] 73 | num_iters (int, optional): Number of iterations for the color matching process. 74 | Default is 5. 75 | eps (float, optional): Small value to determine the range of unclipped pixels. 76 | Default is 0.5 / 255. 77 | 78 | Returns: 79 | torch.Tensor: Color corrected image with the same shape as the input image. 80 | 81 | Note: 82 | - Both input and reference images should be in the range [0, 1]. 83 | - The function works with any number of channels, but typically used with 3 (RGB). 84 | """ 85 | if img.shape[-1] != ref.shape[-1]: 86 | raise ValueError( 87 | f"img's {img.shape[-1]} and ref's {ref.shape[-1]} channels must match" 88 | ) 89 | num_channels = img.shape[-1] 90 | img_mat = img.reshape([-1, num_channels]) 91 | ref_mat = ref.reshape([-1, num_channels]) 92 | 93 | def is_unclipped(z): 94 | return (z >= eps) & (z <= 1 - eps) # z \in [eps, 1-eps]. 95 | 96 | mask0 = is_unclipped(img_mat) 97 | # Because the set of saturated pixels may change after solving for a 98 | # transformation, we repeatedly solve a system `num_iters` times and update 99 | # our estimate of which pixels are saturated. 100 | for _ in range(num_iters): 101 | # Construct the left hand side of a linear system that contains a quadratic 102 | # expansion of each pixel of `img`. 103 | a_mat = [] 104 | for c in range(num_channels): 105 | a_mat.append(img_mat[:, c : (c + 1)] * img_mat[:, c:]) # Quadratic term. 106 | a_mat.append(img_mat) # Linear term. 107 | a_mat.append(torch.ones_like(img_mat[:, :1])) # Bias term. 108 | a_mat = torch.cat(a_mat, dim=-1) 109 | warp = [] 110 | for c in range(num_channels): 111 | # Construct the right hand side of a linear system containing each color 112 | # of `ref`. 113 | b = ref_mat[:, c] 114 | # Ignore rows of the linear system that were saturated in the input or are 115 | # saturated in the current corrected color estimate. 116 | mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b) 117 | ma_mat = torch.where(mask[:, None], a_mat, torch.zeros_like(a_mat)) 118 | mb = torch.where(mask, b, torch.zeros_like(b)) 119 | w = torch.linalg.lstsq(ma_mat, mb, rcond=-1)[0] 120 | assert torch.all(torch.isfinite(w)) 121 | warp.append(w) 122 | warp = torch.stack(warp, dim=-1) 123 | # Apply the warp to update img_mat. 124 | img_mat = torch.clip(torch.matmul(a_mat, warp), 0, 1) 125 | corrected_img = torch.reshape(img_mat, img.shape) 126 | return corrected_img 127 | 128 | 129 | def bilateral_grid_tv_loss(model, config): 130 | """Computes total variations of bilateral grids.""" 131 | total_loss = 0.0 132 | 133 | for bil_grids in model.bil_grids: 134 | total_loss += config.bilgrid_tv_loss_mult * total_variation_loss( 135 | bil_grids.grids 136 | ) 137 | 138 | return total_loss 139 | 140 | 141 | def color_affine_transform(affine_mats, rgb): 142 | """Applies color affine transformations. 143 | 144 | Args: 145 | affine_mats (torch.Tensor): Affine transformation matrices. Supported shape: $(..., 3, 4)$. 146 | rgb (torch.Tensor): Input RGB values. Supported shape: $(..., 3)$. 147 | 148 | Returns: 149 | Output transformed colors of shape $(..., 3)$. 150 | """ 151 | return ( 152 | torch.matmul(affine_mats[..., :3], rgb.unsqueeze(-1)).squeeze(-1) 153 | + affine_mats[..., 3] 154 | ) 155 | 156 | 157 | def _num_tensor_elems(t): 158 | return max(torch.prod(torch.tensor(t.size()[1:]).float()).item(), 1.0) 159 | 160 | 161 | def total_variation_loss(x): # noqa: F811 162 | """Returns total variation on multi-dimensional tensors. 163 | 164 | Args: 165 | x (torch.Tensor): The input tensor with shape $(B, C, ...)$, where $B$ is the batch size and $C$ is the channel size. 166 | """ 167 | batch_size = x.shape[0] 168 | tv = 0 169 | for i in range(2, len(x.shape)): 170 | n_res = x.shape[i] 171 | idx1 = torch.arange(1, n_res, device=x.device) 172 | idx2 = torch.arange(0, n_res - 1, device=x.device) 173 | x1 = x.index_select(i, idx1) 174 | x2 = x.index_select(i, idx2) 175 | count = _num_tensor_elems(x1) 176 | tv += torch.pow((x1 - x2), 2).sum() / count 177 | return tv / batch_size 178 | 179 | 180 | def slice(bil_grids, xy, rgb, grid_idx): 181 | """Slices a batch of 3D bilateral grids by pixel coordinates `xy` and gray-scale guidances of pixel colors `rgb`. 182 | 183 | Supports 2-D, 3-D, and 4-D input shapes. The first dimension of the input is the batch size 184 | and the last dimension is 2 for `xy`, 3 for `rgb`, and 1 for `grid_idx`. 185 | 186 | The return value is a dictionary containing the affine transformations `affine_mats` sliced from bilateral grids and 187 | the output color `rgb_out` after applying the afffine transformations. 188 | 189 | In the 2-D input case, `xy` is a $(N, 2)$ tensor, `rgb` is a $(N, 3)$ tensor, and `grid_idx` is a $(N, 1)$ tensor. 190 | Then `affine_mats[i]` can be obtained via slicing the bilateral grid indexed at `grid_idx[i]` by `xy[i, :]` and `rgb2gray(rgb[i, :])`. 191 | For 3-D and 4-D input cases, the behavior of indexing bilateral grids and coordinates is the same with the 2-D case. 192 | 193 | .. note:: 194 | This function can be regarded as a wrapper of `color_affine_transform` and `BilateralGrid` with a slight performance improvement. 195 | When `grid_idx` contains a unique index, only a single bilateral grid will used during the slicing. In this case, this function will not 196 | perform tensor indexing to avoid data copy and extra memory 197 | (see [this](https://discuss.pytorch.org/t/does-indexing-a-tensor-return-a-copy-of-it/164905)). 198 | 199 | Args: 200 | bil_grids (`BilateralGrid`): An instance of $N$ bilateral grids. 201 | xy (torch.Tensor): The x-y coordinates of shape $(..., 2)$ in the range of $[0,1]$. 202 | rgb (torch.Tensor): The RGB values of shape $(..., 3)$ for computing the guidance coordinates, ranging in $[0,1]$. 203 | grid_idx (torch.Tensor): The indices of bilateral grids for each slicing. Shape: $(..., 1)$. 204 | 205 | Returns: 206 | A dictionary with keys and values as follows: 207 | ``` 208 | { 209 | "rgb": Transformed RGB colors. Shape: (..., 3), 210 | "rgb_affine_mats": The sliced affine transformation matrices from bilateral grids. Shape: (..., 3, 4) 211 | } 212 | ``` 213 | """ 214 | 215 | sh_ = rgb.shape 216 | 217 | grid_idx_unique = torch.unique(grid_idx) 218 | if len(grid_idx_unique) == 1: 219 | # All pixels are from a single view. 220 | grid_idx = grid_idx_unique # (1,) 221 | xy = xy.unsqueeze(0) # (1, ..., 2) 222 | rgb = rgb.unsqueeze(0) # (1, ..., 3) 223 | else: 224 | # Pixels are randomly sampled from different views. 225 | if len(grid_idx.shape) == 4: 226 | grid_idx = grid_idx[:, 0, 0, 0] # (chunk_size,) 227 | elif len(grid_idx.shape) == 3: 228 | grid_idx = grid_idx[:, 0, 0] # (chunk_size,) 229 | elif len(grid_idx.shape) == 2: 230 | grid_idx = grid_idx[:, 0] # (chunk_size,) 231 | else: 232 | raise ValueError( 233 | "The input to bilateral grid slicing is not supported yet." 234 | ) 235 | 236 | affine_mats = bil_grids(xy, rgb, grid_idx) 237 | rgb = color_affine_transform(affine_mats, rgb) 238 | 239 | return { 240 | "rgb": rgb.reshape(*sh_), 241 | "rgb_affine_mats": affine_mats.reshape( 242 | *sh_[:-1], affine_mats.shape[-2], affine_mats.shape[-1] 243 | ), 244 | } 245 | 246 | 247 | class BilateralGrid(nn.Module): 248 | """Class for 3D bilateral grids. 249 | 250 | Holds one or more than one bilateral grids. 251 | """ 252 | 253 | def __init__(self, num, grid_X=16, grid_Y=16, grid_W=8): 254 | """ 255 | Args: 256 | num (int): The number of bilateral grids (i.e., the number of views). 257 | grid_X (int): Defines grid width $W$. 258 | grid_Y (int): Defines grid height $H$. 259 | grid_W (int): Defines grid guidance dimension $L$. 260 | """ 261 | super(BilateralGrid, self).__init__() 262 | 263 | self.grid_width = grid_X 264 | """Grid width. Type: int.""" 265 | self.grid_height = grid_Y 266 | """Grid height. Type: int.""" 267 | self.grid_guidance = grid_W 268 | """Grid guidance dimension. Type: int.""" 269 | 270 | # Initialize grids. 271 | grid = self._init_identity_grid() 272 | self.grids = nn.Parameter(grid.tile(num, 1, 1, 1, 1)) # (N, 12, L, H, W) 273 | """ A 5-D tensor of shape $(N, 12, L, H, W)$.""" 274 | 275 | # Weights of BT601 RGB-to-gray. 276 | self.register_buffer("rgb2gray_weight", torch.Tensor([[0.299, 0.587, 0.114]])) 277 | self.rgb2gray = lambda rgb: (rgb @ self.rgb2gray_weight.T) * 2.0 - 1.0 278 | """ A function that converts RGB to gray-scale guidance in $[-1, 1]$.""" 279 | 280 | def _init_identity_grid(self): 281 | grid = torch.tensor( 282 | [ 283 | 1.0, 284 | 0, 285 | 0, 286 | 0, 287 | 0, 288 | 1.0, 289 | 0, 290 | 0, 291 | 0, 292 | 0, 293 | 1.0, 294 | 0, 295 | ] 296 | ).float() 297 | grid = grid.repeat( 298 | [self.grid_guidance * self.grid_height * self.grid_width, 1] 299 | ) # (L * H * W, 12) 300 | grid = grid.reshape( 301 | 1, self.grid_guidance, self.grid_height, self.grid_width, -1 302 | ) # (1, L, H, W, 12) 303 | grid = grid.permute(0, 4, 1, 2, 3) # (1, 12, L, H, W) 304 | return grid 305 | 306 | def tv_loss(self): 307 | """Computes and returns total variation loss on the bilateral grids.""" 308 | return total_variation_loss(self.grids) 309 | 310 | def forward(self, grid_xy, rgb, idx=None): 311 | """Bilateral grid slicing. Supports 2-D, 3-D, 4-D, and 5-D input. 312 | For the 2-D, 3-D, and 4-D cases, please refer to `slice`. 313 | For the 5-D cases, `idx` will be unused and the first dimension of `xy` should be 314 | equal to the number of bilateral grids. Then this function becomes PyTorch's 315 | [`F.grid_sample`](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html). 316 | 317 | Args: 318 | grid_xy (torch.Tensor): The x-y coordinates in the range of $[0,1]$. 319 | rgb (torch.Tensor): The RGB values in the range of $[0,1]$. 320 | idx (torch.Tensor): The bilateral grid indices. 321 | 322 | Returns: 323 | Sliced affine matrices of shape $(..., 3, 4)$. 324 | """ 325 | 326 | grids = self.grids 327 | input_ndims = len(grid_xy.shape) 328 | assert len(rgb.shape) == input_ndims 329 | 330 | if input_ndims > 1 and input_ndims < 5: 331 | # Convert input into 5D 332 | for i in range(5 - input_ndims): 333 | grid_xy = grid_xy.unsqueeze(1) 334 | rgb = rgb.unsqueeze(1) 335 | assert idx is not None 336 | elif input_ndims != 5: 337 | raise ValueError( 338 | "Bilateral grid slicing only takes either 2D, 3D, 4D and 5D inputs" 339 | ) 340 | 341 | grids = self.grids 342 | if idx is not None: 343 | grids = grids[idx] 344 | assert grids.shape[0] == grid_xy.shape[0] 345 | 346 | # Generate slicing coordinates. 347 | grid_xy = (grid_xy - 0.5) * 2 # Rescale to [-1, 1]. 348 | grid_z = self.rgb2gray(rgb) 349 | 350 | # print(grid_xy.shape, grid_z.shape) 351 | # exit() 352 | grid_xyz = torch.cat([grid_xy, grid_z], dim=-1) # (N, m, h, w, 3) 353 | 354 | affine_mats = F.grid_sample( 355 | grids, grid_xyz, mode="bilinear", align_corners=True, padding_mode="border" 356 | ) # (N, 12, m, h, w) 357 | affine_mats = affine_mats.permute(0, 2, 3, 4, 1) # (N, m, h, w, 12) 358 | affine_mats = affine_mats.reshape( 359 | *affine_mats.shape[:-1], 3, 4 360 | ) # (N, m, h, w, 3, 4) 361 | 362 | for _ in range(5 - input_ndims): 363 | affine_mats = affine_mats.squeeze(1) 364 | 365 | return affine_mats 366 | 367 | 368 | def slice4d(bil_grid4d, xyz, rgb): 369 | """Slices a 4D bilateral grid by point coordinates `xyz` and gray-scale guidances of radiance colors `rgb`. 370 | 371 | Args: 372 | bil_grid4d (`BilateralGridCP4D`): The input 4D bilateral grid. 373 | xyz (torch.Tensor): The xyz coordinates with shape $(..., 3)$. 374 | rgb (torch.Tensor): The RGB values with shape $(..., 3)$. 375 | 376 | Returns: 377 | A dictionary with keys and values as follows: 378 | ``` 379 | { 380 | "rgb": Transformed radiance RGB colors. Shape: (..., 3), 381 | "rgb_affine_mats": The sliced affine transformation matrices from the 4D bilateral grid. Shape: (..., 3, 4) 382 | } 383 | ``` 384 | """ 385 | 386 | affine_mats = bil_grid4d(xyz, rgb) 387 | rgb = color_affine_transform(affine_mats, rgb) 388 | 389 | return {"rgb": rgb, "rgb_affine_mats": affine_mats} 390 | 391 | 392 | class _ScaledTanh(nn.Module): 393 | def __init__(self, s=2.0): 394 | super().__init__() 395 | self.scaler = s 396 | 397 | def forward(self, x): 398 | return torch.tanh(self.scaler * x) 399 | 400 | 401 | class BilateralGridCP4D(nn.Module): 402 | """Class for low-rank 4D bilateral grids.""" 403 | 404 | def __init__( 405 | self, 406 | grid_X=16, 407 | grid_Y=16, 408 | grid_Z=16, 409 | grid_W=8, 410 | rank=5, 411 | learn_gray=True, 412 | gray_mlp_width=8, 413 | gray_mlp_depth=2, 414 | init_noise_scale=1e-6, 415 | bound=2.0, 416 | ): 417 | """ 418 | Args: 419 | grid_X (int): Defines grid width. 420 | grid_Y (int): Defines grid height. 421 | grid_Z (int): Defines grid depth. 422 | grid_W (int): Defines grid guidance dimension. 423 | rank (int): Rank of the 4D bilateral grid. 424 | learn_gray (bool): If True, an MLP will be learned to convert RGB colors to gray-scale guidances. 425 | gray_mlp_width (int): The MLP width for learnable guidance. 426 | gray_mlp_depth (int): The number of MLP layers for learnable guidance. 427 | init_noise_scale (float): The noise scale of the initialized factors. 428 | bound (float): The bound of the xyz coordinates. 429 | """ 430 | super(BilateralGridCP4D, self).__init__() 431 | 432 | self.grid_X = grid_X 433 | """Grid width. Type: int.""" 434 | self.grid_Y = grid_Y 435 | """Grid height. Type: int.""" 436 | self.grid_Z = grid_Z 437 | """Grid depth. Type: int.""" 438 | self.grid_W = grid_W 439 | """Grid guidance dimension. Type: int.""" 440 | self.rank = rank 441 | """Rank of the 4D bilateral grid. Type: int.""" 442 | self.learn_gray = learn_gray 443 | """Flags of learnable guidance is used. Type: bool.""" 444 | self.gray_mlp_width = gray_mlp_width 445 | """The MLP width for learnable guidance. Type: int.""" 446 | self.gray_mlp_depth = gray_mlp_depth 447 | """The MLP depth for learnable guidance. Type: int.""" 448 | self.init_noise_scale = init_noise_scale 449 | """The noise scale of the initialized factors. Type: float.""" 450 | self.bound = bound 451 | """The bound of the xyz coordinates. Type: float.""" 452 | 453 | self._init_cp_factors_parafac() 454 | 455 | self.rgb2gray = None 456 | """ A function that converts RGB to gray-scale guidances in $[-1, 1]$. 457 | If `learn_gray` is True, this will be an MLP network.""" 458 | 459 | if self.learn_gray: 460 | 461 | def rgb2gray_mlp_linear(layer): 462 | return nn.Linear( 463 | self.gray_mlp_width, 464 | self.gray_mlp_width if layer < self.gray_mlp_depth - 1 else 1, 465 | ) 466 | 467 | def rgb2gray_mlp_actfn(_): 468 | return nn.ReLU(inplace=True) 469 | 470 | self.rgb2gray = nn.Sequential( 471 | *( 472 | [nn.Linear(3, self.gray_mlp_width)] 473 | + [ 474 | nn_module(layer) 475 | for layer in range(1, self.gray_mlp_depth) 476 | for nn_module in [rgb2gray_mlp_actfn, rgb2gray_mlp_linear] 477 | ] 478 | + [_ScaledTanh(2.0)] 479 | ) 480 | ) 481 | else: 482 | # Weights of BT601/BT470 RGB-to-gray. 483 | self.register_buffer( 484 | "rgb2gray_weight", torch.Tensor([[0.299, 0.587, 0.114]]) 485 | ) 486 | self.rgb2gray = lambda rgb: (rgb @ self.rgb2gray_weight.T) * 2.0 - 1.0 487 | 488 | def _init_identity_grid(self): 489 | grid = torch.tensor( 490 | [ 491 | 1.0, 492 | 0, 493 | 0, 494 | 0, 495 | 0, 496 | 1.0, 497 | 0, 498 | 0, 499 | 0, 500 | 0, 501 | 1.0, 502 | 0, 503 | ] 504 | ).float() 505 | grid = grid.repeat([self.grid_W * self.grid_Z * self.grid_Y * self.grid_X, 1]) 506 | grid = grid.reshape(self.grid_W, self.grid_Z, self.grid_Y, self.grid_X, -1) 507 | grid = grid.permute(4, 0, 1, 2, 3) # (12, grid_W, grid_Z, grid_Y, grid_X) 508 | return grid 509 | 510 | def _init_cp_factors_parafac(self): 511 | # Initialize identity grids. 512 | init_grids = self._init_identity_grid() 513 | # Random noises are added to avoid singularity. 514 | init_grids = torch.randn_like(init_grids) * self.init_noise_scale + init_grids 515 | from tensorly.decomposition import parafac 516 | 517 | # Initialize grid CP factors 518 | _, facs = parafac(init_grids.clone().detach(), rank=self.rank) 519 | 520 | self.num_facs = len(facs) 521 | 522 | self.fac_0 = nn.Linear(facs[0].shape[0], facs[0].shape[1], bias=False) 523 | self.fac_0.weight = nn.Parameter(facs[0]) # (12, rank) 524 | 525 | for i in range(1, self.num_facs): 526 | fac = facs[i].T # (rank, grid_size) 527 | fac = fac.view(1, fac.shape[0], fac.shape[1], 1) # (1, rank, grid_size, 1) 528 | self.register_buffer(f"fac_{i}_init", fac) 529 | 530 | fac_resid = torch.zeros_like(fac) 531 | self.register_parameter(f"fac_{i}", nn.Parameter(fac_resid)) 532 | 533 | def tv_loss(self): 534 | """Computes and returns total variation loss on the factors of the low-rank 4D bilateral grids.""" 535 | 536 | total_loss = 0 537 | for i in range(1, self.num_facs): 538 | fac = self.get_parameter(f"fac_{i}") 539 | total_loss += total_variation_loss(fac) 540 | 541 | return total_loss 542 | 543 | def forward(self, xyz, rgb): 544 | """Low-rank 4D bilateral grid slicing. 545 | 546 | Args: 547 | xyz (torch.Tensor): The xyz coordinates with shape $(..., 3)$. 548 | rgb (torch.Tensor): The corresponding RGB values with shape $(..., 3)$. 549 | 550 | Returns: 551 | Sliced affine matrices with shape $(..., 3, 4)$. 552 | """ 553 | sh_ = xyz.shape 554 | xyz = xyz.reshape(-1, 3) # flatten (N, 3) 555 | rgb = rgb.reshape(-1, 3) # flatten (N, 3) 556 | 557 | xyz = xyz / self.bound 558 | assert self.rgb2gray is not None 559 | gray = self.rgb2gray(rgb) 560 | xyzw = torch.cat([xyz, gray], dim=-1) # (N, 4) 561 | xyzw = xyzw.transpose(0, 1) # (4, N) 562 | coords = torch.stack([torch.zeros_like(xyzw), xyzw], dim=-1) # (4, N, 2) 563 | coords = coords.unsqueeze(1) # (4, 1, N, 2) 564 | 565 | coef = 1.0 566 | for i in range(1, self.num_facs): 567 | fac = self.get_parameter(f"fac_{i}") + self.get_buffer(f"fac_{i}_init") 568 | coef = coef * F.grid_sample( 569 | fac, coords[[i - 1]], align_corners=True, padding_mode="border" 570 | ) # [1, rank, 1, N] 571 | coef = coef.squeeze([0, 2]).transpose(0, 1) # (N, rank) #type: ignore 572 | mat = self.fac_0(coef) 573 | return mat.reshape(*sh_[:-1], 3, 4) 574 | --------------------------------------------------------------------------------