├── compression ├── __init__.py ├── npz.py ├── jpeg_xl.py ├── png.py ├── exr.py ├── codec.py ├── decompress.py └── compression_exp.py ├── assets ├── worse.png ├── better.png ├── logo_mpi.png ├── logo_uca.png ├── select.png ├── teaser.png ├── logo_inria.png └── logo_graphdeco.png ├── .gitignore ├── results ├── MipNeRF360 │ ├── room.csv │ ├── bicycle.csv │ ├── bonsai.csv │ ├── counter.csv │ ├── flowers.csv │ ├── garden.csv │ ├── kitchen.csv │ ├── stump.csv │ └── treehill.csv ├── SyntheticNeRF │ ├── lego.csv │ ├── mic.csv │ ├── ship.csv │ ├── chair.csv │ ├── drums.csv │ ├── ficus.csv │ ├── hotdog.csv │ └── materials.csv ├── DeepBlending │ ├── drjohnson.csv │ └── playroom.csv └── TanksAndTemples │ ├── train.csv │ └── truck.csv ├── .gitmodules ├── utils ├── image_utils.py ├── system_utils.py ├── wandb_utils.py ├── graphics_utils.py ├── loss_utils.py ├── camera_utils.py ├── general_utils.py ├── sh_utils.py └── quaternion.py ├── lpipsPyTorch ├── __init__.py └── modules │ ├── utils.py │ ├── lpips.py │ └── networks.py ├── environment.yml ├── .github └── workflows │ └── static.yml ├── .vscode └── launch.json ├── config ├── compression │ ├── umbrella_sh.yaml │ ├── ablation_compression.yaml │ └── umbrella.yaml └── ours_q_sh_local_test.yaml ├── scene ├── cameras.py ├── __init__.py ├── dataset_readers.py └── colmap_loader.py ├── gaussian_renderer ├── network_gui.py └── __init__.py ├── render.py ├── eval ├── collect_eval_per_scene.py ├── copy_compressed.sh └── download_eval.sh ├── full_eval.py ├── metrics.py ├── LICENSE.md ├── arguments └── __init__.py ├── convert.py ├── README.md ├── training_viewer.py └── train.py /compression/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/worse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fraunhoferhhi/Self-Organizing-Gaussians/HEAD/assets/worse.png -------------------------------------------------------------------------------- /assets/better.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fraunhoferhhi/Self-Organizing-Gaussians/HEAD/assets/better.png -------------------------------------------------------------------------------- /assets/logo_mpi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fraunhoferhhi/Self-Organizing-Gaussians/HEAD/assets/logo_mpi.png -------------------------------------------------------------------------------- /assets/logo_uca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fraunhoferhhi/Self-Organizing-Gaussians/HEAD/assets/logo_uca.png -------------------------------------------------------------------------------- /assets/select.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fraunhoferhhi/Self-Organizing-Gaussians/HEAD/assets/select.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fraunhoferhhi/Self-Organizing-Gaussians/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /assets/logo_inria.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fraunhoferhhi/Self-Organizing-Gaussians/HEAD/assets/logo_inria.png -------------------------------------------------------------------------------- /assets/logo_graphdeco.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fraunhoferhhi/Self-Organizing-Gaussians/HEAD/assets/logo_graphdeco.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | output 4 | build 5 | diff_rasterization/diff_rast.egg-info 6 | diff_rasterization/dist 7 | tensorboard_3d 8 | screenshots -------------------------------------------------------------------------------- /results/MipNeRF360/room.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,31.028823852539062,0.9156585931777954,0.2051142305135727,14196402,910116 3 | w/o SH,30.898900985717773,0.9156976342201233,0.20806583762168884,6292781,883600 4 | -------------------------------------------------------------------------------- /results/SyntheticNeRF/lego.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,34.9589729309082,0.9770644903182985,0.0206077620387077,4521377,203401 3 | w/o SH,34.037994384765625,0.9760122895240784,0.021831033751368523,2043294,234256 4 | -------------------------------------------------------------------------------- /results/SyntheticNeRF/mic.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,35.85942840576172,0.9898638129234314,0.0100605441257357,2226599,80089 3 | w/o SH,33.5480842590332,0.9840942621231079,0.017452767118811607,1424409,112896 4 | -------------------------------------------------------------------------------- /results/SyntheticNeRF/ship.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,31.1312313079834,0.8997442126274109,0.1125090420246124,4995871,189225 3 | w/o SH,29.34518051147461,0.8842877149581909,0.1340227574110031,2139112,177241 4 | -------------------------------------------------------------------------------- /results/DeepBlending/drjohnson.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,28.80058479309082,0.8882283568382263,0.2753235101699829,18640330,919681 3 | w/o SH,28.52065658569336,0.8836575746536255,0.27594634890556335,5936749,811801 4 | -------------------------------------------------------------------------------- /results/DeepBlending/playroom.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,29.71374893188477,0.9003745317459106,0.2604320049285888,16833263,861184 3 | w/o SH,29.719018936157227,0.9001690149307251,0.26417237520217896,5487891,788544 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/bicycle.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,24.54261779785156,0.7260880470275879,0.251568853855133,53202029,2893401 3 | w/o SH,23.95030975341797,0.7128373980522156,0.26602473855018616,22232720,2845969 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/bonsai.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,31.541339874267575,0.936642289161682,0.1814960539340973,23780541,1375929 3 | w/o SH,30.518634796142578,0.9333961009979248,0.19095730781555176,9873224,1329409 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/counter.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,28.460908889770508,0.8959587216377258,0.1929932087659835,17619220,960400 3 | w/o SH,27.549983978271484,0.8878458738327026,0.2110556662082672,7008789,900601 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/flowers.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,21.50467109680176,0.5899333953857422,0.3525771498680115,54974812,2852721 3 | w/o SH,21.159927368164062,0.5714252591133118,0.3665423095226288,22229645,2819041 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/garden.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,27.030656814575195,0.8397635817527771,0.1349593847990036,83494624,4343056 3 | w/o SH,26.5897159576416,0.8323599696159363,0.14341415464878082,33345182,4359744 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/kitchen.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,30.71577453613281,0.9115329384803772,0.1387937813997268,23259372,1210000 3 | w/o SH,29.809154510498047,0.9013781547546387,0.152778759598732,9006151,1181569 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/stump.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,26.26760673522949,0.7532017230987549,0.2484848499298095,50487921,2812329 3 | w/o SH,25.97733497619629,0.7454085946083069,0.26035696268081665,22452017,2822400 4 | -------------------------------------------------------------------------------- /results/SyntheticNeRF/chair.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,35.07296371459961,0.9855194091796876,0.0124865863472223,5067361,210681 3 | w/o SH,33.291229248046875,0.9806248545646667,0.016682274639606476,2412489,228484 4 | -------------------------------------------------------------------------------- /results/SyntheticNeRF/drums.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,26.093549728393555,0.9522567987442015,0.0388978160917758,5378866,194481 3 | w/o SH,25.18836212158203,0.9449558258056641,0.04795675724744797,2432363,221841 4 | -------------------------------------------------------------------------------- /results/SyntheticNeRF/ficus.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,35.09575653076172,0.9860581755638124,0.0126749658957123,4288195,144400 3 | w/o SH,31.930349349975586,0.9764118790626526,0.02135816588997841,1883854,152881 4 | -------------------------------------------------------------------------------- /results/SyntheticNeRF/hotdog.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,37.46590805053711,0.9837482571601868,0.0218593832105398,2933962,109561 3 | w/o SH,35.91576385498047,0.9812123775482178,0.027121955528855324,1565777,129600 4 | -------------------------------------------------------------------------------- /results/SyntheticNeRF/materials.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,30.16141128540039,0.956821620464325,0.0419293940067291,3648816,126736 3 | w/o SH,27.668258666992188,0.9418511390686035,0.05965135619044304,1922822,149769 4 | -------------------------------------------------------------------------------- /results/TanksAndTemples/train.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,21.55785179138184,0.7973787188529968,0.2225230187177658,16766776,889249 3 | w/o SH,21.229110717773438,0.7868022918701172,0.23594272136688232,6781291,857476 4 | -------------------------------------------------------------------------------- /results/TanksAndTemples/truck.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,25.56157493591309,0.8775752782821655,0.1484980285167694,28790293,1595169 3 | w/o SH,25.06743049621582,0.8687856793403625,0.1609877645969391,11805259,1557504 4 | -------------------------------------------------------------------------------- /results/MipNeRF360/treehill.csv: -------------------------------------------------------------------------------- 1 | Submethod,PSNR,SSIM,LPIPS,Size [Bytes],#Gaussians 2 | Baseline,22.619102478027344,0.6190798282623291,0.3675214648246765,41545592,2229049 3 | w/o SH,22.630258560180664,0.6150164008140564,0.37323009967803955,17722798,2205225 4 | -------------------------------------------------------------------------------- /compression/npz.py: -------------------------------------------------------------------------------- 1 | from compression.codec import Codec 2 | 3 | import numpy as np 4 | 5 | class NpzCodec(Codec): 6 | 7 | def encode_image(self, image, out_file, **kwargs): 8 | return np.savez_compressed(out_file, image, **kwargs) 9 | 10 | def decode_image(self, file_name): 11 | return np.load(file_name)["arr_0"] 12 | 13 | def file_ending(self): 14 | return "npz" -------------------------------------------------------------------------------- /compression/jpeg_xl.py: -------------------------------------------------------------------------------- 1 | from compression.codec import Codec 2 | 3 | import imagecodecs 4 | 5 | class JpegXlCodec(Codec): 6 | 7 | def encode_image(self, image, out_file, **kwargs): 8 | imagecodecs.imwrite(out_file, image, **kwargs) 9 | 10 | def decode_image(self, file_name): 11 | return imagecodecs.imread(file_name) 12 | 13 | def file_ending(self): 14 | return "jxl" -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/simple-knn"] 2 | path = submodules/simple-knn 3 | url = https://gitlab.inria.fr/bkerbl/simple-knn.git 4 | [submodule "submodules/diff-gaussian-rasterization"] 5 | path = submodules/diff-gaussian-rasterization 6 | url = https://github.com/graphdeco-inria/diff-gaussian-rasterization 7 | [submodule "SIBR_viewers"] 8 | path = SIBR_viewers 9 | url = https://gitlab.inria.fr/sibr/sibr_core.git 10 | [submodule "submodules/PLAS"] 11 | path = submodules/PLAS 12 | url = https://github.com/fraunhoferhhi/PLAS.git 13 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | 14 | def mse(img1, img2): 15 | return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 16 | 17 | def psnr(img1, img2): 18 | mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) 19 | return 20 * torch.log10(1.0 / torch.sqrt(mse)) 20 | -------------------------------------------------------------------------------- /lpipsPyTorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: sogs 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - python=3.10 9 | 10 | - cudatoolkit=11.8 11 | 12 | - pytorch=2.4 13 | - pytorch-cuda=11.8 14 | - torchvision=0.19 15 | 16 | # for compatibility with imagecodecs 2023.9.18 17 | - numpy<2 18 | 19 | - click=8.1 20 | - hydra-core=1.3 21 | - kornia=0.7 22 | - opencv=4.10 23 | - pandas=2.2 24 | - pip=24.3 25 | - plyfile=1.1 26 | - scipy=1.14 27 | - screeninfo=0.8 28 | - tqdm=4.67 29 | 30 | - pip: 31 | - submodules/diff-gaussian-rasterization 32 | - submodules/simple-knn 33 | 34 | - submodules/PLAS 35 | 36 | # later imagecodecs version produce much larger JPEG XL files 37 | - imagecodecs[all]==2023.9.18 38 | 39 | - wandb==0.18.7 40 | -------------------------------------------------------------------------------- /.github/workflows/static.yml: -------------------------------------------------------------------------------- 1 | name: Deploy static content to Pages 2 | 3 | on: 4 | push: 5 | branches: ["project-page"] 6 | workflow_dispatch: 7 | 8 | permissions: 9 | contents: read 10 | pages: write 11 | id-token: write 12 | 13 | concurrency: 14 | group: "pages" 15 | cancel-in-progress: false 16 | 17 | jobs: 18 | deploy: 19 | environment: 20 | name: github-pages 21 | url: ${{ steps.deployment.outputs.page_url }} 22 | runs-on: ubuntu-latest 23 | steps: 24 | - name: Checkout 25 | uses: actions/checkout@v4 26 | - name: Setup Pages 27 | uses: actions/configure-pages@v5 28 | - name: Upload artifact 29 | uses: actions/upload-pages-artifact@v3 30 | with: 31 | path: './project-page/' 32 | - name: Deploy to GitHub Pages 33 | id: deployment 34 | uses: actions/deploy-pages@v4 35 | -------------------------------------------------------------------------------- /utils/system_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from errno import EEXIST 13 | from os import makedirs, path 14 | import os 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | def searchForMaxIteration(folder): 27 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 28 | return max(saved_iters) 29 | -------------------------------------------------------------------------------- /compression/png.py: -------------------------------------------------------------------------------- 1 | from compression.codec import Codec 2 | 3 | import numpy as np 4 | import cv2 5 | 6 | # dtype: uint8, uint16 7 | 8 | class PNGCodec(Codec): 9 | 10 | def encode_image(self, image, out_file, dtype): 11 | 12 | match dtype: 13 | case "uint8": 14 | image = image * 255 15 | image = image.astype("uint8") 16 | case "uint16": 17 | image = image * 65535 18 | image = image.astype("uint16") 19 | 20 | cv2.imwrite(out_file, image) 21 | 22 | def decode_image(self, file_name): 23 | img = cv2.imread(file_name, cv2.IMREAD_UNCHANGED | cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR) 24 | match img.dtype: 25 | case np.uint8: 26 | img = img / 255 27 | case np.uint16: 28 | img = img / 65535 29 | return img 30 | 31 | def file_ending(self): 32 | return "png" -------------------------------------------------------------------------------- /lpipsPyTorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /utils/wandb_utils.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import os 3 | from omegaconf import DictConfig, OmegaConf 4 | 5 | from scene import GaussianModel 6 | 7 | 8 | def flatten_dict(d, parent_key='', sep='_'): 9 | items = [] 10 | for k, v in d.items(): 11 | new_key = f"{parent_key}{sep}{k}" if parent_key else k 12 | if isinstance(v, dict): 13 | items.extend(flatten_dict(v, new_key, sep=sep).items()) 14 | else: 15 | items.append((new_key, v)) 16 | return dict(items) 17 | 18 | 19 | 20 | def init_wandb(cfg): 21 | 22 | if os.path.exists("/mnt/output"): 23 | wandb_dir = "/mnt/output/wandb_out" 24 | os.makedirs(wandb_dir, exist_ok=True) 25 | else: 26 | wandb_dir = "wandb_out" 27 | 28 | config_dict = OmegaConf.to_container(cfg, resolve=True) 29 | 30 | 31 | run = wandb.init( 32 | project="ssgs", 33 | config=config_dict, 34 | dir=wandb_dir, 35 | group=cfg.run.group, 36 | name=cfg.run.name, 37 | tags=cfg.run.tags, 38 | ) 39 | 40 | return run.url 41 | 42 | 43 | def save_hist(gaussian: GaussianModel, step, num_bins=200): 44 | hist_dict = {} 45 | for attribute in ["_features_rest", "_xyz", "_features_dc", "_scaling", "_rotation", "_opacity"]: 46 | att = getattr(gaussian, attribute).flatten() 47 | hist = wandb.Histogram(att.cpu().numpy(), num_bins=num_bins) 48 | hist_dict["hist/" + attribute[1:]] = hist 49 | wandb.log(hist_dict, step=step) 50 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Train truck", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "train.py", 12 | "console": "integratedTerminal", 13 | "justMyCode": true, 14 | "args": [ 15 | "--config-name", 16 | "ours_q_sh_local_test", 17 | "hydra.run.dir=/data/output/${now:%Y-%m-%d}/${now:%H-%M-%S}-${run.name}", 18 | "dataset.source_path=/data/gaussian_splatting/tandt_db/tandt/truck", 19 | "run.no_progress_bar=false", 20 | // "local_window_debug_view.enabled=true", 21 | // "run.save_iterations=[1200]", 22 | // "optimization.iterations=1200", 23 | "run.name=vs-code-debug", 24 | // "dataset.sh_degree=3", 25 | // "sorting.shuffle=true", 26 | // "run.compress_iterations=[10,20]", 27 | // "run.test_iterations=[10,20]", 28 | // "run.save_iterations=[15,20]", 29 | // "optimization.densification_interval=10", 30 | // "optimization.iterations=20", 31 | "run.test_lpips=false", 32 | ] 33 | }, 34 | ] 35 | } -------------------------------------------------------------------------------- /config/compression/umbrella_sh.yaml: -------------------------------------------------------------------------------- 1 | experiments: 2 | - name: "exr_jxl_quant_5_norm" 3 | attributes: 4 | - name: "_xyz" 5 | method: "exr" 6 | normalize: false 7 | contract: false 8 | quantize: 13 9 | params: 10 | compression: "zip" 11 | - name: "_features_dc" 12 | method: "jpeg-xl" 13 | normalize: true 14 | params: 15 | level: 90 16 | - name: "_features_rest" 17 | method: "jpeg-xl" 18 | normalize: true 19 | quantize: 5 20 | params: 21 | level: 101 22 | - name: "_scaling" 23 | method: "exr" 24 | normalize: false 25 | contract: false 26 | quantize: 6 27 | params: 28 | compression: "none" 29 | - name: "_rotation" 30 | method: "exr" 31 | normalize: false 32 | quantize: 6 33 | params: 34 | compression: "zip" 35 | - name: "_opacity" 36 | method: "exr" 37 | normalize: true 38 | contract: false 39 | quantize: 5 40 | params: 41 | compression: "none" 42 | 43 | - name: "jxl_quant_sh" 44 | attributes: 45 | - name: "_xyz" 46 | method: "jpeg-xl" 47 | normalize: true 48 | quantize: 14 49 | params: 50 | # compression: "zip" 51 | level: 101 52 | - name: "_features_dc" 53 | method: "jpeg-xl" 54 | normalize: true 55 | params: 56 | level: 100 57 | - name: "_features_rest" 58 | method: "jpeg-xl" 59 | normalize: true 60 | quantize: 5 61 | params: 62 | level: 101 63 | - name: "_scaling" 64 | method: "jpeg-xl" 65 | normalize: false 66 | contract: false 67 | quantize: 6 68 | params: 69 | level: 101 70 | - name: "_rotation" 71 | method: "jpeg-xl" 72 | normalize: true 73 | quantize: 6 74 | params: 75 | level: 101 76 | - name: "_opacity" 77 | method: "jpeg-xl" 78 | normalize: true 79 | contract: false 80 | quantize: 6 81 | params: 82 | level: 101 83 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | import numpy as np 15 | from typing import NamedTuple 16 | 17 | class BasicPointCloud(NamedTuple): 18 | points : np.array 19 | colors : np.array 20 | normals : np.array 21 | 22 | def geom_transform_points(points, transf_matrix): 23 | P, _ = points.shape 24 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 25 | points_hom = torch.cat([points, ones], dim=1) 26 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 27 | 28 | denom = points_out[..., 3:] + 0.0000001 29 | return (points_out[..., :3] / denom).squeeze(dim=0) 30 | 31 | def getWorld2View(R, t): 32 | Rt = np.zeros((4, 4)) 33 | Rt[:3, :3] = R.transpose() 34 | Rt[:3, 3] = t 35 | Rt[3, 3] = 1.0 36 | return np.float32(Rt) 37 | 38 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 39 | Rt = np.zeros((4, 4)) 40 | Rt[:3, :3] = R.transpose() 41 | Rt[:3, 3] = t 42 | Rt[3, 3] = 1.0 43 | 44 | C2W = np.linalg.inv(Rt) 45 | cam_center = C2W[:3, 3] 46 | cam_center = (cam_center + translate) * scale 47 | C2W[:3, 3] = cam_center 48 | Rt = np.linalg.inv(C2W) 49 | return np.float32(Rt) 50 | 51 | def getProjectionMatrix(znear, zfar, fovX, fovY): 52 | tanHalfFovY = math.tan((fovY / 2)) 53 | tanHalfFovX = math.tan((fovX / 2)) 54 | 55 | top = tanHalfFovY * znear 56 | bottom = -top 57 | right = tanHalfFovX * znear 58 | left = -right 59 | 60 | P = torch.zeros(4, 4) 61 | 62 | z_sign = 1.0 63 | 64 | P[0, 0] = 2.0 * znear / (right - left) 65 | P[1, 1] = 2.0 * znear / (top - bottom) 66 | P[0, 2] = (right + left) / (right - left) 67 | P[1, 2] = (top + bottom) / (top - bottom) 68 | P[3, 2] = z_sign 69 | P[2, 2] = z_sign * zfar / (zfar - znear) 70 | P[2, 3] = -(zfar * znear) / (zfar - znear) 71 | return P 72 | 73 | def fov2focal(fov, pixels): 74 | return pixels / (2 * math.tan(fov / 2)) 75 | 76 | def focal2fov(focal, pixels): 77 | return 2*math.atan(pixels/(2*focal)) -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import exp 16 | 17 | def l1_loss(network_output, gt): 18 | return torch.abs((network_output - gt)).mean() 19 | 20 | def l2_loss(network_output, gt): 21 | return ((network_output - gt) ** 2).mean() 22 | 23 | def gaussian(window_size, sigma): 24 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 25 | return gauss / gauss.sum() 26 | 27 | def create_window(window_size, channel): 28 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 29 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 30 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 31 | return window 32 | 33 | def ssim(img1, img2, window_size=11, size_average=True): 34 | channel = img1.size(-3) 35 | window = create_window(window_size, channel) 36 | 37 | if img1.is_cuda: 38 | window = window.cuda(img1.get_device()) 39 | window = window.type_as(img1) 40 | 41 | return _ssim(img1, img2, window, window_size, channel, size_average) 42 | 43 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 44 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 45 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 46 | 47 | mu1_sq = mu1.pow(2) 48 | mu2_sq = mu2.pow(2) 49 | mu1_mu2 = mu1 * mu2 50 | 51 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 52 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 53 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 54 | 55 | C1 = 0.01 ** 2 56 | C2 = 0.03 ** 2 57 | 58 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 59 | 60 | if size_average: 61 | return ssim_map.mean() 62 | else: 63 | return ssim_map.mean(1).mean(1).mean(1) 64 | 65 | -------------------------------------------------------------------------------- /compression/exr.py: -------------------------------------------------------------------------------- 1 | from compression.codec import Codec 2 | 3 | import os 4 | os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' 5 | 6 | import cv2 7 | 8 | 9 | # parameters: 10 | # type: ["half", "float"] 11 | # compression: ["none", "rle", "zps", "zip", "piz", "pxr24", "b4a", "b44", "dwaa", "dwab"] 12 | 13 | class EXRCodec(Codec): 14 | 15 | def encode_image(self, image, out_file, type="half", compression="none"): 16 | 17 | imwrite_flags = [] 18 | 19 | match type: 20 | case "half": 21 | imwrite_flags.extend([cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) 22 | case "float": 23 | imwrite_flags.extend([cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) 24 | case _: 25 | raise NotImplementedError(f"Unknown type: {type}") 26 | 27 | match compression: 28 | case "rle": 29 | imwrite_flags.extend([cv2.IMWRITE_EXR_COMPRESSION, cv2.IMWRITE_EXR_COMPRESSION_RLE]) 30 | case "zps": 31 | imwrite_flags.extend([cv2.IMWRITE_EXR_COMPRESSION, cv2.IMWRITE_EXR_COMPRESSION_ZIP]) 32 | case "zip": 33 | imwrite_flags.extend([cv2.IMWRITE_EXR_COMPRESSION, cv2.IMWRITE_EXR_COMPRESSION_ZIP]) 34 | case "piz": 35 | imwrite_flags.extend([cv2.IMWRITE_EXR_COMPRESSION, cv2.IMWRITE_EXR_COMPRESSION_PIZ]) 36 | case "pxr24": 37 | imwrite_flags.extend([cv2.IMWRITE_EXR_COMPRESSION, cv2.IMWRITE_EXR_COMPRESSION_PXR24]) 38 | case "b4a": 39 | imwrite_flags.extend([cv2.IMWRITE_EXR_COMPRESSION, cv2.IMWRITE_EXR_COMPRESSION_B44]) 40 | case "b44": 41 | imwrite_flags.extend([cv2.IMWRITE_EXR_COMPRESSION, cv2.IMWRITE_EXR_COMPRESSION_B44A]) 42 | case "dwaa": 43 | imwrite_flags.extend([cv2.IMWRITE_EXR_COMPRESSION, cv2.IMWRITE_EXR_COMPRESSION_DWAA]) 44 | case "dwab": 45 | imwrite_flags.extend([cv2.IMWRITE_EXR_COMPRESSION, cv2.IMWRITE_EXR_COMPRESSION_DWAB]) 46 | case "none": 47 | pass 48 | case _: 49 | raise NotImplementedError(f"Unknown compression method: {compression}") 50 | 51 | cv2.imwrite(out_file, image, imwrite_flags) 52 | 53 | def decode_image(self, file_name): 54 | return cv2.imread(file_name, cv2.IMREAD_UNCHANGED | cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR) 55 | 56 | def file_ending(self): 57 | return "exr" -------------------------------------------------------------------------------- /compression/codec.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | def normalize_img(img, min_val, max_val): 4 | 5 | # min_clipped_count = (img < min_val).sum() 6 | # max_clipped_count = (img > max_val).sum() 7 | # print(f"Clipped {(min_clipped_count + max_clipped_count) / img.size * 100}% of values") 8 | 9 | img = img.clip(min_val, max_val) 10 | img = (img - min_val) / (max_val - min_val) 11 | return img 12 | 13 | # from print_ranges.py 14 | min_thresholds = { 15 | "_features_dc": -2, 16 | "_features_rest": -1, 17 | "_scaling": -13, 18 | "_rotation": -1, # TODO better range 19 | "_opacity": -6, 20 | } 21 | 22 | max_thresholds = { 23 | "_features_dc": 4, 24 | "_features_rest": 1, 25 | 26 | # from print_ranges.py 27 | # "_scaling": -1, 28 | 29 | # manually overriden, because clipping large scaled gaussians to smaller scales messes up the results big time 30 | "_scaling": 3, 31 | "_rotation": 2, 32 | "_opacity": 12, 33 | } 34 | 35 | class Codec(ABC): 36 | 37 | def encode_image(self, image, out_file, **kwargs): 38 | raise NotImplementedError("Subclasses should implement this!") 39 | 40 | def decode_image(self, file_name): 41 | raise NotImplementedError("Subclasses should implement this!") 42 | 43 | def file_ending(self): 44 | raise NotImplementedError("Subclasses should implement this!") 45 | 46 | def normalize_to_thresholds(self, img, attr_name): 47 | 48 | # normalize coordinates to 0...1 49 | if attr_name == "_xyz": 50 | xyz_min = img.min() 51 | xyz_max = img.max() 52 | return normalize_img(img, xyz_min, xyz_max), xyz_min, xyz_max 53 | 54 | min_val = min_thresholds[attr_name] 55 | max_val = max_thresholds[attr_name] 56 | 57 | return normalize_img(img, min_val, max_val), min_val, max_val 58 | 59 | def read_file_bytes(self, file_path): 60 | with open(file_path, "rb") as f: 61 | return f.read() 62 | 63 | def write_file_bytes(self, file_path, bytes): 64 | with open(file_path, "wb") as f: 65 | f.write(bytes) 66 | 67 | def encode(self, image, out_file, **kwargs): 68 | self.encode_image(image, out_file, **kwargs) 69 | 70 | def decode(self, image): 71 | return self.decode_image(image) 72 | 73 | def encode_with_normalization(self, image, attr_name, out_file, **kwargs): 74 | img_norm, min_val, max_val = self.normalize_to_thresholds(image, attr_name) 75 | self.encode(img_norm, out_file, **kwargs) 76 | return min_val, max_val 77 | 78 | def decode_with_normalization(self, file_name, min_val, max_val): 79 | img_norm = self.decode(file_name) 80 | return img_norm * (max_val - min_val) + min_val 81 | 82 | -------------------------------------------------------------------------------- /compression/decompress.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import sys 3 | import os 4 | import sys 5 | import json 6 | 7 | from argparse import ArgumentParser 8 | 9 | from compression.compression_exp import run_single_decompression 10 | 11 | 12 | def get_size_of_files_in_dir(directory_path): 13 | """Returns sum of all files in `directory_path`, 14 | ignoring any subdirectories and their files.""" 15 | 16 | total_size = 0 17 | for dirpath, dirnames, filenames in os.walk(directory_path): 18 | for filename in filenames: 19 | file_path = os.path.join(dirpath, filename) 20 | total_size += os.path.getsize(file_path) 21 | return total_size 22 | 23 | 24 | def decompress_single_to_ply(compressed_model_path): 25 | 26 | metrics_dict = {"Size [Bytes]": get_size_of_files_in_dir(compressed_model_path)} 27 | 28 | decompressed_gaussians = run_single_decompression(compressed_model_path) 29 | decompressed_model_path = os.path.join(compressed_model_path, "decompressed_model") 30 | ply_path = os.path.join(decompressed_model_path, "point_cloud", "iteration_1", "point_cloud.ply") 31 | 32 | os.makedirs(decompressed_model_path, exist_ok=True) 33 | 34 | decompressed_gaussians.save_ply(ply_path) 35 | 36 | num_gaussians = decompressed_gaussians.get_xyz.shape[0] 37 | metrics_dict["#Gaussians"] = num_gaussians 38 | 39 | # copy cfg_args file from parent/parent folder to decompressed_model_path 40 | model_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.normpath(compressed_model_path)))) 41 | for file_name in ["cfg_args", "cameras.json"]: 42 | shutil.copyfile(os.path.join(model_dir, file_name), os.path.join(decompressed_model_path, file_name)) 43 | 44 | with open(compressed_model_path + "/stats.json", "w") as fp: 45 | json.dump(metrics_dict, fp, indent=True) 46 | 47 | 48 | def decompress_all_to_ply(compressions_dir): 49 | 50 | for compressed_dir in os.listdir(compressions_dir): 51 | if not os.path.isdir(os.path.join(compressions_dir, compressed_dir)): 52 | continue 53 | decompress_single_to_ply(os.path.join(compressions_dir, compressed_dir)) 54 | 55 | 56 | def decompress(): 57 | # example args: --compressed_model output/2023-11-14/14-01-13-blur-15/5/compression/iteration_30000/jxl_man 58 | 59 | # Set up command line argument parser 60 | parser = ArgumentParser(description="Decompression script parameters") 61 | parser.add_argument("--compressed_model_path", type=str) 62 | args_cmdline = parser.parse_args(sys.argv[1:]) 63 | 64 | # like output/2023-11-14/14-01-13-blur-15/5/compression/iteration_30000/jxl_man 65 | compressed_model_path = args_cmdline.compressed_model_path 66 | 67 | decompress_single_to_ply(compressed_model_path) 68 | 69 | 70 | if __name__ == "__main__": 71 | decompress() 72 | -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | from utils.graphics_utils import getWorld2View2, getProjectionMatrix 16 | 17 | class Camera(nn.Module): 18 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, 19 | image_name, uid, 20 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" 21 | ): 22 | super(Camera, self).__init__() 23 | 24 | self.uid = uid 25 | self.colmap_id = colmap_id 26 | self.R = R 27 | self.T = T 28 | self.FoVx = FoVx 29 | self.FoVy = FoVy 30 | self.image_name = image_name 31 | 32 | try: 33 | self.data_device = torch.device(data_device) 34 | except Exception as e: 35 | print(e) 36 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 37 | self.data_device = torch.device("cuda") 38 | 39 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 40 | self.image_width = self.original_image.shape[2] 41 | self.image_height = self.original_image.shape[1] 42 | 43 | if gt_alpha_mask is not None: 44 | self.original_image *= gt_alpha_mask.to(self.data_device) 45 | else: 46 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) 47 | 48 | self.zfar = 100.0 49 | self.znear = 0.01 50 | 51 | self.trans = trans 52 | self.scale = scale 53 | 54 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 55 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() 56 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 57 | self.camera_center = self.world_view_transform.inverse()[3, :3] 58 | 59 | class MiniCam: 60 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 61 | self.image_width = width 62 | self.image_height = height 63 | self.FoVy = fovy 64 | self.FoVx = fovx 65 | self.znear = znear 66 | self.zfar = zfar 67 | self.world_view_transform = world_view_transform 68 | self.full_proj_transform = full_proj_transform 69 | view_inv = torch.inverse(self.world_view_transform) 70 | self.camera_center = view_inv[3][:3] 71 | 72 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | def init(wish_host, wish_port): 27 | global host, port, listener 28 | host = wish_host 29 | port = wish_port 30 | listener.bind((host, port)) 31 | listener.listen() 32 | listener.settimeout(0) 33 | 34 | def try_connect(): 35 | global conn, addr, listener 36 | try: 37 | conn, addr = listener.accept() 38 | print(f"\nConnected by {addr}") 39 | conn.settimeout(None) 40 | except Exception as inst: 41 | pass 42 | 43 | def read(): 44 | global conn 45 | messageLength = conn.recv(4) 46 | messageLength = int.from_bytes(messageLength, 'little') 47 | message = conn.recv(messageLength) 48 | return json.loads(message.decode("utf-8")) 49 | 50 | def send(message_bytes, verify): 51 | global conn 52 | if message_bytes != None: 53 | conn.sendall(message_bytes) 54 | conn.sendall(len(verify).to_bytes(4, 'little')) 55 | conn.sendall(bytes(verify, 'ascii')) 56 | 57 | def receive(): 58 | message = read() 59 | 60 | width = message["resolution_x"] 61 | height = message["resolution_y"] 62 | 63 | if width != 0 and height != 0: 64 | try: 65 | do_training = bool(message["train"]) 66 | fovy = message["fov_y"] 67 | fovx = message["fov_x"] 68 | znear = message["z_near"] 69 | zfar = message["z_far"] 70 | do_shs_python = bool(message["shs_python"]) 71 | do_rot_scale_python = bool(message["rot_scale_python"]) 72 | keep_alive = bool(message["keep_alive"]) 73 | scaling_modifier = message["scaling_modifier"] 74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 75 | world_view_transform[:,1] = -world_view_transform[:,1] 76 | world_view_transform[:,2] = -world_view_transform[:,2] 77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 78 | full_proj_transform[:,1] = -full_proj_transform[:,1] 79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 80 | except Exception as e: 81 | print("") 82 | traceback.print_exc() 83 | raise e 84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 85 | else: 86 | return None, None, None, None, None, None -------------------------------------------------------------------------------- /utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from scene.cameras import Camera 13 | import numpy as np 14 | from utils.general_utils import PILtoTorch 15 | from utils.graphics_utils import fov2focal 16 | 17 | WARNED = False 18 | 19 | def loadCam(args, id, cam_info, resolution_scale): 20 | orig_w, orig_h = cam_info.image.size 21 | 22 | if args.resolution in [1, 2, 4, 8]: 23 | resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) 24 | else: # should be a type that converts to float 25 | if args.resolution == -1: 26 | if orig_w > 1600: 27 | global WARNED 28 | if not WARNED: 29 | print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " 30 | "If this is not desired, please explicitly specify '--resolution/-r' as 1") 31 | WARNED = True 32 | global_down = orig_w / 1600 33 | else: 34 | global_down = 1 35 | else: 36 | global_down = orig_w / args.resolution 37 | 38 | scale = float(global_down) * float(resolution_scale) 39 | resolution = (int(orig_w / scale), int(orig_h / scale)) 40 | 41 | resized_image_rgb = PILtoTorch(cam_info.image, resolution) 42 | 43 | gt_image = resized_image_rgb[:3, ...] 44 | loaded_mask = None 45 | 46 | if resized_image_rgb.shape[1] == 4: 47 | loaded_mask = resized_image_rgb[3:4, ...] 48 | 49 | return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 50 | FoVx=cam_info.FovX, FoVy=cam_info.FovY, 51 | image=gt_image, gt_alpha_mask=loaded_mask, 52 | image_name=cam_info.image_name, uid=id, data_device=args.data_device) 53 | 54 | def cameraList_from_camInfos(cam_infos, resolution_scale, args): 55 | camera_list = [] 56 | 57 | for id, c in enumerate(cam_infos): 58 | camera_list.append(loadCam(args, id, c, resolution_scale)) 59 | 60 | return camera_list 61 | 62 | def camera_to_JSON(id, camera : Camera): 63 | Rt = np.zeros((4, 4)) 64 | Rt[:3, :3] = camera.R.transpose() 65 | Rt[:3, 3] = camera.T 66 | Rt[3, 3] = 1.0 67 | 68 | W2C = np.linalg.inv(Rt) 69 | pos = W2C[:3, 3] 70 | rot = W2C[:3, :3] 71 | serializable_array_2d = [x.tolist() for x in rot] 72 | camera_entry = { 73 | 'id' : id, 74 | 'img_name' : camera.image_name, 75 | 'width' : camera.width, 76 | 'height' : camera.height, 77 | 'position': pos.tolist(), 78 | 'rotation': serializable_array_2d, 79 | 'fy' : fov2focal(camera.FovY, camera.height), 80 | 'fx' : fov2focal(camera.FovX, camera.width) 81 | } 82 | return camera_entry 83 | -------------------------------------------------------------------------------- /lpipsPyTorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | from scene import Scene 14 | import os 15 | from tqdm import tqdm 16 | from os import makedirs 17 | from gaussian_renderer import render 18 | import torchvision 19 | from utils.general_utils import safe_state 20 | from argparse import ArgumentParser 21 | from arguments import ModelParams, PipelineParams, get_combined_args 22 | from gaussian_renderer import GaussianModel 23 | 24 | def render_set(model_path, name, iteration, views, gaussians, pipeline, background): 25 | render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") 26 | gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") 27 | 28 | makedirs(render_path, exist_ok=True) 29 | makedirs(gts_path, exist_ok=True) 30 | 31 | for idx, view in enumerate(tqdm(views, desc="Rendering progress")): 32 | rendering = render(view, gaussians, pipeline, background)["render"] 33 | gt = view.original_image[0:3, :, :] 34 | torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) 35 | torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) 36 | 37 | def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, disable_xyz_log_activation): 38 | with torch.no_grad(): 39 | gaussians = GaussianModel(dataset.sh_degree, disable_xyz_log_activation=disable_xyz_log_activation) 40 | scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) 41 | 42 | bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] 43 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 44 | 45 | if not skip_train: 46 | render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) 47 | 48 | if not skip_test: 49 | render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background) 50 | 51 | if __name__ == "__main__": 52 | # Set up command line argument parser 53 | parser = ArgumentParser(description="Testing script parameters") 54 | model = ModelParams(parser, sentinel=True) 55 | pipeline = PipelineParams(parser) 56 | parser.add_argument("--iteration", default=-1, type=int) 57 | parser.add_argument("--skip_train", action="store_true") 58 | parser.add_argument("--skip_test", action="store_true") 59 | parser.add_argument("--disable_xyz_log_activation", action="store_true") 60 | parser.add_argument("--quiet", action="store_true") 61 | args = get_combined_args(parser) 62 | print("Rendering " + args.model_path) 63 | 64 | # Initialize system state (RNG) 65 | safe_state(args.quiet) 66 | 67 | render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args.disable_xyz_log_activation) -------------------------------------------------------------------------------- /eval/collect_eval_per_scene.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import click 4 | import pandas as pd 5 | from pathlib import Path 6 | 7 | @click.command() 8 | @click.option('--output-dir', required=True, type=click.Path(), help="The output directory where results will be stored.") 9 | @click.option('--dataset', required=True, type=str, help="The dataset name.") 10 | @click.option('--scene', required=True, type=str, help="The scene name.") 11 | @click.option('--model-path', required=True, type=click.Path(exists=True), help="The path to the model directory.") 12 | @click.option('--submethod', default="", type=str, help="Submethod name, if applicable.") 13 | def process_data(output_dir, dataset, scene, model_path, submethod): 14 | results_dir = Path(output_dir) / 'results' 15 | results_dir.mkdir(parents=True, exist_ok=True) 16 | 17 | dataset_dir = results_dir / dataset 18 | dataset_dir.mkdir(parents=True, exist_ok=True) 19 | 20 | # stats.json created from decompress.py, with #Gaussians and Size [Bytes] 21 | stats_path = Path(model_path) / 'stats.json' 22 | with open(stats_path, 'r') as f: 23 | stats = json.load(f) 24 | 25 | # results.json created in decompressed model from running render.py and metrics.py 26 | results_path = Path(model_path) / 'decompressed_model/results.json' 27 | with open(results_path, 'r') as f: 28 | results = json.load(f) 29 | 30 | metrics = {**stats, **results.get("ours_1", {})} 31 | 32 | # Check if required values are present in the metrics dictionary 33 | required_keys = ['PSNR', 'SSIM', 'LPIPS', 'Size [Bytes]', '#Gaussians'] 34 | for key in required_keys: 35 | if key not in metrics: 36 | raise ValueError(f"Missing required metric: {key}") 37 | 38 | # Scene CSV path within the results directory 39 | scene_csv_path = dataset_dir / f'{scene}.csv' 40 | 41 | # Define table structure 42 | columns = ['Submethod', 'PSNR', 'SSIM', 'LPIPS', 'Size [Bytes]', '#Gaussians'] 43 | 44 | # Extract row values from metrics 45 | row_values = [ 46 | metrics['PSNR'], 47 | metrics['SSIM'], 48 | metrics['LPIPS'], 49 | metrics['Size [Bytes]'], 50 | metrics['#Gaussians'] 51 | ] 52 | 53 | # If CSV exists, load it; otherwise, create a new DataFrame 54 | if scene_csv_path.exists(): 55 | df = pd.read_csv(scene_csv_path) 56 | updated = False 57 | 58 | # Check if submethod exists, and update if necessary 59 | if submethod in df['Submethod'].values: 60 | df.loc[df['Submethod'] == submethod, columns[1:]] = row_values 61 | updated = True 62 | else: 63 | # Append new row if submethod does not exist 64 | new_row = [submethod] + row_values 65 | df = pd.concat([df, pd.DataFrame([new_row], columns=columns)], ignore_index=True) 66 | 67 | if updated: 68 | action = "updated" 69 | else: 70 | action = "added" 71 | else: 72 | # Create a new DataFrame with the metrics 73 | initial_data = [[submethod] + row_values] 74 | df = pd.DataFrame(initial_data, columns=columns) 75 | 76 | action = "created" 77 | 78 | # Write the updated DataFrame back to the CSV file 79 | df.to_csv(scene_csv_path, index=False) 80 | 81 | print(f"The CSV file at {scene_csv_path} was {action}.") 82 | 83 | if __name__ == '__main__': 84 | process_data() 85 | -------------------------------------------------------------------------------- /config/ours_q_sh_local_test.yaml: -------------------------------------------------------------------------------- 1 | # Ours Q 2 | defaults: 3 | - _self_ 4 | - compression: umbrella_sh 5 | 6 | hydra: 7 | run: 8 | # output directory for results 9 | dir: ./output/${now:%Y-%m-%d}/${now:%H-%M-%S}-${run.name} 10 | 11 | run: 12 | # wandb config: group, run name, tags 13 | group: "rank_sum_choices" 14 | name: "run" 15 | tags: "" 16 | 17 | # don't set manually, will be replaced at runtime 18 | wandb_url: null 19 | 20 | test_iterations: 21 | [1000, 2000, 5000, 7000, 10000, 12000, 15000, 17000, 20000, 25000, 30000] 22 | save_iterations: [7000, 10000, 20000, 30000] 23 | checkpoint_iterations: [] 24 | start_checkpoint: null 25 | 26 | compress_iterations: [7000, 10000, 20000, 30000] 27 | 28 | # use spherical harmonics in optimization 29 | use_sh: true 30 | 31 | # throws away any command-line output 32 | quiet: false 33 | 34 | # hide progress bar 35 | no_progress_bar: true 36 | 37 | log_nb_loss_interval: 100 38 | 39 | log_training_report_interval: 500 40 | 41 | # pretty slow to compute, only turn on when needed 42 | test_lpips: false 43 | 44 | debug: 45 | debug_from: -1 46 | detect_anomaly: false 47 | 48 | dataset: 49 | sh_degree: 3 50 | source_path: "" 51 | model_path: "" 52 | images: "images" 53 | resolution: -1 54 | white_background: false 55 | data_device: "cuda" 56 | eval: true 57 | 58 | optimization: 59 | iterations: 30000 60 | 61 | position_lr_init: 0.00016 62 | position_lr_final: 0.0000016 63 | position_lr_delay_mult: 0.01 64 | position_lr_max_steps: 30000 65 | 66 | feature_lr: 0.0025 67 | opacity_lr: 0.05 68 | scaling_lr: 0.005 69 | rotation_lr: 0.001 70 | 71 | percent_dense: 0.1 72 | 73 | lambda_dssim: 0.2 74 | 75 | # 3DGS default: 100 76 | densification_interval: 1000 77 | densify_from_iter: 500 78 | densify_until_iter: 15000 79 | 80 | # 3DGS default: 0.0002 81 | densify_grad_threshold: 0.00007 82 | 83 | # 3DGS default: 1000 84 | opacity_reset_interval: 10000000 85 | 86 | # 3DGS default: 0.005 87 | densify_min_opacity: 0.1 88 | 89 | random_background: false 90 | 91 | neighbor_loss: 92 | # set to 0 to disable neighbor loss 93 | lambda_neighbor: 1.0 94 | 95 | normalize: false 96 | activated: false 97 | 98 | # "mse" or "huber" 99 | loss_fn: "huber" 100 | 101 | blur: 102 | kernel_size: 5 103 | sigma: 3.0 104 | 105 | weights: 106 | xyz: 0.0 107 | features_dc: 0.0 108 | features_rest: 0.0 109 | 110 | # not used: let individual Gaussians die out if they wish 111 | opacity: 1.0 112 | scaling: 0.0 113 | rotation: 10.0 114 | 115 | sorting: 116 | enabled: true 117 | normalize: true 118 | activated: true 119 | shuffle: true 120 | 121 | improvement_break: 0.0001 122 | 123 | weights: 124 | xyz: 1.0 125 | features_dc: 1.0 126 | features_rest: 0.0 127 | opacity: 0.0 128 | scaling: 1.0 129 | rotation: 0.0 130 | 131 | pipeline: 132 | convert_SHs_python: false 133 | compute_cov3D_python: false 134 | debug: false 135 | 136 | gui_server: 137 | ip: "127.0.0.1" 138 | port: 6009 139 | 140 | wandb_debug_view: 141 | view_enabled: false 142 | save_hist: false 143 | 144 | view_id: 100 145 | interval: 500 146 | 147 | local_window_debug_view: 148 | enabled: false 149 | interval: 10 150 | view_id: 100 151 | -------------------------------------------------------------------------------- /full_eval.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | from argparse import ArgumentParser 14 | 15 | mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"] 16 | mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"] 17 | tanks_and_temples_scenes = ["truck", "train"] 18 | deep_blending_scenes = ["drjohnson", "playroom"] 19 | 20 | parser = ArgumentParser(description="Full evaluation script parameters") 21 | parser.add_argument("--skip_training", action="store_true") 22 | parser.add_argument("--skip_rendering", action="store_true") 23 | parser.add_argument("--skip_metrics", action="store_true") 24 | parser.add_argument("--output_path", default="./eval") 25 | args, _ = parser.parse_known_args() 26 | 27 | all_scenes = [] 28 | all_scenes.extend(mipnerf360_outdoor_scenes) 29 | all_scenes.extend(mipnerf360_indoor_scenes) 30 | all_scenes.extend(tanks_and_temples_scenes) 31 | all_scenes.extend(deep_blending_scenes) 32 | 33 | if not args.skip_training or not args.skip_rendering: 34 | parser.add_argument('--mipnerf360', "-m360", required=True, type=str) 35 | parser.add_argument("--tanksandtemples", "-tat", required=True, type=str) 36 | parser.add_argument("--deepblending", "-db", required=True, type=str) 37 | args = parser.parse_args() 38 | 39 | if not args.skip_training: 40 | common_args = " --quiet --eval --test_iterations -1 " 41 | for scene in mipnerf360_outdoor_scenes: 42 | source = args.mipnerf360 + "/" + scene 43 | os.system("python train.py -s " + source + " -i images_4 -m " + args.output_path + "/" + scene + common_args) 44 | for scene in mipnerf360_indoor_scenes: 45 | source = args.mipnerf360 + "/" + scene 46 | os.system("python train.py -s " + source + " -i images_2 -m " + args.output_path + "/" + scene + common_args) 47 | for scene in tanks_and_temples_scenes: 48 | source = args.tanksandtemples + "/" + scene 49 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) 50 | for scene in deep_blending_scenes: 51 | source = args.deepblending + "/" + scene 52 | os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) 53 | 54 | if not args.skip_rendering: 55 | all_sources = [] 56 | for scene in mipnerf360_outdoor_scenes: 57 | all_sources.append(args.mipnerf360 + "/" + scene) 58 | for scene in mipnerf360_indoor_scenes: 59 | all_sources.append(args.mipnerf360 + "/" + scene) 60 | for scene in tanks_and_temples_scenes: 61 | all_sources.append(args.tanksandtemples + "/" + scene) 62 | for scene in deep_blending_scenes: 63 | all_sources.append(args.deepblending + "/" + scene) 64 | 65 | common_args = " --quiet --eval --skip_train" 66 | for scene, source in zip(all_scenes, all_sources): 67 | os.system("python render.py --iteration 7000 -s " + source + " -m " + args.output_path + "/" + scene + common_args) 68 | os.system("python render.py --iteration 30000 -s " + source + " -m " + args.output_path + "/" + scene + common_args) 69 | 70 | if not args.skip_metrics: 71 | scenes_string = "" 72 | for scene in all_scenes: 73 | scenes_string += "\"" + args.output_path + "/" + scene + "\" " 74 | 75 | os.system("python metrics.py -m " + scenes_string) -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 15 | from scene.gaussian_model import GaussianModel 16 | from utils.sh_utils import eval_sh 17 | 18 | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): 19 | """ 20 | Render the scene. 21 | 22 | Background tensor (bg_color) must be on GPU! 23 | """ 24 | 25 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 26 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 27 | try: 28 | screenspace_points.retain_grad() 29 | except: 30 | pass 31 | 32 | # Set up rasterization configuration 33 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 34 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 35 | 36 | raster_settings = GaussianRasterizationSettings( 37 | image_height=int(viewpoint_camera.image_height), 38 | image_width=int(viewpoint_camera.image_width), 39 | tanfovx=tanfovx, 40 | tanfovy=tanfovy, 41 | bg=bg_color, 42 | scale_modifier=scaling_modifier, 43 | viewmatrix=viewpoint_camera.world_view_transform, 44 | projmatrix=viewpoint_camera.full_proj_transform, 45 | sh_degree=pc.active_sh_degree, 46 | campos=viewpoint_camera.camera_center, 47 | prefiltered=False, 48 | debug=pipe.debug 49 | ) 50 | 51 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 52 | 53 | means3D = pc.get_xyz 54 | means2D = screenspace_points 55 | opacity = pc.get_opacity 56 | 57 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 58 | # scaling / rotation by the rasterizer. 59 | scales = None 60 | rotations = None 61 | cov3D_precomp = None 62 | if pipe.compute_cov3D_python: 63 | cov3D_precomp = pc.get_covariance(scaling_modifier) 64 | else: 65 | scales = pc.get_scaling 66 | rotations = pc.get_rotation 67 | 68 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 69 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 70 | shs = None 71 | colors_precomp = None 72 | if override_color is None: 73 | if pipe.convert_SHs_python: 74 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) 75 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) 76 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) 77 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 78 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 79 | else: 80 | shs = pc.get_features 81 | else: 82 | colors_precomp = override_color 83 | 84 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 85 | rendered_image, radii = rasterizer( 86 | means3D = means3D, 87 | means2D = means2D, 88 | shs = shs, 89 | colors_precomp = colors_precomp, 90 | opacities = opacity, 91 | scales = scales, 92 | rotations = rotations, 93 | cov3D_precomp = cov3D_precomp) 94 | 95 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 96 | # They will be excluded from value updates used in the splitting criteria. 97 | return {"render": rendered_image, 98 | "viewspace_points": screenspace_points, 99 | "visibility_filter" : radii > 0, 100 | "radii": radii} 101 | -------------------------------------------------------------------------------- /config/compression/ablation_compression.yaml: -------------------------------------------------------------------------------- 1 | experiments: 2 | - name: "NPZ" 3 | attributes: 4 | - name: "_xyz" 5 | method: "npz" 6 | - name: "_features_dc" 7 | method: "npz" 8 | - name: "_scaling" 9 | method: "npz" 10 | - name: "_rotation" 11 | method: "npz" 12 | - name: "_opacity" 13 | method: "npz" 14 | 15 | - name: "JXL ll" 16 | attributes: 17 | - name: "_xyz" 18 | method: "jpeg-xl" 19 | params: 20 | level: 101 21 | - name: "_features_dc" 22 | method: "jpeg-xl" 23 | params: 24 | level: 101 25 | - name: "_scaling" 26 | method: "jpeg-xl" 27 | params: 28 | level: 101 29 | - name: "_rotation" 30 | method: "jpeg-xl" 31 | params: 32 | level: 101 33 | - name: "_opacity" 34 | method: "jpeg-xl" 35 | params: 36 | level: 101 37 | 38 | - name: "PNG 16" 39 | attributes: 40 | - name: "_xyz" 41 | method: "png" 42 | normalize: true 43 | contract: false 44 | params: 45 | dtype: "uint16" 46 | - name: "_features_dc" 47 | method: "png" 48 | normalize: true 49 | params: 50 | dtype: "uint16" 51 | - name: "_scaling" 52 | method: "png" 53 | normalize: true 54 | params: 55 | dtype: "uint16" 56 | - name: "_rotation" 57 | method: "png" 58 | normalize: true 59 | params: 60 | dtype: "uint16" 61 | - name: "_opacity" 62 | method: "png" 63 | normalize: true 64 | params: 65 | dtype: "uint16" 66 | 67 | # EXR w/ OpenCV can only write 1, 3 or 4 channels 68 | - name: "EXR" 69 | attributes: 70 | - name: "_xyz" 71 | method: "exr" 72 | params: 73 | compression: "zip" 74 | - name: "_features_dc" 75 | method: "exr" 76 | params: 77 | compression: "zip" 78 | - name: "_scaling" 79 | method: "exr" 80 | params: 81 | compression: "zip" 82 | - name: "_rotation" 83 | method: "exr" 84 | params: 85 | compression: "zip" 86 | - name: "_opacity" 87 | method: "exr" 88 | params: 89 | compression: "zip" 90 | 91 | - name: "EXR+JXL q" 92 | attributes: 93 | - name: "_xyz" 94 | method: "exr" 95 | normalize: false 96 | contract: false 97 | quantize: 13 98 | params: 99 | compression: "zip" 100 | - name: "_features_dc" 101 | method: "jpeg-xl" 102 | normalize: true 103 | # contract: false 104 | # quantize: 8 105 | params: 106 | level: 90 107 | - name: "_scaling" 108 | method: "exr" 109 | normalize: false 110 | contract: false 111 | quantize: 6 112 | params: 113 | compression: "zip" 114 | - name: "_rotation" 115 | method: "exr" 116 | normalize: false 117 | # contract: false 118 | quantize: 6 119 | params: 120 | compression: "zip" 121 | - name: "_opacity" 122 | method: "exr" 123 | normalize: true 124 | contract: false 125 | quantize: 5 126 | params: 127 | compression: "zip" 128 | 129 | - name: "JXL" 130 | attributes: 131 | - name: "_xyz" 132 | method: "jpeg-xl" 133 | normalize: true 134 | quantize: 14 135 | params: 136 | # compression: "zip" 137 | level: 101 138 | - name: "_features_dc" 139 | method: "jpeg-xl" 140 | normalize: true 141 | params: 142 | level: 100 143 | - name: "_scaling" 144 | method: "jpeg-xl" 145 | normalize: false 146 | contract: false 147 | quantize: 6 148 | params: 149 | level: 101 150 | - name: "_rotation" 151 | method: "jpeg-xl" 152 | normalize: true 153 | quantize: 6 154 | params: 155 | level: 101 156 | - name: "_opacity" 157 | method: "jpeg-xl" 158 | normalize: true 159 | contract: false 160 | quantize: 6 161 | params: 162 | level: 101 -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import random 14 | import json 15 | from utils.system_utils import searchForMaxIteration 16 | from scene.dataset_readers import sceneLoadTypeCallbacks 17 | from scene.gaussian_model import GaussianModel 18 | from arguments import ModelParams 19 | from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON 20 | 21 | class Scene: 22 | 23 | gaussians : GaussianModel 24 | 25 | def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]): 26 | """b 27 | :param path: Path to colmap scene main folder. 28 | """ 29 | self.model_path = args.model_path 30 | self.loaded_iter = None 31 | self.gaussians = gaussians 32 | 33 | if load_iteration: 34 | if load_iteration == -1: 35 | self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) 36 | else: 37 | self.loaded_iter = load_iteration 38 | print("Loading trained model at iteration {}".format(self.loaded_iter)) 39 | 40 | self.train_cameras = {} 41 | self.test_cameras = {} 42 | 43 | if os.path.exists(os.path.join(args.source_path, "sparse")): 44 | scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval) 45 | elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): 46 | print("Found transforms_train.json file, assuming Blender data set!") 47 | scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) 48 | else: 49 | assert False, "Could not recognize scene type!" 50 | 51 | if not self.loaded_iter: 52 | with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file: 53 | dest_file.write(src_file.read()) 54 | json_cams = [] 55 | camlist = [] 56 | if scene_info.test_cameras: 57 | camlist.extend(scene_info.test_cameras) 58 | if scene_info.train_cameras: 59 | camlist.extend(scene_info.train_cameras) 60 | for id, cam in enumerate(camlist): 61 | json_cams.append(camera_to_JSON(id, cam)) 62 | with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: 63 | json.dump(json_cams, file) 64 | 65 | if shuffle: 66 | random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling 67 | random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling 68 | 69 | self.cameras_extent = scene_info.nerf_normalization["radius"] 70 | 71 | for resolution_scale in resolution_scales: 72 | print("Loading Training Cameras") 73 | self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) 74 | print("Loading Test Cameras") 75 | self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) 76 | 77 | if self.loaded_iter: 78 | self.loaded_gaussian_ply = os.path.join(self.model_path, "point_cloud", 79 | "iteration_" + str(self.loaded_iter), 80 | "point_cloud.ply") 81 | self.gaussians.load_ply(self.loaded_gaussian_ply) 82 | else: 83 | self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) 84 | 85 | def save(self, iteration): 86 | point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) 87 | self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) 88 | 89 | def getTrainCameras(self, scale=1.0): 90 | return self.train_cameras[scale] 91 | 92 | def getTestCameras(self, scale=1.0): 93 | return self.test_cameras[scale] -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from pathlib import Path 13 | import os 14 | from PIL import Image 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | from utils.loss_utils import ssim 18 | from lpipsPyTorch import lpips 19 | import json 20 | from tqdm import tqdm 21 | from utils.image_utils import psnr 22 | from argparse import ArgumentParser 23 | 24 | def readImages(renders_dir, gt_dir): 25 | renders = [] 26 | gts = [] 27 | image_names = [] 28 | for fname in os.listdir(renders_dir): 29 | render = Image.open(renders_dir / fname) 30 | gt = Image.open(gt_dir / fname) 31 | renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) 32 | gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) 33 | image_names.append(fname) 34 | return renders, gts, image_names 35 | 36 | def evaluate(model_paths): 37 | 38 | full_dict = {} 39 | per_view_dict = {} 40 | full_dict_polytopeonly = {} 41 | per_view_dict_polytopeonly = {} 42 | print("") 43 | 44 | for scene_dir in model_paths: 45 | try: 46 | print("Scene:", scene_dir) 47 | full_dict[scene_dir] = {} 48 | per_view_dict[scene_dir] = {} 49 | full_dict_polytopeonly[scene_dir] = {} 50 | per_view_dict_polytopeonly[scene_dir] = {} 51 | 52 | test_dir = Path(scene_dir) / "test" 53 | 54 | for method in os.listdir(test_dir): 55 | print("Method:", method) 56 | 57 | full_dict[scene_dir][method] = {} 58 | per_view_dict[scene_dir][method] = {} 59 | full_dict_polytopeonly[scene_dir][method] = {} 60 | per_view_dict_polytopeonly[scene_dir][method] = {} 61 | 62 | method_dir = test_dir / method 63 | gt_dir = method_dir/ "gt" 64 | renders_dir = method_dir / "renders" 65 | renders, gts, image_names = readImages(renders_dir, gt_dir) 66 | 67 | ssims = [] 68 | psnrs = [] 69 | lpipss = [] 70 | 71 | for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): 72 | ssims.append(ssim(renders[idx], gts[idx])) 73 | psnrs.append(psnr(renders[idx], gts[idx])) 74 | lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg')) 75 | 76 | print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) 77 | print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) 78 | print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) 79 | print("") 80 | 81 | full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(), 82 | "PSNR": torch.tensor(psnrs).mean().item(), 83 | "LPIPS": torch.tensor(lpipss).mean().item()}) 84 | per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)}, 85 | "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)}, 86 | "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}}) 87 | 88 | with open(scene_dir + "/results.json", 'w') as fp: 89 | json.dump(full_dict[scene_dir], fp, indent=True) 90 | with open(scene_dir + "/per_view.json", 'w') as fp: 91 | json.dump(per_view_dict[scene_dir], fp, indent=True) 92 | except: 93 | print("Unable to compute metrics for model", scene_dir) 94 | 95 | if __name__ == "__main__": 96 | device = torch.device("cuda:0") 97 | torch.cuda.set_device(device) 98 | 99 | # Set up command line argument parser 100 | parser = ArgumentParser(description="Training script parameters") 101 | parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[]) 102 | args = parser.parse_args() 103 | evaluate(args.model_paths) 104 | -------------------------------------------------------------------------------- /utils/general_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import sys 14 | from datetime import datetime 15 | import numpy as np 16 | import random 17 | 18 | def inverse_sigmoid(x): 19 | return torch.log(x/(1-x)) 20 | 21 | def PILtoTorch(pil_image, resolution): 22 | resized_image_PIL = pil_image.resize(resolution) 23 | resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 24 | if len(resized_image.shape) == 3: 25 | return resized_image.permute(2, 0, 1) 26 | else: 27 | return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) 28 | 29 | def get_expon_lr_func( 30 | lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 31 | ): 32 | """ 33 | Copied from Plenoxels 34 | 35 | Continuous learning rate decay function. Adapted from JaxNeRF 36 | The returned rate is lr_init when step=0 and lr_final when step=max_steps, and 37 | is log-linearly interpolated elsewhere (equivalent to exponential decay). 38 | If lr_delay_steps>0 then the learning rate will be scaled by some smooth 39 | function of lr_delay_mult, such that the initial learning rate is 40 | lr_init*lr_delay_mult at the beginning of optimization but will be eased back 41 | to the normal learning rate when steps>lr_delay_steps. 42 | :param conf: config subtree 'lr' or similar 43 | :param max_steps: int, the number of steps during optimization. 44 | :return HoF which takes step as input 45 | """ 46 | 47 | def helper(step): 48 | if step < 0 or (lr_init == 0.0 and lr_final == 0.0): 49 | # Disable this parameter 50 | return 0.0 51 | if lr_delay_steps > 0: 52 | # A kind of reverse cosine decay. 53 | delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 54 | 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) 55 | ) 56 | else: 57 | delay_rate = 1.0 58 | t = np.clip(step / max_steps, 0, 1) 59 | log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) 60 | return delay_rate * log_lerp 61 | 62 | return helper 63 | 64 | def strip_lowerdiag(L): 65 | uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") 66 | 67 | uncertainty[:, 0] = L[:, 0, 0] 68 | uncertainty[:, 1] = L[:, 0, 1] 69 | uncertainty[:, 2] = L[:, 0, 2] 70 | uncertainty[:, 3] = L[:, 1, 1] 71 | uncertainty[:, 4] = L[:, 1, 2] 72 | uncertainty[:, 5] = L[:, 2, 2] 73 | return uncertainty 74 | 75 | def strip_symmetric(sym): 76 | return strip_lowerdiag(sym) 77 | 78 | def build_rotation(r): 79 | norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) 80 | 81 | q = r / norm[:, None] 82 | 83 | R = torch.zeros((q.size(0), 3, 3), device='cuda') 84 | 85 | r = q[:, 0] 86 | x = q[:, 1] 87 | y = q[:, 2] 88 | z = q[:, 3] 89 | 90 | R[:, 0, 0] = 1 - 2 * (y*y + z*z) 91 | R[:, 0, 1] = 2 * (x*y - r*z) 92 | R[:, 0, 2] = 2 * (x*z + r*y) 93 | R[:, 1, 0] = 2 * (x*y + r*z) 94 | R[:, 1, 1] = 1 - 2 * (x*x + z*z) 95 | R[:, 1, 2] = 2 * (y*z - r*x) 96 | R[:, 2, 0] = 2 * (x*z - r*y) 97 | R[:, 2, 1] = 2 * (y*z + r*x) 98 | R[:, 2, 2] = 1 - 2 * (x*x + y*y) 99 | return R 100 | 101 | def build_scaling_rotation(s, r): 102 | L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") 103 | R = build_rotation(r) 104 | 105 | L[:,0,0] = s[:,0] 106 | L[:,1,1] = s[:,1] 107 | L[:,2,2] = s[:,2] 108 | 109 | L = R @ L 110 | return L 111 | 112 | def safe_state(silent): 113 | old_f = sys.stdout 114 | class F: 115 | def __init__(self, silent): 116 | self.silent = silent 117 | 118 | def write(self, x): 119 | if not self.silent: 120 | if x.endswith("\n"): 121 | old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) 122 | else: 123 | old_f.write(x) 124 | 125 | def flush(self): 126 | old_f.flush() 127 | 128 | sys.stdout = F(silent) 129 | 130 | random.seed(0) 131 | np.random.seed(0) 132 | torch.manual_seed(0) 133 | torch.cuda.set_device(torch.device("cuda:0")) 134 | -------------------------------------------------------------------------------- /utils/sh_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | 26 | C0 = 0.28209479177387814 27 | C1 = 0.4886025119029199 28 | C2 = [ 29 | 1.0925484305920792, 30 | -1.0925484305920792, 31 | 0.31539156525252005, 32 | -1.0925484305920792, 33 | 0.5462742152960396 34 | ] 35 | C3 = [ 36 | -0.5900435899266435, 37 | 2.890611442640554, 38 | -0.4570457994644658, 39 | 0.3731763325901154, 40 | -0.4570457994644658, 41 | 1.445305721320277, 42 | -0.5900435899266435 43 | ] 44 | C4 = [ 45 | 2.5033429417967046, 46 | -1.7701307697799304, 47 | 0.9461746957575601, 48 | -0.6690465435572892, 49 | 0.10578554691520431, 50 | -0.6690465435572892, 51 | 0.47308734787878004, 52 | -1.7701307697799304, 53 | 0.6258357354491761, 54 | ] 55 | 56 | 57 | def eval_sh(deg, sh, dirs): 58 | """ 59 | Evaluate spherical harmonics at unit directions 60 | using hardcoded SH polynomials. 61 | Works with torch/np/jnp. 62 | ... Can be 0 or more batch dimensions. 63 | Args: 64 | deg: int SH deg. Currently, 0-3 supported 65 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 66 | dirs: jnp.ndarray unit directions [..., 3] 67 | Returns: 68 | [..., C] 69 | """ 70 | assert deg <= 4 and deg >= 0 71 | coeff = (deg + 1) ** 2 72 | assert sh.shape[-1] >= coeff 73 | 74 | result = C0 * sh[..., 0] 75 | if deg > 0: 76 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 77 | result = (result - 78 | C1 * y * sh[..., 1] + 79 | C1 * z * sh[..., 2] - 80 | C1 * x * sh[..., 3]) 81 | 82 | if deg > 1: 83 | xx, yy, zz = x * x, y * y, z * z 84 | xy, yz, xz = x * y, y * z, x * z 85 | result = (result + 86 | C2[0] * xy * sh[..., 4] + 87 | C2[1] * yz * sh[..., 5] + 88 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 89 | C2[3] * xz * sh[..., 7] + 90 | C2[4] * (xx - yy) * sh[..., 8]) 91 | 92 | if deg > 2: 93 | result = (result + 94 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 95 | C3[1] * xy * z * sh[..., 10] + 96 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 97 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 98 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 99 | C3[5] * z * (xx - yy) * sh[..., 14] + 100 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 101 | 102 | if deg > 3: 103 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 104 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 105 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 106 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 107 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 108 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 109 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 110 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 111 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 112 | return result 113 | 114 | def RGB2SH(rgb): 115 | return (rgb - 0.5) / C0 116 | 117 | def SH2RGB(sh): 118 | return sh * C0 + 0.5 -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Gaussian-Splatting License 2 | =========================== 3 | 4 | **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. 5 | The *Software* is in the process of being registered with the Agence pour la Protection des 6 | Programmes (APP). 7 | 8 | The *Software* is still being developed by the *Licensor*. 9 | 10 | *Licensor*'s goal is to allow the research community to use, test and evaluate 11 | the *Software*. 12 | 13 | ## 1. Definitions 14 | 15 | *Licensee* means any person or entity that uses the *Software* and distributes 16 | its *Work*. 17 | 18 | *Licensor* means the owners of the *Software*, i.e Inria and MPII 19 | 20 | *Software* means the original work of authorship made available under this 21 | License ie gaussian-splatting. 22 | 23 | *Work* means the *Software* and any additions to or derivative works of the 24 | *Software* that are made available under this License. 25 | 26 | 27 | ## 2. Purpose 28 | This license is intended to define the rights granted to the *Licensee* by 29 | Licensors under the *Software*. 30 | 31 | ## 3. Rights granted 32 | 33 | For the above reasons Licensors have decided to distribute the *Software*. 34 | Licensors grant non-exclusive rights to use the *Software* for research purposes 35 | to research users (both academic and industrial), free of charge, without right 36 | to sublicense.. The *Software* may be used "non-commercially", i.e., for research 37 | and/or evaluation purposes only. 38 | 39 | Subject to the terms and conditions of this License, you are granted a 40 | non-exclusive, royalty-free, license to reproduce, prepare derivative works of, 41 | publicly display, publicly perform and distribute its *Work* and any resulting 42 | derivative works in any form. 43 | 44 | ## 4. Limitations 45 | 46 | **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do 47 | so under this License, (b) you include a complete copy of this License with 48 | your distribution, and (c) you retain without modification any copyright, 49 | patent, trademark, or attribution notices that are present in the *Work*. 50 | 51 | **4.2 Derivative Works.** You may specify that additional or different terms apply 52 | to the use, reproduction, and distribution of your derivative works of the *Work* 53 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 54 | Section 2 applies to your derivative works, and (b) you identify the specific 55 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 56 | this License (including the redistribution requirements in Section 3.1) will 57 | continue to apply to the *Work* itself. 58 | 59 | **4.3** Any other use without of prior consent of Licensors is prohibited. Research 60 | users explicitly acknowledge having received from Licensors all information 61 | allowing to appreciate the adequacy between of the *Software* and their needs and 62 | to undertake all necessary precautions for its execution and use. 63 | 64 | **4.4** The *Software* is provided both as a compiled library file and as source 65 | code. In case of using the *Software* for a publication or other results obtained 66 | through the use of the *Software*, users are strongly encouraged to cite the 67 | corresponding publications as explained in the documentation of the *Software*. 68 | 69 | ## 5. Disclaimer 70 | 71 | THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES 72 | WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY 73 | UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL 74 | CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES 75 | OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL 76 | USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR 77 | ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE 78 | AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 79 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 80 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) 81 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 82 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR 83 | IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. 84 | 85 | ## 6. Files subject to permissive licenses 86 | The contents of the file ```utils/loss_utils.py``` are based on publicly available code authored by Evan Su, which falls under the permissive MIT license. 87 | 88 | Title: pytorch-ssim\ 89 | Project code: https://github.com/Po-Hsun-Su/pytorch-ssim\ 90 | Copyright Evan Su, 2017\ 91 | License: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/LICENSE.txt (MIT) -------------------------------------------------------------------------------- /eval/copy_compressed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Helper script to package up a set of scenes from a training run into a zip file. 4 | # 5 | # Copies the compressed files from a training run to a destination directory. 6 | # Only copies one compression (third argument), does not copy any of the training logs or decompressed files. 7 | # Sorts experiments into folders by dataset. The dataset is determined by the first part of the folder name after the first dash and before the first underscore. 8 | # For another directory structure, the dataset name is the string before the first underscore and after the last dash. 9 | # Folders are organized into MipNeRF360, DeepBlending, SyntheticNeRF, and TanksAndTemples based on the dataset prefix. 10 | # After copying, the entire destination directory is zipped, but the unzipped version is retained. 11 | 12 | # Function to print error messages 13 | error_exit() { 14 | echo "Error: $1" >&2 15 | exit 1 16 | } 17 | 18 | # Check for correct number of arguments 19 | if [ "$#" -ne 3 ]; then 20 | error_exit "Usage: $0 " 21 | fi 22 | 23 | # Assign arguments to variables 24 | SOURCE_DIR="$1" 25 | DEST_DIR="$2" 26 | COMPRESSION_NAME="$3" 27 | 28 | # Check if source directory exists 29 | if [ ! -d "$SOURCE_DIR" ]; then 30 | error_exit "Source directory does not exist: $SOURCE_DIR" 31 | fi 32 | 33 | # Create destination directory if it does not exist 34 | mkdir -p "$DEST_DIR" || error_exit "Failed to create destination directory: $DEST_DIR" 35 | 36 | # Function to determine the dataset and scene name 37 | get_dataset_and_scene() { 38 | local folder_name="$1" 39 | local dataset="" 40 | local scene="" 41 | 42 | # Extract the scene name (between last and second to last underscore) 43 | scene=$(echo "$folder_name" | awk -F'_' '{print $(NF-1)}') 44 | 45 | # Determine the dataset by searching for the specific keywords in the folder name 46 | if [[ "$folder_name" =~ 360 ]]; then 47 | dataset="MipNeRF360" 48 | elif [[ "$folder_name" =~ db ]]; then 49 | dataset="DeepBlending" 50 | elif [[ "$folder_name" =~ blender ]]; then 51 | dataset="SyntheticNeRF" 52 | elif [[ "$folder_name" =~ tand ]]; then 53 | dataset="TanksAndTemples" 54 | else 55 | error_exit "Unknown dataset prefix for folder: $folder_name" 56 | fi 57 | 58 | echo "$dataset/$scene" 59 | } 60 | 61 | 62 | # Function to copy specific files and directories 63 | copy_files() { 64 | local src_dir="$1" 65 | local dest_dir="$2" 66 | local compression_name="$3" 67 | 68 | # Copy cameras.json 69 | if [ -f "$src_dir/cameras.json" ]; then 70 | cp "$src_dir/cameras.json" "$dest_dir" || error_exit "Failed to copy cameras.json" 71 | fi 72 | 73 | # Copy training_config.yaml 74 | if [ -f "$src_dir/training_config.yaml" ]; then 75 | cp "$src_dir/training_config.yaml" "$dest_dir" || error_exit "Failed to copy training_config.yaml" 76 | fi 77 | 78 | # Copy cfg_args 79 | if [ -f "$src_dir/cfg_args" ]; then 80 | cp "$src_dir/cfg_args" "$dest_dir" || error_exit "Failed to copy cfg_args" 81 | fi 82 | 83 | # Copy files from compression/iteration_30000// 84 | local comp_src_dir="$src_dir/compression/iteration_30000/$compression_name" 85 | if [ -d "$comp_src_dir" ]; then 86 | local comp_dest_dir="$dest_dir/compression/iteration_30000/$compression_name" 87 | mkdir -p "$comp_dest_dir" || error_exit "Failed to create directory: $comp_dest_dir" 88 | find "$comp_src_dir" -maxdepth 1 -type f -exec cp {} "$comp_dest_dir" \; || error_exit "Failed to copy files from $comp_src_dir" 89 | fi 90 | } 91 | 92 | # Iterate over each subdirectory in the source directory 93 | for subdir in "$SOURCE_DIR"/*/; do 94 | subdir_name=$(basename "$subdir") 95 | 96 | # Determine dataset and scene 97 | dataset_scene=$(get_dataset_and_scene "$subdir_name") 98 | 99 | dest_subdir="$DEST_DIR/$dataset_scene" 100 | 101 | # Create destination subdirectory 102 | mkdir -p "$dest_subdir" || error_exit "Failed to create directory: $dest_subdir" 103 | 104 | # Copy the specified files and directories 105 | copy_files "$subdir" "$dest_subdir" "$COMPRESSION_NAME" 106 | done 107 | 108 | # Zip the entire destination directory without compression 109 | zip_file="$DEST_DIR.zip" 110 | zip -r -0 "$zip_file" "$DEST_DIR" || error_exit "Failed to zip the directory: $DEST_DIR" 111 | 112 | echo "Files copied and zipped successfully. Zip file created at: $zip_file" 113 | 114 | # Function to create a zip file for each dataset without compression 115 | zip_datasets() { 116 | local dest_dir="$1" 117 | 118 | # Iterate over each dataset directory within the destination directory 119 | for dataset_dir in "$dest_dir"/*/; do 120 | dataset_name=$(basename "$dataset_dir") 121 | zip_file="$dest_dir/$dataset_name.zip" 122 | 123 | # Zip the contents of each dataset directory without compression 124 | zip -r -0 "$zip_file" "$dataset_dir" || error_exit "Failed to zip the dataset directory: $dataset_dir" 125 | 126 | echo "Dataset zipped successfully: $zip_file" 127 | done 128 | } 129 | 130 | # Call the function to create individual dataset zips 131 | zip_datasets "$DEST_DIR" 132 | -------------------------------------------------------------------------------- /arguments/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from argparse import ArgumentParser, Namespace 13 | import sys 14 | import os 15 | import re 16 | from omegaconf import OmegaConf 17 | 18 | class GroupParams: 19 | pass 20 | 21 | class ParamGroup: 22 | def __init__(self, parser: ArgumentParser, name : str, fill_none = False): 23 | group = parser.add_argument_group(name) 24 | for key, value in vars(self).items(): 25 | shorthand = False 26 | if key.startswith("_"): 27 | shorthand = True 28 | key = key[1:] 29 | t = type(value) 30 | value = value if not fill_none else None 31 | if shorthand: 32 | if t == bool: 33 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") 34 | else: 35 | group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) 36 | else: 37 | if t == bool: 38 | group.add_argument("--" + key, default=value, action="store_true") 39 | else: 40 | group.add_argument("--" + key, default=value, type=t) 41 | 42 | def extract(self, args): 43 | group = GroupParams() 44 | for arg in vars(args).items(): 45 | if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): 46 | setattr(group, arg[0], arg[1]) 47 | return group 48 | 49 | class ModelParams(ParamGroup): 50 | def __init__(self, parser, sentinel=False): 51 | self.sh_degree = 3 52 | self._source_path = "" 53 | self._model_path = "" 54 | self._images = "images" 55 | self._resolution = -1 56 | self._white_background = False 57 | self.data_device = "cuda" 58 | self.eval = False 59 | super().__init__(parser, "Loading Parameters", sentinel) 60 | 61 | def extract(self, args): 62 | g = super().extract(args) 63 | g.source_path = os.path.abspath(g.source_path) 64 | return g 65 | 66 | class PipelineParams(ParamGroup): 67 | def __init__(self, parser): 68 | self.convert_SHs_python = False 69 | self.compute_cov3D_python = False 70 | self.debug = False 71 | super().__init__(parser, "Pipeline Parameters") 72 | 73 | class OptimizationParams(ParamGroup): 74 | def __init__(self, parser): 75 | self.iterations = 30_000 76 | self.position_lr_init = 0.00016 77 | self.position_lr_final = 0.0000016 78 | self.position_lr_delay_mult = 0.01 79 | self.position_lr_max_steps = 30_000 80 | self.feature_lr = 0.0025 81 | self.opacity_lr = 0.05 82 | self.scaling_lr = 0.005 83 | self.rotation_lr = 0.001 84 | self.percent_dense = 0.01 85 | self.lambda_dssim = 0.2 86 | self.densification_interval = 100 87 | self.opacity_reset_interval = 3000 88 | self.densify_from_iter = 500 89 | self.densify_until_iter = 15_000 90 | self.densify_grad_threshold = 0.0002 91 | self.random_background = False 92 | super().__init__(parser, "Optimization Parameters") 93 | 94 | def add_quotes_to_strings(s): 95 | # Only wrap model_path, source_path, and images values with quotes if they are not already quoted 96 | s = re.sub(r'(?<=model_path=)([^,\s]+)', r'"\1"', s) 97 | s = re.sub(r'(?<=source_path=)([^,\s]+)', r'"\1"', s) 98 | s = re.sub(r'(?<=images=)([^,\s]+)', r'"\1"', s) 99 | return s 100 | 101 | 102 | 103 | def get_combined_args(parser : ArgumentParser): 104 | cmdlne_string = sys.argv[1:] 105 | cfgfile_string = "Namespace()" 106 | args_cmdline = parser.parse_args(cmdlne_string) 107 | 108 | try: 109 | cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") 110 | print("Looking for config file in", cfgfilepath) 111 | with open(cfgfilepath) as cfg_file: 112 | print("Config file found: {}".format(cfgfilepath)) 113 | cfgfile_string = cfg_file.read() 114 | except FileNotFoundError: 115 | print("Config file not found at") 116 | pass 117 | try: 118 | args_cfgfile = eval(cfgfile_string) 119 | except SyntaxError: 120 | # If eval fails due to syntax error, apply quoting and try again 121 | cfgfile_string = add_quotes_to_strings(cfgfile_string) 122 | args_cfgfile = eval(cfgfile_string) 123 | 124 | merged_dict = vars(args_cfgfile).copy() 125 | for k,v in vars(args_cmdline).items(): 126 | if v != None: 127 | merged_dict[k] = v 128 | return Namespace(**merged_dict) 129 | 130 | def get_hydra_training_args(model_path): 131 | try: 132 | cfgfilepath = os.path.join(model_path, "training_config.yaml") 133 | print("Looking for config file in", cfgfilepath) 134 | with open(cfgfilepath, 'r') as file: 135 | print("Config file found: {}".format(cfgfilepath)) 136 | training_cfg = OmegaConf.load(cfgfilepath) 137 | return training_cfg 138 | except TypeError: 139 | print("Config file not found!") -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import logging 14 | from argparse import ArgumentParser 15 | import shutil 16 | 17 | # This Python script is based on the shell converter script provided in the MipNerF 360 repository. 18 | parser = ArgumentParser("Colmap converter") 19 | parser.add_argument("--no_gpu", action='store_true') 20 | parser.add_argument("--skip_matching", action='store_true') 21 | parser.add_argument("--source_path", "-s", required=True, type=str) 22 | parser.add_argument("--camera", default="OPENCV", type=str) 23 | parser.add_argument("--colmap_executable", default="", type=str) 24 | parser.add_argument("--resize", action="store_true") 25 | parser.add_argument("--magick_executable", default="", type=str) 26 | args = parser.parse_args() 27 | colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" 28 | magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" 29 | use_gpu = 1 if not args.no_gpu else 0 30 | 31 | if not args.skip_matching: 32 | os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) 33 | 34 | ## Feature extraction 35 | feat_extracton_cmd = colmap_command + " feature_extractor "\ 36 | "--database_path " + args.source_path + "/distorted/database.db \ 37 | --image_path " + args.source_path + "/input \ 38 | --ImageReader.single_camera 1 \ 39 | --ImageReader.camera_model " + args.camera + " \ 40 | --SiftExtraction.use_gpu " + str(use_gpu) 41 | exit_code = os.system(feat_extracton_cmd) 42 | if exit_code != 0: 43 | logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") 44 | exit(exit_code) 45 | 46 | ## Feature matching 47 | feat_matching_cmd = colmap_command + " exhaustive_matcher \ 48 | --database_path " + args.source_path + "/distorted/database.db \ 49 | --SiftMatching.use_gpu " + str(use_gpu) 50 | exit_code = os.system(feat_matching_cmd) 51 | if exit_code != 0: 52 | logging.error(f"Feature matching failed with code {exit_code}. Exiting.") 53 | exit(exit_code) 54 | 55 | ### Bundle adjustment 56 | # The default Mapper tolerance is unnecessarily large, 57 | # decreasing it speeds up bundle adjustment steps. 58 | mapper_cmd = (colmap_command + " mapper \ 59 | --database_path " + args.source_path + "/distorted/database.db \ 60 | --image_path " + args.source_path + "/input \ 61 | --output_path " + args.source_path + "/distorted/sparse \ 62 | --Mapper.ba_global_function_tolerance=0.000001") 63 | exit_code = os.system(mapper_cmd) 64 | if exit_code != 0: 65 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 66 | exit(exit_code) 67 | 68 | ### Image undistortion 69 | ## We need to undistort our images into ideal pinhole intrinsics. 70 | img_undist_cmd = (colmap_command + " image_undistorter \ 71 | --image_path " + args.source_path + "/input \ 72 | --input_path " + args.source_path + "/distorted/sparse/0 \ 73 | --output_path " + args.source_path + "\ 74 | --output_type COLMAP") 75 | exit_code = os.system(img_undist_cmd) 76 | if exit_code != 0: 77 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 78 | exit(exit_code) 79 | 80 | files = os.listdir(args.source_path + "/sparse") 81 | os.makedirs(args.source_path + "/sparse/0", exist_ok=True) 82 | # Copy each file from the source directory to the destination directory 83 | for file in files: 84 | if file == '0': 85 | continue 86 | source_file = os.path.join(args.source_path, "sparse", file) 87 | destination_file = os.path.join(args.source_path, "sparse", "0", file) 88 | shutil.move(source_file, destination_file) 89 | 90 | if(args.resize): 91 | print("Copying and resizing...") 92 | 93 | # Resize images. 94 | os.makedirs(args.source_path + "/images_2", exist_ok=True) 95 | os.makedirs(args.source_path + "/images_4", exist_ok=True) 96 | os.makedirs(args.source_path + "/images_8", exist_ok=True) 97 | # Get the list of files in the source directory 98 | files = os.listdir(args.source_path + "/images") 99 | # Copy each file from the source directory to the destination directory 100 | for file in files: 101 | source_file = os.path.join(args.source_path, "images", file) 102 | 103 | destination_file = os.path.join(args.source_path, "images_2", file) 104 | shutil.copy2(source_file, destination_file) 105 | exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) 106 | if exit_code != 0: 107 | logging.error(f"50% resize failed with code {exit_code}. Exiting.") 108 | exit(exit_code) 109 | 110 | destination_file = os.path.join(args.source_path, "images_4", file) 111 | shutil.copy2(source_file, destination_file) 112 | exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) 113 | if exit_code != 0: 114 | logging.error(f"25% resize failed with code {exit_code}. Exiting.") 115 | exit(exit_code) 116 | 117 | destination_file = os.path.join(args.source_path, "images_8", file) 118 | shutil.copy2(source_file, destination_file) 119 | exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) 120 | if exit_code != 0: 121 | logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") 122 | exit(exit_code) 123 | 124 | print("Done.") 125 | -------------------------------------------------------------------------------- /eval/download_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script downloads and evaluates the Self-Organizing Gaussians (SOGS) scenes. 4 | # It will download and unpack the scenes from the SOGS repository, 5 | # decompress the models to .ply, render the evaluation images, compute the metrics, and collect the evaluation results. 6 | 7 | # It supports two datasets: one with spherical harmonics ("w/ SH", Baseline) and one without ("w/o SH"). 8 | # 9 | # Usage: ./download_eval.sh [source_search_path1 source_search_path2 ...] 10 | # e.g. ./download_eval.sh /data/sogs_results /data/DeepBlending /data/Blender /data/MipNerf360 /data/MipNerf360_extra /data/TandT 11 | # 12 | # Output will be saved as .csv files in the destination directory, under results/ 13 | 14 | set -euo pipefail 15 | 16 | # Add the Self-Organizing Gaussians code folder of where the script is located to PYTHONPATH 17 | CODE_DIR="$(dirname "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")")" 18 | echo "CODE_DIR: $CODE_DIR" 19 | export PYTHONPATH="$CODE_DIR:${PYTHONPATH:-}" 20 | 21 | # Ensure at least one argument is passed (data directory) 22 | if [ $# -lt 1 ]; then 23 | echo "Usage: $0 [source_search_path1 source_search_path2 ...]" 24 | exit 1 25 | fi 26 | 27 | DEST_DIR="$1" 28 | shift # Remove the first argument 29 | 30 | # Remaining arguments are scene source search paths (if any) 31 | SEARCH_PATHS=("$@") 32 | 33 | echo "Destination directory: $DEST_DIR" 34 | echo "Dataset source search paths: ${SEARCH_PATHS[@]}" 35 | 36 | # Function to find a scene in the provided search paths 37 | find_scene_in_paths() { 38 | local scene_name="$1" 39 | 40 | for search_path in "${SEARCH_PATHS[@]}"; do 41 | if [ -d "$search_path/$scene_name" ]; then 42 | echo "$(readlink -f "$search_path/$scene_name")" # Return the absolute path 43 | return 0 44 | fi 45 | done 46 | 47 | echo "Scene directory '$scene_name' not found in search paths: ${SEARCH_PATHS[*]}" >&2 48 | return 1 49 | } 50 | 51 | # List of required scenes 52 | SCENE_LIST=(drjohnson playroom bicycle bonsai counter flowers garden kitchen room stump treehill chair drums ficus hotdog lego materials mic ship train truck) 53 | 54 | # Check that all required scenes can be found 55 | MISSING_SCENES=() 56 | 57 | for SCENE in "${SCENE_LIST[@]}"; do 58 | if ! find_scene_in_paths "$SCENE" >/dev/null 2>&1; then 59 | MISSING_SCENES+=("$SCENE") 60 | fi 61 | done 62 | 63 | if [ ${#MISSING_SCENES[@]} -ne 0 ]; then 64 | echo "The following scene directories were not found in the search paths: ${MISSING_SCENES[*]}" 65 | echo "Please add the dataset source directories as additional arguments to this script." 66 | exit 1 67 | fi 68 | 69 | mkdir -p "$DEST_DIR" 70 | 71 | # Function to download and process a dataset 72 | process_dataset() { 73 | local DATASET_URL="$1" 74 | local DATA_SUBDIR="$2" 75 | local SUBMETHOD="$3" 76 | local COMPRESSION_DIR_NAME="$4" 77 | 78 | # Extract the ZIP file name from the URL 79 | local ZIP_FILE_NAME="${DATASET_URL##*/}" 80 | 81 | # Download and extract the dataset 82 | cd "$DEST_DIR" 83 | echo "Downloading $ZIP_FILE_NAME..." 84 | curl -C - -# -L "$DATASET_URL" -o "$ZIP_FILE_NAME" || { echo "Download failed"; exit 1; } 85 | unzip -n -q "$ZIP_FILE_NAME" || { echo "Failed to unzip"; exit 1; } 86 | echo "Download and extraction complete for $ZIP_FILE_NAME." 87 | 88 | # Process each dataset and scene 89 | local DATASET_PATH="$DEST_DIR/$DATA_SUBDIR" 90 | cd "$DATASET_PATH" 91 | 92 | for DATASET in */; do 93 | DATASET=${DATASET%/} 94 | echo "Dataset: $DATASET" 95 | 96 | for DATASET_SCENE in "$DATASET"/*/; do 97 | DATASET_SCENE=${DATASET_SCENE%/} 98 | SCENE=$(basename "$DATASET_SCENE") 99 | COMPRESSED_MODEL_PATH="$DATASET_PATH/$DATASET_SCENE/compression/iteration_30000/$COMPRESSION_DIR_NAME/" 100 | 101 | SCENE_SOURCE_PATH=$(find_scene_in_paths "$SCENE") || { echo "Scene '$SCENE' not found, please add its source directory to the search paths."; exit 1; } 102 | echo "Scene: $DATASET_SCENE, Source path: $SCENE_SOURCE_PATH" 103 | 104 | echo " Decompressing $DATASET_SCENE to 3DGS .ply" 105 | python "${CODE_DIR}/compression/decompress.py" \ 106 | --compressed_model_path "${COMPRESSED_MODEL_PATH}" 107 | 108 | echo " Rendering eval images for $DATASET_SCENE" 109 | python "${CODE_DIR}/render.py" \ 110 | --source_path "${SCENE_SOURCE_PATH}" \ 111 | --model_path "${COMPRESSED_MODEL_PATH}/decompressed_model" \ 112 | --skip_train \ 113 | --eval \ 114 | --data_device cuda \ 115 | --disable_xyz_log_activation 116 | 117 | echo " Computing metrics for $DATASET_SCENE" 118 | python "${CODE_DIR}/metrics.py" \ 119 | --model_path "${COMPRESSED_MODEL_PATH}/decompressed_model" 120 | 121 | echo " Collecting evaluation results for $DATASET_SCENE" 122 | python "${CODE_DIR}/eval/collect_eval_per_scene.py" \ 123 | --output-dir "${DEST_DIR}" \ 124 | --dataset "${DATASET}" \ 125 | --scene "${SCENE}" \ 126 | --model-path "${COMPRESSED_MODEL_PATH}" \ 127 | --submethod "${SUBMETHOD}" 128 | done 129 | done 130 | } 131 | 132 | # ECCV Baseline, with SH 133 | DATASET1_URL="https://github.com/fraunhoferhhi/Self-Organizing-Gaussians/releases/download/eccv-2024-data/Scenes_SOGS_ECCV_with_SH.zip" 134 | DATASET1_SUBDIR="results_SOGS_ECCV/with_SH" 135 | DATASET1_SUBMETHOD="Baseline" 136 | DATASET1_COMPRESSION_DIR="jxl_quant_sh" 137 | 138 | # ECCV w/o SH 139 | DATASET2_URL="https://github.com/fraunhoferhhi/Self-Organizing-Gaussians/releases/download/eccv-2024-data/Scenes_SOGS_ECCV_without_SH.zip" 140 | DATASET2_SUBDIR="results_SOGS_ECCV/without_SH" 141 | DATASET2_SUBMETHOD=" w/o SH" 142 | DATASET2_COMPRESSION_DIR="jxl_quant" 143 | 144 | process_dataset "$DATASET1_URL" "$DATASET1_SUBDIR" "$DATASET1_SUBMETHOD" "$DATASET1_COMPRESSION_DIR" 145 | 146 | process_dataset "$DATASET2_URL" "$DATASET2_SUBDIR" "$DATASET2_SUBMETHOD" "$DATASET2_COMPRESSION_DIR" 147 | 148 | echo "All scenes processed." 149 | -------------------------------------------------------------------------------- /config/compression/umbrella.yaml: -------------------------------------------------------------------------------- 1 | experiments: 2 | - name: "jxl_101" 3 | attributes: 4 | - name: "_xyz" 5 | method: "jpeg-xl" 6 | params: 7 | level: 101 8 | - name: "_features_dc" 9 | method: "jpeg-xl" 10 | params: 11 | level: 101 12 | # - name: "_features_rest" 13 | # method: "jpeg-xl" 14 | # params: 15 | # level: 101 16 | - name: "_scaling" 17 | method: "jpeg-xl" 18 | params: 19 | level: 101 20 | - name: "_rotation" 21 | method: "jpeg-xl" 22 | params: 23 | level: 101 24 | - name: "_opacity" 25 | method: "jpeg-xl" 26 | params: 27 | level: 101 28 | 29 | - name: "npz_zip" 30 | attributes: 31 | - name: "_xyz" 32 | method: "npz" 33 | - name: "_features_dc" 34 | method: "npz" 35 | # - name: "_features_rest" 36 | # method: "npz" 37 | - name: "_scaling" 38 | method: "npz" 39 | - name: "_rotation" 40 | method: "npz" 41 | - name: "_opacity" 42 | method: "npz" 43 | 44 | # manually tuned parameters on truck with default params early 231108 45 | # - no proper neighbor weighing 46 | - name: "jxl_man" 47 | attributes: 48 | - name: "_xyz" 49 | method: "jpeg-xl" 50 | normalize: false 51 | params: 52 | level: 101 53 | - name: "_features_dc" 54 | method: "jpeg-xl" 55 | normalize: true 56 | params: 57 | level: 50 58 | # - name: "_features_rest" 59 | # method: "jpeg-xl" 60 | # normalize: true 61 | # params: 62 | # level: 90 63 | - name: "_scaling" 64 | method: "jpeg-xl" 65 | normalize: true 66 | params: 67 | level: 100 68 | - name: "_rotation" 69 | method: "jpeg-xl" 70 | normalize: true 71 | params: 72 | level: 100 73 | - name: "_opacity" 74 | method: "jpeg-xl" 75 | normalize: true 76 | params: 77 | level: 20 78 | 79 | # EXR w/ OpenCV can only write 1, 3 or 4 channels 80 | - name: "exr" 81 | attributes: 82 | - name: "_xyz" 83 | method: "exr" 84 | params: 85 | compression: "zip" 86 | - name: "_features_dc" 87 | method: "exr" 88 | params: 89 | compression: "zip" 90 | # - name: "_features_rest" 91 | # method: "jpeg-xl" 92 | # params: 93 | # level: 101 94 | - name: "_scaling" 95 | method: "exr" 96 | params: 97 | compression: "zip" 98 | - name: "_rotation" 99 | method: "exr" 100 | params: 101 | compression: "zip" 102 | - name: "_opacity" 103 | method: "exr" 104 | params: 105 | compression: "zip" 106 | 107 | - name: "png_16" 108 | attributes: 109 | - name: "_xyz" 110 | method: "png" 111 | normalize: true 112 | contract: false 113 | params: 114 | dtype: "uint16" 115 | - name: "_features_dc" 116 | method: "png" 117 | normalize: true 118 | params: 119 | dtype: "uint16" 120 | # - name: "_features_rest" 121 | # method: "jpeg-xl" 122 | # normalize: true 123 | # params: 124 | # level: 101 125 | - name: "_scaling" 126 | method: "png" 127 | normalize: true 128 | params: 129 | dtype: "uint16" 130 | - name: "_rotation" 131 | method: "png" 132 | normalize: true 133 | params: 134 | dtype: "uint16" 135 | - name: "_opacity" 136 | method: "png" 137 | normalize: true 138 | params: 139 | dtype: "uint16" 140 | 141 | - name: "exr_jxl_quant" 142 | attributes: 143 | - name: "_xyz" 144 | method: "exr" 145 | normalize: false 146 | contract: false 147 | quantize: 13 148 | params: 149 | compression: "zip" 150 | - name: "_features_dc" 151 | method: "jpeg-xl" 152 | normalize: true 153 | # contract: false 154 | # quantize: 8 155 | params: 156 | level: 90 157 | # - name: "_features_rest" 158 | # method: "jpeg-xl" 159 | # normalize: false 160 | # # quantize: 8 161 | # params: 162 | # level: 101 163 | # # compression: "none" 164 | # # dtype: "uint8" 165 | - name: "_scaling" 166 | method: "exr" 167 | normalize: false 168 | contract: false 169 | quantize: 6 170 | params: 171 | compression: "none" 172 | - name: "_rotation" 173 | method: "exr" 174 | normalize: false 175 | # contract: false 176 | quantize: 6 177 | params: 178 | compression: "zip" 179 | # - name: "_rotation" 180 | # method: "jpeg-xl" 181 | # normalize: true 182 | # contract: false 183 | # quantize: 8 184 | # params: 185 | # level: 90 186 | - name: "_opacity" 187 | method: "exr" 188 | normalize: true 189 | contract: false 190 | quantize: 5 191 | params: 192 | compression: "none" 193 | 194 | - name: "jxl_quant" 195 | attributes: 196 | - name: "_xyz" 197 | method: "jpeg-xl" 198 | normalize: true 199 | quantize: 14 200 | params: 201 | # compression: "zip" 202 | level: 101 203 | - name: "_features_dc" 204 | method: "jpeg-xl" 205 | normalize: true 206 | params: 207 | level: 100 208 | - name: "_scaling" 209 | method: "jpeg-xl" 210 | normalize: false 211 | contract: false 212 | quantize: 6 213 | params: 214 | level: 101 215 | - name: "_rotation" 216 | method: "jpeg-xl" 217 | normalize: true 218 | quantize: 6 219 | params: 220 | level: 101 221 | - name: "_opacity" 222 | method: "jpeg-xl" 223 | normalize: true 224 | contract: false 225 | quantize: 6 226 | params: 227 | level: 101 228 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Compact 3D Scene Representation via Self-Organizing Gaussian Grids

