├── .github └── workflows │ └── docker-build.yaml ├── .gitignore ├── Dockerfile ├── README.md ├── assets ├── lego.mp4 ├── overview.jpg └── teaser.jpg ├── config_files ├── blender │ └── TriMipRF.gin └── ms_blender │ └── TriMipRF.gin ├── dataset ├── __init__.py ├── parsers │ ├── __init__.py │ ├── nerf_synthetic.py │ └── nerf_synthetic_multiscale.py ├── ray_dataset.py └── utils │ ├── __init__.py │ ├── cameras.py │ ├── io.py │ └── utils.py ├── main.py ├── neural_field ├── __init__.py ├── encoding │ ├── __init__.py │ └── tri_mip.py ├── field │ ├── __init__.py │ └── trimipRF.py ├── model │ ├── RFModel.py │ ├── __init__.py │ └── trimipRF.py └── nn_utils │ ├── __init__.py │ └── activations.py ├── pyproject.toml ├── requirements.txt ├── scripts └── convert_blender_data.py ├── trainer ├── __init__.py └── trainer.py └── utils ├── __init__.py ├── colormaps.py ├── common.py ├── ray.py ├── render_buffer.py ├── tensor_dataclass.py └── writer.py /.github/workflows/docker-build.yaml: -------------------------------------------------------------------------------- 1 | name: Docker build & test 2 | 3 | on: 4 | push: 5 | branches: 6 | - develop 7 | - staging 8 | - prod 9 | pull_request: 10 | types: [ synchronize, opened, reopened, labeled ] 11 | 12 | jobs: 13 | build-docker-image: 14 | if: github.event_name == 'push' || (github.event_name == 'pull_request' && contains(github.event.*.labels.*.name, 'build docker')) 15 | 16 | runs-on: buildjet-4vcpu-ubuntu-2204 17 | 18 | permissions: 19 | contents: read 20 | packages: write 21 | 22 | steps: 23 | - name: Checkout 24 | uses: actions/checkout@v3 25 | 26 | - name: Get image tag 27 | shell: bash 28 | run: | 29 | echo "IMAGE_VERSION=pr-${{ github.event.pull_request.number }}-$(echo ${{github.sha}} | cut -c1-7)" >> $GITHUB_ENV 30 | - name: Echo image tag 31 | shell: bash 32 | run: | 33 | echo "using image tag: ${{ env.IMAGE_VERSION }}" 34 | 35 | - name: Set up Docker Buildx 36 | uses: docker/setup-buildx-action@v2 37 | 38 | - name: Login to Docker Hub 39 | uses: docker/login-action@v2 40 | with: 41 | username: ${{ secrets.DOCKERHUB_USERNAME }} 42 | password: ${{ secrets.DOCKERHUB_TOKEN }} 43 | 44 | - name: Build and push 45 | uses: docker/build-push-action@v4 46 | with: 47 | context: . 48 | file: ./Dockerfile 49 | push: true 50 | tags: joshpwrk/trimip:${{ env.IMAGE_VERSION }} 51 | cache-from: type=registry,ref=joshpwrk/trimip:${{ env.IMAGE_VERSION }} 52 | cache-to: type=inline -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | !scripts/downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .envrc 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | 165 | # Experiments and outputs 166 | runs/ 167 | outputs/ 168 | # tensorboard log files 169 | events.out.* 170 | 171 | # Data 172 | data 173 | !*/data 174 | 175 | # Misc 176 | old/ 177 | temp* 178 | .nfs* 179 | external/ 180 | __MACOSX/ 181 | outputs*/ 182 | node_modules/ 183 | bash/ 184 | cache/ 185 | package-lock.json 186 | camera_paths/ 187 | exp/ 188 | */._DS_Store 189 | .vscode/ 190 | launch_logs/ 191 | **/binary 192 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 2 | 3 | # Set environment variables to non-interactive (this prevents some prompts) 4 | ENV DEBIAN_FRONTEND=noninteractive 5 | 6 | # Update and install some essential packages 7 | RUN apt-get update && apt-get install -y \ 8 | software-properties-common \ 9 | build-essential \ 10 | curl \ 11 | git \ 12 | && rm -rf /var/lib/apt/lists/* 13 | 14 | # install requirements for trimip -r requirements.txt 15 | RUN apt-get update && apt-get install -y \ 16 | ffmpeg \ 17 | libavformat-dev \ 18 | libavcodec-dev \ 19 | libavdevice-dev \ 20 | libavutil-dev \ 21 | libavutil-dev \ 22 | libavfilter-dev \ 23 | libswscale-dev \ 24 | libswresample-dev 25 | 26 | # RUN apt-get install -y libavformat-dev 27 | # RUN apt-get install -y libavcodec-dev 28 | # RUN apt-get install -y libavdevice-dev 29 | # RUN apt-get install -y libavutil-dev 30 | # RUN apt-get install -y libavfilter-dev 31 | # RUN apt-get install -y libswscale-dev 32 | # RUN apt-get install -y libswresample-dev 33 | 34 | # Install Python 3 (Ubuntu 22.04 comes with Python 3.10) 35 | RUN apt-get update && apt-get install -y python3 python3-pip 36 | 37 | # Set the working directory inside the container 38 | WORKDIR /usr/src/app 39 | 40 | COPY requirements.txt ./ 41 | RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 42 | RUN TCNN_CUDA_ARCHITECTURES=89 pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 43 | 44 | # Install nvdiffrast: https://nvlabs.github.io/nvdiffrast/#linux 45 | 46 | RUN pip3 install --no-cache-dir -r requirements.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tri-MipRF 2 | 3 | Official PyTorch implementation of the paper: 4 | 5 | > **Tri-MipRF: Tri-Mip Representation for Efficient Anti-Aliasing Neural Radiance Fields** 6 | > 7 | > ***ICCV 2023*** 8 | > 9 | > Wenbo Hu, Yuling Wang, Lin Ma, Bangbang Yang, Lin Gao, Xiao Liu, Yuewen Ma 10 | > 11 | > 12 | 13 | https://github.com/wbhu/Tri-MipRF/assets/11834645/6c50baf7-ac46-46fd-a36f-172f99ea9054 14 | 15 | 16 | > Instant-ngp (left) suffers from aliasing in distant or low-resolution views and blurriness in 17 | > close-up shots, while Tri-MipRF (right) renders both fine-grained details in close-ups 18 | > and high-fidelity zoomed-out images. 19 | 20 |

21 | 22 |

23 | 24 | > To render a pixel, we emit a cone from the camera’s projection center to the pixel on the 25 | > image plane, and then we cast a set of spheres inside the cone. Next, the spheres are 26 | > orthogonally projected 27 | > on the three planes and featurized by our Tri-Mip encoding. After that the feature 28 | > vector is fed into the tiny MLP to non-linearly map to 29 | > density and color. Finally, the density and 30 | > color of the spheres are integrated using volume rendering to produce final color for the pixel. 31 | 32 | 33 |

34 | 35 |

