├── .gitignore ├── LICENSE ├── README.md ├── data ├── cow.mtl ├── cow.obj └── cow_texture.png ├── fm_render.py ├── generate_inputs.ipynb ├── get_co3d.sh ├── images_to_video.sh ├── pose_estimation.ipynb ├── requirements.txt ├── run_co3d_sp-zpfm.ipynb ├── run_co3d_sp.ipynb ├── util.py ├── util_load.py ├── util_render.py ├── utils_opt.py └── zpfm_render.py /.gitignore: -------------------------------------------------------------------------------- 1 | # project stuff 2 | *.png 3 | *.ply 4 | data/ 5 | *.mp4 6 | *.pkl 7 | rvid/ 8 | tmp_out/ 9 | tmp_out_zpfm/ 10 | *.zip 11 | teddybear/ 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | pip-wheel-metadata/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | *.py,cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | db.sqlite3-journal 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fuzzy Metaballs+: Differentiable Renderering with 3D Gaussians, Flow and More. 2 | # [Project Page](https://leonidk.github.io/fmb-plus/) 3 | This is an expanded version of the original [FM renderer](https://leonidk.github.io/fuzzy-metaballs/) with support for flow, mesh exporting, 2-parameter and zero-parameter renderers. 4 | 5 | It primarily is useful for reconstructing CO3D Sequences. The generation operation is 6 | 7 | * Git clone [unimatch](https://github.com/autonomousvision/unimatch) for generating optical flow and [XMem](https://github.com/hkchengrex/XMem) for propogating the first mask into the same root directory. 8 | * Fetch some CO3D sequences. For example using `get_co3d.sh` for the single sequence teddy bear. 9 | * Run `generate_inputs.ipynb` to generate flows and masks for the reconstruction 10 | * Run `run_co3d_sp.ipynb` or `run_co3d_sp-zpfm.ipynb` to run the reconstructions with either the two parameter or zero parameter models 11 | * Compile [PoissonRecon](https://github.com/mkazhdan/PoissonRecon) and run it to generate a mesh via `PoissonRecon --in tmp_out/teddybear_34_1479_4753.ply --out teddy.ply --bType 2 --depth 6` 12 | 13 | ## TODO 14 | 15 | * Clean up code so it's easier to run non-CO3D sequences 16 | * Add an importer from the released version of 3DGS format -------------------------------------------------------------------------------- /data/cow.mtl: -------------------------------------------------------------------------------- 1 | newmtl material_1 2 | map_Kd cow_texture.png 3 | 4 | # Test colors 5 | 6 | Ka 1.000 1.000 1.000 # white 7 | Kd 1.000 1.000 1.000 # white 8 | Ks 0.000 0.000 0.000 # black 9 | Ns 10.0 10 | -------------------------------------------------------------------------------- /data/cow_texture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonidk/fmb-plus/235a078a402968554186a2ca752fb13afffb84f8/data/cow_texture.png -------------------------------------------------------------------------------- /fm_render.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | # this file implements most of https://arxiv.org/abs/2308.14737 5 | 6 | # contains various rotation conversions 7 | from util_render import * 8 | 9 | # core rendering function 10 | def render_func_rays(means, prec_full, weights_log, camera_starts_rays, beta_2, beta_3): 11 | # precision is fully parameterized by triangle matrix 12 | # we use upper triangle for compatibilize with sklearn 13 | prec = jnp.triu(prec_full) 14 | 15 | # if doing normalized weights for proper GMM 16 | # typically not used for shape reconstruction 17 | #weights = jnp.exp(weights_log) 18 | #weights = weights/weights.sum() 19 | 20 | # gets run per gaussian with [precision, log(weight), mean] 21 | def perf_idx(prcI,w,meansI): 22 | # math is easier with lower triangle 23 | prc = prcI.T 24 | 25 | # gaussian scale 26 | # could be useful for log likelihood but not used here 27 | div = jnp.prod(jnp.diag(jnp.abs(prc))) + 1e-20 28 | 29 | # gets run per ray 30 | def perf_ray(r_t): 31 | # unpack the ray (r) and position (t) 32 | r = r_t[0] 33 | t = r_t[1] 34 | 35 | # shift the mean to be relative to ray start 36 | p = meansI - t 37 | 38 | # compute \sigma^{-0.5} p, which is reused 39 | projp = prc @ p 40 | 41 | # compute v^T \sigma^{-1} v 42 | vsv = ((prc @ r)**2).sum() 43 | 44 | # compute p^T \sigma^{-1} v 45 | psv = ((projp) * (prc@r)).sum() 46 | 47 | # compute the surface normal as \sigma^{-1} p 48 | projp2 = prc.T @ projp 49 | 50 | # distance to get maximum likelihood point for this gaussian 51 | # scale here is based on r! 52 | # if r = [x, y, 1], then depth. if ||r|| = 1, then distance 53 | res = (psv)/(vsv) 54 | 55 | # get the intersection point 56 | v = r * res - p 57 | 58 | # compute intersection's unnormalized Gaussian log likelihood 59 | d0 = ((prc @ v)**2).sum()# + 3*jnp.log(jnp.pi*2) 60 | 61 | # multiply by the weight 62 | d2 = -0.5*d0 + w 63 | 64 | # if you wanted real probability 65 | #d3 = d2 + jnp.log(div) #+ 3*jnp.log(res) 66 | 67 | # compute a normalized normal 68 | norm_est = projp2/jnp.linalg.norm(projp2) 69 | norm_est = jnp.where(r@norm_est < 0,norm_est,-norm_est) 70 | 71 | # return ray distance, gaussian distance, normal 72 | return res, d2, norm_est 73 | 74 | # runs parallel for each ray across each gaussian 75 | res,d2,projp = jax.vmap((perf_ray))(camera_starts_rays) 76 | 77 | return res, d2,projp 78 | 79 | # runs parallel for gaussian 80 | zs,stds,projp = jax.vmap(perf_idx)(prec,weights_log,means) 81 | 82 | # alpha is based on distance from all gaussians 83 | est_alpha = 1-jnp.exp(-jnp.exp(stds).sum(0) ) 84 | 85 | # points behind camera should be zero 86 | # BUG: est_alpha should also use this 87 | sig1 = (zs > 0)# sigmoid 88 | 89 | # compute the algrebraic weights in the paper 90 | w = sig1*jnp.nan_to_num(jax_stable_exp(-zs*beta_2 + beta_3*stds))+1e-20 91 | 92 | # normalize weights 93 | wgt = w.sum(0) 94 | div = jnp.where(wgt==0,1,wgt) 95 | w = w/div 96 | 97 | # compute weighted z and normal 98 | init_t = (w*jnp.nan_to_num(zs)).sum(0) 99 | est_norm = (projp * w[:,:,None]).sum(axis=0) 100 | est_norm = est_norm/jnp.linalg.norm(est_norm,axis=1,keepdims=True) 101 | 102 | # return z, alpha, normal, and the weights 103 | # weights can be used to compute color, DINO features, or any other per-Gaussian property 104 | return init_t,est_alpha,est_norm,w 105 | 106 | # renders image if rotation is in: axis angle rotations n * theta 107 | def render_func_axangle(means, prec_full, weights_log, camera_rays, axangl, t, beta_2, beta_3): 108 | Rest = axangle_to_rot(axangl) 109 | camera_rays = camera_rays @ Rest 110 | trans = jnp.tile(t[None],(camera_rays.shape[0],1)) 111 | 112 | camera_starts_rays = jnp.stack([camera_rays,trans],1) 113 | return render_func_rays(means, prec_full, weights_log, camera_starts_rays, beta_2, beta_3,) 114 | 115 | # renders image if rotation is in: modified rod. parameters n * tan(theta/4) 116 | def render_func_mrp(means, prec_full, weights_log, camera_rays, mrp, t, beta_2, beta_3): 117 | Rest = mrp_to_rot(mrp) 118 | camera_rays = camera_rays @ Rest 119 | trans = jnp.tile(t[None],(camera_rays.shape[0],1)) 120 | 121 | camera_starts_rays = jnp.stack([camera_rays,trans],1) 122 | return render_func_rays(means, prec_full, weights_log, camera_starts_rays, beta_2, beta_3) 123 | 124 | # renders image if rotation is in: quaternions [cos(theta/2), sin(theta/2) * n] 125 | def render_func_quat(means, prec_full, weights_log, camera_rays, quat, t, beta_2, beta_3): 126 | Rest = quat_to_rot(quat) 127 | camera_rays = camera_rays @ Rest 128 | trans = jnp.tile(t[None],(camera_rays.shape[0],1)) 129 | 130 | camera_starts_rays = jnp.stack([camera_rays,trans],1) 131 | return render_func_rays(means, prec_full, weights_log, camera_starts_rays, beta_2, beta_3) 132 | 133 | # renders image if rotation is quaternions and we have pixels with a single parameter inverse focal length 134 | def render_func_quat_cam(means, prec_full, weights_log, pixel_list, aspect, invF, quat, t, beta_2, beta_3): 135 | camera_rays = (pixel_list - jnp.array([0.5,0.5,0]))*jnp.array([invF,aspect*invF,1]) 136 | 137 | Rest = quat_to_rot(quat) 138 | camera_rays = camera_rays @ Rest 139 | trans = jnp.tile(t[None],(camera_rays.shape[0],1)) 140 | 141 | camera_starts_rays = jnp.stack([camera_rays,trans],1) 142 | return render_func_rays(means, prec_full, weights_log, camera_starts_rays, beta_2, beta_3) 143 | 144 | # renders batch of rays 145 | # takes pixel coords (pixels, pose) 146 | # takes inverse focal length 147 | # takes full set of poses as (translation, quaternion) pairs 148 | def render_func_idx_quattrans(means, prec_full, weights_log, pixel_posei, invF, poses, beta_2, beta_3): 149 | rot_mats = jax.vmap(quat_to_rot)(poses[:,:4]) 150 | def rot_ray_t(rayi): 151 | ray = rayi[:3] * jnp.array([invF,invF,1]) 152 | pose_idx = rayi[3].astype(int) 153 | return jnp.array([ray@rot_mats[pose_idx],poses[pose_idx][4:]]) 154 | camera_rays_start= jax.vmap(rot_ray_t)(pixel_posei) 155 | return render_func_rays(means, prec_full, weights_log, camera_rays_start, beta_2, beta_3) 156 | 157 | # renders a batch of rays, as above, but also computes fwd & backward flow for each pixel. 158 | def render_func_idx_quattrans_flow(means, prec_full, weights_log, pixel_posei, invF, poses, beta_2, beta_3): 159 | rot_mats = jax.vmap(quat_to_rot)(poses[:,:4]) 160 | def rot_ray_t(rayi): 161 | ray = rayi[:3] * jnp.array([invF,invF,1]) 162 | pose_idx = rayi[3].astype(int) 163 | return jnp.array([ray@rot_mats[pose_idx],poses[pose_idx][4:]]) 164 | camera_rays_start = jax.vmap(rot_ray_t)(pixel_posei) 165 | 166 | # render the pixels 167 | est_depth,est_alpha,est_norm,est_w = render_func_rays(means, prec_full, weights_log, camera_rays_start, beta_2, beta_3) 168 | 169 | # per pixel, compute the flow 170 | def flow_ray_i(rayi,depth): 171 | # find the pose index 172 | pose_idx = rayi[3].astype(int) 173 | pose_idxp1 = jax.lax.min(pose_idx+1,poses.shape[0]-1) 174 | pose_idxm1 = jax.lax.max(pose_idx-1,0) 175 | 176 | # get the pose 177 | R1 = rot_mats[pose_idx] 178 | t1 = poses[pose_idx,4:] 179 | 180 | # get the next pose 181 | Rp1 = rot_mats[pose_idxp1] 182 | tp1 = poses[pose_idxp1,4:] 183 | 184 | # compute this pose 3D 185 | ray = rayi[:3] * jnp.array([invF,invF,1]) 186 | pt_cldc1 = ray * depth 187 | world_p = pt_cldc1 @ R1 + t1 188 | 189 | # transform and project back into next camera 190 | pt_cldc2 = (world_p- tp1) @ Rp1.T 191 | coord1 = pt_cldc2[:2]/(pt_cldc2[2]*invF) 192 | px_coordp = -(coord1 - rayi[:2]) 193 | 194 | # get the previous pose 195 | Rm1 = rot_mats[pose_idxm1] 196 | tm1 = poses[pose_idxm1,4:] 197 | 198 | # transform and project back into previous camera 199 | pt_cldc3 = (world_p - tm1) @ Rm1.T 200 | coord2 = pt_cldc3[:2]/(pt_cldc3[2]*invF) 201 | px_coordm = -(coord2 - rayi[:2]) 202 | 203 | return px_coordp, px_coordm 204 | 205 | # per ray, compute the flow 206 | flowp,flowm = jax.vmap(flow_ray_i)(pixel_posei,est_depth) 207 | 208 | return est_depth,est_alpha,est_norm,est_w,flowp,flowm 209 | 210 | # just a log likelihood function 211 | # in case you want to maximize 3D points 212 | # verify that the above math/structure matches typical GMM log likelihood 213 | # compare with sklearn 214 | def log_likelihood(params, points): 215 | means, prec_full, weights_log = params 216 | prec = jnp.triu(prec_full) 217 | weights = jnp.exp(weights_log) 218 | weights = weights/weights.sum() 219 | 220 | def perf_idx(prcI,w,meansI): 221 | prc = prcI.T 222 | div = jnp.prod(jnp.diag(jnp.abs(prc))) 223 | 224 | def perf_ray(pt): 225 | p = meansI -pt 226 | 227 | pteval = ((prc @ p)**2).sum() 228 | 229 | d0 = pteval+ 3*jnp.log(jnp.pi*2) 230 | d2 = -0.5*d0 + jnp.log(w) 231 | d3 = d2 + jnp.log(div) 232 | 233 | return d3 234 | res = jax.vmap((perf_ray))(points) # jit perf 235 | return res 236 | 237 | res = jax.vmap(perf_idx)(prec,weights,means) # jit perf 238 | 239 | return -jax.scipy.special.logsumexp(res.T, axis=1).ravel().mean(),res# + ent.mean() 240 | -------------------------------------------------------------------------------- /generate_inputs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os, glob\n", 10 | "import pathlib\n", 11 | "\n", 12 | "import numpy as np\n", 13 | "import skimage\n", 14 | "import skimage.io as sio\n", 15 | "import skimage.transform as strans\n", 16 | "\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "\n", 19 | "import pandas as pd\n", 20 | "import transforms3d" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "# process CO3D data in data2 for a certain class\n", 30 | "data_dir = '.'\n", 31 | "type_f = 'teddybear' # hydrant, plant\n", 32 | "idx_num = 1 # which idx" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "\n", 42 | "viables = [_ for _ in sorted(os.listdir(os.path.join(data_dir,type_f))) if os.path.isdir(os.path.join(data_dir,type_f,_)) and len(os.path.join(data_dir,type_f,_).split('_')) == 3]\n", 43 | "co3d_seq = viables[idx_num]\n", 44 | "output_folder = type_f+'_'+co3d_seq\n", 45 | "co3d_seq_folder = os.path.join(data_dir,type_f,co3d_seq)\n", 46 | "co3d_seq,co3d_seq_folder" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "# Requirements\n", 54 | "\n", 55 | "[unimatch](https://github.com/autonomousvision/unimatch) for generating optical flow\n", 56 | "\n", 57 | "\n", 58 | "[XMem](https://github.com/hkchengrex/XMem) for propogating the first mask\n" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "base_folder = co3d_seq_folder\n", 68 | "input_folder = os.path.join(base_folder,'images')\n", 69 | "output_folder = os.path.join('rvid','{}_{}'.format(type_f,co3d_seq))\n", 70 | "\n", 71 | "target_size = 125000\n", 72 | "gmflow_path = '../unimatch/'\n", 73 | "xmem_path = '../XMem/'\n", 74 | "\n", 75 | "frame1_mask = os.path.join(base_folder,'masks','frame000001.png')\n", 76 | "\n", 77 | "imgs_folder = os.path.join(output_folder,'JPEGImages','video1')\n", 78 | "silh_folder = os.path.join(output_folder,'Annotations','video1')\n", 79 | "flow_folder = os.path.join(output_folder,'Flow','video1')\n", 80 | "\n", 81 | "for gen_folder in [output_folder,imgs_folder,silh_folder,flow_folder]:\n", 82 | " if not os.path.exists(gen_folder):\n", 83 | " #import shutil\n", 84 | " #shutil.rmtree(gen_folder)\n", 85 | " os.makedirs(gen_folder)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "in_files = sorted(glob.glob(os.path.join(input_folder,'*.jpg')) + glob.glob(os.path.join(input_folder,'*.png')))" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "PYo,PXo = sio.imread(in_files[0]).shape[:2]\n", 104 | "init_scale = np.prod([PYo,PXo])\n", 105 | "scales = {}\n", 106 | "for i in range(10):\n", 107 | " scale = 2**i\n", 108 | " scales[scale] = init_scale/(scale**2)\n", 109 | "scale_to_use = sorted([(abs(np.log(v/target_size)),k) for k,v in scales.items() ])[0][1]\n", 110 | "PY,PX = int(round(PYo/scale_to_use)),int(round(PXo/scale_to_use))\n", 111 | "scale_to_use,PY,PX" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "valid_inputs = []\n", 121 | "file_map = {}\n", 122 | "for idx,file in enumerate(in_files):\n", 123 | " name = pathlib.Path(file).parts[-1]\n", 124 | " #if not os.path.exists(os.path.join(imgs_folder,name)):\n", 125 | " img = sio.imread(file)\n", 126 | " valid_inputs.append(img.sum() != 0)\n", 127 | " new_name = 'frame{:06d}.jpg'.format(sum(valid_inputs))\n", 128 | " if valid_inputs[-1] == False:\n", 129 | " continue\n", 130 | " #print(new_name)\n", 131 | " file_map[idx] = sum(valid_inputs)\n", 132 | " simg = strans.resize(img,(PY,PX))\n", 133 | " sio.imsave(os.path.join(imgs_folder,new_name),skimage.img_as_ubyte(simg))" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "df = pd.read_json(os.path.join(*(base_folder.split('/')[:-1] + ['frame_annotations.jgz'])),compression={'method':'gzip'})\n", 143 | "df2 = df[df.sequence_name == int(co3d_seq.replace('_',''))]\n", 144 | "fls = []\n", 145 | "pps = []\n", 146 | "sizes = []\n", 147 | "assert(len(df2) == len(valid_inputs))\n", 148 | "for i,row in enumerate(df2.sort_values('frame_number').itertuples()):\n", 149 | " fn, imgd, maskd, view = row[2],row[4],row[6],row[7]\n", 150 | " if not valid_inputs[i]:\n", 151 | " continue\n", 152 | " fl = np.array(view['focal_length'])\n", 153 | " pp = np.array(view['principal_point'])\n", 154 | " sizeA = list(row[4]['size'])\n", 155 | "\n", 156 | " if 'intrinsics_format' in view and view['intrinsics_format'] == 'ndc_isotropic':\n", 157 | " half_image_size_wh_orig = np.array(list(reversed(sizeA))) / 2.0\n", 158 | " rescale = half_image_size_wh_orig.min()\n", 159 | " # principal point and focal length in pixels\n", 160 | " principal_point_px = half_image_size_wh_orig - pp * rescale\n", 161 | " focal_length_px = fl * rescale\n", 162 | " else:\n", 163 | " half_image_size_wh_orig = np.array(list(reversed(sizeA))) / 2.0\n", 164 | " # principal point and focal length in pixels\n", 165 | " principal_point_px = (\n", 166 | " -1.0 * (pp - 1.0) * half_image_size_wh_orig\n", 167 | " )\n", 168 | " focal_length_px = fl * half_image_size_wh_orig\n", 169 | "\n", 170 | " fls.append(focal_length_px)\n", 171 | " pps.append(principal_point_px)\n", 172 | "\n", 173 | " sizes.append(sizeA)\n", 174 | "assert(np.array(sizes).std(0).sum() == 0) # same sizes\n", 175 | "pp = np.array(pps).mean(0)\n", 176 | "fl = np.array(fls).mean(0).mean()\n", 177 | "meanpp = (np.array([pp[1],pp[0]])/np.array(sizes).mean(0)).mean() \n", 178 | "assert(abs(meanpp - 0.5) < 1e-3) # basically center of frame\n", 179 | "fl = fl/scale_to_use" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "len(valid_inputs),df2.shape,len(in_files)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "poses = []\n", 198 | "depths = []\n", 199 | "import skimage.io as sio\n", 200 | "import skimage.transform as sktrans\n", 201 | "\n", 202 | "for i,row in enumerate(df2.sort_values('frame_number').itertuples()):\n", 203 | " fn, imgd, maskd, view = row[2],row[4],row[6],row[7]\n", 204 | " depthd = row[5]\n", 205 | " if not valid_inputs[i]:\n", 206 | " continue\n", 207 | " maskd = maskd['path'][maskd['path'].index(co3d_seq):]\n", 208 | " imgd = imgd['path'][imgd['path'].index(co3d_seq):]\n", 209 | " \n", 210 | " Rmat = np.array(view['R'])\n", 211 | " Tvec = np.array(view['T'])\n", 212 | " Tvec = -Rmat @ Tvec\n", 213 | " q = transforms3d.quaternions.mat2quat(Rmat.T)\n", 214 | " poses.append(list(q) + list(Tvec))\n", 215 | " \n", 216 | " depth_r = sio.imread(os.path.join(data_dir,type_f,depthd['path'][depthd['path'].index(co3d_seq):]))#.astype(float)\n", 217 | " depth_m = sio.imread(os.path.join(data_dir,type_f,depthd['mask_path'][depthd['mask_path'].index(co3d_seq):])).astype(float)\n", 218 | " \n", 219 | " depth_r_s = depth_r.shape\n", 220 | " depth_r = depthd['scale_adjustment']*np.frombuffer(depth_r,dtype=np.float16).astype(np.float32).reshape(depth_r_s)\n", 221 | "\n", 222 | " valid_d = (depth_r > 0)\n", 223 | " depth_r[~valid_d] = np.nan\n", 224 | " depth_r = sktrans.resize(depth_r,(PY,PX),anti_aliasing=False,order=0)\n", 225 | " depths.append(depth_r)" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "import pickle\n", 235 | "import gzip\n", 236 | "import trimesh\n", 237 | "with gzip.open(os.path.join(output_folder,'pose_depth.pkl.gz'), \"wb\") as f:\n", 238 | " out_dict = {'fl':fl,'poses':poses,'depths':depths}\n", 239 | " ply_path = os.path.join(co3d_seq_folder,'pointcloud.ply')\n", 240 | " if os.path.exists(ply_path):\n", 241 | " out_dict['mesh'] = trimesh.load(ply_path)\n", 242 | " pickle.dump(out_dict, f)" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "gm_flow_cmd = \"\"\"CUDA_VISIBLE_DEVICES=\"0\" python {} \\\n", 252 | "--inference_dir {} \\\n", 253 | "--output_path {} \\\n", 254 | "--pred_bidir_flow \\\n", 255 | "--save_flo_flow \\\n", 256 | "--resume {} {}\n", 257 | "\"\"\"\n", 258 | "alt_cmd = '--inference_size {} {}'.format(PX*2,PY*2) if target_size < 1e5 else ''\n", 259 | "gm_flow_cmd_f = gm_flow_cmd.format(os.path.join(gmflow_path,'main_flow.py'),imgs_folder,flow_folder,os.path.join(gmflow_path,'pretrained','gmflow-scale1-mixdata-train320x576-4c3a6e9a.pth'),alt_cmd)\n", 260 | "print(gm_flow_cmd_f)" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "if len(glob.glob(os.path.join(flow_folder,'*.flo'))) != (len(in_files)*2-2):\n", 270 | " os.system(gm_flow_cmd_f)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "imgM = sio.imread(frame1_mask)\n", 280 | "simgM = skimage.img_as_ubyte(strans.resize(imgM,(PY,PX)) >0.5)\n", 281 | "sio.imsave(os.path.join(silh_folder,pathlib.Path(in_files[0]).parts[-1].replace('.jpg','.png')),simgM)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "xmem_output = os.path.join(output_folder,'masks')\n", 291 | "xmem_cmd = 'python {} --model {} --dataset G --generic_path {} --output {}'.format(os.path.join(xmem_path,'eval.py'),os.path.join(xmem_path,'saves','XMem.pth'),output_folder,xmem_output)" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "os.system(xmem_cmd)" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": null, 306 | "metadata": {}, 307 | "outputs": [], 308 | "source": [ 309 | "xmem_cmd" 310 | ] 311 | } 312 | ], 313 | "metadata": { 314 | "language_info": { 315 | "name": "python" 316 | } 317 | }, 318 | "nbformat": 4, 319 | "nbformat_minor": 5 320 | } 321 | -------------------------------------------------------------------------------- /get_co3d.sh: -------------------------------------------------------------------------------- 1 | wget -N https://dl.fbaipublicfiles.com/co3dv2_231130/teddybear_000_singlesequence.zip 2 | wget -N https://dl.fbaipublicfiles.com/co3dv2_231130/teddybear_001_singlesequence.zip 3 | unzip teddybear_000_singlesequence.zip 4 | unzip teddybear_001_singlesequence.zip 5 | -------------------------------------------------------------------------------- /images_to_video.sh: -------------------------------------------------------------------------------- 1 | /usr/local/bin/ffmpeg -framerate 24 -i $1/%03d.jpg -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2" -c:v h264 -pix_fmt yuv420p $2 2 | -------------------------------------------------------------------------------- /pose_estimation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "%matplotlib inline" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": { 17 | "colab_type": "text", 18 | "id": "qkX7DiM6rmeM" 19 | }, 20 | "source": [ 21 | "# Pose Estimation\n", 22 | "Compare to PyTorch3D `Camera position optimization` sample. " 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "import os\n", 32 | "import sys\n", 33 | "import pickle\n", 34 | "import glob\n", 35 | "\n", 36 | "import numpy as np\n", 37 | "import matplotlib.pyplot as plt" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "from tqdm.notebook import tqdm" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "import trimesh\n", 56 | "import pyrender\n", 57 | "import transforms3d\n", 58 | "\n", 59 | "from tqdm.notebook import tqdm\n", 60 | "class QuasiRandom():\n", 61 | " def __init__(self,dim=1,seed=None):\n", 62 | " self.dim = dim\n", 63 | " self.x = np.random.rand(dim) if seed is None else seed\n", 64 | " root_sys = [1] +[0 for i in range(dim-1)] + [-1,-1]\n", 65 | " self.const = sorted(np.roots(root_sys))[-1].real\n", 66 | " self.phi = np.array([1/(self.const)**(i+1) for i in range(dim)])\n", 67 | " def generate(self,n_points=1):\n", 68 | " res = np.zeros((n_points,self.dim))\n", 69 | " for i in range(n_points):\n", 70 | " res[i] = self.x = (self.x+self.phi)\n", 71 | " return np.squeeze(res%1)\n", 72 | " \n", 73 | "mesh_file = 'data/cow.obj'\n", 74 | "\n", 75 | "mesh_tri = trimesh.load(mesh_file)\n", 76 | "\n", 77 | "# seems sane to fetch/estimate scale\n", 78 | "shape_scale = float(mesh_tri.vertices.std(0).mean())*3\n", 79 | "center = np.array(mesh_tri.vertices.mean(0))\n", 80 | "t_model_scale = np.ptp(mesh_tri.vertices,0).mean()\n", 81 | "\n", 82 | "print('model is {:.2f}x the size of the cow'.format(shape_scale/1.18))" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "image_size = (64,64)\n", 92 | "vfov_degrees = 45\n", 93 | "focal_length = 0.5*image_size[0]/np.tan((np.pi/180.0)*vfov_degrees/2)\n", 94 | "cx = (image_size[1]-1)/2\n", 95 | "cy = (image_size[0]-1)/2\n", 96 | "rand_quat = QuasiRandom(dim=4,seed=0).generate(1)\n", 97 | "rand_quat = rand_quat/np.linalg.norm(rand_quat)\n", 98 | "\n", 99 | "mesh = pyrender.Mesh.from_trimesh(mesh_tri)\n", 100 | "\n", 101 | "scene = pyrender.Scene()\n", 102 | "scene.add(mesh)\n", 103 | "\n", 104 | "\n", 105 | "R = transforms3d.quaternions.quat2mat(rand_quat)\n", 106 | "loc = np.array([0,0,3*shape_scale]) @ R + center\n", 107 | "pose = np.vstack([np.vstack([R,loc]).T,np.array([0,0,0,1])])\n", 108 | "\n", 109 | "light = pyrender.SpotLight(color=np.ones(3), intensity=50.0,\n", 110 | " innerConeAngle=np.pi/16.0,\n", 111 | " outerConeAngle=np.pi/6.0)\n", 112 | "scene.add(light, pose=pose)\n", 113 | "\n", 114 | "camera = pyrender.IntrinsicsCamera(focal_length,focal_length,cx,cy,znear=0.1*shape_scale,zfar=100*shape_scale)\n", 115 | "scene.add(camera,pose=pose)\n", 116 | "\n", 117 | "r = pyrender.OffscreenRenderer(image_size[1],image_size[0])\n", 118 | "color, target_depth = r.render(scene)\n", 119 | "target_depth[target_depth ==0] = np.nan\n", 120 | "\n", 121 | "plt.subplot(1,2,1)\n", 122 | "plt.imshow(color)\n", 123 | "plt.subplot(1,2,2)\n", 124 | "plt.imshow(target_depth)\n", 125 | "plt.tight_layout()" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": {}, 131 | "source": [ 132 | "# Setup Fuzzy Metaball renderer" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "import os\n", 142 | "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n", 143 | "\n", 144 | "import jax\n", 145 | "#jax.config.update('jax_platform_name', 'cpu')\n", 146 | "\n", 147 | "import jax.numpy as jnp\n", 148 | "import fm_render\n", 149 | "\n", 150 | "# volume usually False since color optimization implies surface samples\n", 151 | "# And code defaults towards that sort of usage now\n", 152 | "show_volume = False\n", 153 | "\n", 154 | "NUM_MIXTURE = 40\n", 155 | "beta2 = 21.4\n", 156 | "beta3 = 2.66\n", 157 | "\n", 158 | "gmm_init_scale = 80\n", 159 | "\n", 160 | "render_jit = jax.jit(fm_render.render_func_quat)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "\n", 170 | "import trimesh\n", 171 | "import sklearn.mixture\n", 172 | "if show_volume:\n", 173 | " pts = trimesh.sample.volume_mesh(mesh_tri,10000)\n", 174 | "else:\n", 175 | " pts = trimesh.sample.sample_surface_even(mesh_tri,10000)[0]\n", 176 | "gmm = sklearn.mixture.GaussianMixture(NUM_MIXTURE)\n", 177 | "gmm.fit(pts)\n", 178 | "weights_log = np.log( gmm.weights_) + np.log(gmm_init_scale)\n", 179 | "mean = gmm.means_\n", 180 | "prec = gmm.precisions_cholesky_\n" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "\n", 190 | "height, width = image_size\n", 191 | "K = np.array([[focal_length, 0, cx],[0,focal_length,cy],[0,0,1]])\n", 192 | "pixel_list = (np.array(np.meshgrid(np.arange(width),height-np.arange(height)-1,[0]))[:,:,:,0]).reshape((3,-1)).T\n", 193 | "camera_rays = (pixel_list - K[:,2])/np.diag(K)\n", 194 | "camera_rays[:,-1] = -1\n", 195 | "\n", 196 | "trans_true = loc\n", 197 | "quat_true = rand_quat" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "# Add noise to pose" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "while True:\n", 214 | " t_err_cap = 0.5\n", 215 | " rad_eps = (np.pi/180.0)*90 # range. so 90 is -45 to +45\n", 216 | "\n", 217 | " t_err_vec = np.random.randn(3)\n", 218 | " t_err_vec = t_err_vec/np.linalg.norm(t_err_vec)\n", 219 | " t_err_mag = np.random.rand()\n", 220 | "\n", 221 | " trans_offset = t_err_cap*t_err_mag*t_err_vec*t_model_scale\n", 222 | " trans_shift = trans_true - trans_offset\n", 223 | "\n", 224 | " angles = np.random.randn(3)\n", 225 | " angles = angles/np.linalg.norm(angles)\n", 226 | " angle_mag = (np.random.rand()-0.5)*rad_eps\n", 227 | " R_I = transforms3d.quaternions.quat2mat(quat_true).T\n", 228 | " R_R = transforms3d.axangles.axangle2mat(angles,angle_mag)\n", 229 | " R_C = R_R @ R_I\n", 230 | "\n", 231 | " quat_init = transforms3d.quaternions.mat2quat(R_C.T)\n", 232 | " trans_init = R_R@trans_shift\n", 233 | "\n", 234 | " rand_rot = abs(angle_mag*(180.0/np.pi))\n", 235 | " rand_trans = 100*(t_err_mag*t_err_cap)\n", 236 | " init_pose_err = np.sqrt(rand_rot*rand_trans)\n", 237 | " if rand_trans >30 and rand_rot >30:\n", 238 | " print('pose error of {:.1f}, random rotation of {:.1f} degrees and translation of {:.1f}%'.format(init_pose_err,rand_rot,rand_trans))\n", 239 | " break\n", 240 | "#axangl_init = axangl_true.copy()\n", 241 | "#trans_init = trans_true.copy()" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "def compute_normals(camera_rays, depth_py_px,image_size):\n", 251 | " nan_depth = depth_py_px.ravel()\n", 252 | " PY,PX=image_size\n", 253 | " #nan_depth = jnp.nan_to_num(depth_py_px.ravel(),nan=1e-9)\n", 254 | "\n", 255 | " dpt = jnp.array( camera_rays.reshape((-1,3)) * nan_depth[:,None] )\n", 256 | " dpt = dpt.reshape((PY,PX,3))\n", 257 | " ydiff = dpt - jnp.roll(dpt,1,0)\n", 258 | " xdiff = dpt - jnp.roll(dpt,1,1)\n", 259 | " ydiff = jnp.nan_to_num(ydiff,nan=1e-9) # new\n", 260 | " xdiff = jnp.nan_to_num(xdiff,nan=1e-9) # new \n", 261 | "\n", 262 | " ddiff = jnp.cross(xdiff.reshape((-1,3)),ydiff.reshape((-1,3)),)\n", 263 | " nan_ddiff = jnp.nan_to_num(ddiff,nan=0)\n", 264 | " norms = nan_ddiff/(1e-20+jnp.linalg.norm(nan_ddiff,axis=1,keepdims=True))\n", 265 | "\n", 266 | " return norms\n" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "metadata": {}, 272 | "source": [ 273 | "# Solve for camera pose" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "def error_func(est_depth,est_alpha,true_depth):\n", 283 | " cond = jnp.isnan(est_depth) | jnp.isnan(true_depth)\n", 284 | " #err = (est_depth - true_depth)/jnp.nan_to_num(true_depth,nan=1)\n", 285 | " err = (est_depth - true_depth)/jnp.nanmean(true_depth)\n", 286 | "\n", 287 | " depth_loss = abs(jnp.where(cond,0,err)).mean()\n", 288 | "\n", 289 | " true_alpha = ~jnp.isnan(true_depth)\n", 290 | " est_alpha = jnp.clip(est_alpha,1e-7,1-1e-7)\n", 291 | " mask_loss = -((true_alpha * jnp.log(est_alpha)) + (~true_alpha)*jnp.log(1-est_alpha))\n", 292 | "\n", 293 | " term1 = depth_loss.mean()\n", 294 | " term2 = mask_loss.mean()\n", 295 | " return 50*term1 + term2\n", 296 | "\n", 297 | "def objective(params,means,prec,weights_log,camera_rays,beta2,beta3,depth):\n", 298 | " mrp,trans= params\n", 299 | " render_res = render_jit(means,prec,weights_log,camera_rays,mrp,trans,beta2,beta3)\n", 300 | " return error_func(render_res[0],render_res[1],depth)\n", 301 | "\n", 302 | "def objective_simple(params,means,prec,weights_log,camera_rays,beta2,beta3,depth):\n", 303 | " mrp = jnp.array(params[:3])\n", 304 | " trans = jnp.array(params[3:])\n", 305 | " render_res = render_jit(means,prec,weights_log,camera_rays,mrp,trans,beta2,beta3)\n", 306 | " return error_func(render_res[0],render_res[1],depth)\n", 307 | "grad_render3 = jax.jit(jax.value_and_grad(objective))" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [ 316 | "from jax.example_libraries import optimizers\n", 317 | "from util import DegradeLR\n", 318 | "# Number of optimization steps\n", 319 | "# typically only needs a few hundred\n", 320 | "# and early exits\n", 321 | "Niter = 2000\n", 322 | "\n", 323 | "loop = tqdm(range(Niter))\n", 324 | "\n", 325 | "# babysit learning rates\n", 326 | "adjust_lr = DegradeLR(3e-4,0.1,50,10,-1e-4)\n", 327 | "opt_init, opt_update, opt_params = optimizers.momentum(adjust_lr.step_func,0.95)\n", 328 | "\n", 329 | "# to test scale invariance\n", 330 | "HUHSCALE = 1\n", 331 | "# should get same result even if world scale changes\n", 332 | "\n", 333 | "tmp = [quat_init,HUHSCALE*trans_init]\n", 334 | "opt_state = opt_init(tmp)\n", 335 | "\n", 336 | "losses = []\n", 337 | "jax_tdepth = jnp.array(target_depth.ravel())\n", 338 | "\n", 339 | "for i in loop:\n", 340 | " p = opt_params(opt_state)\n", 341 | "\n", 342 | " val,g = grad_render3(p,HUHSCALE*mean,prec/HUHSCALE,weights_log,camera_rays,beta2/(HUHSCALE*shape_scale),beta3,HUHSCALE*jax_tdepth)\n", 343 | " \n", 344 | " S = jnp.linalg.norm(p[1])\n", 345 | " S2 = S*S\n", 346 | "\n", 347 | " g1 = g[0]\n", 348 | " g2 = g[1]*S2\n", 349 | "\n", 350 | " opt_state = opt_update(i, [g1,g2], opt_state)\n", 351 | "\n", 352 | " val = float(val)\n", 353 | " losses.append(val)\n", 354 | " if adjust_lr.add(val):\n", 355 | " break\n", 356 | " # Print the losses\n", 357 | " loop.set_description(\"total_loss = %.3f\" % val)" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": null, 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [ 366 | "mrp_final, trans_final = opt_params(opt_state)\n", 367 | "trans_final = trans_final/HUHSCALE" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": null, 373 | "metadata": {}, 374 | "outputs": [], 375 | "source": [ 376 | "# 2nd order is also possible\n", 377 | "if False:\n", 378 | " from jax.scipy.optimize import minimize\n", 379 | " res = minimize(objective_simple,jnp.hstack([mrp_init,trans_init]),method='BFGS',args=(mean,prec,weight_log,camera_rays,beta2,beta3,beta4,beta5,jax_tdepth,))\n", 380 | " mrp_final = res.x[:3]\n", 381 | " trans_final = res.x[3:]" 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": null, 387 | "metadata": {}, 388 | "outputs": [], 389 | "source": [ 390 | "plt.title('convergence plot')\n", 391 | "plt.plot(losses,marker='.',lw=0,ms=5,alpha=0.5)\n", 392 | "plt.xlabel('iteration')\n", 393 | "plt.ylabel('log loss')" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": null, 399 | "metadata": {}, 400 | "outputs": [], 401 | "source": [ 402 | "quat_final, trans_final = opt_params(opt_state)\n", 403 | "trans_final = trans_final/HUHSCALE" 404 | ] 405 | }, 406 | { 407 | "cell_type": "markdown", 408 | "metadata": {}, 409 | "source": [ 410 | "# Visualize Results" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": null, 416 | "metadata": {}, 417 | "outputs": [], 418 | "source": [ 419 | "vmin,vmax = np.nanmin(target_depth),np.nanmax(target_depth)" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": null, 425 | "metadata": {}, 426 | "outputs": [], 427 | "source": [ 428 | "plt.subplot(2,3,1)\n", 429 | "plt.imshow(color)\n", 430 | "plt.title('image')\n", 431 | "plt.axis('off')\n", 432 | "plt.subplot(2,3,4)\n", 433 | "est_depth_true, est_alpha_true, _, _ = render_jit(mean,prec,weights_log,camera_rays,quat_true,trans_true,beta2/shape_scale,beta3)\n", 434 | "est_depth_true = np.array(est_depth_true)\n", 435 | "est_depth_true[est_alpha_true < 0.5] = np.nan\n", 436 | "plt.imshow(est_depth_true.reshape(image_size),vmin=vmin,vmax=vmax)\n", 437 | "plt.title('true pose')\n", 438 | "plt.axis('off')\n", 439 | "est_depth_init, est_alpha, _, _ = render_jit(mean,prec,weights_log,camera_rays,quat_init,trans_init,beta2/shape_scale,beta3)\n", 440 | "est_depth_init = np.array(est_depth_init)\n", 441 | "est_depth_init[est_alpha < 0.5] = np.nan\n", 442 | "plt.subplot(2,3,2)\n", 443 | "plt.imshow(est_alpha.reshape(image_size),cmap='Greys')\n", 444 | "plt.title('init FM alpha')\n", 445 | "plt.axis('off')\n", 446 | "plt.subplot(2,3,5)\n", 447 | "plt.imshow(est_depth_init.reshape(image_size),vmin=vmin,vmax=vmax)\n", 448 | "plt.title('init FM depth')\n", 449 | "plt.axis('off')\n", 450 | "est_depth, est_alpha, _, _ = render_jit(mean,prec,weights_log,camera_rays,quat_final,trans_final,beta2/shape_scale,beta3)\n", 451 | "est_depth = np.array(est_depth)\n", 452 | "est_depth[est_alpha < 0.5] = np.nan\n", 453 | "plt.subplot(2,3,3)\n", 454 | "plt.imshow(est_alpha.reshape(image_size),cmap='Greys')\n", 455 | "plt.title('final FM alpha')\n", 456 | "plt.axis('off')\n", 457 | "plt.subplot(2,3,6)\n", 458 | "plt.imshow(est_depth.reshape(image_size),vmin=vmin,vmax=vmax)\n", 459 | "plt.title('final FM depth')\n", 460 | "plt.axis('off')\n", 461 | "plt.tight_layout()" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": null, 467 | "metadata": {}, 468 | "outputs": [], 469 | "source": [ 470 | "q1 = quat_true/np.linalg.norm(quat_true)\n", 471 | "q2 = quat_final/np.linalg.norm(quat_final)\n", 472 | "e1 = np.arccos(np.clip((q1 * q2).sum(),-1,1))\n", 473 | "e2 = np.arccos(np.clip((-q1 * q2).sum(),-1,1))\n", 474 | "rot_err = float((180.0/np.pi)*2*min(e1,e2))\n", 475 | "\n", 476 | "R1 = np.array(transforms3d.quaternions.quat2mat(q1))\n", 477 | "R2 = np.array(transforms3d.quaternions.quat2mat(q2))\n", 478 | "t_norm = np.linalg.norm(R1.T@np.array(trans_true)-R2.T@np.array(trans_final))\n", 479 | "trans_err = 100*t_norm/t_model_scale\n", 480 | "\n", 481 | "pose_err = np.sqrt(rot_err*trans_err)\n", 482 | "print('init. pose error of {:04.1f} with rot. of {:04.1f} deg and trans. of {:04.1f}%'.format(init_pose_err,rand_rot,rand_trans))\n", 483 | "print('final pose error of {:04.1f} with rot. of {:04.1f} deg and trans. of {:04.1f}%'.format(pose_err,rot_err,trans_err))" 484 | ] 485 | } 486 | ], 487 | "metadata": { 488 | "language_info": { 489 | "name": "python" 490 | } 491 | }, 492 | "nbformat": 4, 493 | "nbformat_minor": 1 494 | } 495 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | matplotlib 3 | scipy 4 | jax 5 | tqdm 6 | Pillow 7 | scikit-image 8 | pyrender 9 | transforms3d 10 | optax 11 | open3d -------------------------------------------------------------------------------- /run_co3d_sp-zpfm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Run CO3D Sequence (zero parameter)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import sys, os, glob\n", 17 | "import numpy as np\n", 18 | "import pandas as pd\n", 19 | "from utils_opt import readFlow\n", 20 | "\n", 21 | "import skimage.io as sio\n", 22 | "import matplotlib.pyplot as plt" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "dataset_dir = 'rvid/teddybear_34_1479_4753/'\n", 32 | "co3d_seq = os.path.split(dataset_dir.rstrip('/').lstrip('/'))[-1]\n", 33 | "output_folder = os.path.join('tmp_out_zpfm',co3d_seq)\n", 34 | "NUM_MIXTURE = 40\n", 35 | "shape_scale = 1.8\n", 36 | "c_scale = 4.5\n", 37 | "f_scale = 210\n", 38 | "rand_sphere_size = 55\n", 39 | "cov_scale = 1.2e-2\n", 40 | "weight_scale = 1.1\n", 41 | "LR_RATE = 0.08\n", 42 | "Nepoch = 10\n", 43 | "batch_size = 50000" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "## Load Data" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "import gzip\n", 60 | "import pickle\n", 61 | "\n", 62 | "\n", 63 | "with gzip.open(os.path.join(dataset_dir,'pose_depth.pkl.gz'),'rb') as fp:\n", 64 | " depth_and_pose = pickle.load(fp)\n", 65 | " \n", 66 | "true_depths = depth_and_pose['depths']\n", 67 | "fl = depth_and_pose['fl']\n", 68 | "poses = np.array(depth_and_pose['poses'])" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "masks_folder = os.path.join(dataset_dir,'masks','video1')\n", 78 | "in_files = sorted(glob.glob(masks_folder + '/*.png'))\n", 79 | "\n", 80 | "masks = []\n", 81 | "for img_loc in in_files:\n", 82 | " mask = sio.imread(img_loc).astype(np.float32)\n", 83 | " mask = (mask > 0).astype(np.float32)\n", 84 | " masks.append(mask)\n", 85 | "masks = np.array(masks)\n", 86 | "PY,PX = mask.shape\n", 87 | "image_size = (PY,PX)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "masks_folder = os.path.join(dataset_dir,'JPEGImages','video1')\n", 97 | "in_files = sorted(glob.glob(masks_folder + '/*.jpg'))\n", 98 | "\n", 99 | "images = []\n", 100 | "for img_loc in in_files:\n", 101 | " img = sio.imread(img_loc).astype(np.float32)\n", 102 | " images.append(img)\n", 103 | "images = np.array(images)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "fwd_flows = []\n", 113 | "bwd_flows = []\n", 114 | "\n", 115 | "flow_fol = os.path.join(dataset_dir,'Flow','video1','*.flo')\n", 116 | "\n", 117 | "flow_files = sorted(glob.glob(flow_fol))\n", 118 | "\n", 119 | "for flfile in flow_files:\n", 120 | " new_flow = readFlow(flfile)\n", 121 | " if PY > PX:\n", 122 | " new_flow = np.stack([new_flow[:,:,1],new_flow[:,:,0]],axis=2)\n", 123 | " if 'bwd' in flfile:\n", 124 | " bwd_flows.append(new_flow)\n", 125 | " else:\n", 126 | " fwd_flows.append(new_flow)\n", 127 | "\n", 128 | "\n", 129 | "# last flow has no fowards\n", 130 | "fwd_flows = fwd_flows + [new_flow*0]\n", 131 | "# first flow has no backwards\n", 132 | "bwd_flows = [new_flow*0] + bwd_flows" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "if 'mesh' in depth_and_pose:\n", 142 | " pt_cld = depth_and_pose['mesh'].vertices\n", 143 | " import sklearn.mixture as mixture\n", 144 | "\n", 145 | " idx2 = np.arange(pt_cld.shape[0])\n", 146 | " np.random.shuffle(idx2)\n", 147 | " clf = mixture.GaussianMixture(40)\n", 148 | " clf.fit(pt_cld[idx2[:10000]])\n", 149 | "\n", 150 | " pt_cld_shape_scale = float(pt_cld.std(0).mean())*3 \n", 151 | " center = pt_cld.mean(0)\n", 152 | "else:\n", 153 | " pt_cld_shape_scale = 3.0\n", 154 | " center = np.zeros(3,dtype=np.float32)" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "SCALE_MUL_FACTOR = shape_scale/pt_cld_shape_scale\n", 164 | "# gradients can be sensitive to scale. here we solve on scale = 2.2" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "import os\n", 174 | "#os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n", 175 | "\n", 176 | "import jax\n", 177 | "#jax.config.update('jax_platform_name', 'cpu')\n", 178 | "\n", 179 | "import jax.numpy as jnp\n", 180 | "import zpfm_render\n", 181 | "\n", 182 | "render_jit = jax.jit(zpfm_render.render_func_idx_quattrans)\n" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "img_shape = (PY,PX)\n", 192 | "min_size_idx = np.argmin(img_shape)\n", 193 | "min_size = img_shape[min_size_idx]\n", 194 | "max_size = img_shape[1-min_size_idx]\n", 195 | "invF = 0.5*min_size/fl\n", 196 | "min_dim = np.linspace(-1,1,min_size)\n", 197 | "aspect = max_size/min_size\n", 198 | "max_dim = np.linspace(-aspect,aspect,max_size)\n", 199 | "grid = [-max_dim,-min_dim,1,0] if min_size_idx == 0 else [-min_dim,-max_dim,1,0]\n", 200 | "pixel_list = np.transpose(np.squeeze(np.meshgrid(*grid,indexing='ij')),(2,1,0))\n", 201 | "\n", 202 | "print(pixel_list.shape,img_shape)\n", 203 | "pixel_list = pixel_list.reshape((-1,4))" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "poses = jnp.array(poses)" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "from util import image_grid" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "# random init settings\n", 231 | "rand_mean = center+pt_cld_shape_scale*np.random.multivariate_normal(mean=[0,0,0],cov=cov_scale*np.identity(3),size=NUM_MIXTURE)\n", 232 | "rand_weight_log = jnp.log(weight_scale*np.ones(NUM_MIXTURE)/NUM_MIXTURE)\n", 233 | "rand_prec = jnp.array([np.identity(3)*rand_sphere_size/pt_cld_shape_scale for _ in range(NUM_MIXTURE)])\n", 234 | "rand_color = jnp.array(np.random.randn(NUM_MIXTURE,3))\n", 235 | "\n", 236 | "init_alphas = []\n", 237 | "for i in range(len(poses)):\n", 238 | " pixel_list[:,3] = i\n", 239 | " res_img,est_alpha,_,_ = render_jit(rand_mean,rand_prec,rand_weight_log,pixel_list,invF,poses)\n", 240 | "\n", 241 | " res_imgA = np.array(res_img)\n", 242 | " res_imgA[est_alpha < 0.5] = np.nan\n", 243 | " init_alphas.append(est_alpha.reshape((PY,PX)))\n", 244 | "image_grid(init_alphas,6,6,rgb=False)" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "total_ray_set = []\n", 254 | "for i in range(len(poses)):\n", 255 | " pixel_list[:,3] = i\n", 256 | "\n", 257 | " total_ray_set.append(pixel_list.copy())" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "all_rays = jnp.vstack(total_ray_set)\n", 267 | "# scaled into ray coord space, vectorized flows\n", 268 | "fwv_flow = jnp.array(np.array(fwd_flows).reshape((-1,2)))/(min_size/2)\n", 269 | "bwv_flow = jnp.array(np.array(bwd_flows).reshape((-1,2)))/(min_size/2)" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "render_jit_ray = jax.jit(zpfm_render.render_func_rays)\n", 279 | "last_img_size = np.prod(img_shape)\n", 280 | "v_idx = 40" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "def objective(params,camera_rays,invF,poses,true_alpha,true_fwd,true_bwd,true_color):\n", 290 | " CLIP_ALPHA = 1e-6\n", 291 | " means,prec,weights_log,colors = params\n", 292 | " est_depth, est_alpha, est_norm, est_w,flowp,flowm = zpfm_render.render_func_idx_quattrans_flow(means,prec,weights_log,camera_rays,invF,poses)\n", 293 | " est_w = est_w.T\n", 294 | " est_w = est_w/jnp.maximum(est_w.sum(axis=1,keepdims=True),1e-7)\n", 295 | " \n", 296 | " est_color = est_w @ (jnp.tanh(colors)*0.5+0.5)\n", 297 | " est_alpha = jnp.clip(est_alpha,CLIP_ALPHA,1-CLIP_ALPHA)\n", 298 | " mask_loss = - ((true_alpha * jnp.log(est_alpha)) + (1-true_alpha)*jnp.log(1-est_alpha))\n", 299 | " pad_alpha = (true_alpha)[:,None]\n", 300 | " flow1 = jnp.abs(pad_alpha*true_fwd-pad_alpha*jnp.nan_to_num(flowp))\n", 301 | " flow2 = jnp.abs(pad_alpha*true_bwd-pad_alpha*jnp.nan_to_num(flowm))\n", 302 | " cdiff = jnp.abs( (true_color-est_color)*true_alpha[:,None] )\n", 303 | " return mask_loss.mean() + c_scale*cdiff.mean() + f_scale*(flow1.mean() + flow2.mean()) #+ 0\n", 304 | "grad_render3 = jax.value_and_grad(objective)\n" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": null, 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "SCALE_MUL_FACTOR" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": null, 319 | "metadata": {}, 320 | "outputs": [], 321 | "source": [ 322 | "import optax\n", 323 | "from tqdm.notebook import tqdm\n", 324 | "from util import DegradeLR\n", 325 | "\n", 326 | "vecM = jnp.array([1,1,1,1,SCALE_MUL_FACTOR,SCALE_MUL_FACTOR,SCALE_MUL_FACTOR])[None]\n", 327 | "\n", 328 | "train_size = all_rays.shape[0]\n", 329 | "Niter_epoch = int(round(train_size/batch_size))\n", 330 | "\n", 331 | "def irc(x): return int(round(x))\n", 332 | "\n", 333 | "# babysit learning rates\n", 334 | "adjust_lr = DegradeLR(LR_RATE,0.5,irc(Niter_epoch*0.25),irc(Niter_epoch*0.1),-1e-4)\n", 335 | "\n", 336 | "optimizer = optax.adam(adjust_lr.step_func)\n", 337 | "\n", 338 | "tmp = [rand_mean,rand_prec,rand_weight_log,rand_color]\n", 339 | "#tmp = [means,prec,weights_log]\n", 340 | "\n", 341 | "opt_state = optimizer.init(tmp)\n", 342 | "\n", 343 | "all_sils = jnp.hstack([_.ravel() for _ in masks]).astype(jnp.float32)\n", 344 | "all_colors = jnp.hstack([_.ravel()/255.0 for _ in images]).astype(jnp.float32).reshape((-1,3))\n", 345 | "all_colors = all_colors**(1/2.2)\n", 346 | "\n", 347 | "losses = []\n", 348 | "opt_configs = []\n", 349 | "outer_loop = tqdm(range(Nepoch), desc=\" epoch\", position=0)\n", 350 | "\n", 351 | "rand_idx = np.arange(train_size)\n", 352 | "params = tmp\n", 353 | "def inner_iter(j_idx,rand_idx_local,opt_state,p):\n", 354 | " idx = jax.lax.dynamic_slice(rand_idx_local,[j_idx*batch_size],[batch_size])\n", 355 | "\n", 356 | " val,g = grad_render3([p[0]*SCALE_MUL_FACTOR,p[1]/SCALE_MUL_FACTOR,p[2],p[3]],all_rays[idx],invF,vecM*poses,all_sils[idx],fwv_flow[idx],bwv_flow[idx],all_colors[idx]) \n", 357 | " updates, opt_state = optimizer.update(g, opt_state,p)\n", 358 | " p = optax.apply_updates(p, updates)\n", 359 | " return val, opt_state, p \n", 360 | "jax_iter = jax.jit(inner_iter)\n", 361 | "done = False\n", 362 | "for i in outer_loop:\n", 363 | " np.random.shuffle(rand_idx)\n", 364 | " rand_idx_jnp = jnp.array(rand_idx)\n", 365 | "\n", 366 | " for j in tqdm(range(Niter_epoch), desc=\" iteration\", position=1, leave=False):\n", 367 | " opt_configs.append(list(params))\n", 368 | " val,opt_state,params = jax_iter(j,rand_idx_jnp,opt_state,params)\n", 369 | " val = float(val)\n", 370 | " losses.append(val)\n", 371 | " if np.isnan(val):\n", 372 | " raise\n", 373 | "\n", 374 | " if adjust_lr.add(val):\n", 375 | " done = True\n", 376 | " break\n", 377 | " outer_loop.set_description(\" loss {:.3f}\".format(val))\n", 378 | " if done:\n", 379 | " break" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": null, 385 | "metadata": {}, 386 | "outputs": [], 387 | "source": [ 388 | "plt.plot(losses)" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": null, 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "final_mean, final_prec, final_weight_log,final_color = params" 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": null, 403 | "metadata": {}, 404 | "outputs": [], 405 | "source": [ 406 | "result_depths = []\n", 407 | "result_alphas = []\n", 408 | "results_colors = []\n", 409 | "\n", 410 | "for i in range(len(poses)):\n", 411 | " pixel_list[:,3] = i\n", 412 | " res_img,est_alpha,_,w = render_jit(final_mean, final_prec, final_weight_log,pixel_list,invF,poses)\n", 413 | " est_color = np.array(w.T @ (jnp.tanh(final_color)*0.5+0.5))**(2.2)\n", 414 | "\n", 415 | " res_imgA = np.array(res_img)\n", 416 | " est_alpha = np.array(est_alpha)\n", 417 | " res_imgA[est_alpha < 0.5] = np.nan\n", 418 | " est_color[est_alpha < 0.5] = np.nan\n", 419 | "\n", 420 | " result_depths.append(res_imgA.reshape((PY,PX)))\n", 421 | " result_alphas.append(est_alpha.reshape((PY,PX)))\n", 422 | " results_colors.append(est_color.reshape((PY,PX,3)))" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": null, 428 | "metadata": {}, 429 | "outputs": [], 430 | "source": [ 431 | "plt.subplot(1,3,1)\n", 432 | "plt.imshow(result_alphas[-1])\n", 433 | "plt.axis('off')\n", 434 | "\n", 435 | "plt.subplot(1,3,2)\n", 436 | "plt.imshow(result_depths[-1])\n", 437 | "plt.axis('off')\n", 438 | "plt.subplot(1,3,3)\n", 439 | "plt.imshow(est_color.reshape((PY,PX,3)),interpolation='nearest')\n", 440 | "plt.axis('off')\n" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": null, 446 | "metadata": {}, 447 | "outputs": [], 448 | "source": [ 449 | "from scipy.stats import trim_mean\n", 450 | "errs = []\n", 451 | "d1f = np.hstack([_.ravel() for _ in true_depths]).ravel()\n", 452 | "d2f = np.hstack([_.ravel() for _ in result_depths]).ravel()\n", 453 | "\n", 454 | "mask = (all_sils !=0 ) & (~np.isnan(d1f)) & (~np.isnan(d2f)) & (d1f !=0) \n", 455 | "\n", 456 | "trim_mean(abs(d1f[mask]-d2f[mask]),0.1)" 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": null, 462 | "metadata": {}, 463 | "outputs": [], 464 | "source": [ 465 | "image_grid(masks,rows=3,cols=5,rgb=False)" 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": null, 471 | "metadata": {}, 472 | "outputs": [], 473 | "source": [ 474 | "max_frame = len(poses)\n", 475 | "FWD_BCK_TIMES = 4\n", 476 | "THRESH_IDX = np.where(np.array(losses)/min(losses) < 1.02)[0][0]\n", 477 | "USE_FIRST_N_FRAC = THRESH_IDX/len(losses)\n", 478 | "N_FRAMES = max_frame*FWD_BCK_TIMES\n", 479 | "opt_to_use = np.round(np.linspace(0,int(np.floor(len(opt_configs)*USE_FIRST_N_FRAC-1)),N_FRAMES)).astype(int)" 480 | ] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "execution_count": null, 485 | "metadata": {}, 486 | "outputs": [], 487 | "source": [ 488 | "THRESH_IDX/len(losses)" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": null, 494 | "metadata": {}, 495 | "outputs": [], 496 | "source": [ 497 | "plt.plot(losses[:THRESH_IDX])" 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": null, 503 | "metadata": {}, 504 | "outputs": [], 505 | "source": [ 506 | "if os.path.exists(output_folder):\n", 507 | " import shutil\n", 508 | " shutil.rmtree(output_folder)\n", 509 | "os.makedirs(output_folder, exist_ok=True)" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": null, 515 | "metadata": {}, 516 | "outputs": [], 517 | "source": [ 518 | "frame_idxs = []\n", 519 | "frame_list = list(range(max_frame))\n", 520 | "for i in range(FWD_BCK_TIMES):\n", 521 | " if (i % 2) == 0:\n", 522 | " frame_idxs += frame_list\n", 523 | " else:\n", 524 | " frame_idxs += frame_list[::-1]" 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": null, 530 | "metadata": {}, 531 | "outputs": [], 532 | "source": [ 533 | "full_res_alpha = []\n", 534 | "full_res_depth = []\n", 535 | "full_res_color = []\n", 536 | "\n", 537 | "for r_idx,c_idx in zip(frame_idxs,opt_to_use):\n", 538 | " p = opt_configs[c_idx]\n", 539 | "\n", 540 | " pixel_list[:,3] = r_idx\n", 541 | " est_depth,est_alpha,_,w = render_jit(p[0],p[1],p[2],pixel_list,invF,poses)\n", 542 | " est_color = np.array(w.T @ (jnp.tanh(p[3])*0.5+0.5))**(2.2)\n", 543 | "\n", 544 | " est_alpha = np.array(est_alpha)\n", 545 | " est_depth = np.array(est_depth)\n", 546 | " est_depth[est_alpha < 0.5] = np.nan\n", 547 | " est_color[est_alpha < 0.5] = np.nan\n", 548 | "\n", 549 | " full_res_alpha.append(est_alpha.reshape((PY,PX)))\n", 550 | " full_res_depth.append(est_depth.reshape((PY,PX)))\n", 551 | " full_res_color.append(est_color.reshape((PY,PX,3)))\n", 552 | "\n", 553 | " print('.',end='')" 554 | ] 555 | }, 556 | { 557 | "cell_type": "code", 558 | "execution_count": null, 559 | "metadata": {}, 560 | "outputs": [], 561 | "source": [ 562 | "if os.path.exists(output_folder):\n", 563 | " import shutil\n", 564 | " shutil.rmtree(output_folder)\n", 565 | "os.makedirs(output_folder, exist_ok=True)" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": null, 571 | "metadata": {}, 572 | "outputs": [], 573 | "source": [ 574 | "vecr = np.hstack([_.ravel() for _ in full_res_depth])\n", 575 | "vecr = vecr[~np.isnan(vecr)]\n", 576 | "vmin = np.percentile(vecr,5)\n", 577 | "vmax = np.percentile(vecr,95)\n", 578 | "vscale = vmax-vmin" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": null, 584 | "metadata": {}, 585 | "outputs": [], 586 | "source": [ 587 | "import matplotlib\n", 588 | "from PIL import Image, ImageDraw, ImageFont\n", 589 | "start_f = 0\n", 590 | "avg_size = np.array([PX,PY])\n", 591 | "fsize = irc(96/4)\n", 592 | "\n", 593 | "font = ImageFont.truetype('Roboto-Regular.ttf', size=irc(avg_size[0]/8))\n", 594 | "cmap = matplotlib.cm.get_cmap('viridis')\n", 595 | "cmap2 = matplotlib.cm.get_cmap('magma')\n", 596 | "\n", 597 | "for i,mask_res in enumerate(full_res_alpha):\n", 598 | " r_idx = frame_idxs[i]\n", 599 | " #img1 = ground_images[r_idx]/255.0*np.clip(full_masks[r_idx] > .1,0.3,1)[:,:,None]\n", 600 | " #img2 = ground_images[r_idx]*np.clip((mask_res)**0.4,0.05,1)[:,:,None]\n", 601 | " img2 = full_res_color[i]#np.tile(mask_res[:,:,None],(1,1,3))\n", 602 | " img_gt_mask = np.tile(masks[r_idx][:,:,None],(1,1,3))\n", 603 | "\n", 604 | " true_alpha = masks[r_idx]\n", 605 | "\n", 606 | " est_alpha = jnp.clip(mask_res,1e-6,1-1e-6)\n", 607 | " mask_loss = - ((true_alpha * jnp.log(est_alpha)) + (1-true_alpha)*jnp.log(1-est_alpha))\n", 608 | " loss_viz = cmap2(0.25*mask_loss)[:,:,:3]\n", 609 | "\n", 610 | " depth = cmap((full_res_depth[i]-vmin)/vscale)[:,:,:3]\n", 611 | " img2 = np.concatenate((images[r_idx]/255.0,img_gt_mask,loss_viz,img2, depth), axis=1)\n", 612 | " int_img = np.round(img2*255).astype(np.uint8)\n", 613 | " pil_img = Image.fromarray(int_img)\n", 614 | " d1 = ImageDraw.Draw(pil_img)\n", 615 | " d1.text((avg_size[0]*1.1, irc(fsize*0.1)), \"Iteration: {:3d}\\nEpoch: {:.1f}\".format(opt_to_use[i],opt_to_use[i]/Niter_epoch), ha='center',font=font,fill=(180, 180, 180))\n", 616 | " d1.text((avg_size[0]*1.3, irc(avg_size[1]-fsize*1.5)), \"Target Mask\", font=font,fill=(255, 255, 255),ha='center')\n", 617 | " d1.text((avg_size[0]*2.4, irc(avg_size[1]-fsize*1.5)), \"Loss\", font=font,fill=(255, 255, 255),ha='center',align='center')\n", 618 | " d1.text((avg_size[0]*3.3, irc(avg_size[1]-fsize*2.5)), \"Estimated\\nColor\", font=font,fill=(255, 255, 255),ha='center',align='center')\n", 619 | " d1.text((avg_size[0]*4.3, irc(avg_size[1]-fsize*2.5)), \"Estimated\\nDepth\", font=font,fill=(255, 255, 255),ha='center',align='center')\n", 620 | "\n", 621 | " img3 = np.array(pil_img)\n", 622 | " \n", 623 | " sio.imsave('{}/{:03d}.jpg'.format(output_folder,i),img3,quality=95)\n" 624 | ] 625 | }, 626 | { 627 | "cell_type": "code", 628 | "execution_count": null, 629 | "metadata": {}, 630 | "outputs": [], 631 | "source": [ 632 | "(avg_size[0]*1.3, irc(avg_size[1]-fsize*1.5)),avg_size" 633 | ] 634 | }, 635 | { 636 | "cell_type": "code", 637 | "execution_count": null, 638 | "metadata": {}, 639 | "outputs": [], 640 | "source": [ 641 | "plt.figure(figsize=(18,8))\n", 642 | "plt.imshow(img3)\n", 643 | "plt.axis('off')" 644 | ] 645 | }, 646 | { 647 | "cell_type": "code", 648 | "execution_count": null, 649 | "metadata": {}, 650 | "outputs": [], 651 | "source": [ 652 | "import subprocess\n", 653 | "if os.path.exists('{}.mp4'.format(output_folder)):\n", 654 | " os.remove('{}.mp4'.format(output_folder))\n", 655 | "subprocess.call(' '.join(['/usr/bin/ffmpeg',\n", 656 | " '-framerate','60',\n", 657 | " '-i','{}/%03d.jpg'.format(output_folder),\n", 658 | " '-vf','\\\"pad=ceil(iw/2)*2:ceil(ih/2)*2\\\"',\n", 659 | " '-c:v','h264',\n", 660 | " '-pix_fmt','yuv420p',\n", 661 | " '{}.mp4'.format(output_folder)]),shell=True)" 662 | ] 663 | } 664 | ], 665 | "metadata": { 666 | "language_info": { 667 | "name": "python" 668 | } 669 | }, 670 | "nbformat": 4, 671 | "nbformat_minor": 5 672 | } 673 | -------------------------------------------------------------------------------- /run_co3d_sp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Run CO3D Sequence (2 parameter)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import sys, os, glob\n", 17 | "import numpy as np\n", 18 | "import pandas as pd\n", 19 | "from utils_opt import readFlow\n", 20 | "\n", 21 | "import skimage.io as sio\n", 22 | "import matplotlib.pyplot as plt" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "import os\n", 32 | "#os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n", 33 | "\n", 34 | "import jax\n", 35 | "#jax.config.update('jax_platform_name', 'cpu')\n", 36 | "\n", 37 | "import jax.numpy as jnp\n", 38 | "import fm_render\n", 39 | "\n", 40 | "render_jit = jax.jit(fm_render.render_func_idx_quattrans)\n", 41 | "render_jit_ray = jax.jit(fm_render.render_func_rays)\n", 42 | "jax_flow_rend = jax.jit(fm_render.render_func_idx_quattrans_flow)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "dataset_dir = 'rvid/teddybear_34_1479_4753//'\n", 52 | "co3d_seq = os.path.split(dataset_dir.rstrip('/').lstrip('/'))[-1]\n", 53 | "output_folder = os.path.join('tmp_out',co3d_seq)\n", 54 | "NUM_MIXTURE = 40\n", 55 | "shape_scale = 1.8\n", 56 | "c_scale = 4.5\n", 57 | "f_scale = 210\n", 58 | "rand_sphere_size = 55\n", 59 | "cov_scale = 1.2e-2\n", 60 | "weight_scale = 1.1\n", 61 | "LR_RATE = 0.08\n", 62 | "beta2 = 21.4\n", 63 | "beta3 = 2.66\n", 64 | "Nepoch = 10\n", 65 | "batch_size = 50000\n" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "## Load Data" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "import gzip\n", 82 | "import pickle\n", 83 | "with gzip.open(os.path.join(dataset_dir,'pose_depth.pkl.gz'),'rb') as fp:\n", 84 | " depth_and_pose = pickle.load(fp)\n", 85 | "\n", 86 | "true_depths = depth_and_pose['depths']\n", 87 | "fl = depth_and_pose['fl']\n", 88 | "poses = np.array(depth_and_pose['poses'])\n" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "masks_folder = os.path.join(dataset_dir,'masks','video1')\n", 98 | "in_files = sorted(glob.glob(masks_folder + '/*.png'))\n", 99 | "\n", 100 | "masks = []\n", 101 | "for img_loc in in_files:\n", 102 | " mask = sio.imread(img_loc)\n", 103 | " mask = (mask > 0).astype(np.float32)\n", 104 | " masks.append(mask)\n", 105 | "masks = np.array(masks)\n", 106 | "PY,PX = mask.shape\n", 107 | "image_size = (PY,PX)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "masks_folder = os.path.join(dataset_dir,'JPEGImages','video1')\n", 117 | "in_files = sorted(glob.glob(masks_folder + '/*.jpg'))\n", 118 | "\n", 119 | "images = []\n", 120 | "for img_loc in in_files:\n", 121 | " img = sio.imread(img_loc).astype(np.float32)\n", 122 | " images.append(img)\n", 123 | "images = np.array(images)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "fwd_flows = []\n", 133 | "bwd_flows = []\n", 134 | "\n", 135 | "flow_fol = os.path.join(dataset_dir,'Flow','video1','*.flo')\n", 136 | "\n", 137 | "flow_files = sorted(glob.glob(flow_fol))\n", 138 | "\n", 139 | "for flfile in flow_files:\n", 140 | " new_flow = readFlow(flfile)\n", 141 | " if PY > PX:\n", 142 | " new_flow = np.stack([new_flow[:,:,1],new_flow[:,:,0]],axis=2)\n", 143 | " if 'bwd' in flfile:\n", 144 | " bwd_flows.append(new_flow)\n", 145 | " else:\n", 146 | " fwd_flows.append(new_flow)\n", 147 | "\n", 148 | "\n", 149 | "# last flow has no fowards\n", 150 | "fwd_flows = fwd_flows + [new_flow*0]\n", 151 | "# first flow has no backwards\n", 152 | "bwd_flows = [new_flow*0] + bwd_flows" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "if 'mesh' in depth_and_pose:\n", 162 | " pt_cld = depth_and_pose['mesh'].vertices\n", 163 | " import sklearn.mixture as mixture\n", 164 | "\n", 165 | " idx2 = np.arange(pt_cld.shape[0])\n", 166 | " np.random.shuffle(idx2)\n", 167 | " clf = mixture.GaussianMixture(40)\n", 168 | " clf.fit(pt_cld[idx2[:10000]])\n", 169 | "\n", 170 | " pt_cld_shape_scale = float(pt_cld.std(0).mean())*3\n", 171 | " center = pt_cld.mean(0)\n", 172 | "else: \n", 173 | " pt_cld_shape_scale = 3.0\n", 174 | " center = np.zeros(3,dtype=np.float32)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "SCALE_MUL_FACTOR = shape_scale/pt_cld_shape_scale\n", 184 | "SCALE_MUL_FACTOR" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "img_shape = (PY,PX)\n", 194 | "min_size_idx = np.argmin(img_shape)\n", 195 | "min_size = img_shape[min_size_idx]\n", 196 | "max_size = img_shape[1-min_size_idx]\n", 197 | "invF = 0.5*min_size/fl\n", 198 | "min_dim = np.linspace(-1,1,min_size)\n", 199 | "aspect = max_size/min_size\n", 200 | "max_dim = np.linspace(-aspect,aspect,max_size)\n", 201 | "grid = [-max_dim,-min_dim,1,0] if min_size_idx == 0 else [-min_dim,-max_dim,1,0]\n", 202 | "pixel_list = np.transpose(np.squeeze(np.meshgrid(*grid,indexing='ij')),(2,1,0))\n", 203 | "\n", 204 | "pixel_list = pixel_list.reshape((-1,4))" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "poses = jnp.array(poses)" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "from util import image_grid" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "# random init settings\n", 232 | "rand_mean = center+pt_cld_shape_scale*np.random.multivariate_normal(mean=[0,0,0],cov=cov_scale*np.identity(3),size=NUM_MIXTURE)\n", 233 | "rand_weight_log = jnp.log(weight_scale*np.ones(NUM_MIXTURE)/NUM_MIXTURE)\n", 234 | "rand_prec = jnp.array([np.identity(3)*rand_sphere_size/pt_cld_shape_scale for _ in range(NUM_MIXTURE)])\n", 235 | "rand_color = jnp.array(np.random.randn(NUM_MIXTURE,3))\n", 236 | "\n", 237 | "init_alphas = []\n", 238 | "for i in range(min(36,len(poses))):\n", 239 | " pixel_list[:,3] = i\n", 240 | " res_img,est_alpha,_,_ = render_jit(rand_mean,rand_prec,rand_weight_log,pixel_list,invF,poses,beta2/pt_cld_shape_scale,beta3)\n", 241 | "\n", 242 | " res_imgA = np.array(res_img)\n", 243 | " res_imgA[est_alpha < 0.5] = np.nan\n", 244 | " init_alphas.append(est_alpha.reshape((PY,PX)))\n", 245 | "image_grid(init_alphas,6,6,rgb=False)" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "total_ray_set = []\n", 255 | "for i in range(len(poses)):\n", 256 | " pixel_list[:,3] = i\n", 257 | "\n", 258 | " total_ray_set.append(pixel_list.copy())\n", 259 | "all_rays = jnp.vstack(total_ray_set)" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "# scaled into ray coord space, vectorized flows\n", 269 | "fwv_flow = jnp.array(np.array(fwd_flows).reshape((-1,2)))/(min_size/2)\n", 270 | "bwv_flow = jnp.array(np.array(bwd_flows).reshape((-1,2)))/(min_size/2)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "last_img_size = np.prod(img_shape)\n", 280 | "v_idx = 40" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "def objective(params,camera_rays,invF,poses,beta2,beta3,true_alpha,true_fwd,true_bwd,true_color):\n", 290 | " CLIP_ALPHA = 1e-6\n", 291 | " means,prec,weights_log,colors = params\n", 292 | " est_depth, est_alpha, est_norm, est_w,flowp,flowm = fm_render.render_func_idx_quattrans_flow(means,prec,weights_log,camera_rays,invF,poses,beta2,beta3)\n", 293 | " est_color = est_w.T @ (jnp.tanh(colors)*0.5+0.5)\n", 294 | " est_alpha = jnp.clip(est_alpha,CLIP_ALPHA,1-CLIP_ALPHA)\n", 295 | " mask_loss = - ((true_alpha * jnp.log(est_alpha)) + (1-true_alpha)*jnp.log(1-est_alpha))\n", 296 | " pad_alpha = true_alpha[:,None]\n", 297 | " flow1 = jnp.abs(pad_alpha*true_fwd-pad_alpha*flowp)\n", 298 | " flow2 = jnp.abs(pad_alpha*true_bwd-pad_alpha*flowm)\n", 299 | " cdiff = jnp.abs( (true_color-est_color)*true_alpha[:,None] )\n", 300 | " return mask_loss.mean() + c_scale*cdiff.mean() + f_scale*(flow1.mean() + flow2.mean()) \n", 301 | "grad_render3 = jax.value_and_grad(objective)" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": null, 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "import optax\n", 311 | "from tqdm.notebook import tqdm\n", 312 | "from util import DegradeLR\n", 313 | "\n", 314 | "vecM = jnp.array([1,1,1,1,SCALE_MUL_FACTOR,SCALE_MUL_FACTOR,SCALE_MUL_FACTOR])[None]\n", 315 | "\n", 316 | "train_size = all_rays.shape[0]\n", 317 | "Niter_epoch = int(round(train_size/batch_size))\n", 318 | "\n", 319 | "def irc(x): return int(round(x))\n", 320 | "\n", 321 | "# babysit learning rates\n", 322 | "adjust_lr = DegradeLR(LR_RATE,0.5,irc(Niter_epoch*0.25),irc(Niter_epoch*0.1),-1e-4)\n", 323 | "\n", 324 | "optimizer = optax.adam(adjust_lr.step_func)\n", 325 | "\n", 326 | "tmp = [rand_mean,rand_prec,rand_weight_log,rand_color]\n", 327 | "#tmp = [means,prec,weights_log]\n", 328 | "\n", 329 | "opt_state = optimizer.init(tmp)\n", 330 | "\n", 331 | "all_sils = jnp.hstack([_.ravel() for _ in masks]).astype(jnp.float32)\n", 332 | "all_colors = jnp.hstack([_.ravel()/255.0 for _ in images]).astype(jnp.float32).reshape((-1,3))\n", 333 | "all_colors = all_colors**(1/2.2)\n", 334 | "\n", 335 | "losses = []\n", 336 | "opt_configs = []\n", 337 | "outer_loop = tqdm(range(Nepoch), desc=\" epoch\", position=0)\n", 338 | "\n", 339 | "rand_idx = np.arange(train_size)\n", 340 | "params = tmp\n", 341 | "def inner_iter(j_idx,rand_idx_local,opt_state,p):\n", 342 | " idx = jax.lax.dynamic_slice(rand_idx_local,[j_idx*batch_size],[batch_size])\n", 343 | "\n", 344 | " val,g = grad_render3([p[0]*SCALE_MUL_FACTOR,p[1]/SCALE_MUL_FACTOR,p[2],p[3]],all_rays[idx],invF,vecM*poses,\n", 345 | " beta2/(shape_scale),beta3,all_sils[idx],fwv_flow[idx],bwv_flow[idx],all_colors[idx]) \n", 346 | " updates, opt_state = optimizer.update(g, opt_state,p)\n", 347 | " p = optax.apply_updates(p, updates)\n", 348 | " return val, opt_state, p \n", 349 | "jax_iter = jax.jit(inner_iter)\n", 350 | "done = False\n", 351 | "for i in outer_loop:\n", 352 | " np.random.shuffle(rand_idx)\n", 353 | " rand_idx_jnp = jnp.array(rand_idx)\n", 354 | "\n", 355 | " for j in tqdm(range(Niter_epoch), desc=\" iteration\", position=1, leave=False):\n", 356 | " opt_configs.append(list(params))\n", 357 | " val,opt_state,params = jax_iter(j,rand_idx_jnp,opt_state,params)\n", 358 | " val = float(val)\n", 359 | " losses.append(val)\n", 360 | "\n", 361 | " if adjust_lr.add(val):\n", 362 | " done = True\n", 363 | " break\n", 364 | " outer_loop.set_description(\" loss {:.3f}\".format(val))\n", 365 | " if done:\n", 366 | " break" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": null, 372 | "metadata": {}, 373 | "outputs": [], 374 | "source": [ 375 | "plt.plot(losses)" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": null, 381 | "metadata": {}, 382 | "outputs": [], 383 | "source": [ 384 | "final_mean, final_prec, final_weight_log,final_color = params" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": null, 390 | "metadata": {}, 391 | "outputs": [], 392 | "source": [ 393 | "dump_out = {\n", 394 | " 'mean': np.array(final_mean),\n", 395 | " 'prec': np.array(final_prec),\n", 396 | " 'wlog': np.array(final_weight_log),\n", 397 | " 'color': np.array(final_color)\n", 398 | "}\n", 399 | "import pickle\n", 400 | "with open('output.pkl','wb') as fp:\n", 401 | " pickle.dump(dump_out,fp)" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": null, 407 | "metadata": {}, 408 | "outputs": [], 409 | "source": [ 410 | "result_depths = []\n", 411 | "result_alphas = []\n", 412 | "results_colors = []\n", 413 | "\n", 414 | "for i in range(len(poses)):\n", 415 | " pixel_list[:,3] = i\n", 416 | " res_img,est_alpha,_,w = render_jit(final_mean, final_prec, final_weight_log,pixel_list,invF,poses,beta2/pt_cld_shape_scale,beta3)\n", 417 | " est_color = np.array(w.T @ (jnp.tanh(final_color)*0.5+0.5))**(2.2)\n", 418 | "\n", 419 | " res_imgA = np.array(res_img)\n", 420 | " est_alpha = np.array(est_alpha)\n", 421 | " res_imgA[est_alpha < 0.5] = np.nan\n", 422 | " est_color[est_alpha < 0.5] = np.nan\n", 423 | "\n", 424 | " result_depths.append(res_imgA.reshape((PY,PX)))\n", 425 | " result_alphas.append(est_alpha.reshape((PY,PX)))\n", 426 | " results_colors.append(est_color.reshape((PY,PX,3)))" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": null, 432 | "metadata": {}, 433 | "outputs": [], 434 | "source": [ 435 | "plt.subplot(1,3,1)\n", 436 | "plt.imshow(result_alphas[-1])\n", 437 | "plt.axis('off')\n", 438 | "\n", 439 | "plt.subplot(1,3,2)\n", 440 | "plt.imshow(result_depths[-1])\n", 441 | "plt.axis('off')\n", 442 | "plt.subplot(1,3,3)\n", 443 | "plt.imshow(est_color.reshape((PY,PX,3)),interpolation='nearest')\n", 444 | "plt.axis('off')\n" 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "execution_count": null, 450 | "metadata": {}, 451 | "outputs": [], 452 | "source": [ 453 | "def per_gaussian_error(params,camera_rays,invF,poses,beta2,beta3,true_alpha,true_color,lower_std,upper_std):\n", 454 | " CLIP_ALPHA = 1e-7\n", 455 | " CLIP_ALPHA = 1e-6\n", 456 | " means,prec,weights_log,colors = params\n", 457 | " est_depth, est_alpha, est_norm, est_w =render_jit(means,prec,weights_log,camera_rays,invF,poses,beta2,beta3)\n", 458 | " est_color = est_w.T @ (jnp.tanh(colors)*0.5+0.5)\n", 459 | " est_alpha = jnp.clip(est_alpha,CLIP_ALPHA,1-CLIP_ALPHA)\n", 460 | " mask_loss = - ((true_alpha * jnp.log(est_alpha)) + (1-true_alpha)*jnp.log(1-est_alpha))\n", 461 | " cdiff = jnp.abs( (true_color-est_color)*true_alpha[:,None] )\n", 462 | " \n", 463 | " per_err = ((mask_loss*est_w).mean(axis=1) + c_scale*(cdiff.mean(axis=1) * est_w).mean(axis=1) )\n", 464 | " avg_w = est_w.mean(axis=1)\n", 465 | " keep_idx = (avg_w > (avg_w.mean() - lower_std*avg_w.std()))\n", 466 | " split_idx = (per_err >= (per_err.mean() + upper_std*per_err.std()))\n", 467 | " c_var = (true_color[:,None,:] *est_w.T[:,:,None]).std(axis=0)\n", 468 | " return split_idx, keep_idx, c_var\n", 469 | "\n", 470 | "def get_split_gaussian(params,camera_rays,invF,poses,beta2,beta3,true_alpha,true_color,lower_std,upper_std):\n", 471 | " split_idx,keep_idx,c_var = per_gaussian_error(params,camera_rays,invF,poses,beta2,beta3,true_alpha,true_color,lower_std,upper_std)\n", 472 | " t_keep_idx = keep_idx & (~split_idx)\n", 473 | "\n", 474 | " means,prec,weights_log,colors = params\n", 475 | "\n", 476 | " new_means, new_prec, new_weights, new_colors = [],[],[], []\n", 477 | " for i in np.where(np.array(split_idx))[0]:\n", 478 | " mu, preco, wlog, col = means[i], prec[i], weights_log[i], colors[i]\n", 479 | " covar = np.linalg.pinv(preco.T @ preco)\n", 480 | " u,s,vt = np.linalg.svd(covar)\n", 481 | " s2 = s.copy()\n", 482 | " s2[0] = s2[0] * np.sqrt(1-2/np.pi)\n", 483 | " covar2 = u@np.diag(s2)@vt\n", 484 | " m1 = mu + (u[0] * np.sqrt(s[0]) * np.sqrt(2/np.pi))\n", 485 | " m2 = mu - (u[0] * np.sqrt(s[0]) * np.sqrt(2/np.pi))\n", 486 | " precn = np.linalg.cholesky(np.linalg.pinv(covar2)).T\n", 487 | "\n", 488 | " new_means.append(m1)\n", 489 | " new_means.append(m2)\n", 490 | " new_prec.append(precn)\n", 491 | " new_prec.append(precn)\n", 492 | " new_weights.append(wlog+ 0.1*np.random.randn())\n", 493 | " new_weights.append(wlog+ 0.1*np.random.randn())\n", 494 | " new_colors.append(col + 0.1*np.random.randn(3))\n", 495 | " new_colors.append(col + 0.1*np.random.randn(3))\n", 496 | " oldp = [np.array(_)[t_keep_idx] for _ in params]\n", 497 | " m2 = np.vstack([oldp[0],new_means])\n", 498 | " p2 = np.vstack([oldp[1],new_prec])\n", 499 | " w2 = np.hstack([oldp[2],new_weights])\n", 500 | " c2 = np.vstack([oldp[3],new_colors])\n", 501 | " return [jnp.array(_).astype(jnp.float32) for _ in [m2,p2,w2,c2]]\n", 502 | "idx =rand_idx_jnp[:10*batch_size] \n", 503 | "params2 = get_split_gaussian(params,all_rays[idx],invF,vecM*poses,beta2/(shape_scale),beta3,all_sils[idx],all_colors[idx],2,1)\n", 504 | "print(params2[0].shape)" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": null, 510 | "metadata": {}, 511 | "outputs": [], 512 | "source": [ 513 | "from scipy.stats import trim_mean\n", 514 | "errs = []\n", 515 | "d1f = np.hstack([_.ravel() for _ in true_depths]).ravel()\n", 516 | "d2f = np.hstack([_.ravel() for _ in result_depths]).ravel()\n", 517 | "\n", 518 | "mask = (all_sils !=0 ) & (~np.isnan(d1f)) & (~np.isnan(d2f)) & (d1f !=0) \n", 519 | "\n", 520 | "trim_mean(abs(d1f[mask]-d2f[mask]),0.1)" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": null, 526 | "metadata": {}, 527 | "outputs": [], 528 | "source": [ 529 | "image_grid(masks,rows=3,cols=5,rgb=False)" 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "execution_count": null, 535 | "metadata": {}, 536 | "outputs": [], 537 | "source": [ 538 | "max_frame = len(poses)\n", 539 | "FWD_BCK_TIMES = 4\n", 540 | "THRESH_IDX = np.where(np.array(losses)/min(losses) < 1.02)[0][0]\n", 541 | "USE_FIRST_N_FRAC = THRESH_IDX/len(losses)\n", 542 | "N_FRAMES = max_frame*FWD_BCK_TIMES\n", 543 | "opt_to_use = np.round(np.linspace(0,int(np.floor(len(opt_configs)*USE_FIRST_N_FRAC-1)),N_FRAMES)).astype(int)\n", 544 | "loss_v = np.log(losses)\n", 545 | "loss_v -= loss_v.min()\n", 546 | "loss_v /= loss_v.max()\n", 547 | "loss_v = np.cumsum(loss_v)\n", 548 | "loss_v -= loss_v.min()\n", 549 | "loss_v /= loss_v.max()\n", 550 | "tv = np.stack([N_FRAMES*loss_v,(len(opt_configs)-1)*np.linspace(0,1,len(losses))]).T\n", 551 | "#plt.plot(tv[:,0],tv[:,1])\n", 552 | "#opt_to_use = np.round(np.interp(np.arange(N_FRAMES),tv[:,0],tv[:,1])).astype(int)" 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": null, 558 | "metadata": {}, 559 | "outputs": [], 560 | "source": [ 561 | "THRESH_IDX/len(losses)" 562 | ] 563 | }, 564 | { 565 | "cell_type": "code", 566 | "execution_count": null, 567 | "metadata": {}, 568 | "outputs": [], 569 | "source": [ 570 | "plt.plot(losses[:THRESH_IDX])" 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": null, 576 | "metadata": {}, 577 | "outputs": [], 578 | "source": [ 579 | "len(opt_configs)" 580 | ] 581 | }, 582 | { 583 | "cell_type": "code", 584 | "execution_count": null, 585 | "metadata": {}, 586 | "outputs": [], 587 | "source": [ 588 | "if os.path.exists(output_folder):\n", 589 | " import shutil\n", 590 | " shutil.rmtree(output_folder)\n", 591 | "os.makedirs(output_folder, exist_ok=True)" 592 | ] 593 | }, 594 | { 595 | "cell_type": "code", 596 | "execution_count": null, 597 | "metadata": {}, 598 | "outputs": [], 599 | "source": [ 600 | "frame_idxs = []\n", 601 | "frame_list = list(range(max_frame))\n", 602 | "for i in range(FWD_BCK_TIMES):\n", 603 | " if (i % 2) == 0:\n", 604 | " frame_idxs += frame_list\n", 605 | " else:\n", 606 | " frame_idxs += frame_list[::-1]" 607 | ] 608 | }, 609 | { 610 | "cell_type": "code", 611 | "execution_count": null, 612 | "metadata": {}, 613 | "outputs": [], 614 | "source": [ 615 | "full_res_alpha = []\n", 616 | "full_res_depth = []\n", 617 | "full_res_color = []\n", 618 | "\n", 619 | "for r_idx,c_idx in zip(frame_idxs,opt_to_use):\n", 620 | " p = opt_configs[c_idx]\n", 621 | "\n", 622 | " pixel_list[:,3] = r_idx\n", 623 | " est_depth,est_alpha,_,w = render_jit(p[0],p[1],p[2],pixel_list,invF,poses,beta2/pt_cld_shape_scale,beta3)\n", 624 | " est_color = (1-est_alpha[:,None])*0 + est_alpha[:,None] * np.array(w.T @ (jnp.tanh(p[3])*0.5+0.5))**(2.2)\n", 625 | "\n", 626 | " est_alpha = np.array(est_alpha)\n", 627 | " est_depth = np.array(est_depth)\n", 628 | " est_depth[est_alpha < max(0.5,np.percentile(est_alpha,0.99))] = np.nan\n", 629 | " #est_color[est_alpha < 0.5] = np.nan\n", 630 | "\n", 631 | " full_res_alpha.append(est_alpha.reshape((PY,PX)))\n", 632 | " full_res_depth.append(est_depth.reshape((PY,PX)))\n", 633 | " full_res_color.append(est_color.reshape((PY,PX,3)))\n", 634 | "\n", 635 | " print('.',end='')" 636 | ] 637 | }, 638 | { 639 | "cell_type": "code", 640 | "execution_count": null, 641 | "metadata": {}, 642 | "outputs": [], 643 | "source": [ 644 | "vecr = np.hstack([_.ravel() for _ in full_res_depth])\n", 645 | "vecr = vecr[~np.isnan(vecr)]\n", 646 | "vmin = np.percentile(vecr,5)\n", 647 | "vmax = np.percentile(vecr,95)\n", 648 | "vscale = vmax-vmin" 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": null, 654 | "metadata": {}, 655 | "outputs": [], 656 | "source": [ 657 | "import matplotlib\n", 658 | "from PIL import Image, ImageDraw, ImageFont\n", 659 | "start_f = 0\n", 660 | "avg_size = np.array([PX,PY])\n", 661 | "fsize = irc(96/4)\n", 662 | "\n", 663 | "font = ImageFont.truetype('Roboto-Regular.ttf', size=irc(avg_size[0]/16))\n", 664 | "cmap = matplotlib.cm.get_cmap('viridis')\n", 665 | "cmap2 = matplotlib.cm.get_cmap('magma')\n", 666 | "\n", 667 | "for i,mask_res in enumerate(full_res_alpha):\n", 668 | " r_idx = frame_idxs[i]\n", 669 | " #img1 = ground_images[r_idx]/255.0*np.clip(full_masks[r_idx] > .1,0.3,1)[:,:,None]\n", 670 | " #img2 = ground_images[r_idx]*np.clip((mask_res)**0.4,0.05,1)[:,:,None]\n", 671 | " img2 = full_res_color[i]#np.tile(mask_res[:,:,None],(1,1,3))\n", 672 | " img_gt_mask = np.tile(masks[r_idx][:,:,None],(1,1,3))\n", 673 | "\n", 674 | " true_alpha = masks[r_idx]\n", 675 | "\n", 676 | " est_alpha = jnp.clip(mask_res,1e-6,1-1e-6)\n", 677 | " mask_loss = - ((true_alpha * jnp.log(est_alpha)) + (1-true_alpha)*jnp.log(1-est_alpha))\n", 678 | " loss_viz = cmap2(0.25*mask_loss)[:,:,:3]\n", 679 | "\n", 680 | " depth = cmap((full_res_depth[i]-vmin)/vscale)[:,:,:3]\n", 681 | " img2 = np.concatenate((images[r_idx]/255.0,img_gt_mask,img2, depth), axis=1)\n", 682 | " int_img = np.round(img2*255).astype(np.uint8)\n", 683 | " pil_img = Image.fromarray(int_img)\n", 684 | " d1 = ImageDraw.Draw(pil_img)\n", 685 | " d1.text((avg_size[0]*1.1, irc(fsize*0.1)), \"Iteration: {:3d}\\nEpoch: {:.1f}\".format(opt_to_use[i],opt_to_use[i]/Niter_epoch), ha='center',font=font,fill=(180, 180, 180))\n", 686 | " d1.text((avg_size[0]*1.3, irc(avg_size[1]-fsize*2.5)), \"Target Mask\", font=font,fill=(255, 255, 255),ha='center')\n", 687 | " #d1.text((avg_size[0]*2.4, irc(avg_size[1]-fsize*1.5)), \"Loss\", font=font,fill=(255, 255, 255),ha='center',align='center')\n", 688 | " d1.text((avg_size[0]*2.3, irc(avg_size[1]-fsize*3.5)), \"Estimated\\nColor\", font=font,fill=(255, 255, 255),ha='center',align='center')\n", 689 | " d1.text((avg_size[0]*3.3, irc(avg_size[1]-fsize*3.5)), \"Estimated\\nDepth\", font=font,fill=(255, 255, 255),ha='center',align='center')\n", 690 | "\n", 691 | " img3 = np.array(pil_img)\n", 692 | " \n", 693 | " \n", 694 | " sio.imsave('{}/{:03d}.jpg'.format(output_folder,i),img3,quality=95)" 695 | ] 696 | }, 697 | { 698 | "cell_type": "code", 699 | "execution_count": null, 700 | "metadata": {}, 701 | "outputs": [], 702 | "source": [ 703 | "plt.imshow(img3)" 704 | ] 705 | }, 706 | { 707 | "cell_type": "code", 708 | "execution_count": null, 709 | "metadata": {}, 710 | "outputs": [], 711 | "source": [ 712 | "(avg_size[0]*1.3, irc(avg_size[1]-fsize*1.5)),avg_size" 713 | ] 714 | }, 715 | { 716 | "cell_type": "code", 717 | "execution_count": null, 718 | "metadata": {}, 719 | "outputs": [], 720 | "source": [ 721 | "plt.figure(figsize=(18,8))\n", 722 | "plt.imshow(img3)\n", 723 | "plt.axis('off')" 724 | ] 725 | }, 726 | { 727 | "cell_type": "code", 728 | "execution_count": null, 729 | "metadata": {}, 730 | "outputs": [], 731 | "source": [ 732 | "import subprocess\n", 733 | "if os.path.exists('{}.mp4'.format(output_folder)):\n", 734 | " os.remove('{}.mp4'.format(output_folder))\n", 735 | "subprocess.call(' '.join(['/usr/bin/ffmpeg',\n", 736 | " '-framerate','60',\n", 737 | " '-i','{}/%03d.jpg'.format(output_folder),\n", 738 | " '-vf','\\\"pad=ceil(iw/2)*2:ceil(ih/2)*2\\\"',\n", 739 | " '-c:v','h264',\n", 740 | " '-pix_fmt','yuv420p',\n", 741 | " '{}.mp4'.format(output_folder)]),shell=True)" 742 | ] 743 | }, 744 | { 745 | "cell_type": "code", 746 | "execution_count": null, 747 | "metadata": {}, 748 | "outputs": [], 749 | "source": [ 750 | "#raise" 751 | ] 752 | }, 753 | { 754 | "cell_type": "code", 755 | "execution_count": null, 756 | "metadata": {}, 757 | "outputs": [], 758 | "source": [ 759 | "output_folder" 760 | ] 761 | }, 762 | { 763 | "cell_type": "code", 764 | "execution_count": null, 765 | "metadata": {}, 766 | "outputs": [], 767 | "source": [ 768 | "p = opt_configs[-1]" 769 | ] 770 | }, 771 | { 772 | "cell_type": "code", 773 | "execution_count": null, 774 | "metadata": {}, 775 | "outputs": [], 776 | "source": [ 777 | "base_idx = min(len(masks)-1,28)\n", 778 | "pixel_list[:,3] = base_idx\n", 779 | "est_depth,est_alpha,_,_,flowp,flowm = jax_flow_rend(p[0],p[1],p[2],pixel_list,invF,poses,beta2/pt_cld_shape_scale,beta3)\n", 780 | "\n", 781 | "flowp = (min_size/2)*np.array(flowp)\n", 782 | "flowm = (min_size/2)*np.array(flowm)\n", 783 | "flowp[est_alpha < 0.5] = np.nan\n", 784 | "flowm[est_alpha < 0.5] = np.nan" 785 | ] 786 | }, 787 | { 788 | "cell_type": "code", 789 | "execution_count": null, 790 | "metadata": {}, 791 | "outputs": [], 792 | "source": [ 793 | "tmp_f = np.copy(fwd_flows[base_idx])\n", 794 | "tmp_f[est_alpha.reshape((PY,PX)) < 0.5] = np.nan\n", 795 | "\n", 796 | "plt.subplot(1,2,1)\n", 797 | "plt.imshow(tmp_f[:,:,0],vmin=-6,vmax=6,cmap='RdBu' )\n", 798 | "plt.axis('off')\n", 799 | "plt.colorbar()\n", 800 | "plt.subplot(1,2,2)\n", 801 | "plt.imshow(tmp_f[:,:,1],vmin=-6,vmax=6,cmap='RdBu' )\n", 802 | "plt.axis('off')\n", 803 | "\n", 804 | "plt.colorbar()\n", 805 | "\n", 806 | "plt.figure()\n", 807 | "plt.subplot(1,2,1)\n", 808 | "plt.imshow(flowp[:,0].reshape((PY,PX)),vmin=-6,vmax=6,cmap='RdBu' )\n", 809 | "plt.axis('off')\n", 810 | "plt.colorbar()\n", 811 | "plt.subplot(1,2,2)\n", 812 | "plt.imshow(flowp[:,1].reshape((PY,PX)),vmin=-6,vmax=6,cmap='RdBu' )\n", 813 | "plt.axis('off')\n", 814 | "plt.colorbar()" 815 | ] 816 | }, 817 | { 818 | "cell_type": "code", 819 | "execution_count": null, 820 | "metadata": {}, 821 | "outputs": [], 822 | "source": [ 823 | "tmp_b = np.copy(bwd_flows[base_idx])\n", 824 | "tmp_b[est_alpha.reshape((PY,PX)) < 0.5] = np.nan\n", 825 | "\n", 826 | "plt.subplot(1,2,1)\n", 827 | "plt.imshow(tmp_b[:,:,0],vmin=-6,vmax=6,cmap='RdBu' )\n", 828 | "plt.axis('off')\n", 829 | "plt.colorbar()\n", 830 | "plt.subplot(1,2,2)\n", 831 | "plt.imshow(tmp_b[:,:,1],vmin=-6,vmax=6,cmap='RdBu' )\n", 832 | "plt.axis('off')\n", 833 | "\n", 834 | "plt.colorbar()\n", 835 | "\n", 836 | "plt.figure()\n", 837 | "plt.subplot(1,2,1)\n", 838 | "plt.imshow(flowm[:,0].reshape((PY,PX)),vmin=-6,vmax=6,cmap='RdBu' )\n", 839 | "plt.axis('off')\n", 840 | "plt.colorbar()\n", 841 | "plt.subplot(1,2,2)\n", 842 | "plt.imshow(flowm[:,1].reshape((PY,PX)),vmin=-6,vmax=6,cmap='RdBu' )\n", 843 | "plt.axis('off')\n", 844 | "plt.colorbar()" 845 | ] 846 | }, 847 | { 848 | "cell_type": "code", 849 | "execution_count": null, 850 | "metadata": {}, 851 | "outputs": [], 852 | "source": [ 853 | "import zpfm_render\n", 854 | "render_jit2 = jax.jit(zpfm_render.render_func_idx_quattrans_flow)\n" 855 | ] 856 | }, 857 | { 858 | "cell_type": "code", 859 | "execution_count": null, 860 | "metadata": {}, 861 | "outputs": [], 862 | "source": [ 863 | "points_export = []\n", 864 | "colors_export = []\n", 865 | "colors_export_plain = []\n", 866 | "\n", 867 | "normals_export = []\n", 868 | "\n", 869 | "scaleE = 2\n", 870 | "\n", 871 | "thesh_min = 0.9\n", 872 | "\n", 873 | "for i in range(len(poses)):\n", 874 | " pixel_list[:,3] = i\n", 875 | " rot_mats = jax.vmap(fm_render.quat_to_rot)(poses[:,:4])\n", 876 | " def rot_ray_t(rayi):\n", 877 | " ray = rayi[:3] * jnp.array([invF,invF,1])\n", 878 | " pose_idx = rayi[3].astype(int)\n", 879 | " return jnp.array([ray@rot_mats[pose_idx],poses[pose_idx][4:]])\n", 880 | " camera_rays_start = jax.vmap(rot_ray_t)(pixel_list)\n", 881 | " est_depth,est_alpha,est_norm,est_w,flowp,flowm = render_jit2(final_mean, final_prec*scaleE,(scaleE**2)*final_weight_log,pixel_list,invF,poses)\n", 882 | "\n", 883 | " est_color = np.array(est_w.T @ (jnp.tanh(final_color)*0.5+0.5))**(2.2)\n", 884 | " \n", 885 | " # nneed RGBA\n", 886 | " # or images[i] # #np.round(images[i])\n", 887 | " export_c = np.round(np.clip(est_color,0,1)*255).astype(np.uint8)\n", 888 | " alpha_c = (np.ones(export_c.shape[:-1])*255).astype(np.uint8)\n", 889 | " export_c = np.hstack([export_c.reshape((-1,3)),alpha_c.reshape((-1,1))]).reshape((-1,4))\n", 890 | " \n", 891 | " export_c2 = np.round(images[i]).astype(np.uint8)\n", 892 | " export_c2 = np.hstack([export_c2.reshape((-1,3)),alpha_c.reshape((-1,1))]).reshape((-1,4))\n", 893 | " \n", 894 | " est_3d = est_depth[:,None]*camera_rays_start[:,0]+camera_rays_start[:,1] \n", 895 | " \n", 896 | " est_3d = np.array(est_3d)\n", 897 | " est_alpha = np.array(est_alpha)\n", 898 | " \n", 899 | " export_cond = (est_alpha > thesh_min) & (est_w.max(axis=0) > thesh_min)\n", 900 | "\n", 901 | " points_export.append(est_3d[export_cond])\n", 902 | " colors_export.append(export_c2[export_cond])\n", 903 | " normals_export.append(est_norm[export_cond])\n", 904 | " colors_export_plain.append(export_c[export_cond])\n", 905 | " \n", 906 | " \n", 907 | "points_export = np.concatenate(points_export)\n", 908 | "colors_export = np.concatenate(colors_export)\n", 909 | "colors_export_plain = np.concatenate(colors_export_plain)\n", 910 | "normals_export = np.concatenate(normals_export)" 911 | ] 912 | }, 913 | { 914 | "cell_type": "code", 915 | "execution_count": null, 916 | "metadata": {}, 917 | "outputs": [], 918 | "source": [ 919 | "est_color.max()" 920 | ] 921 | }, 922 | { 923 | "cell_type": "code", 924 | "execution_count": null, 925 | "metadata": {}, 926 | "outputs": [], 927 | "source": [ 928 | "import open3d as o3d\n", 929 | "o3d_cld = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(points_export))\n", 930 | "o3d_cld.colors = o3d.utility.Vector3dVector(colors_export[:,:3].astype(float)/255.0)\n", 931 | "o3d_cld.normals = o3d.utility.Vector3dVector(normals_export)\n", 932 | "o3d.io.write_point_cloud(\"{}.ply\".format(output_folder), o3d_cld)\n", 933 | "\n", 934 | "o3d_cld.colors = o3d.utility.Vector3dVector(colors_export_plain[:,:3].astype(float)/255.0)\n", 935 | "o3d.io.write_point_cloud(\"{}_plain.ply\".format(output_folder), o3d_cld)" 936 | ] 937 | }, 938 | { 939 | "cell_type": "code", 940 | "execution_count": null, 941 | "metadata": {}, 942 | "outputs": [], 943 | "source": [ 944 | "output_folder" 945 | ] 946 | }, 947 | { 948 | "cell_type": "code", 949 | "execution_count": null, 950 | "metadata": {}, 951 | "outputs": [], 952 | "source": [ 953 | "result_depths2 = []\n", 954 | "result_alphas2 = []\n", 955 | "results_colors2 = []\n", 956 | "scaleE=1\n", 957 | "\n", 958 | "for i in range(len(poses)):\n", 959 | " pixel_list[:,3] = i\n", 960 | " est_depth,est_alpha,est_norm,est_w,flowp,flowm = render_jit2(final_mean, final_prec*scaleE,(scaleE**2)*final_weight_log,pixel_list,invF,poses)\n", 961 | " est_color = np.array(w.T @ (jnp.tanh(final_color)*0.5+0.5))**(2.2)\n", 962 | "\n", 963 | " est_depth = np.array(est_depth)\n", 964 | " est_alpha = np.array(est_alpha)\n", 965 | " est_depth[est_alpha < thesh_min] = np.nan\n", 966 | " est_color[est_alpha < thesh_min] = np.nan\n", 967 | "\n", 968 | " result_depths2.append(est_depth.reshape((PY,PX)))\n", 969 | " result_alphas2.append(est_alpha.reshape((PY,PX)))\n", 970 | " results_colors2.append(est_color.reshape((PY,PX,3)))\n", 971 | " break" 972 | ] 973 | }, 974 | { 975 | "cell_type": "code", 976 | "execution_count": null, 977 | "metadata": {}, 978 | "outputs": [], 979 | "source": [ 980 | "plt.imshow(est_w.T[:,6].reshape((PY,PX)))" 981 | ] 982 | }, 983 | { 984 | "cell_type": "code", 985 | "execution_count": null, 986 | "metadata": {}, 987 | "outputs": [], 988 | "source": [ 989 | "plt.plot(est_w[:,est_w.shape[1]//2+100])\n", 990 | "plt.plot(est_w[:,0])\n" 991 | ] 992 | }, 993 | { 994 | "cell_type": "code", 995 | "execution_count": null, 996 | "metadata": {}, 997 | "outputs": [], 998 | "source": [ 999 | "plt.imshow(result_alphas2[-1])" 1000 | ] 1001 | }, 1002 | { 1003 | "cell_type": "code", 1004 | "execution_count": null, 1005 | "metadata": {}, 1006 | "outputs": [], 1007 | "source": [ 1008 | "plt.imshow(result_depths2[-1])\n", 1009 | "plt.colorbar()\n", 1010 | "plt.figure()\n", 1011 | "plt.imshow(result_depths[-1])\n", 1012 | "plt.colorbar()" 1013 | ] 1014 | }, 1015 | { 1016 | "cell_type": "code", 1017 | "execution_count": null, 1018 | "metadata": {}, 1019 | "outputs": [], 1020 | "source": [ 1021 | "plt.imshow(results_colors2[-1])" 1022 | ] 1023 | }, 1024 | { 1025 | "cell_type": "code", 1026 | "execution_count": null, 1027 | "metadata": {}, 1028 | "outputs": [], 1029 | "source": [ 1030 | "import transforms3d\n", 1031 | "Rr = transforms3d.quaternions.quat2mat(poses[0][:4])\n", 1032 | "est_norm2 = -np.array(est_norm) @ Rr\n", 1033 | "est_norm2[est_alpha < 0.25] = np.nan\n", 1034 | "plt.imshow(est_norm2.reshape((image_size[0],image_size[1],3))*0.5+0.5)" 1035 | ] 1036 | }, 1037 | { 1038 | "cell_type": "code", 1039 | "execution_count": null, 1040 | "metadata": {}, 1041 | "outputs": [], 1042 | "source": [ 1043 | "from util import compute_normals\n", 1044 | "est_norms3 = compute_normals(camera_rays_start[:,0,:],est_depth.reshape((PY,PX)))\n", 1045 | "plt.imshow(est_norms3.reshape((image_size[0],image_size[1],3))*0.5+0.5)" 1046 | ] 1047 | } 1048 | ], 1049 | "metadata": { 1050 | "language_info": { 1051 | "name": "python" 1052 | } 1053 | }, 1054 | "nbformat": 4, 1055 | "nbformat_minor": 5 1056 | } 1057 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def image_grid( 6 | images, 7 | rows=None, 8 | cols=None, 9 | fill: bool = True, 10 | show_axes: bool = False, 11 | rgb: bool = True, 12 | vmin = None, 13 | vmax = None, 14 | cmap = None, 15 | interp = 'nearest', 16 | ): 17 | """ 18 | A util function for plotting a grid of images. 19 | 20 | Args: 21 | images: (N, H, W, 4) array of RGBA images 22 | rows: number of rows in the grid 23 | cols: number of columns in the grid 24 | fill: boolean indicating if the space between images should be filled 25 | show_axes: boolean indicating if the axes of the plots should be visible 26 | rgb: boolean, If True, only RGB channels are plotted. 27 | If False, only the alpha channel is plotted. 28 | 29 | Returns: 30 | None 31 | """ 32 | if (rows is None) != (cols is None): 33 | raise ValueError("Specify either both rows and cols or neither.") 34 | 35 | if rows is None: 36 | rows = len(images) 37 | cols = 1 38 | 39 | gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {} 40 | fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw) 41 | bleed = 0 42 | fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed)) 43 | 44 | for ax, im in zip(axarr.ravel(), images): 45 | if rgb: 46 | # only render RGB channels 47 | ax.imshow(im[..., :3],interpolation=interp) 48 | else: 49 | # only render Alpha channel 50 | ax.imshow(im[...],vmin=vmin,vmax=vmax,cmap=cmap,interpolation=interp) 51 | if not show_axes: 52 | ax.set_axis_off() 53 | plt.tight_layout() 54 | 55 | 56 | from scipy.special import erf 57 | import numpy as np 58 | class DegradeLR: 59 | def __init__(self, init_lr, p_thresh=5e-2, window=10, p_window=5, slope_less=0, max_drops = 4, print_debug=True): 60 | assert( (init_lr >0) and (p_thresh > 0) and (p_thresh < 1)) 61 | self.init_lr = init_lr 62 | self.p_thresh = p_thresh 63 | self.window = int(round(window)) 64 | if self.window < 3: 65 | print('window too small! clipped to 3') 66 | self.window = 3 67 | self.slope_less = slope_less 68 | self.p_window = int(round(p_window)) 69 | if self.p_window < 1: 70 | print('p_window too small! clipped to 1') 71 | self.p_window = 1 72 | self.train_val = [] 73 | self.prior_p = [] 74 | self.n_drops = 0 75 | self.max_drops = max_drops 76 | self.last_drop_len = self.window+1 77 | self.step_func = lambda x: self.init_lr/(10** self.n_drops) 78 | self.print_debug = print_debug 79 | self.counter = 0 80 | def add(self,error): 81 | self.counter += 1 82 | self.train_val.append(error) 83 | len_of_opt = len(self.train_val) 84 | 85 | if len_of_opt >= self.window + self.p_window: 86 | yo = np.array(self.train_val[-self.window:]) 87 | yo = yo/yo.mean() 88 | xo = np.arange(self.window) 89 | xv = np.vstack([xo,np.ones_like(xo)]).T 90 | w = np.linalg.pinv(xv.T @ xv) @ xv.T @ yo 91 | yh = xo*w[0] + w[1] 92 | var =((yh-yo)**2).sum() / (self.window-2) 93 | var_slope = (12*var)/(self.window**3) 94 | ps = 0.5*(1+ erf((self.slope_less-w[0])/(np.sqrt(2*var_slope)))) 95 | self.prior_p.append(ps) 96 | 97 | p_eval = np.array(self.prior_p[-self.p_window:]) 98 | if (p_eval < self.p_thresh).all(): 99 | self.n_drops += 1 100 | if self.n_drops > self.max_drops: 101 | if self.print_debug: 102 | print('early exit due to max drops') 103 | return True 104 | if self.print_debug: 105 | print('dropping LR to {:.2e} after {} steps'.format(self.step_func(0),self.counter-1)) 106 | min_len = self.window+self.p_window 107 | if self.last_drop_len == min_len and len_of_opt == min_len: 108 | if self.print_debug: 109 | print('early exit due to no progress') 110 | return True 111 | self.last_drop_len = len(self.train_val) 112 | self.train_val = [] 113 | return False 114 | 115 | import jax.numpy as jnp 116 | def compute_normals(camera_rays, depth_py_px, eps=1e-20): 117 | PY,PX = depth_py_px.shape 118 | nan_depth = jnp.nan_to_num(depth_py_px.ravel()) 119 | dpt = jnp.array( camera_rays.reshape((-1,3)) * nan_depth[:,None] ) 120 | dpt = dpt.reshape((PY,PX,3)) 121 | ydiff = dpt - jnp.roll(dpt,1,0) 122 | xdiff = dpt - jnp.roll(dpt,1,1) 123 | ddiff = jnp.cross(xdiff.reshape((-1,3)),ydiff.reshape((-1,3)),) 124 | nan_ddiff = jnp.nan_to_num(ddiff,nan=1e-6) 125 | norms = nan_ddiff/(eps+jnp.linalg.norm(nan_ddiff,axis=1,keepdims=True)) 126 | 127 | return norms 128 | -------------------------------------------------------------------------------- /util_load.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytorch3d.transforms import so3_log_map, matrix_to_quaternion 3 | 4 | # wraps function call to save own history 5 | # useful if using scipy.minimize but want to exit optimize whenever 6 | class ExceptionWrap: 7 | def __init__(self,func): 8 | self.func = func 9 | self.results = [] 10 | def __call__(self,*argv): 11 | f = self.func(*argv) 12 | self.results.append((f,tuple(argv[0]))) 13 | return f 14 | 15 | def load_mesh_with_pyt3d(shape_file,torch_device): 16 | from pytorch3d.io import load_obj, load_ply 17 | import pytorch3d.io as py3dIO 18 | from pytorch3d.structures import Meshes 19 | from iopath.common.file_io import PathManager 20 | import torch 21 | from pytorch3d.renderer import TexturesVertex 22 | 23 | if shape_file[-3:] == 'obj': 24 | #print('got obj') 25 | verts, faces_idx, _ = load_obj(shape_file) 26 | faces = faces_idx.verts_idx 27 | elif shape_file[-3:] == 'ply': 28 | #print('got ply') 29 | verts, faces = load_ply(shape_file) 30 | elif shape_file[-3:] == 'off': 31 | mesh2 = py3dIO.off_io.MeshOffFormat().read(shape_file,include_textures=False, 32 | device=torch_device,path_manager=PathManager()) 33 | verts = mesh2.verts_list()[0] 34 | faces = mesh2.faces_list()[0] 35 | else: 36 | raise Exception("Not supported format") 37 | 38 | #verts = verts-verts.mean(0) 39 | #shape_scale = float(verts.std(0).mean())*3 40 | #verts = verts/shape_scale 41 | 42 | # Initialize each vertex to be white in color. 43 | verts_rgb = torch.ones_like(verts)[None] # (1, V, 3) 44 | textures = TexturesVertex(verts_features=verts_rgb.to(torch_device)) 45 | 46 | # Create a Meshes object for the teapot. Here we have only one mesh in the batch. 47 | mesh = Meshes( 48 | verts=[verts.to(torch_device)], 49 | faces=[faces.to(torch_device)], 50 | textures=textures 51 | ) 52 | shape_scale = float(verts.std(0).mean())*3 53 | center = np.array(mesh.verts_list()[0].mean(0).detach().cpu()) 54 | 55 | return mesh,shape_scale,center 56 | 57 | def resize_2d_nonan(array,factor): 58 | """ 59 | intial author: damo_ma 60 | """ 61 | xsize, ysize = array.shape 62 | 63 | if isinstance(factor,int): 64 | factor_x = factor 65 | factor_y = factor 66 | elif isinstance(factor,tuple): 67 | factor_x , factor_y = factor[0], factor[1] 68 | else: 69 | raise NameError('Factor must be a tuple (x,y) or an integer') 70 | 71 | if not (xsize %factor_x == 0 or ysize % factor_y == 0) : 72 | raise NameError('Factors must be intger multiple of array shape') 73 | 74 | new_xsize, new_ysize = xsize//factor_x, ysize//factor_y 75 | 76 | new_array = np.empty([new_xsize, new_ysize]) 77 | new_array[:] = np.nan # this saves us an assignment in the loop below 78 | 79 | # submatrix indexes : is the average box on the original matrix 80 | subrow, subcol = np.indices((factor_x, factor_y)) 81 | 82 | # new matrix indexs 83 | row, col = np.indices((new_xsize, new_ysize)) 84 | 85 | for i, j, ind in zip(row.reshape(-1), col.reshape(-1),range(row.size)) : 86 | # define the small sub_matrix as view of input matrix subset 87 | sub_matrix = array[subrow+i*factor_x,subcol+j*factor_y] 88 | # modified from any(a) and all(a) to a.any() and a.all() 89 | # see https://stackoverflow.com/a/10063039/1435167 90 | if (np.isnan(sub_matrix)).sum() < (factor_x*factor_y)/2.0 + (np.random.rand() -0.5): # if we haven't all NaN 91 | if (np.isnan(sub_matrix)).any(): # if we haven no NaN at all 92 | (new_array.reshape(-1))[ind] = np.nanmean(sub_matrix) 93 | else: # if we haven some NaN 94 | (new_array.reshape(-1))[ind] = np.mean(sub_matrix) 95 | # the case assign NaN if we have all NaN is missing due 96 | # to the standard values of new_array 97 | 98 | return new_array 99 | 100 | def convert_pyt3dcamera(cam, image_size): 101 | height, width = image_size 102 | cx = (width-1)/2 103 | cy = (height-1)/2 104 | f = (height/np.tan((np.pi/180)*float(cam.fov[0])/2))*0.5 105 | K = np.array([[f, 0, cx],[0,f,cy],[0,0,1]]) 106 | pixel_list = (np.array(np.meshgrid(width-np.arange(width)-1,height-np.arange(height)-1,[0]))[:,:,:,0]).reshape((3,-1)).T 107 | 108 | camera_rays = (pixel_list - K[:,2])/np.diag(K) 109 | camera_rays[:,-1] = 1 110 | return np.array(camera_rays), np.array(so3_log_map(cam.R.cpu())[0]), np.array(-cam.R.cpu()[0]@cam.T.cpu()[0]) -------------------------------------------------------------------------------- /util_render.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | # numerically stable exponential for softmax purposes 4 | def jax_stable_exp(z,s=1,axis=0): 5 | z = s*z 6 | z = z- z.max(axis) 7 | z = jnp.exp(z) 8 | return z 9 | 10 | # numerically stable softmax 11 | def local_softmax(z,s=1,axis=0): 12 | z = jax_stable_exp(z,s,axis) 13 | return z/z.sum(keepdims=True,axis=axis) 14 | 15 | # converts modified rodriquez parameters to rotation matrix 16 | def mrp_to_rot(vec): 17 | vec_mag = vec @ vec 18 | vec_mag_num = (1-vec_mag) 19 | vec_mag_den = ((1+vec_mag)**2) 20 | x,y,z = vec 21 | K = jnp.array( 22 | [[ 0, -z, y ], 23 | [ z, 0, -x ], 24 | [ -y, x, 0 ]]) 25 | R1 = jnp.eye(3) - ( ((4*vec_mag_num)/vec_mag_den) * K) + ((8/vec_mag_den) * (K @ K)) 26 | R2 = jnp.eye(3) 27 | 28 | Rest = jnp.where(vec_mag > 1e-12,R1,R2) 29 | return Rest 30 | 31 | # converts axis angle to rotation matrix 32 | def axangle_to_rot(axangl): 33 | scale = jnp.sqrt(axangl @ axangl) 34 | vec = axangl/scale 35 | x,y,z = vec 36 | K = jnp.array( 37 | [[ 0, -z, y ], 38 | [ z, 0, -x ], 39 | [ -y, x, 0 ]]) 40 | ctheta = jnp.cos(scale) 41 | stheta = jnp.sin(scale) 42 | R1 = jnp.eye(3) + stheta*K + (1-ctheta)*(K @ K) 43 | R2 = jnp.eye(3) 44 | Rest = jnp.where(scale > 1e-12,R1.T, R2) 45 | return Rest 46 | 47 | # converts quaternion to rotation matrix 48 | def quat_to_rot(q): 49 | w, x, y, z = q 50 | Nq = w*w + x*x + y*y + z*z 51 | 52 | s = 2.0/Nq 53 | X = x*s 54 | Y = y*s 55 | Z = z*s 56 | wX = w*X; wY = w*Y; wZ = w*Z 57 | xX = x*X; xY = x*Y; xZ = x*Z 58 | yY = y*Y; yZ = y*Z; zZ = z*Z 59 | R1 = jnp.array( 60 | [[ 1.0-(yY+zZ), xY-wZ, xZ+wY ], 61 | [ xY+wZ, 1.0-(xX+zZ), yZ-wX ], 62 | [ xZ-wY, yZ+wX, 1.0-(xX+yY) ]]) 63 | R2 = jnp.eye(3) 64 | return jnp.where(Nq > 1e-12,R1,R2) -------------------------------------------------------------------------------- /utils_opt.py: -------------------------------------------------------------------------------- 1 | import optax 2 | from util import DegradeLR 3 | import fm_render 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | def shape_objective(params,pixel_list,invF,poses,beta2,beta3,beta4,reg_amt,true_alpha): 8 | CLIP_ALPHA = 1e-6 9 | flat_alpha = true_alpha.ravel() 10 | mean,prec,weight_log = params 11 | render_res = fm_render.render_func_idx_quattrans_flow(mean,prec,weight_log,pixel_list,invF,poses,beta2,beta3,beta4) 12 | 13 | est_alpha = render_res[2] 14 | est_alpha = jnp.clip(est_alpha,CLIP_ALPHA,1-CLIP_ALPHA) 15 | mask_loss = - ((flat_alpha * jnp.log(est_alpha)) + (1-flat_alpha)*jnp.log(1-est_alpha)) 16 | 17 | est_depth = render_res[0] 18 | xdiff = (jnp.diff(est_depth.reshape(true_alpha.shape),axis=0,append=0)**2).ravel() 19 | ydiff = (jnp.diff(est_depth.reshape(true_alpha.shape),axis=1,append=0)**2).ravel() 20 | 21 | reg = jnp.where(flat_alpha > 0.5,xdiff+ydiff,0) 22 | return mask_loss.mean() + reg_amt*reg.mean() 23 | 24 | def pose_objective(poses,mean,prec,weight_log,pixel_list,invF,beta2,beta3,beta4,true_alpha,true_fwd,true_bwd,flow_mul): 25 | CLIP_ALPHA = 1e-6 26 | render_res = fm_render.render_func_idx_quattrans_flow(mean,prec,weight_log,pixel_list,invF,poses,beta2,beta3,beta4) 27 | 28 | est_alpha = render_res[2] 29 | est_alpha = jnp.clip(est_alpha,CLIP_ALPHA,1-CLIP_ALPHA) 30 | mask_loss = - ((true_alpha * jnp.log(est_alpha)) + (1-true_alpha)*jnp.log(1-est_alpha)) 31 | 32 | pad_alpha = true_alpha[:,None] 33 | flow1 = jnp.abs(pad_alpha*true_fwd.reshape((-1,2))-pad_alpha*render_res[5]) 34 | flow2 = jnp.abs(pad_alpha*true_bwd.reshape((-1,2))-pad_alpha*render_res[6]) 35 | return mask_loss.mean() + flow_mul*(flow1.mean() + flow2.mean()) 36 | 37 | def shape_pose_objective(params,pixel_list,beta2,beta3,beta4,true_alphas): 38 | CLIP_ALPHA = 1e-6 39 | mean,prec,weight_log,invF,poses = params 40 | 41 | def eval_frame(pose,true_alpha): 42 | flat_alpha = true_alpha.ravel() 43 | render_res = fm_render.render_func_idx_quattrans_flow(mean,prec,weight_log,pixel_list,invF,poses,beta2,beta3,beta4) 44 | 45 | est_alpha = render_res[2] 46 | est_alpha = jnp.clip(est_alpha,CLIP_ALPHA,1-CLIP_ALPHA) 47 | mask_loss = - ((flat_alpha * jnp.log(est_alpha)) + (1-flat_alpha)*jnp.log(1-est_alpha)) 48 | return mask_loss.mean() 49 | per_frames = jax.vmap(eval_frame)(poses,true_alphas) 50 | return per_frames.mean() 51 | 52 | def reconstruct_shape(vg_objective,init_shape,degrade_settings,render_settings,pose,Niter,reg_amt,reference): 53 | 54 | # babysit learning rates 55 | adjust_lr = DegradeLR(*degrade_settings) 56 | beta2,beta3,beta4,pixel_list,invF = render_settings 57 | 58 | optimizer = optax.adam(adjust_lr.step_func) 59 | opt_state = optimizer.init(init_shape) 60 | 61 | params = init_shape 62 | 63 | losses = [] 64 | for i in range(Niter): 65 | val,g = vg_objective(params,pixel_list,invF,pose,beta2,beta3,beta4,reg_amt,reference) 66 | val = float(val) 67 | losses.append(val) 68 | updates, opt_state = optimizer.update(g, opt_state, params) 69 | params = optax.apply_updates(params, updates) 70 | if adjust_lr.add(val): 71 | break 72 | return params,losses 73 | 74 | def obtain_pose(vg_objective,shape,degrade_settings,render_settings,poses,Niter,reference,fwdflow,bwdflow,flow_amt): 75 | 76 | # babysit learning rates 77 | adjust_lr = DegradeLR(*degrade_settings) 78 | beta2,beta3,beta4,pixel_list,invF = render_settings 79 | mean, prec, weight_log = shape 80 | 81 | optimizer = optax.sgd(adjust_lr.step_func,0.95) 82 | opt_state = optimizer.init(poses) 83 | params = poses 84 | losses = [] 85 | for i in range(Niter): 86 | val,g = vg_objective(params,mean, prec, weight_log,pixel_list,invF,beta2,beta3,beta4,reference,fwdflow,bwdflow,flow_amt) 87 | val = float(val) 88 | losses.append(val) 89 | updates, opt_state = optimizer.update(g, opt_state, params) 90 | params = optax.apply_updates(params, updates) 91 | 92 | if adjust_lr.add(val): 93 | break 94 | return params,losses 95 | 96 | def refine_shapepose(vg_objective,init_shape,degrade_settings,render_settings,poses,Niter,references): 97 | # babysit learning rates 98 | adjust_lr = DegradeLR(*degrade_settings) 99 | beta2,beta3,beta4,pixel_list,invF = render_settings 100 | 101 | optimizer = optax.adam(adjust_lr.step_func,0.9) 102 | init_state = list(init_shape) + [invF]+ [poses] 103 | opt_state = optimizer.init(init_state) 104 | 105 | losses = [] 106 | loop = range(Niter) 107 | 108 | params = init_state 109 | 110 | for i in loop: 111 | val,g = vg_objective(params,pixel_list,beta2,beta3,beta4,references[:len(poses)]) 112 | 113 | val = float(val) 114 | losses.append(val) 115 | updates, opt_state = optimizer.update(g, opt_state, params) 116 | params = optax.apply_updates(params, updates) 117 | 118 | if adjust_lr.add(val): 119 | break 120 | return params,losses 121 | 122 | def readFlow(fn): 123 | import numpy as np 124 | """ Read .flo file in Middlebury format""" 125 | # Code adapted from: 126 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 127 | 128 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 129 | # print 'fn = %s'%(fn) 130 | with open(fn, 'rb') as f: 131 | magic = np.fromfile(f, np.float32, count=1) 132 | if 202021.25 != magic: 133 | print('Magic number incorrect. Invalid .flo file') 134 | return None 135 | else: 136 | w = np.fromfile(f, np.int32, count=1) 137 | h = np.fromfile(f, np.int32, count=1) 138 | # print 'Reading %d x %d flo file\n' % (w, h) 139 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) 140 | # Reshape testdata into 3D array (columns, rows, bands) 141 | # The reshape here is for visualization, the original code is (w,h,2) 142 | return np.resize(data, (int(h), int(w), 2)) -------------------------------------------------------------------------------- /zpfm_render.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | # this file implements most of https://arxiv.org/abs/2308.14737 5 | # this basically follows fm_render.py, but has a different blending function 6 | # the other file has more detailed comments. 7 | 8 | # contains various rotation conversions 9 | from util_render import * 10 | 11 | def render_func_rays(means, prec_full, weights_log, camera_starts_rays): 12 | prec = jnp.triu(prec_full) 13 | #weights = jnp.exp(weights_log) 14 | #weights = weights/weights.sum() 15 | 16 | def perf_idx(prcI,w,meansI): 17 | prc = prcI.T 18 | #prc = jnp.diag(jnp.sign(jnp.diag(prc))) @ prc 19 | div = jnp.prod(jnp.diag(jnp.abs(prc))) + 1e-20 20 | 21 | def perf_ray(r_t): 22 | r = r_t[0] 23 | t = r_t[1] 24 | p = meansI -t 25 | 26 | projp = prc @ p 27 | vsv = ((prc @ r)**2).sum() 28 | psv = ((projp) * (prc@r)).sum() 29 | projp2 = prc.T @ projp 30 | 31 | # linear 32 | res = (psv)/(vsv) 33 | 34 | v = r * res - p 35 | 36 | d0 = ((prc @ v)**2).sum()# + 3*jnp.log(jnp.pi*2) 37 | d2 = -0.5*d0 + w 38 | #d3 = d2 + jnp.log(div) #+ 3*jnp.log(res) 39 | norm_est = projp2/jnp.linalg.norm(projp2) 40 | norm_est = jnp.where(r@norm_est < 0,norm_est,-norm_est) 41 | return res,d2,norm_est 42 | res,d2,projp = jax.vmap((perf_ray))(camera_starts_rays) # jit perf 43 | return res, d2,projp 44 | 45 | zs,stds,projp = jax.vmap(perf_idx)(prec,weights_log,means) # jit perf 46 | 47 | # compositing 48 | sample_density = jnp.exp(stds) # simplier but splottier 49 | def sort_w(z,densities): 50 | # get the order of the z values 51 | idxs = jnp.argsort(z,axis=0) 52 | # sample the densities in z-order 53 | order_density = densities[idxs] 54 | # integrate 55 | order_summed_density = jnp.cumsum(order_density) 56 | # get "prior sum" 57 | order_prior_density = order_summed_density - order_density 58 | # compute expected alpha as final ray weight 59 | ea = 1 - jnp.exp(-order_summed_density[-1]) 60 | # resample the densities out of z-order, into original order 61 | prior_density = jnp.zeros_like(densities) 62 | prior_density = prior_density.at[idxs].set(order_prior_density) 63 | # compute the transmission of current and prior, Max/NeRF style 64 | transmit = jnp.exp(-prior_density) 65 | wout = transmit * (1-jnp.exp(-densities)) 66 | # return weight and total expected alpha 67 | return wout, ea 68 | w,est_alpha= jax.vmap(sort_w)(zs.T,sample_density.T) 69 | w = w.T 70 | 71 | wgt = w.sum(0) 72 | div = jnp.where(wgt==0,1,wgt) 73 | w_n = w/div 74 | 75 | init_t= (w_n*jnp.nan_to_num(zs)).sum(0) 76 | est_norm = (projp * w_n[:,:,None]).sum(axis=0) 77 | est_norm = est_norm/jnp.linalg.norm(est_norm,axis=1,keepdims=True) 78 | 79 | return init_t,est_alpha,est_norm,w 80 | 81 | # axis angle rotations n * theta 82 | def render_func_axangle(means, prec_full, weights_log, camera_rays, axangl, t): 83 | Rest = axangle_to_rot(axangl) 84 | camera_rays = camera_rays @ Rest 85 | trans = jnp.tile(t[None],(camera_rays.shape[0],1)) 86 | 87 | camera_starts_rays = jnp.stack([camera_rays,trans],1) 88 | return render_func_rays(means, prec_full, weights_log, camera_starts_rays) 89 | 90 | # modified rod. parameters n * tan(theta/4) 91 | def render_func_mrp(means, prec_full, weights_log, camera_rays, mrp, t): 92 | Rest = mrp_to_rot(mrp) 93 | camera_rays = camera_rays @ Rest 94 | trans = jnp.tile(t[None],(camera_rays.shape[0],1)) 95 | 96 | camera_starts_rays = jnp.stack([camera_rays,trans],1) 97 | return render_func_rays(means, prec_full, weights_log, camera_starts_rays) 98 | 99 | # quaternions [cos(theta/2), sin(theta/2) * n] 100 | def render_func_quat(means, prec_full, weights_log, camera_rays, quat, t): 101 | Rest = quat_to_rot(quat) 102 | camera_rays = camera_rays @ Rest 103 | trans = jnp.tile(t[None],(camera_rays.shape[0],1)) 104 | 105 | camera_starts_rays = jnp.stack([camera_rays,trans],1) 106 | return render_func_rays(means, prec_full, weights_log, camera_starts_rays) 107 | 108 | def render_func_quat_cam(means, prec_full, weights_log, pixel_list, aspect, invF, quat, t): 109 | camera_rays = (pixel_list - jnp.array([0.5,0.5,0]))*jnp.array([invF,aspect*invF,1]) 110 | 111 | Rest = quat_to_rot(quat) 112 | camera_rays = camera_rays @ Rest 113 | trans = jnp.tile(t[None],(camera_rays.shape[0],1)) 114 | 115 | camera_starts_rays = jnp.stack([camera_rays,trans],1) 116 | return render_func_rays(means, prec_full, weights_log, camera_starts_rays) 117 | 118 | def render_func_idx_quattrans(means, prec_full, weights_log, pixel_posei, invF, poses): 119 | rot_mats = jax.vmap(quat_to_rot)(poses[:,:4]) 120 | def rot_ray_t(rayi): 121 | ray = rayi[:3] * jnp.array([invF,invF,1]) 122 | pose_idx = rayi[3].astype(int) 123 | return jnp.array([ray@rot_mats[pose_idx],poses[pose_idx][4:]]) 124 | camera_rays_start= jax.vmap(rot_ray_t)(pixel_posei) 125 | return render_func_rays(means, prec_full, weights_log, camera_rays_start) 126 | 127 | 128 | def render_func_idx_quattrans_flow(means, prec_full, weights_log, pixel_posei, invF, poses): 129 | rot_mats = jax.vmap(quat_to_rot)(poses[:,:4]) 130 | def rot_ray_t(rayi): 131 | ray = rayi[:3] * jnp.array([invF,invF,1]) 132 | pose_idx = rayi[3].astype(int) 133 | return jnp.array([ray@rot_mats[pose_idx],poses[pose_idx][4:]]) 134 | camera_rays_start = jax.vmap(rot_ray_t)(pixel_posei) 135 | 136 | est_depth,est_alpha,est_norm,est_w = render_func_rays(means, prec_full, weights_log, camera_rays_start) 137 | 138 | def flow_ray_i(rayi,depth): 139 | pose_idx = rayi[3].astype(int) 140 | pose_idxp1 = jax.lax.min(pose_idx+1,poses.shape[0]-1) 141 | pose_idxm1 = jax.lax.max(pose_idx-1,0) 142 | 143 | R1 = rot_mats[pose_idx] 144 | t1 = poses[pose_idx,4:] 145 | 146 | Rp1 = rot_mats[pose_idxp1] 147 | tp1 = poses[pose_idxp1,4:] 148 | 149 | ray = rayi[:3] * jnp.array([invF,invF,1]) 150 | pt_cldc1 = ray * depth 151 | world_p = pt_cldc1 @ R1 + t1 152 | 153 | pt_cldc2 = (world_p- tp1) @ Rp1.T 154 | coord1 = pt_cldc2[:2]/(pt_cldc2[2]*invF) 155 | px_coordp = -(coord1 - rayi[:2]) 156 | 157 | Rm1 = rot_mats[pose_idxm1] 158 | tm1 = poses[pose_idxm1,4:] 159 | 160 | pt_cldc3 = (world_p - tm1) @ Rm1.T 161 | coord2 = pt_cldc3[:2]/(pt_cldc3[2]*invF) 162 | px_coordm = -(coord2 - rayi[:2]) 163 | 164 | return px_coordp,px_coordm 165 | 166 | flowp,flowm = jax.vmap(flow_ray_i)(pixel_posei,jnp.where(est_depth!=0,jnp.maximum(1e-7,est_depth),1e-7)) 167 | 168 | return est_depth,est_alpha,est_norm,est_w,flowp,flowm 169 | 170 | 171 | def log_likelihood(params, points): 172 | means, prec_full, weights_log = params 173 | prec = jnp.triu(prec_full) 174 | weights = jnp.exp(weights_log) 175 | weights = weights/weights.sum() 176 | 177 | def perf_idx(prcI,w,meansI): 178 | prc = prcI.T 179 | div = jnp.prod(jnp.diag(jnp.abs(prc))) 180 | 181 | def perf_ray(pt): 182 | p = meansI -pt 183 | 184 | pteval = ((prc @ p)**2).sum() 185 | 186 | d0 = pteval+ 3*jnp.log(jnp.pi*2) 187 | d2 = -0.5*d0 + jnp.log(w) 188 | d3 = d2 + jnp.log(div) 189 | 190 | return d3 191 | res = jax.vmap((perf_ray))(points) # jit perf 192 | return res 193 | 194 | res = jax.vmap(perf_idx)(prec,weights,means) # jit perf 195 | 196 | 197 | return -jax.scipy.special.logsumexp(res.T, axis=1).ravel().mean(),res# + ent.mean() 198 | --------------------------------------------------------------------------------