├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── camera.py ├── data ├── base.py ├── blender.py ├── cat.jpg ├── iphone.py └── llff.py ├── evaluate.py ├── extract_mesh.py ├── model ├── barf.py ├── base.py ├── nerf.py └── planar.py ├── options.py ├── options ├── barf_blender.yaml ├── barf_iphone.yaml ├── barf_llff.yaml ├── base.yaml ├── nerf_blender.yaml ├── nerf_blender_repr.yaml ├── nerf_llff.yaml ├── nerf_llff_repr.yaml └── planar.yaml ├── requirements.yaml ├── train.py ├── util.py ├── util_vis.py └── warp.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "external/pohsun_ssim"] 2 | path = external/pohsun_ssim 3 | url = https://github.com/Po-Hsun-Su/pytorch-ssim -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Chen-Hsuan Lin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## BARF :vomiting_face:: Bundle-Adjusting Neural Radiance Fields 2 | [Chen-Hsuan Lin](https://chenhsuanlin.bitbucket.io/), 3 | [Wei-Chiu Ma](http://people.csail.mit.edu/weichium/), 4 | [Antonio Torralba](https://groups.csail.mit.edu/vision/torralbalab/), 5 | and [Simon Lucey](http://ci2cv.net/people/simon-lucey/) 6 | IEEE International Conference on Computer Vision (ICCV), 2021 (**oral presentation**) 7 | 8 | Project page: https://chenhsuanlin.bitbucket.io/bundle-adjusting-NeRF 9 | Paper: https://chenhsuanlin.bitbucket.io/bundle-adjusting-NeRF/paper.pdf 10 | arXiv preprint: https://arxiv.org/abs/2104.06405 11 | 12 | We provide PyTorch code for all experiments: planar image alignment, NeRF/BARF on both synthetic (Blender) and real-world (LLFF) datasets, and a template for BARFing on your custom sequence. 13 | 14 | -------------------------------------- 15 | 16 | ### Prerequisites 17 | 18 | - Note: for Azure ML support for this repository, please consider checking out [this branch](https://github.com/szymanowiczs/bundle-adjusting-NeRF/tree/azureml_training_script) by Stan Szymanowicz. 19 | 20 | This code is developed with Python3 (`python3`). PyTorch 1.9+ is required. 21 | It is recommended use [Anaconda](https://www.anaconda.com/products/individual) to set up the environment. Install the dependencies and activate the environment `barf-env` with 22 | ```bash 23 | conda env create --file requirements.yaml python=3 24 | conda activate barf-env 25 | ``` 26 | Initialize the external submodule dependencies with 27 | ```bash 28 | git submodule update --init --recursive 29 | ``` 30 | 31 | -------------------------------------- 32 | 33 | ### Dataset 34 | 35 | - #### Synthetic data (Blender) and real-world data (LLFF) 36 | Both the Blender synthetic data and LLFF real-world data can be found in the [NeRF Google Drive](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1). 37 | For convenience, you can download them with the following script: (under this repo) 38 | ```bash 39 | # Blender 40 | gdown --id 18JxhpWD-4ZmuFKLzKlAw-w5PpzZxXOcG # download nerf_synthetic.zip 41 | unzip nerf_synthetic.zip 42 | rm -f nerf_synthetic.zip 43 | mv nerf_synthetic data/blender 44 | # LLFF 45 | gdown --id 16VnMcF1KJYxN9QId6TClMsZRahHNMW5g # download nerf_llff_data.zip 46 | unzip nerf_llff_data.zip 47 | rm -f nerf_llff_data.zip 48 | mv nerf_llff_data data/llff 49 | ``` 50 | The `data` directory should contain the subdirectories `blender` and `llff`. 51 | If you already have the datasets downloaded, you can alternatively soft-link them within the `data` directory. 52 | 53 | - #### Test your own sequence! 54 | If you want to try BARF on your own sequence, we provide a template data file in `data/iphone.py`, which is an example to read from a sequence captured by an iPhone 12. 55 | You should modify `get_image()` to read each image sample and set the raw image sizes (`self.raw_H`, `self.raw_W`) and focal length (`self.focal`) according to your camera specs. 56 | You may ignore the camera poses as they are assumed unknown in this case, which we simply set to zero vectors. 57 | 58 | -------------------------------------- 59 | 60 | ### Running the code 61 | 62 | - #### BARF models 63 | To train and evaluate BARF: 64 | ```bash 65 | # and can be set to your likes, while is specific to datasets 66 | 67 | # Blender (={chair,drums,ficus,hotdog,lego,materials,mic,ship}) 68 | python3 train.py --group= --model=barf --yaml=barf_blender --name= --data.scene= --barf_c2f=[0.1,0.5] 69 | python3 evaluate.py --group= --model=barf --yaml=barf_blender --name= --data.scene= --data.val_sub= --resume 70 | 71 | # LLFF (={fern,flower,fortress,horns,leaves,orchids,room,trex}) 72 | python3 train.py --group= --model=barf --yaml=barf_llff --name= --data.scene= --barf_c2f=[0.1,0.5] 73 | python3 evaluate.py --group= --model=barf --yaml=barf_llff --name= --data.scene= --resume 74 | ``` 75 | All the results will be stored in the directory `output//`. 76 | You may want to organize your experiments by grouping different runs in the same group. 77 | 78 | To train baseline models: 79 | - Full positional encoding: omit the `--barf_c2f` argument. 80 | - No positional encoding: add `--arch.posenc!`. 81 | 82 | If you want to evaluate a checkpoint at a specific iteration number, use `--resume=` instead of just `--resume`. 83 | 84 | - #### Training the original NeRF 85 | If you want to train the reference NeRF models (assuming known camera poses): 86 | ```bash 87 | # Blender 88 | python3 train.py --group= --model=nerf --yaml=nerf_blender --name= --data.scene= 89 | python3 evaluate.py --group= --model=nerf --yaml=nerf_blender --name= --data.scene= --data.val_sub= --resume 90 | 91 | # LLFF 92 | python3 train.py --group= --model=nerf --yaml=nerf_llff --name= --data.scene= 93 | python3 evaluate.py --group= --model=nerf --yaml=nerf_llff --name= --data.scene= --resume 94 | ``` 95 | If you wish to replicate the results from the original NeRF paper, use `--yaml=nerf_blender_repr` or `--yaml=nerf_llff_repr` instead for Blender or LLFF respectively. 96 | There are some differences, e.g. NDC will be used for the LLFF forward-facing dataset. 97 | (The reference NeRF models considered in the paper do not use NDC to parametrize the 3D points.) 98 | 99 | - #### Planar image alignment experiment 100 | If you want to try the planar image alignment experiment, run: 101 | ```bash 102 | python3 train.py --group= --model=planar --yaml=planar --name= --seed=3 --barf_c2f=[0,0.4] 103 | ``` 104 | This will fit a neural image representation to a single image (default to `data/cat.jpg`), which takes a couple of minutes to optimize on a modern GPU. 105 | The seed number is set to reproduce the pre-generated warp perturbations in the paper. 106 | For the baseline methods, modify the arguments similarly as in the NeRF case above: 107 | - Full positional encoding: omit the `--barf_c2f` argument. 108 | - No positional encoding: add `--arch.posenc!`. 109 | 110 | A video `vis.mp4` will also be created to visualize the optimization process. 111 | 112 | - #### Visualizing the results 113 | We have included code to visualize the training over TensorBoard and Visdom. 114 | The TensorBoard events include the following: 115 | - **SCALARS**: the rendering losses and PSNR over the course of optimization. For BARF, the rotational/translational errors with respect to the given poses are also computed. 116 | - **IMAGES**: visualization of the RGB images and the RGB/depth rendering. 117 | 118 | We also provide visualization of 3D camera poses in Visdom. 119 | Run `visdom -port 9000` to start the Visdom server. 120 | The Visdom host server is default to `localhost`; this can be overridden with `--visdom.server` (see `options/base.yaml` for details). 121 | If you want to disable Visdom visualization, add `--visdom!`. 122 | 123 | The `extract_mesh.py` script provides a simple way to extract the underlying 3D geometry using marching cubes. Run as follows: 124 | ```bash 125 | python3 extract_mesh.py --group= --model=barf --yaml=barf_blender --name= --data.scene= --data.val_sub= --resume 126 | ``` 127 | This works for both BARF and the original NeRF (by modifying the command line accordingly). This is currently supported only for the Blender dataset. 128 | 129 | -------------------------------------- 130 | ### Codebase structure 131 | 132 | The main engine and network architecture in `model/barf.py` inherit those from `model/nerf.py`. 133 | This codebase is structured so that it is easy to understand the actual parts BARF is extending from NeRF. 134 | It is also simple to build your exciting applications upon either BARF or NeRF -- just inherit them again! 135 | This is the same for dataset files (e.g. `data/blender.py`). 136 | 137 | To understand the config and command lines, take the below command as an example: 138 | ```bash 139 | python3 train.py --group= --model=barf --yaml=barf_blender --name= --data.scene= --barf_c2f=[0.1,0.5] 140 | ``` 141 | This will run `model/barf.py` as the main engine with `options/barf_blender.yaml` as the main config file. 142 | Note that `barf` hierarchically inherits `nerf` (which inherits `base`), making the codebase customizable. 143 | The complete configuration will be printed upon execution. 144 | To override specific options, add `--=value` or `--.=value` (and so on) to the command line. The configuration will be loaded as the variable `opt` throughout the codebase. 145 | 146 | Some tips on using and understanding the codebase: 147 | - The computation graph for forward/backprop is stored in `var` throughout the codebase. 148 | - The losses are stored in `loss`. To add a new loss function, just implement it in `compute_loss()` and add its weight to `opt.loss_weight.`. It will automatically be added to the overall loss and logged to Tensorboard. 149 | - If you are using a multi-GPU machine, you can add `--gpu=` to specify which GPU to use. Multi-GPU training/evaluation is currently not supported. 150 | - To resume from a previous checkpoint, add `--resume=`, or just `--resume` to resume from the latest checkpoint. 151 | - (to be continued....) 152 | 153 | -------------------------------------- 154 | 155 | If you find our code useful for your research, please cite 156 | ``` 157 | @inproceedings{lin2021barf, 158 | title={BARF: Bundle-Adjusting Neural Radiance Fields}, 159 | author={Lin, Chen-Hsuan and Ma, Wei-Chiu and Torralba, Antonio and Lucey, Simon}, 160 | booktitle={IEEE International Conference on Computer Vision ({ICCV})}, 161 | year={2021} 162 | } 163 | ``` 164 | 165 | Please contact me (chlin@cmu.edu) if you have any questions! 166 | -------------------------------------------------------------------------------- /camera.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,sys,time 3 | import torch 4 | import torch.nn.functional as torch_F 5 | import collections 6 | from easydict import EasyDict as edict 7 | 8 | import util 9 | from util import log,debug 10 | 11 | class Pose(): 12 | """ 13 | A class of operations on camera poses (PyTorch tensors with shape [...,3,4]) 14 | each [3,4] camera pose takes the form of [R|t] 15 | """ 16 | 17 | def __call__(self,R=None,t=None): 18 | # construct a camera pose from the given R and/or t 19 | assert(R is not None or t is not None) 20 | if R is None: 21 | if not isinstance(t,torch.Tensor): t = torch.tensor(t) 22 | R = torch.eye(3,device=t.device).repeat(*t.shape[:-1],1,1) 23 | elif t is None: 24 | if not isinstance(R,torch.Tensor): R = torch.tensor(R) 25 | t = torch.zeros(R.shape[:-1],device=R.device) 26 | else: 27 | if not isinstance(R,torch.Tensor): R = torch.tensor(R) 28 | if not isinstance(t,torch.Tensor): t = torch.tensor(t) 29 | assert(R.shape[:-1]==t.shape and R.shape[-2:]==(3,3)) 30 | R = R.float() 31 | t = t.float() 32 | pose = torch.cat([R,t[...,None]],dim=-1) # [...,3,4] 33 | assert(pose.shape[-2:]==(3,4)) 34 | return pose 35 | 36 | def invert(self,pose,use_inverse=False): 37 | # invert a camera pose 38 | R,t = pose[...,:3],pose[...,3:] 39 | R_inv = R.inverse() if use_inverse else R.transpose(-1,-2) 40 | t_inv = (-R_inv@t)[...,0] 41 | pose_inv = self(R=R_inv,t=t_inv) 42 | return pose_inv 43 | 44 | def compose(self,pose_list): 45 | # compose a sequence of poses together 46 | # pose_new(x) = poseN o ... o pose2 o pose1(x) 47 | pose_new = pose_list[0] 48 | for pose in pose_list[1:]: 49 | pose_new = self.compose_pair(pose_new,pose) 50 | return pose_new 51 | 52 | def compose_pair(self,pose_a,pose_b): 53 | # pose_new(x) = pose_b o pose_a(x) 54 | R_a,t_a = pose_a[...,:3],pose_a[...,3:] 55 | R_b,t_b = pose_b[...,:3],pose_b[...,3:] 56 | R_new = R_b@R_a 57 | t_new = (R_b@t_a+t_b)[...,0] 58 | pose_new = self(R=R_new,t=t_new) 59 | return pose_new 60 | 61 | class Lie(): 62 | """ 63 | Lie algebra for SO(3) and SE(3) operations in PyTorch 64 | """ 65 | 66 | def so3_to_SO3(self,w): # [...,3] 67 | wx = self.skew_symmetric(w) 68 | theta = w.norm(dim=-1)[...,None,None] 69 | I = torch.eye(3,device=w.device,dtype=torch.float32) 70 | A = self.taylor_A(theta) 71 | B = self.taylor_B(theta) 72 | R = I+A*wx+B*wx@wx 73 | return R 74 | 75 | def SO3_to_so3(self,R,eps=1e-7): # [...,3,3] 76 | trace = R[...,0,0]+R[...,1,1]+R[...,2,2] 77 | theta = ((trace-1)/2).clamp(-1+eps,1-eps).acos_()[...,None,None]%np.pi # ln(R) will explode if theta==pi 78 | lnR = 1/(2*self.taylor_A(theta)+1e-8)*(R-R.transpose(-2,-1)) # FIXME: wei-chiu finds it weird 79 | w0,w1,w2 = lnR[...,2,1],lnR[...,0,2],lnR[...,1,0] 80 | w = torch.stack([w0,w1,w2],dim=-1) 81 | return w 82 | 83 | def se3_to_SE3(self,wu): # [...,3] 84 | w,u = wu.split([3,3],dim=-1) 85 | wx = self.skew_symmetric(w) 86 | theta = w.norm(dim=-1)[...,None,None] 87 | I = torch.eye(3,device=w.device,dtype=torch.float32) 88 | A = self.taylor_A(theta) 89 | B = self.taylor_B(theta) 90 | C = self.taylor_C(theta) 91 | R = I+A*wx+B*wx@wx 92 | V = I+B*wx+C*wx@wx 93 | Rt = torch.cat([R,(V@u[...,None])],dim=-1) 94 | return Rt 95 | 96 | def SE3_to_se3(self,Rt,eps=1e-8): # [...,3,4] 97 | R,t = Rt.split([3,1],dim=-1) 98 | w = self.SO3_to_so3(R) 99 | wx = self.skew_symmetric(w) 100 | theta = w.norm(dim=-1)[...,None,None] 101 | I = torch.eye(3,device=w.device,dtype=torch.float32) 102 | A = self.taylor_A(theta) 103 | B = self.taylor_B(theta) 104 | invV = I-0.5*wx+(1-A/(2*B))/(theta**2+eps)*wx@wx 105 | u = (invV@t)[...,0] 106 | wu = torch.cat([w,u],dim=-1) 107 | return wu 108 | 109 | def skew_symmetric(self,w): 110 | w0,w1,w2 = w.unbind(dim=-1) 111 | O = torch.zeros_like(w0) 112 | wx = torch.stack([torch.stack([O,-w2,w1],dim=-1), 113 | torch.stack([w2,O,-w0],dim=-1), 114 | torch.stack([-w1,w0,O],dim=-1)],dim=-2) 115 | return wx 116 | 117 | def taylor_A(self,x,nth=10): 118 | # Taylor expansion of sin(x)/x 119 | ans = torch.zeros_like(x) 120 | denom = 1. 121 | for i in range(nth+1): 122 | if i>0: denom *= (2*i)*(2*i+1) 123 | ans = ans+(-1)**i*x**(2*i)/denom 124 | return ans 125 | def taylor_B(self,x,nth=10): 126 | # Taylor expansion of (1-cos(x))/x**2 127 | ans = torch.zeros_like(x) 128 | denom = 1. 129 | for i in range(nth+1): 130 | denom *= (2*i+1)*(2*i+2) 131 | ans = ans+(-1)**i*x**(2*i)/denom 132 | return ans 133 | def taylor_C(self,x,nth=10): 134 | # Taylor expansion of (x-sin(x))/x**3 135 | ans = torch.zeros_like(x) 136 | denom = 1. 137 | for i in range(nth+1): 138 | denom *= (2*i+2)*(2*i+3) 139 | ans = ans+(-1)**i*x**(2*i)/denom 140 | return ans 141 | 142 | class Quaternion(): 143 | 144 | def q_to_R(self,q): 145 | # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion 146 | qa,qb,qc,qd = q.unbind(dim=-1) 147 | R = torch.stack([torch.stack([1-2*(qc**2+qd**2),2*(qb*qc-qa*qd),2*(qa*qc+qb*qd)],dim=-1), 148 | torch.stack([2*(qb*qc+qa*qd),1-2*(qb**2+qd**2),2*(qc*qd-qa*qb)],dim=-1), 149 | torch.stack([2*(qb*qd-qa*qc),2*(qa*qb+qc*qd),1-2*(qb**2+qc**2)],dim=-1)],dim=-2) 150 | return R 151 | 152 | def R_to_q(self,R,eps=1e-8): # [B,3,3] 153 | # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion 154 | # FIXME: this function seems a bit problematic, need to double-check 155 | row0,row1,row2 = R.unbind(dim=-2) 156 | R00,R01,R02 = row0.unbind(dim=-1) 157 | R10,R11,R12 = row1.unbind(dim=-1) 158 | R20,R21,R22 = row2.unbind(dim=-1) 159 | t = R[...,0,0]+R[...,1,1]+R[...,2,2] 160 | r = (1+t+eps).sqrt() 161 | qa = 0.5*r 162 | qb = (R21-R12).sign()*0.5*(1+R00-R11-R22+eps).sqrt() 163 | qc = (R02-R20).sign()*0.5*(1-R00+R11-R22+eps).sqrt() 164 | qd = (R10-R01).sign()*0.5*(1-R00-R11+R22+eps).sqrt() 165 | q = torch.stack([qa,qb,qc,qd],dim=-1) 166 | for i,qi in enumerate(q): 167 | if torch.isnan(qi).any(): 168 | K = torch.stack([torch.stack([R00-R11-R22,R10+R01,R20+R02,R12-R21],dim=-1), 169 | torch.stack([R10+R01,R11-R00-R22,R21+R12,R20-R02],dim=-1), 170 | torch.stack([R20+R02,R21+R12,R22-R00-R11,R01-R10],dim=-1), 171 | torch.stack([R12-R21,R20-R02,R01-R10,R00+R11+R22],dim=-1)],dim=-2)/3.0 172 | K = K[i] 173 | eigval,eigvec = torch.linalg.eigh(K) 174 | V = eigvec[:,eigval.argmax()] 175 | q[i] = torch.stack([V[3],V[0],V[1],V[2]]) 176 | return q 177 | 178 | def invert(self,q): 179 | qa,qb,qc,qd = q.unbind(dim=-1) 180 | norm = q.norm(dim=-1,keepdim=True) 181 | q_inv = torch.stack([qa,-qb,-qc,-qd],dim=-1)/norm**2 182 | return q_inv 183 | 184 | def product(self,q1,q2): # [B,4] 185 | q1a,q1b,q1c,q1d = q1.unbind(dim=-1) 186 | q2a,q2b,q2c,q2d = q2.unbind(dim=-1) 187 | hamil_prod = torch.stack([q1a*q2a-q1b*q2b-q1c*q2c-q1d*q2d, 188 | q1a*q2b+q1b*q2a+q1c*q2d-q1d*q2c, 189 | q1a*q2c-q1b*q2d+q1c*q2a+q1d*q2b, 190 | q1a*q2d+q1b*q2c-q1c*q2b+q1d*q2a],dim=-1) 191 | return hamil_prod 192 | 193 | pose = Pose() 194 | lie = Lie() 195 | quaternion = Quaternion() 196 | 197 | def to_hom(X): 198 | # get homogeneous coordinates of the input 199 | X_hom = torch.cat([X,torch.ones_like(X[...,:1])],dim=-1) 200 | return X_hom 201 | 202 | # basic operations of transforming 3D points between world/camera/image coordinates 203 | def world2cam(X,pose): # [B,N,3] 204 | X_hom = to_hom(X) 205 | return X_hom@pose.transpose(-1,-2) 206 | def cam2img(X,cam_intr): 207 | return X@cam_intr.transpose(-1,-2) 208 | def img2cam(X,cam_intr): 209 | return X@cam_intr.inverse().transpose(-1,-2) 210 | def cam2world(X,pose): 211 | X_hom = to_hom(X) 212 | pose_inv = Pose().invert(pose) 213 | return X_hom@pose_inv.transpose(-1,-2) 214 | 215 | def angle_to_rotation_matrix(a,axis): 216 | # get the rotation matrix from Euler angle around specific axis 217 | roll = dict(X=1,Y=2,Z=0)[axis] 218 | O = torch.zeros_like(a) 219 | I = torch.ones_like(a) 220 | M = torch.stack([torch.stack([a.cos(),-a.sin(),O],dim=-1), 221 | torch.stack([a.sin(),a.cos(),O],dim=-1), 222 | torch.stack([O,O,I],dim=-1)],dim=-2) 223 | M = M.roll((roll,roll),dims=(-2,-1)) 224 | return M 225 | 226 | def get_center_and_ray(opt,pose,intr=None): # [HW,2] 227 | # given the intrinsic/extrinsic matrices, get the camera center and ray directions] 228 | assert(opt.camera.model=="perspective") 229 | with torch.no_grad(): 230 | # compute image coordinate grid 231 | y_range = torch.arange(opt.H,dtype=torch.float32,device=opt.device).add_(0.5) 232 | x_range = torch.arange(opt.W,dtype=torch.float32,device=opt.device).add_(0.5) 233 | Y,X = torch.meshgrid(y_range,x_range) # [H,W] 234 | xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2] 235 | # compute center and ray 236 | batch_size = len(pose) 237 | xy_grid = xy_grid.repeat(batch_size,1,1) # [B,HW,2] 238 | grid_3D = img2cam(to_hom(xy_grid),intr) # [B,HW,3] 239 | center_3D = torch.zeros_like(grid_3D) # [B,HW,3] 240 | # transform from camera to world coordinates 241 | grid_3D = cam2world(grid_3D,pose) # [B,HW,3] 242 | center_3D = cam2world(center_3D,pose) # [B,HW,3] 243 | ray = grid_3D-center_3D # [B,HW,3] 244 | return center_3D,ray 245 | 246 | def get_3D_points_from_depth(opt,center,ray,depth,multi_samples=False): 247 | if multi_samples: center,ray = center[:,:,None],ray[:,:,None] 248 | # x = c+dv 249 | points_3D = center+ray*depth # [B,HW,3]/[B,HW,N,3]/[N,3] 250 | return points_3D 251 | 252 | def convert_NDC(opt,center,ray,intr,near=1): 253 | # shift camera center (ray origins) to near plane (z=1) 254 | # (unlike conventional NDC, we assume the cameras are facing towards the +z direction) 255 | center = center+(near-center[...,2:])/ray[...,2:]*ray 256 | # projection 257 | cx,cy,cz = center.unbind(dim=-1) # [B,HW] 258 | rx,ry,rz = ray.unbind(dim=-1) # [B,HW] 259 | scale_x = intr[:,0,0]/intr[:,0,2] # [B] 260 | scale_y = intr[:,1,1]/intr[:,1,2] # [B] 261 | cnx = scale_x[:,None]*(cx/cz) 262 | cny = scale_y[:,None]*(cy/cz) 263 | cnz = 1-2*near/cz 264 | rnx = scale_x[:,None]*(rx/rz-cx/cz) 265 | rny = scale_y[:,None]*(ry/rz-cy/cz) 266 | rnz = 2*near/cz 267 | center_ndc = torch.stack([cnx,cny,cnz],dim=-1) # [B,HW,3] 268 | ray_ndc = torch.stack([rnx,rny,rnz],dim=-1) # [B,HW,3] 269 | return center_ndc,ray_ndc 270 | 271 | def rotation_distance(R1,R2,eps=1e-7): 272 | # http://www.boris-belousov.net/2016/12/01/quat-dist/ 273 | R_diff = R1@R2.transpose(-2,-1) 274 | trace = R_diff[...,0,0]+R_diff[...,1,1]+R_diff[...,2,2] 275 | angle = ((trace-1)/2).clamp(-1+eps,1-eps).acos_() # numerical stability near -1/+1 276 | return angle 277 | 278 | def procrustes_analysis(X0,X1): # [N,3] 279 | # translation 280 | t0 = X0.mean(dim=0,keepdim=True) 281 | t1 = X1.mean(dim=0,keepdim=True) 282 | X0c = X0-t0 283 | X1c = X1-t1 284 | # scale 285 | s0 = (X0c**2).sum(dim=-1).mean().sqrt() 286 | s1 = (X1c**2).sum(dim=-1).mean().sqrt() 287 | X0cs = X0c/s0 288 | X1cs = X1c/s1 289 | # rotation (use double for SVD, float loses precision) 290 | U,S,V = (X0cs.t()@X1cs).double().svd(some=True) 291 | R = (U@V.t()).float() 292 | if R.det()<0: R[2] *= -1 293 | # align X1 to X0: X1to0 = (X1-t1)/s1@R.t()*s0+t0 294 | sim3 = edict(t0=t0[0],t1=t1[0],s0=s0,s1=s1,R=R) 295 | return sim3 296 | 297 | def get_novel_view_poses(opt,pose_anchor,N=60,scale=1): 298 | # create circular viewpoints (small oscillations) 299 | theta = torch.arange(N)/N*2*np.pi 300 | R_x = angle_to_rotation_matrix((theta.sin()*0.05).asin(),"X") 301 | R_y = angle_to_rotation_matrix((theta.cos()*0.05).asin(),"Y") 302 | pose_rot = pose(R=R_y@R_x) 303 | pose_shift = pose(t=[0,0,-4*scale]) 304 | pose_shift2 = pose(t=[0,0,3.8*scale]) 305 | pose_oscil = pose.compose([pose_shift,pose_rot,pose_shift2]) 306 | pose_novel = pose.compose([pose_oscil,pose_anchor.cpu()[None]]) 307 | return pose_novel 308 | -------------------------------------------------------------------------------- /data/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,sys,time 3 | import torch 4 | import torch.nn.functional as torch_F 5 | import torchvision 6 | import torchvision.transforms.functional as torchvision_F 7 | import torch.multiprocessing as mp 8 | import PIL 9 | import tqdm 10 | import threading,queue 11 | from easydict import EasyDict as edict 12 | 13 | import util 14 | from util import log,debug 15 | 16 | class Dataset(torch.utils.data.Dataset): 17 | 18 | def __init__(self,opt,split="train"): 19 | super().__init__() 20 | self.opt = opt 21 | self.split = split 22 | self.augment = split=="train" and opt.data.augment 23 | # define image sizes 24 | if opt.data.center_crop is not None: 25 | self.crop_H = int(self.raw_H*opt.data.center_crop) 26 | self.crop_W = int(self.raw_W*opt.data.center_crop) 27 | else: self.crop_H,self.crop_W = self.raw_H,self.raw_W 28 | if not opt.H or not opt.W: 29 | opt.H,opt.W = self.crop_H,self.crop_W 30 | 31 | def setup_loader(self,opt,shuffle=False,drop_last=False): 32 | loader = torch.utils.data.DataLoader(self, 33 | batch_size=opt.batch_size or 1, 34 | num_workers=opt.data.num_workers, 35 | shuffle=shuffle, 36 | drop_last=drop_last, 37 | pin_memory=False, # spews warnings in PyTorch 1.9 but should be True in general 38 | ) 39 | print("number of samples: {}".format(len(self))) 40 | return loader 41 | 42 | def get_list(self,opt): 43 | raise NotImplementedError 44 | 45 | def preload_worker(self,data_list,load_func,q,lock,idx_tqdm): 46 | while True: 47 | idx = q.get() 48 | data_list[idx] = load_func(self.opt,idx) 49 | with lock: 50 | idx_tqdm.update() 51 | q.task_done() 52 | 53 | def preload_threading(self,opt,load_func,data_str="images"): 54 | data_list = [None]*len(self) 55 | q = queue.Queue(maxsize=len(self)) 56 | idx_tqdm = tqdm.tqdm(range(len(self)),desc="preloading {}".format(data_str),leave=False) 57 | for i in range(len(self)): q.put(i) 58 | lock = threading.Lock() 59 | for ti in range(opt.data.num_workers): 60 | t = threading.Thread(target=self.preload_worker, 61 | args=(data_list,load_func,q,lock,idx_tqdm),daemon=True) 62 | t.start() 63 | q.join() 64 | idx_tqdm.close() 65 | assert(all(map(lambda x: x is not None,data_list))) 66 | return data_list 67 | 68 | def __getitem__(self,idx): 69 | raise NotImplementedError 70 | 71 | def get_image(self,opt,idx): 72 | raise NotImplementedError 73 | 74 | def generate_augmentation(self,opt): 75 | brightness = opt.data.augment.brightness or 0. 76 | contrast = opt.data.augment.contrast or 0. 77 | saturation = opt.data.augment.saturation or 0. 78 | hue = opt.data.augment.hue or 0. 79 | color_jitter = torchvision.transforms.ColorJitter.get_params( 80 | brightness=(1-brightness,1+brightness), 81 | contrast=(1-contrast,1+contrast), 82 | saturation=(1-saturation,1+saturation), 83 | hue=(-hue,hue), 84 | ) 85 | aug = edict( 86 | color_jitter=color_jitter, 87 | flip=np.random.randn()>0 if opt.data.augment.hflip else False, 88 | rot_angle=(np.random.rand()*2-1)*opt.data.augment.rotate if opt.data.augment.rotate else 0, 89 | ) 90 | return aug 91 | 92 | def preprocess_image(self,opt,image,aug=None): 93 | if aug is not None: 94 | image = self.apply_color_jitter(opt,image,aug.color_jitter) 95 | image = torchvision_F.hflip(image) if aug.flip else image 96 | image = image.rotate(aug.rot_angle,resample=PIL.Image.BICUBIC) 97 | # center crop 98 | if opt.data.center_crop is not None: 99 | self.crop_H = int(self.raw_H*opt.data.center_crop) 100 | self.crop_W = int(self.raw_W*opt.data.center_crop) 101 | image = torchvision_F.center_crop(image,(self.crop_H,self.crop_W)) 102 | else: self.crop_H,self.crop_W = self.raw_H,self.raw_W 103 | # resize 104 | if opt.data.image_size[0] is not None: 105 | image = image.resize((opt.W,opt.H)) 106 | image = torchvision_F.to_tensor(image) 107 | return image 108 | 109 | def preprocess_camera(self,opt,intr,pose,aug=None): 110 | intr,pose = intr.clone(),pose.clone() 111 | # center crop 112 | intr[0,2] -= (self.raw_W-self.crop_W)/2 113 | intr[1,2] -= (self.raw_H-self.crop_H)/2 114 | # resize 115 | intr[0] *= opt.W/self.crop_W 116 | intr[1] *= opt.H/self.crop_H 117 | return intr,pose 118 | 119 | def apply_color_jitter(self,opt,image,color_jitter): 120 | mode = image.mode 121 | if mode!="L": 122 | chan = image.split() 123 | rgb = PIL.Image.merge("RGB",chan[:3]) 124 | rgb = color_jitter(rgb) 125 | rgb_chan = rgb.split() 126 | image = PIL.Image.merge(mode,rgb_chan+chan[3:]) 127 | return image 128 | 129 | def __len__(self): 130 | return len(self.list) 131 | -------------------------------------------------------------------------------- /data/blender.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,sys,time 3 | import torch 4 | import torch.nn.functional as torch_F 5 | import torchvision 6 | import torchvision.transforms.functional as torchvision_F 7 | import PIL 8 | import imageio 9 | from easydict import EasyDict as edict 10 | import json 11 | import pickle 12 | 13 | from . import base 14 | import camera 15 | from util import log,debug 16 | 17 | class Dataset(base.Dataset): 18 | 19 | def __init__(self,opt,split="train",subset=None): 20 | self.raw_H,self.raw_W = 800,800 21 | super().__init__(opt,split) 22 | self.root = opt.data.root or "data/blender" 23 | self.path = "{}/{}".format(self.root,opt.data.scene) 24 | # load/parse metadata 25 | meta_fname = "{}/transforms_{}.json".format(self.path,split) 26 | with open(meta_fname) as file: 27 | self.meta = json.load(file) 28 | self.list = self.meta["frames"] 29 | self.focal = 0.5*self.raw_W/np.tan(0.5*self.meta["camera_angle_x"]) 30 | if subset: self.list = self.list[:subset] 31 | # preload dataset 32 | if opt.data.preload: 33 | self.images = self.preload_threading(opt,self.get_image) 34 | self.cameras = self.preload_threading(opt,self.get_camera,data_str="cameras") 35 | 36 | def prefetch_all_data(self,opt): 37 | assert(not opt.data.augment) 38 | # pre-iterate through all samples and group together 39 | self.all = torch.utils.data._utils.collate.default_collate([s for s in self]) 40 | 41 | def get_all_camera_poses(self,opt): 42 | pose_raw_all = [torch.tensor(f["transform_matrix"],dtype=torch.float32) for f in self.list] 43 | pose_canon_all = torch.stack([self.parse_raw_camera(opt,p) for p in pose_raw_all],dim=0) 44 | return pose_canon_all 45 | 46 | def __getitem__(self,idx): 47 | opt = self.opt 48 | sample = dict(idx=idx) 49 | aug = self.generate_augmentation(opt) if self.augment else None 50 | image = self.images[idx] if opt.data.preload else self.get_image(opt,idx) 51 | image = self.preprocess_image(opt,image,aug=aug) 52 | intr,pose = self.cameras[idx] if opt.data.preload else self.get_camera(opt,idx) 53 | intr,pose = self.preprocess_camera(opt,intr,pose,aug=aug) 54 | sample.update( 55 | image=image, 56 | intr=intr, 57 | pose=pose, 58 | ) 59 | return sample 60 | 61 | def get_image(self,opt,idx): 62 | image_fname = "{}/{}.png".format(self.path,self.list[idx]["file_path"]) 63 | image = PIL.Image.fromarray(imageio.imread(image_fname)) # directly using PIL.Image.open() leads to weird corruption.... 64 | return image 65 | 66 | def preprocess_image(self,opt,image,aug=None): 67 | image = super().preprocess_image(opt,image,aug=aug) 68 | rgb,mask = image[:3],image[3:] 69 | if opt.data.bgcolor is not None: 70 | rgb = rgb*mask+opt.data.bgcolor*(1-mask) 71 | return rgb 72 | 73 | def get_camera(self,opt,idx): 74 | intr = torch.tensor([[self.focal,0,self.raw_W/2], 75 | [0,self.focal,self.raw_H/2], 76 | [0,0,1]]).float() 77 | pose_raw = torch.tensor(self.list[idx]["transform_matrix"],dtype=torch.float32) 78 | pose = self.parse_raw_camera(opt,pose_raw) 79 | return intr,pose 80 | 81 | def parse_raw_camera(self,opt,pose_raw): 82 | pose_flip = camera.pose(R=torch.diag(torch.tensor([1,-1,-1]))) 83 | pose = camera.pose.compose([pose_flip,pose_raw[:3]]) 84 | pose = camera.pose.invert(pose) 85 | return pose 86 | -------------------------------------------------------------------------------- /data/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenhsuanlin/bundle-adjusting-NeRF/803291bd0ee91c7c13fb5cc42195383c5ade7d15/data/cat.jpg -------------------------------------------------------------------------------- /data/iphone.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,sys,time 3 | import torch 4 | import torch.nn.functional as torch_F 5 | import torchvision 6 | import torchvision.transforms.functional as torchvision_F 7 | import PIL 8 | import imageio 9 | from easydict import EasyDict as edict 10 | import json 11 | import pickle 12 | 13 | from . import base 14 | import camera 15 | from util import log,debug 16 | 17 | class Dataset(base.Dataset): 18 | 19 | def __init__(self,opt,split="train",subset=None): 20 | self.raw_H,self.raw_W = 1080,1920 21 | super().__init__(opt,split) 22 | self.root = opt.data.root or "data/iphone" 23 | self.path = "{}/{}".format(self.root,opt.data.scene) 24 | self.path_image = "{}/images".format(self.path) 25 | self.list = sorted(os.listdir(self.path_image),key=lambda f: int(f.split(".")[0])) 26 | # manually split train/val subsets 27 | num_val_split = int(len(self)*opt.data.val_ratio) 28 | self.list = self.list[:-num_val_split] if split=="train" else self.list[-num_val_split:] 29 | if subset: self.list = self.list[:subset] 30 | # preload dataset 31 | if opt.data.preload: 32 | self.images = self.preload_threading(opt,self.get_image) 33 | self.cameras = self.preload_threading(opt,self.get_camera,data_str="cameras") 34 | 35 | def prefetch_all_data(self,opt): 36 | assert(not opt.data.augment) 37 | # pre-iterate through all samples and group together 38 | self.all = torch.utils.data._utils.collate.default_collate([s for s in self]) 39 | 40 | def get_all_camera_poses(self,opt): 41 | # poses are unknown, so just return some dummy poses (identity transform) 42 | return camera.pose(t=torch.zeros(len(self),3)) 43 | 44 | def __getitem__(self,idx): 45 | opt = self.opt 46 | sample = dict(idx=idx) 47 | aug = self.generate_augmentation(opt) if self.augment else None 48 | image = self.images[idx] if opt.data.preload else self.get_image(opt,idx) 49 | image = self.preprocess_image(opt,image,aug=aug) 50 | intr,pose = self.cameras[idx] if opt.data.preload else self.get_camera(opt,idx) 51 | intr,pose = self.preprocess_camera(opt,intr,pose,aug=aug) 52 | sample.update( 53 | image=image, 54 | intr=intr, 55 | pose=pose, 56 | ) 57 | return sample 58 | 59 | def get_image(self,opt,idx): 60 | image_fname = "{}/{}".format(self.path_image,self.list[idx]) 61 | image = PIL.Image.fromarray(imageio.imread(image_fname)) # directly using PIL.Image.open() leads to weird corruption.... 62 | return image 63 | 64 | def get_camera(self,opt,idx): 65 | self.focal = self.raw_W*4.2/(12.8/2.55) 66 | intr = torch.tensor([[self.focal,0,self.raw_W/2], 67 | [0,self.focal,self.raw_H/2], 68 | [0,0,1]]).float() 69 | pose = camera.pose(t=torch.zeros(3)) # dummy pose, won't be used 70 | return intr,pose 71 | -------------------------------------------------------------------------------- /data/llff.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,sys,time 3 | import torch 4 | import torch.nn.functional as torch_F 5 | import torchvision 6 | import torchvision.transforms.functional as torchvision_F 7 | import PIL 8 | import imageio 9 | from easydict import EasyDict as edict 10 | import json 11 | import pickle 12 | 13 | from . import base 14 | import camera 15 | from util import log,debug 16 | 17 | class Dataset(base.Dataset): 18 | 19 | def __init__(self,opt,split="train",subset=None): 20 | self.raw_H,self.raw_W = 3024,4032 21 | super().__init__(opt,split) 22 | self.root = opt.data.root or "data/llff" 23 | self.path = "{}/{}".format(self.root,opt.data.scene) 24 | self.path_image = "{}/images".format(self.path) 25 | image_fnames = sorted(os.listdir(self.path_image)) 26 | poses_raw,bounds = self.parse_cameras_and_bounds(opt) 27 | self.list = list(zip(image_fnames,poses_raw,bounds)) 28 | # manually split train/val subsets 29 | num_val_split = int(len(self)*opt.data.val_ratio) 30 | self.list = self.list[:-num_val_split] if split=="train" else self.list[-num_val_split:] 31 | if subset: self.list = self.list[:subset] 32 | # preload dataset 33 | if opt.data.preload: 34 | self.images = self.preload_threading(opt,self.get_image) 35 | self.cameras = self.preload_threading(opt,self.get_camera,data_str="cameras") 36 | 37 | def prefetch_all_data(self,opt): 38 | assert(not opt.data.augment) 39 | # pre-iterate through all samples and group together 40 | self.all = torch.utils.data._utils.collate.default_collate([s for s in self]) 41 | 42 | def parse_cameras_and_bounds(self,opt): 43 | fname = "{}/poses_bounds.npy".format(self.path) 44 | data = torch.tensor(np.load(fname),dtype=torch.float32) 45 | # parse cameras (intrinsics and poses) 46 | cam_data = data[:,:-2].view([-1,3,5]) # [N,3,5] 47 | poses_raw = cam_data[...,:4] # [N,3,4] 48 | poses_raw[...,0],poses_raw[...,1] = poses_raw[...,1],-poses_raw[...,0] 49 | raw_H,raw_W,self.focal = cam_data[0,:,-1] 50 | assert(self.raw_H==raw_H and self.raw_W==raw_W) 51 | # parse depth bounds 52 | bounds = data[:,-2:] # [N,2] 53 | scale = 1./(bounds.min()*0.75) # not sure how this was determined 54 | poses_raw[...,3] *= scale 55 | bounds *= scale 56 | # roughly center camera poses 57 | poses_raw = self.center_camera_poses(opt,poses_raw) 58 | return poses_raw,bounds 59 | 60 | def center_camera_poses(self,opt,poses): 61 | # compute average pose 62 | center = poses[...,3].mean(dim=0) 63 | v1 = torch_F.normalize(poses[...,1].mean(dim=0),dim=0) 64 | v2 = torch_F.normalize(poses[...,2].mean(dim=0),dim=0) 65 | v0 = v1.cross(v2) 66 | pose_avg = torch.stack([v0,v1,v2,center],dim=-1)[None] # [1,3,4] 67 | # apply inverse of averaged pose 68 | poses = camera.pose.compose([poses,camera.pose.invert(pose_avg)]) 69 | return poses 70 | 71 | def get_all_camera_poses(self,opt): 72 | pose_raw_all = [tup[1] for tup in self.list] 73 | pose_all = torch.stack([self.parse_raw_camera(opt,p) for p in pose_raw_all],dim=0) 74 | return pose_all 75 | 76 | def __getitem__(self,idx): 77 | opt = self.opt 78 | sample = dict(idx=idx) 79 | aug = self.generate_augmentation(opt) if self.augment else None 80 | image = self.images[idx] if opt.data.preload else self.get_image(opt,idx) 81 | image = self.preprocess_image(opt,image,aug=aug) 82 | intr,pose = self.cameras[idx] if opt.data.preload else self.get_camera(opt,idx) 83 | intr,pose = self.preprocess_camera(opt,intr,pose,aug=aug) 84 | sample.update( 85 | image=image, 86 | intr=intr, 87 | pose=pose, 88 | ) 89 | return sample 90 | 91 | def get_image(self,opt,idx): 92 | image_fname = "{}/{}".format(self.path_image,self.list[idx][0]) 93 | image = PIL.Image.fromarray(imageio.imread(image_fname)) # directly using PIL.Image.open() leads to weird corruption.... 94 | return image 95 | 96 | def get_camera(self,opt,idx): 97 | intr = torch.tensor([[self.focal,0,self.raw_W/2], 98 | [0,self.focal,self.raw_H/2], 99 | [0,0,1]]).float() 100 | pose_raw = self.list[idx][1] 101 | pose = self.parse_raw_camera(opt,pose_raw) 102 | return intr,pose 103 | 104 | def parse_raw_camera(self,opt,pose_raw): 105 | pose_flip = camera.pose(R=torch.diag(torch.tensor([1,-1,-1]))) 106 | pose = camera.pose.compose([pose_flip,pose_raw[:3]]) 107 | pose = camera.pose.invert(pose) 108 | pose = camera.pose.compose([pose_flip,pose]) 109 | return pose 110 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,sys,time 3 | import torch 4 | import importlib 5 | 6 | import options 7 | from util import log 8 | 9 | def main(): 10 | 11 | log.process(os.getpid()) 12 | log.title("[{}] (PyTorch code for evaluating NeRF/BARF)".format(sys.argv[0])) 13 | 14 | opt_cmd = options.parse_arguments(sys.argv[1:]) 15 | opt = options.set(opt_cmd=opt_cmd) 16 | 17 | with torch.cuda.device(opt.device): 18 | 19 | model = importlib.import_module("model.{}".format(opt.model)) 20 | m = model.Model(opt) 21 | 22 | m.load_dataset(opt,eval_split="test") 23 | m.build_networks(opt) 24 | 25 | if opt.model=="barf": 26 | m.generate_videos_pose(opt) 27 | 28 | m.restore_checkpoint(opt) 29 | if opt.data.dataset in ["blender","llff"]: 30 | m.evaluate_full(opt) 31 | m.generate_videos_synthesis(opt) 32 | 33 | if __name__=="__main__": 34 | main() 35 | -------------------------------------------------------------------------------- /extract_mesh.py: -------------------------------------------------------------------------------- 1 | """Extracts a 3D mesh from a pretrained model using marching cubes.""" 2 | 3 | import importlib 4 | import sys 5 | 6 | import numpy as np 7 | import options 8 | import torch 9 | import tqdm 10 | import trimesh 11 | import mcubes 12 | 13 | from util import log,debug 14 | 15 | opt_cmd = options.parse_arguments(sys.argv[1:]) 16 | opt = options.set(opt_cmd=opt_cmd) 17 | 18 | with torch.cuda.device(opt.device),torch.no_grad(): 19 | 20 | model = importlib.import_module("model.{}".format(opt.model)) 21 | m = model.Model(opt) 22 | 23 | m.build_networks(opt) 24 | m.restore_checkpoint(opt) 25 | 26 | t = torch.linspace(*opt.trimesh.range,opt.trimesh.res+1) # the best range might vary from model to model 27 | query = torch.stack(torch.meshgrid(t,t,t),dim=-1) 28 | query_flat = query.view(-1,3) 29 | 30 | density_all = [] 31 | for i in tqdm.trange(0,len(query_flat),opt.trimesh.chunk_size,leave=False): 32 | points = query_flat[None,i:i+opt.trimesh.chunk_size].to(opt.device) 33 | ray_unit = torch.zeros_like(points) # dummy ray to comply with interface, not used 34 | _,density_samples = m.graph.nerf.forward(opt,points,ray_unit=ray_unit,mode=None) 35 | density_all.append(density_samples.cpu()) 36 | density_all = torch.cat(density_all,dim=1)[0] 37 | density_all = density_all.view(*query.shape[:-1]).numpy() 38 | 39 | log.info("running marching cubes...") 40 | vertices,triangles = mcubes.marching_cubes(density_all,opt.trimesh.thres) 41 | vertices_centered = vertices/opt.trimesh.res-0.5 42 | mesh = trimesh.Trimesh(vertices_centered,triangles) 43 | 44 | obj_fname = "{}/mesh.obj".format(opt.output_path) 45 | log.info("saving 3D mesh to {}...".format(obj_fname)) 46 | mesh.export(obj_fname) 47 | -------------------------------------------------------------------------------- /model/barf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,sys,time 3 | import torch 4 | import torch.nn.functional as torch_F 5 | import torchvision 6 | import torchvision.transforms.functional as torchvision_F 7 | import tqdm 8 | from easydict import EasyDict as edict 9 | import visdom 10 | import matplotlib.pyplot as plt 11 | 12 | import util,util_vis 13 | from util import log,debug 14 | from . import nerf 15 | import camera 16 | 17 | # ============================ main engine for training and evaluation ============================ 18 | 19 | class Model(nerf.Model): 20 | 21 | def __init__(self,opt): 22 | super().__init__(opt) 23 | 24 | def build_networks(self,opt): 25 | super().build_networks(opt) 26 | if opt.camera.noise: 27 | # pre-generate synthetic pose perturbation 28 | se3_noise = torch.randn(len(self.train_data),6,device=opt.device)*opt.camera.noise 29 | self.graph.pose_noise = camera.lie.se3_to_SE3(se3_noise) 30 | self.graph.se3_refine = torch.nn.Embedding(len(self.train_data),6).to(opt.device) 31 | torch.nn.init.zeros_(self.graph.se3_refine.weight) 32 | 33 | def setup_optimizer(self,opt): 34 | super().setup_optimizer(opt) 35 | optimizer = getattr(torch.optim,opt.optim.algo) 36 | self.optim_pose = optimizer([dict(params=self.graph.se3_refine.parameters(),lr=opt.optim.lr_pose)]) 37 | # set up scheduler 38 | if opt.optim.sched_pose: 39 | scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched_pose.type) 40 | if opt.optim.lr_pose_end: 41 | assert(opt.optim.sched_pose.type=="ExponentialLR") 42 | opt.optim.sched_pose.gamma = (opt.optim.lr_pose_end/opt.optim.lr_pose)**(1./opt.max_iter) 43 | kwargs = { k:v for k,v in opt.optim.sched_pose.items() if k!="type" } 44 | self.sched_pose = scheduler(self.optim_pose,**kwargs) 45 | 46 | def train_iteration(self,opt,var,loader): 47 | self.optim_pose.zero_grad() 48 | if opt.optim.warmup_pose: 49 | # simple linear warmup of pose learning rate 50 | self.optim_pose.param_groups[0]["lr_orig"] = self.optim_pose.param_groups[0]["lr"] # cache the original learning rate 51 | self.optim_pose.param_groups[0]["lr"] *= min(1,self.it/opt.optim.warmup_pose) 52 | loss = super().train_iteration(opt,var,loader) 53 | self.optim_pose.step() 54 | if opt.optim.warmup_pose: 55 | self.optim_pose.param_groups[0]["lr"] = self.optim_pose.param_groups[0]["lr_orig"] # reset learning rate 56 | if opt.optim.sched_pose: self.sched_pose.step() 57 | self.graph.nerf.progress.data.fill_(self.it/opt.max_iter) 58 | if opt.nerf.fine_sampling: 59 | self.graph.nerf_fine.progress.data.fill_(self.it/opt.max_iter) 60 | return loss 61 | 62 | @torch.no_grad() 63 | def validate(self,opt,ep=None): 64 | pose,pose_GT = self.get_all_training_poses(opt) 65 | _,self.graph.sim3 = self.prealign_cameras(opt,pose,pose_GT) 66 | super().validate(opt,ep=ep) 67 | 68 | @torch.no_grad() 69 | def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"): 70 | super().log_scalars(opt,var,loss,metric=metric,step=step,split=split) 71 | if split=="train": 72 | # log learning rate 73 | lr = self.optim_pose.param_groups[0]["lr"] 74 | self.tb.add_scalar("{0}/{1}".format(split,"lr_pose"),lr,step) 75 | # compute pose error 76 | if split=="train" and opt.data.dataset in ["blender","llff"]: 77 | pose,pose_GT = self.get_all_training_poses(opt) 78 | pose_aligned,_ = self.prealign_cameras(opt,pose,pose_GT) 79 | error = self.evaluate_camera_alignment(opt,pose_aligned,pose_GT) 80 | self.tb.add_scalar("{0}/error_R".format(split),error.R.mean(),step) 81 | self.tb.add_scalar("{0}/error_t".format(split),error.t.mean(),step) 82 | 83 | @torch.no_grad() 84 | def visualize(self,opt,var,step=0,split="train"): 85 | super().visualize(opt,var,step=step,split=split) 86 | if opt.visdom: 87 | if split=="val": 88 | pose,pose_GT = self.get_all_training_poses(opt) 89 | util_vis.vis_cameras(opt,self.vis,step=step,poses=[pose,pose_GT]) 90 | 91 | @torch.no_grad() 92 | def get_all_training_poses(self,opt): 93 | # get ground-truth (canonical) camera poses 94 | pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device) 95 | # add synthetic pose perturbation to all training data 96 | if opt.data.dataset=="blender": 97 | pose = pose_GT 98 | if opt.camera.noise: 99 | pose = camera.pose.compose([self.graph.pose_noise,pose]) 100 | else: pose = self.graph.pose_eye 101 | # add learned pose correction to all training data 102 | pose_refine = camera.lie.se3_to_SE3(self.graph.se3_refine.weight) 103 | pose = camera.pose.compose([pose_refine,pose]) 104 | return pose,pose_GT 105 | 106 | @torch.no_grad() 107 | def prealign_cameras(self,opt,pose,pose_GT): 108 | # compute 3D similarity transform via Procrustes analysis 109 | center = torch.zeros(1,1,3,device=opt.device) 110 | center_pred = camera.cam2world(center,pose)[:,0] # [N,3] 111 | center_GT = camera.cam2world(center,pose_GT)[:,0] # [N,3] 112 | try: 113 | sim3 = camera.procrustes_analysis(center_GT,center_pred) 114 | except: 115 | print("warning: SVD did not converge...") 116 | sim3 = edict(t0=0,t1=0,s0=1,s1=1,R=torch.eye(3,device=opt.device)) 117 | # align the camera poses 118 | center_aligned = (center_pred-sim3.t1)/sim3.s1@sim3.R.t()*sim3.s0+sim3.t0 119 | R_aligned = pose[...,:3]@sim3.R.t() 120 | t_aligned = (-R_aligned@center_aligned[...,None])[...,0] 121 | pose_aligned = camera.pose(R=R_aligned,t=t_aligned) 122 | return pose_aligned,sim3 123 | 124 | @torch.no_grad() 125 | def evaluate_camera_alignment(self,opt,pose_aligned,pose_GT): 126 | # measure errors in rotation and translation 127 | R_aligned,t_aligned = pose_aligned.split([3,1],dim=-1) 128 | R_GT,t_GT = pose_GT.split([3,1],dim=-1) 129 | R_error = camera.rotation_distance(R_aligned,R_GT) 130 | t_error = (t_aligned-t_GT)[...,0].norm(dim=-1) 131 | error = edict(R=R_error,t=t_error) 132 | return error 133 | 134 | @torch.no_grad() 135 | def evaluate_full(self,opt): 136 | self.graph.eval() 137 | # evaluate rotation/translation 138 | pose,pose_GT = self.get_all_training_poses(opt) 139 | pose_aligned,self.graph.sim3 = self.prealign_cameras(opt,pose,pose_GT) 140 | error = self.evaluate_camera_alignment(opt,pose_aligned,pose_GT) 141 | print("--------------------------") 142 | print("rot: {:8.3f}".format(np.rad2deg(error.R.mean().cpu()))) 143 | print("trans: {:10.5f}".format(error.t.mean())) 144 | print("--------------------------") 145 | # dump numbers 146 | quant_fname = "{}/quant_pose.txt".format(opt.output_path) 147 | with open(quant_fname,"w") as file: 148 | for i,(err_R,err_t) in enumerate(zip(error.R,error.t)): 149 | file.write("{} {} {}\n".format(i,err_R.item(),err_t.item())) 150 | # evaluate novel view synthesis 151 | super().evaluate_full(opt) 152 | 153 | @torch.enable_grad() 154 | def evaluate_test_time_photometric_optim(self,opt,var): 155 | # use another se3 Parameter to absorb the remaining pose errors 156 | var.se3_refine_test = torch.nn.Parameter(torch.zeros(1,6,device=opt.device)) 157 | optimizer = getattr(torch.optim,opt.optim.algo) 158 | optim_pose = optimizer([dict(params=[var.se3_refine_test],lr=opt.optim.lr_pose)]) 159 | iterator = tqdm.trange(opt.optim.test_iter,desc="test-time optim.",leave=False,position=1) 160 | for it in iterator: 161 | optim_pose.zero_grad() 162 | var.pose_refine_test = camera.lie.se3_to_SE3(var.se3_refine_test) 163 | var = self.graph.forward(opt,var,mode="test-optim") 164 | loss = self.graph.compute_loss(opt,var,mode="test-optim") 165 | loss = self.summarize_loss(opt,var,loss) 166 | loss.all.backward() 167 | optim_pose.step() 168 | iterator.set_postfix(loss="{:.3f}".format(loss.all)) 169 | return var 170 | 171 | @torch.no_grad() 172 | def generate_videos_pose(self,opt): 173 | self.graph.eval() 174 | fig = plt.figure(figsize=(10,10) if opt.data.dataset=="blender" else (16,8)) 175 | cam_path = "{}/poses".format(opt.output_path) 176 | os.makedirs(cam_path,exist_ok=True) 177 | ep_list = [] 178 | for ep in range(0,opt.max_iter+1,opt.freq.ckpt): 179 | # load checkpoint (0 is random init) 180 | if ep!=0: 181 | try: util.restore_checkpoint(opt,self,resume=ep) 182 | except: continue 183 | # get the camera poses 184 | pose,pose_ref = self.get_all_training_poses(opt) 185 | if opt.data.dataset in ["blender","llff"]: 186 | pose_aligned,_ = self.prealign_cameras(opt,pose,pose_ref) 187 | pose_aligned,pose_ref = pose_aligned.detach().cpu(),pose_ref.detach().cpu() 188 | dict( 189 | blender=util_vis.plot_save_poses_blender, 190 | llff=util_vis.plot_save_poses, 191 | )[opt.data.dataset](opt,fig,pose_aligned,pose_ref=pose_ref,path=cam_path,ep=ep) 192 | else: 193 | pose = pose.detach().cpu() 194 | util_vis.plot_save_poses(opt,fig,pose,pose_ref=None,path=cam_path,ep=ep) 195 | ep_list.append(ep) 196 | plt.close() 197 | # write videos 198 | print("writing videos...") 199 | list_fname = "{}/temp.list".format(cam_path) 200 | with open(list_fname,"w") as file: 201 | for ep in ep_list: file.write("file {}.png\n".format(ep)) 202 | cam_vid_fname = "{}/poses.mp4".format(opt.output_path) 203 | os.system("ffmpeg -y -r 30 -f concat -i {0} -pix_fmt yuv420p {1} >/dev/null 2>&1".format(list_fname,cam_vid_fname)) 204 | os.remove(list_fname) 205 | 206 | # ============================ computation graph for forward/backprop ============================ 207 | 208 | class Graph(nerf.Graph): 209 | 210 | def __init__(self,opt): 211 | super().__init__(opt) 212 | self.nerf = NeRF(opt) 213 | if opt.nerf.fine_sampling: 214 | self.nerf_fine = NeRF(opt) 215 | self.pose_eye = torch.eye(3,4).to(opt.device) 216 | 217 | def get_pose(self,opt,var,mode=None): 218 | if mode=="train": 219 | # add the pre-generated pose perturbations 220 | if opt.data.dataset=="blender": 221 | if opt.camera.noise: 222 | var.pose_noise = self.pose_noise[var.idx] 223 | pose = camera.pose.compose([var.pose_noise,var.pose]) 224 | else: pose = var.pose 225 | else: pose = self.pose_eye 226 | # add learnable pose correction 227 | var.se3_refine = self.se3_refine.weight[var.idx] 228 | pose_refine = camera.lie.se3_to_SE3(var.se3_refine) 229 | pose = camera.pose.compose([pose_refine,pose]) 230 | elif mode in ["val","eval","test-optim"]: 231 | # align test pose to refined coordinate system (up to sim3) 232 | sim3 = self.sim3 233 | center = torch.zeros(1,1,3,device=opt.device) 234 | center = camera.cam2world(center,var.pose)[:,0] # [N,3] 235 | center_aligned = (center-sim3.t0)/sim3.s0@sim3.R*sim3.s1+sim3.t1 236 | R_aligned = var.pose[...,:3]@self.sim3.R 237 | t_aligned = (-R_aligned@center_aligned[...,None])[...,0] 238 | pose = camera.pose(R=R_aligned,t=t_aligned) 239 | # additionally factorize the remaining pose imperfection 240 | if opt.optim.test_photo and mode!="val": 241 | pose = camera.pose.compose([var.pose_refine_test,pose]) 242 | else: pose = var.pose 243 | return pose 244 | 245 | class NeRF(nerf.NeRF): 246 | 247 | def __init__(self,opt): 248 | super().__init__(opt) 249 | self.progress = torch.nn.Parameter(torch.tensor(0.)) # use Parameter so it could be checkpointed 250 | 251 | def positional_encoding(self,opt,input,L): # [B,...,N] 252 | input_enc = super().positional_encoding(opt,input,L=L) # [B,...,2NL] 253 | # coarse-to-fine: smoothly mask positional encoding for BARF 254 | if opt.barf_c2f is not None: 255 | # set weights for different frequency bands 256 | start,end = opt.barf_c2f 257 | alpha = (self.progress.data-start)/(end-start)*L 258 | k = torch.arange(L,dtype=torch.float32,device=opt.device) 259 | weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2 260 | # apply weights 261 | shape = input_enc.shape 262 | input_enc = (input_enc.view(-1,L)*weight).view(*shape) 263 | return input_enc 264 | -------------------------------------------------------------------------------- /model/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,sys,time 3 | import torch 4 | import torch.nn.functional as torch_F 5 | import torchvision 6 | import torchvision.transforms.functional as torchvision_F 7 | import torch.utils.tensorboard 8 | import visdom 9 | import importlib 10 | import tqdm 11 | from easydict import EasyDict as edict 12 | 13 | import util,util_vis 14 | from util import log,debug 15 | 16 | # ============================ main engine for training and evaluation ============================ 17 | 18 | class Model(): 19 | 20 | def __init__(self,opt): 21 | super().__init__() 22 | os.makedirs(opt.output_path,exist_ok=True) 23 | 24 | def load_dataset(self,opt,eval_split="val"): 25 | data = importlib.import_module("data.{}".format(opt.data.dataset)) 26 | log.info("loading training data...") 27 | self.train_data = data.Dataset(opt,split="train",subset=opt.data.train_sub) 28 | self.train_loader = self.train_data.setup_loader(opt,shuffle=True) 29 | log.info("loading test data...") 30 | if opt.data.val_on_test: eval_split = "test" 31 | self.test_data = data.Dataset(opt,split=eval_split,subset=opt.data.val_sub) 32 | self.test_loader = self.test_data.setup_loader(opt,shuffle=False) 33 | 34 | def build_networks(self,opt): 35 | graph = importlib.import_module("model.{}".format(opt.model)) 36 | log.info("building networks...") 37 | self.graph = graph.Graph(opt).to(opt.device) 38 | 39 | def setup_optimizer(self,opt): 40 | log.info("setting up optimizers...") 41 | optimizer = getattr(torch.optim,opt.optim.algo) 42 | self.optim = optimizer([dict(params=self.graph.parameters(),lr=opt.optim.lr)]) 43 | # set up scheduler 44 | if opt.optim.sched: 45 | scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type) 46 | kwargs = { k:v for k,v in opt.optim.sched.items() if k!="type" } 47 | self.sched = scheduler(self.optim,**kwargs) 48 | 49 | def restore_checkpoint(self,opt): 50 | epoch_start,iter_start = None,None 51 | if opt.resume: 52 | log.info("resuming from previous checkpoint...") 53 | epoch_start,iter_start = util.restore_checkpoint(opt,self,resume=opt.resume) 54 | elif opt.load is not None: 55 | log.info("loading weights from checkpoint {}...".format(opt.load)) 56 | epoch_start,iter_start = util.restore_checkpoint(opt,self,load_name=opt.load) 57 | else: 58 | log.info("initializing weights from scratch...") 59 | self.epoch_start = epoch_start or 0 60 | self.iter_start = iter_start or 0 61 | 62 | def setup_visualizer(self,opt): 63 | log.info("setting up visualizers...") 64 | if opt.tb: 65 | self.tb = torch.utils.tensorboard.SummaryWriter(log_dir=opt.output_path,flush_secs=10) 66 | if opt.visdom: 67 | # check if visdom server is runninng 68 | is_open = util.check_socket_open(opt.visdom.server,opt.visdom.port) 69 | retry = None 70 | while not is_open: 71 | retry = input("visdom port ({}) not open, retry? (y/n) ".format(opt.visdom.port)) 72 | if retry not in ["y","n"]: continue 73 | if retry=="y": 74 | is_open = util.check_socket_open(opt.visdom.server,opt.visdom.port) 75 | else: break 76 | self.vis = visdom.Visdom(server=opt.visdom.server,port=opt.visdom.port,env=opt.group) 77 | 78 | def train(self,opt): 79 | # before training 80 | log.title("TRAINING START") 81 | self.timer = edict(start=time.time(),it_mean=None) 82 | self.it = self.iter_start 83 | # training 84 | if self.iter_start==0: self.validate(opt,ep=0) 85 | for self.ep in range(self.epoch_start,opt.max_epoch): 86 | self.train_epoch(opt) 87 | # after training 88 | if opt.tb: 89 | self.tb.flush() 90 | self.tb.close() 91 | if opt.visdom: self.vis.close() 92 | log.title("TRAINING DONE") 93 | 94 | def train_epoch(self,opt): 95 | # before train epoch 96 | self.graph.train() 97 | # train epoch 98 | loader = tqdm.tqdm(self.train_loader,desc="training epoch {}".format(self.ep+1),leave=False) 99 | for batch in loader: 100 | # train iteration 101 | var = edict(batch) 102 | var = util.move_to_device(var,opt.device) 103 | loss = self.train_iteration(opt,var,loader) 104 | # after train epoch 105 | lr = self.sched.get_last_lr()[0] if opt.optim.sched else opt.optim.lr 106 | log.loss_train(opt,self.ep+1,lr,loss.all,self.timer) 107 | if opt.optim.sched: self.sched.step() 108 | if (self.ep+1)%opt.freq.val==0: self.validate(opt,ep=self.ep+1) 109 | if (self.ep+1)%opt.freq.ckpt==0: self.save_checkpoint(opt,ep=self.ep+1,it=self.it) 110 | 111 | def train_iteration(self,opt,var,loader): 112 | # before train iteration 113 | self.timer.it_start = time.time() 114 | # train iteration 115 | self.optim.zero_grad() 116 | var = self.graph.forward(opt,var,mode="train") 117 | loss = self.graph.compute_loss(opt,var,mode="train") 118 | loss = self.summarize_loss(opt,var,loss) 119 | loss.all.backward() 120 | self.optim.step() 121 | # after train iteration 122 | if (self.it+1)%opt.freq.scalar==0: self.log_scalars(opt,var,loss,step=self.it+1,split="train") 123 | if (self.it+1)%opt.freq.vis==0: self.visualize(opt,var,step=self.it+1,split="train") 124 | self.it += 1 125 | loader.set_postfix(it=self.it,loss="{:.3f}".format(loss.all)) 126 | self.timer.it_end = time.time() 127 | util.update_timer(opt,self.timer,self.ep,len(loader)) 128 | return loss 129 | 130 | def summarize_loss(self,opt,var,loss): 131 | loss_all = 0. 132 | assert("all" not in loss) 133 | # weigh losses 134 | for key in loss: 135 | assert(key in opt.loss_weight) 136 | assert(loss[key].shape==()) 137 | if opt.loss_weight[key] is not None: 138 | assert not torch.isinf(loss[key]),"loss {} is Inf".format(key) 139 | assert not torch.isnan(loss[key]),"loss {} is NaN".format(key) 140 | loss_all += 10**float(opt.loss_weight[key])*loss[key] 141 | loss.update(all=loss_all) 142 | return loss 143 | 144 | @torch.no_grad() 145 | def validate(self,opt,ep=None): 146 | self.graph.eval() 147 | loss_val = edict() 148 | loader = tqdm.tqdm(self.test_loader,desc="validating",leave=False) 149 | for it,batch in enumerate(loader): 150 | var = edict(batch) 151 | var = util.move_to_device(var,opt.device) 152 | var = self.graph.forward(opt,var,mode="val") 153 | loss = self.graph.compute_loss(opt,var,mode="val") 154 | loss = self.summarize_loss(opt,var,loss) 155 | for key in loss: 156 | loss_val.setdefault(key,0.) 157 | loss_val[key] += loss[key]*len(var.idx) 158 | loader.set_postfix(loss="{:.3f}".format(loss.all)) 159 | if it==0: self.visualize(opt,var,step=ep,split="val") 160 | for key in loss_val: loss_val[key] /= len(self.test_data) 161 | self.log_scalars(opt,var,loss_val,step=ep,split="val") 162 | log.loss_val(opt,loss_val.all) 163 | 164 | @torch.no_grad() 165 | def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"): 166 | for key,value in loss.items(): 167 | if key=="all": continue 168 | if opt.loss_weight[key] is not None: 169 | self.tb.add_scalar("{0}/loss_{1}".format(split,key),value,step) 170 | if metric is not None: 171 | for key,value in metric.items(): 172 | self.tb.add_scalar("{0}/{1}".format(split,key),value,step) 173 | 174 | @torch.no_grad() 175 | def visualize(self,opt,var,step=0,split="train"): 176 | raise NotImplementedError 177 | 178 | def save_checkpoint(self,opt,ep=0,it=0,latest=False): 179 | util.save_checkpoint(opt,self,ep=ep,it=it,latest=latest) 180 | if not latest: 181 | log.info("checkpoint saved: ({0}) {1}, epoch {2} (iteration {3})".format(opt.group,opt.name,ep,it)) 182 | 183 | # ============================ computation graph for forward/backprop ============================ 184 | 185 | class Graph(torch.nn.Module): 186 | 187 | def __init__(self,opt): 188 | super().__init__() 189 | 190 | def forward(self,opt,var,mode=None): 191 | raise NotImplementedError 192 | return var 193 | 194 | def compute_loss(self,opt,var,mode=None): 195 | loss = edict() 196 | raise NotImplementedError 197 | return loss 198 | 199 | def L1_loss(self,pred,label=0): 200 | loss = (pred.contiguous()-label).abs() 201 | return loss.mean() 202 | def MSE_loss(self,pred,label=0): 203 | loss = (pred.contiguous()-label)**2 204 | return loss.mean() 205 | -------------------------------------------------------------------------------- /model/nerf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,sys,time 3 | import torch 4 | import torch.nn.functional as torch_F 5 | import torchvision 6 | import torchvision.transforms.functional as torchvision_F 7 | import tqdm 8 | from easydict import EasyDict as edict 9 | 10 | import lpips 11 | from external.pohsun_ssim import pytorch_ssim 12 | 13 | import util,util_vis 14 | from util import log,debug 15 | from . import base 16 | import camera 17 | 18 | # ============================ main engine for training and evaluation ============================ 19 | 20 | class Model(base.Model): 21 | 22 | def __init__(self,opt): 23 | super().__init__(opt) 24 | self.lpips_loss = lpips.LPIPS(net="alex").to(opt.device) 25 | 26 | def load_dataset(self,opt,eval_split="val"): 27 | super().load_dataset(opt,eval_split=eval_split) 28 | # prefetch all training data 29 | self.train_data.prefetch_all_data(opt) 30 | self.train_data.all = edict(util.move_to_device(self.train_data.all,opt.device)) 31 | 32 | def setup_optimizer(self,opt): 33 | log.info("setting up optimizers...") 34 | optimizer = getattr(torch.optim,opt.optim.algo) 35 | self.optim = optimizer([dict(params=self.graph.nerf.parameters(),lr=opt.optim.lr)]) 36 | if opt.nerf.fine_sampling: 37 | self.optim.add_param_group(dict(params=self.graph.nerf_fine.parameters(),lr=opt.optim.lr)) 38 | # set up scheduler 39 | if opt.optim.sched: 40 | scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type) 41 | if opt.optim.lr_end: 42 | assert(opt.optim.sched.type=="ExponentialLR") 43 | opt.optim.sched.gamma = (opt.optim.lr_end/opt.optim.lr)**(1./opt.max_iter) 44 | kwargs = { k:v for k,v in opt.optim.sched.items() if k!="type" } 45 | self.sched = scheduler(self.optim,**kwargs) 46 | 47 | def train(self,opt): 48 | # before training 49 | log.title("TRAINING START") 50 | self.timer = edict(start=time.time(),it_mean=None) 51 | self.graph.train() 52 | self.ep = 0 # dummy for timer 53 | # training 54 | if self.iter_start==0: self.validate(opt,0) 55 | loader = tqdm.trange(opt.max_iter,desc="training",leave=False) 56 | for self.it in loader: 57 | if self.it/dev/null 2>&1".format(test_path,rgb_vid_fname)) 159 | os.system("ffmpeg -y -framerate 30 -i {0}/depth_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1".format(test_path,depth_vid_fname)) 160 | else: 161 | pose_pred,pose_GT = self.get_all_training_poses(opt) 162 | poses = pose_pred if opt.model=="barf" else pose_GT 163 | if opt.model=="barf" and opt.data.dataset=="llff": 164 | _,sim3 = self.prealign_cameras(opt,pose_pred,pose_GT) 165 | scale = sim3.s1/sim3.s0 166 | else: scale = 1 167 | # rotate novel views around the "center" camera of all poses 168 | idx_center = (poses-poses.mean(dim=0,keepdim=True))[...,3].norm(dim=-1).argmin() 169 | pose_novel = camera.get_novel_view_poses(opt,poses[idx_center],N=60,scale=scale).to(opt.device) 170 | # render the novel views 171 | novel_path = "{}/novel_view".format(opt.output_path) 172 | os.makedirs(novel_path,exist_ok=True) 173 | pose_novel_tqdm = tqdm.tqdm(pose_novel,desc="rendering novel views",leave=False) 174 | intr = edict(next(iter(self.test_loader))).intr[:1].to(opt.device) # grab intrinsics 175 | for i,pose in enumerate(pose_novel_tqdm): 176 | ret = self.graph.render_by_slices(opt,pose[None],intr=intr) if opt.nerf.rand_rays else \ 177 | self.graph.render(opt,pose[None],intr=intr) 178 | invdepth = (1-ret.depth)/ret.opacity if opt.camera.ndc else 1/(ret.depth/ret.opacity+eps) 179 | rgb_map = ret.rgb.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,3,H,W] 180 | invdepth_map = invdepth.view(-1,opt.H,opt.W,1).permute(0,3,1,2) # [B,1,H,W] 181 | torchvision_F.to_pil_image(rgb_map.cpu()[0]).save("{}/rgb_{}.png".format(novel_path,i)) 182 | torchvision_F.to_pil_image(invdepth_map.cpu()[0]).save("{}/depth_{}.png".format(novel_path,i)) 183 | # write videos 184 | print("writing videos...") 185 | rgb_vid_fname = "{}/novel_view_rgb.mp4".format(opt.output_path) 186 | depth_vid_fname = "{}/novel_view_depth.mp4".format(opt.output_path) 187 | os.system("ffmpeg -y -framerate 30 -i {0}/rgb_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1".format(novel_path,rgb_vid_fname)) 188 | os.system("ffmpeg -y -framerate 30 -i {0}/depth_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1".format(novel_path,depth_vid_fname)) 189 | 190 | # ============================ computation graph for forward/backprop ============================ 191 | 192 | class Graph(base.Graph): 193 | 194 | def __init__(self,opt): 195 | super().__init__(opt) 196 | self.nerf = NeRF(opt) 197 | if opt.nerf.fine_sampling: 198 | self.nerf_fine = NeRF(opt) 199 | 200 | def forward(self,opt,var,mode=None): 201 | batch_size = len(var.idx) 202 | pose = self.get_pose(opt,var,mode=mode) 203 | # render images 204 | if opt.nerf.rand_rays and mode in ["train","test-optim"]: 205 | # sample random rays for optimization 206 | var.ray_idx = torch.randperm(opt.H*opt.W,device=opt.device)[:opt.nerf.rand_rays//batch_size] 207 | ret = self.render(opt,pose,intr=var.intr,ray_idx=var.ray_idx,mode=mode) # [B,N,3],[B,N,1] 208 | else: 209 | # render full image (process in slices) 210 | ret = self.render_by_slices(opt,pose,intr=var.intr,mode=mode) if opt.nerf.rand_rays else \ 211 | self.render(opt,pose,intr=var.intr,mode=mode) # [B,HW,3],[B,HW,1] 212 | var.update(ret) 213 | return var 214 | 215 | def compute_loss(self,opt,var,mode=None): 216 | loss = edict() 217 | batch_size = len(var.idx) 218 | image = var.image.view(batch_size,3,opt.H*opt.W).permute(0,2,1) 219 | if opt.nerf.rand_rays and mode in ["train","test-optim"]: 220 | image = image[:,var.ray_idx] 221 | # compute image losses 222 | if opt.loss_weight.render is not None: 223 | loss.render = self.MSE_loss(var.rgb,image) 224 | if opt.loss_weight.render_fine is not None: 225 | assert(opt.nerf.fine_sampling) 226 | loss.render_fine = self.MSE_loss(var.rgb_fine,image) 227 | return loss 228 | 229 | def get_pose(self,opt,var,mode=None): 230 | return var.pose 231 | 232 | def render(self,opt,pose,intr=None,ray_idx=None,mode=None): 233 | batch_size = len(pose) 234 | center,ray = camera.get_center_and_ray(opt,pose,intr=intr) # [B,HW,3] 235 | while ray.isnan().any(): # TODO: weird bug, ray becomes NaN arbitrarily if batch_size>1, not deterministic reproducible 236 | center,ray = camera.get_center_and_ray(opt,pose,intr=intr) # [B,HW,3] 237 | if ray_idx is not None: 238 | # consider only subset of rays 239 | center,ray = center[:,ray_idx],ray[:,ray_idx] 240 | if opt.camera.ndc: 241 | # convert center/ray representations to NDC 242 | center,ray = camera.convert_NDC(opt,center,ray,intr=intr) 243 | # render with main MLP 244 | depth_samples = self.sample_depth(opt,batch_size,num_rays=ray.shape[1]) # [B,HW,N,1] 245 | rgb_samples,density_samples = self.nerf.forward_samples(opt,center,ray,depth_samples,mode=mode) 246 | rgb,depth,opacity,prob = self.nerf.composite(opt,ray,rgb_samples,density_samples,depth_samples) 247 | ret = edict(rgb=rgb,depth=depth,opacity=opacity) # [B,HW,K] 248 | # render with fine MLP from coarse MLP 249 | if opt.nerf.fine_sampling: 250 | with torch.no_grad(): 251 | # resample depth acoording to coarse empirical distribution 252 | depth_samples_fine = self.sample_depth_from_pdf(opt,pdf=prob[...,0]) # [B,HW,Nf,1] 253 | depth_samples = torch.cat([depth_samples,depth_samples_fine],dim=2) # [B,HW,N+Nf,1] 254 | depth_samples = depth_samples.sort(dim=2).values 255 | rgb_samples,density_samples = self.nerf_fine.forward_samples(opt,center,ray,depth_samples,mode=mode) 256 | rgb_fine,depth_fine,opacity_fine,_ = self.nerf_fine.composite(opt,ray,rgb_samples,density_samples,depth_samples) 257 | ret.update(rgb_fine=rgb_fine,depth_fine=depth_fine,opacity_fine=opacity_fine) # [B,HW,K] 258 | return ret 259 | 260 | def render_by_slices(self,opt,pose,intr=None,mode=None): 261 | ret_all = edict(rgb=[],depth=[],opacity=[]) 262 | if opt.nerf.fine_sampling: 263 | ret_all.update(rgb_fine=[],depth_fine=[],opacity_fine=[]) 264 | # render the image by slices for memory considerations 265 | for c in range(0,opt.H*opt.W,opt.nerf.rand_rays): 266 | ray_idx = torch.arange(c,min(c+opt.nerf.rand_rays,opt.H*opt.W),device=opt.device) 267 | ret = self.render(opt,pose,intr=intr,ray_idx=ray_idx,mode=mode) # [B,R,3],[B,R,1] 268 | for k in ret: ret_all[k].append(ret[k]) 269 | # group all slices of images 270 | for k in ret_all: ret_all[k] = torch.cat(ret_all[k],dim=1) 271 | return ret_all 272 | 273 | def sample_depth(self,opt,batch_size,num_rays=None): 274 | depth_min,depth_max = opt.nerf.depth.range 275 | num_rays = num_rays or opt.H*opt.W 276 | rand_samples = torch.rand(batch_size,num_rays,opt.nerf.sample_intvs,1,device=opt.device) if opt.nerf.sample_stratified else 0.5 277 | rand_samples += torch.arange(opt.nerf.sample_intvs,device=opt.device)[None,None,:,None].float() # [B,HW,N,1] 278 | depth_samples = rand_samples/opt.nerf.sample_intvs*(depth_max-depth_min)+depth_min # [B,HW,N,1] 279 | depth_samples = dict( 280 | metric=depth_samples, 281 | inverse=1/(depth_samples+1e-8), 282 | )[opt.nerf.depth.param] 283 | return depth_samples 284 | 285 | def sample_depth_from_pdf(self,opt,pdf): 286 | depth_min,depth_max = opt.nerf.depth.range 287 | # get CDF from PDF (along last dimension) 288 | cdf = pdf.cumsum(dim=-1) # [B,HW,N] 289 | cdf = torch.cat([torch.zeros_like(cdf[...,:1]),cdf],dim=-1) # [B,HW,N+1] 290 | # take uniform samples 291 | grid = torch.linspace(0,1,opt.nerf.sample_intvs_fine+1,device=opt.device) # [Nf+1] 292 | unif = 0.5*(grid[:-1]+grid[1:]).repeat(*cdf.shape[:-1],1) # [B,HW,Nf] 293 | idx = torch.searchsorted(cdf,unif,right=True) # [B,HW,Nf] \in {1...N} 294 | # inverse transform sampling from CDF 295 | depth_bin = torch.linspace(depth_min,depth_max,opt.nerf.sample_intvs+1,device=opt.device) # [N+1] 296 | depth_bin = depth_bin.repeat(*cdf.shape[:-1],1) # [B,HW,N+1] 297 | depth_low = depth_bin.gather(dim=2,index=(idx-1).clamp(min=0)) # [B,HW,Nf] 298 | depth_high = depth_bin.gather(dim=2,index=idx.clamp(max=opt.nerf.sample_intvs)) # [B,HW,Nf] 299 | cdf_low = cdf.gather(dim=2,index=(idx-1).clamp(min=0)) # [B,HW,Nf] 300 | cdf_high = cdf.gather(dim=2,index=idx.clamp(max=opt.nerf.sample_intvs)) # [B,HW,Nf] 301 | # linear interpolation 302 | t = (unif-cdf_low)/(cdf_high-cdf_low+1e-8) # [B,HW,Nf] 303 | depth_samples = depth_low+t*(depth_high-depth_low) # [B,HW,Nf] 304 | return depth_samples[...,None] # [B,HW,Nf,1] 305 | 306 | class NeRF(torch.nn.Module): 307 | 308 | def __init__(self,opt): 309 | super().__init__() 310 | self.define_network(opt) 311 | 312 | def define_network(self,opt): 313 | input_3D_dim = 3+6*opt.arch.posenc.L_3D if opt.arch.posenc else 3 314 | if opt.nerf.view_dep: 315 | input_view_dim = 3+6*opt.arch.posenc.L_view if opt.arch.posenc else 3 316 | # point-wise feature 317 | self.mlp_feat = torch.nn.ModuleList() 318 | L = util.get_layer_dims(opt.arch.layers_feat) 319 | for li,(k_in,k_out) in enumerate(L): 320 | if li==0: k_in = input_3D_dim 321 | if li in opt.arch.skip: k_in += input_3D_dim 322 | if li==len(L)-1: k_out += 1 323 | linear = torch.nn.Linear(k_in,k_out) 324 | if opt.arch.tf_init: 325 | self.tensorflow_init_weights(opt,linear,out="first" if li==len(L)-1 else None) 326 | self.mlp_feat.append(linear) 327 | # RGB prediction 328 | self.mlp_rgb = torch.nn.ModuleList() 329 | L = util.get_layer_dims(opt.arch.layers_rgb) 330 | feat_dim = opt.arch.layers_feat[-1] 331 | for li,(k_in,k_out) in enumerate(L): 332 | if li==0: k_in = feat_dim+(input_view_dim if opt.nerf.view_dep else 0) 333 | linear = torch.nn.Linear(k_in,k_out) 334 | if opt.arch.tf_init: 335 | self.tensorflow_init_weights(opt,linear,out="all" if li==len(L)-1 else None) 336 | self.mlp_rgb.append(linear) 337 | 338 | def tensorflow_init_weights(self,opt,linear,out=None): 339 | # use Xavier init instead of Kaiming init 340 | relu_gain = torch.nn.init.calculate_gain("relu") # sqrt(2) 341 | if out=="all": 342 | torch.nn.init.xavier_uniform_(linear.weight) 343 | elif out=="first": 344 | torch.nn.init.xavier_uniform_(linear.weight[:1]) 345 | torch.nn.init.xavier_uniform_(linear.weight[1:],gain=relu_gain) 346 | else: 347 | torch.nn.init.xavier_uniform_(linear.weight,gain=relu_gain) 348 | torch.nn.init.zeros_(linear.bias) 349 | 350 | def forward(self,opt,points_3D,ray_unit=None,mode=None): # [B,...,3] 351 | if opt.arch.posenc: 352 | points_enc = self.positional_encoding(opt,points_3D,L=opt.arch.posenc.L_3D) 353 | points_enc = torch.cat([points_3D,points_enc],dim=-1) # [B,...,6L+3] 354 | else: points_enc = points_3D 355 | feat = points_enc 356 | # extract coordinate-based features 357 | for li,layer in enumerate(self.mlp_feat): 358 | if li in opt.arch.skip: feat = torch.cat([feat,points_enc],dim=-1) 359 | feat = layer(feat) 360 | if li==len(self.mlp_feat)-1: 361 | density = feat[...,0] 362 | if opt.nerf.density_noise_reg and mode=="train": 363 | density += torch.randn_like(density)*opt.nerf.density_noise_reg 364 | density_activ = getattr(torch_F,opt.arch.density_activ) # relu_,abs_,sigmoid_,exp_.... 365 | density = density_activ(density) 366 | feat = feat[...,1:] 367 | feat = torch_F.relu(feat) 368 | # predict RGB values 369 | if opt.nerf.view_dep: 370 | assert(ray_unit is not None) 371 | if opt.arch.posenc: 372 | ray_enc = self.positional_encoding(opt,ray_unit,L=opt.arch.posenc.L_view) 373 | ray_enc = torch.cat([ray_unit,ray_enc],dim=-1) # [B,...,6L+3] 374 | else: ray_enc = ray_unit 375 | feat = torch.cat([feat,ray_enc],dim=-1) 376 | for li,layer in enumerate(self.mlp_rgb): 377 | feat = layer(feat) 378 | if li!=len(self.mlp_rgb)-1: 379 | feat = torch_F.relu(feat) 380 | rgb = feat.sigmoid_() # [B,...,3] 381 | return rgb,density 382 | 383 | def forward_samples(self,opt,center,ray,depth_samples,mode=None): 384 | points_3D_samples = camera.get_3D_points_from_depth(opt,center,ray,depth_samples,multi_samples=True) # [B,HW,N,3] 385 | if opt.nerf.view_dep: 386 | ray_unit = torch_F.normalize(ray,dim=-1) # [B,HW,3] 387 | ray_unit_samples = ray_unit[...,None,:].expand_as(points_3D_samples) # [B,HW,N,3] 388 | else: ray_unit_samples = None 389 | rgb_samples,density_samples = self.forward(opt,points_3D_samples,ray_unit=ray_unit_samples,mode=mode) # [B,HW,N],[B,HW,N,3] 390 | return rgb_samples,density_samples 391 | 392 | def composite(self,opt,ray,rgb_samples,density_samples,depth_samples): 393 | ray_length = ray.norm(dim=-1,keepdim=True) # [B,HW,1] 394 | # volume rendering: compute probability (using quadrature) 395 | depth_intv_samples = depth_samples[...,1:,0]-depth_samples[...,:-1,0] # [B,HW,N-1] 396 | depth_intv_samples = torch.cat([depth_intv_samples,torch.empty_like(depth_intv_samples[...,:1]).fill_(1e10)],dim=2) # [B,HW,N] 397 | dist_samples = depth_intv_samples*ray_length # [B,HW,N] 398 | sigma_delta = density_samples*dist_samples # [B,HW,N] 399 | alpha = 1-(-sigma_delta).exp_() # [B,HW,N] 400 | T = (-torch.cat([torch.zeros_like(sigma_delta[...,:1]),sigma_delta[...,:-1]],dim=2).cumsum(dim=2)).exp_() # [B,HW,N] 401 | prob = (T*alpha)[...,None] # [B,HW,N,1] 402 | # integrate RGB and depth weighted by probability 403 | depth = (depth_samples*prob).sum(dim=2) # [B,HW,1] 404 | rgb = (rgb_samples*prob).sum(dim=2) # [B,HW,3] 405 | opacity = prob.sum(dim=2) # [B,HW,1] 406 | if opt.nerf.setbg_opaque: 407 | rgb = rgb+opt.data.bgcolor*(1-opacity) 408 | return rgb,depth,opacity,prob # [B,HW,K] 409 | 410 | def positional_encoding(self,opt,input,L): # [B,...,N] 411 | shape = input.shape 412 | freq = 2**torch.arange(L,dtype=torch.float32,device=opt.device)*np.pi # [L] 413 | spectrum = input[...,None]*freq # [B,...,N,L] 414 | sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L] 415 | input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L] 416 | input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL] 417 | return input_enc 418 | -------------------------------------------------------------------------------- /model/planar.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,sys,time 3 | import torch 4 | import torch.nn.functional as torch_F 5 | import torchvision 6 | import torchvision.transforms.functional as torchvision_F 7 | import tqdm 8 | from easydict import EasyDict as edict 9 | import PIL 10 | import PIL.Image,PIL.ImageDraw 11 | import imageio 12 | 13 | import util,util_vis 14 | from util import log,debug 15 | from . import base 16 | import warp 17 | 18 | # ============================ main engine for training and evaluation ============================ 19 | 20 | class Model(base.Model): 21 | 22 | def __init__(self,opt): 23 | super().__init__(opt) 24 | opt.H_crop,opt.W_crop = opt.data.patch_crop 25 | 26 | def load_dataset(self,opt,eval_split=None): 27 | image_raw = PIL.Image.open(opt.data.image_fname) 28 | self.image_raw = torchvision_F.to_tensor(image_raw).to(opt.device) 29 | 30 | def build_networks(self,opt): 31 | super().build_networks(opt) 32 | self.graph.warp_param = torch.nn.Embedding(opt.batch_size,opt.warp.dof).to(opt.device) 33 | torch.nn.init.zeros_(self.graph.warp_param.weight) 34 | 35 | def setup_optimizer(self,opt): 36 | log.info("setting up optimizers...") 37 | optim_list = [ 38 | dict(params=self.graph.neural_image.parameters(),lr=opt.optim.lr), 39 | dict(params=self.graph.warp_param.parameters(),lr=opt.optim.lr_warp), 40 | ] 41 | optimizer = getattr(torch.optim,opt.optim.algo) 42 | self.optim = optimizer(optim_list) 43 | # set up scheduler 44 | if opt.optim.sched: 45 | scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type) 46 | kwargs = { k:v for k,v in opt.optim.sched.items() if k!="type" } 47 | self.sched = scheduler(self.optim,**kwargs) 48 | 49 | def setup_visualizer(self,opt): 50 | super().setup_visualizer(opt) 51 | # set colors for visualization 52 | box_colors = ["#ff0000","#40afff","#9314ff","#ffd700","#00ff00"] 53 | box_colors = list(map(util.colorcode_to_number,box_colors)) 54 | self.box_colors = np.array(box_colors).astype(int) 55 | assert(len(self.box_colors)==opt.batch_size) 56 | # create visualization directory 57 | self.vis_path = "{}/vis".format(opt.output_path) 58 | os.makedirs(self.vis_path,exist_ok=True) 59 | self.video_fname = "{}/vis.mp4".format(opt.output_path) 60 | 61 | def train(self,opt): 62 | # before training 63 | log.title("TRAINING START") 64 | self.timer = edict(start=time.time(),it_mean=None) 65 | self.ep = self.it = self.vis_it = 0 66 | self.graph.train() 67 | var = edict(idx=torch.arange(opt.batch_size)) 68 | # pre-generate perturbations 69 | self.warp_pert,var.image_pert = self.generate_warp_perturbation(opt) 70 | # train 71 | var = util.move_to_device(var,opt.device) 72 | loader = tqdm.trange(opt.max_iter,desc="training",leave=False) 73 | # visualize initial state 74 | var = self.graph.forward(opt,var) 75 | self.visualize(opt,var,step=0) 76 | for it in loader: 77 | # train iteration 78 | loss = self.train_iteration(opt,var,loader) 79 | if opt.warp.fix_first: 80 | self.graph.warp_param.weight.data[0] = 0 81 | # after training 82 | os.system("ffmpeg -y -framerate 30 -i {}/%d.png -pix_fmt yuv420p {}".format(self.vis_path,self.video_fname)) 83 | self.save_checkpoint(opt,ep=None,it=self.it) 84 | if opt.tb: 85 | self.tb.flush() 86 | self.tb.close() 87 | if opt.visdom: self.vis.close() 88 | log.title("TRAINING DONE") 89 | 90 | def train_iteration(self,opt,var,loader): 91 | loss = super().train_iteration(opt,var,loader) 92 | self.graph.neural_image.progress.data.fill_(self.it/opt.max_iter) 93 | return loss 94 | 95 | def generate_warp_perturbation(self,opt): 96 | # pre-generate perturbations (translational noise + homography noise) 97 | warp_pert_all = torch.zeros(opt.batch_size,opt.warp.dof,device=opt.device) 98 | trans_pert = [(0,0)]+[(x,y) for x in (-opt.warp.noise_t,opt.warp.noise_t) 99 | for y in (-opt.warp.noise_t,opt.warp.noise_t)] 100 | def create_random_perturbation(): 101 | warp_pert = torch.randn(opt.warp.dof,device=opt.device)*opt.warp.noise_h 102 | warp_pert[0] += trans_pert[i][0] 103 | warp_pert[1] += trans_pert[i][1] 104 | return warp_pert 105 | for i in range(opt.batch_size): 106 | warp_pert = create_random_perturbation() 107 | while not warp.check_corners_in_range(opt,warp_pert[None]): 108 | warp_pert = create_random_perturbation() 109 | warp_pert_all[i] = warp_pert 110 | if opt.warp.fix_first: 111 | warp_pert_all[0] = 0 112 | # create warped image patches 113 | xy_grid = warp.get_normalized_pixel_grid_crop(opt) # [B,HW,2] 114 | xy_grid_warped = warp.warp_grid(opt,xy_grid,warp_pert_all) 115 | xy_grid_warped = xy_grid_warped.view([opt.batch_size,opt.H_crop,opt.W_crop,2]) 116 | xy_grid_warped = torch.stack([xy_grid_warped[...,0]*max(opt.H,opt.W)/opt.W, 117 | xy_grid_warped[...,1]*max(opt.H,opt.W)/opt.H],dim=-1) 118 | image_raw_batch = self.image_raw.repeat(opt.batch_size,1,1,1) 119 | image_pert_all = torch_F.grid_sample(image_raw_batch,xy_grid_warped,align_corners=False) 120 | return warp_pert_all,image_pert_all 121 | 122 | def visualize_patches(self,opt,warp_param): 123 | image_pil = torchvision_F.to_pil_image(self.image_raw).convert("RGBA") 124 | draw_pil = PIL.Image.new("RGBA",image_pil.size,(0,0,0,0)) 125 | draw = PIL.ImageDraw.Draw(draw_pil) 126 | corners_all = warp.warp_corners(opt,warp_param) 127 | corners_all[...,0] = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5 128 | corners_all[...,1] = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5 129 | for i,corners in enumerate(corners_all): 130 | P = [tuple(float(n) for n in corners[j]) for j in range(4)] 131 | draw.line([P[0],P[1],P[2],P[3],P[0]],fill=tuple(self.box_colors[i]),width=3) 132 | image_pil.alpha_composite(draw_pil) 133 | image_tensor = torchvision_F.to_tensor(image_pil.convert("RGB")) 134 | return image_tensor 135 | 136 | @torch.no_grad() 137 | def predict_entire_image(self,opt): 138 | xy_grid = warp.get_normalized_pixel_grid(opt)[:1] 139 | rgb = self.graph.neural_image.forward(opt,xy_grid) # [B,HW,3] 140 | image = rgb.view(opt.H,opt.W,3).detach().cpu().permute(2,0,1) 141 | return image 142 | 143 | @torch.no_grad() 144 | def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"): 145 | super().log_scalars(opt,var,loss,metric=metric,step=step,split=split) 146 | # compute PSNR 147 | psnr = -10*loss.render.log10() 148 | self.tb.add_scalar("{0}/{1}".format(split,"PSNR"),psnr,step) 149 | # warp error 150 | warp_error = (self.graph.warp_param.weight-self.warp_pert).norm(dim=-1).mean() 151 | self.tb.add_scalar("{0}/{1}".format(split,"warp error"),warp_error,step) 152 | 153 | @torch.no_grad() 154 | def visualize(self,opt,var,step=0,split="train"): 155 | # dump frames for writing to video 156 | frame_GT = self.visualize_patches(opt,self.warp_pert) 157 | frame = self.visualize_patches(opt,self.graph.warp_param.weight) 158 | frame2 = self.predict_entire_image(opt) 159 | frame_cat = (torch.cat([frame,frame2],dim=1)*255).byte().permute(1,2,0).numpy() 160 | imageio.imsave("{}/{}.png".format(self.vis_path,self.vis_it),frame_cat) 161 | self.vis_it += 1 162 | # visualize in Tensorboard 163 | if opt.tb: 164 | colors = self.box_colors 165 | util_vis.tb_image(opt,self.tb,step,split,"image_pert",util_vis.color_border(var.image_pert,colors)) 166 | util_vis.tb_image(opt,self.tb,step,split,"rgb_warped",util_vis.color_border(var.rgb_warped_map,colors)) 167 | util_vis.tb_image(opt,self.tb,self.it+1,"train","image_boxes",frame[None]) 168 | util_vis.tb_image(opt,self.tb,self.it+1,"train","image_boxes_GT",frame_GT[None]) 169 | util_vis.tb_image(opt,self.tb,self.it+1,"train","image_entire",frame2[None]) 170 | 171 | # ============================ computation graph for forward/backprop ============================ 172 | 173 | class Graph(base.Graph): 174 | 175 | def __init__(self,opt): 176 | super().__init__(opt) 177 | self.neural_image = NeuralImageFunction(opt) 178 | 179 | def forward(self,opt,var,mode=None): 180 | xy_grid = warp.get_normalized_pixel_grid_crop(opt) 181 | xy_grid_warped = warp.warp_grid(opt,xy_grid,self.warp_param.weight) 182 | # render images 183 | var.rgb_warped = self.neural_image.forward(opt,xy_grid_warped) # [B,HW,3] 184 | var.rgb_warped_map = var.rgb_warped.view(opt.batch_size,opt.H_crop,opt.W_crop,3).permute(0,3,1,2) # [B,3,H,W] 185 | return var 186 | 187 | def compute_loss(self,opt,var,mode=None): 188 | loss = edict() 189 | if opt.loss_weight.render is not None: 190 | image_pert = var.image_pert.view(opt.batch_size,3,opt.H_crop*opt.W_crop).permute(0,2,1) 191 | loss.render = self.MSE_loss(var.rgb_warped,image_pert) 192 | return loss 193 | 194 | class NeuralImageFunction(torch.nn.Module): 195 | 196 | def __init__(self,opt): 197 | super().__init__() 198 | self.define_network(opt) 199 | self.progress = torch.nn.Parameter(torch.tensor(0.)) # use Parameter so it could be checkpointed 200 | 201 | def define_network(self,opt): 202 | input_2D_dim = 2+4*opt.arch.posenc.L_2D if opt.arch.posenc else 2 203 | # point-wise RGB prediction 204 | self.mlp = torch.nn.ModuleList() 205 | L = util.get_layer_dims(opt.arch.layers) 206 | for li,(k_in,k_out) in enumerate(L): 207 | if li==0: k_in = input_2D_dim 208 | if li in opt.arch.skip: k_in += input_2D_dim 209 | linear = torch.nn.Linear(k_in,k_out) 210 | if opt.barf_c2f and li==0: 211 | # rescale first layer init (distribution was for pos.enc. but only xy is first used) 212 | scale = np.sqrt(input_2D_dim/2.) 213 | linear.weight.data *= scale 214 | linear.bias.data *= scale 215 | self.mlp.append(linear) 216 | 217 | def forward(self,opt,coord_2D): # [B,...,3] 218 | if opt.arch.posenc: 219 | points_enc = self.positional_encoding(opt,coord_2D,L=opt.arch.posenc.L_2D) 220 | points_enc = torch.cat([coord_2D,points_enc],dim=-1) # [B,...,6L+3] 221 | else: points_enc = coord_2D 222 | feat = points_enc 223 | # extract implicit features 224 | for li,layer in enumerate(self.mlp): 225 | if li in opt.arch.skip: feat = torch.cat([feat,points_enc],dim=-1) 226 | feat = layer(feat) 227 | if li!=len(self.mlp)-1: 228 | feat = torch_F.relu(feat) 229 | rgb = feat.sigmoid_() # [B,...,3] 230 | return rgb 231 | 232 | def positional_encoding(self,opt,input,L): # [B,...,N] 233 | shape = input.shape 234 | freq = 2**torch.arange(L,dtype=torch.float32,device=opt.device)*np.pi # [L] 235 | spectrum = input[...,None]*freq # [B,...,N,L] 236 | sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L] 237 | input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L] 238 | input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL] 239 | # coarse-to-fine: smoothly mask positional encoding for BARF 240 | if opt.barf_c2f is not None: 241 | # set weights for different frequency bands 242 | start,end = opt.barf_c2f 243 | alpha = (self.progress.data-start)/(end-start)*L 244 | k = torch.arange(L,dtype=torch.float32,device=opt.device) 245 | weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2 246 | # apply weights 247 | shape = input_enc.shape 248 | input_enc = (input_enc.view(-1,L)*weight).view(*shape) 249 | return input_enc 250 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,sys,time 3 | import torch 4 | import random 5 | import string 6 | import yaml 7 | from easydict import EasyDict as edict 8 | 9 | import util 10 | from util import log 11 | 12 | # torch.backends.cudnn.enabled = False 13 | # torch.backends.cudnn.benchmark = False 14 | # torch.backends.cudnn.deterministic = True 15 | 16 | def parse_arguments(args): 17 | """ 18 | Parse arguments from command line. 19 | Syntax: --key1.key2.key3=value --> value 20 | --key1.key2.key3= --> None 21 | --key1.key2.key3 --> True 22 | --key1.key2.key3! --> False 23 | """ 24 | opt_cmd = {} 25 | for arg in args: 26 | assert(arg.startswith("--")) 27 | if "=" not in arg[2:]: 28 | key_str,value = (arg[2:-1],"false") if arg[-1]=="!" else (arg[2:],"true") 29 | else: 30 | key_str,value = arg[2:].split("=") 31 | keys_sub = key_str.split(".") 32 | opt_sub = opt_cmd 33 | for k in keys_sub[:-1]: 34 | if k not in opt_sub: opt_sub[k] = {} 35 | opt_sub = opt_sub[k] 36 | assert keys_sub[-1] not in opt_sub,keys_sub[-1] 37 | opt_sub[keys_sub[-1]] = yaml.safe_load(value) 38 | opt_cmd = edict(opt_cmd) 39 | return opt_cmd 40 | 41 | def set(opt_cmd={}): 42 | log.info("setting configurations...") 43 | assert("model" in opt_cmd) 44 | # load config from yaml file 45 | assert("yaml" in opt_cmd) 46 | fname = "options/{}.yaml".format(opt_cmd.yaml) 47 | opt_base = load_options(fname) 48 | # override with command line arguments 49 | opt = override_options(opt_base,opt_cmd,key_stack=[],safe_check=True) 50 | process_options(opt) 51 | log.options(opt) 52 | return opt 53 | 54 | def load_options(fname): 55 | with open(fname) as file: 56 | opt = edict(yaml.safe_load(file)) 57 | if "_parent_" in opt: 58 | # load parent yaml file(s) as base options 59 | parent_fnames = opt.pop("_parent_") 60 | if type(parent_fnames) is str: 61 | parent_fnames = [parent_fnames] 62 | for parent_fname in parent_fnames: 63 | opt_parent = load_options(parent_fname) 64 | opt_parent = override_options(opt_parent,opt,key_stack=[]) 65 | opt = opt_parent 66 | print("loading {}...".format(fname)) 67 | return opt 68 | 69 | def override_options(opt,opt_over,key_stack=None,safe_check=False): 70 | for key,value in opt_over.items(): 71 | if isinstance(value,dict): 72 | # parse child options (until leaf nodes are reached) 73 | opt[key] = override_options(opt.get(key,dict()),value,key_stack=key_stack+[key],safe_check=safe_check) 74 | else: 75 | # ensure command line argument to override is also in yaml file 76 | if safe_check and key not in opt: 77 | add_new = None 78 | while add_new not in ["y","n"]: 79 | key_str = ".".join(key_stack+[key]) 80 | add_new = input("\"{}\" not found in original opt, add? (y/n) ".format(key_str)) 81 | if add_new=="n": 82 | print("safe exiting...") 83 | exit() 84 | opt[key] = value 85 | return opt 86 | 87 | def process_options(opt): 88 | # set seed 89 | if opt.seed is not None: 90 | random.seed(opt.seed) 91 | np.random.seed(opt.seed) 92 | torch.manual_seed(opt.seed) 93 | torch.cuda.manual_seed_all(opt.seed) 94 | if opt.seed!=0: 95 | opt.name = str(opt.name)+"_seed{}".format(opt.seed) 96 | else: 97 | # create random string as run ID 98 | randkey = "".join(random.choice(string.ascii_uppercase) for _ in range(4)) 99 | opt.name = str(opt.name)+"_{}".format(randkey) 100 | # other default options 101 | opt.output_path = "{0}/{1}/{2}".format(opt.output_root,opt.group,opt.name) 102 | os.makedirs(opt.output_path,exist_ok=True) 103 | assert(isinstance(opt.gpu,int)) # disable multi-GPU support for now, single is enough 104 | opt.device = "cpu" if opt.cpu or not torch.cuda.is_available() else "cuda:{}".format(opt.gpu) 105 | opt.H,opt.W = opt.data.image_size 106 | 107 | def save_options_file(opt): 108 | opt_fname = "{}/options.yaml".format(opt.output_path) 109 | if os.path.isfile(opt_fname): 110 | with open(opt_fname) as file: 111 | opt_old = yaml.safe_load(file) 112 | if opt!=opt_old: 113 | # prompt if options are not identical 114 | opt_new_fname = "{}/options_temp.yaml".format(opt.output_path) 115 | with open(opt_new_fname,"w") as file: 116 | yaml.safe_dump(util.to_dict(opt),file,default_flow_style=False,indent=4) 117 | print("existing options file found (different from current one)...") 118 | os.system("diff {} {}".format(opt_fname,opt_new_fname)) 119 | os.system("rm {}".format(opt_new_fname)) 120 | override = None 121 | while override not in ["y","n"]: 122 | override = input("override? (y/n) ") 123 | if override=="n": 124 | print("safe exiting...") 125 | exit() 126 | else: print("existing options file found (identical)") 127 | else: print("(creating new options file...)") 128 | with open(opt_fname,"w") as file: 129 | yaml.safe_dump(util.to_dict(opt),file,default_flow_style=False,indent=4) 130 | -------------------------------------------------------------------------------- /options/barf_blender.yaml: -------------------------------------------------------------------------------- 1 | _parent_: options/nerf_blender.yaml 2 | 3 | barf_c2f: # coarse-to-fine scheduling on positional encoding 4 | 5 | camera: # camera options 6 | noise: 0.15 # synthetic perturbations on the camera poses (Blender only) 7 | 8 | optim: # optimization options 9 | lr_pose: 1.e-3 # learning rate of camera poses 10 | lr_pose_end: 1.e-5 # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR) 11 | sched_pose: # learning rate scheduling options 12 | type: ExponentialLR # scheduler (see PyTorch doc) 13 | gamma: # decay rate (can be empty if lr_pose_end were specified) 14 | warmup_pose: # linear warmup of the pose learning rate (N iterations) 15 | test_photo: true # test-time photometric optimization for evaluation 16 | test_iter: 100 # number of iterations for test-time optimization 17 | 18 | visdom: # Visdom options 19 | cam_depth: 0.5 # size of visualized cameras 20 | -------------------------------------------------------------------------------- /options/barf_iphone.yaml: -------------------------------------------------------------------------------- 1 | _parent_: options/barf_llff.yaml 2 | 3 | data: # data options 4 | dataset: iphone # dataset name 5 | scene: IMG_0239 # scene name 6 | image_size: [480,640] # input image sizes [height,width] 7 | -------------------------------------------------------------------------------- /options/barf_llff.yaml: -------------------------------------------------------------------------------- 1 | _parent_: options/nerf_llff.yaml 2 | 3 | barf_c2f: # coarse-to-fine scheduling on positional encoding 4 | 5 | camera: # camera options 6 | noise: # synthetic perturbations on the camera poses (Blender only) 7 | 8 | optim: # optimization options 9 | lr_pose: 3.e-3 # learning rate of camera poses 10 | lr_pose_end: 1.e-5 # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR) 11 | sched_pose: # learning rate scheduling options 12 | type: ExponentialLR # scheduler (see PyTorch doc) 13 | gamma: # decay rate (can be empty if lr_pose_end were specified) 14 | warmup_pose: # linear warmup of the pose learning rate (N iterations) 15 | test_photo: true # test-time photometric optimization for evaluation 16 | test_iter: 100 # number of iterations for test-time optimization 17 | 18 | visdom: # Visdom options 19 | cam_depth: 0.2 # size of visualized cameras 20 | -------------------------------------------------------------------------------- /options/base.yaml: -------------------------------------------------------------------------------- 1 | # default 2 | 3 | group: 0_test # name of experiment group 4 | name: debug # name of experiment run 5 | model: # type of model (must be specified from command line) 6 | yaml: # config file (must be specified from command line) 7 | seed: 0 # seed number (for both numpy and pytorch) 8 | gpu: 0 # GPU index number 9 | cpu: false # run only on CPU (not supported now) 10 | load: # load checkpoint from filename 11 | 12 | arch: {} # architectural options 13 | 14 | data: # data options 15 | root: # root path to dataset 16 | dataset: # dataset name 17 | image_size: [null,null] # input image sizes [height,width] 18 | num_workers: 8 # number of parallel workers for data loading 19 | preload: false # preload the entire dataset into the memory 20 | augment: {} # data augmentation (training only) 21 | # rotate: # random rotation 22 | # brightness: # 0.2 # random brightness jitter 23 | # contrast: # 0.2 # random contrast jitter 24 | # saturation: # 0.2 # random saturation jitter 25 | # hue: # 0.1 # random hue jitter 26 | # hflip: # True # random horizontal flip 27 | center_crop: # center crop the image by ratio 28 | val_on_test: false # validate on test set during training 29 | train_sub: # consider a subset of N training samples 30 | val_sub: # consider a subset of N validation samples 31 | 32 | loss_weight: {} # loss weights (in log scale) 33 | 34 | optim: # optimization options 35 | lr: 1.e-3 # learning rate (main) 36 | lr_end: # terminal learning rate (only used with sched.type=ExponentialLR) 37 | algo: Adam # optimizer (see PyTorch doc) 38 | sched: {} # learning rate scheduling options 39 | # type: StepLR # scheduler (see PyTorch doc) 40 | # steps: # decay every N epochs 41 | # gamma: 0.1 # decay rate (can be empty if lr_end were specified) 42 | 43 | batch_size: 16 # batch size 44 | max_epoch: 1000 # train to maximum number of epochs 45 | resume: false # resume training (true for latest checkpoint, or number for specific epoch number) 46 | 47 | output_root: output # root path for output files (checkpoints and results) 48 | tb: # TensorBoard options 49 | num_images: [4,8] # number of (tiled) images to visualize in TensorBoard 50 | visdom: # Visdom options 51 | server: localhost # server to host Visdom 52 | port: 9000 # port number for Visdom 53 | 54 | freq: # periodic actions during training 55 | scalar: 200 # log losses and scalar states (every N iterations) 56 | vis: 1000 # visualize results (every N iterations) 57 | val: 20 # validate on val set (every N epochs) 58 | ckpt: 50 # save checkpoint (every N epochs) 59 | -------------------------------------------------------------------------------- /options/nerf_blender.yaml: -------------------------------------------------------------------------------- 1 | _parent_: options/base.yaml 2 | 3 | arch: # architectural options 4 | layers_feat: [null,256,256,256,256,256,256,256,256] # hidden layers for feature/density MLP 5 | layers_rgb: [null,128,3] # hidden layers for color MLP 6 | skip: [4] # skip connections 7 | posenc: # positional encoding 8 | L_3D: 10 # number of bases (3D point) 9 | L_view: 4 # number of bases (viewpoint) 10 | density_activ: softplus # activation function for output volume density 11 | tf_init: true # initialize network weights in TensorFlow style 12 | 13 | nerf: # NeRF-specific options 14 | view_dep: true # condition MLP on viewpoint 15 | depth: # depth-related options 16 | param: metric # depth parametrization (for sampling along the ray) 17 | range: [2,6] # near/far bounds for depth sampling 18 | sample_intvs: 128 # number of samples 19 | sample_stratified: true # stratified sampling 20 | fine_sampling: false # hierarchical sampling with another NeRF 21 | sample_intvs_fine: # number of samples for the fine NeRF 22 | rand_rays: 1024 # number of random rays for each step 23 | density_noise_reg: # Gaussian noise on density output as regularization 24 | setbg_opaque: false # fill transparent rendering with known background color (Blender only) 25 | 26 | data: # data options 27 | dataset: blender # dataset name 28 | scene: lego # scene name 29 | image_size: [400,400] # input image sizes [height,width] 30 | num_workers: 4 # number of parallel workers for data loading 31 | preload: true # preload the entire dataset into the memory 32 | bgcolor: 1 # background color (Blender only) 33 | val_sub: 4 # consider a subset of N validation samples 34 | 35 | camera: # camera options 36 | model: perspective # type of camera model 37 | ndc: false # reparametrize as normalized device coordinates (NDC) 38 | 39 | loss_weight: # loss weights (in log scale) 40 | render: 0 # RGB rendering loss 41 | render_fine: # RGB rendering loss (for fine NeRF) 42 | 43 | optim: # optimization options 44 | lr: 5.e-4 # learning rate (main) 45 | lr_end: 1.e-4 # terminal learning rate (only used with sched.type=ExponentialLR) 46 | sched: # learning rate scheduling options 47 | type: ExponentialLR # scheduler (see PyTorch doc) 48 | gamma: # decay rate (can be empty if lr_end were specified) 49 | 50 | batch_size: # batch size (not used for NeRF/BARF) 51 | max_epoch: # train to maximum number of epochs (not used for NeRF/BARF) 52 | max_iter: 200000 # train to maximum number of iterations 53 | 54 | trimesh: # options for marching cubes to extract 3D mesh 55 | res: 128 # 3D sampling resolution 56 | range: [-1.2,1.2] # 3D range of interest (assuming same for x,y,z) 57 | thres: 25. # volume density threshold for marching cubes 58 | chunk_size: 16384 # chunk size of dense samples to be evaluated at a time 59 | 60 | freq: # periodic actions during training 61 | scalar: 200 # log losses and scalar states (every N iterations) 62 | vis: 1000 # visualize results (every N iterations) 63 | val: 2000 # validate on val set (every N iterations) 64 | ckpt: 5000 # save checkpoint (every N iterations) 65 | -------------------------------------------------------------------------------- /options/nerf_blender_repr.yaml: -------------------------------------------------------------------------------- 1 | _parent_: options/base.yaml 2 | 3 | arch: # architectural options 4 | layers_feat: [null,256,256,256,256,256,256,256,256] # hidden layers for feature/density MLP 5 | layers_rgb: [null,128,3] # hidden layers for color MLP 6 | skip: [4] # skip connections 7 | posenc: # positional encoding 8 | L_3D: 10 # number of bases (3D point) 9 | L_view: 4 # number of bases (viewpoint) 10 | density_activ: relu # activation function for output volume density 11 | tf_init: true # initialize network weights in TensorFlow style 12 | 13 | nerf: # NeRF-specific options 14 | view_dep: true # condition MLP on viewpoint 15 | depth: # depth-related options 16 | param: metric # depth parametrization (for sampling along the ray) 17 | range: [2,6] # near/far bounds for depth sampling 18 | sample_intvs: 64 # number of samples 19 | sample_stratified: true # stratified sampling 20 | fine_sampling: true # hierarchical sampling with another NeRF 21 | sample_intvs_fine: 128 # number of samples for the fine NeRF 22 | rand_rays: 1024 # number of random rays for each step 23 | density_noise_reg: 0 # Gaussian noise on density output as regularization 24 | setbg_opaque: true # fill transparent rendering with known background color (Blender only) 25 | 26 | data: # data options 27 | dataset: blender # dataset name 28 | scene: lego # scene name 29 | image_size: [400,400] # input image sizes [height,width] 30 | num_workers: 4 # number of parallel workers for data loading 31 | preload: true # preload the entire dataset into the memory 32 | bgcolor: 1 # background color (Blender only) 33 | val_sub: 4 # consider a subset of N validation samples 34 | 35 | camera: # camera options 36 | model: perspective # type of camera model 37 | ndc: false # reparametrize as normalized device coordinates (NDC) 38 | 39 | loss_weight: # loss weights (in log scale) 40 | render: 0 # RGB rendering loss 41 | render_fine: 0 # RGB rendering loss (for fine NeRF) 42 | 43 | optim: # optimization options 44 | lr: 5.e-4 # learning rate (main) 45 | lr_end: 5.e-5 # terminal learning rate (only used with sched.type=ExponentialLR) 46 | sched: # learning rate scheduling options 47 | type: ExponentialLR # scheduler (see PyTorch doc) 48 | gamma: # decay rate (can be empty if lr_end were specified) 49 | 50 | batch_size: # batch size (not used for NeRF/BARF) 51 | max_epoch: # train to maximum number of epochs (not used for NeRF/BARF) 52 | max_iter: 500000 # train to maximum number of iterations 53 | 54 | trimesh: # options for marching cubes to extract 3D mesh 55 | res: 128 # 3D sampling resolution 56 | range: [-1.2,1.2] # 3D range of interest (assuming same for x,y,z) 57 | thres: 25. # volume density threshold for marching cubes 58 | chunk_size: 16384 # chunk size of dense samples to be evaluated at a time 59 | 60 | freq: # periodic actions during training 61 | scalar: 200 # log losses and scalar states (every N iterations) 62 | vis: 1000 # visualize results (every N iterations) 63 | val: 2000 # validate on val set (every N iterations) 64 | ckpt: 5000 # save checkpoint (every N iterations) 65 | -------------------------------------------------------------------------------- /options/nerf_llff.yaml: -------------------------------------------------------------------------------- 1 | _parent_: options/base.yaml 2 | 3 | arch: # architectural optionss 4 | layers_feat: [null,256,256,256,256,256,256,256,256] # hidden layers for feature/density MLP] 5 | layers_rgb: [null,128,3] # hidden layers for color MLP] 6 | skip: [4] # skip connections 7 | posenc: # positional encoding: 8 | L_3D: 10 # number of bases (3D point) 9 | L_view: 4 # number of bases (viewpoint) 10 | density_activ: softplus # activation function for output volume density 11 | tf_init: true # initialize network weights in TensorFlow style 12 | 13 | nerf: # NeRF-specific options 14 | view_dep: true # condition MLP on viewpoint 15 | depth: # depth-related options 16 | param: inverse # depth parametrization (for sampling along the ray) 17 | range: [1,0] # near/far bounds for depth sampling 18 | sample_intvs: 128 # number of samples 19 | sample_stratified: true # stratified sampling 20 | fine_sampling: false # hierarchical sampling with another NeRF 21 | sample_intvs_fine: # number of samples for the fine NeRF 22 | rand_rays: 2048 # number of random rays for each step 23 | density_noise_reg: # Gaussian noise on density output as regularization 24 | setbg_opaque: # fill transparent rendering with known background color (Blender only) 25 | 26 | data: # data options 27 | dataset: llff # dataset name 28 | scene: fern # scene name 29 | image_size: [480,640] # input image sizes [height,width] 30 | num_workers: 4 # number of parallel workers for data loading 31 | preload: true # preload the entire dataset into the memory 32 | val_ratio: 0.1 # ratio of sequence split for validation 33 | 34 | camera: # camera options 35 | model: perspective # type of camera model 36 | ndc: false # reparametrize as normalized device coordinates (NDC) 37 | 38 | loss_weight: # loss weights (in log scale) 39 | render: 0 # RGB rendering loss 40 | render_fine: # RGB rendering loss (for fine NeRF) 41 | 42 | optim: # optimization options 43 | lr: 1.e-3 # learning rate (main) 44 | lr_end: 1.e-4 # terminal learning rate (only used with sched.type=ExponentialLR) 45 | sched: # learning rate scheduling options 46 | type: ExponentialLR # scheduler (see PyTorch doc) 47 | gamma: # decay rate (can be empty if lr_end were specified) 48 | 49 | batch_size: # batch size (not used for NeRF/BARF) 50 | max_epoch: # train to maximum number of epochs (not used for NeRF/BARF) 51 | max_iter: 200000 # train to maximum number of iterations 52 | 53 | freq: # periodic actions during training 54 | scalar: 200 # log losses and scalar states (every N iterations) 55 | vis: 1000 # visualize results (every N iterations) 56 | val: 2000 # validate on val set (every N iterations) 57 | ckpt: 5000 # save checkpoint (every N iterations) 58 | -------------------------------------------------------------------------------- /options/nerf_llff_repr.yaml: -------------------------------------------------------------------------------- 1 | _parent_: options/base.yaml 2 | 3 | arch: # architectural options 4 | layers_feat: [null,256,256,256,256,256,256,256,256] # hidden layers for feature/density MLP 5 | layers_rgb: [null,128,3] # hidden layers for color MLP 6 | skip: [4] # skip connections 7 | posenc: # positional encoding 8 | L_3D: 10 # number of bases (3D point) 9 | L_view: 4 # number of bases (viewpoint) 10 | density_activ: relu # activation function for output volume density 11 | tf_init: true # initialize network weights in TensorFlow style 12 | 13 | nerf: # NeRF-specific options 14 | view_dep: true # condition MLP on viewpoint 15 | depth: # depth-related options 16 | param: metric # depth parametrization (for sampling along the ray) 17 | range: [0,1] # near/far bounds for depth sampling 18 | sample_intvs: 64 # number of samples 19 | sample_stratified: true # stratified sampling 20 | fine_sampling: true # hierarchical sampling with another NeRF 21 | sample_intvs_fine: 128 # number of samples for the fine NeRF 22 | rand_rays: 1024 # number of random rays for each step 23 | density_noise_reg: 1 # Gaussian noise on density output as regularization 24 | setbg_opaque: # fill transparent rendering with known background color (Blender only) 25 | 26 | data: # data options 27 | dataset: llff # dataset name 28 | scene: fern # scene name 29 | image_size: [480,640] # input image sizes [height,width] 30 | num_workers: 4 # number of parallel workers for data loading 31 | preload: true # preload the entire dataset into the memory 32 | val_ratio: 0.1 # ratio of sequence split for validation 33 | 34 | camera: # camera options 35 | model: perspective # type of camera model 36 | ndc: false # reparametrize as normalized device coordinates (NDC) 37 | 38 | loss_weight: # loss weights (in log scale) 39 | render: 0 # RGB rendering loss 40 | render_fine: 0 # RGB rendering loss (for fine NeRF) 41 | 42 | optim: # optimization options 43 | lr: 5.e-4 # learning rate (main) 44 | lr_end: 5.e-5 # terminal learning rate (only used with sched.type=ExponentialLR) 45 | sched: # learning rate scheduling options 46 | type: ExponentialLR # scheduler (see PyTorch doc) 47 | gamma: # decay rate (can be empty if lr_end were specified) 48 | 49 | batch_size: # batch size (not used for NeRF/BARF) 50 | max_epoch: # train to maximum number of epochs (not used for NeRF/BARF) 51 | max_iter: 500000 # train to maximum number of iterations 52 | 53 | freq: # periodic actions during training 54 | scalar: 200 # log losses and scalar states (every N iterations) 55 | vis: 1000 # visualize results (every N iterations) 56 | val: 2000 # validate on val set (every N iterations) 57 | ckpt: 5000 # save checkpoint (every N iterations) 58 | -------------------------------------------------------------------------------- /options/planar.yaml: -------------------------------------------------------------------------------- 1 | _parent_: options/base.yaml 2 | 3 | arch: # architectural options 4 | layers: [null,256,256,256,256,3] # hidden layers for MLP 5 | skip: [] # skip connections 6 | posenc: # positional encoding 7 | L_2D: 8 # number of bases (3D point) 8 | 9 | barf_c2f: # coarse-to-fine scheduling on positional encoding 10 | 11 | data: # data options 12 | image_fname: data/cat.jpg # path to image file 13 | image_size: [360,480] # original image size 14 | patch_crop: [180,180] # crop size of image patches to align 15 | 16 | warp: # image warping options 17 | type: homography # type of warp function 18 | dof: 8 # degrees of freedom of the warp function 19 | noise_h: 0.1 # scale of pre-generated warp perturbation (homography) 20 | noise_t: 0.2 # scale of pre-generated warp perturbation (translation) 21 | fix_first: true # fix the first patch for uniqueness of solution 22 | 23 | loss_weight: # loss weights (in log scale) 24 | render: 0 # RGB rendering loss 25 | 26 | optim: # optimization options 27 | lr: 1.e-3 # learning rate (main) 28 | lr_warp: 1.e-3 # learning rate of warp parameters 29 | 30 | batch_size: 5 # batch size (set to number of patches to consider) 31 | max_iter: 5000 # train to maximum number of iterations 32 | 33 | visdom: # Visdom options (turned off) 34 | 35 | freq: # periodic actions during training 36 | scalar: 20 # log losses and scalar states (every N iterations) 37 | vis: 100 # visualize results (every N iterations) 38 | -------------------------------------------------------------------------------- /requirements.yaml: -------------------------------------------------------------------------------- 1 | name: barf-env 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | dependencies: 6 | - numpy 7 | - scipy 8 | - tqdm 9 | - termcolor 10 | - easydict 11 | - imageio 12 | - ipdb 13 | - pytorch>=1.9.0 14 | - torchvision 15 | - tensorboard 16 | - visdom 17 | - matplotlib 18 | - scikit-video 19 | - trimesh 20 | - pyyaml 21 | - pip 22 | - gdown 23 | - pip: 24 | - lpips 25 | - pymcubes 26 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,sys,time 3 | import torch 4 | import importlib 5 | 6 | import options 7 | from util import log 8 | 9 | def main(): 10 | 11 | log.process(os.getpid()) 12 | log.title("[{}] (PyTorch code for training NeRF/BARF)".format(sys.argv[0])) 13 | 14 | opt_cmd = options.parse_arguments(sys.argv[1:]) 15 | opt = options.set(opt_cmd=opt_cmd) 16 | options.save_options_file(opt) 17 | 18 | with torch.cuda.device(opt.device): 19 | 20 | model = importlib.import_module("model.{}".format(opt.model)) 21 | m = model.Model(opt) 22 | 23 | m.load_dataset(opt) 24 | m.build_networks(opt) 25 | m.setup_optimizer(opt) 26 | m.restore_checkpoint(opt) 27 | m.setup_visualizer(opt) 28 | 29 | m.train(opt) 30 | 31 | if __name__=="__main__": 32 | main() 33 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,sys,time 3 | import shutil 4 | import datetime 5 | import torch 6 | import torch.nn.functional as torch_F 7 | import ipdb 8 | import types 9 | import termcolor 10 | import socket 11 | import contextlib 12 | from easydict import EasyDict as edict 13 | 14 | # convert to colored strings 15 | def red(message,**kwargs): return termcolor.colored(str(message),color="red",attrs=[k for k,v in kwargs.items() if v is True]) 16 | def green(message,**kwargs): return termcolor.colored(str(message),color="green",attrs=[k for k,v in kwargs.items() if v is True]) 17 | def blue(message,**kwargs): return termcolor.colored(str(message),color="blue",attrs=[k for k,v in kwargs.items() if v is True]) 18 | def cyan(message,**kwargs): return termcolor.colored(str(message),color="cyan",attrs=[k for k,v in kwargs.items() if v is True]) 19 | def yellow(message,**kwargs): return termcolor.colored(str(message),color="yellow",attrs=[k for k,v in kwargs.items() if v is True]) 20 | def magenta(message,**kwargs): return termcolor.colored(str(message),color="magenta",attrs=[k for k,v in kwargs.items() if v is True]) 21 | def grey(message,**kwargs): return termcolor.colored(str(message),color="grey",attrs=[k for k,v in kwargs.items() if v is True]) 22 | 23 | def get_time(sec): 24 | d = int(sec//(24*60*60)) 25 | h = int(sec//(60*60)%24) 26 | m = int((sec//60)%60) 27 | s = int(sec%60) 28 | return d,h,m,s 29 | 30 | def add_datetime(func): 31 | def wrapper(*args,**kwargs): 32 | datetime_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 33 | print(grey("[{}] ".format(datetime_str),bold=True),end="") 34 | return func(*args,**kwargs) 35 | return wrapper 36 | 37 | def add_functionname(func): 38 | def wrapper(*args,**kwargs): 39 | print(grey("[{}] ".format(func.__name__),bold=True)) 40 | return func(*args,**kwargs) 41 | return wrapper 42 | 43 | def pre_post_actions(pre=None,post=None): 44 | def func_decorator(func): 45 | def wrapper(*args,**kwargs): 46 | if pre: pre() 47 | retval = func(*args,**kwargs) 48 | if post: post() 49 | return retval 50 | return wrapper 51 | return func_decorator 52 | 53 | debug = ipdb.set_trace 54 | 55 | class Log: 56 | def __init__(self): pass 57 | def process(self,pid): 58 | print(grey("Process ID: {}".format(pid),bold=True)) 59 | def title(self,message): 60 | print(yellow(message,bold=True,underline=True)) 61 | def info(self,message): 62 | print(magenta(message,bold=True)) 63 | def options(self,opt,level=0): 64 | for key,value in sorted(opt.items()): 65 | if isinstance(value,(dict,edict)): 66 | print(" "*level+cyan("* ")+green(key)+":") 67 | self.options(value,level+1) 68 | else: 69 | print(" "*level+cyan("* ")+green(key)+":",yellow(value)) 70 | def loss_train(self,opt,ep,lr,loss,timer): 71 | if not opt.max_epoch: return 72 | message = grey("[train] ",bold=True) 73 | message += "epoch {}/{}".format(cyan(ep,bold=True),opt.max_epoch) 74 | message += ", lr:{}".format(yellow("{:.2e}".format(lr),bold=True)) 75 | message += ", loss:{}".format(red("{:.3e}".format(loss),bold=True)) 76 | message += ", time:{}".format(blue("{0}-{1:02d}:{2:02d}:{3:02d}".format(*get_time(timer.elapsed)),bold=True)) 77 | message += " (ETA:{})".format(blue("{0}-{1:02d}:{2:02d}:{3:02d}".format(*get_time(timer.arrival)))) 78 | print(message) 79 | def loss_val(self,opt,loss): 80 | message = grey("[val] ",bold=True) 81 | message += "loss:{}".format(red("{:.3e}".format(loss),bold=True)) 82 | print(message) 83 | log = Log() 84 | 85 | def update_timer(opt,timer,ep,it_per_ep): 86 | if not opt.max_epoch: return 87 | momentum = 0.99 88 | timer.elapsed = time.time()-timer.start 89 | timer.it = timer.it_end-timer.it_start 90 | # compute speed with moving average 91 | timer.it_mean = timer.it_mean*momentum+timer.it*(1-momentum) if timer.it_mean is not None else timer.it 92 | timer.arrival = timer.it_mean*it_per_ep*(opt.max_epoch-ep) 93 | 94 | # move tensors to device in-place 95 | def move_to_device(X,device): 96 | if isinstance(X,dict): 97 | for k,v in X.items(): 98 | X[k] = move_to_device(v,device) 99 | elif isinstance(X,list): 100 | for i,e in enumerate(X): 101 | X[i] = move_to_device(e,device) 102 | elif isinstance(X,tuple) and hasattr(X,"_fields"): # collections.namedtuple 103 | dd = X._asdict() 104 | dd = move_to_device(dd,device) 105 | return type(X)(**dd) 106 | elif isinstance(X,torch.Tensor): 107 | return X.to(device=device) 108 | return X 109 | 110 | def to_dict(D,dict_type=dict): 111 | D = dict_type(D) 112 | for k,v in D.items(): 113 | if isinstance(v,dict): 114 | D[k] = to_dict(v,dict_type) 115 | return D 116 | 117 | def get_child_state_dict(state_dict,key): 118 | return { ".".join(k.split(".")[1:]): v for k,v in state_dict.items() if k.startswith("{}.".format(key)) } 119 | 120 | def restore_checkpoint(opt,model,load_name=None,resume=False): 121 | assert((load_name is None)==(resume is not False)) # resume can be True/False or epoch numbers 122 | if resume: 123 | load_name = "{0}/model.ckpt".format(opt.output_path) if resume is True else \ 124 | "{0}/model/{1}.ckpt".format(opt.output_path,resume) 125 | checkpoint = torch.load(load_name,map_location=opt.device) 126 | # load individual (possibly partial) children modules 127 | for name,child in model.graph.named_children(): 128 | child_state_dict = get_child_state_dict(checkpoint["graph"],name) 129 | if child_state_dict: 130 | print("restoring {}...".format(name)) 131 | child.load_state_dict(child_state_dict) 132 | for key in model.__dict__: 133 | if key.split("_")[0] in ["optim","sched"] and key in checkpoint and resume: 134 | print("restoring {}...".format(key)) 135 | getattr(model,key).load_state_dict(checkpoint[key]) 136 | if resume: 137 | ep,it = checkpoint["epoch"],checkpoint["iter"] 138 | if resume is not True: assert(resume==(ep or it)) 139 | print("resuming from epoch {0} (iteration {1})".format(ep,it)) 140 | else: ep,it = None,None 141 | return ep,it 142 | 143 | def save_checkpoint(opt,model,ep,it,latest=False,children=None): 144 | os.makedirs("{0}/model".format(opt.output_path),exist_ok=True) 145 | if children is not None: 146 | graph_state_dict = { k: v for k,v in model.graph.state_dict().items() if k.startswith(children) } 147 | else: graph_state_dict = model.graph.state_dict() 148 | checkpoint = dict( 149 | epoch=ep, 150 | iter=it, 151 | graph=graph_state_dict, 152 | ) 153 | for key in model.__dict__: 154 | if key.split("_")[0] in ["optim","sched"]: 155 | checkpoint.update({ key: getattr(model,key).state_dict() }) 156 | torch.save(checkpoint,"{0}/model.ckpt".format(opt.output_path)) 157 | if not latest: 158 | shutil.copy("{0}/model.ckpt".format(opt.output_path), 159 | "{0}/model/{1}.ckpt".format(opt.output_path,ep or it)) # if ep is None, track it instead 160 | 161 | def check_socket_open(hostname,port): 162 | s = socket.socket(socket.AF_INET,socket.SOCK_STREAM) 163 | is_open = False 164 | try: 165 | s.bind((hostname,port)) 166 | except socket.error: 167 | is_open = True 168 | finally: 169 | s.close() 170 | return is_open 171 | 172 | def get_layer_dims(layers): 173 | # return a list of tuples (k_in,k_out) 174 | return list(zip(layers[:-1],layers[1:])) 175 | 176 | @contextlib.contextmanager 177 | def suppress(stdout=False,stderr=False): 178 | with open(os.devnull,"w") as devnull: 179 | if stdout: old_stdout,sys.stdout = sys.stdout,devnull 180 | if stderr: old_stderr,sys.stderr = sys.stderr,devnull 181 | try: yield 182 | finally: 183 | if stdout: sys.stdout = old_stdout 184 | if stderr: sys.stderr = old_stderr 185 | 186 | def colorcode_to_number(code): 187 | ords = [ord(c) for c in code[1:]] 188 | ords = [n-48 if n<58 else n-87 for n in ords] 189 | rgb = (ords[0]*16+ords[1],ords[2]*16+ords[3],ords[4]*16+ords[5]) 190 | return rgb 191 | -------------------------------------------------------------------------------- /util_vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,sys,time 3 | import torch 4 | import torch.nn.functional as torch_F 5 | import torchvision 6 | import torchvision.transforms.functional as torchvision_F 7 | import matplotlib.pyplot as plt 8 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 9 | import PIL 10 | import imageio 11 | from easydict import EasyDict as edict 12 | 13 | import camera 14 | 15 | @torch.no_grad() 16 | def tb_image(opt,tb,step,group,name,images,num_vis=None,from_range=(0,1),cmap="gray"): 17 | images = preprocess_vis_image(opt,images,from_range=from_range,cmap=cmap) 18 | num_H,num_W = num_vis or opt.tb.num_images 19 | images = images[:num_H*num_W] 20 | image_grid = torchvision.utils.make_grid(images[:,:3],nrow=num_W,pad_value=1.) 21 | if images.shape[1]==4: 22 | mask_grid = torchvision.utils.make_grid(images[:,3:],nrow=num_W,pad_value=1.)[:1] 23 | image_grid = torch.cat([image_grid,mask_grid],dim=0) 24 | tag = "{0}/{1}".format(group,name) 25 | tb.add_image(tag,image_grid,step) 26 | 27 | def preprocess_vis_image(opt,images,from_range=(0,1),cmap="gray"): 28 | min,max = from_range 29 | images = (images-min)/(max-min) 30 | images = images.clamp(min=0,max=1).cpu() 31 | if images.shape[1]==1: 32 | images = get_heatmap(opt,images[:,0].cpu(),cmap=cmap) 33 | return images 34 | 35 | def dump_images(opt,idx,name,images,masks=None,from_range=(0,1),cmap="gray"): 36 | images = preprocess_vis_image(opt,images,masks=masks,from_range=from_range,cmap=cmap) # [B,3,H,W] 37 | images = images.cpu().permute(0,2,3,1).numpy() # [B,H,W,3] 38 | for i,img in zip(idx,images): 39 | fname = "{}/dump/{}_{}.png".format(opt.output_path,i,name) 40 | img_uint8 = (img*255).astype(np.uint8) 41 | imageio.imsave(fname,img_uint8) 42 | 43 | def get_heatmap(opt,gray,cmap): # [N,H,W] 44 | color = plt.get_cmap(cmap)(gray.numpy()) 45 | color = torch.from_numpy(color[...,:3]).permute(0,3,1,2).float() # [N,3,H,W] 46 | return color 47 | 48 | def color_border(images,colors,width=3): 49 | images_pad = [] 50 | for i,image in enumerate(images): 51 | image_pad = torch.ones(3,image.shape[1]+width*2,image.shape[2]+width*2)*(colors[i,:,None,None]/255.0) 52 | image_pad[:,width:-width,width:-width] = image 53 | images_pad.append(image_pad) 54 | images_pad = torch.stack(images_pad,dim=0) 55 | return images_pad 56 | 57 | @torch.no_grad() 58 | def vis_cameras(opt,vis,step,poses=[],colors=["blue","magenta"],plot_dist=True): 59 | win_name = "{}/{}".format(opt.group,opt.name) 60 | data = [] 61 | # set up plots 62 | centers = [] 63 | for pose,color in zip(poses,colors): 64 | pose = pose.detach().cpu() 65 | vertices,faces,wireframe = get_camera_mesh(pose,depth=opt.visdom.cam_depth) 66 | center = vertices[:,-1] 67 | centers.append(center) 68 | # camera centers 69 | data.append(dict( 70 | type="scatter3d", 71 | x=[float(n) for n in center[:,0]], 72 | y=[float(n) for n in center[:,1]], 73 | z=[float(n) for n in center[:,2]], 74 | mode="markers", 75 | marker=dict(color=color,size=3), 76 | )) 77 | # colored camera mesh 78 | vertices_merged,faces_merged = merge_meshes(vertices,faces) 79 | data.append(dict( 80 | type="mesh3d", 81 | x=[float(n) for n in vertices_merged[:,0]], 82 | y=[float(n) for n in vertices_merged[:,1]], 83 | z=[float(n) for n in vertices_merged[:,2]], 84 | i=[int(n) for n in faces_merged[:,0]], 85 | j=[int(n) for n in faces_merged[:,1]], 86 | k=[int(n) for n in faces_merged[:,2]], 87 | flatshading=True, 88 | color=color, 89 | opacity=0.05, 90 | )) 91 | # camera wireframe 92 | wireframe_merged = merge_wireframes(wireframe) 93 | data.append(dict( 94 | type="scatter3d", 95 | x=wireframe_merged[0], 96 | y=wireframe_merged[1], 97 | z=wireframe_merged[2], 98 | mode="lines", 99 | line=dict(color=color,), 100 | opacity=0.3, 101 | )) 102 | if plot_dist: 103 | # distance between two poses (camera centers) 104 | center_merged = merge_centers(centers[:2]) 105 | data.append(dict( 106 | type="scatter3d", 107 | x=center_merged[0], 108 | y=center_merged[1], 109 | z=center_merged[2], 110 | mode="lines", 111 | line=dict(color="red",width=4,), 112 | )) 113 | if len(centers)==4: 114 | center_merged = merge_centers(centers[2:4]) 115 | data.append(dict( 116 | type="scatter3d", 117 | x=center_merged[0], 118 | y=center_merged[1], 119 | z=center_merged[2], 120 | mode="lines", 121 | line=dict(color="red",width=4,), 122 | )) 123 | # send data to visdom 124 | vis._send(dict( 125 | data=data, 126 | win="poses", 127 | eid=win_name, 128 | layout=dict( 129 | title="({})".format(step), 130 | autosize=True, 131 | margin=dict(l=30,r=30,b=30,t=30,), 132 | showlegend=False, 133 | yaxis=dict( 134 | scaleanchor="x", 135 | scaleratio=1, 136 | ) 137 | ), 138 | opts=dict(title="{} poses ({})".format(win_name,step),), 139 | )) 140 | 141 | def get_camera_mesh(pose,depth=1): 142 | vertices = torch.tensor([[-0.5,-0.5,1], 143 | [0.5,-0.5,1], 144 | [0.5,0.5,1], 145 | [-0.5,0.5,1], 146 | [0,0,0]])*depth 147 | faces = torch.tensor([[0,1,2], 148 | [0,2,3], 149 | [0,1,4], 150 | [1,2,4], 151 | [2,3,4], 152 | [3,0,4]]) 153 | vertices = camera.cam2world(vertices[None],pose) 154 | wireframe = vertices[:,[0,1,2,3,0,4,1,2,4,3]] 155 | return vertices,faces,wireframe 156 | 157 | def merge_wireframes(wireframe): 158 | wireframe_merged = [[],[],[]] 159 | for w in wireframe: 160 | wireframe_merged[0] += [float(n) for n in w[:,0]]+[None] 161 | wireframe_merged[1] += [float(n) for n in w[:,1]]+[None] 162 | wireframe_merged[2] += [float(n) for n in w[:,2]]+[None] 163 | return wireframe_merged 164 | def merge_meshes(vertices,faces): 165 | mesh_N,vertex_N = vertices.shape[:2] 166 | faces_merged = torch.cat([faces+i*vertex_N for i in range(mesh_N)],dim=0) 167 | vertices_merged = vertices.view(-1,vertices.shape[-1]) 168 | return vertices_merged,faces_merged 169 | def merge_centers(centers): 170 | center_merged = [[],[],[]] 171 | for c1,c2 in zip(*centers): 172 | center_merged[0] += [float(c1[0]),float(c2[0]),None] 173 | center_merged[1] += [float(c1[1]),float(c2[1]),None] 174 | center_merged[2] += [float(c1[2]),float(c2[2]),None] 175 | return center_merged 176 | 177 | def plot_save_poses(opt,fig,pose,pose_ref=None,path=None,ep=None): 178 | # get the camera meshes 179 | _,_,cam = get_camera_mesh(pose,depth=opt.visdom.cam_depth) 180 | cam = cam.numpy() 181 | if pose_ref is not None: 182 | _,_,cam_ref = get_camera_mesh(pose_ref,depth=opt.visdom.cam_depth) 183 | cam_ref = cam_ref.numpy() 184 | # set up plot window(s) 185 | plt.title("epoch {}".format(ep)) 186 | ax1 = fig.add_subplot(121,projection="3d") 187 | ax2 = fig.add_subplot(122,projection="3d") 188 | setup_3D_plot(ax1,elev=-90,azim=-90,lim=edict(x=(-1,1),y=(-1,1),z=(-1,1))) 189 | setup_3D_plot(ax2,elev=0,azim=-90,lim=edict(x=(-1,1),y=(-1,1),z=(-1,1))) 190 | ax1.set_title("forward-facing view",pad=0) 191 | ax2.set_title("top-down view",pad=0) 192 | plt.subplots_adjust(left=0,right=1,bottom=0,top=0.95,wspace=0,hspace=0) 193 | plt.margins(tight=True,x=0,y=0) 194 | # plot the cameras 195 | N = len(cam) 196 | color = plt.get_cmap("gist_rainbow") 197 | for i in range(N): 198 | if pose_ref is not None: 199 | ax1.plot(cam_ref[i,:,0],cam_ref[i,:,1],cam_ref[i,:,2],color=(0.3,0.3,0.3),linewidth=1) 200 | ax2.plot(cam_ref[i,:,0],cam_ref[i,:,1],cam_ref[i,:,2],color=(0.3,0.3,0.3),linewidth=1) 201 | ax1.scatter(cam_ref[i,5,0],cam_ref[i,5,1],cam_ref[i,5,2],color=(0.3,0.3,0.3),s=40) 202 | ax2.scatter(cam_ref[i,5,0],cam_ref[i,5,1],cam_ref[i,5,2],color=(0.3,0.3,0.3),s=40) 203 | c = np.array(color(float(i)/N))*0.8 204 | ax1.plot(cam[i,:,0],cam[i,:,1],cam[i,:,2],color=c) 205 | ax2.plot(cam[i,:,0],cam[i,:,1],cam[i,:,2],color=c) 206 | ax1.scatter(cam[i,5,0],cam[i,5,1],cam[i,5,2],color=c,s=40) 207 | ax2.scatter(cam[i,5,0],cam[i,5,1],cam[i,5,2],color=c,s=40) 208 | png_fname = "{}/{}.png".format(path,ep) 209 | plt.savefig(png_fname,dpi=75) 210 | # clean up 211 | plt.clf() 212 | 213 | def plot_save_poses_blender(opt,fig,pose,pose_ref=None,path=None,ep=None): 214 | # get the camera meshes 215 | _,_,cam = get_camera_mesh(pose,depth=opt.visdom.cam_depth) 216 | cam = cam.numpy() 217 | if pose_ref is not None: 218 | _,_,cam_ref = get_camera_mesh(pose_ref,depth=opt.visdom.cam_depth) 219 | cam_ref = cam_ref.numpy() 220 | # set up plot window(s) 221 | ax = fig.add_subplot(111,projection="3d") 222 | ax.set_title("epoch {}".format(ep),pad=0) 223 | setup_3D_plot(ax,elev=45,azim=35,lim=edict(x=(-3,3),y=(-3,3),z=(-3,2.4))) 224 | plt.subplots_adjust(left=0,right=1,bottom=0,top=0.95,wspace=0,hspace=0) 225 | plt.margins(tight=True,x=0,y=0) 226 | # plot the cameras 227 | N = len(cam) 228 | ref_color = (0.7,0.2,0.7) 229 | pred_color = (0,0.6,0.7) 230 | ax.add_collection3d(Poly3DCollection([v[:4] for v in cam_ref],alpha=0.2,facecolor=ref_color)) 231 | for i in range(N): 232 | ax.plot(cam_ref[i,:,0],cam_ref[i,:,1],cam_ref[i,:,2],color=ref_color,linewidth=0.5) 233 | ax.scatter(cam_ref[i,5,0],cam_ref[i,5,1],cam_ref[i,5,2],color=ref_color,s=20) 234 | if ep==0: 235 | png_fname = "{}/GT.png".format(path) 236 | plt.savefig(png_fname,dpi=75) 237 | ax.add_collection3d(Poly3DCollection([v[:4] for v in cam],alpha=0.2,facecolor=pred_color)) 238 | for i in range(N): 239 | ax.plot(cam[i,:,0],cam[i,:,1],cam[i,:,2],color=pred_color,linewidth=1) 240 | ax.scatter(cam[i,5,0],cam[i,5,1],cam[i,5,2],color=pred_color,s=20) 241 | for i in range(N): 242 | ax.plot([cam[i,5,0],cam_ref[i,5,0]], 243 | [cam[i,5,1],cam_ref[i,5,1]], 244 | [cam[i,5,2],cam_ref[i,5,2]],color=(1,0,0),linewidth=3) 245 | png_fname = "{}/{}.png".format(path,ep) 246 | plt.savefig(png_fname,dpi=75) 247 | # clean up 248 | plt.clf() 249 | 250 | def setup_3D_plot(ax,elev,azim,lim=None): 251 | ax.xaxis.set_pane_color((1.0,1.0,1.0,0.0)) 252 | ax.yaxis.set_pane_color((1.0,1.0,1.0,0.0)) 253 | ax.zaxis.set_pane_color((1.0,1.0,1.0,0.0)) 254 | ax.xaxis._axinfo["grid"]["color"] = (0.9,0.9,0.9,1) 255 | ax.yaxis._axinfo["grid"]["color"] = (0.9,0.9,0.9,1) 256 | ax.zaxis._axinfo["grid"]["color"] = (0.9,0.9,0.9,1) 257 | ax.xaxis.set_tick_params(labelsize=8) 258 | ax.yaxis.set_tick_params(labelsize=8) 259 | ax.zaxis.set_tick_params(labelsize=8) 260 | ax.set_xlabel("X",fontsize=16) 261 | ax.set_ylabel("Y",fontsize=16) 262 | ax.set_zlabel("Z",fontsize=16) 263 | ax.set_xlim(lim.x[0],lim.x[1]) 264 | ax.set_ylim(lim.y[0],lim.y[1]) 265 | ax.set_zlim(lim.z[0],lim.z[1]) 266 | ax.view_init(elev=elev,azim=azim) 267 | -------------------------------------------------------------------------------- /warp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,sys,time 3 | import torch 4 | import torch.nn.functional as torch_F 5 | 6 | import util 7 | from util import log,debug 8 | import camera 9 | 10 | def get_normalized_pixel_grid(opt): 11 | y_range = ((torch.arange(opt.H,dtype=torch.float32,device=opt.device)+0.5)/opt.H*2-1)*(opt.H/max(opt.H,opt.W)) 12 | x_range = ((torch.arange(opt.W,dtype=torch.float32,device=opt.device)+0.5)/opt.W*2-1)*(opt.W/max(opt.H,opt.W)) 13 | Y,X = torch.meshgrid(y_range,x_range) # [H,W] 14 | xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2] 15 | xy_grid = xy_grid.repeat(opt.batch_size,1,1) # [B,HW,2] 16 | return xy_grid 17 | 18 | def get_normalized_pixel_grid_crop(opt): 19 | y_crop = (opt.H//2-opt.H_crop//2,opt.H//2+opt.H_crop//2) 20 | x_crop = (opt.W//2-opt.W_crop//2,opt.W//2+opt.W_crop//2) 21 | y_range = ((torch.arange(*(y_crop),dtype=torch.float32,device=opt.device)+0.5)/opt.H*2-1)*(opt.H/max(opt.H,opt.W)) 22 | x_range = ((torch.arange(*(x_crop),dtype=torch.float32,device=opt.device)+0.5)/opt.W*2-1)*(opt.W/max(opt.H,opt.W)) 23 | Y,X = torch.meshgrid(y_range,x_range) # [H,W] 24 | xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2] 25 | xy_grid = xy_grid.repeat(opt.batch_size,1,1) # [B,HW,2] 26 | return xy_grid 27 | 28 | def warp_grid(opt,xy_grid,warp): 29 | if opt.warp.type=="translation": 30 | assert(opt.warp.dof==2) 31 | warped_grid = xy_grid+warp[...,None,:] 32 | elif opt.warp.type=="rotation": 33 | assert(opt.warp.dof==1) 34 | warp_matrix = lie.so2_to_SO2(warp) 35 | warped_grid = xy_grid@warp_matrix.transpose(-2,-1) # [B,HW,2] 36 | elif opt.warp.type=="rigid": 37 | assert(opt.warp.dof==3) 38 | xy_grid_hom = camera.to_hom(xy_grid) 39 | warp_matrix = lie.se2_to_SE2(warp) 40 | warped_grid = xy_grid_hom@warp_matrix.transpose(-2,-1) # [B,HW,2] 41 | elif opt.warp.type=="homography": 42 | assert(opt.warp.dof==8) 43 | xy_grid_hom = camera.to_hom(xy_grid) 44 | warp_matrix = lie.sl3_to_SL3(warp) 45 | warped_grid_hom = xy_grid_hom@warp_matrix.transpose(-2,-1) 46 | warped_grid = warped_grid_hom[...,:2]/(warped_grid_hom[...,2:]+1e-8) # [B,HW,2] 47 | else: assert(False) 48 | return warped_grid 49 | 50 | def warp_corners(opt,warp_param): 51 | y_crop = (opt.H//2-opt.H_crop//2,opt.H//2+opt.H_crop//2) 52 | x_crop = (opt.W//2-opt.W_crop//2,opt.W//2+opt.W_crop//2) 53 | Y = [((y+0.5)/opt.H*2-1)*(opt.H/max(opt.H,opt.W)) for y in y_crop] 54 | X = [((x+0.5)/opt.W*2-1)*(opt.W/max(opt.H,opt.W)) for x in x_crop] 55 | corners = [(X[0],Y[0]),(X[0],Y[1]),(X[1],Y[1]),(X[1],Y[0])] 56 | corners = torch.tensor(corners,dtype=torch.float32,device=opt.device).repeat(opt.batch_size,1,1) 57 | corners_warped = warp_grid(opt,corners,warp_param) 58 | return corners_warped 59 | 60 | def check_corners_in_range(opt,warp_param): 61 | corners_all = warp_corners(opt,warp_param) 62 | X = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5 63 | Y = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5 64 | return (0<=X).all() and (X0: denom *= (2*i)*(2*i+1) 140 | ans = ans+(-1)**i*x**(2*i)/denom 141 | return ans 142 | def taylor_B(self,x,nth=10): 143 | # Taylor expansion of (1-cos(x))/x 144 | ans = torch.zeros_like(x) 145 | denom = 1. 146 | for i in range(nth+1): 147 | denom *= (2*i+1)*(2*i+2) 148 | ans = ans+(-1)**i*x**(2*i+1)/denom 149 | return ans 150 | 151 | def taylor_C(self,x,nth=10): 152 | # Taylor expansion of (x*cos(x)-sin(x))/x**2 153 | ans = torch.zeros_like(x) 154 | denom = 1. 155 | for i in range(nth+1): 156 | denom *= (2*i+2)*(2*i+3) 157 | ans = ans+(-1)**(i+1)*x**(2*i+1)*(2*i+2)/denom 158 | return ans 159 | 160 | def taylor_D(self,x,nth=10): 161 | # Taylor expansion of (x*sin(x)+cos(x)-1)/x**2 162 | ans = torch.zeros_like(x) 163 | denom = 1. 164 | for i in range(nth+1): 165 | denom *= (2*i+1)*(2*i+2) 166 | ans = ans+(-1)**i*x**(2*i)*(2*i+1)/denom 167 | return ans 168 | 169 | lie = Lie() 170 | --------------------------------------------------------------------------------