├── data
└── mean.npy
├── example
├── screenshot1.png
├── screenshot2.png
├── screenshot3.png
├── screenshot4.png
├── screenshot5.png
└── screenshot6.png
├── main.py
├── README.md
├── mesh_loader.py
└── viewer.py
/data/mean.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yeongjoonJu/Mesh-Viewer-using-pytorch3d/HEAD/data/mean.npy
--------------------------------------------------------------------------------
/example/screenshot1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yeongjoonJu/Mesh-Viewer-using-pytorch3d/HEAD/example/screenshot1.png
--------------------------------------------------------------------------------
/example/screenshot2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yeongjoonJu/Mesh-Viewer-using-pytorch3d/HEAD/example/screenshot2.png
--------------------------------------------------------------------------------
/example/screenshot3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yeongjoonJu/Mesh-Viewer-using-pytorch3d/HEAD/example/screenshot3.png
--------------------------------------------------------------------------------
/example/screenshot4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yeongjoonJu/Mesh-Viewer-using-pytorch3d/HEAD/example/screenshot4.png
--------------------------------------------------------------------------------
/example/screenshot5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yeongjoonJu/Mesh-Viewer-using-pytorch3d/HEAD/example/screenshot5.png
--------------------------------------------------------------------------------
/example/screenshot6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yeongjoonJu/Mesh-Viewer-using-pytorch3d/HEAD/example/screenshot6.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 | sys.path.append(os.path.abspath(''))
3 | from viewer import MeshViewer
4 |
5 | from PyQt5.QtWidgets import (QApplication, QFileDialog)
6 |
7 |
8 | if __name__ == '__main__':
9 |
10 | app = QApplication(sys.argv)
11 | meshViewer = MeshViewer()
12 | meshViewer.show()
13 | sys.exit(app.exec_())
14 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Mesh Viewer using pytorch3D
2 |
3 | This code is implemented using Pytorch3D and PyQt5
4 |
5 |
6 |
7 | ### Installation
8 |
9 | The code uses **Python 3.7** in **Ubuntu 18.04 LTS**
10 |
11 | [Pytorch3D](https://github.com/facebookresearch/pytorch3d)
12 |
13 | You can install Pytorch3D in upper link.
14 |
15 | To install
16 |
17 | ~~~
18 | git clone https://github.com/yeongjoonJu/Mesh-Viewer-using-pytorch3d.git
19 | cd Mesh-Viewer-using-pytorch3d
20 | mkdir data
21 | pip install pyqt5
22 | pip install pillow
23 | ~~~
24 |
25 |
26 |
27 | ### Run
28 |
29 | ~~~
30 | python main.py
31 | ~~~
32 |
33 | Rotate an object : dragging the mouse
34 |
35 | Zoom in / out an object : wheeling the mouse
36 |
37 | Change the illumination position :
**W - move to up, A - move to left, S - move to down, D - move to right**
38 | **Q - move to front, E - move to back**
39 |
40 | ### Examples
41 |
42 | Texture mapping with color value per vertex [update August 26th]
43 |
44 | 
45 |
46 | You can open .obj file after pressing file button.
47 |
48 | 
49 |
50 | 
51 |
52 | **Rotation**
53 |
54 | You can rotate object through dragging mouse.
55 |
56 | 
57 |
58 | **Zoom in/out**
59 |
60 | You can zoom in/out through wheeling.
61 |
62 | 
63 |
64 | **Change the illumination**
65 |
66 | You can change the illumination using the keyboard.
67 |
68 | 
--------------------------------------------------------------------------------
/mesh_loader.py:
--------------------------------------------------------------------------------
1 | import os, torch
2 |
3 | # Util function for loading meshes
4 | from pytorch3d.io import load_obj, load_ply
5 |
6 | # Data structures and functions for rendering
7 | from pytorch3d.structures import Meshes, Textures
8 | from pytorch3d.ops import GraphConv, sample_points_from_meshes, vert_align
9 | from pytorch3d.renderer import (
10 | look_at_view_transform,
11 | OpenGLPerspectiveCameras,
12 | PointLights, HardPhongShader,
13 | RasterizationSettings,
14 | MeshRenderer, MeshRasterizer,
15 | BlendParams
16 | )
17 | from pytorch3d.renderer.mesh.shader import TexturedSoftPhongShader
18 | import numpy as np
19 |
20 |
21 | class MeshLoader(object):
22 | def __init__(self, device='cuda:0'):
23 | self.device = torch.device(device)
24 | torch.cuda.set_device(self.device)
25 | self.initialize_renderer()
26 |
27 | def set_phong_renderer(self, light_location):
28 | # Place a point light in front of the object
29 | self.light_location = light_location
30 | lights = PointLights(device=self.device, location=[light_location])
31 |
32 | # Create a phong renderer by composing a rasterizer and a shader
33 | self.phong_renderer = MeshRenderer(
34 | rasterizer=MeshRasterizer(
35 | cameras=self.cameras,
36 | raster_settings=self.raster_settings
37 | ),
38 | shader=HardPhongShader(device=self.device, lights=lights)
39 | )
40 |
41 | def initialize_renderer(self):
42 | # Initialize an OpenGL perspective camera
43 | self.cameras = OpenGLPerspectiveCameras(device=self.device)
44 |
45 | self.raster_settings = RasterizationSettings(
46 | image_size = 512,
47 | blur_radius = 0.0,
48 | faces_per_pixel=2,
49 | )
50 |
51 | self.set_phong_renderer([0.0,3.0,5.0])
52 |
53 | def load(self, obj_filename):
54 | # Load obj file
55 | extension = obj_filename[-3:]
56 |
57 | if extension == 'obj':
58 | verts, faces, aux = load_obj(obj_filename)
59 | verts_idx = faces.verts_idx
60 | elif extension == 'ply':
61 | verts, faces = load_ply(obj_filename)
62 | verts_idx = faces
63 |
64 | if os.path.exists(obj_filename[:-3]+'npy'):
65 | colors = np.load(obj_filename[:-3]+'npy')
66 | verts_rgb = torch.FloatTensor(colors[...,[2,1,0]])
67 | verts_rgb = verts_rgb.unsqueeze(0)
68 | verts_rgb = verts_rgb.to(self.device)
69 | else:
70 | # Initialize each vertex to be white in color - bgr
71 | verts_rgb = torch.ones_like(verts)[None]
72 | verts_rgb = verts_rgb.to(self.device)
73 | #textures = Textures(faces_uvs=faces.textures_idx[None,...], verts_uvs=aux.verts_uvs[None,...], verts_rgb=verts_rgb.to(self.device))
74 |
75 | # Create a Meshes object for the face.
76 | self.face_mesh = Meshes(
77 | verts = [verts.to(self.device)],
78 | faces = [verts_idx.to(self.device)],
79 | textures= Textures(verts_rgb=verts_rgb)
80 | )
81 |
82 | def set_camera_location(self, distance, elevation, azimuth):
83 | self.distance = distance
84 | self.elevation = elevation
85 | self.azimuth = azimuth
86 |
87 | def render(self, distance=3, elevation=1.0, azimuth=0.0):
88 | """ Select the viewpoint using spherical angles"""
89 |
90 | self.set_camera_location(distance, elevation, azimuth)
91 |
92 | # Get the position of the camera based on the spherical angles
93 | R, T = look_at_view_transform(distance, elevation, azimuth, device=self.device)
94 |
95 | # Render the face providing the values of R and T
96 | image_ref = self.phong_renderer(meshes_world=self.face_mesh, R=R, T=T)
97 |
98 | #silhouette = silhouette.cpu().numpy()
99 | image_ref = image_ref.cpu().numpy()
100 |
101 | return image_ref.squeeze()
102 |
103 | def change_light(self, light_location):
104 | self.set_phong_renderer(light_location)
105 | return self.render(self.distance, self.elevation, self.azimuth)
106 |
107 | def get_camera_params(self):
108 | return self.distance, self.elevation, self.azimuth
109 |
110 | def get_light_location(self):
111 | return self.light_location
--------------------------------------------------------------------------------
/viewer.py:
--------------------------------------------------------------------------------
1 | from PyQt5.QtCore import QDir, Qt
2 | from PyQt5.QtGui import QImage, QPainter, QPalette, QPixmap, qRgb, QIcon
3 | from PyQt5.QtWidgets import (QAction, QApplication, QFileDialog, QLabel,
4 | QMainWindow, QMenu, QMessageBox, QScrollArea, QSizePolicy, QInputDialog)
5 | import os
6 | import numpy as np
7 | from PIL import Image, ImageFont, ImageDraw
8 |
9 | import sys
10 | sys.path.append(os.path.abspath(''))
11 | from mesh_loader import MeshLoader
12 |
13 | class MeshViewer(QMainWindow):
14 | def __init__(self):
15 | super(MeshViewer, self).__init__()
16 |
17 | self.gray_color_table = [qRgb(i, i, i) for i in range(256)]
18 |
19 | self.width = 1024
20 | self.height = 1024
21 |
22 | self.imageLabel = QLabel()
23 | self.imageLabel.setBackgroundRole(QPalette.Base)
24 | self.imageLabel.setSizePolicy(QSizePolicy.Ignored, QSizePolicy.Ignored)
25 | self.imageLabel.setScaledContents(True)
26 |
27 | self.setCentralWidget(self.imageLabel)
28 |
29 | self.loaded = False
30 | self.font = ImageFont.truetype("/usr/share/fonts/dejavu/DejaVuSans.ttf", 15)
31 |
32 | self.initMenu()
33 | self.meshLoader = MeshLoader()
34 |
35 | self.setWindowTitle("Mesh Viewer by yj_ju")
36 | self.resize(self.width, self.height)
37 |
38 | def initMenu(self):
39 | self.statusBar()
40 |
41 | openFile = QAction(QIcon('open.png'), 'Open', self)
42 | openFile.setShortcut('Ctrl+O')
43 | openFile.setStatusTip('Open .obj file')
44 | openFile.triggered.connect(self.showFileDialog)
45 |
46 | menubar = self.menuBar()
47 | menubar.setNativeMenuBar(False)
48 | fileMenu = menubar.addMenu('&File')
49 | fileMenu.addAction(openFile)
50 |
51 | def showFileDialog(self):
52 | self.fname = QFileDialog.getOpenFileName(self, 'Open obj file', './data')
53 | self.meshLoader.load(self.fname[0])
54 | self.loaded = True
55 | # image = self.meshLoader.render()
56 | # image = image * 255
57 | # image = image.astype('uint8')
58 | # self.openImage(image)
59 | self.meshLoader.set_camera_location(280.0, 0.0, 0.0)
60 | self.change_light_location(0.0, 0.0, 150.0)
61 |
62 |
63 | def toQImage(self, im, copy=False):
64 | if im is None:
65 | return QImage()
66 | if im.dtype == np.uint8:
67 | if len(im.shape) == 2:
68 | qim = QImage(im.data, im.shape[1], im.shape[0], im.strides[0], QImage.Format_Indexed8)
69 | qim.setColorTable(self.gray_color_table)
70 | return qim.copy() if copy else qim
71 |
72 | elif len(im.shape) == 3:
73 | if im.shape[2] == 3:
74 | qim = QImage(im.data, im.shape[1], im.shape[0], im.strides[0], QImage.Format_RGB888)
75 | return qim.copy() if copy else qim
76 | elif im.shape[2] == 4:
77 | qim = QImage(im.data, im.shape[1], im.shape[0], im.strides[0], QImage.Format_ARGB32)
78 | return qim.copy() if copy else qim
79 |
80 | def openImage(self, image):
81 | self.imageLabel.setPixmap(QPixmap.fromImage(self.toQImage(image)))
82 |
83 | #self.fitToWindowAct.setEnabled(True)
84 | #self.updateActions()
85 | #if not self.fitToWindowAct.isChecked():
86 | # self.imageLabel.adjustSize()
87 |
88 | def render_for_camera(self, dist, elev, azim):
89 | image = self.meshLoader.render(dist, elev, azim)
90 | image = image * 255
91 | image = image.astype('uint8')
92 |
93 | img = Image.fromarray(image)
94 | draw = ImageDraw.Draw(img)
95 | draw.text((10,10), "distance: {0}, elevation: {1}, azimuth: {2}"
96 | .format(round(dist,3), round(elev,3), round(azim,3)), (255,255,255), font=self.font)
97 |
98 | self.openImage(np.array(img))
99 |
100 | def mousePressEvent(self, e):
101 | self.prev_pos = (e.x(), e.y())
102 |
103 | def mouseMoveEvent(self, e):
104 | if not self.loaded:
105 | return
106 |
107 | dist, elev, azim = self.meshLoader.get_camera_params()
108 | # Adjust rotation speed
109 | azim = azim + (self.prev_pos[0] - e.x())*0.1
110 | elev = elev - (self.prev_pos[1] - e.y())*0.1
111 | self.render_for_camera(dist, elev, azim)
112 |
113 | self.prev_pos = (e.x(), e.y())
114 |
115 | def wheelEvent(self, e):
116 | if not self.loaded:
117 | return
118 | dist, elev, azim = self.meshLoader.get_camera_params()
119 | # Adjust rotation speed
120 | dist = dist - e.angleDelta().y()*0.01
121 | self.render_for_camera(dist, elev, azim)
122 |
123 | def change_light_location(self, x, y, z):
124 | image = self.meshLoader.change_light([x, y, z])
125 | image = image * 255
126 | image = image.astype('uint8')
127 |
128 | img = Image.fromarray(image)
129 | draw = ImageDraw.Draw(img)
130 | draw.text((10,20), "Illumination x: {0}, y: {1}, z: {2}"
131 | .format(x, y, z), (255,255,255), font=self.font)
132 |
133 | self.openImage(np.array(img))
134 |
135 | def keyPressEvent(self, e):
136 | # Key a
137 | if e.key() == 65:
138 | x, y, z = self.meshLoader.get_light_location()
139 | x -= 0.5
140 | self.change_light_location(x, y, z)
141 |
142 | # Key d
143 | elif e.key() == 68:
144 | x, y, z = self.meshLoader.get_light_location()
145 | x += 0.5
146 | self.change_light_location(x, y, z)
147 |
148 | # Key w
149 | elif e.key() == 87:
150 | x, y, z = self.meshLoader.get_light_location()
151 | y -= 0.5
152 | self.change_light_location(x, y, z)
153 |
154 | # Key s
155 | elif e.key() == 83:
156 | x, y, z = self.meshLoader.get_light_location()
157 | y += 0.5
158 | self.change_light_location(x, y, z)
159 |
160 | # Key e
161 | elif e.key() == 69:
162 | x, y, z = self.meshLoader.get_light_location()
163 | z += 0.5
164 | self.change_light_location(x, y, z)
165 |
166 | # Key q
167 | elif e.key() == 81:
168 | x, y, z = self.meshLoader.get_light_location()
169 | z -= 0.5
170 | self.change_light_location(x, y, z)
--------------------------------------------------------------------------------