├── .gitignore ├── LICENSE ├── README.md ├── bunny ├── bun_zipper.ply ├── bun_zipper_1000_1.ply ├── bun_zipper_50k.ply ├── bun_zipper_895_1.ply ├── bun_zipper_992_1.ply ├── bun_zipper_pds_1000_1.ply ├── bun_zipper_pts_1000_1.ply ├── bun_zipper_res2.ply ├── bun_zipper_res3.ply ├── bun_zipper_res4.ply ├── bun_zipper_res4_25k_pds.ply ├── bun_zipper_res4_pds.ply ├── bun_zipper_res4_sds.ply └── bun_zipper_res4_spr.ply ├── cube-covariance test.ipynb ├── format_graphs-Copy1.ipynb ├── format_graphs.ipynb ├── gen_gmm-mine-com.ipynb ├── gen_gmm-mine-patch.ipynb ├── gen_gmm-mine.ipynb ├── gen_gmm.ipynb ├── gmm_fit.py ├── gmm_fit2.py ├── gmm_fit3.py ├── gmm_fit_extra.py ├── graph_scales.ipynb ├── ground_truth_testing.ipynb ├── icp_test.ipynb ├── likelihood.ipynb ├── llcomp.py ├── mixture ├── __init__.py ├── base.py ├── bayesian_mixture.py ├── gaussian_mixture.py ├── gaussian_mixture_v1.py └── tests │ ├── __init__.py │ ├── test_bayesian_mixture.py │ ├── test_gaussian_mixture.py │ └── test_mixture.py ├── plot_trajec.ipynb ├── reg_results.ipynb ├── reg_viz-extra.ipynb ├── reg_viz.ipynb ├── registration_test.py ├── registration_test_extra.py ├── render_tmp.py ├── road_graphic.py ├── tri_test.py ├── tri_verts_graph-Copy1.ipynb ├── tri_verts_graph.ipynb ├── vis_fig.py ├── vis_fitting.py ├── vis_fitting_bunny.py ├── vis_fitting_bunny_mesh.py └── visualize gmm.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | # extras 6 | .vscode/ 7 | .DS_Store 8 | *.png 9 | *.pdf 10 | old_stuff/ 11 | # C extensions 12 | *.so 13 | *.csv 14 | # Distribution / packaging 15 | .Python 16 | env/ 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # dotenv 88 | .env 89 | 90 | # virtualenv 91 | .venv 92 | venv/ 93 | ENV/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Leonid Keselman 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Direct Fitting of GMMs project website](https://leonidk.github.io/direct_gmm/) 2 | This is the source code and project history for the following publication 3 | 4 | **Direct Fitting of Gaussian Mixture Models** by Leonid Keselman and Martial Hebert ([arXiv version here](https://arxiv.org/abs/1904.05537)) 5 | 6 | ## Overview 7 | Almost all files used in the development and testing of this project are in this folder. The data files for the Stanford Bunny is included in `bunny`. 8 | 9 | * `mixture` contains the modifed version of scikit-learn with the proposed techniques. 10 | * `gmm_fit.py` and `gmm_fit2.py` contain the two sets of the bunny likelihood experiments 11 | * `registration_test.py` contains the mesh registration (P2D) experiments 12 | * Files with `_extra` are usually just copies for non-Stanford Bunny experiments 13 | * `gen_gmm.ipynb` and `gen_gmm_mine.ipynb` generate GMM models from the TUM dataset, with and without uncertainty models 14 | * `reg_results.ipynb` performs D2D registration between the GMM models built from the TUM dataset. 15 | -------------------------------------------------------------------------------- /bunny/bun_zipper_1000_1.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonidk/direct_gmm/6735bef868dc3ba061ca544f5e95ce2fed09f70a/bunny/bun_zipper_1000_1.ply -------------------------------------------------------------------------------- /bunny/bun_zipper_50k.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonidk/direct_gmm/6735bef868dc3ba061ca544f5e95ce2fed09f70a/bunny/bun_zipper_50k.ply -------------------------------------------------------------------------------- /bunny/bun_zipper_895_1.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonidk/direct_gmm/6735bef868dc3ba061ca544f5e95ce2fed09f70a/bunny/bun_zipper_895_1.ply -------------------------------------------------------------------------------- /bunny/bun_zipper_992_1.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonidk/direct_gmm/6735bef868dc3ba061ca544f5e95ce2fed09f70a/bunny/bun_zipper_992_1.ply -------------------------------------------------------------------------------- /bunny/bun_zipper_pds_1000_1.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonidk/direct_gmm/6735bef868dc3ba061ca544f5e95ce2fed09f70a/bunny/bun_zipper_pds_1000_1.ply -------------------------------------------------------------------------------- /bunny/bun_zipper_pts_1000_1.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonidk/direct_gmm/6735bef868dc3ba061ca544f5e95ce2fed09f70a/bunny/bun_zipper_pts_1000_1.ply -------------------------------------------------------------------------------- /bunny/bun_zipper_res4_25k_pds.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonidk/direct_gmm/6735bef868dc3ba061ca544f5e95ce2fed09f70a/bunny/bun_zipper_res4_25k_pds.ply -------------------------------------------------------------------------------- /bunny/bun_zipper_res4_pds.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonidk/direct_gmm/6735bef868dc3ba061ca544f5e95ce2fed09f70a/bunny/bun_zipper_res4_pds.ply -------------------------------------------------------------------------------- /bunny/bun_zipper_res4_sds.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonidk/direct_gmm/6735bef868dc3ba061ca544f5e95ce2fed09f70a/bunny/bun_zipper_res4_sds.ply -------------------------------------------------------------------------------- /bunny/bun_zipper_res4_spr.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonidk/direct_gmm/6735bef868dc3ba061ca544f5e95ce2fed09f70a/bunny/bun_zipper_res4_spr.ply -------------------------------------------------------------------------------- /cube-covariance test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pylab import *\n", 10 | "\n", 11 | "from mpl_toolkits.mplot3d import axes3d, Axes3D\n", 12 | "from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection\n", 13 | "import transforms3d\n", 14 | "import numpy as np\n", 15 | "import os\n", 16 | "\n" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "NUM_SAMPLES = 1\n", 26 | "NUM_POINTS = 1500\n", 27 | "F = 550\n", 28 | "B = 70" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "def get_plane_eq(p1,p2,p3):\n", 38 | " #p1 = np.array(p1)\n", 39 | " #p2 = np.array(p2)\n", 40 | " #p3 = np.array(p3)\n", 41 | "\n", 42 | " v1 = p3 - p1\n", 43 | " v2 = p2 - p1\n", 44 | " cp = np.cross(v1, v2)\n", 45 | " a, b, c = cp/np.linalg.norm(cp)\n", 46 | "\n", 47 | " # This evaluates a * x3 + b * y3 + c * z3 which equals d\n", 48 | " d = np.dot(cp, p3)\n", 49 | " return np.array([a,b,c,d])\n", 50 | "for s in range(NUM_SAMPLES):\n", 51 | " x,y = np.random.randint(-320,320), np.random.randint(-240,240) \n", 52 | " d = np.random.randint(2,64)\n", 53 | " #x,y,d = 1,0,3\n", 54 | " z = (F*B)/d\n", 55 | " ogpt = (z*x/F,z*y/F,z)\n", 56 | " \n", 57 | " zf = (F*B)/(d+1)\n", 58 | " zb = (F*B)/(d-1)\n", 59 | " \n", 60 | "\n", 61 | " frtpt = (zf*(x+0.5)/F,zf*y/F,zf)\n", 62 | " lftpt = (z*(x-0.5)/F,z*y/F,z)\n", 63 | " rgtpt = (z*(x+0.5)/F,z*y/F,z)\n", 64 | " bckpt = (zb*(x-0.5)/F,zb*y/F,zb)\n", 65 | " \n", 66 | " frtptT = np.array(((x+0.5)/F,(y+0.5)/F,1))*zf\n", 67 | " lftptT = np.array(((x-0.5)/F,(y+0.5)/F,1))*z\n", 68 | " rgtptT = np.array(((x+0.5)/F,(y+0.5)/F,1))*z\n", 69 | " bckptT = np.array(((x-0.5)/F,(y+0.5)/F,1))*zb\n", 70 | " \n", 71 | " frtptB = np.array(((x+0.5)/F,(y-0.5)/F,1))*zf\n", 72 | " lftptB = np.array(((x-0.5)/F,(y-0.5)/F,1))*z\n", 73 | " rgtptB = np.array(((x+0.5)/F,(y-0.5)/F,1))*z\n", 74 | " bckptB = np.array(((x-0.5)/F,(y-0.5)/F,1))*zb\n", 75 | " \n", 76 | " box = np.vstack([frtptT,lftptT,rgtptT,bckptT,frtptB,lftptB,rgtptB,bckptB])\n", 77 | " min_bounds = box.min(0)\n", 78 | " max_bounds = box.max(0)\n", 79 | " \n", 80 | " # facing plane, clockwise winding order\n", 81 | " A = get_plane_eq(frtptT,lftptT,rgtptT)\n", 82 | " Z = get_plane_eq(frtptB,rgtptB,lftptB)\n", 83 | "\n", 84 | " C = get_plane_eq(frtptT,rgtptT,rgtptB)\n", 85 | " Y = get_plane_eq(bckptT,lftptT,lftptB)\n", 86 | " \n", 87 | " E = get_plane_eq(lftptT,frtptT,frtptB)\n", 88 | " X = get_plane_eq(rgtptT,bckptT,bckptB)\n", 89 | " \n", 90 | " the_planes = np.vstack([A,Z,C,Y,E,X])\n", 91 | " \n", 92 | " span = max_bounds - min_bounds\n", 93 | " dim_probs = span/np.linalg.norm(span)\n", 94 | " pts = []\n", 95 | " #while len(pts) < NUM_POINTS:\n", 96 | " rpts = np.random.rand(NUM_POINTS,3)*span + min_bounds\n", 97 | " dv = np.vstack([frtptB,lftptB,bckptB,rgtptB])\n", 98 | "\n", 99 | " covar = np.zeros((3,3))\n", 100 | " m = box.mean(0,keepdims=1)\n", 101 | " M = m.T @ m\n", 102 | " for pt in box:\n", 103 | " pt = pt.reshape((-1,1))\n", 104 | " covar += (pt @ pt.T - M)\n", 105 | " covar = covar/12\n", 106 | " funny_covar = np.copy(covar)\n", 107 | " \n", 108 | " \n", 109 | " gpts = np.random.multivariate_normal(m.ravel(),covar,NUM_POINTS)\n", 110 | " plt.figure()\n", 111 | " plt.scatter(gpts[:,0],gpts[:,2],s=3,alpha=0.3)\n", 112 | " plt.scatter(dv[:,0],dv[:,2])\n", 113 | "funny_covar" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "print(frtptT-frtptB)\n", 123 | "X,Y,Z = frtptT,lftptT,rgtptT\n", 124 | "np.linalg.norm(np.cross(X-Y,Z-Y))/2\n", 125 | "\n", 126 | "M = 1/3 * (X+Y+Z)\n", 127 | "c1 = X[:,None] @ X[None,:] + Y[:,None] @ Y[None,:] + Z[:,None] @ Z[None,:] - 3 * M[:,None] @ M[None,:]\n", 128 | "\n", 129 | "A = np.vstack([X,Y,Z]).T\n", 130 | "c2 =A @ np.array([[2,1,1],[1,2,1],[1,1,2] ])/120 @ A.T\n", 131 | "c2 = c2 * np.linalg.det(A) #* (c1[0,0]/c2[0,0])\n", 132 | "center = 1/4 * (X+Y+Z)\n", 133 | "mass = np.linalg.det(A/6)\n", 134 | "total_center = center\n", 135 | "correct = np.array(tuple(tuple(x * y for x in total_center) for y in total_center))*mass\n", 136 | "correct\n", 137 | "c3 = c2 - correct\n", 138 | "\n", 139 | "AB = X-Z\n", 140 | "AC = Y-Z\n", 141 | "print('AREA')\n", 142 | "print(np.linalg.norm(np.cross(AB,AC))/2)\n", 143 | "print(np.linalg.norm(np.cross(X-Y,Z-Y))/2)\n", 144 | "\n", 145 | "Mt = np.array([[2,1,1], [1,2,1], [1,1,2] ])\n", 146 | "v = np.zeros(3)\n", 147 | "A = np.vstack([X-M,Y-M,Z-M]).T\n", 148 | "c4 = A @ Mt @ A.T\n", 149 | "print(c1)\n", 150 | "print(c2)\n", 151 | "print(correct)\n", 152 | "print(c3)\n", 153 | "print(c4)\n", 154 | "frtptT,lftptT,rgtptT,bckptT\n", 155 | "frtptB,lftptB,rgtptB,bckptB\n", 156 | "\n", 157 | "tris = []\n", 158 | "#cw triangles\n", 159 | "# top\n", 160 | "f1t1 = [lftptT,rgtptT,frtptT]\n", 161 | "f1t2 = [lftptT,bckptT,rgtptT]\n", 162 | "#bottom\n", 163 | "f2t1 = [rgtptB,lftptB,frtptB]\n", 164 | "f2t2 = [rgtptB,bckptB,lftptB]\n", 165 | "\n", 166 | "#front right\n", 167 | "f3t1 = [frtptT,rgtptT,rgtptB]\n", 168 | "f3t2 = [rgtptB,frtptB,frtptT]\n", 169 | "\n", 170 | "#front left\n", 171 | "f4t1 = [frtptT,frtptB,lftptB]\n", 172 | "f4t2 = [lftptB,lftptT,frtptT]\n", 173 | "\n", 174 | "#back right\n", 175 | "f5t1 = [rgtptT,bckptT,bckptB]\n", 176 | "f5t2 = [bckptB,rgtptB,rgtptT]\n", 177 | "\n", 178 | "#back right\n", 179 | "f6t1 = [bckptT,lftptT,lftptB]\n", 180 | "f6t2 = [lftptB,bckptB,bckptT]\n", 181 | "verts = [f1t1,f1t2,f2t1,f2t2,f3t1,f3t2,f4t1,f4t2,f5t1,f5t2,f6t1,f6t2]\n", 182 | "canon = np.array([[2,1,1],[1,2,1],[1,1,2]])/120.0\n", 183 | "\n", 184 | "for t in verts:\n", 185 | " print('next')\n", 186 | " print(np.cross(t[2]-t[0],t[1]-t[0]))\n", 187 | "tris = verts\n", 188 | "\n", 189 | "#def get_com_cvar(tris):\n", 190 | "coms = []\n", 191 | "masses = []\n", 192 | "covars = []\n", 193 | "\n", 194 | "\n", 195 | "for t in tris:\n", 196 | " A = np.array([t[0],t[1],t[2]]).T\n", 197 | " det = np.linalg.det(A)\n", 198 | " C = det * A @ canon @ A.T\n", 199 | " mass = det/6\n", 200 | " com = (np.sum(t,0))/4\n", 201 | "\n", 202 | " coms.append(com)\n", 203 | " covars.append(C)\n", 204 | " masses.append(mass)\n", 205 | "total_covar = np.zeros((3,3))\n", 206 | "total_com = np.zeros(3)\n", 207 | "total_mass = sum(masses)\n", 208 | "for com, mass in zip(coms,masses):\n", 209 | " total_com += (mass/total_mass) * com\n", 210 | "total_covar = np.array(covars).sum(0) - total_mass * (total_com[:,None] @ total_com[None,:])\n", 211 | "#return total_covar, total_com, total_mass\n", 212 | "\n", 213 | "clast = np.zeros((3,3))\n", 214 | "\n", 215 | "for t,m,cm in zip(tris,masses,coms):\n", 216 | " A = np.array([t[0]-total_com,t[1]-total_com,t[2]-total_com]).T\n", 217 | " clast += np.linalg.det(A) * A @ canon @ A.T #- (1/total_mass)*(m*cm[:,None]) @ (m*cm[None,:])\n", 218 | "\n", 219 | "#final_covar,final_com,final_mass = get_com_cvar(verts)\n", 220 | "\n", 221 | "if False:\n", 222 | " from mpl_toolkits.mplot3d import Axes3D\n", 223 | " from mpl_toolkits.mplot3d.art3d import Poly3DCollection\n", 224 | " import matplotlib.pyplot as plt\n", 225 | " fig = plt.figure()\n", 226 | " ax = Axes3D(fig)\n", 227 | " b2 = np.array(verts)\n", 228 | " minb = b2.reshape((-1,3)).min(0)\n", 229 | " maxb = b2.reshape((-1,3)).max(0)\n", 230 | " meanb = b2.reshape((-1,3)).mean(0)\n", 231 | " span = maxb-minb\n", 232 | " w = span/2\n", 233 | " pc = Poly3DCollection(verts, alpha = 0.2, facecolor='b', linewidths=1,edgecolors='r')\n", 234 | " ax.add_collection3d(pc)\n", 235 | " ax.set_xlim(minb[0]-w[0],maxb[0]+w[0])\n", 236 | " ax.set_ylim(minb[1]-w[1],maxb[1]+w[1])\n", 237 | " ax.set_zlim(minb[2]-w[2],maxb[2]+w[2])\n", 238 | "\n", 239 | " plt.show()\n", 240 | "gpts2 = np.random.multivariate_normal(total_com,total_covar,NUM_POINTS)\n", 241 | "plt.figure()\n", 242 | "#plt.scatter(gpts2[:,0],gpts2[:,2],s=3,alpha=0.3,label='form/12')\n", 243 | "plt.scatter(gpts[:,0],gpts[:,2],s=3,alpha=0.3,label='mine')\n", 244 | "plt.legend()\n", 245 | "plt.scatter(dv[:,0],dv[:,2])\n", 246 | "funny_covar = np.zeros_like(funny_covar)\n", 247 | "m = total_com[None,:]\n", 248 | "M = m.T @ m\n", 249 | "for pt in box:\n", 250 | " pt = pt.reshape((-1,1))\n", 251 | " funny_covar += (pt @ pt.T - M)\n", 252 | "funny_covar = funny_covar/12" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": null, 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "total_covar, total_com, total_mass, clast" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": null, 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "print(np.linalg.norm(np.cross(X-Y,Z-Y))/2)\n", 271 | "\n", 272 | "temp = np.array([X,Y,Z])\n", 273 | "temp,np.linalg.det(temp)\n" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "funny_covar,m,(total_covar/funny_covar),((2*clast/total_mass)/funny_covar)" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": null, 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "pt_planes = np.hstack([rpts,np.ones((NUM_POINTS,1)) ]) @ the_planes.T\n", 292 | "valid_pts = (pt_planes <= 0).sum(1)\n", 293 | "print((valid_pts == 6).sum())\n", 294 | "print((valid_pts == 0).sum())\n", 295 | "print((pt_planes <= 0).sum(0))\n", 296 | "pts = rpts[(valid_pts <=1)]\n", 297 | "print(len(pts))" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "%matplotlib inline\n", 307 | "plt.scatter(dv[:,0],dv[:,2])\n", 308 | "plt.scatter(pts[:,0],pts[:,2])\n" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": null, 314 | "metadata": {}, 315 | "outputs": [], 316 | "source": [ 317 | "\n" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": null, 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [ 326 | "%matplotlib notebook\n", 327 | "fig = plt.figure()\n", 328 | "ax = Axes3D(fig)\n", 329 | "x = [0, 1, 1, 0]\n", 330 | "y = [0, 0, 1, 1]\n", 331 | "z = [0, 1, 0, 1]\n", 332 | "verts2 = [list(zip(x, y, z))]\n", 333 | "verts = [[tuple(_) for _ in dv]]\n", 334 | "print(verts)\n", 335 | "ax.scatter3D(gpts[:,0],gpts[:,1],gpts[:,2])\n", 336 | "plt.show()\n", 337 | "print(span)" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": null, 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "verts" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "list([tuple(_) for _ in dv])" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": null, 361 | "metadata": {}, 362 | "outputs": [], 363 | "source": [] 364 | } 365 | ], 366 | "metadata": { 367 | "kernelspec": { 368 | "display_name": "Python 3", 369 | "language": "python", 370 | "name": "python3" 371 | }, 372 | "language_info": { 373 | "codemirror_mode": { 374 | "name": "ipython", 375 | "version": 3 376 | }, 377 | "file_extension": ".py", 378 | "mimetype": "text/x-python", 379 | "name": "python", 380 | "nbconvert_exporter": "python", 381 | "pygments_lexer": "ipython3", 382 | "version": "3.6.7" 383 | } 384 | }, 385 | "nbformat": 4, 386 | "nbformat_minor": 2 387 | } 388 | -------------------------------------------------------------------------------- /format_graphs-Copy1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import pandas as pd\n", 12 | "import matplotlib.ticker" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "df = pd.read_csv(\"arma.log\",names=['k','init','model','l','i'])\n", 22 | "mdf2 = df.groupby(['init','model','k']).mean()\n", 23 | "sdf2 = df.groupby(['init','model','k']).std()\n", 24 | "mdf = df.groupby(['init','model','k']).mean()\n", 25 | "sdf = df.groupby(['init','model','k']).std()" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "pltstuff = 1\n", 35 | "\n", 36 | "fig = plt.figure(figsize=(10,pltstuff*5))\n", 37 | "from matplotlib import rc\n", 38 | "plt.style.use('fivethirtyeight')\n", 39 | "plt.style.use('seaborn-whitegrid')\n", 40 | "if pltstuff == 2:\n", 41 | " subplot_order = [1,3,2,4]\n", 42 | "else:\n", 43 | " subplot_order = [1,2]\n", 44 | "#plt.rcParams[\"font.family\"] = \"sans-serif\"\n", 45 | "rc('font',**{'family':'sans-serif','sans-serif':['cm']})\n", 46 | "#plt.style.use('default')\n", 47 | "for ii, init in enumerate(['kmeans']):\n", 48 | " for model in range(4):\n", 49 | " ls = '--' if model < 2 else '-'\n", 50 | " if model == 0:\n", 51 | " label = 'Mesh (Quadric)'\n", 52 | " if model == 1:\n", 53 | " label = 'Mesh (Cluster)'\n", 54 | " if model == 2:\n", 55 | " label = 'Points (Poisson)'\n", 56 | " if model == 3:\n", 57 | " label = 'Points (Random)'\n", 58 | " if init == 'kmeans':\n", 59 | " x = np.array(mdf2.loc[('kmeans',0),].index)\n", 60 | " y = mdf2.loc[(init,model),].values[:,0]\n", 61 | " error = sdf2.loc[(init,model),].values[:,0]\n", 62 | " else:\n", 63 | " x = np.array(mdf.loc[('random',0),].index)\n", 64 | "\n", 65 | " y = mdf.loc[(init,model),].values[:,0]\n", 66 | " error = sdf.loc[(init,model),].values[:,0]\n", 67 | "\n", 68 | " ax = plt.subplot(pltstuff,2,subplot_order[pltstuff*ii])\n", 69 | " #print(x.shape,y.shape)\n", 70 | "\n", 71 | " plt.plot(x,y,ls=ls,label=label)\n", 72 | " plt.fill_between(x, y-error, y+error,alpha=0.3)\n", 73 | " plt.grid(True)\n", 74 | " plt.xlabel('number of mixtures')\n", 75 | " plt.title('{} init. fidelity'.format(init))\n", 76 | " plt.title('{} initialization'.format(init))\n", 77 | "\n", 78 | " plt.ylabel('likelihood of ground truth\\n(higher is better)')\n", 79 | " plt.ylim(2,9)\n", 80 | " ax.set_xscale(\"log\", nonposx='clip')\n", 81 | " ax.tick_params(axis='x', which='minor', bottom=True,width=1,length=5) \n", 82 | " plt.grid(True,axis='x',which='minor')\n", 83 | " ax.tick_params(axis='x', which='major', bottom=True,width=2,length=5) \n", 84 | "\n", 85 | " if pltstuff == 2:\n", 86 | " if init == 'kmeans':\n", 87 | " y = mdf2.loc[(init,model),].values[:,1]\n", 88 | " error = sdf2.loc[(init,model),].values[:,1]\n", 89 | " else:\n", 90 | " #x = np.array(mdf.loc[('random',0),].index)\n", 91 | "\n", 92 | " #x = np.array(mdf.index)\n", 93 | " y = mdf.loc[(init,model),].values[:,1]\n", 94 | " error = sdf.loc[(init,model),].values[:,1]\n", 95 | " ax = plt.subplot(pltstuff,2,subplot_order[1+ii*2])\n", 96 | " print(x.shape,y.shape)\n", 97 | " plt.plot(x,y,ls=ls,label=label)\n", 98 | " plt.fill_between(x, y-error, y+error,alpha=0.15)\n", 99 | "\n", 100 | " ax.set_xscale(\"log\", nonposx='clip')\n", 101 | " plt.grid(True)\n", 102 | " plt.xlabel('number of mixtures')\n", 103 | " plt.title('{} init. runtime'.format(init))\n", 104 | " plt.ylabel('iterations until convergence')\n", 105 | " ax.tick_params(axis='x', which='minor', bottom=True,width=1,length=5) \n", 106 | " plt.grid(True,axis='x',which='minor')\n", 107 | " ax.tick_params(axis='x', which='major', bottom=True,width=2,length=5) \n", 108 | "\n", 109 | "for i in range(2*pltstuff):\n", 110 | " plt.subplot(pltstuff,2,1+i)\n", 111 | " plt.legend()\n", 112 | " plt.xlim(6,500)\n", 113 | "\n", 114 | "\n", 115 | "fig.tight_layout()\n", 116 | "fig.subplots_adjust(top=0.7+pltstuff*0.1)\n", 117 | "\n", 118 | "plt.suptitle('Different Decimation Methods',size=24,weight='bold')\n", 119 | "#plt.savefig('graphs-qc.pdf', facecolor=fig.get_facecolor(), edgecolor='none')\n" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "mdf2.loc[('kmeans',0),].values, np.array(mdf2.loc[('kmeans',0),].index)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "mdf2,mdf" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [] 146 | } 147 | ], 148 | "metadata": { 149 | "kernelspec": { 150 | "display_name": "Python 3", 151 | "language": "python", 152 | "name": "python3" 153 | }, 154 | "language_info": { 155 | "codemirror_mode": { 156 | "name": "ipython", 157 | "version": 3 158 | }, 159 | "file_extension": ".py", 160 | "mimetype": "text/x-python", 161 | "name": "python", 162 | "nbconvert_exporter": "python", 163 | "pygments_lexer": "ipython3", 164 | "version": "3.6.7" 165 | } 166 | }, 167 | "nbformat": 4, 168 | "nbformat_minor": 2 169 | } 170 | -------------------------------------------------------------------------------- /format_graphs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import pandas as pd\n", 12 | "import matplotlib.ticker" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "df = pd.read_csv(\"bunny_fit_monday_subsamples_25.log\",names=['k','init','model','l','i'])\n", 22 | "mdf2 = df.groupby(['init','model','k']).mean()\n", 23 | "sdf2 = df.groupby(['init','model','k']).std()\n", 24 | "mdf = df.groupby(['init','model','k']).mean()\n", 25 | "sdf = df.groupby(['init','model','k']).std()" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "pltstuff = 1\n", 35 | "\n", 36 | "fig = plt.figure(figsize=(10,pltstuff*5))\n", 37 | "from matplotlib import rc\n", 38 | "plt.style.use('fivethirtyeight')\n", 39 | "plt.style.use('seaborn-whitegrid')\n", 40 | "if pltstuff == 2:\n", 41 | " subplot_order = [1,3,2,4]\n", 42 | "else:\n", 43 | " subplot_order = [1,2]\n", 44 | "#plt.rcParams[\"font.family\"] = \"sans-serif\"\n", 45 | "rc('font',**{'family':'sans-serif','sans-serif':['cm']})\n", 46 | "#plt.style.use('default')\n", 47 | "for ii, init in enumerate(['kmeans','random']):\n", 48 | " for model in range(4):\n", 49 | " ls = '--' if model < 2 else '-'\n", 50 | " if model == 0:\n", 51 | " label = 'Mesh (Quadric)'\n", 52 | " if model == 1:\n", 53 | " label = 'Mesh (Cluster)'\n", 54 | " if model == 2:\n", 55 | " label = 'Points (Poisson)'\n", 56 | " if model == 3:\n", 57 | " label = 'Points (Random)'\n", 58 | " if init == 'kmeans':\n", 59 | " x = np.array(mdf2.loc[('kmeans',0),].index)\n", 60 | " y = mdf2.loc[(init,model),].values[:,0]\n", 61 | " error = sdf2.loc[(init,model),].values[:,0]\n", 62 | " else:\n", 63 | " x = np.array(mdf.loc[('random',0),].index)\n", 64 | "\n", 65 | " y = mdf.loc[(init,model),].values[:,0]\n", 66 | " error = sdf.loc[(init,model),].values[:,0]\n", 67 | "\n", 68 | " ax = plt.subplot(pltstuff,2,subplot_order[pltstuff*ii])\n", 69 | " #print(x.shape,y.shape)\n", 70 | "\n", 71 | " plt.plot(x,y,ls=ls,label=label)\n", 72 | " plt.fill_between(x, y-error, y+error,alpha=0.3)\n", 73 | " plt.grid(True)\n", 74 | " plt.xlabel('number of mixtures')\n", 75 | " plt.title('{} init. fidelity'.format(init))\n", 76 | " plt.title('{} initialization'.format(init))\n", 77 | "\n", 78 | " plt.ylabel('likelihood of ground truth\\n(higher is better)')\n", 79 | " plt.ylim(2,9)\n", 80 | " ax.set_xscale(\"log\", nonposx='clip')\n", 81 | " ax.tick_params(axis='x', which='minor', bottom=True,width=1,length=5) \n", 82 | " plt.grid(True,axis='x',which='minor')\n", 83 | " ax.tick_params(axis='x', which='major', bottom=True,width=2,length=5) \n", 84 | "\n", 85 | " if pltstuff == 2:\n", 86 | " if init == 'kmeans':\n", 87 | " y = mdf2.loc[(init,model),].values[:,1]\n", 88 | " error = sdf2.loc[(init,model),].values[:,1]\n", 89 | " else:\n", 90 | " #x = np.array(mdf.loc[('random',0),].index)\n", 91 | "\n", 92 | " #x = np.array(mdf.index)\n", 93 | " y = mdf.loc[(init,model),].values[:,1]\n", 94 | " error = sdf.loc[(init,model),].values[:,1]\n", 95 | " ax = plt.subplot(pltstuff,2,subplot_order[1+ii*2])\n", 96 | " print(x.shape,y.shape)\n", 97 | " plt.plot(x,y,ls=ls,label=label)\n", 98 | " plt.fill_between(x, y-error, y+error,alpha=0.15)\n", 99 | "\n", 100 | " ax.set_xscale(\"log\", nonposx='clip')\n", 101 | " plt.grid(True)\n", 102 | " plt.xlabel('number of mixtures')\n", 103 | " plt.title('{} init. runtime'.format(init))\n", 104 | " plt.ylabel('iterations until convergence')\n", 105 | " ax.tick_params(axis='x', which='minor', bottom=True,width=1,length=5) \n", 106 | " plt.grid(True,axis='x',which='minor')\n", 107 | " ax.tick_params(axis='x', which='major', bottom=True,width=2,length=5) \n", 108 | "\n", 109 | "for i in range(2*pltstuff):\n", 110 | " plt.subplot(pltstuff,2,1+i)\n", 111 | " plt.legend()\n", 112 | " plt.xlim(6,400)\n", 113 | "\n", 114 | "\n", 115 | "fig.tight_layout()\n", 116 | "fig.subplots_adjust(top=0.7+pltstuff*0.1)\n", 117 | "\n", 118 | "plt.suptitle('Different Decimation Methods',size=24,weight='bold')\n", 119 | "plt.savefig('graphs-qc.pdf', facecolor=fig.get_facecolor(), edgecolor='none')\n" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "mdf2.loc[('kmeans',0),].values, np.array(mdf2.loc[('kmeans',0),].index)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "mdf2,mdf" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [] 146 | } 147 | ], 148 | "metadata": { 149 | "kernelspec": { 150 | "display_name": "Python 3", 151 | "language": "python", 152 | "name": "python3" 153 | }, 154 | "language_info": { 155 | "codemirror_mode": { 156 | "name": "ipython", 157 | "version": 3 158 | }, 159 | "file_extension": ".py", 160 | "mimetype": "text/x-python", 161 | "name": "python", 162 | "nbconvert_exporter": "python", 163 | "pygments_lexer": "ipython3", 164 | "version": "3.6.7" 165 | } 166 | }, 167 | "nbformat": 4, 168 | "nbformat_minor": 2 169 | } 170 | -------------------------------------------------------------------------------- /gen_gmm-mine-com.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pylab import *\n", 10 | "import os\n", 11 | "import sys\n", 12 | "from mixture import GaussianMixture\n", 13 | "import pickle\n", 14 | "from PIL import Image" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "dataset_dir = 'rgbd_dataset_freiburg3_long_office_household'\n", 24 | "depth_dir = 'depth'\n", 25 | "gmm_dir = 'gmm_qqvga_mine_quarterpixel_com2'\n", 26 | "\n", 27 | "# og \n", 28 | "fx = 525.0 # focal length x\n", 29 | "fy = 525.0 # focal length y\n", 30 | "cx = 319.5 # optical center x\n", 31 | "cy = 239.5 # optical center y\n", 32 | "\n", 33 | "# fri3 \n", 34 | "fx = 535.4 # focal length x\n", 35 | "fy = 539.2 # focal length y\n", 36 | "cx = 320.1 # optical center x\n", 37 | "cy = 247.6 # optical center y\n", 38 | "factor = 5000 # for the 16-bit PNG files\n", 39 | "# OR: factor = 1 # for the 32-bit float images in the ROS bag files\n", 40 | "baseline = 0.075 \n", 41 | "\n", 42 | "full_depth_dir = os.path.join(dataset_dir,depth_dir)\n", 43 | "full_gmm_dir = os.path.join(dataset_dir,gmm_dir)\n", 44 | "if not os.path.exists(full_gmm_dir):\n", 45 | " os.mkdir(full_gmm_dir)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "max_disp = 325 #fx*baseline/0.125\n", 55 | "min_disp = 2#f\n" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "y_grid = np.repeat(np.arange(480)[:,None],640,1)\n", 65 | "x_grid = np.repeat(np.arange(640)[None,:],480,0)\n", 66 | "\n", 67 | "size_of_disp_error = 0.25 #0.125\n", 68 | "\n", 69 | "def img_to_pts(img):\n", 70 | " Z = img/factor\n", 71 | " X = (x_grid-cx) * Z /fx\n", 72 | " Y = (y_grid-cy) * Z /fy\n", 73 | " \n", 74 | " d = (fx*baseline)/Z\n", 75 | " zf = (fx*baseline)/np.clip(d+size_of_disp_error,min_disp,max_disp)\n", 76 | " zb = (fx*baseline)/np.clip(d-size_of_disp_error,min_disp,max_disp)\n", 77 | " \n", 78 | " subsample = 4\n", 79 | " step_frac = 0.5 * subsample\n", 80 | " Xp = x_grid-cx+step_frac\n", 81 | " Yp = y_grid-cy+step_frac\n", 82 | " Xm = x_grid-cx-step_frac\n", 83 | " Ym = y_grid-cy-step_frac\n", 84 | " frtptT = np.array(((Xp)/fx,(Yp)/fy,np.ones_like(d)))*zf\n", 85 | " lftptT = np.array(((Xm)/fx,(Yp)/fy,np.ones_like(d)))*Z\n", 86 | " rgtptT = np.array(((Xp)/fx,(Yp)/fy,np.ones_like(d)))*Z\n", 87 | " bckptT = np.array(((Xm)/fx,(Yp)/fy,np.ones_like(d)))*zb\n", 88 | " \n", 89 | " frtptB = np.array(((Xp)/fx,(Ym)/fy,np.ones_like(d)))*zf\n", 90 | " lftptB = np.array(((Xm)/fx,(Ym)/fy,np.ones_like(d)))*Z\n", 91 | " rgtptB = np.array(((Xp)/fx,(Ym)/fy,np.ones_like(d)))*Z\n", 92 | " bckptB = np.array(((Xm)/fx,(Ym)/fy,np.ones_like(d)))*zb\n", 93 | " \n", 94 | " box = np.stack([frtptT,lftptT,rgtptT,bckptT,frtptB,lftptB,rgtptB,bckptB])\n", 95 | " print(box.shape)\n", 96 | " \n", 97 | " m = box.mean(0)\n", 98 | "\n", 99 | " m1 = (m.reshape((3,-1)).T)[:,None]\n", 100 | " m2 = m1.reshape((-1,3,1))\n", 101 | " print(m1.shape)\n", 102 | " print(m2.shape)\n", 103 | "\n", 104 | " \n", 105 | " M = m2 @ m1\n", 106 | " print(M.shape)\n", 107 | " print(box.shape)\n", 108 | " \n", 109 | " covar = np.zeros( (640*480,3,3) )\n", 110 | " for pt in box:\n", 111 | " pt = pt.reshape((3,-1)).T\n", 112 | " pt1 = pt.reshape((-1,1,3))\n", 113 | " pt2 = pt.reshape((-1,3,1))\n", 114 | "\n", 115 | " covar += (pt2 @ pt1 - M)\n", 116 | " covar = covar/12\n", 117 | " covar = covar.reshape((480,640,3,3))[::subsample,::subsample,:,:]\n", 118 | " X = m[0][::subsample,::subsample]\n", 119 | " Y = m[1][::subsample,::subsample]\n", 120 | " Z = m[2][::subsample,::subsample]\n", 121 | " print(X.shape)\n", 122 | "\n", 123 | " xl = X[Z > 0]\n", 124 | " yl = Y[Z > 0]\n", 125 | " zl = Z[Z > 0]\n", 126 | " c1 = (covar.reshape((120,160,-1))[Z > 0]).reshape( (-1,3,3) )\n", 127 | " n = zl.shape[0]\n", 128 | " print(n,c1.shape)\n", 129 | " #print('MSHAPE\\t',m.shape,Z.shape)\n", 130 | " idx = np.random.randint(n, size=n//16)\n", 131 | "\n", 132 | " return np.vstack([xl,yl,zl]).T, c1\n", 133 | "\n", 134 | "def load_data(file):\n", 135 | " image = Image.open(file) \n", 136 | " pixel = np.array(image)\n", 137 | " return img_to_pts(pixel)\n", 138 | "\n", 139 | "for fl in sorted(os.listdir(full_depth_dir)):\n", 140 | " name,ext = os.path.splitext(fl)\n", 141 | " new_name = name + '.pkl'\n", 142 | " \n", 143 | " clf = GaussianMixture(100)\n", 144 | " data = load_data(os.path.join(full_depth_dir,fl))\n", 145 | " print([_.shape for _ in data])\n", 146 | " t1 = time.time()\n", 147 | " \n", 148 | " n = data[0].shape[0]\n", 149 | " clf.set_areas(np.ones(n))\n", 150 | " clf.set_covars(data[1])\n", 151 | " clf.fit(data[0])\n", 152 | " clf.set_areas(None)\n", 153 | " clf.set_covars(None)\n", 154 | " \n", 155 | " print(time.time()-t1)\n", 156 | " with open(os.path.join(full_gmm_dir,new_name),'wb') as fp:\n", 157 | " pickle.dump(clf,fp)\n", 158 | " " 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "frtptB.shape" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [] 176 | } 177 | ], 178 | "metadata": { 179 | "kernelspec": { 180 | "display_name": "Python 3", 181 | "language": "python", 182 | "name": "python3" 183 | }, 184 | "language_info": { 185 | "codemirror_mode": { 186 | "name": "ipython", 187 | "version": 3 188 | }, 189 | "file_extension": ".py", 190 | "mimetype": "text/x-python", 191 | "name": "python", 192 | "nbconvert_exporter": "python", 193 | "pygments_lexer": "ipython3", 194 | "version": "3.6.7" 195 | } 196 | }, 197 | "nbformat": 4, 198 | "nbformat_minor": 2 199 | } 200 | -------------------------------------------------------------------------------- /gen_gmm-mine-patch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pylab import *\n", 10 | "import os\n", 11 | "import sys\n", 12 | "from mixture import GaussianMixture\n", 13 | "import pickle\n", 14 | "from PIL import Image" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "dataset_dir = 'rgbd_dataset_freiburg3_long_office_household'\n", 24 | "depth_dir = 'depth'\n", 25 | "gmm_dir = 'gmm_qqvga_mine_patch'\n", 26 | "\n", 27 | "# og \n", 28 | "fx = 525.0 # focal length x\n", 29 | "fy = 525.0 # focal length y\n", 30 | "cx = 319.5 # optical center x\n", 31 | "cy = 239.5 # optical center y\n", 32 | "\n", 33 | "# fri3 \n", 34 | "fx = 535.4 # focal length x\n", 35 | "fy = 539.2 # focal length y\n", 36 | "cx = 320.1 # optical center x\n", 37 | "cy = 247.6 # optical center y\n", 38 | "factor = 5000 # for the 16-bit PNG files\n", 39 | "# OR: factor = 1 # for the 32-bit float images in the ROS bag files\n", 40 | "baseline = 0.075 \n", 41 | "\n", 42 | "full_depth_dir = os.path.join(dataset_dir,depth_dir)\n", 43 | "full_gmm_dir = os.path.join(dataset_dir,gmm_dir)\n", 44 | "if not os.path.exists(full_gmm_dir):\n", 45 | " os.mkdir(full_gmm_dir)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "max_disp = 325 #fx*baseline/0.125\n", 55 | "min_disp = 2#f\n" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "y_grid = np.repeat(np.arange(480)[:,None],640,1)\n", 65 | "x_grid = np.repeat(np.arange(640)[None,:],480,0)\n", 66 | "\n", 67 | "\n", 68 | "def img_to_pts(img):\n", 69 | " Z = img/factor\n", 70 | " X = (x_grid-cx) * Z /fx\n", 71 | " Y = (y_grid-cy) * Z /fy\n", 72 | " subsample = 4\n", 73 | "\n", 74 | " #uniform dist\n", 75 | " covar_X = (subsample* 1 * Z /fx)**2 * (1.0/12.0)\n", 76 | " covar_Y = (subsample* 1 * Z /fy)**2 * (1.0/12.0) \n", 77 | " covar_Z = (0 * Z /fx)**2 * (1.0/12.0)\n", 78 | " covar = np.zeros((480,640,3,3))\n", 79 | " covar[:,:,0,0] = covar_X\n", 80 | " covar[:,:,1,1] = covar_Y\n", 81 | " covar[:,:,2,2] = covar_Z\n", 82 | "\n", 83 | " covar = covar[::subsample,::subsample,:,:]\n", 84 | "\n", 85 | " X = X[::subsample,::subsample]\n", 86 | " Y = Y[::subsample,::subsample]\n", 87 | " Z = Z[::subsample,::subsample]\n", 88 | " print(X.shape)\n", 89 | "\n", 90 | " xl = X[Z > 0]\n", 91 | " yl = Y[Z > 0]\n", 92 | " zl = Z[Z > 0]\n", 93 | " c1 = (covar.reshape((120,160,-1))[Z > 0]).reshape( (-1,3,3) )\n", 94 | " n = zl.shape[0]\n", 95 | " print(n,c1.shape)\n", 96 | " \n", 97 | " idx = np.random.randint(n, size=n//16)\n", 98 | " return np.vstack([xl,yl,zl]).T, c1\n", 99 | "\n", 100 | "def load_data(file):\n", 101 | " image = Image.open(file) \n", 102 | " pixel = np.array(image)\n", 103 | " return img_to_pts(pixel)\n", 104 | "\n", 105 | "for fl in sorted(os.listdir(full_depth_dir)):\n", 106 | " name,ext = os.path.splitext(fl)\n", 107 | " new_name = name + '.pkl'\n", 108 | " \n", 109 | " clf = GaussianMixture(100)\n", 110 | " data = load_data(os.path.join(full_depth_dir,fl))\n", 111 | " print([_.shape for _ in data])\n", 112 | " t1 = time.time()\n", 113 | " \n", 114 | " n = data[0].shape[0]\n", 115 | " clf.set_areas(np.ones(n))\n", 116 | " clf.set_covars(data[1])\n", 117 | " clf.fit(data[0])\n", 118 | " clf.set_areas(None)\n", 119 | " clf.set_covars(None)\n", 120 | " \n", 121 | " print(time.time()-t1)\n", 122 | " with open(os.path.join(full_gmm_dir,new_name),'wb') as fp:\n", 123 | " pickle.dump(clf,fp)\n", 124 | " " 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "frtptB.shape" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [] 142 | } 143 | ], 144 | "metadata": { 145 | "kernelspec": { 146 | "display_name": "Python 3", 147 | "language": "python", 148 | "name": "python3" 149 | }, 150 | "language_info": { 151 | "codemirror_mode": { 152 | "name": "ipython", 153 | "version": 3 154 | }, 155 | "file_extension": ".py", 156 | "mimetype": "text/x-python", 157 | "name": "python", 158 | "nbconvert_exporter": "python", 159 | "pygments_lexer": "ipython3", 160 | "version": "3.6.7" 161 | } 162 | }, 163 | "nbformat": 4, 164 | "nbformat_minor": 2 165 | } 166 | -------------------------------------------------------------------------------- /gen_gmm-mine.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pylab import *\n", 10 | "import os\n", 11 | "import sys\n", 12 | "from mixture import GaussianMixture\n", 13 | "import pickle\n", 14 | "from PIL import Image" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "dataset_dir = 'rgbd_dataset_freiburg3_long_office_household'\n", 24 | "depth_dir = 'depth'\n", 25 | "gmm_dir = 'gmm_qqvga_mine_pixel_fixed'\n", 26 | "\n", 27 | "# og \n", 28 | "fx = 525.0 # focal length x\n", 29 | "fy = 525.0 # focal length y\n", 30 | "cx = 319.5 # optical center x\n", 31 | "cy = 239.5 # optical center y\n", 32 | "\n", 33 | "# fri3 \n", 34 | "fx = 535.4 # focal length x\n", 35 | "fy = 539.2 # focal length y\n", 36 | "cx = 320.1 # optical center x\n", 37 | "cy = 247.6 # optical center y\n", 38 | "factor = 5000 # for the 16-bit PNG files\n", 39 | "# OR: factor = 1 # for the 32-bit float images in the ROS bag files\n", 40 | "baseline = 0.075 \n", 41 | "\n", 42 | "full_depth_dir = os.path.join(dataset_dir,depth_dir)\n", 43 | "full_gmm_dir = os.path.join(dataset_dir,gmm_dir)\n", 44 | "if not os.path.exists(full_gmm_dir):\n", 45 | " os.mkdir(full_gmm_dir)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "max_disp = 325 #fx*baseline/0.125\n", 55 | "min_disp = 2#f\n" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "y_grid = np.repeat(np.arange(480)[:,None],640,1)\n", 65 | "x_grid = np.repeat(np.arange(640)[None,:],480,0)\n", 66 | "\n", 67 | "size_of_disp_error = 1.0 #0.125\n", 68 | "\n", 69 | "def img_to_pts(img):\n", 70 | " Z = img/factor\n", 71 | " X = (x_grid-cx) * Z /fx\n", 72 | " Y = (y_grid-cy) * Z /fy\n", 73 | " \n", 74 | " d = (fx*baseline)/Z\n", 75 | " zf = (fx*baseline)/np.clip(d+size_of_disp_error,min_disp,max_disp)\n", 76 | " zb = (fx*baseline)/np.clip(d-size_of_disp_error,min_disp,max_disp)\n", 77 | " \n", 78 | " subsample = 4\n", 79 | " step_frac = 0.5 * subsample\n", 80 | " Xp = x_grid-cx+step_frac\n", 81 | " Yp = y_grid-cy+step_frac\n", 82 | " Xm = x_grid-cx-step_frac\n", 83 | " Ym = y_grid-cy-step_frac\n", 84 | " frtptT = np.array(((Xp)/fx,(Yp)/fy,np.ones_like(d)))*zf\n", 85 | " lftptT = np.array(((Xm)/fx,(Yp)/fy,np.ones_like(d)))*Z\n", 86 | " rgtptT = np.array(((Xp)/fx,(Yp)/fy,np.ones_like(d)))*Z\n", 87 | " bckptT = np.array(((Xm)/fx,(Yp)/fy,np.ones_like(d)))*zb\n", 88 | " \n", 89 | " frtptB = np.array(((Xp)/fx,(Ym)/fy,np.ones_like(d)))*zf\n", 90 | " lftptB = np.array(((Xm)/fx,(Ym)/fy,np.ones_like(d)))*Z\n", 91 | " rgtptB = np.array(((Xp)/fx,(Ym)/fy,np.ones_like(d)))*Z\n", 92 | " bckptB = np.array(((Xm)/fx,(Ym)/fy,np.ones_like(d)))*zb\n", 93 | " \n", 94 | " box = np.stack([frtptT,lftptT,rgtptT,bckptT,frtptB,lftptB,rgtptB,bckptB])\n", 95 | " print(box.shape)\n", 96 | " \n", 97 | " m = box.mean(0)\n", 98 | " print(m.shape)\n", 99 | "\n", 100 | " m1 = (m.reshape((3,-1)).T)[:,None]\n", 101 | " m2 = m1.reshape((-1,3,1))\n", 102 | " print(m1.shape)\n", 103 | " print(m2.shape)\n", 104 | "\n", 105 | " \n", 106 | " M = m2 @ m1\n", 107 | " print(M.shape)\n", 108 | " print(box.shape)\n", 109 | " \n", 110 | " covar = np.zeros( (640*480,3,3) )\n", 111 | " for pt in box:\n", 112 | " pt = pt.reshape((3,-1)).T\n", 113 | " pt1 = pt.reshape((-1,1,3))\n", 114 | " pt2 = pt.reshape((-1,3,1))\n", 115 | "\n", 116 | " covar += (pt2 @ pt1 - M)\n", 117 | " covar = covar/12\n", 118 | " covar = covar.reshape((480,640,3,3))[::subsample,::subsample,:,:]\n", 119 | " X = X[::subsample,::subsample]\n", 120 | " Y = Y[::subsample,::subsample]\n", 121 | " Z = Z[::subsample,::subsample]\n", 122 | " print(X.shape)\n", 123 | "\n", 124 | " xl = X[Z > 0]\n", 125 | " yl = Y[Z > 0]\n", 126 | " zl = Z[Z > 0]\n", 127 | " c1 = (covar.reshape((120,160,-1))[Z > 0]).reshape( (-1,3,3) )\n", 128 | " n = zl.shape[0]\n", 129 | " print(n,c1.shape)\n", 130 | " \n", 131 | " idx = np.random.randint(n, size=n//16)\n", 132 | "\n", 133 | " return np.vstack([xl,yl,zl]).T, c1\n", 134 | "\n", 135 | "def load_data(file):\n", 136 | " image = Image.open(file) \n", 137 | " pixel = np.array(image)\n", 138 | " return img_to_pts(pixel)\n", 139 | "\n", 140 | "for fl in sorted(os.listdir(full_depth_dir)):\n", 141 | " name,ext = os.path.splitext(fl)\n", 142 | " new_name = name + '.pkl'\n", 143 | " \n", 144 | " clf = GaussianMixture(100)\n", 145 | " data = load_data(os.path.join(full_depth_dir,fl))\n", 146 | " print([_.shape for _ in data])\n", 147 | " t1 = time.time()\n", 148 | " \n", 149 | " n = data[0].shape[0]\n", 150 | " clf.set_areas(np.ones(n))\n", 151 | " clf.set_covars(data[1])\n", 152 | " clf.fit(data[0])\n", 153 | " clf.set_areas(None)\n", 154 | " clf.set_covars(None)\n", 155 | " \n", 156 | " print(time.time()-t1)\n", 157 | " with open(os.path.join(full_gmm_dir,new_name),'wb') as fp:\n", 158 | " pickle.dump(clf,fp)\n", 159 | " " 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "frtptB.shape" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [] 177 | } 178 | ], 179 | "metadata": { 180 | "kernelspec": { 181 | "display_name": "Python 3", 182 | "language": "python", 183 | "name": "python3" 184 | }, 185 | "language_info": { 186 | "codemirror_mode": { 187 | "name": "ipython", 188 | "version": 3 189 | }, 190 | "file_extension": ".py", 191 | "mimetype": "text/x-python", 192 | "name": "python", 193 | "nbconvert_exporter": "python", 194 | "pygments_lexer": "ipython3", 195 | "version": "3.6.7" 196 | } 197 | }, 198 | "nbformat": 4, 199 | "nbformat_minor": 2 200 | } 201 | -------------------------------------------------------------------------------- /gen_gmm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pylab import *\n", 10 | "import os\n", 11 | "import sys\n", 12 | "from sklearn.mixture import GaussianMixture\n", 13 | "import pickle\n", 14 | "from PIL import Image" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "dataset_dir = 'rgbd_dataset_freiburg3_long_office_household'\n", 24 | "depth_dir = 'depth'\n", 25 | "gmm_dir = 'gmm_qqvga'\n", 26 | "\n", 27 | "# og \n", 28 | "fx = 525.0 # focal length x\n", 29 | "fy = 525.0 # focal length y\n", 30 | "cx = 319.5 # optical center x\n", 31 | "cy = 239.5 # optical center y\n", 32 | "\n", 33 | "# fri3 \n", 34 | "fx = 535.4 # focal length x\n", 35 | "fy = 539.2 # focal length y\n", 36 | "cx = 320.1 # optical center x\n", 37 | "cy = 247.6 # optical center y\n", 38 | "factor = 5000 # for the 16-bit PNG files\n", 39 | "# OR: factor = 1 # for the 32-bit float images in the ROS bag files\n", 40 | "\n", 41 | "full_depth_dir = os.path.join(dataset_dir,depth_dir)\n", 42 | "full_gmm_dir = os.path.join(dataset_dir,gmm_dir)\n", 43 | "if not os.path.exists(full_gmm_dir):\n", 44 | " os.mkdir(full_gmm_dir)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "y_grid = np.repeat(np.arange(480)[:,None],640,1)\n", 54 | "x_grid = np.repeat(np.arange(640)[None,:],480,0)\n", 55 | "\n", 56 | "def img_to_pts(img):\n", 57 | " Z = img/factor\n", 58 | " X = (x_grid-cx) * Z /fx\n", 59 | " Y = (y_grid-cy) * Z /fy\n", 60 | " #for v in range(depth_image.height):\n", 61 | " # for u in range(depth_image.width):\n", 62 | " # Z = depth_image[v,u] / factor;\n", 63 | " # X = (u - cx) * Z / fx;\n", 64 | " # Y = (v - cy) * Z / fy;\n", 65 | " X = X[::4,::4]\n", 66 | " Y = Y[::4,::4]\n", 67 | " Z = Z[::4,::4]\n", 68 | " print(X.shape)\n", 69 | "\n", 70 | " xl = X[Z > 0]\n", 71 | " yl = Y[Z > 0]\n", 72 | " zl = Z[Z > 0]\n", 73 | " n = zl.shape[0]\n", 74 | " print(n)\n", 75 | " idx = np.random.randint(n, size=n//16)\n", 76 | "\n", 77 | " return np.vstack([xl,yl,zl]).T\n", 78 | "\n", 79 | "def load_data(file):\n", 80 | " image = Image.open(file) \n", 81 | " pixel = np.array(image)\n", 82 | " return img_to_pts(pixel)\n", 83 | "\n", 84 | "for fl in sorted(os.listdir(full_depth_dir)):\n", 85 | " name,ext = os.path.splitext(fl)\n", 86 | " new_name = name + '.pkl'\n", 87 | " \n", 88 | " clf = GaussianMixture(100)\n", 89 | " data = load_data(os.path.join(full_depth_dir,fl))\n", 90 | " t1 = time.time()\n", 91 | " clf.fit(data)\n", 92 | " print(time.time()-t1)\n", 93 | " with open(os.path.join(full_gmm_dir,new_name),'wb') as fp:\n", 94 | " pickle.dump(clf,fp)\n", 95 | " " 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "name,ext" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [] 113 | } 114 | ], 115 | "metadata": { 116 | "kernelspec": { 117 | "display_name": "Python 3", 118 | "language": "python", 119 | "name": "python3" 120 | }, 121 | "language_info": { 122 | "codemirror_mode": { 123 | "name": "ipython", 124 | "version": 3 125 | }, 126 | "file_extension": ".py", 127 | "mimetype": "text/x-python", 128 | "name": "python", 129 | "nbconvert_exporter": "python", 130 | "pygments_lexer": "ipython3", 131 | "version": "3.6.7" 132 | } 133 | }, 134 | "nbformat": 4, 135 | "nbformat_minor": 2 136 | } 137 | -------------------------------------------------------------------------------- /gmm_fit.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import multivariate_normal as mvn_pdf 3 | 4 | import matplotlib.pyplot as plt 5 | from sklearn.cluster import MiniBatchKMeans 6 | from mixture import GaussianMixture 7 | import pymesh 8 | 9 | def compute_gmm(x,k=2,w=None,iter_max=10000,i_tol=1e-9,e_tol=1e-3): 10 | km = MiniBatchKMeans(k) 11 | if w is None: 12 | w = np.ones(x.shape[0]) 13 | km.fit(x) 14 | labels = km.labels_ 15 | mu = km.cluster_centers_ 16 | 17 | sigma = [] 18 | for i in range(k): 19 | new_sigma = np.identity(x.shape[1])*i_tol 20 | pts = x[(km.labels_ == i),:] 21 | sigma.append(np.cov(pts,rowvar=False) + new_sigma) 22 | sigma = np.array(sigma) 23 | #mu = x[np.random.choice(x.shape[0],k,replace=False),:] 24 | #sigma = np.array([(x.std()/k)*np.identity(x.shape[1]) for _ in range(k)]) 25 | pi = np.ones(shape=k)/k 26 | 27 | mu_prev = mu.copy() 28 | for iternum in range(iter_max): 29 | # e-step 30 | gamma = np.zeros(shape=(x.shape[0],k)) 31 | g2 = np.zeros(shape=(x.shape[0],k)) 32 | 33 | for i in range(k): 34 | gamma_i = pi[i]*mvn_pdf.pdf(x,mean=mu[i],cov=sigma[i]) 35 | gamma[:,i] = gamma_i 36 | g2[:,i] = pi[i]*mvn_pdf.pdf(x,mean=mu[i],cov=sigma[i]) 37 | g2 = np.copy(gamma) 38 | #gamma = w.reshape((-1,1)) * gamma 39 | gamma = gamma/gamma.sum(1,keepdims=True) 40 | g2 = g2/g2.sum(1,keepdims=True) 41 | print(np.linalg.norm(gamma-g2),gamma.shape,w.shape) 42 | # m-step 43 | for i in range(k): 44 | new_mu = np.zeros(x.shape[1]) 45 | for j in range(x.shape[0]): 46 | new_mu += gamma[j,i] * x[j,:] 47 | new_mu /= gamma.sum(0)[0] 48 | mu[i,:] = new_mu 49 | new_sigma = np.identity(x.shape[1])*i_tol 50 | for j in range(x.shape[0]): 51 | xv = x[j,:][:,np.newaxis] 52 | xm = new_mu[:,np.newaxis] 53 | xd = xv - xm 54 | new_sigma += gamma[j,i] * (xd @ xd.T) 55 | new_sigma /= gamma.sum(0)[0] 56 | sigma[i,:,:] = new_sigma 57 | pi = gamma.mean(0) 58 | if ((mu-mu_prev)**2).sum() < e_tol: 59 | break 60 | mu_prev = mu.copy() 61 | print(iternum) 62 | return mu,sigma,pi 63 | 64 | mesh0 = pymesh.load_mesh("bunny/bun_zipper_1000_1.ply") 65 | mesh1 = pymesh.load_mesh("bunny/bun_zipper_992_1.ply") 66 | 67 | mesh2 = pymesh.load_mesh("bunny/bun_zipper_pts_1000_1.ply") 68 | mesh3 = pymesh.load_mesh("bunny/bun_zipper_pds_1000_1.ply") 69 | #mesh3 = pymesh.load_mesh("bunny/bun_zipper_res4_pds.ply") 70 | mesh4 = pymesh.load_mesh("bunny/bun_zipper_50k.ply") 71 | 72 | def get_tri_covar(tris): 73 | covars = [] 74 | for face in tris: 75 | A = face[0][:,None] 76 | B = face[1][:,None] 77 | C = face[2][:,None] 78 | M = (A+B+C)/3 79 | covars.append(A @ A.T + B @ B.T + C @ C.T - 3* M @ M.T) 80 | return np.array(covars)*(1/12.0) 81 | 82 | def get_centroids(mesh): 83 | # obtain a vertex for each face index 84 | face_vert = mesh.vertices[mesh.faces.reshape(-1),:].reshape((mesh.faces.shape[0],3,-1)) 85 | # face_vert is size (faces,3(one for each vert), 3(one for each dimension)) 86 | centroids = face_vert.sum(1)/3.0 87 | ABAC = face_vert[:,1:3,:] - face_vert[:,0:1,:] 88 | areas = np.linalg.norm(np.cross(ABAC[:,0,:],ABAC[:,1,:]),axis=1)/2.0 89 | return centroids, areas,face_vert 90 | 91 | coma,aa,fv1 = get_centroids(mesh0) 92 | com,a,fv2 = get_centroids(mesh1) 93 | 94 | a = a/a.min() 95 | aa = aa/aa.min() 96 | 97 | data_covar1 = get_tri_covar(fv1) 98 | 99 | data_covar2 = get_tri_covar(fv2) 100 | 101 | #verts = mesh2.vertices#[np.random.choice(mesh2.vertices.shape[0], com.shape[0], replace=False), :] 102 | #res = compute_gmm(com,100,a) 103 | #res2 = compute_gmm(verts,100) 104 | #raise 105 | with open('bunny_fit_monday_subsamples_25.log','w') as fout: 106 | for km in [6,12,25,50,100,200,400]: 107 | for init in ['random','kmeans']: 108 | for exp_n in range(10): 109 | gm0 = GaussianMixture(km,init_params=init,max_iter=25,tol=1e-12); gm0.set_covars(data_covar1); gm0.set_areas(aa); gm0.fit(coma); gm0.set_covars(None); gm0.set_areas(None) 110 | gm1 = GaussianMixture(km,init_params=init,max_iter=25,tol=1e-12); gm1.set_covars(data_covar2); gm1.set_areas(a); gm1.fit(com); gm1.set_covars(None); gm1.set_areas(None) 111 | gm2 = GaussianMixture(km,init_params=init,max_iter=25,tol=1e-12); gm2.fit(mesh3.vertices) 112 | gm3 = GaussianMixture(km,init_params=init,max_iter=25,tol=1e-12); gm3.fit(mesh2.vertices) 113 | 114 | #gm3 = GaussianMixture(100); gm3.fit(mesh4.vertices) 115 | #print(coma.shape[0],com.shape[0],mesh2.vertices.shape[0],mesh3.vertices.shape[0]) 116 | s0 = gm0.score(mesh4.vertices) 117 | s1 = gm1.score(mesh4.vertices) 118 | s2 = gm2.score(mesh4.vertices) 119 | s3 = gm3.score(mesh4.vertices) 120 | 121 | #print(gm0.n_iter_,gm1.n_iter_) 122 | #print(gm2.n_iter_,gm3.n_iter_) 123 | #print(s0,s1) 124 | #print(s2,s3) 125 | fout.write("{},{},{},{},{}\n".format(km,init,'0',s0,gm0.n_iter_)) 126 | fout.write("{},{},{},{},{}\n".format(km,init,'1',s1,gm1.n_iter_)) 127 | fout.write("{},{},{},{},{}\n".format(km,init,'2',s2,gm2.n_iter_)) 128 | fout.write("{},{},{},{},{}\n".format(km,init,'3',s3,gm3.n_iter_)) 129 | 130 | #print(gm1.aic(mesh4.vertices),gm2.aic(mesh4.vertices))#,gm3.aic(mesh4.vertices)) 131 | 132 | #print((res[2] >0).sum(),(res2[2] >0).sum()) 133 | if False: 134 | import matplotlib.pyplot as plt 135 | import mpl_toolkits.mplot3d as m3d 136 | ax = m3d.Axes3D(plt.figure()) 137 | ax.scatter(com[:,0],com[:,1],com[:,2],s=a) 138 | ax.scatter(verts[:,0],verts[:,1],verts[:,2],s=20) 139 | plt.show() 140 | -------------------------------------------------------------------------------- /gmm_fit2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import multivariate_normal as mvn_pdf 3 | 4 | import matplotlib.pyplot as plt 5 | from mixture import GaussianMixture 6 | import pymesh 7 | 8 | mesh0 = pymesh.load_mesh("bunny/bun_zipper_res4.ply") 9 | #mesh3 = pymesh.load_mesh("bunny/bun_zipper_res4_pds.ply") 10 | mesh4 = pymesh.load_mesh("bunny/bun_zipper_res4_25k_pds.ply") 11 | 12 | def get_centroids(mesh): 13 | # obtain a vertex for each face index 14 | face_vert = mesh.vertices[mesh.faces.reshape(-1),:].reshape((mesh.faces.shape[0],3,-1)) 15 | # face_vert is size (faces,3(one for each vert), 3(one for each dimension)) 16 | centroids = face_vert.sum(1)/3.0 17 | ABAC = face_vert[:,1:3,:] - face_vert[:,0:1,:] 18 | areas = np.linalg.norm(np.cross(ABAC[:,0,:],ABAC[:,1,:]),axis=1)/2.0 19 | return centroids, areas 20 | 21 | def get_tri_covar(tris): 22 | covars = [] 23 | for face in tris: 24 | A = face[0][:,None] 25 | B = face[1][:,None] 26 | C = face[2][:,None] 27 | M = (A+B+C)/3 28 | covars.append(A @ A.T + B @ B.T + C @ C.T - 3* M @ M.T) 29 | return np.array(covars)*(1/12.0) 30 | com,a = get_centroids(mesh0) 31 | face_vert = mesh0.vertices[mesh0.faces.reshape(-1),:].reshape((mesh0.faces.shape[0],3,-1)) 32 | data_covar = get_tri_covar(face_vert) 33 | 34 | with open('bunny_1k_com_verts_tuesday_25.log','w') as fout: 35 | for km in [6,12,25,50,100,200,400]: 36 | for init in ['kmeans','random']: 37 | for exp_n in range(10): 38 | gm3 = GaussianMixture(km,init_params=init,max_iter=25,tol=1e-12); gm3.set_covars(data_covar); gm3.set_areas(a); gm3.fit(com); gm3.set_covars(None); gm3.set_areas(None) 39 | gm0 = GaussianMixture(km,init_params=init,max_iter=25,tol=1e-12); gm0.set_areas(a); gm0.fit(com); gm0.set_areas(None) 40 | gm1 = GaussianMixture(km,init_params=init,max_iter=25,tol=1e-12); gm1.fit(com) 41 | gm2 = GaussianMixture(km,init_params=init,max_iter=25,tol=1e-12); gm2.fit(mesh0.vertices) 42 | #gm3 = GaussianMixture(km,init_params=init,max_iter=25,tol=1e-4); gm3.fit(mesh2.vertices) 43 | 44 | #gm3 = GaussianMixture(100); gm3.fit(mesh4.vertices) 45 | #print(coma.shape[0],com.shape[0],mesh2.vertices.shape[0],mesh3.vertices.shape[0]) 46 | s0 = gm0.score(mesh4.vertices) 47 | s1 = gm1.score(mesh4.vertices) 48 | s2 = gm2.score(mesh4.vertices) 49 | s3 = gm3.score(mesh4.vertices) 50 | 51 | #print(gm0.n_iter_,gm1.n_iter_) 52 | #print(gm2.n_iter_,gm3.n_iter_) 53 | #print(s0,s1) 54 | #print(s2,s3) 55 | print('.',end='',flush=True) 56 | fout.write("{},{},{},{},{}\n".format(km,init,'0',s0,gm0.n_iter_)) 57 | fout.write("{},{},{},{},{}\n".format(km,init,'1',s1,gm1.n_iter_)) 58 | fout.write("{},{},{},{},{}\n".format(km,init,'2',s2,gm2.n_iter_)) 59 | fout.write("{},{},{},{},{}\n".format(km,init,'3',s3,gm3.n_iter_)) 60 | print('') 61 | #print(gm1.aic(mesh4.vertices),gm2.aic(mesh4.vertices))#,gm3.aic(mesh4.vertices)) 62 | 63 | #print((res[2] >0).sum(),(res2[2] >0).sum()) 64 | if False: 65 | import matplotlib.pyplot as plt 66 | import mpl_toolkits.mplot3d as m3d 67 | ax = m3d.Axes3D(plt.figure()) 68 | ax.scatter(com[:,0],com[:,1],com[:,2],s=a) 69 | ax.scatter(verts[:,0],verts[:,1],verts[:,2],s=20) 70 | plt.show() 71 | -------------------------------------------------------------------------------- /gmm_fit3.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import multivariate_normal as mvn_pdf 3 | 4 | import matplotlib.pyplot as plt 5 | from cluster import MiniBatchKMeans 6 | from mixture import GaussianMixture 7 | import pymesh 8 | from scipy.special import logsumexp 9 | 10 | mesh0 = pymesh.load_mesh("bunny/bun_zipper_res4.ply") 11 | #mesh3 = pymesh.load_mesh("bunny/bun_zipper_res4_pds.ply") 12 | #mesh4 = pymesh.load_mesh("bunny/bun_zipper_res4_25k_pds.ply") 13 | mesh4 = pymesh.load_mesh("bunny/bun_zipper_res4_sds.ply") 14 | 15 | def get_centroids(mesh): 16 | # obtain a vertex for each face index 17 | face_vert = mesh.vertices[mesh.faces.reshape(-1),:].reshape((mesh.faces.shape[0],3,-1)) 18 | # face_vert is size (faces,3(one for each vert), 3(one for each dimension)) 19 | centroids = face_vert.sum(1)/3.0 20 | ABAC = face_vert[:,1:3,:] - face_vert[:,0:1,:] 21 | areas = np.linalg.norm(np.cross(ABAC[:,0,:],ABAC[:,1,:]),axis=1)/2.0 22 | return centroids, areas 23 | 24 | com,a = get_centroids(mesh0) 25 | face_vert = mesh0.vertices[mesh0.faces.reshape(-1),:].reshape((mesh0.faces.shape[0],3,-1)) 26 | 27 | #gm3 = GaussianMixture(100,init_params='kmeans'); gm3.set_triangles(face_vert); gm3.fit(com); gm3.set_triangles(None) 28 | gm3 = GaussianMixture(1,init_params='kmeans',tol=1e-4,max_iter=100); gm3.fit(mesh4.vertices) 29 | 30 | def tri_loss(gmm,faces_and_verts): 31 | centroids = face_vert.mean(1) 32 | ABAC = face_vert[:,1:3,:] - face_vert[:,0:1,:] 33 | areas = np.linalg.norm(np.cross(ABAC[:,0,:],ABAC[:,1,:]),axis=1)/2.0 34 | #areas = areas/areas.sum() 35 | total = 0.0 36 | #for idx, face in enumerate(faces_and_verts): 37 | #face is 3 faces with 3d locs 38 | #center = face.mean(0) 39 | #centr2 = centroids[idx,:] 40 | A = faces_and_verts[:,0,:] 41 | B = faces_and_verts[:,1,:] 42 | C = faces_and_verts[:,2,:] 43 | #m = center.reshape((-1,1)) 44 | #thing = np.zeros(gmm.weights_.shape) 45 | thing = np.zeros((faces_and_verts.shape[0],gmm.weights_.shape[0])) 46 | 47 | i = 0 48 | #things = 49 | weights = np.zeros(thing.shape) 50 | for mu, s, si, pi in zip(gmm.means_,gmm.covariances_,gmm.precisions_,gmm.weights_): 51 | weights[:,i] = mvn_pdf.pdf(centroids,mu,s) 52 | #print(mvn_pdf.pdf(points,mu,s).shape,weights.shape) 53 | i+=1 54 | row_sums = weights.sum(axis=1) 55 | #print(row_sums.shape) 56 | weights = weights / row_sums[:, np.newaxis] 57 | i=0 58 | 59 | for mu, s, si, pi in zip(gmm.means_,gmm.covariances_,gmm.precisions_,gmm.weights_): 60 | res = 0.0 61 | dev = (centroids - mu) 62 | 63 | res = 0.0 64 | res -= 0.5 * np.log(2*np.pi) *3 65 | res -= 0.5 * np.log(np.linalg.det(s)) 66 | t1 = (dev.dot(si)*dev).sum(1) 67 | t2 = (A.dot(si)*A + B.dot(si)*B + C.dot(si)*C - 3*centroids.dot(si)*centroids).sum(1) 68 | #print("T1\t",t1.sum(),t1.min(),t1.max(),t1.mean()) 69 | 70 | #print("T2\t",t2.sum(),t2.min(),t2.max(),t2.mean()) 71 | res -= 0.5 * (t1 + (1.0/12.0) * t2) 72 | total += ((res + np.log(pi))).sum() 73 | thing[:,i] = ((res+ np.log(pi))) 74 | i+=1 75 | #total += thing.sum()*#areas[idx]#logsumexp(thing)*areas[idx] 76 | return logsumexp(thing,axis=1).mean()#.sum()/areas.sum()#.mean()#/points.shape[0] 77 | #return total/areas.sum()#faces_and_verts.shape[0] 78 | def tri_loss_lb(gmm,faces_and_verts): 79 | centroids = face_vert.mean(1) 80 | ABAC = face_vert[:,1:3,:] - face_vert[:,0:1,:] 81 | areas = np.linalg.norm(np.cross(ABAC[:,0,:],ABAC[:,1,:]),axis=1)/2.0 82 | #areas = areas/areas.sum() 83 | total = 0.0 84 | #for idx, face in enumerate(faces_and_verts): 85 | #face is 3 faces with 3d locs 86 | #center = face.mean(0) 87 | #centr2 = centroids[idx,:] 88 | A = faces_and_verts[:,0,:] 89 | B = faces_and_verts[:,1,:] 90 | C = faces_and_verts[:,2,:] 91 | #m = center.reshape((-1,1)) 92 | #thing = np.zeros(gmm.weights_.shape) 93 | thing = np.zeros((faces_and_verts.shape[0],gmm.weights_.shape[0])) 94 | 95 | i = 0 96 | #things = 97 | weights = np.zeros(thing.shape) 98 | for mu, s, si, pi in zip(gmm.means_,gmm.covariances_,gmm.precisions_,gmm.weights_): 99 | weights[:,i] = mvn_pdf.pdf(centroids,mu,s) 100 | #print(mvn_pdf.pdf(points,mu,s).shape,weights.shape) 101 | i+=1 102 | row_sums = weights.sum(axis=1) 103 | #print(row_sums.shape) 104 | weights = weights / row_sums[:, np.newaxis] 105 | i=0 106 | 107 | for mu, s, si, pi in zip(gmm.means_,gmm.covariances_,gmm.precisions_,gmm.weights_): 108 | res = 0.0 109 | dev = (centroids - mu) 110 | 111 | res = 0.0 112 | res -= 0.5 * np.log(2*np.pi) *3 113 | res -= 0.5 * np.log(np.linalg.det(s)) 114 | t1 = (dev.dot(si)*dev).sum(1) 115 | t2 = (A.dot(si)*A + B.dot(si)*B + C.dot(si)*C - 3*centroids.dot(si)*centroids).sum(1) 116 | #print("T1\t",t1.sum(),t1.min(),t1.max(),t1.mean()) 117 | 118 | #print("T2\t",t2.sum(),t2.min(),t2.max(),t2.mean()) 119 | res -= 0.5 * (t1 + (1.0/12.0) * t2) 120 | total += ((res + np.log(pi))).sum() 121 | thing[:,i] = ((res+ np.log(pi)))*areas#/areas.mean() 122 | i+=1 123 | #total += thing.sum()*#areas[idx]#logsumexp(thing)*areas[idx] 124 | #thing = thing*weights 125 | 126 | return np.sum(thing,axis=1).sum()/areas.sum()#.sum()/areas.sum()#.mean()#/points.shape[0] 127 | #return total/areas.sum()#faces_and_verts.shape[0] 128 | def pt_loss(gmm,points): 129 | total = 0.0 130 | #for p in points: 131 | thing = np.zeros((points.shape[0],gmm.weights_.shape[0])) 132 | i = 0 133 | #things = 134 | for mu, s, si, pi in zip(gmm.means_,gmm.covariances_,gmm.precisions_,gmm.weights_): 135 | res = 0.0 136 | dev = points-mu 137 | 138 | res = 0.0 139 | res -= 0.5 * np.log(2*np.pi) *3 140 | res -= 0.5 * np.log(np.linalg.det(s)) 141 | t1 = (dev.dot(si) * dev).sum(1) 142 | res -= 0.5 * t1 143 | #total += (res + np.log(pi)).sum() 144 | thing[:,i] = (res + np.log(pi)) 145 | i+=1 146 | #total += thing.sum()#logsumexp(thing) 147 | return logsumexp(thing,axis=1).mean()#logsumexp(thing,axis=1).mean()#/points.shape[0] 148 | def pt_loss_lb(gmm,points): 149 | total = 0.0 150 | #for p in points: 151 | thing = np.zeros((points.shape[0],gmm.weights_.shape[0])) 152 | i = 0 153 | #things = 154 | weights = np.zeros(thing.shape) 155 | for mu, s, si, pi in zip(gmm.means_,gmm.covariances_,gmm.precisions_,gmm.weights_): 156 | weights[:,i] = mvn_pdf.pdf(points,mu,s) 157 | #print(mvn_pdf.pdf(points,mu,s).shape,weights.shape) 158 | i+=1 159 | row_sums = weights.sum(axis=1) 160 | #print(row_sums.shape) 161 | weights = weights / row_sums[:, np.newaxis] 162 | i=0 163 | for mu, s, si, pi in zip(gmm.means_,gmm.covariances_,gmm.precisions_,gmm.weights_): 164 | res = 0.0 165 | dev = points-mu 166 | 167 | res = 0.0 168 | res -= 0.5 * np.log(2*np.pi) *3 169 | res -= 0.5 * np.log(np.linalg.det(s)) 170 | t1 = (dev.dot(si) * dev).sum(1) 171 | res -= 0.5 * t1 172 | #total += (res + np.log(pi)).sum() 173 | thing[:,i] = (res + np.log(pi)) 174 | i+=1 175 | #total += thing.sum()#logsumexp(thing) 176 | #thing = thing*weights 177 | return thing.sum(axis=1).mean()#logsumexp(thing,axis=1).mean()#/points.shape[0] 178 | def com_loss(gmm,points,areas): 179 | total = 0.0 180 | #for p in points: 181 | thing = np.zeros((points.shape[0],gmm.weights_.shape[0])) 182 | i = 0 183 | #things = 184 | for mu, s, si, pi in zip(gmm.means_,gmm.covariances_,gmm.precisions_,gmm.weights_): 185 | res = 0.0 186 | dev = points-mu 187 | 188 | res = 0.0 189 | res -= 0.5 * np.log(2*np.pi) *3 190 | res -= 0.5 * np.log(np.linalg.det(s)) 191 | t1 = (dev.dot(si) * dev).sum(1) 192 | res -= 0.5 * t1 193 | #total += (res + np.log(pi)).sum() 194 | thing[:,i] = (res + np.log(pi))*(areas/areas.mean()) 195 | i+=1 196 | #total += thing.sum()#logsumexp(thing) 197 | return logsumexp(thing,axis=1).mean()#/points.shape[0] 198 | def com_loss_lb(gmm,points,areas): 199 | total = 0.0 200 | #for p in points: 201 | thing = np.zeros((points.shape[0],gmm.weights_.shape[0])) 202 | i = 0 203 | #things = 204 | for mu, s, si, pi in zip(gmm.means_,gmm.covariances_,gmm.precisions_,gmm.weights_): 205 | res = 0.0 206 | dev = points-mu 207 | 208 | res = 0.0 209 | res -= 0.5 * np.log(2*np.pi) *3 210 | res -= 0.5 * np.log(np.linalg.det(s)) 211 | t1 = (dev.dot(si) * dev).sum(1) 212 | res -= 0.5 * t1 213 | #total += (res + np.log(pi)).sum() 214 | thing[:,i] = (res + np.log(pi))*(areas/areas.mean()) 215 | i+=1 216 | #total += thing.sum()#logsumexp(thing) 217 | return np.sum(thing,axis=1).mean()#/points.shape[0] 218 | 219 | if True: 220 | tl = tri_loss_lb 221 | cl = com_loss_lb 222 | pl = pt_loss_lb 223 | print("OMG") 224 | else: 225 | tl = tri_loss 226 | cl = com_loss 227 | pl = pt_loss 228 | 229 | print("tri\t",tl(gm3,face_vert),'\t',0) 230 | print("mpt\t",pl(gm3,com),'\t',0) 231 | print('com\t',cl(gm3,com,a),'\t',0) 232 | #print("ptLB\t",pt_loss_lb(gm3,com)) 233 | 234 | #print("spt\t",gm3.score(com)) 235 | 236 | #print("sp\t",gm3._estimate_weighted_log_prob(com).sum()) 237 | 238 | 239 | for pn in np.logspace(1,np.log10(mesh4.vertices.shape[0]*.95),10): 240 | scores = [] 241 | for itern in range(10): 242 | ptsn = np.random.choice(range(mesh4.vertices.shape[0]),int(pn),replace=False) 243 | scores.append(pl(gm3,mesh4.vertices[ptsn,:])) 244 | #scores.append(gm3._estimate_weighted_log_prob(mesh4.vertices[ptsn,:]).sum()/pn) 245 | scores = np.array(scores) 246 | print(ptsn.shape[0],'\t',scores.mean(),'\t',scores.std()) 247 | #print(" ",gm3.score(mesh4.vertices)) 248 | 249 | #print(" ",gm3._estimate_weighted_log_prob(mesh4.vertices).sum()/mesh4.vertices.shape[0]) 250 | -------------------------------------------------------------------------------- /gmm_fit_extra.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import multivariate_normal as mvn_pdf 3 | 4 | import matplotlib.pyplot as plt 5 | from mixture import GaussianMixture 6 | import pymesh 7 | 8 | if False: 9 | mesh0 = pymesh.load_mesh("bunny/bun_zipper_res4.ply") 10 | mesh4 = pymesh.load_mesh("bunny/bun_zipper_res4_25k_pds.ply") 11 | elif False: 12 | mesh0 = pymesh.load_mesh("arma/Armadillo_1000.ply") 13 | mesh4 = pymesh.load_mesh("arma/Armadillo_25k.ply") 14 | elif False: 15 | mesh0 = pymesh.load_mesh("dragon/dragon_recon/dragon_vrip_1000.ply") 16 | mesh4 = pymesh.load_mesh("dragon/dragon_recon/dragon_vrip_25k.ply") 17 | elif False: 18 | mesh0 = pymesh.load_mesh("happy/happy_recon/happy_vrip_1000.ply") 19 | mesh4 = pymesh.load_mesh("happy/happy_recon/happy_vrip_25k.ply") 20 | elif True: 21 | mesh0 = pymesh.load_mesh("lucy/lucy_1000.ply") 22 | mesh4 = pymesh.load_mesh("lucy/lucy_25k.ply") 23 | def get_centroids(mesh): 24 | # obtain a vertex for each face index 25 | face_vert = mesh.vertices[mesh.faces.reshape(-1),:].reshape((mesh.faces.shape[0],3,-1)) 26 | # face_vert is size (faces,3(one for each vert), 3(one for each dimension)) 27 | centroids = face_vert.sum(1)/3.0 28 | ABAC = face_vert[:,1:3,:] - face_vert[:,0:1,:] 29 | areas = np.linalg.norm(np.cross(ABAC[:,0,:],ABAC[:,1,:]),axis=1)/2.0 30 | return centroids, areas 31 | 32 | def get_tri_covar(tris): 33 | covars = [] 34 | for face in tris: 35 | A = face[0][:,None] 36 | B = face[1][:,None] 37 | C = face[2][:,None] 38 | M = (A+B+C)/3 39 | covars.append(A @ A.T + B @ B.T + C @ C.T - 3* M @ M.T) 40 | return np.array(covars)*(1/12.0) 41 | com,a = get_centroids(mesh0) 42 | face_vert = mesh0.vertices[mesh0.faces.reshape(-1),:].reshape((mesh0.faces.shape[0],3,-1)) 43 | data_covar = get_tri_covar(face_vert) 44 | 45 | with open('lucy.log','w') as fout: 46 | for km in [100]: 47 | for init in ['kmeans']: 48 | for exp_n in range(10): 49 | gm3 = GaussianMixture(km,init_params=init,max_iter=25,tol=1e-12); gm3.set_covars(data_covar); gm3.set_areas(a); gm3.fit(com); gm3.set_covars(None); gm3.set_areas(None) 50 | gm0 = GaussianMixture(km,init_params=init,max_iter=25,tol=1e-12); gm0.set_areas(a); gm0.fit(com); gm0.set_areas(None) 51 | gm1 = GaussianMixture(km,init_params=init,max_iter=25,tol=1e-12); gm1.fit(com) 52 | gm2 = GaussianMixture(km,init_params=init,max_iter=25,tol=1e-12); gm2.fit(mesh0.vertices) 53 | #gm3 = GaussianMixture(km,init_params=init,max_iter=25,tol=1e-4); gm3.fit(mesh2.vertices) 54 | 55 | #gm3 = GaussianMixture(100); gm3.fit(mesh4.vertices) 56 | #print(coma.shape[0],com.shape[0],mesh2.vertices.shape[0],mesh3.vertices.shape[0]) 57 | s0 = gm0.score(mesh4.vertices) 58 | s1 = gm1.score(mesh4.vertices) 59 | s2 = gm2.score(mesh4.vertices) 60 | s3 = gm3.score(mesh4.vertices) 61 | 62 | #print(gm0.n_iter_,gm1.n_iter_) 63 | #print(gm2.n_iter_,gm3.n_iter_) 64 | #print(s0,s1) 65 | #print(s2,s3) 66 | print('.',end='',flush=True) 67 | fout.write("{},{},{},{},{}\n".format(km,init,'0',s0,gm0.n_iter_)) 68 | fout.write("{},{},{},{},{}\n".format(km,init,'1',s1,gm1.n_iter_)) 69 | fout.write("{},{},{},{},{}\n".format(km,init,'2',s2,gm2.n_iter_)) 70 | fout.write("{},{},{},{},{}\n".format(km,init,'3',s3,gm3.n_iter_)) 71 | print('') 72 | #print(gm1.aic(mesh4.vertices),gm2.aic(mesh4.vertices))#,gm3.aic(mesh4.vertices)) 73 | 74 | #print((res[2] >0).sum(),(res2[2] >0).sum()) 75 | if False: 76 | import matplotlib.pyplot as plt 77 | import mpl_toolkits.mplot3d as m3d 78 | ax = m3d.Axes3D(plt.figure()) 79 | ax.scatter(com[:,0],com[:,1],com[:,2],s=a) 80 | ax.scatter(verts[:,0],verts[:,1],verts[:,2],s=20) 81 | plt.show() 82 | -------------------------------------------------------------------------------- /graph_scales.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import pandas as pd" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "df = pd.read_csv(\"bunny_fit_extra7.log\",names=['k','init','model','l','i','scale'])\n", 21 | "mdf = df.groupby(['init','model','scale','k']).mean()\n", 22 | "sdf = df.groupby(['init','model','scale','k']).std()\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "fig = plt.figure(figsize=(10,10))\n", 32 | "for ii, init in enumerate(['kmeans','random']):\n", 33 | " for model in range(1):\n", 34 | " for scale in sorted(df.scale.unique()):\n", 35 | " ls = '--' if model < 1 else '-'\n", 36 | " if model == 0:\n", 37 | " label = 'Triangles scale: e^{0:.1f}'.format(np.log(scale))\n", 38 | " if model == 1:\n", 39 | " label = 'Center of Mass'\n", 40 | " if model == 2:\n", 41 | " label = 'Vertices'\n", 42 | " ldf = mdf.loc[(init,model,scale),]\n", 43 | " x = np.array(ldf.index)\n", 44 | "\n", 45 | " y = ldf.values[:,0]\n", 46 | " error = 2*sdf.loc[(init,model,scale),].values[:,0]\n", 47 | "\n", 48 | " ax = plt.subplot(2,2,1+ii*2)\n", 49 | " plt.plot(x,y,ls=ls,label=label)\n", 50 | " plt.fill_between(x, y-error, y+error,alpha=0.3)\n", 51 | " plt.grid(True)\n", 52 | " plt.xlabel('number of mixtures (m)')\n", 53 | " plt.title('{} initialization'.format(init))\n", 54 | " plt.ylabel('likelihood of ground truth ')\n", 55 | " plt.ylim(2,9)\n", 56 | " ax.set_xscale(\"log\", nonposx='clip')\n", 57 | "\n", 58 | " y = ldf.values[:,1]\n", 59 | " error = sdf.loc[(init,model,scale),].values[:,1]\n", 60 | "\n", 61 | " ax = plt.subplot(2,2,2+ii*2)\n", 62 | " plt.plot(x,y,ls=ls,label=label)\n", 63 | " plt.fill_between(x, y-error, y+error,alpha=0.15)\n", 64 | "\n", 65 | " ax.set_xscale(\"log\", nonposx='clip')\n", 66 | " plt.grid(True)\n", 67 | " plt.xlabel('number of mixtures (m)')\n", 68 | " plt.title('{} initialization'.format(init))\n", 69 | " plt.ylabel('iterations until convergence')\n", 70 | " #plt.subplot(2,2,3)\n", 71 | " #plt.subplot(2,2,4)\n", 72 | "for i in range(4):\n", 73 | " plt.subplot(2,2,1+i)\n", 74 | " plt.legend()\n", 75 | "fig.subplots_adjust(wspace=0.3,hspace=0.3)\n", 76 | "plt.savefig('graphs7.pdf')" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "mdf.loc[('kmeans',0),].values, np.array(mdf.loc[('kmeans',0),].index)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "df2" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [] 103 | } 104 | ], 105 | "metadata": { 106 | "kernelspec": { 107 | "display_name": "Python 3", 108 | "language": "python", 109 | "name": "python3" 110 | }, 111 | "language_info": { 112 | "codemirror_mode": { 113 | "name": "ipython", 114 | "version": 3 115 | }, 116 | "file_extension": ".py", 117 | "mimetype": "text/x-python", 118 | "name": "python", 119 | "nbconvert_exporter": "python", 120 | "pygments_lexer": "ipython3", 121 | "version": "3.6.7" 122 | } 123 | }, 124 | "nbformat": 4, 125 | "nbformat_minor": 2 126 | } 127 | -------------------------------------------------------------------------------- /ground_truth_testing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pylab import *\n", 10 | "import os\n", 11 | "import sys\n", 12 | "from sklearn.mixture import GaussianMixture\n", 13 | "import pickle\n", 14 | "from PIL import Image" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "dataset_dir = 'rgbd_dataset_freiburg3_long_office_household'\n", 24 | "depth_dir = 'depth'\n", 25 | "gmm_dir = 'gmm_qqvga'\n", 26 | "\n", 27 | "# og \n", 28 | "fx = 525.0 # focal length x\n", 29 | "fy = 525.0 # focal length y\n", 30 | "cx = 319.5 # optical center x\n", 31 | "cy = 239.5 # optical center y\n", 32 | "\n", 33 | "# fri3 \n", 34 | "fx = 535.4 # focal length x\n", 35 | "fy = 539.2 # focal length y\n", 36 | "cx = 320.1 # optical center x\n", 37 | "cy = 247.6 # optical center y\n", 38 | "factor = 5000 # for the 16-bit PNG files\n", 39 | "# OR: factor = 1 # for the 32-bit float images in the ROS bag files\n", 40 | "\n", 41 | "full_depth_dir = os.path.join(dataset_dir,depth_dir)\n" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "y_grid = np.repeat(np.arange(480)[:,None],640,1)\n", 51 | "x_grid = np.repeat(np.arange(640)[None,:],480,0)\n", 52 | "\n", 53 | "def img_to_pts(img):\n", 54 | " Z = img/factor\n", 55 | " X = (x_grid-cx) * Z /fx\n", 56 | " Y = (y_grid-cy) * Z /fy\n", 57 | " #for v in range(depth_image.height):\n", 58 | " # for u in range(depth_image.width):\n", 59 | " # Z = depth_image[v,u] / factor;\n", 60 | " # X = (u - cx) * Z / fx;\n", 61 | " # Y = (v - cy) * Z / fy;\n", 62 | " X = X[::4,::4]\n", 63 | " Y = Y[::4,::4]\n", 64 | " Z = Z[::4,::4]\n", 65 | "\n", 66 | " xl = X[Z > 0]\n", 67 | " yl = Y[Z > 0]\n", 68 | " zl = Z[Z > 0]\n", 69 | " n = zl.shape[0]\n", 70 | " idx = np.random.randint(n, size=n//16)\n", 71 | "\n", 72 | " return np.vstack([xl,yl,zl]).T\n", 73 | "\n", 74 | "def load_data(file):\n", 75 | " image = Image.open(file) \n", 76 | " pixel = np.array(image)\n", 77 | " return img_to_pts(pixel)\n", 78 | "\n", 79 | "dataset = []\n", 80 | "names = []\n", 81 | "for fl in sorted(os.listdir(full_depth_dir)):\n", 82 | " name,ext = os.path.splitext(fl)\n", 83 | " data = load_data(os.path.join(full_depth_dir,fl))\n", 84 | " dataset.append(data)\n", 85 | " names.append(name)\n" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "%matplotlib notebook\n", 95 | "from mpl_toolkits.mplot3d import Axes3D\n", 96 | "from mpl_toolkits.mplot3d.art3d import Poly3DCollection\n", 97 | "import matplotlib.pyplot as plt\n", 98 | "from matplotlib.colors import LightSource" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "name_times = [float(_) for _ in names]" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "import pandas as pd\n", 117 | "gt = pd.read_csv('rgbd_dataset_freiburg3_long_office_household-groundtruth.txt',' ',comment='#',header=None,names='timestamp tx ty tz qx qy qz qw'.split(' '))\n", 118 | "def row_to_vec(row):\n", 119 | " a = np.array(row)\n", 120 | " return a[1:4], a[4:]\n", 121 | "from scipy.spatial.distance import cdist, pdist\n", 122 | "time_dists = cdist(np.array(gt.timestamp)[:,None],np.array(name_times)[:,None])\n", 123 | "matches = time_dists.argmin(axis=0)\n", 124 | "\n", 125 | "start_idx = matches[25]\n", 126 | "end_idx = matches[125]\n", 127 | "\n", 128 | "frm = row_to_vec(gt.iloc[start_idx])\n", 129 | "to = row_to_vec(gt.iloc[end_idx])\n", 130 | "print(gt.timestamp[end_idx] - gt.timestamp[start_idx])\n", 131 | "\n", 132 | "start_idx,end_idx\n" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "for i in range(len(matches)-1):\n", 142 | " f1idx = matches[i]\n", 143 | " f2idx = matches[i+1]\n", 144 | " f1 = row_to_vec(gt.iloc[f1idx])\n", 145 | " f2 = row_to_vec(gt.iloc[f2idx])\n", 146 | "\n", 147 | " print(f1)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "frm,to" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "fig = plt.figure()\n", 166 | "ax = Axes3D(fig)\n", 167 | "\n", 168 | "start_frame = dataset[start_idx][np.random.randint(0,dataset[start_idx].shape[0],1000)]\n", 169 | "end_frame = dataset[end_idx][np.random.randint(0,dataset[end_idx].shape[0],1000)]\n", 170 | "\n", 171 | "ax.scatter(start_frame[:,0],start_frame[:,1],start_frame[:,2])\n", 172 | "ax.scatter(end_frame[:,0],end_frame[:,1],end_frame[:,2])\n", 173 | "ax.set_xlim(-1.5,1.5)\n", 174 | "ax.set_ylim(-1.5,1.5)\n", 175 | "ax.set_zlim(0,3)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "import transforms3d\n", 185 | "fig = plt.figure()\n", 186 | "ax = Axes3D(fig)\n", 187 | "\n", 188 | "start_frame = dataset[start_idx][np.random.randint(0,dataset[start_idx].shape[0],1000)]\n", 189 | "end_frame = dataset[end_idx][np.random.randint(0,dataset[end_idx].shape[0],1000)]\n", 190 | "\n", 191 | "r1 = transforms3d.quaternions.quat2mat(to[1])\n", 192 | "r2 = transforms3d.quaternions.quat2mat(frm[1])\n", 193 | "r1 = np.identity(3)\n", 194 | "r2 = np.identity(3)\n", 195 | "\n", 196 | "end_frame = end_frame + (to[0]-frm[0])\n", 197 | "end_frame = (r1 @ r2.T @(end_frame - end_frame.mean(0)).T).T + end_frame.mean(0)\n", 198 | "\n", 199 | "ax.scatter(start_frame[:,0],start_frame[:,1],start_frame[:,2])\n", 200 | "ax.scatter(end_frame[:,0],end_frame[:,1],end_frame[:,2])\n", 201 | "ax.set_xlim(-2,2)\n", 202 | "ax.set_ylim(-2,2)\n", 203 | "ax.set_zlim(0.5,4.5)" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [] 212 | } 213 | ], 214 | "metadata": { 215 | "kernelspec": { 216 | "display_name": "Python 3", 217 | "language": "python", 218 | "name": "python3" 219 | }, 220 | "language_info": { 221 | "codemirror_mode": { 222 | "name": "ipython", 223 | "version": 3 224 | }, 225 | "file_extension": ".py", 226 | "mimetype": "text/x-python", 227 | "name": "python", 228 | "nbconvert_exporter": "python", 229 | "pygments_lexer": "ipython3", 230 | "version": "3.6.7" 231 | } 232 | }, 233 | "nbformat": 4, 234 | "nbformat_minor": 2 235 | } 236 | -------------------------------------------------------------------------------- /icp_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pymesh\n", 10 | "import numpy as np\n", 11 | "import transforms3d\n", 12 | "from scipy.spatial.distance import cdist,pdist\n", 13 | "from mpl_toolkits.mplot3d import Axes3D\n", 14 | "\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "import matplotlib\n" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "mesh_pts = pymesh.load_mesh(\"bunny/bun_zipper_res4_sds.ply\")\n", 26 | "mesh0 = pymesh.load_mesh(\"bunny/bun_zipper_res4.ply\")\n", 27 | "SAMPLE_PTS = 453\n", 28 | "full_points = mesh0.vertices\n", 29 | "t = np.random.rand(3)*0.1 - 0.05\n", 30 | "angles = np.random.rand(3)*30 - 15\n", 31 | "angles *= np.pi/180.0\n", 32 | "#angles *= 0 \n", 33 | "M = transforms3d.euler.euler2mat(angles[0],angles[1],angles[2])\n", 34 | "true_q = transforms3d.quaternions.mat2quat(M)\n", 35 | "indices = np.random.randint(0,full_points.shape[0],SAMPLE_PTS)\n", 36 | "samples= full_points#[indices]\n", 37 | "source = (samples-samples.mean(0)) @ M + samples.mean(0) + t" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "icp_t = np.zeros(3)\n", 47 | "R = np.identity(3)\n", 48 | "source2 = np.copy(source)\n", 49 | "prev_err = 100000000\n", 50 | "indices2 = np.random.randint(0,full_points.shape[0],SAMPLE_PTS)\n", 51 | "samples_for_icp = np.copy(samples) #full_points[indices2]\n", 52 | "flag = True\n", 53 | "icp_q = np.array([1,0,0,0])\n", 54 | "for icp_iter in range(50):\n", 55 | " fig = plt.figure()\n", 56 | " ax = fig.add_subplot(111, projection='3d')\n", 57 | " ax.scatter(samples[:,0],samples[:,1],samples[:,2],label='orig',alpha=0.5)\n", 58 | "\n", 59 | " ax.scatter(source[:,0],source[:,1],source[:,2],label='trans',alpha=0.5) \n", 60 | " result = (source - icp_t) \n", 61 | " result = (result - result.mean(0)) @ R.T + result.mean(0)\n", 62 | " ax.scatter(result[:,0],result[:,1],result[:,2],label='registered',alpha=0.8,s=4)\n", 63 | " plt.title('{:.3f} {:.3f} {:.3f}'.format(icp_q.dot(transforms3d.quaternions.qconjugate(true_q)),icp_q.dot(true_q),np.linalg.norm(icp_t-t)) )\n", 64 | " plt.legend()\n", 65 | " plt.show()\n", 66 | " \n", 67 | " dist = cdist(source2,samples_for_icp)\n", 68 | " sample_idx = np.argmin(dist,1)\n", 69 | " matched_pts = samples_for_icp[sample_idx]\n", 70 | " it = source2.mean(0) - matched_pts.mean(0)\n", 71 | " if flag:\n", 72 | " idx2 = np.argmin(dist,0)\n", 73 | " matched2 = source2[idx2]\n", 74 | " it += matched2.mean(0) - samples_for_icp.mean(0)\n", 75 | "\n", 76 | " H = (source2-source2.mean(0)).T @ (matched_pts-matched_pts.mean(0))\n", 77 | " if flag:\n", 78 | " H2 = (matched2-matched2.mean(0)).T @ (samples_for_icp-samples_for_icp.mean(0))\n", 79 | " H2 *= source2.shape[0]/samples_for_icp.shape[0]\n", 80 | " H = H + H2\n", 81 | " u,s,vt = np.linalg.svd(H)\n", 82 | " rotmat = vt.T @ np.diag([1,1,np.linalg.det(vt.T @ u.T)]) @ u.T\n", 83 | " rotmat = rotmat.T\n", 84 | " #print(rotmat,'\\n',M)\n", 85 | " #print(it,'\\n',t)\n", 86 | "\n", 87 | " source2 = (source2 - source2.mean(0)) @ rotmat + source2.mean(0) - it \n", 88 | " err = np.linalg.norm(source2-matched_pts,axis=1)\n", 89 | " err = err.mean()\n", 90 | " #print(err)\n", 91 | " #print(np.diag(cdist(source2,matched_pts)).mean(),len(matched_pts))\n", 92 | " if np.linalg.norm(err-prev_err) < 1e-6:\n", 93 | " break\n", 94 | " prev_err = err\n", 95 | " icp_t += it\n", 96 | " R = R @ rotmat.T\n", 97 | " #print(it)\n", 98 | " #print(rotmat)\n", 99 | "\n", 100 | " icp_q = transforms3d.quaternions.mat2quat(R)\n", 101 | " icp_t = icp_t\n", 102 | "\n" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "t,icp_t" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "%matplotlib notebook\n", 121 | "fig = plt.figure()\n", 122 | "ax = fig.add_subplot(111, projection='3d')\n", 123 | "ax.scatter(samples[:,0],samples[:,1],samples[:,2],label='orig',alpha=0.5)\n", 124 | "\n", 125 | "ax.scatter(source[:,0],source[:,1],source[:,2],label='trans',alpha=0.5) \n", 126 | "result = (source - icp_t) \n", 127 | "result = (result - result.mean(0)) @ R.T + result.mean(0)\n", 128 | "ax.scatter(result[:,0],result[:,1],result[:,2],label='registered',alpha=0.8,s=4)\n", 129 | "plt.title('{:.3f} {:.3f} {:.3f}'.format(icp_q.dot(transforms3d.quaternions.qconjugate(true_q)),icp_q.dot(true_q),np.linalg.norm(icp_t-t)) )\n", 130 | "plt.legend()\n", 131 | "plt.show()\n", 132 | " " 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "res = source\n", 142 | "res = (res - res.mean(0)) @ M.T + res.mean(0) - t\n", 143 | "res" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "source2" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "np.diag(cdist(res,matched_pts)).mean()" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "np.diag(cdist(matched_pts,samples)).mean()" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "%matplotlib notebook\n", 180 | "# 'GTK3Agg', 'GTK3Cairo', 'MacOSX', 'nbAgg', 'Qt4Agg', 'Qt4Cairo', 'Qt5Agg', 'Qt5Cairo', 'TkAgg', 'TkCairo', 'WebAgg', 'WX', 'WXAgg', 'WXCairo', 'agg', 'cairo', 'pdf', 'pgf', 'ps', 'svg', 'template'" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "from functools import partial\n", 190 | "import matplotlib.pyplot as plt\n", 191 | "from mpl_toolkits.mplot3d import Axes3D\n", 192 | "from pycpd import rigid_registration\n", 193 | "import numpy as np\n", 194 | "import time\n", 195 | "\n", 196 | "def visualize(iteration, error, X, Y, ax):\n", 197 | " plt.cla()\n", 198 | " ax.scatter(X[:,0], X[:,1], X[:,2], color='red', label='Target')\n", 199 | " ax.scatter(Y[:,0], Y[:,1], Y[:,2], color='blue', label='Source')\n", 200 | " ax.text2D(0.87, 0.92, 'Iteration: {:d}\\nError: {:06.4f}'.format(iteration, error), horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, fontsize='x-large')\n", 201 | " ax.legend(loc='upper left', fontsize='x-large')\n", 202 | " plt.draw()\n", 203 | " plt.pause(0.001)\n", 204 | "\n", 205 | "fig = plt.figure()\n", 206 | "ax = fig.add_subplot(111, projection='3d')\n", 207 | "callback = partial(visualize, ax=ax)\n", 208 | "\n", 209 | "reg = rigid_registration(**{ 'X': source, 'Y': samples_for_icp })\n", 210 | "reg.register(callback)\n", 211 | "plt.show()\n", 212 | "\n" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "reg = rigid_registration(**{ 'X': source, 'Y': samples_for_icp, 'max_iterations':500,'tolerance':1e-8 })\n" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | " TY, (s_reg, R_reg, t_reg) = reg.register()" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "metadata": {}, 237 | "outputs": [], 238 | "source": [ 239 | "R_reg" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": null, 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [ 248 | "s_reg" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "TY" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "t_reg,t,np.linalg.norm(t_reg-t)/0.05" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": null, 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "M" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": null, 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "samples= full_points\n", 285 | "samples_mean = samples.mean(0)\n", 286 | "centered_points = samples - samples_mean\n", 287 | "source = centered_points @ M + samples_mean+ t" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": null, 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [ 296 | "np.linalg.norm( (source-source.mean(0))@M.T +source.mean(0) - t - samples,axis=1).sum()" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "samples_mean" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": null, 311 | "metadata": {}, 312 | "outputs": [], 313 | "source": [ 314 | "samples_mean.shape" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "metadata": {}, 321 | "outputs": [], 322 | "source": [] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": null, 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [] 330 | } 331 | ], 332 | "metadata": { 333 | "kernelspec": { 334 | "display_name": "Python 3", 335 | "language": "python", 336 | "name": "python3" 337 | }, 338 | "language_info": { 339 | "codemirror_mode": { 340 | "name": "ipython", 341 | "version": 3 342 | }, 343 | "file_extension": ".py", 344 | "mimetype": "text/x-python", 345 | "name": "python", 346 | "nbconvert_exporter": "python", 347 | "pygments_lexer": "ipython3", 348 | "version": "3.6.7" 349 | } 350 | }, 351 | "nbformat": 4, 352 | "nbformat_minor": 2 353 | } 354 | -------------------------------------------------------------------------------- /likelihood.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import numpy as np" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "fp1 = pd.read_csv('full_probs.log',delimiter='\\t',names=['label','mu','s'])\n", 21 | "fps = pd.read_csv('full_probs_single.log',delimiter='\\t',names=['label','mu','s'])\n", 22 | "lb1 = pd.read_csv('lower_bounds.log',delimiter='\\t',names=['label','mu','s'])\n", 23 | "lbs = pd.read_csv('lower_bounds_single.log',delimiter='\\t',names=['label','mu','s'])\n", 24 | "dfs = [fp1,lb1]" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "fig = plt.figure(figsize=(12,8))\n", 41 | "def pname(s):\n", 42 | " try: \n", 43 | " int(s)\n", 44 | " return \"Points = \" + s\n", 45 | " except ValueError:\n", 46 | " if s == 'tri':\n", 47 | " return 'Triangle'\n", 48 | " elif s == 'mpt':\n", 49 | " return 'Centroid'\n", 50 | " elif s == 'com':\n", 51 | " return 'Area × Centroid'\n", 52 | " else:\n", 53 | " return s\n", 54 | "\n", 55 | "for idx,df in enumerate(dfs[-1:]):\n", 56 | " ind = np.arange(1,df.shape[0]+1)\n", 57 | " # inferno summer magma\n", 58 | " ax = plt.gca()#plt.subplot(2,1,idx+1)\n", 59 | " colors =[_['color'] for _ in list(plt.rcParams['axes.prop_cycle'])[:3]]\n", 60 | " colors = colors + [plt.cm.inferno(i/float(df.shape[0]-1)) for i in range(df.shape[0]-3)]\n", 61 | " #print(colors)\n", 62 | " ax.bar(np.arange(1, df.shape[0]+1), -df.mu, yerr=df.s*2, align='center', \n", 63 | " alpha=0.66, ecolor='black', capsize=7, color=colors)\n", 64 | " #if idx != 0:\n", 65 | " ax.set_xticklabels([pname(_) for _ in df.label],rotation=75,fontdict={'fontsize':14})\n", 66 | " #else:\n", 67 | " # ax.set_xticklabels([])\n", 68 | " ax.set_xticks(ind)\n", 69 | " if idx == 0:\n", 70 | " ax.set_ylabel('average negative log-likelihood')\n", 71 | " else:\n", 72 | " ax.set_ylabel('average of all mixture & sample log-likelihoods')\n", 73 | " #ax.set_ylim((df.mu-2.5*df.s).min(),(df.mu+2.5*df.s).max())\n", 74 | " plt.hlines(-df.iloc[0,1],1,ind.max()+0.5,linestyles='--')\n", 75 | " #ax.set_ylim([0, 100])\n", 76 | " #ax.set_ylabel('Percent usage')\n", 77 | " #ax.set_title('System Monitor')\n", 78 | " ax.set_ylim([30000,45000])\n", 79 | "#plt.tight_layout()\n", 80 | "fig.suptitle(\"Numerical evaluation of likelihood expressions\\n for a GMM (k=50) fit to 100,000 points\", fontsize=20)\n", 81 | "fig.tight_layout(rect=[0, 0.03, 1, 0.9])\n", 82 | "fig.savefig(\"likelihood3.pdf\")" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "###### " 90 | ] 91 | } 92 | ], 93 | "metadata": { 94 | "kernelspec": { 95 | "display_name": "Python 3", 96 | "language": "python", 97 | "name": "python3" 98 | }, 99 | "language_info": { 100 | "codemirror_mode": { 101 | "name": "ipython", 102 | "version": 3 103 | }, 104 | "file_extension": ".py", 105 | "mimetype": "text/x-python", 106 | "name": "python", 107 | "nbconvert_exporter": "python", 108 | "pygments_lexer": "ipython3", 109 | "version": "3.6.7" 110 | } 111 | }, 112 | "nbformat": 4, 113 | "nbformat_minor": 2 114 | } 115 | -------------------------------------------------------------------------------- /llcomp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import multivariate_normal as mvn_pdf 3 | 4 | import matplotlib.pyplot as plt 5 | from cluster import MiniBatchKMeans 6 | from mixture import GaussianMixture 7 | import pymesh 8 | from scipy.special import logsumexp 9 | 10 | mesh0 = pymesh.load_mesh("bunny/bun_zipper_res4.ply") 11 | #mesh3 = pymesh.load_mesh("bunny/bun_zipper_res4_pds.ply") 12 | #mesh4 = pymesh.load_mesh("bunny/bun_zipper_res4_25k_pds.ply") 13 | mesh4 = pymesh.load_mesh("bunny/bun_zipper_res4_sds.ply") 14 | 15 | def get_centroids(mesh): 16 | # obtain a vertex for each face index 17 | face_vert = mesh.vertices[mesh.faces.reshape(-1),:].reshape((mesh.faces.shape[0],3,-1)) 18 | # face_vert is size (faces,3(one for each vert), 3(one for each dimension)) 19 | centroids = face_vert.sum(1)/3.0 20 | ABAC = face_vert[:,1:3,:] - face_vert[:,0:1,:] 21 | areas = np.linalg.norm(np.cross(ABAC[:,0,:],ABAC[:,1,:]),axis=1)/2.0 22 | return centroids, areas 23 | 24 | com,a = get_centroids(mesh0) 25 | face_vert = mesh0.vertices[mesh0.faces.reshape(-1),:].reshape((mesh0.faces.shape[0],3,-1)) 26 | 27 | #gm3 = GaussianMixture(100,init_params='kmeans'); gm3.set_triangles(face_vert); gm3.fit(com); gm3.set_triangles(None) 28 | gm3 = GaussianMixture(20,init_params='random',tol=1e-2,max_iter=5); gm3.fit(mesh4.vertices) 29 | 30 | def pt_loss(gmm,points): 31 | total = 0.0 32 | thing = np.zeros((points.shape[0],gmm.weights_.shape[0])) 33 | i = 0 34 | for mu, s, si, pi in zip(gmm.means_,gmm.covariances_,gmm.precisions_,gmm.weights_): 35 | thing[:,i] = pi*mvn_pdf.pdf(points,mu,s) 36 | i+=1 37 | return np.log(thing.sum(1)).sum()#logsumexp(thing,axis=1).mean()#/points.shape[0] 38 | def pt_loss_lb(gmm,points): 39 | total = 0.0 40 | #for p in points: 41 | thing = np.zeros((points.shape[0],gmm.weights_.shape[0])) 42 | i = 0 43 | #things = 44 | weights = np.zeros(thing.shape) 45 | for mu, s, si, pi in zip(gmm.means_,gmm.covariances_,gmm.precisions_,gmm.weights_): 46 | weights[:,i] = mvn_pdf.pdf(points,mu,s) 47 | i+=1 48 | row_sums = weights.sum(axis=1) 49 | #print(row_sums.shape) 50 | weights = weights / row_sums[:, np.newaxis] 51 | i=0 52 | print(weights.shape) 53 | for mu, s, si, pi in zip(gmm.means_,gmm.covariances_,gmm.precisions_,gmm.weights_): 54 | res = 0.0 55 | dev = points-mu 56 | 57 | res = 0.0 58 | res -= 0.5 * np.log(2*np.pi) *3 59 | res -= 0.5 * np.log(np.linalg.det(s)) 60 | t1 = (dev.dot(si) * dev).sum(1) 61 | res -= 0.5 * t1 62 | #total += (res + np.log(pi)).sum() 63 | thing[:,i] = (res + np.log(pi)) 64 | i+=1 65 | #total += thing.sum()#logsumexp(thing) 66 | #thing = thing*weights 67 | return ((thing-np.log(weights))*weights).sum()#logsumexp(thing,axis=1).mean()#/points.shape[0] 68 | 69 | 70 | for pn in np.logspace(1,np.log10(mesh4.vertices.shape[0]*.95),10): 71 | scores = [] 72 | scores2 = [] 73 | for itern in range(2): 74 | ptsn = np.random.choice(range(mesh4.vertices.shape[0]),int(pn),replace=False) 75 | scores.append(pt_loss(gm3,mesh4.vertices[ptsn,:])) 76 | scores2.append(pt_loss_lb(gm3,mesh4.vertices[ptsn,:])) 77 | #scores.append(gm3._estimate_weighted_log_prob(mesh4.vertices[ptsn,:]).sum()/pn) 78 | scores = np.array(scores) 79 | scores2 = np.array(scores2) 80 | 81 | print(ptsn.shape[0],'\t',scores.mean(),'\t',scores2.mean()) 82 | #print(" ",gm3.score(mesh4.vertices)) 83 | 84 | #print(" ",gm3._estimate_weighted_log_prob(mesh4.vertices).sum()/mesh4.vertices.shape[0]) 85 | -------------------------------------------------------------------------------- /mixture/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`sklearn.mixture` module implements mixture modeling algorithms. 3 | """ 4 | 5 | from .gaussian_mixture import GaussianMixture 6 | from .bayesian_mixture import BayesianGaussianMixture 7 | 8 | 9 | __all__ = ['GaussianMixture', 10 | 'BayesianGaussianMixture'] 11 | -------------------------------------------------------------------------------- /mixture/base.py: -------------------------------------------------------------------------------- 1 | """Base class for mixture models.""" 2 | 3 | # Author: Wei Xue 4 | # Modified by Thierry Guillemot 5 | # License: BSD 3 clause 6 | 7 | from __future__ import print_function 8 | 9 | import warnings 10 | from abc import ABCMeta, abstractmethod 11 | from time import time 12 | 13 | import numpy as np 14 | 15 | from sklearn import cluster 16 | from sklearn.base import BaseEstimator 17 | from sklearn.base import DensityMixin 18 | from sklearn.externals import six 19 | from sklearn.exceptions import ConvergenceWarning 20 | from sklearn.utils import check_array, check_random_state 21 | from sklearn.utils.fixes import logsumexp 22 | 23 | 24 | def _check_shape(param, param_shape, name): 25 | """Validate the shape of the input parameter 'param'. 26 | 27 | Parameters 28 | ---------- 29 | param : array 30 | 31 | param_shape : tuple 32 | 33 | name : string 34 | """ 35 | param = np.array(param) 36 | if param.shape != param_shape: 37 | raise ValueError("The parameter '%s' should have the shape of %s, " 38 | "but got %s" % (name, param_shape, param.shape)) 39 | 40 | 41 | def _check_X(X, n_components=None, n_features=None, ensure_min_samples=1): 42 | """Check the input data X. 43 | 44 | Parameters 45 | ---------- 46 | X : array-like, shape (n_samples, n_features) 47 | 48 | n_components : int 49 | 50 | Returns 51 | ------- 52 | X : array, shape (n_samples, n_features) 53 | """ 54 | X = check_array(X, dtype=[np.float64, np.float32], 55 | ensure_min_samples=ensure_min_samples) 56 | if n_components is not None and X.shape[0] < n_components: 57 | raise ValueError('Expected n_samples >= n_components ' 58 | 'but got n_components = %d, n_samples = %d' 59 | % (n_components, X.shape[0])) 60 | if n_features is not None and X.shape[1] != n_features: 61 | raise ValueError("Expected the input data X have %d features, " 62 | "but got %d features" 63 | % (n_features, X.shape[1])) 64 | return X 65 | 66 | 67 | class BaseMixture(six.with_metaclass(ABCMeta, DensityMixin, BaseEstimator)): 68 | """Base class for mixture models. 69 | 70 | This abstract class specifies an interface for all mixture classes and 71 | provides basic common methods for mixture models. 72 | """ 73 | 74 | def __init__(self, n_components, tol, reg_covar, 75 | max_iter, n_init, init_params, random_state, warm_start, 76 | verbose, verbose_interval): 77 | self.n_components = n_components 78 | self.tol = tol 79 | self.reg_covar = reg_covar 80 | self.max_iter = max_iter 81 | self.n_init = n_init 82 | self.init_params = init_params 83 | self.random_state = random_state 84 | self.warm_start = warm_start 85 | self.verbose = verbose 86 | self.verbose_interval = verbose_interval 87 | 88 | def _check_initial_parameters(self, X): 89 | """Check values of the basic parameters. 90 | 91 | Parameters 92 | ---------- 93 | X : array-like, shape (n_samples, n_features) 94 | """ 95 | if self.n_components < 1: 96 | raise ValueError("Invalid value for 'n_components': %d " 97 | "Estimation requires at least one component" 98 | % self.n_components) 99 | 100 | if self.tol < 0.: 101 | raise ValueError("Invalid value for 'tol': %.5f " 102 | "Tolerance used by the EM must be non-negative" 103 | % self.tol) 104 | 105 | if self.n_init < 1: 106 | raise ValueError("Invalid value for 'n_init': %d " 107 | "Estimation requires at least one run" 108 | % self.n_init) 109 | 110 | if self.max_iter < 1: 111 | raise ValueError("Invalid value for 'max_iter': %d " 112 | "Estimation requires at least one iteration" 113 | % self.max_iter) 114 | 115 | if self.reg_covar < 0.: 116 | raise ValueError("Invalid value for 'reg_covar': %.5f " 117 | "regularization on covariance must be " 118 | "non-negative" 119 | % self.reg_covar) 120 | 121 | # Check all the parameters values of the derived class 122 | self._check_parameters(X) 123 | 124 | @abstractmethod 125 | def _check_parameters(self, X): 126 | """Check initial parameters of the derived class. 127 | 128 | Parameters 129 | ---------- 130 | X : array-like, shape (n_samples, n_features) 131 | """ 132 | pass 133 | 134 | def _initialize_parameters(self, X, random_state): 135 | """Initialize the model parameters. 136 | 137 | Parameters 138 | ---------- 139 | X : array-like, shape (n_samples, n_features) 140 | 141 | random_state : RandomState 142 | A random number generator instance. 143 | """ 144 | n_samples, _ = X.shape 145 | 146 | if self.init_params == 'kmeans': 147 | resp = np.zeros((n_samples, self.n_components)) 148 | label = cluster.MiniBatchKMeans(n_clusters=self.n_components, n_init=1,batch_size=150, 149 | random_state=random_state).fit(X).labels_ 150 | resp[np.arange(n_samples), label] = 1 151 | elif self.init_params == 'random': 152 | resp = random_state.rand(n_samples, self.n_components) 153 | resp /= resp.sum(axis=1)[:, np.newaxis] 154 | else: 155 | raise ValueError("Unimplemented initialization method '%s'" 156 | % self.init_params) 157 | 158 | self._initialize(X, resp) 159 | 160 | @abstractmethod 161 | def _initialize(self, X, resp): 162 | """Initialize the model parameters of the derived class. 163 | 164 | Parameters 165 | ---------- 166 | X : array-like, shape (n_samples, n_features) 167 | 168 | resp : array-like, shape (n_samples, n_components) 169 | """ 170 | pass 171 | 172 | def fit(self, X, y=None): 173 | """Estimate model parameters with the EM algorithm. 174 | 175 | The method fits the model ``n_init`` times and sets the parameters with 176 | which the model has the largest likelihood or lower bound. Within each 177 | trial, the method iterates between E-step and M-step for ``max_iter`` 178 | times until the change of likelihood or lower bound is less than 179 | ``tol``, otherwise, a ``ConvergenceWarning`` is raised. 180 | If ``warm_start`` is ``True``, then ``n_init`` is ignored and a single 181 | initialization is performed upon the first call. Upon consecutive 182 | calls, training starts where it left off. 183 | 184 | Parameters 185 | ---------- 186 | X : array-like, shape (n_samples, n_features) 187 | List of n_features-dimensional data points. Each row 188 | corresponds to a single data point. 189 | 190 | Returns 191 | ------- 192 | self 193 | """ 194 | self.fit_predict(X, y) 195 | return self 196 | 197 | def fit_predict(self, X, y=None): 198 | """Estimate model parameters using X and predict the labels for X. 199 | 200 | The method fits the model n_init times and sets the parameters with 201 | which the model has the largest likelihood or lower bound. Within each 202 | trial, the method iterates between E-step and M-step for `max_iter` 203 | times until the change of likelihood or lower bound is less than 204 | `tol`, otherwise, a `ConvergenceWarning` is raised. After fitting, it 205 | predicts the most probable label for the input data points. 206 | 207 | .. versionadded:: 0.20 208 | 209 | Parameters 210 | ---------- 211 | X : array-like, shape (n_samples, n_features) 212 | List of n_features-dimensional data points. Each row 213 | corresponds to a single data point. 214 | 215 | Returns 216 | ------- 217 | labels : array, shape (n_samples,) 218 | Component labels. 219 | """ 220 | X = _check_X(X, self.n_components, ensure_min_samples=2) 221 | self._check_initial_parameters(X) 222 | 223 | # if we enable warm_start, we will have a unique initialisation 224 | do_init = not(self.warm_start and hasattr(self, 'converged_')) 225 | n_init = self.n_init if do_init else 1 226 | 227 | max_lower_bound = -np.infty 228 | self.converged_ = False 229 | 230 | random_state = check_random_state(self.random_state) 231 | 232 | n_samples, _ = X.shape 233 | for init in range(n_init): 234 | self._print_verbose_msg_init_beg(init) 235 | 236 | if do_init: 237 | self._initialize_parameters(X, random_state) 238 | 239 | lower_bound = (-np.infty if do_init else self.lower_bound_) 240 | 241 | for n_iter in range(1, self.max_iter + 1): 242 | prev_lower_bound = lower_bound 243 | 244 | log_prob_norm, log_resp = self._e_step(X) 245 | self._m_step(X, log_resp) 246 | lower_bound = self._compute_lower_bound( 247 | log_resp, log_prob_norm) 248 | 249 | change = lower_bound - prev_lower_bound 250 | self._print_verbose_msg_iter_end(n_iter, change) 251 | 252 | if False and abs(change) < self.tol: 253 | self.converged_ = True 254 | break 255 | 256 | self._print_verbose_msg_init_end(lower_bound) 257 | 258 | if lower_bound > max_lower_bound: 259 | max_lower_bound = lower_bound 260 | best_params = self._get_parameters() 261 | best_n_iter = n_iter 262 | 263 | if not self.converged_: 264 | warnings.warn('Initialization %d did not converge. ' 265 | 'Try different init parameters, ' 266 | 'or increase max_iter, tol ' 267 | 'or check for degenerate data.' 268 | % (init + 1), ConvergenceWarning) 269 | 270 | self._set_parameters(best_params) 271 | self.n_iter_ = best_n_iter 272 | self.lower_bound_ = max_lower_bound 273 | 274 | return log_resp.argmax(axis=1) 275 | 276 | def _e_step(self, X): 277 | """E step. 278 | 279 | Parameters 280 | ---------- 281 | X : array-like, shape (n_samples, n_features) 282 | 283 | Returns 284 | ------- 285 | log_prob_norm : float 286 | Mean of the logarithms of the probabilities of each sample in X 287 | 288 | log_responsibility : array, shape (n_samples, n_components) 289 | Logarithm of the posterior probabilities (or responsibilities) of 290 | the point of each sample in X. 291 | """ 292 | log_prob_norm, log_resp = self._estimate_log_prob_resp(X) 293 | return np.mean(log_prob_norm), log_resp 294 | 295 | @abstractmethod 296 | def _m_step(self, X, log_resp): 297 | """M step. 298 | 299 | Parameters 300 | ---------- 301 | X : array-like, shape (n_samples, n_features) 302 | 303 | log_resp : array-like, shape (n_samples, n_components) 304 | Logarithm of the posterior probabilities (or responsibilities) of 305 | the point of each sample in X. 306 | """ 307 | pass 308 | 309 | @abstractmethod 310 | def _check_is_fitted(self): 311 | pass 312 | 313 | @abstractmethod 314 | def _get_parameters(self): 315 | pass 316 | 317 | @abstractmethod 318 | def _set_parameters(self, params): 319 | pass 320 | 321 | def score_samples(self, X): 322 | """Compute the weighted log probabilities for each sample. 323 | 324 | Parameters 325 | ---------- 326 | X : array-like, shape (n_samples, n_features) 327 | List of n_features-dimensional data points. Each row 328 | corresponds to a single data point. 329 | 330 | Returns 331 | ------- 332 | log_prob : array, shape (n_samples,) 333 | Log probabilities of each data point in X. 334 | """ 335 | self._check_is_fitted() 336 | X = _check_X(X, None, self.means_.shape[1]) 337 | 338 | return logsumexp(self._estimate_weighted_log_prob(X), axis=1) 339 | 340 | def score(self, X, y=None): 341 | """Compute the per-sample average log-likelihood of the given data X. 342 | 343 | Parameters 344 | ---------- 345 | X : array-like, shape (n_samples, n_dimensions) 346 | List of n_features-dimensional data points. Each row 347 | corresponds to a single data point. 348 | 349 | Returns 350 | ------- 351 | log_likelihood : float 352 | Log likelihood of the Gaussian mixture given X. 353 | """ 354 | return self.score_samples(X).mean() 355 | 356 | def predict(self, X): 357 | """Predict the labels for the data samples in X using trained model. 358 | 359 | Parameters 360 | ---------- 361 | X : array-like, shape (n_samples, n_features) 362 | List of n_features-dimensional data points. Each row 363 | corresponds to a single data point. 364 | 365 | Returns 366 | ------- 367 | labels : array, shape (n_samples,) 368 | Component labels. 369 | """ 370 | self._check_is_fitted() 371 | X = _check_X(X, None, self.means_.shape[1]) 372 | return self._estimate_weighted_log_prob(X).argmax(axis=1) 373 | 374 | def predict_proba(self, X): 375 | """Predict posterior probability of each component given the data. 376 | 377 | Parameters 378 | ---------- 379 | X : array-like, shape (n_samples, n_features) 380 | List of n_features-dimensional data points. Each row 381 | corresponds to a single data point. 382 | 383 | Returns 384 | ------- 385 | resp : array, shape (n_samples, n_components) 386 | Returns the probability each Gaussian (state) in 387 | the model given each sample. 388 | """ 389 | self._check_is_fitted() 390 | X = _check_X(X, None, self.means_.shape[1]) 391 | _, log_resp = self._estimate_log_prob_resp(X) 392 | return np.exp(log_resp) 393 | 394 | def sample(self, n_samples=1): 395 | """Generate random samples from the fitted Gaussian distribution. 396 | 397 | Parameters 398 | ---------- 399 | n_samples : int, optional 400 | Number of samples to generate. Defaults to 1. 401 | 402 | Returns 403 | ------- 404 | X : array, shape (n_samples, n_features) 405 | Randomly generated sample 406 | 407 | y : array, shape (nsamples,) 408 | Component labels 409 | 410 | """ 411 | self._check_is_fitted() 412 | 413 | if n_samples < 1: 414 | raise ValueError( 415 | "Invalid value for 'n_samples': %d . The sampling requires at " 416 | "least one sample." % (self.n_components)) 417 | 418 | _, n_features = self.means_.shape 419 | rng = check_random_state(self.random_state) 420 | n_samples_comp = rng.multinomial(n_samples, self.weights_) 421 | 422 | if self.covariance_type == 'full': 423 | X = np.vstack([ 424 | rng.multivariate_normal(mean, covariance, int(sample)) 425 | for (mean, covariance, sample) in zip( 426 | self.means_, self.covariances_, n_samples_comp)]) 427 | elif self.covariance_type == "tied": 428 | X = np.vstack([ 429 | rng.multivariate_normal(mean, self.covariances_, int(sample)) 430 | for (mean, sample) in zip( 431 | self.means_, n_samples_comp)]) 432 | else: 433 | X = np.vstack([ 434 | mean + rng.randn(sample, n_features) * np.sqrt(covariance) 435 | for (mean, covariance, sample) in zip( 436 | self.means_, self.covariances_, n_samples_comp)]) 437 | 438 | y = np.concatenate([np.full(sample, j, dtype=int) 439 | for j, sample in enumerate(n_samples_comp)]) 440 | 441 | return (X, y) 442 | 443 | def _estimate_weighted_log_prob(self, X): 444 | """Estimate the weighted log-probabilities, log P(X | Z) + log weights. 445 | 446 | Parameters 447 | ---------- 448 | X : array-like, shape (n_samples, n_features) 449 | 450 | Returns 451 | ------- 452 | weighted_log_prob : array, shape (n_samples, n_component) 453 | """ 454 | return self._estimate_log_prob(X) + self._estimate_log_weights() 455 | 456 | @abstractmethod 457 | def _estimate_log_weights(self): 458 | """Estimate log-weights in EM algorithm, E[ log pi ] in VB algorithm. 459 | 460 | Returns 461 | ------- 462 | log_weight : array, shape (n_components, ) 463 | """ 464 | pass 465 | 466 | @abstractmethod 467 | def _estimate_log_prob(self, X): 468 | """Estimate the log-probabilities log P(X | Z). 469 | 470 | Compute the log-probabilities per each component for each sample. 471 | 472 | Parameters 473 | ---------- 474 | X : array-like, shape (n_samples, n_features) 475 | 476 | Returns 477 | ------- 478 | log_prob : array, shape (n_samples, n_component) 479 | """ 480 | pass 481 | 482 | def _estimate_log_prob_resp(self, X): 483 | """Estimate log probabilities and responsibilities for each sample. 484 | 485 | Compute the log probabilities, weighted log probabilities per 486 | component and responsibilities for each sample in X with respect to 487 | the current state of the model. 488 | 489 | Parameters 490 | ---------- 491 | X : array-like, shape (n_samples, n_features) 492 | 493 | Returns 494 | ------- 495 | log_prob_norm : array, shape (n_samples,) 496 | log p(X) 497 | 498 | log_responsibilities : array, shape (n_samples, n_components) 499 | logarithm of the responsibilities 500 | """ 501 | weighted_log_prob = self._estimate_weighted_log_prob(X) 502 | log_prob_norm = logsumexp(weighted_log_prob, axis=1) 503 | with np.errstate(under='ignore'): 504 | # ignore underflow 505 | log_resp = weighted_log_prob - log_prob_norm[:, np.newaxis] 506 | return log_prob_norm, log_resp 507 | 508 | def _print_verbose_msg_init_beg(self, n_init): 509 | """Print verbose message on initialization.""" 510 | if self.verbose == 1: 511 | print("Initialization %d" % n_init) 512 | elif self.verbose >= 2: 513 | print("Initialization %d" % n_init) 514 | self._init_prev_time = time() 515 | self._iter_prev_time = self._init_prev_time 516 | 517 | def _print_verbose_msg_iter_end(self, n_iter, diff_ll): 518 | """Print verbose message on initialization.""" 519 | if n_iter % self.verbose_interval == 0: 520 | if self.verbose == 1: 521 | print(" Iteration %d" % n_iter) 522 | elif self.verbose >= 2: 523 | cur_time = time() 524 | print(" Iteration %d\t time lapse %.5fs\t ll change %.5f" % ( 525 | n_iter, cur_time - self._iter_prev_time, diff_ll)) 526 | self._iter_prev_time = cur_time 527 | 528 | def _print_verbose_msg_init_end(self, ll): 529 | """Print verbose message on the end of iteration.""" 530 | if self.verbose == 1: 531 | print("Initialization converged: %s" % self.converged_) 532 | elif self.verbose >= 2: 533 | print("Initialization converged: %s\t time lapse %.5fs\t ll %.5f" % 534 | (self.converged_, time() - self._init_prev_time, ll)) 535 | -------------------------------------------------------------------------------- /mixture/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonidk/direct_gmm/6735bef868dc3ba061ca544f5e95ce2fed09f70a/mixture/tests/__init__.py -------------------------------------------------------------------------------- /mixture/tests/test_mixture.py: -------------------------------------------------------------------------------- 1 | # Author: Guillaume Lemaitre 2 | # License: BSD 3 clause 3 | 4 | import pytest 5 | import numpy as np 6 | 7 | from sklearn.mixture import GaussianMixture 8 | from sklearn.mixture import BayesianGaussianMixture 9 | 10 | 11 | @pytest.mark.parametrize( 12 | "estimator", 13 | [GaussianMixture(), 14 | BayesianGaussianMixture()] 15 | ) 16 | def test_gaussian_mixture_n_iter(estimator): 17 | # check that n_iter is the number of iteration performed. 18 | rng = np.random.RandomState(0) 19 | X = rng.rand(10, 5) 20 | max_iter = 1 21 | estimator.set_params(max_iter=max_iter) 22 | estimator.fit(X) 23 | assert estimator.n_iter_ == max_iter 24 | -------------------------------------------------------------------------------- /plot_trajec.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pylab import *" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "normal_path = np.loadtxt('normal_traj.txt')[:,1:]\n", 19 | "patch_path = np.loadtxt('patch_traj.txt')[:,1:]\n", 20 | "icp_path = np.loadtxt('icp_traj.txt')[:,1:]\n", 21 | "icp2k_path = np.loadtxt('icp2k_traj.txt')[:,1:]\n", 22 | "\n", 23 | "gt_path = np.loadtxt('rgbd_dataset_freiburg3_long_office_household-groundtruth.txt')[:,1:4]" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "if False:\n", 33 | " gt_path = gt_path - gt_path[0]\n", 34 | " test_path = np.loadtxt('gmm_res_normal_nodet.txt')[:,1:4]\n", 35 | " test_path.shape\n", 36 | " test_path = test_path[:,[0,2,1]]\n", 37 | " normal_path = test_path - gt_path[0]\n", 38 | "\n", 39 | " test_path = np.loadtxt('gmm_res_patch_nodet.txt')[:,1:4]\n", 40 | " test_path.shape\n", 41 | " test_path = test_path[:,[0,2,1]]\n", 42 | " patch_path = test_path - gt_path[0]\n" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "from matplotlib import rc\n", 52 | "plt.style.use('fivethirtyeight')\n", 53 | "plt.style.use('seaborn-white')\n", 54 | "plt.rcParams[\"font.family\"] = \"sans-serif\"\n", 55 | "rc('font',**{'family':'sans-serif','sans-serif':['cm']})\n", 56 | "fig = plt.figure(figsize=(10,10))\n", 57 | "\n", 58 | "for lbl,path in zip(['GMM','ICP','Ground Truth'],[patch_path,icp2k_path,gt_path]):\n", 59 | " if 'Truth' in lbl:\n", 60 | " plt.plot(path[:,0],path[:,1],label=lbl,color='k')\n", 61 | " else:\n", 62 | " plt.plot(path[:,0],path[:,1],label=lbl)\n", 63 | "plt.xlabel('x location (m)')\n", 64 | "plt.ylabel('y location (m)')\n", 65 | "plt.legend(frameon=True,edgecolor='k')\n", 66 | "plt.tight_layout()\n", 67 | "\n", 68 | "# Hide the right and top spines\n", 69 | "plt.gca().spines['right'].set_visible(False)\n", 70 | "plt.gca().spines['top'].set_visible(False)\n", 71 | "\n", 72 | "# Only show ticks on the left and bottom spines\n", 73 | "plt.gca().yaxis.set_ticks_position('left')\n", 74 | "plt.gca().xaxis.set_ticks_position('bottom')\n", 75 | "fig.savefig('slam2.pdf', facecolor=fig.get_facecolor(), edgecolor='none')" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "gt_path[0]" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [] 93 | } 94 | ], 95 | "metadata": { 96 | "kernelspec": { 97 | "display_name": "Python 3", 98 | "language": "python", 99 | "name": "python3" 100 | }, 101 | "language_info": { 102 | "codemirror_mode": { 103 | "name": "ipython", 104 | "version": 3 105 | }, 106 | "file_extension": ".py", 107 | "mimetype": "text/x-python", 108 | "name": "python", 109 | "nbconvert_exporter": "python", 110 | "pygments_lexer": "ipython3", 111 | "version": "3.6.7" 112 | } 113 | }, 114 | "nbformat": 4, 115 | "nbformat_minor": 2 116 | } 117 | -------------------------------------------------------------------------------- /reg_viz-extra.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pylab import *\n", 10 | "import matplotlib\n", 11 | "import matplotlib.colors" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "names = ['meshk','mesh','vertsk','verts','icp'] #,'areas'\n", 21 | "#data_20k_fullverts/ #data_monday_final_k100/\n", 22 | "for namef in ['arma','bunny','dragon','happy','lucy']:\n", 23 | " res = [np.loadtxt('./{}_{}2.csv'.format(namef,n),delimiter=',') for n in names]\n", 24 | "\n", 25 | "\n", 26 | " angle_errors = []\n", 27 | " for result in res:\n", 28 | " #new_res = np.minimum(2*np.arccos(result[:,0]),2*np.arccos(-result[:,0]))\n", 29 | " new_res = np.clip(2*result[:,0]**2-1,1e-9,1-1e-9)\n", 30 | " new_res =np.arccos(new_res)\n", 31 | " new_res[np.isnan(new_res)] = np.pi\n", 32 | " angle_errors.append(new_res*180.0/np.pi)\n", 33 | " print(namef)\n", 34 | " #print((angle_errors[2].mean()-angle_errors[0].mean())/angle_errors[2].mean() *100)\n", 35 | " #print((res[2][:,1].mean()-res[0][:,1].mean())/res[2][:,1].mean() * 100)\n", 36 | " print((angle_errors[2].mean())/angle_errors[4].mean() *100)\n", 37 | " print((angle_errors[0].mean())/angle_errors[4].mean() *100)\n", 38 | "\n", 39 | " print((res[2][:,1].mean())/res[4][:,1].mean() *100)\n", 40 | " print((res[0][:,1].mean())/res[4][:,1].mean() *100)\n", 41 | "\n", 42 | "names[0] = 'mesh\\n(kmeans)'\n", 43 | "names[1] = 'mesh\\n(random)'\n", 44 | "#names[2] = 'mesh (a)\\n(kmeans)'\n", 45 | "names[2] = 'points\\n(kmeans)'\n", 46 | "names[3] = 'points\\n(random)'" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "for x,l in zip(res,names):\n", 56 | " plt.scatter(np.sqrt(x[:,0]**2),x[:,1],label=l,s=100,alpha=0.5)\n", 57 | "plt.legend()\n", 58 | "plt.xlabel('rotation error')\n", 59 | "plt.ylabel('translation error')\n" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "res[2].mean(0)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "from matplotlib import rc\n", 99 | "plt.style.use('fivethirtyeight')\n", 100 | "plt.style.use('seaborn-white')\n", 101 | "#plt.rcParams[\"font.family\"] = \"sans-serif\"\n", 102 | "rc('font',**{'family':'sans-serif','sans-serif':['cm']})\n", 103 | "#plt.rcParams['font.sans-serif'] = ['Helvetica']\n", 104 | "fig = plt.figure(figsize=(15,6))\n", 105 | "plt.subplot(1,2,1)\n", 106 | "model_scale = np.sqrt((0.1513)**2 + (0.1483)**2 + (0.1144)**2)\n", 107 | "\n", 108 | "angle_errors = []\n", 109 | "for result in res:\n", 110 | " #new_res = np.minimum(2*np.arccos(result[:,0]),2*np.arccos(-result[:,0]))\n", 111 | " new_res = np.clip(2*result[:,0]**2-1,1e-9,1-1e-9)\n", 112 | " new_res =np.arccos(new_res)\n", 113 | " new_res[np.isnan(new_res)] = np.pi\n", 114 | " angle_errors.append(new_res*180.0/np.pi)\n", 115 | "\n", 116 | "plt.boxplot([_ for _ in angle_errors],labels=names,notch=True,flierprops={'marker':''},medianprops={'color':'k'})\n", 117 | "plt.title('rotation error')\n", 118 | "plt.ylabel('degrees')\n", 119 | "#plt.ylim(0,.05)\n", 120 | "#plt.ylim(top=4,bottom=0)\n", 121 | "plt.ylim(bottom=0)\n", 122 | "for i,err in enumerate(res):\n", 123 | " y = angle_errors[i]\n", 124 | " x = np.random.normal(i+1, 0.05, size=len(y))\n", 125 | " p = plt.plot(x,y,'.',alpha=0.3)\n", 126 | " plt_color = matplotlib.colors.hex2color(p[0].get_color())\n", 127 | " clr2 = tuple(np.array(plt_color)*0.5)\n", 128 | " plt.plot(i+1,y.mean(),'.',c=plt_color,ms=20,alpha=0.8,mec=clr2,lw=1.0)\n", 129 | "plt.ylim(top=5.5)\n", 130 | "plt.grid(True,axis='y')\n", 131 | "plt.subplot(1,2,2)\n", 132 | "\n", 133 | "plt.boxplot([_[:,1]/model_scale * 100 for _ in res],labels=names,notch=True,flierprops={'marker':''},medianprops={'color':'k'})\n", 134 | "#plt.ylim(0,0.01)\n", 135 | "plt.title('position error')\n", 136 | "plt.ylabel('percent of model scale')\n", 137 | "#print([2*np.arccos(abs(_[:,0])).mean() for _ in res])\n", 138 | "print([_.mean() for _ in angle_errors])\n", 139 | "\n", 140 | "print([_[:,1].mean() for _ in res])\n", 141 | "#print([2*np.arccos(abs(_[:,0])).max() for _ in res])\n", 142 | "\n", 143 | "plt.ylim(bottom=0)\n", 144 | "for i,err in enumerate(res):\n", 145 | " y = err[:,1]/model_scale * 100\n", 146 | " x = np.random.normal(i+1, 0.05, size=len(y))\n", 147 | " p = plt.plot(x,y,'.',alpha=0.3)\n", 148 | " plt_color = matplotlib.colors.hex2color(p[0].get_color())\n", 149 | " clr2 = tuple(np.array(plt_color)*0.5)\n", 150 | " plt.plot(i+1,y.mean(),'.',c=plt_color,ms=20,alpha=0.8,mec=clr2,lw=1.0)\n", 151 | "#plt.ylim(top=0.025/0.05)\n", 152 | "plt.ylim(top=1.1)\n", 153 | "plt.grid(True,axis='y')\n", 154 | "plt.tight_layout()\n", 155 | "\n", 156 | "\n", 157 | "fig.savefig('new-viz-k100_2.pdf', facecolor=fig.get_facecolor(), edgecolor='none')" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "plt.rcParams" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [] 175 | } 176 | ], 177 | "metadata": { 178 | "kernelspec": { 179 | "display_name": "Python 3", 180 | "language": "python", 181 | "name": "python3" 182 | }, 183 | "language_info": { 184 | "codemirror_mode": { 185 | "name": "ipython", 186 | "version": 3 187 | }, 188 | "file_extension": ".py", 189 | "mimetype": "text/x-python", 190 | "name": "python", 191 | "nbconvert_exporter": "python", 192 | "pygments_lexer": "ipython3", 193 | "version": "3.6.7" 194 | } 195 | }, 196 | "nbformat": 4, 197 | "nbformat_minor": 2 198 | } 199 | -------------------------------------------------------------------------------- /reg_viz.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pylab import *\n", 10 | "import matplotlib\n", 11 | "import matplotlib.colors" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "names = ['meshk','mesh','vertsk','verts','icp','cpd'] #,'areas'\n", 21 | "#data_20k_fullverts/ #data_monday_final_k100/\n", 22 | "res = [np.loadtxt('data_sunday_full/{}2.csv'.format(n),delimiter=',') for n in names]\n", 23 | "names[0] = 'mesh\\n(kmeans)'\n", 24 | "names[1] = 'mesh\\n(random)'\n", 25 | "#names[2] = 'mesh (a)\\n(kmeans)'\n", 26 | "names[2] = 'points\\n(kmeans)'\n", 27 | "names[3] = 'points\\n(random)'" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "for x,l in zip(res,names):\n", 37 | " plt.scatter(np.sqrt(x[:,0]**2),x[:,1],label=l,s=100,alpha=0.5)\n", 38 | "plt.legend()\n", 39 | "plt.xlabel('rotation error')\n", 40 | "plt.ylabel('translation error')\n" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "[(min(X),max(X)) for X in [_[:,0] for _ in res]]" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "from matplotlib import rc\n", 59 | "plt.style.use('fivethirtyeight')\n", 60 | "plt.style.use('seaborn-white')\n", 61 | "#plt.rcParams[\"font.family\"] = \"sans-serif\"\n", 62 | "rc('font',**{'family':'sans-serif','sans-serif':['cm']})\n", 63 | "#plt.rcParams['font.sans-serif'] = ['Helvetica']\n", 64 | "fig = plt.figure(figsize=(15,6))\n", 65 | "plt.subplot(1,2,1)\n", 66 | "model_scale = np.sqrt((0.1513)**2 + (0.1483)**2 + (0.1144)**2)\n", 67 | "\n", 68 | "angle_errors = []\n", 69 | "for result in res:\n", 70 | " #new_res = np.minimum(2*np.arccos(result[:,0]),2*np.arccos(-result[:,0]))\n", 71 | " new_res = np.clip(2*result[:,0]**2-1,1e-9,1-1e-9)\n", 72 | " new_res =np.arccos(new_res)\n", 73 | " new_res[np.isnan(new_res)] = np.pi\n", 74 | " angle_errors.append(new_res*180.0/np.pi)\n", 75 | "\n", 76 | "plt.boxplot([_ for _ in angle_errors],labels=names,notch=True,flierprops={'marker':''},medianprops={'color':'k'})\n", 77 | "plt.title('rotation error')\n", 78 | "plt.ylabel('degrees')\n", 79 | "#plt.ylim(0,.05)\n", 80 | "#plt.ylim(top=4,bottom=0)\n", 81 | "plt.ylim(bottom=0)\n", 82 | "for i,err in enumerate(res):\n", 83 | " y = angle_errors[i]\n", 84 | " x = np.random.normal(i+1, 0.05, size=len(y))\n", 85 | " p = plt.plot(x,y,'.',alpha=0.3)\n", 86 | " plt_color = matplotlib.colors.hex2color(p[0].get_color())\n", 87 | " clr2 = tuple(np.array(plt_color)*0.5)\n", 88 | " plt.plot(i+1,y.mean(),'.',c=plt_color,ms=20,alpha=0.8,mec=clr2,lw=1.0)\n", 89 | "plt.ylim(top=5.5)\n", 90 | "plt.grid(True,axis='y')\n", 91 | "plt.subplot(1,2,2)\n", 92 | "\n", 93 | "plt.boxplot([_[:,1]/model_scale * 100 for _ in res],labels=names,notch=True,flierprops={'marker':''},medianprops={'color':'k'})\n", 94 | "#plt.ylim(0,0.01)\n", 95 | "plt.title('translation error')\n", 96 | "plt.ylabel('percent of model scale')\n", 97 | "#print([2*np.arccos(abs(_[:,0])).mean() for _ in res])\n", 98 | "print([_.mean() for _ in angle_errors])\n", 99 | "\n", 100 | "print([_[:,1].mean() for _ in res])\n", 101 | "#print([2*np.arccos(abs(_[:,0])).max() for _ in res])\n", 102 | "\n", 103 | "plt.ylim(bottom=0)\n", 104 | "for i,err in enumerate(res):\n", 105 | " y = err[:,1]/model_scale * 100\n", 106 | " x = np.random.normal(i+1, 0.05, size=len(y))\n", 107 | " p = plt.plot(x,y,'.',alpha=0.3)\n", 108 | " plt_color = matplotlib.colors.hex2color(p[0].get_color())\n", 109 | " clr2 = tuple(np.array(plt_color)*0.5)\n", 110 | " plt.plot(i+1,y.mean(),'.',c=plt_color,ms=20,alpha=0.8,mec=clr2,lw=1.0)\n", 111 | "#plt.ylim(top=0.025/0.05)\n", 112 | "plt.ylim(top=1.1)\n", 113 | "plt.grid(True,axis='y')\n", 114 | "plt.tight_layout()\n", 115 | "\n", 116 | "\n", 117 | "fig.savefig('new-viz-k100_2.pdf', facecolor=fig.get_facecolor(), edgecolor='none')" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "plt.rcParams" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [] 135 | } 136 | ], 137 | "metadata": { 138 | "kernelspec": { 139 | "display_name": "Python 3", 140 | "language": "python", 141 | "name": "python3" 142 | }, 143 | "language_info": { 144 | "codemirror_mode": { 145 | "name": "ipython", 146 | "version": 3 147 | }, 148 | "file_extension": ".py", 149 | "mimetype": "text/x-python", 150 | "name": "python", 151 | "nbconvert_exporter": "python", 152 | "pygments_lexer": "ipython3", 153 | "version": "3.6.7" 154 | } 155 | }, 156 | "nbformat": 4, 157 | "nbformat_minor": 2 158 | } 159 | -------------------------------------------------------------------------------- /registration_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import multivariate_normal as mvn_pdf 3 | from mpl_toolkits.mplot3d import Axes3D 4 | import matplotlib.pyplot as plt 5 | from mixture import GaussianMixture 6 | from scipy.spatial.distance import cdist,pdist 7 | import pymesh 8 | import pickle 9 | from scipy.special import logsumexp 10 | import scipy.optimize as opt 11 | import transforms3d 12 | from pycpd import rigid_registration 13 | import time 14 | 15 | SAMPLE_NUM = 250 16 | method = None#'CG'#None#'CG'#None#CG' 17 | K = 100 18 | SAMPLE_PTS = 453 # number of vertecies! 19 | ICP_ITERS = 50000 #150 20 | ICP_THRESH = 1e-9 21 | CPD_THRESH = 1e-9 22 | CPD_ITERS = 150 #500 23 | mesh0 = pymesh.load_mesh("bunny/bun_zipper_res4.ply") 24 | mesh_pts = pymesh.load_mesh("bunny/bun_zipper_res4_sds.ply") 25 | #mesh0 = pymesh.load_mesh("bunny/bun_zipper.ply") 26 | #mesh_pts = pymesh.load_mesh("bunny/bun_zipper_50k.ply") 27 | #mesh0 = pymesh.load_mesh("bunny/bun_zipper_1000_1.ply") 28 | #mesh_pts = pymesh.load_mesh("bunny/bun_zipper_50k.ply") 29 | 30 | def get_centroids(mesh): 31 | # obtain a vertex for each face index 32 | face_vert = mesh.vertices[mesh.faces.reshape(-1),:].reshape((mesh.faces.shape[0],3,-1)) 33 | # face_vert is size (faces,3(one for each vert), 3(one for each dimension)) 34 | centroids = face_vert.sum(1)/3.0 35 | #face_vert = ((face_vert.shape[0]/SAMPLE_PTS)*(face_vert.reshape((-1,9))-np.repeat(centroids,3,axis=1)) + np.repeat(centroids,3,axis=1)).reshape((-1,3,3)) 36 | ABAC = face_vert[:,1:3,:] - face_vert[:,0:1,:] 37 | areas = np.linalg.norm(np.cross(ABAC[:,0,:],ABAC[:,1,:]),axis=1)/2.0 38 | return centroids, areas 39 | 40 | def get_tri_covar(tris): 41 | covars = [] 42 | for face in tris: 43 | A = face[0][:,None] 44 | B = face[1][:,None] 45 | C = face[2][:,None] 46 | M = (A+B+C)/3 47 | covars.append(A @ A.T + B @ B.T + C @ C.T - 3* M @ M.T) 48 | return np.array(covars)*(1/12.0) 49 | 50 | com,a = get_centroids(mesh0) 51 | print(com.shape) 52 | face_vert = mesh0.vertices[mesh0.faces.reshape(-1),:].reshape((mesh0.faces.shape[0],3,-1)) 53 | data_covar = get_tri_covar(face_vert) 54 | print(data_covar.shape) 55 | 56 | indices2 = np.random.randint(0,mesh_pts.vertices.shape[0],SAMPLE_PTS) 57 | samples_for_icp = mesh0.vertices#np.copy(mesh_pts.vertices[indices2]) 58 | #gm3 = GaussianMixture(100,init_params='kmeans'); gm3.set_triangles(face_vert); gm3.fit(com); gm3.set_triangles(None) 59 | #usually tol=1e-4,max_iter=100 60 | t1 = time.time() 61 | gm_std_km = GaussianMixture(K,init_params='kmeans',tol=1e-5,max_iter=100); gm_std_km.fit(samples_for_icp) 62 | print((time.time()-t1)*1000) 63 | t1 = time.time() 64 | gm_std = GaussianMixture(K,init_params='random',tol=1e-5,max_iter=100); gm_std.fit(samples_for_icp) 65 | print((time.time()-t1)*1000) 66 | t1 = time.time() 67 | #indices3 = np.random.randint(0,mesh0.vertices.shape[0],SAMPLE_PTS) 68 | gm_mesh = GaussianMixture(K,init_params='random',tol=1e-5,max_iter=100); gm_mesh.set_covars(data_covar); gm_mesh.set_areas(a); gm_mesh.fit(com); gm_mesh.set_covars(None); gm_mesh.set_areas(None) 69 | print((time.time()-t1)*1000) 70 | t1 = time.time() 71 | #indices3 = np.random.randint(0,mesh0.vertices.shape[0],SAMPLE_PTS) 72 | gm_mesh_kmeans = GaussianMixture(K,init_params='kmeans',tol=1e-5,max_iter=100); gm_mesh_kmeans.set_covars(data_covar); gm_mesh_kmeans.set_areas(a); gm_mesh_kmeans.fit(com); gm_mesh_kmeans.set_covars(None); gm_mesh_kmeans.set_areas(None) 73 | print((time.time()-t1)*1000) 74 | 75 | 76 | gm_areas_kmeans = GaussianMixture(K,init_params='kmeans',tol=1e-5,max_iter=100); gm_areas_kmeans.set_areas(a); gm_areas_kmeans.fit(com); gm_areas_kmeans.set_areas(None) 77 | 78 | data_log_mesh = [] 79 | data_log_meshk = [] 80 | data_log_verts = [] 81 | data_log_vertsk = [] 82 | data_log_areas = [] 83 | data_log_icp = [] 84 | data_log_cpd = [] 85 | data_log_oracle = [] 86 | opt_times = [] 87 | opt_times_pts = [] 88 | icp_times = [] 89 | cpd_times = [] 90 | 91 | prev_time = time.time() 92 | for n in range(SAMPLE_NUM): 93 | 94 | indices2 = np.random.randint(0,mesh_pts.vertices.shape[0],SAMPLE_PTS) 95 | samples_for_icp = np.copy(mesh_pts.vertices[indices2]) 96 | full_points = samples_for_icp#mesh_pts.vertices 97 | indices = np.random.randint(0,mesh_pts.vertices.shape[0],SAMPLE_PTS) 98 | samples = np.copy(mesh_pts.vertices[indices]) 99 | samples_mean = samples.mean(0) 100 | centered_points = samples - samples_mean 101 | 102 | print(n,round(time.time()-prev_time,1),'seconds') 103 | prev_time = time.time() 104 | if False: # random transformations 105 | q = np.random.randn(4) 106 | q = q/np.linalg.norm(q) 107 | M = transforms3d.quaternions.quat2mat(q) 108 | t = np.random.randn(3)*0.05 109 | else: 110 | t = np.random.rand(3)*0.1 - 0.05 111 | angles = np.random.rand(3)*30 - 15 112 | angles *= np.pi/180.0 113 | M = transforms3d.euler.euler2mat(angles[0],angles[1],angles[2]) 114 | 115 | true_q = transforms3d.quaternions.mat2quat(M) 116 | 117 | source = centered_points @ M + samples_mean+ t 118 | sourcemean = source.mean(0) 119 | source_centered = source - sourcemean 120 | 121 | H = (source-source.mean(0)).T @ (samples-samples.mean(0)) 122 | u,s,vt = np.linalg.svd(H) 123 | R_reg = vt.T @ np.diag([1,1,np.linalg.det(vt.T @ u.T)]) @ u.T 124 | t_reg = source.mean(0)-samples.mean(0) 125 | oracle_q = transforms3d.quaternions.mat2quat(R_reg) 126 | data_log_oracle.append( [oracle_q.dot(true_q),np.linalg.norm(t_reg-t)] ) 127 | 128 | 129 | 130 | def loss_verts(x): 131 | qs = x[:4] 132 | ts = x[4:] 133 | qs = qs/np.linalg.norm(qs) 134 | Ms = transforms3d.quaternions.quat2mat(qs) 135 | tpts = (source_centered) @ Ms.T + sourcemean - ts 136 | return -gm_std.score(tpts) 137 | t1 = time.time() 138 | res = opt.minimize(loss_verts,np.array([1,0,0,0,0,0,0]),method=method) 139 | opt_times_pts.append(time.time()-t1) 140 | rq = res.x[:4] 141 | rq = rq/np.linalg.norm(rq) 142 | rt = res.x[4:] 143 | #print(method) 144 | #print(np.arccos(rq.dot(true_q)),np.linalg.norm(rt-t)) 145 | data_log_verts.append( [rq.dot(true_q),np.linalg.norm(rt-t)] ) 146 | def loss_mesh(x): 147 | qs = x[:4] 148 | ts = x[4:] 149 | qs = qs/np.linalg.norm(qs) 150 | Ms = transforms3d.quaternions.quat2mat(qs) 151 | tpts = (source_centered) @ Ms.T + sourcemean - ts 152 | return -gm_mesh.score(tpts) 153 | start_opt = time.time() 154 | res = opt.minimize(loss_mesh,np.array([1,0,0,0,0,0,0]),method=method) 155 | end_opt = time.time() 156 | opt_times.append(end_opt-start_opt) 157 | rq = res.x[:4] 158 | rq = rq/np.linalg.norm(rq) 159 | rt = res.x[4:] 160 | data_log_mesh.append( [rq.dot(true_q),np.linalg.norm(rt-t)] ) 161 | def loss_mesh_k(x): 162 | qs = x[:4] 163 | ts = x[4:] 164 | qs = qs/np.linalg.norm(qs) 165 | Ms = transforms3d.quaternions.quat2mat(qs) 166 | tpts = (source_centered) @ Ms.T + sourcemean - ts 167 | return -gm_std_km.score(tpts) 168 | start_opt = time.time() 169 | res = opt.minimize(loss_mesh_k,np.array([1,0,0,0,0,0,0]),method=method) 170 | end_opt = time.time() 171 | opt_times.append(end_opt-start_opt) 172 | rq = res.x[:4] 173 | rq = rq/np.linalg.norm(rq) 174 | rt = res.x[4:] 175 | data_log_vertsk.append( [rq.dot(true_q),np.linalg.norm(rt-t)] ) 176 | 177 | def loss_areas_k(x): 178 | qs = x[:4] 179 | ts = x[4:] 180 | qs = qs/np.linalg.norm(qs) 181 | Ms = transforms3d.quaternions.quat2mat(qs) 182 | tpts = (source_centered) @ Ms.T + sourcemean - ts 183 | return -gm_areas_kmeans.score(tpts) 184 | start_opt = time.time() 185 | res = opt.minimize(loss_areas_k,np.array([1,0,0,0,0,0,0]),method=method) 186 | end_opt = time.time() 187 | rq = res.x[:4] 188 | rq = rq/np.linalg.norm(rq) 189 | rt = res.x[4:] 190 | data_log_areas.append( [rq.dot(true_q),np.linalg.norm(rt-t)] ) 191 | 192 | if True: 193 | def loss_mesh_k2(x): 194 | qs = x[:4] 195 | ts = x[4:] 196 | qs = qs/np.linalg.norm(qs) 197 | Ms = transforms3d.quaternions.quat2mat(qs) 198 | tpts = (source_centered) @ Ms.T + sourcemean - ts 199 | return -gm_mesh_kmeans.score(tpts) 200 | start_opt = time.time() 201 | res = opt.minimize(loss_mesh_k2,np.array([1,0,0,0,0,0,0]),method=method) 202 | end_opt = time.time() 203 | rq = res.x[:4] 204 | rq = rq/np.linalg.norm(rq) 205 | rt = res.x[4:] 206 | data_log_meshk.append( [rq.dot(true_q),np.linalg.norm(rt-t)] ) 207 | else: 208 | def loss_mesh_k2(x): 209 | qs = x[:3] 210 | ts = x[3:] 211 | qs[-1] += 1e-9 212 | angle = np.linalg.norm(qs) 213 | axis = qs/angle 214 | Ms = transforms3d.axangles.axangle2mat(axis,angle) 215 | tpts = (source_centered) @ Ms.T + sourcemean - ts 216 | return -gm_mesh_kmeans.score(tpts) 217 | start_opt = time.time() 218 | res = opt.minimize(loss_mesh_k2,np.array([0,0,0,0,0,0]),method=method) 219 | end_opt = time.time() 220 | 221 | rq = res.x[:3] 222 | rq[-1] += 1e-9 223 | 224 | angle = np.linalg.norm(rq) 225 | axis = rq/angle 226 | rq = transforms3d.quaternions.axangle2quat(axis,angle) 227 | rt = res.x[3:] 228 | data_log_meshk.append( [rq.dot(true_q),np.linalg.norm(rt-t)] ) 229 | icp_t = np.zeros(3) 230 | R = np.identity(3) 231 | source2 = np.copy(source) 232 | prev_err = 100000000 233 | flag = True 234 | t1 = time.time() 235 | for icp_iter in range(ICP_ITERS): 236 | dist = cdist(source2,samples_for_icp) 237 | sample_idx = np.argmin(dist,1) 238 | matched_pts = samples_for_icp[sample_idx] 239 | source2mean = source2.mean(0) 240 | matchedptsmean = matched_pts.mean(0) 241 | source2centered = source2-source2mean 242 | it = source2mean - matchedptsmean 243 | if flag: 244 | idx2 = np.argmin(dist,0) 245 | matched2 = source2[idx2] 246 | it = (0.5*it) + 0.5*(matched2.mean(0) - samples_for_icp.mean(0)) 247 | 248 | H = (source2centered).T @ (matched_pts-matchedptsmean) 249 | if flag: 250 | H2 = (matched2-matched2.mean(0)).T @ (samples_for_icp-samples_for_icp.mean(0)) 251 | H2 *= source2.shape[0]/samples_for_icp.shape[0] 252 | H = H + H2 253 | u,s,vt = np.linalg.svd(H) 254 | rotmat = vt.T @ np.diag([1,1,np.linalg.det(vt.T @ u.T)]) @ u.T 255 | 256 | #print(rotmat,'\n',M) 257 | #print(it,'\n',t) 258 | source2 = (source2centered) @ rotmat.T + source2mean - it 259 | err = np.linalg.norm(source2-matched_pts,axis=1) 260 | #print(err) 261 | #print(np.diag(cdist(source2,matched_pts)).mean(),len(matched_pts)) 262 | if np.linalg.norm(err-prev_err) < ICP_THRESH: 263 | break 264 | prev_err = err 265 | icp_t += it 266 | R = R @ rotmat 267 | #print(it) 268 | #print(rotmat) 269 | 270 | icp_q = transforms3d.quaternions.mat2quat(R) 271 | icp_t = icp_t 272 | icp_times.append(time.time()-t1) 273 | data_log_icp.append( [icp_q.dot(true_q),np.linalg.norm(icp_t-t)] ) 274 | 275 | t1 = time.time() 276 | 277 | reg = rigid_registration(X=source,Y=samples_for_icp,max_iterations=CPD_ITERS,tolerance=CPD_THRESH) 278 | TY, (s_reg, R_reg, t_reg) = reg.register() 279 | cpd_times.append(time.time()-t1) 280 | H = (TY-TY.mean(0)).T @ (samples_for_icp-samples_for_icp.mean(0)) 281 | u,s,vt = np.linalg.svd(H) 282 | R_reg = vt.T @ np.diag([1,1,np.linalg.det(vt.T @ u.T)]) @ u.T 283 | t_reg = TY.mean(0)-samples_for_icp.mean(0) 284 | cpd_q = transforms3d.quaternions.mat2quat(R_reg) 285 | data_log_cpd.append( [cpd_q.dot(true_q),np.linalg.norm(t_reg-t)] ) 286 | 287 | if False: 288 | fig = plt.figure() 289 | ax = fig.add_subplot(111, projection='3d') 290 | ax.scatter(source[:,0],source[:,1],source[:,2],label='orig') 291 | ax.scatter(samples[:,0],samples[:,1],samples[:,2],label='trans') 292 | result = (source + icp_t) @ R.T 293 | ax.scatter(source2[:,0],source2[:,1],source2[:,2],label='registered') 294 | plt.title(str(icp_q.dot(true_q)) + ' ' + str(np.linalg.norm(icp_t-t))) 295 | plt.legend() 296 | plt.show() 297 | print(np.array(opt_times_pts).mean()*1000) 298 | print(np.array(opt_times).mean()*1000) 299 | print(np.array(icp_times).mean()*1000) 300 | print(np.array(cpd_times).mean()*1000) 301 | if len(data_log_verts) > 0 : 302 | np.savetxt('verts2.csv',np.array(data_log_verts),delimiter=',') 303 | np.savetxt('vertsk2.csv',np.array(data_log_vertsk),delimiter=',') 304 | np.savetxt('mesh2.csv',np.array(data_log_mesh),delimiter=',') 305 | np.savetxt('icp2.csv',np.array(data_log_icp),delimiter=',') 306 | np.savetxt('cpd2.csv',np.array(data_log_cpd),delimiter=',') 307 | np.savetxt('oracle2.csv',np.array(data_log_oracle),delimiter=',') 308 | np.savetxt('meshk2.csv',np.array(data_log_meshk),delimiter=',') 309 | np.savetxt('areas2.csv',np.array(data_log_areas),delimiter=',') 310 | -------------------------------------------------------------------------------- /registration_test_extra.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import multivariate_normal as mvn_pdf 3 | from mpl_toolkits.mplot3d import Axes3D 4 | import matplotlib.pyplot as plt 5 | from mixture import GaussianMixture 6 | from scipy.spatial.distance import cdist,pdist 7 | import pymesh 8 | import pickle 9 | from scipy.special import logsumexp 10 | import scipy.optimize as opt 11 | import transforms3d 12 | import time 13 | 14 | SAMPLE_NUM = 25 15 | method = None#'CG'#None#'CG'#None#CG' 16 | K = 100 17 | ICP_ITERS = 50000 #150 18 | ICP_THRESH = 1e-9 19 | 20 | #mesh0 = pymesh.load_mesh("bunny/bun_zipper.ply") 21 | #mesh_pts = pymesh.load_mesh("bunny/bun_zipper_50k.ply") 22 | #mesh0 = pymesh.load_mesh("bunny/bun_zipper_1000_1.ply") 23 | #mesh_pts = pymesh.load_mesh("bunny/bun_zipper_50k.ply") 24 | if False: 25 | mesh0 = pymesh.load_mesh("bunny/bun_zipper_res4.ply") 26 | mesh_pts = pymesh.load_mesh("bunny/bun_zipper_res4_25k_pds.ply") 27 | tag = 'bunny' 28 | elif False: 29 | mesh0 = pymesh.load_mesh("arma/Armadillo_1000.ply") 30 | mesh_pts = pymesh.load_mesh("arma/Armadillo_25k.ply") 31 | tag = 'arma' 32 | 33 | elif False: 34 | mesh0 = pymesh.load_mesh("dragon/dragon_recon/dragon_vrip_1000.ply") 35 | mesh_pts = pymesh.load_mesh("dragon/dragon_recon/dragon_vrip_25k.ply") 36 | tag = 'dragon' 37 | 38 | elif False: 39 | mesh0 = pymesh.load_mesh("happy/happy_recon/happy_vrip_1000.ply") 40 | mesh_pts = pymesh.load_mesh("happy/happy_recon/happy_vrip_25k.ply") 41 | tag = 'happy' 42 | 43 | elif True: 44 | mesh0 = pymesh.load_mesh("lucy/lucy_1000.ply") 45 | mesh_pts = pymesh.load_mesh("lucy/lucy_25k.ply") 46 | tag = 'lucy' 47 | 48 | SAMPLE_PTS = mesh0.vertices.shape[0] # number of vertecies! 49 | 50 | def get_centroids(mesh): 51 | # obtain a vertex for each face index 52 | face_vert = mesh.vertices[mesh.faces.reshape(-1),:].reshape((mesh.faces.shape[0],3,-1)) 53 | # face_vert is size (faces,3(one for each vert), 3(one for each dimension)) 54 | centroids = face_vert.sum(1)/3.0 55 | #face_vert = ((face_vert.shape[0]/SAMPLE_PTS)*(face_vert.reshape((-1,9))-np.repeat(centroids,3,axis=1)) + np.repeat(centroids,3,axis=1)).reshape((-1,3,3)) 56 | ABAC = face_vert[:,1:3,:] - face_vert[:,0:1,:] 57 | areas = np.linalg.norm(np.cross(ABAC[:,0,:],ABAC[:,1,:]),axis=1)/2.0 58 | return centroids, areas 59 | 60 | def get_tri_covar(tris): 61 | covars = [] 62 | for face in tris: 63 | A = face[0][:,None] 64 | B = face[1][:,None] 65 | C = face[2][:,None] 66 | M = (A+B+C)/3 67 | covars.append(A @ A.T + B @ B.T + C @ C.T - 3* M @ M.T) 68 | return np.array(covars)*(1/12.0) 69 | 70 | com,a = get_centroids(mesh0) 71 | print(com.shape) 72 | face_vert = mesh0.vertices[mesh0.faces.reshape(-1),:].reshape((mesh0.faces.shape[0],3,-1)) 73 | data_covar = get_tri_covar(face_vert) 74 | print(data_covar.shape) 75 | 76 | indices2 = np.random.randint(0,mesh_pts.vertices.shape[0],SAMPLE_PTS) 77 | samples_for_icp = mesh0.vertices#np.copy(mesh_pts.vertices[indices2]) 78 | #gm3 = GaussianMixture(100,init_params='kmeans'); gm3.set_triangles(face_vert); gm3.fit(com); gm3.set_triangles(None) 79 | #usually tol=1e-4,max_iter=100 80 | t1 = time.time() 81 | gm_std_km = GaussianMixture(K,init_params='kmeans',tol=1e-5,max_iter=100); gm_std_km.fit(samples_for_icp) 82 | print((time.time()-t1)*1000) 83 | t1 = time.time() 84 | gm_std = GaussianMixture(K,init_params='random',tol=1e-5,max_iter=100); gm_std.fit(samples_for_icp) 85 | print((time.time()-t1)*1000) 86 | t1 = time.time() 87 | #indices3 = np.random.randint(0,mesh0.vertices.shape[0],SAMPLE_PTS) 88 | gm_mesh = GaussianMixture(K,init_params='random',tol=1e-5,max_iter=100); gm_mesh.set_covars(data_covar); gm_mesh.set_areas(a); gm_mesh.fit(com); gm_mesh.set_covars(None); gm_mesh.set_areas(None) 89 | print((time.time()-t1)*1000) 90 | t1 = time.time() 91 | #indices3 = np.random.randint(0,mesh0.vertices.shape[0],SAMPLE_PTS) 92 | gm_mesh_kmeans = GaussianMixture(K,init_params='kmeans',tol=1e-5,max_iter=100); gm_mesh_kmeans.set_covars(data_covar); gm_mesh_kmeans.set_areas(a); gm_mesh_kmeans.fit(com); gm_mesh_kmeans.set_covars(None); gm_mesh_kmeans.set_areas(None) 93 | print((time.time()-t1)*1000) 94 | 95 | 96 | gm_areas_kmeans = GaussianMixture(K,init_params='kmeans',tol=1e-5,max_iter=100); gm_areas_kmeans.set_areas(a); gm_areas_kmeans.fit(com); gm_areas_kmeans.set_areas(None) 97 | 98 | data_log_mesh = [] 99 | data_log_meshk = [] 100 | data_log_verts = [] 101 | data_log_vertsk = [] 102 | data_log_areas = [] 103 | data_log_icp = [] 104 | data_log_cpd = [] 105 | data_log_oracle = [] 106 | opt_times = [] 107 | opt_times_pts = [] 108 | icp_times = [] 109 | cpd_times = [] 110 | 111 | prev_time = time.time() 112 | for n in range(SAMPLE_NUM): 113 | 114 | indices2 = np.random.randint(0,mesh_pts.vertices.shape[0],SAMPLE_PTS) 115 | samples_for_icp = np.copy(mesh_pts.vertices[indices2]) 116 | full_points = samples_for_icp#mesh_pts.vertices 117 | indices = np.random.randint(0,mesh_pts.vertices.shape[0],SAMPLE_PTS) 118 | samples = np.copy(mesh_pts.vertices[indices]) 119 | samples_mean = samples.mean(0) 120 | centered_points = samples - samples_mean 121 | 122 | print(n,round(time.time()-prev_time,1),'seconds') 123 | prev_time = time.time() 124 | if False: # random transformations 125 | q = np.random.randn(4) 126 | q = q/np.linalg.norm(q) 127 | M = transforms3d.quaternions.quat2mat(q) 128 | t = np.random.randn(3)*0.05 129 | else: 130 | t = np.random.rand(3)*0.1 - 0.05 131 | angles = np.random.rand(3)*30 - 15 132 | angles *= np.pi/180.0 133 | M = transforms3d.euler.euler2mat(angles[0],angles[1],angles[2]) 134 | 135 | true_q = transforms3d.quaternions.mat2quat(M) 136 | 137 | source = centered_points @ M + samples_mean+ t 138 | sourcemean = source.mean(0) 139 | source_centered = source - sourcemean 140 | 141 | H = (source-source.mean(0)).T @ (samples-samples.mean(0)) 142 | u,s,vt = np.linalg.svd(H) 143 | R_reg = vt.T @ np.diag([1,1,np.linalg.det(vt.T @ u.T)]) @ u.T 144 | t_reg = source.mean(0)-samples.mean(0) 145 | oracle_q = transforms3d.quaternions.mat2quat(R_reg) 146 | data_log_oracle.append( [oracle_q.dot(true_q),np.linalg.norm(t_reg-t)] ) 147 | 148 | 149 | 150 | def loss_verts(x): 151 | qs = x[:4] 152 | ts = x[4:] 153 | qs = qs/np.linalg.norm(qs) 154 | Ms = transforms3d.quaternions.quat2mat(qs) 155 | tpts = (source_centered) @ Ms.T + sourcemean - ts 156 | return -gm_std.score(tpts) 157 | t1 = time.time() 158 | res = opt.minimize(loss_verts,np.array([1,0,0,0,0,0,0]),method=method) 159 | opt_times_pts.append(time.time()-t1) 160 | rq = res.x[:4] 161 | rq = rq/np.linalg.norm(rq) 162 | rt = res.x[4:] 163 | #print(method) 164 | #print(np.arccos(rq.dot(true_q)),np.linalg.norm(rt-t)) 165 | data_log_verts.append( [rq.dot(true_q),np.linalg.norm(rt-t)] ) 166 | def loss_mesh(x): 167 | qs = x[:4] 168 | ts = x[4:] 169 | qs = qs/np.linalg.norm(qs) 170 | Ms = transforms3d.quaternions.quat2mat(qs) 171 | tpts = (source_centered) @ Ms.T + sourcemean - ts 172 | return -gm_mesh.score(tpts) 173 | start_opt = time.time() 174 | res = opt.minimize(loss_mesh,np.array([1,0,0,0,0,0,0]),method=method) 175 | end_opt = time.time() 176 | opt_times.append(end_opt-start_opt) 177 | rq = res.x[:4] 178 | rq = rq/np.linalg.norm(rq) 179 | rt = res.x[4:] 180 | data_log_mesh.append( [rq.dot(true_q),np.linalg.norm(rt-t)] ) 181 | def loss_mesh_k(x): 182 | qs = x[:4] 183 | ts = x[4:] 184 | qs = qs/np.linalg.norm(qs) 185 | Ms = transforms3d.quaternions.quat2mat(qs) 186 | tpts = (source_centered) @ Ms.T + sourcemean - ts 187 | return -gm_std_km.score(tpts) 188 | start_opt = time.time() 189 | res = opt.minimize(loss_mesh_k,np.array([1,0,0,0,0,0,0]),method=method) 190 | end_opt = time.time() 191 | opt_times.append(end_opt-start_opt) 192 | rq = res.x[:4] 193 | rq = rq/np.linalg.norm(rq) 194 | rt = res.x[4:] 195 | data_log_vertsk.append( [rq.dot(true_q),np.linalg.norm(rt-t)] ) 196 | 197 | def loss_areas_k(x): 198 | qs = x[:4] 199 | ts = x[4:] 200 | qs = qs/np.linalg.norm(qs) 201 | Ms = transforms3d.quaternions.quat2mat(qs) 202 | tpts = (source_centered) @ Ms.T + sourcemean - ts 203 | return -gm_areas_kmeans.score(tpts) 204 | start_opt = time.time() 205 | res = opt.minimize(loss_areas_k,np.array([1,0,0,0,0,0,0]),method=method) 206 | end_opt = time.time() 207 | rq = res.x[:4] 208 | rq = rq/np.linalg.norm(rq) 209 | rt = res.x[4:] 210 | data_log_areas.append( [rq.dot(true_q),np.linalg.norm(rt-t)] ) 211 | 212 | if True: 213 | def loss_mesh_k2(x): 214 | qs = x[:4] 215 | ts = x[4:] 216 | qs = qs/np.linalg.norm(qs) 217 | Ms = transforms3d.quaternions.quat2mat(qs) 218 | tpts = (source_centered) @ Ms.T + sourcemean - ts 219 | return -gm_mesh_kmeans.score(tpts) 220 | start_opt = time.time() 221 | res = opt.minimize(loss_mesh_k2,np.array([1,0,0,0,0,0,0]),method=method) 222 | end_opt = time.time() 223 | rq = res.x[:4] 224 | rq = rq/np.linalg.norm(rq) 225 | rt = res.x[4:] 226 | data_log_meshk.append( [rq.dot(true_q),np.linalg.norm(rt-t)] ) 227 | else: 228 | def loss_mesh_k2(x): 229 | qs = x[:3] 230 | ts = x[3:] 231 | qs[-1] += 1e-9 232 | angle = np.linalg.norm(qs) 233 | axis = qs/angle 234 | Ms = transforms3d.axangles.axangle2mat(axis,angle) 235 | tpts = (source_centered) @ Ms.T + sourcemean - ts 236 | return -gm_mesh_kmeans.score(tpts) 237 | start_opt = time.time() 238 | res = opt.minimize(loss_mesh_k2,np.array([0,0,0,0,0,0]),method=method) 239 | end_opt = time.time() 240 | 241 | rq = res.x[:3] 242 | rq[-1] += 1e-9 243 | 244 | angle = np.linalg.norm(rq) 245 | axis = rq/angle 246 | rq = transforms3d.quaternions.axangle2quat(axis,angle) 247 | rt = res.x[3:] 248 | data_log_meshk.append( [rq.dot(true_q),np.linalg.norm(rt-t)] ) 249 | icp_t = np.zeros(3) 250 | R = np.identity(3) 251 | source2 = np.copy(source) 252 | prev_err = 100000000 253 | flag = True 254 | t1 = time.time() 255 | for icp_iter in range(ICP_ITERS): 256 | dist = cdist(source2,samples_for_icp) 257 | sample_idx = np.argmin(dist,1) 258 | matched_pts = samples_for_icp[sample_idx] 259 | source2mean = source2.mean(0) 260 | matchedptsmean = matched_pts.mean(0) 261 | source2centered = source2-source2mean 262 | it = source2mean - matchedptsmean 263 | if flag: 264 | idx2 = np.argmin(dist,0) 265 | matched2 = source2[idx2] 266 | it = (0.5*it) + 0.5*(matched2.mean(0) - samples_for_icp.mean(0)) 267 | 268 | H = (source2centered).T @ (matched_pts-matchedptsmean) 269 | if flag: 270 | H2 = (matched2-matched2.mean(0)).T @ (samples_for_icp-samples_for_icp.mean(0)) 271 | H2 *= source2.shape[0]/samples_for_icp.shape[0] 272 | H = H + H2 273 | u,s,vt = np.linalg.svd(H) 274 | rotmat = vt.T @ np.diag([1,1,np.linalg.det(vt.T @ u.T)]) @ u.T 275 | 276 | #print(rotmat,'\n',M) 277 | #print(it,'\n',t) 278 | source2 = (source2centered) @ rotmat.T + source2mean - it 279 | err = np.linalg.norm(source2-matched_pts,axis=1) 280 | #print(err) 281 | #print(np.diag(cdist(source2,matched_pts)).mean(),len(matched_pts)) 282 | if np.linalg.norm(err-prev_err) < ICP_THRESH: 283 | break 284 | prev_err = err 285 | icp_t += it 286 | R = R @ rotmat 287 | #print(it) 288 | #print(rotmat) 289 | 290 | icp_q = transforms3d.quaternions.mat2quat(R) 291 | icp_t = icp_t 292 | icp_times.append(time.time()-t1) 293 | data_log_icp.append( [icp_q.dot(true_q),np.linalg.norm(icp_t-t)] ) 294 | 295 | t1 = time.time() 296 | 297 | 298 | if False: 299 | fig = plt.figure() 300 | ax = fig.add_subplot(111, projection='3d') 301 | ax.scatter(source[:,0],source[:,1],source[:,2],label='orig') 302 | ax.scatter(samples[:,0],samples[:,1],samples[:,2],label='trans') 303 | result = (source + icp_t) @ R.T 304 | ax.scatter(source2[:,0],source2[:,1],source2[:,2],label='registered') 305 | plt.title(str(icp_q.dot(true_q)) + ' ' + str(np.linalg.norm(icp_t-t))) 306 | plt.legend() 307 | plt.show() 308 | print(np.array(opt_times_pts).mean()*1000) 309 | print(np.array(opt_times).mean()*1000) 310 | print(np.array(icp_times).mean()*1000) 311 | print(np.array(cpd_times).mean()*1000) 312 | if len(data_log_verts) > 0 : 313 | np.savetxt('{}_verts2.csv'.format(tag),np.array(data_log_verts),delimiter=',') 314 | np.savetxt('{}_vertsk2.csv'.format(tag),np.array(data_log_vertsk),delimiter=',') 315 | np.savetxt('{}_mesh2.csv'.format(tag),np.array(data_log_mesh),delimiter=',') 316 | np.savetxt('{}_icp2.csv'.format(tag),np.array(data_log_icp),delimiter=',') 317 | np.savetxt('{}_cpd2.csv'.format(tag),np.array(data_log_cpd),delimiter=',') 318 | np.savetxt('{}_oracle2.csv'.format(tag),np.array(data_log_oracle),delimiter=',') 319 | np.savetxt('{}_meshk2.csv'.format(tag),np.array(data_log_meshk),delimiter=',') 320 | np.savetxt('{}_areas2.csv'.format(tag),np.array(data_log_areas),delimiter=',') 321 | -------------------------------------------------------------------------------- /render_tmp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import trimesh 3 | from autolab_core import RigidTransform 4 | from perception import CameraIntrinsics, RenderMode 5 | 6 | from meshrender import Scene, MaterialProperties, AmbientLight, PointLight, SceneObject, VirtualCamera 7 | 8 | # Start with an empty scene 9 | scene = Scene() 10 | 11 | #==================================== 12 | # Add objects to the scene 13 | #==================================== 14 | 15 | # Begin by loading meshes 16 | cube_mesh = trimesh.load_mesh('cube.obj') 17 | sphere_mesh = trimesh.load_mesh('sphere.obj') 18 | 19 | # Set up each object's pose in the world 20 | cube_pose = RigidTransform( 21 | rotation=np.eye(3), 22 | translation=np.array([0.0, 0.0, 0.0]), 23 | from_frame='obj', 24 | to_frame='world' 25 | ) 26 | sphere_pose = RigidTransform( 27 | rotation=np.eye(3), 28 | translation=np.array([1.0, 1.0, 0.0]), 29 | from_frame='obj', 30 | to_frame='world' 31 | ) 32 | 33 | # Set up each object's material properties 34 | cube_material = MaterialProperties( 35 | color = np.array([0.1, 0.1, 0.5]), 36 | k_a = 0.3, 37 | k_d = 1.0, 38 | k_s = 1.0, 39 | alpha = 10.0, 40 | smooth=False 41 | ) 42 | sphere_material = MaterialProperties( 43 | color = np.array([0.1, 0.1, 0.5]), 44 | k_a = 0.3, 45 | k_d = 1.0, 46 | k_s = 1.0, 47 | alpha = 10.0, 48 | smooth=True 49 | ) 50 | 51 | # Create SceneObjects for each object 52 | cube_obj = SceneObject(cube_mesh, cube_pose, cube_material) 53 | sphere_obj = SceneObject(sphere_mesh, sphere_pose, sphere_material) 54 | 55 | # Add the SceneObjects to the scene 56 | scene.add_object('cube', cube_obj) 57 | scene.add_object('sphere', sphere_obj) 58 | 59 | #==================================== 60 | # Add lighting to the scene 61 | #==================================== 62 | 63 | # Create an ambient light 64 | ambient = AmbientLight( 65 | color=np.array([1.0, 1.0, 1.0]), 66 | strength=1.0 67 | ) 68 | 69 | # Create a point light 70 | point = PointLight( 71 | location=np.array([1.0, 2.0, 3.0]), 72 | color=np.array([1.0, 1.0, 1.0]), 73 | strength=10.0 74 | ) 75 | 76 | # Add the lights to the scene 77 | scene.ambient_light = ambient # only one ambient light per scene 78 | scene.add_light('point_light_one', point) 79 | 80 | #==================================== 81 | # Add a camera to the scene 82 | #==================================== 83 | 84 | # Set up camera intrinsics 85 | ci = CameraIntrinsics( 86 | frame = 'camera', 87 | fx = 525.0, 88 | fy = 525.0, 89 | cx = 319.5, 90 | cy = 239.5, 91 | skew=0.0, 92 | height=480, 93 | width=640 94 | ) 95 | 96 | # Set up the camera pose (z axis faces away from scene, x to right, y up) 97 | cp = RigidTransform( 98 | rotation = np.array([ 99 | [0.0, 0.0, -1.0], 100 | [0.0, 1.0, 0.0], 101 | [1.0, 0.0, 0.0] 102 | ]), 103 | translation = np.array([-0.3, 0.0, 0.0]), 104 | from_frame='camera', 105 | to_frame='world' 106 | ) 107 | 108 | # Create a VirtualCamera 109 | camera = VirtualCamera(ci, cp) 110 | 111 | # Add the camera to the scene 112 | scene.camera = camera 113 | 114 | #==================================== 115 | # Render images 116 | #==================================== 117 | 118 | # Render raw numpy arrays containing color and depth 119 | color_image_raw, depth_image_raw = scene.render(render_color=True) 120 | 121 | # Alternatively, just render a depth image 122 | depth_image_raw = scene.render(render_color=False) 123 | 124 | # Alternatively, collect wrapped images 125 | wrapped_color, wrapped_depth, wrapped_segmask = scene.wrapped_render( 126 | [RenderMode.COLOR, RenderMode.DEPTH, RenderMode.SEGMASK] 127 | ) 128 | -------------------------------------------------------------------------------- /road_graphic.py: -------------------------------------------------------------------------------- 1 | from pylab import * 2 | 3 | from skimage import io, filters 4 | img = io.imread('aerial.jpeg') 5 | diff = (abs(img - img.mean(2,keepdims=True))).mean(2) * (img.mean(2)) 6 | diff = diff**2 7 | diff = diff/diff.max()*255*12*10 8 | diff = np.minimum(diff,255) 9 | 10 | plt.imshow(diff.astype(np.uint8)) 11 | clean_diff = 255-filters.rank.median(diff.astype(np.uint8),np.ones((51,51))) 12 | crop = clean_diff[400:1100:4,2700:3300:4].astype(np.float) 13 | crop = crop/crop.sum() 14 | plt.imshow(crop) 15 | plt.colorbar() 16 | plt.tight_layout() 17 | 18 | plt.show() 19 | #diff = np.exp(-diff**2/10) 20 | #diff = diff/diff.max()*255 21 | #plt.show() -------------------------------------------------------------------------------- /tri_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import mpl_toolkits.mplot3d as m3d 4 | from scipy.stats import multivariate_normal as mvg 5 | 6 | 7 | 8 | import random 9 | 10 | def point_on_triangle(pt1, pt2, pt3): 11 | """ 12 | Random point on the triangle with vertices pt1, pt2 and pt3. 13 | """ 14 | s, t = sorted([random.random(), random.random()]) 15 | return (s * pt1[0] + (t-s)*pt2[0] + (1-t)*pt3[0], 16 | s * pt1[1] + (t-s)*pt2[1] + (1-t)*pt3[1], 17 | s * pt1[2] + (t-s)*pt2[2] + (1-t)*pt3[2]) 18 | 19 | ax = m3d.Axes3D(plt.figure()) 20 | 21 | vtx = np.array(((0.0,0.0,0.0),(0.0,0.0,2.0),(0.0,1.0,0.0))) 22 | vtx = np.random.randn(3,3) 23 | tri = m3d.art3d.Poly3DCollection([vtx]) 24 | tri.set_facecolor((0.7,0.7,1.0,0.5)) 25 | tri.set_edgecolor('k') 26 | tri.set_alpha(0.1) 27 | 28 | ax.add_collection3d(tri) 29 | 30 | mu = vtx.mean(0) 31 | mu = np.random.randn(3,3)[0,:]/3 32 | covar = np.random.randn(3,3)/2 33 | covar = np.abs(covar.T.dot(covar)) 34 | #covar = np.identity(3) 35 | A = vtx[0,:] 36 | B = vtx[1,:] 37 | C = vtx[2,:] 38 | M = (A+B+C)/3.0 39 | s = np.random.multivariate_normal(mu,covar,200) 40 | ax.scatter(s[:,0],s[:,1],s[:,2]) 41 | for sn in range(1000,1001): 42 | x = np.linspace(0,1,sn) 43 | y = np.linspace(0,1,sn) 44 | 45 | 46 | pts = [] 47 | if False: 48 | for i in range(len(x)): 49 | for j in range(i,len(y)): 50 | u = x[i] 51 | v = 1-y[j] 52 | #pt = (1-u) * A + u * ((1-v)*B + v*C) 53 | pt = A + u*(B-A) + v*(C-A) 54 | pts.append(pt) 55 | else: 56 | #nrm1 = np.linalg.norm(B-A) 57 | #nrm2 = np.linalg.norm(C-A) 58 | #bound = max(nrm1,nrm2) 59 | #pts1 = np.random.uniform(-bound,bound,(sn,2)) 60 | pts = [point_on_triangle(A, B, C) for _ in range(sn)] 61 | 62 | #print(pts1) 63 | #aise 64 | pts = np.array(pts) 65 | #ax.scatter(pts[:,0],pts[:,1],pts[:,2]) 66 | 67 | #cov = np.vstack(((B-A),(C-A))).T 68 | #js = np.sqrt(np.linalg.det(cov.T.dot(cov))) 69 | 70 | at = np.linalg.norm(np.cross((B-A),(C-A)))/2 71 | #print('TRUTH:\t',at) 72 | #print('JAC/2:\t',js/2) 73 | 74 | #u = np.random.randn(3,3)[0,:]/3 75 | #covar = np.random.randn(3,3)/2 76 | #covar = np.abs(covar.T.dot(covar)) 77 | 78 | 79 | #l2 = np.exp( np.log() ) 80 | 81 | lklh_add = 0.0 82 | lklh_mul = 0.0 83 | mll_mul = 0.0 84 | # dev = x - mean 85 | # maha = np.sum(np.square(np.dot(dev, prec_U)), axis=-1) 86 | # return -0.5 * (rank * _LOG_2PI + log_det_cov + maha) 87 | def my_ll(x,u,s): 88 | dev = x - u 89 | dev = dev.reshape((3,1)) 90 | res = 0.0 91 | res -= 0.5 * np.log(2*np.pi) *3 92 | res -= 0.5 * np.log(np.linalg.det(s)) 93 | #res -= 0.5 * dev.reshape((1,-1)).dot(np.linalg.inv(s) ).dot(dev.reshape((-1,1))) 94 | res -= 0.5 * dev.T.dot(np.linalg.inv(s)).dot(dev) 95 | 96 | #res = -0.5 * () 97 | return res.sum() 98 | def tri_ll(A,B,C,mu,s): 99 | m = ((A+B+C)/3 ) 100 | dev = (m - mu).reshape((-1,1)) 101 | a = A.reshape((-1,1)) 102 | b = B.reshape((-1,1)) 103 | c = C.reshape((-1,1)) 104 | m = m.reshape((-1,1)) 105 | 106 | res = 0.0 107 | res -= at*0.5 * np.log(2*np.pi) *3 108 | #print(res) 109 | res -= at*0.5 * np.log(np.linalg.det(s)) 110 | t1 = dev.dot(dev.T) 111 | t2 = (a.dot(a.T) + b.dot(b.T) + c.dot(c.T) - 3*m.dot(m.T)) 112 | #print(res,at,'\n',t1,'\n',t2/12.0) 113 | res -= 0.5 *at * np.trace(( t1 + (1/12.0) * t2).dot(np.linalg.inv(s))) 114 | print(t1.mean(),t2.mean()) 115 | #print(t2.shape) 116 | #res = -0.5 * () 117 | return res.sum()/at,t2 118 | def tri_ll2(A,B,C,mu,s): 119 | m = ((A+B+C)/3 ) 120 | dev = (m - mu).reshape((-1,1)) 121 | a = A.reshape((-1,1)) 122 | b = B.reshape((-1,1)) 123 | c = C.reshape((-1,1)) 124 | m = m.reshape((-1,1)) 125 | 126 | res = 0.0 127 | res -= at*0.5 * np.log(2*np.pi) *3 128 | #print(res) 129 | res -= at*0.5 * np.log(np.linalg.det(s)) 130 | prec = np.linalg.inv(s) 131 | t1 = dev.dot(dev.T) 132 | t2 = (a.dot(a.T) + b.dot(b.T) + c.dot(c.T) - 3*m.dot(m.T)) 133 | t22 = (a.T.dot(prec).dot(a) + b.T.dot(prec).dot(b) + c.T.dot(prec).dot(c) - 3*m.T.dot(prec).dot(m)) 134 | #np.trace(( t1 + (1/12.0) * t2).dot(np.linalg.inv(s))) 135 | #print(res,at,'\n',t1,'\n',t2/12.0) 136 | res -= 0.5 *at * ((dev.T.dot(np.linalg.inv(s))).dot(dev) + (1.0/12.0)*(t22)) 137 | #print(t2.shape) 138 | #res = -0.5 * () 139 | return res.sum()/at,t2 140 | l1 = at*mvg.pdf((A+B+C)/3.0,mu,covar) 141 | est_covar = np.identity(3)*0 142 | for p in pts: 143 | lklh_add += mvg.pdf(p,mu,covar) 144 | lklh_mul += mvg.logpdf(p,mu,covar)#np.log(mvg.pdf(p,u,covar)) 145 | mll_mul += my_ll(p,mu,covar) 146 | d = (p-M).reshape((-1,1)) 147 | est_covar += (d.dot(d.T)) 148 | 149 | est_covar *= (1.0/len(pts)) 150 | ll,cv = tri_ll(A,B,C,mu,covar) 151 | print(len(pts)) 152 | #print("p(m)A",np.log(l1)) 153 | #print("S p(m)",np.log(lklh_add)) 154 | #print("P p(m)",lklh_mul) 155 | #print("P p(m2",lklh_mul+(1/len(pts))) 156 | print("COM 4E", mvg.logpdf(M,mu,covar)) 157 | print("P p(m3",lklh_mul*(1.0/len(pts))) 158 | #print("constat",np.log(len(pts))) 159 | #print("P p(m4",lklh_mul/(1/len(pts))) 160 | #print("P p(m5",lklh_mul-np.log(1/len(pts))) 161 | #print("P p(m6",lklh_mul*np.log(1/len(pts))) 162 | #print("P p(m7",lklh_mul/np.log(1/len(pts))) 163 | #print("M p(m)",mll_mul) 164 | #print("M p(m3",mll_mul*(1.0/len(pts))) 165 | 166 | print("T p(m)2",tri_ll2(A,B,C,mu,covar)[0]) 167 | 168 | print("T p(m)",ll) 169 | #print("R1p(m)",(ll)/(at*mll_mul*(1.0/len(pts)))) 170 | 171 | #print("R2p(m)",(ll/at)/(at*mll_mul*(1.0/len(pts)))) 172 | #print("D1p(m)",(ll) - (at*mll_mul*(1.0/len(pts)))) 173 | 174 | #print("D2p(m)",(ll/at)- (at*mll_mul*(1.0/len(pts)))) 175 | #print("covar\n",cv/12.0) 176 | #print("est covar\n",est_covar) 177 | #print(np.linalg.norm(cv/12.0 - est_covar)) 178 | 179 | m = ((A+B+C)/3 ) 180 | dev = (m - mu).reshape((-1,1)) 181 | a = A.reshape((-1,1)) 182 | b = B.reshape((-1,1)) 183 | c = C.reshape((-1,1)) 184 | m = m.reshape((-1,1)) 185 | 186 | 187 | 188 | 189 | 190 | mu2 = (A+B+C)/3 191 | t1 = dev.dot(dev.T) 192 | 193 | t2 = (a.dot(a.T) + b.dot(b.T) + c.dot(c.T) - 3*m.dot(m.T)) 194 | covar2 = ( (1/12.0) * t2) 195 | s2 = np.random.multivariate_normal(mu2,covar2,2000) 196 | #print("covar2\n",covar2) 197 | #print("ratio\n",covar2/est_covar) 198 | 199 | #ax.scatter(s2[:,0],s2[:,1],s2[:,2]) 200 | #plt.show() 201 | -------------------------------------------------------------------------------- /tri_verts_graph-Copy1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import pandas as pd\n", 12 | "import matplotlib.ticker" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "mdf" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "for namef in ['arma','bunny','dragon','happy','lucy']:\n", 31 | " print(namef)\n", 32 | " df = pd.read_csv(\"{}.log\".format(namef),names=['k','init','model','l','i'])\n", 33 | " mdf = df.groupby(['init','model','k']).mean()\n", 34 | " sdf = df.groupby(['init','model','k']).std()\n", 35 | " print(mdf[\"l\"][2])#(mdf['l'][3]-mdf['l'][2])/sdf['l'][2]\n", 36 | " print(mdf[\"l\"][3])" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "mdf" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [] 54 | } 55 | ], 56 | "metadata": { 57 | "kernelspec": { 58 | "display_name": "Python 3", 59 | "language": "python", 60 | "name": "python3" 61 | }, 62 | "language_info": { 63 | "codemirror_mode": { 64 | "name": "ipython", 65 | "version": 3 66 | }, 67 | "file_extension": ".py", 68 | "mimetype": "text/x-python", 69 | "name": "python", 70 | "nbconvert_exporter": "python", 71 | "pygments_lexer": "ipython3", 72 | "version": "3.6.7" 73 | } 74 | }, 75 | "nbformat": 4, 76 | "nbformat_minor": 2 77 | } 78 | -------------------------------------------------------------------------------- /tri_verts_graph.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import pandas as pd\n", 12 | "import matplotlib.ticker" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "df = pd.read_csv(\"bunny_1k_com_verts_monday_25.log\",names=['k','init','model','l','i'])\n", 22 | "mdf = df.groupby(['init','model','k']).mean()\n", 23 | "sdf = df.groupby(['init','model','k']).std()\n", 24 | "from matplotlib import rc\n", 25 | "plt.style.use('fivethirtyeight')\n", 26 | "plt.style.use('seaborn-whitegrid')\n", 27 | "#plt.rcParams[\"font.family\"] = \"sans-serif\"\n", 28 | "rc('font',**{'family':'sans-serif','sans-serif':['cm']})\n", 29 | "\n", 30 | "pltstuff = 1\n", 31 | "if pltstuff == 2:\n", 32 | " subplot_order = [1,3,2,4]\n", 33 | "else:\n", 34 | " subplot_order = [1,2]\n", 35 | "fig = plt.figure(figsize=(10,5*pltstuff))\n", 36 | "for ii, init in enumerate(['kmeans','random']):\n", 37 | " for model in [3,0,1,2]:\n", 38 | " ls = '--' if model == 0 or model ==3 else '-'\n", 39 | " if model == 0:\n", 40 | " label = 'Mesh (approx)'\n", 41 | " if model == 1:\n", 42 | " label = 'Points (Center of Mass)'\n", 43 | " if model == 2:\n", 44 | " label = 'Points (Vertices)'\n", 45 | " if model == 3:\n", 46 | " label = 'Mesh (exact)'\n", 47 | " ldf = mdf.loc[(init,model),]\n", 48 | " x = np.array(ldf.index)\n", 49 | " \n", 50 | " y = ldf.values[:,0]\n", 51 | " error = 2*sdf.loc[(init,model),].values[:,0]\n", 52 | "\n", 53 | " ax = plt.subplot(pltstuff,2,subplot_order[pltstuff*ii])\n", 54 | " plt.plot(x,y,ls=ls,label=label)\n", 55 | " plt.fill_between(x, y-error, y+error,alpha=0.3)\n", 56 | " plt.grid(True)\n", 57 | " plt.xlabel('number of mixtures')\n", 58 | " plt.title('{} init. fidelity'.format(init))\n", 59 | " plt.title('{} initialization'.format(init))\n", 60 | "\n", 61 | " plt.ylabel('likelihood of ground truth\\n(higher is better)')\n", 62 | " plt.ylim(2,9)\n", 63 | " ax.set_xscale(\"log\", nonposx='clip')\n", 64 | " ax.tick_params(axis='x', which='minor', bottom=True,width=1,length=5) \n", 65 | " plt.grid(True,axis='x',which='minor')\n", 66 | " ax.tick_params(axis='x', which='major', bottom=True,width=2,length=5) \n", 67 | "\n", 68 | " y = ldf.values[:,1]\n", 69 | " error = sdf.loc[(init,model),].values[:,1]\n", 70 | "\n", 71 | " if pltstuff == 2:\n", 72 | " ax = plt.subplot(pltstuff,2,subplot_order[pltstuff*ii+1])\n", 73 | " plt.plot(x,y,ls=ls,label=label)\n", 74 | " plt.fill_between(x, y-error, y+error,alpha=0.15)\n", 75 | "\n", 76 | " ax.set_xscale(\"log\", nonposx='clip')\n", 77 | " plt.grid(True)\n", 78 | " plt.xlabel('number of mixtures')\n", 79 | " plt.title('{} init. runtime'.format(init))\n", 80 | " plt.ylabel('iterations until convergence')\n", 81 | " ax.tick_params(axis='x', which='minor', bottom=True,width=1,length=5) \n", 82 | " plt.grid(True,axis='x',which='minor')\n", 83 | " ax.tick_params(axis='x', which='major', bottom=True,width=2,length=5) \n", 84 | "\n", 85 | " #plt.subplot(2,2,3)\n", 86 | " #plt.subplot(2,2,4)\n", 87 | "for i in range(pltstuff*2):\n", 88 | " plt.subplot(pltstuff,2,1+i)\n", 89 | " plt.xlim(6,400)\n", 90 | " plt.legend()\n", 91 | "\n", 92 | "fig.tight_layout()\n", 93 | "fig.subplots_adjust(top=0.7+pltstuff*0.1)\n", 94 | "\n", 95 | "plt.suptitle('Different Mesh Information',size=24,weight='bold')\n", 96 | "plt.savefig('graph_triverts.pdf', facecolor=fig.get_facecolor(), edgecolor='none')" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "#mdf.loc[('kmeans',0),].values, np.array(mdf.loc[('kmeans',0),].index)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "mdf" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [] 123 | } 124 | ], 125 | "metadata": { 126 | "kernelspec": { 127 | "display_name": "Python 3", 128 | "language": "python", 129 | "name": "python3" 130 | }, 131 | "language_info": { 132 | "codemirror_mode": { 133 | "name": "ipython", 134 | "version": 3 135 | }, 136 | "file_extension": ".py", 137 | "mimetype": "text/x-python", 138 | "name": "python", 139 | "nbconvert_exporter": "python", 140 | "pygments_lexer": "ipython3", 141 | "version": "3.6.7" 142 | } 143 | }, 144 | "nbformat": 4, 145 | "nbformat_minor": 2 146 | } 147 | -------------------------------------------------------------------------------- /vis_fig.py: -------------------------------------------------------------------------------- 1 | from pylab import * 2 | import scipy.stats 3 | from matplotlib import rc 4 | rc('font',**{'family':'mono','sans-serif':['Helvetica']}) 5 | ## for Palatino and other serif fonts use: 6 | #rc('font',**{'family':'serif','serif':['Palatino']}) 7 | rc('text', usetex=True) 8 | plt.style.use('seaborn-pastel') 9 | means = [-2,0,1] 10 | stds = [1,0.8,0.6] 11 | weights = [1/4,1/4,1/2] 12 | 13 | xs = np.linspace(-5,5,200) 14 | 15 | plt.figure(figsize=(9,3)) 16 | all_ys = np.zeros_like(xs) 17 | for i in range(3): 18 | ys = weights[i]*scipy.stats.norm(means[i],stds[i]).pdf(xs) 19 | all_ys+=ys 20 | plt.plot(xs,ys,label=r'$\mu=${} $\sigma=${:.1f} $\lambda=${:.2f}'.format(means[i],stds[i],weights[i])) 21 | 22 | plt.plot(xs,all_ys,label='GMM',c='k',lw=4) 23 | plt.legend() 24 | plt.tight_layout() 25 | plt.savefig('thing.png',dpi=300) 26 | plt.show() 27 | -------------------------------------------------------------------------------- /vis_fitting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import multivariate_normal as mvn_pdf 3 | 4 | import matplotlib.pyplot as plt 5 | from scipy.special import logsumexp 6 | import mpl_toolkits.mplot3d as m3d 7 | 8 | 9 | means = [[0,0,-2],[0,2,0]] 10 | covars = [np.diag([.1,1,.3]),np.diag([1,.2,.1])] 11 | 12 | np.random.seed(30) 13 | N = 30 14 | pts = [] 15 | for m,c in zip(means,covars): 16 | pts.append(np.random.multivariate_normal(m,c,N)) 17 | pts = np.vstack(pts) 18 | 19 | labels = np.random.rand(N*len(means),len(means)) 20 | labels /= labels.sum(1,keepdims=True) 21 | 22 | for iteration in range(15): 23 | # m-step 24 | new_means = [] 25 | new_covars = [] 26 | new_pis = [] 27 | for k in range(len(means)): 28 | weights = labels[:,k:k+1] 29 | weight_norm = weights.sum() 30 | new_mean = (weights * pts).sum(0)/weight_norm 31 | new_means.append(new_mean) 32 | 33 | t = pts - new_mean 34 | new_covar = (weights/weight_norm * t).T @ t 35 | new_covars.append(new_covar) 36 | 37 | new_pis.append( weight_norm.mean() ) 38 | new_pis = np.array(new_pis) 39 | new_pis /= new_pis.sum() 40 | 41 | # e-step 42 | for k in range(len(means)): 43 | labels[:,k] = new_pis[k]*mvn_pdf(new_means[k],new_covars[k]).pdf(pts) 44 | labels /= labels.sum(1,keepdims=True) 45 | 46 | if (iteration % 1) == 0: 47 | fig = plt.figure(figsize=plt.figaspect(0.5)) 48 | 49 | ax = fig.add_subplot(1, 2, 1, projection='3d') 50 | 51 | colors = [tuple(int(h[i:i+2], 16) for i in (0, 2, 4)) for h in ['CA3542','27646B']] 52 | colors = np.array(colors)/255 53 | colors = np.array([[1,0,0],[0,0,1]]) 54 | ax.scatter(pts[:,0],pts[:,1],pts[:,2],s=20,c=labels@ colors) 55 | ax.set_xlim(-3.5,3.5) 56 | ax.set_ylim(-3.5,3.5) 57 | ax.set_zlim(-3.5,3.5) 58 | 59 | plt.title('E-Step Result',size=24,weight='demibold') 60 | plt.tight_layout() 61 | 62 | ax = fig.add_subplot(1, 2, 2, projection='3d') 63 | 64 | for k in range(len(new_means)): 65 | mean,covar = new_means[k],new_covars[k] 66 | u,s,vt = np.linalg.svd(covar) 67 | coefs = (.002, .002, .002) # Coefficients in a0/c x**2 + a1/c y**2 + a2/c z**2 = 1 68 | # Radii corresponding to the coefficients: 69 | rx, ry, rz = 1.7*np.sqrt(s)#s#1/np.sqrt(coefs) 70 | 71 | R_reg = vt.T @ np.diag([1,1,np.linalg.det(vt.T @ u.T)]) @ u.T 72 | 73 | #print(eigs) 74 | # Set of all spherical angles: 75 | u = np.linspace(0, 2 * np.pi, 30) 76 | v = np.linspace(0, np.pi, 30) 77 | 78 | # Cartesian coordinates that correspond to the spherical angles: 79 | # (this is the equation of an ellipsoid): 80 | x = rx * np.outer(np.cos(u), np.sin(v)) #+ mean[0] 81 | y = ry * np.outer(np.sin(u), np.sin(v)) #+ mean[1] 82 | z = rz * np.outer(np.ones_like(u), np.cos(v)) #+ mean[2] 83 | 84 | for i in range(len(x)): 85 | for j in range(len(x)): 86 | x[i,j],y[i,j],z[i,j] = np.dot([x[i,j],y[i,j],z[i,j]], vt) + mean 87 | # Plot: 88 | res = ax.plot_surface(x,y,z, color=colors[k],shade=True,linewidth=0.0,alpha=new_pis[k]) 89 | ax.set_xlim(-3.5,3.5) 90 | ax.set_ylim(-3.5,3.5) 91 | ax.set_zlim(-3.5,3.5) 92 | 93 | plt.title('M-Step Result',size=24,weight='demibold') 94 | plt.tight_layout() 95 | #plt.show() 96 | plt.savefig('output/{:02d}.png'.format(iteration),dpi=300) -------------------------------------------------------------------------------- /vis_fitting_bunny.py: -------------------------------------------------------------------------------- 1 | _tab20c_data = ( 2 | (0.19215686274509805, 0.5098039215686274, 0.7411764705882353 ), # 3182bd 3 | (0.4196078431372549, 0.6823529411764706, 0.8392156862745098 ), # 6baed6 4 | (0.6196078431372549, 0.792156862745098, 0.8823529411764706 ), # 9ecae1 5 | (0.7764705882352941, 0.8588235294117647, 0.9372549019607843 ), # c6dbef 6 | (0.9019607843137255, 0.3333333333333333, 0.050980392156862744), # e6550d 7 | (0.9921568627450981, 0.5529411764705883, 0.23529411764705882 ), # fd8d3c 8 | (0.9921568627450981, 0.6823529411764706, 0.4196078431372549 ), # fdae6b 9 | (0.9921568627450981, 0.8156862745098039, 0.6352941176470588 ), # fdd0a2 10 | (0.19215686274509805, 0.6392156862745098, 0.32941176470588235 ), # 31a354 11 | (0.4549019607843137, 0.7686274509803922, 0.4627450980392157 ), # 74c476 12 | (0.6313725490196078, 0.8509803921568627, 0.6078431372549019 ), # a1d99b 13 | (0.7803921568627451, 0.9137254901960784, 0.7529411764705882 ), # c7e9c0 14 | (0.4588235294117647, 0.4196078431372549, 0.6941176470588235 ), # 756bb1 15 | (0.6196078431372549, 0.6039215686274509, 0.7843137254901961 ), # 9e9ac8 16 | (0.7372549019607844, 0.7411764705882353, 0.8627450980392157 ), # bcbddc 17 | (0.8549019607843137, 0.8549019607843137, 0.9215686274509803 ), # dadaeb 18 | (0.38823529411764707, 0.38823529411764707, 0.38823529411764707 ), # 636363 19 | (0.5882352941176471, 0.5882352941176471, 0.5882352941176471 ), # 969696 20 | (0.7411764705882353, 0.7411764705882353, 0.7411764705882353 ), # bdbdbd 21 | (0.8509803921568627, 0.8509803921568627, 0.8509803921568627 ), # d9d9d9 22 | ) 23 | 24 | import numpy as np 25 | from scipy.stats import multivariate_normal as mvn_pdf 26 | 27 | import matplotlib.pyplot as plt 28 | from scipy.special import logsumexp 29 | import mpl_toolkits.mplot3d as m3d 30 | from mpl_toolkits.mplot3d import Axes3D 31 | import matplotlib.tri as mtri 32 | 33 | import pymesh 34 | 35 | 36 | 37 | def get_centroids(mesh): 38 | # obtain a vertex for each face index 39 | face_vert = mesh.vertices[mesh.faces.reshape(-1),:].reshape((mesh.faces.shape[0],3,-1)) #@ np.array([[1,0,0],[0,0,1],[0,-1,0] ]) 40 | # face_vert is size (faces,3(one for each vert), 3(one for each dimension)) 41 | centroids = face_vert.sum(1)/3.0 42 | ABAC = face_vert[:,1:3,:] - face_vert[:,0:1,:] 43 | areas = np.linalg.norm(np.cross(ABAC[:,0,:],ABAC[:,1,:]),axis=1)/2.0 44 | areas /= areas.min() 45 | areas = areas.reshape((-1,1)) 46 | return centroids, areas 47 | 48 | def get_tri_covar(tris): 49 | covars = [] 50 | for face in tris: 51 | A = face[0][:,None] 52 | B = face[1][:,None] 53 | C = face[2][:,None] 54 | M = (A+B+C)/3 55 | covars.append(A @ A.T + B @ B.T + C @ C.T - 3* M @ M.T) 56 | return np.array(covars)*(1/12.0) 57 | 58 | 59 | mesh0 = pymesh.load_mesh("bunny/bun_zipper_res4.ply") 60 | 61 | #pts = mesh0.vertices @ np.array([[1,0,0],[0,0,1],[0,-1,0] ]) 62 | 63 | pts,a = get_centroids(mesh0) 64 | face_vert = mesh0.vertices[mesh0.faces.reshape(-1),:].reshape((mesh0.faces.shape[0],3,-1)) #@ np.array([[1,0,0],[0,0,1],[0,-1,0] ]) 65 | data_covar = get_tri_covar(face_vert) 66 | 67 | K = 20 68 | colors = np.array(_tab20c_data)[:K] 69 | 70 | np.random.seed(24) 71 | 72 | labels = np.zeros((pts.shape[0],K)) 73 | labels[np.arange(pts.shape[0]), np.random.randint(0,K,pts.shape[0])] = 1 74 | #labels = np.exp(10*np.random.rand(pts.shape[0],K)) 75 | #labels /= labels.sum(1,keepdims=True) 76 | 77 | print(labels.max()) 78 | for iteration in range(150): 79 | # m-step 80 | new_means = [] 81 | new_covars = [] 82 | new_pis = [] 83 | for k in range(K): 84 | weights = a * labels[:,k:k+1] 85 | weight_norm = weights.sum() 86 | new_mean = (weights * pts).sum(0)/weight_norm 87 | new_means.append(new_mean) 88 | 89 | t = pts - new_mean 90 | new_covar = (weights/weight_norm * t).T @ t + ((weights/weight_norm).reshape((-1,1,1)) * data_covar).sum(0) 91 | new_covars.append(new_covar) 92 | 93 | new_pis.append( weight_norm.mean() ) 94 | new_pis = np.array(new_pis) 95 | new_pis /= new_pis.sum() 96 | 97 | # e-step 98 | for k in range(K): 99 | labels[:,k] = new_pis[k]*mvn_pdf(new_means[k],new_covars[k]).pdf(pts) 100 | labels /= labels.sum(1,keepdims=True) 101 | 102 | if (iteration % 1) == 0: 103 | fig = plt.figure(figsize=plt.figaspect(0.5),frameon=False) 104 | 105 | ax = fig.add_subplot(1,2, 1, projection='3d') 106 | 107 | #colors = [tuple(int(h[i:i+2], 16) for i in (0, 2, 4)) for h in ['CA3542','27646B']] 108 | #colors = np.array(colors)/255 109 | #colors = np.array([[1,0,0],[0,0,1]]) 110 | #ax.scatter(pts[:,0],pts[:,1],pts[:,2],s=20,c=labels@ colors) 111 | res = ax.plot_trisurf(mesh0.vertices[:,0],mesh0.vertices[:,1],mesh0.vertices[:,2],triangles=mesh0.faces,facecolors=labels@ colors) 112 | normals = ax._generate_normals(face_vert) 113 | colset = ax._shade_colors(labels@ colors, normals) 114 | res.set_facecolors(colset) 115 | 116 | r = max(pts.max(1) - pts.min(1))/2 117 | m = pts.mean(1) 118 | 119 | ax.set_xlim(m[0]-r,m[0]+r) 120 | ax.set_xlim(m[1]-r,m[1]+r) 121 | ax.set_xlim(m[2]-r,m[2]+r) 122 | ax.view_init(100,-90) 123 | ax.xaxis.set_ticklabels([]) 124 | ax.yaxis.set_ticklabels([]) 125 | ax.zaxis.set_ticklabels([]) 126 | #ax.set_aspect('equal', 'box') 127 | 128 | plt.title('E-Step Result',size=24,weight='demibold') 129 | plt.tight_layout() 130 | 131 | ax = fig.add_subplot(1,2, 2, projection='3d') 132 | 133 | for k in range(len(new_means)): 134 | mean,covar = new_means[k],new_covars[k] 135 | u,s,vt = np.linalg.svd(covar) 136 | coefs = (.002, .002, .002) # Coefficients in a0/c x**2 + a1/c y**2 + a2/c z**2 = 1 137 | # Radii corresponding to the coefficients: 138 | rx, ry, rz = 1.7*np.sqrt(s)#s#1/np.sqrt(coefs) 139 | 140 | R_reg = vt.T @ np.diag([1,1,np.linalg.det(vt.T @ u.T)]) @ u.T 141 | 142 | #print(eigs) 143 | # Set of all spherical angles: 144 | u = np.linspace(0, 2 * np.pi, 10) 145 | v = np.linspace(0, np.pi, 10) 146 | 147 | # Cartesian coordinates that correspond to the spherical angles: 148 | # (this is the equation of an ellipsoid): 149 | x = rx * np.outer(np.cos(u), np.sin(v)) #+ mean[0] 150 | y = ry * np.outer(np.sin(u), np.sin(v)) #+ mean[1] 151 | z = rz * np.outer(np.ones_like(u), np.cos(v)) #+ mean[2] 152 | 153 | for i in range(len(x)): 154 | for j in range(len(x)): 155 | x[i,j],y[i,j],z[i,j] = np.dot([x[i,j],y[i,j],z[i,j]], vt) + mean 156 | # Plot: 157 | res = ax.plot_surface(x,y,z, color=colors[k],shade=True,linewidth=0.0,alpha=min(0.5,new_pis[k]*K)) 158 | ax.set_xlim(m[0]-r,m[0]+r) 159 | ax.set_xlim(m[1]-r,m[1]+r) 160 | ax.set_xlim(m[2]-r,m[2]+r) 161 | ax.view_init(100,-90) 162 | ax.xaxis.set_ticklabels([]) 163 | ax.yaxis.set_ticklabels([]) 164 | ax.zaxis.set_ticklabels([]) 165 | #ax.set_aspect('equal', 'box') 166 | 167 | 168 | plt.title('M-Step Result',size=24,weight='demibold') 169 | plt.tight_layout() 170 | #plt.show() 171 | plt.savefig('output4/{:02d}.png'.format(iteration),dpi=300,pad_inches=0) 172 | plt.close('all') -------------------------------------------------------------------------------- /vis_fitting_bunny_mesh.py: -------------------------------------------------------------------------------- 1 | _tab20c_data = ( 2 | (0.19215686274509805, 0.5098039215686274, 0.7411764705882353 ), # 3182bd 3 | (0.4196078431372549, 0.6823529411764706, 0.8392156862745098 ), # 6baed6 4 | (0.6196078431372549, 0.792156862745098, 0.8823529411764706 ), # 9ecae1 5 | (0.7764705882352941, 0.8588235294117647, 0.9372549019607843 ), # c6dbef 6 | (0.9019607843137255, 0.3333333333333333, 0.050980392156862744), # e6550d 7 | (0.9921568627450981, 0.5529411764705883, 0.23529411764705882 ), # fd8d3c 8 | (0.9921568627450981, 0.6823529411764706, 0.4196078431372549 ), # fdae6b 9 | (0.9921568627450981, 0.8156862745098039, 0.6352941176470588 ), # fdd0a2 10 | (0.19215686274509805, 0.6392156862745098, 0.32941176470588235 ), # 31a354 11 | (0.4549019607843137, 0.7686274509803922, 0.4627450980392157 ), # 74c476 12 | (0.6313725490196078, 0.8509803921568627, 0.6078431372549019 ), # a1d99b 13 | (0.7803921568627451, 0.9137254901960784, 0.7529411764705882 ), # c7e9c0 14 | (0.4588235294117647, 0.4196078431372549, 0.6941176470588235 ), # 756bb1 15 | (0.6196078431372549, 0.6039215686274509, 0.7843137254901961 ), # 9e9ac8 16 | (0.7372549019607844, 0.7411764705882353, 0.8627450980392157 ), # bcbddc 17 | (0.8549019607843137, 0.8549019607843137, 0.9215686274509803 ), # dadaeb 18 | (0.38823529411764707, 0.38823529411764707, 0.38823529411764707 ), # 636363 19 | (0.5882352941176471, 0.5882352941176471, 0.5882352941176471 ), # 969696 20 | (0.7411764705882353, 0.7411764705882353, 0.7411764705882353 ), # bdbdbd 21 | (0.8509803921568627, 0.8509803921568627, 0.8509803921568627 ), # d9d9d9 22 | ) 23 | 24 | import numpy as np 25 | from scipy.stats import multivariate_normal as mvn_pdf 26 | 27 | import matplotlib.pyplot as plt 28 | from scipy.special import logsumexp 29 | import mpl_toolkits.mplot3d as m3d 30 | from mpl_toolkits.mplot3d import Axes3D 31 | import matplotlib.tri as mtri 32 | import sys 33 | import pymesh 34 | 35 | 36 | 37 | def get_centroids(mesh): 38 | # obtain a vertex for each face index 39 | face_vert = mesh.vertices[mesh.faces.reshape(-1),:].reshape((mesh.faces.shape[0],3,-1)) #@ np.array([[1,0,0],[0,0,1],[0,-1,0] ]) 40 | # face_vert is size (faces,3(one for each vert), 3(one for each dimension)) 41 | centroids = face_vert.sum(1)/3.0 42 | ABAC = face_vert[:,1:3,:] - face_vert[:,0:1,:] 43 | areas = np.linalg.norm(np.cross(ABAC[:,0,:],ABAC[:,1,:]),axis=1)/2.0 44 | areas /= areas.min() 45 | areas = areas.reshape((-1,1)) 46 | return centroids, areas 47 | 48 | def get_tri_covar(tris): 49 | covars = [] 50 | for face in tris: 51 | A = face[0][:,None] 52 | B = face[1][:,None] 53 | C = face[2][:,None] 54 | M = (A+B+C)/3 55 | covars.append(A @ A.T + B @ B.T + C @ C.T - 3* M @ M.T) 56 | return np.array(covars)*(1/12.0) 57 | 58 | 59 | mesh0 = pymesh.load_mesh("bunny/bun_zipper_res4.ply") 60 | 61 | #pts = mesh0.vertices @ np.array([[1,0,0],[0,0,1],[0,-1,0] ]) 62 | 63 | 64 | pts,a = get_centroids(mesh0) 65 | r = max(pts.max(1) - pts.min(1))/2 66 | m = pts.mean(1) 67 | #face_vert = mesh0.vertices[mesh0.faces.reshape(-1),:].reshape((mesh0.faces.shape[0],3,-1)) #@ np.array([[1,0,0],[0,0,1],[0,-1,0] ]) 68 | #data_covar = get_tri_covar(face_vert) 69 | mesh0 = pymesh.load_mesh("bunny/bun_zipper_pts_1000_1.ply") 70 | 71 | pts = mesh0.vertices 72 | 73 | K = 20 74 | colors = np.array(_tab20c_data)[:K] 75 | 76 | np.random.seed(42) 77 | 78 | labels = np.zeros((pts.shape[0],K)) 79 | labels[np.arange(pts.shape[0]), np.random.randint(0,K,pts.shape[0])] = 1 80 | #labels = np.exp(10*np.random.rand(pts.shape[0],K)) 81 | #labels /= labels.sum(1,keepdims=True) 82 | 83 | print(labels.max()) 84 | for iteration in range(150): 85 | # m-step 86 | new_means = [] 87 | new_covars = [] 88 | new_pis = [] 89 | for k in range(K): 90 | weights = labels[:,k:k+1] 91 | weight_norm = weights.sum() 92 | new_mean = (weights * pts).sum(0)/weight_norm 93 | new_means.append(new_mean) 94 | 95 | t = pts - new_mean 96 | new_covar = (weights/weight_norm * t).T @ t #+ ((weights/weight_norm).reshape((-1,1,1)) * data_covar).sum(0) 97 | new_covars.append(new_covar) 98 | 99 | new_pis.append( weight_norm.mean() ) 100 | new_pis = np.array(new_pis) 101 | new_pis /= new_pis.sum() 102 | 103 | # e-step 104 | for k in range(K): 105 | try: 106 | labels[:,k] = new_pis[k]*mvn_pdf(new_means[k],new_covars[k]).pdf(pts) 107 | except: 108 | new_covars[k] = np.identity(3) * 1e-6 109 | labels[:,k] = new_pis[k]*mvn_pdf(new_means[k],new_covars[k]).pdf(pts) 110 | 111 | labels /= labels.sum(1,keepdims=True) 112 | 113 | if (iteration % 1) == 0: 114 | fig = plt.figure(figsize=plt.figaspect(0.5),frameon=False) 115 | 116 | ax = fig.add_subplot(1,2, 1, projection='3d') 117 | 118 | #colors = [tuple(int(h[i:i+2], 16) for i in (0, 2, 4)) for h in ['CA3542','27646B']] 119 | #colors = np.array(colors)/255 120 | #colors = np.array([[1,0,0],[0,0,1]]) 121 | ax.scatter(pts[:,0],pts[:,1],pts[:,2],s=20,c=labels@ colors) 122 | 123 | 124 | 125 | ax.set_xlim(m[0]-r,m[0]+r) 126 | ax.set_xlim(m[1]-r,m[1]+r) 127 | ax.set_xlim(m[2]-r,m[2]+r) 128 | ax.view_init(100,-90) 129 | ax.xaxis.set_ticklabels([]) 130 | ax.yaxis.set_ticklabels([]) 131 | ax.zaxis.set_ticklabels([]) 132 | #ax.set_aspect('equal', 'box') 133 | 134 | plt.title('E-Step Result',size=24,weight='demibold') 135 | plt.tight_layout() 136 | 137 | ax = fig.add_subplot(1,2, 2, projection='3d') 138 | 139 | for k in range(len(new_means)): 140 | mean,covar = new_means[k],new_covars[k] 141 | u,s,vt = np.linalg.svd(covar) 142 | coefs = (.002, .002, .002) # Coefficients in a0/c x**2 + a1/c y**2 + a2/c z**2 = 1 143 | # Radii corresponding to the coefficients: 144 | rx, ry, rz = 1.7*np.sqrt(s)#s#1/np.sqrt(coefs) 145 | 146 | R_reg = vt.T @ np.diag([1,1,np.linalg.det(vt.T @ u.T)]) @ u.T 147 | 148 | #print(eigs) 149 | # Set of all spherical angles: 150 | u = np.linspace(0, 2 * np.pi, 10) 151 | v = np.linspace(0, np.pi, 10) 152 | 153 | # Cartesian coordinates that correspond to the spherical angles: 154 | # (this is the equation of an ellipsoid): 155 | x = rx * np.outer(np.cos(u), np.sin(v)) #+ mean[0] 156 | y = ry * np.outer(np.sin(u), np.sin(v)) #+ mean[1] 157 | z = rz * np.outer(np.ones_like(u), np.cos(v)) #+ mean[2] 158 | 159 | for i in range(len(x)): 160 | for j in range(len(x)): 161 | x[i,j],y[i,j],z[i,j] = np.dot([x[i,j],y[i,j],z[i,j]], vt) + mean 162 | # Plot: 163 | res = ax.plot_surface(x,y,z, color=colors[k],shade=True,linewidth=0.0,alpha=min(0.5,new_pis[k]*K)) 164 | ax.set_xlim(m[0]-r,m[0]+r) 165 | ax.set_xlim(m[1]-r,m[1]+r) 166 | ax.set_xlim(m[2]-r,m[2]+r) 167 | ax.view_init(100,-90) 168 | ax.xaxis.set_ticklabels([]) 169 | ax.yaxis.set_ticklabels([]) 170 | ax.zaxis.set_ticklabels([]) 171 | #ax.set_aspect('equal', 'box') 172 | 173 | 174 | plt.title('M-Step Result',size=24,weight='demibold') 175 | plt.tight_layout() 176 | #plt.show() 177 | plt.savefig('output5/{:02d}.png'.format(iteration),dpi=300,pad_inches=0) 178 | plt.close('all') -------------------------------------------------------------------------------- /visualize gmm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pylab import *\n", 10 | "import pymesh\n", 11 | "from mixture import GaussianMixture\n", 12 | "\n", 13 | "def get_centroids(mesh):\n", 14 | " # obtain a vertex for each face index\n", 15 | " face_vert = mesh.vertices[mesh.faces.reshape(-1),:].reshape((mesh.faces.shape[0],3,-1))\n", 16 | " # face_vert is size (faces,3(one for each vert), 3(one for each dimension))\n", 17 | " centroids = face_vert.sum(1)/3.0\n", 18 | " #face_vert = ((face_vert.shape[0]/SAMPLE_PTS)*(face_vert.reshape((-1,9))-np.repeat(centroids,3,axis=1)) + np.repeat(centroids,3,axis=1)).reshape((-1,3,3))\n", 19 | " ABAC = face_vert[:,1:3,:] - face_vert[:,0:1,:]\n", 20 | " areas = np.linalg.norm(np.cross(ABAC[:,0,:],ABAC[:,1,:]),axis=1)/2.0\n", 21 | " return centroids, areas\n", 22 | "\n", 23 | "def get_tri_covar(tris):\n", 24 | " covars = []\n", 25 | " for face in face_vert:\n", 26 | " A = face[0][:,None]\n", 27 | " B = face[1][:,None]\n", 28 | " C = face[2][:,None]\n", 29 | " M = (A+B+C)/3\n", 30 | " covars.append(A @ A.T + B @ B.T + C @ C.T - 3* M @ M.T)\n", 31 | " return np.array(covars)*(1/12.0)\n" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "import transforms3d\n", 41 | "mesh0 = pymesh.load_mesh(\"bunny/bun_zipper_1000_1.ply\")\n", 42 | "K = 100\n", 43 | "com,a = get_centroids(mesh0)\n", 44 | "center = mesh0.vertices.mean(0)\n", 45 | "com -= center\n", 46 | "print(com.shape)\n", 47 | "face_vert = mesh0.vertices[mesh0.faces.reshape(-1),:].reshape((mesh0.faces.shape[0],3,-1)) - center\n", 48 | "R = np.identity(3)\n", 49 | "R = transforms3d.euler.euler2mat(0,np.pi/2,np.pi/2, 'sxyz')\n", 50 | "R = transforms3d.euler.euler2mat(np.pi,0,np.pi, 'sxyz') @ R\n", 51 | "\n", 52 | "face_vert = face_vert @ R\n", 53 | "com = com @ R\n", 54 | "\n", 55 | "data_covar = get_tri_covar(face_vert)\n", 56 | "print(data_covar.shape)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "gm_mesh_kmeans = GaussianMixture(K,init_params='kmeans',tol=1e-9,max_iter=250); gm_mesh_kmeans.set_covars(data_covar); gm_mesh_kmeans.set_areas(a); gm_mesh_kmeans.fit(com); gm_mesh_kmeans.set_covars(None); gm_mesh_kmeans.set_areas(None)\n", 66 | "#gm_mesh_kmeans = GaussianMixture(K,init_params='kmeans',tol=1e-9,max_iter=250); gm_mesh_kmeans.fit(com);" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "%matplotlib notebook\n", 76 | "from mpl_toolkits.mplot3d import Axes3D\n", 77 | "from mpl_toolkits.mplot3d.art3d import Poly3DCollection\n", 78 | "import matplotlib.pyplot as plt\n", 79 | "from matplotlib.colors import LightSource\n", 80 | "\n", 81 | "\n", 82 | "fig = plt.figure(figsize=(15,6))\n", 83 | "ax = fig.add_subplot(1, 2, 1, projection='3d')\n", 84 | "b2 = np.array(face_vert)\n", 85 | "minb = b2.reshape((-1,3)).min(0)\n", 86 | "maxb = b2.reshape((-1,3)).max(0)\n", 87 | "meanb = b2.reshape((-1,3)).mean(0)\n", 88 | "span = maxb-minb\n", 89 | "w = np.max(span) * np.ones(3)\n", 90 | "pc = Poly3DCollection(face_vert, alpha = 0.75, facecolor='#1f77b4', linewidths=1,edgecolors=np.array([0.7,0.0,0.5,0.4]))\n", 91 | "ax.add_collection3d(pc)\n", 92 | "ax.set_xlim(minb[0]-w[0],maxb[0]+w[0])\n", 93 | "ax.set_ylim(minb[1]-w[1],maxb[1]+w[1])\n", 94 | "ax.set_zlim(minb[2]-w[2],maxb[2]+w[2])\n", 95 | "ax.set_proj_type('persp')\n", 96 | "\n", 97 | "\n", 98 | "# Hide grid lines\n", 99 | "ax.grid(False)\n", 100 | "\n", 101 | "\n", 102 | "\n", 103 | "# Hide axes ticks\n", 104 | "ax.set_xticks([])\n", 105 | "ax.set_yticks([])\n", 106 | "ax.set_zticks([])\n", 107 | "ax.view_init(26,-28)\n", 108 | "plt.axis('off')\n", 109 | "ax.dist = 3\n", 110 | "\n", 111 | "xlm=ax.get_xlim3d() #These are two tupples\n", 112 | "ylm=ax.get_ylim3d() #we use them in the next\n", 113 | "zlm=ax.get_zlim3d() #graph to reproduce the magnification from mousing\n", 114 | "\n", 115 | "ax.set_xlim3d(xlm[0]-0.04,xlm[1]-0.04) #Reproduce magnification\n", 116 | "ax.set_ylim3d(ylm[0],ylm[1]) #...\n", 117 | "ax.set_zlim3d(zlm[0],zlm[1]) #...\n", 118 | "\n", 119 | "ax = fig.add_subplot(1, 2, 2, projection='3d')\n", 120 | "#plot_gmm_contour(ax, gm_mesh_kmeans.means_, gm_mesh_kmeans.covariances_, gm_mesh_kmeans.weights_)\n", 121 | "points=10000\n", 122 | "z = []\n", 123 | "c = []\n", 124 | "colors = plt.cm.tab10(np.arange(10)/10) #plt.rcParams['axes.prop_cycle'].by_key()['color']\n", 125 | "means,covars,weights = com,data_covar,a/np.sum(a)#gm_mesh_kmeans.means_, gm_mesh_kmeans.covariances_, gm_mesh_kmeans.weights_\n", 126 | "for e,i in enumerate([14,428,412,446,451,411,439,477]):#range(means.shape[0]):\n", 127 | " if np.random.rand() < 5:\n", 128 | " samples = int(round(points*weights[i]))\n", 129 | " z.append(np.random.multivariate_normal(means[i],covars[i],samples))\n", 130 | " c += [colors[e%10] for _ in range(samples)]\n", 131 | "z = np.vstack(z)\n", 132 | "c = np.vstack(c)\n", 133 | "\n", 134 | "ax.set_xticks([])\n", 135 | "ax.set_yticks([])\n", 136 | "ax.set_zticks([])\n", 137 | "ax.view_init(26,-28)\n", 138 | "plt.axis('off')\n", 139 | "ax.set_proj_type('persp')\n", 140 | "\n", 141 | "ax.dist = 3\n", 142 | "pc = Poly3DCollection(face_vert, alpha = 0.1, facecolor='#ffffff', linewidths=1,edgecolors=np.array([0.7,0.0,0.5,0.1]))\n", 143 | "ax.add_collection3d(pc)\n", 144 | "#ax.scatter(z[:,0],z[:,1],z[:,2],c=c,depthshade=False)\n", 145 | "\n", 146 | "ax.set_xlim3d(xlm[0]-0.04,xlm[1]-0.04) #Reproduce magnification\n", 147 | "ax.set_ylim3d(ylm[0],ylm[1]) #...\n", 148 | "ax.set_zlim3d(zlm[0],zlm[1]) #...\n", 149 | "\n", 150 | "faces_of_interest = [14,428,412,446,451,411,439,477]\n", 151 | "for e,i in enumerate(faces_of_interest):#range(means.shape[0]):\n", 152 | " mean,covar = means[i],covars[i]\n", 153 | " u,s,vt = np.linalg.svd(covar)\n", 154 | " coefs = (.002, .002, .002) # Coefficients in a0/c x**2 + a1/c y**2 + a2/c z**2 = 1 \n", 155 | " # Radii corresponding to the coefficients:\n", 156 | " rx, ry, rz = 1.5*np.sqrt(s)#s#1/np.sqrt(coefs)\n", 157 | " \n", 158 | " R_reg = vt.T @ np.diag([1,1,np.linalg.det(vt.T @ u.T)]) @ u.T\n", 159 | " \n", 160 | " #print(eigs)\n", 161 | " # Set of all spherical angles:\n", 162 | " u = np.linspace(0, 2 * np.pi, 25)\n", 163 | " v = np.linspace(0, np.pi, 25)\n", 164 | "\n", 165 | " # Cartesian coordinates that correspond to the spherical angles:\n", 166 | " # (this is the equation of an ellipsoid):\n", 167 | " x = rx * np.outer(np.cos(u), np.sin(v)) #+ mean[0]\n", 168 | " y = ry * np.outer(np.sin(u), np.sin(v)) #+ mean[1]\n", 169 | " z = rz * np.outer(np.ones_like(u), np.cos(v)) #+ mean[2]\n", 170 | " \n", 171 | " for i in range(len(x)):\n", 172 | " for j in range(len(x)):\n", 173 | " x[i,j],y[i,j],z[i,j] = np.dot([x[i,j],y[i,j],z[i,j]], vt) + mean \n", 174 | " # Plot:\n", 175 | " ax.plot_surface(x,y,z, color=colors[e%10],shade=False,linewidth=0.0)\n", 176 | "pc = Poly3DCollection(face_vert[faces_of_interest], alpha = 0.0, facecolor='#ffffff', linewidths=2,edgecolors=np.array([0.7,0.0,0.5,0.5]))\n", 177 | "ax.add_collection3d(pc)\n", 178 | "plt.tight_layout()\n", 179 | "plt.savefig('bunny_head.pdf')\n", 180 | "\n", 181 | "plt.show()" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "sz = 4\n", 191 | "fig = plt.figure(figsize=(1.4475*sz,sz))\n", 192 | "padw = (1.4475-1)/2\n", 193 | "#ax = fig.add_subplot(1, 1, 1, projection='3d')#Axes3D(fig)\n", 194 | "# L B W H\n", 195 | "ax = Axes3D(fig, [0, -0.15, 1-padw, 1.1], )\n", 196 | "#plot_gmm_contour(ax, gm_mesh_kmeans.means_, gm_mesh_kmeans.covariances_, gm_mesh_kmeans.weights_)\n", 197 | "points=10000\n", 198 | "z = []\n", 199 | "c = []\n", 200 | "ax.autoscale_view(tight=True, scalex=False, scaley=False, scalez=False)\n", 201 | "\n", 202 | "pc = Poly3DCollection(face_vert, alpha = 0.01, facecolor='#ffffff', linewidths=1,edgecolors=np.array([0.7,0.0,0.5,0.1]))\n", 203 | "ax.add_collection3d(pc)\n", 204 | "#ax.scatter(z[:,0],z[:,1],z[:,2],c=c,depthshade=False)\n", 205 | "\n", 206 | "means,covars,weights = gm_mesh_kmeans.means_, gm_mesh_kmeans.covariances_, gm_mesh_kmeans.weights_\n", 207 | "\n", 208 | "colors = plt.cm.Pastel1(np.arange(means.shape[0])/(means.shape[0])) #plt.rcParams['axes.prop_cycle'].by_key()['color']\n", 209 | "np.random.shuffle(colors)\n", 210 | "\n", 211 | "faces_of_interest = [14,428,412,446,451,411,439,477]\n", 212 | "for i in range(means.shape[0]):\n", 213 | " orig_i = i\n", 214 | " mean,covar = means[i],covars[i]\n", 215 | " u,s,vt = np.linalg.svd(covar)\n", 216 | " coefs = (.002, .002, .002) # Coefficients in a0/c x**2 + a1/c y**2 + a2/c z**2 = 1 \n", 217 | " # Radii corresponding to the coefficients:\n", 218 | " rx, ry, rz = 1.7*np.sqrt(s)#s#1/np.sqrt(coefs)\n", 219 | " \n", 220 | " R_reg = vt.T @ np.diag([1,1,np.linalg.det(vt.T @ u.T)]) @ u.T\n", 221 | " \n", 222 | " #print(eigs)\n", 223 | " # Set of all spherical angles:\n", 224 | " u = np.linspace(0, 2 * np.pi, 10)\n", 225 | " v = np.linspace(0, np.pi, 10)\n", 226 | "\n", 227 | " # Cartesian coordinates that correspond to the spherical angles:\n", 228 | " # (this is the equation of an ellipsoid):\n", 229 | " x = rx * np.outer(np.cos(u), np.sin(v)) #+ mean[0]\n", 230 | " y = ry * np.outer(np.sin(u), np.sin(v)) #+ mean[1]\n", 231 | " z = rz * np.outer(np.ones_like(u), np.cos(v)) #+ mean[2]\n", 232 | " \n", 233 | " for i in range(len(x)):\n", 234 | " for j in range(len(x)):\n", 235 | " x[i,j],y[i,j],z[i,j] = np.dot([x[i,j],y[i,j],z[i,j]], vt) + mean \n", 236 | " # Plot:\n", 237 | " res = ax.plot_surface(x,y,z, color=colors[orig_i],shade=True,linewidth=0.0)\n", 238 | "plt.tight_layout()\n", 239 | "\n", 240 | "ax.set_xticks([])\n", 241 | "ax.set_yticks([])\n", 242 | "ax.set_zticks([])\n", 243 | "ax.view_init(39.4,14)\n", 244 | "plt.axis('off')\n", 245 | "ax.set_proj_type('persp')\n", 246 | "ax.dist = 4.2\n", 247 | "ax.set_xlim(meanb[0]-w[0],meanb[0]+w[0])\n", 248 | "ax.set_ylim(meanb[1]-w[1],meanb[1]+w[1])\n", 249 | "ax.set_zlim(meanb[2]-w[2],meanb[2]+w[2])\n", 250 | "ax.set_proj_type('persp')\n", 251 | "plt.tight_layout()\n", 252 | "plt.savefig('bunny_gmm_k100_mesh.pdf')\n", 253 | "plt.show()" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "res[0]" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": null, 268 | "metadata": {}, 269 | "outputs": [], 270 | "source": [] 271 | } 272 | ], 273 | "metadata": { 274 | "kernelspec": { 275 | "display_name": "Python 3", 276 | "language": "python", 277 | "name": "python3" 278 | }, 279 | "language_info": { 280 | "codemirror_mode": { 281 | "name": "ipython", 282 | "version": 3 283 | }, 284 | "file_extension": ".py", 285 | "mimetype": "text/x-python", 286 | "name": "python", 287 | "nbconvert_exporter": "python", 288 | "pygments_lexer": "ipython3", 289 | "version": "3.6.8" 290 | } 291 | }, 292 | "nbformat": 4, 293 | "nbformat_minor": 2 294 | } 295 | --------------------------------------------------------------------------------