├── .idea
├── .gitignore
├── vcs.xml
├── modules.xml
├── other.xml
├── misc.xml
├── inspectionProfiles
│ └── Project_Default.xml
└── 3DMM.iml
├── avgModel.mat
├── requirements.txt
├── resources
├── icon.png
└── gitScreen1.png
├── data
└── avgModel_bh_1779_NE.mat
├── registration_parameters.conf
├── pointRegistration
├── registration_parameters.conf
├── registration_param.py
├── batchRegistration.py
├── registration.py
├── model.py
├── file3D.py
└── displacementMap.py
├── main.py
├── graphicInterface
├── upper_toolbar_controls.py
├── file_dialogs.py
├── rotatable_figure.py
├── plot_button_collection.py
├── window.py
├── console.py
├── show_displacement.py
├── plot_interactive_figure.py
├── conf_param_window.py
├── upper_toolbar.py
├── plot_figure.py
└── main_widget.py
├── morpmodel
├── RP.py
├── Matrix_operations.py
├── _3DMM.py
└── util_for_graphic.py
└── README.md
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Default ignored files
3 | /workspace.xml
--------------------------------------------------------------------------------
/avgModel.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rickie95/3DMMRegistration/HEAD/avgModel.mat
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.15.0
2 | scipy>=1.3
3 | matplotlib
4 | PyQt5
5 | h5py
6 | pycpd
7 |
8 |
--------------------------------------------------------------------------------
/resources/icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rickie95/3DMMRegistration/HEAD/resources/icon.png
--------------------------------------------------------------------------------
/resources/gitScreen1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rickie95/3DMMRegistration/HEAD/resources/gitScreen1.png
--------------------------------------------------------------------------------
/data/avgModel_bh_1779_NE.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rickie95/3DMMRegistration/HEAD/data/avgModel_bh_1779_NE.mat
--------------------------------------------------------------------------------
/registration_parameters.conf:
--------------------------------------------------------------------------------
1 | [PARAMETERS]
2 | tolerance = 0.001
3 | max_iterations = 100
4 | sigma2 = None
5 | w = 0.0
6 |
7 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/pointRegistration/registration_parameters.conf:
--------------------------------------------------------------------------------
1 | # Configuration file for CPD rigid registration, view readme for insights
2 | # This is a comment.
3 | [PARAMETERS]
4 | tolerance=0.001
5 | max_iterations=100
6 | sigma2=None
7 | w=0
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from PyQt5.QtWidgets import QApplication
2 | from graphicInterface.window import *
3 | import sys
4 |
5 | if __name__ == "__main__":
6 |
7 | app = QApplication(sys.argv)
8 |
9 | Logger()
10 | RegistrationParameters()
11 |
12 | w = Window()
13 |
14 | sys.exit(app.exec_())
15 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/.idea/3DMM.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/graphicInterface/upper_toolbar_controls.py:
--------------------------------------------------------------------------------
1 | from PyQt5.QtWidgets import QPushButton, QComboBox
2 |
3 |
4 | class ControlPushButton(QPushButton):
5 |
6 | def __init__(self, text="Button", callback=None, enabled=True):
7 | super().__init__(text)
8 | self.clicked.connect(callback)
9 | self.setEnabled(enabled)
10 |
11 |
12 | class PercentageComboBox(QComboBox):
13 |
14 | def __init__(self, ):
15 | super().__init__()
16 | for x in range(100, 20, -10):
17 | self.addItem(str(x) + "%", x)
18 |
19 |
20 | class RegistrationMethodsCombobox(QComboBox):
21 |
22 | def __init__(self):
23 | super(RegistrationMethodsCombobox, self).__init__()
24 | # self.addItem("ICP", 0)
25 | self.addItem("CPD - Rigid", 1)
26 | self.addItem("CPD - Affine", 2)
27 | self.addItem("CPD - Deformable", 3)
28 |
--------------------------------------------------------------------------------
/graphicInterface/file_dialogs.py:
--------------------------------------------------------------------------------
1 | from PyQt5.Qt import QFileDialog
2 |
3 |
4 | def load_file_dialog(parent, filters, multiple_files=False):
5 | dlg = QFileDialog()
6 | options = dlg.Options()
7 | options |= QFileDialog.DontUseNativeDialog
8 | if multiple_files:
9 | file_name, _ = dlg.getOpenFileNames(parent, "Load a model", "", filters, "File WRML (*.wrl)", options=options)
10 | else:
11 | file_name, _ = dlg.getOpenFileName(parent, "Load a model", "", filters, "File WRML (*.wrl)", options=options)
12 | if file_name == "":
13 | return None
14 | return file_name
15 |
16 |
17 | def save_file_dialog(parent, filters):
18 | dlg = QFileDialog()
19 | options = dlg.Options()
20 | options |= dlg.DontUseNativeDialog
21 | filename, _ = dlg.getSaveFileName(parent, None, "Save model", filter=filters, options=options)
22 | if filename == "":
23 | return None
24 | return filename
25 |
--------------------------------------------------------------------------------
/graphicInterface/rotatable_figure.py:
--------------------------------------------------------------------------------
1 | from graphicInterface.plot_button_collection import PlotButtonCollection
2 | from graphicInterface.plot_figure import PlotFigure
3 | from pointRegistration.model import Model
4 |
5 |
6 | class RotatableFigure(PlotFigure):
7 |
8 | def __init__(self, parent, model=None, landmarks=True, title=None, secondary_model=None):
9 | self.rotate_buttons = None
10 | super().__init__(parent, model, landmarks, title)
11 | self.registered = False
12 | self.secondary_model = secondary_model
13 |
14 | def rotate(self, axis, theta, registered=False):
15 | if self.model is not None:
16 | self.clear()
17 | self.model.rotate(axis, theta)
18 | if self.secondary_model is not None:
19 | self.secondary_model.rotate(axis, theta)
20 |
21 | self.draw_data()
22 |
23 | def load_model(self, model):
24 | super().load_model(model)
25 | if self.rotate_buttons is None:
26 | self.rotate_buttons = PlotButtonCollection(self.rotate, self)
27 |
28 | def set_secondary_model(self, model):
29 | self.secondary_model = Model.from_model(model)
30 |
31 | def draw_data(self, clear=False):
32 | if self.secondary_model is not None:
33 | self.load_data(self.secondary_model.points, "r")
34 | if self.secondary_model.landmarks is not None:
35 | self.load_data(self.secondary_model.landmarks, "black")
36 | super().draw_data(clear=clear)
37 |
--------------------------------------------------------------------------------
/morpmodel/RP.py:
--------------------------------------------------------------------------------
1 | class RP:
2 | def __init__(self):
3 | self.width = 640
4 | self.height = 486
5 | self.gamma = 0
6 | self.theta = 0
7 | self.phi = 0
8 | self.alpha = 0
9 | self.t2d = [0,0]
10 | self.camera_pos = [0,0,3400]
11 | self.scale_scene = 0.0
12 | #rselfp.object_size = 0.615 * 512
13 | self.shift_object = [0,0,-46125]
14 | #rp.shift_object = [0;0;0];
15 | self.shift_world = [0,0,0]
16 | self.scale = 0.001
17 | self.ac_g = [1,1,1]
18 | self.ac_c = 1
19 | self.ac_o = [0,0,0]
20 | #self.ambient_col = 0.6*[1,1,1]
21 | #rp.rotm = eye(3) ???
22 | self.use_rotm = 0
23 | self.do_remap = 0
24 | self.dir_light = []
25 | self.do_specular = 0.1
26 | self.phong_exp = 8
27 | self.specular = 0.1*255
28 | self.do_cast_shadows = 1
29 | self.sbufsize = 200
30 | # projection method
31 | self.proj = 'perspective'
32 | # if scale_scene == 0, then f is used:
33 | self.f = 6000
34 | # is 1 for grey level images and 3 for color images
35 | self.n_chan = 3
36 | self.backface_culling = 2; # 2 = default for current projection
37 | # can be 'phong', 'global_illum' or 'no_illum'
38 | #self.illum_method = 'phong'
39 | #self.global_illum.brdf = 'lambert'
40 | #self.global_illum.envmap = struct([])
41 | #self.global_illum.light_probe = []
42 | self.ablend = [] # no blending performed
43 |
--------------------------------------------------------------------------------
/morpmodel/Matrix_operations.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numpy.linalg import norm
3 |
4 | class Matrix_op:
5 | X_aligned_model = []
6 | X_after_training = []
7 | X_res = []
8 | def __init__(self, Components, aligned_models_data):
9 | if aligned_models_data is None:
10 | self.X_after_training = Components
11 | self.reshape(Components)
12 | X_aligned_model = None
13 | if Components is None:
14 | self.X_aligned_model = np.array(aligned_models_data) # 2-d Dimension and saved in memory with the fortran memory order
15 | self.X_after_training = []
16 | self.X_res = []
17 |
18 | def mean(self):
19 | self.X_aligned_model = self.X_aligned_model - self.X_aligned_model.mean(axis=1, keepdims=True) # axis=1 for get row's mean (axis=0 for columns)
20 |
21 | def normalization(self):
22 | self.X_aligned_model = norm(self.X_aligned_model, axis=1,ord=1) #L1-norm
23 |
24 | def transpose(self):
25 | self.X_aligned_model = np.transpose(self.X_aligned_model)
26 |
27 | def reshape(self,D):
28 | #self.X_res = np.empty((6704,3,300))# define numpy array
29 | self.X_res = np.empty((int(D.shape[0]/3),3,D.shape[1]))
30 | for c in range(D.shape[1]):
31 | comp = np.transpose(np.array(D[:,c]))
32 | _app = np.reshape(np.transpose(comp),(3,int(D.shape[0]/3)), order='F')
33 | comp = np.transpose(_app)
34 | self.X_res[:,:,c] = comp
35 |
36 | class Vector_op:
37 | V = []
38 | def __init__(self,v_init):
39 | self.V = v_init
40 |
41 | def scale(self,mx,mn):
42 | min_w = np.amin(self.V)
43 | max_w = np.amax(self.V)
44 | self.V = (((self.V-min_w)*(mx-mn))/(max_w-min_w)) + mn
45 |
--------------------------------------------------------------------------------
/graphicInterface/plot_button_collection.py:
--------------------------------------------------------------------------------
1 | from matplotlib.widgets import Button
2 | from matplotlib import pyplot
3 |
4 |
5 | class PlotButtonCollection(object):
6 |
7 | def __init__(self, callback, parent):
8 | self.callback = callback
9 |
10 | ax_yminus = pyplot.axes([0, 0.07, 0.09, 0.05])
11 | ax_yplus = pyplot.axes( [0.1, 0.07, 0.09, 0.05])
12 | ax_xminus = pyplot.axes([0, 0.13, 0.09, 0.05])
13 | ax_xplus = pyplot.axes( [0.1, 0.13, 0.09, 0.05])
14 | ax_zminus = pyplot.axes([0, 0.01, 0.09, 0.05])
15 | ax_zplus = pyplot.axes( [0.1, 0.01, 0.09, 0.05])
16 |
17 | self.b_yplus = Button(ax_yplus, '+15°')
18 | self.b_yplus.on_clicked(self.rotate_y_plus)
19 |
20 | self.b_yminus = Button(ax_yminus, 'Y -15°')
21 | self.b_yminus.on_clicked(self.rotate_y_minus)
22 |
23 | self.b_xplus = Button(ax_xplus, '+15°')
24 | self.b_xplus.on_clicked(self.rotate_x_plus)
25 |
26 | self.b_xminus = Button(ax_xminus, 'X -15°')
27 | self.b_xminus.on_clicked(self.rotate_x_minus)
28 |
29 | self.b_zminus = Button(ax_zminus, 'Z -15°')
30 | self.b_zminus.on_clicked(self.rotate_z_minus)
31 |
32 | self.b_zplus = Button(ax_zplus, '+15°')
33 | self.b_zplus.on_clicked(self.rotate_z_plus)
34 |
35 | def rotate_x_plus(self, event):
36 | self.callback('x', 15)
37 |
38 | def rotate_x_minus(self, event):
39 | self.callback('x', -15)
40 |
41 | def rotate_y_plus(self, event):
42 | self.callback('y', 15)
43 |
44 | def rotate_y_minus(self, event):
45 | self.callback('y', -15)
46 |
47 | def rotate_z_plus(self, event):
48 | self.callback('z', +15)
49 |
50 | def rotate_z_minus(self, event):
51 | self.callback('z', -15)
52 |
--------------------------------------------------------------------------------
/graphicInterface/window.py:
--------------------------------------------------------------------------------
1 | from graphicInterface.main_widget import *
2 | from graphicInterface.console import Logger
3 | from pointRegistration.registration_param import RegistrationParameters
4 | from PyQt5.QtGui import QIcon
5 | from graphicInterface.conf_param_window import EditConfWindow
6 |
7 |
8 | class Window(QMainWindow):
9 |
10 | def __init__(self):
11 | super().__init__()
12 | self.setWindowIcon(QIcon('resources/icon.png'))
13 | Logger()
14 | self.statusLabel = QLabel("")
15 | self.initUI()
16 |
17 | def initUI(self):
18 | self.mainWidget = MainWidget(self)
19 | self.setCentralWidget(self.mainWidget)
20 | self.resize(1000, 600)
21 | self.center()
22 | statusBar = QStatusBar()
23 | self.setStatusBar(statusBar)
24 | statusBar.addWidget(self.statusLabel)
25 | statusBar.addPermanentWidget(ConfigLabel())
26 | self.setWindowTitle('Shape Registrator')
27 | self.setStatus("Ready.")
28 | self.show()
29 |
30 | def center(self):
31 | qr = self.frameGeometry()
32 | cp = QDesktopWidget().availableGeometry().center()
33 | qr.moveCenter(cp)
34 | self.move(qr.topLeft())
35 |
36 | def setStatus(self, message):
37 | self.statusLabel.setText(message)
38 |
39 | def setStatusReady(self):
40 | self.setStatus("Ready")
41 |
42 |
43 | class ConfigLabel(QLabel):
44 |
45 | def __init__(self):
46 | super(ConfigLabel, self).__init__()
47 | self.update_parameter_label()
48 |
49 | def mouseDoubleClickEvent(self, *args, **kwargs):
50 | window = EditConfWindow(self)
51 | window.show()
52 |
53 | def update_parameter_label(self):
54 | ss = "(Double click to edit) |" + RegistrationParameters.to_string()
55 | self.setText(ss)
56 |
--------------------------------------------------------------------------------
/pointRegistration/registration_param.py:
--------------------------------------------------------------------------------
1 | from configparser import ConfigParser
2 | from graphicInterface.console import Logger
3 |
4 |
5 | class RegistrationParameters:
6 |
7 | instance = None
8 |
9 | class __PrivateParams: # Singleton
10 |
11 | def __init__(self):
12 | c = ConfigParser()
13 | c.read('registration_parameters.conf')
14 | ps = c['PARAMETERS']
15 | # Now read every key
16 | self.tolerance = float(ps['tolerance'])
17 | self.max_iterations = int(ps['max_iterations'])
18 | self.sigma2 = None if ps['sigma2'] == 'None' else float(ps['sigma2'])
19 | self.w = float(ps['w'])
20 | self.params = {'tolerance': self.tolerance,
21 | 'max_iterations': self.max_iterations,
22 | 'sigma2': self.sigma2,
23 | 'w': self.w}
24 |
25 | def get_param(self, key):
26 | try:
27 | return self.params[key]
28 | except KeyError:
29 | return None
30 |
31 | def write_on_file(self):
32 | with open('registration_parameters.conf', 'w+') as conf_file:
33 | conf = ConfigParser()
34 | conf['PARAMETERS'] = {}
35 | for key, value in self.params.items():
36 | conf['PARAMETERS'][str(key)] = str(value)
37 | conf.write(conf_file)
38 | Logger.addRow("Configuration file updated")
39 |
40 | def get_params(self):
41 | return self.params
42 |
43 | def set_params(self, key, value):
44 | self.params[key] = value
45 |
46 | def __init__(self):
47 | if RegistrationParameters.instance is None:
48 | RegistrationParameters.instance = self.__PrivateParams()
49 |
50 | @staticmethod
51 | def get_params():
52 | """ Returns a dictionary with all the registration parameters. """
53 | return RegistrationParameters.instance.get_params()
54 |
55 | @staticmethod
56 | def get_param(key):
57 | return RegistrationParameters.instance.get_param(key)
58 |
59 | @staticmethod
60 | def set_param(key, value):
61 | RegistrationParameters.instance.set_params(key, value)
62 |
63 | @staticmethod
64 | def write_on_file():
65 | RegistrationParameters.instance.write_on_file()
66 |
67 | @staticmethod
68 | def to_string():
69 | param = RegistrationParameters.instance.get_params()
70 | ss = ""
71 | for k in param:
72 | ss += k + " : " + str(param[k]) + " | "
73 |
74 | return ss[0:-2]
75 |
76 |
--------------------------------------------------------------------------------
/graphicInterface/console.py:
--------------------------------------------------------------------------------
1 | from PyQt5.QtWidgets import QPlainTextEdit
2 | from PyQt5.QtGui import QTextCursor
3 | import datetime
4 | import time
5 | import os
6 |
7 |
8 | class Logger:
9 |
10 | instance = None
11 | lock = False
12 |
13 | class _PrivateLog(QPlainTextEdit):
14 |
15 | def __init__(self):
16 | super().__init__()
17 | self.setReadOnly(True)
18 | self.setLineWrapMode(QPlainTextEdit.NoWrap)
19 | self.cursor = QTextCursor(self.document())
20 | self.setTextCursor(self.cursor)
21 |
22 | def save_onfile(self):
23 | now = datetime.datetime.now()
24 | filename = ("log-%d-%d_%d-%d-%d.txt" % (now.hour, now.minute, now.year, now.month, now.day))
25 | file = open(filename, "w")
26 | file.write(self.toPlainText())
27 | file.close()
28 |
29 | def __init__(self):
30 | if Logger.instance is None:
31 | Logger.instance = self._PrivateLog()
32 |
33 | @staticmethod
34 | def addRow(row):
35 | if Logger.instance is not None:
36 | Logger.instance.insertPlainText(row+"\n")
37 | time.sleep(0.1)
38 | Logger.instance.moveCursor(QTextCursor.End)
39 | time.sleep(0.1)
40 | print(row)
41 |
42 | @staticmethod
43 | def setParent(parent):
44 | if Logger.instance is not None:
45 | Logger.instance.setParent(parent)
46 |
47 | @staticmethod
48 | def save_log():
49 | if Logger.instance is not None:
50 | Logger.instance.save_onfile()
51 |
52 |
53 | class suppress_stdout_stderr(object):
54 | """
55 | A context manager for doing a "deep suppression" of stdout and stderr in
56 | Python, i.e. will suppress all print, even if the print originates in a
57 | compiled C/Fortran sub-function.
58 | This will not suppress raised exceptions, since exceptions are printed
59 | to stderr just before a script exits, and after the context manager has
60 | exited.
61 | """
62 |
63 | def __init__(self):
64 | # Open a pair of null files
65 | self.null_fds = [os.open(os.devnull, os.O_RDWR) for x in range(2)]
66 | # Save the actual stdout (1) and stderr (2) file descriptors.
67 | self.save_fds = [os.dup(1), os.dup(2)]
68 |
69 | def __enter__(self):
70 | # Assign the null pointers to stdout and stderr.
71 | os.dup2(self.null_fds[0], 1)
72 | os.dup2(self.null_fds[1], 2)
73 |
74 | def __exit__(self, *_):
75 | # Re-assign the real stdout/stderr back to (1) and (2)
76 | os.dup2(self.save_fds[0], 1)
77 | os.dup2(self.save_fds[1], 2)
78 | # Close the null files
79 | for fd in self.null_fds + self.save_fds:
80 | os.close(fd)
--------------------------------------------------------------------------------
/graphicInterface/show_displacement.py:
--------------------------------------------------------------------------------
1 | from PyQt5.QtGui import QIcon
2 | from PyQt5.QtWidgets import QMainWindow, QLabel, QGridLayout, QWidget, QPushButton
3 | from graphicInterface.rotatable_figure import RotatableFigure
4 | from graphicInterface.upper_toolbar_controls import ControlPushButton
5 | from graphicInterface.file_dialogs import *
6 |
7 |
8 | class DisplacementMapWindow(QMainWindow):
9 |
10 | def __init__(self, parent, model):
11 | super().__init__(parent=parent)
12 | self.setWindowIcon(QIcon('resources/icon.png'))
13 | self.statusLabel = QLabel("Ready")
14 | self.setWindowTitle("Displacement Map")
15 | self.model = model
16 | self.rot_figure = None
17 | self.initUI()
18 |
19 | def initUI(self):
20 | central_widget = QWidget(self)
21 | grid_central = QGridLayout(self)
22 | central_widget.setLayout(grid_central)
23 | self.setCentralWidget(central_widget)
24 | self.setLayout(grid_central)
25 | self.rot_figure = RotatableFigure(parent=self, model=self.model, title="Displacement Map")
26 | self.toolbar = LowerToolbar(self)
27 | grid_central.addWidget(self.rot_figure, 0, 0, 19, 1)
28 | grid_central.addWidget(self.toolbar, 20, 0, 1, 1)
29 | if self.model is not None:
30 | self.toolbar.save_button.setEnabled(True)
31 | self.resize(600, 600)
32 | self.rot_figure.draw_data(clear=True)
33 | self.show()
34 |
35 | def save_displacement_map(self):
36 | filters = "Serialized Python Obj (*.pickle);;H5Py Compatible file (*.h5py)"
37 | filename = save_file_dialog(self, filters)
38 | if filename is not None:
39 | self.model.save_model(filename)
40 |
41 | def load_displacement_map(self):
42 | filters = "H5Py Compatible file (*.h5py)"
43 | filename = load_file_dialog(self, filters)
44 | if filename is not None:
45 | from pointRegistration.displacementMap import DisplacementMap
46 | self.model = DisplacementMap.load_model(filename)
47 | self.rot_figure.load_model(self.model)
48 | self.rot_figure.draw_data()
49 | self.toolbar.save_button.setEnabled(True)
50 |
51 |
52 | class LowerToolbar(QWidget):
53 |
54 | def __init__(self, parent):
55 | super().__init__(parent)
56 | self.parent = parent
57 | self.initUI()
58 |
59 | def initUI(self):
60 | self.layout = QGridLayout()
61 | self.setLayout(self.layout)
62 | self.save_button = ControlPushButton("Save on file", self.parent.save_displacement_map, False)
63 | load_button = ControlPushButton("Load from file..", self.parent.load_displacement_map)
64 | self.layout.addWidget(self.save_button, 0, 1, 1, 1)
65 | self.layout.addWidget(load_button, 0, 3, 1, 1)
--------------------------------------------------------------------------------
/graphicInterface/plot_interactive_figure.py:
--------------------------------------------------------------------------------
1 | import matplotlib.patches as patches
2 | import numpy as np
3 | from matplotlib.widgets import RectangleSelector
4 | from scipy import spatial
5 |
6 | from graphicInterface.console import suppress_stdout_stderr
7 | from graphicInterface.plot_figure import PlotFigure
8 |
9 |
10 | class PlotInteractiveFigure(PlotFigure):
11 |
12 | def __init__(self, parent, model=None, landmarks=True, title=None):
13 | super().__init__(parent, model, landmarks, title)
14 | self.myTree = None
15 | self.RS = RectangleSelector(self.ax, self.square_select_callback, drawtype='box', useblit=True, button=[1, 3],
16 | minspanx=5, minspany=5, spancoords='pixels', interactive=True)
17 |
18 | def load_model(self, model):
19 | self.myTree = None
20 | super().load_model(model)
21 |
22 | def select_nearest_pixel(self, x_coord, y_coord):
23 | if self.myTree is None:
24 | print("Calculating 2DTree...")
25 | self.myTree = spatial.cKDTree(self.model.landmarks_3D[:, 0:2]) # costruisce il KDTree con i punti del Model
26 |
27 | dist, index = self.myTree.query([[x_coord, y_coord]], k=1)
28 | if dist < 5:
29 | self.landmarks_colors[index[0]] = "y" if self.landmarks_colors[index[0]] == "r" else "r"
30 | self.draw()
31 | if self.parent() is not None:
32 | self.parent().landmark_selected(self.landmarks_colors)
33 |
34 | def square_select_callback(self, eclick, erelease):
35 | # eclick and erelease are the press and release events
36 | x1, y1 = eclick.xdata, eclick.ydata
37 | x2, y2 = erelease.xdata, erelease.ydata
38 | rect = patches.Rectangle((min(x1, x2), min(y1, y2)), np.abs(x1 - x2), np.abs(y1 - y2),
39 | linewidth=1, edgecolor='r', facecolor='none', fill=True)
40 | self.get_ax().add_patch(rect)
41 | self.select_area(min(x1, x2), min(y1, y2), np.abs(x1 - x2), np.abs(y1 - y2))
42 | self.draw()
43 |
44 | def there_are_points_highlighted(self):
45 | return self.model.has_registration_points()
46 |
47 | def select_area(self, x_coord, y_coord, width, height):
48 | with suppress_stdout_stderr():
49 | x_data = self.model.points[:, 0]
50 | y_data = self.model.points[:, 1]
51 |
52 | x_ind = np.where((x_coord <= x_data) & (x_data <= x_coord + width))
53 | y_ind = np.where((y_coord <= y_data) & (y_data <= y_coord + height))
54 |
55 | ind = np.intersect1d(np.array(x_ind), np.array(y_ind), assume_unique=True)
56 | self.highlight_data(ind)
57 |
58 | def highlight_data(self, indices):
59 | if indices[0] != -1:
60 | self.model.add_registration_points(indices)
61 | self.draw_data()
62 | else:
63 | self.model.init_registration_points()
64 | self.draw_data()
65 |
--------------------------------------------------------------------------------
/pointRegistration/batchRegistration.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from threading import Thread
3 | from graphicInterface.console import Logger
4 | from pycpd.rigid_registration import rigid_registration
5 | from pointRegistration.model import Model
6 | from functools import partial
7 | import time
8 | import datetime
9 | import os
10 |
11 |
12 | class BatchRegistrationThread(Thread):
13 |
14 | def __init__(self, source, target_list, percentage, final_callback):
15 | Thread.__init__(self)
16 | self.source_model = source
17 | self.target_list = target_list
18 | self.perc = percentage
19 | self.finalCallback = final_callback
20 | self.should_stop = False
21 | Logger.addRow("Starting Batch Thread..")
22 |
23 | def run(self):
24 | source = self.source_model.get_registration_points()
25 |
26 | try:
27 | for targ in self.target_list:
28 | Logger.addRow("Batch %d of %d:" % (self.target_list.index(targ) + 1, len(self.target_list)))
29 | path_wrl = targ[0:len(targ) - 3] + "bnd"
30 | t = Model(targ, path_wrl)
31 | target = Model.decimate(t.points, self.perc)
32 | Logger.addRow("Points decimated.")
33 | if t.landmarks is not None:
34 | target = np.concatenate((target, t.landmarks), axis=0)
35 | reg = rigid_registration(**{'X': source, 'Y': target})
36 | meth = "CPD Rigid"
37 |
38 | Logger.addRow("Starting registration with " + meth + ", using " + str(self.perc) + "% of points.")
39 | model = Model()
40 |
41 | reg_time = time.time()
42 |
43 | # Se si vuole visualizzare i progressi usare questa versione
44 | # data, reg_param = reg.register(partial(self.drawCallback, ax=None))
45 | data, reg_param = reg.register(partial(self.log, ax=None))
46 |
47 | model.set_points(reg.transform_point_cloud(self.source_model.model_data))
48 |
49 | model.registration_params = reg_param
50 | if t.landmarks is not None:
51 | model.set_landmarks(data[target.shape[0] - t.landmarks.shape[0]: data.shape[0]])
52 | model.filename = t.filename
53 | # model.centerData()
54 | model.compute_displacement_map(target, 3)
55 | now = datetime.datetime.now()
56 | save_filename = "RIGID_REG_{0}_{1}_{2}_{3}_{4}.mat"
57 | save_path = os.path.join("results", save_filename.format(now.day, now.month, now.year, now.hour,
58 | now.minute))
59 | model.save_model(save_path)
60 | model.shoot_displacement_map(save_path)
61 | Logger.addRow("Took " + str(round(time.time() - reg_time, 3)) + "s.")
62 |
63 | except Exception as ex:
64 | Logger.addRow(str(ex))
65 | print(ex)
66 | finally:
67 | self.finalCallback()
68 |
69 | def log(self, iteration, error, X, Y, ax):
70 | sss = "Iteration #" + str(iteration) + " error: " + str(error)
71 | Logger.addRow(sss)
72 | if self.should_stop:
73 | raise Exception("Registration has been stopped")
74 |
75 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Shape Registrator
2 |
3 | *A tool for 3D face models aligment. With a GUI.*
4 |
5 | 
6 |
7 | Shape Registrator works with 3D model and supports **.mat** (Matlab), **.wrl** (WRML) and **.off** (Open File Format) file formats.
8 |
9 | ## 2019 Update:
10 | - Displacement map can now be saved as a Pickle file.
11 | - Target model can be rotated, in order to provide a good starting point for CPD algorithm.
12 | - Source model is now shown in the target section while registrated.
13 |
14 | ### Installation
15 |
16 | 1. Clone repo.
17 | 2. `pip install -r requirements.txt`
18 | 3. Run main.py
19 |
20 | ### Usage
21 |
22 | 1. Load two models, you can choose between .mat, .wrl + .bnd and .off formats.
23 | 2. Select the amount of points to be used and the CPD version: rigid, affine or deformable.
24 | 3. "Registrate", and wait for the result.
25 |
26 | In the log box you can see the registration error and some messages.
27 |
28 | ### About file formats:
29 |
30 | **.mat** files have to contain a '3ddata' field storing 3D points coordinates (a Nx3 matrix) and a 'landmarks3d' witch stores the landmarks coordinates (Nx3 matrix).
31 |
32 | **.wrl** files should be paired with a .bnd file storing the landmarks coordinates. If a .png image with the same name is aviable, it will be used for the background of the plot tool.
33 |
34 | **.off** files must contain 3D points coordinates, faces are not mandatory and will be ignored.
35 |
36 | ## Insights
37 |
38 | In the very heart of the application there's the *Coherent Point Drift* algorithm[1], witch allows to registrate two point sets regardless from the transformation's nature.
39 | In fact, CPD provides three kind of transformation: **rigid**, **affine** and **deformable**; with crescent grades of freedom.
40 |
41 | It's possible to control the quantity of the points involved in the registration process: you can choose to keep all of them (100%) or to reduce until 30% of original number. The sampling policy used is *uniform and casual*.
42 |
43 | Reduce the amount of points makes the registration process faster, but - only for affine and deformable case - the transformation matrices can't be used to process the entire point set. However, we can take advantage of this option in the rigid case.
44 |
45 | ## Packages required
46 | *A requirements.txt is provided, in order to be used with pip*
47 |
48 | Package | Version
49 | --------|--------
50 | numpy | >=1.15.0
51 | scipy | >=1.3
52 | matplotlib| ~2.2.3
53 | PyQt5| ~5.11.2
54 | h5py| ~2.8
55 | pycpd| ~1.0.3
56 |
57 | The application was tested so far with Python 3.7 x64 operating on Windows 10 and Ubuntu 16.04 LTS. Should work on MacOs as well, since PyQt5 is cross-platform.
58 |
59 | ## References
60 | [1] Andriy Myronenko and Xubo Song, "*Point Set Registration: Coherent Point Drift*", IEEE Trans. on Pattern Analysis and Machine Intelligence, vol. 32, issue 12, pp. 2262-2275, 15 May 2009 {[link](https://arxiv.org/pdf/0905.2635.pdf)}
61 |
62 | [2] Alessandro Soci and Gabriele Barlacchi, [AlessandroSoci/3DMM-Facial-Expression-from-Webcam](https://github.com/AlessandroSoci/3DMM-Facial-Expression-from-Webcam)
63 |
64 | [3] Alessandro Sestini and Francesco Lombardi, [fralomba/Facial-Expression-Prediction](https://github.com/fralomba/Facial-Expression-Prediction)
65 |
--------------------------------------------------------------------------------
/graphicInterface/conf_param_window.py:
--------------------------------------------------------------------------------
1 | from PyQt5.QtWidgets import QLabel, QMainWindow, QWidget, QGridLayout, QGroupBox, QLineEdit
2 | from PyQt5.QtGui import QIntValidator, QDoubleValidator, QValidator
3 | from graphicInterface.upper_toolbar_controls import ControlPushButton
4 | from pointRegistration.registration_param import RegistrationParameters
5 |
6 |
7 | class EditConfWindow(QMainWindow):
8 |
9 | def __init__(self, parent):
10 | super().__init__(parent=parent)
11 | self.setWindowTitle("Edit configuration")
12 | self.setCentralWidget(ConfEditCentralWidget(self))
13 | self.resize(400, 200)
14 | self.show()
15 |
16 |
17 | class ConfEditCentralWidget(QWidget):
18 |
19 | def __init__(self, parent):
20 | super(ConfEditCentralWidget, self).__init__(parent=parent)
21 | general_layout = QGridLayout()
22 | self.setLayout(general_layout)
23 | group = QGroupBox("Configuration parameters")
24 | general_layout.addWidget(group, 0, 0, 1, 6)
25 |
26 | general_layout.addWidget(ControlPushButton("Apply", self.apply_changes), 1, 4, 1, 1)
27 | general_layout.addWidget(ControlPushButton("Close", self.parent().close), 1, 5, 1, 1)
28 |
29 | layout_conf = QGridLayout()
30 | group.setLayout(layout_conf)
31 |
32 | layout_conf.addWidget(QLabel("Tolerance"), 0, 0)
33 | layout_conf.addWidget(QLabel("Max iterations for Local Registration"), 1, 0)
34 | layout_conf.addWidget(QLabel("Sigma2"), 2, 0)
35 | weight_label = QLabel("Weight ?")
36 | weight_label.setToolTip("Uniform distribution component in GMM for Expectation-Maximizazion step. "
37 | "Must be in interval (0,1)")
38 | layout_conf.addWidget(weight_label, 3, 0)
39 |
40 | self.tolerance_input = ValidatedLineEdit('tolerance', DoubleNoneValidator(0.0000001, 1000, 7))
41 | layout_conf.addWidget(self.tolerance_input, 0, 1)
42 |
43 | self.max_iterations_input = ValidatedLineEdit('max_iterations', IntValidator(1, 999))
44 | layout_conf.addWidget(self.max_iterations_input, 1, 1)
45 |
46 | self.sigma_input = ValidatedLineEdit('sigma2', DoubleNoneValidator(0, 1, 6, nullable=True))
47 | layout_conf.addWidget(self.sigma_input, 2, 1)
48 |
49 | self.weight_input = ValidatedLineEdit('w', DoubleNoneValidator(-0.0000001, 1, 6))
50 | layout_conf.addWidget(self.weight_input, 3, 1)
51 |
52 | self.inputs = [self.tolerance_input, self.max_iterations_input, self.sigma_input, self.weight_input]
53 |
54 | def apply_changes(self):
55 | for value_input in self.inputs:
56 | if value_input.is_valid():
57 | RegistrationParameters.set_param(value_input.key, value_input.get_value())
58 | else:
59 | print(f"Parameter for {value_input.key} is invalid, using previous value.")
60 | RegistrationParameters.write_on_file()
61 | self.parent().parent().update_parameter_label()
62 |
63 |
64 | class ValidatedLineEdit(QLineEdit):
65 |
66 | def __init__(self, key, validator=None):
67 | super(ValidatedLineEdit, self).__init__()
68 | self.key = key
69 | self.setText(str(RegistrationParameters.get_param(key)))
70 | self.validator = validator
71 |
72 | def get_value(self):
73 | text = self.text()
74 |
75 | if text.lower().find("none") > -1:
76 | return None
77 |
78 | return self.validator.type(text)
79 |
80 | def is_valid(self):
81 | state = self.validator.validate(self.text())
82 | if state == QValidator.Invalid:
83 | return False
84 | return True
85 |
86 |
87 | class IntValidator(QValidator):
88 |
89 | def __init__(self, bottom, top):
90 | super(IntValidator, self).__init__()
91 | self.bottom = bottom
92 | self.top = top
93 | self.type = int
94 |
95 | def validate(self, p_str, p_int=None):
96 | try:
97 | if self.top > int(p_str) > self.bottom:
98 | return QValidator.Acceptable
99 | except ValueError:
100 | return QValidator.Invalid
101 |
102 | return QValidator.Invalid
103 |
104 |
105 | class DoubleNoneValidator(QValidator):
106 |
107 | def __init__(self, bottom, top, digits, nullable=False):
108 | super(DoubleNoneValidator, self).__init__()
109 | self.bottom = bottom
110 | self.top = top
111 | self.digits = digits
112 | self.type = float
113 | self.nullable = nullable
114 |
115 | def validate(self, p_str, p_int=None):
116 | if self.nullable and p_str.lower().find('none') > -1:
117 | return QValidator.Acceptable
118 | try:
119 | if self.top > float(p_str) > self.bottom and len(p_str) <= (self.digits + 2):
120 | return QValidator.Acceptable
121 | except ValueError:
122 | return QValidator.Invalid
123 |
124 | return QValidator.Invalid
125 |
--------------------------------------------------------------------------------
/pointRegistration/registration.py:
--------------------------------------------------------------------------------
1 | from pointRegistration.registration_param import RegistrationParameters
2 | from pointRegistration.model import Model
3 | from graphicInterface.console import Logger
4 | from threading import Thread
5 | from functools import partial
6 | from pycpd import *
7 | import numpy as np
8 | import time
9 |
10 |
11 | class Registration(Thread):
12 |
13 | def __init__(self, method, source_model, target_model, perc, callback, iteration_callback=None):
14 | Thread.__init__(self)
15 | self.method = method
16 | self.source_model = source_model
17 | self.target_model = target_model
18 | self.percentage = perc
19 | self.callback = callback
20 | self.should_stop = False
21 | self.iteration_callback = iteration_callback
22 | self.registration_method = None
23 |
24 | def run(self):
25 | source = self.source_model.get_registration_points()
26 | target = Model.decimate(self.target_model.points, self.percentage)
27 | Logger.addRow("Points decimated.")
28 | if self.target_model.landmarks is not None:
29 | target = np.concatenate((target, self.target_model.landmarks), axis=0)
30 | Logger.addRow("Landmarks added.")
31 |
32 | ps = RegistrationParameters().get_params()
33 |
34 | if self.method == 1: # CPD - RIGID
35 | self.registration_method = rigid_registration(**{'X': target, 'Y': source, 'sigma2': ps['sigma2'],
36 | 'max_iterations': ps['max_iterations'],
37 | 'tolerance': ps['tolerance'], 'w': ps['w']})
38 | method = "CPD Rigid"
39 | if self.method == 2: # CPD - AFFINE
40 | self.registration_method = affine_registration(**{'X': target, 'Y': source, 'sigma2': ps['sigma2'],
41 | 'max_iterations': ps['max_iterations'],
42 | 'tolerance': ps['tolerance'], 'w': ps['w']})
43 | method = "CPD Affine"
44 | if self.method == 3: # CPD - DEFORMABLE
45 | self.registration_method = deformable_registration(**{'X': target, 'Y': source, 'sigma2': ps['sigma2'],
46 | 'max_iterations': ps['max_iterations'],
47 | 'tolerance': ps['tolerance'], 'w': ps['w']})
48 | method = "CPD Deformable"
49 |
50 | Logger.addRow("Starting registration with " + method + ", using " + str(self.percentage) + "% of points.")
51 | model = Model()
52 | reg_time = time.time()
53 |
54 | try:
55 | self.registration_method.register(partial(self.interruptable_wrapper, ax=None))
56 | model = self.aligned_model(model)
57 | except InterruptedException as ex:
58 | Logger.addRow(str(ex))
59 | model = self.aligned_model(model)
60 | except Exception as ex:
61 | Logger.addRow("Err: " + str(ex))
62 | model = self.target_model # Fail: back with the original target model
63 | finally:
64 | Logger.addRow("Took "+str(round(time.time()-reg_time, 3))+"s.")
65 | self.callback(model)
66 |
67 | def aligned_model(self, model):
68 | """
69 | Transforms all data points and landmarks of source model, applying the transformation obtained during the
70 | registration phase.
71 | :param model: A initialized and empty Model object
72 | :return: The input model, filled with transformed data points, landmarks, filename, registration parameters
73 | and displacement map oomputed between transformed source and target.
74 | """
75 | model.registration_params = self.registration_method.get_registration_parameters()
76 | points = self.registration_method.transform_point_cloud(self.source_model.points)
77 | if self.source_model.landmarks is not None:
78 | landmarks = self.registration_method.transform_point_cloud(self.source_model.landmarks)
79 | model.set_landmarks(landmarks)
80 |
81 | model.set_points(points)
82 | model.filename = self.target_model.filename
83 |
84 | # model.compute_displacement_map(self.target_model, 3)
85 | return model
86 |
87 | def stop(self):
88 | self.should_stop = True
89 |
90 | def log(self, iteration, error):
91 | row = "Iteration #" + str(iteration) + " error: " + str(error)
92 | Logger.addRow(row)
93 |
94 | def interruptable_wrapper(self, **kwargs):
95 | if self.should_stop:
96 | raise InterruptedException("Registration has been stopped")
97 |
98 | if self.iteration_callback is None:
99 | self.log(**kwargs)
100 | else:
101 | self.iteration_callback(**kwargs)
102 |
103 |
104 | class InterruptedException(Exception):
105 | pass
106 |
--------------------------------------------------------------------------------
/graphicInterface/upper_toolbar.py:
--------------------------------------------------------------------------------
1 | from graphicInterface.file_dialogs import *
2 | from PyQt5.Qt import *
3 | from graphicInterface.console import Logger
4 | from graphicInterface.upper_toolbar_controls import *
5 |
6 |
7 | class UpperToolbar(QWidget):
8 |
9 | def __init__(self, parent):
10 | super().__init__(parent)
11 | self.parent = parent
12 | self.layout = QGridLayout()
13 | self.setLayout(self.layout)
14 |
15 | # "Registration" group
16 | registration_group = QGroupBox("Registration")
17 | registration_group_layout = QGridLayout()
18 | registration_group.setLayout(registration_group_layout)
19 | label = QLabel("Registration method:")
20 | label.setAlignment(Qt.AlignCenter)
21 | registration_group_layout.addWidget(label, 0, 0)
22 |
23 | self.registration_method_combobox = RegistrationMethodsCombobox()
24 | registration_group_layout.addWidget(self.registration_method_combobox, 0, 1)
25 |
26 | self.start_registration_button = ControlPushButton("Start Registration", self.registrate, False)
27 | registration_group_layout.addWidget(self.start_registration_button, 0, 2)
28 |
29 | self.stop_registration_button = ControlPushButton("Stop", self.stop_registration, False)
30 | registration_group_layout.addWidget(self.stop_registration_button, 1, 2)
31 |
32 | point_percentage_label = QLabel("Target's points used:")
33 | point_percentage_label.setAlignment(Qt.AlignCenter)
34 | registration_group_layout.addWidget(point_percentage_label, 1, 0)
35 |
36 | self.point_percentage_combobox = PercentageComboBox()
37 | registration_group_layout.addWidget(self.point_percentage_combobox, 1, 1)
38 |
39 | # "Model" group
40 | model_group = QGroupBox("Models")
41 | model_group_layout = QGridLayout()
42 | model_group.setLayout(model_group_layout)
43 |
44 | model_group_layout.addWidget(ControlPushButton("Load Source", self.load_source, True), 0, 0)
45 | model_group_layout.addWidget(ControlPushButton("Restore", self.restore, True), 1, 0)
46 | model_group_layout.addWidget(ControlPushButton("Load Target", self.load_target, True), 0, 1)
47 |
48 | self.save_target_btn = ControlPushButton("Save Target", self.save_target, False)
49 | model_group_layout.addWidget(self.save_target_btn, 1, 1)
50 |
51 | model_group_layout.addWidget(ControlPushButton("Batch registration", self.batch_reg, False), 0, 2)
52 |
53 | self.show_displacement_btn = ControlPushButton("Show Displacement Map", self.show_displacement, True)
54 | model_group_layout.addWidget(self.show_displacement_btn, 1, 2)
55 |
56 | # LOGGER GROUP
57 | logger_group = QGroupBox("Log")
58 | logger_group_layout = QGridLayout()
59 | logger_group.setLayout(logger_group_layout)
60 |
61 | save_log_btn = QPushButton("Save log on file")
62 | save_log_btn.clicked.connect(self.savelog_onfile)
63 | logger_group_layout.addWidget(Logger.instance)
64 | logger_group_layout.addWidget(save_log_btn)
65 |
66 | self.layout.addWidget(registration_group, 0, 0)
67 | self.layout.addWidget(model_group, 0, 1)
68 | self.layout.addWidget(logger_group, 0, 2)
69 |
70 | self.layout.setColumnStretch(0, 2)
71 | self.layout.setColumnStretch(1, 1)
72 | self.layout.setColumnStretch(2, 4)
73 |
74 | def registrate(self):
75 | method = self.registration_method_combobox.currentData()
76 | percent = self.point_percentage_combobox.currentData()
77 | try:
78 | self.parent.registrate(method, percent)
79 | self.start_registration_button.setEnabled(False)
80 | self.stop_registration_button.setEnabled(True)
81 | except Exception as ex:
82 | print(ex)
83 |
84 | def show_displacement(self):
85 | self.parent.show_displacement_map()
86 |
87 | def savelog_onfile(self):
88 | self.parent.savelog_onfile()
89 |
90 | def stop_registration(self):
91 | self.parent.stop_registration_thread()
92 | self.stop_registration_button.setEnabled(False)
93 |
94 | def restore(self):
95 | self.parent.restore()
96 |
97 | @pyqtSlot()
98 | def batch_reg(self):
99 | file_names = self.load_file(multiple_files=True)
100 | if file_names:
101 | method = self.registration_method_combobox.currentData()
102 | percent = self.point_percentage_combobox.currentData()
103 | try:
104 | self.parent.registrate_batch(method, percent, file_names)
105 | self.start_registration_button.setEnabled(False)
106 | self.stop_registration_button.setEnabled(True)
107 | except Exception as ex:
108 | print(ex)
109 |
110 | @pyqtSlot()
111 | def load_target(self):
112 | self.parent.load_target()
113 |
114 | @pyqtSlot()
115 | def load_source(self):
116 | self.parent.load_source()
117 |
118 | @pyqtSlot()
119 | def save_target(self):
120 | self.parent.save_target()
121 |
--------------------------------------------------------------------------------
/graphicInterface/plot_figure.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PyQt5.QtWidgets import QSizePolicy
3 | from matplotlib import pyplot
4 | from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
5 |
6 | from graphicInterface.console import suppress_stdout_stderr
7 |
8 |
9 | class PlotFigure(FigureCanvas):
10 |
11 | def __init__(self, parent, model=None, landmarks=True, title=None):
12 | self.fig, self.ax = pyplot.subplots()
13 | self.title = title
14 | FigureCanvas.__init__(self, self.fig)
15 | FigureCanvas.setSizePolicy(self, QSizePolicy.Expanding, QSizePolicy.Expanding)
16 | FigureCanvas.updateGeometry(self)
17 | self.setParent(parent)
18 | self.draw_landmarks = landmarks
19 | self.bgImage = None
20 | self.model = model
21 | self.drawDisplacement = False
22 | self.landmarks_colors = None
23 | self.data_colors = None
24 | self.scale_width = None
25 | self.scale_height = None
26 | self.legend_handlers = []
27 | self.legend_has_been_drawn = True
28 |
29 | if model is not None:
30 | self.load_model(model)
31 |
32 | def load_model(self, model):
33 | self.model = model
34 | self.drawDisplacement = False
35 |
36 | def load_data(self, data, color, marker='o', points_size=0.5, label=""):
37 | if data is None or data.shape[0] == 0:
38 | return
39 | append = False
40 | if label != "":
41 | label += f" ({data.shape[0]})"
42 | append = True
43 | max_v = np.max(data[:, 2])
44 | min_v = np.min(data[:, 2])
45 | sizes = np.copy(data[:, 2])
46 | self.sizes = ((sizes + np.abs(min_v)) / np.abs(max_v)) + 0.5
47 | legend_handler, = self.ax.plot(data[:, 0], data[:, 1], c=color, marker=marker, linestyle='None',
48 | markersize=points_size, label=label)
49 | if append:
50 | self.legend_handlers.append(legend_handler)
51 |
52 | def load_displacement(self):
53 | self.ax.plot(self.model.displacement_map[:, 0], self.model.displacement_map[:, 1])
54 |
55 | def load_image(self):
56 | if self.bgImage is not None:
57 | img = pyplot.imread(self.bgImage)
58 | self.ax.imshow(img, extent=[-self.scale_height * 1.05, self.scale_height * 1.03, -self.scale_height * 1.03,
59 | self.scale_height * 1.05]) # SX DX BOTTOM UP
60 |
61 | def show_displacement(self):
62 | self.drawDisplacement = True if self.model.displacement_map is not None else False
63 | self.draw_landmarks = False
64 |
65 | def clear(self):
66 | self.ax.cla()
67 |
68 | def draw_data(self, clear=False):
69 | if clear:
70 | self.ax.cla()
71 | if self.title is not None:
72 | self.ax.set_title(self.title)
73 | if self.model is not None:
74 | self.ax.autoscale()
75 | self.ax.set_aspect('equal')
76 | self.ax.set_xlabel('X axis')
77 | self.ax.set_ylabel('Y axis')
78 |
79 | self.load_data(self.model.points, self.model.points_color, label="Points")
80 | if self.model.landmarks is not None:
81 | self.load_data(self.model.landmarks, self.model.landmarks_color, points_size=5, label="Landmarks")
82 | try:
83 | self.load_data(self.model.points[self.model.registration_points], 'y', label="Registration Points")
84 | except Exception as ex:
85 | pass
86 |
87 | try:
88 | self.load_data(self.model.missed_points, self.model.missed_points_color, marker='v',
89 | label="Missed Points")
90 | self.load_data(self.model.missed_landmarks, self.model.missed_landmarks_color, size=5, marker='v',
91 | label="Missed Landmarks")
92 | except Exception as ex:
93 | pass
94 | """if len(self.legend_handlers) > 0 and self.legend_has_been_drawn:
95 | self.ax.legend(handles=self.legend_handlers)
96 | self.legend_has_been_drawn = False
97 | """
98 | self.load_image()
99 | self.draw()
100 |
101 | def draw(self, clear=False):
102 | if clear is True:
103 | self.clear()
104 | super().draw()
105 | self.flush_events()
106 |
107 | def restore_model(self):
108 | self.ax.cla()
109 | self.draw_data()
110 |
111 | def landmarks(self, l):
112 | self.draw_landmarks = l
113 |
114 | def set_landmarks_colors(self, colors):
115 | self.model.landmarks_color = colors
116 | self.draw_data()
117 |
118 | def set_data_colors(self, colors):
119 | self.model.points_color = colors
120 | self.draw_data()
121 |
122 | def get_ax(self):
123 | return self.ax
124 |
125 | def update_plot_callback(self, iteration, error, X, Y, ax):
126 | self.ax.cla()
127 | print(iteration, error)
128 | self.ax.scatter(Y[:, 0], Y[:, 1], 0.1, c='r')
129 | self.ax.scatter(X[:, 0], X[:, 1], 0.1, c='b')
130 | if self.bgImage is not None:
131 | img = pyplot.imread(self.bgImage)
132 | self.ax.imshow(img, extent=[-self.model.rangeY/2 * 1.05, self.model.rangeY/2 * 1.03,
133 | -self.model.rangeY/2 * 1.03, self.model.rangeY/2 * 1.05])
134 | try:
135 | self.ax.text(0.87, 0.92, 'Iteration: {:d}\nError: {:06.4f}'.format(iteration, error),
136 | horizontalalignment='center', verticalalignment='center', transform=self.ax.transAxes,
137 | fontsize='x-large')
138 | except Exception as ex:
139 | print(ex)
140 | with suppress_stdout_stderr():
141 | super().draw()
142 | self.flush_events()
143 |
144 | def has_model(self):
145 | if self.model is not None:
146 | return True
147 |
148 | return False
149 |
--------------------------------------------------------------------------------
/pointRegistration/model.py:
--------------------------------------------------------------------------------
1 | import ntpath
2 | import os
3 |
4 | import h5py
5 | import numpy as np
6 | from scipy.spatial.transform import Rotation
7 |
8 | from graphicInterface.console import Logger
9 | from pointRegistration import file3D
10 |
11 |
12 | class Model:
13 |
14 | def __init__(self, path_data=None):
15 | """
16 | Create a model from a .wrl file (3D points) and a .bnd file (landmarks) or load
17 | both elements from a .mat file.
18 | :param path_data: path to the .wrl file
19 | :param path_landmarks: path to .bnd file (only for .wrl/.bnd case)
20 | :param image: path to the associated image (opt)
21 | """
22 | self.landmarks = None
23 | self.landmarks_color = None
24 | self.points = None
25 | self.points_color = None
26 |
27 | self.highlighted_color = "y"
28 |
29 | self.init_attributes()
30 | if path_data is not None:
31 | self.load_model(path_data)
32 | self.center_model()
33 |
34 | @classmethod
35 | def from_model(cls, model):
36 | copy = cls()
37 | copy.set_points(np.copy(model.points))
38 | if model.points_color is not None:
39 | copy.points_color = model.points_color
40 | if model.landmarks is not None:
41 | copy.set_landmarks(np.copy(model.landmarks))
42 | if model.landmarks_color is not None:
43 | copy.landmarks_color = model.landmarks_color
44 | copy.filename = model.filename
45 | return copy
46 |
47 | def init_attributes(self):
48 | self.bgImage = None
49 | self.registration_points = None
50 | self.registration_params = None
51 | self.displacement_map = None
52 | self.rangeX = None
53 | self.rangeY = None
54 | self.filename = None
55 |
56 | def center_model(self):
57 | # Not always scans are perfectly centered.
58 | # Median is not sensible to big extreme value clusters
59 | points_median = np.median(self.points, axis=0)
60 | self.points -= points_median
61 | if self.landmarks is not None:
62 | self.landmarks -= points_median
63 |
64 | def set_points(self, data):
65 | self.points = data
66 | self.rangeX = np.ptp(self.points[:, 0])
67 | self.rangeY = np.ptp(self.points[:, 1])
68 | self.points_color = "b"
69 |
70 | def set_landmarks(self, land):
71 | self.landmarks = land
72 | self.landmarks_color = "r"
73 |
74 | def add_registration_points(self, reg_points):
75 | if reg_points[0] == -1:
76 | self.registration_points = np.empty((0, 3), dtype=int)
77 |
78 | self.registration_points = np.unique(np.append(self.registration_points, reg_points))
79 |
80 | def init_registration_points(self):
81 | self.registration_points = np.empty((0, 3), dtype=int)
82 |
83 | def get_registration_points(self):
84 | self.registration_points = np.unique(self.registration_points)
85 | return np.array(self.points[self.registration_points])
86 |
87 | def has_registration_points(self):
88 | if self.registration_points.shape[0] > 0:
89 | return True
90 | return False
91 |
92 | def save_model(self, filename):
93 | model = {"model_data": self.points}
94 |
95 | if self.landmarks is not None:
96 | model["landmarks3D"] = self.landmarks
97 | if self.displacement_map is not None:
98 | model["displacement_map"] = self.displacement_map
99 | if self.registration_params is not None:
100 | for i in range(len(self.registration_params)):
101 | model[str("reg_param"+str(i))] = self.registration_params[i]
102 |
103 | file3D.save_file(filename, model)
104 | Logger.addRow(str("File saved: " + filename))
105 |
106 | def compute_displacement_map(self, target_model, distance):
107 | from pointRegistration.displacementMap import DisplacementMap
108 | if target_model is None:
109 | return None
110 | return DisplacementMap.compute_map(source_model=self, target_model=target_model, max_dist=distance)
111 |
112 | def rotate(self, axis, theta):
113 | self.points = Model.rotate_model(axis, theta, self.points)
114 | if self.landmarks is not None:
115 | self.landmarks = Model.rotate_model(axis, theta, self.landmarks)
116 |
117 | @staticmethod
118 | def rotate_model(axis, theta, data):
119 | theta = np.radians(theta)
120 | cos_t = np.cos(theta)
121 | sin_t = np.sin(theta)
122 | if axis == 'x':
123 | rotation_matrix = Rotation.from_quat([0, sin_t, 0, cos_t])
124 | if axis == 'y':
125 | rotation_matrix = Rotation.from_quat([sin_t, 0, 0, cos_t])
126 | if axis == 'z':
127 | rotation_matrix = Rotation.from_quat([0, 0, sin_t, cos_t])
128 | return rotation_matrix.apply(data)
129 |
130 | def load_model(self, path_data):
131 | self.filename, self.file_extension = os.path.splitext(path_data)
132 |
133 | if os.path.exists(self.filename + ".png"):
134 | self.bgImage = self.filename + ".png"
135 |
136 | if self.file_extension == ".mat":
137 | file = h5py.File(path_data, 'r')
138 | try:
139 | self.set_points(np.transpose(np.array(file["avgModel"])))
140 | self.set_landmarks(np.transpose(np.array(file["landmarks3D"])))
141 | except Exception as ex:
142 | print("File is not compatible:" + ex)
143 |
144 | if self.file_extension == ".wrl":
145 | self.set_points(file3D.load_wrml(path_data))
146 | self.set_landmarks(file3D.load_bnd(self.filename + ".bnd"))
147 | if self.bgImage is not None and os.path.exists(self.bgImage):
148 | self.bgImage = self.filename[:-3] + "F2D.png"
149 |
150 | if self.file_extension == ".off":
151 | self.set_points(file3D.load_off(path_data))
152 |
153 | row = "Model loaded: " + str(self.points.shape[0]) + " points"
154 |
155 | if self.landmarks is not None:
156 | row += " and " + str(self.landmarks.shape[0]) + " landmarks."
157 |
158 | self.init_registration_points()
159 | Logger.addRow(row)
160 |
161 | @staticmethod
162 | def __path_leaf__(path):
163 | head, tail = ntpath.split(path)
164 | return tail or ntpath.basename(head)
165 |
166 | @staticmethod
167 | def decimate(old_array, percentage):
168 | if percentage >= 100:
169 | return old_array
170 |
171 | le, _ = old_array.shape
172 | useful_range = np.arange(le)
173 | np.random.shuffle(useful_range)
174 | limit = int(le / 100 * percentage)
175 | new_arr = np.empty((limit, 3))
176 | rr = np.arange(limit)
177 | for count in rr:
178 | new_arr[count] = old_array[useful_range[count]]
179 |
180 | return new_arr
181 |
--------------------------------------------------------------------------------
/pointRegistration/file3D.py:
--------------------------------------------------------------------------------
1 | from graphicInterface.console import Logger
2 | import numpy as np
3 | import os
4 | import h5py
5 |
6 |
7 | def load_wrml(path):
8 | """
9 | Restituisce un array Nx3 con le coordinate 3D dei punti contenuti in un file .wrl.
10 |
11 | Attenzione, il file deve avere questa precisa struttura:
12 | **** roba*****
13 | *************
14 | point [
15 | x.xxx y.yyyyyy z.zzzzz,
16 | x.xxx y.yyyyyy z.zzzzz,
17 | ****
18 | ****
19 | ****
20 | x.xxx y.yyyyyy z.zzzzz,
21 | x.xxx y.yyyyyy z.zzzzz
22 | ]
23 | ****** roba ******
24 | ******************
25 |
26 | Notare che manca l'ultima virgola
27 |
28 | Ai fini del caricamento il resto del contenuto non è importante, ma essendo questo
29 | un parser a riga robusto come un bicchiere di cristallo è consigliabile assicurarsi
30 | che il ile rispetti la struttura descritta prima di dare colpa all'implementazione
31 | da veri n00b.
32 |
33 | :param path: pathname del file (" pippo.wrl")
34 | :return: un array numpy Nx3, dove N è il numero dei punti contenuti nel file
35 | """
36 |
37 | file = load_file(path)
38 | if file is None:
39 | return None
40 |
41 | while file.readline().strip() != "point [":
42 | pass
43 |
44 | # arrivo alla prima riga dei punti
45 | points_num = 0
46 | points = np.empty((1000, 3))
47 | line = file.readline().strip()
48 | while line.strip() != "]":
49 | x, y, z = line[0:len(line)-1].split(" ")
50 | points[points_num] = [float(x), float(y), float(z)]
51 | points_num +=1
52 | if points_num == points.size/3:
53 | points = np.concatenate([points, np.empty((1000, 3))])
54 | line = file.readline().strip()
55 | return points[0:points_num, :]
56 |
57 |
58 | def load_bnd(path):
59 | """
60 | Restituisce un array Nx3 con le coordinate 3D dei landmarks contenuti in un file .bnd.
61 |
62 | Attenzione, il file deve avere questa precisa struttura:
63 | nnnn x.xxx y.yyyyyy z.zzzzz
64 | nnnn x.xxx y.yyyyyy z.zzzzz
65 | nnnn x.xxx y.yyyyyy z.zzzzz
66 | *******
67 | *******
68 | *******
69 | nnnn x.xxx y.yyyyyy z.zzzzz
70 | nnnn x.xxx y.yyyyyy z.zzzzz
71 |
72 | Ai fini del caricamento il resto del contenuto non è importante, ma essendo questo
73 | un parser a riga robusto come un bicchiere di cristallo è consigliabile assicurarsi
74 | che il ile rispetti la struttura descritta prima di dare colpa all'implementazione
75 | da veri n00b.
76 |
77 | :param path: pathname del file (" pippo.wrl")
78 | :return: un array numpy Nx3, dove N è il numero dei punti contenuti nel file
79 | """
80 |
81 | file = load_file(path)
82 | if file is None:
83 | return None
84 |
85 | # arrivo alla prima riga dei punti
86 | points_num = 0
87 | points = np.empty((10, 3))
88 | line = file.readline().strip()
89 | line = line.replace("\t\t", " ")
90 | while line.strip() != "":
91 | nn, x, y, z = line.split(" ")
92 | points[points_num] = [float(x), float(y), float(z)]
93 | points_num += 1
94 | if points_num == points.size/3:
95 | points = np.concatenate([points, np.empty((10, 3))])
96 | line = file.readline().replace("\t\t", " ").strip()
97 | return points[0:points_num, :]
98 |
99 |
100 | def load_off(file, faces_required=False):
101 | """
102 | Reads vertices and faces from an off file.
103 |
104 | :param file: path to file to read
105 | :type file: str
106 | :param faces_required: True if the function should return faces also
107 | :type: bool
108 | :return: vertices and faces as lists of tuples
109 | :rtype: [(float)], [(int)]
110 | """
111 |
112 | assert os.path.exists(file)
113 |
114 | with open(file, 'r') as fp:
115 | lines = fp.readlines()
116 | lines = [line.strip() for line in lines]
117 |
118 | assert (lines[0] == 'OFF'), "Invalid preambole"
119 |
120 | parts = lines[1].split(' ')
121 | assert (len(parts) == 3), "Need exactly 3 parameters on 2nd line (n_vertices, n_faces, n_edges)."
122 |
123 | num_vertices = int(parts[0])
124 | assert num_vertices > 0
125 |
126 | num_faces = int(parts[1])
127 | if faces_required:
128 | assert num_faces > 0
129 |
130 | vertices = []
131 | for i in range(num_vertices):
132 | vertex = lines[2 + i].split(' ')
133 | vertex = [float(point) for point in vertex]
134 | assert (len(vertex) == 3), str("Invalid vertex row on line " + str(i))
135 |
136 | vertices.append(vertex)
137 |
138 | if num_vertices > len(vertices):
139 | row = "WARNING: some vertices were not loaded correctly: {0} declared vs {1} loaded."
140 | Logger.addRow(row.format(num_vertices, len(vertices)))
141 |
142 | vertices = np.asarray(vertices)
143 |
144 | if faces_required:
145 | faces = []
146 | for i in range(num_faces):
147 | face = lines[2 + num_vertices + i].split(' ')
148 | face = [int(index) for index in face]
149 |
150 | assert face[0] == len(face) - 1
151 | for index in face:
152 | assert 0 <= index < num_vertices
153 |
154 | assert len(face) > 1
155 |
156 | faces.append(face)
157 | return vertices, faces
158 |
159 | return vertices
160 |
161 |
162 | def load_file(path):
163 | try:
164 | file = open(path, "r")
165 | except FileNotFoundError as ex:
166 | print("Can't find the file")
167 | return None
168 | except FileExistsError as ex:
169 | print("File exists but got troubles")
170 | return None
171 | except Exception as ex:
172 | print("Something gone wrong")
173 | return None
174 |
175 | return file
176 |
177 |
178 | def save_file(filepath, model):
179 | filename, file_extension = os.path.splitext(filepath)
180 |
181 | if file_extension == '.off':
182 | save_off(filepath, model)
183 | if file_extension == '.wrl':
184 | save_wrl(filepath, model)
185 | if file_extension == '.mat':
186 | save_mat(filepath, model)
187 |
188 |
189 | def save_off(filepath, model):
190 | with open(filepath, "w") as file:
191 | file.write("OFF\n")
192 | n_points = int(model["model_data"].size / 3)
193 | file.write(str(n_points) + " 0 0\n")
194 | row = "{0} {1} {2}\n"
195 | for index in range(n_points):
196 | x, y, z = model["model_data"][index]
197 | file.write(row.format(x, y, z))
198 | file.close()
199 |
200 |
201 | def save_mat(filepath, model):
202 | f = h5py.File(filepath, "w")
203 | for key, value in model.items():
204 | f.create_dataset(key, data=value)
205 | f.close()
206 |
207 |
208 | def save_wrl(filepath, model):
209 | pass
--------------------------------------------------------------------------------
/graphicInterface/main_widget.py:
--------------------------------------------------------------------------------
1 | import os
2 | from graphicInterface.plot_interactive_figure import PlotInteractiveFigure
3 | from graphicInterface.rotatable_figure import RotatableFigure
4 | from graphicInterface.show_displacement import DisplacementMapWindow
5 | from graphicInterface.upper_toolbar import *
6 | from pointRegistration.batchRegistration import BatchRegistrationThread
7 | from pointRegistration.model import Model
8 | from pointRegistration.registration import Registration
9 |
10 |
11 | class MainWidget(QWidget):
12 | def __init__(self, parent):
13 | super().__init__(parent)
14 | Logger.addRow(str("Starting up.."))
15 | self.source_model = Model(os.path.join(".", "data", "avgModel_bh_1779_NE.mat"))
16 | self.target_model = None
17 | self.sx_widget = None
18 | self.dx_widget = None
19 | self.registration_thread = None
20 | self.toolbar = None
21 | self.registrated = None
22 | Logger.addRow(str("Ready."))
23 | self.initUI()
24 |
25 | def initUI(self):
26 | grid_central = QGridLayout(self)
27 | self.setLayout(grid_central)
28 | self.sx_widget = PlotInteractiveFigure(self, self.source_model, title="Source")
29 | self.dx_widget = RotatableFigure(self, None, title="Target")
30 | grid_central.addWidget(self.sx_widget, 1, 0, 1, 2)
31 | self.sx_widget.draw_data()
32 | grid_central.addWidget(self.dx_widget, 1, 2, 1, 2)
33 | self.dx_widget.draw_data()
34 | self.toolbar = UpperToolbar(self)
35 | grid_central.addWidget(self.toolbar, 0, 0, 1, 4)
36 | grid_central.setRowStretch(0, 1)
37 | grid_central.setRowStretch(1, 30)
38 |
39 | def load_target(self):
40 | filters = "OFF Files (*.off);;WRML Files (*.wrml);;MAT Files (*.mat)"
41 | file_name = load_file_dialog(self, filters)
42 | if file_name is None:
43 | return
44 |
45 | self.toolbar.start_registration_button.setEnabled(True)
46 | self.target_model = Model(file_name)
47 | self.dx_widget.load_model(self.target_model)
48 | self.dx_widget.draw_data(clear=True)
49 | Logger.addRow(str("File loaded correctly: " + file_name))
50 | self.toolbar.save_target_btn.setEnabled(True)
51 |
52 | def load_source(self):
53 | filters = "OFF Files (*.off);;WRML Files (*.wrml);;MAT Files (*.mat)"
54 | file_name = load_file_dialog(self, filters)
55 | if file_name is None:
56 | return
57 |
58 | self.toolbar.start_registration_button.setEnabled(True)
59 | self.source_model = Model(file_name)
60 | self.sx_widget.load_model(self.source_model)
61 | self.sx_widget.draw_data()
62 | Logger.addRow(str("File loaded correctly: " + file_name))
63 |
64 | def restore(self):
65 | self.restore_highlight()
66 | if self.dx_widget.has_model():
67 | self.restore_target()
68 | self.toolbar.start_registration_button.setEnabled(True)
69 |
70 | def restore_highlight(self):
71 | self.sx_widget.highlight_data([-1])
72 |
73 | def restore_target(self):
74 | self.dx_widget.restore_model()
75 |
76 | def landmark_selected(self, colors):
77 | self.dx_widget.set_landmarks_colors(colors)
78 |
79 | def data_selected(self, x_coord, y_coord, width, height): # apply color to target
80 | self.dx_widget.select_area(x_coord, y_coord, width, height)
81 |
82 | def registrate(self, method, percentage):
83 | if not self.sx_widget.there_are_points_highlighted(): # names are everything
84 | QMessageBox.critical(self, 'Error', "No rigid points have been selected.")
85 | raise Exception("No rigid points selected")
86 |
87 | if self.dx_widget.model is None:
88 | QMessageBox.critical(self, 'Error', "Please, load a target model.")
89 | raise Exception("Target model is not present.")
90 |
91 | if self.registration_thread is None:
92 | self.toolbar.show_displacement_btn.setEnabled(False)
93 | self.parent().setStatus("Busy...")
94 | self.registration_thread = Registration(method, self.source_model, self.target_model, percentage,
95 | self.registration_completed_callback,
96 | self.dx_widget.update_plot_callback)
97 | self.registration_thread.start()
98 |
99 | def stop_registration_thread(self):
100 | if self.registration_thread is not None:
101 | Logger.addRow(str("Trying to stop registration thread..."))
102 | self.registration_thread.stop()
103 |
104 | def show_displacement_map(self):
105 | DisplacementMapWindow(self.parent(), self.source_model.compute_displacement_map(self.target_model, 3)) # FIXME
106 | self.parent().setStatusReady()
107 |
108 | def save_target(self):
109 | if self.target_model is None:
110 | QMessageBox.critical(self, 'Error', "The source model was not registered yet.")
111 | return
112 |
113 | filters = "MAT File (*.mat);;OFF File (*.off);;"
114 | filename = save_file_dialog(self, filters)
115 | if filename is None:
116 | return
117 | self.target_model.save_model(filename) #fixme controllare che venga realmente salvato
118 |
119 | def registration_completed_callback(self, model):
120 | Logger.addRow(str("Registration completed."))
121 | self.target_model.bgImage = self.dx_widget.bgImage
122 | self.dx_widget.clear()
123 |
124 | self.dx_widget.set_secondary_model(model)
125 | self.dx_widget.load_model(self.target_model)
126 | self.target_model = model
127 |
128 | self.dx_widget.draw_data()
129 | self.parent().setStatusReady()
130 | self.registration_thread = None
131 | self.toolbar.stop_registration_button.setEnabled(False)
132 | self.toolbar.show_displacement_btn.setEnabled(True)
133 | self.registrated = True
134 | self.parent().setStatus("Registration completed, displacement map available. Click Show Displacement Map.")
135 | # Target and Source are now plotted in target widget
136 |
137 | def registrate_batch_callback(self):
138 | try:
139 | Logger.addRow(str("Registration completed."))
140 | self.registration_thread = None
141 | self.parent().setStatus("Ready.")
142 | except Exception as ex:
143 | print(ex)
144 |
145 | def registrate_batch(self, method, percentage, filenames):
146 |
147 | if not self.sx_widget.there_are_points_highlighted(): # names are everything
148 | QMessageBox.critical(self, 'Error', "No rigid points have been selected.")
149 | raise Exception("No rigid points selected")
150 |
151 | if self.registration_thread is None and self.sx_widget.there_are_points_highlighted():
152 | self.parent().setStatus("Busy...")
153 |
154 | self.registration_thread = BatchRegistrationThread(self.sx_widget.model, filenames, percentage,
155 | self.registrate_batch_callback)
156 | self.registration_thread.start()
157 |
158 | @staticmethod
159 | def savelog_onfile():
160 | Logger.save_log()
161 |
--------------------------------------------------------------------------------
/pointRegistration/displacementMap.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import os
3 | import pickle # Look Morty, I'm a pickle!
4 |
5 | import h5py
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | from scipy.spatial import cKDTree
9 |
10 | from graphicInterface.console import Logger
11 | from pointRegistration.model import Model
12 |
13 |
14 | class DisplacementMap(Model):
15 |
16 | def __init__(self, points, landmarks, points_color, landmarks_color, missed_points, missed_points_color,
17 | missed_landmarks, missed_landmarks_color):
18 | self.init_attributes()
19 | self.set_points(points)
20 | self.set_landmarks(landmarks)
21 | self.points_color = points_color
22 | self.landmarks_color = landmarks_color
23 | self.missed_points = missed_points
24 | self.missed_points_color = missed_points_color
25 | self.missed_landmarks = missed_landmarks
26 | self.missed_landmarks_color = missed_landmarks_color
27 |
28 | self.rangeX = np.ptp(np.append(self.points[:, 0], self.missed_points[:, 0]))
29 | self.rangeY = np.ptp(np.append(self.points[:, 1], self.missed_points[:, 1]))
30 |
31 | @classmethod
32 | def compute_map(cls, source_model, target_model, max_dist=2):
33 | """:param source_model
34 | :type source_model Model
35 |
36 | :param target_model
37 | :type target_model Model
38 |
39 | :param max_dist
40 | :type max_dist number
41 | """
42 |
43 | source_points = source_model.points
44 | target_points = target_model.points
45 |
46 | hit_landmarks = None
47 | missed_landmarks = None
48 | landmarks_color = None
49 | missed_landmarks_color = None
50 |
51 | # Points KDTree
52 | tree = cKDTree(source_points)
53 | _, indices = tree.query(target_points, distance_upper_bound=max_dist, n_jobs=multiprocessing.cpu_count())
54 | # Indices contains source's neighbors indices to target's points used for query
55 | indices = set(indices)
56 | indices.discard(source_points.shape[0])
57 | indices = list(indices)
58 | hit_points = source_points[indices]
59 | missed_points = np.delete(source_points, indices, axis=0) # Points without a feasible neighbor
60 | missed_color = "y"
61 | points_color = "b"
62 | # LandmarksKDTree
63 | if source_model.landmarks is not None and target_model.landmarks is not None:
64 | tree = cKDTree(source_model.landmarks)
65 | _, indices = tree.query(target_model.landmarks, distance_upper_bound=max_dist, n_jobs=multiprocessing.cpu_count())
66 | indices = set(indices)
67 | indices.discard(source_model.landmarks.shape[0])
68 | indices = list(indices)
69 | hit_landmarks = source_model.landmarks[indices]
70 | missed_landmarks = np.delete(source_model.landmarks, indices, axis=0)
71 | landmarks_color = "r"
72 | missed_landmarks_color = "p"
73 |
74 | Logger.addRow("Displacement map created.")
75 | return cls(points=hit_points, landmarks=hit_landmarks, points_color=points_color,
76 | landmarks_color=landmarks_color, missed_points=missed_points, missed_landmarks=missed_landmarks,
77 | missed_points_color=missed_color, missed_landmarks_color=missed_landmarks_color)
78 |
79 | def save_model(self, filename):
80 | name, file_extension = os.path.splitext(filename)
81 |
82 | if file_extension == ".h5py":
83 | self.save_h5py(filename)
84 | if file_extension == ".pickle":
85 | self.save_pickle(filename)
86 |
87 | def save_pickle(self, filename):
88 | name, file_extension = os.path.splitext(filename)
89 | missed_map = self.missed_points
90 | hit_map = self.points
91 | if self.landmarks is not None:
92 | missed_map = np.concatenate(missed_map, self.missed_landmarks)
93 | hit_map = np.concatenate(hit_map, self.landmarks)
94 | pickle.dump(missed_map, str(name+"_missed_points" + file_extension), "wb")
95 | pickle.dump(hit_map, str(name+"_hit_points" + file_extension), "wb")
96 |
97 | def save_h5py(self, filename):
98 | def encode_ss_utf_8(ss):
99 | return np.void(ss.encode('utf-8'))
100 | # self = self.source_model.compute_displacement_map(self.target_model, 3)
101 | file = None
102 | try:
103 | file = h5py.File(filename, "w")
104 | file.create_dataset("points", data=self.points)
105 | file.create_dataset("missed_points", data=self.missed_points)
106 | file.attrs["points_color"] = encode_ss_utf_8(self.points_color)
107 | file.attrs["missed_points_color"] = encode_ss_utf_8(self.missed_points_color)
108 | if self.landmarks is not None:
109 | file.create_dataset("landmarks", data=self.landmarks)
110 | file.create_dataset("missed_landmarks", data=self.missed_landmarks)
111 | file.attrs["landmarks_color"] = encode_ss_utf_8(self.landmarks_color)
112 | file.attrs["missed_landmarks_color"] = encode_ss_utf_8(self.missed_landmarks_color)
113 | except Exception as ex:
114 | Logger.addRow("ERROR: Displacement map was not saved. =>" + str(ex))
115 | finally:
116 | if file is not None:
117 | file.close()
118 | Logger.addRow(f'File {filename} saved correctly.')
119 |
120 | @classmethod
121 | def load_model(cls, filename):
122 | try:
123 | file = h5py.File(filename, 'r')
124 | points = file.get("points").value
125 | missed_points = file.get("missed_points").value
126 | points_color = (file.attrs["points_color"]).tostring().decode('utf-8')
127 | missed_points_color = (file.attrs["missed_points_color"]).tostring().decode('utf-8')
128 |
129 | # If landmarks are not present a KeyError is raised
130 | landmarks = file.get("landmarks").value
131 | missed_landmarks = file.get("missed_landmarks").value
132 | landmarks_color = (file.attrs["landmarks_color"]).tostring().decode('utf-8')
133 | missed_landmarks_color = (file.attrs["missed_landmarks_color"]).tostring().decode('utf-8')
134 |
135 | except AttributeError:
136 | Logger.addRow("INFO: Displacement model has no landmarks.")
137 | landmarks = None
138 | landmarks_color = None
139 | missed_landmarks = None
140 | missed_landmarks_color = None
141 | except FileNotFoundError:
142 | Logger.addRow(f"ERROR: {filename} not found, check path.")
143 |
144 | except Exception as ex:
145 | Logger.addRow(f"ERROR: Problem during opening of {filename} =>" + str(ex))
146 |
147 | finally:
148 | displacement_model = cls(points=points, landmarks=landmarks,
149 | missed_points=missed_points,
150 | missed_landmarks=missed_landmarks,
151 | points_color=points_color,
152 | landmarks_color=landmarks_color,
153 | missed_points_color=missed_points_color,
154 | missed_landmarks_color=missed_landmarks_color)
155 |
156 | Logger.addRow(f'File {filename} load correctly.')
157 | return displacement_model
158 |
159 | def shoot_displacement_map(self, filepath):
160 | plt.scatter(self.displacement_map[:, 0], self.displacement_map[:, 1], s=0.5)
161 | plt.savefig(str(filepath[0:-3]+"png"))
162 | plt.close()
163 |
164 | def rotate(self, axis, theta):
165 | super().rotate(axis, theta)
166 | try:
167 | self.missed_points = Model.rotate_model(axis, theta, self.missed_points)
168 | if self.missed_landmarks is not None:
169 | self.missed_landmarks = Model.rotate_model(axis, theta, self.missed_landmarks)
170 | except Exception as ex:
171 | print(ex)
172 |
--------------------------------------------------------------------------------
/morpmodel/_3DMM.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numpy.linalg import pinv
3 | from numpy.linalg import inv
4 | from scipy.spatial import ConvexHull
5 |
6 | class _3DMM:
7 |
8 | result = {}# dictionary to return
9 |
10 | def opt_3DMM_fast(self, Weights, Components, Components_res, landmarks3D, id_landmarks_3D, landImage, avgFace, _lambda, rounds, r, C_dist, ex_to_me):
11 | _Weights = Weights
12 | index = np.array(id_landmarks_3D, dtype=np.intp)
13 | app_var = np.reshape(avgFace[index,:],(landmarks3D.shape[0],3),order='F')
14 | # _var --> variabili di appoggio temporanee
15 | [_Aa, _Sa, _Ra, _Ta] = self.estimate_pose(app_var, landImage)
16 | _proj = np.transpose(self.getProjectedVertex(app_var, _Sa, _Ra, _Ta))
17 | # deform shape
18 | _alpha = self.alphaEstimation(landImage, _proj, Components_res, id_landmarks_3D, _Sa, _Ra, _Ta, Weights, _lambda)
19 | _defShape = self.deform_3D_shape_fast(np.transpose(avgFace), Components, _alpha, ex_to_me)
20 | _defShape = np.transpose(_defShape)
21 | _defLand = np.reshape(_defShape[index,:],(landmarks3D.shape[0],3),order='F')
22 | _proj = np.transpose(self.getProjectedVertex(_defLand, _Sa, _Ra, _Ta))
23 | for i in range(rounds):
24 | [_Aa, _Sa, _Ra, _Ta] = self.estimate_pose(_defLand, landImage)
25 | _proj = np.transpose(self.getProjectedVertex(_defLand, _Sa, _Ra, _Ta))
26 | _alpha = self.alphaEstimation(landImage, _proj, Components_res, id_landmarks_3D, _Sa, _Ra, _Ta, Weights,_lambda)
27 | _defShape = self.deform_3D_shape_fast(np.transpose(_defShape),Components,_alpha, ex_to_me)
28 | _defShape = np.transpose(_defShape)
29 | _defLand = np.reshape(_defShape[index, :], (landmarks3D.shape[0], 3), order='F')
30 |
31 | [_Aa, _Sa, _Ra, _Ta] = self.estimate_pose(_defLand, landImage)
32 | _proj = np.transpose(self.getProjectedVertex(_defLand, _Sa, _Ra, _Ta))
33 | _visIdx = self.estimateVis_vertex(_defShape, _Ra, C_dist, r)
34 | _visIdx = np.delete(_visIdx, (0), axis=0)
35 |
36 | # assing values to dictionary
37 | self.result["A"] = _Aa
38 | self.result["S"] = _Sa
39 | self.result["R"] = _Ra
40 | self.result["T"] = _Ta
41 | self.result["defShape"] = _defShape
42 | self.result["alpha"] = _alpha
43 | self.result["visIdx"] = _visIdx
44 |
45 | return self.result
46 |
47 | def estimate_pose(self, landModel, landImage):
48 | baricM = np.mean(landModel,axis=0)
49 | P = landModel - np.tile(baricM,(landModel.shape[0],1))
50 |
51 | baricI = np.mean(landImage,axis=0)
52 | p = landImage - np.tile(baricI,(landImage.shape[0],1))
53 |
54 | P = np.transpose(P)
55 | p = np.transpose(p)
56 | qbar = np.transpose(baricI)
57 | Qbar = np.transpose(baricM)
58 | A = p.dot(pinv(P))
59 | [S,R] = np.linalg.qr(np.transpose(A), mode='complete') # A = S*R
60 | rr = S
61 | S = np.transpose(R)
62 | R = np.transpose(rr)
63 | t = qbar - A.dot(Qbar)
64 | # save data to the dictionary
65 | return [A,S,R,t]
66 |
67 | def getProjectedVertex(self,vertex,S,R,t):
68 | _vertex = self.transVertex(vertex)
69 | rotPc = np.transpose(R.dot(_vertex))
70 | #t = np.reshape(t,(2,1),order='F')
71 | t = np.reshape(t, (3,1), order='F')
72 | app = np.tile(t, (1,rotPc.shape[0]))
73 | return S.dot(np.transpose(rotPc)) + app
74 |
75 | def transVertex(self,vertex):
76 | v = vertex
77 | if vertex.shape[0] != 3:
78 | return np.transpose(v)
79 | else:
80 | return v
81 |
82 |
83 | def alphaEstimation(self, landImage, projLandModel, Components_res, id_landmarks_3D, S, R, t, Weights, _lambda):
84 | Weights = np.reshape(Weights, (Weights.shape[0], 1))
85 | X = landImage - projLandModel
86 | X = X.flatten(order='F')
87 | Y = np.zeros((X.shape[0],Components_res.shape[2]))
88 | index = np.array(id_landmarks_3D,dtype=np.intp)
89 |
90 | for c in range(Components_res.shape[2]):
91 | vect = Components_res[index,:,c].reshape(landImage.shape[0],3,order='F')
92 | vertexOnImComp = np.transpose(self.getProjectedVertex(vect,S,R,t))
93 | Y[:,c] = vertexOnImComp.flatten(order='F')
94 | if _lambda == 0:
95 | Alpha = Y.divide(X)
96 | else:
97 | with np.errstate(divide='ignore'):
98 | invW = np.diag( _lambda/(np.diagflat(Weights))) # as diag in malab
99 | #if invW.shape[1] != 1:
100 | # invW = np.diagflat(invW)
101 | var = (np.transpose(Y)).dot(Y)
102 | res = var + np.diagflat(invW)
103 | YY = inv(res)
104 | app = np.dot(YY, np.transpose(Y))
105 | Alpha = np.dot(app,X)
106 | return Alpha
107 |
108 | def alpha_estimation_test(self, avg_colors, target_colors, Components, Weights, _lambda):
109 | avg_colors = np.reshape(avg_colors, (avg_colors.shape[0], 1))
110 | target_colors = np.reshape(target_colors, (target_colors.shape[0], 1))
111 |
112 | X = target_colors - avg_colors
113 | Y = Components
114 | with np.errstate(divide='ignore'):
115 | invW = np.diag(_lambda / (np.diagflat(Weights))) # as diag in malab
116 | var = (np.transpose(Y)).dot(Y)
117 | res = var + np.diagflat(invW)
118 | YY = inv(res)
119 | app = np.dot(YY, np.transpose(Y))
120 | Alpha = np.dot(app, X)
121 | return Alpha
122 |
123 |
124 | def deform_3D_shape_fast(self, mean_face, eigenvecs, alpha, ex_to_me):
125 | dim = (eigenvecs.shape[0])//3
126 | alpha_full = np.tile(np.transpose(alpha), (eigenvecs.shape[0],1))
127 | tmp_eigen = alpha_full*eigenvecs
128 | sumVec = np.sum(tmp_eigen, axis=1) # somma attraverso le righe
129 | sumMat = np.reshape(np.transpose(sumVec), (3, dim), order='F')
130 | return mean_face + sumMat
131 |
132 |
133 | def estimateVis_vertex(self, vertex, R, C_dist, r):
134 | viewPoint_front = np.array([0,0,C_dist]).reshape(1,3, order='F')
135 | viewPoint = np.transpose(self.rotatePointCloud(viewPoint_front, R, []))
136 | visIdx = self.HPR(vertex, viewPoint, r) # controllare la funzione HPR
137 |
138 | return visIdx
139 |
140 | def rotatePointCloud(self, P, R, t):
141 | #if not t:
142 | # _tile = np.tile(np.tranpose(t),(t.shape[1],1))
143 | # P = P + np.transpose(_tile)
144 | return np.dot(R, np.transpose(P))
145 |
146 | def HPR(self, p, C, param):
147 | dim = p.shape[1]
148 | numPts = p.shape[0]
149 | p = p - np.tile(C,(numPts,1))
150 | #normP = np.sqrt(p.dot(p))
151 | normP = np.linalg.norm(p, axis=1)
152 | normP = normP.reshape(normP.shape[0], 1, order='F')
153 | app = np.amax(normP)*(np.power(10,param))
154 | R = np.tile(app,(numPts,1))
155 |
156 | P = p + 2*np.tile((R-normP),(1,dim))*p/np.tile(normP,(1,dim))
157 | _zeros = np.zeros((1,dim))
158 | vect_conv_hull = np.vstack([P,_zeros])
159 | #into_unique = ConvexHull(vect_conv_hull)
160 | hull = ConvexHull(vect_conv_hull)
161 | visiblePtInds = np.unique(hull.vertices)
162 | #visiblePtInds(visiblePtInds==numPts+1)=[] # ??????????
163 | for i in range(visiblePtInds.shape[0]):
164 | visiblePtInds[i] = visiblePtInds[i] - 1
165 | if visiblePtInds[i] == (numPts + 1):
166 | visiblePtInds.remove(i)
167 |
168 | return visiblePtInds.reshape(visiblePtInds.shape[0], 1, order='F')
169 |
170 | def getVisLand(self, vertex, landImage, visIdx, id_landmarks_3D):
171 | # deve ritornare visLand, idxVland, Nvis
172 | visLand = np.intersect1d(visIdx, id_landmarks_3D)
173 | # id_Vland
174 | mb = np.in1d(id_landmarks_3D, visIdx)
175 | id_Vland = np.nonzero(mb)[0]
176 |
177 | app = np.empty([1,landImage.shape[0]])
178 | for i in range(1,landImage.shape[0]):
179 | app[0,i] = i
180 |
181 | Nvis = np.setxor1d(app,id_Vland)
182 | return [visLand, id_Vland, Nvis]
183 |
184 |
--------------------------------------------------------------------------------
/morpmodel/util_for_graphic.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.interpolate import LinearNDInterpolator
3 | from numpy import dot, mean, std, argsort
4 | from numpy.linalg import eigh
5 | from scipy.spatial import ConvexHull
6 | from matplotlib import path
7 |
8 |
9 | class graphic_tools:
10 | _3DMM_obj = ()
11 |
12 | def __init__(self, _3dmm_obj):
13 | if _3dmm_obj is None:
14 | self._3DMM_obj = None
15 | else:
16 | self._3DMM_obj = _3dmm_obj
17 |
18 | def render3DMM(self, xIm, yIm, rgb, w, h):
19 | # need to perfomr the proj shape from the 3D model of the shape. The first column of projShape is xIm, the secondo is yIm.
20 | # The texture of the 3D model is the var rgb
21 | x_grid, y_grid = np.meshgrid([int(x) for x in range(w)], [int(y) for y in range(h)])
22 | _first_cord = np.transpose(xIm)
23 | _second_cord = np.transpose(yIm)
24 | # Interpolate z values on the grid
25 | coords = self._2linear_vects(_first_cord.reshape(_first_cord.shape[0],1), _second_cord.reshape(_second_cord.shape[0],1))
26 | Fr = LinearNDInterpolator(coords, np.transpose(rgb[:,0]))
27 | Fb = LinearNDInterpolator(coords, np.transpose(rgb[:,1]))
28 | Fg = LinearNDInterpolator(coords, np.transpose(rgb[:,2]))
29 | # Get values for each location
30 | imR = Fr(x_grid,y_grid)
31 | imG = Fb(x_grid,y_grid)
32 | imB = Fg(x_grid,y_grid)
33 |
34 | pts = np.column_stack([xIm, yIm])
35 | pts = np.round(pts)
36 | pts = np.unique(pts, axis=0)
37 | allPoints=np.column_stack((pts[:,0],pts[:,1]))
38 | hullPoints = ConvexHull(allPoints)
39 | bb = hullPoints.vertices
40 | bb = np.array(bb, dtype=np.intp)
41 | mask = self.inpolygon(x_grid, y_grid, pts[bb,0], pts[bb,1])
42 | mask = np.logical_not(mask)
43 |
44 | imR[mask] = 0
45 | imG[mask] = 0
46 | imB[mask] = 0
47 | # building image
48 | img = np.empty((imR.shape[0], imR.shape[1], 3))
49 | img[:,:,0] = imR
50 | img[:,:,1] = imG
51 | img[:,:,2] = imB
52 | #img = img*255
53 | return np.array(img,dtype=np.uint8)
54 |
55 | def inpolygon(self, xq, yq, xv, yv):
56 | shape = xq.shape
57 | xq = xq.reshape(-1)
58 | yq = yq.reshape(-1)
59 | xv = xv.reshape(-1)
60 | yv = yv.reshape(-1)
61 | q = [(xq[i], yq[i]) for i in range(xq.shape[0])]
62 | p = path.Path([(xv[i], yv[i]) for i in range(xv.shape[0])])
63 | return p.contains_points(q).reshape(shape)
64 |
65 | def denseResampling(self, defShape, proj_shape, img, S, R, t, visIdx):
66 | grid_step = 1
67 | max_sh = (np.amax(proj_shape, axis=0))
68 | max_sh = np.reshape(max_sh,(1,2),order='F')
69 | min_sh = np.amin(proj_shape, axis=0)
70 | min_sh = np.reshape(min_sh, (1,2), order='F')
71 | Xsampling = np.arange(min_sh[0,0], max_sh[0,0], grid_step, dtype='float')
72 | Ysampling = np.arange(max_sh[0,1], min_sh[0,1], -grid_step, dtype='float')
73 | # round the float numbers
74 | Xsampling = np.around(Xsampling)
75 | Ysampling = np.around(Ysampling)
76 |
77 | print ("SAMPLING THE GRID")
78 | x_grid, y_grid = np.meshgrid(Xsampling, Ysampling)
79 | print ("3D LOCATION INTERPOLATION")
80 | index = np.array(visIdx, dtype=np.intp)
81 | X = proj_shape[index, 0]
82 | Y = proj_shape[index, 1]
83 | coords = self._2linear_vects(X,Y)
84 | #coords = list(zip(X,Y))
85 | Fx = LinearNDInterpolator(coords, defShape[index, 0])
86 | Fy = LinearNDInterpolator(coords, defShape[index, 1])
87 | Fz = LinearNDInterpolator(coords, defShape[index, 2])
88 | print ("PIXEL SAMPLING")
89 | x = Fx(x_grid.flatten(order='F'), y_grid.flatten(order='F'))
90 | y = Fy(x_grid.flatten(order='F'), y_grid.flatten(order='F'))
91 | z = Fz(x_grid.flatten(order='F'), y_grid.flatten(order='F'))
92 | print ("IMAGE BUILDING")
93 | x = x[~np.isnan(x).any(axis=1)]
94 | y = y[~np.isnan(y).any(axis=1)]
95 | z = z[~np.isnan(z).any(axis=1)]
96 | mod3d = np.dstack([x, y, z])
97 | #mod3d = np.delete(mod3d, [0,1,2,3], axis=0) # c'erano 3 valori in piu dall'interpolazione...??
98 | mod3d = np.reshape(mod3d,(mod3d.shape[0],3))
99 | print ("FINAL RENDERING")
100 | projMod3d = np.transpose(self._3DMM_obj.getProjectedVertex(mod3d, S, R, t))
101 | colors = self.getRGBtexture(np.round(projMod3d), img)
102 | colors = self._3DMM_obj.transVertex(colors)
103 | print ("DONE DENSE RESAMPLING")
104 | return [mod3d,colors]
105 |
106 | def resampleRGB(self,P,colors,step):
107 | min_x = np.amin(P[0,:], axis=0)
108 | max_x = np.amax(P[0,:], axis=0)
109 | min_y = np.amin(P[1,:], axis=0)
110 | max_y = np.amax(P[1,:], axis=0)
111 | Xsampling = np.arange(min_x, max_x, step, dtype='float')
112 | Ysampling = np.arange(max_y, min_y, -step, dtype='float')
113 | x, y = np.meshgrid(Xsampling, Ysampling)
114 | _first_cord = np.transpose(P[0,:])
115 | _second_cord = np.transpose(P[1,:])
116 | coords = self._2linear_vects(_first_cord.reshape(_first_cord.shape[0],1), _second_cord.reshape(_second_cord.shape[0],1))
117 | #Fr = griddata(coords, np.transpose(colors[0, :]), (x, y), method='linear')
118 | #Fg = griddata(coords, np.transpose(colors[1, :]), (x, y))
119 | #Fb = griddata(coords, np.transpose(colors[2, :]), (x, y))
120 |
121 | Fr = LinearNDInterpolator(coords, np.transpose(colors[0, :]))
122 | Fb = LinearNDInterpolator(coords, np.transpose(colors[1, :]))
123 | Fg = LinearNDInterpolator(coords, np.transpose(colors[2, :]))
124 |
125 | texR = Fr(x,y)
126 | texG = Fb(x,y)
127 | texB = Fg(x,y)
128 | #remove Nan elements
129 | mask = np.isnan(texR)
130 | index = np.array(mask, dtype=np.intp)
131 | texR[mask] = 0
132 | texG[mask] = 0
133 | texB[mask] = 0
134 | img = np.empty((texR.shape[0], texR.shape[1], 3))
135 | img[:,:,0] = texR
136 | img[:,:,1] = texG
137 | img[:,:,2] = texB
138 | return img, Xsampling, Ysampling
139 |
140 |
141 |
142 | '''texR = np.nan_to_num(Fr(x, y))
143 | texG = np.nan_to_num(Fg(x, y))
144 | texB = np.nan_to_num(Fb(x, y))
145 | img = np.empty((texR.shape[0], texR.shape[1], 3))
146 | img[:, :, 0] = texR
147 | img[:, :, 1] = texG
148 | img[:, :, 2] = texB
149 | #img = img * 255
150 | return img, Xsampling, Ysampling'''
151 |
152 | def _2linear_vects(self, X,Y): # Creo le coordinate (X,Y) da passare alla funzione interpolatrice
153 | coords = np.empty((1, 2))
154 | for i in range(X.shape[0]):
155 | new_cord = np.array([X[i, 0], Y[i, 0]]).reshape(1, 2, order='F')
156 | coords = np.row_stack((coords, new_cord))
157 | coords = np.delete(coords, (0), axis=0)
158 | return coords
159 |
160 | def getRGBtexture(self, coordTex, tex):
161 | R = (tex[:,:,0]).flatten(order='F')
162 | G = (tex[:,:,1]).flatten(order='F')
163 | B = (tex[:,:,2]).flatten(order='F')
164 | #print(coordTex.shape)
165 | Xtex = np.round(coordTex[:,0])
166 | Ytex = np.round(coordTex[:,1])
167 | Xtex = np.maximum(Xtex, np.tile(1, Xtex.shape[0]))
168 | Ytex = np.maximum(Ytex, np.tile(1, Ytex.shape[0]))
169 | Xtex = np.minimum(Xtex, np.tile(tex.shape[1], Xtex.shape[0]))
170 | Ytex = np.minimum(Ytex, np.tile(tex.shape[0], Ytex.shape[0]))
171 | #I = (self.sub2ind(tex.shape, Ytex, Xtex)).astype(int)
172 | #I = np.sub2ind(tex.shape, Ytex, Xtex)
173 | I = Ytex + (Xtex - 1) * tex.shape[0]
174 | I = I.astype(int)
175 | #print(I.shape)
176 | Ri = R[I]
177 | Gi = G[I]
178 | Bi = B[I]
179 | colors = np.empty((Ri.shape[0],1))
180 | colors = np.hstack((colors, Ri.reshape(Ri.shape[0],1)))
181 | colors = np.hstack((colors, Gi.reshape(Gi.shape[0],1)))
182 | colors = np.hstack((colors, Bi.reshape(Bi.shape[0],1)))
183 | colors = np.delete(colors, (0), axis=1)
184 | colors = (colors.astype(float))/255.0
185 |
186 | return colors
187 |
188 | def renderFaceLossLess(self, defShape, projShape, img, S, R, T, renderingStep, visIdx):
189 | [mod3d, colors] = self.denseResampling(defShape, projShape, img, S, R, T, visIdx)
190 | frontal_view = self.build_image(np.transpose(mod3d), np.transpose(defShape), colors, renderingStep)
191 | return frontal_view, colors, mod3d
192 |
193 | def build_image(self, P_f, P, colors, step):
194 | print("START BUILD FRONTAL VIEW")
195 | min_x = np.amin(P[0,:])
196 | max_x = np.amax(P[0,:])
197 | min_y = np.amin(P[1,:])
198 | max_y = np.amax(P[1,:])
199 | Xsampling = np.arange(min_x, max_x, step, dtype='float')
200 | Ysampling = np.arange(max_y, min_y, -step, dtype='float')
201 |
202 | [x, y] = np.meshgrid(Xsampling, Ysampling)
203 | _first_cord = np.transpose(P_f[0,:])
204 | _second_cord = np.transpose(P_f[1,:])
205 |
206 | coords = self._2linear_vects(_first_cord.reshape(_first_cord.shape[0],1), _second_cord.reshape(_second_cord.shape[0],1))
207 | Fr = LinearNDInterpolator(coords, np.transpose(colors[0,:]))
208 | Fb = LinearNDInterpolator(coords, np.transpose(colors[1,:]))
209 | Fg = LinearNDInterpolator(coords, np.transpose(colors[2,:]))
210 | texR = np.nan_to_num(Fr(x, y))
211 | texG = np.nan_to_num(Fg(x, y))
212 | texB = np.nan_to_num(Fb(x, y))
213 | #mask = np.zeros((texR.shape[0], texR.shape[1]), dtype='bool')
214 | #for i in range(mask.shape[0]):
215 | # for j in range(mask.shape[1]):
216 | # if np.isnan(texR[i, j]):
217 | # mask[i, j] = True
218 | #texR = texR[~mask]
219 | img = np.empty((texR.shape[0],texR.shape[1],3))
220 | img[:,:,0] = texR
221 | img[:,:,1] = texG
222 | img[:,:,2] = texB
223 | img = img*255
224 | #img = img[~mask]
225 | print("DONE")
226 | return np.array(img,dtype=np.uint8)
227 |
228 | def resize_imgs(self, img, size_row, size_col):
229 | rows = img.shape[0]
230 | cols = img.shape[1]
231 | # resize the rows
232 | if rows > size_row:
233 | diff = rows - size_row
234 | for i in range(diff):
235 | img = np.delete(img, (i), axis=0)
236 | else:
237 | diff = size_row - rows
238 | new_row = np.zeros((1, cols, 3))
239 | for i in range(diff):
240 | img = np.vstack([img, new_row])
241 | if cols > size_col:
242 | diff = cols - size_col
243 | for i in range(diff):
244 | img = np.delete(img, (i), axis=1)
245 | else:
246 | diff = size_col - cols
247 | new_col = np.zeros((img.shape[0],1, 3))
248 | #print img.shape
249 | #print new_col.shape
250 | for i in range(diff):
251 | img = np.hstack([img,new_col])
252 | return img
253 |
254 | #def sub2ind(self, array_shape, rows, cols):
255 | # return rows*array_shape[1] + cols
256 |
257 | def avgModel(self, object):
258 | # creo il modello medio
259 | rows = object[0].frontalView.shape[0]
260 | cols = object[0].frontalView.shape[1]
261 | R_cumulative_matrix = object[0].frontalView[:, :, 0]
262 | G_cumulative_matrix = object[0].frontalView[:, :, 1]
263 | B_cumulative_matrix = object[0].frontalView[:, :, 2]
264 | for i in range(len(object)):
265 | current_R = object[i].frontalView[:, :, 0]
266 | current_G = object[i].frontalView[:, :, 1]
267 | current_B = object[i].frontalView[:, :, 2]
268 | R_cumulative_matrix = R_cumulative_matrix + current_R
269 | G_cumulative_matrix = G_cumulative_matrix + current_G
270 | B_cumulative_matrix = B_cumulative_matrix + current_B
271 | mean_R = R_cumulative_matrix / len(object)
272 | mean_G = G_cumulative_matrix / len(object)
273 | mean_B = B_cumulative_matrix / len(object)
274 | avgModel = np.empty((rows, cols, 3))
275 | avgModel[:,:,0] = mean_R
276 | avgModel[:,:,1] = mean_G
277 | avgModel[:,:,2] = mean_B
278 |
279 | return np.array(avgModel,dtype=np.uint8)
280 |
281 | def colors_AVG(self, object):
282 | cumulative_matrix_col = object[0].colors
283 | for i in range(len(object)):
284 | cumulative_matrix_col = cumulative_matrix_col + object[i].colors
285 | return cumulative_matrix_col/len(object)
286 |
287 | def deform_texture_fast(self, mean, eigenves, alpha, ex_to_ne):
288 | dim = eigenves.shape[0]/3
289 | alpha_full = np.tile(np.transpose(alpha), (eigenves.shape[0],1))
290 | tmp_eigen = alpha_full*eigenves
291 | sumVec = tmp_eigen.sum(axis=1)
292 | sumVec = sumVec.reshape((sumVec.shape[0],1), order='F')
293 | sumMat = np.reshape(np.transpose(sumVec), (3,dim), order='F')
294 | return mean + sumMat
295 |
296 | def cov(self,X):
297 | """
298 | Covariance matrix
299 | note: specifically for mean-centered data
300 | note: numpy's `cov` uses N-1 as normalization
301 | """
302 | return dot(X.T, X) / X.shape[0]
303 | # N = data.shape[1]
304 | # C = empty((N, N))
305 | # for j in range(N):
306 | # C[j, j] = mean(data[:, j] * data[:, j])
307 | # for k in range(j + 1, N):
308 | # C[j, k] = C[k, j] = mean(data[:, j] * data[:, k])
309 | # return C
310 |
311 | def pca(self,data, pc_count=None):
312 | """
313 | Principal component analysis using eigenvalues
314 | note: this mean-centers and auto-scales the data (in-place)
315 | """
316 | data -= mean(data, 0)
317 | data /= std(data, 0)
318 | C = self.cov(data)
319 | E, V = eigh(C)
320 | key = argsort(E)[::-1][:pc_count]
321 | E, V = E[key], V[:, key]
322 | U = dot(data, V) # used to be dot(V.T, data.T).T
323 | return U, E, V
324 |
325 | def PIL2array(self,img):
326 | return np.array(img.getdata(),
327 | np.uint8).reshape(img.size[1], img.size[0], 3)
328 |
329 | def mse(self, imageA, imageB):
330 | # the 'Mean Squared Error' between the two images is the
331 | # sum of the squared difference between the two images;
332 | # NOTE: the two images must have the same dimension
333 | err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
334 | err /= float(imageA.shape[0] * imageA.shape[1])
335 |
336 | # return the MSE, the lower the error, the more "similar"
337 | # the two images are
338 | return err
339 |
340 | def compare_imgs(self,i1,i2):
341 | pairs = zip(i1.getdata(), i2.getdata())
342 | if len(i1.getbands()) == 1:
343 | # for gray-scale jpegs
344 | dif = sum(abs(p1 - p2) for p1, p2 in pairs)
345 | else:
346 | dif = sum(abs(c1 - c2) for p1, p2 in pairs for c1, c2 in zip(p1, p2))
347 |
348 | ncomponents = i1.size[0] * i1.size[1] * 3
349 | #print "Difference (percentage):", (dif / 255.0 * 100) / ncomponents
350 | return (dif / 255.0 * 100) / ncomponents
351 |
352 | def mean(self, obj):
353 | sum = 0
354 | for i in range(len(obj)):
355 | sum += obj[i].ssim_D
356 | return sum/len(obj)
357 |
358 | def plot_mesh(self, vertex, face, texture, texture_coords):
359 | # function to map texture (on a 2d image) to a 3d surface
360 | # It's not general, vertex was assigned both to FF and TF
361 | FF = face
362 | VV = vertex
363 | TF = face
364 | VT = texture_coords
365 | I = texture
366 | iscolor = True
367 | VT2 = VT
368 | sizep = 64
369 | [lambda1, lambda2, lambda3, jind] = self.calculateBarycentricInterpolationValues(sizep)
370 | Ir = I[:,:,0]
371 | Ig = I[:,:,1]
372 | Ib = I[:,:,2]
373 | Jr = np.zeros((sizep+1, sizep+1))
374 | Jg = np.zeros((sizep+1, sizep+1))
375 | Jb = np.zeros((sizep+1, sizep+1))
376 |
377 | #for i in range(FF.shape[0]):
378 | for i in range(1):
379 | # current triangles of the mesh
380 | V = VV[FF[i,:],:]
381 | Vt = VT2[TF[i,:],:]
382 |
383 | # Define the triangle as a surface
384 | x = np.matrix([[V[0,0], V[1,0]], [V[2,0], V[2,0]]])
385 | y = np.matrix([[V[0,1], V[1,1]], [V[2,1], V[2,1]]])
386 | z = np.matrix([[V[0,2], V[1,2]], [V[2,2], V[2,2]]])
387 |
388 | # Define the texture coordinates of the surface
389 | tx = np.matrix([Vt[0,0], Vt[1,0], Vt[2,0], Vt[2,0]])
390 | ty = np.matrix([Vt[0,1], Vt[1,1], Vt[2,1], Vt[2,1]])
391 | xy = np.matrix([ [[tx[0,0]], ty[0,0]], [tx[0,1], ty[0,1]], [tx[0,2], ty[0,2]], [tx[0,2], ty[0,2]] ])
392 | pos = np.zeros([lambda1.shape[0], 2])
393 | pos[:,0] = (xy[0,0]*lambda1+xy[1,0]*lambda2+xy[2,0]*lambda3).reshape((pos.shape[0]))
394 | pos[:,1] = (xy[0,1]*lambda1+xy[1,1]*lambda2+xy[2,1]*lambda3).reshape((pos.shape[0]))
395 | pos = np.round(pos)
396 | pos[:,0] = np.minimum(pos[:,0], I.shape[0])
397 | pos[:,1] = np.minimum(pos[:,1], I.shape[1])
398 | posind=(pos[:,0]-1)+(pos[:,1]-1)*I.shape[0]+1
399 | # indices
400 | jind = np.array(jind, dtype=np.intp)
401 | posind = np.array(posind, dtype=np.intp)
402 | Jr = Jr.flatten(order='F')
403 | Jg = Jr.flatten(order='F')
404 | Jb = Jr.flatten(order='F')
405 | Ir = Ir.flatten(order='F')
406 | Ig = Ir.flatten(order='F')
407 | Ib = Ir.flatten(order='F')
408 |
409 | J = np.zeros((sizep+1, sizep+1, 3))
410 |
411 | Jr[jind-1] = Ir[posind-1]
412 | J[:,:,0] = Jr.reshape((J.shape[0], J.shape[0]))
413 |
414 | Jg[jind-1] = Ig[posind-1]
415 | J[:,:,1] = Jg.reshape((J.shape[0], J.shape[0]))
416 |
417 | Jb[jind-1] = Ib[posind-1]
418 | J[:,:,2] = Jb.reshape((J.shape[0], J.shape[0]))
419 |
420 |
421 |
422 |
423 | return J
424 |
425 |
426 | def calculateBarycentricInterpolationValues(self, sizep):
427 | # Define a triangle in the upperpart of an square, because only that
428 | # part is used by the surface function
429 | x1 = sizep
430 | y1 = sizep
431 | x2 = sizep
432 | y2 = 0
433 | x3 = 0
434 | y3 = 0
435 | detT = (x1-x3)*(y2-y3) - (x2-x3)*(y1-y3)
436 | [x,y] = np.meshgrid(np.arange(0,sizep+1), np.arange(0,sizep+1))
437 | x = x.flatten()
438 | y = y.flatten()
439 | x = x.reshape((x.shape[0], 1))
440 | y = y.reshape((y.shape[0], 1))
441 | lambda1 = ((y2-y3)*(x-x3)+(x3-x2)*(y-y3))/detT
442 | lambda2 = ((y3-y1)*(x-x3)+(x1-x3)*(y-y3))/detT
443 | lambda3 = 1-lambda1-lambda2
444 |
445 | [jx,jy] = np.meshgrid(sizep-np.arange(0,sizep+1)+1, sizep-np.arange(0,sizep+1)+1)
446 | jind = (jx.flatten()-1)+(jy.flatten()-1)*(sizep+1)+1
447 | return np.reshape(lambda1, (lambda1.shape[0], 1)), np.reshape(lambda2, (lambda2.shape[0], 1)), np.reshape(lambda3, (lambda3.shape[0], 1)), jind
448 |
--------------------------------------------------------------------------------