├── .github
└── workflows
│ └── docker-build.yaml
├── .gitignore
├── Dockerfile
├── README.md
├── assets
├── lego.mp4
├── overview.jpg
└── teaser.jpg
├── config_files
├── blender
│ └── TriMipRF.gin
└── ms_blender
│ └── TriMipRF.gin
├── dataset
├── __init__.py
├── parsers
│ ├── __init__.py
│ ├── nerf_synthetic.py
│ └── nerf_synthetic_multiscale.py
├── ray_dataset.py
└── utils
│ ├── __init__.py
│ ├── cameras.py
│ ├── io.py
│ └── utils.py
├── main.py
├── neural_field
├── __init__.py
├── encoding
│ ├── __init__.py
│ └── tri_mip.py
├── field
│ ├── __init__.py
│ └── trimipRF.py
├── model
│ ├── RFModel.py
│ ├── __init__.py
│ └── trimipRF.py
└── nn_utils
│ ├── __init__.py
│ └── activations.py
├── pyproject.toml
├── requirements.txt
├── scripts
└── convert_blender_data.py
├── trainer
├── __init__.py
└── trainer.py
└── utils
├── __init__.py
├── colormaps.py
├── common.py
├── ray.py
├── render_buffer.py
├── tensor_dataclass.py
└── writer.py
/.github/workflows/docker-build.yaml:
--------------------------------------------------------------------------------
1 | name: Docker build & test
2 |
3 | on:
4 | push:
5 | branches:
6 | - develop
7 | - staging
8 | - prod
9 | pull_request:
10 | types: [ synchronize, opened, reopened, labeled ]
11 |
12 | jobs:
13 | build-docker-image:
14 | if: github.event_name == 'push' || (github.event_name == 'pull_request' && contains(github.event.*.labels.*.name, 'build docker'))
15 |
16 | runs-on: buildjet-4vcpu-ubuntu-2204
17 |
18 | permissions:
19 | contents: read
20 | packages: write
21 |
22 | steps:
23 | - name: Checkout
24 | uses: actions/checkout@v3
25 |
26 | - name: Get image tag
27 | shell: bash
28 | run: |
29 | echo "IMAGE_VERSION=pr-${{ github.event.pull_request.number }}-$(echo ${{github.sha}} | cut -c1-7)" >> $GITHUB_ENV
30 | - name: Echo image tag
31 | shell: bash
32 | run: |
33 | echo "using image tag: ${{ env.IMAGE_VERSION }}"
34 |
35 | - name: Set up Docker Buildx
36 | uses: docker/setup-buildx-action@v2
37 |
38 | - name: Login to Docker Hub
39 | uses: docker/login-action@v2
40 | with:
41 | username: ${{ secrets.DOCKERHUB_USERNAME }}
42 | password: ${{ secrets.DOCKERHUB_TOKEN }}
43 |
44 | - name: Build and push
45 | uses: docker/build-push-action@v4
46 | with:
47 | context: .
48 | file: ./Dockerfile
49 | push: true
50 | tags: joshpwrk/trimip:${{ env.IMAGE_VERSION }}
51 | cache-from: type=registry,ref=joshpwrk/trimip:${{ env.IMAGE_VERSION }}
52 | cache-to: type=inline
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | .idea
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | !scripts/downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 | cover/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | .pybuilder/
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | # For a library or package, you might want to ignore these files since the code is
89 | # intended to run in multiple environments; otherwise, check them in:
90 | # .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # poetry
100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101 | # This is especially recommended for binary packages to ensure reproducibility, and is more
102 | # commonly ignored for libraries.
103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104 | #poetry.lock
105 |
106 | # pdm
107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108 | #pdm.lock
109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110 | # in version control.
111 | # https://pdm.fming.dev/#use-with-ide
112 | .pdm.toml
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .envrc
127 | .venv
128 | env/
129 | venv/
130 | ENV/
131 | env.bak/
132 | venv.bak/
133 |
134 | # Spyder project settings
135 | .spyderproject
136 | .spyproject
137 |
138 | # Rope project settings
139 | .ropeproject
140 |
141 | # mkdocs documentation
142 | /site
143 |
144 | # mypy
145 | .mypy_cache/
146 | .dmypy.json
147 | dmypy.json
148 |
149 | # Pyre type checker
150 | .pyre/
151 |
152 | # pytype static type analyzer
153 | .pytype/
154 |
155 | # Cython debug symbols
156 | cython_debug/
157 |
158 | # PyCharm
159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161 | # and can be added to the global gitignore or merged into this file. For a more nuclear
162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163 | #.idea/
164 |
165 | # Experiments and outputs
166 | runs/
167 | outputs/
168 | # tensorboard log files
169 | events.out.*
170 |
171 | # Data
172 | data
173 | !*/data
174 |
175 | # Misc
176 | old/
177 | temp*
178 | .nfs*
179 | external/
180 | __MACOSX/
181 | outputs*/
182 | node_modules/
183 | bash/
184 | cache/
185 | package-lock.json
186 | camera_paths/
187 | exp/
188 | */._DS_Store
189 | .vscode/
190 | launch_logs/
191 | **/binary
192 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
2 |
3 | # Set environment variables to non-interactive (this prevents some prompts)
4 | ENV DEBIAN_FRONTEND=noninteractive
5 |
6 | # Update and install some essential packages
7 | RUN apt-get update && apt-get install -y \
8 | software-properties-common \
9 | build-essential \
10 | curl \
11 | git \
12 | && rm -rf /var/lib/apt/lists/*
13 |
14 | # install requirements for trimip -r requirements.txt
15 | RUN apt-get update && apt-get install -y \
16 | ffmpeg \
17 | libavformat-dev \
18 | libavcodec-dev \
19 | libavdevice-dev \
20 | libavutil-dev \
21 | libavutil-dev \
22 | libavfilter-dev \
23 | libswscale-dev \
24 | libswresample-dev
25 |
26 | # RUN apt-get install -y libavformat-dev
27 | # RUN apt-get install -y libavcodec-dev
28 | # RUN apt-get install -y libavdevice-dev
29 | # RUN apt-get install -y libavutil-dev
30 | # RUN apt-get install -y libavfilter-dev
31 | # RUN apt-get install -y libswscale-dev
32 | # RUN apt-get install -y libswresample-dev
33 |
34 | # Install Python 3 (Ubuntu 22.04 comes with Python 3.10)
35 | RUN apt-get update && apt-get install -y python3 python3-pip
36 |
37 | # Set the working directory inside the container
38 | WORKDIR /usr/src/app
39 |
40 | COPY requirements.txt ./
41 | RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
42 | RUN TCNN_CUDA_ARCHITECTURES=89 pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
43 |
44 | # Install nvdiffrast: https://nvlabs.github.io/nvdiffrast/#linux
45 |
46 | RUN pip3 install --no-cache-dir -r requirements.txt
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Tri-MipRF
2 |
3 | Official PyTorch implementation of the paper:
4 |
5 | > **Tri-MipRF: Tri-Mip Representation for Efficient Anti-Aliasing Neural Radiance Fields**
6 | >
7 | > ***ICCV 2023***
8 | >
9 | > Wenbo Hu, Yuling Wang, Lin Ma, Bangbang Yang, Lin Gao, Xiao Liu, Yuewen Ma
10 | >
11 | >
12 |
13 | https://github.com/wbhu/Tri-MipRF/assets/11834645/6c50baf7-ac46-46fd-a36f-172f99ea9054
14 |
15 |
16 | > Instant-ngp (left) suffers from aliasing in distant or low-resolution views and blurriness in
17 | > close-up shots, while Tri-MipRF (right) renders both fine-grained details in close-ups
18 | > and high-fidelity zoomed-out images.
19 |
20 |
21 |
22 |
23 |
24 | > To render a pixel, we emit a cone from the camera’s projection center to the pixel on the
25 | > image plane, and then we cast a set of spheres inside the cone. Next, the spheres are
26 | > orthogonally projected
27 | > on the three planes and featurized by our Tri-Mip encoding. After that the feature
28 | > vector is fed into the tiny MLP to non-linearly map to
29 | > density and color. Finally, the density and
30 | > color of the spheres are integrated using volume rendering to produce final color for the pixel.
31 |
32 |
33 |
34 |
35 |
36 |
37 | > Our Tri-MipRF achieves state-of-the-art rendering quality while can be reconstructed efficiently,
38 | > compared with cutting-edge radiance fields methods, e.g., NeRF, MipNeRF, Plenoxels,
39 | > TensoRF, and Instant-ngp. Equipping Instant-ngp with super-sampling (named Instant-ngp↑5×)
40 | > improves the rendering quality to a certain extent but significantly slows down the reconstruction.
41 |
42 | ## **Installation**
43 | Please install the following dependencies first
44 | - [PyTorch (1.13.1 + cu11.6)](https://pytorch.org/get-started/locally/)
45 | - [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn)
46 | - [nvdiffrast](https://nvlabs.github.io/nvdiffrast/)
47 |
48 | And then install the following dependencies using *pip*
49 | ```shell
50 | pip3 install av==9.2.0 \
51 | beautifulsoup4==4.11.1 \
52 | entrypoints==0.4 \
53 | gdown==4.5.1 \
54 | gin-config==0.5.0 \
55 | h5py==3.7.0 \
56 | imageio==2.21.1 \
57 | imageio-ffmpeg \
58 | ipython==7.19.0 \
59 | kornia==0.6.8 \
60 | loguru==0.6.0 \
61 | lpips==0.1.4 \
62 | mediapy==1.1.0 \
63 | mmcv==1.6.2 \
64 | ninja==1.10.2.3 \
65 | numpy==1.23.3 \
66 | open3d==0.16.0 \
67 | opencv-python==4.6.0.66 \
68 | pandas==1.5.0 \
69 | Pillow==9.2.0 \
70 | plotly==5.7.0 \
71 | pycurl==7.43.0.6 \
72 | PyMCubes==0.1.2 \
73 | pyransac3d==0.6.0 \
74 | PyYAML==6.0 \
75 | rich==12.6.0 \
76 | scipy==1.9.2 \
77 | tensorboard==2.9.0 \
78 | torch-fidelity==0.3.0 \
79 | torchmetrics==0.10.0 \
80 | torchtyping==0.1.4 \
81 | tqdm==4.64.1 \
82 | tyro==0.3.25 \
83 | appdirs \
84 | nerfacc==0.3.5 \
85 | plyfile \
86 | scikit-image \
87 | trimesh \
88 | torch_efficient_distloss \
89 | umsgpack \
90 | pyngrok \
91 | cryptography==39.0.2 \
92 | omegaconf==2.2.3 \
93 | segmentation-refinement \
94 | xatlas \
95 | protobuf==3.20.0 \
96 | jinja2 \
97 | click==8.1.7 \
98 | tensorboardx \
99 | termcolor
100 | ```
101 |
102 | ## **Data**
103 |
104 | ### nerf_synthetic dataset
105 | Please download and unzip `nerf_synthetic.zip` from the [NeRF official Google Drive](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1).
106 |
107 | ### Generate multiscale dataset
108 | Please generate it by
109 | ```shell
110 | python scripts/convert_blender_data.py --blenderdir /path/to/nerf_synthetic --outdir /path/to/nerf_synthetic_multiscale
111 | ```
112 |
113 | ## **Training and evaluation**
114 | ```shell
115 | python main.py --ginc config_files/ms_blender/TriMipRF.gin
116 | ```
117 |
118 |
119 | ## **TODO**
120 |
121 | - [x] ~~Release source code~~.
122 |
123 | ## **Citation**
124 |
125 | If you find the code useful for your work, please star this repo and consider citing:
126 |
127 | ```
128 | @inproceedings{hu2023Tri-MipRF,
129 | author = {Hu, Wenbo and Wang, Yuling and Ma, Lin and Yang, Bangbang and Gao, Lin and Liu, Xiao and Ma, Yuewen},
130 | title = {Tri-MipRF: Tri-Mip Representation for Efficient Anti-Aliasing Neural Radiance Fields},
131 | booktitle = {ICCV},
132 | year = {2023}
133 | }
134 | ```
135 |
136 |
137 | ## **Related Work**
138 |
139 | - [Mip-NeRF (ICCV 2021)](https://jonbarron.info/mipnerf/)
140 | - [Instant-ngp (SIGGRAPH 2022)](https://nvlabs.github.io/instant-ngp/)
141 | - [Zip-NeRF (ICCV 2023)](https://jonbarron.info/zipnerf/)
142 |
--------------------------------------------------------------------------------
/assets/lego.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wbhu/Tri-MipRF/5b54a274d338ab64405d8123837e2d393c2262c2/assets/lego.mp4
--------------------------------------------------------------------------------
/assets/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wbhu/Tri-MipRF/5b54a274d338ab64405d8123837e2d393c2262c2/assets/overview.jpg
--------------------------------------------------------------------------------
/assets/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wbhu/Tri-MipRF/5b54a274d338ab64405d8123837e2d393c2262c2/assets/teaser.jpg
--------------------------------------------------------------------------------
/config_files/blender/TriMipRF.gin:
--------------------------------------------------------------------------------
1 | main.train_split = 'trainval'
2 | main.num_workers = 16
3 | main.model_name = 'Tri-MipRF'
4 | main.batch_size = 28 # this is not the actual batch_size, but the prefetch size
5 |
6 | RayDataset.base_path = '/path/to/nerf_synthetic'
7 | RayDataset.scene = 'chair'
8 | RayDataset.scene_type = 'nerf_synthetic'
9 |
10 |
11 | Trainer.base_exp_dir = '/path/to/experiment/lod/dir'
12 | Trainer.exp_name = None
13 | Trainer.eval_step = 25000
14 | Trainer.log_step = 1000
15 | Trainer.max_steps = 25001
16 | Trainer.target_sample_batch_size = 262144
17 |
18 |
19 |
--------------------------------------------------------------------------------
/config_files/ms_blender/TriMipRF.gin:
--------------------------------------------------------------------------------
1 | main.train_split = 'trainval'
2 | main.num_workers = 16
3 | main.model_name = 'Tri-MipRF'
4 | main.batch_size = 24 # this is not the actual batch_size, but the prefetch size
5 |
6 | RayDataset.base_path = '/path/to/nerf_synthetic_multiscale'
7 | RayDataset.scene = 'chair'
8 | RayDataset.scene_type = 'nerf_synthetic_multiscale'
9 |
10 |
11 | Trainer.base_exp_dir = '/path/to/experiment/lod/dir'
12 | Trainer.exp_name = None
13 | Trainer.eval_step = 25000
14 | Trainer.log_step = 1000
15 | Trainer.max_steps = 25001
16 | Trainer.target_sample_batch_size = 262144
17 |
18 |
19 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wbhu/Tri-MipRF/5b54a274d338ab64405d8123837e2d393c2262c2/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/parsers/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | from . import nerf_synthetic
4 | from . import nerf_synthetic_multiscale
5 |
6 |
7 | def get_parser(parser_name: str) -> Callable:
8 | if 'nerf_synthetic' == parser_name:
9 | return nerf_synthetic.load_data
10 | elif 'nerf_synthetic_multiscale' == parser_name:
11 | return nerf_synthetic_multiscale.load_data
12 | else:
13 | raise NotImplementedError
14 |
--------------------------------------------------------------------------------
/dataset/parsers/nerf_synthetic.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import numpy as np
3 |
4 | import dataset.utils.io as data_io
5 | from dataset.utils.cameras import PinholeCamera
6 |
7 |
8 | def load_data(base_path: Path, scene: str, split: str):
9 | data_path = base_path / scene
10 | splits = ['train', 'val'] if split == "trainval" else [split]
11 | meta = None
12 | for s in splits:
13 | meta_path = data_path / "transforms_{}.json".format(s)
14 | m = data_io.load_from_json(meta_path)
15 | if meta is None:
16 | meta = m
17 | else:
18 | for k, v in meta.items():
19 | if type(v) is list:
20 | v.extend(m[k])
21 | else:
22 | assert v == m[k]
23 |
24 | image_height, image_width = 800, 800
25 | camera_angle_x = float(meta["camera_angle_x"])
26 | focal_length = 0.5 * image_width / np.tan(0.5 * camera_angle_x)
27 | cx = image_width / 2.0
28 | cy = image_height / 2.0
29 | cameras = [
30 | PinholeCamera(
31 | fx=focal_length,
32 | fy=focal_length,
33 | cx=cx,
34 | cy=cy,
35 | width=image_width,
36 | height=image_height,
37 | )
38 | ]
39 | cam_num = len(cameras)
40 |
41 | frames, poses = {k: [] for k in range(len(cameras))}, {
42 | k: [] for k in range(len(cameras))
43 | }
44 | index = 0
45 | for frame in meta["frames"]:
46 | fname = data_path / Path(frame["file_path"].replace("./", "") + ".png")
47 | frames[index % cam_num].append(
48 | {
49 | 'image_filename': fname,
50 | 'lossmult': 1.0,
51 | }
52 | )
53 | poses[index % cam_num].append(
54 | np.array(frame["transform_matrix"]).astype(np.float32)
55 | )
56 | index = index + 1
57 |
58 | aabb = np.array([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5])
59 |
60 | outputs = {
61 | 'frames': frames,
62 | 'poses': poses,
63 | 'cameras': cameras,
64 | 'aabb': aabb,
65 | }
66 | return outputs
67 |
--------------------------------------------------------------------------------
/dataset/parsers/nerf_synthetic_multiscale.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import numpy as np
3 | from tqdm import tqdm
4 |
5 | import dataset.utils.io as data_io
6 | from dataset.utils.cameras import PinholeCamera
7 |
8 |
9 | def load_data(base_path: Path, scene: str, split: str, cam_num: int = 4):
10 | # ipdb.set_trace()
11 | data_path = base_path / scene
12 | meta_path = data_path / 'metadata.json'
13 |
14 | splits = ['train', 'val'] if split == "trainval" else [split]
15 | meta = None
16 | for s in splits:
17 | m = data_io.load_from_json(meta_path)[s]
18 | if meta is None:
19 | meta = m
20 | else:
21 | for k, v in meta.items():
22 | v.extend(m[k])
23 |
24 | pix2cam = meta['pix2cam']
25 | poses = meta['cam2world']
26 | image_width = meta['width']
27 | image_height = meta['height']
28 | lossmult = meta['lossmult']
29 |
30 | assert image_height[0] == image_height[cam_num]
31 | assert image_width[0] == image_width[cam_num]
32 | assert pix2cam[0] == pix2cam[cam_num]
33 | assert lossmult[0] == lossmult[cam_num]
34 | cameras = []
35 | for i in range(cam_num):
36 | k = np.linalg.inv(pix2cam[i])
37 | fx = k[0, 0]
38 | fy = -k[1, 1]
39 | cx = -k[0, 2]
40 | cy = -k[1, 2]
41 | cam = PinholeCamera(
42 | fx=fx,
43 | fy=fy,
44 | cx=cx,
45 | cy=cy,
46 | width=image_width[i],
47 | height=image_height[i],
48 | # loss_multi=lossmult[i],
49 | )
50 | cameras.append(cam)
51 |
52 | frames = {k: [] for k in range(len(cameras))}
53 | index = 0
54 | for frame in tqdm(meta['file_path']):
55 | fname = data_path / frame
56 | frames[index % cam_num].append(
57 | {
58 | 'image_filename': fname,
59 | 'lossmult': lossmult[index],
60 | }
61 | )
62 | index = index + 1
63 | poses = {k: poses[k :: len(cameras)] for k in range(len(cameras))}
64 |
65 | aabb = np.array([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5])
66 | outputs = {
67 | 'frames': frames,
68 | 'poses': poses,
69 | 'cameras': cameras,
70 | 'aabb': aabb,
71 | }
72 | return outputs
73 |
74 |
75 | if __name__ == '__main__':
76 | data = load_data(
77 | Path('/mnt/bn/wbhu-nerf/Dataset/nerf_synthetic_multiscale'),
78 | 'lego',
79 | split='train',
80 | )
81 | pass
82 |
--------------------------------------------------------------------------------
/dataset/ray_dataset.py:
--------------------------------------------------------------------------------
1 | from concurrent.futures import ThreadPoolExecutor
2 | import multiprocessing
3 |
4 | import gin
5 | import numpy as np
6 | import torch
7 | from loguru import logger
8 | from torch.utils.data import Dataset, DataLoader
9 | from pathlib import Path
10 |
11 | from tqdm import tqdm
12 |
13 | from dataset.parsers import get_parser
14 | from dataset.utils import io as data_io
15 | from utils.ray import RayBundle
16 | from utils.render_buffer import RenderBuffer
17 | from utils.tensor_dataclass import TensorDataclass
18 |
19 |
20 | @gin.configurable()
21 | class RayDataset(Dataset):
22 | def __init__(
23 | self,
24 | base_path: str,
25 | scene: str = 'lego',
26 | scene_type: str = 'nerf_synthetic_multiscale',
27 | split: str = 'train',
28 | to_world: bool = True,
29 | num_rays: int = 8192,
30 | render_bkgd: str = 'white',
31 | **kwargs
32 | ):
33 | super().__init__()
34 | parser = get_parser(scene_type)
35 | data_source = parser(
36 | base_path=Path(base_path), scene=scene, split=split, **kwargs
37 | )
38 | self.training = split.find('train') >= 0
39 |
40 | self.cameras = data_source['cameras']
41 | self.ray_bundles = [c.build('cpu') for c in self.cameras]
42 | logger.info('==> Find {} cameras'.format(len(self.cameras)))
43 | self.poses = {
44 | k: torch.tensor(np.asarray(v)).float() # Nx4x4
45 | for k, v in data_source["poses"].items()
46 | }
47 | # parallel loading frames
48 | self.frames = {}
49 | for k, cam_frames in data_source['frames'].items():
50 | with ThreadPoolExecutor(
51 | max_workers=min(multiprocessing.cpu_count(), 32)
52 | ) as executor:
53 | frames = list(
54 | tqdm(
55 | executor.map(
56 | lambda f: torch.tensor(
57 | data_io.imread(f['image_filename'])
58 | ),
59 | cam_frames,
60 | ),
61 | total=len(cam_frames),
62 | dynamic_ncols=True,
63 | )
64 | )
65 | self.frames[k] = torch.stack(frames, dim=0)
66 | self.frame_number = {k: x.shape[0] for k, x in self.frames.items()}
67 | self.aabb = torch.tensor(np.asarray(data_source['aabb'])).float()
68 | self.loss_multi = {
69 | k: torch.tensor([x['lossmult'] for x in v])
70 | for k, v in data_source['frames'].items()
71 | }
72 | self.file_names = {
73 | k: [x['image_filename'].stem for x in v]
74 | for k, v in data_source['frames'].items()
75 | }
76 | self.to_world = to_world
77 | self.num_rays = num_rays
78 | self.render_bkgd = render_bkgd
79 |
80 | # try to read a data to initialize RenderBuffer subclass
81 | self[0]
82 |
83 | def __len__(self):
84 | if self.training:
85 | return 10**9 # hack of streaming dataset
86 | else:
87 | return sum([x.shape[0] for k, x in self.poses.items()])
88 |
89 | def update_num_rays(self, num_rays):
90 | self.num_rays = num_rays
91 |
92 | @torch.no_grad()
93 | def __getitem__(self, index):
94 | if self.training:
95 | rgb, c2w, cam_rays, loss_multi = [], [], [], []
96 | for cam_idx in range(len(self.cameras)):
97 | num_rays = int(
98 | self.num_rays
99 | * (1.0 / self.loss_multi[cam_idx][0])
100 | / sum([1.0 / v[0] for _, v in self.loss_multi.items()])
101 | )
102 | idx = torch.randint(
103 | 0,
104 | self.frames[cam_idx].shape[0],
105 | size=(num_rays,),
106 | )
107 | sample_x = torch.randint(
108 | 0,
109 | self.cameras[cam_idx].width,
110 | size=(num_rays,),
111 | ) # uniform sampling
112 | sample_y = torch.randint(
113 | 0,
114 | self.cameras[cam_idx].height,
115 | size=(num_rays,),
116 | ) # uniform sampling
117 | rgb.append(self.frames[cam_idx][idx, sample_y, sample_x])
118 | c2w.append(self.poses[cam_idx][idx])
119 | cam_rays.append(self.ray_bundles[cam_idx][sample_y, sample_x])
120 | loss_multi.append(self.loss_multi[cam_idx][idx, None])
121 | rgb = torch.cat(rgb, dim=0)
122 | c2w = torch.cat(c2w, dim=0)
123 | cam_rays = RayBundle.direct_cat(cam_rays, dim=0)
124 | loss_multi = torch.cat(loss_multi, dim=0)
125 | if 'white' == self.render_bkgd:
126 | render_bkgd = torch.ones_like(rgb[..., [-1]])
127 | elif 'rand' == self.render_bkgd:
128 | render_bkgd = torch.rand_like(rgb[..., :3])
129 | elif 'randn' == self.render_bkgd:
130 | render_bkgd = (torch.randn_like(rgb[..., :3]) + 0.5).clamp(
131 | 0.0, 1.0
132 | )
133 | else:
134 | raise NotImplementedError
135 |
136 | else:
137 | for cam_idx, num in self.frame_number.items():
138 | if index < num:
139 | break
140 | index = index - num
141 | num_rays = len(self.ray_bundles[cam_idx])
142 | idx = torch.ones(size=(num_rays,), dtype=torch.int64) * index
143 | sample_x, sample_y = torch.meshgrid(
144 | torch.arange(self.cameras[cam_idx].width),
145 | torch.arange(self.cameras[cam_idx].height),
146 | indexing="xy",
147 | )
148 | sample_x = sample_x.reshape(-1)
149 | sample_y = sample_y.reshape(-1)
150 |
151 | rgb = self.frames[cam_idx][idx, sample_y, sample_x]
152 | c2w = self.poses[cam_idx][idx]
153 | cam_rays = self.ray_bundles[cam_idx][sample_y, sample_x]
154 | loss_multi = self.loss_multi[cam_idx][idx, None]
155 | render_bkgd = torch.ones_like(rgb[..., [-1]])
156 |
157 | if self.to_world:
158 | cam_rays.directions = (
159 | c2w[:, :3, :3] @ cam_rays.directions[..., None]
160 | ).squeeze(-1)
161 | cam_rays.origins = c2w[:, :3, -1]
162 | target = RenderBuffer(
163 | rgb=rgb[..., :3] * rgb[..., [-1]]
164 | + (1.0 - rgb[..., [-1]]) * render_bkgd,
165 | render_bkgd=render_bkgd,
166 | # alpha=rgb[..., [-1]],
167 | loss_multi=loss_multi,
168 | )
169 | if not self.training:
170 | cam_rays = cam_rays.reshape(
171 | (self.cameras[cam_idx].height, self.cameras[cam_idx].width)
172 | )
173 | target = target.reshape(
174 | (self.cameras[cam_idx].height, self.cameras[cam_idx].width)
175 | )
176 | outputs = {
177 | # 'c2w': c2w,
178 | 'cam_rays': cam_rays,
179 | 'target': target,
180 | # 'idx': idx,
181 | }
182 | if not self.training:
183 | outputs['name'] = self.file_names[cam_idx][index]
184 | return outputs
185 |
186 |
187 | def ray_collate(batch):
188 | res = {k: [] for k in batch[0].keys()}
189 | for data in batch:
190 | for k, v in data.items():
191 | res[k].append(v)
192 | for k, v in res.items():
193 | if isinstance(v[0], RenderBuffer) or isinstance(v[0], RayBundle):
194 | res[k] = TensorDataclass.direct_cat(v, dim=0)
195 | else:
196 | res[k] = torch.cat(v, dim=0)
197 | return res
198 |
199 |
200 | if __name__ == '__main__':
201 | training_dataset = RayDataset(
202 | # '/mnt/bn/wbhu-nerf/Dataset/nerf_synthetic',
203 | '/mnt/bn/wbhu-nerf/Dataset/nerf_synthetic_multiscale',
204 | 'lego',
205 | # 'nerf_synthetic',
206 | 'nerf_synthetic_multiscale',
207 | )
208 | train_loader = iter(
209 | DataLoader(
210 | training_dataset,
211 | batch_size=8,
212 | shuffle=False,
213 | num_workers=0,
214 | collate_fn=ray_collate,
215 | pin_memory=True,
216 | worker_init_fn=None,
217 | pin_memory_device='cuda',
218 | )
219 | )
220 | test_dataset = RayDataset(
221 | # '/mnt/bn/wbhu-nerf/Dataset/nerf_synthetic',
222 | '/mnt/bn/wbhu-nerf/Dataset/nerf_synthetic_multiscale',
223 | 'lego',
224 | # 'nerf_synthetic',
225 | 'nerf_synthetic_multiscale',
226 | num_rays=81920,
227 | split='test',
228 | )
229 | for i in tqdm(range(1000)):
230 | data = next(train_loader)
231 | pass
232 | for i in tqdm(range(len(test_dataset))):
233 | data = test_dataset[i]
234 | pass
235 | pass
236 |
--------------------------------------------------------------------------------
/dataset/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wbhu/Tri-MipRF/5b54a274d338ab64405d8123837e2d393c2262c2/dataset/utils/__init__.py
--------------------------------------------------------------------------------
/dataset/utils/cameras.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 |
5 | from utils.ray import RayBundle
6 |
7 |
8 | class PinholeCamera:
9 | def __init__(
10 | self,
11 | fx: float,
12 | fy: float,
13 | cx: float,
14 | cy: float,
15 | width: int = None,
16 | height: int = None,
17 | coord_type: str = 'opengl',
18 | device: str = 'cuda:0',
19 | normalize_ray: bool = True,
20 | ):
21 | self.fx, self.fy, self.cx, self.cy = fx, fy, cx, cy
22 | self.width, self.height = width, height
23 | self.coord_type = coord_type
24 | self.K = torch.tensor(
25 | [
26 | [self.fx, 0, self.cx],
27 | [0, self.fy, self.cy],
28 | [0, 0, 1],
29 | ],
30 | dtype=torch.float32,
31 | ) # (3, 3)
32 | self.device = device
33 | self.normalize_ray = normalize_ray
34 | self.near = 0.1
35 | self.far = 100
36 | if self.coord_type == 'opencv':
37 | self.sign_z = 1.0
38 | elif self.coord_type == 'opengl':
39 | self.sign_z = -1.0
40 | else:
41 | raise ValueError
42 | self.ray_bundle = None
43 |
44 | def build(self, device):
45 | x, y = torch.meshgrid(
46 | torch.arange(self.width, device=device),
47 | torch.arange(self.height, device=device),
48 | indexing="xy",
49 | )
50 | directions = F.pad(
51 | torch.stack(
52 | [
53 | (x - self.K[0, 2] + 0.5) / self.K[0, 0],
54 | (y - self.K[1, 2] + 0.5) / self.K[1, 1] * self.sign_z,
55 | ],
56 | dim=-1,
57 | ),
58 | (0, 1),
59 | value=self.sign_z,
60 | ) # [H,W,3]
61 | # Distance from each unit-norm direction vector to its x-axis neighbor
62 | dx = torch.linalg.norm(
63 | (directions[:, :-1, :] - directions[:, 1:, :]),
64 | dim=-1,
65 | keepdims=True,
66 | ) # [H,W-1,1]
67 | dx = torch.cat([dx, dx[:, -2:-1, :]], 1) # [H,W,1]
68 | dy = torch.linalg.norm(
69 | (directions[:-1, :, :] - directions[1:, :, :]),
70 | dim=-1,
71 | keepdims=True,
72 | ) # [H-1,W,1]
73 | dy = torch.cat([dy, dy[-2:-1, :, :]], 0) # [H,W,1]
74 | # Cut the distance in half, and then round it out so that it's
75 | # halfway between inscribed by / circumscribed about the pixel.
76 | area = dx * dy
77 | radii = torch.sqrt(area / torch.pi)
78 | if self.normalize_ray:
79 | directions = directions / torch.linalg.norm(
80 | directions, dim=-1, keepdims=True
81 | )
82 | self.ray_bundle = RayBundle(
83 | origins=torch.zeros_like(directions),
84 | directions=directions,
85 | radiis=radii,
86 | ray_cos=torch.matmul(
87 | directions,
88 | torch.tensor([[0.0, 0.0, self.sign_z]], device=device).T,
89 | ),
90 | )
91 | return self.ray_bundle
92 |
93 | @property
94 | def fov_y(self):
95 | return np.degrees(2 * np.arctan(self.cy / self.fy))
96 |
97 | def get_proj(self):
98 | # projection
99 | proj = np.eye(4, dtype=np.float32)
100 | proj[0, 0] = 2 * self.fx / self.width
101 | proj[1, 1] = 2 * self.fy / self.height
102 | proj[0, 2] = 2 * self.cx / self.width - 1
103 | proj[1, 2] = 2 * self.cy / self.height - 1
104 | proj[2, 2] = -(self.far + self.near) / (self.far - self.near)
105 | proj[2, 3] = -2 * self.far * self.near / (self.far - self.near)
106 | proj[3, 2] = -1
107 | proj[3, 3] = 0
108 | return proj
109 |
110 | def get_PVM(self, c2w):
111 | c2w = c2w.copy()
112 | # to right up backward (opengl)
113 | c2w[:3, 1] *= -1
114 | c2w[:3, 2] *= -1
115 | w2c = np.linalg.inv(c2w)
116 | return np.matmul(self.get_proj(), w2c)
117 |
--------------------------------------------------------------------------------
/dataset/utils/io.py:
--------------------------------------------------------------------------------
1 | import json
2 | import gzip
3 | import numpy as np
4 | import cv2
5 | from pathlib import Path
6 | from typing import Union, Any
7 | import open3d as o3d
8 |
9 | cv2.ocl.setUseOpenCL(False)
10 | cv2.setNumThreads(0)
11 |
12 |
13 | def load_from_json(file_path: Path):
14 | assert file_path.suffix == ".json"
15 | with open(file_path, encoding="UTF-8") as file:
16 | return json.load(file)
17 |
18 |
19 | def load_from_jgz(file_path: Path):
20 | assert file_path.suffix == ".jgz"
21 | with gzip.GzipFile(file_path, "rb") as file:
22 | return json.load(file)
23 |
24 |
25 | def write_to_json(file_path: Path, content: dict):
26 | assert file_path.suffix == ".json"
27 | with open(file_path, "w", encoding="UTF-8") as file:
28 | json.dump(content, file)
29 |
30 |
31 | def imread(file_path: Path, dtype: np.dtype = np.float32) -> np.ndarray:
32 | im = cv2.imread(str(file_path), flags=cv2.IMREAD_UNCHANGED)
33 | if 2 == len(im.shape):
34 | im = im[..., None]
35 | if 4 == im.shape[-1]:
36 | im = cv2.cvtColor(im, cv2.COLOR_BGRA2RGBA)
37 | elif 3 == im.shape[-1]:
38 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
39 | elif 1 == im.shape[-1]:
40 | pass
41 | else:
42 | raise NotImplementedError
43 | if dtype != np.uint8:
44 | im = im / 255.0
45 | return im.astype(dtype)
46 |
47 |
48 | def imwrite(im: np.ndarray, file_path: Path) -> None:
49 | if not file_path.parent.exists():
50 | file_path.parent.mkdir(parents=True, exist_ok=True)
51 | if len(im.shape) == 4:
52 | assert im.shape[0] == 1
53 | im = im[0]
54 | assert len(im.shape) == 3
55 | if im.dtype == np.float32:
56 | im = (im.clip(0.0, 1.0) * 255).astype(np.uint8)
57 | if 4 == im.shape[-1]:
58 | im = cv2.cvtColor(im, cv2.COLOR_RGBA2BGRA)
59 | elif 3 == im.shape[-1]:
60 | im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
61 | elif 1 == im.shape[-1]:
62 | im = im[..., 0]
63 | else:
64 | raise NotImplementedError
65 | cv2.imwrite(str(file_path), im)
66 |
67 |
68 | def write_rendering(data: Any, parrent_path: Path, name: str):
69 | if isinstance(data, np.ndarray):
70 | imwrite(data, parrent_path / (name + '.png'))
71 | elif isinstance(data, o3d.geometry.PointCloud):
72 | if not parrent_path.exists():
73 | parrent_path.mkdir(exist_ok=True, parents=True)
74 | o3d.io.write_point_cloud(str(parrent_path / (name + '.ply')), data)
75 |
--------------------------------------------------------------------------------
/dataset/utils/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 |
5 |
6 | def split_training(num_images, train_split_percentage, split):
7 | # filter image_filenames and poses based on train/eval split percentage
8 | num_train_images = math.ceil(train_split_percentage * num_images)
9 | num_test_images = num_images - num_train_images
10 | i_all = np.arange(num_images)
11 | i_train = np.linspace(
12 | 0, num_images - 1, num_train_images, dtype=int
13 | ) # equally spaced training images starting and ending at 0 and num_images-1
14 | # eval images are the remaining images
15 | i_test = np.setdiff1d(i_all, i_train)
16 | assert len(i_test) == num_test_images
17 | if split == "train":
18 | indices = i_train
19 | elif split in ["val", "test"]:
20 | indices = i_test
21 | elif split == 'all' or split == 'rendering':
22 | indices = i_all
23 | else:
24 | raise ValueError(f"Unknown dataparser split {split}")
25 | return indices
26 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from datetime import datetime
3 | import gin
4 | from loguru import logger
5 | from torch.utils.data import DataLoader
6 |
7 | from utils.common import set_random_seed
8 | from dataset.ray_dataset import RayDataset, ray_collate
9 | from neural_field.model import get_model
10 | from trainer import Trainer
11 |
12 |
13 | @gin.configurable()
14 | def main(
15 | seed: int = 42,
16 | num_workers: int = 0,
17 | train_split: str = "train",
18 | stages: str = "train_eval",
19 | batch_size: int = 16,
20 | model_name="Tri-MipRF",
21 | ):
22 | stages = list(stages.split("_"))
23 | set_random_seed(seed)
24 |
25 | logger.info("==> Init dataloader ...")
26 | train_dataset = RayDataset(split=train_split)
27 | train_loader = DataLoader(
28 | train_dataset,
29 | batch_size=batch_size,
30 | num_workers=num_workers,
31 | shuffle=False,
32 | collate_fn=ray_collate,
33 | pin_memory=True,
34 | worker_init_fn=None,
35 | pin_memory_device='cuda',
36 | prefetch_factor=2,
37 | )
38 | test_dataset = RayDataset(split='test')
39 | test_loader = DataLoader(
40 | test_dataset,
41 | batch_size=None,
42 | num_workers=1,
43 | shuffle=False,
44 | pin_memory=True,
45 | worker_init_fn=None,
46 | pin_memory_device='cuda',
47 | )
48 |
49 | logger.info("==> Init model ...")
50 | model = get_model(model_name=model_name)(aabb=train_dataset.aabb)
51 | logger.info(model)
52 |
53 | logger.info("==> Init trainer ...")
54 | trainer = Trainer(model, train_loader, eval_loader=test_loader)
55 | if "train" in stages:
56 | trainer.fit()
57 | if "eval" in stages:
58 | if "train" not in stages:
59 | trainer.load_ckpt()
60 | trainer.eval(save_results=True, rendering_channels=["rgb", "depth"])
61 |
62 |
63 | if __name__ == "__main__":
64 | parser = argparse.ArgumentParser()
65 | parser.add_argument(
66 | "--ginc",
67 | action="append",
68 | help="gin config file",
69 | )
70 | parser.add_argument(
71 | "--ginb",
72 | action="append",
73 | help="gin bindings",
74 | )
75 | args = parser.parse_args()
76 |
77 | ginbs = []
78 | if args.ginb:
79 | ginbs.extend(args.ginb)
80 | gin.parse_config_files_and_bindings(args.ginc, ginbs, finalize_config=False)
81 |
82 | exp_name = gin.query_parameter("Trainer.exp_name")
83 | exp_name = (
84 | "%s/%s/%s/%s"
85 | % (
86 | gin.query_parameter("RayDataset.scene_type"),
87 | gin.query_parameter("RayDataset.scene"),
88 | gin.query_parameter("main.model_name"),
89 | datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
90 | )
91 | if exp_name is None
92 | else exp_name
93 | )
94 | gin.bind_parameter("Trainer.exp_name", exp_name)
95 | gin.finalize()
96 | main()
97 |
--------------------------------------------------------------------------------
/neural_field/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/neural_field/encoding/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wbhu/Tri-MipRF/5b54a274d338ab64405d8123837e2d393c2262c2/neural_field/encoding/__init__.py
--------------------------------------------------------------------------------
/neural_field/encoding/tri_mip.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | import nvdiffrast.torch
5 |
6 |
7 | class TriMipEncoding(nn.Module):
8 | def __init__(
9 | self,
10 | n_levels: int,
11 | plane_size: int,
12 | feature_dim: int,
13 | include_xyz: bool = False,
14 | ):
15 | super(TriMipEncoding, self).__init__()
16 | self.n_levels = n_levels
17 | self.plane_size = plane_size
18 | self.feature_dim = feature_dim
19 | self.include_xyz = include_xyz
20 |
21 | self.register_parameter(
22 | "fm",
23 | nn.Parameter(torch.zeros(3, plane_size, plane_size, feature_dim)),
24 | )
25 | self.init_parameters()
26 | self.dim_out = (
27 | self.feature_dim * 3 + 3 if include_xyz else self.feature_dim * 3
28 | )
29 |
30 | def init_parameters(self) -> None:
31 | # Important for performance
32 | nn.init.uniform_(self.fm, -1e-2, 1e-2)
33 |
34 | def forward(self, x, level):
35 | # x in [0,1], level in [0,max_level]
36 | # x is Nx3, level is Nx1
37 | if 0 == x.shape[0]:
38 | return torch.zeros([x.shape[0], self.feature_dim * 3]).to(x)
39 | decomposed_x = torch.stack(
40 | [
41 | x[:, None, [1, 2]],
42 | x[:, None, [0, 2]],
43 | x[:, None, [0, 1]],
44 | ],
45 | dim=0,
46 | ) # 3xNx1x2
47 | if 0 == self.n_levels:
48 | level = None
49 | else:
50 | # assert level.shape[0] > 0, [level.shape, x.shape]
51 | torch.stack([level, level, level], dim=0)
52 | level = torch.broadcast_to(
53 | level, decomposed_x.shape[:3]
54 | ).contiguous()
55 | enc = nvdiffrast.torch.texture(
56 | self.fm,
57 | decomposed_x,
58 | mip_level_bias=level,
59 | boundary_mode="clamp",
60 | max_mip_level=self.n_levels - 1,
61 | ) # 3xNx1xC
62 | enc = (
63 | enc.permute(1, 2, 0, 3)
64 | .contiguous()
65 | .view(
66 | x.shape[0],
67 | self.feature_dim * 3,
68 | )
69 | ) # Nx(3C)
70 | if self.include_xyz:
71 | enc = torch.cat([x, enc], dim=-1)
72 | return enc
73 |
--------------------------------------------------------------------------------
/neural_field/field/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wbhu/Tri-MipRF/5b54a274d338ab64405d8123837e2d393c2262c2/neural_field/field/__init__.py
--------------------------------------------------------------------------------
/neural_field/field/trimipRF.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Callable
3 |
4 | import gin
5 | import torch
6 | from torch import Tensor, nn
7 | import tinycudann as tcnn
8 |
9 | from neural_field.encoding.tri_mip import TriMipEncoding
10 | from neural_field.nn_utils.activations import trunc_exp
11 |
12 |
13 | @gin.configurable()
14 | class TriMipRF(nn.Module):
15 | def __init__(
16 | self,
17 | n_levels: int = 8,
18 | plane_size: int = 512,
19 | feature_dim: int = 16,
20 | geo_feat_dim: int = 15,
21 | net_depth_base: int = 2,
22 | net_depth_color: int = 4,
23 | net_width: int = 128,
24 | density_activation: Callable = lambda x: trunc_exp(x - 1),
25 | ) -> None:
26 | super().__init__()
27 | self.plane_size = plane_size
28 | self.log2_plane_size = math.log2(plane_size)
29 | self.geo_feat_dim = geo_feat_dim
30 | self.density_activation = density_activation
31 |
32 | self.encoding = TriMipEncoding(n_levels, plane_size, feature_dim)
33 | self.direction_encoding = tcnn.Encoding(
34 | n_input_dims=3,
35 | encoding_config={
36 | "otype": "SphericalHarmonics",
37 | "degree": 4,
38 | },
39 | )
40 | self.mlp_base = tcnn.Network(
41 | n_input_dims=self.encoding.dim_out,
42 | n_output_dims=geo_feat_dim + 1,
43 | network_config={
44 | "otype": "FullyFusedMLP",
45 | "activation": "ReLU",
46 | "output_activation": "None",
47 | "n_neurons": net_width,
48 | "n_hidden_layers": net_depth_base,
49 | },
50 | )
51 | self.mlp_head = tcnn.Network(
52 | n_input_dims=self.direction_encoding.n_output_dims + geo_feat_dim,
53 | n_output_dims=3,
54 | network_config={
55 | "otype": "FullyFusedMLP",
56 | "activation": "ReLU",
57 | "output_activation": "Sigmoid",
58 | "n_neurons": net_width,
59 | "n_hidden_layers": net_depth_color,
60 | },
61 | )
62 |
63 | def query_density(
64 | self, x: Tensor, level_vol: Tensor, return_feat: bool = False
65 | ):
66 | level = (
67 | level_vol if level_vol is None else level_vol + self.log2_plane_size
68 | )
69 | selector = ((x > 0.0) & (x < 1.0)).all(dim=-1)
70 | enc = self.encoding(
71 | x.view(-1, 3),
72 | level=level.view(-1, 1),
73 | )
74 | x = (
75 | self.mlp_base(enc)
76 | .view(list(x.shape[:-1]) + [1 + self.geo_feat_dim])
77 | .to(x)
78 | )
79 | density_before_activation, base_mlp_out = torch.split(
80 | x, [1, self.geo_feat_dim], dim=-1
81 | )
82 | density = (
83 | self.density_activation(density_before_activation)
84 | * selector[..., None]
85 | )
86 | return {
87 | "density": density,
88 | "feature": base_mlp_out if return_feat else None,
89 | }
90 |
91 | def query_rgb(self, dir, embedding):
92 | # dir in [-1,1]
93 | dir = (dir + 1.0) / 2.0 # SH encoding must be in the range [0, 1]
94 | d = self.direction_encoding(dir.view(-1, dir.shape[-1]))
95 | h = torch.cat([d, embedding.view(-1, self.geo_feat_dim)], dim=-1)
96 | rgb = (
97 | self.mlp_head(h)
98 | .view(list(embedding.shape[:-1]) + [3])
99 | .to(embedding)
100 | )
101 | return {"rgb": rgb}
102 |
--------------------------------------------------------------------------------
/neural_field/model/RFModel.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Union, List, Dict
3 |
4 | import gin
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 | from torchmetrics.functional import peak_signal_noise_ratio
9 |
10 | from utils.ray import RayBundle
11 | from utils.render_buffer import RenderBuffer
12 |
13 |
14 | # @gin.configurable()
15 | class RFModel(nn.Module):
16 | def __init__(
17 | self,
18 | aabb: Union[torch.Tensor, List[float]],
19 | samples_per_ray: int = 1024,
20 | ) -> None:
21 | super().__init__()
22 | if not isinstance(aabb, torch.Tensor):
23 | aabb = torch.tensor(aabb, dtype=torch.float32)
24 | self.register_buffer("aabb", aabb)
25 | self.samples_per_ray = samples_per_ray
26 | self.render_step_size = (
27 | (self.aabb[3:] - self.aabb[:3]).max()
28 | * math.sqrt(3)
29 | / samples_per_ray
30 | ).item()
31 | aabb_min, aabb_max = torch.split(self.aabb, 3, dim=-1)
32 | self.aabb_size = aabb_max - aabb_min
33 | assert (
34 | self.aabb_size[0] == self.aabb_size[1] == self.aabb_size[2]
35 | ), "Current implementation only supports cube aabb"
36 | self.field = None
37 | self.ray_sampler = None
38 |
39 | def contraction(self, x):
40 | aabb_min, aabb_max = self.aabb[:3].unsqueeze(0), self.aabb[
41 | 3:
42 | ].unsqueeze(0)
43 | x = (x - aabb_min) / (aabb_max - aabb_min)
44 | return x
45 |
46 | def before_iter(self, step):
47 | pass
48 |
49 | def after_iter(self, step):
50 | pass
51 |
52 | def forward(
53 | self,
54 | rays: RayBundle,
55 | background_color=None,
56 | ):
57 | raise NotImplementedError
58 |
59 | @gin.configurable()
60 | def get_optimizer(
61 | self, lr=1e-3, weight_decay=1e-5, feature_lr_scale=10.0, **kwargs
62 | ):
63 | raise NotImplementedError
64 |
65 | @gin.configurable()
66 | def compute_loss(
67 | self,
68 | rays: RayBundle,
69 | rb: RenderBuffer,
70 | target: RenderBuffer,
71 | # Configurable
72 | metric='smooth_l1',
73 | **kwargs
74 | ) -> Dict:
75 | if 'smooth_l1' == metric:
76 | loss_fn = F.smooth_l1_loss
77 | elif 'mse' == metric:
78 | loss_fn = F.mse_loss
79 | elif 'mae' == metric:
80 | loss_fn = F.l1_loss
81 | else:
82 | raise NotImplementedError
83 |
84 | alive_ray_mask = (rb.alpha.squeeze(-1) > 0).detach()
85 | loss = loss_fn(
86 | rb.rgb[alive_ray_mask], target.rgb[alive_ray_mask], reduction='none'
87 | )
88 | loss = (
89 | loss * target.loss_multi[alive_ray_mask]
90 | ).sum() / target.loss_multi[alive_ray_mask].sum()
91 | return {'total_loss': loss}
92 |
93 | @gin.configurable()
94 | def compute_metrics(
95 | self,
96 | rays: RayBundle,
97 | rb: RenderBuffer,
98 | target: RenderBuffer,
99 | # Configurable
100 | **kwargs
101 | ) -> Dict:
102 | # ray info
103 | alive_ray_mask = (rb.alpha.squeeze(-1) > 0).detach()
104 | rendering_samples_actual = rb.num_samples[0].item()
105 | ray_info = {
106 | 'num_alive_ray': alive_ray_mask.long().sum().item(),
107 | 'rendering_samples_actual': rendering_samples_actual,
108 | 'num_rays': len(target),
109 | }
110 | # quality
111 | quality = {'PSNR': peak_signal_noise_ratio(rb.rgb, target.rgb).item()}
112 | return {**ray_info, **quality}
113 |
--------------------------------------------------------------------------------
/neural_field/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .RFModel import RFModel
2 | from .trimipRF import TriMipRFModel
3 |
4 |
5 | def get_model(model_name: str = 'Tri-MipRF') -> RFModel:
6 | if 'Tri-MipRF' == model_name:
7 | return TriMipRFModel
8 | else:
9 | raise NotImplementedError
10 |
--------------------------------------------------------------------------------
/neural_field/model/trimipRF.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List, Optional, Callable
2 |
3 | import gin
4 | import torch
5 | import nerfacc
6 | from nerfacc import render_weight_from_density, accumulate_along_rays
7 |
8 | from neural_field.model.RFModel import RFModel
9 | from utils.ray import RayBundle
10 | from utils.render_buffer import RenderBuffer
11 | from neural_field.field.trimipRF import TriMipRF
12 |
13 |
14 | @gin.configurable()
15 | class TriMipRFModel(RFModel):
16 | def __init__(
17 | self,
18 | aabb: Union[torch.Tensor, List[float]],
19 | samples_per_ray: int = 1024,
20 | occ_grid_resolution: int = 128,
21 | ) -> None:
22 | super().__init__(aabb=aabb, samples_per_ray=samples_per_ray)
23 | self.field = TriMipRF()
24 | self.ray_sampler = nerfacc.OccupancyGrid(
25 | roi_aabb=self.aabb, resolution=occ_grid_resolution
26 | )
27 |
28 | self.feature_vol_radii = self.aabb_size[0] / 2.0
29 | self.register_buffer(
30 | "occ_level_vol",
31 | torch.log2(
32 | self.aabb_size[0]
33 | / occ_grid_resolution
34 | / 2.0
35 | / self.feature_vol_radii
36 | ),
37 | )
38 |
39 | def before_iter(self, step):
40 | # update_ray_sampler
41 | self.ray_sampler.every_n_step(
42 | step=step,
43 | occ_eval_fn=lambda x: self.field.query_density(
44 | x=self.contraction(x),
45 | level_vol=torch.empty_like(x[..., 0]).fill_(self.occ_level_vol),
46 | )['density']
47 | * self.render_step_size,
48 | occ_thre=5e-3,
49 | )
50 |
51 | @staticmethod
52 | def compute_ball_radii(distance, radiis, cos):
53 | inverse_cos = 1.0 / cos
54 | tmp = (inverse_cos * inverse_cos - 1).sqrt() - radiis
55 | sample_ball_radii = distance * radiis * cos / (tmp * tmp + 1.0).sqrt()
56 | return sample_ball_radii
57 |
58 | def forward(
59 | self,
60 | rays: RayBundle,
61 | background_color=None,
62 | alpha_thre=0.0,
63 | ray_marching_aabb=None,
64 | ):
65 | # Ray sampling with occupancy grid
66 | with torch.no_grad():
67 |
68 | def sigma_fn(t_starts, t_ends, ray_indices):
69 | ray_indices = ray_indices.long()
70 | t_origins = rays.origins[ray_indices]
71 | t_dirs = rays.directions[ray_indices]
72 | radiis = rays.radiis[ray_indices]
73 | cos = rays.ray_cos[ray_indices]
74 | distance = (t_starts + t_ends) / 2.0
75 | positions = t_origins + t_dirs * distance
76 | positions = self.contraction(positions)
77 | sample_ball_radii = self.compute_ball_radii(
78 | distance, radiis, cos
79 | )
80 | level_vol = torch.log2(
81 | sample_ball_radii / self.feature_vol_radii
82 | ) # real level should + log2(feature_resolution)
83 | return self.field.query_density(positions, level_vol)['density']
84 |
85 | ray_indices, t_starts, t_ends = nerfacc.ray_marching(
86 | rays.origins,
87 | rays.directions,
88 | scene_aabb=self.aabb,
89 | grid=self.ray_sampler,
90 | sigma_fn=sigma_fn,
91 | render_step_size=self.render_step_size,
92 | stratified=self.training,
93 | early_stop_eps=1e-4,
94 | )
95 |
96 | # Ray rendering
97 | def rgb_sigma_fn(t_starts, t_ends, ray_indices):
98 | t_origins = rays.origins[ray_indices]
99 | t_dirs = rays.directions[ray_indices]
100 | radiis = rays.radiis[ray_indices]
101 | cos = rays.ray_cos[ray_indices]
102 | distance = (t_starts + t_ends) / 2.0
103 | positions = t_origins + t_dirs * distance
104 | positions = self.contraction(positions)
105 | sample_ball_radii = self.compute_ball_radii(distance, radiis, cos)
106 | level_vol = torch.log2(
107 | sample_ball_radii / self.feature_vol_radii
108 | ) # real level should + log2(feature_resolution)
109 | res = self.field.query_density(
110 | x=positions,
111 | level_vol=level_vol,
112 | return_feat=True,
113 | )
114 | density, feature = res['density'], res['feature']
115 | rgb = self.field.query_rgb(dir=t_dirs, embedding=feature)['rgb']
116 | return rgb, density
117 |
118 | return self.rendering(
119 | t_starts,
120 | t_ends,
121 | ray_indices,
122 | rays,
123 | rgb_sigma_fn=rgb_sigma_fn,
124 | render_bkgd=background_color,
125 | )
126 |
127 | def rendering(
128 | self,
129 | # ray marching results
130 | t_starts: torch.Tensor,
131 | t_ends: torch.Tensor,
132 | ray_indices: torch.Tensor,
133 | rays: RayBundle,
134 | # radiance field
135 | rgb_sigma_fn: Callable = None, # rendering options
136 | render_bkgd: Optional[torch.Tensor] = None,
137 | ) -> RenderBuffer:
138 | n_rays = rays.origins.shape[0]
139 | # Query sigma/alpha and color with gradients
140 | rgbs, sigmas = rgb_sigma_fn(t_starts, t_ends, ray_indices.long())
141 |
142 | # Rendering
143 | weights = render_weight_from_density(
144 | t_starts,
145 | t_ends,
146 | sigmas,
147 | ray_indices=ray_indices,
148 | n_rays=n_rays,
149 | )
150 | sample_buffer = {
151 | 'num_samples': torch.as_tensor(
152 | [len(t_starts)], dtype=torch.int32, device=rgbs.device
153 | ),
154 | }
155 |
156 | # Rendering: accumulate rgbs, opacities, and depths along the rays.
157 | colors = accumulate_along_rays(
158 | weights, ray_indices=ray_indices, values=rgbs, n_rays=n_rays
159 | )
160 | opacities = accumulate_along_rays(
161 | weights, values=None, ray_indices=ray_indices, n_rays=n_rays
162 | )
163 | opacities.clamp_(
164 | 0.0, 1.0
165 | ) # sometimes it may slightly bigger than 1.0, which will lead abnormal behaviours
166 |
167 | depths = accumulate_along_rays(
168 | weights,
169 | ray_indices=ray_indices,
170 | values=(t_starts + t_ends) / 2.0,
171 | n_rays=n_rays,
172 | )
173 | depths = (
174 | depths * rays.ray_cos
175 | ) # from distance to real depth (z value in camera space)
176 |
177 | # Background composition.
178 | if render_bkgd is not None:
179 | colors = colors + render_bkgd * (1.0 - opacities)
180 |
181 | return RenderBuffer(
182 | rgb=colors,
183 | alpha=opacities,
184 | depth=depths,
185 | **sample_buffer,
186 | _static_field=set(sample_buffer),
187 | )
188 |
189 | @gin.configurable()
190 | def get_optimizer(
191 | self, lr=2e-3, weight_decay=1e-5, feature_lr_scale=10.0, **kwargs
192 | ):
193 | params_list = []
194 | params_list.append(
195 | dict(
196 | params=self.field.encoding.parameters(),
197 | lr=lr * feature_lr_scale,
198 | )
199 | )
200 | params_list.append(
201 | dict(params=self.field.direction_encoding.parameters(), lr=lr)
202 | )
203 | params_list.append(dict(params=self.field.mlp_base.parameters(), lr=lr))
204 | params_list.append(dict(params=self.field.mlp_head.parameters(), lr=lr))
205 |
206 | optim = torch.optim.AdamW(
207 | params_list,
208 | weight_decay=weight_decay,
209 | **kwargs,
210 | eps=1e-15,
211 | )
212 | return optim
213 |
--------------------------------------------------------------------------------
/neural_field/nn_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wbhu/Tri-MipRF/5b54a274d338ab64405d8123837e2d393c2262c2/neural_field/nn_utils/__init__.py
--------------------------------------------------------------------------------
/neural_field/nn_utils/activations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from torch.cuda.amp import custom_bwd, custom_fwd
4 |
5 |
6 | class TruncExp(Function): # pylint: disable=abstract-method
7 | # Implementation from torch-ngp:
8 | # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
9 | @staticmethod
10 | @custom_fwd(cast_inputs=torch.float32)
11 | def forward(ctx, x): # pylint: disable=arguments-differ
12 | ctx.save_for_backward(x)
13 | return torch.exp(x)
14 |
15 | @staticmethod
16 | @custom_bwd
17 | def backward(ctx, g): # pylint: disable=arguments-differ
18 | x = ctx.saved_tensors[0]
19 | return g * torch.exp(torch.clamp(x, min=-15, max=15))
20 |
21 |
22 | trunc_exp = TruncExp.apply
23 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 80
3 | skip-string-normalization = true
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | av==9.2.0
2 | beautifulsoup4==4.11.1
3 | entrypoints==0.4
4 | gdown==4.5.1
5 | gin-config==0.5.0
6 | h5py==3.7.0
7 | imageio==2.21.1
8 | imageio-ffmpeg
9 | ipython==7.19.0
10 | kornia==0.6.8
11 | loguru==0.6.0
12 | lpips==0.1.4
13 | mediapy==1.1.0
14 | mmcv==1.6.2
15 | ninja==1.10.2.3
16 | numpy==1.23.3
17 | open3d==0.16.0
18 | opencv-python==4.6.0.66
19 | pandas==1.5.0
20 | Pillow==9.2.0
21 | plotly==5.7.0
22 | pycurl==7.43.0.6
23 | PyMCubes==0.1.2
24 | pyransac3d==0.6.0
25 | PyYAML==6.0
26 | rich==12.6.0
27 | scipy==1.9.2
28 | tensorboard==2.9.0
29 | torch-fidelity==0.3.0
30 | torchmetrics==0.10.0
31 | torchtyping==0.1.4
32 | tqdm==4.64.1
33 | tyro==0.3.25
34 | appdirs
35 | nerfacc==0.3.5
36 | plyfile
37 | scikit-image
38 | trimesh
39 | torch_efficient_distloss
40 | umsgpack
41 | pyngrok
42 | cryptography==39.0.2
43 | omegaconf==2.2.3
44 | segmentation-refinement
45 | xatlas
46 | protobuf==3.20.0
47 | jinja2
48 | click==8.1.7
49 | tensorboardx
50 | termcolor
--------------------------------------------------------------------------------
/scripts/convert_blender_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from os import path
4 |
5 | from absl import app
6 | from absl import flags
7 | import numpy as np
8 | from PIL import Image
9 |
10 | FLAGS = flags.FLAGS
11 |
12 | flags.DEFINE_string('blenderdir', None, 'Base directory for all Blender data.')
13 | flags.DEFINE_string('outdir', None, 'Where to save multiscale data.')
14 | flags.DEFINE_integer('n_down', 4, 'How many levels of downscaling to use.')
15 |
16 |
17 | def load_renderings(data_dir, split):
18 | """Load images and metadata from disk."""
19 | f = 'transforms_{}.json'.format(split)
20 | with open(path.join(data_dir, f), 'r') as fp:
21 | meta = json.load(fp)
22 | images = []
23 | cams = []
24 | print('Loading imgs')
25 | for frame in meta['frames']:
26 | fname = os.path.join(data_dir, frame['file_path'] + '.png')
27 | with open(fname, 'rb') as imgin:
28 | image = np.array(Image.open(imgin), dtype=np.float32) / 255.0
29 | cams.append(frame['transform_matrix'])
30 | images.append(image)
31 | ret = {}
32 | ret['images'] = np.stack(images, axis=0)
33 | print('Loaded all images, shape is', ret['images'].shape)
34 | ret['camtoworlds'] = np.stack(cams, axis=0)
35 | w = ret['images'].shape[2]
36 | camera_angle_x = float(meta['camera_angle_x'])
37 | ret['focal'] = 0.5 * w / np.tan(0.5 * camera_angle_x)
38 | return ret
39 |
40 |
41 | def down2(img):
42 | sh = img.shape
43 | return np.mean(np.reshape(img, [sh[0] // 2, 2, sh[1] // 2, 2, -1]), (1, 3))
44 |
45 |
46 | def convert_to_nerfdata(basedir, newdir, n_down):
47 | """Convert Blender data to multiscale."""
48 | if not os.path.exists(newdir):
49 | os.makedirs(newdir)
50 | splits = ['train', 'val', 'test']
51 | bigmeta = {}
52 | # Foreach split in the dataset
53 | for split in splits:
54 | print('Split', split)
55 | # Load everything
56 | data = load_renderings(basedir, split)
57 |
58 | # Save out all the images
59 | imgdir = 'images_{}'.format(split)
60 | os.makedirs(os.path.join(newdir, imgdir), exist_ok=True)
61 | fnames = []
62 | widths = []
63 | heights = []
64 | focals = []
65 | cam2worlds = []
66 | lossmults = []
67 | labels = []
68 | nears, fars = [], []
69 | f = data['focal']
70 | print('Saving images')
71 | for i, img in enumerate(data['images']):
72 | for j in range(n_down):
73 | fname = '{}/{:03d}_d{}.png'.format(imgdir, i, j)
74 | fnames.append(fname)
75 | fname = os.path.join(newdir, fname)
76 | with open(fname, 'wb') as imgout:
77 | img8 = Image.fromarray(np.uint8(img * 255))
78 | img8.save(imgout)
79 | widths.append(img.shape[1])
80 | heights.append(img.shape[0])
81 | focals.append(f / 2**j)
82 | cam2worlds.append(data['camtoworlds'][i].tolist())
83 | lossmults.append(4.0**j)
84 | labels.append(j)
85 | nears.append(2.0)
86 | fars.append(6.0)
87 | img = down2(img)
88 |
89 | # Create metadata
90 | meta = {}
91 | meta['file_path'] = fnames
92 | meta['cam2world'] = cam2worlds
93 | meta['width'] = widths
94 | meta['height'] = heights
95 | meta['focal'] = focals
96 | meta['label'] = labels
97 | meta['near'] = nears
98 | meta['far'] = fars
99 | meta['lossmult'] = lossmults
100 |
101 | fx = np.array(focals)
102 | fy = np.array(focals)
103 | cx = np.array(meta['width']) * 0.5
104 | cy = np.array(meta['height']) * 0.5
105 | arr0 = np.zeros_like(cx)
106 | arr1 = np.ones_like(cx)
107 | k_inv = np.array(
108 | [
109 | [arr1 / fx, arr0, -cx / fx],
110 | [arr0, -arr1 / fy, cy / fy],
111 | [arr0, arr0, -arr1],
112 | ]
113 | )
114 | k_inv = np.moveaxis(k_inv, -1, 0)
115 | meta['pix2cam'] = k_inv.tolist()
116 |
117 | bigmeta[split] = meta
118 |
119 | for k in bigmeta:
120 | for j in bigmeta[k]:
121 | print(k, j, type(bigmeta[k][j]), np.array(bigmeta[k][j]).shape)
122 |
123 | jsonfile = os.path.join(newdir, 'metadata.json')
124 | with open(jsonfile, 'w') as f:
125 | json.dump(bigmeta, f, ensure_ascii=False, indent=4)
126 |
127 |
128 | def main(unused_argv):
129 |
130 | blenderdir = FLAGS.blenderdir
131 | outdir = FLAGS.outdir
132 | n_down = FLAGS.n_down
133 | if not os.path.exists(outdir):
134 | os.makedirs(outdir)
135 |
136 | dirs = [os.path.join(blenderdir, f) for f in os.listdir(blenderdir)]
137 | dirs = [d for d in dirs if os.path.isdir(d)]
138 | print(dirs)
139 | for basedir in dirs:
140 | print()
141 | newdir = os.path.join(outdir, os.path.basename(basedir))
142 | print('Converting from', basedir, 'to', newdir)
143 | convert_to_nerfdata(basedir, newdir, n_down)
144 |
145 |
146 | if __name__ == '__main__':
147 | app.run(main)
148 |
--------------------------------------------------------------------------------
/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainer import Trainer
2 |
--------------------------------------------------------------------------------
/trainer/trainer.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import gin
4 | import numpy as np
5 | import torch
6 | from pathlib import Path
7 | from loguru import logger
8 | from torch.utils.data import DataLoader
9 | from tqdm import tqdm
10 | from typing import Dict, List
11 |
12 | from neural_field.model.RFModel import RFModel
13 | from utils.writer import TensorboardWriter
14 | from utils.colormaps import apply_depth_colormap
15 | import dataset.utils.io as data_io
16 |
17 |
18 | @gin.configurable()
19 | class Trainer:
20 | def __init__(
21 | self,
22 | model: RFModel,
23 | train_loader: DataLoader,
24 | eval_loader: DataLoader,
25 | # configurable
26 | base_exp_dir: str = 'experiments',
27 | exp_name: str = 'Tri-MipRF',
28 | max_steps: int = 50000,
29 | log_step: int = 500,
30 | eval_step: int = 500,
31 | target_sample_batch_size: int = 65536,
32 | test_chunk_size: int = 8192,
33 | dynamic_batch_size: bool = True,
34 | num_rays: int = 8192,
35 | varied_eval_img: bool = True,
36 | ):
37 | self.model = model.cuda()
38 | self.train_loader = train_loader
39 | self.eval_loader = eval_loader
40 |
41 | self.max_steps = max_steps
42 | self.target_sample_batch_size = target_sample_batch_size
43 | # exp_dir
44 | self.exp_dir = Path(base_exp_dir) / exp_name
45 | self.log_step = log_step
46 | self.eval_step = eval_step
47 | self.test_chunk_size = test_chunk_size
48 | self.dynamic_batch_size = dynamic_batch_size
49 | self.num_rays = num_rays
50 | self.varied_eval_img = varied_eval_img
51 |
52 | self.writer = TensorboardWriter(log_dir=self.exp_dir)
53 |
54 | self.optimizer = self.model.get_optimizer()
55 | self.scheduler = self.get_scheduler()
56 | self.grad_scaler = torch.cuda.amp.GradScaler(2**10)
57 |
58 | # Save configure
59 | conf = gin.operative_config_str()
60 | logger.info(conf)
61 | self.save_config(conf)
62 |
63 | def train_iter(self, step: int, data: Dict, logging=False):
64 | tic = time.time()
65 | cam_rays = data['cam_rays']
66 | num_rays = min(self.num_rays, len(cam_rays))
67 | cam_rays = cam_rays[:num_rays].cuda(non_blocking=True)
68 | target = data['target'][:num_rays].cuda(non_blocking=True)
69 |
70 | rb = self.model(cam_rays, target.render_bkgd)
71 |
72 | # compute loss
73 | loss_dict = self.model.compute_loss(cam_rays, rb, target)
74 | metrics = self.model.compute_metrics(cam_rays, rb, target)
75 | if 0 == metrics.get("rendering_samples_actual", -1):
76 | return metrics
77 |
78 | # update
79 | self.optimizer.zero_grad()
80 | self.grad_scaler.scale(loss_dict['total_loss']).backward()
81 | self.optimizer.step()
82 |
83 | # logging
84 | if logging:
85 | with torch.no_grad():
86 | iter_time = time.time() - tic
87 | remaining_time = (self.max_steps - step) * iter_time
88 | status = {
89 | 'lr': self.optimizer.param_groups[0]["lr"],
90 | 'step': step,
91 | 'iter_time': iter_time,
92 | 'ETA': remaining_time,
93 | }
94 | self.writer.write_scalar_dicts(
95 | ['loss', 'metrics', 'status'],
96 | [
97 | {k: v.item() for k, v in loss_dict.items()},
98 | metrics,
99 | status,
100 | ],
101 | step,
102 | )
103 | return metrics
104 |
105 | def fit(self):
106 | logger.info("==> Start training ...")
107 |
108 | iter_train_loader = iter(self.train_loader)
109 | iter_eval_loader = iter(self.eval_loader)
110 | eval_0 = next(iter_eval_loader)
111 | self.model.train()
112 | for step in range(self.max_steps):
113 | self.model.before_iter(step)
114 | metrics = self.train_iter(
115 | step,
116 | data=next(iter_train_loader),
117 | logging=(step % self.log_step == 0 and step > 0)
118 | or (step == 100),
119 | )
120 | if 0 == metrics.get("rendering_samples_actual", -1):
121 | continue
122 |
123 | self.scheduler.step()
124 | if self.dynamic_batch_size:
125 | rendering_samples_actual = metrics.get(
126 | "rendering_samples_actual",
127 | self.target_sample_batch_size,
128 | )
129 | self.num_rays = (
130 | self.num_rays
131 | * self.target_sample_batch_size
132 | // rendering_samples_actual
133 | + 1
134 | )
135 |
136 | self.model.after_iter(step)
137 |
138 | if step > 0 and step % self.eval_step == 0:
139 | self.model.eval()
140 | metrics, final_rb, target = self.eval_img(
141 | next(iter_eval_loader) if self.varied_eval_img else eval_0,
142 | compute_metrics=True,
143 | )
144 | self.writer.write_scalar_dicts(['eval'], [metrics], step)
145 | self.writer.write_image('eval/rgb', final_rb.rgb, step)
146 | self.writer.write_image('gt/rgb', target.rgb, step)
147 | self.writer.write_image(
148 | 'eval/depth',
149 | apply_depth_colormap(final_rb.depth),
150 | step,
151 | )
152 | self.writer.write_image('eval/alpha', final_rb.alpha, step)
153 |
154 | self.model.train()
155 |
156 | logger.info('==> Training done!')
157 | self.save_ckpt()
158 |
159 | @torch.no_grad()
160 | def eval_img(self, data, compute_metrics=True):
161 | cam_rays = data['cam_rays'].cuda(non_blocking=True)
162 | target = data['target'].cuda(non_blocking=True)
163 |
164 | final_rb = None
165 | flatten_rays = cam_rays.reshape(-1)
166 | flatten_target = target.reshape(-1)
167 | for i in range(0, len(cam_rays), self.test_chunk_size):
168 | rb = self.model(
169 | flatten_rays[i : i + self.test_chunk_size],
170 | flatten_target[i : i + self.test_chunk_size].render_bkgd,
171 | )
172 | final_rb = rb if final_rb is None else final_rb.cat(rb)
173 | final_rb = final_rb.reshape(cam_rays.shape)
174 | metrics = None
175 | if compute_metrics:
176 | metrics = self.model.compute_metrics(cam_rays, final_rb, target)
177 | return metrics, final_rb, target
178 |
179 | @torch.no_grad()
180 | def eval(
181 | self,
182 | save_results: bool = False,
183 | rendering_channels: List[str] = ["rgb"],
184 | ):
185 | # ipdb.set_trace()
186 | logger.info("==> Start evaluation on testset ...")
187 | if save_results:
188 | res_dir = self.exp_dir / 'rendering'
189 | res_dir.mkdir(parents=True, exist_ok=True)
190 | results = {"names": []}
191 | results.update({k: [] for k in rendering_channels})
192 |
193 | self.model.eval()
194 | metrics = []
195 | for idx, data in enumerate(tqdm(self.eval_loader)):
196 | metric, rb, target = self.eval_img(data)
197 | metrics.append(metric)
198 | if save_results:
199 | results["names"].append(data['name'])
200 | for channel in rendering_channels:
201 | if hasattr(rb, channel):
202 | values = getattr(rb, channel).cpu().numpy()
203 | if 'depth' == channel:
204 | values = (values * 10000.0).astype(
205 | np.uint16
206 | ) # scale the depth by 10k, and save it as uint16 png images
207 | results[channel].append(values)
208 | else:
209 | raise NotImplementedError
210 | del rb
211 | if save_results:
212 | for idx, name in enumerate(tqdm(results['names'])):
213 | for channel in rendering_channels:
214 | channel_path = res_dir / channel
215 | data = results[channel][idx]
216 | data_io.write_rendering(data, channel_path, name)
217 |
218 | metrics = {k: [dct[k] for dct in metrics] for k in metrics[0]}
219 | logger.info("==> Evaluation done")
220 | for k, v in metrics.items():
221 | metrics[k] = sum(v) / len(v)
222 | self.writer.write_scalar_dicts(['benchmark'], [metrics], 0)
223 | self.writer.tb_writer.close()
224 |
225 | def save_config(self, config):
226 | dest = self.exp_dir / 'config.gin'
227 | if dest.exists():
228 | return
229 | self.exp_dir.mkdir(parents=True, exist_ok=True)
230 | with open(self.exp_dir / 'config.gin', 'w') as f:
231 | f.write(config)
232 | md_config_str = gin.config.markdown(config)
233 | self.writer.write_config(md_config_str)
234 | self.writer.tb_writer.flush()
235 |
236 | def save_ckpt(self):
237 | dest = self.exp_dir / 'model.ckpt'
238 | logger.info('==> Saving checkpoints to ' + str(dest))
239 | torch.save(
240 | {
241 | "model": self.model.state_dict(),
242 | },
243 | dest,
244 | )
245 |
246 | def load_ckpt(self):
247 | dest = self.exp_dir / 'model.ckpt'
248 | loaded_state = torch.load(dest, map_location="cpu")
249 | logger.info('==> Loading checkpoints from ' + str(dest))
250 | self.model.load_state_dict(loaded_state['model'])
251 |
252 | @gin.configurable()
253 | def get_scheduler(self, gamma=0.6, **kwargs):
254 | scheduler = torch.optim.lr_scheduler.MultiStepLR(
255 | self.optimizer,
256 | milestones=[
257 | self.max_steps // 2,
258 | self.max_steps * 3 // 4,
259 | self.max_steps * 5 // 6,
260 | self.max_steps * 9 // 10,
261 | ],
262 | gamma=gamma,
263 | **kwargs,
264 | )
265 | return scheduler
266 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wbhu/Tri-MipRF/5b54a274d338ab64405d8123837e2d393c2262c2/utils/__init__.py
--------------------------------------------------------------------------------
/utils/colormaps.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from matplotlib import cm
5 | from torchtyping import TensorType
6 |
7 | WHITE = torch.tensor([1.0, 1.0, 1.0])
8 | BLACK = torch.tensor([0.0, 0.0, 0.0])
9 | RED = torch.tensor([1.0, 0.0, 0.0])
10 | GREEN = torch.tensor([0.0, 1.0, 0.0])
11 | BLUE = torch.tensor([0.0, 0.0, 1.0])
12 |
13 |
14 | def apply_colormap(
15 | image: TensorType["bs":..., 1],
16 | cmap="viridis",
17 | ) -> TensorType["bs":..., "rgb":3]:
18 | """Convert single channel to a color image.
19 | Args:
20 | image: Single channel image.
21 | cmap: Colormap for image.
22 | Returns:
23 | TensorType: Colored image
24 | """
25 |
26 | colormap = cm.get_cmap(cmap)
27 | colormap = torch.tensor(colormap.colors).to(image.device) # type: ignore
28 | image_long = (image * 255).long()
29 | image_long_min = torch.min(image_long)
30 | image_long_max = torch.max(image_long)
31 | assert image_long_min >= 0, f"the min value is {image_long_min}"
32 | assert image_long_max <= 255, f"the max value is {image_long_max}"
33 | return colormap[image_long[..., 0]]
34 |
35 |
36 | def apply_depth_colormap(
37 | depth: TensorType["bs":..., 1],
38 | accumulation: Optional[TensorType["bs":..., 1]] = None,
39 | near_plane: Optional[float] = None,
40 | far_plane: Optional[float] = None,
41 | cmap="turbo",
42 | ) -> TensorType["bs":..., "rgb":3]:
43 | """Converts a depth image to color for easier analysis.
44 | Args:
45 | depth: Depth image.
46 | accumulation: Ray accumulation used for masking vis.
47 | near_plane: Closest depth to consider. If None, use min image value.
48 | far_plane: Furthest depth to consider. If None, use max image value.
49 | cmap: Colormap to apply.
50 | Returns:
51 | Colored depth image
52 | """
53 |
54 | near_plane = near_plane or float(torch.min(depth))
55 | far_plane = far_plane or float(torch.max(depth))
56 |
57 | depth = (depth - near_plane) / (far_plane - near_plane + 1e-10)
58 | depth = torch.clip(depth, 0, 1)
59 | depth = torch.nan_to_num(depth, nan=0.0)
60 |
61 | colored_image = apply_colormap(depth, cmap=cmap)
62 |
63 | if accumulation is not None:
64 | colored_image = colored_image * accumulation + (1 - accumulation)
65 |
66 | return colored_image
67 |
68 |
69 | def apply_boolean_colormap(
70 | image: TensorType["bs":..., 1, bool],
71 | true_color: TensorType["bs":..., "rgb":3] = WHITE,
72 | false_color: TensorType["bs":..., "rgb":3] = BLACK,
73 | ) -> TensorType["bs":..., "rgb":3]:
74 | """Converts a depth image to color for easier analysis.
75 | Args:
76 | image: Boolean image.
77 | true_color: Color to use for True.
78 | false_color: Color to use for False.
79 | Returns:
80 | Colored boolean image
81 | """
82 |
83 | colored_image = torch.ones(image.shape[:-1] + (3,))
84 | colored_image[image[..., 0], :] = true_color
85 | colored_image[~image[..., 0], :] = false_color
86 | return colored_image
87 |
--------------------------------------------------------------------------------
/utils/common.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | import os
5 | from loguru import logger
6 |
7 |
8 | def set_random_seed(seed):
9 | random.seed(seed)
10 | np.random.seed(seed)
11 | torch.manual_seed(seed)
12 |
--------------------------------------------------------------------------------
/utils/ray.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from dataclasses import dataclass
3 | from typing import Optional
4 |
5 | from utils.tensor_dataclass import TensorDataclass
6 |
7 |
8 | @dataclass
9 | class RayBundle(TensorDataclass):
10 | origins: Optional[torch.Tensor] = None
11 | """Ray origins (XYZ)"""
12 |
13 | directions: Optional[torch.Tensor] = None
14 | """Unit ray direction vector"""
15 |
16 | radiis: Optional[torch.Tensor] = None
17 | """Ray image plane intersection circle radii"""
18 |
19 | ray_cos: Optional[torch.Tensor] = None
20 | """Ray cos"""
21 |
22 | def __len__(self):
23 | num_rays = torch.numel(self.origins) // self.origins.shape[-1]
24 | return num_rays
25 |
26 | @property
27 | def shape(self):
28 | return list(super().shape)
29 |
30 |
31 | @dataclass
32 | class RayBundleExt(RayBundle):
33 |
34 | ray_depth: Optional[torch.Tensor] = None
35 |
36 |
37 | @dataclass
38 | class RayBundleRast(RayBundleExt):
39 |
40 | ray_uv: Optional[torch.Tensor] = None
41 | ray_mip_level: Optional[torch.Tensor] = None
42 |
--------------------------------------------------------------------------------
/utils/render_buffer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8 |
9 | from __future__ import annotations
10 | from dataclasses import fields, dataclass, make_dataclass
11 | from typing import Optional, List, Tuple, Set, Dict, Iterator
12 | import torch
13 | import types
14 |
15 | from utils.tensor_dataclass import TensorDataclass
16 |
17 |
18 | __TD_VARIANTS__ = dict()
19 |
20 |
21 | @dataclass
22 | class RenderBuffer(TensorDataclass):
23 | """
24 | A torch based, multi-channel, pixel buffer object.
25 | RenderBuffers are "smart" data buffers, used for accumulating tracing results, blending buffers of information,
26 | and providing discretized images.
27 |
28 | The spatial dimensions of RenderBuffer channels are flexible, see TensorDataclass.
29 | """
30 |
31 | rgb: Optional[torch.Tensor] = None
32 | """ rgb is a shaded RGB color. """
33 |
34 | alpha: Optional[torch.Tensor] = None
35 | """ alpha is the alpha component of RGB-A. """
36 |
37 | depth: Optional[torch.Tensor] = None
38 | """ depth is usually a distance to the surface hit point."""
39 |
40 | # Renderbuffer supports additional custom channels passed to the Renderbuffer constructor.
41 | # Some example of custom channels used throughout wisp:
42 | # xyz=None, # xyz is usually the xyz position for the surface hit point.
43 | # hit=None, # hit is usually a segmentation mask of hit points.
44 | # normal=None, # normal is usually the surface normal at the hit point.
45 | # shadow =None, # shadow is usually some additional buffer for shadowing.
46 | # ao=None, # ao is usually some addition buffer for ambient occlusion.
47 | # ray_o=None, # ray_o is usually the ray origin.
48 | # ray_d=None, # ray_d is usually the ray direction.
49 | # err=None, # err is usually some error metric against the ground truth.
50 | # gts=None, # gts is usually the ground truth image.
51 |
52 | def __new__(cls, *args, **kwargs):
53 | class_fields = [f.name for f in fields(RenderBuffer)]
54 | new_fields = [k for k in kwargs.keys() if k not in class_fields]
55 | if 0 < len(new_fields):
56 | class_key = frozenset(new_fields)
57 | rb_class = __TD_VARIANTS__.get(class_key)
58 | if rb_class is None:
59 | rb_class = make_dataclass(
60 | f'RenderBuffer_{len(__TD_VARIANTS__)}',
61 | fields=[
62 | (
63 | k,
64 | Optional[torch.Tensor],
65 | None,
66 | )
67 | for k in kwargs.keys()
68 | ],
69 | bases=(RenderBuffer,),
70 | )
71 | # Cache for future __new__ calls
72 | __TD_VARIANTS__[class_key] = rb_class
73 | setattr(types, rb_class.__name__, rb_class)
74 | return super(RenderBuffer, rb_class).__new__(rb_class)
75 | else:
76 | return super(TensorDataclass, cls).__new__(cls)
77 |
78 | @property
79 | def rgba(self) -> Optional[torch.Tensor]:
80 | """
81 | Returns:
82 | (Optional[torch.Tensor]) A concatenated rgba. If rgb or alpha are none, this property will return None.
83 | """
84 | if self.alpha is None or self.rgb is None:
85 | return None
86 | else:
87 | return torch.cat((self.rgb, self.alpha), dim=-1)
88 |
89 | @rgba.setter
90 | def rgba(self, val: Optional[torch.Tensor]) -> None:
91 | """
92 | Args:
93 | val (Optional[torch.Tensor]) A concatenated rgba channel value, which sets values for the rgb and alpha
94 | internal channels simultaneously.
95 | """
96 | self.rgb = val[..., 0:-1]
97 | self.alpha = val[..., -1:]
98 |
99 | @property
100 | def channels(self) -> Set[str]:
101 | """Returns a set of channels supported by this RenderBuffer"""
102 | all_channels = self.fields
103 | static_channels = self._static_field
104 | return all_channels.difference(static_channels)
105 |
106 | def has_channel(self, name: str) -> bool:
107 | """Returns whether the RenderBuffer supports the specified channel"""
108 | return name in self.channels
109 |
110 | def get_channel(self, name: str) -> Optional[torch.Tensor]:
111 | """Returns the pixels value of the specified channel,
112 | assuming this RenderBuffer supports the specified channel.
113 | """
114 | return getattr(self, name)
115 |
116 | def transpose(self) -> RenderBuffer:
117 | """Permutes dimensions 0 and 1 of each channel.
118 | The rest of the channel dimensions will remain in the same order.
119 | """
120 | fn = lambda x: x.permute(1, 0, *tuple(range(2, x.ndim)))
121 | return self._apply(fn)
122 |
123 | def scale(self, size: Tuple, interpolation='bilinear') -> RenderBuffer:
124 | """Upsamples or downsamples the renderbuffer pixels using the specified interpolation.
125 | Scaling assumes renderbuffers with 2 spatial dimensions, e.g. (H, W, C) or (W, H, C).
126 |
127 | Warning: for non-floating point channels, this function will upcast to floating point dtype
128 | to perform interpolation, and will then re-cast back to the original dtype.
129 | Hence truncations due to rounding may occur.
130 |
131 | Args:
132 | size (Tuple): The new spatial dimensions of the renderbuffer.
133 | interpolation (str): Interpolation method applied to cope with missing or decimated pixels due to
134 | up / downsampling. The interpolation methods supported are aligned with
135 | :func:`torch.nn.functional.interpolate`.
136 |
137 | Returns:
138 | (RenderBuffer): A new RenderBuffer object with rescaled channels.
139 | """
140 |
141 | def _scale(x):
142 | assert (
143 | x.ndim == 3
144 | ), 'RenderBuffer scale() assumes channels have 2D spatial dimensions.'
145 | # Some versions of torch don't support direct interpolation of non-fp tensors
146 | dtype = x.dtype
147 | if not torch.is_floating_point(x):
148 | x = x.float()
149 | x = x.permute(2, 0, 1)[None]
150 | x = torch.nn.functional.interpolate(
151 | x, size=size, mode=interpolation
152 | )
153 | x = x[0].permute(1, 2, 0)
154 | if x.dtype != dtype:
155 | x = torch.round(x).to(dtype)
156 | return x
157 |
158 | return self._apply(_scale)
159 |
160 | def exr_dict(self) -> Dict[str, torch.Tensor]:
161 | """This function returns an EXR format compatible dictionary.
162 |
163 | Returns:
164 | (Dict[str, torch.Tensor])
165 | a dictionary suitable for use with `pyexr` to output multi-channel EXR images which can be
166 | viewed interactively with software like `tev`.
167 | This is suitable for debugging geometric quantities like ray origins and ray directions.
168 | """
169 | _dict = self.numpy_dict()
170 | if 'rgb' in _dict:
171 | _dict['default'] = _dict['rgb']
172 | del _dict['rgb']
173 | return _dict
174 |
175 | def image(self) -> RenderBuffer:
176 | """This function will return a copy of the RenderBuffer which will contain 8-bit [0,255] images.
177 |
178 | This function is used to output a RenderBuffer suitable for saving as a 8-bit RGB image (e.g. with
179 | Pillow). Since this quantization operation permanently loses information, this is not an inplace
180 | operation and will return a copy of the RenderBuffer. Currently this function will only return
181 | the hit segmentation mask, normalized depth, RGB, and the surface normals.
182 |
183 | If users want custom behaviour, users can implement their own conversion function which takes a
184 | RenderBuffer as input.
185 | """
186 | norm = lambda arr: ((arr + 1.0) / 2.0) if arr is not None else None
187 | bwrgb = (
188 | lambda arr: torch.cat([arr] * 3, dim=-1)
189 | if arr is not None
190 | else None
191 | )
192 | rgb8 = lambda arr: (arr * 255.0) if arr is not None else None
193 |
194 | channels = dict()
195 | if self.rgb is not None:
196 | channels['rgb'] = rgb8(self.rgb)
197 | if self.alpha is not None:
198 | channels['alpha'] = rgb8(self.alpha)
199 | if self.depth is not None:
200 | # If the relative depth is respect to some camera clipping plane, the depth should
201 | # be clipped in advance.
202 | relative_depth = self.depth / (torch.max(self.depth) + 1e-8)
203 | channels['depth'] = rgb8(bwrgb(relative_depth))
204 |
205 | if hasattr(self, 'hit') and self.hit is not None:
206 | channels['hit'] = rgb8(bwrgb(self.hit))
207 | else:
208 | channels['hit'] = None
209 | if hasattr(self, 'normal') and self.normal is not None:
210 | channels['normal'] = rgb8(norm(self.normal))
211 | else:
212 | channels['normal'] = None
213 |
214 | return RenderBuffer(**channels)
215 |
--------------------------------------------------------------------------------
/utils/tensor_dataclass.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The Nerfstudio Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tensor dataclass"""
16 |
17 | import dataclasses
18 | from dataclasses import dataclass
19 | from typing import (
20 | Dict,
21 | Set,
22 | List,
23 | NoReturn,
24 | Optional,
25 | Tuple,
26 | TypeVar,
27 | Union,
28 | Iterator,
29 | )
30 |
31 | import numpy as np
32 | import torch
33 |
34 | TensorDataclassT = TypeVar("TensorDataclassT", bound="TensorDataclass")
35 |
36 |
37 | @dataclass
38 | class TensorDataclass:
39 | """@dataclass of tensors with the same size batch. Allows indexing and standard tensor ops.
40 | Fields that are not Tensors will not be batched unless they are also a TensorDataclass.
41 |
42 | Example:
43 |
44 | .. code-block:: python
45 |
46 | @dataclass
47 | class TestTensorDataclass(TensorDataclass):
48 | a: torch.Tensor
49 | b: torch.Tensor
50 | c: torch.Tensor = None
51 |
52 | # Create a new tensor dataclass with batch size of [2,3,4]
53 | test = TestTensorDataclass(a=torch.ones((2, 3, 4, 2)), b=torch.ones((4, 3)))
54 |
55 | test.shape # [2, 3, 4]
56 | test.a.shape # [2, 3, 4, 2]
57 | test.b.shape # [2, 3, 4, 3]
58 |
59 | test.reshape((6,4)).shape # [6, 4]
60 | test.flatten().shape # [24,]
61 |
62 | test[..., 0].shape # [2, 3]
63 | test[:, 0, :].shape # [2, 4]
64 | """
65 |
66 | _shape: tuple = torch.Size([])
67 | _static_field: set = dataclasses.field(default_factory=set)
68 |
69 | def __post_init__(self) -> None:
70 | """Finishes setting up the TensorDataclass
71 |
72 | This will 1) find the broadcasted shape and 2) broadcast all fields to this shape 3)
73 | set _shape to be the broadcasted shape.
74 | """
75 | if not dataclasses.is_dataclass(self):
76 | raise TypeError("TensorDataclass must be a dataclass")
77 |
78 | batch_shapes = self._get_dict_batch_shapes(
79 | {
80 | f.name: self.__getattribute__(f.name)
81 | for f in dataclasses.fields(self)
82 | }
83 | )
84 | if len(batch_shapes) == 0:
85 | raise ValueError("TensorDataclass must have at least one tensor")
86 | try:
87 | batch_shape = torch.broadcast_shapes(*batch_shapes)
88 |
89 | broadcasted_fields = self._broadcast_dict_fields(
90 | {
91 | f.name: self.__getattribute__(f.name)
92 | for f in dataclasses.fields(self)
93 | },
94 | batch_shape,
95 | )
96 | for f, v in broadcasted_fields.items():
97 | self.__setattr__(f, v)
98 |
99 | self.__setattr__("_shape", batch_shape)
100 | except RuntimeError:
101 | pass
102 | except IndexError:
103 | # import ipdb;ipdb.set_trace()
104 | pass
105 |
106 | def _get_dict_batch_shapes(self, dict_: Dict) -> List:
107 | """Returns batch shapes of all tensors in a dictionary
108 |
109 | Args:
110 | dict_: The dictionary to get the batch shapes of.
111 |
112 | Returns:
113 | The batch shapes of all tensors in the dictionary.
114 | """
115 | batch_shapes = []
116 | for k, v in dict_.items():
117 | if k in self._static_field:
118 | continue
119 | if isinstance(v, torch.Tensor):
120 | batch_shapes.append(v.shape[:-1])
121 | elif isinstance(v, TensorDataclass):
122 | batch_shapes.append(v.shape)
123 | return batch_shapes
124 |
125 | def _broadcast_dict_fields(self, dict_: Dict, batch_shape) -> Dict:
126 | """Broadcasts all tensors in a dictionary according to batch_shape
127 |
128 | Args:
129 | dict_: The dictionary to broadcast.
130 |
131 | Returns:
132 | The broadcasted dictionary.
133 | """
134 | new_dict = {}
135 | for k, v in dict_.items():
136 | if k in self._static_field:
137 | continue
138 | if isinstance(v, torch.Tensor):
139 | new_dict[k] = v.broadcast_to((*batch_shape, v.shape[-1]))
140 | elif isinstance(v, TensorDataclass):
141 | new_dict[k] = v.broadcast_to(batch_shape)
142 | return new_dict
143 |
144 | def __getitem__(self: TensorDataclassT, indices) -> TensorDataclassT:
145 | if isinstance(indices, torch.Tensor):
146 | return self._apply_exclude_static(lambda x: x[indices])
147 | if isinstance(indices, (int, slice)):
148 | indices = (indices,)
149 | return self._apply_exclude_static(lambda x: x[indices + (slice(None),)])
150 |
151 | def __setitem__(self, indices, value) -> NoReturn:
152 | raise RuntimeError(
153 | "Index assignment is not supported for TensorDataclass"
154 | )
155 |
156 | def __len__(self) -> int:
157 | return self.shape[0]
158 |
159 | def __bool__(self) -> bool:
160 | if len(self) == 0:
161 | raise ValueError(
162 | f"The truth value of {self.__class__.__name__} when `len(x) == 0` "
163 | "is ambiguous. Use `len(x)` or `x is not None`."
164 | )
165 | return True
166 |
167 | def __iter__(self) -> Iterator[Tuple[str, Optional[torch.Tensor]]]:
168 | return iter(
169 | (f.name, getattr(self, f.name))
170 | for f in dataclasses.fields(self)
171 | if f.name not in ('_shape', '_static_field')
172 | )
173 |
174 | @property
175 | def shape(self) -> Tuple[int, ...]:
176 | """Returns the batch shape of the tensor dataclass."""
177 | return self._shape
178 |
179 | @property
180 | def size(self) -> int:
181 | """Returns the number of elements in the tensor dataclass batch dimension."""
182 | if len(self._shape) == 0:
183 | return 1
184 | return int(np.prod(self._shape))
185 |
186 | @property
187 | def ndim(self) -> int:
188 | """Returns the number of dimensions of the tensor dataclass."""
189 | return len(self._shape)
190 |
191 | @property
192 | def fields(self) -> Set[str]:
193 | return set([f[0] for f in self])
194 |
195 | def _apply(self, fn) -> TensorDataclassT:
196 | """Applies the function fn on each of the Renderbuffer channels, if not None.
197 | Returns a new instance with the processed channels.
198 | """
199 | data = {}
200 | for f in self:
201 | attr = f[1]
202 | data[f[0]] = None if attr is None else fn(attr)
203 | return dataclasses.replace(
204 | self,
205 | _static_field=self._static_field,
206 | **data,
207 | )
208 | # return TensorDataclass(**data)
209 |
210 | def _apply_exclude_static(self, fn) -> TensorDataclassT:
211 | data = {}
212 | for f in self:
213 | if f[0] in self._static_field:
214 | continue
215 | attr = f[1]
216 | data[f[0]] = None if attr is None else fn(attr)
217 | return dataclasses.replace(
218 | self,
219 | _static_field=self._static_field,
220 | **data,
221 | )
222 |
223 | @staticmethod
224 | def _apply_on_pair(td1, td2, fn) -> TensorDataclassT:
225 | """Applies the function fn on each of the Renderbuffer channels, if not None.
226 | Returns a new instance with the processed channels.
227 | """
228 | joint_fields = TensorDataclass._join_fields(
229 | td1, td2
230 | ) # Union of field names and tuples of values
231 | combined_channels = map(
232 | fn, joint_fields.values()
233 | ) # Invoke on pair per Renderbuffer field
234 | return dataclasses.replace(
235 | td1,
236 | _static_field=td1._static_field.union(td2._static_field),
237 | **dict(zip(joint_fields.keys(), combined_channels)),
238 | )
239 | # return TensorDataclass(**dict(zip(joint_fields.keys(), combined_channels))) # Pack combined fields to a new rb
240 |
241 | @staticmethod
242 | def _apply_on_list(tds, fn) -> TensorDataclassT:
243 | joint_fields = set().union(*[td.fields for td in tds])
244 | joint_fields = {
245 | f: [getattr(td, f, None) for td in tds] for f in joint_fields
246 | }
247 | combined_channels = map(fn, joint_fields.values())
248 |
249 | return dataclasses.replace(
250 | tds[0],
251 | _static_field=tds[0]._static_field.union(
252 | *[td._static_field for td in tds[1:]]
253 | ),
254 | **dict(zip(joint_fields.keys(), combined_channels)),
255 | )
256 |
257 | @staticmethod
258 | def _join_fields(td1, td2):
259 | """Creates a joint mapping of renderbuffer fields in a format of
260 | {
261 | channel1_name: (rb1.c1, rb2.c1),
262 | channel2_name: (rb1.c2, rb2.cb),
263 | channel3_name: (rb1.c1, None), # rb2 doesn't define channel3
264 | }
265 | If a renderbuffer does not have define a specific channel, None is returned.
266 | """
267 | joint_fields = td1.fields.union(td2.fields)
268 | return {
269 | f: (getattr(td1, f, None), getattr(td2, f, None))
270 | for f in joint_fields
271 | }
272 |
273 | def numpy_dict(self) -> Dict[str, np.array]:
274 | """This function returns a dictionary of numpy arrays containing the data of each channel.
275 |
276 | Returns:
277 | (Dict[str, numpy.Array])
278 | a dictionary with entries of (channel_name, channel_data)
279 | """
280 | _dict = dict(iter(self))
281 | _dict = {k: v.numpy() for k, v in _dict.items() if v is not None}
282 | return _dict
283 |
284 | def reshape(
285 | self: TensorDataclassT, shape: Tuple[int, ...]
286 | ) -> TensorDataclassT:
287 | """Returns a new TensorDataclass with the same data but with a new shape.
288 |
289 | This should deepcopy as well.
290 |
291 | Args:
292 | shape: The new shape of the tensor dataclass.
293 |
294 | Returns:
295 | A new TensorDataclass with the same data but with a new shape.
296 | """
297 | if isinstance(shape, int):
298 | shape = (shape,)
299 | return self._apply_exclude_static(
300 | lambda x: x.reshape((*shape, x.shape[-1]))
301 | )
302 |
303 | def flatten(self: TensorDataclassT) -> TensorDataclassT:
304 | """Returns a new TensorDataclass with flattened batch dimensions
305 |
306 | Returns:
307 | TensorDataclass: A new TensorDataclass with the same data but with a new shape.
308 | """
309 | return self.reshape((-1,))
310 |
311 | def broadcast_to(
312 | self: TensorDataclassT,
313 | shape: Union[torch.Size, Tuple[int, ...]],
314 | ) -> TensorDataclassT:
315 | """Returns a new TensorDataclass broadcast to new shape.
316 |
317 | Changes to the original tensor dataclass should effect the returned tensor dataclass,
318 | meaning it is NOT a deepcopy, and they are still linked.
319 |
320 | Args:
321 | shape: The new shape of the tensor dataclass.
322 |
323 | Returns:
324 | A new TensorDataclass with the same data but with a new shape.
325 | """
326 | return self._apply_exclude_static(
327 | lambda x: x.broadcast_to((*shape, x.shape[-1]))
328 | )
329 |
330 | def to(self: TensorDataclassT, device) -> TensorDataclassT:
331 | """Returns a new TensorDataclass with the same data but on the specified device.
332 |
333 | Args:
334 | device: The device to place the tensor dataclass.
335 |
336 | Returns:
337 | A new TensorDataclass with the same data but on the specified device.
338 | """
339 | return self._apply(lambda x: x.to(device))
340 |
341 | def cuda(self, non_blocking=False) -> TensorDataclassT:
342 | """Shifts the renderbuffer to the default torch cuda device"""
343 | fn = lambda x: x.cuda(non_blocking=non_blocking)
344 | return self._apply(fn)
345 |
346 | def cpu(self) -> TensorDataclassT:
347 | """Shifts the renderbuffer to the torch cpu device"""
348 | fn = lambda x: x.cpu()
349 | return self._apply(fn)
350 |
351 | def detach(self) -> TensorDataclassT:
352 | """Detaches the gradients of all channel tensors of the renderbuffer"""
353 | fn = lambda x: x.detach()
354 | return self._apply(fn)
355 |
356 | @staticmethod
357 | def direct_cat(
358 | tds: List[TensorDataclassT],
359 | dim: int = 0,
360 | ) -> TensorDataclassT:
361 | # cat_func = partial(torch.cat, dim=dim)
362 | def cat_func(arr):
363 | _arr = [ele for ele in arr if ele is not None]
364 | if 0 == len(_arr):
365 | return None
366 | return torch.cat(_arr, dim=dim)
367 |
368 | return TensorDataclass._apply_on_list(tds, cat_func)
369 |
370 | @staticmethod
371 | def direct_stack(
372 | tds: List[TensorDataclassT],
373 | dim: int = 0,
374 | ) -> TensorDataclassT:
375 | # cat_func = partial(torch.cat, dim=dim)
376 | def cat_func(arr):
377 | _arr = [ele for ele in arr if ele is not None]
378 | if 0 == len(_arr):
379 | return None
380 | return torch.stack(_arr, dim=dim)
381 |
382 | return TensorDataclass._apply_on_list(tds, cat_func)
383 |
384 | def cat(self, other: TensorDataclassT, dim: int = 0) -> TensorDataclassT:
385 | """Concatenates the channels of self and other RenderBuffers.
386 | If a channel only exists in one of the RBs, that channel will be returned as is.
387 | For channels that exists in both RBs, the spatial dimensions are assumed to be identical except for the
388 | concatenated dimension.
389 |
390 | Args:
391 | other (TensorDataclass) A second buffer to concatenate to the current buffer.
392 | dim (int): The index of spatial dimension used to concat the channels
393 |
394 | Returns:
395 | A new TensorDataclass with the concatenated channels.
396 | """
397 |
398 | def _cat(pair):
399 | if None not in pair:
400 | # Concatenating tensors of different dims where one is unsqueezed with dimensionality 1
401 | if (
402 | pair[0].ndim == (pair[1].ndim + 1)
403 | and pair[0].shape[-1] == 1
404 | ):
405 | pair = (pair[0], pair[1].unsqueeze(-1))
406 | elif (
407 | pair[1].ndim == (pair[0].ndim + 1)
408 | and pair[1].shape[-1] == 1
409 | ):
410 | pair = (pair[0].unsqueeze(-1), pair[1])
411 | return torch.cat(pair, dim=dim)
412 | elif (
413 | pair[0] is not None and pair[1] is None
414 | ): # Channel is None for other but not self
415 | return pair[0]
416 | elif (
417 | pair[0] is None and pair[1] is not None
418 | ): # Channel is None for self but not other
419 | return pair[1]
420 | else:
421 | return None
422 |
423 | return TensorDataclass._apply_on_pair(self, other, _cat)
424 |
--------------------------------------------------------------------------------
/utils/writer.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from abc import abstractmethod
3 | from pathlib import Path
4 | from typing import Any, Dict, List, Union
5 | from collections import ChainMap
6 | from loguru import logger
7 | from termcolor import colored
8 |
9 | import torch
10 |
11 | # from torch.utils.tensorboard import SummaryWriter
12 | from tensorboardX import SummaryWriter
13 | from torchtyping import TensorType
14 |
15 | to8b = lambda x: (255 * torch.clamp(x, min=0, max=1)).to(torch.uint8)
16 |
17 |
18 | class Writer:
19 | """Writer class"""
20 |
21 | def __init__(self):
22 | self.std_logger = logger
23 |
24 | @abstractmethod
25 | def write_image(
26 | self,
27 | name: str,
28 | image: TensorType["H", "W", "C"],
29 | step: int,
30 | ) -> None:
31 | """method to write out image
32 | Args:
33 | name: data identifier
34 | image: rendered image to write
35 | step: the time step to log
36 | """
37 | raise NotImplementedError
38 |
39 | @abstractmethod
40 | def write_scalar(
41 | self,
42 | name: str,
43 | scalar: Union[float, torch.Tensor],
44 | step: int,
45 | ) -> None:
46 | """Required method to write a single scalar value to the logger
47 | Args:
48 | name: data identifier
49 | scalar: value to write out
50 | step: the time step to log
51 | """
52 | raise NotImplementedError
53 |
54 | def write_scalar_dict(
55 | self,
56 | name: str,
57 | scalar_dict: Dict[str, Any],
58 | step: int,
59 | ) -> None:
60 | """Function that writes out all scalars from a given dictionary to the logger
61 | Args:
62 | scalar_dict: dictionary containing all scalar values with key names and quantities
63 | step: the time step to log
64 | """
65 | for key, scalar in scalar_dict.items():
66 | try:
67 | float_scalar = float(scalar)
68 | self.write_scalar(name + "/" + key, float_scalar, step)
69 | except:
70 | pass
71 |
72 | def write_scalar_dicts(
73 | self,
74 | names: List[str],
75 | scalar_dicts: List[Dict[str, Any]],
76 | step: int,
77 | ) -> None:
78 | # self.std_logger.info(scalar_dicts)
79 | self.std_logger.info(
80 | ''.join(
81 | [
82 | '{}{} '.format(
83 | colored('{}:'.format(k), 'light_magenta'),
84 | v
85 | if k != 'ETA'
86 | else str(datetime.timedelta(seconds=int(v))),
87 | )
88 | for k, v in dict(ChainMap(*scalar_dicts)).items()
89 | ]
90 | )
91 | )
92 | assert len(names) == len(scalar_dicts)
93 | for n, d in zip(names, scalar_dicts):
94 | self.write_scalar_dict(n, d, step)
95 |
96 |
97 | class TensorboardWriter(Writer):
98 | """Tensorboard Writer Class"""
99 |
100 | def __init__(self, log_dir: Path):
101 | super(TensorboardWriter, self).__init__()
102 | self.tb_writer = SummaryWriter(log_dir=str(log_dir))
103 |
104 | def write_image(
105 | self,
106 | name: str,
107 | image: TensorType["H", "W", "C"],
108 | step: int,
109 | ) -> None:
110 | image = to8b(image)
111 | self.tb_writer.add_image(name, image, step, dataformats="HWC")
112 |
113 | def write_scalar(
114 | self,
115 | name: str,
116 | scalar: Union[float, torch.Tensor],
117 | step: int,
118 | ) -> None:
119 | self.tb_writer.add_scalar(name, scalar, step)
120 |
121 | def write_config(self, config: str): # pylint: disable=unused-argument
122 | self.tb_writer.add_text("config", config)
123 |
--------------------------------------------------------------------------------