├── assets
├── demo.gif
├── example_ref.png
└── example_input.png
├── requirements.txt
├── examples
├── gsplat
│ ├── requirements.txt
│ ├── datasets
│ │ ├── normalize.py
│ │ ├── traj.py
│ │ └── colmap.py
│ ├── utils.py
│ └── lib_bilagrid.py
├── nerfstudio
│ ├── pyproject.toml
│ └── difix3d
│ │ ├── difix3d_config.py
│ │ ├── difix3d_datamanager.py
│ │ ├── difix3d_trainer.py
│ │ ├── difix3d.py
│ │ ├── difix3d_pipeline.py
│ │ └── difix3d_field.py
└── utils.py
├── src
├── dataset.py
├── loss.py
├── inference_difix.py
├── model.py
└── train_difix.py
├── README.md
└── LICENSE.txt
/assets/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/Difix3D/HEAD/assets/demo.gif
--------------------------------------------------------------------------------
/assets/example_ref.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/Difix3D/HEAD/assets/example_ref.png
--------------------------------------------------------------------------------
/assets/example_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/Difix3D/HEAD/assets/example_input.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | torchmetrics
4 | wandb
5 | imageio[ffmpeg]
6 | einops
7 | lpips
8 | xformers
9 | peft==0.9.0
10 | diffusers==0.25.1
11 | huggingface-hub==0.25.1
12 | transformers==4.38.0
--------------------------------------------------------------------------------
/examples/gsplat/requirements.txt:
--------------------------------------------------------------------------------
1 | # assume torch is already installed
2 |
3 | # pycolmap for data parsing
4 | git+https://github.com/rmbrualla/pycolmap@cc7ea4b7301720ac29287dbe450952511b32125e
5 | # (optional) nerfacc for torch version rasterization
6 | # git+https://github.com/nerfstudio-project/nerfacc
7 |
8 | viser
9 | nerfview==0.0.2
10 | imageio[ffmpeg]
11 | numpy<2.0.0
12 | scikit-learn
13 | tqdm
14 | torchmetrics[image]
15 | opencv-python
16 | tyro>=0.8.8
17 | Pillow
18 | tensorboard
19 | tensorly
20 | pyyaml
21 | matplotlib
22 | git+https://github.com/rahul-goel/fused-ssim@1272e21a282342e89537159e4bad508b19b34157
23 |
--------------------------------------------------------------------------------
/examples/nerfstudio/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "difix3d"
3 | version = "0.1.0"
4 |
5 | dependencies=[
6 | "nerfstudio>=0.3.0",
7 | ]
8 |
9 | # black
10 | [tool.black]
11 | line-length = 120
12 |
13 | # pylint
14 | [tool.pylint.messages_control]
15 | max-line-length = 120
16 | generated-members = ["numpy.*", "torch.*", "cv2.*", "cv.*"]
17 | good-names-rgxs = "^[_a-zA-Z][_a-z0-9]?$"
18 | ignore-paths = ["scripts/colmap2nerf.py"]
19 | jobs = 0
20 | ignored-classes = ["TensorDataclass"]
21 |
22 | disable = [
23 | "duplicate-code",
24 | "fixme",
25 | "logging-fstring-interpolation",
26 | "too-many-arguments",
27 | "too-many-branches",
28 | "too-many-instance-attributes",
29 | "too-many-locals",
30 | "unnecessary-ellipsis",
31 | ]
32 |
33 | [tool.setuptools.packages.find]
34 | include = ["difix3d"]
35 |
36 | [project.entry-points.'nerfstudio.method_configs']
37 | difix3d = 'difix3d.difix3d_config:difix3d_method'
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | from PIL import Image
4 | import torchvision.transforms.functional as F
5 |
6 |
7 | class PairedDataset(torch.utils.data.Dataset):
8 | def __init__(self, dataset_path, split, height=576, width=1024, tokenizer=None):
9 |
10 | super().__init__()
11 | with open(dataset_path, "r") as f:
12 | self.data = json.load(f)[split]
13 | self.img_ids = list(self.data.keys())
14 | self.image_size = (height, width)
15 | self.tokenizer = tokenizer
16 |
17 | def __len__(self):
18 |
19 | return len(self.img_ids)
20 |
21 | def __getitem__(self, idx):
22 |
23 | img_id = self.img_ids[idx]
24 |
25 | input_img = self.data[img_id]["image"]
26 | output_img = self.data[img_id]["target_image"]
27 | ref_img = self.data[img_id]["ref_image"] if "ref_image" in self.data[img_id] else None
28 | caption = self.data[img_id]["prompt"]
29 |
30 | try:
31 | input_img = Image.open(input_img)
32 | output_img = Image.open(output_img)
33 | except:
34 | print("Error loading image:", input_img, output_img)
35 | return self.__getitem__(idx + 1)
36 |
37 | img_t = F.to_tensor(img_t)
38 | img_t = F.resize(img_t, self.image_size)
39 | img_t = F.normalize(img_t, mean=[0.5], std=[0.5])
40 |
41 | output_t = F.to_tensor(output_t)
42 | output_t = F.resize(output_t, self.image_size)
43 | output_t = F.normalize(output_t, mean=[0.5], std=[0.5])
44 |
45 | if ref_img is not None:
46 | ref_img = Image.open(ref_img)
47 | ref_t = F.to_tensor(ref_t)
48 | ref_t = F.resize(ref_t, self.image_size)
49 | ref_t = F.normalize(ref_t, mean=[0.5], std=[0.5])
50 |
51 | img_t = torch.stack([img_t, ref_t], dim=0)
52 | output_t = torch.stack([output_t, ref_t], dim=0)
53 | else:
54 | img_t = img_t.unsqueeze(0)
55 | output_t = output_t.unsqueeze(0)
56 |
57 | out = {
58 | "output_pixel_values": output_t,
59 | "conditioning_pixel_values": img_t,
60 | "caption": caption,
61 | }
62 |
63 | if self.tokenizer is not None:
64 | input_ids = self.tokenizer(
65 | caption, max_length=self.tokenizer.model_max_length,
66 | padding="max_length", truncation=True, return_tensors="pt"
67 | ).input_ids
68 | out["input_ids"] = input_ids
69 |
70 | return out
71 |
--------------------------------------------------------------------------------
/src/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import transforms, models
3 |
4 | # Define style weights for different layers
5 | STYLE_WEIGHTS = {
6 | 'relu1_2': 1.0 / 2.6,
7 | 'relu2_2': 1.0 / 4.8,
8 | 'relu3_3': 1.0 / 3.7,
9 | 'relu4_3': 1.0 / 5.6,
10 | 'relu5_3': 10.0 / 1.5
11 | }
12 |
13 | def get_features(image, model, layers=None):
14 | """
15 | Extract features from specific layers of a model for a given image.
16 |
17 | Args:
18 | image (torch.Tensor): Input image tensor.
19 | model (torch.nn.Module): Pretrained model (e.g., VGG).
20 | layers (dict): Mapping of layer indices to layer names.
21 |
22 | Returns:
23 | dict: A dictionary of features for the specified layers.
24 | """
25 | if layers is None:
26 | layers = {
27 | '3': 'relu1_2',
28 | '8': 'relu2_2',
29 | '15': 'relu3_3',
30 | '22': 'relu4_3',
31 | '29': 'relu5_3'
32 | }
33 |
34 | features = {}
35 | x = image
36 | for name, layer in model._modules.items():
37 | x = layer(x)
38 | if name in layers:
39 | features[layers[name]] = x
40 | return features
41 |
42 | def gram_matrix(tensor):
43 | """
44 | Compute the Gram matrix for a given tensor.
45 |
46 | Args:
47 | tensor (torch.Tensor): Input tensor of shape (batch_size, depth, height, width).
48 |
49 | Returns:
50 | torch.Tensor: Gram matrix of the input tensor.
51 | """
52 | b, d, h, w = tensor.size()
53 | tensor = tensor.view(b * d, h * w) # Reshape tensor for matrix multiplication
54 | gram = torch.mm(tensor, tensor.t()) # Compute Gram matrix
55 | return gram
56 |
57 | def gram_loss(style, target, model):
58 | """
59 | Compute the Gram loss (style loss) between a style image and a target image.
60 |
61 | Args:
62 | style (torch.Tensor): Style image tensor.
63 | target (torch.Tensor): Target image tensor.
64 | model (torch.nn.Module): Pretrained model (e.g., VGG).
65 |
66 | Returns:
67 | torch.Tensor: The computed Gram loss.
68 | """
69 | # Extract features for the style and target images
70 | style_features = get_features(style, model)
71 | target_features = get_features(target, model)
72 |
73 | # Compute Gram matrices for the style image
74 | style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}
75 |
76 | # Initialize total loss
77 | total_loss = 0
78 |
79 | # Compute the weighted Gram loss for each layer
80 | for layer, weight in STYLE_WEIGHTS.items():
81 | target_feature = target_features[layer]
82 | target_gram = gram_matrix(target_feature)
83 | style_gram = style_grams[layer]
84 |
85 | # Compute the layer-specific Gram loss
86 | _, d, h, w = target_feature.shape
87 | layer_loss = weight * torch.mean((target_gram - style_gram) ** 2)
88 | total_loss += layer_loss / (d * h * w)
89 |
90 | return total_loss
--------------------------------------------------------------------------------
/examples/nerfstudio/difix3d/difix3d_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The Nerfstudio Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from nerfstudio.cameras.camera_optimizers import CameraOptimizerConfig
17 | from nerfstudio.configs.base_config import ViewerConfig
18 | from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig
19 | from nerfstudio.engine.optimizers import AdamOptimizerConfig
20 | from nerfstudio.engine.schedulers import ExponentialDecaySchedulerConfig
21 | from nerfstudio.plugins.types import MethodSpecification
22 |
23 | from difix3d.difix3d_datamanager import Difix3DDataManagerConfig
24 | from difix3d.difix3d import Difix3DModelConfig
25 | from difix3d.difix3d_pipeline import Difix3DPipelineConfig
26 | from difix3d.difix3d_trainer import Difix3DTrainerConfig
27 |
28 | difix3d_method = MethodSpecification(
29 | config=Difix3DTrainerConfig(
30 | method_name="difix3d",
31 | # steps_per_eval_batch=1000,
32 | # steps_per_eval_image=100,
33 | # steps_per_save=250,
34 | # max_num_iterations=15000,
35 | # save_only_latest_checkpoint=True,
36 | # mixed_precision=False,
37 | steps_per_eval_batch=500,
38 | steps_per_save=2000,
39 | max_num_iterations=30000,
40 | mixed_precision=True,
41 | pipeline=Difix3DPipelineConfig(
42 | datamanager=Difix3DDataManagerConfig(
43 | dataparser=NerfstudioDataParserConfig(),
44 | train_num_rays_per_batch=16384,
45 | eval_num_rays_per_batch=4096,
46 | ),
47 | model=Difix3DModelConfig(
48 | eval_num_rays_per_chunk=1 << 15,
49 | average_init_density=0.01,
50 | camera_optimizer=CameraOptimizerConfig(mode="SO3xR3"),
51 | ),
52 | ),
53 | optimizers={
54 | "proposal_networks": {
55 | "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
56 | "scheduler": ExponentialDecaySchedulerConfig(lr_final=0.0001, max_steps=200000),
57 | },
58 | "fields": {
59 | "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15),
60 | "scheduler": ExponentialDecaySchedulerConfig(lr_final=0.0001, max_steps=200000),
61 | },
62 | "camera_opt": {
63 | "optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15),
64 | "scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-4, max_steps=5000),
65 | },
66 | },
67 | viewer=ViewerConfig(num_rays_per_chunk=1 << 15),
68 | vis="viewer",
69 | ),
70 | description="Difix3D",
71 | )
--------------------------------------------------------------------------------
/examples/nerfstudio/difix3d/difix3d_datamanager.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The Nerfstudio Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from __future__ import annotations
17 |
18 | from dataclasses import dataclass, field
19 | from typing import Dict, Tuple, Type
20 |
21 | from rich.progress import Console
22 |
23 | from nerfstudio.cameras.rays import RayBundle
24 | from nerfstudio.data.utils.dataloaders import CacheDataloader
25 | from nerfstudio.model_components.ray_generators import RayGenerator
26 | from nerfstudio.data.datamanagers.base_datamanager import (
27 | VanillaDataManager,
28 | VanillaDataManagerConfig,
29 | )
30 |
31 | CONSOLE = Console(width=120)
32 |
33 | @dataclass
34 | class Difix3DDataManagerConfig(VanillaDataManagerConfig):
35 | """Configuration for the Difix3DDataManager."""
36 |
37 | _target: Type = field(default_factory=lambda: Difix3DDataManager)
38 | patch_size: int = 32
39 | """Size of patch to sample from. If >1, patch-based sampling will be used."""
40 |
41 |
42 | class Difix3DDataManager(VanillaDataManager):
43 | """Data manager for Difix3D."""
44 |
45 | config: Difix3DDataManagerConfig
46 |
47 | def setup_train(self):
48 | """Sets up the data loaders for training"""
49 | assert self.train_dataset is not None
50 | CONSOLE.print("Setting up training dataset...")
51 | self.train_image_dataloader = CacheDataloader(
52 | self.train_dataset,
53 | num_images_to_sample_from=self.config.train_num_images_to_sample_from,
54 | num_times_to_repeat_images=self.config.train_num_times_to_repeat_images,
55 | device=self.device,
56 | num_workers=self.world_size * 4,
57 | pin_memory=True,
58 | collate_fn=self.config.collate_fn,
59 | exclude_batch_keys_from_device=self.exclude_batch_keys_from_device,
60 | )
61 | self.iter_train_image_dataloader = iter(self.train_image_dataloader)
62 | self.train_pixel_sampler = self._get_pixel_sampler(self.train_dataset, self.config.train_num_rays_per_batch)
63 | self.train_ray_generator = RayGenerator(self.train_dataset.cameras.to(self.device))
64 |
65 | def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
66 | """Returns the next batch of data from the train dataloader."""
67 | self.train_count += 1
68 | image_batch = next(self.iter_train_image_dataloader)
69 | assert self.train_pixel_sampler is not None
70 | assert isinstance(image_batch, dict)
71 | batch = self.train_pixel_sampler.sample(image_batch)
72 | ray_indices = batch["indices"]
73 | ray_bundle = self.train_ray_generator(ray_indices)
74 | return ray_bundle, batch
75 |
--------------------------------------------------------------------------------
/src/inference_difix.py:
--------------------------------------------------------------------------------
1 | import os
2 | import imageio
3 | import argparse
4 | import numpy as np
5 | from PIL import Image
6 | from glob import glob
7 | from tqdm import tqdm
8 | from model import Difix
9 |
10 |
11 | if __name__ == "__main__":
12 | # Argument parser
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--input_image', type=str, required=True, help='Path to the input image or directory')
15 | parser.add_argument('--ref_image', type=str, default=None, help='Path to the reference image or directory')
16 | parser.add_argument('--height', type=int, default=576, help='Height of the input image')
17 | parser.add_argument('--width', type=int, default=1024, help='Width of the input image')
18 | parser.add_argument('--prompt', type=str, required=True, help='The prompt to be used')
19 | parser.add_argument('--model_name', type=str, default=None, help='Name of the pretrained model to be used')
20 | parser.add_argument('--model_path', type=str, default=None, help='Path to a model state dict to be used')
21 | parser.add_argument('--output_dir', type=str, default='output', help='Directory to save the output')
22 | parser.add_argument('--seed', type=int, default=42, help='Random seed to be used')
23 | parser.add_argument('--timestep', type=int, default=199, help='Diffusion timestep')
24 | parser.add_argument('--video', action='store_true', help='If the input is a video')
25 | args = parser.parse_args()
26 |
27 | # Create output directory
28 | os.makedirs(args.output_dir, exist_ok=True)
29 |
30 | # Initialize the model
31 | model = Difix(
32 | pretrained_name=args.model_name,
33 | pretrained_path=args.model_path,
34 | timestep=args.timestep,
35 | mv_unet=True if args.ref_image is not None else False,
36 | )
37 | model.set_eval()
38 |
39 | # Load input images
40 | if os.path.isdir(args.input_image):
41 | input_images = sorted(glob(os.path.join(args.input_image, "*.png")))
42 | else:
43 | input_images = [args.input_image]
44 |
45 | # Load reference images if provided
46 | if args.ref_image is not None:
47 | if os.path.isdir(args.ref_image):
48 | ref_images = sorted(glob(os.path.join(args.ref_image, "*")))
49 | else:
50 | ref_images = [args.ref_image]
51 |
52 | assert len(input_images) == len(ref_images), "Number of input images and reference images should be the same"
53 |
54 | # Process images
55 | output_images = []
56 | for i, input_image in enumerate(tqdm(input_images, desc="Processing images")):
57 | image = Image.open(input_image).convert('RGB')
58 | ref_image = Image.open(ref_images[i]).convert('RGB') if args.ref_image is not None else None
59 | output_image = model.sample(
60 | image,
61 | height=args.height,
62 | width=args.width,
63 | ref_image=ref_image,
64 | prompt=args.prompt
65 | )
66 | output_images.append(output_image)
67 |
68 | # Save outputs
69 | if args.video:
70 | # Save as video
71 | video_path = os.path.join(args.output_dir, "output.mp4")
72 | writer = imageio.get_writer(video_path, fps=30)
73 | for output_image in tqdm(output_images, desc="Saving video"):
74 | writer.append_data(np.array(output_image))
75 | writer.close()
76 | else:
77 | # Save as individual images
78 | for i, output_image in enumerate(tqdm(output_images, desc="Saving images")):
79 | output_image.save(os.path.join(args.output_dir, os.path.basename(input_images[i])))
--------------------------------------------------------------------------------
/examples/nerfstudio/difix3d/difix3d_trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The Nerfstudio Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import dataclasses
17 | from dataclasses import dataclass, field
18 | from typing import Type, Literal
19 | from nerfstudio.engine.trainer import Trainer, TrainerConfig
20 | from nerfstudio.engine.callbacks import TrainingCallbackAttributes
21 | from nerfstudio.viewer.viewer import Viewer as ViewerState
22 | from nerfstudio.utils import profiler, writer
23 |
24 |
25 | @dataclass
26 | class Difix3DTrainerConfig(TrainerConfig):
27 | """Configuration for the Difix3DTrainer."""
28 | _target: Type = field(default_factory=lambda: Difix3DTrainer)
29 |
30 |
31 | class Difix3DTrainer(Trainer):
32 | """Trainer for Difix3D"""
33 |
34 | def __init__(self, config: TrainerConfig, local_rank: int = 0, world_size: int = 1) -> None:
35 |
36 | super().__init__(config, local_rank, world_size)
37 |
38 | def setup(self, test_mode: Literal["test", "val", "inference"] = "val") -> None:
39 | """Setup the Trainer by calling other setup functions.
40 |
41 | Args:
42 | test_mode:
43 | 'val': loads train/val datasets into memory
44 | 'test': loads train/test datasets into memory
45 | 'inference': does not load any dataset into memory
46 | """
47 | self.pipeline = self.config.pipeline.setup(
48 | device=self.device,
49 | test_mode=test_mode,
50 | world_size=self.world_size,
51 | local_rank=self.local_rank,
52 | grad_scaler=self.grad_scaler,
53 | render_dir=self.base_dir / "renders",
54 | )
55 | self.optimizers = self.setup_optimizers()
56 |
57 | # set up viewer if enabled
58 | viewer_log_path = self.base_dir / self.config.viewer.relative_log_filename
59 | self.viewer_state, banner_messages = None, None
60 | if self.config.is_viewer_legacy_enabled() and self.local_rank == 0:
61 | datapath = self.config.data
62 | if datapath is None:
63 | datapath = self.base_dir
64 | self.viewer_state = ViewerLegacyState(
65 | self.config.viewer,
66 | log_filename=viewer_log_path,
67 | datapath=datapath,
68 | pipeline=self.pipeline,
69 | trainer=self,
70 | train_lock=self.train_lock,
71 | )
72 | banner_messages = [f"Legacy viewer at: {self.viewer_state.viewer_url}"]
73 | if self.config.is_viewer_enabled() and self.local_rank == 0:
74 | datapath = self.config.data
75 | if datapath is None:
76 | datapath = self.base_dir
77 | self.viewer_state = ViewerState(
78 | self.config.viewer,
79 | log_filename=viewer_log_path,
80 | datapath=datapath,
81 | pipeline=self.pipeline,
82 | trainer=self,
83 | train_lock=self.train_lock,
84 | share=self.config.viewer.make_share_url,
85 | )
86 | banner_messages = self.viewer_state.viewer_info
87 | self._check_viewer_warnings()
88 |
89 | self._load_checkpoint()
90 |
91 | self.callbacks = self.pipeline.get_training_callbacks(
92 | TrainingCallbackAttributes(
93 | optimizers=self.optimizers, grad_scaler=self.grad_scaler, pipeline=self.pipeline, trainer=self
94 | )
95 | )
96 |
97 | # set up writers/profilers if enabled
98 | writer_log_path = self.base_dir / self.config.logging.relative_log_dir
99 | writer.setup_event_writer(
100 | self.config.is_wandb_enabled(),
101 | self.config.is_tensorboard_enabled(),
102 | self.config.is_comet_enabled(),
103 | log_dir=writer_log_path,
104 | experiment_name=self.config.experiment_name,
105 | project_name=self.config.project_name,
106 | )
107 | writer.setup_local_writer(
108 | self.config.logging, max_iter=self.config.max_num_iterations, banner_messages=banner_messages
109 | )
110 | writer.put_config(name="config", config_dict=dataclasses.asdict(self.config), step=0)
111 | profiler.setup_profiler(self.config.logging, writer_log_path)
--------------------------------------------------------------------------------
/examples/nerfstudio/difix3d/difix3d.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The Nerfstudio Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from __future__ import annotations
17 |
18 | from dataclasses import dataclass, field
19 | from typing import Type
20 |
21 | import torch
22 | from nerfstudio.models.nerfacto import NerfactoModel, NerfactoModelConfig
23 | from difix3d.difix3d_field import Difix3DField
24 | from nerfstudio.cameras.camera_optimizers import CameraOptimizer
25 | from nerfstudio.field_components.spatial_distortions import SceneContraction
26 | from nerfstudio.fields.density_fields import HashMLPDensityField
27 |
28 |
29 | @dataclass
30 | class Difix3DModelConfig(NerfactoModelConfig):
31 | """Configuration for the Difix3DModel."""
32 | _target: Type = field(default_factory=lambda: Difix3DModel)
33 |
34 | disable_camera_optimizer: bool = False
35 | """Whether to disable the camera optimizer or not."""
36 | freeze_appearance_embedding: bool = False
37 |
38 | class Difix3DModel(NerfactoModel):
39 | """Model for Difix3D."""
40 |
41 | config: Difix3DModelConfig
42 |
43 | def populate_modules(self):
44 | """Set the fields and modules."""
45 | super().populate_modules()
46 |
47 | if self.config.disable_scene_contraction:
48 | scene_contraction = None
49 | else:
50 | scene_contraction = SceneContraction(order=float("inf"))
51 |
52 | appearance_embedding_dim = self.config.appearance_embed_dim if self.config.use_appearance_embedding else 0
53 |
54 | # Fields
55 | self.field = Difix3DField(
56 | self.scene_box.aabb,
57 | hidden_dim=self.config.hidden_dim,
58 | num_levels=self.config.num_levels,
59 | max_res=self.config.max_res,
60 | base_res=self.config.base_res,
61 | features_per_level=self.config.features_per_level,
62 | log2_hashmap_size=self.config.log2_hashmap_size,
63 | hidden_dim_color=self.config.hidden_dim_color,
64 | hidden_dim_transient=self.config.hidden_dim_transient,
65 | spatial_distortion=scene_contraction,
66 | num_images=self.num_train_data,
67 | use_pred_normals=self.config.predict_normals,
68 | use_average_appearance_embedding=self.config.use_average_appearance_embedding,
69 | appearance_embedding_dim=appearance_embedding_dim,
70 | average_init_density=self.config.average_init_density,
71 | implementation=self.config.implementation,
72 | freeze_appearance_embedding=self.config.freeze_appearance_embedding,
73 | )
74 |
75 | self.camera_optimizer: CameraOptimizer = self.config.camera_optimizer.setup(
76 | num_cameras=self.num_train_data, device="cpu"
77 | )
78 | if self.config.disable_camera_optimizer:
79 | self.camera_optimizer.mode = "off"
80 | self.density_fns = []
81 | num_prop_nets = self.config.num_proposal_iterations
82 | # Build the proposal network(s)
83 | self.proposal_networks = torch.nn.ModuleList()
84 | if self.config.use_same_proposal_network:
85 | assert len(self.config.proposal_net_args_list) == 1, "Only one proposal network is allowed."
86 | prop_net_args = self.config.proposal_net_args_list[0]
87 | network = HashMLPDensityField(
88 | self.scene_box.aabb,
89 | spatial_distortion=scene_contraction,
90 | **prop_net_args,
91 | average_init_density=self.config.average_init_density,
92 | implementation=self.config.implementation,
93 | )
94 | self.proposal_networks.append(network)
95 | self.density_fns.extend([network.density_fn for _ in range(num_prop_nets)])
96 | else:
97 | for i in range(num_prop_nets):
98 | prop_net_args = self.config.proposal_net_args_list[min(i, len(self.config.proposal_net_args_list) - 1)]
99 | network = HashMLPDensityField(
100 | self.scene_box.aabb,
101 | spatial_distortion=scene_contraction,
102 | **prop_net_args,
103 | average_init_density=self.config.average_init_density,
104 | implementation=self.config.implementation,
105 | )
106 | self.proposal_networks.append(network)
107 | self.density_fns.extend([network.density_fn for network in self.proposal_networks])
108 |
--------------------------------------------------------------------------------
/examples/gsplat/datasets/normalize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def similarity_from_cameras(c2w, strict_scaling=False, center_method="focus"):
5 | """
6 | reference: nerf-factory
7 | Get a similarity transform to normalize dataset
8 | from c2w (OpenCV convention) cameras
9 | :param c2w: (N, 4)
10 | :return T (4,4) , scale (float)
11 | """
12 | t = c2w[:, :3, 3]
13 | R = c2w[:, :3, :3]
14 |
15 | # (1) Rotate the world so that z+ is the up axis
16 | # we estimate the up axis by averaging the camera up axes
17 | ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1)
18 | world_up = np.mean(ups, axis=0)
19 | world_up /= np.linalg.norm(world_up)
20 |
21 | up_camspace = np.array([0.0, -1.0, 0.0])
22 | c = (up_camspace * world_up).sum()
23 | cross = np.cross(world_up, up_camspace)
24 | skew = np.array(
25 | [
26 | [0.0, -cross[2], cross[1]],
27 | [cross[2], 0.0, -cross[0]],
28 | [-cross[1], cross[0], 0.0],
29 | ]
30 | )
31 | if c > -1:
32 | R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c)
33 | else:
34 | # In the unlikely case the original data has y+ up axis,
35 | # rotate 180-deg about x axis
36 | R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
37 |
38 | # R_align = np.eye(3) # DEBUG
39 | R = R_align @ R
40 | fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1)
41 | t = (R_align @ t[..., None])[..., 0]
42 |
43 | # (2) Recenter the scene.
44 | if center_method == "focus":
45 | # find the closest point to the origin for each camera's center ray
46 | nearest = t + (fwds * -t).sum(-1)[:, None] * fwds
47 | translate = -np.median(nearest, axis=0)
48 | elif center_method == "poses":
49 | # use center of the camera positions
50 | translate = -np.median(t, axis=0)
51 | else:
52 | raise ValueError(f"Unknown center_method {center_method}")
53 |
54 | transform = np.eye(4)
55 | transform[:3, 3] = translate
56 | transform[:3, :3] = R_align
57 |
58 | # (3) Rescale the scene using camera distances
59 | scale_fn = np.max if strict_scaling else np.median
60 | scale = 1.0 / scale_fn(np.linalg.norm(t + translate, axis=-1))
61 | transform[:3, :] *= scale
62 |
63 | return transform
64 |
65 |
66 | def align_principle_axes(point_cloud):
67 | # Compute centroid
68 | centroid = np.median(point_cloud, axis=0)
69 |
70 | # Translate point cloud to centroid
71 | translated_point_cloud = point_cloud - centroid
72 |
73 | # Compute covariance matrix
74 | covariance_matrix = np.cov(translated_point_cloud, rowvar=False)
75 |
76 | # Compute eigenvectors and eigenvalues
77 | eigenvalues, eigenvectors = np.linalg.eigh(covariance_matrix)
78 |
79 | # Sort eigenvectors by eigenvalues (descending order) so that the z-axis
80 | # is the principal axis with the smallest eigenvalue.
81 | sort_indices = eigenvalues.argsort()[::-1]
82 | eigenvectors = eigenvectors[:, sort_indices]
83 |
84 | # Check orientation of eigenvectors. If the determinant of the eigenvectors is
85 | # negative, then we need to flip the sign of one of the eigenvectors.
86 | if np.linalg.det(eigenvectors) < 0:
87 | eigenvectors[:, 0] *= -1
88 |
89 | # Create rotation matrix
90 | rotation_matrix = eigenvectors.T
91 |
92 | # Create SE(3) matrix (4x4 transformation matrix)
93 | transform = np.eye(4)
94 | transform[:3, :3] = rotation_matrix
95 | transform[:3, 3] = -rotation_matrix @ centroid
96 |
97 | return transform
98 |
99 |
100 | def transform_points(matrix, points):
101 | """Transform points using an SE(3) matrix.
102 |
103 | Args:
104 | matrix: 4x4 SE(3) matrix
105 | points: Nx3 array of points
106 |
107 | Returns:
108 | Nx3 array of transformed points
109 | """
110 | assert matrix.shape == (4, 4)
111 | assert len(points.shape) == 2 and points.shape[1] == 3
112 | return points @ matrix[:3, :3].T + matrix[:3, 3]
113 |
114 |
115 | def transform_cameras(matrix, camtoworlds):
116 | """Transform cameras using an SE(3) matrix.
117 |
118 | Args:
119 | matrix: 4x4 SE(3) matrix
120 | camtoworlds: Nx4x4 array of camera-to-world matrices
121 |
122 | Returns:
123 | Nx4x4 array of transformed camera-to-world matrices
124 | """
125 | assert matrix.shape == (4, 4)
126 | assert len(camtoworlds.shape) == 3 and camtoworlds.shape[1:] == (4, 4)
127 | camtoworlds = np.einsum("nij, ki -> nkj", camtoworlds, matrix)
128 | scaling = np.linalg.norm(camtoworlds[:, 0, :3], axis=1)
129 | camtoworlds[:, :3, :3] = camtoworlds[:, :3, :3] / scaling[:, None, None]
130 | return camtoworlds
131 |
132 |
133 | def normalize(camtoworlds, points=None):
134 | T1 = similarity_from_cameras(camtoworlds)
135 | camtoworlds = transform_cameras(T1, camtoworlds)
136 | if points is not None:
137 | points = transform_points(T1, points)
138 | T2 = align_principle_axes(points)
139 | camtoworlds = transform_cameras(T2, camtoworlds)
140 | points = transform_points(T2, points)
141 | return camtoworlds, points, T2 @ T1
142 | else:
143 | return camtoworlds, T1
144 |
--------------------------------------------------------------------------------
/examples/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.spatial.transform import Rotation
3 |
4 |
5 | class CameraPoseInterpolator:
6 | """
7 | A system for interpolating between sets of camera poses with visualization capabilities.
8 | """
9 |
10 | def __init__(self, rotation_weight=1.0, translation_weight=1.0):
11 | """
12 | Initialize the interpolator with weights for pose distance computation.
13 |
14 | Args:
15 | rotation_weight: Weight for rotational distance in pose matching
16 | translation_weight: Weight for translational distance in pose matching
17 | """
18 | self.rotation_weight = rotation_weight
19 | self.translation_weight = translation_weight
20 |
21 | def compute_pose_distance(self, pose1, pose2):
22 | """
23 | Compute weighted distance between two camera poses.
24 |
25 | Args:
26 | pose1, pose2: 4x4 transformation matrices
27 |
28 | Returns:
29 | Combined weighted distance between poses
30 | """
31 | # Translation distance (Euclidean)
32 | t1, t2 = pose1[:3, 3], pose2[:3, 3]
33 | translation_dist = np.linalg.norm(t1 - t2)
34 |
35 | # Rotation distance (angular distance between quaternions)
36 | R1 = Rotation.from_matrix(pose1[:3, :3])
37 | R2 = Rotation.from_matrix(pose2[:3, :3])
38 | q1 = R1.as_quat()
39 | q2 = R2.as_quat()
40 |
41 | # Ensure quaternions are in the same hemisphere
42 | if np.dot(q1, q2) < 0:
43 | q2 = -q2
44 |
45 | rotation_dist = np.arccos(2 * np.dot(q1, q2)**2 - 1)
46 |
47 | return (self.translation_weight * translation_dist +
48 | self.rotation_weight * rotation_dist)
49 |
50 | def find_nearest_assignments(self, training_poses, testing_poses):
51 | """
52 | Find the nearest training camera pose for each testing camera pose.
53 |
54 | Args:
55 | training_poses: [N, 4, 4] array of training camera poses
56 | testing_poses: [M, 4, 4] array of testing camera poses
57 |
58 | Returns:
59 | assignments: list of closest training pose indices for each testing pose
60 | """
61 | M = len(testing_poses)
62 | assignments = []
63 |
64 | for j in range(M):
65 | # Compute distance from each training pose to this testing pose
66 | distances = [self.compute_pose_distance(training_pose, testing_poses[j])
67 | for training_pose in training_poses]
68 | # Find the index of the nearest training pose
69 | nearest_index = np.argmin(distances)
70 | assignments.append(nearest_index)
71 |
72 | return assignments
73 |
74 | def interpolate_rotation(self, R1, R2, t):
75 | """
76 | Interpolate between two rotation matrices using SLERP.
77 | """
78 | q1 = Rotation.from_matrix(R1).as_quat()
79 | q2 = Rotation.from_matrix(R2).as_quat()
80 |
81 | if np.dot(q1, q2) < 0:
82 | q2 = -q2
83 |
84 | # Clamp dot product to avoid invalid values in arccos
85 | dot_product = np.clip(np.dot(q1, q2), -1.0, 1.0)
86 | theta = np.arccos(dot_product)
87 |
88 | if np.abs(theta) < 1e-6:
89 | q_interp = (1 - t) * q1 + t * q2
90 | else:
91 | q_interp = (np.sin((1-t)*theta) * q1 + np.sin(t*theta) * q2) / np.sin(theta)
92 |
93 | q_interp = q_interp / np.linalg.norm(q_interp)
94 | return Rotation.from_quat(q_interp).as_matrix()
95 |
96 | def interpolate_poses(self, training_poses, testing_poses, num_steps=20):
97 | """
98 | Interpolate between camera poses using nearest assignments.
99 |
100 | Args:
101 | training_poses: [N, 4, 4] array of training poses
102 | testing_poses: [M, 4, 4] array of testing poses
103 | num_steps: number of interpolation steps
104 |
105 | Returns:
106 | interpolated_sequences: list of lists of interpolated poses
107 | """
108 | assignments = self.find_nearest_assignments(training_poses, testing_poses)
109 | interpolated_sequences = []
110 |
111 | for test_idx, train_idx in enumerate(assignments):
112 | train_pose = training_poses[train_idx]
113 | test_pose = testing_poses[test_idx]
114 | sequence = []
115 |
116 | for t in np.linspace(0, 1, num_steps):
117 | # Interpolate rotation
118 | R_interp = self.interpolate_rotation(
119 | train_pose[:3, :3],
120 | test_pose[:3, :3],
121 | t
122 | )
123 |
124 | # Interpolate translation
125 | t_interp = (1-t) * train_pose[:3, 3] + t * test_pose[:3, 3]
126 |
127 | # Construct interpolated pose
128 | pose_interp = np.eye(4)
129 | pose_interp[:3, :3] = R_interp
130 | pose_interp[:3, 3] = t_interp
131 |
132 | sequence.append(pose_interp)
133 |
134 | interpolated_sequences.append(sequence)
135 |
136 | return interpolated_sequences
137 |
138 |
139 | def shift_poses(self, training_poses, testing_poses, distance=0.1, threshold=0.1):
140 | """
141 | Shift nearest training poses toward testing poses by a specified distance.
142 |
143 | Args:
144 | training_poses: [N, 4, 4] array of training camera poses
145 | testing_poses: [M, 4, 4] array of testing camera poses
146 | distance: float, the step size to move training pose toward testing pose
147 |
148 | Returns:
149 | novel_poses: [M, 4, 4] array of shifted poses
150 | """
151 | assignments = self.find_nearest_assignments(training_poses, testing_poses)
152 | novel_poses = []
153 |
154 | for test_idx, train_idx in enumerate(assignments):
155 | train_pose = training_poses[train_idx]
156 | test_pose = testing_poses[test_idx]
157 |
158 | if self.compute_pose_distance(train_pose, test_pose) <= distance:
159 | novel_poses.append(test_pose)
160 | continue
161 |
162 | # Calculate translation step if shifting is necessary
163 | t1, t2 = train_pose[:3, 3], test_pose[:3, 3]
164 | translation_direction = t2 - t1
165 | translation_norm = np.linalg.norm(translation_direction)
166 |
167 | if translation_norm > 1e-6:
168 | translation_step = (translation_direction / translation_norm) * distance
169 | new_translation = t1 + translation_step
170 | else:
171 | # If translation direction is too small, use testing pose translation directly
172 | new_translation = t2
173 |
174 | # Check if the new translation would overshoot the testing pose translation
175 | if np.dot(new_translation - t1, t2 - t1) <= 0 or np.linalg.norm(new_translation - t2) <= distance:
176 | new_translation = t2
177 |
178 | # Update rotation
179 | R1 = train_pose[:3, :3]
180 | R2 = test_pose[:3, :3]
181 | if translation_norm > 1e-6:
182 | R_interp = self.interpolate_rotation(R1, R2, min(distance / translation_norm, 1.0))
183 | else:
184 | R_interp = R2 # Use testing rotation if too close
185 |
186 | # Construct shifted pose
187 | shifted_pose = np.eye(4)
188 | shifted_pose[:3, :3] = R_interp
189 | shifted_pose[:3, 3] = new_translation
190 |
191 | novel_poses.append(shifted_pose)
192 |
193 | return np.array(novel_poses)
--------------------------------------------------------------------------------
/examples/gsplat/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import torch
5 | from sklearn.neighbors import NearestNeighbors
6 | from torch import Tensor
7 | import torch.nn.functional as F
8 | import matplotlib.pyplot as plt
9 | from matplotlib import colormaps
10 |
11 |
12 | class CameraOptModule(torch.nn.Module):
13 | """Camera pose optimization module."""
14 |
15 | def __init__(self, n: int):
16 | super().__init__()
17 | # Delta positions (3D) + Delta rotations (6D)
18 | self.embeds = torch.nn.Embedding(n, 9)
19 | # Identity rotation in 6D representation
20 | self.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]))
21 |
22 | def zero_init(self):
23 | torch.nn.init.zeros_(self.embeds.weight)
24 |
25 | def random_init(self, std: float):
26 | torch.nn.init.normal_(self.embeds.weight, std=std)
27 |
28 | def forward(self, camtoworlds: Tensor, embed_ids: Tensor) -> Tensor:
29 | """Adjust camera pose based on deltas.
30 |
31 | Args:
32 | camtoworlds: (..., 4, 4)
33 | embed_ids: (...,)
34 |
35 | Returns:
36 | updated camtoworlds: (..., 4, 4)
37 | """
38 | assert camtoworlds.shape[:-2] == embed_ids.shape
39 | batch_shape = camtoworlds.shape[:-2]
40 | pose_deltas = self.embeds(embed_ids) # (..., 9)
41 | dx, drot = pose_deltas[..., :3], pose_deltas[..., 3:]
42 | rot = rotation_6d_to_matrix(
43 | drot + self.identity.expand(*batch_shape, -1)
44 | ) # (..., 3, 3)
45 | transform = torch.eye(4, device=pose_deltas.device).repeat((*batch_shape, 1, 1))
46 | transform[..., :3, :3] = rot
47 | transform[..., :3, 3] = dx
48 | return torch.matmul(camtoworlds, transform)
49 |
50 |
51 | class AppearanceOptModule(torch.nn.Module):
52 | """Appearance optimization module."""
53 |
54 | def __init__(
55 | self,
56 | n: int,
57 | feature_dim: int,
58 | embed_dim: int = 16,
59 | sh_degree: int = 3,
60 | mlp_width: int = 64,
61 | mlp_depth: int = 2,
62 | ):
63 | super().__init__()
64 | self.embed_dim = embed_dim
65 | self.sh_degree = sh_degree
66 | self.embeds = torch.nn.Embedding(n, embed_dim)
67 | layers = []
68 | layers.append(
69 | torch.nn.Linear(embed_dim + feature_dim + (sh_degree + 1) ** 2, mlp_width)
70 | )
71 | layers.append(torch.nn.ReLU(inplace=True))
72 | for _ in range(mlp_depth - 1):
73 | layers.append(torch.nn.Linear(mlp_width, mlp_width))
74 | layers.append(torch.nn.ReLU(inplace=True))
75 | layers.append(torch.nn.Linear(mlp_width, 3))
76 | self.color_head = torch.nn.Sequential(*layers)
77 |
78 | def forward(
79 | self, features: Tensor, embed_ids: Tensor, dirs: Tensor, sh_degree: int
80 | ) -> Tensor:
81 | """Adjust appearance based on embeddings.
82 |
83 | Args:
84 | features: (N, feature_dim)
85 | embed_ids: (C,)
86 | dirs: (C, N, 3)
87 |
88 | Returns:
89 | colors: (C, N, 3)
90 | """
91 | from gsplat.cuda._torch_impl import _eval_sh_bases_fast
92 |
93 | C, N = dirs.shape[:2]
94 | # Camera embeddings
95 | if embed_ids is None:
96 | embeds = torch.zeros(C, self.embed_dim, device=features.device)
97 | else:
98 | embeds = self.embeds(embed_ids) # [C, D2]
99 | embeds = embeds[:, None, :].expand(-1, N, -1) # [C, N, D2]
100 | # GS features
101 | features = features[None, :, :].expand(C, -1, -1) # [C, N, D1]
102 | # View directions
103 | dirs = F.normalize(dirs, dim=-1) # [C, N, 3]
104 | num_bases_to_use = (sh_degree + 1) ** 2
105 | num_bases = (self.sh_degree + 1) ** 2
106 | sh_bases = torch.zeros(C, N, num_bases, device=features.device) # [C, N, K]
107 | sh_bases[:, :, :num_bases_to_use] = _eval_sh_bases_fast(num_bases_to_use, dirs)
108 | # Get colors
109 | if self.embed_dim > 0:
110 | h = torch.cat([embeds, features, sh_bases], dim=-1) # [C, N, D1 + D2 + K]
111 | else:
112 | h = torch.cat([features, sh_bases], dim=-1)
113 | colors = self.color_head(h)
114 | return colors
115 |
116 |
117 | def rotation_6d_to_matrix(d6: Tensor) -> Tensor:
118 | """
119 | Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
120 | using Gram--Schmidt orthogonalization per Section B of [1]. Adapted from pytorch3d.
121 | Args:
122 | d6: 6D rotation representation, of size (*, 6)
123 |
124 | Returns:
125 | batch of rotation matrices of size (*, 3, 3)
126 |
127 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
128 | On the Continuity of Rotation Representations in Neural Networks.
129 | IEEE Conference on Computer Vision and Pattern Recognition, 2019.
130 | Retrieved from http://arxiv.org/abs/1812.07035
131 | """
132 |
133 | a1, a2 = d6[..., :3], d6[..., 3:]
134 | b1 = F.normalize(a1, dim=-1)
135 | b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
136 | b2 = F.normalize(b2, dim=-1)
137 | b3 = torch.cross(b1, b2, dim=-1)
138 | return torch.stack((b1, b2, b3), dim=-2)
139 |
140 |
141 | def knn(x: Tensor, K: int = 4) -> Tensor:
142 | x_np = x.cpu().numpy()
143 | model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np)
144 | distances, _ = model.kneighbors(x_np)
145 | return torch.from_numpy(distances).to(x)
146 |
147 |
148 | def rgb_to_sh(rgb: Tensor) -> Tensor:
149 | C0 = 0.28209479177387814
150 | return (rgb - 0.5) / C0
151 |
152 |
153 | def set_random_seed(seed: int):
154 | random.seed(seed)
155 | np.random.seed(seed)
156 | torch.manual_seed(seed)
157 |
158 |
159 | # ref: https://github.com/hbb1/2d-gaussian-splatting/blob/main/utils/general_utils.py#L163
160 | def colormap(img, cmap="jet"):
161 | W, H = img.shape[:2]
162 | dpi = 300
163 | fig, ax = plt.subplots(1, figsize=(H / dpi, W / dpi), dpi=dpi)
164 | im = ax.imshow(img, cmap=cmap)
165 | ax.set_axis_off()
166 | fig.colorbar(im, ax=ax)
167 | fig.tight_layout()
168 | fig.canvas.draw()
169 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
170 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
171 | img = torch.from_numpy(data).float().permute(2, 0, 1)
172 | plt.close()
173 | return img
174 |
175 |
176 | def apply_float_colormap(img: torch.Tensor, colormap: str = "turbo") -> torch.Tensor:
177 | """Convert single channel to a color img.
178 |
179 | Args:
180 | img (torch.Tensor): (..., 1) float32 single channel image.
181 | colormap (str): Colormap for img.
182 |
183 | Returns:
184 | (..., 3) colored img with colors in [0, 1].
185 | """
186 | img = torch.nan_to_num(img, 0)
187 | if colormap == "gray":
188 | return img.repeat(1, 1, 3)
189 | img_long = (img * 255).long()
190 | img_long_min = torch.min(img_long)
191 | img_long_max = torch.max(img_long)
192 | assert img_long_min >= 0, f"the min value is {img_long_min}"
193 | assert img_long_max <= 255, f"the max value is {img_long_max}"
194 | return torch.tensor(
195 | colormaps[colormap].colors, # type: ignore
196 | device=img.device,
197 | )[img_long[..., 0]]
198 |
199 |
200 | def apply_depth_colormap(
201 | depth: torch.Tensor,
202 | acc: torch.Tensor = None,
203 | near_plane: float = None,
204 | far_plane: float = None,
205 | ) -> torch.Tensor:
206 | """Converts a depth image to color for easier analysis.
207 |
208 | Args:
209 | depth (torch.Tensor): (..., 1) float32 depth.
210 | acc (torch.Tensor | None): (..., 1) optional accumulation mask.
211 | near_plane: Closest depth to consider. If None, use min image value.
212 | far_plane: Furthest depth to consider. If None, use max image value.
213 |
214 | Returns:
215 | (..., 3) colored depth image with colors in [0, 1].
216 | """
217 | near_plane = near_plane or float(torch.min(depth))
218 | far_plane = far_plane or float(torch.max(depth))
219 | depth = (depth - near_plane) / (far_plane - near_plane + 1e-10)
220 | depth = torch.clip(depth, 0.0, 1.0)
221 | img = apply_float_colormap(depth, colormap="turbo")
222 | if acc is not None:
223 | img = img * acc + (1.0 - acc)
224 | return img
225 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Difix3D+
2 |
3 | **Difix3D+: Improving 3D Reconstructions with Single-Step Diffusion Models**
4 | [Jay Zhangjie Wu*](https://zhangjiewu.github.io/), [Yuxuan Zhang*](https://scholar.google.com/citations?user=Jt5VvNgAAAAJ&hl=en), [Haithem Turki](https://haithemturki.com/), [Xuanchi Ren](https://xuanchiren.com/), [Jun Gao](https://www.cs.toronto.edu/~jungao/),
5 | [Mike Zheng Shou](https://sites.google.com/view/showlab/home?authuser=0), [Sanja Fidler](https://www.cs.utoronto.ca/~fidler/), [Zan Gojcic†](https://zgojcic.github.io/), [Huan Ling†](https://www.cs.toronto.edu/~linghuan/) _(*/† equal contribution/advising)_
6 | CVPR 2025 (Oral)
7 | [Project Page](https://research.nvidia.com/labs/toronto-ai/difix3d/) | [Paper](https://arxiv.org/abs/2503.01774) | [Model](https://huggingface.co/nvidia/difix) | [Demo](https://huggingface.co/spaces/nvidia/difix)
8 |
9 |
10 |

11 |
12 |
13 |
14 | ## News
15 |
16 | * [11/06/2025] Code and models are now available! We will present our work at CVPR 2025 ([oral](https://cvpr.thecvf.com/virtual/2025/oral/35364), [poster](https://cvpr.thecvf.com/virtual/2025/poster/34172)). See you in Nashville🎵!
17 |
18 |
19 | ## Setup
20 |
21 | ```bash
22 | git clone https://github.com/nv-tlabs/Difix3D.git
23 | cd Difix3D
24 | pip install -r requirements.txt
25 | ```
26 |
27 | ## Quickstart (diffusers)
28 |
29 | ```
30 | from pipeline_difix import DifixPipeline
31 | from diffusers.utils import load_image
32 |
33 | pipe = DifixPipeline.from_pretrained("nvidia/difix", trust_remote_code=True)
34 | pipe.to("cuda")
35 |
36 | input_image = load_image("assets/example_input.png")
37 | prompt = "remove degradation"
38 |
39 | output_image = pipe(prompt, image=input_image, num_inference_steps=1, timesteps=[199], guidance_scale=0.0).images[0]
40 | output_image.save("example_output.png")
41 | ```
42 |
43 | Optionally, you can use a reference image to guide the denoising process.
44 | ```
45 | from pipeline_difix import DifixPipeline
46 | from diffusers.utils import load_image
47 |
48 | pipe = DifixPipeline.from_pretrained("nvidia/difix_ref", trust_remote_code=True)
49 | pipe.to("cuda")
50 |
51 | input_image = load_image("assets/example_input.png")
52 | ref_image = load_image("assets/example_ref.png")
53 | prompt = "remove degradation"
54 |
55 | output_image = pipe(prompt, image=input_image, ref_image=ref_image, num_inference_steps=1, timesteps=[199], guidance_scale=0.0).images[0]
56 | output_image.save("example_output.png")
57 | ```
58 |
59 | ## Difix: Single-step diffusion for 3D artifact removal
60 |
61 | ### Training
62 |
63 | #### Data Preparation
64 |
65 | Prepare your dataset in the following JSON format:
66 |
67 | ```json
68 | {
69 | "train": {
70 | "{data_id}": {
71 | "image": "{PATH_TO_IMAGE}",
72 | "target_image": "{PATH_TO_TARGET_IMAGE}",
73 | "ref_image": "{PATH_TO_REF_IMAGE}",
74 | "prompt": "remove degradation"
75 | }
76 | },
77 | "test": {
78 | "{data_id}": {
79 | "image": "{PATH_TO_IMAGE}",
80 | "target_image": "{PATH_TO_TARGET_IMAGE}",
81 | "ref_image": "{PATH_TO_REF_IMAGE}",
82 | "prompt": "remove degradation"
83 | }
84 | }
85 | }
86 | ```
87 |
88 | #### Single GPU
89 |
90 | ```bash
91 | accelerate launch --mixed_precision=bf16 src/train_difix.py \
92 | --output_dir=./outputs/difix/train \
93 | --dataset_path="data/data.json" \
94 | --max_train_steps 10000 \
95 | --resolution=512 --learning_rate 2e-5 \
96 | --train_batch_size=1 --dataloader_num_workers 8 \
97 | --enable_xformers_memory_efficient_attention \
98 | --checkpointing_steps=1000 --eval_freq 1000 --viz_freq 100 \
99 | --lambda_lpips 1.0 --lambda_l2 1.0 --lambda_gram 1.0 --gram_loss_warmup_steps 2000 \
100 | --report_to "wandb" --tracker_project_name "difix" --tracker_run_name "train" --timestep 199
101 | ```
102 |
103 | #### Multipe GPUs
104 |
105 | ```bash
106 | export NUM_NODES=1
107 | export NUM_GPUS=8
108 | accelerate launch --mixed_precision=bf16 --main_process_port 29501 --multi_gpu --num_machines $NUM_NODES --num_processes $NUM_GPUS src/train_difix.py \
109 | --output_dir=./outputs/difix/train \
110 | --dataset_path="data/data.json" \
111 | --max_train_steps 10000 \
112 | --resolution=512 --learning_rate 2e-5 \
113 | --train_batch_size=1 --dataloader_num_workers 8 \
114 | --enable_xformers_memory_efficient_attention \
115 | --checkpointing_steps=1000 --eval_freq 1000 --viz_freq 100 \
116 | --lambda_lpips 1.0 --lambda_l2 1.0 --lambda_gram 1.0 --gram_loss_warmup_steps 2000 \
117 | --report_to "wandb" --tracker_project_name "difix" --tracker_run_name "train" --timestep 199
118 | ```
119 |
120 | ### Inference
121 |
122 | Place the `model_*.pkl` in the `checkpoints` directory. You can run inference using the following command:
123 |
124 | ```bash
125 | python src/inference_difix.py \
126 | --model_path "checkpoints/model.pkl" \
127 | --input_image "assets/example_input.png" \
128 | --prompt "remove degradation" \
129 | --output_dir "outputs/difix" \
130 | --timestep 199
131 | ```
132 |
133 |
134 | ## Difix3D: Progressive 3D update
135 |
136 | ### Data Format
137 |
138 | The data should be organized in the following structure:
139 |
140 | ```
141 | DATA_DIR/
142 | ├── {SCENE_ID}
143 | │ ├── colmap
144 | │ │ ├── sparse
145 | │ │ │ └── 0
146 | │ │ │ ├── cameras.bin
147 | │ │ │ ├── database.db
148 | │ │ │ └── ...
149 | │ ├── images
150 | │ │ ├── image_train_000001.png
151 | │ │ ├── image_train_000002.png
152 | │ │ ├── ...
153 | │ │ ├── image_eval_000200.png
154 | │ │ ├── image_eval_000201.png
155 | │ │ └── ...
156 | │ ├── images_2
157 | │ ├── images_4
158 | │ └── images_8
159 | ```
160 |
161 | ### nerfstudio
162 |
163 | Setup the nerfstudio environment.
164 | ```bash
165 | cd examples/nerfstudio
166 | pip install -e .
167 | cd ../..
168 | ```
169 |
170 | Run Difix3D finetuning with nerfstudio.
171 | ```bash
172 | SCENE_ID=032dee9fb0a8bc1b90871dc5fe950080d0bcd3caf166447f44e60ca50ac04ec7
173 | DATA=DATA_DIR/${SCENE_ID}
174 | DATA_FACTOR=4
175 | CKPT_PATH=CKPR_DIR/${SCENE_ID}/nerfacto/nerfstudio_models/step-000029999.ckpt # Path to the pretrained checkpoint file
176 | OUTPUT_DIR=outputs/difix3d/nerfacto/${SCENE_ID}
177 |
178 | CUDA_VISIBLE_DEVICES=0 ns-train difix3d \
179 | --data ${DATA} --pipeline.model.appearance-embed-dim 0 --pipeline.model.camera-optimizer.mode off --save_only_latest_checkpoint False --vis viewer \
180 | --output_dir ${OUTPUT_DIR} --experiment_name ${SCENE_ID} --timestamp '' --load-checkpoint ${CKPT_PATH} \
181 | --max_num_iterations 30000 --steps_per_eval_all_images 0 --steps_per_eval_batch 0 --steps_per_eval_image 0 --steps_per_save 2000 --viewer.quit-on-train-completion True \
182 | nerfstudio-data --orientation-method none --center_method none --auto-scale-poses False --downscale_factor ${DATA_FACTOR} --eval_mode filename
183 | ```
184 |
185 | ### gsplat
186 |
187 | Install the gsplat following the instructions in the [gsplat repository](https://github.com/nerfstudio-project/gsplat?tab=readme-ov-file#installation).
188 |
189 | Run Difix3D finetuning with gsplat.
190 | ```bash
191 | SCENE_ID=032dee9fb0a8bc1b90871dc5fe950080d0bcd3caf166447f44e60ca50ac04ec7
192 | DATA=DATA_DIR/${SCENE_ID}/gaussian_splat
193 | DATA_FACTOR=4
194 | CKPT_PATH=CKPT_DIR/${SCENE_ID}/ckpts/ckpt_29999_rank0.pt # Path to the pretrained checkpoint file
195 | OUTPUT_DIR=outputs/difix3d/gsplat/${SCENE_ID}
196 |
197 | CUDA_VISIBLE_DEVICES=0 python examples/gsplat/simple_trainer_difix3d.py default \
198 | --data_dir ${DATA} --data_factor ${DATA_FACTOR} \
199 | --result_dir ${OUTPUT_DIR} --no-normalize-world-space --test_every 1 --ckpt ${CKPT_PATH}
200 | ```
201 |
202 |
203 | ## Difix3D+: With real-time post-rendering
204 |
205 | Due to the limited capacity of reconstruction methods to represent sharp details, some regions remain blurry. To further enhance the novel views, we use our Difix model as the final post-processing step at render time.
206 |
207 | ```bash
208 | python src/inference_difix.py \
209 | --model_path "checkpoints/model.pkl" \
210 | --input_image "PATH_TO_IMAGES" \
211 | --prompt "remove degradation" \
212 | --output_dir "outputs/difix3d+" \
213 | --timestep 199
214 | ```
215 |
216 | ## Acknowledgements
217 |
218 | Our work is built upon the following projects:
219 | - [diffusers](https://github.com/huggingface/diffusers)
220 | - [img2img-turbo](https://github.com/GaParmar/img2img-turbo)
221 | - [nerfstudio](https://github.com/nerfstudio-project/nerfstudio)
222 | - [gsplat](https://github.com/nerfstudio-project/gsplat)
223 | - [DL3DV-10K](https://github.com/DL3DV-10K/Dataset)
224 | - [nerfbusters](https://github.com/ethanweber/nerfbusters)
225 |
226 | Shoutout to all the contributors of these projects for their invaluable work that made this research possible.
227 |
228 | ## License/Terms of Use:
229 |
230 | The use of the model and code is governed by the NVIDIA License. See [LICENSE.txt](LICENSE.txt) for details.
231 | Additional Information: [LICENSE.md · stabilityai/sd-turbo at main](https://huggingface.co/stabilityai/sd-turbo/blob/main/LICENSE.md)
232 |
233 | ## Citation
234 |
235 | ```bibtex
236 | @inproceedings{wu2025difix3d+,
237 | title={DIFIX3D+: Improving 3D Reconstructions with Single-Step Diffusion Models},
238 | author={Wu, Jay Zhangjie and Zhang, Yuxuan and Turki, Haithem and Ren, Xuanchi and Gao, Jun and Shou, Mike Zheng and Fidler, Sanja and Gojcic, Zan and Ling, Huan},
239 | booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference},
240 | pages={26024--26035},
241 | year={2025}
242 | }
243 | ```
--------------------------------------------------------------------------------
/examples/nerfstudio/difix3d/difix3d_pipeline.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The Nerfstudio Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from dataclasses import dataclass, field
17 | from typing import Optional, Type
18 | from pathlib import Path
19 | from PIL import Image
20 | import os
21 | import tqdm
22 | import random
23 | import numpy as np
24 | import torch
25 | from torch.cuda.amp.grad_scaler import GradScaler
26 | from typing_extensions import Literal
27 | from nerfstudio.cameras.cameras import Cameras
28 | from nerfstudio.pipelines.base_pipeline import VanillaPipeline, VanillaPipelineConfig
29 | from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs
30 |
31 |
32 | from difix3d.difix3d_datamanager import (
33 | Difix3DDataManagerConfig,
34 | )
35 | from src.pipeline_difix import DifixPipeline
36 | from examples.utils import CameraPoseInterpolator
37 |
38 |
39 | @dataclass
40 | class Difix3DPipelineConfig(VanillaPipelineConfig):
41 | """Configuration for pipeline instantiation"""
42 |
43 | _target: Type = field(default_factory=lambda: Difix3DPipeline)
44 | """target class to instantiate"""
45 | datamanager: Difix3DDataManagerConfig = Difix3DDataManagerConfig()
46 | """specifies the datamanager config"""
47 | steps_per_fix: int = 2000
48 | """rate at which to fix artifacts"""
49 | steps_per_val: int = 5000
50 | """rate at which to evaluate the model"""
51 |
52 | class Difix3DPipeline(VanillaPipeline):
53 | """Difix3D pipeline"""
54 |
55 | config: Difix3DPipelineConfig
56 |
57 | def __init__(
58 | self,
59 | config: Difix3DPipelineConfig,
60 | device: str,
61 | test_mode: Literal["test", "val", "inference"] = "val",
62 | world_size: int = 1,
63 | local_rank: int = 0,
64 | grad_scaler: Optional[GradScaler] = None,
65 | render_dir: str = "renders",
66 | ):
67 | super().__init__(config, device, test_mode, world_size, local_rank)
68 |
69 | self.render_dir = render_dir
70 |
71 | self.difix = DifixPipeline.from_pretrained("nvidia/difix_ref", trust_remote_code=True)
72 | self.difix.set_progress_bar_config(disable=True)
73 | self.difix.to("cuda")
74 |
75 | self.training_poses = self.datamanager.train_dataparser_outputs.cameras.camera_to_worlds
76 | self.training_poses = torch.cat([self.training_poses, torch.tensor([0, 0, 0, 1]).reshape(1, 1, 4).repeat(self.training_poses.shape[0], 1, 1)], dim=1)
77 | self.testing_poses = self.datamanager.dataparser.get_dataparser_outputs(split=self.datamanager.test_split).cameras.camera_to_worlds
78 | self.testing_poses = torch.cat([self.testing_poses, torch.tensor([0, 0, 0, 1]).reshape(1, 1, 4).repeat(self.testing_poses.shape[0], 1, 1)], dim=1)
79 | self.current_novel_poses = self.training_poses
80 | self.current_novel_cameras = self.datamanager.train_dataparser_outputs.cameras
81 |
82 | self.interpolator = CameraPoseInterpolator(rotation_weight=1.0, translation_weight=1.0)
83 | self.novel_datamanagers = []
84 |
85 | def get_train_loss_dict(self, step: int):
86 | """This function gets your training loss dict and performs image editing.
87 | Args:
88 | step: current iteration step to update sampler if using DDP (distributed)
89 | """
90 | if len(self.novel_datamanagers) == 0 or random.random() < 0.6:
91 | ray_bundle, batch = self.datamanager.next_train(step)
92 | else:
93 | ray_bundle, batch = self.novel_datamanagers[-1].next_train(step)
94 |
95 | model_outputs = self.model(ray_bundle)
96 | metrics_dict = self.model.get_metrics_dict(model_outputs, batch)
97 |
98 | loss_dict = self.model.get_loss_dict(model_outputs, batch, metrics_dict)
99 |
100 | # run fixer
101 | if (step % self.config.steps_per_fix == 0):
102 | self.fix(step)
103 |
104 | # run evaluation
105 | if (step % self.config.steps_per_val == 0):
106 | self.val(step)
107 |
108 | return model_outputs, loss_dict, metrics_dict
109 |
110 | def forward(self):
111 | """Not implemented since we only want the parameter saving of the nn module, but not forward()"""
112 | raise NotImplementedError
113 |
114 | @torch.no_grad()
115 | def render_traj(self, step, cameras, tag="novel"):
116 | for i in tqdm.trange(0, len(cameras), desc="Rendering trajectory"):
117 | with torch.no_grad():
118 | outputs = self.model.get_outputs_for_camera(cameras[i])
119 |
120 | rgb_path = f"{self.render_dir}/{tag}/{step}/Pred/{i:04d}.png"
121 | os.makedirs(os.path.dirname(rgb_path), exist_ok=True)
122 | rgb_canvas = outputs['rgb'].cpu().numpy()
123 | rgb_canvas = (rgb_canvas * 255).astype(np.uint8)
124 | Image.fromarray(rgb_canvas).save(rgb_path)
125 |
126 | @torch.no_grad()
127 | def val(self, step):
128 | cameras = self.datamanager.dataparser.get_dataparser_outputs(split=self.datamanager.test_split).cameras
129 | for i in tqdm.trange(0, len(cameras), desc="Running evaluation"):
130 | with torch.no_grad():
131 | outputs = self.model.get_outputs_for_camera(cameras[i])
132 |
133 | rgb_path = f"{self.render_dir}/val/{step}/{i:04d}.png"
134 | os.makedirs(os.path.dirname(rgb_path), exist_ok=True)
135 | rgb_canvas = outputs['rgb'].cpu().numpy()
136 | rgb_canvas = (rgb_canvas * 255).astype(np.uint8)
137 | Image.fromarray(rgb_canvas).save(rgb_path)
138 |
139 | @torch.no_grad()
140 | def fix(self, step: int):
141 |
142 | novel_poses = self.interpolator.shift_poses(self.current_novel_poses.numpy(), self.testing_poses.numpy(), distance=0.5)
143 | novel_poses = torch.from_numpy(novel_poses).to(self.testing_poses.dtype)
144 |
145 | ref_image_indices = self.interpolator.find_nearest_assignments(self.training_poses.numpy(), novel_poses.numpy())
146 | ref_image_filenames = np.array(self.datamanager.train_dataparser_outputs.image_filenames)[ref_image_indices].tolist()
147 |
148 | cameras = self.datamanager.train_dataparser_outputs.cameras
149 | cameras = Cameras(
150 | fx=cameras.fx[0].repeat(len(novel_poses), 1),
151 | fy=cameras.fy[0].repeat(len(novel_poses), 1),
152 | cx=cameras.cx[0].repeat(len(novel_poses), 1),
153 | cy=cameras.cy[0].repeat(len(novel_poses), 1),
154 | distortion_params=cameras.distortion_params[0].repeat(len(novel_poses), 1),
155 | height=cameras.height[0].repeat(len(novel_poses), 1),
156 | width=cameras.width[0].repeat(len(novel_poses), 1),
157 | camera_to_worlds=novel_poses[:, :3, :4],
158 | camera_type=cameras.camera_type[0].repeat(len(novel_poses), 1),
159 | metadata=cameras.metadata,
160 | )
161 |
162 | self.render_traj(step, cameras)
163 |
164 | image_filenames = []
165 | for i in tqdm.trange(0, len(novel_poses), desc="Fixing artifacts..."):
166 | image = Image.open(f"{self.render_dir}/novel/{step}/Pred/{i:04d}.png").convert("RGB")
167 | ref_image = Image.open(ref_image_filenames[i]).convert("RGB")
168 | output_image = self.difix(prompt="remove degradation", image=image, ref_image=ref_image, num_inference_steps=1, timesteps=[199], guidance_scale=0.0).images[0]
169 | output_image = output_image.resize(image.size, Image.LANCZOS)
170 | os.makedirs(f"{self.render_dir}/novel/{step}/Fixed", exist_ok=True)
171 | output_image.save(f"{self.render_dir}/novel/{step}/Fixed/{i:04d}.png")
172 | image_filenames.append(Path(f"{self.render_dir}/novel/{step}/Fixed/{i:04d}.png"))
173 | if ref_image is not None:
174 | os.makedirs(f"{self.render_dir}/novel/{step}/Ref", exist_ok=True)
175 | ref_image.save(f"{self.render_dir}/novel/{step}/Ref/{i:04d}.png")
176 |
177 | dataparser_outputs = self.datamanager.train_dataparser_outputs
178 | dataparser_outputs = DataparserOutputs(
179 | image_filenames=image_filenames,
180 | cameras=cameras,
181 | scene_box=dataparser_outputs.scene_box,
182 | mask_filenames=None,
183 | dataparser_scale=dataparser_outputs.dataparser_scale,
184 | dataparser_transform=dataparser_outputs.dataparser_transform,
185 | metadata=dataparser_outputs.metadata,
186 | )
187 |
188 | datamanager_config = Difix3DDataManagerConfig(
189 | dataparser=self.config.datamanager.dataparser,
190 | train_num_rays_per_batch=16384,
191 | eval_num_rays_per_batch=4096,
192 | )
193 |
194 | datamanager = datamanager_config.setup(
195 | device=self.datamanager.device,
196 | test_mode=self.datamanager.test_mode,
197 | world_size=self.datamanager.world_size,
198 | local_rank=self.datamanager.local_rank
199 | )
200 |
201 | datamanager.train_dataparser_outputs = dataparser_outputs
202 | datamanager.train_dataset = datamanager.create_train_dataset()
203 | datamanager.setup_train()
204 |
205 | self.novel_datamanagers.append(datamanager)
206 | self.current_novel_poses = novel_poses
207 | self.current_novel_cameras = cameras
--------------------------------------------------------------------------------
/examples/gsplat/datasets/traj.py:
--------------------------------------------------------------------------------
1 | """
2 | Code borrowed from
3 |
4 | https://github.com/google-research/multinerf/blob/5b4d4f64608ec8077222c52fdf814d40acc10bc1/internal/camera_utils.py
5 | """
6 |
7 | import numpy as np
8 | import scipy
9 |
10 |
11 | def normalize(x: np.ndarray) -> np.ndarray:
12 | """Normalization helper function."""
13 | return x / np.linalg.norm(x)
14 |
15 |
16 | def viewmatrix(lookdir: np.ndarray, up: np.ndarray, position: np.ndarray) -> np.ndarray:
17 | """Construct lookat view matrix."""
18 | vec2 = normalize(lookdir)
19 | vec0 = normalize(np.cross(up, vec2))
20 | vec1 = normalize(np.cross(vec2, vec0))
21 | m = np.stack([vec0, vec1, vec2, position], axis=1)
22 | return m
23 |
24 |
25 | def focus_point_fn(poses: np.ndarray) -> np.ndarray:
26 | """Calculate nearest point to all focal axes in poses."""
27 | directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]
28 | m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])
29 | mt_m = np.transpose(m, [0, 2, 1]) @ m
30 | focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]
31 | return focus_pt
32 |
33 |
34 | def average_pose(poses: np.ndarray) -> np.ndarray:
35 | """New pose using average position, z-axis, and up vector of input poses."""
36 | position = poses[:, :3, 3].mean(0)
37 | z_axis = poses[:, :3, 2].mean(0)
38 | up = poses[:, :3, 1].mean(0)
39 | cam2world = viewmatrix(z_axis, up, position)
40 | return cam2world
41 |
42 |
43 | def generate_spiral_path(
44 | poses,
45 | bounds,
46 | n_frames=120,
47 | n_rots=2,
48 | zrate=0.5,
49 | spiral_scale_f=1.0,
50 | spiral_scale_r=1.0,
51 | focus_distance=0.75,
52 | ):
53 | """Calculates a forward facing spiral path for rendering."""
54 | # Find a reasonable 'focus depth' for this dataset as a weighted average
55 | # of conservative near and far bounds in disparity space.
56 | near_bound = bounds.min()
57 | far_bound = bounds.max()
58 | # All cameras will point towards the world space point (0, 0, -focal).
59 | focal = 1 / (((1 - focus_distance) / near_bound + focus_distance / far_bound))
60 | focal = focal * spiral_scale_f
61 |
62 | # Get radii for spiral path using 90th percentile of camera positions.
63 | positions = poses[:, :3, 3]
64 | radii = np.percentile(np.abs(positions), 90, 0)
65 | radii = radii * spiral_scale_r
66 | radii = np.concatenate([radii, [1.0]])
67 |
68 | # Generate poses for spiral path.
69 | render_poses = []
70 | cam2world = average_pose(poses)
71 | up = poses[:, :3, 1].mean(0)
72 | for theta in np.linspace(0.0, 2.0 * np.pi * n_rots, n_frames, endpoint=False):
73 | t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0]
74 | position = cam2world @ t
75 | lookat = cam2world @ [0, 0, -focal, 1.0]
76 | z_axis = position - lookat
77 | render_poses.append(viewmatrix(z_axis, up, position))
78 | render_poses = np.stack(render_poses, axis=0)
79 | return render_poses
80 |
81 |
82 | def generate_ellipse_path_z(
83 | poses: np.ndarray,
84 | n_frames: int = 120,
85 | # const_speed: bool = True,
86 | variation: float = 0.0,
87 | phase: float = 0.0,
88 | height: float = 0.0,
89 | ) -> np.ndarray:
90 | """Generate an elliptical render path based on the given poses."""
91 | # Calculate the focal point for the path (cameras point toward this).
92 | center = focus_point_fn(poses)
93 | # Path height sits at z=height (in middle of zero-mean capture pattern).
94 | offset = np.array([center[0], center[1], height])
95 |
96 | # Calculate scaling for ellipse axes based on input camera positions.
97 | sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0)
98 | # Use ellipse that is symmetric about the focal point in xy.
99 | low = -sc + offset
100 | high = sc + offset
101 | # Optional height variation need not be symmetric
102 | z_low = np.percentile((poses[:, :3, 3]), 10, axis=0)
103 | z_high = np.percentile((poses[:, :3, 3]), 90, axis=0)
104 |
105 | def get_positions(theta):
106 | # Interpolate between bounds with trig functions to get ellipse in x-y.
107 | # Optionally also interpolate in z to change camera height along path.
108 | return np.stack(
109 | [
110 | low[0] + (high - low)[0] * (np.cos(theta) * 0.5 + 0.5),
111 | low[1] + (high - low)[1] * (np.sin(theta) * 0.5 + 0.5),
112 | variation
113 | * (
114 | z_low[2]
115 | + (z_high - z_low)[2]
116 | * (np.cos(theta + 2 * np.pi * phase) * 0.5 + 0.5)
117 | )
118 | + height,
119 | ],
120 | -1,
121 | )
122 |
123 | theta = np.linspace(0, 2.0 * np.pi, n_frames + 1, endpoint=True)
124 | positions = get_positions(theta)
125 |
126 | # if const_speed:
127 | # # Resample theta angles so that the velocity is closer to constant.
128 | # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
129 | # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1)
130 | # positions = get_positions(theta)
131 |
132 | # Throw away duplicated last position.
133 | positions = positions[:-1]
134 |
135 | # Set path's up vector to axis closest to average of input pose up vectors.
136 | avg_up = poses[:, :3, 1].mean(0)
137 | avg_up = avg_up / np.linalg.norm(avg_up)
138 | ind_up = np.argmax(np.abs(avg_up))
139 | up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])
140 |
141 | return np.stack([viewmatrix(center - p, up, p) for p in positions])
142 |
143 |
144 | def generate_ellipse_path_y(
145 | poses: np.ndarray,
146 | n_frames: int = 120,
147 | # const_speed: bool = True,
148 | variation: float = 0.0,
149 | phase: float = 0.0,
150 | height: float = 0.0,
151 | ) -> np.ndarray:
152 | """Generate an elliptical render path based on the given poses."""
153 | # Calculate the focal point for the path (cameras point toward this).
154 | center = focus_point_fn(poses)
155 | # Path height sits at y=height (in middle of zero-mean capture pattern).
156 | offset = np.array([center[0], height, center[2]])
157 |
158 | # Calculate scaling for ellipse axes based on input camera positions.
159 | sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0)
160 | # Use ellipse that is symmetric about the focal point in xy.
161 | low = -sc + offset
162 | high = sc + offset
163 | # Optional height variation need not be symmetric
164 | y_low = np.percentile((poses[:, :3, 3]), 10, axis=0)
165 | y_high = np.percentile((poses[:, :3, 3]), 90, axis=0)
166 |
167 | def get_positions(theta):
168 | # Interpolate between bounds with trig functions to get ellipse in x-z.
169 | # Optionally also interpolate in y to change camera height along path.
170 | return np.stack(
171 | [
172 | low[0] + (high - low)[0] * (np.cos(theta) * 0.5 + 0.5),
173 | variation
174 | * (
175 | y_low[1]
176 | + (y_high - y_low)[1]
177 | * (np.cos(theta + 2 * np.pi * phase) * 0.5 + 0.5)
178 | )
179 | + height,
180 | low[2] + (high - low)[2] * (np.sin(theta) * 0.5 + 0.5),
181 | ],
182 | -1,
183 | )
184 |
185 | theta = np.linspace(0, 2.0 * np.pi, n_frames + 1, endpoint=True)
186 | positions = get_positions(theta)
187 |
188 | # if const_speed:
189 | # # Resample theta angles so that the velocity is closer to constant.
190 | # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
191 | # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1)
192 | # positions = get_positions(theta)
193 |
194 | # Throw away duplicated last position.
195 | positions = positions[:-1]
196 |
197 | # Set path's up vector to axis closest to average of input pose up vectors.
198 | avg_up = poses[:, :3, 1].mean(0)
199 | avg_up = avg_up / np.linalg.norm(avg_up)
200 | ind_up = np.argmax(np.abs(avg_up))
201 | up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])
202 |
203 | return np.stack([viewmatrix(p - center, up, p) for p in positions])
204 |
205 |
206 | def generate_interpolated_path(
207 | poses: np.ndarray,
208 | n_interp: int,
209 | spline_degree: int = 5,
210 | smoothness: float = 0.03,
211 | rot_weight: float = 0.1,
212 | ):
213 | """Creates a smooth spline path between input keyframe camera poses.
214 |
215 | Spline is calculated with poses in format (position, lookat-point, up-point).
216 |
217 | Args:
218 | poses: (n, 3, 4) array of input pose keyframes.
219 | n_interp: returned path will have n_interp * (n - 1) total poses.
220 | spline_degree: polynomial degree of B-spline.
221 | smoothness: parameter for spline smoothing, 0 forces exact interpolation.
222 | rot_weight: relative weighting of rotation/translation in spline solve.
223 |
224 | Returns:
225 | Array of new camera poses with shape (n_interp * (n - 1), 3, 4).
226 | """
227 |
228 | def poses_to_points(poses, dist):
229 | """Converts from pose matrices to (position, lookat, up) format."""
230 | pos = poses[:, :3, -1]
231 | lookat = poses[:, :3, -1] - dist * poses[:, :3, 2]
232 | up = poses[:, :3, -1] + dist * poses[:, :3, 1]
233 | return np.stack([pos, lookat, up], 1)
234 |
235 | def points_to_poses(points):
236 | """Converts from (position, lookat, up) format to pose matrices."""
237 | return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points])
238 |
239 | def interp(points, n, k, s):
240 | """Runs multidimensional B-spline interpolation on the input points."""
241 | sh = points.shape
242 | pts = np.reshape(points, (sh[0], -1))
243 | k = min(k, sh[0] - 1)
244 | tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s)
245 | u = np.linspace(0, 1, n, endpoint=False)
246 | new_points = np.array(scipy.interpolate.splev(u, tck))
247 | new_points = np.reshape(new_points.T, (n, sh[1], sh[2]))
248 | return new_points
249 |
250 | points = poses_to_points(poses, dist=rot_weight)
251 | new_points = interp(
252 | points, n_interp * (points.shape[0] - 1), k=spline_degree, s=smoothness
253 | )
254 | return points_to_poses(new_points)
255 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import requests
3 | import sys
4 | import numpy as np
5 | from PIL import Image
6 | from tqdm import tqdm
7 | import torch
8 | from torchvision import transforms
9 | from transformers import AutoTokenizer, CLIPTextModel
10 | from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
11 | from peft import LoraConfig
12 | p = "src/"
13 | sys.path.append(p)
14 | from einops import rearrange, repeat
15 |
16 |
17 | def make_1step_sched():
18 | noise_scheduler_1step = DDPMScheduler.from_pretrained("stabilityai/sd-turbo", subfolder="scheduler")
19 | noise_scheduler_1step.set_timesteps(1, device="cuda")
20 | noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
21 | return noise_scheduler_1step
22 |
23 |
24 | def my_vae_encoder_fwd(self, sample):
25 | sample = self.conv_in(sample)
26 | l_blocks = []
27 | # down
28 | for down_block in self.down_blocks:
29 | l_blocks.append(sample)
30 | sample = down_block(sample)
31 | # middle
32 | sample = self.mid_block(sample)
33 | sample = self.conv_norm_out(sample)
34 | sample = self.conv_act(sample)
35 | sample = self.conv_out(sample)
36 | self.current_down_blocks = l_blocks
37 | return sample
38 |
39 |
40 | def my_vae_decoder_fwd(self, sample, latent_embeds=None):
41 | sample = self.conv_in(sample)
42 | upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
43 | # middle
44 | sample = self.mid_block(sample, latent_embeds)
45 | sample = sample.to(upscale_dtype)
46 | if not self.ignore_skip:
47 | skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4]
48 | # up
49 | for idx, up_block in enumerate(self.up_blocks):
50 | skip_in = skip_convs[idx](self.incoming_skip_acts[::-1][idx] * self.gamma)
51 | # add skip
52 | sample = sample + skip_in
53 | sample = up_block(sample, latent_embeds)
54 | else:
55 | for idx, up_block in enumerate(self.up_blocks):
56 | sample = up_block(sample, latent_embeds)
57 | # post-process
58 | if latent_embeds is None:
59 | sample = self.conv_norm_out(sample)
60 | else:
61 | sample = self.conv_norm_out(sample, latent_embeds)
62 | sample = self.conv_act(sample)
63 | sample = self.conv_out(sample)
64 | return sample
65 |
66 |
67 | def download_url(url, outf):
68 | if not os.path.exists(outf):
69 | print(f"Downloading checkpoint to {outf}")
70 | response = requests.get(url, stream=True)
71 | total_size_in_bytes = int(response.headers.get('content-length', 0))
72 | block_size = 1024 # 1 Kibibyte
73 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
74 | with open(outf, 'wb') as file:
75 | for data in response.iter_content(block_size):
76 | progress_bar.update(len(data))
77 | file.write(data)
78 | progress_bar.close()
79 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
80 | print("ERROR, something went wrong")
81 | print(f"Downloaded successfully to {outf}")
82 | else:
83 | print(f"Skipping download, {outf} already exists")
84 |
85 |
86 | def load_ckpt_from_state_dict(net_difix, optimizer, pretrained_path):
87 | sd = torch.load(pretrained_path, map_location="cpu")
88 |
89 | if "state_dict_vae" in sd:
90 | _sd_vae = net_difix.vae.state_dict()
91 | for k in sd["state_dict_vae"]:
92 | _sd_vae[k] = sd["state_dict_vae"][k]
93 | net_difix.vae.load_state_dict(_sd_vae)
94 | _sd_unet = net_difix.unet.state_dict()
95 | for k in sd["state_dict_unet"]:
96 | _sd_unet[k] = sd["state_dict_unet"][k]
97 | net_difix.unet.load_state_dict(_sd_unet)
98 |
99 | optimizer.load_state_dict(sd["optimizer"])
100 |
101 | return net_difix, optimizer
102 |
103 |
104 | def save_ckpt(net_difix, optimizer, outf):
105 | sd = {}
106 | sd["vae_lora_target_modules"] = net_difix.target_modules_vae
107 | sd["rank_vae"] = net_difix.lora_rank_vae
108 | sd["state_dict_unet"] = net_difix.unet.state_dict()
109 | sd["state_dict_vae"] = {k: v for k, v in net_difix.vae.state_dict().items() if "lora" in k or "skip" in k}
110 |
111 | sd["optimizer"] = optimizer.state_dict()
112 |
113 | torch.save(sd, outf)
114 |
115 |
116 | class Difix(torch.nn.Module):
117 | def __init__(self, pretrained_name=None, pretrained_path=None, ckpt_folder="checkpoints", lora_rank_vae=4, mv_unet=False, timestep=999):
118 | super().__init__()
119 | self.tokenizer = AutoTokenizer.from_pretrained("stabilityai/sd-turbo", subfolder="tokenizer")
120 | self.text_encoder = CLIPTextModel.from_pretrained("stabilityai/sd-turbo", subfolder="text_encoder").cuda()
121 | self.sched = make_1step_sched()
122 |
123 | vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
124 | vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
125 | vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
126 | # add the skip connection convs
127 | vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
128 | vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
129 | vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
130 | vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
131 | vae.decoder.ignore_skip = False
132 |
133 | if mv_unet:
134 | from mv_unet import UNet2DConditionModel
135 | else:
136 | from diffusers import UNet2DConditionModel
137 |
138 | unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet")
139 |
140 | if pretrained_path is not None:
141 | sd = torch.load(pretrained_path, map_location="cpu")
142 | vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
143 | vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
144 | _sd_vae = vae.state_dict()
145 | for k in sd["state_dict_vae"]:
146 | _sd_vae[k] = sd["state_dict_vae"][k]
147 | vae.load_state_dict(_sd_vae)
148 | _sd_unet = unet.state_dict()
149 | for k in sd["state_dict_unet"]:
150 | _sd_unet[k] = sd["state_dict_unet"][k]
151 | unet.load_state_dict(_sd_unet)
152 |
153 | elif pretrained_name is None and pretrained_path is None:
154 | print("Initializing model with random weights")
155 | target_modules_vae = []
156 |
157 | torch.nn.init.constant_(vae.decoder.skip_conv_1.weight, 1e-5)
158 | torch.nn.init.constant_(vae.decoder.skip_conv_2.weight, 1e-5)
159 | torch.nn.init.constant_(vae.decoder.skip_conv_3.weight, 1e-5)
160 | torch.nn.init.constant_(vae.decoder.skip_conv_4.weight, 1e-5)
161 | target_modules_vae = ["conv1", "conv2", "conv_in", "conv_shortcut", "conv", "conv_out",
162 | "skip_conv_1", "skip_conv_2", "skip_conv_3", "skip_conv_4",
163 | "to_k", "to_q", "to_v", "to_out.0",
164 | ]
165 |
166 | target_modules = []
167 | for id, (name, param) in enumerate(vae.named_modules()):
168 | if 'decoder' in name and any(name.endswith(x) for x in target_modules_vae):
169 | target_modules.append(name)
170 | target_modules_vae = target_modules
171 | vae.encoder.requires_grad_(False)
172 |
173 | vae_lora_config = LoraConfig(r=lora_rank_vae, init_lora_weights="gaussian",
174 | target_modules=target_modules_vae)
175 | vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
176 |
177 | self.lora_rank_vae = lora_rank_vae
178 | self.target_modules_vae = target_modules_vae
179 |
180 | # unet.enable_xformers_memory_efficient_attention()
181 | unet.to("cuda")
182 | vae.to("cuda")
183 |
184 | self.unet, self.vae = unet, vae
185 | self.vae.decoder.gamma = 1
186 | self.timesteps = torch.tensor([timestep], device="cuda").long()
187 | self.text_encoder.requires_grad_(False)
188 |
189 | # print number of trainable parameters
190 | print("="*50)
191 | print(f"Number of trainable parameters in UNet: {sum(p.numel() for p in unet.parameters() if p.requires_grad) / 1e6:.2f}M")
192 | print(f"Number of trainable parameters in VAE: {sum(p.numel() for p in vae.parameters() if p.requires_grad) / 1e6:.2f}M")
193 | print("="*50)
194 |
195 | def set_eval(self):
196 | self.unet.eval()
197 | self.vae.eval()
198 | self.unet.requires_grad_(False)
199 | self.vae.requires_grad_(False)
200 |
201 | def set_train(self):
202 | self.unet.train()
203 | self.vae.train()
204 | self.unet.requires_grad_(True)
205 |
206 | for n, _p in self.vae.named_parameters():
207 | if "lora" in n:
208 | _p.requires_grad = True
209 | self.vae.decoder.skip_conv_1.requires_grad_(True)
210 | self.vae.decoder.skip_conv_2.requires_grad_(True)
211 | self.vae.decoder.skip_conv_3.requires_grad_(True)
212 | self.vae.decoder.skip_conv_4.requires_grad_(True)
213 |
214 | def forward(self, x, timesteps=None, prompt=None, prompt_tokens=None):
215 | # either the prompt or the prompt_tokens should be provided
216 | assert (prompt is None) != (prompt_tokens is None), "Either prompt or prompt_tokens should be provided"
217 | assert (timesteps is None) != (self.timesteps is None), "Either timesteps or self.timesteps should be provided"
218 |
219 | if prompt is not None:
220 | # encode the text prompt
221 | caption_tokens = self.tokenizer(prompt, max_length=self.tokenizer.model_max_length,
222 | padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda()
223 | caption_enc = self.text_encoder(caption_tokens)[0]
224 | else:
225 | caption_enc = self.text_encoder(prompt_tokens)[0]
226 |
227 | num_views = x.shape[1]
228 | x = rearrange(x, 'b v c h w -> (b v) c h w')
229 | z = self.vae.encode(x).latent_dist.sample() * self.vae.config.scaling_factor
230 | caption_enc = repeat(caption_enc, 'b n c -> (b v) n c', v=num_views)
231 |
232 | unet_input = z
233 |
234 | model_pred = self.unet(unet_input, self.timesteps, encoder_hidden_states=caption_enc,).sample
235 | z_denoised = self.sched.step(model_pred, self.timesteps, z, return_dict=True).prev_sample
236 | self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks
237 | output_image = (self.vae.decode(z_denoised / self.vae.config.scaling_factor).sample).clamp(-1, 1)
238 | output_image = rearrange(output_image, '(b v) c h w -> b v c h w', v=num_views)
239 |
240 | return output_image
241 |
242 | def sample(self, image, width, height, ref_image=None, timesteps=None, prompt=None, prompt_tokens=None):
243 | input_width, input_height = image.size
244 | new_width = image.width - image.width % 8
245 | new_height = image.height - image.height % 8
246 | image = image.resize((new_width, new_height), Image.LANCZOS)
247 |
248 | T = transforms.Compose([
249 | transforms.Resize((height, width), interpolation=Image.LANCZOS),
250 | transforms.ToTensor(),
251 | transforms.Normalize([0.5], [0.5]),
252 | ])
253 | if ref_image is None:
254 | x = T(image).unsqueeze(0).unsqueeze(0).cuda()
255 | else:
256 | ref_image = ref_image.resize((new_width, new_height), Image.LANCZOS)
257 | x = torch.stack([T(image), T(ref_image)], dim=0).unsqueeze(0).cuda()
258 |
259 | output_image = self.forward(x, timesteps, prompt, prompt_tokens)[:, 0]
260 | output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
261 | output_pil = output_pil.resize((input_width, input_height), Image.LANCZOS)
262 |
263 | return output_pil
264 |
265 | def save_model(self, outf, optimizer):
266 | sd = {}
267 | sd["vae_lora_target_modules"] = self.target_modules_vae
268 | sd["rank_vae"] = self.lora_rank_vae
269 | sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k}
270 | sd["state_dict_vae"] = {k: v for k, v in self.vae.state_dict().items() if "lora" in k or "skip" in k}
271 |
272 | sd["optimizer"] = optimizer.state_dict()
273 |
274 | torch.save(sd, outf)
275 |
--------------------------------------------------------------------------------
/examples/nerfstudio/difix3d/difix3d_field.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import Dict, Literal, Optional, Tuple
17 |
18 | import torch
19 | from torch import Tensor, nn
20 |
21 | from nerfstudio.cameras.rays import RaySamples
22 | from nerfstudio.data.scene_box import SceneBox
23 | from nerfstudio.field_components.activations import trunc_exp
24 | from nerfstudio.field_components.embedding import Embedding
25 | from nerfstudio.field_components.encodings import NeRFEncoding, SHEncoding
26 | from nerfstudio.field_components.field_heads import (
27 | FieldHeadNames,
28 | PredNormalsFieldHead,
29 | SemanticFieldHead,
30 | TransientDensityFieldHead,
31 | TransientRGBFieldHead,
32 | UncertaintyFieldHead,
33 | )
34 | from nerfstudio.field_components.mlp import MLP, MLPWithHashEncoding
35 | from nerfstudio.field_components.spatial_distortions import SpatialDistortion
36 | from nerfstudio.fields.base_field import Field, get_normalized_directions
37 |
38 |
39 | class Difix3DField(Field):
40 | """Compound Field
41 |
42 | Args:
43 | aabb: parameters of scene aabb bounds
44 | num_images: number of images in the dataset
45 | num_layers: number of hidden layers
46 | hidden_dim: dimension of hidden layers
47 | geo_feat_dim: output geo feat dimensions
48 | num_levels: number of levels of the hashmap for the base mlp
49 | base_res: base resolution of the hashmap for the base mlp
50 | max_res: maximum resolution of the hashmap for the base mlp
51 | log2_hashmap_size: size of the hashmap for the base mlp
52 | num_layers_color: number of hidden layers for color network
53 | num_layers_transient: number of hidden layers for transient network
54 | features_per_level: number of features per level for the hashgrid
55 | hidden_dim_color: dimension of hidden layers for color network
56 | hidden_dim_transient: dimension of hidden layers for transient network
57 | appearance_embedding_dim: dimension of appearance embedding
58 | transient_embedding_dim: dimension of transient embedding
59 | use_transient_embedding: whether to use transient embedding
60 | use_semantics: whether to use semantic segmentation
61 | num_semantic_classes: number of semantic classes
62 | use_pred_normals: whether to use predicted normals
63 | use_average_appearance_embedding: whether to use average appearance embedding or zeros for inference
64 | spatial_distortion: spatial distortion to apply to the scene
65 | """
66 |
67 | aabb: Tensor
68 |
69 | def __init__(
70 | self,
71 | aabb: Tensor,
72 | num_images: int,
73 | num_layers: int = 2,
74 | hidden_dim: int = 64,
75 | geo_feat_dim: int = 15,
76 | num_levels: int = 16,
77 | base_res: int = 16,
78 | max_res: int = 2048,
79 | log2_hashmap_size: int = 19,
80 | num_layers_color: int = 3,
81 | num_layers_transient: int = 2,
82 | features_per_level: int = 2,
83 | hidden_dim_color: int = 64,
84 | hidden_dim_transient: int = 64,
85 | appearance_embedding_dim: int = 32,
86 | transient_embedding_dim: int = 16,
87 | use_transient_embedding: bool = False,
88 | use_semantics: bool = False,
89 | num_semantic_classes: int = 100,
90 | pass_semantic_gradients: bool = False,
91 | use_pred_normals: bool = False,
92 | use_average_appearance_embedding: bool = False,
93 | spatial_distortion: Optional[SpatialDistortion] = None,
94 | average_init_density: float = 1.0,
95 | implementation: Literal["tcnn", "torch"] = "tcnn",
96 | freeze_appearance_embedding: bool = False,
97 | ) -> None:
98 | super().__init__()
99 |
100 | self.register_buffer("aabb", aabb)
101 | self.geo_feat_dim = geo_feat_dim
102 |
103 | self.register_buffer("max_res", torch.tensor(max_res))
104 | self.register_buffer("num_levels", torch.tensor(num_levels))
105 | self.register_buffer("log2_hashmap_size", torch.tensor(log2_hashmap_size))
106 |
107 | self.spatial_distortion = spatial_distortion
108 | self.num_images = num_images
109 | self.appearance_embedding_dim = appearance_embedding_dim
110 | if self.appearance_embedding_dim > 0:
111 | self.embedding_appearance = Embedding(self.num_images, self.appearance_embedding_dim)
112 | else:
113 | self.embedding_appearance = None
114 | self.use_average_appearance_embedding = use_average_appearance_embedding
115 | self.freeze_appearance_embedding = freeze_appearance_embedding
116 | self.use_transient_embedding = use_transient_embedding
117 | self.use_semantics = use_semantics
118 | self.use_pred_normals = use_pred_normals
119 | self.pass_semantic_gradients = pass_semantic_gradients
120 | self.base_res = base_res
121 | self.average_init_density = average_init_density
122 | self.step = 0
123 |
124 | self.direction_encoding = SHEncoding(
125 | levels=4,
126 | implementation=implementation,
127 | )
128 |
129 | self.position_encoding = NeRFEncoding(
130 | in_dim=3, num_frequencies=2, min_freq_exp=0, max_freq_exp=2 - 1, implementation=implementation
131 | )
132 |
133 | self.mlp_base = MLPWithHashEncoding(
134 | num_levels=num_levels,
135 | min_res=base_res,
136 | max_res=max_res,
137 | log2_hashmap_size=log2_hashmap_size,
138 | features_per_level=features_per_level,
139 | num_layers=num_layers,
140 | layer_width=hidden_dim,
141 | out_dim=1 + self.geo_feat_dim,
142 | activation=nn.ReLU(),
143 | out_activation=None,
144 | implementation=implementation,
145 | )
146 |
147 | # transients
148 | if self.use_transient_embedding:
149 | self.transient_embedding_dim = transient_embedding_dim
150 | self.embedding_transient = Embedding(self.num_images, self.transient_embedding_dim)
151 | self.mlp_transient = MLP(
152 | in_dim=self.geo_feat_dim + self.transient_embedding_dim,
153 | num_layers=num_layers_transient,
154 | layer_width=hidden_dim_transient,
155 | out_dim=hidden_dim_transient,
156 | activation=nn.ReLU(),
157 | out_activation=None,
158 | implementation=implementation,
159 | )
160 | self.field_head_transient_uncertainty = UncertaintyFieldHead(in_dim=self.mlp_transient.get_out_dim())
161 | self.field_head_transient_rgb = TransientRGBFieldHead(in_dim=self.mlp_transient.get_out_dim())
162 | self.field_head_transient_density = TransientDensityFieldHead(in_dim=self.mlp_transient.get_out_dim())
163 |
164 | # semantics
165 | if self.use_semantics:
166 | self.mlp_semantics = MLP(
167 | in_dim=self.geo_feat_dim,
168 | num_layers=2,
169 | layer_width=64,
170 | out_dim=hidden_dim_transient,
171 | activation=nn.ReLU(),
172 | out_activation=None,
173 | implementation=implementation,
174 | )
175 | self.field_head_semantics = SemanticFieldHead(
176 | in_dim=self.mlp_semantics.get_out_dim(), num_classes=num_semantic_classes
177 | )
178 |
179 | # predicted normals
180 | if self.use_pred_normals:
181 | self.mlp_pred_normals = MLP(
182 | in_dim=self.geo_feat_dim + self.position_encoding.get_out_dim(),
183 | num_layers=3,
184 | layer_width=64,
185 | out_dim=hidden_dim_transient,
186 | activation=nn.ReLU(),
187 | out_activation=None,
188 | implementation=implementation,
189 | )
190 | self.field_head_pred_normals = PredNormalsFieldHead(in_dim=self.mlp_pred_normals.get_out_dim())
191 |
192 | self.mlp_head = MLP(
193 | in_dim=self.direction_encoding.get_out_dim() + self.geo_feat_dim + self.appearance_embedding_dim,
194 | num_layers=num_layers_color,
195 | layer_width=hidden_dim_color,
196 | out_dim=3,
197 | activation=nn.ReLU(),
198 | out_activation=nn.Sigmoid(),
199 | implementation=implementation,
200 | )
201 |
202 | def get_density(self, ray_samples: RaySamples) -> Tuple[Tensor, Tensor]:
203 | """Computes and returns the densities."""
204 | if self.spatial_distortion is not None:
205 | positions = ray_samples.frustums.get_positions()
206 | positions = self.spatial_distortion(positions)
207 | positions = (positions + 2.0) / 4.0
208 | else:
209 | positions = SceneBox.get_normalized_positions(ray_samples.frustums.get_positions(), self.aabb)
210 | # Make sure the tcnn gets inputs between 0 and 1.
211 | selector = ((positions > 0.0) & (positions < 1.0)).all(dim=-1)
212 | positions = positions * selector[..., None]
213 |
214 | assert positions.numel() > 0, "positions is empty."
215 |
216 | self._sample_locations = positions
217 | if not self._sample_locations.requires_grad:
218 | self._sample_locations.requires_grad = True
219 | positions_flat = positions.view(-1, 3)
220 |
221 | assert positions_flat.numel() > 0, "positions_flat is empty."
222 | h = self.mlp_base(positions_flat).view(*ray_samples.frustums.shape, -1)
223 | density_before_activation, base_mlp_out = torch.split(h, [1, self.geo_feat_dim], dim=-1)
224 | self._density_before_activation = density_before_activation
225 |
226 | # Rectifying the density with an exponential is much more stable than a ReLU or
227 | # softplus, because it enables high post-activation (float32) density outputs
228 | # from smaller internal (float16) parameters.
229 | density = self.average_init_density * trunc_exp(density_before_activation.to(positions))
230 | density = density * selector[..., None]
231 | return density, base_mlp_out
232 |
233 | def get_outputs(
234 | self, ray_samples: RaySamples, density_embedding: Optional[Tensor] = None
235 | ) -> Dict[FieldHeadNames, Tensor]:
236 | assert density_embedding is not None
237 | outputs = {}
238 | if ray_samples.camera_indices is None:
239 | raise AttributeError("Camera indices are not provided.")
240 | camera_indices = ray_samples.camera_indices.squeeze()
241 | directions = get_normalized_directions(ray_samples.frustums.directions)
242 | directions_flat = directions.view(-1, 3)
243 | d = self.direction_encoding(directions_flat)
244 |
245 | outputs_shape = ray_samples.frustums.directions.shape[:-1]
246 |
247 | # appearance
248 | embedded_appearance = None
249 | if self.embedding_appearance is not None:
250 | if self.training and not self.freeze_appearance_embedding:
251 | embedded_appearance = self.embedding_appearance(camera_indices)
252 | else:
253 | if self.use_average_appearance_embedding:
254 | embedded_appearance = torch.ones(
255 | (*directions.shape[:-1], self.appearance_embedding_dim), device=directions.device
256 | ) * self.embedding_appearance.mean(dim=0)
257 | else:
258 | embedded_appearance = torch.zeros(
259 | (*directions.shape[:-1], self.appearance_embedding_dim), device=directions.device
260 | )
261 |
262 | # transients
263 | if self.use_transient_embedding and self.training:
264 | embedded_transient = self.embedding_transient(camera_indices)
265 | transient_input = torch.cat(
266 | [
267 | density_embedding.view(-1, self.geo_feat_dim),
268 | embedded_transient.view(-1, self.transient_embedding_dim),
269 | ],
270 | dim=-1,
271 | )
272 | x = self.mlp_transient(transient_input).view(*outputs_shape, -1).to(directions)
273 | outputs[FieldHeadNames.UNCERTAINTY] = self.field_head_transient_uncertainty(x)
274 | outputs[FieldHeadNames.TRANSIENT_RGB] = self.field_head_transient_rgb(x)
275 | outputs[FieldHeadNames.TRANSIENT_DENSITY] = self.field_head_transient_density(x)
276 |
277 | # semantics
278 | if self.use_semantics:
279 | semantics_input = density_embedding.view(-1, self.geo_feat_dim)
280 | if not self.pass_semantic_gradients:
281 | semantics_input = semantics_input.detach()
282 |
283 | x = self.mlp_semantics(semantics_input).view(*outputs_shape, -1).to(directions)
284 | outputs[FieldHeadNames.SEMANTICS] = self.field_head_semantics(x)
285 |
286 | # predicted normals
287 | if self.use_pred_normals:
288 | positions = ray_samples.frustums.get_positions()
289 |
290 | positions_flat = self.position_encoding(positions.view(-1, 3))
291 | pred_normals_inp = torch.cat([positions_flat, density_embedding.view(-1, self.geo_feat_dim)], dim=-1)
292 |
293 | x = self.mlp_pred_normals(pred_normals_inp).view(*outputs_shape, -1).to(directions)
294 | outputs[FieldHeadNames.PRED_NORMALS] = self.field_head_pred_normals(x)
295 |
296 | h = torch.cat(
297 | [
298 | d,
299 | density_embedding.view(-1, self.geo_feat_dim),
300 | ]
301 | + (
302 | [embedded_appearance.view(-1, self.appearance_embedding_dim)] if embedded_appearance is not None else []
303 | ),
304 | dim=-1,
305 | )
306 | rgb = self.mlp_head(h).view(*outputs_shape, -1).to(directions)
307 | outputs.update({FieldHeadNames.RGB: rgb})
308 |
309 | return outputs
310 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | NVIDIA License
2 |
3 | 1. Definitions
4 |
5 | “Licensor” means any person or entity that distributes its Work.
6 | “Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license.
7 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
8 | Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license.
9 |
10 | 2. License Grant
11 |
12 | 2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
13 |
14 | 3. Limitations
15 |
16 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
17 |
18 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
19 |
20 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only.
21 |
22 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately.
23 |
24 | 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license.
25 |
26 | 3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately.
27 |
28 | 4. Disclaimer of Warranty.
29 |
30 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
31 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
32 |
33 | 5. Limitation of Liability.
34 |
35 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
36 |
37 | STABILITY AI COMMUNITY LICENSE AGREEMENT
38 |
39 | Last Updated: July 5, 2024
40 |
41 | INTRODUCTION
42 | This Agreement applies to any individual person or entity (“You”, “Your” or “Licensee”) that uses or distributes any portion or element of the Stability AI Materials or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below.
43 |
44 | This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement preserves free access to the Models for people or organizations generating annual revenue of less than US $1,000,000 (or local currency equivalent).
45 |
46 | By clicking “I Accept” or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then “You” includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity’s behalf.
47 |
48 | RESEARCH & NON-COMMERCIAL USE LICENSE
49 | Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. “Research Purpose” means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. “Non-Commercial Purpose” means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing.
50 |
51 | COMMERCIAL USE LICENSE
52 | Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. “Commercial Purpose” means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your business’s or organization’s internal operations. If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you.
53 |
54 | GENERAL TERMS
55 | Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms. a. Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved”, and (iii) prominently display “Powered by Stability AI” on a related website, user interface, blogpost, about page, or product documentation. If You create a Derivative Work, You may add your own attribution notice(s) to the “Notice” text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the “Notice” text file that You changed the Stability AI Materials and how it was modified. b. Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AI’s AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works). c. Intellectual Property. (i) Trademark License. No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein. (ii) Ownership of Derivative Works. As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI’s ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI. (iii) Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law. (iv) Disputes. If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement. (v) Feedback. From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AI’s existing or prospective technology, products or services (collectively, “Feedback”). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided “AS IS” and You make no warranties whatsoever about any Feedback. d. Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS. e. Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. f. Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement. g. Governing Law. This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement.
56 |
57 | DEFINITIONS
58 | “Affiliate(s)” means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, “control” means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity.
59 |
60 | "Agreement" means this Stability AI Community License Agreement.
61 |
62 | “AUP” means the Stability AI Acceptable Use Policy available at (https://stability.ai/use-policy), as may be updated from time to time.
63 |
64 | "Derivative Work(s)” means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output, including “fine tune” and “low-rank adaptation” models derived from a Model or a Model’s output, but do not include the output of any Model.
65 |
66 | “Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models.
67 |
68 | “Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stability’s Core Models Webpage available at (https://stability.ai/core-models), as may be updated from time to time.
69 |
70 | "Stability AI" or "we" means Stability AI Ltd. and its Affiliates.
71 |
72 | "Software" means Stability AI’s proprietary software made available under this Agreement now or in the future.
73 |
74 | “Stability AI Materials” means, collectively, Stability’s proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement.
75 |
76 | “Trade Control Laws” means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
--------------------------------------------------------------------------------
/src/train_difix.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gc
3 | import lpips
4 | import random
5 | import argparse
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | import torch.utils.checkpoint
10 | import torchvision
11 | import transformers
12 | from torchvision.transforms.functional import crop
13 | from accelerate import Accelerator
14 | from accelerate.utils import set_seed
15 | from PIL import Image
16 | from torchvision import transforms
17 | from tqdm.auto import tqdm
18 | from glob import glob
19 | from einops import rearrange
20 |
21 | import diffusers
22 | from diffusers.utils.import_utils import is_xformers_available
23 | from diffusers.optimization import get_scheduler
24 |
25 | import wandb
26 |
27 | from model import Difix, load_ckpt_from_state_dict, save_ckpt
28 | from dataset import PairedDataset
29 | from loss import gram_loss
30 |
31 |
32 | def main(args):
33 | accelerator = Accelerator(
34 | gradient_accumulation_steps=args.gradient_accumulation_steps,
35 | mixed_precision=args.mixed_precision,
36 | log_with=args.report_to,
37 | )
38 |
39 | if accelerator.is_local_main_process:
40 | transformers.utils.logging.set_verbosity_warning()
41 | diffusers.utils.logging.set_verbosity_info()
42 | else:
43 | transformers.utils.logging.set_verbosity_error()
44 | diffusers.utils.logging.set_verbosity_error()
45 |
46 | if args.seed is not None:
47 | set_seed(args.seed)
48 |
49 | if accelerator.is_main_process:
50 | os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True)
51 | os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True)
52 |
53 | net_difix = Difix(
54 | lora_rank_vae=args.lora_rank_vae,
55 | timestep=args.timestep,
56 | mv_unet=args.mv_unet,
57 | )
58 | net_difix.set_train()
59 |
60 | if args.enable_xformers_memory_efficient_attention:
61 | if is_xformers_available():
62 | net_difix.unet.enable_xformers_memory_efficient_attention()
63 | else:
64 | raise ValueError("xformers is not available, please install it by running `pip install xformers`")
65 |
66 | if args.gradient_checkpointing:
67 | net_difix.unet.enable_gradient_checkpointing()
68 |
69 | if args.allow_tf32:
70 | torch.backends.cuda.matmul.allow_tf32 = True
71 |
72 | net_lpips = lpips.LPIPS(net='vgg').cuda()
73 |
74 | net_lpips.requires_grad_(False)
75 |
76 | net_vgg = torchvision.models.vgg16(pretrained=True).features
77 | for param in net_vgg.parameters():
78 | param.requires_grad_(False)
79 |
80 | # make the optimizer
81 | layers_to_opt = []
82 | layers_to_opt += list(net_difix.unet.parameters())
83 |
84 | for n, _p in net_difix.vae.named_parameters():
85 | if "lora" in n and "vae_skip" in n:
86 | assert _p.requires_grad
87 | layers_to_opt.append(_p)
88 | layers_to_opt = layers_to_opt + list(net_difix.vae.decoder.skip_conv_1.parameters()) + \
89 | list(net_difix.vae.decoder.skip_conv_2.parameters()) + \
90 | list(net_difix.vae.decoder.skip_conv_3.parameters()) + \
91 | list(net_difix.vae.decoder.skip_conv_4.parameters())
92 |
93 | optimizer = torch.optim.AdamW(layers_to_opt, lr=args.learning_rate,
94 | betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
95 | eps=args.adam_epsilon,)
96 | lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer,
97 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
98 | num_training_steps=args.max_train_steps * accelerator.num_processes,
99 | num_cycles=args.lr_num_cycles, power=args.lr_power,)
100 |
101 | dataset_train = PairedDataset(dataset_path=args.dataset_path, split="train", tokenizer=net_difix.tokenizer)
102 | dl_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers)
103 | dataset_val = PairedDataset(dataset_path=args.dataset_path, split="test", tokenizer=net_difix.tokenizer)
104 | random.Random(42).shuffle(dataset_val.img_names)
105 | dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0)
106 |
107 | # Resume from checkpoint
108 | global_step = 0
109 | if args.resume is not None:
110 | if os.path.isdir(args.resume):
111 | # Resume from last ckpt
112 | ckpt_files = glob(os.path.join(args.resume, "*.pkl"))
113 | assert len(ckpt_files) > 0, f"No checkpoint files found: {args.resume}"
114 | ckpt_files = sorted(ckpt_files, key=lambda x: int(x.split("/")[-1].replace("model_", "").replace(".pkl", "")))
115 | print("="*50); print(f"Loading checkpoint from {ckpt_files[-1]}"); print("="*50)
116 | global_step = int(ckpt_files[-1].split("/")[-1].replace("model_", "").replace(".pkl", ""))
117 | net_difix, optimizer = load_ckpt_from_state_dict(
118 | net_difix, optimizer, ckpt_files[-1]
119 | )
120 | elif args.resume.endswith(".pkl"):
121 | print("="*50); print(f"Loading checkpoint from {args.resume}"); print("="*50)
122 | global_step = int(args.resume.split("/")[-1].replace("model_", "").replace(".pkl", ""))
123 | net_difix, optimizer = load_ckpt_from_state_dict(
124 | net_difix, optimizer, args.resume
125 | )
126 | else:
127 | raise NotImplementedError(f"Invalid resume path: {args.resume}")
128 | else:
129 | print("="*50); print(f"Training from scratch"); print("="*50)
130 |
131 | weight_dtype = torch.float32
132 | if accelerator.mixed_precision == "fp16":
133 | weight_dtype = torch.float16
134 | elif accelerator.mixed_precision == "bf16":
135 | weight_dtype = torch.bfloat16
136 |
137 | # Move al networksr to device and cast to weight_dtype
138 | net_difix.to(accelerator.device, dtype=weight_dtype)
139 | net_lpips.to(accelerator.device, dtype=weight_dtype)
140 | net_vgg.to(accelerator.device, dtype=weight_dtype)
141 |
142 | # Prepare everything with our `accelerator`.
143 | net_difix, optimizer, dl_train, lr_scheduler = accelerator.prepare(
144 | net_difix, optimizer, dl_train, lr_scheduler
145 | )
146 | net_lpips, net_vgg = accelerator.prepare(net_lpips, net_vgg)
147 | # renorm with image net statistics
148 | t_vgg_renorm = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
149 |
150 | # We need to initialize the trackers we use, and also store our configuration.
151 | # The trackers initializes automatically on the main process.
152 | if accelerator.is_main_process:
153 | init_kwargs = {
154 | "wandb": {
155 | "name": args.tracker_run_name,
156 | "dir": args.output_dir,
157 | },
158 | }
159 | tracker_config = dict(vars(args))
160 | accelerator.init_trackers(args.tracker_project_name, config=tracker_config, init_kwargs=init_kwargs)
161 |
162 | progress_bar = tqdm(range(0, args.max_train_steps), initial=global_step, desc="Steps",
163 | disable=not accelerator.is_local_main_process,)
164 |
165 | # start the training loop
166 | for epoch in range(0, args.num_training_epochs):
167 | for step, batch in enumerate(dl_train):
168 | l_acc = [net_difix]
169 | with accelerator.accumulate(*l_acc):
170 | x_src = batch["conditioning_pixel_values"]
171 | x_tgt = batch["output_pixel_values"]
172 | B, V, C, H, W = x_src.shape
173 |
174 | # forward pass
175 | x_tgt_pred = net_difix(x_src, prompt_tokens=batch["input_ids"])
176 |
177 | x_tgt = rearrange(x_tgt, 'b v c h w -> (b v) c h w')
178 | x_tgt_pred = rearrange(x_tgt_pred, 'b v c h w -> (b v) c h w')
179 |
180 | # Reconstruction loss
181 | loss_l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.float(), reduction="mean") * args.lambda_l2
182 | loss_lpips = net_lpips(x_tgt_pred.float(), x_tgt.float()).mean() * args.lambda_lpips
183 | loss = loss_l2 + loss_lpips
184 |
185 | # Gram matrix loss
186 | if args.lambda_gram > 0:
187 | if global_step > args.gram_loss_warmup_steps:
188 | x_tgt_pred_renorm = t_vgg_renorm(x_tgt_pred * 0.5 + 0.5)
189 | crop_h, crop_w = 400, 400
190 | top, left = random.randint(0, H - crop_h), random.randint(0, W - crop_w)
191 | x_tgt_pred_renorm = crop(x_tgt_pred_renorm, top, left, crop_h, crop_w)
192 |
193 | x_tgt_renorm = t_vgg_renorm(x_tgt * 0.5 + 0.5)
194 | x_tgt_renorm = crop(x_tgt_renorm, top, left, crop_h, crop_w)
195 |
196 | loss_gram = gram_loss(x_tgt_pred_renorm.to(weight_dtype), x_tgt_renorm.to(weight_dtype), net_vgg) * args.lambda_gram
197 | loss += loss_gram
198 | else:
199 | loss_gram = torch.tensor(0.0).to(weight_dtype)
200 |
201 | accelerator.backward(loss, retain_graph=False)
202 | if accelerator.sync_gradients:
203 | accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm)
204 | optimizer.step()
205 | lr_scheduler.step()
206 | optimizer.zero_grad(set_to_none=args.set_grads_to_none)
207 |
208 | x_tgt = rearrange(x_tgt, '(b v) c h w -> b v c h w', v=V)
209 | x_tgt_pred = rearrange(x_tgt_pred, '(b v) c h w -> b v c h w', v=V)
210 |
211 | # Checks if the accelerator has performed an optimization step behind the scenes
212 | if accelerator.sync_gradients:
213 | progress_bar.update(1)
214 | global_step += 1
215 |
216 | if accelerator.is_main_process:
217 | logs = {}
218 | # log all the losses
219 | logs["loss_l2"] = loss_l2.detach().item()
220 | logs["loss_lpips"] = loss_lpips.detach().item()
221 | if args.lambda_gram > 0:
222 | logs["loss_gram"] = loss_gram.detach().item()
223 | progress_bar.set_postfix(**logs)
224 |
225 | # viz some images
226 | if global_step % args.viz_freq == 1:
227 | log_dict = {
228 | "train/source": [wandb.Image(rearrange(x_src, "b v c h w -> b c (v h) w")[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(B)],
229 | "train/target": [wandb.Image(rearrange(x_tgt, "b v c h w -> b c (v h) w")[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(B)],
230 | "train/model_output": [wandb.Image(rearrange(x_tgt_pred, "b v c h w -> b c (v h) w")[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(B)],
231 | }
232 | for k in log_dict:
233 | logs[k] = log_dict[k]
234 |
235 | # checkpoint the model
236 | if global_step % args.checkpointing_steps == 1:
237 | outf = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl")
238 | # accelerator.unwrap_model(net_difix).save_model(outf)
239 | save_ckpt(accelerator.unwrap_model(net_difix), optimizer, outf)
240 |
241 | # compute validation set L2, LPIPS
242 | if args.eval_freq > 0 and global_step % args.eval_freq == 1:
243 | l_l2, l_lpips = [], []
244 | log_dict = {"sample/source": [], "sample/target": [], "sample/model_output": []}
245 | for step, batch_val in enumerate(dl_val):
246 | if step >= args.num_samples_eval:
247 | break
248 | x_src = batch_val["conditioning_pixel_values"].to(accelerator.device, dtype=weight_dtype)
249 | x_tgt = batch_val["output_pixel_values"].to(accelerator.device, dtype=weight_dtype)
250 | B, V, C, H, W = x_src.shape
251 | assert B == 1, "Use batch size 1 for eval."
252 | with torch.no_grad():
253 | # forward pass
254 | x_tgt_pred = accelerator.unwrap_model(net_difix)(x_src, prompt_tokens=batch_val["input_ids"].cuda())
255 |
256 | if step % 10 == 0:
257 | log_dict["sample/source"].append(wandb.Image(rearrange(x_src, "b v c h w -> b c (v h) w")[0].float().detach().cpu(), caption=f"idx={len(log_dict['sample/source'])}"))
258 | log_dict["sample/target"].append(wandb.Image(rearrange(x_tgt, "b v c h w -> b c (v h) w")[0].float().detach().cpu(), caption=f"idx={len(log_dict['sample/source'])}"))
259 | log_dict["sample/model_output"].append(wandb.Image(rearrange(x_tgt_pred, "b v c h w -> b c (v h) w")[0].float().detach().cpu(), caption=f"idx={len(log_dict['sample/source'])}"))
260 |
261 | x_tgt = x_tgt[:, 0] # take the input view
262 | x_tgt_pred = x_tgt_pred[:, 0] # take the input view
263 | # compute the reconstruction losses
264 | loss_l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.float(), reduction="mean")
265 | loss_lpips = net_lpips(x_tgt_pred.float(), x_tgt.float()).mean()
266 |
267 | l_l2.append(loss_l2.item())
268 | l_lpips.append(loss_lpips.item())
269 |
270 | logs["val/l2"] = np.mean(l_l2)
271 | logs["val/lpips"] = np.mean(l_lpips)
272 | for k in log_dict:
273 | logs[k] = log_dict[k]
274 | gc.collect()
275 | torch.cuda.empty_cache()
276 | accelerator.log(logs, step=global_step)
277 |
278 |
279 | if __name__ == "__main__":
280 |
281 | parser = argparse.ArgumentParser()
282 | # args for the loss function
283 | parser.add_argument("--lambda_lpips", default=1.0, type=float)
284 | parser.add_argument("--lambda_l2", default=1.0, type=float)
285 | parser.add_argument("--lambda_gram", default=1.0, type=float)
286 | parser.add_argument("--gram_loss_warmup_steps", default=2000, type=int)
287 |
288 | # dataset options
289 | parser.add_argument("--dataset_path", required=True, type=str)
290 | parser.add_argument("--train_image_prep", default="resized_crop_512", type=str)
291 | parser.add_argument("--test_image_prep", default="resized_crop_512", type=str)
292 | parser.add_argument("--prompt", default=None, type=str)
293 |
294 | # validation eval args
295 | parser.add_argument("--eval_freq", default=100, type=int)
296 | parser.add_argument("--num_samples_eval", type=int, default=100, help="Number of samples to use for all evaluation")
297 |
298 | parser.add_argument("--viz_freq", type=int, default=100, help="Frequency of visualizing the outputs.")
299 | parser.add_argument("--tracker_project_name", type=str, default="difix", help="The name of the wandb project to log to.")
300 | parser.add_argument("--tracker_run_name", type=str, required=True)
301 |
302 | # details about the model architecture
303 | parser.add_argument("--pretrained_model_name_or_path")
304 | parser.add_argument("--revision", type=str, default=None,)
305 | parser.add_argument("--variant", type=str, default=None,)
306 | parser.add_argument("--tokenizer_name", type=str, default=None)
307 | parser.add_argument("--lora_rank_vae", default=4, type=int)
308 | parser.add_argument("--timestep", default=199, type=int)
309 | parser.add_argument("--mv_unet", action="store_true")
310 |
311 | # training details
312 | parser.add_argument("--output_dir", required=True)
313 | parser.add_argument("--cache_dir", default=None,)
314 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
315 | parser.add_argument("--resolution", type=int, default=512,)
316 | parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.")
317 | parser.add_argument("--num_training_epochs", type=int, default=10)
318 | parser.add_argument("--max_train_steps", type=int, default=10_000,)
319 | parser.add_argument("--checkpointing_steps", type=int, default=500,)
320 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.",)
321 | parser.add_argument("--gradient_checkpointing", action="store_true",)
322 | parser.add_argument("--learning_rate", type=float, default=5e-6)
323 | parser.add_argument("--lr_scheduler", type=str, default="constant",
324 | help=(
325 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
326 | ' "constant", "constant_with_warmup"]'
327 | ),
328 | )
329 | parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.")
330 | parser.add_argument("--lr_num_cycles", type=int, default=1,
331 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
332 | )
333 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
334 |
335 | parser.add_argument("--dataloader_num_workers", type=int, default=0,)
336 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
337 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
338 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
339 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
340 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
341 | parser.add_argument("--allow_tf32", action="store_true",
342 | help=(
343 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
344 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
345 | ),
346 | )
347 | parser.add_argument("--report_to", type=str, default="wandb",
348 | help=(
349 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
350 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
351 | ),
352 | )
353 | parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"],)
354 | parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
355 | parser.add_argument("--set_grads_to_none", action="store_true",)
356 |
357 | # resume
358 | parser.add_argument("--resume", default=None, type=str)
359 |
360 | args = parser.parse_args()
361 |
362 | main(args)
363 |
--------------------------------------------------------------------------------
/examples/gsplat/datasets/colmap.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from typing import Any, Dict, List, Optional
4 | from typing_extensions import assert_never
5 |
6 | import cv2
7 | import imageio.v2 as imageio
8 | import numpy as np
9 | import torch
10 | from pycolmap import SceneManager
11 |
12 | from .normalize import (
13 | align_principle_axes,
14 | similarity_from_cameras,
15 | transform_cameras,
16 | transform_points,
17 | )
18 |
19 |
20 | def _get_rel_paths(path_dir: str) -> List[str]:
21 | """Recursively get relative paths of files in a directory."""
22 | paths = []
23 | for dp, dn, fn in os.walk(path_dir):
24 | for f in fn:
25 | paths.append(os.path.relpath(os.path.join(dp, f), path_dir))
26 | return paths
27 |
28 |
29 | class Parser:
30 | """COLMAP parser."""
31 |
32 | def __init__(
33 | self,
34 | data_dir: str,
35 | factor: int = 1,
36 | normalize: bool = False,
37 | test_every: int = 8,
38 | ):
39 | self.data_dir = data_dir
40 | self.factor = factor
41 | self.normalize = normalize
42 | self.test_every = test_every
43 |
44 | colmap_dir = os.path.join(data_dir, "sparse/0/")
45 | if not os.path.exists(colmap_dir):
46 | # colmap_dir = os.path.join(data_dir, "sparse")
47 | colmap_dir = os.path.join(data_dir, "colmap/sparse/0")
48 | assert os.path.exists(
49 | colmap_dir
50 | ), f"COLMAP directory {colmap_dir} does not exist."
51 |
52 | manager = SceneManager(colmap_dir)
53 | manager.load_cameras()
54 | manager.load_images()
55 | manager.load_points3D()
56 |
57 | # Extract extrinsic matrices in world-to-camera format.
58 | imdata = manager.images
59 | w2c_mats = []
60 | camera_ids = []
61 | Ks_dict = dict()
62 | params_dict = dict()
63 | imsize_dict = dict() # width, height
64 | mask_dict = dict()
65 | bottom = np.array([0, 0, 0, 1]).reshape(1, 4)
66 | for k in imdata:
67 | im = imdata[k]
68 | rot = im.R()
69 | trans = im.tvec.reshape(3, 1)
70 | w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0)
71 | w2c_mats.append(w2c)
72 |
73 | # support different camera intrinsics
74 | camera_id = im.camera_id
75 | camera_ids.append(camera_id)
76 |
77 | # camera intrinsics
78 | cam = manager.cameras[camera_id]
79 | fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy
80 | K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
81 | K[:2, :] /= factor
82 | Ks_dict[camera_id] = K
83 |
84 | # Get distortion parameters.
85 | type_ = cam.camera_type
86 | if type_ == 0 or type_ == "SIMPLE_PINHOLE":
87 | params = np.empty(0, dtype=np.float32)
88 | camtype = "perspective"
89 | elif type_ == 1 or type_ == "PINHOLE":
90 | params = np.empty(0, dtype=np.float32)
91 | camtype = "perspective"
92 | if type_ == 2 or type_ == "SIMPLE_RADIAL":
93 | params = np.array([cam.k1, 0.0, 0.0, 0.0], dtype=np.float32)
94 | camtype = "perspective"
95 | elif type_ == 3 or type_ == "RADIAL":
96 | params = np.array([cam.k1, cam.k2, 0.0, 0.0], dtype=np.float32)
97 | camtype = "perspective"
98 | elif type_ == 4 or type_ == "OPENCV":
99 | params = np.array([cam.k1, cam.k2, cam.p1, cam.p2], dtype=np.float32)
100 | camtype = "perspective"
101 | elif type_ == 5 or type_ == "OPENCV_FISHEYE":
102 | params = np.array([cam.k1, cam.k2, cam.k3, cam.k4], dtype=np.float32)
103 | camtype = "fisheye"
104 | assert (
105 | camtype == "perspective" or camtype == "fisheye"
106 | ), f"Only perspective and fisheye cameras are supported, got {type_}"
107 |
108 | params_dict[camera_id] = params
109 | imsize_dict[camera_id] = (cam.width // factor, cam.height // factor)
110 | mask_dict[camera_id] = None
111 | print(
112 | f"[Parser] {len(imdata)} images, taken by {len(set(camera_ids))} cameras."
113 | )
114 |
115 | if len(imdata) == 0:
116 | raise ValueError("No images found in COLMAP.")
117 | if not (type_ == 0 or type_ == 1):
118 | print("Warning: COLMAP Camera is not PINHOLE. Images have distortion.")
119 |
120 | w2c_mats = np.stack(w2c_mats, axis=0)
121 |
122 | # Convert extrinsics to camera-to-world.
123 | camtoworlds = np.linalg.inv(w2c_mats)
124 |
125 | # Image names from COLMAP. No need for permuting the poses according to
126 | # image names anymore.
127 | image_names = [imdata[k].name for k in imdata]
128 |
129 | # Previous Nerf results were generated with images sorted by filename,
130 | # ensure metrics are reported on the same test set.
131 | inds = np.argsort(image_names)
132 | image_names = [image_names[i] for i in inds]
133 | camtoworlds = camtoworlds[inds]
134 | camera_ids = [camera_ids[i] for i in inds]
135 |
136 | # Load extended metadata. Used by Bilarf dataset.
137 | self.extconf = {
138 | "spiral_radius_scale": 1.0,
139 | "no_factor_suffix": False,
140 | }
141 | extconf_file = os.path.join(data_dir, "ext_metadata.json")
142 | if os.path.exists(extconf_file):
143 | with open(extconf_file) as f:
144 | self.extconf.update(json.load(f))
145 |
146 | # Load bounds if possible (only used in forward facing scenes).
147 | self.bounds = np.array([0.01, 1.0])
148 | posefile = os.path.join(data_dir, "poses_bounds.npy")
149 | if os.path.exists(posefile):
150 | self.bounds = np.load(posefile)[:, -2:]
151 |
152 | # Load images.
153 | if factor > 1 and not self.extconf["no_factor_suffix"]:
154 | image_dir_suffix = f"_{factor}"
155 | else:
156 | image_dir_suffix = ""
157 | colmap_image_dir = os.path.join(data_dir, "images")
158 | image_dir = os.path.join(data_dir, "images" + image_dir_suffix)
159 | for d in [image_dir, colmap_image_dir]:
160 | if not os.path.exists(d):
161 | raise ValueError(f"Image folder {d} does not exist.")
162 |
163 | # Downsampled images may have different names vs images used for COLMAP,
164 | # so we need to map between the two sorted lists of files.
165 | if "3dv-dataset-nerfstudio" in data_dir:
166 | colmap_files = sorted(_get_rel_paths(colmap_image_dir), key=lambda x: int(x.split(".")[0].split("_")[-1]))
167 | image_files = sorted(_get_rel_paths(image_dir), key=lambda x: int(x.split(".")[0].split("_")[-1]))
168 | colmap_to_image = dict(zip(colmap_files, image_files))
169 | image_names = colmap_files
170 | image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names]
171 | elif "DL3DV-Benchmark" in data_dir:
172 | colmap_files = sorted(_get_rel_paths(colmap_image_dir))
173 | image_files = sorted(_get_rel_paths(image_dir))
174 | colmap_to_image = dict(zip(colmap_files, image_files))
175 | if len(colmap_files) != len(image_names):
176 | print(f"Warning: colmap_files: {len(colmap_files)}, image_names: {len(image_names)}")
177 | image_names = colmap_files
178 | image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names]
179 | else:
180 | colmap_files = sorted(_get_rel_paths(colmap_image_dir))
181 | image_files = sorted(_get_rel_paths(image_dir))
182 | colmap_to_image = dict(zip(colmap_files, image_files))
183 | image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names]
184 |
185 | # 3D points and {image_name -> [point_idx]}
186 | points = manager.points3D.astype(np.float32)
187 | points_err = manager.point3D_errors.astype(np.float32)
188 | points_rgb = manager.point3D_colors.astype(np.uint8)
189 | point_indices = dict()
190 |
191 | image_id_to_name = {v: k for k, v in manager.name_to_image_id.items()}
192 | for point_id, data in manager.point3D_id_to_images.items():
193 | for image_id, _ in data:
194 | image_name = image_id_to_name[image_id]
195 | point_idx = manager.point3D_id_to_point3D_idx[point_id]
196 | point_indices.setdefault(image_name, []).append(point_idx)
197 | point_indices = {
198 | k: np.array(v).astype(np.int32) for k, v in point_indices.items()
199 | }
200 |
201 | # Normalize the world space.
202 | if normalize:
203 | T1 = similarity_from_cameras(camtoworlds)
204 | camtoworlds = transform_cameras(T1, camtoworlds)
205 | points = transform_points(T1, points)
206 |
207 | T2 = align_principle_axes(points)
208 | camtoworlds = transform_cameras(T2, camtoworlds)
209 | points = transform_points(T2, points)
210 |
211 | transform = T2 @ T1
212 | else:
213 | transform = np.eye(4)
214 |
215 | self.image_names = image_names # List[str], (num_images,)
216 | self.image_paths = image_paths # List[str], (num_images,)
217 | self.alpha_mask_paths = None # List[str], (num_images,)
218 | self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4)
219 | self.camera_ids = camera_ids # List[int], (num_images,)
220 | self.Ks_dict = Ks_dict # Dict of camera_id -> K
221 | self.params_dict = params_dict # Dict of camera_id -> params
222 | self.imsize_dict = imsize_dict # Dict of camera_id -> (width, height)
223 | self.mask_dict = mask_dict # Dict of camera_id -> mask
224 | self.points = points # np.ndarray, (num_points, 3)
225 | self.points_err = points_err # np.ndarray, (num_points,)
226 | self.points_rgb = points_rgb # np.ndarray, (num_points, 3)
227 | self.point_indices = point_indices # Dict[str, np.ndarray], image_name -> [M,]
228 | self.transform = transform # np.ndarray, (4, 4)
229 |
230 | # load one image to check the size. In the case of tanksandtemples dataset, the
231 | # intrinsics stored in COLMAP corresponds to 2x upsampled images.
232 | actual_image = imageio.imread(self.image_paths[0])[..., :3]
233 | actual_height, actual_width = actual_image.shape[:2]
234 | colmap_width, colmap_height = self.imsize_dict[self.camera_ids[0]]
235 | s_height, s_width = actual_height / colmap_height, actual_width / colmap_width
236 | for camera_id, K in self.Ks_dict.items():
237 | K[0, :] *= s_width
238 | K[1, :] *= s_height
239 | self.Ks_dict[camera_id] = K
240 | width, height = self.imsize_dict[camera_id]
241 | self.imsize_dict[camera_id] = (actual_width, actual_height)
242 |
243 | # undistortion
244 | self.mapx_dict = dict()
245 | self.mapy_dict = dict()
246 | self.roi_undist_dict = dict()
247 | for camera_id in self.params_dict.keys():
248 | params = self.params_dict[camera_id]
249 | if len(params) == 0:
250 | continue # no distortion
251 | assert camera_id in self.Ks_dict, f"Missing K for camera {camera_id}"
252 | assert (
253 | camera_id in self.params_dict
254 | ), f"Missing params for camera {camera_id}"
255 | K = self.Ks_dict[camera_id]
256 | width, height = self.imsize_dict[camera_id]
257 |
258 | if camtype == "perspective":
259 | K_undist, roi_undist = cv2.getOptimalNewCameraMatrix(
260 | K, params, (width, height), 0
261 | )
262 | mapx, mapy = cv2.initUndistortRectifyMap(
263 | K, params, None, K_undist, (width, height), cv2.CV_32FC1
264 | )
265 | mask = None
266 | elif camtype == "fisheye":
267 | fx = K[0, 0]
268 | fy = K[1, 1]
269 | cx = K[0, 2]
270 | cy = K[1, 2]
271 | grid_x, grid_y = np.meshgrid(
272 | np.arange(width, dtype=np.float32),
273 | np.arange(height, dtype=np.float32),
274 | indexing="xy",
275 | )
276 | x1 = (grid_x - cx) / fx
277 | y1 = (grid_y - cy) / fy
278 | theta = np.sqrt(x1**2 + y1**2)
279 | r = (
280 | 1.0
281 | + params[0] * theta**2
282 | + params[1] * theta**4
283 | + params[2] * theta**6
284 | + params[3] * theta**8
285 | )
286 | mapx = fx * x1 * r + width // 2
287 | mapy = fy * y1 * r + height // 2
288 |
289 | # Use mask to define ROI
290 | mask = np.logical_and(
291 | np.logical_and(mapx > 0, mapy > 0),
292 | np.logical_and(mapx < width - 1, mapy < height - 1),
293 | )
294 | y_indices, x_indices = np.nonzero(mask)
295 | y_min, y_max = y_indices.min(), y_indices.max() + 1
296 | x_min, x_max = x_indices.min(), x_indices.max() + 1
297 | mask = mask[y_min:y_max, x_min:x_max]
298 | K_undist = K.copy()
299 | K_undist[0, 2] -= x_min
300 | K_undist[1, 2] -= y_min
301 | roi_undist = [x_min, y_min, x_max - x_min, y_max - y_min]
302 | else:
303 | assert_never(camtype)
304 |
305 | self.mapx_dict[camera_id] = mapx
306 | self.mapy_dict[camera_id] = mapy
307 | self.Ks_dict[camera_id] = K_undist
308 | self.roi_undist_dict[camera_id] = roi_undist
309 | self.imsize_dict[camera_id] = (roi_undist[2], roi_undist[3])
310 | self.mask_dict[camera_id] = mask
311 |
312 | # size of the scene measured by cameras
313 | camera_locations = camtoworlds[:, :3, 3]
314 | scene_center = np.mean(camera_locations, axis=0)
315 | dists = np.linalg.norm(camera_locations - scene_center, axis=1)
316 | self.scene_scale = np.max(dists)
317 |
318 |
319 | class Dataset:
320 | """A simple dataset class."""
321 |
322 | def __init__(
323 | self,
324 | parser: Parser,
325 | split: str = "train",
326 | patch_size: Optional[int] = None,
327 | load_depths: bool = False,
328 | ):
329 | self.parser = parser
330 | self.split = split
331 | self.patch_size = patch_size
332 | self.load_depths = load_depths
333 |
334 | indices = np.arange(len(self.parser.image_names))
335 | if self.parser.test_every == 1:
336 | image_names = sorted(_get_rel_paths(f"{self.parser.data_dir}/images"), key=lambda x: int(x.split(".")[0].split("_")[-1]))
337 | assert len(image_names) == len(self.parser.image_names)
338 | if split == "train":
339 | self.indices = [ind for ind in indices if "_train_" in image_names[ind]]
340 | else:
341 | self.indices = [ind for ind in indices if "_eval_" in image_names[ind]]
342 | elif self.parser.test_every == 0:
343 | self.indices = indices
344 | else:
345 | if split == "train":
346 | self.indices = indices[indices % self.parser.test_every == 0]
347 | else:
348 | self.indices = indices[indices % self.parser.test_every != 0]
349 |
350 | def __len__(self):
351 | return len(self.indices)
352 |
353 | def __getitem__(self, item: int) -> Dict[str, Any]:
354 | index = self.indices[item]
355 | image = imageio.imread(self.parser.image_paths[index])[..., :3]
356 | camera_id = self.parser.camera_ids[index]
357 | K = self.parser.Ks_dict[camera_id].copy() # undistorted K
358 | params = self.parser.params_dict[camera_id]
359 | camtoworlds = self.parser.camtoworlds[index]
360 | mask = self.parser.mask_dict[camera_id]
361 |
362 | if len(params) > 0:
363 | # Images are distorted. Undistort them.
364 | mapx, mapy = (
365 | self.parser.mapx_dict[camera_id],
366 | self.parser.mapy_dict[camera_id],
367 | )
368 | image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR)
369 | x, y, w, h = self.parser.roi_undist_dict[camera_id]
370 | image = image[y : y + h, x : x + w]
371 |
372 | if self.patch_size is not None:
373 | # Random crop.
374 | h, w = image.shape[:2]
375 | x = np.random.randint(0, max(w - self.patch_size, 1))
376 | y = np.random.randint(0, max(h - self.patch_size, 1))
377 | image = image[y : y + self.patch_size, x : x + self.patch_size]
378 | K[0, 2] -= x
379 | K[1, 2] -= y
380 |
381 | data = {
382 | "K": torch.from_numpy(K).float(),
383 | "camtoworld": torch.from_numpy(camtoworlds).float(),
384 | "image": torch.from_numpy(image).float(),
385 | "image_id": item, # the index of the image in the dataset
386 | }
387 | if mask is not None:
388 | data["mask"] = torch.from_numpy(mask).bool()
389 |
390 | if self.parser.alpha_mask_paths is not None:
391 | alpha_mask = imageio.imread(self.parser.alpha_mask_paths[index], mode="L")[:,:,None] / 255.0
392 | if self.patch_size is not None:
393 | alpha_mask = alpha_mask[y : y + self.patch_size, x : x + self.patch_size]
394 | data["alpha_mask"] = torch.from_numpy(alpha_mask).float()
395 |
396 | if self.load_depths:
397 | # projected points to image plane to get depths
398 | worldtocams = np.linalg.inv(camtoworlds)
399 | image_name = self.parser.image_names[index]
400 | point_indices = self.parser.point_indices[image_name]
401 | points_world = self.parser.points[point_indices]
402 | points_cam = (worldtocams[:3, :3] @ points_world.T + worldtocams[:3, 3:4]).T
403 | points_proj = (K @ points_cam.T).T
404 | points = points_proj[:, :2] / points_proj[:, 2:3] # (M, 2)
405 | depths = points_cam[:, 2] # (M,)
406 | # filter out points outside the image
407 | selector = (
408 | (points[:, 0] >= 0)
409 | & (points[:, 0] < image.shape[1])
410 | & (points[:, 1] >= 0)
411 | & (points[:, 1] < image.shape[0])
412 | & (depths > 0)
413 | )
414 | points = points[selector]
415 | depths = depths[selector]
416 | data["points"] = torch.from_numpy(points).float()
417 | data["depths"] = torch.from_numpy(depths).float()
418 |
419 | return data
420 |
421 |
422 | if __name__ == "__main__":
423 | import argparse
424 |
425 | import imageio.v2 as imageio
426 | import tqdm
427 |
428 | parser = argparse.ArgumentParser()
429 | parser.add_argument("--data_dir", type=str, default="data/360_v2/garden")
430 | parser.add_argument("--factor", type=int, default=4)
431 | args = parser.parse_args()
432 |
433 | # Parse COLMAP data.
434 | parser = Parser(
435 | data_dir=args.data_dir, factor=args.factor, normalize=True, test_every=8
436 | )
437 | dataset = Dataset(parser, split="train", load_depths=True)
438 | print(f"Dataset: {len(dataset)} images.")
439 |
440 | writer = imageio.get_writer("results/points.mp4", fps=30)
441 | for data in tqdm.tqdm(dataset, desc="Plotting points"):
442 | image = data["image"].numpy().astype(np.uint8)
443 | points = data["points"].numpy()
444 | depths = data["depths"].numpy()
445 | for x, y in points:
446 | cv2.circle(image, (int(x), int(y)), 2, (255, 0, 0), -1)
447 | writer.append_data(image)
448 | writer.close()
449 |
--------------------------------------------------------------------------------
/examples/gsplat/lib_bilagrid.py:
--------------------------------------------------------------------------------
1 | # # Copyright 2024 Yuehao Wang (https://github.com/yuehaowang). This part of code is borrowed form ["Bilateral Guided Radiance Field Processing"](https://bilarfpro.github.io/).
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | This is a standalone PyTorch implementation of 3D bilateral grid and CP-decomposed 4D bilateral grid.
17 | To use this module, you can download the "lib_bilagrid.py" file and simply put it in your project directory.
18 |
19 | For the details, please check our research project: ["Bilateral Guided Radiance Field Processing"](https://bilarfpro.github.io/).
20 |
21 | #### Dependencies
22 |
23 | In addition to PyTorch and Numpy, please install [tensorly](https://github.com/tensorly/tensorly).
24 | We have tested this module on Python 3.9.18, PyTorch 2.0.1 (CUDA 11), tensorly 0.8.1, and Numpy 1.25.2.
25 |
26 | #### Overview
27 |
28 | - For bilateral guided training, you need to construct a `BilateralGrid` instance, which can hold multiple bilateral grids
29 | for input views. Then, use `slice` function to obtain transformed RGB output and the corresponding affine transformations.
30 |
31 | - For bilateral guided finishing, you need to instantiate a `BilateralGridCP4D` object and use `slice4d`.
32 |
33 | #### Examples
34 |
35 | - Bilateral grid for approximating ISP:
36 |
37 |
38 |
39 | - Low-rank 4D bilateral grid for MR enhancement:
40 |
41 |
42 |
43 |
44 | Below is the API reference.
45 |
46 | """
47 |
48 | import tensorly as tl
49 | import torch
50 | import torch.nn.functional as F
51 | from torch import nn
52 |
53 | tl.set_backend("pytorch")
54 |
55 |
56 | def color_correct(
57 | img: torch.Tensor, ref: torch.Tensor, num_iters: int = 5, eps: float = 0.5 / 255
58 | ) -> torch.Tensor:
59 | """
60 | Warp `img` to match the colors in `ref_img` using iterative color matching.
61 |
62 | This function performs color correction by warping the colors of the input image
63 | to match those of a reference image. It uses a least squares method to find a
64 | transformation that maps the input image's colors to the reference image's colors.
65 |
66 | The algorithm iteratively solves a system of linear equations, updating the set of
67 | unsaturated pixels in each iteration. This approach helps handle non-linear color
68 | transformations and reduces the impact of clipping.
69 |
70 | Args:
71 | img (torch.Tensor): Input image to be color corrected. Shape: [..., num_channels]
72 | ref (torch.Tensor): Reference image to match colors. Shape: [..., num_channels]
73 | num_iters (int, optional): Number of iterations for the color matching process.
74 | Default is 5.
75 | eps (float, optional): Small value to determine the range of unclipped pixels.
76 | Default is 0.5 / 255.
77 |
78 | Returns:
79 | torch.Tensor: Color corrected image with the same shape as the input image.
80 |
81 | Note:
82 | - Both input and reference images should be in the range [0, 1].
83 | - The function works with any number of channels, but typically used with 3 (RGB).
84 | """
85 | if img.shape[-1] != ref.shape[-1]:
86 | raise ValueError(
87 | f"img's {img.shape[-1]} and ref's {ref.shape[-1]} channels must match"
88 | )
89 | num_channels = img.shape[-1]
90 | img_mat = img.reshape([-1, num_channels])
91 | ref_mat = ref.reshape([-1, num_channels])
92 |
93 | def is_unclipped(z):
94 | return (z >= eps) & (z <= 1 - eps) # z \in [eps, 1-eps].
95 |
96 | mask0 = is_unclipped(img_mat)
97 | # Because the set of saturated pixels may change after solving for a
98 | # transformation, we repeatedly solve a system `num_iters` times and update
99 | # our estimate of which pixels are saturated.
100 | for _ in range(num_iters):
101 | # Construct the left hand side of a linear system that contains a quadratic
102 | # expansion of each pixel of `img`.
103 | a_mat = []
104 | for c in range(num_channels):
105 | a_mat.append(img_mat[:, c : (c + 1)] * img_mat[:, c:]) # Quadratic term.
106 | a_mat.append(img_mat) # Linear term.
107 | a_mat.append(torch.ones_like(img_mat[:, :1])) # Bias term.
108 | a_mat = torch.cat(a_mat, dim=-1)
109 | warp = []
110 | for c in range(num_channels):
111 | # Construct the right hand side of a linear system containing each color
112 | # of `ref`.
113 | b = ref_mat[:, c]
114 | # Ignore rows of the linear system that were saturated in the input or are
115 | # saturated in the current corrected color estimate.
116 | mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b)
117 | ma_mat = torch.where(mask[:, None], a_mat, torch.zeros_like(a_mat))
118 | mb = torch.where(mask, b, torch.zeros_like(b))
119 | w = torch.linalg.lstsq(ma_mat, mb, rcond=-1)[0]
120 | assert torch.all(torch.isfinite(w))
121 | warp.append(w)
122 | warp = torch.stack(warp, dim=-1)
123 | # Apply the warp to update img_mat.
124 | img_mat = torch.clip(torch.matmul(a_mat, warp), 0, 1)
125 | corrected_img = torch.reshape(img_mat, img.shape)
126 | return corrected_img
127 |
128 |
129 | def bilateral_grid_tv_loss(model, config):
130 | """Computes total variations of bilateral grids."""
131 | total_loss = 0.0
132 |
133 | for bil_grids in model.bil_grids:
134 | total_loss += config.bilgrid_tv_loss_mult * total_variation_loss(
135 | bil_grids.grids
136 | )
137 |
138 | return total_loss
139 |
140 |
141 | def color_affine_transform(affine_mats, rgb):
142 | """Applies color affine transformations.
143 |
144 | Args:
145 | affine_mats (torch.Tensor): Affine transformation matrices. Supported shape: $(..., 3, 4)$.
146 | rgb (torch.Tensor): Input RGB values. Supported shape: $(..., 3)$.
147 |
148 | Returns:
149 | Output transformed colors of shape $(..., 3)$.
150 | """
151 | return (
152 | torch.matmul(affine_mats[..., :3], rgb.unsqueeze(-1)).squeeze(-1)
153 | + affine_mats[..., 3]
154 | )
155 |
156 |
157 | def _num_tensor_elems(t):
158 | return max(torch.prod(torch.tensor(t.size()[1:]).float()).item(), 1.0)
159 |
160 |
161 | def total_variation_loss(x): # noqa: F811
162 | """Returns total variation on multi-dimensional tensors.
163 |
164 | Args:
165 | x (torch.Tensor): The input tensor with shape $(B, C, ...)$, where $B$ is the batch size and $C$ is the channel size.
166 | """
167 | batch_size = x.shape[0]
168 | tv = 0
169 | for i in range(2, len(x.shape)):
170 | n_res = x.shape[i]
171 | idx1 = torch.arange(1, n_res, device=x.device)
172 | idx2 = torch.arange(0, n_res - 1, device=x.device)
173 | x1 = x.index_select(i, idx1)
174 | x2 = x.index_select(i, idx2)
175 | count = _num_tensor_elems(x1)
176 | tv += torch.pow((x1 - x2), 2).sum() / count
177 | return tv / batch_size
178 |
179 |
180 | def slice(bil_grids, xy, rgb, grid_idx):
181 | """Slices a batch of 3D bilateral grids by pixel coordinates `xy` and gray-scale guidances of pixel colors `rgb`.
182 |
183 | Supports 2-D, 3-D, and 4-D input shapes. The first dimension of the input is the batch size
184 | and the last dimension is 2 for `xy`, 3 for `rgb`, and 1 for `grid_idx`.
185 |
186 | The return value is a dictionary containing the affine transformations `affine_mats` sliced from bilateral grids and
187 | the output color `rgb_out` after applying the afffine transformations.
188 |
189 | In the 2-D input case, `xy` is a $(N, 2)$ tensor, `rgb` is a $(N, 3)$ tensor, and `grid_idx` is a $(N, 1)$ tensor.
190 | Then `affine_mats[i]` can be obtained via slicing the bilateral grid indexed at `grid_idx[i]` by `xy[i, :]` and `rgb2gray(rgb[i, :])`.
191 | For 3-D and 4-D input cases, the behavior of indexing bilateral grids and coordinates is the same with the 2-D case.
192 |
193 | .. note::
194 | This function can be regarded as a wrapper of `color_affine_transform` and `BilateralGrid` with a slight performance improvement.
195 | When `grid_idx` contains a unique index, only a single bilateral grid will used during the slicing. In this case, this function will not
196 | perform tensor indexing to avoid data copy and extra memory
197 | (see [this](https://discuss.pytorch.org/t/does-indexing-a-tensor-return-a-copy-of-it/164905)).
198 |
199 | Args:
200 | bil_grids (`BilateralGrid`): An instance of $N$ bilateral grids.
201 | xy (torch.Tensor): The x-y coordinates of shape $(..., 2)$ in the range of $[0,1]$.
202 | rgb (torch.Tensor): The RGB values of shape $(..., 3)$ for computing the guidance coordinates, ranging in $[0,1]$.
203 | grid_idx (torch.Tensor): The indices of bilateral grids for each slicing. Shape: $(..., 1)$.
204 |
205 | Returns:
206 | A dictionary with keys and values as follows:
207 | ```
208 | {
209 | "rgb": Transformed RGB colors. Shape: (..., 3),
210 | "rgb_affine_mats": The sliced affine transformation matrices from bilateral grids. Shape: (..., 3, 4)
211 | }
212 | ```
213 | """
214 |
215 | sh_ = rgb.shape
216 |
217 | grid_idx_unique = torch.unique(grid_idx)
218 | if len(grid_idx_unique) == 1:
219 | # All pixels are from a single view.
220 | grid_idx = grid_idx_unique # (1,)
221 | xy = xy.unsqueeze(0) # (1, ..., 2)
222 | rgb = rgb.unsqueeze(0) # (1, ..., 3)
223 | else:
224 | # Pixels are randomly sampled from different views.
225 | if len(grid_idx.shape) == 4:
226 | grid_idx = grid_idx[:, 0, 0, 0] # (chunk_size,)
227 | elif len(grid_idx.shape) == 3:
228 | grid_idx = grid_idx[:, 0, 0] # (chunk_size,)
229 | elif len(grid_idx.shape) == 2:
230 | grid_idx = grid_idx[:, 0] # (chunk_size,)
231 | else:
232 | raise ValueError(
233 | "The input to bilateral grid slicing is not supported yet."
234 | )
235 |
236 | affine_mats = bil_grids(xy, rgb, grid_idx)
237 | rgb = color_affine_transform(affine_mats, rgb)
238 |
239 | return {
240 | "rgb": rgb.reshape(*sh_),
241 | "rgb_affine_mats": affine_mats.reshape(
242 | *sh_[:-1], affine_mats.shape[-2], affine_mats.shape[-1]
243 | ),
244 | }
245 |
246 |
247 | class BilateralGrid(nn.Module):
248 | """Class for 3D bilateral grids.
249 |
250 | Holds one or more than one bilateral grids.
251 | """
252 |
253 | def __init__(self, num, grid_X=16, grid_Y=16, grid_W=8):
254 | """
255 | Args:
256 | num (int): The number of bilateral grids (i.e., the number of views).
257 | grid_X (int): Defines grid width $W$.
258 | grid_Y (int): Defines grid height $H$.
259 | grid_W (int): Defines grid guidance dimension $L$.
260 | """
261 | super(BilateralGrid, self).__init__()
262 |
263 | self.grid_width = grid_X
264 | """Grid width. Type: int."""
265 | self.grid_height = grid_Y
266 | """Grid height. Type: int."""
267 | self.grid_guidance = grid_W
268 | """Grid guidance dimension. Type: int."""
269 |
270 | # Initialize grids.
271 | grid = self._init_identity_grid()
272 | self.grids = nn.Parameter(grid.tile(num, 1, 1, 1, 1)) # (N, 12, L, H, W)
273 | """ A 5-D tensor of shape $(N, 12, L, H, W)$."""
274 |
275 | # Weights of BT601 RGB-to-gray.
276 | self.register_buffer("rgb2gray_weight", torch.Tensor([[0.299, 0.587, 0.114]]))
277 | self.rgb2gray = lambda rgb: (rgb @ self.rgb2gray_weight.T) * 2.0 - 1.0
278 | """ A function that converts RGB to gray-scale guidance in $[-1, 1]$."""
279 |
280 | def _init_identity_grid(self):
281 | grid = torch.tensor(
282 | [
283 | 1.0,
284 | 0,
285 | 0,
286 | 0,
287 | 0,
288 | 1.0,
289 | 0,
290 | 0,
291 | 0,
292 | 0,
293 | 1.0,
294 | 0,
295 | ]
296 | ).float()
297 | grid = grid.repeat(
298 | [self.grid_guidance * self.grid_height * self.grid_width, 1]
299 | ) # (L * H * W, 12)
300 | grid = grid.reshape(
301 | 1, self.grid_guidance, self.grid_height, self.grid_width, -1
302 | ) # (1, L, H, W, 12)
303 | grid = grid.permute(0, 4, 1, 2, 3) # (1, 12, L, H, W)
304 | return grid
305 |
306 | def tv_loss(self):
307 | """Computes and returns total variation loss on the bilateral grids."""
308 | return total_variation_loss(self.grids)
309 |
310 | def forward(self, grid_xy, rgb, idx=None):
311 | """Bilateral grid slicing. Supports 2-D, 3-D, 4-D, and 5-D input.
312 | For the 2-D, 3-D, and 4-D cases, please refer to `slice`.
313 | For the 5-D cases, `idx` will be unused and the first dimension of `xy` should be
314 | equal to the number of bilateral grids. Then this function becomes PyTorch's
315 | [`F.grid_sample`](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html).
316 |
317 | Args:
318 | grid_xy (torch.Tensor): The x-y coordinates in the range of $[0,1]$.
319 | rgb (torch.Tensor): The RGB values in the range of $[0,1]$.
320 | idx (torch.Tensor): The bilateral grid indices.
321 |
322 | Returns:
323 | Sliced affine matrices of shape $(..., 3, 4)$.
324 | """
325 |
326 | grids = self.grids
327 | input_ndims = len(grid_xy.shape)
328 | assert len(rgb.shape) == input_ndims
329 |
330 | if input_ndims > 1 and input_ndims < 5:
331 | # Convert input into 5D
332 | for i in range(5 - input_ndims):
333 | grid_xy = grid_xy.unsqueeze(1)
334 | rgb = rgb.unsqueeze(1)
335 | assert idx is not None
336 | elif input_ndims != 5:
337 | raise ValueError(
338 | "Bilateral grid slicing only takes either 2D, 3D, 4D and 5D inputs"
339 | )
340 |
341 | grids = self.grids
342 | if idx is not None:
343 | grids = grids[idx]
344 | assert grids.shape[0] == grid_xy.shape[0]
345 |
346 | # Generate slicing coordinates.
347 | grid_xy = (grid_xy - 0.5) * 2 # Rescale to [-1, 1].
348 | grid_z = self.rgb2gray(rgb)
349 |
350 | # print(grid_xy.shape, grid_z.shape)
351 | # exit()
352 | grid_xyz = torch.cat([grid_xy, grid_z], dim=-1) # (N, m, h, w, 3)
353 |
354 | affine_mats = F.grid_sample(
355 | grids, grid_xyz, mode="bilinear", align_corners=True, padding_mode="border"
356 | ) # (N, 12, m, h, w)
357 | affine_mats = affine_mats.permute(0, 2, 3, 4, 1) # (N, m, h, w, 12)
358 | affine_mats = affine_mats.reshape(
359 | *affine_mats.shape[:-1], 3, 4
360 | ) # (N, m, h, w, 3, 4)
361 |
362 | for _ in range(5 - input_ndims):
363 | affine_mats = affine_mats.squeeze(1)
364 |
365 | return affine_mats
366 |
367 |
368 | def slice4d(bil_grid4d, xyz, rgb):
369 | """Slices a 4D bilateral grid by point coordinates `xyz` and gray-scale guidances of radiance colors `rgb`.
370 |
371 | Args:
372 | bil_grid4d (`BilateralGridCP4D`): The input 4D bilateral grid.
373 | xyz (torch.Tensor): The xyz coordinates with shape $(..., 3)$.
374 | rgb (torch.Tensor): The RGB values with shape $(..., 3)$.
375 |
376 | Returns:
377 | A dictionary with keys and values as follows:
378 | ```
379 | {
380 | "rgb": Transformed radiance RGB colors. Shape: (..., 3),
381 | "rgb_affine_mats": The sliced affine transformation matrices from the 4D bilateral grid. Shape: (..., 3, 4)
382 | }
383 | ```
384 | """
385 |
386 | affine_mats = bil_grid4d(xyz, rgb)
387 | rgb = color_affine_transform(affine_mats, rgb)
388 |
389 | return {"rgb": rgb, "rgb_affine_mats": affine_mats}
390 |
391 |
392 | class _ScaledTanh(nn.Module):
393 | def __init__(self, s=2.0):
394 | super().__init__()
395 | self.scaler = s
396 |
397 | def forward(self, x):
398 | return torch.tanh(self.scaler * x)
399 |
400 |
401 | class BilateralGridCP4D(nn.Module):
402 | """Class for low-rank 4D bilateral grids."""
403 |
404 | def __init__(
405 | self,
406 | grid_X=16,
407 | grid_Y=16,
408 | grid_Z=16,
409 | grid_W=8,
410 | rank=5,
411 | learn_gray=True,
412 | gray_mlp_width=8,
413 | gray_mlp_depth=2,
414 | init_noise_scale=1e-6,
415 | bound=2.0,
416 | ):
417 | """
418 | Args:
419 | grid_X (int): Defines grid width.
420 | grid_Y (int): Defines grid height.
421 | grid_Z (int): Defines grid depth.
422 | grid_W (int): Defines grid guidance dimension.
423 | rank (int): Rank of the 4D bilateral grid.
424 | learn_gray (bool): If True, an MLP will be learned to convert RGB colors to gray-scale guidances.
425 | gray_mlp_width (int): The MLP width for learnable guidance.
426 | gray_mlp_depth (int): The number of MLP layers for learnable guidance.
427 | init_noise_scale (float): The noise scale of the initialized factors.
428 | bound (float): The bound of the xyz coordinates.
429 | """
430 | super(BilateralGridCP4D, self).__init__()
431 |
432 | self.grid_X = grid_X
433 | """Grid width. Type: int."""
434 | self.grid_Y = grid_Y
435 | """Grid height. Type: int."""
436 | self.grid_Z = grid_Z
437 | """Grid depth. Type: int."""
438 | self.grid_W = grid_W
439 | """Grid guidance dimension. Type: int."""
440 | self.rank = rank
441 | """Rank of the 4D bilateral grid. Type: int."""
442 | self.learn_gray = learn_gray
443 | """Flags of learnable guidance is used. Type: bool."""
444 | self.gray_mlp_width = gray_mlp_width
445 | """The MLP width for learnable guidance. Type: int."""
446 | self.gray_mlp_depth = gray_mlp_depth
447 | """The MLP depth for learnable guidance. Type: int."""
448 | self.init_noise_scale = init_noise_scale
449 | """The noise scale of the initialized factors. Type: float."""
450 | self.bound = bound
451 | """The bound of the xyz coordinates. Type: float."""
452 |
453 | self._init_cp_factors_parafac()
454 |
455 | self.rgb2gray = None
456 | """ A function that converts RGB to gray-scale guidances in $[-1, 1]$.
457 | If `learn_gray` is True, this will be an MLP network."""
458 |
459 | if self.learn_gray:
460 |
461 | def rgb2gray_mlp_linear(layer):
462 | return nn.Linear(
463 | self.gray_mlp_width,
464 | self.gray_mlp_width if layer < self.gray_mlp_depth - 1 else 1,
465 | )
466 |
467 | def rgb2gray_mlp_actfn(_):
468 | return nn.ReLU(inplace=True)
469 |
470 | self.rgb2gray = nn.Sequential(
471 | *(
472 | [nn.Linear(3, self.gray_mlp_width)]
473 | + [
474 | nn_module(layer)
475 | for layer in range(1, self.gray_mlp_depth)
476 | for nn_module in [rgb2gray_mlp_actfn, rgb2gray_mlp_linear]
477 | ]
478 | + [_ScaledTanh(2.0)]
479 | )
480 | )
481 | else:
482 | # Weights of BT601/BT470 RGB-to-gray.
483 | self.register_buffer(
484 | "rgb2gray_weight", torch.Tensor([[0.299, 0.587, 0.114]])
485 | )
486 | self.rgb2gray = lambda rgb: (rgb @ self.rgb2gray_weight.T) * 2.0 - 1.0
487 |
488 | def _init_identity_grid(self):
489 | grid = torch.tensor(
490 | [
491 | 1.0,
492 | 0,
493 | 0,
494 | 0,
495 | 0,
496 | 1.0,
497 | 0,
498 | 0,
499 | 0,
500 | 0,
501 | 1.0,
502 | 0,
503 | ]
504 | ).float()
505 | grid = grid.repeat([self.grid_W * self.grid_Z * self.grid_Y * self.grid_X, 1])
506 | grid = grid.reshape(self.grid_W, self.grid_Z, self.grid_Y, self.grid_X, -1)
507 | grid = grid.permute(4, 0, 1, 2, 3) # (12, grid_W, grid_Z, grid_Y, grid_X)
508 | return grid
509 |
510 | def _init_cp_factors_parafac(self):
511 | # Initialize identity grids.
512 | init_grids = self._init_identity_grid()
513 | # Random noises are added to avoid singularity.
514 | init_grids = torch.randn_like(init_grids) * self.init_noise_scale + init_grids
515 | from tensorly.decomposition import parafac
516 |
517 | # Initialize grid CP factors
518 | _, facs = parafac(init_grids.clone().detach(), rank=self.rank)
519 |
520 | self.num_facs = len(facs)
521 |
522 | self.fac_0 = nn.Linear(facs[0].shape[0], facs[0].shape[1], bias=False)
523 | self.fac_0.weight = nn.Parameter(facs[0]) # (12, rank)
524 |
525 | for i in range(1, self.num_facs):
526 | fac = facs[i].T # (rank, grid_size)
527 | fac = fac.view(1, fac.shape[0], fac.shape[1], 1) # (1, rank, grid_size, 1)
528 | self.register_buffer(f"fac_{i}_init", fac)
529 |
530 | fac_resid = torch.zeros_like(fac)
531 | self.register_parameter(f"fac_{i}", nn.Parameter(fac_resid))
532 |
533 | def tv_loss(self):
534 | """Computes and returns total variation loss on the factors of the low-rank 4D bilateral grids."""
535 |
536 | total_loss = 0
537 | for i in range(1, self.num_facs):
538 | fac = self.get_parameter(f"fac_{i}")
539 | total_loss += total_variation_loss(fac)
540 |
541 | return total_loss
542 |
543 | def forward(self, xyz, rgb):
544 | """Low-rank 4D bilateral grid slicing.
545 |
546 | Args:
547 | xyz (torch.Tensor): The xyz coordinates with shape $(..., 3)$.
548 | rgb (torch.Tensor): The corresponding RGB values with shape $(..., 3)$.
549 |
550 | Returns:
551 | Sliced affine matrices with shape $(..., 3, 4)$.
552 | """
553 | sh_ = xyz.shape
554 | xyz = xyz.reshape(-1, 3) # flatten (N, 3)
555 | rgb = rgb.reshape(-1, 3) # flatten (N, 3)
556 |
557 | xyz = xyz / self.bound
558 | assert self.rgb2gray is not None
559 | gray = self.rgb2gray(rgb)
560 | xyzw = torch.cat([xyz, gray], dim=-1) # (N, 4)
561 | xyzw = xyzw.transpose(0, 1) # (4, N)
562 | coords = torch.stack([torch.zeros_like(xyzw), xyzw], dim=-1) # (4, N, 2)
563 | coords = coords.unsqueeze(1) # (4, 1, N, 2)
564 |
565 | coef = 1.0
566 | for i in range(1, self.num_facs):
567 | fac = self.get_parameter(f"fac_{i}") + self.get_buffer(f"fac_{i}_init")
568 | coef = coef * F.grid_sample(
569 | fac, coords[[i - 1]], align_corners=True, padding_mode="border"
570 | ) # [1, rank, 1, N]
571 | coef = coef.squeeze([0, 2]).transpose(0, 1) # (N, rank) #type: ignore
572 | mat = self.fac_0(coef)
573 | return mat.reshape(*sh_[:-1], 3, 4)
574 |
--------------------------------------------------------------------------------