├── data └── mean.npy ├── example ├── screenshot1.png ├── screenshot2.png ├── screenshot3.png ├── screenshot4.png ├── screenshot5.png └── screenshot6.png ├── main.py ├── README.md ├── mesh_loader.py └── viewer.py /data/mean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeongjoonJu/Mesh-Viewer-using-pytorch3d/HEAD/data/mean.npy -------------------------------------------------------------------------------- /example/screenshot1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeongjoonJu/Mesh-Viewer-using-pytorch3d/HEAD/example/screenshot1.png -------------------------------------------------------------------------------- /example/screenshot2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeongjoonJu/Mesh-Viewer-using-pytorch3d/HEAD/example/screenshot2.png -------------------------------------------------------------------------------- /example/screenshot3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeongjoonJu/Mesh-Viewer-using-pytorch3d/HEAD/example/screenshot3.png -------------------------------------------------------------------------------- /example/screenshot4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeongjoonJu/Mesh-Viewer-using-pytorch3d/HEAD/example/screenshot4.png -------------------------------------------------------------------------------- /example/screenshot5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeongjoonJu/Mesh-Viewer-using-pytorch3d/HEAD/example/screenshot5.png -------------------------------------------------------------------------------- /example/screenshot6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeongjoonJu/Mesh-Viewer-using-pytorch3d/HEAD/example/screenshot6.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.append(os.path.abspath('')) 3 | from viewer import MeshViewer 4 | 5 | from PyQt5.QtWidgets import (QApplication, QFileDialog) 6 | 7 | 8 | if __name__ == '__main__': 9 | 10 | app = QApplication(sys.argv) 11 | meshViewer = MeshViewer() 12 | meshViewer.show() 13 | sys.exit(app.exec_()) 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Mesh Viewer using pytorch3D 2 | 3 | This code is implemented using Pytorch3D and PyQt5 4 | 5 | 6 | 7 | ### Installation 8 | 9 | The code uses **Python 3.7** in **Ubuntu 18.04 LTS** 10 | 11 | [Pytorch3D](https://github.com/facebookresearch/pytorch3d) 12 | 13 | You can install Pytorch3D in upper link. 14 | 15 | To install 16 | 17 | ~~~ 18 | git clone https://github.com/yeongjoonJu/Mesh-Viewer-using-pytorch3d.git 19 | cd Mesh-Viewer-using-pytorch3d 20 | mkdir data 21 | pip install pyqt5 22 | pip install pillow 23 | ~~~ 24 | 25 | 26 | 27 | ### Run 28 | 29 | ~~~ 30 | python main.py 31 | ~~~ 32 | 33 | Rotate an object : dragging the mouse 34 | 35 | Zoom in / out an object : wheeling the mouse 36 | 37 | Change the illumination position :
**W - move to up, A - move to left, S - move to down, D - move to right**
38 | **Q - move to front, E - move to back** 39 | 40 | ### Examples 41 | 42 | Texture mapping with color value per vertex [update August 26th] 43 | 44 | ![](./example/screenshot6.png) 45 | 46 | You can open .obj file after pressing file button. 47 | 48 | ![screenshot1](./example/screenshot1.png) 49 | 50 | ![](./example/screenshot2.png) 51 | 52 | **Rotation** 53 | 54 | You can rotate object through dragging mouse. 55 | 56 | ![](./example/screenshot3.png) 57 | 58 | **Zoom in/out** 59 | 60 | You can zoom in/out through wheeling. 61 | 62 | ![](./example/screenshot4.png) 63 | 64 | **Change the illumination** 65 | 66 | You can change the illumination using the keyboard. 67 | 68 | ![](./example/screenshot5.png) -------------------------------------------------------------------------------- /mesh_loader.py: -------------------------------------------------------------------------------- 1 | import os, torch 2 | 3 | # Util function for loading meshes 4 | from pytorch3d.io import load_obj, load_ply 5 | 6 | # Data structures and functions for rendering 7 | from pytorch3d.structures import Meshes, Textures 8 | from pytorch3d.ops import GraphConv, sample_points_from_meshes, vert_align 9 | from pytorch3d.renderer import ( 10 | look_at_view_transform, 11 | OpenGLPerspectiveCameras, 12 | PointLights, HardPhongShader, 13 | RasterizationSettings, 14 | MeshRenderer, MeshRasterizer, 15 | BlendParams 16 | ) 17 | from pytorch3d.renderer.mesh.shader import TexturedSoftPhongShader 18 | import numpy as np 19 | 20 | 21 | class MeshLoader(object): 22 | def __init__(self, device='cuda:0'): 23 | self.device = torch.device(device) 24 | torch.cuda.set_device(self.device) 25 | self.initialize_renderer() 26 | 27 | def set_phong_renderer(self, light_location): 28 | # Place a point light in front of the object 29 | self.light_location = light_location 30 | lights = PointLights(device=self.device, location=[light_location]) 31 | 32 | # Create a phong renderer by composing a rasterizer and a shader 33 | self.phong_renderer = MeshRenderer( 34 | rasterizer=MeshRasterizer( 35 | cameras=self.cameras, 36 | raster_settings=self.raster_settings 37 | ), 38 | shader=HardPhongShader(device=self.device, lights=lights) 39 | ) 40 | 41 | def initialize_renderer(self): 42 | # Initialize an OpenGL perspective camera 43 | self.cameras = OpenGLPerspectiveCameras(device=self.device) 44 | 45 | self.raster_settings = RasterizationSettings( 46 | image_size = 512, 47 | blur_radius = 0.0, 48 | faces_per_pixel=2, 49 | ) 50 | 51 | self.set_phong_renderer([0.0,3.0,5.0]) 52 | 53 | def load(self, obj_filename): 54 | # Load obj file 55 | extension = obj_filename[-3:] 56 | 57 | if extension == 'obj': 58 | verts, faces, aux = load_obj(obj_filename) 59 | verts_idx = faces.verts_idx 60 | elif extension == 'ply': 61 | verts, faces = load_ply(obj_filename) 62 | verts_idx = faces 63 | 64 | if os.path.exists(obj_filename[:-3]+'npy'): 65 | colors = np.load(obj_filename[:-3]+'npy') 66 | verts_rgb = torch.FloatTensor(colors[...,[2,1,0]]) 67 | verts_rgb = verts_rgb.unsqueeze(0) 68 | verts_rgb = verts_rgb.to(self.device) 69 | else: 70 | # Initialize each vertex to be white in color - bgr 71 | verts_rgb = torch.ones_like(verts)[None] 72 | verts_rgb = verts_rgb.to(self.device) 73 | #textures = Textures(faces_uvs=faces.textures_idx[None,...], verts_uvs=aux.verts_uvs[None,...], verts_rgb=verts_rgb.to(self.device)) 74 | 75 | # Create a Meshes object for the face. 76 | self.face_mesh = Meshes( 77 | verts = [verts.to(self.device)], 78 | faces = [verts_idx.to(self.device)], 79 | textures= Textures(verts_rgb=verts_rgb) 80 | ) 81 | 82 | def set_camera_location(self, distance, elevation, azimuth): 83 | self.distance = distance 84 | self.elevation = elevation 85 | self.azimuth = azimuth 86 | 87 | def render(self, distance=3, elevation=1.0, azimuth=0.0): 88 | """ Select the viewpoint using spherical angles""" 89 | 90 | self.set_camera_location(distance, elevation, azimuth) 91 | 92 | # Get the position of the camera based on the spherical angles 93 | R, T = look_at_view_transform(distance, elevation, azimuth, device=self.device) 94 | 95 | # Render the face providing the values of R and T 96 | image_ref = self.phong_renderer(meshes_world=self.face_mesh, R=R, T=T) 97 | 98 | #silhouette = silhouette.cpu().numpy() 99 | image_ref = image_ref.cpu().numpy() 100 | 101 | return image_ref.squeeze() 102 | 103 | def change_light(self, light_location): 104 | self.set_phong_renderer(light_location) 105 | return self.render(self.distance, self.elevation, self.azimuth) 106 | 107 | def get_camera_params(self): 108 | return self.distance, self.elevation, self.azimuth 109 | 110 | def get_light_location(self): 111 | return self.light_location -------------------------------------------------------------------------------- /viewer.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtCore import QDir, Qt 2 | from PyQt5.QtGui import QImage, QPainter, QPalette, QPixmap, qRgb, QIcon 3 | from PyQt5.QtWidgets import (QAction, QApplication, QFileDialog, QLabel, 4 | QMainWindow, QMenu, QMessageBox, QScrollArea, QSizePolicy, QInputDialog) 5 | import os 6 | import numpy as np 7 | from PIL import Image, ImageFont, ImageDraw 8 | 9 | import sys 10 | sys.path.append(os.path.abspath('')) 11 | from mesh_loader import MeshLoader 12 | 13 | class MeshViewer(QMainWindow): 14 | def __init__(self): 15 | super(MeshViewer, self).__init__() 16 | 17 | self.gray_color_table = [qRgb(i, i, i) for i in range(256)] 18 | 19 | self.width = 1024 20 | self.height = 1024 21 | 22 | self.imageLabel = QLabel() 23 | self.imageLabel.setBackgroundRole(QPalette.Base) 24 | self.imageLabel.setSizePolicy(QSizePolicy.Ignored, QSizePolicy.Ignored) 25 | self.imageLabel.setScaledContents(True) 26 | 27 | self.setCentralWidget(self.imageLabel) 28 | 29 | self.loaded = False 30 | self.font = ImageFont.truetype("/usr/share/fonts/dejavu/DejaVuSans.ttf", 15) 31 | 32 | self.initMenu() 33 | self.meshLoader = MeshLoader() 34 | 35 | self.setWindowTitle("Mesh Viewer by yj_ju") 36 | self.resize(self.width, self.height) 37 | 38 | def initMenu(self): 39 | self.statusBar() 40 | 41 | openFile = QAction(QIcon('open.png'), 'Open', self) 42 | openFile.setShortcut('Ctrl+O') 43 | openFile.setStatusTip('Open .obj file') 44 | openFile.triggered.connect(self.showFileDialog) 45 | 46 | menubar = self.menuBar() 47 | menubar.setNativeMenuBar(False) 48 | fileMenu = menubar.addMenu('&File') 49 | fileMenu.addAction(openFile) 50 | 51 | def showFileDialog(self): 52 | self.fname = QFileDialog.getOpenFileName(self, 'Open obj file', './data') 53 | self.meshLoader.load(self.fname[0]) 54 | self.loaded = True 55 | # image = self.meshLoader.render() 56 | # image = image * 255 57 | # image = image.astype('uint8') 58 | # self.openImage(image) 59 | self.meshLoader.set_camera_location(280.0, 0.0, 0.0) 60 | self.change_light_location(0.0, 0.0, 150.0) 61 | 62 | 63 | def toQImage(self, im, copy=False): 64 | if im is None: 65 | return QImage() 66 | if im.dtype == np.uint8: 67 | if len(im.shape) == 2: 68 | qim = QImage(im.data, im.shape[1], im.shape[0], im.strides[0], QImage.Format_Indexed8) 69 | qim.setColorTable(self.gray_color_table) 70 | return qim.copy() if copy else qim 71 | 72 | elif len(im.shape) == 3: 73 | if im.shape[2] == 3: 74 | qim = QImage(im.data, im.shape[1], im.shape[0], im.strides[0], QImage.Format_RGB888) 75 | return qim.copy() if copy else qim 76 | elif im.shape[2] == 4: 77 | qim = QImage(im.data, im.shape[1], im.shape[0], im.strides[0], QImage.Format_ARGB32) 78 | return qim.copy() if copy else qim 79 | 80 | def openImage(self, image): 81 | self.imageLabel.setPixmap(QPixmap.fromImage(self.toQImage(image))) 82 | 83 | #self.fitToWindowAct.setEnabled(True) 84 | #self.updateActions() 85 | #if not self.fitToWindowAct.isChecked(): 86 | # self.imageLabel.adjustSize() 87 | 88 | def render_for_camera(self, dist, elev, azim): 89 | image = self.meshLoader.render(dist, elev, azim) 90 | image = image * 255 91 | image = image.astype('uint8') 92 | 93 | img = Image.fromarray(image) 94 | draw = ImageDraw.Draw(img) 95 | draw.text((10,10), "distance: {0}, elevation: {1}, azimuth: {2}" 96 | .format(round(dist,3), round(elev,3), round(azim,3)), (255,255,255), font=self.font) 97 | 98 | self.openImage(np.array(img)) 99 | 100 | def mousePressEvent(self, e): 101 | self.prev_pos = (e.x(), e.y()) 102 | 103 | def mouseMoveEvent(self, e): 104 | if not self.loaded: 105 | return 106 | 107 | dist, elev, azim = self.meshLoader.get_camera_params() 108 | # Adjust rotation speed 109 | azim = azim + (self.prev_pos[0] - e.x())*0.1 110 | elev = elev - (self.prev_pos[1] - e.y())*0.1 111 | self.render_for_camera(dist, elev, azim) 112 | 113 | self.prev_pos = (e.x(), e.y()) 114 | 115 | def wheelEvent(self, e): 116 | if not self.loaded: 117 | return 118 | dist, elev, azim = self.meshLoader.get_camera_params() 119 | # Adjust rotation speed 120 | dist = dist - e.angleDelta().y()*0.01 121 | self.render_for_camera(dist, elev, azim) 122 | 123 | def change_light_location(self, x, y, z): 124 | image = self.meshLoader.change_light([x, y, z]) 125 | image = image * 255 126 | image = image.astype('uint8') 127 | 128 | img = Image.fromarray(image) 129 | draw = ImageDraw.Draw(img) 130 | draw.text((10,20), "Illumination x: {0}, y: {1}, z: {2}" 131 | .format(x, y, z), (255,255,255), font=self.font) 132 | 133 | self.openImage(np.array(img)) 134 | 135 | def keyPressEvent(self, e): 136 | # Key a 137 | if e.key() == 65: 138 | x, y, z = self.meshLoader.get_light_location() 139 | x -= 0.5 140 | self.change_light_location(x, y, z) 141 | 142 | # Key d 143 | elif e.key() == 68: 144 | x, y, z = self.meshLoader.get_light_location() 145 | x += 0.5 146 | self.change_light_location(x, y, z) 147 | 148 | # Key w 149 | elif e.key() == 87: 150 | x, y, z = self.meshLoader.get_light_location() 151 | y -= 0.5 152 | self.change_light_location(x, y, z) 153 | 154 | # Key s 155 | elif e.key() == 83: 156 | x, y, z = self.meshLoader.get_light_location() 157 | y += 0.5 158 | self.change_light_location(x, y, z) 159 | 160 | # Key e 161 | elif e.key() == 69: 162 | x, y, z = self.meshLoader.get_light_location() 163 | z += 0.5 164 | self.change_light_location(x, y, z) 165 | 166 | # Key q 167 | elif e.key() == 81: 168 | x, y, z = self.meshLoader.get_light_location() 169 | z -= 0.5 170 | self.change_light_location(x, y, z) --------------------------------------------------------------------------------