├── .gitignore ├── README.md ├── assets └── demo.gif ├── configs └── scalarflowreal.txt ├── data └── ScalarReal │ ├── info.json │ ├── train00.mp4 │ ├── train01.mp4 │ ├── train02.mp4 │ ├── train03.mp4 │ └── train04.mp4 ├── environment.yml ├── load_scalarflow.py ├── loss.py ├── radam.py ├── ray_utils.py ├── requirements.txt ├── run_nerf_density.py ├── run_nerf_helpers.py ├── run_nerf_jointly.py ├── run_nerf_vort.py ├── scripts ├── test_future_pred.sh ├── test_resim.sh ├── train.sh ├── train_j.sh └── train_vort.sh ├── taichi_encoders ├── __init__.py ├── hash4.py ├── mgpcg.py ├── taichi_utils.py └── utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .idea/ 131 | .DS_Store 132 | 133 | logs/ 134 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Inferring Hybrid Neural Fluid Fields from Videos 2 | This is the official code for Inferring Hybrid Neural Fluid Fields from Videos (NeurIPS 2023). 3 | 4 | ![teaser](assets/demo.gif) 5 | 6 | **[[Paper](https://arxiv.org/pdf/2312.06561.pdf)] [[Project Page](https://kovenyu.com/hyfluid/)]** 7 | 8 | ## Installation 9 | Install with conda: 10 | ```bash 11 | conda env create -f environment.yml 12 | conda activate hyfluid 13 | ``` 14 | or with pip: 15 | ```bash 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## Data 20 | The demo data is available at [data/ScalarReal](data/ScalarReal). 21 | The full ScalarFlow dataset can be downloaded [here](https://ge.in.tum.de/publications/2019-scalarflow-eckert/). 22 | 23 | ## Quick Start 24 | To learn the hybrid neural fluid fields from the demo data, firstly reconstruct the density field by running (~40min): 25 | ```bash 26 | bash scripts/train.sh 27 | ``` 28 | Then, reconstruct the velocity field by jointly training with the density field (~15 hours on a single A6000 GPU.): 29 | ```bash 30 | bash scripts/train_j.sh 31 | ``` 32 | Finally, add vortex particles and optimize their physical parameters (~40min): 33 | ```bash 34 | bash scripts/train_vort.sh 35 | ``` 36 | The results will be saved in `./logs/exp_real`. With the learned hybrid neural fluid fields, you can re-simulate the fluid by using the velocity fields to advect density: 37 | ```bash 38 | bash scripts/test_resim.sh 39 | ``` 40 | Or, you can predict the future states by extrapolating the velocity fields: 41 | ```bash 42 | bash scripts/test_future_pred.sh 43 | ``` 44 | 45 | ## Citation 46 | If you find this code useful for your research, please cite our paper: 47 | ``` 48 | @article{yu2023inferring, 49 | title={Inferring Hybrid Neural Fluid Fields from Videos}, 50 | author={Yu, Hong-Xing and Zheng, Yang and Gao, Yuan and Deng, Yitong and Zhu, Bo and Wu, Jiajun}, 51 | journal={NeurIPS}, 52 | year={2023} 53 | } 54 | ``` -------------------------------------------------------------------------------- /assets/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/y-zheng18/HyFluid/70b4b962c66d0371e366ad6ebbb61af562b08846/assets/demo.gif -------------------------------------------------------------------------------- /configs/scalarflowreal.txt: -------------------------------------------------------------------------------- 1 | expname = scalarflowreal 2 | basedir = ./logs 3 | datadir = ./data/ScalarReal 4 | 5 | N_samples = 192 6 | N_rand = 1024 7 | 8 | half_res = True -------------------------------------------------------------------------------- /data/ScalarReal/info.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_videos": [ 3 | { 4 | "file_name": "train00.mp4", 5 | "frame_rate": 30, 6 | "frame_num": 120, 7 | "camera_angle_x": 0.40746459248665245, 8 | "camera_hw": [ 9 | 1920, 10 | 1080 11 | ], 12 | "transform_matrix": [ 13 | [ 14 | 0.48627835512161255, 15 | -0.24310240149497986, 16 | -0.8393059968948364, 17 | -0.7697111964225769 18 | ], 19 | [ 20 | -0.01889985240995884, 21 | 0.9573688507080078, 22 | -0.2882491946220398, 23 | 0.013170702382922173 24 | ], 25 | [ 26 | 0.8735995292663574, 27 | 0.15603208541870117, 28 | 0.4609531760215759, 29 | 0.3249526023864746 30 | ], 31 | [ 32 | 0.0, 33 | 0.0, 34 | 0.0, 35 | 1.0 36 | ] 37 | ] 38 | }, 39 | { 40 | "file_name": "train01.mp4", 41 | "frame_rate": 30, 42 | "frame_num": 120, 43 | "camera_angle_x": 0.39413608028840563, 44 | "camera_hw": [ 45 | 1920, 46 | 1080 47 | ], 48 | "transform_matrix": [ 49 | [ 50 | 0.8157652020454407, 51 | -0.1372431218624115, 52 | -0.5618642568588257, 53 | -0.39192497730255127 54 | ], 55 | [ 56 | -0.04113851860165596, 57 | 0.9552109837532043, 58 | -0.2930521070957184, 59 | 0.010452679358422756 60 | ], 61 | [ 62 | 0.5769183039665222, 63 | 0.262175977230072, 64 | 0.7735819220542908, 65 | 0.8086869120597839 66 | ], 67 | [ 68 | 0.0, 69 | 0.0, 70 | 0.0, 71 | 1.0 72 | ] 73 | ] 74 | }, 75 | { 76 | "file_name": "train03.mp4", 77 | "frame_rate": 30, 78 | "frame_num": 120, 79 | "camera_angle_x": 0.41320072172607875, 80 | "camera_hw": [ 81 | 1920, 82 | 1080 83 | ], 84 | "transform_matrix": [ 85 | [ 86 | 0.8836436867713928, 87 | 0.15215487778186798, 88 | 0.44274458289146423, 89 | 0.8974969983100891 90 | ], 91 | [ 92 | -0.021659603342413902, 93 | 0.9579861760139465, 94 | -0.28599533438682556, 95 | 0.02680988796055317 96 | ], 97 | [ 98 | -0.46765878796577454, 99 | 0.24312829971313477, 100 | 0.8498140573501587, 101 | 0.8316138386726379 102 | ], 103 | [ 104 | 0.0, 105 | 0.0, 106 | 0.0, 107 | 1.0 108 | ] 109 | ] 110 | }, 111 | { 112 | "file_name": "train04.mp4", 113 | "frame_rate": 30, 114 | "frame_num": 120, 115 | "camera_angle_x": 0.40746459248665245, 116 | "camera_hw": [ 117 | 1920, 118 | 1080 119 | ], 120 | "transform_matrix": [ 121 | [ 122 | 0.6336104273796082, 123 | 0.20118704438209534, 124 | 0.7470352053642273, 125 | 1.2956339120864868 126 | ], 127 | [ 128 | 0.014488859102129936, 129 | 0.9623404741287231, 130 | -0.27146074175834656, 131 | 0.02436656318604946 132 | ], 133 | [ 134 | -0.7735165357589722, 135 | 0.1828240603208542, 136 | 0.6068339943885803, 137 | 0.497546911239624 138 | ], 139 | [ 140 | 0.0, 141 | 0.0, 142 | 0.0, 143 | 1.0 144 | ] 145 | ] 146 | } 147 | ], 148 | "test_videos": [ 149 | { 150 | "file_name": "train02.mp4", 151 | "frame_rate": 30, 152 | "frame_num": 120, 153 | "camera_angle_x": 0.41505697544547304, 154 | "camera_hw": [ 155 | 1920, 156 | 1080 157 | ], 158 | "transform_matrix": [ 159 | [ 160 | 0.999511182308197, 161 | -0.0030406631994992495, 162 | -0.03111351653933525, 163 | 0.2844361364841461 164 | ], 165 | [ 166 | -0.005995774641633034, 167 | 0.9581364989280701, 168 | -0.2862490713596344, 169 | 0.011681094765663147 170 | ], 171 | [ 172 | 0.03068138100206852, 173 | 0.28629571199417114, 174 | 0.9576499462127686, 175 | 0.9857829809188843 176 | ], 177 | [ 178 | 0.0, 179 | 0.0, 180 | 0.0, 181 | 1.0 182 | ] 183 | ] 184 | } 185 | ], 186 | "frame_bkg_color": [ 187 | 0.0, 188 | 0.0, 189 | 0.0 190 | ], 191 | "voxel_scale": [ 192 | 0.4909, 193 | 0.73635, 194 | 0.4909 195 | ], 196 | "voxel_matrix": [ 197 | [ 198 | 7.549790126404332e-08, 199 | 0.0, 200 | 1.0, 201 | 0.081816665828228 202 | ], 203 | [ 204 | 0.0, 205 | 1.0, 206 | 0.0, 207 | -0.044627271592617035 208 | ], 209 | [ 210 | -1.0, 211 | 0.0, 212 | 7.549790126404332e-08, 213 | -0.004908999893814325 214 | ], 215 | [ 216 | 0.0, 217 | 0.0, 218 | 0.0, 219 | 1.0 220 | ] 221 | ], 222 | "render_center":[ 223 | 0.3382070094283088, 224 | 0.38795384153014023, 225 | -0.2609209839653898 226 | ], 227 | "near":1.1, 228 | "far":1.5, 229 | "phi":20.0, 230 | "rot":"Y" 231 | } -------------------------------------------------------------------------------- /data/ScalarReal/train00.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/y-zheng18/HyFluid/70b4b962c66d0371e366ad6ebbb61af562b08846/data/ScalarReal/train00.mp4 -------------------------------------------------------------------------------- /data/ScalarReal/train01.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/y-zheng18/HyFluid/70b4b962c66d0371e366ad6ebbb61af562b08846/data/ScalarReal/train01.mp4 -------------------------------------------------------------------------------- /data/ScalarReal/train02.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/y-zheng18/HyFluid/70b4b962c66d0371e366ad6ebbb61af562b08846/data/ScalarReal/train02.mp4 -------------------------------------------------------------------------------- /data/ScalarReal/train03.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/y-zheng18/HyFluid/70b4b962c66d0371e366ad6ebbb61af562b08846/data/ScalarReal/train03.mp4 -------------------------------------------------------------------------------- /data/ScalarReal/train04.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/y-zheng18/HyFluid/70b4b962c66d0371e366ad6ebbb61af562b08846/data/ScalarReal/train04.mp4 -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: hyfluid_ 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - bzip2=1.0.8=h7b6447c_0 9 | - ca-certificates=2023.11.17=hbcca054_0 10 | - ffmpeg=4.2.2=h20bf706_0 11 | - freetype=2.10.4=h0708190_1 12 | - gmp=6.1.2=hf484d3e_1000 13 | - gnutls=3.6.15=he1e5248_0 14 | - lame=3.100=h7f98852_1001 15 | - ld_impl_linux-64=2.38=h1181459_1 16 | - libffi=3.4.4=h6a678d5_0 17 | - libgcc-ng=11.2.0=h1234567_1 18 | - libgomp=11.2.0=h1234567_1 19 | - libidn2=2.3.4=h5eee18b_0 20 | - libopus=1.3.1=h7f98852_1 21 | - libpng=1.6.39=h5eee18b_0 22 | - libstdcxx-ng=11.2.0=h1234567_1 23 | - libtasn1=4.19.0=h5eee18b_0 24 | - libunistring=0.9.10=h7f98852_0 25 | - libuuid=1.41.5=h5eee18b_0 26 | - libvpx=1.7.0=h439df22_0 27 | - ncurses=6.4=h6a678d5_0 28 | - nettle=3.7.3=hbbd107a_1 29 | - openh264=2.1.1=h4ff587b_0 30 | - openssl=3.0.12=h7f8727e_0 31 | - pip=23.3.1=py310h06a4308_0 32 | - python=3.10.13=h955ad1f_0 33 | - readline=8.2=h5eee18b_0 34 | - setuptools=68.2.2=py310h06a4308_0 35 | - sqlite=3.41.2=h5eee18b_0 36 | - tk=8.6.12=h1ccaba5_0 37 | - tzdata=2023c=h04d1e81_0 38 | - wheel=0.41.2=py310h06a4308_0 39 | - x264=1!157.20191217=h7b6447c_0 40 | - xz=5.4.5=h5eee18b_0 41 | - zlib=1.2.13=h5eee18b_0 42 | - pip: 43 | - asttokens==2.4.1 44 | - certifi==2023.11.17 45 | - charset-normalizer==3.3.2 46 | - cmake==3.28.1 47 | - colorama==0.4.6 48 | - configargparse==1.5.3 49 | - cycler==0.12.1 50 | - decorator==5.1.1 51 | - dill==0.3.7 52 | - exceptiongroup==1.2.0 53 | - executing==2.0.1 54 | - filelock==3.13.1 55 | - fonttools==4.47.0 56 | - idna==3.6 57 | - imageio==2.27.0 58 | - imageio-ffmpeg==0.4.5 59 | - ipdb==0.13.13 60 | - ipython==8.18.1 61 | - jedi==0.19.1 62 | - jinja2==3.1.2 63 | - joblib==1.3.2 64 | - kiwisolver==1.4.5 65 | - kornia==0.6.11 66 | - lazy-loader==0.3 67 | - lit==17.0.6 68 | - lpips==0.1.4 69 | - markdown-it-py==3.0.0 70 | - markupsafe==2.1.3 71 | - matplotlib==3.5.3 72 | - matplotlib-inline==0.1.6 73 | - mdurl==0.1.2 74 | - mpmath==1.3.0 75 | - networkx==3.2.1 76 | - numpy==1.24.2 77 | - nvidia-cublas-cu11==11.10.3.66 78 | - nvidia-cuda-cupti-cu11==11.7.101 79 | - nvidia-cuda-nvrtc-cu11==11.7.99 80 | - nvidia-cuda-runtime-cu11==11.7.99 81 | - nvidia-cudnn-cu11==8.5.0.96 82 | - nvidia-cufft-cu11==10.9.0.58 83 | - nvidia-curand-cu11==10.2.10.91 84 | - nvidia-cusolver-cu11==11.4.0.1 85 | - nvidia-cusparse-cu11==11.7.4.91 86 | - nvidia-nccl-cu11==2.14.3 87 | - nvidia-nvtx-cu11==11.7.91 88 | - opencv-python==4.6.0.66 89 | - packaging==23.2 90 | - parso==0.8.3 91 | - pexpect==4.9.0 92 | - pillow==10.1.0 93 | - plyfile==1.0.2 94 | - prompt-toolkit==3.0.43 95 | - psutil==5.9.7 96 | - ptyprocess==0.7.0 97 | - pure-eval==0.2.2 98 | - pygments==2.17.2 99 | - pyparsing==3.1.1 100 | - python-dateutil==2.8.2 101 | - pywavelets==1.5.0 102 | - requests==2.31.0 103 | - rich==13.7.0 104 | - scikit-image==0.20.0 105 | - scikit-learn==1.3.2 106 | - scipy==1.11.4 107 | - six==1.16.0 108 | - stack-data==0.6.3 109 | - sympy==1.12 110 | - taichi==1.5.0 111 | - threadpoolctl==3.2.0 112 | - tifffile==2023.12.9 113 | - tomli==2.0.1 114 | - torch==2.0.1 115 | - torchvision==0.15.2 116 | - tqdm==4.65.0 117 | - traitlets==5.14.0 118 | - triton==2.0.0 119 | - typing-extensions==4.9.0 120 | - urllib3==2.1.0 121 | - vtk==9.2.6 122 | - wcwidth==0.2.12 123 | -------------------------------------------------------------------------------- /load_scalarflow.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | import torch.nn.functional as F 7 | import cv2 8 | 9 | trans_t = lambda t: torch.Tensor([ 10 | [1, 0, 0, 0], 11 | [0, 1, 0, 0], 12 | [0, 0, 1, t], 13 | [0, 0, 0, 1]]).float() 14 | 15 | rot_phi = lambda phi: torch.Tensor([ 16 | [1, 0, 0, 0], 17 | [0, np.cos(phi), -np.sin(phi), 0], 18 | [0, np.sin(phi), np.cos(phi), 0], 19 | [0, 0, 0, 1]]).float() 20 | 21 | rot_theta = lambda th: torch.Tensor([ 22 | [np.cos(th), 0, -np.sin(th), 0], 23 | [0, 1, 0, 0], 24 | [np.sin(th), 0, np.cos(th), 0], 25 | [0, 0, 0, 1]]).float() 26 | 27 | 28 | def pose_spherical(theta, phi, radius, rotZ=True, wx=0.0, wy=0.0, wz=0.0): 29 | # spherical, rotZ=True: theta rotate around Z; rotZ=False: theta rotate around Y 30 | # wx,wy,wz, additional translation, normally the center coord. 31 | c2w = trans_t(radius) 32 | c2w = rot_phi(phi / 180. * np.pi) @ c2w 33 | c2w = rot_theta(theta / 180. * np.pi) @ c2w 34 | if rotZ: # swap yz, and keep right-hand 35 | c2w = torch.Tensor(np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])) @ c2w 36 | 37 | ct = torch.Tensor([ 38 | [1, 0, 0, wx], 39 | [0, 1, 0, wy], 40 | [0, 0, 1, wz], 41 | [0, 0, 0, 1]]).float() 42 | c2w = ct @ c2w 43 | 44 | return c2w 45 | 46 | 47 | def load_pinf_frame_data(basedir, half_res=False, split='train'): 48 | # frame data 49 | all_imgs = [] 50 | all_poses = [] 51 | 52 | with open(os.path.join(basedir, 'info.json'), 'r') as fp: 53 | # read render settings 54 | meta = json.load(fp) 55 | near = float(meta['near']) 56 | far = float(meta['far']) 57 | radius = (near + far) * 0.5 58 | phi = float(meta['phi']) 59 | rotZ = (meta['rot'] == 'Z') 60 | r_center = np.float32(meta['render_center']) 61 | 62 | # read scene data 63 | voxel_tran = np.float32(meta['voxel_matrix']) 64 | voxel_tran = np.stack([voxel_tran[:, 2], voxel_tran[:, 1], voxel_tran[:, 0], voxel_tran[:, 3]], 65 | axis=1) # swap_zx 66 | voxel_scale = np.broadcast_to(meta['voxel_scale'], [3]) 67 | 68 | # read video frames 69 | # all videos should be synchronized, having the same frame_rate and frame_num 70 | 71 | video_list = meta[split + '_videos'] if (split + '_videos') in meta else meta['train_videos'][0:1] 72 | 73 | for video_id, train_video in enumerate(video_list): 74 | imgs = [] 75 | 76 | f_name = os.path.join(basedir, train_video['file_name']) 77 | reader = imageio.get_reader(f_name, "ffmpeg") 78 | for frame_i in range(train_video['frame_num']): 79 | reader.set_image_index(frame_i) 80 | frame = reader.get_next_data() 81 | 82 | H, W = frame.shape[:2] 83 | camera_angle_x = float(train_video['camera_angle_x']) 84 | Focal = .5 * W / np.tan(.5 * camera_angle_x) 85 | imgs.append(frame) 86 | 87 | reader.close() 88 | imgs = (np.float32(imgs) / 255.) 89 | 90 | if half_res: 91 | H = H // 2 92 | W = W // 2 93 | Focal = Focal / 2. 94 | 95 | imgs_half_res = np.zeros((imgs.shape[0], H, W, imgs.shape[-1])) 96 | for i, img in enumerate(imgs): 97 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) 98 | imgs = imgs_half_res 99 | 100 | all_imgs.append(imgs) 101 | all_poses.append(np.array( 102 | train_video['transform_matrix_list'][frame_i] 103 | if 'transform_matrix_list' in train_video else train_video['transform_matrix'] 104 | ).astype(np.float32)) 105 | 106 | imgs = np.stack(all_imgs, 0) # [V, T, H, W, 3] 107 | imgs = np.transpose(imgs, [1, 0, 2, 3, 4]) # [T, V, H, W, 3] 108 | poses = np.stack(all_poses, 0) # [V, 4, 4] 109 | hwf = np.float32([H, W, Focal]) 110 | 111 | # set render settings: 112 | sp_n = 120 # an even number! 113 | sp_poses = [ 114 | pose_spherical(angle, phi, radius, rotZ, r_center[0], r_center[1], r_center[2]) 115 | for angle in np.linspace(-180, 180, sp_n + 1)[:-1] 116 | ] 117 | render_poses = torch.stack(sp_poses, 0) # [sp_poses[36]]*sp_n, for testing a single pose 118 | render_timesteps = np.arange(sp_n) / (sp_n - 1) 119 | 120 | return imgs, poses, hwf, render_poses, render_timesteps, voxel_tran, voxel_scale, near, far 121 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | # Author: Yash Bhalgat 2 | 3 | from math import exp, log, floor 4 | import torch 5 | import torch.nn.functional as F 6 | import pdb 7 | 8 | from utils import hash 9 | 10 | 11 | def total_variation_loss(embeddings, min_resolution, max_resolution, level, log2_hashmap_size, n_levels=16): 12 | # Get resolution 13 | b = exp((log(max_resolution)-log(min_resolution))/(n_levels-1)) 14 | resolution = torch.tensor(floor(min_resolution * b**level)) 15 | 16 | # Cube size to apply TV loss 17 | min_cube_size = min_resolution - 1 18 | max_cube_size = 50 # can be tuned 19 | if min_cube_size > max_cube_size: 20 | print("ALERT! min cuboid size greater than max!") 21 | pdb.set_trace() 22 | cube_size = torch.floor(torch.clip(resolution/10.0, min_cube_size, max_cube_size)).int() 23 | 24 | # Sample cuboid 25 | min_vertex = torch.randint(0, resolution-cube_size, (3,)) 26 | idx = min_vertex + torch.stack([torch.arange(cube_size+1) for _ in range(3)], dim=-1) 27 | cube_indices = torch.stack(torch.meshgrid(idx[:,0], idx[:,1], idx[:,2], indexing='ij'), dim=-1) 28 | 29 | hashed_indices = hash(cube_indices, log2_hashmap_size) 30 | cube_embeddings = embeddings(hashed_indices) 31 | #hashed_idx_offset_x = hash(idx+torch.tensor([1,0,0]), log2_hashmap_size) 32 | #hashed_idx_offset_y = hash(idx+torch.tensor([0,1,0]), log2_hashmap_size) 33 | #hashed_idx_offset_z = hash(idx+torch.tensor([0,0,1]), log2_hashmap_size) 34 | 35 | # Compute loss 36 | #tv_x = torch.pow(embeddings(hashed_idx)-embeddings(hashed_idx_offset_x), 2).sum() 37 | #tv_y = torch.pow(embeddings(hashed_idx)-embeddings(hashed_idx_offset_y), 2).sum() 38 | #tv_z = torch.pow(embeddings(hashed_idx)-embeddings(hashed_idx_offset_z), 2).sum() 39 | tv_x = torch.pow(cube_embeddings[1:,:,:,:]-cube_embeddings[:-1,:,:,:], 2).sum() 40 | tv_y = torch.pow(cube_embeddings[:,1:,:,:]-cube_embeddings[:,:-1,:,:], 2).sum() 41 | tv_z = torch.pow(cube_embeddings[:,:,1:,:]-cube_embeddings[:,:,:-1,:], 2).sum() 42 | 43 | return (tv_x + tv_y + tv_z)/cube_size 44 | 45 | def sigma_sparsity_loss(sigmas): 46 | # Using Cauchy Sparsity loss on sigma values 47 | return torch.log(1.0 + 2*sigmas**2).sum(dim=-1) 48 | -------------------------------------------------------------------------------- /radam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | class RAdam(Optimizer): 6 | 7 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=False): 8 | if not 0.0 <= lr: 9 | raise ValueError("Invalid learning rate: {}".format(lr)) 10 | if not 0.0 <= eps: 11 | raise ValueError("Invalid epsilon value: {}".format(eps)) 12 | if not 0.0 <= betas[0] < 1.0: 13 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 14 | if not 0.0 <= betas[1] < 1.0: 15 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 16 | 17 | self.degenerated_to_sgd = degenerated_to_sgd 18 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 19 | for param in params: 20 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 21 | param['buffer'] = [[None, None, None] for _ in range(10)] 22 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) 23 | super(RAdam, self).__init__(params, defaults) 24 | 25 | def __setstate__(self, state): 26 | super(RAdam, self).__setstate__(state) 27 | 28 | def step(self, closure=None): 29 | 30 | loss = None 31 | if closure is not None: 32 | loss = closure() 33 | 34 | for group in self.param_groups: 35 | 36 | for p in group['params']: 37 | if p.grad is None: 38 | continue 39 | grad = p.grad.data.float() 40 | if grad.is_sparse: 41 | raise RuntimeError('RAdam does not support sparse gradients') 42 | 43 | p_data_fp32 = p.data.float() 44 | 45 | state = self.state[p] 46 | 47 | if len(state) == 0: 48 | state['step'] = 0 49 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 50 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 51 | else: 52 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 53 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 54 | 55 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 56 | beta1, beta2 = group['betas'] 57 | 58 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2) 59 | exp_avg.mul_(beta1).add_(grad, alpha=1-beta1) 60 | 61 | state['step'] += 1 62 | buffered = group['buffer'][int(state['step'] % 10)] 63 | if state['step'] == buffered[0]: 64 | N_sma, step_size = buffered[1], buffered[2] 65 | else: 66 | buffered[0] = state['step'] 67 | beta2_t = beta2 ** state['step'] 68 | N_sma_max = 2 / (1 - beta2) - 1 69 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 70 | buffered[1] = N_sma 71 | 72 | # more conservative since it's an approximated value 73 | if N_sma >= 5: 74 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 75 | elif self.degenerated_to_sgd: 76 | step_size = 1.0 / (1 - beta1 ** state['step']) 77 | else: 78 | step_size = -1 79 | buffered[2] = step_size 80 | 81 | # more conservative since it's an approximated value 82 | if N_sma >= 5: 83 | if group['weight_decay'] != 0: 84 | p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr']) 85 | denom = exp_avg_sq.sqrt().add_(group['eps']) 86 | p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr']) 87 | p.data.copy_(p_data_fp32) 88 | elif step_size > 0: 89 | if group['weight_decay'] != 0: 90 | p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr']) 91 | p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr']) 92 | p.data.copy_(p_data_fp32) 93 | 94 | return loss -------------------------------------------------------------------------------- /ray_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia import create_meshgrid 3 | 4 | 5 | def get_ray_directions(H, W, focal): 6 | """ 7 | Get ray directions for all pixels in camera coordinate. 8 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 9 | ray-tracing-generating-camera-rays/standard-coordinate-systems 10 | 11 | Inputs: 12 | H, W, focal: image height, width and focal length 13 | 14 | Outputs: 15 | directions: (H, W, 3), the direction of the rays in camera coordinate 16 | """ 17 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0] 18 | i, j = grid.unbind(-1) 19 | # the direction here is without +0.5 pixel centering as calibration is not so accurate 20 | # see https://github.com/bmild/nerf/issues/24 21 | directions = \ 22 | torch.stack([(i-W/2)/focal, -(j-H/2)/focal, -torch.ones_like(i)], -1) # (H, W, 3) 23 | 24 | dir_bounds = directions.view(-1, 3) 25 | # print("Directions ", directions[0,0,:], directions[H-1,0,:], directions[0,W-1,:], directions[H-1, W-1, :]) 26 | # print("Directions ", dir_bounds[0], dir_bounds[W-1], dir_bounds[H*W-W], dir_bounds[H*W-1]) 27 | 28 | return directions 29 | 30 | 31 | def get_rays(directions, c2w): 32 | """ 33 | Get ray origin and normalized directions in world coordinate for all pixels in one image. 34 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 35 | ray-tracing-generating-camera-rays/standard-coordinate-systems 36 | 37 | Inputs: 38 | directions: (H, W, 3) precomputed ray directions in camera coordinate 39 | c2w: (3, 4) transformation matrix from camera coordinate to world coordinate 40 | 41 | Outputs: 42 | rays_o: (H*W, 3), the origin of the rays in world coordinate 43 | rays_d: (H*W, 3), the normalized direction of the rays in world coordinate 44 | """ 45 | # Rotate ray directions from camera coordinate to the world coordinate 46 | rays_d = directions @ c2w[:3, :3].T # (H, W, 3) 47 | rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 48 | # The origin of all rays is the camera origin in world coordinate 49 | rays_o = c2w[:3, -1].expand(rays_d.shape) # (H, W, 3) 50 | 51 | rays_d = rays_d.view(-1, 3) 52 | rays_o = rays_o.view(-1, 3) 53 | 54 | return rays_o, rays_d 55 | 56 | 57 | def get_ndc_rays(H, W, focal, near, rays_o, rays_d): 58 | """ 59 | Transform rays from world coordinate to NDC. 60 | NDC: Space such that the canvas is a cube with sides [-1, 1] in each axis. 61 | For detailed derivation, please see: 62 | http://www.songho.ca/opengl/gl_projectionmatrix.html 63 | https://github.com/bmild/nerf/files/4451808/ndc_derivation.pdf 64 | 65 | In practice, use NDC "if and only if" the scene is unbounded (has a large depth). 66 | See https://github.com/bmild/nerf/issues/18 67 | 68 | Inputs: 69 | H, W, focal: image height, width and focal length 70 | near: (N_rays) or float, the depths of the near plane 71 | rays_o: (N_rays, 3), the origin of the rays in world coordinate 72 | rays_d: (N_rays, 3), the direction of the rays in world coordinate 73 | 74 | Outputs: 75 | rays_o: (N_rays, 3), the origin of the rays in NDC 76 | rays_d: (N_rays, 3), the direction of the rays in NDC 77 | """ 78 | # Shift ray origins to near plane 79 | t = -(near + rays_o[...,2]) / rays_d[...,2] 80 | rays_o = rays_o + t[...,None] * rays_d 81 | 82 | # Store some intermediate homogeneous results 83 | ox_oz = rays_o[...,0] / rays_o[...,2] 84 | oy_oz = rays_o[...,1] / rays_o[...,2] 85 | 86 | # Projection 87 | o0 = -1./(W/(2.*focal)) * ox_oz 88 | o1 = -1./(H/(2.*focal)) * oy_oz 89 | o2 = 1. + 2. * near / rays_o[...,2] 90 | 91 | d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - ox_oz) 92 | d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - oy_oz) 93 | d2 = 1 - o2 94 | 95 | rays_o = torch.stack([o0, o1, o2], -1) # (B, 3) 96 | rays_d = torch.stack([d0, d1, d2], -1) # (B, 3) 97 | 98 | return rays_o, rays_d 99 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ConfigArgParse==1.5.3 2 | imageio==2.27.0 3 | imageio-ffmpeg==0.4.5 4 | ipdb==0.13.13 5 | kornia==0.6.11 6 | lpips==0.1.4 7 | matplotlib==3.5.3 8 | numpy==1.24.2 9 | opencv_python==4.6.0.66 10 | plyfile==1.0.2 11 | scikit-image==0.20.0 12 | scikit-learn==1.3.2 13 | taichi==1.5.0 14 | torch==2.0.1 15 | tqdm==4.65.0 16 | vtk==9.2.6 17 | -------------------------------------------------------------------------------- /run_nerf_density.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from tqdm import tqdm, trange 3 | 4 | import matplotlib.pyplot as plt 5 | 6 | from run_nerf_helpers import NeRFSmall, to8b, img2mse, mse2psnr, get_rays_np, get_rays, get_rays_np_continuous, sample_bilinear 7 | import torch.nn.functional as F 8 | from radam import RAdam 9 | from load_scalarflow import load_pinf_frame_data 10 | import lpips 11 | import torch 12 | 13 | from skimage.metrics import structural_similarity 14 | 15 | 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | np.random.seed(0) 18 | 19 | 20 | def batchify_rays(rays_flat, chunk=1024 * 64, **kwargs): 21 | """Render rays in smaller minibatches to avoid OOM. 22 | """ 23 | all_ret = {} 24 | for i in range(0, rays_flat.shape[0], chunk): 25 | ret = render_rays(rays_flat[i:i + chunk], **kwargs) 26 | for k in ret: 27 | if k not in all_ret: 28 | all_ret[k] = [] 29 | all_ret[k].append(ret[k]) 30 | 31 | all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret} 32 | return all_ret 33 | 34 | 35 | def render(H, W, K, rays=None, c2w=None, 36 | near=0., far=1., time_step=None, 37 | **kwargs): 38 | """Render rays 39 | Args: 40 | H: int. Height of image in pixels. 41 | W: int. Width of image in pixels. 42 | K: float. Focal length of pinhole camera. 43 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for 44 | each example in batch. 45 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix. 46 | near: float or array of shape [batch_size]. Nearest distance for a ray. 47 | far: float or array of shape [batch_size]. Farthest distance for a ray. 48 | Returns: 49 | rgb_map: [batch_size, 3]. Predicted RGB values for rays. 50 | disp_map: [batch_size]. Disparity map. Inverse of depth. 51 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. 52 | extras: dict with everything returned by render_rays(). 53 | """ 54 | if c2w is not None: 55 | # special case to render full image 56 | rays_o, rays_d = get_rays(H, W, K, c2w) 57 | else: 58 | # use provided ray batch 59 | rays_o, rays_d = rays 60 | 61 | sh = rays_d.shape # [..., 3] 62 | 63 | # Create ray batch 64 | rays_o = torch.reshape(rays_o, [-1, 3]).float() 65 | rays_d = torch.reshape(rays_d, [-1, 3]).float() 66 | 67 | near, far = near * torch.ones_like(rays_d[..., :1]), far * torch.ones_like(rays_d[..., :1]) 68 | rays = torch.cat([rays_o, rays_d, near, far], -1) 69 | time_step = time_step[:, None, None] # [N_t, 1, 1] 70 | N_t = time_step.shape[0] 71 | N_r = rays.shape[0] 72 | rays = torch.cat([rays[None].expand(N_t, -1, -1), time_step.expand(-1, N_r, -1)], -1) # [N_t, n_rays, 7] 73 | rays = rays.flatten(0, 1) # [n_time_steps * n_rays, 7] 74 | 75 | # Render and reshape 76 | all_ret = batchify_rays(rays, **kwargs) 77 | if N_t == 1: 78 | for k in all_ret: 79 | k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) 80 | all_ret[k] = torch.reshape(all_ret[k], k_sh) 81 | 82 | k_extract = ['rgb_map', 'depth_map', 'acc_map'] 83 | ret_list = [all_ret[k] for k in k_extract] 84 | ret_dict = [{k: all_ret[k] for k in all_ret if k not in k_extract}, ] 85 | return ret_list + ret_dict 86 | 87 | 88 | def render_path(render_poses, hwf, K, render_kwargs, gt_imgs=None, savedir=None, time_steps=None): 89 | def merge_imgs(save_dir, framerate=30, prefix=''): 90 | os.system( 91 | 'ffmpeg -hide_banner -loglevel error -y -i {0}/{1}%03d.png -vf palettegen {0}/palette.png'.format(save_dir, 92 | prefix)) 93 | os.system( 94 | 'ffmpeg -hide_banner -loglevel error -y -framerate {0} -i {1}/{2}%03d.png -i {1}/palette.png -lavfi paletteuse {1}/_{2}.gif'.format( 95 | framerate, save_dir, prefix)) 96 | os.system( 97 | 'ffmpeg -hide_banner -loglevel error -y -framerate {0} -i {1}/{2}%03d.png -i {1}/palette.png -lavfi paletteuse {1}/_{2}.mp4'.format( 98 | framerate, save_dir, prefix)) 99 | 100 | 101 | render_kwargs.update(chunk=512 * 64) 102 | H, W, focal = hwf 103 | near, far = render_kwargs['near'], render_kwargs['far'] 104 | if time_steps is None: 105 | time_steps = torch.ones(render_poses.shape[0], dtype=torch.float32) 106 | 107 | rgbs = [] 108 | depths = [] 109 | psnrs = [] 110 | ssims = [] 111 | lpipss = [] 112 | 113 | lpips_net = lpips.LPIPS().cuda() 114 | 115 | for i, c2w in enumerate(tqdm(render_poses)): 116 | rgb, depth, acc, _ = render(H, W, K, c2w=c2w[:3, :4], time_step=time_steps[i][None], **render_kwargs) 117 | rgbs.append(rgb.cpu().numpy()) 118 | # normalize depth to [0,1] 119 | depth = (depth - near) / (far - near) 120 | depths.append(depth.cpu().numpy()) 121 | 122 | if gt_imgs is not None: 123 | gt_img = torch.tensor(gt_imgs[i].squeeze(), dtype=torch.float32) # [H, W, 3] 124 | gt_img8 = to8b(gt_img.cpu().numpy()) 125 | gt_img = gt_img[90:960, 45:540] 126 | rgb = rgb[90:960, 45:540] 127 | lpips_value = lpips_net(rgb.permute(2, 0, 1), gt_img.permute(2, 0, 1), normalize=True).item() 128 | p = -10. * np.log10(np.mean(np.square(rgb.detach().cpu().numpy() - gt_img.cpu().numpy()))) 129 | ssim_value = structural_similarity(gt_img.cpu().numpy(), rgb.cpu().numpy(), data_range=1.0, channel_axis=2) 130 | lpipss.append(lpips_value) 131 | psnrs.append(p) 132 | ssims.append(ssim_value) 133 | print(f'PSNR: {p:.4g}, SSIM: {ssim_value:.4g}, LPIPS: {lpips_value:.4g}') 134 | 135 | 136 | if savedir is not None: 137 | # save rgb and depth as a figure 138 | rgb8 = to8b(rgbs[-1]) 139 | imageio.imsave(os.path.join(savedir, 'rgb_{:03d}.png'.format(i)), rgb8) 140 | depth = depths[-1] 141 | colored_depth_map = plt.cm.viridis(depth.squeeze()) 142 | imageio.imwrite(os.path.join(savedir, 'depth_{:03d}.png'.format(i)), 143 | (colored_depth_map * 255).astype(np.uint8)) 144 | 145 | if savedir is not None: 146 | merge_imgs(savedir, prefix='rgb_') 147 | merge_imgs(savedir, prefix='depth_') 148 | 149 | rgbs = np.stack(rgbs, 0) 150 | depths = np.stack(depths, 0) 151 | if gt_imgs is not None: 152 | avg_psnr = sum(psnrs) / len(psnrs) 153 | avg_lpips = sum(lpipss) / len(lpipss) 154 | avg_ssim = sum(ssims) / len(ssims) 155 | print("Avg PSNR over Test set: ", avg_psnr) 156 | print("Avg LPIPS over Test set: ", avg_lpips) 157 | print("Avg SSIM over Test set: ", avg_ssim) 158 | with open(os.path.join(savedir, "test_psnrs_{:0.4f}_lpips_{:0.4f}_ssim_{:0.4f}.json".format(avg_psnr, avg_lpips, avg_ssim)), 'w') as fp: 159 | json.dump(psnrs, fp) 160 | 161 | return rgbs, depths 162 | 163 | 164 | def create_nerf(args): 165 | """Instantiate NeRF's MLP model. 166 | """ 167 | # from encoding import get_encoder 168 | from taichi_encoders.hash4 import Hash4Encoder 169 | # embed_fn, input_ch = get_encoder('hashgrid', input_dim=4, num_levels=args.num_levels, base_resolution=args.base_resolution, 170 | # finest_resolution=args.finest_resolution, log2_hashmap_size=args.log2_hashmap_size,) 171 | if args.encoder == 'ingp': 172 | max_res = np.array( 173 | [args.finest_resolution, args.finest_resolution, args.finest_resolution, args.finest_resolution_t]) 174 | min_res = np.array([args.base_resolution, args.base_resolution, args.base_resolution, args.base_resolution_t]) 175 | 176 | embed_fn = Hash4Encoder(max_res=max_res, min_res=min_res, num_scales=args.num_levels, 177 | max_params=2 ** args.log2_hashmap_size) 178 | input_ch = embed_fn.num_scales * 2 # default 2 params per scale 179 | embedding_params = list(embed_fn.parameters()) 180 | else: 181 | raise NotImplementedError 182 | 183 | model = NeRFSmall(num_layers=2, 184 | hidden_dim=64, 185 | geo_feat_dim=15, 186 | num_layers_color=2, 187 | hidden_dim_color=16, 188 | input_ch=input_ch).to(device) 189 | print(model) 190 | print('Total number of trainable parameters in model: {}'.format( 191 | sum([p.numel() for p in model.parameters() if p.requires_grad]))) 192 | print('Total number of parameters in embedding: {}'.format( 193 | sum([p.numel() for p in embedding_params if p.requires_grad]))) 194 | grad_vars = list(model.parameters()) 195 | 196 | network_query_fn = lambda x: model(embed_fn(x)) 197 | 198 | # Create optimizer 199 | optimizer = RAdam([ 200 | {'params': grad_vars, 'weight_decay': 1e-6}, 201 | {'params': embedding_params, 'eps': 1e-15} 202 | ], lr=args.lrate, betas=(0.9, 0.99)) 203 | grad_vars += list(embedding_params) 204 | start = 0 205 | basedir = args.basedir 206 | expname = args.expname 207 | 208 | ########################## 209 | 210 | # Load checkpoints 211 | if args.ft_path is not None and args.ft_path != 'None': 212 | ckpts = [args.ft_path] 213 | else: 214 | ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 215 | 'tar' in f] 216 | 217 | print('Found ckpts', ckpts) 218 | if len(ckpts) > 0 and not args.no_reload: 219 | ckpt_path = ckpts[-1] 220 | print('Reloading from', ckpt_path) 221 | ckpt = torch.load(ckpt_path) 222 | 223 | start = ckpt['global_step'] 224 | # Load model 225 | model.load_state_dict(ckpt['network_fn_state_dict']) 226 | embed_fn.load_state_dict(ckpt['embed_fn_state_dict']) 227 | # Load optimizer 228 | optimizer.load_state_dict(ckpt['optimizer_state_dict']) 229 | 230 | ########################## 231 | # pdb.set_trace() 232 | 233 | render_kwargs_train = { 234 | 'network_query_fn': network_query_fn, 235 | 'perturb': args.perturb, 236 | 'N_samples': args.N_samples, 237 | 'network_fn': model, 238 | 'embed_fn': embed_fn, 239 | } 240 | 241 | render_kwargs_test = {k: render_kwargs_train[k] for k in render_kwargs_train} 242 | render_kwargs_test['perturb'] = False 243 | 244 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer 245 | 246 | 247 | def raw2outputs(raw, z_vals, rays_d, learned_rgb=None): 248 | """Transforms model's predictions to semantically meaningful values. 249 | Args: 250 | raw: [num_rays, num_samples along ray, 4]. Prediction from model. 251 | z_vals: [num_rays, num_samples along ray]. Integration time. 252 | rays_d: [num_rays, 3]. Direction of each ray. 253 | Returns: 254 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. 255 | disp_map: [num_rays]. Disparity map. Inverse of depth map. 256 | acc_map: [num_rays]. Sum of weights along each ray. 257 | weights: [num_rays, num_samples]. Weights assigned to each sampled color. 258 | depth_map: [num_rays]. Estimated distance to object. 259 | """ 260 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1. - torch.exp(-act_fn(raw) * dists) 261 | 262 | dists = z_vals[..., 1:] - z_vals[..., :-1] 263 | dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[..., :1].shape)], -1) # [N_rays, N_samples] 264 | 265 | dists = dists * torch.norm(rays_d[..., None, :], dim=-1) 266 | 267 | rgb = torch.ones(3) * (0.6 + torch.tanh(learned_rgb) * 0.4) 268 | # rgb = 0.6 + torch.tanh(learned_rgb) * 0.4 269 | noise = 0. 270 | 271 | alpha = raw2alpha(raw[..., -1] + noise, dists) # [N_rays, N_samples] 272 | weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1. - alpha + 1e-10], -1), -1)[:, 273 | :-1] # [N_rays, N_samples] 274 | rgb_map = torch.sum(weights[..., None] * rgb, -2) # [N_rays, 3] 275 | 276 | depth_map = torch.sum(weights * z_vals, -1) / (torch.sum(weights, -1) + 1e-10) 277 | disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map), depth_map) 278 | acc_map = torch.sum(weights, -1) 279 | depth_map[acc_map < 1e-1] = 0. 280 | 281 | return rgb_map, disp_map, acc_map, weights, depth_map 282 | 283 | 284 | def render_rays(ray_batch, 285 | network_query_fn, 286 | N_samples, 287 | retraw=False, 288 | perturb=0., 289 | **kwargs): 290 | """Volumetric rendering. 291 | Args: 292 | ray_batch: array of shape [batch_size, ...]. All information necessary 293 | for sampling along a ray, including: ray origin, ray direction, min 294 | dist, max dist, and unit-magnitude viewing direction. 295 | network_query_fn: function used for passing queries to network_fn. 296 | N_samples: int. Number of different times to sample along each ray. 297 | retraw: bool. If True, include model's raw, unprocessed predictions. 298 | perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified 299 | random points in time. 300 | Returns: 301 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. 302 | disp_map: [num_rays]. Disparity map. 1 / depth. 303 | acc_map: [num_rays]. Accumulated opacity along each ray. 304 | raw: [num_rays, num_samples, 4]. Raw predictions from model. 305 | z_std: [num_rays]. Standard deviation of distances along ray for each 306 | sample. 307 | """ 308 | N_rays = ray_batch.shape[0] 309 | rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] # [N_rays, 3] each 310 | time_step = ray_batch[:, -1] 311 | bounds = torch.reshape(ray_batch[..., 6:8], [-1, 1, 2]) 312 | near, far = bounds[..., 0], bounds[..., 1] # [-1,1] 313 | 314 | t_vals = torch.linspace(0., 1., steps=N_samples) 315 | z_vals = near * (1. - t_vals) + far * (t_vals) 316 | 317 | z_vals = z_vals.expand([N_rays, N_samples]) 318 | 319 | if perturb > 0.: 320 | # get intervals between samples 321 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) 322 | upper = torch.cat([mids, z_vals[..., -1:]], -1) 323 | lower = torch.cat([z_vals[..., :1], mids], -1) 324 | # stratified samples in those intervals 325 | t_rand = torch.rand(z_vals.shape) 326 | 327 | z_vals = lower + (upper - lower) * t_rand 328 | 329 | pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None] # [N_rays, N_samples, 3] 330 | pts_time_step = time_step[..., None, None].expand(-1, pts.shape[1], -1) 331 | pts = torch.cat([pts, pts_time_step], -1) # [..., 4] 332 | pts_flat = torch.reshape(pts, [-1, 4]) 333 | out_dim = 1 334 | raw_flat = torch.zeros([N_rays, N_samples, out_dim]).reshape(-1, out_dim) 335 | 336 | bbox_mask = bbox_model.insideMask(pts_flat[..., :3], to_float=False) 337 | if bbox_mask.sum() == 0: 338 | bbox_mask[0] = True # in case zero rays are inside the bbox 339 | pts = pts_flat[bbox_mask] 340 | 341 | raw_flat[bbox_mask] = network_query_fn(pts) 342 | raw = raw_flat.reshape(N_rays, N_samples, out_dim) 343 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, 344 | learned_rgb=kwargs['network_fn'].rgb,) 345 | 346 | ret = {'rgb_map': rgb_map, 'depth_map': depth_map, 'acc_map': acc_map} 347 | if retraw: 348 | ret['raw'] = raw 349 | return ret 350 | 351 | 352 | def config_parser(): 353 | import configargparse 354 | parser = configargparse.ArgumentParser() 355 | parser.add_argument('--config', is_config_file=True, 356 | help='config file path') 357 | parser.add_argument("--expname", type=str, 358 | help='experiment name') 359 | parser.add_argument("--basedir", type=str, default='./logs/', 360 | help='where to store ckpts and logs') 361 | parser.add_argument("--datadir", type=str, default='./data/llff/fern', 362 | help='input data directory') 363 | 364 | # training options 365 | parser.add_argument("--encoder", type=str, default='ingp', 366 | choices=['ingp', 'plane']) 367 | parser.add_argument("--N_rand", type=int, default=32 * 32 * 4, 368 | help='batch size (number of random rays per gradient step)') 369 | parser.add_argument("--N_time", type=int, default=1, 370 | help='batch size in time') 371 | parser.add_argument("--lrate", type=float, default=5e-4, 372 | help='learning rate') 373 | parser.add_argument("--lrate_decay", type=int, default=250, 374 | help='exponential learning rate decay') 375 | parser.add_argument("--N_iters", type=int, default=50000) 376 | parser.add_argument("--no_reload", action='store_true', 377 | help='do not reload weights from saved ckpt') 378 | parser.add_argument("--ft_path", type=str, default=None, 379 | help='specific weights npy file to reload for coarse network') 380 | 381 | # rendering options 382 | parser.add_argument("--N_samples", type=int, default=64, 383 | help='number of coarse samples per ray') 384 | parser.add_argument("--perturb", type=float, default=1., 385 | help='set to 0. for no jitter, 1. for jitter') 386 | 387 | parser.add_argument("--render_only", action='store_true', 388 | help='do not optimize, reload weights and render out render_poses path') 389 | parser.add_argument("--half_res", action='store_true', 390 | help='load at half resolution') 391 | 392 | # logging/saving options 393 | parser.add_argument("--i_print", type=int, default=100, 394 | help='frequency of console printout and metric loggin') 395 | parser.add_argument("--i_weights", type=int, default=10000, 396 | help='frequency of weight ckpt saving') 397 | parser.add_argument("--i_video", type=int, default=9999999, 398 | help='frequency of render_poses video saving') 399 | 400 | parser.add_argument("--finest_resolution", type=int, default=512, 401 | help='finest resolultion for hashed embedding') 402 | parser.add_argument("--finest_resolution_t", type=int, default=512, 403 | help='finest resolultion for hashed embedding') 404 | parser.add_argument("--num_levels", type=int, default=16, 405 | help='number of levels for hashed embedding') 406 | parser.add_argument("--base_resolution", type=int, default=16, 407 | help='base resolution for hashed embedding') 408 | parser.add_argument("--base_resolution_t", type=int, default=16, 409 | help='base resolution for hashed embedding') 410 | parser.add_argument("--log2_hashmap_size", type=int, default=19, 411 | help='log2 of hashmap size') 412 | parser.add_argument("--feats_dim", type=int, default=36, 413 | help='feature dimension of kplanes') 414 | parser.add_argument("--tv-loss-weight", type=float, default=1e-6, 415 | help='learning rate') 416 | 417 | return parser 418 | 419 | 420 | def train(): 421 | parser = config_parser() 422 | args = parser.parse_args() 423 | 424 | # Load data 425 | images_train_, poses_train, hwf, render_poses, render_timesteps, voxel_tran, voxel_scale, near, far = \ 426 | load_pinf_frame_data(args.datadir, args.half_res, split='train') 427 | images_test, poses_test, hwf, render_poses, render_timesteps, voxel_tran, voxel_scale, near, far = \ 428 | load_pinf_frame_data(args.datadir, args.half_res, split='test') 429 | global bbox_model 430 | voxel_tran_inv = np.linalg.inv(voxel_tran) 431 | bbox_model = BBox_Tool(voxel_tran_inv, voxel_scale) 432 | render_timesteps = torch.tensor(render_timesteps, dtype=torch.float32) 433 | print('Loaded scalarflow', images_train_.shape, render_poses.shape, hwf, args.datadir) 434 | 435 | # Cast intrinsics to right types 436 | H, W, focal = hwf 437 | H, W = int(H), int(W) 438 | hwf = [H, W, focal] 439 | 440 | K = np.array([ 441 | [focal, 0, 0.5 * W], 442 | [0, focal, 0.5 * H], 443 | [0, 0, 1] 444 | ]) 445 | 446 | # Create log dir and copy the config file 447 | basedir = args.basedir 448 | expname = args.expname 449 | 450 | os.makedirs(os.path.join(basedir, expname), exist_ok=True) 451 | f = os.path.join(basedir, expname, 'args.txt') 452 | with open(f, 'w') as file: 453 | for arg in sorted(vars(args)): 454 | attr = getattr(args, arg) 455 | file.write('{} = {}\n'.format(arg, attr)) 456 | if args.config is not None: 457 | f = os.path.join(basedir, expname, 'config.txt') 458 | with open(f, 'w') as file: 459 | file.write(open(args.config, 'r').read()) 460 | 461 | # Create nerf model 462 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args) 463 | global_step = start 464 | 465 | bds_dict = { 466 | 'near': near, 467 | 'far': far, 468 | } 469 | render_kwargs_train.update(bds_dict) 470 | render_kwargs_test.update(bds_dict) 471 | 472 | # Move testing data to GPU 473 | render_poses = torch.Tensor(render_poses).to(device) 474 | 475 | # Short circuit if only rendering out from trained model 476 | if args.render_only: 477 | print('RENDER ONLY') 478 | with torch.no_grad(): 479 | testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(start)) 480 | os.makedirs(testsavedir, exist_ok=True) 481 | with torch.no_grad(): 482 | test_view_pose = torch.tensor(poses_test[0]) 483 | N_timesteps = images_test.shape[0] 484 | test_timesteps = torch.arange(N_timesteps) / (N_timesteps - 1) 485 | test_view_poses = test_view_pose.unsqueeze(0).repeat(N_timesteps, 1, 1) 486 | print(test_view_poses.shape) 487 | test_view_poses = torch.tensor(poses_train[0]).unsqueeze(0).repeat(N_timesteps, 1, 1) 488 | print(test_view_poses.shape) 489 | render_path(test_view_poses, hwf, K, render_kwargs_test, time_steps=test_timesteps, gt_imgs=images_test, 490 | savedir=testsavedir) 491 | return 492 | 493 | # Prepare raybatch tensor if batching random rays 494 | N_rand = args.N_rand 495 | # For random ray batching 496 | print('get rays') 497 | rays = [] 498 | ij = [] 499 | 500 | # anti-aliasing 501 | for p in poses_train[:, :3, :4]: 502 | r_o, r_d, i_, j_ = get_rays_np_continuous(H, W, K, p) 503 | rays.append([r_o, r_d]) 504 | ij.append([i_, j_]) 505 | rays = np.stack(rays, 0) # [V, ro+rd=2, H, W, 3] 506 | ij = np.stack(ij, 0) # [V, 2, H, W] 507 | images_train = sample_bilinear(images_train_, ij) # [T, V, H, W, 3] 508 | 509 | rays = np.transpose(rays, [0, 2, 3, 1, 4]) # [V, H, W, ro+rd=2, 3] 510 | rays = np.reshape(rays, [-1, 2, 3]) # [VHW, ro+rd=2, 3] 511 | rays = rays.astype(np.float32) 512 | 513 | print('done') 514 | i_batch = 0 515 | 516 | # Move training data to GPU 517 | images_train = torch.Tensor(images_train).to(device).flatten(start_dim=1, end_dim=3) # [T, VHW, 3] 518 | T, S, _ = images_train.shape 519 | rays = torch.Tensor(rays).to(device) 520 | ray_idxs = torch.randperm(rays.shape[0]) 521 | 522 | loss_list = [] 523 | psnr_list = [] 524 | start = start + 1 525 | loss_meter, psnr_meter = AverageMeter(), AverageMeter() 526 | resample_rays = False 527 | for i in trange(start, args.N_iters + 1): 528 | # Sample random ray batch 529 | batch_ray_idx = ray_idxs[i_batch:i_batch + N_rand] 530 | batch_rays = rays[batch_ray_idx] # [B, 2, 3] 531 | batch_rays = torch.transpose(batch_rays, 0, 1) # [2, B, 3] 532 | 533 | i_batch += N_rand 534 | # temporal bilinear sampling 535 | time_idx = torch.randperm(T)[:args.N_time].float().to(device) # [N_t] 536 | time_idx += torch.randn(args.N_time) - 0.5 # -0.5 ~ 0.5 537 | time_idx_floor = torch.floor(time_idx).long() 538 | time_idx_ceil = torch.ceil(time_idx).long() 539 | time_idx_floor = torch.clamp(time_idx_floor, 0, T - 1) 540 | time_idx_ceil = torch.clamp(time_idx_ceil, 0, T - 1) 541 | time_idx_residual = time_idx - time_idx_floor.float() 542 | frames_floor = images_train[time_idx_floor] # [N_t, VHW, 3] 543 | frames_ceil = images_train[time_idx_ceil] # [N_t, VHW, 3] 544 | frames_interp = frames_floor * (1 - time_idx_residual).unsqueeze(-1) + \ 545 | frames_ceil * time_idx_residual.unsqueeze(-1) # [N_t, VHW, 3] 546 | time_step = time_idx / (T - 1) if T > 1 else torch.zeros_like(time_idx) 547 | points = frames_interp[:, batch_ray_idx] # [N_t, B, 3] 548 | target_s = points.flatten(0, 1) # [N_t*B, 3] 549 | 550 | if i_batch >= rays.shape[0]: 551 | print("Shuffle data after an epoch!") 552 | ray_idxs = torch.randperm(rays.shape[0]) 553 | i_batch = 0 554 | resample_rays = True 555 | 556 | ##### Core optimization loop ##### 557 | rgb, depth, acc, extras = render(H, W, K, rays=batch_rays, time_step=time_step, 558 | **render_kwargs_train) 559 | 560 | img_loss = img2mse(rgb, target_s) 561 | loss = img_loss 562 | psnr = mse2psnr(img_loss) 563 | loss_meter.update(loss.item()) 564 | psnr_meter.update(psnr.item()) 565 | 566 | for param in grad_vars: # slightly faster than optimizer.zero_grad() 567 | param.grad = None 568 | loss.backward() 569 | optimizer.step() 570 | 571 | ### update learning rate ### 572 | decay_rate = 0.1 573 | decay_steps = args.lrate_decay 574 | new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps)) 575 | for param_group in optimizer.param_groups: 576 | param_group['lr'] = new_lrate 577 | ################################ 578 | 579 | # Rest is logging 580 | if i % args.i_weights == 0: 581 | path = os.path.join(basedir, expname, '{:06d}.tar'.format(i)) 582 | torch.save({ 583 | 'global_step': global_step, 584 | 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), 585 | 'embed_fn_state_dict': render_kwargs_train['embed_fn'].state_dict(), 586 | 'optimizer_state_dict': optimizer.state_dict(), 587 | }, path) 588 | print('Saved checkpoints at', path) 589 | 590 | if i % args.i_video == 0 and i > 0: 591 | # Turn on testing mode 592 | testsavedir = os.path.join(basedir, expname, 'spiral_{:06d}'.format(i)) 593 | os.makedirs(testsavedir, exist_ok=True) 594 | with torch.no_grad(): 595 | render_path(render_poses, hwf, K, render_kwargs_test, time_steps=render_timesteps, savedir=testsavedir) 596 | 597 | testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) 598 | os.makedirs(testsavedir, exist_ok=True) 599 | with torch.no_grad(): 600 | test_view_pose = torch.tensor(poses_test[0]) 601 | N_timesteps = images_test.shape[0] 602 | test_timesteps = torch.arange(N_timesteps) / (N_timesteps - 1) 603 | test_view_poses = test_view_pose.unsqueeze(0).repeat(N_timesteps, 1, 1) 604 | render_path(test_view_poses, hwf, K, render_kwargs_test, time_steps=test_timesteps, gt_imgs=images_test, 605 | savedir=testsavedir) 606 | 607 | if i % args.i_print == 0: 608 | tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss_meter.avg:.2g} PSNR: {psnr_meter.avg:.4g}") 609 | loss_list.append(loss_meter.avg) 610 | psnr_list.append(psnr_meter.avg) 611 | loss_psnr = { 612 | "losses": loss_list, 613 | "psnr": psnr_list, 614 | } 615 | loss_meter.reset() 616 | psnr_meter.reset() 617 | with open(os.path.join(basedir, expname, "loss_vs_time.json"), "w") as fp: 618 | json.dump(loss_psnr, fp) 619 | 620 | if resample_rays: 621 | print("Sampling new rays!") 622 | rays = [] 623 | ij = [] 624 | for p in poses_train[:, :3, :4]: 625 | r_o, r_d, i_, j_ = get_rays_np_continuous(H, W, K, p) 626 | rays.append([r_o, r_d]) 627 | ij.append([i_, j_]) 628 | rays = np.stack(rays, 0) # [V, ro+rd=2, H, W, 3] 629 | ij = np.stack(ij, 0) # [V, 2, H, W] 630 | images_train = sample_bilinear(images_train_, ij) # [T, V, H, W, 3] 631 | rays = np.transpose(rays, [0, 2, 3, 1, 4]) # [V, H, W, ro+rd=2, 3] 632 | rays = np.reshape(rays, [-1, 2, 3]) # [VHW, ro+rd=2, 3] 633 | rays = rays.astype(np.float32) 634 | 635 | # Move training data to GPU 636 | images_train = torch.Tensor(images_train).to(device).flatten(start_dim=1, end_dim=3) # [T, VHW, 3] 637 | T, S, _ = images_train.shape 638 | rays = torch.Tensor(rays).to(device) 639 | 640 | ray_idxs = torch.randperm(rays.shape[0]) 641 | i_batch = 0 642 | resample_rays = False 643 | global_step += 1 644 | 645 | 646 | if __name__ == '__main__': 647 | import taichi as ti 648 | 649 | ti.init(arch=ti.cuda, device_memory_GB=6.0) 650 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 651 | import ipdb 652 | 653 | try: 654 | train() 655 | except Exception as e: 656 | print(e) 657 | ipdb.post_mortem() 658 | 659 | -------------------------------------------------------------------------------- /run_nerf_helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # torch.autograd.set_detect_anomaly(True) 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | # from hash_encoding import HashEmbedder, SHEncoder 8 | 9 | # Misc 10 | img2mse = lambda x, y: torch.mean((x - y) ** 2) 11 | mse2psnr = lambda x: -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 12 | to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) 13 | 14 | 15 | def batchify_query(inputs, query_function, batch_size=2 ** 22): 16 | """ 17 | args: 18 | inputs: [..., input_dim] 19 | return: 20 | outputs: [..., output_dim] 21 | """ 22 | input_dim = inputs.shape[-1] 23 | input_shape = inputs.shape 24 | inputs = inputs.view(-1, input_dim) # flatten all but last dim 25 | N = inputs.shape[0] 26 | outputs = [] 27 | for i in range(0, N, batch_size): 28 | output = query_function(inputs[i:i + batch_size]) 29 | if isinstance(output, tuple): 30 | output = output[0] 31 | outputs.append(output) 32 | outputs = torch.cat(outputs, dim=0) 33 | return outputs.view(*input_shape[:-1], -1) # unflatten 34 | 35 | 36 | class SineLayer(nn.Module): 37 | # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0. 38 | 39 | # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the 40 | # nonlinearity. Different signals may require different omega_0 in the first layer - this is a 41 | # hyperparameter. 42 | 43 | # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of 44 | # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5) 45 | 46 | def __init__(self, in_features, out_features, bias=True, 47 | is_first=False, omega_0=30): 48 | super().__init__() 49 | self.omega_0 = omega_0 50 | self.is_first = is_first 51 | 52 | self.in_features = in_features 53 | self.linear = nn.Linear(in_features, out_features, bias=bias) 54 | 55 | self.init_weights() 56 | 57 | def init_weights(self): 58 | with torch.no_grad(): 59 | if self.is_first: 60 | self.linear.weight.uniform_(-1 / self.in_features, 61 | 1 / self.in_features) 62 | else: 63 | self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 64 | np.sqrt(6 / self.in_features) / self.omega_0) 65 | 66 | def forward(self, input): 67 | return torch.sin(self.omega_0 * self.linear(input)) 68 | 69 | 70 | class SirenNeRF(nn.Module): 71 | def __init__(self, D=8, W=256, input_ch=3, output_ch=4, skips=[4], 72 | first_omega_0=30, hidden_omega_0=1): 73 | """ 74 | """ 75 | super(SirenNeRF, self).__init__() 76 | self.D = D 77 | self.W = W 78 | self.input_ch = input_ch 79 | self.skips = skips 80 | 81 | self.pts_linears = nn.ModuleList([SineLayer(input_ch, W, omega_0=first_omega_0, is_first=True)] \ 82 | + [SineLayer(W, W, omega_0=hidden_omega_0) for i in range(D - 1)]) 83 | 84 | self.output_linear = nn.Linear(W, output_ch) 85 | # with torch.no_grad(): 86 | # self.output_linear.weight.uniform_(-np.sqrt(6 / W) / hidden_omega_0, 87 | # np.sqrt(6 / W) / hidden_omega_0) 88 | 89 | def forward(self, x): 90 | input_pts = x # torch.split(x, [self.input_ch, self.input_ch_views], dim=-1) 91 | 92 | h = input_pts 93 | for i, l in enumerate(self.pts_linears): 94 | h = self.pts_linears[i](h) 95 | 96 | outputs = self.output_linear(h) 97 | return outputs 98 | 99 | 100 | # Positional encoding (section 5.1) 101 | class Embedder: 102 | def __init__(self, **kwargs): 103 | self.kwargs = kwargs 104 | self.create_embedding_fn() 105 | 106 | def create_embedding_fn(self): 107 | embed_fns = [] 108 | d = self.kwargs['input_dims'] 109 | out_dim = 0 110 | if self.kwargs['include_input']: 111 | embed_fns.append(lambda x: x) 112 | out_dim += d 113 | 114 | max_freq = self.kwargs['max_freq_log2'] 115 | N_freqs = self.kwargs['num_freqs'] 116 | 117 | if self.kwargs['log_sampling']: 118 | freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs) 119 | else: 120 | freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs) 121 | 122 | for freq in freq_bands: 123 | for p_fn in self.kwargs['periodic_fns']: 124 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 125 | out_dim += d 126 | 127 | self.embed_fns = embed_fns 128 | self.out_dim = out_dim 129 | 130 | def embed(self, inputs): 131 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 132 | 133 | 134 | def get_embedder(multires, args, i=0): 135 | if i == -1: 136 | return nn.Identity(), 3 137 | elif i == 0: 138 | embed_kwargs = { 139 | 'include_input': True, 140 | 'input_dims': 3, 141 | 'max_freq_log2': multires - 1, 142 | 'num_freqs': multires, 143 | 'log_sampling': True, 144 | 'periodic_fns': [torch.sin, torch.cos], 145 | } 146 | 147 | embedder_obj = Embedder(**embed_kwargs) 148 | embed = lambda x, eo=embedder_obj: eo.embed(x) 149 | out_dim = embedder_obj.out_dim 150 | elif i == 1: 151 | embed = HashEmbedder(bounding_box=args.bounding_box, \ 152 | log2_hashmap_size=args.log2_hashmap_size, \ 153 | finest_resolution=args.finest_res) 154 | out_dim = embed.out_dim 155 | elif i == 2: 156 | embed = SHEncoder() 157 | out_dim = embed.out_dim 158 | return embed, out_dim 159 | 160 | 161 | # Small NeRF for Hash embeddings 162 | class NeRFSmall(nn.Module): 163 | def __init__(self, 164 | num_layers=3, 165 | hidden_dim=64, 166 | geo_feat_dim=15, 167 | num_layers_color=2, 168 | hidden_dim_color=16, 169 | input_ch=3, 170 | ): 171 | super(NeRFSmall, self).__init__() 172 | 173 | self.input_ch = input_ch 174 | self.rgb = torch.nn.Parameter(torch.tensor([0.0])) 175 | 176 | # sigma network 177 | self.num_layers = num_layers 178 | self.hidden_dim = hidden_dim 179 | self.geo_feat_dim = geo_feat_dim 180 | 181 | sigma_net = [] 182 | for l in range(num_layers): 183 | if l == 0: 184 | in_dim = self.input_ch 185 | else: 186 | in_dim = hidden_dim 187 | 188 | if l == num_layers - 1: 189 | out_dim = 1 # 1 sigma + 15 SH features for color 190 | else: 191 | out_dim = hidden_dim 192 | 193 | sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) 194 | 195 | self.sigma_net = nn.ModuleList(sigma_net) 196 | 197 | self.color_net = [] 198 | for l in range(num_layers_color): 199 | if l == 0: 200 | in_dim = 1 201 | else: 202 | in_dim = hidden_dim_color 203 | 204 | if l == num_layers_color - 1: 205 | out_dim = 1 206 | else: 207 | out_dim = hidden_dim_color 208 | 209 | self.color_net.append(nn.Linear(in_dim, out_dim, bias=True)) 210 | 211 | def forward(self, x): 212 | h = x 213 | for l in range(self.num_layers): 214 | h = self.sigma_net[l](h) 215 | h = F.relu(h, inplace=True) 216 | 217 | sigma = h 218 | return sigma 219 | 220 | class NeRFSmall_c(nn.Module): 221 | def __init__(self, 222 | num_layers=3, 223 | hidden_dim=64, 224 | geo_feat_dim=15, 225 | num_layers_color=2, 226 | hidden_dim_color=16, 227 | input_ch=3, 228 | ): 229 | super(NeRFSmall_c, self).__init__() 230 | 231 | self.input_ch = input_ch 232 | self.rgb = torch.nn.Parameter(torch.tensor([0.0])) 233 | 234 | # sigma network 235 | self.num_layers = num_layers 236 | self.hidden_dim = hidden_dim 237 | self.geo_feat_dim = geo_feat_dim 238 | self.num_layers_color = num_layers_color 239 | 240 | sigma_net = [] 241 | for l in range(num_layers): 242 | if l == 0: 243 | in_dim = self.input_ch 244 | else: 245 | in_dim = hidden_dim 246 | 247 | if l == num_layers - 1: 248 | out_dim = 1 + geo_feat_dim # 1 sigma + 15 SH features for color 249 | else: 250 | out_dim = hidden_dim 251 | 252 | sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) 253 | 254 | self.sigma_net = nn.ModuleList(sigma_net) 255 | 256 | self.color_net = [] 257 | for l in range(num_layers_color): 258 | if l == 0: 259 | in_dim = geo_feat_dim 260 | else: 261 | in_dim = hidden_dim_color 262 | 263 | if l == num_layers_color - 1: 264 | out_dim = 1 265 | else: 266 | out_dim = hidden_dim_color 267 | 268 | self.color_net.append(nn.Linear(in_dim, out_dim, bias=True)) 269 | self.color_net = nn.ModuleList(self.color_net) 270 | 271 | def forward(self, x): 272 | h = x 273 | for l in range(self.num_layers): 274 | h = self.sigma_net[l](h) 275 | h = F.relu(h, inplace=True) 276 | 277 | sigma = h 278 | color = self.color_net[0](sigma[..., 1:]) 279 | for l in range(1, self.num_layers_color): 280 | color = F.relu(color, inplace=True) 281 | color = self.color_net[l](color) 282 | return sigma[..., :1], color 283 | 284 | 285 | 286 | class NeRFSmall_bg(nn.Module): 287 | def __init__(self, 288 | num_layers=3, 289 | hidden_dim=64, 290 | geo_feat_dim=15, 291 | num_layers_color=2, 292 | hidden_dim_color=16, 293 | input_ch=3, 294 | ): 295 | super(NeRFSmall_bg, self).__init__() 296 | 297 | self.input_ch = input_ch 298 | self.rgb = torch.nn.Parameter(torch.tensor([0.0])) 299 | 300 | # sigma network 301 | self.num_layers = num_layers 302 | self.hidden_dim = hidden_dim 303 | self.geo_feat_dim = geo_feat_dim 304 | 305 | sigma_net = [] 306 | for l in range(num_layers): 307 | if l == 0: 308 | in_dim = self.input_ch 309 | else: 310 | in_dim = hidden_dim 311 | 312 | if l == num_layers - 1: 313 | out_dim = 1 + geo_feat_dim # 1 sigma + 15 SH features for color 314 | else: 315 | out_dim = hidden_dim 316 | 317 | sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) 318 | 319 | self.sigma_net = nn.ModuleList(sigma_net) 320 | 321 | # color network 322 | self.color_net = [] 323 | for l in range(num_layers_color): 324 | if l == 0: 325 | in_dim = input_ch + geo_feat_dim # 1 for sigma, 15 for SH features 326 | else: 327 | in_dim = hidden_dim_color 328 | 329 | if l == num_layers_color - 1: 330 | out_dim = 3 # RGB color channels 331 | else: 332 | out_dim = hidden_dim_color 333 | 334 | self.color_net.append(nn.Linear(in_dim, out_dim, bias=True)) 335 | 336 | self.color_net = nn.ModuleList(self.color_net) 337 | 338 | def forward(self, x): 339 | h = x 340 | for l in range(self.num_layers): 341 | h = self.sigma_net[l](h) 342 | h = F.relu(h, inplace=True) 343 | 344 | sigma = h[..., :1] 345 | geo_feat = h[..., 1:] 346 | 347 | # color network 348 | h_color = torch.cat([geo_feat, x], dim=-1) # concatenate sigma and SH features 349 | for l in range(len(self.color_net)): 350 | h_color = self.color_net[l](h_color) 351 | if l < len(self.color_net) - 1: 352 | h_color = F.relu(h_color, inplace=True) 353 | 354 | color = torch.sigmoid(h_color) # apply sigmoid activation to get color values in range [0, 1] 355 | 356 | return sigma, color 357 | 358 | 359 | 360 | class NeRFSmallPotential(nn.Module): 361 | def __init__(self, 362 | num_layers=3, 363 | hidden_dim=64, 364 | geo_feat_dim=15, 365 | num_layers_color=2, 366 | hidden_dim_color=16, 367 | input_ch=3, 368 | use_f=False 369 | ): 370 | super(NeRFSmallPotential, self).__init__() 371 | 372 | self.input_ch = input_ch 373 | self.rgb = torch.nn.Parameter(torch.tensor([0.0])) 374 | 375 | # sigma network 376 | self.num_layers = num_layers 377 | self.hidden_dim = hidden_dim 378 | self.geo_feat_dim = geo_feat_dim 379 | 380 | sigma_net = [] 381 | for l in range(num_layers): 382 | if l == 0: 383 | in_dim = self.input_ch 384 | else: 385 | in_dim = hidden_dim 386 | 387 | if l == num_layers - 1: 388 | out_dim = hidden_dim # 1 sigma + 15 SH features for color 389 | else: 390 | out_dim = hidden_dim 391 | 392 | sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) 393 | self.sigma_net = nn.ModuleList(sigma_net) 394 | self.out = nn.Linear(hidden_dim, 3, bias=True) 395 | self.use_f = use_f 396 | if use_f: 397 | self.out_f = nn.Linear(hidden_dim, hidden_dim, bias=True) 398 | self.out_f2 = nn.Linear(hidden_dim, 3, bias=True) 399 | 400 | 401 | def forward(self, x): 402 | h = x 403 | for l in range(self.num_layers): 404 | h = self.sigma_net[l](h) 405 | h = F.relu(h, True) 406 | 407 | v = self.out(h) 408 | if self.use_f: 409 | f = self.out_f(h) 410 | f = F.relu(f, True) 411 | f = self.out_f2(f) 412 | else: 413 | f = v * 0 414 | return v, f 415 | 416 | 417 | def save_quiver_plot(u, v, res, save_path, scale=0.00000002): 418 | """ 419 | Args: 420 | u: [H, W], vel along x (W) 421 | v: [H, W], vel along y (H) 422 | res: resolution of the plot along the longest axis; if None, let step = 1 423 | save_path: 424 | """ 425 | import matplotlib.pyplot as plt 426 | import matplotlib 427 | H, W = u.shape 428 | y, x = np.mgrid[0:H, 0:W] 429 | axis_len = max(H, W) 430 | step = 1 if res is None else axis_len // res 431 | xq = [i[::step] for i in x[::step]] 432 | yq = [i[::step] for i in y[::step]] 433 | uq = [i[::step] for i in u[::step]] 434 | vq = [i[::step] for i in v[::step]] 435 | 436 | uv_norm = np.sqrt(np.array(uq) ** 2 + np.array(vq) ** 2).max() 437 | short_len = min(H, W) 438 | matplotlib.rcParams['font.size'] = 10 / short_len * axis_len 439 | fig, ax = plt.subplots(figsize=(10 / short_len * W, 10 / short_len * H)) 440 | q = ax.quiver(xq, yq, uq, vq, pivot='tail', angles='uv', scale_units='xy', scale=scale / step) 441 | ax.invert_yaxis() 442 | plt.quiverkey(q, X=0.6, Y=1.05, U=uv_norm, label=f'Max arrow length = {uv_norm:.2g}', labelpos='E') 443 | plt.savefig(save_path) 444 | plt.close() 445 | return 446 | 447 | 448 | # Ray helpers 449 | def get_rays(H, W, K, c2w): 450 | i, j = torch.meshgrid(torch.linspace(0, W - 1, W), torch.linspace(0, H - 1, H), 451 | indexing='ij') # pytorch's meshgrid has indexing='ij' 452 | i = i.t() 453 | j = j.t() 454 | dirs = torch.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -torch.ones_like(i)], -1) 455 | # Rotate ray directions from camera frame to the world frame 456 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], 457 | -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 458 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 459 | rays_o = c2w[:3, -1].expand(rays_d.shape) 460 | return rays_o, rays_d 461 | 462 | 463 | def get_rays_np(H, W, K, c2w): 464 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') 465 | dirs = np.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -np.ones_like(i)], -1) 466 | # Rotate ray directions from camera frame to the world frame 467 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], 468 | -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 469 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 470 | rays_o = np.broadcast_to(c2w[:3, -1], np.shape(rays_d)) 471 | return rays_o, rays_d 472 | 473 | 474 | def get_rays_np_continuous(H, W, K, c2w): 475 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') 476 | random_offset_i = np.random.uniform(0, 1, size=(H, W)) 477 | random_offset_j = np.random.uniform(0, 1, size=(H, W)) 478 | i = i + random_offset_i 479 | j = j + random_offset_j 480 | i = np.clip(i, 0, W - 1) 481 | j = np.clip(j, 0, H - 1) 482 | 483 | dirs = np.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -np.ones_like(i)], -1) 484 | # Rotate ray directions from camera frame to the world frame 485 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], 486 | -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 487 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 488 | rays_o = np.broadcast_to(c2w[:3, -1], np.shape(rays_d)) 489 | return rays_o, rays_d, i, j 490 | 491 | 492 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 493 | # Shift ray origins to near plane 494 | t = -(near + rays_o[..., 2]) / rays_d[..., 2] 495 | rays_o = rays_o + t[..., None] * rays_d 496 | 497 | # Projection 498 | o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2] 499 | o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2] 500 | o2 = 1. + 2. * near / rays_o[..., 2] 501 | 502 | d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2]) 503 | d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2]) 504 | d2 = -2. * near / rays_o[..., 2] 505 | 506 | rays_o = torch.stack([o0, o1, o2], -1) 507 | rays_d = torch.stack([d0, d1, d2], -1) 508 | 509 | return rays_o, rays_d 510 | 511 | 512 | def sample_bilinear(img, xy): 513 | """ 514 | Sample image with bilinear interpolation 515 | :param img: (T, V, H, W, 3) 516 | :param xy: (V, 2, H, W) 517 | :return: img: (T, V, H, W, 3) 518 | """ 519 | T, V, H, W, _ = img.shape 520 | u, v = xy[:, 0], xy[:, 1] 521 | 522 | u = np.clip(u, 0, W - 1) 523 | v = np.clip(v, 0, H - 1) 524 | 525 | u_floor, v_floor = np.floor(u).astype(int), np.floor(v).astype(int) 526 | u_ceil, v_ceil = np.ceil(u).astype(int), np.ceil(v).astype(int) 527 | 528 | u_ratio, v_ratio = u - u_floor, v - v_floor 529 | u_ratio, v_ratio = u_ratio[None, ..., None], v_ratio[None, ..., None] 530 | 531 | bottom_left = img[:, np.arange(V)[:, None, None], v_floor, u_floor] 532 | bottom_right = img[:, np.arange(V)[:, None, None], v_floor, u_ceil] 533 | top_left = img[:, np.arange(V)[:, None, None], v_ceil, u_floor] 534 | top_right = img[:, np.arange(V)[:, None, None], v_ceil, u_ceil] 535 | 536 | bottom = (1 - u_ratio) * bottom_left + u_ratio * bottom_right 537 | top = (1 - u_ratio) * top_left + u_ratio * top_right 538 | 539 | interpolated = (1 - v_ratio) * bottom + v_ratio * top 540 | 541 | return interpolated 542 | 543 | 544 | # Hierarchical sampling (section 5.2) 545 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False): 546 | # Get pdf 547 | weights = weights + 1e-5 # prevent nans 548 | pdf = weights / torch.sum(weights, -1, keepdim=True) 549 | cdf = torch.cumsum(pdf, -1) 550 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) 551 | 552 | # Take uniform samples 553 | if det: 554 | u = torch.linspace(0., 1., steps=N_samples) 555 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 556 | else: 557 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples]) 558 | 559 | # Pytest, overwrite u with numpy's fixed random numbers 560 | if pytest: 561 | np.random.seed(0) 562 | new_shape = list(cdf.shape[:-1]) + [N_samples] 563 | if det: 564 | u = np.linspace(0., 1., N_samples) 565 | u = np.broadcast_to(u, new_shape) 566 | else: 567 | u = np.random.rand(*new_shape) 568 | u = torch.Tensor(u) 569 | 570 | # Invert CDF 571 | u = u.contiguous() 572 | inds = torch.searchsorted(cdf, u, right=True) 573 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 574 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 575 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 576 | 577 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 578 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 579 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 580 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 581 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 582 | 583 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 584 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 585 | t = (u - cdf_g[..., 0]) / denom 586 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 587 | 588 | return samples 589 | -------------------------------------------------------------------------------- /scripts/test_future_pred.sh: -------------------------------------------------------------------------------- 1 | python run_nerf_vort.py --config configs/scalarflowreal.txt --lrate 0.01 --lrate_den 1e-4 --lrate_decay 5000 --N_iters 10000 --i_weights 5000 \ 2 | --expname exp_real/vort50 --finest_resolution 256 --base_resolution 16 --finest_resolution_t 128 --base_resolution_t 16 --num_levels 16 --N_samples 192 --N_rand 512 --log2_hashmap_size 19 --vel_num_layers 2 \ 3 | --ft_path ./logs/exp_real/p_v128_128/den/100000.tar \ 4 | --vel_path ./logs/exp_real/p_v128_128/100000.tar --no_vel_der --vel_scale 0.05 \ 5 | --finest_resolution_v 128 --base_resolution_v 16 --finest_resolution_v_t 128 --base_resolution_v_t 16 \ 6 | --n_particles 50 --vort_intensity 5 --vort_weight 0.01 --run_future_pred -------------------------------------------------------------------------------- /scripts/test_resim.sh: -------------------------------------------------------------------------------- 1 | python run_nerf_vort.py --config configs/scalarflowreal.txt --lrate 0.01 --lrate_den 1e-4 --lrate_decay 5000 --N_iters 10000 --i_weights 5000 \ 2 | --expname exp_real/vort50 --finest_resolution 256 --base_resolution 16 --finest_resolution_t 128 --base_resolution_t 16 --num_levels 16 --N_samples 192 --N_rand 512 --log2_hashmap_size 19 --vel_num_layers 2 \ 3 | --ft_path ./logs/exp_real/p_v128_128/den/100000.tar \ 4 | --vel_path ./logs/exp_real/p_v128_128/100000.tar --no_vel_der --vel_scale 0.05 \ 5 | --finest_resolution_v 128 --base_resolution_v 16 --finest_resolution_v_t 128 --base_resolution_v_t 16 \ 6 | --n_particles 50 --vort_intensity 5 --vort_weight 0.01 --run_advect_den -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | python run_nerf_density.py --config configs/scalarflowreal.txt --lrate 0.01 \ 2 | --lrate_decay 100000 --N_iters 300000 --i_weights 100000 --N_time 1 \ 3 | --expname exp_real/density_256_128 --i_video 100000 --finest_resolution 256 \ 4 | --base_resolution 16 --finest_resolution_t 128 --base_resolution_t 16 --num_levels 16 --N_samples 192 --N_rand 256 --log2_hashmap_size 19 5 | -------------------------------------------------------------------------------- /scripts/train_j.sh: -------------------------------------------------------------------------------- 1 | python run_nerf_jointly.py --config configs/scalarflowreal.txt --lrate_decay 100000 --N_iters 100000 --i_weights 10000 \ 2 | --expname exp_real/p_v128_128 --lrate 5e-4 --lrate_den 1e-4 --rec_weight 10000 --d2v_weight 10 --coef_den2vel 0.2 --vel_weight 1 --d_weight 0 --proj_weight 1 --flow_weight 0.001 --vel_num_layers 2 --i_video 10000 --i_print 100 --finest_resolution 256 --base_resolution 16 --finest_resolution_t 128 --base_resolution_t 16 --num_levels 16 --N_samples 192 --N_rand 512 --log2_hashmap_size 19 \ 3 | --ft_path ./logs/exp_real/density_256_128/300000.tar --vel_scale 0.025 \ 4 | --finest_resolution_v 128 --base_resolution_v 16 --finest_resolution_v_t 128 --base_resolution_v_t 16 --no_vel_der -------------------------------------------------------------------------------- /scripts/train_vort.sh: -------------------------------------------------------------------------------- 1 | python run_nerf_vort.py --config configs/scalarflowreal.txt --lrate 0.01 --lrate_den 1e-4 --lrate_decay 5000 --N_iters 10000 --i_weights 5000 \ 2 | --expname exp_real/vort50 --i_video 1000 --i_print 100 --finest_resolution 256 --base_resolution 16 --finest_resolution_t 128 --base_resolution_t 16 --num_levels 16 --N_samples 192 --N_rand 512 --log2_hashmap_size 19 --rec_weight 10000 \ 3 | --rec_weight 10000 --vel_weight 0 --d_weight 0 --flow_weight 0.001 --vel_num_layers 2 \ 4 | --ft_path ./logs/exp_real/p_v128_128/den/100000.tar \ 5 | --vel_path ./logs/exp_real/p_v128_128/100000.tar --no_vel_der --vel_scale 0.05 \ 6 | --finest_resolution_v 128 --base_resolution_v 16 --finest_resolution_v_t 128 --base_resolution_v_t 16 \ 7 | --n_particles 50 --vort_intensity 5 --vort_weight 0.01 -------------------------------------------------------------------------------- /taichi_encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/y-zheng18/HyFluid/70b4b962c66d0371e366ad6ebbb61af562b08846/taichi_encoders/__init__.py -------------------------------------------------------------------------------- /taichi_encoders/hash4.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import taichi as ti 3 | import torch 4 | from taichi.math import uvec3, uvec4 5 | from torch.cuda.amp import custom_bwd, custom_fwd 6 | 7 | from torch.func import jacrev, vmap 8 | 9 | from .utils import (data_type, ti2torch, ti2torch_grad, ti2torch_grad_vec, 10 | ti2torch_vec, torch2ti, torch2ti_grad, torch2ti_grad_vec, 11 | torch2ti_vec, torch_type) 12 | 13 | 14 | @ti.kernel 15 | def random_initialize(data: ti.types.ndarray()): 16 | for I in ti.grouped(data): 17 | data[I] = (ti.random() * 2.0 - 1.0) * 1e-4 18 | # data[I] = ti.random() * 0.4 + 0.1 19 | 20 | 21 | @ti.func 22 | def fast_hash(pos_grid_local): 23 | result = ti.uint32(0) 24 | # primes = uvec3(ti.uint32(1), ti.uint32(1958374283), ti.uint32(2654435761)) 805459861u, 3674653429u 25 | primes = uvec4(ti.uint32(1), ti.uint32(2654435761), ti.uint32(805459861), ti.uint32(3674653429)) 26 | for i in ti.static(range(4)): 27 | result ^= ti.uint32(pos_grid_local[i]) * primes[i] 28 | return result 29 | 30 | 31 | # ravel (i, j, k, t) to i + i_dim * j + (i_dim * j_dim) * k + (i_dim * j_dim * k_dim) * t 32 | @ti.func 33 | def under_hash(pos_grid_local, resolution): 34 | result = ti.uint32(0) 35 | stride = ti.uint32(1) 36 | for i in ti.static(range(4)): 37 | result += ti.uint32(pos_grid_local[i] * stride) 38 | stride *= resolution[i] + 1 # note the +1 here, because 256 x 256 grid actually has 257 x 257 entries 39 | return result 40 | 41 | 42 | @ti.func 43 | def grid_pos2hash_index(indicator, pos_grid_local, plane_res, map_size): 44 | hash_result = ti.uint32(0) 45 | if indicator == 1: 46 | hash_result = under_hash(pos_grid_local, plane_res) 47 | else: 48 | hash_result = fast_hash(pos_grid_local) 49 | 50 | return hash_result % map_size 51 | 52 | 53 | @ti.func 54 | def smooth_step(t): 55 | return t * t * (3 - 2 * t) 56 | 57 | 58 | @ti.func 59 | def d_smooth_step(t): 60 | return 6 * t * (1 - t) 61 | 62 | 63 | @ti.func 64 | def linear_step(t): 65 | return t 66 | 67 | 68 | @ti.func 69 | def d_linear_step(t): 70 | return 1 71 | 72 | 73 | @ti.func 74 | def isnan(x): 75 | return not (x < 0 or 0 < x or x == 0) 76 | 77 | @ti.kernel 78 | def hash_encode_kernel_smoothstep( 79 | xyzts: ti.template(), table: ti.template(), 80 | xyzts_embedding: ti.template(), hash_map_indicator: ti.template(), 81 | hash_map_sizes_field: ti.template(), hash_map_shapes_field: ti.template(), 82 | offsets: ti.template(), B: ti.i32, num_scales: ti.i32): 83 | # # # get hash table embedding 84 | ti.loop_config(block_dim=16) 85 | for i, level in ti.ndrange(B, num_scales): 86 | res_x = hash_map_shapes_field[level, 0] 87 | res_y = hash_map_shapes_field[level, 1] 88 | res_z = hash_map_shapes_field[level, 2] 89 | res_t = hash_map_shapes_field[level, 3] 90 | plane_res = ti.Vector([res_x, res_y, res_z, res_t]) 91 | pos = ti.Vector([xyzts[i, 0], xyzts[i, 1], xyzts[i, 2], xyzts[i, 3]]) * plane_res 92 | 93 | pos_grid_uint = ti.cast(ti.floor(pos), ti.uint32) # floor 94 | pos_grid_uint = ti.math.clamp(pos_grid_uint, 0, plane_res - 1) 95 | pos -= pos_grid_uint # pos now represents frac 96 | pos = ti.math.clamp(pos, 0.0, 1.0) 97 | 98 | offset = offsets[level] 99 | 100 | indicator = hash_map_indicator[level] 101 | map_size = hash_map_sizes_field[level] 102 | 103 | local_feature_0 = 0.0 104 | local_feature_1 = 0.0 105 | 106 | for idx in ti.static(range(16)): 107 | w = 1. 108 | pos_grid_local = uvec4(0) 109 | 110 | for d in ti.static(range(4)): 111 | t = smooth_step(pos[d]) 112 | if (idx & (1 << d)) == 0: 113 | pos_grid_local[d] = pos_grid_uint[d] 114 | w *= 1 - t 115 | else: 116 | pos_grid_local[d] = pos_grid_uint[d] + 1 117 | w *= t 118 | 119 | index = grid_pos2hash_index(indicator, pos_grid_local, plane_res, map_size) 120 | index_table = offset + index * 2 # the flat index for the 1st entry 121 | index_table_int = ti.cast(index_table, ti.int32) 122 | local_feature_0 += w * table[index_table_int] 123 | local_feature_1 += w * table[index_table_int + 1] 124 | 125 | xyzts_embedding[i, level * 2] = local_feature_0 126 | xyzts_embedding[i, level * 2 + 1] = local_feature_1 127 | 128 | 129 | @ti.kernel 130 | def hash_encode_kernel( 131 | xyzts: ti.template(), table: ti.template(), 132 | xyzts_embedding: ti.template(), hash_map_indicator: ti.template(), 133 | hash_map_sizes_field: ti.template(), hash_map_shapes_field: ti.template(), 134 | offsets: ti.template(), B: ti.i32, num_scales: ti.i32): 135 | # # # get hash table embedding 136 | ti.loop_config(block_dim=16) 137 | for i, level in ti.ndrange(B, num_scales): 138 | res_x = hash_map_shapes_field[level, 0] 139 | res_y = hash_map_shapes_field[level, 1] 140 | res_z = hash_map_shapes_field[level, 2] 141 | res_t = hash_map_shapes_field[level, 3] 142 | plane_res = ti.Vector([res_x, res_y, res_z, res_t]) 143 | pos = ti.Vector([xyzts[i, 0], xyzts[i, 1], xyzts[i, 2], xyzts[i, 3]]) * plane_res 144 | 145 | pos_grid_uint = ti.cast(ti.floor(pos), ti.uint32) # floor 146 | pos_grid_uint = ti.math.clamp(pos_grid_uint, 0, plane_res - 1) 147 | pos -= pos_grid_uint # pos now represents frac 148 | pos = ti.math.clamp(pos, 0.0, 1.0) 149 | 150 | offset = offsets[level] 151 | 152 | indicator = hash_map_indicator[level] 153 | map_size = hash_map_sizes_field[level] 154 | 155 | local_feature_0 = 0.0 156 | local_feature_1 = 0.0 157 | 158 | for idx in ti.static(range(16)): 159 | w = 1. 160 | pos_grid_local = uvec4(0) 161 | 162 | for d in ti.static(range(4)): 163 | t = linear_step(pos[d]) 164 | if (idx & (1 << d)) == 0: 165 | pos_grid_local[d] = pos_grid_uint[d] 166 | w *= 1 - t 167 | else: 168 | pos_grid_local[d] = pos_grid_uint[d] + 1 169 | w *= t 170 | 171 | index = grid_pos2hash_index(indicator, pos_grid_local, plane_res, map_size) 172 | index_table = offset + index * 2 # the flat index for the 1st entry 173 | index_table_int = ti.cast(index_table, ti.int32) 174 | local_feature_0 += w * table[index_table_int] 175 | local_feature_1 += w * table[index_table_int + 1] 176 | 177 | xyzts_embedding[i, level * 2] = local_feature_0 178 | xyzts_embedding[i, level * 2 + 1] = local_feature_1 179 | 180 | 181 | @ti.kernel 182 | def hash_encode_kernel_grad( 183 | xyzts: ti.template(), table: ti.template(), 184 | xyzts_embedding: ti.template(), hash_map_indicator: ti.template(), 185 | hash_map_sizes_field: ti.template(), hash_map_shapes_field: ti.template(), 186 | offsets: ti.template(), B: ti.i32, num_scales: ti.i32, xyzts_grad: ti.template(), table_grad: ti.template(), 187 | output_grad: ti.template()): 188 | # # # get hash table embedding 189 | 190 | ti.loop_config(block_dim=16) 191 | for i, level in ti.ndrange(B, num_scales): 192 | res_x = hash_map_shapes_field[level, 0] 193 | res_y = hash_map_shapes_field[level, 1] 194 | res_z = hash_map_shapes_field[level, 2] 195 | res_t = hash_map_shapes_field[level, 3] 196 | plane_res = ti.Vector([res_x, res_y, res_z, res_t]) 197 | pos = ti.Vector([xyzts[i, 0], xyzts[i, 1], xyzts[i, 2], xyzts[i, 3]]) * plane_res 198 | 199 | pos_grid_uint = ti.cast(ti.floor(pos), ti.uint32) # floor 200 | pos_grid_uint = ti.math.clamp(pos_grid_uint, 0, plane_res - 1) 201 | pos -= pos_grid_uint # pos now represents frac 202 | pos = ti.math.clamp(pos, 0.0, 1.0) 203 | 204 | offset = offsets[level] 205 | 206 | indicator = hash_map_indicator[level] 207 | map_size = hash_map_sizes_field[level] 208 | 209 | local_feature_0 = 0.0 210 | local_feature_1 = 0.0 211 | 212 | for idx in ti.static(range(16)): 213 | w = 1. 214 | pos_grid_local = uvec4(0) 215 | dw = ti.Vector([0., 0., 0., 0.]) 216 | # prods = ti.Vector([0., 0., 0.,0.]) 217 | for d in ti.static(range(4)): 218 | t = linear_step(pos[d]) 219 | dt = d_linear_step(pos[d]) 220 | if (idx & (1 << d)) == 0: 221 | pos_grid_local[d] = pos_grid_uint[d] 222 | w *= 1 - t 223 | dw[d] = -dt 224 | 225 | else: 226 | pos_grid_local[d] = pos_grid_uint[d] + 1 227 | w *= t 228 | dw[d] = dt 229 | 230 | index = grid_pos2hash_index(indicator, pos_grid_local, plane_res, map_size) 231 | index_table = offset + index * 2 # the flat index for the 1st entry 232 | index_table_int = ti.cast(index_table, ti.int32) 233 | table_grad[index_table_int] += w * output_grad[i, 2 * level] 234 | table_grad[index_table_int + 1] += w * output_grad[i, 2 * level + 1] 235 | for d in ti.static(range(4)): 236 | # eps = 1e-15 237 | # prod = w / ((linear_step(pos[d]) if idx & (1 << d) > 0 else 1 - linear_step(pos[d])) + eps) 238 | # prod=1.0 239 | # for k in range(4): 240 | # if k == d: 241 | # prod *= dw[k] 242 | # else: 243 | # prod *= 1- linear_step(pos[k]) if (idx & (1 << k) == 0) else linear_step(pos[k]) 244 | prod = dw[d] * ( 245 | linear_step(pos[(d + 1) % 4]) if (idx & (1 << ((d + 1) % 4)) > 0) else 1 - linear_step( 246 | pos[(d + 1) % 4]) 247 | ) * ( 248 | linear_step(pos[(d + 2) % 4]) if (idx & (1 << ((d + 2) % 4)) > 0) else 1 - linear_step( 249 | pos[(d + 2) % 4]) 250 | ) * ( 251 | linear_step(pos[(d + 3) % 4]) if (idx & (1 << ((d + 3) % 4)) > 0) else 1 - linear_step( 252 | pos[(d + 3) % 4]) 253 | ) 254 | xyzts_grad[i, d] += table[index_table_int] * prod * plane_res[d] * output_grad[i, 2 * level] 255 | xyzts_grad[i, d] += table[index_table_int + 1] * prod * plane_res[d] * output_grad[i, 2 * level + 1] 256 | 257 | 258 | @ti.kernel 259 | def hash_encode_kernel_smoothstep_grad( 260 | xyzts: ti.template(), table: ti.template(), 261 | xyzts_embedding: ti.template(), hash_map_indicator: ti.template(), 262 | hash_map_sizes_field: ti.template(), hash_map_shapes_field: ti.template(), 263 | offsets: ti.template(), B: ti.i32, num_scales: ti.i32, xyzts_grad: ti.template(), table_grad: ti.template(), 264 | output_grad: ti.template()): 265 | # # # get hash table embedding 266 | 267 | ti.loop_config(block_dim=16) 268 | for i, level in ti.ndrange(B, num_scales): 269 | res_x = hash_map_shapes_field[level, 0] 270 | res_y = hash_map_shapes_field[level, 1] 271 | res_z = hash_map_shapes_field[level, 2] 272 | res_t = hash_map_shapes_field[level, 3] 273 | plane_res = ti.Vector([res_x, res_y, res_z, res_t]) 274 | pos = ti.Vector([xyzts[i, 0], xyzts[i, 1], xyzts[i, 2], xyzts[i, 3]]) * plane_res 275 | 276 | pos_grid_uint = ti.cast(ti.floor(pos), ti.uint32) # floor 277 | pos_grid_uint = ti.math.clamp(pos_grid_uint, 0, plane_res - 1) 278 | pos -= pos_grid_uint # pos now represents frac 279 | pos = ti.math.clamp(pos, 0.0, 1.0) 280 | 281 | offset = offsets[level] 282 | 283 | indicator = hash_map_indicator[level] 284 | map_size = hash_map_sizes_field[level] 285 | 286 | local_feature_0 = 0.0 287 | local_feature_1 = 0.0 288 | 289 | for idx in ti.static(range(16)): 290 | w = 1. 291 | pos_grid_local = uvec4(0) 292 | dw = ti.Vector([0., 0., 0., 0.]) 293 | # prods = ti.Vector([0., 0., 0.,0.]) 294 | for d in ti.static(range(4)): 295 | t = smooth_step(pos[d]) 296 | dt = d_smooth_step(pos[d]) 297 | if (idx & (1 << d)) == 0: 298 | pos_grid_local[d] = pos_grid_uint[d] 299 | w *= 1 - t 300 | dw[d] = -dt 301 | 302 | else: 303 | pos_grid_local[d] = pos_grid_uint[d] + 1 304 | w *= t 305 | dw[d] = dt 306 | 307 | index = grid_pos2hash_index(indicator, pos_grid_local, plane_res, map_size) 308 | index_table = offset + index * 2 # the flat index for the 1st entry 309 | index_table_int = ti.cast(index_table, ti.int32) 310 | table_grad[index_table_int] += w * output_grad[i, 2 * level] 311 | table_grad[index_table_int + 1] += w * output_grad[i, 2 * level + 1] 312 | for d in ti.static(range(4)): 313 | # eps = 1e-15 314 | # prod = w / ((smooth_step(pos[d]) if idx & (1 << d) > 0 else 1 - smooth_step(pos[d])) + eps) 315 | # prod=1.0 316 | # for k in range(4): 317 | # if k == d: 318 | # prod *= dw[k] 319 | # else: 320 | # prod *= 1- smooth_step(pos[k]) if (idx & (1 << k) == 0) else smooth_step(pos[k]) 321 | prod = dw[d] * ( 322 | smooth_step(pos[(d + 1) % 4]) if (idx & (1 << ((d + 1) % 4)) > 0) else 1 - smooth_step( 323 | pos[(d + 1) % 4]) 324 | ) * ( 325 | smooth_step(pos[(d + 2) % 4]) if (idx & (1 << ((d + 2) % 4)) > 0) else 1 - smooth_step( 326 | pos[(d + 2) % 4]) 327 | ) * ( 328 | smooth_step(pos[(d + 3) % 4]) if (idx & (1 << ((d + 3) % 4)) > 0) else 1 - smooth_step( 329 | pos[(d + 3) % 4]) 330 | ) 331 | xyzts_grad[i, d] += table[index_table_int] * prod * plane_res[d] * output_grad[i, 2 * level] 332 | xyzts_grad[i, d] += table[index_table_int + 1] * prod * plane_res[d] * output_grad[i, 2 * level + 1] 333 | 334 | 335 | class Hash4Encoder(torch.nn.Module): 336 | def __init__(self, 337 | max_res=np.array([512, 512, 512, 512]), 338 | min_res=np.array([16, 16, 16, 16]), 339 | num_scales=16, 340 | max_num_queries=10000000, 341 | data_type=data_type, 342 | max_params=2 ** 19, 343 | interpolation='linear' 344 | ): 345 | super(Hash4Encoder, self).__init__() 346 | 347 | b = np.exp((np.log(max_res) - np.log(min_res)) / (num_scales - 1)) 348 | 349 | self.num_scales = num_scales 350 | self.interpolation = interpolation 351 | self.offsets = ti.field(ti.i32, shape=(num_scales,)) 352 | self.hash_map_sizes_field = ti.field(ti.uint32, shape=(num_scales,)) 353 | self.hash_map_shapes_field = ti.field(ti.uint32, shape=(num_scales, 4)) 354 | self.hash_map_indicator = ti.field(ti.i32, shape=(num_scales,)) 355 | 356 | offset_ = 0 357 | hash_map_sizes = [] 358 | hash_map_shapes = [] 359 | for i in range(num_scales): # loop through each level 360 | res = np.ceil(min_res * np.power(b, i)).astype(int) 361 | hash_map_shapes.append(res) 362 | params_in_level_raw = (res[0] + 1) * (res[1] + 1) * ( 363 | res[2] + 1) * (res[3] + 1) # number of params required to store everything 364 | params_in_level = int(params_in_level_raw) if params_in_level_raw % 8 == 0 \ 365 | else int((params_in_level_raw + 8 - 1) / 8) * 8 # make sure is multiple of 8 366 | # if max_params has enough space, store everything; otherwise store as much as we can 367 | params_in_level = min(max_params, params_in_level) 368 | hash_map_sizes.append(params_in_level) 369 | self.hash_map_indicator[ 370 | i] = 1 if params_in_level_raw <= params_in_level else 0 # i if have stored everything, 0 if collision 371 | self.offsets[i] = offset_ 372 | offset_ += params_in_level * 2 # multiply by two because we store 2 features per entry 373 | print("hash map sizes", hash_map_sizes) 374 | print("hash map shapes", hash_map_shapes) 375 | print("offsets", self.offsets.to_numpy()) 376 | print("hash map indicator", self.hash_map_indicator.to_numpy()) 377 | size = np.uint32(np.array(hash_map_sizes)) 378 | self.hash_map_sizes_field.from_numpy(size) 379 | shape = np.uint32(np.array(hash_map_shapes)) 380 | self.hash_map_shapes_field.from_numpy(shape) 381 | 382 | self.total_hash_size = offset_ 383 | 384 | # the main storage, pytorch 385 | self.hash_table = torch.nn.Parameter(torch.zeros(self.total_hash_size, 386 | dtype=torch_type), 387 | requires_grad=True) 388 | random_initialize(self.hash_table) # randomly initialize 389 | 390 | # the taichi counterpart of self.hash_table 391 | self.parameter_fields = ti.field(data_type, 392 | shape=(self.total_hash_size,), 393 | needs_grad=True) 394 | 395 | # output fields will have num_scales * 2 entries (2 features per scale) 396 | self.output_fields = ti.field(dtype=data_type, 397 | shape=(max_num_queries, num_scales * 2), 398 | needs_grad=True) 399 | if interpolation == 'linear': 400 | self._hash_encode_kernel = hash_encode_kernel 401 | self._hash_encode_kernel_grad = hash_encode_kernel_grad 402 | elif interpolation == 'smoothstep': 403 | self._hash_encode_kernel = hash_encode_kernel_smoothstep 404 | self._hash_encode_kernel_grad = hash_encode_kernel_smoothstep_grad 405 | else: 406 | raise NotImplementedError 407 | # input assumes a dimension of 4 408 | self.input_fields = ti.field(dtype=data_type, 409 | shape=(max_num_queries, 4), 410 | needs_grad=True) 411 | self.input_fields_grad = ti.field(dtype=data_type, 412 | shape=(max_num_queries, 4), 413 | needs_grad=True) 414 | self.parameter_fields_grad = ti.field(dtype=data_type, 415 | shape=(self.total_hash_size,), 416 | needs_grad=True) 417 | self.output_grad = ti.field(dtype=data_type, 418 | shape=(max_num_queries, num_scales * 2), 419 | needs_grad=True) 420 | 421 | self.register_buffer( 422 | 'hash_grad', 423 | torch.zeros(self.total_hash_size, dtype=torch_type), 424 | persistent=False 425 | ) 426 | self.register_buffer( 427 | 'hash_grad2', 428 | torch.zeros(self.total_hash_size, dtype=torch_type), 429 | persistent=False 430 | ) 431 | self.register_buffer( 432 | 'input_grad', 433 | torch.zeros(max_num_queries, 4, dtype=torch_type), 434 | persistent=False 435 | ) 436 | self.register_buffer( 437 | 'input_grad2', 438 | torch.zeros(max_num_queries, 4, dtype=torch_type), 439 | persistent=False 440 | ) 441 | self.register_buffer( 442 | 'output_embedding', 443 | torch.zeros(max_num_queries, num_scales * 2, dtype=torch_type), 444 | persistent=False 445 | ) 446 | 447 | class _module_function(torch.autograd.Function): 448 | @staticmethod 449 | @custom_fwd(cast_inputs=torch_type) 450 | def forward(ctx, input_pos, params): 451 | output_embedding = self.output_embedding[:input_pos. 452 | shape[0]].contiguous( 453 | ) 454 | torch2ti(self.input_fields, input_pos.contiguous()) 455 | torch2ti(self.parameter_fields, params.contiguous()) 456 | 457 | self._hash_encode_kernel( 458 | self.input_fields, 459 | self.parameter_fields, 460 | self.output_fields, 461 | self.hash_map_indicator, 462 | self.hash_map_sizes_field, 463 | self.hash_map_shapes_field, 464 | self.offsets, 465 | input_pos.shape[0], 466 | self.num_scales, 467 | ) 468 | ti2torch(self.output_fields, output_embedding) 469 | ctx.save_for_backward(input_pos, params) 470 | return output_embedding 471 | 472 | @staticmethod 473 | @custom_bwd 474 | def backward(ctx, doutput): 475 | self.zero_grad() 476 | input_pos, params = ctx.saved_tensors 477 | return self._module_function_grad.apply(input_pos, params, doutput) 478 | 479 | class _module_function_ad(torch.autograd.Function): 480 | 481 | @staticmethod 482 | @custom_fwd(cast_inputs=torch_type) 483 | def forward(ctx, input_pos, params): 484 | output_embedding = self.output_embedding[:input_pos. 485 | shape[0]].contiguous( 486 | ) 487 | torch2ti(self.input_fields, input_pos.contiguous()) 488 | torch2ti(self.parameter_fields, params.contiguous()) 489 | 490 | self._hash_encode_kernel( 491 | self.input_fields, 492 | self.parameter_fields, 493 | self.output_fields, 494 | self.hash_map_indicator, 495 | self.hash_map_sizes_field, 496 | self.hash_map_shapes_field, 497 | self.offsets, 498 | input_pos.shape[0], 499 | self.num_scales, 500 | ) 501 | ti2torch(self.output_fields, output_embedding) 502 | ctx.save_for_backward(input_pos, params) 503 | return output_embedding 504 | 505 | @staticmethod 506 | @custom_bwd 507 | def backward(ctx, doutput): 508 | self.zero_grad() 509 | 510 | torch2ti_grad(self.output_fields, doutput.contiguous()) 511 | self._hash_encode_kernel.grad( 512 | self.input_fields, 513 | self.parameter_fields, 514 | self.output_fields, 515 | self.hash_map_indicator, 516 | self.hash_map_sizes_field, 517 | self.hash_map_shapes_field, 518 | self.offsets, 519 | doutput.shape[0], 520 | self.num_scales, 521 | ) 522 | ti2torch_grad(self.parameter_fields, 523 | self.hash_grad.contiguous()) 524 | ti2torch_grad(self.input_fields, self.input_grad.contiguous()[:doutput.shape[0]]) 525 | return self.input_grad[:doutput.shape[0]], self.hash_grad 526 | 527 | class _module_function_grad(torch.autograd.Function): 528 | @staticmethod 529 | @custom_fwd(cast_inputs=torch_type) 530 | def forward(ctx, input_pos, params, doutput): 531 | torch2ti(self.input_fields, input_pos.contiguous()) 532 | torch2ti(self.parameter_fields, params.contiguous()) 533 | torch2ti(self.output_grad, doutput.contiguous()) 534 | self._hash_encode_kernel_grad( 535 | self.input_fields, 536 | self.parameter_fields, 537 | self.output_fields, 538 | self.hash_map_indicator, 539 | self.hash_map_sizes_field, 540 | self.hash_map_shapes_field, 541 | self.offsets, 542 | doutput.shape[0], 543 | self.num_scales, 544 | self.input_fields_grad, 545 | self.parameter_fields_grad, 546 | self.output_grad 547 | ) 548 | 549 | ti2torch(self.input_fields_grad, self.input_grad.contiguous()) 550 | ti2torch(self.parameter_fields_grad, self.hash_grad.contiguous()) 551 | return self.input_grad[:doutput.shape[0]], self.hash_grad 552 | 553 | @staticmethod 554 | @custom_bwd 555 | def backward(ctx, d_input_grad, d_hash_grad): 556 | self.zero_grad_2() 557 | torch2ti_grad(self.input_fields_grad, d_input_grad.contiguous()) 558 | torch2ti_grad(self.parameter_fields_grad, d_hash_grad.contiguous()) 559 | self._hash_encode_kernel_grad.grad( 560 | self.input_fields, 561 | self.parameter_fields, 562 | self.output_fields, 563 | self.hash_map_indicator, 564 | self.hash_map_sizes_field, 565 | self.hash_map_shapes_field, 566 | self.offsets, 567 | d_input_grad.shape[0], 568 | self.num_scales, 569 | self.input_fields_grad, 570 | self.parameter_fields_grad, 571 | self.output_grad 572 | ) 573 | ti2torch_grad(self.input_fields, self.input_grad2.contiguous()[:d_input_grad.shape[0]]) 574 | ti2torch_grad(self.parameter_fields, self.hash_grad2.contiguous()) 575 | # set_trace(term_size=(120,30)) 576 | return self.input_grad2[:d_input_grad.shape[0]], self.hash_grad2, None 577 | 578 | self._module_function = _module_function 579 | self._module_function_grad = _module_function_grad 580 | 581 | def zero_grad(self): 582 | self.parameter_fields.grad.fill(0.) 583 | self.input_fields.grad.fill(0.) 584 | self.input_fields_grad.fill(0.) 585 | self.parameter_fields_grad.fill(0.) 586 | 587 | def zero_grad_2(self): 588 | self.parameter_fields.grad.fill(0.) 589 | self.input_fields.grad.fill(0.) 590 | # self.input_fields_grad.grad.fill(0.) 591 | # self.parameter_fields_grad.grad.fill(0.) 592 | 593 | def forward(self, positions): 594 | # positions: (N, 4), normalized to [-1, 1] 595 | positions = positions * 0.5 + 0.5 596 | return self._module_function.apply(positions, self.hash_table) 597 | 598 | if __name__ == '__main__': 599 | ti.init(arch=ti.cpu, device_memory_GB=4.0) 600 | 601 | import torch.nn as nn 602 | import torch.nn.functional as F 603 | 604 | print(torch.__version__) 605 | 606 | 607 | class NeRFSmallPotential(nn.Module): 608 | def __init__(self, 609 | num_layers=3, 610 | hidden_dim=64, 611 | geo_feat_dim=15, 612 | num_layers_color=2, 613 | hidden_dim_color=16, 614 | input_ch=3, 615 | use_f=False 616 | ): 617 | super(NeRFSmallPotential, self).__init__() 618 | 619 | self.input_ch = input_ch 620 | self.rgb = torch.nn.Parameter(torch.tensor([0.0])) 621 | 622 | # sigma network 623 | self.num_layers = num_layers 624 | self.hidden_dim = hidden_dim 625 | self.geo_feat_dim = geo_feat_dim 626 | 627 | sigma_net = [] 628 | for l in range(num_layers): 629 | if l == 0: 630 | in_dim = self.input_ch 631 | else: 632 | in_dim = hidden_dim 633 | 634 | if l == num_layers - 1: 635 | out_dim = hidden_dim # 1 sigma + 15 SH features for color 636 | else: 637 | out_dim = hidden_dim 638 | 639 | sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) 640 | self.sigma_net = nn.ModuleList(sigma_net) 641 | self.out = nn.Linear(hidden_dim, 3, bias=True) 642 | self.use_f = use_f 643 | if use_f: 644 | self.out_f = nn.Linear(hidden_dim, hidden_dim, bias=True) 645 | self.out_f2 = nn.Linear(hidden_dim, 3, bias=True) 646 | 647 | def forward(self, x): 648 | h = x 649 | for l in range(self.num_layers): 650 | h = self.sigma_net[l](h) 651 | h = F.relu(h, True) 652 | 653 | v = self.out(h) 654 | if self.use_f: 655 | f = self.out_f(h) 656 | f = F.relu(f, True) 657 | f = self.out_f2(f) 658 | else: 659 | f = v * 0 660 | return v, f 661 | 662 | 663 | 664 | # embedding = h(x) 665 | network_vel = NeRFSmallPotential(input_ch=32) 666 | embed_vel = Hash4Encoder() 667 | 668 | pts = torch.rand(100, 4) 669 | pts.requires_grad = True 670 | with torch.enable_grad(): 671 | h = embed_vel(pts) 672 | vel_output, f_output = network_vel(h) 673 | 674 | print('vel_output', vel_output.shape) 675 | print('h', h.shape) 676 | def g(x): 677 | return network_vel(x)[0] 678 | 679 | jac = vmap(jacrev(g))(h) 680 | print('jac', jac.shape) 681 | jac_x = [] #_get_minibatch_jacobian(h, pts) 682 | for j in range(h.shape[1]): 683 | dy_j_dx = torch.autograd.grad( 684 | h[:, j], 685 | pts, 686 | torch.ones_like(h[:, j], device='cpu'), 687 | retain_graph=True, 688 | create_graph=True, 689 | )[0].view(pts.shape[0], -1) 690 | jac_x.append(dy_j_dx.unsqueeze(1)) 691 | jac_x = torch.cat(jac_x, dim=1) 692 | print(jac_x.shape) 693 | jac = jac @ jac_x 694 | assert jac.shape == (pts.shape[0], 3, 4) 695 | _u_x, _u_y, _u_z, _u_t = [torch.squeeze(_, -1) for _ in jac.split(1, dim=-1)] # (N,1) 696 | 697 | jac = torch.stack([_u_x, _u_y, _u_z], dim=-1) # [N, 3, 3] 698 | curl = torch.stack([jac[:, 2, 1] - jac[:, 1, 2], 699 | jac[:, 0, 2] - jac[:, 2, 0], 700 | jac[:, 1, 0] - jac[:, 0, 1]], dim=-1) # [N, 3] 701 | # curl = curl.view(list(pts_shape[:-1]) + [3]) # [..., 3] 702 | print(curl.shape) 703 | vorticity_norm = torch.norm(curl, dim=-1, keepdim=True) 704 | 705 | vorticity_norm_grad = [] 706 | 707 | print(vorticity_norm.shape) 708 | for j in range(vorticity_norm.shape[1]): 709 | # breakpoint() 710 | 711 | dy_j_dx = torch.autograd.grad( 712 | vorticity_norm[:, j], 713 | pts, 714 | torch.ones_like(vorticity_norm[:, j], device='cpu'), 715 | retain_graph=True, 716 | create_graph=True, 717 | )[0] 718 | vorticity_norm_grad.append(dy_j_dx.unsqueeze(1)) 719 | vorticity_norm_grad = torch.cat(vorticity_norm_grad, dim=1) 720 | print(vorticity_norm_grad.shape) -------------------------------------------------------------------------------- /taichi_encoders/mgpcg.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | from .taichi_utils import * 5 | 6 | 7 | @ti.data_oriented 8 | class MGPCG: 9 | ''' 10 | Grid-based MGPCG solver for the possion equation. 11 | 12 | .. note:: 13 | 14 | This solver only runs on CPU and CUDA backends since it requires the 15 | ``pointer`` SNode. 16 | ''' 17 | 18 | def __init__(self, boundary_types, N, dim=2, base_level=3, real=float): 19 | ''' 20 | :parameter dim: Dimensionality of the fields. 21 | :parameter N: Grid resolutions. 22 | :parameter n_mg_levels: Number of multigrid levels. 23 | ''' 24 | 25 | # grid parameters 26 | self.use_multigrid = True 27 | 28 | self.N = N 29 | self.n_mg_levels = int(math.log2(min(N))) - base_level + 1 30 | self.pre_and_post_smoothing = 2 31 | self.bottom_smoothing = 50 32 | self.dim = dim 33 | self.real = real 34 | 35 | # setup sparse simulation data arrays 36 | self.r = [ti.field(dtype=self.real) 37 | for _ in range(self.n_mg_levels)] # residual 38 | self.z = [ti.field(dtype=self.real) 39 | for _ in range(self.n_mg_levels)] # M^-1 self.r 40 | self.x = ti.field(dtype=self.real) # solution 41 | self.p = ti.field(dtype=self.real) # conjugate gradient 42 | self.Ap = ti.field(dtype=self.real) # matrix-vector product 43 | self.alpha = ti.field(dtype=self.real) # step size 44 | self.beta = ti.field(dtype=self.real) # step size 45 | self.sum = ti.field(dtype=self.real) # storage for reductions 46 | self.r_mean = ti.field(dtype=self.real) # storage for avg of r 47 | self.num_entries = math.prod(self.N) 48 | 49 | indices = ti.ijk if self.dim == 3 else ti.ij 50 | self.grid = ti.root.pointer(indices, [n // 4 for n in self.N]).dense( 51 | indices, 4).place(self.x, self.p, self.Ap) 52 | 53 | for l in range(self.n_mg_levels): 54 | self.grid = ti.root.pointer(indices, 55 | [n // (4 * 2 ** l) for n in self.N]).dense( 56 | indices, 57 | 4).place(self.r[l], self.z[l]) 58 | 59 | ti.root.place(self.alpha, self.beta, self.sum, self.r_mean) 60 | 61 | self.boundary_types = boundary_types 62 | 63 | @ti.func 64 | def init_r(self, I, r_I): 65 | self.r[0][I] = r_I 66 | self.z[0][I] = 0 67 | self.Ap[I] = 0 68 | self.p[I] = 0 69 | self.x[I] = 0 70 | 71 | @ti.kernel 72 | def init(self, r: ti.template(), k: ti.template()): 73 | ''' 74 | Set up the solver for $\nabla^2 x = k r$, a scaled Poisson problem. 75 | :parameter k: (scalar) A scaling factor of the right-hand side. 76 | :parameter r: (ti.field) Unscaled right-hand side. 77 | ''' 78 | for I in ti.grouped(ti.ndrange(*self.N)): 79 | self.init_r(I, r[I] * k) 80 | 81 | @ti.kernel 82 | def get_result(self, x: ti.template()): 83 | ''' 84 | Get the solution field. 85 | 86 | :parameter x: (ti.field) The field to store the solution 87 | ''' 88 | for I in ti.grouped(ti.ndrange(*self.N)): 89 | x[I] = self.x[I] 90 | 91 | @ti.func 92 | def neighbor_sum(self, x, I): 93 | dims = x.shape 94 | ret = ti.cast(0.0, self.real) 95 | for i in ti.static(range(self.dim)): 96 | offset = ti.Vector.unit(self.dim, i) 97 | # add right if has right 98 | if I[i] < dims[i] - 1: 99 | ret += x[I + offset] 100 | # add left if has left 101 | if I[i] > 0: 102 | ret += x[I - offset] 103 | return ret 104 | 105 | @ti.func 106 | def num_fluid_neighbors(self, x, I): 107 | dims = x.shape 108 | num = 2.0 * self.dim 109 | for i in ti.static(range(self.dim)): 110 | if I[i] <= 0 and self.boundary_types[i, 0] == 2: 111 | num -= 1.0 112 | if I[i] >= dims[i] - 1 and self.boundary_types[i, 1] == 2: 113 | num -= 1.0 114 | return num 115 | 116 | @ti.kernel 117 | def compute_Ap(self): 118 | for I in ti.grouped(self.Ap): 119 | multiplier = self.num_fluid_neighbors(self.p, I) 120 | self.Ap[I] = multiplier * self.p[I] - self.neighbor_sum( 121 | self.p, I) 122 | 123 | @ti.kernel 124 | def get_Ap(self, p: ti.template(), Ap: ti.template()): 125 | for I in ti.grouped(Ap): 126 | multiplier = self.num_fluid_neighbors(p, I) 127 | Ap[I] = multiplier * p[I] - self.neighbor_sum( 128 | p, I) 129 | 130 | @ti.kernel 131 | def reduce(self, p: ti.template(), q: ti.template()): 132 | self.sum[None] = 0 133 | for I in ti.grouped(p): 134 | self.sum[None] += p[I] * q[I] 135 | 136 | @ti.kernel 137 | def update_x(self): 138 | for I in ti.grouped(self.p): 139 | self.x[I] += self.alpha[None] * self.p[I] 140 | 141 | @ti.kernel 142 | def update_r(self): 143 | for I in ti.grouped(self.p): 144 | self.r[0][I] -= self.alpha[None] * self.Ap[I] 145 | 146 | @ti.kernel 147 | def update_p(self): 148 | for I in ti.grouped(self.p): 149 | self.p[I] = self.z[0][I] + self.beta[None] * self.p[I] 150 | 151 | @ti.kernel 152 | def restrict(self, l: ti.template()): 153 | for I in ti.grouped(self.r[l]): 154 | multiplier = self.num_fluid_neighbors(self.z[l], I) 155 | res = self.r[l][I] - (multiplier * self.z[l][I] - 156 | self.neighbor_sum(self.z[l], I)) 157 | self.r[l + 1][I // 2] += res * 1.0 / (self.dim - 1.0) 158 | 159 | @ti.kernel 160 | def prolongate(self, l: ti.template()): 161 | for I in ti.grouped(self.z[l]): 162 | self.z[l][I] += self.z[l + 1][I // 2] 163 | 164 | @ti.kernel 165 | def smooth(self, l: ti.template(), phase: ti.template()): 166 | # phase = red/black Gauss-Seidel phase 167 | for I in ti.grouped(self.r[l]): 168 | if (I.sum()) & 1 == phase: 169 | multiplier = self.num_fluid_neighbors(self.z[l], I) 170 | self.z[l][I] = (self.r[l][I] + self.neighbor_sum( 171 | self.z[l], I)) / multiplier 172 | 173 | @ti.kernel 174 | def recenter(self, r: ti.template()): # so that the mean value of r is 0 175 | self.r_mean[None] = 0.0 176 | for I in ti.grouped(r): 177 | self.r_mean[None] += r[I] / self.num_entries 178 | for I in ti.grouped(r): 179 | r[I] -= self.r_mean[None] 180 | 181 | def apply_preconditioner(self): 182 | self.z[0].fill(0) 183 | for l in range(self.n_mg_levels - 1): 184 | for i in range(self.pre_and_post_smoothing): 185 | self.smooth(l, 0) 186 | self.smooth(l, 1) 187 | self.z[l + 1].fill(0) 188 | self.r[l + 1].fill(0) 189 | self.restrict(l) 190 | 191 | for i in range(self.bottom_smoothing): 192 | self.smooth(self.n_mg_levels - 1, 0) 193 | self.smooth(self.n_mg_levels - 1, 1) 194 | 195 | for l in reversed(range(self.n_mg_levels - 1)): 196 | self.prolongate(l) 197 | for i in range(self.pre_and_post_smoothing): 198 | self.smooth(l, 1) 199 | self.smooth(l, 0) 200 | 201 | def solve(self, 202 | max_iters=-1, 203 | eps=1e-12, 204 | tol=1e-12, 205 | verbose=False): 206 | ''' 207 | Solve a Poisson problem. 208 | 209 | :parameter max_iters: Specify the maximal iterations. -1 for no limit. 210 | :parameter eps: Specify a non-zero value to prevent ZeroDivisionError. 211 | :parameter abs_tol: Specify the absolute tolerance of loss. 212 | :parameter rel_tol: Specify the tolerance of loss relative to initial loss. 213 | ''' 214 | all_neumann = (self.boundary_types.sum() == 2 * 2 * self.dim) 215 | 216 | # self.r = b - Ax = b since self.x = 0 217 | # self.p = self.r = self.r + 0 self.p 218 | 219 | if all_neumann: 220 | self.recenter(self.r[0]) 221 | if self.use_multigrid: 222 | self.apply_preconditioner() 223 | else: 224 | self.z[0].copy_from(self.r[0]) 225 | 226 | self.update_p() 227 | 228 | self.reduce(self.z[0], self.r[0]) 229 | old_zTr = self.sum[None] 230 | #print("[MGPCG] Starting error: ", math.sqrt(old_zTr)) 231 | 232 | # Conjugate gradients 233 | it = 0 234 | start_t = time.time() 235 | while max_iters == -1 or it < max_iters: 236 | # self.alpha = rTr / pTAp 237 | self.compute_Ap() 238 | self.reduce(self.p, self.Ap) 239 | pAp = self.sum[None] 240 | self.alpha[None] = old_zTr / (pAp + eps) 241 | 242 | # self.x = self.x + self.alpha self.p 243 | self.update_x() 244 | 245 | # self.r = self.r - self.alpha self.Ap 246 | self.update_r() 247 | 248 | # check for convergence 249 | self.reduce(self.r[0], self.r[0]) 250 | rTr = self.sum[None] 251 | 252 | if verbose: 253 | print(f'iter {it}, |residual|_2={math.sqrt(rTr)}') 254 | 255 | if rTr < tol: 256 | end_t = time.time() 257 | # print("[MGPCG] final error: ", math.sqrt(rTr), " using time: ", end_t - start_t) 258 | return 259 | 260 | if all_neumann: 261 | self.recenter(self.r[0]) 262 | # self.z = M^-1 self.r 263 | if self.use_multigrid: 264 | self.apply_preconditioner() 265 | else: 266 | self.z[0].copy_from(self.r[0]) 267 | 268 | # self.beta = new_rTr / old_rTr 269 | self.reduce(self.z[0], self.r[0]) 270 | new_zTr = self.sum[None] 271 | self.beta[None] = new_zTr / (old_zTr + eps) 272 | 273 | # self.p = self.z + self.beta self.p 274 | self.update_p() 275 | old_zTr = new_zTr 276 | 277 | it += 1 278 | 279 | end_t = time.time() 280 | # print("[MGPCG] Return without converging at iter: ", it, " with final error: ", math.sqrt(rTr), " using time: ", 281 | # end_t - start_t) 282 | 283 | 284 | class MGPCG_2(MGPCG): 285 | 286 | def __init__(self, boundary_types, N, base_level=3, real=float): 287 | super().__init__(boundary_types, N, dim=2, base_level=base_level, real=real) 288 | 289 | self.u_div = ti.field(float, shape=N) 290 | self.p = ti.field(float, shape=N) 291 | self.boundary_types = boundary_types 292 | 293 | @ti.kernel 294 | def apply_bc(self, u_horizontal: ti.template(), u_vertical: ti.template()): 295 | u_dim, v_dim = u_horizontal.shape 296 | for i, j in u_horizontal: 297 | if i == 0 and self.boundary_types[0, 0] == 2: 298 | u_horizontal[i, j] = 0 299 | if i == u_dim - 1 and self.boundary_types[0, 1] == 2: 300 | u_horizontal[i, j] = 0 301 | u_dim, v_dim = u_vertical.shape 302 | for i, j in u_vertical: 303 | if j == 0 and self.boundary_types[1, 0] == 2: 304 | u_vertical[i, j] = 0 305 | if j == v_dim - 1 and self.boundary_types[1, 1] == 2: 306 | u_vertical[i, j] = 0 307 | 308 | @ti.kernel 309 | def divergence(self, u_horizontal: ti.template(), u_vertical: ti.template()): 310 | u_dim, v_dim = self.u_div.shape 311 | for i, j in self.u_div: 312 | vl = sample(u_horizontal, i, j) 313 | vr = sample(u_horizontal, i + 1, j) 314 | vb = sample(u_vertical, i, j) 315 | vt = sample(u_vertical, i, j + 1) 316 | self.u_div[i, j] = vr - vl + vt - vb 317 | 318 | @ti.kernel 319 | def subtract_grad_p(self, u_horizontal: ti.template(), u_vertical: ti.template()): 320 | u_dim, v_dim = self.p.shape 321 | for i, j in u_horizontal: 322 | pr = sample(self.p, i, j) 323 | pl = sample(self.p, i - 1, j) 324 | if i - 1 < 0: 325 | pl = 0 326 | if i >= u_dim: 327 | pr = 0 328 | u_horizontal[i, j] -= (pr - pl) 329 | for i, j in u_vertical: 330 | pt = sample(self.p, i, j) 331 | pb = sample(self.p, i, j - 1) 332 | if j - 1 < 0: 333 | pb = 0 334 | if j >= v_dim: 335 | pt = 0 336 | u_vertical[i, j] -= pt - pb 337 | 338 | def solve_pressure_MGPCG(self, verbose): 339 | self.init(self.u_div, -1) 340 | self.solve(max_iters=400, verbose=verbose, tol=1.e-12) 341 | self.get_result(self.p) 342 | 343 | def Poisson(self, u_horizontal, u_vertical, verbose=False): 344 | self.apply_bc(u_horizontal, u_vertical) 345 | self.divergence(u_horizontal, u_vertical) 346 | self.solve_pressure_MGPCG(verbose=verbose) 347 | self.subtract_grad_p(u_horizontal, u_vertical) 348 | self.apply_bc(u_horizontal, u_vertical) 349 | 350 | 351 | class MGPCG_3(MGPCG): 352 | 353 | def __init__(self, boundary_types, N, base_level=3, real=float): 354 | super().__init__(boundary_types, N, dim=3, base_level=base_level, real=real) 355 | 356 | rx, ry, rz = N 357 | self.u_div = ti.field(float, shape=N) 358 | self.p = ti.field(float, shape=N) 359 | self.boundary_types = boundary_types 360 | self.u_x = ti.field(float, shape=(rx + 1, ry, rz)) 361 | self.u_y = ti.field(float, shape=(rx, ry + 1, rz)) 362 | self.u_z = ti.field(float, shape=(rx, ry, rz + 1)) 363 | self.u = ti.Vector.field(3, float, shape=(rx, ry, rz)) 364 | self.u_y_bottom = ti.field(float, shape=(rx, 1, rz)) 365 | 366 | @ti.kernel 367 | def apply_bc(self, u_x: ti.template(), u_y: ti.template(), u_z: ti.template()): 368 | u_dim, v_dim, w_dim = u_x.shape 369 | for i, j, k in u_x: 370 | if i == 0 and self.boundary_types[0, 0] == 2: 371 | u_x[i, j, k] = 0 372 | if i == u_dim - 1 and self.boundary_types[0, 1] == 2: 373 | u_x[i, j, k] = 0 374 | u_dim, v_dim, w_dim = u_y.shape 375 | for i, j, k in u_y: 376 | if j == 0 and self.boundary_types[1, 0] == 2: 377 | u_y[i, j, k] = self.u_y_bottom[i, j, k] 378 | # u_y[i, j, k] = 0.5 379 | if j == v_dim - 1 and self.boundary_types[1, 1] == 2: 380 | u_y[i, j, k] = 0 381 | u_dim, v_dim, w_dim = u_z.shape 382 | for i, j, k in u_z: 383 | if k == 0 and self.boundary_types[2, 0] == 2: 384 | u_z[i, j, k] = 0 385 | if k == w_dim - 1 and self.boundary_types[2, 1] == 2: 386 | u_z[i, j, k] = 0 387 | 388 | @ti.kernel 389 | def divergence(self, u_x: ti.template(), u_y: ti.template(), u_z: ti.template()): 390 | u_dim, v_dim, w_dim = self.u_div.shape 391 | for i, j, k in self.u_div: 392 | vl = sample(u_x, i, j, k) 393 | vr = sample(u_x, i + 1, j, k) 394 | vb = sample(u_y, i, j, k) 395 | vt = sample(u_y, i, j + 1, k) 396 | va = sample(u_z, i, j, k) 397 | vc = sample(u_z, i, j, k + 1) 398 | self.u_div[i, j, k] = vr - vl + vt - vb + vc - va 399 | 400 | @ti.kernel 401 | def subtract_grad_p(self, u_x: ti.template(), u_y: ti.template(), u_z: ti.template()): 402 | u_dim, v_dim, w_dim = self.p.shape 403 | for i, j, k in u_x: 404 | pr = sample(self.p, i, j, k) 405 | pl = sample(self.p, i - 1, j, k) 406 | if i - 1 < 0: 407 | pl = 0 408 | if i >= u_dim: 409 | pr = 0 410 | u_x[i, j, k] -= (pr - pl) 411 | for i, j, k in u_y: 412 | pt = sample(self.p, i, j, k) 413 | pb = sample(self.p, i, j - 1, k) 414 | if j - 1 < 0: 415 | pb = 0 416 | if j >= v_dim: 417 | pt = 0 418 | u_y[i, j, k] -= pt - pb 419 | for i, j, k in u_z: 420 | pc = sample(self.p, i, j, k) 421 | pa = sample(self.p, i, j, k - 1) 422 | if k - 1 < 0: 423 | pa = 0 424 | if j >= w_dim: 425 | pc = 0 426 | u_z[i, j, k] -= pc - pa 427 | 428 | def solve_pressure_MGPCG(self, verbose): 429 | self.init(self.u_div, -1) 430 | self.solve(max_iters=400, verbose=verbose, tol=1.e-12) 431 | self.get_result(self.p) 432 | 433 | @ti.kernel 434 | def set_uy_bottom(self): 435 | for i, j, k in self.u_y: 436 | if j == 0 and self.boundary_types[1, 0] == 2: 437 | self.u_y_bottom[i, j, k] = self.u_y[i, j, k] 438 | 439 | def Poisson(self, vel, verbose=False): 440 | """ 441 | args: 442 | vel: torch tensor of shape (X, Y, Z, 3) 443 | returns: 444 | vel: torch tensor of shape (X, Y, Z, 3), projected 445 | """ 446 | self.u.from_torch(vel) 447 | split_central_vector(self.u, self.u_x, self.u_y, self.u_z) 448 | self.set_uy_bottom() 449 | self.apply_bc(self.u_x, self.u_y, self.u_z) 450 | self.divergence(self.u_x, self.u_y, self.u_z) 451 | self.solve_pressure_MGPCG(verbose=verbose) 452 | self.subtract_grad_p(self.u_x, self.u_y, self.u_z) 453 | self.apply_bc(self.u_x, self.u_y, self.u_z) 454 | get_central_vector(self.u_x, self.u_y, self.u_z, self.u) 455 | vel = self.u.to_torch() 456 | return vel 457 | -------------------------------------------------------------------------------- /taichi_encoders/taichi_utils.py: -------------------------------------------------------------------------------- 1 | import taichi as ti 2 | 3 | eps = 1.e-6 4 | 5 | @ti.kernel 6 | def copy_to(source: ti.template(), dest: ti.template()): 7 | for I in ti.grouped(source): 8 | dest[I] = source[I] 9 | 10 | @ti.kernel 11 | def scale_field(a: ti.template(), alpha: float, result: ti.template()): 12 | for I in ti.grouped(result): 13 | result[I] = alpha * a[I] 14 | 15 | @ti.kernel 16 | def add_fields(f1: ti.template(), f2: ti.template(), dest: ti.template(), multiplier: float): 17 | for I in ti.grouped(dest): 18 | dest[I] = f1[I] + multiplier * f2[I] 19 | 20 | @ti.func 21 | def lerp(vl, vr, frac): 22 | # frac: [0.0, 1.0] 23 | return vl + frac * (vr - vl) 24 | 25 | @ti.kernel 26 | def center_coords_func(pf: ti.template(), dx: float): 27 | for I in ti.grouped(pf): 28 | pf[I] = (I+0.5) * dx 29 | 30 | @ti.kernel 31 | def x_coords_func(pf: ti.template(), dx: float): 32 | for i, j, k in pf: 33 | pf[i, j, k] = ti.Vector([i, j + 0.5, k + 0.5]) * dx 34 | 35 | @ti.kernel 36 | def y_coords_func(pf: ti.template(), dx: float): 37 | for i, j, k in pf: 38 | pf[i, j, k] = ti.Vector([i + 0.5, j, k + 0.5]) * dx 39 | 40 | @ti.kernel 41 | def z_coords_func(pf: ti.template(), dx: float): 42 | for i, j, k in pf: 43 | pf[i, j, k] = ti.Vector([i + 0.5, j + 0.5, k]) * dx 44 | 45 | @ti.func 46 | def sample(qf: ti.template(), u: float, v: float, w: float): 47 | u_dim, v_dim, w_dim = qf.shape 48 | i = ti.max(0, ti.min(int(u), u_dim-1)) 49 | j = ti.max(0, ti.min(int(v), v_dim-1)) 50 | k = ti.max(0, ti.min(int(w), w_dim-1)) 51 | return qf[i, j, k] 52 | 53 | @ti.kernel 54 | def curl(vf: ti.template(), cf: ti.template(), dx: float): 55 | inv_dist = 1./(2*dx) 56 | for i, j, k in cf: 57 | vr = sample(vf, i+1, j, k) 58 | vl = sample(vf, i-1, j, k) 59 | vt = sample(vf, i, j+1, k) 60 | vb = sample(vf, i, j-1, k) 61 | vc = sample(vf, i, j, k+1) 62 | va = sample(vf, i, j, k-1) 63 | 64 | d_vx_dz = inv_dist * (vc.x - va.x) 65 | d_vx_dy = inv_dist * (vt.x - vb.x) 66 | 67 | d_vy_dx = inv_dist * (vr.y - vl.y) 68 | d_vy_dz = inv_dist * (vc.y - va.y) 69 | 70 | d_vz_dx = inv_dist * (vr.z - vl.z) 71 | d_vz_dy = inv_dist * (vt.z - vb.z) 72 | 73 | cf[i,j,k][0] = d_vz_dy - d_vy_dz 74 | cf[i,j,k][1] = d_vx_dz - d_vz_dx 75 | cf[i,j,k][2] = d_vy_dx - d_vx_dy 76 | 77 | @ti.kernel 78 | def get_central_vector(vx: ti.template(), vy: ti.template(), vz: ti.template(), vc: ti.template()): 79 | for i, j, k in vc: 80 | vc[i,j,k].x = 0.5 * (vx[i+1, j, k] + vx[i, j, k]) 81 | vc[i,j,k].y = 0.5 * (vy[i, j+1, k] + vy[i, j, k]) 82 | vc[i,j,k].z = 0.5 * (vz[i, j, k+1] + vz[i, j, k]) 83 | 84 | @ti.kernel 85 | def split_central_vector(vc: ti.template(), vx: ti.template(), vy: ti.template(), vz: ti.template()): 86 | for i, j, k in vx: 87 | r = sample(vc, i, j, k) 88 | l = sample(vc, i-1, j, k) 89 | vx[i,j,k] = 0.5 * (r.x + l.x) 90 | for i, j, k in vy: 91 | t = sample(vc, i, j, k) 92 | b = sample(vc, i, j-1, k) 93 | vy[i,j,k] = 0.5 * (t.y + b.y) 94 | for i, j, k in vz: 95 | c = sample(vc, i, j, k) 96 | a = sample(vc, i, j, k-1) 97 | vz[i,j,k] = 0.5 * (c.z + a.z) 98 | 99 | # # # interpolation 100 | @ti.func 101 | def N_2(x): 102 | result = 0.0 103 | abs_x = ti.abs(x) 104 | if abs_x < 0.5: 105 | result = 3.0/4.0 - abs_x ** 2 106 | elif abs_x < 1.5: 107 | result = 0.5 * (3.0/2.0-abs_x) ** 2 108 | return result 109 | 110 | @ti.func 111 | def dN_2(x): 112 | result = 0.0 113 | abs_x = ti.abs(x) 114 | if abs_x < 0.5: 115 | result = -2 * abs_x 116 | elif abs_x < 1.5: 117 | result = 0.5 * (2 * abs_x - 3) 118 | if x < 0.0: # if x < 0 then abs_x is -1 * x 119 | result *= -1 120 | return result 121 | 122 | @ti.func 123 | def interp_grad_2(vf, p, dx, BL_x = 0.5, BL_y = 0.5, BL_z = 0.5): 124 | u_dim, v_dim, w_dim = vf.shape 125 | 126 | u, v, w = p / dx 127 | u = u - BL_x 128 | v = v - BL_y 129 | w = w - BL_z 130 | s = ti.max(1., ti.min(u, u_dim-2-eps)) 131 | t = ti.max(1., ti.min(v, v_dim-2-eps)) 132 | l = ti.max(1., ti.min(w, w_dim-2-eps)) 133 | 134 | # floor 135 | iu, iv, iw = ti.floor(s), ti.floor(t), ti.floor(l) 136 | 137 | partial_x = 0. 138 | partial_y = 0. 139 | partial_z = 0. 140 | interped = 0. 141 | 142 | # loop over 16 indices 143 | for i in range(-1, 3): 144 | for j in range(-1, 3): 145 | for k in range(-1, 3): 146 | x_p_x_i = s - (iu + i) # x_p - x_i 147 | y_p_y_i = t - (iv + j) 148 | z_p_z_i = l - (iw + k) 149 | value = sample(vf, iu + i, iv + j, iw + k) 150 | partial_x += 1./dx * (value * dN_2(x_p_x_i) * N_2(y_p_y_i) * N_2(z_p_z_i)) 151 | partial_y += 1./dx * (value * N_2(x_p_x_i) * dN_2(y_p_y_i) * N_2(z_p_z_i)) 152 | partial_z += 1./dx * (value * N_2(x_p_x_i) * N_2(y_p_y_i) * dN_2(z_p_z_i)) 153 | interped += value * N_2(x_p_x_i) * N_2(y_p_y_i) * N_2(z_p_z_i) 154 | 155 | return interped, ti.Vector([partial_x, partial_y, partial_z]) -------------------------------------------------------------------------------- /taichi_encoders/utils.py: -------------------------------------------------------------------------------- 1 | import taichi as ti 2 | import torch 3 | from taichi.math import uvec3 4 | import numpy as np 5 | import cv2 6 | 7 | taichi_block_size = 128 8 | 9 | data_type = ti.f32 10 | torch_type = torch.float32 11 | 12 | MAX_SAMPLES = 1024 13 | NEAR_DISTANCE = 0.01 14 | SQRT3 = 1.7320508075688772 15 | SQRT3_MAX_SAMPLES = SQRT3 / 1024 16 | SQRT3_2 = 1.7320508075688772 * 2 17 | 18 | 19 | @ti.func 20 | def scalbn(x, exponent): 21 | return x * ti.math.pow(2, exponent) 22 | 23 | 24 | @ti.func 25 | def calc_dt(t, exp_step_factor, grid_size, scale): 26 | return ti.math.clamp(t * exp_step_factor, SQRT3_MAX_SAMPLES, 27 | SQRT3_2 * scale / grid_size) 28 | 29 | 30 | @ti.func 31 | def frexp_bit(x): 32 | exponent = 0 33 | if x != 0.0: 34 | # frac = ti.abs(x) 35 | bits = ti.bit_cast(x, ti.u32) 36 | exponent = ti.i32((bits & ti.u32(0x7f800000)) >> 23) - 127 37 | # exponent = (ti.i32(bits & ti.u32(0x7f800000)) >> 23) - 127 38 | bits &= ti.u32(0x7fffff) 39 | bits |= ti.u32(0x3f800000) 40 | frac = ti.bit_cast(bits, ti.f32) 41 | if frac < 0.5: 42 | exponent -= 1 43 | elif frac > 1.0: 44 | exponent += 1 45 | return exponent 46 | 47 | 48 | @ti.func 49 | def mip_from_pos(xyz, cascades): 50 | mx = ti.abs(xyz).max() 51 | # _, exponent = _frexp(mx) 52 | exponent = frexp_bit(ti.f32(mx)) + 1 53 | # frac, exponent = ti.frexp(ti.f32(mx)) 54 | return ti.min(cascades - 1, ti.max(0, exponent)) 55 | 56 | 57 | @ti.func 58 | def mip_from_dt(dt, grid_size, cascades): 59 | # _, exponent = _frexp(dt*grid_size) 60 | exponent = frexp_bit(ti.f32(dt * grid_size)) 61 | # frac, exponent = ti.frexp(ti.f32(dt*grid_size)) 62 | return ti.min(cascades - 1, ti.max(0, exponent)) 63 | 64 | 65 | @ti.func 66 | def __expand_bits(v): 67 | v = (v * ti.uint32(0x00010001)) & ti.uint32(0xFF0000FF) 68 | v = (v * ti.uint32(0x00000101)) & ti.uint32(0x0F00F00F) 69 | v = (v * ti.uint32(0x00000011)) & ti.uint32(0xC30C30C3) 70 | v = (v * ti.uint32(0x00000005)) & ti.uint32(0x49249249) 71 | return v 72 | 73 | 74 | @ti.func 75 | def __morton3D(xyz): 76 | xyz = __expand_bits(xyz) 77 | return xyz[0] | (xyz[1] << 1) | (xyz[2] << 2) 78 | 79 | 80 | @ti.func 81 | def __morton3D_invert(x): 82 | x = x & (0x49249249) 83 | x = (x | (x >> 2)) & ti.uint32(0xc30c30c3) 84 | x = (x | (x >> 4)) & ti.uint32(0x0f00f00f) 85 | x = (x | (x >> 8)) & ti.uint32(0xff0000ff) 86 | x = (x | (x >> 16)) & ti.uint32(0x0000ffff) 87 | return ti.int32(x) 88 | 89 | 90 | @ti.kernel 91 | def morton3D_invert_kernel(indices: ti.types.ndarray(ndim=1), 92 | coords: ti.types.ndarray(ndim=2)): 93 | for i in indices: 94 | ind = ti.uint32(indices[i]) 95 | coords[i, 0] = __morton3D_invert(ind >> 0) 96 | coords[i, 1] = __morton3D_invert(ind >> 1) 97 | coords[i, 2] = __morton3D_invert(ind >> 2) 98 | 99 | 100 | def morton3D_invert(indices): 101 | coords = torch.zeros(indices.size(0), 102 | 3, 103 | device=indices.device, 104 | dtype=torch.int32) 105 | morton3D_invert_kernel(indices.contiguous(), coords) 106 | ti.sync() 107 | return coords 108 | 109 | 110 | @ti.kernel 111 | def morton3D_kernel(xyzs: ti.types.ndarray(ndim=2), 112 | indices: ti.types.ndarray(ndim=1)): 113 | for s in indices: 114 | xyz = uvec3([xyzs[s, 0], xyzs[s, 1], xyzs[s, 2]]) 115 | indices[s] = ti.cast(__morton3D(xyz), ti.int32) 116 | 117 | 118 | def morton3D(coords1): 119 | indices = torch.zeros(coords1.size(0), 120 | device=coords1.device, 121 | dtype=torch.int32) 122 | morton3D_kernel(coords1.contiguous(), indices) 123 | ti.sync() 124 | return indices 125 | 126 | 127 | @ti.kernel 128 | def packbits(density_grid: ti.types.ndarray(ndim=1), 129 | density_threshold: float, 130 | density_bitfield: ti.types.ndarray(ndim=1)): 131 | 132 | for n in density_bitfield: 133 | bits = ti.uint8(0) 134 | 135 | for i in ti.static(range(8)): 136 | bits |= (ti.uint8(1) << i) if ( 137 | density_grid[8 * n + i] > density_threshold) else ti.uint8(0) 138 | 139 | density_bitfield[n] = bits 140 | 141 | 142 | @ti.kernel 143 | def torch2ti(field: ti.template(), data: ti.types.ndarray()): 144 | for I in ti.grouped(data): 145 | field[I] = data[I] 146 | 147 | 148 | @ti.kernel 149 | def ti2torch(field: ti.template(), data: ti.types.ndarray()): 150 | for I in ti.grouped(data): 151 | data[I] = field[I] 152 | 153 | 154 | @ti.kernel 155 | def ti2torch_grad(field: ti.template(), grad: ti.types.ndarray()): 156 | for I in ti.grouped(grad): 157 | grad[I] = field.grad[I] 158 | 159 | 160 | @ti.kernel 161 | def torch2ti_grad(field: ti.template(), grad: ti.types.ndarray()): 162 | for I in ti.grouped(grad): 163 | field.grad[I] = grad[I] 164 | 165 | 166 | @ti.kernel 167 | def torch2ti_vec(field: ti.template(), data: ti.types.ndarray()): 168 | for I in range(data.shape[0] // 2): 169 | field[I] = ti.Vector([data[I * 2], data[I * 2 + 1]]) 170 | 171 | 172 | @ti.kernel 173 | def ti2torch_vec(field: ti.template(), data: ti.types.ndarray()): 174 | for i, j in ti.ndrange(data.shape[0], data.shape[1] // 2): 175 | data[i, j * 2] = field[i, j][0] 176 | data[i, j * 2 + 1] = field[i, j][1] 177 | 178 | 179 | @ti.kernel 180 | def ti2torch_grad_vec(field: ti.template(), grad: ti.types.ndarray()): 181 | for I in range(grad.shape[0] // 2): 182 | grad[I * 2] = field.grad[I][0] 183 | grad[I * 2 + 1] = field.grad[I][1] 184 | 185 | 186 | @ti.kernel 187 | def torch2ti_grad_vec(field: ti.template(), grad: ti.types.ndarray()): 188 | for i, j in ti.ndrange(grad.shape[0], grad.shape[1] // 2): 189 | field.grad[i, j][0] = grad[i, j * 2] 190 | field.grad[i, j][1] = grad[i, j * 2 + 1] 191 | 192 | 193 | def extract_model_state_dict(ckpt_path, 194 | model_name='model', 195 | prefixes_to_ignore=[]): 196 | checkpoint = torch.load(ckpt_path, map_location='cpu') 197 | checkpoint_ = {} 198 | if 'state_dict' in checkpoint: # if it's a pytorch-lightning checkpoint 199 | checkpoint = checkpoint['state_dict'] 200 | for k, v in checkpoint.items(): 201 | if not k.startswith(model_name): 202 | continue 203 | k = k[len(model_name) + 1:] 204 | for prefix in prefixes_to_ignore: 205 | if k.startswith(prefix): 206 | break 207 | else: 208 | checkpoint_[k] = v 209 | return checkpoint_ 210 | 211 | 212 | def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]): 213 | if not ckpt_path: 214 | return 215 | model_dict = model.state_dict() 216 | checkpoint_ = extract_model_state_dict(ckpt_path, model_name, 217 | prefixes_to_ignore) 218 | model_dict.update(checkpoint_) 219 | model.load_state_dict(model_dict) 220 | 221 | def depth2img(depth): 222 | depth = (depth - depth.min()) / (depth.max() - depth.min()) 223 | depth_img = cv2.applyColorMap((depth * 255).astype(np.uint8), 224 | cv2.COLORMAP_TURBO) 225 | 226 | return depth_img -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import pdb 4 | import torch 5 | import vtk 6 | from vtk.util import numpy_support 7 | 8 | from ray_utils import get_rays, get_ray_directions, get_ndc_rays 9 | import os, imageio, json 10 | from tqdm import tqdm 11 | import torch.nn.functional as F 12 | from run_nerf_helpers import batchify_query, to8b 13 | from lpips import LPIPS 14 | from skimage.metrics import structural_similarity 15 | 16 | 17 | BOX_OFFSETS = torch.tensor([[[i,j,k] for i in [0, 1] for j in [0, 1] for k in [0, 1]]], 18 | device='cuda') 19 | 20 | class Vortex_Particles(torch.nn.Module): 21 | def __init__(self, P, T, R, fix_intensity=False): 22 | super(Vortex_Particles, self).__init__() 23 | self.P = P 24 | self.T = T 25 | 26 | self.initialized = False 27 | self.register_buffer('particle_time_mask', torch.zeros(P, T)) # [P, T] 28 | self.register_buffer('particle_pos_world', torch.zeros(P, T, 3)) # [P, T, 3] 29 | self.register_buffer('particle_dir_world', torch.zeros(P, T, 3)) # [P, T, 3] 30 | self.register_buffer('particle_intensity', torch.zeros(P, T, 1)) # [P, T, 1] 31 | self.register_buffer('radius', R * (0.5 * torch.rand(P, 1)+1)) # [P, 1] 32 | # self.radius = torch.nn.Parameter(R * torch.ones(P, 1)) # [P, 1] 33 | self.particle_intensity_raw = torch.nn.Parameter((10/P * torch.ones(P, 1)).clamp(0, 0.2)) # [P, 1] 34 | 35 | self.register_buffer('particle_time_coef', torch.zeros(P, T)) # [P, T] 36 | 37 | def initialize_with_state_dict(self, state_dict): 38 | self.load_state_dict(state_dict) 39 | self.particle_time_mask = self.particle_time_mask.bool() 40 | self.initialized = True 41 | print('Load vortex particles from state dict.') 42 | 43 | def initialize_from_generation(self, generated_dict): 44 | self.particle_time_mask = generated_dict['particle_time_mask'] 45 | self.particle_pos_world = generated_dict['particle_pos_world'] 46 | self.particle_dir_world = generated_dict['particle_dir_world'] 47 | self.particle_time_coef = generated_dict['particle_time_coef'] 48 | self.particle_intensity = generated_dict['particle_intensity'] / 200 49 | assert self.particle_time_mask.shape == (self.P, self.T) 50 | assert self.particle_time_coef.shape == (self.P, self.T) 51 | assert self.particle_pos_world.shape == (self.P, self.T, 3) 52 | assert self.particle_dir_world.shape == (self.P, self.T, 3) 53 | self.initialized = True 54 | 55 | def forward(self, coord_3d_world, time_idx, chunk=50): 56 | """ 57 | args: 58 | coord_3d_world: [..., 3] 59 | time_idx: int 60 | return: 61 | confinement_field: [..., 3] 62 | """ 63 | assert self.initialized, 'Vortex_Particles not initialized' 64 | mask_particle = self.particle_time_mask[:, time_idx] # [P, T] -> [P] 65 | particle_pos_world = self.particle_pos_world[:, time_idx] # [P, T, 3] -> [P, 3] 66 | particle_dir_world = self.particle_dir_world[:, time_idx] # [P, T, 3] -> [P, 3] 67 | particle_intensity = self.particle_intensity_raw.clamp(0, 10) + 1e-8 # [P, 1] 68 | particle_intensity = particle_intensity.pow(0.5) # associated with energy 69 | particle_intensity = particle_intensity * self.particle_intensity[:, time_idx] # [P, 1] 70 | radius = torch.relu(self.radius) 71 | if any(mask_particle): 72 | confinement_field = compute_confinement_field(particle_pos_world[mask_particle], particle_dir_world[mask_particle], 73 | particle_intensity[mask_particle], radius[mask_particle], coord_3d_world, chunk=chunk) 74 | else: 75 | confinement_field = torch.zeros_like(coord_3d_world) 76 | return confinement_field 77 | 78 | def vort_kernel(x, x_p, r): 79 | dist = torch.norm(x - x_p, dim=-1, keepdim=True) 80 | influence = torch.exp(-dist ** 2 / (2 * r ** 2)) / (r**3 * 40000) 81 | mask = dist < 3*r 82 | influence = influence * mask.float().detach() 83 | return influence 84 | 85 | def generate_vort_trajectory_curl(time_steps, bbox_model, rx=128, ry=192, rz=128, get_vel_der_fn=None, 86 | P=100, N_sample=2**10, den_net=None, **render_kwargs): 87 | print('Generating vortex trajectory using curl...') 88 | dt = time_steps[1] - time_steps[0] 89 | T = len(time_steps) 90 | 91 | # construct simulation domain grid 92 | xs, ys, zs = torch.meshgrid([torch.linspace(0, 1, rx), torch.linspace(0, 1, ry), torch.linspace(0, 1, rz)], indexing='ij') 93 | coord_3d_sim = torch.stack([xs, ys, zs], dim=-1) # [X, Y, Z, 3] 94 | coord_3d_world = bbox_model.sim2world(coord_3d_sim) # [X, Y, Z, 3] 95 | 96 | # initialize density field 97 | time_step = torch.ones_like(coord_3d_world[..., :1]) * time_steps[0] 98 | coord_4d_world = torch.cat([coord_3d_world, time_step], dim=-1) # [X, Y, Z, 4] 99 | 100 | # place empty vortex particles 101 | all_init_pos = [] 102 | all_init_dir = [] 103 | all_init_int = [] 104 | all_init_time = [] 105 | 106 | for i in range(P): 107 | # sample 4d points 108 | timesteps = 0.25 + torch.rand(N_sample) * 0.65 # sample from t=0.25 to t=0.9 109 | sampled_3d_coord_x = 0.25 + torch.rand(N_sample) * 0.5 # [N] 110 | sampled_3d_coord_y = 0.25 + torch.rand(N_sample) * 0.5 # [N] 111 | sampled_3d_coord_z = 0.25 + torch.rand(N_sample) * 0.5 # [N] 112 | sampled_3d_coord = torch.stack([sampled_3d_coord_x, sampled_3d_coord_y, sampled_3d_coord_z], dim=-1) # [N, 3] 113 | sampled_3d_coord_world = bbox_model.sim2world(sampled_3d_coord) # [N, 3] 114 | sampled_4d_coord_world = torch.cat([sampled_3d_coord_world, timesteps[:, None]], dim=-1) # [N, 4] 115 | 116 | # compute curl of sampled points 117 | density = den_net(sampled_4d_coord_world) # [N, 1] 118 | density = density.squeeze(-1) # [N] 119 | mask = density > 1 120 | curls = compute_curl_batch(sampled_4d_coord_world, get_vel_der_fn) # [N, 3] 121 | curls = curls[mask] 122 | timesteps = timesteps[mask] 123 | sampled_3d_coord_world = sampled_3d_coord_world[mask] 124 | curls_norm = curls.norm(dim=-1) # [N] 125 | print(i, 'max curl norm: ', curls_norm.max().item()) 126 | 127 | # get points with highest curl norm 128 | max_idx = curls_norm.argmax() # get points with highest curl norm 129 | init_pos = sampled_3d_coord_world[max_idx] # [3] 130 | init_dir = curls[max_idx] / curls_norm[max_idx] # [3] 131 | init_int = curls_norm[max_idx] # [1] 132 | init_time = timesteps[max_idx] # [1] 133 | all_init_pos.append(init_pos) 134 | all_init_dir.append(init_dir) 135 | all_init_int.append(init_int) 136 | all_init_time.append(init_time) 137 | 138 | all_init_pos = torch.stack(all_init_pos, dim=0) # [P, 3] 139 | all_init_dir = torch.stack(all_init_dir, dim=0) # [P, 3] 140 | all_init_int = torch.stack(all_init_int, dim=0)[:, None] # [P, 1] 141 | all_init_time = torch.stack(all_init_time, dim=0)[:, None] # [P, 1] 142 | 143 | # initialize vortex particle position, direction, and when it spawns 144 | particle_start_timestep = all_init_time # [P, 1] 145 | particle_start_timestep = torch.floor(particle_start_timestep * T).expand(-1, T) # [P, T] 146 | particle_time_mask = torch.arange(T).unsqueeze(0).expand(P, -1) >= particle_start_timestep # [P, T] 147 | particle_time_coef = particle_time_mask.float() # [P, T] 148 | for time_coef in particle_time_coef: 149 | n = 20 150 | first_idx = time_coef.nonzero()[0] 151 | try: 152 | time_coef[first_idx:first_idx+n] = torch.linspace(0, 1, n) 153 | except: 154 | time_coef[first_idx:] = torch.linspace(0, 1, T - first_idx.item()) 155 | particle_pos_world = all_init_pos # [P, 3] 156 | particle_dir_world = all_init_dir # [P, 3] 157 | particle_int_multiplier = torch.ones_like(all_init_int) # [P, 1] 158 | particle_int = all_init_int.clone() # [P, 1] 159 | 160 | all_pos = [] 161 | all_dir = [] 162 | all_int = [] 163 | 164 | for i in range(T): 165 | # update simulation den and source den 166 | if i > 0: 167 | coord_4d_world[..., 3] = time_steps[i - 1] # sample velocity at previous moment 168 | vel = batchify_query(coord_4d_world, render_kwargs['network_query_fn_vel']) # [X, Y, Z, 3] 169 | 170 | # advect vortex particles 171 | mask_to_evolve = particle_time_mask[:, i] 172 | print('particles to evolve: ', mask_to_evolve.sum().item(), '/', P) 173 | if any(mask_to_evolve): 174 | particle_pos_world[mask_to_evolve] = advect_maccormack_particle(particle_pos_world[mask_to_evolve], vel, coord_3d_sim, dt, bbox_model=bbox_model, **render_kwargs) 175 | 176 | # stretch vortex particles 177 | grad_u, grad_v, grad_w = get_particle_vel_der(particle_pos_world[mask_to_evolve], bbox_model, get_vel_der_fn, time_steps[i - 1]) 178 | particle_dir_world[mask_to_evolve], particle_int_multiplier[mask_to_evolve] = stretch_vortex_particles(particle_dir_world[mask_to_evolve], grad_u, grad_v, grad_w, dt) 179 | particle_int[mask_to_evolve] = particle_int[mask_to_evolve] * particle_int_multiplier[mask_to_evolve] 180 | particle_int[particle_int > all_init_int] = all_init_int[particle_int > all_init_int] 181 | 182 | all_pos.append(particle_pos_world.clone()) 183 | all_dir.append(particle_dir_world.clone()) 184 | all_int.append(particle_int.clone()) 185 | particle_pos_world = torch.stack(all_pos, dim=0).permute(1, 0, 2) # [P, T, 3] 186 | particle_dir_world = torch.stack(all_dir, dim=0).permute(1, 0, 2) # [P, T, 3] 187 | particle_intensity = torch.stack(all_int, dim=0).permute(1, 0, 2) # [P, T, 1] 188 | radius = 0.03 * torch.ones(P, 1)[:, None].expand(-1, T, -1) # [P, T, 1] 189 | vort_particles = {'particle_time_mask': particle_time_mask, 190 | 'particle_pos_world': particle_pos_world, 191 | 'particle_dir_world': particle_dir_world, 192 | 'particle_intensity': particle_intensity, 193 | 'particle_time_coef': particle_time_coef, 194 | 'radius': radius} 195 | return vort_particles 196 | 197 | def stretch_vortex_particles(particle_dir, grad_u, grad_v, grad_w, dt): 198 | stretch_term = torch.cat([(particle_dir * grad_u).sum(dim=-1, keepdim=True), 199 | (particle_dir * grad_v).sum(dim=-1, keepdim=True), 200 | (particle_dir * grad_w).sum(dim=-1, keepdim=True), ], dim=-1) # [P, 3] 201 | particle_dir = particle_dir + stretch_term * dt 202 | particle_int = torch.norm(particle_dir, dim=-1, keepdim=True) 203 | particle_dir = particle_dir / (particle_int + 1e-8) 204 | return particle_dir, particle_int 205 | 206 | def get_particle_vel_der(particle_pos_3d_world, bbox_model, get_vel_der_fn, t): 207 | time_step = torch.ones_like(particle_pos_3d_world[..., :1]) * t 208 | particle_pos_4d_world = torch.cat([particle_pos_3d_world, time_step], dim=-1) # [P, 4] 209 | particle_pos_4d_world.requires_grad_() 210 | with torch.enable_grad(): 211 | _, _, _u_x, _u_y, _u_z, _u_t = get_vel_der_fn(particle_pos_4d_world) # [P, 3], partial der of u,v,w 212 | jac = torch.stack([_u_x, _u_y, _u_z], dim=-1) # [P, 3, 3] 213 | grad_u_world, grad_v_world, grad_w_world = jac[:, 0], jac[:, 1], jac[:, 2] # [P, 3] 214 | return grad_u_world, grad_v_world, grad_w_world 215 | 216 | def compute_confinement_field(particle_pos_world, particle_dir_world, particle_intensity, radius, coord_3d_world, chunk=50): 217 | """ 218 | :param particle_pos_world: [P, 3] 219 | :param particle_dir_world: [P, 3] 220 | :param particle_intensity: [P, 1] 221 | :param radius: [P, 1] 222 | :param coord_3d_world: [..., 3] 223 | :param chunk: int 224 | return: 225 | confinement_field: [..., 3] 226 | """ 227 | coord_3d_world_shape = coord_3d_world.shape 228 | assert coord_3d_world_shape[-1] == 3 229 | coord_3d_world = coord_3d_world.view(-1, 3) # [N, 3] 230 | P = particle_pos_world.shape[0] 231 | confinement_field = torch.zeros_like(coord_3d_world) # [N, 3] 232 | for i in range(0, P, chunk): 233 | location_field = particle_pos_world[i:i+chunk, None, :] - coord_3d_world # [P, N, 3] 234 | location_field = location_field / torch.norm(location_field, dim=-1, keepdim=True) # [P, N, 3] 235 | vorticity_field = vort_kernel(coord_3d_world, particle_pos_world[i:i+chunk, None, :], r=radius[i:i+chunk, None, :])\ 236 | * particle_dir_world[i:i+chunk, None, :] # [P, N, 3] 237 | confinement_field_each = particle_intensity[i:i+chunk, None, :] \ 238 | * torch.cross(location_field, vorticity_field, dim=-1) # [P, N, 3] 239 | confinement_field += confinement_field_each.sum(dim=0) # [N, 3] 240 | confinement_field = confinement_field.view(coord_3d_world_shape) # [..., 3] 241 | return confinement_field 242 | 243 | def compute_curl_batch(pts, get_vel_der_fn, chunk=64*96*64): 244 | pts_shape = pts.shape 245 | pts = pts.view(-1, pts_shape[-1]) # [N, 3] 246 | N = pts.shape[0] 247 | curls = [] 248 | for i in range(0, N, chunk): 249 | curl = compute_curl(pts[i:i+chunk], get_vel_der_fn) 250 | curls.append(curl) 251 | curl = torch.cat(curls, dim=0) # [N, 3] 252 | curl = curl.view(list(pts_shape[:-1]) + [3]) # [..., 3] 253 | return curl 254 | 255 | def compute_curl(pts, get_vel_der_fn): 256 | """ 257 | :param pts: [..., 4] 258 | :param get_vel_der_fn: function 259 | :return: 260 | curl: [..., 3] 261 | """ 262 | pts_shape = pts.shape 263 | pts = pts.view(-1, pts_shape[-1]) # [N, 3] 264 | pts.requires_grad_() 265 | with torch.enable_grad(): 266 | _, _, _u_x, _u_y, _u_z, _u_t = get_vel_der_fn(pts) # [N, 3], partial der of u,v,w 267 | jac = torch.stack([_u_x, _u_y, _u_z], dim=-1) # [N, 3, 3] 268 | curl = torch.stack([jac[:, 2, 1] - jac[:, 1, 2], 269 | jac[:, 0, 2] - jac[:, 2, 0], 270 | jac[:, 1, 0] - jac[:, 0, 1]], dim=-1) # [N, 3] 271 | curl = curl.view(list(pts_shape[:-1]) + [3]) # [..., 3] 272 | return curl 273 | 274 | def compute_curl_FD(vel, reverse_z=True): 275 | X, Y, Z, _ = vel.shape 276 | curl = torch.zeros_like(vel) 277 | 278 | if reverse_z: 279 | curl[1:-1, 1:-1, 1:-1, 0] = (vel[1:-1, 2:, 1:-1, 2] - vel[1:-1, :-2, 1:-1, 2]) / 2.0 - (vel[1:-1, 1:-1, :-2, 1] - vel[1:-1, 1:-1, 2:, 1]) / 2.0 280 | curl[1:-1, 1:-1, 1:-1, 1] = (vel[1:-1, 1:-1, :-2, 0] - vel[1:-1, 1:-1, 2:, 0]) / 2.0 - (vel[2:, 1:-1, 1:-1, 2] - vel[:-2, 1:-1, 1:-1, 2]) / 2.0 281 | curl[1:-1, 1:-1, 1:-1, 2] = (vel[2:, 1:-1, 1:-1, 1] - vel[:-2, 1:-1, 1:-1, 1]) / 2.0 - (vel[1:-1, 2:, 1:-1, 0] - vel[1:-1, :-2, 1:-1, 0]) / 2.0 282 | 283 | else: 284 | curl[1:-1, 1:-1, 1:-1, 0] = (vel[1:-1, 2:, 1:-1, 2] - vel[1:-1, :-2, 1:-1, 2]) / 2.0 - (vel[1:-1, 1:-1, 2:, 1] - vel[1:-1, 1:-1, :-2, 1]) / 2.0 285 | curl[1:-1, 1:-1, 1:-1, 1] = (vel[1:-1, 1:-1, 2:, 0] - vel[1:-1, 1:-1, :-2, 0]) / 2.0 - (vel[2:, 1:-1, 1:-1, 2] - vel[:-2, 1:-1, 1:-1, 2]) / 2.0 286 | curl[1:-1, 1:-1, 1:-1, 2] = (vel[2:, 1:-1, 1:-1, 1] - vel[:-2, 1:-1, 1:-1, 1]) / 2.0 - (vel[1:-1, 2:, 1:-1, 0] - vel[1:-1, :-2, 1:-1, 0]) / 2.0 287 | return curl 288 | 289 | def compute_grad_FD(scalar_field): 290 | X, Y, Z, _ = scalar_field.shape 291 | grad = torch.zeros((X, Y, Z, 3), dtype=scalar_field.dtype, device=scalar_field.device) 292 | 293 | # Compute finite differences and update grad, except for boundaries 294 | grad[1:-1, :, :, 0] = (scalar_field[2:, :, :, 0] - scalar_field[:-2, :, :, 0]) / 2.0 295 | grad[:, 1:-1, :, 1] = (scalar_field[:, 2:, :, 0] - scalar_field[:, :-2, :, 0]) / 2.0 296 | grad[:, :, 1:-1, 2] = (scalar_field[:, :, 2:, 0] - scalar_field[:, :, :-2, 0]) / 2.0 297 | 298 | return grad 299 | 300 | def compute_curl_and_grad_batch(pts, get_vel_der_fn, chunk=64*96*64): 301 | pts_shape = pts.shape 302 | pts = pts.view(-1, pts_shape[-1]) # [N, 3] 303 | N = pts.shape[0] 304 | curls = [] 305 | vorticity_norm_grads = [] 306 | for i in range(0, N, chunk): 307 | curl, vorticity_norm_grad = compute_curl_and_grad(pts[i:i+chunk], get_vel_der_fn) 308 | curls.append(curl) 309 | vorticity_norm_grads.append(vorticity_norm_grad) 310 | curl = torch.cat(curls, dim=0) # [N, 3] 311 | vorticity_norm_grad = torch.cat(vorticity_norm_grads, dim=0) # [N, 3] 312 | curl = curl.view(list(pts_shape[:-1]) + [3]) # [..., 3] 313 | vorticity_norm_grad = vorticity_norm_grad.view(list(pts_shape[:-1]) + [3]) # [..., 3] 314 | return curl, vorticity_norm_grad 315 | 316 | def compute_curl_and_grad(pts, get_vel_der_fn): 317 | pts_shape = pts.shape 318 | pts = pts.view(-1, pts_shape[-1]) # [N, 3] 319 | pts.requires_grad_() 320 | with torch.enable_grad(): 321 | _, _, _u_x, _u_y, _u_z, _u_t = get_vel_der_fn(pts) # [N, 3], partial der of u,v,w 322 | jac = torch.stack([_u_x, _u_y, _u_z], dim=-1) # [N, 3, 3] 323 | curl = torch.stack([jac[:, 2, 1] - jac[:, 1, 2], 324 | jac[:, 0, 2] - jac[:, 2, 0], 325 | jac[:, 1, 0] - jac[:, 0, 1]], dim=-1) # [N, 3] 326 | 327 | vorticity_norm = torch.norm(curl, dim=-1, keepdim=True) 328 | vorticity_norm_grad = [] 329 | 330 | for j in range(vorticity_norm.shape[1]): 331 | dy_j_dx = torch.autograd.grad( 332 | vorticity_norm[:, j], 333 | pts, 334 | torch.ones_like(vorticity_norm[:, j], device=vorticity_norm.get_device()), 335 | retain_graph=True, 336 | create_graph=True, 337 | )[0].view(pts.shape[0], -1) 338 | vorticity_norm_grad.append(dy_j_dx.unsqueeze(1)) 339 | vorticity_norm_grad = torch.cat(vorticity_norm_grad, dim=1) 340 | curl = curl.view(list(pts_shape[:-1]) + [3]) # [..., 3] 341 | vorticity_norm_grad = vorticity_norm_grad.view(list(pts_shape[:-1]) + [4])[..., :3] # [..., 3] 342 | 343 | return curl, vorticity_norm_grad 344 | 345 | def run_advect_den(render_poses, hwf, K, time_steps, savedir, gt_imgs, bbox_model, rx=128, ry=192, rz=128, 346 | save_fields=False, save_den=False, vort_particles=None, render=None, get_vel_der_fn=None, **render_kwargs): 347 | H, W, focal = hwf 348 | dt = time_steps[1] - time_steps[0] 349 | render_kwargs.update(chunk=512 * 16) 350 | psnrs = [] 351 | lpipss = [] 352 | ssims = [] 353 | lpips_net = LPIPS().cuda() # input should be [-1, 1] or [0, 1] (normalize=True) 354 | 355 | # construct simulation domain grid 356 | xs, ys, zs = torch.meshgrid([torch.linspace(0, 1, rx), torch.linspace(0, 1, ry), torch.linspace(0, 1, rz)], indexing='ij') 357 | coord_3d_sim = torch.stack([xs, ys, zs], dim=-1) # [X, Y, Z, 3] 358 | coord_3d_world = bbox_model.sim2world(coord_3d_sim) # [X, Y, Z, 3] 359 | 360 | # initialize density field 361 | time_step = torch.ones_like(coord_3d_world[..., :1]) * time_steps[0] 362 | coord_4d_world = torch.cat([coord_3d_world, time_step], dim=-1) # [X, Y, Z, 4] 363 | den = batchify_query(coord_4d_world, render_kwargs['network_query_fn']) # [X, Y, Z, 1] 364 | den_ori = den 365 | vel = batchify_query(coord_4d_world, render_kwargs['network_query_fn_vel']) # [X, Y, Z, 3] 366 | vel_saved = vel 367 | bbox_mask = bbox_model.insideMask(coord_3d_world[..., :3].reshape(-1, 3), to_float=False) 368 | bbox_mask = bbox_mask.reshape(rx, ry, rz) 369 | 370 | source_height = 0.25 371 | y_start = int(source_height * ry) 372 | print('y_start: {}'.format(y_start)) 373 | render_kwargs.update(y_start=y_start) 374 | for i, c2w in enumerate(tqdm(render_poses)): 375 | # update simulation den and source den 376 | mask_to_sim = coord_3d_sim[..., 1] > source_height 377 | if i > 0: 378 | coord_4d_world[..., 3] = time_steps[i - 1] # sample velocity at previous moment 379 | 380 | vel = batchify_query(coord_4d_world, render_kwargs['network_query_fn_vel']) # [X, Y, Z, 3] 381 | vel_saved = vel 382 | # advect vortex particles 383 | if vort_particles is not None: 384 | confinement_field = vort_particles(coord_3d_world, i) 385 | print('Vortex energy over velocity: {:.2f}%'.format(torch.norm(confinement_field, dim=-1).pow(2).sum() / torch.norm(vel, dim=-1).pow(2).sum() * 100)) 386 | else: 387 | confinement_field = torch.zeros_like(vel) 388 | 389 | vel_confined = vel + confinement_field 390 | den, vel = advect_maccormack(den, vel_confined, coord_3d_sim, dt, bbox_model=bbox_model, **render_kwargs) 391 | den_ori = batchify_query(coord_4d_world, render_kwargs['network_query_fn']) # [X, Y, Z, 1] 392 | # zero grad for coord_4d_world 393 | # coord_4d_world.grad = None 394 | # coord_4d_world = coord_4d_world.detach() 395 | 396 | coord_4d_world[..., 3] = time_steps[i] # source density at current moment 397 | den[~mask_to_sim] = batchify_query(coord_4d_world[~mask_to_sim], render_kwargs['network_query_fn']) 398 | den[~bbox_mask] *= 0.0 399 | 400 | if save_fields: 401 | # save_fields_to_vti(vel.permute(2, 1, 0, 3).detach().cpu().numpy(), 402 | # den.permute(2, 1, 0, 3).detach().cpu().numpy(), 403 | # os.path.join(savedir, 'fields_{:03d}.vti'.format(i))) 404 | np.save(os.path.join(savedir, 'den_{:03d}.npy'.format(i)), den.permute(2, 1, 0, 3).detach().cpu().numpy()) 405 | np.save(os.path.join(savedir, 'den_ori_{:03d}.npy'.format(i)), den_ori.permute(2, 1, 0, 3).detach().cpu().numpy()) 406 | np.save(os.path.join(savedir, 'vel_{:03d}.npy'.format(i)), vel_saved.permute(2, 1, 0, 3).detach().cpu().numpy()) 407 | if save_den: 408 | # save_vdb(den[..., 0].detach().cpu().numpy(), 409 | # os.path.join(savedir, 'den_{:03d}.vdb'.format(i))) 410 | # save npy files 411 | np.save(os.path.join(savedir, 'den_{:03d}.npy'.format(i)), den[..., 0].detach().cpu().numpy()) 412 | rgb, _ = render(H, W, K, c2w=c2w[:3, :4], time_step=time_steps[i][None], render_grid=True, den_grid=den, 413 | **render_kwargs) 414 | rgb8 = to8b(rgb.detach().cpu().numpy()) 415 | if gt_imgs is not None: 416 | gt_img = torch.tensor(gt_imgs[i].squeeze(), dtype=torch.float32) # [H, W, 3] 417 | gt_img8 = to8b(gt_img.cpu().numpy()) 418 | gt_img = gt_img[90:960, 45:540] 419 | rgb = rgb[90:960, 45:540] 420 | lpips_value = lpips_net(rgb.permute(2, 0, 1), gt_img.permute(2, 0, 1), normalize=True).item() 421 | p = -10. * np.log10(np.mean(np.square(rgb.detach().cpu().numpy() - gt_img.cpu().numpy()))) 422 | ssim_value = structural_similarity(gt_img.cpu().numpy(), rgb.cpu().numpy(), data_range=1.0, channel_axis=2) 423 | lpipss.append(lpips_value) 424 | psnrs.append(p) 425 | ssims.append(ssim_value) 426 | print(f'PSNR: {p:.4g}, SSIM: {ssim_value:.4g}, LPIPS: {lpips_value:.4g}') 427 | imageio.imsave(os.path.join(savedir, 'rgb_{:03d}.png'.format(i)), rgb8) 428 | imageio.imsave(os.path.join(savedir, 'gt_{:03d}.png'.format(i)), gt_img8) 429 | merge_imgs(savedir, prefix='rgb_') 430 | merge_imgs(savedir, prefix='gt_') 431 | 432 | if gt_imgs is not None: 433 | avg_psnr = sum(psnrs)/len(psnrs) 434 | print(f"Avg PSNR over full simulation: ", avg_psnr) 435 | avg_ssim = sum(ssims)/len(ssims) 436 | print(f"Avg SSIM over full simulation: ", avg_ssim) 437 | avg_lpips = sum(lpipss)/len(lpipss) 438 | print(f"Avg LPIPS over full simulation: ", avg_lpips) 439 | with open(os.path.join(savedir, "psnrs_{:0.2f}_ssim_{:.2g}_lpips_{:.2g}.json".format(avg_psnr, avg_ssim, avg_lpips)), "w") as fp: 440 | json.dump(psnrs, fp) 441 | 442 | 443 | def run_future_pred(render_poses, hwf, K, time_steps, savedir, gt_imgs, bbox_model, rx=128, ry=192, rz=128, 444 | save_fields=False, vort_particles=None, project_solver=None, render=None, get_vel_der_fn=None, **render_kwargs): 445 | H, W, focal = hwf 446 | dt = time_steps[1] - time_steps[0] 447 | render_kwargs.update(chunk=512 * 16) 448 | psnrs = [] 449 | lpipss = [] 450 | ssims = [] 451 | lpips_net = LPIPS().cuda() # input should be [-1, 1] or [0, 1] (normalize=True) 452 | 453 | # construct simulation domain grid 454 | xs, ys, zs = torch.meshgrid([torch.linspace(0, 1, rx), torch.linspace(0, 1, ry), torch.linspace(0, 1, rz)], indexing='ij') 455 | coord_3d_sim = torch.stack([xs, ys, zs], dim=-1) # [X, Y, Z, 3] 456 | coord_3d_world = bbox_model.sim2world(coord_3d_sim) # [X, Y, Z, 3] 457 | 458 | # initialize density field 459 | starting_frame = 89 460 | n_pred = 30 461 | time_step = torch.ones_like(coord_3d_world[..., :1]) * time_steps[starting_frame] 462 | coord_4d_world = torch.cat([coord_3d_world, time_step], dim=-1) # [X, Y, Z, 4] 463 | den = batchify_query(coord_4d_world, render_kwargs['network_query_fn']) # [X, Y, Z, 1] 464 | vel = batchify_query(coord_4d_world, render_kwargs['network_query_fn_vel']) # [X, Y, Z, 3] 465 | 466 | source_height = 0.25 467 | y_start = int(source_height * ry) 468 | print('y_start: {}'.format(y_start)) 469 | render_kwargs.update(y_start=y_start) 470 | proj_y = render_kwargs['proj_y'] 471 | for idx, i in enumerate(range(starting_frame+1, starting_frame+n_pred+1)): 472 | c2w = render_poses[0] 473 | mask_to_sim = coord_3d_sim[..., 1] > source_height 474 | n_substeps = 1 475 | if vort_particles is not None: 476 | confinement_field = vort_particles(coord_3d_world, i) 477 | print('Vortex energy over velocity: {:.2f}%'.format( 478 | torch.norm(confinement_field, dim=-1).pow(2).sum() / torch.norm(vel, dim=-1).pow(2).sum() * 100)) 479 | else: 480 | confinement_field = torch.zeros_like(vel) 481 | vel_confined = vel + confinement_field 482 | 483 | for _ in range(n_substeps): 484 | dt_ = dt/n_substeps 485 | den, _ = advect_SL(den, vel_confined, coord_3d_sim, dt_, bbox_model=bbox_model, **render_kwargs) 486 | vel, _ = advect_SL(vel, vel, coord_3d_sim, dt_, bbox_model=bbox_model, **render_kwargs) 487 | vel[..., 2] *= -1 # world coord is left handed, while solver assumes right handed 488 | vel[:, y_start:y_start + proj_y] = project_solver.Poisson(vel[:, y_start:y_start + proj_y]) 489 | vel[..., 2] *= -1 490 | 491 | try: 492 | coord_4d_world[..., 3] = time_steps[i] # sample density source at current moment 493 | den[~mask_to_sim] = batchify_query(coord_4d_world[~mask_to_sim], render_kwargs['network_query_fn']) 494 | vel[~mask_to_sim] = batchify_query(coord_4d_world[~mask_to_sim], render_kwargs['network_query_fn_vel']) 495 | except IndexError: 496 | pass 497 | 498 | if save_fields: 499 | save_fields_to_vti(vel.permute(2, 1, 0, 3).cpu().numpy(), 500 | den.permute(2, 1, 0, 3).cpu().numpy(), 501 | os.path.join(savedir, 'fields_{:03d}.vti'.format(idx))) 502 | print('Saved fields to {}'.format(os.path.join(savedir, 'fields_{:03d}.vti'.format(idx)))) 503 | rgb, _ = render(H, W, K, c2w=c2w[:3, :4], time_step=time_steps[0][None], render_grid=True, den_grid=den, 504 | **render_kwargs) 505 | rgb8 = to8b(rgb.cpu().numpy()) 506 | try: 507 | gt_img = torch.tensor(gt_imgs[i].squeeze(), dtype=torch.float32) # [H, W, 3] 508 | gt_img8 = to8b(gt_img.cpu().numpy()) 509 | gt_img = gt_img[90:960, 45:540] 510 | rgb = rgb[90:960, 45:540] 511 | lpips_value = lpips_net(rgb.permute(2, 0, 1), gt_img.permute(2, 0, 1), normalize=True).item() 512 | p = -10. * np.log10(np.mean(np.square(rgb.detach().cpu().numpy() - gt_img.cpu().numpy()))) 513 | ssim_value = structural_similarity(gt_img.cpu().numpy(), rgb.cpu().numpy(), data_range=1.0, channel_axis=2) 514 | lpipss.append(lpips_value) 515 | psnrs.append(p) 516 | ssims.append(ssim_value) 517 | print(f'PSNR: {p:.4g}, SSIM: {ssim_value:.4g}, LPIPS: {lpips_value:.4g}') 518 | except IndexError: 519 | pass 520 | imageio.imsave(os.path.join(savedir, 'rgb_{:03d}.png'.format(idx)), rgb8) 521 | imageio.imsave(os.path.join(savedir, 'gt_{:03d}.png'.format(idx)), gt_img8) 522 | merge_imgs(savedir, framerate=10, prefix='rgb_') 523 | merge_imgs(savedir, framerate=10, prefix='gt_') 524 | 525 | if gt_imgs is not None: 526 | try: 527 | avg_psnr = sum(psnrs) / len(psnrs) 528 | print(f"Avg PSNR over full simulation: ", avg_psnr) 529 | avg_ssim = sum(ssims) / len(ssims) 530 | print(f"Avg SSIM over full simulation: ", avg_ssim) 531 | avg_lpips = sum(lpipss) / len(lpipss) 532 | print(f"Avg LPIPS over full simulation: ", avg_lpips) 533 | with open(os.path.join(savedir, "psnrs_{:0.2f}_ssim_{:.2g}_lpips_{:.2g}.json".format(avg_psnr, avg_ssim, avg_lpips)), "w") as fp: 534 | json.dump(psnrs, fp) 535 | except: 536 | pass 537 | 538 | def run_view_synthesis(render_poses, hwf, K, time_steps, savedir, gt_imgs, bbox_model, rx=128, ry=192, rz=128, 539 | save_fields=False, vort_particles=None, project_solver=None, render=None, get_vel_der_fn=None, **render_kwargs): 540 | H, W, focal = hwf 541 | dt = time_steps[1] - time_steps[0] 542 | render_kwargs.update(chunk=512 * 16) 543 | psnrs = [] 544 | lpipss = [] 545 | ssims = [] 546 | lpips_net = LPIPS().cuda() # input should be [-1, 1] or [0, 1] (normalize=True) 547 | 548 | # initialize density field 549 | starting_frame = 0 550 | n_pred = 120 551 | for idx, i in enumerate(range(starting_frame, starting_frame+n_pred)): 552 | c2w = render_poses[i] 553 | rgb, _ = render(H, W, K, c2w=c2w[:3, :4], time_step=time_steps[i][None], render_den=True, 554 | **render_kwargs) 555 | rgb8 = to8b(rgb.cpu().numpy()) 556 | if gt_imgs is not None: 557 | gt_img = torch.tensor(gt_imgs[i].squeeze(), dtype=torch.float32) # [H, W, 3] 558 | gt_img8 = to8b(gt_img.cpu().numpy()) 559 | gt_img = gt_img[90:960, 45:540] 560 | rgb = rgb[90:960, 45:540] 561 | lpips_value = lpips_net(rgb.permute(2, 0, 1), gt_img.permute(2, 0, 1), normalize=True).item() 562 | p = -10. * np.log10(np.mean(np.square(rgb.detach().cpu().numpy() - gt_img.cpu().numpy()))) 563 | ssim_value = structural_similarity(gt_img.cpu().numpy(), rgb.cpu().numpy(), data_range=1.0, channel_axis=2) 564 | lpipss.append(lpips_value) 565 | psnrs.append(p) 566 | ssims.append(ssim_value) 567 | print(f'PSNR: {p:.4g}, SSIM: {ssim_value:.4g}, LPIPS: {lpips_value:.4g}') 568 | imageio.imsave(os.path.join(savedir, 'rgb_{:03d}.png'.format(idx)), rgb8) 569 | imageio.imsave(os.path.join(savedir, 'gt_{:03d}.png'.format(idx)), gt_img8) 570 | merge_imgs(savedir, framerate=10, prefix='rgb_') 571 | merge_imgs(savedir, framerate=10, prefix='gt_') 572 | 573 | if gt_imgs is not None: 574 | avg_psnr = sum(psnrs) / len(psnrs) 575 | print(f"Avg PSNR over full simulation: ", avg_psnr) 576 | avg_ssim = sum(ssims) / len(ssims) 577 | print(f"Avg SSIM over full simulation: ", avg_ssim) 578 | avg_lpips = sum(lpipss) / len(lpipss) 579 | print(f"Avg LPIPS over full simulation: ", avg_lpips) 580 | with open(os.path.join(savedir, "psnrs_{:0.2f}_ssim_{:.2g}_lpips_{:.2g}.json".format(avg_psnr, avg_ssim, avg_lpips)), "w") as fp: 581 | json.dump(psnrs, fp) 582 | 583 | def advect_SL(q_grid, vel_world_prev, coord_3d_sim, dt, RK=2, y_start=48, proj_y=128, 584 | use_project=False, project_solver=None, bbox_model=None, **kwargs): 585 | """Advect a scalar quantity using a given velocity field. 586 | Args: 587 | q_grid: [X', Y', Z', C] 588 | vel_world_prev: [X, Y, Z, 3] 589 | coord_3d_sim: [X, Y, Z, 3] 590 | dt: float 591 | RK: int, number of Runge-Kutta steps 592 | y_start: where to start at y-axis 593 | proj_y: simulation domain resolution at y-axis 594 | use_project: whether to use Poisson solver 595 | project_solver: Poisson solver 596 | bbox_model: bounding box model 597 | Returns: 598 | advected_quantity: [X, Y, Z, 1] 599 | vel_world: [X, Y, Z, 3] 600 | """ 601 | if RK == 1: 602 | vel_world = vel_world_prev.clone() 603 | vel_world[:, y_start:y_start+proj_y] = project_solver.Poisson(vel_world[:, y_start:y_start+proj_y]) if use_project else vel_world[:, y_start:y_start+proj_y] 604 | vel_sim = bbox_model.world2sim_rot(vel_world) # [X, Y, Z, 3] 605 | elif RK == 2: 606 | vel_world = vel_world_prev.clone() # [X, Y, Z, 3] 607 | vel_world[:, y_start:y_start+proj_y] = project_solver.Poisson(vel_world[:, y_start:y_start+proj_y]) if use_project else vel_world[:, y_start:y_start+proj_y] 608 | # breakpoint() 609 | vel_sim = bbox_model.world2sim_rot(vel_world) # [X, Y, Z, 3] 610 | coord_3d_sim_midpoint = coord_3d_sim - 0.5 * dt * vel_sim # midpoint 611 | midpoint_sampled = coord_3d_sim_midpoint * 2 - 1 # [X, Y, Z, 3] 612 | vel_sim = F.grid_sample(vel_sim.permute(3, 2, 1, 0)[None], midpoint_sampled.permute(2, 1, 0, 3)[None], align_corners=True, padding_mode='zeros').squeeze(0).permute(3, 2, 1, 0) # [X, Y, Z, 3] 613 | else: 614 | raise NotImplementedError 615 | backtrace_coord = coord_3d_sim - dt * vel_sim # [X, Y, Z, 3] 616 | backtrace_coord_sampled = backtrace_coord * 2 - 1 # ranging [-1, 1] 617 | q_grid = q_grid[None, ...].permute([0, 4, 3, 2, 1]) # [N, C, Z, Y, X] i.e., [N, C, D, H, W] 618 | q_backtraced = F.grid_sample(q_grid, backtrace_coord_sampled.permute(2, 1, 0, 3)[None, ...], align_corners=True, padding_mode='zeros') # [N, C, D, H, W] 619 | q_backtraced = q_backtraced.squeeze(0).permute([3, 2, 1, 0]) # [X, Y, Z, C] 620 | return q_backtraced, vel_world 621 | 622 | def advect_maccormack(q_grid, vel_world_prev, coord_3d_sim, dt, **kwargs): 623 | """ 624 | Args: 625 | q_grid: [X', Y', Z', C] 626 | vel_world_prev: [X, Y, Z, 3] 627 | coord_3d_sim: [X, Y, Z, 3] 628 | dt: float 629 | Returns: 630 | advected_quantity: [X, Y, Z, C] 631 | vel_world: [X, Y, Z, 3] 632 | """ 633 | q_grid_next, _ = advect_SL(q_grid, vel_world_prev, coord_3d_sim, dt, **kwargs) 634 | q_grid_back, vel_world = advect_SL(q_grid_next, vel_world_prev, coord_3d_sim, -dt, **kwargs) 635 | q_advected = q_grid_next + (q_grid - q_grid_back) / 2 636 | C = q_advected.shape[-1] 637 | for i in range(C): 638 | q_max, q_min = q_grid[..., i].max(), q_grid[..., i].min() 639 | q_advected[..., i] = q_advected[..., i].clamp_(q_min, q_max) 640 | return q_advected, vel_world 641 | 642 | def advect_SL_particle(particle_pos, vel_world_prev, coord_3d_sim, dt, RK=2, y_start=48, proj_y=128, 643 | use_project=False, project_solver=None, bbox_model=None, **kwargs): 644 | """Advect a scalar quantity using a given velocity field. 645 | Args: 646 | particle_pos: [N, 3], in world coordinate domain 647 | vel_world_prev: [X, Y, Z, 3] 648 | coord_3d_sim: [X, Y, Z, 3] 649 | dt: float 650 | RK: int, number of Runge-Kutta steps 651 | y_start: where to start at y-axis 652 | proj_y: simulation domain resolution at y-axis 653 | use_project: whether to use Poisson solver 654 | project_solver: Poisson solver 655 | bbox_model: bounding box model 656 | Returns: 657 | new_particle_pos: [N, 3], in simulation coordinate domain 658 | """ 659 | if RK == 1: 660 | vel_world = vel_world_prev.clone() 661 | vel_world[:, y_start:y_start+proj_y] = project_solver.Poisson(vel_world[:, y_start:y_start+proj_y]) if use_project else vel_world[:, y_start:y_start+proj_y] 662 | vel_sim = bbox_model.world2sim_rot(vel_world) # [X, Y, Z, 3] 663 | elif RK == 2: 664 | vel_world = vel_world_prev.clone() # [X, Y, Z, 3] 665 | vel_world[:, y_start:y_start+proj_y] = project_solver.Poisson(vel_world[:, y_start:y_start+proj_y]) if use_project else vel_world[:, y_start:y_start+proj_y] 666 | vel_sim = bbox_model.world2sim_rot(vel_world) # [X, Y, Z, 3] 667 | coord_3d_sim_midpoint = coord_3d_sim - 0.5 * dt * vel_sim # midpoint 668 | midpoint_sampled = coord_3d_sim_midpoint * 2 - 1 # [X, Y, Z, 3] 669 | vel_sim = F.grid_sample(vel_sim.permute(3, 2, 1, 0)[None], midpoint_sampled.permute(2, 1, 0, 3)[None], align_corners=True).squeeze(0).permute(3, 2, 1, 0) # [X, Y, Z, 3] 670 | else: 671 | raise NotImplementedError 672 | particle_pos_sampled = bbox_model.world2sim(particle_pos) * 2 - 1 # ranging [-1, 1] 673 | particle_vel_sim = F.grid_sample(vel_sim.permute(3, 2, 1, 0)[None], particle_pos_sampled[None, None, None], align_corners=True).permute([0, 2, 3, 4, 1]).flatten(0, 3) # [N, 3] 674 | particle_pos_new = particle_pos + dt * bbox_model.sim2world_rot(particle_vel_sim) # [N, 3] 675 | return particle_pos_new 676 | 677 | def advect_maccormack_particle(particle_pos, vel_world_prev, coord_3d_sim, dt, **kwargs): 678 | """ 679 | Args: 680 | particle_pos: [N, 3], in world coordinate domain 681 | vel_world_prev: [X, Y, Z, 3] 682 | coord_3d_sim: [X, Y, Z, 3] 683 | dt: float 684 | Returns: 685 | particle_pos_new: [N, 3], in simulation coordinate domain 686 | """ 687 | particle_pos_next = advect_SL_particle(particle_pos, vel_world_prev, coord_3d_sim, dt, **kwargs) 688 | particle_pos_back = advect_SL_particle(particle_pos_next, vel_world_prev, coord_3d_sim, -dt, **kwargs) 689 | particle_pos_new = particle_pos_next + (particle_pos - particle_pos_back) / 2 690 | return particle_pos_new 691 | 692 | 693 | def merge_imgs(save_dir, framerate=30, prefix=''): 694 | os.system( 695 | 'ffmpeg -hide_banner -loglevel error -y -i {0}/{1}%03d.png -vf palettegen {0}/palette.png'.format(save_dir, 696 | prefix)) 697 | os.system( 698 | 'ffmpeg -hide_banner -loglevel error -y -framerate {0} -i {1}/{2}%03d.png -i {1}/palette.png -lavfi paletteuse {1}/_{2}.gif'.format( 699 | framerate, save_dir, prefix)) 700 | os.system( 701 | 'ffmpeg -hide_banner -loglevel error -y -framerate {0} -i {1}/{2}%03d.png -i {1}/palette.png -lavfi paletteuse -vcodec prores {1}/_{2}.mov'.format( 702 | framerate, save_dir, prefix)) 703 | 704 | 705 | def hash(coords, log2_hashmap_size): 706 | ''' 707 | coords: this function can process upto 7 dim coordinates 708 | log2T: logarithm of T w.r.t 2 709 | ''' 710 | primes = [1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737] 711 | 712 | xor_result = torch.zeros_like(coords)[..., 0] 713 | for i in range(coords.shape[-1]): 714 | xor_result ^= coords[..., i]*primes[i] 715 | 716 | return torch.tensor((1< pt[i]): 738 | min_bound[i] = pt[i] 739 | if(max_bound[i] < pt[i]): 740 | max_bound[i] = pt[i] 741 | return 742 | 743 | for i in [0, W-1, H*W-W, H*W-1]: 744 | min_point = rays_o[i] + near*rays_d[i] 745 | max_point = rays_o[i] + far*rays_d[i] 746 | points += [min_point, max_point] 747 | find_min_max(min_point) 748 | find_min_max(max_point) 749 | 750 | return (torch.tensor(min_bound)-torch.tensor([1.0,1.0,1.0]), torch.tensor(max_bound)+torch.tensor([1.0,1.0,1.0])) 751 | 752 | 753 | def get_bbox3d_for_llff(poses, hwf, near=0.0, far=1.0): 754 | H, W, focal = hwf 755 | H, W = int(H), int(W) 756 | 757 | # ray directions in camera coordinates 758 | directions = get_ray_directions(H, W, focal) 759 | 760 | min_bound = [100, 100, 100] 761 | max_bound = [-100, -100, -100] 762 | 763 | points = [] 764 | poses = torch.FloatTensor(poses) 765 | for pose in poses: 766 | rays_o, rays_d = get_rays(directions, pose) 767 | rays_o, rays_d = get_ndc_rays(H, W, focal, 1.0, rays_o, rays_d) 768 | 769 | def find_min_max(pt): 770 | for i in range(3): 771 | if(min_bound[i] > pt[i]): 772 | min_bound[i] = pt[i] 773 | if(max_bound[i] < pt[i]): 774 | max_bound[i] = pt[i] 775 | return 776 | 777 | for i in [0, W-1, H*W-W, H*W-1]: 778 | min_point = rays_o[i] + near*rays_d[i] 779 | max_point = rays_o[i] + far*rays_d[i] 780 | points += [min_point, max_point] 781 | find_min_max(min_point) 782 | find_min_max(max_point) 783 | 784 | return (torch.tensor(min_bound)-torch.tensor([0.1,0.1,0.0001]), torch.tensor(max_bound)+torch.tensor([0.1,0.1,0.0001])) 785 | 786 | 787 | def get_voxel_vertices(xyz, bounding_box, resolution, log2_hashmap_size): 788 | ''' 789 | xyz: 3D coordinates of samples. B x 3 790 | bounding_box: min and max x,y,z coordinates of object bbox 791 | resolution: number of voxels per axis 792 | ''' 793 | box_min, box_max = bounding_box 794 | 795 | keep_mask = xyz==torch.max(torch.min(xyz, box_max), box_min) 796 | if not torch.all(xyz <= box_max) or not torch.all(xyz >= box_min): 797 | # print("ALERT: some points are outside bounding box. Clipping them!") 798 | xyz = torch.clamp(xyz, min=box_min, max=box_max) 799 | 800 | grid_size = (box_max-box_min)/resolution 801 | 802 | bottom_left_idx = torch.floor((xyz-box_min)/grid_size).int() 803 | voxel_min_vertex = bottom_left_idx*grid_size + box_min 804 | voxel_max_vertex = voxel_min_vertex + torch.tensor([1.0,1.0,1.0])*grid_size 805 | 806 | voxel_indices = bottom_left_idx.unsqueeze(1) + BOX_OFFSETS 807 | hashed_voxel_indices = hash(voxel_indices, log2_hashmap_size) 808 | 809 | return voxel_min_vertex, voxel_max_vertex, hashed_voxel_indices, keep_mask 810 | 811 | 812 | def pos_world2smoke(Pworld, w2s, scale_vector): 813 | pos_rot = torch.sum(Pworld[..., None, :] * (w2s[:3,:3]), -1) # 4.world to 3.target 814 | pos_off = (w2s[:3, -1]).expand(pos_rot.shape) # 4.world to 3.target 815 | new_pose = pos_rot + pos_off 816 | pos_scale = new_pose / (scale_vector) # 3.target to 2.simulation 817 | return pos_scale 818 | 819 | class BBox_Tool(object): 820 | def __init__(self, smoke_tran_inv, smoke_scale, in_min=[0.15, 0.0, 0.15], in_max=[0.85, 1., 0.85]): 821 | self.s_w2s = torch.tensor(smoke_tran_inv).expand([4, 4]).float() 822 | self.s2w = torch.inverse(self.s_w2s) 823 | self.s_scale = torch.tensor(smoke_scale.copy()).expand([3]).float() 824 | self.s_min = torch.Tensor(in_min) 825 | self.s_max = torch.Tensor(in_max) 826 | 827 | def world2sim(self, pts_world): 828 | pts_world_homo = torch.cat([pts_world, torch.ones_like(pts_world[..., :1])], dim=-1) 829 | pts_sim_ = torch.matmul(self.s_w2s, pts_world_homo[..., None]).squeeze(-1)[..., :3] 830 | pts_sim = pts_sim_ / (self.s_scale) # 3.target to 2.simulation 831 | return pts_sim 832 | 833 | def world2sim_rot(self, pts_world): 834 | pts_sim_ = torch.matmul(self.s_w2s[:3, :3], pts_world[..., None]).squeeze(-1) 835 | pts_sim = pts_sim_ / (self.s_scale) # 3.target to 2.simulation 836 | return pts_sim 837 | 838 | def sim2world(self, pts_sim): 839 | pts_sim_ = pts_sim * self.s_scale 840 | pts_sim_homo = torch.cat([pts_sim_, torch.ones_like(pts_sim_[..., :1])], dim=-1) 841 | pts_world = torch.matmul(self.s2w, pts_sim_homo[..., None]).squeeze(-1)[..., :3] 842 | return pts_world 843 | 844 | def sim2world_rot(self, pts_sim): 845 | pts_sim_ = pts_sim * self.s_scale 846 | pts_world = torch.matmul(self.s2w[:3, :3], pts_sim_[..., None]).squeeze(-1) 847 | return pts_world 848 | 849 | def isInside(self, inputs_pts): 850 | target_pts = pos_world2smoke(inputs_pts, self.s_w2s, self.s_scale) 851 | above = torch.logical_and(target_pts[...,0] >= self.s_min[0], target_pts[...,1] >= self.s_min[1] ) 852 | above = torch.logical_and(above, target_pts[...,2] >= self.s_min[2] ) 853 | below = torch.logical_and(target_pts[...,0] <= self.s_max[0], target_pts[...,1] <= self.s_max[1] ) 854 | below = torch.logical_and(below, target_pts[...,2] <= self.s_max[2] ) 855 | outputs = torch.logical_and(below, above) 856 | return outputs 857 | 858 | def insideMask(self, inputs_pts, to_float=True): 859 | return self.isInside(inputs_pts).to(torch.float) if to_float else self.isInside(inputs_pts) 860 | 861 | 862 | class AverageMeter(object): 863 | """Computes and stores the average and current value""" 864 | val = 0 865 | avg = 0 866 | sum = 0 867 | count = 0 868 | tot_count = 0 869 | 870 | def __init__(self): 871 | self.reset() 872 | self.tot_count = 0 873 | 874 | def reset(self): 875 | self.val = 0 876 | self.avg = 0 877 | self.sum = 0 878 | self.count = 0 879 | 880 | def update(self, val, n=1): 881 | self.val = float(val) 882 | self.sum += float(val) * n 883 | self.count += n 884 | self.tot_count += n 885 | self.avg = self.sum / self.count 886 | 887 | def write_ply(points, filename, text=True): 888 | from plyfile import PlyData, PlyElement 889 | """ input: Nx3 or Nx6, write points to filename as PLY format. """ 890 | if points.shape[1] == 3: 891 | points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])] 892 | vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')]) 893 | elif points.shape[1] == 6: 894 | if points[:, 3:6].max() <= 1.0: 895 | points[:, 3:6] *= 255 896 | points = [(p[0], p[1], p[2], p[3], p[4], p[5]) for p in points] 897 | vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4'),('red', 'u1'), ('green', 'u1'),('blue', 'u1')]) 898 | else: 899 | assert False, 'points shape:{}, not valid (2nd dim should be 3 or 6).'.format(points.shape) 900 | el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) 901 | PlyData([el], text=text).write(filename) 902 | 903 | def save_time_varying_fields_to_vti(velocity_field, density_field=None, save_dir='', basename='fields'): 904 | """ 905 | Save a time-varying velocity field and density field to a series of VTI files. 906 | args: 907 | velocity_field: a 5D NumPy array of shape (T, D, H, W, 3) containing the velocity field at each time step 908 | density_field (optional): a 5D NumPy array of shape (T, D, H, W, 1) containing the density field at each time step 909 | save_dir: the directory to save the VTI files 910 | basename: the base name of the VTI files to be saved 911 | """ 912 | assert velocity_field.ndim == 5 and velocity_field.shape[4] == 3, "Invalid velocity field shape" 913 | if density_field is not None: 914 | assert density_field.ndim == 5 and density_field.shape[4] == 1, "Invalid density field shape" 915 | assert velocity_field.shape[:4] == density_field.shape[:4], "Velocity and density fields must have the same time and grid dimensions" 916 | 917 | T, D, H, W, _ = velocity_field.shape 918 | 919 | for t in range(T): 920 | save_path = os.path.join(save_dir, f"{basename}_{t:04d}.vti") 921 | single_time_velocity_field = velocity_field[t, :, :, :] 922 | single_time_density_field = None if density_field is None else density_field[t, :, :, :] 923 | 924 | save_fields_to_vti(single_time_velocity_field, single_time_density_field, save_path) 925 | 926 | def save_fields_to_vti(velocity_field, density_field=None, save_path='fields.vti', vel_name='velocity', den_name='density'): 927 | D, H, W, _ = velocity_field.shape 928 | 929 | # Create a VTK image data object 930 | image_data = vtk.vtkImageData() 931 | image_data.SetDimensions(W, H, D) 932 | image_data.SetSpacing(1, 1, 1) 933 | 934 | # Convert the velocity NumPy array to a VTK array 935 | vtk_velocity_array = numpy_support.numpy_to_vtk(velocity_field.reshape(-1, 3), deep=True, array_type=vtk.VTK_FLOAT) 936 | vtk_velocity_array.SetName(vel_name) 937 | image_data.GetPointData().SetVectors(vtk_velocity_array) 938 | 939 | # Convert the density NumPy array to a VTK array 940 | if density_field is not None: 941 | vtk_density_array = numpy_support.numpy_to_vtk(density_field.ravel(), deep=True, array_type=vtk.VTK_FLOAT) 942 | vtk_density_array.SetName(den_name) 943 | image_data.GetPointData().SetScalars(vtk_density_array) 944 | 945 | # Save the image data object to a VTI file 946 | writer = vtk.vtkXMLImageDataWriter() 947 | writer.SetFileName(save_path) 948 | writer.SetInputData(image_data) 949 | writer.Write() 950 | 951 | def advect_bfecc(q_grid, coord_3d_sim, coord_4d_world, dt, RK=1, vel_net=None): 952 | """ 953 | Args: 954 | q_grid: [X, Y, Z, C] 955 | coord_3d_sim: [X, Y, Z, 3] 956 | coord_4d_world: [X, Y, Z, 4] 957 | dt: float 958 | RK: int, number of Runge-Kutta steps 959 | vel_net: function, velocity network 960 | Returns: 961 | advected_quantity: [XYZ, C] 962 | """ 963 | X, Y, Z, _ = coord_3d_sim.shape 964 | C = q_grid.shape[-1] 965 | q_grid_next = advect_SL(q_grid, coord_3d_sim.view(-1, 3), coord_4d_world.view(-1, 4), dt, RK=RK, vel_net=vel_net) 966 | q_grid_back = advect_SL(q_grid_next.view(X, Y, Z, -1), coord_3d_sim.view(-1, 3), coord_4d_world.view(-1, 4), -dt, RK=RK, vel_net=vel_net) 967 | q_grid_corrected = q_grid + (q_grid - q_grid_back.view(X, Y, Z, -1)) / 2 968 | q_advected = advect_SL(q_grid_corrected, coord_3d_sim.view(-1, 3), coord_4d_world.view(-1, 4), dt, RK=RK, vel_net=vel_net) 969 | return q_advected 970 | 971 | if __name__=="__main__": 972 | with open("data/nerf_synthetic/chair/transforms_train.json", "r") as f: 973 | camera_transforms = json.load(f) 974 | 975 | bounding_box = get_bbox3d_for_blenderobj(camera_transforms, 800, 800) 976 | --------------------------------------------------------------------------------