├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── camera.py
├── data
├── base.py
├── blender.py
├── cat.jpg
├── iphone.py
└── llff.py
├── evaluate.py
├── extract_mesh.py
├── model
├── barf.py
├── base.py
├── nerf.py
└── planar.py
├── options.py
├── options
├── barf_blender.yaml
├── barf_iphone.yaml
├── barf_llff.yaml
├── base.yaml
├── nerf_blender.yaml
├── nerf_blender_repr.yaml
├── nerf_llff.yaml
├── nerf_llff_repr.yaml
└── planar.yaml
├── requirements.yaml
├── train.py
├── util.py
├── util_vis.py
└── warp.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "external/pohsun_ssim"]
2 | path = external/pohsun_ssim
3 | url = https://github.com/Po-Hsun-Su/pytorch-ssim
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Chen-Hsuan Lin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## BARF :vomiting_face:: Bundle-Adjusting Neural Radiance Fields
2 | [Chen-Hsuan Lin](https://chenhsuanlin.bitbucket.io/),
3 | [Wei-Chiu Ma](http://people.csail.mit.edu/weichium/),
4 | [Antonio Torralba](https://groups.csail.mit.edu/vision/torralbalab/),
5 | and [Simon Lucey](http://ci2cv.net/people/simon-lucey/)
6 | IEEE International Conference on Computer Vision (ICCV), 2021 (**oral presentation**)
7 |
8 | Project page: https://chenhsuanlin.bitbucket.io/bundle-adjusting-NeRF
9 | Paper: https://chenhsuanlin.bitbucket.io/bundle-adjusting-NeRF/paper.pdf
10 | arXiv preprint: https://arxiv.org/abs/2104.06405
11 |
12 | We provide PyTorch code for all experiments: planar image alignment, NeRF/BARF on both synthetic (Blender) and real-world (LLFF) datasets, and a template for BARFing on your custom sequence.
13 |
14 | --------------------------------------
15 |
16 | ### Prerequisites
17 |
18 | - Note: for Azure ML support for this repository, please consider checking out [this branch](https://github.com/szymanowiczs/bundle-adjusting-NeRF/tree/azureml_training_script) by Stan Szymanowicz.
19 |
20 | This code is developed with Python3 (`python3`). PyTorch 1.9+ is required.
21 | It is recommended use [Anaconda](https://www.anaconda.com/products/individual) to set up the environment. Install the dependencies and activate the environment `barf-env` with
22 | ```bash
23 | conda env create --file requirements.yaml python=3
24 | conda activate barf-env
25 | ```
26 | Initialize the external submodule dependencies with
27 | ```bash
28 | git submodule update --init --recursive
29 | ```
30 |
31 | --------------------------------------
32 |
33 | ### Dataset
34 |
35 | - #### Synthetic data (Blender) and real-world data (LLFF)
36 | Both the Blender synthetic data and LLFF real-world data can be found in the [NeRF Google Drive](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1).
37 | For convenience, you can download them with the following script: (under this repo)
38 | ```bash
39 | # Blender
40 | gdown --id 18JxhpWD-4ZmuFKLzKlAw-w5PpzZxXOcG # download nerf_synthetic.zip
41 | unzip nerf_synthetic.zip
42 | rm -f nerf_synthetic.zip
43 | mv nerf_synthetic data/blender
44 | # LLFF
45 | gdown --id 16VnMcF1KJYxN9QId6TClMsZRahHNMW5g # download nerf_llff_data.zip
46 | unzip nerf_llff_data.zip
47 | rm -f nerf_llff_data.zip
48 | mv nerf_llff_data data/llff
49 | ```
50 | The `data` directory should contain the subdirectories `blender` and `llff`.
51 | If you already have the datasets downloaded, you can alternatively soft-link them within the `data` directory.
52 |
53 | - #### Test your own sequence!
54 | If you want to try BARF on your own sequence, we provide a template data file in `data/iphone.py`, which is an example to read from a sequence captured by an iPhone 12.
55 | You should modify `get_image()` to read each image sample and set the raw image sizes (`self.raw_H`, `self.raw_W`) and focal length (`self.focal`) according to your camera specs.
56 | You may ignore the camera poses as they are assumed unknown in this case, which we simply set to zero vectors.
57 |
58 | --------------------------------------
59 |
60 | ### Running the code
61 |
62 | - #### BARF models
63 | To train and evaluate BARF:
64 | ```bash
65 | # and can be set to your likes, while is specific to datasets
66 |
67 | # Blender (={chair,drums,ficus,hotdog,lego,materials,mic,ship})
68 | python3 train.py --group= --model=barf --yaml=barf_blender --name= --data.scene= --barf_c2f=[0.1,0.5]
69 | python3 evaluate.py --group= --model=barf --yaml=barf_blender --name= --data.scene= --data.val_sub= --resume
70 |
71 | # LLFF (={fern,flower,fortress,horns,leaves,orchids,room,trex})
72 | python3 train.py --group= --model=barf --yaml=barf_llff --name= --data.scene= --barf_c2f=[0.1,0.5]
73 | python3 evaluate.py --group= --model=barf --yaml=barf_llff --name= --data.scene= --resume
74 | ```
75 | All the results will be stored in the directory `output//`.
76 | You may want to organize your experiments by grouping different runs in the same group.
77 |
78 | To train baseline models:
79 | - Full positional encoding: omit the `--barf_c2f` argument.
80 | - No positional encoding: add `--arch.posenc!`.
81 |
82 | If you want to evaluate a checkpoint at a specific iteration number, use `--resume=` instead of just `--resume`.
83 |
84 | - #### Training the original NeRF
85 | If you want to train the reference NeRF models (assuming known camera poses):
86 | ```bash
87 | # Blender
88 | python3 train.py --group= --model=nerf --yaml=nerf_blender --name= --data.scene=
89 | python3 evaluate.py --group= --model=nerf --yaml=nerf_blender --name= --data.scene= --data.val_sub= --resume
90 |
91 | # LLFF
92 | python3 train.py --group= --model=nerf --yaml=nerf_llff --name= --data.scene=
93 | python3 evaluate.py --group= --model=nerf --yaml=nerf_llff --name= --data.scene= --resume
94 | ```
95 | If you wish to replicate the results from the original NeRF paper, use `--yaml=nerf_blender_repr` or `--yaml=nerf_llff_repr` instead for Blender or LLFF respectively.
96 | There are some differences, e.g. NDC will be used for the LLFF forward-facing dataset.
97 | (The reference NeRF models considered in the paper do not use NDC to parametrize the 3D points.)
98 |
99 | - #### Planar image alignment experiment
100 | If you want to try the planar image alignment experiment, run:
101 | ```bash
102 | python3 train.py --group= --model=planar --yaml=planar --name= --seed=3 --barf_c2f=[0,0.4]
103 | ```
104 | This will fit a neural image representation to a single image (default to `data/cat.jpg`), which takes a couple of minutes to optimize on a modern GPU.
105 | The seed number is set to reproduce the pre-generated warp perturbations in the paper.
106 | For the baseline methods, modify the arguments similarly as in the NeRF case above:
107 | - Full positional encoding: omit the `--barf_c2f` argument.
108 | - No positional encoding: add `--arch.posenc!`.
109 |
110 | A video `vis.mp4` will also be created to visualize the optimization process.
111 |
112 | - #### Visualizing the results
113 | We have included code to visualize the training over TensorBoard and Visdom.
114 | The TensorBoard events include the following:
115 | - **SCALARS**: the rendering losses and PSNR over the course of optimization. For BARF, the rotational/translational errors with respect to the given poses are also computed.
116 | - **IMAGES**: visualization of the RGB images and the RGB/depth rendering.
117 |
118 | We also provide visualization of 3D camera poses in Visdom.
119 | Run `visdom -port 9000` to start the Visdom server.
120 | The Visdom host server is default to `localhost`; this can be overridden with `--visdom.server` (see `options/base.yaml` for details).
121 | If you want to disable Visdom visualization, add `--visdom!`.
122 |
123 | The `extract_mesh.py` script provides a simple way to extract the underlying 3D geometry using marching cubes. Run as follows:
124 | ```bash
125 | python3 extract_mesh.py --group= --model=barf --yaml=barf_blender --name= --data.scene= --data.val_sub= --resume
126 | ```
127 | This works for both BARF and the original NeRF (by modifying the command line accordingly). This is currently supported only for the Blender dataset.
128 |
129 | --------------------------------------
130 | ### Codebase structure
131 |
132 | The main engine and network architecture in `model/barf.py` inherit those from `model/nerf.py`.
133 | This codebase is structured so that it is easy to understand the actual parts BARF is extending from NeRF.
134 | It is also simple to build your exciting applications upon either BARF or NeRF -- just inherit them again!
135 | This is the same for dataset files (e.g. `data/blender.py`).
136 |
137 | To understand the config and command lines, take the below command as an example:
138 | ```bash
139 | python3 train.py --group= --model=barf --yaml=barf_blender --name= --data.scene= --barf_c2f=[0.1,0.5]
140 | ```
141 | This will run `model/barf.py` as the main engine with `options/barf_blender.yaml` as the main config file.
142 | Note that `barf` hierarchically inherits `nerf` (which inherits `base`), making the codebase customizable.
143 | The complete configuration will be printed upon execution.
144 | To override specific options, add `--=value` or `--.=value` (and so on) to the command line. The configuration will be loaded as the variable `opt` throughout the codebase.
145 |
146 | Some tips on using and understanding the codebase:
147 | - The computation graph for forward/backprop is stored in `var` throughout the codebase.
148 | - The losses are stored in `loss`. To add a new loss function, just implement it in `compute_loss()` and add its weight to `opt.loss_weight.`. It will automatically be added to the overall loss and logged to Tensorboard.
149 | - If you are using a multi-GPU machine, you can add `--gpu=` to specify which GPU to use. Multi-GPU training/evaluation is currently not supported.
150 | - To resume from a previous checkpoint, add `--resume=`, or just `--resume` to resume from the latest checkpoint.
151 | - (to be continued....)
152 |
153 | --------------------------------------
154 |
155 | If you find our code useful for your research, please cite
156 | ```
157 | @inproceedings{lin2021barf,
158 | title={BARF: Bundle-Adjusting Neural Radiance Fields},
159 | author={Lin, Chen-Hsuan and Ma, Wei-Chiu and Torralba, Antonio and Lucey, Simon},
160 | booktitle={IEEE International Conference on Computer Vision ({ICCV})},
161 | year={2021}
162 | }
163 | ```
164 |
165 | Please contact me (chlin@cmu.edu) if you have any questions!
166 |
--------------------------------------------------------------------------------
/camera.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os,sys,time
3 | import torch
4 | import torch.nn.functional as torch_F
5 | import collections
6 | from easydict import EasyDict as edict
7 |
8 | import util
9 | from util import log,debug
10 |
11 | class Pose():
12 | """
13 | A class of operations on camera poses (PyTorch tensors with shape [...,3,4])
14 | each [3,4] camera pose takes the form of [R|t]
15 | """
16 |
17 | def __call__(self,R=None,t=None):
18 | # construct a camera pose from the given R and/or t
19 | assert(R is not None or t is not None)
20 | if R is None:
21 | if not isinstance(t,torch.Tensor): t = torch.tensor(t)
22 | R = torch.eye(3,device=t.device).repeat(*t.shape[:-1],1,1)
23 | elif t is None:
24 | if not isinstance(R,torch.Tensor): R = torch.tensor(R)
25 | t = torch.zeros(R.shape[:-1],device=R.device)
26 | else:
27 | if not isinstance(R,torch.Tensor): R = torch.tensor(R)
28 | if not isinstance(t,torch.Tensor): t = torch.tensor(t)
29 | assert(R.shape[:-1]==t.shape and R.shape[-2:]==(3,3))
30 | R = R.float()
31 | t = t.float()
32 | pose = torch.cat([R,t[...,None]],dim=-1) # [...,3,4]
33 | assert(pose.shape[-2:]==(3,4))
34 | return pose
35 |
36 | def invert(self,pose,use_inverse=False):
37 | # invert a camera pose
38 | R,t = pose[...,:3],pose[...,3:]
39 | R_inv = R.inverse() if use_inverse else R.transpose(-1,-2)
40 | t_inv = (-R_inv@t)[...,0]
41 | pose_inv = self(R=R_inv,t=t_inv)
42 | return pose_inv
43 |
44 | def compose(self,pose_list):
45 | # compose a sequence of poses together
46 | # pose_new(x) = poseN o ... o pose2 o pose1(x)
47 | pose_new = pose_list[0]
48 | for pose in pose_list[1:]:
49 | pose_new = self.compose_pair(pose_new,pose)
50 | return pose_new
51 |
52 | def compose_pair(self,pose_a,pose_b):
53 | # pose_new(x) = pose_b o pose_a(x)
54 | R_a,t_a = pose_a[...,:3],pose_a[...,3:]
55 | R_b,t_b = pose_b[...,:3],pose_b[...,3:]
56 | R_new = R_b@R_a
57 | t_new = (R_b@t_a+t_b)[...,0]
58 | pose_new = self(R=R_new,t=t_new)
59 | return pose_new
60 |
61 | class Lie():
62 | """
63 | Lie algebra for SO(3) and SE(3) operations in PyTorch
64 | """
65 |
66 | def so3_to_SO3(self,w): # [...,3]
67 | wx = self.skew_symmetric(w)
68 | theta = w.norm(dim=-1)[...,None,None]
69 | I = torch.eye(3,device=w.device,dtype=torch.float32)
70 | A = self.taylor_A(theta)
71 | B = self.taylor_B(theta)
72 | R = I+A*wx+B*wx@wx
73 | return R
74 |
75 | def SO3_to_so3(self,R,eps=1e-7): # [...,3,3]
76 | trace = R[...,0,0]+R[...,1,1]+R[...,2,2]
77 | theta = ((trace-1)/2).clamp(-1+eps,1-eps).acos_()[...,None,None]%np.pi # ln(R) will explode if theta==pi
78 | lnR = 1/(2*self.taylor_A(theta)+1e-8)*(R-R.transpose(-2,-1)) # FIXME: wei-chiu finds it weird
79 | w0,w1,w2 = lnR[...,2,1],lnR[...,0,2],lnR[...,1,0]
80 | w = torch.stack([w0,w1,w2],dim=-1)
81 | return w
82 |
83 | def se3_to_SE3(self,wu): # [...,3]
84 | w,u = wu.split([3,3],dim=-1)
85 | wx = self.skew_symmetric(w)
86 | theta = w.norm(dim=-1)[...,None,None]
87 | I = torch.eye(3,device=w.device,dtype=torch.float32)
88 | A = self.taylor_A(theta)
89 | B = self.taylor_B(theta)
90 | C = self.taylor_C(theta)
91 | R = I+A*wx+B*wx@wx
92 | V = I+B*wx+C*wx@wx
93 | Rt = torch.cat([R,(V@u[...,None])],dim=-1)
94 | return Rt
95 |
96 | def SE3_to_se3(self,Rt,eps=1e-8): # [...,3,4]
97 | R,t = Rt.split([3,1],dim=-1)
98 | w = self.SO3_to_so3(R)
99 | wx = self.skew_symmetric(w)
100 | theta = w.norm(dim=-1)[...,None,None]
101 | I = torch.eye(3,device=w.device,dtype=torch.float32)
102 | A = self.taylor_A(theta)
103 | B = self.taylor_B(theta)
104 | invV = I-0.5*wx+(1-A/(2*B))/(theta**2+eps)*wx@wx
105 | u = (invV@t)[...,0]
106 | wu = torch.cat([w,u],dim=-1)
107 | return wu
108 |
109 | def skew_symmetric(self,w):
110 | w0,w1,w2 = w.unbind(dim=-1)
111 | O = torch.zeros_like(w0)
112 | wx = torch.stack([torch.stack([O,-w2,w1],dim=-1),
113 | torch.stack([w2,O,-w0],dim=-1),
114 | torch.stack([-w1,w0,O],dim=-1)],dim=-2)
115 | return wx
116 |
117 | def taylor_A(self,x,nth=10):
118 | # Taylor expansion of sin(x)/x
119 | ans = torch.zeros_like(x)
120 | denom = 1.
121 | for i in range(nth+1):
122 | if i>0: denom *= (2*i)*(2*i+1)
123 | ans = ans+(-1)**i*x**(2*i)/denom
124 | return ans
125 | def taylor_B(self,x,nth=10):
126 | # Taylor expansion of (1-cos(x))/x**2
127 | ans = torch.zeros_like(x)
128 | denom = 1.
129 | for i in range(nth+1):
130 | denom *= (2*i+1)*(2*i+2)
131 | ans = ans+(-1)**i*x**(2*i)/denom
132 | return ans
133 | def taylor_C(self,x,nth=10):
134 | # Taylor expansion of (x-sin(x))/x**3
135 | ans = torch.zeros_like(x)
136 | denom = 1.
137 | for i in range(nth+1):
138 | denom *= (2*i+2)*(2*i+3)
139 | ans = ans+(-1)**i*x**(2*i)/denom
140 | return ans
141 |
142 | class Quaternion():
143 |
144 | def q_to_R(self,q):
145 | # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion
146 | qa,qb,qc,qd = q.unbind(dim=-1)
147 | R = torch.stack([torch.stack([1-2*(qc**2+qd**2),2*(qb*qc-qa*qd),2*(qa*qc+qb*qd)],dim=-1),
148 | torch.stack([2*(qb*qc+qa*qd),1-2*(qb**2+qd**2),2*(qc*qd-qa*qb)],dim=-1),
149 | torch.stack([2*(qb*qd-qa*qc),2*(qa*qb+qc*qd),1-2*(qb**2+qc**2)],dim=-1)],dim=-2)
150 | return R
151 |
152 | def R_to_q(self,R,eps=1e-8): # [B,3,3]
153 | # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion
154 | # FIXME: this function seems a bit problematic, need to double-check
155 | row0,row1,row2 = R.unbind(dim=-2)
156 | R00,R01,R02 = row0.unbind(dim=-1)
157 | R10,R11,R12 = row1.unbind(dim=-1)
158 | R20,R21,R22 = row2.unbind(dim=-1)
159 | t = R[...,0,0]+R[...,1,1]+R[...,2,2]
160 | r = (1+t+eps).sqrt()
161 | qa = 0.5*r
162 | qb = (R21-R12).sign()*0.5*(1+R00-R11-R22+eps).sqrt()
163 | qc = (R02-R20).sign()*0.5*(1-R00+R11-R22+eps).sqrt()
164 | qd = (R10-R01).sign()*0.5*(1-R00-R11+R22+eps).sqrt()
165 | q = torch.stack([qa,qb,qc,qd],dim=-1)
166 | for i,qi in enumerate(q):
167 | if torch.isnan(qi).any():
168 | K = torch.stack([torch.stack([R00-R11-R22,R10+R01,R20+R02,R12-R21],dim=-1),
169 | torch.stack([R10+R01,R11-R00-R22,R21+R12,R20-R02],dim=-1),
170 | torch.stack([R20+R02,R21+R12,R22-R00-R11,R01-R10],dim=-1),
171 | torch.stack([R12-R21,R20-R02,R01-R10,R00+R11+R22],dim=-1)],dim=-2)/3.0
172 | K = K[i]
173 | eigval,eigvec = torch.linalg.eigh(K)
174 | V = eigvec[:,eigval.argmax()]
175 | q[i] = torch.stack([V[3],V[0],V[1],V[2]])
176 | return q
177 |
178 | def invert(self,q):
179 | qa,qb,qc,qd = q.unbind(dim=-1)
180 | norm = q.norm(dim=-1,keepdim=True)
181 | q_inv = torch.stack([qa,-qb,-qc,-qd],dim=-1)/norm**2
182 | return q_inv
183 |
184 | def product(self,q1,q2): # [B,4]
185 | q1a,q1b,q1c,q1d = q1.unbind(dim=-1)
186 | q2a,q2b,q2c,q2d = q2.unbind(dim=-1)
187 | hamil_prod = torch.stack([q1a*q2a-q1b*q2b-q1c*q2c-q1d*q2d,
188 | q1a*q2b+q1b*q2a+q1c*q2d-q1d*q2c,
189 | q1a*q2c-q1b*q2d+q1c*q2a+q1d*q2b,
190 | q1a*q2d+q1b*q2c-q1c*q2b+q1d*q2a],dim=-1)
191 | return hamil_prod
192 |
193 | pose = Pose()
194 | lie = Lie()
195 | quaternion = Quaternion()
196 |
197 | def to_hom(X):
198 | # get homogeneous coordinates of the input
199 | X_hom = torch.cat([X,torch.ones_like(X[...,:1])],dim=-1)
200 | return X_hom
201 |
202 | # basic operations of transforming 3D points between world/camera/image coordinates
203 | def world2cam(X,pose): # [B,N,3]
204 | X_hom = to_hom(X)
205 | return X_hom@pose.transpose(-1,-2)
206 | def cam2img(X,cam_intr):
207 | return X@cam_intr.transpose(-1,-2)
208 | def img2cam(X,cam_intr):
209 | return X@cam_intr.inverse().transpose(-1,-2)
210 | def cam2world(X,pose):
211 | X_hom = to_hom(X)
212 | pose_inv = Pose().invert(pose)
213 | return X_hom@pose_inv.transpose(-1,-2)
214 |
215 | def angle_to_rotation_matrix(a,axis):
216 | # get the rotation matrix from Euler angle around specific axis
217 | roll = dict(X=1,Y=2,Z=0)[axis]
218 | O = torch.zeros_like(a)
219 | I = torch.ones_like(a)
220 | M = torch.stack([torch.stack([a.cos(),-a.sin(),O],dim=-1),
221 | torch.stack([a.sin(),a.cos(),O],dim=-1),
222 | torch.stack([O,O,I],dim=-1)],dim=-2)
223 | M = M.roll((roll,roll),dims=(-2,-1))
224 | return M
225 |
226 | def get_center_and_ray(opt,pose,intr=None): # [HW,2]
227 | # given the intrinsic/extrinsic matrices, get the camera center and ray directions]
228 | assert(opt.camera.model=="perspective")
229 | with torch.no_grad():
230 | # compute image coordinate grid
231 | y_range = torch.arange(opt.H,dtype=torch.float32,device=opt.device).add_(0.5)
232 | x_range = torch.arange(opt.W,dtype=torch.float32,device=opt.device).add_(0.5)
233 | Y,X = torch.meshgrid(y_range,x_range) # [H,W]
234 | xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2]
235 | # compute center and ray
236 | batch_size = len(pose)
237 | xy_grid = xy_grid.repeat(batch_size,1,1) # [B,HW,2]
238 | grid_3D = img2cam(to_hom(xy_grid),intr) # [B,HW,3]
239 | center_3D = torch.zeros_like(grid_3D) # [B,HW,3]
240 | # transform from camera to world coordinates
241 | grid_3D = cam2world(grid_3D,pose) # [B,HW,3]
242 | center_3D = cam2world(center_3D,pose) # [B,HW,3]
243 | ray = grid_3D-center_3D # [B,HW,3]
244 | return center_3D,ray
245 |
246 | def get_3D_points_from_depth(opt,center,ray,depth,multi_samples=False):
247 | if multi_samples: center,ray = center[:,:,None],ray[:,:,None]
248 | # x = c+dv
249 | points_3D = center+ray*depth # [B,HW,3]/[B,HW,N,3]/[N,3]
250 | return points_3D
251 |
252 | def convert_NDC(opt,center,ray,intr,near=1):
253 | # shift camera center (ray origins) to near plane (z=1)
254 | # (unlike conventional NDC, we assume the cameras are facing towards the +z direction)
255 | center = center+(near-center[...,2:])/ray[...,2:]*ray
256 | # projection
257 | cx,cy,cz = center.unbind(dim=-1) # [B,HW]
258 | rx,ry,rz = ray.unbind(dim=-1) # [B,HW]
259 | scale_x = intr[:,0,0]/intr[:,0,2] # [B]
260 | scale_y = intr[:,1,1]/intr[:,1,2] # [B]
261 | cnx = scale_x[:,None]*(cx/cz)
262 | cny = scale_y[:,None]*(cy/cz)
263 | cnz = 1-2*near/cz
264 | rnx = scale_x[:,None]*(rx/rz-cx/cz)
265 | rny = scale_y[:,None]*(ry/rz-cy/cz)
266 | rnz = 2*near/cz
267 | center_ndc = torch.stack([cnx,cny,cnz],dim=-1) # [B,HW,3]
268 | ray_ndc = torch.stack([rnx,rny,rnz],dim=-1) # [B,HW,3]
269 | return center_ndc,ray_ndc
270 |
271 | def rotation_distance(R1,R2,eps=1e-7):
272 | # http://www.boris-belousov.net/2016/12/01/quat-dist/
273 | R_diff = R1@R2.transpose(-2,-1)
274 | trace = R_diff[...,0,0]+R_diff[...,1,1]+R_diff[...,2,2]
275 | angle = ((trace-1)/2).clamp(-1+eps,1-eps).acos_() # numerical stability near -1/+1
276 | return angle
277 |
278 | def procrustes_analysis(X0,X1): # [N,3]
279 | # translation
280 | t0 = X0.mean(dim=0,keepdim=True)
281 | t1 = X1.mean(dim=0,keepdim=True)
282 | X0c = X0-t0
283 | X1c = X1-t1
284 | # scale
285 | s0 = (X0c**2).sum(dim=-1).mean().sqrt()
286 | s1 = (X1c**2).sum(dim=-1).mean().sqrt()
287 | X0cs = X0c/s0
288 | X1cs = X1c/s1
289 | # rotation (use double for SVD, float loses precision)
290 | U,S,V = (X0cs.t()@X1cs).double().svd(some=True)
291 | R = (U@V.t()).float()
292 | if R.det()<0: R[2] *= -1
293 | # align X1 to X0: X1to0 = (X1-t1)/s1@R.t()*s0+t0
294 | sim3 = edict(t0=t0[0],t1=t1[0],s0=s0,s1=s1,R=R)
295 | return sim3
296 |
297 | def get_novel_view_poses(opt,pose_anchor,N=60,scale=1):
298 | # create circular viewpoints (small oscillations)
299 | theta = torch.arange(N)/N*2*np.pi
300 | R_x = angle_to_rotation_matrix((theta.sin()*0.05).asin(),"X")
301 | R_y = angle_to_rotation_matrix((theta.cos()*0.05).asin(),"Y")
302 | pose_rot = pose(R=R_y@R_x)
303 | pose_shift = pose(t=[0,0,-4*scale])
304 | pose_shift2 = pose(t=[0,0,3.8*scale])
305 | pose_oscil = pose.compose([pose_shift,pose_rot,pose_shift2])
306 | pose_novel = pose.compose([pose_oscil,pose_anchor.cpu()[None]])
307 | return pose_novel
308 |
--------------------------------------------------------------------------------
/data/base.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os,sys,time
3 | import torch
4 | import torch.nn.functional as torch_F
5 | import torchvision
6 | import torchvision.transforms.functional as torchvision_F
7 | import torch.multiprocessing as mp
8 | import PIL
9 | import tqdm
10 | import threading,queue
11 | from easydict import EasyDict as edict
12 |
13 | import util
14 | from util import log,debug
15 |
16 | class Dataset(torch.utils.data.Dataset):
17 |
18 | def __init__(self,opt,split="train"):
19 | super().__init__()
20 | self.opt = opt
21 | self.split = split
22 | self.augment = split=="train" and opt.data.augment
23 | # define image sizes
24 | if opt.data.center_crop is not None:
25 | self.crop_H = int(self.raw_H*opt.data.center_crop)
26 | self.crop_W = int(self.raw_W*opt.data.center_crop)
27 | else: self.crop_H,self.crop_W = self.raw_H,self.raw_W
28 | if not opt.H or not opt.W:
29 | opt.H,opt.W = self.crop_H,self.crop_W
30 |
31 | def setup_loader(self,opt,shuffle=False,drop_last=False):
32 | loader = torch.utils.data.DataLoader(self,
33 | batch_size=opt.batch_size or 1,
34 | num_workers=opt.data.num_workers,
35 | shuffle=shuffle,
36 | drop_last=drop_last,
37 | pin_memory=False, # spews warnings in PyTorch 1.9 but should be True in general
38 | )
39 | print("number of samples: {}".format(len(self)))
40 | return loader
41 |
42 | def get_list(self,opt):
43 | raise NotImplementedError
44 |
45 | def preload_worker(self,data_list,load_func,q,lock,idx_tqdm):
46 | while True:
47 | idx = q.get()
48 | data_list[idx] = load_func(self.opt,idx)
49 | with lock:
50 | idx_tqdm.update()
51 | q.task_done()
52 |
53 | def preload_threading(self,opt,load_func,data_str="images"):
54 | data_list = [None]*len(self)
55 | q = queue.Queue(maxsize=len(self))
56 | idx_tqdm = tqdm.tqdm(range(len(self)),desc="preloading {}".format(data_str),leave=False)
57 | for i in range(len(self)): q.put(i)
58 | lock = threading.Lock()
59 | for ti in range(opt.data.num_workers):
60 | t = threading.Thread(target=self.preload_worker,
61 | args=(data_list,load_func,q,lock,idx_tqdm),daemon=True)
62 | t.start()
63 | q.join()
64 | idx_tqdm.close()
65 | assert(all(map(lambda x: x is not None,data_list)))
66 | return data_list
67 |
68 | def __getitem__(self,idx):
69 | raise NotImplementedError
70 |
71 | def get_image(self,opt,idx):
72 | raise NotImplementedError
73 |
74 | def generate_augmentation(self,opt):
75 | brightness = opt.data.augment.brightness or 0.
76 | contrast = opt.data.augment.contrast or 0.
77 | saturation = opt.data.augment.saturation or 0.
78 | hue = opt.data.augment.hue or 0.
79 | color_jitter = torchvision.transforms.ColorJitter.get_params(
80 | brightness=(1-brightness,1+brightness),
81 | contrast=(1-contrast,1+contrast),
82 | saturation=(1-saturation,1+saturation),
83 | hue=(-hue,hue),
84 | )
85 | aug = edict(
86 | color_jitter=color_jitter,
87 | flip=np.random.randn()>0 if opt.data.augment.hflip else False,
88 | rot_angle=(np.random.rand()*2-1)*opt.data.augment.rotate if opt.data.augment.rotate else 0,
89 | )
90 | return aug
91 |
92 | def preprocess_image(self,opt,image,aug=None):
93 | if aug is not None:
94 | image = self.apply_color_jitter(opt,image,aug.color_jitter)
95 | image = torchvision_F.hflip(image) if aug.flip else image
96 | image = image.rotate(aug.rot_angle,resample=PIL.Image.BICUBIC)
97 | # center crop
98 | if opt.data.center_crop is not None:
99 | self.crop_H = int(self.raw_H*opt.data.center_crop)
100 | self.crop_W = int(self.raw_W*opt.data.center_crop)
101 | image = torchvision_F.center_crop(image,(self.crop_H,self.crop_W))
102 | else: self.crop_H,self.crop_W = self.raw_H,self.raw_W
103 | # resize
104 | if opt.data.image_size[0] is not None:
105 | image = image.resize((opt.W,opt.H))
106 | image = torchvision_F.to_tensor(image)
107 | return image
108 |
109 | def preprocess_camera(self,opt,intr,pose,aug=None):
110 | intr,pose = intr.clone(),pose.clone()
111 | # center crop
112 | intr[0,2] -= (self.raw_W-self.crop_W)/2
113 | intr[1,2] -= (self.raw_H-self.crop_H)/2
114 | # resize
115 | intr[0] *= opt.W/self.crop_W
116 | intr[1] *= opt.H/self.crop_H
117 | return intr,pose
118 |
119 | def apply_color_jitter(self,opt,image,color_jitter):
120 | mode = image.mode
121 | if mode!="L":
122 | chan = image.split()
123 | rgb = PIL.Image.merge("RGB",chan[:3])
124 | rgb = color_jitter(rgb)
125 | rgb_chan = rgb.split()
126 | image = PIL.Image.merge(mode,rgb_chan+chan[3:])
127 | return image
128 |
129 | def __len__(self):
130 | return len(self.list)
131 |
--------------------------------------------------------------------------------
/data/blender.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os,sys,time
3 | import torch
4 | import torch.nn.functional as torch_F
5 | import torchvision
6 | import torchvision.transforms.functional as torchvision_F
7 | import PIL
8 | import imageio
9 | from easydict import EasyDict as edict
10 | import json
11 | import pickle
12 |
13 | from . import base
14 | import camera
15 | from util import log,debug
16 |
17 | class Dataset(base.Dataset):
18 |
19 | def __init__(self,opt,split="train",subset=None):
20 | self.raw_H,self.raw_W = 800,800
21 | super().__init__(opt,split)
22 | self.root = opt.data.root or "data/blender"
23 | self.path = "{}/{}".format(self.root,opt.data.scene)
24 | # load/parse metadata
25 | meta_fname = "{}/transforms_{}.json".format(self.path,split)
26 | with open(meta_fname) as file:
27 | self.meta = json.load(file)
28 | self.list = self.meta["frames"]
29 | self.focal = 0.5*self.raw_W/np.tan(0.5*self.meta["camera_angle_x"])
30 | if subset: self.list = self.list[:subset]
31 | # preload dataset
32 | if opt.data.preload:
33 | self.images = self.preload_threading(opt,self.get_image)
34 | self.cameras = self.preload_threading(opt,self.get_camera,data_str="cameras")
35 |
36 | def prefetch_all_data(self,opt):
37 | assert(not opt.data.augment)
38 | # pre-iterate through all samples and group together
39 | self.all = torch.utils.data._utils.collate.default_collate([s for s in self])
40 |
41 | def get_all_camera_poses(self,opt):
42 | pose_raw_all = [torch.tensor(f["transform_matrix"],dtype=torch.float32) for f in self.list]
43 | pose_canon_all = torch.stack([self.parse_raw_camera(opt,p) for p in pose_raw_all],dim=0)
44 | return pose_canon_all
45 |
46 | def __getitem__(self,idx):
47 | opt = self.opt
48 | sample = dict(idx=idx)
49 | aug = self.generate_augmentation(opt) if self.augment else None
50 | image = self.images[idx] if opt.data.preload else self.get_image(opt,idx)
51 | image = self.preprocess_image(opt,image,aug=aug)
52 | intr,pose = self.cameras[idx] if opt.data.preload else self.get_camera(opt,idx)
53 | intr,pose = self.preprocess_camera(opt,intr,pose,aug=aug)
54 | sample.update(
55 | image=image,
56 | intr=intr,
57 | pose=pose,
58 | )
59 | return sample
60 |
61 | def get_image(self,opt,idx):
62 | image_fname = "{}/{}.png".format(self.path,self.list[idx]["file_path"])
63 | image = PIL.Image.fromarray(imageio.imread(image_fname)) # directly using PIL.Image.open() leads to weird corruption....
64 | return image
65 |
66 | def preprocess_image(self,opt,image,aug=None):
67 | image = super().preprocess_image(opt,image,aug=aug)
68 | rgb,mask = image[:3],image[3:]
69 | if opt.data.bgcolor is not None:
70 | rgb = rgb*mask+opt.data.bgcolor*(1-mask)
71 | return rgb
72 |
73 | def get_camera(self,opt,idx):
74 | intr = torch.tensor([[self.focal,0,self.raw_W/2],
75 | [0,self.focal,self.raw_H/2],
76 | [0,0,1]]).float()
77 | pose_raw = torch.tensor(self.list[idx]["transform_matrix"],dtype=torch.float32)
78 | pose = self.parse_raw_camera(opt,pose_raw)
79 | return intr,pose
80 |
81 | def parse_raw_camera(self,opt,pose_raw):
82 | pose_flip = camera.pose(R=torch.diag(torch.tensor([1,-1,-1])))
83 | pose = camera.pose.compose([pose_flip,pose_raw[:3]])
84 | pose = camera.pose.invert(pose)
85 | return pose
86 |
--------------------------------------------------------------------------------
/data/cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenhsuanlin/bundle-adjusting-NeRF/803291bd0ee91c7c13fb5cc42195383c5ade7d15/data/cat.jpg
--------------------------------------------------------------------------------
/data/iphone.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os,sys,time
3 | import torch
4 | import torch.nn.functional as torch_F
5 | import torchvision
6 | import torchvision.transforms.functional as torchvision_F
7 | import PIL
8 | import imageio
9 | from easydict import EasyDict as edict
10 | import json
11 | import pickle
12 |
13 | from . import base
14 | import camera
15 | from util import log,debug
16 |
17 | class Dataset(base.Dataset):
18 |
19 | def __init__(self,opt,split="train",subset=None):
20 | self.raw_H,self.raw_W = 1080,1920
21 | super().__init__(opt,split)
22 | self.root = opt.data.root or "data/iphone"
23 | self.path = "{}/{}".format(self.root,opt.data.scene)
24 | self.path_image = "{}/images".format(self.path)
25 | self.list = sorted(os.listdir(self.path_image),key=lambda f: int(f.split(".")[0]))
26 | # manually split train/val subsets
27 | num_val_split = int(len(self)*opt.data.val_ratio)
28 | self.list = self.list[:-num_val_split] if split=="train" else self.list[-num_val_split:]
29 | if subset: self.list = self.list[:subset]
30 | # preload dataset
31 | if opt.data.preload:
32 | self.images = self.preload_threading(opt,self.get_image)
33 | self.cameras = self.preload_threading(opt,self.get_camera,data_str="cameras")
34 |
35 | def prefetch_all_data(self,opt):
36 | assert(not opt.data.augment)
37 | # pre-iterate through all samples and group together
38 | self.all = torch.utils.data._utils.collate.default_collate([s for s in self])
39 |
40 | def get_all_camera_poses(self,opt):
41 | # poses are unknown, so just return some dummy poses (identity transform)
42 | return camera.pose(t=torch.zeros(len(self),3))
43 |
44 | def __getitem__(self,idx):
45 | opt = self.opt
46 | sample = dict(idx=idx)
47 | aug = self.generate_augmentation(opt) if self.augment else None
48 | image = self.images[idx] if opt.data.preload else self.get_image(opt,idx)
49 | image = self.preprocess_image(opt,image,aug=aug)
50 | intr,pose = self.cameras[idx] if opt.data.preload else self.get_camera(opt,idx)
51 | intr,pose = self.preprocess_camera(opt,intr,pose,aug=aug)
52 | sample.update(
53 | image=image,
54 | intr=intr,
55 | pose=pose,
56 | )
57 | return sample
58 |
59 | def get_image(self,opt,idx):
60 | image_fname = "{}/{}".format(self.path_image,self.list[idx])
61 | image = PIL.Image.fromarray(imageio.imread(image_fname)) # directly using PIL.Image.open() leads to weird corruption....
62 | return image
63 |
64 | def get_camera(self,opt,idx):
65 | self.focal = self.raw_W*4.2/(12.8/2.55)
66 | intr = torch.tensor([[self.focal,0,self.raw_W/2],
67 | [0,self.focal,self.raw_H/2],
68 | [0,0,1]]).float()
69 | pose = camera.pose(t=torch.zeros(3)) # dummy pose, won't be used
70 | return intr,pose
71 |
--------------------------------------------------------------------------------
/data/llff.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os,sys,time
3 | import torch
4 | import torch.nn.functional as torch_F
5 | import torchvision
6 | import torchvision.transforms.functional as torchvision_F
7 | import PIL
8 | import imageio
9 | from easydict import EasyDict as edict
10 | import json
11 | import pickle
12 |
13 | from . import base
14 | import camera
15 | from util import log,debug
16 |
17 | class Dataset(base.Dataset):
18 |
19 | def __init__(self,opt,split="train",subset=None):
20 | self.raw_H,self.raw_W = 3024,4032
21 | super().__init__(opt,split)
22 | self.root = opt.data.root or "data/llff"
23 | self.path = "{}/{}".format(self.root,opt.data.scene)
24 | self.path_image = "{}/images".format(self.path)
25 | image_fnames = sorted(os.listdir(self.path_image))
26 | poses_raw,bounds = self.parse_cameras_and_bounds(opt)
27 | self.list = list(zip(image_fnames,poses_raw,bounds))
28 | # manually split train/val subsets
29 | num_val_split = int(len(self)*opt.data.val_ratio)
30 | self.list = self.list[:-num_val_split] if split=="train" else self.list[-num_val_split:]
31 | if subset: self.list = self.list[:subset]
32 | # preload dataset
33 | if opt.data.preload:
34 | self.images = self.preload_threading(opt,self.get_image)
35 | self.cameras = self.preload_threading(opt,self.get_camera,data_str="cameras")
36 |
37 | def prefetch_all_data(self,opt):
38 | assert(not opt.data.augment)
39 | # pre-iterate through all samples and group together
40 | self.all = torch.utils.data._utils.collate.default_collate([s for s in self])
41 |
42 | def parse_cameras_and_bounds(self,opt):
43 | fname = "{}/poses_bounds.npy".format(self.path)
44 | data = torch.tensor(np.load(fname),dtype=torch.float32)
45 | # parse cameras (intrinsics and poses)
46 | cam_data = data[:,:-2].view([-1,3,5]) # [N,3,5]
47 | poses_raw = cam_data[...,:4] # [N,3,4]
48 | poses_raw[...,0],poses_raw[...,1] = poses_raw[...,1],-poses_raw[...,0]
49 | raw_H,raw_W,self.focal = cam_data[0,:,-1]
50 | assert(self.raw_H==raw_H and self.raw_W==raw_W)
51 | # parse depth bounds
52 | bounds = data[:,-2:] # [N,2]
53 | scale = 1./(bounds.min()*0.75) # not sure how this was determined
54 | poses_raw[...,3] *= scale
55 | bounds *= scale
56 | # roughly center camera poses
57 | poses_raw = self.center_camera_poses(opt,poses_raw)
58 | return poses_raw,bounds
59 |
60 | def center_camera_poses(self,opt,poses):
61 | # compute average pose
62 | center = poses[...,3].mean(dim=0)
63 | v1 = torch_F.normalize(poses[...,1].mean(dim=0),dim=0)
64 | v2 = torch_F.normalize(poses[...,2].mean(dim=0),dim=0)
65 | v0 = v1.cross(v2)
66 | pose_avg = torch.stack([v0,v1,v2,center],dim=-1)[None] # [1,3,4]
67 | # apply inverse of averaged pose
68 | poses = camera.pose.compose([poses,camera.pose.invert(pose_avg)])
69 | return poses
70 |
71 | def get_all_camera_poses(self,opt):
72 | pose_raw_all = [tup[1] for tup in self.list]
73 | pose_all = torch.stack([self.parse_raw_camera(opt,p) for p in pose_raw_all],dim=0)
74 | return pose_all
75 |
76 | def __getitem__(self,idx):
77 | opt = self.opt
78 | sample = dict(idx=idx)
79 | aug = self.generate_augmentation(opt) if self.augment else None
80 | image = self.images[idx] if opt.data.preload else self.get_image(opt,idx)
81 | image = self.preprocess_image(opt,image,aug=aug)
82 | intr,pose = self.cameras[idx] if opt.data.preload else self.get_camera(opt,idx)
83 | intr,pose = self.preprocess_camera(opt,intr,pose,aug=aug)
84 | sample.update(
85 | image=image,
86 | intr=intr,
87 | pose=pose,
88 | )
89 | return sample
90 |
91 | def get_image(self,opt,idx):
92 | image_fname = "{}/{}".format(self.path_image,self.list[idx][0])
93 | image = PIL.Image.fromarray(imageio.imread(image_fname)) # directly using PIL.Image.open() leads to weird corruption....
94 | return image
95 |
96 | def get_camera(self,opt,idx):
97 | intr = torch.tensor([[self.focal,0,self.raw_W/2],
98 | [0,self.focal,self.raw_H/2],
99 | [0,0,1]]).float()
100 | pose_raw = self.list[idx][1]
101 | pose = self.parse_raw_camera(opt,pose_raw)
102 | return intr,pose
103 |
104 | def parse_raw_camera(self,opt,pose_raw):
105 | pose_flip = camera.pose(R=torch.diag(torch.tensor([1,-1,-1])))
106 | pose = camera.pose.compose([pose_flip,pose_raw[:3]])
107 | pose = camera.pose.invert(pose)
108 | pose = camera.pose.compose([pose_flip,pose])
109 | return pose
110 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os,sys,time
3 | import torch
4 | import importlib
5 |
6 | import options
7 | from util import log
8 |
9 | def main():
10 |
11 | log.process(os.getpid())
12 | log.title("[{}] (PyTorch code for evaluating NeRF/BARF)".format(sys.argv[0]))
13 |
14 | opt_cmd = options.parse_arguments(sys.argv[1:])
15 | opt = options.set(opt_cmd=opt_cmd)
16 |
17 | with torch.cuda.device(opt.device):
18 |
19 | model = importlib.import_module("model.{}".format(opt.model))
20 | m = model.Model(opt)
21 |
22 | m.load_dataset(opt,eval_split="test")
23 | m.build_networks(opt)
24 |
25 | if opt.model=="barf":
26 | m.generate_videos_pose(opt)
27 |
28 | m.restore_checkpoint(opt)
29 | if opt.data.dataset in ["blender","llff"]:
30 | m.evaluate_full(opt)
31 | m.generate_videos_synthesis(opt)
32 |
33 | if __name__=="__main__":
34 | main()
35 |
--------------------------------------------------------------------------------
/extract_mesh.py:
--------------------------------------------------------------------------------
1 | """Extracts a 3D mesh from a pretrained model using marching cubes."""
2 |
3 | import importlib
4 | import sys
5 |
6 | import numpy as np
7 | import options
8 | import torch
9 | import tqdm
10 | import trimesh
11 | import mcubes
12 |
13 | from util import log,debug
14 |
15 | opt_cmd = options.parse_arguments(sys.argv[1:])
16 | opt = options.set(opt_cmd=opt_cmd)
17 |
18 | with torch.cuda.device(opt.device),torch.no_grad():
19 |
20 | model = importlib.import_module("model.{}".format(opt.model))
21 | m = model.Model(opt)
22 |
23 | m.build_networks(opt)
24 | m.restore_checkpoint(opt)
25 |
26 | t = torch.linspace(*opt.trimesh.range,opt.trimesh.res+1) # the best range might vary from model to model
27 | query = torch.stack(torch.meshgrid(t,t,t),dim=-1)
28 | query_flat = query.view(-1,3)
29 |
30 | density_all = []
31 | for i in tqdm.trange(0,len(query_flat),opt.trimesh.chunk_size,leave=False):
32 | points = query_flat[None,i:i+opt.trimesh.chunk_size].to(opt.device)
33 | ray_unit = torch.zeros_like(points) # dummy ray to comply with interface, not used
34 | _,density_samples = m.graph.nerf.forward(opt,points,ray_unit=ray_unit,mode=None)
35 | density_all.append(density_samples.cpu())
36 | density_all = torch.cat(density_all,dim=1)[0]
37 | density_all = density_all.view(*query.shape[:-1]).numpy()
38 |
39 | log.info("running marching cubes...")
40 | vertices,triangles = mcubes.marching_cubes(density_all,opt.trimesh.thres)
41 | vertices_centered = vertices/opt.trimesh.res-0.5
42 | mesh = trimesh.Trimesh(vertices_centered,triangles)
43 |
44 | obj_fname = "{}/mesh.obj".format(opt.output_path)
45 | log.info("saving 3D mesh to {}...".format(obj_fname))
46 | mesh.export(obj_fname)
47 |
--------------------------------------------------------------------------------
/model/barf.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os,sys,time
3 | import torch
4 | import torch.nn.functional as torch_F
5 | import torchvision
6 | import torchvision.transforms.functional as torchvision_F
7 | import tqdm
8 | from easydict import EasyDict as edict
9 | import visdom
10 | import matplotlib.pyplot as plt
11 |
12 | import util,util_vis
13 | from util import log,debug
14 | from . import nerf
15 | import camera
16 |
17 | # ============================ main engine for training and evaluation ============================
18 |
19 | class Model(nerf.Model):
20 |
21 | def __init__(self,opt):
22 | super().__init__(opt)
23 |
24 | def build_networks(self,opt):
25 | super().build_networks(opt)
26 | if opt.camera.noise:
27 | # pre-generate synthetic pose perturbation
28 | se3_noise = torch.randn(len(self.train_data),6,device=opt.device)*opt.camera.noise
29 | self.graph.pose_noise = camera.lie.se3_to_SE3(se3_noise)
30 | self.graph.se3_refine = torch.nn.Embedding(len(self.train_data),6).to(opt.device)
31 | torch.nn.init.zeros_(self.graph.se3_refine.weight)
32 |
33 | def setup_optimizer(self,opt):
34 | super().setup_optimizer(opt)
35 | optimizer = getattr(torch.optim,opt.optim.algo)
36 | self.optim_pose = optimizer([dict(params=self.graph.se3_refine.parameters(),lr=opt.optim.lr_pose)])
37 | # set up scheduler
38 | if opt.optim.sched_pose:
39 | scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched_pose.type)
40 | if opt.optim.lr_pose_end:
41 | assert(opt.optim.sched_pose.type=="ExponentialLR")
42 | opt.optim.sched_pose.gamma = (opt.optim.lr_pose_end/opt.optim.lr_pose)**(1./opt.max_iter)
43 | kwargs = { k:v for k,v in opt.optim.sched_pose.items() if k!="type" }
44 | self.sched_pose = scheduler(self.optim_pose,**kwargs)
45 |
46 | def train_iteration(self,opt,var,loader):
47 | self.optim_pose.zero_grad()
48 | if opt.optim.warmup_pose:
49 | # simple linear warmup of pose learning rate
50 | self.optim_pose.param_groups[0]["lr_orig"] = self.optim_pose.param_groups[0]["lr"] # cache the original learning rate
51 | self.optim_pose.param_groups[0]["lr"] *= min(1,self.it/opt.optim.warmup_pose)
52 | loss = super().train_iteration(opt,var,loader)
53 | self.optim_pose.step()
54 | if opt.optim.warmup_pose:
55 | self.optim_pose.param_groups[0]["lr"] = self.optim_pose.param_groups[0]["lr_orig"] # reset learning rate
56 | if opt.optim.sched_pose: self.sched_pose.step()
57 | self.graph.nerf.progress.data.fill_(self.it/opt.max_iter)
58 | if opt.nerf.fine_sampling:
59 | self.graph.nerf_fine.progress.data.fill_(self.it/opt.max_iter)
60 | return loss
61 |
62 | @torch.no_grad()
63 | def validate(self,opt,ep=None):
64 | pose,pose_GT = self.get_all_training_poses(opt)
65 | _,self.graph.sim3 = self.prealign_cameras(opt,pose,pose_GT)
66 | super().validate(opt,ep=ep)
67 |
68 | @torch.no_grad()
69 | def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
70 | super().log_scalars(opt,var,loss,metric=metric,step=step,split=split)
71 | if split=="train":
72 | # log learning rate
73 | lr = self.optim_pose.param_groups[0]["lr"]
74 | self.tb.add_scalar("{0}/{1}".format(split,"lr_pose"),lr,step)
75 | # compute pose error
76 | if split=="train" and opt.data.dataset in ["blender","llff"]:
77 | pose,pose_GT = self.get_all_training_poses(opt)
78 | pose_aligned,_ = self.prealign_cameras(opt,pose,pose_GT)
79 | error = self.evaluate_camera_alignment(opt,pose_aligned,pose_GT)
80 | self.tb.add_scalar("{0}/error_R".format(split),error.R.mean(),step)
81 | self.tb.add_scalar("{0}/error_t".format(split),error.t.mean(),step)
82 |
83 | @torch.no_grad()
84 | def visualize(self,opt,var,step=0,split="train"):
85 | super().visualize(opt,var,step=step,split=split)
86 | if opt.visdom:
87 | if split=="val":
88 | pose,pose_GT = self.get_all_training_poses(opt)
89 | util_vis.vis_cameras(opt,self.vis,step=step,poses=[pose,pose_GT])
90 |
91 | @torch.no_grad()
92 | def get_all_training_poses(self,opt):
93 | # get ground-truth (canonical) camera poses
94 | pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device)
95 | # add synthetic pose perturbation to all training data
96 | if opt.data.dataset=="blender":
97 | pose = pose_GT
98 | if opt.camera.noise:
99 | pose = camera.pose.compose([self.graph.pose_noise,pose])
100 | else: pose = self.graph.pose_eye
101 | # add learned pose correction to all training data
102 | pose_refine = camera.lie.se3_to_SE3(self.graph.se3_refine.weight)
103 | pose = camera.pose.compose([pose_refine,pose])
104 | return pose,pose_GT
105 |
106 | @torch.no_grad()
107 | def prealign_cameras(self,opt,pose,pose_GT):
108 | # compute 3D similarity transform via Procrustes analysis
109 | center = torch.zeros(1,1,3,device=opt.device)
110 | center_pred = camera.cam2world(center,pose)[:,0] # [N,3]
111 | center_GT = camera.cam2world(center,pose_GT)[:,0] # [N,3]
112 | try:
113 | sim3 = camera.procrustes_analysis(center_GT,center_pred)
114 | except:
115 | print("warning: SVD did not converge...")
116 | sim3 = edict(t0=0,t1=0,s0=1,s1=1,R=torch.eye(3,device=opt.device))
117 | # align the camera poses
118 | center_aligned = (center_pred-sim3.t1)/sim3.s1@sim3.R.t()*sim3.s0+sim3.t0
119 | R_aligned = pose[...,:3]@sim3.R.t()
120 | t_aligned = (-R_aligned@center_aligned[...,None])[...,0]
121 | pose_aligned = camera.pose(R=R_aligned,t=t_aligned)
122 | return pose_aligned,sim3
123 |
124 | @torch.no_grad()
125 | def evaluate_camera_alignment(self,opt,pose_aligned,pose_GT):
126 | # measure errors in rotation and translation
127 | R_aligned,t_aligned = pose_aligned.split([3,1],dim=-1)
128 | R_GT,t_GT = pose_GT.split([3,1],dim=-1)
129 | R_error = camera.rotation_distance(R_aligned,R_GT)
130 | t_error = (t_aligned-t_GT)[...,0].norm(dim=-1)
131 | error = edict(R=R_error,t=t_error)
132 | return error
133 |
134 | @torch.no_grad()
135 | def evaluate_full(self,opt):
136 | self.graph.eval()
137 | # evaluate rotation/translation
138 | pose,pose_GT = self.get_all_training_poses(opt)
139 | pose_aligned,self.graph.sim3 = self.prealign_cameras(opt,pose,pose_GT)
140 | error = self.evaluate_camera_alignment(opt,pose_aligned,pose_GT)
141 | print("--------------------------")
142 | print("rot: {:8.3f}".format(np.rad2deg(error.R.mean().cpu())))
143 | print("trans: {:10.5f}".format(error.t.mean()))
144 | print("--------------------------")
145 | # dump numbers
146 | quant_fname = "{}/quant_pose.txt".format(opt.output_path)
147 | with open(quant_fname,"w") as file:
148 | for i,(err_R,err_t) in enumerate(zip(error.R,error.t)):
149 | file.write("{} {} {}\n".format(i,err_R.item(),err_t.item()))
150 | # evaluate novel view synthesis
151 | super().evaluate_full(opt)
152 |
153 | @torch.enable_grad()
154 | def evaluate_test_time_photometric_optim(self,opt,var):
155 | # use another se3 Parameter to absorb the remaining pose errors
156 | var.se3_refine_test = torch.nn.Parameter(torch.zeros(1,6,device=opt.device))
157 | optimizer = getattr(torch.optim,opt.optim.algo)
158 | optim_pose = optimizer([dict(params=[var.se3_refine_test],lr=opt.optim.lr_pose)])
159 | iterator = tqdm.trange(opt.optim.test_iter,desc="test-time optim.",leave=False,position=1)
160 | for it in iterator:
161 | optim_pose.zero_grad()
162 | var.pose_refine_test = camera.lie.se3_to_SE3(var.se3_refine_test)
163 | var = self.graph.forward(opt,var,mode="test-optim")
164 | loss = self.graph.compute_loss(opt,var,mode="test-optim")
165 | loss = self.summarize_loss(opt,var,loss)
166 | loss.all.backward()
167 | optim_pose.step()
168 | iterator.set_postfix(loss="{:.3f}".format(loss.all))
169 | return var
170 |
171 | @torch.no_grad()
172 | def generate_videos_pose(self,opt):
173 | self.graph.eval()
174 | fig = plt.figure(figsize=(10,10) if opt.data.dataset=="blender" else (16,8))
175 | cam_path = "{}/poses".format(opt.output_path)
176 | os.makedirs(cam_path,exist_ok=True)
177 | ep_list = []
178 | for ep in range(0,opt.max_iter+1,opt.freq.ckpt):
179 | # load checkpoint (0 is random init)
180 | if ep!=0:
181 | try: util.restore_checkpoint(opt,self,resume=ep)
182 | except: continue
183 | # get the camera poses
184 | pose,pose_ref = self.get_all_training_poses(opt)
185 | if opt.data.dataset in ["blender","llff"]:
186 | pose_aligned,_ = self.prealign_cameras(opt,pose,pose_ref)
187 | pose_aligned,pose_ref = pose_aligned.detach().cpu(),pose_ref.detach().cpu()
188 | dict(
189 | blender=util_vis.plot_save_poses_blender,
190 | llff=util_vis.plot_save_poses,
191 | )[opt.data.dataset](opt,fig,pose_aligned,pose_ref=pose_ref,path=cam_path,ep=ep)
192 | else:
193 | pose = pose.detach().cpu()
194 | util_vis.plot_save_poses(opt,fig,pose,pose_ref=None,path=cam_path,ep=ep)
195 | ep_list.append(ep)
196 | plt.close()
197 | # write videos
198 | print("writing videos...")
199 | list_fname = "{}/temp.list".format(cam_path)
200 | with open(list_fname,"w") as file:
201 | for ep in ep_list: file.write("file {}.png\n".format(ep))
202 | cam_vid_fname = "{}/poses.mp4".format(opt.output_path)
203 | os.system("ffmpeg -y -r 30 -f concat -i {0} -pix_fmt yuv420p {1} >/dev/null 2>&1".format(list_fname,cam_vid_fname))
204 | os.remove(list_fname)
205 |
206 | # ============================ computation graph for forward/backprop ============================
207 |
208 | class Graph(nerf.Graph):
209 |
210 | def __init__(self,opt):
211 | super().__init__(opt)
212 | self.nerf = NeRF(opt)
213 | if opt.nerf.fine_sampling:
214 | self.nerf_fine = NeRF(opt)
215 | self.pose_eye = torch.eye(3,4).to(opt.device)
216 |
217 | def get_pose(self,opt,var,mode=None):
218 | if mode=="train":
219 | # add the pre-generated pose perturbations
220 | if opt.data.dataset=="blender":
221 | if opt.camera.noise:
222 | var.pose_noise = self.pose_noise[var.idx]
223 | pose = camera.pose.compose([var.pose_noise,var.pose])
224 | else: pose = var.pose
225 | else: pose = self.pose_eye
226 | # add learnable pose correction
227 | var.se3_refine = self.se3_refine.weight[var.idx]
228 | pose_refine = camera.lie.se3_to_SE3(var.se3_refine)
229 | pose = camera.pose.compose([pose_refine,pose])
230 | elif mode in ["val","eval","test-optim"]:
231 | # align test pose to refined coordinate system (up to sim3)
232 | sim3 = self.sim3
233 | center = torch.zeros(1,1,3,device=opt.device)
234 | center = camera.cam2world(center,var.pose)[:,0] # [N,3]
235 | center_aligned = (center-sim3.t0)/sim3.s0@sim3.R*sim3.s1+sim3.t1
236 | R_aligned = var.pose[...,:3]@self.sim3.R
237 | t_aligned = (-R_aligned@center_aligned[...,None])[...,0]
238 | pose = camera.pose(R=R_aligned,t=t_aligned)
239 | # additionally factorize the remaining pose imperfection
240 | if opt.optim.test_photo and mode!="val":
241 | pose = camera.pose.compose([var.pose_refine_test,pose])
242 | else: pose = var.pose
243 | return pose
244 |
245 | class NeRF(nerf.NeRF):
246 |
247 | def __init__(self,opt):
248 | super().__init__(opt)
249 | self.progress = torch.nn.Parameter(torch.tensor(0.)) # use Parameter so it could be checkpointed
250 |
251 | def positional_encoding(self,opt,input,L): # [B,...,N]
252 | input_enc = super().positional_encoding(opt,input,L=L) # [B,...,2NL]
253 | # coarse-to-fine: smoothly mask positional encoding for BARF
254 | if opt.barf_c2f is not None:
255 | # set weights for different frequency bands
256 | start,end = opt.barf_c2f
257 | alpha = (self.progress.data-start)/(end-start)*L
258 | k = torch.arange(L,dtype=torch.float32,device=opt.device)
259 | weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2
260 | # apply weights
261 | shape = input_enc.shape
262 | input_enc = (input_enc.view(-1,L)*weight).view(*shape)
263 | return input_enc
264 |
--------------------------------------------------------------------------------
/model/base.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os,sys,time
3 | import torch
4 | import torch.nn.functional as torch_F
5 | import torchvision
6 | import torchvision.transforms.functional as torchvision_F
7 | import torch.utils.tensorboard
8 | import visdom
9 | import importlib
10 | import tqdm
11 | from easydict import EasyDict as edict
12 |
13 | import util,util_vis
14 | from util import log,debug
15 |
16 | # ============================ main engine for training and evaluation ============================
17 |
18 | class Model():
19 |
20 | def __init__(self,opt):
21 | super().__init__()
22 | os.makedirs(opt.output_path,exist_ok=True)
23 |
24 | def load_dataset(self,opt,eval_split="val"):
25 | data = importlib.import_module("data.{}".format(opt.data.dataset))
26 | log.info("loading training data...")
27 | self.train_data = data.Dataset(opt,split="train",subset=opt.data.train_sub)
28 | self.train_loader = self.train_data.setup_loader(opt,shuffle=True)
29 | log.info("loading test data...")
30 | if opt.data.val_on_test: eval_split = "test"
31 | self.test_data = data.Dataset(opt,split=eval_split,subset=opt.data.val_sub)
32 | self.test_loader = self.test_data.setup_loader(opt,shuffle=False)
33 |
34 | def build_networks(self,opt):
35 | graph = importlib.import_module("model.{}".format(opt.model))
36 | log.info("building networks...")
37 | self.graph = graph.Graph(opt).to(opt.device)
38 |
39 | def setup_optimizer(self,opt):
40 | log.info("setting up optimizers...")
41 | optimizer = getattr(torch.optim,opt.optim.algo)
42 | self.optim = optimizer([dict(params=self.graph.parameters(),lr=opt.optim.lr)])
43 | # set up scheduler
44 | if opt.optim.sched:
45 | scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type)
46 | kwargs = { k:v for k,v in opt.optim.sched.items() if k!="type" }
47 | self.sched = scheduler(self.optim,**kwargs)
48 |
49 | def restore_checkpoint(self,opt):
50 | epoch_start,iter_start = None,None
51 | if opt.resume:
52 | log.info("resuming from previous checkpoint...")
53 | epoch_start,iter_start = util.restore_checkpoint(opt,self,resume=opt.resume)
54 | elif opt.load is not None:
55 | log.info("loading weights from checkpoint {}...".format(opt.load))
56 | epoch_start,iter_start = util.restore_checkpoint(opt,self,load_name=opt.load)
57 | else:
58 | log.info("initializing weights from scratch...")
59 | self.epoch_start = epoch_start or 0
60 | self.iter_start = iter_start or 0
61 |
62 | def setup_visualizer(self,opt):
63 | log.info("setting up visualizers...")
64 | if opt.tb:
65 | self.tb = torch.utils.tensorboard.SummaryWriter(log_dir=opt.output_path,flush_secs=10)
66 | if opt.visdom:
67 | # check if visdom server is runninng
68 | is_open = util.check_socket_open(opt.visdom.server,opt.visdom.port)
69 | retry = None
70 | while not is_open:
71 | retry = input("visdom port ({}) not open, retry? (y/n) ".format(opt.visdom.port))
72 | if retry not in ["y","n"]: continue
73 | if retry=="y":
74 | is_open = util.check_socket_open(opt.visdom.server,opt.visdom.port)
75 | else: break
76 | self.vis = visdom.Visdom(server=opt.visdom.server,port=opt.visdom.port,env=opt.group)
77 |
78 | def train(self,opt):
79 | # before training
80 | log.title("TRAINING START")
81 | self.timer = edict(start=time.time(),it_mean=None)
82 | self.it = self.iter_start
83 | # training
84 | if self.iter_start==0: self.validate(opt,ep=0)
85 | for self.ep in range(self.epoch_start,opt.max_epoch):
86 | self.train_epoch(opt)
87 | # after training
88 | if opt.tb:
89 | self.tb.flush()
90 | self.tb.close()
91 | if opt.visdom: self.vis.close()
92 | log.title("TRAINING DONE")
93 |
94 | def train_epoch(self,opt):
95 | # before train epoch
96 | self.graph.train()
97 | # train epoch
98 | loader = tqdm.tqdm(self.train_loader,desc="training epoch {}".format(self.ep+1),leave=False)
99 | for batch in loader:
100 | # train iteration
101 | var = edict(batch)
102 | var = util.move_to_device(var,opt.device)
103 | loss = self.train_iteration(opt,var,loader)
104 | # after train epoch
105 | lr = self.sched.get_last_lr()[0] if opt.optim.sched else opt.optim.lr
106 | log.loss_train(opt,self.ep+1,lr,loss.all,self.timer)
107 | if opt.optim.sched: self.sched.step()
108 | if (self.ep+1)%opt.freq.val==0: self.validate(opt,ep=self.ep+1)
109 | if (self.ep+1)%opt.freq.ckpt==0: self.save_checkpoint(opt,ep=self.ep+1,it=self.it)
110 |
111 | def train_iteration(self,opt,var,loader):
112 | # before train iteration
113 | self.timer.it_start = time.time()
114 | # train iteration
115 | self.optim.zero_grad()
116 | var = self.graph.forward(opt,var,mode="train")
117 | loss = self.graph.compute_loss(opt,var,mode="train")
118 | loss = self.summarize_loss(opt,var,loss)
119 | loss.all.backward()
120 | self.optim.step()
121 | # after train iteration
122 | if (self.it+1)%opt.freq.scalar==0: self.log_scalars(opt,var,loss,step=self.it+1,split="train")
123 | if (self.it+1)%opt.freq.vis==0: self.visualize(opt,var,step=self.it+1,split="train")
124 | self.it += 1
125 | loader.set_postfix(it=self.it,loss="{:.3f}".format(loss.all))
126 | self.timer.it_end = time.time()
127 | util.update_timer(opt,self.timer,self.ep,len(loader))
128 | return loss
129 |
130 | def summarize_loss(self,opt,var,loss):
131 | loss_all = 0.
132 | assert("all" not in loss)
133 | # weigh losses
134 | for key in loss:
135 | assert(key in opt.loss_weight)
136 | assert(loss[key].shape==())
137 | if opt.loss_weight[key] is not None:
138 | assert not torch.isinf(loss[key]),"loss {} is Inf".format(key)
139 | assert not torch.isnan(loss[key]),"loss {} is NaN".format(key)
140 | loss_all += 10**float(opt.loss_weight[key])*loss[key]
141 | loss.update(all=loss_all)
142 | return loss
143 |
144 | @torch.no_grad()
145 | def validate(self,opt,ep=None):
146 | self.graph.eval()
147 | loss_val = edict()
148 | loader = tqdm.tqdm(self.test_loader,desc="validating",leave=False)
149 | for it,batch in enumerate(loader):
150 | var = edict(batch)
151 | var = util.move_to_device(var,opt.device)
152 | var = self.graph.forward(opt,var,mode="val")
153 | loss = self.graph.compute_loss(opt,var,mode="val")
154 | loss = self.summarize_loss(opt,var,loss)
155 | for key in loss:
156 | loss_val.setdefault(key,0.)
157 | loss_val[key] += loss[key]*len(var.idx)
158 | loader.set_postfix(loss="{:.3f}".format(loss.all))
159 | if it==0: self.visualize(opt,var,step=ep,split="val")
160 | for key in loss_val: loss_val[key] /= len(self.test_data)
161 | self.log_scalars(opt,var,loss_val,step=ep,split="val")
162 | log.loss_val(opt,loss_val.all)
163 |
164 | @torch.no_grad()
165 | def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
166 | for key,value in loss.items():
167 | if key=="all": continue
168 | if opt.loss_weight[key] is not None:
169 | self.tb.add_scalar("{0}/loss_{1}".format(split,key),value,step)
170 | if metric is not None:
171 | for key,value in metric.items():
172 | self.tb.add_scalar("{0}/{1}".format(split,key),value,step)
173 |
174 | @torch.no_grad()
175 | def visualize(self,opt,var,step=0,split="train"):
176 | raise NotImplementedError
177 |
178 | def save_checkpoint(self,opt,ep=0,it=0,latest=False):
179 | util.save_checkpoint(opt,self,ep=ep,it=it,latest=latest)
180 | if not latest:
181 | log.info("checkpoint saved: ({0}) {1}, epoch {2} (iteration {3})".format(opt.group,opt.name,ep,it))
182 |
183 | # ============================ computation graph for forward/backprop ============================
184 |
185 | class Graph(torch.nn.Module):
186 |
187 | def __init__(self,opt):
188 | super().__init__()
189 |
190 | def forward(self,opt,var,mode=None):
191 | raise NotImplementedError
192 | return var
193 |
194 | def compute_loss(self,opt,var,mode=None):
195 | loss = edict()
196 | raise NotImplementedError
197 | return loss
198 |
199 | def L1_loss(self,pred,label=0):
200 | loss = (pred.contiguous()-label).abs()
201 | return loss.mean()
202 | def MSE_loss(self,pred,label=0):
203 | loss = (pred.contiguous()-label)**2
204 | return loss.mean()
205 |
--------------------------------------------------------------------------------
/model/nerf.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os,sys,time
3 | import torch
4 | import torch.nn.functional as torch_F
5 | import torchvision
6 | import torchvision.transforms.functional as torchvision_F
7 | import tqdm
8 | from easydict import EasyDict as edict
9 |
10 | import lpips
11 | from external.pohsun_ssim import pytorch_ssim
12 |
13 | import util,util_vis
14 | from util import log,debug
15 | from . import base
16 | import camera
17 |
18 | # ============================ main engine for training and evaluation ============================
19 |
20 | class Model(base.Model):
21 |
22 | def __init__(self,opt):
23 | super().__init__(opt)
24 | self.lpips_loss = lpips.LPIPS(net="alex").to(opt.device)
25 |
26 | def load_dataset(self,opt,eval_split="val"):
27 | super().load_dataset(opt,eval_split=eval_split)
28 | # prefetch all training data
29 | self.train_data.prefetch_all_data(opt)
30 | self.train_data.all = edict(util.move_to_device(self.train_data.all,opt.device))
31 |
32 | def setup_optimizer(self,opt):
33 | log.info("setting up optimizers...")
34 | optimizer = getattr(torch.optim,opt.optim.algo)
35 | self.optim = optimizer([dict(params=self.graph.nerf.parameters(),lr=opt.optim.lr)])
36 | if opt.nerf.fine_sampling:
37 | self.optim.add_param_group(dict(params=self.graph.nerf_fine.parameters(),lr=opt.optim.lr))
38 | # set up scheduler
39 | if opt.optim.sched:
40 | scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type)
41 | if opt.optim.lr_end:
42 | assert(opt.optim.sched.type=="ExponentialLR")
43 | opt.optim.sched.gamma = (opt.optim.lr_end/opt.optim.lr)**(1./opt.max_iter)
44 | kwargs = { k:v for k,v in opt.optim.sched.items() if k!="type" }
45 | self.sched = scheduler(self.optim,**kwargs)
46 |
47 | def train(self,opt):
48 | # before training
49 | log.title("TRAINING START")
50 | self.timer = edict(start=time.time(),it_mean=None)
51 | self.graph.train()
52 | self.ep = 0 # dummy for timer
53 | # training
54 | if self.iter_start==0: self.validate(opt,0)
55 | loader = tqdm.trange(opt.max_iter,desc="training",leave=False)
56 | for self.it in loader:
57 | if self.it/dev/null 2>&1".format(test_path,rgb_vid_fname))
159 | os.system("ffmpeg -y -framerate 30 -i {0}/depth_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1".format(test_path,depth_vid_fname))
160 | else:
161 | pose_pred,pose_GT = self.get_all_training_poses(opt)
162 | poses = pose_pred if opt.model=="barf" else pose_GT
163 | if opt.model=="barf" and opt.data.dataset=="llff":
164 | _,sim3 = self.prealign_cameras(opt,pose_pred,pose_GT)
165 | scale = sim3.s1/sim3.s0
166 | else: scale = 1
167 | # rotate novel views around the "center" camera of all poses
168 | idx_center = (poses-poses.mean(dim=0,keepdim=True))[...,3].norm(dim=-1).argmin()
169 | pose_novel = camera.get_novel_view_poses(opt,poses[idx_center],N=60,scale=scale).to(opt.device)
170 | # render the novel views
171 | novel_path = "{}/novel_view".format(opt.output_path)
172 | os.makedirs(novel_path,exist_ok=True)
173 | pose_novel_tqdm = tqdm.tqdm(pose_novel,desc="rendering novel views",leave=False)
174 | intr = edict(next(iter(self.test_loader))).intr[:1].to(opt.device) # grab intrinsics
175 | for i,pose in enumerate(pose_novel_tqdm):
176 | ret = self.graph.render_by_slices(opt,pose[None],intr=intr) if opt.nerf.rand_rays else \
177 | self.graph.render(opt,pose[None],intr=intr)
178 | invdepth = (1-ret.depth)/ret.opacity if opt.camera.ndc else 1/(ret.depth/ret.opacity+eps)
179 | rgb_map = ret.rgb.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,3,H,W]
180 | invdepth_map = invdepth.view(-1,opt.H,opt.W,1).permute(0,3,1,2) # [B,1,H,W]
181 | torchvision_F.to_pil_image(rgb_map.cpu()[0]).save("{}/rgb_{}.png".format(novel_path,i))
182 | torchvision_F.to_pil_image(invdepth_map.cpu()[0]).save("{}/depth_{}.png".format(novel_path,i))
183 | # write videos
184 | print("writing videos...")
185 | rgb_vid_fname = "{}/novel_view_rgb.mp4".format(opt.output_path)
186 | depth_vid_fname = "{}/novel_view_depth.mp4".format(opt.output_path)
187 | os.system("ffmpeg -y -framerate 30 -i {0}/rgb_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1".format(novel_path,rgb_vid_fname))
188 | os.system("ffmpeg -y -framerate 30 -i {0}/depth_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1".format(novel_path,depth_vid_fname))
189 |
190 | # ============================ computation graph for forward/backprop ============================
191 |
192 | class Graph(base.Graph):
193 |
194 | def __init__(self,opt):
195 | super().__init__(opt)
196 | self.nerf = NeRF(opt)
197 | if opt.nerf.fine_sampling:
198 | self.nerf_fine = NeRF(opt)
199 |
200 | def forward(self,opt,var,mode=None):
201 | batch_size = len(var.idx)
202 | pose = self.get_pose(opt,var,mode=mode)
203 | # render images
204 | if opt.nerf.rand_rays and mode in ["train","test-optim"]:
205 | # sample random rays for optimization
206 | var.ray_idx = torch.randperm(opt.H*opt.W,device=opt.device)[:opt.nerf.rand_rays//batch_size]
207 | ret = self.render(opt,pose,intr=var.intr,ray_idx=var.ray_idx,mode=mode) # [B,N,3],[B,N,1]
208 | else:
209 | # render full image (process in slices)
210 | ret = self.render_by_slices(opt,pose,intr=var.intr,mode=mode) if opt.nerf.rand_rays else \
211 | self.render(opt,pose,intr=var.intr,mode=mode) # [B,HW,3],[B,HW,1]
212 | var.update(ret)
213 | return var
214 |
215 | def compute_loss(self,opt,var,mode=None):
216 | loss = edict()
217 | batch_size = len(var.idx)
218 | image = var.image.view(batch_size,3,opt.H*opt.W).permute(0,2,1)
219 | if opt.nerf.rand_rays and mode in ["train","test-optim"]:
220 | image = image[:,var.ray_idx]
221 | # compute image losses
222 | if opt.loss_weight.render is not None:
223 | loss.render = self.MSE_loss(var.rgb,image)
224 | if opt.loss_weight.render_fine is not None:
225 | assert(opt.nerf.fine_sampling)
226 | loss.render_fine = self.MSE_loss(var.rgb_fine,image)
227 | return loss
228 |
229 | def get_pose(self,opt,var,mode=None):
230 | return var.pose
231 |
232 | def render(self,opt,pose,intr=None,ray_idx=None,mode=None):
233 | batch_size = len(pose)
234 | center,ray = camera.get_center_and_ray(opt,pose,intr=intr) # [B,HW,3]
235 | while ray.isnan().any(): # TODO: weird bug, ray becomes NaN arbitrarily if batch_size>1, not deterministic reproducible
236 | center,ray = camera.get_center_and_ray(opt,pose,intr=intr) # [B,HW,3]
237 | if ray_idx is not None:
238 | # consider only subset of rays
239 | center,ray = center[:,ray_idx],ray[:,ray_idx]
240 | if opt.camera.ndc:
241 | # convert center/ray representations to NDC
242 | center,ray = camera.convert_NDC(opt,center,ray,intr=intr)
243 | # render with main MLP
244 | depth_samples = self.sample_depth(opt,batch_size,num_rays=ray.shape[1]) # [B,HW,N,1]
245 | rgb_samples,density_samples = self.nerf.forward_samples(opt,center,ray,depth_samples,mode=mode)
246 | rgb,depth,opacity,prob = self.nerf.composite(opt,ray,rgb_samples,density_samples,depth_samples)
247 | ret = edict(rgb=rgb,depth=depth,opacity=opacity) # [B,HW,K]
248 | # render with fine MLP from coarse MLP
249 | if opt.nerf.fine_sampling:
250 | with torch.no_grad():
251 | # resample depth acoording to coarse empirical distribution
252 | depth_samples_fine = self.sample_depth_from_pdf(opt,pdf=prob[...,0]) # [B,HW,Nf,1]
253 | depth_samples = torch.cat([depth_samples,depth_samples_fine],dim=2) # [B,HW,N+Nf,1]
254 | depth_samples = depth_samples.sort(dim=2).values
255 | rgb_samples,density_samples = self.nerf_fine.forward_samples(opt,center,ray,depth_samples,mode=mode)
256 | rgb_fine,depth_fine,opacity_fine,_ = self.nerf_fine.composite(opt,ray,rgb_samples,density_samples,depth_samples)
257 | ret.update(rgb_fine=rgb_fine,depth_fine=depth_fine,opacity_fine=opacity_fine) # [B,HW,K]
258 | return ret
259 |
260 | def render_by_slices(self,opt,pose,intr=None,mode=None):
261 | ret_all = edict(rgb=[],depth=[],opacity=[])
262 | if opt.nerf.fine_sampling:
263 | ret_all.update(rgb_fine=[],depth_fine=[],opacity_fine=[])
264 | # render the image by slices for memory considerations
265 | for c in range(0,opt.H*opt.W,opt.nerf.rand_rays):
266 | ray_idx = torch.arange(c,min(c+opt.nerf.rand_rays,opt.H*opt.W),device=opt.device)
267 | ret = self.render(opt,pose,intr=intr,ray_idx=ray_idx,mode=mode) # [B,R,3],[B,R,1]
268 | for k in ret: ret_all[k].append(ret[k])
269 | # group all slices of images
270 | for k in ret_all: ret_all[k] = torch.cat(ret_all[k],dim=1)
271 | return ret_all
272 |
273 | def sample_depth(self,opt,batch_size,num_rays=None):
274 | depth_min,depth_max = opt.nerf.depth.range
275 | num_rays = num_rays or opt.H*opt.W
276 | rand_samples = torch.rand(batch_size,num_rays,opt.nerf.sample_intvs,1,device=opt.device) if opt.nerf.sample_stratified else 0.5
277 | rand_samples += torch.arange(opt.nerf.sample_intvs,device=opt.device)[None,None,:,None].float() # [B,HW,N,1]
278 | depth_samples = rand_samples/opt.nerf.sample_intvs*(depth_max-depth_min)+depth_min # [B,HW,N,1]
279 | depth_samples = dict(
280 | metric=depth_samples,
281 | inverse=1/(depth_samples+1e-8),
282 | )[opt.nerf.depth.param]
283 | return depth_samples
284 |
285 | def sample_depth_from_pdf(self,opt,pdf):
286 | depth_min,depth_max = opt.nerf.depth.range
287 | # get CDF from PDF (along last dimension)
288 | cdf = pdf.cumsum(dim=-1) # [B,HW,N]
289 | cdf = torch.cat([torch.zeros_like(cdf[...,:1]),cdf],dim=-1) # [B,HW,N+1]
290 | # take uniform samples
291 | grid = torch.linspace(0,1,opt.nerf.sample_intvs_fine+1,device=opt.device) # [Nf+1]
292 | unif = 0.5*(grid[:-1]+grid[1:]).repeat(*cdf.shape[:-1],1) # [B,HW,Nf]
293 | idx = torch.searchsorted(cdf,unif,right=True) # [B,HW,Nf] \in {1...N}
294 | # inverse transform sampling from CDF
295 | depth_bin = torch.linspace(depth_min,depth_max,opt.nerf.sample_intvs+1,device=opt.device) # [N+1]
296 | depth_bin = depth_bin.repeat(*cdf.shape[:-1],1) # [B,HW,N+1]
297 | depth_low = depth_bin.gather(dim=2,index=(idx-1).clamp(min=0)) # [B,HW,Nf]
298 | depth_high = depth_bin.gather(dim=2,index=idx.clamp(max=opt.nerf.sample_intvs)) # [B,HW,Nf]
299 | cdf_low = cdf.gather(dim=2,index=(idx-1).clamp(min=0)) # [B,HW,Nf]
300 | cdf_high = cdf.gather(dim=2,index=idx.clamp(max=opt.nerf.sample_intvs)) # [B,HW,Nf]
301 | # linear interpolation
302 | t = (unif-cdf_low)/(cdf_high-cdf_low+1e-8) # [B,HW,Nf]
303 | depth_samples = depth_low+t*(depth_high-depth_low) # [B,HW,Nf]
304 | return depth_samples[...,None] # [B,HW,Nf,1]
305 |
306 | class NeRF(torch.nn.Module):
307 |
308 | def __init__(self,opt):
309 | super().__init__()
310 | self.define_network(opt)
311 |
312 | def define_network(self,opt):
313 | input_3D_dim = 3+6*opt.arch.posenc.L_3D if opt.arch.posenc else 3
314 | if opt.nerf.view_dep:
315 | input_view_dim = 3+6*opt.arch.posenc.L_view if opt.arch.posenc else 3
316 | # point-wise feature
317 | self.mlp_feat = torch.nn.ModuleList()
318 | L = util.get_layer_dims(opt.arch.layers_feat)
319 | for li,(k_in,k_out) in enumerate(L):
320 | if li==0: k_in = input_3D_dim
321 | if li in opt.arch.skip: k_in += input_3D_dim
322 | if li==len(L)-1: k_out += 1
323 | linear = torch.nn.Linear(k_in,k_out)
324 | if opt.arch.tf_init:
325 | self.tensorflow_init_weights(opt,linear,out="first" if li==len(L)-1 else None)
326 | self.mlp_feat.append(linear)
327 | # RGB prediction
328 | self.mlp_rgb = torch.nn.ModuleList()
329 | L = util.get_layer_dims(opt.arch.layers_rgb)
330 | feat_dim = opt.arch.layers_feat[-1]
331 | for li,(k_in,k_out) in enumerate(L):
332 | if li==0: k_in = feat_dim+(input_view_dim if opt.nerf.view_dep else 0)
333 | linear = torch.nn.Linear(k_in,k_out)
334 | if opt.arch.tf_init:
335 | self.tensorflow_init_weights(opt,linear,out="all" if li==len(L)-1 else None)
336 | self.mlp_rgb.append(linear)
337 |
338 | def tensorflow_init_weights(self,opt,linear,out=None):
339 | # use Xavier init instead of Kaiming init
340 | relu_gain = torch.nn.init.calculate_gain("relu") # sqrt(2)
341 | if out=="all":
342 | torch.nn.init.xavier_uniform_(linear.weight)
343 | elif out=="first":
344 | torch.nn.init.xavier_uniform_(linear.weight[:1])
345 | torch.nn.init.xavier_uniform_(linear.weight[1:],gain=relu_gain)
346 | else:
347 | torch.nn.init.xavier_uniform_(linear.weight,gain=relu_gain)
348 | torch.nn.init.zeros_(linear.bias)
349 |
350 | def forward(self,opt,points_3D,ray_unit=None,mode=None): # [B,...,3]
351 | if opt.arch.posenc:
352 | points_enc = self.positional_encoding(opt,points_3D,L=opt.arch.posenc.L_3D)
353 | points_enc = torch.cat([points_3D,points_enc],dim=-1) # [B,...,6L+3]
354 | else: points_enc = points_3D
355 | feat = points_enc
356 | # extract coordinate-based features
357 | for li,layer in enumerate(self.mlp_feat):
358 | if li in opt.arch.skip: feat = torch.cat([feat,points_enc],dim=-1)
359 | feat = layer(feat)
360 | if li==len(self.mlp_feat)-1:
361 | density = feat[...,0]
362 | if opt.nerf.density_noise_reg and mode=="train":
363 | density += torch.randn_like(density)*opt.nerf.density_noise_reg
364 | density_activ = getattr(torch_F,opt.arch.density_activ) # relu_,abs_,sigmoid_,exp_....
365 | density = density_activ(density)
366 | feat = feat[...,1:]
367 | feat = torch_F.relu(feat)
368 | # predict RGB values
369 | if opt.nerf.view_dep:
370 | assert(ray_unit is not None)
371 | if opt.arch.posenc:
372 | ray_enc = self.positional_encoding(opt,ray_unit,L=opt.arch.posenc.L_view)
373 | ray_enc = torch.cat([ray_unit,ray_enc],dim=-1) # [B,...,6L+3]
374 | else: ray_enc = ray_unit
375 | feat = torch.cat([feat,ray_enc],dim=-1)
376 | for li,layer in enumerate(self.mlp_rgb):
377 | feat = layer(feat)
378 | if li!=len(self.mlp_rgb)-1:
379 | feat = torch_F.relu(feat)
380 | rgb = feat.sigmoid_() # [B,...,3]
381 | return rgb,density
382 |
383 | def forward_samples(self,opt,center,ray,depth_samples,mode=None):
384 | points_3D_samples = camera.get_3D_points_from_depth(opt,center,ray,depth_samples,multi_samples=True) # [B,HW,N,3]
385 | if opt.nerf.view_dep:
386 | ray_unit = torch_F.normalize(ray,dim=-1) # [B,HW,3]
387 | ray_unit_samples = ray_unit[...,None,:].expand_as(points_3D_samples) # [B,HW,N,3]
388 | else: ray_unit_samples = None
389 | rgb_samples,density_samples = self.forward(opt,points_3D_samples,ray_unit=ray_unit_samples,mode=mode) # [B,HW,N],[B,HW,N,3]
390 | return rgb_samples,density_samples
391 |
392 | def composite(self,opt,ray,rgb_samples,density_samples,depth_samples):
393 | ray_length = ray.norm(dim=-1,keepdim=True) # [B,HW,1]
394 | # volume rendering: compute probability (using quadrature)
395 | depth_intv_samples = depth_samples[...,1:,0]-depth_samples[...,:-1,0] # [B,HW,N-1]
396 | depth_intv_samples = torch.cat([depth_intv_samples,torch.empty_like(depth_intv_samples[...,:1]).fill_(1e10)],dim=2) # [B,HW,N]
397 | dist_samples = depth_intv_samples*ray_length # [B,HW,N]
398 | sigma_delta = density_samples*dist_samples # [B,HW,N]
399 | alpha = 1-(-sigma_delta).exp_() # [B,HW,N]
400 | T = (-torch.cat([torch.zeros_like(sigma_delta[...,:1]),sigma_delta[...,:-1]],dim=2).cumsum(dim=2)).exp_() # [B,HW,N]
401 | prob = (T*alpha)[...,None] # [B,HW,N,1]
402 | # integrate RGB and depth weighted by probability
403 | depth = (depth_samples*prob).sum(dim=2) # [B,HW,1]
404 | rgb = (rgb_samples*prob).sum(dim=2) # [B,HW,3]
405 | opacity = prob.sum(dim=2) # [B,HW,1]
406 | if opt.nerf.setbg_opaque:
407 | rgb = rgb+opt.data.bgcolor*(1-opacity)
408 | return rgb,depth,opacity,prob # [B,HW,K]
409 |
410 | def positional_encoding(self,opt,input,L): # [B,...,N]
411 | shape = input.shape
412 | freq = 2**torch.arange(L,dtype=torch.float32,device=opt.device)*np.pi # [L]
413 | spectrum = input[...,None]*freq # [B,...,N,L]
414 | sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L]
415 | input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L]
416 | input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL]
417 | return input_enc
418 |
--------------------------------------------------------------------------------
/model/planar.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os,sys,time
3 | import torch
4 | import torch.nn.functional as torch_F
5 | import torchvision
6 | import torchvision.transforms.functional as torchvision_F
7 | import tqdm
8 | from easydict import EasyDict as edict
9 | import PIL
10 | import PIL.Image,PIL.ImageDraw
11 | import imageio
12 |
13 | import util,util_vis
14 | from util import log,debug
15 | from . import base
16 | import warp
17 |
18 | # ============================ main engine for training and evaluation ============================
19 |
20 | class Model(base.Model):
21 |
22 | def __init__(self,opt):
23 | super().__init__(opt)
24 | opt.H_crop,opt.W_crop = opt.data.patch_crop
25 |
26 | def load_dataset(self,opt,eval_split=None):
27 | image_raw = PIL.Image.open(opt.data.image_fname)
28 | self.image_raw = torchvision_F.to_tensor(image_raw).to(opt.device)
29 |
30 | def build_networks(self,opt):
31 | super().build_networks(opt)
32 | self.graph.warp_param = torch.nn.Embedding(opt.batch_size,opt.warp.dof).to(opt.device)
33 | torch.nn.init.zeros_(self.graph.warp_param.weight)
34 |
35 | def setup_optimizer(self,opt):
36 | log.info("setting up optimizers...")
37 | optim_list = [
38 | dict(params=self.graph.neural_image.parameters(),lr=opt.optim.lr),
39 | dict(params=self.graph.warp_param.parameters(),lr=opt.optim.lr_warp),
40 | ]
41 | optimizer = getattr(torch.optim,opt.optim.algo)
42 | self.optim = optimizer(optim_list)
43 | # set up scheduler
44 | if opt.optim.sched:
45 | scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type)
46 | kwargs = { k:v for k,v in opt.optim.sched.items() if k!="type" }
47 | self.sched = scheduler(self.optim,**kwargs)
48 |
49 | def setup_visualizer(self,opt):
50 | super().setup_visualizer(opt)
51 | # set colors for visualization
52 | box_colors = ["#ff0000","#40afff","#9314ff","#ffd700","#00ff00"]
53 | box_colors = list(map(util.colorcode_to_number,box_colors))
54 | self.box_colors = np.array(box_colors).astype(int)
55 | assert(len(self.box_colors)==opt.batch_size)
56 | # create visualization directory
57 | self.vis_path = "{}/vis".format(opt.output_path)
58 | os.makedirs(self.vis_path,exist_ok=True)
59 | self.video_fname = "{}/vis.mp4".format(opt.output_path)
60 |
61 | def train(self,opt):
62 | # before training
63 | log.title("TRAINING START")
64 | self.timer = edict(start=time.time(),it_mean=None)
65 | self.ep = self.it = self.vis_it = 0
66 | self.graph.train()
67 | var = edict(idx=torch.arange(opt.batch_size))
68 | # pre-generate perturbations
69 | self.warp_pert,var.image_pert = self.generate_warp_perturbation(opt)
70 | # train
71 | var = util.move_to_device(var,opt.device)
72 | loader = tqdm.trange(opt.max_iter,desc="training",leave=False)
73 | # visualize initial state
74 | var = self.graph.forward(opt,var)
75 | self.visualize(opt,var,step=0)
76 | for it in loader:
77 | # train iteration
78 | loss = self.train_iteration(opt,var,loader)
79 | if opt.warp.fix_first:
80 | self.graph.warp_param.weight.data[0] = 0
81 | # after training
82 | os.system("ffmpeg -y -framerate 30 -i {}/%d.png -pix_fmt yuv420p {}".format(self.vis_path,self.video_fname))
83 | self.save_checkpoint(opt,ep=None,it=self.it)
84 | if opt.tb:
85 | self.tb.flush()
86 | self.tb.close()
87 | if opt.visdom: self.vis.close()
88 | log.title("TRAINING DONE")
89 |
90 | def train_iteration(self,opt,var,loader):
91 | loss = super().train_iteration(opt,var,loader)
92 | self.graph.neural_image.progress.data.fill_(self.it/opt.max_iter)
93 | return loss
94 |
95 | def generate_warp_perturbation(self,opt):
96 | # pre-generate perturbations (translational noise + homography noise)
97 | warp_pert_all = torch.zeros(opt.batch_size,opt.warp.dof,device=opt.device)
98 | trans_pert = [(0,0)]+[(x,y) for x in (-opt.warp.noise_t,opt.warp.noise_t)
99 | for y in (-opt.warp.noise_t,opt.warp.noise_t)]
100 | def create_random_perturbation():
101 | warp_pert = torch.randn(opt.warp.dof,device=opt.device)*opt.warp.noise_h
102 | warp_pert[0] += trans_pert[i][0]
103 | warp_pert[1] += trans_pert[i][1]
104 | return warp_pert
105 | for i in range(opt.batch_size):
106 | warp_pert = create_random_perturbation()
107 | while not warp.check_corners_in_range(opt,warp_pert[None]):
108 | warp_pert = create_random_perturbation()
109 | warp_pert_all[i] = warp_pert
110 | if opt.warp.fix_first:
111 | warp_pert_all[0] = 0
112 | # create warped image patches
113 | xy_grid = warp.get_normalized_pixel_grid_crop(opt) # [B,HW,2]
114 | xy_grid_warped = warp.warp_grid(opt,xy_grid,warp_pert_all)
115 | xy_grid_warped = xy_grid_warped.view([opt.batch_size,opt.H_crop,opt.W_crop,2])
116 | xy_grid_warped = torch.stack([xy_grid_warped[...,0]*max(opt.H,opt.W)/opt.W,
117 | xy_grid_warped[...,1]*max(opt.H,opt.W)/opt.H],dim=-1)
118 | image_raw_batch = self.image_raw.repeat(opt.batch_size,1,1,1)
119 | image_pert_all = torch_F.grid_sample(image_raw_batch,xy_grid_warped,align_corners=False)
120 | return warp_pert_all,image_pert_all
121 |
122 | def visualize_patches(self,opt,warp_param):
123 | image_pil = torchvision_F.to_pil_image(self.image_raw).convert("RGBA")
124 | draw_pil = PIL.Image.new("RGBA",image_pil.size,(0,0,0,0))
125 | draw = PIL.ImageDraw.Draw(draw_pil)
126 | corners_all = warp.warp_corners(opt,warp_param)
127 | corners_all[...,0] = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5
128 | corners_all[...,1] = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5
129 | for i,corners in enumerate(corners_all):
130 | P = [tuple(float(n) for n in corners[j]) for j in range(4)]
131 | draw.line([P[0],P[1],P[2],P[3],P[0]],fill=tuple(self.box_colors[i]),width=3)
132 | image_pil.alpha_composite(draw_pil)
133 | image_tensor = torchvision_F.to_tensor(image_pil.convert("RGB"))
134 | return image_tensor
135 |
136 | @torch.no_grad()
137 | def predict_entire_image(self,opt):
138 | xy_grid = warp.get_normalized_pixel_grid(opt)[:1]
139 | rgb = self.graph.neural_image.forward(opt,xy_grid) # [B,HW,3]
140 | image = rgb.view(opt.H,opt.W,3).detach().cpu().permute(2,0,1)
141 | return image
142 |
143 | @torch.no_grad()
144 | def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
145 | super().log_scalars(opt,var,loss,metric=metric,step=step,split=split)
146 | # compute PSNR
147 | psnr = -10*loss.render.log10()
148 | self.tb.add_scalar("{0}/{1}".format(split,"PSNR"),psnr,step)
149 | # warp error
150 | warp_error = (self.graph.warp_param.weight-self.warp_pert).norm(dim=-1).mean()
151 | self.tb.add_scalar("{0}/{1}".format(split,"warp error"),warp_error,step)
152 |
153 | @torch.no_grad()
154 | def visualize(self,opt,var,step=0,split="train"):
155 | # dump frames for writing to video
156 | frame_GT = self.visualize_patches(opt,self.warp_pert)
157 | frame = self.visualize_patches(opt,self.graph.warp_param.weight)
158 | frame2 = self.predict_entire_image(opt)
159 | frame_cat = (torch.cat([frame,frame2],dim=1)*255).byte().permute(1,2,0).numpy()
160 | imageio.imsave("{}/{}.png".format(self.vis_path,self.vis_it),frame_cat)
161 | self.vis_it += 1
162 | # visualize in Tensorboard
163 | if opt.tb:
164 | colors = self.box_colors
165 | util_vis.tb_image(opt,self.tb,step,split,"image_pert",util_vis.color_border(var.image_pert,colors))
166 | util_vis.tb_image(opt,self.tb,step,split,"rgb_warped",util_vis.color_border(var.rgb_warped_map,colors))
167 | util_vis.tb_image(opt,self.tb,self.it+1,"train","image_boxes",frame[None])
168 | util_vis.tb_image(opt,self.tb,self.it+1,"train","image_boxes_GT",frame_GT[None])
169 | util_vis.tb_image(opt,self.tb,self.it+1,"train","image_entire",frame2[None])
170 |
171 | # ============================ computation graph for forward/backprop ============================
172 |
173 | class Graph(base.Graph):
174 |
175 | def __init__(self,opt):
176 | super().__init__(opt)
177 | self.neural_image = NeuralImageFunction(opt)
178 |
179 | def forward(self,opt,var,mode=None):
180 | xy_grid = warp.get_normalized_pixel_grid_crop(opt)
181 | xy_grid_warped = warp.warp_grid(opt,xy_grid,self.warp_param.weight)
182 | # render images
183 | var.rgb_warped = self.neural_image.forward(opt,xy_grid_warped) # [B,HW,3]
184 | var.rgb_warped_map = var.rgb_warped.view(opt.batch_size,opt.H_crop,opt.W_crop,3).permute(0,3,1,2) # [B,3,H,W]
185 | return var
186 |
187 | def compute_loss(self,opt,var,mode=None):
188 | loss = edict()
189 | if opt.loss_weight.render is not None:
190 | image_pert = var.image_pert.view(opt.batch_size,3,opt.H_crop*opt.W_crop).permute(0,2,1)
191 | loss.render = self.MSE_loss(var.rgb_warped,image_pert)
192 | return loss
193 |
194 | class NeuralImageFunction(torch.nn.Module):
195 |
196 | def __init__(self,opt):
197 | super().__init__()
198 | self.define_network(opt)
199 | self.progress = torch.nn.Parameter(torch.tensor(0.)) # use Parameter so it could be checkpointed
200 |
201 | def define_network(self,opt):
202 | input_2D_dim = 2+4*opt.arch.posenc.L_2D if opt.arch.posenc else 2
203 | # point-wise RGB prediction
204 | self.mlp = torch.nn.ModuleList()
205 | L = util.get_layer_dims(opt.arch.layers)
206 | for li,(k_in,k_out) in enumerate(L):
207 | if li==0: k_in = input_2D_dim
208 | if li in opt.arch.skip: k_in += input_2D_dim
209 | linear = torch.nn.Linear(k_in,k_out)
210 | if opt.barf_c2f and li==0:
211 | # rescale first layer init (distribution was for pos.enc. but only xy is first used)
212 | scale = np.sqrt(input_2D_dim/2.)
213 | linear.weight.data *= scale
214 | linear.bias.data *= scale
215 | self.mlp.append(linear)
216 |
217 | def forward(self,opt,coord_2D): # [B,...,3]
218 | if opt.arch.posenc:
219 | points_enc = self.positional_encoding(opt,coord_2D,L=opt.arch.posenc.L_2D)
220 | points_enc = torch.cat([coord_2D,points_enc],dim=-1) # [B,...,6L+3]
221 | else: points_enc = coord_2D
222 | feat = points_enc
223 | # extract implicit features
224 | for li,layer in enumerate(self.mlp):
225 | if li in opt.arch.skip: feat = torch.cat([feat,points_enc],dim=-1)
226 | feat = layer(feat)
227 | if li!=len(self.mlp)-1:
228 | feat = torch_F.relu(feat)
229 | rgb = feat.sigmoid_() # [B,...,3]
230 | return rgb
231 |
232 | def positional_encoding(self,opt,input,L): # [B,...,N]
233 | shape = input.shape
234 | freq = 2**torch.arange(L,dtype=torch.float32,device=opt.device)*np.pi # [L]
235 | spectrum = input[...,None]*freq # [B,...,N,L]
236 | sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L]
237 | input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L]
238 | input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL]
239 | # coarse-to-fine: smoothly mask positional encoding for BARF
240 | if opt.barf_c2f is not None:
241 | # set weights for different frequency bands
242 | start,end = opt.barf_c2f
243 | alpha = (self.progress.data-start)/(end-start)*L
244 | k = torch.arange(L,dtype=torch.float32,device=opt.device)
245 | weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2
246 | # apply weights
247 | shape = input_enc.shape
248 | input_enc = (input_enc.view(-1,L)*weight).view(*shape)
249 | return input_enc
250 |
--------------------------------------------------------------------------------
/options.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os,sys,time
3 | import torch
4 | import random
5 | import string
6 | import yaml
7 | from easydict import EasyDict as edict
8 |
9 | import util
10 | from util import log
11 |
12 | # torch.backends.cudnn.enabled = False
13 | # torch.backends.cudnn.benchmark = False
14 | # torch.backends.cudnn.deterministic = True
15 |
16 | def parse_arguments(args):
17 | """
18 | Parse arguments from command line.
19 | Syntax: --key1.key2.key3=value --> value
20 | --key1.key2.key3= --> None
21 | --key1.key2.key3 --> True
22 | --key1.key2.key3! --> False
23 | """
24 | opt_cmd = {}
25 | for arg in args:
26 | assert(arg.startswith("--"))
27 | if "=" not in arg[2:]:
28 | key_str,value = (arg[2:-1],"false") if arg[-1]=="!" else (arg[2:],"true")
29 | else:
30 | key_str,value = arg[2:].split("=")
31 | keys_sub = key_str.split(".")
32 | opt_sub = opt_cmd
33 | for k in keys_sub[:-1]:
34 | if k not in opt_sub: opt_sub[k] = {}
35 | opt_sub = opt_sub[k]
36 | assert keys_sub[-1] not in opt_sub,keys_sub[-1]
37 | opt_sub[keys_sub[-1]] = yaml.safe_load(value)
38 | opt_cmd = edict(opt_cmd)
39 | return opt_cmd
40 |
41 | def set(opt_cmd={}):
42 | log.info("setting configurations...")
43 | assert("model" in opt_cmd)
44 | # load config from yaml file
45 | assert("yaml" in opt_cmd)
46 | fname = "options/{}.yaml".format(opt_cmd.yaml)
47 | opt_base = load_options(fname)
48 | # override with command line arguments
49 | opt = override_options(opt_base,opt_cmd,key_stack=[],safe_check=True)
50 | process_options(opt)
51 | log.options(opt)
52 | return opt
53 |
54 | def load_options(fname):
55 | with open(fname) as file:
56 | opt = edict(yaml.safe_load(file))
57 | if "_parent_" in opt:
58 | # load parent yaml file(s) as base options
59 | parent_fnames = opt.pop("_parent_")
60 | if type(parent_fnames) is str:
61 | parent_fnames = [parent_fnames]
62 | for parent_fname in parent_fnames:
63 | opt_parent = load_options(parent_fname)
64 | opt_parent = override_options(opt_parent,opt,key_stack=[])
65 | opt = opt_parent
66 | print("loading {}...".format(fname))
67 | return opt
68 |
69 | def override_options(opt,opt_over,key_stack=None,safe_check=False):
70 | for key,value in opt_over.items():
71 | if isinstance(value,dict):
72 | # parse child options (until leaf nodes are reached)
73 | opt[key] = override_options(opt.get(key,dict()),value,key_stack=key_stack+[key],safe_check=safe_check)
74 | else:
75 | # ensure command line argument to override is also in yaml file
76 | if safe_check and key not in opt:
77 | add_new = None
78 | while add_new not in ["y","n"]:
79 | key_str = ".".join(key_stack+[key])
80 | add_new = input("\"{}\" not found in original opt, add? (y/n) ".format(key_str))
81 | if add_new=="n":
82 | print("safe exiting...")
83 | exit()
84 | opt[key] = value
85 | return opt
86 |
87 | def process_options(opt):
88 | # set seed
89 | if opt.seed is not None:
90 | random.seed(opt.seed)
91 | np.random.seed(opt.seed)
92 | torch.manual_seed(opt.seed)
93 | torch.cuda.manual_seed_all(opt.seed)
94 | if opt.seed!=0:
95 | opt.name = str(opt.name)+"_seed{}".format(opt.seed)
96 | else:
97 | # create random string as run ID
98 | randkey = "".join(random.choice(string.ascii_uppercase) for _ in range(4))
99 | opt.name = str(opt.name)+"_{}".format(randkey)
100 | # other default options
101 | opt.output_path = "{0}/{1}/{2}".format(opt.output_root,opt.group,opt.name)
102 | os.makedirs(opt.output_path,exist_ok=True)
103 | assert(isinstance(opt.gpu,int)) # disable multi-GPU support for now, single is enough
104 | opt.device = "cpu" if opt.cpu or not torch.cuda.is_available() else "cuda:{}".format(opt.gpu)
105 | opt.H,opt.W = opt.data.image_size
106 |
107 | def save_options_file(opt):
108 | opt_fname = "{}/options.yaml".format(opt.output_path)
109 | if os.path.isfile(opt_fname):
110 | with open(opt_fname) as file:
111 | opt_old = yaml.safe_load(file)
112 | if opt!=opt_old:
113 | # prompt if options are not identical
114 | opt_new_fname = "{}/options_temp.yaml".format(opt.output_path)
115 | with open(opt_new_fname,"w") as file:
116 | yaml.safe_dump(util.to_dict(opt),file,default_flow_style=False,indent=4)
117 | print("existing options file found (different from current one)...")
118 | os.system("diff {} {}".format(opt_fname,opt_new_fname))
119 | os.system("rm {}".format(opt_new_fname))
120 | override = None
121 | while override not in ["y","n"]:
122 | override = input("override? (y/n) ")
123 | if override=="n":
124 | print("safe exiting...")
125 | exit()
126 | else: print("existing options file found (identical)")
127 | else: print("(creating new options file...)")
128 | with open(opt_fname,"w") as file:
129 | yaml.safe_dump(util.to_dict(opt),file,default_flow_style=False,indent=4)
130 |
--------------------------------------------------------------------------------
/options/barf_blender.yaml:
--------------------------------------------------------------------------------
1 | _parent_: options/nerf_blender.yaml
2 |
3 | barf_c2f: # coarse-to-fine scheduling on positional encoding
4 |
5 | camera: # camera options
6 | noise: 0.15 # synthetic perturbations on the camera poses (Blender only)
7 |
8 | optim: # optimization options
9 | lr_pose: 1.e-3 # learning rate of camera poses
10 | lr_pose_end: 1.e-5 # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR)
11 | sched_pose: # learning rate scheduling options
12 | type: ExponentialLR # scheduler (see PyTorch doc)
13 | gamma: # decay rate (can be empty if lr_pose_end were specified)
14 | warmup_pose: # linear warmup of the pose learning rate (N iterations)
15 | test_photo: true # test-time photometric optimization for evaluation
16 | test_iter: 100 # number of iterations for test-time optimization
17 |
18 | visdom: # Visdom options
19 | cam_depth: 0.5 # size of visualized cameras
20 |
--------------------------------------------------------------------------------
/options/barf_iphone.yaml:
--------------------------------------------------------------------------------
1 | _parent_: options/barf_llff.yaml
2 |
3 | data: # data options
4 | dataset: iphone # dataset name
5 | scene: IMG_0239 # scene name
6 | image_size: [480,640] # input image sizes [height,width]
7 |
--------------------------------------------------------------------------------
/options/barf_llff.yaml:
--------------------------------------------------------------------------------
1 | _parent_: options/nerf_llff.yaml
2 |
3 | barf_c2f: # coarse-to-fine scheduling on positional encoding
4 |
5 | camera: # camera options
6 | noise: # synthetic perturbations on the camera poses (Blender only)
7 |
8 | optim: # optimization options
9 | lr_pose: 3.e-3 # learning rate of camera poses
10 | lr_pose_end: 1.e-5 # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR)
11 | sched_pose: # learning rate scheduling options
12 | type: ExponentialLR # scheduler (see PyTorch doc)
13 | gamma: # decay rate (can be empty if lr_pose_end were specified)
14 | warmup_pose: # linear warmup of the pose learning rate (N iterations)
15 | test_photo: true # test-time photometric optimization for evaluation
16 | test_iter: 100 # number of iterations for test-time optimization
17 |
18 | visdom: # Visdom options
19 | cam_depth: 0.2 # size of visualized cameras
20 |
--------------------------------------------------------------------------------
/options/base.yaml:
--------------------------------------------------------------------------------
1 | # default
2 |
3 | group: 0_test # name of experiment group
4 | name: debug # name of experiment run
5 | model: # type of model (must be specified from command line)
6 | yaml: # config file (must be specified from command line)
7 | seed: 0 # seed number (for both numpy and pytorch)
8 | gpu: 0 # GPU index number
9 | cpu: false # run only on CPU (not supported now)
10 | load: # load checkpoint from filename
11 |
12 | arch: {} # architectural options
13 |
14 | data: # data options
15 | root: # root path to dataset
16 | dataset: # dataset name
17 | image_size: [null,null] # input image sizes [height,width]
18 | num_workers: 8 # number of parallel workers for data loading
19 | preload: false # preload the entire dataset into the memory
20 | augment: {} # data augmentation (training only)
21 | # rotate: # random rotation
22 | # brightness: # 0.2 # random brightness jitter
23 | # contrast: # 0.2 # random contrast jitter
24 | # saturation: # 0.2 # random saturation jitter
25 | # hue: # 0.1 # random hue jitter
26 | # hflip: # True # random horizontal flip
27 | center_crop: # center crop the image by ratio
28 | val_on_test: false # validate on test set during training
29 | train_sub: # consider a subset of N training samples
30 | val_sub: # consider a subset of N validation samples
31 |
32 | loss_weight: {} # loss weights (in log scale)
33 |
34 | optim: # optimization options
35 | lr: 1.e-3 # learning rate (main)
36 | lr_end: # terminal learning rate (only used with sched.type=ExponentialLR)
37 | algo: Adam # optimizer (see PyTorch doc)
38 | sched: {} # learning rate scheduling options
39 | # type: StepLR # scheduler (see PyTorch doc)
40 | # steps: # decay every N epochs
41 | # gamma: 0.1 # decay rate (can be empty if lr_end were specified)
42 |
43 | batch_size: 16 # batch size
44 | max_epoch: 1000 # train to maximum number of epochs
45 | resume: false # resume training (true for latest checkpoint, or number for specific epoch number)
46 |
47 | output_root: output # root path for output files (checkpoints and results)
48 | tb: # TensorBoard options
49 | num_images: [4,8] # number of (tiled) images to visualize in TensorBoard
50 | visdom: # Visdom options
51 | server: localhost # server to host Visdom
52 | port: 9000 # port number for Visdom
53 |
54 | freq: # periodic actions during training
55 | scalar: 200 # log losses and scalar states (every N iterations)
56 | vis: 1000 # visualize results (every N iterations)
57 | val: 20 # validate on val set (every N epochs)
58 | ckpt: 50 # save checkpoint (every N epochs)
59 |
--------------------------------------------------------------------------------
/options/nerf_blender.yaml:
--------------------------------------------------------------------------------
1 | _parent_: options/base.yaml
2 |
3 | arch: # architectural options
4 | layers_feat: [null,256,256,256,256,256,256,256,256] # hidden layers for feature/density MLP
5 | layers_rgb: [null,128,3] # hidden layers for color MLP
6 | skip: [4] # skip connections
7 | posenc: # positional encoding
8 | L_3D: 10 # number of bases (3D point)
9 | L_view: 4 # number of bases (viewpoint)
10 | density_activ: softplus # activation function for output volume density
11 | tf_init: true # initialize network weights in TensorFlow style
12 |
13 | nerf: # NeRF-specific options
14 | view_dep: true # condition MLP on viewpoint
15 | depth: # depth-related options
16 | param: metric # depth parametrization (for sampling along the ray)
17 | range: [2,6] # near/far bounds for depth sampling
18 | sample_intvs: 128 # number of samples
19 | sample_stratified: true # stratified sampling
20 | fine_sampling: false # hierarchical sampling with another NeRF
21 | sample_intvs_fine: # number of samples for the fine NeRF
22 | rand_rays: 1024 # number of random rays for each step
23 | density_noise_reg: # Gaussian noise on density output as regularization
24 | setbg_opaque: false # fill transparent rendering with known background color (Blender only)
25 |
26 | data: # data options
27 | dataset: blender # dataset name
28 | scene: lego # scene name
29 | image_size: [400,400] # input image sizes [height,width]
30 | num_workers: 4 # number of parallel workers for data loading
31 | preload: true # preload the entire dataset into the memory
32 | bgcolor: 1 # background color (Blender only)
33 | val_sub: 4 # consider a subset of N validation samples
34 |
35 | camera: # camera options
36 | model: perspective # type of camera model
37 | ndc: false # reparametrize as normalized device coordinates (NDC)
38 |
39 | loss_weight: # loss weights (in log scale)
40 | render: 0 # RGB rendering loss
41 | render_fine: # RGB rendering loss (for fine NeRF)
42 |
43 | optim: # optimization options
44 | lr: 5.e-4 # learning rate (main)
45 | lr_end: 1.e-4 # terminal learning rate (only used with sched.type=ExponentialLR)
46 | sched: # learning rate scheduling options
47 | type: ExponentialLR # scheduler (see PyTorch doc)
48 | gamma: # decay rate (can be empty if lr_end were specified)
49 |
50 | batch_size: # batch size (not used for NeRF/BARF)
51 | max_epoch: # train to maximum number of epochs (not used for NeRF/BARF)
52 | max_iter: 200000 # train to maximum number of iterations
53 |
54 | trimesh: # options for marching cubes to extract 3D mesh
55 | res: 128 # 3D sampling resolution
56 | range: [-1.2,1.2] # 3D range of interest (assuming same for x,y,z)
57 | thres: 25. # volume density threshold for marching cubes
58 | chunk_size: 16384 # chunk size of dense samples to be evaluated at a time
59 |
60 | freq: # periodic actions during training
61 | scalar: 200 # log losses and scalar states (every N iterations)
62 | vis: 1000 # visualize results (every N iterations)
63 | val: 2000 # validate on val set (every N iterations)
64 | ckpt: 5000 # save checkpoint (every N iterations)
65 |
--------------------------------------------------------------------------------
/options/nerf_blender_repr.yaml:
--------------------------------------------------------------------------------
1 | _parent_: options/base.yaml
2 |
3 | arch: # architectural options
4 | layers_feat: [null,256,256,256,256,256,256,256,256] # hidden layers for feature/density MLP
5 | layers_rgb: [null,128,3] # hidden layers for color MLP
6 | skip: [4] # skip connections
7 | posenc: # positional encoding
8 | L_3D: 10 # number of bases (3D point)
9 | L_view: 4 # number of bases (viewpoint)
10 | density_activ: relu # activation function for output volume density
11 | tf_init: true # initialize network weights in TensorFlow style
12 |
13 | nerf: # NeRF-specific options
14 | view_dep: true # condition MLP on viewpoint
15 | depth: # depth-related options
16 | param: metric # depth parametrization (for sampling along the ray)
17 | range: [2,6] # near/far bounds for depth sampling
18 | sample_intvs: 64 # number of samples
19 | sample_stratified: true # stratified sampling
20 | fine_sampling: true # hierarchical sampling with another NeRF
21 | sample_intvs_fine: 128 # number of samples for the fine NeRF
22 | rand_rays: 1024 # number of random rays for each step
23 | density_noise_reg: 0 # Gaussian noise on density output as regularization
24 | setbg_opaque: true # fill transparent rendering with known background color (Blender only)
25 |
26 | data: # data options
27 | dataset: blender # dataset name
28 | scene: lego # scene name
29 | image_size: [400,400] # input image sizes [height,width]
30 | num_workers: 4 # number of parallel workers for data loading
31 | preload: true # preload the entire dataset into the memory
32 | bgcolor: 1 # background color (Blender only)
33 | val_sub: 4 # consider a subset of N validation samples
34 |
35 | camera: # camera options
36 | model: perspective # type of camera model
37 | ndc: false # reparametrize as normalized device coordinates (NDC)
38 |
39 | loss_weight: # loss weights (in log scale)
40 | render: 0 # RGB rendering loss
41 | render_fine: 0 # RGB rendering loss (for fine NeRF)
42 |
43 | optim: # optimization options
44 | lr: 5.e-4 # learning rate (main)
45 | lr_end: 5.e-5 # terminal learning rate (only used with sched.type=ExponentialLR)
46 | sched: # learning rate scheduling options
47 | type: ExponentialLR # scheduler (see PyTorch doc)
48 | gamma: # decay rate (can be empty if lr_end were specified)
49 |
50 | batch_size: # batch size (not used for NeRF/BARF)
51 | max_epoch: # train to maximum number of epochs (not used for NeRF/BARF)
52 | max_iter: 500000 # train to maximum number of iterations
53 |
54 | trimesh: # options for marching cubes to extract 3D mesh
55 | res: 128 # 3D sampling resolution
56 | range: [-1.2,1.2] # 3D range of interest (assuming same for x,y,z)
57 | thres: 25. # volume density threshold for marching cubes
58 | chunk_size: 16384 # chunk size of dense samples to be evaluated at a time
59 |
60 | freq: # periodic actions during training
61 | scalar: 200 # log losses and scalar states (every N iterations)
62 | vis: 1000 # visualize results (every N iterations)
63 | val: 2000 # validate on val set (every N iterations)
64 | ckpt: 5000 # save checkpoint (every N iterations)
65 |
--------------------------------------------------------------------------------
/options/nerf_llff.yaml:
--------------------------------------------------------------------------------
1 | _parent_: options/base.yaml
2 |
3 | arch: # architectural optionss
4 | layers_feat: [null,256,256,256,256,256,256,256,256] # hidden layers for feature/density MLP]
5 | layers_rgb: [null,128,3] # hidden layers for color MLP]
6 | skip: [4] # skip connections
7 | posenc: # positional encoding:
8 | L_3D: 10 # number of bases (3D point)
9 | L_view: 4 # number of bases (viewpoint)
10 | density_activ: softplus # activation function for output volume density
11 | tf_init: true # initialize network weights in TensorFlow style
12 |
13 | nerf: # NeRF-specific options
14 | view_dep: true # condition MLP on viewpoint
15 | depth: # depth-related options
16 | param: inverse # depth parametrization (for sampling along the ray)
17 | range: [1,0] # near/far bounds for depth sampling
18 | sample_intvs: 128 # number of samples
19 | sample_stratified: true # stratified sampling
20 | fine_sampling: false # hierarchical sampling with another NeRF
21 | sample_intvs_fine: # number of samples for the fine NeRF
22 | rand_rays: 2048 # number of random rays for each step
23 | density_noise_reg: # Gaussian noise on density output as regularization
24 | setbg_opaque: # fill transparent rendering with known background color (Blender only)
25 |
26 | data: # data options
27 | dataset: llff # dataset name
28 | scene: fern # scene name
29 | image_size: [480,640] # input image sizes [height,width]
30 | num_workers: 4 # number of parallel workers for data loading
31 | preload: true # preload the entire dataset into the memory
32 | val_ratio: 0.1 # ratio of sequence split for validation
33 |
34 | camera: # camera options
35 | model: perspective # type of camera model
36 | ndc: false # reparametrize as normalized device coordinates (NDC)
37 |
38 | loss_weight: # loss weights (in log scale)
39 | render: 0 # RGB rendering loss
40 | render_fine: # RGB rendering loss (for fine NeRF)
41 |
42 | optim: # optimization options
43 | lr: 1.e-3 # learning rate (main)
44 | lr_end: 1.e-4 # terminal learning rate (only used with sched.type=ExponentialLR)
45 | sched: # learning rate scheduling options
46 | type: ExponentialLR # scheduler (see PyTorch doc)
47 | gamma: # decay rate (can be empty if lr_end were specified)
48 |
49 | batch_size: # batch size (not used for NeRF/BARF)
50 | max_epoch: # train to maximum number of epochs (not used for NeRF/BARF)
51 | max_iter: 200000 # train to maximum number of iterations
52 |
53 | freq: # periodic actions during training
54 | scalar: 200 # log losses and scalar states (every N iterations)
55 | vis: 1000 # visualize results (every N iterations)
56 | val: 2000 # validate on val set (every N iterations)
57 | ckpt: 5000 # save checkpoint (every N iterations)
58 |
--------------------------------------------------------------------------------
/options/nerf_llff_repr.yaml:
--------------------------------------------------------------------------------
1 | _parent_: options/base.yaml
2 |
3 | arch: # architectural options
4 | layers_feat: [null,256,256,256,256,256,256,256,256] # hidden layers for feature/density MLP
5 | layers_rgb: [null,128,3] # hidden layers for color MLP
6 | skip: [4] # skip connections
7 | posenc: # positional encoding
8 | L_3D: 10 # number of bases (3D point)
9 | L_view: 4 # number of bases (viewpoint)
10 | density_activ: relu # activation function for output volume density
11 | tf_init: true # initialize network weights in TensorFlow style
12 |
13 | nerf: # NeRF-specific options
14 | view_dep: true # condition MLP on viewpoint
15 | depth: # depth-related options
16 | param: metric # depth parametrization (for sampling along the ray)
17 | range: [0,1] # near/far bounds for depth sampling
18 | sample_intvs: 64 # number of samples
19 | sample_stratified: true # stratified sampling
20 | fine_sampling: true # hierarchical sampling with another NeRF
21 | sample_intvs_fine: 128 # number of samples for the fine NeRF
22 | rand_rays: 1024 # number of random rays for each step
23 | density_noise_reg: 1 # Gaussian noise on density output as regularization
24 | setbg_opaque: # fill transparent rendering with known background color (Blender only)
25 |
26 | data: # data options
27 | dataset: llff # dataset name
28 | scene: fern # scene name
29 | image_size: [480,640] # input image sizes [height,width]
30 | num_workers: 4 # number of parallel workers for data loading
31 | preload: true # preload the entire dataset into the memory
32 | val_ratio: 0.1 # ratio of sequence split for validation
33 |
34 | camera: # camera options
35 | model: perspective # type of camera model
36 | ndc: false # reparametrize as normalized device coordinates (NDC)
37 |
38 | loss_weight: # loss weights (in log scale)
39 | render: 0 # RGB rendering loss
40 | render_fine: 0 # RGB rendering loss (for fine NeRF)
41 |
42 | optim: # optimization options
43 | lr: 5.e-4 # learning rate (main)
44 | lr_end: 5.e-5 # terminal learning rate (only used with sched.type=ExponentialLR)
45 | sched: # learning rate scheduling options
46 | type: ExponentialLR # scheduler (see PyTorch doc)
47 | gamma: # decay rate (can be empty if lr_end were specified)
48 |
49 | batch_size: # batch size (not used for NeRF/BARF)
50 | max_epoch: # train to maximum number of epochs (not used for NeRF/BARF)
51 | max_iter: 500000 # train to maximum number of iterations
52 |
53 | freq: # periodic actions during training
54 | scalar: 200 # log losses and scalar states (every N iterations)
55 | vis: 1000 # visualize results (every N iterations)
56 | val: 2000 # validate on val set (every N iterations)
57 | ckpt: 5000 # save checkpoint (every N iterations)
58 |
--------------------------------------------------------------------------------
/options/planar.yaml:
--------------------------------------------------------------------------------
1 | _parent_: options/base.yaml
2 |
3 | arch: # architectural options
4 | layers: [null,256,256,256,256,3] # hidden layers for MLP
5 | skip: [] # skip connections
6 | posenc: # positional encoding
7 | L_2D: 8 # number of bases (3D point)
8 |
9 | barf_c2f: # coarse-to-fine scheduling on positional encoding
10 |
11 | data: # data options
12 | image_fname: data/cat.jpg # path to image file
13 | image_size: [360,480] # original image size
14 | patch_crop: [180,180] # crop size of image patches to align
15 |
16 | warp: # image warping options
17 | type: homography # type of warp function
18 | dof: 8 # degrees of freedom of the warp function
19 | noise_h: 0.1 # scale of pre-generated warp perturbation (homography)
20 | noise_t: 0.2 # scale of pre-generated warp perturbation (translation)
21 | fix_first: true # fix the first patch for uniqueness of solution
22 |
23 | loss_weight: # loss weights (in log scale)
24 | render: 0 # RGB rendering loss
25 |
26 | optim: # optimization options
27 | lr: 1.e-3 # learning rate (main)
28 | lr_warp: 1.e-3 # learning rate of warp parameters
29 |
30 | batch_size: 5 # batch size (set to number of patches to consider)
31 | max_iter: 5000 # train to maximum number of iterations
32 |
33 | visdom: # Visdom options (turned off)
34 |
35 | freq: # periodic actions during training
36 | scalar: 20 # log losses and scalar states (every N iterations)
37 | vis: 100 # visualize results (every N iterations)
38 |
--------------------------------------------------------------------------------
/requirements.yaml:
--------------------------------------------------------------------------------
1 | name: barf-env
2 | channels:
3 | - conda-forge
4 | - pytorch
5 | dependencies:
6 | - numpy
7 | - scipy
8 | - tqdm
9 | - termcolor
10 | - easydict
11 | - imageio
12 | - ipdb
13 | - pytorch>=1.9.0
14 | - torchvision
15 | - tensorboard
16 | - visdom
17 | - matplotlib
18 | - scikit-video
19 | - trimesh
20 | - pyyaml
21 | - pip
22 | - gdown
23 | - pip:
24 | - lpips
25 | - pymcubes
26 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os,sys,time
3 | import torch
4 | import importlib
5 |
6 | import options
7 | from util import log
8 |
9 | def main():
10 |
11 | log.process(os.getpid())
12 | log.title("[{}] (PyTorch code for training NeRF/BARF)".format(sys.argv[0]))
13 |
14 | opt_cmd = options.parse_arguments(sys.argv[1:])
15 | opt = options.set(opt_cmd=opt_cmd)
16 | options.save_options_file(opt)
17 |
18 | with torch.cuda.device(opt.device):
19 |
20 | model = importlib.import_module("model.{}".format(opt.model))
21 | m = model.Model(opt)
22 |
23 | m.load_dataset(opt)
24 | m.build_networks(opt)
25 | m.setup_optimizer(opt)
26 | m.restore_checkpoint(opt)
27 | m.setup_visualizer(opt)
28 |
29 | m.train(opt)
30 |
31 | if __name__=="__main__":
32 | main()
33 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os,sys,time
3 | import shutil
4 | import datetime
5 | import torch
6 | import torch.nn.functional as torch_F
7 | import ipdb
8 | import types
9 | import termcolor
10 | import socket
11 | import contextlib
12 | from easydict import EasyDict as edict
13 |
14 | # convert to colored strings
15 | def red(message,**kwargs): return termcolor.colored(str(message),color="red",attrs=[k for k,v in kwargs.items() if v is True])
16 | def green(message,**kwargs): return termcolor.colored(str(message),color="green",attrs=[k for k,v in kwargs.items() if v is True])
17 | def blue(message,**kwargs): return termcolor.colored(str(message),color="blue",attrs=[k for k,v in kwargs.items() if v is True])
18 | def cyan(message,**kwargs): return termcolor.colored(str(message),color="cyan",attrs=[k for k,v in kwargs.items() if v is True])
19 | def yellow(message,**kwargs): return termcolor.colored(str(message),color="yellow",attrs=[k for k,v in kwargs.items() if v is True])
20 | def magenta(message,**kwargs): return termcolor.colored(str(message),color="magenta",attrs=[k for k,v in kwargs.items() if v is True])
21 | def grey(message,**kwargs): return termcolor.colored(str(message),color="grey",attrs=[k for k,v in kwargs.items() if v is True])
22 |
23 | def get_time(sec):
24 | d = int(sec//(24*60*60))
25 | h = int(sec//(60*60)%24)
26 | m = int((sec//60)%60)
27 | s = int(sec%60)
28 | return d,h,m,s
29 |
30 | def add_datetime(func):
31 | def wrapper(*args,**kwargs):
32 | datetime_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
33 | print(grey("[{}] ".format(datetime_str),bold=True),end="")
34 | return func(*args,**kwargs)
35 | return wrapper
36 |
37 | def add_functionname(func):
38 | def wrapper(*args,**kwargs):
39 | print(grey("[{}] ".format(func.__name__),bold=True))
40 | return func(*args,**kwargs)
41 | return wrapper
42 |
43 | def pre_post_actions(pre=None,post=None):
44 | def func_decorator(func):
45 | def wrapper(*args,**kwargs):
46 | if pre: pre()
47 | retval = func(*args,**kwargs)
48 | if post: post()
49 | return retval
50 | return wrapper
51 | return func_decorator
52 |
53 | debug = ipdb.set_trace
54 |
55 | class Log:
56 | def __init__(self): pass
57 | def process(self,pid):
58 | print(grey("Process ID: {}".format(pid),bold=True))
59 | def title(self,message):
60 | print(yellow(message,bold=True,underline=True))
61 | def info(self,message):
62 | print(magenta(message,bold=True))
63 | def options(self,opt,level=0):
64 | for key,value in sorted(opt.items()):
65 | if isinstance(value,(dict,edict)):
66 | print(" "*level+cyan("* ")+green(key)+":")
67 | self.options(value,level+1)
68 | else:
69 | print(" "*level+cyan("* ")+green(key)+":",yellow(value))
70 | def loss_train(self,opt,ep,lr,loss,timer):
71 | if not opt.max_epoch: return
72 | message = grey("[train] ",bold=True)
73 | message += "epoch {}/{}".format(cyan(ep,bold=True),opt.max_epoch)
74 | message += ", lr:{}".format(yellow("{:.2e}".format(lr),bold=True))
75 | message += ", loss:{}".format(red("{:.3e}".format(loss),bold=True))
76 | message += ", time:{}".format(blue("{0}-{1:02d}:{2:02d}:{3:02d}".format(*get_time(timer.elapsed)),bold=True))
77 | message += " (ETA:{})".format(blue("{0}-{1:02d}:{2:02d}:{3:02d}".format(*get_time(timer.arrival))))
78 | print(message)
79 | def loss_val(self,opt,loss):
80 | message = grey("[val] ",bold=True)
81 | message += "loss:{}".format(red("{:.3e}".format(loss),bold=True))
82 | print(message)
83 | log = Log()
84 |
85 | def update_timer(opt,timer,ep,it_per_ep):
86 | if not opt.max_epoch: return
87 | momentum = 0.99
88 | timer.elapsed = time.time()-timer.start
89 | timer.it = timer.it_end-timer.it_start
90 | # compute speed with moving average
91 | timer.it_mean = timer.it_mean*momentum+timer.it*(1-momentum) if timer.it_mean is not None else timer.it
92 | timer.arrival = timer.it_mean*it_per_ep*(opt.max_epoch-ep)
93 |
94 | # move tensors to device in-place
95 | def move_to_device(X,device):
96 | if isinstance(X,dict):
97 | for k,v in X.items():
98 | X[k] = move_to_device(v,device)
99 | elif isinstance(X,list):
100 | for i,e in enumerate(X):
101 | X[i] = move_to_device(e,device)
102 | elif isinstance(X,tuple) and hasattr(X,"_fields"): # collections.namedtuple
103 | dd = X._asdict()
104 | dd = move_to_device(dd,device)
105 | return type(X)(**dd)
106 | elif isinstance(X,torch.Tensor):
107 | return X.to(device=device)
108 | return X
109 |
110 | def to_dict(D,dict_type=dict):
111 | D = dict_type(D)
112 | for k,v in D.items():
113 | if isinstance(v,dict):
114 | D[k] = to_dict(v,dict_type)
115 | return D
116 |
117 | def get_child_state_dict(state_dict,key):
118 | return { ".".join(k.split(".")[1:]): v for k,v in state_dict.items() if k.startswith("{}.".format(key)) }
119 |
120 | def restore_checkpoint(opt,model,load_name=None,resume=False):
121 | assert((load_name is None)==(resume is not False)) # resume can be True/False or epoch numbers
122 | if resume:
123 | load_name = "{0}/model.ckpt".format(opt.output_path) if resume is True else \
124 | "{0}/model/{1}.ckpt".format(opt.output_path,resume)
125 | checkpoint = torch.load(load_name,map_location=opt.device)
126 | # load individual (possibly partial) children modules
127 | for name,child in model.graph.named_children():
128 | child_state_dict = get_child_state_dict(checkpoint["graph"],name)
129 | if child_state_dict:
130 | print("restoring {}...".format(name))
131 | child.load_state_dict(child_state_dict)
132 | for key in model.__dict__:
133 | if key.split("_")[0] in ["optim","sched"] and key in checkpoint and resume:
134 | print("restoring {}...".format(key))
135 | getattr(model,key).load_state_dict(checkpoint[key])
136 | if resume:
137 | ep,it = checkpoint["epoch"],checkpoint["iter"]
138 | if resume is not True: assert(resume==(ep or it))
139 | print("resuming from epoch {0} (iteration {1})".format(ep,it))
140 | else: ep,it = None,None
141 | return ep,it
142 |
143 | def save_checkpoint(opt,model,ep,it,latest=False,children=None):
144 | os.makedirs("{0}/model".format(opt.output_path),exist_ok=True)
145 | if children is not None:
146 | graph_state_dict = { k: v for k,v in model.graph.state_dict().items() if k.startswith(children) }
147 | else: graph_state_dict = model.graph.state_dict()
148 | checkpoint = dict(
149 | epoch=ep,
150 | iter=it,
151 | graph=graph_state_dict,
152 | )
153 | for key in model.__dict__:
154 | if key.split("_")[0] in ["optim","sched"]:
155 | checkpoint.update({ key: getattr(model,key).state_dict() })
156 | torch.save(checkpoint,"{0}/model.ckpt".format(opt.output_path))
157 | if not latest:
158 | shutil.copy("{0}/model.ckpt".format(opt.output_path),
159 | "{0}/model/{1}.ckpt".format(opt.output_path,ep or it)) # if ep is None, track it instead
160 |
161 | def check_socket_open(hostname,port):
162 | s = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
163 | is_open = False
164 | try:
165 | s.bind((hostname,port))
166 | except socket.error:
167 | is_open = True
168 | finally:
169 | s.close()
170 | return is_open
171 |
172 | def get_layer_dims(layers):
173 | # return a list of tuples (k_in,k_out)
174 | return list(zip(layers[:-1],layers[1:]))
175 |
176 | @contextlib.contextmanager
177 | def suppress(stdout=False,stderr=False):
178 | with open(os.devnull,"w") as devnull:
179 | if stdout: old_stdout,sys.stdout = sys.stdout,devnull
180 | if stderr: old_stderr,sys.stderr = sys.stderr,devnull
181 | try: yield
182 | finally:
183 | if stdout: sys.stdout = old_stdout
184 | if stderr: sys.stderr = old_stderr
185 |
186 | def colorcode_to_number(code):
187 | ords = [ord(c) for c in code[1:]]
188 | ords = [n-48 if n<58 else n-87 for n in ords]
189 | rgb = (ords[0]*16+ords[1],ords[2]*16+ords[3],ords[4]*16+ords[5])
190 | return rgb
191 |
--------------------------------------------------------------------------------
/util_vis.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os,sys,time
3 | import torch
4 | import torch.nn.functional as torch_F
5 | import torchvision
6 | import torchvision.transforms.functional as torchvision_F
7 | import matplotlib.pyplot as plt
8 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection
9 | import PIL
10 | import imageio
11 | from easydict import EasyDict as edict
12 |
13 | import camera
14 |
15 | @torch.no_grad()
16 | def tb_image(opt,tb,step,group,name,images,num_vis=None,from_range=(0,1),cmap="gray"):
17 | images = preprocess_vis_image(opt,images,from_range=from_range,cmap=cmap)
18 | num_H,num_W = num_vis or opt.tb.num_images
19 | images = images[:num_H*num_W]
20 | image_grid = torchvision.utils.make_grid(images[:,:3],nrow=num_W,pad_value=1.)
21 | if images.shape[1]==4:
22 | mask_grid = torchvision.utils.make_grid(images[:,3:],nrow=num_W,pad_value=1.)[:1]
23 | image_grid = torch.cat([image_grid,mask_grid],dim=0)
24 | tag = "{0}/{1}".format(group,name)
25 | tb.add_image(tag,image_grid,step)
26 |
27 | def preprocess_vis_image(opt,images,from_range=(0,1),cmap="gray"):
28 | min,max = from_range
29 | images = (images-min)/(max-min)
30 | images = images.clamp(min=0,max=1).cpu()
31 | if images.shape[1]==1:
32 | images = get_heatmap(opt,images[:,0].cpu(),cmap=cmap)
33 | return images
34 |
35 | def dump_images(opt,idx,name,images,masks=None,from_range=(0,1),cmap="gray"):
36 | images = preprocess_vis_image(opt,images,masks=masks,from_range=from_range,cmap=cmap) # [B,3,H,W]
37 | images = images.cpu().permute(0,2,3,1).numpy() # [B,H,W,3]
38 | for i,img in zip(idx,images):
39 | fname = "{}/dump/{}_{}.png".format(opt.output_path,i,name)
40 | img_uint8 = (img*255).astype(np.uint8)
41 | imageio.imsave(fname,img_uint8)
42 |
43 | def get_heatmap(opt,gray,cmap): # [N,H,W]
44 | color = plt.get_cmap(cmap)(gray.numpy())
45 | color = torch.from_numpy(color[...,:3]).permute(0,3,1,2).float() # [N,3,H,W]
46 | return color
47 |
48 | def color_border(images,colors,width=3):
49 | images_pad = []
50 | for i,image in enumerate(images):
51 | image_pad = torch.ones(3,image.shape[1]+width*2,image.shape[2]+width*2)*(colors[i,:,None,None]/255.0)
52 | image_pad[:,width:-width,width:-width] = image
53 | images_pad.append(image_pad)
54 | images_pad = torch.stack(images_pad,dim=0)
55 | return images_pad
56 |
57 | @torch.no_grad()
58 | def vis_cameras(opt,vis,step,poses=[],colors=["blue","magenta"],plot_dist=True):
59 | win_name = "{}/{}".format(opt.group,opt.name)
60 | data = []
61 | # set up plots
62 | centers = []
63 | for pose,color in zip(poses,colors):
64 | pose = pose.detach().cpu()
65 | vertices,faces,wireframe = get_camera_mesh(pose,depth=opt.visdom.cam_depth)
66 | center = vertices[:,-1]
67 | centers.append(center)
68 | # camera centers
69 | data.append(dict(
70 | type="scatter3d",
71 | x=[float(n) for n in center[:,0]],
72 | y=[float(n) for n in center[:,1]],
73 | z=[float(n) for n in center[:,2]],
74 | mode="markers",
75 | marker=dict(color=color,size=3),
76 | ))
77 | # colored camera mesh
78 | vertices_merged,faces_merged = merge_meshes(vertices,faces)
79 | data.append(dict(
80 | type="mesh3d",
81 | x=[float(n) for n in vertices_merged[:,0]],
82 | y=[float(n) for n in vertices_merged[:,1]],
83 | z=[float(n) for n in vertices_merged[:,2]],
84 | i=[int(n) for n in faces_merged[:,0]],
85 | j=[int(n) for n in faces_merged[:,1]],
86 | k=[int(n) for n in faces_merged[:,2]],
87 | flatshading=True,
88 | color=color,
89 | opacity=0.05,
90 | ))
91 | # camera wireframe
92 | wireframe_merged = merge_wireframes(wireframe)
93 | data.append(dict(
94 | type="scatter3d",
95 | x=wireframe_merged[0],
96 | y=wireframe_merged[1],
97 | z=wireframe_merged[2],
98 | mode="lines",
99 | line=dict(color=color,),
100 | opacity=0.3,
101 | ))
102 | if plot_dist:
103 | # distance between two poses (camera centers)
104 | center_merged = merge_centers(centers[:2])
105 | data.append(dict(
106 | type="scatter3d",
107 | x=center_merged[0],
108 | y=center_merged[1],
109 | z=center_merged[2],
110 | mode="lines",
111 | line=dict(color="red",width=4,),
112 | ))
113 | if len(centers)==4:
114 | center_merged = merge_centers(centers[2:4])
115 | data.append(dict(
116 | type="scatter3d",
117 | x=center_merged[0],
118 | y=center_merged[1],
119 | z=center_merged[2],
120 | mode="lines",
121 | line=dict(color="red",width=4,),
122 | ))
123 | # send data to visdom
124 | vis._send(dict(
125 | data=data,
126 | win="poses",
127 | eid=win_name,
128 | layout=dict(
129 | title="({})".format(step),
130 | autosize=True,
131 | margin=dict(l=30,r=30,b=30,t=30,),
132 | showlegend=False,
133 | yaxis=dict(
134 | scaleanchor="x",
135 | scaleratio=1,
136 | )
137 | ),
138 | opts=dict(title="{} poses ({})".format(win_name,step),),
139 | ))
140 |
141 | def get_camera_mesh(pose,depth=1):
142 | vertices = torch.tensor([[-0.5,-0.5,1],
143 | [0.5,-0.5,1],
144 | [0.5,0.5,1],
145 | [-0.5,0.5,1],
146 | [0,0,0]])*depth
147 | faces = torch.tensor([[0,1,2],
148 | [0,2,3],
149 | [0,1,4],
150 | [1,2,4],
151 | [2,3,4],
152 | [3,0,4]])
153 | vertices = camera.cam2world(vertices[None],pose)
154 | wireframe = vertices[:,[0,1,2,3,0,4,1,2,4,3]]
155 | return vertices,faces,wireframe
156 |
157 | def merge_wireframes(wireframe):
158 | wireframe_merged = [[],[],[]]
159 | for w in wireframe:
160 | wireframe_merged[0] += [float(n) for n in w[:,0]]+[None]
161 | wireframe_merged[1] += [float(n) for n in w[:,1]]+[None]
162 | wireframe_merged[2] += [float(n) for n in w[:,2]]+[None]
163 | return wireframe_merged
164 | def merge_meshes(vertices,faces):
165 | mesh_N,vertex_N = vertices.shape[:2]
166 | faces_merged = torch.cat([faces+i*vertex_N for i in range(mesh_N)],dim=0)
167 | vertices_merged = vertices.view(-1,vertices.shape[-1])
168 | return vertices_merged,faces_merged
169 | def merge_centers(centers):
170 | center_merged = [[],[],[]]
171 | for c1,c2 in zip(*centers):
172 | center_merged[0] += [float(c1[0]),float(c2[0]),None]
173 | center_merged[1] += [float(c1[1]),float(c2[1]),None]
174 | center_merged[2] += [float(c1[2]),float(c2[2]),None]
175 | return center_merged
176 |
177 | def plot_save_poses(opt,fig,pose,pose_ref=None,path=None,ep=None):
178 | # get the camera meshes
179 | _,_,cam = get_camera_mesh(pose,depth=opt.visdom.cam_depth)
180 | cam = cam.numpy()
181 | if pose_ref is not None:
182 | _,_,cam_ref = get_camera_mesh(pose_ref,depth=opt.visdom.cam_depth)
183 | cam_ref = cam_ref.numpy()
184 | # set up plot window(s)
185 | plt.title("epoch {}".format(ep))
186 | ax1 = fig.add_subplot(121,projection="3d")
187 | ax2 = fig.add_subplot(122,projection="3d")
188 | setup_3D_plot(ax1,elev=-90,azim=-90,lim=edict(x=(-1,1),y=(-1,1),z=(-1,1)))
189 | setup_3D_plot(ax2,elev=0,azim=-90,lim=edict(x=(-1,1),y=(-1,1),z=(-1,1)))
190 | ax1.set_title("forward-facing view",pad=0)
191 | ax2.set_title("top-down view",pad=0)
192 | plt.subplots_adjust(left=0,right=1,bottom=0,top=0.95,wspace=0,hspace=0)
193 | plt.margins(tight=True,x=0,y=0)
194 | # plot the cameras
195 | N = len(cam)
196 | color = plt.get_cmap("gist_rainbow")
197 | for i in range(N):
198 | if pose_ref is not None:
199 | ax1.plot(cam_ref[i,:,0],cam_ref[i,:,1],cam_ref[i,:,2],color=(0.3,0.3,0.3),linewidth=1)
200 | ax2.plot(cam_ref[i,:,0],cam_ref[i,:,1],cam_ref[i,:,2],color=(0.3,0.3,0.3),linewidth=1)
201 | ax1.scatter(cam_ref[i,5,0],cam_ref[i,5,1],cam_ref[i,5,2],color=(0.3,0.3,0.3),s=40)
202 | ax2.scatter(cam_ref[i,5,0],cam_ref[i,5,1],cam_ref[i,5,2],color=(0.3,0.3,0.3),s=40)
203 | c = np.array(color(float(i)/N))*0.8
204 | ax1.plot(cam[i,:,0],cam[i,:,1],cam[i,:,2],color=c)
205 | ax2.plot(cam[i,:,0],cam[i,:,1],cam[i,:,2],color=c)
206 | ax1.scatter(cam[i,5,0],cam[i,5,1],cam[i,5,2],color=c,s=40)
207 | ax2.scatter(cam[i,5,0],cam[i,5,1],cam[i,5,2],color=c,s=40)
208 | png_fname = "{}/{}.png".format(path,ep)
209 | plt.savefig(png_fname,dpi=75)
210 | # clean up
211 | plt.clf()
212 |
213 | def plot_save_poses_blender(opt,fig,pose,pose_ref=None,path=None,ep=None):
214 | # get the camera meshes
215 | _,_,cam = get_camera_mesh(pose,depth=opt.visdom.cam_depth)
216 | cam = cam.numpy()
217 | if pose_ref is not None:
218 | _,_,cam_ref = get_camera_mesh(pose_ref,depth=opt.visdom.cam_depth)
219 | cam_ref = cam_ref.numpy()
220 | # set up plot window(s)
221 | ax = fig.add_subplot(111,projection="3d")
222 | ax.set_title("epoch {}".format(ep),pad=0)
223 | setup_3D_plot(ax,elev=45,azim=35,lim=edict(x=(-3,3),y=(-3,3),z=(-3,2.4)))
224 | plt.subplots_adjust(left=0,right=1,bottom=0,top=0.95,wspace=0,hspace=0)
225 | plt.margins(tight=True,x=0,y=0)
226 | # plot the cameras
227 | N = len(cam)
228 | ref_color = (0.7,0.2,0.7)
229 | pred_color = (0,0.6,0.7)
230 | ax.add_collection3d(Poly3DCollection([v[:4] for v in cam_ref],alpha=0.2,facecolor=ref_color))
231 | for i in range(N):
232 | ax.plot(cam_ref[i,:,0],cam_ref[i,:,1],cam_ref[i,:,2],color=ref_color,linewidth=0.5)
233 | ax.scatter(cam_ref[i,5,0],cam_ref[i,5,1],cam_ref[i,5,2],color=ref_color,s=20)
234 | if ep==0:
235 | png_fname = "{}/GT.png".format(path)
236 | plt.savefig(png_fname,dpi=75)
237 | ax.add_collection3d(Poly3DCollection([v[:4] for v in cam],alpha=0.2,facecolor=pred_color))
238 | for i in range(N):
239 | ax.plot(cam[i,:,0],cam[i,:,1],cam[i,:,2],color=pred_color,linewidth=1)
240 | ax.scatter(cam[i,5,0],cam[i,5,1],cam[i,5,2],color=pred_color,s=20)
241 | for i in range(N):
242 | ax.plot([cam[i,5,0],cam_ref[i,5,0]],
243 | [cam[i,5,1],cam_ref[i,5,1]],
244 | [cam[i,5,2],cam_ref[i,5,2]],color=(1,0,0),linewidth=3)
245 | png_fname = "{}/{}.png".format(path,ep)
246 | plt.savefig(png_fname,dpi=75)
247 | # clean up
248 | plt.clf()
249 |
250 | def setup_3D_plot(ax,elev,azim,lim=None):
251 | ax.xaxis.set_pane_color((1.0,1.0,1.0,0.0))
252 | ax.yaxis.set_pane_color((1.0,1.0,1.0,0.0))
253 | ax.zaxis.set_pane_color((1.0,1.0,1.0,0.0))
254 | ax.xaxis._axinfo["grid"]["color"] = (0.9,0.9,0.9,1)
255 | ax.yaxis._axinfo["grid"]["color"] = (0.9,0.9,0.9,1)
256 | ax.zaxis._axinfo["grid"]["color"] = (0.9,0.9,0.9,1)
257 | ax.xaxis.set_tick_params(labelsize=8)
258 | ax.yaxis.set_tick_params(labelsize=8)
259 | ax.zaxis.set_tick_params(labelsize=8)
260 | ax.set_xlabel("X",fontsize=16)
261 | ax.set_ylabel("Y",fontsize=16)
262 | ax.set_zlabel("Z",fontsize=16)
263 | ax.set_xlim(lim.x[0],lim.x[1])
264 | ax.set_ylim(lim.y[0],lim.y[1])
265 | ax.set_zlim(lim.z[0],lim.z[1])
266 | ax.view_init(elev=elev,azim=azim)
267 |
--------------------------------------------------------------------------------
/warp.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os,sys,time
3 | import torch
4 | import torch.nn.functional as torch_F
5 |
6 | import util
7 | from util import log,debug
8 | import camera
9 |
10 | def get_normalized_pixel_grid(opt):
11 | y_range = ((torch.arange(opt.H,dtype=torch.float32,device=opt.device)+0.5)/opt.H*2-1)*(opt.H/max(opt.H,opt.W))
12 | x_range = ((torch.arange(opt.W,dtype=torch.float32,device=opt.device)+0.5)/opt.W*2-1)*(opt.W/max(opt.H,opt.W))
13 | Y,X = torch.meshgrid(y_range,x_range) # [H,W]
14 | xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2]
15 | xy_grid = xy_grid.repeat(opt.batch_size,1,1) # [B,HW,2]
16 | return xy_grid
17 |
18 | def get_normalized_pixel_grid_crop(opt):
19 | y_crop = (opt.H//2-opt.H_crop//2,opt.H//2+opt.H_crop//2)
20 | x_crop = (opt.W//2-opt.W_crop//2,opt.W//2+opt.W_crop//2)
21 | y_range = ((torch.arange(*(y_crop),dtype=torch.float32,device=opt.device)+0.5)/opt.H*2-1)*(opt.H/max(opt.H,opt.W))
22 | x_range = ((torch.arange(*(x_crop),dtype=torch.float32,device=opt.device)+0.5)/opt.W*2-1)*(opt.W/max(opt.H,opt.W))
23 | Y,X = torch.meshgrid(y_range,x_range) # [H,W]
24 | xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2]
25 | xy_grid = xy_grid.repeat(opt.batch_size,1,1) # [B,HW,2]
26 | return xy_grid
27 |
28 | def warp_grid(opt,xy_grid,warp):
29 | if opt.warp.type=="translation":
30 | assert(opt.warp.dof==2)
31 | warped_grid = xy_grid+warp[...,None,:]
32 | elif opt.warp.type=="rotation":
33 | assert(opt.warp.dof==1)
34 | warp_matrix = lie.so2_to_SO2(warp)
35 | warped_grid = xy_grid@warp_matrix.transpose(-2,-1) # [B,HW,2]
36 | elif opt.warp.type=="rigid":
37 | assert(opt.warp.dof==3)
38 | xy_grid_hom = camera.to_hom(xy_grid)
39 | warp_matrix = lie.se2_to_SE2(warp)
40 | warped_grid = xy_grid_hom@warp_matrix.transpose(-2,-1) # [B,HW,2]
41 | elif opt.warp.type=="homography":
42 | assert(opt.warp.dof==8)
43 | xy_grid_hom = camera.to_hom(xy_grid)
44 | warp_matrix = lie.sl3_to_SL3(warp)
45 | warped_grid_hom = xy_grid_hom@warp_matrix.transpose(-2,-1)
46 | warped_grid = warped_grid_hom[...,:2]/(warped_grid_hom[...,2:]+1e-8) # [B,HW,2]
47 | else: assert(False)
48 | return warped_grid
49 |
50 | def warp_corners(opt,warp_param):
51 | y_crop = (opt.H//2-opt.H_crop//2,opt.H//2+opt.H_crop//2)
52 | x_crop = (opt.W//2-opt.W_crop//2,opt.W//2+opt.W_crop//2)
53 | Y = [((y+0.5)/opt.H*2-1)*(opt.H/max(opt.H,opt.W)) for y in y_crop]
54 | X = [((x+0.5)/opt.W*2-1)*(opt.W/max(opt.H,opt.W)) for x in x_crop]
55 | corners = [(X[0],Y[0]),(X[0],Y[1]),(X[1],Y[1]),(X[1],Y[0])]
56 | corners = torch.tensor(corners,dtype=torch.float32,device=opt.device).repeat(opt.batch_size,1,1)
57 | corners_warped = warp_grid(opt,corners,warp_param)
58 | return corners_warped
59 |
60 | def check_corners_in_range(opt,warp_param):
61 | corners_all = warp_corners(opt,warp_param)
62 | X = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5
63 | Y = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5
64 | return (0<=X).all() and (X0: denom *= (2*i)*(2*i+1)
140 | ans = ans+(-1)**i*x**(2*i)/denom
141 | return ans
142 | def taylor_B(self,x,nth=10):
143 | # Taylor expansion of (1-cos(x))/x
144 | ans = torch.zeros_like(x)
145 | denom = 1.
146 | for i in range(nth+1):
147 | denom *= (2*i+1)*(2*i+2)
148 | ans = ans+(-1)**i*x**(2*i+1)/denom
149 | return ans
150 |
151 | def taylor_C(self,x,nth=10):
152 | # Taylor expansion of (x*cos(x)-sin(x))/x**2
153 | ans = torch.zeros_like(x)
154 | denom = 1.
155 | for i in range(nth+1):
156 | denom *= (2*i+2)*(2*i+3)
157 | ans = ans+(-1)**(i+1)*x**(2*i+1)*(2*i+2)/denom
158 | return ans
159 |
160 | def taylor_D(self,x,nth=10):
161 | # Taylor expansion of (x*sin(x)+cos(x)-1)/x**2
162 | ans = torch.zeros_like(x)
163 | denom = 1.
164 | for i in range(nth+1):
165 | denom *= (2*i+1)*(2*i+2)
166 | ans = ans+(-1)**i*x**(2*i)*(2*i+1)/denom
167 | return ans
168 |
169 | lie = Lie()
170 |
--------------------------------------------------------------------------------