├── .gitignore ├── .gitmodules ├── .vscode ├── launch.json └── settings.json ├── LICENSE ├── README.md ├── controls ├── __init__.py ├── pose_slider.py ├── shape_slider.py └── trans_slider.py ├── docs └── smal_viewer.gif ├── p3d_renderer.py ├── pyqt_viewer.py ├── pyrenderer.py ├── smal_model ├── __init__.py ├── batch_lbs.py ├── smal_basics.py ├── smal_torch.py ├── template_w_tex_uv.mtl └── template_w_tex_uv.obj └── smal_viewer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | *.pkl 106 | 107 | data/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "SMPL"] 2 | path = SMPL 3 | url = https://github.com/benjiebob/SMPL 4 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: SMAL_Viewer", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${workspaceFolder}/smal_viewer.py", 12 | "console": "integratedTerminal" 13 | } 14 | ] 15 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "C:\\Users\\bjb10042\\.conda\\envs\\bjb_env\\python.exe" 3 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Benjamin Biggs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SMAL Viewer 2 | PyQt5 app for viewing SMAL meshes 3 | 4 | 5 | 6 | ## Installation 7 | 1. Clone the repository and enter directory 8 | ``` 9 | git clone https://github.com/benjiebob/SMALViewer 10 | cd SMALViewer 11 | ``` 12 | 13 | 2. Clone the [SMALST](https://github.com/silviazuffi/smalst) project website in order to access the latest version of the SMAL deformable animal model. You should copy all of [these files](https://github.com/silviazuffi/smalst/tree/master/smpl_models) underneath a SMALViewer/data directory. 14 | 15 | Windows tip: If you are a Windows user, you can use these files but you'll need to edit the line endings. Try the following Powershell commands, shown here on one example: 16 | ``` 17 | $path="my_smpl_00781_4_all_template_w_tex_uv_001.pkl" 18 | (Get-Content $path -Raw).Replace("`r`n","`n") | Set-Content $path -Force 19 | ``` 20 | 21 | For more information, check out the StackOverflow answer [here](https://stackoverflow.com/questions/19127741/replace-crlf-using-powershell) 22 | 23 | 24 | 3. Install dependencies, particularly [PyTorch](https://pytorch.org/), [PyQt5](https://pypi.org/project/PyQt5/), [Pyrender](https://github.com/mmatl/pyrender) and [nibabel](https://github.com/nipy/nibabel). 25 | 26 | Tips for debugging offscreen render: If you are a Linux user and have trouble with the Pyrender's OffscreenRenderer, I recommend following the steps to install OSMesa [here](https://pyrender.readthedocs.io/en/latest/examples/offscreen.html) including the need to add the following to the top of pyrenderer.py 27 | 28 | ``` 29 | os.environ['PYOPENGL_PLATFORM'] = 'osmesa'. 30 | ``` 31 | 32 | If you are a Windows user and you experience issues with OffscreenRenderer, you can fix by following the advice [here](https://github.com/mmatl/pyrender/issues/117). A quick fix is to edit the function "make_current" in pyrender/platforms/pyglet_platform.py, L53 (wherever it's installed for you) to: 33 | 34 | ``` 35 | def make_uncurrent(self): 36 | try: 37 | import pyglet.gl.xlib 38 | pyglet.gl.xlib.glx.glXMakeContextCurrent(self._window.context.x_display, 0, 0, None) 39 | except: 40 | pass 41 | ``` 42 | 43 | 4. Download [SMPL](https://smpl.is.tue.mpg.de/) and create a smpl_webuser directory underneath SMALViewer/smal_model 44 | 45 | 5. Test the python3 script 46 | ``` 47 | python smal_viewer.py 48 | ``` 49 | ## Differentiable Rendering 50 | 51 | For many research applications, it is useful to be able to propagate gradients from 2D losses (e.g. silhouette/perceptual) back through the rendering process. For this, one should use a differentiable render such as [PyTorch3D](https://github.com/facebookresearch/pytorch3d) or [Neural Mesh Renderer](https://github.com/daniilidis-group/neural_renderer). Although not usful for this simple demo app, I have included a script p3d_renderer.py which shows how one can achieve differentiable rendering of the SMAL mesh with PyTorch3D. You can flip between the two rendering methods by selecting between the two imports at the top of pyqt_viewer.py: 52 | 53 | ``` 54 | from pyrenderer import Renderer 55 | # from p3d_renderer import Renderer 56 | ``` 57 | 58 | Please note that the speed of PyTorch3D compared to Pyrender is significantly slower so you'll probably experience some lag with this option. 59 | 60 | For completeness, I've also shown how to apply a texture map to the SMAL mesh with p3d_renderer (again useful for perceptual losses). To do this, you will need to download an example SMAL texture map. Do this by creating an account for the [SMALR page](http://smalr.is.tue.mpg.de/downloads), choose CVPR Downloads and download (for example) the Dog B zip file. Extract this underneath ./data. 61 | 62 | ## Acknowledgements 63 | This work was completed in relation to the paper [Creatures Great and SMAL: Recovering the shape and motion of animals from video](https://arxiv.org/abs/1811.05804): 64 | ``` 65 | @inproceedings{biggs2018creatures, 66 | title={{C}reatures great and {SMAL}: {R}ecovering the shape and motion of animals from video}, 67 | author={Biggs, Benjamin and Roddick, Thomas and Fitzgibbon, Andrew and Cipolla, Roberto}, 68 | booktitle={ACCV}, 69 | year={2018} 70 | } 71 | ``` 72 | 73 | and more recently [Who Left the Dogs Out? 3D Animal Reconstruction with Expectation Maximization in the Loop](https://arxiv.org/abs/2007.11110): 74 | ``` 75 | @inproceedings{biggs2020wldo, 76 | title={{W}ho left the dogs out?: {3D} animal reconstruction with expectation maximization in the loop}, 77 | author={Biggs, Benjamin and Boyne, Oliver and Charles, James and Fitzgibbon, Andrew and Cipolla, Roberto}, 78 | booktitle={ECCV}, 79 | year={2020} 80 | } 81 | ``` 82 | 83 | Please also acknowledge the original authors of the SMAL animal model: 84 | ``` 85 | @inproceedings{Zuffi:CVPR:2017, 86 | title = {{3D} Menagerie: Modeling the {3D} Shape and Pose of Animals}, 87 | author = {Zuffi, Silvia and Kanazawa, Angjoo and Jacobs, David and Black, Michael J.}, 88 | booktitle = {IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)}, 89 | month = jul, 90 | year = {2017}, 91 | month_numeric = {7} 92 | } 93 | ``` 94 | -------------------------------------------------------------------------------- /controls/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/benjiebob/SMALViewer/ca54072ad5d7c78b2bf4ed19945ed2e052e7fbe9/controls/__init__.py -------------------------------------------------------------------------------- /controls/pose_slider.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import QtGui, QtCore 2 | from PyQt5.QtWidgets import QWidget, QLabel, QSlider, QVBoxLayout, QHBoxLayout 3 | 4 | import numpy as np 5 | from nibabel import eulerangles 6 | 7 | class PoseSlider(QWidget): 8 | value_changed = QtCore.pyqtSignal(np.ndarray) 9 | 10 | def __init__(self, idx, angle_range, vert_stack = False, label_to_side = True): 11 | super(PoseSlider, self).__init__() 12 | self.idx = idx 13 | 14 | self.max_angle = 0.5 * angle_range 15 | self.min_angle = -0.5 * angle_range 16 | 17 | self.value = np.array([0.0, 0.0, 0.0]) 18 | self.axis_value = self.eul_to_axis(self.value) 19 | self.eul_value_label = QLabel(self.np_rots_to_label_text(self.value)) 20 | self.axis_value_label = QLabel(self.np_rots_to_label_text(self.axis_value)) 21 | 22 | self.x_slider = QSlider(QtCore.Qt.Horizontal) 23 | self.y_slider = QSlider(QtCore.Qt.Horizontal) 24 | self.z_slider = QSlider(QtCore.Qt.Horizontal) 25 | self.sliders = [self.x_slider, self.y_slider, self.z_slider] 26 | 27 | vert_layout = QVBoxLayout() 28 | horiz_layout = QHBoxLayout() 29 | for slider in self.sliders: 30 | slider.setRange(0, 100) 31 | slider.setValue(50) 32 | slider.valueChanged[int].connect(self.__slider_value_changed) 33 | min_label = QLabel(str(np.around(self.min_angle, decimals=2))) 34 | max_label = QLabel(str(np.around(self.max_angle, decimals=2))) 35 | 36 | horiz_layout.addWidget(min_label) 37 | horiz_layout.addWidget(slider) 38 | horiz_layout.addWidget(max_label) 39 | if vert_stack: 40 | if label_to_side: 41 | horiz_layout.addWidget(self.eul_value_label) 42 | horiz_layout.addWidget(self.axis_value_label) 43 | vert_layout.addLayout(horiz_layout) 44 | horiz_layout = QHBoxLayout() 45 | 46 | if not label_to_side: 47 | horiz_layout.addWidget(self.eul_value_label) 48 | horiz_layout.addWidget(self.axis_value_label) 49 | vert_layout.addLayout(horiz_layout) 50 | 51 | if vert_stack: 52 | self.setLayout(vert_layout) 53 | else: 54 | horiz_layout.addWidget(self.eul_value_label) 55 | horiz_layout.addWidget(self.axis_value_label) 56 | self.setLayout(horiz_layout) 57 | 58 | def reset(self): 59 | self.value = np.array([0.0, 0.0, 0.0]) 60 | self.axis_value = self.eul_to_axis(self.value) 61 | self.eul_value_label.setText(self.np_rots_to_label_text(self.value)) 62 | self.axis_value_label.setText(self.np_rots_to_label_text(self.axis_value)) 63 | self.x_slider.setValue(50) 64 | self.y_slider.setValue(50) 65 | self.z_slider.setValue(50) 66 | 67 | def setValue(self, value): 68 | self.value = value 69 | self.axis_value = self.eul_to_axis(self.value) 70 | self.eul_value_label.setText(self.np_rots_to_label_text(self.value)) 71 | self.axis_value_label.setText(self.np_rots_to_label_text(self.axis_value)) 72 | 73 | self.x_slider.setValue(self.__rot_to_slider_int(value[0])) 74 | self.y_slider.setValue(self.__rot_to_slider_int(value[1])) 75 | self.z_slider.setValue(self.__rot_to_slider_int(value[2])) 76 | 77 | def np_rots_to_label_text(self, np_array): 78 | return str(np.around(np_array, decimals = 2)) 79 | 80 | def __slider_value_changed(self, value): 81 | x_value = self.__slider_int_to_float(self.x_slider.value()) 82 | y_value = self.__slider_int_to_float(self.y_slider.value()) 83 | z_value = self.__slider_int_to_float(self.z_slider.value()) 84 | 85 | self.value = np.array([x_value, y_value, z_value]) 86 | self.axis_value = self.eul_to_axis(self.value) 87 | 88 | self.eul_value_label.setText(self.np_rots_to_label_text(self.value)) 89 | self.axis_value_label.setText(self.np_rots_to_label_text(self.axis_value)) 90 | self.value_changed.emit(self.axis_value) 91 | 92 | def force_emit(self): 93 | axis_value = self.eul_to_axis(self.value) 94 | self.value_changed.emit(self.axis_value) 95 | 96 | def eul_to_axis(self, euler_value): 97 | theta, vector = eulerangles.euler2angle_axis(euler_value[2], euler_value[1], euler_value[0]) 98 | return vector * theta 99 | 100 | def __rot_to_slider_int(self, rotation): 101 | total_range = self.max_angle - self.min_angle 102 | offset = rotation - self.min_angle 103 | return (offset / total_range) * 100.0 104 | 105 | def __slider_int_to_float(self, slider_val): 106 | total_range = self.max_angle - self.min_angle 107 | float_percentage = (total_range / 100.0) * slider_val # slider ranges between [0, 100] so slider_val is just a percentage 108 | return self.min_angle + float_percentage 109 | -------------------------------------------------------------------------------- /controls/shape_slider.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import QtGui, QtCore 2 | from PyQt5.QtWidgets import QWidget 3 | from PyQt5.QtWidgets import QWidget, QLabel, QSlider, QVBoxLayout, QHBoxLayout 4 | 5 | import numpy as np 6 | 7 | class ShapeSlider(QWidget): 8 | value_changed = QtCore.pyqtSignal(float) 9 | 10 | def __init__(self, idx, std_dev, std_offset = 2): 11 | super(ShapeSlider, self).__init__() 12 | self.idx = idx 13 | self.std_offset = std_offset 14 | 15 | self.slider = QSlider(QtCore.Qt.Horizontal) 16 | self.slider.setRange(0, 100) 17 | self.slider.setValue(50) 18 | 19 | self.value = 0 20 | self.value_label = QLabel("--") 21 | self.min_label = QLabel("--") 22 | self.max_label = QLabel("--") 23 | 24 | self.__set_range_by_std_dev(std_dev) 25 | 26 | self.slider.valueChanged[int].connect(self.__slider_value_changed) 27 | 28 | horiz_layout = QHBoxLayout() 29 | horiz_layout.addWidget(self.min_label) 30 | horiz_layout.addWidget(self.slider) 31 | horiz_layout.addWidget(self.max_label) 32 | horiz_layout.addWidget(self.value_label) 33 | 34 | self.setLayout(horiz_layout) 35 | 36 | def reset(self): 37 | self.slider.setValue(50) 38 | 39 | def setValue(self, value): 40 | self.slider.setValue(self.__float_to_slider(value)) 41 | 42 | def __set_range_by_std_dev(self, std_dev): 43 | self.std_dev = std_dev 44 | 45 | min = -1.0 * std_dev * self.std_offset 46 | max = std_dev * self.std_offset 47 | 48 | min_rnd = np.around(min, decimals=2) 49 | max_rnd = np.around(max, decimals=2) 50 | 51 | self.min_label.setText(str(min_rnd)) 52 | self.max_label.setText(str(max_rnd)) 53 | self.__slider_value_changed(50) 54 | 55 | def __slider_value_changed(self, value): 56 | self.value = self.__slider_int_to_float(value) 57 | self.value_label.setText(str(np.around(self.value, decimals=2))) 58 | self.value_changed.emit(self.value) 59 | 60 | def __float_to_slider(self, float_val): 61 | total_range = self.std_dev * 2 * self.std_offset 62 | offset_val = float_val + (self.std_dev * self.std_offset) 63 | return (offset_val / total_range) * 100 64 | 65 | def __slider_int_to_float(self, slider_val): 66 | total_range = self.std_dev * 2 * self.std_offset 67 | float_percentage = (total_range / 100.0) * slider_val # slider ranges between [0, 100] so slider_val is just a percentage 68 | return (-1.0 * self.std_dev * self.std_offset) + float_percentage 69 | -------------------------------------------------------------------------------- /controls/trans_slider.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import QtGui, QtCore 2 | from PyQt5.QtWidgets import QWidget, QLabel, QSlider, QVBoxLayout, QHBoxLayout 3 | 4 | import numpy as np 5 | from nibabel import eulerangles 6 | 7 | class TransSlider(QWidget): 8 | value_changed = QtCore.pyqtSignal(np.ndarray) 9 | 10 | def __init__(self, idx, min_trans, max_trans, initial_trans, vert_stack = False, label_to_side = True): 11 | super(TransSlider, self).__init__() 12 | self.idx = idx 13 | self.default = initial_trans 14 | 15 | self.max_trans = max_trans 16 | self.min_trans = min_trans 17 | self.value_label = QLabel() 18 | 19 | self.x_slider = QSlider(QtCore.Qt.Horizontal) 20 | self.y_slider = QSlider(QtCore.Qt.Horizontal) 21 | self.z_slider = QSlider(QtCore.Qt.Horizontal) 22 | self.sliders = [self.x_slider, self.y_slider, self.z_slider] 23 | 24 | vert_layout = QVBoxLayout() 25 | horiz_layout = QHBoxLayout() 26 | for slider in self.sliders: 27 | slider.setRange(0, 100) 28 | slider.setValue(50) 29 | slider.valueChanged[int].connect(self.__slider_value_changed) 30 | min_label = QLabel(str(np.around(self.min_trans, decimals=2))) 31 | max_label = QLabel(str(np.around(self.max_trans, decimals=2))) 32 | 33 | horiz_layout.addWidget(min_label) 34 | horiz_layout.addWidget(slider) 35 | horiz_layout.addWidget(max_label) 36 | 37 | if vert_stack: 38 | if label_to_side: 39 | horiz_layout.addWidget(self.value_label) 40 | vert_layout.addLayout(horiz_layout) 41 | horiz_layout = QHBoxLayout() 42 | 43 | if not label_to_side: 44 | horiz_layout.addWidget(self.value_label) 45 | vert_layout.addLayout(horiz_layout) 46 | 47 | if vert_stack: 48 | self.setLayout(vert_layout) 49 | else: 50 | horiz_layout.addWidget(self.value_label) 51 | self.setLayout(horiz_layout) 52 | 53 | self.setValue(initial_trans) 54 | 55 | def reset(self): 56 | self.setValue(self.default) 57 | 58 | def setValue(self, value): 59 | self.value = value 60 | self.value_label.setText(self.np_to_label_text(self.value)) 61 | 62 | self.x_slider.setValue(self.__np_to_slider_int(value[0])) 63 | self.y_slider.setValue(self.__np_to_slider_int(value[1])) 64 | self.z_slider.setValue(self.__np_to_slider_int(value[2])) 65 | 66 | def np_to_label_text(self, np_array): 67 | return str(np.around(np_array, decimals = 2)) 68 | 69 | def __slider_value_changed(self, value): 70 | x_value = self.__slider_int_to_np(self.x_slider.value()) 71 | y_value = self.__slider_int_to_np(self.y_slider.value()) 72 | z_value = self.__slider_int_to_np(self.z_slider.value()) 73 | 74 | self.value = np.array([x_value, y_value, z_value]) 75 | 76 | self.value_label.setText(self.np_to_label_text(self.value)) 77 | self.value_changed.emit(self.value) 78 | 79 | def __np_to_slider_int(self, rotation): 80 | total_range = self.max_trans - self.min_trans 81 | offset = rotation - self.min_trans 82 | return (offset / total_range) * 100.0 83 | 84 | def __slider_int_to_np(self, slider_val): 85 | total_range = self.max_trans - self.min_trans 86 | float_percentage = (total_range / 100.0) * slider_val # slider ranges between [0, 100] so slider_val is just a percentage 87 | return self.min_trans + float_percentage 88 | -------------------------------------------------------------------------------- /docs/smal_viewer.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/benjiebob/SMALViewer/ca54072ad5d7c78b2bf4ed19945ed2e052e7fbe9/docs/smal_viewer.gif -------------------------------------------------------------------------------- /p3d_renderer.py: -------------------------------------------------------------------------------- 1 | # Data structures and functions for rendering 2 | import torch 3 | import torch.nn.functional as F 4 | from scipy.io import loadmat 5 | import numpy as np 6 | 7 | from pytorch3d.structures import Meshes 8 | from pytorch3d.renderer import ( 9 | OpenGLPerspectiveCameras, look_at_view_transform, look_at_rotation, 10 | RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams, 11 | PointLights, SoftPhongShader, SoftSilhouetteShader 12 | ) 13 | from pytorch3d.io import load_objs_as_meshes 14 | 15 | class Renderer(torch.nn.Module): 16 | def __init__(self, image_size): 17 | super(Renderer, self).__init__() 18 | 19 | self.image_size = image_size 20 | self.dog_obj = load_objs_as_meshes(['data/dog_B/dog_B/dog_B_tpose.obj']) 21 | 22 | raster_settings = RasterizationSettings( 23 | image_size=self.image_size, 24 | blur_radius=0.0, 25 | faces_per_pixel=1, 26 | bin_size=None 27 | ) 28 | 29 | R, T = look_at_view_transform(2.7, 0, 0) 30 | cameras = OpenGLPerspectiveCameras(device=R.device, R=R, T=T) 31 | lights = PointLights(device=R.device, location=[[0.0, 1.0, 0.0]]) 32 | 33 | self.renderer = MeshRenderer( 34 | rasterizer=MeshRasterizer( 35 | cameras=cameras, 36 | raster_settings=raster_settings 37 | ), 38 | shader=SoftPhongShader( 39 | device=R.device, 40 | cameras=cameras, 41 | lights=lights 42 | ) 43 | ) 44 | 45 | def forward(self, vertices, faces): 46 | mesh = Meshes(verts=vertices, faces=faces, textures=self.dog_obj.textures) 47 | images = self.renderer(mesh) 48 | return images -------------------------------------------------------------------------------- /pyqt_viewer.py: -------------------------------------------------------------------------------- 1 | # Imports 2 | import numpy as np 3 | import torch 4 | 5 | from PyQt5 import QtGui, QtCore 6 | from PyQt5.QtWidgets import QMainWindow, QWidget, QFrame, QLabel, QPushButton, QHBoxLayout, QVBoxLayout, QScrollArea, QGridLayout, QCheckBox 7 | from PyQt5.QtGui import QImage, QPixmap 8 | 9 | from controls.shape_slider import ShapeSlider 10 | from controls.pose_slider import PoseSlider 11 | from controls.trans_slider import TransSlider 12 | 13 | from smal_model.smal_torch import SMAL 14 | from pyrenderer import Renderer 15 | # from p3d_renderer import Renderer 16 | from functools import partial 17 | import time 18 | 19 | import os 20 | import collections 21 | 22 | import scipy.misc 23 | import datetime 24 | 25 | import pickle as pkl 26 | 27 | NUM_POSE_PARAMS = 34 28 | NUM_SHAPE_PARAMS = 41 29 | RENDER_SIZE = 256 30 | DISPLAY_SIZE = 512 31 | 32 | class MainWindow(QMainWindow): 33 | def __init__(self, parent=None): 34 | QMainWindow.__init__(self, parent) 35 | 36 | self.smal_params = { 37 | 'betas' : torch.zeros(1, NUM_SHAPE_PARAMS), 38 | 'joint_rotations' : torch.zeros(1, NUM_POSE_PARAMS, 3), 39 | 'global_rotation' : torch.zeros(1, 1, 3), 40 | 'trans' : torch.zeros(1, 1, 3), 41 | } 42 | 43 | self.model_renderer = Renderer(RENDER_SIZE) 44 | self.smal_model = SMAL('data/my_smpl_00781_4_all.pkl') 45 | 46 | with open('data/my_smpl_data_00781_4_all.pkl', 'rb') as f: 47 | u = pkl._Unpickler(f) 48 | u.encoding = 'latin1' 49 | smal_data = u.load() 50 | 51 | self.toy_betas = smal_data['toys_betas'] 52 | self.setup_ui() 53 | 54 | def get_layout_region(self, control_set): 55 | layout_region = QVBoxLayout() 56 | 57 | scrollArea = QScrollArea() 58 | scrollArea.setWidgetResizable(True) 59 | scrollAreaWidgetContents = QWidget(scrollArea) 60 | scrollArea.setWidget(scrollAreaWidgetContents) 61 | scrollArea.setMinimumWidth(750) 62 | 63 | grid_layout = QGridLayout() 64 | 65 | for idx, (label, com_slider) in enumerate(control_set): 66 | grid_layout.addWidget(label, idx, 0) 67 | grid_layout.addWidget(com_slider, idx, 1) 68 | 69 | scrollAreaWidgetContents.setLayout(grid_layout) 70 | 71 | layout_region.addWidget(scrollArea) 72 | return layout_region 73 | 74 | def setup_ui(self): 75 | self.shape_controls = [] 76 | self.pose_controls = [] 77 | self.update_poly = True 78 | self.toy_pbs = [] 79 | 80 | def ctrl_layout_add_separator(): 81 | line = QFrame() 82 | line.setFrameShape(QFrame.HLine) 83 | line.setFrameShadow(QFrame.Sunken) 84 | ctrl_layout.addWidget(line) 85 | 86 | # SHAPE REGION 87 | std_devs = np.std(self.toy_betas, axis=1) 88 | for idx, toy_std in enumerate(std_devs): 89 | label = QLabel("S{0}".format(idx)) 90 | sliders = ShapeSlider(idx, toy_std) 91 | sliders.value_changed.connect(self.update_model) 92 | self.shape_controls.append((label, sliders)) 93 | self.toy_pbs.append(QPushButton("T{0}".format(idx))) 94 | 95 | reset_shape_pb = QPushButton('Reset Shape') 96 | reset_shape_pb.clicked.connect(self.reset_shape) 97 | 98 | self.toy_frame = QFrame() 99 | self.toy_layout = QGridLayout() 100 | for idx, pb in enumerate(self.toy_pbs): 101 | row = idx // 5 102 | col = idx - row * 5 103 | pb.clicked.connect(partial(self.make_toy_shape, idx)) 104 | self.toy_layout.addWidget(pb, row, col) 105 | 106 | self.toy_frame.setLayout(self.toy_layout) 107 | self.toy_frame.setHidden(True) 108 | 109 | show_toys_cb = QCheckBox('Show Toys', self) 110 | show_toys_cb.stateChanged.connect(partial(self.toggle_control, self.toy_frame)) 111 | 112 | shape_layout = self.get_layout_region(self.shape_controls) 113 | shape_layout.addWidget(reset_shape_pb) 114 | shape_layout.addWidget(show_toys_cb) 115 | 116 | ctrl_layout = QVBoxLayout() 117 | ctrl_layout.addLayout(shape_layout) 118 | ctrl_layout.addWidget(self.toy_frame) 119 | ctrl_layout_add_separator() 120 | 121 | # POSE REGION 122 | model_joints = NUM_POSE_PARAMS 123 | for idx in range(model_joints): 124 | if idx == 0: 125 | label = QLabel("Root Pose (P{0})".format(idx)) 126 | slider = PoseSlider(idx, 2 * np.pi, vert_stack = True) 127 | else: 128 | label = QLabel("P{0}".format(idx)) 129 | slider = PoseSlider(idx, np.pi) 130 | 131 | slider.value_changed.connect(self.update_model) 132 | self.pose_controls.append((label, slider)) 133 | 134 | reset_pose_pb = QPushButton('Reset Pose') 135 | reset_pose_pb.clicked.connect(self.reset_pose) 136 | 137 | root_pose_dict = collections.OrderedDict() 138 | # root_pose_dict[ "Face Left" ] = np.array([0, 0, np.pi]) 139 | # root_pose_dict[ "Diag Left" ] = np.array([0, 0, 3 * np.pi / 2]) 140 | # root_pose_dict[ "Head On" ] = np.array([0, 0, np.pi / 2]) 141 | # root_pose_dict[ "Diag Right" ] = np.array([0, 0, np.pi / 4]) 142 | # root_pose_dict[ "Face Right" ] = np.array([0, 0, 0]) 143 | # root_pose_dict[ "Straight Up" ] = np.array([np.pi / 2, 0, np.pi / 2]) 144 | # root_pose_dict[ "Straight Down" ] = np.array([-np.pi / 2, 0, np.pi / 2]) 145 | 146 | root_pose_dict[ "Face Left" ] = np.array([-np.pi / 2, 0, -np.pi]) 147 | root_pose_dict[ "Diag Left" ] = np.array([-np.pi / 2, 0, -3 * np.pi / 4]) 148 | root_pose_dict[ "Head On" ] = np.array([-np.pi / 2, 0, -np.pi / 2]) 149 | root_pose_dict[ "Diag Right" ] = np.array([-np.pi / 2, 0, -np.pi / 4]) 150 | root_pose_dict[ "Face Right" ] = np.array([-np.pi / 2, 0, 0]) 151 | 152 | root_pose_dict[ "Straight Up" ] = np.array([np.pi, np.pi, -np.pi / 2]) 153 | root_pose_dict[ "Straight Down" ] = np.array([np.pi, np.pi, np.pi / 2]) 154 | 155 | root_pose_layout = QGridLayout() 156 | idx = 0 157 | for key, value in root_pose_dict.items(): 158 | head_on_pb = QPushButton(key) 159 | head_on_pb.clicked.connect(partial(self.set_known_pose, value)) 160 | root_pose_layout.addWidget(head_on_pb, 0, idx) 161 | idx = idx + 1 162 | 163 | pose_layout = QGridLayout() 164 | root_label, root_pose_sliders = self.pose_controls[0] 165 | 166 | pose_layout.addWidget(root_label, 0, 0) 167 | pose_layout.addWidget(root_pose_sliders, 1, 0) 168 | pose_layout.addLayout(self.get_layout_region(self.pose_controls[1:]), 2, 0) 169 | pose_layout.addWidget(reset_pose_pb) 170 | 171 | ctrl_layout.addLayout(pose_layout) 172 | ctrl_layout.addLayout(root_pose_layout) 173 | ctrl_layout_add_separator() 174 | 175 | # TRANSLATION REGION 176 | trans_label = QLabel("Root Translation".format(idx)) 177 | self.trans_sliders = TransSlider(idx, -5.0, 5.0, np.array([0.0, 0.0, 0.0]), vert_stack = True) 178 | self.trans_sliders.value_changed.connect(self.update_model) 179 | 180 | reset_trans_pb = QPushButton('Reset Translation') 181 | reset_trans_pb.clicked.connect(self.reset_trans) 182 | 183 | # Add the translation slider 184 | trans_layout = QGridLayout() 185 | trans_layout.addWidget(trans_label, 0, 0) 186 | trans_layout.addWidget(self.trans_sliders, 1, 0) 187 | trans_layout.addWidget(reset_trans_pb, 2, 0) 188 | 189 | self.trans_frame = QFrame() 190 | self.trans_frame.setLayout(trans_layout) 191 | self.trans_frame.setHidden(True) 192 | 193 | show_trans_cb = QCheckBox('Show Translation Parameters', self) 194 | show_trans_cb.stateChanged.connect(partial(self.toggle_control, self.trans_frame)) 195 | 196 | ctrl_layout.addWidget(show_trans_cb) 197 | ctrl_layout.addWidget(self.trans_frame) 198 | ctrl_layout_add_separator() 199 | 200 | # ACTION BUTTONS 201 | reset_pb = QPushButton('&Reset') 202 | reset_pb.clicked.connect(self.reset_model) 203 | 204 | export_image_pb = QPushButton('&Export Image') 205 | export_image_pb.clicked.connect(self.export_image) 206 | 207 | misc_pbs_layout = QGridLayout() 208 | misc_pbs_layout.addWidget(reset_pb, 0, 0) 209 | misc_pbs_layout.addWidget(export_image_pb, 0, 1) 210 | ctrl_layout.addLayout(misc_pbs_layout) 211 | ctrl_layout_add_separator() 212 | 213 | view_layout = QVBoxLayout() 214 | self.render_img_label = QLabel() 215 | view_layout.addWidget(self.render_img_label) 216 | 217 | main_layout = QHBoxLayout() 218 | main_layout.addLayout(ctrl_layout) 219 | main_layout.addLayout(view_layout) 220 | 221 | main_widget = QWidget() 222 | main_widget.setLayout(main_layout) 223 | self.setCentralWidget(main_widget) 224 | 225 | # WINDOW 226 | self.window_title_stem = 'SMAL Model Viewer' 227 | self.setWindowTitle(self.window_title_stem) 228 | 229 | self.statusBar().showMessage('Ready...') 230 | self.update_render() 231 | 232 | self.showMaximized() 233 | 234 | def update_model(self, value): 235 | sender = self.sender() 236 | 237 | if type(sender) is ShapeSlider: 238 | self.smal_params['betas'][0, sender.idx] = value 239 | elif type(sender) is PoseSlider: 240 | if sender.idx == 0: 241 | self.smal_params['global_rotation'][0, sender.idx] = torch.FloatTensor(value) 242 | else: 243 | self.smal_params['joint_rotations'][0, sender.idx - 1] = torch.FloatTensor(value) 244 | elif type(sender) is TransSlider: 245 | self.smal_params['trans'][0] = torch.FloatTensor(value) 246 | 247 | if self.update_poly: 248 | self.update_render() 249 | 250 | def update_render(self): 251 | with torch.no_grad(): 252 | start = time.time() 253 | verts, joints, Rs, v_shaped = self.smal_model( 254 | self.smal_params['betas'], 255 | torch.cat([self.smal_params['global_rotation'], self.smal_params['joint_rotations']], dim = 1)) 256 | 257 | # normalize by center of mass 258 | verts = verts - torch.mean(verts, dim = 1, keepdim=True) 259 | 260 | # add on the translation 261 | verts = verts + self.smal_params['trans'] 262 | 263 | end = time.time() 264 | ellapsed = end - start 265 | print (f"SMAL Time: {ellapsed }") 266 | 267 | start = time.time() 268 | rendered_images = self.model_renderer(verts, self.smal_model.faces.unsqueeze(0)) 269 | end = time.time() 270 | ellapsed = end - start 271 | print (f"Renderer Time: {ellapsed }") 272 | 273 | self.image_np = rendered_images[0, :, :, :3] 274 | self.render_img_label.setPixmap(self.image_to_pixmap(self.image_np, DISPLAY_SIZE)) 275 | self.render_img_label.update() 276 | 277 | def reset_shape(self): 278 | # Reset sliders to zero 279 | self.update_poly = False 280 | for label, com_slider in self.shape_controls: 281 | com_slider.reset() 282 | self.update_poly = True 283 | self.update_render() 284 | 285 | def reset_pose(self): 286 | self.update_poly = False 287 | for label, com_slider in self.pose_controls: 288 | com_slider.reset() 289 | self.update_poly = True 290 | self.update_render() 291 | 292 | def make_toy_shape(self, toy_id): 293 | self.update_poly = False 294 | toy_betas = self.toy_betas[toy_id] 295 | for idx, val in enumerate(toy_betas): 296 | label, shape_slider = self.shape_controls[idx] 297 | shape_slider.setValue(val) 298 | 299 | self.update_poly = True 300 | self.statusBar().showMessage(str(toy_betas)) 301 | self.smal_params['betas'][0] = torch.from_numpy(toy_betas) 302 | self.update_render() 303 | 304 | def toggle_control(self, layout): 305 | sender = self.sender() 306 | layout.setHidden(not sender.isChecked()) 307 | 308 | def set_known_pose(self, pose): 309 | label, root_pose_slider = self.pose_controls[0] 310 | root_pose_slider.setValue(pose) 311 | root_pose_slider.force_emit() 312 | 313 | def reset_trans(self): 314 | self.trans_sliders.reset() 315 | 316 | def reset_model(self): 317 | self.reset_shape() 318 | self.reset_pose() 319 | self.reset_trans() 320 | 321 | def image_to_pixmap(self, img, img_size): 322 | im = np.require(img * 255.0, dtype='uint8') 323 | qim = QImage(im.data, im.shape[1], im.shape[0], im.strides[0], QImage.Format_RGB888).copy() 324 | pixmap = QPixmap(qim) 325 | return pixmap.scaled(img_size, img_size, QtCore.Qt.KeepAspectRatio) 326 | 327 | def export_image(self): 328 | out_dir = "output" 329 | if not os.path.exists(out_dir): 330 | os.mkdir(out_dir) 331 | 332 | time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 333 | scipy.misc.imsave(os.path.join(out_dir, "{0}.png".format(time_str)), self.image_np) -------------------------------------------------------------------------------- /pyrenderer.py: -------------------------------------------------------------------------------- 1 | import os 2 | #https://pyrender.readthedocs.io/en/latest/examples/offscreen.html 3 | # fix for windows from https://github.com/mmatl/pyrender/issues/117 4 | # edit C:\Users\bjb10042\.conda\envs\bjb_env\Lib\site-packages\pyrender 5 | # os.environ['PYOPENGL_PLATFORM'] = 'osmesa' 6 | import torch 7 | from torchvision.utils import make_grid 8 | import numpy as np 9 | import pyrender 10 | import trimesh 11 | 12 | class Renderer: 13 | """ 14 | Renderer used for visualizing the SMPL model 15 | Code adapted from https://github.com/vchoutas/smplify-x 16 | """ 17 | def __init__(self, img_res=224): 18 | self.renderer = pyrender.OffscreenRenderer(viewport_width=img_res, 19 | viewport_height=img_res, 20 | point_size=1.0) 21 | self.focal_length = 5000 22 | self.camera_center = [img_res // 2, img_res // 2] 23 | 24 | def __call__(self, vertices, faces): 25 | material = pyrender.MetallicRoughnessMaterial( 26 | metallicFactor=0.2, 27 | alphaMode='OPAQUE', 28 | baseColorFactor=(0.8, 0.3, 0.3, 1.0)) 29 | 30 | camera_translation = np.array([0.0, 0.0, 50.0]) 31 | 32 | mesh = trimesh.Trimesh(vertices[0], faces[0], process=False) 33 | # rot = trimesh.transformations.rotation_matrix( 34 | # np.radians(180), [1, 0, 0]) 35 | # mesh.apply_transform(rot) 36 | mesh = pyrender.Mesh.from_trimesh(mesh, material=material) 37 | 38 | scene = pyrender.Scene(ambient_light=(0.5, 0.5, 0.5)) 39 | scene.add(mesh, 'mesh') 40 | 41 | camera_pose = np.eye(4) 42 | camera_pose[:3, 3] = camera_translation 43 | camera = pyrender.IntrinsicsCamera(fx=self.focal_length, fy=self.focal_length, 44 | cx=self.camera_center[0], cy=self.camera_center[1]) 45 | scene.add(camera, pose=camera_pose) 46 | 47 | 48 | light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=1) 49 | light_pose = np.eye(4) 50 | 51 | light_pose[:3, 3] = np.array([0, -1, 1]) 52 | scene.add(light, pose=light_pose) 53 | 54 | light_pose[:3, 3] = np.array([0, 1, 1]) 55 | scene.add(light, pose=light_pose) 56 | 57 | light_pose[:3, 3] = np.array([1, 1, 2]) 58 | scene.add(light, pose=light_pose) 59 | 60 | color, rend_depth = self.renderer.render(scene, flags=pyrender.RenderFlags.RGBA) 61 | color = color.astype(np.float32) / 255.0 62 | return torch.from_numpy(color).float().unsqueeze(0) -------------------------------------------------------------------------------- /smal_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/benjiebob/SMALViewer/ca54072ad5d7c78b2bf4ed19945ed2e052e7fbe9/smal_model/__init__.py -------------------------------------------------------------------------------- /smal_model/batch_lbs.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import numpy as np 7 | 8 | def batch_skew(vec, batch_size=None, opts=None): 9 | """ 10 | vec is N x 3, batch_size is int 11 | returns N x 3 x 3. Skew_sym version of each matrix. 12 | """ 13 | if batch_size is None: 14 | batch_size = vec.shape.as_list()[0] 15 | col_inds = torch.LongTensor([1, 2, 3, 5, 6, 7]) 16 | indices = torch.reshape(torch.reshape(torch.arange(0, batch_size) * 9, [-1, 1]) + col_inds, [-1, 1]) 17 | updates = torch.reshape( 18 | torch.stack( 19 | [ 20 | -vec[:, 2], vec[:, 1], vec[:, 2], -vec[:, 0], -vec[:, 1], 21 | vec[:, 0] 22 | ], 23 | dim=1), [-1]) 24 | out_shape = [batch_size * 9] 25 | res = torch.Tensor(np.zeros(out_shape[0])).to(device=vec.device) 26 | res[np.array(indices.flatten())] = updates 27 | res = torch.reshape(res, [batch_size, 3, 3]) 28 | 29 | return res 30 | 31 | 32 | 33 | def batch_rodrigues(theta, opts=None): 34 | """ 35 | Theta is Nx3 36 | """ 37 | batch_size = theta.shape[0] 38 | 39 | angle = (torch.norm(theta + 1e-8, p=2, dim=1)).unsqueeze(-1) 40 | r = (torch.div(theta, angle)).unsqueeze(-1) 41 | 42 | angle = angle.unsqueeze(-1) 43 | cos = torch.cos(angle) 44 | sin = torch.sin(angle) 45 | 46 | outer = torch.matmul(r, r.transpose(1,2)) 47 | 48 | eyes = torch.eye(3).unsqueeze(0).repeat([batch_size, 1, 1]).to(device=theta.device) 49 | H = batch_skew(r, batch_size=batch_size, opts=opts) 50 | R = cos * eyes + (1 - cos) * outer + sin * H 51 | 52 | return R 53 | 54 | def batch_lrotmin(theta): 55 | """ 56 | Output of this is used to compute joint-to-pose blend shape mapping. 57 | Equation 9 in SMPL paper. 58 | Args: 59 | pose: `Tensor`, N x 72 vector holding the axis-angle rep of K joints. 60 | This includes the global rotation so K=24 61 | Returns 62 | diff_vec : `Tensor`: N x 207 rotation matrix of 23=(K-1) joints with identity subtracted., 63 | """ 64 | # Ignore global rotation 65 | theta = theta[:,3:] 66 | 67 | Rs = batch_rodrigues(torch.reshape(theta, [-1,3])) 68 | lrotmin = torch.reshape(Rs - torch.eye(3), [-1, 207]) 69 | 70 | return lrotmin 71 | 72 | def batch_global_rigid_transformation(Rs, Js, parent, rotate_base = False, opts=None): 73 | """ 74 | Computes absolute joint locations given pose. 75 | rotate_base: if True, rotates the global rotation by 90 deg in x axis. 76 | if False, this is the original SMPL coordinate. 77 | Args: 78 | Rs: N x 24 x 3 x 3 rotation vector of K joints 79 | Js: N x 24 x 3, joint locations before posing 80 | parent: 24 holding the parent id for each index 81 | Returns 82 | new_J : `Tensor`: N x 24 x 3 location of absolute joints 83 | A : `Tensor`: N x 24 4 x 4 relative joint transformations for LBS. 84 | """ 85 | if rotate_base: 86 | print('Flipping the SMPL coordinate frame!!!!') 87 | rot_x = torch.Tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) 88 | rot_x = torch.reshape(torch.repeat(rot_x, [N, 1]), [N, 3, 3]) # In tf it was tile 89 | root_rotation = torch.matmul(Rs[:, 0, :, :], rot_x) 90 | else: 91 | root_rotation = Rs[:, 0, :, :] 92 | 93 | # Now Js is N x 24 x 3 x 1 94 | Js = Js.unsqueeze(-1) 95 | N = Rs.shape[0] 96 | 97 | def make_A(R, t): 98 | # Rs is N x 3 x 3, ts is N x 3 x 1 99 | R_homo = torch.nn.functional.pad(R, (0,0,0,1,0,0)) 100 | t_homo = torch.cat([t, torch.ones([N, 1, 1]).to(device=Rs.device)], 1) 101 | return torch.cat([R_homo, t_homo], 2) 102 | 103 | A0 = make_A(root_rotation, Js[:, 0]) 104 | results = [A0] 105 | for i in range(1, parent.shape[0]): 106 | j_here = Js[:, i] - Js[:, parent[i]] 107 | A_here = make_A(Rs[:, i], j_here) 108 | res_here = torch.matmul( 109 | results[parent[i]], A_here) 110 | results.append(res_here) 111 | 112 | # 10 x 24 x 4 x 4 113 | results = torch.stack(results, dim=1) 114 | 115 | new_J = results[:, :, :3, 3] 116 | 117 | # --- Compute relative A: Skinning is based on 118 | # how much the bone moved (not the final location of the bone) 119 | # but (final_bone - init_bone) 120 | # --- 121 | Js_w0 = torch.cat([Js, torch.zeros([N, 35, 1, 1]).to(device=Rs.device)], 2) 122 | init_bone = torch.matmul(results, Js_w0) 123 | # Append empty 4 x 3: 124 | init_bone = torch.nn.functional.pad(init_bone, (3,0,0,0,0,0,0,0)) 125 | A = results - init_bone 126 | 127 | return new_J, A -------------------------------------------------------------------------------- /smal_model/smal_basics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle as pkl 3 | import numpy as np 4 | from smal_model.smpl_webuser.serialization import load_model 5 | 6 | def align_smal_template_to_symmetry_axis(v, sym_file): 7 | # These are the indexes of the points that are on the symmetry axis 8 | I = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 37, 55, 119, 120, 163, 209, 210, 211, 213, 216, 227, 326, 395, 452, 578, 910, 959, 964, 975, 976, 977, 1172, 1175, 1176, 1178, 1194, 1243, 1739, 1796, 1797, 1798, 1799, 1800, 1801, 1802, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1820, 1821, 1822, 1823, 1824, 1825, 1826, 1827, 1828, 1829, 1830, 1831, 1832, 1833, 1834, 1835, 1836, 1837, 1838, 1839, 1840, 1842, 1843, 1844, 1845, 1846, 1847, 1848, 1849, 1850, 1851, 1852, 1853, 1854, 1855, 1856, 1857, 1858, 1859, 1860, 1861, 1862, 1863, 1870, 1919, 1960, 1961, 1965, 1967, 2003] 9 | 10 | v = v - np.mean(v) 11 | y = np.mean(v[I,1]) 12 | v[:,1] = v[:,1] - y 13 | v[I,1] = 0 14 | 15 | sym_path = sym_file 16 | # symIdx = pkl.load(open(sym_path)) 17 | with open(sym_path, 'rb') as f: 18 | u = pkl._Unpickler(f) 19 | u.encoding = 'latin1' 20 | symIdx = u.load() 21 | 22 | 23 | left = v[:, 1] < 0 24 | right = v[:, 1] > 0 25 | center = v[:, 1] == 0 26 | v[left[symIdx]] = np.array([1,-1,1])*v[left] 27 | 28 | left_inds = np.where(left)[0] 29 | right_inds = np.where(right)[0] 30 | center_inds = np.where(center)[0] 31 | 32 | try: 33 | assert(len(left_inds) == len(right_inds)) 34 | except: 35 | import pdb; pdb.set_trace() 36 | 37 | return v, left_inds, right_inds, center_inds 38 | 39 | def load_smal_model(): 40 | model = load_model(config.SMAL_FILE) 41 | v = align_smal_template_to_symmetry_axis(model.r.copy()) 42 | 43 | 44 | return v, model.f 45 | 46 | def get_smal_template(model_name, data_name, shape_family_id=-1): 47 | model = load_model(model_name) 48 | nBetas = len(model.betas.r) 49 | 50 | with open(data_name, 'rb') as f: 51 | u = pkl._Unpickler(f) 52 | u.encoding = 'latin1' 53 | data = u.load() 54 | 55 | # Select average zebra/horse 56 | # betas = data['cluster_means'][2][:nBetas] 57 | betas = data['cluster_means'][shape_family_id][:nBetas] 58 | model.betas[:] = betas 59 | 60 | if shape_family_id == -1: 61 | model.betas[:] = np.zeros_like(betas) 62 | 63 | v = model.r.copy() 64 | return v 65 | 66 | 67 | -------------------------------------------------------------------------------- /smal_model/smal_torch.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | PyTorch implementation of the SMAL/SMPL model 4 | 5 | """ 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import numpy as np 11 | import torch 12 | from torch.autograd import Variable 13 | import pickle as pkl 14 | from .batch_lbs import batch_rodrigues, batch_global_rigid_transformation 15 | from .smal_basics import align_smal_template_to_symmetry_axis, get_smal_template 16 | import torch.nn as nn 17 | 18 | # There are chumpy variables so convert them to numpy. 19 | def undo_chumpy(x): 20 | return x if isinstance(x, np.ndarray) else x.r 21 | 22 | class SMAL(nn.Module): 23 | def __init__(self, pkl_path, shape_family_id=-1, dtype=torch.float): 24 | super(SMAL, self).__init__() 25 | 26 | # -- Load SMPL params -- 27 | # with open(pkl_path, 'r') as f: 28 | # dd = pkl.load(f) 29 | 30 | with open(pkl_path, 'rb') as f: 31 | u = pkl._Unpickler(f) 32 | u.encoding = 'latin1' 33 | dd = u.load() 34 | 35 | self.f = dd['f'] 36 | 37 | self.faces = torch.from_numpy(self.f.astype(int)) 38 | 39 | v_template = get_smal_template(model_name='data/my_smpl_00781_4_all.pkl', data_name='data/my_smpl_data_00781_4_all.pkl', shape_family_id=shape_family_id) 40 | v, self.left_inds, self.right_inds, self.center_inds = align_smal_template_to_symmetry_axis(v_template, sym_file='data/symIdx.pkl') 41 | 42 | # Mean template vertices 43 | self.v_template = Variable( 44 | torch.Tensor(v), 45 | requires_grad=False) 46 | # Size of mesh [Number of vertices, 3] 47 | self.size = [self.v_template.shape[0], 3] 48 | self.num_betas = dd['shapedirs'].shape[-1] 49 | # Shape blend shape basis 50 | 51 | shapedir = np.reshape( 52 | undo_chumpy(dd['shapedirs']), [-1, self.num_betas]).T 53 | self.shapedirs = Variable( 54 | torch.Tensor(shapedir), requires_grad=False) 55 | 56 | # Regressor for joint locations given shape 57 | self.J_regressor = Variable( 58 | torch.Tensor(dd['J_regressor'].T.todense()), 59 | requires_grad=False) 60 | 61 | # Pose blend shape basis 62 | num_pose_basis = dd['posedirs'].shape[-1] 63 | 64 | posedirs = np.reshape( 65 | undo_chumpy(dd['posedirs']), [-1, num_pose_basis]).T 66 | self.posedirs = Variable( 67 | torch.Tensor(posedirs), requires_grad=False) 68 | 69 | # indices of parents for each joints 70 | self.parents = dd['kintree_table'][0].astype(np.int32) 71 | 72 | # LBS weights 73 | self.weights = Variable( 74 | torch.Tensor(undo_chumpy(dd['weights'])), 75 | requires_grad=False) 76 | 77 | def __call__(self, beta, theta, trans=None, del_v=None, betas_logscale=None, get_skin=True): 78 | 79 | if True: 80 | nBetas = beta.shape[1] 81 | else: 82 | nBetas = 0 83 | 84 | 85 | # v_template = self.v_template.unsqueeze(0).expand(beta.shape[0], 3889, 3) 86 | v_template = self.v_template 87 | # 1. Add shape blend shapes 88 | 89 | if nBetas > 0: 90 | if del_v is None: 91 | v_shaped = v_template + torch.reshape(torch.matmul(beta, self.shapedirs[:nBetas,:]), [-1, self.size[0], self.size[1]]) 92 | else: 93 | v_shaped = v_template + del_v + torch.reshape(torch.matmul(beta, self.shapedirs[:nBetas,:]), [-1, self.size[0], self.size[1]]) 94 | else: 95 | if del_v is None: 96 | v_shaped = v_template.unsqueeze(0) 97 | else: 98 | v_shaped = v_template + del_v 99 | 100 | # 2. Infer shape-dependent joint locations. 101 | Jx = torch.matmul(v_shaped[:, :, 0], self.J_regressor) 102 | Jy = torch.matmul(v_shaped[:, :, 1], self.J_regressor) 103 | Jz = torch.matmul(v_shaped[:, :, 2], self.J_regressor) 104 | J = torch.stack([Jx, Jy, Jz], dim=2) 105 | 106 | # 3. Add pose blend shapes 107 | # N x 24 x 3 x 3 108 | if len(theta.shape) == 4: 109 | Rs = theta 110 | else: 111 | Rs = torch.reshape( batch_rodrigues(torch.reshape(theta, [-1, 3])), [-1, 35, 3, 3]) 112 | 113 | # Ignore global rotation. 114 | pose_feature = torch.reshape(Rs[:, 1:, :, :] - torch.eye(3).to(beta.device), [-1, 306]) 115 | 116 | v_posed = torch.reshape( 117 | torch.matmul(pose_feature, self.posedirs), 118 | [-1, self.size[0], self.size[1]]) + v_shaped 119 | 120 | #4. Get the global joint location 121 | self.J_transformed, A = batch_global_rigid_transformation( 122 | Rs, J, self.parents) 123 | 124 | 125 | # 5. Do skinning: 126 | num_batch = theta.shape[0] 127 | 128 | weights_t = self.weights.repeat([num_batch, 1]) 129 | W = torch.reshape(weights_t, [num_batch, -1, 35]) 130 | 131 | 132 | T = torch.reshape( 133 | torch.matmul(W, torch.reshape(A, [num_batch, 35, 16])), 134 | [num_batch, -1, 4, 4]) 135 | v_posed_homo = torch.cat( 136 | [v_posed, torch.ones([num_batch, v_posed.shape[1], 1]).to(device=beta.device)], 2) 137 | v_homo = torch.matmul(T, v_posed_homo.unsqueeze(-1)) 138 | 139 | verts = v_homo[:, :, :3, 0] 140 | 141 | if trans is None: 142 | trans = torch.zeros((num_batch,3)).to(device=beta.device) 143 | 144 | verts = verts + trans[:,None,:] 145 | 146 | # Get joints: 147 | joint_x = torch.matmul(verts[:, :, 0], self.J_regressor) 148 | joint_y = torch.matmul(verts[:, :, 1], self.J_regressor) 149 | joint_z = torch.matmul(verts[:, :, 2], self.J_regressor) 150 | joints = torch.stack([joint_x, joint_y, joint_z], dim=2) 151 | 152 | joints = torch.cat([ 153 | joints, 154 | verts[:, None, 1863], # end_of_nose 155 | verts[:, None, 26], # chin 156 | verts[:, None, 2124], # right ear tip 157 | verts[:, None, 150], # left ear tip 158 | verts[:, None, 3055], # left eye 159 | verts[:, None, 1097], # right eye 160 | ], dim = 1) 161 | 162 | if get_skin: 163 | return verts, joints, Rs, v_shaped 164 | else: 165 | return joints 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /smal_model/template_w_tex_uv.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl my_mat 5 | Ns 96.078431 6 | Ka 0.000000 0.000000 0.000000 7 | Kd 0.640000 0.640000 0.640000 8 | Ks 0.500000 0.500000 0.500000 9 | Ni 1.000000 10 | d 1.000000 11 | illum 2 12 | map_Kd texture_debug.jpg 13 | -------------------------------------------------------------------------------- /smal_viewer.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtWidgets import QApplication 2 | import pyqt_viewer 3 | 4 | def main(): 5 | qapp = QApplication([]) 6 | main_window = pyqt_viewer.MainWindow() 7 | 8 | main_window.setWindowTitle("SMAL Model Viewer") 9 | main_window.show() 10 | qapp.exec_() 11 | 12 | if __name__ == '__main__': 13 | main() 14 | --------------------------------------------------------------------------------