36 | 37 | > Our Tri-MipRF achieves state-of-the-art rendering quality while can be reconstructed efficiently, 38 | > compared with cutting-edge radiance fields methods, e.g., NeRF, MipNeRF, Plenoxels, 39 | > TensoRF, and Instant-ngp. Equipping Instant-ngp with super-sampling (named Instant-ngp↑5×) 40 | > improves the rendering quality to a certain extent but significantly slows down the reconstruction. 41 | 42 | ## **Installation** 43 | Please install the following dependencies first 44 | - [PyTorch (1.13.1 + cu11.6)](https://pytorch.org/get-started/locally/) 45 | - [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn) 46 | - [nvdiffrast](https://nvlabs.github.io/nvdiffrast/) 47 | 48 | And then install the following dependencies using *pip* 49 | ```shell 50 | pip3 install av==9.2.0 \ 51 | beautifulsoup4==4.11.1 \ 52 | entrypoints==0.4 \ 53 | gdown==4.5.1 \ 54 | gin-config==0.5.0 \ 55 | h5py==3.7.0 \ 56 | imageio==2.21.1 \ 57 | imageio-ffmpeg \ 58 | ipython==7.19.0 \ 59 | kornia==0.6.8 \ 60 | loguru==0.6.0 \ 61 | lpips==0.1.4 \ 62 | mediapy==1.1.0 \ 63 | mmcv==1.6.2 \ 64 | ninja==1.10.2.3 \ 65 | numpy==1.23.3 \ 66 | open3d==0.16.0 \ 67 | opencv-python==4.6.0.66 \ 68 | pandas==1.5.0 \ 69 | Pillow==9.2.0 \ 70 | plotly==5.7.0 \ 71 | pycurl==7.43.0.6 \ 72 | PyMCubes==0.1.2 \ 73 | pyransac3d==0.6.0 \ 74 | PyYAML==6.0 \ 75 | rich==12.6.0 \ 76 | scipy==1.9.2 \ 77 | tensorboard==2.9.0 \ 78 | torch-fidelity==0.3.0 \ 79 | torchmetrics==0.10.0 \ 80 | torchtyping==0.1.4 \ 81 | tqdm==4.64.1 \ 82 | tyro==0.3.25 \ 83 | appdirs \ 84 | nerfacc==0.3.5 \ 85 | plyfile \ 86 | scikit-image \ 87 | trimesh \ 88 | torch_efficient_distloss \ 89 | umsgpack \ 90 | pyngrok \ 91 | cryptography==39.0.2 \ 92 | omegaconf==2.2.3 \ 93 | segmentation-refinement \ 94 | xatlas \ 95 | protobuf==3.20.0 \ 96 | jinja2 \ 97 | click==8.1.7 \ 98 | tensorboardx \ 99 | termcolor 100 | ``` 101 | 102 | ## **Data** 103 | 104 | ### nerf_synthetic dataset 105 | Please download and unzip `nerf_synthetic.zip` from the [NeRF official Google Drive](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1). 106 | 107 | ### Generate multiscale dataset 108 | Please generate it by 109 | ```shell 110 | python scripts/convert_blender_data.py --blenderdir /path/to/nerf_synthetic --outdir /path/to/nerf_synthetic_multiscale 111 | ``` 112 | 113 | ## **Training and evaluation** 114 | ```shell 115 | python main.py --ginc config_files/ms_blender/TriMipRF.gin 116 | ``` 117 | 118 | 119 | ## **TODO** 120 | 121 | - [x] ~~Release source code~~. 122 | 123 | ## **Citation** 124 | 125 | If you find the code useful for your work, please star this repo and consider citing: 126 | 127 | ``` 128 | @inproceedings{hu2023Tri-MipRF, 129 | author = {Hu, Wenbo and Wang, Yuling and Ma, Lin and Yang, Bangbang and Gao, Lin and Liu, Xiao and Ma, Yuewen}, 130 | title = {Tri-MipRF: Tri-Mip Representation for Efficient Anti-Aliasing Neural Radiance Fields}, 131 | booktitle = {ICCV}, 132 | year = {2023} 133 | } 134 | ``` 135 | 136 | 137 | ## **Related Work** 138 | 139 | - [Mip-NeRF (ICCV 2021)](https://jonbarron.info/mipnerf/) 140 | - [Instant-ngp (SIGGRAPH 2022)](https://nvlabs.github.io/instant-ngp/) 141 | - [Zip-NeRF (ICCV 2023)](https://jonbarron.info/zipnerf/) 142 | -------------------------------------------------------------------------------- /assets/lego.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wbhu/Tri-MipRF/5b54a274d338ab64405d8123837e2d393c2262c2/assets/lego.mp4 -------------------------------------------------------------------------------- /assets/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wbhu/Tri-MipRF/5b54a274d338ab64405d8123837e2d393c2262c2/assets/overview.jpg -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wbhu/Tri-MipRF/5b54a274d338ab64405d8123837e2d393c2262c2/assets/teaser.jpg -------------------------------------------------------------------------------- /config_files/blender/TriMipRF.gin: -------------------------------------------------------------------------------- 1 | main.train_split = 'trainval' 2 | main.num_workers = 16 3 | main.model_name = 'Tri-MipRF' 4 | main.batch_size = 28 # this is not the actual batch_size, but the prefetch size 5 | 6 | RayDataset.base_path = '/path/to/nerf_synthetic' 7 | RayDataset.scene = 'chair' 8 | RayDataset.scene_type = 'nerf_synthetic' 9 | 10 | 11 | Trainer.base_exp_dir = '/path/to/experiment/lod/dir' 12 | Trainer.exp_name = None 13 | Trainer.eval_step = 25000 14 | Trainer.log_step = 1000 15 | Trainer.max_steps = 25001 16 | Trainer.target_sample_batch_size = 262144 17 | 18 | 19 | -------------------------------------------------------------------------------- /config_files/ms_blender/TriMipRF.gin: -------------------------------------------------------------------------------- 1 | main.train_split = 'trainval' 2 | main.num_workers = 16 3 | main.model_name = 'Tri-MipRF' 4 | main.batch_size = 24 # this is not the actual batch_size, but the prefetch size 5 | 6 | RayDataset.base_path = '/path/to/nerf_synthetic_multiscale' 7 | RayDataset.scene = 'chair' 8 | RayDataset.scene_type = 'nerf_synthetic_multiscale' 9 | 10 | 11 | Trainer.base_exp_dir = '/path/to/experiment/lod/dir' 12 | Trainer.exp_name = None 13 | Trainer.eval_step = 25000 14 | Trainer.log_step = 1000 15 | Trainer.max_steps = 25001 16 | Trainer.target_sample_batch_size = 262144 17 | 18 | 19 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wbhu/Tri-MipRF/5b54a274d338ab64405d8123837e2d393c2262c2/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/parsers/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from . import nerf_synthetic 4 | from . import nerf_synthetic_multiscale 5 | 6 | 7 | def get_parser(parser_name: str) -> Callable: 8 | if 'nerf_synthetic' == parser_name: 9 | return nerf_synthetic.load_data 10 | elif 'nerf_synthetic_multiscale' == parser_name: 11 | return nerf_synthetic_multiscale.load_data 12 | else: 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /dataset/parsers/nerf_synthetic.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | 4 | import dataset.utils.io as data_io 5 | from dataset.utils.cameras import PinholeCamera 6 | 7 | 8 | def load_data(base_path: Path, scene: str, split: str): 9 | data_path = base_path / scene 10 | splits = ['train', 'val'] if split == "trainval" else [split] 11 | meta = None 12 | for s in splits: 13 | meta_path = data_path / "transforms_{}.json".format(s) 14 | m = data_io.load_from_json(meta_path) 15 | if meta is None: 16 | meta = m 17 | else: 18 | for k, v in meta.items(): 19 | if type(v) is list: 20 | v.extend(m[k]) 21 | else: 22 | assert v == m[k] 23 | 24 | image_height, image_width = 800, 800 25 | camera_angle_x = float(meta["camera_angle_x"]) 26 | focal_length = 0.5 * image_width / np.tan(0.5 * camera_angle_x) 27 | cx = image_width / 2.0 28 | cy = image_height / 2.0 29 | cameras = [ 30 | PinholeCamera( 31 | fx=focal_length, 32 | fy=focal_length, 33 | cx=cx, 34 | cy=cy, 35 | width=image_width, 36 | height=image_height, 37 | ) 38 | ] 39 | cam_num = len(cameras) 40 | 41 | frames, poses = {k: [] for k in range(len(cameras))}, { 42 | k: [] for k in range(len(cameras)) 43 | } 44 | index = 0 45 | for frame in meta["frames"]: 46 | fname = data_path / Path(frame["file_path"].replace("./", "") + ".png") 47 | frames[index % cam_num].append( 48 | { 49 | 'image_filename': fname, 50 | 'lossmult': 1.0, 51 | } 52 | ) 53 | poses[index % cam_num].append( 54 | np.array(frame["transform_matrix"]).astype(np.float32) 55 | ) 56 | index = index + 1 57 | 58 | aabb = np.array([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5]) 59 | 60 | outputs = { 61 | 'frames': frames, 62 | 'poses': poses, 63 | 'cameras': cameras, 64 | 'aabb': aabb, 65 | } 66 | return outputs 67 | -------------------------------------------------------------------------------- /dataset/parsers/nerf_synthetic_multiscale.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | import dataset.utils.io as data_io 6 | from dataset.utils.cameras import PinholeCamera 7 | 8 | 9 | def load_data(base_path: Path, scene: str, split: str, cam_num: int = 4): 10 | # ipdb.set_trace() 11 | data_path = base_path / scene 12 | meta_path = data_path / 'metadata.json' 13 | 14 | splits = ['train', 'val'] if split == "trainval" else [split] 15 | meta = None 16 | for s in splits: 17 | m = data_io.load_from_json(meta_path)[s] 18 | if meta is None: 19 | meta = m 20 | else: 21 | for k, v in meta.items(): 22 | v.extend(m[k]) 23 | 24 | pix2cam = meta['pix2cam'] 25 | poses = meta['cam2world'] 26 | image_width = meta['width'] 27 | image_height = meta['height'] 28 | lossmult = meta['lossmult'] 29 | 30 | assert image_height[0] == image_height[cam_num] 31 | assert image_width[0] == image_width[cam_num] 32 | assert pix2cam[0] == pix2cam[cam_num] 33 | assert lossmult[0] == lossmult[cam_num] 34 | cameras = [] 35 | for i in range(cam_num): 36 | k = np.linalg.inv(pix2cam[i]) 37 | fx = k[0, 0] 38 | fy = -k[1, 1] 39 | cx = -k[0, 2] 40 | cy = -k[1, 2] 41 | cam = PinholeCamera( 42 | fx=fx, 43 | fy=fy, 44 | cx=cx, 45 | cy=cy, 46 | width=image_width[i], 47 | height=image_height[i], 48 | # loss_multi=lossmult[i], 49 | ) 50 | cameras.append(cam) 51 | 52 | frames = {k: [] for k in range(len(cameras))} 53 | index = 0 54 | for frame in tqdm(meta['file_path']): 55 | fname = data_path / frame 56 | frames[index % cam_num].append( 57 | { 58 | 'image_filename': fname, 59 | 'lossmult': lossmult[index], 60 | } 61 | ) 62 | index = index + 1 63 | poses = {k: poses[k :: len(cameras)] for k in range(len(cameras))} 64 | 65 | aabb = np.array([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5]) 66 | outputs = { 67 | 'frames': frames, 68 | 'poses': poses, 69 | 'cameras': cameras, 70 | 'aabb': aabb, 71 | } 72 | return outputs 73 | 74 | 75 | if __name__ == '__main__': 76 | data = load_data( 77 | Path('/mnt/bn/wbhu-nerf/Dataset/nerf_synthetic_multiscale'), 78 | 'lego', 79 | split='train', 80 | ) 81 | pass 82 | -------------------------------------------------------------------------------- /dataset/ray_dataset.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor 2 | import multiprocessing 3 | 4 | import gin 5 | import numpy as np 6 | import torch 7 | from loguru import logger 8 | from torch.utils.data import Dataset, DataLoader 9 | from pathlib import Path 10 | 11 | from tqdm import tqdm 12 | 13 | from dataset.parsers import get_parser 14 | from dataset.utils import io as data_io 15 | from utils.ray import RayBundle 16 | from utils.render_buffer import RenderBuffer 17 | from utils.tensor_dataclass import TensorDataclass 18 | 19 | 20 | @gin.configurable() 21 | class RayDataset(Dataset): 22 | def __init__( 23 | self, 24 | base_path: str, 25 | scene: str = 'lego', 26 | scene_type: str = 'nerf_synthetic_multiscale', 27 | split: str = 'train', 28 | to_world: bool = True, 29 | num_rays: int = 8192, 30 | render_bkgd: str = 'white', 31 | **kwargs 32 | ): 33 | super().__init__() 34 | parser = get_parser(scene_type) 35 | data_source = parser( 36 | base_path=Path(base_path), scene=scene, split=split, **kwargs 37 | ) 38 | self.training = split.find('train') >= 0 39 | 40 | self.cameras = data_source['cameras'] 41 | self.ray_bundles = [c.build('cpu') for c in self.cameras] 42 | logger.info('==> Find {} cameras'.format(len(self.cameras))) 43 | self.poses = { 44 | k: torch.tensor(np.asarray(v)).float() # Nx4x4 45 | for k, v in data_source["poses"].items() 46 | } 47 | # parallel loading frames 48 | self.frames = {} 49 | for k, cam_frames in data_source['frames'].items(): 50 | with ThreadPoolExecutor( 51 | max_workers=min(multiprocessing.cpu_count(), 32) 52 | ) as executor: 53 | frames = list( 54 | tqdm( 55 | executor.map( 56 | lambda f: torch.tensor( 57 | data_io.imread(f['image_filename']) 58 | ), 59 | cam_frames, 60 | ), 61 | total=len(cam_frames), 62 | dynamic_ncols=True, 63 | ) 64 | ) 65 | self.frames[k] = torch.stack(frames, dim=0) 66 | self.frame_number = {k: x.shape[0] for k, x in self.frames.items()} 67 | self.aabb = torch.tensor(np.asarray(data_source['aabb'])).float() 68 | self.loss_multi = { 69 | k: torch.tensor([x['lossmult'] for x in v]) 70 | for k, v in data_source['frames'].items() 71 | } 72 | self.file_names = { 73 | k: [x['image_filename'].stem for x in v] 74 | for k, v in data_source['frames'].items() 75 | } 76 | self.to_world = to_world 77 | self.num_rays = num_rays 78 | self.render_bkgd = render_bkgd 79 | 80 | # try to read a data to initialize RenderBuffer subclass 81 | self[0] 82 | 83 | def __len__(self): 84 | if self.training: 85 | return 10**9 # hack of streaming dataset 86 | else: 87 | return sum([x.shape[0] for k, x in self.poses.items()]) 88 | 89 | def update_num_rays(self, num_rays): 90 | self.num_rays = num_rays 91 | 92 | @torch.no_grad() 93 | def __getitem__(self, index): 94 | if self.training: 95 | rgb, c2w, cam_rays, loss_multi = [], [], [], [] 96 | for cam_idx in range(len(self.cameras)): 97 | num_rays = int( 98 | self.num_rays 99 | * (1.0 / self.loss_multi[cam_idx][0]) 100 | / sum([1.0 / v[0] for _, v in self.loss_multi.items()]) 101 | ) 102 | idx = torch.randint( 103 | 0, 104 | self.frames[cam_idx].shape[0], 105 | size=(num_rays,), 106 | ) 107 | sample_x = torch.randint( 108 | 0, 109 | self.cameras[cam_idx].width, 110 | size=(num_rays,), 111 | ) # uniform sampling 112 | sample_y = torch.randint( 113 | 0, 114 | self.cameras[cam_idx].height, 115 | size=(num_rays,), 116 | ) # uniform sampling 117 | rgb.append(self.frames[cam_idx][idx, sample_y, sample_x]) 118 | c2w.append(self.poses[cam_idx][idx]) 119 | cam_rays.append(self.ray_bundles[cam_idx][sample_y, sample_x]) 120 | loss_multi.append(self.loss_multi[cam_idx][idx, None]) 121 | rgb = torch.cat(rgb, dim=0) 122 | c2w = torch.cat(c2w, dim=0) 123 | cam_rays = RayBundle.direct_cat(cam_rays, dim=0) 124 | loss_multi = torch.cat(loss_multi, dim=0) 125 | if 'white' == self.render_bkgd: 126 | render_bkgd = torch.ones_like(rgb[..., [-1]]) 127 | elif 'rand' == self.render_bkgd: 128 | render_bkgd = torch.rand_like(rgb[..., :3]) 129 | elif 'randn' == self.render_bkgd: 130 | render_bkgd = (torch.randn_like(rgb[..., :3]) + 0.5).clamp( 131 | 0.0, 1.0 132 | ) 133 | else: 134 | raise NotImplementedError 135 | 136 | else: 137 | for cam_idx, num in self.frame_number.items(): 138 | if index < num: 139 | break 140 | index = index - num 141 | num_rays = len(self.ray_bundles[cam_idx]) 142 | idx = torch.ones(size=(num_rays,), dtype=torch.int64) * index 143 | sample_x, sample_y = torch.meshgrid( 144 | torch.arange(self.cameras[cam_idx].width), 145 | torch.arange(self.cameras[cam_idx].height), 146 | indexing="xy", 147 | ) 148 | sample_x = sample_x.reshape(-1) 149 | sample_y = sample_y.reshape(-1) 150 | 151 | rgb = self.frames[cam_idx][idx, sample_y, sample_x] 152 | c2w = self.poses[cam_idx][idx] 153 | cam_rays = self.ray_bundles[cam_idx][sample_y, sample_x] 154 | loss_multi = self.loss_multi[cam_idx][idx, None] 155 | render_bkgd = torch.ones_like(rgb[..., [-1]]) 156 | 157 | if self.to_world: 158 | cam_rays.directions = ( 159 | c2w[:, :3, :3] @ cam_rays.directions[..., None] 160 | ).squeeze(-1) 161 | cam_rays.origins = c2w[:, :3, -1] 162 | target = RenderBuffer( 163 | rgb=rgb[..., :3] * rgb[..., [-1]] 164 | + (1.0 - rgb[..., [-1]]) * render_bkgd, 165 | render_bkgd=render_bkgd, 166 | # alpha=rgb[..., [-1]], 167 | loss_multi=loss_multi, 168 | ) 169 | if not self.training: 170 | cam_rays = cam_rays.reshape( 171 | (self.cameras[cam_idx].height, self.cameras[cam_idx].width) 172 | ) 173 | target = target.reshape( 174 | (self.cameras[cam_idx].height, self.cameras[cam_idx].width) 175 | ) 176 | outputs = { 177 | # 'c2w': c2w, 178 | 'cam_rays': cam_rays, 179 | 'target': target, 180 | # 'idx': idx, 181 | } 182 | if not self.training: 183 | outputs['name'] = self.file_names[cam_idx][index] 184 | return outputs 185 | 186 | 187 | def ray_collate(batch): 188 | res = {k: [] for k in batch[0].keys()} 189 | for data in batch: 190 | for k, v in data.items(): 191 | res[k].append(v) 192 | for k, v in res.items(): 193 | if isinstance(v[0], RenderBuffer) or isinstance(v[0], RayBundle): 194 | res[k] = TensorDataclass.direct_cat(v, dim=0) 195 | else: 196 | res[k] = torch.cat(v, dim=0) 197 | return res 198 | 199 | 200 | if __name__ == '__main__': 201 | training_dataset = RayDataset( 202 | # '/mnt/bn/wbhu-nerf/Dataset/nerf_synthetic', 203 | '/mnt/bn/wbhu-nerf/Dataset/nerf_synthetic_multiscale', 204 | 'lego', 205 | # 'nerf_synthetic', 206 | 'nerf_synthetic_multiscale', 207 | ) 208 | train_loader = iter( 209 | DataLoader( 210 | training_dataset, 211 | batch_size=8, 212 | shuffle=False, 213 | num_workers=0, 214 | collate_fn=ray_collate, 215 | pin_memory=True, 216 | worker_init_fn=None, 217 | pin_memory_device='cuda', 218 | ) 219 | ) 220 | test_dataset = RayDataset( 221 | # '/mnt/bn/wbhu-nerf/Dataset/nerf_synthetic', 222 | '/mnt/bn/wbhu-nerf/Dataset/nerf_synthetic_multiscale', 223 | 'lego', 224 | # 'nerf_synthetic', 225 | 'nerf_synthetic_multiscale', 226 | num_rays=81920, 227 | split='test', 228 | ) 229 | for i in tqdm(range(1000)): 230 | data = next(train_loader) 231 | pass 232 | for i in tqdm(range(len(test_dataset))): 233 | data = test_dataset[i] 234 | pass 235 | pass 236 | -------------------------------------------------------------------------------- /dataset/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wbhu/Tri-MipRF/5b54a274d338ab64405d8123837e2d393c2262c2/dataset/utils/__init__.py -------------------------------------------------------------------------------- /dataset/utils/cameras.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from utils.ray import RayBundle 6 | 7 | 8 | class PinholeCamera: 9 | def __init__( 10 | self, 11 | fx: float, 12 | fy: float, 13 | cx: float, 14 | cy: float, 15 | width: int = None, 16 | height: int = None, 17 | coord_type: str = 'opengl', 18 | device: str = 'cuda:0', 19 | normalize_ray: bool = True, 20 | ): 21 | self.fx, self.fy, self.cx, self.cy = fx, fy, cx, cy 22 | self.width, self.height = width, height 23 | self.coord_type = coord_type 24 | self.K = torch.tensor( 25 | [ 26 | [self.fx, 0, self.cx], 27 | [0, self.fy, self.cy], 28 | [0, 0, 1], 29 | ], 30 | dtype=torch.float32, 31 | ) # (3, 3) 32 | self.device = device 33 | self.normalize_ray = normalize_ray 34 | self.near = 0.1 35 | self.far = 100 36 | if self.coord_type == 'opencv': 37 | self.sign_z = 1.0 38 | elif self.coord_type == 'opengl': 39 | self.sign_z = -1.0 40 | else: 41 | raise ValueError 42 | self.ray_bundle = None 43 | 44 | def build(self, device): 45 | x, y = torch.meshgrid( 46 | torch.arange(self.width, device=device), 47 | torch.arange(self.height, device=device), 48 | indexing="xy", 49 | ) 50 | directions = F.pad( 51 | torch.stack( 52 | [ 53 | (x - self.K[0, 2] + 0.5) / self.K[0, 0], 54 | (y - self.K[1, 2] + 0.5) / self.K[1, 1] * self.sign_z, 55 | ], 56 | dim=-1, 57 | ), 58 | (0, 1), 59 | value=self.sign_z, 60 | ) # [H,W,3] 61 | # Distance from each unit-norm direction vector to its x-axis neighbor 62 | dx = torch.linalg.norm( 63 | (directions[:, :-1, :] - directions[:, 1:, :]), 64 | dim=-1, 65 | keepdims=True, 66 | ) # [H,W-1,1] 67 | dx = torch.cat([dx, dx[:, -2:-1, :]], 1) # [H,W,1] 68 | dy = torch.linalg.norm( 69 | (directions[:-1, :, :] - directions[1:, :, :]), 70 | dim=-1, 71 | keepdims=True, 72 | ) # [H-1,W,1] 73 | dy = torch.cat([dy, dy[-2:-1, :, :]], 0) # [H,W,1] 74 | # Cut the distance in half, and then round it out so that it's 75 | # halfway between inscribed by / circumscribed about the pixel. 76 | area = dx * dy 77 | radii = torch.sqrt(area / torch.pi) 78 | if self.normalize_ray: 79 | directions = directions / torch.linalg.norm( 80 | directions, dim=-1, keepdims=True 81 | ) 82 | self.ray_bundle = RayBundle( 83 | origins=torch.zeros_like(directions), 84 | directions=directions, 85 | radiis=radii, 86 | ray_cos=torch.matmul( 87 | directions, 88 | torch.tensor([[0.0, 0.0, self.sign_z]], device=device).T, 89 | ), 90 | ) 91 | return self.ray_bundle 92 | 93 | @property 94 | def fov_y(self): 95 | return np.degrees(2 * np.arctan(self.cy / self.fy)) 96 | 97 | def get_proj(self): 98 | # projection 99 | proj = np.eye(4, dtype=np.float32) 100 | proj[0, 0] = 2 * self.fx / self.width 101 | proj[1, 1] = 2 * self.fy / self.height 102 | proj[0, 2] = 2 * self.cx / self.width - 1 103 | proj[1, 2] = 2 * self.cy / self.height - 1 104 | proj[2, 2] = -(self.far + self.near) / (self.far - self.near) 105 | proj[2, 3] = -2 * self.far * self.near / (self.far - self.near) 106 | proj[3, 2] = -1 107 | proj[3, 3] = 0 108 | return proj 109 | 110 | def get_PVM(self, c2w): 111 | c2w = c2w.copy() 112 | # to right up backward (opengl) 113 | c2w[:3, 1] *= -1 114 | c2w[:3, 2] *= -1 115 | w2c = np.linalg.inv(c2w) 116 | return np.matmul(self.get_proj(), w2c) 117 | -------------------------------------------------------------------------------- /dataset/utils/io.py: -------------------------------------------------------------------------------- 1 | import json 2 | import gzip 3 | import numpy as np 4 | import cv2 5 | from pathlib import Path 6 | from typing import Union, Any 7 | import open3d as o3d 8 | 9 | cv2.ocl.setUseOpenCL(False) 10 | cv2.setNumThreads(0) 11 | 12 | 13 | def load_from_json(file_path: Path): 14 | assert file_path.suffix == ".json" 15 | with open(file_path, encoding="UTF-8") as file: 16 | return json.load(file) 17 | 18 | 19 | def load_from_jgz(file_path: Path): 20 | assert file_path.suffix == ".jgz" 21 | with gzip.GzipFile(file_path, "rb") as file: 22 | return json.load(file) 23 | 24 | 25 | def write_to_json(file_path: Path, content: dict): 26 | assert file_path.suffix == ".json" 27 | with open(file_path, "w", encoding="UTF-8") as file: 28 | json.dump(content, file) 29 | 30 | 31 | def imread(file_path: Path, dtype: np.dtype = np.float32) -> np.ndarray: 32 | im = cv2.imread(str(file_path), flags=cv2.IMREAD_UNCHANGED) 33 | if 2 == len(im.shape): 34 | im = im[..., None] 35 | if 4 == im.shape[-1]: 36 | im = cv2.cvtColor(im, cv2.COLOR_BGRA2RGBA) 37 | elif 3 == im.shape[-1]: 38 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 39 | elif 1 == im.shape[-1]: 40 | pass 41 | else: 42 | raise NotImplementedError 43 | if dtype != np.uint8: 44 | im = im / 255.0 45 | return im.astype(dtype) 46 | 47 | 48 | def imwrite(im: np.ndarray, file_path: Path) -> None: 49 | if not file_path.parent.exists(): 50 | file_path.parent.mkdir(parents=True, exist_ok=True) 51 | if len(im.shape) == 4: 52 | assert im.shape[0] == 1 53 | im = im[0] 54 | assert len(im.shape) == 3 55 | if im.dtype == np.float32: 56 | im = (im.clip(0.0, 1.0) * 255).astype(np.uint8) 57 | if 4 == im.shape[-1]: 58 | im = cv2.cvtColor(im, cv2.COLOR_RGBA2BGRA) 59 | elif 3 == im.shape[-1]: 60 | im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR) 61 | elif 1 == im.shape[-1]: 62 | im = im[..., 0] 63 | else: 64 | raise NotImplementedError 65 | cv2.imwrite(str(file_path), im) 66 | 67 | 68 | def write_rendering(data: Any, parrent_path: Path, name: str): 69 | if isinstance(data, np.ndarray): 70 | imwrite(data, parrent_path / (name + '.png')) 71 | elif isinstance(data, o3d.geometry.PointCloud): 72 | if not parrent_path.exists(): 73 | parrent_path.mkdir(exist_ok=True, parents=True) 74 | o3d.io.write_point_cloud(str(parrent_path / (name + '.ply')), data) 75 | -------------------------------------------------------------------------------- /dataset/utils/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | 5 | 6 | def split_training(num_images, train_split_percentage, split): 7 | # filter image_filenames and poses based on train/eval split percentage 8 | num_train_images = math.ceil(train_split_percentage * num_images) 9 | num_test_images = num_images - num_train_images 10 | i_all = np.arange(num_images) 11 | i_train = np.linspace( 12 | 0, num_images - 1, num_train_images, dtype=int 13 | ) # equally spaced training images starting and ending at 0 and num_images-1 14 | # eval images are the remaining images 15 | i_test = np.setdiff1d(i_all, i_train) 16 | assert len(i_test) == num_test_images 17 | if split == "train": 18 | indices = i_train 19 | elif split in ["val", "test"]: 20 | indices = i_test 21 | elif split == 'all' or split == 'rendering': 22 | indices = i_all 23 | else: 24 | raise ValueError(f"Unknown dataparser split {split}") 25 | return indices 26 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | import gin 4 | from loguru import logger 5 | from torch.utils.data import DataLoader 6 | 7 | from utils.common import set_random_seed 8 | from dataset.ray_dataset import RayDataset, ray_collate 9 | from neural_field.model import get_model 10 | from trainer import Trainer 11 | 12 | 13 | @gin.configurable() 14 | def main( 15 | seed: int = 42, 16 | num_workers: int = 0, 17 | train_split: str = "train", 18 | stages: str = "train_eval", 19 | batch_size: int = 16, 20 | model_name="Tri-MipRF", 21 | ): 22 | stages = list(stages.split("_")) 23 | set_random_seed(seed) 24 | 25 | logger.info("==> Init dataloader ...") 26 | train_dataset = RayDataset(split=train_split) 27 | train_loader = DataLoader( 28 | train_dataset, 29 | batch_size=batch_size, 30 | num_workers=num_workers, 31 | shuffle=False, 32 | collate_fn=ray_collate, 33 | pin_memory=True, 34 | worker_init_fn=None, 35 | pin_memory_device='cuda', 36 | prefetch_factor=2, 37 | ) 38 | test_dataset = RayDataset(split='test') 39 | test_loader = DataLoader( 40 | test_dataset, 41 | batch_size=None, 42 | num_workers=1, 43 | shuffle=False, 44 | pin_memory=True, 45 | worker_init_fn=None, 46 | pin_memory_device='cuda', 47 | ) 48 | 49 | logger.info("==> Init model ...") 50 | model = get_model(model_name=model_name)(aabb=train_dataset.aabb) 51 | logger.info(model) 52 | 53 | logger.info("==> Init trainer ...") 54 | trainer = Trainer(model, train_loader, eval_loader=test_loader) 55 | if "train" in stages: 56 | trainer.fit() 57 | if "eval" in stages: 58 | if "train" not in stages: 59 | trainer.load_ckpt() 60 | trainer.eval(save_results=True, rendering_channels=["rgb", "depth"]) 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument( 66 | "--ginc", 67 | action="append", 68 | help="gin config file", 69 | ) 70 | parser.add_argument( 71 | "--ginb", 72 | action="append", 73 | help="gin bindings", 74 | ) 75 | args = parser.parse_args() 76 | 77 | ginbs = [] 78 | if args.ginb: 79 | ginbs.extend(args.ginb) 80 | gin.parse_config_files_and_bindings(args.ginc, ginbs, finalize_config=False) 81 | 82 | exp_name = gin.query_parameter("Trainer.exp_name") 83 | exp_name = ( 84 | "%s/%s/%s/%s" 85 | % ( 86 | gin.query_parameter("RayDataset.scene_type"), 87 | gin.query_parameter("RayDataset.scene"), 88 | gin.query_parameter("main.model_name"), 89 | datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 90 | ) 91 | if exp_name is None 92 | else exp_name 93 | ) 94 | gin.bind_parameter("Trainer.exp_name", exp_name) 95 | gin.finalize() 96 | main() 97 | -------------------------------------------------------------------------------- /neural_field/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /neural_field/encoding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wbhu/Tri-MipRF/5b54a274d338ab64405d8123837e2d393c2262c2/neural_field/encoding/__init__.py -------------------------------------------------------------------------------- /neural_field/encoding/tri_mip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | import nvdiffrast.torch 5 | 6 | 7 | class TriMipEncoding(nn.Module): 8 | def __init__( 9 | self, 10 | n_levels: int, 11 | plane_size: int, 12 | feature_dim: int, 13 | include_xyz: bool = False, 14 | ): 15 | super(TriMipEncoding, self).__init__() 16 | self.n_levels = n_levels 17 | self.plane_size = plane_size 18 | self.feature_dim = feature_dim 19 | self.include_xyz = include_xyz 20 | 21 | self.register_parameter( 22 | "fm", 23 | nn.Parameter(torch.zeros(3, plane_size, plane_size, feature_dim)), 24 | ) 25 | self.init_parameters() 26 | self.dim_out = ( 27 | self.feature_dim * 3 + 3 if include_xyz else self.feature_dim * 3 28 | ) 29 | 30 | def init_parameters(self) -> None: 31 | # Important for performance 32 | nn.init.uniform_(self.fm, -1e-2, 1e-2) 33 | 34 | def forward(self, x, level): 35 | # x in [0,1], level in [0,max_level] 36 | # x is Nx3, level is Nx1 37 | if 0 == x.shape[0]: 38 | return torch.zeros([x.shape[0], self.feature_dim * 3]).to(x) 39 | decomposed_x = torch.stack( 40 | [ 41 | x[:, None, [1, 2]], 42 | x[:, None, [0, 2]], 43 | x[:, None, [0, 1]], 44 | ], 45 | dim=0, 46 | ) # 3xNx1x2 47 | if 0 == self.n_levels: 48 | level = None 49 | else: 50 | # assert level.shape[0] > 0, [level.shape, x.shape] 51 | torch.stack([level, level, level], dim=0) 52 | level = torch.broadcast_to( 53 | level, decomposed_x.shape[:3] 54 | ).contiguous() 55 | enc = nvdiffrast.torch.texture( 56 | self.fm, 57 | decomposed_x, 58 | mip_level_bias=level, 59 | boundary_mode="clamp", 60 | max_mip_level=self.n_levels - 1, 61 | ) # 3xNx1xC 62 | enc = ( 63 | enc.permute(1, 2, 0, 3) 64 | .contiguous() 65 | .view( 66 | x.shape[0], 67 | self.feature_dim * 3, 68 | ) 69 | ) # Nx(3C) 70 | if self.include_xyz: 71 | enc = torch.cat([x, enc], dim=-1) 72 | return enc 73 | -------------------------------------------------------------------------------- /neural_field/field/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wbhu/Tri-MipRF/5b54a274d338ab64405d8123837e2d393c2262c2/neural_field/field/__init__.py -------------------------------------------------------------------------------- /neural_field/field/trimipRF.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Callable 3 | 4 | import gin 5 | import torch 6 | from torch import Tensor, nn 7 | import tinycudann as tcnn 8 | 9 | from neural_field.encoding.tri_mip import TriMipEncoding 10 | from neural_field.nn_utils.activations import trunc_exp 11 | 12 | 13 | @gin.configurable() 14 | class TriMipRF(nn.Module): 15 | def __init__( 16 | self, 17 | n_levels: int = 8, 18 | plane_size: int = 512, 19 | feature_dim: int = 16, 20 | geo_feat_dim: int = 15, 21 | net_depth_base: int = 2, 22 | net_depth_color: int = 4, 23 | net_width: int = 128, 24 | density_activation: Callable = lambda x: trunc_exp(x - 1), 25 | ) -> None: 26 | super().__init__() 27 | self.plane_size = plane_size 28 | self.log2_plane_size = math.log2(plane_size) 29 | self.geo_feat_dim = geo_feat_dim 30 | self.density_activation = density_activation 31 | 32 | self.encoding = TriMipEncoding(n_levels, plane_size, feature_dim) 33 | self.direction_encoding = tcnn.Encoding( 34 | n_input_dims=3, 35 | encoding_config={ 36 | "otype": "SphericalHarmonics", 37 | "degree": 4, 38 | }, 39 | ) 40 | self.mlp_base = tcnn.Network( 41 | n_input_dims=self.encoding.dim_out, 42 | n_output_dims=geo_feat_dim + 1, 43 | network_config={ 44 | "otype": "FullyFusedMLP", 45 | "activation": "ReLU", 46 | "output_activation": "None", 47 | "n_neurons": net_width, 48 | "n_hidden_layers": net_depth_base, 49 | }, 50 | ) 51 | self.mlp_head = tcnn.Network( 52 | n_input_dims=self.direction_encoding.n_output_dims + geo_feat_dim, 53 | n_output_dims=3, 54 | network_config={ 55 | "otype": "FullyFusedMLP", 56 | "activation": "ReLU", 57 | "output_activation": "Sigmoid", 58 | "n_neurons": net_width, 59 | "n_hidden_layers": net_depth_color, 60 | }, 61 | ) 62 | 63 | def query_density( 64 | self, x: Tensor, level_vol: Tensor, return_feat: bool = False 65 | ): 66 | level = ( 67 | level_vol if level_vol is None else level_vol + self.log2_plane_size 68 | ) 69 | selector = ((x > 0.0) & (x < 1.0)).all(dim=-1) 70 | enc = self.encoding( 71 | x.view(-1, 3), 72 | level=level.view(-1, 1), 73 | ) 74 | x = ( 75 | self.mlp_base(enc) 76 | .view(list(x.shape[:-1]) + [1 + self.geo_feat_dim]) 77 | .to(x) 78 | ) 79 | density_before_activation, base_mlp_out = torch.split( 80 | x, [1, self.geo_feat_dim], dim=-1 81 | ) 82 | density = ( 83 | self.density_activation(density_before_activation) 84 | * selector[..., None] 85 | ) 86 | return { 87 | "density": density, 88 | "feature": base_mlp_out if return_feat else None, 89 | } 90 | 91 | def query_rgb(self, dir, embedding): 92 | # dir in [-1,1] 93 | dir = (dir + 1.0) / 2.0 # SH encoding must be in the range [0, 1] 94 | d = self.direction_encoding(dir.view(-1, dir.shape[-1])) 95 | h = torch.cat([d, embedding.view(-1, self.geo_feat_dim)], dim=-1) 96 | rgb = ( 97 | self.mlp_head(h) 98 | .view(list(embedding.shape[:-1]) + [3]) 99 | .to(embedding) 100 | ) 101 | return {"rgb": rgb} 102 | -------------------------------------------------------------------------------- /neural_field/model/RFModel.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Union, List, Dict 3 | 4 | import gin 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from torchmetrics.functional import peak_signal_noise_ratio 9 | 10 | from utils.ray import RayBundle 11 | from utils.render_buffer import RenderBuffer 12 | 13 | 14 | # @gin.configurable() 15 | class RFModel(nn.Module): 16 | def __init__( 17 | self, 18 | aabb: Union[torch.Tensor, List[float]], 19 | samples_per_ray: int = 1024, 20 | ) -> None: 21 | super().__init__() 22 | if not isinstance(aabb, torch.Tensor): 23 | aabb = torch.tensor(aabb, dtype=torch.float32) 24 | self.register_buffer("aabb", aabb) 25 | self.samples_per_ray = samples_per_ray 26 | self.render_step_size = ( 27 | (self.aabb[3:] - self.aabb[:3]).max() 28 | * math.sqrt(3) 29 | / samples_per_ray 30 | ).item() 31 | aabb_min, aabb_max = torch.split(self.aabb, 3, dim=-1) 32 | self.aabb_size = aabb_max - aabb_min 33 | assert ( 34 | self.aabb_size[0] == self.aabb_size[1] == self.aabb_size[2] 35 | ), "Current implementation only supports cube aabb" 36 | self.field = None 37 | self.ray_sampler = None 38 | 39 | def contraction(self, x): 40 | aabb_min, aabb_max = self.aabb[:3].unsqueeze(0), self.aabb[ 41 | 3: 42 | ].unsqueeze(0) 43 | x = (x - aabb_min) / (aabb_max - aabb_min) 44 | return x 45 | 46 | def before_iter(self, step): 47 | pass 48 | 49 | def after_iter(self, step): 50 | pass 51 | 52 | def forward( 53 | self, 54 | rays: RayBundle, 55 | background_color=None, 56 | ): 57 | raise NotImplementedError 58 | 59 | @gin.configurable() 60 | def get_optimizer( 61 | self, lr=1e-3, weight_decay=1e-5, feature_lr_scale=10.0, **kwargs 62 | ): 63 | raise NotImplementedError 64 | 65 | @gin.configurable() 66 | def compute_loss( 67 | self, 68 | rays: RayBundle, 69 | rb: RenderBuffer, 70 | target: RenderBuffer, 71 | # Configurable 72 | metric='smooth_l1', 73 | **kwargs 74 | ) -> Dict: 75 | if 'smooth_l1' == metric: 76 | loss_fn = F.smooth_l1_loss 77 | elif 'mse' == metric: 78 | loss_fn = F.mse_loss 79 | elif 'mae' == metric: 80 | loss_fn = F.l1_loss 81 | else: 82 | raise NotImplementedError 83 | 84 | alive_ray_mask = (rb.alpha.squeeze(-1) > 0).detach() 85 | loss = loss_fn( 86 | rb.rgb[alive_ray_mask], target.rgb[alive_ray_mask], reduction='none' 87 | ) 88 | loss = ( 89 | loss * target.loss_multi[alive_ray_mask] 90 | ).sum() / target.loss_multi[alive_ray_mask].sum() 91 | return {'total_loss': loss} 92 | 93 | @gin.configurable() 94 | def compute_metrics( 95 | self, 96 | rays: RayBundle, 97 | rb: RenderBuffer, 98 | target: RenderBuffer, 99 | # Configurable 100 | **kwargs 101 | ) -> Dict: 102 | # ray info 103 | alive_ray_mask = (rb.alpha.squeeze(-1) > 0).detach() 104 | rendering_samples_actual = rb.num_samples[0].item() 105 | ray_info = { 106 | 'num_alive_ray': alive_ray_mask.long().sum().item(), 107 | 'rendering_samples_actual': rendering_samples_actual, 108 | 'num_rays': len(target), 109 | } 110 | # quality 111 | quality = {'PSNR': peak_signal_noise_ratio(rb.rgb, target.rgb).item()} 112 | return {**ray_info, **quality} 113 | -------------------------------------------------------------------------------- /neural_field/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .RFModel import RFModel 2 | from .trimipRF import TriMipRFModel 3 | 4 | 5 | def get_model(model_name: str = 'Tri-MipRF') -> RFModel: 6 | if 'Tri-MipRF' == model_name: 7 | return TriMipRFModel 8 | else: 9 | raise NotImplementedError 10 | -------------------------------------------------------------------------------- /neural_field/model/trimipRF.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Optional, Callable 2 | 3 | import gin 4 | import torch 5 | import nerfacc 6 | from nerfacc import render_weight_from_density, accumulate_along_rays 7 | 8 | from neural_field.model.RFModel import RFModel 9 | from utils.ray import RayBundle 10 | from utils.render_buffer import RenderBuffer 11 | from neural_field.field.trimipRF import TriMipRF 12 | 13 | 14 | @gin.configurable() 15 | class TriMipRFModel(RFModel): 16 | def __init__( 17 | self, 18 | aabb: Union[torch.Tensor, List[float]], 19 | samples_per_ray: int = 1024, 20 | occ_grid_resolution: int = 128, 21 | ) -> None: 22 | super().__init__(aabb=aabb, samples_per_ray=samples_per_ray) 23 | self.field = TriMipRF() 24 | self.ray_sampler = nerfacc.OccupancyGrid( 25 | roi_aabb=self.aabb, resolution=occ_grid_resolution 26 | ) 27 | 28 | self.feature_vol_radii = self.aabb_size[0] / 2.0 29 | self.register_buffer( 30 | "occ_level_vol", 31 | torch.log2( 32 | self.aabb_size[0] 33 | / occ_grid_resolution 34 | / 2.0 35 | / self.feature_vol_radii 36 | ), 37 | ) 38 | 39 | def before_iter(self, step): 40 | # update_ray_sampler 41 | self.ray_sampler.every_n_step( 42 | step=step, 43 | occ_eval_fn=lambda x: self.field.query_density( 44 | x=self.contraction(x), 45 | level_vol=torch.empty_like(x[..., 0]).fill_(self.occ_level_vol), 46 | )['density'] 47 | * self.render_step_size, 48 | occ_thre=5e-3, 49 | ) 50 | 51 | @staticmethod 52 | def compute_ball_radii(distance, radiis, cos): 53 | inverse_cos = 1.0 / cos 54 | tmp = (inverse_cos * inverse_cos - 1).sqrt() - radiis 55 | sample_ball_radii = distance * radiis * cos / (tmp * tmp + 1.0).sqrt() 56 | return sample_ball_radii 57 | 58 | def forward( 59 | self, 60 | rays: RayBundle, 61 | background_color=None, 62 | alpha_thre=0.0, 63 | ray_marching_aabb=None, 64 | ): 65 | # Ray sampling with occupancy grid 66 | with torch.no_grad(): 67 | 68 | def sigma_fn(t_starts, t_ends, ray_indices): 69 | ray_indices = ray_indices.long() 70 | t_origins = rays.origins[ray_indices] 71 | t_dirs = rays.directions[ray_indices] 72 | radiis = rays.radiis[ray_indices] 73 | cos = rays.ray_cos[ray_indices] 74 | distance = (t_starts + t_ends) / 2.0 75 | positions = t_origins + t_dirs * distance 76 | positions = self.contraction(positions) 77 | sample_ball_radii = self.compute_ball_radii( 78 | distance, radiis, cos 79 | ) 80 | level_vol = torch.log2( 81 | sample_ball_radii / self.feature_vol_radii 82 | ) # real level should + log2(feature_resolution) 83 | return self.field.query_density(positions, level_vol)['density'] 84 | 85 | ray_indices, t_starts, t_ends = nerfacc.ray_marching( 86 | rays.origins, 87 | rays.directions, 88 | scene_aabb=self.aabb, 89 | grid=self.ray_sampler, 90 | sigma_fn=sigma_fn, 91 | render_step_size=self.render_step_size, 92 | stratified=self.training, 93 | early_stop_eps=1e-4, 94 | ) 95 | 96 | # Ray rendering 97 | def rgb_sigma_fn(t_starts, t_ends, ray_indices): 98 | t_origins = rays.origins[ray_indices] 99 | t_dirs = rays.directions[ray_indices] 100 | radiis = rays.radiis[ray_indices] 101 | cos = rays.ray_cos[ray_indices] 102 | distance = (t_starts + t_ends) / 2.0 103 | positions = t_origins + t_dirs * distance 104 | positions = self.contraction(positions) 105 | sample_ball_radii = self.compute_ball_radii(distance, radiis, cos) 106 | level_vol = torch.log2( 107 | sample_ball_radii / self.feature_vol_radii 108 | ) # real level should + log2(feature_resolution) 109 | res = self.field.query_density( 110 | x=positions, 111 | level_vol=level_vol, 112 | return_feat=True, 113 | ) 114 | density, feature = res['density'], res['feature'] 115 | rgb = self.field.query_rgb(dir=t_dirs, embedding=feature)['rgb'] 116 | return rgb, density 117 | 118 | return self.rendering( 119 | t_starts, 120 | t_ends, 121 | ray_indices, 122 | rays, 123 | rgb_sigma_fn=rgb_sigma_fn, 124 | render_bkgd=background_color, 125 | ) 126 | 127 | def rendering( 128 | self, 129 | # ray marching results 130 | t_starts: torch.Tensor, 131 | t_ends: torch.Tensor, 132 | ray_indices: torch.Tensor, 133 | rays: RayBundle, 134 | # radiance field 135 | rgb_sigma_fn: Callable = None, # rendering options 136 | render_bkgd: Optional[torch.Tensor] = None, 137 | ) -> RenderBuffer: 138 | n_rays = rays.origins.shape[0] 139 | # Query sigma/alpha and color with gradients 140 | rgbs, sigmas = rgb_sigma_fn(t_starts, t_ends, ray_indices.long()) 141 | 142 | # Rendering 143 | weights = render_weight_from_density( 144 | t_starts, 145 | t_ends, 146 | sigmas, 147 | ray_indices=ray_indices, 148 | n_rays=n_rays, 149 | ) 150 | sample_buffer = { 151 | 'num_samples': torch.as_tensor( 152 | [len(t_starts)], dtype=torch.int32, device=rgbs.device 153 | ), 154 | } 155 | 156 | # Rendering: accumulate rgbs, opacities, and depths along the rays. 157 | colors = accumulate_along_rays( 158 | weights, ray_indices=ray_indices, values=rgbs, n_rays=n_rays 159 | ) 160 | opacities = accumulate_along_rays( 161 | weights, values=None, ray_indices=ray_indices, n_rays=n_rays 162 | ) 163 | opacities.clamp_( 164 | 0.0, 1.0 165 | ) # sometimes it may slightly bigger than 1.0, which will lead abnormal behaviours 166 | 167 | depths = accumulate_along_rays( 168 | weights, 169 | ray_indices=ray_indices, 170 | values=(t_starts + t_ends) / 2.0, 171 | n_rays=n_rays, 172 | ) 173 | depths = ( 174 | depths * rays.ray_cos 175 | ) # from distance to real depth (z value in camera space) 176 | 177 | # Background composition. 178 | if render_bkgd is not None: 179 | colors = colors + render_bkgd * (1.0 - opacities) 180 | 181 | return RenderBuffer( 182 | rgb=colors, 183 | alpha=opacities, 184 | depth=depths, 185 | **sample_buffer, 186 | _static_field=set(sample_buffer), 187 | ) 188 | 189 | @gin.configurable() 190 | def get_optimizer( 191 | self, lr=2e-3, weight_decay=1e-5, feature_lr_scale=10.0, **kwargs 192 | ): 193 | params_list = [] 194 | params_list.append( 195 | dict( 196 | params=self.field.encoding.parameters(), 197 | lr=lr * feature_lr_scale, 198 | ) 199 | ) 200 | params_list.append( 201 | dict(params=self.field.direction_encoding.parameters(), lr=lr) 202 | ) 203 | params_list.append(dict(params=self.field.mlp_base.parameters(), lr=lr)) 204 | params_list.append(dict(params=self.field.mlp_head.parameters(), lr=lr)) 205 | 206 | optim = torch.optim.AdamW( 207 | params_list, 208 | weight_decay=weight_decay, 209 | **kwargs, 210 | eps=1e-15, 211 | ) 212 | return optim 213 | -------------------------------------------------------------------------------- /neural_field/nn_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wbhu/Tri-MipRF/5b54a274d338ab64405d8123837e2d393c2262c2/neural_field/nn_utils/__init__.py -------------------------------------------------------------------------------- /neural_field/nn_utils/activations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.cuda.amp import custom_bwd, custom_fwd 4 | 5 | 6 | class TruncExp(Function): # pylint: disable=abstract-method 7 | # Implementation from torch-ngp: 8 | # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py 9 | @staticmethod 10 | @custom_fwd(cast_inputs=torch.float32) 11 | def forward(ctx, x): # pylint: disable=arguments-differ 12 | ctx.save_for_backward(x) 13 | return torch.exp(x) 14 | 15 | @staticmethod 16 | @custom_bwd 17 | def backward(ctx, g): # pylint: disable=arguments-differ 18 | x = ctx.saved_tensors[0] 19 | return g * torch.exp(torch.clamp(x, min=-15, max=15)) 20 | 21 | 22 | trunc_exp = TruncExp.apply 23 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 80 3 | skip-string-normalization = true -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | av==9.2.0 2 | beautifulsoup4==4.11.1 3 | entrypoints==0.4 4 | gdown==4.5.1 5 | gin-config==0.5.0 6 | h5py==3.7.0 7 | imageio==2.21.1 8 | imageio-ffmpeg 9 | ipython==7.19.0 10 | kornia==0.6.8 11 | loguru==0.6.0 12 | lpips==0.1.4 13 | mediapy==1.1.0 14 | mmcv==1.6.2 15 | ninja==1.10.2.3 16 | numpy==1.23.3 17 | open3d==0.16.0 18 | opencv-python==4.6.0.66 19 | pandas==1.5.0 20 | Pillow==9.2.0 21 | plotly==5.7.0 22 | pycurl==7.43.0.6 23 | PyMCubes==0.1.2 24 | pyransac3d==0.6.0 25 | PyYAML==6.0 26 | rich==12.6.0 27 | scipy==1.9.2 28 | tensorboard==2.9.0 29 | torch-fidelity==0.3.0 30 | torchmetrics==0.10.0 31 | torchtyping==0.1.4 32 | tqdm==4.64.1 33 | tyro==0.3.25 34 | appdirs 35 | nerfacc==0.3.5 36 | plyfile 37 | scikit-image 38 | trimesh 39 | torch_efficient_distloss 40 | umsgpack 41 | pyngrok 42 | cryptography==39.0.2 43 | omegaconf==2.2.3 44 | segmentation-refinement 45 | xatlas 46 | protobuf==3.20.0 47 | jinja2 48 | click==8.1.7 49 | tensorboardx 50 | termcolor -------------------------------------------------------------------------------- /scripts/convert_blender_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from os import path 4 | 5 | from absl import app 6 | from absl import flags 7 | import numpy as np 8 | from PIL import Image 9 | 10 | FLAGS = flags.FLAGS 11 | 12 | flags.DEFINE_string('blenderdir', None, 'Base directory for all Blender data.') 13 | flags.DEFINE_string('outdir', None, 'Where to save multiscale data.') 14 | flags.DEFINE_integer('n_down', 4, 'How many levels of downscaling to use.') 15 | 16 | 17 | def load_renderings(data_dir, split): 18 | """Load images and metadata from disk.""" 19 | f = 'transforms_{}.json'.format(split) 20 | with open(path.join(data_dir, f), 'r') as fp: 21 | meta = json.load(fp) 22 | images = [] 23 | cams = [] 24 | print('Loading imgs') 25 | for frame in meta['frames']: 26 | fname = os.path.join(data_dir, frame['file_path'] + '.png') 27 | with open(fname, 'rb') as imgin: 28 | image = np.array(Image.open(imgin), dtype=np.float32) / 255.0 29 | cams.append(frame['transform_matrix']) 30 | images.append(image) 31 | ret = {} 32 | ret['images'] = np.stack(images, axis=0) 33 | print('Loaded all images, shape is', ret['images'].shape) 34 | ret['camtoworlds'] = np.stack(cams, axis=0) 35 | w = ret['images'].shape[2] 36 | camera_angle_x = float(meta['camera_angle_x']) 37 | ret['focal'] = 0.5 * w / np.tan(0.5 * camera_angle_x) 38 | return ret 39 | 40 | 41 | def down2(img): 42 | sh = img.shape 43 | return np.mean(np.reshape(img, [sh[0] // 2, 2, sh[1] // 2, 2, -1]), (1, 3)) 44 | 45 | 46 | def convert_to_nerfdata(basedir, newdir, n_down): 47 | """Convert Blender data to multiscale.""" 48 | if not os.path.exists(newdir): 49 | os.makedirs(newdir) 50 | splits = ['train', 'val', 'test'] 51 | bigmeta = {} 52 | # Foreach split in the dataset 53 | for split in splits: 54 | print('Split', split) 55 | # Load everything 56 | data = load_renderings(basedir, split) 57 | 58 | # Save out all the images 59 | imgdir = 'images_{}'.format(split) 60 | os.makedirs(os.path.join(newdir, imgdir), exist_ok=True) 61 | fnames = [] 62 | widths = [] 63 | heights = [] 64 | focals = [] 65 | cam2worlds = [] 66 | lossmults = [] 67 | labels = [] 68 | nears, fars = [], [] 69 | f = data['focal'] 70 | print('Saving images') 71 | for i, img in enumerate(data['images']): 72 | for j in range(n_down): 73 | fname = '{}/{:03d}_d{}.png'.format(imgdir, i, j) 74 | fnames.append(fname) 75 | fname = os.path.join(newdir, fname) 76 | with open(fname, 'wb') as imgout: 77 | img8 = Image.fromarray(np.uint8(img * 255)) 78 | img8.save(imgout) 79 | widths.append(img.shape[1]) 80 | heights.append(img.shape[0]) 81 | focals.append(f / 2**j) 82 | cam2worlds.append(data['camtoworlds'][i].tolist()) 83 | lossmults.append(4.0**j) 84 | labels.append(j) 85 | nears.append(2.0) 86 | fars.append(6.0) 87 | img = down2(img) 88 | 89 | # Create metadata 90 | meta = {} 91 | meta['file_path'] = fnames 92 | meta['cam2world'] = cam2worlds 93 | meta['width'] = widths 94 | meta['height'] = heights 95 | meta['focal'] = focals 96 | meta['label'] = labels 97 | meta['near'] = nears 98 | meta['far'] = fars 99 | meta['lossmult'] = lossmults 100 | 101 | fx = np.array(focals) 102 | fy = np.array(focals) 103 | cx = np.array(meta['width']) * 0.5 104 | cy = np.array(meta['height']) * 0.5 105 | arr0 = np.zeros_like(cx) 106 | arr1 = np.ones_like(cx) 107 | k_inv = np.array( 108 | [ 109 | [arr1 / fx, arr0, -cx / fx], 110 | [arr0, -arr1 / fy, cy / fy], 111 | [arr0, arr0, -arr1], 112 | ] 113 | ) 114 | k_inv = np.moveaxis(k_inv, -1, 0) 115 | meta['pix2cam'] = k_inv.tolist() 116 | 117 | bigmeta[split] = meta 118 | 119 | for k in bigmeta: 120 | for j in bigmeta[k]: 121 | print(k, j, type(bigmeta[k][j]), np.array(bigmeta[k][j]).shape) 122 | 123 | jsonfile = os.path.join(newdir, 'metadata.json') 124 | with open(jsonfile, 'w') as f: 125 | json.dump(bigmeta, f, ensure_ascii=False, indent=4) 126 | 127 | 128 | def main(unused_argv): 129 | 130 | blenderdir = FLAGS.blenderdir 131 | outdir = FLAGS.outdir 132 | n_down = FLAGS.n_down 133 | if not os.path.exists(outdir): 134 | os.makedirs(outdir) 135 | 136 | dirs = [os.path.join(blenderdir, f) for f in os.listdir(blenderdir)] 137 | dirs = [d for d in dirs if os.path.isdir(d)] 138 | print(dirs) 139 | for basedir in dirs: 140 | print() 141 | newdir = os.path.join(outdir, os.path.basename(basedir)) 142 | print('Converting from', basedir, 'to', newdir) 143 | convert_to_nerfdata(basedir, newdir, n_down) 144 | 145 | 146 | if __name__ == '__main__': 147 | app.run(main) 148 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer 2 | -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import gin 4 | import numpy as np 5 | import torch 6 | from pathlib import Path 7 | from loguru import logger 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | from typing import Dict, List 11 | 12 | from neural_field.model.RFModel import RFModel 13 | from utils.writer import TensorboardWriter 14 | from utils.colormaps import apply_depth_colormap 15 | import dataset.utils.io as data_io 16 | 17 | 18 | @gin.configurable() 19 | class Trainer: 20 | def __init__( 21 | self, 22 | model: RFModel, 23 | train_loader: DataLoader, 24 | eval_loader: DataLoader, 25 | # configurable 26 | base_exp_dir: str = 'experiments', 27 | exp_name: str = 'Tri-MipRF', 28 | max_steps: int = 50000, 29 | log_step: int = 500, 30 | eval_step: int = 500, 31 | target_sample_batch_size: int = 65536, 32 | test_chunk_size: int = 8192, 33 | dynamic_batch_size: bool = True, 34 | num_rays: int = 8192, 35 | varied_eval_img: bool = True, 36 | ): 37 | self.model = model.cuda() 38 | self.train_loader = train_loader 39 | self.eval_loader = eval_loader 40 | 41 | self.max_steps = max_steps 42 | self.target_sample_batch_size = target_sample_batch_size 43 | # exp_dir 44 | self.exp_dir = Path(base_exp_dir) / exp_name 45 | self.log_step = log_step 46 | self.eval_step = eval_step 47 | self.test_chunk_size = test_chunk_size 48 | self.dynamic_batch_size = dynamic_batch_size 49 | self.num_rays = num_rays 50 | self.varied_eval_img = varied_eval_img 51 | 52 | self.writer = TensorboardWriter(log_dir=self.exp_dir) 53 | 54 | self.optimizer = self.model.get_optimizer() 55 | self.scheduler = self.get_scheduler() 56 | self.grad_scaler = torch.cuda.amp.GradScaler(2**10) 57 | 58 | # Save configure 59 | conf = gin.operative_config_str() 60 | logger.info(conf) 61 | self.save_config(conf) 62 | 63 | def train_iter(self, step: int, data: Dict, logging=False): 64 | tic = time.time() 65 | cam_rays = data['cam_rays'] 66 | num_rays = min(self.num_rays, len(cam_rays)) 67 | cam_rays = cam_rays[:num_rays].cuda(non_blocking=True) 68 | target = data['target'][:num_rays].cuda(non_blocking=True) 69 | 70 | rb = self.model(cam_rays, target.render_bkgd) 71 | 72 | # compute loss 73 | loss_dict = self.model.compute_loss(cam_rays, rb, target) 74 | metrics = self.model.compute_metrics(cam_rays, rb, target) 75 | if 0 == metrics.get("rendering_samples_actual", -1): 76 | return metrics 77 | 78 | # update 79 | self.optimizer.zero_grad() 80 | self.grad_scaler.scale(loss_dict['total_loss']).backward() 81 | self.optimizer.step() 82 | 83 | # logging 84 | if logging: 85 | with torch.no_grad(): 86 | iter_time = time.time() - tic 87 | remaining_time = (self.max_steps - step) * iter_time 88 | status = { 89 | 'lr': self.optimizer.param_groups[0]["lr"], 90 | 'step': step, 91 | 'iter_time': iter_time, 92 | 'ETA': remaining_time, 93 | } 94 | self.writer.write_scalar_dicts( 95 | ['loss', 'metrics', 'status'], 96 | [ 97 | {k: v.item() for k, v in loss_dict.items()}, 98 | metrics, 99 | status, 100 | ], 101 | step, 102 | ) 103 | return metrics 104 | 105 | def fit(self): 106 | logger.info("==> Start training ...") 107 | 108 | iter_train_loader = iter(self.train_loader) 109 | iter_eval_loader = iter(self.eval_loader) 110 | eval_0 = next(iter_eval_loader) 111 | self.model.train() 112 | for step in range(self.max_steps): 113 | self.model.before_iter(step) 114 | metrics = self.train_iter( 115 | step, 116 | data=next(iter_train_loader), 117 | logging=(step % self.log_step == 0 and step > 0) 118 | or (step == 100), 119 | ) 120 | if 0 == metrics.get("rendering_samples_actual", -1): 121 | continue 122 | 123 | self.scheduler.step() 124 | if self.dynamic_batch_size: 125 | rendering_samples_actual = metrics.get( 126 | "rendering_samples_actual", 127 | self.target_sample_batch_size, 128 | ) 129 | self.num_rays = ( 130 | self.num_rays 131 | * self.target_sample_batch_size 132 | // rendering_samples_actual 133 | + 1 134 | ) 135 | 136 | self.model.after_iter(step) 137 | 138 | if step > 0 and step % self.eval_step == 0: 139 | self.model.eval() 140 | metrics, final_rb, target = self.eval_img( 141 | next(iter_eval_loader) if self.varied_eval_img else eval_0, 142 | compute_metrics=True, 143 | ) 144 | self.writer.write_scalar_dicts(['eval'], [metrics], step) 145 | self.writer.write_image('eval/rgb', final_rb.rgb, step) 146 | self.writer.write_image('gt/rgb', target.rgb, step) 147 | self.writer.write_image( 148 | 'eval/depth', 149 | apply_depth_colormap(final_rb.depth), 150 | step, 151 | ) 152 | self.writer.write_image('eval/alpha', final_rb.alpha, step) 153 | 154 | self.model.train() 155 | 156 | logger.info('==> Training done!') 157 | self.save_ckpt() 158 | 159 | @torch.no_grad() 160 | def eval_img(self, data, compute_metrics=True): 161 | cam_rays = data['cam_rays'].cuda(non_blocking=True) 162 | target = data['target'].cuda(non_blocking=True) 163 | 164 | final_rb = None 165 | flatten_rays = cam_rays.reshape(-1) 166 | flatten_target = target.reshape(-1) 167 | for i in range(0, len(cam_rays), self.test_chunk_size): 168 | rb = self.model( 169 | flatten_rays[i : i + self.test_chunk_size], 170 | flatten_target[i : i + self.test_chunk_size].render_bkgd, 171 | ) 172 | final_rb = rb if final_rb is None else final_rb.cat(rb) 173 | final_rb = final_rb.reshape(cam_rays.shape) 174 | metrics = None 175 | if compute_metrics: 176 | metrics = self.model.compute_metrics(cam_rays, final_rb, target) 177 | return metrics, final_rb, target 178 | 179 | @torch.no_grad() 180 | def eval( 181 | self, 182 | save_results: bool = False, 183 | rendering_channels: List[str] = ["rgb"], 184 | ): 185 | # ipdb.set_trace() 186 | logger.info("==> Start evaluation on testset ...") 187 | if save_results: 188 | res_dir = self.exp_dir / 'rendering' 189 | res_dir.mkdir(parents=True, exist_ok=True) 190 | results = {"names": []} 191 | results.update({k: [] for k in rendering_channels}) 192 | 193 | self.model.eval() 194 | metrics = [] 195 | for idx, data in enumerate(tqdm(self.eval_loader)): 196 | metric, rb, target = self.eval_img(data) 197 | metrics.append(metric) 198 | if save_results: 199 | results["names"].append(data['name']) 200 | for channel in rendering_channels: 201 | if hasattr(rb, channel): 202 | values = getattr(rb, channel).cpu().numpy() 203 | if 'depth' == channel: 204 | values = (values * 10000.0).astype( 205 | np.uint16 206 | ) # scale the depth by 10k, and save it as uint16 png images 207 | results[channel].append(values) 208 | else: 209 | raise NotImplementedError 210 | del rb 211 | if save_results: 212 | for idx, name in enumerate(tqdm(results['names'])): 213 | for channel in rendering_channels: 214 | channel_path = res_dir / channel 215 | data = results[channel][idx] 216 | data_io.write_rendering(data, channel_path, name) 217 | 218 | metrics = {k: [dct[k] for dct in metrics] for k in metrics[0]} 219 | logger.info("==> Evaluation done") 220 | for k, v in metrics.items(): 221 | metrics[k] = sum(v) / len(v) 222 | self.writer.write_scalar_dicts(['benchmark'], [metrics], 0) 223 | self.writer.tb_writer.close() 224 | 225 | def save_config(self, config): 226 | dest = self.exp_dir / 'config.gin' 227 | if dest.exists(): 228 | return 229 | self.exp_dir.mkdir(parents=True, exist_ok=True) 230 | with open(self.exp_dir / 'config.gin', 'w') as f: 231 | f.write(config) 232 | md_config_str = gin.config.markdown(config) 233 | self.writer.write_config(md_config_str) 234 | self.writer.tb_writer.flush() 235 | 236 | def save_ckpt(self): 237 | dest = self.exp_dir / 'model.ckpt' 238 | logger.info('==> Saving checkpoints to ' + str(dest)) 239 | torch.save( 240 | { 241 | "model": self.model.state_dict(), 242 | }, 243 | dest, 244 | ) 245 | 246 | def load_ckpt(self): 247 | dest = self.exp_dir / 'model.ckpt' 248 | loaded_state = torch.load(dest, map_location="cpu") 249 | logger.info('==> Loading checkpoints from ' + str(dest)) 250 | self.model.load_state_dict(loaded_state['model']) 251 | 252 | @gin.configurable() 253 | def get_scheduler(self, gamma=0.6, **kwargs): 254 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 255 | self.optimizer, 256 | milestones=[ 257 | self.max_steps // 2, 258 | self.max_steps * 3 // 4, 259 | self.max_steps * 5 // 6, 260 | self.max_steps * 9 // 10, 261 | ], 262 | gamma=gamma, 263 | **kwargs, 264 | ) 265 | return scheduler 266 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wbhu/Tri-MipRF/5b54a274d338ab64405d8123837e2d393c2262c2/utils/__init__.py -------------------------------------------------------------------------------- /utils/colormaps.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from matplotlib import cm 5 | from torchtyping import TensorType 6 | 7 | WHITE = torch.tensor([1.0, 1.0, 1.0]) 8 | BLACK = torch.tensor([0.0, 0.0, 0.0]) 9 | RED = torch.tensor([1.0, 0.0, 0.0]) 10 | GREEN = torch.tensor([0.0, 1.0, 0.0]) 11 | BLUE = torch.tensor([0.0, 0.0, 1.0]) 12 | 13 | 14 | def apply_colormap( 15 | image: TensorType["bs":..., 1], 16 | cmap="viridis", 17 | ) -> TensorType["bs":..., "rgb":3]: 18 | """Convert single channel to a color image. 19 | Args: 20 | image: Single channel image. 21 | cmap: Colormap for image. 22 | Returns: 23 | TensorType: Colored image 24 | """ 25 | 26 | colormap = cm.get_cmap(cmap) 27 | colormap = torch.tensor(colormap.colors).to(image.device) # type: ignore 28 | image_long = (image * 255).long() 29 | image_long_min = torch.min(image_long) 30 | image_long_max = torch.max(image_long) 31 | assert image_long_min >= 0, f"the min value is {image_long_min}" 32 | assert image_long_max <= 255, f"the max value is {image_long_max}" 33 | return colormap[image_long[..., 0]] 34 | 35 | 36 | def apply_depth_colormap( 37 | depth: TensorType["bs":..., 1], 38 | accumulation: Optional[TensorType["bs":..., 1]] = None, 39 | near_plane: Optional[float] = None, 40 | far_plane: Optional[float] = None, 41 | cmap="turbo", 42 | ) -> TensorType["bs":..., "rgb":3]: 43 | """Converts a depth image to color for easier analysis. 44 | Args: 45 | depth: Depth image. 46 | accumulation: Ray accumulation used for masking vis. 47 | near_plane: Closest depth to consider. If None, use min image value. 48 | far_plane: Furthest depth to consider. If None, use max image value. 49 | cmap: Colormap to apply. 50 | Returns: 51 | Colored depth image 52 | """ 53 | 54 | near_plane = near_plane or float(torch.min(depth)) 55 | far_plane = far_plane or float(torch.max(depth)) 56 | 57 | depth = (depth - near_plane) / (far_plane - near_plane + 1e-10) 58 | depth = torch.clip(depth, 0, 1) 59 | depth = torch.nan_to_num(depth, nan=0.0) 60 | 61 | colored_image = apply_colormap(depth, cmap=cmap) 62 | 63 | if accumulation is not None: 64 | colored_image = colored_image * accumulation + (1 - accumulation) 65 | 66 | return colored_image 67 | 68 | 69 | def apply_boolean_colormap( 70 | image: TensorType["bs":..., 1, bool], 71 | true_color: TensorType["bs":..., "rgb":3] = WHITE, 72 | false_color: TensorType["bs":..., "rgb":3] = BLACK, 73 | ) -> TensorType["bs":..., "rgb":3]: 74 | """Converts a depth image to color for easier analysis. 75 | Args: 76 | image: Boolean image. 77 | true_color: Color to use for True. 78 | false_color: Color to use for False. 79 | Returns: 80 | Colored boolean image 81 | """ 82 | 83 | colored_image = torch.ones(image.shape[:-1] + (3,)) 84 | colored_image[image[..., 0], :] = true_color 85 | colored_image[~image[..., 0], :] = false_color 86 | return colored_image 87 | -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import os 5 | from loguru import logger 6 | 7 | 8 | def set_random_seed(seed): 9 | random.seed(seed) 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | -------------------------------------------------------------------------------- /utils/ray.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | from utils.tensor_dataclass import TensorDataclass 6 | 7 | 8 | @dataclass 9 | class RayBundle(TensorDataclass): 10 | origins: Optional[torch.Tensor] = None 11 | """Ray origins (XYZ)""" 12 | 13 | directions: Optional[torch.Tensor] = None 14 | """Unit ray direction vector""" 15 | 16 | radiis: Optional[torch.Tensor] = None 17 | """Ray image plane intersection circle radii""" 18 | 19 | ray_cos: Optional[torch.Tensor] = None 20 | """Ray cos""" 21 | 22 | def __len__(self): 23 | num_rays = torch.numel(self.origins) // self.origins.shape[-1] 24 | return num_rays 25 | 26 | @property 27 | def shape(self): 28 | return list(super().shape) 29 | 30 | 31 | @dataclass 32 | class RayBundleExt(RayBundle): 33 | 34 | ray_depth: Optional[torch.Tensor] = None 35 | 36 | 37 | @dataclass 38 | class RayBundleRast(RayBundleExt): 39 | 40 | ray_uv: Optional[torch.Tensor] = None 41 | ray_mip_level: Optional[torch.Tensor] = None 42 | -------------------------------------------------------------------------------- /utils/render_buffer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | from __future__ import annotations 10 | from dataclasses import fields, dataclass, make_dataclass 11 | from typing import Optional, List, Tuple, Set, Dict, Iterator 12 | import torch 13 | import types 14 | 15 | from utils.tensor_dataclass import TensorDataclass 16 | 17 | 18 | __TD_VARIANTS__ = dict() 19 | 20 | 21 | @dataclass 22 | class RenderBuffer(TensorDataclass): 23 | """ 24 | A torch based, multi-channel, pixel buffer object. 25 | RenderBuffers are "smart" data buffers, used for accumulating tracing results, blending buffers of information, 26 | and providing discretized images. 27 | 28 | The spatial dimensions of RenderBuffer channels are flexible, see TensorDataclass. 29 | """ 30 | 31 | rgb: Optional[torch.Tensor] = None 32 | """ rgb is a shaded RGB color. """ 33 | 34 | alpha: Optional[torch.Tensor] = None 35 | """ alpha is the alpha component of RGB-A. """ 36 | 37 | depth: Optional[torch.Tensor] = None 38 | """ depth is usually a distance to the surface hit point.""" 39 | 40 | # Renderbuffer supports additional custom channels passed to the Renderbuffer constructor. 41 | # Some example of custom channels used throughout wisp: 42 | # xyz=None, # xyz is usually the xyz position for the surface hit point. 43 | # hit=None, # hit is usually a segmentation mask of hit points. 44 | # normal=None, # normal is usually the surface normal at the hit point. 45 | # shadow =None, # shadow is usually some additional buffer for shadowing. 46 | # ao=None, # ao is usually some addition buffer for ambient occlusion. 47 | # ray_o=None, # ray_o is usually the ray origin. 48 | # ray_d=None, # ray_d is usually the ray direction. 49 | # err=None, # err is usually some error metric against the ground truth. 50 | # gts=None, # gts is usually the ground truth image. 51 | 52 | def __new__(cls, *args, **kwargs): 53 | class_fields = [f.name for f in fields(RenderBuffer)] 54 | new_fields = [k for k in kwargs.keys() if k not in class_fields] 55 | if 0 < len(new_fields): 56 | class_key = frozenset(new_fields) 57 | rb_class = __TD_VARIANTS__.get(class_key) 58 | if rb_class is None: 59 | rb_class = make_dataclass( 60 | f'RenderBuffer_{len(__TD_VARIANTS__)}', 61 | fields=[ 62 | ( 63 | k, 64 | Optional[torch.Tensor], 65 | None, 66 | ) 67 | for k in kwargs.keys() 68 | ], 69 | bases=(RenderBuffer,), 70 | ) 71 | # Cache for future __new__ calls 72 | __TD_VARIANTS__[class_key] = rb_class 73 | setattr(types, rb_class.__name__, rb_class) 74 | return super(RenderBuffer, rb_class).__new__(rb_class) 75 | else: 76 | return super(TensorDataclass, cls).__new__(cls) 77 | 78 | @property 79 | def rgba(self) -> Optional[torch.Tensor]: 80 | """ 81 | Returns: 82 | (Optional[torch.Tensor]) A concatenated rgba. If rgb or alpha are none, this property will return None. 83 | """ 84 | if self.alpha is None or self.rgb is None: 85 | return None 86 | else: 87 | return torch.cat((self.rgb, self.alpha), dim=-1) 88 | 89 | @rgba.setter 90 | def rgba(self, val: Optional[torch.Tensor]) -> None: 91 | """ 92 | Args: 93 | val (Optional[torch.Tensor]) A concatenated rgba channel value, which sets values for the rgb and alpha 94 | internal channels simultaneously. 95 | """ 96 | self.rgb = val[..., 0:-1] 97 | self.alpha = val[..., -1:] 98 | 99 | @property 100 | def channels(self) -> Set[str]: 101 | """Returns a set of channels supported by this RenderBuffer""" 102 | all_channels = self.fields 103 | static_channels = self._static_field 104 | return all_channels.difference(static_channels) 105 | 106 | def has_channel(self, name: str) -> bool: 107 | """Returns whether the RenderBuffer supports the specified channel""" 108 | return name in self.channels 109 | 110 | def get_channel(self, name: str) -> Optional[torch.Tensor]: 111 | """Returns the pixels value of the specified channel, 112 | assuming this RenderBuffer supports the specified channel. 113 | """ 114 | return getattr(self, name) 115 | 116 | def transpose(self) -> RenderBuffer: 117 | """Permutes dimensions 0 and 1 of each channel. 118 | The rest of the channel dimensions will remain in the same order. 119 | """ 120 | fn = lambda x: x.permute(1, 0, *tuple(range(2, x.ndim))) 121 | return self._apply(fn) 122 | 123 | def scale(self, size: Tuple, interpolation='bilinear') -> RenderBuffer: 124 | """Upsamples or downsamples the renderbuffer pixels using the specified interpolation. 125 | Scaling assumes renderbuffers with 2 spatial dimensions, e.g. (H, W, C) or (W, H, C). 126 | 127 | Warning: for non-floating point channels, this function will upcast to floating point dtype 128 | to perform interpolation, and will then re-cast back to the original dtype. 129 | Hence truncations due to rounding may occur. 130 | 131 | Args: 132 | size (Tuple): The new spatial dimensions of the renderbuffer. 133 | interpolation (str): Interpolation method applied to cope with missing or decimated pixels due to 134 | up / downsampling. The interpolation methods supported are aligned with 135 | :func:`torch.nn.functional.interpolate`. 136 | 137 | Returns: 138 | (RenderBuffer): A new RenderBuffer object with rescaled channels. 139 | """ 140 | 141 | def _scale(x): 142 | assert ( 143 | x.ndim == 3 144 | ), 'RenderBuffer scale() assumes channels have 2D spatial dimensions.' 145 | # Some versions of torch don't support direct interpolation of non-fp tensors 146 | dtype = x.dtype 147 | if not torch.is_floating_point(x): 148 | x = x.float() 149 | x = x.permute(2, 0, 1)[None] 150 | x = torch.nn.functional.interpolate( 151 | x, size=size, mode=interpolation 152 | ) 153 | x = x[0].permute(1, 2, 0) 154 | if x.dtype != dtype: 155 | x = torch.round(x).to(dtype) 156 | return x 157 | 158 | return self._apply(_scale) 159 | 160 | def exr_dict(self) -> Dict[str, torch.Tensor]: 161 | """This function returns an EXR format compatible dictionary. 162 | 163 | Returns: 164 | (Dict[str, torch.Tensor]) 165 | a dictionary suitable for use with `pyexr` to output multi-channel EXR images which can be 166 | viewed interactively with software like `tev`. 167 | This is suitable for debugging geometric quantities like ray origins and ray directions. 168 | """ 169 | _dict = self.numpy_dict() 170 | if 'rgb' in _dict: 171 | _dict['default'] = _dict['rgb'] 172 | del _dict['rgb'] 173 | return _dict 174 | 175 | def image(self) -> RenderBuffer: 176 | """This function will return a copy of the RenderBuffer which will contain 8-bit [0,255] images. 177 | 178 | This function is used to output a RenderBuffer suitable for saving as a 8-bit RGB image (e.g. with 179 | Pillow). Since this quantization operation permanently loses information, this is not an inplace 180 | operation and will return a copy of the RenderBuffer. Currently this function will only return 181 | the hit segmentation mask, normalized depth, RGB, and the surface normals. 182 | 183 | If users want custom behaviour, users can implement their own conversion function which takes a 184 | RenderBuffer as input. 185 | """ 186 | norm = lambda arr: ((arr + 1.0) / 2.0) if arr is not None else None 187 | bwrgb = ( 188 | lambda arr: torch.cat([arr] * 3, dim=-1) 189 | if arr is not None 190 | else None 191 | ) 192 | rgb8 = lambda arr: (arr * 255.0) if arr is not None else None 193 | 194 | channels = dict() 195 | if self.rgb is not None: 196 | channels['rgb'] = rgb8(self.rgb) 197 | if self.alpha is not None: 198 | channels['alpha'] = rgb8(self.alpha) 199 | if self.depth is not None: 200 | # If the relative depth is respect to some camera clipping plane, the depth should 201 | # be clipped in advance. 202 | relative_depth = self.depth / (torch.max(self.depth) + 1e-8) 203 | channels['depth'] = rgb8(bwrgb(relative_depth)) 204 | 205 | if hasattr(self, 'hit') and self.hit is not None: 206 | channels['hit'] = rgb8(bwrgb(self.hit)) 207 | else: 208 | channels['hit'] = None 209 | if hasattr(self, 'normal') and self.normal is not None: 210 | channels['normal'] = rgb8(norm(self.normal)) 211 | else: 212 | channels['normal'] = None 213 | 214 | return RenderBuffer(**channels) 215 | -------------------------------------------------------------------------------- /utils/tensor_dataclass.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tensor dataclass""" 16 | 17 | import dataclasses 18 | from dataclasses import dataclass 19 | from typing import ( 20 | Dict, 21 | Set, 22 | List, 23 | NoReturn, 24 | Optional, 25 | Tuple, 26 | TypeVar, 27 | Union, 28 | Iterator, 29 | ) 30 | 31 | import numpy as np 32 | import torch 33 | 34 | TensorDataclassT = TypeVar("TensorDataclassT", bound="TensorDataclass") 35 | 36 | 37 | @dataclass 38 | class TensorDataclass: 39 | """@dataclass of tensors with the same size batch. Allows indexing and standard tensor ops. 40 | Fields that are not Tensors will not be batched unless they are also a TensorDataclass. 41 | 42 | Example: 43 | 44 | .. code-block:: python 45 | 46 | @dataclass 47 | class TestTensorDataclass(TensorDataclass): 48 | a: torch.Tensor 49 | b: torch.Tensor 50 | c: torch.Tensor = None 51 | 52 | # Create a new tensor dataclass with batch size of [2,3,4] 53 | test = TestTensorDataclass(a=torch.ones((2, 3, 4, 2)), b=torch.ones((4, 3))) 54 | 55 | test.shape # [2, 3, 4] 56 | test.a.shape # [2, 3, 4, 2] 57 | test.b.shape # [2, 3, 4, 3] 58 | 59 | test.reshape((6,4)).shape # [6, 4] 60 | test.flatten().shape # [24,] 61 | 62 | test[..., 0].shape # [2, 3] 63 | test[:, 0, :].shape # [2, 4] 64 | """ 65 | 66 | _shape: tuple = torch.Size([]) 67 | _static_field: set = dataclasses.field(default_factory=set) 68 | 69 | def __post_init__(self) -> None: 70 | """Finishes setting up the TensorDataclass 71 | 72 | This will 1) find the broadcasted shape and 2) broadcast all fields to this shape 3) 73 | set _shape to be the broadcasted shape. 74 | """ 75 | if not dataclasses.is_dataclass(self): 76 | raise TypeError("TensorDataclass must be a dataclass") 77 | 78 | batch_shapes = self._get_dict_batch_shapes( 79 | { 80 | f.name: self.__getattribute__(f.name) 81 | for f in dataclasses.fields(self) 82 | } 83 | ) 84 | if len(batch_shapes) == 0: 85 | raise ValueError("TensorDataclass must have at least one tensor") 86 | try: 87 | batch_shape = torch.broadcast_shapes(*batch_shapes) 88 | 89 | broadcasted_fields = self._broadcast_dict_fields( 90 | { 91 | f.name: self.__getattribute__(f.name) 92 | for f in dataclasses.fields(self) 93 | }, 94 | batch_shape, 95 | ) 96 | for f, v in broadcasted_fields.items(): 97 | self.__setattr__(f, v) 98 | 99 | self.__setattr__("_shape", batch_shape) 100 | except RuntimeError: 101 | pass 102 | except IndexError: 103 | # import ipdb;ipdb.set_trace() 104 | pass 105 | 106 | def _get_dict_batch_shapes(self, dict_: Dict) -> List: 107 | """Returns batch shapes of all tensors in a dictionary 108 | 109 | Args: 110 | dict_: The dictionary to get the batch shapes of. 111 | 112 | Returns: 113 | The batch shapes of all tensors in the dictionary. 114 | """ 115 | batch_shapes = [] 116 | for k, v in dict_.items(): 117 | if k in self._static_field: 118 | continue 119 | if isinstance(v, torch.Tensor): 120 | batch_shapes.append(v.shape[:-1]) 121 | elif isinstance(v, TensorDataclass): 122 | batch_shapes.append(v.shape) 123 | return batch_shapes 124 | 125 | def _broadcast_dict_fields(self, dict_: Dict, batch_shape) -> Dict: 126 | """Broadcasts all tensors in a dictionary according to batch_shape 127 | 128 | Args: 129 | dict_: The dictionary to broadcast. 130 | 131 | Returns: 132 | The broadcasted dictionary. 133 | """ 134 | new_dict = {} 135 | for k, v in dict_.items(): 136 | if k in self._static_field: 137 | continue 138 | if isinstance(v, torch.Tensor): 139 | new_dict[k] = v.broadcast_to((*batch_shape, v.shape[-1])) 140 | elif isinstance(v, TensorDataclass): 141 | new_dict[k] = v.broadcast_to(batch_shape) 142 | return new_dict 143 | 144 | def __getitem__(self: TensorDataclassT, indices) -> TensorDataclassT: 145 | if isinstance(indices, torch.Tensor): 146 | return self._apply_exclude_static(lambda x: x[indices]) 147 | if isinstance(indices, (int, slice)): 148 | indices = (indices,) 149 | return self._apply_exclude_static(lambda x: x[indices + (slice(None),)]) 150 | 151 | def __setitem__(self, indices, value) -> NoReturn: 152 | raise RuntimeError( 153 | "Index assignment is not supported for TensorDataclass" 154 | ) 155 | 156 | def __len__(self) -> int: 157 | return self.shape[0] 158 | 159 | def __bool__(self) -> bool: 160 | if len(self) == 0: 161 | raise ValueError( 162 | f"The truth value of {self.__class__.__name__} when `len(x) == 0` " 163 | "is ambiguous. Use `len(x)` or `x is not None`." 164 | ) 165 | return True 166 | 167 | def __iter__(self) -> Iterator[Tuple[str, Optional[torch.Tensor]]]: 168 | return iter( 169 | (f.name, getattr(self, f.name)) 170 | for f in dataclasses.fields(self) 171 | if f.name not in ('_shape', '_static_field') 172 | ) 173 | 174 | @property 175 | def shape(self) -> Tuple[int, ...]: 176 | """Returns the batch shape of the tensor dataclass.""" 177 | return self._shape 178 | 179 | @property 180 | def size(self) -> int: 181 | """Returns the number of elements in the tensor dataclass batch dimension.""" 182 | if len(self._shape) == 0: 183 | return 1 184 | return int(np.prod(self._shape)) 185 | 186 | @property 187 | def ndim(self) -> int: 188 | """Returns the number of dimensions of the tensor dataclass.""" 189 | return len(self._shape) 190 | 191 | @property 192 | def fields(self) -> Set[str]: 193 | return set([f[0] for f in self]) 194 | 195 | def _apply(self, fn) -> TensorDataclassT: 196 | """Applies the function fn on each of the Renderbuffer channels, if not None. 197 | Returns a new instance with the processed channels. 198 | """ 199 | data = {} 200 | for f in self: 201 | attr = f[1] 202 | data[f[0]] = None if attr is None else fn(attr) 203 | return dataclasses.replace( 204 | self, 205 | _static_field=self._static_field, 206 | **data, 207 | ) 208 | # return TensorDataclass(**data) 209 | 210 | def _apply_exclude_static(self, fn) -> TensorDataclassT: 211 | data = {} 212 | for f in self: 213 | if f[0] in self._static_field: 214 | continue 215 | attr = f[1] 216 | data[f[0]] = None if attr is None else fn(attr) 217 | return dataclasses.replace( 218 | self, 219 | _static_field=self._static_field, 220 | **data, 221 | ) 222 | 223 | @staticmethod 224 | def _apply_on_pair(td1, td2, fn) -> TensorDataclassT: 225 | """Applies the function fn on each of the Renderbuffer channels, if not None. 226 | Returns a new instance with the processed channels. 227 | """ 228 | joint_fields = TensorDataclass._join_fields( 229 | td1, td2 230 | ) # Union of field names and tuples of values 231 | combined_channels = map( 232 | fn, joint_fields.values() 233 | ) # Invoke on pair per Renderbuffer field 234 | return dataclasses.replace( 235 | td1, 236 | _static_field=td1._static_field.union(td2._static_field), 237 | **dict(zip(joint_fields.keys(), combined_channels)), 238 | ) 239 | # return TensorDataclass(**dict(zip(joint_fields.keys(), combined_channels))) # Pack combined fields to a new rb 240 | 241 | @staticmethod 242 | def _apply_on_list(tds, fn) -> TensorDataclassT: 243 | joint_fields = set().union(*[td.fields for td in tds]) 244 | joint_fields = { 245 | f: [getattr(td, f, None) for td in tds] for f in joint_fields 246 | } 247 | combined_channels = map(fn, joint_fields.values()) 248 | 249 | return dataclasses.replace( 250 | tds[0], 251 | _static_field=tds[0]._static_field.union( 252 | *[td._static_field for td in tds[1:]] 253 | ), 254 | **dict(zip(joint_fields.keys(), combined_channels)), 255 | ) 256 | 257 | @staticmethod 258 | def _join_fields(td1, td2): 259 | """Creates a joint mapping of renderbuffer fields in a format of 260 | { 261 | channel1_name: (rb1.c1, rb2.c1), 262 | channel2_name: (rb1.c2, rb2.cb), 263 | channel3_name: (rb1.c1, None), # rb2 doesn't define channel3 264 | } 265 | If a renderbuffer does not have define a specific channel, None is returned. 266 | """ 267 | joint_fields = td1.fields.union(td2.fields) 268 | return { 269 | f: (getattr(td1, f, None), getattr(td2, f, None)) 270 | for f in joint_fields 271 | } 272 | 273 | def numpy_dict(self) -> Dict[str, np.array]: 274 | """This function returns a dictionary of numpy arrays containing the data of each channel. 275 | 276 | Returns: 277 | (Dict[str, numpy.Array]) 278 | a dictionary with entries of (channel_name, channel_data) 279 | """ 280 | _dict = dict(iter(self)) 281 | _dict = {k: v.numpy() for k, v in _dict.items() if v is not None} 282 | return _dict 283 | 284 | def reshape( 285 | self: TensorDataclassT, shape: Tuple[int, ...] 286 | ) -> TensorDataclassT: 287 | """Returns a new TensorDataclass with the same data but with a new shape. 288 | 289 | This should deepcopy as well. 290 | 291 | Args: 292 | shape: The new shape of the tensor dataclass. 293 | 294 | Returns: 295 | A new TensorDataclass with the same data but with a new shape. 296 | """ 297 | if isinstance(shape, int): 298 | shape = (shape,) 299 | return self._apply_exclude_static( 300 | lambda x: x.reshape((*shape, x.shape[-1])) 301 | ) 302 | 303 | def flatten(self: TensorDataclassT) -> TensorDataclassT: 304 | """Returns a new TensorDataclass with flattened batch dimensions 305 | 306 | Returns: 307 | TensorDataclass: A new TensorDataclass with the same data but with a new shape. 308 | """ 309 | return self.reshape((-1,)) 310 | 311 | def broadcast_to( 312 | self: TensorDataclassT, 313 | shape: Union[torch.Size, Tuple[int, ...]], 314 | ) -> TensorDataclassT: 315 | """Returns a new TensorDataclass broadcast to new shape. 316 | 317 | Changes to the original tensor dataclass should effect the returned tensor dataclass, 318 | meaning it is NOT a deepcopy, and they are still linked. 319 | 320 | Args: 321 | shape: The new shape of the tensor dataclass. 322 | 323 | Returns: 324 | A new TensorDataclass with the same data but with a new shape. 325 | """ 326 | return self._apply_exclude_static( 327 | lambda x: x.broadcast_to((*shape, x.shape[-1])) 328 | ) 329 | 330 | def to(self: TensorDataclassT, device) -> TensorDataclassT: 331 | """Returns a new TensorDataclass with the same data but on the specified device. 332 | 333 | Args: 334 | device: The device to place the tensor dataclass. 335 | 336 | Returns: 337 | A new TensorDataclass with the same data but on the specified device. 338 | """ 339 | return self._apply(lambda x: x.to(device)) 340 | 341 | def cuda(self, non_blocking=False) -> TensorDataclassT: 342 | """Shifts the renderbuffer to the default torch cuda device""" 343 | fn = lambda x: x.cuda(non_blocking=non_blocking) 344 | return self._apply(fn) 345 | 346 | def cpu(self) -> TensorDataclassT: 347 | """Shifts the renderbuffer to the torch cpu device""" 348 | fn = lambda x: x.cpu() 349 | return self._apply(fn) 350 | 351 | def detach(self) -> TensorDataclassT: 352 | """Detaches the gradients of all channel tensors of the renderbuffer""" 353 | fn = lambda x: x.detach() 354 | return self._apply(fn) 355 | 356 | @staticmethod 357 | def direct_cat( 358 | tds: List[TensorDataclassT], 359 | dim: int = 0, 360 | ) -> TensorDataclassT: 361 | # cat_func = partial(torch.cat, dim=dim) 362 | def cat_func(arr): 363 | _arr = [ele for ele in arr if ele is not None] 364 | if 0 == len(_arr): 365 | return None 366 | return torch.cat(_arr, dim=dim) 367 | 368 | return TensorDataclass._apply_on_list(tds, cat_func) 369 | 370 | @staticmethod 371 | def direct_stack( 372 | tds: List[TensorDataclassT], 373 | dim: int = 0, 374 | ) -> TensorDataclassT: 375 | # cat_func = partial(torch.cat, dim=dim) 376 | def cat_func(arr): 377 | _arr = [ele for ele in arr if ele is not None] 378 | if 0 == len(_arr): 379 | return None 380 | return torch.stack(_arr, dim=dim) 381 | 382 | return TensorDataclass._apply_on_list(tds, cat_func) 383 | 384 | def cat(self, other: TensorDataclassT, dim: int = 0) -> TensorDataclassT: 385 | """Concatenates the channels of self and other RenderBuffers. 386 | If a channel only exists in one of the RBs, that channel will be returned as is. 387 | For channels that exists in both RBs, the spatial dimensions are assumed to be identical except for the 388 | concatenated dimension. 389 | 390 | Args: 391 | other (TensorDataclass) A second buffer to concatenate to the current buffer. 392 | dim (int): The index of spatial dimension used to concat the channels 393 | 394 | Returns: 395 | A new TensorDataclass with the concatenated channels. 396 | """ 397 | 398 | def _cat(pair): 399 | if None not in pair: 400 | # Concatenating tensors of different dims where one is unsqueezed with dimensionality 1 401 | if ( 402 | pair[0].ndim == (pair[1].ndim + 1) 403 | and pair[0].shape[-1] == 1 404 | ): 405 | pair = (pair[0], pair[1].unsqueeze(-1)) 406 | elif ( 407 | pair[1].ndim == (pair[0].ndim + 1) 408 | and pair[1].shape[-1] == 1 409 | ): 410 | pair = (pair[0].unsqueeze(-1), pair[1]) 411 | return torch.cat(pair, dim=dim) 412 | elif ( 413 | pair[0] is not None and pair[1] is None 414 | ): # Channel is None for other but not self 415 | return pair[0] 416 | elif ( 417 | pair[0] is None and pair[1] is not None 418 | ): # Channel is None for self but not other 419 | return pair[1] 420 | else: 421 | return None 422 | 423 | return TensorDataclass._apply_on_pair(self, other, _cat) 424 | -------------------------------------------------------------------------------- /utils/writer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from abc import abstractmethod 3 | from pathlib import Path 4 | from typing import Any, Dict, List, Union 5 | from collections import ChainMap 6 | from loguru import logger 7 | from termcolor import colored 8 | 9 | import torch 10 | 11 | # from torch.utils.tensorboard import SummaryWriter 12 | from tensorboardX import SummaryWriter 13 | from torchtyping import TensorType 14 | 15 | to8b = lambda x: (255 * torch.clamp(x, min=0, max=1)).to(torch.uint8) 16 | 17 | 18 | class Writer: 19 | """Writer class""" 20 | 21 | def __init__(self): 22 | self.std_logger = logger 23 | 24 | @abstractmethod 25 | def write_image( 26 | self, 27 | name: str, 28 | image: TensorType["H", "W", "C"], 29 | step: int, 30 | ) -> None: 31 | """method to write out image 32 | Args: 33 | name: data identifier 34 | image: rendered image to write 35 | step: the time step to log 36 | """ 37 | raise NotImplementedError 38 | 39 | @abstractmethod 40 | def write_scalar( 41 | self, 42 | name: str, 43 | scalar: Union[float, torch.Tensor], 44 | step: int, 45 | ) -> None: 46 | """Required method to write a single scalar value to the logger 47 | Args: 48 | name: data identifier 49 | scalar: value to write out 50 | step: the time step to log 51 | """ 52 | raise NotImplementedError 53 | 54 | def write_scalar_dict( 55 | self, 56 | name: str, 57 | scalar_dict: Dict[str, Any], 58 | step: int, 59 | ) -> None: 60 | """Function that writes out all scalars from a given dictionary to the logger 61 | Args: 62 | scalar_dict: dictionary containing all scalar values with key names and quantities 63 | step: the time step to log 64 | """ 65 | for key, scalar in scalar_dict.items(): 66 | try: 67 | float_scalar = float(scalar) 68 | self.write_scalar(name + "/" + key, float_scalar, step) 69 | except: 70 | pass 71 | 72 | def write_scalar_dicts( 73 | self, 74 | names: List[str], 75 | scalar_dicts: List[Dict[str, Any]], 76 | step: int, 77 | ) -> None: 78 | # self.std_logger.info(scalar_dicts) 79 | self.std_logger.info( 80 | ''.join( 81 | [ 82 | '{}{} '.format( 83 | colored('{}:'.format(k), 'light_magenta'), 84 | v 85 | if k != 'ETA' 86 | else str(datetime.timedelta(seconds=int(v))), 87 | ) 88 | for k, v in dict(ChainMap(*scalar_dicts)).items() 89 | ] 90 | ) 91 | ) 92 | assert len(names) == len(scalar_dicts) 93 | for n, d in zip(names, scalar_dicts): 94 | self.write_scalar_dict(n, d, step) 95 | 96 | 97 | class TensorboardWriter(Writer): 98 | """Tensorboard Writer Class""" 99 | 100 | def __init__(self, log_dir: Path): 101 | super(TensorboardWriter, self).__init__() 102 | self.tb_writer = SummaryWriter(log_dir=str(log_dir)) 103 | 104 | def write_image( 105 | self, 106 | name: str, 107 | image: TensorType["H", "W", "C"], 108 | step: int, 109 | ) -> None: 110 | image = to8b(image) 111 | self.tb_writer.add_image(name, image, step, dataformats="HWC") 112 | 113 | def write_scalar( 114 | self, 115 | name: str, 116 | scalar: Union[float, torch.Tensor], 117 | step: int, 118 | ) -> None: 119 | self.tb_writer.add_scalar(name, scalar, step) 120 | 121 | def write_config(self, config: str): # pylint: disable=unused-argument 122 | self.tb_writer.add_text("config", config) 123 | --------------------------------------------------------------------------------