3 |
4 | 5 |

6 | Teaser of the publication. Millions of Gaussians at 174 MB with a PSNR of 24.90 are sorted into 2D attribute grids, stored at 17 MB with the same PSNR 7 |
8 |
9 | Project Page 10 | · 11 | arXiv 12 |

13 | 14 |
15 | 16 | ### Code 17 | 18 | This repository is a fork of the official authors implementation associated with the paper "3D Gaussian Splatting for Real-Time Radiance Field Rendering". 19 | 20 | The code for "Compact 3D Scene Representation via Self-Organizing Gaussian Grids" consists of multiple parts. The multi-dimensional sorting algorithm, PLAS, is available under the Apache License at [fraunhoferhhi/PLAS](https://github.com/fraunhoferhhi/PLAS). 21 | 22 | The integration of the sorting, the smoothness regularization and the compression code for training and compressing 3D scenes is available in this repository. 23 | 24 | ## Cloning the Repository 25 | 26 | The repository contains submodules, thus please check it out with 27 | 28 | ```shell 29 | # SSH 30 | git clone git@github.com:fraunhoferhhi/Self-Organizing-Gaussians.git --recursive 31 | ``` 32 | 33 | or 34 | 35 | ```shell 36 | # HTTPS 37 | git clone https://github.com/fraunhoferhhi/Self-Organizing-Gaussians.git --recursive 38 | ``` 39 | 40 | ## Python Environment 41 | 42 | The code is using a few additional Python packages on top of graphdeco-inria/gaussian-splatting. We provide an extended environment.yml: 43 | 44 | Installation with [micromamba](https://mamba.readthedocs.io/en/latest/installation/micromamba-installation.html): 45 | 46 | ```shell 47 | micromamba env create --file environment.yml --channel-priority flexible -y 48 | micromamba activate sogs 49 | ``` 50 | 51 | ## Example training 52 | 53 | Download a dataset, e.g. [T&T](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/datasets/input/tandt_db.zip). 54 | 55 | The train.py script expects a name to a .yaml config file in the [config/](config/) folder. All parameters for the run are by default loaded from the yaml file. An example launch file can be found in .vscode/launch.json, for launching from Visual Studio Code. 56 | 57 | Example: 58 | 59 | ```shell 60 | python train.py \ 61 | --config-name ours_q_sh_local_test \ 62 | hydra.run.dir=/data/output/${now:%Y-%m-%d}/${now:%H-%M-%S}-${run.name} \ 63 | dataset.source_path=/data/gaussian_splatting/tandt_db/tandt/truck \ 64 | run.no_progress_bar=false \ 65 | run.name=vs-code-debug 66 | ``` 67 | 68 | The parameter configurations can be overriden in the launch as shown (using [Hydra](https://hydra.cc/)). 69 | 70 | ## Pre-Trained Models & Evaluation 71 | 72 | Trained and compressed scenes are available for download in the [ECCV 2024 release](https://github.com/fraunhoferhhi/Self-Organizing-Gaussians/releases/tag/eccv-2024-data). 73 | 74 | The script at [eval/download_eval.sh](https://github.com/fraunhoferhhi/Self-Organizing-Gaussians/blob/main/eval/download_eval.sh) will automatically: 75 | * download the pre-trained scenes with and without spherical harmonics 76 | * measure size on disk and number of Gaussians of the compressed scenes 77 | * decompress the scenes into .ply 78 | * render the test images for each scene, using the original 3DGS code 79 | * compute the metrics (PSNR, SSIM, LPIPS) for all test images 80 | * gather the results in .csv, in the format of the [3DGS compression survey](https://w-m.github.io/3dgs-compression-survey/) 81 | 82 | The evaluation results can be found in [results/](https://github.com/fraunhoferhhi/Self-Organizing-Gaussians/blob/main/results/). 83 | 84 | ## Differences with graphdeco-inria/gaussian-splatting 85 | 86 | Code differences can be found in this diff: https://github.com/fraunhoferhhi/Self-Organizing-Gaussians/pull/1/files 87 | 88 | ### Usage 89 | 90 | - different command-line interface for train.py (using Hydra) 91 | - wandb.ai used for logging 92 | 93 | ### Code extensions 94 | 95 | - post-training quantization, compression/decompression 96 | - xyz log activation (gaussian_model.py) 97 | - grid sorting, neighbor loss (gaussian_model.py) 98 | - option to disable spherical harmonics 99 | 100 | ## Citation 101 | 102 | If you use our method in your research, please cite our paper. The paper was presented at ECCV 2024 and [published](https://doi.org/10.1007/978-3-031-73013-9_2) in the official proceedings in 2025. You can use the following BibTeX entry: 103 | 104 | ```bibtex 105 | @InProceedings{morgenstern2024compact, 106 | author = {Wieland Morgenstern and Florian Barthel and Anna Hilsmann and Peter Eisert}, 107 | title = {Compact 3D Scene Representation via Self-Organizing Gaussian Grids}, 108 | booktitle = {Computer Vision -- {ECCV} 2024}, 109 | year = {2025}, 110 | publisher = {Springer Nature Switzerland}, 111 | address = {Cham}, 112 | pages = {18--34}, 113 | doi = {10.1007/978-3-031-73013-9_2}, 114 | url = {https://fraunhoferhhi.github.io/Self-Organizing-Gaussians/}, 115 | } 116 | ``` 117 | 118 | ## Updates 119 | 120 | - 2024-12-03: Freezing package versions in environment.yml, particularly imagecodecs. There is a regression in file size with later imagecodecs versions, see [#3](https://github.com/fraunhoferhhi/Self-Organizing-Gaussians/issues/3). 121 | - 2024-11-28: Add [ECCV 2024 Redux Talk](https://www.youtube.com/watch?v=nb5U9xfx7-w), [View Dependent Podcast](https://www.youtube.com/watch?v=Y0O6R0Keywg) and proceedings .bib to project page. 122 | - 2024-10-30: Update project page with reduction factors from updated metric computation -> 19.9x to 39.5x compression over 3DGS. 123 | - 2024-09-16: Script to compute per-scene metrics from uploaded models (see *Pre-Trained Models & Evaluation*). This fixes issues in the metric computation, previously done from Weights & Biases runs: *Dr Johnson* now correctly attributed to DeepBlending dataset (was: *T&T*); Quality loss from quantization and compression losses correctly incorporated. 124 | - 2024-08-22: Released pre-trained, [compressed scenes](https://github.com/fraunhoferhhi/Self-Organizing-Gaussians/releases/tag/eccv-2024-data) 125 | - 2024-07-09: Project website updated with TLDR, contributions, insights and comparison to concurrent methods 126 | - 2024-07-01: Our work was accepted at **ECCV 2024** 🥳 127 | - 2024-06-13: Training code available 128 | - 2024-05-14: Improved compression scores! New results for paper v2 available on the [project website](https://fraunhoferhhi.github.io/Self-Organizing-Gaussians/) 129 | - 2024-05-02: Revised [paper v2](https://arxiv.org/pdf/2312.13299) on arXiv: Added compression of spherical harmonics, updated compression method with improved results (all attributes compressed with JPEG XL now), added qualitative comparison of additional scenes, moved compression explanation and comparison to main paper, added comparison with "Making Gaussian Splats smaller". 130 | - 2024-02-22: The code for the sorting algorithm is now available at [fraunhoferhhi/PLAS](https://github.com/fraunhoferhhi/PLAS) 131 | - 2024-02-21: Video comparisons for different scenes available on the [project website](https://fraunhoferhhi.github.io/Self-Organizing-Gaussians/) 132 | - 2023-12-19: Preprint available on [arXiv](https://arxiv.org/abs/2312.13299) 133 | 134 | -------------------------------------------------------------------------------- /training_viewer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from utils.quaternion import quaternion_to_matrix, matrix_to_rotation_6d 5 | from utils.sh_utils import SH2RGB 6 | from screeninfo import get_monitors 7 | import torch 8 | import wandb 9 | 10 | from scene.gaussian_model import GaussianModel 11 | from dataclasses import dataclass 12 | from gaussian_renderer import render 13 | 14 | 15 | def dcn(x: torch.tensor, normalize=False): 16 | if normalize: 17 | x = (x - x.min()) / (x.min() - x.max()) 18 | return x.detach().cpu().numpy() 19 | 20 | 21 | def organize_windows(window_names): 22 | # Get screen width and height 23 | monitor = get_monitors()[0] # Assuming you have only one monitor 24 | screen_width, screen_height = monitor.width, monitor.height 25 | 26 | # Grid dimensions (3x3) 27 | grid_cols = 3 28 | 29 | # Calculate window width and height 30 | window_width = screen_width // 2 // grid_cols 31 | window_height = window_width 32 | 33 | min_y = 64 34 | 35 | # Loop through your windows and position them 36 | for i, window_name in enumerate(window_names): 37 | cv2.namedWindow(window_name, cv2.WINDOW_NORMAL) 38 | 39 | # Calculate position 40 | col = i % grid_cols 41 | row = i // grid_cols 42 | x = screen_width // 2 + col * window_width 43 | y = min_y + row * (window_height + min_y) 44 | 45 | # Set window position and size 46 | cv2.moveWindow(window_name, x, y) 47 | cv2.resizeWindow(window_name, window_width, window_height) 48 | 49 | 50 | def show_grad_img(gaussians, grad, name): 51 | grad = torch.norm(grad, dim=-1) 52 | 53 | gradn = grad / grad.max() 54 | gradn = gradn ** 0.3 55 | gradn = gaussians.as_grid_img(gradn) 56 | # grad = cv2.cvtColor(grad, cv2.COLOR_RGB2BGR) 57 | cv2.imshow(name, gradn.cpu().numpy()) 58 | 59 | 60 | @dataclass 61 | class TrainingViewer: 62 | has_updated: bool = False 63 | debug_view: int = 0 64 | 65 | def training_view(self, scene, gaussians, pipe, background=None): 66 | 67 | if not self.has_updated: 68 | # organize_windows(["xyz", "rgb", "grads_xyz_accum", "opacity", "rotation", "scale", "grad_xyz", "grad_rgb"]) 69 | organize_windows(["xyz", "rgb", "grads_xyz_accum", "opacity", "rotation 3:6", "rotation 0:3", "scale"]) 70 | 71 | viewpoint_cam = scene.getTrainCameras().copy()[self.debug_view] 72 | render_pkg = render(viewpoint_cam, gaussians, pipe, background) 73 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], \ 74 | render_pkg["visibility_filter"], render_pkg["radii"] 75 | 76 | img = image.moveaxis(0, -1).detach().cpu().numpy() 77 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 78 | cv2.imshow("debug view", img) 79 | 80 | xyzs = gaussians.as_grid_img(gaussians._xyz) 81 | rgbs = gaussians.as_grid_img(SH2RGB(gaussians._features_dc)) 82 | 83 | xyzs_norm = (xyzs - xyzs.min()) / (xyzs.max() - xyzs.min()) 84 | cv2.imshow("xyz", xyzs_norm.detach().cpu().numpy()) 85 | 86 | rgbs_norm = torch.clamp(rgbs, 0.0, 1.0) 87 | cv2.imshow("rgb", rgbs_norm.detach().cpu().numpy()) 88 | 89 | MIN_DISPLAY_SCALE = -10 90 | MAX_DISPLAY_SCALE = 3 91 | scales = gaussians.as_grid_img(gaussians._scaling) 92 | display_scales = torch.clamp(scales, MIN_DISPLAY_SCALE, MAX_DISPLAY_SCALE) 93 | scales_norm = (display_scales - MIN_DISPLAY_SCALE) / (MAX_DISPLAY_SCALE - MIN_DISPLAY_SCALE) 94 | cv2.imshow("scale", scales_norm.detach().cpu().numpy()) 95 | 96 | MIN_DISPLAY_OPACITY = -5 97 | MAX_DISPLAY_OPACITY = 8 98 | opacities = gaussians.as_grid_img(gaussians._opacity) 99 | display_opacities = torch.clamp(opacities, MIN_DISPLAY_OPACITY, MAX_DISPLAY_OPACITY) 100 | opacities_norm = (display_opacities - MIN_DISPLAY_OPACITY) / (MAX_DISPLAY_OPACITY - MIN_DISPLAY_OPACITY) 101 | cv2.imshow("opacity", opacities_norm.detach().cpu().numpy()) 102 | 103 | # quaternions = gaussians._rotation 104 | # euler_angles = tgm.quaternion_to_angle_axis(quaternions) 105 | # euler_norm = (euler_angles + np.pi) / (2 * np.pi) 106 | # euler_img = gaussians.as_grid_img(euler_norm) 107 | 108 | quaternions = gaussians._rotation 109 | matrix = quaternion_to_matrix(quaternions) 110 | euler_angles = matrix_to_rotation_6d(matrix) # , convention="XYZ") 111 | euler_norm = (euler_angles + torch.pi) / (2 * torch.pi) 112 | euler_img_03 = gaussians.as_grid_img(euler_norm[..., :3]) 113 | euler_img_36 = gaussians.as_grid_img(euler_norm[..., 3:]) 114 | cv2.imshow("rotation 0:3", euler_img_03.detach().cpu().numpy()) 115 | cv2.imshow("rotation 3:6", euler_img_36.detach().cpu().numpy()) 116 | 117 | grads = gaussians.xyz_gradient_accum / gaussians.denom 118 | grads[grads.isnan()] = 0.0 119 | grads_norm = (grads - grads.min()) / (grads.max() - grads.min()) 120 | grads_img = gaussians.as_grid_img(grads_norm) 121 | cv2.imshow("grads_xyz_accum", grads_img.detach().cpu().numpy()) 122 | 123 | if not self.has_updated: 124 | # while cv2.waitKey(1) != 32: 125 | # pass 126 | self.has_updated = True 127 | 128 | cv2.waitKey(1) 129 | 130 | def training_view_wandb(self, scene, gaussians: GaussianModel, step, pipe, background=None): 131 | 132 | # images are now rendered in evaluation 133 | # viewpoint_cam = scene.getTrainCameras().copy()[self.debug_view] 134 | # render_pkg = render(viewpoint_cam, gaussians, pipe, background) 135 | # image = render_pkg["render"] 136 | # img = dcn(image.moveaxis(0, -1)) 137 | # img = np.clip(img, 0, 1) 138 | # img = wandb.Image(img, caption="debug view") 139 | 140 | xyzs = gaussians.as_grid_img(gaussians._xyz) 141 | xyzs_norm = (xyzs - xyzs.min()) / (xyzs.max() - xyzs.min()) 142 | xyz_img = wandb.Image(dcn(xyzs_norm), caption="XYZ") 143 | 144 | rgbs = gaussians.as_grid_img(SH2RGB(gaussians._features_dc)) 145 | rgbs_norm = torch.clamp(rgbs, 0.0, 1.0) 146 | rgb_img = wandb.Image(dcn(rgbs_norm), caption="RGB") 147 | 148 | MIN_DISPLAY_SCALE = -10 149 | MAX_DISPLAY_SCALE = 3 150 | scales = gaussians.as_grid_img(gaussians._scaling) 151 | display_scales = torch.clamp(scales, MIN_DISPLAY_SCALE, MAX_DISPLAY_SCALE) 152 | scales_norm = (display_scales - MIN_DISPLAY_SCALE) / (MAX_DISPLAY_SCALE - MIN_DISPLAY_SCALE) 153 | scale_img = wandb.Image(dcn(scales_norm), caption="scale") 154 | 155 | MIN_DISPLAY_OPACITY = -5 156 | MAX_DISPLAY_OPACITY = 8 157 | opacities = gaussians.as_grid_img(gaussians._opacity) 158 | display_opacities = torch.clamp(opacities, MIN_DISPLAY_OPACITY, MAX_DISPLAY_OPACITY) 159 | opacities_norm = (display_opacities - MIN_DISPLAY_OPACITY) / (MAX_DISPLAY_OPACITY - MIN_DISPLAY_OPACITY) 160 | opacity_img = wandb.Image(dcn(opacities_norm), caption="opacity") 161 | 162 | # quaternions = gaussians._rotation 163 | # euler_angles = tgm.quaternion_to_angle_axis(quaternions) 164 | # euler_norm = (euler_angles + np.pi) / (2 * np.pi) 165 | # euler_img = gaussians.as_grid_img(euler_norm) 166 | 167 | quaternions = gaussians._rotation 168 | matrix = quaternion_to_matrix(quaternions) 169 | euler_angles = matrix_to_rotation_6d(matrix) 170 | euler_norm = (euler_angles + torch.pi) / (2 * torch.pi) 171 | euler_img_03 = gaussians.as_grid_img(euler_norm[..., :3]) 172 | euler_img_36 = gaussians.as_grid_img(euler_norm[..., 3:]) 173 | rotation_03_img = wandb.Image(dcn(euler_img_03), caption="rotation 0:3") 174 | rotation_36_img = wandb.Image(dcn(euler_img_36), caption="rotation 3:6") 175 | 176 | grads = gaussians.xyz_gradient_accum / gaussians.denom 177 | grads[grads.isnan()] = 0.0 178 | grads_norm = (grads - grads.min()) / (grads.max() - grads.min()) 179 | grads_img = gaussians.as_grid_img(grads_norm) 180 | grads_xyz_accum_img = wandb.Image(dcn(grads_img), caption="grads_xyz_accum") 181 | 182 | to_log = { 183 | "grid": [ 184 | xyz_img, 185 | rgb_img, 186 | scale_img, 187 | opacity_img, 188 | rotation_03_img, 189 | rotation_36_img, 190 | grads_xyz_accum_img 191 | ] 192 | } 193 | 194 | if gaussians.max_sh_degree > 0: 195 | sh_composed = self.sh_pyramid(gaussians) 196 | sh_composed_img = wandb.Image(sh_composed) 197 | to_log["spherical harmonics"] = sh_composed_img 198 | 199 | wandb.log(to_log, step=step) 200 | 201 | def sh_pyramid(self, gaussians): 202 | w = gaussians.grid_sidelen 203 | sh = dcn(gaussians.get_features, normalize=True) 204 | sh = np.reshape(sh, [w, w, sh.shape[1], sh.shape[2]]) 205 | sh_composed = np.zeros([w * 4, w * 7, 3]) 206 | sh_composed[0 * w:1 * w, 3 * w:4 * w] = sh[:, :, 0, :] 207 | 208 | if gaussians.active_sh_degree > 1: 209 | sh_composed[1 * w:2 * w, 2 * w:5 * w] = np.concatenate(sh[:, :, 1:4, :].transpose(2, 0, 1, 3), 210 | axis=0).transpose(1, 0, 2) 211 | sh_composed[2 * w:3 * w, 1 * w:6 * w] = np.concatenate(sh[:, :, 4:9, :].transpose(2, 0, 1, 3), 212 | axis=0).transpose(1, 0, 2) 213 | sh_composed[3 * w:4 * w, 0 * w:7 * w] = np.concatenate(sh[:, :, 9:, :].transpose(2, 0, 1, 3), axis=0).transpose( 214 | 1, 0, 2) 215 | return sh_composed 216 | -------------------------------------------------------------------------------- /utils/quaternion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | # https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py 6 | 7 | def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: 8 | """ 9 | Convert rotations given as quaternions to rotation matrices. 10 | 11 | Args: 12 | quaternions: quaternions with real part first, 13 | as tensor of shape (..., 4). 14 | 15 | Returns: 16 | Rotation matrices as tensor of shape (..., 3, 3). 17 | """ 18 | r, i, j, k = torch.unbind(quaternions, -1) 19 | # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. 20 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 21 | 22 | o = torch.stack( 23 | ( 24 | 1 - two_s * (j * j + k * k), 25 | two_s * (i * j - k * r), 26 | two_s * (i * k + j * r), 27 | two_s * (i * j + k * r), 28 | 1 - two_s * (i * i + k * k), 29 | two_s * (j * k - i * r), 30 | two_s * (i * k - j * r), 31 | two_s * (j * k + i * r), 32 | 1 - two_s * (i * i + j * j), 33 | ), 34 | -1, 35 | ) 36 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 37 | 38 | 39 | def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: 40 | """ 41 | Convert rotations given as rotation matrices to Euler angles in radians. 42 | 43 | Args: 44 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 45 | convention: Convention string of three uppercase letters. 46 | 47 | Returns: 48 | Euler angles in radians as tensor of shape (..., 3). 49 | """ 50 | if len(convention) != 3: 51 | raise ValueError("Convention must have 3 letters.") 52 | if convention[1] in (convention[0], convention[2]): 53 | raise ValueError(f"Invalid convention {convention}.") 54 | for letter in convention: 55 | if letter not in ("X", "Y", "Z"): 56 | raise ValueError(f"Invalid letter {letter} in convention string.") 57 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: 58 | raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") 59 | i0 = _index_from_letter(convention[0]) 60 | i2 = _index_from_letter(convention[2]) 61 | tait_bryan = i0 != i2 62 | if tait_bryan: 63 | central_angle = torch.asin( 64 | matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) 65 | ) 66 | else: 67 | central_angle = torch.acos(matrix[..., i0, i0]) 68 | 69 | o = ( 70 | _angle_from_tan( 71 | convention[0], convention[1], matrix[..., i2], False, tait_bryan 72 | ), 73 | central_angle, 74 | _angle_from_tan( 75 | convention[2], convention[1], matrix[..., i0, :], True, tait_bryan 76 | ), 77 | ) 78 | return torch.stack(o, -1) 79 | 80 | 81 | def _index_from_letter(letter: str) -> int: 82 | if letter == "X": 83 | return 0 84 | if letter == "Y": 85 | return 1 86 | if letter == "Z": 87 | return 2 88 | raise ValueError("letter must be either X, Y or Z.") 89 | 90 | def _angle_from_tan( 91 | axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool 92 | ) -> torch.Tensor: 93 | """ 94 | Extract the first or third Euler angle from the two members of 95 | the matrix which are positive constant times its sine and cosine. 96 | 97 | Args: 98 | axis: Axis label "X" or "Y or "Z" for the angle we are finding. 99 | other_axis: Axis label "X" or "Y or "Z" for the middle axis in the 100 | convention. 101 | data: Rotation matrices as tensor of shape (..., 3, 3). 102 | horizontal: Whether we are looking for the angle for the third axis, 103 | which means the relevant entries are in the same row of the 104 | rotation matrix. If not, they are in the same column. 105 | tait_bryan: Whether the first and third axes in the convention differ. 106 | 107 | Returns: 108 | Euler Angles in radians for each matrix in data as a tensor 109 | of shape (...). 110 | """ 111 | 112 | i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] 113 | if horizontal: 114 | i2, i1 = i1, i2 115 | even = (axis + other_axis) in ["XY", "YZ", "ZX"] 116 | if horizontal == even: 117 | return torch.atan2(data[..., i1], data[..., i2]) 118 | if tait_bryan: 119 | return torch.atan2(-data[..., i2], data[..., i1]) 120 | return torch.atan2(data[..., i2], -data[..., i1]) 121 | 122 | 123 | def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: 124 | """ 125 | Converts rotation matrices to 6D rotation representation by Zhou et al. [1] 126 | by dropping the last row. Note that 6D representation is not unique. 127 | Args: 128 | matrix: batch of rotation matrices of size (*, 3, 3) 129 | 130 | Returns: 131 | 6D rotation representation, of size (*, 6) 132 | 133 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. 134 | On the Continuity of Rotation Representations in Neural Networks. 135 | IEEE Conference on Computer Vision and Pattern Recognition, 2019. 136 | Retrieved from http://arxiv.org/abs/1812.07035 137 | """ 138 | batch_dim = matrix.size()[:-2] 139 | return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) 140 | 141 | 142 | def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: 143 | """ 144 | Convert rotations given as Euler angles in radians to rotation matrices. 145 | 146 | Args: 147 | euler_angles: Euler angles in radians as tensor of shape (..., 3). 148 | convention: Convention string of three uppercase letters from 149 | {"X", "Y", and "Z"}. 150 | 151 | Returns: 152 | Rotation matrices as tensor of shape (..., 3, 3). 153 | """ 154 | if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: 155 | raise ValueError("Invalid input euler angles.") 156 | if len(convention) != 3: 157 | raise ValueError("Convention must have 3 letters.") 158 | if convention[1] in (convention[0], convention[2]): 159 | raise ValueError(f"Invalid convention {convention}.") 160 | for letter in convention: 161 | if letter not in ("X", "Y", "Z"): 162 | raise ValueError(f"Invalid letter {letter} in convention string.") 163 | matrices = [ 164 | _axis_angle_rotation(c, e) 165 | for c, e in zip(convention, torch.unbind(euler_angles, -1)) 166 | ] 167 | # return functools.reduce(torch.matmul, matrices) 168 | return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) 169 | 170 | 171 | def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: 172 | """ 173 | Return the rotation matrices for one of the rotations about an axis 174 | of which Euler angles describe, for each value of the angle given. 175 | 176 | Args: 177 | axis: Axis label "X" or "Y or "Z". 178 | angle: any shape tensor of Euler angles in radians 179 | 180 | Returns: 181 | Rotation matrices as tensor of shape (..., 3, 3). 182 | """ 183 | 184 | cos = torch.cos(angle) 185 | sin = torch.sin(angle) 186 | one = torch.ones_like(angle) 187 | zero = torch.zeros_like(angle) 188 | 189 | if axis == "X": 190 | R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) 191 | elif axis == "Y": 192 | R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) 193 | elif axis == "Z": 194 | R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) 195 | else: 196 | raise ValueError("letter must be either X, Y or Z.") 197 | 198 | return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) 199 | 200 | 201 | def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: 202 | """ 203 | Convert rotations given as rotation matrices to quaternions. 204 | 205 | Args: 206 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 207 | 208 | Returns: 209 | quaternions with real part first, as tensor of shape (..., 4). 210 | """ 211 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: 212 | raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") 213 | 214 | batch_dim = matrix.shape[:-2] 215 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( 216 | matrix.reshape(batch_dim + (9,)), dim=-1 217 | ) 218 | 219 | q_abs = _sqrt_positive_part( 220 | torch.stack( 221 | [ 222 | 1.0 + m00 + m11 + m22, 223 | 1.0 + m00 - m11 - m22, 224 | 1.0 - m00 + m11 - m22, 225 | 1.0 - m00 - m11 + m22, 226 | ], 227 | dim=-1, 228 | ) 229 | ) 230 | 231 | # we produce the desired quaternion multiplied by each of r, i, j, k 232 | quat_by_rijk = torch.stack( 233 | [ 234 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 235 | # `int`. 236 | torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), 237 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 238 | # `int`. 239 | torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), 240 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 241 | # `int`. 242 | torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), 243 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 244 | # `int`. 245 | torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), 246 | ], 247 | dim=-2, 248 | ) 249 | 250 | # We floor here at 0.1 but the exact level is not important; if q_abs is small, 251 | # the candidate won't be picked. 252 | flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) 253 | quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) 254 | 255 | # if not for numerical problems, quat_candidates[i] should be same (up to a sign), 256 | # forall i; we pick the best-conditioned one (with the largest denominator) 257 | 258 | return quat_candidates[ 259 | F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : 260 | ].reshape(batch_dim + (4,)) 261 | 262 | 263 | def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: 264 | """ 265 | Returns torch.sqrt(torch.max(0, x)) 266 | but with a zero subgradient where x is 0. 267 | """ 268 | ret = torch.zeros_like(x) 269 | positive_mask = x > 0 270 | ret[positive_mask] = torch.sqrt(x[positive_mask]) 271 | return ret -------------------------------------------------------------------------------- /scene/dataset_readers.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import os 13 | import sys 14 | from PIL import Image 15 | from typing import NamedTuple 16 | from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \ 17 | read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text 18 | from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal 19 | import numpy as np 20 | import json 21 | from pathlib import Path 22 | from plyfile import PlyData, PlyElement 23 | from utils.sh_utils import SH2RGB 24 | from scene.gaussian_model import BasicPointCloud 25 | 26 | class CameraInfo(NamedTuple): 27 | uid: int 28 | R: np.array 29 | T: np.array 30 | FovY: np.array 31 | FovX: np.array 32 | image: np.array 33 | image_path: str 34 | image_name: str 35 | width: int 36 | height: int 37 | 38 | class SceneInfo(NamedTuple): 39 | point_cloud: BasicPointCloud 40 | train_cameras: list 41 | test_cameras: list 42 | nerf_normalization: dict 43 | ply_path: str 44 | 45 | def getNerfppNorm(cam_info): 46 | def get_center_and_diag(cam_centers): 47 | cam_centers = np.hstack(cam_centers) 48 | avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) 49 | center = avg_cam_center 50 | dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) 51 | diagonal = np.max(dist) 52 | return center.flatten(), diagonal 53 | 54 | cam_centers = [] 55 | 56 | for cam in cam_info: 57 | W2C = getWorld2View2(cam.R, cam.T) 58 | C2W = np.linalg.inv(W2C) 59 | cam_centers.append(C2W[:3, 3:4]) 60 | 61 | center, diagonal = get_center_and_diag(cam_centers) 62 | radius = diagonal * 1.1 63 | 64 | translate = -center 65 | 66 | return {"translate": translate, "radius": radius} 67 | 68 | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): 69 | cam_infos = [] 70 | for idx, key in enumerate(cam_extrinsics): 71 | sys.stdout.write('\r') 72 | # the exact output you're looking for: 73 | sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics))) 74 | sys.stdout.flush() 75 | 76 | extr = cam_extrinsics[key] 77 | intr = cam_intrinsics[extr.camera_id] 78 | height = intr.height 79 | width = intr.width 80 | 81 | uid = intr.id 82 | R = np.transpose(qvec2rotmat(extr.qvec)) 83 | T = np.array(extr.tvec) 84 | 85 | if intr.model=="SIMPLE_PINHOLE": 86 | focal_length_x = intr.params[0] 87 | FovY = focal2fov(focal_length_x, height) 88 | FovX = focal2fov(focal_length_x, width) 89 | elif intr.model=="PINHOLE": 90 | focal_length_x = intr.params[0] 91 | focal_length_y = intr.params[1] 92 | FovY = focal2fov(focal_length_y, height) 93 | FovX = focal2fov(focal_length_x, width) 94 | else: 95 | assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" 96 | 97 | image_path = os.path.join(images_folder, os.path.basename(extr.name)) 98 | image_name = os.path.basename(image_path).split(".")[0] 99 | image = Image.open(image_path) 100 | 101 | cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 102 | image_path=image_path, image_name=image_name, width=width, height=height) 103 | cam_infos.append(cam_info) 104 | sys.stdout.write('\n') 105 | return cam_infos 106 | 107 | def fetchPly(path): 108 | plydata = PlyData.read(path) 109 | vertices = plydata['vertex'] 110 | positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T 111 | colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 112 | normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T 113 | return BasicPointCloud(points=positions, colors=colors, normals=normals) 114 | 115 | def storePly(path, xyz, rgb): 116 | # Define the dtype for the structured array 117 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 118 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 119 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 120 | 121 | normals = np.zeros_like(xyz) 122 | 123 | elements = np.empty(xyz.shape[0], dtype=dtype) 124 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 125 | elements[:] = list(map(tuple, attributes)) 126 | 127 | # Create the PlyData object and write to file 128 | vertex_element = PlyElement.describe(elements, 'vertex') 129 | ply_data = PlyData([vertex_element]) 130 | ply_data.write(path) 131 | 132 | def readColmapSceneInfo(path, images, eval, llffhold=8): 133 | try: 134 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") 135 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") 136 | cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) 137 | cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) 138 | except: 139 | cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") 140 | cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") 141 | cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) 142 | cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) 143 | 144 | reading_dir = "images" if images == None else images 145 | cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir)) 146 | cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name) 147 | 148 | if eval: 149 | train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] 150 | test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] 151 | else: 152 | train_cam_infos = cam_infos 153 | test_cam_infos = [] 154 | 155 | nerf_normalization = getNerfppNorm(train_cam_infos) 156 | 157 | ply_path = os.path.join(path, "sparse/0/points3D.ply") 158 | bin_path = os.path.join(path, "sparse/0/points3D.bin") 159 | txt_path = os.path.join(path, "sparse/0/points3D.txt") 160 | if not os.path.exists(ply_path): 161 | print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") 162 | try: 163 | xyz, rgb, _ = read_points3D_binary(bin_path) 164 | except: 165 | xyz, rgb, _ = read_points3D_text(txt_path) 166 | storePly(ply_path, xyz, rgb) 167 | try: 168 | pcd = fetchPly(ply_path) 169 | except: 170 | pcd = None 171 | 172 | scene_info = SceneInfo(point_cloud=pcd, 173 | train_cameras=train_cam_infos, 174 | test_cameras=test_cam_infos, 175 | nerf_normalization=nerf_normalization, 176 | ply_path=ply_path) 177 | return scene_info 178 | 179 | def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"): 180 | cam_infos = [] 181 | 182 | with open(os.path.join(path, transformsfile)) as json_file: 183 | contents = json.load(json_file) 184 | fovx = contents["camera_angle_x"] 185 | 186 | frames = contents["frames"] 187 | for idx, frame in enumerate(frames): 188 | cam_name = os.path.join(path, frame["file_path"] + extension) 189 | 190 | # NeRF 'transform_matrix' is a camera-to-world transform 191 | c2w = np.array(frame["transform_matrix"]) 192 | # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) 193 | c2w[:3, 1:3] *= -1 194 | 195 | # get the world-to-camera transform and set R, T 196 | w2c = np.linalg.inv(c2w) 197 | R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code 198 | T = w2c[:3, 3] 199 | 200 | image_path = os.path.join(path, cam_name) 201 | image_name = Path(cam_name).stem 202 | image = Image.open(image_path) 203 | 204 | im_data = np.array(image.convert("RGBA")) 205 | 206 | bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0]) 207 | 208 | norm_data = im_data / 255.0 209 | arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) 210 | image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") 211 | 212 | fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) 213 | FovY = fovy 214 | FovX = fovx 215 | 216 | cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, 217 | image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1])) 218 | 219 | return cam_infos 220 | 221 | def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): 222 | print("Reading Training Transforms") 223 | train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension) 224 | print("Reading Test Transforms") 225 | test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension) 226 | 227 | if not eval: 228 | train_cam_infos.extend(test_cam_infos) 229 | test_cam_infos = [] 230 | 231 | nerf_normalization = getNerfppNorm(train_cam_infos) 232 | 233 | ply_path = os.path.join(path, "points3d.ply") 234 | if not os.path.exists(ply_path): 235 | # Since this data set has no colmap data, we start with random points 236 | num_pts = 100_000 237 | print(f"Generating random point cloud ({num_pts})...") 238 | 239 | # We create random points inside the bounds of the synthetic Blender scenes 240 | xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 241 | shs = np.random.random((num_pts, 3)) / 255.0 242 | pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) 243 | 244 | storePly(ply_path, xyz, SH2RGB(shs) * 255) 245 | try: 246 | pcd = fetchPly(ply_path) 247 | except: 248 | pcd = None 249 | 250 | scene_info = SceneInfo(point_cloud=pcd, 251 | train_cameras=train_cam_infos, 252 | test_cameras=test_cam_infos, 253 | nerf_normalization=nerf_normalization, 254 | ply_path=ply_path) 255 | return scene_info 256 | 257 | sceneLoadTypeCallbacks = { 258 | "Colmap": readColmapSceneInfo, 259 | "Blender" : readNerfSyntheticInfo 260 | } -------------------------------------------------------------------------------- /scene/colmap_loader.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import numpy as np 13 | import collections 14 | import struct 15 | 16 | CameraModel = collections.namedtuple( 17 | "CameraModel", ["model_id", "model_name", "num_params"]) 18 | Camera = collections.namedtuple( 19 | "Camera", ["id", "model", "width", "height", "params"]) 20 | BaseImage = collections.namedtuple( 21 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 22 | Point3D = collections.namedtuple( 23 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 24 | CAMERA_MODELS = { 25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 32 | CameraModel(model_id=7, model_name="FOV", num_params=5), 33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 36 | } 37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 38 | for camera_model in CAMERA_MODELS]) 39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 40 | for camera_model in CAMERA_MODELS]) 41 | 42 | 43 | def qvec2rotmat(qvec): 44 | return np.array([ 45 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 46 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 47 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 48 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 49 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 50 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 51 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 52 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 53 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 54 | 55 | def rotmat2qvec(R): 56 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 57 | K = np.array([ 58 | [Rxx - Ryy - Rzz, 0, 0, 0], 59 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 60 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 61 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 62 | eigvals, eigvecs = np.linalg.eigh(K) 63 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 64 | if qvec[0] < 0: 65 | qvec *= -1 66 | return qvec 67 | 68 | class Image(BaseImage): 69 | def qvec2rotmat(self): 70 | return qvec2rotmat(self.qvec) 71 | 72 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 73 | """Read and unpack the next bytes from a binary file. 74 | :param fid: 75 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 76 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 77 | :param endian_character: Any of {@, =, <, >, !} 78 | :return: Tuple of read and unpacked values. 79 | """ 80 | data = fid.read(num_bytes) 81 | return struct.unpack(endian_character + format_char_sequence, data) 82 | 83 | def read_points3D_text(path): 84 | """ 85 | see: src/base/reconstruction.cc 86 | void Reconstruction::ReadPoints3DText(const std::string& path) 87 | void Reconstruction::WritePoints3DText(const std::string& path) 88 | """ 89 | xyzs = None 90 | rgbs = None 91 | errors = None 92 | num_points = 0 93 | with open(path, "r") as fid: 94 | while True: 95 | line = fid.readline() 96 | if not line: 97 | break 98 | line = line.strip() 99 | if len(line) > 0 and line[0] != "#": 100 | num_points += 1 101 | 102 | 103 | xyzs = np.empty((num_points, 3)) 104 | rgbs = np.empty((num_points, 3)) 105 | errors = np.empty((num_points, 1)) 106 | count = 0 107 | with open(path, "r") as fid: 108 | while True: 109 | line = fid.readline() 110 | if not line: 111 | break 112 | line = line.strip() 113 | if len(line) > 0 and line[0] != "#": 114 | elems = line.split() 115 | xyz = np.array(tuple(map(float, elems[1:4]))) 116 | rgb = np.array(tuple(map(int, elems[4:7]))) 117 | error = np.array(float(elems[7])) 118 | xyzs[count] = xyz 119 | rgbs[count] = rgb 120 | errors[count] = error 121 | count += 1 122 | 123 | return xyzs, rgbs, errors 124 | 125 | def read_points3D_binary(path_to_model_file): 126 | """ 127 | see: src/base/reconstruction.cc 128 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 129 | void Reconstruction::WritePoints3DBinary(const std::string& path) 130 | """ 131 | 132 | 133 | with open(path_to_model_file, "rb") as fid: 134 | num_points = read_next_bytes(fid, 8, "Q")[0] 135 | 136 | xyzs = np.empty((num_points, 3)) 137 | rgbs = np.empty((num_points, 3)) 138 | errors = np.empty((num_points, 1)) 139 | 140 | for p_id in range(num_points): 141 | binary_point_line_properties = read_next_bytes( 142 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 143 | xyz = np.array(binary_point_line_properties[1:4]) 144 | rgb = np.array(binary_point_line_properties[4:7]) 145 | error = np.array(binary_point_line_properties[7]) 146 | track_length = read_next_bytes( 147 | fid, num_bytes=8, format_char_sequence="Q")[0] 148 | track_elems = read_next_bytes( 149 | fid, num_bytes=8*track_length, 150 | format_char_sequence="ii"*track_length) 151 | xyzs[p_id] = xyz 152 | rgbs[p_id] = rgb 153 | errors[p_id] = error 154 | return xyzs, rgbs, errors 155 | 156 | def read_intrinsics_text(path): 157 | """ 158 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 159 | """ 160 | cameras = {} 161 | with open(path, "r") as fid: 162 | while True: 163 | line = fid.readline() 164 | if not line: 165 | break 166 | line = line.strip() 167 | if len(line) > 0 and line[0] != "#": 168 | elems = line.split() 169 | camera_id = int(elems[0]) 170 | model = elems[1] 171 | assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" 172 | width = int(elems[2]) 173 | height = int(elems[3]) 174 | params = np.array(tuple(map(float, elems[4:]))) 175 | cameras[camera_id] = Camera(id=camera_id, model=model, 176 | width=width, height=height, 177 | params=params) 178 | return cameras 179 | 180 | def read_extrinsics_binary(path_to_model_file): 181 | """ 182 | see: src/base/reconstruction.cc 183 | void Reconstruction::ReadImagesBinary(const std::string& path) 184 | void Reconstruction::WriteImagesBinary(const std::string& path) 185 | """ 186 | images = {} 187 | with open(path_to_model_file, "rb") as fid: 188 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 189 | for _ in range(num_reg_images): 190 | binary_image_properties = read_next_bytes( 191 | fid, num_bytes=64, format_char_sequence="idddddddi") 192 | image_id = binary_image_properties[0] 193 | qvec = np.array(binary_image_properties[1:5]) 194 | tvec = np.array(binary_image_properties[5:8]) 195 | camera_id = binary_image_properties[8] 196 | image_name = "" 197 | current_char = read_next_bytes(fid, 1, "c")[0] 198 | while current_char != b"\x00": # look for the ASCII 0 entry 199 | image_name += current_char.decode("utf-8") 200 | current_char = read_next_bytes(fid, 1, "c")[0] 201 | num_points2D = read_next_bytes(fid, num_bytes=8, 202 | format_char_sequence="Q")[0] 203 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 204 | format_char_sequence="ddq"*num_points2D) 205 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 206 | tuple(map(float, x_y_id_s[1::3]))]) 207 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 208 | images[image_id] = Image( 209 | id=image_id, qvec=qvec, tvec=tvec, 210 | camera_id=camera_id, name=image_name, 211 | xys=xys, point3D_ids=point3D_ids) 212 | return images 213 | 214 | 215 | def read_intrinsics_binary(path_to_model_file): 216 | """ 217 | see: src/base/reconstruction.cc 218 | void Reconstruction::WriteCamerasBinary(const std::string& path) 219 | void Reconstruction::ReadCamerasBinary(const std::string& path) 220 | """ 221 | cameras = {} 222 | with open(path_to_model_file, "rb") as fid: 223 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 224 | for _ in range(num_cameras): 225 | camera_properties = read_next_bytes( 226 | fid, num_bytes=24, format_char_sequence="iiQQ") 227 | camera_id = camera_properties[0] 228 | model_id = camera_properties[1] 229 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 230 | width = camera_properties[2] 231 | height = camera_properties[3] 232 | num_params = CAMERA_MODEL_IDS[model_id].num_params 233 | params = read_next_bytes(fid, num_bytes=8*num_params, 234 | format_char_sequence="d"*num_params) 235 | cameras[camera_id] = Camera(id=camera_id, 236 | model=model_name, 237 | width=width, 238 | height=height, 239 | params=np.array(params)) 240 | assert len(cameras) == num_cameras 241 | return cameras 242 | 243 | 244 | def read_extrinsics_text(path): 245 | """ 246 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 247 | """ 248 | images = {} 249 | with open(path, "r") as fid: 250 | while True: 251 | line = fid.readline() 252 | if not line: 253 | break 254 | line = line.strip() 255 | if len(line) > 0 and line[0] != "#": 256 | elems = line.split() 257 | image_id = int(elems[0]) 258 | qvec = np.array(tuple(map(float, elems[1:5]))) 259 | tvec = np.array(tuple(map(float, elems[5:8]))) 260 | camera_id = int(elems[8]) 261 | image_name = elems[9] 262 | elems = fid.readline().split() 263 | xys = np.column_stack([tuple(map(float, elems[0::3])), 264 | tuple(map(float, elems[1::3]))]) 265 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 266 | images[image_id] = Image( 267 | id=image_id, qvec=qvec, tvec=tvec, 268 | camera_id=camera_id, name=image_name, 269 | xys=xys, point3D_ids=point3D_ids) 270 | return images 271 | 272 | 273 | def read_colmap_bin_array(path): 274 | """ 275 | Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py 276 | 277 | :param path: path to the colmap binary file. 278 | :return: nd array with the floating point values in the value 279 | """ 280 | with open(path, "rb") as fid: 281 | width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, 282 | usecols=(0, 1, 2), dtype=int) 283 | fid.seek(0) 284 | num_delimiter = 0 285 | byte = fid.read(1) 286 | while True: 287 | if byte == b"&": 288 | num_delimiter += 1 289 | if num_delimiter >= 3: 290 | break 291 | byte = fid.read(1) 292 | array = np.fromfile(fid, np.float32) 293 | array = array.reshape((width, height, channels), order="F") 294 | return np.transpose(array, (1, 0, 2)).squeeze() 295 | -------------------------------------------------------------------------------- /compression/compression_exp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import torch 4 | from scene import Scene 5 | import os 6 | from tqdm import tqdm 7 | from gaussian_renderer import render 8 | from utils.general_utils import safe_state 9 | from argparse import ArgumentParser 10 | from arguments import get_hydra_training_args 11 | from gaussian_renderer import GaussianModel 12 | from utils.image_utils import psnr 13 | from utils.loss_utils import ssim 14 | from lpipsPyTorch import lpips 15 | 16 | import yaml 17 | from dataclasses import dataclass, asdict 18 | import pandas as pd 19 | 20 | from compression.jpeg_xl import JpegXlCodec 21 | from compression.npz import NpzCodec 22 | from compression.exr import EXRCodec 23 | from compression.png import PNGCodec 24 | 25 | codecs = { 26 | "jpeg-xl": JpegXlCodec, 27 | "npz": NpzCodec, 28 | "exr": EXRCodec, 29 | "png": PNGCodec, 30 | } 31 | 32 | pd.set_option('display.max_rows', 500) 33 | pd.set_option('display.max_columns', 500) 34 | pd.set_option('display.width', 1000) 35 | 36 | 37 | @dataclass 38 | class QuantEval: 39 | psnr: float 40 | ssim: float 41 | lpips: float 42 | 43 | @dataclass 44 | class Measurement: 45 | name: str 46 | path: str 47 | size_bytes: int 48 | quant_eval: QuantEval = None 49 | 50 | @property 51 | def human_readable_byte_size(self): 52 | if self.size_bytes == 0: 53 | return "0B" 54 | size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") 55 | i = int(np.floor(np.log(self.size_bytes) / np.log(1000))) 56 | p = np.power(1000, i) 57 | s = round(self.size_bytes / p, 2) 58 | return f"{s}{size_name[i]}" 59 | 60 | def to_dict(self): 61 | d = asdict(self) 62 | d.pop('quant_eval') 63 | if self.quant_eval is not None: 64 | d.update(self.quant_eval.__dict__) 65 | d['size'] = self.human_readable_byte_size 66 | return d 67 | 68 | 69 | 70 | def log_transform(coords): 71 | positive = coords > 0 72 | negative = coords < 0 73 | zero = coords == 0 74 | 75 | transformed_coords = np.zeros_like(coords) 76 | transformed_coords[positive] = np.log1p(coords[positive]) 77 | transformed_coords[negative] = -np.log1p(-coords[negative]) 78 | # For zero, no change is needed as transformed_coords is already initialized to zeros 79 | 80 | return transformed_coords 81 | 82 | def inverse_log_transform(transformed_coords): 83 | positive = transformed_coords > 0 84 | negative = transformed_coords < 0 85 | zero = transformed_coords == 0 86 | 87 | original_coords = np.zeros_like(transformed_coords) 88 | original_coords[positive] = np.expm1(transformed_coords[positive]) 89 | original_coords[negative] = -np.expm1(-transformed_coords[negative]) 90 | # For zero, no change is needed as original_coords is already initialized to zeros 91 | 92 | return original_coords 93 | 94 | 95 | 96 | def get_attr_numpy(gaussians, attr_name): 97 | attr_tensor = gaussians.attr_as_grid_img(attr_name) 98 | attr_numpy = attr_tensor.detach().cpu().numpy() 99 | return attr_numpy 100 | 101 | 102 | def compress_attr(attr_config, gaussians, out_folder): 103 | attr_name = attr_config['name'] 104 | attr_method = attr_config['method'] 105 | attr_params = attr_config.get('params', {}) 106 | 107 | if not attr_params: 108 | attr_params = {} 109 | 110 | codec = codecs[attr_method]() 111 | attr_np = get_attr_numpy(gaussians, attr_name) 112 | 113 | file_name = f"{attr_name}.{codec.file_ending()}" 114 | out_file = os.path.join(out_folder, file_name) 115 | 116 | if attr_config.get('contract', False): 117 | # sc = SceneContraction() 118 | # TODO take the original cuda array 119 | # attr = torch.tensor(attr_np, device="cuda") 120 | # attr_contracted = sc(attr) 121 | # attr_np = attr_contracted.cpu().numpy() 122 | attr_np = log_transform(attr_np) 123 | 124 | if "quantize" in attr_config: 125 | quantization = attr_config["quantize"] 126 | min_val = attr_np.min() 127 | max_val = attr_np.max() 128 | val_range = max_val - min_val 129 | # no division by zero 130 | if val_range == 0: 131 | val_range = 1 132 | attr_np_norm = (attr_np - min_val) / (val_range) 133 | qpow = 2 ** quantization 134 | attr_np_quantized = np.round(attr_np_norm * qpow) / qpow 135 | attr_np = attr_np_quantized * (val_range) + min_val 136 | attr_np = attr_np.astype(np.float32) 137 | 138 | if attr_config.get('normalize', False): 139 | min_val, max_val = codec.encode_with_normalization(attr_np, attr_name, out_file, **attr_params) 140 | return file_name, min_val, max_val 141 | else: 142 | codec.encode(attr_np, out_file, **attr_params) 143 | return file_name, None, None 144 | 145 | 146 | def decompress_attr(gaussians, attr_config, compressed_file, min_val, max_val): 147 | attr_name = attr_config['name'] 148 | attr_method = attr_config['method'] 149 | 150 | codec = codecs[attr_method]() 151 | 152 | if attr_config.get('normalize', False): 153 | decompressed_attr = codec.decode_with_normalization(compressed_file, min_val, max_val) 154 | else: 155 | decompressed_attr = codec.decode(compressed_file) 156 | 157 | if attr_config.get('contract', False): 158 | decompressed_attr = inverse_log_transform(decompressed_attr) 159 | 160 | # TODO dtype? 161 | # TODO to device? 162 | # TODO add grad? 163 | gaussians.set_attr_from_grid_img(attr_name, decompressed_attr) 164 | 165 | 166 | def run_single_compression(gaussians, experiment_out_path, experiment_config): 167 | compressed_min_vals = {} 168 | compressed_max_vals = {} 169 | 170 | compressed_files = {} 171 | 172 | total_size_bytes = 0 173 | 174 | for attribute in experiment_config['attributes']: 175 | compressed_file, min_val, max_mal = compress_attr(attribute, gaussians, experiment_out_path) 176 | attr_name = attribute['name'] 177 | compressed_files[attr_name] = compressed_file 178 | compressed_min_vals[attr_name] = min_val 179 | compressed_max_vals[attr_name] = max_mal 180 | total_size_bytes += os.path.getsize(os.path.join(experiment_out_path, compressed_file)) 181 | 182 | compr_info = pd.DataFrame([compressed_min_vals, compressed_max_vals, compressed_files], index=["min", "max", "file"]).T 183 | compr_info.to_csv(os.path.join(experiment_out_path, "compression_info.csv")) 184 | 185 | experiment_config['max_sh_degree'] = gaussians.max_sh_degree 186 | experiment_config['active_sh_degree'] = gaussians.active_sh_degree 187 | experiment_config['disable_xyz_log_activation'] = gaussians.disable_xyz_log_activation 188 | with open(os.path.join(experiment_out_path, "compression_config.yml"), 'w') as stream: 189 | yaml.dump(experiment_config, stream) 190 | 191 | return total_size_bytes 192 | 193 | def run_compressions(gaussians, out_path, compr_exp_config): 194 | 195 | # TODO some code duplciation with run_experiments / run_roundtrip 196 | 197 | results = {} 198 | 199 | for experiment in compr_exp_config['experiments']: 200 | 201 | experiment_name = experiment['name'] 202 | experiment_out_path = os.path.join(out_path, experiment_name) 203 | os.makedirs(experiment_out_path, exist_ok=True) 204 | 205 | size_bytes = run_single_compression(gaussians, experiment_out_path, experiment) 206 | results[f"size_bytes/cmpr_{experiment['name']}"] = size_bytes 207 | 208 | return results 209 | 210 | def run_single_decompression(compressed_dir): 211 | 212 | compr_info = pd.read_csv(os.path.join(compressed_dir, "compression_info.csv"), index_col=0) 213 | 214 | with open(os.path.join(compressed_dir, "compression_config.yml"), 'r') as stream: 215 | experiment_config = yaml.safe_load(stream) 216 | 217 | decompressed_gaussians = GaussianModel(experiment_config['max_sh_degree'], experiment_config['disable_xyz_log_activation']) 218 | decompressed_gaussians.active_sh_degree = experiment_config['active_sh_degree'] 219 | 220 | for attribute in experiment_config['attributes']: 221 | attr_name = attribute["name"] 222 | # compressed_bytes = compressed_attrs[attr_name] 223 | compressed_file = os.path.join(compressed_dir, compr_info.loc[attr_name, "file"]) 224 | 225 | decompress_attr(decompressed_gaussians, attribute, compressed_file, compr_info.loc[attr_name, "min"], compr_info.loc[attr_name, "max"]) 226 | 227 | return decompressed_gaussians 228 | 229 | def run_decompressions(compressions_dir): 230 | 231 | for compressed_dir in os.listdir(compressions_dir): 232 | compressed_dir_path = os.path.join(compressions_dir, compressed_dir) 233 | if not os.path.isdir(compressed_dir_path): 234 | continue 235 | yield os.path.basename(compressed_dir_path), run_single_decompression(compressed_dir_path) 236 | 237 | def run_roundtrip(gaussians, out_path, experiment_config): 238 | 239 | experiment_name = experiment_config['name'] 240 | experiment_out_path = os.path.join(out_path, experiment_name) 241 | os.makedirs(experiment_out_path, exist_ok=True) 242 | 243 | gaussians.prune_to_square_shape() 244 | 245 | total_size_bytes = run_single_compression(gaussians, experiment_out_path, experiment_config) 246 | 247 | decompressed_gaussians = run_single_decompression(experiment_out_path) 248 | 249 | return decompressed_gaussians, total_size_bytes, experiment_out_path 250 | 251 | 252 | 253 | 254 | 255 | 256 | def run_experiments(training_cfg, cmdline_iteration, compr_exp_config, disable_lpips=False): 257 | 258 | gaussians = GaussianModel(training_cfg.dataset.sh_degree, False) 259 | 260 | scene = Scene(training_cfg.dataset, gaussians, load_iteration=cmdline_iteration, shuffle=False) 261 | iteration = scene.loaded_iter 262 | 263 | gaussians._xyz = gaussians.inverse_xyz_activation(gaussians._xyz.detach()) 264 | 265 | print(f"Compressing {training_cfg.dataset.model_path} iteration {iteration}") 266 | out_path = os.path.join(training_cfg.dataset.model_path, "compression", f"iteration_{iteration}") 267 | os.makedirs(out_path, exist_ok=True) 268 | 269 | bg_color = [1,1,1] if training_cfg.dataset.white_background else [0, 0, 0] 270 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 271 | all_cameras = scene.getTestCameras() # + scene.getTrainCameras() 272 | 273 | def render_test_measure(gaussians_to_render): 274 | 275 | with torch.inference_mode(): 276 | psnrs = [] 277 | ssims = [] 278 | lpipss = [] 279 | 280 | for idx, view in enumerate(all_cameras): 281 | rendering = render(view, gaussians_to_render, training_cfg.pipeline, background)["render"] 282 | gt = view.original_image[0:3, :, :] 283 | psnrs.append(psnr(rendering, gt).cpu().numpy()) 284 | ssims.append(ssim(rendering, gt).cpu().numpy()) 285 | if disable_lpips: 286 | lpipss.append(np.nan) 287 | else: 288 | lpipss.append(lpips(rendering, gt, net_type='vgg').cpu().numpy()) 289 | 290 | return QuantEval(psnr=np.mean(psnrs), ssim=np.mean(ssims), lpips=np.mean(lpipss)) 291 | 292 | exp_results = [] 293 | 294 | original_eval = render_test_measure(gaussians) 295 | exp_results.append(Measurement(name="PLY", path=scene.loaded_gaussian_ply, size_bytes=os.path.getsize(scene.loaded_gaussian_ply), quant_eval=original_eval)) 296 | 297 | for experiment in compr_exp_config['experiments']: 298 | gaussians_roundtrip, compressed_size_bytes, exp_out_path = run_roundtrip(gaussians, out_path, experiment) 299 | rendered_eval = render_test_measure(gaussians_roundtrip) 300 | meas = Measurement(name=experiment['name'], path=exp_out_path, size_bytes=compressed_size_bytes, quant_eval=rendered_eval) 301 | print(meas) 302 | exp_results.append(meas) 303 | 304 | exp_df = pd.DataFrame([m.to_dict() for m in exp_results]) 305 | 306 | sorted_columns_for_easy_comparison = ['name', 'size', 'psnr', 'ssim', 'lpips', 'path', 'size_bytes'] 307 | 308 | assert len(exp_df.columns) == len(sorted_columns_for_easy_comparison), "Hey, you added a column to the dataframe, please add it to the sorted_columns_for_easy_comparison list as well" 309 | 310 | exp_df = exp_df[sorted_columns_for_easy_comparison] 311 | exp_df.to_csv(os.path.join(out_path, "results.csv"), index=False) 312 | return exp_df 313 | 314 | 315 | 316 | 317 | def load_config(config_path: str): 318 | with open(config_path, 'r') as stream: 319 | config = yaml.safe_load(stream) 320 | return config 321 | 322 | 323 | def compression_exp(): 324 | # example args: --model_path=../models/truck --iteration 10000 --compression_config compression/configs/jpeg_xl.yml [--results_csv results.csv] [--disable_lpips] 325 | 326 | parser = ArgumentParser(description="Compression script parameters") 327 | parser.add_argument("--model_path", type=str) 328 | parser.add_argument("--source_path", type=str) 329 | parser.add_argument("--iteration", default=-1, type=int) 330 | parser.add_argument("--compression_config", type=str) 331 | parser.add_argument("--results_csv", type=str) 332 | parser.add_argument("--results_tex", type=str) 333 | parser.add_argument("--disable_lpips", action="store_true") 334 | 335 | cmdlne_string = sys.argv[1:] 336 | args_cmdline = parser.parse_args(cmdlne_string) 337 | 338 | iteration = args_cmdline.iteration 339 | model_path = args_cmdline.model_path 340 | 341 | compr_exp_config = load_config(args_cmdline.compression_config) 342 | 343 | training_cfg = get_hydra_training_args(model_path) 344 | 345 | training_cfg.dataset.model_path = model_path 346 | training_cfg.dataset.source_path = args_cmdline.source_path 347 | 348 | disable_lpips = args_cmdline.disable_lpips 349 | 350 | results_csv = args_cmdline.results_csv 351 | results_tex = args_cmdline.results_tex 352 | 353 | exp_df = run_experiments(training_cfg, iteration, compr_exp_config, disable_lpips=disable_lpips) 354 | print(exp_df) 355 | 356 | if results_csv: 357 | csv_dirname = os.path.dirname(results_csv) 358 | if csv_dirname: 359 | os.makedirs(csv_dirname, exist_ok=True) 360 | exp_df.to_csv(results_csv, index=False) 361 | 362 | if results_tex: 363 | tex_dirname = os.path.dirname(results_tex) 364 | if tex_dirname: 365 | os.makedirs(tex_dirname, exist_ok=True) 366 | exp_df.to_latex(results_tex, index=False, 367 | columns=["name", "psnr", "ssim", "lpips", "size_bytes"], 368 | header=["Name", "PSNR $\\uparrow$", "SSIM $\\uparrow$", "LPIPS $\\downarrow$", "Size (MB)"], 369 | formatters={"size_bytes": lambda x: f"{x / 1000 / 1000:.2f}", "psnr": lambda x: f"{x:.2f}", "ssim": lambda x: f"{x:.3f}", "lpips": lambda x: f"{x:.3f}"} 370 | ) 371 | 372 | 373 | 374 | 375 | if __name__ == "__main__": 376 | compression_exp() 377 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import os 14 | import hydra 15 | 16 | from omegaconf import DictConfig, OmegaConf 17 | from tqdm import tqdm 18 | from copy import deepcopy 19 | from lpipsPyTorch import lpips 20 | from random import randint 21 | 22 | from gaussian_renderer import render, network_gui 23 | from scene import Scene, GaussianModel 24 | 25 | from utils.general_utils import safe_state 26 | from utils.loss_utils import l1_loss, ssim 27 | from utils.image_utils import psnr 28 | from utils.wandb_utils import init_wandb, save_hist, wandb 29 | 30 | from training_viewer import TrainingViewer 31 | 32 | from compression.compression_exp import run_compressions, run_decompressions 33 | from compression.decompress import decompress_all_to_ply 34 | 35 | 36 | def training(cfg): 37 | 38 | first_iter = 0 39 | 40 | if not cfg.run.use_sh: 41 | print("use_sh not set, disabling sorting for spherical harmonics") 42 | cfg.sorting.weights.features_rest = 0 43 | cfg.neighbor_loss.weights.features_rest = 0 44 | for compression in cfg.compression["experiments"]: 45 | for i, att in enumerate(compression["attributes"]): 46 | if att["name"] == "_features_rest": 47 | del compression["attributes"][i] 48 | 49 | print(f"Starting training on dataset {cfg.dataset.source_path}") 50 | 51 | disable_xyz_log_activation = "disable_xyz_log_activation" in cfg.optimization and cfg.optimization.disable_xyz_log_activation 52 | print(f"{disable_xyz_log_activation=}") 53 | 54 | gaussians = GaussianModel(cfg.dataset.sh_degree, disable_xyz_log_activation=disable_xyz_log_activation) 55 | scene = Scene(cfg.dataset, gaussians) 56 | gaussians.training_setup(cfg.optimization) 57 | if cfg.run.start_checkpoint: 58 | (model_params, first_iter) = torch.load(cfg.run.start_checkpoint) 59 | gaussians.restore(model_params, cfg.optimization) 60 | 61 | # ---------------------- 62 | # SSGS 63 | if cfg.sorting.enabled: 64 | gaussians.prune_to_square_shape() 65 | gaussians.sort_into_grid(cfg.sorting, not cfg.run.no_progress_bar) 66 | 67 | debug_viewer = TrainingViewer(debug_view=cfg.local_window_debug_view.view_id) 68 | # ---------------------- 69 | 70 | bg_color = [1, 1, 1] if cfg.dataset.white_background else [0, 0, 0] 71 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 72 | 73 | iter_start = torch.cuda.Event(enable_timing=True) 74 | iter_end = torch.cuda.Event(enable_timing=True) 75 | 76 | viewpoint_stack = None 77 | ema_loss_for_log = 0.0 78 | if cfg.run.no_progress_bar: 79 | progress_bar = None 80 | else: 81 | progress_bar = tqdm(range(first_iter, cfg.optimization.iterations), desc="Training progress") 82 | first_iter += 1 83 | for iteration in range(first_iter, cfg.optimization.iterations + 1): 84 | if network_gui.conn == None: 85 | network_gui.try_connect() 86 | while network_gui.conn != None: 87 | try: 88 | net_image_bytes = None 89 | custom_cam, do_training, cfg.pipeline.convert_SHs_python, cfg.pipeline.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() 90 | if custom_cam != None: 91 | net_image = render(custom_cam, gaussians, cfg.pipeline, background, scaling_modifer)["render"] 92 | net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) 93 | network_gui.send(net_image_bytes, cfg.dataset.source_path) 94 | if do_training and ((iteration < int(cfg.optimization.iterations)) or not keep_alive): 95 | break 96 | except Exception as e: 97 | network_gui.conn = None 98 | 99 | iter_start.record() 100 | 101 | gaussians.update_learning_rate(iteration) 102 | 103 | # Every 1000 its we increase the levels of SH up to a maximum degree 104 | if iteration % 1000 == 0 and cfg.run.use_sh: 105 | gaussians.oneupSHdegree() 106 | 107 | # Pick a random Camera 108 | if not viewpoint_stack: 109 | viewpoint_stack = scene.getTrainCameras().copy() 110 | viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1)) 111 | 112 | # Render 113 | if (iteration - 1) == cfg.debug.debug_from: 114 | cfg.pipeline.debug = True 115 | 116 | bg = torch.rand((3), device="cuda") if cfg.optimization.random_background else background 117 | 118 | render_pkg = render(viewpoint_cam, gaussians, cfg.pipeline, bg) 119 | image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], \ 120 | render_pkg["visibility_filter"], render_pkg["radii"] 121 | 122 | # Loss 123 | gt_image = viewpoint_cam.original_image.cuda() 124 | Ll1 = l1_loss(image, gt_image) 125 | 126 | loss = (1.0 - cfg.optimization.lambda_dssim) * Ll1 + cfg.optimization.lambda_dssim * (1.0 - ssim(image, gt_image)) 127 | 128 | # ---------------------- 129 | # SSGS 130 | if cfg.neighbor_loss.lambda_neighbor > 0: 131 | 132 | nb_losses = [] 133 | wandb_log = {} 134 | 135 | attr_getter_fn = gaussians.get_activated_attr_flat if cfg.neighbor_loss.activated else gaussians.get_attr_flat 136 | 137 | weight_sum = sum(cfg.neighbor_loss.weights.values()) 138 | for attr_name, attr_weight in cfg.neighbor_loss.weights.items(): 139 | if attr_weight > 0: 140 | nb_losses.append(gaussians.neighborloss_2d(attr_getter_fn(attr_name), cfg.neighbor_loss) * attr_weight / weight_sum) 141 | wandb_log[f"neighbor_loss/{attr_name}"] = nb_losses[-1] 142 | 143 | nb_loss = cfg.neighbor_loss.lambda_neighbor * sum(nb_losses) 144 | 145 | if iteration % cfg.run.log_nb_loss_interval == 0: 146 | wandb.log(wandb_log, step=iteration) 147 | else: 148 | nb_loss = torch.tensor(0.0) 149 | # ---------------------- 150 | 151 | loss += nb_loss 152 | loss.backward() 153 | 154 | iter_end.record() 155 | 156 | with torch.no_grad(): 157 | # Progress bar 158 | ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log 159 | 160 | if progress_bar is not None: 161 | if iteration % 10 == 0: 162 | progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) 163 | progress_bar.update(10) 164 | if iteration == cfg.optimization.iterations: 165 | progress_bar.close() 166 | 167 | # Debug view 168 | if cfg.wandb_debug_view.interval != -1 and iteration % cfg.wandb_debug_view.interval == 0: 169 | if cfg.wandb_debug_view.view_enabled: 170 | debug_viewer.training_view_wandb(scene, gaussians, pipe=cfg.pipeline, step=iteration, background=background) 171 | if cfg.wandb_debug_view.save_hist: 172 | save_hist(gaussians, step=iteration) 173 | 174 | if cfg.local_window_debug_view.enabled and cfg.local_window_debug_view.interval != -1 and iteration % cfg.local_window_debug_view.interval == 0: 175 | debug_viewer.training_view(scene, gaussians, pipe=cfg.pipeline, background=background) 176 | 177 | 178 | # Log and save 179 | if iteration % cfg.run.log_training_report_interval == 0: 180 | wandb.log( 181 | { 182 | "loss/l1_loss": Ll1.item(), 183 | "loss/total_loss": loss.item(), 184 | "loss/nb_loss": nb_loss.item(), 185 | "iter_time": iter_start.elapsed_time(iter_end), 186 | "num gaussians": len(gaussians.get_xyz), 187 | }, 188 | step=iteration 189 | ) 190 | if iteration in cfg.run.test_iterations: 191 | training_report(cfg, iteration, scene, gaussians, (cfg.pipeline, background), log_name="uncompressed") 192 | 193 | if (iteration in cfg.run.save_iterations): 194 | print("\n[ITER {}] Saving Gaussians".format(iteration)) 195 | scene.save(iteration) 196 | 197 | # Compression 198 | if (iteration in cfg.run.compress_iterations): 199 | print("\n[ITER {}] Compressing Gaussians".format(iteration)) 200 | compr_path = os.path.join(cfg.dataset.model_path, "compression", f"iteration_{iteration}") 201 | 202 | # enable compression of non-sorted gaussians without affecting results 203 | gaussians_to_compress = deepcopy(gaussians) 204 | gaussians_to_compress.prune_to_square_shape() 205 | 206 | compr_results = run_compressions(gaussians_to_compress, compr_path, OmegaConf.to_container(cfg.compression)) 207 | wandb.log(compr_results, step=iteration) 208 | 209 | for compr_name, decompressed_gaussians in run_decompressions(compr_path): 210 | training_report(cfg, iteration, scene, decompressed_gaussians, (cfg.pipeline, background), log_name=f"cmpr_{compr_name}", log_GT=False) 211 | 212 | # decompress plys in last compression iteration 213 | if iteration == max(cfg.run.compress_iterations): 214 | decompress_all_to_ply(compr_path) 215 | 216 | # Densification 217 | if iteration < cfg.optimization.densify_until_iter: 218 | # Keep track of max radii in image-space for pruning 219 | gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], 220 | radii[visibility_filter]) 221 | gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) 222 | 223 | if iteration > cfg.optimization.densify_from_iter and iteration % cfg.optimization.densification_interval == 0: 224 | size_threshold = 20 if iteration > cfg.optimization.opacity_reset_interval else None 225 | gaussians.densify_and_prune(max_grad=cfg.optimization.densify_grad_threshold, min_opacity=cfg.optimization.densify_min_opacity, extent=scene.cameras_extent, max_screen_size=size_threshold) 226 | 227 | if iteration > cfg.optimization.densify_from_iter and iteration % cfg.optimization.densification_interval == 0: 228 | # ---------------------- 229 | # SSGS 230 | if cfg.sorting.enabled: 231 | gaussians.prune_to_square_shape() 232 | gaussians.sort_into_grid(cfg.sorting, not cfg.run.no_progress_bar) 233 | # ---------------------- 234 | 235 | if iteration < cfg.optimization.densify_until_iter: 236 | if iteration % cfg.optimization.opacity_reset_interval == 0 or ( 237 | cfg.dataset.white_background and iteration == cfg.optimization.densify_from_iter): 238 | gaussians.reset_opacity() 239 | 240 | # Optimizer step 241 | if iteration < cfg.optimization.iterations: 242 | gaussians.optimizer.step() 243 | gaussians.optimizer.zero_grad(set_to_none=True) 244 | 245 | if (iteration in cfg.run.checkpoint_iterations): 246 | print("\n[ITER {}] Saving Checkpoint".format(iteration)) 247 | torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") 248 | 249 | 250 | def training_report(cfg, iteration, scene, gaussians, renderArgs, log_name, log_GT=True): 251 | # Report test and samples of training set 252 | torch.cuda.empty_cache() 253 | validation_configs = ( 254 | { 255 | 'name': 'test', 256 | 'cameras': scene.getTestCameras() 257 | }, 258 | { 259 | 'name': 'train', 260 | 'cameras': [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)] 261 | } 262 | ) 263 | 264 | for config in validation_configs: 265 | if config['cameras'] and len(config['cameras']) > 0: 266 | l1_test = 0.0 267 | psnr_test = 0.0 268 | ssim_test = 0.0 269 | lpipss_test = 0.0 270 | wandb_images = [] 271 | wandb_gt_images = [] 272 | for idx, viewpoint in enumerate(config['cameras']): 273 | image = torch.clamp(render(viewpoint, gaussians, *renderArgs)["render"], 0.0, 1.0) 274 | gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) 275 | 276 | if cfg.wandb_debug_view.view_enabled and idx < 5: 277 | name = config['name'] + "_view_{}/render".format(viewpoint.image_name) 278 | wandb_img = wandb.Image(image[None], caption=name) 279 | wandb_images.append(wandb_img) 280 | name = config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name) 281 | wandb_img = wandb.Image(gt_image[None], caption=name) 282 | wandb_gt_images.append(wandb_img) 283 | l1_test += l1_loss(image, gt_image).mean().double() 284 | psnr_test += psnr(image, gt_image).mean().double() 285 | ssim_test += ssim(image, gt_image) 286 | if cfg.run.test_lpips: 287 | lpipss_test += lpips(image, gt_image, net_type='vgg').item() 288 | 289 | psnr_test /= len(config['cameras']) 290 | l1_test /= len(config['cameras']) 291 | ssim_test /= len(config['cameras']) 292 | lpipss_test /= len(config['cameras']) 293 | print(f"\n[ITER {iteration}] Evaluating {log_name} {config['name']}: L1: {l1_test:.4f} - PSNR: {psnr_test:.4f} - SSIM: {ssim_test:.4f} - LPIPS: {lpipss_test:.4f}") 294 | 295 | to_log = { 296 | f"eval_{config['name']}_PSNR/{log_name}": psnr_test, 297 | f"eval_{config['name']}_SSIM/{log_name}": ssim_test, 298 | f"eval_{config['name']}_LPIPS/{log_name}": lpipss_test, 299 | f"eval_{config['name']}_L1/{log_name}": l1_test, 300 | f"eval_{config['name']}_renders/{log_name}": wandb_images, 301 | } 302 | 303 | if log_GT: 304 | to_log[f"eval_{config['name']}_gt_img/{log_name}"] = wandb_gt_images 305 | 306 | wandb.log(to_log, step=iteration) 307 | torch.cuda.empty_cache() 308 | 309 | 310 | @hydra.main(version_base=None, config_path='config', config_name='training') 311 | def main(cfg: DictConfig): 312 | 313 | # Initialize system state (RNG) 314 | safe_state(cfg.run.quiet) 315 | cfg.run.wandb_url = init_wandb(cfg) 316 | 317 | output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir 318 | 319 | if not cfg.dataset.model_path: 320 | cfg.dataset.model_path = output_dir 321 | 322 | yaml_config = OmegaConf.to_yaml(cfg) 323 | 324 | training_config_yml_path = os.path.join(cfg.dataset.model_path, "training_config.yaml") 325 | with open(training_config_yml_path, 'w') as file: 326 | file.write("# Also available at .hydra/config.yaml\n") 327 | file.write(yaml_config) 328 | 329 | cfg_args_path = os.path.join(cfg.dataset.model_path, "cfg_args") 330 | with open(cfg_args_path, "w") as file: 331 | file.write(f"Namespace(model_path='{cfg.dataset.model_path}', source_path='{cfg.dataset.source_path}', images='{cfg.dataset.images}', resolution='{cfg.dataset.resolution}', sh_degree={cfg.dataset.sh_degree}, white_background={cfg.dataset.white_background}, eval={cfg.dataset.eval})") 332 | 333 | 334 | # Start GUI server, configure and run training 335 | network_gui.init(cfg.gui_server.ip, cfg.gui_server.port) 336 | torch.autograd.set_detect_anomaly(cfg.debug.detect_anomaly) 337 | training(cfg) 338 | 339 | # All done 340 | print("\nTraining complete.") 341 | 342 | if __name__ == "__main__": 343 | main() 344 | 345 | 346 | 347 | --------------------------------------------------------------------------------